Yolov8模型用torch_pruning剪枝

目录

🚀🚀🚀订阅专栏,更新及时查看不迷路🚀🚀🚀

原理

 遍历所有分组

高级剪枝器


🚀🚀🚀订阅专栏,更新及时查看不迷路🚀🚀🚀

http://t.csdnimg.cn/sVHxv

原理

传统剪枝方法的缺陷

在复杂的网络结构中, 参数之间可能存在依赖关系, 这种依赖要求算法对这类参数进行同步移除以保证结构正确性,这就涉及到耦合参数的分组问题. 我们的工作通过提供一种自动化机制来对参数进行分组. 具体而言, Torch-Pruning使用伪输入来运 行模型, 跟踪网络计算图, 并记录层之间的依赖关系. 当剪枝某一层时, Torch-Pruning会识别所有耦合层, 并返回包含这些耦合信息的tp.Group.

一种通用的结构化剪枝框架DepGraph(Dependency Graph),可以应用于任意类型的神经网络架构(包括CNN、RNN、GNN和Transformer等)进行结构化剪枝。主要原理如下:

1. 神经网络内部存在着层与层之间的依赖关系,需要同时剪枝依赖的层组,否则会破坏网络结构。

2. 结构化剪枝的优势

结构化剪枝的做法是,找到网络中相互依赖的层组,把整个层组同时全部保留或全部删除,从而保证网络结构的完整性。这种做法虽然灵活性较低,但可以有效避免了网络结构被破坏的问题。

3. DepGraph通过建模层与层之间的依赖关系,明确每一层所属的层组。具体分为两种依赖关系:

   a) 层间依赖(Inter-layer Dependency): 相邻连接的层之间存在依赖   层间不依赖:resnet

   b) 层内依赖(Intra-layer Dependency): 同一层的输入和输出具有相同的剪枝方式时存在依赖   层内不依赖:没有共享权重的

4. 通过图遍历算法在DepGraph上找到最大连接分量作为层组,实现自动化的层组划分。总的来说,DepGraph解决了之前结构化剪枝算法依赖人工设计层组划分规则、缺乏通用性的问题,提出了一种自动建模层组依赖关系和组级剪枝重要性评估的通用框架。

5. DepGraph的工作原理

以ResNet的基本模块为例,如果要删除某个卷积层的滤波器核,由于残差连接的存在,我们必须同时删除该模块中所有层(BN层、ReLU层等)对应的通道。DepGraph通过建模层与层之间的依赖关系,自动将这些相互依赖的层划分到同一个层组中。在剪枝时,整个层组被统一评分,决定是完全保留还是完全删除,从而实现安全的结构化剪枝。

import torch
from torchvision.models import resnet18
import torch_pruning as tpmodel = resnet18(pretrained=True).eval()# 1. 构建依赖图
DG = tp.DependencyGraph()
DG.build_dependency(model, example_inputs=torch.randn(1,3,224,224))# 2. 指定剪枝的通道维度
pruning_idxs = [2, 6, 9]
pruning_group = DG.get_pruning_group( model.conv1, tp.prune_conv_out_channels, idxs=pruning_idxs )print(pruning_group.details())  # or print(pruning_group)# 3. 检查剩余通道数是否<=0, 并执行剪枝
if DG.check_pruning_group(pruning_group):pruning_group.prune()

这个例子演示了使用 DepGraph剪枝的基本流程, resnet.conv1实际上会与多个层耦合在一起.通过打印返回的组, 可以看到组内各个层之间的剪枝是如何互相“触发”的.在以下输出中, “A => B”表示剪枝操作“A”触发剪枝操作“B”.group[0]是用户在DG.get_pruning_group中给出的剪枝操作. 

--------------------------------Pruning Group
--------------------------------
[0] prune_out_channels on conv1 (Conv2d(3, 61, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)) => prune_out_channels on conv1 (Conv2d(3, 61, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)), #idxs=3
[1] prune_out_channels on conv1 (Conv2d(3, 61, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)) => prune_out_channels on bn1 (BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), #idxs=3
[2] prune_out_channels on bn1 (BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on _ElementWiseOp_20(ReluBackward0), #idxs=3
[3] prune_out_channels on _ElementWiseOp_20(ReluBackward0) => prune_out_channels on _ElementWiseOp_19(MaxPool2DWithIndicesBackward0), #idxs=3
[4] prune_out_channels on _ElementWiseOp_19(MaxPool2DWithIndicesBackward0) => prune_out_channels on _ElementWiseOp_18(AddBackward0), #idxs=3
[5] prune_out_channels on _ElementWiseOp_19(MaxPool2DWithIndicesBackward0) => prune_in_channels on layer1.0.conv1 (Conv2d(61, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), #idxs=3
[6] prune_out_channels on _ElementWiseOp_18(AddBackward0) => prune_out_channels on layer1.0.bn2 (BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), #idxs=3
[7] prune_out_channels on _ElementWiseOp_18(AddBackward0) => prune_out_channels on _ElementWiseOp_17(ReluBackward0), #idxs=3
[8] prune_out_channels on _ElementWiseOp_17(ReluBackward0) => prune_out_channels on _ElementWiseOp_16(AddBackward0), #idxs=3
[9] prune_out_channels on _ElementWiseOp_17(ReluBackward0) => prune_in_channels on layer1.1.conv1 (Conv2d(61, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), #idxs=3
[10] prune_out_channels on _ElementWiseOp_16(AddBackward0) => prune_out_channels on layer1.1.bn2 (BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), #idxs=3
[11] prune_out_channels on _ElementWiseOp_16(AddBackward0) => prune_out_channels on _ElementWiseOp_15(ReluBackward0), #idxs=3
[12] prune_out_channels on _ElementWiseOp_15(ReluBackward0) => prune_in_channels on layer2.0.downsample.0 (Conv2d(61, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)), #idxs=3
[13] prune_out_channels on _ElementWiseOp_15(ReluBackward0) => prune_in_channels on layer2.0.conv1 (Conv2d(61, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)), #idxs=3
[14] prune_out_channels on layer1.1.bn2 (BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on layer1.1.conv2 (Conv2d(64, 61, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), #idxs=3
[15] prune_out_channels on layer1.0.bn2 (BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on layer1.0.conv2 (Conv2d(64, 61, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), #idxs=3
--------------------------------
 遍历所有分组

可以利用DG.get_all_groups(ignored_layers, root_module_types)来按顺序扫描所有的分组. 每个分组都会以一个"root_module_types"中所指定的层作为起点. 默认情况下, 这些组包含了完整的剪枝索引idxs=[0,1,2,3,...,K], 这个索引列表包含了所有的可修剪参数的索引. 如果我们希望对一个group进行剪枝, 我们需要使用group.prune(idxs=idxs)来指定具体的修剪通道/维度.

for group in DG.get_all_groups(ignored_layers=[model.conv1], root_module_types=[nn.Conv2d, nn.Linear]):# handle groups in sequential orderidxs = [2,4,6] # your pruning indicesgroup.prune(idxs=idxs)print(group)
高级剪枝器
import torch
from torchvision.models import resnet18
import torch_pruning as tpmodel = resnet18(pretrained=True)# 重要性指标
example_inputs = torch.randn(1, 3, 224, 224)
imp = tp.importance.MagnitudeImportance(p=2) # p=2表示使用L2正则,对每个group中的每个层的权值,独立的计算重要性   重要性如何计算??什么是重要的?值大还是小?是损失吗ignored_layers = []
for m in model.modules():if isinstance(m, torch.nn.Linear) and m.out_features == 1000:ignored_layers.append(m) # DO NOT prune the final classifier!iterative_steps = 5 # 迭代式剪枝, 该示例会分五步完成50%通道剪枝 (10%->20%->...->50%)
pruner = tp.pruner.MagnitudePruner(model,example_inputs,importance=imp,iterative_steps=iterative_steps,pruning_ratio=0.5, # 整体移除50%通道, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}ignored_layers=ignored_layers,
)base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
for i in range(iterative_steps):pruner.step()macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/272612.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【重新定义matlab强大系列十七】Matlab深入浅出长短期记忆神经网络LSTM

&#x1f517; 运行环境&#xff1a;Matlab &#x1f6a9; 撰写作者&#xff1a;左手の明天 &#x1f947; 精选专栏&#xff1a;《python》 &#x1f525; 推荐专栏&#xff1a;《算法研究》 #### 防伪水印——左手の明天 #### &#x1f497; 大家好&#x1f917;&#x1f91…

NPP VIIRS卫星数据介绍及获取

VIIRS&#xff08;Visible infrared Imaging Radiometer&#xff09;可见光红外成像辐射仪。扫描式成像辐射仪&#xff0c;可收集陆地、大气、冰层和海洋在可见光和红外波段的辐射图像。它是高分辨率辐射仪AVHRR和地球观测系列中分辨率成像光谱仪MODIS系列的拓展和改进。VIIRS数…

java 数据结构二叉树

目录 树 树的概念 树的表示形式 二叉树 两种特殊的二叉树 二叉树的性质 二叉树的存储 二叉树的基本操作 二叉树的遍历 二叉树的基本操作 二叉树oj题 树 树是一种 非线性 的数据结构&#xff0c;它是由 n &#xff08; n>0 &#xff09;个有限结点组成一个具有层次…

vs创建asp.net core webapi发布到ISS服务器

打开服务器创建test123文件夹&#xff0c;并设置共享。 ISS配置信息&#xff1a; 邮件网站&#xff0c;添加网站 webapi asp.net core发布到ISS服务器网页无法打开解决方法 点击ISS Express测试&#xff0c;可以成功打开网页。 点击生成&#xff0c;发布到服务器 找到服务器IP…

OJ_复数集合

题干 C实现 #define _CRT_SECURE_NO_WARNINGS #include <stdio.h> #include <queue> #include <string> using namespace std;struct Complex {int re;int im;//构造函数Complex(int _re, int _im) {//注意参数名字必须不同re _re;im _im;} };//结构体不支…

新闻文章分类项目

注意&#xff1a;本文引用自专业人工智能社区Venus AI 更多AI知识请参考原站 &#xff08;[www.aideeplearning.cn]&#xff09; 新闻文章分类模型比较项目报告 项目介绍 背景 新闻文章自动分类是自然语言处理和文本挖掘领域的一个重要任务。正确分类新闻文章不仅能帮助用…

日期问题---算法精讲

前言 今天讲讲日期问题&#xff0c;所谓日期问题&#xff0c;在蓝桥杯中出现众多&#xff0c;但是解法比较固定。 一般有判断日期合法性&#xff0c;判断是否闰年&#xff0c;判断日期的特殊形式&#xff08;回文或abababab型等&#xff09; 目录 例题 题2 题三 总结 …

问题:前端获取long型数值精度丢失,后面几位都为0

文章目录 问题分析解决 问题 通过接口获取到的数据和 Postman 获取到的数据不一样&#xff0c;仔细看 data 的第17位之后 分析 该字段类型是long类型问题&#xff1a;前端接收到数据后&#xff0c;发现精度丢失&#xff0c;当返回的结果超过17位的时候&#xff0c;后面的全…

[java入门到精通] 11 泛型,数据结构,List,Set

今日目标 泛型使用 数据结构 List Set 1 泛型 1.1 泛型的介绍 泛型是一种类型参数&#xff0c;专门用来保存类型用的 最早接触泛型是在ArrayList&#xff0c;这个E就是所谓的泛型了。使用ArrayList时&#xff0c;只要给E指定某一个类型&#xff0c;里面所有用到泛型的地…

【C++】函数重载

&#x1f984;个人主页:修修修也 &#x1f38f;所属专栏:C ⚙️操作环境:Visual Studio 2022 目录 &#x1f4cc;函数重载的定义 &#x1f4cc;函数重载的三种类型 &#x1f38f;参数个数不同 &#x1f38f;参数类型不同 &#x1f38f;参数类型顺序不同 &#x1f4cc;重载…

用C语言执行SQLite3的gcc编译细节

错误信息&#xff1a; /tmp/cc3joSwp.o: In function main: execSqlite.c:(.text0x100): undefined reference to sqlite3_open execSqlite.c:(.text0x16c): undefined reference to sqlite3_exec execSqlite.c:(.text0x174): undefined reference to sqlite3_close execSqlit…

❤ Vue3项目搭建系统篇(二)

❤ Vue3项目搭建系统篇&#xff08;二&#xff09; 1、安装和配置 Element Plus&#xff08;完整导入&#xff09; yarn add element-plus --savemain.ts中引入&#xff1a; // 引入组件 import ElementPlus from element-plus import element-plus/dist/index.css const ap…

STL之deque容器代码详解

1 基础概念 功能&#xff1a; 双端数组&#xff0c;可以对头端进行插入删除操作。 deque与vector区别&#xff1a; vector对于头部的插入删除效率低&#xff0c;数据量越大&#xff0c;效率越低。 deque相对而言&#xff0c;对头部的插入删除速度回比vector快。 vector访问…

jpg 转 ico 强大的图片处理工具 imageMagick

点击下载 windows, mac os, linux版本 GitHub - ImageMagick/ImageMagick: &#x1f9d9;‍♂️ ImageMagick 7 1. windows程序 链接&#xff1a;https://pan.baidu.com/s/1wZLqpcytpCVAl52pIrBBEw 提取码&#xff1a;hbfy 一直点击下一步安装 2. 然后 winr键 打开cmd 然…

SSD的原理

简介 SSD&#xff08;Solid State Drive&#xff09;是一种使用闪存存储芯片&#xff08;NAND Flash&#xff09;的存储设备。与传统的机械硬盘不同&#xff0c;SSD没有移动部件&#xff0c;因此具有更快的读写速度和更低的能耗。 架构 NAND Flash是一种非易失性存储器&…

nodejs web服务器 -- 搭建开发环境

一、配置目录结构 1、使用npm生成package.json&#xff0c;我创建了一个nodejs_network 文件夹&#xff0c;cd到这个文件夹下&#xff0c;执行&#xff1a; npm init -y 其中-y的含义是yes的意思&#xff0c;在init的时候省去了敲回车的步骤&#xff0c;如此就生成了默认的pac…

008-slot插槽

slot插槽 1、插槽 slot 的简单使用2、插槽分类2.1 默认插槽2.2 具名插槽2.3 作用域插槽 插槽就是子组件中的提供给父组件使用的一个占位符&#xff0c;用<slot></slot> 表示&#xff0c;父组件可以在这个占位符中填充任何模板代码&#xff0c;如 HTML、组件等&…

cannot import name ‘Flask‘ from partially initialized module ‘flask‘

bug&#xff1a; ImportError: cannot import name Flask from partially initialized module flask (most likely due to a circular import) (G:\pythonProject6\flask.py) 这个是因为包的名字和文件的名字一样 修改文件名&#xff1a; 结果 &#x1f923;&#x1f923;&…

nginx配置支持ipv6访问,ipv4改造ipv6

一、前言 本地测试nginx部署的web系统支持ipv6地址访问。 二、本机ipv6地址 cmd ipconfig 找到IPv6地址 其中带有%号其实是临时分配得到地址 我们可以ping一下看看 另一种ping的方式 加上中括号 还有就是去掉%号 三、nginx增加配置 server块里增加 listen [::]:80; 四、测…

arcgis 栅格数据处理2——栅格转地级市(栅格转矢量图)

1. 获取空间分析权限&#xff08;解决无法执行所选工具问题&#xff09; 选中“自定义”中的“扩展模块” 在弹出的模块中选中能选的模块&#xff0c;此处需要选择“spatial analysis”以进行下一步分析 3. 将栅格数据转为整数型&#xff08;解决无法矢量化&#xff09; 选…