MultiHeadAttention
類別tf_keras.layers.MultiHeadAttention(
num_heads,
key_dim,
value_dim=None,
dropout=0.0,
use_bias=True,
output_shape=None,
attention_axes=None,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
**kwargs
)
MultiHeadAttention 層。
這是論文「Attention is all you Need」(Vaswani 等人,2017)中所描述的多頭注意力機制的實作。如果 query
、key
、value
相同,則這就是自我注意力。query
中的每個時間步都會關注 key
中的對應序列,並傳回一個固定寬度的向量。
此層首先會投影 query
、key
和 value
。這些實際上是長度為 num_attention_heads
的張量列表,其中對應的形狀為 (batch_size, <query 維度>, key_dim)
、(batch_size, <key/value 維度>, key_dim)
、(batch_size, <key/value 維度>, value_dim)
。
然後,查詢 (query) 和鍵 (key) 張量進行點積和縮放。這些會經過 softmax 處理以獲得注意力機率。然後,值 (value) 張量會根據這些機率進行插值,然後串接回單個張量。
最後,最後一個維度為 value_dim 的結果張量可以進行線性投影並傳回。
當在自訂層中使用 MultiHeadAttention
時,自訂層必須實作自己的 build()
方法,並在那裡呼叫 MultiHeadAttention
的 _build_from_signature()
。這使得權重在載入模型時可以正確還原。
範例
對具有注意力遮罩的兩個序列輸入執行 1D 交叉注意力。傳回各頭的額外注意力權重。
>>> layer = MultiHeadAttention(num_heads=2, key_dim=2)
>>> target = tf.keras.Input(shape=[8, 16])
>>> source = tf.keras.Input(shape=[4, 16])
>>> output_tensor, weights = layer(target, source,
... return_attention_scores=True)
>>> print(output_tensor.shape)
(None, 8, 16)
>>> print(weights.shape)
(None, 2, 8, 4)
對軸 2 和 3 上的 5D 輸入張量執行 2D 自我注意力。
>>> layer = MultiHeadAttention(
... num_heads=2, key_dim=2, attention_axes=(2, 3))
>>> input_tensor = tf.keras.Input(shape=[5, 3, 4, 16])
>>> output_tensor = layer(input_tensor, input_tensor)
>>> print(output_tensor.shape)
(None, 5, 3, 4, 16)
引數
None
表示對所有軸(但不包括批次、頭和特徵)應用注意力。呼叫引數
(B, T, dim)
的查詢 Tensor
。(B, S, dim)
的值 Tensor
。Tensor
,形狀為 (B, S, dim)
。如果未給定,將會使用 value
作為 key
和 value
,這是最常見的情況。(B, T, S)
的布林遮罩,可防止關注特定位置。布林遮罩指定哪些查詢元素可以關注哪些鍵元素,1 表示關注,0 表示不關注。廣播可以發生在缺少的批次維度和頭維度上。True
,輸出應為 (attention_output, attention_scores)
,如果為 False
,則為 attention_output
。預設為 False
。傳回
(B, T, E)
,其中 T
是目標序列形狀,如果 output_shape
為 None
,則 E
是查詢輸入的最後一個維度。否則,多頭輸出會投影到 output_shape
指定的形狀。