Keras 3 API 文件 / 回呼 API / 基礎回呼類別

基礎回呼類別

[來源]

Callback 類別

keras.callbacks.Callback()

用於建立新回呼的基礎類別。

回呼可以傳遞給諸如 fit()evaluate()predict() 等 Keras 方法,以便在模型訓練、評估和推論生命週期的各個階段掛鉤。

若要建立自訂回呼,請子類別化 keras.callbacks.Callback 並覆寫與目標階段相關聯的方法。

範例

>>> training_finished = False
>>> class MyCallback(Callback):
...   def on_train_end(self, logs=None):
...     global training_finished
...     training_finished = True
>>> model = Sequential([
...     layers.Dense(1, input_shape=(1,))])
>>> model.compile(loss='mean_squared_error')
>>> model.fit(np.array([[1.0]]), np.array([[1.0]]),
...           callbacks=[MyCallback()])
>>> assert training_finished == True

如果您想在自訂訓練迴圈中使用 Callback 物件

  1. 您應該將所有回呼打包成單一的 callbacks.CallbackList,以便它們可以一起呼叫。
  2. 您需要手動在迴圈中的適當位置呼叫所有 on_* 方法。像這樣

範例

callbacks =  keras.callbacks.CallbackList([...])
callbacks.append(...)
callbacks.on_train_begin(...)
for epoch in range(EPOCHS):
    callbacks.on_epoch_begin(epoch)
    for i, data in dataset.enumerate():
    callbacks.on_train_batch_begin(i)
    batch_logs = model.train_step(data)
    callbacks.on_train_batch_end(i, batch_logs)
    epoch_logs = ...
    callbacks.on_epoch_end(epoch, epoch_logs)
final_logs=...
callbacks.on_train_end(final_logs)

屬性

  • params:Dict。訓練參數(例如:詳細程度、批次大小、epoch 數量...)。
  • modelModel 的實例。正在訓練的模型參考。

回呼方法作為引數的 logs 字典將包含與目前批次或 epoch 相關的數量鍵(請參閱特定方法的說明字串)。