程式碼範例 / 生成式深度學習 / 使用 Stable Diffusion 3 漫步於潛在空間

使用 Stable Diffusion 3 漫步於潛在空間

作者: Hongyu Chiu、Ian Stenbit、fcholletlukewood
建立日期 2024/11/11
上次修改日期 2024/11/11
說明: 探索 Stable Diffusion 3 的潛在多樣性。

ⓘ 這個範例使用 Keras 3

在 Colab 中檢視 GitHub 原始碼


概觀

生成式影像模型會學習視覺世界的「潛在多樣性」:一個低維向量空間,其中每個點都會對應到一個影像。從多樣性上的某個點回到可顯示的影像稱為「解碼」– 在 Stable Diffusion 模型中,這由「解碼器」模型處理。

Stable Diffusion 3 Medium Architecture

這個影像的潛在多樣性是連續且可內插的,表示

  1. 在多樣性上稍微移動只會稍微改變對應的影像(連續性)。
  2. 對於多樣性上的任何兩個點 A 和 B(即任何兩個影像),可以透過路徑從 A 移動到 B,其中每個中間點也都在多樣性上(即也是有效的影像)。中間點會被稱為兩個起始影像之間的「內插」。

不過,Stable Diffusion 不只是影像模型,它也是自然語言模型。它有兩個潛在空間:訓練期間使用的編碼器學習的影像表示空間,以及使用預訓練和訓練時微調結合學習的提示潛在空間。

潛在空間漫步潛在空間探索是取樣潛在空間中的點,並逐步變更潛在表示的過程。最常見的應用是產生動畫,其中每個取樣點都會饋送到解碼器,並儲存為最終動畫中的影格。對於高品質的潛在表示,這會產生外觀連貫的動畫。這些動畫可以深入了解潛在空間的特徵圖,並最終提升訓練流程。以下顯示這類 GIF

dog_to_cat_64.gif

在本指南中,我們將說明如何利用 KerasHub 中的 TextToImage API 來執行提示內插,以及在 Stable Diffusion 3 的視覺潛在多樣性中進行循環漫步,以及在文字編碼器的潛在多樣性中進行循環漫步。

本指南假設讀者對 Stable Diffusion 3 有高階的了解。如果您還沒看過,應該先閱讀 KerasHub 中的 Stable Diffusion 3

此外,值得注意的是,預設的「stable_diffusion_3_medium」排除了 T5XXL 文字編碼器,因為它需要更多 GPU 記憶體。在大多數情況下,效能下降的幅度微乎其微。包含 T5XXL 的權重很快就會在 KerasHub 上提供。

!# Use the latest version of KerasHub
!!pip install -Uq git+https://github.com/keras-team/keras-hub.git
import math

import keras
import keras_hub
import matplotlib.pyplot as plt
from keras import ops
from keras import random
from PIL import Image

height, width = 512, 512
num_steps = 28
guidance_scale = 7.0
dtype = "float16"

# Instantiate the Stable Diffusion 3 model and the preprocessor
backbone = keras_hub.models.StableDiffusion3Backbone.from_preset(
    "stable_diffusion_3_medium", image_shape=(height, width, 3), dtype=dtype
)
preprocessor = keras_hub.models.StableDiffusion3TextToImagePreprocessor.from_preset(
    "stable_diffusion_3_medium"
)

讓我們為這個範例定義一些輔助函式。

def get_text_embeddings(prompt):
    """Get the text embeddings for a given prompt."""
    token_ids = preprocessor.generate_preprocess([prompt])
    negative_token_ids = preprocessor.generate_preprocess([""])
    (
        positive_embeddings,
        negative_embeddings,
        positive_pooled_embeddings,
        negative_pooled_embeddings,
    ) = backbone.encode_text_step(token_ids, negative_token_ids)
    return (
        positive_embeddings,
        negative_embeddings,
        positive_pooled_embeddings,
        negative_pooled_embeddings,
    )


def decode_to_images(x, height, width):
    """Concatenate and normalize the images to uint8 dtype."""
    x = ops.concatenate(x, axis=0)
    x = ops.reshape(x, (-1, height, width, 3))
    x = ops.clip(ops.divide(ops.add(x, 1.0), 2.0), 0.0, 1.0)
    return ops.cast(ops.round(ops.multiply(x, 255.0)), "uint8")


def generate_with_latents_and_embeddings(
    latents, embeddings, num_steps, guidance_scale
):
    """Generate images from latents and text embeddings."""

    def body_fun(step, latents):
        return backbone.denoise_step(
            latents,
            embeddings,
            step,
            num_steps,
            guidance_scale,
        )

    latents = ops.fori_loop(0, num_steps, body_fun, latents)
    return backbone.decode_step(latents)


def export_as_gif(filename, images, frames_per_second=10, no_rubber_band=False):
    if not no_rubber_band:
        images += images[2:-1][::-1]  # Makes a rubber band: A->B->A
    images[0].save(
        filename,
        save_all=True,
        append_images=images[1:],
        duration=1000 // frames_per_second,
        loop=0,
    )

我們將使用自訂潛在向量和嵌入來產生影像,因此我們需要實作 generate_with_latents_and_embeddings 函式。此外,將此函式編譯以加速產生程序非常重要。

if keras.config.backend() == "torch":
    import torch

    @torch.no_grad()
    def wrapped_function(*args, **kwargs):
        return generate_with_latents_and_embeddings(*args, **kwargs)

    generate_function = wrapped_function
elif keras.config.backend() == "tensorflow":
    import tensorflow as tf

    generate_function = tf.function(
        generate_with_latents_and_embeddings, jit_compile=True
    )
elif keras.config.backend() == "jax":
    import itertools

    import jax

    @jax.jit
    def compiled_function(state, *args, **kwargs):
        (trainable_variables, non_trainable_variables) = state
        mapping = itertools.chain(
            zip(backbone.trainable_variables, trainable_variables),
            zip(backbone.non_trainable_variables, non_trainable_variables),
        )
        with keras.StatelessScope(state_mapping=mapping):
            return generate_with_latents_and_embeddings(*args, **kwargs)

    def wrapped_function(*args, **kwargs):
        state = (
            [v.value for v in backbone.trainable_variables],
            [v.value for v in backbone.non_trainable_variables],
        )
        return compiled_function(state, *args, **kwargs)

    generate_function = wrapped_function

在文字提示之間內插

在 Stable Diffusion 3 中,文字提示會編碼成多個向量,然後這些向量會用來引導擴散程序。這些潛在編碼向量的形狀分別為 154x4096 和 2048(針對正面和負面提示)– 相當大!當我們將文字提示輸入到 Stable Diffusion 3 時,我們會在這個潛在多樣性上的單一點產生影像。

若要探索這個多樣性的更多部分,我們可以內插兩個文字編碼之間的向量,並在那些內插點產生影像

prompt_1 = "A cute dog in a beautiful field of lavander colorful flowers "
prompt_1 += "everywhere, perfect lighting, leica summicron 35mm f2.0, kodak "
prompt_1 += "portra 400, film grain"
prompt_2 = prompt_1.replace("dog", "cat")
interpolation_steps = 5

encoding_1 = get_text_embeddings(prompt_1)
encoding_2 = get_text_embeddings(prompt_2)


# Show the size of the latent manifold
print(f"Positive embeddings shape: {encoding_1[0].shape}")
print(f"Negative embeddings shape: {encoding_1[1].shape}")
print(f"Positive pooled embeddings shape: {encoding_1[2].shape}")
print(f"Negative pooled embeddings shape: {encoding_1[3].shape}")
Positive embeddings shape: (1, 154, 4096)
Negative embeddings shape: (1, 154, 4096)
Positive pooled embeddings shape: (1, 2048)
Negative pooled embeddings shape: (1, 2048)

在這個範例中,我們想要使用球面線性內插 (slerp) 而不是簡單的線性內插。Slerp 通常用於電腦圖形中,以便平滑地產生旋轉動畫,也可以應用於內插高維度資料點,例如生成模型中使用的潛在向量。

來源來自 Andrej Karpathy 的要點:https://gist.github.com/karpathy/00103b0037c5aaea32fe1da1af553355

如需更詳細的方法說明,請參閱:https://en.wikipedia.org/wiki/Slerp

def slerp(v1, v2, num):
    ori_dtype = v1.dtype
    # Cast to float32 for numerical stability.
    v1 = ops.cast(v1, "float32")
    v2 = ops.cast(v2, "float32")

    def interpolation(t, v1, v2, dot_threshold=0.9995):
        """helper function to spherically interpolate two arrays."""
        dot = ops.sum(
            v1 * v2 / (ops.linalg.norm(ops.ravel(v1)) * ops.linalg.norm(ops.ravel(v2)))
        )
        if ops.abs(dot) > dot_threshold:
            v2 = (1 - t) * v1 + t * v2
        else:
            theta_0 = ops.arccos(dot)
            sin_theta_0 = ops.sin(theta_0)
            theta_t = theta_0 * t
            sin_theta_t = ops.sin(theta_t)
            s0 = ops.sin(theta_0 - theta_t) / sin_theta_0
            s1 = sin_theta_t / sin_theta_0
            v2 = s0 * v1 + s1 * v2
        return v2

    t = ops.linspace(0, 1, num)
    interpolated = ops.stack([interpolation(t[i], v1, v2) for i in range(num)], axis=0)
    return ops.cast(interpolated, ori_dtype)


interpolated_positive_embeddings = slerp(
    encoding_1[0], encoding_2[0], interpolation_steps
)
interpolated_positive_pooled_embeddings = slerp(
    encoding_1[2], encoding_2[2], interpolation_steps
)
# We don't use negative prompts in this example, so there’s no need to
# interpolate them.
negative_embeddings = encoding_1[1]
negative_pooled_embeddings = encoding_1[3]

內插編碼後,我們可以從每個點產生影像。請注意,為了維持產生影像之間的一些穩定性,我們會在影像之間保持擴散潛在向量不變。

latents = random.normal((1, height // 8, width // 8, 16), seed=42)

images = []
progbar = keras.utils.Progbar(interpolation_steps)
for i in range(interpolation_steps):
    images.append(
        generate_function(
            latents,
            (
                interpolated_positive_embeddings[i],
                negative_embeddings,
                interpolated_positive_pooled_embeddings[i],
                negative_pooled_embeddings,
            ),
            ops.convert_to_tensor(num_steps),
            ops.convert_to_tensor(guidance_scale),
        )
    )
    progbar.update(i + 1, finalize=i == interpolation_steps - 1)

現在我們已產生一些內插影像,讓我們來看看這些影像!

在本教學課程中,我們會將影像序列匯出為 GIF,方便您在一些時間脈絡中輕鬆檢視。對於第一個影像和最後一個影像在概念上不一致的影像序列,我們會將 GIF 做橡皮筋效果。

如果您在 Colab 中執行,可以執行以下程式碼來檢視自己的 GIF

from IPython.display import Image as IImage
IImage("dog_to_cat_5.gif")
images = ops.convert_to_numpy(decode_to_images(images, height, width))
export_as_gif(
    "dog_to_cat_5.gif",
    [Image.fromarray(image) for image in images],
    frames_per_second=2,
)

dog_to_cat_5.gif

結果可能看起來令人驚訝。一般來說,在提示之間內插會產生外觀連貫的影像,並且通常會呈現兩個提示內容之間的概念逐漸轉變。這表示表示空間的品質很高,可以緊密鏡射視覺世界的自然結構。

為了充分視覺化這一點,我們應該使用更多步驟進行更細緻的內插。

interpolation_steps = 64
batch_size = 4
batches = interpolation_steps // batch_size

interpolated_positive_embeddings = slerp(
    encoding_1[0], encoding_2[0], interpolation_steps
)
interpolated_positive_pooled_embeddings = slerp(
    encoding_1[2], encoding_2[2], interpolation_steps
)
positive_embeddings_shape = ops.shape(encoding_1[0])
positive_pooled_embeddings_shape = ops.shape(encoding_1[2])
interpolated_positive_embeddings = ops.reshape(
    interpolated_positive_embeddings,
    (
        batches,
        batch_size,
        positive_embeddings_shape[-2],
        positive_embeddings_shape[-1],
    ),
)
interpolated_positive_pooled_embeddings = ops.reshape(
    interpolated_positive_pooled_embeddings,
    (batches, batch_size, positive_pooled_embeddings_shape[-1]),
)
negative_embeddings = ops.tile(encoding_1[1], (batch_size, 1, 1))
negative_pooled_embeddings = ops.tile(encoding_1[3], (batch_size, 1))

latents = random.normal((1, height // 8, width // 8, 16), seed=42)
latents = ops.tile(latents, (batch_size, 1, 1, 1))

images = []
progbar = keras.utils.Progbar(batches)
for i in range(batches):
    images.append(
        generate_function(
            latents,
            (
                interpolated_positive_embeddings[i],
                negative_embeddings,
                interpolated_positive_pooled_embeddings[i],
                negative_pooled_embeddings,
            ),
            ops.convert_to_tensor(num_steps),
            ops.convert_to_tensor(guidance_scale),
        )
    )
    progbar.update(i + 1, finalize=i == batches - 1)

images = ops.convert_to_numpy(decode_to_images(images, height, width))
export_as_gif(
    "dog_to_cat_64.gif",
    [Image.fromarray(image) for image in images],
    frames_per_second=2,
)

dog_to_cat_64.gif

產生的 GIF 顯示兩個提示之間更清晰且更連貫的轉變。試試看您自己的提示並實驗看看!

我們甚至可以將這個概念延伸到一個以上的影像。例如,我們可以在四個提示之間內插

prompt_1 = "A watercolor painting of a Golden Retriever at the beach"
prompt_2 = "A still life DSLR photo of a bowl of fruit"
prompt_3 = "The eiffel tower in the style of starry night"
prompt_4 = "An architectural sketch of a skyscraper"

interpolation_steps = 8
batch_size = 4
batches = (interpolation_steps**2) // batch_size

encoding_1 = get_text_embeddings(prompt_1)
encoding_2 = get_text_embeddings(prompt_2)
encoding_3 = get_text_embeddings(prompt_3)
encoding_4 = get_text_embeddings(prompt_4)

positive_embeddings_shape = ops.shape(encoding_1[0])
positive_pooled_embeddings_shape = ops.shape(encoding_1[2])
interpolated_positive_embeddings_12 = slerp(
    encoding_1[0], encoding_2[0], interpolation_steps
)
interpolated_positive_embeddings_34 = slerp(
    encoding_3[0], encoding_4[0], interpolation_steps
)
interpolated_positive_embeddings = slerp(
    interpolated_positive_embeddings_12,
    interpolated_positive_embeddings_34,
    interpolation_steps,
)
interpolated_positive_embeddings = ops.reshape(
    interpolated_positive_embeddings,
    (
        batches,
        batch_size,
        positive_embeddings_shape[-2],
        positive_embeddings_shape[-1],
    ),
)
interpolated_positive_pooled_embeddings_12 = slerp(
    encoding_1[2], encoding_2[2], interpolation_steps
)
interpolated_positive_pooled_embeddings_34 = slerp(
    encoding_3[2], encoding_4[2], interpolation_steps
)
interpolated_positive_pooled_embeddings = slerp(
    interpolated_positive_pooled_embeddings_12,
    interpolated_positive_pooled_embeddings_34,
    interpolation_steps,
)
interpolated_positive_pooled_embeddings = ops.reshape(
    interpolated_positive_pooled_embeddings,
    (batches, batch_size, positive_pooled_embeddings_shape[-1]),
)
negative_embeddings = ops.tile(encoding_1[1], (batch_size, 1, 1))
negative_pooled_embeddings = ops.tile(encoding_1[3], (batch_size, 1))

latents = random.normal((1, height // 8, width // 8, 16), seed=42)
latents = ops.tile(latents, (batch_size, 1, 1, 1))

images = []
progbar = keras.utils.Progbar(batches)
for i in range(batches):
    images.append(
        generate_function(
            latents,
            (
                interpolated_positive_embeddings[i],
                negative_embeddings,
                interpolated_positive_pooled_embeddings[i],
                negative_pooled_embeddings,
            ),
            ops.convert_to_tensor(num_steps),
            ops.convert_to_tensor(guidance_scale),
        )
    )
    progbar.update(i + 1, finalize=i == batches - 1)

讓我們將產生的影像顯示在網格中,以便更輕鬆解讀。

def plot_grid(images, path, grid_size, scale=2):
    fig, axs = plt.subplots(
        grid_size, grid_size, figsize=(grid_size * scale, grid_size * scale)
    )
    fig.tight_layout()
    plt.subplots_adjust(wspace=0, hspace=0)
    plt.axis("off")
    for ax in axs.flat:
        ax.axis("off")

    for i in range(min(grid_size * grid_size, len(images))):
        ax = axs.flat[i]
        ax.imshow(images[i])
        ax.axis("off")

    for i in range(len(images), grid_size * grid_size):
        axs.flat[i].axis("off")
        axs.flat[i].remove()

    plt.savefig(
        fname=path,
        pad_inches=0,
        bbox_inches="tight",
        transparent=False,
        dpi=60,
    )


images = ops.convert_to_numpy(decode_to_images(images, height, width))
plot_grid(images, "4-way-interpolation.jpg", interpolation_steps)

png

我們也可以在允許擴散潛在向量變化的同時進行內插,方法是捨棄 seed 參數

images = []
progbar = keras.utils.Progbar(batches)
for i in range(batches):
    # Vary diffusion latents for each input.
    latents = random.normal((batch_size, height // 8, width // 8, 16))
    images.append(
        generate_function(
            latents,
            (
                interpolated_positive_embeddings[i],
                negative_embeddings,
                interpolated_positive_pooled_embeddings[i],
                negative_pooled_embeddings,
            ),
            ops.convert_to_tensor(num_steps),
            ops.convert_to_tensor(guidance_scale),
        )
    )
    progbar.update(i + 1, finalize=i == batches - 1)

images = ops.convert_to_numpy(decode_to_images(images, height, width))
plot_grid(images, "4-way-interpolation-varying-latent.jpg", interpolation_steps)

png

接下來 – 讓我們進行一些漫步!


繞著文字提示漫步

我們的下一個實驗將會從特定提示產生的點開始,繞著潛在多樣性漫步。

walk_steps = 64
batch_size = 4
batches = walk_steps // batch_size
step_size = 0.01
prompt = "The eiffel tower in the style of starry night"
encoding = get_text_embeddings(prompt)

positive_embeddings = encoding[0]
positive_pooled_embeddings = encoding[2]
negative_embeddings = encoding[1]
negative_pooled_embeddings = encoding[3]

# The shape of `positive_embeddings`: (1, 154, 4096)
# The shape of `positive_pooled_embeddings`: (1, 2048)
positive_embeddings_delta = ops.ones_like(positive_embeddings) * step_size
positive_pooled_embeddings_delta = ops.ones_like(positive_pooled_embeddings) * step_size
positive_embeddings_shape = ops.shape(positive_embeddings)
positive_pooled_embeddings_shape = ops.shape(positive_pooled_embeddings)

walked_positive_embeddings = []
walked_positive_pooled_embeddings = []
for step_index in range(walk_steps):
    walked_positive_embeddings.append(positive_embeddings)
    walked_positive_pooled_embeddings.append(positive_pooled_embeddings)
    positive_embeddings += positive_embeddings_delta
    positive_pooled_embeddings += positive_pooled_embeddings_delta
walked_positive_embeddings = ops.stack(walked_positive_embeddings, axis=0)
walked_positive_pooled_embeddings = ops.stack(walked_positive_pooled_embeddings, axis=0)
walked_positive_embeddings = ops.reshape(
    walked_positive_embeddings,
    (
        batches,
        batch_size,
        positive_embeddings_shape[-2],
        positive_embeddings_shape[-1],
    ),
)
walked_positive_pooled_embeddings = ops.reshape(
    walked_positive_pooled_embeddings,
    (batches, batch_size, positive_pooled_embeddings_shape[-1]),
)
negative_embeddings = ops.tile(encoding_1[1], (batch_size, 1, 1))
negative_pooled_embeddings = ops.tile(encoding_1[3], (batch_size, 1))

latents = random.normal((1, height // 8, width // 8, 16), seed=42)
latents = ops.tile(latents, (batch_size, 1, 1, 1))

images = []
progbar = keras.utils.Progbar(batches)
for i in range(batches):
    images.append(
        generate_function(
            latents,
            (
                walked_positive_embeddings[i],
                negative_embeddings,
                walked_positive_pooled_embeddings[i],
                negative_pooled_embeddings,
            ),
            ops.convert_to_tensor(num_steps),
            ops.convert_to_tensor(guidance_scale),
        )
    )
    progbar.update(i + 1, finalize=i == batches - 1)

images = ops.convert_to_numpy(decode_to_images(images, height, width))
export_as_gif(
    "eiffel-tower-starry-night.gif",
    [Image.fromarray(image) for image in images],
    frames_per_second=2,
)

eiffel-tower-starry-night.gif

或許不令人意外地,離編碼器的潛在多樣性太遠的漫步會產生外觀不連貫的影像。您可以設定自己的提示並調整 step_size 來增加或減少漫步的幅度,自行試試看。請注意,當漫步的幅度變大時,漫步通常會進入會產生極度雜訊影像的區域。


針對單一提示在擴散潛在空間中進行循環漫步

我們的最後一個實驗是維持一個提示,並探索擴散模型可以從該提示產生的各種影像。我們會透過控制用來設定擴散程序種子的雜訊來執行此操作。

我們會建立兩個雜訊元件 xy,並從 0 到 2π 進行漫步,加總 x 元件的餘弦值和 y 元件的正弦值來產生雜訊。使用這種方法,漫步的終點會到達我們開始漫步時的相同雜訊輸入,因此我們會得到「可循環」的結果!

walk_steps = 64
batch_size = 4
batches = walk_steps // batch_size
prompt = "An oil paintings of cows in a field next to a windmill in Holland"
encoding = get_text_embeddings(prompt)

walk_latent_x = random.normal((1, height // 8, width // 8, 16))
walk_latent_y = random.normal((1, height // 8, width // 8, 16))
walk_scale_x = ops.cos(ops.linspace(0.0, 2.0, walk_steps) * math.pi)
walk_scale_y = ops.sin(ops.linspace(0.0, 2.0, walk_steps) * math.pi)
latent_x = ops.tensordot(walk_scale_x, walk_latent_x, axes=0)
latent_y = ops.tensordot(walk_scale_y, walk_latent_y, axes=0)
latents = ops.add(latent_x, latent_y)
latents = ops.reshape(latents, (batches, batch_size, height // 8, width // 8, 16))

images = []
progbar = keras.utils.Progbar(batches)
for i in range(batches):
    images.append(
        generate_function(
            latents[i],
            (
                ops.tile(encoding[0], (batch_size, 1, 1)),
                ops.tile(encoding[1], (batch_size, 1, 1)),
                ops.tile(encoding[2], (batch_size, 1)),
                ops.tile(encoding[3], (batch_size, 1)),
            ),
            ops.convert_to_tensor(num_steps),
            ops.convert_to_tensor(guidance_scale),
        )
    )
    progbar.update(i + 1, finalize=i == batches - 1)

images = ops.convert_to_numpy(decode_to_images(images, height, width))
export_as_gif(
    "cows.gif",
    [Image.fromarray(image) for image in images],
    frames_per_second=4,
    no_rubber_band=True,
)

cows.gif

嘗試使用您自己的提示和不同的參數值進行實驗!


結論

Stable Diffusion 3 不僅僅提供單純的文字轉圖像生成。探索文字編碼器的潛在流形和擴散模型的潛在空間是體驗此模型強大的兩種有趣方式,而 KerasHub 使這一切變得容易!