yolov8通过训练完成的模型生成图片热力图--论文需要

源代码来自于网络

使用pytorch_grad_cam,对特定图片生成热力图结果。
请添加图片描述

安装热力图工具

pip install pytorch_grad_cam
pip install grad-cam
# get_params中的参数:
# weight:
#         模型权重文件,代码默认是yolov8m.pt
# cfg:
#         模型文件,代码默认是yolov8m.yaml,需要注意的是需要跟weight中的预训练文件的配置是一样的,不然会报错
# device:
#         选择使用GPU还是CPU
# method:
#         选择grad-cam方法,默认是GradCAM,这里是提供了几种,可能对效果有点不一样,大家大胆尝试。
# layer::
#         选择需要可视化的层数,只需要修改数字即可,比如想用第9层,也就是model.model[9]。
# backward_type:
#         反向传播的方式,可以是以conf的loss传播,也可以class的loss传播,一般选用all,效果比较好一点。
# conf_threshold:
#         置信度,默认是0.6。
# ratio:
#         默认是0.02,就是用来筛选置信度高的结果,低的就舍弃,0.02则是筛选置信度最高的前2%的图像来进行热力图。![请添加图片描述](https://img-blog.csdnimg.cn/direct/4403f71e29314c68909ca28c037bd2b2.png)
import warningswarnings.filterwarnings('ignore')
warnings.simplefilter('ignore')
import torch, cv2, os, shutil
import numpy as npnp.random.seed(0)
import matplotlib.pyplot as plt
from tqdm import trange
from PIL import Image
from ultralytics.nn.tasks import DetectionModel as Model
from ultralytics.utils.torch_utils import intersect_dicts
from ultralytics.utils.ops import xywh2xyxy
from pytorch_grad_cam import GradCAMPlusPlus, GradCAM, XGradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.activations_and_gradients import ActivationsAndGradientsdef letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32):# Resize and pad image while meeting stride-multiple constraintsshape = im.shape[:2]  # current shape [height, width]if isinstance(new_shape, int):new_shape = (new_shape, new_shape)# Scale ratio (new / old)r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])if not scaleup:  # only scale down, do not scale up (for better val mAP)r = min(r, 1.0)# Compute paddingratio = r, r  # width, height ratiosnew_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1]  # wh paddingif auto:  # minimum rectangledw, dh = np.mod(dw, stride), np.mod(dh, stride)  # wh paddingelif scaleFill:  # stretchdw, dh = 0.0, 0.0new_unpad = (new_shape[1], new_shape[0])ratio = new_shape[1] / shape[1], new_shape[0] / shape[0]  # width, height ratiosdw /= 2  # divide padding into 2 sidesdh /= 2if shape[::-1] != new_unpad:  # resizeim = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))left, right = int(round(dw - 0.1)), int(round(dw + 0.1))im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)  # add borderreturn im, ratio, (dw, dh)class yolov8_heatmap:def __init__(self, weight, cfg, device, method, layer, backward_type, conf_threshold, ratio):device = torch.device(device)ckpt = torch.load(weight)model_names = ckpt['model'].namescsd = ckpt['model'].float().state_dict()  # checkpoint state_dict as FP32model = Model(cfg, ch=3, nc=len(model_names)).to(device)csd = intersect_dicts(csd, model.state_dict(), exclude=['anchor'])  # intersectmodel.load_state_dict(csd, strict=False)  # loadmodel.eval()print(f'Transferred {len(csd)}/{len(model.state_dict())} items')target_layers = [eval(layer)]method = eval(method)colors = np.random.uniform(0, 255, size=(len(model_names), 3)).astype(np.int32)self.__dict__.update(locals())def post_process(self, result):logits_ = result[:, 4:]boxes_ = result[:, :4]sorted, indices = torch.sort(logits_.max(1)[0], descending=True)return torch.transpose(logits_[0], dim0=0, dim1=1)[indices[0]], torch.transpose(boxes_[0], dim0=0, dim1=1)[indices[0]], xywh2xyxy(torch.transpose(boxes_[0], dim0=0, dim1=1)[indices[0]]).cpu().detach().numpy()def draw_detections(self, box, color, name, img):xmin, ymin, xmax, ymax = list(map(int, list(box)))cv2.rectangle(img, (xmin, ymin), (xmax, ymax), tuple(int(x) for x in color), 2)cv2.putText(img, str(name), (xmin, ymin - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.8, tuple(int(x) for x in color), 2,lineType=cv2.LINE_AA)return imgdef __call__(self, img_path, save_path):# remove dir if existif os.path.exists(save_path):shutil.rmtree(save_path)# make dir if not existos.makedirs(save_path, exist_ok=True)# img processimg = cv2.imread(img_path)img = letterbox(img)[0]img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)img = np.float32(img) / 255.0tensor = torch.from_numpy(np.transpose(img, axes=[2, 0, 1])).unsqueeze(0).to(self.device)# init ActivationsAndGradientsgrads = ActivationsAndGradients(self.model, self.target_layers, reshape_transform=None)# get ActivationsAndResultresult = grads(tensor)activations = grads.activations[0].cpu().detach().numpy()# postprocess to yolo outputpost_result, pre_post_boxes, post_boxes = self.post_process(result[0])print(post_result.size(0))for i in trange(int(post_result.size(0) * self.ratio)):if float(post_result[i].max()) < self.conf_threshold:breakself.model.zero_grad()# get max probability for this predictionif self.backward_type == 'class' or self.backward_type == 'all':score = post_result[i].max()score.backward(retain_graph=True)if self.backward_type == 'box' or self.backward_type == 'all':for j in range(4):score = pre_post_boxes[i, j]score.backward(retain_graph=True)# process heatmapif self.backward_type == 'class':gradients = grads.gradients[0]elif self.backward_type == 'box':gradients = grads.gradients[0] + grads.gradients[1] + grads.gradients[2] + grads.gradients[3]else:gradients = grads.gradients[0] + grads.gradients[1] + grads.gradients[2] + grads.gradients[3] + \grads.gradients[4]b, k, u, v = gradients.size()weights = self.method.get_cam_weights(self.method, None, None, None, activations,gradients.detach().numpy())weights = weights.reshape((b, k, 1, 1))saliency_map = np.sum(weights * activations, axis=1)saliency_map = np.squeeze(np.maximum(saliency_map, 0))saliency_map = cv2.resize(saliency_map, (tensor.size(3), tensor.size(2)))saliency_map_min, saliency_map_max = saliency_map.min(), saliency_map.max()if (saliency_map_max - saliency_map_min) == 0:continuesaliency_map = (saliency_map - saliency_map_min) / (saliency_map_max - saliency_map_min)# add heatmap and box to imagecam_image = show_cam_on_image(img.copy(), saliency_map, use_rgb=True)cam_image = Image.fromarray(cam_image)cam_image.save(f'{save_path}/{i}.png')def get_params():params = {'weight': './weights/bz-yolov8-aspp-s-100.pt', # 这选择想要热力可视化的模型权重路径'cfg': './ultralytics/cfg/models/cfg2024/YOLOv8-金字塔结构改进/YOLOv8-ASPP.yaml', # 这里选择与训练上面模型权重相对应的.yaml文件路径'device': 'cpu', # 选择设备,其中0表示0号显卡。如果使用CPU可视化 # 'device': 'cpu' cuda:0'method': 'GradCAM', # GradCAMPlusPlus, GradCAM, XGradCAM'layer': 'model.model[6]',   # 选择特征层'backward_type': 'all', # class, box, all'conf_threshold': 0.65, # 置信度阈值默认0.65, 可根据情况调节'ratio': 0.02 # 取前多少数据,默认是0.02,可根据情况调节}return paramsif __name__ == '__main__':model = yolov8_heatmap(**get_params()) # 初始化model('output_002.jpg', './result') # 第一个参数是图片的路径,第二个参数是保存路径,比如是result的话,其会创建一个名字为result的文件夹,如果result文件夹不为空,其会先清空文件夹。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/350262.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【机器学习】集成学习方法:Bagging与Boosting的应用与优势

&#x1f525; 个人主页&#xff1a;空白诗 文章目录 引言一、集成学习的定义二、Bagging方法1. 随机森林&#xff08;Random Forest&#xff09;2. 其他Bagging方法 二、Boosting方法1. 梯度提升树&#xff08;Gradient Boosting Machine, GBM&#xff09;解释GBM的基本原理和…

【2024最新华为OD-C/D卷试题汇总】[支持在线评测] 特惠寿司(100分) - 三语言AC题解(Python/Java/Cpp)

🍭 大家好这里是清隆学长 ,一枚热爱算法的程序员 ✨ 本系列打算持续跟新华为OD-C/D卷的三语言AC题解 💻 ACM银牌🥈| 多次AK大厂笔试 | 编程一对一辅导 👏 感谢大家的订阅➕ 和 喜欢💗 📎在线评测链接 特惠寿司(100分) 🌍 评测功能需要订阅专栏后私信联系清隆解…

9. 文本三剑客之awk

文章目录 9.1 什么是awk9.2 awk命令格式9.3 awk执行流程11.4 行与列11.4.1 取行11.4.2 取列 9.1 什么是awk 虽然sed编辑器是非常方便自动修改文本文件的工具&#xff0c;但其也有自身的限制。通常你需要一个用来处理文件中的数据的更高级工具&#xff0c;它能提供一个类编程环…

【二】【动态规划NEW】91. 解码方法,62. 不同路径,63. 不同路径 II

91. 解码方法 一条包含字母 A-Z 的消息通过以下映射进行了 编码 &#xff1a; ‘A’ -> “1” ‘B’ -> “2” … ‘Z’ -> “26” 要 解码 已编码的消息&#xff0c;所有数字必须基于上述映射的方法&#xff0c;反向映射回字母&#xff08;可能有多种方法&#xff…

利用74HC165实现8路并行输入口的扩展

代码&#xff1a; #include <mega16.h>// Declare your global variables here #define hc165_clk PORTB.0 #define hc165_lp PORTB.1 #define hc165_out PINB.2unsigned char read_hc165(void) {unsigned char data0,i,temp0x80;hc165_lp0;hc165_lp1; for(i0;i<7;i)…

Git 基础操作(一)

Git 基础操作 配置Git 安装完Git后&#xff0c;首先要做的事情是设置你的 用户名 和 e-mail 地址。这样在你向仓库提交代码的时候&#xff0c;就知道是谁提交的&#xff0c;以及提交人的联系方式。 配置用户名和邮箱 使用git config [--global] user.name "你的名字&qu…

碳中和研究院OLED透明屏2x2整机项目方案

一、项目背景 随着全球气候变化和环境问题的日益严重&#xff0c;碳中和成为各国政府和企业的重要议题。为了响应这一趋势&#xff0c;黑龙江碳中和研究院计划引入先进的OLED透明屏技术&#xff0c;以展示其研究成果和碳中和理念。本项目旨在为该研究院提供一套高质量的OLED透明…

干部选拔任用的六条原则

在干部选拔任用的过程中&#xff0c;为确保选拔出的干部能够真正符合党和人民的期望&#xff0c;必须遵循以下六条原则&#xff1a; 一、党管干部原则 党管干部原则是指在整个干部选拔任用过程中&#xff0c;党要发挥总揽全局、协调各方的领导作用&#xff0c;确保选拔出的干…

pytorch 加权CE_loss实现(语义分割中的类不平衡使用)

加权CE_loss和BCE_loss稍有不同 1.标签为long类型&#xff0c;BCE标签为float类型 2.当reduction为mean时计算每个像素点的损失的平均&#xff0c;BCE除以像素数得到平均值&#xff0c;CE除以像素对应的权重之和得到平均值。 参数配置torch.nn.CrossEntropyLoss(weightNone,…

算法01 递推算法及相关问题详解【C++实现】

目录 递推的概念 训练&#xff1a;斐波那契数列 解析 参考代码 训练&#xff1a;上台阶 参考代码 训练&#xff1a;信封 解析 参考代码 递推的概念 递推是一种处理问题的重要方法。 递推通过对问题的分析&#xff0c;找到问题相邻项之间的关系&#xff08;递推式&a…

【机器学习】LightGBM: 优化机器学习的高效梯度提升决策树

&#x1f308;个人主页: 鑫宝Code &#x1f525;热门专栏: 闲话杂谈&#xff5c; 炫酷HTML | JavaScript基础 ​&#x1f4ab;个人格言: "如无必要&#xff0c;勿增实体" 文章目录 LightGBM: 优化机器学习的高效梯度提升决策树引言一、LightGBM概览二、核心技术…

Go语言结构体内嵌接口

前言 在golang中&#xff0c;结构体内嵌结构体&#xff0c;接口内嵌接口都很常见&#xff0c;但是结构体内嵌接口很少见。它是做什么用的呢&#xff1f; 当我们需要重写实现了某个接口的结构体的(该接口)的部分方法&#xff0c;可以使用结构体内嵌接口。 作用 继承赋值给接口…

激活和禁用Hierarchy面板上的物体

1、准备工作&#xff1a; (1) 在HIerarchy上添加待隐藏/显示的物体&#xff0c;名字自取。如&#xff1a;endImage (2) 在Inspector面板&#xff0c;该物体的名称前取消勾选&#xff08;隐藏&#xff09; (3) 在HIerarchy上添加按钮&#xff0c;名字自取。如&#xff1a;tip…

chatgpt的命令词

人不走空 &#x1f308;个人主页&#xff1a;人不走空 &#x1f496;系列专栏&#xff1a;算法专题 ⏰诗词歌赋&#xff1a;斯是陋室&#xff0c;惟吾德馨 目录 &#x1f308;个人主页&#xff1a;人不走空 &#x1f496;系列专栏&#xff1a;算法专题 ⏰诗词歌…

汽车级TPSI2140QDWQRQ1隔离式固态继电器,TMUX6136PWR、TMUX1109PWR、TMUX1133PWR模拟开关与多路复用器(参数)

1、TPSI2140-Q1 是一款隔离式固态继电器&#xff0c;专为高电压汽车和工业应用而设计。 TPSI2140-Q1 与 TI 具有高可靠性的电容隔离技术和内部背对背 MOSFET 整合在一起&#xff0c;形成了一款完全集成式解决方案&#xff0c;无需次级侧电源。 该器件的初级侧仅由 9mA 的输入电…

CSS实现经典打字小游戏《生死时速》

&#x1f33b; 前言 CSS 中有这样一个模块&#xff1a;Motion Path 运动模块&#xff0c;它可以使元素按照自定义的路径进行移动。本文将为你讲解这个模块属性的使用&#xff0c;并且利用它实现我小时候电脑课经常玩的一个打字游戏&#xff1a;金山打字的《生死时速》。 &…

僵尸网络相关

个人电脑被植入木马之后&#xff0c;就会主动的连接被黑客控制的这个C&C服务器&#xff0c;然后这个服务器就会给被植入木马的这个电脑发指令&#xff0c;让他探测在他的局域网内还有没有其他的电脑了&#xff0c;如果有那么就继续感染同局域网的其他病毒&#xff0c;黑客就…

随机森林算法进行预测(+调参+变量重要性)--血友病计数数据

1.读取数据 所使用的数据是血友病数据&#xff0c;如有需要&#xff0c;可在主页资源处获取&#xff0c;数据信息如下&#xff1a; import pandas as pd import numpy as np hemophilia pd.read_csv(D:/my_files/data.csv) #读取数据 2.数据预处理 在使用机器学习方法时&…

手机在网状态-手机在网状态查询-手机在网站状态接口

查询手机号在网状态&#xff0c;返回正常使用、停机、未启用/在网但不可用、不在网&#xff08;销号/未启用/异常&#xff09;、预销户等多种状态 直连三大运营商&#xff0c;实时更新&#xff0c;可查询实时在网状态 高准确率-实时更新&#xff0c;准确率99.99% 接口地址&…

Matlab的Simulink系统仿真(simulink调用m函数)

这几天要用Simulink做一个小东西&#xff0c;所以在网上现学现卖&#xff0c;加油&#xff01; 起初的入门是看这篇文章MATLAB 之 Simulink 操作基础和系统仿真模型的建立_matlab仿真模型搭建-CSDN博客 写的很不错 后面我想在simulink中调用m文件 在 Simulink 中调用 MATLA…