Keras 3 API 文件 / KerasHub / 建模層 / RotaryEmbedding 層

RotaryEmbedding 層

[來源]

RotaryEmbedding 類別

keras_hub.layers.RotaryEmbedding(
    max_wavelength=10000, scaling_factor=1.0, sequence_axis=1, feature_axis=-1, **kwargs
)

旋轉位置編碼層。

此層使用旋轉矩陣對絕對位置資訊進行編碼。它使用正弦和餘弦函數的混合計算旋轉編碼,這些函數的波長呈幾何級數增長。在 RoFormer:增強型旋轉位置嵌入 Transformer 中定義和制定。輸入必須是一個具有序列維度和特徵維度的張量。通常,這將是一個形狀為 (批次大小, 序列長度, 特徵長度)(批次大小, 序列長度, 頭數, 特徵長度) 的輸入。此層將返回一個新的張量,其中應用旋轉嵌入到輸入張量。

參數

  • max_wavelength:int。正弦/餘弦曲線的最大角波長。
  • scaling_factor:float。用於縮放權杖位置的縮放因子。
  • sequence_axis:int。輸入張量中的序列軸。
  • feature_axis:int。輸入張量中的特徵軸。
  • **kwargs:傳遞給 keras.layers.Layer 的其他關鍵字參數,包括 nametrainabledtype 等。

呼叫參數

  • inputs:要應用嵌入的張量輸入。這可以有任何形狀,但必須同時包含序列軸和特徵軸。旋轉嵌入將應用於 inputs 並返回。
  • start_index:整數或整數張量。開始計算旋轉嵌入的位置。這在快取解碼時很有用,其中每個位置都在迴圈中單獨預測。

範例

batch_size = 16
feature_length = 18
sequence_length = 256
num_heads = 8

# No multi-head dimension.
tensor = np.ones((batch_size, sequence_length, feature_length))
rot_emb_layer = RotaryEmbedding()
tensor_rot = rot_emb_layer(tensor)

# With multi-head dimension.
tensor = np.ones((batch_size, sequence_length, num_heads, feature_length))
tensor_rot = rot_emb_layer(tensor)

參考