pytorch 加权CE_loss实现(语义分割中的类不平衡使用)

加权CE_loss和BCE_loss稍有不同

1.标签为long类型,BCE标签为float类型
2.当reduction为mean时计算每个像素点的损失的平均,BCE除以像素数得到平均值,CE除以像素对应的权重之和得到平均值。
在这里插入图片描述

参数配置torch.nn.CrossEntropyLoss(weight=None,size_average=None,ignore_index=-100,reduce=None,reduction=‘mean’,label_smoothing=0.0)

增加加权的CE_loss代码实现

# 总之, CrossEntropyLoss() = softmax + log + NLLLoss() = log_softmax + NLLLoss(), 具体等价应用如下:
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import numpy as npclass CrossEntropyLoss2d(nn.Module):def __init__(self, weight=None):super(CrossEntropyLoss2d, self).__init__()self.nll_loss = nn.CrossEntropyLoss(weight, reduction='mean')def forward(self, preds, targets):return self.nll_loss(preds, targets)

语义分割类别计算

class CE_w_loss(nn.Module):def __init__(self,ignore_index=255):super(CE_w_loss, self).__init__()self.ignore_index = ignore_index# self.CE = nn.CrossEntropyLoss(ignore_index=self.ignore_index)def forward(self, outputs, targets):class_num = outputs.shape[1]# print("class_num :",class_num )# # 计算每个类别在整个 batch 中的像素数占比class_pixel_counts = torch.bincount(targets.flatten(), minlength=class_num)  # 假设有class_num个类别class_pixel_proportions = class_pixel_counts.float() / torch.numel(targets)# # 根据类别占比计算权重class_weights = 1.0 / (torch.log(1.02 + class_pixel_proportions)).double()  # 使用对数变换平衡权重# # print("class_weights :",class_weights)## 定义交叉熵损失函数,并使用动态计算的类别权重criterion = nn.CrossEntropyLoss(ignore_index=self.ignore_index,weight= class_weights)# 计算损失loss = criterion(outputs, targets)print(loss.item())  # 打印损失值return lossnp.random.seed(666)pred = np.ones((2, 5, 256,256))seg = np.ones((2, 5, 256, 256)) # 灰度label = np.ones((2, 256, 256))  # 灰度pred = torch.from_numpy(pred)seg = torch.from_numpy(seg).int()  # 灰度label = torch.from_numpy(label).long()ce = CE_w_loss()loss = ce(pred, label)print("loss:",loss.item())

调用库(手动设置权重)

import torch
import torch.nn as nn# 假设有一些模型输出和目标标签
model_output = torch.randn(3, 5)  # 假设有5个类别
target = torch.empty(3, dtype=torch.long).random_(5)# 定义权重
weights = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0])# 定义交叉熵损失函数,并设置权重
criterion = nn.CrossEntropyLoss(weight=weights)# 计算损失
loss = criterion(model_output, target)
print(loss)

自适应计算权重

import torch
import torch.nn as nn
import numpy as np# 假设我们有一个包含10个样本的批次,每个样本属于4个类别之一
batch_size = 10
num_classes = 4# 随机生成未经过 softmax 的logits输出(网络的最后一层输出)
logits = torch.randn(batch_size, num_classes, requires_grad=True)# 真实的标签(每个样本的类别索引),例如 [0, 2, 1, 3, 0, 0, 1, 2, 3, 3]
labels = torch.tensor([0, 2, 1, 3, 0, 0, 1, 2, 3, 3])# 统计每个类别的频率
class_counts = torch.bincount(labels, minlength=num_classes).float()# 计算每个类别的权重,权重可以为类别频率的倒数
# 为了防止分母为零,这里加一个小的常数epsilon
epsilon = 1e-6
class_weights = 1.0 / (class_counts + epsilon)# 归一化权重,使其和为1
class_weights /= class_weights.sum()print('Class Counts:', class_counts)
print('Class Weights:', class_weights)# 创建带权重的交叉熵损失函数
criterion = nn.CrossEntropyLoss(weight=class_weights)# 计算损失值
loss = criterion(logits, labels)print('Logits:\n', logits)
print('Labels:\n', labels)
print('Weighted Cross-Entropy Loss:', loss.item())# 反向传播梯度
loss.backward()

报错

Weight=torch.from_numpy(np.array([0.1, 0.8, 1.0, 1.0])).float() 报错
Weight=torch.from_numpy(np.array([0.1, 0.8, 1.0, 1.0])).double() 正确

参考:[1]https://blog.csdn.net/CSDN_of_ding/article/details/111515226
[2] https://blog.csdn.net/qq_40306845/article/details/137651442
[3] https://www.zhihu.com/question/400443029/answer/2477658229

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/350248.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

算法01 递推算法及相关问题详解【C++实现】

目录 递推的概念 训练:斐波那契数列 解析 参考代码 训练:上台阶 参考代码 训练:信封 解析 参考代码 递推的概念 递推是一种处理问题的重要方法。 递推通过对问题的分析,找到问题相邻项之间的关系(递推式&a…

【机器学习】LightGBM: 优化机器学习的高效梯度提升决策树

🌈个人主页: 鑫宝Code 🔥热门专栏: 闲话杂谈| 炫酷HTML | JavaScript基础 ​💫个人格言: "如无必要,勿增实体" 文章目录 LightGBM: 优化机器学习的高效梯度提升决策树引言一、LightGBM概览二、核心技术…

Go语言结构体内嵌接口

前言 在golang中,结构体内嵌结构体,接口内嵌接口都很常见,但是结构体内嵌接口很少见。它是做什么用的呢? 当我们需要重写实现了某个接口的结构体的(该接口)的部分方法,可以使用结构体内嵌接口。 作用 继承赋值给接口…

激活和禁用Hierarchy面板上的物体

1、准备工作: (1) 在HIerarchy上添加待隐藏/显示的物体,名字自取。如:endImage (2) 在Inspector面板,该物体的名称前取消勾选(隐藏) (3) 在HIerarchy上添加按钮,名字自取。如:tip…

chatgpt的命令词

人不走空 🌈个人主页:人不走空 💖系列专栏:算法专题 ⏰诗词歌赋:斯是陋室,惟吾德馨 目录 🌈个人主页:人不走空 💖系列专栏:算法专题 ⏰诗词歌…

汽车级TPSI2140QDWQRQ1隔离式固态继电器,TMUX6136PWR、TMUX1109PWR、TMUX1133PWR模拟开关与多路复用器(参数)

1、TPSI2140-Q1 是一款隔离式固态继电器,专为高电压汽车和工业应用而设计。 TPSI2140-Q1 与 TI 具有高可靠性的电容隔离技术和内部背对背 MOSFET 整合在一起,形成了一款完全集成式解决方案,无需次级侧电源。 该器件的初级侧仅由 9mA 的输入电…

CSS实现经典打字小游戏《生死时速》

🌻 前言 CSS 中有这样一个模块:Motion Path 运动模块,它可以使元素按照自定义的路径进行移动。本文将为你讲解这个模块属性的使用,并且利用它实现我小时候电脑课经常玩的一个打字游戏:金山打字的《生死时速》。 &…

僵尸网络相关

个人电脑被植入木马之后,就会主动的连接被黑客控制的这个C&C服务器,然后这个服务器就会给被植入木马的这个电脑发指令,让他探测在他的局域网内还有没有其他的电脑了,如果有那么就继续感染同局域网的其他病毒,黑客就…

随机森林算法进行预测(+调参+变量重要性)--血友病计数数据

1.读取数据 所使用的数据是血友病数据,如有需要,可在主页资源处获取,数据信息如下: import pandas as pd import numpy as np hemophilia pd.read_csv(D:/my_files/data.csv) #读取数据 2.数据预处理 在使用机器学习方法时&…

手机在网状态-手机在网状态查询-手机在网站状态接口

查询手机号在网状态,返回正常使用、停机、未启用/在网但不可用、不在网(销号/未启用/异常)、预销户等多种状态 直连三大运营商,实时更新,可查询实时在网状态 高准确率-实时更新,准确率99.99% 接口地址&…

Matlab的Simulink系统仿真(simulink调用m函数)

这几天要用Simulink做一个小东西,所以在网上现学现卖,加油! 起初的入门是看这篇文章MATLAB 之 Simulink 操作基础和系统仿真模型的建立_matlab仿真模型搭建-CSDN博客 写的很不错 后面我想在simulink中调用m文件 在 Simulink 中调用 MATLA…

解决Maven依赖引入不成功的问题

解决Maven依赖引入不成功的问题 确认IntelliJ IDEA中Maven的设置是否正确。 file --> settings --> maven 清除无效的jar,进入本地仓库清除或利用bat工具 以下是bat工具内容,运行即可。【把仓库地址换成你自己的地址进行无效jar包清除】 echo o…

复制网页文字和图片到Word中-Word插件-大珩助手

问题整理: 为什么从浏览器的网页上复制文字和图片后,在Word中粘贴时图片无法显示?有没有插件可以将网页中的文字和图片复制到Office Word 中? Word大珩助手是一款功能丰富的Office Word插件,旨在提高用户在处理文档时…

Jenkins三种构建类型

目录 传送门前言一、概念二、前置处理(必做)1、赋予777权限2、让jenkins用户拥有root用户的kill权限3、要运行jar包端口号需要大于1024 三、自由风格软件项目(FreeStyle Project)(推荐)三、Maven项目&#…

Bigtable: A Distributed Storage System for Structured Data

2003年USENIX,出自谷歌,开启分布式大数据时代的三篇论文之一,底层依赖 GFS 存储,上层供 MapReduce 查询使用 Abstract 是一种分布式结构化数据存储管理系统,存储量级是PB级别。存储的数据类型和延时要求差异都很大。…

植物ATAC-seq文献集锦(三)——果实发育篇

ATAC-seq在植物研究领域的应用我们已经介绍2期了,本期我们聚焦ATAC-seq技术在果实发育方向的应用案例。 植物ATAC-seq文献集锦(一)——基因组篇 植物ATAC-seq文献集锦(二)——生长发育篇 文献一:Ident…

electron基础使用

安装以及运行 当前node版本18,按照官网提供操作,npm init进行初始化操作,将index.js修改为main.js,执行npm install --save-dev electron。(这里我挂梯子下载成功了。),添加如下代码至package.…

【AI绘画】教你一个完美的文生图方法,简单易学好上手,新手也能轻松掌握Stable Diffusion使用!

当我们还在思索, AI(人工智能)是否能替代人类的地位, 它已悄然无声, 融入我们生活的点滴细节。 在艺术创作领域, AI技术作为核心力量,展现其无尽的魅力。 AI换脸、AI影像,AI角色、A…

Nas实现软路由OpenWrt安装

文章目录 基本配置步骤 基本配置 NAS:TS-264C 宇宙魔方 步骤 1.下载软路由OpenWrt 下载地址:https://openwrt.org/ 2.下载好以后,需要下载虚拟盘转换工具(StarWind V2V Convert) 下载地址:https://…

sprintboot容器功能

容器 容器功能Spring注入组件的注解Component,Controller,Service,Repository案例演示 Configuration应用实例传统方式使用Configuration 注意事项和细节 Import应用实例 ConditionalConditional介绍应用实例 ImportResource应用实例 配置绑定…