作者: Divyashree Sreepathihalli
建立日期 2023/10/23
最後修改日期 2023/10/30
描述: 將 Keras 2 程式碼遷移至多後端 Keras 3 的說明與疑難排解。
本指南將協助您將僅限 TensorFlow 的 Keras 2 程式碼遷移至多後端 Keras 3 程式碼。遷移的額外負擔極小。遷移後,您可以在 JAX、TensorFlow 或 PyTorch 之上執行 Keras 工作流程。
本指南分為兩個部分
讓我們開始吧。
首先,讓我們安裝 keras-nightly
。
此範例使用 TensorFlow 後端 (os.environ["KERAS_BACKEND"] = "tensorflow"
)。遷移程式碼後,您可以將 "tensorflow"
字串變更為 "jax"
或 "torch"
,然後在 Colab 中按一下「重新啟動執行階段」,程式碼將會在 JAX 或 PyTorch 後端上執行。
!pip install -q keras-nightly
import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import keras
import tensorflow as tf
import numpy as np
[[34;49mnotice[1;39;49m][39;49m A new release of pip is available: [31;49m23.3.1[39;49m -> [32;49m24.0
[[34;49mnotice[1;39;49m][39;49m To update, run: [32;49mpip install --upgrade pip
首先,取代您的匯入
from tensorflow import keras
取代為 import keras
from tensorflow.keras import xyz
(例如,from tensorflow.keras import layers
) 取代為 from keras import xyz
(例如,from keras import layers
)tf.keras.*
取代為 keras.*
接下來,開始執行您的測試。大多數情況下,您的程式碼會在 Keras 3 上正常執行。您可能遇到的所有問題均詳述如下,並附有修正方式。
jit_compile
預設設為 True
。在 Keras 3 中,GPU 上 Model
建構函式的 jit_compile
引數預設值已設為 True
。這表示模型預設會在 GPU 上使用即時 (JIT) 編譯進行編譯。
JIT 編譯可以改善某些模型的效能。然而,它可能不適用於所有 TensorFlow 運算。如果您使用自訂模型或層,並且看到與 XLA 相關的錯誤,您可能需要將 jit_compile
引數設為 False
。以下是使用 XLA 與 TensorFlow 時遇到的已知問題清單。除了這些問題之外,還有一些 XLA 不支援的運算。
您可能會遇到的錯誤訊息如下
Detected unsupported operations when trying to compile graph
__inference_one_step_on_data_125[] on XLA_GPU_JIT
例如,以下程式碼片段會重現上述錯誤
class MyModel(keras.Model):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def call(self, inputs):
string_input = tf.strings.as_string(inputs)
return tf.strings.to_number(string_input)
subclass_model = MyModel()
x_train = np.array([[1, 2, 3], [4, 5, 6]])
subclass_model.compile(optimizer="sgd", loss="mse")
subclass_model.predict(x_train)
如何修正:在 model.compile(..., jit_compile=False)
中設定 jit_compile=False
,或將 jit_compile
屬性設為 False
,如下所示
class MyModel(keras.Model):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def call(self, inputs):
# tf.strings ops aren't support by XLA
string_input = tf.strings.as_string(inputs)
return tf.strings.to_number(string_input)
subclass_model = MyModel()
x_train = np.array([[1, 2, 3], [4, 5, 6]])
subclass_model.jit_compile = False
subclass_model.predict(x_train)
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 51ms/step
array([[1., 2., 3.],
[4., 5., 6.]], dtype=float32)
在 Keras 3 中,不再支援透過 model.save()
儲存為 TF SavedModel 格式。
您可能會遇到的錯誤訊息如下
>>> model.save("mymodel")
ValueError: Invalid filepath extension for saving. Please add either a `.keras` extension
for the native Keras format (recommended) or a `.h5` extension. Use
`model.export(filepath)` if you want to export a SavedModel for use with
TFLite/TFServing/etc. Received: filepath=saved_model.
以下程式碼片段會重現上述錯誤
sequential_model = keras.Sequential([
keras.layers.Dense(2)
])
sequential_model.save("saved_model")
如何修正:使用 model.export(filepath)
而非 model.save(filepath)
sequential_model = keras.Sequential([keras.layers.Dense(2)])
sequential_model(np.random.rand(3, 5))
sequential_model.export("saved_model")
INFO:tensorflow:Assets written to: saved_model/assets
INFO:tensorflow:Assets written to: saved_model/assets
Saved artifact at 'saved_model'. The following endpoints are available:
* Endpoint 'serve'
args_0 (POSITIONAL_ONLY): TensorSpec(shape=(3, 5), dtype=tf.float32, name='keras_tensor')
Output Type:
TensorSpec(shape=(3, 2), dtype=tf.float32, name=None)
Captures:
14428321600: TensorSpec(shape=(), dtype=tf.resource, name=None)
14439128528: TensorSpec(shape=(), dtype=tf.resource, name=None)
不再支援透過 keras.models.load_model()
載入 TF SavedModel 檔案。如果您嘗試對 TF SavedModel 使用 keras.models.load_model()
,您會收到以下錯誤
ValueError: File format not supported: filepath=saved_model. Keras 3 only supports V3
`.keras` files and legacy H5 format files (`.h5` extension). Note that the legacy
SavedModel format is not supported by `load_model()` in Keras 3. In order to reload a
TensorFlow SavedModel as an inference-only layer in Keras 3, use
`keras.layers.TFSMLayer(saved_model, call_endpoint='serving_default')` (note that your
`call_endpoint` might have a different name).
以下程式碼片段會重現上述錯誤
keras.models.load_model("saved_model")
如何修正:使用 keras.layers.TFSMLayer(filepath, call_endpoint="serving_default")
以將 TF SavedModel 重新載入為 Keras 層。這不僅限於源自 Keras 的 SavedModel,它也適用於任何 SavedModel,例如 TF-Hub 模型。
keras.layers.TFSMLayer("saved_model", call_endpoint="serving_default")
<TFSMLayer name=tfsm_layer, built=True>
Model()
不再能傳遞深度巢狀的輸入/輸出 (巢狀深度超過 1 層,例如張量的列表的列表)。
您會遇到如下的錯誤
ValueError: When providing `inputs` as a dict, all values in the dict must be
KerasTensors. Received: inputs={'foo': <KerasTensor shape=(None, 1), dtype=float32,
sparse=None, name=foo>, 'bar': {'baz': <KerasTensor shape=(None, 1), dtype=float32,
sparse=None, name=bar>}} including invalid value {'baz': <KerasTensor shape=(None, 1),
dtype=float32, sparse=None, name=bar>} of type <class 'dict'>
以下程式碼片段會重現上述錯誤
inputs = {
"foo": keras.Input(shape=(1,), name="foo"),
"bar": {
"baz": keras.Input(shape=(1,), name="bar"),
},
}
outputs = inputs["foo"] + inputs["bar"]["baz"]
keras.Model(inputs, outputs)
如何修正:將巢狀輸入取代為輸入張量的字典、列表和元組。
inputs = {
"foo": keras.Input(shape=(1,), name="foo"),
"bar": keras.Input(shape=(1,), name="bar"),
}
outputs = inputs["foo"] + inputs["bar"]
keras.Model(inputs, outputs)
<Functional name=functional_2, built=True>
在 Keras 2 中,預設會在自訂層的 call()
方法上啟用 TF 自動圖形。在 Keras 3 中,則不是這樣。這表示如果您使用控制流程,您可能必須使用條件運算,或者您可以選擇使用 @tf.function
修飾您的 call()
方法。
您會遇到如下的錯誤
OperatorNotAllowedInGraphError: Exception encountered when calling MyCustomLayer.call().
Using a symbolic [`tf.Tensor`](https://tensorflow.dev.org.tw/api_docs/python/tf/Tensor) as a Python `bool` is not allowed. You can attempt the
following resolutions to the problem: If you are running in Graph mode, use Eager
execution mode or decorate this function with @tf.function. If you are using AutoGraph,
you can try decorating this function with @tf.function. If that does not work, then you
may be using an unsupported feature or your source code may not be visible to AutoGraph.
Here is a [link for more information](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/ref
erence/limitations.md#access-to-source-code).
以下程式碼片段會重現上述錯誤
class MyCustomLayer(keras.layers.Layer):
def call(self, inputs):
if tf.random.uniform(()) > 0.5:
return inputs * 2
else:
return inputs / 2
layer = MyCustomLayer()
data = np.random.uniform(size=[3, 3])
model = keras.models.Sequential([layer])
model.compile(optimizer="adam", loss="mse")
model.predict(data)
如何修正:使用 @tf.function
修飾您的 call()
方法
class MyCustomLayer(keras.layers.Layer):
@tf.function()
def call(self, inputs):
if tf.random.uniform(()) > 0.5:
return inputs * 2
else:
return inputs / 2
layer = MyCustomLayer()
data = np.random.uniform(size=[3, 3])
model = keras.models.Sequential([layer])
model.compile(optimizer="adam", loss="mse")
model.predict(data)
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 43ms/step
array([[0.59727275, 1.9986179 , 1.5514829 ],
[0.56239295, 1.6529864 , 0.33085832],
[0.67086476, 1.5208522 , 1.99276 ]], dtype=float32)
KerasTensor
呼叫 TF 運算在函數式模型建構期間,不允許在 Keras 張量上使用 TF 運算:「KerasTensor 無法用作 TensorFlow 函數的輸入」。
您會遇到的錯誤如下
ValueError: A KerasTensor cannot be used as input to a TensorFlow function. A KerasTensor
is a symbolic placeholder for a shape and dtype, used when constructing Keras Functional
models or Keras Functions. You can only use it as input to a Keras layer or a Keras
operation (from the namespaces `keras.layers` and `keras.operations`).
以下程式碼片段會重現此錯誤
input = keras.layers.Input([2, 2, 1])
tf.squeeze(input)
如何修正:使用 keras.ops
中的對等運算。
input = keras.layers.Input([2, 2, 1])
keras.ops.squeeze(input)
<KerasTensor shape=(None, 2, 2), dtype=float32, sparse=None, name=keras_tensor_6>
evaluate()
多輸出模型的 evaluate()
方法不再單獨傳回個別輸出損失。相反地,您應該利用 compile()
方法中的 metrics
引數來追蹤這些損失。
在處理多個具名輸出 (例如 output_a 和 output_b) 時,舊有的 tf.keras
會包含
以下程式碼片段會重現上述行為
from keras import layers
# A functional model with multiple outputs
inputs = layers.Input(shape=(10,))
x1 = layers.Dense(5, activation='relu')(inputs)
x2 = layers.Dense(5, activation='relu')(x1)
output_1 = layers.Dense(5, activation='softmax', name="output_1")(x1)
output_2 = layers.Dense(5, activation='softmax', name="output_2")(x2)
model = keras.Model(inputs=inputs, outputs=[output_1, output_2])
model.compile(optimizer='adam', loss='categorical_crossentropy')
# dummy data
x_test = np.random.uniform(size=[10, 10])
y_test = np.random.uniform(size=[10, 5])
model.evaluate(x_test, y_test)
from keras import layers
# A functional model with multiple outputs
inputs = layers.Input(shape=(10,))
x1 = layers.Dense(5, activation="relu")(inputs)
x2 = layers.Dense(5, activation="relu")(x1)
output_1 = layers.Dense(5, activation="softmax", name="output_1")(x1)
output_2 = layers.Dense(5, activation="softmax", name="output_2")(x2)
# dummy data
x_test = np.random.uniform(size=[10, 10])
y_test = np.random.uniform(size=[10, 5])
multi_output_model = keras.Model(inputs=inputs, outputs=[output_1, output_2])
multi_output_model.compile(
optimizer="adam",
loss="categorical_crossentropy",
metrics=["categorical_crossentropy", "categorical_crossentropy"],
)
multi_output_model.evaluate(x_test, y_test)
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 112ms/step - loss: 4.0217 - output_1_categorical_crossentropy: 4.0217
[4.021683692932129, 4.021683692932129]
與 Keras 2 不同,將 tf.Variable
設定為 Keras 3 層或模型的屬性時,不會自動追蹤該變數。以下程式碼片段會顯示 tf.Variables
未被追蹤。
class MyCustomLayer(keras.layers.Layer):
def __init__(self, units):
super().__init__()
self.units = units
def build(self, input_shape):
input_dim = input_shape[-1]
self.w = tf.Variable(initial_value=tf.zeros([input_dim, self.units]))
self.b = tf.Variable(initial_value=tf.zeros([self.units,]))
def call(self, inputs):
return keras.ops.matmul(inputs, self.w) + self.b
layer = MyCustomLayer(3)
data = np.random.uniform(size=[3, 3])
model = keras.models.Sequential([layer])
model.compile(optimizer="adam", loss="mse")
model.predict(data)
# The model does not have any trainable variables
for layer in model.layers:
print(layer.trainable_variables)
您會看到以下警告
UserWarning: The model does not have any trainable weights.
warnings.warn("The model does not have any trainable weights.")
如何修正:使用 self.add_weight()
方法或選擇使用 keras.Variable
。如果您目前使用 tf.variable
,您可以切換至 keras.Variable
。
class MyCustomLayer(keras.layers.Layer):
def __init__(self, units):
super().__init__()
self.units = units
def build(self, input_shape):
input_dim = input_shape[-1]
self.w = self.add_weight(
shape=[input_dim, self.units],
initializer="zeros",
)
self.b = self.add_weight(
shape=[
self.units,
],
initializer="zeros",
)
def call(self, inputs):
return keras.ops.matmul(inputs, self.w) + self.b
layer = MyCustomLayer(3)
data = np.random.uniform(size=[3, 3])
model = keras.models.Sequential([layer])
model.compile(optimizer="adam", loss="mse")
model.predict(data)
# Verify that the variables are now being tracked
for layer in model.layers:
print(layer.trainable_variables)
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 33ms/step
[<KerasVariable shape=(3, 3), dtype=float32, path=sequential_2/my_custom_layer_1/variable>, <KerasVariable shape=(3,), dtype=float32, path=sequential_2/my_custom_layer_1/variable_1>]
call()
引數中的 None
項目在 Layer.call()
中,不允許將 None
項目作為巢狀 (例如,列表/元組) 張量引數的一部分,也不允許將其作為 call()
巢狀傳回值的一部分。
如果引數中的 None
是刻意的,且有特定目的,請確保引數是選擇性的,並將其構建為個別參數。例如,考慮使用選擇性引數定義 call
方法。
以下程式碼片段會重現此錯誤。
class CustomLayer(keras.layers.Layer):
def __init__(self):
super().__init__()
def call(self, inputs):
foo = inputs["foo"]
baz = inputs["bar"]["baz"]
if baz is not None:
return foo + baz
return foo
layer = CustomLayer()
inputs = {
"foo": keras.Input(shape=(1,), name="foo"),
"bar": {
"baz": None,
},
}
layer(inputs)
如何修正
解決方案 1:將 None
取代為值,如下所示
class CustomLayer(keras.layers.Layer):
def __init__(self):
super().__init__()
def call(self, inputs):
foo = inputs["foo"]
baz = inputs["bar"]["baz"]
return foo + baz
layer = CustomLayer()
inputs = {
"foo": keras.Input(shape=(1,), name="foo"),
"bar": {
"baz": keras.Input(shape=(1,), name="bar"),
},
}
layer(inputs)
<KerasTensor shape=(None, 1), dtype=float32, sparse=False, name=keras_tensor_14>
解決方案 2:使用選擇性引數定義呼叫方法。以下是此修正的範例
class CustomLayer(keras.layers.Layer):
def __init__(self):
super().__init__()
def call(self, foo, baz=None):
if baz is not None:
return foo + baz
return foo
layer = CustomLayer()
foo = keras.Input(shape=(1,), name="foo")
baz = None
layer(foo, baz=baz)
<KerasTensor shape=(None, 1), dtype=float32, sparse=False, name=keras_tensor_15>
關於何時可以建立狀態 (例如數值權重變數),Keras 3 比 Keras 2 嚴格得多。Keras 3 希望所有狀態在模型可以訓練之前建立。這是使用 JAX 的必要條件 (而 TensorFlow 在狀態建立時機方面非常寬鬆)。
Keras 層應在其建構函式 (__init__()
方法) 或其 build()
方法中建立其狀態。它們應避免在 call()
中建立狀態。
如果您忽略此建議,且仍然在 call()
中建立狀態 (例如,透過呼叫先前未建置的層),那麼 Keras 會嘗試在訓練之前透過對符號輸入呼叫 call()
方法來自動建置該層。然而,此自動狀態建立嘗試在某些情況下可能會失敗。這會導致類似如下的錯誤
Layer 'frame_position_embedding' looks like it has unbuilt state,
but Keras is not able to trace the layer `call()` in order to build it automatically.
Possible causes:
1. The `call()` method of your layer may be crashing.
Try to `__call__()` the layer eagerly on some test input first to see if it works.
E.g. `x = np.random.random((3, 4)); y = layer(x)`
2. If the `call()` method is correct, then you may need to implement
the `def build(self, input_shape)` method on your layer.
It should create all variables used by the layer
(e.g. by calling `layer.build()` on all its children layers).
當與 JAX 後端搭配使用時,您可以使用以下層重現此錯誤
class PositionalEmbedding(keras.layers.Layer):
def __init__(self, sequence_length, output_dim, **kwargs):
super().__init__(**kwargs)
self.position_embeddings = layers.Embedding(
input_dim=sequence_length, output_dim=output_dim
)
self.sequence_length = sequence_length
self.output_dim = output_dim
def call(self, inputs):
inputs = keras.ops.cast(inputs, self.compute_dtype)
length = keras.ops.shape(inputs)[1]
positions = keras.ops.arange(start=0, stop=length, step=1)
embedded_positions = self.position_embeddings(positions)
return inputs + embedded_positions
如何修正:完全按照錯誤訊息的要求執行。首先,嘗試急切地執行該層,以查看 call()
方法是否確實正確 (注意:如果它在 Keras 2 中運作,那麼它是正確的,並且不需要變更)。如果確實正確,那麼您應該實作一個 build(self, input_shape)
方法,該方法會建立層的所有狀態,包括子層的狀態。以下是針對上述層套用的修正 (請注意 build()
方法)
class PositionalEmbedding(keras.layers.Layer):
def __init__(self, sequence_length, output_dim, **kwargs):
super().__init__(**kwargs)
self.position_embeddings = layers.Embedding(
input_dim=sequence_length, output_dim=output_dim
)
self.sequence_length = sequence_length
self.output_dim = output_dim
def build(self, input_shape):
self.position_embeddings.build(input_shape)
def call(self, inputs):
inputs = keras.ops.cast(inputs, self.compute_dtype)
length = keras.ops.shape(inputs)[1]
positions = keras.ops.arange(start=0, stop=length, step=1)
embedded_positions = self.position_embeddings(positions)
return inputs + embedded_positions
為了清理,Keras 3 中移除了一些使用率非常低的舊有功能
keras.layers.ThresholdedReLU
。相反地,您可以直接使用帶有引數 threshold
的 ReLU
層。Layer.add_loss()
:已移除符號 add_loss()
(您仍然可以在層/模型的 call()
方法內使用 add_loss()
)。LocallyConnected1D
、LocallyConnected2D
)。若要使用局部連線層,請將層實作複製到您自己的程式碼庫中。keras.layers.experimental.RandomFourierFeatures
。若要使用它,請將層實作複製到您自己的程式碼庫中。metrics
、dynamic
。metrics
在 Model
類別上仍然可用。constants
和 time_major
引數。constants
引數是 Theano 的殘留物,使用率非常低。time_major
引數的使用率也非常低。reset_metrics
引數:已從 model.*_on_batch()
方法中移除 reset_metrics
引數。此引數的使用率非常低。keras.constraints.RadialConstraint
物件。此物件的使用率非常低。使用 TensorFlow 後端的 Keras 3 程式碼將會與原生 TensorFlow API 協同運作。然而,如果您希望您的程式碼與後端無關,則需要:
tf.*
API 呼叫替換為等效的 Keras API。train_step
/test_step
方法轉換為多框架實作。keras.random
運算正確。讓我們詳細探討每個要點。
在許多情況下,這是您唯一需要做的,即可開始使用 JAX 和 PyTorch 執行自訂層和指標:將任何 tf.*
、 tf.math*
、tf.linalg.*
等替換為 keras.ops.*
。大多數 TF 運算應與 Keras 3 一致。如果名稱不同,將在本指南中突出顯示。
Keras 將 NumPy API 實作為 keras.ops
的一部分。
下表僅列出 TensorFlow 和 Keras 運算的一小部分;未列出的運算通常在兩個框架中都具有相同的名稱(例如,reshape
、matmul
、cast
等)。
train_step()
方法您的模型可能包含自訂的 train_step()
或 test_step()
方法,這些方法依賴僅限 TensorFlow 的 API – 例如,您的 train_step()
方法可能會利用 TensorFlow 的 tf.GradientTape
。若要轉換此類模型以在 JAX 或 PyTorch 上執行,您必須為您要支援的每個後端編寫不同的 train_step()
實作。
在某些情況下,您可能可以簡單地覆寫 Model.compute_loss()
方法,使其完全與後端無關,而不是覆寫 train_step()
。以下是一個具有自訂 compute_loss()
方法的層範例,該方法可在 JAX、TensorFlow 和 PyTorch 中運作
class MyModel(keras.Model):
def compute_loss(self, x=None, y=None, y_pred=None, sample_weight=None):
loss = keras.ops.sum(keras.losses.mean_squared_error(y, y_pred, sample_weight))
return loss
如果您需要修改優化機制本身(超出損失計算範圍),則需要覆寫 train_step()
,並為每個後端實作一個 train_step
方法,如下所示。
請參閱以下指南,了解如何處理每個後端的詳細資訊
class MyModel(keras.Model):
def train_step(self, *args, **kwargs):
if keras.backend.backend() == "jax":
return self._jax_train_step(*args, **kwargs)
elif keras.backend.backend() == "tensorflow":
return self._tensorflow_train_step(*args, **kwargs)
elif keras.backend.backend() == "torch":
return self._torch_train_step(*args, **kwargs)
def _jax_train_step(self, state, data):
pass # See guide: keras.io/guides/custom_train_step_in_jax/
def _tensorflow_train_step(self, data):
pass # See guide: keras.io/guides/custom_train_step_in_tensorflow/
def _torch_train_step(self, data):
pass # See guide: keras.io/guides/custom_train_step_in_torch/
Keras 3 有一個新的 keras.random
命名空間,其中包含
這些運算是無狀態的,這表示如果您傳遞 seed
引數,它們每次都會傳回相同的結果。像這樣
print(keras.random.normal(shape=(), seed=123))
print(keras.random.normal(shape=(), seed=123))
tf.Tensor(0.7832616, shape=(), dtype=float32)
tf.Tensor(0.7832616, shape=(), dtype=float32)
至關重要的是,這與有狀態的 tf.random
運算的行為不同
print(tf.random.normal(shape=(), seed=123))
print(tf.random.normal(shape=(), seed=123))
tf.Tensor(2.4435377, shape=(), dtype=float32)
tf.Tensor(-0.6386405, shape=(), dtype=float32)
當您編寫一個使用 RNG 的層(例如,自訂的 dropout 層)時,您會希望在層呼叫時使用不同的 seed 值。但是,您不能僅遞增 Python 整數並傳遞它,因為雖然這在急切執行時可以正常運作,但在使用編譯(JAX、TensorFlow 和 PyTorch 都有提供)時將無法按預期運作。在編譯層時,該層看到的第一個 Python 整數 seed 值將會硬編碼到編譯圖中。
若要解決此問題,您應該將狀態化的 keras.random.SeedGenerator
物件的執行個體作為 seed
引數傳遞,如下所示
seed_generator = keras.random.SeedGenerator(1337)
print(keras.random.normal(shape=(), seed=seed_generator))
print(keras.random.normal(shape=(), seed=seed_generator))
tf.Tensor(0.6077996, shape=(), dtype=float32)
tf.Tensor(0.8211102, shape=(), dtype=float32)
因此,當編寫使用 RNG 的層時,您會使用以下模式
class RandomNoiseLayer(keras.layers.Layer):
def __init__(self, noise_rate, **kwargs):
super().__init__(**kwargs)
self.noise_rate = noise_rate
self.seed_generator = keras.random.SeedGenerator(1337)
def call(self, inputs):
noise = keras.random.uniform(
minval=0, maxval=self.noise_rate, seed=self.seed_generator
)
return inputs + noise
這樣的層可以在任何設定中使用 - 在急切執行或編譯模型中。每次層呼叫都會使用不同的 seed 值,如預期的一樣。