【yolov8系列】 yolov8 目标检测的模型剪枝

前言

最近在实现yolov8的剪枝,所以有找相关的工作作为参考,用以完成该项工作。

  • 先细读了 Torch-Pruning,个人简单记录了下 【剪枝】torch-pruning的基本使用,有框架完成的对网络所有结构都自适应剪枝是最佳的,但这里没有详细记录torch-pruning的yolov8的剪枝,是因为存在不解 对其yolov8具体的剪枝代码中操作:“比较疑惑 replace_c2f_with_c2f_v2(model.model)
    这句注释掉了代码就跑不通了。是tp不支持原本 c2f 吗,我如果想使用c2f进行剪枝,应该怎么办呢”。待解决该问题,然后在记录。
  • 然后另外参考博客 Jetson nano部署剪枝YOLOv8,该方法的代码实现仅针对yolov8 的剪枝,但可从中借鉴如何利用bn对模型进行剪枝,举一反三应用到其它工程中。bn剪枝的原理可以在 3.1 常用的结构化剪枝原理 简单记录。

1. 剪枝工程搭建

yolov8工程下载
本地有个版本 Ultralytics 8.0.81,所以该篇博客基于该版本记录。不同版本可能带来的影响,所以当使用最新版本出bug时,又无法定位和解决,可先尝试8.0.81版本。我们剪枝过程在VOC数据集上完成尝试。
在这里插入图片描述

这里添加两个代码文件
分别为【LL_pruning.pyLL_train.py】存放于根目录下在这里插入图片描述

  • LL_train.py的内容为
    from ultralytics import YOLO
    import os
    # os.environ["CUDA_VISIBLE_DEVICES"]="0,1" root = os.getcwd()
    ## 配置文件路径
    name_yaml             = os.path.join(root, "ultralytics/datasets/VOC.yaml")
    name_pretrain         = os.path.join(root, "yolov8s.pt")
    ## 原始训练路径
    path_train            = os.path.join(root, "runs/detect/VOC")
    name_train            = os.path.join(path_train, "weights/last.pt")
    ## 约束训练路径、剪枝模型文件
    path_constraint_train = os.path.join(root, "runs/detect/VOC_Constraint")
    name_prune_before     = os.path.join(path_constraint_train, "weights/last.pt")
    name_prune_after      = os.path.join(path_constraint_train, "weights/last_prune.pt")
    ## 微调路径
    path_fineturn         = os.path.join(root, "runs/detect/VOC_finetune")def else_api():path_data = ""path_result = ""model = YOLO(name_pretrain) metrics = model.val()  # evaluate model performance on the validation setmodel.export(format='onnx', opset=11, simplify=True, dynamic=False, imgsz=640)model.predict(path_data, device="0", save=True, show=False, save_txt=True, imgsz=[288,480], save_conf=True, name=path_result, iou=0.5)  # 这里的imgsz为高宽def step1_train():model = YOLO(name_pretrain) model.train(data=name_yaml, device="0,1", imgsz=640, epochs=50, batch=32, workers=16, save_period=1, name=path_train)  # train the modeldef step2_Constraint_train():model = YOLO(name_train) model.train(data=name_yaml, device="0,1", imgsz=640, epochs=50, batch=32, workers=16, save_period=1,name=path_constraint_train)  # train the modeldef step3_pruning():from LL_pruning import do_pruningdo_pruning(os.path.join(name_prune_before, name_prune_after))def step4_finetune():model = YOLO(name_prune_after)     # load a pretrained model (recommended for training)model.train(data=name_yaml, device="0,1", imgsz=640, epochs=50, batch=32, workers=16, save_period=1, name=path_fineturn)  # train the modelstep1_train()
    # step2_Constraint_train()
    # step3_pruning()
    # step4_finetune()
  • LL_pruning.py的内容为
    from ultralytics import YOLO
    import torch
    from ultralytics.nn.modules import Bottleneck, Conv, C2f, SPPF, Detect
    import os
    # os.environ["CUDA_VISIBLE_DEVICES"] = "2"class PRUNE():def __init__(self) -> None:self.threshold = Nonedef get_threshold(self, model, factor=0.8):ws = []bs = []for name, m in model.named_modules():if isinstance(m, torch.nn.BatchNorm2d):w = m.weight.abs().detach()b = m.bias.abs().detach()ws.append(w)bs.append(b)print(name, w.max().item(), w.min().item(), b.max().item(), b.min().item())print()# keepws = torch.cat(ws)self.threshold = torch.sort(ws, descending=True)[0][int(len(ws) * factor)]def prune_conv(self, conv1: Conv, conv2: Conv):## a. 根据BN中的参数,获取需要保留的index================gamma = conv1.bn.weight.data.detach()beta  = conv1.bn.bias.data.detach()keep_idxs = []local_threshold = self.thresholdwhile len(keep_idxs) < 8:  ## 若剩余卷积核<8, 则降低阈值重新筛选keep_idxs = torch.where(gamma.abs() >= local_threshold)[0]local_threshold = local_threshold * 0.5n = len(keep_idxs)# n = max(int(len(idxs) * 0.8), p)print(n / len(gamma) * 100)# scale = len(idxs) / n## b. 利用index对BN进行剪枝============================conv1.bn.weight.data = gamma[keep_idxs]conv1.bn.bias.data   = beta[keep_idxs]conv1.bn.running_var.data = conv1.bn.running_var.data[keep_idxs]conv1.bn.running_mean.data = conv1.bn.running_mean.data[keep_idxs]conv1.bn.num_features = nconv1.conv.weight.data = conv1.conv.weight.data[keep_idxs]conv1.conv.out_channels = n## c. 利用index对conv1进行剪枝=========================if conv1.conv.bias is not None:conv1.conv.bias.data = conv1.conv.bias.data[keep_idxs]## d. 利用index对conv2进行剪枝=========================if not isinstance(conv2, list):conv2 = [conv2]for item in conv2:if item is None: continueif isinstance(item, Conv):conv = item.convelse:conv = itemconv.in_channels = nconv.weight.data = conv.weight.data[:, keep_idxs]def prune(self, m1, m2):if isinstance(m1, C2f):      # C2f as a top convm1 = m1.cv2if not isinstance(m2, list): # m2 is just one modulem2 = [m2]for i, item in enumerate(m2):if isinstance(item, C2f) or isinstance(item, SPPF):m2[i] = item.cv1self.prune_conv(m1, m2)def do_pruning(modelpath, savepath):pruning = PRUNE()### 0. 加载模型yolo = YOLO(modelpath)                  # build a new model from scratchpruning.get_threshold(yolo.model, 0.8)  # 获取剪枝时bn参数的阈值,这里的0.8为剪枝率。### 1. 剪枝c2f 中的Bottleneckfor name, m in yolo.model.named_modules():if isinstance(m, Bottleneck):pruning.prune_conv(m.cv1, m.cv2)### 2. 指定剪枝不同模块之间的卷积核seq = yolo.model.modelfor i in [3,5,7,8]: pruning.prune(seq[i], seq[i+1])### 3. 对检测头进行剪枝# 在P3层: seq[15]之后的网络节点与其相连的有 seq[16]、detect.cv2[0] (box分支)、detect.cv3[0] (class分支)# 在P4层: seq[18]之后的网络节点与其相连的有 seq[19]、detect.cv2[1] 、detect.cv3[1] # 在P5层: seq[21]之后的网络节点与其相连的有 detect.cv2[2] 、detect.cv3[2] detect:Detect = seq[-1]last_inputs   = [seq[15], seq[18], seq[21]]colasts       = [seq[16], seq[19], None]for last_input, colast, cv2, cv3 in zip(last_inputs, colasts, detect.cv2, detect.cv3):pruning.prune(last_input, [colast, cv2[0], cv3[0]])pruning.prune(cv2[0], cv2[1])pruning.prune(cv2[1], cv2[2])pruning.prune(cv3[0], cv3[1])pruning.prune(cv3[1], cv3[2])# ***step4,一定要设置所有参数为需要训练。因为加载后的model他会给弄成false。导致报错# pipeline:# 1. 为模型的BN增加L1约束,lambda用1e-2左右# 2. 剪枝模型,比如用全局阈值# 3. finetune,一定要注意,此时需要去掉L1约束。最终final的版本一定是去掉的for name, p in yolo.model.named_parameters():p.requires_grad = True# 1. 不能剪枝的layer,其实可以不用约束# 2. 对于低于全局阈值的,可以删掉整个module# 3. keep channels,对于保留的channels,他应该能整除n才是最合适的,否则硬件加速比较差#    n怎么选,一般fp16时,n为8; int8时,n为16#    cp.async.cg.sharedyolo.val()torch.save(yolo.ckpt, savepath)yolo.model.pt_path = yolo.model.pt_path.replace("last.pt", os.path.basename(savepath))yolo.export(format="onnx")## 重新load模型,修改保存命名,用以比较剪枝前后的onnx的大小yolo = YOLO(modelpath)  # build a new model from scratchyolo.export(format="onnx")if __name__ == "__main__":modelpath = "runs/detect1/14_Constraint/weights/last.pt"savepath  = "runs/detect1/14_Constraint/weights/last_prune.pt"do_pruning(modelpath, savepath)
    

2 剪枝全流程

剪枝的流程可分为

  • 正常训练:我们可以通过模型得到未剪枝时的精度,方便剪枝前后进行精度对比
  • 稀疏训练:bn的参数约等于0的并不多,所以需要稀疏训练,使得部分参数接近0,然后再对卷积核进行裁剪,这样可以减小剪枝对网络输出的影响。
  • 剪枝:根据bn中参数对相应的卷积进行剪枝。
  • 微调:剪枝后模型必然会降低精度,所以需要再微调使其恢复精度。

2.1 正常训练
  1. 设置yaml文件:该工程使用VOC数据集,若已下载,在yaml文件中设置好数据路径。若无下来,运行训练代码时,工程会在一开始自动下载数据集,但下载速度可能会慢。
    LL_train.py 中设置:在这里插入图片描述
    ultralytics/datasets/VOC.yaml中如下::在这里插入图片描述
  2. LL_train.py脚本中,调用 step1_train(),注释其它的函数调用,如下图。

    在这里插入图片描述
  3. 激活相应环境,进行训练。运行python LL_train.py,训练结束指标如下:

    在这里插入图片描述

2.2 稀疏训练

  1. ./ultralytics/yolo/engine/trainer.py代码中修改:
    在[反向传播]和[梯度更新] 之间添加[bn的L1正则],使得bn参数在训练时变得稀疏。

                    # Backwardself.scaler.scale(self.loss).backward()## add start=============================## add l1 regulation for step2_Constraint_train               l1_lambda = 1e-2 * (1 - 0.9 * epoch / self.epochs)for k, m in self.model.named_modules():if isinstance(m, nn.BatchNorm2d):m.weight.grad.data.add_(l1_lambda * torch.sign(m.weight.data))m.bias.grad.data.add_(1e-2 * torch.sign(m.bias.data))## add end ==============================# Optimize - https://pytorch.org/docs/master/notes/amp_examples.htmlif ni - last_opt_step >= self.accumulate:self.optimizer_step()last_opt_step = ni
    

    在这里插入图片描述

  2. LL_train.py 中修改:
    在这里插入图片描述

  3. 然后进行训练,结束后的验证指标如下图。
    可以在以下截屏看到最顶行,此时模型共168层、参数量11133324、计算量28.5GFLOPs

    在这里插入图片描述


2.3 模型剪枝

  1. 修改LL_train.py 并训练:

    在这里插入图片描述在这里插入图片描述
  2. 看下剪枝后的文件
    du -sh ./runs/detect/VOC_Constraint/weights/last*
    
    终端输入命名如下,可查看文件大小。可以看到其中剪枝前后的pt的last.pt/last_prune.pt,对应的onnx模型为last.onnx last_prune.onnx,可以看剪枝后的pt增大了,但onnx减小了,我们只需要关注onnx的大小即可,由43M 剪枝为36M。

    在这里插入图片描述

2.4 微调

  1. 将第二步的约束训练添加的bn限制注释掉
    在这里插入图片描述
  2. 加载剪枝后的模型作为训练时的网络结构。在 ultralytics/yolo/engine/trainer.py 中修改内容:
    370行的 self.trainer.model 是从yaml中加载的模型( 未剪枝前的) 以及其它的配置信息,从pt文件中加载的权重( 剪枝后的)。所以只需将该变量中的网络结构更新为剪枝后的网络结构即可,增加代码可解决问题。
    否则训练出来的模型参数不会发生变化。
    self.trainer.model.model = self.model.model
    

在这里插入图片描述

  1. ultralytics/yolo/engine/trainer.py 中增加内容如下:
    model.22.dfl.conv.weight的梯度置为Fasle,是因为该层是解析box时的一个向量,具体的为 [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15],为了便捷的将该处理保存在模型中,所以就定义了一个卷积,所以该卷积不需要梯度、不需要反向传播,所以该层的param.requires_grad = False
    在这里插入图片描述
    否则训练过程中会报错:


至此,就可微调训练。训练结果如下:
在这里插入图片描述
若不进行修改,依然会从modeld的yaml文件中加载模型。参数量和计算量不会有任何变化。下图是未做更多的修改训练后的信息。这样不是我们想要的状态。
在这里插入图片描述


2.5 剪枝完成后的分析

剪枝过程中的所有训练mAP、准召日志,使用tensorboard打开查看对应变化曲线:本次0.8的剪枝率下剪枝前后的指标变化并不大。在错误微调时指标整体偏低。在这里插入图片描述

  • 正常训练:参数量 11133324、计算量28.5GFLOPs、验证集上的准召:0.812、0.727
  • 剪枝微调后:参数量9261550、计算量19.3GFLOPs、验证集上的准召:0.811、0.699

可以看出在0.8的剪枝率下,剪枝后的模型与原本模型的验证集上的准召相差不大。运行时间上,若在高性能GPU上不会有明显加速,在端侧会有加速。具体的数值,后面有空测试后附上。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/221024.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

VBA之Word应用:利用代码统计文档中的书签个数

《VBA之Word应用》&#xff08;版权10178982&#xff09;&#xff0c;是我推出第八套教程&#xff0c;教程是专门讲解VBA在Word中的应用&#xff0c;围绕“面向对象编程”讲解&#xff0c;首先让大家认识Word中VBA的对象&#xff0c;以及对象的属性、方法&#xff0c;然后通过实…

10kV站所柜内运行状态及环境指标监测管理平台

背景&#xff1a; 10kV站所柜内运行状态及环境指标监测管理平台对分布在不同位置的动力设备、环境监测设备和安保设备进行遥测、遥信采集&#xff0c;对各设备的运行状态进行实时监控&#xff0c;同时就相关监测数据展开记录与处理&#xff0c;第一时间向相关人员发出通知&…

【算法】红黑树

一、红黑树介绍 红黑树是一种自平衡二叉查找树&#xff0c;是在计算机科学中用到的一种数据结构&#xff0c;典型的用途是实现关联数组。 红黑树是在1972年由Rudolf Bayer发明的&#xff0c;当时被称为平衡二叉B树&#xff08;symmetric binary B-trees&#xff09;。后来&am…

HTML---CSS美化网页元素

文章目录 前言一、pandas是什么&#xff1f;二、使用步骤 1.引入库2.读入数据总结 一.div 标签&#xff1a; <div>是HTML中的一个常用标签&#xff0c;用于定义HTML文档中的一个区块&#xff08;或一个容器&#xff09;。它可以包含其他HTML元素&#xff0c;如文本、图像…

15 使用v-model绑定单选框

概述 使用v-model绑定单选框也比较常见&#xff0c;比如性别&#xff0c;要么是男&#xff0c;要么是女。比如单选题&#xff0c;给出多个选择&#xff0c;但是只能选择其中的一个。 在本节课中&#xff0c;我们演示一下这两种常见的用法。 基本用法 我们创建src/component…

Word的兼容性问题很常见,禁用兼容模式虽步不是最有效的,但可以解决兼容性问题

当你在较新版本的Word应用程序中打开用较旧版本的Word创建的文档时&#xff0c;会出现兼容性问题。错误通常发生在文件名附近&#xff08;兼容模式&#xff09;。兼容性模式问题&#xff08;暂时&#xff09;禁用Word功能&#xff0c;从而限制使用较新版本Word的用户编辑文档。…

Nginx快速入门:安装目录结构详解及核心配置解读(二)

0. 引言 上节我们讲解了nginx的应用场景和安装&#xff0c;本节继续针对nginx的各个目录文件进行讲解&#xff0c;让大家更加深入的认识nginx。并通过一个实操案例&#xff0c;带大家来实际认知nginx的核心配置 1. nginx安装目录结构 首先nginx的默认安装目录为&#xff1a;…

【深度学习】语言模型与注意力机制以及Bert实战指引之一

文章目录 统计语言模型和神经网络语言模型注意力机制和Bert实战Bert配置环境和模型转换格式准备 模型构建网络设计模型配置代码实战 统计语言模型和神经网络语言模型 区别&#xff1a;统计语言模型的本质是基于词与词共现频次的统计&#xff0c;而神经网络语言模型则是给每个词…

2023大湾区汽车创新大会暨IEEE自动驾驶国际标准研讨会成功举办

2023年12月15日-12月16日&#xff0c;由IEEE ADWG工作组主席孙栋博士、杨子江博士共同主持的2023大湾区汽车创新大会平行主题论坛-IEEE自动驾驶国际标准研讨会在深圳坪山成功举办。图灵奖获得者Joseph Sifakis、英伟达仿真生态总监German Ros、ASAM标准组织CEO Marius Dupuis、…

云原生之深入解析Kubernetes集群发生网络异常时如何排查

一、Pod 网络异常 网络不可达&#xff0c;主要现象为 ping 不通&#xff0c;其可能原因为&#xff1a; 源端和目的端防火墙&#xff08;iptables, selinux&#xff09;限制&#xff1b; 网络路由配置不正确&#xff1b; 源端和目的端的系统负载过高&#xff0c;网络连接数满…

Re解析(正则表达式解析)

正则表达式基础 元字符 B站教学视频&#xff1a; 正则表达式元字符基本使用 量词 贪婪匹配和惰性匹配 惰性匹配如下两张图&#xff0c;而 .* 就表示贪婪匹配&#xff0c;即尽可能多的匹配到符合的字符串&#xff0c;如果使用贪婪匹配&#xff0c;那么结果就是图中的情况三 p…

CTF网络安全大赛是干什么的?发展史、赛制、赛程介绍,参赛需要学什么?

CTF&#xff08;Capture The Flag&#xff09;是一种网络安全竞赛&#xff0c;它模拟了各种信息安全场景&#xff0c;旨在提升参与者的网络安全技能。CTF 赛事通常包含多种类型的挑战&#xff0c;如密码学、逆向工程、网络攻防、Web 安全、二进制利用等。 发展史 CTF 的概念…

LLM大语言模型(二):Streamlit 无需前端经验也能画web页面

目录 问题 Streamlit是什么&#xff1f; 怎样用Streamlit画一个LLM的web页面呢&#xff1f; 文本输出 页面布局 滑动条 按钮 对话框 输入框 总结 问题 假如你是一位后端开发&#xff0c;没有任何的web开发经验&#xff0c;那如何去实现一个LLM的对话交互页面呢&…

持续集成交付CICD:K8S 自动化完成前端项目应用发布与回滚

目录 一、实验 1.环境 2.GitLab新建项目存放K8S部署文件 3.Jenkins手动测试前端项目CD 流水线代码&#xff08;下载部署文件&#xff09; 4. 将K8S master节点配置为jenkins从节点 5.K8S 手动回滚前端项目版本 6.Jenkins手动测试前端项目CD 流水线代码&#xff08;发布应…

棋牌的电脑计时计费管理系统教程,棋牌灯控管理软件操作教程

一、前言 有的棋牌室在计时的时候&#xff0c;需要使用灯控管理&#xff0c;在开始计时的时候打开灯&#xff0c;在结账后关闭灯&#xff0c;也有的不需要用灯控&#xff0c;只用来计时。 下面以 佳易王棋牌计时计费管理系统软件为例说明&#xff1a; 软件试用版下载或技术支…

Axure中继器完成表格的增删改查的自定义元件(三列表格与十列表格)

目录 一、中继器 1.1 定义 1.2 特点 1.3 适用场景 二、三列表格增删改查 2.1 实现思路 2.2 效果演示 三、十列表格增删改查 3.1 实现思路 3.2 效果演示 一、中继器 1.1 定义 在Axure中&#xff0c;"中继器"通常指的是界面设计中的一个元素&#xff0c;用…

Eclipse 一直提示 loading descriptor for 的解决方法

启动eclipse之后&#xff0c;进行相关操作时&#xff0c;弹出界面&#xff0c;提示&#xff1a;loading descriptor for xxx 解决方法&#xff1a; 在Eclipse左侧的Project Explorer 最右上角有一个小钮,鼠标移上去时提示"View Menu". 你点一下,在弹出的上下文菜单中…

TypeScript基础语法

官方网站&#xff1a; TypeScript: JavaScript With Syntax For Types. (typescriptlang.org) 中文网站&#xff1a;TypeScript中文网 TypeScript——JavaScript的超集 (tslang.cn) TypeScript基础语法 1.变量声明 2.条件控制 3.循环迭代 数组迭代 4.函数 默认参数 5.类和接…

JavaWeb笔记之JavaWeb JDBC

//Author 流云 //Version 1.0 一. 引言 1.1 如何操作数据库 使用客户端工具访问数据库&#xff0c;需要手工建立连接&#xff0c;输入用户名和密码登录&#xff0c;编写 SQL 语句&#xff0c;点击执行&#xff0c;查看操作结果&#xff08;结果集或受影响行数&#xff09;。…

计算机网络考研辨析(后续整理入笔记)

文章目录 体系结构物理层速率辨析交换方式辨析编码调制辨析 链路层链路层功能介质访问控制&#xff08;MAC&#xff09;信道划分控制之——CDMA随机访问控制轮询访问控制 扩展以太网交换机 网络层网络层功能IPv4协议IP地址IP数据报分析ICMP 网络拓扑与转发分析&#xff08;重点…