Keras 3 API 文件 / 模型 API / 儲存與序列化 / 模型匯出以進行推論

模型匯出以進行推論

[來源]

export 方法

Model.export(
    filepath, format="tf_saved_model", verbose=None, input_signature=None, **kwargs
)

將模型匯出為推論的成品。

引數

  • filepathstrpathlib.Path 物件。儲存成品的路徑。
  • formatstr。匯出格式。支援的值:"tf_saved_model""onnx"。預設為 "tf_saved_model"
  • verbosebool。是否在匯出期間列印訊息。預設為 None,這會使用不同後端和格式設定的預設值。
  • input_signature:選用。指定模型輸入的形狀和 dtype。可以是 keras.InputSpectf.TensorSpecbackend.KerasTensor 或後端張量的結構。若未提供,將會自動計算。預設為 None
  • **kwargs:其他關鍵字引數
    • 特定於 JAX 後端和 format="tf_saved_model": - is_static:選用 bool。指出 fn 是否為靜態。如果 fn 涉及狀態更新(例如,RNG 種子和計數器),則設定為 False。 - jax2tf_kwargs:選用 dictjax2tf.convert 的引數。請參閱 jax2tf.convert 的文件。如果未提供 native_serializationpolymorphic_shapes,則會自動計算。

注意: 此功能目前僅支援 TensorFlow、JAX 和 Torch 後端。

注意: 請注意,當使用 format="onnx"verbose=True 和 Torch 後端時,匯出的成品可能包含來自本機檔案系統的資訊。

範例

以下是如何匯出 TensorFlow SavedModel 以進行推論。

# Export the model as a TensorFlow SavedModel artifact
model.export("path/to/location", format="tf_saved_model")

# Load the artifact in a different process/environment
reloaded_artifact = tf.saved_model.load("path/to/location")
predictions = reloaded_artifact.serve(input_data)

以下是如何匯出 ONNX 以進行推論。

# Export the model as a ONNX artifact
model.export("path/to/location", format="onnx")

# Load the artifact in a different process/environment
ort_session = onnxruntime.InferenceSession("path/to/location")
ort_inputs = {
    k.name: v for k, v in zip(ort_session.get_inputs(), input_data)
}
predictions = ort_session.run(None, ort_inputs)

[來源]

ExportArchive 類別

keras.export.ExportArchive()

ExportArchive 用於寫入 SavedModel 成品(例如,用於推論)。

如果您有想要匯出為 SavedModel 以進行服務(例如,透過 TensorFlow-Serving)的 Keras 模型或層,您可以使用 ExportArchive 來設定您需要提供的不同服務端點,以及它們的簽章。只需實例化 ExportArchive,使用 track() 註冊要使用的層或模型,然後使用 add_endpoint() 方法註冊新的服務端點。完成後,使用 write_out() 方法儲存成品。

產生的成品是 SavedModel,可以透過 tf.saved_model.load 重新載入。

範例

以下是如何匯出模型以進行推論。

export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
    name="serve",
    fn=model.call,
    input_signature=[keras.InputSpec(shape=(None, 3), dtype="float32")],
)
export_archive.write_out("path/to/location")

# Elsewhere, we can reload the artifact and serve it.
# The endpoint we added is available as a method:
serving_model = tf.saved_model.load("path/to/location")
outputs = serving_model.serve(inputs)

以下是如何匯出具有一個用於推論的端點和一個用於訓練模式正向傳遞的端點(例如,啟用 dropout)的模型。

export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
    name="call_inference",
    fn=lambda x: model.call(x, training=False),
    input_signature=[keras.InputSpec(shape=(None, 3), dtype="float32")],
)
export_archive.add_endpoint(
    name="call_training",
    fn=lambda x: model.call(x, training=True),
    input_signature=[keras.InputSpec(shape=(None, 3), dtype="float32")],
)
export_archive.write_out("path/to/location")

關於資源追蹤的注意事項

ExportArchive 能夠自動追蹤其端點使用的所有 keras.Variables,因此大多數時候並非嚴格要求呼叫 .track(model)。但是,如果您的模型使用查找層(例如 IntegerLookupStringLookupTextVectorization),則需要透過 .track(model) 明確追蹤。

如果您需要能夠存取已復原封存上的屬性 variablestrainable_variablesnon_trainable_variables,也需要明確追蹤。


[來源]

add_endpoint 方法

ExportArchive.add_endpoint(name, fn, input_signature=None, **kwargs)

註冊新的服務端點。

引數

  • namestr。端點的名稱。
  • fn:可呼叫物件。它應僅利用 ExportArchive 追蹤的模型/層上可用的資源(例如,keras.Variable 物件或 tf.lookup.StaticHashTable 物件)(您可以呼叫 .track(model) 來追蹤新模型)。函數輸入的形狀和 dtype 必須已知。為此,您可以 1) 確保 fn 是至少呼叫過一次的 tf.function,或 2) 提供 input_signature 引數,以指定輸入的形狀和 dtype(請參閱下方的範例,其中顯示具有 2 個輸入引數的 Functional 模型)。
  • input_signature:選用。指定 fn 的形狀和 dtype。可以是 keras.InputSpectf.TensorSpecbackend.KerasTensor 或後端張量的結構(請參閱下方範例,其中顯示具有 2 個輸入引數的 Functional 模型)。若未提供,fn 必須是至少呼叫過一次的 tf.function。預設為 None
  • **kwargs:其他關鍵字引數
    • 特定於 JAX 後端: - is_static:選用 bool。指出 fn 是否為靜態。如果 fn 涉及狀態更新(例如,RNG 種子),則設定為 False。 - jax2tf_kwargs:選用 dictjax2tf.convert 的引數。請參閱 jax2tf.convert。如果未提供 native_serializationpolymorphic_shapes,則會自動計算。

傳回

包裝新增至封存的 fntf.function

範例

當模型具有單一輸入引數時,使用 input_signature 引數新增端點

export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
    name="serve",
    fn=model.call,
    input_signature=[keras.InputSpec(shape=(None, 3), dtype="float32")],
)

當模型具有兩個位置輸入引數時,使用 input_signature 引數新增端點

export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
    name="serve",
    fn=model.call,
    input_signature=[
        keras.InputSpec(shape=(None, 3), dtype="float32"),
        keras.InputSpec(shape=(None, 4), dtype="float32"),
    ],
)

當模型具有一個輸入引數,該引數是 2 個張量的列表時,使用 input_signature 引數新增端點(例如,具有 2 個輸入的 Functional 模型)

model = keras.Model(inputs=[x1, x2], outputs=outputs)

export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
    name="serve",
    fn=model.call,
    input_signature=[
        [
            keras.InputSpec(shape=(None, 3), dtype="float32"),
            keras.InputSpec(shape=(None, 4), dtype="float32"),
        ],
    ],
)

這也適用於字典輸入

model = keras.Model(inputs={"x1": x1, "x2": x2}, outputs=outputs)

export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
    name="serve",
    fn=model.call,
    input_signature=[
        {
            "x1": keras.InputSpec(shape=(None, 3), dtype="float32"),
            "x2": keras.InputSpec(shape=(None, 4), dtype="float32"),
        },
    ],
)

新增身為 tf.function 的端點

@tf.function()
def serving_fn(x):
    return model(x)

# The function must be traced, i.e. it must be called at least once.
serving_fn(tf.random.normal(shape=(2, 3)))

export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(name="serve", fn=serving_fn)

將模型與一些 TensorFlow 預處理結合,這可以使用 TensorFlow 資源

lookup_table = tf.lookup.StaticHashTable(initializer, default_value=0.0)

export_archive = ExportArchive()
model_fn = export_archive.track_and_add_endpoint(
    "model_fn",
    model,
    input_signature=[tf.TensorSpec(shape=(None, 5), dtype=tf.float32)],
)
export_archive.track(lookup_table)

@tf.function()
def serving_fn(x):
    x = lookup_table.lookup(x)
    return model_fn(x)

export_archive.add_endpoint(name="serve", fn=serving_fn)

[來源]

add_variable_collection 方法

ExportArchive.add_variable_collection(name, variables)

註冊一組變數,以便在重新載入後擷取。

引數

  • name:集合的字串名稱。
  • variableskeras.Variable 實例的元組/列表/集合。

範例

export_archive = ExportArchive()
export_archive.track(model)
# Register an endpoint
export_archive.add_endpoint(
    name="serve",
    fn=model.call,
    input_signature=[keras.InputSpec(shape=(None, 3), dtype="float32")],
)
# Save a variable collection
export_archive.add_variable_collection(
    name="optimizer_variables", variables=model.optimizer.variables)
export_archive.write_out("path/to/location")

# Reload the object
revived_object = tf.saved_model.load("path/to/location")
# Retrieve the variables
optimizer_variables = revived_object.optimizer_variables

[來源]

track 方法

ExportArchive.track(resource)

追蹤變數(層或模型的)和其他資產。

預設情況下,當您呼叫 add_endpoint() 時,會自動追蹤端點函數使用的所有變數。但是,非變數資產(例如查找表)需要手動追蹤。請注意,內建 Keras 層(TextVectorizationIntegerLookupStringLookup)使用的查找表會由 add_endpoint() 自動追蹤。

引數

  • resource:層、模型或 TensorFlow 可追蹤資源。

[來源]

write_out 方法

ExportArchive.write_out(filepath, options=None, verbose=True)

將對應的 SavedModel 寫入磁碟。

引數

  • filepathstrpathlib.Path 物件。儲存成品的路徑。
  • optionstf.saved_model.SaveOptions 物件,指定 SavedModel 儲存選項。
  • verbose:是否列印匯出的 SavedModel 的所有變數。

關於 TF-Serving 的注意事項:透過 add_endpoint() 註冊的所有端點都會在 SavedModel 成品中對 TF-Serving 可見。此外,第一個註冊的端點會在別名 "serving_default" 下可見(除非已手動註冊名稱為 "serving_default" 的端點),因為 TF-Serving 要求設定此端點。