📜  残差网络 (ResNet) – 深度学习(1)

📅  最后修改于: 2023-12-03 15:26:53.116000             🧑  作者: Mango

残差网络 (ResNet) – 深度学习

概述

残差网络(ResNet)是由何凯明等人提出的,是深度学习中非常重要的一类网络,通过引入残差块解决了深度神经网络中的梯度弥散问题,使训练时的错误率更低,训练更为稳定,精度更高。在图像分类,目标检测,语音识别等领域取得了非常优秀的结果。

原理

ResNet的主要思路是引入了残差学习(residual learning)的概念,直接拟合残差,即用输入的值减去期望的值得到残差,再学习如何拟合这个残差。在网络的某些层中,不是直接拟合函数,而是拟合输入函数和残差之和的函数,这种方式被称为残差块(residual block)。

残差块 (Residual Block) 示意图:

alt text

残差网络的优势在于,它能够训练比传统网络更深更复杂的结构,而不会导致网络性能退化,因为它减轻了梯度弥散问题。

ResNet 架构

ResNet 的主流架构有 ResNet-18,ResNet-34,ResNet-50,ResNet-101,ResNet-152。其中,“18”代表网络的层数。 换句话说, ResNet-34 模型的层数是 34, ResNet-50 模型的层数是 50。我们在实战中通常使用 ResNet-50 等深度网络。

ResNet-50 架构示意图:

alt text

PyTorch 实现

反映这个架构的具体代码实现通常非常复杂,使用 PyTorch 框架可以使代码量更少,已经提供了预训练模型的实现。

代码示例:

import torch
import torch.nn as nn
import torch.optim as optim

class ResNet(nn.Module):
    def __init__(self):
        super(ResNet, self).__init__()
        
        self.resnet = nn.Sequential(
            nn.Conv2d(3,64,7,stride=2, padding=3),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(3, stride=2, padding=1),
            nn.Conv2d(64,64,3,padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64,64,3,padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64,64,3,padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64,128,3,stride=2,padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128,128,3,padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128,128,3,padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128,256,3,stride=2,padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256,256,3,padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256,256,3,padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256,512,3,stride=2,padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(512,512,3,padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(512,512,3,padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1,1))
        )
        
        self.fc = nn.Linear(512, 1000)

    def forward(self, x):
        y = self.resnet(x)
        y = y.view(y.shape[0], -1)
        y = self.fc(y)
        return y

在这个例子中,我们使用 PyTorch 框架,实现了 ResNet 骨干网络的代码。

总结

残差网络是深度学习中非常重要的一类网络,通过引入残差块解决了深度神经网络中的梯度弥散问题,取得了非常优秀的结果,同时深度学习框架 PyTorch 已提供了相应的实现,使得开发和使用 ResNet 变得更加容易。