CachedMultiHeadAttention
類別keras_hub.layers.CachedMultiHeadAttention(
num_heads,
key_dim,
value_dim=None,
dropout=0.0,
use_bias=True,
output_shape=None,
attention_axes=None,
flash_attention=None,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
seed=None,
**kwargs
)
具有快取支援的多頭注意力層。
此層適用於自迴歸解碼。它可用於快取解碼器自我注意力和交叉注意力。前向傳播可以發生在以下三種模式之一
cache_update_index
為 None)。在這種情況下,將使用快取的鍵/值投影,並且輸入值將被忽略。cache_update_index
不為 None)。在這種情況下,將使用輸入計算新的鍵/值投影,並在指定的索引處拼接回快取中。請注意,快取僅在推論期間有用,不應在訓練期間使用。
我們在下面使用符號 B
、T
、S
,其中 B
是批次維度,T
是目標序列長度,S
是來源序列長度。請注意,在生成式解碼期間,T
通常為 1(您正在生成長度為 1 的目標序列以預測下一個符記)。
呼叫參數
Tensor
,形狀為 (B, T, dim)
。Tensor
,形狀為 (B, S*, dim)
。如果 cache
為 None,
S*必須等於
S並且與
attention_mask的形狀匹配。如果快取
不為 None
,則 S*
可以是任何小於 S
的長度,並且計算出的值將在 cache_update_index
處拼接回 cache
中。Tensor
,形狀為 (B, S*, dim)
。如果 cache
為 None
,則 S*
必須等於 S
並且與 attention_mask
的形狀匹配。如果 cache
不為 None
,則 S*
可以是任何小於 S
的長度,並且計算出的值將在 cache_update_index
處拼接回 cache
中。(B, T, S)
的布林遮罩。attention_mask
防止注意力集中在某些位置。布林遮罩指定哪些查詢元素可以關注哪些鍵元素,1 表示注意力,0 表示不注意。廣播可以發生在遺失的批次維度和頭部維度上。[B, 2, S, num_heads, key_dims]
,其中 S
必須與 attention_mask
形狀一致。此參數旨在在生成期間使用,以避免重新計算中間狀態。cache
的索引(通常是在執行生成時正在處理的目前符記的索引)。如果設定了 cache
,但 cache_update_index=None
,則不會更新快取。返回
一個 (attention_output, cache)
元組。attention_output
是計算結果,形狀為 (B, T, dim)
,其中 T
用於目標序列形狀,如果 output_shape
為 None
,則 dim
是查詢輸入的最後一個維度。否則,多頭輸出將投影到 output_shape
指定的形狀。cache
是更新後的快取。