程式碼範例 / Keras 快速入門範例 / 使用函數式子類別化方式封裝 Keras 模型以廣泛發布

使用函數式子類別化方式封裝 Keras 模型以廣泛發布

作者: Martin Görner
建立日期 2023-12-13
上次修改 2023-12-13
說明: 當您分享深度學習模型時,請使用函數式子類別化模式將其封裝。

ⓘ 此範例使用 Keras 3

在 Colab 中檢視 GitHub 原始碼


簡介

Keras 是分享您尖端深度學習模型的理想框架,無論是預先訓練(或非預先訓練)的模型庫。數百萬的 ML 工程師都精通熟悉的 Keras API,讓您的模型可以被全球社群存取,無論他們偏好哪個後端(Jax、PyTorch 或 TensorFlow)。

Keras API 的優點之一是它允許使用者以程式化的方式檢查或編輯模型,此功能在建立基於預先訓練模型的新架構或工作流程時是必要的。

在分發模型時,Keras 團隊建議使用函數式子類別化模式封裝它們。以這種方式實現的模型結合了兩種優點

  • 它們可以以正常的 Python 方式實例化
    model = model_collection_xyz.AmazingModel()
  • 它們是 Keras 函數式模型,這意味著它們具有可透過程式存取的層圖,用於內省或模型手術。

本指南說明如何使用函數式子類別化模式,並展示其在程式化模型內省模型手術方面的優點。它還展示了其他兩個可共享 Keras 模型的最佳實務:設定模型以支援最廣泛的輸入範圍,例如各種大小的圖像,以及使用字典輸入來提高更複雜模型的清晰度。


設定

import keras
import tensorflow as tf  # only for tf.data

print("Keras version", keras.version())
print("Keras is running on", keras.config.backend())
Keras version 3.0.1
Keras is running on tensorflow

資料集

讓我們載入 MNIST 資料集,以便我們有一些東西可以訓練。

# tf.data is a great API for putting together a data stream.
# It works whether you use the TensorFlow, PyTorch or Jax backend,
# as long as you use it in the data stream only and not inside of a model.

BATCH_SIZE = 256

(x_train, train_labels), (x_test, test_labels) = keras.datasets.mnist.load_data()

train_data = tf.data.Dataset.from_tensor_slices((x_train, train_labels))
train_data = train_data.map(
    lambda x, y: (tf.expand_dims(x, axis=-1), y)
)  # 1-channel monochrome
train_data = train_data.batch(BATCH_SIZE)
train_data = train_data.cache()
train_data = train_data.shuffle(5000, reshuffle_each_iteration=True)
train_data = train_data.repeat()

test_data = tf.data.Dataset.from_tensor_slices((x_test, test_labels))
test_data = test_data.map(
    lambda x, y: (tf.expand_dims(x, axis=-1), y)
)  # 1-channel monochrome
test_data = test_data.batch(10000)
test_data = test_data.cache()

STEPS_PER_EPOCH = len(train_labels) // BATCH_SIZE
EPOCHS = 5

函數式子類別化模型

該模型被包裝在一個類別中,以便最終使用者可以通過呼叫建構函式 MnistModel() 而不是呼叫工廠函式來正常實例化它。

class MnistModel(keras.Model):
    def __init__(self, **kwargs):
        # Keras Functional model definition. This could have used Sequential as
        # well. Sequential is just syntactic sugar for simple functional models.

        # 1-channel monochrome input
        inputs = keras.layers.Input(shape=(None, None, 1), dtype="uint8")
        # pixel format conversion from uint8 to float32
        y = keras.layers.Rescaling(1 / 255.0)(inputs)

        # 3 convolutional layers
        y = keras.layers.Conv2D(
            filters=16, kernel_size=3, padding="same", activation="relu"
        )(y)
        y = keras.layers.Conv2D(
            filters=32, kernel_size=6, padding="same", activation="relu", strides=2
        )(y)
        y = keras.layers.Conv2D(
            filters=48, kernel_size=6, padding="same", activation="relu", strides=2
        )(y)

        # 2 dense layers
        y = keras.layers.GlobalAveragePooling2D()(y)
        y = keras.layers.Dense(48, activation="relu")(y)
        y = keras.layers.Dropout(0.4)(y)
        outputs = keras.layers.Dense(
            10, activation="softmax", name="classification_head"  # 10 classes
        )(y)

        # A Keras Functional model is created by calling keras.Model(inputs, outputs)
        super().__init__(inputs=inputs, outputs=outputs, **kwargs)

讓我們實例化並訓練這個模型。

model = MnistModel()

model.compile(
    optimizer="adam",
    loss="sparse_categorical_crossentropy",
    metrics=["sparse_categorical_accuracy"],
)

history = model.fit(
    train_data,
    steps_per_epoch=STEPS_PER_EPOCH,
    epochs=EPOCHS,
    validation_data=test_data,
)
Epoch 1/5
 234/234 ━━━━━━━━━━━━━━━━━━━━ 9s 33ms/step - loss: 1.8916 - sparse_categorical_accuracy: 0.2933 - val_loss: 0.4278 - val_sparse_categorical_accuracy: 0.8864
Epoch 2/5
 234/234 ━━━━━━━━━━━━━━━━━━━━ 7s 31ms/step - loss: 0.5723 - sparse_categorical_accuracy: 0.8201 - val_loss: 0.2703 - val_sparse_categorical_accuracy: 0.9248
Epoch 3/5
 234/234 ━━━━━━━━━━━━━━━━━━━━ 7s 31ms/step - loss: 0.4063 - sparse_categorical_accuracy: 0.8772 - val_loss: 0.2010 - val_sparse_categorical_accuracy: 0.9400
Epoch 4/5
 234/234 ━━━━━━━━━━━━━━━━━━━━ 7s 31ms/step - loss: 0.3391 - sparse_categorical_accuracy: 0.8996 - val_loss: 0.1869 - val_sparse_categorical_accuracy: 0.9427
Epoch 5/5
 234/234 ━━━━━━━━━━━━━━━━━━━━ 7s 31ms/step - loss: 0.2989 - sparse_categorical_accuracy: 0.9120 - val_loss: 0.1513 - val_sparse_categorical_accuracy: 0.9557

不受限制的輸入

請注意,在上面的模型定義中,輸入是以未定義的維度指定的:Input(shape=(None, None, 1)

這允許模型接受任何圖像大小作為輸入。但是,這只有在鬆散定義的形狀可以透過所有層傳播並仍然確定所有權重的大小時才有效。

  • 因此,如果您有一個可以處理具有相同權重的不同輸入大小的模型架構(像這裡),那麼您的使用者將能夠在沒有參數的情況下實例化它
    model = MnistModel()
  • 另一方面,如果模型必須為不同的輸入大小配置不同的權重,您將必須要求使用者在建構函式中指定大小
    model = ModelXYZ(input_size=...)

模型內省

Keras 為每個模型維護一個可透過程式存取的層圖。它可以用於內省,並透過 model.layerslayer.layers 屬性存取。公用程式函式 model.summary() 也在內部使用此機制。

model = MnistModel()

# Model summary works
model.summary()


# Recursively walking the layer graph works as well
def walk_layers(layer):
    if hasattr(layer, "layers"):
        for layer in layer.layers:
            walk_layers(layer)
    else:
        print(layer.name)


print("\nWalking model layers:\n")
walk_layers(model)
Model: "mnist_model_1"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓
┃ Layer (type)                     Output Shape                  Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩
│ input_layer_1 (InputLayer)      │ (None, None, None, 1)     │          0 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ rescaling_1 (Rescaling)         │ (None, None, None, 1)     │          0 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ conv2d_3 (Conv2D)               │ (None, None, None, 16)    │        160 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ conv2d_4 (Conv2D)               │ (None, None, None, 32)    │     18,464 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ conv2d_5 (Conv2D)               │ (None, None, None, 48)    │     55,344 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ global_average_pooling2d_1      │ (None, 48)                │          0 │
│ (GlobalAveragePooling2D)        │                           │            │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ dense_1 (Dense)                 │ (None, 48)                │      2,352 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ dropout_1 (Dropout)             │ (None, 48)                │          0 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ classification_head (Dense)     │ (None, 10)                │        490 │
└─────────────────────────────────┴───────────────────────────┴────────────┘
 Total params: 76,810 (300.04 KB)
 Trainable params: 76,810 (300.04 KB)
 Non-trainable params: 0 (0.00 B)
Walking model layers:
input_layer_1
rescaling_1
conv2d_3
conv2d_4
conv2d_5
global_average_pooling2d_1
dense_1
dropout_1
classification_head

模型手術

最終使用者可能想要從您的程式庫中實例化模型,但在使用前修改它。函數式模型具有可透過程式存取的層圖。可以透過切片和拼接圖形並建立新的函數式模型來進行編輯。

另一種方法是分叉模型程式碼並進行修改,但這會強迫使用者無限期地維護他們的分叉。

範例:實例化模型,但將分類頭變更為執行二元分類,「0」或「非 0」,而不是原始的 10 向數字分類。

model = MnistModel()

input = model.input
# cut before the classification head
y = model.get_layer("classification_head").input

# add a new classification head
output = keras.layers.Dense(
    1,  # single class for binary classification
    activation="sigmoid",
    name="binary_classification_head",
)(y)

# create a new functional model
binary_model = keras.Model(input, output)

binary_model.summary()
Model: "functional_1"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓
┃ Layer (type)                     Output Shape                  Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩
│ input_layer_2 (InputLayer)      │ (None, None, None, 1)     │          0 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ rescaling_2 (Rescaling)         │ (None, None, None, 1)     │          0 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ conv2d_6 (Conv2D)               │ (None, None, None, 16)    │        160 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ conv2d_7 (Conv2D)               │ (None, None, None, 32)    │     18,464 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ conv2d_8 (Conv2D)               │ (None, None, None, 48)    │     55,344 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ global_average_pooling2d_2      │ (None, 48)                │          0 │
│ (GlobalAveragePooling2D)        │                           │            │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ dense_2 (Dense)                 │ (None, 48)                │      2,352 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ dropout_2 (Dropout)             │ (None, 48)                │          0 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ binary_classification_head      │ (None, 1)                 │         49 │
│ (Dense)                         │                           │            │
└─────────────────────────────────┴───────────────────────────┴────────────┘
 Total params: 76,369 (298.32 KB)
 Trainable params: 76,369 (298.32 KB)
 Non-trainable params: 0 (0.00 B)

我們現在可以將新模型訓練為二元分類器。

# new dataset with 0 / 1 labels (1 = digit '0', 0 = all other digits)
bin_train_data = train_data.map(
    lambda x, y: (x, tf.cast(tf.math.equal(y, tf.zeros_like(y)), dtype=tf.uint8))
)
bin_test_data = test_data.map(
    lambda x, y: (x, tf.cast(tf.math.equal(y, tf.zeros_like(y)), dtype=tf.uint8))
)

# appropriate loss and metric for binary classification
binary_model.compile(
    optimizer="adam", loss="binary_crossentropy", metrics=["binary_accuracy"]
)

history = binary_model.fit(
    bin_train_data,
    steps_per_epoch=STEPS_PER_EPOCH,
    epochs=EPOCHS,
    validation_data=bin_test_data,
)
Epoch 1/5
 234/234 ━━━━━━━━━━━━━━━━━━━━ 9s 33ms/step - binary_accuracy: 0.8926 - loss: 0.3635 - val_binary_accuracy: 0.9235 - val_loss: 0.1777
Epoch 2/5
 234/234 ━━━━━━━━━━━━━━━━━━━━ 7s 31ms/step - binary_accuracy: 0.9411 - loss: 0.1620 - val_binary_accuracy: 0.9766 - val_loss: 0.0748
Epoch 3/5
 234/234 ━━━━━━━━━━━━━━━━━━━━ 7s 31ms/step - binary_accuracy: 0.9751 - loss: 0.0794 - val_binary_accuracy: 0.9884 - val_loss: 0.0414
Epoch 4/5
 234/234 ━━━━━━━━━━━━━━━━━━━━ 7s 31ms/step - binary_accuracy: 0.9848 - loss: 0.0480 - val_binary_accuracy: 0.9915 - val_loss: 0.0292
Epoch 5/5
 234/234 ━━━━━━━━━━━━━━━━━━━━ 7s 31ms/step - binary_accuracy: 0.9910 - loss: 0.0326 - val_binary_accuracy: 0.9917 - val_loss: 0.0286

使用字典輸入的模型

在具有多個輸入的更複雜模型中,將輸入結構化為字典可以提高可讀性和可用性。使用函數式模型可以簡單地做到這一點

class MnistDictModel(keras.Model):
    def __init__(self, **kwargs):
        #
        # The input is a dictionary
        #
        inputs = {
            "image": keras.layers.Input(
                shape=(None, None, 1),  # 1-channel monochrome
                dtype="uint8",
                name="image",
            )
        }

        # pixel format conversion from uint8 to float32
        y = keras.layers.Rescaling(1 / 255.0)(inputs["image"])

        # 3 conv layers
        y = keras.layers.Conv2D(
            filters=16, kernel_size=3, padding="same", activation="relu"
        )(y)
        y = keras.layers.Conv2D(
            filters=32, kernel_size=6, padding="same", activation="relu", strides=2
        )(y)
        y = keras.layers.Conv2D(
            filters=48, kernel_size=6, padding="same", activation="relu", strides=2
        )(y)

        # 2 dense layers
        y = keras.layers.GlobalAveragePooling2D()(y)
        y = keras.layers.Dense(48, activation="relu")(y)
        y = keras.layers.Dropout(0.4)(y)
        outputs = keras.layers.Dense(
            10, activation="softmax", name="classification_head"  # 10 classes
        )(y)

        # A Keras Functional model is created by calling keras.Model(inputs, outputs)
        super().__init__(inputs=inputs, outputs=outputs, **kwargs)

我們現在可以訓練模型,其輸入結構化為字典。

model = MnistDictModel()

# reformat the dataset as a dictionary
dict_train_data = train_data.map(lambda x, y: ({"image": x}, y))
dict_test_data = test_data.map(lambda x, y: ({"image": x}, y))

model.compile(
    optimizer="adam",
    loss="sparse_categorical_crossentropy",
    metrics=["sparse_categorical_accuracy"],
)

history = model.fit(
    dict_train_data,
    steps_per_epoch=STEPS_PER_EPOCH,
    epochs=EPOCHS,
    validation_data=dict_test_data,
)
Epoch 1/5
 234/234 ━━━━━━━━━━━━━━━━━━━━ 9s 34ms/step - loss: 1.8702 - sparse_categorical_accuracy: 0.3175 - val_loss: 0.4505 - val_sparse_categorical_accuracy: 0.8779
Epoch 2/5
 234/234 ━━━━━━━━━━━━━━━━━━━━ 8s 32ms/step - loss: 0.5991 - sparse_categorical_accuracy: 0.8131 - val_loss: 0.2582 - val_sparse_categorical_accuracy: 0.9245
Epoch 3/5
 234/234 ━━━━━━━━━━━━━━━━━━━━ 7s 32ms/step - loss: 0.3916 - sparse_categorical_accuracy: 0.8846 - val_loss: 0.1938 - val_sparse_categorical_accuracy: 0.9422
Epoch 4/5
 234/234 ━━━━━━━━━━━━━━━━━━━━ 8s 33ms/step - loss: 0.3109 - sparse_categorical_accuracy: 0.9089 - val_loss: 0.1450 - val_sparse_categorical_accuracy: 0.9566
Epoch 5/5
 234/234 ━━━━━━━━━━━━━━━━━━━━ 8s 32ms/step - loss: 0.2775 - sparse_categorical_accuracy: 0.9197 - val_loss: 0.1316 - val_sparse_categorical_accuracy: 0.9608