AdditiveAttention
類別tf_keras.layers.AdditiveAttention(use_scale=True, **kwargs)
加法注意力層,又稱 Bahdanau 式注意力。
輸入為形狀為 [batch_size, Tq, dim]
的 query
張量、形狀為 [batch_size, Tv, dim]
的 value
張量,以及形狀為 [batch_size, Tv, dim]
的 key
張量。計算步驟如下:
query
和 key
分別重塑為形狀 [batch_size, Tq, 1, dim]
和 [batch_size, 1, Tv, dim]
。[batch_size, Tq, Tv]
的分數,作為非線性總和: scores = tf.reduce_sum(tf.tanh(query + key), axis=-1)
[batch_size, Tq, Tv]
的分布: distribution = tf.nn.softmax(scores)
。distribution
建立形狀為 [batch_size, Tq, dim]
的 value
線性組合: return tf.matmul(distribution, value)
。引數
True
,將建立一個變數來縮放注意力分數。0.0
。呼叫引數
[batch_size, Tq, dim]
的 Query Tensor
。[batch_size, Tv, dim]
的 Value Tensor
。Tensor
,形狀為 [batch_size, Tv, dim]
。如果未給定,將使用 value
作為 key
和 value
,這是最常見的情況。[batch_size, Tq]
的布林遮罩 Tensor
。如果給定,輸出在 mask==False
的位置將為零。[batch_size, Tv]
的布林遮罩 Tensor
。如果給定,將套用遮罩,使 mask==False
位置的值不會影響結果。True
,則將注意力分數 (在遮罩和 softmax 後) 作為額外輸出引數傳回。True
。新增一個遮罩,使位置 i
無法關注位置 j > i
。這可防止資訊從未來流向過去。預設為 False
。輸出
Attention outputs of shape `[batch_size, Tq, dim]`.
[Optional] Attention scores after masking and softmax with shape
`[batch_size, Tq, Tv]`.
query
、value
和 key
的含義取決於應用。例如,在文字相似性的情況下,query
是第一段文字的序列嵌入,而 value
是第二段文字的序列嵌入。 key
通常與 value
張量相同。
以下是在 CNN+注意力網路中使用 AdditiveAttention
的程式碼範例
# Variable-length int sequences.
query_input = tf.keras.Input(shape=(None,), dtype='int32')
value_input = tf.keras.Input(shape=(None,), dtype='int32')
# Embedding lookup.
token_embedding = tf.keras.layers.Embedding(max_tokens, dimension)
# Query embeddings of shape [batch_size, Tq, dimension].
query_embeddings = token_embedding(query_input)
# Value embeddings of shape [batch_size, Tv, dimension].
value_embeddings = token_embedding(value_input)
# CNN layer.
cnn_layer = tf.keras.layers.Conv1D(
filters=100,
kernel_size=4,
# Use 'same' padding so outputs have the same shape as inputs.
padding='same')
# Query encoding of shape [batch_size, Tq, filters].
query_seq_encoding = cnn_layer(query_embeddings)
# Value encoding of shape [batch_size, Tv, filters].
value_seq_encoding = cnn_layer(value_embeddings)
# Query-value attention of shape [batch_size, Tq, filters].
query_value_attention_seq = tf.keras.layers.AdditiveAttention()(
[query_seq_encoding, value_seq_encoding])
# Reduce over the sequence axis to produce encodings of shape
# [batch_size, filters].
query_encoding = tf.keras.layers.GlobalAveragePooling1D()(
query_seq_encoding)
query_value_attention = tf.keras.layers.GlobalAveragePooling1D()(
query_value_attention_seq)
# Concatenate query and document encodings to produce a DNN input layer.
input_layer = tf.keras.layers.Concatenate()(
[query_encoding, query_value_attention])
# Add DNN layers, and create Model.
# ...