Model
類別keras.Model()
一個將層分組到具有訓練/推論功能的物件中的模型。
有三種實例化 Model
的方法
從 Input
開始,串聯層呼叫以指定模型的前向傳遞,最後,從輸入和輸出建立模型
inputs = keras.Input(shape=(37,))
x = keras.layers.Dense(32, activation="relu")(inputs)
outputs = keras.layers.Dense(5, activation="softmax")(x)
model = keras.Model(inputs=inputs, outputs=outputs)
注意:僅支援輸入張量的字典、列表和元組。不支援巢狀輸入(例如列表的列表或字典的字典)。
也可以使用中間張量建立新的函數式 API 模型。這使您可以快速提取模型的子組件。
範例
inputs = keras.Input(shape=(None, None, 3))
processed = keras.layers.RandomCrop(width=128, height=128)(inputs)
conv = keras.layers.Conv2D(filters=32, kernel_size=3)(processed)
pooling = keras.layers.GlobalAveragePooling2D()(conv)
feature = keras.layers.Dense(10)(pooling)
full_model = keras.Model(inputs, feature)
backbone = keras.Model(processed, conv)
activations = keras.Model(conv, feature)
請注意,backbone
和 activations
模型不是使用 keras.Input
物件建立的,而是使用源自 keras.Input
物件的張量建立的。在底層,層和權重將在這些模型之間共享,以便使用者可以訓練 full_model
,並使用 backbone
或 activations
進行特徵提取。模型的輸入和輸出也可以是張量的巢狀結構,並且建立的模型是標準的函數式 API 模型,支援所有現有的 API。
Model
類別在這種情況下,您應該在 __init__()
中定義您的層,並且您應該在 call()
中實作模型的前向傳遞。
class MyModel(keras.Model):
def __init__(self):
super().__init__()
self.dense1 = keras.layers.Dense(32, activation="relu")
self.dense2 = keras.layers.Dense(5, activation="softmax")
def call(self, inputs):
x = self.dense1(inputs)
return self.dense2(x)
model = MyModel()
如果您子類化 Model
,您可以選擇在 call()
中使用 training
參數(布林值),您可以使用它來指定訓練和推論中的不同行為
class MyModel(keras.Model):
def __init__(self):
super().__init__()
self.dense1 = keras.layers.Dense(32, activation="relu")
self.dense2 = keras.layers.Dense(5, activation="softmax")
self.dropout = keras.layers.Dropout(0.5)
def call(self, inputs, training=False):
x = self.dense1(inputs)
x = self.dropout(x, training=training)
return self.dense2(x)
model = MyModel()
模型建立後,您可以使用 model.compile()
配置模型的損失和指標,使用 model.fit()
訓練模型,或使用模型透過 model.predict()
進行預測。
Sequential
類別此外,keras.Sequential
是一種特殊的模型,其中模型純粹是單輸入、單輸出層的堆疊。
model = keras.Sequential([
keras.Input(shape=(None, None, 3)),
keras.layers.Conv2D(filters=32, kernel_size=3),
])
summary
方法Model.summary(
line_length=None,
positions=None,
print_fn=None,
expand_nested=False,
show_trainable=False,
layer_range=None,
)
印出網路的字串摘要。
參數
[0.3, 0.6, 0.70, 1.]
。預設為 None
。stdout
。如果 stdout
在您的環境中不起作用,請變更為 print
。它將在摘要的每一行上呼叫。您可以將其設定為自訂函數,以便捕獲字串摘要。False
。False
。layer_range[0]
的元素,結束謂詞將是最後一個符合 layer_range[1]
的元素。預設情況下,None
考慮模型的所有層。引發
summary()
。get_layer
方法Model.get_layer(name=None, index=None)
根據層的名稱(唯一)或索引檢索層。
如果同時提供了 name
和 index
,則 index
將優先。索引基於水平圖遍歷的順序(由下而上)。
參數
回傳
一個層實例。