模型量化笔记--KL散度量化

KL散度量化

前面介绍的非对称量化中,是将数据中的min值和max值直接映射到[-128, 127]。
同样的,前面介绍的对称量化是将数据的最大绝对值 ∣ m a x ∣ |max| max直接映射到127。
上面两种直接映射的方法比较粗暴,而TensorRT中的int8量化是基于KL散度来选取最佳的阈值T来映射到127中。超出阈值 ± ∣ T ∣ \pm|T| ±T的数据会直接映射为阈值(类似于截断映射)。

KL散度定义

KL散度常用来衡量两个分布P和Q之间的差异,KL散度越小,两个分布越相似,其公式定义如下:
D K L = ∑ i P ( x i ) l o g ( P ( x i ) Q ( x i ) ) D_{KL} = \sum_{i}P(x_i)log(\frac{P(x_i)}{Q(x_i)}) DKL=iP(xi)log(Q(xi)P(xi))

TensorRT实现KL散度量化的步骤

  1. 基于原始输入数据生成拥有2048个bin的直方图
hist, bin_edges = np.histogram(P, bins = 2048)
  1. 在[128, 2048]返回内循环执行3-5步,寻找最佳的划分 b i n i bin_{i} bini
  2. [ 0 , b i n i ] [0,bin{i}] [0,bini]范围内的直方图数据作为原始P, 并将 b i n i bin_{i} bini之后的直方图数据进行求和,并累加到 b i n i − 1 bin_{i-1} bini1中形成以 b i n i bin_{i} bini作为划分的最终P分布
  3. 对P分布进行量化形成Q分布(一般是划分和合并bins,计算合并后的平均值作为Q分布对应bins的值)
  1. 计算P分布和Q分布的KL散度
  2. 根据最小的KL散度来选取最佳的 b i n b e s t bin_{best} binbest,将bin_edges[ b i n b e s t bin_{best} binbest]作为最终的阈值threshold,即映射到127的阈值T
  3. 根据最佳的阈值T来计算scale
    s c a l e = T i n t m a x = T 127 scale = \frac{T}{int_{max}} = \frac{T}{127} scale=intmaxT=127T
  4. 根据对称量化来量化原始数据(权重、激活值等等)

TensorRT使用KL散度量化的目的

通过KL散度选取合适的阈值T,根据阈值计算对应的缩放系数scale,力求int8量化后的数值能更准确表示出量化前的FP32数值。

代码案例

import random
import numpy as np
import matplotlib.pyplot as plt  
import copy
import scipy.stats as stats# 随机生成测试数据
def generator_P(size):walk = []avg = random.uniform(3.000, 600.999)std = random.uniform(500.000, 1024.959)for _ in range(size):walk.append(random.gauss(avg, std)) # 生成符合高斯分布的随机数return walk# 平滑p和q,防止出现nan值,因为KL散度会计算log(p/q), 当q为0值时会出现nan
def smooth_distribution(p, eps=0.0001):is_zeros = (p == 0).astype(np.float32)is_nonzeros = (p != 0).astype(np.float32)n_zeros = is_zeros.sum()n_nonzeros = p.size - n_zerosif not n_nonzeros:raise ValueError('The discrete probability distribution is malformed. All entries are 0.')eps1 = eps * float(n_zeros) / float(n_nonzeros)assert eps1 < 1.0, 'n_zeros=%d, n_nonzeros=%d, eps1=%f' % (n_zeros, n_nonzeros, eps1)hist = p.astype(np.float32)hist += eps * is_zeros + (-eps1) * is_nonzerosassert (hist <= 0).sum() == 0return histdef threshold_distribution(distribution, target_bin = 128):distribution = distribution[1:]length = distribution.size # 2047threshold_sum = sum(distribution[target_bin:]) # [128: ]kl_divergence = np.zeros(length - target_bin) # 初始化 2047 - 128 = 1919 个KL散度值for threshold in range(target_bin, length): # 遍历threshold寻找KL散度最低的阈值sliced_nd_hist = copy.deepcopy(distribution[:threshold]) # [0, threshold)内的作为Pp = sliced_nd_hist.copy() # 生成pp[threshold - 1] += threshold_sum # 把 [threshold:] 后的累加和加到 p[threshold - 1] 中threshold_sum = threshold_sum - distribution[threshold] # 更新下一轮的累加和,即上一轮的累加和减去即将移入P分布的区间数据is_nonzeros = (p != 0).astype(np.int64) # [0:threshold]内不为0的区间quantized_bins = np.zeros(target_bin, dtype = np.int64) # 初始化量化后的binsnum_merged_bins = sliced_nd_hist.size // target_bin # 计算多少个区间需要合并来计算平均值,例如最初有8个bins,需要合并到4个bins,则每两个bins需要进行合并# 合并binsfor j in range(target_bin): start = j * num_merged_bins # 合并开始的binsstop = start + num_merged_bins # 合并结束的binsquantized_bins[j] = sliced_nd_hist[start:stop].sum() # 计算区间内bins的总和quantized_bins[-1] += sliced_nd_hist[target_bin * num_merged_bins:].sum()# 计算qq = np.zeros(sliced_nd_hist.size, dtype = np.float64) # 初始化量化后的qfor j in range(target_bin):start = j * num_merged_binsif j == target_bin - 1:stop = -1else:stop = start + num_merged_bins # 每num_merged_bins个bins进行合并组成qnorm = is_nonzeros[start:stop].sum() # 看看合并区间里,不为0的区间个数if norm != 0:q[start:stop] = float(quantized_bins[j]) / float(norm) # 用均值(假如区间内都不为0)填充q# 平滑p和qp = smooth_distribution(p)q = smooth_distribution(q)# 计算p和q之间的KL散度kl_divergence[threshold - target_bin] = stats.entropy(p, q)# 寻找最小KL散度对应threshold的索引min_kl_divergence = np.argmin(kl_divergence)threshold_value = min_kl_divergence + target_bin # 计算真正的threshold, 基于最初的128, 因为一开始就是从128开始不断向外计算来扩大P的范围return threshold_valueif __name__ == '__main__':# 随机初始化测试数据size = 20480 P = generator_P(size) P = np.array(P)P = P[P > 0] # 保留大于0的数# print("maximum activation value", max(np.absolute(P))) # 最大的激活值hist, bin_edges = np.histogram(P, bins = 2048) # 生成直方图 hist表示每一个bins对应的数量, bins表示截止 threshold = threshold_distribution(hist, target_bin = 128) # 返回KL散度最小的划分binsprint("threshold: ", threshold)print("threshold edges:", bin_edges[threshold]) # 截止到threshold对应的bins, 能够表示的范围 bin_edges[-1]表示上面最大的激活值,即能够表示所有数# 计算scale# scale = bin_edges[threshold] / int_max # 即bin_edges[threshold] / 127 # 在最初的对称量化中,我们是用绝对值最大的数值作为bin_edges[threhold], 而TensorRT就是利用KL散度来评估最佳的bin_edges[threshold]# 分成 split_zie 组, density表示是否要normedplt.title("Relu activation value Histogram")plt.xlabel("Activation values")plt.ylabel("Normalized number of Counts")plt.hist(P, bins=2047)plt.vlines(bin_edges[threshold], 0, 30, colors = "r", linestyles = "dashed") # 红线向左就是能够表示的所有范围plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/163459.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

openGauss学习笔记-102 openGauss 数据库管理-管理数据库安全-客户端接入之查看数据库连接数

文章目录 openGauss学习笔记-102 openGauss 数据库管理-管理数据库安全-客户端接入之查看数据库连接数102.1 背景信息102.2 操作步骤 openGauss学习笔记-102 openGauss 数据库管理-管理数据库安全-客户端接入之查看数据库连接数 102.1 背景信息 当用户连接数达到上限后&#…

小黑子—Maven基础

Maven基础 一 小黑子的Maven学习1. Mavn的介绍2. Maven基础概念2.1 仓库2.2 坐标2.3 仓库配置 3. 手动写一个maven项目3.1 Maven项目构建命令3.2 插件创建工程 4. IDEA下的maven项目5. 依赖管理5.1 依赖配置5.2 依赖传递5.3 可选依赖&#xff08;不透明&#xff09;5.4 排除依赖…

【一:实战开发testng的介绍】

目录 1、主要内容1.1、为啥要做接口测试1.2、接口自动化测试落地过程1.3、接口测试范围1.4、手工接口常用的工具1.5、自动化框架的设计 2、testng自动化测试框架基本测试1、基本注解2、忽略测试3、依赖测试4、超时测试5、异常测试6、通过xml文件参数测试7、通过data实现数据驱动…

UWB十个知识点

UWB是一直被基于厚望的高精度定位技术 1&#xff1a;定位技术及UWB特点 位置空间感知技术包括了GNSS、RFID、蓝牙和UWB&#xff0c;在室内和区域空间测量最具技术优势的技术是UWB。 GNSS是广域定位技术&#xff0c;室内以及建筑物旁边等场景&#xff0c;GNSS无法实现定位&am…

【微服务 SpringCloud】实用篇 · Ribbon负载均衡

微服务&#xff08;4&#xff09; 文章目录 微服务&#xff08;4&#xff09;1. 负载均衡原理2. 源码跟踪1&#xff09;LoadBalancerIntercepor2&#xff09;LoadBalancerClient3&#xff09;负载均衡策略IRule4&#xff09;总结 3. 负载均衡策略3.1 负载均衡策略3.2 自定义负载…

企业IT资产设备折旧残值如何计算

环境&#xff1a; 企业/公司 IT资产 问题描述&#xff1a; 企业IT设备折旧残值如何计算&#xff1f; 解决方案&#xff1a; 1.按三年折旧 净值原值-月折旧额折旧月份 &#xff0c; 月折旧额原值(1-3%)/36 折旧月份ROUND(E2*(1-3%)/36,2) 2.净值E2-F2*G2

vue使用pdf 导出当前页面,(jspdf, html2canvas )

需要安装两个插件 npm install html2canvas jspdfyarn add html2canvas jspdf<div class"app-container" id"pdfPage">我是内容 </div><el-button size"mini" click"onExportPdf">导出数据</el-button>onexp…

代码随想录算法训练营第23期day24|回溯算法理论基础、77. 组合

目录 一、回溯算法基础 回溯法模板 二、&#xff08;leetcode 77&#xff09;组合 剪枝 一、回溯算法基础 1.回溯的本质是穷举&#xff0c;穷举所有可能&#xff0c;然后选出想要的答案&#xff08;为了提升效率&#xff0c;最多再加一个剪枝&#xff09; 2.回溯法解决的…

凝聚技术力量 共建测试生态 ——集成电路测试技术交流日成功举办

10月18日下午&#xff0c;凝聚技术力量&#xff0c;共建测试生态 ——集成电路测试技术交流会在上海成功举办。来自全国各地知名专家学者、技术大咖及企业代表齐聚一堂&#xff0c;共同探讨封装测试技术的发展方向&#xff0c;共话产业未来&#xff0c;共促产业发展。 本次活动…

Stable Diffusion WebUI扩展a1111-sd-webui-tagcomplete之Booru风格Tag自动补全功能详细介绍

安装地址 直接附上地址先: Ranting8323 / A1111 Sd Webui Tagcomplete GitCodeGitCode——开源代码托管平台,独立第三方开源社区,Git/Github/Gitlabhttps://gitcode.net/ranting8323/a1111-sd-webui-tagcomplete.git上面是GitCode的地址,下面是GitHub的地址,根据自身情…

CUDA编程入门系列(二) GPU硬件架构综述

一、Fermi GPU Fermi GPU如下图所示&#xff0c;由16个SM&#xff08;stream multiprocessor&#xff09;组成&#xff0c;不同的SM之间通过L2 Cache和全局内存进行相连。整个架构大致分为两个层次&#xff0c;①总体架构由多个SM组成 ②每个SM由多个SP core&#xff08;stream…

数据结构中的七大排序(Java实现)

目录 一、直接插入排序 二、希尔排序 三、直接选择排序 四、堆排序 五、冒泡排序 六、快速排序 七、归并排序 一、直接插入排序 思想&#xff1a; 定义i下标之前的元素全部已经有序&#xff0c;遍历一遍要排序的数组&#xff0c;把i下标前的元素全部进行排序&#xff0…

elementui select组件下拉框底部增加自定义按钮

elementui select组件下拉框底部增加自定义按钮 el-select组件的visible-change 事件&#xff08;下拉框出现/隐藏时触发&#xff09; <el-selectref"select":value"value"placeholder"请选择"visible-change"visibleChange">&…

一天吃透Java集合面试八股文

内容摘自我的学习网站&#xff1a;topjavaer.cn 常见的集合有哪些&#xff1f; Java集合类主要由两个接口Collection和Map派生出来的&#xff0c;Collection有三个子接口&#xff1a;List、Set、Queue。 Java集合框架图如下&#xff1a; List代表了有序可重复集合&#xff0c…

软考-访问控制技术原理与应用

本文为作者学习文章&#xff0c;按作者习惯写成&#xff0c;如有错误或需要追加内容请留言&#xff08;不喜勿喷&#xff09; 本文为追加文章&#xff0c;后期慢慢追加 by 2023年10月 访问控制概念 访问控制是计算机安全的一个重要组成部分&#xff0c;用于控制用户或程序如…

LiveGBS流媒体平台GB/T28181常见问题-安全控制HTTP接口鉴权勾选流地址鉴权后401Unauthorized如何播放调用接口

LiveGBS流媒体平台GB/T28181常见问题-安全控制HTTP接口鉴权勾选流地址鉴权后401 Unauthorized如何播放调用接口&#xff1f; 1、安全控制1.1、HTTP接口鉴权1.2、流地址鉴权 2、401 Unauthorized2.1、携带token调用接口2.1.1、获取鉴权token2.1.2、调用其它接口2.1.2.1、携带 Co…

Spring Boot 可以同时处理多少请求?

文章目录 Spring Boot 的请求处理能力1. 硬件资源2. 应用程序的设计3. 配置4. 运行时环境 基准测试和性能优化高性能的 Spring Boot 应用程序示例结论 &#x1f389;欢迎来到架构设计专栏~Spring Boot 可以同时处理多少请求&#xff1f; ☆* o(≧▽≦)o *☆嗨~我是IT陈寒&#…

C语言实现面向对象编程 | 干货

前言 GOF的《设计模式》一书的副标题叫做“可复用面向对象软件的基础”&#xff0c;从标题就能看出面向对象是设计模式基本思想。 由于C语言并不是面向对象的语言&#xff0c;C语言没有直接提供封装、继承、组合、多态等面向对象的功能&#xff0c;但C语言有struct和函数指针。…

019-第三代软件开发-Git提交规范

第三代软件开发-Git提交规范 文章目录 第三代软件开发-Git提交规范项目介绍Git提交规范分支规范Commit Message FormatHeaderBodyFooterRevert 总结一下 关键字&#xff1a; Qt、 Qml、 git、 Commit、 release 项目介绍 欢迎来到我们的 QML & C 项目&#xff01;这个…

【数据结构】优先级队列(堆)

作者主页&#xff1a;paper jie_博客 本文作者&#xff1a;大家好&#xff0c;我是paper jie&#xff0c;感谢你阅读本文&#xff0c;欢迎一建三连哦。 本文录入于《JAVA数据结构》专栏&#xff0c;本专栏是针对于大学生&#xff0c;编程小白精心打造的。笔者用重金(时间和精力…