Keras 3 API 文件 / KerasHub / 取樣器 / 取樣器基類別

取樣器基類別

[來源]

Sampler 類別

keras_hub.samplers.Sampler(temperature=1.0)

基本取樣器類別。

參數

  • temperature:浮點數,可選。用於控制取樣的隨機性。溫度越高,樣本的多樣性就越高。預設值為 1.0

呼叫參數

{{call_args}}

可以擴充此基類別以實現不同的自迴歸取樣方法。為此,請覆寫 get_next_token() 方法,該方法根據所有可能的詞彙項目的機率分佈計算下一個標記。

範例

causal_lm = keras_hub.models.GPT2CausalLM.from_preset("gpt2_base_en")

# Greedy search with some tokens forbidden.
class CustomSampler(keras_hub.samplers.Sampler):
    def __init__(self, forbidden_tokens, **kwargs):
        super().__init__(**kwargs)
        self.forbidden_tokens = forbidden_tokens

    def get_next_token(self, probs):
        batch_size, vocab_size = keras.ops.shape(probs)
        for id in self.forbidden_tokens:
            update = keras.ops.zeros((batch_size, 1))
            probs = keras.ops.slice_update(probs, (0, id), update)
        return keras.ops.argmax(probs, axis=-1)

# 257 = "a" with a leading space, 262 = "the" with a leading space.
causal_lm.compile(sampler=CustomSampler(forbidden_tokens=[257, 262]))
causal_lm.summary()
causal_lm.generate(["That's strange"])

[來源]

get_next_token 方法

Sampler.get_next_token(probabilities)

取得下一個標記。參數

  • probabilities:一個張量,表示下一個標記在所有詞彙標記上的機率分佈。

根據給定的標記機率分佈取得下一個標記。子類別必須實作此方法。