MultiSegmentPacker
類別keras_hub.layers.MultiSegmentPacker(
sequence_length,
start_value,
end_value,
sep_value=None,
pad_value=None,
truncate="round_robin",
**kwargs
)
將多個序列打包成單一固定寬度的模型輸入。
此層將多個輸入序列打包成包含起始和結束分隔符號的單一固定寬度序列,形成適用於 BERT 和類 BERT 模型分類任務的密集輸入。
接受 token 段落的元組作為輸入。每個元組元素應包含段落的 token,以 tensors、tf.RaggedTensor
s 或列表的形式傳遞。對於批次輸入,段落元組中的每個元素都應為列表的列表或秩為 2 的 tensor。對於非批次輸入,每個元素都應為列表或秩為 1 的 tensor。
此層將依以下步驟處理輸入:- 根據 truncate
策略,截斷所有輸入段落以符合 sequence_length
限制。 - 連接所有輸入段落,在整個序列的開頭添加單個 start_value
,並在每個段落的末尾添加多個 end_value
。 - 使用 pad_tokens
將結果序列填充到 sequence_length
。 - 計算一個單獨的「段落 ID」tensor,其整數類型和形狀與打包的 token 輸出相同,其中每個整數索引表示 token 的來源段落。 start_value
的段落 ID 始終為 0,而每個 end_value
的段落 ID 是其前面的段落。
引數
None
,則使用 end_value
。dtype 必須與層的輸入 tensors 的 dtype 相符。"round_robin"
或 "waterfall"
"round_robin"
:可用空間以循環方式一次分配一個 token 給仍然需要空間的輸入,直到達到限制。"waterfall"
:「瀑布式」演算法用於預算分配,該演算法以從左到右的方式分配配額並填滿儲存桶,直到預算用完。它支援任意數量的段落。回傳值
包含兩個元素的元組。第一個是密集的、打包的 token 序列。第二個是相同形狀的整數 tensor,包含段落 ID。
範例
為分類打包單個輸入。
>>> seq1 = [1, 2, 3, 4]
>>> packer = keras_hub.layers.MultiSegmentPacker(
... sequence_length=8, start_value=101, end_value=102
... )
>>> token_ids, segment_ids = packer((seq1,))
>>> np.array(token_ids)
array([101, 1, 2, 3, 4, 102, 0, 0], dtype=int32)
>>> np.array(segment_ids)
array([0, 0, 0, 0, 0, 0, 0, 0], dtype=int32)
為分類打包多個輸入。
>>> seq1 = [1, 2, 3, 4]
>>> seq2 = [11, 12, 13, 14]
>>> packer = keras_hub.layers.MultiSegmentPacker(
... sequence_length=8, start_value=101, end_value=102
... )
>>> token_ids, segment_ids = packer((seq1, seq2))
>>> np.array(token_ids)
array([101, 1, 2, 3, 102, 11, 12, 102], dtype=int32)
>>> np.array(segment_ids)
array([0, 0, 0, 0, 0, 1, 1, 1], dtype=int32)
為分類打包具有不同 sep token 的多個輸入。
>>> seq1 = [1, 2, 3, 4]
>>> seq2 = [11, 12, 13, 14]
>>> packer = keras_hub.layers.MultiSegmentPacker(
... sequence_length=8,
... start_value=101,
... end_value=102,
... sep_value=[102, 102],
... )
>>> token_ids, segment_ids = packer((seq1, seq2))
>>> np.array(token_ids)
array([101, 1, 2, 102, 102, 11, 12, 102], dtype=int32)
>>> np.array(segment_ids)
array([0, 0, 0, 0, 0, 1, 1, 1], dtype=int32)
參考文獻