Keras 3 API 文件 / 回呼 API / SwapEMAWeights

SwapEMAWeights

[原始碼]

SwapEMAWeights 類別

keras.callbacks.SwapEMAWeights(swap_on_epoch=False)

在評估之前和之後交換模型權重和 EMA 權重。

此回呼在模型評估之前,將模型的權重值替換為最佳化器的 EMA 權重值(過去模型權重值的指數移動平均,實作「Polyak 平均」),並在評估後還原先前的權重。

SwapEMAWeights 回呼應與設定 use_ema=True 的最佳化器搭配使用。

請注意,權重是以原地交換的方式進行,以節省記憶體。如果您在其他回呼中修改 EMA 權重或模型權重,則行為未定義。

範例

# Remember to set `use_ema=True` in the optimizer
optimizer = SGD(use_ema=True)
model.compile(optimizer=optimizer, loss=..., metrics=...)

# Metrics will be computed with EMA weights
model.fit(X_train, Y_train, callbacks=[SwapEMAWeights()])

# If you want to save model checkpoint with EMA weights, you can set
# `swap_on_epoch=True` and place ModelCheckpoint after SwapEMAWeights.
model.fit(
    X_train,
    Y_train,
    callbacks=[SwapEMAWeights(swap_on_epoch=True), ModelCheckpoint(...)]
)

引數

  • swap_on_epoch:是否在 on_epoch_begin()on_epoch_end() 執行交換。如果您想將 EMA 權重用於其他回呼(例如 ModelCheckpoint),這會很有用。預設值為 False