📅  最后修改于: 2023-12-03 15:34:32.947000             🧑  作者: Mango
在PyTorch中,经常需要将输入数据填充为正方形的形状,以便进行卷积和池化。本文将介绍如何使用PyTorch对任意大小的输入进行填充,使其变为正方形。
以下代码片段显示了如何使用PyTorch对输入数据进行填充,使其变为正方形。
import torch
def pad_to_square(image, pad_value):
_, h, w = image.shape
if h == w:
return image
elif h > w:
padding = [0, 0, (h - w) // 2, (h - w) // 2]
else:
padding = [(w - h) // 2, (w - h) // 2, 0, 0]
return torch.nn.functional.pad(image, padding, mode='constant', value=pad_value)
以上函数接受两个参数:一个输入张量和一个表示填充值的标量。如果输入张量是正方形,则直接返回它。如果不是,则计算需要添加到最终形状的填充量,然后使用torch.nn.functional.pad
函数进行填充。
可以使用以下代码对上述函数进行测试:
# Generate a random image of size 3x4
image = torch.randn((1, 3, 4))
# Pad the image to a square of size 4x4 with zero padding
padded_image = pad_to_square(image, 0)
# Print the padded image shape
print(padded_image.shape) # Output: torch.Size([1, 3, 4, 4])
此代码生成一个大小为3x4的随机张量,然后使用pad_to_square
函数将其填充为一个大小为4x4的正方形,并将填充值设置为0。最后,输出填充后的张量形状。
使用本文所述的pad_to_square
函数,可以轻松地将任意大小的输入填充为正方形,以便进行卷积和池化。