PyTorch内置损失函数汇总 !!

文章目录

一、损失函数的概念

二、Pytorch内置损失函数

1. nn.CrossEntropyLoss

2. nn.NLLLoss

3. nn.NLLLoss2d

4. nn.BCELoss

5. nn.BCEWithLogitsLoss

6. nn.L1Loss

7. nn.MSELoss

8. nn.SmoothL1Loss

9. nn.PoissonNLLLoss

10. nn.KLDivLoss

11. nn.MarginRankingLoss

12. nn.MultiLabelMarginLoss

13. nn.SoftMarginLoss

14. nn.MultilabelSoftMarginLoss

15. nn.MultiMarginLoss

16. nn.TripletMarginLoss

17. nn.HingeEmbeddingLoss

18. nn.CosineEmbeddingLoss

19. nn.CTCLoss


一、损失函数的概念

损失函数(loss function):衡量模型输出与真实标签的差异。

损失函数也叫代价函数(cost function)/ 准测(criterion)/ 目标函数(objective function)/ 误差函数(error function)。

二、Pytorch内置损失函数

1. nn.CrossEntropyLoss

功能:交叉熵损失函数,用于多分类问题。这个损失函数结合了nn.LogSoftmaxnn.NLLLoss的计算过程。通常用于网络最后的分类层输出

主要参数:

  • weight:各类别的loss设置权值
  • ignore_index:忽略某个类别
  • reduction:计算模式,可为 none /sum /mean:

①. none:逐个元素计算

②. sum:所有元素求和,返回标量

③. mean:加权平均,返回标量

nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=-100, reduce=None, reduction=‘mean’)

用法示例:

# Example of target with class indices
loss = nn.CrossEntropyLoss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.empty(3, dtype=torch.long).random_(5)
output = loss(input, target)
output.backward()
# Example of target with class probabilities
input = torch.randn(3, 5, requires_grad=True)
target = torch.randn(3, 5).softmax(dim=1)
output = loss(input, target)
output.backward()

2. nn.NLLLoss

功能:负对数似然损失函数,当网络的最后一层是nn.LogSoftmax时使用。用于训练 C 个类别的分类问题

主要参数:

  • weight:各类别的loss设置权值,必须是一个长度为 C 的 Tensor
  • ignore _index:设置一个目标值, 该目标值会被忽略, 从而不会影响到 输入的梯度
  • reduction :计算模式,可为none /sum /mean

①. none:逐个元素计算

②. sum:所有元素求和,返回标量

③. mean:加权平均,返回标量

nn.NLLLoss(weight=None,size_average=None, ignore_index=-100, reduce=None, reduction='mean')

用法示例:

m = nn.LogSoftmax(dim=1)
loss = nn.NLLLoss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.tensor([1, 0, 4])
output = loss(m(input), target)

3. nn.NLLLoss2d

功能:对于图片输入的负对数似然损失. 它计算每个像素的负对数似然损失。它是nn.NLLLoss的二维版本。适用于图像相关的任务,比如像素级任务或分割

torch.nn.NLLLoss2d(weight=None, ignore_index=-100, reduction='mean')

4. nn.BCELoss

功能:二元交叉熵损失函数,用于二分类问题。计算的是目标值和预测值之间的交叉熵。

注意事项:输入值取值在 [0,1]

主要参数:

  • weight:各类别的loss设置权值
  • ignore_index:忽略某个类别
  • reduction:计算模式,可为none /sum /mean

①. none:逐个元素计算

②. sum:所有元素求和,返回标量

③. mean:加权平均,返回标量

torch.nn.BCELoss(weight=None, size_average=None,reduce=None, reduction='mean')

用法示例:

m = nn.Sigmoid()
loss = nn.BCELoss()
input = torch.randn(3, requires_grad=True)
target = torch.empty(3).random_(2)
output = loss(m(input), target)

5. nn.BCEWithLogitsLoss

功能:结合了nn.Sigmoid层和nn.BCELoss的损失函数,用于二分类问题,尤其在预测值没有经过nn.Sigmoid层时

注意事项:网络最后不加sigmoid函数

主要参数:

  • pos_weight:正样本的权值
  • weight:各类别的loss设置权值
  • ignore_index:忽略某个类别
  • reduction:计算模式,可为none /sum /mean

①. none:逐个元素计算

②. sum:所有元素求和,返回标量

③. mean:加权平均,返回标量

nn.BCEWithLogitsLoss(weight=None, size_average=None, reduce=None, reduction='mean', pos_weight=None)

用法示例:

loss = nn.BCEWithLogitsLoss()
input = torch.randn(3, requires_grad=True)
target = torch.empty(3).random_(2)
output = loss(input, target)

6. nn.L1Loss

功能:L1损失函数,也称为最小绝对偏差(LAD)。它是预测值和真实值之间差的绝对值的和

主要参数:

  • reduction:计算模式,可为none /sum /mean

①. none:逐个元素计算

②. sum:所有元素求和,返回标量

③. mean:加权平均,返回标量

torch.nn.L1Loss(reduction='mean')

用法示例:

loss = nn.L1Loss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.randn(3, 5)
output = loss(input, target)

7. nn.MSELoss

功能:均方误差损失函数,计算预测值和真实值之间差的平方的平均值,用于回归问题。

主要参数:

  • reduction:计算模式,可为none /sum /mean

①. none:逐个元素计算

②. sum:所有元素求和,返回标量

③. mean:加权平均,返回标量

torch.nn.MSELoss(reduction='mean')

用法示例:

loss = nn.MSELoss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.randn(3, 5)
output = loss(input, target)

8. nn.SmoothL1Loss

功能:平滑L1损失,也称为Huber损失,主要用于回归问题,尤其是当预测值与目标值差异较大时,比起L1损失更不易受到异常值的影响

  • size_average
  • reduce
  • reduction
  • beta
torch.nn.SmoothL1Loss(reduction='mean')

其中,

用法示例:

loss = nn.SmoothL1Loss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.randn(3, 5)
output = loss(input, target)

9. nn.PoissonNLLLoss

功能:泊松负对数似然损失,适用于计数或事件率预测,其中预测的是事件发生的平均率

主要参数:

  • log_inpput:输入是否为对数形式,决定计算公式
  • full:计算所有loss,默认为False
  • eps:修正项,避免log(input)为nan
torch.nn.PoissonNLLLoss(log_input=True, full=False,  eps=1e-08,  reduction='mean')

用法示例:

loss = nn.PoissonNLLLoss()
log_input = torch.randn(5, 2, requires_grad=True)
target = torch.randn(5, 2)
output = loss(log_input.exp(), target)

10. nn.KLDivLoss

功能::KL散度损失,用于衡量两个概率分布之间的差异。通常用于模型输出与某个目标分布或另一个模型输出之间的相似性度量

注意事项:需提前将输入计算 log-probabilities,如通过nn.logsoftmax()

主要参数:

  • reduction:none / sum / mean / batchmean

①. batchmean:batchsize维度求平均值

②. none:逐个元素计算

③. sum:所有元素求和,返回标量

④. mean:加权平均,返回标量

torch.nn.KLDivLoss(reduction='mean')

用法示例:

loss = nn.KLDivLoss(reduction='batchmean')
input = torch.log_softmax(torch.randn(5, 10), dim=1)
target = torch.softmax(torch.randn(5, 10), dim=1)
output = loss(input, target)

11. nn.MarginRankingLoss

功能:边缘排序损失,用于排序学习任务,它鼓励正例的得分比负例的得分更高一个边界值

注意事项:该方法计算两组数据之间的差异,返回一个 n*n 的loss 矩阵

主要参数:

  • margin:边界值,x1和x2之间的差异值
  • reduction:计算模式,可为none / sum / mean

①. y=1时,希望x1比x2大,当x1>x2时,不产生loss

②. y=-1时,希望x2比x1大,当x2>x1时,不产生loss

torch.nn.MarginRankingLoss(margin=0.0, reduction='mean')

用法示例:

loss = nn.MarginRankingLoss()
input1 = torch.randn(3, requires_grad=True)
input2 = torch.randn(3, requires_grad=True)
target = torch.randn(3).sign()
output = loss(input1, input2, target)

12. nn.MultiLabelMarginLoss

功能:多标签边缘损失,用于多标签分类问题,其中每个类别的损失是独立计算的。

举例:四分类任务,样本x属于0类或3类

主要参数:

  • reduction:计算模式,可为none / sum / mean
torch.nn.MultiLabelMarginLoss(reduction='mean')

对于mini-batch(小批量) 中的每个样本按如下公式计算损失:

用法示例:

loss = nn.MultiLabelMarginLoss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.tensor([[3, 0, -1, -1, -1],[1, 3, -1, -1, -1],[1, 2, 3, -1, -1]])
output = loss(input, target)

13. nn.SoftMarginLoss

功能:软边缘损失,用于二分类任务,是逻辑回归损失的平滑版本。

主要参数:

  • reduction:计算模式,可为none / sum / mean
torch.nn.SoftMarginLoss(reduction='mean')

用法示例:

loss = nn.SoftMarginLoss()
input = torch.randn(3, requires_grad=True)
target = torch.tensor([-1, 1, 1], dtype=torch.float)
output = loss(input, target)

14. nn.MultilabelSoftMarginLoss

功能:多标签软边缘损失,用于多标签分类问题,它是每个标签的二元交叉熵损失的加权版本

主要参数:

  • weight:各类别的loos设置权值
  • reduction:计算模式,可为none / sum / mean
torch.nn.MultiLabelSoftMarginLoss(weight=None, reduction='mean')

用法示例:

loss = nn.MultiLabelSoftMarginLoss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.empty(3, 5).random_(2)
output = loss(input, target)

15. nn.MultiMarginLoss

功能:多类别边缘损失,是SVM(支持向量机)的一个变种,用于多类别分类问题。

主要参数:

  • p:可选1或2
  • weight:各类别的loos设置权值
  • margin:边界值
  • reduction:计算模式,可为none / sum / mean
torch.nn.MultiMarginLoss(p=1, margin=1.0, weight=None,  reduction='mean')

用法示例:

loss = nn.MultiMarginLoss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.tensor([1, 0, 4])
output = loss(input, target)

16. nn.TripletMarginLoss

功能:三元组边缘损失,用于度量学习,其中学习的是输入样本之间的相对距离。人脸验证中常用

主要参数:

  • p:范数的阶,默认为2
  • margin:边界值
  • reduction:计算模式,可为none / sum / mean

和孪生网络相似,具体例子:给一个A,然后再给B、C,看看B、C谁和A更像。

torch.nn.TripletMarginLoss(margin=1.0, p=2.0, eps=1e-06, swap=False, reduction='mean')

其中,

用法示例:

loss = nn.TripletMarginLoss(margin=1.0, p=2)
anchor = torch.randn(100, 128, requires_grad=True)
positive = torch.randn(100, 128, requires_grad=True)
negative = torch.randn(100, 128, requires_grad=True)
output = loss(anchor, positive, negative)

17. nn.HingeEmbeddingLoss

功能:铰链嵌入损失,用于学习基于距离的相似性,当两个输入被认为是不相似的时,会惩罚它们的距离。常用于非线性embedding和半监督学习

注意事项:输入x 应为两个输入之差的绝对值

主要参数:

  • margin:边界值
  • reduction:计算模式,可为none / sum / mean

torch.nn.HingeEmbeddingLoss(margin=1.0,  reduction='mean')

用法示例:

loss = nn.HingeEmbeddingLoss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.tensor([1, -1, 1])
output = loss(input, target)

18. nn.CosineEmbeddingLoss

功能:余弦嵌入损失,用于学习输入之间的余弦相似性,适用于确定两个输入是否在方向上是相似的

主要参数:

  • margin:可取值[-1, 1],推荐为 [0,0.5]
  • reduction:计算模式,可为none / sum / mean
torch.nn.CosineEmbeddingLoss(margin=0.0, reduction='mean')

用法示例:

loss = nn.CosineEmbeddingLoss()
input1 = torch.randn(3, 5, requires_grad=True)
input2 = torch.randn(3, 5, requires_grad=True)
target = torch.tensor([1, -1, 1])
output = loss(input1, input2, target)

19. nn.CTCLoss

功能:连接时序分类(CTC)损失,用于无对齐或序列到序列问题,如语音或手写识别。

主要参数:

  • blank:blank label
  • zero_infinity:无穷大的值或梯度置0
  • reduction:计算模式,可为none / sum / mean
torch.nn.CTCLoss(blank=0, reduction='mean')

用法示例:

T = 50      # Input sequence length
C = 20      # Number of classes (including blank)
N = 16      # Batch size
S = 30      # Target sequence length of longest target in batch
S_min = 10  # Minimum target length, for demonstration purposes
# Initialize random batch of input vectors, for *size = (T,N,C)
input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_()
# Initialize random batch of targets (0 = blank, 1:C = classes)
target = torch.randint(low=1, high=C, size=(N, S), dtype=torch.long)
input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)
target_lengths = torch.randint(low=S_min, high=S, size=(N,), dtype=torch.long)
loss = nn.CTCLoss()
output = loss(input, target, input_lengths, target_lengths)

在实际的代码实现中,你需要根据你的模型和数据来调整输入和目标张量的尺寸

参考:https://yolov5.blog.csdn.net/article/details/123441628

参考:深度学习爱好者

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/244212.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

微信小程序(十二)在线图标与字体的获取与引入

注释很详细,直接上代码 上一篇 新增内容: 1.从IconFont获取图标与文字的样式链接 2.将在线图标配置进页面中(源码) 3.将字体配置进页面文字中(源码) 4.css样式的多文件导入 获取链接 1.获取图标链接 登入…

如何自己制作一个属于自己的小程序?

在这个数字化时代,小程序已经成为了我们生活中不可或缺的一部分。它们方便快捷,无需下载安装,扫一扫就能使用。如果你想拥有一个属于自己的小程序,不论是为了个人兴趣,还是商业用途,都可以通过编程或者使用…

搭建k8s集群实战(一)系统设置

1、架构及服务 Kubernetes作为容器集群系统,通过健康检查+重启策略实现了Pod故障自我修复能力,通过调度算法实现将Pod分布式部署,并保持预期副本数,根据Node失效状态自动在其他Node拉起Pod,实现了应用层的高可用性。 针对Kubernetes集群,高可用性还应包含以下两个层面的…

回归预测 | Matlab基于OOA-SVR鱼鹰算法优化支持向量机的数据多输入单输出回归预测

回归预测 | Matlab基于OOA-SVR鱼鹰算法优化支持向量机的数据多输入单输出回归预测 目录 回归预测 | Matlab基于OOA-SVR鱼鹰算法优化支持向量机的数据多输入单输出回归预测预测效果基本描述程序设计参考资料 预测效果 基本描述 1.Matlab基于OOA-SVR鱼鹰算法优化支持向量机的数据…

IS-IS:01 ISIS基本配置

这是实验拓扑,下面是基本配置: R1: sys sysname R1 user-interface console 0 idle-timeout 0 0 int loop 0 ip add 1.1.1.1 24 int g0/0/0 ip add 192.168.12.1 24 qR2: sys sysname R2 user-interface console 0 idle-timeout 0 0 int loop 0 ip add …

05.Elasticsearch应用(五)

Elasticsearch应用(五) 1.目标 咱们这一章主要学习Mapping(映射) 2.介绍 Mapping是对索引库中文档的约束,类似于数据表结构,作用如下: 定义索引中的字段的名称定义字段的数据类型&#xff…

HarmonyOS鸿蒙学习基础篇 - 基本语法概述

书接上文 HarmonyOS鸿蒙学习基础篇 - 运行第一个程序 Hello World 基本语法概述 打开 entry>src>main>ets>pages>index.ets 代码如下代码详细解释如下: Entry //Entry装饰的自定义组件将作为UI页面的入口。在单个UI页面中,最多可以使用…

<蓝桥杯软件赛>零基础备赛20周--第16周--GCD和LCM

报名明年4月蓝桥杯软件赛的同学们,如果你是大一零基础,目前懵懂中,不知该怎么办,可以看看本博客系列:备赛20周合集 20周的完整安排请点击:20周计划 每周发1个博客,共20周。 在QQ群上交流答疑&am…

OpenCV第 1 课 计算机视觉和 OpenCV 介绍

文章目录 第 1 课 计算机视觉和 OpenCV 介绍1.机器是如何“看”的2.机器视觉技术的常见应用3.图像识别介绍4. 图像识别技术的常见应用5.OpenCV 介绍6.图像在计算机中的存储形式 第 1 课 计算机视觉和 OpenCV 介绍 1.机器是如何“看”的 我们人类可以通过眼睛看到五颜六色的世界…

MySQL InnoDB 底层数据存储

InnoDB 页记录Page Directory记录迁移 页 是内存与磁盘交互的基本单位,16kb。 比如,查询的时候,并不是只从磁盘读取某条记录,而是记录所在的页 记录 记录的物理插入是随机的,就是在磁盘上的位置是无序的。但是在页中…

一文讲透Redis的LRU与LFU算法实现

深入解析Redis的LRU与LFU算法实现 一、前言 Redis是一款基于内存的高性能NoSQL数据库,数据都缓存在内存里, 这使得Redis可以每秒轻松地处理数万的读写请求。 相对于磁盘的容量,内存的空间一般都是有限的,为了避免Redis耗尽宿主…

【Linux工具篇】编辑器vim

目录 vim的基本操作 进入vim(正常模式) 正常模式->插入模式 插入模式->正常模式 正常模式->底行模式 底行模式->正常模式 底行模式->退出vim vim正常模式命令集 vim插入模式命令集 vim末行模式命令集 vim操作总结 vim配置 Linux编译器…

小米浏览器打开H5页面表格无法滑动,如何解决?

问题: 小米浏览器打开H5页面表格无法滑动,出现此问题时,第一时间怀疑是代码的css样式适配问题,也做了很多样式适配的尝试,最后测试均没有解决无法滑动的问题。 转变思维: 脑海中突然闪现是否可以使用其他…

【Python进阶编程】python编程高手常用的设计模式(持续更新中)

Python编程高手通常熟练运用各种设计模式,这些设计模式有助于提高代码的可维护性、可扩展性和重用性。 以下是一些Python编程高手常用的设计模式: 1.单例模式(Singleton Pattern) 确保一个类只有一个实例,并提供全局…

[晓理紫]每日论文分享(有中文摘要,源码或项目地址)-大模型、扩散模型、视觉导航

专属领域论文订阅 关注{晓理紫|小李子},每日更新论文,如感兴趣,请转发给有需要的同学,谢谢支持 如果你感觉对你有所帮助,请关注我,每日准时为你推送最新论文。 分类: 大语言模型LLM视觉模型VLM扩散模型视觉…

机器学习预测全家桶之单变量输入多步预测,天气温度预测为例,MATLAB代码

截止到本期,一共发了8篇关于机器学习预测全家桶的文章。参考文章如下: 1.五花八门的机器学习预测?一篇搞定不行吗? 2.机器学习预测全家桶,多步预测之BiGRU、BiLSTM、GRU、LSTM,LSSVM、TCN、CNN,…

性能优化(CPU优化技术)-NEON指令介绍

「发表于知乎专栏《移动端算法优化》」 本文主要介绍了 NEON 指令相关的知识,首先通过讲解 arm 指令集的分类,NEON寄存器的类型,树立基本概念。然后进一步梳理了 NEON 汇编以及 intrinsics 指令的格式。最后结合指令的分类,使用例…

前端实现贪吃蛇功能

大家都玩过贪吃蛇小游戏,控制一条蛇去吃食物,然后蛇在吃到食物后会变大。本篇博客将会实现贪吃蛇小游戏的功能。 1.实现效果 2.整体布局 /*** 游戏区域样式*/ const gameBoardStyle {gridTemplateColumns: repeat(${width}, 1fr),gridTemplateRows: re…

【强化学习】QAC、A2C、A3C学习笔记

强化学习算法:QAC vs A2C vs A3C 引言 经典的REINFORCE算法为我们提供了一种直接优化策略的方式,它通过梯度上升方法来寻找最优策略。然而,REINFORCE算法也有其局限性,采样效率低、高方差、收敛性差、难以处理高维离散空间。 为…

面试题: Nginx 的优化思路有哪些?网站的防盗链如何做?

文章目录 拓扑图推荐步骤在Centos01上安装Nginx,设置网站根目录/www使用域名www.h.com访问配置Nginx配置DNS 验证Nginx日志切割在www.h.com网站配置防盗链防止www.hy.com盗www.h.com的连接 注:本文提到的网址仅不是实际存在的网站,仅作为技术…