BASNet
類別keras_cv.models.BASNet(
backbone,
num_classes,
input_shape=(None, None, 3),
input_tensor=None,
include_rescaling=False,
projection_filters=64,
prediction_heads=None,
refinement_head=None,
**kwargs
)
一個實現 BASNet 架構的 Keras 模型,用於語義分割。
參考文獻
參數
keras.Model
。模型的主幹網路,用作 BASNet 預測編碼器的特徵提取器。目前支援的主幹網路是 ResNet18 和 ResNet34。預設主幹網路是 keras_cv.models.ResNet34Backbone()
(注意:請勿在主幹網路中指定「input_shape」、「input_tensor」或「include_rescaling」。請在初始化「BASNet」模型時提供這些參數。)layers.Input()
的輸出),用作模型的影像輸入。True
,則輸入將通過 Rescaling(1/255.0)
層。backbone
投影低階特徵的卷積層中的濾波器數量。keras.layers.Layer
列表,定義模型的預測模組頭。如果未提供,則會使用 Conv2D 層後跟調整大小操作來建立預設頭。keras.layers.Layer
,定義模型的細化模組頭。如果未提供,則會使用 Conv2D 層建立預設頭。範例
import keras_cv
images = np.ones(shape=(1, 288, 288, 3))
labels = np.zeros(shape=(1, 288, 288, 1))
# Note: Do not specify 'input_shape', 'input_tensor', or
# 'include_rescaling' within the backbone.
backbone = keras_cv.models.ResNet34Backbone()
model = keras_cv.models.segmentation.BASNet(
backbone=backbone,
num_classes=1,
input_shape=[288, 288, 3],
include_rescaling=False
)
# Evaluate model
output = model(images)
pred_labels = output[0]
# Train model
model.compile(
optimizer="adam",
loss=keras.losses.BinaryCrossentropy(from_logits=False),
metrics=["accuracy"],
)
model.fit(images, labels, epochs=3)
```
----
<span style="float:right;">[[source]](https://github.com/keras-team/keras-cv/tree/v0.9.0/keras_cv/src/models/task.py#L183)</span>
### `from_preset` method
```python
BASNet.from_preset()
從預設配置和權重實例化 BASNet 模型。
參數
None
,這取決於預設模型是否有可用的預先訓練權重。None
。如果為 None
,則將使用預設值。範例
# Load architecture and weights from preset
model = keras_cv.models.BASNet.from_preset(
"",
)
# Load randomly initialized model from preset architecture with weights
model = keras_cv.models.BASNet.from_preset(
"",
load_weights=False,
預設名稱 | 參數 | 描述 |
---|---|---|
basnet_resnet18 | 98.78M | 具有 ResNet18 v1 主幹網路的 BASNet。 |
basnet_resnet34 | 108.90M | 採用 ResNet34 v1 骨幹的 BASNet。 |