yolov8剪枝实践

本文使用的剪枝库是torch-pruning ,实验了该库的三个剪枝算法GroupNormPruner、BNScalePruner和GrowingRegPruner。

安装使用

  1. 安装依赖库
pip install torch-pruning 
  1. 把 https://github.com/VainF/Torch-Pruning/blob/master/examples/yolov8/yolov8_pruning.py,文件拷贝到yolov8的根目录下。或者使用我的剪枝代码,在原有的基础上稍作修改,保存了不同剪枝阶段的模型。
# This code is adapted from Issue [#147](https://github.com/VainF/Torch-Pruning/issues/147), implemented by @Hyunseok-Kim0.
import argparse
import math
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "7"
from copy import deepcopy
from datetime import datetime
from pathlib import Path
from typing import List, Unionimport numpy as np
import torch
import torch.nn as nn
from matplotlib import pyplot as plt
from ultralytics import YOLO, __version__
from ultralytics.nn.modules import Detect, C2f, Conv, Bottleneck
from ultralytics.nn.tasks import attempt_load_one_weight
from ultralytics.yolo.engine.model import TASK_MAP
from ultralytics.yolo.engine.trainer import BaseTrainer
from ultralytics.yolo.utils import yaml_load, LOGGER, RANK, DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS
from ultralytics.yolo.utils.checks import check_yaml
from ultralytics.yolo.utils.torch_utils import initialize_weights, de_parallelimport torch_pruning as tpdef save_pruning_performance_graph(x, y1, y2, y3):"""Draw performance change graphParameters----------x : ListParameter numbers of all pruning stepsy1 : ListmAPs after fine-tuning of all pruning stepsy2 : ListMACs of all pruning stepsy3 : ListmAPs after pruning (not fine-tuned) of all pruning stepsReturns-------"""try:plt.style.use("ggplot")except:passx, y1, y2, y3 = np.array(x), np.array(y1), np.array(y2), np.array(y3)y2_ratio = y2 / y2[0]# create the figure and the axis objectfig, ax = plt.subplots(figsize=(8, 6))# plot the pruned mAP and recovered mAPax.set_xlabel('Pruning Ratio')ax.set_ylabel('mAP')ax.plot(x, y1, label='recovered mAP')ax.scatter(x, y1)ax.plot(x, y3, color='tab:gray', label='pruned mAP')ax.scatter(x, y3, color='tab:gray')# create a second axis that shares the same x-axisax2 = ax.twinx()# plot the second set of dataax2.set_ylabel('MACs')ax2.plot(x, y2_ratio, color='tab:orange', label='MACs')ax2.scatter(x, y2_ratio, color='tab:orange')# add a legendlines, labels = ax.get_legend_handles_labels()lines2, labels2 = ax2.get_legend_handles_labels()ax2.legend(lines + lines2, labels + labels2, loc='best')ax.set_xlim(105, -5)ax.set_ylim(0, max(y1) + 0.05)ax2.set_ylim(0.05, 1.05)# calculate the highest and lowest points for each set of datamax_y1_idx = np.argmax(y1)min_y1_idx = np.argmin(y1)max_y2_idx = np.argmax(y2)min_y2_idx = np.argmin(y2)max_y1 = y1[max_y1_idx]min_y1 = y1[min_y1_idx]max_y2 = y2_ratio[max_y2_idx]min_y2 = y2_ratio[min_y2_idx]# add text for the highest and lowest values near the pointsax.text(x[max_y1_idx], max_y1 - 0.05, f'max mAP = {max_y1:.2f}', fontsize=10)ax.text(x[min_y1_idx], min_y1 + 0.02, f'min mAP = {min_y1:.2f}', fontsize=10)ax2.text(x[max_y2_idx], max_y2 - 0.05, f'max MACs = {max_y2 * y2[0] / 1e9:.2f}G', fontsize=10)ax2.text(x[min_y2_idx], min_y2 + 0.02, f'min MACs = {min_y2 * y2[0] / 1e9:.2f}G', fontsize=10)plt.title('Comparison of mAP and MACs with Pruning Ratio')plt.savefig('pruning_perf_change.png')def infer_shortcut(bottleneck):c1 = bottleneck.cv1.conv.in_channelsc2 = bottleneck.cv2.conv.out_channelsreturn c1 == c2 and hasattr(bottleneck, 'add') and bottleneck.addclass C2f_v2(nn.Module):# CSP Bottleneck with 2 convolutionsdef __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansionsuper().__init__()self.c = int(c2 * e)  # hidden channelsself.cv0 = Conv(c1, self.c, 1, 1)self.cv1 = Conv(c1, self.c, 1, 1)self.cv2 = Conv((2 + n) * self.c, c2, 1)  # optional act=FReLU(c2)self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n))def forward(self, x):# y = list(self.cv1(x).chunk(2, 1))y = [self.cv0(x), self.cv1(x)]y.extend(m(y[-1]) for m in self.m)return self.cv2(torch.cat(y, 1))def transfer_weights(c2f, c2f_v2):c2f_v2.cv2 = c2f.cv2c2f_v2.m = c2f.mstate_dict = c2f.state_dict()state_dict_v2 = c2f_v2.state_dict()# Transfer cv1 weights from C2f to cv0 and cv1 in C2f_v2old_weight = state_dict['cv1.conv.weight']half_channels = old_weight.shape[0] // 2state_dict_v2['cv0.conv.weight'] = old_weight[:half_channels]state_dict_v2['cv1.conv.weight'] = old_weight[half_channels:]# Transfer cv1 batchnorm weights and buffers from C2f to cv0 and cv1 in C2f_v2for bn_key in ['weight', 'bias', 'running_mean', 'running_var']:old_bn = state_dict[f'cv1.bn.{bn_key}']state_dict_v2[f'cv0.bn.{bn_key}'] = old_bn[:half_channels]state_dict_v2[f'cv1.bn.{bn_key}'] = old_bn[half_channels:]# Transfer remaining weights and buffersfor key in state_dict:if not key.startswith('cv1.'):state_dict_v2[key] = state_dict[key]# Transfer all non-method attributesfor attr_name in dir(c2f):attr_value = getattr(c2f, attr_name)if not callable(attr_value) and '_' not in attr_name:setattr(c2f_v2, attr_name, attr_value)c2f_v2.load_state_dict(state_dict_v2)def replace_c2f_with_c2f_v2(module):for name, child_module in module.named_children():if isinstance(child_module, C2f):# Replace C2f with C2f_v2 while preserving its parametersshortcut = infer_shortcut(child_module.m[0])c2f_v2 = C2f_v2(child_module.cv1.conv.in_channels, child_module.cv2.conv.out_channels,n=len(child_module.m), shortcut=shortcut,g=child_module.m[0].cv2.conv.groups,e=child_module.c / child_module.cv2.conv.out_channels)transfer_weights(child_module, c2f_v2)setattr(module, name, c2f_v2)else:replace_c2f_with_c2f_v2(child_module)def save_model_v2(self: BaseTrainer):"""Disabled half precision saving. originated from ultralytics/yolo/engine/trainer.py"""ckpt = {'epoch': self.epoch,'best_fitness': self.best_fitness,'model': deepcopy(de_parallel(self.model)),'ema': deepcopy(self.ema.ema),'updates': self.ema.updates,'optimizer': self.optimizer.state_dict(),'train_args': vars(self.args),  # save as dict'date': datetime.now().isoformat(),'version': __version__}# Save last, best and deletetorch.save(ckpt, self.last)if self.best_fitness == self.fitness:torch.save(ckpt, self.best)if (self.epoch > 0) and (self.save_period > 0) and (self.epoch % self.save_period == 0):torch.save(ckpt, self.wdir / f'epoch{self.epoch}.pt')del ckptdef final_eval_v2(self: BaseTrainer):"""originated from ultralytics/yolo/engine/trainer.py"""for f in self.last, self.best:if f.exists():strip_optimizer_v2(f)  # strip optimizersif f is self.best:LOGGER.info(f'\nValidating {f}...')self.metrics = self.validator(model=f)self.metrics.pop('fitness', None)self.run_callbacks('on_fit_epoch_end')def strip_optimizer_v2(f: Union[str, Path] = 'best.pt', s: str = '') -> None:"""Disabled half precision saving. originated from ultralytics/yolo/utils/torch_utils.py"""x = torch.load(f, map_location=torch.device('cpu'))args = {**DEFAULT_CFG_DICT, **x['train_args']}  # combine model args with default args, preferring model argsif x.get('ema'):x['model'] = x['ema']  # replace model with emafor k in 'optimizer', 'ema', 'updates':  # keysx[k] = Nonefor p in x['model'].parameters():p.requires_grad = Falsex['train_args'] = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS}  # strip non-default keys# x['model'].args = x['train_args']torch.save(x, s or f)mb = os.path.getsize(s or f) / 1E6  # filesizeLOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")def train_v2(self: YOLO, pruning=False, **kwargs):"""Disabled loading new model when pruning flag is set. originated from ultralytics/yolo/engine/model.py"""self._check_is_pytorch_model()if self.session:  # Ultralytics HUB sessionif any(kwargs):LOGGER.warning('WARNING ⚠️ using HUB training arguments, ignoring local training arguments.')kwargs = self.session.train_argsoverrides = self.overrides.copy()overrides.update(kwargs)if kwargs.get('cfg'):LOGGER.info(f"cfg file passed. Overriding default params with {kwargs['cfg']}.")overrides = yaml_load(check_yaml(kwargs['cfg']))overrides['mode'] = 'train'if not overrides.get('data'):raise AttributeError("Dataset required but missing, i.e. pass 'data=coco128.yaml'")if overrides.get('resume'):overrides['resume'] = self.ckpt_pathself.task = overrides.get('task') or self.taskself.trainer = TASK_MAP[self.task][1](overrides=overrides, _callbacks=self.callbacks)if not pruning:if not overrides.get('resume'):  # manually set model only if not resumingself.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)self.model = self.trainer.modelelse:# pruning modeself.trainer.pruning = Trueself.trainer.model = self.model# replace some functions to disable half precision savingself.trainer.save_model = save_model_v2.__get__(self.trainer)self.trainer.final_eval = final_eval_v2.__get__(self.trainer)self.trainer.hub_session = self.session  # attach optional HUB sessionself.trainer.train()# Update model and cfg after trainingif RANK in (-1, 0):self.model, _ = attempt_load_one_weight(str(self.trainer.best))self.overrides = self.model.argsself.metrics = getattr(self.trainer.validator, 'metrics', None)def prune(args):# load trained yolov8 modelbase_name = 'prune/' + str(datetime.now()) + '/'model = YOLO(args.model)model.__setattr__("train_v2", train_v2.__get__(model))pruning_cfg = yaml_load(check_yaml(args.cfg))batch_size = pruning_cfg['batch']# use coco128 dataset for 10 epochs fine-tuning each pruning iteration step# this part is only for sample code, number of epochs should be included in config filepruning_cfg['data'] = "./ultralytics/datasets/soccer.yaml"pruning_cfg['epochs'] = 4model.model.train()replace_c2f_with_c2f_v2(model.model)initialize_weights(model.model)  # set BN.eps, momentum, ReLU.inplacefor name, param in model.model.named_parameters():param.requires_grad = Trueexample_inputs = torch.randn(1, 3, pruning_cfg["imgsz"], pruning_cfg["imgsz"]).to(model.device)macs_list, nparams_list, map_list, pruned_map_list = [], [], [], []base_macs, base_nparams = tp.utils.count_ops_and_params(model.model, example_inputs)# do validation before pruning modelpruning_cfg['name'] = base_name+f"baseline_val"pruning_cfg['batch'] = 128validation_model = deepcopy(model)metric = validation_model.val(**pruning_cfg)init_map = metric.box.mapmacs_list.append(base_macs)nparams_list.append(100)map_list.append(init_map)pruned_map_list.append(init_map)print(f"Before Pruning: MACs={base_macs / 1e9: .5f} G, #Params={base_nparams / 1e6: .5f} M, mAP={init_map: .5f}")# prune same ratio of filter based on initial sizech_sparsity = 1 - math.pow((1 - args.target_prune_rate), 1 / args.iterative_steps)for i in range(args.iterative_steps):model.model.train()for name, param in model.model.named_parameters():param.requires_grad = Trueignored_layers = []unwrapped_parameters = []for m in model.model.modules():if isinstance(m, (Detect,)):ignored_layers.append(m)example_inputs = example_inputs.to(model.device)pruner = tp.pruner.GroupNormPruner(model.model,example_inputs,importance=tp.importance.GroupNormImportance(),  # L2 norm pruning,iterative_steps=1,ch_sparsity=ch_sparsity,ignored_layers=ignored_layers,unwrapped_parameters=unwrapped_parameters)# Test regularization#output = model.model(example_inputs)#(output[0].sum() + sum([o.sum() for o in output[1]])).backward()#pruner.regularize(model.model)pruner.step()# pre fine-tuning validationpruning_cfg['name'] = base_name+f"step_{i}_pre_val"pruning_cfg['batch'] = 128validation_model.model = deepcopy(model.model)metric = validation_model.val(**pruning_cfg)pruned_map = metric.box.mappruned_macs, pruned_nparams = tp.utils.count_ops_and_params(pruner.model, example_inputs.to(model.device))current_speed_up = float(macs_list[0]) / pruned_macsprint(f"After pruning iter {i + 1}: MACs={pruned_macs / 1e9} G, #Params={pruned_nparams / 1e6} M, "f"mAP={pruned_map}, speed up={current_speed_up}")# fine-tuningfor name, param in model.model.named_parameters():param.requires_grad = Truepruning_cfg['name'] = base_name+f"step_{i}_finetune"pruning_cfg['batch'] = batch_size  # restore batch sizemodel.train_v2(pruning=True, **pruning_cfg)# post fine-tuning validationpruning_cfg['name'] = base_name+f"step_{i}_post_val"pruning_cfg['batch'] = 128validation_model = YOLO(model.trainer.best)metric = validation_model.val(**pruning_cfg)current_map = metric.box.mapprint(f"After fine tuning mAP={current_map}")macs_list.append(pruned_macs)nparams_list.append(pruned_nparams / base_nparams * 100)pruned_map_list.append(pruned_map)map_list.append(current_map)# remove pruner after single iterationdel prunermodel.model.zero_grad() # Remove gradientssave_path = 'runs/detect/'+base_name+f"step_{i}_pruned_model.pth"torch.save(model.model,save_path) # without .state_dictprint('pruned model saved in',save_path)# model = torch.load('model.pth') # load the pruned modelsave_pruning_performance_graph(nparams_list, map_list, macs_list, pruned_map_list)# if init_map - current_map > args.max_map_drop:#     print("Pruning early stop")#     break# model.export(format='onnx')if __name__ == "__main__":parser = argparse.ArgumentParser()parser.add_argument('--model', default='runs/detect/train/weights/last.pt', help='Pretrained pruning target model file')parser.add_argument('--cfg', default='default.yaml',help='Pruning config file.'' This file should have same format with ultralytics/yolo/cfg/default.yaml')parser.add_argument('--iterative-steps', default=4, type=int, help='Total pruning iteration step')parser.add_argument('--target-prune-rate', default=0.2, type=float, help='Target pruning rate')parser.add_argument('--max-map-drop', default=1, type=float, help='Allowed maximum map drop after fine-tuning')args = parser.parse_args()prune(args)
  1. 在代码的这些位置加上一些限制,不然它会经常的验证模型:
    在这里插入图片描述
    在这里插入图片描述

实验结果: 实验中,待续~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/154894.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

使用wireshark解析ipsec esp包

Ipsec esp包就是ipsec通过ike协议协商好后建立的通信隧道使用的加密包,该加密包里面就是用户的数据,比如通过的语音等。 那么如何将抓出来的esp包解析出来看呢? 获取相关的esp的key信息. 打开wireshark -> edit->preferences 找到pr…

【手写数字识别】GPU训练版本

SVM Adaboost Bagging 完整代码 I import torch import torch.nn.functional as F from torch.utils.data import DataLoader, TensorDataset from torchvision import transforms, datasets import matplotlib.pyplot as plt# 超参数 batch_size 64 num_epochs 10# 数据…

安卓三防平板在行业应用中有哪些优势

在工业维修和检测中,安卓三防平板的应用也十分广泛。它可以搭载各种专业软件和工具,帮助工人们进行设备故障排查和维护,降低了维修成本和停机时间。 一、产品卖点: 1. 防水性能:该手持平板采用了防水设计,…

国标28181 开源WVP-PRO项目部署

感谢大牛的开源框架 https://doc.wvp-pro.cn/#/ 一.直接使用源码部署(在linux) -- 安装环境 yum install -y java-1.8.0-openjdk.x86_64 git maven nodejs npm -- 下载源码-wvp项目 git clone https://gitee.com/pan648540858/wvp-GB28181-pro.git ---…

MyBatis-Plus为简化开发而生

简介 MyBatis-Plus 简称 MP是一个 MyBatis 的增强工具,在 MyBatis 的基础上只做增强不做改变,为简化开发、提高效率而生。 他们的愿景是成为 MyBatis 最好的搭档,就像魂斗罗中的 1P、2P,基友搭配,效率翻倍。 特性 无…

【ARM CoreLink 系列 6 -- DMC-400控制器简介】

文章目录 1.1 DMC-400 简介1.1.1 DFI(DDR PHY Interface)1.1.2 DFI 接口组1.1.3 DMC-400 兼容协议1.1.4 DMC-400 特性1.1.5 DMC-400 Interface 1.1 DMC-400 简介 DMC-400是一个由ARM开发、测试和授权的动态内存控制器,同时 DMC-400也是一个符…

如何实现 Es 全文检索、高亮文本略缩处理

如何实现 Es 全文检索、高亮文本略缩处理 前言技术选型JAVA 常用语法说明全文检索开发高亮开发Es Map 转对象使用核心代码 Trans 接口(支持父类属性的复杂映射)Trans 接口的不足真实项目落地效果结语 前言 最近手上在做 Es 全文检索的需求,类…

曦力音视频转换工具Xilisoft Video Converter Ultimate mac中文版

Xilisoft Video Converter Ultimate mac是一款功能强大的视频转换软件,它可以将几乎所有流行的视频格式转换为其他格式,包括AVI、MPEG、WMV、DivX、MP4、H.264/AVC、AVCHD、MKV、RM、MOV、XviD、3GP等。此外,它还支持将视频转换为音频格式&am…

选择适合自身业务的HTTP代理有哪些因素决定?

相信对很多爬虫工作者和数据采集的企业来说,如何选购适合自己业务的HTTP代理是一个特别特别困扰的选题,市面上那么多HTTP代理厂商,好像这家有这些缺点,转头又看到另外一家的缺点,要找一家心仪的仿佛大海捞针。今天我们…

前端预览、下载二进制文件流(png、pdf)

前端请求设置 responseType: “blob” 后台接口返回的文件流如下&#xff1a; 拿到后端返回的文件流后&#xff1a; 预览 <iframe :src"previewUrl" frameborder"0" style"width: 500px; height: 500px;"></iframe>1、预览 v…

虹科分享 | 想买车无忧?AR为您带来全新体验!

新能源汽车的蓬勃发展&#xff0c;推动着汽车行业加速进行数字化变革。据数据显示&#xff0c;全球新能源汽车销售额持续上升&#xff0c;预计到2025年&#xff0c;新能源汽车市场规模将达到约 4200亿美元&#xff0c;年复合增长率超过 30%。这表明消费者对清洁能源出行的需求不…

隔离上网,安全上网

SDC沙盒数据防泄密系统&#xff08;安全上网&#xff0c;隔离上网&#xff09; •深信达SDC沙盒数据防泄密系统&#xff0c;是专门针对敏感数据进行防泄密保护的系统&#xff0c;根据隔离上网和安全上网的原则实现数据的代码级保护&#xff0c;不会影响工作效率&#xff0c;不…

数据挖掘与统计分析——T检验,正态性检验和一致性检验——代码复现

T检验是一种统计测试&#xff0c;用于确定两个样本组的均值是否有统计学上的显著差异。以下是对T检验的详细介绍&#xff1a; 定义&#xff1a; T检验是一种参数检验&#xff0c;它的前提是数据近似于正态分布。它通过计算T统计量&#xff0c;并将其与特定分布&#xff08;T分…

PyTorch Lightning - LightningModule 训练逻辑 (training_step) 异常处理 try-except

欢迎关注我的CSDN&#xff1a;https://spike.blog.csdn.net/ 本文地址&#xff1a;https://spike.blog.csdn.net/article/details/133673820 在使用 LightningModule 框架训练模型时&#xff0c;因数据导致的训练错误&#xff0c;严重影响训练稳定性&#xff0c;因此需要使用 t…

华为OD机试 - 数组组成的最小数字(Java 2023 B卷 100分)

目录 专栏导读一、题目描述二、输入描述三、输出描述四、解题思路五、Java算法源码六、效果展示1、输入2、输出3、说明 华为OD机试 2023B卷题库疯狂收录中&#xff0c;刷题点这里 专栏导读 本专栏收录于《华为OD机试&#xff08;JAVA&#xff09;真题&#xff08;A卷B卷&#…

Spring5应用之事务处理

作者简介&#xff1a;☕️大家好&#xff0c;我是Aomsir&#xff0c;一个爱折腾的开发者&#xff01; 个人主页&#xff1a;Aomsir_Spring5应用专栏,Netty应用专栏,RPC应用专栏-CSDN博客 当前专栏&#xff1a;Spring5应用专栏_Aomsir的博客-CSDN博客 文章目录 参考文献前言事务…

ant-design-vue 实现表格表头纵排列

结果如图&#xff1a; 区域&#xff0c;成功率&#xff0c;清单率为表头&#xff0c;右侧为动态的数据 废话不多说直接上代码&#xff1a; 1.先声明表格&#xff0c;使用框架自带a-table&#xff0c;核心点就在data和columns上 <div style"margin-bottom: 60px;"…

jvm--执行引擎

文章目录 1. 执行引擎的工作流程2. 解释器、JIT及时编译器3. 热点代码及探测技术4. HotSpotVM 中 JIT 分类 执行引擎属于 JVM 的下层&#xff0c;里面包括解释器、及时编译器、垃圾回收器 JVM 的主要任务是负责 装载字节码到其内部&#xff0c;但字节码并不能够直接运行在操作…

C之fopen/fclose/fread/fwrite/flseek

一、C中文件操作简介 c中的文件操作大致和linux的文件操作类似&#xff0c;但是毕竟是不同的API&#xff0c;所以会有些差异。部分差异会在下面的案例中体验 二、fopen open的参数有两个一个是文件名&#xff0c;一个是模式选择&#xff0c;不同open函数&#xff0c;open中的模…

Jmeter 分布式压测,你的系统能否承受高负载?

‍你可以使用 JMeter 来模拟高并发秒杀场景下的压力测试。这里有一个例子&#xff0c;它模拟了同时有 5000 个用户&#xff0c;循环 10 次的情况‍。 请求默认配置 token 配置 秒杀接口 ​结果分析 ​但是&#xff0c;实际企业中&#xff0c;这种压测方式根本不满足实际需求。下…