開發者指南 / KerasCV / 在 KerasCV 中使用 Segment Anything!

在 KerasCV 中使用 Segment Anything!

作者:Tirth Patel、Ian Stenbit
建立日期 2023/12/04
上次修改日期 2023/12/19
說明:使用 KerasCV 中的文字、方框和點提示來分割任何物件。

在 Colab 中檢視 GitHub 原始碼


概觀

Segment Anything Model (SAM) 可以根據輸入的提示(例如點或方框)產生高品質的物件遮罩,並且可以用於生成影像中所有物件的遮罩。它已經在一個包含 1,100 萬張影像和 11 億個遮罩的資料集上進行了訓練,並且在各種分割任務上具有強大的零樣本學習效能。

在本指南中,我們將展示如何使用 KerasCV 對Segment Anything Model的實作,並展示 TensorFlow 和 JAX 的效能提升有多麼強大。

首先,讓我們獲取演示所需的所有依賴項和圖像。

!pip install -Uq keras-cv
!pip install -Uq keras
!wget -q https://raw.githubusercontent.com/facebookresearch/segment-anything/main/notebooks/images/truck.jpg

選擇後端

使用 Keras 3,您可以選擇使用您喜歡的后端!

import os

os.environ["KERAS_BACKEND"] = "jax"

import timeit
import numpy as np
import matplotlib.pyplot as plt
import keras
from keras import ops
import keras_cv

輔助函數

讓我們定義一些輔助函數,用於可視化圖像、提示和分割結果。

def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)


def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels == 1]
    neg_points = coords[labels == 0]
    ax.scatter(
        pos_points[:, 0],
        pos_points[:, 1],
        color="green",
        marker="*",
        s=marker_size,
        edgecolor="white",
        linewidth=1.25,
    )
    ax.scatter(
        neg_points[:, 0],
        neg_points[:, 1],
        color="red",
        marker="*",
        s=marker_size,
        edgecolor="white",
        linewidth=1.25,
    )


def show_box(box, ax):
    box = box.reshape(-1)
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(
        plt.Rectangle((x0, y0), w, h, edgecolor="green", facecolor=(0, 0, 0, 0), lw=2)
    )


def inference_resizing(image, pad=True):
    # Compute Preprocess Shape
    image = ops.cast(image, dtype="float32")
    old_h, old_w = image.shape[0], image.shape[1]
    scale = 1024 * 1.0 / max(old_h, old_w)
    new_h = old_h * scale
    new_w = old_w * scale
    preprocess_shape = int(new_h + 0.5), int(new_w + 0.5)

    # Resize the image
    image = ops.image.resize(image[None, ...], preprocess_shape)[0]

    # Pad the shorter side
    if pad:
        pixel_mean = ops.array([123.675, 116.28, 103.53])
        pixel_std = ops.array([58.395, 57.12, 57.375])
        image = (image - pixel_mean) / pixel_std
        h, w = image.shape[0], image.shape[1]
        pad_h = 1024 - h
        pad_w = 1024 - w
        image = ops.pad(image, [(0, pad_h), (0, pad_w), (0, 0)])
        # KerasCV now rescales the images and normalizes them.
        # Just unnormalize such that when KerasCV normalizes them
        # again, the padded values map to 0.
        image = image * pixel_std + pixel_mean
    return image

獲取預先訓練的 SAM 模型

我們可以使用 KerasCV 的 from_preset 工廠方法來初始化一個經過訓練的 SAM 模型。在這裡,我們使用在 SA-1B 數據集(sam_huge_sa1b)上訓練的大型 ViT 骨幹網路,以獲得高質量的分割遮罩。您也可以使用 sam_large_sa1bsam_base_sa1b 中的一個來獲得更好的性能(但代價是分割遮罩的質量會下降)。

model = keras_cv.models.SegmentAnythingModel.from_preset("sam_huge_sa1b")

理解提示

Segment Anything 允許使用點、框和遮罩來提示圖像

  1. 點提示是最基本的:模型會嘗試根據圖像上的一個點來猜測物體。該點可以是前景點(即所需的分割遮罩包含該點)或背景點(即該點位於所需的遮罩之外)。
  2. 提示模型的另一種方法是使用框。給定一個邊界框,模型會嘗試分割其中包含的物體。
  3. 最後,模型也可以使用遮罩本身進行提示。例如,這對於細化先前預測或已知的分割遮罩的邊界很有用。

該模型功能強大的原因在於能夠組合上述提示。點、框和遮罩提示可以通過幾種不同的方式組合起來,以達到最佳效果。

讓我們看看在 KerasCV 中將這些提示傳遞給 Segment Anything 模型的語義。SAM 模型的輸入是一個字典,鍵值如下

  1. "images":要分割的一批圖像。形狀必須為 (B, 1024, 1024, 3)
  2. "points":一批點提示。每個點都是從圖像左上角開始的 (x, y) 坐標。換句話說,每個點的形式為 (r, c),其中 rc 是像素在圖像中的行和列。形狀必須為 (B, N, 2)
  3. "labels":給定點的一批標籤。1 表示前景點,0 表示背景點。形狀必須為 (B, N)
  4. "boxes":一批框。請注意,模型每批只接受一個框。因此,預期的形狀為 (B, 1, 2, 2)。每個框都是 2 個點的集合:框的左上角和右下角。此處的點遵循與點提示相同的語義。這裡第二維的 1 表示存在框提示。如果缺少框提示,則必須傳遞形狀為 (B, 0, 2, 2) 的佔位符輸入。
  5. "masks":一批遮罩。與框提示一樣,每個圖像只允許一個遮罩提示。如果存在遮罩提示,則輸入遮罩的形狀必須為 (B, 1, 256, 256, 1),如果缺少遮罩提示,則為 (B, 0, 256, 256, 1)

僅當直接調用模型時(即 model(...)),才需要佔位符提示。當調用 predict 方法時,可以在輸入字典中省略缺少的提示。


點提示

首先,讓我們使用點提示來分割圖像。我們載入圖像並將其調整為 (1024, 1024) 的形狀,這是預先訓練的 SAM 模型期望的圖像大小。

# Load our image
image = np.array(keras.utils.load_img("truck.jpg"))
image = inference_resizing(image)

plt.figure(figsize=(10, 10))
plt.imshow(ops.convert_to_numpy(image) / 255.0)
plt.axis("on")
plt.show()

png

接下來,我們將定義要分割的物體上的點。讓我們嘗試分割坐標為 (284, 213) 的卡車窗格。

# Define the input point prompt
input_point = np.array([[284, 213.5]])
input_label = np.array([1])

plt.figure(figsize=(10, 10))
plt.imshow(ops.convert_to_numpy(image) / 255.0)
show_points(input_point, input_label, plt.gca())
plt.axis("on")
plt.show()

png

現在讓我們調用模型的 predict 方法來獲取分割遮罩。

注意:我們不會直接呼叫模型 (model(...)),因為需要佔位提示才能這樣做。缺少的提示會由預測方法自動處理,因此我們改為呼叫該方法。此外,當沒有框提示時,需要分別使用零點提示和 -1 標籤提示來填充點和標籤。以下儲存格說明了這是如何運作的。

outputs = model.predict(
    {
        "images": image[np.newaxis, ...],
        "points": np.concatenate(
            [input_point[np.newaxis, ...], np.zeros((1, 1, 2))], axis=1
        ),
        "labels": np.concatenate(
            [input_label[np.newaxis, ...], np.full((1, 1), fill_value=-1)], axis=1
        ),
    }
)
 1/1 ━━━━━━━━━━━━━━━━━━━━ 48s 48s/step

SegmentAnythingModel.predict 會傳回兩個輸出。第一個是形狀為 (1, 4, 256, 256) 的 logits(分割遮罩),另一個是每個預測遮罩的 IoU 信賴度分數(形狀為 (1, 4))。預先訓練的 SAM 模型會預測四個遮罩:第一個是模型針對給定提示所能想出的最佳遮罩,另外三個是替代遮罩,可在最佳預測未包含所需物件時使用。使用者可以選擇他們想要的任何遮罩。

讓我們將模型傳回的遮罩視覺化!

# Resize the mask to our image shape i.e. (1024, 1024)
mask = inference_resizing(outputs["masks"][0][0][..., None], pad=False)[..., 0]
# Convert the logits to a numpy array
# and convert the logits to a boolean mask
mask = ops.convert_to_numpy(mask) > 0.0
iou_score = ops.convert_to_numpy(outputs["iou_pred"][0][0])

plt.figure(figsize=(10, 10))
plt.imshow(ops.convert_to_numpy(image) / 255.0)
show_mask(mask, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.title(f"IoU Score: {iou_score:.3f}", fontsize=18)
plt.axis("off")
plt.show()

png

如預期所示,模型會傳回卡車窗格的分割遮罩。但是,我們的點提示也可能代表其他範圍的事物。例如,另一個包含我們點的可能遮罩只是窗格的右側或整輛卡車。

我們也來視覺化模型預測的其他遮罩。

fig, ax = plt.subplots(1, 3, figsize=(20, 60))
masks, scores = outputs["masks"][0][1:], outputs["iou_pred"][0][1:]
for i, (mask, score) in enumerate(zip(masks, scores)):
    mask = inference_resizing(mask[..., None], pad=False)[..., 0]
    mask, score = map(ops.convert_to_numpy, (mask, score))
    mask = 1 * (mask > 0.0)
    ax[i].imshow(ops.convert_to_numpy(image) / 255.0)
    show_mask(mask, ax[i])
    show_points(input_point, input_label, ax[i])
    ax[i].set_title(f"Mask {i+1}, Score: {score:.3f}", fontsize=12)
    ax[i].axis("off")
plt.show()

png

太棒了!SAM 能夠捕捉到我們點提示的模糊性,並傳回其他可能的分割遮罩。


框提示

現在,讓我們看看如何使用框提示模型。使用兩個點指定框,即邊界框左上角和右下角的 xyxy 格式。讓我們使用卡車左前輪胎周圍的邊界框提示模型。

# Let's specify the box
input_box = np.array([[240, 340], [400, 500]])

outputs = model.predict(
    {"images": image[np.newaxis, ...], "boxes": input_box[np.newaxis, np.newaxis, ...]}
)
mask = inference_resizing(outputs["masks"][0][0][..., None], pad=False)[..., 0]
mask = ops.convert_to_numpy(mask) > 0.0

plt.figure(figsize=(10, 10))
plt.imshow(ops.convert_to_numpy(image) / 255.0)
show_mask(mask, plt.gca())
show_box(input_box, plt.gca())
plt.axis("off")
plt.show()
 1/1 ━━━━━━━━━━━━━━━━━━━━ 13s 13s/step

png

轟!模型完美地分割出我們邊界框中的左前輪胎。


組合提示

為了發揮模型的真正潛力,讓我們組合框和點提示,看看模型會怎麼做。

# Let's specify the box
input_box = np.array([[240, 340], [400, 500]])
# Let's specify the point and mark it background
input_point = np.array([[325, 425]])
input_label = np.array([0])

outputs = model.predict(
    {
        "images": image[np.newaxis, ...],
        "points": input_point[np.newaxis, ...],
        "labels": input_label[np.newaxis, ...],
        "boxes": input_box[np.newaxis, np.newaxis, ...],
    }
)
mask = inference_resizing(outputs["masks"][0][0][..., None], pad=False)[..., 0]
mask = ops.convert_to_numpy(mask) > 0.0

plt.figure(figsize=(10, 10))
plt.imshow(ops.convert_to_numpy(image) / 255.0)
show_mask(mask, plt.gca())
show_box(input_box, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis("off")
plt.show()
 1/1 ━━━━━━━━━━━━━━━━━━━━ 16s 16s/step

png

瞧!模型了解到我們想要從遮罩中排除的物件是輪胎的輪框。


文字提示

最後,讓我們看看如何將文字提示與 KerasCV 的 SegmentAnythingModel 一起使用。

在本展示中,我們將使用官方 Grounding DINO 模型。Grounding DINO 是一種模型,它將 (影像、文字) 對作為輸入,並在 文字 描述的 影像 中的物件周圍產生邊界框。您可以參考論文,以取得有關模型實作的更多詳細資訊。

對於展示的這一部分,我們需要從原始碼安裝 groundingdino 套件

pip install -U git+https://github.com/IDEA-Research/GroundingDINO.git

然後,我們可以安裝預先訓練的模型的權重和配置

!wget -q https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth
!wget -q https://raw.githubusercontent.com/IDEA-Research/GroundingDINO/v0.1.0-alpha2/groundingdino/config/GroundingDINO_SwinT_OGC.py
from groundingdino.util.inference import Model as GroundingDINO

CONFIG_PATH = "GroundingDINO_SwinT_OGC.py"
WEIGHTS_PATH = "groundingdino_swint_ogc.pth"

grounding_dino = GroundingDINO(CONFIG_PATH, WEIGHTS_PATH)
/home/tirthp/oss/virtualenvs/keras-io-dev/lib/python3.10/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3526.)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]

final text_encoder_type: bert-base-uncased

讓我們載入一張狗的影像作為這一部分!

filepath = keras.utils.get_file(
    origin="https://storage.googleapis.com/keras-cv/test-images/mountain-dog.jpeg"
)
image = np.array(keras.utils.load_img(filepath))
image = ops.convert_to_numpy(inference_resizing(image))

plt.figure(figsize=(10, 10))
plt.imshow(image / 255.0)
plt.axis("on")
plt.show()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

png

我們首先使用 Grounding DINO 模型預測我們要分割的物件的邊界框。然後,我們使用邊界框提示 SAM 模型以取得分割遮罩。

讓我們嘗試分割出狗的背帶。變更以下的影像和文字,使用影像中的文字分割您想要的任何內容!

# Let's predict the bounding box for the harness of the dog
boxes = grounding_dino.predict_with_caption(image.astype(np.uint8), "harness")
boxes = np.array(boxes[0].xyxy)

outputs = model.predict(
    {
        "images": np.repeat(image[np.newaxis, ...], boxes.shape[0], axis=0),
        "boxes": boxes.reshape(-1, 1, 2, 2),
    },
    batch_size=1,
)
/home/tirthp/oss/virtualenvs/keras-io-dev/lib/python3.10/site-packages/transformers/modeling_utils.py:942: FutureWarning: The `device` argument is deprecated and will be removed in v5 of Transformers.
  warnings.warn(
/home/tirthp/oss/virtualenvs/keras-io-dev/lib/python3.10/site-packages/torch/utils/checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.
  warnings.warn(
/home/tirthp/oss/virtualenvs/keras-io-dev/lib/python3.10/site-packages/torch/utils/checkpoint.py:61: UserWarning: None of the inputs have requires_grad=True. Gradients will be None
  warnings.warn(

 1/1 ━━━━━━━━━━━━━━━━━━━━ 13s 13s/step

就是這樣!我們使用 Gounding DINO + SAM 的組合為我們的文字提示取得了一個分割遮罩!這是一種非常強大的技術,可以結合不同的模型來擴展應用程式!

讓我們將結果視覺化。

plt.figure(figsize=(10, 10))
plt.imshow(image / 255.0)

for mask in outputs["masks"]:
    mask = inference_resizing(mask[0][..., None], pad=False)[..., 0]
    mask = ops.convert_to_numpy(mask) > 0.0
    show_mask(mask, plt.gca())
    show_box(boxes, plt.gca())

plt.axis("off")
plt.show()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

png


最佳化 SAM

您可以使用 mixed_float16bfloat16 dtype 策略,以在精度損失相對較低的情況下獲得巨大的速度提升和記憶體最佳化。

# Load our image
image = np.array(keras.utils.load_img("truck.jpg"))
image = inference_resizing(image)

# Specify the prompt
input_box = np.array([[240, 340], [400, 500]])

# Let's first see how fast the model is with float32 dtype
time_taken = timeit.repeat(
    'model.predict({"images": image[np.newaxis, ...], "boxes": input_box[np.newaxis, np.newaxis, ...]}, verbose=False)',
    repeat=3,
    number=3,
    globals=globals(),
)
print(f"Time taken with float32 dtype: {min(time_taken) / 3:.10f}s")

# Set the dtype policy in Keras
keras.mixed_precision.set_global_policy("mixed_float16")

model = keras_cv.models.SegmentAnythingModel.from_preset("sam_huge_sa1b")

time_taken = timeit.repeat(
    'model.predict({"images": image[np.newaxis, ...], "boxes": input_box[np.newaxis, np.newaxis, ...]}, verbose=False)',
    repeat=3,
    number=3,
    globals=globals(),
)
print(f"Time taken with float16 dtype: {min(time_taken) / 3:.10f}s")
Time taken with float32 dtype: 0.5304666963s
Time taken with float16 dtype: 0.1586400040s

以下是 KerasCV 的實作與原始 PyTorch 實作的比較!

benchmark

用於產生基準測試的指令碼位於這裡


結論

KerasCV 的 SegmentAnythingModel 支援各種應用,並在 Keras 3 的幫助下,可以在 TensorFlow、JAX 和 PyTorch 上執行模型!借助 JAX 和 TensorFlow 中的 XLA,該模型的運行速度比原始實作快數倍。此外,使用 Keras 的混合精度支援只需一行代碼即可優化內存使用和計算時間!

如需更進階的用途,請查看自動遮罩產生器範例