export
方法Model.export(
filepath, format="tf_saved_model", verbose=None, input_signature=None, **kwargs
)
將模型匯出為推論的成品。
引數
str
或 pathlib.Path
物件。儲存成品的路徑。str
。匯出格式。支援的值:"tf_saved_model"
和 "onnx"
。預設為 "tf_saved_model"
。bool
。是否在匯出期間列印訊息。預設為 None
,這會使用不同後端和格式設定的預設值。keras.InputSpec
、tf.TensorSpec
、backend.KerasTensor
或後端張量的結構。若未提供,將會自動計算。預設為 None
。format="tf_saved_model"
: - is_static
:選用 bool
。指出 fn
是否為靜態。如果 fn
涉及狀態更新(例如,RNG 種子和計數器),則設定為 False
。 - jax2tf_kwargs
:選用 dict
。jax2tf.convert
的引數。請參閱 jax2tf.convert
的文件。如果未提供 native_serialization
和 polymorphic_shapes
,則會自動計算。注意: 此功能目前僅支援 TensorFlow、JAX 和 Torch 後端。
注意: 請注意,當使用 format="onnx"
、verbose=True
和 Torch 後端時,匯出的成品可能包含來自本機檔案系統的資訊。
範例
以下是如何匯出 TensorFlow SavedModel 以進行推論。
# Export the model as a TensorFlow SavedModel artifact
model.export("path/to/location", format="tf_saved_model")
# Load the artifact in a different process/environment
reloaded_artifact = tf.saved_model.load("path/to/location")
predictions = reloaded_artifact.serve(input_data)
以下是如何匯出 ONNX 以進行推論。
# Export the model as a ONNX artifact
model.export("path/to/location", format="onnx")
# Load the artifact in a different process/environment
ort_session = onnxruntime.InferenceSession("path/to/location")
ort_inputs = {
k.name: v for k, v in zip(ort_session.get_inputs(), input_data)
}
predictions = ort_session.run(None, ort_inputs)
ExportArchive
類別keras.export.ExportArchive()
ExportArchive 用於寫入 SavedModel 成品(例如,用於推論)。
如果您有想要匯出為 SavedModel 以進行服務(例如,透過 TensorFlow-Serving)的 Keras 模型或層,您可以使用 ExportArchive
來設定您需要提供的不同服務端點,以及它們的簽章。只需實例化 ExportArchive
,使用 track()
註冊要使用的層或模型,然後使用 add_endpoint()
方法註冊新的服務端點。完成後,使用 write_out()
方法儲存成品。
產生的成品是 SavedModel,可以透過 tf.saved_model.load
重新載入。
範例
以下是如何匯出模型以進行推論。
export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
name="serve",
fn=model.call,
input_signature=[keras.InputSpec(shape=(None, 3), dtype="float32")],
)
export_archive.write_out("path/to/location")
# Elsewhere, we can reload the artifact and serve it.
# The endpoint we added is available as a method:
serving_model = tf.saved_model.load("path/to/location")
outputs = serving_model.serve(inputs)
以下是如何匯出具有一個用於推論的端點和一個用於訓練模式正向傳遞的端點(例如,啟用 dropout)的模型。
export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
name="call_inference",
fn=lambda x: model.call(x, training=False),
input_signature=[keras.InputSpec(shape=(None, 3), dtype="float32")],
)
export_archive.add_endpoint(
name="call_training",
fn=lambda x: model.call(x, training=True),
input_signature=[keras.InputSpec(shape=(None, 3), dtype="float32")],
)
export_archive.write_out("path/to/location")
關於資源追蹤的注意事項
ExportArchive
能夠自動追蹤其端點使用的所有 keras.Variables
,因此大多數時候並非嚴格要求呼叫 .track(model)
。但是,如果您的模型使用查找層(例如 IntegerLookup
、StringLookup
或 TextVectorization
),則需要透過 .track(model)
明確追蹤。
如果您需要能夠存取已復原封存上的屬性 variables
、trainable_variables
或 non_trainable_variables
,也需要明確追蹤。
add_endpoint
方法ExportArchive.add_endpoint(name, fn, input_signature=None, **kwargs)
註冊新的服務端點。
引數
str
。端點的名稱。ExportArchive
追蹤的模型/層上可用的資源(例如,keras.Variable
物件或 tf.lookup.StaticHashTable
物件)(您可以呼叫 .track(model)
來追蹤新模型)。函數輸入的形狀和 dtype 必須已知。為此,您可以 1) 確保 fn
是至少呼叫過一次的 tf.function
,或 2) 提供 input_signature
引數,以指定輸入的形狀和 dtype(請參閱下方的範例,其中顯示具有 2 個輸入引數的 Functional
模型)。fn
的形狀和 dtype。可以是 keras.InputSpec
、tf.TensorSpec
、backend.KerasTensor
或後端張量的結構(請參閱下方範例,其中顯示具有 2 個輸入引數的 Functional
模型)。若未提供,fn
必須是至少呼叫過一次的 tf.function
。預設為 None
。is_static
:選用 bool
。指出 fn
是否為靜態。如果 fn
涉及狀態更新(例如,RNG 種子),則設定為 False
。 - jax2tf_kwargs
:選用 dict
。jax2tf.convert
的引數。請參閱 jax2tf.convert
。如果未提供 native_serialization
和 polymorphic_shapes
,則會自動計算。傳回
包裝新增至封存的 fn
的 tf.function
。
範例
當模型具有單一輸入引數時,使用 input_signature
引數新增端點
export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
name="serve",
fn=model.call,
input_signature=[keras.InputSpec(shape=(None, 3), dtype="float32")],
)
當模型具有兩個位置輸入引數時,使用 input_signature
引數新增端點
export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
name="serve",
fn=model.call,
input_signature=[
keras.InputSpec(shape=(None, 3), dtype="float32"),
keras.InputSpec(shape=(None, 4), dtype="float32"),
],
)
當模型具有一個輸入引數,該引數是 2 個張量的列表時,使用 input_signature
引數新增端點(例如,具有 2 個輸入的 Functional 模型)
model = keras.Model(inputs=[x1, x2], outputs=outputs)
export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
name="serve",
fn=model.call,
input_signature=[
[
keras.InputSpec(shape=(None, 3), dtype="float32"),
keras.InputSpec(shape=(None, 4), dtype="float32"),
],
],
)
這也適用於字典輸入
model = keras.Model(inputs={"x1": x1, "x2": x2}, outputs=outputs)
export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
name="serve",
fn=model.call,
input_signature=[
{
"x1": keras.InputSpec(shape=(None, 3), dtype="float32"),
"x2": keras.InputSpec(shape=(None, 4), dtype="float32"),
},
],
)
新增身為 tf.function
的端點
@tf.function()
def serving_fn(x):
return model(x)
# The function must be traced, i.e. it must be called at least once.
serving_fn(tf.random.normal(shape=(2, 3)))
export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(name="serve", fn=serving_fn)
將模型與一些 TensorFlow 預處理結合,這可以使用 TensorFlow 資源
lookup_table = tf.lookup.StaticHashTable(initializer, default_value=0.0)
export_archive = ExportArchive()
model_fn = export_archive.track_and_add_endpoint(
"model_fn",
model,
input_signature=[tf.TensorSpec(shape=(None, 5), dtype=tf.float32)],
)
export_archive.track(lookup_table)
@tf.function()
def serving_fn(x):
x = lookup_table.lookup(x)
return model_fn(x)
export_archive.add_endpoint(name="serve", fn=serving_fn)
add_variable_collection
方法ExportArchive.add_variable_collection(name, variables)
註冊一組變數,以便在重新載入後擷取。
引數
keras.Variable
實例的元組/列表/集合。範例
export_archive = ExportArchive()
export_archive.track(model)
# Register an endpoint
export_archive.add_endpoint(
name="serve",
fn=model.call,
input_signature=[keras.InputSpec(shape=(None, 3), dtype="float32")],
)
# Save a variable collection
export_archive.add_variable_collection(
name="optimizer_variables", variables=model.optimizer.variables)
export_archive.write_out("path/to/location")
# Reload the object
revived_object = tf.saved_model.load("path/to/location")
# Retrieve the variables
optimizer_variables = revived_object.optimizer_variables
track
方法ExportArchive.track(resource)
追蹤變數(層或模型的)和其他資產。
預設情況下,當您呼叫 add_endpoint()
時,會自動追蹤端點函數使用的所有變數。但是,非變數資產(例如查找表)需要手動追蹤。請注意,內建 Keras 層(TextVectorization
、IntegerLookup
、StringLookup
)使用的查找表會由 add_endpoint()
自動追蹤。
引數
write_out
方法ExportArchive.write_out(filepath, options=None, verbose=True)
將對應的 SavedModel 寫入磁碟。
引數
str
或 pathlib.Path
物件。儲存成品的路徑。tf.saved_model.SaveOptions
物件,指定 SavedModel 儲存選項。關於 TF-Serving 的注意事項:透過 add_endpoint()
註冊的所有端點都會在 SavedModel 成品中對 TF-Serving 可見。此外,第一個註冊的端點會在別名 "serving_default"
下可見(除非已手動註冊名稱為 "serving_default"
的端點),因為 TF-Serving 要求設定此端點。