Keras 3 API 文件 / 層 API / 注意力機制層 / MultiHeadAttention 層

MultiHeadAttention 層

[原始碼]

MultiHeadAttention 類別

keras.layers.MultiHeadAttention(
    num_heads,
    key_dim,
    value_dim=None,
    dropout=0.0,
    use_bias=True,
    output_shape=None,
    attention_axes=None,
    flash_attention=None,
    kernel_initializer="glorot_uniform",
    bias_initializer="zeros",
    kernel_regularizer=None,
    bias_regularizer=None,
    activity_regularizer=None,
    kernel_constraint=None,
    bias_constraint=None,
    seed=None,
    **kwargs
)

MultiHeadAttention 層。

這是多頭注意力機制的實作,如論文「Attention is all you Need」Vaswani 等人,2017 中所述。如果 querykeyvalue 相同,則這是自我注意力機制。query 中的每個時間步都會關注 key 中的對應序列,並傳回固定寬度的向量。

此層首先投影 querykeyvalue。這些(實際上)是長度為 num_attention_heads 的張量列表,其中對應的形狀為 (batch_size, <query 維度>, key_dim)(batch_size, <key/value 維度>, key_dim)(batch_size, <key/value 維度>, value_dim)

然後,查詢和鍵張量進行點積和縮放。這些會經過 softmax 處理以獲得注意力機率。然後,值張量會根據這些機率進行內插,然後串聯回單一張量。

最後,最後一個維度為 value_dim 的結果張量可以進行線性投影並傳回。

引數

  • num_heads:注意力頭的數量。
  • key_dim:查詢和鍵的每個注意力頭的大小。
  • value_dim:值的每個注意力頭的大小。
  • dropout:Dropout 機率。
  • use_bias:布林值,指示密集層是否使用偏置向量/矩陣。
  • output_shape:輸出張量的預期形狀,除了批次和序列維度之外。如果未指定,則投影回查詢特徵維度(查詢輸入的最後一個維度)。
  • attention_axes:應用注意力的軸。None 表示注意力應用於所有軸,但不包括批次、頭和特徵軸。
  • flash_attention:如果為 None,則此層會嘗試使用快閃注意力機制,以便在可能的情況下更快且更有效率地進行記憶體注意力計算。此行為可以使用 keras.config.enable_flash_attention()keras.config.disable_flash_attention() 進行配置。
  • kernel_initializer:密集層核心的初始化器。
  • bias_initializer:密集層偏置的初始化器。
  • kernel_regularizer:密集層核心的正規化器。
  • bias_regularizer:密集層偏置的正規化器。
  • activity_regularizer:密集層活動的正規化器。
  • kernel_constraint:密集層核心的約束。
  • bias_constraint:密集層核心的約束。
  • seed:用於設定 dropout 層種子的選用整數。

呼叫引數

  • query:形狀為 (B, T, dim) 的查詢張量,其中 B 是批次大小,T 是目標序列長度,而 dim 是特徵維度。
  • value:形狀為 (B, S, dim) 的值張量,其中 B 是批次大小,S 是來源序列長度,而 dim 是特徵維度。
  • key:選用的鍵張量,形狀為 (B, S, dim)。如果未給定,將使用 value 作為 keyvalue,這是最常見的情況。
  • attention_mask:形狀為 (B, T, S) 的布林遮罩,可防止注意力機制關注某些位置。布林遮罩指定哪些查詢元素可以關注哪些鍵元素,1 表示關注,0 表示不關注。廣播可以發生在遺失的批次維度和頭維度上。
  • return_attention_scores:布林值,指示輸出是否應為 (attention_output, attention_scores)(如果為 True),或 attention_output(如果為 False)。預設為 False
  • training:Python 布林值,指示此層應在訓練模式(新增 dropout)或推論模式(無 dropout)下運作。將使用父層/模型的訓練模式,如果沒有父層,則使用 False(推論)。
  • use_causal_mask:布林值,指示是否套用因果遮罩以防止符記關注未來的符記(例如,在解碼器 Transformer 中使用)。

傳回

  • attention_output:計算結果,形狀為 (B, T, E),其中 T 用於目標序列形狀,而 E 是查詢輸入的最後一個維度(如果 output_shapeNone)。否則,多頭輸出會投影到 output_shape 指定的形狀。
  • attention_scores:(選用)注意力軸上的多頭注意力係數。