📜  Tensor.expand_as - Python (1)

📅  最后修改于: 2023-12-03 15:35:16.803000             🧑  作者: Mango

Tensor.expand_as - Python

Tensor.expand_as() 是 PyTorch 的一个 Tensor 扩展方法,用于将一个 Tensor 沿着维度进行重复,使其形状和另一个指定的 Tensor 形状相同。

语法
torch.Tensor.expand_as(other)
参数
  • other (Tensor):用于确定新形状的 Tensor。
返回值

一个新的 Tensor,其形状与 other 形状相同,但数据按重复原始 Tensor 得到。

示例
import torch

# 创建一个形状为(2, 1, 3)的Tensor
x = torch.randn(2, 1, 3)
print(x.shape)
# torch.Size([2, 1, 3])

# 创建一个形状为(2, 4, 3)的Tensor
y = torch.randn(2, 4, 3)
print(y.shape)
# torch.Size([2, 4, 3])

# 将x沿着第二个维度重复4遍,并输出形状
z = x.expand_as(y)
print(z.shape)
# torch.Size([2, 4, 3])

在上面的示例中,我们创建了两个不同形状的 Tensor x 和 y。然后,我们使用 expand_as 将 x 沿着第二个维度重复了4遍,得到形状与 y 相同的 Tensor z。

总结

Tensor.expand_as 是一个方便的 Tensor 扩展方法,可以轻松地将一个 Tensor 沿着维度重复,以匹配另一个指定的 Tensor 的形状。这个方法是 PyTorch 中强大而灵活的 Tensor 操作之一,通常用于模型训练和数据处理中。