程式碼範例 / 自然語言處理 / 端對端遮罩語言模型與 BERT

端對端遮罩語言模型與 BERT

作者: Ankur Singh
建立日期 2020/09/18
最後修改日期 2024/03/15
描述: 實作一個使用 BERT 的遮罩語言模型 (MLM),並在 IMDB 評論資料集上微調它。

ⓘ 這個範例使用 Keras 3

在 Colab 中檢視 GitHub 原始碼


簡介

遮罩語言模型是一個填空任務,模型使用遮罩符號周圍的上下文詞語,嘗試預測被遮罩的詞語應該是什麼。

對於包含一個或多個遮罩符號的輸入,模型將為每個符號生成最可能的替換。

範例

  • 輸入:「I have watched this [MASK] and it was awesome.」
  • 輸出:「I have watched this movie and it was awesome.」

遮罩語言模型是在自我監督設定(沒有人工註釋標籤)中訓練語言模型的好方法。然後,可以微調此模型以完成各種監督式 NLP 任務。

這個範例教您如何從頭開始建立 BERT 模型,使用遮罩語言模型任務訓練它,然後在情感分類任務上微調此模型。

我們將使用 Keras TextVectorizationMultiHeadAttention 層來建立 BERT Transformer 編碼器網路架構。

注意:這個範例應該使用 tf-nightly 執行。


安裝設定

透過 pip install tf-nightly 安裝 tf-nightly

import os

os.environ["KERAS_BACKEND"] = "torch"  # or jax, or tensorflow

import keras_hub

import keras
from keras import layers
from keras.layers import TextVectorization

from dataclasses import dataclass
import pandas as pd
import numpy as np
import glob
import re
from pprint import pprint

設定配置

@dataclass
class Config:
    MAX_LEN = 256
    BATCH_SIZE = 32
    LR = 0.001
    VOCAB_SIZE = 30000
    EMBED_DIM = 128
    NUM_HEAD = 8  # used in bert model
    FF_DIM = 128  # used in bert model
    NUM_LAYERS = 1


config = Config()

載入資料

我們將首先下載 IMDB 資料並載入到 Pandas 資料框架中。

!curl -O https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz
!tar -xf aclImdb_v1.tar.gz
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed

0 0 0 0 0 0 0 0 –:–:– –:–:– –:–:– 0

0 0 0 0 0 0 0 0 –:–:– 0:00:01 –:–:– 0

0 80.2M 0 16384 0 0 8170 0 2:51:36 0:00:02 2:51:34 8171

0 80.2M 0 32768 0 0 11611 0 2:00:45 0:00:02 2:00:43 11611

0 80.2M 0 144k 0 0 36325 0 0:38:35 0:00:04 0:38:31 36328

0 80.2M 0 240k 0 0 50462 0 0:27:47 0:00:04 0:27:43 50464

0 80.2M 0 464k 0 0 81947 0 0:17:06 0:00:05 0:17:01 100k

0 80.2M 0 672k

0     0  97992      0  0:14:18  0:00:07  0:14:11  130k

1 80.2M 1 912k 0 0 116k 0 0:11:42 0:00:07 0:11:35 176k

1 80.2M 1 1104k 0 0 123k 0 0:11:07 0:00:08 0:10:59 195k

1 80.2M 1 1424k 0 0 145k 0 0:09:25 0:00:09 0:09:16 240k

2 80.2M 2 1696k 0 0 153k 0 0:08:53 0:00:11 0:08:42 236k

2 80.2M 2 2016k 0 0 170k 0 0:08:02 0:00:11 0:07:51 279k

3 80.2M 3 2544k 0 0 194k 0 0:07:01 0:00:13 0:06:48 310k

3 80.2M 3 2816k 0 0 202k 0 0:06:45 0:00:13 0:06:32 348k

3 80.2M 3 3264k 0 0 215k 0 0:06:20 0:00:15 0:06:05 346k

4 80.2M 4 3632k 0 0 227k 0 0:06:00 0:00:15 0:05:45 393k

5 80.2M 5 4128k 0 0 240k 0 0:05:40 0:00:17 0:05:23 398k

5 80.2M 5 4384k 0 0 243k 0 0:05:36 0:00:17 0:05:19 374k

5 80.2M 5 4832k 0 0 252k 0 0:05:24 0:00:19 0:05:05 386k

6 80.2M 6 5152k 0 0 258k 0 0:05:17 0:00:19 0:04:58 392k

6 80.2M 6 5632k 0 0 267k 0 0:05:07 0:00:21 0:04:46 389k

7 80.2M 7 5952k 0 0 272k 0 0:05:01 0:00:21 0:04:40 385k

7 80.2M 7 6432k 0 0 278k 0 0:04:55 0:00:23 0:04:32 400k

8 80.2M 8 6768k 0 0 284k 0 0:04:49 0:00:23 0:04:26 411k

9 80.2M 9 7408k 0 0 298k 0 0:04:34 0:00:24 0:04:10 462k

9 80.2M 9 7952k 0 0 308k 0 0:04:26 0:00:25 0:04:01 489k

10 80.2M 10 8896k 0 0 331k 0 0:04:07 0:00:26 0:03:41 597k

11 80.2M 11 9392k 0 0 337k 0 0:04:03 0:00:27 0:03:36 625k

12 80.2M 12 10.0M 0 0 355k 0 0:03:51 0:00:28 0:03:23 696k

13 80.2M 13 10.7M 0 0 368k 0 0:03:42 0:00:29 0:03:13 705k

14 80.2M 14 11.4M 0 0 377k 0 0:03:37 0:00:31 0:03:06 717k

14 80.2M 14 11.6M 0 0 373k 0 0:03:40 0:00:31 0:03:09 589k

15 80.2M 15 12.3M 0 0 380k 0 0:03:35 0:00:33 0:03:02 606k

15 80.2M 15 12.6M 0 0 383k 0 0:03:34 0:00:33 0:03:01 542k

16 80.2M 16 13.0M 0 0 381k 0 0:03:35 0:00:34 0:03:01 455k

16 80.2M 16 13.2M 0 0 378k 0 0:03:37 0:00:35 0:03:02 383k

17 80.2M 17 13.9M 0 0 387k 0 0:03:31 0:00:36 0:02:55 480k

18 80.2M 18 14.4M 0 0 388k 0 0:03:31 0:00:38 0:02:53 440k

18 80.2M 18 14.8M 0 0 389k 0 0:03:30 0:00:38 0:02:52 431k

19 80.2M 19 15.3M 0 0 394k 0 0:03:28 0:00:39 0:02:49 491k

19 80.2M 19 15.6M 0 0 390k 0 0:03:30 0:00:41 0:02:49 468k

19 80.2M 19 15.8M 0 0 388k 0 0:03:31 0:00:41 0:02:50 397k

20 80.2M 20 16.4M 0 0 390k 0 0:03:30 0:00:42 0:02:48 411k

20 80.2M 20 16.8M 0 0 393k 0 0:03:28 0:00:43 0:02:45 427k

21 80.2M 21 17.2M 0 0 394k 0 0:03:28 0:00:44 0:02:44 393k

22 80.2M 22 17.6M 0 0 393k 0 0:03:29 0:00:46 0:02:43 417k

22 80.2M 22 17.9M 0 0 393k 0 0:03:29 0:00:46 0:02:43 427k

23 80.2M 23 18.4M 0 0 393k 0 0:03:28 0:00:48 0:02:40 416k

23 80.2M 23 18.8M 0 0 390k 0 0:03:30 0:00:49 0:02:41 364k

24 80.2M 24 19.2M 0 0 394k 0 0:03:28 0:00:50 0:02:38 398k

24 80.2M 24 19.5M 0 0 393k 0 0:03:28 0:00:50 0:02:38 402k

24 80.2M 24 20.0M 0 0 393k 0 0:03:28 0:00:52 0:02:36 401k

25 80.2M 25 20.3M 0 0 394k 0 0:03:28 0:00:52 0:02:36 405k

26 80.2M 26 21.2M 0 0 398k 0 0:03:25 0:00:54 0:02:31 478k

26 80.2M 26 21.2M 0 0 396k 0 0:03:27 0:00:54 0:02:33 407k

26 80.2M 26 21.6M 0 0 394k 0 0:03:28 0:00:56 0:02:32 399k

27 80.2M 27 21.9M 0 0 394k 0 0:03:28 0:00:56 0:02:32 400k

27 80.2M 27 22.2M 0 0 394k 0 0:03:28 0:00:57 0:02:31 396k

28 80.2M 28 23.1M 0 0 398k 0 0:03:26 0:00:59 0:02:27 396k

28 80.2M 28 23.1M 0 0 396k 0 0:03:27 0:00:59 0:02:28 397k

29 80.2M 29 23.5M 0 0 395k 0 0:03:27 0:01:00 0:02:27 405k

29 80.2M 29 23.8M 0 0 395k 0 0:03:27 0:01:01 0:02:26 405k

30 80.2M 30 24.3M 0 0 395k 0 0:03:27 0:01:03 0:02:24 407k

31 80.2M 31 24.8M 0 0 397k 0 0:03:26 0:01:04 0:02:22 387k

31 80.2M 31 25.5M 0 0 401k 0 0:03:24 0:01:05 0:02:19 465k

32 80.2M 32 25.7M 0 0 400k 0 0:03:25 0:01:05 0:02:20 465k

32 80.2M 32 26.1M 0 0 399k 0 0:03:25 0:01:06 0:02:19 456k

33 80.2M 33 26.6M 0 0 400k 0 0:03:25 0:01:08 0:02:17 450k

33 80.2M 33 26.9M 0 0 400k 0 0:03:25 0:01:08 0:02:17 432k

34 80.2M 34 27.6M 0 0 403k 0 0:03:23 0:01:10 0:02:13 422k

34 80.2M 34 27.7M 0 0 401k 0 0:03:24 0:01:10 0:02:14 412k

35 80.2M 35 28.2M 0 0 401k 0 0:03:24 0:01:12 0:02:12 418k

35 80.2M 35 28.5M 0 0 400k 0 0:03:24 0:01:12 0:02:12 414k

36 80.2M 36 29.0M 0 0 401k 0 0:03:24 0:01:13 0:02:11 424k

36 80.2M 36 29.3M 0 0 400k 0 0:03:25 0:01:15 0:02:10 360k

37 80.2M 37 29.8M 0 0 403k 0 0:03:23 0:01:15 0:02:08 432k

37 80.2M 37 30.2M 0 0 402k 0 0:03:24 0:01:17 0:02:07 417k

38 80.2M 38 30.6M 0 0 402k 0 0:03:24 0:01:17 0:02:07 424k

38 80.2M 38 31.1M 0 0 403k 0 0:03:23 0:01:18 0:02:05 432k

39 80.2M 39 31.6M 0 0 404k 0 0:03:23 0:01:20 0:02:03 465k

39 80.2M 39 31.9M 0 0 404k 0 0:03:23 0:01:20 0:02:03 415k

40 80.2M 40 32.5M 0 0 404k 0 0:03:22 0:01:22 0:02:00 443k

40 80.2M 40 32.8M 0 0 405k 0 0:03:22 0:01:23 0:01:59 446k

41 80.2M 41 33.1M 0 0 405k 0 0:03:22 0:01:23 0:01:59 429k

42 80.2M 42 33.8M 0 0 408k 0 0:03:21 0:01:24 0:01:57 476k

42 80.2M 42 34.1M 0 0 406k 0 0:03:22 0:01:25 0:01:57 446k

42 80.2M 42 34.4M 0 0 406k 0 0:03:21 0:01:26 0:01:55 440k

43 80.2M 43 34.9M 0 0 407k 0 0:03:21 0:01:27 0:01:54 439k

44 80.2M 44 35.4M 0 0 408k 0 0:03:21 0:01:28 0:01:53 459k

44 80.2M 44 36.0M 0 0 408k 0 0:03:20 0:01:30 0:01:50 414k

45 80.2M 45 36.2M 0 0 408k 0 0:03:21 0:01:30 0:01:51 436k

45 80.2M 45 36.5M 0 0 407k 0 0:03:21 0:01:31 0:01:50 424k

46 80.2M 46 37.0M 0 0 408k 0 0:03:21 0:01:33 0:01:48 425k

46 80.2M 46 37.6M 0 0 408k 0 0:03:20 0:01:34 0:01:46 416k

47 80.2M 47 38.1M 0 0 410k 0 0:03:20 0:01:35 0:01:45 446k

47 80.2M 47 38.2M 0 0 409k 0 0:03:20 0:01:35 0:01:45 422k

48 80.2M 48 38.6M 0 0 408k 0 0:03:21 0:01:37 0:01:44 416k

48 80.2M 48 38.9M 0 0 408k 0 0:03:21 0:01:37 0:01:44 409k

49 80.2M 49 39.5M 0 0 408k 0 0:03:21 0:01:38 0:01:43 402k

50 80.2M 50 40.2M 0 0 410k 0 0:03:20 0:01:40 0:01:40 410k

50 80.2M 50 40.3M 0 0 408k 0 0:03:20 0:01:41 0:01:39 407k

50 80.2M 50 40.6M 0 0 408k 0 0:03:21 0:01:41 0:01:40 412k

51 80.2M 51 41.0M 0 0 408k 0 0:03:21 0:01:42 0:01:39 418k

51 80.2M 51 41.5M 0 0 409k 0 0:03:20 0:01:43 0:01:37 431k

52 80.2M 52 42.1M 0 0 411k 0 0:03:19 0:01:44 0:01:35 426k

52 80.2M 52 42.3M 0 0 409k 0 0:03:20 0:01:45 0:01:35 426k

53 80.2M 53 42.8M 0 0 409k 0 0:03:20 0:01:47 0:01:33 432k

53 80.2M 53 43.1M 0 0 409k 0 0:03:20 0:01:47 0:01:33 432k

54 80.2M 54 44.0M 0 0 412k 0 0:03:19 0:01:49 0:01:30 453k

54 80.2M 54 44.0M 0 0 410k 0 0:03:20 0:01:49 0:01:31 395k

55 80.2M 55 44.4M 0 0 409k 0 0:03:20 0:01:51 0:01:29 410k

55 80.2M 55 44.7M 0 0 409k 0 0:03:20 0:01:51 0:01:29 405k

56 80.2M 56 45.2M 0 0 410k 0 0:03:20 0:01:53 0:01:27 421k

57 80.2M 57 45.9M 0 0 412k 0 0:03:19 0:01:54 0:01:25 432k

57 80.2M 57 46.1M 0 0 411k 0 0:03:19 0:01:54 0:01:25 423k

57 80.2M 57 46.5M 0 0 410k 0 0:03:20 0:01:56 0:01:24 426k

58 80.2M 58 46.8M 0 0 410k 0 0:03:19 0:01:56 0:01:23 445k

59 80.2M 59 47.3M 0 0 411k 0 0:03:19 0:01:57 0:01:22 440k

59 80.2M 59 47.4M 0 0 409k 0 0:03:20 0:01:58 0:01:22 321k

59 80.2M 59 48.1M 0 0 410k 0 0:03:20 0:02:00 0:01:20 389k

60 80.2M 60 48.4M 0 0 410k 0 0:03:20 0:02:00 0:01:20 413k

60 80.2M 60 48.9M 0 0 410k 0 0:03:20 0:02:02 0:01:18 398k

61 80.2M 61 49.2M 0 0 410k 0 0:03:20 0:02:02 0:01:18 390k

62 80.2M 62 49.7M 0 0 410k 0 0:03:19 0:02:04 0:01:15 449k

62 80.2M 62 50.0M 0 0 410k 0 0:03:20 0:02:04 0:01:16 411k

62 80.2M 62 50.4M 0 0 410k 0 0:03:20 0:02:05 0:01:15 406k

63 80.2M 63 50.7M 0 0 410k 0 0:03:20 0:02:06 0:01:14 401k

63 80.2M 63 51.1M 0 0 409k 0 0:03:20 0:02:08 0:01:12 380k

64 80.2M 64 51.5M 0 0 409k 0 0:03:20 0:02:08 0:01:12 365k

64 80.2M 64 51.8M 0 0 409k 0 0:03:20 0:02:09 0:01:11 384k

65 80.2M 65 52.4M 0 0 409k 0 0:03:20 0:02:11 0:01:09 388k

65 80.2M 65 52.6M 0 0 409k 0 0:03:20 0:02:11 0:01:09 383k

66 80.2M 66 53.1M 0 0 408k 0 0:03:20 0:02:13 0:01:07 394k

66 80.2M 66 53.4M 0 0 408k 0 0:03:20 0:02:13 0:01:07 400k

67 80.2M 67 54.0M 0 0 409k 0 0:03:20 0:02:15 0:01:05 412k

68 80.2M 68 54.6M 0 0 410k 0 0:03:19 0:02:16 0:01:03 443k

68 80.2M 68 54.8M 0 0 409k 0 0:03:20 0:02:17 0:01:03 422k

68 80.2M 68 55.0M 0 0 409k 0 0:03:20 0:02:17 0:01:03 421k

69 80.2M 69 55.3M 0 0 408k 0 0:03:21 0:02:18 0:01:03 403k

69 80.2M 69 55.9M 0 0 408k 0 0:03:20 0:02:20 0:01:00 389k

70 80.2M 70 56.4M 0 0 410k 0 0:03:20 0:02:20 0:01:00 386k

70 80.2M 70 56.7M 0 0 408k 0 0:03:21 0:02:22 0:00:59 383k

71 80.2M 71 57.0M 0 0 408k 0 0:03:21 0:02:22 0:00:59 384k

71 80.2M 71 57.4M 0 0 408k 0 0:03:21 0:02:24 0:00:57 402k

72 80.2M 72 57.8M 0 0 408k 0 0:03:21 0:02:24 0:00:57 403k

72 80.2M 72 58.3M 0 0 409k 0 0:03:20 0:02:26 0:00:54 383k

73 80.2M 73 58.6M 0 0 408k 0 0:03:21 0:02:26 0:00:55 408k

73 80.2M 73 59.0M 0 0 408k 0 0:03:21 0:02:28 0:00:53 407k

74 80.2M 74 59.4M 0 0 408k 0 0:03:21 0:02:28 0:00:53 408k

74 80.2M 74 59.8M 0 0 409k 0 0:03:20 0:02:29 0:00:51 429k

75 80.2M 75 60.2M 0 0 408k 0 0:03:20 0:02:30 0:00:50 400k

75 80.2M 75 60.5M 0 0 408k 0 0:03:21 0:02:31 0:00:50 397k

76 80.2M 76 60.9M 0 0 408k 0 0:03:21 0:02:33 0:00:48 397k

76 80.2M 76 61.3M 0 0 408k 0 0:03:21 0:02:33 0:00:48 412k

77 80.2M 77 61.8M 0 0 407k 0 0:03:21 0:02:35 0:00:46 357k

77 80.2M 77 62.1M 0 0 408k 0 0:03:21 0:02:35 0:00:46 405k

78 80.2M 78 62.5M 0 0 408k 0 0:03:21 0:02:37 0:00:44 404k

78 80.2M 78 62.8M 0 0 407k 0 0:03:21 0:02:37 0:00:44 401k

78 80.2M 78 63.2M 0 0 407k 0 0:03:21 0:02:38 0:00:43 377k

79 80.2M 79 63.9M 0 0 408k 0 0:03:20 0:02:40 0:00:40 454k

79 80.2M 79 64.0M 0 0 407k 0 0:03:21 0:02:40 0:00:41 375k

80 80.2M 80 64.4M 0 0 407k 0 0:03:21 0:02:42 0:00:39 378k

80 80.2M 80 64.7M 0 0 407k 0 0:03:21 0:02:42 0:00:39 384k

81 80.2M 81 65.1M 0 0 407k 0 0:03:21 0:02:43 0:00:38 400k

81 80.2M 81 65.5M 0 0 406k 0 0:03:22 0:02:44 0:00:38 333k

82 80.2M 82 66.1M 0 0 407k 0 0:03:21 0:02:46 0:00:35 409k

82 80.2M 82 66.4M 0 0 407k 0 0:03:21 0:02:46 0:00:35 417k

83 80.2M 83 66.7M 0 0 407k 0 0:03:21 0:02:47 0:00:34 406k

83 80.2M 83 66.8M 0 0 405k 0 0:03:22 0:02:48 0:00:34 348k

83 80.2M 83 66.8M 0 0 402k 0 0:03:24 0:02:50 0:00:34 267k

83 80.2M 83 66.9M 0 0 400k 0 0:03:24 0:02:51 0:00:33 178k

83 80.2M 83 67.0M 0 0 399k 0 0:03:25 0:02:51 0:00:34 121k

83 80.2M 83 67.0M 0 0 396k 0 0:03:27 0:02:53 0:00:34 68552

83 80.2M 83 67

.1M    0     0   395k      0  0:03:27  0:02:53  0:00:34 55340

83 80.2M 83 67.3M 0 0 393k 0 0:03:28 0:02:55 0:00:33 97k

84 80.2M 84 67.5M 0 0 393k 0 0:03:28 0:02:55 0:00:33 123k

84 80.2M 84 67.8M 0 0 392k 0 0:03:29 0:02:56 0:00:33 169k

85 80.2M 85 68.2M 0 0 392k 0 0:03:29 0:02:57 0:00:32 244k

85 80.2M 85 68.7M 0 0 393k 0 0:03:28 0:02:58 0:00:30 318k

86 80.2M 86 69.0M 0 0 393k 0 0:03:28 0:02:59 0:00:29 379k

86 80.2M 86 69.1M 0 0 391k 0 0:03:29 0:03:00 0:00:29 331k

86 80.2M 86 69.1M 0 0 388k 0 0:03:31 0:03:02 0:00:29 260k

86 80.2M 86 69.1M 0 0 387k 0 0:03:32 0:03:03 0:00:29 191k

86 80.2M 86 69.2M 0 0 385k 0 0:03:33 0:03:03 0:00:30 104k

86 80.2M 86 69.4M 0 0 383k 0 0:03:33 0:03:05 0:00:28 64954

86 80.2M 86 69.4M 0 0 382k 0 0:03:34 0:03:05 0:00:29 65431

86 80.2M 86 69.5M 0 0 381k 0 0:03:35 0:03:07 0:00:28 94686

86 80.2M 86 69.7M 0 0 380k 0 0:03:36 0:03:07 0:00:29 115k

87 80.2M 87 69.9M 0 0 379k 0 0:03:36 0:03:08 0:00:28 149k

87 80.2M 87 70.1M 0 0 378k 0 0:03:37 0:03:10 0:00:27 159k

87 80.2M 87 70.3M 0 0 377k 0 0:03:37 0:03:10 0:00:27 188k

88 80.2M 88 70.7M 0 0 377k 0 0:03:37 0:03:12 0:00:25 229k

88 80.2M 88 70.9M 0 0 376k 0 0:03:37 0:03:12 0:00:25 257k

89 80.2M 89 71.5M 0 0 376k 0 0:03:38 0:03:14 0:00:24 279k

89 80.2M 89 71.7M 0 0 376k 0 0:03:38 0:03:14 0:00:24 318k

90 80.2M 90 72.2M 0 0 377k 0 0:03:37 0:03:16 0:00:21 366k

90 80.2M 90 72.4M 0 0 376k 0 0:03:38 0:03:16 0:00:22 361k

90 80.2M 90 72.8M 0 0 377k 0 0:03:37 0:03:17 0:00:20 386k

91 80.2M 91 73.1M 0 0 376k 0 0:03:38 0:03:18 0:00:20 388k

91 80.2M 91 73.6M 0 0 376k 0 0:03:37 0:03:20 0:00:17 384k

92 80.2M 92 73.9M 0 0 376k 0 0:03:37 0:03:20 0:00:17 359k

92 80.2M 92 74.4M 0 0 377k 0 0:03:37 0:03:21 0:00:16 401k

93 80.2M 93 74.7M 0 0 377k 0 0:03:37 0:03:22 0:00:15 386k

93 80.2M 93 75.2M 0 0 377k 0 0:03:37 0:03:24 0:00:13 409k

94 80.2M 94 75.5M 0 0 377k 0 0:03:37 0:03:24 0:00:13 411k

94 80.2M 94 75.9M 0 0 378k 0 0:03:37 0:03:25 0:00:12 425k

95 80.2M 95 76.3M 0 0 377k 0 0:03:37 0:03:27 0:00:10 395k

95 80.2M 95 76.7M 0 0 377k 0 0:03:37 0:03:27 0:00:10 400k

96 80.2M 96 77.0M 0 0 377k 0 0:03:37 0:03:28 0:00:09 383k

96 80.2M 96 77.2M 0 0 376k 0 0:03:38 0:03:30 0:00:08 325k

96 80.2M 96 77.4M 0 0 375k 0 0:03:38 0:03:30 0:00:08 290k

97 80.2M 97 77.8M 0 0 376k 0 0:03:38 0:03:32 0:00:06 306k

97 80.2M 97 78.1M 0 0 376k 0 0:03:38 0:03:32 0:00:06 297k

98 80.2M 98 78.6M 0 0 376k 0 0:03:38 0:03:33 0:00:05 318k

98 80.2M 98 78.9M 0 0 376k 0 0:03:38 0:03:34 0:00:04 374k

98 80.2M 98 79.4M 0 0 376k 0 0:03:38 0:03:35 0:00:03 410k

99 80.2M 99 79.9M 0 0 377k 0 0:03:37 0:03:36 0:00:01 444k

100 80.2M 100 80.2M 0 0 378k 0 0:03:37 0:03:37 –:–:– 473k

def get_text_list_from_files(files):
    text_list = []
    for name in files:
        with open(name) as f:
            for line in f:
                text_list.append(line)
    return text_list


def get_data_from_text_files(folder_name):
    pos_files = glob.glob("aclImdb/" + folder_name + "/pos/*.txt")
    pos_texts = get_text_list_from_files(pos_files)
    neg_files = glob.glob("aclImdb/" + folder_name + "/neg/*.txt")
    neg_texts = get_text_list_from_files(neg_files)
    df = pd.DataFrame(
        {
            "review": pos_texts + neg_texts,
            "sentiment": [0] * len(pos_texts) + [1] * len(neg_texts),
        }
    )
    df = df.sample(len(df)).reset_index(drop=True)
    return df


train_df = get_data_from_text_files("train")
test_df = get_data_from_text_files("test")

all_data = pd.concat([train_df, test_df], ignore_index=True)

資料集準備

我們將使用 TextVectorization 層將文字向量化為整數符記 ID。它將一批字串轉換為符記索引序列(一個樣本 = 1D 整數符記索引陣列,依序排列)或密集表示(一個樣本 = 1D 浮點數值陣列,編碼符記的無序集合)。

下面,我們定義 3 個預處理函數。

  1. get_vectorize_layer 函數建立 TextVectorization 層。
  2. encode 函數將原始文字編碼為整數符記 ID。
  3. get_masked_input_and_labels 函數將遮罩輸入符記 ID。它隨機遮罩每個序列中 15% 的所有輸入符記。
# For data pre-processing and tf.data.Dataset
import tensorflow as tf


def custom_standardization(input_data):
    lowercase = tf.strings.lower(input_data)
    stripped_html = tf.strings.regex_replace(lowercase, "<br />", " ")
    return tf.strings.regex_replace(
        stripped_html, "[%s]" % re.escape("!#$%&'()*+,-./:;<=>?@\^_`{|}~"), ""
    )


def get_vectorize_layer(texts, vocab_size, max_seq, special_tokens=["[MASK]"]):
    """Build Text vectorization layer

    Args:
      texts (list): List of string i.e input texts
      vocab_size (int): vocab size
      max_seq (int): Maximum sequence length.
      special_tokens (list, optional): List of special tokens. Defaults to ['[MASK]'].

    Returns:
        layers.Layer: Return TextVectorization Keras Layer
    """
    vectorize_layer = TextVectorization(
        max_tokens=vocab_size,
        output_mode="int",
        standardize=custom_standardization,
        output_sequence_length=max_seq,
    )
    vectorize_layer.adapt(texts)

    # Insert mask token in vocabulary
    vocab = vectorize_layer.get_vocabulary()
    vocab = vocab[2 : vocab_size - len(special_tokens)] + ["[mask]"]
    vectorize_layer.set_vocabulary(vocab)
    return vectorize_layer


vectorize_layer = get_vectorize_layer(
    all_data.review.values.tolist(),
    config.VOCAB_SIZE,
    config.MAX_LEN,
    special_tokens=["[mask]"],
)

# Get mask token id for masked language model
mask_token_id = vectorize_layer(["[mask]"]).numpy()[0][0]


def encode(texts):
    encoded_texts = vectorize_layer(texts)
    return encoded_texts.numpy()


def get_masked_input_and_labels(encoded_texts):
    # 15% BERT masking
    inp_mask = np.random.rand(*encoded_texts.shape) < 0.15
    # Do not mask special tokens
    inp_mask[encoded_texts <= 2] = False
    # Set targets to -1 by default, it means ignore
    labels = -1 * np.ones(encoded_texts.shape, dtype=int)
    # Set labels for masked tokens
    labels[inp_mask] = encoded_texts[inp_mask]

    # Prepare input
    encoded_texts_masked = np.copy(encoded_texts)
    # Set input to [MASK] which is the last token for the 90% of tokens
    # This means leaving 10% unchanged
    inp_mask_2mask = inp_mask & (np.random.rand(*encoded_texts.shape) < 0.90)
    encoded_texts_masked[inp_mask_2mask] = (
        mask_token_id  # mask token is the last in the dict
    )

    # Set 10% to a random token
    inp_mask_2random = inp_mask_2mask & (np.random.rand(*encoded_texts.shape) < 1 / 9)
    encoded_texts_masked[inp_mask_2random] = np.random.randint(
        3, mask_token_id, inp_mask_2random.sum()
    )

    # Prepare sample_weights to pass to .fit() method
    sample_weights = np.ones(labels.shape)
    sample_weights[labels == -1] = 0

    # y_labels would be same as encoded_texts i.e input tokens
    y_labels = np.copy(encoded_texts)

    return encoded_texts_masked, y_labels, sample_weights


# We have 25000 examples for training
x_train = encode(train_df.review.values)  # encode reviews with vectorizer
y_train = train_df.sentiment.values
train_classifier_ds = (
    tf.data.Dataset.from_tensor_slices((x_train, y_train))
    .shuffle(1000)
    .batch(config.BATCH_SIZE)
)

# We have 25000 examples for testing
x_test = encode(test_df.review.values)
y_test = test_df.sentiment.values
test_classifier_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(
    config.BATCH_SIZE
)

# Dataset for end to end model input (will be used at the end)
test_raw_classifier_ds = test_df

# Prepare data for masked language model
x_all_review = encode(all_data.review.values)
x_masked_train, y_masked_labels, sample_weights = get_masked_input_and_labels(
    x_all_review
)

mlm_ds = tf.data.Dataset.from_tensor_slices(
    (x_masked_train, y_masked_labels, sample_weights)
)
mlm_ds = mlm_ds.shuffle(1000).batch(config.BATCH_SIZE)

建立用於遮罩語言模型的 BERT 模型(預訓練模型)

我們將使用 MultiHeadAttention 層建立類似 BERT 的預訓練模型架構。它將採用符記 ID 作為輸入(包括遮罩符記),並將預測遮罩輸入符記的正確 ID。

def bert_module(query, key, value, i):
    # Multi headed self-attention
    attention_output = layers.MultiHeadAttention(
        num_heads=config.NUM_HEAD,
        key_dim=config.EMBED_DIM // config.NUM_HEAD,
        name="encoder_{}_multiheadattention".format(i),
    )(query, key, value)
    attention_output = layers.Dropout(0.1, name="encoder_{}_att_dropout".format(i))(
        attention_output
    )
    attention_output = layers.LayerNormalization(
        epsilon=1e-6, name="encoder_{}_att_layernormalization".format(i)
    )(query + attention_output)

    # Feed-forward layer
    ffn = keras.Sequential(
        [
            layers.Dense(config.FF_DIM, activation="relu"),
            layers.Dense(config.EMBED_DIM),
        ],
        name="encoder_{}_ffn".format(i),
    )
    ffn_output = ffn(attention_output)
    ffn_output = layers.Dropout(0.1, name="encoder_{}_ffn_dropout".format(i))(
        ffn_output
    )
    sequence_output = layers.LayerNormalization(
        epsilon=1e-6, name="encoder_{}_ffn_layernormalization".format(i)
    )(attention_output + ffn_output)
    return sequence_output


loss_fn = keras.losses.SparseCategoricalCrossentropy(reduction=None)
loss_tracker = keras.metrics.Mean(name="loss")


class MaskedLanguageModel(keras.Model):

    def compute_loss(self, x=None, y=None, y_pred=None, sample_weight=None):

        loss = loss_fn(y, y_pred, sample_weight)
        loss_tracker.update_state(loss, sample_weight=sample_weight)
        return keras.ops.sum(loss)

    def compute_metrics(self, x, y, y_pred, sample_weight):

        # Return a dict mapping metric names to current value
        return {"loss": loss_tracker.result()}

    @property
    def metrics(self):
        # We list our `Metric` objects here so that `reset_states()` can be
        # called automatically at the start of each epoch
        # or at the start of `evaluate()`.
        # If you don't implement this property, you have to call
        # `reset_states()` yourself at the time of your choosing.
        return [loss_tracker]


def create_masked_language_bert_model():
    inputs = layers.Input((config.MAX_LEN,), dtype="int64")

    word_embeddings = layers.Embedding(
        config.VOCAB_SIZE, config.EMBED_DIM, name="word_embedding"
    )(inputs)
    position_embeddings = keras_hub.layers.PositionEmbedding(
        sequence_length=config.MAX_LEN
    )(word_embeddings)
    embeddings = word_embeddings + position_embeddings

    encoder_output = embeddings
    for i in range(config.NUM_LAYERS):
        encoder_output = bert_module(encoder_output, encoder_output, encoder_output, i)

    mlm_output = layers.Dense(config.VOCAB_SIZE, name="mlm_cls", activation="softmax")(
        encoder_output
    )
    mlm_model = MaskedLanguageModel(inputs, mlm_output, name="masked_bert_model")

    optimizer = keras.optimizers.Adam(learning_rate=config.LR)
    mlm_model.compile(optimizer=optimizer)
    return mlm_model


id2token = dict(enumerate(vectorize_layer.get_vocabulary()))
token2id = {y: x for x, y in id2token.items()}


class MaskedTextGenerator(keras.callbacks.Callback):
    def __init__(self, sample_tokens, top_k=5):
        self.sample_tokens = sample_tokens
        self.k = top_k

    def decode(self, tokens):
        return " ".join([id2token[t] for t in tokens if t != 0])

    def convert_ids_to_tokens(self, id):
        return id2token[id]

    def on_epoch_end(self, epoch, logs=None):
        prediction = self.model.predict(self.sample_tokens)

        masked_index = np.where(self.sample_tokens == mask_token_id)
        masked_index = masked_index[1]
        mask_prediction = prediction[0][masked_index]

        top_indices = mask_prediction[0].argsort()[-self.k :][::-1]
        values = mask_prediction[0][top_indices]

        for i in range(len(top_indices)):
            p = top_indices[i]
            v = values[i]
            tokens = np.copy(sample_tokens[0])
            tokens[masked_index[0]] = p
            result = {
                "input_text": self.decode(sample_tokens[0].numpy()),
                "prediction": self.decode(tokens),
                "probability": v,
                "predicted mask token": self.convert_ids_to_tokens(p),
            }
            pprint(result)


sample_tokens = vectorize_layer(["I have watched this [mask] and it was awesome"])
generator_callback = MaskedTextGenerator(sample_tokens.numpy())

bert_masked_model = create_masked_language_bert_model()
bert_masked_model.summary()
Model: "masked_bert_model"
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓
┃ Layer (type)         Output Shape          Param #  Connected to      ┃
┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩
│ input_layer         │ (None, 256)       │          0 │ -                 │
│ (InputLayer)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ word_embedding      │ (None, 256, 128)  │  3,840,000 │ input_layer[0][0] │
│ (Embedding)         │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ position_embedding  │ (None, 256, 128)  │     32,768 │ word_embedding[0… │
│ (PositionEmbedding) │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add (Add)           │ (None, 256, 128)  │          0 │ word_embedding[0… │
│                     │                   │            │ position_embeddi… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ encoder_0_multihea… │ (None, 256, 128)  │     66,048 │ add[0][0],        │
│ (MultiHeadAttentio… │                   │            │ add[0][0],        │
│                     │                   │            │ add[0][0]         │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ encoder_0_att_drop… │ (None, 256, 128)  │          0 │ encoder_0_multih… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_1 (Add)         │ (None, 256, 128)  │          0 │ add[0][0],        │
│                     │                   │            │ encoder_0_att_dr… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ encoder_0_att_laye… │ (None, 256, 128)  │        256 │ add_1[0][0]       │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ encoder_0_ffn       │ (None, 256, 128)  │     33,024 │ encoder_0_att_la… │
│ (Sequential)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ encoder_0_ffn_drop… │ (None, 256, 128)  │          0 │ encoder_0_ffn[0]… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_2 (Add)         │ (None, 256, 128)  │          0 │ encoder_0_att_la… │
│                     │                   │            │ encoder_0_ffn_dr… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ encoder_0_ffn_laye… │ (None, 256, 128)  │        256 │ add_2[0][0]       │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ mlm_cls (Dense)     │ (None, 256,       │  3,870,000 │ encoder_0_ffn_la… │
│                     │ 30000)            │            │                   │
└─────────────────────┴───────────────────┴────────────┴───────────────────┘
 Total params: 7,842,352 (29.92 MB)
 Trainable params: 7,842,352 (29.92 MB)
 Non-trainable params: 0 (0.00 B)

訓練和儲存

bert_masked_model.fit(mlm_ds, epochs=5, callbacks=[generator_callback])
bert_masked_model.save("bert_mlm_imdb.keras")

1/16 ━ [37m━━━━━━━━━━━━━━━━━━━ 3:02 12s/step - loss: 10.3103



2/16 ━━ [37m━━━━━━━━━━━━━━━━━━ 3:31 15s/step - loss: 10.2979



3/16 ━━━ [37m━━━━━━━━━━━━━━━━━ 3:25 16s/step - loss: 10.2859



4/16 ━━━━━ [37m━━━━━━━━━━━━━━━ 3:14 16s/step - loss: 10.2727



5/16 ━━━━━━ [37m━━━━━━━━━━━━━━ 2:57 16s/step - loss: 10.2564



6/16 ━━━━━━━ [37m━━━━━━━━━━━━━ 2:42 16s/step - loss: 10.2378



7/16 ━━━━━━━━ [37m━━━━━━━━━━━━ 2:26 16s/step - loss: 10.2182



8/16 ━━━━━━━━━━ [37m━━━━━━━━━━ 2:10 16s/step - loss: 10.1975



9/16 ━━━━━━━━━━━ [37m━━━━━━━━━ 1:55 16s/step - loss: 10.1745



10/16 ━━━━━━━━━━━━ [37m━━━━━━━━ 1:39 17s/step - loss: 10.1503



11/16 ━━━━━━━━━━━━━ [37m━━━━━━━ 1:23 17s/step - loss: 10.1254



12/16 ━━━━━━━━━━━━━━━ [37m━━━━━ 1:07 17s/step - loss: 10.0993



13/16 ━━━━━━━━━━━━━━━━ [37m━━━━ 50s 17s/step - loss: 10.0726



14/16 ━━━━━━━━━━━━━━━━━ [37m━━━ 33s 17s/step - loss: 10.0452



15/16 ━━━━━━━━━━━━━━━━━━ [37m━━ 16s 17s/step - loss: 10.0174



16/16 ━━━━━━━━━━━━━━━━━━━━ 0s 17s/step - loss: 9.9899

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 81ms/step



1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 82ms/step

{'input_text': 'i have watched this [mask] and it was awesome',
 'predicted mask token': 'a',
 'prediction': 'i have watched this a and it was awesome',
 'probability': 0.0013674975}
{'input_text': 'i have watched this [mask] and it was awesome',
 'predicted mask token': 'i',
 'prediction': 'i have watched this i and it was awesome',
 'probability': 0.0012694978}
{'input_text': 'i have watched this [mask] and it was awesome',
 'predicted mask token': 'is',
 'prediction': 'i have watched this is and it was awesome',
 'probability': 0.0012668626}
{'input_text': 'i have watched this [mask] and it was awesome',
 'predicted mask token': 'to',
 'prediction': 'i have watched this to and it was awesome',
 'probability': 0.0012651902}
{'input_text': 'i have watched this [mask] and it was awesome',
 'predicted mask token': 'of',
 'prediction': 'i have watched this of and it was awesome',
 'probability': 0.0011966776}


16/16 ━━━━━━━━━━━━━━━━━━━━ 261s 17s/step - loss: 9.9656


微調情感分類模型

我們將在情感分類的下游任務上微調我們的自我監督模型。為此,讓我們透過在預訓練的 BERT 特徵之上添加池化層和 Dense 層來建立分類器。

# Load pretrained bert model
mlm_model = keras.models.load_model(
    "bert_mlm_imdb.keras", custom_objects={"MaskedLanguageModel": MaskedLanguageModel}
)
pretrained_bert_model = keras.Model(
    mlm_model.input, mlm_model.get_layer("encoder_0_ffn_layernormalization").output
)

# Freeze it
pretrained_bert_model.trainable = False


def create_classifier_bert_model():
    inputs = layers.Input((config.MAX_LEN,), dtype="int64")
    sequence_output = pretrained_bert_model(inputs)
    pooled_output = layers.GlobalMaxPooling1D()(sequence_output)
    hidden_layer = layers.Dense(64, activation="relu")(pooled_output)
    outputs = layers.Dense(1, activation="sigmoid")(hidden_layer)
    classifer_model = keras.Model(inputs, outputs, name="classification")
    optimizer = keras.optimizers.Adam()
    classifer_model.compile(
        optimizer=optimizer, loss="binary_crossentropy", metrics=["accuracy"]
    )
    return classifer_model


classifer_model = create_classifier_bert_model()
classifer_model.summary()

# Train the classifier with frozen BERT stage
classifer_model.fit(
    train_classifier_ds,
    epochs=5,
    validation_data=test_classifier_ds,
)

# Unfreeze the BERT model for fine-tuning
pretrained_bert_model.trainable = True
optimizer = keras.optimizers.Adam()
classifer_model.compile(
    optimizer=optimizer, loss="binary_crossentropy", metrics=["accuracy"]
)
classifer_model.fit(
    train_classifier_ds,
    epochs=5,
    validation_data=test_classifier_ds,
)
Model: "classification"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer (type)                     Output Shape                  Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ input_layer_2 (InputLayer)      │ (None, 256)            │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ functional_3 (Functional)       │ (None, 256, 128)       │     3,972,352 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ global_max_pooling1d            │ (None, 128)            │             0 │
│ (GlobalMaxPooling1D)            │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense_2 (Dense)                 │ (None, 64)             │         8,256 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense_3 (Dense)                 │ (None, 1)              │            65 │
└─────────────────────────────────┴────────────────────────┴───────────────┘
 Total params: 3,980,673 (15.19 MB)
 Trainable params: 8,321 (32.50 KB)
 Non-trainable params: 3,972,352 (15.15 MB)

1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0s 140ms/step - accuracy: 0.5312 - loss: 0.7599



2/8 ━━━━━ [37m━━━━━━━━━━━━━━━ 1s 184ms/step - accuracy: 0.5703 - loss: 0.7296



3/8 ━━━━━━━ [37m━━━━━━━━━━━━━ 0s 164ms/step - accuracy: 0.5851 - loss: 0.7164



4/8 ━━━━━━━━━━ [37m━━━━━━━━━━ 0s 161ms/step - accuracy: 0.5794 - loss: 0.7125



5/8 ━━━━━━━━━━━━ [37m━━━━━━━━ 0s 158ms/step - accuracy: 0.5685 - loss: 0.7105



6/8 ━━━━━━━━━━━━━━━ [37m━━━━━ 0s 158ms/step - accuracy: 0.5589 - loss: 0.7090



7/8 ━━━━━━━━━━━━━━━━━ [37m━━━ 0s 156ms/step - accuracy: 0.5504 - loss: 0.7080



8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 151ms/step - accuracy: 0.5426 - loss: 0.7076



8/8 ━━━━━━━━━━━━━━━━━━━━ 2s 288ms/step - accuracy: 0.5366 - loss: 0.7073 - val_accuracy: 0.4920 - val_loss: 0.6975

1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 3s 436ms/step - accuracy: 0.5000 - loss: 0.7119



2/8 ━━━━━ [37m━━━━━━━━━━━━━━━ 3s 534ms/step - accuracy: 0.5469 - loss: 0.6903



3/8 ━━━━━━━ [37m━━━━━━━━━━━━━ 2s 472ms/step - accuracy: 0.5660 - loss: 0.6913



4/8 ━━━━━━━━━━ [37m━━━━━━━━━━ 1s 461ms/step - accuracy: 0.5671 - loss: 0.7032



5/8 ━━━━━━━━━━━━ [37m━━━━━━━━ 1s 459ms/step - accuracy: 0.5636 - loss: 0.7116



6/8 ━━━━━━━━━━━━━━━ [37m━━━━━ 0s 468ms/step - accuracy: 0.5626 - loss: 0.7156



7/8 ━━━━━━━━━━━━━━━━━ [37m━━━ 0s 476ms/step - accuracy: 0.5600 - loss: 0.7183



8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 476ms/step - accuracy: 0.5580 - loss: 0.7198



8/8 ━━━━━━━━━━━━━━━━━━━━ 5s 650ms/step - accuracy: 0.5565 - loss: 0.7210 - val_accuracy: 0.5160 - val_loss: 0.6895

<keras.src.callbacks.history.History at 0x7a0e5fd9bf50>

建立端對端模型並評估它

當您想要部署模型時,最好是模型已經包含其預處理管線,這樣您就不必在生產環境中重新實作預處理邏輯。讓我們建立一個端對端模型,將 TextVectorization 層併入評估方法中,並進行評估。我們將傳遞原始字串作為輸入。

# We create a custom Model to override the evaluate method so
# that it first pre-process text data
class ModelEndtoEnd(keras.Model):

    def evaluate(self, inputs):
        features = encode(inputs.review.values)
        labels = inputs.sentiment.values
        test_classifier_ds = (
            tf.data.Dataset.from_tensor_slices((features, labels))
            .shuffle(1000)
            .batch(config.BATCH_SIZE)
        )
        return super().evaluate(test_classifier_ds)

    # Build the model
    def build(self, input_shape):
        self.built = True


def get_end_to_end(model):
    inputs = classifer_model.inputs[0]
    outputs = classifer_model.outputs
    end_to_end_model = ModelEndtoEnd(inputs, outputs, name="end_to_end_model")
    optimizer = keras.optimizers.Adam(learning_rate=config.LR)
    end_to_end_model.compile(
        optimizer=optimizer, loss="binary_crossentropy", metrics=["accuracy"]
    )
    return end_to_end_model


end_to_end_classification_model = get_end_to_end(classifer_model)
# Pass raw text dataframe to the model
end_to_end_classification_model.evaluate(test_raw_classifier_ds)

1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0s 138ms/step - accuracy: 0.6875 - loss: 0.6684



2/8 ━━━━━ [37m━━━━━━━━━━━━━━━ 1s 225ms/step - accuracy: 0.6250 - loss: 0.6761



3/8 ━━━━━━━ [37m━━━━━━━━━━━━━ 0s 190ms/step - accuracy: 0.5833 - loss: 0.6820



4/8 ━━━━━━━━━━ [37m━━━━━━━━━━ 0s 184ms/step - accuracy: 0.5605 - loss: 0.6848



5/8 ━━━━━━━━━━━━ [37m━━━━━━━━ 0s 178ms/step - accuracy: 0.5422 - loss: 0.6871



6/8 ━━━━━━━━━━━━━━━ [37m━━━━━ 0s 174ms/step - accuracy: 0.5352 - loss: 0.6880



7/8 ━━━━━━━━━━━━━━━━━ [37m━━━ 0s 169ms/step - accuracy: 0.5320 - loss: 0.6883



8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 165ms/step - accuracy: 0.5300 - loss: 0.6885



8/8 ━━━━━━━━━━━━━━━━━━━━ 1s 166ms/step - accuracy: 0.5285 - loss: 0.6886

[0.6894814372062683, 0.515999972820282]