MultiHeadAttention
類別keras.layers.MultiHeadAttention(
num_heads,
key_dim,
value_dim=None,
dropout=0.0,
use_bias=True,
output_shape=None,
attention_axes=None,
flash_attention=None,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
seed=None,
**kwargs
)
多頭注意力層。
這是論文「Attention is all you Need」Vaswani et al., 2017 中描述的多頭注意力的實現。如果 query
、key
、value
相同,則為自我注意力。 query
中的每個時間步都會關注 key
中的對應序列,並返回一個固定寬度的向量。
此層首先投影 query
、key
和 value
。這些(實際上)是長度為 num_attention_heads
的張量列表,其中對應的形狀為 (batch_size, <query 維度>, key_dim)
、(batch_size, <key/value 維度>, key_dim)
、(batch_size, <key/value 維度>, value_dim)
。
然後,將查詢張量和鍵張量點乘並縮放。這些經過 softmax 處理以獲得注意力機率。然後,值張量會根據這些機率進行插值,然後連接回單個張量。
最後,最後一個維度為 value_dim
的結果張量可以採用線性投影並返回。
引數
None
表示對所有軸應用注意力,但不包括批次、頭部和特徵。None
,則該層會在可能的情況下嘗試使用快閃注意力,以實現更快、更省記憶體的注意力運算。可以使用 keras.config.enable_flash_attention()
或 keras.config.disable_flash_attention()
配置此行為。呼叫引數
(B, T, dim)
的查詢張量,其中 B
是批次大小,T
是目標序列長度,而 dim 是特徵維度。(B, S, dim)
的值張量,其中 B
是批次大小,S
是來源序列長度,而 dim 是特徵維度。(B, S, dim)
的可選鍵張量。如果未提供,則會將 value
用於 key
和 value
,這是最常見的情況。(B, T, S)
的布林遮罩,可防止注意力集中在某些位置。布林遮罩指定哪些查詢元素可以關注哪些鍵元素,1 表示注意力,0 表示不關注。遺失的批次維度和頭部維度可以進行廣播。True
,則輸出應為 (attention_output, attention_scores)
,如果為 False
,則輸出應為 attention_output
。預設為 False
。False
(推論)。返回
(B, T, E)
,其中 T
代表目標序列形狀,如果 output_shape
為 None
,則 E
代表查詢輸入的最後一個維度。否則,多頭輸出會投影到 output_shape
指定的形狀。