📜  Python Pytorch eye() 方法(1)

📅  最后修改于: 2023-12-03 15:34:03.921000             🧑  作者: Mango

Python Pytorch eye() 方法介绍

简介

eye() 是 PyTorch 提供的方法之一,用于创建一个二维矩阵,并将其对角线上的元素赋值为 1,其余元素为 0。该方法的语法如下:

torch.eye(n, m=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False)

其中,参数 n 表示生成矩阵的行数,参数 m 表示生成矩阵的列数,如果 m 未指定,则默认与 n 相同,参数 dtype 表示生成矩阵的数据类型,默认为 torch.float32

举例
import torch

# 创建一个 3 行 3 列的单位矩阵
out = torch.eye(3)
print(out)

以上代码输出结果如下:

tensor([[1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.]])

在使用 eye() 方法时,可以通过指定 out 参数来使用一个已有的 Tensor 对象来接收生成的矩阵。

import torch

# 创建一个 3 行 3 列的单位矩阵,使用一个已有的 Tensor 对象接收生成的矩阵
out = torch.empty(3, 3)
torch.eye(3, out=out)
print(out)

以上代码输出结果如下:

tensor([[1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.]])

通过指定 dtype 参数,可以生成指定类型的矩阵。

import torch

# 创建一个 3 行 3 列的单位矩阵,数据类型为 torch.int64
out = torch.eye(3, dtype=torch.int64)
print(out)

以上代码输出结果如下:

tensor([[1, 0, 0],
        [0, 1, 0],
        [0, 0, 1]])
总结

通过本文的介绍,我们学习了 PyTorch 的 eye() 方法,了解了其语法和用法,并且通过举例展示了如何使用该方法,帮助程序员更好地使用 PyTorch 进行矩阵操作。