MultiHeadAttention
類別keras.layers.MultiHeadAttention(
num_heads,
key_dim,
value_dim=None,
dropout=0.0,
use_bias=True,
output_shape=None,
attention_axes=None,
flash_attention=None,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
seed=None,
**kwargs
)
MultiHeadAttention 層。
這是多頭注意力機制的實作,如論文「Attention is all you Need」Vaswani 等人,2017 中所述。如果 query
、key
、value
相同,則這是自我注意力機制。query
中的每個時間步都會關注 key
中的對應序列,並傳回固定寬度的向量。
此層首先投影 query
、key
和 value
。這些(實際上)是長度為 num_attention_heads
的張量列表,其中對應的形狀為 (batch_size, <query 維度>, key_dim)
、(batch_size, <key/value 維度>, key_dim)
、(batch_size, <key/value 維度>, value_dim)
。
然後,查詢和鍵張量進行點積和縮放。這些會經過 softmax 處理以獲得注意力機率。然後,值張量會根據這些機率進行內插,然後串聯回單一張量。
最後,最後一個維度為 value_dim
的結果張量可以進行線性投影並傳回。
引數
None
表示注意力應用於所有軸,但不包括批次、頭和特徵軸。None
,則此層會嘗試使用快閃注意力機制,以便在可能的情況下更快且更有效率地進行記憶體注意力計算。此行為可以使用 keras.config.enable_flash_attention()
或 keras.config.disable_flash_attention()
進行配置。呼叫引數
(B, T, dim)
的查詢張量,其中 B
是批次大小,T
是目標序列長度,而 dim 是特徵維度。(B, S, dim)
的值張量,其中 B
是批次大小,S
是來源序列長度,而 dim 是特徵維度。(B, S, dim)
。如果未給定,將使用 value
作為 key
和 value
,這是最常見的情況。(B, T, S)
的布林遮罩,可防止注意力機制關注某些位置。布林遮罩指定哪些查詢元素可以關注哪些鍵元素,1 表示關注,0 表示不關注。廣播可以發生在遺失的批次維度和頭維度上。(attention_output, attention_scores)
(如果為 True
),或 attention_output
(如果為 False
)。預設為 False
。False
(推論)。傳回
(B, T, E)
,其中 T
用於目標序列形狀,而 E
是查詢輸入的最後一個維度(如果 output_shape
為 None
)。否則,多頭輸出會投影到 output_shape
指定的形狀。