Perplexity
類別keras_hub.metrics.Perplexity(
from_logits=False, mask_token_id=None, dtype="float32", name="perplexity", **kwargs
)
困惑度指標。
這個類別實作了困惑度指標。簡而言之,這個類別計算交叉熵損失並取其指數。注意:此實作不適用於固定大小的窗口。
參數
y_pred
(update_state()
的輸入)應為模型返回的 logits。否則,y_pred
是一個機率張量。update_state()
中的 sample_weight
欄位,我們將計算最終的 sample_weight
作為遮罩和 sample_weight
的元素乘積。"float32"
。範例
sample_weight
和 mask_token_id
。>>> np.random.seed(42)
>>> perplexity = keras_hub.metrics.Perplexity(name="perplexity")
>>> target = np.random.randint(10, size=[2, 5])
>>> logits = np.random.uniform(size=(2, 5, 10))
>>> perplexity.update_state(target, logits)
>>> perplexity.result()
<tf.Tensor: shape=(), dtype=float32, numpy=14.352535>
1.2. 指定了 sample_weight
(遮罩 ID 為 0 的 token)。
>>> np.random.seed(42)
>>> perplexity = keras_hub.metrics.Perplexity(name="perplexity")
>>> target = np.random.randint(10, size=[2, 5])
>>> logits = np.random.uniform(size=(2, 5, 10))
>>> sample_weight = (target != 0).astype("float32")
>>> perplexity.update_state(target, logits, sample_weight)
>>> perplexity.result()
<tf.Tensor: shape=(), dtype=float32, numpy=14.352535>
>>> np.random.seed(42)
>>> perplexity = keras_hub.metrics.Perplexity(name="perplexity")
>>> target = np.random.randint(10, size=[2, 5])
>>> logits = np.random.uniform(size=(2, 5, 10))
>>> perplexity(target, logits)
<tf.Tensor: shape=(), dtype=float32, numpy=14.352535>
>>> np.random.seed(42)
>>> perplexity = keras_hub.metrics.Perplexity(mask_token_id=0)
>>> target = np.random.randint(10, size=[2, 5])
>>> logits = np.random.uniform(size=(2, 5, 10))
>>> perplexity(target, logits)
<tf.Tensor: shape=(), dtype=float32, numpy=14.352535>