Keras 3 API 文件 / 層 API / 注意力層 / 多頭注意力層

多頭注意力層

[原始碼]

MultiHeadAttention 類別

keras.layers.MultiHeadAttention(
    num_heads,
    key_dim,
    value_dim=None,
    dropout=0.0,
    use_bias=True,
    output_shape=None,
    attention_axes=None,
    flash_attention=None,
    kernel_initializer="glorot_uniform",
    bias_initializer="zeros",
    kernel_regularizer=None,
    bias_regularizer=None,
    activity_regularizer=None,
    kernel_constraint=None,
    bias_constraint=None,
    seed=None,
    **kwargs
)

多頭注意力層。

這是論文「Attention is all you Need」Vaswani et al., 2017 中描述的多頭注意力的實現。如果 querykeyvalue 相同,則為自我注意力。 query 中的每個時間步都會關注 key 中的對應序列,並返回一個固定寬度的向量。

此層首先投影 querykeyvalue。這些(實際上)是長度為 num_attention_heads 的張量列表,其中對應的形狀為 (batch_size, <query 維度>, key_dim)(batch_size, <key/value 維度>, key_dim)(batch_size, <key/value 維度>, value_dim)

然後,將查詢張量和鍵張量點乘並縮放。這些經過 softmax 處理以獲得注意力機率。然後,值張量會根據這些機率進行插值,然後連接回單個張量。

最後,最後一個維度為 value_dim 的結果張量可以採用線性投影並返回。

引數

  • num_heads:注意力頭的數量。
  • key_dim:查詢和鍵的每個注意力頭的大小。
  • value_dim:值的每個注意力頭的大小。
  • dropout:Dropout 機率。
  • use_bias:布林值,表示密集層是否使用偏差向量/矩陣。
  • output_shape:輸出張量的預期形狀,除了批次和序列維度之外。如果未指定,則會投影回查詢特徵維度(查詢輸入的最後一個維度)。
  • attention_axes:應用注意力的軸。 None 表示對所有軸應用注意力,但不包括批次、頭部和特徵。
  • flash_attention:如果為 None,則該層會在可能的情況下嘗試使用快閃注意力,以實現更快、更省記憶體的注意力運算。可以使用 keras.config.enable_flash_attention()keras.config.disable_flash_attention() 配置此行為。
  • kernel_initializer:密集層核心的初始化器。
  • bias_initializer:密集層偏差的初始化器。
  • kernel_regularizer:密集層核心的正規化器。
  • bias_regularizer:密集層偏差的正規化器。
  • activity_regularizer:密集層活動的正規化器。
  • kernel_constraint:密集層核心的約束。
  • bias_constraint:密集層核心的約束。
  • seed:用於設定 dropout 層種子的可選整數。

呼叫引數

  • query:形狀為 (B, T, dim) 的查詢張量,其中 B 是批次大小,T 是目標序列長度,而 dim 是特徵維度。
  • value:形狀為 (B, S, dim) 的值張量,其中 B 是批次大小,S 是來源序列長度,而 dim 是特徵維度。
  • key:形狀為 (B, S, dim) 的可選鍵張量。如果未提供,則會將 value 用於 keyvalue,這是最常見的情況。
  • attention_mask:形狀為 (B, T, S) 的布林遮罩,可防止注意力集中在某些位置。布林遮罩指定哪些查詢元素可以關注哪些鍵元素,1 表示注意力,0 表示不關注。遺失的批次維度和頭部維度可以進行廣播。
  • return_attention_scores:布林值,表示如果為 True,則輸出應為 (attention_output, attention_scores),如果為 False,則輸出應為 attention_output。預設為 False
  • training:Python 布林值,表示該層應在訓練模式(新增 dropout)還是推論模式(無 dropout)中運行。將使用父層/模型的訓練模式,或者如果沒有父層,則使用 False(推論)。
  • use_causal_mask:布林值,表示是否套用因果遮罩,以防止標記關注未來標記(例如,用於解碼器 Transformer)。

返回

  • attention_output:計算結果,形狀為 (B, T, E),其中 T 代表目標序列形狀,如果 output_shapeNone,則 E 代表查詢輸入的最後一個維度。否則,多頭輸出會投影到 output_shape 指定的形狀。
  • attention_scores:(可選)在注意力軸上的多頭注意力係數。