fit()
中的行為作者: fchollet
建立日期 2023/06/27
上次修改日期 2023/06/27
描述: 使用 JAX 覆寫 Model 類別的訓練步驟。
當您進行監督式學習時,可以使用 fit()
,一切都會順利運作。
當您需要控制每個細節時,可以完全從頭開始編寫自己的訓練迴圈。
但是,如果您需要自訂訓練演算法,但仍然想從 fit()
的便利功能(例如回呼、內建分散式支援或步驟融合)中獲益,該怎麼辦?
Keras 的核心原則是逐步揭示複雜性。您應該始終能夠以漸進的方式進入較低層級的工作流程。如果高層級功能不完全符合您的使用案例,您不應該直接跳崖。您應該能夠在保留相應的高層級便利性的同時,獲得對細節的更多控制權。
當您需要自訂 fit()
的行為時,您應該覆寫 Model
類別的訓練步驟函式。這是 fit()
針對每個批次的資料呼叫的函式。然後,您就可以像往常一樣呼叫 fit()
– 它會執行您自己的學習演算法。
請注意,這種模式不會阻止您使用函式式 API 建立模型。無論您是建立 Sequential
模型、函式式 API 模型還是子類別化模型,都可以這樣做。
讓我們看看它是如何運作的。
import os
# This guide can only be run with the JAX backend.
os.environ["KERAS_BACKEND"] = "jax"
import jax
import keras
import numpy as np
讓我們從一個簡單的範例開始
keras.Model
的子類別。compute_loss_and_updates()
方法,以計算損失以及模型非可訓練變數的更新值。在內部,它會呼叫 stateless_call()
和內建的 compute_loss()
。train_step()
方法,以計算目前的指標值(包括損失),以及可訓練變數、最佳化器變數和指標變數的更新值。請注意,您也可以透過以下方式將 sample_weight
引數納入考量
x, y, sample_weight = data
sample_weight
傳遞給 compute_loss()
stateless_update_state()
中,將 sample_weight
與 y
和 y_pred
一起傳遞給指標class CustomModel(keras.Model):
def compute_loss_and_updates(
self,
trainable_variables,
non_trainable_variables,
x,
y,
training=False,
):
y_pred, non_trainable_variables = self.stateless_call(
trainable_variables,
non_trainable_variables,
x,
training=training,
)
loss = self.compute_loss(x, y, y_pred)
return loss, (y_pred, non_trainable_variables)
def train_step(self, state, data):
(
trainable_variables,
non_trainable_variables,
optimizer_variables,
metrics_variables,
) = state
x, y = data
# Get the gradient function.
grad_fn = jax.value_and_grad(self.compute_loss_and_updates, has_aux=True)
# Compute the gradients.
(loss, (y_pred, non_trainable_variables)), grads = grad_fn(
trainable_variables,
non_trainable_variables,
x,
y,
training=True,
)
# Update trainable variables and optimizer variables.
(
trainable_variables,
optimizer_variables,
) = self.optimizer.stateless_apply(
optimizer_variables, grads, trainable_variables
)
# Update metrics.
new_metrics_vars = []
logs = {}
for metric in self.metrics:
this_metric_vars = metrics_variables[
len(new_metrics_vars) : len(new_metrics_vars) + len(metric.variables)
]
if metric.name == "loss":
this_metric_vars = metric.stateless_update_state(this_metric_vars, loss)
else:
this_metric_vars = metric.stateless_update_state(
this_metric_vars, y, y_pred
)
logs[metric.name] = metric.stateless_result(this_metric_vars)
new_metrics_vars += this_metric_vars
# Return metric logs and updated state variables.
state = (
trainable_variables,
non_trainable_variables,
optimizer_variables,
new_metrics_vars,
)
return logs, state
讓我們試試看
# Construct and compile an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
model.compile(optimizer="adam", loss="mse", metrics=["mae"])
# Just use `fit` as usual
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.fit(x, y, epochs=3)
Epoch 1/3
32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - mae: 1.0022 - loss: 1.2464
Epoch 2/3
32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 198us/step - mae: 0.5811 - loss: 0.4912
Epoch 3/3
32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 231us/step - mae: 0.4386 - loss: 0.2905
<keras.src.callbacks.history.History at 0x14da599c0>
當然,您可以直接跳過在 compile()
中傳遞損失函式,而是改為在 train_step
中手動完成所有操作。指標也是如此。
這裡有一個更底層的範例,只使用 compile()
來設定最佳化器
class CustomModel(keras.Model):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.loss_tracker = keras.metrics.Mean(name="loss")
self.mae_metric = keras.metrics.MeanAbsoluteError(name="mae")
self.loss_fn = keras.losses.MeanSquaredError()
def compute_loss_and_updates(
self,
trainable_variables,
non_trainable_variables,
x,
y,
training=False,
):
y_pred, non_trainable_variables = self.stateless_call(
trainable_variables,
non_trainable_variables,
x,
training=training,
)
loss = self.loss_fn(y, y_pred)
return loss, (y_pred, non_trainable_variables)
def train_step(self, state, data):
(
trainable_variables,
non_trainable_variables,
optimizer_variables,
metrics_variables,
) = state
x, y = data
# Get the gradient function.
grad_fn = jax.value_and_grad(self.compute_loss_and_updates, has_aux=True)
# Compute the gradients.
(loss, (y_pred, non_trainable_variables)), grads = grad_fn(
trainable_variables,
non_trainable_variables,
x,
y,
training=True,
)
# Update trainable variables and optimizer variables.
(
trainable_variables,
optimizer_variables,
) = self.optimizer.stateless_apply(
optimizer_variables, grads, trainable_variables
)
# Update metrics.
loss_tracker_vars = metrics_variables[: len(self.loss_tracker.variables)]
mae_metric_vars = metrics_variables[len(self.loss_tracker.variables) :]
loss_tracker_vars = self.loss_tracker.stateless_update_state(
loss_tracker_vars, loss
)
mae_metric_vars = self.mae_metric.stateless_update_state(
mae_metric_vars, y, y_pred
)
logs = {}
logs[self.loss_tracker.name] = self.loss_tracker.stateless_result(
loss_tracker_vars
)
logs[self.mae_metric.name] = self.mae_metric.stateless_result(mae_metric_vars)
new_metrics_vars = loss_tracker_vars + mae_metric_vars
# Return metric logs and updated state variables.
state = (
trainable_variables,
non_trainable_variables,
optimizer_variables,
new_metrics_vars,
)
return logs, state
@property
def metrics(self):
# We list our `Metric` objects here so that `reset_states()` can be
# called automatically at the start of each epoch
# or at the start of `evaluate()`.
return [self.loss_tracker, self.mae_metric]
# Construct an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
# We don't pass a loss or metrics here.
model.compile(optimizer="adam")
# Just use `fit` as usual -- you can use callbacks, etc.
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.fit(x, y, epochs=5)
Epoch 1/5
32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.6085 - mae: 0.6580
Epoch 2/5
32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 215us/step - loss: 0.2630 - mae: 0.4141
Epoch 3/5
32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 202us/step - loss: 0.2271 - mae: 0.3835
Epoch 4/5
32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 192us/step - loss: 0.2093 - mae: 0.3714
Epoch 5/5
32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 194us/step - loss: 0.2188 - mae: 0.3818
<keras.src.callbacks.history.History at 0x14de01420>
如果您想對 model.evaluate()
的呼叫執行相同的操作,該怎麼辦?那麼您將以完全相同的方式覆寫 test_step
。以下是它的樣子
class CustomModel(keras.Model):
def test_step(self, state, data):
# Unpack the data.
x, y = data
(
trainable_variables,
non_trainable_variables,
metrics_variables,
) = state
# Compute predictions and loss.
y_pred, non_trainable_variables = self.stateless_call(
trainable_variables,
non_trainable_variables,
x,
training=False,
)
loss = self.compute_loss(x, y, y_pred)
# Update metrics.
new_metrics_vars = []
for metric in self.metrics:
this_metric_vars = metrics_variables[
len(new_metrics_vars) : len(new_metrics_vars) + len(metric.variables)
]
if metric.name == "loss":
this_metric_vars = metric.stateless_update_state(this_metric_vars, loss)
else:
this_metric_vars = metric.stateless_update_state(
this_metric_vars, y, y_pred
)
logs = metric.stateless_result(this_metric_vars)
new_metrics_vars += this_metric_vars
# Return metric logs and updated state variables.
state = (
trainable_variables,
non_trainable_variables,
new_metrics_vars,
)
return logs, state
# Construct an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
model.compile(loss="mse", metrics=["mae"])
# Evaluate with our custom test_step
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.evaluate(x, y)
32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 973us/step - mae: 0.7887 - loss: 0.8385
[0.8385222554206848, 0.7956181168556213]
就是這樣!