Callback
類別keras.callbacks.Callback()
用於建立新回呼的基礎類別。
回呼可以傳遞給諸如 fit()
、evaluate()
和 predict()
等 Keras 方法,以便在模型訓練、評估和推論生命週期的各個階段掛鉤。
若要建立自訂回呼,請子類別化 keras.callbacks.Callback
並覆寫與目標階段相關聯的方法。
範例
>>> training_finished = False
>>> class MyCallback(Callback):
... def on_train_end(self, logs=None):
... global training_finished
... training_finished = True
>>> model = Sequential([
... layers.Dense(1, input_shape=(1,))])
>>> model.compile(loss='mean_squared_error')
>>> model.fit(np.array([[1.0]]), np.array([[1.0]]),
... callbacks=[MyCallback()])
>>> assert training_finished == True
如果您想在自訂訓練迴圈中使用 Callback
物件
callbacks.CallbackList
,以便它們可以一起呼叫。on_*
方法。像這樣範例
callbacks = keras.callbacks.CallbackList([...])
callbacks.append(...)
callbacks.on_train_begin(...)
for epoch in range(EPOCHS):
callbacks.on_epoch_begin(epoch)
for i, data in dataset.enumerate():
callbacks.on_train_batch_begin(i)
batch_logs = model.train_step(data)
callbacks.on_train_batch_end(i, batch_logs)
epoch_logs = ...
callbacks.on_epoch_end(epoch, epoch_logs)
final_logs=...
callbacks.on_train_end(final_logs)
屬性
Model
的實例。正在訓練的模型參考。回呼方法作為引數的 logs
字典將包含與目前批次或 epoch 相關的數量鍵(請參閱特定方法的說明字串)。