作者: fchollet
建立日期 2019/03/01
上次修改日期 2023/06/25
描述: 在 TensorFlow 中編寫底層訓練 & 評估迴圈。
import time
import os
# This guide can only be run with the TensorFlow backend.
os.environ["KERAS_BACKEND"] = "tensorflow"
import tensorflow as tf
import keras
import numpy as np
Keras 提供預設的訓練和評估迴圈,fit()
和 evaluate()
。其用法在使用內建方法進行訓練 & 評估指南中介紹。
如果您想自訂模型的學習演算法,同時仍利用 fit()
的便利性 (例如,使用 fit()
訓練 GAN),您可以子類別化 Model
類別並實作您自己的 train_step()
方法,該方法會在 fit()
期間重複呼叫。
現在,如果您想要對訓練 & 評估進行非常低階的控制,您應該從頭開始編寫自己的訓練 & 評估迴圈。這就是本指南的重點。
讓我們考慮一個簡單的 MNIST 模型
def get_model():
inputs = keras.Input(shape=(784,), name="digits")
x1 = keras.layers.Dense(64, activation="relu")(inputs)
x2 = keras.layers.Dense(64, activation="relu")(x1)
outputs = keras.layers.Dense(10, name="predictions")(x2)
model = keras.Model(inputs=inputs, outputs=outputs)
return model
model = get_model()
讓我們使用帶有自訂訓練迴圈的迷你批次梯度來訓練它。
首先,我們需要一個最佳化器、一個損失函數和一個資料集
# Instantiate an optimizer.
optimizer = keras.optimizers.Adam(learning_rate=1e-3)
# Instantiate a loss function.
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
# Prepare the training dataset.
batch_size = 32
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = np.reshape(x_train, (-1, 784))
x_test = np.reshape(x_test, (-1, 784))
# Reserve 10,000 samples for validation.
x_val = x_train[-10000:]
y_val = y_train[-10000:]
x_train = x_train[:-10000]
y_train = y_train[:-10000]
# Prepare the training dataset.
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)
# Prepare the validation dataset.
val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_dataset = val_dataset.batch(batch_size)
在 GradientTape
範圍內呼叫模型,可以讓您取得層的可訓練權重相對於損失值的梯度。使用最佳化器實例,您可以使用這些梯度來更新這些變數 (您可以使用 model.trainable_weights
取得這些變數)。
這是我們的訓練迴圈,逐步說明
for
迴圈,該迴圈會迭代遍歷 epochfor
迴圈,該迴圈會以批次為單位迭代遍歷資料集GradientTape()
範圍epochs = 3
for epoch in range(epochs):
print(f"\nStart of epoch {epoch}")
# Iterate over the batches of the dataset.
for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
# Open a GradientTape to record the operations run
# during the forward pass, which enables auto-differentiation.
with tf.GradientTape() as tape:
# Run the forward pass of the layer.
# The operations that the layer applies
# to its inputs are going to be recorded
# on the GradientTape.
logits = model(x_batch_train, training=True) # Logits for this minibatch
# Compute the loss value for this minibatch.
loss_value = loss_fn(y_batch_train, logits)
# Use the gradient tape to automatically retrieve
# the gradients of the trainable variables with respect to the loss.
grads = tape.gradient(loss_value, model.trainable_weights)
# Run one step of gradient descent by updating
# the value of the variables to minimize the loss.
optimizer.apply(grads, model.trainable_weights)
# Log every 100 batches.
if step % 100 == 0:
print(
f"Training loss (for 1 batch) at step {step}: {float(loss_value):.4f}"
)
print(f"Seen so far: {(step + 1) * batch_size} samples")
Start of epoch 0
Training loss (for 1 batch) at step 0: 95.3300
Seen so far: 32 samples
Training loss (for 1 batch) at step 100: 2.5622
Seen so far: 3232 samples
Training loss (for 1 batch) at step 200: 3.1138
Seen so far: 6432 samples
Training loss (for 1 batch) at step 300: 0.6748
Seen so far: 9632 samples
Training loss (for 1 batch) at step 400: 1.3308
Seen so far: 12832 samples
Training loss (for 1 batch) at step 500: 1.9813
Seen so far: 16032 samples
Training loss (for 1 batch) at step 600: 0.8640
Seen so far: 19232 samples
Training loss (for 1 batch) at step 700: 1.0696
Seen so far: 22432 samples
Training loss (for 1 batch) at step 800: 0.3662
Seen so far: 25632 samples
Training loss (for 1 batch) at step 900: 0.9556
Seen so far: 28832 samples
Training loss (for 1 batch) at step 1000: 0.7459
Seen so far: 32032 samples
Training loss (for 1 batch) at step 1100: 0.0468
Seen so far: 35232 samples
Training loss (for 1 batch) at step 1200: 0.7392
Seen so far: 38432 samples
Training loss (for 1 batch) at step 1300: 0.8435
Seen so far: 41632 samples
Training loss (for 1 batch) at step 1400: 0.3859
Seen so far: 44832 samples
Training loss (for 1 batch) at step 1500: 0.4156
Seen so far: 48032 samples
Start of epoch 1
Training loss (for 1 batch) at step 0: 0.4045
Seen so far: 32 samples
Training loss (for 1 batch) at step 100: 0.5983
Seen so far: 3232 samples
Training loss (for 1 batch) at step 200: 0.3154
Seen so far: 6432 samples
Training loss (for 1 batch) at step 300: 0.7911
Seen so far: 9632 samples
Training loss (for 1 batch) at step 400: 0.2607
Seen so far: 12832 samples
Training loss (for 1 batch) at step 500: 0.2303
Seen so far: 16032 samples
Training loss (for 1 batch) at step 600: 0.6048
Seen so far: 19232 samples
Training loss (for 1 batch) at step 700: 0.7041
Seen so far: 22432 samples
Training loss (for 1 batch) at step 800: 0.3669
Seen so far: 25632 samples
Training loss (for 1 batch) at step 900: 0.6389
Seen so far: 28832 samples
Training loss (for 1 batch) at step 1000: 0.7739
Seen so far: 32032 samples
Training loss (for 1 batch) at step 1100: 0.3888
Seen so far: 35232 samples
Training loss (for 1 batch) at step 1200: 0.8133
Seen so far: 38432 samples
Training loss (for 1 batch) at step 1300: 0.2034
Seen so far: 41632 samples
Training loss (for 1 batch) at step 1400: 0.0768
Seen so far: 44832 samples
Training loss (for 1 batch) at step 1500: 0.1544
Seen so far: 48032 samples
Start of epoch 2
Training loss (for 1 batch) at step 0: 0.1250
Seen so far: 32 samples
Training loss (for 1 batch) at step 100: 0.0152
Seen so far: 3232 samples
Training loss (for 1 batch) at step 200: 0.0917
Seen so far: 6432 samples
Training loss (for 1 batch) at step 300: 0.1330
Seen so far: 9632 samples
Training loss (for 1 batch) at step 400: 0.0884
Seen so far: 12832 samples
Training loss (for 1 batch) at step 500: 0.2656
Seen so far: 16032 samples
Training loss (for 1 batch) at step 600: 0.4375
Seen so far: 19232 samples
Training loss (for 1 batch) at step 700: 0.2246
Seen so far: 22432 samples
Training loss (for 1 batch) at step 800: 0.0748
Seen so far: 25632 samples
Training loss (for 1 batch) at step 900: 0.1765
Seen so far: 28832 samples
Training loss (for 1 batch) at step 1000: 0.0130
Seen so far: 32032 samples
Training loss (for 1 batch) at step 1100: 0.4030
Seen so far: 35232 samples
Training loss (for 1 batch) at step 1200: 0.0667
Seen so far: 38432 samples
Training loss (for 1 batch) at step 1300: 1.0553
Seen so far: 41632 samples
Training loss (for 1 batch) at step 1400: 0.6513
Seen so far: 44832 samples
Training loss (for 1 batch) at step 1500: 0.0599
Seen so far: 48032 samples
讓我們將度量監控新增到這個基本迴圈。
您可以隨時在這種從頭編寫的訓練迴圈中重複使用內建度量 (或您編寫的自訂度量)。以下是流程
metric.update_state()
metric.result()
metric.reset_state()
(通常在 epoch 結束時)讓我們使用這些知識來計算每個 epoch 結束時的訓練和驗證資料上的 SparseCategoricalAccuracy
# Get a fresh model
model = get_model()
# Instantiate an optimizer to train the model.
optimizer = keras.optimizers.Adam(learning_rate=1e-3)
# Instantiate a loss function.
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
# Prepare the metrics.
train_acc_metric = keras.metrics.SparseCategoricalAccuracy()
val_acc_metric = keras.metrics.SparseCategoricalAccuracy()
這是我們的訓練 & 評估迴圈
epochs = 2
for epoch in range(epochs):
print(f"\nStart of epoch {epoch}")
start_time = time.time()
# Iterate over the batches of the dataset.
for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
with tf.GradientTape() as tape:
logits = model(x_batch_train, training=True)
loss_value = loss_fn(y_batch_train, logits)
grads = tape.gradient(loss_value, model.trainable_weights)
optimizer.apply(grads, model.trainable_weights)
# Update training metric.
train_acc_metric.update_state(y_batch_train, logits)
# Log every 100 batches.
if step % 100 == 0:
print(
f"Training loss (for 1 batch) at step {step}: {float(loss_value):.4f}"
)
print(f"Seen so far: {(step + 1) * batch_size} samples")
# Display metrics at the end of each epoch.
train_acc = train_acc_metric.result()
print(f"Training acc over epoch: {float(train_acc):.4f}")
# Reset training metrics at the end of each epoch
train_acc_metric.reset_state()
# Run a validation loop at the end of each epoch.
for x_batch_val, y_batch_val in val_dataset:
val_logits = model(x_batch_val, training=False)
# Update val metrics
val_acc_metric.update_state(y_batch_val, val_logits)
val_acc = val_acc_metric.result()
val_acc_metric.reset_state()
print(f"Validation acc: {float(val_acc):.4f}")
print(f"Time taken: {time.time() - start_time:.2f}s")
Start of epoch 0
Training loss (for 1 batch) at step 0: 89.1303
Seen so far: 32 samples
Training loss (for 1 batch) at step 100: 1.0351
Seen so far: 3232 samples
Training loss (for 1 batch) at step 200: 2.9143
Seen so far: 6432 samples
Training loss (for 1 batch) at step 300: 1.7842
Seen so far: 9632 samples
Training loss (for 1 batch) at step 400: 0.9583
Seen so far: 12832 samples
Training loss (for 1 batch) at step 500: 1.1100
Seen so far: 16032 samples
Training loss (for 1 batch) at step 600: 2.1144
Seen so far: 19232 samples
Training loss (for 1 batch) at step 700: 0.6801
Seen so far: 22432 samples
Training loss (for 1 batch) at step 800: 0.6202
Seen so far: 25632 samples
Training loss (for 1 batch) at step 900: 1.2570
Seen so far: 28832 samples
Training loss (for 1 batch) at step 1000: 0.3638
Seen so far: 32032 samples
Training loss (for 1 batch) at step 1100: 1.8402
Seen so far: 35232 samples
Training loss (for 1 batch) at step 1200: 0.7836
Seen so far: 38432 samples
Training loss (for 1 batch) at step 1300: 0.5147
Seen so far: 41632 samples
Training loss (for 1 batch) at step 1400: 0.4798
Seen so far: 44832 samples
Training loss (for 1 batch) at step 1500: 0.1653
Seen so far: 48032 samples
Training acc over epoch: 0.7961
Validation acc: 0.8825
Time taken: 46.06s
Start of epoch 1
Training loss (for 1 batch) at step 0: 1.3917
Seen so far: 32 samples
Training loss (for 1 batch) at step 100: 0.2600
Seen so far: 3232 samples
Training loss (for 1 batch) at step 200: 0.7206
Seen so far: 6432 samples
Training loss (for 1 batch) at step 300: 0.4987
Seen so far: 9632 samples
Training loss (for 1 batch) at step 400: 0.3410
Seen so far: 12832 samples
Training loss (for 1 batch) at step 500: 0.6788
Seen so far: 16032 samples
Training loss (for 1 batch) at step 600: 1.1355
Seen so far: 19232 samples
Training loss (for 1 batch) at step 700: 0.1762
Seen so far: 22432 samples
Training loss (for 1 batch) at step 800: 0.1801
Seen so far: 25632 samples
Training loss (for 1 batch) at step 900: 0.3515
Seen so far: 28832 samples
Training loss (for 1 batch) at step 1000: 0.4344
Seen so far: 32032 samples
Training loss (for 1 batch) at step 1100: 0.2027
Seen so far: 35232 samples
Training loss (for 1 batch) at step 1200: 0.4649
Seen so far: 38432 samples
Training loss (for 1 batch) at step 1300: 0.6848
Seen so far: 41632 samples
Training loss (for 1 batch) at step 1400: 0.4594
Seen so far: 44832 samples
Training loss (for 1 batch) at step 1500: 0.3548
Seen so far: 48032 samples
Training acc over epoch: 0.8896
Validation acc: 0.9094
Time taken: 43.49s
tf.function
加速您的訓練步驟TensorFlow 中的預設執行時間是迫切執行。因此,我們上面的訓練迴圈會迫切執行。
這對於偵錯很棒,但圖形編譯具有明確的效能優勢。將您的計算描述為靜態圖形,可讓架構套用全域效能最佳化。當架構受限於貪婪地一個接一個執行操作,而不知道接下來會發生什麼時,這是不可行的。
您可以將任何將張量作為輸入的函數編譯為靜態圖形。只需在其上新增 @tf.function
裝飾器,如下所示
@tf.function
def train_step(x, y):
with tf.GradientTape() as tape:
logits = model(x, training=True)
loss_value = loss_fn(y, logits)
grads = tape.gradient(loss_value, model.trainable_weights)
optimizer.apply(grads, model.trainable_weights)
train_acc_metric.update_state(y, logits)
return loss_value
讓我們對評估步驟執行相同的操作
@tf.function
def test_step(x, y):
val_logits = model(x, training=False)
val_acc_metric.update_state(y, val_logits)
現在,讓我們使用此已編譯的訓練步驟重新執行我們的訓練迴圈
epochs = 2
for epoch in range(epochs):
print(f"\nStart of epoch {epoch}")
start_time = time.time()
# Iterate over the batches of the dataset.
for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
loss_value = train_step(x_batch_train, y_batch_train)
# Log every 100 batches.
if step % 100 == 0:
print(
f"Training loss (for 1 batch) at step {step}: {float(loss_value):.4f}"
)
print(f"Seen so far: {(step + 1) * batch_size} samples")
# Display metrics at the end of each epoch.
train_acc = train_acc_metric.result()
print(f"Training acc over epoch: {float(train_acc):.4f}")
# Reset training metrics at the end of each epoch
train_acc_metric.reset_state()
# Run a validation loop at the end of each epoch.
for x_batch_val, y_batch_val in val_dataset:
test_step(x_batch_val, y_batch_val)
val_acc = val_acc_metric.result()
val_acc_metric.reset_state()
print(f"Validation acc: {float(val_acc):.4f}")
print(f"Time taken: {time.time() - start_time:.2f}s")
Start of epoch 0
Training loss (for 1 batch) at step 0: 0.5366
Seen so far: 32 samples
Training loss (for 1 batch) at step 100: 0.2732
Seen so far: 3232 samples
Training loss (for 1 batch) at step 200: 0.2478
Seen so far: 6432 samples
Training loss (for 1 batch) at step 300: 0.0263
Seen so far: 9632 samples
Training loss (for 1 batch) at step 400: 0.4845
Seen so far: 12832 samples
Training loss (for 1 batch) at step 500: 0.2239
Seen so far: 16032 samples
Training loss (for 1 batch) at step 600: 0.2242
Seen so far: 19232 samples
Training loss (for 1 batch) at step 700: 0.2122
Seen so far: 22432 samples
Training loss (for 1 batch) at step 800: 0.2856
Seen so far: 25632 samples
Training loss (for 1 batch) at step 900: 0.1957
Seen so far: 28832 samples
Training loss (for 1 batch) at step 1000: 0.2946
Seen so far: 32032 samples
Training loss (for 1 batch) at step 1100: 0.3080
Seen so far: 35232 samples
Training loss (for 1 batch) at step 1200: 0.2326
Seen so far: 38432 samples
Training loss (for 1 batch) at step 1300: 0.6514
Seen so far: 41632 samples
Training loss (for 1 batch) at step 1400: 0.2018
Seen so far: 44832 samples
Training loss (for 1 batch) at step 1500: 0.2812
Seen so far: 48032 samples
Training acc over epoch: 0.9104
Validation acc: 0.9199
Time taken: 5.73s
Start of epoch 1
Training loss (for 1 batch) at step 0: 0.3080
Seen so far: 32 samples
Training loss (for 1 batch) at step 100: 0.3943
Seen so far: 3232 samples
Training loss (for 1 batch) at step 200: 0.1657
Seen so far: 6432 samples
Training loss (for 1 batch) at step 300: 0.1463
Seen so far: 9632 samples
Training loss (for 1 batch) at step 400: 0.5359
Seen so far: 12832 samples
Training loss (for 1 batch) at step 500: 0.1894
Seen so far: 16032 samples
Training loss (for 1 batch) at step 600: 0.1801
Seen so far: 19232 samples
Training loss (for 1 batch) at step 700: 0.1724
Seen so far: 22432 samples
Training loss (for 1 batch) at step 800: 0.3997
Seen so far: 25632 samples
Training loss (for 1 batch) at step 900: 0.6017
Seen so far: 28832 samples
Training loss (for 1 batch) at step 1000: 0.1539
Seen so far: 32032 samples
Training loss (for 1 batch) at step 1100: 0.1078
Seen so far: 35232 samples
Training loss (for 1 batch) at step 1200: 0.8731
Seen so far: 38432 samples
Training loss (for 1 batch) at step 1300: 0.3110
Seen so far: 41632 samples
Training loss (for 1 batch) at step 1400: 0.6092
Seen so far: 44832 samples
Training loss (for 1 batch) at step 1500: 0.2046
Seen so far: 48032 samples
Training acc over epoch: 0.9189
Validation acc: 0.9358
Time taken: 3.17s
快多了,不是嗎?
層 & 模型會遞迴追蹤在呼叫 self.add_loss(value)
的層的前向傳遞期間建立的任何損失。產生的純量損失值清單可透過前向傳遞結束時的 model.losses
屬性取得。
如果您想使用這些損失元件,您應該將它們加總並將它們新增到訓練步驟中的主要損失。
考慮這個層,它會建立活動正規化損失
class ActivityRegularizationLayer(keras.layers.Layer):
def call(self, inputs):
self.add_loss(1e-2 * tf.reduce_sum(inputs))
return inputs
讓我們建構一個使用它的非常簡單的模型
inputs = keras.Input(shape=(784,), name="digits")
x = keras.layers.Dense(64, activation="relu")(inputs)
# Insert activity regularization as a layer
x = ActivityRegularizationLayer()(x)
x = keras.layers.Dense(64, activation="relu")(x)
outputs = keras.layers.Dense(10, name="predictions")(x)
model = keras.Model(inputs=inputs, outputs=outputs)
這是我們現在的訓練步驟應該是什麼樣子
@tf.function
def train_step(x, y):
with tf.GradientTape() as tape:
logits = model(x, training=True)
loss_value = loss_fn(y, logits)
# Add any extra losses created during the forward pass.
loss_value += sum(model.losses)
grads = tape.gradient(loss_value, model.trainable_weights)
optimizer.apply(grads, model.trainable_weights)
train_acc_metric.update_state(y, logits)
return loss_value
現在您知道關於使用內建訓練迴圈和從頭開始編寫自己的所有知識。
最後,這是一個簡單的端對端範例,將您在本指南中學到的所有內容連結在一起:在 MNIST 數字上訓練的 DCGAN。
您可能熟悉生成對抗網路 (GAN)。GAN 可以產生看起來幾乎真實的新影像,方法是學習影像訓練資料集的潛在分佈 (影像的「潛在空間」)。
GAN 由兩部分組成:一個「生成器」模型,它將潛在空間中的點對應到影像空間中的點,以及一個「鑑別器」模型,這是一個分類器,可以分辨真實影像 (來自訓練資料集) 和虛假影像 (生成器網路的輸出) 之間的差異。
GAN 訓練迴圈如下所示
1) 訓練鑑別器。 - 取樣潛在空間中的一批隨機點。 - 透過「生成器」模型將這些點轉換為虛假影像。 - 取得一批真實影像,並將它們與產生的影像合併。 - 訓練「鑑別器」模型,以分類產生的影像與真實影像。
2) 訓練生成器。 - 取樣潛在空間中的隨機點。 - 透過「生成器」網路將這些點轉換為虛假影像。 - 取得一批真實影像,並將它們與產生的影像合併。 - 訓練「生成器」模型來「欺騙」鑑別器,並將虛假影像分類為真實影像。
如需 GAN 如何運作的更詳細概述,請參閱使用 Python 進行深度學習。
讓我們實作這個訓練迴圈。首先,建立旨在分類虛假與真實數字的鑑別器
discriminator = keras.Sequential(
[
keras.Input(shape=(28, 28, 1)),
keras.layers.Conv2D(64, (3, 3), strides=(2, 2), padding="same"),
keras.layers.LeakyReLU(negative_slope=0.2),
keras.layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same"),
keras.layers.LeakyReLU(negative_slope=0.2),
keras.layers.GlobalMaxPooling2D(),
keras.layers.Dense(1),
],
name="discriminator",
)
discriminator.summary()
Model: "discriminator"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩ │ conv2d (Conv2D) │ (None, 14, 14, 64) │ 640 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ leaky_re_lu (LeakyReLU) │ (None, 14, 14, 64) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ conv2d_1 (Conv2D) │ (None, 7, 7, 128) │ 73,856 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ leaky_re_lu_1 (LeakyReLU) │ (None, 7, 7, 128) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ global_max_pooling2d │ (None, 128) │ 0 │ │ (GlobalMaxPooling2D) │ │ │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ dense_6 (Dense) │ (None, 1) │ 129 │ └─────────────────────────────────┴───────────────────────────┴────────────┘
Total params: 74,625 (291.50 KB)
Trainable params: 74,625 (291.50 KB)
Non-trainable params: 0 (0.00 B)
然後,讓我們建立一個生成器網路,它會將潛在向量轉換為形狀為 (28, 28, 1)
(表示 MNIST 數字) 的輸出
latent_dim = 128
generator = keras.Sequential(
[
keras.Input(shape=(latent_dim,)),
# We want to generate 128 coefficients to reshape into a 7x7x128 map
keras.layers.Dense(7 * 7 * 128),
keras.layers.LeakyReLU(negative_slope=0.2),
keras.layers.Reshape((7, 7, 128)),
keras.layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
keras.layers.LeakyReLU(negative_slope=0.2),
keras.layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
keras.layers.LeakyReLU(negative_slope=0.2),
keras.layers.Conv2D(1, (7, 7), padding="same", activation="sigmoid"),
],
name="generator",
)
這是關鍵部分:訓練迴圈。如您所見,它非常簡單。訓練步驟函數僅需 17 行。
# Instantiate one optimizer for the discriminator and another for the generator.
d_optimizer = keras.optimizers.Adam(learning_rate=0.0003)
g_optimizer = keras.optimizers.Adam(learning_rate=0.0004)
# Instantiate a loss function.
loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)
@tf.function
def train_step(real_images):
# Sample random points in the latent space
random_latent_vectors = tf.random.normal(shape=(batch_size, latent_dim))
# Decode them to fake images
generated_images = generator(random_latent_vectors)
# Combine them with real images
combined_images = tf.concat([generated_images, real_images], axis=0)
# Assemble labels discriminating real from fake images
labels = tf.concat(
[tf.ones((batch_size, 1)), tf.zeros((real_images.shape[0], 1))], axis=0
)
# Add random noise to the labels - important trick!
labels += 0.05 * tf.random.uniform(labels.shape)
# Train the discriminator
with tf.GradientTape() as tape:
predictions = discriminator(combined_images)
d_loss = loss_fn(labels, predictions)
grads = tape.gradient(d_loss, discriminator.trainable_weights)
d_optimizer.apply(grads, discriminator.trainable_weights)
# Sample random points in the latent space
random_latent_vectors = tf.random.normal(shape=(batch_size, latent_dim))
# Assemble labels that say "all real images"
misleading_labels = tf.zeros((batch_size, 1))
# Train the generator (note that we should *not* update the weights
# of the discriminator)!
with tf.GradientTape() as tape:
predictions = discriminator(generator(random_latent_vectors))
g_loss = loss_fn(misleading_labels, predictions)
grads = tape.gradient(g_loss, generator.trainable_weights)
g_optimizer.apply(grads, generator.trainable_weights)
return d_loss, g_loss, generated_images
讓我們透過重複呼叫影像批次上的 train_step
來訓練我們的 GAN。
由於我們的鑑別器和生成器是卷積網路,因此您會想要在 GPU 上執行此程式碼。
# Prepare the dataset. We use both the training & test MNIST digits.
batch_size = 64
(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
all_digits = np.concatenate([x_train, x_test])
all_digits = all_digits.astype("float32") / 255.0
all_digits = np.reshape(all_digits, (-1, 28, 28, 1))
dataset = tf.data.Dataset.from_tensor_slices(all_digits)
dataset = dataset.shuffle(buffer_size=1024).batch(batch_size)
epochs = 1 # In practice you need at least 20 epochs to generate nice digits.
save_dir = "./"
for epoch in range(epochs):
print(f"\nStart epoch {epoch}")
for step, real_images in enumerate(dataset):
# Train the discriminator & generator on one batch of real images.
d_loss, g_loss, generated_images = train_step(real_images)
# Logging.
if step % 100 == 0:
# Print metrics
print(f"discriminator loss at step {step}: {d_loss:.2f}")
print(f"adversarial loss at step {step}: {g_loss:.2f}")
# Save one generated image
img = keras.utils.array_to_img(generated_images[0] * 255.0, scale=False)
img.save(os.path.join(save_dir, f"generated_img_{step}.png"))
# To limit execution time we stop after 10 steps.
# Remove the lines below to actually train the model!
if step > 10:
break
Start epoch 0
discriminator loss at step 0: 0.69
adversarial loss at step 0: 0.69
就是這樣!您只需在 Colab GPU 上訓練約 30 秒,即可獲得看起來不錯的虛假 MNIST 數字。