NLP(六十七)BERT模型训练后动态量化(PTDQ)

  本文将会介绍BERT模型训练后动态量化(Post Training Dynamic Quantization,PTDQ)。

量化

  在深度学习中,量化(Quantization)指的是使用更少的bit来存储原本以浮点数存储的tensor,以及使用更少的bit来完成原本以浮点数完成的计算。这么做的好处主要有如下几点:

  • 更少的模型体积,接近4倍的减少
  • 可以更快地计算,由于更少的内存访问和更快的int8计算,可以快2~4倍

  PyTorch中的模型参数默认以FP32精度储存。对于量化后的模型,其部分或者全部的tensor操作会使用int类型来计算,而不是使用量化之前的float类型。当然,量化还需要底层硬件支持,x86 CPU(支持AVX2)、ARM CPU、Google TPU、Nvidia Volta/Turing/Ampere、Qualcomm DSP这些主流硬件都对量化提供了支持。

模型量化示例图片

PTDQ

  PyTorch对量化的支持目前有如下三种方式:

  • Post Training Dynamic Quantization:模型训练完毕后的动态量化
  • Post Training Static Quantization:模型训练完毕后的静态量化
  • QAT (Quantization Aware Training):模型训练中开启量化

  本文仅介绍Post Training Dynamic Quantization(PTDQ)
  对训练后的模型权重执行动态量化,将浮点模型转换为动态量化模型,仅对模型权重进行量化,偏置不会量化。默认情况下,仅对Linear和RNN变体量化 (因为这些layer的参数量很大,收益更高)。

torch.quantization.quantize_dynamic(model, qconfig_spec=None, dtype=torch.qint8, mapping=None, inplace=False)

参数解释:

  • model:模型(默认为FP32)
  • qconfig_spec:
  1. 集合:比如: qconfig_spec={nn.LSTM, nn.Linear} 。列出要量化的神经网络模块。
  2. 字典: qconfig_spec = {nn.Linear: default_dynamic_qconfig, nn.LSTM: default_dynamic_qconfig}
  • dtype: float16 或 qint8
  • mapping:就地执行模型转换,原始模块发生变异
  • inplace:将子模块的类型映射到需要替换子模块的相应动态量化版本的类型

例子:

# -*- coding: utf-8 -*-
# 动态量化模型,只量化权重
import torch
from torch import nnclass DemoModel(torch.nn.Module):def __init__(self):super(DemoModel, self).__init__()self.conv = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=1)self.relu = nn.ReLU()self.fc = torch.nn.Linear(2, 2)def forward(self, x):x = self.conv(x)x = self.relu(x)x = self.fc(x)return xif __name__ == "__main__":model_fp32 = DemoModel()# 创建一个量化的模型实例model_int8 = torch.quantization.quantize_dynamic(model=model_fp32,  # 原始模型qconfig_spec={torch.nn.Linear},  # 要动态量化的算子dtype=torch.qint8)  # 将权重量化为:qint8print(model_fp32)print(model_int8)# 运行模型input_fp32 = torch.randn(1, 1, 2, 2)output_fp32 = model_fp32(input_fp32)print(output_fp32)output_int8 = model_int8(input_fp32)print(output_int8)

输出结果如下:

DemoModel((conv): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1))(relu): ReLU()(fc): Linear(in_features=2, out_features=2, bias=True)
)
DemoModel((conv): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1))(relu): ReLU()(fc): DynamicQuantizedLinear(in_features=2, out_features=2, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
)
tensor([[[[0.3120, 0.3042],[0.3120, 0.3042]]]], grad_fn=<AddBackward0>)
tensor([[[[0.3120, 0.3042],[0.3120, 0.3042]]]])

模型量化策略

  当前,由于量化算子的覆盖有限,因此,对于不同的深度学习模型,其量化策略不同,见下表:

模型量化策略原因
LSTM/RNNDynamic Quantization模型吞吐量由权重的计算/内存带宽决定
BERT/TransformerDynamic Quantization模型吞吐量由权重的计算/内存带宽决定
CNNStatic Quantization模型吞吐量由激活函数的内存带宽决定
CNNQuantization Aware Training模型准确率不能由Static Quantization获取的情况

   下面对BERT模型进行训练后动态量化,分析模型在量化前后,推理效果和推理性能的变化。

实验

   我们使用的训练后的模型为中文文本分类模型,其训练过程可以参考文章:NLP(六十六)使用HuggingFace中的Trainer进行BERT模型微调 。
   训练后的BERT模型动态量化实验的设置如下:

  1. base model: bert-base-chinese
  2. CPU info: x86-64, Intel® Core™ i5-10210U CPU @ 1.60GHz
  3. batch size: 1
  4. thread: 1

   具体的实验过程如下

  • 加载模型及tokenizer
import torch
from transformers import AutoModelForSequenceClassificationMAX_LENGTH = 128
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
checkpoint = f"./sougou_test_trainer_{MAX_LENGTH}/checkpoint-96"
model = AutoModelForSequenceClassification.from_pretrained(checkpoint).to(device)
from transformers import AutoTokenizer, DataCollatorWithPaddingtokenizer = AutoTokenizer.from_pretrained(checkpoint)
  • 测试数据集
import pandas as pdtest_df = pd.read_csv("./data/sougou/test.csv")test_df.head()
textlabel
0届数比赛时间比赛地点参加国家和地区冠军亚军决赛成绩第一届1956-1957英国11美国丹麦6...0
1商品属性材质软橡胶带加浮雕工艺+合金彩色队徽吊牌规格162mm数量这一系列产品不限量发行图案...0
2今天下午,沈阳金德和长春亚泰队将在五里河相遇。在这两支球队中沈阳籍球员居多,因此这场比赛实际...0
3本报讯中国足协准备好了与特鲁西埃谈判的合同文本,也在北京给他预订好了房间,但特鲁西埃爽约了!...0
4网友点击发表评论祝贺中国队夺得五连冠搜狐体育讯北京时间5月6日,2006年尤伯杯羽毛球赛在日...0
  • 量化前模型的推理时间及评估指标
import numpy as np
import times_time = time.time()
true_labels, pred_labels = [], [] 
for i, row in test_df.iterrows():row_s_time = time.time()true_labels.append(row["label"])encoded_text = tokenizer(row['text'], max_length=MAX_LENGTH, truncation=True, padding=True, return_tensors='pt').to(device)# print(encoded_text)logits = model(**encoded_text)label_id = np.argmax(logits[0].detach().cpu().numpy(), axis=1)[0]pred_labels.append(label_id)print(i, (time.time() - row_s_time)*1000, label_id)print("avg time: ", (time.time() - s_time) * 1000 / test_df.shape[0])
0 229.3872833251953 0
100 362.0314598083496 1
200 311.16747856140137 2
300 324.13792610168457 3
400 406.9099426269531 4
avg time:  352.44047810332944
from sklearn.metrics import classification_reportprint(classification_report(true_labels, pred_labels, digits=4))
              precision    recall  f1-score   support0     0.9900    1.0000    0.9950        991     0.9691    0.9495    0.9592        992     0.9900    1.0000    0.9950        993     0.9320    0.9697    0.9505        994     0.9895    0.9495    0.9691        99accuracy                         0.9737       495macro avg     0.9741    0.9737    0.9737       495
weighted avg     0.9741    0.9737    0.9737       495
  • 设置量化后端
# 模型量化
cpu_device = torch.device("cpu")
torch.backends.quantized.supported_engines
['none', 'onednn', 'x86', 'fbgemm']
torch.backends.quantized.engine = 'x86'
  • 量化后模型的推理时间及评估指标
# 8-bit 量化
quantized_model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8
).to(cpu_device)
q_s_time = time.time()
q_true_labels, q_pred_labels = [], [] for i, row in test_df.iterrows():row_s_time = time.time()q_true_labels.append(row["label"])encoded_text = tokenizer(row['text'], max_length=MAX_LENGTH, truncation=True, padding=True, return_tensors='pt').to(cpu_device)logits = quantized_model(**encoded_text)label_id = np.argmax(logits[0].detach().numpy(), axis=1)[0]q_pred_labels.append(label_id)print(i, (time.time() - row_s_time) * 1000, label_id)print("avg time: ", (time.time() - q_s_time) * 1000 / test_df.shape[0])
0 195.47462463378906 0
100 247.33805656433105 1
200 219.41304206848145 2
300 206.44831657409668 3
400 187.4992847442627 4
avg time:  217.63229466447928
from sklearn.metrics import classification_reportprint(classification_report(q_true_labels, q_pred_labels, digits=4))
              precision    recall  f1-score   support0     0.9900    1.0000    0.9950        991     0.9688    0.9394    0.9538        992     0.9900    1.0000    0.9950        993     0.9320    0.9697    0.9505        994     0.9896    0.9596    0.9744        99accuracy                         0.9737       495macro avg     0.9741    0.9737    0.9737       495
weighted avg     0.9741    0.9737    0.9737       495
  • 量化前后模型大小对比
import osdef print_size_of_model(model):torch.save(model.state_dict(), "temp.p")print("Size (MB): ", os.path.getsize("temp.p")/1e6)os.remove("temp.p")print_size_of_model(model)
print_size_of_model(quantized_model)
Size (MB):  409.155273
Size (MB):  152.627621

  量化后端(Quantization backend)取决于CPU架构,不同计算机的CPU架构不同,因此,默认的动态量化不一定在所有的CPU上都能生效,需根据自己计算机的CPU架构设置好对应的量化后端。另外,不同的量化后端也有些许差异。Linux服务器使用uname -a可查看CPU信息。
  重复上述实验过程,以模型的最大输入长度为变量,取值为128,256,384,每种情况各做3次实验,结果如下:

实验最大长度量化前平均推理时间(ms)量化前weighted F1值量化前平均推理时间(ms)量化前weighted F1值
实验138410660.97976860.9838
实验23841047.60.9899738.10.9879
实验33841020.90.9817714.00.9838
实验1256668.70.9717431.40.9718
实验2256675.10.9717449.90.9718
实验3256656.00.9717446.50.9718
实验1128335.80.9737200.50.9737
实验2128336.50.9737227.20.9737
实验3128352.40.9737217.60.9737

  综上所述,对于训练后的BERT模型(文本分类模型)进行动态量化,其结论如下:

  • 模型推理效果:量化前后基本相同,量化后略有下降
  • 模型推理时间:量化后平均提速约1.52倍

总结

  本文介绍了量化基本概念,PyTorch模型量化方式,以及对BERT模型训练后进行动态量化后在推理效果和推理性能上的实验。
  本文项目已开源至Github项目:https://github.com/percent4/dynamic_quantization_on_bert 。
  本人已开通个人博客网站,网址为:https://percent4.github.io/ ,欢迎大家访问~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/117483.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Java泛型机制

✅作者简介&#xff1a;大家好&#xff0c;我是Leo&#xff0c;热爱Java后端开发者&#xff0c;一个想要与大家共同进步的男人&#x1f609;&#x1f609; &#x1f34e;个人主页&#xff1a;Leo的博客 &#x1f49e;当前专栏&#xff1a;每天一个知识点 ✨特色专栏&#xff1a…

【半监督医学图像分割】2022-MedIA-UWI

【半监督医学图像分割】2022-MedIA-UWI 论文题目&#xff1a;Semi-supervise d me dical image segmentation via a triple d-uncertainty guided mean teacher model with contrastive learning 中文题目&#xff1a;基于对比学习的三维不确定性指导平均教师模型的半监督图像分…

“新KG”视点 | 陈华钧——大模型时代的知识处理:新机遇与新挑战

OpenKG 大模型专辑 导读 知识图谱和大型语言模型都是用来表示和处理知识的手段。大模型补足了理解语言的能力&#xff0c;知识图谱则丰富了表示知识的方式&#xff0c;两者的深度结合必将为人工智能提供更为全面、可靠、可控的知识处理方法。在这一背景下&#xff0c;OpenKG组织…

微机原理 || 第2次测试:汇编指令(加减乘除运算,XOR,PUSH,POP,寻址方式,物理地址公式,状态标志位)(测试题+手写解析)

&#xff08;一&#xff09;测试题目&#xff1a; 1.数[X]补1111,1110B&#xff0c;则其真值为 2.在I/O指令中,可用于表示端口地址的寄存器 3. MOV AX,[BXSl]的指令中&#xff0c;源操作数的物理地址应该如何计算 4.执行以下两条指令后&#xff0c;标志寄存器FLAGS的六个状态…

Cmake qt ,vtkDataArray.cxx.obj: File too big

解决方法&#xff1a; Qt4 在pro 加入“QMAKE_CXXFLAGS -BigObj” 可以解决 Qt5 在网上用“-Wa,-mbig-obj” 不能解决&#xff0c;最后通过“QMAKE_CXXFLAGS -Ofast -flto”解决问题。 Qt4 在pro 加入“QMAKE_CXXFLAGS -BigObj” 可以解决Qt5 在网上用“-Wa,-mbig-obj” …

wxWidgets从空项目开始Hello World

前文回顾 接上篇&#xff0c;已经是在CodeBlocks20.03配置了wxWidgets3.0.5&#xff0c;并且能够通过项目创建导航创建一个新的工程&#xff0c;并且成功运行。 那么上一个是通过CodeBlocks的模板创建的&#xff0c;一进去就已经是2个头文件2个cpp文件&#xff0c;总是感觉缺…

OAuth2.0二 JWT以及Oauth2实现SSO

一 JWT 1.1 什么是JWT JSON Web Token&#xff08;JWT&#xff09;是一个开放的行业标准&#xff08;RFC 7519&#xff09;&#xff0c;它定义了一种简介的、自包含的协议格式&#xff0c;用于在通信双方传递json对象&#xff0c;传递的信息经过数字签名可以被验证和信任。JW…

python web 开发与 Node.js + Express 创建web服务器入门

目录 1. Node.js Express 框架简介 2 Node.js Express 和 Python 创建web服务器的对比 3 使用 Node.js Express 创建web服务器示例 3.1 Node.js Express 下载安装 3.2 使用Node.js Express 创建 web服务器流程 1. Node.js Express 框架简介 Node.js Express 是一种…

机器学习---决策树的划分依据(熵、信息增益、信息增益率、基尼值和基尼指数)

1. 熵 物理学上&#xff0c;熵 Entropy 是“混乱”程度的量度。 系统越有序&#xff0c;熵值越低&#xff1b;系统越混乱或者分散&#xff0c;熵值越⾼。 1948年⾹农提出了信息熵&#xff08;Entropy&#xff09;的概念。 从信息的完整性上进⾏的描述&#xff1a;当系统的有序…

myspl使用指南

mysql数据库 使用命令行工具连接数据库 mysql -h -u 用户名 -p -u表示后面是用户名-p表示后面是密码-h表示后面是主机名&#xff0c;登录当前设备可省略。 如我们要登录本机用户名为root&#xff0c;密码为123456的账户&#xff1a; mysql -u root -p按回车&#xff0c;然后…

大数据组件-Flume集群环境的启动与验证

&#x1f947;&#x1f947;【大数据学习记录篇】-持续更新中~&#x1f947;&#x1f947; 个人主页&#xff1a;beixi 本文章收录于专栏&#xff08;点击传送&#xff09;&#xff1a;【大数据学习】 &#x1f493;&#x1f493;持续更新中&#xff0c;感谢各位前辈朋友们支持…

gitlab-rake gitlab:backup:create 执行报错 Errno::ENOSPC: No space left on device

gitlab仓库备份执行 gitlab-rake gitlab:backup:create报错如下&#xff1a; 问题分析&#xff1a;存储备份的空间满 解决方法&#xff1a; 方法1&#xff1a;清理存放路径&#xff0c;删除不需要文件&#xff0c;释放空间。 方法2&#xff1a;创建一个根目录的挂载点&#x…

八一参考文献:[八一新书]许少辉.乡村振兴战略下传统村落文化旅游设计[M]北京:中国建筑出版传媒,2022.

八一参考文献&#xff1a;&#xff3b;八一新书&#xff3d;许少辉&#xff0e;乡村振兴战略下传统村落文化旅游设计&#xff3b;&#xff2d;&#xff3d;北京&#xff1a;中国建筑出版传媒&#xff0c;&#xff12;&#xff10;&#xff12;&#xff12;&#xff0e;

机器视觉工程师,有哪几种类型

1.光学实验室&#xff08;打光机器视觉工程师&#xff0c;一般此职位&#xff0c;要求有光学学历的背景最佳&#xff09; 2.机器视觉算法开发工程师&#xff08;此职位国内稀缺&#xff09;3.机器视觉工程师/机器视觉开发工程师&#xff08;MV工程师/MV工程师&#xff09;&…

常见项目管理中npm包操作总结

前言 我们在日常工作中&#xff0c;可能需要下载包、创建包、发布包等等。本篇推文将记录日常项目中关于npm包的操作。 引用包 npm仓库公开的包我们都可以通过npm install的命令进行引用下载。 而我们开发的业务公共组件需要在公司内部项目公共引用&#xff0c;而不希望公开为外…

Android——基本控件(下)(二十)

1. 树型组件&#xff1a;ExpandableListView 1.1 知识点 &#xff08;1&#xff09;掌握树型组件的定义&#xff1b; &#xff08;2&#xff09;可以使用事件对树操作进行监听。 2. 具体内容 既然这个组件可以完成列表的功能&#xff0c;肯定就需要一个可以操作的数据&…

el-select 选择一条数据后,把其余数据带过来

1. 案例&#xff1a; ps: 票号是下拉框选择&#xff0c;风险分类、场站名称以及开始时间是选择【票号】后带过来的。 2. 思路: 使用官网上给的方法&#xff0c;选择之后&#xff0c;触发change方法从而给其余字段赋值 3. 代码 <el-form-itemlabel"票号&#xff1a;&…

(leetcode1761一个图中连通三元组的最小度数,暴力+剪枝)-------------------Java实现

&#xff08;leetcode1761一个图中连通三元组的最小度数&#xff0c;暴力剪枝&#xff09;-------------------Java实现 题目表述 给你一个无向图&#xff0c;整数 n 表示图中节点的数目&#xff0c;edges 数组表示图中的边&#xff0c;其中 edges[i] [ui, vi] &#xff0c;…

2023_Spark_实验四:SCALA基础

一、在IDEA中执行以下语句 或者用windows徽标R 输入cmd 进入命令提示符 输入scala直接进入编写界面 1、Scala的常用数据类型 注意&#xff1a;在Scala中&#xff0c;任何数据都是对象。例如&#xff1a; scala> 1 res0: Int 1scala> 1.toString res1: String 1scala…

linux安装firefox

1.下载对应包 https://www.mozilla.org/en-US/firefox/all/#product-desktop-release 2. 挂载桌面链接(如果/usr/bin/firefox下有的话,先删除) ln -s /opt/firefox/firefox /usr/bin/firefox 3.执行以下命令&#xff0c;即可启动Firefox客户端&#xff1a; firefox