Keras 3 API 文件 / 模型 API / 模型訓練 API

模型訓練 API

[來源]

compile 方法

Model.compile(
    optimizer="rmsprop",
    loss=None,
    loss_weights=None,
    metrics=None,
    weighted_metrics=None,
    run_eagerly=False,
    steps_per_execution=1,
    jit_compile="auto",
    auto_scale_loss=True,
)

配置模型以進行訓練。

範例

model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=1e-3),
    loss=keras.losses.BinaryCrossentropy(),
    metrics=[
        keras.metrics.BinaryAccuracy(),
        keras.metrics.FalseNegatives(),
    ],
)

引數

  • optimizer:字串(最佳化器名稱)或最佳化器實例。請參閱 keras.optimizers
  • loss:損失函數。可以是字串(損失函數名稱),或 keras.losses.Loss 實例。請參閱 keras.losses。損失函數可以是任何具有簽章 loss = fn(y_true, y_pred) 的可呼叫物件,其中 y_true 是真實值,而 y_pred 是模型的預測值。y_true 應具有形狀 (batch_size, d0, .. dN) (稀疏損失函數(例如稀疏類別交叉熵)除外,其預期為形狀 (batch_size, d0, .. dN-1) 的整數陣列)。y_pred 應具有形狀 (batch_size, d0, .. dN)。損失函數應傳回浮點張量。
  • loss_weights:選用的列表或字典,用於指定純量係數(Python 浮點數)以權衡不同模型輸出的損失貢獻。模型將最小化的損失值將會是所有個別損失的加權總和,並由 loss_weights 係數加權。如果是列表,則預期與模型的輸出具有 1:1 的映射關係。如果是字典,則預期將輸出名稱(字串)映射到純量係數。
  • metrics:要在模型訓練和測試期間評估的指標列表。每個指標可以是字串(內建函數的名稱)、函數或 keras.metrics.Metric 實例。請參閱 keras.metrics。通常您會使用 metrics=['accuracy']。函數可以是任何具有簽章 result = fn(y_true, _pred) 的可呼叫物件。若要為多輸出模型指定不同輸出的不同指標,您也可以傳遞字典,例如 metrics={'a':'accuracy', 'b':['accuracy', 'mse']}。您也可以傳遞列表,為每個輸出指定一個指標或指標列表,例如 metrics=[['accuracy'], ['accuracy', 'mse']]metrics=['accuracy', ['accuracy', 'mse']]。當您傳遞字串 'accuracy' 或 'acc' 時,我們會根據目標和模型輸出的形狀,將其轉換為 keras.metrics.BinaryAccuracykeras.metrics.CategoricalAccuracykeras.metrics.SparseCategoricalAccuracy 其中之一。對於字串 "crossentropy""ce" 也會進行類似的轉換。此處傳遞的指標會在不進行樣本加權的情況下評估;如果您希望套用樣本加權,您可以改為透過 weighted_metrics 引數指定您的指標。
  • weighted_metrics:要在訓練和測試期間評估並由 sample_weightclass_weight 加權的指標列表。
  • run_eagerly:布林值。如果為 True,則永遠不會編譯此模型的前向傳遞。建議在訓練時將其保留為 False(以獲得最佳效能),並在偵錯時將其設定為 True
  • steps_per_execution:整數。每次單一編譯函數呼叫期間要執行的批次數量。在單一編譯函數呼叫內執行多個批次可以大幅改善 TPU 或具有大量 Python 額外負荷的小型模型的效能。每次執行最多會執行一個完整週期。如果傳遞的數字大於週期的長度,則執行將會截斷為週期的長度。請注意,如果 steps_per_execution 設定為 N,則只會每 N 個批次呼叫 Callback.on_batch_beginCallback.on_batch_end 方法(即在每次編譯函數執行之前/之後)。PyTorch 後端不支援。
  • jit_compile:布林值或 "auto"。是否在編譯模型時使用 XLA 編譯。對於 jaxtensorflow 後端,如果模型支援 XLA 編譯,jit_compile="auto" 會啟用 XLA 編譯,否則會停用。對於 torch 後端,"auto" 將預設為 eager 執行,而 jit_compile=True 將使用 "inductor" 後端執行 torch.compile
  • auto_scale_loss:布林值。如果為 True 且模型 dtype 策略為 "mixed_float16",則傳遞的最佳化器將會自動包裝在 LossScaleOptimizer 中,這將會動態縮放損失以防止下溢。

[來源]

fit 方法

Model.fit(
    x=None,
    y=None,
    batch_size=None,
    epochs=1,
    verbose="auto",
    callbacks=None,
    validation_split=0.0,
    validation_data=None,
    shuffle=True,
    class_weight=None,
    sample_weight=None,
    initial_epoch=0,
    steps_per_epoch=None,
    validation_steps=None,
    validation_batch_size=None,
    validation_freq=1,
)

針對固定數量的週期(資料集迭代)訓練模型。

引數

  • x:輸入資料。它可以是
    • NumPy 陣列(或類陣列),或陣列列表(如果模型具有多個輸入)。
    • 後端原生張量,或張量列表(如果模型具有多個輸入)。
    • 字典,將輸入名稱映射到對應的陣列/張量,如果模型具有具名輸入。
    • keras.utils.PyDataset,傳回 (inputs, targets)(inputs, targets, sample_weights)
    • tf.data.Dataset,產生 (inputs, targets)(inputs, targets, sample_weights)
    • torch.utils.data.DataLoader,產生 (inputs, targets)(inputs, targets, sample_weights)
    • Python 產生器函數,產生 (inputs, targets)(inputs, targets, sample_weights)
  • y:目標資料。與輸入資料 x 類似,它可以是 NumPy 陣列或後端原生張量。如果 xkeras.utils.PyDatasettf.data.Datasettorch.utils.data.DataLoader 或 Python 產生器函數,則不應指定 y,因為目標將從 x 取得。
  • batch_size:整數或 None。每次梯度更新的樣本數。如果未指定,batch_size 將預設為 32。如果您的輸入資料 xkeras.utils.PyDatasettf.data.Datasettorch.utils.data.DataLoader 或 Python 產生器函數,請勿指定 batch_size,因為它們會產生批次。
  • epochs:整數。要訓練模型的週期數。一個週期是對提供的整個 xy 資料進行一次迭代(除非 steps_per_epoch 旗標設定為 None 以外的值)。請注意,結合 initial_epochepochs 應理解為「最終週期」。模型不會針對 epochs 給定的迭代次數進行訓練,而僅僅訓練到達到索引為 epochs 的週期為止。
  • verbose"auto"、0、1 或 2。詳細模式。0 = 靜音,1 = 進度列,2 = 每週期一行。"auto" 在大多數情況下變為 1。請注意,當記錄到檔案時,進度列不是特別有用,因此建議在非互動式執行時(例如,在生產環境中)使用 verbose=2。預設為 "auto"
  • callbackskeras.callbacks.Callback 實例列表。要在訓練期間套用的回呼列表。請參閱 keras.callbacks。請注意,keras.callbacks.ProgbarLoggerkeras.callbacks.History 回呼會自動建立,無需傳遞給 model.fit()keras.callbacks.ProgbarLogger 是否建立取決於 model.fit() 中的 verbose 引數。
  • validation_split:介於 0 和 1 之間的浮點數。用作驗證資料的訓練資料比例。模型將會分開訓練資料的此比例,不會對其進行訓練,並會在每個週期結束時評估此資料的損失和任何模型指標。驗證資料是從提供的 xy 資料中的最後樣本選取的,在混洗之前。僅當 xy 由 NumPy 陣列或張量組成時,才支援此引數。如果同時提供了 validation_datavalidation_split,則 validation_data 將會覆寫 validation_split
  • validation_data:在每個週期結束時評估損失和任何模型指標的資料。模型不會在此資料上進行訓練。因此,請注意使用 validation_splitvalidation_data 提供的資料的驗證損失不受雜訊和 dropout 等正規化層的影響。validation_data 將會覆寫 validation_split。它可以是
    • NumPy 陣列或張量的元組 (x_val, y_val)
    • NumPy 陣列的元組 (x_val, y_val, val_sample_weights)
    • keras.utils.PyDatasettf.data.Datasettorch.utils.data.DataLoader,產生 (inputs, targets) 或 Python 產生器函數,產生 (x_val, y_val)(inputs, targets, sample_weights)
  • shuffle:布林值,是否在每個週期之前混洗訓練資料。當 xkeras.utils.PyDatasettf.data.Datasettorch.utils.data.DataLoader 或 Python 產生器函數時,將會忽略此引數。
  • class_weight:選用的字典,將類別索引(整數)映射到權重(浮點數)值,用於權衡損失函數(僅在訓練期間)。這可用於告知模型「更關注」來自代表性不足類別的樣本。當指定 class_weight 且目標的秩為 2 或更大時,y 必須是 one-hot 編碼,或者必須為稀疏類別標籤包含明確的最終維度 1
  • sample_weight:選用的 NumPy 陣列或訓練樣本權重的張量,用於權衡損失函數(僅在訓練期間)。您可以傳遞與輸入樣本長度相同的平面 (1D) NumPy 陣列或張量(權重與樣本之間的 1:1 映射),或者在時間序列資料的情況下,您可以傳遞形狀為 (samples, sequence_length) 的 2D NumPy 陣列或張量,以將不同的權重套用至每個樣本的每個時間步。當 xkeras.utils.PyDatasettf.data.Datasettorch.utils.data.DataLoader 或 Python 產生器函數時,不支援此引數。相反地,請將 sample_weights 作為 x 的第三個元素提供。請注意,樣本加權不適用於透過 compile() 中的 metrics 引數指定的指標。若要將樣本加權套用至您的指標,您可以改為透過 compile() 中的 weighted_metrics 指定它們。
  • initial_epoch:整數。開始訓練的週期(適用於恢復先前的訓練執行)。
  • steps_per_epoch:整數或 None。宣告一個週期完成並開始下一個週期之前的總步數(樣本批次)。當使用輸入張量或 NumPy 陣列進行訓練時,預設值 None 表示使用的值是資料集中樣本數除以批次大小,如果無法判斷,則為 1。如果 xkeras.utils.PyDatasettf.data.Datasettorch.utils.data.DataLoader 或 Python 產生器函數,則週期將執行到輸入資料集耗盡為止。當傳遞無限重複的資料集時,您必須指定 steps_per_epoch 引數,否則訓練將會無限期執行。
  • validation_steps:整數或 None。僅在提供 validation_data 時相關。在每個週期結束時執行驗證時,在停止之前要繪製的總步數(樣本批次)。如果 validation_stepsNone,則驗證將執行到 validation_data 資料集耗盡為止。在無限重複資料集的情況下,它將無限期執行。如果指定了 validation_steps 且僅使用了部分資料集,則評估將在每個週期從資料集的開頭開始。這確保了每次都使用相同的驗證樣本。
  • validation_batch_size:整數或 None。每個驗證批次的樣本數。如果未指定,則預設為 batch_size。如果您的資料是 keras.utils.PyDatasettf.data.Datasettorch.utils.data.DataLoader 或 Python 產生器函數,請勿指定 validation_batch_size,因為它們會產生批次。
  • validation_freq:僅在提供驗證資料時相關。指定在執行新的驗證執行之前要執行的訓練週期數,例如 validation_freq=2 每 2 個週期執行一次驗證。

類似迭代器輸入的解包行為:常見的模式是將類似迭代器的物件(例如 tf.data.Datasetkeras.utils.PyDataset)傳遞給 fit(),這實際上不僅會產生特徵 (x),還可能會產生目標 (y) 和樣本權重 (sample_weight)。Keras 要求此類類似迭代器的輸出必須明確。迭代器應傳回長度為 1、2 或 3 的元組,其中選用的第二個和第三個元素將分別用於 ysample_weight。提供的任何其他類型都將包裝在長度為 1 的元組中,有效地將所有內容視為 x。當產生字典時,它們仍應遵守最上層的元組結構,例如 ({"x0": x0, "x1": x1}, y)。Keras 不會嘗試從單一字典的鍵中分離特徵、目標和權重。一個值得注意的不支援資料類型是 namedtuple。原因是它的行為既像有序資料類型(元組)又像映射資料類型(字典)。因此,給定形式為 namedtuple("example_tuple", ["y", "x"]) 的 namedtuple,在解釋值時是否要反轉元素的順序是不明確的。更糟的是形式為 namedtuple("other_tuple", ["x", "y", "z"]) 的元組,其中不清楚元組是打算解包到 xysample_weight 中,還是作為單一元素傳遞到 x

傳回

History 物件。其 History.history 屬性是連續週期中訓練損失值和指標值的記錄,以及驗證損失值和驗證指標值(如果適用)。


[來源]

evaluate 方法

Model.evaluate(
    x=None,
    y=None,
    batch_size=None,
    verbose="auto",
    sample_weight=None,
    steps=None,
    callbacks=None,
    return_dict=False,
    **kwargs
)

傳回模型在測試模式下的損失值和指標值。

計算以批次方式完成(請參閱 batch_size 引數)。

引數

  • x:輸入資料。它可以是
    • NumPy 陣列(或類陣列),或陣列列表(如果模型具有多個輸入)。
    • 後端原生張量,或張量列表(如果模型具有多個輸入)。
    • 字典,將輸入名稱映射到對應的陣列/張量,如果模型具有具名輸入。
    • keras.utils.PyDataset,傳回 (inputs, targets)(inputs, targets, sample_weights)
    • tf.data.Dataset,產生 (inputs, targets)(inputs, targets, sample_weights)
    • torch.utils.data.DataLoader,產生 (inputs, targets)(inputs, targets, sample_weights)
    • Python 產生器函數,產生 (inputs, targets)(inputs, targets, sample_weights)
  • y:目標資料。與輸入資料 x 類似,它可以是 NumPy 陣列或後端原生張量。如果 xkeras.utils.PyDatasettf.data.Datasettorch.utils.data.DataLoader 或 Python 產生器函數,則不應指定 y,因為目標將從 x 取得。
  • batch_size:整數或 None。每次計算批次的樣本數。如果未指定,batch_size 將預設為 32。如果您的輸入資料 xkeras.utils.PyDatasettf.data.Datasettorch.utils.data.DataLoader 或 Python 產生器函數,請勿指定 batch_size,因為它們會產生批次。
  • verbose"auto"、0、1 或 2。詳細模式。0 = 靜音,1 = 進度列,2 = 單行。"auto" 在大多數情況下變為 1。請注意,當記錄到檔案時,進度列不是特別有用,因此建議在非互動式執行時(例如在生產環境中)使用 verbose=2。預設為 "auto"
  • sample_weight:選用的 NumPy 陣列或訓練樣本權重的張量,用於權衡損失函數(僅在訓練期間)。您可以傳遞與輸入樣本長度相同的平面 (1D) NumPy 陣列或張量(權重與樣本之間的 1:1 映射),或者在時間序列資料的情況下,您可以傳遞形狀為 (samples, sequence_length) 的 2D NumPy 陣列或張量,以將不同的權重套用至每個樣本的每個時間步。當 xkeras.utils.PyDatasettf.data.Datasettorch.utils.data.DataLoader 或 Python 產生器函數時,不支援此引數。相反地,請將 sample_weights 作為 x 的第三個元素提供。請注意,樣本加權不適用於透過 compile() 中的 metrics 引數指定的指標。若要將樣本加權套用至您的指標,您可以改為透過 compile() 中的 weighted_metrics 指定它們。
  • steps:整數或 None。宣告評估回合完成之前要繪製的總步數(樣本批次)。如果 stepsNone,則將執行到 x 耗盡為止。在無限重複資料集的情況下,它將無限期執行。
  • callbackskeras.callbacks.Callback 實例列表。要在評估期間套用的回呼列表。
  • return_dict:如果為 True,則損失和指標結果會以字典形式傳回,每個鍵都是指標的名稱。如果為 False,則會以列表形式傳回。

傳回

純量測試損失(如果模型具有單一輸出且沒有指標)或純量列表(如果模型具有多個輸出和/或指標)。屬性 model.metrics_names 將為您提供純量輸出的顯示標籤。


[來源]

predict 方法

Model.predict(x, batch_size=None, verbose="auto", steps=None, callbacks=None)

為輸入樣本產生輸出預測。

計算以批次方式完成。此方法旨在用於大量輸入的批次處理。它不適用於在資料上迭代並一次處理少量輸入的迴圈內使用。

對於適合一個批次的小量輸入,直接使用 __call__() 以獲得更快的執行速度,例如,model(x),或如果您有在推論期間行為不同的層(例如 BatchNormalization),則使用 model(x, training=False)

注意:有關 Model 方法 predict()__call__() 之間差異的更多詳細資訊,請參閱 此常見問題解答條目

引數

  • x:輸入資料。它可以是
    • NumPy 陣列(或類陣列),或陣列列表(如果模型具有多個輸入)。
    • 後端原生張量,或張量列表(如果模型具有多個輸入)。
    • 字典,將輸入名稱映射到對應的陣列/張量,如果模型具有具名輸入。
    • keras.utils.PyDataset
    • tf.data.Dataset
    • torch.utils.data.DataLoader
    • Python 產生器函數。
  • batch_size:整數或 None。每次計算批次的樣本數。如果未指定,batch_size 將預設為 32。如果您的輸入資料 xkeras.utils.PyDatasettf.data.Datasettorch.utils.data.DataLoader 或 Python 產生器函數,請勿指定 batch_size,因為它們會產生批次。
  • verbose"auto"、0、1 或 2。詳細模式。0 = 靜音,1 = 進度列,2 = 單行。"auto" 在大多數情況下變為 1。請注意,當記錄到檔案時,進度列不是特別有用,因此建議在非互動式執行時(例如在生產環境中)使用 verbose=2。預設為 "auto"
  • steps:宣告預測回合完成之前要繪製的總步數(樣本批次)。如果 stepsNone,則將執行到 x 耗盡為止。在無限重複資料集的情況下,它將無限期執行。
  • callbackskeras.callbacks.Callback 實例列表。要在預測期間套用的回呼列表。

傳回

NumPy 預測陣列。


[來源]

train_on_batch 方法

Model.train_on_batch(
    x, y=None, sample_weight=None, class_weight=None, return_dict=False
)

在單一批次資料上執行單次梯度更新。

引數

  • x:輸入資料。必須是類陣列。
  • y:目標資料。必須是類陣列。
  • sample_weight:選用的與 x 長度相同的陣列,包含要套用於每個樣本的模型損失的權重。在時間序列資料的情況下,您可以傳遞形狀為 (samples, sequence_length) 的 2D 陣列,以將不同的權重套用至每個樣本的每個時間步。
  • class_weight:選用的字典,將類別索引(整數)映射到權重(浮點數),以在訓練期間套用於來自此類別的樣本的模型損失。這可用於告知模型「更關注」來自代表性不足類別的樣本。當指定 class_weight 且目標的秩為 2 或更大時,y 必須是 one-hot 編碼,或者必須為稀疏類別標籤包含明確的最終維度 1。
  • return_dict:如果為 True,則損失和指標結果會以字典形式傳回,每個鍵都是指標的名稱。如果為 False,則會以列表形式傳回。

傳回

純量損失值(當沒有指標且 return_dict=False 時)、損失和指標值列表(如果存在指標且 return_dict=False 時),或指標和損失值字典(如果 return_dict=True 時)。


[來源]

test_on_batch 方法

Model.test_on_batch(x, y=None, sample_weight=None, return_dict=False)

在單一批次樣本上測試模型。

引數

  • x:輸入資料。必須是類陣列。
  • y:目標資料。必須是類陣列。
  • sample_weight:選用的與 x 長度相同的陣列,包含要套用於每個樣本的模型損失的權重。在時間序列資料的情況下,您可以傳遞形狀為 (samples, sequence_length) 的 2D 陣列,以將不同的權重套用至每個樣本的每個時間步。
  • return_dict:如果為 True,則損失和指標結果會以字典形式傳回,每個鍵都是指標的名稱。如果為 False,則會以列表形式傳回。

傳回

純量損失值(當沒有指標且 return_dict=False 時)、損失和指標值列表(如果存在指標且 return_dict=False 時),或指標和損失值字典(如果 return_dict=True 時)。


[來源]

predict_on_batch 方法

Model.predict_on_batch(x)

傳回單一批次樣本的預測。

引數

  • x:輸入資料。它必須是類陣列。

傳回

NumPy 預測陣列。