基于 PyTorch 的模型瘦身三部曲:量化、剪枝和蒸馏,让模型更短小精悍!

基于 PyTorch 的模型量化、剪枝和蒸馏

    • 1. 模型量化
      • 1.1 原理介绍
      • 1.2 PyTorch 实现
    • 2. 模型剪枝
      • 2.1 原理介绍
      • 2.2 PyTorch 实现
    • 3. 模型蒸馏
      • 3.1 原理介绍
      • 3.2 PyTorch 实现
    • 参考文献

在这里插入图片描述

1. 模型量化

1.1 原理介绍

模型量化是将模型参数从高精度(通常是 float32)转换为低精度(如 int8 或更低)的过程。这种技术可以显著减少模型大小、降低计算复杂度,并加快推理速度,同时尽可能保持模型的性能。
在这里插入图片描述
量化的主要方法包括:

  1. 动态量化

    • 在推理时动态地将权重从 float32 量化为 int8。
    • 激活值在计算过程中保持为浮点数。
    • 适用于 RNN 和变换器等模型。
  2. 静态量化

    • 在推理之前,预先将权重从 float32 量化为 int8。
    • 在推理过程中,激活值也被量化。
    • 需要校准数据来确定激活值的量化参数。
  3. 量化感知训练(QAT)

    • 在训练过程中模拟量化操作。
    • 允许模型适应量化带来的精度损失。
    • 通常能够获得比后量化更高的精度。

1.2 PyTorch 实现

import torch# 1. 动态量化
model_fp32 = MyModel()
model_int8 = torch.quantization.quantize_dynamic(model_fp32,  # 原始模型{torch.nn.Linear, torch.nn.LSTM},  # 要量化的层类型dtype=torch.qint8  # 量化后的数据类型
)# 2. 静态量化
model_fp32 = MyModel()
model_fp32.eval()  # 设置为评估模式# 设置量化配置
model_fp32.qconfig = torch.quantization.get_default_qconfig('fbgemm')
model_fp32_prepared = torch.quantization.prepare(model_fp32)# 使用校准数据进行校准
with torch.no_grad():for batch in calibration_data:model_fp32_prepared(batch)# 转换模型
model_int8 = torch.quantization.convert(model_fp32_prepared)# 3. 量化感知训练
model_fp32 = MyModel()
model_fp32.train()  # 设置为训练模式# 设置量化感知训练配置
model_fp32.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
model_fp32_prepared = torch.quantization.prepare_qat(model_fp32)# 训练循环
for epoch in range(num_epochs):for batch in train_data:output = model_fp32_prepared(batch)loss = criterion(output, target)loss.backward()optimizer.step()# 转换模型
model_int8 = torch.quantization.convert(model_fp32_prepared)

2. 模型剪枝

2.1 原理介绍

模型剪枝是一种通过移除模型中不重要的权重或神经元来减少模型复杂度的技术。剪枝可以减少模型大小、降低计算复杂度,并可能改善模型的泛化能力。
在这里插入图片描述

主要的剪枝方法包括:

  1. 权重剪枝

    • 移除绝对值小于某个阈值的单个权重。
    • 可以大幅减少模型参数数量,但可能导致非结构化稀疏性。
  2. 结构化剪枝

    • 移除整个卷积核、神经元或通道。
    • 产生更加规则的稀疏结构,有利于硬件加速。
  3. 重要性剪枝

    • 基于权重或激活值的重要性评分来决定剪枝对象。
    • 常用的重要性度量包括权重幅度、激活值、梯度等。

2.2 PyTorch 实现

import torch
import torch.nn.utils.prune as prunemodel = MyModel()# 1. 权重剪枝
prune.l1_unstructured(model.conv1, name='weight', amount=0.3)# 2. 结构化剪枝
prune.ln_structured(model.conv1, name='weight', amount=0.5, n=2, dim=0)# 3. 全局剪枝
parameters_to_prune = ((model.conv1, 'weight'),(model.conv2, 'weight'),(model.fc1, 'weight'),
)
prune.global_unstructured(parameters_to_prune,pruning_method=prune.L1Unstructured,amount=0.2
)# 4. 移除剪枝
for module in model.modules():if isinstance(module, torch.nn.Conv2d):prune.remove(module, 'weight')

3. 模型蒸馏

3.1 原理介绍

模型蒸馏是一种将复杂模型(教师模型)的知识转移到简单模型(学生模型)的技术。这种方法可以在保持性能的同时,大幅减少模型的复杂度和计算需求。
在这里插入图片描述

主要的蒸馏方法包括:

  1. 响应蒸馏

    • 学生模型学习教师模型的最终输出(软标签)。
    • 软标签包含了教师模型对不同类别的置信度信息。
  2. 特征蒸馏

    • 学生模型学习教师模型的中间层特征。
    • 可以传递更丰富的知识,但需要设计合适的映射函数。
  3. 关系蒸馏

    • 学习样本之间的关系,如相似度或排序。
    • 有助于保持教师模型学到的数据结构。

3.2 PyTorch 实现

import torch
import torch.nn as nn
import torch.nn.functional as Fclass DistillationLoss(nn.Module):def __init__(self, alpha=0.5, temperature=2.0):super().__init__()self.alpha = alphaself.T = temperaturedef forward(self, student_outputs, teacher_outputs, labels):# 硬标签损失hard_loss = F.cross_entropy(student_outputs, labels)# 软标签损失soft_loss = F.kl_div(F.log_softmax(student_outputs / self.T, dim=1),F.softmax(teacher_outputs / self.T, dim=1),reduction='batchmean') * (self.T * self.T)# 总损失loss = (1 - self.alpha) * hard_loss + self.alpha * soft_lossreturn loss# 训练循环
teacher_model = TeacherModel().eval()
student_model = StudentModel().train()
distillation_loss = DistillationLoss(alpha=0.5, temperature=2.0)for epoch in range(num_epochs):for batch, labels in train_loader:optimizer.zero_grad()with torch.no_grad():teacher_outputs = teacher_model(batch)student_outputs = student_model(batch)loss = distillation_loss(student_outputs, teacher_outputs, labels)loss.backward()optimizer.step()

通过这些技术的组合使用,可以显著减小模型大小、提高推理速度,同时尽可能保持模型性能。在实际应用中,可能需要根据具体任务和硬件限制来选择和调整这些方法。

参考文献

[1]Jacob, B., Kligys, S., Chen, B., Zhu, M., Tang, M., Howard, A., Adam, H., & Kalenichenko, D. (2018). Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR) (pp. 2704-2713).[2]Krishnamoorthi, R. (2018). Quantizing deep convolutional networks for efficient inference: A whitepaper. arXiv preprint arXiv:1806.08342.[3]Han, S., Pool, J., Tran, J., & Dally, W. (2015). Learning both Weights and Connections for Efficient Neural Network. In Advances in Neural Information Processing Systems (NeurIPS) (pp. 1135-1143).[4]Li, H., Kadav, A., Durdanovic, I., Samet, H., & Graf, H. P. (2016). Pruning Filters for Efficient ConvNets. arXiv preprint arXiv:1608.08710.[5]Hinton, G., Vinyals, O., & Dean, J. (2015). Distilling the Knowledge in a Neural Network. arXiv preprint arXiv:1503.02531.[6]Romero, A., Ballas, N., Kahou, S. E., Chassang, A., Gatta, C., & Bengio, Y. (2014). FitNets: Hints for Thin Deep Nets. arXiv preprint arXiv:1412.6550.

创作不易,烦请各位观众老爷给个三连,小编在这里跪谢了!
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/381953.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【北京迅为】《i.MX8MM嵌入式Linux开发指南》-第三篇 嵌入式Linux驱动开发篇-第四十四章 注册字符设备号

i.MX8MM处理器采用了先进的14LPCFinFET工艺,提供更快的速度和更高的电源效率;四核Cortex-A53,单核Cortex-M4,多达五个内核 ,主频高达1.8GHz,2G DDR4内存、8G EMMC存储。千兆工业级以太网、MIPI-DSI、USB HOST、WIFI/BT…

【Linux】汇总TCP网络连接状态命令

输入命令: netstat -na | awk /^tcp/ {S[$NF]} END {for(a in S) print a, S[a]} 显示: 让我们逐步解析这个命令: netstat -na: netstat 是一个用于显示网络连接、路由表、接口统计等信息的命令。 -n 选项表示输出地址和端口以数字格式显示…

Armv8/Armv9架构的学习大纲-学习方法-自学路线-付费学习路线

本文给大家列出了Arm架构的学习大纲、学习方法、自学路线、付费学习路线。有兴趣的可以关注,希望对您有帮助。 如果大家有需要的,欢迎关注我的CSDN课程:https://edu.csdn.net/lecturer/6964 ARM 64位架构介绍 ARM 64位架构介绍 ARM架构概况…

Wi-SUN无线通信技术 — 大规模分散式物联网应用首选

引言 在数字化浪潮的推动下,物联网(IoT)正逐渐渗透到我们生活的方方面面。Wi-SUN技术以其卓越的性能和广泛的应用前景,成为了大规模分散式物联网应用的首选。本文将深入探讨Wi-SUN技术的市场现状、核心优势、实际应用中的案例以及…

JavaEE (1)

web开发概述 所谓web开发,指的是从网页中向后端程序发送请求,与后端程序进行 交互. 流程图如下 Web服务器是指驻留于因特网上某种类型计算机的程序. 可以向浏览器等Web客户端提供文档,也可以放置网站文件,让全世界浏览; 它是一个容器&…

C++ —— 关于模板初阶

1.什么是模板 在C中,模板(template)是一种通用的编程工具,允许程序员编写通用代码以处理多种数据型或数据结构,而不需要为每种特定类型编写重复的代码,通过模板,可以实现代码的复用和泛化提高代…

QT5.9.9+Android开发环境搭建

文章目录 1.安装准备1.1 下载地址1.2 安装前准备2.安装过程2.1 JDK安装2.1.1 安装2.1.2 环境变量配置2.2 SDK配置2.2.1 安装2.2.2 环境变量配置2.2.3 adb 错误解决2.2.4 其他SDK安装2.2.5 AVD虚拟机配置2.3 NDK配置2.4 QT 5.9.9安装配置2.4.1 QT安装2.4.2 配置安卓环境3.QT工程…

【Linux】进程信号 --- 信号处理

👦个人主页:Weraphael ✍🏻作者简介:目前正在学习c和算法 ✈️专栏:Linux 🐋 希望大家多多支持,咱一起进步!😁 如果文章有啥瑕疵,希望大佬指点一二 如果文章对…

Java---异常

乐观学习,乐观生活,才能不断前进啊!!! 我的主页:optimistic_chen 我的专栏:c语言 ,Java 欢迎大家访问~ 创作不易,大佬们点赞鼓励下吧~ 文章目录 什么是异常异常的分类编译…

安装 VMware vSphere vCenter 8.0

安装 VMware vSphere vCenter 8.0 1、运行安装程序 2、语言选择中文 3、点下一步 4、接受许可协议,点下一步 5、填写部署vCenter服务的ESXI主机IP地址以及对应ESXI主机的账号密码,这里将vCenter服务部署在192.168.1.14这台ESXi主机上 6、接受证书警告 7…

新手小白的pytorch学习第十弹----多类别分类问题模型以及九、十弹的练习

目录 1 多类别分类模型1.1 创建数据1.2 创建模型1.3 模型传出的数据1.4 损失函数和优化器1.5 训练和测试1.6 衡量模型性能的指标 2 练习Exercise 之前我们已经学习了 二分类问题,二分类就像抛硬币正面和反面,只有两种情况。 这里我们要探讨一个 多类别…

基于关键字驱动设计Web UI自动化测试框架!

引言 在自动化测试领域,关键字驱动测试(Keyword-Driven Testing, KDT)是一种高效且灵活的方法,它通过抽象测试用例中的操作为关键字,实现了测试用例与测试代码的分离,从而提高了测试脚本的可维护性和可扩展…

记录解决springboot项目上传图片到本地,在html里不能回显的问题

项目场景: 项目场景:在我的博客系统里:有个相册模块:需要把图片上传到项目里,在html页面上显示 解决方案 1.建一个文件夹 例如在windows系统下。可以在项目根目录下建个photos文件夹,把上传的图片文件…

[PM]产品运营

生命周期 运营阶段 主要工作 拉新 新用户的定义 冷启动 拉新方式 促活 用户活跃的原因 量化活跃度 运营社区化/内容化 留存 用户流失 培养用户习惯 用户挽回 变现 变现方式 付费模式 广告模式 数据变现 变现指标 传播 营销 认识营销 电商营销中心 拼团活动 1.需求整理 2.…

JMeter请求导出Excel

前言 今天记录一个使用JMeter模拟浏览器请求后端导出,并下载Excel到指定位置的过程 创建请求 同样先创建一个线程组,再创建一个请求,设置好请求路径,端口号等 查看结果树 右键--添加--监听器--查看结果树 这里可以查看&#…

VUE之---slot插槽

什么是插槽 slot 【插槽】, 是 Vue 的内容分发机制, 组件内部的模板引擎使用slot 元素作为承载分发内容的出口。slot 是子组件的一个模板标签元素, 而这一个标签元素是否显示, 以及怎么显示是由父组件决定的。 VUE中slot【插槽】…

详细讲解vue3 watch回调的触发时机

目录 Vue 3 watch 基本用法 副作用刷新时机 flush 选项 flush: pre flush: post flush: sync Vue 3 watch 基本用法 计算属性允许我们声明性地计算衍生值。然而在有些情况下,我们需要在状态变化时执行一些“副作用”:例如更改 DOM,或是…

display: flex 和 justify-content: center 强大居中

你还在为居中而烦恼吗,水平居中多个元素、创建响应式布局、垂直和水平同时居中内容。它,display: flex 和 justify-content: center 都可以完成! display: flex:将元素定义为flex容器 justify-content:定义项目在主轴…

【2024最新华为OD-C/D卷试题汇总】[支持在线评测] LYA的生日派对座位安排(200分) - 三语言AC题解(Python/Java/Cpp)

🍭 大家好这里是清隆学长 ,一枚热爱算法的程序员 ✨ 本系列打算持续跟新华为OD-C/D卷的三语言AC题解 💻 ACM银牌🥈| 多次AK大厂笔试 | 编程一对一辅导 👏 感谢大家的订阅➕ 和 喜欢💗 🍿 最新华为OD机试D卷目录,全、新、准,题目覆盖率达 95% 以上,支持题目在线…

FairGuard游戏加固入选《嘶吼2024网络安全产业图谱》

2024年7月16日,国内网络安全专业媒体——嘶吼安全产业研究院正式发布《嘶吼2024网络安全产业图谱》(以下简称“产业图谱”)。 本次发布的产业图谱,共涉及七大类别,127个细分领域。全面展现了网络安全产业的构成和重要组成部分,探…