作者: Aritra Roy Gosthipaty、Ayush Thakur(同等貢獻)
建立日期 2022/01/12
上次修改日期 2024/01/15
說明: 一個用於影片分類的基於轉換器的架構。
影片是影像的序列。假設您手邊有一個影像表示模型(CNN、ViT 等)和一個序列模型(RNN、LSTM 等)。我們要求您調整模型以進行影片分類。最簡單的方法是將影像模型應用於個別影格,使用序列模型學習影像特徵序列,然後將分類標頭應用於學習的序列表示。Keras 範例 使用 CNN-RNN 架構進行影片分類 詳細說明了這種方法。或者,您也可以建立一個基於混合轉換器的影片分類模型,如 Keras 範例 使用轉換器進行影片分類 中所示。
在這個範例中,我們最低限度地實作了 Arnab 等人的 ViViT:影片視覺轉換器,這是一個用於影片分類的純粹基於轉換器的模型。作者提出了一種新的嵌入方案和多種轉換器變體來為影片剪輯建模。為簡化起見,我們實作了嵌入方案和轉換器架構的一個變體。
此範例需要 medmnist
套件,您可以執行下面的程式碼儲存格來安裝它。
!pip install -qq medmnist
import os
import io
import imageio
import medmnist
import ipywidgets
import numpy as np
import tensorflow as tf # for data preprocessing only
import keras
from keras import layers, ops
# Setting seed for reproducibility
SEED = 42
os.environ["TF_CUDNN_DETERMINISTIC"] = "1"
keras.utils.set_random_seed(SEED)
超參數是透過超參數搜尋選擇的。您可以在「結論」部分中了解更多關於此過程的資訊。
# DATA
DATASET_NAME = "organmnist3d"
BATCH_SIZE = 32
AUTO = tf.data.AUTOTUNE
INPUT_SHAPE = (28, 28, 28, 1)
NUM_CLASSES = 11
# OPTIMIZER
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-5
# TRAINING
EPOCHS = 60
# TUBELET EMBEDDING
PATCH_SIZE = (8, 8, 8)
NUM_PATCHES = (INPUT_SHAPE[0] // PATCH_SIZE[0]) ** 2
# ViViT ARCHITECTURE
LAYER_NORM_EPS = 1e-6
PROJECTION_DIM = 128
NUM_HEADS = 8
NUM_LAYERS = 8
在我們的範例中,我們使用 MedMNIST v2:適用於 2D 和 3D 生物醫學影像分類的大型輕量級基準測試 資料集。這些影片是輕量級的,易於訓練。
def download_and_prepare_dataset(data_info: dict):
"""Utility function to download the dataset.
Arguments:
data_info (dict): Dataset metadata.
"""
data_path = keras.utils.get_file(origin=data_info["url"], md5_hash=data_info["MD5"])
with np.load(data_path) as data:
# Get videos
train_videos = data["train_images"]
valid_videos = data["val_images"]
test_videos = data["test_images"]
# Get labels
train_labels = data["train_labels"].flatten()
valid_labels = data["val_labels"].flatten()
test_labels = data["test_labels"].flatten()
return (
(train_videos, train_labels),
(valid_videos, valid_labels),
(test_videos, test_labels),
)
# Get the metadata of the dataset
info = medmnist.INFO[DATASET_NAME]
# Get the dataset
prepared_dataset = download_and_prepare_dataset(info)
(train_videos, train_labels) = prepared_dataset[0]
(valid_videos, valid_labels) = prepared_dataset[1]
(test_videos, test_labels) = prepared_dataset[2]
tf.data
管道def preprocess(frames: tf.Tensor, label: tf.Tensor):
"""Preprocess the frames tensors and parse the labels."""
# Preprocess images
frames = tf.image.convert_image_dtype(
frames[
..., tf.newaxis
], # The new axis is to help for further processing with Conv3D layers
tf.float32,
)
# Parse label
label = tf.cast(label, tf.float32)
return frames, label
def prepare_dataloader(
videos: np.ndarray,
labels: np.ndarray,
loader_type: str = "train",
batch_size: int = BATCH_SIZE,
):
"""Utility function to prepare the dataloader."""
dataset = tf.data.Dataset.from_tensor_slices((videos, labels))
if loader_type == "train":
dataset = dataset.shuffle(BATCH_SIZE * 2)
dataloader = (
dataset.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
.batch(batch_size)
.prefetch(tf.data.AUTOTUNE)
)
return dataloader
trainloader = prepare_dataloader(train_videos, train_labels, "train")
validloader = prepare_dataloader(valid_videos, valid_labels, "valid")
testloader = prepare_dataloader(test_videos, test_labels, "test")
在 ViT 中,影像被分成圖塊,然後在空間上扁平化,這個過程稱為標記化。對於影片,可以對個別影格重複此過程。作者建議的統一影格採樣是一種標記化方案,我們從影片剪輯中採樣影格並執行簡單的 ViT 標記化。
![]() |
---|
統一影格採樣 來源 |
管狀嵌入在從影片中擷取時間資訊方面有所不同。首先,我們從影片中擷取體積 – 這些體積包含影格的圖塊和時間資訊。然後將這些體積扁平化以建構影片符記。
![]() |
---|
管狀嵌入 來源 |
class TubeletEmbedding(layers.Layer):
def __init__(self, embed_dim, patch_size, **kwargs):
super().__init__(**kwargs)
self.projection = layers.Conv3D(
filters=embed_dim,
kernel_size=patch_size,
strides=patch_size,
padding="VALID",
)
self.flatten = layers.Reshape(target_shape=(-1, embed_dim))
def call(self, videos):
projected_patches = self.projection(videos)
flattened_patches = self.flatten(projected_patches)
return flattened_patches
此層將位置資訊新增至編碼的影片符記。
class PositionalEncoder(layers.Layer):
def __init__(self, embed_dim, **kwargs):
super().__init__(**kwargs)
self.embed_dim = embed_dim
def build(self, input_shape):
_, num_tokens, _ = input_shape
self.position_embedding = layers.Embedding(
input_dim=num_tokens, output_dim=self.embed_dim
)
self.positions = ops.arange(0, num_tokens, 1)
def call(self, encoded_tokens):
# Encode the positions and add it to the encoded tokens
encoded_positions = self.position_embedding(self.positions)
encoded_tokens = encoded_tokens + encoded_positions
return encoded_tokens
作者建議 4 種視覺轉換器的變體
在這個範例中,我們將為了簡化實作時空注意力模型。以下程式碼片段大量參考了使用 Vision Transformer 進行圖像分類。您也可以參考ViViT 的官方儲存庫,其中包含所有變體,以 JAX 實作。
def create_vivit_classifier(
tubelet_embedder,
positional_encoder,
input_shape=INPUT_SHAPE,
transformer_layers=NUM_LAYERS,
num_heads=NUM_HEADS,
embed_dim=PROJECTION_DIM,
layer_norm_eps=LAYER_NORM_EPS,
num_classes=NUM_CLASSES,
):
# Get the input layer
inputs = layers.Input(shape=input_shape)
# Create patches.
patches = tubelet_embedder(inputs)
# Encode patches.
encoded_patches = positional_encoder(patches)
# Create multiple layers of the Transformer block.
for _ in range(transformer_layers):
# Layer normalization and MHSA
x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
attention_output = layers.MultiHeadAttention(
num_heads=num_heads, key_dim=embed_dim // num_heads, dropout=0.1
)(x1, x1)
# Skip connection
x2 = layers.Add()([attention_output, encoded_patches])
# Layer Normalization and MLP
x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
x3 = keras.Sequential(
[
layers.Dense(units=embed_dim * 4, activation=ops.gelu),
layers.Dense(units=embed_dim, activation=ops.gelu),
]
)(x3)
# Skip connection
encoded_patches = layers.Add()([x3, x2])
# Layer normalization and Global average pooling.
representation = layers.LayerNormalization(epsilon=layer_norm_eps)(encoded_patches)
representation = layers.GlobalAvgPool1D()(representation)
# Classify outputs.
outputs = layers.Dense(units=num_classes, activation="softmax")(representation)
# Create the Keras model.
model = keras.Model(inputs=inputs, outputs=outputs)
return model
def run_experiment():
# Initialize model
model = create_vivit_classifier(
tubelet_embedder=TubeletEmbedding(
embed_dim=PROJECTION_DIM, patch_size=PATCH_SIZE
),
positional_encoder=PositionalEncoder(embed_dim=PROJECTION_DIM),
)
# Compile the model with the optimizer, loss function
# and the metrics.
optimizer = keras.optimizers.Adam(learning_rate=LEARNING_RATE)
model.compile(
optimizer=optimizer,
loss="sparse_categorical_crossentropy",
metrics=[
keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
keras.metrics.SparseTopKCategoricalAccuracy(5, name="top-5-accuracy"),
],
)
# Train the model.
_ = model.fit(trainloader, epochs=EPOCHS, validation_data=validloader)
_, accuracy, top_5_accuracy = model.evaluate(testloader)
print(f"Test accuracy: {round(accuracy * 100, 2)}%")
print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")
return model
model = run_experiment()
Test accuracy: 76.72%
Test top 5 accuracy: 97.54%
NUM_SAMPLES_VIZ = 25
testsamples, labels = next(iter(testloader))
testsamples, labels = testsamples[:NUM_SAMPLES_VIZ], labels[:NUM_SAMPLES_VIZ]
ground_truths = []
preds = []
videos = []
for i, (testsample, label) in enumerate(zip(testsamples, labels)):
# Generate gif
testsample = np.reshape(testsample.numpy(), (-1, 28, 28))
with io.BytesIO() as gif:
imageio.mimsave(gif, (testsample * 255).astype("uint8"), "GIF", fps=5)
videos.append(gif.getvalue())
# Get model prediction
output = model.predict(ops.expand_dims(testsample, axis=0))[0]
pred = np.argmax(output, axis=0)
ground_truths.append(label.numpy().astype("int"))
preds.append(pred)
def make_box_for_grid(image_widget, fit):
"""Make a VBox to hold caption/image for demonstrating option_fit values.
Source: https://ipywidgets.readthedocs.io/en/latest/examples/Widget%20Styling.html
"""
# Make the caption
if fit is not None:
fit_str = "'{}'".format(fit)
else:
fit_str = str(fit)
h = ipywidgets.HTML(value="" + str(fit_str) + "")
# Make the green box with the image widget inside it
boxb = ipywidgets.widgets.Box()
boxb.children = [image_widget]
# Compose into a vertical box
vb = ipywidgets.widgets.VBox()
vb.layout.align_items = "center"
vb.children = [h, boxb]
return vb
boxes = []
for i in range(NUM_SAMPLES_VIZ):
ib = ipywidgets.widgets.Image(value=videos[i], width=100, height=100)
true_class = info["label"][str(ground_truths[i])]
pred_class = info["label"][str(preds[i])]
caption = f"T: {true_class} | P: {pred_class}"
boxes.append(make_box_for_grid(ib, caption))
ipywidgets.widgets.GridBox(
boxes, layout=ipywidgets.widgets.Layout(grid_template_columns="repeat(5, 200px)")
)
使用基本實作,我們在測試資料集上達到約 79-80% 的 Top-1 準確度。
本教學中使用的超參數是透過使用W&B Sweeps執行超參數搜尋後確定的。您可以在這裡找到我們的 sweeps 結果,以及在這裡找到我們對結果的快速分析。
為了進一步改進,您可以研究以下內容
我們想感謝 Anurag Arnab(ViViT 的第一作者)的協助討論。我們感謝 Weights and Biases 計畫提供 GPU 額度。
您可以使用在 Hugging Face Hub 上託管的訓練模型,並在 Hugging Face Spaces 上試用展示。