作者: Aritra Roy Gosthipaty, Sayak Paul (同等貢獻),由 Muhammad Anas Raza 轉換為 Keras 3
建立日期 2021/12/10
上次修改日期 2023/08/14
說明: 為 Vision Transformer 自適應生成較少數量的標記。
Vision Transformer (Dosovitskiy 等人) 和許多其他基於 Transformer 的架構 (Liu 等人、Yuan 等人 等) 在影像辨識方面取得了良好的成果。以下簡要概述了 Vision Transformer 架構中用於影像分類的元件
如果我們採用 224x224 的影像並提取 16x16 的區塊,則每個影像總共得到 196 個區塊(也稱為標記)。當我們提高解析度時,區塊的數量會增加,導致更高的記憶體佔用量。我們是否可以在不影響效能的情況下使用較少數量的區塊?Ryoo 等人在 TokenLearner:用於視訊的自適應時空標記化 中研究了這個問題。他們引入了一個名為 TokenLearner 的新型模組,它可以幫助減少 Vision Transformer (ViT) 使用的區塊數量。透過將 TokenLearner 納入標準 ViT 架構中,他們能夠減少模型使用的計算量(以 FLOPS 為單位)。
在此範例中,我們實作了 TokenLearner 模組,並使用迷你 ViT 和 CIFAR-10 資料集展示其效能。我們使用了以下參考資料
import keras
from keras import layers
from keras import ops
from tensorflow import data as tf_data
from datetime import datetime
import matplotlib.pyplot as plt
import numpy as np
import math
請隨意更改超參數並檢查您的結果。建立關於架構的直覺的最佳方法是透過實驗。
# DATA
BATCH_SIZE = 256
AUTO = tf_data.AUTOTUNE
INPUT_SHAPE = (32, 32, 3)
NUM_CLASSES = 10
# OPTIMIZER
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-4
# TRAINING
EPOCHS = 1
# AUGMENTATION
IMAGE_SIZE = 48 # We will resize input images to this size.
PATCH_SIZE = 6 # Size of the patches to be extracted from the input images.
NUM_PATCHES = (IMAGE_SIZE // PATCH_SIZE) ** 2
# ViT ARCHITECTURE
LAYER_NORM_EPS = 1e-6
PROJECTION_DIM = 128
NUM_HEADS = 4
NUM_LAYERS = 4
MLP_UNITS = [
PROJECTION_DIM * 2,
PROJECTION_DIM,
]
# TOKENLEARNER
NUM_TOKENS = 4
# Load the CIFAR-10 dataset.
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
(x_train, y_train), (x_val, y_val) = (
(x_train[:40000], y_train[:40000]),
(x_train[40000:], y_train[40000:]),
)
print(f"Training samples: {len(x_train)}")
print(f"Validation samples: {len(x_val)}")
print(f"Testing samples: {len(x_test)}")
# Convert to tf.data.Dataset objects.
train_ds = tf_data.Dataset.from_tensor_slices((x_train, y_train))
train_ds = train_ds.shuffle(BATCH_SIZE * 100).batch(BATCH_SIZE).prefetch(AUTO)
val_ds = tf_data.Dataset.from_tensor_slices((x_val, y_val))
val_ds = val_ds.batch(BATCH_SIZE).prefetch(AUTO)
test_ds = tf_data.Dataset.from_tensor_slices((x_test, y_test))
test_ds = test_ds.batch(BATCH_SIZE).prefetch(AUTO)
Training samples: 40000
Validation samples: 10000
Testing samples: 10000
擴增管線包含
data_augmentation = keras.Sequential(
[
layers.Rescaling(1 / 255.0),
layers.Resizing(INPUT_SHAPE[0] + 20, INPUT_SHAPE[0] + 20),
layers.RandomCrop(IMAGE_SIZE, IMAGE_SIZE),
layers.RandomFlip("horizontal"),
],
name="data_augmentation",
)
請注意,影像資料擴增層不會在推論時套用資料轉換。這表示當這些層使用 training=False
呼叫時,它們的行為會不同。請參閱文件以瞭解更多詳細資料。
Transformer 架構主要由 多頭自我注意力 層和 全連接前饋 網路 (MLP) 組成。這兩個元件都是置換不變的:它們不知道特徵的順序。
為了克服這個問題,我們將位置資訊注入標記中。position_embedding
函數將此位置資訊新增至線性投影的標記。
class PatchEncoder(layers.Layer):
def __init__(self, num_patches, projection_dim):
super().__init__()
self.num_patches = num_patches
self.position_embedding = layers.Embedding(
input_dim=num_patches, output_dim=projection_dim
)
def call(self, patch):
positions = ops.expand_dims(
ops.arange(start=0, stop=self.num_patches, step=1), axis=0
)
encoded = patch + self.position_embedding(positions)
return encoded
def get_config(self):
config = super().get_config()
config.update({"num_patches": self.num_patches})
return config
這是我們 Transformer 的全連接前饋區塊。
def mlp(x, dropout_rate, hidden_units):
# Iterate over the hidden units and
# add Dense => Dropout.
for units in hidden_units:
x = layers.Dense(units, activation=ops.gelu)(x)
x = layers.Dropout(dropout_rate)(x)
return x
下圖呈現了該模組的圖示概述 (來源)。
TokenLearner 模組接收形狀為影像的張量作為輸入。然後,它將其傳遞多個單通道卷積層,以提取關注於輸入不同部分的各個空間注意力圖。然後將這些注意力圖逐元素乘以輸入,並將結果與池化聚合。這個池化的輸出可以被視為輸入的摘要,並且具有比原始輸出(例如,196 個)少得多的區塊(例如,8 個)。
使用多個卷積層有助於表達能力。施加一種空間注意力的形式有助於保留輸入中的相關資訊。這兩個元件對於使 TokenLearner 運作至關重要,尤其是在我們大幅減少區塊數量時。
def token_learner(inputs, number_of_tokens=NUM_TOKENS):
# Layer normalize the inputs.
x = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(inputs) # (B, H, W, C)
# Applying Conv2D => Reshape => Permute
# The reshape and permute is done to help with the next steps of
# multiplication and Global Average Pooling.
attention_maps = keras.Sequential(
[
# 3 layers of conv with gelu activation as suggested
# in the paper.
layers.Conv2D(
filters=number_of_tokens,
kernel_size=(3, 3),
activation=ops.gelu,
padding="same",
use_bias=False,
),
layers.Conv2D(
filters=number_of_tokens,
kernel_size=(3, 3),
activation=ops.gelu,
padding="same",
use_bias=False,
),
layers.Conv2D(
filters=number_of_tokens,
kernel_size=(3, 3),
activation=ops.gelu,
padding="same",
use_bias=False,
),
# This conv layer will generate the attention maps
layers.Conv2D(
filters=number_of_tokens,
kernel_size=(3, 3),
activation="sigmoid", # Note sigmoid for [0, 1] output
padding="same",
use_bias=False,
),
# Reshape and Permute
layers.Reshape((-1, number_of_tokens)), # (B, H*W, num_of_tokens)
layers.Permute((2, 1)),
]
)(
x
) # (B, num_of_tokens, H*W)
# Reshape the input to align it with the output of the conv block.
num_filters = inputs.shape[-1]
inputs = layers.Reshape((1, -1, num_filters))(inputs) # inputs == (B, 1, H*W, C)
# Element-Wise multiplication of the attention maps and the inputs
attended_inputs = (
ops.expand_dims(attention_maps, axis=-1) * inputs
) # (B, num_tokens, H*W, C)
# Global average pooling the element wise multiplication result.
outputs = ops.mean(attended_inputs, axis=2) # (B, num_tokens, C)
return outputs
def transformer(encoded_patches):
# Layer normalization 1.
x1 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(encoded_patches)
# Multi Head Self Attention layer 1.
attention_output = layers.MultiHeadAttention(
num_heads=NUM_HEADS, key_dim=PROJECTION_DIM, dropout=0.1
)(x1, x1)
# Skip connection 1.
x2 = layers.Add()([attention_output, encoded_patches])
# Layer normalization 2.
x3 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(x2)
# MLP layer 1.
x4 = mlp(x3, hidden_units=MLP_UNITS, dropout_rate=0.1)
# Skip connection 2.
encoded_patches = layers.Add()([x4, x2])
return encoded_patches
def create_vit_classifier(use_token_learner=True, token_learner_units=NUM_TOKENS):
inputs = layers.Input(shape=INPUT_SHAPE) # (B, H, W, C)
# Augment data.
augmented = data_augmentation(inputs)
# Create patches and project the pathces.
projected_patches = layers.Conv2D(
filters=PROJECTION_DIM,
kernel_size=(PATCH_SIZE, PATCH_SIZE),
strides=(PATCH_SIZE, PATCH_SIZE),
padding="VALID",
)(augmented)
_, h, w, c = projected_patches.shape
projected_patches = layers.Reshape((h * w, c))(
projected_patches
) # (B, number_patches, projection_dim)
# Add positional embeddings to the projected patches.
encoded_patches = PatchEncoder(
num_patches=NUM_PATCHES, projection_dim=PROJECTION_DIM
)(
projected_patches
) # (B, number_patches, projection_dim)
encoded_patches = layers.Dropout(0.1)(encoded_patches)
# Iterate over the number of layers and stack up blocks of
# Transformer.
for i in range(NUM_LAYERS):
# Add a Transformer block.
encoded_patches = transformer(encoded_patches)
# Add TokenLearner layer in the middle of the
# architecture. The paper suggests that anywhere
# between 1/2 or 3/4 will work well.
if use_token_learner and i == NUM_LAYERS // 2:
_, hh, c = encoded_patches.shape
h = int(math.sqrt(hh))
encoded_patches = layers.Reshape((h, h, c))(
encoded_patches
) # (B, h, h, projection_dim)
encoded_patches = token_learner(
encoded_patches, token_learner_units
) # (B, num_tokens, c)
# Layer normalization and Global average pooling.
representation = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(encoded_patches)
representation = layers.GlobalAvgPool1D()(representation)
# Classify outputs.
outputs = layers.Dense(NUM_CLASSES, activation="softmax")(representation)
# Create the Keras model.
model = keras.Model(inputs=inputs, outputs=outputs)
return model
如 TokenLearner 論文中所示,在網路的中間加入 TokenLearner 模組幾乎總是更有優勢的。
def run_experiment(model):
# Initialize the AdamW optimizer.
optimizer = keras.optimizers.AdamW(
learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY
)
# Compile the model with the optimizer, loss function
# and the metrics.
model.compile(
optimizer=optimizer,
loss="sparse_categorical_crossentropy",
metrics=[
keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
keras.metrics.SparseTopKCategoricalAccuracy(5, name="top-5-accuracy"),
],
)
# Define callbacks
checkpoint_filepath = "/tmp/checkpoint.weights.h5"
checkpoint_callback = keras.callbacks.ModelCheckpoint(
checkpoint_filepath,
monitor="val_accuracy",
save_best_only=True,
save_weights_only=True,
)
# Train the model.
_ = model.fit(
train_ds,
epochs=EPOCHS,
validation_data=val_ds,
callbacks=[checkpoint_callback],
)
model.load_weights(checkpoint_filepath)
_, accuracy, top_5_accuracy = model.evaluate(test_ds)
print(f"Test accuracy: {round(accuracy * 100, 2)}%")
print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")
vit_token_learner = create_vit_classifier()
run_experiment(vit_token_learner)
157/157 ━━━━━━━━━━━━━━━━━━━━ 303s 2s/step - accuracy: 0.1158 - loss: 2.4798 - top-5-accuracy: 0.5352 - val_accuracy: 0.2206 - val_loss: 2.0292 - val_top-5-accuracy: 0.7688
40/40 ━━━━━━━━━━━━━━━━━━━━ 5s 133ms/step - accuracy: 0.2298 - loss: 2.0179 - top-5-accuracy: 0.7723
Test accuracy: 22.9%
Test top 5 accuracy: 77.22%
我們實驗了在我們實作的迷你 ViT 中包含和不包含 TokenLearner(使用此範例中呈現的相同超參數)。以下是我們的結果
TokenLearner | 輸入 # tokens TokenLearner |
Top-1 準確率 (5 次運行的平均值) |
GFLOPs | TensorBoard |
---|---|---|---|---|
N | - | 56.112% | 0.0184 | 連結 |
Y | 8 | 56.55% | 0.0153 | 連結 |
N | - | 56.37% | 0.0184 | 連結 |
Y | 4 | 56.4980% | 0.0147 | 連結 |
N | - (# Transformer 層:8) | 55.36% | 0.0359 | 連結 |
TokenLearner 能夠穩定地勝過我們沒有模組的迷你 ViT。有趣的是,它也能夠勝過更深層的迷你 ViT 版本(具有 8 層)。作者在論文中也報告了類似的觀察,並將其歸因於 TokenLearner 的適應性。
還應該注意的是,加入 TokenLearner 模組後,FLOPs 計數顯著減少。在較少的 FLOPs 計數下,TokenLearner 模組能夠提供更好的結果。這與作者的發現非常一致。
此外,作者引入了用於較小訓練數據量的較新版本的 TokenLearner。引用作者的話
這個版本不是使用 4 個具有小通道的卷積層來實現空間注意力,而是使用 2 個具有更多通道的分組卷積層。它還使用 softmax 而不是 sigmoid。我們證實,當訓練數據有限時,例如從頭開始使用 ImageNet1K 進行訓練時,此版本效果更好。
我們對這個模組進行了實驗,並在下表中總結了結果
# 群組 | # Tokens | Top-1 準確率 | GFLOPs | TensorBoard |
---|---|---|---|---|
4 | 4 | 54.638% | 0.0149 | 連結 |
8 | 8 | 54.898% | 0.0146 | 連結 |
4 | 8 | 55.196% | 0.0149 | 連結 |
請注意,我們使用了此範例中呈現的相同超參數。我們的實作可在此筆記本中找到。我們承認,使用這個新的 TokenLearner 模組的結果與預期略有偏差,這可能會透過超參數調整來改善。
注意:為了計算我們模型的 FLOPs,我們使用了來自此儲存庫的此工具。
您可能已經注意到,添加 TokenLearner 模組會增加基礎網路的參數數量。但正如 Dehghani 等人 所展示的那樣,這並不意味著它的效率較低。Bello 等人也報告了類似的發現。TokenLearner 模組有助於減少整體網路中的 FLOPS,從而有助於減少記憶體佔用。
我們感謝 JarvisLabs 和 Google 開發者專家計畫在 GPU 額度方面提供的幫助。此外,我們感謝 Michael Ryoo(TokenLearner 的第一作者)進行了富有成效的討論。