FlaxLayer
類別keras.layers.FlaxLayer(module, method=None, variables=None, **kwargs)
Keras 層,用於封裝 Flax 模組。
當使用 JAX 作為 Keras 的後端時,此層能夠在 Keras 內使用 flax.linen.Module
實例形式的 Flax 組件。
用於前向傳遞的模組方法可以通過 method
參數指定,預設為 __call__
。此方法必須接受以下具有完全相同名稱的參數
self
如果該方法綁定到模組,預設的 __call__
就是這種情況,否則為 module
以傳遞模組。inputs
:模型的輸入,JAX 陣列或陣列的 PyTree
。training
(可選):指定我們處於訓練模式還是推論模式的參數,訓練模式下傳遞 True
。FlaxLayer
自動處理模型的非可訓練狀態和所需的 RNG。請注意,flax.linen.Module.apply()
的 mutable
參數設定為 DenyList(["params"])
,因此假設 "params" 集合之外的所有變數都是非可訓練權重。
此範例展示瞭如何從具有預設 __call__
方法且沒有訓練參數的 Flax Module
建立 FlaxLayer
class MyFlaxModule(flax.linen.Module):
@flax.linen.compact
def __call__(self, inputs):
x = inputs
x = flax.linen.Conv(features=32, kernel_size=(3, 3))(x)
x = flax.linen.relu(x)
x = flax.linen.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = flax.linen.Dense(features=200)(x)
x = flax.linen.relu(x)
x = flax.linen.Dense(features=10)(x)
x = flax.linen.softmax(x)
return x
flax_module = MyFlaxModule()
keras_layer = FlaxLayer(flax_module)
此範例展示瞭如何封裝模組方法以符合所需的簽名。這允許有多個輸入參數和一個具有不同名稱和值的訓練參數。此外,這還展示瞭如何使用未綁定到模組的函數。
class MyFlaxModule(flax.linen.Module):
@flax.linen.compact
def forward(self, input1, input2, deterministic):
...
return outputs
def my_flax_module_wrapper(module, inputs, training):
input1, input2 = inputs
return module.forward(input1, input2, not training)
flax_module = MyFlaxModule()
keras_layer = FlaxLayer(
module=flax_module,
method=my_flax_module_wrapper,
)
參數
flax.linen.Module
或子類別的實例。Module
中的方法。如果未提供,則使用 __call__
方法。method
也可以是未在 Module
中定義的函數,在這種情況下,它必須將 Module
作為第一個參數。它用於 Module.init
和 Module.apply
。詳細資訊記錄在 flax.linen.Module.apply()
的 method
參數中。dict
,其中包含模組的所有變數,格式與 flax.linen.Module.init()
返回的格式相同。它應包含 "params" 鍵,如果適用,還應包含非可訓練狀態的變數集合的其他鍵。這允許傳遞已訓練的參數和已學習的非可訓練狀態,或控制初始化。如果傳遞 None
,則會在建構時調用模組的 init
函數以初始化模型的變數。