Keras 3 API 文件 / 模型 API / 模型訓練 API

模型訓練 API

[原始碼]

compile 方法

Model.compile(
    optimizer="rmsprop",
    loss=None,
    loss_weights=None,
    metrics=None,
    weighted_metrics=None,
    run_eagerly=False,
    steps_per_execution=1,
    jit_compile="auto",
    auto_scale_loss=True,
)

配置模型以進行訓練。

範例

model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=1e-3),
    loss=keras.losses.BinaryCrossentropy(),
    metrics=[
        keras.metrics.BinaryAccuracy(),
        keras.metrics.FalseNegatives(),
    ],
)

參數

  • optimizer:字串(優化器名稱)或優化器實例。請參閱 keras.optimizers
  • loss:損失函數。可以是字串(損失函數名稱)或 keras.losses.Loss 實例。請參閱 keras.losses。損失函數是任何具有簽名 loss = fn(y_true, y_pred) 的可呼叫函數,其中 y_true 是真實值,而 y_pred 是模型的預測值。y_true 的形狀應為 (batch_size, d0, .. dN)(除非在稀疏損失函數的情況下,例如稀疏類別交叉熵,它期望形狀為 (batch_size, d0, .. dN-1) 的整數數組)。y_pred 的形狀應為 (batch_size, d0, .. dN)。損失函數應返回浮點張量。
  • loss_weights:可選的列表或字典,指定純量係數(Python 浮點數)來加權不同模型輸出的損失貢獻。模型將最小化的損失值將是所有個別損失的加權總和,由 loss_weights 係數加權。如果是列表,則預期它與模型的輸出具有 1:1 的映射。如果是字典,則預期它將輸出名稱(字串)映射到純量係數。
  • metrics:模型在訓練和測試期間要評估的指標列表。每個指標都可以是字串(內建函數的名稱)、函數或 keras.metrics.Metric 實例。請參閱 keras.metrics。通常您會使用 metrics=['accuracy']。函數是任何具有簽名 result = fn(y_true, _pred) 的可呼叫函數。若要為多輸出模型指定不同輸出的不同指標,您也可以傳遞字典,例如 metrics={'a':'accuracy', 'b':['accuracy', 'mse']}。您也可以傳遞列表來為每個輸出指定指標或指標列表,例如 metrics=[['accuracy'], ['accuracy', 'mse']]metrics=['accuracy', ['accuracy', 'mse']]。當您傳遞字串 'accuracy' 或 'acc' 時,我們會根據目標和模型輸出的形狀將其轉換為 keras.metrics.BinaryAccuracykeras.metrics.CategoricalAccuracykeras.metrics.SparseCategoricalAccuracy 之一。對於字串 "crossentropy""ce" 也會進行類似的轉換。此處傳遞的指標會在不進行樣本加權的情況下進行評估;如果您希望應用樣本加權,則可以改為透過 weighted_metrics 參數指定指標。
  • weighted_metrics:在訓練和測試期間要評估並由 sample_weightclass_weight 加權的指標列表。
  • run_eagerly:布林值。如果為 True,則此模型的前向傳遞永遠不會被編譯。建議在訓練時將此設定為 False(以獲得最佳效能),並在偵錯時將其設定為 True
  • steps_per_execution:整數。在每次單一編譯函數呼叫期間要執行的批次數量。在單一編譯函數呼叫內執行多個批次可以大幅提升在 TPU 或具有大量 Python 開銷的小型模型上的效能。每次執行最多會執行一個完整週期。如果傳遞的數字大於週期的長度,則執行將截斷為週期的長度。請注意,如果 steps_per_execution 設定為 N,則 Callback.on_batch_beginCallback.on_batch_end 方法只會在每 N 個批次之後呼叫(即,在每次編譯函數執行之前/之後)。不支援 PyTorch 後端。
  • jit_compile:布林值或 "auto"。編譯模型時是否使用 XLA 編譯。對於 jaxtensorflow 後端,如果模型支援,jit_compile="auto" 會啟用 XLA 編譯,否則會停用。對於 torch 後端,"auto" 會預設為急切執行,而 jit_compile=True 會使用 torch.compile"inductor" 後端執行。
  • auto_scale_loss:布林值。如果為 True 且模型 dtype 原則為 "mixed_float16",則傳遞的優化器會自動包裝在 LossScaleOptimizer 中,該優化器會動態調整損失以防止下溢。

[原始碼]

fit 方法

Model.fit(
    x=None,
    y=None,
    batch_size=None,
    epochs=1,
    verbose="auto",
    callbacks=None,
    validation_split=0.0,
    validation_data=None,
    shuffle=True,
    class_weight=None,
    sample_weight=None,
    initial_epoch=0,
    steps_per_epoch=None,
    validation_steps=None,
    validation_batch_size=None,
    validation_freq=1,
)

訓練模型固定數量的週期(資料集迭代)。

參數

  • x:輸入資料。它可以是
    • NumPy 陣列(或類似陣列)或陣列列表(如果模型有多個輸入)。
    • 後端原生張量或張量列表(如果模型有多個輸入)。
    • 如果模型具有命名輸入,則將輸入名稱映射到對應陣列/張量的字典。
    • 傳回 (inputs, targets)(inputs, targets, sample_weights)keras.utils.PyDataset
    • 產生 (inputs, targets)(inputs, targets, sample_weights)tf.data.Dataset
    • 產生 (inputs, targets)(inputs, targets, sample_weights)torch.utils.data.DataLoader
    • 產生 (inputs, targets)(inputs, targets, sample_weights) 的 Python 產生器函數。
  • y:目標資料。與輸入資料 x 類似,它可以是 NumPy 陣列或後端原生張量。如果 xkeras.utils.PyDatasettf.data.Datasettorch.utils.data.DataLoader 或 Python 產生器函數,則不應指定 y,因為目標將從 x 取得。
  • batch_size:整數或 None。每次梯度更新的樣本數。如果未指定,則 batch_size 會預設為 32。如果您的輸入資料 xkeras.utils.PyDatasettf.data.Datasettorch.utils.data.DataLoader 或 Python 產生器函數,則請勿指定 batch_size,因為它們會產生批次。
  • epochs:整數。訓練模型的週期數。一個週期是對提供的整個 xy 資料的迭代(除非 steps_per_epoch 旗標設定為非 None 的值)。請注意,與 initial_epoch 結合使用時,epochs 應理解為「最終週期」。模型不會根據 epochs 給定的迭代次數進行訓練,而只是訓練到達到索引為 epochs 的週期。
  • verbose"auto"、0、1 或 2。詳細模式。0 = 無聲、1 = 進度列、2 = 每個週期一行。「auto」在大多數情況下會變為 1。請注意,進度列在記錄到檔案時並不是特別有用,因此建議在非互動式執行時(例如,在生產環境中)使用 verbose=2。預設值為 "auto"
  • callbackskeras.callbacks.Callback 實例列表。在訓練期間要應用的回呼列表。請參閱 keras.callbacks。請注意,keras.callbacks.ProgbarLoggerkeras.callbacks.History 回呼會自動建立,不需要傳遞給 model.fit()keras.callbacks.ProgbarLogger 是根據 model.fit() 中的 verbose 參數建立的。
  • validation_split:介於 0 和 1 之間的浮點數。用作驗證資料的訓練資料比例。模型會將此比例的訓練資料分開,不會對其進行訓練,並在每個週期結束時評估此資料上的損失和任何模型指標。驗證資料是從提供的 xy 資料的最後樣本中選取的,在洗牌之前。只有當 xy 由 NumPy 陣列或張量組成時,才支援此參數。如果同時提供 validation_datavalidation_split,則 validation_data 會覆寫 validation_split
  • validation_data:在每個週期結束時評估損失和任何模型指標的資料。模型不會在此資料上進行訓練。因此,請注意,使用 validation_splitvalidation_data 提供的資料的驗證損失不會受到雜訊和 dropout 等正規化層的影響。validation_data 會覆寫 validation_split。它可以是
    • NumPy 陣列或張量的元組 (x_val, y_val)
    • NumPy 陣列的元組 (x_val, y_val, val_sample_weights)
    • 產生 (inputs, targets)keras.utils.PyDatasettf.data.Datasettorch.utils.data.DataLoader 或產生 (x_val, y_val)(inputs, targets, sample_weights) 的 Python 產生器函數。
  • shuffle:布林值,是否在每個週期之前對訓練資料進行洗牌。當 xkeras.utils.PyDatasettf.data.Datasettorch.utils.data.DataLoader 或 Python 產生器函數時,會忽略此參數。
  • class_weight:可選的字典,將類別索引(整數)映射到權重(浮點數)值,用於加權損失函數(僅在訓練期間)。這可用於告知模型「更注意」來自代表性不足類別的樣本。當指定 class_weight 且目標的秩為 2 或更大時,y 必須進行單熱編碼,或者必須為稀疏類別標籤包含明確的最終維度 1
  • sample_weight:可選的 NumPy 陣列或張量,用於訓練樣本的權重,僅在訓練期間用於加權損失函數。您可以傳遞一個與輸入樣本長度相同的扁平(1D)NumPy 陣列或張量(權重和樣本之間 1:1 對應),或者在時間序列資料的情況下,您可以傳遞一個形狀為 (樣本數, 序列長度) 的 2D NumPy 陣列或張量,以便對每個樣本的每個時間步應用不同的權重。當 xkeras.utils.PyDatasettf.data.Datasettorch.utils.data.DataLoader 或 Python 生成器函數時,不支援此參數。請改為將 sample_weights 作為 x 的第三個元素提供。請注意,樣本加權不適用於在 compile() 中透過 metrics 參數指定的指標。要將樣本加權應用於您的指標,您可以在 compile() 中透過 weighted_metrics 來指定它們。
  • initial_epoch:整數。開始訓練的 epoch (對於恢復先前的訓練執行很有用)。
  • steps_per_epoch:整數或 None。在宣告一個 epoch 完成並開始下一個 epoch 之前,總共需要執行多少步驟(樣本批次)。當使用輸入張量或 NumPy 陣列進行訓練時,預設的 None 表示使用的值是數據集中樣本數除以批次大小,如果無法確定,則為 1。如果 xkeras.utils.PyDatasettf.data.Datasettorch.utils.data.DataLoader 或 Python 生成器函數,則 epoch 將執行直到輸入數據集耗盡。當傳遞無限重複的數據集時,您必須指定 steps_per_epoch 參數,否則訓練將無限期執行。
  • validation_steps:整數或 None。僅在提供 validation_data 時相關。在每個 epoch 結束時執行驗證之前,需要提取的步驟總數(樣本批次)。如果 validation_stepsNone,則驗證將執行直到 validation_data 數據集耗盡。在無限重複的數據集的情況下,它將無限期執行。如果指定了 validation_steps 並且僅消耗了數據集的一部分,則評估將在每個 epoch 從數據集的開頭開始。這確保了每次都使用相同的驗證樣本。
  • validation_batch_size:整數或 None。每個驗證批次的樣本數。如果未指定,則預設為 batch_size。如果您的資料是 keras.utils.PyDatasettf.data.Datasettorch.utils.data.DataLoader 或 Python 生成器函數,則不要指定 validation_batch_size,因為它們會產生批次。
  • validation_freq:僅在提供驗證資料時相關。指定在執行新的驗證之前要執行多少個訓練 epoch,例如 validation_freq=2 每 2 個 epoch 執行一次驗證。

類迭代器輸入的解包行為:一個常見的模式是將類似迭代器的物件(例如 tf.data.Datasetkeras.utils.PyDataset)傳遞給 fit(),實際上它不僅會產生特徵 (x),還會選擇性地產生目標 (y) 和樣本權重 (sample_weight)。Keras 要求此類類似迭代器的輸出是明確的。迭代器應返回長度為 1、2 或 3 的元組,其中可選的第二個和第三個元素將分別用於 ysample_weight。提供的任何其他類型將被包裝在長度為一的元組中,有效地將所有內容都視為 x。當產生字典時,它們仍然應遵守頂層的元組結構,例如 ({"x0": x0, "x1": x1}, y)。Keras 不會嘗試從單個字典的鍵中分離特徵、目標和權重。一個值得注意的不支援資料類型是 namedtuple。原因在於它既像有序資料類型 (元組) 又像映射資料類型 (字典)。因此,給定一個形式為:namedtuple("example_tuple", ["y", "x"]) 的 namedtuple,當解釋值時,是否反轉元素的順序是不明確的。更糟的是形式為:namedtuple("other_tuple", ["x", "y", "z"]) 的元組,其中不清楚元組是否打算解包到 xysample_weight 中,還是作為單個元素傳遞到 x 中。

回傳值

一個 History 物件。其 History.history 屬性記錄了連續 epoch 的訓練損失值和指標值,以及驗證損失值和驗證指標值(如果適用)。


[原始碼]

evaluate 方法

Model.evaluate(
    x=None,
    y=None,
    batch_size=None,
    verbose="auto",
    sample_weight=None,
    steps=None,
    callbacks=None,
    return_dict=False,
    **kwargs
)

回傳模型在測試模式下的損失值和指標值。

計算是以批次完成的(請參閱 batch_size 參數)。

參數

  • x:輸入資料。它可以是
    • NumPy 陣列(或類似陣列)或陣列列表(如果模型有多個輸入)。
    • 後端原生張量或張量列表(如果模型有多個輸入)。
    • 如果模型具有命名輸入,則將輸入名稱映射到對應陣列/張量的字典。
    • 傳回 (inputs, targets)(inputs, targets, sample_weights)keras.utils.PyDataset
    • 產生 (inputs, targets)(inputs, targets, sample_weights)tf.data.Dataset
    • 產生 (inputs, targets)(inputs, targets, sample_weights)torch.utils.data.DataLoader
    • 產生 (inputs, targets)(inputs, targets, sample_weights) 的 Python 產生器函數。
  • y:目標資料。與輸入資料 x 類似,它可以是 NumPy 陣列或後端原生張量。如果 xkeras.utils.PyDatasettf.data.Datasettorch.utils.data.DataLoader 或 Python 產生器函數,則不應指定 y,因為目標將從 x 取得。
  • batch_size:整數或 None。每個計算批次的樣本數。如果未指定,則 batch_size 將預設為 32。如果您的輸入資料 xkeras.utils.PyDatasettf.data.Datasettorch.utils.data.DataLoader 或 Python 生成器函數,則不要指定 batch_size,因為它們會產生批次。
  • verbose"auto"、0、1 或 2。詳細模式。0 = 無聲,1 = 進度條,2 = 單行。"auto" 在大多數情況下變為 1。請注意,進度條在記錄到檔案時並非特別有用,因此當不在互動模式下執行時(例如,在生產環境中),建議使用 verbose=2。預設值為 "auto"
  • sample_weight:可選的 NumPy 陣列或張量,用於訓練樣本的權重,僅在訓練期間用於加權損失函數。您可以傳遞一個與輸入樣本長度相同的扁平(1D)NumPy 陣列或張量(權重和樣本之間 1:1 對應),或者在時間序列資料的情況下,您可以傳遞一個形狀為 (樣本數, 序列長度) 的 2D NumPy 陣列或張量,以便對每個樣本的每個時間步應用不同的權重。當 xkeras.utils.PyDatasettf.data.Datasettorch.utils.data.DataLoader 或 Python 生成器函數時,不支援此參數。請改為將 sample_weights 作為 x 的第三個元素提供。請注意,樣本加權不適用於在 compile() 中透過 metrics 參數指定的指標。要將樣本加權應用於您的指標,您可以在 compile() 中透過 weighted_metrics 來指定它們。
  • steps:整數或 None。在宣告評估輪次完成之前,需要提取的步驟總數(樣本批次)。如果 stepsNone,則將執行直到 x 耗盡。在無限重複的數據集的情況下,它將無限期執行。
  • callbackskeras.callbacks.Callback 實例的列表。在評估期間套用的回調列表。
  • return_dict:如果為 True,則損失和指標結果將以字典的形式回傳,其中每個鍵是指標的名稱。如果為 False,則將以列表的形式回傳。

回傳值

純量測試損失(如果模型具有單個輸出且沒有指標)或純量列表(如果模型具有多個輸出和/或指標)。屬性 model.metrics_names 將為您提供純量輸出的顯示標籤。


[原始碼]

predict 方法

Model.predict(x, batch_size=None, verbose="auto", steps=None, callbacks=None)

為輸入樣本產生輸出預測。

計算是以批次完成的。此方法旨在批量處理大量輸入。它不適用於在迴圈內迭代資料並一次處理少量輸入。

對於適合單個批次的小量輸入,直接使用 __call__() 可更快地執行,例如,model(x),或者如果您有諸如 BatchNormalization 之類的在推論期間行為不同的層,則使用 model(x, training=False)

注意:有關 Model 方法 predict()__call__() 之間差異的更多詳細資訊,請參閱此常見問題解答條目

參數

  • x:輸入資料。它可以是
    • NumPy 陣列(或類似陣列)或陣列列表(如果模型有多個輸入)。
    • 後端原生張量或張量列表(如果模型有多個輸入)。
    • 如果模型具有命名輸入,則將輸入名稱映射到對應陣列/張量的字典。
    • 一個 keras.utils.PyDataset
    • 一個 tf.data.Dataset
    • 一個 torch.utils.data.DataLoader
    • 一個 Python 生成器函數。
  • batch_size:整數或 None。每個計算批次的樣本數。如果未指定,則 batch_size 將預設為 32。如果您的輸入資料 xkeras.utils.PyDatasettf.data.Datasettorch.utils.data.DataLoader 或 Python 生成器函數,則不要指定 batch_size,因為它們會產生批次。
  • verbose"auto"、0、1 或 2。詳細模式。0 = 無聲,1 = 進度條,2 = 單行。"auto" 在大多數情況下變為 1。請注意,進度條在記錄到檔案時並非特別有用,因此當不在互動模式下執行時(例如,在生產環境中),建議使用 verbose=2。預設值為 "auto"
  • steps:在宣告預測輪次完成之前,需要提取的步驟總數(樣本批次)。如果 stepsNone,則將執行直到 x 耗盡。在無限重複的數據集的情況下,它將無限期執行。
  • callbackskeras.callbacks.Callback 實例的列表。在預測期間套用的回調列表。

回傳值

預測的 NumPy 陣列。


[原始碼]

train_on_batch 方法

Model.train_on_batch(
    x, y=None, sample_weight=None, class_weight=None, return_dict=False
)

對單批次資料執行單一梯度更新。

參數

  • x:輸入資料。必須為類陣列。
  • y:目標資料。必須為類陣列。
  • sample_weight:與 x 長度相同的可選陣列,其中包含要應用於每個樣本的模型損失的權重。在時間序列資料的情況下,您可以傳遞形狀為 (樣本數, 序列長度) 的 2D 陣列,以便對每個樣本的每個時間步應用不同的權重。
  • class_weight:可選字典,將類別索引(整數)映射到權重(浮點數),以便在訓練期間應用於來自該類別的樣本的模型損失。這對於告知模型「更多關注」來自代表性不足類別的樣本很有用。當指定 class_weight 並且目標的排名為 2 或更高時,y 必須為單熱編碼,或者對於稀疏類別標籤,必須包含明確的最終維度 1。
  • return_dict:如果為 True,則損失和指標結果將以字典的形式回傳,其中每個鍵是指標的名稱。如果為 False,則將以列表的形式回傳。

回傳值

純量損失值(當沒有指標且 return_dict=False 時),損失和指標值的列表(如果有指標且 return_dict=False 時),或指標和損失值的字典(如果 return_dict=True 時)。


[原始碼]

test_on_batch 方法

Model.test_on_batch(x, y=None, sample_weight=None, return_dict=False)

在單批次樣本上測試模型。

參數

  • x:輸入資料。必須為類陣列。
  • y:目標資料。必須為類陣列。
  • sample_weight:與 x 長度相同的可選陣列,其中包含要應用於每個樣本的模型損失的權重。在時間序列資料的情況下,您可以傳遞形狀為 (樣本數, 序列長度) 的 2D 陣列,以便對每個樣本的每個時間步應用不同的權重。
  • return_dict:如果為 True,則損失和指標結果將以字典的形式回傳,其中每個鍵是指標的名稱。如果為 False,則將以列表的形式回傳。

回傳值

純量損失值(當沒有指標且 return_dict=False 時),損失和指標值的列表(如果有指標且 return_dict=False 時),或指標和損失值的字典(如果 return_dict=True 時)。


[原始碼]

predict_on_batch 方法

Model.predict_on_batch(x)

回傳單批次樣本的預測。

參數

  • x:輸入資料。它必須為類陣列。

回傳值

預測的 NumPy 陣列。