開發人員指南 / 在 JAX 中從頭開始撰寫訓練迴圈

在 JAX 中從頭開始撰寫訓練迴圈

作者: fchollet
建立日期 2023/06/25
上次修改日期 2023/06/25
說明: 在 JAX 中撰寫低階訓練和評估迴圈。

在 Colab 中檢視 GitHub 來源


設定

import os

# This guide can only be run with the jax backend.
os.environ["KERAS_BACKEND"] = "jax"

import jax

# We import TF so we can use tf.data.
import tensorflow as tf
import keras
import numpy as np

簡介

Keras 提供預設的訓練和評估迴圈,fit()evaluate()。其用法涵蓋在使用內建方法進行訓練和評估指南中。

如果您想自訂模型的學習演算法,同時仍能利用 fit() 的便利性(例如,使用 fit() 訓練 GAN),您可以子類別化 Model 類別,並實作您自己的 train_step() 方法,該方法會在 fit() 期間重複呼叫。

現在,如果您想要對訓練和評估進行非常低階的控制,則應該從頭開始撰寫您自己的訓練和評估迴圈。這就是本指南的內容。


第一個端對端範例

若要撰寫自訂訓練迴圈,我們需要下列要素

  • 當然,要訓練的模型。
  • 一個最佳化器。您可以使用 keras.optimizers 中的最佳化器,也可以使用 optax 套件中的最佳化器。
  • 一個損失函數。
  • 一個資料集。JAX 生態系統中的標準是透過 tf.data 載入資料,因此我們將使用它。

讓我們將它們排列好。

首先,讓我們取得模型和 MNIST 資料集

def get_model():
    inputs = keras.Input(shape=(784,), name="digits")
    x1 = keras.layers.Dense(64, activation="relu")(inputs)
    x2 = keras.layers.Dense(64, activation="relu")(x1)
    outputs = keras.layers.Dense(10, name="predictions")(x2)
    model = keras.Model(inputs=inputs, outputs=outputs)
    return model


model = get_model()

# Prepare the training dataset.
batch_size = 32
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = np.reshape(x_train, (-1, 784)).astype("float32")
x_test = np.reshape(x_test, (-1, 784)).astype("float32")
y_train = keras.utils.to_categorical(y_train)
y_test = keras.utils.to_categorical(y_test)

# Reserve 10,000 samples for validation.
x_val = x_train[-10000:]
y_val = y_train[-10000:]
x_train = x_train[:-10000]
y_train = y_train[:-10000]

# Prepare the training dataset.
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)

# Prepare the validation dataset.
val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_dataset = val_dataset.batch(batch_size)

接下來,這是損失函數和最佳化器。在此案例中,我們將使用 Keras 最佳化器。

# Instantiate a loss function.
loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True)

# Instantiate an optimizer.
optimizer = keras.optimizers.Adam(learning_rate=1e-3)

在 JAX 中取得梯度

讓我們使用帶有自訂訓練迴圈的小批次梯度來訓練模型。

在 JAX 中,梯度是透過元程式設計計算的:您在函數上呼叫 jax.grad(或 jax.value_and_grad),以便為第一個函數建立梯度計算函數。

因此,我們需要的第一件事是傳回損失值的函數。這就是我們將用於產生梯度函數的函數。類似這樣

def compute_loss(x, y):
    ...
    return loss

一旦您擁有這樣的函數,就可以透過如下所示的元程式設計計算梯度

grad_fn = jax.grad(compute_loss)
grads = grad_fn(x, y)

通常,您不僅要取得梯度值,還要取得損失值。您可以使用 jax.value_and_grad 而不是 jax.grad 來執行此操作

grad_fn = jax.value_and_grad(compute_loss)
loss, grads = grad_fn(x, y)

JAX 計算是完全無狀態的

在 JAX 中,一切都必須是無狀態函數 – 因此我們的損失計算函數也必須是無狀態的。這表示所有 Keras 變數(例如權重張量)都必須以函數輸入的形式傳遞,並且在正向傳遞期間更新的任何變數都必須以函數輸出的形式傳回。該函數沒有副作用。

在正向傳遞期間,Keras 模型無法訓練的變數可能會更新。這些變數可以是,例如 RNG 種子狀態變數或 BatchNormalization 統計資料。我們需要傳回它們。因此,我們需要類似這樣

def compute_loss_and_updates(trainable_variables, non_trainable_variables, x, y):
    ...
    return loss, non_trainable_variables

一旦您擁有這樣的函數,就可以在 value_and_grad 中指定 has_aux 來取得梯度函數:它會告知 JAX 損失計算函數傳回的輸出不只是損失。請注意,損失應始終是第一個輸出。

grad_fn = jax.value_and_grad(compute_loss_and_updates, has_aux=True)
(loss, non_trainable_variables), grads = grad_fn(
    trainable_variables, non_trainable_variables, x, y
)

既然我們已經建立了基本概念,讓我們實作此 compute_loss_and_updates 函數。Keras 模型有一個 stateless_call 方法,在這裡會派上用場。它的運作方式就像 model.__call__,但它需要您明確傳遞模型中所有變數的值,並且它不僅會傳回 __call__ 輸出,還會傳回(可能已更新)的無法訓練變數。

def compute_loss_and_updates(trainable_variables, non_trainable_variables, x, y):
    y_pred, non_trainable_variables = model.stateless_call(
        trainable_variables, non_trainable_variables, x, training=True
    )
    loss = loss_fn(y, y_pred)
    return loss, non_trainable_variables

讓我們取得梯度函數

grad_fn = jax.value_and_grad(compute_loss_and_updates, has_aux=True)

訓練步驟函數

接下來,讓我們實作端對端訓練步驟,這個函數將執行正向傳遞、計算損失、計算梯度,並且使用最佳化器更新可訓練變數。此函數也需要是無狀態的,因此它將取得一個 state 元組作為輸入,其中包含我們將要使用的每個狀態元素

  • trainable_variablesnon_trainable_variables:模型的變數。
  • optimizer_variables:最佳化器的狀態變數,例如動量累積器。

若要更新可訓練變數,我們使用最佳化器的無狀態方法 stateless_apply。它等同於 optimizer.apply(),但它需要始終傳遞 trainable_variablesoptimizer_variables。它會傳回更新的可訓練變數和更新的 optimizer_variables。

def train_step(state, data):
    trainable_variables, non_trainable_variables, optimizer_variables = state
    x, y = data
    (loss, non_trainable_variables), grads = grad_fn(
        trainable_variables, non_trainable_variables, x, y
    )
    trainable_variables, optimizer_variables = optimizer.stateless_apply(
        optimizer_variables, grads, trainable_variables
    )
    # Return updated state
    return loss, (
        trainable_variables,
        non_trainable_variables,
        optimizer_variables,
    )

使用 jax.jit 加快速度

根據預設,JAX 作業會像在 TensorFlow 急迫模式和 PyTorch 急迫模式中一樣急迫執行。而且就像 TensorFlow 急迫模式和 PyTorch 急迫模式一樣,它的速度非常慢 – 急迫模式最好用作偵錯環境,而不是執行任何實際工作的手段。因此,讓我們透過編譯來加快 train_step 的速度。

當您有 JAX 無狀態函數時,可以透過 @jax.jit 修飾符將其編譯為 XLA。它將在第一次執行期間進行追蹤,並且在後續執行中,您將執行追蹤的圖表(這就像 @tf.function(jit_compile=True))。讓我們試試看

@jax.jit
def train_step(state, data):
    trainable_variables, non_trainable_variables, optimizer_variables = state
    x, y = data
    (loss, non_trainable_variables), grads = grad_fn(
        trainable_variables, non_trainable_variables, x, y
    )
    trainable_variables, optimizer_variables = optimizer.stateless_apply(
        optimizer_variables, grads, trainable_variables
    )
    # Return updated state
    return loss, (
        trainable_variables,
        non_trainable_variables,
        optimizer_variables,
    )

我們現在準備好訓練模型了。訓練迴圈本身很簡單:我們只需重複呼叫 loss, state = train_step(state, data)

注意

  • 我們會先將 tf.data.Dataset 產生的 TF 張量轉換為 NumPy,再將它們傳遞給我們的 JAX 函數。
  • 所有變數都必須事先建立:必須建立模型,並且必須建立最佳化器。由於我們使用的是函數式 API 模型,因此它已經建立,但如果它是子類別化模型,則需要在批次資料上呼叫它來建立它。
# Build optimizer variables.
optimizer.build(model.trainable_variables)

trainable_variables = model.trainable_variables
non_trainable_variables = model.non_trainable_variables
optimizer_variables = optimizer.variables
state = trainable_variables, non_trainable_variables, optimizer_variables

# Training loop
for step, data in enumerate(train_dataset):
    data = (data[0].numpy(), data[1].numpy())
    loss, state = train_step(state, data)
    # Log every 100 batches.
    if step % 100 == 0:
        print(f"Training loss (for 1 batch) at step {step}: {float(loss):.4f}")
        print(f"Seen so far: {(step + 1) * batch_size} samples")
Training loss (for 1 batch) at step 0: 96.2726
Seen so far: 32 samples
Training loss (for 1 batch) at step 100: 2.0853
Seen so far: 3232 samples
Training loss (for 1 batch) at step 200: 0.6535
Seen so far: 6432 samples
Training loss (for 1 batch) at step 300: 1.2679
Seen so far: 9632 samples
Training loss (for 1 batch) at step 400: 0.7563
Seen so far: 12832 samples
Training loss (for 1 batch) at step 500: 0.7154
Seen so far: 16032 samples
Training loss (for 1 batch) at step 600: 1.0267
Seen so far: 19232 samples
Training loss (for 1 batch) at step 700: 0.6860
Seen so far: 22432 samples
Training loss (for 1 batch) at step 800: 0.7306
Seen so far: 25632 samples
Training loss (for 1 batch) at step 900: 0.4571
Seen so far: 28832 samples
Training loss (for 1 batch) at step 1000: 0.6023
Seen so far: 32032 samples
Training loss (for 1 batch) at step 1100: 0.9140
Seen so far: 35232 samples
Training loss (for 1 batch) at step 1200: 0.4224
Seen so far: 38432 samples
Training loss (for 1 batch) at step 1300: 0.6696
Seen so far: 41632 samples
Training loss (for 1 batch) at step 1400: 0.1399
Seen so far: 44832 samples
Training loss (for 1 batch) at step 1500: 0.5761
Seen so far: 48032 samples

這裡需要注意的關鍵是迴圈完全是無狀態的 – 附加到模型的變數 (model.weights) 在迴圈期間永遠不會更新。它們的新值僅儲存在 state 元組中。這表示在儲存模型之前,您應該在某個時間點將新的變數值重新附加到模型。

只需對您要更新的每個模型變數呼叫 variable.assign(new_value)

trainable_variables, non_trainable_variables, optimizer_variables = state
for variable, value in zip(model.trainable_variables, trainable_variables):
    variable.assign(value)
for variable, value in zip(model.non_trainable_variables, non_trainable_variables):
    variable.assign(value)

低階處理指標

讓我們在此基本訓練迴圈中新增指標監控。

您可以在從頭撰寫的訓練迴圈中輕鬆重複使用內建 Keras 指標(或您撰寫的自訂指標)。以下是流程

  • 在迴圈開始時,建立指標的執行個體
  • train_step 引數和 compute_loss_and_updates 引數中包含 metric_variables
  • compute_loss_and_updates 函數中呼叫 metric.stateless_update_state()。它等同於 update_state() – 只是無狀態。
  • 當您需要在 train_step 外部(在急迫範圍內)顯示指標的目前值時,請將新的指標變數值附加到指標物件,並呼叫 metric.result()
  • 當您需要清除指標的狀態時(通常在週期結束時)呼叫 metric.reset_state()

讓我們運用這些知識在訓練結束時計算訓練和驗證資料的 CategoricalAccuracy

# Get a fresh model
model = get_model()

# Instantiate an optimizer to train the model.
optimizer = keras.optimizers.Adam(learning_rate=1e-3)
# Instantiate a loss function.
loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True)

# Prepare the metrics.
train_acc_metric = keras.metrics.CategoricalAccuracy()
val_acc_metric = keras.metrics.CategoricalAccuracy()


def compute_loss_and_updates(
    trainable_variables, non_trainable_variables, metric_variables, x, y
):
    y_pred, non_trainable_variables = model.stateless_call(
        trainable_variables, non_trainable_variables, x
    )
    loss = loss_fn(y, y_pred)
    metric_variables = train_acc_metric.stateless_update_state(
        metric_variables, y, y_pred
    )
    return loss, (non_trainable_variables, metric_variables)


grad_fn = jax.value_and_grad(compute_loss_and_updates, has_aux=True)


@jax.jit
def train_step(state, data):
    (
        trainable_variables,
        non_trainable_variables,
        optimizer_variables,
        metric_variables,
    ) = state
    x, y = data
    (loss, (non_trainable_variables, metric_variables)), grads = grad_fn(
        trainable_variables, non_trainable_variables, metric_variables, x, y
    )
    trainable_variables, optimizer_variables = optimizer.stateless_apply(
        optimizer_variables, grads, trainable_variables
    )
    # Return updated state
    return loss, (
        trainable_variables,
        non_trainable_variables,
        optimizer_variables,
        metric_variables,
    )

我們也會準備一個評估步驟函數

@jax.jit
def eval_step(state, data):
    trainable_variables, non_trainable_variables, metric_variables = state
    x, y = data
    y_pred, non_trainable_variables = model.stateless_call(
        trainable_variables, non_trainable_variables, x
    )
    loss = loss_fn(y, y_pred)
    metric_variables = val_acc_metric.stateless_update_state(
        metric_variables, y, y_pred
    )
    return loss, (
        trainable_variables,
        non_trainable_variables,
        metric_variables,
    )

以下是我們的迴圈

# Build optimizer variables.
optimizer.build(model.trainable_variables)

trainable_variables = model.trainable_variables
non_trainable_variables = model.non_trainable_variables
optimizer_variables = optimizer.variables
metric_variables = train_acc_metric.variables
state = (
    trainable_variables,
    non_trainable_variables,
    optimizer_variables,
    metric_variables,
)

# Training loop
for step, data in enumerate(train_dataset):
    data = (data[0].numpy(), data[1].numpy())
    loss, state = train_step(state, data)
    # Log every 100 batches.
    if step % 100 == 0:
        print(f"Training loss (for 1 batch) at step {step}: {float(loss):.4f}")
        _, _, _, metric_variables = state
        for variable, value in zip(train_acc_metric.variables, metric_variables):
            variable.assign(value)
        print(f"Training accuracy: {train_acc_metric.result()}")
        print(f"Seen so far: {(step + 1) * batch_size} samples")

metric_variables = val_acc_metric.variables
(
    trainable_variables,
    non_trainable_variables,
    optimizer_variables,
    metric_variables,
) = state
state = trainable_variables, non_trainable_variables, metric_variables

# Eval loop
for step, data in enumerate(val_dataset):
    data = (data[0].numpy(), data[1].numpy())
    loss, state = eval_step(state, data)
    # Log every 100 batches.
    if step % 100 == 0:
        print(f"Validation loss (for 1 batch) at step {step}: {float(loss):.4f}")
        _, _, metric_variables = state
        for variable, value in zip(val_acc_metric.variables, metric_variables):
            variable.assign(value)
        print(f"Validation accuracy: {val_acc_metric.result()}")
        print(f"Seen so far: {(step + 1) * batch_size} samples")
Training loss (for 1 batch) at step 0: 70.8851
Training accuracy: 0.09375
Seen so far: 32 samples
Training loss (for 1 batch) at step 100: 2.1930
Training accuracy: 0.6596534848213196
Seen so far: 3232 samples
Training loss (for 1 batch) at step 200: 3.0249
Training accuracy: 0.7352300882339478
Seen so far: 6432 samples
Training loss (for 1 batch) at step 300: 0.6004
Training accuracy: 0.7588247656822205
Seen so far: 9632 samples
Training loss (for 1 batch) at step 400: 1.4633
Training accuracy: 0.7736907601356506
Seen so far: 12832 samples
Training loss (for 1 batch) at step 500: 1.3367
Training accuracy: 0.7826846241950989
Seen so far: 16032 samples
Training loss (for 1 batch) at step 600: 0.8767
Training accuracy: 0.7930532693862915
Seen so far: 19232 samples
Training loss (for 1 batch) at step 700: 0.3479
Training accuracy: 0.8004636168479919
Seen so far: 22432 samples
Training loss (for 1 batch) at step 800: 0.3608
Training accuracy: 0.8066869378089905
Seen so far: 25632 samples
Training loss (for 1 batch) at step 900: 0.7582
Training accuracy: 0.8117369413375854
Seen so far: 28832 samples
Training loss (for 1 batch) at step 1000: 1.3135
Training accuracy: 0.8142170310020447
Seen so far: 32032 samples
Training loss (for 1 batch) at step 1100: 1.0202
Training accuracy: 0.8186308145523071
Seen so far: 35232 samples
Training loss (for 1 batch) at step 1200: 0.6766
Training accuracy: 0.822023332118988
Seen so far: 38432 samples
Training loss (for 1 batch) at step 1300: 0.7606
Training accuracy: 0.8257110118865967
Seen so far: 41632 samples
Training loss (for 1 batch) at step 1400: 0.7657
Training accuracy: 0.8290283679962158
Seen so far: 44832 samples
Training loss (for 1 batch) at step 1500: 0.6563
Training accuracy: 0.831653892993927
Seen so far: 48032 samples
Validation loss (for 1 batch) at step 0: 0.1622
Validation accuracy: 0.8329269289970398
Seen so far: 32 samples
Validation loss (for 1 batch) at step 100: 0.7455
Validation accuracy: 0.8338780999183655
Seen so far: 3232 samples
Validation loss (for 1 batch) at step 200: 0.2738
Validation accuracy: 0.836174488067627
Seen so far: 6432 samples
Validation loss (for 1 batch) at step 300: 0.1255
Validation accuracy: 0.8390461206436157
Seen so far: 9632 samples

低階處理模型追蹤的損失

層和模型會遞迴追蹤在呼叫 self.add_loss(value) 的層的正向傳遞期間建立的任何損失。正向傳遞結束時,可透過屬性 model.losses 取得產生的純量損失值清單。

如果您想要使用這些損失元件,則應將它們加總並將它們新增至訓練步驟中的主要損失。

請考慮此層,它會建立活動正規化損失

class ActivityRegularizationLayer(keras.layers.Layer):
    def call(self, inputs):
        self.add_loss(1e-2 * jax.numpy.sum(inputs))
        return inputs

讓我們建立一個使用它的非常簡單的模型

inputs = keras.Input(shape=(784,), name="digits")
x = keras.layers.Dense(64, activation="relu")(inputs)
# Insert activity regularization as a layer
x = ActivityRegularizationLayer()(x)
x = keras.layers.Dense(64, activation="relu")(x)
outputs = keras.layers.Dense(10, name="predictions")(x)

model = keras.Model(inputs=inputs, outputs=outputs)

現在,我們的 compute_loss_and_updates 函數應該如下所示

  • return_losses=True 傳遞至 model.stateless_call()
  • 將產生的 losses 加總並將它們新增至主要損失。
def compute_loss_and_updates(
    trainable_variables, non_trainable_variables, metric_variables, x, y
):
    y_pred, non_trainable_variables, losses = model.stateless_call(
        trainable_variables, non_trainable_variables, x, return_losses=True
    )
    loss = loss_fn(y, y_pred)
    if losses:
        loss += jax.numpy.sum(losses)
    metric_variables = train_acc_metric.stateless_update_state(
        metric_variables, y, y_pred
    )
    return loss, non_trainable_variables, metric_variables

就是這樣!