KerasHub:預訓練模型 / API 文件 / 取樣器 / 集束取樣器 (BeamSampler)

集束取樣器 (BeamSampler)

[原始碼]

BeamSampler 類別

keras_hub.samplers.BeamSampler(num_beams=5, return_all_beams=False, **kwargs)

集束取樣器類別。

此取樣器實作了集束搜尋演算法。在每個時間步,集束搜尋會保留累積機率最高的 num_beams 個集束(序列),並使用每個集束來預測候選的下一個符記。

參數

  • num_beams:整數。在每個時間步應保留的集束數量。num_beams 應為嚴格正數。
  • return_all_beams:布林值。當設定為 True 時,取樣器將傳回所有集束及其各自的機率分數。

呼叫參數

{{call_args}}

範例

causal_lm = keras_hub.models.GPT2CausalLM.from_preset("gpt2_base_en")

# Pass by name to compile.
causal_lm.compile(sampler="beam")
causal_lm.generate(["Keras is a"])

# Pass by object to compile.
sampler = keras_hub.samplers.BeamSampler(num_beams=5)
causal_lm.compile(sampler=sampler)
causal_lm.generate(["Keras is a"])