系列文章目录
文章目录
- 系列文章目录
- 前言
- 1. 什么是 Focal Loss
- 2. 逐过程解析 Focal Loss
- 3. Focal Loss 的 PyTorch 实现
- 总结
前言
类别不平衡是一个在目标检测领域被广泛讨论的问题,因为目标数量的多少在数据集中能很直观的体现。同时,在分割中这也是一个值得关注的问题,毕竟分割的本质是对像素进行分类。而处理类别不平衡一个非差常用的方法就是通过 Focal Loss
来引导模型更关注困难的类。
1. 什么是 Focal Loss
Focal Loss
是在标准交叉熵损失基础上修改得到的。相比 CrossEntropy Loss
它增加了容易和难分样本的权重,对于难分的样本增加权重,增加 loss 的贡献度;减少易分类样本的权重,使得模型在训练时更专注于难分类的样本。
Focal Loss
从另外的视角来解决样本不平衡问题,那就是根据置信度动态调整 CE Loss
,当预测正确的置信度增加时,loss 的权重系数会逐渐衰减至0,这样模型训练的 loss 更关注难例,而大量容易的例子其 loss 贡献很低。
比如假如一张图片上有 10 个正样本,每个正样本的损失值是 3,那么这些正样本的总损失是 10x3=30。而假如该图片上有 10000 个简单易分负样本,尽管每个负样本的损失值很小,假设是 0.1,那么这些简单易分负样本的总损失是 10000x0.1=1000,那么损失值要远远高于正样本的损失值。所以如果在训练的过程中使用全部的正负样本,那么它的训练效果会很差。
2. 逐过程解析 Focal Loss
- 公式一览:
- α \alpha α 侧重的是正负样本之间的不平衡,一般设置为 0.25
- γ \gamma γ 难易样本上的权重调节,一般设置为 2
- 简单的加权 CE Loss 可能只能实现正负样本之间不平衡的调节,所以对于大多数不平衡任务来说 Focal Loss 应该还是能起到更好的效果
- 首先看一下二分类交叉熵损失函数
- 二分类交叉熵损失函数: y y y 是样本的标签值,而 p p p 是模型预测某一个样本为正样本的概率,对于真实标签为正样本的样本,它的概率 p p p 越大说明模型预测的越准确,对于真实标签为负样本的样本,它的概率 p p p 越小说明模型预测的越准确
- 如果我们定义 p t p_t pt 为如下的形式
- 公式 (1) 可以修改为如下形式 (2)
- 现在我们定义一个参数 α \alpha α 和 1 − α 1 - \alpha 1−α 来平衡正负样本的权重,定义 α t \alpha_t αt 如下,需要注意的是, α \alpha α 是个超参数用来平衡正负样本的权重,并不是实际的正负样本的比例,
- 公式 (2) 可以修改为如下形式 (3)
- 又因为样本有难易之分,所以我们必须要能区分出困难样本和简单样本,所以我们设置一个系数 ( 1 − p t ) γ ( 1-p_t )^{\gamma} (1−pt)γ
- 它可以降低简单样本的损失贡献,而使得训练时更重视一些困难样本,
Focal Loss
可以定义为:
- 看一些权重计算的例子:
- 如果预测正样本概率是 0.95(即对于一个真实标签为正样本的样本,使用模型预测它也是正样本的概率是 0.95),这显然是一个简单的样本
- 如果预测正样本概率是 0.5 ,这显然是一个稍微困难一定的样本
- 如果预测负样本的概率为 0.9(即对于一个真实标签为负样本的样本,使用模型预测它是正样本的概率是 0.9),这显然是一个困难的样本,则该样本的难易权重是
- 如果预测负样本的概率为 0.1(即对于一个真实标签为负样本的样本,使用模型预测它是正样本的概率是 0.1),这显然是一个简单的样本,
- 为此,我们得到最终的
Focal Loss
3. Focal Loss 的 PyTorch 实现
首先感谢上海 AI Lab 的杰出工作,SAM-Med2D
我这里的实现来自仓库:SAM-Med2D
如果能对大家有帮助,希望后期大家不要忘记引用这个工作:
class FocalLoss(nn.Module):def __init__(self, gamma=2.0, alpha=0.25):super(FocalLoss, self).__init__()self.gamma = gammaself.alpha = alphadef forward(self, pred, mask):"""pred: [B, 1, H, W]mask: [B, 1, H, W]"""assert pred.shape == mask.shape, "pred and mask should have the same shape."p = torch.sigmoid(pred)num_pos = torch.sum(mask)num_neg = mask.numel() - num_posw_pos = (1 - p) ** self.gammaw_neg = p ** self.gammaloss_pos = -self.alpha * mask * w_pos * torch.log(p + 1e-12)loss_neg = -(1 - self.alpha) * (1 - mask) * w_neg * torch.log(1 - p + 1e-12)loss = (torch.sum(loss_pos) + torch.sum(loss_neg)) / (num_pos + num_neg + 1e-12)return loss
总结
参考链接:
深入剖析Focal loss损失函数