fit()
中的行為作者: fchollet
建立日期 2020/04/15
最後修改日期 2023/06/27
描述: 使用 TensorFlow 覆寫 Model 類別的訓練步驟。
當您進行監督式學習時,可以使用 fit()
,一切都會順利進行。
當您需要控制每個細節時,您可以從頭開始撰寫自己的訓練迴圈。
但是,如果您需要自訂的訓練演算法,但仍想從 fit()
的便利功能中獲益,例如回呼函數、內建的分散式支援或步驟融合呢?
Keras 的核心原則是逐步揭露複雜性。您應該始終能夠以漸進的方式進入較低層級的工作流程。如果高階功能與您的用例不完全匹配,您不應該遇到困難。您應該能夠在保留相應數量的高階便利性的同時,獲得對細節的更多控制權。
當您需要自訂 fit()
的行為時,您應該覆寫 Model
類別的訓練步驟函數。這是 fit()
針對每個資料批次呼叫的函數。然後,您將能夠像往常一樣呼叫 fit()
,並且它將執行您自己的學習演算法。
請注意,此模式不會阻止您使用函數式 API 建立模型。無論您建立的是 Sequential
模型、函數式 API 模型還是子類別化模型,都可以這樣做。
讓我們看看它是如何運作的。
import os
# This guide can only be run with the TF backend.
os.environ["KERAS_BACKEND"] = "tensorflow"
import tensorflow as tf
import keras
from keras import layers
import numpy as np
讓我們從一個簡單的範例開始
keras.Model
的新類別。train_step(self, data)
。輸入參數 data
是傳遞給 fit 作為訓練資料的內容
fit(x, y, ...)
,則 data
將是元組 (x, y)
tf.data.Dataset
,透過呼叫 fit(dataset, ...)
,則 data
將是 dataset
在每個批次產生的內容。在 train_step()
方法的主體中,我們實作一個常規的訓練更新,類似於您已經熟悉的方式。重要的是,我們透過 self.compute_loss()
計算損失,它會包裝傳遞給 compile()
的損失函數。
類似地,我們對 self.metrics
中的指標呼叫 metric.update_state(y, y_pred)
,以更新 compile()
中傳入的指標狀態,並在最後查詢 self.metrics
的結果以擷取它們目前的值。
class CustomModel(keras.Model):
def train_step(self, data):
# Unpack the data. Its structure depends on your model and
# on what you pass to `fit()`.
x, y = data
with tf.GradientTape() as tape:
y_pred = self(x, training=True) # Forward pass
# Compute the loss value
# (the loss function is configured in `compile()`)
loss = self.compute_loss(y=y, y_pred=y_pred)
# Compute gradients
trainable_vars = self.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
# Update weights
self.optimizer.apply(gradients, trainable_vars)
# Update metrics (includes the metric that tracks the loss)
for metric in self.metrics:
if metric.name == "loss":
metric.update_state(loss)
else:
metric.update_state(y, y_pred)
# Return a dict mapping metric names to current value
return {m.name: m.result() for m in self.metrics}
讓我們試試看
# Construct and compile an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
model.compile(optimizer="adam", loss="mse", metrics=["mae"])
# Just use `fit` as usual
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.fit(x, y, epochs=3)
Epoch 1/3
32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - mae: 0.5089 - loss: 0.3778
Epoch 2/3
32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 318us/step - mae: 0.3986 - loss: 0.2466
Epoch 3/3
32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 372us/step - mae: 0.3848 - loss: 0.2319
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1699222602.443035 1 device_compiler.h:187] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.
<keras.src.callbacks.history.History at 0x2a5599f00>
當然,您可以直接跳過在 compile()
中傳遞損失函數,而是在 train_step
中手動完成所有操作。指標也是如此。
這是一個較低階的範例,僅使用 compile()
來設定最佳化器
Metric
實例來追蹤我們的損失和 MAE 分數(在 __init__()
中)。train_step()
,它會更新這些指標的狀態(透過對它們呼叫 update_state()
),然後查詢它們(透過 result()
)以傳回它們目前的平均值,以顯示在進度列中並傳遞給任何回呼函數。reset_states()
!否則,呼叫 result()
將傳回自訓練開始以來的平均值,而我們通常使用每個 epoch 的平均值。幸運的是,框架可以為我們執行此操作:只需在模型的 metrics
屬性中列出任何您想要重設的指標即可。模型會在每個 fit()
epoch 的開始或呼叫 evaluate()
的開始,對此處列出的任何物件呼叫 reset_states()
。class CustomModel(keras.Model):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.loss_tracker = keras.metrics.Mean(name="loss")
self.mae_metric = keras.metrics.MeanAbsoluteError(name="mae")
self.loss_fn = keras.losses.MeanSquaredError()
def train_step(self, data):
x, y = data
with tf.GradientTape() as tape:
y_pred = self(x, training=True) # Forward pass
# Compute our own loss
loss = self.loss_fn(y, y_pred)
# Compute gradients
trainable_vars = self.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
# Update weights
self.optimizer.apply(gradients, trainable_vars)
# Compute our own metrics
self.loss_tracker.update_state(loss)
self.mae_metric.update_state(y, y_pred)
return {
"loss": self.loss_tracker.result(),
"mae": self.mae_metric.result(),
}
@property
def metrics(self):
# We list our `Metric` objects here so that `reset_states()` can be
# called automatically at the start of each epoch
# or at the start of `evaluate()`.
return [self.loss_tracker, self.mae_metric]
# Construct an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
# We don't pass a loss or metrics here.
model.compile(optimizer="adam")
# Just use `fit` as usual -- you can use callbacks, etc.
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.fit(x, y, epochs=5)
Epoch 1/5
32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 4.0292 - mae: 1.9270
Epoch 2/5
32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 385us/step - loss: 2.2155 - mae: 1.3920
Epoch 3/5
32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 336us/step - loss: 1.1863 - mae: 0.9700
Epoch 4/5
32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 373us/step - loss: 0.6510 - mae: 0.6811
Epoch 5/5
32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 330us/step - loss: 0.4059 - mae: 0.5094
<keras.src.callbacks.history.History at 0x2a7a02860>
sample_weight
& class_weight
您可能已經注意到我們第一個基本範例沒有提及樣本加權。如果您想要支援 fit()
引數 sample_weight
和 class_weight
,您只需執行下列操作
data
引數解包 sample_weight
compute_loss
和 update_state
(當然,如果您不依賴 compile()
來計算損失和指標,您也可以手動套用它)class CustomModel(keras.Model):
def train_step(self, data):
# Unpack the data. Its structure depends on your model and
# on what you pass to `fit()`.
if len(data) == 3:
x, y, sample_weight = data
else:
sample_weight = None
x, y = data
with tf.GradientTape() as tape:
y_pred = self(x, training=True) # Forward pass
# Compute the loss value.
# The loss function is configured in `compile()`.
loss = self.compute_loss(
y=y,
y_pred=y_pred,
sample_weight=sample_weight,
)
# Compute gradients
trainable_vars = self.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
# Update weights
self.optimizer.apply(gradients, trainable_vars)
# Update the metrics.
# Metrics are configured in `compile()`.
for metric in self.metrics:
if metric.name == "loss":
metric.update_state(loss)
else:
metric.update_state(y, y_pred, sample_weight=sample_weight)
# Return a dict mapping metric names to current value.
# Note that it will include the loss (tracked in self.metrics).
return {m.name: m.result() for m in self.metrics}
# Construct and compile an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
model.compile(optimizer="adam", loss="mse", metrics=["mae"])
# You can now use sample_weight argument
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
sw = np.random.random((1000, 1))
model.fit(x, y, sample_weight=sw, epochs=3)
Epoch 1/3
32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - mae: 0.4228 - loss: 0.1420
Epoch 2/3
32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 449us/step - mae: 0.3751 - loss: 0.1058
Epoch 3/3
32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 337us/step - mae: 0.3478 - loss: 0.0951
<keras.src.callbacks.history.History at 0x2a7491780>
如果您想對 model.evaluate()
的呼叫執行相同的操作怎麼辦?然後,您將以完全相同的方式覆寫 test_step
。以下是它的外觀
class CustomModel(keras.Model):
def test_step(self, data):
# Unpack the data
x, y = data
# Compute predictions
y_pred = self(x, training=False)
# Updates the metrics tracking the loss
loss = self.compute_loss(y=y, y_pred=y_pred)
# Update the metrics.
for metric in self.metrics:
if metric.name == "loss":
metric.update_state(loss)
else:
metric.update_state(y, y_pred)
# Return a dict mapping metric names to current value.
# Note that it will include the loss (tracked in self.metrics).
return {m.name: m.result() for m in self.metrics}
# Construct an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
model.compile(loss="mse", metrics=["mae"])
# Evaluate with our custom test_step
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.evaluate(x, y)
32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 927us/step - mae: 0.8518 - loss: 0.9166
[0.912325382232666, 0.8567370176315308]
讓我們逐步介紹一個利用您剛學到的一切的端對端範例。
讓我們考慮
# Create the discriminator
discriminator = keras.Sequential(
[
keras.Input(shape=(28, 28, 1)),
layers.Conv2D(64, (3, 3), strides=(2, 2), padding="same"),
layers.LeakyReLU(negative_slope=0.2),
layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same"),
layers.LeakyReLU(negative_slope=0.2),
layers.GlobalMaxPooling2D(),
layers.Dense(1),
],
name="discriminator",
)
# Create the generator
latent_dim = 128
generator = keras.Sequential(
[
keras.Input(shape=(latent_dim,)),
# We want to generate 128 coefficients to reshape into a 7x7x128 map
layers.Dense(7 * 7 * 128),
layers.LeakyReLU(negative_slope=0.2),
layers.Reshape((7, 7, 128)),
layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
layers.LeakyReLU(negative_slope=0.2),
layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
layers.LeakyReLU(negative_slope=0.2),
layers.Conv2D(1, (7, 7), padding="same", activation="sigmoid"),
],
name="generator",
)
這是一個功能完整的 GAN 類別,覆寫 compile()
以使用自己的簽名,並在 train_step
中以 17 行程式碼實作整個 GAN 演算法
class GAN(keras.Model):
def __init__(self, discriminator, generator, latent_dim):
super().__init__()
self.discriminator = discriminator
self.generator = generator
self.latent_dim = latent_dim
self.d_loss_tracker = keras.metrics.Mean(name="d_loss")
self.g_loss_tracker = keras.metrics.Mean(name="g_loss")
self.seed_generator = keras.random.SeedGenerator(1337)
@property
def metrics(self):
return [self.d_loss_tracker, self.g_loss_tracker]
def compile(self, d_optimizer, g_optimizer, loss_fn):
super().compile()
self.d_optimizer = d_optimizer
self.g_optimizer = g_optimizer
self.loss_fn = loss_fn
def train_step(self, real_images):
if isinstance(real_images, tuple):
real_images = real_images[0]
# Sample random points in the latent space
batch_size = tf.shape(real_images)[0]
random_latent_vectors = keras.random.normal(
shape=(batch_size, self.latent_dim), seed=self.seed_generator
)
# Decode them to fake images
generated_images = self.generator(random_latent_vectors)
# Combine them with real images
combined_images = tf.concat([generated_images, real_images], axis=0)
# Assemble labels discriminating real from fake images
labels = tf.concat(
[tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0
)
# Add random noise to the labels - important trick!
labels += 0.05 * keras.random.uniform(
tf.shape(labels), seed=self.seed_generator
)
# Train the discriminator
with tf.GradientTape() as tape:
predictions = self.discriminator(combined_images)
d_loss = self.loss_fn(labels, predictions)
grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
self.d_optimizer.apply(grads, self.discriminator.trainable_weights)
# Sample random points in the latent space
random_latent_vectors = keras.random.normal(
shape=(batch_size, self.latent_dim), seed=self.seed_generator
)
# Assemble labels that say "all real images"
misleading_labels = tf.zeros((batch_size, 1))
# Train the generator (note that we should *not* update the weights
# of the discriminator)!
with tf.GradientTape() as tape:
predictions = self.discriminator(self.generator(random_latent_vectors))
g_loss = self.loss_fn(misleading_labels, predictions)
grads = tape.gradient(g_loss, self.generator.trainable_weights)
self.g_optimizer.apply(grads, self.generator.trainable_weights)
# Update metrics and return their value.
self.d_loss_tracker.update_state(d_loss)
self.g_loss_tracker.update_state(g_loss)
return {
"d_loss": self.d_loss_tracker.result(),
"g_loss": self.g_loss_tracker.result(),
}
讓我們試用一下
# Prepare the dataset. We use both the training & test MNIST digits.
batch_size = 64
(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
all_digits = np.concatenate([x_train, x_test])
all_digits = all_digits.astype("float32") / 255.0
all_digits = np.reshape(all_digits, (-1, 28, 28, 1))
dataset = tf.data.Dataset.from_tensor_slices(all_digits)
dataset = dataset.shuffle(buffer_size=1024).batch(batch_size)
gan = GAN(discriminator=discriminator, generator=generator, latent_dim=latent_dim)
gan.compile(
d_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
g_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
loss_fn=keras.losses.BinaryCrossentropy(from_logits=True),
)
# To limit the execution time, we only train on 100 batches. You can train on
# the entire dataset. You will need about 20 epochs to get nice results.
gan.fit(dataset.take(100), epochs=1)
100/100 ━━━━━━━━━━━━━━━━━━━━ 51s 500ms/step - d_loss: 0.5645 - g_loss: 0.7434
<keras.src.callbacks.history.History at 0x14a4f1b10>
深度學習背後的概念很簡單,那麼為什麼它們的實作應該很痛苦呢?