Keras 3 API 文件 / 層 API / 後端特定層 / Tensorflow SavedModel 層

Tensorflow SavedModel 層

[原始碼]

TFSMLayer 類別

keras.layers.TFSMLayer(
    filepath,
    call_endpoint="serve",
    call_training_endpoint=None,
    trainable=True,
    name=None,
    dtype=None,
)

重新載入透過 SavedModel / ExportArchive 儲存的 Keras 模型/層。

引數

  • filepathstrpathlib.Path 物件。SavedModel 的路徑。
  • call_endpoint:要用作重新載入層的 call() 方法的端點名稱。如果 SavedModel 是透過 model.export() 建立的,則預設端點名稱為 'serve'。在其他情況下,它可能被命名為 'serving_default'

範例

model.export("path/to/artifact")
reloaded_layer = TFSMLayer("path/to/artifact")
outputs = reloaded_layer(inputs)

重新載入的物件可以像常規 Keras 層一樣使用,並支援其可訓練權重的訓練/微調。請注意,重新載入的物件不保留原始物件的內部結構或自訂方法——它是一個圍繞儲存函數建立的全新層。

限制

  • 僅支援具有單個 inputs 張量引數(可以選擇性地為張量的 dict/tuple/list)的呼叫端點。對於具有多個獨立輸入張量引數的端點,請考慮子類別化 TFSMLayer 並實作具有自訂簽名的 call() 方法。
  • 如果您需要訓練時行為與推論時行為不同(即,如果您需要重新載入的物件在 __call__() 中支援 training=True 引數),請確保訓練時呼叫函數作為獨立端點儲存在工件中,並透過 call_training_endpoint 引數將其名稱提供給 TFSMLayer