作者: fchollet
建立日期 2019/04/29
上次修改日期 2023/12/21
描述: 一個簡單的 DCGAN,透過覆寫 CelebA 圖像上的 train_step
,使用 fit()
進行訓練。
import keras
import tensorflow as tf
from keras import layers
from keras import ops
import matplotlib.pyplot as plt
import os
import gdown
from zipfile import ZipFile
我們將使用 CelebA 資料集中的人臉圖像,調整大小為 64x64。
os.makedirs("celeba_gan")
url = "https://drive.google.com/uc?id=1O7m1010EJjLE5QxLZiM9Fpjs7Oj6e684"
output = "celeba_gan/data.zip"
gdown.download(url, output, quiet=True)
with ZipFile("celeba_gan/data.zip", "r") as zipobj:
zipobj.extractall("celeba_gan")
從我們的資料夾建立一個資料集,並將圖像重新縮放到 [0-1] 範圍
dataset = keras.utils.image_dataset_from_directory(
"celeba_gan", label_mode=None, image_size=(64, 64), batch_size=32
)
dataset = dataset.map(lambda x: x / 255.0)
Found 202599 files.
讓我們顯示一個範例圖像
for x in dataset:
plt.axis("off")
plt.imshow((x.numpy() * 255).astype("int32")[0])
break
它將 64x64 的圖像映射到二元分類分數。
discriminator = keras.Sequential(
[
keras.Input(shape=(64, 64, 3)),
layers.Conv2D(64, kernel_size=4, strides=2, padding="same"),
layers.LeakyReLU(negative_slope=0.2),
layers.Conv2D(128, kernel_size=4, strides=2, padding="same"),
layers.LeakyReLU(negative_slope=0.2),
layers.Conv2D(128, kernel_size=4, strides=2, padding="same"),
layers.LeakyReLU(negative_slope=0.2),
layers.Flatten(),
layers.Dropout(0.2),
layers.Dense(1, activation="sigmoid"),
],
name="discriminator",
)
discriminator.summary()
Model: "discriminator"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩ │ conv2d (Conv2D) │ (None, 32, 32, 64) │ 3,136 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ leaky_re_lu (LeakyReLU) │ (None, 32, 32, 64) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ conv2d_1 (Conv2D) │ (None, 16, 16, 128) │ 131,200 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ leaky_re_lu_1 (LeakyReLU) │ (None, 16, 16, 128) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ conv2d_2 (Conv2D) │ (None, 8, 8, 128) │ 262,272 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ leaky_re_lu_2 (LeakyReLU) │ (None, 8, 8, 128) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ flatten (Flatten) │ (None, 8192) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ dropout (Dropout) │ (None, 8192) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ dense (Dense) │ (None, 1) │ 8,193 │ └─────────────────────────────────┴───────────────────────────┴────────────┘
Total params: 404,801 (1.54 MB)
Trainable params: 404,801 (1.54 MB)
Non-trainable params: 0 (0.00 B)
它反映了鑑別器,將 Conv2D
層替換為 Conv2DTranspose
層。
latent_dim = 128
generator = keras.Sequential(
[
keras.Input(shape=(latent_dim,)),
layers.Dense(8 * 8 * 128),
layers.Reshape((8, 8, 128)),
layers.Conv2DTranspose(128, kernel_size=4, strides=2, padding="same"),
layers.LeakyReLU(negative_slope=0.2),
layers.Conv2DTranspose(256, kernel_size=4, strides=2, padding="same"),
layers.LeakyReLU(negative_slope=0.2),
layers.Conv2DTranspose(512, kernel_size=4, strides=2, padding="same"),
layers.LeakyReLU(negative_slope=0.2),
layers.Conv2D(3, kernel_size=5, padding="same", activation="sigmoid"),
],
name="generator",
)
generator.summary()
Model: "generator"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩ │ dense_1 (Dense) │ (None, 8192) │ 1,056,768 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ reshape (Reshape) │ (None, 8, 8, 128) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ conv2d_transpose │ (None, 16, 16, 128) │ 262,272 │ │ (Conv2DTranspose) │ │ │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ leaky_re_lu_3 (LeakyReLU) │ (None, 16, 16, 128) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ conv2d_transpose_1 │ (None, 32, 32, 256) │ 524,544 │ │ (Conv2DTranspose) │ │ │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ leaky_re_lu_4 (LeakyReLU) │ (None, 32, 32, 256) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ conv2d_transpose_2 │ (None, 64, 64, 512) │ 2,097,664 │ │ (Conv2DTranspose) │ │ │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ leaky_re_lu_5 (LeakyReLU) │ (None, 64, 64, 512) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ conv2d_3 (Conv2D) │ (None, 64, 64, 3) │ 38,403 │ └─────────────────────────────────┴───────────────────────────┴────────────┘
Total params: 3,979,651 (15.18 MB)
Trainable params: 3,979,651 (15.18 MB)
Non-trainable params: 0 (0.00 B)
train_step
class GAN(keras.Model):
def __init__(self, discriminator, generator, latent_dim):
super().__init__()
self.discriminator = discriminator
self.generator = generator
self.latent_dim = latent_dim
self.seed_generator = keras.random.SeedGenerator(1337)
def compile(self, d_optimizer, g_optimizer, loss_fn):
super().compile()
self.d_optimizer = d_optimizer
self.g_optimizer = g_optimizer
self.loss_fn = loss_fn
self.d_loss_metric = keras.metrics.Mean(name="d_loss")
self.g_loss_metric = keras.metrics.Mean(name="g_loss")
@property
def metrics(self):
return [self.d_loss_metric, self.g_loss_metric]
def train_step(self, real_images):
# Sample random points in the latent space
batch_size = ops.shape(real_images)[0]
random_latent_vectors = keras.random.normal(
shape=(batch_size, self.latent_dim), seed=self.seed_generator
)
# Decode them to fake images
generated_images = self.generator(random_latent_vectors)
# Combine them with real images
combined_images = ops.concatenate([generated_images, real_images], axis=0)
# Assemble labels discriminating real from fake images
labels = ops.concatenate(
[ops.ones((batch_size, 1)), ops.zeros((batch_size, 1))], axis=0
)
# Add random noise to the labels - important trick!
labels += 0.05 * tf.random.uniform(tf.shape(labels))
# Train the discriminator
with tf.GradientTape() as tape:
predictions = self.discriminator(combined_images)
d_loss = self.loss_fn(labels, predictions)
grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
self.d_optimizer.apply_gradients(
zip(grads, self.discriminator.trainable_weights)
)
# Sample random points in the latent space
random_latent_vectors = keras.random.normal(
shape=(batch_size, self.latent_dim), seed=self.seed_generator
)
# Assemble labels that say "all real images"
misleading_labels = ops.zeros((batch_size, 1))
# Train the generator (note that we should *not* update the weights
# of the discriminator)!
with tf.GradientTape() as tape:
predictions = self.discriminator(self.generator(random_latent_vectors))
g_loss = self.loss_fn(misleading_labels, predictions)
grads = tape.gradient(g_loss, self.generator.trainable_weights)
self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))
# Update metrics
self.d_loss_metric.update_state(d_loss)
self.g_loss_metric.update_state(g_loss)
return {
"d_loss": self.d_loss_metric.result(),
"g_loss": self.g_loss_metric.result(),
}
class GANMonitor(keras.callbacks.Callback):
def __init__(self, num_img=3, latent_dim=128):
self.num_img = num_img
self.latent_dim = latent_dim
self.seed_generator = keras.random.SeedGenerator(42)
def on_epoch_end(self, epoch, logs=None):
random_latent_vectors = keras.random.normal(
shape=(self.num_img, self.latent_dim), seed=self.seed_generator
)
generated_images = self.model.generator(random_latent_vectors)
generated_images *= 255
generated_images.numpy()
for i in range(self.num_img):
img = keras.utils.array_to_img(generated_images[i])
img.save("generated_img_%03d_%d.png" % (epoch, i))
epochs = 1 # In practice, use ~100 epochs
gan = GAN(discriminator=discriminator, generator=generator, latent_dim=latent_dim)
gan.compile(
d_optimizer=keras.optimizers.Adam(learning_rate=0.0001),
g_optimizer=keras.optimizers.Adam(learning_rate=0.0001),
loss_fn=keras.losses.BinaryCrossentropy(),
)
gan.fit(
dataset, epochs=epochs, callbacks=[GANMonitor(num_img=10, latent_dim=latent_dim)]
)
2/6332 [37m━━━━━━━━━━━━━━━━━━━━ 9:54 94ms/step - d_loss: 0.6792 - g_loss: 0.7880
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1704214667.959762 1319 device_compiler.h:186] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.
6332/6332 ━━━━━━━━━━━━━━━━━━━━ 557s 84ms/step - d_loss: 0.5616 - g_loss: 1.4099
<keras.src.callbacks.history.History at 0x7f251d32bc40>
大約在 epoch 30 左右生成的一些最後圖像(之後結果會持續改善)