程式碼範例 / 自然語言處理 / 使用 FNet 的文字分類

使用 FNet 的文字分類

作者: Abheesht Sharma
建立日期 2022/06/01
最後修改日期 2022/12/21
描述: 使用 keras_hub.layers.FNetEncoder 層在 IMDb 資料集上進行文字分類。

ⓘ 此範例使用 Keras 3

在 Colab 中檢視 GitHub 來源


簡介

在此範例中,我們將展示 FNet 在文字分類任務中實現與 vanilla Transformer 模型相媲美結果的能力。我們將使用 IMDb 資料集,該資料集是標記為正面或負面(情感分析)的電影評論集合。

為了建構 tokenizer、模型等,我們將使用來自 KerasHub 的元件。KerasHub 讓想要建構 NLP 管道的人們生活更輕鬆! :)

模型

基於 Transformer 的語言模型 (LM),例如 BERT、RoBERTa、XLNet 等,已證明自注意力機制對於計算輸入文字的豐富嵌入的有效性。然而,自注意力機制是一種昂貴的操作,其時間複雜度為 O(n^2),其中 n 是輸入中的 token 數量。因此,人們一直努力減少自注意力機制的時間複雜度並提高效能,同時又不犧牲結果的品質。

在 2020 年,一篇題為 FNet: Mixing Tokens with Fourier Transforms 的論文用一個簡單的傅立葉變換層替換了 BERT 中的自注意力層,用於「token 混合」。這導致了可比較的準確性和訓練期間的加速。特別是,論文中有幾點值得注意

  • 作者聲稱 FNet 在 GPU 上比 BERT 快 80%,在 TPU 上快 70%。這種加速的原因有兩個:a) 傅立葉變換層是未參數化的,它沒有任何參數,b) 作者使用快速傅立葉變換 (FFT);這將時間複雜度從 O(n^2)(在自注意力的情況下)降低到 O(n log n)
  • FNet 設法在 GLUE 基準測試中達到 BERT 準確度的 92-97%。

設定

在我們開始實作之前,讓我們先匯入所有必要的套件。

!pip install -q --upgrade keras-hub
!pip install -q --upgrade keras  # Upgrade to Keras 3.
import keras_hub
import keras
import tensorflow as tf
import os

keras.utils.set_random_seed(42)

我們也來定義我們的超參數。

BATCH_SIZE = 64
EPOCHS = 3
MAX_SEQUENCE_LENGTH = 512
VOCAB_SIZE = 15000

EMBED_DIM = 128
INTERMEDIATE_DIM = 512

載入資料集

首先,讓我們下載 IMDB 資料集並解壓縮它。

!wget http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz
!tar -xzf aclImdb_v1.tar.gz
--2023-11-22 17:59:33--  http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz
Resolving ai.stanford.edu (ai.stanford.edu)... 171.64.68.10
Connecting to ai.stanford.edu (ai.stanford.edu)|171.64.68.10|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 84125825 (80M) [application/x-gzip]
Saving to: ‘aclImdb_v1.tar.gz’
aclImdb_v1.tar.gz   100%[===================>]  80.23M  93.3MB/s    in 0.9s    
2023-11-22 17:59:34 (93.3 MB/s) - ‘aclImdb_v1.tar.gz’ saved [84125825/84125825]

樣本以文字檔的形式存在。讓我們檢查目錄的結構。

print(os.listdir("./aclImdb"))
print(os.listdir("./aclImdb/train"))
print(os.listdir("./aclImdb/test"))
['README', 'imdb.vocab', 'imdbEr.txt', 'train', 'test']
['neg', 'unsup', 'pos', 'unsupBow.feat', 'urls_unsup.txt', 'urls_neg.txt', 'urls_pos.txt', 'labeledBow.feat']
['neg', 'pos', 'urls_neg.txt', 'urls_pos.txt', 'labeledBow.feat']

該目錄包含兩個子目錄:traintest。每個子目錄又包含兩個資料夾:分別用於正面和負面評論的 posneg。在我們載入資料集之前,讓我們刪除 ./aclImdb/train/unsup 資料夾,因為它有未標記的樣本。

!rm -rf aclImdb/train/unsup

我們將使用 keras.utils.text_dataset_from_directory 公用程式從文字檔產生我們標記的 tf.data.Dataset 資料集。

train_ds = keras.utils.text_dataset_from_directory(
    "aclImdb/train",
    batch_size=BATCH_SIZE,
    validation_split=0.2,
    subset="training",
    seed=42,
)
val_ds = keras.utils.text_dataset_from_directory(
    "aclImdb/train",
    batch_size=BATCH_SIZE,
    validation_split=0.2,
    subset="validation",
    seed=42,
)
test_ds = keras.utils.text_dataset_from_directory("aclImdb/test", batch_size=BATCH_SIZE)
Found 25000 files belonging to 2 classes.
Using 20000 files for training.
Found 25000 files belonging to 2 classes.
Using 5000 files for validation.
Found 25000 files belonging to 2 classes.

我們現在將把文字轉換為小寫。

train_ds = train_ds.map(lambda x, y: (tf.strings.lower(x), y))
val_ds = val_ds.map(lambda x, y: (tf.strings.lower(x), y))
test_ds = test_ds.map(lambda x, y: (tf.strings.lower(x), y))

讓我們列印一些樣本。

for text_batch, label_batch in train_ds.take(1):
    for i in range(3):
        print(text_batch.numpy()[i])
        print(label_batch.numpy()[i])
b'an illegal immigrant resists the social support system causing dire consequences for many. well filmed and acted even though the story is a bit forced, yet the slow pacing really sets off the conclusion. the feeling of being lost in the big city is effectively conveyed. the little person lost in the big society is something to which we can all relate, but i cannot endorse going out of your way to see this movie.'
0
b"to get in touch with the beauty of this film pay close attention to the sound track, not only the music, but the way all sounds help to weave the imagery. how beautifully the opening scene leading to the expulsion of gino establishes the theme of moral ambiguity! note the way music introduces the characters as we are led inside giovanna's marriage. don't expect to find much here of the political life of italy in 1943. that's not what this is about. on the other hand, if you are susceptible to the music of images and sounds, you will be led into a word that reaches beyond neo-realism. by the end of the film we there are moments antonioni-like landscape that has more to do with the inner life of the characters than with real places. this is one of my favorite visconti films."
1
b'"hollywood hotel" has relationships to many films like "ella cinders" and "merton of the movies" about someone winning a contest including a contract to make films in hollywood, only to find the road to stardom either paved with pitfalls or non-existent. in fact, as i was watching it tonight, on turner classic movies, i was considering whether or not the authors of the later musical classic "singing in the rain" may have taken some of their ideas from "hollywood hotel", most notably a temperamental leading lady star in a movie studio and a conclusion concerning one person singing a film score while another person got the credit by mouthing along on screen.<br /><br />"hollywood hotel" is a fascinating example of movie making in the 1930s. among the supporting players is louella parsons, playing herself (and, despite some negative comments i\'ve seen, she has a very ingratiating personality on screen and a natural command of her lines). she is not the only real person in the script. make-up specialist perc westmore briefly appears as himself to try to make one character resemble another.<br /><br />this film also was one of the first in the career of young mr. ronald reagan, playing a radio interviewer at a movie premiere. reagan actually does quite nicely in his brief scenes - particularly when he realizes that nobody dick powell is about to take over the microphone when it should be used with more important people.<br /><br />dick powell has won a hollywood contract in a contest, and is leaving his job as a saxophonist in benny goodman\'s band. the beginning of this film, by the way, is quite impressive, as the band drives in a parade of trucks to give a proper goodbye to powell. they end up singing "hooray for hollywood". the interesting thing about this wonderful number is that a lyric has been left out on purpose. throughout the johnny mercer lyrics are references to such hollywood as max factor the make-up king, rin tin tin, and even a hint of tarzan. but the original song lyric referred to looking like tyrone power. obviously jack warner and his brothers were not going to advertise the leading man of 20th century fox, and the name donald duck was substituted. in any event the number showed the singers and instrumentalists of goodman\'s orchestra at their best. so did a later five minute section of the film, where the band is rehearsing.<br /><br />powell leaves the band and his girl friend (frances langford) and goes to hollywood, only to find he is a contract player (most likely for musicals involving saxophonists). he is met by allen joslyn, the publicist of the studio (the owner is grant mitchell). joslyn is not a bad fellow, but he is busy and he tends to slough off people unless it is necessary to speak to them. he parks powell at a room at the hollywood hotel, which is also where the studio\'s temperamental star (lola lane) lives with her father (hugh herbert), her sister (mabel todd), and her sensible if cynical assistant (glenda farrell). lane is like jean hagen in "singing in the rain", except her speaking voice is good. her version of "dan lockwood" is one "alexander dupre" (alan mowbray, scene stealing with ease several times). the only difference is that mowbray is not a nice guy like gene kelly was, and lane (when not wrapped up in her ego) is fully aware of it. having a fit on being by-passed for an out-of-the ordinary role she wanted, she refuses to attend the premiere of her latest film. joslyn finds a double for her (lola\'s real life sister rosemary lane), and rosemary is made up to play the star at the premiere and the follow-up party. but she attends with powell (joslyn wanting someone who doesn\'t know the real lola). this leads to powell knocking down mowbray when the latter makes a pest of himself. but otherwise the evening is a success, and when the two are together they start finding each other attractive.<br /><br />the complications deal with lola coming back and slapping powell in the face, after mowbray complains he was attacked by powell ("and his gang of hoodlums"). powell\'s contract is bought out. working with photographer turned agent ted healey (actually not too bad in this film - he even tries to do a jolson imitation at one point), the two try to find work, ending up as employees at a hamburger stand run by bad tempered edgar kennedy (the number of broken dishes and singing customers in the restaurant give edgar plenty of time to do his slow burns with gusto). eventually powell gets a "break" by being hired to be dupre\'s singing voice in a rip-off of "gone with the wind". this leads to the final section of the film, when rosemary lane, herbert, and healey help give powell his chance to show it\'s his voice, not mowbrays.<br /><br />it\'s quite a cute and appealing film even now. the worst aspects are due to it\'s time. several jokes concerning african-americans are no longer tolerable (while trying to photograph powell as he arrives in hollywood, healey accidentally photographs a porter, and mentions to joslyn to watch out, powell photographs too darkly - get the point?). also a bit with curt bois as a fashion designer for lola lane, who is (shall we say) too high strung is not very tolerable either. herbert\'s "hoo-hoo"ing is a bit much (too much of the time) but it was really popular in 1937. and an incident where healey nearly gets into a brawl at the premiere (this was one of his last films) reminds people of the tragic, still mysterious end of the comedian in december 1937. but most of the film is quite good, and won\'t disappoint the viewer in 2008.'
1

將資料 token 化

我們將使用 keras_hub.tokenizers.WordPieceTokenizer 層來將文字 token 化。keras_hub.tokenizers.WordPieceTokenizer 採用 WordPiece 詞彙表,並具有將文字 token 化和將 token 序列 detoken 化的函數。

在我們定義 tokenizer 之前,我們首先需要在我們擁有的資料集上訓練它。WordPiece token 化演算法是一種子詞 token 化演算法;在語料庫上訓練它會給我們一個子詞詞彙表。子詞 tokenizer 是單詞 tokenizer(單詞 tokenizer 需要非常大的詞彙表才能良好地覆蓋輸入單詞)和字元 tokenizer(字元不像單詞那樣真正編碼含義)之間的折衷方案。幸運的是,KerasHub 使使用 keras_hub.tokenizers.compute_word_piece_vocabulary 公用程式在語料庫上訓練 WordPiece 變得非常簡單。

注意:FNet 的官方實作使用 SentencePiece Tokenizer。

def train_word_piece(ds, vocab_size, reserved_tokens):
    word_piece_ds = ds.unbatch().map(lambda x, y: x)
    vocab = keras_hub.tokenizers.compute_word_piece_vocabulary(
        word_piece_ds.batch(1000).prefetch(2),
        vocabulary_size=vocab_size,
        reserved_tokens=reserved_tokens,
    )
    return vocab

每個詞彙表都有一些特殊的保留 token。我們有兩個這樣的 token

  • "[PAD]" - 填充 token。當輸入序列長度短於最大序列長度時,填充 token 會附加到輸入序列長度。
  • "[UNK]" - 未知 token。
reserved_tokens = ["[PAD]", "[UNK]"]
train_sentences = [element[0] for element in train_ds]
vocab = train_word_piece(train_ds, VOCAB_SIZE, reserved_tokens)

讓我們看看一些 token!

print("Tokens: ", vocab[100:110])
Tokens:  ['à', 'á', 'â', 'ã', 'ä', 'å', 'æ', 'ç', 'è', 'é']

現在,讓我們定義 tokenizer。我們將使用上述訓練的詞彙表配置 tokenizer。我們將定義最大序列長度,以便所有序列都填充到相同長度,如果序列的長度小於指定序列長度。否則,序列將被截斷。

tokenizer = keras_hub.tokenizers.WordPieceTokenizer(
    vocabulary=vocab,
    lowercase=False,
    sequence_length=MAX_SEQUENCE_LENGTH,
)

讓我們嘗試對我們資料集中的樣本進行 token 化!為了驗證文字是否已正確 token 化,我們也可以將 token 列表 detoken 化回原始文字。

input_sentence_ex = train_ds.take(1).get_single_element()[0][0]
input_tokens_ex = tokenizer(input_sentence_ex)

print("Sentence: ", input_sentence_ex)
print("Tokens: ", input_tokens_ex)
print("Recovered text after detokenizing: ", tokenizer.detokenize(input_tokens_ex))
Sentence:  tf.Tensor(b'this picture seemed way to slanted, it\'s almost as bad as the drum beating of the right wing kooks who say everything is rosy in iraq. it paints a picture so unredeemable that i can\'t help but wonder about it\'s legitimacy and bias. also it seemed to meander from being about the murderous carnage of our troops to the lack of health care in the states for ptsd. to me the subject matter seemed confused, it only cared about portraying the military in a bad light, as a) an organzation that uses mind control to turn ordinary peace loving civilians into baby killers and b) an organization that once having used and spent the bodies of it\'s soldiers then discards them to the despotic bureacracy of the v.a. this is a legitimate argument, but felt off topic for me, almost like a movie in and of itself. i felt that "the war tapes" and "blood of my brother" were much more fair and let the viewer draw some conclusions of their own rather than be beaten over the head with the film makers viewpoint. f-', shape=(), dtype=string)
Tokens:  [  145   576   608   228   140    58 13343    13   143     8    58   360
   148   209   148   137  9759  3681   139   137   344  3276    50 12092
   164   169   269   424   141    57  2093   292   144  5115    15   143
  7890    40   576   170  2970  2459  2412 10452   146    48   184     8
    59   478   152   733   177   143     8    58  4060  8069 13355   138
  8557    15   214   143   608   140   526  2121   171   247   177   137
  4726  7336   139   395  4985   140   137   711   139  3959   597   144
   137  1844   149    55  1175   288    15   140   203   137  1009   686
   608  1701    13   143   197  3979   177  2514   137  1442   144    40
   209   776    13   148    40    10   168 14198 13928   146  1260   470
  1300   140   604  2118  2836  1873  9991   217  1006  2318   138    41
    10   168  8469   146   422   400   480   138  1213   137  2541   139
   143     8    58  1487   227  4319 10720   229   140   137  6310  8532
   862    41  2215  6547 10768   139   137    61    15    40    15   145
   141    40  7738  4120    13   152   569   260  3297   149   203    13
   360   172    40   150   144   138   139   561    15    48   569   146
     3   137   466  6192     3   138     3   665   139   193   707     3
   204   207   185  1447   138   417   137   643  2731   182  8421   139
   199   342   385   206   161  3920   253   137   566   151   137   153
  1340  8845    15    45    14     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0]
Recovered text after detokenizing:  tf.Tensor(b'this picture seemed way to slanted , it \' s almost as bad as the drum beating of the right wing kooks who say everything is rosy in iraq . it paints a picture so unredeemable that i can \' t help but wonder about it \' s legitimacy and bias . also it seemed to meander from being about the murderous carnage of our troops to the lack of health care in the states for ptsd . to me the subject matter seemed confused , it only cared about portraying the military in a bad light , as a ) an organzation that uses mind control to turn ordinary peace loving civilians into baby killers and b ) an organization that once having used and spent the bodies of it \' s soldiers then discards them to the despotic bureacracy of the v . a . this is a legitimate argument , but felt off topic for me , almost like a movie in and of itself . i felt that " the war tapes " and " blood of my brother " were much more fair and let the viewer draw some conclusions of their own rather than be beaten over the head with the film makers viewpoint . f - [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]', shape=(), dtype=string)

格式化資料集

接下來,我們將以將會饋送到模型的形式格式化我們的資料集。我們需要將文字 token 化。

def format_dataset(sentence, label):
    sentence = tokenizer(sentence)
    return ({"input_ids": sentence}, label)


def make_dataset(dataset):
    dataset = dataset.map(format_dataset, num_parallel_calls=tf.data.AUTOTUNE)
    return dataset.shuffle(512).prefetch(16).cache()


train_ds = make_dataset(train_ds)
val_ds = make_dataset(val_ds)
test_ds = make_dataset(test_ds)

建構模型

現在,讓我們繼續令人興奮的部分 - 定義我們的模型!我們首先需要一個嵌入層,也就是將輸入序列中的每個 token 對應到一個向量的層。此嵌入層可以隨機初始化。我們還需要一個位置嵌入層,該層會編碼序列中的單詞順序。習慣是將這兩個嵌入相加,也就是求和。KerasHub 有一個 keras_hub.layers.TokenAndPositionEmbedding 層,它為我們完成了上述所有步驟。

我們的 FNet 分類模型由三個 keras_hub.layers.FNetEncoder 層組成,頂部有一個 keras.layers.Dense 層。

注意:對於 FNet,遮蔽填充 token 對結果的影響很小。在官方實作中,填充 token 不會被遮蔽。

input_ids = keras.Input(shape=(None,), dtype="int64", name="input_ids")

x = keras_hub.layers.TokenAndPositionEmbedding(
    vocabulary_size=VOCAB_SIZE,
    sequence_length=MAX_SEQUENCE_LENGTH,
    embedding_dim=EMBED_DIM,
    mask_zero=True,
)(input_ids)

x = keras_hub.layers.FNetEncoder(intermediate_dim=INTERMEDIATE_DIM)(inputs=x)
x = keras_hub.layers.FNetEncoder(intermediate_dim=INTERMEDIATE_DIM)(inputs=x)
x = keras_hub.layers.FNetEncoder(intermediate_dim=INTERMEDIATE_DIM)(inputs=x)


x = keras.layers.GlobalAveragePooling1D()(x)
x = keras.layers.Dropout(0.1)(x)
outputs = keras.layers.Dense(1, activation="sigmoid")(x)

fnet_classifier = keras.Model(input_ids, outputs, name="fnet_classifier")
/home/matt/miniconda3/envs/keras-io/lib/python3.10/site-packages/keras/src/layers/layer.py:861: UserWarning: Layer 'f_net_encoder' (of type FNetEncoder) was passed an input with a mask attached to it. However, this layer does not support masking and will therefore destroy the mask information. Downstream layers will not see the mask.
  warnings.warn(

訓練我們的模型

我們將使用準確性來監控驗證資料上的訓練進度。讓我們訓練我們的模型 3 個 epoch。

fnet_classifier.summary()
fnet_classifier.compile(
    optimizer=keras.optimizers.Adam(learning_rate=0.001),
    loss="binary_crossentropy",
    metrics=["accuracy"],
)
fnet_classifier.fit(train_ds, epochs=EPOCHS, validation_data=val_ds)
Model: "fnet_classifier"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓
┃ Layer (type)                     Output Shape                  Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩
│ input_ids (InputLayer)          │ (None, None)              │          0 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ token_and_position_embedding    │ (None, None, 128)         │  1,985,536 │
│ (TokenAndPositionEmbedding)     │                           │            │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ f_net_encoder (FNetEncoder)     │ (None, None, 128)         │    132,224 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ f_net_encoder_1 (FNetEncoder)   │ (None, None, 128)         │    132,224 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ f_net_encoder_2 (FNetEncoder)   │ (None, None, 128)         │    132,224 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ global_average_pooling1d        │ (None, 128)               │          0 │
│ (GlobalAveragePooling1D)        │                           │            │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ dropout (Dropout)               │ (None, 128)               │          0 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ dense (Dense)                   │ (None, 1)                 │        129 │
└─────────────────────────────────┴───────────────────────────┴────────────┘
 Total params: 2,382,337 (9.09 MB)
 Trainable params: 2,382,337 (9.09 MB)
 Non-trainable params: 0 (0.00 B)
Epoch 1/3

/home/matt/miniconda3/envs/keras-io/lib/python3.10/site-packages/keras/src/backend/jax/core.py:64: UserWarning: Explicitly requested dtype int64 requested in array is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return jnp.array(x, dtype=dtype)

 313/313 ━━━━━━━━━━━━━━━━━━━━ 8s 18ms/step - accuracy: 0.5916 - loss: 0.6542 - val_accuracy: 0.8479 - val_loss: 0.3536
Epoch 2/3
 313/313 ━━━━━━━━━━━━━━━━━━━━ 4s 12ms/step - accuracy: 0.8776 - loss: 0.2916 - val_accuracy: 0.8532 - val_loss: 0.3387
Epoch 3/3
 313/313 ━━━━━━━━━━━━━━━━━━━━ 4s 12ms/step - accuracy: 0.9442 - loss: 0.1543 - val_accuracy: 0.8534 - val_loss: 0.4018

<keras.src.callbacks.history.History at 0x7feb7169c0d0>

我們獲得了大約 92% 的訓練準確性和大約 85% 的驗證準確性。此外,對於 3 個 epoch,訓練模型大約需要 86 秒(在具有 16 GB Tesla T4 GPU 的 Colab 上)。

讓我們計算測試準確性。

fnet_classifier.evaluate(test_ds, batch_size=BATCH_SIZE)
 391/391 ━━━━━━━━━━━━━━━━━━━━ 3s 5ms/step - accuracy: 0.8412 - loss: 0.4281

[0.4198716878890991, 0.8427909016609192]

與 Transformer 模型比較

讓我們將我們的 FNet 分類器模型與 Transformer 分類器模型進行比較。我們保留所有參數/超參數相同。例如,我們使用三個 TransformerEncoder 層。

我們將 head 數量設為 2。

NUM_HEADS = 2
input_ids = keras.Input(shape=(None,), dtype="int64", name="input_ids")


x = keras_hub.layers.TokenAndPositionEmbedding(
    vocabulary_size=VOCAB_SIZE,
    sequence_length=MAX_SEQUENCE_LENGTH,
    embedding_dim=EMBED_DIM,
    mask_zero=True,
)(input_ids)

x = keras_hub.layers.TransformerEncoder(
    intermediate_dim=INTERMEDIATE_DIM, num_heads=NUM_HEADS
)(inputs=x)
x = keras_hub.layers.TransformerEncoder(
    intermediate_dim=INTERMEDIATE_DIM, num_heads=NUM_HEADS
)(inputs=x)
x = keras_hub.layers.TransformerEncoder(
    intermediate_dim=INTERMEDIATE_DIM, num_heads=NUM_HEADS
)(inputs=x)


x = keras.layers.GlobalAveragePooling1D()(x)
x = keras.layers.Dropout(0.1)(x)
outputs = keras.layers.Dense(1, activation="sigmoid")(x)

transformer_classifier = keras.Model(input_ids, outputs, name="transformer_classifier")


transformer_classifier.summary()
transformer_classifier.compile(
    optimizer=keras.optimizers.Adam(learning_rate=0.001),
    loss="binary_crossentropy",
    metrics=["accuracy"],
)
transformer_classifier.fit(train_ds, epochs=EPOCHS, validation_data=val_ds)
Model: "transformer_classifier"
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓
┃ Layer (type)         Output Shape       Param #  Connected to         ┃
┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩
│ input_ids           │ (None, None)      │       0 │ -                    │
│ (InputLayer)        │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ token_and_position… │ (None, None, 128) │ 1,985,… │ input_ids[0][0]      │
│ (TokenAndPositionE… │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ transformer_encoder │ (None, None, 128) │ 198,272 │ token_and_position_… │
│ (TransformerEncode… │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ transformer_encode… │ (None, None, 128) │ 198,272 │ transformer_encoder… │
│ (TransformerEncode… │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ transformer_encode… │ (None, None, 128) │ 198,272 │ transformer_encoder… │
│ (TransformerEncode… │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ not_equal_1         │ (None, None)      │       0 │ input_ids[0][0]      │
│ (NotEqual)          │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ global_average_poo… │ (None, 128)       │       0 │ transformer_encoder… │
│ (GlobalAveragePool… │                   │         │ not_equal_1[0][0]    │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ dropout_4 (Dropout) │ (None, 128)       │       0 │ global_average_pool… │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ dense_1 (Dense)     │ (None, 1)         │     129 │ dropout_4[0][0]      │
└─────────────────────┴───────────────────┴─────────┴──────────────────────┘
 Total params: 2,580,481 (9.84 MB)
 Trainable params: 2,580,481 (9.84 MB)
 Non-trainable params: 0 (0.00 B)
Epoch 1/3
 313/313 ━━━━━━━━━━━━━━━━━━━━ 14s 38ms/step - accuracy: 0.5895 - loss: 0.7401 - val_accuracy: 0.8912 - val_loss: 0.2694
Epoch 2/3
 313/313 ━━━━━━━━━━━━━━━━━━━━ 9s 29ms/step - accuracy: 0.9051 - loss: 0.2382 - val_accuracy: 0.8853 - val_loss: 0.2984
Epoch 3/3
 313/313 ━━━━━━━━━━━━━━━━━━━━ 9s 29ms/step - accuracy: 0.9496 - loss: 0.1366 - val_accuracy: 0.8730 - val_loss: 0.3607

<keras.src.callbacks.history.History at 0x7feaf9c56ad0>

我們獲得了大約 94% 的訓練準確性和大約 86.5% 的驗證準確性。訓練模型大約需要 146 秒(在具有 16 GB Tesla T4 GPU 的 Colab 上)。

讓我們計算測試準確性。

transformer_classifier.evaluate(test_ds, batch_size=BATCH_SIZE)
 391/391 ━━━━━━━━━━━━━━━━━━━━ 4s 11ms/step - accuracy: 0.8399 - loss: 0.4579

[0.4496161639690399, 0.8423193097114563]

讓我們製作一個表格並比較兩個模型。我們可以發現,FNet 明顯加快了我們的執行時間 (1.7 倍),而整體準確性僅略有犧牲(下降 0.75%)。

FNet 分類器 Transformer 分類器
訓練時間 86 秒 146 秒
訓練準確性 92.34% 93.85%
驗證準確性 85.21% 86.42%
測試準確性 83.94% 84.69%
#Params 2,321,921 2,520,065