程式碼範例 / 生成式深度學習 / 使用自適應鑑別器增強的資料效率 GAN

使用自適應鑑別器增強的資料效率 GAN

作者: András Béres
建立日期 2021/10/28
上次修改日期 2021/10/28
說明: 使用 Caltech Birds 資料集從有限的資料中生成圖像。

ⓘ 這個範例使用 Keras 2

在 Colab 中檢視 GitHub 原始碼


簡介

GAN

生成對抗網路 (GAN) 是一種流行的生成式深度學習模型,通常用於圖像生成。它們由一對相互對抗的神經網路組成,稱為鑑別器和生成器。鑑別器的任務是區分真實圖像和生成的(偽造的)圖像,而生成器網路則試圖透過生成越來越真實的圖像來欺騙鑑別器。然而,如果生成器太容易或太難以欺騙,它可能無法為生成器提供有用的學習訊號,因此訓練 GAN 通常被認為是一項困難的任務。

用於 GAN 的資料增強

資料增強是深度學習中一種流行的技術,它是隨機將語義保留轉換應用於輸入資料以生成其多個真實版本,從而有效地增加可用訓練資料量的過程。最簡單的例子是左右翻轉圖像,這樣既保留了圖像的內容,又生成了第二個獨特的訓練樣本。資料增強通常用於監督式學習中,以防止過度擬合並增強泛化能力。

StyleGAN2-ADA 的作者表明,鑑別器過度擬合可能是 GAN 中的一個問題,尤其是在只有少量訓練資料可用的情況下。他們提出自適應鑑別器增強來緩解這個問題。

然而,將資料增強應用於 GAN 並非易事。由於生成器是使用鑑別器的梯度來更新的,如果生成的圖像被增強,則增強管道必須是可微分的,並且還必須與 GPU 相容才能提高計算效率。幸運的是,Keras 圖像增強層滿足這兩個要求,因此非常適合此任務。

可逆資料增強

在生成模型中使用資料增強時,一個可能的困難是 「洩漏增強」(第 2.2 節) 的問題,即當模型生成已增強的圖像時。這將意味著它無法將增強與底層資料分佈分開,這可能是由於使用不可逆的資料轉換造成的。例如,如果以相等的機率執行 0 度、90 度、180 度或 270 度旋轉,則不可能推斷出圖像的原始方向,並且此資訊會被破壞。

使資料增強可逆的一個簡單技巧是僅以一定的機率應用它們。這樣,圖像的原始版本將更常見,並且可以推斷出資料分佈。透過適當選擇此機率,可以有效地正規化鑑別器,而不會使增強洩漏。


設定

import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds

from tensorflow import keras
from tensorflow.keras import layers

超參數

# data
num_epochs = 10  # train for 400 epochs for good results
image_size = 64
# resolution of Kernel Inception Distance measurement, see related section
kid_image_size = 75
padding = 0.25
dataset_name = "caltech_birds2011"

# adaptive discriminator augmentation
max_translation = 0.125
max_rotation = 0.125
max_zoom = 0.25
target_accuracy = 0.85
integration_steps = 1000

# architecture
noise_size = 64
depth = 4
width = 128
leaky_relu_slope = 0.2
dropout_rate = 0.4

# optimization
batch_size = 128
learning_rate = 2e-4
beta_1 = 0.5  # not using the default value of 0.9 is important
ema = 0.99

資料管道

在本範例中,我們將使用 Caltech Birds (2011) 資料集來生成鳥類的圖像,這是一個多樣化的自然資料集,其中包含少於 6000 個用於訓練的圖像。當使用如此少量的資料時,必須格外小心以盡可能保持高的資料品質。在本範例中,我們使用提供的鳥類邊界框來裁剪出方形裁剪,並在可能的情況下保留其長寬比。

def round_to_int(float_value):
    return tf.cast(tf.math.round(float_value), dtype=tf.int32)


def preprocess_image(data):
    # unnormalize bounding box coordinates
    height = tf.cast(tf.shape(data["image"])[0], dtype=tf.float32)
    width = tf.cast(tf.shape(data["image"])[1], dtype=tf.float32)
    bounding_box = data["bbox"] * tf.stack([height, width, height, width])

    # calculate center and length of longer side, add padding
    target_center_y = 0.5 * (bounding_box[0] + bounding_box[2])
    target_center_x = 0.5 * (bounding_box[1] + bounding_box[3])
    target_size = tf.maximum(
        (1.0 + padding) * (bounding_box[2] - bounding_box[0]),
        (1.0 + padding) * (bounding_box[3] - bounding_box[1]),
    )

    # modify crop size to fit into image
    target_height = tf.reduce_min(
        [target_size, 2.0 * target_center_y, 2.0 * (height - target_center_y)]
    )
    target_width = tf.reduce_min(
        [target_size, 2.0 * target_center_x, 2.0 * (width - target_center_x)]
    )

    # crop image
    image = tf.image.crop_to_bounding_box(
        data["image"],
        offset_height=round_to_int(target_center_y - 0.5 * target_height),
        offset_width=round_to_int(target_center_x - 0.5 * target_width),
        target_height=round_to_int(target_height),
        target_width=round_to_int(target_width),
    )

    # resize and clip
    # for image downsampling, area interpolation is the preferred method
    image = tf.image.resize(
        image, size=[image_size, image_size], method=tf.image.ResizeMethod.AREA
    )
    return tf.clip_by_value(image / 255.0, 0.0, 1.0)


def prepare_dataset(split):
    # the validation dataset is shuffled as well, because data order matters
    # for the KID calculation
    return (
        tfds.load(dataset_name, split=split, shuffle_files=True)
        .map(preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)
        .cache()
        .shuffle(10 * batch_size)
        .batch(batch_size, drop_remainder=True)
        .prefetch(buffer_size=tf.data.AUTOTUNE)
    )


train_dataset = prepare_dataset("train")
val_dataset = prepare_dataset("test")

預處理後,訓練圖像如下所示: 鳥類資料集


核心初始距離

核心初始距離 (KID) 被提出來作為流行的 弗雷歇初始距離 (FID) 指標的替代品,用於衡量圖像生成品質。這兩個指標都衡量在 InceptionV3 網路(預先在 ImageNet 上訓練)的表示空間中,生成分佈和訓練分佈之間的差異。

根據論文,提出 KID 是因為 FID 沒有無偏估計器,當它在較少的圖像上測量時,它的期望值較高。KID 更適合小型資料集,因為它的期望值不取決於測量的樣本數量。根據我的經驗,它在計算上也更輕便,在數值上更穩定,並且因為它可以按批次方式估計而更易於實作。

在本範例中,圖像在 Inception 網路的最小可能解析度 (75x75 而不是 299x299) 下進行評估,並且該指標僅在驗證集上進行測量以提高計算效率。

class KID(keras.metrics.Metric):
    def __init__(self, name="kid", **kwargs):
        super().__init__(name=name, **kwargs)

        # KID is estimated per batch and is averaged across batches
        self.kid_tracker = keras.metrics.Mean()

        # a pretrained InceptionV3 is used without its classification layer
        # transform the pixel values to the 0-255 range, then use the same
        # preprocessing as during pretraining
        self.encoder = keras.Sequential(
            [
                layers.InputLayer(input_shape=(image_size, image_size, 3)),
                layers.Rescaling(255.0),
                layers.Resizing(height=kid_image_size, width=kid_image_size),
                layers.Lambda(keras.applications.inception_v3.preprocess_input),
                keras.applications.InceptionV3(
                    include_top=False,
                    input_shape=(kid_image_size, kid_image_size, 3),
                    weights="imagenet",
                ),
                layers.GlobalAveragePooling2D(),
            ],
            name="inception_encoder",
        )

    def polynomial_kernel(self, features_1, features_2):
        feature_dimensions = tf.cast(tf.shape(features_1)[1], dtype=tf.float32)
        return (features_1 @ tf.transpose(features_2) / feature_dimensions + 1.0) ** 3.0

    def update_state(self, real_images, generated_images, sample_weight=None):
        real_features = self.encoder(real_images, training=False)
        generated_features = self.encoder(generated_images, training=False)

        # compute polynomial kernels using the two sets of features
        kernel_real = self.polynomial_kernel(real_features, real_features)
        kernel_generated = self.polynomial_kernel(
            generated_features, generated_features
        )
        kernel_cross = self.polynomial_kernel(real_features, generated_features)

        # estimate the squared maximum mean discrepancy using the average kernel values
        batch_size = tf.shape(real_features)[0]
        batch_size_f = tf.cast(batch_size, dtype=tf.float32)
        mean_kernel_real = tf.reduce_sum(kernel_real * (1.0 - tf.eye(batch_size))) / (
            batch_size_f * (batch_size_f - 1.0)
        )
        mean_kernel_generated = tf.reduce_sum(
            kernel_generated * (1.0 - tf.eye(batch_size))
        ) / (batch_size_f * (batch_size_f - 1.0))
        mean_kernel_cross = tf.reduce_mean(kernel_cross)
        kid = mean_kernel_real + mean_kernel_generated - 2.0 * mean_kernel_cross

        # update the average KID estimate
        self.kid_tracker.update_state(kid)

    def result(self):
        return self.kid_tracker.result()

    def reset_state(self):
        self.kid_tracker.reset_state()

自適應鑑別器增強

StyleGAN2-ADA 的作者建議在訓練期間自適應地更改增強機率。儘管該論文中的解釋不同,但他們使用增強機率的 積分控制來保持鑑別器在真實圖像上的準確性接近目標值。請注意,他們控制的變數實際上是鑑別器邏輯平均符號 (論文中的 r_t),它對應於 2 * 準確度 - 1。

此方法需要兩個超參數

  1. target_accuracy:鑑別器在真實圖像上的準確性的目標值。我建議從 80-90% 的範圍內選擇其值。
  2. integration_steps:將 100% 的準確度誤差轉換為 100% 的增強機率增加所需的更新步驟數。直觀地說,這定義了增強機率變化的速度。我建議將其設定為相對較高的值 (在本例中為 1000),以便僅緩慢調整增強強度。

此程序的主要動機是目標準確度的最佳值在不同的資料集大小中相似(請參閱論文中的圖 4 和圖 5),因此不必重新調整,因為該過程會在需要時自動應用更強的資料增強。

# "hard sigmoid", useful for binary accuracy calculation from logits
def step(values):
    # negative values -> 0.0, positive values -> 1.0
    return 0.5 * (1.0 + tf.sign(values))


# augments images with a probability that is dynamically updated during training
class AdaptiveAugmenter(keras.Model):
    def __init__(self):
        super().__init__()

        # stores the current probability of an image being augmented
        self.probability = tf.Variable(0.0)

        # the corresponding augmentation names from the paper are shown above each layer
        # the authors show (see figure 4), that the blitting and geometric augmentations
        # are the most helpful in the low-data regime
        self.augmenter = keras.Sequential(
            [
                layers.InputLayer(input_shape=(image_size, image_size, 3)),
                # blitting/x-flip:
                layers.RandomFlip("horizontal"),
                # blitting/integer translation:
                layers.RandomTranslation(
                    height_factor=max_translation,
                    width_factor=max_translation,
                    interpolation="nearest",
                ),
                # geometric/rotation:
                layers.RandomRotation(factor=max_rotation),
                # geometric/isotropic and anisotropic scaling:
                layers.RandomZoom(
                    height_factor=(-max_zoom, 0.0), width_factor=(-max_zoom, 0.0)
                ),
            ],
            name="adaptive_augmenter",
        )

    def call(self, images, training):
        if training:
            augmented_images = self.augmenter(images, training)

            # during training either the original or the augmented images are selected
            # based on self.probability
            augmentation_values = tf.random.uniform(
                shape=(batch_size, 1, 1, 1), minval=0.0, maxval=1.0
            )
            augmentation_bools = tf.math.less(augmentation_values, self.probability)

            images = tf.where(augmentation_bools, augmented_images, images)
        return images

    def update(self, real_logits):
        current_accuracy = tf.reduce_mean(step(real_logits))

        # the augmentation probability is updated based on the discriminator's
        # accuracy on real images
        accuracy_error = current_accuracy - target_accuracy
        self.probability.assign(
            tf.clip_by_value(
                self.probability + accuracy_error / integration_steps, 0.0, 1.0
            )
        )

網路架構

在這裡,我們指定兩個網路的架構

  • 生成器:將隨機向量映射到圖像,該圖像應盡可能真實
  • 鑑別器:將圖像映射到純量分數,真實圖像的分數應較高,生成圖像的分數應較低

GAN 往往對網路架構敏感,我在本範例中實作了 DCGAN 架構,因為它在訓練過程中相對穩定,同時易於實作。我們在整個網路中使用固定數量的濾波器,在生成器的最後一層中使用 sigmoid 而不是 tanh,並使用預設初始化而不是隨機常態作為進一步的簡化。

作為一個好的做法,我們停用批次正規化層中的可學習比例參數,因為一方面,以下 relu + 卷積層使其變得多餘(如 文件中所述)。但也因為根據理論,當使用 譜正規化 (第 4.1 節) 時應停用它,此處未使用譜正規化,但在 GAN 中很常見。我們還停用完全連線和卷積層中的偏差,因為以下的批次正規化會使其變得多餘。

# DCGAN generator
def get_generator():
    noise_input = keras.Input(shape=(noise_size,))
    x = layers.Dense(4 * 4 * width, use_bias=False)(noise_input)
    x = layers.BatchNormalization(scale=False)(x)
    x = layers.ReLU()(x)
    x = layers.Reshape(target_shape=(4, 4, width))(x)
    for _ in range(depth - 1):
        x = layers.Conv2DTranspose(
            width, kernel_size=4, strides=2, padding="same", use_bias=False,
        )(x)
        x = layers.BatchNormalization(scale=False)(x)
        x = layers.ReLU()(x)
    image_output = layers.Conv2DTranspose(
        3, kernel_size=4, strides=2, padding="same", activation="sigmoid",
    )(x)

    return keras.Model(noise_input, image_output, name="generator")


# DCGAN discriminator
def get_discriminator():
    image_input = keras.Input(shape=(image_size, image_size, 3))
    x = image_input
    for _ in range(depth):
        x = layers.Conv2D(
            width, kernel_size=4, strides=2, padding="same", use_bias=False,
        )(x)
        x = layers.BatchNormalization(scale=False)(x)
        x = layers.LeakyReLU(alpha=leaky_relu_slope)(x)
    x = layers.Flatten()(x)
    x = layers.Dropout(dropout_rate)(x)
    output_score = layers.Dense(1)(x)

    return keras.Model(image_input, output_score, name="discriminator")

GAN 模型

class GAN_ADA(keras.Model):
    def __init__(self):
        super().__init__()

        self.augmenter = AdaptiveAugmenter()
        self.generator = get_generator()
        self.ema_generator = keras.models.clone_model(self.generator)
        self.discriminator = get_discriminator()

        self.generator.summary()
        self.discriminator.summary()

    def compile(self, generator_optimizer, discriminator_optimizer, **kwargs):
        super().compile(**kwargs)

        # separate optimizers for the two networks
        self.generator_optimizer = generator_optimizer
        self.discriminator_optimizer = discriminator_optimizer

        self.generator_loss_tracker = keras.metrics.Mean(name="g_loss")
        self.discriminator_loss_tracker = keras.metrics.Mean(name="d_loss")
        self.real_accuracy = keras.metrics.BinaryAccuracy(name="real_acc")
        self.generated_accuracy = keras.metrics.BinaryAccuracy(name="gen_acc")
        self.augmentation_probability_tracker = keras.metrics.Mean(name="aug_p")
        self.kid = KID()

    @property
    def metrics(self):
        return [
            self.generator_loss_tracker,
            self.discriminator_loss_tracker,
            self.real_accuracy,
            self.generated_accuracy,
            self.augmentation_probability_tracker,
            self.kid,
        ]

    def generate(self, batch_size, training):
        latent_samples = tf.random.normal(shape=(batch_size, noise_size))
        # use ema_generator during inference
        if training:
            generated_images = self.generator(latent_samples, training)
        else:
            generated_images = self.ema_generator(latent_samples, training)
        return generated_images

    def adversarial_loss(self, real_logits, generated_logits):
        # this is usually called the non-saturating GAN loss

        real_labels = tf.ones(shape=(batch_size, 1))
        generated_labels = tf.zeros(shape=(batch_size, 1))

        # the generator tries to produce images that the discriminator considers as real
        generator_loss = keras.losses.binary_crossentropy(
            real_labels, generated_logits, from_logits=True
        )
        # the discriminator tries to determine if images are real or generated
        discriminator_loss = keras.losses.binary_crossentropy(
            tf.concat([real_labels, generated_labels], axis=0),
            tf.concat([real_logits, generated_logits], axis=0),
            from_logits=True,
        )

        return tf.reduce_mean(generator_loss), tf.reduce_mean(discriminator_loss)

    def train_step(self, real_images):
        real_images = self.augmenter(real_images, training=True)

        # use persistent gradient tape because gradients will be calculated twice
        with tf.GradientTape(persistent=True) as tape:
            generated_images = self.generate(batch_size, training=True)
            # gradient is calculated through the image augmentation
            generated_images = self.augmenter(generated_images, training=True)

            # separate forward passes for the real and generated images, meaning
            # that batch normalization is applied separately
            real_logits = self.discriminator(real_images, training=True)
            generated_logits = self.discriminator(generated_images, training=True)

            generator_loss, discriminator_loss = self.adversarial_loss(
                real_logits, generated_logits
            )

        # calculate gradients and update weights
        generator_gradients = tape.gradient(
            generator_loss, self.generator.trainable_weights
        )
        discriminator_gradients = tape.gradient(
            discriminator_loss, self.discriminator.trainable_weights
        )
        self.generator_optimizer.apply_gradients(
            zip(generator_gradients, self.generator.trainable_weights)
        )
        self.discriminator_optimizer.apply_gradients(
            zip(discriminator_gradients, self.discriminator.trainable_weights)
        )

        # update the augmentation probability based on the discriminator's performance
        self.augmenter.update(real_logits)

        self.generator_loss_tracker.update_state(generator_loss)
        self.discriminator_loss_tracker.update_state(discriminator_loss)
        self.real_accuracy.update_state(1.0, step(real_logits))
        self.generated_accuracy.update_state(0.0, step(generated_logits))
        self.augmentation_probability_tracker.update_state(self.augmenter.probability)

        # track the exponential moving average of the generator's weights to decrease
        # variance in the generation quality
        for weight, ema_weight in zip(
            self.generator.weights, self.ema_generator.weights
        ):
            ema_weight.assign(ema * ema_weight + (1 - ema) * weight)

        # KID is not measured during the training phase for computational efficiency
        return {m.name: m.result() for m in self.metrics[:-1]}

    def test_step(self, real_images):
        generated_images = self.generate(batch_size, training=False)

        self.kid.update_state(real_images, generated_images)

        # only KID is measured during the evaluation phase for computational efficiency
        return {self.kid.name: self.kid.result()}

    def plot_images(self, epoch=None, logs=None, num_rows=3, num_cols=6, interval=5):
        # plot random generated images for visual evaluation of generation quality
        if epoch is None or (epoch + 1) % interval == 0:
            num_images = num_rows * num_cols
            generated_images = self.generate(num_images, training=False)

            plt.figure(figsize=(num_cols * 2.0, num_rows * 2.0))
            for row in range(num_rows):
                for col in range(num_cols):
                    index = row * num_cols + col
                    plt.subplot(num_rows, num_cols, index + 1)
                    plt.imshow(generated_images[index])
                    plt.axis("off")
            plt.tight_layout()
            plt.show()
            plt.close()

訓練

從訓練期間的指標可以看出,如果真實準確度(鑑別器在真實圖像上的準確度)低於目標準確度,則增強機率會增加,反之亦然。根據我的經驗,在正常的 GAN 訓練過程中,鑑別器準確度應保持在 80-95% 的範圍內。低於該值,鑑別器太弱,高於該值,鑑別器太強。

請注意,我們追蹤生成器權重的指數移動平均值,並將其用於圖像生成和 KID 評估。

# create and compile the model
model = GAN_ADA()
model.compile(
    generator_optimizer=keras.optimizers.Adam(learning_rate, beta_1),
    discriminator_optimizer=keras.optimizers.Adam(learning_rate, beta_1),
)

# save the best model based on the validation KID metric
checkpoint_path = "gan_model"
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_path,
    save_weights_only=True,
    monitor="val_kid",
    mode="min",
    save_best_only=True,
)

# run training and plot generated images periodically
model.fit(
    train_dataset,
    epochs=num_epochs,
    validation_data=val_dataset,
    callbacks=[
        keras.callbacks.LambdaCallback(on_epoch_end=model.plot_images),
        checkpoint_callback,
    ],
)
Model: "generator"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_2 (InputLayer)         [(None, 64)]              0         
_________________________________________________________________
dense (Dense)                (None, 2048)              131072    
_________________________________________________________________
batch_normalization (BatchNo (None, 2048)              6144      
_________________________________________________________________
re_lu (ReLU)                 (None, 2048)              0         
_________________________________________________________________
reshape (Reshape)            (None, 4, 4, 128)         0         
_________________________________________________________________
conv2d_transpose (Conv2DTran (None, 8, 8, 128)         262144    
_________________________________________________________________
batch_normalization_1 (Batch (None, 8, 8, 128)         384       
_________________________________________________________________
re_lu_1 (ReLU)               (None, 8, 8, 128)         0         
_________________________________________________________________
conv2d_transpose_1 (Conv2DTr (None, 16, 16, 128)       262144    
_________________________________________________________________
batch_normalization_2 (Batch (None, 16, 16, 128)       384       
_________________________________________________________________
re_lu_2 (ReLU)               (None, 16, 16, 128)       0         
_________________________________________________________________
conv2d_transpose_2 (Conv2DTr (None, 32, 32, 128)       262144    
_________________________________________________________________
batch_normalization_3 (Batch (None, 32, 32, 128)       384       
_________________________________________________________________
re_lu_3 (ReLU)               (None, 32, 32, 128)       0         
_________________________________________________________________
conv2d_transpose_3 (Conv2DTr (None, 64, 64, 3)         6147      
=================================================================
Total params: 930,947
Trainable params: 926,083
Non-trainable params: 4,864
_________________________________________________________________
Model: "discriminator"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_3 (InputLayer)         [(None, 64, 64, 3)]       0         
_________________________________________________________________
conv2d (Conv2D)              (None, 32, 32, 128)       6144      
_________________________________________________________________
batch_normalization_4 (Batch (None, 32, 32, 128)       384       
_________________________________________________________________
leaky_re_lu (LeakyReLU)      (None, 32, 32, 128)       0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 16, 16, 128)       262144    
_________________________________________________________________
batch_normalization_5 (Batch (None, 16, 16, 128)       384       
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 16, 16, 128)       0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 8, 8, 128)         262144    
_________________________________________________________________
batch_normalization_6 (Batch (None, 8, 8, 128)         384       
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 8, 8, 128)         0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 4, 4, 128)         262144    
_________________________________________________________________
batch_normalization_7 (Batch (None, 4, 4, 128)         384       
_________________________________________________________________
leaky_re_lu_3 (LeakyReLU)    (None, 4, 4, 128)         0         
_________________________________________________________________
flatten (Flatten)            (None, 2048)              0         
_________________________________________________________________
dropout (Dropout)            (None, 2048)              0         
_________________________________________________________________
dense_1 (Dense)              (None, 1)                 2049      
=================================================================
Total params: 796,161
Trainable params: 795,137
Non-trainable params: 1,024
_________________________________________________________________
Epoch 1/10
46/46 [==============================] - 36s 307ms/step - g_loss: 3.3293 - d_loss: 0.1576 - real_acc: 0.9387 - gen_acc: 0.9579 - aug_p: 0.0020 - val_kid: 9.0999
Epoch 2/10
46/46 [==============================] - 10s 215ms/step - g_loss: 4.9824 - d_loss: 0.0912 - real_acc: 0.9704 - gen_acc: 0.9798 - aug_p: 0.0077 - val_kid: 8.3523
Epoch 3/10
46/46 [==============================] - 10s 218ms/step - g_loss: 5.0587 - d_loss: 0.1248 - real_acc: 0.9530 - gen_acc: 0.9625 - aug_p: 0.0131 - val_kid: 6.8116
Epoch 4/10
46/46 [==============================] - 10s 221ms/step - g_loss: 4.2580 - d_loss: 0.1002 - real_acc: 0.9686 - gen_acc: 0.9740 - aug_p: 0.0179 - val_kid: 5.2327
Epoch 5/10
46/46 [==============================] - 10s 225ms/step - g_loss: 4.6022 - d_loss: 0.0847 - real_acc: 0.9655 - gen_acc: 0.9852 - aug_p: 0.0234 - val_kid: 3.9004

png

Epoch 6/10
46/46 [==============================] - 10s 224ms/step - g_loss: 4.9362 - d_loss: 0.0671 - real_acc: 0.9791 - gen_acc: 0.9895 - aug_p: 0.0291 - val_kid: 6.6020
Epoch 7/10
46/46 [==============================] - 10s 222ms/step - g_loss: 4.4272 - d_loss: 0.1184 - real_acc: 0.9570 - gen_acc: 0.9657 - aug_p: 0.0345 - val_kid: 3.3644
Epoch 8/10
46/46 [==============================] - 10s 220ms/step - g_loss: 4.5060 - d_loss: 0.1635 - real_acc: 0.9421 - gen_acc: 0.9594 - aug_p: 0.0392 - val_kid: 3.1381
Epoch 9/10
46/46 [==============================] - 10s 219ms/step - g_loss: 3.8264 - d_loss: 0.1667 - real_acc: 0.9383 - gen_acc: 0.9484 - aug_p: 0.0433 - val_kid: 2.9423
Epoch 10/10
46/46 [==============================] - 10s 219ms/step - g_loss: 3.4063 - d_loss: 0.1757 - real_acc: 0.9314 - gen_acc: 0.9475 - aug_p: 0.0473 - val_kid: 2.9112

png

<keras.callbacks.History at 0x7fefcc2cb9d0>

推論

# load the best model and generate images
model.load_weights(checkpoint_path)
model.plot_images()

png


結果

透過執行 400 個 epoch 的訓練(在 Colab 筆記本中需要 2-3 小時),可以使用此程式碼範例獲得高品質的圖像生成。

隨機一批圖像在 400 個 epoch 訓練過程中的演變 (動畫平滑度 ema=0.999): 鳥類演變 gif

選定圖片批次之間的潛在空間插值:鳥類插值 gif

我也建議嘗試在其他資料集上進行訓練,例如 CelebA。根據我的經驗,在不更改任何超參數的情況下也能取得不錯的結果(儘管辨別器增強可能不是必要的)。


GAN 的技巧與訣竅

我這個範例的目標是在 GAN 的易於實作和生成品質之間找到一個良好的平衡。在準備過程中,我使用這個儲存庫進行了許多消融研究。

在本節中,我將列出我學到的經驗以及我根據主觀重要性順序提出的建議。

我建議參考 DCGAN 論文、這個 NeurIPS 演講 和這個 大型 GAN 研究,以了解其他人對這個主題的看法。

架構技巧

  • 解析度:在較高解析度下訓練 GAN 往往會變得更加困難,我建議先從 32x32 或 64x64 解析度開始實驗。
  • 初始化:如果您在訓練初期看到強烈的彩色圖案,則可能是初始化出了問題。將層的 kernel_initializer 參數設定為 隨機常態分佈,並減小標準差(建議值:0.02,遵循 DCGAN),直到問題消失。
  • 上採樣:在生成器中有兩種主要的上採樣方法。轉置卷積速度較快,但可能會導致棋盤格偽影,這可以透過使用可被步幅整除的核大小來減少(對於步幅為 2,建議的核大小為 4)。上採樣 + 標準卷積的品質可能略低,但棋盤格偽影不是問題。我建議使用最近鄰插值而不是雙線性插值。
  • 辨別器中的批次正規化:有時會有很大的影響,我建議兩種方式都嘗試一下。
  • 譜正規化:一種用於訓練 GAN 的流行技術,可以幫助提高穩定性。我建議禁用批次正規化的可學習縮放參數以及它。
  • 殘差連接:雖然殘差辨別器的行為類似,但根據我的經驗,殘差生成器更難以訓練。然而,它們對於訓練大型和深層架構是必要的。我建議從非殘差架構開始。
  • Dropout:根據我的經驗,在辨別器的最後一層之前使用 Dropout 可以提高生成品質。建議的 Dropout 率低於 0.5。
  • Leaky ReLU:在辨別器中使用 Leaky ReLU 激活函數,以使其梯度不那麼稀疏。建議的斜率/alpha 遵循 DCGAN 為 0.2。

演算法技巧

  • 損失函數:多年來,人們提出了許多用於訓練 GAN 的損失函數,承諾可以提高效能和穩定性。我已經在這個儲存庫中實作了其中 5 個,我的經驗與 這個 GAN 研究一致:似乎沒有任何損失函數能始終優於預設的非飽和 GAN 損失。我建議將其作為預設使用。
  • Adam 的 beta_1 參數:Adam 中的 beta_1 參數可以解釋為平均梯度估計的動量。DCGAN 中提出了使用 0.5 甚至 0.0 而不是預設值 0.9,這很重要。這個範例如果使用預設值將無法運作。
  • 為生成和真實圖像分別進行批次正規化:辨別器的前向傳遞應該為生成和真實圖像分開進行。否則可能會導致偽影(在我的情況下是 45 度條紋)和效能下降。
  • 生成器權重的指數移動平均:這有助於減少 KID 測量的變異數,並有助於平均化訓練期間快速的調色盤變化。
  • 生成器和辨別器不同的學習率:如果資源充足,它可以幫助分別調整兩個網路的學習率。一個類似的想法是,對於另一個網路的每次更新,多次更新其中一個網路(通常是辨別器)的權重。我建議對於兩個網路都使用相同的學習率 2e-4 (Adam),遵循 DCGAN,並且預設情況下僅更新它們一次。
  • 標籤雜訊單邊標籤平滑(對真實標籤使用小於 1.0 的值)或在標籤中添加雜訊可以正規化辨別器,使其不過度自信,但在我的情況下,它們沒有提高效能。
  • 自適應數據增強:由於它為訓練過程增加了另一個動態組件,因此預設情況下請禁用它,僅在其他組件已經運作良好時才啟用它。

其他 GAN 相關的 Keras 程式碼範例

現代 GAN 架構線路

關於辨別器數據增強的同期論文:123

最近關於 GAN 的文獻概述:演講