Keras 3 API 文件 / 層 API / 正規化層 / 批次正規化層

批次正規化層

[原始碼]

BatchNormalization 類別

keras.layers.BatchNormalization(
    axis=-1,
    momentum=0.99,
    epsilon=0.001,
    center=True,
    scale=True,
    beta_initializer="zeros",
    gamma_initializer="ones",
    moving_mean_initializer="zeros",
    moving_variance_initializer="ones",
    beta_regularizer=None,
    gamma_regularizer=None,
    beta_constraint=None,
    gamma_constraint=None,
    synchronized=False,
    **kwargs
)

正規化其輸入的層。

批次正規化應用一種轉換,保持平均輸出接近 0,輸出標準差接近 1。

重要的是,批次正規化在訓練和推論期間的工作方式不同。

在訓練期間(即當使用 fit() 或以參數 training=True 呼叫層/模型時),該層使用當前輸入批次的平均值和標準差來正規化其輸出。也就是說,對於每個正在正規化的通道,該層返回 gamma * (batch - mean(batch)) / sqrt(var(batch) + epsilon) + beta,其中

  • epsilon 是一個小常數(可在建構子參數中配置)
  • gamma 是一個學習到的縮放因子(初始化為 1),可以透過傳遞 scale=False 給建構子來停用。
  • beta 是一個學習到的偏移因子(初始化為 0),可以透過傳遞 center=False 給建構子來停用。

在推論期間(即當使用 evaluate()predict() 或以參數 training=False 呼叫層/模型時(這是預設值)),該層使用在訓練期間看到的批次的平均值和標準差的移動平均來正規化其輸出。也就是說,它返回 gamma * (batch - self.moving_mean) / sqrt(self.moving_var+epsilon) + beta

self.moving_meanself.moving_var 是不可訓練的變數,每次在訓練模式下呼叫層時都會更新,如下所示

  • moving_mean = moving_mean * momentum + mean(batch) * (1 - momentum)
  • moving_var = moving_var * momentum + var(batch) * (1 - momentum)

因此,該層僅在已在具有與推論資料相似統計數據的資料上訓練後,才會在推論期間正規化其輸入。

參數

  • axis:整數,應該正規化的軸(通常是特徵軸)。例如,在使用 data_format="channels_first"Conv2D 層之後,使用 axis=1
  • momentum:移動平均的動量。
  • epsilon:添加到變異數的小浮點數,以避免除以零。
  • center:如果為 True,則將 beta 偏移量添加到正規化張量。如果為 False,則忽略 beta
  • scale:如果為 True,則乘以 gamma。如果為 False,則不使用 gamma。當下一層是線性層時,可以停用此項,因為縮放將由下一層完成。
  • beta_initializer:beta 權重的初始化器。
  • gamma_initializer:gamma 權重的初始化器。
  • moving_mean_initializer:移動平均值的初始化器。
  • moving_variance_initializer:移動變異數的初始化器。
  • beta_regularizer:beta 權重的可選正規化器。
  • gamma_regularizer:gamma 權重的可選正規化器。
  • beta_constraint:beta 權重的可選約束。
  • gamma_constraint:gamma 權重的可選約束。
  • synchronized:僅適用於 TensorFlow 後端。如果為 True,則在分散式訓練策略中的每個訓練步驟,跨所有裝置同步層的全域批次統計資訊(平均值和變異數)。如果為 False,則每個副本使用其自己的本機批次統計資訊。
  • **kwargs:基礎層關鍵字參數(例如 namedtype)。

呼叫參數

  • inputs:輸入張量(任何秩)。
  • training:Python 布林值,指示層應以訓練模式還是推論模式運作。
    • training=True:該層將使用當前輸入批次的平均值和變異數來正規化其輸入。
    • training=False:該層將使用其移動統計資訊的平均值和變異數來正規化其輸入,這些統計資訊是在訓練期間學習到的。
  • mask:形狀可廣播到 inputs 張量的二元張量,其中 True 值指示應計算平均值和變異數的位置。在訓練期間,當前輸入的遮罩元素不納入平均值和變異數計算。任何先前的未遮罩元素值都將被納入考量,直到其動量過期。

參考文獻

關於在 BatchNormalization 層上設定 layer.trainable = False

設定 layer.trainable = False 的含義是凍結該層,即其內部狀態在訓練期間不會改變:其可訓練權重在 fit()train_on_batch() 期間不會更新,並且其狀態更新也不會執行。

通常,這不一定表示該層在推論模式下執行(這通常由呼叫層時可以傳遞的 training 參數控制)。「凍結狀態」和「推論模式」是兩個獨立的概念。

但是,在 BatchNormalization 層的情況下,在層上設定 trainable = False 表示該層後續將在推論模式下執行(表示它將使用移動平均值和移動變異數來正規化當前批次,而不是使用當前批次的平均值和變異數)。

請注意

  • 在包含其他層的模型上設定 trainable 將遞迴地設定所有內部層的 trainable 值。
  • 如果在模型上呼叫 compile() 之後更改 trainable 屬性的值,則新值在此模型上不會生效,直到再次呼叫 compile()