作者: Hongyu Chiu、Abheesht Sharma、Matthew Watson
建立日期 2024/08/06
上次修改日期 2024/08/06
描述: 使用 KerasHub 和 LoRA 及 QLoRA 來微調 Gemma LLM。
大型語言模型 (LLM) 已被證明在各種 NLP 任務中都有效。LLM 首先以自監督的方式在大量的文字語料庫上進行預訓練。預訓練有助於 LLM 學習通用的知識,例如單字之間的統計關係。然後可以針對感興趣的下游任務(例如情感分析)微調 LLM。
然而,LLM 的規模非常龐大,我們在微調時不需要訓練模型中的所有參數,尤其是在模型微調所用的資料集相對較小時。換句話說,LLM 在微調方面是過度參數化的。這就是 低秩適應 (LoRA) 的用武之地;它顯著減少了可訓練參數的數量。這會在保持輸出品質的同時,減少訓練時間和 GPU 記憶體使用量。
此外,量化低秩適應 (QLoRA) 延伸了 LoRA,透過量化技術來增強效率,而不會降低效能。
在本範例中,我們將使用 LoRA 和 QLoRA,針對下一個符號預測任務微調 KerasHub 的 Gemma 模型。
請注意,此範例在 Keras 支援的所有後端上執行。TensorFlow 僅用於資料前處理。
在我們開始實作管線之前,讓我們先安裝並匯入我們需要的所有程式庫。我們將使用 KerasHub 程式庫。
其次,讓我們將精確度設定為 bfloat16。這將有助於我們減少記憶體使用量和訓練時間。
此外,請確保已正確設定 KAGGLE_USERNAME
和 KAGGLE_KEY
,以存取 Gemma 模型。
# We might need the latest code from Keras and KerasHub
!pip install -q git+https://github.com/keras-team/keras.git git+https://github.com/keras-team/keras-hub.git
import gc
import os
os.environ["KERAS_BACKEND"] = "jax"
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" # Suppress verbose logging from TF
# os.environ["KAGGLE_USERNAME"] = "..."
# os.environ["KAGGLE_KEY"] = "..."
import keras
import keras_hub
import tensorflow as tf
import tensorflow_datasets as tfds
keras.config.set_dtype_policy("bfloat16")
我們將使用 MTNT (機器翻譯雜訊文字) 資料集,該資料集可從 TensorFlow 資料集取得。在本範例中,我們將使用資料集的法語到英語部分。
train_ds = tfds.load("mtnt/fr-en", split="train")
我們可以列印一些範例。資料集中的每個範例都包含兩個項目
examples = train_ds.take(3)
examples = examples.as_numpy_iterator()
for idx, example in enumerate(examples):
print(f"Example {idx}:")
for key, val in example.items():
print(f"{key}: {val}")
print()
Example 0:
dst: b'Yep, serious...'
src: b"Le journal l'est peut-\xc3\xaatre, mais m\xc3\xaame moi qui suit de droite je les trouve limite de temps en temps..."
Example 1:
dst: b'Finally, I explained to you in what context this copy-pasting is relevant: when we are told padamalgame etc.'
src: b"Enfin je t'ai expliqu\xc3\xa9 dans quel cadre ce copypasta est pertinent : quand on nous dit padamalgame etc."
Example 2:
dst: b'Gift of Ubiquity: Fran\xc3\xa7ois Baroin is now advisor to the Barclays Bank, mayor, president of the agglomeration, professor at HEC Paris, president of the Association of Mayors of France and Advocate Counselor, it must take him half a day each month.'
src: b"Don d'Ubiquit\xc3\xa9 : Fran\xc3\xa7ois Baroin est d\xc3\xa9sormais conseiller \xc3\xa0 la Banque Barclays, maire, pr\xc3\xa9sident d'agglom\xc3\xa9ration, professeur \xc3\xa0 HEC Paris, pr\xc3\xa9sident de l'association des maires de France et avocat Conseiller, \xc3\xa7a doit lui prendre une demi journ\xc3\xa9e par mois."
由於我們將微調模型以執行法語到英語的翻譯任務,因此我們應該格式化指令微調的輸入。例如,我們可以像這樣格式化此範例中的翻譯任務
<start_of_turn>user
Translate French into English:
{src}<end_of_turn>
<start_of_turn>model
{dst}<end_of_turn>
諸如 <start_of_turn>user
、<start_of_turn>model
和 <end_of_turn>
等特殊符號用於 Gemma 模型。您可以從 https://ai.google.dev/gemma/docs/formatting 了解更多資訊
train_ds = train_ds.map(
lambda x: tf.strings.join(
[
"<start_of_turn>user\n",
"Translate French into English:\n",
x["src"],
"<end_of_turn>\n",
"<start_of_turn>model\n",
"Translation:\n",
x["dst"],
"<end_of_turn>",
]
)
)
examples = train_ds.take(3)
examples = examples.as_numpy_iterator()
for idx, example in enumerate(examples):
print(f"Example {idx}:")
print(example)
print()
Example 0:
b"<start_of_turn>user\nTranslate French into English:\nLe journal l'est peut-\xc3\xaatre, mais m\xc3\xaame moi qui suit de droite je les trouve limite de temps en temps...<end_of_turn>\n<start_of_turn>model\nTranslation:\nYep, serious...<end_of_turn>"
Example 1:
b"<start_of_turn>user\nTranslate French into English:\nEnfin je t'ai expliqu\xc3\xa9 dans quel cadre ce copypasta est pertinent : quand on nous dit padamalgame etc.<end_of_turn>\n<start_of_turn>model\nTranslation:\nFinally, I explained to you in what context this copy-pasting is relevant: when we are told padamalgame etc.<end_of_turn>"
Example 2:
b"<start_of_turn>user\nTranslate French into English:\nDon d'Ubiquit\xc3\xa9 : Fran\xc3\xa7ois Baroin est d\xc3\xa9sormais conseiller \xc3\xa0 la Banque Barclays, maire, pr\xc3\xa9sident d'agglom\xc3\xa9ration, professeur \xc3\xa0 HEC Paris, pr\xc3\xa9sident de l'association des maires de France et avocat Conseiller, \xc3\xa7a doit lui prendre une demi journ\xc3\xa9e par mois.<end_of_turn>\n<start_of_turn>model\nTranslation:\nGift of Ubiquity: Fran\xc3\xa7ois Baroin is now advisor to the Barclays Bank, mayor, president of the agglomeration, professor at HEC Paris, president of the Association of Mayors of France and Advocate Counselor, it must take him half a day each month.<end_of_turn>"
為了本範例的目的,我們將選取資料集的一個子集。
train_ds = train_ds.batch(1).take(100)
KerasHub 提供了許多熱門模型架構的實作。在本範例中,我們將使用 GemmaCausalLM
,這是一個用於因果語言建模的端對端 Gemma 模型。因果語言模型會根據先前的符號預測下一個符號。
請注意,sequence_length
設定為 256
以加快擬合速度。
preprocessor = keras_hub.models.GemmaCausalLMPreprocessor.from_preset(
"gemma_1.1_instruct_2b_en", sequence_length=256
)
gemma_lm = keras_hub.models.GemmaCausalLM.from_preset(
"gemma_1.1_instruct_2b_en", preprocessor=preprocessor
)
gemma_lm.summary()
Preprocessor: "gemma_causal_lm_preprocessor"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Tokenizer (type) ┃ Vocab # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ gemma_tokenizer (GemmaTokenizer) │ 256,000 │ └────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘
Model: "gemma_causal_lm"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ padding_mask (InputLayer) │ (None, None) │ 0 │ - │ ├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤ │ token_ids (InputLayer) │ (None, None) │ 0 │ - │ ├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤ │ gemma_backbone │ (None, None, 2048) │ 2,506,172,416 │ padding_mask[0][0], │ │ (GemmaBackbone) │ │ │ token_ids[0][0] │ ├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤ │ token_embedding │ (None, None, 256000) │ 524,288,000 │ gemma_backbone[0][0] │ │ (ReversibleEmbedding) │ │ │ │ └───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘
Total params: 2,506,172,416 (4.67 GB)
Trainable params: 2,506,172,416 (4.67 GB)
Non-trainable params: 0 (0.00 B)
低秩適應 (LoRA) 是一種針對 LLM 的參數高效微調技術。它會凍結 LLM 的權重,並插入可訓練的秩分解矩陣。讓我們更清楚地了解這一點。
假設我們有一個 n x n
預訓練密集層(或權重矩陣)W0
。我們初始化兩個密集層 A
和 B
,其形狀分別為 n x rank
和 rank x n
。rank
遠小於 n
。在論文中,介於 1 和 4 之間的值顯示效果良好。
原始方程式為 output = W0x + b0
,其中 x
是輸入,W0
和 b0
是原始密集層的權重矩陣和偏差項(凍結)。LoRA 方程式為:output = W0x + b0 + BAx
,其中 A
和 B
是秩分解矩陣。
LoRA 是基於以下概念,即由於預訓練語言模型是過度參數化的,因此預訓練語言模型的權重更新具有較低的「固有秩」。即使將 W0
的更新限制為低秩分解矩陣,也可以複製完整微調的預測效能。
讓我們快速計算一下。假設 n
是 768,而 rank
是 4。W0
有 768 x 768 = 589,824
個參數,而 LoRA 層 A
和 B
一起有 768 x 4 + 4 x 768 = 6,144
個參數。因此,對於密集層,我們從 589,824
個可訓練參數減少到 6,144
個可訓練參數!
即使參數的總數增加(因為我們正在新增 LoRA 層),記憶體佔用空間也會減少,因為可訓練參數的數量減少了。讓我們深入探討一下。
模型的記憶體使用量可以分為四個部分
由於使用 LoRA 後,可訓練參數的數量大幅減少,因此最佳化工具記憶體和儲存 LoRA 梯度所需的記憶體遠小於原始模型。這就是大部分記憶體節省發生的原因。
使用 KerasHub 時,我們可以使用單行 API 來啟用 LoRA:enable_lora(rank=4)
從 gemma_lm.summary()
中,我們可以看到啟用 LoRA 可以顯著減少可訓練參數的數量(從 25 億個減少到 130 萬個)。
gemma_lm.backbone.enable_lora(rank=4)
gemma_lm.summary()
Preprocessor: "gemma_causal_lm_preprocessor"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Tokenizer (type) ┃ Vocab # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ gemma_tokenizer (GemmaTokenizer) │ 256,000 │ └────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘
Model: "gemma_causal_lm"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ padding_mask (InputLayer) │ (None, None) │ 0 │ - │ ├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤ │ token_ids (InputLayer) │ (None, None) │ 0 │ - │ ├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤ │ gemma_backbone │ (None, None, 2048) │ 2,507,536,384 │ padding_mask[0][0], │ │ (GemmaBackbone) │ │ │ token_ids[0][0] │ ├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤ │ token_embedding │ (None, None, 256000) │ 524,288,000 │ gemma_backbone[0][0] │ │ (ReversibleEmbedding) │ │ │ │ └───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘
Total params: 2,507,536,384 (4.67 GB)
Trainable params: 1,363,968 (2.60 MB)
Non-trainable params: 2,506,172,416 (4.67 GB)
讓我們微調 LoRA 模型。
# To save memory, use the SGD optimizer instead of the usual AdamW optimizer.
# For this specific example, SGD is more than enough.
optimizer = keras.optimizers.SGD(learning_rate=1e-4)
gemma_lm.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=optimizer,
weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.fit(train_ds, epochs=1)
微調後,回應將遵循提示中提供的指示。
template = (
"<start_of_turn>user\n"
"Translate French into English:\n"
"{inputs}"
"<end_of_turn>\n"
"<start_of_turn>model\n"
"Translation:\n"
)
prompt = template.format(inputs="Bonjour, je m'appelle Morgane.")
outputs = gemma_lm.generate(prompt, max_length=256)
print("Translation:\n", outputs.replace(prompt, ""))
Translation:
Hello, my name is Morgane.
釋放記憶體。
del preprocessor
del gemma_lm
del optimizer
gc.collect()
量化低秩適應 (QLoRA) 擴展了 LoRA,透過將模型權重從高精確度資料類型(例如 float32)量化為較低精確度資料類型(例如 int8)來增強效率。這會導致減少記憶體使用量和更快的計算速度。儲存的模型權重也小得多。
請注意,此處的 QLoRA 實作是與原始實作相比的簡化版本。差異如下
若要在 KerasHub 中啟用 QLoRA,請遵循下列步驟
步驟 2 和 3 可透過單行 API 完成
quantize("int8")
enable_lora(...)
preprocessor = keras_hub.models.GemmaCausalLMPreprocessor.from_preset(
"gemma_1.1_instruct_2b_en", sequence_length=256
)
gemma_lm = keras_hub.models.GemmaCausalLM.from_preset(
"gemma_1.1_instruct_2b_en", preprocessor=preprocessor
)
gemma_lm.quantize("int8")
gemma_lm.backbone.enable_lora(rank=4)
gemma_lm.summary()
Preprocessor: "gemma_causal_lm_preprocessor_1"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Tokenizer (type) ┃ Vocab # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ gemma_tokenizer (GemmaTokenizer) │ 256,000 │ └────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘
Model: "gemma_causal_lm_1"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ padding_mask (InputLayer) │ (None, None) │ 0 │ - │ ├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤ │ token_ids (InputLayer) │ (None, None) │ 0 │ - │ ├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤ │ gemma_backbone │ (None, None, 2048) │ 2,508,502,016 │ padding_mask[0][0], │ │ (GemmaBackbone) │ │ │ token_ids[0][0] │ ├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤ │ token_embedding │ (None, None, 256000) │ 524,544,000 │ gemma_backbone[0][0] │ │ (ReversibleEmbedding) │ │ │ │ └───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘
Total params: 2,508,502,016 (2.34 GB)
Trainable params: 1,363,968 (2.60 MB)
Non-trainable params: 2,507,138,048 (2.34 GB)
讓我們微調 QLoRA 模型。
如果您使用的裝置支援 int8 加速,您應該會看到訓練速度有所提升。
optimizer = keras.optimizers.SGD(learning_rate=1e-4)
gemma_lm.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=optimizer,
weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.fit(train_ds, epochs=1)
您應該會得到類似的 QLoRA 微調輸出。
prompt = template.format(inputs="Bonjour, je m'appelle Morgane.")
outputs = gemma_lm.generate(prompt, max_length=256)
print("Translation:\n", outputs.replace(prompt, ""))
Translation:
Hello, my name is Morgane.
這樣就全部完成了!
請注意,為了示範起見,此範例僅在資料集的一小部分子集上微調模型一個 epoch,並使用較低的 LoRA 秩值。若要從微調模型獲得更好的回應,您可以嘗試
learning_rate
和 weight_decay
。