模型部署:量化中的Post-Training-Quantization(PTQ)和Quantization-Aware-Training(QAT)

模型部署:量化中的Post-Training-Quantization(PTQ)和Quantization-Aware-Training(QAT)

  • 前言
  • 量化
    • Post-Training-Quantization(PTQ)
    • Quantization-Aware-Training(QAT)
  • 参考文献

前言

随着人工智能的不断发展,深度学习网络被广泛应用于图像处理、自然语言处理等实际场景,将其部署至多种不同设备的需求也日益增加。然而,常见的深度学习网络模型通常包含大量参数和数百万的浮点数运算(例如ResNet50具有95MB的参数以及38亿浮点数运算),实时地运行这些模型需要消耗大量内存和算力,这使得它们难以部署到资源受限且需要满足实时性、低功耗等要求的边缘设备。为了进一步推动深度学习网络模型在移动端或边缘设备中的快速部署,深度学习领域提出了一系列的模型压缩与加速方法:

  • 知识蒸馏(Knowledge distillation):使用教师-学生网络结构,让小型的学生网络模仿大型教师网络的行为,以使得准确率尽可能高的同时,能够获得一个轻量化的网络。
  • 剪枝(Parameter pruning):删除不必要的网络参数,以减少模型的规模和计算复杂度。
  • 低秩分解(Low-rank factorization):将模型的参数矩阵分解为较低秩的小矩阵,以减少模型的复杂度和计算成本。
  • 参数共享(Parameter sharing):将多个层共用一组参数,以减少模型的参数数量。
  • 量化(Quantization):将模型的参数和运算转化为更小的数据类型,以减少内存占用和计算时间。

量化

模型量化(Quantization)是一种将浮点计算转化为定点计算的技术,例如从FP32降低至INT8,主要用于减少模型的计算强度、参数大小以及内存消耗,以提高模型在设备上的推理计算效率,但是也有可能会带来一定的精度损失。

模型量化精度损失的主要原因为量化-反量化(Quantization-Dequantization)过程中取整引起的误差。这里简单介绍一下量化的计算方法,以FP32到INT8的量化为例,量化的核心思想就是将浮点数区间的参数映射到INT8的离散区间中。
量化公式:
q = r s + Z q = \frac{r}{s} + Z q=sr+Z反量化公式:
r = S ( q − Z ) r = S(q-Z) r=S(qZ)其中, r r r 为FP32的浮点数(real value), q q q 为INT8的量化值(quantization value),
S S S Z Z Z 分别为缩放因子(Scale-factor)和零点(Zero-Point)。

量化最重要的便是确定 S S S Z Z Z 的值, S S S Z Z Z 的计算公式如下:
S = r m a x − r m i n q m a x − q m i n S = \frac{r_{max}-r_{min}}{q_{max}-q_{min}} S=qmaxqminrmaxrmin Z = − r m i n S + q m i n Z = -\frac{r_{min}}{S} + q_{min} Z=Srmin+qmin其中, r m a x r_{max} rmax r m i n r_{min} rmin 分别为FP32网络参数最大、最小值, q m a x q_{max} qmax q m i n q_{min} qmin 分别为INT8网络参数最大、最小值。

为了减少量化所带来的精度损失,学者提出了Quantization-Aware-Training(QAT)方法,再介绍此之前,由于Post-Training-Quantization(PTQ)方法也经常在文献中出现,此篇博客将着重介绍这两个方法的含义与区别。
在这里插入图片描述

Post-Training-Quantization(PTQ)

Post-Training-Quantization(PTQ)是目前常用的模型量化方法之一。以INT8量化为例,PTQ方法的处理流程为:

  1. 首先在数据集上以FP32精度进行模型训练,得到训练好的模型;
  2. 使用小部分数据对FP32模型进行采样(Calibration),主要是为了得到网络各层参数的数据分布特性(比如统计最大最小值);
  3. 根据步骤2中的数据分布特性,计算出网络各层 S 和 Z 量化参数;
  4. 使用步骤3中的量化参数对FP32模型进行量化得到INT8模型,并将其部署至推理框架进行推理。

PTQ方法会使用小部分数据集来估计网络各层参数的数据分布,找到合适的S和Z的取值,从而一定程度上降低模型精度损失。然而,论文中指出PTQ方式虽然在大模型上效果较好(例如ResNet101),但是在小模型上经常会有较大的精度损失(例如MobileNet) 不同通道的输出范围相差可能会非常大(大于100x), 对异常值较为敏感。

Quantization-Aware-Training(QAT)

由上文可知PTQ方法中模型的训练和量化是分开的,而Quantization-Aware-Training(QAT)方法则是在模型训练时加入了伪量化节点,用于模拟模型量化时引起的误差,并通过微调使得模型在量化后尽可能减少精度损失。以INT8量化为例,QAT方法的处理流程为:

  1. 首先在数据集上以FP32精度进行模型训练,得到训练好的FP32模型;
  2. 在FP32模型中插入伪量化节点,得到QAT模型,并且在数据集上对QAT模型进行微调(Fine-tuning);
  3. 同PTQ方法中的采样(Calibration),并计算量化参数 S 和 Z ;
  4. 使用步骤3中得到的量化参数对QAT模型进行量化得到INT8模型,并部署至推理框架中进行推理。

在PyTorch中,可以使用 torch.quantization.quantize_dynamic() 方法来执行 QAT。这是一个基本的 QAT 代码示例:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.quantization import quantize_dynamic, QuantStub, DeQuantStub# 定义简单的模型
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.quant = QuantStub()self.dequant = DeQuantStub()self.fc1 = nn.Linear(784, 256)self.relu = nn.ReLU()self.fc2 = nn.Linear(256, 10)def forward(self, x):x = self.quant(x)x = self.fc1(x)x = self.relu(x)x = self.fc2(x)x = self.dequant(x)return x# 数据加载
# 这里使用 MNIST 数据集作为示例
from torchvision import datasets, transformstransform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])train_loader = torch.utils.data.DataLoader(datasets.MNIST('./data', train=True, download=True, transform=transform),batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(datasets.MNIST('./data', train=False, download=True, transform=transform),batch_size=64, shuffle=False)# 定义损失函数和优化器
model = SimpleModel()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 定义 QAT 训练函数
def train(model, train_loader, criterion, optimizer, num_epochs=5):model.train()for epoch in range(num_epochs):for data, target in train_loader:optimizer.zero_grad()output = model(data.view(data.shape[0], -1))loss = criterion(output, target)loss.backward()optimizer.step()# 训练模型
train(model, train_loader, criterion, optimizer, num_epochs=5)# 在训练完成后执行动态量化
quantized_model = quantize_dynamic(model, {nn.Linear}, dtype=torch.qint8)# 评估量化模型
def test(model, test_loader, criterion):model.eval()correct = 0total = 0with torch.no_grad():for data, target in test_loader:output = model(data.view(data.shape[0], -1))_, predicted = torch.max(output.data, 1)total += target.size(0)correct += (predicted == target).sum().item()accuracy = correct / totalprint(f'Accuracy of the network on the test images: {accuracy * 100:.2f}%')# 测试量化模型
test(quantized_model, test_loader, criterion)

上述代码示例中,我使用了一个简单的全连接神经网络,并在训练完成后使用torch.quantization.quantize_dynamic()对模型进行动态量化。在量化之前,我们通过QuantStub()DeQuantStub()添加了量化和反量化的辅助模块。这个示例使用了MNIST数据集,你可以根据你的实际需求替换成其他数据集和模型。

参考文献

量化感知训练(Quantization-aware-training)探索-从原理到实践

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/190732.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

AIGC|如何将Milvus集成到LangFlow中?详细代码演示!

目录 一、基本介绍 二、修改langflow代码使其支持milvus 三、效果演示 langflow是一个LangChain UI,它提供了一种交互界面来使用LangChain,通过简单的拖拽即可搭建自己的实验、原型流。通过在langflow中引入Milvus,用户可以更方便地存储和…

【Java 进阶篇】JQuery DOM操作:通用属性操作的绝妙魔法

在前端的舞台上,JQuery犹如一位魔法师,为我们展现了操纵HTML元素的奇妙技巧。而在这个技巧的精妙组成中,通用属性操作是一门绝妙的魔法。在本篇博客中,我们将深入研究JQuery DOM操作中的通用属性操作,揭示这段魔法的神…

业务出海之服务器探秘

这几年随着国内互联网市场的逐渐饱和,越来越多的公司加入到出海的行列,很多领域都取得了很不错的成就。虽然出海可以获得更加广阔的市场,但也需要面对很多之前在国内可能没有重视的一些问题。集中在海外服务器的选择维度上就有很大的变化。例…

“第六十七天”

各位,昨天查找子串的方法想起来了,就是那个KMP算法......自己理解都有点困难,还看看能不能想一下,确实很困难啊。 不要忘了toupper函数和tolower函数不是直接改变字符的大小写,而是返回对应的大小写的值,需…

2023nacos源码解读第2集——nacos-server的启动

nacos 是一个典型的server-client中间件,server这里安装最新的nacos-server 2.3.0-BETA版本 1.docker启动nacos-server 镜像详情参考nacos-docker项目的readme ,很方便,但是官方提供的nacos-server镜像往往可能滞后,且不便于后续…

Libhybris之线程局部存储TLS实例(五)

简介: CSDN博客专家,专注Android/Linux系统,分享多mic语音方案、音视频、编解码等技术,与大家一起成长! 优质专栏:Audio工程师进阶系列【原创干货持续更新中……】🚀 人生格言: 人生…

【STM32 CAN】STM32G47x 单片机FDCAN作为普通CAN外设使用教程

STM32G47x 单片机FDCAN作为普通CAN外设使用教程 控制器局域网总线(CAN,Controller Area Network)是一种用于实时应用的串行通讯协议总线,它可以使用双绞线来传输信号,是世界上应用最广泛的现场总线之一。CAN协议用于汽…

Apache Druid连接回收引发的血案

问题 线上执行大批量定时任务,发现SQL执行失败的报错: CommunicationsException, druid version 1.1.10, jdbcUrl : jdbc:mysql://xxx?useUnicodetrue&characterEncodingUTF-8&zeroDateTimeBehaviorconvertToNull,testWhileIdle true, idle …

Java事务详解

一、事务的理解: 1、事务的特性: 1) 原子性(atomicity):事务是数据库的逻辑工作单位,而且是必须是原子工作单位,对于其数据修改,要么全部执行,要么全部不执行。 2) 一致性…

Vatee万腾外汇数字化策略:Vatee科技决策力的未来引领

在外汇市场,Vatee万腾通过其前瞻性的外汇数字化策略,正引领着科技决策的未来。这一数字化策略的崭新愿景为投资者提供了更智慧、更高效的外汇投资体验,成为科技决策领域的翘楚。 Vatee万腾的外汇数字化策略是科技决策力未来引领的典范。通过运…

消息队列之初识Rabbit及安装

文章目录 一、MQ的相关概念1.什么是MQ?2.为什么要用MQ2.1流量消峰2.2应用解耦2.3异步处理 3.MQ 的分类3.1.ActiveMQ3.2.Kafka3.3.RocketMQ3.4.RabbitMQ 4.MQ 的选择4.1.Kafka4.2.RocketMQ4.3.RabbitMQ 二、RabbitMQ的相关概念1.四大核心概念2.RabbitMQ 核心部分3.Ra…

游戏AI:游戏开发和运营的新增长点

游戏AI(Game AI)是指在游戏开发运营的过程中模拟人类玩家或创建虚构性对手行为的人工智能技术。游戏AI的目标是增强游戏的互动性、可玩性和挑战性,使游戏中的角色能够智能地做出决策和行为。在游戏的开发和运营过程中使用人工智能技术&#x…

caffe搭建squeezenet网络的整套工程

之前用pytorch构建了squeezenet,个人觉得pytorch是最好用的,但是有的工程就是需要caffe结构的,所以本篇也用caffe构建一个squeezenet网络。 数据处理 首先要对数据进行处理,跟pytorch不同,pytorch读取数据只需要给数据…

【C++】类和对象(2)--构造函数

目录 一 概念 二 构造函数特性 三 默认构造函数 一 概念 对于以下Date类&#xff1a; class Date { public:void Init(int year, int month, int day){_year year;_month month;_day day;}void Print(){cout << _year << "-" << _month <…

Qt贝塞尔曲线

目录 引言核心代码基本表达绘制曲线使用QEasingCurve 完整代码 引言 贝塞尔曲线客户端开发中常见的过渡效果&#xff0c;如界面的淡入淡出、数值变化、颜色变化等等。为了能够更深的了解地理解贝塞尔曲线&#xff0c;本文通过Demo将贝塞尔曲线绘制出来&#xff0c;如下所示&am…

DevChat:开发者专属的基于IDE插件化编程协助工具

DevChat&#xff1a;开发者专属的基于IDE插件化编程协助工具 一、DevChat 的介绍1.1 DevChat 简介1.2 DevChat 优势 二、DevChat 在 VSCode 上的使用2.1 安装 DevChat2.2 注册 DevChat2.3 使用 DevChat 三、DevChat 的实战四、总结 一、DevChat 的介绍 在AI浪潮的席卷下&#x…

基于开源项目OCR做一个探究(chineseocr_lite)

背景&#xff1a;基于图片识别的技术有很多&#xff0c;应用与各行各业&#xff0c;我们公司围绕电子身份证识别自动录入需求开展&#xff0c;以下是我的研究心得 技术栈&#xff1a;python3.6&#xff0c;chineseocr_lite的onnx推理 环境部署&#xff1a;直接上截图&#xff…

c语言-数据结构-栈和队列的实现和解析

目录 一、栈 1、栈的概念 1.2 栈的结构 2、栈的创建及初始化 3、压栈操作 4、出栈操作 5、显示栈顶元素 6、显示栈空间内元素的总个数 7、释放栈空间 8、测试栈 二、队列 1、队列的概念 1.2 队列的结构 2、队列的创建及初始化 3、入队 4、出队 5、显示队头、队…

在Spring Boot中使用JTA实现对多数据源的事务管理

了解事务的都知道&#xff0c;在我们日常开发中单单靠事务管理就可以解决绝大多数问题了&#xff0c;但是为啥还要提出JTA这个玩意呢&#xff0c;到底JTA是什么呢&#xff1f;他又是具体来解决啥问题的呢&#xff1f; JTA JTA&#xff08;Java Transaction API&#xff09;是…

CG Magic分享效果图中VRay的灯光设置分析

效果图制作中&#xff0c;一张图VRay效果图好不好看主要看灯光、材质、模型、相机、后期这五点。VRay的灯光设置来说是极为重要的。 VRay灯光设置不好&#xff0c;就会出现vray灯光颜色不能正常显示再或是vray的灯光不亮的问题。 vray的灯光怎么设置才能使效果图展现的更加真实…