📜  pytorch 获取非诊断元素 - Python (1)

📅  最后修改于: 2023-12-03 15:19:37.120000             🧑  作者: Mango

PyTorch获取非诊断元素

PyTorch是一个开源的机器学习框架,它旨在提供与Numpy类似的灵活性和易用性,同时又能利用GPU加速计算。在PyTorch中,非诊断元素是指数值的原始值,而不是经过自动微分计算的梯度值。

在PyTorch中获取非诊断元素有多种方法,下面分别介绍。

1.使用.data属性

PyTorch Tensor有一个.data属性,它包含原始的Tensor。可以使用此属性轻松获取非诊断元素。

import torch

a = torch.randn((3,3))
print(a)

# 获取非诊断元素
print(a.data)

输出:

tensor([[-0.4130, -0.7468, -0.8859],
        [ 0.3699,  0.1145,  0.4282],
        [ 0.4439, -0.9954,  0.8953]])
tensor([[-0.4130, -0.7468, -0.8859],
        [ 0.3699,  0.1145,  0.4282],
        [ 0.4439, -0.9954,  0.8953]])
2.使用.numpy()方法

可以使用.numpy()方法将PyTorch Tensor转换为NumPy数组,因为NumPy数组是基本的Python数据类型,因此可以直接访问非诊断元素。

import torch

a = torch.randn((3,3))
print(a)

# 获取非诊断元素
print(a.numpy())

输出:

tensor([[-0.2827, -0.7321, -1.2669],
        [ 1.6779, -0.8000,  0.3645],
        [-1.0587, -0.3588,  1.6982]])
[[-0.28274167 -0.73209715 -1.2668867 ]
 [ 1.6778553  -0.7999647   0.3644662 ]
 [-1.0587276  -0.35880426  1.6981963 ]]
3.使用.detach()方法

可以使用.detach()方法从计算图中分离Tensor,以获取非诊断元素。

import torch

a = torch.randn((3,3), requires_grad=True)
print(a)

# 获取非诊断元素
print(a.detach())

输出:

tensor([[-1.1620, -1.5022,  0.2724],
        [-0.3703, -0.3313, -0.3007],
        [ 0.5315, -0.6505, -0.2275]], requires_grad=True)
tensor([[-1.1620, -1.5022,  0.2724],
        [-0.3703, -0.3313, -0.3007],
        [ 0.5315, -0.6505, -0.2275]])
结论

无论哪种方法,都可以轻松地从PyTorch Tensor中获取非诊断元素。在实际编程中,可以根据情况选择不同的方法来完成特定的任务。