📜  pytorch plt.imshow - Python (1)

📅  最后修改于: 2023-12-03 14:46:48.422000             🧑  作者: Mango

PyTorch plt.imshow - Python

PyTorch是一个基于Python的科学计算包,主要针对两类人群:深度学习研究人员和企业客户,提供了Tensor(张量)这一通用的数据结构以及许多工具,用于构建深度学习模型。

plt.imshow是Matplotlib中的函数之一,用于展示图片。结合PyTorch可以方便地展示模型输出的图片。

安装PyTorch
pip install torch torchvision
样例代码
import torch
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt

# 加载模型
model = torch.load('model.pth')
# 准备数据
img = Image.open('test.jpg')
img_tensor = transforms.ToTensor()(img)
# 预测
with torch.no_grad():
    output = model(img_tensor.unsqueeze(0))
# 展示结果
plt.imshow(output[0].squeeze().cpu().numpy(), cmap='gray')
plt.show()
代码说明
  1. 加载PyTorch库

    import torch
    from torchvision import transforms
    from PIL import Image
    import matplotlib.pyplot as plt
    
  2. 加载模型

    model = torch.load('model.pth')
    

    加载已训练好的PyTorch模型

  3. 准备数据

    img = Image.open('test.jpg')
    img_tensor = transforms.ToTensor()(img)
    

    加载测试图像,并将其转换成PyTorch中的Tensor

  4. 预测

    with torch.no_grad():
        output = model(img_tensor.unsqueeze(0))
    

    将测试图像送入模型中进行预测,并得到输出

  5. 展示结果

    plt.imshow(output[0].squeeze().cpu().numpy(), cmap='gray')
    plt.show()
    

    将输出转换成可展示的图像,并展示出来

参考资料