Pytorch量化之Post Train Static Quantization(训练后静态量化)

使用Pytorch训练出的模型权重为fp32,部署时,为了加快速度,一般会将模型量化至int8。与fp32相比,int8模型的大小为原来的1/4, 速度为2~4倍。
Pytorch支持三种量化方式:

  • 动态量化(Dynamic Quantization): 只量化权重,激活在推理过程中进行量化
  • 静态量化(Static Quantization): 量化权重和激活
  • 量化感知训练(Quantization Aware Training,QAT): 插入量化算子后进行训练,主要在静态量化精度不满足需求时进行。
    大多数情况下,我们只需要进行静态量化,少数情况下在量化感知训练不满足时使用QAT进行微调。所以本篇只重点讲静态量化,并且理论部分先略过(后面再专门总结),只关注实操。
    注:下面的代码是在pytorch1.10下,后面Pytorch对量化的接口有调整
    官方文档:Quantization — PyTorch 1.10 documentation

动态模式(Eager Mode)与静态模式(fx graph)

Pytorch支持用2种方式量化,一种是动态图模式,也是我们日常使用Pytorch训练所使用的方式,使用这种方式量化需要自己手动修改网络结构,在支持量化的算子前、后插入量化节点,优点是方便调试。静态模式则是由pytorch自动在计算图中插入量化节点,不需要手动修改网络。
网络上大部分的教程都是基于静态模式,这种方式比较大的问题就是需要手动修改网络结构,官方教程里的网络是属于demo型, 其中的QuantStub和DeQuantStub就分别是量化和反量化的节点:

# define a floating point model where some layers could be statically quantized
class M(torch.nn.Module):def __init__(self):super(M, self).__init__()# QuantStub converts tensors from floating point to quantizedself.quant = torch.quantization.QuantStub()self.conv = torch.nn.Conv2d(1, 1, 1)self.relu = torch.nn.ReLU()# DeQuantStub converts tensors from quantized to floating pointself.dequant = torch.quantization.DeQuantStub()def forward(self, x):# manually specify where tensors will be converted from floating# point to quantized in the quantized modelx = self.quant(x)x = self.conv(x)x = self.relu(x)# manually specify where tensors will be converted from quantized# to floating point in the quantized modelx = self.dequant(x)return x

Pytorch对于很多网络层是不支持量化的(比如很常用的Prelu),如果我们用这种方式,我们就必须在这些不支持的层前面插入DeQuantStub,然后在支持的层前面插入QuantStub。笔者体验下来,体验很差,个人觉得不太实用,会破坏原来的网络结构。
而静态图模式,我们只需要调用Pytorch提供的接口将原模型转换一下即可,不需要修改原来的网络结构文件,个人认为实用性更强。
image.png

静态模式量化

1. 载入fp32模型,并转成fx graph

其中量化参数有‘fbgemm’和‘qnnpack’两种,前者在x86运行,后者在arm运行。

model_fp32 = torch.load(xxx)
model_fp32_quantize = copy.deepcopy(model_fp32)
qconfig_dict = {"": torch.quantization.get_default_qconfig('fbgemm')}
model_fp32_quantize.eval()
# preparemodel_prepared = quantize_fx.prepare_fx(model_fp32_quantize, qconfig_dict)
model_prepared.eval()

2.读取量化数据,标定(Calibration)量化参数

标定的过程就是使用模型推理量化图片,然后统计权重和激活分布,从而得到量化参数。量化图片一般来源于训练集(几百张左右,根据测试情况调整)。量化图片可以通过Pytorch的Dataloader读取,也可以直接自行实现读图片然后送入网络。

### 使用dataloader读取
for i, (data, label) in enumerate(train_loader):data = data.to(torch.device("cpu:0"))outputs = model_prepared(data)print("calibrating {}".format(i))if i > 1000:break

3. 转换为量化模型并保存

quantized_model = quantize_fx.convert_fx(model_prepared)
torch.jit.save(torch.jit.script(quantized_model), "quantized_model.pt")

速度测试

量化后的模型使用方法与fp32模型一样:

import torch
import cv2
import numpy as np
torch.set_num_threads(1)fused_model = torch.jit.load("jit_model.pt")
fused_model.eval()
fused_model.to(torch.device("cpu:0"))img = cv2.imread("./1.png")
img_fp32 = img.astype(np.float32)
img_fp32 = (img_fp32-127.5) / 127.5
input = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).float()def speed_test(model, input):# warm upfor i in range(10):model(input)import timestart = time.time()for i in range(100):model(input)end = time.time()print("model time: ", (end-start)/100)time.sleep(10)# quantized model
quantized_model= torch.jit.load("quantized_model.pt")
quantized_model.eval()
quantized_model.to(torch.device("cpu:0"))speed_test(fused_model, input)
speed_test(quantized_model, input)

实测fp32模型单核运行120ms, 量化后47ms

结语

本文介绍了fx graph模式下的Pytorch的PTSQ方法,并实测了一个模型,效果还比较不错。
1_995567224_161_79_3_732056265_62005da0d7c1b531a6cf91ea587d312e.jpg

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/83136.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

微服务服务拆分和远程调用

一、服务架构比较 单体架构:简单方便,高度耦合,扩展性差,适合小型项目。例如:学生管理系统 分布式架构:松耦合,扩展性好,但架构复杂,难度大。适合大型互联网项目&#x…

python的gui界面程序爬虫,python的gui界面怎么打开

大家好,小编来为大家解答以下问题,python的gui界面怎么打开,python的gui界面程序爬虫,今天让我们一起来看看吧! Python支持多种图形界面的第三方库,包括: wxWidgets Qt GTK Tkinter&#xf…

[信号与系统系列] 正弦振幅调制之差拍信号

当将具有不同频率的两个正弦曲线相乘时,可以创建一个有趣的音频效果,称为差拍音符。这种现象听起来像颤音,最好通过选择一个频率非常小的信号与和另一个频率大约1KHz的信号,把二者混合从而听到。一些乐器能够自然产生差拍音符。使…

idea如何上传项目到github(超详细)

idea如何上传项目到github 1、IDEA配置2、项目上传到本地仓库2.1、创建本地git仓库2.2、Add操作2.3、Commit操作 3、项目上传到Github4、拿到登录Github的token 1、IDEA配置 File-Settings-VersionControl-Git Git的安装路径下bin目录下的git.exe可执行文件 可以直接点 Gene…

基于TF-IDF+TensorFlow+词云+LDA 新闻自动文摘推荐系统—深度学习算法应用(含ipynb源码)+训练数据集

目录 前言总体设计系统整体结构图系统流程图 运行环境Python 环境TensorFlow环境方法一方法二 模块实现1. 数据预处理1)导入数据2)数据清洗3)统计词频 2. 词云构建3. 关键词提取4. 语音播报5. LDA主题模型6. 模型构建 系统测试工程源代码下载…

十九、docker学习-Dockerfile

Dockerfile 官网地址 https://docs.docker.com/engine/reference/builder/Dockerfile其实就是我们用来构建Docker镜像的源码,当然这不是所谓的编程源码,而是一些命令的集合,只要理解它的逻辑和语法格式,就可以很容易的编写Docke…

Android 面试重点之Framework (Handler篇)

近期在网上看到不少Android 开发分享的面试经验,我发现基本每个面经中多多少少都有Framework 底层原理的影子。它也是Android 开发中最重要的一个部分,面试官一般会通过 Framework底层中的一些逻辑原理由浅入深进行提问,来评估应聘者的真实水…

对强缓存和协商缓存的理解

浏览器缓存的定义: 浏览器缓存是浏览器在本地磁盘对用户最近请求过的文档进行存储,当访问者再次访问同一页面时,浏览器就可以直接从本地磁盘加载文档。 浏览器缓存分为强缓存和协商缓存。 浏览器是如何使用缓存的: 浏览器缓存…

HarmonyOS应用开发者基础认证考试题库

此博文为HarmonyOS应用开发者基础认证考试的最后的大考,要求100分取得90分方可获取证书、现将考试的题库进行分享,希望能帮到大家。但是需要注意的是,题库会不定时的进行题目删减,但是大概的内容是不会进行改变的。真心希望这篇博…

MongoDB 使用总结

🍓 简介:java系列技术分享(👉持续更新中…🔥) 🍓 初衷:一起学习、一起进步、坚持不懈 🍓 如果文章内容有误与您的想法不一致,欢迎大家在评论区指正🙏 🍓 希望这篇文章对你有所帮助,欢…

数据结构和算法——哈希查找冲突处理方法(开放地址法-线性探测、平方探测、双散列探测、再散列,分离链接法)

目录 开放地址法(Open Addressing) 线性探测(Linear Probing) 散列表查找性能分析 平方探测(Quadratic Probing) 定理 平方探测法的查找与插入 双散列探测法(Double Hashing&#xff09…

爬虫011_元组高级操作_以及字符串的切片操作---python工作笔记030

获取元组的下标对应的值 注意元组是不可以修改值的,只能获取不能修改 但是列表是可以修改值的对吧

界面控件DevExpress WPF Chart组件——拥有超快的数据可视化库!

DevExpress WPF Chart组件拥有超大的可视化数据集,并提供交互式仪表板与高性能WPF图表库。DevExpress Charts提供了全面的2D / 3D图形集合,包括数十个UI定制和数据分析/数据挖掘选项。 PS:DevExpress WPF拥有120个控件和库,将帮助…

亚马逊对AIGC的定义

大家好,这里是Doker,最近AIGC非常火,这里我们聊一下什么是AIGC. 一、 AIGC 介绍与典型行业应用场景 ​AIGC 又称生成式 AI (Generative AI),是继专业生产内容(PGC, Professional-generated Content)、用户…

人脸识别场景下Faiss大规模向量检测性能测试评估分析

在前面的两篇博文中,主要是考虑基于之前以往的人脸识别项目经历结合最近使用到的faiss来构建更加高效的检索系统,感兴趣的话可以自行移步阅读即可: 《基于facenetfaiss开发构建人脸识别系统》 Facenet算法的优点:高准确率&#…

DoIP学习笔记系列:(三)用CAPL脚本过“安全认证”,$27服务实现

文章目录 1. 如何调用接口通过安全认证?如何新建CAPL工程,在此不再赘述,本章主要分享一下如何在CAPL中调用DoIP接口、diag接口进行DoIP和诊断的测试。 注意:CANoe工具本身的使用没什么难的,所谓会者不难难者不会,各位小伙伴有疑问要多问,多交流,往往难事都只是一层窗户…

ElasticSearch:环境搭建步骤

1、拉取镜像 docker pull elasticsearch:7.4.0 2、创建容器 docker run -id --name elasticsearch -d --restartalways -p 9200:9200 -p 9300:9300 -v /usr/share/elasticsearch/plugins:/usr/share/elasticsearch/plugins -e "discovery.typesingle-node" elasti…

图的拓扑排序算法

拓扑排序 什么是拓扑排序? 比如说,我们平时工作过程中一定听过一个词叫做—不能循环依赖。什么意思? A依赖BCD,B依赖CD,C依赖D,D依赖EF,想要获得A的话,首先就要先有EF,有…

webpack基础知识五:说说Loader和Plugin的区别?编写Loader,Plugin的思路?

一、区别 前面两节我们有提到Loader与Plugin对应的概念,先来回顾下 loader 是文件加载器,能够加载资源文件,并对这些文件进行一些处理,诸如编译、压缩等,最终一起打包到指定的文件中plugin 赋予了 webpack 各种灵活的…

二分法的应用

文章目录 什么是二分法🎮二分查找的优先级二分查找的步骤💥图解演示🧩 代码演示🫕python程序实现🐈‍⬛C程序实现🐕‍🦺C程序实现🐯Java程序实现🐳 非常规类二分查找&…