Keras 3 API 文件 / 回呼 API / 學習率排程器

學習率排程器

[原始碼]

LearningRateScheduler 類別

keras.callbacks.LearningRateScheduler(schedule, verbose=0)

學習率排程器。

在每個 epoch 開始時,此回呼從 __init__ 中提供的 schedule 函數取得更新後的學習率值,包含當前 epoch 和當前學習率,並將更新後的學習率應用於優化器。

參數

  • schedule: 一個函數,接受 epoch 索引(整數,從 0 開始索引)和當前學習率(浮點數)作為輸入,並返回新的學習率作為輸出(浮點數)。
  • verbose: 整數。0:靜默,1:記錄更新訊息。

範例

>>> # This function keeps the initial learning rate for the first ten epochs
>>> # and decreases it exponentially after that.
>>> def scheduler(epoch, lr):
...     if epoch < 10:
...         return lr
...     else:
...         return lr * ops.exp(-0.1)
>>>
>>> model = keras.models.Sequential([keras.layers.Dense(10)])
>>> model.compile(keras.optimizers.SGD(), loss='mse')
>>> round(model.optimizer.learning_rate, 5)
0.01
>>> callback = keras.callbacks.LearningRateScheduler(scheduler)
>>> history = model.fit(np.arange(100).reshape(5, 20), np.zeros(5),
...                     epochs=15, callbacks=[callback], verbose=0)
>>> round(model.optimizer.learning_rate, 5)
0.00607