KerasHub:預訓練模型 / 開發者指南 / KerasHub 中的 Stable Diffusion 3!

KerasHub 中的 Stable Diffusion 3!

作者: Hongyu Chiufcholletlukewooddivamgupta
建立日期 2024/10/09
上次修改日期 2024/10/24
描述: 使用 KerasHub 的 Stable Diffusion 3 模型產生影像。

在 Colab 中檢視 GitHub 原始碼


概述

Stable Diffusion 3 是一個強大的開源潛在擴散模型(LDM),旨在根據文字提示生成高品質的新穎影像。由 Stability AI 發布,它在 10 億張影像上進行預訓練,並在 3300 萬張高品質美學和偏好影像上進行微調,與先前版本的 Stable Diffusion 模型相比,效能大幅提升。

在本指南中,我們將探索 KerasHub 對 Stable Diffusion 3 Medium 的實作,包括文字轉影像、影像轉影像和修復任務。

首先,讓我們安裝一些相依性,並取得用於演示的影像

!pip install -Uq keras
!pip install -Uq git+https://github.com/keras-team/keras-hub.git
!wget -O mountain_dog.png https://raw.githubusercontent.com/keras-team/keras-io/master/guides/img/stable_diffusion_3_in_keras_hub/mountain_dog.png
!wget -O mountain_dog_mask.png https://raw.githubusercontent.com/keras-team/keras-io/master/guides/img/stable_diffusion_3_in_keras_hub/mountain_dog_mask.png
import os

os.environ["KERAS_BACKEND"] = "jax"

import time

import keras
import keras_hub
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image

簡介

在深入研究潛在擴散模型的工作原理之前,讓我們先使用 KerasHub 的 API 生成一些影像。

為了避免為不同的任務重新初始化變數,我們將使用 KerasHub 的 from_preset 工廠方法來實例化並載入已訓練的 backbonepreprocessor。如果您只想一次執行一個任務,可以使用更簡單的 API,例如這樣

text_to_image = keras_hub.models.StableDiffusion3TextToImage.from_preset(
    "stable_diffusion_3_medium", dtype="float16"
)

這將自動為您載入並配置已訓練的 backbonepreprocessor

請注意,在本指南中,我們將使用 image_shape=(512, 512, 3) 以加快影像生成速度。為了獲得更高品質的輸出,建議使用預設大小 1024。由於整個主幹約有 30 億個參數,這對於消費級 GPU 來說可能難以容納,因此我們將 dtype="float16" 設定為減少 GPU 記憶體的使用量 – 官方發布的權重也是 float16。

還值得注意的是,預設的「stable_diffusion_3_medium」不包含 T5XXL 文字編碼器,因為它需要更多的 GPU 記憶體。在大多數情況下,效能的降低可以忽略不計。包含 T5XXL 的權重將很快在 KerasHub 上提供。

def display_generated_images(images):
    """Helper function to display the images from the inputs.

    This function accepts the following input formats:
    - 3D numpy array.
    - 4D numpy array: concatenated horizontally.
    - List of 3D numpy arrays: concatenated horizontally.
    """
    display_image = None
    if isinstance(images, np.ndarray):
        if images.ndim == 3:
            display_image = Image.fromarray(images)
        elif images.ndim == 4:
            concated_images = np.concatenate(list(images), axis=1)
            display_image = Image.fromarray(concated_images)
    elif isinstance(images, list):
        concated_images = np.concatenate(images, axis=1)
        display_image = Image.fromarray(concated_images)

    if display_image is None:
        raise ValueError("Unsupported input format.")

    plt.figure(figsize=(10, 10))
    plt.axis("off")
    plt.imshow(display_image)
    plt.show()
    plt.close()


backbone = keras_hub.models.StableDiffusion3Backbone.from_preset(
    "stable_diffusion_3_medium", image_shape=(512, 512, 3), dtype="float16"
)
preprocessor = keras_hub.models.StableDiffusion3TextToImagePreprocessor.from_preset(
    "stable_diffusion_3_medium"
)
text_to_image = keras_hub.models.StableDiffusion3TextToImage(backbone, preprocessor)

接下來,我們給它一個提示

prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"

# When using JAX or TensorFlow backends, you might experience a significant
# compilation time during the first `generate()` call. The subsequent
# `generate()` call speedup highlights the power of JIT compilation and caching
# in frameworks like JAX and TensorFlow, making them well-suited for
# high-performance deep learning tasks like image generation.
generated_image = text_to_image.generate(prompt)
display_generated_images(generated_image)

png

非常令人印象深刻!但這是如何運作的?

讓我們深入研究「潛在擴散模型」的含義。

考慮一下「超解析度」的概念,其中深度學習模型「去噪」輸入影像,將其轉換為更高解析度的版本。該模型使用其訓練資料分佈來虛構在給定輸入的情況下最有可能出現的視覺細節。若要深入了解超解析度,您可以查看以下 Keras.io 教學課程

Super-resolution

當我們將這個想法推向極限時,我們可能會開始問 - 如果我們只是在純雜訊上運行這樣的模型會怎麼樣?那麼模型會「對雜訊進行去噪」,並開始虛構一個全新的影像。透過多次重複此過程,我們可以將一小塊雜訊變成越來越清晰且高解析度的人工圖片。

這是潛在擴散的關鍵思想,在 使用潛在擴散模型的高解析度影像合成 中提出。若要深入了解擴散,您可以查看 Keras.io 教學課程 去噪擴散隱式模型

Denoising diffusion

若要從潛在擴散轉換為文字轉影像系統,必須新增一個關鍵功能:使用提示關鍵字控制產生的視覺內容的能力。在 Stable Diffusion 3 中,來自 CLIP 和 T5XXL 模型的文字編碼器用於取得文字嵌入,然後將其輸入到擴散模型中以調節擴散過程。這種方法基於 無分類器擴散引導 中提出的「無分類器引導」概念。

當我們結合這些想法時,我們可以大致了解 Stable Diffusion 3 的架構

  • 文字編碼器:將文字提示轉換為文字嵌入。
  • 擴散模型:重複「去噪」較小的潛在影像區塊。
  • 解碼器:將最終的潛在區塊轉換為更高解析度的影像。

首先,文字提示由多個文字編碼器投影到潛在空間中,這些編碼器是經過預訓練且凍結的語言模型。接下來,將文字嵌入與隨機產生的雜訊區塊(通常來自高斯分佈)一起輸入到擴散模型中。擴散模型會在一系列步驟中重複「去噪」雜訊區塊(步驟越多,影像就越清晰、越精細 - 預設值為 28 個步驟)。最後,潛在區塊會通過 VAE 模型的解碼器,以高解析度呈現影像。

Stable Diffusion 3 架構的概述:Stable Diffusion 3 的架構

一旦我們在數十億張圖片及其標題上進行訓練,這個相對簡單的系統就會開始看起來像魔法。正如費曼談到宇宙時所說的:「它並不複雜,只是數量龐大!」


文字轉影像任務

現在我們知道 Stable Diffusion 3 和文字轉影像任務的基礎。讓我們使用 KerasHub API 深入探索。

若要使用 KerasHub 的 API 進行高效批次處理,我們可以為模型提供提示清單

generated_images = text_to_image.generate([prompt] * 3)
display_generated_images(generated_images)

png

num_steps 參數控制影像生成期間使用的去噪步驟數。增加步驟數通常會提高影像品質,但會增加生成時間。在 Stable Diffusion 3 中,此參數預設為 28

num_steps = [10, 28, 50]
generated_images = []
for n in num_steps:
    st = time.time()
    generated_images.append(text_to_image.generate(prompt, num_steps=n))
    print(f"Cost time (`num_steps={n}`): {time.time() - st:.2f}s")

display_generated_images(generated_images)
Cost time (`num_steps=10`): 1.35s

Cost time (`num_steps=28`): 3.44s

Cost time (`num_steps=50`): 6.18s

png

我們可以使用 "negative_prompts" 來引導模型避免產生特定的樣式和元素。輸入格式會變成一個字典,其鍵為 "prompts""negative_prompts"

如果未提供 "negative_prompts",則會將其解釋為具有預設值 "" 的無條件提示。

generated_images = text_to_image.generate(
    {
        "prompts": [prompt] * 3,
        "negative_prompts": ["Green color"] * 3,
    }
)
display_generated_images(generated_images)

png

guidance_scale 會影響 "prompts" 對影像生成的影響程度。較低的值會讓模型有創意地產生與提示較不相關的影像。較高的值會推動模型更密切地遵循提示。如果此值太高,您可能會在生成的影像中觀察到一些瑕疵。在 Stable Diffusion 3 中,它預設為 7.0

generated_images = [
    text_to_image.generate(prompt, guidance_scale=2.5),
    text_to_image.generate(prompt, guidance_scale=7.0),
    text_to_image.generate(prompt, guidance_scale=10.5),
]
display_generated_images(generated_images)

png

請注意,negative_promptsguidance_scale 是相關的。實作中的公式可以表示如下:predicted_noise = negative_noise + guidance_scale * (positive_noise - negative_noise)


影像轉影像任務

參考影像可以用作擴散過程的起點。這需要在管道中新增一個模組:來自 VAE 模型的編碼器。

參考影像由 VAE 編碼器編碼到潛在空間,然後在其中新增雜訊。後續的去噪步驟與文字轉影像任務的程序相同。

輸入格式會變成一個字典,其鍵為 "images""prompts" 和可選的 "negative_prompts"

image_to_image = keras_hub.models.StableDiffusion3ImageToImage(backbone, preprocessor)

image = Image.open("mountain_dog.png").convert("RGB")
image = image.resize((512, 512))
width, height = image.size

# Note that the values of the image must be in the range of [-1.0, 1.0].
rescale = keras.layers.Rescaling(scale=1 / 127.5, offset=-1.0)
image_array = rescale(np.array(image))

prompt = "dog wizard, gandalf, lord of the rings, detailed, fantasy, cute, "
prompt += "adorable, Pixar, Disney, 8k"

generated_image = image_to_image.generate(
    {
        "images": image_array,
        "prompts": prompt,
    }
)
display_generated_images(
    [
        np.array(image),
        generated_image,
    ]
)

png

如您所見,會根據參考影像和提示生成新的影像。

strength 參數在決定生成的影像與參考影像的相似程度方面起著關鍵作用。該值範圍為 [0.0, 1.0],在 Stable Diffusion 3 中預設為 0.8

較高的 strength 值會讓模型有更多「創意」來產生與參考影像不同的影像。在值為 1.0 時,會完全忽略參考影像,使該任務純粹是文字轉影像。

較低的 strength 值表示生成的影像與參考影像更相似。

generated_images = [
    image_to_image.generate(
        {
            "images": image_array,
            "prompts": prompt,
        },
        strength=0.7,
    ),
    image_to_image.generate(
        {
            "images": image_array,
            "prompts": prompt,
        },
        strength=0.8,
    ),
    image_to_image.generate(
        {
            "images": image_array,
            "prompts": prompt,
        },
        strength=0.9,
    ),
]
display_generated_images(generated_images)

png


修復任務

在影像轉影像任務的基礎上,我們還可以使用遮罩來控制生成的區域。此過程稱為修復,其中影像的特定區域會被取代或編輯。

修復依賴遮罩來確定要修改影像的哪些區域。要修復的區域由白色像素 (True) 表示,而要保留的區域由黑色像素 (False) 表示。

對於修復,輸入是一個字典,其鍵為 "images""masks""prompts" 和可選的 "negative_prompts"

inpaint = keras_hub.models.StableDiffusion3Inpaint(backbone, preprocessor)

image = Image.open("mountain_dog.png").convert("RGB")
image = image.resize((512, 512))
image_array = rescale(np.array(image))

# Note that the mask values are of boolean dtype.
mask = Image.open("mountain_dog_mask.png").convert("L")
mask = mask.resize((512, 512))
mask_array = np.array(mask).astype("bool")

prompt = "a black cat with glowing eyes, cute, adorable, disney, pixar, highly "
prompt += "detailed, 8k"

generated_image = inpaint.generate(
    {
        "images": image_array,
        "masks": mask_array,
        "prompts": prompt,
    }
)
display_generated_images(
    [
        np.array(image),
        np.array(mask.convert("RGB")),
        generated_image,
    ]
)

png

太棒了!狗被可愛的黑貓取代,但與影像轉影像不同,背景被保留了。

請注意,修復任務也包含 strength 參數來控制影像生成,在 Stable Diffusion 3 中的預設值為 0.6


結論

KerasHub 的 StableDiffusion3 支援各種應用程式,並且在 Keras 3 的協助下,可以在 TensorFlow、JAX 和 PyTorch 上運行模型!