作者: Sayak Paul
建立日期 2021/10/20
上次修改日期 2024/02/11
描述: MobileViT 用於圖像分類,結合卷積和 Transformer 的優點。

ⓘ 此範例使用 Keras 3

在此範例中,我們實作了 MobileViT 架構 (Mehta 等人),它結合了 Transformer (Vaswani 等人) 和卷積的優點。 使用 Transformer,我們可以捕捉產生全域表示的長距離依賴關係。 使用卷積,我們可以捕捉模擬局部性的空間關係。

除了結合 Transformer 和卷積的屬性之外,作者還將 MobileViT 作為不同圖像識別任務的通用行動裝置友善骨幹引入。 他們的研究結果表明,就效能而言,MobileViT 比其他具有相同或更高複雜度的模型(例如 MobileNetV3)更好,同時在行動裝置上具有高效性。

注意:此範例應使用 Tensorflow 2.13 或更高版本執行。


import os
import tensorflow as tf

os.environ["KERAS_BACKEND"] = "tensorflow"

import keras
from keras import layers
from keras import backend

import tensorflow_datasets as tfds



# Values are from table 4.
patch_size = 4  # 2x2, for the Transformer blocks.
image_size = 256
expansion_factor = 2  # expansion factor for the MobileNetV2 blocks.

MobileViT 公用程式

MobileViT 架構包含下列區塊

  • 處理輸入圖像的步幅 3x3 卷積。
  • 用於對中間特徵圖進行降採樣的 MobileNetV2 風格反向殘差區塊。
  • 結合 Transformer 和卷積優點的 MobileViT 區塊。 它顯示在下圖中(取自原始論文

def conv_block(x, filters=16, kernel_size=3, strides=2):
    conv_layer = layers.Conv2D(
    return conv_layer(x)

# Reference: https://github.com/keras-team/keras/blob/e3858739d178fe16a0c77ce7fab88b0be6dbbdc7/keras/applications/imagenet_utils.py#L413C17-L435

def correct_pad(inputs, kernel_size):
    img_dim = 2 if backend.image_data_format() == "channels_first" else 1
    input_size = inputs.shape[img_dim : (img_dim + 2)]
    if isinstance(kernel_size, int):
        kernel_size = (kernel_size, kernel_size)
    if input_size[0] is None:
        adjust = (1, 1)
        adjust = (1 - input_size[0] % 2, 1 - input_size[1] % 2)
    correct = (kernel_size[0] // 2, kernel_size[1] // 2)
    return (
        (correct[0] - adjust[0], correct[0]),
        (correct[1] - adjust[1], correct[1]),

# Reference: https://git.io/JKgtC

def inverted_residual_block(x, expanded_channels, output_channels, strides=1):
    m = layers.Conv2D(expanded_channels, 1, padding="same", use_bias=False)(x)
    m = layers.BatchNormalization()(m)
    m = keras.activations.swish(m)

    if strides == 2:
        m = layers.ZeroPadding2D(padding=correct_pad(m, 3))(m)
    m = layers.DepthwiseConv2D(
        3, strides=strides, padding="same" if strides == 1 else "valid", use_bias=False
    m = layers.BatchNormalization()(m)
    m = keras.activations.swish(m)

    m = layers.Conv2D(output_channels, 1, padding="same", use_bias=False)(m)
    m = layers.BatchNormalization()(m)

    if keras.ops.equal(x.shape[-1], output_channels) and strides == 1:
        return layers.Add()([m, x])
    return m

# Reference:
# https://keras.dev.org.tw/examples/vision/image_classification_with_vision_transformer/

def mlp(x, hidden_units, dropout_rate):
    for units in hidden_units:
        x = layers.Dense(units, activation=keras.activations.swish)(x)
        x = layers.Dropout(dropout_rate)(x)
    return x

def transformer_block(x, transformer_layers, projection_dim, num_heads=2):
    for _ in range(transformer_layers):
        # Layer normalization 1.
        x1 = layers.LayerNormalization(epsilon=1e-6)(x)
        # Create a multi-head attention layer.
        attention_output = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=projection_dim, dropout=0.1
        )(x1, x1)
        # Skip connection 1.
        x2 = layers.Add()([attention_output, x])
        # Layer normalization 2.
        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        # MLP.
        x3 = mlp(
            hidden_units=[x.shape[-1] * 2, x.shape[-1]],
        # Skip connection 2.
        x = layers.Add()([x3, x2])

    return x

def mobilevit_block(x, num_blocks, projection_dim, strides=1):
    # Local projection with convolutions.
    local_features = conv_block(x, filters=projection_dim, strides=strides)
    local_features = conv_block(
        local_features, filters=projection_dim, kernel_size=1, strides=strides

    # Unfold into patches and then pass through Transformers.
    num_patches = int((local_features.shape[1] * local_features.shape[2]) / patch_size)
    non_overlapping_patches = layers.Reshape((patch_size, num_patches, projection_dim))(
    global_features = transformer_block(
        non_overlapping_patches, num_blocks, projection_dim

    # Fold into conv-like feature-maps.
    folded_feature_map = layers.Reshape((*local_features.shape[1:-1], projection_dim))(

    # Apply point-wise conv -> concatenate with the input features.
    folded_feature_map = conv_block(
        folded_feature_map, filters=x.shape[-1], kernel_size=1, strides=strides
    local_global_features = layers.Concatenate(axis=-1)([x, folded_feature_map])

    # Fuse the local and global features using a convoluion layer.
    local_global_features = conv_block(
        local_global_features, filters=projection_dim, strides=strides

    return local_global_features

更多關於 MobileViT 區塊:

  • 首先,特徵表示 (A) 會通過捕捉局部關係的卷積區塊。 此處單個條目的預期形狀為 (h, w, num_channels)
  • 然後,它們會展開成另一個形狀為 (p, n, num_channels) 的向量,其中 p 是小塊的面積,而 n(h * w) / p。 因此,我們最終會得到 n 個不重疊的塊。
  • 然後,此展開的向量會通過 Transformer 區塊,以捕捉塊之間的整體關係。
  • 輸出向量 (B) 會再次摺疊成形狀為 (h, w, num_channels) 的向量,類似於從卷積輸出的特徵圖。

然後,向量 A 和 B 會通過另外兩個卷積層,以融合局部和全域表示。 請注意,最終向量的空間解析度在此時保持不變。 作者還解釋了 MobileViT 區塊如何類似於 CNN 的卷積區塊。 如需更多詳細資訊,請參閱原始論文。

接下來,我們將這些區塊結合在一起並實作 MobileViT 架構 (XXS 變體)。 下圖(取自原始論文)呈現了架構的示意圖

def create_mobilevit(num_classes=5):
    inputs = keras.Input((image_size, image_size, 3))
    x = layers.Rescaling(scale=1.0 / 255)(inputs)

    # Initial conv-stem -> MV2 block.
    x = conv_block(x, filters=16)
    x = inverted_residual_block(
        x, expanded_channels=16 * expansion_factor, output_channels=16

    # Downsampling with MV2 block.
    x = inverted_residual_block(
        x, expanded_channels=16 * expansion_factor, output_channels=24, strides=2
    x = inverted_residual_block(
        x, expanded_channels=24 * expansion_factor, output_channels=24
    x = inverted_residual_block(
        x, expanded_channels=24 * expansion_factor, output_channels=24

    # First MV2 -> MobileViT block.
    x = inverted_residual_block(
        x, expanded_channels=24 * expansion_factor, output_channels=48, strides=2
    x = mobilevit_block(x, num_blocks=2, projection_dim=64)

    # Second MV2 -> MobileViT block.
    x = inverted_residual_block(
        x, expanded_channels=64 * expansion_factor, output_channels=64, strides=2
    x = mobilevit_block(x, num_blocks=4, projection_dim=80)

    # Third MV2 -> MobileViT block.
    x = inverted_residual_block(
        x, expanded_channels=80 * expansion_factor, output_channels=80, strides=2
    x = mobilevit_block(x, num_blocks=3, projection_dim=96)
    x = conv_block(x, filters=320, kernel_size=1, strides=1)

    # Classification head.
    x = layers.GlobalAvgPool2D()(x)
    outputs = layers.Dense(num_classes, activation="softmax")(x)

    return keras.Model(inputs, outputs)

mobilevit_xxs = create_mobilevit()
Model: "model"
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 256, 256, 3) 0                                            
rescaling (Rescaling)           (None, 256, 256, 3)  0           input_1[0][0]                    
conv2d (Conv2D)                 (None, 128, 128, 16) 448         rescaling[0][0]                  
conv2d_1 (Conv2D)               (None, 128, 128, 32) 512         conv2d[0][0]                     
batch_normalization (BatchNorma (None, 128, 128, 32) 128         conv2d_1[0][0]                   
tf.nn.silu (TFOpLambda)         (None, 128, 128, 32) 0           batch_normalization[0][0]        
depthwise_conv2d (DepthwiseConv (None, 128, 128, 32) 288         tf.nn.silu[0][0]                 
batch_normalization_1 (BatchNor (None, 128, 128, 32) 128         depthwise_conv2d[0][0]           
tf.nn.silu_1 (TFOpLambda)       (None, 128, 128, 32) 0           batch_normalization_1[0][0]      
conv2d_2 (Conv2D)               (None, 128, 128, 16) 512         tf.nn.silu_1[0][0]               
batch_normalization_2 (BatchNor (None, 128, 128, 16) 64          conv2d_2[0][0]                   
add (Add)                       (None, 128, 128, 16) 0           batch_normalization_2[0][0]      
conv2d_3 (Conv2D)               (None, 128, 128, 32) 512         add[0][0]                        
batch_normalization_3 (BatchNor (None, 128, 128, 32) 128         conv2d_3[0][0]                   
tf.nn.silu_2 (TFOpLambda)       (None, 128, 128, 32) 0           batch_normalization_3[0][0]      
zero_padding2d (ZeroPadding2D)  (None, 129, 129, 32) 0           tf.nn.silu_2[0][0]               
depthwise_conv2d_1 (DepthwiseCo (None, 64, 64, 32)   288         zero_padding2d[0][0]             
batch_normalization_4 (BatchNor (None, 64, 64, 32)   128         depthwise_conv2d_1[0][0]         
tf.nn.silu_3 (TFOpLambda)       (None, 64, 64, 32)   0           batch_normalization_4[0][0]      
conv2d_4 (Conv2D)               (None, 64, 64, 24)   768         tf.nn.silu_3[0][0]               
batch_normalization_5 (BatchNor (None, 64, 64, 24)   96          conv2d_4[0][0]                   
conv2d_5 (Conv2D)               (None, 64, 64, 48)   1152        batch_normalization_5[0][0]      
batch_normalization_6 (BatchNor (None, 64, 64, 48)   192         conv2d_5[0][0]                   
tf.nn.silu_4 (TFOpLambda)       (None, 64, 64, 48)   0           batch_normalization_6[0][0]      
depthwise_conv2d_2 (DepthwiseCo (None, 64, 64, 48)   432         tf.nn.silu_4[0][0]               
batch_normalization_7 (BatchNor (None, 64, 64, 48)   192         depthwise_conv2d_2[0][0]         
tf.nn.silu_5 (TFOpLambda)       (None, 64, 64, 48)   0           batch_normalization_7[0][0]      
conv2d_6 (Conv2D)               (None, 64, 64, 24)   1152        tf.nn.silu_5[0][0]               
batch_normalization_8 (BatchNor (None, 64, 64, 24)   96          conv2d_6[0][0]                   
add_1 (Add)                     (None, 64, 64, 24)   0           batch_normalization_8[0][0]      
conv2d_7 (Conv2D)               (None, 64, 64, 48)   1152        add_1[0][0]                      
batch_normalization_9 (BatchNor (None, 64, 64, 48)   192         conv2d_7[0][0]                   
tf.nn.silu_6 (TFOpLambda)       (None, 64, 64, 48)   0           batch_normalization_9[0][0]      
depthwise_conv2d_3 (DepthwiseCo (None, 64, 64, 48)   432         tf.nn.silu_6[0][0]               
batch_normalization_10 (BatchNo (None, 64, 64, 48)   192         depthwise_conv2d_3[0][0]         
tf.nn.silu_7 (TFOpLambda)       (None, 64, 64, 48)   0           batch_normalization_10[0][0]     
conv2d_8 (Conv2D)               (None, 64, 64, 24)   1152        tf.nn.silu_7[0][0]               
batch_normalization_11 (BatchNo (None, 64, 64, 24)   96          conv2d_8[0][0]                   
add_2 (Add)                     (None, 64, 64, 24)   0           batch_normalization_11[0][0]     
conv2d_9 (Conv2D)               (None, 64, 64, 48)   1152        add_2[0][0]                      
batch_normalization_12 (BatchNo (None, 64, 64, 48)   192         conv2d_9[0][0]                   
tf.nn.silu_8 (TFOpLambda)       (None, 64, 64, 48)   0           batch_normalization_12[0][0]     
zero_padding2d_1 (ZeroPadding2D (None, 65, 65, 48)   0           tf.nn.silu_8[0][0]               
depthwise_conv2d_4 (DepthwiseCo (None, 32, 32, 48)   432         zero_padding2d_1[0][0]           
batch_normalization_13 (BatchNo (None, 32, 32, 48)   192         depthwise_conv2d_4[0][0]         
tf.nn.silu_9 (TFOpLambda)       (None, 32, 32, 48)   0           batch_normalization_13[0][0]     
conv2d_10 (Conv2D)              (None, 32, 32, 48)   2304        tf.nn.silu_9[0][0]               
batch_normalization_14 (BatchNo (None, 32, 32, 48)   192         conv2d_10[0][0]                  
conv2d_11 (Conv2D)              (None, 32, 32, 64)   27712       batch_normalization_14[0][0]     
conv2d_12 (Conv2D)              (None, 32, 32, 64)   4160        conv2d_11[0][0]                  
reshape (Reshape)               (None, 4, 256, 64)   0           conv2d_12[0][0]                  
layer_normalization (LayerNorma (None, 4, 256, 64)   128         reshape[0][0]                    
multi_head_attention (MultiHead (None, 4, 256, 64)   33216       layer_normalization[0][0]        
add_3 (Add)                     (None, 4, 256, 64)   0           multi_head_attention[0][0]       
layer_normalization_1 (LayerNor (None, 4, 256, 64)   128         add_3[0][0]                      
dense (Dense)                   (None, 4, 256, 128)  8320        layer_normalization_1[0][0]      
dropout (Dropout)               (None, 4, 256, 128)  0           dense[0][0]                      
dense_1 (Dense)                 (None, 4, 256, 64)   8256        dropout[0][0]                    
dropout_1 (Dropout)             (None, 4, 256, 64)   0           dense_1[0][0]                    
add_4 (Add)                     (None, 4, 256, 64)   0           dropout_1[0][0]                  
layer_normalization_2 (LayerNor (None, 4, 256, 64)   128         add_4[0][0]                      
multi_head_attention_1 (MultiHe (None, 4, 256, 64)   33216       layer_normalization_2[0][0]      
add_5 (Add)                     (None, 4, 256, 64)   0           multi_head_attention_1[0][0]     
layer_normalization_3 (LayerNor (None, 4, 256, 64)   128         add_5[0][0]                      
dense_2 (Dense)                 (None, 4, 256, 128)  8320        layer_normalization_3[0][0]      
dropout_2 (Dropout)             (None, 4, 256, 128)  0           dense_2[0][0]                    
dense_3 (Dense)                 (None, 4, 256, 64)   8256        dropout_2[0][0]                  
dropout_3 (Dropout)             (None, 4, 256, 64)   0           dense_3[0][0]                    
add_6 (Add)                     (None, 4, 256, 64)   0           dropout_3[0][0]                  
reshape_1 (Reshape)             (None, 32, 32, 64)   0           add_6[0][0]                      
conv2d_13 (Conv2D)              (None, 32, 32, 48)   3120        reshape_1[0][0]                  
concatenate (Concatenate)       (None, 32, 32, 96)   0           batch_normalization_14[0][0]     
conv2d_14 (Conv2D)              (None, 32, 32, 64)   55360       concatenate[0][0]                
conv2d_15 (Conv2D)              (None, 32, 32, 128)  8192        conv2d_14[0][0]                  
batch_normalization_15 (BatchNo (None, 32, 32, 128)  512         conv2d_15[0][0]                  
tf.nn.silu_10 (TFOpLambda)      (None, 32, 32, 128)  0           batch_normalization_15[0][0]     
zero_padding2d_2 (ZeroPadding2D (None, 33, 33, 128)  0           tf.nn.silu_10[0][0]              
depthwise_conv2d_5 (DepthwiseCo (None, 16, 16, 128)  1152        zero_padding2d_2[0][0]           
batch_normalization_16 (BatchNo (None, 16, 16, 128)  512         depthwise_conv2d_5[0][0]         
tf.nn.silu_11 (TFOpLambda)      (None, 16, 16, 128)  0           batch_normalization_16[0][0]     
conv2d_16 (Conv2D)              (None, 16, 16, 64)   8192        tf.nn.silu_11[0][0]              
batch_normalization_17 (BatchNo (None, 16, 16, 64)   256         conv2d_16[0][0]                  
conv2d_17 (Conv2D)              (None, 16, 16, 80)   46160       batch_normalization_17[0][0]     
conv2d_18 (Conv2D)              (None, 16, 16, 80)   6480        conv2d_17[0][0]                  
reshape_2 (Reshape)             (None, 4, 64, 80)    0           conv2d_18[0][0]                  
layer_normalization_4 (LayerNor (None, 4, 64, 80)    160         reshape_2[0][0]                  
multi_head_attention_2 (MultiHe (None, 4, 64, 80)    51760       layer_normalization_4[0][0]      
add_7 (Add)                     (None, 4, 64, 80)    0           multi_head_attention_2[0][0]     
layer_normalization_5 (LayerNor (None, 4, 64, 80)    160         add_7[0][0]                      
dense_4 (Dense)                 (None, 4, 64, 160)   12960       layer_normalization_5[0][0]      
dropout_4 (Dropout)             (None, 4, 64, 160)   0           dense_4[0][0]                    
dense_5 (Dense)                 (None, 4, 64, 80)    12880       dropout_4[0][0]                  
dropout_5 (Dropout)             (None, 4, 64, 80)    0           dense_5[0][0]                    
add_8 (Add)                     (None, 4, 64, 80)    0           dropout_5[0][0]                  
layer_normalization_6 (LayerNor (None, 4, 64, 80)    160         add_8[0][0]                      
multi_head_attention_3 (MultiHe (None, 4, 64, 80)    51760       layer_normalization_6[0][0]      
add_9 (Add)                     (None, 4, 64, 80)    0           multi_head_attention_3[0][0]     
layer_normalization_7 (LayerNor (None, 4, 64, 80)    160         add_9[0][0]                      
dense_6 (Dense)                 (None, 4, 64, 160)   12960       layer_normalization_7[0][0]      
dropout_6 (Dropout)             (None, 4, 64, 160)   0           dense_6[0][0]                    
dense_7 (Dense)                 (None, 4, 64, 80)    12880       dropout_6[0][0]                  
dropout_7 (Dropout)             (None, 4, 64, 80)    0           dense_7[0][0]                    
add_10 (Add)                    (None, 4, 64, 80)    0           dropout_7[0][0]                  
layer_normalization_8 (LayerNor (None, 4, 64, 80)    160         add_10[0][0]                     
multi_head_attention_4 (MultiHe (None, 4, 64, 80)    51760       layer_normalization_8[0][0]      
add_11 (Add)                    (None, 4, 64, 80)    0           multi_head_attention_4[0][0]     
layer_normalization_9 (LayerNor (None, 4, 64, 80)    160         add_11[0][0]                     
dense_8 (Dense)                 (None, 4, 64, 160)   12960       layer_normalization_9[0][0]      
dropout_8 (Dropout)             (None, 4, 64, 160)   0           dense_8[0][0]                    
dense_9 (Dense)                 (None, 4, 64, 80)    12880       dropout_8[0][0]                  
dropout_9 (Dropout)             (None, 4, 64, 80)    0           dense_9[0][0]                    
add_12 (Add)                    (None, 4, 64, 80)    0           dropout_9[0][0]                  
layer_normalization_10 (LayerNo (None, 4, 64, 80)    160         add_12[0][0]                     
multi_head_attention_5 (MultiHe (None, 4, 64, 80)    51760       layer_normalization_10[0][0]     
add_13 (Add)                    (None, 4, 64, 80)    0           multi_head_attention_5[0][0]     
layer_normalization_11 (LayerNo (None, 4, 64, 80)    160         add_13[0][0]                     
dense_10 (Dense)                (None, 4, 64, 160)   12960       layer_normalization_11[0][0]     
dropout_10 (Dropout)            (None, 4, 64, 160)   0           dense_10[0][0]                   
dense_11 (Dense)                (None, 4, 64, 80)    12880       dropout_10[0][0]                 
dropout_11 (Dropout)            (None, 4, 64, 80)    0           dense_11[0][0]                   
add_14 (Add)                    (None, 4, 64, 80)    0           dropout_11[0][0]                 
reshape_3 (Reshape)             (None, 16, 16, 80)   0           add_14[0][0]                     
conv2d_19 (Conv2D)              (None, 16, 16, 64)   5184        reshape_3[0][0]                  
concatenate_1 (Concatenate)     (None, 16, 16, 128)  0           batch_normalization_17[0][0]     
conv2d_20 (Conv2D)              (None, 16, 16, 80)   92240       concatenate_1[0][0]              
conv2d_21 (Conv2D)              (None, 16, 16, 160)  12800       conv2d_20[0][0]                  
batch_normalization_18 (BatchNo (None, 16, 16, 160)  640         conv2d_21[0][0]                  
tf.nn.silu_12 (TFOpLambda)      (None, 16, 16, 160)  0           batch_normalization_18[0][0]     
zero_padding2d_3 (ZeroPadding2D (None, 17, 17, 160)  0           tf.nn.silu_12[0][0]              
depthwise_conv2d_6 (DepthwiseCo (None, 8, 8, 160)    1440        zero_padding2d_3[0][0]           
batch_normalization_19 (BatchNo (None, 8, 8, 160)    640         depthwise_conv2d_6[0][0]         
tf.nn.silu_13 (TFOpLambda)      (None, 8, 8, 160)    0           batch_normalization_19[0][0]     
conv2d_22 (Conv2D)              (None, 8, 8, 80)     12800       tf.nn.silu_13[0][0]              
batch_normalization_20 (BatchNo (None, 8, 8, 80)     320         conv2d_22[0][0]                  
conv2d_23 (Conv2D)              (None, 8, 8, 96)     69216       batch_normalization_20[0][0]     
conv2d_24 (Conv2D)              (None, 8, 8, 96)     9312        conv2d_23[0][0]                  
reshape_4 (Reshape)             (None, 4, 16, 96)    0           conv2d_24[0][0]                  
layer_normalization_12 (LayerNo (None, 4, 16, 96)    192         reshape_4[0][0]                  
multi_head_attention_6 (MultiHe (None, 4, 16, 96)    74400       layer_normalization_12[0][0]     
add_15 (Add)                    (None, 4, 16, 96)    0           multi_head_attention_6[0][0]     
layer_normalization_13 (LayerNo (None, 4, 16, 96)    192         add_15[0][0]                     
dense_12 (Dense)                (None, 4, 16, 192)   18624       layer_normalization_13[0][0]     
dropout_12 (Dropout)            (None, 4, 16, 192)   0           dense_12[0][0]                   
dense_13 (Dense)                (None, 4, 16, 96)    18528       dropout_12[0][0]                 
dropout_13 (Dropout)            (None, 4, 16, 96)    0           dense_13[0][0]                   
add_16 (Add)                    (None, 4, 16, 96)    0           dropout_13[0][0]                 
layer_normalization_14 (LayerNo (None, 4, 16, 96)    192         add_16[0][0]                     
multi_head_attention_7 (MultiHe (None, 4, 16, 96)    74400       layer_normalization_14[0][0]     
add_17 (Add)                    (None, 4, 16, 96)    0           multi_head_attention_7[0][0]     
layer_normalization_15 (LayerNo (None, 4, 16, 96)    192         add_17[0][0]                     
dense_14 (Dense)                (None, 4, 16, 192)   18624       layer_normalization_15[0][0]     
dropout_14 (Dropout)            (None, 4, 16, 192)   0           dense_14[0][0]                   
dense_15 (Dense)                (None, 4, 16, 96)    18528       dropout_14[0][0]                 
dropout_15 (Dropout)            (None, 4, 16, 96)    0           dense_15[0][0]                   
add_18 (Add)                    (None, 4, 16, 96)    0           dropout_15[0][0]                 
layer_normalization_16 (LayerNo (None, 4, 16, 96)    192         add_18[0][0]                     
multi_head_attention_8 (MultiHe (None, 4, 16, 96)    74400       layer_normalization_16[0][0]     
add_19 (Add)                    (None, 4, 16, 96)    0           multi_head_attention_8[0][0]     
layer_normalization_17 (LayerNo (None, 4, 16, 96)    192         add_19[0][0]                     
dense_16 (Dense)                (None, 4, 16, 192)   18624       layer_normalization_17[0][0]     
dropout_16 (Dropout)            (None, 4, 16, 192)   0           dense_16[0][0]                   
dense_17 (Dense)                (None, 4, 16, 96)    18528       dropout_16[0][0]                 
dropout_17 (Dropout)            (None, 4, 16, 96)    0           dense_17[0][0]                   
add_20 (Add)                    (None, 4, 16, 96)    0           dropout_17[0][0]                 
reshape_5 (Reshape)             (None, 8, 8, 96)     0           add_20[0][0]                     
conv2d_25 (Conv2D)              (None, 8, 8, 80)     7760        reshape_5[0][0]                  
concatenate_2 (Concatenate)     (None, 8, 8, 160)    0           batch_normalization_20[0][0]     
conv2d_26 (Conv2D)              (None, 8, 8, 96)     138336      concatenate_2[0][0]              
conv2d_27 (Conv2D)              (None, 8, 8, 320)    31040       conv2d_26[0][0]                  
global_average_pooling2d (Globa (None, 320)          0           conv2d_27[0][0]                  
dense_18 (Dense)                (None, 5)            1605        global_average_pooling2d[0][0]   
Total params: 1,307,621
Trainable params: 1,305,077
Non-trainable params: 2,544

## Dataset preparation

We will be using the
dataset to demonstrate the model. Unlike other Transformer-based architectures,
MobileViT uses a simple augmentation pipeline primarily because it has the properties
of a CNN.

batch_size = 64
auto = tf.data.AUTOTUNE
resize_bigger = 280
num_classes = 5

def preprocess_dataset(is_training=True):
    def _pp(image, label):
        if is_training:
            # Resize to a bigger spatial resolution and take the random
            # crops.
            image = tf.image.resize(image, (resize_bigger, resize_bigger))
            image = tf.image.random_crop(image, (image_size, image_size, 3))
            image = tf.image.random_flip_left_right(image)
            image = tf.image.resize(image, (image_size, image_size))
        label = tf.one_hot(label, depth=num_classes)
        return image, label

    return _pp

def prepare_dataset(dataset, is_training=True):
    if is_training:
        dataset = dataset.shuffle(batch_size * 10)
    dataset = dataset.map(preprocess_dataset(is_training), num_parallel_calls=auto)
    return dataset.batch(batch_size).prefetch(auto)
作者使用多尺度資料取樣器來協助模型學習各種尺度的表示。 在此範例中,我們將捨棄此部分。
train_dataset, val_dataset = tfds.load(
    "tf_flowers", split=["train[:90%]", "train[90%:]"], as_supervised=True

num_train = train_dataset.cardinality()
num_val = val_dataset.cardinality()
print(f"Number of training examples: {num_train}")
print(f"Number of validation examples: {num_val}")

train_dataset = prepare_dataset(train_dataset, is_training=True)
val_dataset = prepare_dataset(val_dataset, is_training=False)
Number of training examples: 3303
Number of validation examples: 367
--- ## 訓練 MobileViT (XXS) 模型
learning_rate = 0.002
label_smoothing_factor = 0.1
epochs = 30

optimizer = keras.optimizers.Adam(learning_rate=learning_rate)
loss_fn = keras.losses.CategoricalCrossentropy(label_smoothing=label_smoothing_factor)

def run_experiment(epochs=epochs):
    mobilevit_xxs = create_mobilevit(num_classes=num_classes)
    mobilevit_xxs.compile(optimizer=optimizer, loss=loss_fn, metrics=["accuracy"])

    # When using `save_weights_only=True` in `ModelCheckpoint`, the filepath provided must end in `.weights.h5`
    checkpoint_filepath = "/tmp/checkpoint.weights.h5"
    checkpoint_callback = keras.callbacks.ModelCheckpoint(

    _, accuracy = mobilevit_xxs.evaluate(val_dataset)
    print(f"Validation accuracy: {round(accuracy * 100, 2)}%")
    return mobilevit_xxs

mobilevit_xxs = run_experiment()
Epoch 1/30
52/52 [==============================] - 47s 459ms/step - loss: 1.3397 - accuracy: 0.4832 - val_loss: 1.7250 - val_accuracy: 0.1662
Epoch 2/30
52/52 [==============================] - 21s 404ms/step - loss: 1.1167 - accuracy: 0.6210 - val_loss: 1.9844 - val_accuracy: 0.1907
Epoch 3/30
52/52 [==============================] - 21s 403ms/step - loss: 1.0217 - accuracy: 0.6709 - val_loss: 1.8187 - val_accuracy: 0.1907
Epoch 4/30
52/52 [==============================] - 21s 409ms/step - loss: 0.9682 - accuracy: 0.7048 - val_loss: 2.0329 - val_accuracy: 0.1907
Epoch 5/30
52/52 [==============================] - 21s 408ms/step - loss: 0.9552 - accuracy: 0.7196 - val_loss: 2.1150 - val_accuracy: 0.1907
Epoch 6/30
52/52 [==============================] - 21s 407ms/step - loss: 0.9186 - accuracy: 0.7318 - val_loss: 2.9713 - val_accuracy: 0.1907
Epoch 7/30
52/52 [==============================] - 21s 407ms/step - loss: 0.8986 - accuracy: 0.7457 - val_loss: 3.2062 - val_accuracy: 0.1907
Epoch 8/30
52/52 [==============================] - 21s 408ms/step - loss: 0.8831 - accuracy: 0.7542 - val_loss: 3.8631 - val_accuracy: 0.1907
Epoch 9/30
52/52 [==============================] - 21s 408ms/step - loss: 0.8433 - accuracy: 0.7714 - val_loss: 1.8029 - val_accuracy: 0.3542
Epoch 10/30
52/52 [==============================] - 21s 408ms/step - loss: 0.8489 - accuracy: 0.7763 - val_loss: 1.7920 - val_accuracy: 0.4796
Epoch 11/30
52/52 [==============================] - 21s 409ms/step - loss: 0.8256 - accuracy: 0.7884 - val_loss: 1.4992 - val_accuracy: 0.5477
Epoch 12/30
52/52 [==============================] - 21s 407ms/step - loss: 0.7859 - accuracy: 0.8123 - val_loss: 0.9236 - val_accuracy: 0.7330
Epoch 13/30
52/52 [==============================] - 21s 409ms/step - loss: 0.7702 - accuracy: 0.8159 - val_loss: 0.8059 - val_accuracy: 0.8011
Epoch 14/30
52/52 [==============================] - 21s 403ms/step - loss: 0.7670 - accuracy: 0.8153 - val_loss: 1.1535 - val_accuracy: 0.7084
Epoch 15/30
52/52 [==============================] - 21s 408ms/step - loss: 0.7332 - accuracy: 0.8344 - val_loss: 0.7746 - val_accuracy: 0.8147
Epoch 16/30
52/52 [==============================] - 21s 404ms/step - loss: 0.7284 - accuracy: 0.8335 - val_loss: 1.0342 - val_accuracy: 0.7330
Epoch 17/30
52/52 [==============================] - 21s 409ms/step - loss: 0.7484 - accuracy: 0.8262 - val_loss: 1.0523 - val_accuracy: 0.7112
Epoch 18/30
52/52 [==============================] - 21s 408ms/step - loss: 0.7209 - accuracy: 0.8450 - val_loss: 0.8146 - val_accuracy: 0.8174
Epoch 19/30
52/52 [==============================] - 21s 409ms/step - loss: 0.7141 - accuracy: 0.8435 - val_loss: 0.8016 - val_accuracy: 0.7875
Epoch 20/30
52/52 [==============================] - 21s 410ms/step - loss: 0.7075 - accuracy: 0.8435 - val_loss: 0.9352 - val_accuracy: 0.7439
Epoch 21/30
52/52 [==============================] - 21s 406ms/step - loss: 0.7066 - accuracy: 0.8504 - val_loss: 1.0171 - val_accuracy: 0.7139
Epoch 22/30
52/52 [==============================] - 21s 405ms/step - loss: 0.6913 - accuracy: 0.8532 - val_loss: 0.7059 - val_accuracy: 0.8610
Epoch 23/30
52/52 [==============================] - 21s 408ms/step - loss: 0.6681 - accuracy: 0.8671 - val_loss: 0.8007 - val_accuracy: 0.8147
Epoch 24/30
52/52 [==============================] - 21s 409ms/step - loss: 0.6636 - accuracy: 0.8747 - val_loss: 0.9490 - val_accuracy: 0.7302
Epoch 25/30
52/52 [==============================] - 21s 408ms/step - loss: 0.6637 - accuracy: 0.8722 - val_loss: 0.6913 - val_accuracy: 0.8556
Epoch 26/30
52/52 [==============================] - 21s 406ms/step - loss: 0.6443 - accuracy: 0.8837 - val_loss: 1.0483 - val_accuracy: 0.7139
Epoch 27/30
52/52 [==============================] - 21s 407ms/step - loss: 0.6555 - accuracy: 0.8695 - val_loss: 0.9448 - val_accuracy: 0.7602
Epoch 28/30
52/52 [==============================] - 21s 409ms/step - loss: 0.6409 - accuracy: 0.8807 - val_loss: 0.9337 - val_accuracy: 0.7302
Epoch 29/30
52/52 [==============================] - 21s 408ms/step - loss: 0.6300 - accuracy: 0.8910 - val_loss: 0.7461 - val_accuracy: 0.8256
Epoch 30/30
52/52 [==============================] - 21s 408ms/step - loss: 0.6093 - accuracy: 0.8968 - val_loss: 0.8651 - val_accuracy: 0.7766
6/6 [==============================] - 0s 65ms/step - loss: 0.7059 - accuracy: 0.8610
Validation accuracy: 86.1%
--- ## 結果和 TFLite 轉換 大約一百萬個參數,在 256x256 解析度下達到約 85% 的 top-1 準確度是一個很棒的結果。 這個 MobileViT 行動裝置完全與 TensorFlow Lite (TFLite) 相容,可以使用下列程式碼進行轉換
# Serialize the model as a SavedModel.
tf.saved_model.save(mobilevit_xxs, "mobilevit_xxs")

# Convert to TFLite. This form of quantization is called
# post-training dynamic-range quantization in TFLite.
converter = tf.lite.TFLiteConverter.from_saved_model("mobilevit_xxs")
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_ops = [
    tf.lite.OpsSet.TFLITE_BUILTINS,  # Enable TensorFlow Lite ops.
    tf.lite.OpsSet.SELECT_TF_OPS,  # Enable TensorFlow ops.
tflite_model = converter.convert()
open("mobilevit_xxs.tflite", "wb").write(tflite_model)
若要深入了解 TFLite 中可用的不同量化方法,以及使用 TFLite 模型執行推論,請查看[此官方資源](https://tensorflow.dev.org.tw/lite/performance/post_training_quantization)。 您可以使用託管在 [Hugging Face Hub](https://huggingface.co/keras-io/mobile-vit-xxs) 上的已訓練模型,並在 [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/Flowers-Classification-MobileViT) 上試用展示。