作者: Khalid Salama
建立日期 2021/04/30
上次修改日期 2023/12/30
說明: 實作 Perceiver 模型進行圖像分類。
此範例實作 Andrew Jaegle 等人提出的 Perceiver:使用迭代注意力的通用感知 模型,用於圖像分類,並在 CIFAR-100 資料集上進行示範。
Perceiver 模型利用非對稱注意力機制,迭代地將輸入提煉成緊密的潛在瓶頸,使其能夠擴展以處理非常大的輸入。
換句話說:假設您的輸入資料陣列(例如圖像)具有 M
個元素(即圖塊),其中 M
很大。在標準的 Transformer 模型中,會對 M
個元素執行自我注意力操作。此操作的複雜度為 O(M^2)
。然而,Perceiver 模型會建立一個具有 N
個元素的潛在陣列,其中 N << M
,並迭代執行兩個操作
O(M.N)
。O(N^2)
。此範例需要 Keras 3.0 或更高版本。
import keras
from keras import layers, activations, ops
num_classes = 100
input_shape = (32, 32, 3)
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()
print(f"x_train shape: {x_train.shape} - y_train shape: {y_train.shape}")
print(f"x_test shape: {x_test.shape} - y_test shape: {y_test.shape}")
x_train shape: (50000, 32, 32, 3) - y_train shape: (50000, 1)
x_test shape: (10000, 32, 32, 3) - y_test shape: (10000, 1)
learning_rate = 0.001
weight_decay = 0.0001
batch_size = 64
num_epochs = 2 # You should actually use 50 epochs!
dropout_rate = 0.2
image_size = 64 # We'll resize input images to this size.
patch_size = 2 # Size of the patches to be extract from the input images.
num_patches = (image_size // patch_size) ** 2 # Size of the data array.
latent_dim = 256 # Size of the latent array.
projection_dim = 256 # Embedding size of each element in the data and latent arrays.
num_heads = 8 # Number of Transformer heads.
ffn_units = [
projection_dim,
projection_dim,
] # Size of the Transformer Feedforward network.
num_transformer_blocks = 4
num_iterations = 2 # Repetitions of the cross-attention and Transformer modules.
classifier_units = [
projection_dim,
num_classes,
] # Size of the Feedforward network of the final classifier.
print(f"Image size: {image_size} X {image_size} = {image_size ** 2}")
print(f"Patch size: {patch_size} X {patch_size} = {patch_size ** 2} ")
print(f"Patches per image: {num_patches}")
print(f"Elements per patch (3 channels): {(patch_size ** 2) * 3}")
print(f"Latent array shape: {latent_dim} X {projection_dim}")
print(f"Data array shape: {num_patches} X {projection_dim}")
Image size: 64 X 64 = 4096
Patch size: 2 X 2 = 4
Patches per image: 1024
Elements per patch (3 channels): 12
Latent array shape: 256 X 256
Data array shape: 1024 X 256
請注意,為了使用每個像素作為資料陣列中的個別輸入,請將 patch_size
設定為 1。
data_augmentation = keras.Sequential(
[
layers.Normalization(),
layers.Resizing(image_size, image_size),
layers.RandomFlip("horizontal"),
layers.RandomZoom(height_factor=0.2, width_factor=0.2),
],
name="data_augmentation",
)
# Compute the mean and the variance of the training data for normalization.
data_augmentation.layers[0].adapt(x_train)
def create_ffn(hidden_units, dropout_rate):
ffn_layers = []
for units in hidden_units[:-1]:
ffn_layers.append(layers.Dense(units, activation=activations.gelu))
ffn_layers.append(layers.Dense(units=hidden_units[-1]))
ffn_layers.append(layers.Dropout(dropout_rate))
ffn = keras.Sequential(ffn_layers)
return ffn
class Patches(layers.Layer):
def __init__(self, patch_size):
super().__init__()
self.patch_size = patch_size
def call(self, images):
batch_size = ops.shape(images)[0]
patches = ops.image.extract_patches(
image=images,
size=(self.patch_size, self.patch_size),
strides=(self.patch_size, self.patch_size),
dilation_rate=1,
padding="valid",
)
patch_dims = patches.shape[-1]
patches = ops.reshape(patches, [batch_size, -1, patch_dims])
return patches
PatchEncoder
層將線性轉換圖塊,透過將其投影到大小為 latent_dim
的向量中。此外,它會將可學習的位置嵌入新增到投影向量中。
請注意,原始的 Perceiver 論文使用傅立葉特徵位置編碼。
class PatchEncoder(layers.Layer):
def __init__(self, num_patches, projection_dim):
super().__init__()
self.num_patches = num_patches
self.projection = layers.Dense(units=projection_dim)
self.position_embedding = layers.Embedding(
input_dim=num_patches, output_dim=projection_dim
)
def call(self, patches):
positions = ops.arange(start=0, stop=self.num_patches, step=1)
encoded = self.projection(patches) + self.position_embedding(positions)
return encoded
Perceiver 由兩個模組組成:交叉注意力模組和具有自我注意力的標準 Transformer。
交叉注意力機制期望輸入一個 (latent_dim, projection_dim)
的潛在向量陣列,以及一個 (data_dim, projection_dim)
的資料向量陣列,並產生一個 (latent_dim, projection_dim)
的潛在向量陣列作為輸出。為了應用交叉注意力機制,query
向量由潛在向量陣列生成,而 key
和 value
向量則由編碼後的圖像生成。
請注意,在此範例中,資料向量陣列即為圖像,其中 data_dim
設定為 num_patches
(圖塊數量)。
def create_cross_attention_module(
latent_dim, data_dim, projection_dim, ffn_units, dropout_rate
):
inputs = {
# Recieve the latent array as an input of shape [1, latent_dim, projection_dim].
"latent_array": layers.Input(
shape=(latent_dim, projection_dim), name="latent_array"
),
# Recieve the data_array (encoded image) as an input of shape [batch_size, data_dim, projection_dim].
"data_array": layers.Input(shape=(data_dim, projection_dim), name="data_array"),
}
# Apply layer norm to the inputs
latent_array = layers.LayerNormalization(epsilon=1e-6)(inputs["latent_array"])
data_array = layers.LayerNormalization(epsilon=1e-6)(inputs["data_array"])
# Create query tensor: [1, latent_dim, projection_dim].
query = layers.Dense(units=projection_dim)(latent_array)
# Create key tensor: [batch_size, data_dim, projection_dim].
key = layers.Dense(units=projection_dim)(data_array)
# Create value tensor: [batch_size, data_dim, projection_dim].
value = layers.Dense(units=projection_dim)(data_array)
# Generate cross-attention outputs: [batch_size, latent_dim, projection_dim].
attention_output = layers.Attention(use_scale=True, dropout=0.1)(
[query, key, value], return_attention_scores=False
)
# Skip connection 1.
attention_output = layers.Add()([attention_output, latent_array])
# Apply layer norm.
attention_output = layers.LayerNormalization(epsilon=1e-6)(attention_output)
# Apply Feedforward network.
ffn = create_ffn(hidden_units=ffn_units, dropout_rate=dropout_rate)
outputs = ffn(attention_output)
# Skip connection 2.
outputs = layers.Add()([outputs, attention_output])
# Create the Keras model.
model = keras.Model(inputs=inputs, outputs=outputs)
return model
轉換器模組期望將交叉注意力模組的輸出潛在向量作為輸入,對其 latent_dim
個元素應用多頭自注意力機制,接著透過前饋網路,產生另一個 (latent_dim, projection_dim)
的潛在向量陣列。
def create_transformer_module(
latent_dim,
projection_dim,
num_heads,
num_transformer_blocks,
ffn_units,
dropout_rate,
):
# input_shape: [1, latent_dim, projection_dim]
inputs = layers.Input(shape=(latent_dim, projection_dim))
x0 = inputs
# Create multiple layers of the Transformer block.
for _ in range(num_transformer_blocks):
# Apply layer normalization 1.
x1 = layers.LayerNormalization(epsilon=1e-6)(x0)
# Create a multi-head self-attention layer.
attention_output = layers.MultiHeadAttention(
num_heads=num_heads, key_dim=projection_dim, dropout=0.1
)(x1, x1)
# Skip connection 1.
x2 = layers.Add()([attention_output, x0])
# Apply layer normalization 2.
x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
# Apply Feedforward network.
ffn = create_ffn(hidden_units=ffn_units, dropout_rate=dropout_rate)
x3 = ffn(x3)
# Skip connection 2.
x0 = layers.Add()([x3, x2])
# Create the Keras model.
model = keras.Model(inputs=inputs, outputs=x0)
return model
感知器模型重複交叉注意力模組和轉換器模組 num_iterations
次—使用共享權重和跳躍連接—以允許潛在向量陣列根據需要從輸入圖像中迭代提取資訊。
class Perceiver(keras.Model):
def __init__(
self,
patch_size,
data_dim,
latent_dim,
projection_dim,
num_heads,
num_transformer_blocks,
ffn_units,
dropout_rate,
num_iterations,
classifier_units,
):
super().__init__()
self.latent_dim = latent_dim
self.data_dim = data_dim
self.patch_size = patch_size
self.projection_dim = projection_dim
self.num_heads = num_heads
self.num_transformer_blocks = num_transformer_blocks
self.ffn_units = ffn_units
self.dropout_rate = dropout_rate
self.num_iterations = num_iterations
self.classifier_units = classifier_units
def build(self, input_shape):
# Create latent array.
self.latent_array = self.add_weight(
shape=(self.latent_dim, self.projection_dim),
initializer="random_normal",
trainable=True,
)
# Create patching module.war
self.patch_encoder = PatchEncoder(self.data_dim, self.projection_dim)
# Create cross-attenion module.
self.cross_attention = create_cross_attention_module(
self.latent_dim,
self.data_dim,
self.projection_dim,
self.ffn_units,
self.dropout_rate,
)
# Create Transformer module.
self.transformer = create_transformer_module(
self.latent_dim,
self.projection_dim,
self.num_heads,
self.num_transformer_blocks,
self.ffn_units,
self.dropout_rate,
)
# Create global average pooling layer.
self.global_average_pooling = layers.GlobalAveragePooling1D()
# Create a classification head.
self.classification_head = create_ffn(
hidden_units=self.classifier_units, dropout_rate=self.dropout_rate
)
super().build(input_shape)
def call(self, inputs):
# Augment data.
augmented = data_augmentation(inputs)
# Create patches.
patches = self.patcher(augmented)
# Encode patches.
encoded_patches = self.patch_encoder(patches)
# Prepare cross-attention inputs.
cross_attention_inputs = {
"latent_array": ops.expand_dims(self.latent_array, 0),
"data_array": encoded_patches,
}
# Apply the cross-attention and the Transformer modules iteratively.
for _ in range(self.num_iterations):
# Apply cross-attention from the latent array to the data array.
latent_array = self.cross_attention(cross_attention_inputs)
# Apply self-attention Transformer to the latent array.
latent_array = self.transformer(latent_array)
# Set the latent array of the next iteration.
cross_attention_inputs["latent_array"] = latent_array
# Apply global average pooling to generate a [batch_size, projection_dim] repesentation tensor.
representation = self.global_average_pooling(latent_array)
# Generate logits.
logits = self.classification_head(representation)
return logits
def run_experiment(model):
# Create ADAM instead of LAMB optimizer with weight decay. (LAMB isn't supported yet)
optimizer = keras.optimizers.Adam(learning_rate=learning_rate)
# Compile the model.
model.compile(
optimizer=optimizer,
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[
keras.metrics.SparseCategoricalAccuracy(name="acc"),
keras.metrics.SparseTopKCategoricalAccuracy(5, name="top5-acc"),
],
)
# Create a learning rate scheduler callback.
reduce_lr = keras.callbacks.ReduceLROnPlateau(
monitor="val_loss", factor=0.2, patience=3
)
# Create an early stopping callback.
early_stopping = keras.callbacks.EarlyStopping(
monitor="val_loss", patience=15, restore_best_weights=True
)
# Fit the model.
history = model.fit(
x=x_train,
y=y_train,
batch_size=batch_size,
epochs=num_epochs,
validation_split=0.1,
callbacks=[early_stopping, reduce_lr],
)
_, accuracy, top_5_accuracy = model.evaluate(x_test, y_test)
print(f"Test accuracy: {round(accuracy * 100, 2)}%")
print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")
# Return history to plot learning curves.
return history
請注意,在目前的設定下,使用 V100 GPU 訓練感知器模型大約需要 200 秒。
perceiver_classifier = Perceiver(
patch_size,
num_patches,
latent_dim,
projection_dim,
num_heads,
num_transformer_blocks,
ffn_units,
dropout_rate,
num_iterations,
classifier_units,
)
history = run_experiment(perceiver_classifier)
Test accuracy: 0.91%
Test top 5 accuracy: 5.2%
經過 40 個 epoch 後,感知器模型在測試資料上的準確率達到約 53%,前 5 名準確率達到約 81%。
正如感知器論文的消融實驗中所述,您可以通過增加潛在向量陣列的大小、增加潛在向量陣列和資料向量陣列元素的(投影)維度、增加轉換器模組中的區塊數量,以及增加應用交叉注意力和潛在轉換器模組的迭代次數,來獲得更好的結果。您也可以嘗試增加輸入圖像的大小並使用不同的圖塊大小。
感知器模型受益於增加模型大小。然而,較大的模型需要更大的加速器才能容納並有效率地訓練。這就是為什麼在感知器論文中使用 32 個 TPU 核心來運行實驗的原因。