程式碼範例 / 電腦視覺 / 從頭開始的圖像分類

從頭開始的圖像分類

作者: fchollet
建立日期 2020/04/27
上次修改日期 2023/11/09
說明: 從頭開始在 Kaggle 貓狗資料集上訓練圖像分類器。

ⓘ 此範例使用 Keras 3

在 Colab 中檢視 GitHub 來源


簡介

本範例展示如何從頭開始進行圖像分類,從磁碟上的 JPEG 圖像檔案開始,不利用預訓練權重或預先製作的 Keras 應用程式模型。我們在 Kaggle 貓狗二元分類資料集上示範此工作流程。

我們使用 image_dataset_from_directory 工具來產生資料集,並使用 Keras 圖像預處理層進行圖像標準化和資料擴增。


設定

import os
import numpy as np
import keras
from keras import layers
from tensorflow import data as tf_data
import matplotlib.pyplot as plt

載入資料:貓狗資料集

原始資料下載

首先,讓我們下載原始資料的 786M ZIP 壓縮檔

!curl -O https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_5340.zip
!unzip -q kagglecatsanddogs_5340.zip
!ls
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  786M  100  786M    0     0  11.1M      0  0:01:10  0:01:10 --:--:-- 11.8M

 CDLA-Permissive-2.0.pdf           kagglecatsanddogs_5340.zip
 PetImages                'readme[1].txt'
 image_classification_from_scratch.ipynb

現在我們有一個 PetImages 資料夾,其中包含兩個子資料夾,CatDog。每個子資料夾都包含每個類別的圖像檔案。

!ls PetImages
Cat  Dog

篩選損壞的圖像

當使用大量真實世界的圖像資料時,損壞的圖像是一種常見的情況。讓我們篩選掉標頭中沒有 "JFIF" 字串的不良編碼圖像。

num_skipped = 0
for folder_name in ("Cat", "Dog"):
    folder_path = os.path.join("PetImages", folder_name)
    for fname in os.listdir(folder_path):
        fpath = os.path.join(folder_path, fname)
        try:
            fobj = open(fpath, "rb")
            is_jfif = b"JFIF" in fobj.peek(10)
        finally:
            fobj.close()

        if not is_jfif:
            num_skipped += 1
            # Delete corrupted image
            os.remove(fpath)

print(f"Deleted {num_skipped} images.")
Deleted 1590 images.

產生 Dataset

image_size = (180, 180)
batch_size = 128

train_ds, val_ds = keras.utils.image_dataset_from_directory(
    "PetImages",
    validation_split=0.2,
    subset="both",
    seed=1337,
    image_size=image_size,
    batch_size=batch_size,
)
Found 23410 files belonging to 2 classes.
Using 18728 files for training.
Using 4682 files for validation.

視覺化資料

以下是訓練資料集中的前 9 張圖像。

plt.figure(figsize=(10, 10))
for images, labels in train_ds.take(1):
    for i in range(9):
        ax = plt.subplot(3, 3, i + 1)
        plt.imshow(np.array(images[i]).astype("uint8"))
        plt.title(int(labels[i]))
        plt.axis("off")

png


使用圖像資料擴增

當您沒有大型圖像資料集時,最好透過對訓練圖像應用隨機但真實的轉換(例如隨機水平翻轉或小的隨機旋轉)來人工引入樣本多樣性。這有助於讓模型接觸到訓練資料的不同面向,同時減緩過度擬合。

data_augmentation_layers = [
    layers.RandomFlip("horizontal"),
    layers.RandomRotation(0.1),
]


def data_augmentation(images):
    for layer in data_augmentation_layers:
        images = layer(images)
    return images

讓我們透過將 data_augmentation 重複應用於資料集中的前幾個圖像,來視覺化擴增樣本的外觀

plt.figure(figsize=(10, 10))
for images, _ in train_ds.take(1):
    for i in range(9):
        augmented_images = data_augmentation(images)
        ax = plt.subplot(3, 3, i + 1)
        plt.imshow(np.array(augmented_images[0]).astype("uint8"))
        plt.axis("off")

png


標準化資料

我們的圖像已經是標準大小 (180x180),因為它們是由我們的資料集產生為連續的 float32 批次。但是,它們的 RGB 通道值在 [0, 255] 範圍內。這對於神經網路來說並不理想;一般來說,您應該盡量使您的輸入值變小。在這裡,我們將使用模型開始時的 Rescaling 層將值標準化為 [0, 1]


預處理資料的兩種選項

您可以使用 data_augmentation 預處理器的兩種方法

選項 1:將其設為模型的一部分,如下所示

inputs = keras.Input(shape=input_shape)
x = data_augmentation(inputs)
x = layers.Rescaling(1./255)(x)
...  # Rest of the model

使用此選項,您的資料擴增將在裝置上發生,與模型執行的其餘部分同步,這表示它將受益於 GPU 加速。

請注意,資料擴增在測試時處於非活動狀態,因此輸入樣本只會在 fit() 期間擴增,而不是在呼叫 evaluate()predict() 時擴增。

如果您在 GPU 上訓練,這可能是一個不錯的選項。

選項 2:將其應用於資料集,以取得產生擴增圖像批次的資料集,如下所示

augmented_train_ds = train_ds.map(
    lambda x, y: (data_augmentation(x, training=True), y))

使用此選項,您的資料擴增將在 CPU 上非同步發生,並會在進入模型之前進行緩衝。

如果您是在 CPU 上訓練,這是較好的選擇,因為它可以讓資料擴增變成非同步且非阻塞的。

在我們的例子中,我們會選擇第二個選項。如果您不確定該選哪個,第二個選項(非同步預處理)通常是個穩妥的選擇。


為了效能配置資料集

讓我們將資料擴增應用於我們的訓練資料集,並確保使用緩衝預取,這樣我們才能從磁碟產生資料,而不會讓 I/O 變成阻塞。

# Apply `data_augmentation` to the training images.
train_ds = train_ds.map(
    lambda img, label: (data_augmentation(img), label),
    num_parallel_calls=tf_data.AUTOTUNE,
)
# Prefetching samples in GPU memory helps maximize GPU utilization.
train_ds = train_ds.prefetch(tf_data.AUTOTUNE)
val_ds = val_ds.prefetch(tf_data.AUTOTUNE)

建立模型

我們將建立一個小型版本的 Xception 網路。我們並沒有特別嘗試優化架構;如果您想系統性地搜尋最佳的模型配置,請考慮使用 KerasTuner

請注意

  • 我們從 data_augmentation 預處理器開始模型,接著是一個 Rescaling 層。
  • 我們在最後的分類層之前包含一個 Dropout 層。
def make_model(input_shape, num_classes):
    inputs = keras.Input(shape=input_shape)

    # Entry block
    x = layers.Rescaling(1.0 / 255)(inputs)
    x = layers.Conv2D(128, 3, strides=2, padding="same")(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)

    previous_block_activation = x  # Set aside residual

    for size in [256, 512, 728]:
        x = layers.Activation("relu")(x)
        x = layers.SeparableConv2D(size, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.Activation("relu")(x)
        x = layers.SeparableConv2D(size, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.MaxPooling2D(3, strides=2, padding="same")(x)

        # Project residual
        residual = layers.Conv2D(size, 1, strides=2, padding="same")(
            previous_block_activation
        )
        x = layers.add([x, residual])  # Add back residual
        previous_block_activation = x  # Set aside next residual

    x = layers.SeparableConv2D(1024, 3, padding="same")(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)

    x = layers.GlobalAveragePooling2D()(x)
    if num_classes == 2:
        units = 1
    else:
        units = num_classes

    x = layers.Dropout(0.25)(x)
    # We specify activation=None so as to return logits
    outputs = layers.Dense(units, activation=None)(x)
    return keras.Model(inputs, outputs)


model = make_model(input_shape=image_size + (3,), num_classes=2)
keras.utils.plot_model(model, show_shapes=True)

png


訓練模型

epochs = 25

callbacks = [
    keras.callbacks.ModelCheckpoint("save_at_{epoch}.keras"),
]
model.compile(
    optimizer=keras.optimizers.Adam(3e-4),
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[keras.metrics.BinaryAccuracy(name="acc")],
)
model.fit(
    train_ds,
    epochs=epochs,
    callbacks=callbacks,
    validation_data=val_ds,
)
Epoch 1/25
...
Epoch 25/25
 147/147 ━━━━━━━━━━━━━━━━━━━━ 53s 354ms/step - acc: 0.9638 - loss: 0.0903 - val_acc: 0.9382 - val_loss: 0.1542

<keras.src.callbacks.history.History at 0x7f41003c24a0>

在完整資料集上訓練 25 個週期後,我們獲得了 >90% 的驗證準確度(實際上,您可以在驗證效能開始下降之前訓練 50 個以上的週期)。


在新資料上執行推論

請注意,資料擴增和 dropout 在推論時處於非活動狀態。

img = keras.utils.load_img("PetImages/Cat/6779.jpg", target_size=image_size)
plt.imshow(img)

img_array = keras.utils.img_to_array(img)
img_array = keras.ops.expand_dims(img_array, 0)  # Create batch axis

predictions = model.predict(img_array)
score = float(keras.ops.sigmoid(predictions[0][0]))
print(f"This image is {100 * (1 - score):.2f}% cat and {100 * score:.2f}% dog.")
 1/1 ━━━━━━━━━━━━━━━━━━━━ 2s 2s/step
This image is 94.30% cat and 5.70% dog.

png