作者: Hongyu Chiu
建立日期 2024/05/14
上次修改日期 2024/05/14
描述: 使用 float8 量化訓練簡單的 Transformer 模型。
隨著 Transformer 模型中的參數數量持續增加,訓練和推論變得非常耗費記憶體和運算資源。因此,引入了 8 位元浮點數 (FP8),在效能上比 16 位元浮點數有所改進,且準確度幾乎沒有下降。
詳細而言,FP8 有兩種不同的類型:E4M3 和 E5M2,在訓練的不同部分很有用。
通常,E4M3 最適合在正向傳播期間使用,因為激活和權重需要更高的精度。然而,在反向傳播中,會使用 E5M2,因為梯度較不易受到精度損失的影響,但需要更高的動態範圍。
值得注意的是,FP8 推論部署大大簡化了,因為推論和訓練使用相同的資料類型。這與使用 32 位元或 16 位元浮點數訓練的網路進行 INT8 推論形成對比,後者需要訓練後量化 (PTQ) 校準,甚至需要感知量化訓練 (QAT) 才能維持模型準確度。
在此範例中,我們將建立一個簡單的 Transformer 模型,並使用 FP16 和 FP8 精度進行訓練。您會觀察到,準確度不會隨著精度的降低而下降。
注意:您需要一個具有 FP8 Tensor Core 支援的優良 GPU,才能獲得預期的效能改進。
我們將使用 KerasHub 程式庫來簡化模型實作。此外,使用混合精度訓練來縮短訓練時間。
注意:僅資料處理需要依賴 TensorFlow。
!pip install -q --upgrade keras-hub
!pip install -q --upgrade keras # Upgrade to Keras 3.
import os
os.environ["KERAS_BACKEND"] = "jax"
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import re
import keras
import keras_hub
import tensorflow as tf
keras.config.set_dtype_policy("mixed_bfloat16")
定義一些超參數。
EPOCHS = 3
BATCH_SIZE = 32
VOCABULARY_SIZE = 20000
MAX_SEQUENCE_LENGTH = 200
MODEL_KWARGS = dict(
vocabulary_size=VOCABULARY_SIZE,
max_sequence_length=MAX_SEQUENCE_LENGTH,
hidden_dim=32, # Hidden size for each token
num_heads=2, # Number of attention heads
intermediate_dim=32, # Intermediate size in feedforward network
dropout=0.1, # Dropout rate
)
首先,讓我們下載 IMDB 資料集並解壓縮它。
!mkdir -p datasets
!wget http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz -q -O datasets/aclImdb_v1.tar.gz
!mkdir -p datasets/aclImdb
!tar -xzf datasets/aclImdb_v1.tar.gz -C datasets
!rm -rf datasets/aclImdb/train/unsup
我們將使用 keras.utils.text_dataset_from_directory
工具來從文字檔案產生標籤化的 tf.data.Dataset
資料集。
train_ds = keras.utils.text_dataset_from_directory(
"datasets/aclImdb/train",
batch_size=BATCH_SIZE,
validation_split=0.2,
subset="training",
seed=42,
)
val_ds = keras.utils.text_dataset_from_directory(
"datasets/aclImdb/train",
batch_size=BATCH_SIZE,
validation_split=0.2,
subset="validation",
seed=42,
)
test_ds = keras.utils.text_dataset_from_directory(
"datasets/aclImdb/test", batch_size=BATCH_SIZE
)
Found 25000 files belonging to 2 classes.
Using 20000 files for training.
Found 25000 files belonging to 2 classes.
Using 5000 files for validation.
Found 25000 files belonging to 2 classes.
我們現在將文字轉換為小寫。
train_ds = train_ds.map(lambda x, y: (tf.strings.lower(x), y))
val_ds = val_ds.map(lambda x, y: (tf.strings.lower(x), y))
test_ds = test_ds.map(lambda x, y: (tf.strings.lower(x), y))
讓我們列印一些樣本。
for text_batch, label_batch in train_ds.take(1):
for i in range(3):
print(f"Text: {text_batch.numpy()[i]}")
print(f"Label: {label_batch.numpy()[i]}")
Text: b'"pandemonium" is a horror movie spoof that comes off more stupid than funny. believe me when i tell you, i love comedies. especially comedy spoofs. "airplane", "the naked gun" trilogy, "blazing saddles", "high anxiety", and "spaceballs" are some of my favorite comedies that spoof a particular genre. "pandemonium" is not up there with those films. most of the scenes in this movie had me sitting there in stunned silence because the movie wasn\'t all that funny. there are a few laughs in the film, but when you watch a comedy, you expect to laugh a lot more than a few times and that\'s all this film has going for it. geez, "scream" had more laughs than this film and that was more of a horror film. how bizarre is that?<br /><br />*1/2 (out of four)'
Label: 0
Text: b"david mamet is a very interesting and a very un-equal director. his first movie 'house of games' was the one i liked best, and it set a series of films with characters whose perspective of life changes as they get into complicated situations, and so does the perspective of the viewer.<br /><br />so is 'homicide' which from the title tries to set the mind of the viewer to the usual crime drama. the principal characters are two cops, one jewish and one irish who deal with a racially charged area. the murder of an old jewish shop owner who proves to be an ancient veteran of the israeli independence war triggers the jewish identity in the mind and heart of the jewish detective.<br /><br />this is were the flaws of the film are the more obvious. the process of awakening is theatrical and hard to believe, the group of jewish militants is operatic, and the way the detective eventually walks to the final violent confrontation is pathetic. the end of the film itself is mamet-like smart, but disappoints from a human emotional perspective.<br /><br />joe mantegna and william macy give strong performances, but the flaws of the story are too evident to be easily compensated."
Label: 0
Text: b'great documentary about the lives of ny firefighters during the worst terrorist attack of all time.. that reason alone is why this should be a must see collectors item.. what shocked me was not only the attacks, but the"high fat diet" and physical appearance of some of these firefighters. i think a lot of doctors would agree with me that,in the physical shape they were in, some of these firefighters would not of made it to the 79th floor carrying over 60 lbs of gear. having said that i now have a greater respect for firefighters and i realize becoming a firefighter is a life altering job. the french have a history of making great documentary\'s and that is what this is, a great documentary.....'
Label: 1
我們將使用 keras_hub.tokenizers.WordPieceTokenizer
層來對文字進行符號化。keras_hub.tokenizers.WordPieceTokenizer
採用 WordPiece 詞彙表,並具有用於對文字進行符號化和將符號序列取消符號化的功能。
在我們定義符號化工具之前,我們首先需要使用我們的資料集來訓練它。WordPiece 符號化演算法是一種子詞符號化演算法;在語料庫上訓練它會產生子詞詞彙表。子詞符號化工具是字詞符號化工具(字詞符號化工具需要非常大的詞彙表才能良好涵蓋輸入字詞)和字元符號化工具(字元不像字詞一樣真正編碼意義)之間的折衷方案。幸運的是,KerasHub 使用 keras_hub.tokenizers.compute_word_piece_vocabulary
工具在語料庫上訓練 WordPiece 非常簡單。
def train_word_piece(ds, vocab_size, reserved_tokens):
word_piece_ds = ds.unbatch().map(lambda x, y: x)
vocab = keras_hub.tokenizers.compute_word_piece_vocabulary(
word_piece_ds.batch(1000).prefetch(2),
vocabulary_size=vocab_size,
reserved_tokens=reserved_tokens,
)
return vocab
每個詞彙表都有一些特殊的保留符號。我們有兩個這樣的符號
"[PAD]"
- 填充符號。當輸入序列長度短於最大序列長度時,會在輸入序列長度中附加填充符號。"[UNK]"
- 未知符號。reserved_tokens = ["[PAD]", "[UNK]"]
train_sentences = [element[0] for element in train_ds]
vocab = train_word_piece(train_ds, VOCABULARY_SIZE, reserved_tokens)
讓我們看看一些符號!
print("Tokens: ", vocab[100:110])
Tokens: ['à', 'á', 'â', 'ã', 'ä', 'å', 'æ', 'ç', 'è', 'é']
現在,讓我們定義符號化工具。我們將使用上面訓練的詞彙表來配置符號化工具。我們將定義最大序列長度,以便將所有序列填充到相同的長度(如果序列的長度小於指定的序列長度)。否則,序列將被截斷。
tokenizer = keras_hub.tokenizers.WordPieceTokenizer(
vocabulary=vocab,
lowercase=False,
sequence_length=MAX_SEQUENCE_LENGTH,
)
讓我們嘗試對資料集中的一個樣本進行符號化!為了驗證文字是否已正確符號化,我們也可以將符號清單取消符號化回原始文字。
input_sentence_ex = train_ds.take(1).get_single_element()[0][0]
input_tokens_ex = tokenizer(input_sentence_ex)
print("Sentence: ", input_sentence_ex)
print("Tokens: ", input_tokens_ex)
print("Recovered text after detokenizing: ", tokenizer.detokenize(input_tokens_ex))
Sentence: tf.Tensor(b'great movie - especially the music - etta james - "at last". this speaks volumes when you have finally found that special someone.', shape=(), dtype=string)
Tokens:
[ 218 150 14 393 137 356 14 4917 2941 719 14 3
164 370 3 15 145 2705 11670 186 155 160 557 391
146 452 416 15 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0]
Recovered text after detokenizing: tf.Tensor(b'great movie - especially the music - etta james - " at last " . this speaks volumes when you have finally found that special someone . [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]', shape=(), dtype=string)
接下來,我們將以將被饋送到模型的形式來格式化我們的資料集。我們需要將文字符號化。
def format_dataset(sentence, label):
sentence = tokenizer(sentence)
return ({"input_ids": sentence}, label)
def make_dataset(dataset):
dataset = dataset.map(format_dataset, num_parallel_calls=tf.data.AUTOTUNE)
return dataset.shuffle(512).prefetch(tf.data.AUTOTUNE).cache()
train_ds = make_dataset(train_ds)
val_ds = make_dataset(val_ds)
test_ds = make_dataset(test_ds)
讓我們建立一個簡單的 Transformer 模型。我們將使用 KerasHub 程式庫中的 TokenAndPositionEmbedding
和 TransformerDecoder
。TokenAndPositionEmbedding
代表句子中的字詞及其順序,而 TransformerDecoder
會為我們輸入序列的每個時間步驟輸出一個向量。在這裡,我們取所有時間步驟的平均值,並在其上使用前饋網路來分類文字。
def build_model(
vocabulary_size=20000,
max_sequence_length=200,
hidden_dim=32,
num_heads=2,
intermediate_dim=32,
dropout=0.1,
):
token_id_input = keras.layers.Input(shape=(None,), dtype="int32", name="input_ids")
x = keras_hub.layers.TokenAndPositionEmbedding(
vocabulary_size=vocabulary_size,
sequence_length=max_sequence_length,
embedding_dim=hidden_dim,
)(token_id_input)
x = keras.layers.Dropout(rate=dropout)(x)
x = keras_hub.layers.TransformerDecoder(
intermediate_dim=intermediate_dim,
num_heads=num_heads,
dropout=dropout,
)(x)
x = keras.layers.GlobalAveragePooling1D()(x)
x = keras.layers.Dropout(dropout)(x)
x = keras.layers.Dense(intermediate_dim, activation="relu")(x)
x = keras.layers.Dropout(dropout)(x)
outputs = keras.layers.Dense(1, activation="sigmoid")(x)
return keras.Model(inputs=token_id_input, outputs=outputs)
首先,我們使用混合精度 ("mixed_bfloat16"
) 訓練和評估模型。之後,我們將結果與 FP8 訓練/推論進行比較。
model = build_model(**MODEL_KWARGS)
model.summary()
model.compile(
optimizer="adam",
loss="binary_crossentropy",
metrics=["accuracy"],
)
history = model.fit(train_ds, epochs=EPOCHS, validation_data=val_ds)
result = model.evaluate(test_ds)
print(f"Accuracy (mixed_bfloat16): {result[1]:.2%}")
Model: "functional_1"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ input_ids (InputLayer) │ (None, None) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ token_and_position_embedding │ (None, None, 32) │ 646,400 │ │ (TokenAndPositionEmbedding) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dropout (Dropout) │ (None, None, 32) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ transformer_decoder │ (None, None, 32) │ 6,464 │ │ (TransformerDecoder) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ global_average_pooling1d │ (None, 32) │ 0 │ │ (GlobalAveragePooling1D) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dropout_2 (Dropout) │ (None, 32) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense (Dense) │ (None, 32) │ 1,056 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dropout_3 (Dropout) │ (None, 32) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_1 (Dense) │ (None, 1) │ 33 │ └─────────────────────────────────┴────────────────────────┴───────────────┘
Total params: 653,953 (2.49 MB)
Trainable params: 653,953 (2.49 MB)
Non-trainable params: 0 (0.00 B)
Accuracy (mixed_bfloat16): 75.56%
我們可以使用單行 API 來啟用 FP8 訓練/推論:model.quantize("float8")
。
model = build_model(**MODEL_KWARGS)
model.quantize("float8")
為了檢查是否發生 FP8 訓練,我們可以列印一些與 FP8 訓練相關的變數
*_scale
:將輸入、權重和梯度的分佈移至 FP8 可表示範圍內的縮放係數。預設為 1.0
*_amax_history
:用於縮放係數計算的 amax 歷史視窗。預設為 0.0
,長度為 1024。pattern = r"(transformer).+(multi_head).+(query).+(scale|amax_history)"
for v in model.trainable_variables:
if re.findall(pattern, v.path):
print(v.path)
print(keras.ops.convert_to_numpy(v.value))
FP8 層的 dtype 原則也已修改。
for layer in model._flatten_layers(recursive=True):
if "float8" in str(layer.dtype_policy):
print(f"{layer.name}: {layer.dtype_policy}")
feedforward_output_dense: <QuantizedFloat8DTypePolicy "float8_from_mixed_bfloat16">
feedforward_intermediate_dense: <QuantizedFloat8DTypePolicy "float8_from_mixed_bfloat16">
attention_output: <QuantizedFloat8DTypePolicy "float8_from_mixed_bfloat16">
value: <QuantizedFloat8DTypePolicy "float8_from_mixed_bfloat16">
key: <QuantizedFloat8DTypePolicy "float8_from_mixed_bfloat16">
query: <QuantizedFloat8DTypePolicy "float8_from_mixed_bfloat16">
dense_2: <QuantizedFloat8DTypePolicy "float8_from_mixed_bfloat16">
dense_3: <QuantizedFloat8DTypePolicy "float8_from_mixed_bfloat16">
讓我們訓練模型並查看結果。我們可以驗證,使用 FP8 訓練後,準確度不會下降,且包含 FP8 資訊的變數會改變。
model.compile(
optimizer="adam",
loss="binary_crossentropy",
metrics=["accuracy"],
)
history = model.fit(train_ds, epochs=EPOCHS, validation_data=val_ds)
result = model.evaluate(test_ds)
print(f"Accuracy (float8): {result[1]:.2%}")
for v in model.trainable_variables:
if re.findall(pattern, v.path):
print(v.path)
print(keras.ops.convert_to_numpy(v.value))
Accuracy (float8): 74.16%