Keras 3 API 文件 / KerasNLP / 採樣器 / BeamSampler

BeamSampler

[來源]

BeamSampler 類別

keras_nlp.samplers.BeamSampler(num_beams=5, return_all_beams=False, **kwargs)

Beam 採樣器類別。

此採樣器實作了 Beam Search 演算法。在每個時間步,Beam Search 會保留累積機率最高的 num_beams 個 beam(序列),並使用每個 beam 來預測候選的下一個詞彙。

引數

  • num_beams:int。每個時間步應保留的 beam 數量。num_beams 應嚴格為正數。
  • return_all_beams:bool。設定為 True 時,採樣器將返回所有 beam 及其各自的機率分數。

呼叫引數

{{call_args}}

範例

causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en")

# Pass by name to compile.
causal_lm.compile(sampler="beam")
causal_lm.generate(["Keras is a"])

# Pass by object to compile.
sampler = keras_nlp.samplers.BeamSampler(num_beams=5)
causal_lm.compile(sampler=sampler)
causal_lm.generate(["Keras is a"])