程式碼範例 / 自然語言處理 / 大規模多標籤文本分類

大規模多標籤文本分類

作者: Sayak Paul, Soumik Rakshit
建立日期 2020/09/25
上次修改日期 2020/12/23
說明: 實作大規模多標籤文字分類模型。

ⓘ 此範例使用 Keras 2

在 Colab 中檢視 GitHub 原始碼


簡介

在此範例中,我們將建立一個多標籤文字分類器,從 arXiv 論文的摘要主體預測其主題領域。這種類型的分類器對於像 OpenReview 這樣的會議投稿入口網站很有用。給定論文摘要,入口網站可以針對該論文最適合的領域提供建議。

該數據集是使用 arXiv Python 函式庫收集的,該函式庫提供了一個圍繞 原始 arXiv API 的包裝器。要了解更多關於數據收集過程,請參考這個筆記本。此外,您也可以在 Kaggle 上找到該數據集。


匯入

from tensorflow.keras import layers
from tensorflow import keras
import tensorflow as tf

from sklearn.model_selection import train_test_split
from ast import literal_eval

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

執行探索性數據分析

在本節中,我們首先將數據集載入 pandas 資料框,然後執行一些基本的探索性數據分析 (EDA)。

arxiv_data = pd.read_csv(
    "https://github.com/soumik12345/multi-label-text-classification/releases/download/v0.2/arxiv_data.csv"
)
arxiv_data.head()
標題 摘要 類別
0 語義立體匹配 / 語義的調查... 立體匹配是廣泛使用的技術之一... ['cs.CV', 'cs.LG']
1 未來 AI:指導原則和共識再... 人工智慧的最新進展... ['cs.CV', 'cs.AI', 'cs.LG']
2 針對硬區域的強制互通性一致性... 在本文中,我們提出了一種新的互通性一致性... ['cs.CV', 'cs.AI']
3 用於半監督參數解耦策略... 一致性訓練已被證明是一個先進的... ['cs.CV']
4 用於內飾的背景前景分割... 為確保自動駕駛的安全性,正確的... ['cs.CV', 'cs.LG']

我們的文字特徵存在於 summaries 欄中,而它們對應的標籤則在 terms 中。如您所見,一個特定的條目有多個類別與之關聯。

print(f"There are {len(arxiv_data)} rows in the dataset.")
There are 51774 rows in the dataset.

真實世界的數據是嘈雜的。最常見的雜訊來源之一是數據重複。在這裡,我們注意到我們的初始數據集大約有 1.3 萬個重複條目。

total_duplicate_titles = sum(arxiv_data["titles"].duplicated())
print(f"There are {total_duplicate_titles} duplicate titles.")
There are 12802 duplicate titles.

在繼續之前,我們刪除這些條目。

arxiv_data = arxiv_data[~arxiv_data["titles"].duplicated()]
print(f"There are {len(arxiv_data)} rows in the deduplicated dataset.")

# There are some terms with occurrence as low as 1.
print(sum(arxiv_data["terms"].value_counts() == 1))

# How many unique terms?
print(arxiv_data["terms"].nunique())
There are 38972 rows in the deduplicated dataset.
2321
3157

如上所述,在 3,157 個獨特的 terms 組合中,有 2,321 個條目的出現次數最低。為了準備具有分層抽樣的訓練、驗證和測試集,我們需要刪除這些類別。

# Filtering the rare terms.
arxiv_data_filtered = arxiv_data.groupby("terms").filter(lambda x: len(x) > 1)
arxiv_data_filtered.shape
(36651, 3)

將字串標籤轉換為字串列表

初始標籤表示為原始字串。在這裡,我們將它們設為 List[str] 以實現更緊湊的表示形式。

arxiv_data_filtered["terms"] = arxiv_data_filtered["terms"].apply(
    lambda x: literal_eval(x)
)
arxiv_data_filtered["terms"].values[:5]
array([list(['cs.CV', 'cs.LG']), list(['cs.CV', 'cs.AI', 'cs.LG']),
       list(['cs.CV', 'cs.AI']), list(['cs.CV']),
       list(['cs.CV', 'cs.LG'])], dtype=object)

由於類別不平衡,請使用分層分割

該數據集存在類別不平衡問題。因此,為了獲得公平的評估結果,我們需要確保數據集是通過分層抽樣的。若要了解更多關於處理類別不平衡問題的不同策略,您可以參考本教學課程。如需展示使用不平衡數據進行分類的端到端演示,請參閱不平衡分類:信用卡詐欺檢測

test_split = 0.1

# Initial train and test split.
train_df, test_df = train_test_split(
    arxiv_data_filtered,
    test_size=test_split,
    stratify=arxiv_data_filtered["terms"].values,
)

# Splitting the test set further into validation
# and new test sets.
val_df = test_df.sample(frac=0.5)
test_df.drop(val_df.index, inplace=True)

print(f"Number of rows in training set: {len(train_df)}")
print(f"Number of rows in validation set: {len(val_df)}")
print(f"Number of rows in test set: {len(test_df)}")
Number of rows in training set: 32985
Number of rows in validation set: 1833
Number of rows in test set: 1833

多標籤二元化

現在,我們使用 StringLookup 層來預處理我們的標籤。

terms = tf.ragged.constant(train_df["terms"].values)
lookup = tf.keras.layers.StringLookup(output_mode="multi_hot")
lookup.adapt(terms)
vocab = lookup.get_vocabulary()


def invert_multi_hot(encoded_labels):
    """Reverse a single multi-hot encoded label to a tuple of vocab terms."""
    hot_indices = np.argwhere(encoded_labels == 1.0)[..., 0]
    return np.take(vocab, hot_indices)


print("Vocabulary:\n")
print(vocab)
Vocabulary:
['[UNK]', 'cs.CV', 'cs.LG', 'stat.ML', 'cs.AI', 'eess.IV', 'cs.RO', 'cs.CL', 'cs.NE', 'cs.CR', 'math.OC', 'eess.SP', 'cs.GR', 'cs.SI', 'cs.MM', 'cs.SY', 'cs.IR', 'cs.MA', 'eess.SY', 'cs.HC', 'math.IT', 'cs.IT', 'cs.DC', 'cs.CY', 'stat.AP', 'stat.TH', 'math.ST', 'stat.ME', 'eess.AS', 'cs.SD', 'q-bio.QM', 'q-bio.NC', 'cs.DS', 'cs.GT', 'cs.CG', 'cs.SE', 'cs.NI', 'I.2.6', 'stat.CO', 'math.NA', 'cs.NA', 'physics.chem-ph', 'cs.DB', 'q-bio.BM', 'cs.PL', 'cs.LO', 'cond-mat.dis-nn', '68T45', 'math.PR', 'physics.comp-ph', 'I.2.10', 'cs.CE', 'cs.AR', 'q-fin.ST', 'cond-mat.stat-mech', '68T05', 'quant-ph', 'math.DS', 'physics.data-an', 'cs.CC', 'I.4.6', 'physics.soc-ph', 'physics.ao-ph', 'cs.DM', 'econ.EM', 'q-bio.GN', 'physics.med-ph', 'astro-ph.IM', 'I.4.8', 'math.AT', 'cs.PF', 'cs.FL', 'I.4', 'q-fin.TR', 'I.5.4', 'I.2', '68U10', 'hep-ex', 'cond-mat.mtrl-sci', '68T10', 'physics.optics', 'physics.geo-ph', 'physics.flu-dyn', 'math.CO', 'math.AP', 'I.4; I.5', 'I.4.9', 'I.2.6; I.2.8', '68T01', '65D19', 'q-fin.CP', 'nlin.CD', 'cs.MS', 'I.2.6; I.5.1', 'I.2.10; I.4; I.5', 'I.2.0; I.2.6', '68T07', 'q-fin.GN', 'cs.SC', 'cs.ET', 'K.3.2', 'I.2.8', '68U01', '68T30', 'q-fin.EC', 'q-bio.MN', 'econ.GN', 'I.4.9; I.5.4', 'I.4.5', 'I.2; I.5', 'I.2; I.4; I.5', 'I.2.6; I.2.7', 'I.2.10; I.4.8', '68T99', '68Q32', '68', '62H30', 'q-fin.RM', 'q-fin.PM', 'q-bio.TO', 'q-bio.OT', 'physics.bio-ph', 'nlin.AO', 'math.LO', 'math.FA', 'hep-ph', 'cond-mat.soft', 'I.4.6; I.4.8', 'I.4.4', 'I.4.3', 'I.4.0', 'I.2; J.2', 'I.2; I.2.6; I.2.7', 'I.2.7', 'I.2.6; I.5.4', 'I.2.6; I.2.9', 'I.2.6; I.2.7; H.3.1; H.3.3', 'I.2.6; I.2.10', 'I.2.6, I.5.4', 'I.2.1; J.3', 'I.2.10; I.5.1; I.4.8', 'I.2.10; I.4.8; I.5.4', 'I.2.10; I.2.6', 'I.2.1', 'H.3.1; I.2.6; I.2.7', 'H.3.1; H.3.3; I.2.6; I.2.7', 'G.3', 'F.2.2; I.2.7', 'E.5; E.4; E.2; H.1.1; F.1.1; F.1.3', '68Txx', '62H99', '62H35', '14J60 (Primary) 14F05, 14J26 (Secondary)']

在這裡,我們將從標籤池中分離可用的單個獨特類別,然後使用此資訊以 0 和 1 來表示給定的標籤集。以下是一個範例。

sample_label = train_df["terms"].iloc[0]
print(f"Original label: {sample_label}")

label_binarized = lookup([sample_label])
print(f"Label-binarized representation: {label_binarized}")
Original label: ['cs.LG', 'cs.CV', 'eess.IV']
Label-binarized representation: [[0. 1. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0.]]

數據預處理和 tf.data.Dataset 物件

我們首先取得序列長度的百分位數估計值。稍後就會清楚這樣做的目的。

train_df["summaries"].apply(lambda x: len(x.split(" "))).describe()
count    32985.000000
mean       156.497105
std         41.528225
min          5.000000
25%        128.000000
50%        154.000000
75%        183.000000
max        462.000000
Name: summaries, dtype: float64

請注意,50% 的摘要長度為 154(您可能會根據分割得到不同的數字)。因此,任何接近該值的值都足以近似最大序列長度。

現在,我們實作實用工具來準備我們的數據集。

max_seqlen = 150
batch_size = 128
padding_token = "<pad>"
auto = tf.data.AUTOTUNE


def make_dataset(dataframe, is_train=True):
    labels = tf.ragged.constant(dataframe["terms"].values)
    label_binarized = lookup(labels).numpy()
    dataset = tf.data.Dataset.from_tensor_slices(
        (dataframe["summaries"].values, label_binarized)
    )
    dataset = dataset.shuffle(batch_size * 10) if is_train else dataset
    return dataset.batch(batch_size)

現在,我們可以準備 tf.data.Dataset 物件。

train_dataset = make_dataset(train_df, is_train=True)
validation_dataset = make_dataset(val_df, is_train=False)
test_dataset = make_dataset(test_df, is_train=False)

數據集預覽

text_batch, label_batch = next(iter(train_dataset))

for i, text in enumerate(text_batch[:5]):
    label = label_batch[i].numpy()[None, ...]
    print(f"Abstract: {text}")
    print(f"Label(s): {invert_multi_hot(label[0])}")
    print(" ")
Abstract: b"In this paper we show how using satellite images can improve the accuracy of\nhousing price estimation models. Using Los Angeles County's property assessment\ndataset, by transferring learning from an Inception-v3 model pretrained on\nImageNet, we could achieve an improvement of ~10% in R-squared score compared\nto two baseline models that only use non-image features of the house."
Label(s): ['cs.LG' 'stat.ML']

Abstract: b'Learning from data streams is an increasingly important topic in data mining,\nmachine learning, and artificial intelligence in general. A major focus in the\ndata stream literature is on designing methods that can deal with concept\ndrift, a challenge where the generating distribution changes over time. A\ngeneral assumption in most of this literature is that instances are\nindependently distributed in the stream. In this work we show that, in the\ncontext of concept drift, this assumption is contradictory, and that the\npresence of concept drift necessarily implies temporal dependence; and thus\nsome form of time series. This has important implications on model design and\ndeployment. We explore and highlight the these implications, and show that\nHoeffding-tree based ensembles, which are very popular for learning in streams,\nare not naturally suited to learning \\emph{within} drift; and can perform in\nthis scenario only at significant computational cost of destructive adaptation.\nOn the other hand, we develop and parameterize gradient-descent methods and\ndemonstrate how they can perform \\emph{continuous} adaptation with no explicit\ndrift-detection mechanism, offering major advantages in terms of accuracy and\nefficiency. As a consequence of our theoretical discussion and empirical\nobservations, we outline a number of recommendations for deploying methods in\nconcept-drifting streams.'
Label(s): ['cs.LG' 'stat.ML']

Abstract: b"As reinforcement learning (RL) achieves more success in solving complex\ntasks, more care is needed to ensure that RL research is reproducible and that\nalgorithms herein can be compared easily and fairly with minimal bias. RL\nresults are, however, notoriously hard to reproduce due to the algorithms'\nintrinsic variance, the environments' stochasticity, and numerous (potentially\nunreported) hyper-parameters. In this work we investigate the many issues\nleading to irreproducible research and how to manage those. We further show how\nto utilise a rigorous and standardised evaluation approach for easing the\nprocess of documentation, evaluation and fair comparison of different\nalgorithms, where we emphasise the importance of choosing the right measurement\nmetrics and conducting proper statistics on the results, for unbiased reporting\nof the results."
Label(s): ['cs.LG' 'stat.ML' 'cs.AI' 'cs.RO']

Abstract: b'Estimating dense correspondences between images is a long-standing image\nunder-standing task. Recent works introduce convolutional neural networks\n(CNNs) to extract high-level feature maps and find correspondences through\nfeature matching. However,high-level feature maps are in low spatial resolution\nand therefore insufficient to provide accurate and fine-grained features to\ndistinguish intra-class variations for correspondence matching. To address this\nproblem, we generate robust features by dynamically selecting features at\ndifferent scales. To resolve two critical issues in feature selection,i.e.,how\nmany and which scales of features to be selected, we frame the feature\nselection process as a sequential Markov decision-making process (MDP) and\nintroduce an optimal selection strategy using reinforcement learning (RL). We\ndefine an RL environment for image matching in which each individual action\neither requires new features or terminates the selection episode by referring a\nmatching score. Deep neural networks are incorporated into our method and\ntrained for decision making. Experimental results show that our method achieves\ncomparable/superior performance with state-of-the-art methods on three\nbenchmarks, demonstrating the effectiveness of our feature selection strategy.'
Label(s): ['cs.CV']

Abstract: b'Dense reconstructions often contain errors that prior work has so far\nminimised using high quality sensors and regularising the output. Nevertheless,\nerrors still persist. This paper proposes a machine learning technique to\nidentify errors in three dimensional (3D) meshes. Beyond simply identifying\nerrors, our method quantifies both the magnitude and the direction of depth\nestimate errors when viewing the scene. This enables us to improve the\nreconstruction accuracy.\n  We train a suitably deep network architecture with two 3D meshes: a\nhigh-quality laser reconstruction, and a lower quality stereo image\nreconstruction. The network predicts the amount of error in the lower quality\nreconstruction with respect to the high-quality one, having only view the\nformer through its input. We evaluate our approach by correcting\ntwo-dimensional (2D) inverse-depth images extracted from the 3D model, and show\nthat our method improves the quality of these depth reconstructions by up to a\nrelative 10% RMSE.'
Label(s): ['cs.CV' 'cs.RO']

向量化

在將資料饋入模型之前,我們需要將其向量化(以數字形式表示)。為此,我們將使用TextVectorization。它可以作為主要模型的一部分運作,以便將模型排除在核心預處理邏輯之外。這大大降低了推理期間的訓練/服務偏差的機率。

我們首先計算摘要中存在的唯一單字數量。

# Source: https://stackoverflow.com/a/18937309/7636462
vocabulary = set()
train_df["summaries"].str.lower().str.split().apply(vocabulary.update)
vocabulary_size = len(vocabulary)
print(vocabulary_size)
153338

現在,我們建立向量化層,並將 map() 對應到先前建立的 tf.data.Dataset

text_vectorizer = layers.TextVectorization(
    max_tokens=vocabulary_size, ngrams=2, output_mode="tf_idf"
)

# `TextVectorization` layer needs to be adapted as per the vocabulary from our
# training set.
with tf.device("/CPU:0"):
    text_vectorizer.adapt(train_dataset.map(lambda text, label: text))

train_dataset = train_dataset.map(
    lambda text, label: (text_vectorizer(text), label), num_parallel_calls=auto
).prefetch(auto)
validation_dataset = validation_dataset.map(
    lambda text, label: (text_vectorizer(text), label), num_parallel_calls=auto
).prefetch(auto)
test_dataset = test_dataset.map(
    lambda text, label: (text_vectorizer(text), label), num_parallel_calls=auto
).prefetch(auto)

一批原始文字將首先通過 TextVectorization 層,它將產生它們的整數表示。在內部,TextVectorization 層將首先從序列中建立雙字母組,然後使用 TF-IDF 來表示它們。輸出表示將接著傳遞到負責文字分類的淺層模型。

若要深入了解 TextVectorizer 的其他可能組態,請參閱官方文件

注意:將 max_tokens 引數設定為預先計算的詞彙大小不是必要條件。


建立文字分類模型

我們會保持模型的簡單性 - 它將是一個小型全連接層堆疊,並以 ReLU 作為非線性。

def make_model():
    shallow_mlp_model = keras.Sequential(
        [
            layers.Dense(512, activation="relu"),
            layers.Dense(256, activation="relu"),
            layers.Dense(lookup.vocabulary_size(), activation="sigmoid"),
        ]  # More on why "sigmoid" has been used here in a moment.
    )
    return shallow_mlp_model

訓練模型

我們將使用二元交叉熵損失來訓練我們的模型。這是因為標籤並非不相交。對於給定的摘要,我們可能有多個類別。因此,我們將把預測任務劃分為一系列多個二元分類問題。這也是為什麼我們在模型中將分類層的啟動函數保留為 sigmoid 的原因。研究人員也使用了其他損失函數和啟動函數的組合。例如,在探索弱監督預訓練的極限中,Mahajan 等人使用了 softmax 啟動函數和交叉熵損失來訓練他們的模型。

在多標籤分類中,可以使用多個度量選項。為了縮小此程式碼範例的範圍,我們決定使用 二元準確度度量。若要了解為什麼使用此度量的說明,我們參考此提取請求。還有其他適用於多標籤分類的度量,例如 F1 分數Hamming 損失

epochs = 20

shallow_mlp_model = make_model()
shallow_mlp_model.compile(
    loss="binary_crossentropy", optimizer="adam", metrics=["binary_accuracy"]
)

history = shallow_mlp_model.fit(
    train_dataset, validation_data=validation_dataset, epochs=epochs
)


def plot_result(item):
    plt.plot(history.history[item], label=item)
    plt.plot(history.history["val_" + item], label="val_" + item)
    plt.xlabel("Epochs")
    plt.ylabel(item)
    plt.title("Train and Validation {} Over Epochs".format(item), fontsize=14)
    plt.legend()
    plt.grid()
    plt.show()


plot_result("loss")
plot_result("binary_accuracy")
Epoch 1/20
258/258 [==============================] - 87s 332ms/step - loss: 0.0326 - binary_accuracy: 0.9893 - val_loss: 0.0189 - val_binary_accuracy: 0.9943
Epoch 2/20
258/258 [==============================] - 100s 387ms/step - loss: 0.0033 - binary_accuracy: 0.9990 - val_loss: 0.0271 - val_binary_accuracy: 0.9940
Epoch 3/20
258/258 [==============================] - 99s 384ms/step - loss: 7.8393e-04 - binary_accuracy: 0.9999 - val_loss: 0.0328 - val_binary_accuracy: 0.9939
Epoch 4/20
258/258 [==============================] - 109s 421ms/step - loss: 3.0132e-04 - binary_accuracy: 1.0000 - val_loss: 0.0366 - val_binary_accuracy: 0.9939
Epoch 5/20
258/258 [==============================] - 105s 405ms/step - loss: 1.6006e-04 - binary_accuracy: 1.0000 - val_loss: 0.0399 - val_binary_accuracy: 0.9939
Epoch 6/20
258/258 [==============================] - 107s 414ms/step - loss: 1.2400e-04 - binary_accuracy: 1.0000 - val_loss: 0.0412 - val_binary_accuracy: 0.9939
Epoch 7/20
258/258 [==============================] - 110s 425ms/step - loss: 7.7131e-05 - binary_accuracy: 1.0000 - val_loss: 0.0439 - val_binary_accuracy: 0.9940
Epoch 8/20
258/258 [==============================] - 105s 405ms/step - loss: 5.5611e-05 - binary_accuracy: 1.0000 - val_loss: 0.0446 - val_binary_accuracy: 0.9940
Epoch 9/20
258/258 [==============================] - 103s 397ms/step - loss: 4.5994e-05 - binary_accuracy: 1.0000 - val_loss: 0.0454 - val_binary_accuracy: 0.9940
Epoch 10/20
258/258 [==============================] - 105s 405ms/step - loss: 3.5126e-05 - binary_accuracy: 1.0000 - val_loss: 0.0472 - val_binary_accuracy: 0.9939
Epoch 11/20
258/258 [==============================] - 109s 422ms/step - loss: 2.9927e-05 - binary_accuracy: 1.0000 - val_loss: 0.0466 - val_binary_accuracy: 0.9940
Epoch 12/20
258/258 [==============================] - 133s 516ms/step - loss: 2.5748e-05 - binary_accuracy: 1.0000 - val_loss: 0.0484 - val_binary_accuracy: 0.9940
Epoch 13/20
258/258 [==============================] - 129s 497ms/step - loss: 4.3529e-05 - binary_accuracy: 1.0000 - val_loss: 0.0500 - val_binary_accuracy: 0.9940
Epoch 14/20
258/258 [==============================] - 158s 611ms/step - loss: 8.1068e-04 - binary_accuracy: 0.9998 - val_loss: 0.0377 - val_binary_accuracy: 0.9936
Epoch 15/20
258/258 [==============================] - 144s 558ms/step - loss: 0.0016 - binary_accuracy: 0.9995 - val_loss: 0.0418 - val_binary_accuracy: 0.9935
Epoch 16/20
258/258 [==============================] - 131s 506ms/step - loss: 0.0018 - binary_accuracy: 0.9995 - val_loss: 0.0479 - val_binary_accuracy: 0.9931
Epoch 17/20
258/258 [==============================] - 127s 491ms/step - loss: 0.0012 - binary_accuracy: 0.9997 - val_loss: 0.0521 - val_binary_accuracy: 0.9931
Epoch 18/20
258/258 [==============================] - 153s 594ms/step - loss: 6.3144e-04 - binary_accuracy: 0.9998 - val_loss: 0.0549 - val_binary_accuracy: 0.9934
Epoch 19/20
258/258 [==============================] - 142s 550ms/step - loss: 3.1753e-04 - binary_accuracy: 0.9999 - val_loss: 0.0589 - val_binary_accuracy: 0.9934
Epoch 20/20
258/258 [==============================] - 153s 594ms/step - loss: 2.0258e-04 - binary_accuracy: 1.0000 - val_loss: 0.0585 - val_binary_accuracy: 0.9933

png

png

在訓練過程中,我們注意到損失最初急遽下降,然後逐漸衰減。

評估模型

_, binary_acc = shallow_mlp_model.evaluate(test_dataset)
print(f"Categorical accuracy on the test set: {round(binary_acc * 100, 2)}%.")
15/15 [==============================] - 3s 196ms/step - loss: 0.0580 - binary_accuracy: 0.9933
Categorical accuracy on the test set: 99.33%.

經過訓練的模型給予我們約 99% 的評估準確度。


推論

Keras 提供的預處理層的一個重要功能是它們可以包含在 tf.keras.Model 中。我們將透過在 shallow_mlp_model 之上包含 text_vectorization 層來匯出推論模型。這將允許我們的推論模型直接在原始字串上運作。

請注意,在訓練過程中,最好將這些預處理層作為資料輸入管道的一部分,而不是模型的一部分,以避免硬體加速器出現瓶頸。這也有助於實現非同步資料處理。

# Create a model for inference.
model_for_inference = keras.Sequential([text_vectorizer, shallow_mlp_model])

# Create a small dataset just for demoing inference.
inference_dataset = make_dataset(test_df.sample(100), is_train=False)
text_batch, label_batch = next(iter(inference_dataset))
predicted_probabilities = model_for_inference.predict(text_batch)

# Perform inference.
for i, text in enumerate(text_batch[:5]):
    label = label_batch[i].numpy()[None, ...]
    print(f"Abstract: {text}")
    print(f"Label(s): {invert_multi_hot(label[0])}")
    predicted_proba = [proba for proba in predicted_probabilities[i]]
    top_3_labels = [
        x
        for _, x in sorted(
            zip(predicted_probabilities[i], lookup.get_vocabulary()),
            key=lambda pair: pair[0],
            reverse=True,
        )
    ][:3]
    print(f"Predicted Label(s): ({', '.join([label for label in top_3_labels])})")
    print(" ")
4/4 [==============================] - 0s 62ms/step
Abstract: b'We investigate the training of sparse layers that use different parameters\nfor different inputs based on hashing in large Transformer models.\nSpecifically, we modify the feedforward layer to hash to different sets of\nweights depending on the current token, over all tokens in the sequence. We\nshow that this procedure either outperforms or is competitive with\nlearning-to-route mixture-of-expert methods such as Switch Transformers and\nBASE Layers, while requiring no routing parameters or extra terms in the\nobjective function such as a load balancing loss, and no sophisticated\nassignment algorithm. We study the performance of different hashing techniques,\nhash sizes and input features, and show that balanced and random hashes focused\non the most local features work best, compared to either learning clusters or\nusing longer-range context. We show our approach works well both on large\nlanguage modeling and dialogue tasks, and on downstream fine-tuning tasks.'
Label(s): ['cs.LG' 'cs.CL']
Predicted Label(s): (cs.LG, cs.CL, stat.ML)

Abstract: b'We present the first method capable of photorealistically reconstructing\ndeformable scenes using photos/videos captured casually from mobile phones. Our\napproach augments neural radiance fields (NeRF) by optimizing an additional\ncontinuous volumetric deformation field that warps each observed point into a\ncanonical 5D NeRF. We observe that these NeRF-like deformation fields are prone\nto local minima, and propose a coarse-to-fine optimization method for\ncoordinate-based models that allows for more robust optimization. By adapting\nprinciples from geometry processing and physical simulation to NeRF-like\nmodels, we propose an elastic regularization of the deformation field that\nfurther improves robustness. We show that our method can turn casually captured\nselfie photos/videos into deformable NeRF models that allow for photorealistic\nrenderings of the subject from arbitrary viewpoints, which we dub "nerfies." We\nevaluate our method by collecting time-synchronized data using a rig with two\nmobile phones, yielding train/validation images of the same pose at different\nviewpoints. We show that our method faithfully reconstructs non-rigidly\ndeforming scenes and reproduces unseen views with high fidelity.'
Label(s): ['cs.CV' 'cs.GR']
Predicted Label(s): (cs.CV, cs.GR, cs.RO)

Abstract: b'We propose to jointly learn multi-view geometry and warping between views of\nthe same object instances for robust cross-view object detection. What makes\nmulti-view object instance detection difficult are strong changes in viewpoint,\nlighting conditions, high similarity of neighbouring objects, and strong\nvariability in scale. By turning object detection and instance\nre-identification in different views into a joint learning task, we are able to\nincorporate both image appearance and geometric soft constraints into a single,\nmulti-view detection process that is learnable end-to-end. We validate our\nmethod on a new, large data set of street-level panoramas of urban objects and\nshow superior performance compared to various baselines. Our contribution is\nthreefold: a large-scale, publicly available data set for multi-view instance\ndetection and re-identification; an annotation tool custom-tailored for\nmulti-view instance detection; and a novel, holistic multi-view instance\ndetection and re-identification method that jointly models geometry and\nappearance across views.'
Label(s): ['cs.CV' 'cs.LG' 'stat.ML']
Predicted Label(s): (cs.CV, cs.RO, cs.MM)

Abstract: b'Learning graph convolutional networks (GCNs) is an emerging field which aims\nat generalizing deep learning to arbitrary non-regular domains. Most of the\nexisting GCNs follow a neighborhood aggregation scheme, where the\nrepresentation of a node is recursively obtained by aggregating its neighboring\nnode representations using averaging or sorting operations. However, these\noperations are either ill-posed or weak to be discriminant or increase the\nnumber of training parameters and thereby the computational complexity and the\nrisk of overfitting. In this paper, we introduce a novel GCN framework that\nachieves spatial graph convolution in a reproducing kernel Hilbert space\n(RKHS). The latter makes it possible to design, via implicit kernel\nrepresentations, convolutional graph filters in a high dimensional and more\ndiscriminating space without increasing the number of training parameters. The\nparticularity of our GCN model also resides in its ability to achieve\nconvolutions without explicitly realigning nodes in the receptive fields of the\nlearned graph filters with those of the input graphs, thereby making\nconvolutions permutation agnostic and well defined. Experiments conducted on\nthe challenging task of skeleton-based action recognition show the superiority\nof the proposed method against different baselines as well as the related work.'
Label(s): ['cs.CV']
Predicted Label(s): (cs.LG, cs.CV, cs.NE)

Abstract: b'Recurrent meta reinforcement learning (meta-RL) agents are agents that employ\na recurrent neural network (RNN) for the purpose of "learning a learning\nalgorithm". After being trained on a pre-specified task distribution, the\nlearned weights of the agent\'s RNN are said to implement an efficient learning\nalgorithm through their activity dynamics, which allows the agent to quickly\nsolve new tasks sampled from the same distribution. However, due to the\nblack-box nature of these agents, the way in which they work is not yet fully\nunderstood. In this study, we shed light on the internal working mechanisms of\nthese agents by reformulating the meta-RL problem using the Partially\nObservable Markov Decision Process (POMDP) framework. We hypothesize that the\nlearned activity dynamics is acting as belief states for such agents. Several\nillustrative experiments suggest that this hypothesis is true, and that\nrecurrent meta-RL agents can be viewed as agents that learn to act optimally in\npartially observable environments consisting of multiple related tasks. This\nview helps in understanding their failure cases and some interesting\nmodel-based results reported in the literature.'
Label(s): ['cs.LG' 'cs.AI']
Predicted Label(s): (stat.ML, cs.LG, cs.AI)

預測結果並非非常出色,但對於像我們這樣簡單的模型來說,也還算是在水準之上。我們可以透過考慮詞語順序的模型(如 LSTM),甚至是使用 Transformer 的模型(Vaswani 等人)來提高效能。


致謝

我們想感謝 Matt Watson 協助我們處理多標籤二元化部分,以及將處理過的標籤反向轉換回原始形式。

感謝 Cingis Kratochvil 建議並擴展此程式碼範例,加入二元準確度。