2024强化学习的结构化剪枝模型RL-Pruner原理及实践

[2024] RL-Pruner: Structured Pruning Using Reinforcement Learning for CNN Compression and Acceleration

目录

  • [2024] RL-Pruner: Structured Pruning Using Reinforcement Learning for CNN Compression and Acceleration
    • 一、论文说明
    • 二、原理
    • 三、实验与分析
      • 1、环境配置
        • 在Windows配置git bash链接conda环境
      • 2、项目代码运行
        • 1、训练预训练权重
        • 2、模型压缩
        • 3、模型验证
    • 四、总结

一、论文说明

论文标题:使用强化学习进行结构化剪枝用于卷积神经网路压缩和加速

机构:伊利诺伊大学厄巴纳-香槟分校

论文链接:https://arxiv.org/pdf/2411.06463

代码链接:https://github.com/Beryex/RLPruner-CNN

论文简介: 卷积神经网络(ConvolutionalNeural Networks, CNNs)近年来表现出卓越的性能。压缩这些模型不仅减少了存储需求,使其在边缘设备上的部署变得可行,还加速了推理,从而降低了延迟和计算成本。结构化剪枝,它在层级上去除过滤器,直接修改了模型架构。这种方法实现了更紧凑的架构,同时保持目标准确性,确保压缩模型具有较好的兼容性和硬件效率。所提方法基于一个关键观察:

  • 1、神经网络中不同层的过滤器对模型性能的重要性各不相同。

  • 2、当修剪的过滤器数量固定时,不同层之间的最佳修剪分配是不均匀的,以最小化性能损失

  • 3、对修剪敏感的层应该占据更小的修剪分配比例。

为了利用这一洞察,文中提出了RL-Pruner,它使用强化学习来学习最佳修剪分配。RL-Pruner可以自动提取输入模型中过滤器之间的依赖关系并执行修剪,无需特定于模型的修剪实现。在GoogleNet、ResNet和MobileNet 等模型上进行了实验,将所提方法与其他结构化剪枝方法进行了比较,以验证其有效性。

在这里插入图片描述

二、原理

RL-Pruner 首先在模型中的层之间构建依赖图,然后分几个步骤进行剪枝。在每个步骤中:1) 基于基础分布生成一个新的剪枝稀疏分布作为动作 ,这作为策略;2)根据相应的稀疏度,使用泰勒准则(Taylorcriterion)对每一层进行剪枝;3) 评估压缩后的模型以获得奖励,并将动作和奖励存储在回放(replay)经验池中。每个步骤后,基础分布根据经验池更新,如果计算资源足够,则对压缩模型应用后训练阶段,使用知识蒸馏(knowledge distillation),其中原始模型作为教师。具体框图如图2所示。

三、实验与分析

1、环境配置

实验平台及软件

  • Windows 10
  • git bash
  • conda环境

这里主要介绍如何在windows系统上让git bash链接conda环境。

在Windows配置git bash链接conda环境

由于工程代码中需要使用bash命令运行代码,因此需要保证git bash能调用conda环境运行对应的脚本文件。

C:\Users\username\.bashrc文件内设置conda.sh位置(文中示例为:D:\\Anaconda3\\etc\\profile.d\\conda.sh),并激活配置。在git bash界面输入具体命令如下:

echo "D:\\Anaconda3\\etc\\profile.d\\conda.sh" >> ~/.bashrc  
source ~/.bashrc

然后关闭git bash界面,再重新打开一个git bash界面,最后输入命令激活conda环境conda activate 虚拟环境名字。如果命令提示中出现如下图所示的字样,即为配置成功,否则根据提示的要求进行配置,比如输入conda init,重新打开一个新的git bash界面。

在这里插入图片描述

2、项目代码运行

克隆项目文件,具体命令如下:

git clone https://github.com/Beryex/RLPruner-CNN.git --depth 1
cd RLPruner-CNN

安装python第三方包,具体命令如下(如果之前有conda环境,可以不用进行下面这一步,等报错了再根据提示安装对应的包即可):

conda create -n RLPruner python=3.10 -y
conda activate RLPruner
pip install -r requirements.txt

官方代码提供了一步到位的运行脚本,从预训练模型、模型压缩到模型验证,仅需在命令行中输入如下代码:

./scripts/flexible.sh googlenet cifar100 0.20 taylor 0.00 0.00

为了更好地了解每一步的设置,下面内容将分为预训练模型、模型压缩、模型验证三个步骤进行介绍。

1、训练预训练权重

训练模型得到对应的预训练权重,这里以resnet32googlenet为例,在git bash输入具体命令(默认使用cuda)如下:

./scripts/train.sh googlenet cifar100
./scripts/train.sh resnet32 cifar100

或者使用参考指定配置命令:

python -m train --model ${MODEL} --dataset ${DATASET} --device cuda \--output_dir ${PRETRAINED_MODEL_DIR} \--log_dir ${LOG}

其中,
${MODEL}代表backbone的类型([“vgg11”, “vgg13”, “vgg16”, “vgg19”, “resnet18”, “resnet34”, “resnet50”, “resnet101”, “resnet152”, “resnet8”, “resnet14”, “resnet20”, “resnet32”, “resnet44”, “resnet56”, “resnet110”, “densenet121”, “densenet161”, “densenet169”, “densenet201”, “mobilenetv3_small”, “mobilenetv3_large”, “googlenet”]);
${DATASET}代表数据集名称,如cifar10或者cifar100。
${PRETRAINED_MODEL_DIR}代表输出权重文件路径,默认在pretrained_model文件夹下;
${LOG}代表输出日志路径,默认在log文件夹下。

在CIFAR100数据集上训练resnet32的结果(最佳准确率:0.706)如下图所示。
在这里插入图片描述
在CIFAR100数据集上训练googlenet的结果(最佳准确率:0.774)如下图所示。
在这里插入图片描述

2、模型压缩

模型结构化剪枝这里以0.2的稀疏度,taylor剪枝策略和Q_FLOP_coef=0,Q_Para_coef=0的参数进行测试。在git bash输入具体命令(默认使用cuda)如下:

./scripts/flexible.sh googlenet cifar100 0.20 taylor 0.00 0.00

同理,也可以使用参考指定配置命令:

python -m compress --model ${MODEL} --dataset ${DATASET} --device cuda \--sparsity ${SPARSITY} --prune_strategy ${prune_strategy} --ppo \--Q_FLOP_coef ${Q_FLOP_coef} --Q_Para_coef ${Q_Para_coef} \--pretrained_pth ${PRETRAINED_MODEL_PTH} \--compressed_dir ${COMPRESSED_MODEL_DIR} \--checkpoint_dir ${CKPT_DIR} \--log_dir ${LOG} --save_model

测试结果如下图所示:
在这里插入图片描述

3、模型验证

在数据集上验证模型的识别性能,在git bash输入具体命令(默认使用cuda)如下:

./scripts/evaluate.sh googlenet cifar100

同理,也可以使用参考指定配置命令:

python -m evaluate --model ${MODEL} --dataset ${DATASET} --device cuda \--pretrained_pth ${PRETRAINED_MODEL_PTH} \--compressed_pth ${COMPRESSED_MODEL_PTH} \--log_dir ${LOG}

测试结果如下图所示:
在这里插入图片描述

四、总结

本文提出了RL-Pruner,一种结构化剪枝方法,能够学习各层之间的最优稀疏性分布,并支持无模型特定修改的一般剪枝。希望所提方法能够认识到每一层对模型(model)性能的重要性不同,这将影响未来在神经网络压缩领域的工作,包括无结构剪枝和量化。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/472578.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

嵌入式硬件电子电路设计(五)MOS管详解(NMOS、PMOS、三极管跟mos管的区别)

引言:在我们的日常使用中,MOS就是个纯粹的电子开关,虽然MOS管也有放大作用,但是几乎用不到,只用它的开关作用,一般的电机驱动,开关电源,逆变器等大功率设备,全部使用MOS管…

蓝桥杯每日真题 - 第14天

题目:(2022) 题目描述(13届 C&C B组A题) 解题思路: 定义状态: 使用一个二维数组 dp[j][k] 来表示将数字 k 拆分为 j 个不同正整数的方案数。 初始化: 初始状态设定为 dp[0][0]…

利用云计算实现高效的数据备份与恢复策略

💓 博客主页:瑕疵的CSDN主页 📝 Gitee主页:瑕疵的gitee主页 ⏩ 文章专栏:《热点资讯》 利用云计算实现高效的数据备份与恢复策略 利用云计算实现高效的数据备份与恢复策略 利用云计算实现高效的数据备份与恢复策略 引…

thinkphp6配置多应用项目及多域名访问路由app配置

这里写一写TP6下配置多应用。TP6默认是单应用模式(单模块),而我们实际项目中往往是多应用的(多个模块),所以在利用TP6是就需要进行配置,开启多应用模式。 1、安装ThinkPHP6 1.1安装ThinkPHP6.…

JavaScript:浏览器对象模型BOM

BOM介绍 浏览器对象模型(Brower Object Model,BOM)提供了独立于内容而与浏览器窗口进行交互的对象,其核心对象是window BOM由一系列相关的对象构成,并且每个对象都提供了很多方法和属性。 BOM与DOM区别 DOM是文档对…

SpringBoot 2.2.10 无法执行Test单元测试

很早之前的项目今天clone现在,想执行一个业务订单的检查,该检查的代码放在test单元测试中,启动也是好好的,当点击对应的方法执行Test的时候就报错 tip:已添加spring-boot-test-starter 所以本身就引入了junit5的库 No…

前后端、网关、协议方面补充

这里写目录标题 前后端接口文档简介前后端视角对于前端对于后端代码注册路由路由处理函数 关于httpGET/POST底层网络关于前端的获取 路由器网关路由器的IP简介公网IP(WAN IP)私网IP(LAN IP)无线网络IP(WIFI IP)查询路由器私网IP路由器公网IP LAN口与WIFI简介基本原理 手动配置电…

英伟达基于Mistral 7B开发新一代Embedding模型——NV-Embed-v2

我们介绍的 NV-Embed-v2 是一种通用嵌入模型,它在大规模文本嵌入基准(MTEB 基准)(截至 2024 年 8 月 30 日)的 56 项文本嵌入任务中以 72.31 的高分排名第一。此外,它还在检索子类别中排名第一(…

【计算机网络】TCP网络特点2

断开连接 四次挥手 原因 TCP 四次挥手是为了满足 TCP 连接的全双工特性:两个方向都可以自由传输 保证数据传输的完整性:两方都完成了数据发送和接收并且都同意断开连接 可靠地终止连接以及避免数据混淆和错误等需求:每个方向都需要单独确认导致四次挥手过程 这些…

Opengl光照测试

代码 #include "Model.h" #include "shader_m.h" #include "imgui.h" #include "imgui_impl_glfw.h" #include "imgui_impl_opengl3.h" //以上是放在同目录的头文件#include <glad/glad.h> #include <GLFW/glfw3.…

【MySQL】SQL语言

【MySQL】SQL语言 文章目录 【MySQL】SQL语言前言一、SQL的通用语法二、SQL的分类三、SQLDDLDMLDQLDCL 总结 前言 本篇文章将讲到SQL语言&#xff0c;包括SQL的通用语法,SQL的分类,以及SQL语言的DDL,DML,DQL,DCL。 一、SQL的通用语法 在学习具体的SQL语句之前&#xff0c;先来…

.netcore + postgis 保存地图围栏数据

一、数据库字段 字段类型选择(Type) 设置对象类型为&#xff1a;geometry 二、前端传递的Json格式转换 前端传递围栏的各个坐标点数据如下&#xff1a; {"AreaRange": [{"lat": 30.123456,"lng": 120.123456},{"lat": 30.123456…

T265相机双目鱼眼+imu联合标定(全记录)

最近工作用到t265&#xff0c;记录一遍标定过程 1.安装驱动 首先安装realsense驱动&#xff0c;因为笔者之前使用过d435i&#xff0c;装的librealsense版本为2.55.1&#xff0c;直接使用t265会出现找不到设备的问题&#xff0c;经查阅发现是因为realsense在2.53.1后就不再支持…

【C语言指南】C语言内存管理 深度解析

&#x1f493; 博客主页&#xff1a;倔强的石头的CSDN主页 &#x1f4dd;Gitee主页&#xff1a;倔强的石头的gitee主页 ⏩ 文章专栏&#xff1a;《C语言指南》 期待您的关注 引言 C语言是一种强大而灵活的编程语言&#xff0c;为程序员提供了对内存的直接控制能力。这种对内存…

Python学习从0到1 day26 第三阶段 Spark ④ 数据输出

半山腰太挤了&#xff0c;你该去山顶看看 —— 24.11.10 一、输出为python对象 1.collect算子 功能: 将RDD各个分区内的数据&#xff0c;统一收集到Driver中&#xff0c;形成一个List对象 语法&#xff1a; rdd.collect() 返回值是一个list列表 示例&#xff1a; from …

【机器学习】机器学习中用到的高等数学知识-1.线性代数 (Linear Algebra)

向量(Vector)和矩阵(Matrix)&#xff1a;用于表示数据集&#xff08;Dataset&#xff09;和特征&#xff08;Feature&#xff09;。矩阵运算&#xff1a;加法、乘法和逆矩阵(Inverse Matrix)等&#xff0c;用于计算模型参数。特征值(Eigenvalues)和特征向量(Eigenvectors)&…

java项目-jenkins任务的创建和执行

参考内容: jenkins的安装部署以及全局配置 1.编译任务的general 2.源码管理 3.构建里编译打包然后copy复制jar包到运行服务器的路径 clean install -DskipTests -Pdev 中的-Pdev这个参数用于激活 Maven 项目中的特定构建配置&#xff08;Profile&#xff09; 在 pom.xml 文件…

【数据库取证】快速从服务器镜像文件中获取后台隐藏数据

文章关键词&#xff1a;电子数据取证、数据库取证、电子物证、云取证、手机取证、计算机取证、服务器取证 小编最近做了很多鉴定案件和参加相关电子数据取证比武赛&#xff0c;经常涉及到服务器数据库分析。现在分享一下技术方案&#xff0c;供各位在工作中和取证赛事中取得好成…

__VUE_PROD_HYDRATION_MISMATCH_DETAILS__ is not explicitly defined

VUE_PROD_HYDRATION_MISMATCH_DETAILS 未明确定义。您正在运行 Vue 的 esm-bundler 构建&#xff0c;它期望这些编译时功能标志通过捆绑器配置全局注入&#xff0c;以便在生产捆绑包中获得更好的tree-shaking优化。 Vue.js应用程序正在使用ESM&#xff08;ECMAScript模块&#…

git撤销、回退某个commit的修改

文章目录 撤销某个特定的commit方法 1&#xff1a;使用 git revert方法 2&#xff1a;使用 git rebase -i方法 3&#xff1a;使用 git reset 撤销某个特定的commit 如果你要撤销某个很早之前的 commit&#xff0c;比如 7461f745cfd58496554bd672d52efa8b1ccf0b42&#xff0c;可…