StartEndPacker
類別keras_nlp.layers.StartEndPacker(
sequence_length,
start_value=None,
end_value=None,
pad_value=None,
return_padding_mask=False,
name=None,
**kwargs
)
將起始和結束標記添加到序列中,並填充到固定長度。
此層在為翻譯等任務對輸入進行分詞時非常有用,其中每個序列都應包含起始和結束標記。它應該在分詞後被呼叫。該層將首先修剪輸入以使其適配,然後添加起始/結束標記,最後如有必要,填充到 sequence_length
。
輸入資料應以張量、tf.RaggedTensor
或列表的形式傳遞。對於批次輸入,輸入應為列表的列表或二階張量。對於非批次輸入,每個元素應為列表或一階張量。
參數
None
,則不會添加起始值。None
,則不會添加結束值。None
,則會根據輸入張量的資料類型添加 0 或 ""。pad_value
。呼叫參數
tf.Tensor
、tf.RaggedTensor
或 Python 字串列表。sequence_length
。False
以不為此輸入附加起始值。False
表示不要附加此輸入的結束值。範例
未批次處理的輸入(整數)。
>>> inputs = [5, 6, 7]
>>> start_end_packer = keras_nlp.layers.StartEndPacker(
... sequence_length=7, start_value=1, end_value=2,
... )
>>> outputs = start_end_packer(inputs)
>>> np.array(outputs)
array([1, 5, 6, 7, 2, 0, 0], dtype=int32)
批次處理的輸入(整數)。
>>> inputs = [[5, 6, 7], [8, 9, 10, 11, 12, 13, 14]]
>>> start_end_packer = keras_nlp.layers.StartEndPacker(
... sequence_length=6, start_value=1, end_value=2,
... )
>>> outputs = start_end_packer(inputs)
>>> np.array(outputs)
array([[ 1, 5, 6, 7, 2, 0],
[ 1, 8, 9, 10, 11, 2]], dtype=int32)
未批次處理的輸入(字串)。
>>> inputs = tf.constant(["this", "is", "fun"])
>>> start_end_packer = keras_nlp.layers.StartEndPacker(
... sequence_length=6, start_value="<s>", end_value="</s>",
... pad_value="<pad>"
... )
>>> outputs = start_end_packer(inputs)
>>> np.array(outputs).astype("U")
array(['<s>', 'this', 'is', 'fun', '</s>', '<pad>'], dtype='<U5')
批次處理的輸入(字串)。
>>> inputs = tf.ragged.constant([["this", "is", "fun"], ["awesome"]])
>>> start_end_packer = keras_nlp.layers.StartEndPacker(
... sequence_length=6, start_value="<s>", end_value="</s>",
... pad_value="<pad>"
... )
>>> outputs = start_end_packer(inputs)
>>> np.array(outputs).astype("U")
array([['<s>', 'this', 'is', 'fun', '</s>', '<pad>'],
['<s>', 'awesome', '</s>', '<pad>', '<pad>', '<pad>']], dtype='<U7')
多個起始標記。
>>> inputs = tf.ragged.constant([["this", "is", "fun"], ["awesome"]])
>>> start_end_packer = keras_nlp.layers.StartEndPacker(
... sequence_length=6, start_value=["</s>", "<s>"], end_value="</s>",
... pad_value="<pad>"
... )
>>> outputs = start_end_packer(inputs)
>>> np.array(outputs).astype("U")
array([['</s>', '<s>', 'this', 'is', 'fun', '</s>'],
['</s>', '<s>', 'awesome', '</s>', '<pad>', '<pad>']], dtype='<U7')