Keras 2 API 文件 / 層 API / 注意力層 / MultiHeadAttention 層

MultiHeadAttention 層

[原始碼]

MultiHeadAttention 類別

tf_keras.layers.MultiHeadAttention(
    num_heads,
    key_dim,
    value_dim=None,
    dropout=0.0,
    use_bias=True,
    output_shape=None,
    attention_axes=None,
    kernel_initializer="glorot_uniform",
    bias_initializer="zeros",
    kernel_regularizer=None,
    bias_regularizer=None,
    activity_regularizer=None,
    kernel_constraint=None,
    bias_constraint=None,
    **kwargs
)

MultiHeadAttention 層。

這是論文「Attention is all you Need」(Vaswani 等人,2017)中所描述的多頭注意力機制的實作。如果 querykeyvalue 相同,則這就是自我注意力。query 中的每個時間步都會關注 key 中的對應序列,並傳回一個固定寬度的向量。

此層首先會投影 querykeyvalue。這些實際上是長度為 num_attention_heads 的張量列表,其中對應的形狀為 (batch_size, <query 維度>, key_dim)(batch_size, <key/value 維度>, key_dim)(batch_size, <key/value 維度>, value_dim)

然後,查詢 (query) 和鍵 (key) 張量進行點積和縮放。這些會經過 softmax 處理以獲得注意力機率。然後,值 (value) 張量會根據這些機率進行插值,然後串接回單個張量。

最後,最後一個維度為 value_dim 的結果張量可以進行線性投影並傳回。

當在自訂層中使用 MultiHeadAttention 時,自訂層必須實作自己的 build() 方法,並在那裡呼叫 MultiHeadAttention_build_from_signature()。這使得權重在載入模型時可以正確還原。

範例

對具有注意力遮罩的兩個序列輸入執行 1D 交叉注意力。傳回各頭的額外注意力權重。

>>> layer = MultiHeadAttention(num_heads=2, key_dim=2)
>>> target = tf.keras.Input(shape=[8, 16])
>>> source = tf.keras.Input(shape=[4, 16])
>>> output_tensor, weights = layer(target, source,
...                                return_attention_scores=True)
>>> print(output_tensor.shape)
(None, 8, 16)
>>> print(weights.shape)
(None, 2, 8, 4)

對軸 2 和 3 上的 5D 輸入張量執行 2D 自我注意力。

>>> layer = MultiHeadAttention(
...     num_heads=2, key_dim=2, attention_axes=(2, 3))
>>> input_tensor = tf.keras.Input(shape=[5, 3, 4, 16])
>>> output_tensor = layer(input_tensor, input_tensor)
>>> print(output_tensor.shape)
(None, 5, 3, 4, 16)

引數

  • num_heads:注意力頭的數量。
  • key_dim:查詢和鍵的每個注意力頭的大小。
  • value_dim:值的每個注意力頭的大小。
  • dropout:dropout 機率。
  • use_bias:布林值,指示密集層是否使用偏差向量/矩陣。
  • output_shape:除了批次和序列維度之外,輸出張量的預期形狀。如果未指定,則會投影回查詢特徵維度(查詢輸入的最後一個維度)。
  • attention_axes:應用注意力的軸。None 表示對所有軸(但不包括批次、頭和特徵)應用注意力。
  • kernel_initializer:密集層核心的初始化器。
  • bias_initializer:密集層偏差的初始化器。
  • kernel_regularizer:密集層核心的正規化器。
  • bias_regularizer:密集層偏差的正規化器。
  • activity_regularizer:密集層活動的正規化器。
  • kernel_constraint:密集層核心的約束。
  • bias_constraint:密集層核心的約束。

呼叫引數

  • query:形狀為 (B, T, dim) 的查詢 Tensor
  • value:形狀為 (B, S, dim) 的值 Tensor
  • key:可選的鍵 Tensor,形狀為 (B, S, dim)。如果未給定,將會使用 value 作為 keyvalue,這是最常見的情況。
  • attention_mask:形狀為 (B, T, S) 的布林遮罩,可防止關注特定位置。布林遮罩指定哪些查詢元素可以關注哪些鍵元素,1 表示關注,0 表示不關注。廣播可以發生在缺少的批次維度和頭維度上。
  • return_attention_scores:一個布林值,指示如果為 True,輸出應為 (attention_output, attention_scores),如果為 False,則為 attention_output。預設為 False
  • training:Python 布林值,指示層應在訓練模式(新增 dropout)還是推理模式(無 dropout)下運作。如果沒有父層,則會使用父層/模型的訓練模式,或 False(推理)。
  • use_causal_mask:一個布林值,指示是否套用因果遮罩以防止權杖關注未來的權杖(例如,在解碼器 Transformer 中使用)。

傳回

  • attention_output:計算結果,形狀為 (B, T, E),其中 T 是目標序列形狀,如果 output_shapeNone,則 E 是查詢輸入的最後一個維度。否則,多頭輸出會投影到 output_shape 指定的形狀。
  • attention_scores:[可選] 注意力軸上的多頭注意力係數。