Keras 3 API 文件 / 指標 / 基礎指標類別

基礎指標類別

[原始碼]

Metric 類別

keras.metrics.Metric(dtype=None, name=None)

封裝指標邏輯和狀態。

參數

  • name:指標實例的可選名稱。
  • dtype:指標計算的 dtype。預設為 None,表示使用 keras.backend.floatx()。除非設定為不同的值(透過 keras.backend.set_floatx()),否則 keras.backend.floatx()"float32"。如果提供了 keras.DTypePolicy,則將使用 compute_dtype

範例

m = SomeMetric(...)
for input in ...:
    m.update_state(input)
print('Final result: ', m.result())

compile() API 的使用方式

model = keras.Sequential()
model.add(keras.layers.Dense(64, activation='relu'))
model.add(keras.layers.Dense(64, activation='relu'))
model.add(keras.layers.Dense(10, activation='softmax'))

model.compile(optimizer=keras.optimizers.RMSprop(0.01),
              loss=keras.losses.CategoricalCrossentropy(),
              metrics=[keras.metrics.CategoricalAccuracy()])

data = np.random.random((1000, 32))
labels = np.random.random((1000, 10))

model.fit(data, labels, epochs=10)

由子類別實作

  • __init__():所有狀態變數都應在此方法中透過呼叫 self.add_variable() 來建立,例如:self.var = self.add_variable(...)
  • update_state():包含對狀態變數的所有更新,例如:self.var.assign(...)
  • result():從狀態變數計算並傳回指標的純量值或純量值字典。

範例子類別實作

class BinaryTruePositives(Metric):

    def __init__(self, name='binary_true_positives', **kwargs):
        super().__init__(name=name, **kwargs)
        self.true_positives = self.add_variable(
            shape=(),
            initializer='zeros',
            name='true_positives'
        )

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true = ops.cast(y_true, "bool")
        y_pred = ops.cast(y_pred, "bool")

        values = ops.logical_and(
            ops.equal(y_true, True), ops.equal(y_pred, True))
        values = ops.cast(values, self.dtype)
        if sample_weight is not None:
            sample_weight = ops.cast(sample_weight, self.dtype)
            sample_weight = ops.broadcast_to(
                sample_weight, ops.shape(values)
            )
            values = ops.multiply(values, sample_weight)
        self.true_positives.assign(self.true_positives + ops.sum(values))

    def result(self):
        return self.true_positives