Keras 3 API 文件 / KerasHub / 建模層 / CachedMultiHeadAttention 層

CachedMultiHeadAttention 層

[來源]

CachedMultiHeadAttention 類別

keras_hub.layers.CachedMultiHeadAttention(
    num_heads,
    key_dim,
    value_dim=None,
    dropout=0.0,
    use_bias=True,
    output_shape=None,
    attention_axes=None,
    kernel_initializer="glorot_uniform",
    bias_initializer="zeros",
    kernel_regularizer=None,
    bias_regularizer=None,
    activity_regularizer=None,
    kernel_constraint=None,
    bias_constraint=None,
    seed=None,
    **kwargs
)

支援快取的多頭注意力層。

此層適用於自回歸解碼。它可用於快取解碼器自注意力和交叉注意力。正向傳遞可以在以下三種模式之一中進行

  • 無快取,與常規的多頭注意力相同。
  • 靜態快取 (cache_update_index 為 None)。在這種情況下,將使用快取的鍵/值投影,並忽略輸入值。
  • 更新的快取 (cache_update_index 不是 None)。在這種情況下,將使用輸入計算新的鍵/值投影,並將其拼接至指定索引處的快取中。

請注意,快取僅在推理期間有用,不應在訓練期間使用。

我們在下方使用符號 BTS,其中 B 是批次維度,T 是目標序列長度,S 是來源序列長度。請注意,在生成解碼期間,T 通常為 1(您正在生成長度為一的目標序列以預測下一個詞彙)。

呼叫參數

  • query:形狀為 (B, T, dim) 的查詢 Tensor
  • value:形狀為 (B, S*, dim) 的值 Tensor。如果 cache 為 None,則 S* 必須等於 S 並與 attention_mask 的形狀匹配。如果 cache 不是 None,則 S* 可以是任何小於 S 的長度,並且計算出的值將在 cache_update_index 處拼接至 cache 中。
  • key:形狀為 (B, S*, dim) 的可選金鑰 Tensor。如果 cacheNone,則 S* 必須等於 S 並與 attention_mask 的形狀相符。如果 cache 不為 None,則 S* 可以是任何小於 S 的長度,並且計算值將在 cache_update_index 處拼接至 cache
  • attention_mask:形狀為 (B, T, S) 的布林遮罩。attention_mask 會防止注意力集中在某些位置。布林遮罩指定哪些查詢元素可以參與哪些金鑰元素,1 表示參與,0 表示不參與。廣播可以發生在缺失的批次維度和頭部維度。
  • cache:一個密集的浮點數 Tensor。金鑰/值快取,形狀為 [B, 2, S, num_heads, key_dims],其中 S 必須與 attention_mask 形狀一致。此參數旨在在生成過程中使用,以避免重新計算中間狀態。
  • cache_update_index:一個整數或整數 Tensor,用於更新 cache 的索引(通常是在執行生成時正在處理的當前標記的索引)。如果在設定 cache 的同時 cache_update_index=None,則不會更新快取。
  • training:一個布林值,指示層應該在訓練模式還是推論模式下運行。

傳回值

一個 (attention_output, cache) 元組。attention_output 是計算結果,形狀為 (B, T, dim),其中 T 表示目標序列形狀,如果 output_shapeNone,則 dim 為查詢輸入的最後一個維度。否則,多頭輸出將投影到 output_shape 指定的形狀。cache 是更新後的快取。