📜  如何导入 mnist 数据集 keras - Python (1)

📅  最后修改于: 2023-12-03 15:24:42.722000             🧑  作者: Mango

如何导入 MNIST 数据集到 Keras

如果您正在使用 Keras 构建神经网络,您可能会需要导入它的默认数据集之一:MNIST,这是一个手写数字数据集,有 60,000 张训练图片和 10,000 张测试图片。

以下是在 Python 中使用 Keras 导入 MNIST 数据集的步骤。

首先,我们需要导入 Keras 库和 MNIST 数据集:

from keras.datasets import mnist

# 加载 MNIST 数据集,并将数据集分为训练数据和测试数据
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

在这里,我们使用 mnist.load_data() 方法来下载和加载 MNIST 数据集。返回的数据集具有以下属性:

  • train_imagestrain_labels:这是训练集,即用于训练模型的输入数据和真实标签。
  • test_imagestest_labels:这是测试集,即用于评估模型的输入数据和真实标签。

现在,我们已经成功加载了 MNIST 数据集。在继续之前,让我们先看一下数据集的结构。

# 查看训练数据集的形状
print(train_images.shape)    # (60000, 28, 28)
print(train_labels.shape)    # (60000,)

# 查看测试数据集的形状
print(test_images.shape)     # (10000, 28, 28)
print(test_labels.shape)     # (10000,)

我们可以看到训练数据集和测试数据集形状相同,都有 28x28 像素的图像,但它们的数量不同。

接下来,我们可以将数据集可视化,以便更好地了解它们的结构和特征。在这个例子中,我们将显示前 10 张图像,并添加它们的真实标签。

import matplotlib.pyplot as plt

# 将数据集可视化
fig, axes = plt.subplots(2, 5, figsize=(10, 5))
fig.subplots_adjust(hspace=0.4, wspace=0.4)
for i, ax in enumerate(axes.flat):
    ax.imshow(train_images[i], cmap='binary', interpolation='nearest')
    ax.set(xticks=[], yticks=[], xlabel=train_labels[i])

MNIST 数据集可视化

我们可以看到这些手写数字都在黑白图像中呈现,它们的分辨率相对较低,但我们对它们的形状和特征可以做出准确的预测。

到此为止,您已经学会了如何导入 MNIST 数据集到 Keras 中,如何查看数据集的结构和特征,并可视化这些图像。接下来,您可以利用这些数据训练神经网络模型并进行预测。