文章目录
- 前言
- 1、FocalLoss
- 1.1.公式定义
- 2、代码
- 总结
前言
为了加深对Focal Loss理解,本文提供了一个简单的手写Demo。
1、FocalLoss
介绍FocalLoss的文章已经很多了,这里简单提一下:
1.1.公式定义
Focal Loss 的公式如下:
FL ( p t ) = − α t ( 1 − p t ) γ log ( p t ) \text{FL}(p_t) = -\alpha_t (1 - p_t)^{\gamma} \log(p_t) FL(pt)=−αt(1−pt)γlog(pt)
;根据真实标签 y y y 的不同,Focal Loss 可以分为两种情况:
1) 当真实标签 y = 1 y = 1 y=1 时,公式变为:
FL ( p ) = − α ( 1 − p ) γ log ( p ) \text{FL}(p) = -\alpha (1 - p)^{\gamma} \log(p) FL(p)=−α(1−p)γlog(p)
2) 当真实标签 y = 0 y = 0 y=0 时,公式变为:
FL ( p ) = − ( 1 − α ) p γ log ( 1 − p ) \text{FL}(p) = -(1 - \alpha) p^{\gamma} \log(1 - p) FL(p)=−(1−α)pγlog(1−p)
Focal Loss 的完整公式可以写为:
FL ( y , p ) = − [ y ⋅ α ( 1 − p ) γ log ( p ) + ( 1 − y ) ⋅ ( 1 − α ) p γ log ( 1 − p ) ] \text{FL}(y, p) = -\left[ y \cdot \alpha (1 - p)^{\gamma} \log(p) + (1 - y) \cdot (1 - \alpha) p^{\gamma} \log(1 - p) \right] FL(y,p)=−[y⋅α(1−p)γlog(p)+(1−y)⋅(1−α)pγlog(1−p)]
其中 p p p表示经过sigmoid的预测值。本文实现的是完整版的公式,而且没有引入额外的封装函数。
2、代码
import torch
import torch.nn as nn
import torch.nn.functional as F# focal_loss = pos_loss + neg_loss
# if y == 1: pos_loss = -|1-p|^gamma * log(p)
# if y == 0: neg_loss = -|0-p|^gamma * log(1-p)
class FocalLoss(nn.Module):def __init__(self,alpha=0.25,gamma=2.0,reduce='sum'):super(FocalLoss,self).__init__()self.alpha = alphaself.gamma = gammaself.reduce = reducedef forward(self,classifications,targets):alpha = self.alphagamma = self.gammaclassifications = classifications.view(-1)p = torch.sigmoid(classifications)targets = targets.view(-1)# 获取pos 和 neg 的索引pos_idx = torch.nonzero(targets==1).view(-1)neg_idx = torch.nonzero(targets==0).view(-1)# step1: cpt pos loss pos_loss = -(1-p[pos_idx]).abs() ** gamma * torch.log(p[pos_idx])# step2: cpt neg loss neg_loss = -(0-p[neg_idx]).abs() ** gamma * torch.log(1-p[neg_idx])loss = torch.cat((pos_loss, neg_loss), dim=0)# targets 也需要重新排序 来跟loss值对应 concat_idx = torch.cat((pos_idx, neg_idx), dim=0)targets = targets[concat_idx]if alpha >= 0:alpha_t = alpha * targets + (1 - alpha) * (1 - targets)loss = alpha_t * lossif self.reduce=='sum':loss = loss.sum()elif self.reduce=='mean':loss = loss.mean()else:raise ValueError('reduce type is wrong!')return loss# ---test unit --- #
def main():# single cls focal loss focal_loss = FocalLoss()pred = torch.FloatTensor([0.1,0.9,0.2,0.8,0.7]) # nb_anchors :5tgt = torch.FloatTensor([0,1,0,1,1]) # neg:0 pos:1 ; no ignoreloss = focal_loss(pred, tgt)print('loss:', loss)
总结
本文只是简单实现了一个二分类的FocalLoss,旨在加深读者对其理解。欢迎批评指正。