作者: Abheesht Sharma
建立日期 2023/07/08
上次修改日期 2024/03/20
描述: 使用 KerasHub 在抽象式摘要任務上微調 BART。
在資訊超載的時代,從長篇文件或對話中提取重點,並用幾句話表達出來,變得至關重要。由於摘要在不同領域有廣泛的應用,近年來它已成為一項關鍵且經過深入研究的 NLP 任務。
雙向自回歸 Transformer (BART) 是一種基於 Transformer 的編碼器-解碼器模型,通常用於序列到序列的任務,例如摘要和神經機器翻譯。BART 以自我監督的方式在大型文本語料庫上進行預訓練。在預訓練期間,文本會被損壞,而 BART 被訓練來重建原始文本(因此稱為「去噪自編碼器」)。一些預訓練任務包括標記遮罩、標記刪除、句子排列(打亂句子並訓練 BART 來修正順序)等等。
在本範例中,我們將示範如何使用 KerasHub 在抽象式摘要任務(針對對話!)上微調 BART,並使用微調後的模型生成摘要。
在開始實作流程之前,讓我們先安裝和匯入我們需要的所有程式庫。我們將使用 KerasHub 程式庫。我們也需要一些實用程式庫。
!pip install git+https://github.com/keras-team/keras-hub.git py7zr -q
Installing build dependencies ... [?25l[?25hdone
Getting requirements to build wheel ... [?25l[?25hdone
Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 66.4/66.4 kB [31m1.4 MB/s eta [36m0:00:00
[2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.1/2.1 MB [31m34.8 MB/s eta [36m0:00:00
[2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 412.3/412.3 kB [31m30.4 MB/s eta [36m0:00:00
[2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 138.8/138.8 kB [31m15.1 MB/s eta [36m0:00:00
[2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 49.8/49.8 kB [31m5.8 MB/s eta [36m0:00:00
[2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.7/2.7 MB [31m61.4 MB/s eta [36m0:00:00
[2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 93.1/93.1 kB [31m10.1 MB/s eta [36m0:00:00
[?25h Building wheel for keras-hub (pyproject.toml) ... [?25l[?25hdone
此範例使用 Keras 3 在 "tensorflow"
、"jax"
或 "torch"
的任一後端中運作。Keras 3 的支援已內建於 KerasHub 中,只需變更 "KERAS_BACKEND"
環境變數即可選擇您選擇的後端。我們在下面選擇 JAX 後端。
import os
os.environ["KERAS_BACKEND"] = "jax"
匯入所有必要的程式庫。
import py7zr
import time
import keras_hub
import keras
import tensorflow as tf
import tensorflow_datasets as tfds
Using JAX backend.
讓我們也定義我們的超參數。
BATCH_SIZE = 8
NUM_BATCHES = 600
EPOCHS = 1 # Can be set to a higher value for better results
MAX_ENCODER_SEQUENCE_LENGTH = 512
MAX_DECODER_SEQUENCE_LENGTH = 128
MAX_GENERATION_LENGTH = 40
讓我們載入 SAMSum 資料集。此資料集包含大約 15,000 對的對話/對白和摘要。
# Download the dataset.
filename = keras.utils.get_file(
"corpus.7z",
origin="https://huggingface.co/datasets/samsum/resolve/main/data/corpus.7z",
)
# Extract the `.7z` file.
with py7zr.SevenZipFile(filename, mode="r") as z:
z.extractall(path="/root/tensorflow_datasets/downloads/manual")
# Load data using TFDS.
samsum_ds = tfds.load("samsum", split="train", as_supervised=True)
Downloading data from https://huggingface.co/datasets/samsum/resolve/main/data/corpus.7z
2944100/2944100 ━━━━━━━━━━━━━━━━━━━━ 1s 0us/step
Downloading and preparing dataset Unknown size (download: Unknown size, generated: 10.71 MiB, total: 10.71 MiB) to /root/tensorflow_datasets/samsum/1.0.0...
Generating splits...: 0%| | 0/3 [00:00<?, ? splits/s]
Generating train examples...: 0%| | 0/14732 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/samsum/1.0.0.incompleteYA9MAV/samsum-train.tfrecord*...: 0%| | …
Generating validation examples...: 0%| | 0/818 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/samsum/1.0.0.incompleteYA9MAV/samsum-validation.tfrecord*...: 0%| …
Generating test examples...: 0%| | 0/819 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/samsum/1.0.0.incompleteYA9MAV/samsum-test.tfrecord*...: 0%| | 0…
Dataset samsum downloaded and prepared to /root/tensorflow_datasets/samsum/1.0.0. Subsequent calls will reuse this data.
資料集有兩個欄位:dialogue
和 summary
。讓我們看看一個範例。
for dialogue, summary in samsum_ds:
print(dialogue.numpy())
print(summary.numpy())
break
b"Carter: Hey Alexis, I just wanted to let you know that I had a really nice time with you tonight. \r\nAlexis: Thanks Carter. Yeah, I really enjoyed myself as well. \r\nCarter: If you are up for it, I would really like to see you again soon.\r\nAlexis: Thanks Carter, I'm flattered. But I have a really busy week coming up.\r\nCarter: Yeah, no worries. I totally understand. But if you ever want to go grab dinner again, just let me know. \r\nAlexis: Yeah of course. Thanks again for tonight. \r\nCarter: Sure. Have a great night. "
b'Alexis and Carter met tonight. Carter would like to meet again, but Alexis is busy.'
我們現在將對資料集進行批次處理,並且僅保留資料集的一個子集以用於本範例的目的。對話會被輸入到編碼器,而對應的摘要則作為解碼器的輸入。因此,我們會將資料集的格式變更為具有兩個鍵的字典:"encoder_text"
和 "decoder_text"
。這就是 keras_hub.models.BartSeq2SeqLMPreprocessor
所期望的輸入格式。
train_ds = (
samsum_ds.map(
lambda dialogue, summary: {"encoder_text": dialogue, "decoder_text": summary}
)
.batch(BATCH_SIZE)
.cache()
)
train_ds = train_ds.take(NUM_BATCHES)
讓我們先載入模型和預處理器。我們針對編碼器和解碼器分別使用 512 和 128 的序列長度,而不是 1024(預設序列長度)。這可讓我們在 Colab 上快速執行此範例。
如果您仔細觀察,預處理器會附加到模型上。這表示我們不必擔心預處理文字輸入;一切都會在內部完成。預處理器會將編碼器文字和解碼器文字進行標記化、新增特殊標記並進行填補。為了產生用於自回歸訓練的標籤,預處理器會將解碼器文字向右移動一個位置。這樣做的原因是在每個時間步,模型都會被訓練來預測下一個標記。
preprocessor = keras_hub.models.BartSeq2SeqLMPreprocessor.from_preset(
"bart_base_en",
encoder_sequence_length=MAX_ENCODER_SEQUENCE_LENGTH,
decoder_sequence_length=MAX_DECODER_SEQUENCE_LENGTH,
)
bart_lm = keras_hub.models.BartSeq2SeqLM.from_preset(
"bart_base_en", preprocessor=preprocessor
)
bart_lm.summary()
Downloading data from https://storage.googleapis.com/keras-hub/models/bart_base_en/v1/vocab.json
898823/898823 ━━━━━━━━━━━━━━━━━━━━ 1s 1us/step
Downloading data from https://storage.googleapis.com/keras-hub/models/bart_base_en/v1/merges.txt
456318/456318 ━━━━━━━━━━━━━━━━━━━━ 1s 1us/step
Downloading data from https://storage.googleapis.com/keras-hub/models/bart_base_en/v1/model.h5
557969120/557969120 ━━━━━━━━━━━━━━━━━━━━ 29s 0us/step
Preprocessor: "bart_seq2_seq_lm_preprocessor"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Tokenizer (type) ┃ Vocab # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ bart_tokenizer (BartTokenizer) │ 50,265 │ └────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘
Model: "bart_seq2_seq_lm"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ decoder_padding_mask │ (None, None) │ 0 │ - │ │ (InputLayer) │ │ │ │ ├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤ │ decoder_token_ids │ (None, None) │ 0 │ - │ │ (InputLayer) │ │ │ │ ├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤ │ encoder_padding_mask │ (None, None) │ 0 │ - │ │ (InputLayer) │ │ │ │ ├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤ │ encoder_token_ids │ (None, None) │ 0 │ - │ │ (InputLayer) │ │ │ │ ├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤ │ bart_backbone (BartBackbone) │ [(None, None, 768), │ 139,417,344 │ decoder_padding_mask[0][0], │ │ │ (None, None, 768)] │ │ decoder_token_ids[0][0], │ │ │ │ │ encoder_padding_mask[0][0], │ │ │ │ │ encoder_token_ids[0][0] │ ├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤ │ reverse_embedding │ (None, 50265) │ 38,603,520 │ bart_backbone[0][0] │ │ (ReverseEmbedding) │ │ │ │ └───────────────────────────────┴───────────────────────────┴─────────────┴────────────────────────────────┘
Total params: 139,417,344 (4.15 GB)
Trainable params: 139,417,344 (4.15 GB)
Non-trainable params: 0 (0.00 B)
定義最佳化器和損失。我們使用 Adam 最佳化器和線性衰減學習速率。編譯模型。
optimizer = keras.optimizers.AdamW(
learning_rate=5e-5,
weight_decay=0.01,
epsilon=1e-6,
global_clipnorm=1.0, # Gradient clipping.
)
# Exclude layernorm and bias terms from weight decay.
optimizer.exclude_from_weight_decay(var_names=["bias"])
optimizer.exclude_from_weight_decay(var_names=["gamma"])
optimizer.exclude_from_weight_decay(var_names=["beta"])
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
bart_lm.compile(
optimizer=optimizer,
loss=loss,
weighted_metrics=["accuracy"],
)
讓我們訓練模型!
bart_lm.fit(train_ds, epochs=EPOCHS)
600/600 ━━━━━━━━━━━━━━━━━━━━ 398s 586ms/step - loss: 0.4330
<keras_core.src.callbacks.history.History at 0x7ae2faf3e110>
現在模型已經過訓練,讓我們開始有趣的部分 - 實際產生摘要!讓我們從驗證集中挑選前 100 個樣本,並為它們產生摘要。我們將使用預設的解碼策略,即貪婪搜尋。
KerasHub 中的生成經過高度最佳化。它背後有 XLA 的強大支援。其次,會快取解碼器中自我注意力層和交叉注意力層中的鍵/值張量,以避免在每個時間步重新計算。
def generate_text(model, input_text, max_length=200, print_time_taken=False):
start = time.time()
output = model.generate(input_text, max_length=max_length)
end = time.time()
print(f"Total Time Elapsed: {end - start:.2f}s")
return output
# Load the dataset.
val_ds = tfds.load("samsum", split="validation", as_supervised=True)
val_ds = val_ds.take(100)
dialogues = []
ground_truth_summaries = []
for dialogue, summary in val_ds:
dialogues.append(dialogue.numpy())
ground_truth_summaries.append(summary.numpy())
# Let's make a dummy call - the first call to XLA generally takes a bit longer.
_ = generate_text(bart_lm, "sample text", max_length=MAX_GENERATION_LENGTH)
# Generate summaries.
generated_summaries = generate_text(
bart_lm,
val_ds.map(lambda dialogue, _: dialogue).batch(8),
max_length=MAX_GENERATION_LENGTH,
print_time_taken=True,
)
Total Time Elapsed: 21.22s
Total Time Elapsed: 49.00s
讓我們看看一些摘要。
for dialogue, generated_summary, ground_truth_summary in zip(
dialogues[:5], generated_summaries[:5], ground_truth_summaries[:5]
):
print("Dialogue:", dialogue)
print("Generated Summary:", generated_summary)
print("Ground Truth Summary:", ground_truth_summary)
print("=============================")
Dialogue: b'Tony: Is the boss in?\r\nClaire: Not yet.\r\nTony: Could let me know when he comes, please? \r\nClaire: Of course.\r\nTony: Thank you.'
Generated Summary: Tony will let Claire know when her boss comes.
Ground Truth Summary: b"The boss isn't in yet. Claire will let Tony know when he comes."
=============================
Dialogue: b"James: What shouldl I get her?\r\nTim: who?\r\nJames: gees Mary my girlfirend\r\nTim: Am I really the person you should be asking?\r\nJames: oh come on it's her birthday on Sat\r\nTim: ask Sandy\r\nTim: I honestly am not the right person to ask this\r\nJames: ugh fine!"
Generated Summary: Mary's girlfriend is birthday. James and Tim are going to ask Sandy to buy her.
Ground Truth Summary: b"Mary's birthday is on Saturday. Her boyfriend, James, is looking for gift ideas. Tim suggests that he ask Sandy."
=============================
Dialogue: b"Mary: So, how's Israel? Have you been on the beach?\r\nKate: It's so expensive! But they say, it's Tel Aviv... Tomorrow we are going to Jerusalem.\r\nMary: I've heard Israel is expensive, Monica was there on vacation last year, she complained about how pricey it is. Are you going to the Dead Sea before it dies? ahahahha\r\nKate: ahahhaha yup, in few days."
Generated Summary: Kate is on vacation in Tel Aviv. Mary will visit the Dead Sea in a few days.
Ground Truth Summary: b'Mary and Kate discuss how expensive Israel is. Kate is in Tel Aviv now, planning to travel to Jerusalem tomorrow, and to the Dead Sea few days later.'
=============================
Dialogue: b"Giny: do we have rice?\r\nRiley: nope, it's finished\r\nGiny: fuck!\r\nGiny: ok, I'll buy"
Generated Summary: Giny wants to buy rice from Riley.
Ground Truth Summary: b"Giny and Riley don't have any rice left. Giny will buy some."
=============================
Dialogue: b"Jude: i'll be in warsaw at the beginning of december so we could meet again\r\nLeon: !!!\r\nLeon: at the beginning means...?\r\nLeon: cuz I won't be here during the first weekend\r\nJude: 10\r\nJude: but i think it's a monday, so never mind i guess :D\r\nLeon: yeah monday doesn't really work for me :D\r\nLeon: :<\r\nJude: oh well next time :d\r\nLeon: yeah...!"
Generated Summary: Jude and Leon will meet again this weekend at 10 am.
Ground Truth Summary: b'Jude is coming to Warsaw on the 10th of December and wants to see Leon. Leon has no time.'
=============================
產生的摘要看起來很棒!對於僅訓練 1 個 epoch 且使用 5000 個範例的模型來說,還不錯 :)