作者: Ankur Singh
建立日期 2020/09/18
最後修改日期 2024/03/15
描述: 實作一個使用 BERT 的遮罩語言模型 (MLM),並在 IMDB 評論資料集上微調它。
遮罩語言模型是一個填空任務,模型使用遮罩符號周圍的上下文詞語,嘗試預測被遮罩的詞語應該是什麼。
對於包含一個或多個遮罩符號的輸入,模型將為每個符號生成最可能的替換。
範例
遮罩語言模型是在自我監督設定(沒有人工註釋標籤)中訓練語言模型的好方法。然後,可以微調此模型以完成各種監督式 NLP 任務。
這個範例教您如何從頭開始建立 BERT 模型,使用遮罩語言模型任務訓練它,然後在情感分類任務上微調此模型。
我們將使用 Keras TextVectorization
和 MultiHeadAttention
層來建立 BERT Transformer 編碼器網路架構。
注意:這個範例應該使用 tf-nightly
執行。
透過 pip install tf-nightly
安裝 tf-nightly
。
import os
os.environ["KERAS_BACKEND"] = "torch" # or jax, or tensorflow
import keras_hub
import keras
from keras import layers
from keras.layers import TextVectorization
from dataclasses import dataclass
import pandas as pd
import numpy as np
import glob
import re
from pprint import pprint
@dataclass
class Config:
MAX_LEN = 256
BATCH_SIZE = 32
LR = 0.001
VOCAB_SIZE = 30000
EMBED_DIM = 128
NUM_HEAD = 8 # used in bert model
FF_DIM = 128 # used in bert model
NUM_LAYERS = 1
config = Config()
我們將首先下載 IMDB 資料並載入到 Pandas 資料框架中。
!curl -O https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz
!tar -xf aclImdb_v1.tar.gz
% Total % Received % Xferd Average Speed Time Time Time Current
Dload Upload Total Spent Left Speed
0 0 0 0 0 0 0 0 –:–:– –:–:– –:–:– 0
0 0 0 0 0 0 0 0 –:–:– 0:00:01 –:–:– 0
0 80.2M 0 16384 0 0 8170 0 2:51:36 0:00:02 2:51:34 8171
0 80.2M 0 32768 0 0 11611 0 2:00:45 0:00:02 2:00:43 11611
0 80.2M 0 144k 0 0 36325 0 0:38:35 0:00:04 0:38:31 36328
0 80.2M 0 240k 0 0 50462 0 0:27:47 0:00:04 0:27:43 50464
0 80.2M 0 464k 0 0 81947 0 0:17:06 0:00:05 0:17:01 100k
0 80.2M 0 672k
0 0 97992 0 0:14:18 0:00:07 0:14:11 130k
1 80.2M 1 912k 0 0 116k 0 0:11:42 0:00:07 0:11:35 176k
1 80.2M 1 1104k 0 0 123k 0 0:11:07 0:00:08 0:10:59 195k
1 80.2M 1 1424k 0 0 145k 0 0:09:25 0:00:09 0:09:16 240k
2 80.2M 2 1696k 0 0 153k 0 0:08:53 0:00:11 0:08:42 236k
2 80.2M 2 2016k 0 0 170k 0 0:08:02 0:00:11 0:07:51 279k
3 80.2M 3 2544k 0 0 194k 0 0:07:01 0:00:13 0:06:48 310k
3 80.2M 3 2816k 0 0 202k 0 0:06:45 0:00:13 0:06:32 348k
3 80.2M 3 3264k 0 0 215k 0 0:06:20 0:00:15 0:06:05 346k
4 80.2M 4 3632k 0 0 227k 0 0:06:00 0:00:15 0:05:45 393k
5 80.2M 5 4128k 0 0 240k 0 0:05:40 0:00:17 0:05:23 398k
5 80.2M 5 4384k 0 0 243k 0 0:05:36 0:00:17 0:05:19 374k
5 80.2M 5 4832k 0 0 252k 0 0:05:24 0:00:19 0:05:05 386k
6 80.2M 6 5152k 0 0 258k 0 0:05:17 0:00:19 0:04:58 392k
6 80.2M 6 5632k 0 0 267k 0 0:05:07 0:00:21 0:04:46 389k
7 80.2M 7 5952k 0 0 272k 0 0:05:01 0:00:21 0:04:40 385k
7 80.2M 7 6432k 0 0 278k 0 0:04:55 0:00:23 0:04:32 400k
8 80.2M 8 6768k 0 0 284k 0 0:04:49 0:00:23 0:04:26 411k
9 80.2M 9 7408k 0 0 298k 0 0:04:34 0:00:24 0:04:10 462k
9 80.2M 9 7952k 0 0 308k 0 0:04:26 0:00:25 0:04:01 489k
10 80.2M 10 8896k 0 0 331k 0 0:04:07 0:00:26 0:03:41 597k
11 80.2M 11 9392k 0 0 337k 0 0:04:03 0:00:27 0:03:36 625k
12 80.2M 12 10.0M 0 0 355k 0 0:03:51 0:00:28 0:03:23 696k
13 80.2M 13 10.7M 0 0 368k 0 0:03:42 0:00:29 0:03:13 705k
14 80.2M 14 11.4M 0 0 377k 0 0:03:37 0:00:31 0:03:06 717k
14 80.2M 14 11.6M 0 0 373k 0 0:03:40 0:00:31 0:03:09 589k
15 80.2M 15 12.3M 0 0 380k 0 0:03:35 0:00:33 0:03:02 606k
15 80.2M 15 12.6M 0 0 383k 0 0:03:34 0:00:33 0:03:01 542k
16 80.2M 16 13.0M 0 0 381k 0 0:03:35 0:00:34 0:03:01 455k
16 80.2M 16 13.2M 0 0 378k 0 0:03:37 0:00:35 0:03:02 383k
17 80.2M 17 13.9M 0 0 387k 0 0:03:31 0:00:36 0:02:55 480k
18 80.2M 18 14.4M 0 0 388k 0 0:03:31 0:00:38 0:02:53 440k
18 80.2M 18 14.8M 0 0 389k 0 0:03:30 0:00:38 0:02:52 431k
19 80.2M 19 15.3M 0 0 394k 0 0:03:28 0:00:39 0:02:49 491k
19 80.2M 19 15.6M 0 0 390k 0 0:03:30 0:00:41 0:02:49 468k
19 80.2M 19 15.8M 0 0 388k 0 0:03:31 0:00:41 0:02:50 397k
20 80.2M 20 16.4M 0 0 390k 0 0:03:30 0:00:42 0:02:48 411k
20 80.2M 20 16.8M 0 0 393k 0 0:03:28 0:00:43 0:02:45 427k
21 80.2M 21 17.2M 0 0 394k 0 0:03:28 0:00:44 0:02:44 393k
22 80.2M 22 17.6M 0 0 393k 0 0:03:29 0:00:46 0:02:43 417k
22 80.2M 22 17.9M 0 0 393k 0 0:03:29 0:00:46 0:02:43 427k
23 80.2M 23 18.4M 0 0 393k 0 0:03:28 0:00:48 0:02:40 416k
23 80.2M 23 18.8M 0 0 390k 0 0:03:30 0:00:49 0:02:41 364k
24 80.2M 24 19.2M 0 0 394k 0 0:03:28 0:00:50 0:02:38 398k
24 80.2M 24 19.5M 0 0 393k 0 0:03:28 0:00:50 0:02:38 402k
24 80.2M 24 20.0M 0 0 393k 0 0:03:28 0:00:52 0:02:36 401k
25 80.2M 25 20.3M 0 0 394k 0 0:03:28 0:00:52 0:02:36 405k
26 80.2M 26 21.2M 0 0 398k 0 0:03:25 0:00:54 0:02:31 478k
26 80.2M 26 21.2M 0 0 396k 0 0:03:27 0:00:54 0:02:33 407k
26 80.2M 26 21.6M 0 0 394k 0 0:03:28 0:00:56 0:02:32 399k
27 80.2M 27 21.9M 0 0 394k 0 0:03:28 0:00:56 0:02:32 400k
27 80.2M 27 22.2M 0 0 394k 0 0:03:28 0:00:57 0:02:31 396k
28 80.2M 28 23.1M 0 0 398k 0 0:03:26 0:00:59 0:02:27 396k
28 80.2M 28 23.1M 0 0 396k 0 0:03:27 0:00:59 0:02:28 397k
29 80.2M 29 23.5M 0 0 395k 0 0:03:27 0:01:00 0:02:27 405k
29 80.2M 29 23.8M 0 0 395k 0 0:03:27 0:01:01 0:02:26 405k
30 80.2M 30 24.3M 0 0 395k 0 0:03:27 0:01:03 0:02:24 407k
31 80.2M 31 24.8M 0 0 397k 0 0:03:26 0:01:04 0:02:22 387k
31 80.2M 31 25.5M 0 0 401k 0 0:03:24 0:01:05 0:02:19 465k
32 80.2M 32 25.7M 0 0 400k 0 0:03:25 0:01:05 0:02:20 465k
32 80.2M 32 26.1M 0 0 399k 0 0:03:25 0:01:06 0:02:19 456k
33 80.2M 33 26.6M 0 0 400k 0 0:03:25 0:01:08 0:02:17 450k
33 80.2M 33 26.9M 0 0 400k 0 0:03:25 0:01:08 0:02:17 432k
34 80.2M 34 27.6M 0 0 403k 0 0:03:23 0:01:10 0:02:13 422k
34 80.2M 34 27.7M 0 0 401k 0 0:03:24 0:01:10 0:02:14 412k
35 80.2M 35 28.2M 0 0 401k 0 0:03:24 0:01:12 0:02:12 418k
35 80.2M 35 28.5M 0 0 400k 0 0:03:24 0:01:12 0:02:12 414k
36 80.2M 36 29.0M 0 0 401k 0 0:03:24 0:01:13 0:02:11 424k
36 80.2M 36 29.3M 0 0 400k 0 0:03:25 0:01:15 0:02:10 360k
37 80.2M 37 29.8M 0 0 403k 0 0:03:23 0:01:15 0:02:08 432k
37 80.2M 37 30.2M 0 0 402k 0 0:03:24 0:01:17 0:02:07 417k
38 80.2M 38 30.6M 0 0 402k 0 0:03:24 0:01:17 0:02:07 424k
38 80.2M 38 31.1M 0 0 403k 0 0:03:23 0:01:18 0:02:05 432k
39 80.2M 39 31.6M 0 0 404k 0 0:03:23 0:01:20 0:02:03 465k
39 80.2M 39 31.9M 0 0 404k 0 0:03:23 0:01:20 0:02:03 415k
40 80.2M 40 32.5M 0 0 404k 0 0:03:22 0:01:22 0:02:00 443k
40 80.2M 40 32.8M 0 0 405k 0 0:03:22 0:01:23 0:01:59 446k
41 80.2M 41 33.1M 0 0 405k 0 0:03:22 0:01:23 0:01:59 429k
42 80.2M 42 33.8M 0 0 408k 0 0:03:21 0:01:24 0:01:57 476k
42 80.2M 42 34.1M 0 0 406k 0 0:03:22 0:01:25 0:01:57 446k
42 80.2M 42 34.4M 0 0 406k 0 0:03:21 0:01:26 0:01:55 440k
43 80.2M 43 34.9M 0 0 407k 0 0:03:21 0:01:27 0:01:54 439k
44 80.2M 44 35.4M 0 0 408k 0 0:03:21 0:01:28 0:01:53 459k
44 80.2M 44 36.0M 0 0 408k 0 0:03:20 0:01:30 0:01:50 414k
45 80.2M 45 36.2M 0 0 408k 0 0:03:21 0:01:30 0:01:51 436k
45 80.2M 45 36.5M 0 0 407k 0 0:03:21 0:01:31 0:01:50 424k
46 80.2M 46 37.0M 0 0 408k 0 0:03:21 0:01:33 0:01:48 425k
46 80.2M 46 37.6M 0 0 408k 0 0:03:20 0:01:34 0:01:46 416k
47 80.2M 47 38.1M 0 0 410k 0 0:03:20 0:01:35 0:01:45 446k
47 80.2M 47 38.2M 0 0 409k 0 0:03:20 0:01:35 0:01:45 422k
48 80.2M 48 38.6M 0 0 408k 0 0:03:21 0:01:37 0:01:44 416k
48 80.2M 48 38.9M 0 0 408k 0 0:03:21 0:01:37 0:01:44 409k
49 80.2M 49 39.5M 0 0 408k 0 0:03:21 0:01:38 0:01:43 402k
50 80.2M 50 40.2M 0 0 410k 0 0:03:20 0:01:40 0:01:40 410k
50 80.2M 50 40.3M 0 0 408k 0 0:03:20 0:01:41 0:01:39 407k
50 80.2M 50 40.6M 0 0 408k 0 0:03:21 0:01:41 0:01:40 412k
51 80.2M 51 41.0M 0 0 408k 0 0:03:21 0:01:42 0:01:39 418k
51 80.2M 51 41.5M 0 0 409k 0 0:03:20 0:01:43 0:01:37 431k
52 80.2M 52 42.1M 0 0 411k 0 0:03:19 0:01:44 0:01:35 426k
52 80.2M 52 42.3M 0 0 409k 0 0:03:20 0:01:45 0:01:35 426k
53 80.2M 53 42.8M 0 0 409k 0 0:03:20 0:01:47 0:01:33 432k
53 80.2M 53 43.1M 0 0 409k 0 0:03:20 0:01:47 0:01:33 432k
54 80.2M 54 44.0M 0 0 412k 0 0:03:19 0:01:49 0:01:30 453k
54 80.2M 54 44.0M 0 0 410k 0 0:03:20 0:01:49 0:01:31 395k
55 80.2M 55 44.4M 0 0 409k 0 0:03:20 0:01:51 0:01:29 410k
55 80.2M 55 44.7M 0 0 409k 0 0:03:20 0:01:51 0:01:29 405k
56 80.2M 56 45.2M 0 0 410k 0 0:03:20 0:01:53 0:01:27 421k
57 80.2M 57 45.9M 0 0 412k 0 0:03:19 0:01:54 0:01:25 432k
57 80.2M 57 46.1M 0 0 411k 0 0:03:19 0:01:54 0:01:25 423k
57 80.2M 57 46.5M 0 0 410k 0 0:03:20 0:01:56 0:01:24 426k
58 80.2M 58 46.8M 0 0 410k 0 0:03:19 0:01:56 0:01:23 445k
59 80.2M 59 47.3M 0 0 411k 0 0:03:19 0:01:57 0:01:22 440k
59 80.2M 59 47.4M 0 0 409k 0 0:03:20 0:01:58 0:01:22 321k
59 80.2M 59 48.1M 0 0 410k 0 0:03:20 0:02:00 0:01:20 389k
60 80.2M 60 48.4M 0 0 410k 0 0:03:20 0:02:00 0:01:20 413k
60 80.2M 60 48.9M 0 0 410k 0 0:03:20 0:02:02 0:01:18 398k
61 80.2M 61 49.2M 0 0 410k 0 0:03:20 0:02:02 0:01:18 390k
62 80.2M 62 49.7M 0 0 410k 0 0:03:19 0:02:04 0:01:15 449k
62 80.2M 62 50.0M 0 0 410k 0 0:03:20 0:02:04 0:01:16 411k
62 80.2M 62 50.4M 0 0 410k 0 0:03:20 0:02:05 0:01:15 406k
63 80.2M 63 50.7M 0 0 410k 0 0:03:20 0:02:06 0:01:14 401k
63 80.2M 63 51.1M 0 0 409k 0 0:03:20 0:02:08 0:01:12 380k
64 80.2M 64 51.5M 0 0 409k 0 0:03:20 0:02:08 0:01:12 365k
64 80.2M 64 51.8M 0 0 409k 0 0:03:20 0:02:09 0:01:11 384k
65 80.2M 65 52.4M 0 0 409k 0 0:03:20 0:02:11 0:01:09 388k
65 80.2M 65 52.6M 0 0 409k 0 0:03:20 0:02:11 0:01:09 383k
66 80.2M 66 53.1M 0 0 408k 0 0:03:20 0:02:13 0:01:07 394k
66 80.2M 66 53.4M 0 0 408k 0 0:03:20 0:02:13 0:01:07 400k
67 80.2M 67 54.0M 0 0 409k 0 0:03:20 0:02:15 0:01:05 412k
68 80.2M 68 54.6M 0 0 410k 0 0:03:19 0:02:16 0:01:03 443k
68 80.2M 68 54.8M 0 0 409k 0 0:03:20 0:02:17 0:01:03 422k
68 80.2M 68 55.0M 0 0 409k 0 0:03:20 0:02:17 0:01:03 421k
69 80.2M 69 55.3M 0 0 408k 0 0:03:21 0:02:18 0:01:03 403k
69 80.2M 69 55.9M 0 0 408k 0 0:03:20 0:02:20 0:01:00 389k
70 80.2M 70 56.4M 0 0 410k 0 0:03:20 0:02:20 0:01:00 386k
70 80.2M 70 56.7M 0 0 408k 0 0:03:21 0:02:22 0:00:59 383k
71 80.2M 71 57.0M 0 0 408k 0 0:03:21 0:02:22 0:00:59 384k
71 80.2M 71 57.4M 0 0 408k 0 0:03:21 0:02:24 0:00:57 402k
72 80.2M 72 57.8M 0 0 408k 0 0:03:21 0:02:24 0:00:57 403k
72 80.2M 72 58.3M 0 0 409k 0 0:03:20 0:02:26 0:00:54 383k
73 80.2M 73 58.6M 0 0 408k 0 0:03:21 0:02:26 0:00:55 408k
73 80.2M 73 59.0M 0 0 408k 0 0:03:21 0:02:28 0:00:53 407k
74 80.2M 74 59.4M 0 0 408k 0 0:03:21 0:02:28 0:00:53 408k
74 80.2M 74 59.8M 0 0 409k 0 0:03:20 0:02:29 0:00:51 429k
75 80.2M 75 60.2M 0 0 408k 0 0:03:20 0:02:30 0:00:50 400k
75 80.2M 75 60.5M 0 0 408k 0 0:03:21 0:02:31 0:00:50 397k
76 80.2M 76 60.9M 0 0 408k 0 0:03:21 0:02:33 0:00:48 397k
76 80.2M 76 61.3M 0 0 408k 0 0:03:21 0:02:33 0:00:48 412k
77 80.2M 77 61.8M 0 0 407k 0 0:03:21 0:02:35 0:00:46 357k
77 80.2M 77 62.1M 0 0 408k 0 0:03:21 0:02:35 0:00:46 405k
78 80.2M 78 62.5M 0 0 408k 0 0:03:21 0:02:37 0:00:44 404k
78 80.2M 78 62.8M 0 0 407k 0 0:03:21 0:02:37 0:00:44 401k
78 80.2M 78 63.2M 0 0 407k 0 0:03:21 0:02:38 0:00:43 377k
79 80.2M 79 63.9M 0 0 408k 0 0:03:20 0:02:40 0:00:40 454k
79 80.2M 79 64.0M 0 0 407k 0 0:03:21 0:02:40 0:00:41 375k
80 80.2M 80 64.4M 0 0 407k 0 0:03:21 0:02:42 0:00:39 378k
80 80.2M 80 64.7M 0 0 407k 0 0:03:21 0:02:42 0:00:39 384k
81 80.2M 81 65.1M 0 0 407k 0 0:03:21 0:02:43 0:00:38 400k
81 80.2M 81 65.5M 0 0 406k 0 0:03:22 0:02:44 0:00:38 333k
82 80.2M 82 66.1M 0 0 407k 0 0:03:21 0:02:46 0:00:35 409k
82 80.2M 82 66.4M 0 0 407k 0 0:03:21 0:02:46 0:00:35 417k
83 80.2M 83 66.7M 0 0 407k 0 0:03:21 0:02:47 0:00:34 406k
83 80.2M 83 66.8M 0 0 405k 0 0:03:22 0:02:48 0:00:34 348k
83 80.2M 83 66.8M 0 0 402k 0 0:03:24 0:02:50 0:00:34 267k
83 80.2M 83 66.9M 0 0 400k 0 0:03:24 0:02:51 0:00:33 178k
83 80.2M 83 67.0M 0 0 399k 0 0:03:25 0:02:51 0:00:34 121k
83 80.2M 83 67.0M 0 0 396k 0 0:03:27 0:02:53 0:00:34 68552
83 80.2M 83 67
.1M 0 0 395k 0 0:03:27 0:02:53 0:00:34 55340
83 80.2M 83 67.3M 0 0 393k 0 0:03:28 0:02:55 0:00:33 97k
84 80.2M 84 67.5M 0 0 393k 0 0:03:28 0:02:55 0:00:33 123k
84 80.2M 84 67.8M 0 0 392k 0 0:03:29 0:02:56 0:00:33 169k
85 80.2M 85 68.2M 0 0 392k 0 0:03:29 0:02:57 0:00:32 244k
85 80.2M 85 68.7M 0 0 393k 0 0:03:28 0:02:58 0:00:30 318k
86 80.2M 86 69.0M 0 0 393k 0 0:03:28 0:02:59 0:00:29 379k
86 80.2M 86 69.1M 0 0 391k 0 0:03:29 0:03:00 0:00:29 331k
86 80.2M 86 69.1M 0 0 388k 0 0:03:31 0:03:02 0:00:29 260k
86 80.2M 86 69.1M 0 0 387k 0 0:03:32 0:03:03 0:00:29 191k
86 80.2M 86 69.2M 0 0 385k 0 0:03:33 0:03:03 0:00:30 104k
86 80.2M 86 69.4M 0 0 383k 0 0:03:33 0:03:05 0:00:28 64954
86 80.2M 86 69.4M 0 0 382k 0 0:03:34 0:03:05 0:00:29 65431
86 80.2M 86 69.5M 0 0 381k 0 0:03:35 0:03:07 0:00:28 94686
86 80.2M 86 69.7M 0 0 380k 0 0:03:36 0:03:07 0:00:29 115k
87 80.2M 87 69.9M 0 0 379k 0 0:03:36 0:03:08 0:00:28 149k
87 80.2M 87 70.1M 0 0 378k 0 0:03:37 0:03:10 0:00:27 159k
87 80.2M 87 70.3M 0 0 377k 0 0:03:37 0:03:10 0:00:27 188k
88 80.2M 88 70.7M 0 0 377k 0 0:03:37 0:03:12 0:00:25 229k
88 80.2M 88 70.9M 0 0 376k 0 0:03:37 0:03:12 0:00:25 257k
89 80.2M 89 71.5M 0 0 376k 0 0:03:38 0:03:14 0:00:24 279k
89 80.2M 89 71.7M 0 0 376k 0 0:03:38 0:03:14 0:00:24 318k
90 80.2M 90 72.2M 0 0 377k 0 0:03:37 0:03:16 0:00:21 366k
90 80.2M 90 72.4M 0 0 376k 0 0:03:38 0:03:16 0:00:22 361k
90 80.2M 90 72.8M 0 0 377k 0 0:03:37 0:03:17 0:00:20 386k
91 80.2M 91 73.1M 0 0 376k 0 0:03:38 0:03:18 0:00:20 388k
91 80.2M 91 73.6M 0 0 376k 0 0:03:37 0:03:20 0:00:17 384k
92 80.2M 92 73.9M 0 0 376k 0 0:03:37 0:03:20 0:00:17 359k
92 80.2M 92 74.4M 0 0 377k 0 0:03:37 0:03:21 0:00:16 401k
93 80.2M 93 74.7M 0 0 377k 0 0:03:37 0:03:22 0:00:15 386k
93 80.2M 93 75.2M 0 0 377k 0 0:03:37 0:03:24 0:00:13 409k
94 80.2M 94 75.5M 0 0 377k 0 0:03:37 0:03:24 0:00:13 411k
94 80.2M 94 75.9M 0 0 378k 0 0:03:37 0:03:25 0:00:12 425k
95 80.2M 95 76.3M 0 0 377k 0 0:03:37 0:03:27 0:00:10 395k
95 80.2M 95 76.7M 0 0 377k 0 0:03:37 0:03:27 0:00:10 400k
96 80.2M 96 77.0M 0 0 377k 0 0:03:37 0:03:28 0:00:09 383k
96 80.2M 96 77.2M 0 0 376k 0 0:03:38 0:03:30 0:00:08 325k
96 80.2M 96 77.4M 0 0 375k 0 0:03:38 0:03:30 0:00:08 290k
97 80.2M 97 77.8M 0 0 376k 0 0:03:38 0:03:32 0:00:06 306k
97 80.2M 97 78.1M 0 0 376k 0 0:03:38 0:03:32 0:00:06 297k
98 80.2M 98 78.6M 0 0 376k 0 0:03:38 0:03:33 0:00:05 318k
98 80.2M 98 78.9M 0 0 376k 0 0:03:38 0:03:34 0:00:04 374k
98 80.2M 98 79.4M 0 0 376k 0 0:03:38 0:03:35 0:00:03 410k
99 80.2M 99 79.9M 0 0 377k 0 0:03:37 0:03:36 0:00:01 444k
100 80.2M 100 80.2M 0 0 378k 0 0:03:37 0:03:37 –:–:– 473k
def get_text_list_from_files(files):
text_list = []
for name in files:
with open(name) as f:
for line in f:
text_list.append(line)
return text_list
def get_data_from_text_files(folder_name):
pos_files = glob.glob("aclImdb/" + folder_name + "/pos/*.txt")
pos_texts = get_text_list_from_files(pos_files)
neg_files = glob.glob("aclImdb/" + folder_name + "/neg/*.txt")
neg_texts = get_text_list_from_files(neg_files)
df = pd.DataFrame(
{
"review": pos_texts + neg_texts,
"sentiment": [0] * len(pos_texts) + [1] * len(neg_texts),
}
)
df = df.sample(len(df)).reset_index(drop=True)
return df
train_df = get_data_from_text_files("train")
test_df = get_data_from_text_files("test")
all_data = pd.concat([train_df, test_df], ignore_index=True)
我們將使用 TextVectorization
層將文字向量化為整數符記 ID。它將一批字串轉換為符記索引序列(一個樣本 = 1D 整數符記索引陣列,依序排列)或密集表示(一個樣本 = 1D 浮點數值陣列,編碼符記的無序集合)。
下面,我們定義 3 個預處理函數。
get_vectorize_layer
函數建立 TextVectorization
層。encode
函數將原始文字編碼為整數符記 ID。get_masked_input_and_labels
函數將遮罩輸入符記 ID。它隨機遮罩每個序列中 15% 的所有輸入符記。# For data pre-processing and tf.data.Dataset
import tensorflow as tf
def custom_standardization(input_data):
lowercase = tf.strings.lower(input_data)
stripped_html = tf.strings.regex_replace(lowercase, "<br />", " ")
return tf.strings.regex_replace(
stripped_html, "[%s]" % re.escape("!#$%&'()*+,-./:;<=>?@\^_`{|}~"), ""
)
def get_vectorize_layer(texts, vocab_size, max_seq, special_tokens=["[MASK]"]):
"""Build Text vectorization layer
Args:
texts (list): List of string i.e input texts
vocab_size (int): vocab size
max_seq (int): Maximum sequence length.
special_tokens (list, optional): List of special tokens. Defaults to ['[MASK]'].
Returns:
layers.Layer: Return TextVectorization Keras Layer
"""
vectorize_layer = TextVectorization(
max_tokens=vocab_size,
output_mode="int",
standardize=custom_standardization,
output_sequence_length=max_seq,
)
vectorize_layer.adapt(texts)
# Insert mask token in vocabulary
vocab = vectorize_layer.get_vocabulary()
vocab = vocab[2 : vocab_size - len(special_tokens)] + ["[mask]"]
vectorize_layer.set_vocabulary(vocab)
return vectorize_layer
vectorize_layer = get_vectorize_layer(
all_data.review.values.tolist(),
config.VOCAB_SIZE,
config.MAX_LEN,
special_tokens=["[mask]"],
)
# Get mask token id for masked language model
mask_token_id = vectorize_layer(["[mask]"]).numpy()[0][0]
def encode(texts):
encoded_texts = vectorize_layer(texts)
return encoded_texts.numpy()
def get_masked_input_and_labels(encoded_texts):
# 15% BERT masking
inp_mask = np.random.rand(*encoded_texts.shape) < 0.15
# Do not mask special tokens
inp_mask[encoded_texts <= 2] = False
# Set targets to -1 by default, it means ignore
labels = -1 * np.ones(encoded_texts.shape, dtype=int)
# Set labels for masked tokens
labels[inp_mask] = encoded_texts[inp_mask]
# Prepare input
encoded_texts_masked = np.copy(encoded_texts)
# Set input to [MASK] which is the last token for the 90% of tokens
# This means leaving 10% unchanged
inp_mask_2mask = inp_mask & (np.random.rand(*encoded_texts.shape) < 0.90)
encoded_texts_masked[inp_mask_2mask] = (
mask_token_id # mask token is the last in the dict
)
# Set 10% to a random token
inp_mask_2random = inp_mask_2mask & (np.random.rand(*encoded_texts.shape) < 1 / 9)
encoded_texts_masked[inp_mask_2random] = np.random.randint(
3, mask_token_id, inp_mask_2random.sum()
)
# Prepare sample_weights to pass to .fit() method
sample_weights = np.ones(labels.shape)
sample_weights[labels == -1] = 0
# y_labels would be same as encoded_texts i.e input tokens
y_labels = np.copy(encoded_texts)
return encoded_texts_masked, y_labels, sample_weights
# We have 25000 examples for training
x_train = encode(train_df.review.values) # encode reviews with vectorizer
y_train = train_df.sentiment.values
train_classifier_ds = (
tf.data.Dataset.from_tensor_slices((x_train, y_train))
.shuffle(1000)
.batch(config.BATCH_SIZE)
)
# We have 25000 examples for testing
x_test = encode(test_df.review.values)
y_test = test_df.sentiment.values
test_classifier_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(
config.BATCH_SIZE
)
# Dataset for end to end model input (will be used at the end)
test_raw_classifier_ds = test_df
# Prepare data for masked language model
x_all_review = encode(all_data.review.values)
x_masked_train, y_masked_labels, sample_weights = get_masked_input_and_labels(
x_all_review
)
mlm_ds = tf.data.Dataset.from_tensor_slices(
(x_masked_train, y_masked_labels, sample_weights)
)
mlm_ds = mlm_ds.shuffle(1000).batch(config.BATCH_SIZE)
我們將使用 MultiHeadAttention
層建立類似 BERT 的預訓練模型架構。它將採用符記 ID 作為輸入(包括遮罩符記),並將預測遮罩輸入符記的正確 ID。
def bert_module(query, key, value, i):
# Multi headed self-attention
attention_output = layers.MultiHeadAttention(
num_heads=config.NUM_HEAD,
key_dim=config.EMBED_DIM // config.NUM_HEAD,
name="encoder_{}_multiheadattention".format(i),
)(query, key, value)
attention_output = layers.Dropout(0.1, name="encoder_{}_att_dropout".format(i))(
attention_output
)
attention_output = layers.LayerNormalization(
epsilon=1e-6, name="encoder_{}_att_layernormalization".format(i)
)(query + attention_output)
# Feed-forward layer
ffn = keras.Sequential(
[
layers.Dense(config.FF_DIM, activation="relu"),
layers.Dense(config.EMBED_DIM),
],
name="encoder_{}_ffn".format(i),
)
ffn_output = ffn(attention_output)
ffn_output = layers.Dropout(0.1, name="encoder_{}_ffn_dropout".format(i))(
ffn_output
)
sequence_output = layers.LayerNormalization(
epsilon=1e-6, name="encoder_{}_ffn_layernormalization".format(i)
)(attention_output + ffn_output)
return sequence_output
loss_fn = keras.losses.SparseCategoricalCrossentropy(reduction=None)
loss_tracker = keras.metrics.Mean(name="loss")
class MaskedLanguageModel(keras.Model):
def compute_loss(self, x=None, y=None, y_pred=None, sample_weight=None):
loss = loss_fn(y, y_pred, sample_weight)
loss_tracker.update_state(loss, sample_weight=sample_weight)
return keras.ops.sum(loss)
def compute_metrics(self, x, y, y_pred, sample_weight):
# Return a dict mapping metric names to current value
return {"loss": loss_tracker.result()}
@property
def metrics(self):
# We list our `Metric` objects here so that `reset_states()` can be
# called automatically at the start of each epoch
# or at the start of `evaluate()`.
# If you don't implement this property, you have to call
# `reset_states()` yourself at the time of your choosing.
return [loss_tracker]
def create_masked_language_bert_model():
inputs = layers.Input((config.MAX_LEN,), dtype="int64")
word_embeddings = layers.Embedding(
config.VOCAB_SIZE, config.EMBED_DIM, name="word_embedding"
)(inputs)
position_embeddings = keras_hub.layers.PositionEmbedding(
sequence_length=config.MAX_LEN
)(word_embeddings)
embeddings = word_embeddings + position_embeddings
encoder_output = embeddings
for i in range(config.NUM_LAYERS):
encoder_output = bert_module(encoder_output, encoder_output, encoder_output, i)
mlm_output = layers.Dense(config.VOCAB_SIZE, name="mlm_cls", activation="softmax")(
encoder_output
)
mlm_model = MaskedLanguageModel(inputs, mlm_output, name="masked_bert_model")
optimizer = keras.optimizers.Adam(learning_rate=config.LR)
mlm_model.compile(optimizer=optimizer)
return mlm_model
id2token = dict(enumerate(vectorize_layer.get_vocabulary()))
token2id = {y: x for x, y in id2token.items()}
class MaskedTextGenerator(keras.callbacks.Callback):
def __init__(self, sample_tokens, top_k=5):
self.sample_tokens = sample_tokens
self.k = top_k
def decode(self, tokens):
return " ".join([id2token[t] for t in tokens if t != 0])
def convert_ids_to_tokens(self, id):
return id2token[id]
def on_epoch_end(self, epoch, logs=None):
prediction = self.model.predict(self.sample_tokens)
masked_index = np.where(self.sample_tokens == mask_token_id)
masked_index = masked_index[1]
mask_prediction = prediction[0][masked_index]
top_indices = mask_prediction[0].argsort()[-self.k :][::-1]
values = mask_prediction[0][top_indices]
for i in range(len(top_indices)):
p = top_indices[i]
v = values[i]
tokens = np.copy(sample_tokens[0])
tokens[masked_index[0]] = p
result = {
"input_text": self.decode(sample_tokens[0].numpy()),
"prediction": self.decode(tokens),
"probability": v,
"predicted mask token": self.convert_ids_to_tokens(p),
}
pprint(result)
sample_tokens = vectorize_layer(["I have watched this [mask] and it was awesome"])
generator_callback = MaskedTextGenerator(sample_tokens.numpy())
bert_masked_model = create_masked_language_bert_model()
bert_masked_model.summary()
Model: "masked_bert_model"
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃ ┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩ │ input_layer │ (None, 256) │ 0 │ - │ │ (InputLayer) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ word_embedding │ (None, 256, 128) │ 3,840,000 │ input_layer[0][0] │ │ (Embedding) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ position_embedding │ (None, 256, 128) │ 32,768 │ word_embedding[0… │ │ (PositionEmbedding) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ add (Add) │ (None, 256, 128) │ 0 │ word_embedding[0… │ │ │ │ │ position_embeddi… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ encoder_0_multihea… │ (None, 256, 128) │ 66,048 │ add[0][0], │ │ (MultiHeadAttentio… │ │ │ add[0][0], │ │ │ │ │ add[0][0] │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ encoder_0_att_drop… │ (None, 256, 128) │ 0 │ encoder_0_multih… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ add_1 (Add) │ (None, 256, 128) │ 0 │ add[0][0], │ │ │ │ │ encoder_0_att_dr… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ encoder_0_att_laye… │ (None, 256, 128) │ 256 │ add_1[0][0] │ │ (LayerNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ encoder_0_ffn │ (None, 256, 128) │ 33,024 │ encoder_0_att_la… │ │ (Sequential) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ encoder_0_ffn_drop… │ (None, 256, 128) │ 0 │ encoder_0_ffn[0]… │ │ (Dropout) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ add_2 (Add) │ (None, 256, 128) │ 0 │ encoder_0_att_la… │ │ │ │ │ encoder_0_ffn_dr… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ encoder_0_ffn_laye… │ (None, 256, 128) │ 256 │ add_2[0][0] │ │ (LayerNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ mlm_cls (Dense) │ (None, 256, │ 3,870,000 │ encoder_0_ffn_la… │ │ │ 30000) │ │ │ └─────────────────────┴───────────────────┴────────────┴───────────────────┘
Total params: 7,842,352 (29.92 MB)
Trainable params: 7,842,352 (29.92 MB)
Non-trainable params: 0 (0.00 B)
bert_masked_model.fit(mlm_ds, epochs=5, callbacks=[generator_callback])
bert_masked_model.save("bert_mlm_imdb.keras")
1/16 ━ [37m━━━━━━━━━━━━━━━━━━━ 3:02 12s/step - loss: 10.3103
2/16 ━━ [37m━━━━━━━━━━━━━━━━━━ 3:31 15s/step - loss: 10.2979
3/16 ━━━ [37m━━━━━━━━━━━━━━━━━ 3:25 16s/step - loss: 10.2859
4/16 ━━━━━ [37m━━━━━━━━━━━━━━━ 3:14 16s/step - loss: 10.2727
5/16 ━━━━━━ [37m━━━━━━━━━━━━━━ 2:57 16s/step - loss: 10.2564
6/16 ━━━━━━━ [37m━━━━━━━━━━━━━ 2:42 16s/step - loss: 10.2378
7/16 ━━━━━━━━ [37m━━━━━━━━━━━━ 2:26 16s/step - loss: 10.2182
8/16 ━━━━━━━━━━ [37m━━━━━━━━━━ 2:10 16s/step - loss: 10.1975
9/16 ━━━━━━━━━━━ [37m━━━━━━━━━ 1:55 16s/step - loss: 10.1745
10/16 ━━━━━━━━━━━━ [37m━━━━━━━━ 1:39 17s/step - loss: 10.1503
11/16 ━━━━━━━━━━━━━ [37m━━━━━━━ 1:23 17s/step - loss: 10.1254
12/16 ━━━━━━━━━━━━━━━ [37m━━━━━ 1:07 17s/step - loss: 10.0993
13/16 ━━━━━━━━━━━━━━━━ [37m━━━━ 50s 17s/step - loss: 10.0726
14/16 ━━━━━━━━━━━━━━━━━ [37m━━━ 33s 17s/step - loss: 10.0452
15/16 ━━━━━━━━━━━━━━━━━━ [37m━━ 16s 17s/step - loss: 10.0174
16/16 ━━━━━━━━━━━━━━━━━━━━ 0s 17s/step - loss: 9.9899
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 81ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 82ms/step
{'input_text': 'i have watched this [mask] and it was awesome',
'predicted mask token': 'a',
'prediction': 'i have watched this a and it was awesome',
'probability': 0.0013674975}
{'input_text': 'i have watched this [mask] and it was awesome',
'predicted mask token': 'i',
'prediction': 'i have watched this i and it was awesome',
'probability': 0.0012694978}
{'input_text': 'i have watched this [mask] and it was awesome',
'predicted mask token': 'is',
'prediction': 'i have watched this is and it was awesome',
'probability': 0.0012668626}
{'input_text': 'i have watched this [mask] and it was awesome',
'predicted mask token': 'to',
'prediction': 'i have watched this to and it was awesome',
'probability': 0.0012651902}
{'input_text': 'i have watched this [mask] and it was awesome',
'predicted mask token': 'of',
'prediction': 'i have watched this of and it was awesome',
'probability': 0.0011966776}
16/16 ━━━━━━━━━━━━━━━━━━━━ 261s 17s/step - loss: 9.9656
我們將在情感分類的下游任務上微調我們的自我監督模型。為此,讓我們透過在預訓練的 BERT 特徵之上添加池化層和 Dense
層來建立分類器。
# Load pretrained bert model
mlm_model = keras.models.load_model(
"bert_mlm_imdb.keras", custom_objects={"MaskedLanguageModel": MaskedLanguageModel}
)
pretrained_bert_model = keras.Model(
mlm_model.input, mlm_model.get_layer("encoder_0_ffn_layernormalization").output
)
# Freeze it
pretrained_bert_model.trainable = False
def create_classifier_bert_model():
inputs = layers.Input((config.MAX_LEN,), dtype="int64")
sequence_output = pretrained_bert_model(inputs)
pooled_output = layers.GlobalMaxPooling1D()(sequence_output)
hidden_layer = layers.Dense(64, activation="relu")(pooled_output)
outputs = layers.Dense(1, activation="sigmoid")(hidden_layer)
classifer_model = keras.Model(inputs, outputs, name="classification")
optimizer = keras.optimizers.Adam()
classifer_model.compile(
optimizer=optimizer, loss="binary_crossentropy", metrics=["accuracy"]
)
return classifer_model
classifer_model = create_classifier_bert_model()
classifer_model.summary()
# Train the classifier with frozen BERT stage
classifer_model.fit(
train_classifier_ds,
epochs=5,
validation_data=test_classifier_ds,
)
# Unfreeze the BERT model for fine-tuning
pretrained_bert_model.trainable = True
optimizer = keras.optimizers.Adam()
classifer_model.compile(
optimizer=optimizer, loss="binary_crossentropy", metrics=["accuracy"]
)
classifer_model.fit(
train_classifier_ds,
epochs=5,
validation_data=test_classifier_ds,
)
Model: "classification"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ input_layer_2 (InputLayer) │ (None, 256) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ functional_3 (Functional) │ (None, 256, 128) │ 3,972,352 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ global_max_pooling1d │ (None, 128) │ 0 │ │ (GlobalMaxPooling1D) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_2 (Dense) │ (None, 64) │ 8,256 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_3 (Dense) │ (None, 1) │ 65 │ └─────────────────────────────────┴────────────────────────┴───────────────┘
Total params: 3,980,673 (15.19 MB)
Trainable params: 8,321 (32.50 KB)
Non-trainable params: 3,972,352 (15.15 MB)
1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0s 140ms/step - accuracy: 0.5312 - loss: 0.7599
2/8 ━━━━━ [37m━━━━━━━━━━━━━━━ 1s 184ms/step - accuracy: 0.5703 - loss: 0.7296
3/8 ━━━━━━━ [37m━━━━━━━━━━━━━ 0s 164ms/step - accuracy: 0.5851 - loss: 0.7164
4/8 ━━━━━━━━━━ [37m━━━━━━━━━━ 0s 161ms/step - accuracy: 0.5794 - loss: 0.7125
5/8 ━━━━━━━━━━━━ [37m━━━━━━━━ 0s 158ms/step - accuracy: 0.5685 - loss: 0.7105
6/8 ━━━━━━━━━━━━━━━ [37m━━━━━ 0s 158ms/step - accuracy: 0.5589 - loss: 0.7090
7/8 ━━━━━━━━━━━━━━━━━ [37m━━━ 0s 156ms/step - accuracy: 0.5504 - loss: 0.7080
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 151ms/step - accuracy: 0.5426 - loss: 0.7076
8/8 ━━━━━━━━━━━━━━━━━━━━ 2s 288ms/step - accuracy: 0.5366 - loss: 0.7073 - val_accuracy: 0.4920 - val_loss: 0.6975
1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 3s 436ms/step - accuracy: 0.5000 - loss: 0.7119
2/8 ━━━━━ [37m━━━━━━━━━━━━━━━ 3s 534ms/step - accuracy: 0.5469 - loss: 0.6903
3/8 ━━━━━━━ [37m━━━━━━━━━━━━━ 2s 472ms/step - accuracy: 0.5660 - loss: 0.6913
4/8 ━━━━━━━━━━ [37m━━━━━━━━━━ 1s 461ms/step - accuracy: 0.5671 - loss: 0.7032
5/8 ━━━━━━━━━━━━ [37m━━━━━━━━ 1s 459ms/step - accuracy: 0.5636 - loss: 0.7116
6/8 ━━━━━━━━━━━━━━━ [37m━━━━━ 0s 468ms/step - accuracy: 0.5626 - loss: 0.7156
7/8 ━━━━━━━━━━━━━━━━━ [37m━━━ 0s 476ms/step - accuracy: 0.5600 - loss: 0.7183
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 476ms/step - accuracy: 0.5580 - loss: 0.7198
8/8 ━━━━━━━━━━━━━━━━━━━━ 5s 650ms/step - accuracy: 0.5565 - loss: 0.7210 - val_accuracy: 0.5160 - val_loss: 0.6895
<keras.src.callbacks.history.History at 0x7a0e5fd9bf50>
當您想要部署模型時,最好是模型已經包含其預處理管線,這樣您就不必在生產環境中重新實作預處理邏輯。讓我們建立一個端對端模型,將 TextVectorization
層併入評估方法中,並進行評估。我們將傳遞原始字串作為輸入。
# We create a custom Model to override the evaluate method so
# that it first pre-process text data
class ModelEndtoEnd(keras.Model):
def evaluate(self, inputs):
features = encode(inputs.review.values)
labels = inputs.sentiment.values
test_classifier_ds = (
tf.data.Dataset.from_tensor_slices((features, labels))
.shuffle(1000)
.batch(config.BATCH_SIZE)
)
return super().evaluate(test_classifier_ds)
# Build the model
def build(self, input_shape):
self.built = True
def get_end_to_end(model):
inputs = classifer_model.inputs[0]
outputs = classifer_model.outputs
end_to_end_model = ModelEndtoEnd(inputs, outputs, name="end_to_end_model")
optimizer = keras.optimizers.Adam(learning_rate=config.LR)
end_to_end_model.compile(
optimizer=optimizer, loss="binary_crossentropy", metrics=["accuracy"]
)
return end_to_end_model
end_to_end_classification_model = get_end_to_end(classifer_model)
# Pass raw text dataframe to the model
end_to_end_classification_model.evaluate(test_raw_classifier_ds)
1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0s 138ms/step - accuracy: 0.6875 - loss: 0.6684
2/8 ━━━━━ [37m━━━━━━━━━━━━━━━ 1s 225ms/step - accuracy: 0.6250 - loss: 0.6761
3/8 ━━━━━━━ [37m━━━━━━━━━━━━━ 0s 190ms/step - accuracy: 0.5833 - loss: 0.6820
4/8 ━━━━━━━━━━ [37m━━━━━━━━━━ 0s 184ms/step - accuracy: 0.5605 - loss: 0.6848
5/8 ━━━━━━━━━━━━ [37m━━━━━━━━ 0s 178ms/step - accuracy: 0.5422 - loss: 0.6871
6/8 ━━━━━━━━━━━━━━━ [37m━━━━━ 0s 174ms/step - accuracy: 0.5352 - loss: 0.6880
7/8 ━━━━━━━━━━━━━━━━━ [37m━━━ 0s 169ms/step - accuracy: 0.5320 - loss: 0.6883
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 165ms/step - accuracy: 0.5300 - loss: 0.6885
8/8 ━━━━━━━━━━━━━━━━━━━━ 1s 166ms/step - accuracy: 0.5285 - loss: 0.6886
[0.6894814372062683, 0.515999972820282]