程式碼範例 / 自然語言處理 / 多模態蘊涵

多模態蘊涵

作者: Sayak Paul
建立日期 2021/08/08
上次修改日期 2025/01/03
描述: 訓練用於預測蘊涵關係的多模態模型。

ⓘ 這個範例使用 Keras 2

在 Colab 中檢視 GitHub 原始碼


簡介

在這個範例中,我們將建立並訓練一個用於預測多模態蘊涵的模型。我們將使用 Google Research 最近推出的 多模態蘊涵資料集

什麼是多模態蘊涵?

在社群媒體平台上,為了稽核和調節內容,我們可能需要近乎即時地找出以下問題的答案:

  • 給定的資訊是否與另一個資訊相矛盾?
  • 給定的資訊是否暗示另一個資訊?

在自然語言處理中,這個任務稱為分析文本蘊涵。然而,這僅在資訊來自文字內容時成立。在實務中,通常可用資訊不僅來自文字內容,還來自文字、圖像、音訊、影片等的多模態組合。多模態蘊涵只是將文本蘊涵擴展到各種新的輸入模態。

需求

這個範例需要 TensorFlow 2.5 或更高版本。此外,BERT 模型 (Devlin 等人) 需要 TensorFlow Hub 和 TensorFlow Text。這些函式庫可以使用以下命令安裝:

!pip install -q tensorflow_text
 [notice] A new release of pip is available: 24.0 -> 24.3.1
 [notice] To update, run: pip install --upgrade pip

匯入

from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import random
import math
from skimage.io import imread
from skimage.transform import resize
from PIL import Image
import os

os.environ["KERAS_BACKEND"] = "jax"  # or tensorflow, or torch

import keras
import keras_hub
from keras.utils import PyDataset

定義標籤對應

label_map = {"Contradictory": 0, "Implies": 1, "NoEntailment": 2}

收集資料集

原始資料集可在這裡取得。它帶有圖像的 URL,這些圖像託管在 Twitter 的照片儲存系統 Photo Blob Storage (簡稱 PBS) 上。我們將使用下載的圖像以及原始資料集附帶的其他資料。感謝 Nilabhra Roy Chowdhury 處理圖像資料的準備工作。

image_base_path = keras.utils.get_file(
    "tweet_images",
    "https://github.com/sayakpaul/Multimodal-Entailment-Baseline/releases/download/v1.0.0/tweet_images.tar.gz",
    untar=True,
)

讀取資料集並應用基本預處理

df = pd.read_csv(
    "https://github.com/sayakpaul/Multimodal-Entailment-Baseline/raw/main/csvs/tweets.csv"
).iloc[
    0:1000
]  # Resources conservation since these are examples and not SOTA
df.sample(10)
id_1 text_1 image_1 id_2 text_2 image_2 標籤
815 1370730009921343490 黏性炸彈是威脅,因為它們有磁鐵... http://pbs.twimg.com/media/EwXOFrgVIAEkfjR.jpg 1370731764906295307 黏性炸彈是威脅,因為它們有磁鐵... http://pbs.twimg.com/media/EwXRK_3XEAA6Q6F.jpg 無蘊涵
615 1364119737446395905 #巨蟹座 2.23.21 每日星座運勢 ♊️❤️✨ #Hor... http://pbs.twimg.com/media/Eu5Te44VgAIo1jZ.jpg 1365218087906078720 #巨蟹座 2.26.21 每日星座運勢 ♊️❤️✨ #Hor... http://pbs.twimg.com/media/EvI6nW4WQAA4_E_.jpg 無蘊涵
624 1335542260923068417 馴鹿賽跑又回來了,今年的賽跑是... http://pbs.twimg.com/media/Eoi99DyXEAE0AFV.jpg 1335872932267122689 戴上你的紅鼻子和鹿角,參加 2020 年的... http://pbs.twimg.com/media/Eon5Wk7XUAE-CxN.jpg 無蘊涵
970 1345058844439949312 線上調查需要參與者!\n\n主題... http://pbs.twimg.com/media/Eqqb4_MXcAA-Pvu.jpg 1361211461792632835 針對蘇... 的頂級研究需要參與者 http://pbs.twimg.com/media/EuPz0GwXMAMDklt.jpg 無蘊涵
456 1379831489043521545 為 @NanoBiteTSF 繪製的委託畫作,享受兄弟們和... http://pbs.twimg.com/media/EyVf0_VXMAMtRaL.jpg 1380660763749142531 為 @NanoBiteTSF 繪製的另一幅委託畫作,希望你... http://pbs.twimg.com/media/EykW0iXXAAA2SBC.jpg 無蘊涵
917 1336180735191891968 (2/10)\n(首爾中區) 市場群聚 ->\n... http://pbs.twimg.com/media/EosRFpGVQAIeuYG.jpg 1356113330536996866 (3/11)\n(首爾東大門區) 考試院群聚... http://pbs.twimg.com/media/EtHhj7QVcAAibvF.jpg 無蘊涵
276 1339270210029834241 今天自由的訊息傳到基索羅、盧... http://pbs.twimg.com/media/EpVK3pfXcAAZ5Du.jpg 1340881971132698625 今天自由的訊息將傳到... http://pbs.twimg.com/media/EpvDorkXYAEyz4g.jpg 暗示
35 1360186999836200961 阿根廷的比特幣 - Google 趨勢 https://t... http://pbs.twimg.com/media/EuBa3UxXYAMb99_.jpg 1382778703055228929 阿根廷想要 #比特幣 https://127.0.0.1/9lNxJdxX... http://pbs.twimg.com/media/EzCbUFNXMAABwPD.jpg 暗示
762 1370824756400959491 $HSBA.L:長期趨勢是正向的,且... http://pbs.twimg.com/media/EwYl2hPWYAE2niq.png 1374347458126475269 雖然技術評級僅為中等,但... http://pbs.twimg.com/media/ExKpuwrWgAAktg4.png 無蘊涵
130 1373789433607172097 我剛看完《泰德拉索》S01 | E05 集... http://pbs.twimg.com/media/ExCuNbDXAAQaPiL.jpg 1374913509662806016 我剛看完《泰德拉索》S01 | E06 集... http://pbs.twimg.com/media/ExSsjRQWgAUVRPz.jpg 矛盾

我們感興趣的欄位如下:

  • text_1
  • image_1
  • text_2
  • image_2
  • 標籤

蘊涵任務的公式如下:

給定 (text_1, image_1) 和 (text_2, image_2) 的配對,它們是否蘊含(或不蘊含或矛盾)彼此?

我們已經下載了圖像。image_1 以下載為 id1 作為其檔案名稱,而 image2 以下載為 id2 作為其檔案名稱。在下一步中,我們將在 df 中新增兩個欄位 - image_1image_2 的檔案路徑。

images_one_paths = []
images_two_paths = []

for idx in range(len(df)):
    current_row = df.iloc[idx]
    id_1 = current_row["id_1"]
    id_2 = current_row["id_2"]
    extentsion_one = current_row["image_1"].split(".")[-1]
    extentsion_two = current_row["image_2"].split(".")[-1]

    image_one_path = os.path.join(image_base_path, str(id_1) + f".{extentsion_one}")
    image_two_path = os.path.join(image_base_path, str(id_2) + f".{extentsion_two}")

    images_one_paths.append(image_one_path)
    images_two_paths.append(image_two_path)

df["image_1_path"] = images_one_paths
df["image_2_path"] = images_two_paths

# Create another column containing the integer ids of
# the string labels.
df["label_idx"] = df["label"].apply(lambda x: label_map[x])

資料集視覺化

def visualize(idx):
    current_row = df.iloc[idx]
    image_1 = plt.imread(current_row["image_1_path"])
    image_2 = plt.imread(current_row["image_2_path"])
    text_1 = current_row["text_1"]
    text_2 = current_row["text_2"]
    label = current_row["label"]

    plt.subplot(1, 2, 1)
    plt.imshow(image_1)
    plt.axis("off")
    plt.title("Image One")
    plt.subplot(1, 2, 2)
    plt.imshow(image_1)
    plt.axis("off")
    plt.title("Image Two")
    plt.show()

    print(f"Text one: {text_1}")
    print(f"Text two: {text_2}")
    print(f"Label: {label}")


random_idx = random.choice(range(len(df)))
visualize(random_idx)

random_idx = random.choice(range(len(df)))
visualize(random_idx)

png

Text one: World #water day reminds that we should follow the #guidelines to save water for us. This Day is an #opportunity to learn more about water related issues, be #inspired to tell others and take action to make a difference. Just remember, every #drop counts.
#WorldWaterDay2021 https://127.0.0.1/bQ9Hp53qUj
Text two: Water is an extremely precious resource without which life would be impossible. We need to ensure that water is used judiciously, this #WorldWaterDay, let us pledge to reduce water wastage and conserve it.
#WorldWaterDay2021 https://127.0.0.1/0KWnd8Kn8r
Label: NoEntailment

png

Text one: 🎧 𝗘𝗣𝗜𝗦𝗢𝗗𝗘 𝟯𝟬: 𝗗𝗬𝗟𝗔𝗡 𝗙𝗜𝗧𝗭𝗦𝗜𝗠𝗢𝗡𝗦
Dylan Fitzsimons is a young passionate greyhound supporter. 
He and @Drakesport enjoy a great chat about everything greyhounds!
Listen: https://127.0.0.1/B2XgMp0yaO
#GoGreyhoundRacing #ThisRunsDeep #TalkingDogs https://127.0.0.1/crBiSqHUvp
Text two: 🎧 𝗘𝗣𝗜𝗦𝗢𝗗𝗘 𝟯𝟳: 𝗣𝗜𝗢 𝗕𝗔𝗥𝗥𝗬 🎧
Well known within greyhound circles, Pio Barry shares some wonderful greyhound racing stories with @Drakesport in this podcast episode.
A great chat. 
Listen: https://127.0.0.1/mJTVlPHzp0
#TalkingDogs #GoGreyhoundRacing #ThisRunsDeep https://127.0.0.1/QbxtCpLcGm
Label: NoEntailment

訓練/測試分割

資料集受到類別不平衡問題的影響。我們可以在以下儲存格中確認這一點。

df["label"].value_counts()
label
NoEntailment     819
Contradictory     92
Implies           89
Name: count, dtype: int64

為了考量這一點,我們將進行分層分割。

# 10% for test
train_df, test_df = train_test_split(
    df, test_size=0.1, stratify=df["label"].values, random_state=42
)
# 5% for validation
train_df, val_df = train_test_split(
    train_df, test_size=0.05, stratify=train_df["label"].values, random_state=42
)

print(f"Total training examples: {len(train_df)}")
print(f"Total validation examples: {len(val_df)}")
print(f"Total test examples: {len(test_df)}")
Total training examples: 855
Total validation examples: 45
Total test examples: 100

資料輸入管道

Keras Hub 提供各種 BERT 系列模型。每個模型都帶有對應的預處理層。您可以從此資源中瞭解有關這些模型及其預處理層的更多資訊。

為了使此範例的執行時間相對較短,我們將使用原始 BERT 模型的基本未經編譯變體。

使用 KerasHub 進行文字預處理

text_preprocessor = keras_hub.models.BertTextClassifierPreprocessor.from_preset(
    "bert_base_en_uncased",
    sequence_length=128,
)

在範例輸入上執行預處理器

idx = random.choice(range(len(train_df)))
row = train_df.iloc[idx]
sample_text_1, sample_text_2 = row["text_1"], row["text_2"]
print(f"Text 1: {sample_text_1}")
print(f"Text 2: {sample_text_2}")

test_text = [sample_text_1, sample_text_2]
text_preprocessed = text_preprocessor(test_text)

print("Keys           : ", list(text_preprocessed.keys()))
print("Shape Token Ids : ", text_preprocessed["token_ids"].shape)
print("Token Ids       : ", text_preprocessed["token_ids"][0, :16])
print(" Shape Padding Mask     : ", text_preprocessed["padding_mask"].shape)
print("Padding Mask     : ", text_preprocessed["padding_mask"][0, :16])
print("Shape Segment Ids : ", text_preprocessed["segment_ids"].shape)
print("Segment Ids       : ", text_preprocessed["segment_ids"][0, :16])
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.

Text 1: The RPF Lohardaga and Hatia Post of Ranchi Division have recovered  02 bags on 20.02.2021 at Station platform and in T/No.08310 Spl. respectively and  handed over to their actual owner correctly. @RPF_INDIA https://127.0.0.1/bdEBl2egIc
Text 2: The RPF Lohardaga and Hatia Post of Ranchi Division have recovered  02 bags on 20.02.2021 at Station platform and in T/No.08310 (JAT-SBP) Spl. respectively and  handed over to their actual owner correctly. @RPF_INDIA https://127.0.0.1/Q5l2AtA4uq
Keys           :  ['token_ids', 'padding_mask', 'segment_ids']
Shape Token Ids :  (2, 128)
Token Ids       :  [  101  1996  1054 14376  8840 11783 16098  1998  6045  2401  2695  1997
  8086  2072  2407  2031]
 Shape Padding Mask     :  (2, 128)
Padding Mask     :  [ True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True]
Shape Segment Ids :  (2, 128)
Segment Ids       :  [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]

我們現在將從資料框架中建立 tf.data.Dataset 物件。

請注意,文字輸入將作為資料輸入管道的一部分進行預處理。但是,預處理模組也可以是其對應的 BERT 模型的一部分。這有助於減少訓練/服務偏差,並讓我們的模型使用原始文字輸入進行操作。請遵循本教學課程以瞭解有關如何直接將預處理模組合併到模型中的更多資訊。

def dataframe_to_dataset(dataframe):
    columns = ["image_1_path", "image_2_path", "text_1", "text_2", "label_idx"]
    ds = UnifiedPyDataset(
        dataframe,
        batch_size=32,
        workers=4,
    )
    return ds

預處理工具

bert_input_features = ["padding_mask", "segment_ids", "token_ids"]


def preprocess_text(text_1, text_2):
    output = text_preprocessor([text_1, text_2])
    output = {
        feature: keras.ops.reshape(output[feature], [-1])
        for feature in bert_input_features
    }
    return output

建立最終資料集,方法改編自 PyDataset 文件字串。

class UnifiedPyDataset(PyDataset):
    """A Keras-compatible dataset that processes a DataFrame for TensorFlow, JAX, and PyTorch."""

    def __init__(
        self,
        df,
        batch_size=32,
        workers=4,
        use_multiprocessing=False,
        max_queue_size=10,
        **kwargs,
    ):
        """
        Args:
            df: pandas DataFrame with data
            batch_size: Batch size for dataset
            workers: Number of workers to use for parallel loading (Keras)
            use_multiprocessing: Whether to use multiprocessing
            max_queue_size: Maximum size of the data queue for parallel loading
        """
        super().__init__(**kwargs)
        self.dataframe = df
        columns = ["image_1_path", "image_2_path", "text_1", "text_2"]

        # image files
        self.image_x_1 = self.dataframe["image_1_path"]
        self.image_x_2 = self.dataframe["image_1_path"]
        self.image_y = self.dataframe["label_idx"]

        # text files
        self.text_x_1 = self.dataframe["text_1"]
        self.text_x_2 = self.dataframe["text_2"]
        self.text_y = self.dataframe["label_idx"]

        # general
        self.batch_size = batch_size
        self.workers = workers
        self.use_multiprocessing = use_multiprocessing
        self.max_queue_size = max_queue_size

    def __getitem__(self, index):
        """
        Fetches a batch of data from the dataset at the given index.
        """

        # Return x, y for batch idx.
        low = index * self.batch_size
        # Cap upper bound at array length; the last batch may be smaller
        # if the total number of items is not a multiple of batch size.

        high_image_1 = min(low + self.batch_size, len(self.image_x_1))
        high_image_2 = min(low + self.batch_size, len(self.image_x_2))

        high_text_1 = min(low + self.batch_size, len(self.text_x_1))
        high_text_2 = min(low + self.batch_size, len(self.text_x_1))

        # images files
        batch_image_x_1 = self.image_x_1[low:high_image_1]
        batch_image_y_1 = self.image_y[low:high_image_1]

        batch_image_x_2 = self.image_x_2[low:high_image_2]
        batch_image_y_2 = self.image_y[low:high_image_2]

        # text files
        batch_text_x_1 = self.text_x_1[low:high_text_1]
        batch_text_y_1 = self.text_y[low:high_text_1]

        batch_text_x_2 = self.text_x_2[low:high_text_2]
        batch_text_y_2 = self.text_y[low:high_text_2]

        # image number 1 inputs
        image_1 = [
            resize(imread(file_name), (128, 128)) for file_name in batch_image_x_1
        ]
        image_1 = [
            (  # exeperienced some shapes which were different from others.
                np.array(Image.fromarray((img.astype(np.uint8))).convert("RGB"))
                if img.shape[2] == 4
                else img
            )
            for img in image_1
        ]
        image_1 = np.array(image_1)

        # Both text inputs to the model, return a dict for inputs to BertBackbone
        text = {
            key: np.array(
                [
                    d[key]
                    for d in [
                        preprocess_text(file_path1, file_path2)
                        for file_path1, file_path2 in zip(
                            batch_text_x_1, batch_text_x_2
                        )
                    ]
                ]
            )
            for key in ["padding_mask", "token_ids", "segment_ids"]
        }

        # Image number 2 model inputs
        image_2 = [
            resize(imread(file_name), (128, 128)) for file_name in batch_image_x_2
        ]
        image_2 = [
            (  # exeperienced some shapes which were different from others
                np.array(Image.fromarray((img.astype(np.uint8))).convert("RGB"))
                if img.shape[2] == 4
                else img
            )
            for img in image_2
        ]
        # Stack the list comprehension to an nd.array
        image_2 = np.array(image_2)

        return (
            {
                "image_1": image_1,
                "image_2": image_2,
                "padding_mask": text["padding_mask"],
                "segment_ids": text["segment_ids"],
                "token_ids": text["token_ids"],
            },
            # Target lables
            np.array(batch_image_y_1),
        )

    def __len__(self):
        """
        Returns the number of batches in the dataset.
        """
        return math.ceil(len(self.dataframe) / self.batch_size)

建立訓練、驗證和測試資料集

def prepare_dataset(dataframe):
    ds = dataframe_to_dataset(dataframe)
    return ds


train_ds = prepare_dataset(train_df)
validation_ds = prepare_dataset(val_df)
test_ds = prepare_dataset(test_df)

模型建立工具

我們的最終模型將接受兩個圖像及其文字對應項。雖然圖像將直接饋送到模型中,但文字輸入將首先進行預處理,然後才會進入模型。以下是此方法的視覺說明:

模型包含以下元素:

  • 圖像的獨立編碼器。我們將使用在 ImageNet-1k 資料集上預先訓練的 ResNet50V2 作為此編碼器。
  • 圖像的獨立編碼器。預先訓練的 BERT 將用於此編碼器。

提取個別嵌入後,它們將會投影到相同的空間中。最後,它們的投影將會串連在一起,並饋送到最終分類層。

這是一個涉及以下類別的多類別分類問題:

  • 無蘊涵
  • 暗示
  • 矛盾

project_embeddings()create_vision_encoder()create_text_encoder() 工具程式參考自此範例

投影工具程式

def project_embeddings(
    embeddings, num_projection_layers, projection_dims, dropout_rate
):
    projected_embeddings = keras.layers.Dense(units=projection_dims)(embeddings)
    for _ in range(num_projection_layers):
        x = keras.ops.nn.gelu(projected_embeddings)
        x = keras.layers.Dense(projection_dims)(x)
        x = keras.layers.Dropout(dropout_rate)(x)
        x = keras.layers.Add()([projected_embeddings, x])
        projected_embeddings = keras.layers.LayerNormalization()(x)
    return projected_embeddings

視覺編碼器工具程式

def create_vision_encoder(
    num_projection_layers, projection_dims, dropout_rate, trainable=False
):
    # Load the pre-trained ResNet50V2 model to be used as the base encoder.
    resnet_v2 = keras.applications.ResNet50V2(
        include_top=False, weights="imagenet", pooling="avg"
    )
    # Set the trainability of the base encoder.
    for layer in resnet_v2.layers:
        layer.trainable = trainable

    # Receive the images as inputs.
    image_1 = keras.Input(shape=(128, 128, 3), name="image_1")
    image_2 = keras.Input(shape=(128, 128, 3), name="image_2")

    # Preprocess the input image.
    preprocessed_1 = keras.applications.resnet_v2.preprocess_input(image_1)
    preprocessed_2 = keras.applications.resnet_v2.preprocess_input(image_2)

    # Generate the embeddings for the images using the resnet_v2 model
    # concatenate them.
    embeddings_1 = resnet_v2(preprocessed_1)
    embeddings_2 = resnet_v2(preprocessed_2)
    embeddings = keras.layers.Concatenate()([embeddings_1, embeddings_2])

    # Project the embeddings produced by the model.
    outputs = project_embeddings(
        embeddings, num_projection_layers, projection_dims, dropout_rate
    )
    # Create the vision encoder model.
    return keras.Model([image_1, image_2], outputs, name="vision_encoder")

文字編碼器工具程式

def create_text_encoder(
    num_projection_layers, projection_dims, dropout_rate, trainable=False
):
    # Load the pre-trained BERT BackBone using KerasHub.
    bert = keras_hub.models.BertBackbone.from_preset(
        "bert_base_en_uncased", num_classes=3
    )

    # Set the trainability of the base encoder.
    bert.trainable = trainable

    # Receive the text as inputs.
    bert_input_features = ["padding_mask", "segment_ids", "token_ids"]
    inputs = {
        feature: keras.Input(shape=(256,), dtype="int32", name=feature)
        for feature in bert_input_features
    }

    # Generate embeddings for the preprocessed text using the BERT model.
    embeddings = bert(inputs)["pooled_output"]

    # Project the embeddings produced by the model.
    outputs = project_embeddings(
        embeddings, num_projection_layers, projection_dims, dropout_rate
    )
    # Create the text encoder model.
    return keras.Model(inputs, outputs, name="text_encoder")

多模態模型工具程式

def create_multimodal_model(
    num_projection_layers=1,
    projection_dims=256,
    dropout_rate=0.1,
    vision_trainable=False,
    text_trainable=False,
):
    # Receive the images as inputs.
    image_1 = keras.Input(shape=(128, 128, 3), name="image_1")
    image_2 = keras.Input(shape=(128, 128, 3), name="image_2")

    # Receive the text as inputs.
    bert_input_features = ["padding_mask", "segment_ids", "token_ids"]
    text_inputs = {
        feature: keras.Input(shape=(256,), dtype="int32", name=feature)
        for feature in bert_input_features
    }
    text_inputs = list(text_inputs.values())
    # Create the encoders.
    vision_encoder = create_vision_encoder(
        num_projection_layers, projection_dims, dropout_rate, vision_trainable
    )
    text_encoder = create_text_encoder(
        num_projection_layers, projection_dims, dropout_rate, text_trainable
    )

    # Fetch the embedding projections.
    vision_projections = vision_encoder([image_1, image_2])
    text_projections = text_encoder(text_inputs)

    # Concatenate the projections and pass through the classification layer.
    concatenated = keras.layers.Concatenate()([vision_projections, text_projections])
    outputs = keras.layers.Dense(3, activation="softmax")(concatenated)
    return keras.Model([image_1, image_2, *text_inputs], outputs)


multimodal_model = create_multimodal_model()
keras.utils.plot_model(multimodal_model, show_shapes=True)

png

您也可以透過將 plot_model()expand_nested 參數設定為 True 來檢查各個編碼器的結構。建議您嘗試建構此模型時使用的不同超參數,並觀察最終效能如何受到影響。


編譯並訓練模型

multimodal_model.compile(
    optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"]
)

history = multimodal_model.fit(train_ds, validation_data=validation_ds, epochs=1)
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
  warnings.warn(msg)

/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
  warnings.warn(msg)

1/27 ━━━━━━━━━━━━━━━━━━━━ 45:45 106秒/步 - 準確度:0.0625 - 損失:1.6335

/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
  warnings.warn(msg)



2/27 ━━━━━━━━━━━━━━━━━━━━ 42:14 101秒/步 - 準確度:0.2422 - 損失:1.9508

/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
  warnings.warn(msg)



3/27 ━━━━━━━━━━━━━━━━━━━━ 38:49 97秒/步 - 準確度:0.3524 - 損失:2.0126

/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
  warnings.warn(msg)



4/27 ━━━━━━━━━━━━━━━━━━━━ 37:09 97秒/步 - 準確度:0.4284 - 損失:1.9870

/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
  warnings.warn(msg)



5/27 ━━━━━━━━━━━━━━━━━━━━ 35:08 96秒/步 - 準確度:0.4815 - 損失:1.9855

/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
  warnings.warn(msg)



6/27 ━━━━━━━━━━━━━━━━━━━━ 31:56 91秒/步 - 準確度:0.5210 - 損失:1.9939



7/27 ━━━━━━━━━━━━━━━━━━━━ 29:30 89秒/步 - 準確度:0.5512 - 損失:1.9980



8/27 ━━━━━━━━━━━━━━━━━━━━ 27:12 86秒/步 - 準確度:0.5750 - 損失:2.0061



9/27 ━━━━━━━━━━━━━━━━━━━━ 25:15 84秒/步 - 準確度:0.5956 - 損失:1.9959



10/27 ━━━━━━━━━━━━━━━━━━━━ 23:33 83秒/步 - 準確度:0.6120 - 損失:1.9738

/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
  warnings.warn(msg)



11/27 ━━━━━━━━━━━━━━━━━━━━ 22:09 83秒/步 - 準確度:0.6251 - 損失:1.9579

/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
  warnings.warn(msg)



12/27 ━━━━━━━━━━━━━━━━━━━━ 20:59 84秒/步 - 準確度:0.6357 - 損失:1.9524

/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
  warnings.warn(msg)



13/27 ━━━━━━━━━━━━━━━━━━━━ 19:44 85秒/步 - 準確度:0.6454 - 損失:1.9439

/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
  warnings.warn(msg)



14/27 ━━━━━━━━━━━━━━━━━━━━ 18:22 85秒/步 - 準確度:0.6540 - 損失:1.9346

/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(23, 256))', 'Tensor(shape=(23, 256))', 'Tensor(shape=(23, 256))']
  warnings.warn(msg)



15/27 ━━━━━━━━━━━━━━━━━━━━ 16:52 84秒/步 - 準確度:0.6621 - 損失:1.9213

/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
  warnings.warn(msg)



16/27 ━━━━━━━━━━━━━━━━━━━━ 15:29 85秒/步 - 準確度:0.6693 - 損失:1.9101

/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
  warnings.warn(msg)



17/27 ━━━━━━━━━━━━━━━━━━━━ 14:08 85秒/步 - 準確度:0.6758 - 損失:1.9021

/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
  warnings.warn(msg)



18/27 ━━━━━━━━━━━━━━━━━━━━ 12:45 85秒/步 - 準確度:0.6819 - 損失:1.8916

/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
  warnings.warn(msg)



19/27 ━━━━━━━━━━━━━━━━━━━━ 11:24 86秒/步 - 準確度:0.6874 - 損失:1.8851

/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
  warnings.warn(msg)



20/27 ━━━━━━━━━━━━━━━━━━━━ 10:00 86秒/步 - 準確度:0.6925 - 損失:1.8791

/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
  warnings.warn(msg)



21/27 ━━━━━━━━━━━━━━━━━━━━ 8:36 86秒/步 - 準確度:0.6976 - 損失:1.8699

/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
  warnings.warn(msg)



22/27 ━━━━━━━━━━━━━━━━━━━━ 7:11 86秒/步 - 準確度:0.7020 - 損失:1.8623

/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
  warnings.warn(msg)



23/27 ━━━━━━━━━━━━━━━━━━━━ 5:46 87秒/步 - 準確度:0.7061 - 損失:1.8573

/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
  warnings.warn(msg)



24/27 ━━━━━━━━━━━━━━━━━━━━ 4:20 87秒/步 - 準確度:0.7100 - 損失:1.8534

/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
  warnings.warn(msg)



25/27 ━━━━━━━━━━━━━━━━━━━━ 2:54 87秒/步 - 準確度:0.7136 - 損失:1.8494

/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
  warnings.warn(msg)



26/27 ━━━━━━━━━━━━━━━━━━━━ 1:27 87秒/步 - 準確度:0.7170 - 損失:1.8449

/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
  warnings.warn(msg)



27/27 ━━━━━━━━━━━━━━━━━━━━ 0秒 88秒/步 - 準確度:0.7201 - 損失:1.8414

/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/PIL/Image.py:1054: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  warnings.warn(

/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/PIL/Image.py:1054: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  warnings.warn(

/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
  warnings.warn(msg)

/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(13, 256))', 'Tensor(shape=(13, 256))', 'Tensor(shape=(13, 256))']
  warnings.warn(msg)



27/27 ━━━━━━━━━━━━━━━━━━━━ 2508秒 92秒/步 - 準確度:0.7231 - 損失:1.8382 - 驗證準確度:0.8222 - 驗證損失:1.7304


評估模型

_, acc = multimodal_model.evaluate(test_ds)
print(f"Accuracy on the test set: {round(acc * 100, 2)}%.")
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/PIL/Image.py:1054: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  warnings.warn(

/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/PIL/Image.py:1054: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  warnings.warn(

/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
  warnings.warn(msg)

1/4 ━━━━━━━━━━━━━━━━━━━━ 5:32 111秒/步 - 準確度:0.7812 - 損失:1.9384



2/4 ━━━━━━━━━━━━━━━━━━━━ 2:10 65秒/步 - 準確度:0.7969 - 損失:1.8931



3/4 ━━━━━━━━━━━━━━━━━━━━ 1:05 65秒/步 - 準確度:0.8056 - 損失:1.8200

/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(4, 256))', 'Tensor(shape=(4, 256))', 'Tensor(shape=(4, 256))']
  warnings.warn(msg)



4/4 ━━━━━━━━━━━━━━━━━━━━ 0秒 49秒/步 - 準確度:0.8092 - 損失:1.8075



4/4 ━━━━━━━━━━━━━━━━━━━━ 256秒 49秒/步 - 準確度:0.8113 - 損失:1.8000

Accuracy on the test set: 82.0%.

關於訓練的其他注意事項

加入正規化:

訓練記錄顯示模型開始過度擬合,且可能受益於正規化。Dropout(Srivastava 等人)是一種簡單但功能強大的正規化技術,我們可以在模型中使用。但是我們應該如何在此處應用它呢?

我們始終可以在模型的不同層之間引入 Dropout(keras.layers.Dropout)。但這裡還有另一個方法。我們的模型預期來自兩種不同資料模式的輸入。如果在推論期間其中一種模式不存在會怎麼樣?為了考量到這一點,我們可以將 Dropout 引入到個別的投影中,就在它們被串連在一起之前

vision_projections = keras.layers.Dropout(rate)(vision_projections)
text_projections = keras.layers.Dropout(rate)(text_projections)
concatenated = keras.layers.Concatenate()([vision_projections, text_projections])

關注重要事項:

圖片的所有部分是否都與它們的文字對應部分相等?情況可能不是這樣。為了讓我們的模型只專注於圖片中與其對應文字部分最相關的重要位元,我們可以使用「交叉注意力」。

# Embeddings.
vision_projections = vision_encoder([image_1, image_2])
text_projections = text_encoder(text_inputs)

# Cross-attention (Luong-style).
query_value_attention_seq = keras.layers.Attention(use_scale=True, dropout=0.2)(
    [vision_projections, text_projections]
)
# Concatenate.
concatenated = keras.layers.Concatenate()([vision_projections, text_projections])
contextual = keras.layers.Concatenate()([concatenated, query_value_attention_seq])

若要查看實際運作情形,請參閱此筆記本

處理類別不平衡:

此資料集存在類別不平衡問題。調查上述模型的混淆矩陣會發現,它在少數類別上的表現很差。如果我們使用加權損失,那麼訓練會更有指導性。您可以查看此筆記本,其中在模型訓練期間會考量類別不平衡問題。

僅使用文字輸入:

另外,如果我們只在蘊含任務中使用文字輸入會怎麼樣?由於在社群媒體平台上遇到的文字輸入的本質,僅文字輸入會損害最終效能。在類似的訓練設定下,僅使用文字輸入,我們在相同的測試集上達到 67.14% 的前 1 準確度。請參閱此筆記本以了解詳細資訊。

最後,這是一個比較蘊含任務所採取不同方法的表格

類型 標準
交叉熵
損失加權
交叉熵
焦點損失
多模態 77.86% 67.86% 86.43%
僅文字 67.14% 11.43% 37.86%

您可以查看此儲存庫,以深入瞭解如何進行實驗以取得這些數字。


最後的說明

  • 在此範例中使用的架構對於可用的訓練資料點數量而言太大。它將受益於更多資料。
  • 我們使用了原始 BERT 模型的較小變體。使用較大變體很可能會提高此效能。TensorFlow Hub 提供許多不同的 BERT 模型,您可以試驗它們。
  • 我們讓預先訓練的模型保持凍結。在多模態蘊含任務中微調它們可能會產生更好的效能。
  • 我們為多模態蘊含任務建立了一個簡單的基準模型。已經提出了各種方法來解決蘊含問題。此簡報投影片組來自識別多模態蘊含教學課程,提供全面的概述。

您可以使用託管在Hugging Face Hub上的訓練模型,並在Hugging Face Spaces上試用示範。