Seq2SeqLM

[原始碼]

Seq2SeqLM 類別

keras_hub.models.Seq2SeqLM()

序列到序列語言建模任務的基底類別。

Seq2SeqLM 任務封裝了 keras_hub.models.Backbonekeras_hub.models.Preprocessor,以建立可用於生成和生成式微調的模型,當生成條件是在序列到序列設定中的額外輸入序列時。

Seq2SeqLM 任務提供了一個額外的、高階的 generate() 函數,可用於逐 token 自動迴歸地取樣輸出序列。Seq2SeqLM 類別的 compile() 方法包含一個額外的 sampler 參數,可用於傳遞 keras_hub.samplers.Sampler 來控制如何取樣預測的分佈。

當呼叫 fit() 時,每個輸入應包含一個輸入和輸出序列。模型將被訓練以使用因果遮罩逐 token 預測輸出序列,類似於 keras_hub.models.CausalLM 任務。與 CausalLM 任務不同,必須傳遞輸入序列,並且輸出序列中的所有 token 都可以完全關注它。

所有 Seq2SeqLM 任務都包含一個 from_preset() 建構子,可用於載入預訓練的配置和權重。

範例

# Load a Bart backbone with pre-trained weights.
seq_2_seq_lm = keras_hub.models.Seq2SeqLM.from_preset(
    "bart_base_en",
)
seq_2_seq_lm.compile(sampler="top_k")
# Generate conditioned on the `"The quick brown fox."` as an input sequence.
seq_2_seq_lm.generate("The quick brown fox.", max_length=30)

[原始碼]

from_preset 方法

Seq2SeqLM.from_preset(preset, load_weights=True, **kwargs)

從模型預設實例化 keras_hub.models.Task

預設是一個配置、權重和其他檔案資產的目錄,用於儲存和載入預訓練模型。preset 可以作為以下之一傳遞

  1. 內建預設識別符,例如 'bert_base_en'
  2. Kaggle 模型句柄,例如 'kaggle://user/bert/keras/bert_base_en'
  3. Hugging Face 句柄,例如 'hf://user/bert_base_en'
  4. 本地預設目錄的路徑,例如 './bert_base_en'

對於任何 Task 子類別,您可以執行 cls.presets.keys() 以列出該類別上所有可用的內建預設。

此建構子可以透過兩種方式之一呼叫。可以從任務特定的基底類別(例如 keras_hub.models.CausalLM.from_preset())呼叫,也可以從模型類別(例如 keras_hub.models.BertTextClassifier.from_preset())呼叫。如果從基底類別呼叫,則返回物件的子類別將從預設目錄中的配置推斷出來。

參數

  • preset:字串。內建預設識別符、Kaggle 模型句柄、Hugging Face 句柄或本地目錄的路徑。
  • load_weights:布林值。如果為 True,則儲存的權重將載入到模型架構中。如果為 False,則所有權重將隨機初始化。

範例

# Load a Gemma generative task.
causal_lm = keras_hub.models.CausalLM.from_preset(
    "gemma_2b_en",
)

# Load a Bert classification task.
model = keras_hub.models.TextClassifier.from_preset(
    "bert_base_en",
    num_classes=2,
)
預設 參數 描述
bart_base_en 139.42M 6 層 BART 模型,其中保留大小寫。在 BookCorpus、英文維基百科和 CommonCrawl 上訓練。
bart_large_en 406.29M 12 層 BART 模型,其中保留大小寫。在 BookCorpus、英文維基百科和 CommonCrawl 上訓練。
bart_large_en_cnn 406.29M 在 CNN+DM 摘要資料集上微調的 bart_large_en 主幹模型。

[原始碼]

compile 方法

Seq2SeqLM.compile(
    optimizer="auto", loss="auto", weighted_metrics="auto", sampler="top_k", **kwargs
)

配置 CausalLM 任務以進行訓練和生成。

CausalLM 任務擴展了 keras.Model.compile 的預設編譯簽名,並為 optimizerlossweighted_metrics 設定了預設值。若要覆寫這些預設值,請在編譯期間將任何值傳遞給這些參數。

CausalLM 任務在 compile 中新增了一個新的 sampler,可用於控制與 generate 函數一起使用的取樣策略。

請注意,由於訓練輸入包含從損失中排除的填充 token,因此幾乎總是建議使用 weighted_metrics 而不是 metrics 進行編譯。

參數

  • optimizer"auto"、最佳化器名稱或 keras.Optimizer 實例。預設為 "auto",它使用給定模型和任務的預設最佳化器。請參閱 keras.Model.compilekeras.optimizers 以取得有關可能的 optimizer 值的更多資訊。
  • loss"auto"、損失名稱或 keras.losses.Loss 實例。預設為 "auto",其中將為 token 分類 CausalLM 任務套用 keras.losses.SparseCategoricalCrossentropy 損失。請參閱 keras.Model.compilekeras.losses 以取得有關可能的 loss 值的更多資訊。
  • weighted_metrics"auto",或在訓練和測試期間由模型評估的指標列表。預設為 "auto",其中將套用 keras.metrics.SparseCategoricalAccuracy 以追蹤模型在猜測遮罩 token 值時的準確性。請參閱 keras.Model.compilekeras.metrics 以取得有關可能的 weighted_metrics 值的更多資訊。
  • sampler:取樣器名稱或 keras_hub.samplers.Sampler 實例。配置在 generate() 呼叫期間使用的取樣方法。請參閱 keras_hub.samplers 以取得內建取樣策略的完整列表。
  • **kwargs:請參閱 keras.Model.compile 以取得 compile 方法支援的完整參數列表。

[原始碼]

generate 方法

Seq2SeqLM.generate(
    inputs, max_length=None, stop_token_ids="auto", strip_prompt=False
)

根據提示 inputs 生成文字。

此方法根據給定的 inputs 生成文字。用於生成的取樣方法可以透過 compile() 方法設定。

如果 inputstf.data.Dataset,則輸出將「逐批次」生成並串聯。否則,所有輸入將被視為單一批次處理。

如果模型附加了 preprocessor,則 inputs 將在 generate() 函數內部進行預處理,並且應符合 preprocessor 層預期的結構(通常是原始字串)。如果未附加 preprocessor,則輸入應符合 backbone 預期的結構。請參閱上面的範例用法以示範每種情況。

參數

  • inputs:python 資料、張量資料或 tf.data.Dataset。如果模型附加了 preprocessor,則 inputs 應符合 preprocessor 層預期的結構。如果未附加 preprocessor,則輸入應符合 backbone 模型預期的結構。
  • max_length:可選。整數。生成序列的最大長度。預設為 preprocessor 的最大配置 sequence_length。如果 preprocessorNone,則應將 inputs 填充到所需的最大長度,並且此參數將被忽略。
  • stop_token_ids:可選。None"auto" 或 token id 元組。預設為 "auto",它使用 preprocessor.tokenizer.end_token_id。未指定處理器將產生錯誤。None 在生成 max_length 個 token 後停止生成。您也可以指定模型應停止的 token id 列表。請注意,token 序列將被解釋為停止 token,不支援多 token 停止序列。
  • strip_prompt:可選。預設情況下,generate() 會傳回完整提示,後接模型生成的完成文字。如果此選項設定為 True,則僅傳回新生成的文字。

[原始碼]

save_to_preset 方法

Seq2SeqLM.save_to_preset(preset_dir)

將任務儲存到預設目錄。

參數

  • preset_dir:本地模型預設目錄的路徑。

preprocessor 屬性

keras_hub.models.Seq2SeqLM.preprocessor

用於預處理輸入的 keras_hub.models.Preprocessor 層。


backbone 屬性

keras_hub.models.Seq2SeqLM.backbone

具有核心架構的 keras_hub.models.Backbone 模型。