程式碼範例 / 生成式深度學習 / 去噪擴散隱式模型

去噪擴散隱式模型

作者: András Béres
建立日期 2022/06/24
上次修改日期 2022/06/24
說明: 使用去噪擴散隱式模型生成花朵圖像。

ⓘ 這個範例使用 Keras 3

在 Colab 中檢視 GitHub 原始碼


簡介

什麼是擴散模型?

最近,去噪擴散模型,包括基於分數的生成模型,作為一類強大的生成模型而廣受歡迎,它們的圖像合成品質甚至可以媲美生成對抗網路 (GANs)。它們傾向於產生更多樣化的樣本,同時訓練穩定且易於擴展。最近的大型擴散模型,例如DALL-E 2Imagen,展現了令人難以置信的文字轉圖像生成能力。然而,它們的缺點之一是它們的採樣速度較慢,因為它們需要多次前向傳遞才能生成圖像。

擴散是指逐步將結構化訊號(圖像)轉換為雜訊的過程。透過模擬擴散,我們可以從訓練圖像產生雜訊圖像,並可以訓練神經網路嘗試對它們進行去噪。使用訓練後的網路,我們可以模擬與擴散相反的過程,即逆向擴散,這是圖像從雜訊中出現的過程。

diffusion process gif

一句話總結:擴散模型經過訓練可以對雜訊圖像進行去噪,並且可以透過迭代地對純雜訊進行去噪來生成圖像。

本範例的目標

此程式碼範例旨在成為擴散模型的最小但功能完整(具有生成品質指標)的實作,具有適度的計算要求和合理的效能。我的實作選擇和超參數調整都是以這些目標為考量。

由於目前擴散模型的文獻在數學上相當複雜,具有多個理論框架(分數匹配微分方程馬可夫鏈),有時甚至出現互相矛盾的符號(請參閱附錄 C.2),因此嘗試理解它們可能會令人望而卻步。我在這個範例中對這些模型的看法是,它們學習將雜訊圖像分離為其圖像和高斯雜訊成分。

在本範例中,我努力將所有長的數學表達式分解成易於理解的部分,並為所有變數提供了說明性的名稱。我也加入了許多相關文獻的連結,以幫助感興趣的讀者更深入地了解這個主題,希望這個程式碼範例將成為從業者學習擴散模型的一個良好起點。

在以下章節中,我們將實作去噪擴散隱式模型 (DDIM)的連續時間版本,並採用確定性採樣。


設定

import os

os.environ["KERAS_BACKEND"] = "tensorflow"

import math
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds

import keras
from keras import layers
from keras import ops

超參數

# data
dataset_name = "oxford_flowers102"
dataset_repetitions = 5
num_epochs = 1  # train for at least 50 epochs for good results
image_size = 64
# KID = Kernel Inception Distance, see related section
kid_image_size = 75
kid_diffusion_steps = 5
plot_diffusion_steps = 20

# sampling
min_signal_rate = 0.02
max_signal_rate = 0.95

# architecture
embedding_dims = 32
embedding_max_frequency = 1000.0
widths = [32, 64, 96, 128]
block_depth = 2

# optimization
batch_size = 64
ema = 0.999
learning_rate = 1e-3
weight_decay = 1e-4

資料管道

我們將使用牛津花卉 102資料集來產生花卉圖像,這是一個包含約 8,000 張圖像的多樣化自然資料集。不幸的是,官方拆分是不平衡的,因為大多數圖像都包含在測試拆分中。我們使用Tensorflow Datasets 切片 API建立新的拆分(80% 訓練、20% 驗證)。我們應用中心裁剪作為預處理,並多次重複資料集(原因在下一節中給出)。

def preprocess_image(data):
    # center crop image
    height = ops.shape(data["image"])[0]
    width = ops.shape(data["image"])[1]
    crop_size = ops.minimum(height, width)
    image = tf.image.crop_to_bounding_box(
        data["image"],
        (height - crop_size) // 2,
        (width - crop_size) // 2,
        crop_size,
        crop_size,
    )

    # resize and clip
    # for image downsampling it is important to turn on antialiasing
    image = tf.image.resize(image, size=[image_size, image_size], antialias=True)
    return ops.clip(image / 255.0, 0.0, 1.0)


def prepare_dataset(split):
    # the validation dataset is shuffled as well, because data order matters
    # for the KID estimation
    return (
        tfds.load(dataset_name, split=split, shuffle_files=True)
        .map(preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)
        .cache()
        .repeat(dataset_repetitions)
        .shuffle(10 * batch_size)
        .batch(batch_size, drop_remainder=True)
        .prefetch(buffer_size=tf.data.AUTOTUNE)
    )


# load dataset
train_dataset = prepare_dataset("train[:80%]+validation[:80%]+test[:80%]")
val_dataset = prepare_dataset("train[80%:]+validation[80%:]+test[80%:]")

核心起始距離

核心起始距離 (KID) 是一種圖像品質指標,它被提出作為流行的弗雷歇起始距離 (FID) 的替代品。我偏好 KID 而不是 FID,因為它實作起來更簡單,可以按批次估計,並且計算量更輕。更多詳細資訊在此

在本範例中,圖像以起始網路的最小可能解析度 (75x75 而非 299x299) 進行評估,並且該指標僅針對計算效率在驗證集上測量。基於相同的原因,我們也將評估時的採樣步驟數量限制為 5。

由於資料集相對較小,我們會在每個 epoch 中多次瀏覽訓練和驗證拆分,因為 KID 估計是嘈雜且計算密集型的,因此我們希望僅在多次迭代後進行評估,但要進行多次迭代。

@keras.saving.register_keras_serializable()
class KID(keras.metrics.Metric):
    def __init__(self, name, **kwargs):
        super().__init__(name=name, **kwargs)

        # KID is estimated per batch and is averaged across batches
        self.kid_tracker = keras.metrics.Mean(name="kid_tracker")

        # a pretrained InceptionV3 is used without its classification layer
        # transform the pixel values to the 0-255 range, then use the same
        # preprocessing as during pretraining
        self.encoder = keras.Sequential(
            [
                keras.Input(shape=(image_size, image_size, 3)),
                layers.Rescaling(255.0),
                layers.Resizing(height=kid_image_size, width=kid_image_size),
                layers.Lambda(keras.applications.inception_v3.preprocess_input),
                keras.applications.InceptionV3(
                    include_top=False,
                    input_shape=(kid_image_size, kid_image_size, 3),
                    weights="imagenet",
                ),
                layers.GlobalAveragePooling2D(),
            ],
            name="inception_encoder",
        )

    def polynomial_kernel(self, features_1, features_2):
        feature_dimensions = ops.cast(ops.shape(features_1)[1], dtype="float32")
        return (
            features_1 @ ops.transpose(features_2) / feature_dimensions + 1.0
        ) ** 3.0

    def update_state(self, real_images, generated_images, sample_weight=None):
        real_features = self.encoder(real_images, training=False)
        generated_features = self.encoder(generated_images, training=False)

        # compute polynomial kernels using the two sets of features
        kernel_real = self.polynomial_kernel(real_features, real_features)
        kernel_generated = self.polynomial_kernel(
            generated_features, generated_features
        )
        kernel_cross = self.polynomial_kernel(real_features, generated_features)

        # estimate the squared maximum mean discrepancy using the average kernel values
        batch_size = real_features.shape[0]
        batch_size_f = ops.cast(batch_size, dtype="float32")
        mean_kernel_real = ops.sum(kernel_real * (1.0 - ops.eye(batch_size))) / (
            batch_size_f * (batch_size_f - 1.0)
        )
        mean_kernel_generated = ops.sum(
            kernel_generated * (1.0 - ops.eye(batch_size))
        ) / (batch_size_f * (batch_size_f - 1.0))
        mean_kernel_cross = ops.mean(kernel_cross)
        kid = mean_kernel_real + mean_kernel_generated - 2.0 * mean_kernel_cross

        # update the average KID estimate
        self.kid_tracker.update_state(kid)

    def result(self):
        return self.kid_tracker.result()

    def reset_state(self):
        self.kid_tracker.reset_state()

網路架構

在這裡,我們指定將用於去噪的神經網路的架構。我們使用相同的輸入和輸出維度建構U-Net。U-Net 是一種流行的語義分割架構,其主要思想是逐步對輸入圖像進行降採樣,然後再升採樣,並在具有相同解析度的圖層之間新增跳接連線。這些有助於梯度流,並避免引入表示瓶頸,這與通常的自動編碼器不同。基於此,可以將擴散模型視為沒有瓶頸的去噪自動編碼器

網路接收兩個輸入,即雜訊圖像及其雜訊成分的變異數。後者是必需的,因為對訊號去噪需要在不同的雜訊程度下執行不同的操作。我們使用正弦嵌入來轉換雜訊變異數,類似於轉換器NeRF中使用的位置編碼。這有助於網路對雜訊等級高度敏感,這對於良好的效能至關重要。我們使用Lambda 圖層實作正弦嵌入。

其他一些考量

  • 我們使用Keras 函數式 API建構網路,並使用閉包以一致的樣式建構圖層區塊。
  • 擴散模型會嵌入擴散過程的時間步長索引,而不是雜訊變異數,而基於分數的模型(表 1)通常使用雜訊等級的某個函數。我偏好後者,以便我們可以在推論時變更採樣排程,而無需重新訓練網路。
  • 擴散模型會將嵌入單獨輸入到每個卷積區塊。為了簡單起見,我們僅在網路的開始處輸入它,根據我的經驗,這幾乎不會降低效能,因為跳接和殘差連線有助於資訊在網路上正確傳播。
  • 在文獻中,通常會在較低的解析度下使用注意力圖層,以獲得更好的全域連貫性。為了簡單起見,我省略了它。
  • 我們停用批次標準化圖層的可學習中心和縮放參數,因為後續的卷積圖層會使其變得多餘。
  • 我們將最後一個卷積的核初始化為全零,這是一個很好的做法,使網路在初始化後僅預測零,這也是其目標的平均值。這將改善訓練開始時的行為,並使均方誤差損失從正好 1 開始。
@keras.saving.register_keras_serializable()
def sinusoidal_embedding(x):
    embedding_min_frequency = 1.0
    frequencies = ops.exp(
        ops.linspace(
            ops.log(embedding_min_frequency),
            ops.log(embedding_max_frequency),
            embedding_dims // 2,
        )
    )
    angular_speeds = ops.cast(2.0 * math.pi * frequencies, "float32")
    embeddings = ops.concatenate(
        [ops.sin(angular_speeds * x), ops.cos(angular_speeds * x)], axis=3
    )
    return embeddings


def ResidualBlock(width):
    def apply(x):
        input_width = x.shape[3]
        if input_width == width:
            residual = x
        else:
            residual = layers.Conv2D(width, kernel_size=1)(x)
        x = layers.BatchNormalization(center=False, scale=False)(x)
        x = layers.Conv2D(width, kernel_size=3, padding="same", activation="swish")(x)
        x = layers.Conv2D(width, kernel_size=3, padding="same")(x)
        x = layers.Add()([x, residual])
        return x

    return apply


def DownBlock(width, block_depth):
    def apply(x):
        x, skips = x
        for _ in range(block_depth):
            x = ResidualBlock(width)(x)
            skips.append(x)
        x = layers.AveragePooling2D(pool_size=2)(x)
        return x

    return apply


def UpBlock(width, block_depth):
    def apply(x):
        x, skips = x
        x = layers.UpSampling2D(size=2, interpolation="bilinear")(x)
        for _ in range(block_depth):
            x = layers.Concatenate()([x, skips.pop()])
            x = ResidualBlock(width)(x)
        return x

    return apply


def get_network(image_size, widths, block_depth):
    noisy_images = keras.Input(shape=(image_size, image_size, 3))
    noise_variances = keras.Input(shape=(1, 1, 1))

    e = layers.Lambda(sinusoidal_embedding, output_shape=(1, 1, 32))(noise_variances)
    e = layers.UpSampling2D(size=image_size, interpolation="nearest")(e)

    x = layers.Conv2D(widths[0], kernel_size=1)(noisy_images)
    x = layers.Concatenate()([x, e])

    skips = []
    for width in widths[:-1]:
        x = DownBlock(width, block_depth)([x, skips])

    for _ in range(block_depth):
        x = ResidualBlock(widths[-1])(x)

    for width in reversed(widths[:-1]):
        x = UpBlock(width, block_depth)([x, skips])

    x = layers.Conv2D(3, kernel_size=1, kernel_initializer="zeros")(x)

    return keras.Model([noisy_images, noise_variances], x, name="residual_unet")

這展示了函數式 API 的強大功能。請注意我們如何在 80 行程式碼中建構具有跳接連線、殘差區塊、多個輸入和正弦嵌入的相對複雜的 U-Net!


擴散模型

擴散排程

讓我們假設擴散過程從時間 = 0 開始,到時間 = 1 結束。此變數將稱為擴散時間,可以是離散的(在擴散模型中很常見)或連續的(在基於分數的模型中很常見)。我選擇後者,以便可以在推論時變更採樣步驟的數量。

我們需要一個函數來告訴我們,在擴散過程的每個時間點,雜訊圖像的雜訊等級和訊號等級,對應於實際的擴散時間。這將稱為擴散排程(請參閱 diffusion_schedule())。

此排程會輸出兩個量:noise_ratesignal_rate(分別對應於 DDIM 論文中的 sqrt(1 - alpha) 和 sqrt(alpha))。我們透過依其對應的速率加權隨機雜訊和訓練圖像,並將它們加在一起,來產生雜訊圖像。

由於(標準常態)隨機雜訊和(標準化)影像都具有零均值和單位變異數,因此雜訊率和訊號率可以解釋為它們在含雜訊影像中成分的標準差,而它們的平方率可以解釋為它們的變異數(或在訊號處理意義上的功率)。這些比率的設定始終會使它們的平方和為 1,這表示含雜訊影像的變異數始終為 1,就像其未縮放的成分一樣。

我們將使用一個簡化的連續版本 餘弦排程(第 3.2 節),該排程在文獻中相當常見。此排程是對稱的,在擴散過程的開始和結束時較慢,並且它還具有一個很好的幾何解釋,使用 單位圓的三角特性

diffusion schedule gif

訓練過程

去噪擴散模型的訓練過程(請參閱 train_step()denoise())如下:我們均勻地採樣隨機擴散時間,並以對應於擴散時間的比率將訓練影像與隨機高斯雜訊混合。然後,我們訓練模型以將含雜訊的影像分離為其兩個成分。

通常,訓練神經網路是為了預測未縮放的雜訊成分,然後可以使用訊號率和雜訊率計算出預測的影像成分。理論上應該使用逐像素的 均方誤差,但我建議改用 平均絕對誤差(類似於 這個 實作),這在本數據集上產生更好的結果。

採樣(反向擴散)

當採樣時(請參閱 reverse_diffusion()),在每個步驟中,我們獲取先前對含雜訊影像的估計值,並使用我們的網路將其分離為影像和雜訊。然後,我們使用下一步的訊號率和雜訊率重新組合這些成分。

雖然在 DDIM 的方程式 12 中顯示了類似的觀點,但我認為以上對採樣方程式的解釋並未廣為人知。

此範例僅實作 DDIM 的確定性採樣程序,對應於論文中的 eta = 0。也可以使用隨機採樣(在這種情況下,模型會變成 去噪擴散機率模型 (DDPM)),其中一部分預測的雜訊會被相同或較大量的隨機雜訊取代(請參閱方程式 16 及以下)。

隨機採樣可以在不重新訓練網路的情況下使用(因為兩個模型的訓練方式相同),並且可以提高樣本品質,但另一方面通常需要更多的採樣步驟。

@keras.saving.register_keras_serializable()
class DiffusionModel(keras.Model):
    def __init__(self, image_size, widths, block_depth):
        super().__init__()

        self.normalizer = layers.Normalization()
        self.network = get_network(image_size, widths, block_depth)
        self.ema_network = keras.models.clone_model(self.network)

    def compile(self, **kwargs):
        super().compile(**kwargs)

        self.noise_loss_tracker = keras.metrics.Mean(name="n_loss")
        self.image_loss_tracker = keras.metrics.Mean(name="i_loss")
        self.kid = KID(name="kid")

    @property
    def metrics(self):
        return [self.noise_loss_tracker, self.image_loss_tracker, self.kid]

    def denormalize(self, images):
        # convert the pixel values back to 0-1 range
        images = self.normalizer.mean + images * self.normalizer.variance**0.5
        return ops.clip(images, 0.0, 1.0)

    def diffusion_schedule(self, diffusion_times):
        # diffusion times -> angles
        start_angle = ops.cast(ops.arccos(max_signal_rate), "float32")
        end_angle = ops.cast(ops.arccos(min_signal_rate), "float32")

        diffusion_angles = start_angle + diffusion_times * (end_angle - start_angle)

        # angles -> signal and noise rates
        signal_rates = ops.cos(diffusion_angles)
        noise_rates = ops.sin(diffusion_angles)
        # note that their squared sum is always: sin^2(x) + cos^2(x) = 1

        return noise_rates, signal_rates

    def denoise(self, noisy_images, noise_rates, signal_rates, training):
        # the exponential moving average weights are used at evaluation
        if training:
            network = self.network
        else:
            network = self.ema_network

        # predict noise component and calculate the image component using it
        pred_noises = network([noisy_images, noise_rates**2], training=training)
        pred_images = (noisy_images - noise_rates * pred_noises) / signal_rates

        return pred_noises, pred_images

    def reverse_diffusion(self, initial_noise, diffusion_steps):
        # reverse diffusion = sampling
        num_images = initial_noise.shape[0]
        step_size = 1.0 / diffusion_steps

        # important line:
        # at the first sampling step, the "noisy image" is pure noise
        # but its signal rate is assumed to be nonzero (min_signal_rate)
        next_noisy_images = initial_noise
        for step in range(diffusion_steps):
            noisy_images = next_noisy_images

            # separate the current noisy image to its components
            diffusion_times = ops.ones((num_images, 1, 1, 1)) - step * step_size
            noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)
            pred_noises, pred_images = self.denoise(
                noisy_images, noise_rates, signal_rates, training=False
            )
            # network used in eval mode

            # remix the predicted components using the next signal and noise rates
            next_diffusion_times = diffusion_times - step_size
            next_noise_rates, next_signal_rates = self.diffusion_schedule(
                next_diffusion_times
            )
            next_noisy_images = (
                next_signal_rates * pred_images + next_noise_rates * pred_noises
            )
            # this new noisy image will be used in the next step

        return pred_images

    def generate(self, num_images, diffusion_steps):
        # noise -> images -> denormalized images
        initial_noise = keras.random.normal(
            shape=(num_images, image_size, image_size, 3)
        )
        generated_images = self.reverse_diffusion(initial_noise, diffusion_steps)
        generated_images = self.denormalize(generated_images)
        return generated_images

    def train_step(self, images):
        # normalize images to have standard deviation of 1, like the noises
        images = self.normalizer(images, training=True)
        noises = keras.random.normal(shape=(batch_size, image_size, image_size, 3))

        # sample uniform random diffusion times
        diffusion_times = keras.random.uniform(
            shape=(batch_size, 1, 1, 1), minval=0.0, maxval=1.0
        )
        noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)
        # mix the images with noises accordingly
        noisy_images = signal_rates * images + noise_rates * noises

        with tf.GradientTape() as tape:
            # train the network to separate noisy images to their components
            pred_noises, pred_images = self.denoise(
                noisy_images, noise_rates, signal_rates, training=True
            )

            noise_loss = self.loss(noises, pred_noises)  # used for training
            image_loss = self.loss(images, pred_images)  # only used as metric

        gradients = tape.gradient(noise_loss, self.network.trainable_weights)
        self.optimizer.apply_gradients(zip(gradients, self.network.trainable_weights))

        self.noise_loss_tracker.update_state(noise_loss)
        self.image_loss_tracker.update_state(image_loss)

        # track the exponential moving averages of weights
        for weight, ema_weight in zip(self.network.weights, self.ema_network.weights):
            ema_weight.assign(ema * ema_weight + (1 - ema) * weight)

        # KID is not measured during the training phase for computational efficiency
        return {m.name: m.result() for m in self.metrics[:-1]}

    def test_step(self, images):
        # normalize images to have standard deviation of 1, like the noises
        images = self.normalizer(images, training=False)
        noises = keras.random.normal(shape=(batch_size, image_size, image_size, 3))

        # sample uniform random diffusion times
        diffusion_times = keras.random.uniform(
            shape=(batch_size, 1, 1, 1), minval=0.0, maxval=1.0
        )
        noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)
        # mix the images with noises accordingly
        noisy_images = signal_rates * images + noise_rates * noises

        # use the network to separate noisy images to their components
        pred_noises, pred_images = self.denoise(
            noisy_images, noise_rates, signal_rates, training=False
        )

        noise_loss = self.loss(noises, pred_noises)
        image_loss = self.loss(images, pred_images)

        self.image_loss_tracker.update_state(image_loss)
        self.noise_loss_tracker.update_state(noise_loss)

        # measure KID between real and generated images
        # this is computationally demanding, kid_diffusion_steps has to be small
        images = self.denormalize(images)
        generated_images = self.generate(
            num_images=batch_size, diffusion_steps=kid_diffusion_steps
        )
        self.kid.update_state(images, generated_images)

        return {m.name: m.result() for m in self.metrics}

    def plot_images(self, epoch=None, logs=None, num_rows=3, num_cols=6):
        # plot random generated images for visual evaluation of generation quality
        generated_images = self.generate(
            num_images=num_rows * num_cols,
            diffusion_steps=plot_diffusion_steps,
        )

        plt.figure(figsize=(num_cols * 2.0, num_rows * 2.0))
        for row in range(num_rows):
            for col in range(num_cols):
                index = row * num_cols + col
                plt.subplot(num_rows, num_cols, index + 1)
                plt.imshow(generated_images[index])
                plt.axis("off")
        plt.tight_layout()
        plt.show()
        plt.close()

訓練

# create and compile the model
model = DiffusionModel(image_size, widths, block_depth)
# below tensorflow 2.9:
# pip install tensorflow_addons
# import tensorflow_addons as tfa
# optimizer=tfa.optimizers.AdamW
model.compile(
    optimizer=keras.optimizers.AdamW(
        learning_rate=learning_rate, weight_decay=weight_decay
    ),
    loss=keras.losses.mean_absolute_error,
)
# pixelwise mean absolute error is used as loss

# save the best model based on the validation KID metric
checkpoint_path = "checkpoints/diffusion_model.weights.h5"
checkpoint_callback = keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_path,
    save_weights_only=True,
    monitor="val_kid",
    mode="min",
    save_best_only=True,
)

# calculate mean and variance of training dataset for normalization
model.normalizer.adapt(train_dataset)

# run training and plot generated images periodically
model.fit(
    train_dataset,
    epochs=num_epochs,
    validation_data=val_dataset,
    callbacks=[
        keras.callbacks.LambdaCallback(on_epoch_end=model.plot_images),
        checkpoint_callback,
    ],
)
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/inception_v3/inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5
 87910968/87910968 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step

 511/511 ━━━━━━━━━━━━━━━━━━━━ 0s 48ms/step - i_loss: 0.6896 - n_loss: 0.2961

png

 511/511 ━━━━━━━━━━━━━━━━━━━━ 110s 138ms/step - i_loss: 0.6891 - n_loss: 0.2959 - kid: 0.0000e+00 - val_i_loss: 2.5650 - val_kid: 2.0372 - val_n_loss: 0.7914

<keras.src.callbacks.history.History at 0x7f521b149870>

推論

# load the best model and generate images
model.load_weights(checkpoint_path)
model.plot_images()

png


結果

透過運行訓練至少 50 個 epoch(在 T4 GPU 上需要 2 小時,在 A100 GPU 上需要 30 分鐘),可以使用此程式碼範例獲得高品質的影像生成。

一批影像在 80 個 epoch 訓練過程中的演變(色彩失真是由於 GIF 壓縮)

flowers training gif

使用 1 到 20 個採樣步驟從相同的初始雜訊產生的影像

flowers sampling steps gif

初始雜訊樣本之間的插值(球面)

flowers interpolation gif

確定性採樣過程(頂部為含雜訊影像,底部為預測影像,40 個步驟)

flowers deterministic generation gif

隨機採樣過程(頂部為含雜訊影像,底部為預測影像,80 個步驟)

flowers stochastic generation gif


學到的教訓

在準備此程式碼範例期間,我使用 這個儲存庫 進行了許多實驗。在本節中,我將依重要性的主觀順序列出學到的教訓和建議。

演算法提示

  • 最小和最大訊號率:我發現最小訊號率是一個重要的超參數。將其設定得太低會使生成的影像過度飽和,而將其設定得太高會使其欠飽和。我建議仔細調整它。此外,將其設定為 0 會導致除以零錯誤。最大訊號率可以設定為 1,但我發現將其設定得較低會稍微提高生成品質。
  • 損失函數:雖然大型模型傾向於使用均方誤差 (MSE) 損失,但我建議在此數據集上使用平均絕對誤差 (MAE)。根據我的經驗,MSE 損失會產生更多樣化的樣本(它似乎也會產生更多第 3 節),而 MAE 損失會產生更平滑的影像。我建議兩種都嘗試。
  • 權重衰減:當擴大模型時,我偶爾會遇到發散的訓練,並且發現權重衰減有助於避免在效能成本較低的情況下出現不穩定。這就是為什麼我在這個範例中使用 AdamW 而不是 Adam 的原因。
  • 權重的指數移動平均:這有助於減少 KID 指標的變異數,並有助於平均訓練期間的短期變化。
  • 影像增強:雖然我在此範例中未使用影像增強,但根據我的經驗,在訓練中新增水平翻轉會提高生成效能,而隨機裁切則不會。由於我們使用監督去噪損失,過擬合可能會是一個問題,因此影像增強可能在小型數據集上很重要。還應注意不要使用 洩漏增強,例如,可以按照 這種方法(第 5 節末尾) 進行。
  • 數據標準化:在文獻中,影像的像素值通常會轉換為 -1 到 1 的範圍。為了理論上的正確性,我將影像標準化為具有零均值和單位變異數,就像隨機雜訊一樣。
  • 雜訊等級輸入:我選擇將雜訊變異數輸入到網路中,因為在我們的採樣排程下它是對稱的。也可以輸入雜訊率(類似效能)、訊號率(較低效能),甚至是 對數訊號雜訊比(附錄 B.1)(沒有嘗試,因為它的範圍高度取決於最小和最大訊號率,並且需要相應地調整最小嵌入頻率)。
  • 梯度裁剪:使用值為 1 的全域梯度裁剪有助於大型模型的訓練穩定性,但根據我的經驗,效能顯著降低。
  • 殘差連線縮減:對於 更深層的模型(附錄 B),使用 1/sqrt(2) 縮放殘差連線可能會有所幫助,但我的情況沒有幫助。
  • 學習率:對我來說,Adam 優化器 的預設學習率 1e-3 效果非常好,但較低的學習率在 文獻中更常見(表 11-13)

架構提示

  • 正弦嵌入:在網路的雜訊等級輸入上使用正弦嵌入對於獲得良好的效能至關重要。我建議將最小嵌入頻率設定為此輸入範圍的倒數,由於我們在此範例中使用雜訊變異數,因此可以始終將其保持在 1。最大嵌入頻率控制網路將敏感的雜訊變異數中的最小變化,而嵌入維度設定嵌入中的頻率成分數。根據我的經驗,效能對這些值不太敏感。
  • 跳躍連線:在網路架構中使用跳躍連線絕對至關重要,沒有它們,模型將無法學會在良好的效能下去噪。
  • 殘差連線:根據我的經驗,殘差連線也顯著提高了效能,但這可能是由於我們僅將雜訊等級嵌入輸入到網路的第一層,而不是所有層。
  • 標準化:當擴大模型時,我偶爾會遇到發散的訓練,使用標準化層有助於緩解此問題。在文獻中,通常在網路中使用 群組標準化(例如,使用 8 個群組)或 層標準化,但我選擇使用 批次標準化,因為它在我的實驗中提供了類似的好處,但計算量較輕。
  • 激活:激活函數的選擇對生成品質的影響比我預期的要大。在我的實驗中,使用非單調激活函數的效果優於單調激活函數(例如 ReLU),其中 Swish 的效果最佳(這也是 Imagen 使用的,第 41 頁)。
  • 注意力:如前所述,在文獻中,通常在低解析度下使用 注意力層 以獲得更好的全域一致性。為了簡單起見,我省略了它們。
  • 上採樣:網路中的 雙線性上採樣和最近鄰上採樣 的效果相似,但我沒有嘗試 轉置卷積

有關 GAN 的類似列表,請查看 此 Keras 教學課程


接下來嘗試什麼?

如果您想深入研究該主題,我建議查看我為準備此程式碼範例而建立的 此儲存庫,該儲存庫以類似的風格實作了更廣泛的功能,例如