网络剪枝——network-slimming 项目复现

目录

文章目录

  • 目录
  • 网络剪枝——network-slimming 项目复现
    • clone 存储库
    • Baseline
      • vgg
        • 训练
        • 结果
      • resnet
        • 训练
        • 结果
      • densenet
        • 训练
        • 结果
    • Sparsity
      • vgg
        • 训练
        • 结果
      • resnet
        • 训练
        • 结果
      • densenet
        • 训练
        • 结果
    • Prune
      • vgg
        • 命令
        • 结果
      • resnet
        • 命令
        • 结果
      • densenet
        • 命令
        • 结果
    • Fine-tune
      • vgg
        • 训练
        • 结果
      • resnet
        • 训练
        • 结果
      • densenet
        • 训练
        • 结果
    • 模型大小计算脚本 param_counter.py
    • 结果汇总
      • CIFAR10

网络剪枝——network-slimming 项目复现

  • 【GiHnub】:Eric-mingjie/network-slimming: Network Slimming (Pytorch) (ICCV 2017) (github.com)
  • 【作者复现项目】:
  • 通过百度网盘分享的文件:network-slimming-regin.zip
    链接:https://pan.baidu.com/s/1vTJSLS5ZDjE8R8XaApW96A?pwd=t1z2
    提取码:t1z2
    • 仅以 CIFAR-10 为例,CIFAR-100 同理.
    • 提供中文README_zh-CN.md.
    • 包含 CIFAR-10/100 数据集data.cifar10data.cifar100.
    • 解决了 main.py 运行报错问题.
    • 加入了计算训练后模型的 Parameters 大小脚本param_counter.py.

clone 存储库

注:若 clone 作者复现项目,则忽略这一步,直接进入下一步;若想自行从头复现,则 clone 以下存储库.

  • 链接:https://pan.baidu.com/s/1nppPLKoiPbJPW60HOa2TxQ?pwd=ud89
    提取码:ud89


Baseline

vgg

训练
  • 【命令】:
python main.py --dataset cifar10 --arch vgg --depth 19

  • 这个报错通常出现在使用 Python 的multiprocessing库来创建进程时,尤其是在 Windows 操作系统上. 在 Windows 上,Python 的multiprocessing模块启动新进程的方式与 Linux 或 macOS 不同,它使用 “spawn” 来启动新进程,这意味着每个子进程都会从头开始执行脚本. 因此,如果在脚本顶层级别启动进程(而不是在受保护的if __name__ == '__main__':块中),每个子进程都会尝试再次启动子进程,从而导致无限递归和上述错误.
  • 为了解决这个问题,应 确保多进程代码(即main.py)位于if __name__ == '__main__':保护块内.
# 导入部分
...def main():...if __name__ == '__main__':main()
  • 再次运行命令,又报错:

  • 这个报错通常发生在尝试直接索引一个0维的张量(tensor)时. 在 PyTorch 中,0 维张量是一个单一值的张量,但是不能像普通的数组那样通过索引来访问。要从 0 维张量中获取其 Python 数值,需要使用.item()方法.
  • 为了解决这个问题,应该 使用.item()方法来替换所有.data[0]的用法
# 在 train 函数中
if batch_idx % args.log_interval == 0:print('Train Epoch: {} [{}/{} ({:.1f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item()))# 在 test 函数中
for data, target in test_loader:if args.cuda:data, target = data.cuda(), target.cuda()data, target = Variable(data), Variable(target)output = model(data)test_loss += F.cross_entropy(output, target, reduction='sum').item()  # sum up batch losspred = output.data.max(1, keepdim=True)[1]correct += pred.eq(target.data.view_as(pred)).cpu().sum()test_loss /= len(test_loader.dataset)
  • 再次运行命令就正常运行了:

结果
  • Terminal

  • 在 ./logs 生成文件checkpoint.pth.tarmodel_best.pth.tar

resnet

训练
  • 【命令】:
python main.py --dataset cifar10 --arch resnet --depth 164
结果

densenet

训练
  • 【命令】:
python main.py --dataset cifar10 --arch densenet --depth 40
结果


Sparsity

vgg

训练
  • 【命令】:
python main.py -sr --s 0.0001 --dataset cifar10 --arch vgg --depth 19
结果

resnet

训练
  • 【命令】:
python main.py -sr --s 0.00001 --dataset cifar10 --arch resnet --depth 164
结果

densenet

训练
  • 【命令】:
python main.py -sr --s 0.00001 --dataset cifar10 --arch densenet --depth 40
结果


Prune

vgg

命令
python vggprune.py --dataset cifar10 --depth 19 --percent 0.7 --model ./results/CIFAR10_results/CIFAR10-Vgg/Sparsity/model_best.pth.tar --save ./prunes

  • main.py同理,为了解决这个问题,应 确保多进程代码位于if __name__ == '__main__':保护块内
# 导入部分
...def main():...if __name__ == '__main__':main()
  • 之后就可以正常运行了.

结果
  • Terminal

  • 在./prunes生成文件prune.txtpruned.pth.tar

  • prune.txt中我们可以看到 Number of parametersTest accuracy

resnet

命令
python resprune.py --dataset cifar10 --depth 164 --percent 0.4 --model ./results/CIFAR10_results/CIFAR10-Resnet-164/Sparsity/model_best.pth.tar --save ./prunes
结果

densenet

命令
python denseprune.py --dataset cifar10 --depth 40 --percent 0.4 --model ./results/CIFAR10_results/CIFAR10-Densenet-40/Sparsity/model_best.pth.tar --save ./prunes
结果


Fine-tune

vgg

训练
  • 【命令】:
python main.py --refine ./results/CIFAR10_results/CIFAR10-Vgg/Prune/pruned.pth.tar --dataset cifar10 --arch vgg --depth 19 --epochs 160
结果

resnet

训练
  • 【命令】:
python main.py --refine ./results/CIFAR10_results/CIFAR10-Resnet-164/Prune/pruned.pth.tar --dataset cifar10 --arch resnet --depth 164 --epochs 160
结果

densenet

训练
  • 【命令】:
python main.py --refine ./results/CIFAR10_results/CIFAR10-Densenet-40/Prune/pruned.pth.tar --dataset cifar10 --arch densenet --depth 40 --epochs 160
结果


模型大小计算脚本 param_counter.py

  • 【路径】:./script/param_counter.py
import torchdef load_model(model_path):model = torch.load(model_path, map_location=torch.device('cpu'))return modeldef count_parameters(model_state_dict):total_params = sum(p.numel() for p in model_state_dict.values())return total_paramsdef get_model_parameters(model_path):# 加载模型状态字典model = load_model(model_path)# 模型状态字典存储在 'state_dict' 键下model_state_dict = model['state_dict'] if 'state_dict' in model else model# 计算参数总数total_params = count_parameters(model_state_dict)return total_params
  • main.py中:
from script.param_counter import get_model_parametersdef main():...# 计算 Parametersmodel_path = 'logs/model_best.pth.tar'total_params = get_model_parameters(model_path)print(f'Total parameters in the model: {total_params}')

结果汇总

注:与原项目结果略有差别.

CIFAR10

CIFAR10-VggBaselineSparsity(1e-4)Prune(70%)Fine-tune-160(70%)
Top1 Accuracy(%)93.7293.6033.9893.75
Parameters20.05M20.05M2.22M2.23M
CIFAR10-Resnet-164BaselineSparsity(1e-5)Prune(40%)Fine-tune-160(40%)
Top1 Accuracy(%)94.9995.0094.5995.27
Parameters1.74M1.74M1.46M1.49M
CIFAR10-Densenet-40BaselineSparsity(1e-5)Prune(40%)Fine-tune-160(40%)
Top1 Accuracy(%)94.1594.3794.1494.48
Parameters1.09M1.09M0.70M0.72M

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/397717.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

移情别恋c++ ദ്ദി˶ー̀֊ー́ ) ——5.string

1.字符串相乘 . - 力扣(LeetCode) 思路: 1.如果两个串有一个首元素为‘0’,则直接返回‘0’ 2.设置两层循环,内层第一次循环 用于str插入初始数据 (num2 的各个元素和num1 的最后一个元素相乘的结果&#…

C代码做底层及Matlab_SimuLink做应用层设计单片机程序

前言:SimuLink工具极其强大,但是能直接支持单片机自主开发的很少,造成这个问题的原因主要是我们使用的芯片底层多是C代码工程,芯片厂家也只提供C代码库,很少能提供SimuLink的支持库,即使提供也不是很不完善,如NXP的一些芯片提供的SimuLink库不含盖高级应用,再比如意法半…

哈希表 - 快乐数

202. 快乐数 方法一:用哈希集合检测循环 /*** param {number} n* return {boolean}*/let getNext function(n) {return n.toString().split().map(i > i ** 2).reduce((a, b) > a b); }let isHappy function(n) {let seen new Set();while (n ! 1 &&…

什么是跨境电商独立站?为什么选择做独立站?

独立站在近两年被推上风口,很多人跟风涌入赛道,但并不知道做独立站的根本原因是什么?为什么跨境电商要做独立站? 今天分享这篇文章,希望能帮助正在建站或想要建站的朋友们建立起对独立站的基本认知,做到不踩…

【学习笔记】Matlab和python双语言的学习(图论最短路径)

文章目录 前言一、图论基本概念示例 二、代码实现----Matlab三、代码实现----python总结 前言 通过模型算法,熟练对Matlab和python的应用。 学习视频链接: https://www.bilibili.com/video/BV1EK41187QF?p36&vd_source67471d3a1b4f517b7a7964093e6…

Java线程模型

一、相关知识 用户级线程(ULT):实现在用户空间的线程称为用户级线程。用户线程是完全建立在用户空间的线程库,用户线程的创建、调度、同步和销毁全由用户空间的库函数完成,不需要内核的参与,也不需要进行用…

FPGA之间数据传输的讨论:解析数据传输与同步技术

在现代电子工程领域,数据传输和同步技术是确保信息准确、高效传递的关键。FPGA间的高速数据传输是实现复杂系统功能的关键技术之一。本文将基于移知公开课《FPGA之间数据传输的讨论》的内容,探讨FPGA间数据传输的技术细节和面临的挑战,帮助读…

使用VS2022生成安装包

首先需要本地已经能够正常运行的软件包,包含可执行文件及必要的运行库等,如下所示RemoteCli.exe为最终的可执行文件 打开VS2022 ,选择 扩展–>管理扩展–>联机,搜索Microsoft Visual Studio Installer Projects,…

Lua调用c#

1. 类 --lua中使用C#的类非常简单 --固定套路 --CS.命名空间.类名 --Unity的类 比如 GameObject Transform等等 —— CS.UnityEngine.类名 --CS.UnityEngine.GameObject--通过C#中的类 实例化一个对象 lua中没有new 所以我们直接 类名括号就是实例化对象 --默认调用的 相当于就…

智能分析/视频汇聚EasyCVR安防视频融合管理云平台技术优势分析

安防行业的发展历程主要围绕视频监控技术的不断改革升级,从最初的模拟监控到数字监控,再到高清化、网络化监控,直至现在的智能化监控,每一次变革都推动了行业的快速发展。特别是近年来,随着AI、大数据、物联网等技术的…

LVS负载均衡(twenty-six day)

一、LVS (一)什么是LVS linux virtural server的简称,也就是linxu虚拟机服务器,这是一个由章文岩博士发起的开源项目,官网是http://www.linuxvirtualserver.org,现在lvs已经是linux内核标准的-部分,使用lv…

学术周交流与学习节选

文章目录 1、粒度多模态运动分析1.1 免特征重建的终身行人重识别1.2 无样本保留的终身行人重识别1.3 粒度多模态运动之类增量学习1.4 粒度多模态之人体姿态估计扩散模型 2、深度伪造的被动取证与主动防御2.1 研究现状及主要方法2.2 基于梯度的伪影特征表示2.3 基于伪造自适应学…

SQL注入实例(sqli-labs/less-18)

0、初始页面 先使用brup爆破密码,账号admin,密码admin 1、确定闭合字符 判断注入点在post请求参数的User-agent处 闭合字符为单引号 2、爆库名 3、爆表名 4、爆列名 5、查询最终目标 在index.php中有这么一句 $insert"INSERT INTO security.uage…

haproxy算法与具体实现

一、负载均衡 1.什么是负载均衡 负载均衡:Load Balance,简称LB,是一种服务或基于硬件设备等实现的高可用反向代理技术,负载均 衡将特定的业务(web服务、网络流量等)分担给指定的一个或多个后端特定的服务器或设备,从…

『大模型笔记』人类反馈的强化学习(Reinforcement Learning from Human Feedback, RLHF)

人类反馈的强化学习(Reinforcement Learning from Human Feedback, RLHF) 文章目录 一. 人类反馈的强化学习(Reinforcement Learning from Human Feedback, RLHF)1. 概念解释2. RLHF的组成部分2.1. 强化学习(Reinforcement Learning, RL)2.2. 状态空间(state space)2.3. 动作空…

深入InnoDB核心:揭秘B+树在数据库索引中的高效应用

目录 一、索引页与数据行的紧密关联 (一)数据页的双向链表结构 (二)记录行的单向链表结构 二、未创建索引情况 (一)无索引下的单页查找过程 以主键为搜索条件 以非主键列为搜索条件 (二…

ffmpeg 内存模型

最近在学习ffmpeg,阅读了一些packet和frame关于内存操作的api。在此长话短说,只说核心点。 ffmpeg模型 AVFrame 表示编码前的原始数据帧,AVPacket 表示编码后的压缩数据包。 问题: (1)从av_read_frame读…

算法打卡 Day20(二叉树)-找树左下角的值 + 路径总和 + 从中序与后序遍历序列构造二叉树

文章目录 Leetcode 513-找树左下角的值题目描述解题思路 Leetcode 112-路径总和题目描述解题思路相关题目Leetcode 113-路径总和 ii Leetcode 106-从中序与后序遍历序列构造二叉树题目描述解题思路类似题目Leetcode 105-从前序与中序遍历序列构造二叉树 Leetcode 513-找树左下角…

HSL模型和HSB模型,和懒人配色的Color Hunt

色彩不仅仅是视觉上的享受,它在数据可视化中也扮演着关键角色。通过合理运用色彩模型,我们可以使数据更具可读性和解释性。在这篇文章将探讨HSL(Hue, Saturation, Lightness)和HSB(Hue, Saturation, Brightness&#x…

Java中的抽象类与接口

1. 抽象类 1.1 抽象类概念 在面向对象的概念中,所有的对象都是通过类来描绘的,但是反过来,并不是所有的类都是用来描绘对象的, 如果一个类中没有包含足够的信息来描绘一个具体的对象,这样的类就是抽象类。 比如&…