GroupedQueryAttention
類別keras.layers.GroupQueryAttention(
head_dim,
num_query_heads,
num_key_value_heads,
dropout=0.0,
use_bias=True,
flash_attention=None,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
seed=None,
**kwargs
)
分組查詢注意力機制層。
這是 Ainslie 等人於 2023 年提出的分組查詢注意力機制的實作。此處 num_key_value_heads
表示群組數量,將 num_key_value_heads
設定為 1 等同於多查詢注意力機制,而當 num_key_value_heads
等於 num_query_heads
時,則等同於多頭注意力機制。
此層首先投影 query
、key
和 value
張量。然後,重複 key
和 value
以匹配 query
的頭數。
接著,將 query
縮放並與 key
張量進行點積運算。這些結果會經過 softmax 函數以獲得注意力權重。然後,值張量會根據這些權重進行插值,並串接回單一張量。
參數
None
,則此層會在可能的情況下嘗試使用快閃注意力機制,以實現更快且更節省記憶體的注意力計算。此行為可以使用 keras.config.enable_flash_attention()
或 keras.config.disable_flash_attention()
進行配置。呼叫參數
(batch_dim, target_seq_len, feature_dim)
的查詢張量,其中 batch_dim
是批次大小,target_seq_len
是目標序列的長度,而 feature_dim
是特徵的維度。(batch_dim, source_seq_len, feature_dim)
的值張量,其中 batch_dim
是批次大小,source_seq_len
是來源序列的長度,而 feature_dim
是特徵的維度。(batch_dim, source_seq_len, feature_dim)
。若未提供,將使用 value
同時作為 key
和 value
,這是最常見的情況。(batch_dim, target_seq_len, source_seq_len)
的布林遮罩,用於防止注意力機制關注某些位置。此布林遮罩指定哪些查詢元素可以關注哪些鍵元素,其中 1 表示關注,0 表示不關注。廣播機制可以應用於缺失的批次維度和頭維度。True
,輸出應為 (attention_output, attention_scores)
,若為 False
,則輸出為 attention_output
。預設為 False
。False
(推論模式)。返回
(batch_dim, target_seq_len, feature_dim)
,其中 target_seq_len
代表目標序列長度,而 feature_dim
則是查詢輸入的最後一個維度。(batch_dim, num_query_heads, target_seq_len, source_seq_len)
。