程式碼範例 / 電腦視覺 / 當遞迴遇上 Transformer

當遞迴遇上 Transformer

作者: Aritra Roy GosthipatySuvaditya Mukherjee
建立日期 2023/03/12
最後修改日期 2024/11/12
描述: 使用時間潛在瓶頸網路進行影像分類。

ⓘ 此範例使用 Keras 3

在 Colab 中檢視 GitHub 原始碼


簡介

一個簡單的遞迴神經網路 (RNN) 顯示出強烈的歸納偏誤,傾向於學習時間壓縮表示方程式 1 顯示了遞迴公式,其中 h_t 是整個輸入序列 x 的壓縮表示(單一向量)。

Equation of RNN
方程式 1:遞迴方程式。(來源:Aritra 和 Suvaditya)

另一方面,Transformer(Vaswani et. al)在學習時間壓縮表示方面幾乎沒有歸納偏誤。Transformer 以其成對注意力機制在自然語言處理 (NLP) 和視覺任務中取得了最先進的成果。

雖然 Transformer 能夠注意輸入序列的不同部分,但注意力的計算在本質上是二次方的。

Didolkar et. al 認為,擁有更壓縮的序列表示可能有利於泛化,因為它可以很容易地被重複使用重新利用,而減少不相關的細節。雖然壓縮是好的,但他們也注意到,過多的壓縮會損害表達能力。

作者提出了一個將計算分為兩個流的解決方案。一個慢速流,在本質上是遞迴的;一個快速流,被參數化為 Transformer。雖然此方法具有引入不同處理流以保留和處理潛在狀態的新穎性,但它與其他工作有相似之處,例如Perceiver 機制(由 Jaegle et. al. 提出)Grounded Language Learning Fast and Slow(由 Hill et. al. 提出)

以下範例探討如何利用新的時間潛在瓶頸機制,對 CIFAR-10 資料集執行影像分類。我們透過自訂 RNNCell 實作來實作此模型,以製作一個高效能向量化的設計。


設定匯入

import os

import keras
from keras import layers, ops, mixed_precision
from keras.optimizers import AdamW
import numpy as np
import random
from matplotlib import pyplot as plt

# Set seed for reproducibility.
keras.utils.set_random_seed(42)

設定所需組態

我們設定了在我們設計的管線中需要的一些組態參數。目前的參數用於 CIFAR10 資料集

該模型也支援 混合精度 設定,它會將模型量化為在可行的情況下使用 16 位元 浮點數,同時根據需要保留一些參數為 32 位元,以確保數值穩定性。這帶來了效能上的好處,因為模型的佔用空間顯著減少,同時在推理時提高了速度。

config = {
    "mixed_precision": True,
    "dataset": "cifar10",
    "train_slice": 40_000,
    "batch_size": 2048,
    "buffer_size": 2048 * 2,
    "input_shape": [32, 32, 3],
    "image_size": 48,
    "num_classes": 10,
    "learning_rate": 1e-4,
    "weight_decay": 1e-4,
    "epochs": 30,
    "patch_size": 4,
    "embed_dim": 64,
    "chunk_size": 8,
    "r": 2,
    "num_layers": 4,
    "ffn_drop": 0.2,
    "attn_drop": 0.2,
    "num_heads": 1,
}

if config["mixed_precision"]:
    policy = mixed_precision.Policy("mixed_float16")
    mixed_precision.set_global_policy(policy)

載入 CIFAR-10 資料集

我們將使用 CIFAR10 資料集來執行我們的實驗。此資料集包含 50,000 張影像的訓練集,用於 10 個類別,標準影像大小為 (32, 32, 3)

它還有一組 10,000 張具有相似特徵的獨立影像。有關資料集的更多資訊,請造訪資料集的官方網站以及 keras.datasets.cifar10 API 參考。

(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
(x_train, y_train), (x_val, y_val) = (
    (x_train[: config["train_slice"]], y_train[: config["train_slice"]]),
    (x_train[config["train_slice"] :], y_train[config["train_slice"] :]),
)

為訓練和驗證/測試管線定義資料擴增

我們定義了單獨的管線,以對我們的資料執行影像擴增。此步驟對於使模型對變化更具穩健性非常重要,有助於模型更好地泛化。我們執行的預處理和擴增步驟如下

  • 重新縮放(訓練、測試):執行此步驟是為了將所有影像像素值從 [0,255] 範圍正規化到 [0,1)。這有助於在稍後的訓練過程中維持數值穩定性。
  • 調整大小 (訓練、測試):我們將影像從原始大小 (32, 32) 調整為 (52, 52)。這樣做是為了配合隨機裁剪,並符合論文中給定的數據規格。
  • RandomCrop (訓練):此層隨機選擇影像的一個裁剪/子區域,大小為 (48, 48)
  • RandomFlip (訓練):此層隨機將所有影像水平翻轉,並保持影像大小不變。
# Build the `train` augmentation pipeline.
train_augmentation = keras.Sequential(
    [
        layers.Rescaling(1 / 255.0, dtype="float32"),
        layers.Resizing(
            config["input_shape"][0] + 20,
            config["input_shape"][0] + 20,
            dtype="float32",
        ),
        layers.RandomCrop(config["image_size"], config["image_size"], dtype="float32"),
        layers.RandomFlip("horizontal", dtype="float32"),
    ],
    name="train_data_augmentation",
)

# Build the `val` and `test` data pipeline.
test_augmentation = keras.Sequential(
    [
        layers.Rescaling(1 / 255.0, dtype="float32"),
        layers.Resizing(config["image_size"], config["image_size"], dtype="float32"),
    ],
    name="test_data_augmentation",
)

# We define functions in place of simple lambda functions to run through the
# [`keras.Sequential`](/api/models/sequential#sequential-class)in order to solve this warning:
# (https://github.com/tensorflow/tensorflow/issues/56089)


def train_map_fn(image, label):
    return train_augmentation(image), label


def test_map_fn(image, label):
    return test_augmentation(image), label

將數據集載入至 PyDataset 物件中

  • 我們取得數據集的 np.ndarray 實例,並將其包裝在一個類別中,該類別包裝了 keras.utils.PyDataset,並使用 keras 預處理層應用擴增。
class Dataset(keras.utils.PyDataset):
    def __init__(
        self, x_data, y_data, batch_size, preprocess_fn=None, shuffle=False, **kwargs
    ):
        if shuffle:
            perm = np.random.permutation(len(x_data))
            x_data = x_data[perm]
            y_data = y_data[perm]
        self.x_data = x_data
        self.y_data = y_data
        self.preprocess_fn = preprocess_fn
        self.batch_size = batch_size
        super().__init__(*kwargs)

    def __len__(self):
        return len(self.x_data) // self.batch_size

    def __getitem__(self, idx):
        batch_x, batch_y = [], []
        for i in range(idx * self.batch_size, (idx + 1) * self.batch_size):
            x, y = self.x_data[i], self.y_data[i]
            if self.preprocess_fn:
                x, y = self.preprocess_fn(x, y)
            batch_x.append(x)
            batch_y.append(y)
        batch_x = ops.stack(batch_x, axis=0)
        batch_y = ops.stack(batch_y, axis=0)
        return batch_x, batch_y


train_ds = Dataset(
    x_train, y_train, config["batch_size"], preprocess_fn=train_map_fn, shuffle=True
)
val_ds = Dataset(x_val, y_val, config["batch_size"], preprocess_fn=test_map_fn)
test_ds = Dataset(x_test, y_test, config["batch_size"], preprocess_fn=test_map_fn)

時間潛在瓶頸 (Temporal Latent Bottleneck)

論文中的一段摘錄

在大腦中,短期記憶和長期記憶以一種特殊的方式發展。短期記憶可以非常快速地改變,以對即時的感官輸入和感知做出反應。相比之下,長期記憶變化緩慢,具有高度選擇性,並涉及重複的鞏固。

受短期和長期記憶的啟發,作者引入了快速流和慢速流計算。快速流具有高容量的短期記憶,可以快速對感官輸入做出反應 (轉換器)。慢速流具有長期記憶,它以較慢的速度更新並總結最相關的信息 (循環)。

為了實現這個想法,我們需要

  • 取得一個數據序列。
  • 將序列分成固定大小的區塊。
  • 快速流在每個區塊內運作。它提供細緻的局部資訊。
  • 慢速流整合和聚合跨區塊的資訊。它提供粗略的遠程資訊。

快速和慢速流引起所謂的資訊不對稱。兩個流透過一個注意力瓶頸相互作用。圖 1 顯示了模型的架構。

Architecture of the model
圖 1:模型的架構。(來源:https://arxiv.org/abs/2205.14794)

作者還提出了一個 PyTorch 風格的偽代碼,如演算法 1 所示。

Pseudocode of the model
演算法 1:PyTorch 風格的偽代碼。(來源:https://arxiv.org/abs/2205.14794)

PatchEmbedding

這個自定義的 keras.layers.Layer 可用於從影像生成圖塊,並使用 keras.layers.Embedding 將它們轉換為更高維度的嵌入空間。圖塊操作是使用 keras.layers.Conv2D 實例完成的。

一旦完成影像的圖塊化,我們將重新塑形影像圖塊,以獲得一個扁平化的表示,其中維度的數量是嵌入維度。在此階段,我們還將位置資訊注入到令牌中。

取得令牌後,我們將它們分塊。分塊操作涉及從嵌入輸出中取得固定大小的序列來建立「區塊」,然後將這些區塊用作模型的最終輸入。

class PatchEmbedding(layers.Layer):
    """Image to Patch Embedding.
    Args:
        image_size (`Tuple[int]`): Size of the input image.
        patch_size (`Tuple[int]`): Size of the patch.
        embed_dim (`int`): Dimension of the embedding.
        chunk_size (`int`): Number of patches to be chunked.
    """

    def __init__(
        self,
        image_size,
        patch_size,
        embed_dim,
        chunk_size,
        **kwargs,
    ):
        super().__init__(**kwargs)

        # Compute the patch resolution.
        patch_resolution = [
            image_size[0] // patch_size[0],
            image_size[1] // patch_size[1],
        ]

        # Store the parameters.
        self.image_size = image_size
        self.patch_size = patch_size
        self.embed_dim = embed_dim
        self.patch_resolution = patch_resolution
        self.num_patches = patch_resolution[0] * patch_resolution[1]

        # Define the positions of the patches.
        self.positions = ops.arange(start=0, stop=self.num_patches, step=1)

        # Create the layers.
        self.projection = layers.Conv2D(
            filters=embed_dim,
            kernel_size=patch_size,
            strides=patch_size,
            name="projection",
        )
        self.flatten = layers.Reshape(
            target_shape=(-1, embed_dim),
            name="flatten",
        )
        self.position_embedding = layers.Embedding(
            input_dim=self.num_patches,
            output_dim=embed_dim,
            name="position_embedding",
        )
        self.layernorm = keras.layers.LayerNormalization(
            epsilon=1e-5,
            name="layernorm",
        )
        self.chunking_layer = layers.Reshape(
            target_shape=(self.num_patches // chunk_size, chunk_size, embed_dim),
            name="chunking_layer",
        )

    def call(self, inputs):
        # Project the inputs to the embedding dimension.
        x = self.projection(inputs)

        # Flatten the pathces and add position embedding.
        x = self.flatten(x)
        x = x + self.position_embedding(self.positions)

        # Normalize the embeddings.
        x = self.layernorm(x)

        # Chunk the tokens.
        x = self.chunking_layer(x)

        return x

FeedForwardNetwork

這個自定義的 keras.layers.Layer 實例允許我們定義一個通用的 FFN 以及 dropout。

class FeedForwardNetwork(layers.Layer):
    """Feed Forward Network.
    Args:
        dims (`int`): Number of units in FFN.
        dropout (`float`): Dropout probability for FFN.
    """

    def __init__(self, dims, dropout, **kwargs):
        super().__init__(**kwargs)

        # Create the layers.
        self.ffn = keras.Sequential(
            [
                layers.Dense(units=4 * dims, activation="gelu"),
                layers.Dense(units=dims),
                layers.Dropout(rate=dropout),
            ],
            name="ffn",
        )
        self.layernorm = layers.LayerNormalization(
            epsilon=1e-5,
            name="layernorm",
        )

    def call(self, inputs):
        # Apply the FFN.
        x = self.layernorm(inputs)
        x = inputs + self.ffn(x)
        return x

BaseAttention

這個自定義的 keras.layers.Layer 實例是一個 super/base 類別,它包裝了一個 keras.layers.MultiHeadAttention 層以及其他一些元件。這為我們模型中的所有注意力層/模組提供了基本的共同功能。

class BaseAttention(layers.Layer):
    """Base Attention Module.
    Args:
        num_heads (`int`): Number of attention heads.
        key_dim (`int`): Size of each attention head for key.
        dropout (`float`): Dropout probability for attention module.
    """

    def __init__(self, num_heads, key_dim, dropout, **kwargs):
        super().__init__(**kwargs)
        self.multi_head_attention = layers.MultiHeadAttention(
            num_heads=num_heads,
            key_dim=key_dim,
            dropout=dropout,
            name="mha",
        )
        self.query_layernorm = layers.LayerNormalization(
            epsilon=1e-5,
            name="q_layernorm",
        )
        self.key_layernorm = layers.LayerNormalization(
            epsilon=1e-5,
            name="k_layernorm",
        )
        self.value_layernorm = layers.LayerNormalization(
            epsilon=1e-5,
            name="v_layernorm",
        )

        self.attention_scores = None

    def call(self, input_query, key, value):
        # Apply the attention module.
        query = self.query_layernorm(input_query)
        key = self.key_layernorm(key)
        value = self.value_layernorm(value)
        (attention_outputs, attention_scores) = self.multi_head_attention(
            query=query,
            key=key,
            value=value,
            return_attention_scores=True,
        )

        # Save the attention scores for later visualization.
        self.attention_scores = attention_scores

        # Add the input to the attention output.
        x = input_query + attention_outputs
        return x

具有 FeedForwardNetwork 層的 Attention

這個自定義的 keras.layers.Layer 實作結合了 BaseAttentionFeedForwardNetwork 元件,以開發一個將在模型中重複使用的區塊。此模組具有高度可自訂性和靈活性,允許在內部層進行變更。

class AttentionWithFFN(layers.Layer):
    """Attention with Feed Forward Network.
    Args:
        ffn_dims (`int`): Number of units in FFN.
        ffn_dropout (`float`): Dropout probability for FFN.
        num_heads (`int`): Number of attention heads.
        key_dim (`int`): Size of each attention head for key.
        attn_dropout (`float`): Dropout probability for attention module.
    """

    def __init__(
        self,
        ffn_dims,
        ffn_dropout,
        num_heads,
        key_dim,
        attn_dropout,
        **kwargs,
    ):
        super().__init__(**kwargs)
        # Create the layers.
        self.fast_stream_attention = BaseAttention(
            num_heads=num_heads,
            key_dim=key_dim,
            dropout=attn_dropout,
            name="base_attn",
        )
        self.slow_stream_attention = BaseAttention(
            num_heads=num_heads,
            key_dim=key_dim,
            dropout=attn_dropout,
            name="base_attn",
        )
        self.ffn = FeedForwardNetwork(
            dims=ffn_dims,
            dropout=ffn_dropout,
            name="ffn",
        )

        self.attention_scores = None

    def build(self, input_shape):
        self.built = True

    def call(self, query, key, value, stream="fast"):
        # Apply the attention module.
        attention_layer = {
            "fast": self.fast_stream_attention,
            "slow": self.slow_stream_attention,
        }[stream]
        if len(query.shape) == 2:
            query = ops.expand_dims(query, -1)
        if len(key.shape) == 2:
            key = ops.expand_dims(key, -1)
        if len(value.shape) == 2:
            value = ops.expand_dims(value, -1)
        x = attention_layer(query, key, value)

        # Save the attention scores for later visualization.
        self.attention_scores = attention_layer.attention_scores

        # Apply the FFN.
        x = self.ffn(x)
        return x

用於時間潛在瓶頸感知模組的自定義 RNN Cell

演算法 1 (偽代碼) 使用 for 迴圈描述了循環。迴圈確實使實作更簡單,但會損害訓練時間。在本節中,我們將自定義循環邏輯包裝在 CustomRecurrentCell 內部。然後,這個自定義的 cell 將被包裝在 Keras RNN API 中,這使得整個程式碼可以向量化。

這個實作為 keras.layers.Layer 的自定義 cell 是模型邏輯的組成部分。cell 的功能可以分為兩個部分:- 慢速流 (時間潛在瓶頸):

  • 此模組包含一個 AttentionWithFFN 層,該層解析先前慢速流的輸出、一個中間隱藏表示 (這是時間潛在瓶頸中的潛在),作為 Query,並將最新快速流的輸出作為 Key 和 Value。此層也可以被視為交叉注意力層。
  • 快速流 (感知模組)
  • 此模組包含相互交織的 AttentionWithFFN 層。此流包含依序排列的 nSelfAttentionCrossAttention
  • 在這裡,一些層將分塊的輸入作為 Query、Key 和 Value (也稱為自我注意力層)。
  • 其他層將時間潛在瓶頸模組內的中間狀態輸出作為 Query,同時使用它之前的先前自我注意力層的輸出作為 Key 和 Value。
class CustomRecurrentCell(layers.Layer):
    """Custom Recurrent Cell.
    Args:
        chunk_size (`int`): Number of tokens in a chunk.
        r (`int`): One Cross Attention per **r** Self Attention.
        num_layers (`int`): Number of layers.
        ffn_dims (`int`): Number of units in FFN.
        ffn_dropout (`float`): Dropout probability for FFN.
        num_heads (`int`): Number of attention heads.
        key_dim (`int`): Size of each attention head for key.
        attn_dropout (`float`): Dropout probability for attention module.
    """

    def __init__(
        self,
        chunk_size,
        r,
        num_layers,
        ffn_dims,
        ffn_dropout,
        num_heads,
        key_dim,
        attn_dropout,
        **kwargs,
    ):
        super().__init__(**kwargs)
        # Save the arguments.
        self.chunk_size = chunk_size
        self.r = r
        self.num_layers = num_layers
        self.ffn_dims = ffn_dims
        self.ffn_droput = ffn_dropout
        self.num_heads = num_heads
        self.key_dim = key_dim
        self.attn_dropout = attn_dropout

        # Create state_size. This is important for
        # custom recurrence logic.
        self.state_size = chunk_size * ffn_dims

        self.get_attention_scores = False
        self.attention_scores = []

        # Perceptual Module
        perceptual_module = list()
        for layer_idx in range(num_layers):
            perceptual_module.append(
                AttentionWithFFN(
                    ffn_dims=ffn_dims,
                    ffn_dropout=ffn_dropout,
                    num_heads=num_heads,
                    key_dim=key_dim,
                    attn_dropout=attn_dropout,
                    name=f"pm_self_attn_{layer_idx}",
                )
            )
            if layer_idx % r == 0:
                perceptual_module.append(
                    AttentionWithFFN(
                        ffn_dims=ffn_dims,
                        ffn_dropout=ffn_dropout,
                        num_heads=num_heads,
                        key_dim=key_dim,
                        attn_dropout=attn_dropout,
                        name=f"pm_cross_attn_ffn_{layer_idx}",
                    )
                )
        self.perceptual_module = perceptual_module

        # Temporal Latent Bottleneck Module
        self.tlb_module = AttentionWithFFN(
            ffn_dims=ffn_dims,
            ffn_dropout=ffn_dropout,
            num_heads=num_heads,
            key_dim=key_dim,
            attn_dropout=attn_dropout,
            name=f"tlb_cross_attn_ffn",
        )

    def build(self, input_shape):
        self.built = True

    def call(self, inputs, states):
        # inputs => (batch, chunk_size, dims)
        # states => [(batch, chunk_size, units)]
        slow_stream = ops.reshape(states[0], (-1, self.chunk_size, self.ffn_dims))
        fast_stream = inputs

        for layer_idx, layer in enumerate(self.perceptual_module):
            fast_stream = layer(
                query=fast_stream, key=fast_stream, value=fast_stream, stream="fast"
            )

            if layer_idx % self.r == 0:
                fast_stream = layer(
                    query=fast_stream, key=slow_stream, value=slow_stream, stream="slow"
                )

        slow_stream = self.tlb_module(
            query=slow_stream, key=fast_stream, value=fast_stream
        )

        # Save the attention scores for later visualization.
        if self.get_attention_scores:
            self.attention_scores.append(self.tlb_module.attention_scores)

        return fast_stream, [
            ops.reshape(slow_stream, (-1, self.chunk_size * self.ffn_dims))
        ]

TemporalLatentBottleneckModel 用於封裝完整的模型

在這裡,我們只是封裝完整的模型,以便將其公開以進行訓練。

class TemporalLatentBottleneckModel(keras.Model):
    """Model Trainer.
    Args:
        patch_layer ([`keras.layers.Layer`](/api/layers/base_layer#layer-class)): Patching layer.
        custom_cell ([`keras.layers.Layer`](/api/layers/base_layer#layer-class)): Custom Recurrent Cell.
    """

    def __init__(self, patch_layer, custom_cell, unroll_loops=False, **kwargs):
        super().__init__(**kwargs)
        self.patch_layer = patch_layer
        self.rnn = layers.RNN(custom_cell, unroll=unroll_loops, name="rnn")
        self.gap = layers.GlobalAveragePooling1D(name="gap")
        self.head = layers.Dense(10, activation="softmax", dtype="float32", name="head")

    def call(self, inputs):
        x = self.patch_layer(inputs)
        x = self.rnn(x)
        x = self.gap(x)
        outputs = self.head(x)
        return outputs

建置模型

為了開始訓練,我們現在分別定義元件,並將它們作為引數傳遞給我們的包裝類別,該類別將準備好用於訓練的最終模型。我們定義一個 PatchEmbed 層和基於 CustomCell 的 RNN。

# Build the model.
patch_layer = PatchEmbedding(
    image_size=(config["image_size"], config["image_size"]),
    patch_size=(config["patch_size"], config["patch_size"]),
    embed_dim=config["embed_dim"],
    chunk_size=config["chunk_size"],
)
custom_rnn_cell = CustomRecurrentCell(
    chunk_size=config["chunk_size"],
    r=config["r"],
    num_layers=config["num_layers"],
    ffn_dims=config["embed_dim"],
    ffn_dropout=config["ffn_drop"],
    num_heads=config["num_heads"],
    key_dim=config["embed_dim"],
    attn_dropout=config["attn_drop"],
)
model = TemporalLatentBottleneckModel(
    patch_layer=patch_layer,
    custom_cell=custom_rnn_cell,
)

指標和回調

我們使用 AdamW 優化器,因為它已證明從優化的角度來看在幾個基準任務上表現良好。它是 keras.optimizers.Adam 優化器的一個版本,同時也加入了權重衰減。

對於損失函數,我們使用 keras.losses.SparseCategoricalCrossentropy 函數,該函數使用預測和實際 logits 之間的簡單交叉熵。我們也計算數據的準確性作為健全性檢查。

optimizer = AdamW(
    learning_rate=config["learning_rate"], weight_decay=config["weight_decay"]
)
model.compile(
    optimizer=optimizer,
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy"],
)

使用 model.fit() 訓練模型

我們傳遞訓練數據集並執行訓練。

history = model.fit(
    train_ds,
    epochs=config["epochs"],
    validation_data=val_ds,
)
Epoch 1/30
19/19 ━━━━━━━━━━━━━━━━━━━━ 1270s 62s/step - accuracy: 0.1166 - loss: 3.1132 - val_accuracy: 0.1486 - val_loss: 2.2887

Epoch 2/30
19/19 ━━━━━━━━━━━━━━━━━━━━ 1152s 60s/step - accuracy: 0.1798 - loss: 2.2290 - val_accuracy: 0.2249 - val_loss: 2.1083

Epoch 3/30
19/19 ━━━━━━━━━━━━━━━━━━━━ 1150s 60s/step - accuracy: 0.2371 - loss: 2.0661 - val_accuracy: 0.2610 - val_loss: 2.0294

Epoch 4/30
19/19 ━━━━━━━━━━━━━━━━━━━━ 1150s 60s/step - accuracy: 0.2631 - loss: 1.9997 - val_accuracy: 0.2765 - val_loss: 2.0008

Epoch 5/30
19/19 ━━━━━━━━━━━━━━━━━━━━ 1151s 60s/step - accuracy: 0.2869 - loss: 1.9634 - val_accuracy: 0.2985 - val_loss: 1.9578

Epoch 6/30
19/19 ━━━━━━━━━━━━━━━━━━━━ 1151s 60s/step - accuracy: 0.3048 - loss: 1.9314 - val_accuracy: 0.3055 - val_loss: 1.9324

Epoch 7/30
 19/19 ━━━━━━━━━━━━━━━━━━━━ 1152s 60s/step - accuracy: 0.3136 - loss: 1.8977 - val_accuracy: 0.3209 - val_loss: 1.9050

Epoch 8/30
19/19 ━━━━━━━━━━━━━━━━━━━━ 1151s 60s/step - accuracy: 0.3238 - loss: 1.8717 - val_accuracy: 0.3231 - val_loss: 1.8874

Epoch 9/30
19/19 ━━━━━━━━━━━━━━━━━━━━ 1152s 60s/step - accuracy: 0.3414 - loss: 1.8453 - val_accuracy: 0.3445 - val_loss: 1.8334

Epoch 10/30
19/19 ━━━━━━━━━━━━━━━━━━━━ 1152s 60s/step - accuracy: 0.3469 - loss: 1.8119 - val_accuracy: 0.3591 - val_loss: 1.8019

Epoch 11/30
19/19 ━━━━━━━━━━━━━━━━━━━━ 1151s 60s/step - accuracy: 0.3648 - loss: 1.7712 - val_accuracy: 0.3793 - val_loss: 1.7513

Epoch 12/30
19/19 ━━━━━━━━━━━━━━━━━━━━ 1146s 60s/step - accuracy: 0.3730 - loss: 1.7332 - val_accuracy: 0.3667 - val_loss: 1.7464

Epoch 13/30
19/19 ━━━━━━━━━━━━━━━━━━━━ 1148s 60s/step - accuracy: 0.3918 - loss: 1.6986 - val_accuracy: 0.3995 - val_loss: 1.6843

Epoch 14/30
19/19 ━━━━━━━━━━━━━━━━━━━━ 1147s 60s/step - accuracy: 0.3975 - loss: 1.6679 - val_accuracy: 0.4026 - val_loss: 1.6602

Epoch 15/30
19/19 ━━━━━━━━━━━━━━━━━━━━ 1146s 60s/step - accuracy: 0.4078 - loss: 1.6400 - val_accuracy: 0.3990 - val_loss: 1.6536

Epoch 16/30
19/19 ━━━━━━━━━━━━━━━━━━━━ 1146s 60s/step - accuracy: 0.4135 - loss: 1.6224 - val_accuracy: 0.4216 - val_loss: 1.6144

Epoch 17/30
19/19 ━━━━━━━━━━━━━━━━━━━━ 1147s 60s/step - accuracy: 0.4254 - loss: 1.5884 - val_accuracy: 0.4281 - val_loss: 1.5788

Epoch 18/30
19/19 ━━━━━━━━━━━━━━━━━━━━ 1146s 60s/step - accuracy: 0.4383 - loss: 1.5614 - val_accuracy: 0.4294 - val_loss: 1.5731

Epoch 19/30
19/19 ━━━━━━━━━━━━━━━━━━━━ 1146s 60s/step - accuracy: 0.4419 - loss: 1.5440 - val_accuracy: 0.4338 - val_loss: 1.5633

Epoch 20/30
19/19 ━━━━━━━━━━━━━━━━━━━━ 1146s 60s/step - accuracy: 0.4439 - loss: 1.5268 - val_accuracy: 0.4430 - val_loss: 1.5211

Epoch 21/30
19/19 ━━━━━━━━━━━━━━━━━━━━ 1147s 60s/step - accuracy: 0.4509 - loss: 1.5108 - val_accuracy: 0.4504 - val_loss: 1.5054

Epoch 22/30
19/19 ━━━━━━━━━━━━━━━━━━━━ 1146s 60s/step - accuracy: 0.4629 - loss: 1.4828 - val_accuracy: 0.4563 - val_loss: 1.4974

Epoch 23/30
19/19 ━━━━━━━━━━━━━━━━━━━━ 1145s 60s/step - accuracy: 0.4660 - loss: 1.4682 - val_accuracy: 0.4647 - val_loss: 1.4794

Epoch 24/30
19/19 ━━━━━━━━━━━━━━━━━━━━ 1146s 60s/step - accuracy: 0.4680 - loss: 1.4524 - val_accuracy: 0.4640 - val_loss: 1.4681

Epoch 25/30
19/19 ━━━━━━━━━━━━━━━━━━━━ 1145s 60s/step - accuracy: 0.4786 - loss: 1.4297 - val_accuracy: 0.4663 - val_loss: 1.4496

Epoch 26/30
19/19 ━━━━━━━━━━━━━━━━━━━━ 1146s 60s/step - accuracy: 0.4889 - loss: 1.4149 - val_accuracy: 0.4769 - val_loss: 1.4350

Epoch 27/30
19/19 ━━━━━━━━━━━━━━━━━━━━ 1146s 60s/step - accuracy: 0.4925 - loss: 1.4009 - val_accuracy: 0.4808 - val_loss: 1.4317

Epoch 28/30
19/19 ━━━━━━━━━━━━━━━━━━━━ 1145s 60s/step - accuracy: 0.4907 - loss: 1.3994 - val_accuracy: 0.4810 - val_loss: 1.4307

Epoch 29/30
19/19 ━━━━━━━━━━━━━━━━━━━━ 1146s 60s/step - accuracy: 0.5000 - loss: 1.3832 - val_accuracy: 0.4844 - val_loss: 1.3996

Epoch 30/30
19/19 ━━━━━━━━━━━━━━━━━━━━ 1146s 60s/step - accuracy: 0.5076 - loss: 1.3592 - val_accuracy: 0.4890 - val_loss: 1.3961
---
## Visualize training metrics

The `model.fit()` will return a `history` object, which stores the values of the metrics
generated during the training run (but it is ephemeral and needs to be saved manually).

We now display the Loss and Accuracy curves for the training and validation sets.


```python
plt.plot(history.history["loss"], label="loss")
plt.plot(history.history["val_loss"], label="val_loss")
plt.legend()
plt.show()

plt.plot(history.history["accuracy"], label="accuracy")
plt.plot(history.history["val_accuracy"], label="val_accuracy")
plt.legend()
plt.show()
![png](/img/examples/vision/temporal_latent_bottleneck/temporal_latent_bottleneck_32_0.png) ![png](/img/examples/vision/temporal_latent_bottleneck/temporal_latent_bottleneck_32_1.png) --- ## 將注意力圖視覺化 從時間潛在瓶頸現在我們已經訓練好模型,是時候進行一些視覺化了。快速流 (轉換器) 處理一組令牌。慢速流處理每個區塊,並關注對任務有用的令牌。在本節中,我們將視覺化慢速流的注意力圖。這是透過從每個區塊交叉點的 TLB 層中提取注意力分數並將其儲存在 RNN 的狀態中來完成的。然後將其「膨脹」起來並傳回這些值。
def score_to_viz(chunk_score):
    # get the most attended token
    chunk_viz = ops.max(chunk_score, axis=-2)
    # get the mean across heads
    chunk_viz = ops.mean(chunk_viz, axis=1)
    return chunk_viz


# Get a batch of images and labels from the testing dataset
images, labels = next(iter(test_ds))

# Create a new model instance that is executed eagerly to allow saving
# attention scores. This also requires unrolling loops
eager_model = TemporalLatentBottleneckModel(
    patch_layer=patch_layer, custom_cell=custom_rnn_cell, unroll_loops=True
)
eager_model.compile(run_eagerly=True, jit_compile=False)
model.save("weights.keras")
eager_model.load_weights("weights.keras")

# Set the get_attn_scores flag to True
eager_model.rnn.cell.get_attention_scores = True

# Run the model with the testing images and grab the
# attention scores.
outputs = eager_model(images)
list_chunk_scores = eager_model.rnn.cell.attention_scores

# Process the attention scores in order to visualize them
num_chunks = (config["image_size"] // config["patch_size"]) ** 2 // config["chunk_size"]
list_chunk_viz = [score_to_viz(x) for x in list_chunk_scores[-num_chunks:]]
chunk_viz = ops.concatenate(list_chunk_viz, axis=-1)
chunk_viz = ops.reshape(
    chunk_viz,
    (
        config["batch_size"],
        config["image_size"] // config["patch_size"],
        config["image_size"] // config["patch_size"],
        1,
    ),
)
upsampled_heat_map = layers.UpSampling2D(
    size=(4, 4), interpolation="bilinear", dtype="float32"
)(chunk_viz)
執行下列程式碼片段以取得不同的影像及其注意力圖。
# Sample a random image
index = random.randint(0, config["batch_size"])
orig_image = images[index]
overlay_image = upsampled_heat_map[index, ..., 0]

if keras.backend.backend() == "torch":
    # when using the torch backend, we are required to ensure that the
    # image is copied from the GPU
    orig_image = orig_image.cpu().detach().numpy()
    overlay_image = overlay_image.cpu().detach().numpy()

# Plot the visualization
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))

ax[0].imshow(orig_image)
ax[0].set_title("Original:")
ax[0].axis("off")

image = ax[1].imshow(orig_image)
ax[1].imshow(
    overlay_image,
    cmap="inferno",
    alpha=0.6,
    extent=image.get_extent(),
)
ax[1].set_title("TLB Attention:")

plt.show()
![png](/img/examples/vision/temporal_latent_bottleneck/temporal_latent_bottleneck_36_0.png) --- ## 結論 這個範例示範了時間潛在瓶頸機制的實作。此範例強調使用壓縮和儲存歷史狀態,以時間潛在瓶頸的形式呈現,並定期從感知模組更新,作為一種有效的方法。在原始論文中,作者對不同的模態進行了非常廣泛的測試,範圍從監督影像分類到增強學習中的應用。雖然我們僅展示了一種將此機制應用於影像分類的方法,但它也可以透過最小的變更擴展到其他模態。*注意*:在建置此範例時,我們沒有官方程式碼可供參考。這表示我們的實作是受論文啟發的,並未聲稱是完整的複製。有關訓練過程的更多詳細資訊,可以前往[我們的 GitHub 儲存庫](https://github.com/suvadityamuk/Temporal-Latent-Bottleneck-TF)。--- ## 致謝 感謝 [Aniket Didolkar](https://www.aniketdidolkar.in/) (第一作者) 和 [Anirudh Goyal](https://anirudh9119.github.io/) (第三作者) 審閱我們的工作。我們要感謝 [PyImageSearch](https://pyimagesearch.com/) 提供 Colab Pro 帳戶和 [JarvisLabs.ai](https://cloud.jarvislabs.ai/) 提供 GPU 點數。