Keras 3 API 文件 / 最佳化器

最佳化器

可用的最佳化器


compile() & fit() 的使用

最佳化器是編譯 Keras 模型所需的兩個參數之一

import keras
from keras import layers

model = keras.Sequential()
model.add(layers.Dense(64, kernel_initializer='uniform', input_shape=(10,)))
model.add(layers.Activation('softmax'))

opt = keras.optimizers.Adam(learning_rate=0.01)
model.compile(loss='categorical_crossentropy', optimizer=opt)

您可以先實例化一個最佳化器,再將其傳遞給 model.compile(),如以上範例所示,或者您可以透過其字串識別碼傳遞它。在後一種情況下,將會使用最佳化器的預設參數。

# pass optimizer by name: default parameters will be used
model.compile(loss='categorical_crossentropy', optimizer='adam')

學習率衰減/排程

您可以使用學習率排程來調整最佳化器的學習率隨時間的變化

lr_schedule = keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate=1e-2,
    decay_steps=10000,
    decay_rate=0.9)
optimizer = keras.optimizers.SGD(learning_rate=lr_schedule)

請查看學習率排程 API 文件,以取得可用排程的列表。


基本最佳化器 API

這些方法和屬性是所有 Keras 最佳化器通用的。

[原始碼]

Optimizer 類別

keras.optimizers.Optimizer()

抽象最佳化器基礎類別。

如果您打算建立自己的最佳化演算法,請繼承此類別並覆寫以下方法

  • build:建立您的最佳化器相關變數,例如 SGD 最佳化器中的動量變數。
  • update_step:實作您的最佳化器的變數更新邏輯。
  • get_config:最佳化器的序列化。

範例

class SGD(Optimizer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.momentum = 0.9

    def build(self, variables):
        super().build(variables)
        self.momentums = []
        for variable in variables:
            self.momentums.append(
                self.add_variable_from_reference(
                    reference_variable=variable, name="momentum"
                )
            )

    def update_step(self, gradient, variable, learning_rate):
        learning_rate = ops.cast(learning_rate, variable.dtype)
        gradient = ops.cast(gradient, variable.dtype)
        m = self.momentums[self._get_variable_index(variable)]
        self.assign(
            m,
            ops.subtract(
                ops.multiply(m, ops.cast(self.momentum, variable.dtype)),
                ops.multiply(gradient, learning_rate),
            ),
        )
        self.assign_add(variable, m)

    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "momentum": self.momentum,
                "nesterov": self.nesterov,
            }
        )
        return config

[原始碼]

apply_gradients 方法

Optimizer.apply_gradients(grads_and_vars)

variables 屬性

keras.optimizers.Optimizer.variables