📜  针对CIFAR-10数据集的LeNet模型的测试(1)

📅  最后修改于: 2023-12-03 15:42:07.189000             🧑  作者: Mango

针对CIFAR-10数据集的LeNet模型的测试

本文主要介绍针对CIFAR-10数据集的LeNet模型的测试方法及其实现过程。CIFAR-10数据集是一个常用的图像分类数据集,其中包含10个不同类别的图像,每个类别包含6000张32 x 32的彩色图像。LeNet模型是神经网络中的一个经典模型,其结构简单,适合用于图像分类问题。

数据集准备

首先,我们需要从官网上下载CIFAR-10数据集,下载地址为:https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz

下载完成后,我们需要解压缩,得到6个文件:

cifar-10-batches-py/
|-- batches.meta
|-- data_batch_1
|-- data_batch_2
|-- data_batch_3
|-- data_batch_4
|-- data_batch_5
|-- readme.html
|-- test_batch

其中,batches.meta是一个包含标签名称的文件,data_batch_*包含训练图像及其对应标签,test_batch包含测试图像及其对应标签。我们可以使用Python的pickle模块将它们加载到内存中进行处理。

import pickle

# 加载训练数据
with open('cifar-10-batches-py/data_batch_1', mode='rb') as file:
    train_data = pickle.load(file, encoding='bytes')
    
# 加载测试数据
with open('cifar-10-batches-py/test_batch', mode='rb') as file:
    test_data = pickle.load(file, encoding='bytes')
数据集预处理

加载数据后,我们需要对其进行预处理,将其转换为模型可接受的输入格式。首先,我们需要将图像数据转换为tensor类型,并进行归一化处理。其次,我们需要将图像数据和标签分别打包成一个batch,并将其划分为训练集和验证集。

import torch
from torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import train_test_split

# 将图像数据转换为tensor类型,并进行归一化处理
x_train = torch.FloatTensor(train_data[b'data']).view(-1, 3, 32, 32)/255.0
y_train = torch.LongTensor(train_data[b'labels'])
x_test = torch.FloatTensor(test_data[b'data']).view(-1, 3, 32, 32)/255.0
y_test = torch.LongTensor(test_data[b'labels'])

# 将图像数据和标签分别打包成一个TensorDataset
train_dataset = TensorDataset(x_train, y_train)
test_dataset = TensorDataset(x_test, y_test)

# 将训练集划分为训练集和验证集
train_dataset, valid_dataset = train_test_split(train_dataset, test_size=0.2, random_state=42)

# 使用DataLoader将数据加载到模型中
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=True)
LeNet模型的实现

LeNet模型的结构比较简单,总共包含7层,其中包含2个卷积层和3个全连接层。我们可以使用PyTorch搭建如下的LeNet模型。

import torch.nn as nn

class LeNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5, padding=2)
        self.pool1 = nn.MaxPool2d(kernel_size=2)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, padding=0)
        self.pool2 = nn.MaxPool2d(kernel_size=2)
        self.fc1 = nn.Linear(in_features=16*5*5, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=84)
        self.fc3 = nn.Linear(in_features=84, out_features=10)
    
    def forward(self, x):
        x = nn.functional.relu(self.conv1(x))
        x = self.pool1(x)
        x = nn.functional.relu(self.conv2(x))
        x = self.pool2(x)
        x = x.view(-1, 16*5*5)
        x = nn.functional.relu(self.fc1(x))
        x = nn.functional.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# 创建一个LeNet模型实例
net = LeNet()
训练模型

创建好模型后,我们需要对其进行训练。在训练模型时,我们需要定义损失函数和优化器,并在每个epoch结束时计算验证集的准确率,以便确定模型是否过拟合。

import torch.optim as optim

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

# 训练模型
for epoch in range(20):  # 共训练20个epoch
    running_loss = 0.0
    for i_batch, (inputs, labels) in enumerate(train_loader, 0):
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    # 每个epoch结束后计算在验证集上的准确率
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in valid_loader:
            outputs = net(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    acc = correct/total
    print('Epoch %d loss: %.3f valid acc: %.3f' % (epoch+1, running_loss/len(train_loader), acc))
测试模型

最后,我们需要使用测试集对模型进行测试,并得到其在测试集上的准确率。

# 在测试集上测试模型
correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in test_loader:
        outputs = net(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
acc = correct/total
print('Test acc: %.3f' % acc)

以上就是针对CIFAR-10数据集的LeNet模型的测试的完整实现过程。