BartSeq2SeqLM
類別keras_hub.models.BartSeq2SeqLM(backbone, preprocessor=None, **kwargs)
用於序列到序列語言建模的端到端 BART 模型。
序列到序列語言模型 (LM) 是一種編碼器-解碼器模型,用於條件文本生成。編碼器被給予「上下文」文本(饋送到編碼器),而解碼器基於編碼器輸入和先前的詞元預測下一個詞元。您可以微調 BartSeq2SeqLM
以為任何序列到序列任務(例如,翻譯或摘要)生成文本。
此模型具有 generate()
方法,該方法基於編碼器輸入和解碼器的可選提示生成文本。使用的生成策略由傳遞給 compile()
的額外 sampler
參數控制。您可以使用不同的 keras_hub.samplers
物件重新編譯模型以控制生成。預設情況下,將使用 "top_k"
採樣。
此模型可以選擇性地配置 preprocessor
層,在這種情況下,它將在 fit()
、predict()
、evaluate()
和 generate()
期間自動將預處理應用於字串輸入。當使用 from_preset()
建立模型時,預設會執行此操作。
免責聲明:預訓練模型以「現狀」基礎提供,不提供任何形式的保證或條件。底層模型由第三方提供,並受單獨的許可證約束,可在此處取得 here。
引數
keras_hub.models.BartBackbone
實例。keras_hub.models.BartSeq2SeqLMPreprocessor
或 None
。如果為 None
,則此模型將不應用預處理,並且應在呼叫模型之前預處理輸入。範例
使用 generate()
執行文本生成,給定輸入上下文。
bart_lm = keras_hub.models.BartSeq2SeqLM.from_preset("bart_base_en")
bart_lm.generate("The quick brown fox", max_length=30)
# Generate with batched inputs.
bart_lm.generate(["The quick brown fox", "The whale"], max_length=30)
使用自訂採樣器編譯 generate()
函數。
bart_lm = keras_hub.models.BartSeq2SeqLM.from_preset("bart_base_en")
bart_lm.compile(sampler="greedy")
bart_lm.generate("The quick brown fox", max_length=30)
將 generate()
與編碼器輸入和不完整的解碼器輸入(提示)一起使用。
bart_lm = keras_hub.models.BartSeq2SeqLM.from_preset("bart_base_en")
bart_lm.generate(
{
"encoder_text": "The quick brown fox",
"decoder_text": "The fast"
}
)
在不進行預處理的情況下使用 generate()
。
# Preprocessed inputs, with encoder inputs corresponding to
# "The quick brown fox", and the decoder inputs to "The fast". Use
# `"padding_mask"` to indicate values that should not be overridden.
prompt = {
"encoder_token_ids": np.array([[0, 133, 2119, 6219, 23602, 2, 1, 1]]),
"encoder_padding_mask": np.array(
[[1, 1, 1, 1, 1, 1, 0, 0]]
),
"decoder_token_ids": np.array([[2, 0, 133, 1769, 2, 1, 1]]),
"decoder_padding_mask": np.array([[1, 1, 1, 1, 0, 0]])
}
bart_lm = keras_hub.models.BartSeq2SeqLM.from_preset(
"bart_base_en",
preprocessor=None,
)
bart_lm.generate(prompt)
在單一批次上呼叫 fit()
。
features = {
"encoder_text": ["The quick fox jumped.", "I forgot my homework."],
"decoder_text": ["The fast hazel fox leapt.", "I forgot my assignment."]
}
bart_lm = keras_hub.models.BartSeq2SeqLM.from_preset("bart_base_en")
bart_lm.fit(x=features, batch_size=2)
在不進行預處理的情況下呼叫 fit()
。
x = {
"encoder_token_ids": np.array([[0, 133, 2119, 2, 1]] * 2),
"encoder_padding_mask": np.array([[1, 1, 1, 1, 0]] * 2),
"decoder_token_ids": np.array([[2, 0, 133, 1769, 2]] * 2),
"decoder_padding_mask": np.array([[1, 1, 1, 1, 1]] * 2),
}
y = np.array([[0, 133, 1769, 2, 1]] * 2)
sw = np.array([[1, 1, 1, 1, 0]] * 2)
bart_lm = keras_hub.models.BartSeq2SeqLM.from_preset(
"bart_base_en",
preprocessor=None,
)
bart_lm.fit(x=x, y=y, sample_weight=sw, batch_size=2)
自訂骨幹網路和詞彙表。
features = {
"encoder_text": [" afternoon sun"],
"decoder_text": ["noon sun"],
}
vocab = {
"<s>": 0,
"<pad>": 1,
"</s>": 2,
"Ġafter": 5,
"noon": 6,
"Ġsun": 7,
}
merges = ["Ġ a", "Ġ s", "Ġ n", "e r", "n o", "o n", "Ġs u", "Ġa f", "no on"]
merges += ["Ġsu n", "Ġaf t", "Ġaft er"]
tokenizer = keras_hub.models.BartTokenizer(
vocabulary=vocab,
merges=merges,
)
preprocessor = keras_hub.models.BartSeq2SeqLMPreprocessor(
tokenizer=tokenizer,
encoder_sequence_length=128,
decoder_sequence_length=128,
)
backbone = keras_hub.models.BartBackbone(
vocabulary_size=50265,
num_layers=6,
num_heads=12,
hidden_dim=768,
intermediate_dim=3072,
max_sequence_length=128,
)
bart_lm = keras_hub.models.BartSeq2SeqLM(
backbone=backbone,
preprocessor=preprocessor,
)
bart_lm.fit(x=features, batch_size=2)
from_preset
方法BartSeq2SeqLM.from_preset(preset, load_weights=True, **kwargs)
從模型預設實例化 keras_hub.models.Task
。
預設是配置、權重和其他檔案資產的目錄,用於儲存和載入預訓練模型。preset
可以作為以下其中之一傳遞
'bert_base_en'
'kaggle://user/bert/keras/bert_base_en'
'hf://user/bert_base_en'
'./bert_base_en'
對於任何 Task
子類別,您可以執行 cls.presets.keys()
以列出類別上可用的所有內建預設。
此建構子可以透過兩種方式之一呼叫。可以從特定於任務的基底類別(如 keras_hub.models.CausalLM.from_preset()
)呼叫,也可以從模型類別(如 keras_hub.models.BertTextClassifier.from_preset()
)呼叫。如果從基底類別呼叫,則傳回物件的子類別將從預設目錄中的配置推斷出來。
引數
True
,則儲存的權重將載入到模型架構中。如果為 False
,則所有權重都將隨機初始化。範例
# Load a Gemma generative task.
causal_lm = keras_hub.models.CausalLM.from_preset(
"gemma_2b_en",
)
# Load a Bert classification task.
model = keras_hub.models.TextClassifier.from_preset(
"bert_base_en",
num_classes=2,
)
預設 | 參數 | 描述 |
---|---|---|
bart_base_en | 139.42M | 6 層 BART 模型,其中保留了大小寫。在 BookCorpus、English Wikipedia 和 CommonCrawl 上訓練。 |
bart_large_en | 406.29M | 12 層 BART 模型,其中保留了大小寫。在 BookCorpus、English Wikipedia 和 CommonCrawl 上訓練。 |
bart_large_en_cnn | 406.29M | 在 CNN+DM 摘要資料集上微調的 bart_large_en 骨幹網路模型。 |
generate
方法BartSeq2SeqLM.generate(
inputs, max_length=None, stop_token_ids="auto", strip_prompt=False
)
給定提示 inputs
生成文本。
此方法基於給定的 inputs
生成文本。用於生成的採樣方法可以透過 compile()
方法設定。
如果 inputs
是 tf.data.Dataset
,則輸出將「逐批次」生成並串連。否則,所有輸入都將作為單一批次處理。
如果 preprocessor
已附加到模型,則 inputs
將在 generate()
函數內部進行預處理,並且應符合 preprocessor
層預期的結構(通常為原始字串)。如果未附加 preprocessor
,則輸入應符合 backbone
預期的結構。請參閱上面的範例用法以示範每個用法。
引數
tf.data.Dataset
。如果 preprocessor
已附加到模型,則 inputs
應符合 preprocessor
層預期的結構。如果未附加 preprocessor
,則 inputs
應符合 backbone
模型預期的結構。preprocessor
的最大配置 sequence_length
。如果 preprocessor
為 None
,則 inputs
應填充到所需的最大長度,並且將忽略此引數。None
、「auto」或詞元 ID 元組。預設為「auto」,它使用 preprocessor.tokenizer.end_token_id
。不指定處理器將產生錯誤。None 在生成 max_length
個詞元後停止生成。您也可以指定模型應停止的詞元 ID 列表。請注意,詞元序列都將被解釋為停止詞元,並且不支援多詞元停止序列。backbone
屬性keras_hub.models.BartSeq2SeqLM.backbone
具有核心架構的 keras_hub.models.Backbone
模型。
preprocessor
屬性keras_hub.models.BartSeq2SeqLM.preprocessor
用於預處理輸入的 keras_hub.models.Preprocessor
層。