【yolov8】yolov8剪枝训练流程

yolov8剪枝训练流程

流程:

  • 约束
  • 剪枝
  • 微调

一、正常训练

yolo train model=./weights/yolov8s.pt data=yolo_bvn.yaml epochs=100 amp=False project=prun name=train

二、约束训练

2.1 修改YOLOv8代码:

ultralytics/yolo/engine/trainer.py
添加内容:

# Backwardself.scaler.scale(self.loss).backward()# ========== 新增 ==========l1_lambda = 1e-2 * (1 - 0.9 * epoch / self.epochs)
for k, m in self.model.named_modules():if isinstance(m, nn.BatchNorm2d):m.weight.grad.data.add_(l1_lambda * torch.sign(m.weight.data))m.bias.grad.data.add_(1e-2 * torch.sign(m.bias.data))# ========== 新增 ==========# Optimize - https://pytorch.org/docs/master/notes/amp_examples.htmlif ni - last_opt_step >= self.accumulate:self.optimizer_step()last_opt_step = ni

2.2 训练

需要注意的就是amp=False

yolo train model=prunt/train/weights/best.pt data=yolo_bvn.yaml epochs=100 amp=False project=prun name=constraint

训练完会得到一个best.pt和last.pt,推荐用last.pt

三、剪枝

上一步得到的last.pt作为剪枝对象,运行项目中的prun.py文件:

*这里的剪枝代码仅适用yolov8原模型,如有模块/模型的更改,则需要修改剪枝代码*

运行完会得到prune.pt和prune.onnx可以在netron.app网站拖入onnx文件查看是否剪枝成功了,成功的话可以看到某些通道数字为单数或者一些不规律的数字,如下图:

在这里插入图片描述

左侧为未剪枝的模型,右侧为剪枝后的模型。

关于yolov8剪枝有以下几点值得注意:

Pipeline:

    1. 为模型的BN增加L1约束,lambda用1e-2左右
    1. 剪枝模型使用的是全局阈值
    1. finetune模型时,一定要注意,此时需要去掉L1约束,最终的final的版本一定是去掉的(ultralytics/yolo/engine/trainer.py中注释)
    1. 对于yolo.model.named_parameters()循环,需要设置p.requires_gradTrue

Future work:

    1. 不能剪枝的layer,其实可以不用约束
    1. 对于低于全局阈值的,可以删掉整个module
    1. keep channels,对于保留的channels,它应该能整除n才是最合适的,否则硬件加速比较差
  • n怎么选呢?一般fp16时,n为8;int8时,n为16

四、 回调训练(finetune)

回调训练的唯一关键点就在于不让模型从yaml文件加载结构,直接加载pt文件

两种方法(因yolov8版本不同而选择不同方法):

方法一:

3.1 首先要把第一步约束训练的代码注释掉
3.2 修改相关代码,使模型不加载yaml文件

修改位置:yolo/engine/model.py的443行左右

self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1)  # calls Model(cfg, weights)# ========== 新增该行代码 ==========self.model = weights# ========== 新增该行代码 ==========return ckpt

方法二:

3.1 首先要把第一步约束训练的代码注释掉
3.2 修改相关代码,使模型不加载yaml文件

修改位置:yolo/engine/model.py的335行左右

if not args.get('resume'):  # manually set model only if not resuming# self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)# self.model = self.trainer.model######################上面两行注释掉,添加下面一行#####self.trainer.model = self.model.train()##########################修改####################self.trainer.hub_session = self.session  # attach optional HUB session
3.3 修改完代码就可以进行finetun训练了

命令行输入:

yolo train model=prun/prune/weights/last_prune.pt data="yolo_bvn.yaml" amp=False epochs=100 project=prun name=finetune device=0

五、结果展示:

5.1模型大小:ONNX模型大小从42M减少到34M

在这里插入图片描述

5.2PR曲线:

正常训练约束训练100轮微调
在这里插入图片描述在这里插入图片描述在这里插入图片描述

5.3实测视频在ubuntu上检测速度:

未剪枝:平均每帧5毫秒

剪枝后:平均每帧3.7毫秒

六、问题及解决:

对剪枝完的yolov8进行finetune时遇到RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument mat2 in method wrapper_mm)

self.proj 可能不在与 pred_dist 相同的设备上。这可能是因为 self.proj 被指定在 CPU 上,而 pred_dist 在 GPU 上(或反之)。
要解决这个问题,需要确保两个张量位于相同的设备上。可以使用 to() 方法将 self.proj 放到与 pred_dist 相同的设备上。

解决:在loss.py添加如下代码:

def bbox_decode(self, anchor_points, pred_dist):"""Decode predicted object bounding box coordinates from anchor points and distribution."""if self.use_dfl:b, a, c = pred_dist.shape  # batch, anchors, channels####添加device = pred_dist.deviceself.proj = self.proj.to(device)#####pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))# pred_dist = pred_dist.view(b, a, c // 4, 4).transpose(2,3).softmax(3).matmul(self.proj.type(pred_dist.dtype))# pred_dist = (pred_dist.view(b, a, c // 4, 4).softmax(2) * self.proj.type(pred_dist.dtype).view(1, 1, -1, 1)).sum(2)return dist2bbox(pred_dist, anchor_points, xywh=False)

七、参考:

7.1 【yolov8系列】 yolov8 目标检测的模型剪枝_yolov8 剪枝-CSDN博客
7.2 YOLOv8剪枝全过程-CSDN博客

7.3 剪枝与重参第七课:YOLOv8剪枝-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/320859.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

深度学习之基于Vgg19预训练卷积神经网络图像风格迁移系统

欢迎大家点赞、收藏、关注、评论啦 ,由于篇幅有限,只展示了部分核心代码。 文章目录 一项目简介 二、功能三、系统四. 总结 一项目简介 一、项目背景 在数字艺术和图像处理领域,图像风格迁移技术一直备受关注。该技术可以将一幅图像的内容和…

MATLAB实现杜拉德公式和凯夫公式的计算固液混合料浆临界流速

MATLAB实现杜拉德公式和凯夫公式的计算固液混合料浆临界流速: 杜拉德公式是用来计算非均质固液混合料浆在输送管中的临界速度的公式,具体形式为: uL FL (2gD / (ρ0 - ρ1))^(1/2) 其中: uL:表示料浆的临界速度,…

什么是泛域名证书?与普通SSL证书有什么区别

随着互联网的发展,越来越多的网站开始使用SSL证书来保护用户的隐私和安全。在SSL证书中,泛域名SSL证书和普通域名证书是两种常见的类型。那么,什么是泛域名SSL证书,与普通域名证书有什么区别呢? 首先,我们来…

投资者悄然收购二手楼梯楼,在杭州豪掷巨资购买12套!

独家首发 -------------- 日前杭州中介流传,一名投资客大举收购二手楼梯楼,下手就是12套,显示出一些具有前瞻性眼光的投资者悄悄放弃电梯楼,选择了处于价格洼地的楼梯楼。 二手楼梯楼当下被严重低估,在一线城市的二手楼…

【文献阅读】 The ITS Irregular Terrain Model(Longely-Rice模型)海上电波传播模型

前言 因为最近在做海上通信的一个项目,所以需要对海上的信道进行建模,所以才阅读到了这一篇文献,下面的内容大部分是我的个人理解,如有错误,请见谅。欢迎在评论区和我一起讨论。 Longely-Rice模型介绍 频率介于 20 …

AI摄影教程,让你实现写真自由!

AI摄影,就是用AI生成写真照片 和传统摄影不同的是,传统的摄影需要先妆造、布景,然后再进行拍摄,前后需要耗费的时间精力非常多 而AI摄影只需要在电脑上上传十几张自己的日常照片,就能根据自己的喜好去生成各种梦幻、甚…

软件测试经理工作日常随记【2】-接口自动化

软件测试主管工作日常随记【2】-接口自动化 1.接口自动化 jmeter-反电诈项目 这个我做过的一个非常有意义的项目,和腾讯合作的,主要为用户拦截并提示所有可能涉及到的诈骗类型,并以裂变的形式扩展用户,这个项目前期后端先完成&…

设计宝典与速查手册,设计师必备资料合集

一、资料描述 本套设计资料,大小194.34M,共有13个文件。 二、资料目录 01-《商业设计宝典》.pdf 02-《色彩速查宝典》.pdf 03-《配色宝典》.pdf 04-《解读色彩情感密码》.pdf 05-《行业色彩应用宝典》.pdf 06-《构图宝典》.pdf 07-《创意宝典》…

上位机图像处理和嵌入式模块部署(树莓派4b下ros安装方法)

【 声明:版权所有,欢迎转载,请勿用于商业用途。 联系信箱:feixiaoxing 163.com】 随着嵌入式开发板算力越来越强,很多的同学开始用树莓派做一些ros开发的工作。目前来说,ros有两个版本,分别是ro…

【RPC】Dubbo接口测试

关于rpc,推荐看看这篇 : 既然有HTTP协议,为什么还要有RPC 一、Dubbo 是一款alibaba开源的高性能服务框架: 分布式服务框架高性能和透明化的RPC远程服务调用方案SOA服务治理方案 二、Dubbo基础架构 三、 Dubbo接口测试 1、jme…

MambaMOS:基于激光雷达的三维运动物体分割与运动感知状态空间模型

MambaMOS:基于激光雷达的三维运动物体分割与运动感知状态空间模型 摘要INTRODUCTIONRelated WorkMethod MambaMOS: LiDAR-based 3D Moving Object Segmentation with Motion-aware State Space Model 摘要 激光雷达基于的运动目标分割(MOS)旨在利用之前…

一站式PDF解决方案:如何部署自己的PDF全能工具(Docker部署和群晖部署教程)

文章目录 📖 介绍 📖🏡 演示环境 🏡📒 开始部署 📒📝 Docker部署📝 群晖部署📝 本地安装⚓️ 相关链接 ⚓️📖 介绍 📖 在数字化办公的今天,PDF文件几乎成了我们日常工作中不可或缺的一部分。但你是否曾因为PDF文件的编辑、转换、合并等问题而头疼?如果…

Python类方法探秘:从单例模式到版本控制

引言: 在Python编程中,类方法作为一种特殊的实例方法,以其独特的魅力在众多编程范式中脱颖而出。它们不仅提供了无需实例即可调用的便捷性,还在设计模式、版本控制等方面发挥着重要作用。本文将通过几个生动的示例,带您…

STM32——GPIO篇

技术笔记! 1. 什么是GPIO? GPIO是通用输入输出端口(General-purpose input/output)的英文简写,是所有的微控制器必不可少的外设之一,可以由STM32直接驱动从而实现与外部设备通信、控制以及采集和捕获的功…

骨传导耳机哪个品牌值得入手?精选五款高性能骨传导耳机,闭眼入都不踩雷!

随着健康生活的日益普及,运动健身逐渐成为人们生活中的重要组成部分。在这一背景下,骨传导耳机作为一种新型蓝牙耳机,凭借其不堵塞耳道、防水性能强等特性,受到了广大运动爱好者的喜爱。然而,骨传导耳机的热销也吸引了…

海外大带宽服务器的带宽大小是如何定义的?

海外大带宽服务器的带宽大小通常是由提供的数据传输速率来衡量的。Rak部落小编为您整理发布海外大带宽服务器的带宽大小是如何定义的? 带宽的大小决定了服务器能够处理的数据量和传输速度,这对于确保服务器性能至关重要。在详细定义中,带宽可以根据以下…

Flutter笔记:Widgets Easier组件库(9)使用弹窗

Flutter笔记 Widgets Easier组件库(9):使用弹窗 - 文章信息 - Author: 李俊才 (jcLee95) Visit me at CSDN: https://jclee95.blog.csdn.netMy WebSite:http://thispage.tech/Email: 291148484163.com. Shenzhen ChinaAddress o…

每日OJ题_贪心算法三②_力扣553. 最优除法

目录 力扣553. 最优除法 解析代码 力扣553. 最优除法 553. 最优除法 难度 中等 给定一正整数数组 nums,nums 中的相邻整数将进行浮点除法。例如, [2,3,4] -> 2 / 3 / 4 。 例如,nums [2,3,4],我们将求表达式的值 "…

MySQL-笔记-08.数据库编程

目录 8.1 编程基础 8.1.1 基本语法 8.1.2 运算符与表达式 1. 标识符 2. 常量 (1) 字符串常量 (2)日期时间常量 (3)数值常量 (4)布尔值常量 (5)NULL…

暖心又实用!母亲节教会妈妈这4招才是最贴心的礼物

母亲节就要到了,这个特殊的日子,我们总是想要为妈妈送上最真挚的祝福和关怀。在这个数字化时代,一部智能手机就能成为我们表达爱意的桥梁。今天,就让我们一起来看看华为手机的四个功能,让妈妈的手机使用体验更加便捷、…