程式碼範例 / 電腦視覺 / 梯度中心化以提升訓練效能

梯度中心化以提升訓練效能

作者: Rishit Dagli
建立日期 06/18/21
最後修改日期 07/25/23
描述: 實作梯度中心化以改善深度神經網路的訓練效能。

ⓘ 此範例使用 Keras 3

在 Colab 中檢視 GitHub 來源


簡介

此範例實作由 Yong 等人提出的深度神經網路新優化技術梯度中心化,並在 Laurence Moroney 的 馬或人類數據集上展示。梯度中心化可以加速訓練過程並改善深度神經網路的最終泛化效能。它透過將梯度向量集中到零平均值來直接對梯度進行操作。梯度中心化還可以改善損失函數及其梯度的 Lipschitz 性,使訓練過程更有效率和穩定。

此範例需要 tensorflow_datasets,可以使用此命令安裝

pip install tensorflow-datasets

設定

from time import time

import keras
from keras import layers
from keras.optimizers import RMSprop
from keras import ops

from tensorflow import data as tf_data
import tensorflow_datasets as tfds

準備資料

在此範例中,我們將使用馬或人類數據集

num_classes = 2
input_shape = (300, 300, 3)
dataset_name = "horses_or_humans"
batch_size = 128
AUTOTUNE = tf_data.AUTOTUNE

(train_ds, test_ds), metadata = tfds.load(
    name=dataset_name,
    split=[tfds.Split.TRAIN, tfds.Split.TEST],
    with_info=True,
    as_supervised=True,
)

print(f"Image shape: {metadata.features['image'].shape}")
print(f"Training images: {metadata.splits['train'].num_examples}")
print(f"Test images: {metadata.splits['test'].num_examples}")
Image shape: (300, 300, 3)
Training images: 1027
Test images: 256

使用數據擴增

我們將數據重新縮放到 [0, 1] 並對我們的數據執行簡單的擴增。

rescale = layers.Rescaling(1.0 / 255)

data_augmentation = [
    layers.RandomFlip("horizontal_and_vertical"),
    layers.RandomRotation(0.3),
    layers.RandomZoom(0.2),
]


# Helper to apply augmentation
def apply_aug(x):
    for aug in data_augmentation:
        x = aug(x)
    return x


def prepare(ds, shuffle=False, augment=False):
    # Rescale dataset
    ds = ds.map(lambda x, y: (rescale(x), y), num_parallel_calls=AUTOTUNE)

    if shuffle:
        ds = ds.shuffle(1024)

    # Batch dataset
    ds = ds.batch(batch_size)

    # Use data augmentation only on the training set
    if augment:
        ds = ds.map(
            lambda x, y: (apply_aug(x), y),
            num_parallel_calls=AUTOTUNE,
        )

    # Use buffered prefecting
    return ds.prefetch(buffer_size=AUTOTUNE)

重新縮放並擴增數據

train_ds = prepare(train_ds, shuffle=True, augment=True)
test_ds = prepare(test_ds)

定義模型

在本節中,我們將定義卷積神經網路。

model = keras.Sequential(
    [
        layers.Input(shape=input_shape),
        layers.Conv2D(16, (3, 3), activation="relu"),
        layers.MaxPooling2D(2, 2),
        layers.Conv2D(32, (3, 3), activation="relu"),
        layers.Dropout(0.5),
        layers.MaxPooling2D(2, 2),
        layers.Conv2D(64, (3, 3), activation="relu"),
        layers.Dropout(0.5),
        layers.MaxPooling2D(2, 2),
        layers.Conv2D(64, (3, 3), activation="relu"),
        layers.MaxPooling2D(2, 2),
        layers.Conv2D(64, (3, 3), activation="relu"),
        layers.MaxPooling2D(2, 2),
        layers.Flatten(),
        layers.Dropout(0.5),
        layers.Dense(512, activation="relu"),
        layers.Dense(1, activation="sigmoid"),
    ]
)

實作梯度中心化

我們現在將子類化 RMSProp 優化器類別,修改 keras.optimizers.Optimizer.get_gradients() 方法,我們現在在其中實作梯度中心化。在高層次上,其想法是,假設我們透過反向傳播獲得密集層或卷積層的梯度,然後我們計算權重矩陣的列向量的平均值,然後從每個列向量中刪除平均值。

本篇論文中針對各種應用(包括一般圖像分類、細粒度圖像分類、偵測和分割以及 Person ReID)的實驗表明,GC 可以始終如一地提高深度神經網路學習的效能。

此外,為了簡單起見,目前我們沒有實作梯度裁剪功能,但這很容易實作。

目前我們只是為 RMSProp 優化器建立子類別,但是您可以輕鬆地以相同的方式為任何其他優化器或自訂優化器重現此功能。我們將在稍後的部分使用此類別,在其中使用梯度中心化訓練模型。

class GCRMSprop(RMSprop):
    def get_gradients(self, loss, params):
        # We here just provide a modified get_gradients() function since we are
        # trying to just compute the centralized gradients.

        grads = []
        gradients = super().get_gradients()
        for grad in gradients:
            grad_len = len(grad.shape)
            if grad_len > 1:
                axis = list(range(grad_len - 1))
                grad -= ops.mean(grad, axis=axis, keep_dims=True)
            grads.append(grad)

        return grads


optimizer = GCRMSprop(learning_rate=1e-4)

訓練公用程式

我們還將建立一個回呼,使我們能夠輕鬆測量總訓練時間和每個週期所花費的時間,因為我們有興趣比較梯度中心化對上面建立的模型的效果。

class TimeHistory(keras.callbacks.Callback):
    def on_train_begin(self, logs={}):
        self.times = []

    def on_epoch_begin(self, batch, logs={}):
        self.epoch_time_start = time()

    def on_epoch_end(self, batch, logs={}):
        self.times.append(time() - self.epoch_time_start)

訓練沒有 GC 的模型

我們現在訓練我們先前建立的沒有梯度中心化的模型,我們可以將其與使用梯度中心化訓練的模型的訓練效能進行比較。

time_callback_no_gc = TimeHistory()
model.compile(
    loss="binary_crossentropy",
    optimizer=RMSprop(learning_rate=1e-4),
    metrics=["accuracy"],
)

model.summary()
Model: "sequential"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓
┃ Layer (type)                     Output Shape                  Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩
│ conv2d (Conv2D)                 │ (None, 298, 298, 16)      │        448 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ max_pooling2d (MaxPooling2D)    │ (None, 149, 149, 16)      │          0 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ conv2d_1 (Conv2D)               │ (None, 147, 147, 32)      │      4,640 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ dropout (Dropout)               │ (None, 147, 147, 32)      │          0 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ max_pooling2d_1 (MaxPooling2D)  │ (None, 73, 73, 32)        │          0 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ conv2d_2 (Conv2D)               │ (None, 71, 71, 64)        │     18,496 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ dropout_1 (Dropout)             │ (None, 71, 71, 64)        │          0 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ max_pooling2d_2 (MaxPooling2D)  │ (None, 35, 35, 64)        │          0 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ conv2d_3 (Conv2D)               │ (None, 33, 33, 64)        │     36,928 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ max_pooling2d_3 (MaxPooling2D)  │ (None, 16, 16, 64)        │          0 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ conv2d_4 (Conv2D)               │ (None, 14, 14, 64)        │     36,928 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ max_pooling2d_4 (MaxPooling2D)  │ (None, 7, 7, 64)          │          0 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ flatten (Flatten)               │ (None, 3136)              │          0 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ dropout_2 (Dropout)             │ (None, 3136)              │          0 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ dense (Dense)                   │ (None, 512)               │  1,606,144 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ dense_1 (Dense)                 │ (None, 1)                 │        513 │
└─────────────────────────────────┴───────────────────────────┴────────────┘
 Total params: 1,704,097 (6.50 MB)
 Trainable params: 1,704,097 (6.50 MB)
 Non-trainable params: 0 (0.00 B)

我們也會儲存歷史記錄,因為我們稍後想要比較我們使用和未使用梯度中心化訓練的模型

history_no_gc = model.fit(
    train_ds, epochs=10, verbose=1, callbacks=[time_callback_no_gc]
)
Epoch 1/10
 9/9 ━━━━━━━━━━━━━━━━━━━━ 24s 778ms/step - accuracy: 0.4772 - loss: 0.7405
Epoch 2/10
 9/9 ━━━━━━━━━━━━━━━━━━━━ 10s 597ms/step - accuracy: 0.5434 - loss: 0.6861
Epoch 3/10
 9/9 ━━━━━━━━━━━━━━━━━━━━ 10s 700ms/step - accuracy: 0.5402 - loss: 0.6911
Epoch 4/10
 9/9 ━━━━━━━━━━━━━━━━━━━━ 9s 586ms/step - accuracy: 0.5884 - loss: 0.6788
Epoch 5/10
 9/9 ━━━━━━━━━━━━━━━━━━━━ 9s 588ms/step - accuracy: 0.6570 - loss: 0.6564
Epoch 6/10
 9/9 ━━━━━━━━━━━━━━━━━━━━ 10s 591ms/step - accuracy: 0.6671 - loss: 0.6395
Epoch 7/10
 9/9 ━━━━━━━━━━━━━━━━━━━━ 10s 594ms/step - accuracy: 0.7010 - loss: 0.6161
Epoch 8/10
 9/9 ━━━━━━━━━━━━━━━━━━━━ 9s 593ms/step - accuracy: 0.6946 - loss: 0.6129
Epoch 9/10
 9/9 ━━━━━━━━━━━━━━━━━━━━ 10s 699ms/step - accuracy: 0.6972 - loss: 0.5987
Epoch 10/10
 9/9 ━━━━━━━━━━━━━━━━━━━━ 11s 623ms/step - accuracy: 0.6839 - loss: 0.6197

訓練具有 GC 的模型

我們現在將訓練相同的模型,這次使用梯度中心化 (Gradient Centralization),請注意這次我們的優化器是使用梯度中心化的。

time_callback_gc = TimeHistory()
model.compile(loss="binary_crossentropy", optimizer=optimizer, metrics=["accuracy"])

model.summary()

history_gc = model.fit(train_ds, epochs=10, verbose=1, callbacks=[time_callback_gc])
Model: "sequential"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓
┃ Layer (type)                     Output Shape                  Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩
│ conv2d (Conv2D)                 │ (None, 298, 298, 16)      │        448 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ max_pooling2d (MaxPooling2D)    │ (None, 149, 149, 16)      │          0 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ conv2d_1 (Conv2D)               │ (None, 147, 147, 32)      │      4,640 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ dropout (Dropout)               │ (None, 147, 147, 32)      │          0 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ max_pooling2d_1 (MaxPooling2D)  │ (None, 73, 73, 32)        │          0 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ conv2d_2 (Conv2D)               │ (None, 71, 71, 64)        │     18,496 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ dropout_1 (Dropout)             │ (None, 71, 71, 64)        │          0 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ max_pooling2d_2 (MaxPooling2D)  │ (None, 35, 35, 64)        │          0 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ conv2d_3 (Conv2D)               │ (None, 33, 33, 64)        │     36,928 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ max_pooling2d_3 (MaxPooling2D)  │ (None, 16, 16, 64)        │          0 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ conv2d_4 (Conv2D)               │ (None, 14, 14, 64)        │     36,928 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ max_pooling2d_4 (MaxPooling2D)  │ (None, 7, 7, 64)          │          0 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ flatten (Flatten)               │ (None, 3136)              │          0 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ dropout_2 (Dropout)             │ (None, 3136)              │          0 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ dense (Dense)                   │ (None, 512)               │  1,606,144 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ dense_1 (Dense)                 │ (None, 1)                 │        513 │
└─────────────────────────────────┴───────────────────────────┴────────────┘
 Total params: 1,704,097 (6.50 MB)
 Trainable params: 1,704,097 (6.50 MB)
 Non-trainable params: 0 (0.00 B)
Epoch 1/10
 9/9 ━━━━━━━━━━━━━━━━━━━━ 12s 649ms/step - accuracy: 0.7118 - loss: 0.5594
Epoch 2/10
 9/9 ━━━━━━━━━━━━━━━━━━━━ 10s 592ms/step - accuracy: 0.7249 - loss: 0.5817
Epoch 3/10
 9/9 ━━━━━━━━━━━━━━━━━━━━ 9s 587ms/step - accuracy: 0.8060 - loss: 0.4448
Epoch 4/10
 9/9 ━━━━━━━━━━━━━━━━━━━━ 10s 693ms/step - accuracy: 0.8472 - loss: 0.4051
Epoch 5/10
 9/9 ━━━━━━━━━━━━━━━━━━━━ 10s 594ms/step - accuracy: 0.8386 - loss: 0.3978
Epoch 6/10
 9/9 ━━━━━━━━━━━━━━━━━━━━ 10s 593ms/step - accuracy: 0.8442 - loss: 0.3976
Epoch 7/10
 9/9 ━━━━━━━━━━━━━━━━━━━━ 9s 585ms/step - accuracy: 0.7409 - loss: 0.6626
Epoch 8/10
 9/9 ━━━━━━━━━━━━━━━━━━━━ 10s 587ms/step - accuracy: 0.8191 - loss: 0.4357
Epoch 9/10
 9/9 ━━━━━━━━━━━━━━━━━━━━ 9s 587ms/step - accuracy: 0.8248 - loss: 0.3974
Epoch 10/10
 9/9 ━━━━━━━━━━━━━━━━━━━━ 10s 646ms/step - accuracy: 0.8022 - loss: 0.4589

效能比較

print("Not using Gradient Centralization")
print(f"Loss: {history_no_gc.history['loss'][-1]}")
print(f"Accuracy: {history_no_gc.history['accuracy'][-1]}")
print(f"Training Time: {sum(time_callback_no_gc.times)}")

print("Using Gradient Centralization")
print(f"Loss: {history_gc.history['loss'][-1]}")
print(f"Accuracy: {history_gc.history['accuracy'][-1]}")
print(f"Training Time: {sum(time_callback_gc.times)}")
Not using Gradient Centralization
Loss: 0.5345584154129028
Accuracy: 0.7604166865348816
Training Time: 112.48799777030945
Using Gradient Centralization
Loss: 0.4014038145542145
Accuracy: 0.8153935074806213
Training Time: 98.31573963165283

鼓勵讀者在不同領域的不同資料集上嘗試梯度中心化,並實驗其效果。 強烈建議您也查閱原始論文 - 作者提出了幾項關於梯度中心化的研究,顯示它如何能改善整體效能、泛化能力、訓練時間以及更有效率。

非常感謝 Ali Mustufa Shaikh 審閱此實作。