📜  Python PyTorch – rsqrt() 方法(1)

📅  最后修改于: 2023-12-03 15:18:59.394000             🧑  作者: Mango

Python PyTorch - rsqrt() 方法

在 PyTorch 中,rsqrt() 是一个用于计算张量元素的倒数平方根的方法。它返回一个新的张量,其中每个元素都是输入张量对应元素的倒数平方根。

语法
torch.rsqrt(input, out=None) -> Tensor
  • input:输入张量。
  • out:可选的输出张量。
示例
import torch

# 创建一个张量
x = torch.tensor([4, 9, 16], dtype=torch.float)

# 使用 rsqrt() 方法计算倒数平方根
output = torch.rsqrt(x)

print(output)

输出:

tensor([0.5000, 0.3333, 0.2500])

上述示例中,我们创建了一个张量 x,其值为 [4, 9, 16]。然后使用 rsqrt() 方法计算 x 中每个元素的倒数平方根,最后打印输出结果。

注意事项
  • 输入张量的元素类型应为浮点数类型,否则将引发异常。
  • 如果传递了 out 参数,则输出结果将保存在这个参数对应的张量中。
  • 如果输入张量的元素为负值,则将返回 NaN (not a number)。
  • 如果输入张量的元素为零,则将返回正无穷大 (positive infinity)。

更多关于 PyTorch 中 rsqrt() 方法的详细信息,请参考官方文档

希望这个介绍能帮助到你理解和使用 PyTorch 中的 rsqrt() 方法。