📜  pytorch pad to square - Python (1)

📅  最后修改于: 2023-12-03 15:34:32.947000             🧑  作者: Mango

PyTorch pad to square - Python

在PyTorch中,经常需要将输入数据填充为正方形的形状,以便进行卷积和池化。本文将介绍如何使用PyTorch对任意大小的输入进行填充,使其变为正方形。

算法实现

以下代码片段显示了如何使用PyTorch对输入数据进行填充,使其变为正方形。

import torch

def pad_to_square(image, pad_value):
    _, h, w = image.shape
    if h == w:
        return image
    elif h > w:
        padding = [0, 0, (h - w) // 2, (h - w) // 2]
    else:
        padding = [(w - h) // 2, (w - h) // 2, 0, 0]
    return torch.nn.functional.pad(image, padding, mode='constant', value=pad_value)

以上函数接受两个参数:一个输入张量和一个表示填充值的标量。如果输入张量是正方形,则直接返回它。如果不是,则计算需要添加到最终形状的填充量,然后使用torch.nn.functional.pad函数进行填充。

使用示例

可以使用以下代码对上述函数进行测试:

# Generate a random image of size 3x4
image = torch.randn((1, 3, 4))

# Pad the image to a square of size 4x4 with zero padding
padded_image = pad_to_square(image, 0)

# Print the padded image shape
print(padded_image.shape) # Output: torch.Size([1, 3, 4, 4])

此代码生成一个大小为3x4的随机张量,然后使用pad_to_square函数将其填充为一个大小为4x4的正方形,并将填充值设置为0。最后,输出填充后的张量形状。

结论

使用本文所述的pad_to_square函数,可以轻松地将任意大小的输入填充为正方形,以便进行卷积和池化。