Seq2SeqLM
類別keras_hub.models.Seq2SeqLM()
序列到序列語言建模任務的基底類別。
Seq2SeqLM
任務封裝了 keras_hub.models.Backbone
和 keras_hub.models.Preprocessor
,以建立可用於生成和生成式微調的模型,當生成條件是在序列到序列設定中的額外輸入序列時。
Seq2SeqLM
任務提供了一個額外的、高階的 generate()
函數,可用於逐 token 自動迴歸地取樣輸出序列。Seq2SeqLM
類別的 compile()
方法包含一個額外的 sampler
參數,可用於傳遞 keras_hub.samplers.Sampler
來控制如何取樣預測的分佈。
當呼叫 fit()
時,每個輸入應包含一個輸入和輸出序列。模型將被訓練以使用因果遮罩逐 token 預測輸出序列,類似於 keras_hub.models.CausalLM
任務。與 CausalLM
任務不同,必須傳遞輸入序列,並且輸出序列中的所有 token 都可以完全關注它。
所有 Seq2SeqLM
任務都包含一個 from_preset()
建構子,可用於載入預訓練的配置和權重。
範例
# Load a Bart backbone with pre-trained weights.
seq_2_seq_lm = keras_hub.models.Seq2SeqLM.from_preset(
"bart_base_en",
)
seq_2_seq_lm.compile(sampler="top_k")
# Generate conditioned on the `"The quick brown fox."` as an input sequence.
seq_2_seq_lm.generate("The quick brown fox.", max_length=30)
from_preset
方法Seq2SeqLM.from_preset(preset, load_weights=True, **kwargs)
從模型預設實例化 keras_hub.models.Task
。
預設是一個配置、權重和其他檔案資產的目錄,用於儲存和載入預訓練模型。preset
可以作為以下之一傳遞
'bert_base_en'
'kaggle://user/bert/keras/bert_base_en'
'hf://user/bert_base_en'
'./bert_base_en'
對於任何 Task
子類別,您可以執行 cls.presets.keys()
以列出該類別上所有可用的內建預設。
此建構子可以透過兩種方式之一呼叫。可以從任務特定的基底類別(例如 keras_hub.models.CausalLM.from_preset()
)呼叫,也可以從模型類別(例如 keras_hub.models.BertTextClassifier.from_preset()
)呼叫。如果從基底類別呼叫,則返回物件的子類別將從預設目錄中的配置推斷出來。
參數
True
,則儲存的權重將載入到模型架構中。如果為 False
,則所有權重將隨機初始化。範例
# Load a Gemma generative task.
causal_lm = keras_hub.models.CausalLM.from_preset(
"gemma_2b_en",
)
# Load a Bert classification task.
model = keras_hub.models.TextClassifier.from_preset(
"bert_base_en",
num_classes=2,
)
預設 | 參數 | 描述 |
---|---|---|
bart_base_en | 139.42M | 6 層 BART 模型,其中保留大小寫。在 BookCorpus、英文維基百科和 CommonCrawl 上訓練。 |
bart_large_en | 406.29M | 12 層 BART 模型,其中保留大小寫。在 BookCorpus、英文維基百科和 CommonCrawl 上訓練。 |
bart_large_en_cnn | 406.29M | 在 CNN+DM 摘要資料集上微調的 bart_large_en 主幹模型。 |
compile
方法Seq2SeqLM.compile(
optimizer="auto", loss="auto", weighted_metrics="auto", sampler="top_k", **kwargs
)
配置 CausalLM
任務以進行訓練和生成。
CausalLM
任務擴展了 keras.Model.compile
的預設編譯簽名,並為 optimizer
、loss
和 weighted_metrics
設定了預設值。若要覆寫這些預設值,請在編譯期間將任何值傳遞給這些參數。
CausalLM
任務在 compile
中新增了一個新的 sampler
,可用於控制與 generate
函數一起使用的取樣策略。
請注意,由於訓練輸入包含從損失中排除的填充 token,因此幾乎總是建議使用 weighted_metrics
而不是 metrics
進行編譯。
參數
"auto"
、最佳化器名稱或 keras.Optimizer
實例。預設為 "auto"
,它使用給定模型和任務的預設最佳化器。請參閱 keras.Model.compile
和 keras.optimizers
以取得有關可能的 optimizer
值的更多資訊。"auto"
、損失名稱或 keras.losses.Loss
實例。預設為 "auto"
,其中將為 token 分類 CausalLM
任務套用 keras.losses.SparseCategoricalCrossentropy
損失。請參閱 keras.Model.compile
和 keras.losses
以取得有關可能的 loss
值的更多資訊。"auto"
,或在訓練和測試期間由模型評估的指標列表。預設為 "auto"
,其中將套用 keras.metrics.SparseCategoricalAccuracy
以追蹤模型在猜測遮罩 token 值時的準確性。請參閱 keras.Model.compile
和 keras.metrics
以取得有關可能的 weighted_metrics
值的更多資訊。keras_hub.samplers.Sampler
實例。配置在 generate()
呼叫期間使用的取樣方法。請參閱 keras_hub.samplers
以取得內建取樣策略的完整列表。keras.Model.compile
以取得 compile
方法支援的完整參數列表。generate
方法Seq2SeqLM.generate(
inputs, max_length=None, stop_token_ids="auto", strip_prompt=False
)
根據提示 inputs
生成文字。
此方法根據給定的 inputs
生成文字。用於生成的取樣方法可以透過 compile()
方法設定。
如果 inputs
是 tf.data.Dataset
,則輸出將「逐批次」生成並串聯。否則,所有輸入將被視為單一批次處理。
如果模型附加了 preprocessor
,則 inputs
將在 generate()
函數內部進行預處理,並且應符合 preprocessor
層預期的結構(通常是原始字串)。如果未附加 preprocessor
,則輸入應符合 backbone
預期的結構。請參閱上面的範例用法以示範每種情況。
參數
tf.data.Dataset
。如果模型附加了 preprocessor
,則 inputs
應符合 preprocessor
層預期的結構。如果未附加 preprocessor
,則輸入應符合 backbone
模型預期的結構。preprocessor
的最大配置 sequence_length
。如果 preprocessor
為 None
,則應將 inputs
填充到所需的最大長度,並且此參數將被忽略。None
、"auto"
或 token id 元組。預設為 "auto"
,它使用 preprocessor.tokenizer.end_token_id
。未指定處理器將產生錯誤。None
在生成 max_length
個 token 後停止生成。您也可以指定模型應停止的 token id 列表。請注意,token 序列將被解釋為停止 token,不支援多 token 停止序列。save_to_preset
方法Seq2SeqLM.save_to_preset(preset_dir)
將任務儲存到預設目錄。
參數
preprocessor
屬性keras_hub.models.Seq2SeqLM.preprocessor
用於預處理輸入的 keras_hub.models.Preprocessor
層。
backbone
屬性keras_hub.models.Seq2SeqLM.backbone
具有核心架構的 keras_hub.models.Backbone
模型。