BatchNormalization
類別tf_keras.layers.BatchNormalization(
axis=-1,
momentum=0.99,
epsilon=0.001,
center=True,
scale=True,
beta_initializer="zeros",
gamma_initializer="ones",
moving_mean_initializer="zeros",
moving_variance_initializer="ones",
beta_regularizer=None,
gamma_regularizer=None,
beta_constraint=None,
gamma_constraint=None,
synchronized=False,
**kwargs
)
正規化其輸入的層。
批次正規化應用一種轉換,使平均輸出接近 0,且輸出標準差接近 1。
重要的是,批次正規化在訓練和推論期間的工作方式不同。
在訓練期間(即當使用 fit()
或以參數 training=True
呼叫層/模型時),該層使用目前輸入批次的平均值和標準差來正規化其輸出。也就是說,對於每個被正規化的通道,該層會返回 gamma * (batch - mean(batch)) / sqrt(var(batch) + epsilon) + beta
,其中
epsilon
是一個小常數(可作為建構子參數的一部分設定)gamma
是一個學習到的縮放因子(初始化為 1),可以通過將 scale=False
傳遞給建構子來禁用。beta
是一個學習到的偏移因子(初始化為 0),可以通過將 center=False
傳遞給建構子來禁用。在推論期間(即當使用 evaluate()
或 predict()
時,或以參數 training=False
(這是預設值)呼叫層/模型時),該層會使用在訓練期間所看到的批次的平均值和標準差的移動平均值來正規化其輸出。也就是說,它返回 gamma * (batch - self.moving_mean) / sqrt(self.moving_var+epsilon) + beta
。
self.moving_mean
和 self.moving_var
是不可訓練的變數,每次在訓練模式中呼叫層時都會更新,如下所示
moving_mean = moving_mean * momentum + mean(batch) * (1 - momentum)
moving_var = moving_var * momentum + var(batch) * (1 - momentum)
因此,該層只會在在與推論資料具有相似統計資料的資料上訓練後,才會在推論期間正規化其輸入。
當設定 synchronized=True
,且此層在 tf.distribute
策略中使用時,將會有一個 allreduce
呼叫,以在每個訓練步驟中聚合所有副本的批次統計資料。當模型在未指定任何分佈策略的情況下進行訓練時,設定 synchronized
沒有任何影響。
使用範例
strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(16))
model.add(tf.keras.layers.BatchNormalization(synchronized=True))
引數
data_format="channels_first"
的 Conv2D
層之後,在 BatchNormalization
中設定 axis=1
。beta
的偏移量加到正規化的張量。如果為 False,則會忽略 beta
。gamma
。如果為 False,則不使用 gamma
。當下一層是線性的(例如 nn.relu
),則可以禁用此設定,因為縮放將由下一層完成。tf.distribute
策略中使用時才相關。呼叫引數
training=True
:該層將使用目前輸入批次的平均值和變異數來正規化其輸入。training=False
:該層將使用在訓練期間學習到的移動統計資料的平均值和變異數來正規化其輸入。輸入形狀
任意。當將此層用作模型中的第一層時,請使用關鍵字引數 input_shape
(整數元組,不包含樣本軸)。
輸出形狀
與輸入相同的形狀。
參考文獻
關於在 BatchNormalization
層上設定 layer.trainable = False
設定 layer.trainable = False
的含義是凍結該層,也就是說,其內部狀態在訓練期間不會變更:其可訓練的權重在 fit()
或 train_on_batch()
期間不會更新,並且其狀態更新不會執行。
通常,這並不一定表示該層在推論模式下執行(這通常由呼叫層時可以傳遞的 training
引數控制)。「凍結狀態」和「推論模式」是兩個獨立的概念。
但是,在 BatchNormalization
層的情況下,在該層上設定 trainable = False
表示該層隨後將在推論模式下執行(表示它將使用移動平均值和移動變異數來正規化目前的批次,而不是使用目前批次的平均值和變異數)。
此行為已在 TensorFlow 2.0 中引入,以便在卷積網路微調用例中,啟用 layer.trainable = False
以產生最常預期的行為。
請注意:- 在包含其他層的模型上設定 trainable
將會遞迴設定所有內部層的 trainable
值。- 如果在模型的 compile()
呼叫之後變更 trainable
屬性的值,則新值直到再次呼叫 compile()
後才會對此模型生效。