CausalLM

[原始碼]

CausalLM 類別

keras_hub.models.CausalLM()

生成式語言建模任務的基礎類別。

CausalLM 任務封裝了 keras_hub.models.Backbonekeras_hub.models.Preprocessor,以建立可用於生成和生成式微調的模型。

CausalLM 任務提供了一個額外的高階 generate() 函數,可用於以字串輸入、字串輸出的簽章,以自動迴歸方式逐個 token 取樣模型。所有 CausalLM 類別的 compile() 方法都包含一個額外的 sampler 參數,可用於傳遞 keras_hub.samplers.Sampler,以控制將如何取樣預測的分佈。

當呼叫 fit() 時,token 化的輸入將逐個 token 預測,並套用因果遮罩,這為控制推論時間生成提供了預訓練和監督式微調設定。

所有 CausalLM 任務都包含一個 from_preset() 建構子,可用於載入預訓練的組態和權重。

範例

# Load a GPT2 backbone with pre-trained weights.
causal_lm = keras_hub.models.CausalLM.from_preset(
    "gpt2_base_en",
)
causal_lm.compile(sampler="top_k")
causal_lm.generate("Keras is a", max_length=64)

# Load a Mistral instruction tuned checkpoint at bfloat16 precision.
causal_lm = keras_hub.models.CausalLM.from_preset(
    "mistral_instruct_7b_en",
    dtype="bfloat16",
)
causal_lm.compile(sampler="greedy")
causal_lm.generate("Keras is a", max_length=64)

[原始碼]

from_preset 方法

CausalLM.from_preset(preset, load_weights=True, **kwargs)

從模型預設實例化 keras_hub.models.Task

預設是一個組態、權重和其他檔案資產的目錄,用於儲存和載入預訓練模型。preset 可以作為以下其中一種傳遞:

  1. 內建預設識別符,例如 'bert_base_en'
  2. Kaggle Models 句柄,例如 'kaggle://user/bert/keras/bert_base_en'
  3. Hugging Face 句柄,例如 'hf://user/bert_base_en'
  4. 本機預設目錄的路徑,例如 './bert_base_en'

對於任何 Task 子類別,您可以執行 cls.presets.keys() 以列出類別上所有可用的內建預設。

此建構子可以透過兩種方式之一呼叫。可以從特定任務的基礎類別呼叫,例如 keras_hub.models.CausalLM.from_preset(),也可以從模型類別呼叫,例如 keras_hub.models.BertTextClassifier.from_preset()。如果從基礎類別呼叫,則傳回物件的子類別將從預設目錄中的組態推斷出來。

引數

  • preset:字串。內建預設識別符、Kaggle Models 句柄、Hugging Face 句柄或本機目錄的路徑。
  • load_weights:布林值。如果為 True,則儲存的權重將載入到模型架構中。如果為 False,則所有權重將隨機初始化。

範例

# Load a Gemma generative task.
causal_lm = keras_hub.models.CausalLM.from_preset(
    "gemma_2b_en",
)

# Load a Bert classification task.
model = keras_hub.models.TextClassifier.from_preset(
    "bert_base_en",
    num_classes=2,
)
預設 參數 描述
bart_base_en 139.42M 6 層 BART 模型,其中保留大小寫。在 BookCorpus、英文維基百科和 CommonCrawl 上訓練。
bart_large_en 406.29M 12 層 BART 模型,其中保留大小寫。在 BookCorpus、英文維基百科和 CommonCrawl 上訓練。
bart_large_en_cnn 406.29M 在 CNN+DM 摘要資料集上微調的 bart_large_en 主幹模型。
bloom_560m_multi 559.21M 24 層 Bloom 模型,隱藏維度為 1024。在 45 種自然語言和 12 種程式語言上訓練。
bloomz_560m_multi 559.21M 24 層 Bloom 模型,隱藏維度為 1024。在跨語言任務混合 (xP3) 資料集上進行微調。
bloom_1.1b_multi 1.07B 24 層 Bloom 模型,隱藏維度為 1536。在 45 種自然語言和 12 種程式語言上訓練。
bloomz_1.1b_multi 1.07B 24 層 Bloom 模型,隱藏維度為 1536。在跨語言任務混合 (xP3) 資料集上進行微調。
bloom_1.7b_multi 1.72B 24 層 Bloom 模型,隱藏維度為 2048。在 45 種自然語言和 12 種程式語言上訓練。
bloomz_1.7b_multi 1.72B 24 層 Bloom 模型,隱藏維度為 2048。在跨語言任務混合 (xP3) 資料集上進行微調。
bloom_3b_multi 3.00B 30 層 Bloom 模型,隱藏維度為 2560。在 45 種自然語言和 12 種程式語言上訓練。
bloomz_3b_multi 3.00B 30 層 Bloom 模型,隱藏維度為 2560。在跨語言任務混合 (xP3) 資料集上進行微調。
falcon_refinedweb_1b_en 1.31B 24 層 Falcon 模型(具有 1B 參數的 Falcon),在 350B 個 RefinedWeb 資料集的 token 上訓練。
gemma_2b_en 2.51B 20 億參數、18 層、基礎 Gemma 模型。
gemma_instruct_2b_en 2.51B 20 億參數、18 層、指令調整 Gemma 模型。
gemma_1.1_instruct_2b_en 2.51B 20 億參數、18 層、指令調整 Gemma 模型。1.1 更新改進了模型品質。
code_gemma_1.1_2b_en 2.51B 20 億參數、18 層、CodeGemma 模型。此模型已在程式碼完成的填空中間 (FIM) 任務上進行訓練。1.1 更新改進了模型品質。
code_gemma_2b_en 2.51B 20 億參數、18 層、CodeGemma 模型。此模型已在程式碼完成的填空中間 (FIM) 任務上進行訓練。
gemma2_2b_en 2.61B 20 億參數、26 層、基礎 Gemma 模型。
gemma2_instruct_2b_en 2.61B 20 億參數、26 層、指令調整 Gemma 模型。
shieldgemma_2b_en 2.61B 20 億參數、26 層、ShieldGemma 模型。
gemma_7b_en 8.54B 70 億參數、28 層、基礎 Gemma 模型。
gemma_instruct_7b_en 8.54B 70 億參數、28 層、指令調整 Gemma 模型。
gemma_1.1_instruct_7b_en 8.54B 70 億參數、28 層、指令調整 Gemma 模型。1.1 更新改進了模型品質。
code_gemma_7b_en 8.54B 70 億參數、28 層、CodeGemma 模型。此模型已在程式碼完成的填空中間 (FIM) 任務上進行訓練。
code_gemma_instruct_7b_en 8.54B 70 億參數、28 層、指令調整 CodeGemma 模型。此模型已針對與程式碼相關的聊天用例進行訓練。
code_gemma_1.1_instruct_7b_en 8.54B 70 億參數、28 層、指令調整 CodeGemma 模型。此模型已針對與程式碼相關的聊天用例進行訓練。1.1 更新改進了模型品質。
gemma2_9b_en 9.24B 90 億參數、42 層、基礎 Gemma 模型。
gemma2_instruct_9b_en 9.24B 90 億參數、42 層、指令調整 Gemma 模型。
shieldgemma_9b_en 9.24B 90 億參數、42 層、ShieldGemma 模型。
gemma2_27b_en 27.23B 270 億參數、42 層、基礎 Gemma 模型。
gemma2_instruct_27b_en 27.23B 270 億參數、42 層、指令調整 Gemma 模型。
shieldgemma_27b_en 27.23B 270 億參數、42 層、ShieldGemma 模型。
gpt2_base_en 124.44M 12 層 GPT-2 模型,其中保留大小寫。在 WebText 上訓練。
gpt2_base_en_cnn_dailymail 124.44M 12 層 GPT-2 模型,其中保留大小寫。在 CNN/DailyMail 摘要資料集上微調。
gpt2_medium_en 354.82M 24 層 GPT-2 模型,其中保留大小寫。在 WebText 上訓練。
gpt2_large_en 774.03M 36 層 GPT-2 模型,其中保留大小寫。在 WebText 上訓練。
gpt2_extra_large_en 1.56B 48 層 GPT-2 模型,其中保留大小寫。在 WebText 上訓練。
llama2_7b_en 6.74B 70 億參數、32 層、基礎 LLaMA 2 模型。
llama2_instruct_7b_en 6.74B 70 億參數、32 層、指令調整 LLaMA 2 模型。
vicuna_1.5_7b_en 6.74B 70 億參數、32 層、指令調整 Vicuna v1.5 模型。
llama2_7b_en_int8 6.74B 70 億參數、32 層、基礎 LLaMA 2 模型,其激活和權重量化為 int8。
llama2_instruct_7b_en_int8 6.74B 70 億參數、32 層、指令調整 LLaMA 2 模型,其激活和權重量化為 int8。
llama3_8b_en 8.03B 80 億參數、32 層、基礎 LLaMA 3 模型。
llama3_instruct_8b_en 8.03B 80 億參數、32 層、指令調整 LLaMA 3 模型。
llama3_8b_en_int8 8.03B 80 億參數、32 層、基礎 LLaMA 3 模型,其激活和權重量化為 int8。
llama3_instruct_8b_en_int8 8.03B 80 億參數、32 層、指令調整 LLaMA 3 模型,其激活和權重量化為 int8。
mistral_7b_en 7.24B Mistral 7B 基礎模型
mistral_instruct_7b_en 7.24B Mistral 7B 指令模型
mistral_0.2_instruct_7b_en 7.24B Mistral 7B 指令版本 0.2 模型
opt_125m_en 125.24M 12 層 OPT 模型,其中保留大小寫。在 BookCorpus、CommonCrawl、Pile 和 PushShift.io 語料庫上訓練。
opt_1.3b_en 1.32B 24 層 OPT 模型,其中保留大小寫。在 BookCorpus、CommonCrawl、Pile 和 PushShift.io 語料庫上訓練。
opt_2.7b_en 2.70B 32 層 OPT 模型,其中保留大小寫。在 BookCorpus、CommonCrawl、Pile 和 PushShift.io 語料庫上訓練。
opt_6.7b_en 6.70B 32 層 OPT 模型,其中保留大小寫。在 BookCorpus、CommonCrawl、Pile 和 PushShift.io 語料庫上訓練。
pali_gemma_3b_mix_224 2.92B 影像大小 224,混合微調,文字序列長度為 256
pali_gemma_3b_224 2.92B 影像大小 224,預訓練,文字序列長度為 128
pali_gemma_3b_mix_448 2.92B 影像大小 448,混合微調,文字序列長度為 512
pali_gemma_3b_448 2.92B 影像大小 448,預訓練,文字序列長度為 512
pali_gemma_3b_896 2.93B 影像大小 896,預訓練,文字序列長度為 512
pali_gemma2_mix_3b_224 3.03B 30 億參數,影像大小 224,SigLIP-So400m 視覺編碼器為 27 層,Gemma2 2B 語言模型為 26 層。此模型已在廣泛的視覺語言任務和領域上進行微調。
pali_gemma2_pt_3b_224 3.03B 30 億參數,影像大小 224,SigLIP-So400m 視覺編碼器為 27 層,Gemma2 2B 語言模型為 26 層。此模型已在資料集混合上進行預訓練。
pali_gemma_2_ft_docci_3b_448 3.03B 30 億參數,影像大小 448,SigLIP-So400m 視覺編碼器為 27 層,Gemma2 2B 語言模型為 26 層。此模型已在 DOCCI 資料集上進行微調,以改進具有精細細節的描述。
pali_gemma2_mix_3b_448 3.03B 30 億參數,影像大小 448,SigLIP-So400m 視覺編碼器為 27 層,Gemma2 2B 語言模型為 26 層。此模型已在廣泛的視覺語言任務和領域上進行微調。
pali_gemma2_pt_3b_448 3.03B 30 億參數,影像大小 448,SigLIP-So400m 視覺編碼器為 27 層,Gemma2 2B 語言模型為 26 層。此模型已在資料集混合上進行預訓練。
pali_gemma2_pt_3b_896 3.04B 30 億參數,影像大小 896,SigLIP-So400m 視覺編碼器為 27 層,Gemma2 2B 語言模型為 26 層。此模型已在資料集混合上進行預訓練。
pali_gemma2_mix_10b_224 9.66B 100 億參數,影像大小 224,SigLIP-So400m 視覺編碼器為 27 層,Gemma2 9B 語言模型為 42 層。此模型已在廣泛的視覺語言任務和領域上進行微調。
pali_gemma2_pt_10b_224 9.66B 100 億參數,影像大小 224,SigLIP-So400m 視覺編碼器為 27 層,Gemma2 9B 語言模型為 42 層。此模型已在資料集混合上進行預訓練。
pali_gemma2_ft_docci_10b_448 9.66B 100 億參數,SigLIP-So400m 視覺編碼器為 27 層,Gemma2 9B 語言模型為 42 層。此模型已在 DOCCI 資料集上進行微調,以改進具有精細細節的描述。
pali_gemma2_mix_10b_448 9.66B 100 億參數,影像大小 448,SigLIP-So400m 視覺編碼器為 27 層,Gemma2 9B 語言模型為 42 層。此模型已在廣泛的視覺語言任務和領域上進行微調。
pali_gemma2_pt_10b_448 9.66B 100 億參數,影像大小 448,SigLIP-So400m 視覺編碼器為 27 層,Gemma2 9B 語言模型為 42 層。此模型已在資料集混合上進行預訓練。
pali_gemma2_pt_10b_896 9.67B 100 億參數,影像大小 896,SigLIP-So400m 視覺編碼器為 27 層,Gemma2 9B 語言模型為 42 層。此模型已在資料集混合上進行預訓練。
pali_gemma2_mix_28b_224 27.65B 280 億參數,影像大小 224,SigLIP-So400m 視覺編碼器為 27 層,Gemma2 27B 語言模型為 46 層。此模型已在廣泛的視覺語言任務和領域上進行微調。
pali_gemma2_mix_28b_448 27.65B 280 億參數,影像大小 448,SigLIP-So400m 視覺編碼器為 27 層,Gemma2 27B 語言模型為 46 層。此模型已在廣泛的視覺語言任務和領域上進行微調。
pali_gemma2_pt_28b_224 27.65B 280 億參數,影像大小 224,SigLIP-So400m 視覺編碼器為 27 層,Gemma2 27B 語言模型為 46 層。此模型已在資料集混合上進行預訓練。
pali_gemma2_pt_28b_448 27.65B 280 億參數,影像大小 448,SigLIP-So400m 視覺編碼器為 27 層,Gemma2 27B 語言模型為 46 層。此模型已在資料集混合上進行預訓練。
pali_gemma2_pt_28b_896 27.65B 280 億參數,影像大小 896,SigLIP-So400m 視覺編碼器為 27 層,Gemma2 27B 語言模型為 46 層。此模型已在資料集混合上進行預訓練。
phi3_mini_4k_instruct_en 3.82B 38 億參數、32 層、4k 上下文長度、Phi-3 模型。該模型是使用 Phi-3 資料集訓練的。此資料集包括合成資料和經過濾的公開可用網站資料,重點是高品質和推理密集型屬性。
phi3_mini_128k_instruct_en 3.82B 38 億參數、32 層、128k 上下文長度、Phi-3 模型。該模型是使用 Phi-3 資料集訓練的。此資料集包括合成資料和經過濾的公開可用網站資料,重點是高品質和推理密集型屬性。

[原始碼]

compile 方法

CausalLM.compile(
    optimizer="auto", loss="auto", weighted_metrics="auto", sampler="top_k", **kwargs
)

設定 CausalLM 任務以進行訓練和生成。

CausalLM 任務使用 optimizerlossweighted_metrics 的預設值擴展了 keras.Model.compile 的預設編譯簽章。若要覆寫這些預設值,請在編譯期間將任何值傳遞給這些引數。

CausalLM 任務將新的 sampler 新增至 compile,可用於控制與 generate 函數一起使用的取樣策略。

請注意,由於訓練輸入包括從損失中排除的填充 token,因此幾乎始終建議使用 weighted_metrics 而不是 metrics 進行編譯。

引數

  • optimizer"auto"、最佳化器名稱或 keras.Optimizer 實例。預設為 "auto",它使用給定模型和任務的預設最佳化器。有關可能的 optimizer 值的更多資訊,請參閱 keras.Model.compilekeras.optimizers
  • loss"auto"、損失名稱或 keras.losses.Loss 實例。預設為 "auto",其中 keras.losses.SparseCategoricalCrossentropy 損失將應用於 token 分類 CausalLM 任務。有關可能的 loss 值的更多資訊,請參閱 keras.Model.compilekeras.losses
  • weighted_metrics"auto",或在模型訓練和測試期間要評估的指標列表。預設為 "auto",其中將應用 keras.metrics.SparseCategoricalAccuracy 以追蹤模型在猜測遮罩 token 值時的準確性。有關可能的 weighted_metrics 值的更多資訊,請參閱 keras.Model.compilekeras.metrics
  • sampler:取樣器名稱或 keras_hub.samplers.Sampler 實例。設定在 generate() 呼叫期間使用的取樣方法。有關內建取樣策略的完整列表,請參閱 keras_hub.samplers
  • **kwargs:有關 compile 方法支援的引數的完整列表,請參閱 keras.Model.compile

[原始碼]

generate 方法

CausalLM.generate(inputs, max_length=None, stop_token_ids="auto", strip_prompt=False)

產生給定提示 inputs 的文字。

此方法根據給定的 inputs 產生文字。用於生成的取樣方法可以透過 compile() 方法設定。

如果 inputstf.data.Dataset,則輸出將「逐批次」生成並串聯。否則,所有輸入都將作為單一批次處理。

如果模型附加了 preprocessor,則 inputs 將在 generate() 函數內部進行預處理,並且應符合 preprocessor 層預期的結構(通常是原始字串)。如果未附加 preprocessor,則輸入應符合 backbone 預期的結構。有關每個結構的示範,請參閱上面的範例用法。

引數

  • inputs:python 資料、張量資料或 tf.data.Dataset。如果模型附加了 preprocessor,則 inputs 應符合 preprocessor 層預期的結構。如果未附加 preprocessor,則 inputs 應符合 backbone 模型預期的結構。
  • max_length:選用。整數。生成序列的最大長度。預設為 preprocessor 的最大設定 sequence_length。如果 preprocessorNone,則 inputs 應填充到所需的最大長度,並且此引數將被忽略。
  • stop_token_ids:選用。None、「auto」或 token ID 元組。預設為「auto」,它使用 preprocessor.tokenizer.end_token_id。不指定處理器將產生錯誤。None 在生成 max_length 個 token 後停止生成。您也可以指定模型應停止的一系列 token ID。請注意,token 序列將被視為停止 token,不支援多 token 停止序列。
  • strip_prompt:選用。預設情況下,generate() 會傳回完整提示,後跟模型產生的完成。如果將此選項設定為 True,則僅傳回新產生的文字。

[原始碼]

save_to_preset 方法

CausalLM.save_to_preset(preset_dir)

將任務儲存到預設目錄。

引數

  • preset_dir:本機模型預設目錄的路徑。

preprocessor 屬性

keras_hub.models.CausalLM.preprocessor

用於預處理輸入的 keras_hub.models.Preprocessor 層。


backbone 屬性

keras_hub.models.CausalLM.backbone

具有核心架構的 keras_hub.models.Backbone 模型。