作者: fchollet
建立日期 2020/04/15
最後修改日期 2023/06/25
描述: Keras 中遷移學習與微調的完整指南。
import numpy as np
import keras
from keras import layers
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
遷移學習 包括採用在一個問題上學習到的特徵,並將它們應用於新的、類似的問題。例如,一個已學習識別浣熊的模型中的特徵,可能對啟動一個旨在識別狸的模型很有用。
遷移學習通常用於資料集資料不足,無法從頭開始訓練完整規模模型的任務。
深度學習環境中最常見的遷移學習形式是以下工作流程
最後一個可選步驟是微調,它包括解凍您在上面獲得的整個模型(或部分模型),並以非常低的學習率在新資料上重新訓練它。這可能會通過逐步調整預訓練的特徵以適應新資料,從而實現有意義的改進。
首先,我們將詳細介紹 Keras 的 trainable
API,這是大多數遷移學習和微調工作流程的基礎。
然後,我們將通過採用一個在 ImageNet 資料集上預訓練的模型,並在 Kaggle「貓與狗」分類資料集上重新訓練它來展示典型的工作流程。
這是從 Deep Learning with Python 和 2016 年的部落格文章 「使用極少資料建立強大的影像分類模型」 改編而來。
trainable
屬性層和模型具有三個權重屬性
weights
是該層所有權重變數的列表。trainable_weights
是那些旨在通過梯度下降更新(以最大程度減少訓練期間的損失)的列表。non_trainable_weights
是那些不打算訓練的列表。通常它們在正向傳遞期間由模型更新。範例:Dense
層具有 2 個可訓練的權重(核心和偏差)
layer = keras.layers.Dense(3)
layer.build((None, 4)) # Create the weights
print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 2
trainable_weights: 2
non_trainable_weights: 0
一般來說,所有權重都是可訓練的權重。唯一具有不可訓練權重的內建層是 BatchNormalization
層。它使用不可訓練的權重來追蹤訓練期間輸入的平均值和變異數。若要了解如何在您自己的自訂層中使用不可訓練的權重,請參閱從頭開始撰寫新層的指南。
範例:BatchNormalization
層具有 2 個可訓練權重和 2 個不可訓練權重
layer = keras.layers.BatchNormalization()
layer.build((None, 4)) # Create the weights
print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 4
trainable_weights: 2
non_trainable_weights: 2
層和模型還具有布林屬性 trainable
。其值可以更改。將 layer.trainable
設定為 False
會將該層的所有權重從可訓練移動到不可訓練。這稱為「凍結」該層:凍結層的狀態在訓練期間不會更新(無論是使用 fit()
進行訓練,還是使用任何依賴 trainable_weights
來應用梯度更新的自訂迴圈進行訓練)。
範例:將 trainable
設定為 False
layer = keras.layers.Dense(3)
layer.build((None, 4)) # Create the weights
layer.trainable = False # Freeze the layer
print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 2
trainable_weights: 0
non_trainable_weights: 2
當可訓練權重變成不可訓練時,其值在訓練期間不再更新。
# Make a model with 2 layers
layer1 = keras.layers.Dense(3, activation="relu")
layer2 = keras.layers.Dense(3, activation="sigmoid")
model = keras.Sequential([keras.Input(shape=(3,)), layer1, layer2])
# Freeze the first layer
layer1.trainable = False
# Keep a copy of the weights of layer1 for later reference
initial_layer1_weights_values = layer1.get_weights()
# Train the model
model.compile(optimizer="adam", loss="mse")
model.fit(np.random.random((2, 3)), np.random.random((2, 3)))
# Check that the weights of layer1 have not changed during training
final_layer1_weights_values = layer1.get_weights()
np.testing.assert_allclose(
initial_layer1_weights_values[0], final_layer1_weights_values[0]
)
np.testing.assert_allclose(
initial_layer1_weights_values[1], final_layer1_weights_values[1]
)
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 766ms/step - loss: 0.0615
不要將 layer.trainable
屬性與 layer.__call__()
中的參數 training
混淆(它控制該層應該以推論模式還是訓練模式執行正向傳遞)。如需詳細資訊,請參閱Keras 常見問題。
trainable
屬性如果您在模型或任何具有子層的層上設定 trainable = False
,則所有子層也會變成不可訓練的。
範例
inner_model = keras.Sequential(
[
keras.Input(shape=(3,)),
keras.layers.Dense(3, activation="relu"),
keras.layers.Dense(3, activation="relu"),
]
)
model = keras.Sequential(
[
keras.Input(shape=(3,)),
inner_model,
keras.layers.Dense(3, activation="sigmoid"),
]
)
model.trainable = False # Freeze the outer model
assert inner_model.trainable == False # All layers in `model` are now frozen
assert inner_model.layers[0].trainable == False # `trainable` is propagated recursively
這引導我們了解如何在 Keras 中實作典型的遷移學習工作流程
trainable = False
來凍結基本模型中的所有層。請注意,另一種更輕量的工作流程也可能是
第二個工作流程的一個主要優勢是您只需在資料上執行基本模型一次,而不是每個訓練週期執行一次。因此,它更快、更便宜。
但是,第二個工作流程的一個問題是它不允許您在訓練期間動態修改新模型的輸入資料,這在進行資料擴充時是必需的。當您的新資料集資料不足,無法從頭開始訓練完整規模模型時,通常會使用遷移學習,並且在這種情況下,資料擴充非常重要。因此,在接下來的內容中,我們將重點介紹第一個工作流程。
以下是 Keras 中第一個工作流程的外觀
首先,例示具有預訓練權重的基本模型。
base_model = keras.applications.Xception(
weights='imagenet', # Load weights pre-trained on ImageNet.
input_shape=(150, 150, 3),
include_top=False) # Do not include the ImageNet classifier at the top.
然後,凍結基本模型。
base_model.trainable = False
在頂部建立新模型。
inputs = keras.Input(shape=(150, 150, 3))
# We make sure that the base_model is running in inference mode here,
# by passing `training=False`. This is important for fine-tuning, as you will
# learn in a few paragraphs.
x = base_model(inputs, training=False)
# Convert features of shape `base_model.output_shape[1:]` to vectors
x = keras.layers.GlobalAveragePooling2D()(x)
# A Dense classifier with a single unit (binary classification)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)
在新資料上訓練模型。
model.compile(optimizer=keras.optimizers.Adam(),
loss=keras.losses.BinaryCrossentropy(from_logits=True),
metrics=[keras.metrics.BinaryAccuracy()])
model.fit(new_dataset, epochs=20, callbacks=..., validation_data=...)
一旦您的模型在新資料上收斂,您可以嘗試解凍基本模型的全部或部分,並以非常低的學習率重新訓練整個模型端到端。
這是一個可選的最後一步,可能會給您帶來逐步的改進。它也可能導致快速過度擬合 — 請記住這一點。
僅在訓練具有凍結層的模型收斂後執行此步驟至關重要。如果您將隨機初始化的可訓練層與保留預訓練特徵的可訓練層混合,則隨機初始化的層將在訓練期間引起非常大的梯度更新,這將破壞您的預訓練特徵。
在這個階段使用非常低的學習率也至關重要,因為您正在訓練比第一輪訓練更大的模型,並且在一個通常非常小的資料集上進行訓練。因此,如果您應用較大的權重更新,您可能會很快面臨過度擬合的風險。在這裡,您只想以漸進的方式重新調整預訓練的權重。
這是在整個基本模型上實作微調的方式
# Unfreeze the base model
base_model.trainable = True
# It's important to recompile your model after you make any changes
# to the `trainable` attribute of any inner layer, so that your changes
# are take into account
model.compile(optimizer=keras.optimizers.Adam(1e-5), # Very low learning rate
loss=keras.losses.BinaryCrossentropy(from_logits=True),
metrics=[keras.metrics.BinaryAccuracy()])
# Train end-to-end. Be careful to stop before you overfit!
model.fit(new_dataset, epochs=10, callbacks=..., validation_data=...)
關於 compile()
和 trainable
的重要注意事項
在模型上呼叫 compile()
是為了「凍結」該模型的行為。這意味著在編譯模型時的 trainable
屬性值應在該模型的整個生命週期內保留,直到再次呼叫 compile
為止。因此,如果您變更任何 trainable
值,請確保在您的模型上再次呼叫 compile()
,以便您的變更生效。
關於 BatchNormalization
層的重要注意事項
許多影像模型都包含 BatchNormalization
層。該層在每個可想像的方面都是一個特例。以下是一些需要記住的事項。
BatchNormalization
包含 2 個在訓練期間更新的不可訓練權重。這些是追蹤輸入的平均值和變異數的變數。bn_layer.trainable = False
時,BatchNormalization
層將在推論模式下執行,並且不會更新其平均值和變異數統計資訊。這與一般情況下的其他層不同,因為權重可訓練性和推論/訓練模式是兩個正交的概念。但是,在 BatchNormalization
層的情況下,這兩者是綁定的。BatchNormalization
層的模型以進行微調時,在呼叫基礎模型時應傳遞 training=False
,讓 BatchNormalization
層保持在推論模式。否則,對不可訓練權重套用的更新會突然破壞模型已學習的內容。您將在本指南末尾的端對端範例中看到此模式的實際應用。
為了鞏固這些概念,讓我們逐步引導您完成一個具體的端對端遷移學習和微調範例。我們將載入在 ImageNet 上預先訓練的 Xception 模型,並將其用於 Kaggle 的「貓 vs. 狗」分類資料集。
首先,讓我們使用 TFDS 取得貓狗資料集。如果您有自己的資料集,您可能會想使用 keras.utils.image_dataset_from_directory
工具,從磁碟上分類到特定類別資料夾的一組影像中產生類似的標籤資料集物件。
當處理非常小的資料集時,遷移學習最有用。為了保持資料集較小,我們將使用原始訓練資料(25,000 張影像)的 40% 用於訓練,10% 用於驗證,10% 用於測試。
tfds.disable_progress_bar()
train_ds, validation_ds, test_ds = tfds.load(
"cats_vs_dogs",
# Reserve 10% for validation and 10% for test
split=["train[:40%]", "train[40%:50%]", "train[50%:60%]"],
as_supervised=True, # Include labels
)
print(f"Number of training samples: {train_ds.cardinality()}")
print(f"Number of validation samples: {validation_ds.cardinality()}")
print(f"Number of test samples: {test_ds.cardinality()}")
Downloading and preparing dataset 786.68 MiB (download: 786.68 MiB, generated: Unknown size, total: 786.68 MiB) to /home/mattdangerw/tensorflow_datasets/cats_vs_dogs/4.0.0...
WARNING:absl:1738 images were corrupted and were skipped
Dataset cats_vs_dogs downloaded and prepared to /home/mattdangerw/tensorflow_datasets/cats_vs_dogs/4.0.0. Subsequent calls will reuse this data.
Number of training samples: 9305
Number of validation samples: 2326
Number of test samples: 2326
這些是訓練資料集中前 9 張影像 – 如您所見,它們的大小都不同。
plt.figure(figsize=(10, 10))
for i, (image, label) in enumerate(train_ds.take(9)):
ax = plt.subplot(3, 3, i + 1)
plt.imshow(image)
plt.title(int(label))
plt.axis("off")
我們也可以看到標籤 1 是「狗」,標籤 0 是「貓」。
我們的原始影像大小不一。此外,每個像素都包含介於 0 到 255 之間的 3 個整數值(RGB 級別值)。這不適合用於饋送神經網路。我們需要做兩件事
Normalization
層來執行此操作。一般而言,開發將原始資料作為輸入的模型是一種很好的做法,而不是開發將預先處理過的資料作為輸入的模型。原因是,如果您的模型期望預先處理過的資料,則每次您匯出模型以便在其他地方(在網頁瀏覽器、行動應用程式中)使用時,您都需要重新實作完全相同的預處理流程。這很快就會變得非常棘手。因此,我們應該在接觸模型之前盡可能少做預處理。
在這裡,我們將在資料流程中進行影像大小調整(因為深度神經網路只能處理連續的資料批次),並且我們將在建立模型時將輸入值縮放作為模型的一部分。
讓我們將影像調整為 150x150
resize_fn = keras.layers.Resizing(150, 150)
train_ds = train_ds.map(lambda x, y: (resize_fn(x), y))
validation_ds = validation_ds.map(lambda x, y: (resize_fn(x), y))
test_ds = test_ds.map(lambda x, y: (resize_fn(x), y))
當您沒有大型影像資料集時,通過對訓練影像應用隨機但逼真的轉換,例如隨機水平翻轉或小的隨機旋轉,人工引入樣本多樣性是一種很好的做法。這有助於使模型接觸到訓練資料的不同方面,同時減緩過擬合。
augmentation_layers = [
layers.RandomFlip("horizontal"),
layers.RandomRotation(0.1),
]
def data_augmentation(x):
for layer in augmentation_layers:
x = layer(x)
return x
train_ds = train_ds.map(lambda x, y: (data_augmentation(x), y))
讓我們對資料進行批次處理,並使用預取來最佳化載入速度。
from tensorflow import data as tf_data
batch_size = 64
train_ds = train_ds.batch(batch_size).prefetch(tf_data.AUTOTUNE).cache()
validation_ds = validation_ds.batch(batch_size).prefetch(tf_data.AUTOTUNE).cache()
test_ds = test_ds.batch(batch_size).prefetch(tf_data.AUTOTUNE).cache()
讓我們視覺化第一個批次的第一個影像在經過各種隨機轉換後的外觀
for images, labels in train_ds.take(1):
plt.figure(figsize=(10, 10))
first_image = images[0]
for i in range(9):
ax = plt.subplot(3, 3, i + 1)
augmented_image = data_augmentation(np.expand_dims(first_image, 0))
plt.imshow(np.array(augmented_image[0]).astype("int32"))
plt.title(int(labels[0]))
plt.axis("off")
現在,讓我們建立一個遵循我們先前解釋的藍圖的模型。
請注意
Rescaling
層,以將輸入值(最初在 [0, 255]
範圍內)縮放到 [-1, 1]
範圍。Dropout
層,用於正規化。training=False
,使其在推論模式下執行,以便即使在我們解凍基礎模型以進行微調後,批次正規化統計資訊也不會更新。base_model = keras.applications.Xception(
weights="imagenet", # Load weights pre-trained on ImageNet.
input_shape=(150, 150, 3),
include_top=False,
) # Do not include the ImageNet classifier at the top.
# Freeze the base_model
base_model.trainable = False
# Create new model on top
inputs = keras.Input(shape=(150, 150, 3))
# Pre-trained Xception weights requires that input be scaled
# from (0, 255) to a range of (-1., +1.), the rescaling layer
# outputs: `(inputs * scale) + offset`
scale_layer = keras.layers.Rescaling(scale=1 / 127.5, offset=-1)
x = scale_layer(inputs)
# The base model contains batchnorm layers. We want to keep them in inference mode
# when we unfreeze the base model for fine-tuning, so we make sure that the
# base_model is running in inference mode here.
x = base_model(x, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.Dropout(0.2)(x) # Regularize with dropout
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)
model.summary(show_trainable=True)
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/xception/xception_weights_tf_dim_ordering_tf_kernels_notop.h5
83683744/83683744 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step
Model: "functional_4"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Trai… ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━┩ │ input_layer_4 (InputLayer) │ (None, 150, 150, 3) │ 0 │ - │ ├─────────────────────────────┼──────────────────────────┼─────────┼───────┤ │ rescaling (Rescaling) │ (None, 150, 150, 3) │ 0 │ - │ ├─────────────────────────────┼──────────────────────────┼─────────┼───────┤ │ xception (Functional) │ (None, 5, 5, 2048) │ 20,861… │ N │ ├─────────────────────────────┼──────────────────────────┼─────────┼───────┤ │ global_average_pooling2d │ (None, 2048) │ 0 │ - │ │ (GlobalAveragePooling2D) │ │ │ │ ├─────────────────────────────┼──────────────────────────┼─────────┼───────┤ │ dropout (Dropout) │ (None, 2048) │ 0 │ - │ ├─────────────────────────────┼──────────────────────────┼─────────┼───────┤ │ dense_7 (Dense) │ (None, 1) │ 2,049 │ Y │ └─────────────────────────────┴──────────────────────────┴─────────┴───────┘
Total params: 20,863,529 (79.59 MB)
Trainable params: 2,049 (8.00 KB)
Non-trainable params: 20,861,480 (79.58 MB)
model.compile(
optimizer=keras.optimizers.Adam(),
loss=keras.losses.BinaryCrossentropy(from_logits=True),
metrics=[keras.metrics.BinaryAccuracy()],
)
epochs = 2
print("Fitting the top layer of the model")
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
Fitting the top layer of the model
Epoch 1/2
78/146 ━━━━━━━━━━[37m━━━━━━━━━━ 15s 226ms/step - binary_accuracy: 0.7995 - loss: 0.4088
Corrupt JPEG data: 65 extraneous bytes before marker 0xd9
136/146 ━━━━━━━━━━━━━━━━━━[37m━━ 2s 231ms/step - binary_accuracy: 0.8430 - loss: 0.3298
Corrupt JPEG data: 239 extraneous bytes before marker 0xd9
143/146 ━━━━━━━━━━━━━━━━━━━[37m━ 0s 231ms/step - binary_accuracy: 0.8464 - loss: 0.3235
Corrupt JPEG data: 1153 extraneous bytes before marker 0xd9
144/146 ━━━━━━━━━━━━━━━━━━━[37m━ 0s 231ms/step - binary_accuracy: 0.8468 - loss: 0.3226
Corrupt JPEG data: 228 extraneous bytes before marker 0xd9
146/146 ━━━━━━━━━━━━━━━━━━━━ 0s 260ms/step - binary_accuracy: 0.8478 - loss: 0.3209
Corrupt JPEG data: 2226 extraneous bytes before marker 0xd9
146/146 ━━━━━━━━━━━━━━━━━━━━ 54s 317ms/step - binary_accuracy: 0.8482 - loss: 0.3200 - val_binary_accuracy: 0.9667 - val_loss: 0.0877
Epoch 2/2
146/146 ━━━━━━━━━━━━━━━━━━━━ 7s 51ms/step - binary_accuracy: 0.9483 - loss: 0.1232 - val_binary_accuracy: 0.9705 - val_loss: 0.0786
<keras.src.callbacks.history.History at 0x7fc8b7f1db70>
最後,讓我們解凍基礎模型,並以低學習率端對端訓練整個模型。
重要的是,儘管基礎模型變得可訓練,但由於我們在建立模型時呼叫它時傳遞了 training=False
,它仍然在推論模式下執行。這表示內部的批次正規化層不會更新其批次統計資訊。如果它們更新了,它們會破壞模型到目前為止學習到的表示。
# Unfreeze the base_model. Note that it keeps running in inference mode
# since we passed `training=False` when calling it. This means that
# the batchnorm layers will not update their batch statistics.
# This prevents the batchnorm layers from undoing all the training
# we've done so far.
base_model.trainable = True
model.summary(show_trainable=True)
model.compile(
optimizer=keras.optimizers.Adam(1e-5), # Low learning rate
loss=keras.losses.BinaryCrossentropy(from_logits=True),
metrics=[keras.metrics.BinaryAccuracy()],
)
epochs = 1
print("Fitting the end-to-end model")
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
Model: "functional_4"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Trai… ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━┩ │ input_layer_4 (InputLayer) │ (None, 150, 150, 3) │ 0 │ - │ ├─────────────────────────────┼──────────────────────────┼─────────┼───────┤ │ rescaling (Rescaling) │ (None, 150, 150, 3) │ 0 │ - │ ├─────────────────────────────┼──────────────────────────┼─────────┼───────┤ │ xception (Functional) │ (None, 5, 5, 2048) │ 20,861… │ Y │ ├─────────────────────────────┼──────────────────────────┼─────────┼───────┤ │ global_average_pooling2d │ (None, 2048) │ 0 │ - │ │ (GlobalAveragePooling2D) │ │ │ │ ├─────────────────────────────┼──────────────────────────┼─────────┼───────┤ │ dropout (Dropout) │ (None, 2048) │ 0 │ - │ ├─────────────────────────────┼──────────────────────────┼─────────┼───────┤ │ dense_7 (Dense) │ (None, 1) │ 2,049 │ Y │ └─────────────────────────────┴──────────────────────────┴─────────┴───────┘
Total params: 20,867,629 (79.60 MB)
Trainable params: 20,809,001 (79.38 MB)
Non-trainable params: 54,528 (213.00 KB)
Optimizer params: 4,100 (16.02 KB)
Fitting the end-to-end model
146/146 ━━━━━━━━━━━━━━━━━━━━ 75s 327ms/step - binary_accuracy: 0.8487 - loss: 0.3760 - val_binary_accuracy: 0.9494 - val_loss: 0.1160
<keras.src.callbacks.history.History at 0x7fcd1c755090>
經過 10 個 epoch 後,微調在這裡為我們帶來了很好的改進。讓我們在測試資料集上評估模型
print("Test dataset evaluation")
model.evaluate(test_ds)
Test dataset evaluation
11/37 ━━━━━[37m━━━━━━━━━━━━━━━ 1s 52ms/step - binary_accuracy: 0.9407 - loss: 0.1155
Corrupt JPEG data: 99 extraneous bytes before marker 0xd9
37/37 ━━━━━━━━━━━━━━━━━━━━ 2s 47ms/step - binary_accuracy: 0.9427 - loss: 0.1259
[0.13755160570144653, 0.941300630569458]