LayerNormalization
類別tf_keras.layers.LayerNormalization(
axis=-1,
epsilon=0.001,
center=True,
scale=True,
beta_initializer="zeros",
gamma_initializer="ones",
beta_regularizer=None,
gamma_regularizer=None,
beta_constraint=None,
gamma_constraint=None,
**kwargs
)
Layer Normalization 層 (Ba et al., 2016)。
獨立正規化批次中每個給定範例的前一層的激活,而不是像批次正規化一樣跨批次進行正規化。也就是說,它應用一種轉換,使每個範例內的平均激活接近 0,並且使激活標準差接近 1。
給定一個張量 inputs
,會計算矩,並在 axis
中指定的軸上執行正規化。
範例
>>> data = tf.constant(np.arange(10).reshape(5, 2) * 10, dtype=tf.float32)
>>> print(data)
tf.Tensor(
[[ 0. 10.]
[20. 30.]
[40. 50.]
[60. 70.]
[80. 90.]], shape=(5, 2), dtype=float32)
>>> layer = tf.keras.layers.LayerNormalization(axis=1)
>>> output = layer(data)
>>> print(output)
tf.Tensor(
[[-1. 1.]
[-1. 1.]
[-1. 1.]
[-1. 1.]
[-1. 1.]], shape=(5, 2), dtype=float32)
請注意,使用 Layer Normalization 時,正規化是在每個範例內的軸上進行,而不是跨批次中的不同範例進行。
如果啟用 scale
或 center
,則層會通過可訓練變數 gamma
廣播來縮放正規化輸出,並通過可訓練變數 beta
廣播來使輸出居中。gamma
的預設值為全一張量,beta
的預設值為全零張量,因此在訓練開始之前,居中和縮放是不做任何操作的。
因此,啟用縮放和居中後,正規化方程式如下:
令小批次的內部激活為 inputs
。
對於 inputs
中具有 k
個特徵的每個樣本 x_i
,我們會計算樣本的平均值和變異數
mean_i = sum(x_i[j] for j in range(k)) / k
var_i = sum((x_i[j] - mean_i) ** 2 for j in range(k)) / k
然後計算正規化的 x_i_normalized
,包括一個小的因子 epsilon
以實現數值穩定性。
x_i_normalized = (x_i - mean_i) / sqrt(var_i + epsilon)
最後,x_i_normalized
會通過 gamma
和 beta
進行線性轉換,這些都是學習的參數
output_i = x_i_normalized * gamma + beta
gamma
和 beta
將跨越 axis
中指定的 inputs
軸,且輸入的形狀的這部分必須完全定義。
例如
>>> layer = tf.keras.layers.LayerNormalization(axis=[1, 2, 3])
>>> layer.build([5, 20, 30, 40])
>>> print(layer.beta.shape)
(20, 30, 40)
>>> print(layer.gamma.shape)
(20, 30, 40)
請注意,Layer Normalization 的其他實作方式可能會選擇在與正規化所跨越的軸不同的軸集合上定義 gamma
和 beta
。例如,群組正規化 (Wu et al. 2018) 群組大小為 1 對應於一個 Layer Normalization,它會在高度、寬度和通道上正規化,且 gamma
和 beta
僅跨越通道維度。因此,此 Layer Normalization 實作方式不會與群組大小設定為 1 的群組正規化層相符。
引數
-1
是輸入中的最後一個維度。預設值為 -1
。beta
的偏移量新增至正規化張量。如果為 False,則會忽略 beta
。預設值為 True
。gamma
。如果為 False,則不會使用 gamma
。當下一層是線性的時 (例如 nn.relu
),則可以停用此項,因為縮放將由下一層完成。預設值為 True
。輸入形狀
任意。當使用此層作為模型中的第一層時,請使用關鍵字引數 input_shape
(整數元組,不包含樣本軸)。
輸出形狀
與輸入形狀相同。
參考文獻