作者: fchollet
建立日期 2023/06/25
上次修改日期 2023/06/25
說明: 在 JAX 中撰寫低階訓練和評估迴圈。
import os
# This guide can only be run with the jax backend.
os.environ["KERAS_BACKEND"] = "jax"
import jax
# We import TF so we can use tf.data.
import tensorflow as tf
import keras
import numpy as np
Keras 提供預設的訓練和評估迴圈,fit()
和 evaluate()
。其用法涵蓋在使用內建方法進行訓練和評估指南中。
如果您想自訂模型的學習演算法,同時仍能利用 fit()
的便利性(例如,使用 fit()
訓練 GAN),您可以子類別化 Model
類別,並實作您自己的 train_step()
方法,該方法會在 fit()
期間重複呼叫。
現在,如果您想要對訓練和評估進行非常低階的控制,則應該從頭開始撰寫您自己的訓練和評估迴圈。這就是本指南的內容。
若要撰寫自訂訓練迴圈,我們需要下列要素
keras.optimizers
中的最佳化器,也可以使用 optax
套件中的最佳化器。tf.data
載入資料,因此我們將使用它。讓我們將它們排列好。
首先,讓我們取得模型和 MNIST 資料集
def get_model():
inputs = keras.Input(shape=(784,), name="digits")
x1 = keras.layers.Dense(64, activation="relu")(inputs)
x2 = keras.layers.Dense(64, activation="relu")(x1)
outputs = keras.layers.Dense(10, name="predictions")(x2)
model = keras.Model(inputs=inputs, outputs=outputs)
return model
model = get_model()
# Prepare the training dataset.
batch_size = 32
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = np.reshape(x_train, (-1, 784)).astype("float32")
x_test = np.reshape(x_test, (-1, 784)).astype("float32")
y_train = keras.utils.to_categorical(y_train)
y_test = keras.utils.to_categorical(y_test)
# Reserve 10,000 samples for validation.
x_val = x_train[-10000:]
y_val = y_train[-10000:]
x_train = x_train[:-10000]
y_train = y_train[:-10000]
# Prepare the training dataset.
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)
# Prepare the validation dataset.
val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_dataset = val_dataset.batch(batch_size)
接下來,這是損失函數和最佳化器。在此案例中,我們將使用 Keras 最佳化器。
# Instantiate a loss function.
loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True)
# Instantiate an optimizer.
optimizer = keras.optimizers.Adam(learning_rate=1e-3)
讓我們使用帶有自訂訓練迴圈的小批次梯度來訓練模型。
在 JAX 中,梯度是透過元程式設計計算的:您在函數上呼叫 jax.grad
(或 jax.value_and_grad
),以便為第一個函數建立梯度計算函數。
因此,我們需要的第一件事是傳回損失值的函數。這就是我們將用於產生梯度函數的函數。類似這樣
def compute_loss(x, y):
...
return loss
一旦您擁有這樣的函數,就可以透過如下所示的元程式設計計算梯度
grad_fn = jax.grad(compute_loss)
grads = grad_fn(x, y)
通常,您不僅要取得梯度值,還要取得損失值。您可以使用 jax.value_and_grad
而不是 jax.grad
來執行此操作
grad_fn = jax.value_and_grad(compute_loss)
loss, grads = grad_fn(x, y)
在 JAX 中,一切都必須是無狀態函數 – 因此我們的損失計算函數也必須是無狀態的。這表示所有 Keras 變數(例如權重張量)都必須以函數輸入的形式傳遞,並且在正向傳遞期間更新的任何變數都必須以函數輸出的形式傳回。該函數沒有副作用。
在正向傳遞期間,Keras 模型無法訓練的變數可能會更新。這些變數可以是,例如 RNG 種子狀態變數或 BatchNormalization 統計資料。我們需要傳回它們。因此,我們需要類似這樣
def compute_loss_and_updates(trainable_variables, non_trainable_variables, x, y):
...
return loss, non_trainable_variables
一旦您擁有這樣的函數,就可以在 value_and_grad
中指定 has_aux
來取得梯度函數:它會告知 JAX 損失計算函數傳回的輸出不只是損失。請注意,損失應始終是第一個輸出。
grad_fn = jax.value_and_grad(compute_loss_and_updates, has_aux=True)
(loss, non_trainable_variables), grads = grad_fn(
trainable_variables, non_trainable_variables, x, y
)
既然我們已經建立了基本概念,讓我們實作此 compute_loss_and_updates
函數。Keras 模型有一個 stateless_call
方法,在這裡會派上用場。它的運作方式就像 model.__call__
,但它需要您明確傳遞模型中所有變數的值,並且它不僅會傳回 __call__
輸出,還會傳回(可能已更新)的無法訓練變數。
def compute_loss_and_updates(trainable_variables, non_trainable_variables, x, y):
y_pred, non_trainable_variables = model.stateless_call(
trainable_variables, non_trainable_variables, x, training=True
)
loss = loss_fn(y, y_pred)
return loss, non_trainable_variables
讓我們取得梯度函數
grad_fn = jax.value_and_grad(compute_loss_and_updates, has_aux=True)
接下來,讓我們實作端對端訓練步驟,這個函數將執行正向傳遞、計算損失、計算梯度,並且使用最佳化器更新可訓練變數。此函數也需要是無狀態的,因此它將取得一個 state
元組作為輸入,其中包含我們將要使用的每個狀態元素
trainable_variables
和 non_trainable_variables
:模型的變數。optimizer_variables
:最佳化器的狀態變數,例如動量累積器。若要更新可訓練變數,我們使用最佳化器的無狀態方法 stateless_apply
。它等同於 optimizer.apply()
,但它需要始終傳遞 trainable_variables
和 optimizer_variables
。它會傳回更新的可訓練變數和更新的 optimizer_variables。
def train_step(state, data):
trainable_variables, non_trainable_variables, optimizer_variables = state
x, y = data
(loss, non_trainable_variables), grads = grad_fn(
trainable_variables, non_trainable_variables, x, y
)
trainable_variables, optimizer_variables = optimizer.stateless_apply(
optimizer_variables, grads, trainable_variables
)
# Return updated state
return loss, (
trainable_variables,
non_trainable_variables,
optimizer_variables,
)
jax.jit
加快速度根據預設,JAX 作業會像在 TensorFlow 急迫模式和 PyTorch 急迫模式中一樣急迫執行。而且就像 TensorFlow 急迫模式和 PyTorch 急迫模式一樣,它的速度非常慢 – 急迫模式最好用作偵錯環境,而不是執行任何實際工作的手段。因此,讓我們透過編譯來加快 train_step
的速度。
當您有 JAX 無狀態函數時,可以透過 @jax.jit
修飾符將其編譯為 XLA。它將在第一次執行期間進行追蹤,並且在後續執行中,您將執行追蹤的圖表(這就像 @tf.function(jit_compile=True)
)。讓我們試試看
@jax.jit
def train_step(state, data):
trainable_variables, non_trainable_variables, optimizer_variables = state
x, y = data
(loss, non_trainable_variables), grads = grad_fn(
trainable_variables, non_trainable_variables, x, y
)
trainable_variables, optimizer_variables = optimizer.stateless_apply(
optimizer_variables, grads, trainable_variables
)
# Return updated state
return loss, (
trainable_variables,
non_trainable_variables,
optimizer_variables,
)
我們現在準備好訓練模型了。訓練迴圈本身很簡單:我們只需重複呼叫 loss, state = train_step(state, data)
。
注意
tf.data.Dataset
產生的 TF 張量轉換為 NumPy,再將它們傳遞給我們的 JAX 函數。# Build optimizer variables.
optimizer.build(model.trainable_variables)
trainable_variables = model.trainable_variables
non_trainable_variables = model.non_trainable_variables
optimizer_variables = optimizer.variables
state = trainable_variables, non_trainable_variables, optimizer_variables
# Training loop
for step, data in enumerate(train_dataset):
data = (data[0].numpy(), data[1].numpy())
loss, state = train_step(state, data)
# Log every 100 batches.
if step % 100 == 0:
print(f"Training loss (for 1 batch) at step {step}: {float(loss):.4f}")
print(f"Seen so far: {(step + 1) * batch_size} samples")
Training loss (for 1 batch) at step 0: 96.2726
Seen so far: 32 samples
Training loss (for 1 batch) at step 100: 2.0853
Seen so far: 3232 samples
Training loss (for 1 batch) at step 200: 0.6535
Seen so far: 6432 samples
Training loss (for 1 batch) at step 300: 1.2679
Seen so far: 9632 samples
Training loss (for 1 batch) at step 400: 0.7563
Seen so far: 12832 samples
Training loss (for 1 batch) at step 500: 0.7154
Seen so far: 16032 samples
Training loss (for 1 batch) at step 600: 1.0267
Seen so far: 19232 samples
Training loss (for 1 batch) at step 700: 0.6860
Seen so far: 22432 samples
Training loss (for 1 batch) at step 800: 0.7306
Seen so far: 25632 samples
Training loss (for 1 batch) at step 900: 0.4571
Seen so far: 28832 samples
Training loss (for 1 batch) at step 1000: 0.6023
Seen so far: 32032 samples
Training loss (for 1 batch) at step 1100: 0.9140
Seen so far: 35232 samples
Training loss (for 1 batch) at step 1200: 0.4224
Seen so far: 38432 samples
Training loss (for 1 batch) at step 1300: 0.6696
Seen so far: 41632 samples
Training loss (for 1 batch) at step 1400: 0.1399
Seen so far: 44832 samples
Training loss (for 1 batch) at step 1500: 0.5761
Seen so far: 48032 samples
這裡需要注意的關鍵是迴圈完全是無狀態的 – 附加到模型的變數 (model.weights
) 在迴圈期間永遠不會更新。它們的新值僅儲存在 state
元組中。這表示在儲存模型之前,您應該在某個時間點將新的變數值重新附加到模型。
只需對您要更新的每個模型變數呼叫 variable.assign(new_value)
trainable_variables, non_trainable_variables, optimizer_variables = state
for variable, value in zip(model.trainable_variables, trainable_variables):
variable.assign(value)
for variable, value in zip(model.non_trainable_variables, non_trainable_variables):
variable.assign(value)
讓我們在此基本訓練迴圈中新增指標監控。
您可以在從頭撰寫的訓練迴圈中輕鬆重複使用內建 Keras 指標(或您撰寫的自訂指標)。以下是流程
train_step
引數和 compute_loss_and_updates
引數中包含 metric_variables
。compute_loss_and_updates
函數中呼叫 metric.stateless_update_state()
。它等同於 update_state()
– 只是無狀態。train_step
外部(在急迫範圍內)顯示指標的目前值時,請將新的指標變數值附加到指標物件,並呼叫 metric.result()
。metric.reset_state()
讓我們運用這些知識在訓練結束時計算訓練和驗證資料的 CategoricalAccuracy
# Get a fresh model
model = get_model()
# Instantiate an optimizer to train the model.
optimizer = keras.optimizers.Adam(learning_rate=1e-3)
# Instantiate a loss function.
loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True)
# Prepare the metrics.
train_acc_metric = keras.metrics.CategoricalAccuracy()
val_acc_metric = keras.metrics.CategoricalAccuracy()
def compute_loss_and_updates(
trainable_variables, non_trainable_variables, metric_variables, x, y
):
y_pred, non_trainable_variables = model.stateless_call(
trainable_variables, non_trainable_variables, x
)
loss = loss_fn(y, y_pred)
metric_variables = train_acc_metric.stateless_update_state(
metric_variables, y, y_pred
)
return loss, (non_trainable_variables, metric_variables)
grad_fn = jax.value_and_grad(compute_loss_and_updates, has_aux=True)
@jax.jit
def train_step(state, data):
(
trainable_variables,
non_trainable_variables,
optimizer_variables,
metric_variables,
) = state
x, y = data
(loss, (non_trainable_variables, metric_variables)), grads = grad_fn(
trainable_variables, non_trainable_variables, metric_variables, x, y
)
trainable_variables, optimizer_variables = optimizer.stateless_apply(
optimizer_variables, grads, trainable_variables
)
# Return updated state
return loss, (
trainable_variables,
non_trainable_variables,
optimizer_variables,
metric_variables,
)
我們也會準備一個評估步驟函數
@jax.jit
def eval_step(state, data):
trainable_variables, non_trainable_variables, metric_variables = state
x, y = data
y_pred, non_trainable_variables = model.stateless_call(
trainable_variables, non_trainable_variables, x
)
loss = loss_fn(y, y_pred)
metric_variables = val_acc_metric.stateless_update_state(
metric_variables, y, y_pred
)
return loss, (
trainable_variables,
non_trainable_variables,
metric_variables,
)
以下是我們的迴圈
# Build optimizer variables.
optimizer.build(model.trainable_variables)
trainable_variables = model.trainable_variables
non_trainable_variables = model.non_trainable_variables
optimizer_variables = optimizer.variables
metric_variables = train_acc_metric.variables
state = (
trainable_variables,
non_trainable_variables,
optimizer_variables,
metric_variables,
)
# Training loop
for step, data in enumerate(train_dataset):
data = (data[0].numpy(), data[1].numpy())
loss, state = train_step(state, data)
# Log every 100 batches.
if step % 100 == 0:
print(f"Training loss (for 1 batch) at step {step}: {float(loss):.4f}")
_, _, _, metric_variables = state
for variable, value in zip(train_acc_metric.variables, metric_variables):
variable.assign(value)
print(f"Training accuracy: {train_acc_metric.result()}")
print(f"Seen so far: {(step + 1) * batch_size} samples")
metric_variables = val_acc_metric.variables
(
trainable_variables,
non_trainable_variables,
optimizer_variables,
metric_variables,
) = state
state = trainable_variables, non_trainable_variables, metric_variables
# Eval loop
for step, data in enumerate(val_dataset):
data = (data[0].numpy(), data[1].numpy())
loss, state = eval_step(state, data)
# Log every 100 batches.
if step % 100 == 0:
print(f"Validation loss (for 1 batch) at step {step}: {float(loss):.4f}")
_, _, metric_variables = state
for variable, value in zip(val_acc_metric.variables, metric_variables):
variable.assign(value)
print(f"Validation accuracy: {val_acc_metric.result()}")
print(f"Seen so far: {(step + 1) * batch_size} samples")
Training loss (for 1 batch) at step 0: 70.8851
Training accuracy: 0.09375
Seen so far: 32 samples
Training loss (for 1 batch) at step 100: 2.1930
Training accuracy: 0.6596534848213196
Seen so far: 3232 samples
Training loss (for 1 batch) at step 200: 3.0249
Training accuracy: 0.7352300882339478
Seen so far: 6432 samples
Training loss (for 1 batch) at step 300: 0.6004
Training accuracy: 0.7588247656822205
Seen so far: 9632 samples
Training loss (for 1 batch) at step 400: 1.4633
Training accuracy: 0.7736907601356506
Seen so far: 12832 samples
Training loss (for 1 batch) at step 500: 1.3367
Training accuracy: 0.7826846241950989
Seen so far: 16032 samples
Training loss (for 1 batch) at step 600: 0.8767
Training accuracy: 0.7930532693862915
Seen so far: 19232 samples
Training loss (for 1 batch) at step 700: 0.3479
Training accuracy: 0.8004636168479919
Seen so far: 22432 samples
Training loss (for 1 batch) at step 800: 0.3608
Training accuracy: 0.8066869378089905
Seen so far: 25632 samples
Training loss (for 1 batch) at step 900: 0.7582
Training accuracy: 0.8117369413375854
Seen so far: 28832 samples
Training loss (for 1 batch) at step 1000: 1.3135
Training accuracy: 0.8142170310020447
Seen so far: 32032 samples
Training loss (for 1 batch) at step 1100: 1.0202
Training accuracy: 0.8186308145523071
Seen so far: 35232 samples
Training loss (for 1 batch) at step 1200: 0.6766
Training accuracy: 0.822023332118988
Seen so far: 38432 samples
Training loss (for 1 batch) at step 1300: 0.7606
Training accuracy: 0.8257110118865967
Seen so far: 41632 samples
Training loss (for 1 batch) at step 1400: 0.7657
Training accuracy: 0.8290283679962158
Seen so far: 44832 samples
Training loss (for 1 batch) at step 1500: 0.6563
Training accuracy: 0.831653892993927
Seen so far: 48032 samples
Validation loss (for 1 batch) at step 0: 0.1622
Validation accuracy: 0.8329269289970398
Seen so far: 32 samples
Validation loss (for 1 batch) at step 100: 0.7455
Validation accuracy: 0.8338780999183655
Seen so far: 3232 samples
Validation loss (for 1 batch) at step 200: 0.2738
Validation accuracy: 0.836174488067627
Seen so far: 6432 samples
Validation loss (for 1 batch) at step 300: 0.1255
Validation accuracy: 0.8390461206436157
Seen so far: 9632 samples
層和模型會遞迴追蹤在呼叫 self.add_loss(value)
的層的正向傳遞期間建立的任何損失。正向傳遞結束時,可透過屬性 model.losses
取得產生的純量損失值清單。
如果您想要使用這些損失元件,則應將它們加總並將它們新增至訓練步驟中的主要損失。
請考慮此層,它會建立活動正規化損失
class ActivityRegularizationLayer(keras.layers.Layer):
def call(self, inputs):
self.add_loss(1e-2 * jax.numpy.sum(inputs))
return inputs
讓我們建立一個使用它的非常簡單的模型
inputs = keras.Input(shape=(784,), name="digits")
x = keras.layers.Dense(64, activation="relu")(inputs)
# Insert activity regularization as a layer
x = ActivityRegularizationLayer()(x)
x = keras.layers.Dense(64, activation="relu")(x)
outputs = keras.layers.Dense(10, name="predictions")(x)
model = keras.Model(inputs=inputs, outputs=outputs)
現在,我們的 compute_loss_and_updates
函數應該如下所示
return_losses=True
傳遞至 model.stateless_call()
。losses
加總並將它們新增至主要損失。def compute_loss_and_updates(
trainable_variables, non_trainable_variables, metric_variables, x, y
):
y_pred, non_trainable_variables, losses = model.stateless_call(
trainable_variables, non_trainable_variables, x, return_losses=True
)
loss = loss_fn(y, y_pred)
if losses:
loss += jax.numpy.sum(losses)
metric_variables = train_acc_metric.stateless_update_state(
metric_variables, y, y_pred
)
return loss, non_trainable_variables, metric_variables
就是這樣!