FlaxLayer

[來源]

FlaxLayer 類別

keras.layers.FlaxLayer(module, method=None, variables=None, **kwargs)

Keras 層,用於封裝 Flax 模組。

當使用 JAX 作為 Keras 的後端時,此層能夠在 Keras 內使用 flax.linen.Module 實例形式的 Flax 組件。

用於前向傳遞的模組方法可以通過 method 參數指定,預設為 __call__。此方法必須接受以下具有完全相同名稱的參數

  • self 如果該方法綁定到模組,預設的 __call__ 就是這種情況,否則為 module 以傳遞模組。
  • inputs:模型的輸入,JAX 陣列或陣列的 PyTree
  • training (可選):指定我們處於訓練模式還是推論模式的參數,訓練模式下傳遞 True

FlaxLayer 自動處理模型的非可訓練狀態和所需的 RNG。請注意,flax.linen.Module.apply()mutable 參數設定為 DenyList(["params"]),因此假設 "params" 集合之外的所有變數都是非可訓練權重。

此範例展示瞭如何從具有預設 __call__ 方法且沒有訓練參數的 Flax Module 建立 FlaxLayer

class MyFlaxModule(flax.linen.Module):
    @flax.linen.compact
    def __call__(self, inputs):
        x = inputs
        x = flax.linen.Conv(features=32, kernel_size=(3, 3))(x)
        x = flax.linen.relu(x)
        x = flax.linen.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))  # flatten
        x = flax.linen.Dense(features=200)(x)
        x = flax.linen.relu(x)
        x = flax.linen.Dense(features=10)(x)
        x = flax.linen.softmax(x)
        return x

flax_module = MyFlaxModule()
keras_layer = FlaxLayer(flax_module)

此範例展示瞭如何封裝模組方法以符合所需的簽名。這允許有多個輸入參數和一個具有不同名稱和值的訓練參數。此外,這還展示瞭如何使用未綁定到模組的函數。

class MyFlaxModule(flax.linen.Module):
    @flax.linen.compact
    def forward(self, input1, input2, deterministic):
        ...
        return outputs

def my_flax_module_wrapper(module, inputs, training):
    input1, input2 = inputs
    return module.forward(input1, input2, not training)

flax_module = MyFlaxModule()
keras_layer = FlaxLayer(
    module=flax_module,
    method=my_flax_module_wrapper,
)

參數

  • moduleflax.linen.Module 或子類別的實例。
  • method:調用模型的方法。這通常是 Module 中的方法。如果未提供,則使用 __call__ 方法。method 也可以是未在 Module 中定義的函數,在這種情況下,它必須將 Module 作為第一個參數。它用於 Module.initModule.apply。詳細資訊記錄在 flax.linen.Module.apply()method 參數中。
  • variables:一個 dict,其中包含模組的所有變數,格式與 flax.linen.Module.init() 返回的格式相同。它應包含 "params" 鍵,如果適用,還應包含非可訓練狀態的變數集合的其他鍵。這允許傳遞已訓練的參數和已學習的非可訓練狀態,或控制初始化。如果傳遞 None,則會在建構時調用模組的 init 函數以初始化模型的變數。