📜  sklearn kmeans mnist - Python (1)

📅  最后修改于: 2023-12-03 14:47:28.343000             🧑  作者: Mango

Sklearn KMeans MNIST - Python

Introduction

This guide demonstrates how to implement K-means clustering on the MNIST dataset using the scikit-learn library (sklearn) in Python. K-means clustering is an unsupervised machine learning algorithm used to separate data into different clusters based on their similarities.

Prerequisites
  1. Python (version 3 or above)
  2. scikit-learn library (pip install scikit-learn)
  3. MNIST dataset (available from the official website)
Code Example
# Import necessary libraries
from sklearn.cluster import KMeans
from sklearn.datasets import fetch_openml
import matplotlib.pyplot as plt

# Fetch MNIST dataset
mnist = fetch_openml('mnist_784')

# Prepare the data
X = mnist.data / 255.0

# Initialize the K-means model
kmeans = KMeans(n_clusters=10, random_state=42)

# Fit the model to the data
kmeans.fit(X)

# Get the predicted labels
labels = kmeans.labels_

# Plot random images from each cluster
fig, axs = plt.subplots(2, 5, figsize=(12, 6))
for i in range(10):
    cluster_samples = X[labels == i]
    random_sample = cluster_samples[np.random.randint(cluster_samples.shape[0])]
    axs[i // 5, i % 5].imshow(random_sample.reshape(28, 28), cmap='gray')
    axs[i // 5, i % 5].axis('off')
    axs[i // 5, i % 5].set_title(f'Cluster {i}')
plt.tight_layout()
plt.show()
Explanation
  1. The required libraries are imported, including KMeans class from sklearn.cluster, fetch_openml function from sklearn.datasets, and pyplot module from matplotlib.
  2. MNIST dataset is fetched using the fetch_openml function, which loads the dataset as a pandas dataframe.
  3. Data normalization is performed by dividing each pixel value by 255 to scale them between 0 and 1.
  4. An instance of the K-means model is created with n_clusters=10 (to cluster into 10 classes) and random_state=42 for reproducibility.
  5. The K-means model is fitted to the normalized data using the fit method.
  6. Cluster labels for each data point are obtained using the labels_ attribute of the K-means model.
  7. Random images from each cluster are visualized using matplotlib's imshow function.
Conclusion

By using scikit-learn's KMeans, we can easily perform K-means clustering on the MNIST dataset. This algorithm helps in grouping similar digits together without the need for any supervised training. K-means clustering can be further extended for various tasks such as image compression, anomaly detection, and more.