作者: Abheesht Sharma, Matthew Watson
建立日期 2023/05/27
最後修改日期 2023/05/27
描述: 使用 KerasHub 和 LoRA 微調 GPT-2 LLM。
大型語言模型 (LLM) 已被證明在各種 NLP 任務中非常有效。LLM 首先以自我監督的方式在大型文本語料庫上進行預訓練。預訓練有助於 LLM 學習通用知識,例如單字之間的統計關係。然後,可以針對感興趣的下游任務(例如情感分析)微調 LLM。
然而,LLM 的尺寸非常大,並且我們不需要在微調時訓練模型中的所有參數,特別是因為微調模型的資料集相對較小。換句話說,LLM 對於微調來說是過度參數化的。這就是 低秩適應 (LoRA) 的用武之地;它顯著減少了可訓練參數的數量。這會減少訓練時間和 GPU 記憶體使用量,同時保持輸出的品質。
在此範例中,我們將用技術術語解釋 LoRA,展示技術解釋如何轉換為程式碼,破解 KerasHub 的 GPT-2 模型,並使用 LoRA 在下一個權杖預測任務上對其進行微調。我們將在生成文字的品質、訓練時間和 GPU 記憶體使用量方面比較 LoRA GPT-2 和完全微調的 GPT-2。
注意:此範例僅在 TensorFlow 後端上執行,原因是使用 tf.config.experimental.get_memory_info
API 來輕鬆繪製記憶體使用量。除了記憶體使用量回呼之外,此範例將在 jax
和 torch
後端上執行。
在我們開始實作管線之前,讓我們先安裝並導入我們需要的所有程式庫。我們將使用 KerasHub 程式庫。
其次,讓我們啟用混合精度訓練。這將有助於我們減少訓練時間。
!pip install -q --upgrade keras-hub
!pip install -q --upgrade keras # Upgrade to Keras 3.
import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import keras_hub
import keras
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds
import time
keras.mixed_precision.set_global_policy("mixed_float16")
讓我們也定義我們的超參數。
# General hyperparameters
BATCH_SIZE = 32
NUM_BATCHES = 500
EPOCHS = 1 # Can be set to a higher value for better results
MAX_SEQUENCE_LENGTH = 128
MAX_GENERATION_LENGTH = 200
GPT2_PRESET = "gpt2_base_en"
# LoRA-specific hyperparameters
RANK = 4
ALPHA = 32.0
讓我們載入 Reddit 資料集。我們將在這個資料集的子集上微調 GPT-2 模型和 LoRA GPT-2 模型。目標是產生風格類似於 Reddit 貼文的文字。
reddit_ds = tfds.load("reddit_tifu", split="train", as_supervised=True)
資料集有兩個欄位:document
和 title
。
for document, title in reddit_ds:
print(document.numpy())
print(title.numpy())
break
b"me and a friend decided to go to the beach last sunday. we loaded up and headed out. we were about half way there when i decided that i was not leaving till i had seafood. \n\nnow i'm not talking about red lobster. no friends i'm talking about a low country boil. i found the restaurant and got directions. i don't know if any of you have heard about the crab shack on tybee island but let me tell you it's worth it. \n\nwe arrived and was seated quickly. we decided to get a seafood sampler for two and split it. the waitress bought it out on separate platters for us. the amount of food was staggering. two types of crab, shrimp, mussels, crawfish, andouille sausage, red potatoes, and corn on the cob. i managed to finish it and some of my friends crawfish and mussels. it was a day to be a fat ass. we finished paid for our food and headed to the beach. \n\nfunny thing about seafood. it runs through me faster than a kenyan \n\nwe arrived and walked around a bit. it was about 45min since we arrived at the beach when i felt a rumble from the depths of my stomach. i ignored it i didn't want my stomach to ruin our fun. i pushed down the feeling and continued. about 15min later the feeling was back and stronger than before. again i ignored it and continued. 5min later it felt like a nuclear reactor had just exploded in my stomach. i started running. i yelled to my friend to hurry the fuck up. \n\nrunning in sand is extremely hard if you did not know this. we got in his car and i yelled at him to floor it. my stomach was screaming and if he didn't hurry i was gonna have this baby in his car and it wasn't gonna be pretty. after a few red lights and me screaming like a woman in labor we made it to the store. \n\ni practically tore his car door open and ran inside. i ran to the bathroom opened the door and barely got my pants down before the dam burst and a flood of shit poured from my ass. \n\ni finished up when i felt something wet on my ass. i rubbed it thinking it was back splash. no, mass was covered in the after math of me abusing the toilet. i grabbed all the paper towels i could and gave my self a whores bath right there. \n\ni sprayed the bathroom down with the air freshener and left. an elderly lady walked in quickly and closed the door. i was just about to walk away when i heard gag. instead of walking i ran. i got to the car and told him to get the hell out of there."
b'liking seafood'
我們現在將對資料集進行批次處理,並僅保留 document
欄位,因為我們正在針對下一個單字預測任務微調模型。為了此範例的目的,取資料集的子集。
train_ds = (
reddit_ds.map(lambda document, _: document)
.batch(BATCH_SIZE)
.cache()
.prefetch(tf.data.AUTOTUNE)
)
train_ds = train_ds.take(NUM_BATCHES)
在我們開始微調模型之前,讓我們先定義一些輔助函數和類別。
我們將定義一個自訂回呼函數,以追蹤 GPU 記憶體使用量。回呼函數使用 TensorFlow 的 tf.config.experimental.get_memory_info
API。
在這裡,我們假設我們正在使用單一 GPU,GPU:0
。
class GPUMemoryCallback(keras.callbacks.Callback):
def __init__(
self,
target_batches,
print_stats=False,
**kwargs,
):
super().__init__(**kwargs)
self.target_batches = target_batches
self.print_stats = print_stats
self.memory_usage = []
self.labels = []
def _compute_memory_usage(self):
memory_stats = tf.config.experimental.get_memory_info("GPU:0")
# Convert bytes to GB and store in list.
peak_usage = round(memory_stats["peak"] / (2**30), 3)
self.memory_usage.append(peak_usage)
def on_epoch_begin(self, epoch, logs=None):
self._compute_memory_usage()
self.labels.append(f"epoch {epoch} start")
def on_train_batch_begin(self, batch, logs=None):
if batch in self.target_batches:
self._compute_memory_usage()
self.labels.append(f"batch {batch}")
def on_epoch_end(self, epoch, logs=None):
self._compute_memory_usage()
self.labels.append(f"epoch {epoch} end")
以下是用於產生文字的輔助函數。
def generate_text(model, input_text, max_length=200):
start = time.time()
output = model.generate(input_text, max_length=max_length)
print("\nOutput:")
print(output)
end = time.time()
print(f"Total Time Elapsed: {end - start:.2f}s")
我們將使用 AdamW 最佳化器和交叉熵損失來訓練兩個模型。
def get_optimizer_and_loss():
optimizer = keras.optimizers.AdamW(
learning_rate=5e-5,
weight_decay=0.01,
epsilon=1e-6,
global_clipnorm=1.0, # Gradient clipping.
)
# Exclude layernorm and bias terms from weight decay.
optimizer.exclude_from_weight_decay(var_names=["bias"])
optimizer.exclude_from_weight_decay(var_names=["gamma"])
optimizer.exclude_from_weight_decay(var_names=["beta"])
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
return optimizer, loss
讓我們先載入模型和預處理器。我們使用 128 的序列長度,而不是 1024(這是預設序列長度)。這會限制我們預測長序列的能力,但可以讓我們在 Colab 上快速執行此範例。
preprocessor = keras_hub.models.GPT2CausalLMPreprocessor.from_preset(
"gpt2_base_en",
sequence_length=MAX_SEQUENCE_LENGTH,
)
gpt2_lm = keras_hub.models.GPT2CausalLM.from_preset(
"gpt2_base_en", preprocessor=preprocessor
)
gpt2_lm.summary()
Preprocessor: "gpt2_causal_lm_preprocessor"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Tokenizer (type) ┃ Vocab # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ gpt2_tokenizer (GPT2Tokenizer) │ 50,257 │ └────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘
Model: "gpt2_causal_lm"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ padding_mask (InputLayer) │ (None, None) │ 0 │ - │ ├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤ │ token_ids (InputLayer) │ (None, None) │ 0 │ - │ ├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤ │ gpt2_backbone (GPT2Backbone) │ (None, None, 768) │ 124,439,808 │ padding_mask[0][0], │ │ │ │ │ token_ids[0][0] │ ├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤ │ token_embedding │ (None, None, 50257) │ 38,597,376 │ gpt2_backbone[0][0] │ │ (ReversibleEmbedding) │ │ │ │ └───────────────────────────────┴───────────────────────────┴─────────────┴────────────────────────────────┘
Total params: 124,439,808 (474.70 MB)
Trainable params: 124,439,808 (474.70 MB)
Non-trainable params: 0 (0.00 B)
初始化 GPU 記憶體追蹤器回呼物件,並編譯模型。我們使用具有線性衰減學習率的 Adam 最佳化器。
gpu_memory_callback = GPUMemoryCallback(
target_batches=[5, 10, 25, 50, 100, 150, 200, 300, 400, 500],
print_stats=True,
)
optimizer, loss = get_optimizer_and_loss()
gpt2_lm.compile(
optimizer=optimizer,
loss=loss,
weighted_metrics=["accuracy"],
)
我們都準備好訓練模型了!
gpt2_lm.fit(train_ds, epochs=EPOCHS, callbacks=[gpu_memory_callback])
gpt2_lm_memory_usage = gpu_memory_callback.memory_usage
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1701128462.076856 38706 device_compiler.h:186] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.
W0000 00:00:1701128462.146837 38706 graph_launch.cc:671] Fallback to op-by-op mode because memset node breaks graph update
500/500 ━━━━━━━━━━━━━━━━━━━━ 114s 128ms/step - accuracy: 0.3183 - loss: 3.3682
最後一步,讓我們產生一些文字。我們將利用 XLA 的強大功能。第一次呼叫 generate()
會因為 XLA 編譯而速度較慢,但後續呼叫會非常快。:)
generate_text(gpt2_lm, "I like basketball", max_length=MAX_GENERATION_LENGTH)
generate_text(gpt2_lm, "That Italian restaurant is", max_length=MAX_GENERATION_LENGTH)
Output:
I like basketball, but this one actually happened a few months ago.
i was on my way to a party in the city when i noticed a group of guys were playing basketball. one of my friends, a guy named "jenny," was playing. jenny's mom, a very nice girl, was sitting on her couch.
jenny and jenny were sitting in a circle around her, and i started to play some of my favorite basketball games. i got to the end of the circle and jenny started to run. i didn't know how jenny was doing. she ran, but it
Total Time Elapsed: 6.66s
Output:
That Italian restaurant is a bit of a mystery, because the place is closed.
so i was at my friends house and i went to grab some food, so i got the usual pizza and some chicken, but it wasn't really the pizza, so i just grabbed my friend's pizza.
i had a lot of chicken, but i was hungry, so i decided to grab a few of the other pizza's that were already in there.
i was eating the pizza with some friends and i was eating the pizza and then i got a knock on the door.
the guy in front of me is
Total Time Elapsed: 0.22s
在本節中,我們將討論 LoRA 的技術細節、建立 LoRA GPT-2 模型、對其進行微調並產生文字。
LoRA 是 LLM 的參數高效微調技術。它會凍結 LLM 的權重,並注入可訓練的秩分解矩陣。讓我們更清楚地了解這一點。
假設我們有一個 n x n
的預訓練密集層(或權重矩陣),W0
。我們初始化兩個密集層,A
和 B
,形狀分別為 n x rank
和 rank x n
。rank
比 n
小得多。在論文中,顯示 1 到 4 之間的值效果很好。
原始方程式是 output = W0x + b0
,其中 x
是輸入,W0
和 b0
是原始密集層(凍結)的權重矩陣和偏差項。LoRA 方程式是:output = W0x + b0 + BAx
,其中 A
和 B
是秩分解矩陣。
LoRA 基於這樣的概念:由於預訓練語言模型是過度參數化的,因此對預訓練語言模型的權重的更新具有較低的「內在秩」。即使將 W0
的更新限制為低秩分解矩陣,也可以複製完全微調的預測效能。
讓我們做一些快速數學計算。假設 n
是 768,而 rank
是 4。W0
有 768 x 768 = 589,824
個參數,而 LoRA 層 A
和 B
合計有 768 x 4 + 4 x 768 = 6,144
個參數。因此,對於密集層,我們從 589,824
個可訓練參數變成 6,144
個可訓練參數!
即使參數總數增加(因為我們正在新增 LoRA 層),記憶體佔用量也會減少,因為可訓練參數的數量會減少。讓我們更深入地探討這一點。
模型的記憶體使用量可以分為四個部分
由於使用 LoRA 後,可訓練參數的數量大幅減少,因此 LoRA 的最佳化器記憶體和儲存梯度所需的記憶體遠低於 GPT-2。這是大部分記憶體節省發生的原因。
根據上述技術描述,讓我們來建立一個 LoRA 層。在 Transformer 模型中,LoRA 層會被建立並注入到查詢(query)和值(value)投影矩陣中。在 keras.layers.MultiHeadAttention
中,查詢/值投影層是 keras.layers.EinsumDense
層。
import math
class LoraLayer(keras.layers.Layer):
def __init__(
self,
original_layer,
rank=8,
alpha=32,
trainable=False,
**kwargs,
):
# We want to keep the name of this layer the same as the original
# dense layer.
original_layer_config = original_layer.get_config()
name = original_layer_config["name"]
kwargs.pop("name", None)
super().__init__(name=name, trainable=trainable, **kwargs)
self.rank = rank
self.alpha = alpha
self._scale = alpha / rank
self._num_heads = original_layer_config["output_shape"][-2]
self._hidden_dim = self._num_heads * original_layer_config["output_shape"][-1]
# Layers.
# Original dense layer.
self.original_layer = original_layer
# No matter whether we are training the model or are in inference mode,
# this layer should be frozen.
self.original_layer.trainable = False
# LoRA dense layers.
self.A = keras.layers.Dense(
units=rank,
use_bias=False,
# Note: the original paper mentions that normal distribution was
# used for initialization. However, the official LoRA implementation
# uses "Kaiming/He Initialization".
kernel_initializer=keras.initializers.VarianceScaling(
scale=math.sqrt(5), mode="fan_in", distribution="uniform"
),
trainable=trainable,
name=f"lora_A",
)
# B has the same `equation` and `output_shape` as the original layer.
# `equation = abc,cde->abde`, where `a`: batch size, `b`: sequence
# length, `c`: `hidden_dim`, `d`: `num_heads`,
# `e`: `hidden_dim//num_heads`. The only difference is that in layer `B`,
# `c` represents `rank`.
self.B = keras.layers.EinsumDense(
equation=original_layer_config["equation"],
output_shape=original_layer_config["output_shape"],
kernel_initializer="zeros",
trainable=trainable,
name=f"lora_B",
)
def call(self, inputs):
original_output = self.original_layer(inputs)
if self.trainable:
# If we are fine-tuning the model, we will add LoRA layers' output
# to the original layer's output.
lora_output = self.B(self.A(inputs)) * self._scale
return original_output + lora_output
# If we are in inference mode, we "merge" the LoRA layers' weights into
# the original layer's weights - more on this in the text generation
# section!
return original_output
我們現在將改造原始的 GPT-2 模型,並將 LoRA 層注入其中。在執行此操作之前,我們先做幾件事:
tf.config.experimental.reset_memory_stats
重置「峰值」GPU 記憶體使用量;del gpt2_lm
del optimizer
del loss
# This resets "peak" memory usage to "current" memory usage.
tf.config.experimental.reset_memory_stats("GPU:0")
# Load the original model.
preprocessor = keras_hub.models.GPT2CausalLMPreprocessor.from_preset(
"gpt2_base_en",
sequence_length=128,
)
lora_model = keras_hub.models.GPT2CausalLM.from_preset(
"gpt2_base_en",
preprocessor=preprocessor,
)
我們現在將使用新的 LoRA 層覆蓋原始的查詢/值投影矩陣。
for layer_idx in range(lora_model.backbone.num_layers):
# Change query dense layer.
decoder_layer = lora_model.backbone.get_layer(f"transformer_layer_{layer_idx}")
self_attention_layer = decoder_layer._self_attention_layer
# Allow mutation to Keras layer state.
self_attention_layer._tracker.locked = False
# Change query dense layer.
self_attention_layer._query_dense = LoraLayer(
self_attention_layer._query_dense,
rank=RANK,
alpha=ALPHA,
trainable=True,
)
# Change value dense layer.
self_attention_layer._value_dense = LoraLayer(
self_attention_layer._value_dense,
rank=RANK,
alpha=ALPHA,
trainable=True,
)
現在執行一次前向傳遞,以確保我們仍然有一個有效的計算鏈。
lora_model(preprocessor(["LoRA is very useful for quick LLM finetuning"])[0])
pass
凍結整個 LLM,只有 LoRA 層應該是可訓練的。
for layer in lora_model._flatten_layers():
lst_of_sublayers = list(layer._flatten_layers())
if len(lst_of_sublayers) == 1: # "leaves of the model"
if layer.name in ["lora_A", "lora_B"]:
layer.trainable = True
else:
layer.trainable = False
列印模型的摘要,並檢查不可訓練的參數數量和總參數數量是否正確。
在先前的章節中,我們計算出與 LoRA 層相關的參數數量為 6,144。模型中可訓練的參數總數應為 num_layers * (query, value) * 6,144 = 12 * 2 * 6,144 = 147,456
。不可訓練的參數數量應與原始 GPT-2 模型中的參數總數相同,即 124,439,808
。
lora_model.summary()
Preprocessor: "gpt2_causal_lm_preprocessor_1"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Tokenizer (type) ┃ Vocab # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ gpt2_tokenizer_1 (GPT2Tokenizer) │ 50,257 │ └────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘
Model: "gpt2_causal_lm_1"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ padding_mask (InputLayer) │ (None, None) │ 0 │ - │ ├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤ │ token_ids (InputLayer) │ (None, None) │ 0 │ - │ ├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤ │ gpt2_backbone_1 │ (None, None, 768) │ 124,587,264 │ padding_mask[0][0], │ │ (GPT2Backbone) │ │ │ token_ids[0][0] │ ├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤ │ token_embedding │ (None, None, 50257) │ 38,597,376 │ gpt2_backbone_1[0][0] │ │ (ReversibleEmbedding) │ │ │ │ └───────────────────────────────┴───────────────────────────┴─────────────┴────────────────────────────────┘
Total params: 124,587,264 (475.26 MB)
Trainable params: 147,456 (576.00 KB)
Non-trainable params: 124,439,808 (474.70 MB)
現在我們已經改造並驗證了 LoRA GPT-2 模型,讓我們來訓練它!
gpu_memory_callback = GPUMemoryCallback(
target_batches=[5, 10, 25, 50, 100, 150, 200, 300, 400, 500],
print_stats=True,
)
optimizer, loss = get_optimizer_and_loss()
lora_model.compile(
optimizer=optimizer,
loss=loss,
weighted_metrics=["accuracy"],
)
lora_model.fit(
train_ds,
epochs=EPOCHS,
callbacks=[gpu_memory_callback],
)
lora_model_memory_usage = gpu_memory_callback.memory_usage
2/500 [37m━━━━━━━━━━━━━━━━━━━━ 41s 84ms/step - accuracy: 0.2828 - loss: 3.7188
W0000 00:00:1701128576.353742 38699 graph_launch.cc:671] Fallback to op-by-op mode because memset node breaks graph update
500/500 ━━━━━━━━━━━━━━━━━━━━ 80s 81ms/step - accuracy: 0.2930 - loss: 3.6158
我們已經完成模型的微調!在生成文字之前,讓我們先比較兩個模型的訓練時間和記憶體使用量。GPT-2 在 16 GB Tesla T4(Colab)上的訓練時間為 7 分鐘,而 LoRA 為 5 分鐘,減少了 30%。LoRA GPT-2 的記憶體使用量大約比 GPT-2 少 35%。
plt.bar(
["GPT-2", "LoRA GPT-2"],
[max(gpt2_lm_memory_usage), max(lora_model_memory_usage)],
color=["red", "blue"],
)
plt.xlabel("Time")
plt.ylabel("GPU Memory Usage (in GB)")
plt.title("GPU Memory Usage Comparison")
plt.legend()
plt.show()
WARNING:matplotlib.legend:No artists with labels found to put in legend. Note that artists whose label start with an underscore are ignored when legend() is called with no argument.
LoRA 相較於其他適配器方法的最大優勢之一是它不會產生額外的推論延遲。讓我們了解為什麼。
回想一下我們的 LoRA 方程式:output = W0x + b0 + BAx
。我們可以將其改寫為:output = Wx + b0 = (W0 + BA)x + b0
,其中 W = W0 + BA
。這表示如果我們合併原始模型和適配器的權重,我們實際上將會執行與原始模型相同的計算!
for layer_idx in range(lora_model.backbone.num_layers):
self_attention_layer = lora_model.backbone.get_layer(
f"transformer_layer_{layer_idx}"
)._self_attention_layer
# Merge query dense layer.
query_lora_layer = self_attention_layer._query_dense
A_weights = query_lora_layer.A.kernel # (768, 1) (a, b)
B_weights = query_lora_layer.B.kernel # (1, 12, 64) (b, c, d)
increment_weights = tf.einsum("ab,bcd->acd", A_weights, B_weights) * (ALPHA / RANK)
query_lora_layer.original_layer.kernel.assign_add(increment_weights)
# Merge value dense layer.
value_lora_layer = self_attention_layer._value_dense
A_weights = value_lora_layer.A.kernel # (768, 1) (a, b)
B_weights = value_lora_layer.B.kernel # (1, 12, 64) (b, c, d)
increment_weights = tf.einsum("ab,bcd->acd", A_weights, B_weights) * (ALPHA / RANK)
value_lora_layer.original_layer.kernel.assign_add(increment_weights)
# Put back in place the original layers with updated weights
self_attention_layer._query_dense = query_lora_layer.original_layer
self_attention_layer._value_dense = value_lora_layer.original_layer
我們現在都準備好使用我們的 LoRA 模型生成文字了 :)。
# Freezing weights not necessary during generation since no weights are updated.
generate_text(lora_model, "I like basketball", max_length=MAX_GENERATION_LENGTH)
generate_text(
lora_model, "That Italian restaurant is", max_length=MAX_GENERATION_LENGTH
)
Output:
I like basketball. i've played this game for about a week and i'm pretty tired. today, i'm playing with my friend, who is a really good player. i'm a little older than the average player and i'm a bit too young.
Total Time Elapsed: 6.81s
Output:
That Italian restaurant is in the city center and is located on a street that was recently renovated for the summer.
i was in a group of friends and had a great time.
Total Time Elapsed: 0.32s
我們都完成了!