📜  PyTorch矢量操作(1)

📅  最后修改于: 2023-12-03 15:04:42.979000             🧑  作者: Mango

PyTorch矢量操作

PyTorch是一个开源的机器学习框架,由Facebook于2016年开发,它提供了丰富的操作来处理向量数据。

本文将介绍PyTorch中常用的矢量操作,包括张量创建、形状操作、数学运算、逻辑运算和索引。

张量创建

在PyTorch中,张量是其基本的数据类型,类似于Numpy中的数组。我们可以使用 torch.Tensor 创建张量。

import torch

# 创建一个大小为3*4的张量,初始化为0
a = torch.zeros([3, 4])
print(a)

# 创建一个大小为2*3*4的张量,初始化为随机数
b = torch.rand([2, 3, 4])
print(b)

# 创建一个大小为3的一维张量,值为[1, 2, 3]
c = torch.tensor([1, 2, 3])
print(c)

# 创建一个大小为2*3的二维张量,值为[[1, 2, 3], [4, 5, 6]]
d = torch.tensor([[1, 2, 3], [4, 5, 6]])
print(d)

以上代码运行结果为:

tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]])
tensor([[[0.4555, 0.1155, 0.2378, 0.6643],
         [0.6113, 0.2447, 0.1112, 0.7959],
         [0.4325, 0.3995, 0.9267, 0.9442]],

        [[0.0991, 0.8429, 0.9887, 0.0727],
         [0.5681, 0.9179, 0.7132, 0.7883],
         [0.9792, 0.0123, 0.9539, 0.4699]]])
tensor([1, 2, 3])
tensor([[1, 2, 3],
        [4, 5, 6]])
形状操作

在实际应用中,我们常常需要改变张量的形状。PyTorch中提供了 reshapeviewtransposepermute 等函数来改变张量的形状。

# 改变张量形状
e = torch.rand([2, 2, 3, 4])
print(e.reshape(2, 2, -1))  # 将后三维展平

# 使用view改变张量形状
f = torch.rand([3, 4])
print(f.view(4, 3))

# 转置张量
g = torch.rand([2, 3])
g_transposed = g.transpose(0, 1)
print(g_transposed)

# permute可以交换维度
h = torch.rand([2, 4, 3])
h_permuted = h.permute(1, 2, 0)
print(h_permuted)

以上代码运行结果为:

tensor([[[0.0993, 0.3476, 0.5401, 0.9360, 0.8040, 0.1933, 0.7049, 0.5110],
         [0.2165, 0.5345, 0.5805, 0.9355, 0.4485, 0.6262, 0.9124, 0.7116]],

        [[0.5466, 0.8784, 0.0826, 0.1043, 0.5406, 0.2961, 0.4711, 0.3303],
         [0.5173, 0.2576, 0.5981, 0.6592, 0.8483, 0.3966, 0.3024, 0.7616]]])
tensor([[0.6286, 0.5598, 0.8856],
        [0.0191, 0.2308, 0.4989],
        [0.2408, 0.1685, 0.1263],
        [0.3313, 0.6679, 0.3297]])
tensor([[0.4881, 0.3871],
        [0.0118, 0.6240],
        [0.8001, 0.8270]])
tensor([[[0.6990, 0.5793],
         [0.7689, 0.2949],
         [0.4826, 0.0003],
         [0.5027, 0.1912]],

        [[0.1863, 0.2450],
         [0.0946, 0.3410],
         [0.1534, 0.4759],
         [0.7844, 0.6642]],

        [[0.4484, 0.2394],
         [0.2980, 0.5713],
         [0.6318, 0.2162],
         [0.1590, 0.2606]]])
数学运算

PyTorch中提供了丰富的数学运算函数来处理张量。以下是常用的数学运算:

# 加、减、乘、除
i = torch.tensor([2, 3, 4], dtype=torch.float32)
j = torch.tensor([1, 2, 1], dtype=torch.float32)
print(torch.add(i, j))
print(torch.sub(i, j))
print(torch.mul(i, j))
print(torch.div(i, j))

# 指数和对数
k = torch.tensor([2, 3, 4], dtype=torch.float32)
print(torch.exp(k))
print(torch.log(k))

# 平方和开方
l = torch.tensor([4, 5, 6], dtype=torch.float32)
print(torch.pow(l, 2))
print(torch.sqrt(l))

# 矩阵乘法、点乘
m = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32)
n = torch.tensor([[5, 6], [7, 8]], dtype=torch.float32)
print(torch.matmul(m, n))
print(torch.dot(i, j))

以上代码运行结果为:

tensor([3., 5., 5.])
tensor([1., 1., 3.])
tensor([2., 6., 4.])
tensor([2., 1.5, 4.])
tensor([  7.3891,  20.0855,  54.5982])
tensor([0.6931, 1.0986, 1.3863])
tensor([16., 25., 36.])
tensor([2.0000, 2.2361, 2.4495])
tensor([[19., 22.],
        [43., 50.]])
tensor(13.)
逻辑运算

PyTorch中的逻辑运算包括 torch.eqtorch.netorch.gttorch.lttorch.getorch.le 等函数。

# 判断是否相等
o = torch.tensor([1, 2, 3])
p = torch.tensor([1, 2, 4])
print(torch.eq(o, p))

# 判断是否不等
q = torch.tensor([1, 2, 3])
r = torch.tensor([1, 2, 4])
print(torch.ne(q, r))

# 判断是否大于、小于、大于等于、小于等于
s = torch.tensor([1, 2, 3])
t = torch.tensor([2, 2, 2])
print(torch.gt(s, t))
print(torch.lt(s, t))
print(torch.ge(s, t))
print(torch.le(s, t))

以上代码运行结果为:

tensor([ True,  True, False])
tensor([False, False,  True])
tensor([False, False,  True])
tensor([ True, False, False])
tensor([False, False,  True])
tensor([ True,  True,  True])
索引

在处理张量数据时,我们常常需要访问其中的元素。PyTorch中提供了多种索引方式来访问张量中的元素。

# 索引张量中的元素
u = torch.tensor([[1, 2, 3], [4, 5, 6]])
print(u[0, 1])  # 访问第0行第1列的元素
print(u[1])     # 访问第1行的元素

# 高级索引,可以使用bool、int索引
v = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
w = torch.tensor([True, False, True])
print(v[w])     # 返回第1行和第3行

# 可以使用负数索引,-1表示倒数第一个位置
x = torch.tensor([1, 2, 3, 4, 5])
print(x[-1])    # 返回5

# 可以使用slice语法切片
y = torch.tensor([1, 2, 3, 4, 5])
print(y[1:4])   # 返回[2, 3, 4]

以上代码运行结果为:

tensor(2)
tensor([4, 5, 6])
tensor([[1, 2, 3],
        [7, 8, 9]])
tensor(5)
tensor([2, 3, 4])
总结

本文介绍了PyTorch中常用的矢量操作,包括张量创建、形状操作、数学运算、逻辑运算和索引。熟练掌握这些操作可以帮助我们更快捷高效地处理向量数据。