Keras 3 API 文件 / KerasNLP / 指標 / 困惑度指標

困惑度指標

[來源]

Perplexity 類別

keras_nlp.metrics.Perplexity(
    from_logits=False, mask_token_id=None, dtype="float32", name="perplexity", **kwargs
)

困惑度指標。

此類別實作了困惑度指標。簡而言之,此類別會計算交叉熵損失並取其指數。注意:此實作不適用於固定大小的視窗。

參數

  • from_logits:布林值。如果為 True,則 y_predupdate_state() 的輸入)應該是模型返回的 logits。否則,y_pred 是機率張量。
  • mask_token_id:整數。要遮罩的權杖 ID。如果提供,則會為此類別計算遮罩。請注意,如果提供了此欄位,並且也提供了 update_state() 中的 sample_weight 欄位,我們將計算最終的 sample_weight 為遮罩和 sample_weight 的元素乘積。
  • dtype:字串或 tf.dtypes.Dtype。指標計算的精度。如果未指定,則預設為 "float32"
  • name:字串。指標實例的名稱。
  • **kwargs:其他關鍵字參數。

範例

  1. 透過呼叫 update_state() 和 result() 計算困惑度。 1.1. 未提供 sample_weightmask_token_id
>>> np.random.seed(42)
>>> perplexity = keras_nlp.metrics.Perplexity(name="perplexity")
>>> target = np.random.randint(10, size=[2, 5])
>>> logits = np.random.uniform(size=(2, 5, 10))
>>> perplexity.update_state(target, logits)
>>> perplexity.result()
<tf.Tensor: shape=(), dtype=float32, numpy=14.352535>

1.2. 指定了 sample_weight(遮罩 ID 為 0 的權杖)。

>>> np.random.seed(42)
>>> perplexity = keras_nlp.metrics.Perplexity(name="perplexity")
>>> target = np.random.randint(10, size=[2, 5])
>>> logits = np.random.uniform(size=(2, 5, 10))
>>> sample_weight = (target != 0).astype("float32")
>>> perplexity.update_state(target, logits, sample_weight)
>>> perplexity.result()
<tf.Tensor: shape=(), dtype=float32, numpy=14.352535>
  1. 直接呼叫困惑度。
>>> np.random.seed(42)
>>> perplexity = keras_nlp.metrics.Perplexity(name="perplexity")
>>> target = np.random.randint(10, size=[2, 5])
>>> logits = np.random.uniform(size=(2, 5, 10))
>>> perplexity(target, logits)
<tf.Tensor: shape=(), dtype=float32, numpy=14.352535>
  1. 提供填充權杖 ID 並讓類別自行計算遮罩。
>>> np.random.seed(42)
>>> perplexity = keras_nlp.metrics.Perplexity(mask_token_id=0)
>>> target = np.random.randint(10, size=[2, 5])
>>> logits = np.random.uniform(size=(2, 5, 10))
>>> perplexity(target, logits)
<tf.Tensor: shape=(), dtype=float32, numpy=14.352535>