三、损失函数


损失函数

  • 前言
  • 一、分类问题的损失函数
    • 1.1 二分类损失函数
      • 1.1.1 数学定义
      • 1.1.2 函数解释:
      • 1.1.3 性质
      • 1.1.4 计算演示
      • 1.1.5 代码演示
    • 1.2 多分类损失函数
      • 1.1.1 数学定义
      • 1.1.2 性质与特点
      • 1.1.3 计算演示
      • 1.1.4 代码演示
  • 二、回归问题的损失函数
    • 2.1 MAE损失
    • 2.2 MSE损失
    • 2.3 Smooth L1损失
  • 总结


前言


  • 在深度学习中, 损失函数是用来衡量模型参数的质量的函数, 衡量的方式是比较网络输出和真实输出的差异:
  • 在这里插入图片描述

一、分类问题的损失函数

1.1 二分类损失函数

1.1.1 数学定义

  • 处理二分类问题的时候,我们使用 sigmoid 激活函数,使用二分类的交叉熵损失函数:
    L ( y , y ^ ) = − [ y l o g ( y ^ ) + ( 1 − y ) l o g ( 1 − y ^ ) ] L(y, \hat y)=−[ylog( \hat y )+(1−y)log(1− \hat y)] L(y,y^)=[ylog(y^)+(1y)log(1y^)]
  • 这里的 l o g log log 是以自然常数 e e e(约等于2.71828)为底的,即自然对数
  • 模型输出的预测概率为 y ^ \hat y y^ (在 0 到 1 之间),真实的标签为 y y y (取值为 0 或 1)

1.1.2 函数解释:

  • y = 1 y = 1 y=1 的时候,也就是真实值为 1 的时候
    L ( 1 , y ^ ) = − l o g ( y ^ ) L(1, \hat y) =−log(\hat y) L(1,y^)=log(y^)

    • 此时如果模型预测的 y ^ \hat y y^ 越接近 1,损失值 L L L 越小;如果越接近0,损失值 L L L 越大。
  • y = 0 y = 0 y=0 的时候,也就是真实值为 1 的时候
    L ( 0 , y ^ ) = − l o g ( 1 − y ^ ) L(0, \hat y) =−log(1-\hat y) L(0,y^)=log(1y^)

    • 此时如果模型预测的 y ^ \hat y y^ 越接近 0,损失值 L L L 越小;如果越接近1,损失值 L L L 越大。

1.1.3 性质

  • 非负性:交叉熵损失函数总是非负的,即 L ( y , y ^ ) ≥ 0 L(y, \hat y )≥0 L(y,y^)0
  • 严格凸函数:在 y ^ \hat y y^ 的定义域(0, 1)内,交叉熵损失函数是严格凸的,因此有助于优化过程找到全局最优解。
  • 对数尺度:由于使用对数函数,交叉熵损失函数对预测概率的微小变化非常敏感,这使得它在梯度下降等优化算法中表现良好。

1.1.4 计算演示

  • 假设我们有一个样本,真实标签 y = 1 y=1 y=1,模型预测的概率为 y ^ \hat y y^ =0.9:
    L ( 1 , 0.9 ) = − l o g ( 0.9 ) = − l n ( 0.9 ) ≈ 0.1054 L(1,0.9)=−log(0.9)= -ln(0.9)≈0.1054 L(1,0.9)=log(0.9)=ln(0.9)0.1054
  • 真实标签 y = 1 y=1 y=1,如果模型预测的概率为 y ^ \hat y y^ = 0.5:
    L ( 1 , 0.5 ) = − l o g ( 0.5 ) = − l n ( 0.5 ) ≈ 0.6931 L(1,0.5)=−log(0.5)= -ln(0.5)≈0.6931 L(1,0.5)=log(0.5)=ln(0.5)0.6931
  • 可以看到,当预测概率接近真实标签时,损失值较小;反之,损失值较大。

1.1.5 代码演示

代码演示 :

import torch
from torch import nndef my_BCELoss():# 1 设置真实值和预测值# 预测值是sigmoid输出的结果y_pred = torch.tensor([0.6901, 0.5459, 0.2469], requires_grad=True)y_true = torch.tensor([0, 1, 0], dtype=torch.float32)# 2 实例化二分类交叉熵损失criterion = nn.BCELoss()# 3 计算损失my_loss = criterion(y_pred, y_true).detach().numpy()# detach() 用于将损失函数从计算图中分离出来, numpy() 将tensor转成numpy,# 这样打印出来就会是个数字 print('loss:', my_loss)if __name__ == '__main__':my_BCELoss()   # loss: 0.6867941

1.2 多分类损失函数

1.1.1 数学定义

  • 处理多分类的问题的时候,我们就不能使用二分类的损失函数,我们需要使用多分类的损失函数,在多分类任务通常使用 softmax 将 logits(原始分数输出) 转换为概率的形式,所以多分类的交叉熵损失也叫做softmax损失,它的计算方法是:
    H ( y , y ^ ) = − ∑ i = 1 N y i l o g ( y ^ i ) H(y, \hat y)=−\sum_{i=1}^{N}y_ilog(\hat y_i) H(y,y^)=i=1Nyilog(y^i)
    y ^ i = s o f t m a x ( f ( x i ) ) \hat y_i = softmax(f(x_i)) y^i=softmax(f(xi))
    • f ( x i ) f(x_i) f(xi) 是对第 i 个类别的预测分数
    • y ^ i \hat y_i y^i 是预测为第 i 类的概率
    • y y y 是真实的标签分布,通常表示为 one-hot 编码向量,即只有一个元素为 1(表示真实类别),其余元素为 0,也可以不进行 热编码,在下边的代码演示中有提及
    • y ^ \hat y y^ 是模型预测的概率分布,由模型的输出层经过softmax函数转换得到,表示每个类别的预测概率
    • N N N是类别的总数
    • 这里的 l o g log log 是以自然常数 e e e(约等于2.71828)为底的,即自然对数

注意:我们不能在网络层的输出层通过softmax函数进行激活,这是因为多分类损失函数会对输入的结果先进行softmax变化,所以会影响损失值的大小,还有准确率

1.1.2 性质与特点

  • 非负性:交叉熵损失总是非负的,当且仅当预测分布与真实分布完全一致时,损失为 0
  • 敏感性:交叉熵损失函数对正确分类的概率非常敏感。如果实际类别的预测概率低(即接近于0),那么损失将会非常高
  • 非对称性:在处理极端概率(接近0或1)时,交叉熵损失表现出明显的非对称性。特别是当预测概率趋近于0时,损失会迅速增加
  • 凸性:交叉熵损失函数是凸函数,这有助于保证优化过程的稳定性和有效性
  • 适合多类别问题:交叉熵损失函数能够很好地处理多类别分类问题,通过分别计算每个类别的损失并求和来得到总损失

1.1.3 计算演示

  • 例子背景

    • 假设我们有一个三分类问题(例如,识别图像中的猫、狗或鸟),并且我们有一个训练好的神经网络模型。模型的输出层有3个神经元,每个神经元对应一个类别的预测得分(也称为logits)。这些得分通过softmax函数转换为概率分布,表示每个类别的预测概率。
  • 数据与标签

    • 假设我们有一个输入图像,其真实标签是“狗”(在one-hot编码中表示为[0, 1, 0])。模型对该图像的预测输出(经过softmax之前的logits)为[2.0, 1.0, 0.1],经过softmax转换后的概率分布为[0.33, 0.50, 0.17]。
  • 计算交叉熵损失

    • 真实标签(one-hot编码):y = [0, 1, 0]

    • 预测概率: y ^ \hat y y^ = [0.33, 0.50, 0.17]

    • 交叉熵损失的计算公式为:
      H ( y , y ^ ) = − ∑ i = 1 N y i l o g ( y ^ i ) H(y, \hat y)=−\sum_{i=1}^{N}y_ilog(\hat y_i) H(y,y^)=i=1Nyilog(y^i)

    • 将真实标签和预测概率代入公式中,我们得到:
      H ( y , y ^ ) = − ( 0 ⋅ l o g ( 0.33 ) + 1 ⋅ l o g ( 0.50 ) + 0 ⋅ l o g ( 0.17 ) ) = − l o g ( 0.50 ) ≈ 0.693 H(y, \hat y)=−(0⋅log(0.33)+1⋅log(0.50)+0⋅log(0.17)) =−log(0.50) ≈0.693 H(y,y^)=(0log(0.33)+1log(0.50)+0log(0.17))=log(0.50)0.693

  • 因此,该输入图像的交叉熵损失约为0.693。这个值表示模型预测的概率分布与真实标签分布之间的差异程度。损失值越小,表示模型的预测越准确。

1.1.4 代码演示

代码演示 :

import torch
from torch import nn# 多分类交叉熵损失,使用nn.CrossEntropyLoss()实现。nn.CrossEntropyLoss()=softmax + 损失计算
def my_CrossEntropyLoss():# 设置真实值: 可以是热编码后的结果也可以不进行热编码y_true = torch.tensor([[0, 1, 0], [0, 0, 1]], dtype=torch.float32)# 不解码的时候, 需要传入真实类别的tensor# 注意的类型必须是64位整型数据# 意思是 真实标签是 1(即第二类)真实标签是 2(即第三类) # 对应下边的 y_pred 有 0.2 的概率是 0 有0.6 的概率是1 有0.2 的概率是 2# y_true = torch.tensor([1, 2], dtype=torch.int64)y_pred = torch.tensor([[0.2, 0.6, 0.2], [0.1, 0.8, 0.1]], dtype=torch.float32)# 实例化交叉熵损失loss = nn.CrossEntropyLoss()# 计算损失结果my_loss = loss(y_pred, y_true).numpy()print('loss:', my_loss)if __name__ == '__main__':my_CrossEntropyLoss()

二、回归问题的损失函数

2.1 MAE损失

  • Mean absolute loss(MAE)也被称为 L1 Loss,是以绝对误差作为距离。损失函数公式:
    L = 1 N ∑ i = 1 N ∣ y i − y ^ ∣ L = \frac{1}{N}\sum_{i=1}^{N} |y_i-\hat y| L=N1i=1Nyiy^

    • N N N 是样本数量
    • y i y_i yi 是第 i i i 个样本的真实值
    • y ^ \hat y y^ 是第 i i i 个样本的预测值
  • 特点:

    • 由于L1 loss具有稀疏性,为了惩罚较大的值,因此常常将其作为正则项添加到其他loss中作为约束。
    • L1 loss的最大问题是梯度在零点不平滑,导致会跳过极小值。
  • 图像

    • 在这里插入图片描述
  • 代码演示

代码演示 :

import torch
from torch import nndef my_L1Loss():# 1 设置真实值和预测值y_pred = torch.tensor([1.0, 1.0, 1.9], requires_grad=True)y_true = torch.tensor([2.0, 2.0, 2.0], dtype=torch.float32)# 2 实例MAE损失对象loss = nn.L1Loss()# 3 计算损失my_loss = loss(y_pred, y_true).detach().numpy()print('loss:', my_loss)if __name__ == '__main__':my_L1Loss()

2.2 MSE损失

  • Mean Squared Loss/ Quadratic Loss(MSE loss)也被称为L2 loss,或欧氏距离,它以误差的平方和的均值作为距离,损失函数公式:
    L = 1 N ∑ i = 1 N ( y i − y ^ ) 2 L = \frac{1}{N}\sum_{i=1}^{N} (y_i-\hat y)^2 L=N1i=1N(yiy^)2

    • N N N 是样本数量
    • y i y_i yi 是第 i i i 个样本的真实值
    • y ^ \hat y y^ 是第 i i i 个样本的预测值
  • 特点:

    • .L2 loss也常常作为正则项
    • 当预测值与目标值相差很大时, 梯度容易爆炸
  • 图像

    • 在这里插入图片描述
  • 代码演示

代码演示 :

import torch
from torch import nndef my_MSELoss():# 1 设置真实值和预测值y_pred = torch.tensor([1.0, 1.0, 1.9], requires_grad=True)y_true = torch.tensor([2.0, 2.0, 2.0], dtype=torch.float32)# 2 实例MSE损失对象loss = nn.MSELoss()# 3 计算损失my_loss = loss(y_pred, y_true).detach().numpy()print('myloss:', my_loss)if __name__ == '__main__':my_MSELoss()

2.3 Smooth L1损失

  • smooth L1说的是光滑之后的L1。损失函数公式:
    Smooth L1 ( x ) = { 0.5 × x 2 if  ∣ x ∣ < 1 ∣ x ∣ − 0.5 otherwise \text{Smooth L1}(x) = \begin{cases} 0.5 \times x ^2 & \text{if } |x| < 1 \\ |x| - 0.5 & \text{otherwise} \end{cases} Smooth L1(x)={0.5×x2x0.5if x<1otherwise

    • 其中, x x x 就是真实值与预测值的差值
  • 图像

    • 在这里插入图片描述
  • 特点:从上述图像中我们可以看出

    • .在 [-1,1] 之间实际上就是 L2 损失,这样解决了 L1 的不光滑问题
    • 在 [-1,1 ] 区间外,实际上就是 L1 损失,这样就解决了离群点梯度爆炸的问题
  • 代码演示

代码演示 :

import torch
from torch import nndef my_SmoothL1Loss():# 1 设置真实值和预测值y_true = torch.tensor([1, 1])y_pred = torch.tensor([0.6, 0.4], requires_grad=True)# 2 实例化smoothL1损失对象loss = nn.SmoothL1Loss()# 3 计算损失my_loss = loss(y_pred, y_true).detach().numpy()print('loss:', my_loss)if __name__ == '__main__':my_SmoothL1Loss()

总结

  • 我们总结了神经网络中常用的损失函数,在不同的情况下,我们需要搭配不同的损失函数。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/471145.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

PNG图片批量压缩exe工具+功能纯净+不改变原始尺寸

小编最近有一篇png图片要批量压缩&#xff0c;大小都在5MB之上&#xff0c;在网上找了半天要么就是有广告&#xff0c;要么就是有毒&#xff0c;要么就是功能复杂&#xff0c;整的我心烦意乱。 于是我自己用python写了一个纯净工具&#xff0c;只能压缩png图片&#xff0c;没任…

测试工程师简历「精选篇」

【#测试工程师简历#】一份专业且引人注目的测试工程师简历&#xff0c;无疑是你敲开理想职位大门的金钥匙。那么&#xff0c;如何撰写一份既体现技术水平又彰显个人特色的简历呢&#xff1f;以下是幻主简历网整理的测试工程师简历「程序员篇」&#xff0c;欢迎大家阅读收藏&…

git下载慢下载不了?Git国内国外下载地址镜像,git安装视频教程

git安装下载的视频教程在这 3分钟完成git下载和安装&#xff0c;git国内外下载地址镜像&#xff0c;Windows为例_哔哩哔哩_bilibili 一、Git安装包国内和国外下载地址镜像 1.1国外官方下载地址 打开Git的官方网站&#xff1a;Git官网下载页面。在页面上选择对应的系统&…

专题十八_动态规划_斐波那契数列模型_路径问题_算法专题详细总结

目录 动态规划 动态规范五步走&#xff1a; 1. 第 N 个泰波那契数&#xff08;easy&#xff09; 解析&#xff1a; 1.状态表达式&#xff1a; 2.状态转移方程&#xff1a; 3.初始化&#xff1a; 4.填表顺序&#xff1a; 5.返回值 编写代码&#xff1a; 总结&#xff…

阿里云centos7.9服务器磁盘挂载,切换服务路径

项目背景 1、项目使用的服务器为阿里云centos7.9&#xff0c;默认的磁盘为vda&#xff0c;文件系统挂载在这个磁盘上&#xff0c;项目上使用的文件夹为/home/hnst/uploadPath 2、vda使用率已达到91% 3、现购置一块新的磁盘为vdb&#xff0c;大小为2T 目的 切换服务所使用的…

STM32问题集

这里写目录标题 一、烧录1、 Can not connect to target!【ST-LINK烧录】 一、烧录 1、 Can not connect to target!【ST-LINK烧录】 烧录突然 If the target is in low power mode, please enable “Debug in Low Power mode” option from Target->settings menu 然后就&…

Scala学习记录,case class,迭代器

case class case class创建的对象的属性是不可改的 创建对象&#xff0c;可以不用写new 自动重写&#xff1a;toString, equals, hashCode, copy 自动重写方法&#xff1a;toString,equals,hashCode,copy 小习一下 1.case class 的定义语法是什么 基本形式&#xff1a;case …

成都睿明智科技有限公司解锁抖音电商新玩法

在这个短视频风起云涌的时代&#xff0c;抖音电商以其独特的魅力迅速崛起&#xff0c;成为众多商家争夺的流量高地。而在这片充满机遇与挑战的蓝海中&#xff0c;成都睿明智科技有限公司犹如一颗璀璨的新星&#xff0c;以其专业的抖音电商服务&#xff0c;助力无数品牌实现从零…

阅读2020-2023年《国外军用无人机装备技术发展综述》笔记_技术趋势

目录 文献基本信息 序言 1 发展概况 2 重点技术发展 2.1 人工智能技术 2.1.1 应用深化 2.1.2 作战效能提升 2.2 航空技术 2.2.1螺旋桨设计创新 2.2.2 发射回收技术进步 2.3 其他相关技术 2.3.1 远程控制技术探 2.3.2 云地控制平台应用 3 装备系统进展 3.1 无人作…

LeetCode 86.分隔链表

题目&#xff1a; 给你一个链表的头节点 head 和一个特定值 x &#xff0c;请你对链表进行分隔&#xff0c;使得所有 小于 x 的节点都出现在 大于或等于 x 的节点之前。 你应当 保留 两个分区中每个节点的初始相对位置。 思路&#xff1a; 代码&#xff1a; /*** Definiti…

SystemVerilog学习笔记(六):控制流

条件语句 条件语句用于检查块中的语句是否被执行。条件语句创建语句块。如果给出的表达式是 true&#xff0c;执行块中的语句集&#xff0c;如果表达式为 false&#xff0c;则 else 块语句将最后执行。 序号条件语句1.if2.if-else3.if-else ladder4.unique if5.unique0 if6.p…

SQL,力扣题目1127, 用户购买平台

一、力扣链接 LeetCode_1127 二、题目描述 支出表: Spending ---------------------- | Column Name | Type | ---------------------- | user_id | int | | spend_date | date | | platform | enum | | amount | int | ------------------…

【计算机网络】【传输层】【习题】

计算机网络-传输层-习题 文章目录 10. 图 5-29 给出了 TCP 连接建立的三次握手与连接释放的四次握手过程。根据 TCP 协议的工作原理&#xff0c;请填写图 5-29 中 ①~⑧ 位置的序号值。答案技巧 注&#xff1a;本文基于《计算机网络》&#xff08;第5版&#xff09;吴功宜、吴英…

群控系统服务端开发模式-应用开发-前端个人信息功能

个人信息功能我把他分为了3部分&#xff1a;第一部分是展示登录者信息&#xff1b;第二步就是登录者登录退出信息&#xff1b;第三部分就是修改个人资料。 一、展示登录者信息 1、优先添加固定路由 在根目录下src文件夹下route文件夹下index.js文件中&#xff0c;添加如下代码 …

Qwen2-VL:发票数据提取、视频聊天和使用 PDF 的多模态 RAG 的实践指南

概述 随着人工智能技术的迅猛发展&#xff0c;多模态模型在各类应用场景中展现出强大的潜力和广泛的适用性。Qwen2-VL 作为最新一代的多模态大模型&#xff0c;融合了视觉与语言处理能力&#xff0c;旨在提升复杂任务的执行效率和准确性。本指南聚焦于 Qwen2-VL 在三个关键领域…

Java面向对象高级2

1.代码块 2.内部类 成员内部类 public class Demo{public static void main(String[] args) {outer.inner innew outer().new inner();in.run();}}class outer{private String str"outer";public class inner{public void run(){String sstr;System.out.println(s);…

Elasticsearch 8.16:适用于生产的混合对话搜索和创新的向量数据量化,其性能优于乘积量化 (PQ)

作者&#xff1a;来自 Elastic Ranjana Devaji, Dana Juratoni Elasticsearch 8.16 引入了 BBQ&#xff08;Better Binary Quantization - 更好的二进制量化&#xff09;—— 一种压缩向量化数据的创新方法&#xff0c;其性能优于传统方法&#xff0c;例如乘积量化 (Product Qu…

androidstudio下载gradle慢

1&#xff0c;现象&#xff1a; 2&#xff0c;原因&#xff0c;国内到国外网址慢 3&#xff0c;解决方法&#xff1a;更改gradle-wrapper.properties #Wed Sep 26 20:01:52 CST 2018 distributionBaseGRADLE_USER_HOME distributionPathwrapper/dists zipStoreBaseGRADLE_USER…

浅谈:基于三维场景的视频融合方法

视频融合技术的出现可以追溯到 1996 年 , Paul Debevec等 提出了与视点相关的纹理混合方法 。 也就是说 &#xff0c; 现实的漫游效果不是从摄像机的角度来看 &#xff0c; 但其仍然存在很多困难 。基于三维场景的视频融合 &#xff0c; 因其直观等特效在视频监控等相关领域有着…

探索Python的HTTP利器:Requests库的神秘面纱

文章目录 **探索Python的HTTP利器&#xff1a;Requests库的神秘面纱**一、背景&#xff1a;为何选择Requests库&#xff1f;二、Requests库是什么&#xff1f;三、如何安装Requests库&#xff1f;四、Requests库的五个简单函数使用方法1. GET请求2. POST请求3. PUT请求4. DELET…