Keras 3 API 文件 / 回呼 API / ModelCheckpoint

模型檢查點

[原始碼]

ModelCheckpoint 類別

keras.callbacks.ModelCheckpoint(
    filepath,
    monitor="val_loss",
    verbose=0,
    save_best_only=False,
    save_weights_only=False,
    mode="auto",
    save_freq="epoch",
    initial_value_threshold=None,
)

用於在特定頻率儲存 Keras 模型或模型權重的回呼。

ModelCheckpoint 回呼與使用 model.fit() 進行訓練結合使用,以在特定間隔儲存模型或權重(在檢查點檔案中),以便稍後載入模型或權重,從儲存的狀態繼續訓練。

此回呼提供的一些選項包括

  • 是否僅保留到目前為止達到「最佳效能」的模型,或者是否無論效能如何,在每個 epoch 結束時都儲存模型。
  • 「最佳」的定義;要監控的量,以及應最大化還是最小化。
  • 應該儲存的頻率。目前,此回呼支援在每個 epoch 結束時或在固定數量的訓練批次後儲存。
  • 是僅儲存權重,還是儲存整個模型。

範例

model.compile(loss=..., optimizer=...,
              metrics=['accuracy'])

EPOCHS = 10
checkpoint_filepath = '/tmp/ckpt/checkpoint.model.keras'
model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    monitor='val_accuracy',
    mode='max',
    save_best_only=True)

# Model is saved at the end of every epoch, if it's the best seen so far.
model.fit(epochs=EPOCHS, callbacks=[model_checkpoint_callback])

# The model (that are considered the best) can be loaded as -
keras.models.load_model(checkpoint_filepath)

# Alternatively, one could checkpoint just the model weights as -
checkpoint_filepath = '/tmp/ckpt/checkpoint.weights.h5'
model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=True,
    monitor='val_accuracy',
    mode='max',
    save_best_only=True)

# Model weights are saved at the end of every epoch, if it's the best seen
# so far.
model.fit(epochs=EPOCHS, callbacks=[model_checkpoint_callback])

# The model weights (that are considered the best) can be loaded as -
model.load_weights(checkpoint_filepath)

引數

  • filepath:字串或 PathLike,儲存模型檔案的路徑。filepath 可以包含具名格式化選項,這些選項將會填入 epoch 的值和 logs 中的鍵(在 on_epoch_end 中傳遞)。當 save_weights_only=True 時,filepath 名稱需要以 ".weights.h5" 結尾;當檢查點儲存整個模型時(預設),則應以 ".keras"".h5" 結尾。例如:如果 filepath"{epoch:02d}-{val_loss:.2f}.keras" 或 "{epoch:02d}-{val_loss:.2f}.weights.h5"`,則模型檢查點將會以 epoch 號碼和驗證損失儲存在檔名中。filepath 的目錄不應被任何其他回呼重複使用,以避免衝突。
  • monitor:要監控的指標名稱。通常指標由 Model.compile 方法設定。注意
    • 在名稱前加上 "val_" 以監控驗證指標。
    • 使用 "loss""val_loss" 監控模型的總損失。
    • 如果您將指標指定為字串,例如 "accuracy",請傳遞相同的字串(帶或不帶 "val_" 前綴)。
    • 如果您傳遞 metrics.Metric 物件,monitor 應設定為 metric.name
    • 如果您不確定指標名稱,可以查看 history = model.fit() 傳回的 history.history 字典的內容
    • 多輸出模型在指標名稱上設定額外的前綴。
  • verbose:Verbosity 模式,0 或 1。模式 0 為靜音,模式 1 在回呼採取動作時顯示訊息。
  • save_best_only:如果 save_best_only=True,則僅在模型被視為「最佳」時儲存,並且根據監控量最新的最佳模型將不會被覆寫。如果 filepath 不包含格式化選項(如 {epoch}),則 filepath 將會被每個新的更佳模型覆寫。
  • mode:{"auto", "min", "max"} 之一。如果 save_best_only=True,則是否覆寫目前儲存檔案的決定是根據監控量的最大化或最小化來決定的。對於 val_acc,這應該是 "max";對於 val_loss,這應該是 "min",依此類推。在 "auto" 模式中,如果監控量是 "acc" 或以 "fmeasure" 開頭,則模式設定為 "max";對於其餘量,則設定為 "min"
  • save_weights_only:如果為 True,則僅儲存模型的權重 (model.save_weights(filepath)),否則儲存完整模型 (model.save(filepath))。
  • save_freq"epoch" 或整數。當使用 "epoch" 時,回呼會在每個 epoch 後儲存模型。當使用整數時,回呼會在這麼多批次結束時儲存模型。如果 Model 是使用 steps_per_execution=N 編譯的,則將每 N 個批次檢查儲存條件。請注意,如果儲存未與 epoch 對齊,則監控的指標可能較不可靠(它可能僅反映 1 個批次,因為指標會在每個 epoch 重設)。預設為 "epoch"
  • initial_value_threshold:浮點數初始「最佳」值,用於監控的指標。僅在 save_best_value=True 時適用。僅當目前模型的效能優於此值時,才會覆寫已儲存的模型權重。