KerasHub:預訓練模型 / API 文件 / 採樣器 / 對比採樣器 (ContrastiveSampler)

對比採樣器 (ContrastiveSampler)

[原始碼]

ContrastiveSampler 類別

keras_hub.samplers.ContrastiveSampler(k=5, alpha=0.6, **kwargs)

對比採樣器類別。

此採樣器實作對比搜尋演算法。簡而言之,採樣器選擇具有最大「分數」的詞元作為下一個詞元。「分數」是詞元機率與先前詞元最大相似度之間的加權總和。透過使用這種聯合分數,對比採樣器可以減少重複出現已見詞元的行為。

參數

  • k:整數,top-k 的 k 值。下一個詞元將從 k 個詞元中選擇。
  • alpha:浮點數,在聯合分數計算中,負最大相似度的權重。alpha 的值越大,分數越依賴相似度而非詞元機率。

呼叫參數

{{call_args}}

範例

causal_lm = keras_hub.models.GPT2CausalLM.from_preset("gpt2_base_en")

# Pass by name to compile.
causal_lm.compile(sampler="contrastive")
causal_lm.generate(["Keras is a"])

# Pass by object to compile.
sampler = keras_hub.samplers.ContrastiveSampler(k=5)
causal_lm.compile(sampler=sampler)
causal_lm.generate(["Keras is a"])