一、定义
- 接口总结
- 量化模式解读
二、实现
-
接口总结
1. PyTorch提供了三种不同的量化模式:Eager模式量化、FX图模式量化(维护)和PyTorch 2导出量化。
2. Eager Mode Quantization是一个测试版功能。用户需要进行融合,并手动指定量化和去量化发生的位置,而且它只支持模块,不支持函数。
3. FX图形模式量化是PyTorch中的一个自动量化工作流程,目前它是一个原型功能,由于我们有. PyTorch 2导出量化,它处于维护模式。它通过添加对泛函的支持和自动化量化过程来改进Eager Mode Quantization,尽管人们可能需要重构模型以使模型与FX Graph Mode Quantizations兼容(可通过torch.FX进行象征性跟踪)。请注意,FX Graph Mode Quantization预计不会在任意模型上工作,因为该模型可能无法进行符号追踪,我们将把它集成到torchvision等领域库中,用户将能够使用FX Graph Mode Quantization量化与支持的领域库中类似的模型。对于任意模型,我们将提供一般指导方针,但要真正使其工作,用户可能需要熟悉torch.fx,特别是如何使模型具有符号可追溯性。
4. PyTorch 2导出量化是新的全图模式量化工作流程,作为PyTorch 2.1中的原型功能发布。使用PyTorch 2,我们正在转向更好的完整程序捕获解决方案(torch.export),因为它可以捕获更高比例的模型(在14K型号上为88.8%),而fx Graph Mode Quantization使用的程序捕获方案torch.fx.symbolic_trace(在14K型号上为72.7%)。torch.export在某些python构造方面仍然存在局限性,需要用户参与以支持导出模型中的动态性,但总体而言,它比之前的程序捕获解决方案有所改进。PyTorch 2导出量化是为torch.Export捕获的模型构建的,考虑了建模用户和后端开发人员的灵活性和生产力。主要特点是(1)。可编程API,用于配置如何量化模型,可以扩展到更多的用例(2)。简化了建模用户和后端开发人员的用户体验,因为他们只需要与单个对象(量化器)交互,就可以表达用户对如何量化模型以及后端支持什么的意图。3.可选的参考量化模型表示,可以用整数运算表示量化计算,该运算更接近硬件中发生的实际量化计算。
5. 鼓励量化的新用户首先尝试PyTorch 2导出量化,如果效果不佳,用户可以尝试渴望模式量化。
下表比较了Eager模式量化、FX图形模式量化和PyTorch 2导出量化之间的差异:
支持三种类型的量化:- 动态量化(当网络训练完成后,其权重值已经确定,故权重的量化因子已经确定,但是对于不同的输入值来说,其缩放因子是动态计算的)------------训练后量化
- 静态量化(静态量化的模型在使用前有fine-tuning的过程(校准缩放因子):准备部分输入(对于图像分类模型就是准备一些图片,其他任务类似),使用静态量化后的模型进行预测,在此过程中量化模型的缩放因子会根据输入数据的分布进行调整。) -------------训练后量化
- 静态量化感知训练(它将静态量化直接插入到网络的训练过程中,消除了网络训练后的校准过程。)---------训练时量化
-
量化模式解读
量化模式分为: eager 模式, fx 模式 ,pytorch 2 模式
eager 模式下量化
1. PTQ–动态量化
import torch
# define a floating point model
class M(torch.nn.Module):def __init__(self):super().__init__()self.fc = torch.nn.Linear(4, 4)def forward(self, x):x = self.fc(x)return x# create a model instance
model_fp32 = M()
# create a quantized model instance
model_int8 = torch.ao.quantization.quantize_dynamic(model_fp32, # the original model{torch.nn.Linear}, # a set of layers to dynamically quantize #量化的层dtype=torch.qint8) # the target dtype for quantized weights# run the model
input_fp32 = torch.randn(4, 4, 4, 4)
res = model_int8(input_fp32)
PTQ—静态量化
import torch# define a floating point model where some layers could be statically quantized
class M(torch.nn.Module):def __init__(self):super().__init__()# QuantStub converts tensors from floating point to quantizedself.quant = torch.ao.quantization.QuantStub()self.conv = torch.nn.Conv2d(1, 1, 1)self.relu = torch.nn.ReLU()# DeQuantStub converts tensors from quantized to floating pointself.dequant = torch.ao.quantization.DeQuantStub()def forward(self, x):# manually specify where tensors will be converted from floating# point to quantized in the quantized modelx = self.quant(x)x = self.conv(x)x = self.relu(x)# manually specify where tensors will be converted from quantized# to floating point in the quantized modelx = self.dequant(x)return x# create a model instance
model_fp32 = M()# model must be set to eval mode for static quantization logic to work
model_fp32.eval()model_fp32.qconfig = torch.ao.quantization.get_default_qconfig('x86')model_fp32_fused = torch.ao.quantization.fuse_modules(model_fp32, [['conv', 'relu']]) #手动融合model_fp32_prepared = torch.ao.quantization.prepare(model_fp32_fused)# calibrate the prepared model to determine quantization parameters for activations
# in a real world setting, the calibration would be done with a representative dataset
input_fp32 = torch.randn(4, 1, 4, 4)
model_fp32_prepared(input_fp32)# Convert the observed model to a quantized model. This does several things:
# quantizes the weights, computes and stores the scale and bias value to be
# used with each activation tensor, and replaces key operators with quantized
# implementations.
model_int8 = torch.ao.quantization.convert(model_fp32_prepared) #量化# run the model, relevant calculations will happen in int8
res = model_int8(input_fp32)
ATQ:
import torch# define a floating point model where some layers could benefit from QAT
class M(torch.nn.Module):def __init__(self):super().__init__()# QuantStub converts tensors from floating point to quantizedself.quant = torch.ao.quantization.QuantStub()self.conv = torch.nn.Conv2d(1, 1, 1)self.bn = torch.nn.BatchNorm2d(1)self.relu = torch.nn.ReLU()# DeQuantStub converts tensors from quantized to floating pointself.dequant = torch.ao.quantization.DeQuantStub()def forward(self, x):x = self.quant(x)x = self.conv(x)x = self.bn(x)x = self.relu(x)x = self.dequant(x)return x# create a model instance
model_fp32 = M()# model must be set to eval for fusion to work
model_fp32.eval()# attach a global qconfig, which contains information about what kind
# of observers to attach. Use 'x86' for server inference and 'qnnpack'
# for mobile inference. Other quantization configurations such as selecting
# symmetric or asymmetric quantization and MinMax or L2Norm calibration techniques
# can be specified here.
# Note: the old 'fbgemm' is still available but 'x86' is the recommended default
# for server inference.
# model_fp32.qconfig = torch.ao.quantization.get_default_qconfig('fbgemm')
model_fp32.qconfig = torch.ao.quantization.get_default_qat_qconfig('x86')# fuse the activations to preceding layers, where applicable
# this needs to be done manually depending on the model architecture
model_fp32_fused = torch.ao.quantization.fuse_modules(model_fp32,[['conv', 'bn', 'relu']])# Prepare the model for QAT. This inserts observers and fake_quants in
# the model needs to be set to train for QAT logic to work
# the model that will observe weight and activation tensors during calibration.
model_fp32_prepared = torch.ao.quantization.prepare_qat(model_fp32_fused.train())# run the training loop (not shown)
num_train_batches = 20# QAT takes time and one needs to train over a few epochs.
# Train and check accuracy after each epoch
for nepoch in range(8):train_one_epoch(qat_model, criterion, optimizer, data_loader, torch.device('cpu'), num_train_batches)if nepoch > 3:# Freeze quantizer parametersqat_model.apply(torch.ao.quantization.disable_observer)if nepoch > 2:# Freeze batch norm mean and variance estimatesqat_model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)# Check the accuracy after each epochquantized_model = torch.ao.quantization.convert(qat_model.eval(), inplace=False)quantized_model.eval()top1, top5 = evaluate(quantized_model,criterion, data_loader_test, neval_batches=num_eval_batches)print('Epoch %d :Evaluation accuracy on %d images, %2.2f'%(nepoch, num_eval_batches * eval_batch_size, top1.avg))
fx 模式下量化:优点:自动融合算子-量化
PTQ-静态量化
import torch
from torch.ao.quantization import get_default_qconfig
from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx
from torch.ao.quantization import QConfigMapping
float_model.eval()
# The old 'fbgemm' is still available but 'x86' is the recommended default.
qconfig = get_default_qconfig("x86")
qconfig_mapping = QConfigMapping().set_global(qconfig)
def calibrate(model, data_loader):model.eval()with torch.no_grad():for image, target in data_loader:model(image)
example_inputs = (next(iter(data_loader))[0]) # get an example input
prepared_model = prepare_fx(float_model, qconfig_mapping, example_inputs) # fuse modules and insert observers 融合算子,插入观测
calibrate(prepared_model, data_loader_test) # run calibration on sample data记录校对
quantized_model = convert_fx(prepared_model) # convert the calibrated model to a quantized model #量化
具体见:https://pytorch.org/tutorials/prototype/fx_graph_mode_ptq_static.html
PTQ–动态量化
import torch
from torch.ao.quantization import default_dynamic_qconfig, QConfigMapping
# Note that this is temporary, we'll expose these functions to torch.ao.quantization after official releasee
from torch.quantization.quantize_fx import prepare_fx, convert_fxfloat_model.eval()
# The old 'fbgemm' is still available but 'x86' is the recommended default.
qconfig = get_default_qconfig("x86")
qconfig_mapping = QConfigMapping().set_global(qconfig)
prepared_model = prepare_fx(float_model, qconfig_mapping, example_inputs) # fuse modules and insert observers
# no calibration is required for dynamic quantization
quantized_model = convert_fx(prepared_model) # convert the model to a dynamically quantized model
pytorch2 模式下量化
1. 训练后量化ptq模式
import torchfrom torch._export import capture_pre_autograd_graph
class M(torch.nn.Module):def __init__(self):super().__init__()self.linear = torch.nn.Linear(5, 10)def forward(self, x):return self.linear(x)example_inputs = (torch.randn(1, 5),)
m = M().eval()# Step 1. program capture
# NOTE: this API will be updated to torch.export API in the future, but the captured
# result shoud mostly stay the same
m = capture_pre_autograd_graph(m, *example_inputs) #获取动态图
# we get a model with aten ops# Step 2. quantization
from torch.ao.quantization.quantize_pt2e import (prepare_pt2e,convert_pt2e,
)from torch.ao.quantization.quantizer import (XNNPACKQuantizer,get_symmetric_quantization_config,
)
# backend developer will write their own Quantizer and expose methods to allow
# users to express how they
# want the model to be quantized #获取量化器, int8
quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config())
m = prepare_pt2e(m, quantizer)# calibration omitted
def calibrate(model, data_loader):model.eval()with torch.no_grad():for image, target in data_loader:model(image)
calibrate(m, example_inputs)m = convert_pt2e(m)
# we have a model with aten ops doing integer computations when possible########################扩展至c++
import torch._inductor.config as config
config.cpp_wrapper = Truewith torch.no_grad():optimized_model = torch.compile(m)# Running some benchmarkoptimized_model(*example_inputs)res=optimized_model(example_inputs[0])print(res)# tensor([[0.0312, 0.0998, -0.7920, 0.0748, 0.7982, 0.1808, 0.4365, 0.0998,# 0.5800, 0.4428]])
QAT:量化感知训练
#简化基本步骤
import torch
from torch._export import capture_pre_autograd_graph
from torch.ao.quantization.quantize_pt2e import (prepare_qat_pt2e,convert_pt2e,
)
from torch.ao.quantization.quantizer import (XNNPACKQuantizer,get_symmetric_quantization_config,
)class M(torch.nn.Module):def __init__(self):super().__init__()self.linear = torch.nn.Linear(5, 10)def forward(self, x):return self.linear(x)example_inputs = (torch.randn(1, 5),)
m = M()# Step 1. program capture
# NOTE: this API will be updated to torch.export API in the future, but the captured
# result shoud mostly stay the same
m = capture_pre_autograd_graph(m, *example_inputs)
# we get a model with aten ops# Step 2. quantization-aware training
# backend developer will write their own Quantizer and expose methods to allow
# users to express how they want the model to be quantized
quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config())
m = prepare_qat_pt2e(m, quantizer)# train omittedm = convert_pt2e(m)
# we have a model with aten ops doing integer computations when possible# move the quantized model to eval mode, equivalent to `m.eval()`
torch.ao.quantization.move_exported_model_to_eval(m)
https://pytorch.org/tutorials/prototype/pt2e_quant_qat.html