Keras 3 API 文件 / KerasCV / 損失函式 / 焦點損失 (Focal Loss)

焦點損失 (Focal Loss)

[原始碼]

FocalLoss 類別

keras_cv.losses.FocalLoss(
    alpha=0.25, gamma=2, from_logits=False, label_smoothing=0, **kwargs
)

實作焦點損失

焦點損失是一種改進的交叉熵,旨在在類別不平衡的情況下表現更好。因此,它通常用於物件偵測器。

參數

  • alpha:介於 0 到 1 之間的浮點數值,表示用於處理類別不平衡的加權因子。正類別和負類別分別具有 alpha 和 (1 - alpha) 作為其加權因子。預設值為 0.25。
  • gamma:表示可調整焦點參數的正浮點數值,預設值為 2。
  • from_logits:是否預期 y_pred 為 logits 張量。預設情況下,假設 y_pred 編碼機率分佈。預設值為 False
  • label_smoothing:介於 [0, 1] 的浮點數。如果大於 0,則透過將標籤壓縮至 0.5 來平滑標籤,也就是對目標類別使用 1. - 0.5 * label_smoothing,對非目標類別使用 0.5 * label_smoothing

參考

範例

y_true = np.random.uniform(size=[10], low=0, high=4)
y_pred = np.random.uniform(size=[10], low=0, high=4)
loss = FocalLoss()
loss(y_true, y_pred)

使用 compile() API 的用法

model.compile(optimizer='adam', loss=keras_cv.losses.FocalLoss())