CausalLM
類別keras_hub.models.CausalLM()
生成式語言建模任務的基礎類別。
CausalLM
任務封裝了 keras_hub.models.Backbone
和 keras_hub.models.Preprocessor
,以建立可用於生成和生成式微調的模型。
CausalLM
任務提供了一個額外的高階 generate()
函數,可用於以字串輸入、字串輸出的簽章,以自動迴歸方式逐個 token 取樣模型。所有 CausalLM
類別的 compile()
方法都包含一個額外的 sampler
參數,可用於傳遞 keras_hub.samplers.Sampler
,以控制將如何取樣預測的分佈。
當呼叫 fit()
時,token 化的輸入將逐個 token 預測,並套用因果遮罩,這為控制推論時間生成提供了預訓練和監督式微調設定。
所有 CausalLM
任務都包含一個 from_preset()
建構子,可用於載入預訓練的組態和權重。
範例
# Load a GPT2 backbone with pre-trained weights.
causal_lm = keras_hub.models.CausalLM.from_preset(
"gpt2_base_en",
)
causal_lm.compile(sampler="top_k")
causal_lm.generate("Keras is a", max_length=64)
# Load a Mistral instruction tuned checkpoint at bfloat16 precision.
causal_lm = keras_hub.models.CausalLM.from_preset(
"mistral_instruct_7b_en",
dtype="bfloat16",
)
causal_lm.compile(sampler="greedy")
causal_lm.generate("Keras is a", max_length=64)
from_preset
方法CausalLM.from_preset(preset, load_weights=True, **kwargs)
從模型預設實例化 keras_hub.models.Task
。
預設是一個組態、權重和其他檔案資產的目錄,用於儲存和載入預訓練模型。preset
可以作為以下其中一種傳遞:
'bert_base_en'
'kaggle://user/bert/keras/bert_base_en'
'hf://user/bert_base_en'
'./bert_base_en'
對於任何 Task
子類別,您可以執行 cls.presets.keys()
以列出類別上所有可用的內建預設。
此建構子可以透過兩種方式之一呼叫。可以從特定任務的基礎類別呼叫,例如 keras_hub.models.CausalLM.from_preset()
,也可以從模型類別呼叫,例如 keras_hub.models.BertTextClassifier.from_preset()
。如果從基礎類別呼叫,則傳回物件的子類別將從預設目錄中的組態推斷出來。
引數
True
,則儲存的權重將載入到模型架構中。如果為 False
,則所有權重將隨機初始化。範例
# Load a Gemma generative task.
causal_lm = keras_hub.models.CausalLM.from_preset(
"gemma_2b_en",
)
# Load a Bert classification task.
model = keras_hub.models.TextClassifier.from_preset(
"bert_base_en",
num_classes=2,
)
預設 | 參數 | 描述 |
---|---|---|
bart_base_en | 139.42M | 6 層 BART 模型,其中保留大小寫。在 BookCorpus、英文維基百科和 CommonCrawl 上訓練。 |
bart_large_en | 406.29M | 12 層 BART 模型,其中保留大小寫。在 BookCorpus、英文維基百科和 CommonCrawl 上訓練。 |
bart_large_en_cnn | 406.29M | 在 CNN+DM 摘要資料集上微調的 bart_large_en 主幹模型。 |
bloom_560m_multi | 559.21M | 24 層 Bloom 模型,隱藏維度為 1024。在 45 種自然語言和 12 種程式語言上訓練。 |
bloomz_560m_multi | 559.21M | 24 層 Bloom 模型,隱藏維度為 1024。在跨語言任務混合 (xP3) 資料集上進行微調。 |
bloom_1.1b_multi | 1.07B | 24 層 Bloom 模型,隱藏維度為 1536。在 45 種自然語言和 12 種程式語言上訓練。 |
bloomz_1.1b_multi | 1.07B | 24 層 Bloom 模型,隱藏維度為 1536。在跨語言任務混合 (xP3) 資料集上進行微調。 |
bloom_1.7b_multi | 1.72B | 24 層 Bloom 模型,隱藏維度為 2048。在 45 種自然語言和 12 種程式語言上訓練。 |
bloomz_1.7b_multi | 1.72B | 24 層 Bloom 模型,隱藏維度為 2048。在跨語言任務混合 (xP3) 資料集上進行微調。 |
bloom_3b_multi | 3.00B | 30 層 Bloom 模型,隱藏維度為 2560。在 45 種自然語言和 12 種程式語言上訓練。 |
bloomz_3b_multi | 3.00B | 30 層 Bloom 模型,隱藏維度為 2560。在跨語言任務混合 (xP3) 資料集上進行微調。 |
falcon_refinedweb_1b_en | 1.31B | 24 層 Falcon 模型(具有 1B 參數的 Falcon),在 350B 個 RefinedWeb 資料集的 token 上訓練。 |
gemma_2b_en | 2.51B | 20 億參數、18 層、基礎 Gemma 模型。 |
gemma_instruct_2b_en | 2.51B | 20 億參數、18 層、指令調整 Gemma 模型。 |
gemma_1.1_instruct_2b_en | 2.51B | 20 億參數、18 層、指令調整 Gemma 模型。1.1 更新改進了模型品質。 |
code_gemma_1.1_2b_en | 2.51B | 20 億參數、18 層、CodeGemma 模型。此模型已在程式碼完成的填空中間 (FIM) 任務上進行訓練。1.1 更新改進了模型品質。 |
code_gemma_2b_en | 2.51B | 20 億參數、18 層、CodeGemma 模型。此模型已在程式碼完成的填空中間 (FIM) 任務上進行訓練。 |
gemma2_2b_en | 2.61B | 20 億參數、26 層、基礎 Gemma 模型。 |
gemma2_instruct_2b_en | 2.61B | 20 億參數、26 層、指令調整 Gemma 模型。 |
shieldgemma_2b_en | 2.61B | 20 億參數、26 層、ShieldGemma 模型。 |
gemma_7b_en | 8.54B | 70 億參數、28 層、基礎 Gemma 模型。 |
gemma_instruct_7b_en | 8.54B | 70 億參數、28 層、指令調整 Gemma 模型。 |
gemma_1.1_instruct_7b_en | 8.54B | 70 億參數、28 層、指令調整 Gemma 模型。1.1 更新改進了模型品質。 |
code_gemma_7b_en | 8.54B | 70 億參數、28 層、CodeGemma 模型。此模型已在程式碼完成的填空中間 (FIM) 任務上進行訓練。 |
code_gemma_instruct_7b_en | 8.54B | 70 億參數、28 層、指令調整 CodeGemma 模型。此模型已針對與程式碼相關的聊天用例進行訓練。 |
code_gemma_1.1_instruct_7b_en | 8.54B | 70 億參數、28 層、指令調整 CodeGemma 模型。此模型已針對與程式碼相關的聊天用例進行訓練。1.1 更新改進了模型品質。 |
gemma2_9b_en | 9.24B | 90 億參數、42 層、基礎 Gemma 模型。 |
gemma2_instruct_9b_en | 9.24B | 90 億參數、42 層、指令調整 Gemma 模型。 |
shieldgemma_9b_en | 9.24B | 90 億參數、42 層、ShieldGemma 模型。 |
gemma2_27b_en | 27.23B | 270 億參數、42 層、基礎 Gemma 模型。 |
gemma2_instruct_27b_en | 27.23B | 270 億參數、42 層、指令調整 Gemma 模型。 |
shieldgemma_27b_en | 27.23B | 270 億參數、42 層、ShieldGemma 模型。 |
gpt2_base_en | 124.44M | 12 層 GPT-2 模型,其中保留大小寫。在 WebText 上訓練。 |
gpt2_base_en_cnn_dailymail | 124.44M | 12 層 GPT-2 模型,其中保留大小寫。在 CNN/DailyMail 摘要資料集上微調。 |
gpt2_medium_en | 354.82M | 24 層 GPT-2 模型,其中保留大小寫。在 WebText 上訓練。 |
gpt2_large_en | 774.03M | 36 層 GPT-2 模型,其中保留大小寫。在 WebText 上訓練。 |
gpt2_extra_large_en | 1.56B | 48 層 GPT-2 模型,其中保留大小寫。在 WebText 上訓練。 |
llama2_7b_en | 6.74B | 70 億參數、32 層、基礎 LLaMA 2 模型。 |
llama2_instruct_7b_en | 6.74B | 70 億參數、32 層、指令調整 LLaMA 2 模型。 |
vicuna_1.5_7b_en | 6.74B | 70 億參數、32 層、指令調整 Vicuna v1.5 模型。 |
llama2_7b_en_int8 | 6.74B | 70 億參數、32 層、基礎 LLaMA 2 模型,其激活和權重量化為 int8。 |
llama2_instruct_7b_en_int8 | 6.74B | 70 億參數、32 層、指令調整 LLaMA 2 模型,其激活和權重量化為 int8。 |
llama3_8b_en | 8.03B | 80 億參數、32 層、基礎 LLaMA 3 模型。 |
llama3_instruct_8b_en | 8.03B | 80 億參數、32 層、指令調整 LLaMA 3 模型。 |
llama3_8b_en_int8 | 8.03B | 80 億參數、32 層、基礎 LLaMA 3 模型,其激活和權重量化為 int8。 |
llama3_instruct_8b_en_int8 | 8.03B | 80 億參數、32 層、指令調整 LLaMA 3 模型,其激活和權重量化為 int8。 |
mistral_7b_en | 7.24B | Mistral 7B 基礎模型 |
mistral_instruct_7b_en | 7.24B | Mistral 7B 指令模型 |
mistral_0.2_instruct_7b_en | 7.24B | Mistral 7B 指令版本 0.2 模型 |
opt_125m_en | 125.24M | 12 層 OPT 模型,其中保留大小寫。在 BookCorpus、CommonCrawl、Pile 和 PushShift.io 語料庫上訓練。 |
opt_1.3b_en | 1.32B | 24 層 OPT 模型,其中保留大小寫。在 BookCorpus、CommonCrawl、Pile 和 PushShift.io 語料庫上訓練。 |
opt_2.7b_en | 2.70B | 32 層 OPT 模型,其中保留大小寫。在 BookCorpus、CommonCrawl、Pile 和 PushShift.io 語料庫上訓練。 |
opt_6.7b_en | 6.70B | 32 層 OPT 模型,其中保留大小寫。在 BookCorpus、CommonCrawl、Pile 和 PushShift.io 語料庫上訓練。 |
pali_gemma_3b_mix_224 | 2.92B | 影像大小 224,混合微調,文字序列長度為 256 |
pali_gemma_3b_224 | 2.92B | 影像大小 224,預訓練,文字序列長度為 128 |
pali_gemma_3b_mix_448 | 2.92B | 影像大小 448,混合微調,文字序列長度為 512 |
pali_gemma_3b_448 | 2.92B | 影像大小 448,預訓練,文字序列長度為 512 |
pali_gemma_3b_896 | 2.93B | 影像大小 896,預訓練,文字序列長度為 512 |
pali_gemma2_mix_3b_224 | 3.03B | 30 億參數,影像大小 224,SigLIP-So400m 視覺編碼器為 27 層,Gemma2 2B 語言模型為 26 層。此模型已在廣泛的視覺語言任務和領域上進行微調。 |
pali_gemma2_pt_3b_224 | 3.03B | 30 億參數,影像大小 224,SigLIP-So400m 視覺編碼器為 27 層,Gemma2 2B 語言模型為 26 層。此模型已在資料集混合上進行預訓練。 |
pali_gemma_2_ft_docci_3b_448 | 3.03B | 30 億參數,影像大小 448,SigLIP-So400m 視覺編碼器為 27 層,Gemma2 2B 語言模型為 26 層。此模型已在 DOCCI 資料集上進行微調,以改進具有精細細節的描述。 |
pali_gemma2_mix_3b_448 | 3.03B | 30 億參數,影像大小 448,SigLIP-So400m 視覺編碼器為 27 層,Gemma2 2B 語言模型為 26 層。此模型已在廣泛的視覺語言任務和領域上進行微調。 |
pali_gemma2_pt_3b_448 | 3.03B | 30 億參數,影像大小 448,SigLIP-So400m 視覺編碼器為 27 層,Gemma2 2B 語言模型為 26 層。此模型已在資料集混合上進行預訓練。 |
pali_gemma2_pt_3b_896 | 3.04B | 30 億參數,影像大小 896,SigLIP-So400m 視覺編碼器為 27 層,Gemma2 2B 語言模型為 26 層。此模型已在資料集混合上進行預訓練。 |
pali_gemma2_mix_10b_224 | 9.66B | 100 億參數,影像大小 224,SigLIP-So400m 視覺編碼器為 27 層,Gemma2 9B 語言模型為 42 層。此模型已在廣泛的視覺語言任務和領域上進行微調。 |
pali_gemma2_pt_10b_224 | 9.66B | 100 億參數,影像大小 224,SigLIP-So400m 視覺編碼器為 27 層,Gemma2 9B 語言模型為 42 層。此模型已在資料集混合上進行預訓練。 |
pali_gemma2_ft_docci_10b_448 | 9.66B | 100 億參數,SigLIP-So400m 視覺編碼器為 27 層,Gemma2 9B 語言模型為 42 層。此模型已在 DOCCI 資料集上進行微調,以改進具有精細細節的描述。 |
pali_gemma2_mix_10b_448 | 9.66B | 100 億參數,影像大小 448,SigLIP-So400m 視覺編碼器為 27 層,Gemma2 9B 語言模型為 42 層。此模型已在廣泛的視覺語言任務和領域上進行微調。 |
pali_gemma2_pt_10b_448 | 9.66B | 100 億參數,影像大小 448,SigLIP-So400m 視覺編碼器為 27 層,Gemma2 9B 語言模型為 42 層。此模型已在資料集混合上進行預訓練。 |
pali_gemma2_pt_10b_896 | 9.67B | 100 億參數,影像大小 896,SigLIP-So400m 視覺編碼器為 27 層,Gemma2 9B 語言模型為 42 層。此模型已在資料集混合上進行預訓練。 |
pali_gemma2_mix_28b_224 | 27.65B | 280 億參數,影像大小 224,SigLIP-So400m 視覺編碼器為 27 層,Gemma2 27B 語言模型為 46 層。此模型已在廣泛的視覺語言任務和領域上進行微調。 |
pali_gemma2_mix_28b_448 | 27.65B | 280 億參數,影像大小 448,SigLIP-So400m 視覺編碼器為 27 層,Gemma2 27B 語言模型為 46 層。此模型已在廣泛的視覺語言任務和領域上進行微調。 |
pali_gemma2_pt_28b_224 | 27.65B | 280 億參數,影像大小 224,SigLIP-So400m 視覺編碼器為 27 層,Gemma2 27B 語言模型為 46 層。此模型已在資料集混合上進行預訓練。 |
pali_gemma2_pt_28b_448 | 27.65B | 280 億參數,影像大小 448,SigLIP-So400m 視覺編碼器為 27 層,Gemma2 27B 語言模型為 46 層。此模型已在資料集混合上進行預訓練。 |
pali_gemma2_pt_28b_896 | 27.65B | 280 億參數,影像大小 896,SigLIP-So400m 視覺編碼器為 27 層,Gemma2 27B 語言模型為 46 層。此模型已在資料集混合上進行預訓練。 |
phi3_mini_4k_instruct_en | 3.82B | 38 億參數、32 層、4k 上下文長度、Phi-3 模型。該模型是使用 Phi-3 資料集訓練的。此資料集包括合成資料和經過濾的公開可用網站資料,重點是高品質和推理密集型屬性。 |
phi3_mini_128k_instruct_en | 3.82B | 38 億參數、32 層、128k 上下文長度、Phi-3 模型。該模型是使用 Phi-3 資料集訓練的。此資料集包括合成資料和經過濾的公開可用網站資料,重點是高品質和推理密集型屬性。 |
compile
方法CausalLM.compile(
optimizer="auto", loss="auto", weighted_metrics="auto", sampler="top_k", **kwargs
)
設定 CausalLM
任務以進行訓練和生成。
CausalLM
任務使用 optimizer
、loss
和 weighted_metrics
的預設值擴展了 keras.Model.compile
的預設編譯簽章。若要覆寫這些預設值,請在編譯期間將任何值傳遞給這些引數。
CausalLM
任務將新的 sampler
新增至 compile
,可用於控制與 generate
函數一起使用的取樣策略。
請注意,由於訓練輸入包括從損失中排除的填充 token,因此幾乎始終建議使用 weighted_metrics
而不是 metrics
進行編譯。
引數
"auto"
、最佳化器名稱或 keras.Optimizer
實例。預設為 "auto"
,它使用給定模型和任務的預設最佳化器。有關可能的 optimizer
值的更多資訊,請參閱 keras.Model.compile
和 keras.optimizers
。"auto"
、損失名稱或 keras.losses.Loss
實例。預設為 "auto"
,其中 keras.losses.SparseCategoricalCrossentropy
損失將應用於 token 分類 CausalLM
任務。有關可能的 loss
值的更多資訊,請參閱 keras.Model.compile
和 keras.losses
。"auto"
,或在模型訓練和測試期間要評估的指標列表。預設為 "auto"
,其中將應用 keras.metrics.SparseCategoricalAccuracy
以追蹤模型在猜測遮罩 token 值時的準確性。有關可能的 weighted_metrics
值的更多資訊,請參閱 keras.Model.compile
和 keras.metrics
。keras_hub.samplers.Sampler
實例。設定在 generate()
呼叫期間使用的取樣方法。有關內建取樣策略的完整列表,請參閱 keras_hub.samplers
。keras.Model.compile
。generate
方法CausalLM.generate(inputs, max_length=None, stop_token_ids="auto", strip_prompt=False)
產生給定提示 inputs
的文字。
此方法根據給定的 inputs
產生文字。用於生成的取樣方法可以透過 compile()
方法設定。
如果 inputs
是 tf.data.Dataset
,則輸出將「逐批次」生成並串聯。否則,所有輸入都將作為單一批次處理。
如果模型附加了 preprocessor
,則 inputs
將在 generate()
函數內部進行預處理,並且應符合 preprocessor
層預期的結構(通常是原始字串)。如果未附加 preprocessor
,則輸入應符合 backbone
預期的結構。有關每個結構的示範,請參閱上面的範例用法。
引數
tf.data.Dataset
。如果模型附加了 preprocessor
,則 inputs
應符合 preprocessor
層預期的結構。如果未附加 preprocessor
,則 inputs
應符合 backbone
模型預期的結構。preprocessor
的最大設定 sequence_length
。如果 preprocessor
為 None
,則 inputs
應填充到所需的最大長度,並且此引數將被忽略。None
、「auto」或 token ID 元組。預設為「auto」,它使用 preprocessor.tokenizer.end_token_id
。不指定處理器將產生錯誤。None
在生成 max_length
個 token 後停止生成。您也可以指定模型應停止的一系列 token ID。請注意,token 序列將被視為停止 token,不支援多 token 停止序列。save_to_preset
方法CausalLM.save_to_preset(preset_dir)
將任務儲存到預設目錄。
引數
preprocessor
屬性keras_hub.models.CausalLM.preprocessor
用於預處理輸入的 keras_hub.models.Preprocessor
層。
backbone
屬性keras_hub.models.CausalLM.backbone
具有核心架構的 keras_hub.models.Backbone
模型。