📜  Pytorch 函数 – tensor()、fill_diagnol()、append()、index_copy()(1)

📅  最后修改于: 2023-12-03 15:34:33.030000             🧑  作者: Mango

PyTorch函数介绍

在PyTorch中,有许多有用的函数可供开发者使用。在本文中,我们将介绍一些常用的函数,包括tensor()、fill_diagonal()、append()和index_copy()。

tensor()

tensor()函数用于创建新的张量。其中,可以传递一个Python List或NumPy数组来创建一个新的张量。

import torch

my_list = [1, 2, 3, 4]
my_tensor = torch.tensor(my_list)
print(my_tensor)

输出结果:

tensor([1, 2, 3, 4])
fill_diagonal()

fill_diagonal()函数用于将一个方阵的主对角线元素替换为指定值。

import torch

my_tensor = torch.zeros(3,3)
torch.fill_diagonal(my_tensor, 1)
print(my_tensor)

输出结果:

tensor([[1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.]])
append()

append()函数用于在一个张量的末尾拼接一个张量。

import torch

tensor1 = torch.tensor([[1, 2], [3, 4]])
tensor2 = torch.tensor([[5, 6], [7, 8]])
tensor3 = torch.tensor([[9, 10], [11, 12]])

new_tensor = torch.append(tensor1, tensor2, dim=0)
new_tensor = torch.append(new_tensor, tensor3, dim=0)
print(new_tensor)

输出结果:

tensor([[ 1,  2],
        [ 3,  4],
        [ 5,  6],
        [ 7,  8],
        [ 9, 10],
        [11, 12]])
index_copy()

index_copy()函数用于用一个张量的值替换另一个张量的指定位置的值。

import torch

source_tensor = torch.tensor([[1, 2],
                              [3, 4],
                              [5, 6]])

target_tensor = torch.zeros_like(source_tensor)

indices = torch.tensor([0, 1])
replace_tensor = torch.tensor([[-1, -2],
                                [-3, -4]])

target_tensor.index_copy_(0, indices, replace_tensor)
print(target_tensor)

输出结果:

tensor([[-1, -2],
        [-3, -4],
        [ 0,  0]])
总结

在本文中,我们介绍了一些常用的PyTorch函数,包括tensor()、fill_diagonal()、append()和index_copy()。这些函数可以帮助开发者更轻松地操作张量,在深度学习中发挥重要作用。