程式碼範例 / Keras 快速入門 / 訓練器模式

訓練器模式

作者: nkovela1
建立日期 2022/09/19
上次修改日期 2022/09/26
描述: 關於如何在多個 Keras 模型之間共享自訂訓練步驟的指南。

ⓘ 此範例使用 Keras 3

在 Colab 中檢視 GitHub 原始碼


簡介

此範例展示如何使用「訓練器模式」建立自訂訓練步驟,然後可以在多個 Keras 模型之間共享。此模式會覆寫 keras.Model 類的 train_step() 方法,允許進行超越純粹監督式學習的訓練迴圈。

透過將自訂訓練步驟放入 Trainer 類別定義中,訓練器模式還可以輕鬆地適用於具有更大自訂訓練步驟的更複雜模型,例如這個端到端 GAN 模型


設定

import os

os.environ["KERAS_BACKEND"] = "tensorflow"

import tensorflow as tf
import keras

# Load MNIST dataset and standardize the data
mnist = keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

定義 Trainer 類別

可以透過覆寫 Model 子類別的 train_step()test_step() 方法來建立自訂的訓練和評估步驟

class MyTrainer(keras.Model):
    def __init__(self, model):
        super().__init__()
        self.model = model
        # Create loss and metrics here.
        self.loss_fn = keras.losses.SparseCategoricalCrossentropy()
        self.accuracy_metric = keras.metrics.SparseCategoricalAccuracy()

    @property
    def metrics(self):
        # List metrics here.
        return [self.accuracy_metric]

    def train_step(self, data):
        x, y = data
        with tf.GradientTape() as tape:
            y_pred = self.model(x, training=True)  # Forward pass
            # Compute loss value
            loss = self.loss_fn(y, y_pred)

        # Compute gradients
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        # Update metrics
        for metric in self.metrics:
            metric.update_state(y, y_pred)

        # Return a dict mapping metric names to current value.
        return {m.name: m.result() for m in self.metrics}

    def test_step(self, data):
        x, y = data

        # Inference step
        y_pred = self.model(x, training=False)

        # Update metrics
        for metric in self.metrics:
            metric.update_state(y, y_pred)
        return {m.name: m.result() for m in self.metrics}

    def call(self, x):
        # Equivalent to `call()` of the wrapped keras.Model
        x = self.model(x)
        return x

定義多個模型以共享自訂訓練步驟

讓我們定義兩個不同的模型,它們可以共享我們的 Trainer 類別及其自訂的 train_step()

# A model defined using Sequential API
model_a = keras.models.Sequential(
    [
        keras.layers.Flatten(input_shape=(28, 28)),
        keras.layers.Dense(256, activation="relu"),
        keras.layers.Dropout(0.2),
        keras.layers.Dense(10, activation="softmax"),
    ]
)

# A model defined using Functional API
func_input = keras.Input(shape=(28, 28, 1))
x = keras.layers.Flatten(input_shape=(28, 28))(func_input)
x = keras.layers.Dense(512, activation="relu")(x)
x = keras.layers.Dropout(0.4)(x)
func_output = keras.layers.Dense(10, activation="softmax")(x)

model_b = keras.Model(func_input, func_output)

從模型建立 Trainer 類別物件

trainer_1 = MyTrainer(model_a)
trainer_2 = MyTrainer(model_b)

將模型編譯並擬合到 MNIST 資料集

trainer_1.compile(optimizer=keras.optimizers.SGD())
trainer_1.fit(
    x_train, y_train, epochs=5, batch_size=64, validation_data=(x_test, y_test)
)

trainer_2.compile(optimizer=keras.optimizers.Adam())
trainer_2.fit(
    x_train, y_train, epochs=5, batch_size=64, validation_data=(x_test, y_test)
)
Epoch 1/5
...
Epoch 4/5
 938/938 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - sparse_categorical_accuracy: 0.9770 - val_sparse_categorical_accuracy: 0.9770
Epoch 5/5
 938/938 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - sparse_categorical_accuracy: 0.9805 - val_sparse_categorical_accuracy: 0.9789

<keras.src.callbacks.history.History at 0x7efe405fe560>