ModelCheckpoint
類別keras.callbacks.ModelCheckpoint(
filepath,
monitor="val_loss",
verbose=0,
save_best_only=False,
save_weights_only=False,
mode="auto",
save_freq="epoch",
initial_value_threshold=None,
)
用於在特定頻率儲存 Keras 模型或模型權重的回呼。
ModelCheckpoint
回呼與使用 model.fit()
進行訓練結合使用,以在特定間隔儲存模型或權重(在檢查點檔案中),以便稍後載入模型或權重,從儲存的狀態繼續訓練。
此回呼提供的一些選項包括
範例
model.compile(loss=..., optimizer=...,
metrics=['accuracy'])
EPOCHS = 10
checkpoint_filepath = '/tmp/ckpt/checkpoint.model.keras'
model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
filepath=checkpoint_filepath,
monitor='val_accuracy',
mode='max',
save_best_only=True)
# Model is saved at the end of every epoch, if it's the best seen so far.
model.fit(epochs=EPOCHS, callbacks=[model_checkpoint_callback])
# The model (that are considered the best) can be loaded as -
keras.models.load_model(checkpoint_filepath)
# Alternatively, one could checkpoint just the model weights as -
checkpoint_filepath = '/tmp/ckpt/checkpoint.weights.h5'
model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
filepath=checkpoint_filepath,
save_weights_only=True,
monitor='val_accuracy',
mode='max',
save_best_only=True)
# Model weights are saved at the end of every epoch, if it's the best seen
# so far.
model.fit(epochs=EPOCHS, callbacks=[model_checkpoint_callback])
# The model weights (that are considered the best) can be loaded as -
model.load_weights(checkpoint_filepath)
引數
PathLike
,儲存模型檔案的路徑。filepath
可以包含具名格式化選項,這些選項將會填入 epoch
的值和 logs
中的鍵(在 on_epoch_end
中傳遞)。當 save_weights_only=True
時,filepath
名稱需要以 ".weights.h5"
結尾;當檢查點儲存整個模型時(預設),則應以 ".keras"
或 ".h5"
結尾。例如:如果 filepath
是 "{epoch:02d}-{val_loss:.2f}.keras"
或 "{epoch:02d}-{val_loss:.2f}.weights.h5"`,則模型檢查點將會以 epoch 號碼和驗證損失儲存在檔名中。filepath 的目錄不應被任何其他回呼重複使用,以避免衝突。Model.compile
方法設定。注意"val_"
以監控驗證指標。"loss"
或 "val_loss"
監控模型的總損失。"accuracy"
,請傳遞相同的字串(帶或不帶 "val_"
前綴)。metrics.Metric
物件,monitor
應設定為 metric.name
history = model.fit()
傳回的 history.history
字典的內容save_best_only=True
,則僅在模型被視為「最佳」時儲存,並且根據監控量最新的最佳模型將不會被覆寫。如果 filepath
不包含格式化選項(如 {epoch}
),則 filepath
將會被每個新的更佳模型覆寫。"auto"
, "min"
, "max"
} 之一。如果 save_best_only=True
,則是否覆寫目前儲存檔案的決定是根據監控量的最大化或最小化來決定的。對於 val_acc
,這應該是 "max"
;對於 val_loss
,這應該是 "min"
,依此類推。在 "auto"
模式中,如果監控量是 "acc"
或以 "fmeasure"
開頭,則模式設定為 "max"
;對於其餘量,則設定為 "min"
。True
,則僅儲存模型的權重 (model.save_weights(filepath)
),否則儲存完整模型 (model.save(filepath)
)。"epoch"
或整數。當使用 "epoch"
時,回呼會在每個 epoch 後儲存模型。當使用整數時,回呼會在這麼多批次結束時儲存模型。如果 Model
是使用 steps_per_execution=N
編譯的,則將每 N 個批次檢查儲存條件。請注意,如果儲存未與 epoch 對齊,則監控的指標可能較不可靠(它可能僅反映 1 個批次,因為指標會在每個 epoch 重設)。預設為 "epoch"
。save_best_value=True
時適用。僅當目前模型的效能優於此值時,才會覆寫已儲存的模型權重。