隆重介紹 Keras 3.0

經過五個月廣泛的公開測試後,我們很高興宣布 Keras 3.0 正式發布。Keras 3 是 Keras 的完全重寫版本,可讓您在 JAX、TensorFlow、PyTorch 或 OpenVINO(僅用於推論)之上運行 Keras 工作流程,並解鎖全新的大規模模型訓練和部署功能。您可以選擇最適合您的框架,並根據您當前的目標在框架之間切換。您也可以將 Keras 作為低階跨框架語言,來開發可在 JAX、TensorFlow 或 PyTorch 原生工作流程中使用的自訂組件,例如層、模型或指標,而只需要一個程式碼庫。


歡迎來到多框架機器學習。

您已經熟悉使用 Keras 的好處 — 它透過專注於優異的使用者體驗、API 設計和除錯能力,實現了高速開發。它也是一個經過實戰考驗的框架,已被超過 250 萬開發人員選用,並為世界上一些最複雜、最大規模的 ML 系統提供支援,例如 Waymo 自動駕駛車隊和 YouTube 推薦引擎。但是,使用新的多後端 Keras 3 有哪些額外的好處呢?

  • 永遠為您的模型取得最佳效能。 在我們的基準測試中,我們發現 JAX 通常在 GPU、TPU 和 CPU 上提供最佳的訓練和推論效能 — 但結果因模型而異,因為非 XLA TensorFlow 在 GPU 上有時速度更快。能夠動態選擇能為您的模型提供最佳效能的後端,而無需更改任何程式碼,這表示您可以保證以最高的效率進行訓練和服務。
  • 為您的模型解鎖生態系統的選擇性。 任何 Keras 3 模型都可以實例化為 PyTorch Module,可以匯出為 TensorFlow SavedModel,也可以實例化為無狀態 JAX 函數。這表示您可以將您的 Keras 3 模型與 PyTorch 生態系統套件、全系列的 TensorFlow 部署和生產工具(例如 TF-Serving、TF.js 和 TFLite)以及 JAX 大規模 TPU 訓練基礎架構一起使用。使用 Keras 3 API 撰寫一個 model.py,即可存取 ML 世界所能提供的一切。
  • 利用 JAX 進行大規模模型平行化和資料平行化。 Keras 3 包含一個全新的分散式 API,即 keras.distribution 名稱空間,目前已針對 JAX 後端實作(即將推出到 TensorFlow 和 PyTorch 後端)。它可以輕鬆地進行模型平行化、資料平行化以及兩者的組合 — 在任意模型規模和叢集規模下進行。由於它將模型定義、訓練邏輯和分片組態彼此分開,因此讓您的分散式工作流程易於開發和維護。請參閱我們的入門指南
  • 最大限度地提高您的開源模型發布的覆蓋範圍。 想要發布預訓練模型嗎?希望盡可能多的人可以使用它嗎?如果您使用純 TensorFlow 或 PyTorch 實作它,則大約只有一半的社群可以使用它。如果您使用 Keras 3 實作它,則無論他們選擇哪個框架(即使他們本身不是 Keras 用戶),都可以立即使用它。在沒有增加開發成本的情況下,影響力增加一倍。
  • 使用來自任何來源的資料管道。 無論您使用的是哪個後端,Keras 3 的 fit()/evaluate()/predict() 常式都與 tf.data.Dataset 物件、PyTorch DataLoader 物件、NumPy 陣列、Pandas 資料框架相容。您可以在 PyTorch DataLoader 上訓練 Keras 3 + TensorFlow 模型,或在 tf.data.Dataset 上訓練 Keras 3 + PyTorch 模型。

完整的 Keras API,適用於 JAX、TensorFlow 和 PyTorch。

Keras 3 實作了完整的 Keras API,並使其可與 TensorFlow、JAX 和 PyTorch 一起使用 — 超過一百個層、數十個指標、損失函數、最佳化器和回呼、Keras 訓練和評估迴圈以及 Keras 儲存和序列化基礎架構。您熟悉和喜愛的所有 API 都在這裡。

任何僅使用內建層的 Keras 模型都可以立即與所有支援的後端一起使用。事實上,您現有的僅使用內建層的 tf.keras 模型可以立即開始在 JAX 和 PyTorch 中執行!沒錯,您的程式碼庫剛剛獲得了一整套新的功能。


建立多框架層、模型、指標...

Keras 3 可讓您建立在任何框架中都以相同方式運作的組件(例如任意自訂層或預訓練模型)。特別是,Keras 3 可讓您存取跨所有後端運作的 keras.ops 名稱空間。它包含

  • NumPy API 的完整實作。 不是「類 NumPy」的東西 — 而是字面上的 NumPy API,具有相同的函數和相同的引數。您會得到 ops.matmulops.sumops.stackops.einsum 等。
  • 一組 NumPy 中沒有的類神經網路特定函數,例如 ops.softmaxops.binary_crossentropyops.conv 等。

只要您僅使用 keras.ops 中的運算,您的自訂層、自訂損失、自訂指標和自訂最佳化器將與 JAX、PyTorch 和 TensorFlow 一起使用 — 使用相同的程式碼。這表示您只需要維護一個組件實作(例如,一個 model.py 和一個檢查點檔案),並且可以在所有框架中使用它,且具有完全相同的數值。


...與任何 JAX、TensorFlow 和 PyTorch 工作流程無縫協作。

Keras 3 不僅適用於以 Keras 為中心的工作流程,在這些工作流程中,您定義 Keras 模型、Keras 最佳化器、Keras 損失和指標,然後呼叫 fit()evaluate()predict()。它也旨在與低階後端原生工作流程無縫協作:您可以採用 Keras 模型(或任何其他組件,例如損失或指標),並開始在 JAX 訓練迴圈、TensorFlow 訓練迴圈或 PyTorch 訓練迴圈中使用它,或作為 JAX 或 PyTorch 模型的一部分,而不會產生任何阻礙。Keras 3 在 JAX 和 PyTorch 中提供的低階實作彈性與 tf.keras 先前在 TensorFlow 中提供的完全相同。

您可以

  • 撰寫一個低階 JAX 訓練迴圈,以使用 optax 最佳化器、jax.gradjax.jitjax.pmap 來訓練 Keras 模型。
  • 撰寫一個低階 TensorFlow 訓練迴圈,以使用 tf.GradientTapetf.distribute 來訓練 Keras 模型。
  • 撰寫一個低階 PyTorch 訓練迴圈,以使用 torch.optim 最佳化器、torch 損失函數和 torch.nn.parallel.DistributedDataParallel 包裝器來訓練 Keras 模型。
  • 在 PyTorch Module 中使用 Keras 層(因為它們也是 Module 實例!)
  • 在 Keras 模型中使用任何 PyTorch Module,就像它是 Keras 層一樣。
  • 等等。


用於大規模資料平行化和模型平行化的全新分散式 API。

我們一直在使用的模型越來越大,因此我們希望為多裝置模型分片問題提供 Keras 式解決方案。我們設計的 API 將模型定義、訓練邏輯和分片組態完全彼此分開,這表示您可以像模型將在單一裝置上執行一樣撰寫模型。然後,您可以在訓練模型時將任意分片組態新增至任意模型。

資料平行化(在多個裝置上完全複製小型模型)只需兩行即可處理

模型平行化可讓您指定模型變數和中間輸出張量的分片佈局,沿著多個命名維度。在典型情況下,您會將可用裝置組織為 2D 網格(稱為裝置網格),其中第一個維度用於資料平行化,第二個維度用於模型平行化。然後,您會將模型設定為沿著模型維度進行分片,並沿著資料維度進行複製。

API 可讓您透過正規表示式設定每個變數和每個輸出張量的佈局。這讓您可以輕鬆快速地為整個類別的變數指定相同的佈局。

新的分散式 API 旨在支援多後端,但目前僅適用於 JAX 後端。TensorFlow 和 PyTorch 支援即將推出。請參考本指南開始使用!


預訓練模型。

您可以立即開始使用 Keras 3 的各種預訓練模型。

所有 40 個 Keras 應用程式模型(keras.applications 名稱空間)都可在所有後端中使用。KerasCVKerasHub 中大量的預訓練模型也適用於所有後端。這包括

  • BERT
  • OPT
  • Whisper
  • T5
  • StableDiffusion
  • YOLOv8
  • SegmentAnything
  • 等等。

支援所有後端的跨框架資料管道。

多框架 ML 也表示多框架資料載入和預處理。Keras 3 模型可以使用各種資料管道進行訓練 — 無論您使用的是 JAX、PyTorch 或 TensorFlow 後端。它可以正常運作。

  • tf.data.Dataset 管道:可擴展生產 ML 的參考。
  • torch.utils.data.DataLoader 物件。
  • NumPy 陣列和 Pandas 資料框架。
  • Keras 自己的 keras.utils.PyDataset 物件。

複雜性的漸進揭露。

複雜性的漸進揭露是 Keras API 核心的設計原則。Keras 不會強迫您遵循單一的「真實」模型建構和訓練方式。相反,它支援各種不同的工作流程,從非常高階到非常低階,對應於不同的使用者設定檔。

這表示您可以從簡單的工作流程開始 — 例如使用 SequentialFunctional 模型並使用 fit() 訓練它們 — 當您需要更多彈性時,您可以輕鬆自訂不同的組件,同時重複使用您的大部分先前程式碼。隨著您的需求變得更具體,您不會突然跌落複雜性懸崖,也不需要切換到不同的工具組。

我們已將此原則應用於所有後端。例如,您可以自訂訓練迴圈中的行為,同時仍然能利用 fit() 的強大功能,而無需從頭開始編寫自己的訓練迴圈 — 僅需覆寫 train_step 方法即可。

以下是在 PyTorch 和 TensorFlow 中的運作方式

這裡有 JAX 版本的連結


用於層、模型、指標和最佳化器的新無狀態 API。

您喜歡函數式程式設計嗎?您將會感到非常驚喜。

Keras 中所有具狀態的物件(即擁有在訓練或評估期間會更新的數值變數的物件)現在都有一個無狀態 API,使其可以在 JAX 函數中使用(這些函數必須是完全無狀態的)。

  • 所有層和模型都有一個 stateless_call() 方法,此方法會鏡像 __call__()
  • 所有最佳化器都有一個 stateless_apply() 方法,此方法會鏡像 apply()
  • 所有指標都有一個 stateless_update_state() 方法,此方法會鏡像 update_state(),以及一個 stateless_result() 方法,此方法會鏡像 result()

這些方法完全沒有副作用:它們將目標物件的狀態變數的目前值作為輸入,並將更新的值作為其輸出的一部分傳回,例如:

outputs, updated_non_trainable_variables = layer.stateless_call(
    trainable_variables,
    non_trainable_variables,
    inputs,
)

您永遠不必自己實作這些方法 — 只要您實作了具狀態的版本(例如 call()update_state()),它們就會自動可用。


使用 OpenVINO 後端執行推論。

從 3.8 版本開始,Keras 引入了 OpenVINO 後端,這是一個僅限於推論的後端,表示它僅設計用於使用 predict() 方法執行模型預測。此後端可以直接在 Keras 工作流程中利用 OpenVINO 的效能優化,從而在 OpenVINO 支援的硬體上實現更快的推論。

若要切換到 OpenVINO 後端,請將 KERAS_BACKEND 環境變數設定為 "openvino",或在 ~/.keras/keras.json 的本機設定檔中指定後端。以下是示範如何使用 OpenVINO 後端推論模型(使用 PyTorch、JAX 或 TensorFlow 後端訓練)的範例

import os
os.environ["KERAS_BACKEND"] = "openvino"
import keras

loaded_model = keras.saving.load_model(...)
predictions = loaded_model.predict(...)

請注意,OpenVINO 後端目前可能缺少對某些操作的支援。這將在未來的 Keras 版本中解決,因為操作範圍正在擴大。


從 Keras 2 遷移到 Keras 3

Keras 3 與 Keras 2 高度向後相容:它實作了 Keras 2 的完整公開 API 介面,只有少數例外情況,列在這裡。大多數使用者無需進行任何程式碼變更即可開始在 Keras 3 上執行其 Keras 腳本。

較大的程式碼庫可能需要進行一些程式碼變更,因為它們更有可能遇到上述列出的例外情況之一,並且更有可能使用了私有 API 或已棄用的 API(tf.compat.v1.keras 命名空間、experimental 命名空間、keras.src 私有命名空間)。為了協助您遷移到 Keras 3,我們發布了完整的遷移指南,其中提供了您可能遇到的所有問題的快速修復方法。

您也可以選擇忽略 Keras 3 中的變更,並繼續將 Keras 2 與 TensorFlow 一起使用 — 這對於未積極開發但需要使用更新的相依性來繼續執行的專案來說可能是一個不錯的選擇。您有兩種可能性

  1. 如果您是將 keras 作為獨立套件存取,只需切換為改用 Python 套件 tf_keras 即可,您可以透過 pip install tf_keras 安裝。程式碼和 API 完全沒有變更 — 這是具有不同套件名稱的 Keras 2.15。我們將繼續修復 tf_keras 中的錯誤,並且我們將定期發布新版本。但是,由於該套件現在處於維護模式,因此不會新增任何新功能或效能改進。
  2. 如果您是透過 tf.keras 存取 keras,則在 TensorFlow 2.16 之前不會有任何立即的變更。TensorFlow 2.16+ 預設會使用 Keras 3。在 TensorFlow 2.16+ 中,若要繼續使用 Keras 2,您可以先安裝 tf_keras,然後匯出環境變數 TF_USE_LEGACY_KERAS=1。這將指示 TensorFlow 2.16+ 將 tf.keras 解析為本機安裝的 tf_keras 套件。請注意,這可能會影響的不僅僅是您自己的程式碼:它會影響 Python 程序中匯入 tf.keras 的任何套件。為了確保您的變更僅影響您自己的程式碼,您應該使用 tf_keras 套件。

祝您使用愉快!

我們很高興您能試用新的 Keras,並透過利用多框架 ML 來改進您的工作流程。請讓我們知道您的使用情況:問題、摩擦點、功能請求或成功案例 — 我們很樂意收到您的回饋!


常見問題

問:Keras 3 與舊版 Keras 2 相容嗎?

使用 tf.keras 開發的程式碼通常可以按原樣使用 Keras 3(使用 TensorFlow 後端)執行。有一些您應該注意的相容性問題,這些問題都在此遷移指南中解決。

當涉及到並排使用 tf.keras 和 Keras 3 的 API 時,這是不可能的 — 它們是不同的套件,在完全不同的引擎上執行。

問:在舊版 Keras 2 中開發的預訓練模型是否適用於 Keras 3?

一般來說,是的。任何 tf.keras 模型都應該可以開箱即用地與使用 TensorFlow 後端的 Keras 3 搭配使用(請確保將其儲存為 .keras v3 格式)。此外,如果模型僅使用內建的 Keras 層,則它也可以開箱即用地與使用 JAX 和 PyTorch 後端的 Keras 3 搭配使用。

如果模型包含使用 TensorFlow API 撰寫的自訂層,則通常很容易將程式碼轉換為與後端無關。例如,我們僅花了幾個小時就將 Keras Applications 中的所有 40 個舊版 tf.keras 模型轉換為與後端無關。

問:我可以將 Keras 3 模型儲存在一個後端中,然後在另一個後端中重新載入嗎?

是的,您可以。在儲存的 .keras 檔案中沒有任何後端專用性。您儲存的 Keras 模型與框架無關,並且可以使用任何後端重新載入。

但是,請注意,使用不同的後端重新載入包含自訂元件的模型需要您的自訂元件使用與後端無關的 API 實作,例如 keras.ops

問:我可以在 tf.data 管線內使用 Keras 3 元件嗎?

使用 TensorFlow 後端,Keras 3 完全與 tf.data 相容(例如,您可以將 Sequential 模型 .map()tf.data 管線中)。

使用不同的後端,Keras 3 對 tf.data 的支援有限。您將無法將任意層或模型 .map()tf.data 管線中。但是,您可以使用特定 Keras 3 預處理層與 tf.data,例如 IntegerLookupCategoryEncoding

當涉及到使用 tf.data 管線(未使用 Keras)來饋送您對 .fit().evaluate().predict() 的呼叫時 — 這在所有後端中都可以開箱即用地使用。

問:使用不同後端執行時,Keras 3 模型的行為是否相同?

是的,數值在所有後端中都是相同的。但是,請記住以下注意事項

  • RNG 行為在不同後端之間是不同的(即使在設定種子之後 — 您的結果在每個後端中都將是確定性的,但在後端之間會有所不同)。因此,隨機權重初始化值和 dropout 值在後端之間會有所不同。
  • 由於浮點實作的性質,在 float32 中,每次函數執行結果僅在 1e-7 的精確度內相同。因此,當長時間訓練模型時,小的數值差異會累積,並可能最終導致顯著的數值差異。
  • 由於 PyTorch 中缺少對具有不對稱填充的平均池化支援,因此 padding="same" 的平均池化層可能會導致邊界行/列上的數值不同。這種情況在實踐中並不常發生 — 在 40 個 Keras Applications 視覺模型中,只有一個受到影響。

問:Keras 3 是否支援分散式訓練?

JAX、TensorFlow 和 PyTorch 均開箱即用地支援資料並行分佈。使用 keras.distribution API,JAX 開箱即用地支援模型並行分佈。

使用 TensorFlow

Keras 3 與 tf.distribute 相容 — 只需開啟 Distribution Strategy 範圍並在其中建立/訓練您的模型即可。這裡有一個範例

使用 PyTorch

Keras 3 與 PyTorch 的 DistributedDataParallel 工具相容。這裡有一個範例

使用 JAX

您可以使用 keras.distribution API 在 JAX 中執行資料並行和模型並行分佈。例如,若要執行資料並行分佈,您只需要以下程式碼片段

distribution = keras.distribution.DataParallel(devices=keras.distribution.list_devices())
keras.distribution.set_distribution(distribution)

若要瞭解模型並行分佈,請參閱以下指南

您也可以透過 JAX API(例如 jax.sharding)自行分配訓練。這裡有一個範例

問:我的自訂 Keras 層是否可以在原生 PyTorch Modules 中使用或與 Flax Modules 一起使用?

如果它們僅使用 Keras API(例如 keras.ops 命名空間)撰寫,那麼是的,您的 Keras 層將開箱即用地與原生 PyTorch 和 JAX 程式碼搭配使用。在 PyTorch 中,只需像其他 PyTorch Module 一樣使用您的 Keras 層即可。在 JAX 中,請確保使用無狀態層 API,即 layer.stateless_call()

問:您未來會新增更多後端嗎?框架 XYZ 呢?

我們樂於新增新的後端,只要目標框架具有龐大的使用者群或在技術上具有一些獨特的優勢。但是,新增和維護新的後端是一項繁重的負擔,因此我們將仔細考慮每個新的後端候選者,並且我們不太可能新增許多新的後端。我們不會新增任何尚未建立完善的框架。我們現在可能會考慮新增一個使用 Mojo 撰寫的後端。如果您覺得這可能對您有用,請告知 Mojo 團隊。