Keras 3 API 文件 / RNG API / SeedGenerator 類別

SeedGenerator 類別

[原始碼]

SeedGenerator 類別

keras.random.SeedGenerator(seed=None, name=None, **kwargs)

每次呼叫產生隨機數字的函數時,產生可變的種子值。

在 Keras 中,所有隨機數字產生器(例如 keras.random.normal())都是無狀態的,這表示如果您將整數種子傳遞給它們(例如 seed=42),它們在重複呼叫時將返回相同的值。為了在每次呼叫時獲得不同的值,必須使用提供隨機產生器狀態的 SeedGenerator

請注意,所有隨機數字產生器都有預設種子值 None,這表示使用內部的全域 SeedGenerator。如果您需要將 RNG 與全域狀態解耦,您可以提供具有確定性或隨機初始狀態的本地 StateGenerator

關於 JAX 後端的注意事項:請注意,對於使用 JAX 後端進行 RNG 的 JIT 編譯,需要使用本地 StateGenerator 作為種子參數,因為不支援使用全域狀態。

範例

seed_gen = keras.random.SeedGenerator(seed=42)
values = keras.random.normal(shape=(2, 3), seed=seed_gen)
new_values = keras.random.normal(shape=(2, 3), seed=seed_gen)

在層中的用法

class Dropout(keras.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.seed_generator = keras.random.SeedGenerator(1337)

    def call(self, x, training=False):
        if training:
            return keras.random.dropout(
                x, rate=0.5, seed=self.seed_generator
            )
        return x