📜  使用 Keras 构建生成对抗网络(1)

📅  最后修改于: 2023-12-03 15:36:31.496000             🧑  作者: Mango

使用 Keras 构建生成对抗网络

什么是生成对抗网络(GAN)

生成对抗网络(GAN)是一种深度学习模型,用于生成和合成现实世界中的图像、声音和文本等数据。GAN 由两个深度神经网络组成,即生成器和判别器。生成器将输入的随机向量转换为新图像或数据,而判别器则尝试区分生成器生成的图像和真实的图像。因此,生成器必须尝试生成逼真的图像,以欺骗判别器,而判别器必须尝试尽可能准确地区分这些图像。

如何使用 Keras 构建 GAN

首先,我们需要导入必要的库,包括 Keras、NumPy 和 Matplotlib。代码示例如下:

import numpy as np
from tensorflow import keras
from matplotlib import pyplot as plt
构建生成器模型

我们通过 Keras 的 Sequential API 来构建生成器模型。生成器的目的是将一个随机向量转换为一个类似真实图像的向量。代码示例如下:

def build_generator():
    model = keras.Sequential()
    model.add(keras.layers.Dense(256, input_shape=(100,), activation='relu'))
    model.add(keras.layers.Dense(512, activation='relu'))
    model.add(keras.layers.Dense(784, activation='tanh'))
    model.add(keras.layers.Reshape((28,28)))
    return model

在上面的代码中,我们首先定义了一个 Sequential 对象,然后添加了三个 Dense 层,分别包含 256、512 和 784 个神经元。激活函数为 ReLU,输出使用了 tanh 激活函数,并通过 Reshape 将输出转换为 28x28 的图像。

构建判别器模型

同样地,我们可以使用 Keras 的 Sequential API 来构建判别器模型。判别器的目的是将输入的图像分类为真实图像或生成图像。代码示例如下:

def build_discriminator():
    model = keras.Sequential()
    model.add(keras.layers.Flatten(input_shape=(28,28)))
    model.add(keras.layers.Dense(512, activation='relu'))
    model.add(keras.layers.Dense(256, activation='relu'))
    model.add(keras.layers.Dense(1, activation='sigmoid'))
    return model

在上面的代码中,我们首先使用 Flatten 层将输入的图像转换为一维向量,然后添加了三个 Dense 层,分别包含 512、256 和 1 个神经元。激活函数为 ReLU,输出使用了 sigmoid 激活函数。

构建完整的 GAN 模型

现在,我们已经构建了生成器和判别器模型。接下来,我们需要将它们组合在一起,构建一个完整的 GAN 模型。代码示例如下:

def build_gan(generator, discriminator):
    model = keras.Sequential()
    model.add(generator)
    discriminator.trainable = False
    model.add(discriminator)
    return model

在上面的代码中,我们首先定义了一个 Sequential 对象,然后将生成器和判别器模型依次添加到其中。我们还将判别器的可训练性设置为 False,以确保在训练 GAN 时只更新生成器的权重。

GAN 的训练和生成

现在,我们已经构建完成了 GAN 模型,我们可以开始训练 GAN 并生成新图像了。对于 MNIST 数据集,每张图像都是 28x28 像素的灰度图像。我们可以使用以下代码加载 MNIST 数据集:

from keras.datasets import mnist
(train_images, _), (_, _) = mnist.load_data()

# 像素值归一化
train_images = (train_images - 127.5) / 127.5

接下来,我们定义一些训练 GAN 的超参数,并使用以下代码进行训练:

# 超参数
epochs = 50
batch_size = 256

# 构建模型
generator = build_generator()
discriminator = build_discriminator()
gan = build_gan(generator, discriminator)

# 编译模型
generator.compile(loss='binary_crossentropy', optimizer='adam')
gan.compile(loss='binary_crossentropy', optimizer='adam')

# 训练模型
for epoch in range(1, epochs+1):
    print('[%d/%d]' % (epoch, epochs))
    for batch in range(train_images.shape[0] // batch_size):
        # 生成噪声
        noise = np.random.normal(0, 1, size=(batch_size, 100))
        # 生成图像
        generated_images = generator.predict(noise)
        # 从真实图像中随机选择一批数据
        real_images = train_images[np.random.randint(0, train_images.shape[0], size=batch_size)]
        # 合并数据
        images = np.concatenate([generated_images, real_images])
        # 标签应为 0(生成图像)或 1(真实图像)
        labels = np.zeros(2 * batch_size)
        labels[:batch_size] = 1
        # 训练判别器
        discriminator.trainable = True
        loss_d = discriminator.train_on_batch(images, labels)
        # 生成噪声
        noise = np.random.normal(0, 1, size=(batch_size, 100))
        # 标签应为 1,因为我们希望生成图像被误认为是真实图像
        labels = np.ones(batch_size)
        # 训练生成器
        discriminator.trainable = False
        loss_g = gan.train_on_batch(noise, labels)
        # 打印损失
        print('batch: %d, loss_d: %f, loss_g: %f' % (batch+1, loss_d, loss_g))

通过以上代码进行训练,我们可以看到 GAN 不断生成更加逼真的数字图像。您可以使用以下代码生成新图像:

# 生成新图像
noise = np.random.normal(0, 1, size=(16, 100))
generated_images = generator.predict(noise)

# 缩放像素值到 0-1 范围
generated_images = (generated_images + 1) / 2

# 显示图像
fig, axes = plt.subplots(4, 4, figsize=(10,10))
for i, ax in enumerate(axes.flat):
    ax.imshow(generated_images[i], cmap='gray')
    ax.axis('off')
plt.show()
总结

这篇文章向您介绍了如何使用 Keras 构建生成对抗网络(GAN),并生成 MNIST 数据集的图像。除此之外,您还可以使用 GAN 生成其他类型的图像、声音或文本数据。GAN 可用于许多领域,如数字艺术、语音生成、数据合成等。如果您对深度学习和 GAN 感兴趣,这篇文章会是一个不错的起点。