Keras 3 API 文件 / KerasHub / 建模層 / 位置嵌入層

位置嵌入層 (PositionEmbedding layer)

[來源]

PositionEmbedding 類別

keras_hub.layers.PositionEmbedding(
    sequence_length, initializer="glorot_uniform", **kwargs
)

一個為輸入序列學習位置嵌入的層。

此類別假設在輸入張量中,最後一個維度對應於特徵,而倒數第二個維度對應於序列。

此層不支援遮罩,但可以與 keras.layers.Embedding 結合以支援填充遮罩。

參數

  • sequence_length:動態序列的最大長度。
  • initializer:用於嵌入權重的初始化器。預設為 "glorot_uniform"
  • seq_axis:我們在其中添加嵌入的輸入張量的軸。
  • **kwargs:傳遞給 keras.layers.Layer 的其他關鍵字參數,包括 nametrainabledtype 等。

呼叫參數

  • inputs:要計算嵌入的張量輸入,形狀為 (batch_size, sequence_length, hidden_dim)。由於位置嵌入不依賴於輸入序列內容,因此只會使用輸入形狀。
  • start_index:一個整數或整數張量。開始計算位置嵌入的起始位置。這在快取解碼時很有用,其中每個位置都在迴圈中單獨預測。

範例

直接在輸入上呼叫。

>>> layer = keras_hub.layers.PositionEmbedding(sequence_length=10)
>>> layer(np.zeros((8, 10, 16)))

與詞彙嵌入結合。

seq_length = 50
vocab_size = 5000
embed_dim = 128
inputs = keras.Input(shape=(seq_length,))
token_embeddings = keras.layers.Embedding(
    input_dim=vocab_size, output_dim=embed_dim
)(inputs)
position_embeddings = keras_hub.layers.PositionEmbedding(
    sequence_length=seq_length
)(token_embeddings)
outputs = token_embeddings + position_embeddings

參考