Keras 3 API 文件 / 層 API / 正規化層 / 批次正規化層

批次正規化層 (BatchNormalization layer)

[原始碼]

BatchNormalization 類別

keras.layers.BatchNormalization(
    axis=-1,
    momentum=0.99,
    epsilon=0.001,
    center=True,
    scale=True,
    beta_initializer="zeros",
    gamma_initializer="ones",
    moving_mean_initializer="zeros",
    moving_variance_initializer="ones",
    beta_regularizer=None,
    gamma_regularizer=None,
    beta_constraint=None,
    gamma_constraint=None,
    synchronized=False,
    **kwargs
)

正規化其輸入的層。

批次正規化會套用一個轉換,使輸出平均值接近 0,輸出標準差接近 1。

重要的是,批次正規化在訓練期間和推論期間的運作方式不同。

在訓練期間(即當使用 fit() 或以 training=True 參數呼叫層/模型時),該層會使用目前輸入批次的平均值和標準差來正規化其輸出。也就是說,對於每個被正規化的通道,該層會返回 gamma * (batch - mean(batch)) / sqrt(var(batch) + epsilon) + beta,其中

  • epsilon 是一個小的常數(可在建構函式參數中設定)
  • gamma 是一個已學習的縮放因子(初始化為 1),可以透過將 scale=False 傳遞給建構函式來停用。
  • beta 是一個已學習的偏移因子(初始化為 0),可以透過將 center=False 傳遞給建構函式來停用。

在推論期間(即當使用 evaluate()predict() 或以 training=False 參數呼叫層/模型時(這是預設值)),該層會使用在訓練期間看到的批次的平均值和標準差的移動平均值來正規化其輸出。也就是說,它會返回 gamma * (batch - self.moving_mean) / sqrt(self.moving_var+epsilon) + beta

self.moving_meanself.moving_var 是不可訓練的變數,每次在訓練模式下呼叫該層時都會更新,如下所示

  • moving_mean = moving_mean * momentum + mean(batch) * (1 - momentum)
  • moving_var = moving_var * momentum + var(batch) * (1 - momentum)

因此,該層只會在經過在與推論資料具有相似統計數據的資料上訓練後,才會在推論期間正規化其輸入。

參數

  • axis:整數,應該被正規化的軸(通常是特徵軸)。例如,在具有 data_format="channels_first"Conv2D 層之後,請使用 axis=1
  • momentum:用於移動平均的動量。
  • epsilon:添加到變異數以避免除以零的小浮點數。
  • center:如果為 True,則將 beta 的偏移量添加到正規化的張量。如果為 False,則會忽略 beta
  • scale:如果為 True,則乘以 gamma。如果為 False,則不使用 gamma。當下一層是線性層時,可以停用此選項,因為縮放將由下一層完成。
  • beta_initializer:beta 權重的初始化器。
  • gamma_initializer:gamma 權重的初始化器。
  • moving_mean_initializer:移動平均的初始化器。
  • moving_variance_initializer:移動變異數的初始化器。
  • beta_regularizer:beta 權重的可選正規化器。
  • gamma_regularizer:gamma 權重的可選正規化器。
  • beta_constraint:beta 權重的可選約束。
  • gamma_constraint:gamma 權重的可選約束。
  • synchronized:僅適用於 TensorFlow 後端。如果為 True,則在分散式訓練策略中,會在每個訓練步驟同步所有裝置的層的整體批次統計數據(平均值和變異數)。如果為 False,則每個副本使用其自己的本地批次統計數據。
  • **kwargs:基礎層的關鍵字參數(例如 namedtype)。

呼叫參數

  • inputs:輸入張量(任何階數)。
  • training:Python 布林值,指示該層應在訓練模式還是推論模式下運作。
    • training=True:該層將使用目前輸入批次的平均值和變異數來正規化其輸入。
    • training=False:該層將使用其在訓練期間學習的移動統計數據的平均值和變異數來正規化其輸入。
  • mask:形狀可廣播到 inputs 張量的二元張量,其中 True 值表示應該計算平均值和變異數的位置。目前輸入的遮罩元素在訓練期間不會被納入平均值和變異數計算中。任何先前的未遮罩元素值都將被納入考量,直到它們的動量過期。

參考資料

關於在 BatchNormalization 層上設定 layer.trainable = False

設定 layer.trainable = False 的意義在於凍結該層,即其內部狀態在訓練期間不會改變:其可訓練的權重不會在 fit()train_on_batch() 期間更新,並且其狀態更新將不會執行。

通常,這不一定表示該層在推論模式下執行(這通常由呼叫層時可以傳遞的 training 參數控制)。「凍結狀態」和「推論模式」是兩個不同的概念。

但是,對於 BatchNormalization 層,在該層上設定 trainable = False 表示該層隨後將在推論模式下執行(這表示它將使用移動平均值和移動變異數來正規化目前批次,而不是使用目前批次的平均值和變異數)。

請注意

  • 在包含其他層的模型上設定 trainable 會遞迴地設定所有內部層的 trainable 值。
  • 如果在對模型呼叫 compile() 後更改 trainable 屬性的值,則直到再次呼叫 compile() 後,新值才會對此模型生效。