開發者指南 / 使用 PyTorch 自訂 `fit()` 中的行為

使用 PyTorch 自訂 fit() 中的行為

作者: fchollet
建立日期 2023/06/27
上次修改日期 2024/08/01
描述: 使用 PyTorch 覆寫 Model 類別的訓練步驟。

在 Colab 中檢視 GitHub 原始碼


簡介

當您進行監督式學習時,可以使用 fit(),一切都會順利運作。

當您需要控制每個小細節時,您可以從頭開始編寫自己的訓練迴圈。

但是,如果您需要自訂的訓練演算法,但仍然想要受益於 fit() 的便捷功能(例如回呼、內建分散式支援或步驟融合),該怎麼辦?

Keras 的核心原則是複雜性的漸進揭露。您應該始終能夠以漸進的方式進入較低層級的工作流程。如果高階功能與您的使用案例不完全匹配,您不應該從懸崖上掉下來。您應該能夠在保留相當程度的高階便利性的同時,獲得對小細節的更多控制。

當您需要自訂 fit() 的行為時,您應該覆寫 Model 類別的訓練步驟函數。這是 fit() 針對每個資料批次呼叫的函數。然後,您將能夠像平常一樣呼叫 fit() – 它將執行您自己的學習演算法。

請注意,此模式不會阻止您使用函數式 API 建立模型。無論您是建立 Sequential 模型、函數式 API 模型還是子類別模型,都可以執行此操作。

讓我們看看它是如何運作的。


設定

import os

# This guide can only be run with the torch backend.
os.environ["KERAS_BACKEND"] = "torch"

import torch
import keras
from keras import layers
import numpy as np

第一個簡單範例

讓我們從一個簡單的範例開始

  • 我們建立一個新的類別,該類別繼承自 keras.Model
  • 我們只覆寫方法 train_step(self, data)
  • 我們傳回一個字典,將指標名稱(包括損失)對應到它們的目前值。

輸入引數 data 是作為訓練資料傳遞給 fit 的內容

  • 如果您傳遞 NumPy 陣列,透過呼叫 fit(x, y, ...),則 data 將會是元組 (x, y)
  • 如果您傳遞 torch.utils.data.DataLoadertf.data.Dataset,透過呼叫 fit(dataset, ...),則 data 將會是 dataset 在每個批次產生的內容。

train_step() 方法的主體中,我們實作一個常規的訓練更新,類似於您已經熟悉的內容。重要的是,我們透過 self.compute_loss() 計算損失,它會包裝傳遞給 compile() 的損失函數。

同樣地,我們在來自 self.metrics 的指標上呼叫 metric.update_state(y, y_pred),以更新傳遞至 compile() 的指標的狀態,並在最後從 self.metrics 查詢結果,以擷取它們的目前值。

class CustomModel(keras.Model):
    def train_step(self, data):
        # Unpack the data. Its structure depends on your model and
        # on what you pass to `fit()`.
        x, y = data

        # Call torch.nn.Module.zero_grad() to clear the leftover gradients
        # for the weights from the previous train step.
        self.zero_grad()

        # Compute loss
        y_pred = self(x, training=True)  # Forward pass
        loss = self.compute_loss(y=y, y_pred=y_pred)

        # Call torch.Tensor.backward() on the loss to compute gradients
        # for the weights.
        loss.backward()

        trainable_weights = [v for v in self.trainable_weights]
        gradients = [v.value.grad for v in trainable_weights]

        # Update weights
        with torch.no_grad():
            self.optimizer.apply(gradients, trainable_weights)

        # Update metrics (includes the metric that tracks the loss)
        for metric in self.metrics:
            if metric.name == "loss":
                metric.update_state(loss)
            else:
                metric.update_state(y, y_pred)

        # Return a dict mapping metric names to current value
        # Note that it will include the loss (tracked in self.metrics).
        return {m.name: m.result() for m in self.metrics}

讓我們試試看

# Construct and compile an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
model.compile(optimizer="adam", loss="mse", metrics=["mae"])

# Just use `fit` as usual
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.fit(x, y, epochs=3)
Epoch 1/3    
 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - mae: 0.3410 - loss: 0.1772
Epoch 2/3
 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - mae: 0.3336 - loss: 0.1695
Epoch 3/3
 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - mae: 0.3170 - loss: 0.1511

<keras.src.callbacks.history.History at 0x7f48a3255710>

進入更低層級

當然,您可以直接跳過在 compile() 中傳遞損失函數,而是在 train_step手動執行所有操作。指標也是如此。

以下是一個較低層級的範例,僅使用 compile() 來設定最佳化器

  • 我們首先建立 Metric 實例,以追蹤我們的損失和 MAE 分數(在 __init__() 中)。
  • 我們實作一個自訂的 train_step(),以更新這些指標的狀態(透過在它們上面呼叫 update_state()),然後查詢它們(透過 result())以傳回它們的目前平均值,以由進度列顯示並傳遞給任何回呼。
  • 請注意,我們需要在每個 epoch 之間在我們的指標上呼叫 reset_states()!否則,呼叫 result() 將傳回自訓練開始以來的平均值,而我們通常使用每個 epoch 的平均值。幸運的是,該架構可以為我們做到這一點:只需在模型的 metrics 屬性中列出任何您想要重設的指標即可。模型將會在每個 fit() epoch 開始時,或是在呼叫 evaluate() 開始時,在 此處列出的任何物件上呼叫 reset_states()
class CustomModel(keras.Model):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.loss_tracker = keras.metrics.Mean(name="loss")
        self.mae_metric = keras.metrics.MeanAbsoluteError(name="mae")
        self.loss_fn = keras.losses.MeanSquaredError()

    def train_step(self, data):
        x, y = data

        # Call torch.nn.Module.zero_grad() to clear the leftover gradients
        # for the weights from the previous train step.
        self.zero_grad()

        # Compute loss
        y_pred = self(x, training=True)  # Forward pass
        loss = self.loss_fn(y, y_pred)

        # Call torch.Tensor.backward() on the loss to compute gradients
        # for the weights.
        loss.backward()

        trainable_weights = [v for v in self.trainable_weights]
        gradients = [v.value.grad for v in trainable_weights]

        # Update weights
        with torch.no_grad():
            self.optimizer.apply(gradients, trainable_weights)

        # Compute our own metrics
        self.loss_tracker.update_state(loss)
        self.mae_metric.update_state(y, y_pred)
        return {
            "loss": self.loss_tracker.result(),
            "mae": self.mae_metric.result(),
        }

    @property
    def metrics(self):
        # We list our `Metric` objects here so that `reset_states()` can be
        # called automatically at the start of each epoch
        # or at the start of `evaluate()`.
        return [self.loss_tracker, self.mae_metric]


# Construct an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)

# We don't pass a loss or metrics here.
model.compile(optimizer="adam")

# Just use `fit` as usual -- you can use callbacks, etc.
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.fit(x, y, epochs=5)
Epoch 1/5
 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.6173 - mae: 0.6607
Epoch 2/5
 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.2340 - mae: 0.3883
Epoch 3/5
 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.1922 - mae: 0.3517
Epoch 4/5
 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.1802 - mae: 0.3411
Epoch 5/5
 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.1862 - mae: 0.3505

<keras.src.callbacks.history.History at 0x7f48975ccbd0>

支援 sample_weightclass_weight

您可能已經注意到,我們第一個基本範例沒有提及任何樣本權重。如果您想要支援 fit() 引數 sample_weightclass_weight,您只需執行下列操作

  • data 引數中解包 sample_weight
  • 將其傳遞給 compute_lossupdate_state(當然,如果您不依賴 compile() 來處理損失和指標,您也可以手動套用它)
  • 就是這樣。
class CustomModel(keras.Model):
    def train_step(self, data):
        # Unpack the data. Its structure depends on your model and
        # on what you pass to `fit()`.
        if len(data) == 3:
            x, y, sample_weight = data
        else:
            sample_weight = None
            x, y = data

        # Call torch.nn.Module.zero_grad() to clear the leftover gradients
        # for the weights from the previous train step.
        self.zero_grad()

        # Compute loss
        y_pred = self(x, training=True)  # Forward pass
        loss = self.compute_loss(
            y=y,
            y_pred=y_pred,
            sample_weight=sample_weight,
        )

        # Call torch.Tensor.backward() on the loss to compute gradients
        # for the weights.
        loss.backward()

        trainable_weights = [v for v in self.trainable_weights]
        gradients = [v.value.grad for v in trainable_weights]

        # Update weights
        with torch.no_grad():
            self.optimizer.apply(gradients, trainable_weights)

        # Update metrics (includes the metric that tracks the loss)
        for metric in self.metrics:
            if metric.name == "loss":
                metric.update_state(loss)
            else:
                metric.update_state(y, y_pred, sample_weight=sample_weight)

        # Return a dict mapping metric names to current value
        # Note that it will include the loss (tracked in self.metrics).
        return {m.name: m.result() for m in self.metrics}


# Construct and compile an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
model.compile(optimizer="adam", loss="mse", metrics=["mae"])

# You can now use sample_weight argument
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
sw = np.random.random((1000, 1))
model.fit(x, y, sample_weight=sw, epochs=3)
Epoch 1/3
 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - mae: 0.3216 - loss: 0.0827
Epoch 2/3
 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - mae: 0.3156 - loss: 0.0803
Epoch 3/3
 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - mae: 0.3085 - loss: 0.0760

<keras.src.callbacks.history.History at 0x7f48975d7bd0>

提供您自己的評估步驟

如果您想要對 model.evaluate() 的呼叫執行相同的操作,該怎麼辦?那麼您將以完全相同的方式覆寫 test_step。以下是它的樣子

class CustomModel(keras.Model):
    def test_step(self, data):
        # Unpack the data
        x, y = data
        # Compute predictions
        y_pred = self(x, training=False)
        # Updates the metrics tracking the loss
        loss = self.compute_loss(y=y, y_pred=y_pred)
        # Update the metrics.
        for metric in self.metrics:
            if metric.name == "loss":
                metric.update_state(loss)
            else:
                metric.update_state(y, y_pred)
        # Return a dict mapping metric names to current value.
        # Note that it will include the loss (tracked in self.metrics).
        return {m.name: m.result() for m in self.metrics}


# Construct an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
model.compile(loss="mse", metrics=["mae"])

# Evaluate with our custom test_step
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.evaluate(x, y)

1/32 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - mae: 0.8706 - loss: 0.9344



32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - mae: 0.8959 - loss: 0.9952

[1.0077838897705078, 0.8984771370887756]

總結:端對端 GAN 範例

讓我們逐步瀏覽一個利用您剛學到的一切的端對端範例。

讓我們考慮

  • 一個旨在產生 28x28x1 影像的產生器網路。
  • 一個旨在將 28x28x1 影像分類為兩個類別(「假」和「真」)的判別器網路。
  • 每個網路一個最佳化器。
  • 一個用於訓練判別器的損失函數。
# Create the discriminator
discriminator = keras.Sequential(
    [
        keras.Input(shape=(28, 28, 1)),
        layers.Conv2D(64, (3, 3), strides=(2, 2), padding="same"),
        layers.LeakyReLU(negative_slope=0.2),
        layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same"),
        layers.LeakyReLU(negative_slope=0.2),
        layers.GlobalMaxPooling2D(),
        layers.Dense(1),
    ],
    name="discriminator",
)

# Create the generator
latent_dim = 128
generator = keras.Sequential(
    [
        keras.Input(shape=(latent_dim,)),
        # We want to generate 128 coefficients to reshape into a 7x7x128 map
        layers.Dense(7 * 7 * 128),
        layers.LeakyReLU(negative_slope=0.2),
        layers.Reshape((7, 7, 128)),
        layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
        layers.LeakyReLU(negative_slope=0.2),
        layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
        layers.LeakyReLU(negative_slope=0.2),
        layers.Conv2D(1, (7, 7), padding="same", activation="sigmoid"),
    ],
    name="generator",
)

以下是一個功能完整的 GAN 類別,它覆寫 compile() 以使用其自己的簽章,並在 train_step 中以 17 行實作整個 GAN 演算法

class GAN(keras.Model):
    def __init__(self, discriminator, generator, latent_dim):
        super().__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim
        self.d_loss_tracker = keras.metrics.Mean(name="d_loss")
        self.g_loss_tracker = keras.metrics.Mean(name="g_loss")
        self.seed_generator = keras.random.SeedGenerator(1337)
        self.built = True

    @property
    def metrics(self):
        return [self.d_loss_tracker, self.g_loss_tracker]

    def compile(self, d_optimizer, g_optimizer, loss_fn):
        super().compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.loss_fn = loss_fn

    def train_step(self, real_images):
        device = "cuda" if torch.cuda.is_available() else "cpu"
        if isinstance(real_images, tuple) or isinstance(real_images, list):
            real_images = real_images[0]
        # Sample random points in the latent space
        batch_size = real_images.shape[0]
        random_latent_vectors = keras.random.normal(
            shape=(batch_size, self.latent_dim), seed=self.seed_generator
        )

        # Decode them to fake images
        generated_images = self.generator(random_latent_vectors)

        # Combine them with real images
        real_images = torch.tensor(real_images, device=device)
        combined_images = torch.concat([generated_images, real_images], axis=0)

        # Assemble labels discriminating real from fake images
        labels = torch.concat(
            [
                torch.ones((batch_size, 1), device=device),
                torch.zeros((batch_size, 1), device=device),
            ],
            axis=0,
        )
        # Add random noise to the labels - important trick!
        labels += 0.05 * keras.random.uniform(labels.shape, seed=self.seed_generator)

        # Train the discriminator
        self.zero_grad()
        predictions = self.discriminator(combined_images)
        d_loss = self.loss_fn(labels, predictions)
        d_loss.backward()
        grads = [v.value.grad for v in self.discriminator.trainable_weights]
        with torch.no_grad():
            self.d_optimizer.apply(grads, self.discriminator.trainable_weights)

        # Sample random points in the latent space
        random_latent_vectors = keras.random.normal(
            shape=(batch_size, self.latent_dim), seed=self.seed_generator
        )

        # Assemble labels that say "all real images"
        misleading_labels = torch.zeros((batch_size, 1), device=device)

        # Train the generator (note that we should *not* update the weights
        # of the discriminator)!
        self.zero_grad()
        predictions = self.discriminator(self.generator(random_latent_vectors))
        g_loss = self.loss_fn(misleading_labels, predictions)
        grads = g_loss.backward()
        grads = [v.value.grad for v in self.generator.trainable_weights]
        with torch.no_grad():
            self.g_optimizer.apply(grads, self.generator.trainable_weights)

        # Update metrics and return their value.
        self.d_loss_tracker.update_state(d_loss)
        self.g_loss_tracker.update_state(g_loss)
        return {
            "d_loss": self.d_loss_tracker.result(),
            "g_loss": self.g_loss_tracker.result(),
        }

讓我們測試一下它

# Prepare the dataset. We use both the training & test MNIST digits.
batch_size = 64
(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
all_digits = np.concatenate([x_train, x_test])
all_digits = all_digits.astype("float32") / 255.0
all_digits = np.reshape(all_digits, (-1, 28, 28, 1))

# Create a TensorDataset
dataset = torch.utils.data.TensorDataset(
    torch.from_numpy(all_digits), torch.from_numpy(all_digits)
)
# Create a DataLoader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

gan = GAN(discriminator=discriminator, generator=generator, latent_dim=latent_dim)
gan.compile(
    d_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
    g_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
    loss_fn=keras.losses.BinaryCrossentropy(from_logits=True),
)

gan.fit(dataloader, epochs=1)
 1094/1094 ━━━━━━━━━━━━━━━━━━━━ 394s 360ms/step - d_loss: 0.2436 - g_loss: 4.7259
<keras.src.callbacks.history.History at 0x7f489760a490>

深度學習背後的概念很簡單,所以為什麼它們的實作應該很痛苦呢?