Yolov8 目标检测剪枝学习记录

最近在进行YOLOv8系列的轻量化,目前在网络结构方面的优化已经接近极限了,所以想要学习一下模型剪枝是否能够进一步优化模型的性能
这里主要参考了torch-pruning的基本使用,v8模型剪枝,Jetson nano部署剪枝YOLOv8
下面只是记录一个简单流程,用于后续使用在自己的任务和网络中,数据不作为参考

首先训练一个base模型用于参考

  • 环境:Ultralytics YOLOv8.2.18 🚀 Python-3.10.14 torch-2.4.0 CUDA:0 (NVIDIA H100 PCIe, 81008MiB)
  • 训练代码

参考网上或者自己写一个能训练即可,为了方便我将通用的记录下来,实测可用来自代码来源

from ultralytics import YOLO
import osroot = os.getcwd()
## 配置文件路径
name_yaml             = os.path.join(root, "ultralytics/datasets/VOC.yaml")
name_pretrain         = os.path.join(root, "yolov8s.pt")
## 原始训练路径
path_train            = os.path.join(root, "runs/detect/VOC")
name_train            = os.path.join(path_train, "weights/last.pt")
## 约束训练路径、剪枝模型文件
path_constraint_train = os.path.join(root, "runs/detect/VOC_Constraint")
name_prune_before     = os.path.join(path_constraint_train, "weights/last.pt")
name_prune_after      = os.path.join(path_constraint_train, "weights/last_prune.pt")
## 微调路径
path_fineturn         = os.path.join(root, "runs/detect/VOC_finetune")def else_api():path_data = ""path_result = ""model = YOLO(name_pretrain) metrics = model.val()  # evaluate model performance on the validation setmodel.export(format='onnx', opset=11, simplify=True, dynamic=False, imgsz=640)model.predict(path_data, device="0", save=True, show=False, save_txt=True, imgsz=[288,480], save_conf=True, name=path_result, iou=0.5)  # 这里的imgsz为高宽def step1_train():model = YOLO(name_pretrain) model.train(data=name_yaml, device="0,1", imgsz=640, epochs=50, batch=32, workers=16, save_period=1, name=path_train)  # train the model## 2024.3.4添加【amp=False】
def step2_Constraint_train():model = YOLO(name_train) model.train(data=name_yaml, device="0,1", imgsz=640, epochs=50, batch=32, amp=False, workers=16, save_period=1,name=path_constraint_train)  # train the modeldef step3_pruning():from LL_pruning import do_pruningdo_pruning(name_prune_before, name_prune_after)def step4_finetune():model = YOLO(name_prune_after)     # load a pretrained model (recommended for training)model.train(data=name_yaml, device="0,1", imgsz=640, epochs=50, batch=32, workers=16, save_period=1, name=path_fineturn)  # train the modelstep1_train()
# step2_Constraint_train()
# step3_pruning()
# step4_finetune()

第一步,step1_train()

  • 即训练一个base模型,用于最后性能好坏的重要参考
    在这里插入图片描述

第二步,step2_Constraint_train()

训练之前在ultralytics\engine\trainer.py添加bn的L1正则,使得bn参数在训练时变得稀疏

  • 通过对参数的绝对值进行惩罚,使得一些不重要的权重变为零,从而实现模型的稀疏化和简化
     # Backwardself.scaler.scale(self.loss).backward()## add new code=============================duj## add l1 regulation for step2_Constraint_train               l1_lambda = 1e-2 * (1 - 0.9 * epoch / self.epochs)for k, m in self.model.named_modules():if isinstance(m, nn.BatchNorm2d):m.weight.grad.data.add_(l1_lambda * torch.sign(m.weight.data))m.bias.grad.data.add_(1e-2 * torch.sign(m.bias.data))# Optimize - https://pytorch.org/docs/master/notes/amp_examples.htmlif ni - last_opt_step >= self.accumulate:self.optimizer_step()last_opt_step = ni

在这里插入图片描述

  • 个人理解的稀疏化作用
    • 通过对 gamma 和 beta 添加 L1 正则化,可以促使某些通道的 BN 权重变得非常小,甚至为零。这意味着在剪枝时,可以将这些通道从模型中移除
    • 通过稀疏化 BN 层并剪除不重要的通道,剩下的通道会更有效地利用计算资源,减少无用计算。

第三步,step3_pruning()剪枝操作

LL_pruning.py

from ultralytics import YOLO
import torch
from ultralytics.nn.modules import Bottleneck, Conv, C2f, SPPF, Detect
import osclass PRUNE():def __init__(self) -> None:self.threshold = Nonedef get_threshold(self, model, factor=0.8):ws = []bs = []for name, m in model.named_modules():if isinstance(m, torch.nn.BatchNorm2d):w = m.weight.abs().detach()b = m.bias.abs().detach()ws.append(w)bs.append(b)print(name, w.max().item(), w.min().item(), b.max().item(), b.min().item())print()# keepws = torch.cat(ws)self.threshold = torch.sort(ws, descending=True)[0][int(len(ws) * factor)]def prune_conv(self, conv1: Conv, conv2: Conv):## a. 根据BN中的参数,获取需要保留的index================gamma = conv1.bn.weight.data.detach()beta  = conv1.bn.bias.data.detach()keep_idxs = []local_threshold = self.thresholdwhile len(keep_idxs) < 8:  ## 若剩余卷积核<8, 则降低阈值重新筛选keep_idxs = torch.where(gamma.abs() >= local_threshold)[0]local_threshold = local_threshold * 0.5n = len(keep_idxs)# n = max(int(len(idxs) * 0.8), p)print(n / len(gamma) * 100)# scale = len(idxs) / n## b. 利用index对BN进行剪枝============================conv1.bn.weight.data = gamma[keep_idxs]conv1.bn.bias.data   = beta[keep_idxs]conv1.bn.running_var.data = conv1.bn.running_var.data[keep_idxs]conv1.bn.running_mean.data = conv1.bn.running_mean.data[keep_idxs]conv1.bn.num_features = nconv1.conv.weight.data = conv1.conv.weight.data[keep_idxs]conv1.conv.out_channels = n## c. 利用index对conv1进行剪枝=========================if conv1.conv.bias is not None:conv1.conv.bias.data = conv1.conv.bias.data[keep_idxs]## d. 利用index对conv2进行剪枝=========================if not isinstance(conv2, list):conv2 = [conv2]for item in conv2:if item is None: continueif isinstance(item, Conv):conv = item.convelse:conv = itemconv.in_channels = nconv.weight.data = conv.weight.data[:, keep_idxs]def prune(self, m1, m2):if isinstance(m1, C2f):      # C2f as a top convm1 = m1.cv2if not isinstance(m2, list): # m2 is just one modulem2 = [m2]for i, item in enumerate(m2):if isinstance(item, C2f) or isinstance(item, SPPF):m2[i] = item.cv1self.prune_conv(m1, m2)def do_pruning(modelpath, savepath):pruning = PRUNE()### 0. 加载模型yolo = YOLO(modelpath)                  # build a new model from scratchpruning.get_threshold(yolo.model, 0.8)  # 获取剪枝时bn参数的阈值,这里的0.8为剪枝率。### 1. 剪枝c2f 中的Bottleneckfor name, m in yolo.model.named_modules():if isinstance(m, Bottleneck):pruning.prune_conv(m.cv1, m.cv2)### 2. 指定剪枝不同模块之间的卷积核seq = yolo.model.modelfor i in [3,5,7,8]: pruning.prune(seq[i], seq[i+1])### 3. 对检测头进行剪枝# 在P3层: seq[15]之后的网络节点与其相连的有 seq[16]、detect.cv2[0] (box分支)、detect.cv3[0] (class分支)# 在P4层: seq[18]之后的网络节点与其相连的有 seq[19]、detect.cv2[1] 、detect.cv3[1] # 在P5层: seq[21]之后的网络节点与其相连的有 detect.cv2[2] 、detect.cv3[2] detect:Detect = seq[-1]last_inputs   = [seq[15], seq[18], seq[21]]colasts       = [seq[16], seq[19], None]for last_input, colast, cv2, cv3 in zip(last_inputs, colasts, detect.cv2, detect.cv3):pruning.prune(last_input, [colast, cv2[0], cv3[0]])pruning.prune(cv2[0], cv2[1])pruning.prune(cv2[1], cv2[2])pruning.prune(cv3[0], cv3[1])pruning.prune(cv3[1], cv3[2])### 4. 模型梯度设置与保存for name, p in yolo.model.named_parameters():p.requires_grad = Trueyolo.val()torch.save(yolo.ckpt, savepath)yolo.model.pt_path = yolo.model.pt_path.replace("last.pt", os.path.basename(savepath))yolo.export(format="onnx")## 重新load模型,修改保存命名,用以比较剪枝前后的onnx的大小yolo = YOLO(modelpath)  # build a new model from scratchyolo.export(format="onnx")if __name__ == "__main__":modelpath = "runs/detect1/14_Constraint/weights/last.pt"savepath  = "runs/detect1/14_Constraint/weights/last_prune.pt"do_pruning(modelpath, savepath)

在这里插入图片描述

  • 如下图可用看到剪枝前后还是有区别的,参数量减少很多,网络性能将不可用,需要微调恢复精度
    在这里插入图片描述
  • 查看剪枝前后模型大小 du -sh ./runs/detect/VOC_Constraint/weights/last*yolov8n模型
    在这里插入图片描述

微调

该部分内容我也存在一些疑问,例如很多博主让ultralytics\engine\trainer.py添加加载模型代码,经过我8.2版本测试代码添加是完全失效的,因为setup_model在执行if isinstance(self.model, torch.nn.Module)就已经return了。

 def setup_model(self):"""Load/create/download model for any task."""if isinstance(self.model, torch.nn.Module):  # if model is loaded beforehand. No setup neededreturn
  • 例如ultralytics\engine\trainer.py
  • v8…x添加代码:548行 参考这里
self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1)
# duj add code to finetune
self.model = weights
return ckpt
  • 如果是v8.0.x 参考这里

在看到这篇中的修改1启发

  • v8.2.x上面我不确定是哪个版本需要添加的,但是我实测都不起作用
  • 我尝试在ultralytics\engine\model.py添加如下代码加载模型成功
 self.trainer = (trainer or self._smart_load("trainer"))(overrides=args, _callbacks=self.callbacks)if not args.get("resume"):  # manually set model only if not resuming# self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)# self.model = self.trainer.model# dujiang edit self.trainer.model = self.model.train()if SETTINGS["hub"] is True and not self.session:
  • 这里就是确保自己加载的是剪枝后的模型,但是不同版本好像不同,后续在探究原因。。。
  • 这里有个小插曲,我在使用自己模型稀疏训练后剪枝发现(步骤2)发现BN层全没了,这里后面我将别人的稀疏训练的v8s模型拿来进行剪枝就没问题
  • 可能是v8n的问题,也可能是我训练的问题,这里先不做深究继续查看剪枝是否成功且微调加载成功后能否恢复精度
    在这里插入图片描述
  • 此时多次尝试我基本确定微调加载的是我剪枝后的模型,接下来就是等待训练结果是否参数量正确了。
    在这里插入图片描述

总结

总的来说跑通整个流程了,接下来尝试在自己的任务和数据上面进行剪枝,看看更换了模型结构又会有哪些坑等着我

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/2712.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

AWS Lambda

AWS Lambda 是 Amazon Web Services&#xff08;AWS&#xff09;提供的无服务器计算服务&#xff0c;它让开发者能够运行代码而不需要管理服务器或基础设施。AWS Lambda 会自动处理代码的执行、扩展和计费&#xff0c;开发者只需关注编写和部署代码&#xff0c;而无需担心底层硬…

MySQL-索引

目录 &#x1f334;概念 &#x1f335;作用 &#x1f331;使用场景 &#x1f384;使用 查看索引 创建索引 删除索引 &#x1f334;概念 索引是一种特殊的文件&#xff0c;包含着对数据表里所有记录的引用指针。可以对表中的一列或多列创建索引&#xff0c;并指定索引的类…

自动化办公|xlwings简介

xlwings 是一个开源的 Python 库&#xff0c;旨在实现 Python 与 Microsoft Excel 的无缝集成。它允许用户使用 Python 脚本自动化 Excel 操作&#xff0c;读取和写入数据&#xff0c;执行宏&#xff0c;甚至调用 VBA 脚本。这使得数据分析、报告生成和其他与 Excel 相关的任务…

【物联网】ARM核介绍

文章目录 芯片产业链1. CPU核(1)ARM(2)MIPS(3)PowerPc(4)Intel(5)RISC-V 2. SOC芯片(1)主流厂家(2)产品解决方案 3. 产品 ARM核发展1. 不同架构的特点分析(1)VFP(2)Jazelle(3)Thumb(4)TrustZone(5)SIMD(6)NEON ARM核(ARMv7)工作模式1. 权限级别(privilege level)2. ARM process…

SuperMap iClient3D for Cesium立体地图选中+下钻特效

在大屏展示系统中&#xff0c;对行政区划数据制作了立体效果&#xff0c;如果希望选中某一行政区划进行重点介绍&#xff0c;目前常见的方式是通过修改选中对象色彩、边线等方式进行实现&#xff1b;这里提供另外一种偏移动效的思路&#xff0c;并提供下钻功能&#xff0c;让地…

Linux的常用命令(三)

目录 六、网络通信命令 1.网络通信命令ping 2.网络通信命令ifconfig 七、系统命令 1. 系统命令shutdown 2. 系统命令reboot 八、vi编辑器 六、网络通信命令 1.网络通信命令ping 命令名称&#xff1a;ping 命令所在路径&#xff1a;/usr/sbin/ping 执行权限&#xff…

SQL Prompt 插件

SQL Prompt 插件 注&#xff1a;SQL Prompt插件提供智能代码补全、SQL格式化、代码自动提示和快捷输入等功能&#xff0c;非常方便&#xff0c;可以自行去尝试体会。 1、问题 SSMS&#xff08;SQL Server Management Studio&#xff09;是SQL Server自带的管理工具&#xff0c…

《小迪安全》学习笔记05

目录 读取&#xff1a; 写入&#xff1a; &#xff08;其中的读取和写入时我认为比较重要的&#xff0c;所以单独做成了目录&#xff0c;这里的读取和写入是指在进行sql注入的时候与本地文件进行的交互&#xff09; 好久没发博客了。。。从这篇开始的小迪安全学习笔记就开始…

SpringCloud源码-Ribbon

一、Spring定制化RestTemplate&#xff0c;预留出RestTemplate定制化扩展点 org.springframework.cloud.client.loadbalancer.LoadBalancerAutoConfiguration 二、Ribbon定义RestTemplate Ribbon扩展点功能 org.springframework.cloud.netflix.ribbon.RibbonAutoConfiguratio…

Linux 常用命令 - chmod 【改变文件或目录权限】

简介 “chmod” 这个命令来自于 “change mode” 的缩写&#xff0c;用于更改文件或目录的访问权限。这个命令允许用户设定谁可以读取、写入或执行一个文件。在 Linux 和其他类 Unix 系统中&#xff0c;文件权限对系统安全和用户隐私至关重要。 Linux/Unix 的文件调用权限分为…

Linux系统离线部署MySQL详细教程(带每步骤图文教程)

1、登录官网下载对应的安装包 MySQL :: Developer Zone 2、将压缩包上传到服务器上&#xff0c;这里直接上传到/usr/local路径上 使用sftp工具上传到/usr/local目录上 3、解压压缩包 tar -xf mysql-8.0.39-linux-glibc2.17-x86_64.tar.xz 4、将mysql-8.0.39-linux-glibc2.17…

PyTorch使用教程(1)—PyTorch简介

PyTorch是一个开源的深度学习框架&#xff0c;由Facebook人工智能研究院&#xff08;FAIR&#xff09;于2016年开发并发布&#xff0c;其主要特点包括自动微分功能和动态计算图的支持&#xff0c;使得模型建立更加灵活‌。官网网址&#xff1a;https://pytorch.org。以下是关于…

Linux浅谈——管道、网络配置和客户端软件的使用

目录 一、管道 1、管道符 2、过滤功能 3、特殊功能 4、扩展处理 5、xargs命令扩展 二、网络配置 1、ifconfig查看网络信息 2、配置文件详解 网卡配置文件位置 3、systemctl查看网卡状态 4、systemctl启动/重启/停止网卡 三、客户端软件 1、什么是SSH 2、常用SSH终…

arcgis中生成格网矢量带高度

效果 1、数据准备 (1)矢量边界(miain.shp) (2)DEM(用于提取格网标高) (3)DSM(用于提取格网最高点) 2、根据矢量范围生成格网 模板范围选择矢量边界,像元宽度和高度根据坐标系来输入,我这边是4326的,所以输入的是弧度,输出格网矢量gewang.shp 3、分区统计 …

IEC103 转 ModbusTCP 网关

一、产品概述 IEC103 转 ModbusTCP 网关型号 SG-TCP-IEC103 &#xff0c;是三格电子推出的工业级网关&#xff08;以下简 称网关&#xff09;&#xff0c;主要用于 IEC103 数据采集、 DLT645-1997/2007 数据采集&#xff0c; IEC103 支持遥测和遥 信&#xff0c;可接…

Android BottomNavigationView不加icon使text垂直居中,完美解决。

这个问题网上千篇一律的设置iconsize为0&#xff0c;labale固定什么的&#xff0c;都没有效果。我的这个基本上所有人用都会有效果。 问题解决之前的效果&#xff1a;垂直方向&#xff0c;文本不居中&#xff0c;看着很难受 问题解决之后&#xff1a;舒服多了 其实很简单&…

【蓝桥杯】43687.赢球票

题目描述 某机构举办球票大奖赛。获奖选手有机会赢得若干张球票。 主持人拿出 N 张卡片&#xff08;上面写着 1⋯N 的数字&#xff09;&#xff0c;打乱顺序&#xff0c;排成一个圆圈。 你可以从任意一张卡片开始顺时针数数: 1,2,3 ⋯ ⋯ 如果数到的数字刚好和卡片上的数字…

(01)FreeRTOS移植到STM32

一、以STM32的裸机工程模板 任意模板即可 二、去官网上下载FreeRTOS V9.0.0 源码 在移植之前&#xff0c;我们首先要获取到 FreeRTOS 的官方的源码包。这里我们提供两个下载 链 接 &#xff0c; 一 个 是 官 网 &#xff1a; http://www.freertos.org/ &#xff0c; 另…

金融项目实战 05|Python实现接口自动化——登录接口

目录 一、代码实现自动化理论及流程 二、脚本实现的理论和准备工作 1、抽取功能转为自动化用例 2、搭建环境(测试工具) 3、搭建目录结构 三、登录接口脚本实现 1、代码编写 1️⃣api目录 2️⃣script目录 2、断言 3、参数化 1️⃣编写数据存储文件&#xff1a;jso…

C# .NetCore 使用 Flurl.Http 与 HttpClient 请求处理流式响应

AI对话接口采用流式返回&#xff1a; 1、使用Flurl处理返回的数据流 using Flurl; using Flurl.Http; [HttpPost] public async Task<string> GetLiushiChatLaw() { //1、请求参数&#xff0c;根据实际情况 YourModel request new YourModel();string allStr …