Keras 3 API 文件 / KerasCV / 模型 / 任務 / BASNet 分割

BASNet 分割

[來源]

BASNet 類別

keras_cv.models.BASNet(
    backbone,
    num_classes,
    input_shape=(None, None, 3),
    input_tensor=None,
    include_rescaling=False,
    projection_filters=64,
    prediction_heads=None,
    refinement_head=None,
    **kwargs
)

一個實現 BASNet 架構的 Keras 模型,用於語義分割。

參考文獻

參數

  • backbonekeras.Model。模型的主幹網路,用作 BASNet 預測編碼器的特徵提取器。目前支援的主幹網路是 ResNet18 和 ResNet34。預設主幹網路是 keras_cv.models.ResNet34Backbone()(注意:請勿在主幹網路中指定「input_shape」、「input_tensor」或「include_rescaling」。請在初始化「BASNet」模型時提供這些參數。)
  • num_classes:整數,分割模型的類別數。
  • input_shape:可選的形狀元組,預設為 (None, None, 3)。
  • input_tensor:可選的 Keras 張量(例如,layers.Input() 的輸出),用作模型的影像輸入。
  • include_rescaling:布林值,是否重新縮放輸入。如果設為 True,則輸入將通過 Rescaling(1/255.0) 層。
  • projection_filters:整數,從 backbone 投影低階特徵的卷積層中的濾波器數量。
  • prediction_heads:(可選)keras.layers.Layer 列表,定義模型的預測模組頭。如果未提供,則會使用 Conv2D 層後跟調整大小操作來建立預設頭。
  • refinement_head:(可選)keras.layers.Layer,定義模型的細化模組頭。如果未提供,則會使用 Conv2D 層建立預設頭。

範例

import keras_cv

images = np.ones(shape=(1, 288, 288, 3))
labels = np.zeros(shape=(1, 288, 288, 1))

# Note: Do not specify 'input_shape', 'input_tensor', or
# 'include_rescaling' within the backbone.
backbone = keras_cv.models.ResNet34Backbone()
model = keras_cv.models.segmentation.BASNet(
    backbone=backbone,
    num_classes=1,
    input_shape=[288, 288, 3],
    include_rescaling=False
)

# Evaluate model
output = model(images)
pred_labels = output[0]

# Train model
model.compile(
    optimizer="adam",
    loss=keras.losses.BinaryCrossentropy(from_logits=False),
    metrics=["accuracy"],
)
model.fit(images, labels, epochs=3)
    ```


----

<span style="float:right;">[[source]](https://github.com/keras-team/keras-cv/tree/v0.9.0/keras_cv/src/models/task.py#L183)</span>

### `from_preset` method


```python
BASNet.from_preset()

從預設配置和權重實例化 BASNet 模型。

參數

  • preset:字串。必須是「resnet18」、「resnet34」、「basnet_resnet18」或「basnet_resnet34」其中之一。如果要尋找具有預先訓練權重的預設模型,請選擇「」其中之一。
  • load_weights:是否將預先訓練的權重載入模型。預設為 None,這取決於預設模型是否有可用的預先訓練權重。
  • input_shape :將傳遞給主幹網路初始化的輸入形狀,預設為 None。如果為 None,則將使用預設值。

範例

# Load architecture and weights from preset
model = keras_cv.models.BASNet.from_preset(
    "",
)

# Load randomly initialized model from preset architecture with weights
model = keras_cv.models.BASNet.from_preset(
    "",
    load_weights=False,
預設名稱 參數 描述
basnet_resnet18 98.78M 具有 ResNet18 v1 主幹網路的 BASNet。
basnet_resnet34 108.90M 採用 ResNet34 v1 骨幹的 BASNet。