Keras 3 API 文件 / KerasNLP / 取樣器 / TopPSampler

TopPSampler

[來源]

TopPSampler 類別

keras_nlp.samplers.TopPSampler(p=0.1, k=None, seed=None, **kwargs)

Top-P 取樣器類別。

此取樣器實作了 top-p 搜尋演算法。Top-p 搜尋會從輸出機率總和大於 p 的最小子集中選擇標記。換句話說,top-p 會先根據可能性對標記預測進行排序,並忽略累積機率超過 p 之後的所有標記,然後從剩餘的標記中選擇一個標記。

參數

  • p:浮點數,top-p 的 p 值。
  • k:整數。如果設定,此參數會定義在「top-p」取樣之前應用的啟發式「top-k」截止值。不在前 k 個中的所有對數將被捨棄,剩餘的對數將被排序以找到 p 的截止點。設定此參數可以透過減少要排序的標記數量來顯著加快取樣速度。預設值為 None
  • seed:整數。隨機種子。預設值為 None

呼叫參數

{{call_args}}

範例

causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en")

# Pass by name to compile.
causal_lm.compile(sampler="top_p")
causal_lm.generate(["Keras is a"])

# Pass by object to compile.
sampler = keras_nlp.samplers.TopPSampler(p=0.1, k=1_000)
causal_lm.compile(sampler=sampler)
causal_lm.generate(["Keras is a"])