對比採樣器 (ContrastiveSampler)
類別keras_nlp.samplers.ContrastiveSampler(k=5, alpha=0.6, **kwargs)
對比採樣器類別。
這個採樣器實作了對比搜尋演算法。簡而言之,採樣器會選擇具有最大「分數」的詞作為下一個詞。「分數」是詞的機率與先前詞的最大相似度之間的加權總和。通過使用這個聯合分數,對比採樣器減少了重複出現已見詞的行為。
參數
k
值。下一個詞將從 k 個詞中選出。alpha
的值越大,分數越依賴於相似度而不是詞的機率。呼叫參數
{{call_args}}
範例
causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en")
# Pass by name to compile.
causal_lm.compile(sampler="contrastive")
causal_lm.generate(["Keras is a"])
# Pass by object to compile.
sampler = keras_nlp.samplers.ContrastiveSampler(k=5)
causal_lm.compile(sampler=sampler)
causal_lm.generate(["Keras is a"])