作者: Sachin Prasad、Divyashree Sreepathihalli、Ian Stenbit
建立日期 2024/10/11
上次修改日期 2024/10/22
描述: 使用 KerasHub 進行 DeepLabV3 訓練和推論。
語義分割是一種電腦視覺任務,它會將類別標籤(例如「人」、「自行車」或「背景」)指派給影像的每個像素,有效地將影像劃分為對應於不同物件類別或類別的區域。
KerasHub 提供 DeepLabv3、DeepLabv3+、SegFormer 等模型用於語義分割。
本指南示範如何微調和使用 Google 開發的 DeepLabv3+ 模型,以使用 KerasHub 進行影像語義分割。其架構結合了 Atrous 卷積、上下文資訊聚合和強大的骨幹網路,以實現準確且詳細的語義分割。
DeepLabv3+ 透過新增一個簡單而有效的解碼器模組來擴展 DeepLabv3,以改善分割結果,尤其是在物件邊界附近。這兩個模型都在各種影像分割基準測試中取得了最先進的成果。
使用 Atrous 可分離卷積進行語義影像分割的編碼器-解碼器 重新思考語義影像分割的 Atrous 卷積
讓我們安裝相依性並匯入必要的模組。
若要執行本教學課程,您需要安裝下列套件
keras-hub
keras
!pip install -q --upgrade keras-hub
!pip install -q --upgrade keras
安裝 keras
和 keras-hub
後,請設定 keras
的後端。本指南可以使用任何後端(Tensorflow、JAX、PyTorch)執行。
import os
os.environ["KERAS_BACKEND"] = "jax"
import keras
from keras import ops
import keras_hub
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
KerasHub 語義分割 API 中最高層級的 API 是 keras_hub.models
API。此 API 包含完全預訓練的語義分割模型,例如 keras_hub.models.DeepLabV3ImageSegmenter
。
讓我們開始建構一個在 Pascal VOC 資料集上預訓練的 DeepLabv3。此外,定義模型的預處理函式,以預處理影像和標籤。注意: 依預設,KerasHub 中的 from_preset()
方法會載入預訓練的任務權重以及所有類別,在此範例中為 21 個類別。
model = keras_hub.models.DeepLabV3ImageSegmenter.from_preset(
"deeplab_v3_plus_resnet50_pascalvoc"
)
image_converter = keras_hub.layers.DeepLabV3ImageConverter(
image_size=(512, 512),
interpolation="bilinear",
)
preprocessor = keras_hub.models.DeepLabV3ImageSegmenterPreprocessor(image_converter)
讓我們將此預訓練模型的結果視覺化
filepath = keras.utils.get_file(
origin="https://storage.googleapis.com/keras-cv/models/paligemma/cow_beach_1.png"
)
image = keras.utils.load_img(filepath)
image = np.array(image)
image = preprocessor(image)
image = keras.ops.expand_dims(image, axis=0)
preds = ops.expand_dims(ops.argmax(model.predict(image), axis=-1), axis=-1)
def plot_segmentation(original_image, predicted_mask):
plt.figure(figsize=(5, 5))
plt.subplot(1, 2, 1)
plt.imshow(original_image[0] / 255)
plt.axis("off")
plt.subplot(1, 2, 2)
plt.imshow(predicted_mask[0])
plt.axis("off")
plt.tight_layout()
plt.show()
plot_segmentation(image, preds)
1/1 ━━━━━━━━━━━━━━━━━━━━ 0 秒 5 秒/步
1/1 ━━━━━━━━━━━━━━━━━━━━ 5 秒 5 秒/步
在本指南中,我們將組裝 KerasHub DeepLabV3 語義分割模型的完整訓練管線。這包括資料載入、增強、訓練、指標評估和推論!
我們下載 Pascal VOC 2012 資料集,其中包含此處提供的額外註釋 來自反向偵測器的語義輪廓,並將其分為訓練資料集 train_ds
和 eval_ds
。
# @title helper functions
import logging
import multiprocessing
from builtins import open
import os.path
import random
import xml
import tensorflow_datasets as tfds
VOC_URL = "https://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar"
SBD_URL = "https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz"
# Note that this list doesn't contain the background class. In the
# classification use case, the label is 0 based (aeroplane -> 0), whereas in
# segmentation use case, the 0 is reserved for background, so aeroplane maps to
# 1.
CLASSES = [
"aeroplane",
"bicycle",
"bird",
"boat",
"bottle",
"bus",
"car",
"cat",
"chair",
"cow",
"diningtable",
"dog",
"horse",
"motorbike",
"person",
"pottedplant",
"sheep",
"sofa",
"train",
"tvmonitor",
]
# This is used to map between string class to index.
CLASS_TO_INDEX = {name: index for index, name in enumerate(CLASSES)}
# For the mask data in the PNG file, the encoded raw pixel value need to be
# converted to the proper class index. In the following map, [0, 0, 0] will be
# convert to 0, and [128, 0, 0] will be converted to 1, so on so forth. Also
# note that the mask class is 1 base since class 0 is reserved for the
# background. The [128, 0, 0] (class 1) is mapped to `aeroplane`.
VOC_PNG_COLOR_VALUE = [
[0, 0, 0],
[128, 0, 0],
[0, 128, 0],
[128, 128, 0],
[0, 0, 128],
[128, 0, 128],
[0, 128, 128],
[128, 128, 128],
[64, 0, 0],
[192, 0, 0],
[64, 128, 0],
[192, 128, 0],
[64, 0, 128],
[192, 0, 128],
[64, 128, 128],
[192, 128, 128],
[0, 64, 0],
[128, 64, 0],
[0, 192, 0],
[128, 192, 0],
[0, 64, 128],
]
# Will be populated by maybe_populate_voc_color_mapping() below.
VOC_PNG_COLOR_MAPPING = None
def maybe_populate_voc_color_mapping():
"""Lazy creation of VOC_PNG_COLOR_MAPPING, which could take 64M memory."""
global VOC_PNG_COLOR_MAPPING
if VOC_PNG_COLOR_MAPPING is None:
VOC_PNG_COLOR_MAPPING = [0] * (256**3)
for i, colormap in enumerate(VOC_PNG_COLOR_VALUE):
VOC_PNG_COLOR_MAPPING[
(colormap[0] * 256 + colormap[1]) * 256 + colormap[2]
] = i
# There is a special mapping with [224, 224, 192] -> 255
VOC_PNG_COLOR_MAPPING[224 * 256 * 256 + 224 * 256 + 192] = 255
VOC_PNG_COLOR_MAPPING = tf.constant(VOC_PNG_COLOR_MAPPING)
return VOC_PNG_COLOR_MAPPING
def parse_annotation_data(annotation_file_path):
"""Parse the annotation XML file for the image.
The annotation contains the metadata, as well as the object bounding box
information.
"""
with open(annotation_file_path, "r") as f:
root = xml.etree.ElementTree.parse(f).getroot()
size = root.find("size")
width = int(size.find("width").text)
height = int(size.find("height").text)
objects = []
for obj in root.findall("object"):
# Get object's label name.
label = CLASS_TO_INDEX[obj.find("name").text.lower()]
# Get objects' pose name.
pose = obj.find("pose").text.lower()
is_truncated = obj.find("truncated").text == "1"
is_difficult = obj.find("difficult").text == "1"
bndbox = obj.find("bndbox")
xmax = int(bndbox.find("xmax").text)
xmin = int(bndbox.find("xmin").text)
ymax = int(bndbox.find("ymax").text)
ymin = int(bndbox.find("ymin").text)
objects.append(
{
"label": label,
"pose": pose,
"bbox": [ymin, xmin, ymax, xmax],
"is_truncated": is_truncated,
"is_difficult": is_difficult,
}
)
return {"width": width, "height": height, "objects": objects}
def get_image_ids(data_dir, split):
"""To get image ids from the "train", "eval" or "trainval" files of VOC data."""
data_file_mapping = {
"train": "train.txt",
"eval": "val.txt",
"trainval": "trainval.txt",
}
with open(
os.path.join(data_dir, "ImageSets", "Segmentation", data_file_mapping[split]),
"r",
) as f:
image_ids = f.read().splitlines()
logging.info(f"Received {len(image_ids)} images for {split} dataset.")
return image_ids
def get_sbd_image_ids(data_dir, split):
"""To get image ids from the "sbd_train", "sbd_eval" from files of SBD data."""
data_file_mapping = {"sbd_train": "train.txt", "sbd_eval": "val.txt"}
with open(
os.path.join(data_dir, data_file_mapping[split]),
"r",
) as f:
image_ids = f.read().splitlines()
logging.info(f"Received {len(image_ids)} images for {split} dataset.")
return image_ids
def parse_single_image(image_file_path):
"""Creates metadata of VOC images and path."""
data_dir, image_file_name = os.path.split(image_file_path)
data_dir = os.path.normpath(os.path.join(data_dir, os.path.pardir))
image_id, _ = os.path.splitext(image_file_name)
class_segmentation_file_path = os.path.join(
data_dir, "SegmentationClass", image_id + ".png"
)
object_segmentation_file_path = os.path.join(
data_dir, "SegmentationObject", image_id + ".png"
)
annotation_file_path = os.path.join(data_dir, "Annotations", image_id + ".xml")
image_annotations = parse_annotation_data(annotation_file_path)
result = {
"image/filename": image_id + ".jpg",
"image/file_path": image_file_path,
"segmentation/class/file_path": class_segmentation_file_path,
"segmentation/object/file_path": object_segmentation_file_path,
}
result.update(image_annotations)
# Labels field should be same as the 'object.label'
labels = list(set([o["label"] for o in result["objects"]]))
result["labels"] = sorted(labels)
return result
def parse_single_sbd_image(image_file_path):
"""Creates metadata of SBD images and path."""
data_dir, image_file_name = os.path.split(image_file_path)
data_dir = os.path.normpath(os.path.join(data_dir, os.path.pardir))
image_id, _ = os.path.splitext(image_file_name)
class_segmentation_file_path = os.path.join(data_dir, "cls", image_id + ".mat")
object_segmentation_file_path = os.path.join(data_dir, "inst", image_id + ".mat")
result = {
"image/filename": image_id + ".jpg",
"image/file_path": image_file_path,
"segmentation/class/file_path": class_segmentation_file_path,
"segmentation/object/file_path": object_segmentation_file_path,
}
return result
def build_metadata(data_dir, image_ids):
"""Transpose the metadata which convert from list of dict to dict of list."""
# Parallel process all the images.
image_file_paths = [
os.path.join(data_dir, "JPEGImages", i + ".jpg") for i in image_ids
]
pool_size = 10 if len(image_ids) > 10 else len(image_ids)
with multiprocessing.Pool(pool_size) as p:
metadata = p.map(parse_single_image, image_file_paths)
keys = [
"image/filename",
"image/file_path",
"segmentation/class/file_path",
"segmentation/object/file_path",
"labels",
"width",
"height",
]
result = {}
for key in keys:
values = [value[key] for value in metadata]
result[key] = values
# The ragged objects need some special handling
for key in ["label", "pose", "bbox", "is_truncated", "is_difficult"]:
values = []
objects = [value["objects"] for value in metadata]
for object in objects:
values.append([o[key] for o in object])
result["objects/" + key] = values
return result
def build_sbd_metadata(data_dir, image_ids):
"""Transpose the metadata which convert from list of dict to dict of list."""
# Parallel process all the images.
image_file_paths = [os.path.join(data_dir, "img", i + ".jpg") for i in image_ids]
pool_size = 10 if len(image_ids) > 10 else len(image_ids)
with multiprocessing.Pool(pool_size) as p:
metadata = p.map(parse_single_sbd_image, image_file_paths)
keys = [
"image/filename",
"image/file_path",
"segmentation/class/file_path",
"segmentation/object/file_path",
]
result = {}
for key in keys:
values = [value[key] for value in metadata]
result[key] = values
return result
def decode_png_mask(mask):
"""Decode the raw PNG image and convert it to 2D tensor with probably
class."""
# Cast the mask to int32 since the original uint8 will overflow when
# multiplied with 256
mask = tf.cast(mask, tf.int32)
mask = mask[:, :, 0] * 256 * 256 + mask[:, :, 1] * 256 + mask[:, :, 2]
mask = tf.expand_dims(tf.gather(VOC_PNG_COLOR_MAPPING, mask), -1)
mask = tf.cast(mask, tf.uint8)
return mask
def load_images(example):
"""Loads VOC images for segmentation task from the provided paths"""
image_file_path = example.pop("image/file_path")
segmentation_class_file_path = example.pop("segmentation/class/file_path")
segmentation_object_file_path = example.pop("segmentation/object/file_path")
image = tf.io.read_file(image_file_path)
image = tf.image.decode_jpeg(image)
segmentation_class_mask = tf.io.read_file(segmentation_class_file_path)
segmentation_class_mask = tf.image.decode_png(segmentation_class_mask)
segmentation_class_mask = decode_png_mask(segmentation_class_mask)
segmentation_object_mask = tf.io.read_file(segmentation_object_file_path)
segmentation_object_mask = tf.image.decode_png(segmentation_object_mask)
segmentation_object_mask = decode_png_mask(segmentation_object_mask)
example.update(
{
"image": image,
"class_segmentation": segmentation_class_mask,
"object_segmentation": segmentation_object_mask,
}
)
return example
def load_sbd_images(image_file_path, seg_cls_file_path, seg_obj_file_path):
"""Loads SBD images for segmentation task from the provided paths"""
image = tf.io.read_file(image_file_path)
image = tf.image.decode_jpeg(image)
segmentation_class_mask = tfds.core.lazy_imports.scipy.io.loadmat(seg_cls_file_path)
segmentation_class_mask = segmentation_class_mask["GTcls"]["Segmentation"][0][0]
segmentation_class_mask = segmentation_class_mask[..., np.newaxis]
segmentation_object_mask = tfds.core.lazy_imports.scipy.io.loadmat(
seg_obj_file_path
)
segmentation_object_mask = segmentation_object_mask["GTinst"]["Segmentation"][0][0]
segmentation_object_mask = segmentation_object_mask[..., np.newaxis]
return {
"image": image,
"class_segmentation": segmentation_class_mask,
"object_segmentation": segmentation_object_mask,
}
def build_dataset_from_metadata(metadata):
"""Builds TensorFlow dataset from the image metadata of VOC dataset."""
# The objects need some manual conversion to ragged tensor.
metadata["labels"] = tf.ragged.constant(metadata["labels"])
metadata["objects/label"] = tf.ragged.constant(metadata["objects/label"])
metadata["objects/pose"] = tf.ragged.constant(metadata["objects/pose"])
metadata["objects/is_truncated"] = tf.ragged.constant(
metadata["objects/is_truncated"]
)
metadata["objects/is_difficult"] = tf.ragged.constant(
metadata["objects/is_difficult"]
)
metadata["objects/bbox"] = tf.ragged.constant(
metadata["objects/bbox"], ragged_rank=1
)
dataset = tf.data.Dataset.from_tensor_slices(metadata)
dataset = dataset.map(load_images, num_parallel_calls=tf.data.AUTOTUNE)
return dataset
def build_sbd_dataset_from_metadata(metadata):
"""Builds TensorFlow dataset from the image metadata of SBD dataset."""
img_filepath = metadata["image/file_path"]
cls_filepath = metadata["segmentation/class/file_path"]
obj_filepath = metadata["segmentation/object/file_path"]
def md_gen():
c = list(zip(img_filepath, cls_filepath, obj_filepath))
# random shuffling for each generator boosts up the quality.
random.shuffle(c)
for fp in c:
img_fp, cls_fp, obj_fp = fp
yield load_sbd_images(img_fp, cls_fp, obj_fp)
dataset = tf.data.Dataset.from_generator(
md_gen,
output_signature=(
{
"image": tf.TensorSpec(shape=(None, None, 3), dtype=tf.uint8),
"class_segmentation": tf.TensorSpec(
shape=(None, None, 1), dtype=tf.uint8
),
"object_segmentation": tf.TensorSpec(
shape=(None, None, 1), dtype=tf.uint8
),
}
),
)
return dataset
def load(
split="sbd_train",
data_dir=None,
):
"""Load the Pacal VOC 2012 dataset.
This function will download the data tar file from remote if needed, and
untar to the local `data_dir`, and build dataset from it.
It supports both VOC2012 and Semantic Boundaries Dataset (SBD).
The returned segmentation masks will be int ranging from [0, num_classes),
as well as 255 which is the boundary mask.
Args:
split: string, can be 'train', 'eval', 'trainval', 'sbd_train', or
'sbd_eval'. 'sbd_train' represents the training dataset for SBD
dataset, while 'train' represents the training dataset for VOC2012
dataset. Defaults to `sbd_train`.
data_dir: string, local directory path for the loaded data. This will be
used to download the data file, and unzip. It will be used as a
cache directory. Defaults to None, and `~/.keras/pascal_voc_2012`
will be used.
"""
supported_split_value = [
"train",
"eval",
"trainval",
"sbd_train",
"sbd_eval",
]
if split not in supported_split_value:
raise ValueError(
f"The support value for `split` are {supported_split_value}. "
f"Got: {split}"
)
if data_dir is not None:
data_dir = os.path.expanduser(data_dir)
if "sbd" in split:
return load_sbd(split, data_dir)
else:
return load_voc(split, data_dir)
def load_voc(
split="train",
data_dir=None,
):
"""This function will download VOC data from a URL. If the data is already
present in the cache directory, it will load the data from that directory
instead.
"""
extracted_dir = os.path.join("VOCdevkit", "VOC2012")
get_data = keras.utils.get_file(
fname=os.path.basename(VOC_URL),
origin=VOC_URL,
cache_dir=data_dir,
extract=True,
)
data_dir = os.path.join(os.path.dirname(get_data), extracted_dir)
image_ids = get_image_ids(data_dir, split)
# len(metadata) = #samples, metadata[i] is a dict.
metadata = build_metadata(data_dir, image_ids)
maybe_populate_voc_color_mapping()
dataset = build_dataset_from_metadata(metadata)
return dataset
def load_sbd(
split="sbd_train",
data_dir=None,
):
"""This function will download SBD data from a URL. If the data is already
present in the cache directory, it will load the data from that directory
instead.
"""
extracted_dir = os.path.join("benchmark_RELEASE", "dataset")
get_data = keras.utils.get_file(
fname=os.path.basename(SBD_URL),
origin=SBD_URL,
cache_dir=data_dir,
extract=True,
)
data_dir = os.path.join(os.path.dirname(get_data), extracted_dir)
image_ids = get_sbd_image_ids(data_dir, split)
# len(metadata) = #samples, metadata[i] is a dict.
metadata = build_sbd_metadata(data_dir, image_ids)
dataset = build_sbd_dataset_from_metadata(metadata)
return dataset
對於訓練和評估,讓我們使用「sbd_train」和「sbd_eval」。您也可以為 load
函式選擇任何這些資料集:「train」、「eval」、「trainval」、「sbd_train」或「sbd_eval」。「sbd_train」代表 SBD 資料集的訓練資料集,而「train」代表 VOC2012 資料集的訓練資料集。
train_ds = load(split="sbd_train", data_dir="segmentation")
eval_ds = load(split="sbd_eval", data_dir="segmentation")
preprocess_inputs 公用程式函式會預處理輸入,將其轉換為包含影像和 segmentation_masks 的字典。影像和分割遮罩的大小都調整為 512x512。然後,將產生的資料集分批為四組影像和分割遮罩配對。
def preprocess_inputs(inputs):
def unpackage_inputs(inputs):
return {
"images": inputs["image"],
"segmentation_masks": inputs["class_segmentation"],
}
outputs = inputs.map(unpackage_inputs)
outputs = outputs.map(keras.layers.Resizing(height=512, width=512))
outputs = outputs.batch(4, drop_remainder=True)
return outputs
train_ds = preprocess_inputs(train_ds)
batch = train_ds.take(1).get_single_element()
可以使用 plot_images_masks
函式將此預處理的輸入訓練資料批次視覺化。此函式會將影像批次、分割遮罩和預測遮罩作為輸入,並將其顯示在網格中。
def plot_images_masks(images, masks, pred_masks=None):
num_images = len(images)
plt.figure(figsize=(8, 4))
rows = 3 if pred_masks is not None else 2
for i in range(num_images):
plt.subplot(rows, num_images, i + 1)
plt.imshow(images[i] / 255)
plt.axis("off")
plt.subplot(rows, num_images, num_images + i + 1)
plt.imshow(masks[i])
plt.axis("off")
if pred_masks is not None:
plt.subplot(rows, num_images, i + 1 + 2 * num_images)
plt.imshow(pred_masks[i])
plt.axis("off")
plt.show()
plot_images_masks(batch["images"], batch["segmentation_masks"])
預處理會套用至評估資料集 eval_ds
。
eval_ds = preprocess_inputs(eval_ds)
Keras 提供各種影像增強選項。在本範例中,我們將使用 RandomFlip
增強來增強訓練資料集。RandomFlip
增強會隨機水平或垂直翻轉訓練資料集中的影像。這有助於提高模型對影像中物件方向變化的穩健性。
train_ds = train_ds.map(keras.layers.RandomFlip())
batch = train_ds.take(1).get_single_element()
plot_images_masks(batch["images"], batch["segmentation_masks"])
請隨意修改模型訓練的組態,並注意訓練結果的變化。這是讓您更了解訓練管線的絕佳練習。
最佳化工具使用學習率排程來計算每個週期的學習率。然後,最佳化工具會使用學習率來更新模型的權重。在這種情況下,學習率排程會使用餘弦衰減函式。餘弦衰減函式會從高值開始,然後隨著時間推移而遞減,最終達到零。VOC 資料集的基數為 2124,批次大小為 4。資料集的基數對於學習率衰減很重要,因為它決定模型將訓練多少步驟。初始學習率與 0.007 成正比,而衰減步驟為 2124。這表示學習率將從 INITIAL_LR
開始,然後在 2124 個步驟後遞減至零。
BATCH_SIZE = 4
INITIAL_LR = 0.007 * BATCH_SIZE / 16
EPOCHS = 1
NUM_CLASSES = 21
learning_rate = keras.optimizers.schedules.CosineDecay(
INITIAL_LR,
decay_steps=EPOCHS * 2124,
)
讓我們採用 resnet_50_imagenet
預訓練的權重作為模型的影像編碼器,此實作可以與額外的解碼器區塊一起作為 DeepLabV3 和 DeepLabV3+ 使用。對於 DeepLabV3+,我們透過提供 low_level_feature_key
作為 P2
金字塔層級輸出來實例化 DeepLabV3Backbone 模型,以從 resnet_50_imagenet
中擷取特徵,此模型充當解碼器區塊。若要將此模型用作 DeepLabV3 架構,請忽略預設為 None
的 low_level_feature_key
。
然後,我們建立 DeepLabV3ImageSegmenter 執行個體。num_classes
參數會指定模型將訓練來分割的類別數量。preprocessor
引數會將預處理套用至影像輸入和遮罩。
image_encoder = keras_hub.models.Backbone.from_preset("resnet_50_imagenet")
deeplab_backbone = keras_hub.models.DeepLabV3Backbone(
image_encoder=image_encoder,
low_level_feature_key="P2",
spatial_pyramid_pooling_key="P5",
dilation_rates=[6, 12, 18],
upsampling_size=8,
)
model = keras_hub.models.DeepLabV3ImageSegmenter(
backbone=deeplab_backbone,
num_classes=21,
activation="softmax",
preprocessor=preprocessor,
)
model.compile() 函式會設定模型的訓練程序。它會定義 - 最佳化演算法 - 隨機梯度下降 (SGD) - 損失函式 - 類別交叉熵 - 評估指標 - 平均 IoU 和類別準確度
語義分割評估指標
平均交集比率 (MeanIoU):MeanIoU 會測量語義分割模型準確識別和描繪影像中不同物件或區域的程度。它會計算預測和實際物件邊界之間的重疊,並提供介於 0 和 1 之間的分數,其中 1 代表完全符合。
類別準確度:類別準確度會測量影像中正確分類的像素比例。它會提供一個簡單的百分比,指出模型在整個影像中預測像素類別的準確程度。
實質上,MeanIoU 強調識別特定物件邊界的準確性,而類別準確度則提供整體像素級正確性的廣泛概述。
model.compile(
optimizer=keras.optimizers.SGD(
learning_rate=learning_rate, weight_decay=0.0001, momentum=0.9, clipnorm=10.0
),
loss=keras.losses.CategoricalCrossentropy(from_logits=False),
metrics=[
keras.metrics.MeanIoU(
num_classes=NUM_CLASSES, sparse_y_true=False, sparse_y_pred=False
),
keras.metrics.CategoricalAccuracy(),
],
)
model.summary()
Preprocessor: "deep_lab_v3_image_segmenter_preprocessor"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Config ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ deep_lab_v3_image_converter (DeepLabV3ImageConverter) │ Image size: (512, 512) │ └───────────────────────────────────────────────────────────────┴──────────────────────────────────────────┘
Model: "deep_lab_v3_image_segmenter"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩ │ inputs (InputLayer) │ (None, None, None, 3) │ 0 │ ├───────────────────────────────────────────────┼────────────────────────────────────┼─────────────────────┤ │ deep_lab_v3_backbone (DeepLabV3Backbone) │ (None, None, None, 256) │ 39,190,656 │ ├───────────────────────────────────────────────┼────────────────────────────────────┼─────────────────────┤ │ segmentation_output (Conv2D) │ (None, None, None, 21) │ 5,376 │ └───────────────────────────────────────────────┴────────────────────────────────────┴─────────────────────┘
Total params: 39,196,032 (149.52 MB)
Trainable params: 39,139,232 (149.30 MB)
Non-trainable params: 56,800 (221.88 KB)
公用程式函式 dict_to_tuple
會有效地將訓練和驗證資料集的字典轉換為影像和單熱編碼分割遮罩的元組,此元組會在 DeepLabv3+ 模型的訓練和評估期間使用。
def dict_to_tuple(x):
return x["images"], tf.one_hot(
tf.cast(tf.squeeze(x["segmentation_masks"], axis=-1), "int32"), 21
)
train_ds = train_ds.map(dict_to_tuple)
eval_ds = eval_ds.map(dict_to_tuple)
model.fit(train_ds, validation_data=eval_ds, epochs=EPOCHS)
1/Unknown 40s 40s/step - categorical_accuracy: 0.1191 - loss: 3.0568 - mean_io_u: 0.0118
2124/2124 ━━━━━━━━━━━━━━━━━━━━ 281 秒 114 毫秒/步 - categorical_accuracy: 0.7286 - loss: 1.0707 - mean_io_u: 0.0926 - val_categorical_accuracy: 0.8199 - val_loss: 0.5900 - val_mean_io_u: 0.3265
<keras.src.callbacks.history.History at 0x7fd7a897f8d0>
現在 DeepLabv3+ 的模型訓練已完成,讓我們透過對一些範例影像進行預測來測試它。注意:為了示範目的,模型僅訓練了 1 個週期,若要獲得更好的準確度和結果,請使用更多週期進行訓練。
test_ds = load(split="sbd_eval")
test_ds = preprocess_inputs(test_ds)
images, masks = next(iter(test_ds.take(1)))
images = ops.convert_to_tensor(images)
masks = ops.convert_to_tensor(masks)
preds = ops.expand_dims(ops.argmax(model.predict(images), axis=-1), axis=-1)
masks = ops.expand_dims(ops.argmax(masks, axis=-1), axis=-1)
plot_images_masks(images, masks, preds)
1/1 ━━━━━━━━━━━━━━━━━━━━ 0 秒 3 秒/步
1/1 ━━━━━━━━━━━━━━━━━━━━ 3 秒 3 秒/步
以下是一些使用 KerasHub DeepLabv3 模型的其他提示