手撕FocalLoss

文章目录

  • 前言
  • 1、FocalLoss
    • 1.1.公式定义
  • 2、代码
  • 总结


前言

 为了加深对Focal Loss理解,本文提供了一个简单的手写Demo。

1、FocalLoss

 介绍FocalLoss的文章已经很多了,这里简单提一下:

1.1.公式定义

 Focal Loss 的公式如下:

FL ( p t ) = − α t ( 1 − p t ) γ log ⁡ ( p t ) \text{FL}(p_t) = -\alpha_t (1 - p_t)^{\gamma} \log(p_t) FL(pt)=αt(1pt)γlog(pt)

 ;根据真实标签 y y y 的不同,Focal Loss 可以分为两种情况:

 1) 当真实标签 y = 1 y = 1 y=1 时,公式变为:

FL ( p ) = − α ( 1 − p ) γ log ⁡ ( p ) \text{FL}(p) = -\alpha (1 - p)^{\gamma} \log(p) FL(p)=α(1p)γlog(p)

 2) 当真实标签 y = 0 y = 0 y=0 时,公式变为:

FL ( p ) = − ( 1 − α ) p γ log ⁡ ( 1 − p ) \text{FL}(p) = -(1 - \alpha) p^{\gamma} \log(1 - p) FL(p)=(1α)pγlog(1p)

 Focal Loss 的完整公式可以写为:

FL ( y , p ) = − [ y ⋅ α ( 1 − p ) γ log ⁡ ( p ) + ( 1 − y ) ⋅ ( 1 − α ) p γ log ⁡ ( 1 − p ) ] \text{FL}(y, p) = -\left[ y \cdot \alpha (1 - p)^{\gamma} \log(p) + (1 - y) \cdot (1 - \alpha) p^{\gamma} \log(1 - p) \right] FL(y,p)=[yα(1p)γlog(p)+(1y)(1α)pγlog(1p)]

其中 p p p表示经过sigmoid的预测值。本文实现的是完整版的公式,而且没有引入额外的封装函数。

2、代码

import torch
import torch.nn as nn
import torch.nn.functional as F# focal_loss = pos_loss + neg_loss 
# if y == 1: pos_loss = -|1-p|^gamma * log(p)  
# if y == 0: neg_loss = -|0-p|^gamma * log(1-p)
class FocalLoss(nn.Module):def __init__(self,alpha=0.25,gamma=2.0,reduce='sum'):super(FocalLoss,self).__init__()self.alpha = alphaself.gamma = gammaself.reduce = reducedef forward(self,classifications,targets):alpha = self.alphagamma = self.gammaclassifications = classifications.view(-1)p = torch.sigmoid(classifications)targets = targets.view(-1)# 获取pos 和 neg 的索引pos_idx = torch.nonzero(targets==1).view(-1)neg_idx = torch.nonzero(targets==0).view(-1)# step1: cpt pos loss       pos_loss = -(1-p[pos_idx]).abs() ** gamma * torch.log(p[pos_idx])# step2: cpt neg loss neg_loss = -(0-p[neg_idx]).abs() ** gamma * torch.log(1-p[neg_idx])loss = torch.cat((pos_loss, neg_loss), dim=0)# targets 也需要重新排序 来跟loss值对应 concat_idx = torch.cat((pos_idx, neg_idx), dim=0)targets = targets[concat_idx]if alpha >= 0:alpha_t = alpha * targets + (1 - alpha) * (1 - targets)loss = alpha_t * lossif self.reduce=='sum':loss = loss.sum()elif self.reduce=='mean':loss = loss.mean()else:raise ValueError('reduce type is wrong!')return loss# ---test unit --- #
def main():# single cls focal loss focal_loss = FocalLoss()pred = torch.FloatTensor([0.1,0.9,0.2,0.8,0.7]) # nb_anchors :5tgt  = torch.FloatTensor([0,1,0,1,1])           # neg:0 pos:1 ; no ignoreloss = focal_loss(pred, tgt)print('loss:', loss) 

总结

 本文只是简单实现了一个二分类的FocalLoss,旨在加深读者对其理解。欢迎批评指正。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/22355.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

矩阵的扩展运算(MATLAB和pytorch实例)

秩(Rank)的定义 秩的计算 初等行变换法(最常用)行列式法(仅适用于方阵) 满秩的分类方阵的满秩非方阵的满秩几何意义应用场景判断方法 矩阵的特征值 定义求解特征值 特征方程步骤 关键性质 迹与行列式相似矩…

python面试题整理

Python 如何处理异常? Python中,使用try 和 except 关键字来捕获和处理异常 try 块中放置可能会引发异常的代码,然后在except块中处理这些异常。 能补充一下finally的作用吗? finally 块中的代码无论是否发生异常都会执行&#xf…

linux之perf(17)PMU事件采集脚本

Linux之perf(17)PMU事件采集脚本 Author: Once Day Date: 2025年2月22日 一位热衷于Linux学习和开发的菜鸟,试图谱写一场冒险之旅,也许终点只是一场白日梦… 漫漫长路,有人对你微笑过嘛… 全系列文章可参考专栏: Perf性能分析_Once_day的博…

Java数据结构-排序

目录 一.本文关注焦点 二.七大排序分析及相关实现 1.冒泡排序 2.简单选择排序 3.直接插入排序 4.希尔排序 5.堆排序 ​编辑 6.归并排序 7.快速排序 一.本文关注焦点 各种排序的代码实现及各自的时间空间复杂度分析及稳定性。 时间复杂度:在比较排序中主…

改进收敛因子和比例权重的灰狼优化算法【期刊论文完美复现】(Matlab代码实现)

2 灰狼优化算法 2.1 基本灰狼优化算法 灰狼优化算法是一种模拟灰狼捕猎自然群体行为的社会启发式优化算法,属于一种新型的群体智能优化算法。灰狼优化算法具有高度的灵活性,是当前较为流行的优化算法之一。灰狼优化算法主要分为三个阶段:追…

创建Linux虚拟环境并远程连接

目录 下载VMware软件 下载CentOS 创建虚拟环境 远程连接Linux系统 下载VMware软件 不会的可以参考 传送门 下载CentOS 不会的可以参考 传送门 创建虚拟环境 打开VMware软件,创建虚拟机 选择典型安装 找到我们安装好的centOS文件,之后会自动检…

汽车智能制造企业数字化转型SAP解决方案总结

一、项目实施概述 项目阶段划分: 蓝图设计阶段主数据管理方案各模块蓝图设计方案下一阶段工作计划 关键里程碑: 2022年6月6日:项目启动会2022年12月1日:系统上线 二、总体目标 通过SAP实施,构建研产供销协同、业财一…

《Head First设计模式》读书笔记 —— 命令模式

文章目录 本节用例餐厅类比点餐流程角色与职责从餐厅到命令模式 命令模式第一个命令对象实现命令接口实现一个命令 使用命令对象NoCommand与空对象 定义命令模式支持撤销功能使用状态实现撤销多层次撤销 One One One …… more things宏命令使用宏命令 队列请求日志请求 总结 《…

基于YOLO11深度学习的运动鞋品牌检测与识别系统【python源码+Pyqt5界面+数据集+训练代码】

《------往期经典推荐------》 一、AI应用软件开发实战专栏【链接】 项目名称项目名称1.【人脸识别与管理系统开发】2.【车牌识别与自动收费管理系统开发】3.【手势识别系统开发】4.【人脸面部活体检测系统开发】5.【图片风格快速迁移软件开发】6.【人脸表表情识别系统】7.【…

DAY08 List接口、Collections接口、Set接口

学习目标 能够说出List集合特点1.有序2.允许存储重复的元素3.有带索引的方法(练习 add,remove,set,get) 能够使用集合工具类Collections类:static void sort(List<T> list) 根据元素的自然顺序 对指定列表按升序进行排序。static <T> void sort(List<T> lis…

shell编程总结

前言 shell编程学习总结&#xff0c;1万3千多字带你学习shell编程 往期推荐 14wpoc&#xff0c;nuclei全家桶&#xff1a;nuclei模版管理工具Nuclei 哥斯拉二开&#xff0c;免杀绕过规避流量检测设备 fscan全家桶&#xff1a;FscanPlus&#xff0c;fs&#xff0c;fscan适用…

OpenAI ChatGPT在心理治疗领域展现超凡同理心,通过图灵测试挑战人类专家

近期&#xff0c;一项关于OpenAI ChatGPT在心理治疗领域的研究更是引起了广泛关注。据报道&#xff0c;ChatGPT已经成功通过了治疗师领域的图灵测试&#xff0c;其表现甚至在某些方面超越了人类治疗师&#xff0c;尤其是在展现同理心方面&#xff0c;这一发现无疑为AI在心理健康…

【智能客服】ChatGPT大模型话术优化落地方案

本文原创作者:姚瑞南 AI-agent 大模型运营专家,先后任职于美团、猎聘等中大厂AI训练专家和智能运营专家岗;多年人工智能行业智能产品运营及大模型落地经验,拥有AI外呼方向国家专利与PMP项目管理证书。(转载需经授权) 目录 一、项目背景 1.1 行业背景 1.2 业务现…

【JavaWeb12】数据交换与异步请求:JSON与Ajax的绝妙搭配是否塑造了Web的交互革命?

文章目录 &#x1f30d;一. 数据交换--JSON❄️1. JSON介绍❄️2. JSON 快速入门❄️3. JSON 对象和字符串对象转换❄️4. JSON 在 java 中使用❄️5. 代码演示 &#x1f30d;二. 异步请求--Ajax❄️1. 基本介绍❄️2. JavaScript 原生 Ajax 请求❄️3. JQuery 的 Ajax 请求 &a…

[Android]APP自启动

APP添加自启动权限&#xff0c;重启设备后自动打开APP。 1.AndroidManifest.xml <?xml version"1.0" encoding"utf-8"?> <manifest xmlns:android"http://schemas.android.com/apk/res/android"xmlns:tools"http://schemas.an…

Moonshot AI 新突破:MoBA 为大语言模型长文本处理提效论文速读

前言 在自然语言处理领域&#xff0c;随着大语言模型&#xff08;LLMs&#xff09;不断拓展其阅读、理解和生成文本的能力&#xff0c;如何高效处理长文本成为一项关键挑战。近日&#xff0c;Moonshot AI Research 联合清华大学、浙江大学的研究人员提出了一种创新方法 —— 混…

cs224w课程学习笔记-第2课

cs224w课程学习笔记-第2课 传统图学习 前言一、节点任务1、任务背景2、特征节点度3、特征节点中心性3.1 特征向量中心性&#xff08;Eigenvector Centrality&#xff09;3.2 中介中心性&#xff08;Betweenness Centrality&#xff09;3.3 接近中心性&#xff08;Closeness Cen…

Centos虚拟机扩展磁盘空间

Centos虚拟机扩展磁盘空间 扩展前后效果1 虚拟机vmware关机后&#xff0c;编辑2 扩展2.1 查看2.2 新建分区2.3 格式化新建分区ext42.3.1 格式化2.3.2 创建2.3.3 修改2.3.4 查看 2.4 扩容2.4.1 扩容2.4.1 查看 扩展前后效果 df -h1 虚拟机vmware关机后&#xff0c;编辑 2 扩展 …

1.13作业

1 if(!preg_match("/[0-9]|\~|\|\|\#|\\$|\%|\^|\&|\*|\&#xff08;|\&#xff09;|\-|\|\|\{|\[|\]|\}|\:|\|\"|\,|\<|\.|\>|\/|\?|\\\\/i", $c)){eval($c); 构造数组rce ?ceval(array_pop(next(get_defined_vars()))); post传参:asystem("c…

如何在 SpringBoot 项目使用 Redis 的 Pipeline 功能

本文是博主在批量存储聊天中用户状态和登陆信息到 Redis 缓存中时&#xff0c;使用到了 Pipeline 功能&#xff0c;并对此做出了整理。 一、Redis Pipeline 是什么 Redis 的 Pipeline 功能可以显著提升 Redis 操作的性能&#xff0c;性能提升的原因在于可以批量执行命令。当我…