Keras 3 API 文件 / 層 API / 注意力機制層 / GroupQueryAttention

分組查詢注意力機制

[原始碼]

GroupedQueryAttention 類別

keras.layers.GroupQueryAttention(
    head_dim,
    num_query_heads,
    num_key_value_heads,
    dropout=0.0,
    use_bias=True,
    flash_attention=None,
    kernel_initializer="glorot_uniform",
    bias_initializer="zeros",
    kernel_regularizer=None,
    bias_regularizer=None,
    activity_regularizer=None,
    kernel_constraint=None,
    bias_constraint=None,
    seed=None,
    **kwargs
)

分組查詢注意力機制層。

這是 Ainslie 等人於 2023 年提出的分組查詢注意力機制的實作。此處 num_key_value_heads 表示群組數量,將 num_key_value_heads 設定為 1 等同於多查詢注意力機制,而當 num_key_value_heads 等於 num_query_heads 時,則等同於多頭注意力機制。

此層首先投影 querykeyvalue 張量。然後,重複 keyvalue 以匹配 query 的頭數。

接著,將 query 縮放並與 key 張量進行點積運算。這些結果會經過 softmax 函數以獲得注意力權重。然後,值張量會根據這些權重進行插值,並串接回單一張量。

參數

  • head_dim:每個注意力頭的大小。
  • num_query_heads:查詢注意力頭的數量。
  • num_key_value_heads:鍵和值注意力頭的數量。
  • dropout:Dropout 機率。
  • use_bias:布林值,指示密集層是否使用偏差向量/矩陣。
  • flash_attention:若為 None,則此層會在可能的情況下嘗試使用快閃注意力機制,以實現更快且更節省記憶體的注意力計算。此行為可以使用 keras.config.enable_flash_attention()keras.config.disable_flash_attention() 進行配置。
  • kernel_initializer:密集層核心的初始化器。
  • bias_initializer:密集層偏差的初始化器。
  • kernel_regularizer:密集層核心的正規化器。
  • bias_regularizer:密集層偏差的正規化器。
  • activity_regularizer:密集層活動的正規化器。
  • kernel_constraint:密集層核心的約束。
  • bias_constraint:密集層核心的約束。
  • seed:用於設定 dropout 層種子的可選整數。

呼叫參數

  • query:形狀為 (batch_dim, target_seq_len, feature_dim) 的查詢張量,其中 batch_dim 是批次大小,target_seq_len 是目標序列的長度,而 feature_dim 是特徵的維度。
  • value:形狀為 (batch_dim, source_seq_len, feature_dim) 的值張量,其中 batch_dim 是批次大小,source_seq_len 是來源序列的長度,而 feature_dim 是特徵的維度。
  • key:可選的鍵張量,形狀為 (batch_dim, source_seq_len, feature_dim)。若未提供,將使用 value 同時作為 keyvalue,這是最常見的情況。
  • attention_mask:形狀為 (batch_dim, target_seq_len, source_seq_len) 的布林遮罩,用於防止注意力機制關注某些位置。此布林遮罩指定哪些查詢元素可以關注哪些鍵元素,其中 1 表示關注,0 表示不關注。廣播機制可以應用於缺失的批次維度和頭維度。
  • return_attention_scores:布林值,指示若為 True,輸出應為 (attention_output, attention_scores),若為 False,則輸出為 attention_output。預設為 False
  • training:Python 布林值,指示此層應以訓練模式(加入 dropout)或推論模式(不加入 dropout)運作。將使用父層/模型的訓練模式,或在沒有父層的情況下使用 False(推論模式)。
  • use_causal_mask:布林值,指示是否應用因果遮罩,以防止 tokens 關注未來的 tokens(例如,用於解碼器 Transformer)。

返回

  • attention_output:計算結果,形狀為 (batch_dim, target_seq_len, feature_dim),其中 target_seq_len 代表目標序列長度,而 feature_dim 則是查詢輸入的最後一個維度。
  • attention_scores:(可選)注意力係數,形狀為 (batch_dim, num_query_heads, target_seq_len, source_seq_len)