作者: Sayak Paul
建立日期 2021/08/08
上次修改日期 2025/01/03
描述: 訓練用於預測蘊涵關係的多模態模型。
在這個範例中,我們將建立並訓練一個用於預測多模態蘊涵的模型。我們將使用 Google Research 最近推出的 多模態蘊涵資料集。
在社群媒體平台上,為了稽核和調節內容,我們可能需要近乎即時地找出以下問題的答案:
在自然語言處理中,這個任務稱為分析文本蘊涵。然而,這僅在資訊來自文字內容時成立。在實務中,通常可用資訊不僅來自文字內容,還來自文字、圖像、音訊、影片等的多模態組合。多模態蘊涵只是將文本蘊涵擴展到各種新的輸入模態。
這個範例需要 TensorFlow 2.5 或更高版本。此外,BERT 模型 (Devlin 等人) 需要 TensorFlow Hub 和 TensorFlow Text。這些函式庫可以使用以下命令安裝:
!pip install -q tensorflow_text
[[34;49mnotice[1;39;49m][39;49m A new release of pip is available: [31;49m24.0[39;49m -> [32;49m24.3.1
[[34;49mnotice[1;39;49m][39;49m To update, run: [32;49mpip install --upgrade pip
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import random
import math
from skimage.io import imread
from skimage.transform import resize
from PIL import Image
import os
os.environ["KERAS_BACKEND"] = "jax" # or tensorflow, or torch
import keras
import keras_hub
from keras.utils import PyDataset
label_map = {"Contradictory": 0, "Implies": 1, "NoEntailment": 2}
原始資料集可在這裡取得。它帶有圖像的 URL,這些圖像託管在 Twitter 的照片儲存系統 Photo Blob Storage (簡稱 PBS) 上。我們將使用下載的圖像以及原始資料集附帶的其他資料。感謝 Nilabhra Roy Chowdhury 處理圖像資料的準備工作。
image_base_path = keras.utils.get_file(
"tweet_images",
"https://github.com/sayakpaul/Multimodal-Entailment-Baseline/releases/download/v1.0.0/tweet_images.tar.gz",
untar=True,
)
df = pd.read_csv(
"https://github.com/sayakpaul/Multimodal-Entailment-Baseline/raw/main/csvs/tweets.csv"
).iloc[
0:1000
] # Resources conservation since these are examples and not SOTA
df.sample(10)
id_1 | text_1 | image_1 | id_2 | text_2 | image_2 | 標籤 | |
---|---|---|---|---|---|---|---|
815 | 1370730009921343490 | 黏性炸彈是威脅,因為它們有磁鐵... | http://pbs.twimg.com/media/EwXOFrgVIAEkfjR.jpg | 1370731764906295307 | 黏性炸彈是威脅,因為它們有磁鐵... | http://pbs.twimg.com/media/EwXRK_3XEAA6Q6F.jpg | 無蘊涵 |
615 | 1364119737446395905 | #巨蟹座 2.23.21 每日星座運勢 ♊️❤️✨ #Hor... | http://pbs.twimg.com/media/Eu5Te44VgAIo1jZ.jpg | 1365218087906078720 | #巨蟹座 2.26.21 每日星座運勢 ♊️❤️✨ #Hor... | http://pbs.twimg.com/media/EvI6nW4WQAA4_E_.jpg | 無蘊涵 |
624 | 1335542260923068417 | 馴鹿賽跑又回來了,今年的賽跑是... | http://pbs.twimg.com/media/Eoi99DyXEAE0AFV.jpg | 1335872932267122689 | 戴上你的紅鼻子和鹿角,參加 2020 年的... | http://pbs.twimg.com/media/Eon5Wk7XUAE-CxN.jpg | 無蘊涵 |
970 | 1345058844439949312 | 線上調查需要參與者!\n\n主題... | http://pbs.twimg.com/media/Eqqb4_MXcAA-Pvu.jpg | 1361211461792632835 | 針對蘇... 的頂級研究需要參與者 | http://pbs.twimg.com/media/EuPz0GwXMAMDklt.jpg | 無蘊涵 |
456 | 1379831489043521545 | 為 @NanoBiteTSF 繪製的委託畫作,享受兄弟們和... | http://pbs.twimg.com/media/EyVf0_VXMAMtRaL.jpg | 1380660763749142531 | 為 @NanoBiteTSF 繪製的另一幅委託畫作,希望你... | http://pbs.twimg.com/media/EykW0iXXAAA2SBC.jpg | 無蘊涵 |
917 | 1336180735191891968 | (2/10)\n(首爾中區) 市場群聚 ->\n... | http://pbs.twimg.com/media/EosRFpGVQAIeuYG.jpg | 1356113330536996866 | (3/11)\n(首爾東大門區) 考試院群聚... | http://pbs.twimg.com/media/EtHhj7QVcAAibvF.jpg | 無蘊涵 |
276 | 1339270210029834241 | 今天自由的訊息傳到基索羅、盧... | http://pbs.twimg.com/media/EpVK3pfXcAAZ5Du.jpg | 1340881971132698625 | 今天自由的訊息將傳到... | http://pbs.twimg.com/media/EpvDorkXYAEyz4g.jpg | 暗示 |
35 | 1360186999836200961 | 阿根廷的比特幣 - Google 趨勢 https://t... | http://pbs.twimg.com/media/EuBa3UxXYAMb99_.jpg | 1382778703055228929 | 阿根廷想要 #比特幣 https://127.0.0.1/9lNxJdxX... | http://pbs.twimg.com/media/EzCbUFNXMAABwPD.jpg | 暗示 |
762 | 1370824756400959491 | $HSBA.L:長期趨勢是正向的,且... | http://pbs.twimg.com/media/EwYl2hPWYAE2niq.png | 1374347458126475269 | 雖然技術評級僅為中等,但... | http://pbs.twimg.com/media/ExKpuwrWgAAktg4.png | 無蘊涵 |
130 | 1373789433607172097 | 我剛看完《泰德拉索》S01 | E05 集... | http://pbs.twimg.com/media/ExCuNbDXAAQaPiL.jpg | 1374913509662806016 | 我剛看完《泰德拉索》S01 | E06 集... | http://pbs.twimg.com/media/ExSsjRQWgAUVRPz.jpg | 矛盾 |
我們感興趣的欄位如下:
text_1
image_1
text_2
image_2
標籤
蘊涵任務的公式如下:
給定 (text_1
, image_1
) 和 (text_2
, image_2
) 的配對,它們是否蘊含(或不蘊含或矛盾)彼此?
我們已經下載了圖像。image_1
以下載為 id1
作為其檔案名稱,而 image2
以下載為 id2
作為其檔案名稱。在下一步中,我們將在 df
中新增兩個欄位 - image_1
和 image_2
的檔案路徑。
images_one_paths = []
images_two_paths = []
for idx in range(len(df)):
current_row = df.iloc[idx]
id_1 = current_row["id_1"]
id_2 = current_row["id_2"]
extentsion_one = current_row["image_1"].split(".")[-1]
extentsion_two = current_row["image_2"].split(".")[-1]
image_one_path = os.path.join(image_base_path, str(id_1) + f".{extentsion_one}")
image_two_path = os.path.join(image_base_path, str(id_2) + f".{extentsion_two}")
images_one_paths.append(image_one_path)
images_two_paths.append(image_two_path)
df["image_1_path"] = images_one_paths
df["image_2_path"] = images_two_paths
# Create another column containing the integer ids of
# the string labels.
df["label_idx"] = df["label"].apply(lambda x: label_map[x])
def visualize(idx):
current_row = df.iloc[idx]
image_1 = plt.imread(current_row["image_1_path"])
image_2 = plt.imread(current_row["image_2_path"])
text_1 = current_row["text_1"]
text_2 = current_row["text_2"]
label = current_row["label"]
plt.subplot(1, 2, 1)
plt.imshow(image_1)
plt.axis("off")
plt.title("Image One")
plt.subplot(1, 2, 2)
plt.imshow(image_1)
plt.axis("off")
plt.title("Image Two")
plt.show()
print(f"Text one: {text_1}")
print(f"Text two: {text_2}")
print(f"Label: {label}")
random_idx = random.choice(range(len(df)))
visualize(random_idx)
random_idx = random.choice(range(len(df)))
visualize(random_idx)
Text one: World #water day reminds that we should follow the #guidelines to save water for us. This Day is an #opportunity to learn more about water related issues, be #inspired to tell others and take action to make a difference. Just remember, every #drop counts.
#WorldWaterDay2021 https://127.0.0.1/bQ9Hp53qUj
Text two: Water is an extremely precious resource without which life would be impossible. We need to ensure that water is used judiciously, this #WorldWaterDay, let us pledge to reduce water wastage and conserve it.
#WorldWaterDay2021 https://127.0.0.1/0KWnd8Kn8r
Label: NoEntailment
Text one: 🎧 𝗘𝗣𝗜𝗦𝗢𝗗𝗘 𝟯𝟬: 𝗗𝗬𝗟𝗔𝗡 𝗙𝗜𝗧𝗭𝗦𝗜𝗠𝗢𝗡𝗦
Dylan Fitzsimons is a young passionate greyhound supporter.
He and @Drakesport enjoy a great chat about everything greyhounds!
Listen: https://127.0.0.1/B2XgMp0yaO
#GoGreyhoundRacing #ThisRunsDeep #TalkingDogs https://127.0.0.1/crBiSqHUvp
Text two: 🎧 𝗘𝗣𝗜𝗦𝗢𝗗𝗘 𝟯𝟳: 𝗣𝗜𝗢 𝗕𝗔𝗥𝗥𝗬 🎧
Well known within greyhound circles, Pio Barry shares some wonderful greyhound racing stories with @Drakesport in this podcast episode.
A great chat.
Listen: https://127.0.0.1/mJTVlPHzp0
#TalkingDogs #GoGreyhoundRacing #ThisRunsDeep https://127.0.0.1/QbxtCpLcGm
Label: NoEntailment
資料集受到類別不平衡問題的影響。我們可以在以下儲存格中確認這一點。
df["label"].value_counts()
label
NoEntailment 819
Contradictory 92
Implies 89
Name: count, dtype: int64
為了考量這一點,我們將進行分層分割。
# 10% for test
train_df, test_df = train_test_split(
df, test_size=0.1, stratify=df["label"].values, random_state=42
)
# 5% for validation
train_df, val_df = train_test_split(
train_df, test_size=0.05, stratify=train_df["label"].values, random_state=42
)
print(f"Total training examples: {len(train_df)}")
print(f"Total validation examples: {len(val_df)}")
print(f"Total test examples: {len(test_df)}")
Total training examples: 855
Total validation examples: 45
Total test examples: 100
Keras Hub 提供各種 BERT 系列模型。每個模型都帶有對應的預處理層。您可以從此資源中瞭解有關這些模型及其預處理層的更多資訊。
為了使此範例的執行時間相對較短,我們將使用原始 BERT 模型的基本未經編譯變體。
使用 KerasHub 進行文字預處理
text_preprocessor = keras_hub.models.BertTextClassifierPreprocessor.from_preset(
"bert_base_en_uncased",
sequence_length=128,
)
idx = random.choice(range(len(train_df)))
row = train_df.iloc[idx]
sample_text_1, sample_text_2 = row["text_1"], row["text_2"]
print(f"Text 1: {sample_text_1}")
print(f"Text 2: {sample_text_2}")
test_text = [sample_text_1, sample_text_2]
text_preprocessed = text_preprocessor(test_text)
print("Keys : ", list(text_preprocessed.keys()))
print("Shape Token Ids : ", text_preprocessed["token_ids"].shape)
print("Token Ids : ", text_preprocessed["token_ids"][0, :16])
print(" Shape Padding Mask : ", text_preprocessed["padding_mask"].shape)
print("Padding Mask : ", text_preprocessed["padding_mask"][0, :16])
print("Shape Segment Ids : ", text_preprocessed["segment_ids"].shape)
print("Segment Ids : ", text_preprocessed["segment_ids"][0, :16])
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
Text 1: The RPF Lohardaga and Hatia Post of Ranchi Division have recovered 02 bags on 20.02.2021 at Station platform and in T/No.08310 Spl. respectively and handed over to their actual owner correctly. @RPF_INDIA https://127.0.0.1/bdEBl2egIc
Text 2: The RPF Lohardaga and Hatia Post of Ranchi Division have recovered 02 bags on 20.02.2021 at Station platform and in T/No.08310 (JAT-SBP) Spl. respectively and handed over to their actual owner correctly. @RPF_INDIA https://127.0.0.1/Q5l2AtA4uq
Keys : ['token_ids', 'padding_mask', 'segment_ids']
Shape Token Ids : (2, 128)
Token Ids : [ 101 1996 1054 14376 8840 11783 16098 1998 6045 2401 2695 1997
8086 2072 2407 2031]
Shape Padding Mask : (2, 128)
Padding Mask : [ True True True True True True True True True True True True
True True True True]
Shape Segment Ids : (2, 128)
Segment Ids : [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
我們現在將從資料框架中建立 tf.data.Dataset
物件。
請注意,文字輸入將作為資料輸入管道的一部分進行預處理。但是,預處理模組也可以是其對應的 BERT 模型的一部分。這有助於減少訓練/服務偏差,並讓我們的模型使用原始文字輸入進行操作。請遵循本教學課程以瞭解有關如何直接將預處理模組合併到模型中的更多資訊。
def dataframe_to_dataset(dataframe):
columns = ["image_1_path", "image_2_path", "text_1", "text_2", "label_idx"]
ds = UnifiedPyDataset(
dataframe,
batch_size=32,
workers=4,
)
return ds
bert_input_features = ["padding_mask", "segment_ids", "token_ids"]
def preprocess_text(text_1, text_2):
output = text_preprocessor([text_1, text_2])
output = {
feature: keras.ops.reshape(output[feature], [-1])
for feature in bert_input_features
}
return output
class UnifiedPyDataset(PyDataset):
"""A Keras-compatible dataset that processes a DataFrame for TensorFlow, JAX, and PyTorch."""
def __init__(
self,
df,
batch_size=32,
workers=4,
use_multiprocessing=False,
max_queue_size=10,
**kwargs,
):
"""
Args:
df: pandas DataFrame with data
batch_size: Batch size for dataset
workers: Number of workers to use for parallel loading (Keras)
use_multiprocessing: Whether to use multiprocessing
max_queue_size: Maximum size of the data queue for parallel loading
"""
super().__init__(**kwargs)
self.dataframe = df
columns = ["image_1_path", "image_2_path", "text_1", "text_2"]
# image files
self.image_x_1 = self.dataframe["image_1_path"]
self.image_x_2 = self.dataframe["image_1_path"]
self.image_y = self.dataframe["label_idx"]
# text files
self.text_x_1 = self.dataframe["text_1"]
self.text_x_2 = self.dataframe["text_2"]
self.text_y = self.dataframe["label_idx"]
# general
self.batch_size = batch_size
self.workers = workers
self.use_multiprocessing = use_multiprocessing
self.max_queue_size = max_queue_size
def __getitem__(self, index):
"""
Fetches a batch of data from the dataset at the given index.
"""
# Return x, y for batch idx.
low = index * self.batch_size
# Cap upper bound at array length; the last batch may be smaller
# if the total number of items is not a multiple of batch size.
high_image_1 = min(low + self.batch_size, len(self.image_x_1))
high_image_2 = min(low + self.batch_size, len(self.image_x_2))
high_text_1 = min(low + self.batch_size, len(self.text_x_1))
high_text_2 = min(low + self.batch_size, len(self.text_x_1))
# images files
batch_image_x_1 = self.image_x_1[low:high_image_1]
batch_image_y_1 = self.image_y[low:high_image_1]
batch_image_x_2 = self.image_x_2[low:high_image_2]
batch_image_y_2 = self.image_y[low:high_image_2]
# text files
batch_text_x_1 = self.text_x_1[low:high_text_1]
batch_text_y_1 = self.text_y[low:high_text_1]
batch_text_x_2 = self.text_x_2[low:high_text_2]
batch_text_y_2 = self.text_y[low:high_text_2]
# image number 1 inputs
image_1 = [
resize(imread(file_name), (128, 128)) for file_name in batch_image_x_1
]
image_1 = [
( # exeperienced some shapes which were different from others.
np.array(Image.fromarray((img.astype(np.uint8))).convert("RGB"))
if img.shape[2] == 4
else img
)
for img in image_1
]
image_1 = np.array(image_1)
# Both text inputs to the model, return a dict for inputs to BertBackbone
text = {
key: np.array(
[
d[key]
for d in [
preprocess_text(file_path1, file_path2)
for file_path1, file_path2 in zip(
batch_text_x_1, batch_text_x_2
)
]
]
)
for key in ["padding_mask", "token_ids", "segment_ids"]
}
# Image number 2 model inputs
image_2 = [
resize(imread(file_name), (128, 128)) for file_name in batch_image_x_2
]
image_2 = [
( # exeperienced some shapes which were different from others
np.array(Image.fromarray((img.astype(np.uint8))).convert("RGB"))
if img.shape[2] == 4
else img
)
for img in image_2
]
# Stack the list comprehension to an nd.array
image_2 = np.array(image_2)
return (
{
"image_1": image_1,
"image_2": image_2,
"padding_mask": text["padding_mask"],
"segment_ids": text["segment_ids"],
"token_ids": text["token_ids"],
},
# Target lables
np.array(batch_image_y_1),
)
def __len__(self):
"""
Returns the number of batches in the dataset.
"""
return math.ceil(len(self.dataframe) / self.batch_size)
建立訓練、驗證和測試資料集
def prepare_dataset(dataframe):
ds = dataframe_to_dataset(dataframe)
return ds
train_ds = prepare_dataset(train_df)
validation_ds = prepare_dataset(val_df)
test_ds = prepare_dataset(test_df)
我們的最終模型將接受兩個圖像及其文字對應項。雖然圖像將直接饋送到模型中,但文字輸入將首先進行預處理,然後才會進入模型。以下是此方法的視覺說明:
模型包含以下元素:
提取個別嵌入後,它們將會投影到相同的空間中。最後,它們的投影將會串連在一起,並饋送到最終分類層。
這是一個涉及以下類別的多類別分類問題:
project_embeddings()
、create_vision_encoder()
和 create_text_encoder()
工具程式參考自此範例。
投影工具程式
def project_embeddings(
embeddings, num_projection_layers, projection_dims, dropout_rate
):
projected_embeddings = keras.layers.Dense(units=projection_dims)(embeddings)
for _ in range(num_projection_layers):
x = keras.ops.nn.gelu(projected_embeddings)
x = keras.layers.Dense(projection_dims)(x)
x = keras.layers.Dropout(dropout_rate)(x)
x = keras.layers.Add()([projected_embeddings, x])
projected_embeddings = keras.layers.LayerNormalization()(x)
return projected_embeddings
視覺編碼器工具程式
def create_vision_encoder(
num_projection_layers, projection_dims, dropout_rate, trainable=False
):
# Load the pre-trained ResNet50V2 model to be used as the base encoder.
resnet_v2 = keras.applications.ResNet50V2(
include_top=False, weights="imagenet", pooling="avg"
)
# Set the trainability of the base encoder.
for layer in resnet_v2.layers:
layer.trainable = trainable
# Receive the images as inputs.
image_1 = keras.Input(shape=(128, 128, 3), name="image_1")
image_2 = keras.Input(shape=(128, 128, 3), name="image_2")
# Preprocess the input image.
preprocessed_1 = keras.applications.resnet_v2.preprocess_input(image_1)
preprocessed_2 = keras.applications.resnet_v2.preprocess_input(image_2)
# Generate the embeddings for the images using the resnet_v2 model
# concatenate them.
embeddings_1 = resnet_v2(preprocessed_1)
embeddings_2 = resnet_v2(preprocessed_2)
embeddings = keras.layers.Concatenate()([embeddings_1, embeddings_2])
# Project the embeddings produced by the model.
outputs = project_embeddings(
embeddings, num_projection_layers, projection_dims, dropout_rate
)
# Create the vision encoder model.
return keras.Model([image_1, image_2], outputs, name="vision_encoder")
文字編碼器工具程式
def create_text_encoder(
num_projection_layers, projection_dims, dropout_rate, trainable=False
):
# Load the pre-trained BERT BackBone using KerasHub.
bert = keras_hub.models.BertBackbone.from_preset(
"bert_base_en_uncased", num_classes=3
)
# Set the trainability of the base encoder.
bert.trainable = trainable
# Receive the text as inputs.
bert_input_features = ["padding_mask", "segment_ids", "token_ids"]
inputs = {
feature: keras.Input(shape=(256,), dtype="int32", name=feature)
for feature in bert_input_features
}
# Generate embeddings for the preprocessed text using the BERT model.
embeddings = bert(inputs)["pooled_output"]
# Project the embeddings produced by the model.
outputs = project_embeddings(
embeddings, num_projection_layers, projection_dims, dropout_rate
)
# Create the text encoder model.
return keras.Model(inputs, outputs, name="text_encoder")
多模態模型工具程式
def create_multimodal_model(
num_projection_layers=1,
projection_dims=256,
dropout_rate=0.1,
vision_trainable=False,
text_trainable=False,
):
# Receive the images as inputs.
image_1 = keras.Input(shape=(128, 128, 3), name="image_1")
image_2 = keras.Input(shape=(128, 128, 3), name="image_2")
# Receive the text as inputs.
bert_input_features = ["padding_mask", "segment_ids", "token_ids"]
text_inputs = {
feature: keras.Input(shape=(256,), dtype="int32", name=feature)
for feature in bert_input_features
}
text_inputs = list(text_inputs.values())
# Create the encoders.
vision_encoder = create_vision_encoder(
num_projection_layers, projection_dims, dropout_rate, vision_trainable
)
text_encoder = create_text_encoder(
num_projection_layers, projection_dims, dropout_rate, text_trainable
)
# Fetch the embedding projections.
vision_projections = vision_encoder([image_1, image_2])
text_projections = text_encoder(text_inputs)
# Concatenate the projections and pass through the classification layer.
concatenated = keras.layers.Concatenate()([vision_projections, text_projections])
outputs = keras.layers.Dense(3, activation="softmax")(concatenated)
return keras.Model([image_1, image_2, *text_inputs], outputs)
multimodal_model = create_multimodal_model()
keras.utils.plot_model(multimodal_model, show_shapes=True)
您也可以透過將 plot_model()
的 expand_nested
參數設定為 True
來檢查各個編碼器的結構。建議您嘗試建構此模型時使用的不同超參數,並觀察最終效能如何受到影響。
multimodal_model.compile(
optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"]
)
history = multimodal_model.fit(train_ds, validation_data=validation_ds, epochs=1)
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
warnings.warn(msg)
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
warnings.warn(msg)
1/27 [37m━━━━━━━━━━━━━━━━━━━━ 45:45 106秒/步 - 準確度:0.0625 - 損失:1.6335
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
warnings.warn(msg)
2/27 ━[37m━━━━━━━━━━━━━━━━━━━ 42:14 101秒/步 - 準確度:0.2422 - 損失:1.9508
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
warnings.warn(msg)
3/27 ━━[37m━━━━━━━━━━━━━━━━━━ 38:49 97秒/步 - 準確度:0.3524 - 損失:2.0126
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
warnings.warn(msg)
4/27 ━━[37m━━━━━━━━━━━━━━━━━━ 37:09 97秒/步 - 準確度:0.4284 - 損失:1.9870
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
warnings.warn(msg)
5/27 ━━━[37m━━━━━━━━━━━━━━━━━ 35:08 96秒/步 - 準確度:0.4815 - 損失:1.9855
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
warnings.warn(msg)
6/27 ━━━━[37m━━━━━━━━━━━━━━━━ 31:56 91秒/步 - 準確度:0.5210 - 損失:1.9939
7/27 ━━━━━[37m━━━━━━━━━━━━━━━ 29:30 89秒/步 - 準確度:0.5512 - 損失:1.9980
8/27 ━━━━━[37m━━━━━━━━━━━━━━━ 27:12 86秒/步 - 準確度:0.5750 - 損失:2.0061
9/27 ━━━━━━[37m━━━━━━━━━━━━━━ 25:15 84秒/步 - 準確度:0.5956 - 損失:1.9959
10/27 ━━━━━━━[37m━━━━━━━━━━━━━ 23:33 83秒/步 - 準確度:0.6120 - 損失:1.9738
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
warnings.warn(msg)
11/27 ━━━━━━━━[37m━━━━━━━━━━━━ 22:09 83秒/步 - 準確度:0.6251 - 損失:1.9579
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
warnings.warn(msg)
12/27 ━━━━━━━━[37m━━━━━━━━━━━━ 20:59 84秒/步 - 準確度:0.6357 - 損失:1.9524
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
warnings.warn(msg)
13/27 ━━━━━━━━━[37m━━━━━━━━━━━ 19:44 85秒/步 - 準確度:0.6454 - 損失:1.9439
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
warnings.warn(msg)
14/27 ━━━━━━━━━━[37m━━━━━━━━━━ 18:22 85秒/步 - 準確度:0.6540 - 損失:1.9346
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(23, 256))', 'Tensor(shape=(23, 256))', 'Tensor(shape=(23, 256))']
warnings.warn(msg)
15/27 ━━━━━━━━━━━[37m━━━━━━━━━ 16:52 84秒/步 - 準確度:0.6621 - 損失:1.9213
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
warnings.warn(msg)
16/27 ━━━━━━━━━━━[37m━━━━━━━━━ 15:29 85秒/步 - 準確度:0.6693 - 損失:1.9101
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
warnings.warn(msg)
17/27 ━━━━━━━━━━━━[37m━━━━━━━━ 14:08 85秒/步 - 準確度:0.6758 - 損失:1.9021
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
warnings.warn(msg)
18/27 ━━━━━━━━━━━━━[37m━━━━━━━ 12:45 85秒/步 - 準確度:0.6819 - 損失:1.8916
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
warnings.warn(msg)
19/27 ━━━━━━━━━━━━━━[37m━━━━━━ 11:24 86秒/步 - 準確度:0.6874 - 損失:1.8851
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
warnings.warn(msg)
20/27 ━━━━━━━━━━━━━━[37m━━━━━━ 10:00 86秒/步 - 準確度:0.6925 - 損失:1.8791
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
warnings.warn(msg)
21/27 ━━━━━━━━━━━━━━━[37m━━━━━ 8:36 86秒/步 - 準確度:0.6976 - 損失:1.8699
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
warnings.warn(msg)
22/27 ━━━━━━━━━━━━━━━━[37m━━━━ 7:11 86秒/步 - 準確度:0.7020 - 損失:1.8623
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
warnings.warn(msg)
23/27 ━━━━━━━━━━━━━━━━━[37m━━━ 5:46 87秒/步 - 準確度:0.7061 - 損失:1.8573
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
warnings.warn(msg)
24/27 ━━━━━━━━━━━━━━━━━[37m━━━ 4:20 87秒/步 - 準確度:0.7100 - 損失:1.8534
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
warnings.warn(msg)
25/27 ━━━━━━━━━━━━━━━━━━[37m━━ 2:54 87秒/步 - 準確度:0.7136 - 損失:1.8494
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
warnings.warn(msg)
26/27 ━━━━━━━━━━━━━━━━━━━[37m━ 1:27 87秒/步 - 準確度:0.7170 - 損失:1.8449
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
warnings.warn(msg)
27/27 ━━━━━━━━━━━━━━━━━━━━ 0秒 88秒/步 - 準確度:0.7201 - 損失:1.8414
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/PIL/Image.py:1054: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
warnings.warn(
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/PIL/Image.py:1054: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
warnings.warn(
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
warnings.warn(msg)
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(13, 256))', 'Tensor(shape=(13, 256))', 'Tensor(shape=(13, 256))']
warnings.warn(msg)
27/27 ━━━━━━━━━━━━━━━━━━━━ 2508秒 92秒/步 - 準確度:0.7231 - 損失:1.8382 - 驗證準確度:0.8222 - 驗證損失:1.7304
_, acc = multimodal_model.evaluate(test_ds)
print(f"Accuracy on the test set: {round(acc * 100, 2)}%.")
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/PIL/Image.py:1054: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
warnings.warn(
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/PIL/Image.py:1054: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
warnings.warn(
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
warnings.warn(msg)
1/4 ━━━━━[37m━━━━━━━━━━━━━━━ 5:32 111秒/步 - 準確度:0.7812 - 損失:1.9384
2/4 ━━━━━━━━━━[37m━━━━━━━━━━ 2:10 65秒/步 - 準確度:0.7969 - 損失:1.8931
3/4 ━━━━━━━━━━━━━━━[37m━━━━━ 1:05 65秒/步 - 準確度:0.8056 - 損失:1.8200
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(4, 256))', 'Tensor(shape=(4, 256))', 'Tensor(shape=(4, 256))']
warnings.warn(msg)
4/4 ━━━━━━━━━━━━━━━━━━━━ 0秒 49秒/步 - 準確度:0.8092 - 損失:1.8075
4/4 ━━━━━━━━━━━━━━━━━━━━ 256秒 49秒/步 - 準確度:0.8113 - 損失:1.8000
Accuracy on the test set: 82.0%.
加入正規化:
訓練記錄顯示模型開始過度擬合,且可能受益於正規化。Dropout(Srivastava 等人)是一種簡單但功能強大的正規化技術,我們可以在模型中使用。但是我們應該如何在此處應用它呢?
我們始終可以在模型的不同層之間引入 Dropout(keras.layers.Dropout
)。但這裡還有另一個方法。我們的模型預期來自兩種不同資料模式的輸入。如果在推論期間其中一種模式不存在會怎麼樣?為了考量到這一點,我們可以將 Dropout 引入到個別的投影中,就在它們被串連在一起之前
vision_projections = keras.layers.Dropout(rate)(vision_projections)
text_projections = keras.layers.Dropout(rate)(text_projections)
concatenated = keras.layers.Concatenate()([vision_projections, text_projections])
關注重要事項:
圖片的所有部分是否都與它們的文字對應部分相等?情況可能不是這樣。為了讓我們的模型只專注於圖片中與其對應文字部分最相關的重要位元,我們可以使用「交叉注意力」。
# Embeddings.
vision_projections = vision_encoder([image_1, image_2])
text_projections = text_encoder(text_inputs)
# Cross-attention (Luong-style).
query_value_attention_seq = keras.layers.Attention(use_scale=True, dropout=0.2)(
[vision_projections, text_projections]
)
# Concatenate.
concatenated = keras.layers.Concatenate()([vision_projections, text_projections])
contextual = keras.layers.Concatenate()([concatenated, query_value_attention_seq])
若要查看實際運作情形,請參閱此筆記本。
處理類別不平衡:
此資料集存在類別不平衡問題。調查上述模型的混淆矩陣會發現,它在少數類別上的表現很差。如果我們使用加權損失,那麼訓練會更有指導性。您可以查看此筆記本,其中在模型訓練期間會考量類別不平衡問題。
僅使用文字輸入:
另外,如果我們只在蘊含任務中使用文字輸入會怎麼樣?由於在社群媒體平台上遇到的文字輸入的本質,僅文字輸入會損害最終效能。在類似的訓練設定下,僅使用文字輸入,我們在相同的測試集上達到 67.14% 的前 1 準確度。請參閱此筆記本以了解詳細資訊。
最後,這是一個比較蘊含任務所採取不同方法的表格
類型 | 標準 交叉熵 |
損失加權 交叉熵 |
焦點損失 |
---|---|---|---|
多模態 | 77.86% | 67.86% | 86.43% |
僅文字 | 67.14% | 11.43% | 37.86% |
您可以查看此儲存庫,以深入瞭解如何進行實驗以取得這些數字。
您可以使用託管在Hugging Face Hub上的訓練模型,並在Hugging Face Spaces上試用示範。