📜  TensorFlow 中的 CIFAR-10 图像分类(1)

📅  最后修改于: 2023-12-03 15:05:32.363000             🧑  作者: Mango

TensorFlow 中的 CIFAR-10 图像分类

简介

CIFAR-10 是一个经典的图像分类数据集,包括 60,000 个 32x32 像素的彩色图像,分为 10 个不同的类别,每类有 6,000 个图片。本文将介绍如何使用 TensorFlow 进行 CIFAR-10 图像分类。

步骤
  1. 准备数据

    • 下载 CIFAR-10 数据集。可以在 CIFAR-10 官网 上下载。
    • 解压并放置文件夹,并确保你的代码可以读取。
    • 使用 TensorFlow 的 tf.data.Dataset API 加载数据。具体实现可以参照以下代码片段:
    import tensorflow as tf
    
    BATCH_SIZE = 128
    
    def load_data():
        (train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.cifar10.load_data()
        # Normalize pixel values to be between 0 and 1
        train_images, test_images = train_images / 255.0, test_images / 255.0
        # Convert labels to categorical one-hot encoding
        train_labels = tf.keras.utils.to_categorical(train_labels, num_classes=10)
        test_labels = tf.keras.utils.to_categorical(test_labels, num_classes=10)
        train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels))
        train_dataset = train_dataset.shuffle(buffer_size=50000).batch(BATCH_SIZE)
        test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels))
        test_dataset = test_dataset.batch(BATCH_SIZE)
        return train_dataset, test_dataset
    
  2. 建立模型

    • 使用 TensorFlow 中的 tf.keras 建立卷积神经网络模型,可以参照以下代码片段:
    from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
    
    NUM_CLASSES = 10
    
    def build_model():
        model = tf.keras.Sequential([
            Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)),
            MaxPooling2D((2, 2)),
            Conv2D(64, (3, 3), activation='relu'),
            MaxPooling2D((2, 2)),
            Conv2D(64, (3, 3), activation='relu'),
            Flatten(),
            Dense(64, activation='relu'),
            Dropout(0.5),
            Dense(NUM_CLASSES, activation='softmax')
        ])
        return model
    
  3. 训练模型

    • 定义损失函数、优化器和评价指标。
    • 训练模型,使用 tf.keras.Model 中的 fit() 函数,可以参照以下代码片段:
    EPOCHS = 10
    
    def train(model, train_dataset, test_dataset):
        model.compile(optimizer='adam',
                      loss='categorical_crossentropy',
                      metrics=['accuracy'])
        model.fit(train_dataset, epochs=EPOCHS,
                  validation_data=test_dataset, verbose=2)
    
  4. 评估模型

    • 使用 tf.keras.Model 中的 evaluate() 函数评估模型,可以参照以下代码片段:
    def evaluate(model, test_dataset):
        loss, accuracy = model.evaluate(test_dataset, verbose=2)
        print(f'Test loss: {loss}, Test accuracy: {accuracy}')
    
  5. 预测

    • 使用 tf.keras.Model 中的 predict() 函数预测新的图片,可以参照以下代码片段:
    def predict(model, new_images):
        predictions = model.predict(new_images)
        return predictions
    
总结

本文介绍了如何使用 TensorFlow 对 CIFAR-10 图像分类,涉及数据加载、模型建立、训练、评估和预测。如果你希望深入了解 TensorFlow,建议查看 TensorFlow 的官方文档。