📜  Python Pytorch full() 方法(1)

📅  最后修改于: 2023-12-03 15:04:08.059000             🧑  作者: Mango

Python Pytorch full() 方法

在 Pytorch 中,torch.full(size, fill_value, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) 方法可以用于创建具有指定形状和填充值的新张量。

其中参数含义如下:

  • size:张量的形状。可以是一个整数,表示一个具有该值的标量张量;也可以是一个包含多个整数的元组,以表示具有该形状的张量。
  • fill_value:张量的填充值。必须是一个标量,表示将张量的所有元素都设置为该值。
  • out:张量的输出张量。可以是预先分配的张量来接收输出。默认为None
  • dtype:输出张量的数据类型。默认为None,表示将使用默认数据类型。
  • layout:默认为 torch.strided。已知渐进传输的代价为线性传输,这种设定中,张量被视为一个连续的一维张量,在其中找到元素,并用 stride 属性明确约定轴的大小和步长。如果将 torch.memory_format 作为张量参数传递,则将使用与该格式相对应的新分配内存布局。
  • device:分配张量的设备。默认为None,表示使用当前设备。
  • requires_grad:是否计算梯度。默认为False

以下是使用torch.full()方法创建张量的示例代码:

import torch

# 创建一个形状为(2, 3)、元素值为4的张量
a = torch.full((2, 3), 4)
print(a)

# 输出结果
# tensor([[4, 4, 4],
#         [4, 4, 4]])

# 创建一个形状为(1,)、元素值为5.5的张量
b = torch.full((1,), 5.5)
print(b)

# 输出结果
# tensor([5.5000])

在上面的代码中,使用torch.full()方法创建了两个张量ab。张量a的形状为(2, 3),元素值都是4。张量b的形状为(1,),元素值为5.5。