JaxLayer

[原始碼]

JaxLayer 類別

keras.layers.JaxLayer(
    call_fn, init_fn=None, params=None, state=None, seed=None, **kwargs
)

包裝 JAX 模型的 Keras 層。

當使用 JAX 作為 Keras 的後端時,此層啟用在 Keras 中使用 JAX 組件。

模型函數

此層接受 JAX 模型,形式為函數 call_fn,它必須接受以下具有這些確切名稱的引數

  • params:模型的可訓練參數。
  • state (選用):模型的非訓練狀態。如果模型沒有非訓練狀態,則可以省略。
  • rng (選用):jax.random.PRNGKey 實例。如果模型在訓練或推論期間都不需要 RNG,則可以省略。
  • inputs:模型的輸入,JAX 陣列或陣列的 PyTree
  • training (選用):一個引數,指定我們是否處於訓練模式或推論模式,在訓練模式下傳遞 True。如果模型在訓練模式和推論模式下的行為相同,則可以省略。

inputs 引數是強制性的。模型的輸入必須透過單一引數提供。如果 JAX 模型將多個輸入作為單獨的引數,則它們必須組合成單一結構,例如在 tupledict 中。

模型權重初始化

模型的 paramsstate 的初始化可以由此層處理,在這種情況下,必須提供 init_fn 引數。這允許模型使用正確的形狀動態初始化。或者,如果形狀已知,則可以使用 params 引數以及可選的 state 引數來建立已初始化的模型。

init_fn 函數(如果提供)必須接受以下具有這些確切名稱的引數

  • rngjax.random.PRNGKey 實例。
  • inputs:JAX 陣列或陣列的 PyTree,具有佔位符值以提供輸入的形狀。
  • training (選用):一個引數,指定我們是否處於訓練模式或推論模式。True 始終傳遞給 init_fn。無論 call_fn 是否具有 training 引數,都可以省略。

具有非訓練狀態的模型

對於具有非訓練狀態的 JAX 模型

  • call_fn 必須具有 state 引數
  • call_fn 必須返回一個 tuple,其中包含模型的輸出和模型的新非訓練狀態
  • init_fn 必須返回一個 tuple,其中包含模型的初始可訓練參數和模型的初始非訓練狀態。

此程式碼顯示了具有非訓練狀態的模型的 call_fninit_fn 簽章的可能組合。在此範例中,模型在 call_fn 中具有 training 引數和 rng 引數。

def stateful_call(params, state, rng, inputs, training):
    outputs = ...
    new_state = ...
    return outputs, new_state

def stateful_init(rng, inputs):
    initial_params = ...
    initial_state = ...
    return initial_params, initial_state

沒有非訓練狀態的模型

對於沒有非訓練狀態的 JAX 模型

  • call_fn 不得具有 state 引數
  • call_fn 必須僅返回模型的輸出
  • init_fn 必須僅返回模型的初始可訓練參數。

此程式碼顯示了沒有非訓練狀態的模型的 call_fninit_fn 簽章的可能組合。在此範例中,模型在 call_fn 中沒有 training 引數,也沒有 rng 引數。

def stateless_call(params, inputs):
    outputs = ...
    return outputs

def stateless_init(rng, inputs):
    initial_params = ...
    return initial_params

符合所需的簽章

如果模型的簽章與 JaxLayer 所需的簽章不同,則可以輕鬆編寫一個包裝方法來調整引數。此範例顯示了一個模型,該模型將多個輸入作為單獨的引數,在 dict 中預期多個 RNG,並且具有與 training 相反含義的 deterministic 引數。為了符合,輸入使用 tuple 組合在單一結構中,RNG 被拆分並用於填充預期的 dict,並且布林標誌被否定

def my_model_fn(params, rngs, input1, input2, deterministic):
    ...
    if not deterministic:
        dropout_rng = rngs["dropout"]
        keep = jax.random.bernoulli(dropout_rng, dropout_rate, x.shape)
        x = jax.numpy.where(keep, x / dropout_rate, 0)
        ...
    ...
    return outputs

def my_model_wrapper_fn(params, rng, inputs, training):
    input1, input2 = inputs
    rng1, rng2 = jax.random.split(rng)
    rngs = {"dropout": rng1, "preprocessing": rng2}
    deterministic = not training
    return my_model_fn(params, rngs, input1, input2, deterministic)

keras_layer = JaxLayer(my_model_wrapper_fn, params=initial_params)

與 Haiku 模組一起使用

JaxLayer 啟用以 Haiku 組件的形式使用 haiku.Module。這是通過按照 Haiku 模式轉換模組,然後在 call_fn 參數中傳遞 module.apply,並在需要時在 init_fn 參數中傳遞 module.init 來實現的。

如果模型具有非訓練狀態,則應使用 haiku.transform_with_state 進行轉換。如果模型沒有非訓練狀態,則應使用 haiku.transform 進行轉換。此外,可選地,如果模組在 "apply" 中不使用 RNG,則可以使用 haiku.without_apply_rng 進行轉換。

以下範例示範如何從 Haiku 模組建立 JaxLayer,該模組透過 hk.next_rng_key() 使用隨機數生成器,並採用訓練位置引數

class MyHaikuModule(hk.Module):
    def __call__(self, x, training):
        x = hk.Conv2D(32, (3, 3))(x)
        x = jax.nn.relu(x)
        x = hk.AvgPool((1, 2, 2, 1), (1, 2, 2, 1), "VALID")(x)
        x = hk.Flatten()(x)
        x = hk.Linear(200)(x)
        if training:
            x = hk.dropout(rng=hk.next_rng_key(), rate=0.3, x=x)
        x = jax.nn.relu(x)
        x = hk.Linear(10)(x)
        x = jax.nn.softmax(x)
        return x

def my_haiku_module_fn(inputs, training):
    module = MyHaikuModule()
    return module(inputs, training)

transformed_module = hk.transform(my_haiku_module_fn)

keras_layer = JaxLayer(
    call_fn=transformed_module.apply,
    init_fn=transformed_module.init,
)

引數

  • call_fn:呼叫模型的函數。請參閱上面的說明,了解它接受的引數列表以及它返回的輸出。init_fn:呼叫以初始化模型的函數。請參閱上面的說明,了解它接受的引數列表以及它返回的輸出。如果為 None,則必須提供 params 和/或 state
  • params:一個 PyTree,包含所有模型可訓練參數。這允許傳遞已訓練的參數或控制初始化。如果 paramsstate 均為 None,則在建置時呼叫 init_fn 以初始化模型的可訓練參數。
  • state:一個 PyTree,包含所有模型非訓練狀態。這允許傳遞已學習的狀態或控制初始化。如果 paramsstate 均為 None,且 call_fn 接受 state 引數,則在建置時呼叫 init_fn 以初始化模型的非訓練狀態。
  • seed:隨機數生成器的種子。選用。
  • dtype:層的計算和權重的 dtype。也可以是 keras.DTypePolicy。選用。預設為預設策略。