📜  RuntimeError:预期的类型为 torch.FloatTensor 的对象,但发现类型为 torch.cuda.FloatTensor 作为参数 - Python 代码示例

📅  最后修改于: 2022-03-11 14:47:06.351000             🧑  作者: Mango

代码示例1
# 第二个类型是 torch.cuda.FloatTensor,这意味着它是已经移到 GPU 的张量。
# 它想获得类型为 torch.FloatTensor 的张量,但是没有 .cuda,因此该张量应该在 CPU 上。
# PyTorch 只能对位于相同设备上的张量进行运算,因此必须同时位于 CPU 或 GPU 上。如果你要在 GPU 
# 上运行网络,一定要使用 .to(device) 将模型和所有必要张量移到 GPU 上,其中 device 为 "cuda" 或 "cpu"。

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device) # the same model.cuda()

for images, labels in trainloader:
    images, labels = images.to(device), labels.to(device)