Keras 3 API 文件 / KerasTuner / 超模型 / 基礎 HyperModel 類別

基礎 HyperModel 類別

[原始碼]

HyperModel 類別

keras_tuner.HyperModel(name=None, tunable=True)

定義模型的搜尋空間。

搜尋空間是模型的集合。build 函數將使用給定的 HyperParameters 物件從空間中建構其中一個模型。

使用者應該繼承 HyperModel 類別並覆寫 build() 來定義他們的搜尋空間,該函數會建立並返回 Keras 模型。您也可以選擇性地覆寫 fit() 來客製化模型的訓練過程。

範例

build() 中,您可以使用超參數建立模型。

class MyHyperModel(kt.HyperModel):
    def build(self, hp):
        model = keras.Sequential()
        model.add(keras.layers.Dense(
            hp.Choice('units', [8, 16, 32]),
            activation='relu'))
        model.add(keras.layers.Dense(1, activation='relu'))
        model.compile(loss='mse')
        return model

覆寫 HyperModel.fit() 時,如果您使用 model.fit() 來訓練您的模型(它會返回訓練歷史記錄),您可以直接返回它。您可以使用 hp 指定要調整的任何超參數。

class MyHyperModel(kt.HyperModel):
    def build(self, hp):
        ...

    def fit(self, hp, model, *args, **kwargs):
        return model.fit(
            *args,
            epochs=hp.Int("epochs", 5, 20),
            **kwargs)

如果您有客製化的訓練過程,您可以將目標值作為浮點數返回。

如果您想追蹤更多指標,您可以返回要追蹤的指標字典。

class MyHyperModel(kt.HyperModel):
    def build(self, hp):
        ...

    def fit(self, hp, model, *args, **kwargs):
        ...
        return {
            "loss": loss,
            "val_loss": val_loss,
            "val_accuracy": val_accuracy
        }

參數

  • name:選填字串,此 HyperModel 的名稱。
  • tunable:布林值,是否應將此超模型中定義的超參數添加到搜尋空間。如果為 False,則必須事先定義這些參數的搜尋空間,否則將使用預設值。預設值為 True。

[原始碼]

build 方法

HyperModel.build(hp)

建構模型。

參數

  • hpHyperParameters 實例。

返回

模型實例。