程式碼範例 / 電腦視覺 / 使用 NNCLR 進行自監督對比學習

使用 NNCLR 進行自監督對比學習

作者: Rishit Dagli
建立日期 2021/09/13
上次修改日期 2024/01/22
描述: NNCLR 的實作,一種用於電腦視覺的自監督學習方法。

ⓘ 這個範例使用 Keras 3

在 Colab 中檢視 GitHub 原始碼

簡介

自監督學習

自監督表示學習旨在從原始資料中獲得樣本的穩健表示,而無需昂貴的標籤或註釋。該領域的早期方法側重於定義預訓練任務,這些任務涉及在具有充足弱監督標籤的領域中進行替代任務。經過訓練以解決此類任務的編碼器,預期會學習通用的特徵,這些特徵可能對其他需要昂貴註釋的下游任務(例如圖像分類)有用。

對比學習

自監督學習技術中的一個廣泛類別是使用對比損失的技術,這些損失已廣泛用於電腦視覺應用中,例如圖像相似度降維(DrLIM)人臉驗證/識別。這些方法學習一個潛在空間,該空間將正樣本聚在一起,同時將負樣本分開。

NNCLR

在這個範例中,我們實作了 NNCLR,如 Google Research 和 DeepMind 的論文 With a Little Help from My Friends: Nearest-Neighbor Contrastive Learning of Visual Representations 中所提出。

NNCLR 學習的自監督表示超越了單一實例的正樣本,這允許學習更好的特徵,這些特徵對於不同的視點、變形,甚至類內變異是不變的。基於分群的方法為超越單一實例正樣本提供了一個很好的方法,但假設整個分群都是正樣本可能會由於早期過度泛化而損害效能。相反,NNCLR 使用學習表示空間中的最近鄰居作為正樣本。此外,NNCLR 提高了現有對比學習方法(如 SimCLRKeras 範例))的效能,並減少了自監督方法對資料增強策略的依賴。

以下是論文作者提供的精美視覺化,展示 NNCLR 如何建立在 SimCLR 的想法之上

我們可以看到 SimCLR 使用同一圖像的兩個視圖作為正對。這兩個視圖是使用隨機資料增強產生的,通過編碼器饋送以獲得正嵌入對,我們最終使用了兩種增強。相反,NNCLR 會保留一個代表完整資料分佈的嵌入支援集,並使用最近鄰居形成正對。在訓練期間,支援集用作記憶體,類似於 MoCo 中的佇列(即先進先出)。

這個範例需要 tensorflow_datasets,可以使用以下命令安裝

!pip install tensorflow-datasets

設定

import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds
import os

os.environ["KERAS_BACKEND"] = "tensorflow"
import keras
import keras_cv
from keras import ops
from keras import layers

超參數

如原始論文所示,較大的 queue_size 很可能意味著更好的效能,但會引入顯著的運算負擔。作者指出,NNCLR 的最佳結果是在佇列大小為 98,304(他們實驗中最大的 queue_size)時達成的。我們在這裡使用 10,000 作為一個可運作的範例。

AUTOTUNE = tf.data.AUTOTUNE
shuffle_buffer = 5000
# The below two values are taken from https://tensorflow.dev.org.tw/datasets/catalog/stl10
labelled_train_images = 5000
unlabelled_images = 100000

temperature = 0.1
queue_size = 10000
contrastive_augmenter = {
    "brightness": 0.5,
    "name": "contrastive_augmenter",
    "scale": (0.2, 1.0),
}
classification_augmenter = {
    "brightness": 0.2,
    "name": "classification_augmenter",
    "scale": (0.5, 1.0),
}
input_shape = (96, 96, 3)
width = 128
num_epochs = 5  # Use 25 for better results
steps_per_epoch = 50  # Use 200 for better results

載入資料集

我們從 TensorFlow Datasets 載入 STL-10 資料集,這是一個用於開發非監督式特徵學習、深度學習、自學演算法的影像辨識資料集。它的靈感來自 CIFAR-10 資料集,但進行了一些修改。

dataset_name = "stl10"


def prepare_dataset():
    unlabeled_batch_size = unlabelled_images // steps_per_epoch
    labeled_batch_size = labelled_train_images // steps_per_epoch
    batch_size = unlabeled_batch_size + labeled_batch_size

    unlabeled_train_dataset = (
        tfds.load(
            dataset_name, split="unlabelled", as_supervised=True, shuffle_files=True
        )
        .shuffle(buffer_size=shuffle_buffer)
        .batch(unlabeled_batch_size, drop_remainder=True)
    )
    labeled_train_dataset = (
        tfds.load(dataset_name, split="train", as_supervised=True, shuffle_files=True)
        .shuffle(buffer_size=shuffle_buffer)
        .batch(labeled_batch_size, drop_remainder=True)
    )
    test_dataset = (
        tfds.load(dataset_name, split="test", as_supervised=True)
        .batch(batch_size)
        .prefetch(buffer_size=AUTOTUNE)
    )
    train_dataset = tf.data.Dataset.zip(
        (unlabeled_train_dataset, labeled_train_dataset)
    ).prefetch(buffer_size=AUTOTUNE)

    return batch_size, train_dataset, labeled_train_dataset, test_dataset


batch_size, train_dataset, labeled_train_dataset, test_dataset = prepare_dataset()

資料擴增

其他自我監督技術,如 SimCLRBYOLSwAV 等,都非常依賴精心設計的資料擴增流程才能獲得最佳效能。然而,NNCLR 對於複雜的擴增的依賴程度較低,因為最近鄰方法已經提供了豐富的樣本變化。一些常用的技術經常包含在擴增流程中,包括:

  • 隨機調整大小的裁剪
  • 多種顏色失真
  • 高斯模糊

由於 NNCLR 對於複雜擴增的依賴程度較低,我們只會使用隨機裁剪和隨機亮度來擴增輸入影像。

準備擴增模組

def augmenter(brightness, name, scale):
    return keras.Sequential(
        [
            layers.Input(shape=input_shape),
            layers.Rescaling(1 / 255),
            layers.RandomFlip("horizontal"),
            keras_cv.layers.RandomCropAndResize(
                target_size=(input_shape[0], input_shape[1]),
                crop_area_factor=scale,
                aspect_ratio_factor=(3 / 4, 4 / 3),
            ),
            keras_cv.layers.RandomBrightness(factor=brightness, value_range=(0.0, 1.0)),
        ],
        name=name,
    )

編碼器架構

在文獻中,使用 ResNet-50 作為編碼器架構是很常見的。在原始論文中,作者使用 ResNet-50 作為編碼器架構,並在空間上平均 ResNet-50 的輸出。但是,請記住,更強大的模型不僅會增加訓練時間,還會需要更多的記憶體,並限制您可以使用的最大批次大小。為了這個範例的目的,我們只使用四個卷積層。

def encoder():
    return keras.Sequential(
        [
            layers.Input(shape=input_shape),
            layers.Conv2D(width, kernel_size=3, strides=2, activation="relu"),
            layers.Conv2D(width, kernel_size=3, strides=2, activation="relu"),
            layers.Conv2D(width, kernel_size=3, strides=2, activation="relu"),
            layers.Conv2D(width, kernel_size=3, strides=2, activation="relu"),
            layers.Flatten(),
            layers.Dense(width, activation="relu"),
        ],
        name="encoder",
    )

用於對比預訓練的 NNCLR 模型

我們使用對比損失訓練一個未標記影像的編碼器。一個非線性投影頭被附加到編碼器的頂部,因為它可以提高編碼器表示的品質。

class NNCLR(keras.Model):
    def __init__(
        self, temperature, queue_size,
    ):
        super().__init__()
        self.probe_accuracy = keras.metrics.SparseCategoricalAccuracy()
        self.correlation_accuracy = keras.metrics.SparseCategoricalAccuracy()
        self.contrastive_accuracy = keras.metrics.SparseCategoricalAccuracy()
        self.probe_loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)

        self.contrastive_augmenter = augmenter(**contrastive_augmenter)
        self.classification_augmenter = augmenter(**classification_augmenter)
        self.encoder = encoder()
        self.projection_head = keras.Sequential(
            [
                layers.Input(shape=(width,)),
                layers.Dense(width, activation="relu"),
                layers.Dense(width),
            ],
            name="projection_head",
        )
        self.linear_probe = keras.Sequential(
            [layers.Input(shape=(width,)), layers.Dense(10)], name="linear_probe"
        )
        self.temperature = temperature

        feature_dimensions = self.encoder.output_shape[1]
        self.feature_queue = keras.Variable(
            keras.utils.normalize(
                keras.random.normal(shape=(queue_size, feature_dimensions)),
                axis=1,
                order=2,
            ),
            trainable=False,
        )

    def compile(self, contrastive_optimizer, probe_optimizer, **kwargs):
        super().compile(**kwargs)
        self.contrastive_optimizer = contrastive_optimizer
        self.probe_optimizer = probe_optimizer

    def nearest_neighbour(self, projections):
        support_similarities = ops.matmul(projections, ops.transpose(self.feature_queue))
        nn_projections = ops.take(
            self.feature_queue, ops.argmax(support_similarities, axis=1), axis=0
        )
        return projections + ops.stop_gradient(nn_projections - projections)

    def update_contrastive_accuracy(self, features_1, features_2):
        features_1 = keras.utils.normalize(features_1, axis=1, order=2)
        features_2 = keras.utils.normalize(features_2, axis=1, order=2)
        similarities = ops.matmul(features_1, ops.transpose(features_2))
        batch_size = ops.shape(features_1)[0]
        contrastive_labels = ops.arange(batch_size)
        self.contrastive_accuracy.update_state(
            ops.concatenate([contrastive_labels, contrastive_labels], axis=0),
            ops.concatenate([similarities, ops.transpose(similarities)], axis=0),
        )

    def update_correlation_accuracy(self, features_1, features_2):
        features_1 = (features_1 - ops.mean(features_1, axis=0)) / ops.std(
            features_1, axis=0
        )
        features_2 = (features_2 - ops.mean(features_2, axis=0)) / ops.std(
            features_2, axis=0
        )

        batch_size = ops.shape(features_1)[0]
        cross_correlation = (
            ops.matmul(ops.transpose(features_1), features_2) / batch_size
        )

        feature_dim = ops.shape(features_1)[1]
        correlation_labels = ops.arange(feature_dim)
        self.correlation_accuracy.update_state(
            ops.concatenate([correlation_labels, correlation_labels], axis=0),
            ops.concatenate(
                [cross_correlation, ops.transpose(cross_correlation)], axis=0
            ),
        )

    def contrastive_loss(self, projections_1, projections_2):
        projections_1 = keras.utils.normalize(projections_1, axis=1, order=2)
        projections_2 = keras.utils.normalize(projections_2, axis=1, order=2)

        similarities_1_2_1 = (
            ops.matmul(
                self.nearest_neighbour(projections_1), ops.transpose(projections_2)
            )
            / self.temperature
        )
        similarities_1_2_2 = (
             ops.matmul(
                projections_2, ops.transpose(self.nearest_neighbour(projections_1))
            )
            / self.temperature
        )

        similarities_2_1_1 = (
            ops.matmul(
                self.nearest_neighbour(projections_2), ops.transpose(projections_1)
            )
            / self.temperature
        )
        similarities_2_1_2 = (
            ops.matmul(
                projections_1, ops.transpose(self.nearest_neighbour(projections_2))
            )
            / self.temperature
        )

        batch_size = ops.shape(projections_1)[0]
        contrastive_labels = ops.arange(batch_size)
        loss = keras.losses.sparse_categorical_crossentropy(
            ops.concatenate(
                [
                    contrastive_labels,
                    contrastive_labels,
                    contrastive_labels,
                    contrastive_labels,
                ],
                axis=0,
            ),
            ops.concatenate(
                [
                    similarities_1_2_1,
                    similarities_1_2_2,
                    similarities_2_1_1,
                    similarities_2_1_2,
                ],
                axis=0,
            ),
            from_logits=True,
        )

        self.feature_queue.assign(
            ops.concatenate([projections_1, self.feature_queue[:-batch_size]], axis=0)
        )
        return loss

    def train_step(self, data):
        (unlabeled_images, _), (labeled_images, labels) = data
        images = ops.concatenate((unlabeled_images, labeled_images), axis=0)
        augmented_images_1 = self.contrastive_augmenter(images)
        augmented_images_2 = self.contrastive_augmenter(images)

        with tf.GradientTape() as tape:
            features_1 = self.encoder(augmented_images_1)
            features_2 = self.encoder(augmented_images_2)
            projections_1 = self.projection_head(features_1)
            projections_2 = self.projection_head(features_2)
            contrastive_loss = self.contrastive_loss(projections_1, projections_2)
        gradients = tape.gradient(
            contrastive_loss,
            self.encoder.trainable_weights + self.projection_head.trainable_weights,
        )
        self.contrastive_optimizer.apply_gradients(
            zip(
                gradients,
                self.encoder.trainable_weights + self.projection_head.trainable_weights,
            )
        )
        self.update_contrastive_accuracy(features_1, features_2)
        self.update_correlation_accuracy(features_1, features_2)
        preprocessed_images = self.classification_augmenter(labeled_images)

        with tf.GradientTape() as tape:
            features = self.encoder(preprocessed_images)
            class_logits = self.linear_probe(features)
            probe_loss = self.probe_loss(labels, class_logits)
        gradients = tape.gradient(probe_loss, self.linear_probe.trainable_weights)
        self.probe_optimizer.apply_gradients(
            zip(gradients, self.linear_probe.trainable_weights)
        )
        self.probe_accuracy.update_state(labels, class_logits)

        return {
            "c_loss": contrastive_loss,
            "c_acc": self.contrastive_accuracy.result(),
            "r_acc": self.correlation_accuracy.result(),
            "p_loss": probe_loss,
            "p_acc": self.probe_accuracy.result(),
        }

    def test_step(self, data):
        labeled_images, labels = data

        preprocessed_images = self.classification_augmenter(
            labeled_images, training=False
        )
        features = self.encoder(preprocessed_images, training=False)
        class_logits = self.linear_probe(features, training=False)
        probe_loss = self.probe_loss(labels, class_logits)

        self.probe_accuracy.update_state(labels, class_logits)
        return {"p_loss": probe_loss, "p_acc": self.probe_accuracy.result()}

預訓練 NNCLR

我們使用論文中建議的 0.1 的 temperature 和之前解釋的 10,000 的 queue_size 來訓練網路。我們使用 Adam 作為我們的對比和探針優化器。在這個範例中,我們只訓練模型 30 個 epoch,但為了獲得更好的效能,應該訓練更多的 epoch。

以下兩個指標可用於監控預訓練效能,我們也會記錄它們(取自這個 Keras 範例

  • 對比準確度:自我監督的指標,即一個影像的表示比目前批次中任何其他影像的表示更相似於其不同擴增版本表示的情況的比例。即使在沒有標記範例的情況下,自我監督的指標也可以用於超參數調整。
  • 線性探針準確度:線性探針是評估自我監督分類器的常用指標。它被計算為在編碼器特徵之上訓練的邏輯迴歸分類器的準確度。在我們的例子中,這是通過在凍結的編碼器之上訓練一個單一密集層來完成的。請注意,與傳統方法在預訓練階段之後訓練分類器相反,在這個範例中,我們在預訓練期間訓練它。這可能會稍微降低其準確性,但這樣我們可以在訓練期間監控其值,這有助於實驗和除錯。
model = NNCLR(temperature=temperature, queue_size=queue_size)
model.compile(
    contrastive_optimizer=keras.optimizers.Adam(),
    probe_optimizer=keras.optimizers.Adam(),
    jit_compile=False,
)
pretrain_history = model.fit(
    train_dataset, epochs=num_epochs, validation_data=test_dataset
)

當您只能訪問非常有限的標記訓練資料,但您可以設法建立大量的未標記資料語料庫時,自我監督學習特別有用,如先前的方法,如 SEERSimCLRSwAV 等所示。

您還應該查看這些論文的部落格文章,這些文章清楚地顯示,通過首先在大型未標記資料集上進行預訓練,然後在較小的標記資料集上進行微調,可以通過少量類別標籤獲得良好的結果

建議您也查看原始論文

非常感謝 NNCLR 論文的主要作者 Debidatta Dwibedi(Google Research),他為這個範例提供了非常有見地的評論。這個範例也從SimCLR Keras 範例中汲取了靈感。