【大语言模型】ACL2024论文-07 BitDistiller: 释放亚4比特大型语言模型的潜力通过自蒸馏

【大语言模型】ACL2024论文-07 BitDistiller: 释放亚4比特大型语言模型的潜力通过自蒸馏


目录

文章目录

  • 【大语言模型】ACL2024论文-07 BitDistiller: 释放亚4比特大型语言模型的潜力通过自蒸馏
    • 目录
      • 摘要
      • 研究背景
      • 问题与挑战
      • 如何解决
      • 创新点
      • 算法模型
      • 实验效果
      • 代码
      • 推荐阅读指数:✭✭✭✭✩
    • 后记


BitDistiller: 释放亚4比特大型语言模型的潜力通过自蒸馏
在这里插入图片描述

摘要

本文介绍了BitDistiller,这是一个通过结合量化感知训练(QAT)和知识蒸馏(KD)来提升超低精度(亚4比特)大型语言模型(LLMs)性能的框架。BitDistiller首先采用定制的非对称量化和裁剪技术来尽可能保持量化权重的保真度,然后提出了一种新颖的基于置信度的Kullback-Leibler散度(CAKLD)目标,用于自蒸馏,以实现更快的收敛和更优的模型性能。实验评估表明,BitDistiller在3比特和2比特配置下,无论是在通用语言理解还是复杂推理基准测试中,都显著超越了现有方法。值得注意的是,BitDistiller更具成本效益,需要更少的数据和训练资源。

研究背景

随着大型语言模型(LLMs)规模的扩大,自然语言处理领域取得了令人印象深刻的进展。然而,这种模型规模的扩大在部署上带来了显著的挑战,尤其是在资源受限的设备上,因为它们需要大量的内存和计算能力。权重量化作为一种流行的策略,通过减少模型大小来提高LLMs的效率和可访问性,同时最小化性能损失。尽管4比特量化已被广泛采用,提供了显著的压缩比和保留LLM能力之间的平衡,但亚4比特量化会显著降低模型权重的保真度,尤其是在小型模型或需要复杂推理的任务中,导致模型性能恶化。
在这里插入图片描述

问题与挑战

在极端低比特QAT中实现高性能的两个基本挑战是:如何在量化过程中最大限度地保持权重保真度,以及如何在训练中有效学习低比特表示。

如何解决

BitDistiller通过以下方式解决上述挑战:

  1. 非对称量化和裁剪:BitDistiller采用了定制的非对称量化和裁剪策略,以保持全精度模型的能力,特别是在超低比特水平上。
  2. 自蒸馏:BitDistiller利用全精度模型作为教师,低比特模型作为学生,通过自蒸馏方法进行有效的低比特表示学习。
  3. CAKLD目标:BitDistiller创新性地提出了一种基于置信度的Kullback-Leibler散度(CAKLD)目标,优化知识传递效率,实现更快的收敛和增强的模型性能。

创新点

  • 非对称量化和裁剪:BitDistiller针对不同比特级别的量化采用了不同的量化策略,如NF格式和INT格式,以及非对称裁剪,以提高量化权重的表示保真度。
  • CAKLD目标:BitDistiller提出了一种新颖的CAKLD目标,它根据全精度模型对训练数据的置信度自动权衡模式寻求和模式覆盖行为。
  • 自蒸馏框架:BitDistiller将QAT与知识蒸馏相结合,使用全精度模型作为教师来指导低比特学生模型,这是一种简单而有效的自蒸馏方法。
    在这里插入图片描述

算法模型

BitDistiller的框架包括以下几个关键步骤:

  1. 非对称量化和裁剪:在QAT初始化阶段,BitDistiller对权重进行非对称裁剪,以减少量化误差。
  2. 自蒸馏:在训练过程中,全精度模型生成数据,低比特模型学习这些数据,通过CAKLD目标进行优化。
  3. CAKLD目标:CAKLD目标结合了反向KL散度和正向KL散度,根据全精度模型的置信度自动调整模式寻求和模式覆盖行为。
    在这里插入图片描述

实验效果

实验评估表明,BitDistiller在3比特和2比特配置下的性能显著优于现有的PTQ和QAT方法。以下是一些重要的数据和结论:

  • 语言建模任务:在WikiText-2的困惑度(PPL)和MMLU(5-shot)准确性方面,BitDistiller超越了竞争对手。
  • 推理任务:在HumanEval和GSM8K等推理基准测试中,BitDistiller在3比特和2比特量化中均展现出优越性能。
  • 成本效益:BitDistiller需要的训练数据和资源更少,更具成本效益。
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述

代码

https://github.com/DD-DuDa/BitDistiller.git
在这里插入图片描述

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from tqdm import tqdm
import gc
# import bitsandbytes as bnb
import torch.nn as nn
from functools import partial
# import bitsandbytes.functional as bnbFclass Round(Function):@staticmethoddef forward(self, input):sign = torch.sign(input)output = sign * torch.floor(torch.abs(input) + 0.5)return output@staticmethoddef backward(self, grad_output):grad_input = grad_output.clone()return grad_input# core quantization method (simulated quantization)
def pseudo_quantize_tensor(w, n_bit=8,zero_point=True, q_group_size=-1,inplace=False,get_scale_zp=False):org_w_shape = w.shapeif q_group_size > 0:assert org_w_shape[-1] % q_group_size == 0w = w.reshape(-1, q_group_size)elif q_group_size == -1:w = w.reshape(-1, w.shape[-1])assert w.dim() == 2if zero_point:max_val = w.amax(dim=1, keepdim=True)min_val = w.amin(dim=1, keepdim=True)max_int = 2 ** n_bit - 1min_int = 0scales = (max_val - min_val).clamp(min=1e-5) / max_intzeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)else:  # we actually never used thisassert min_val is Nonemax_val = w.abs().amax(dim=1, keepdim=True)max_val = max_val.clamp(min=1e-5)max_int = 2 ** (n_bit - 1) - 1min_int = - 2 ** (n_bit - 1)scales = max_val / max_intzeros = 0assert torch.isnan(scales).sum() == 0assert torch.isnan(w).sum() == 0if inplace:((w.div_(scales).round_().add_(zeros)).clamp_(min_int, max_int).sub_(zeros)).mul_(scales)else:w = (torch.clamp(torch.round(w / scales) +zeros, min_int, max_int) - zeros) * scalesassert torch.isnan(w).sum() == 0w = w.reshape(org_w_shape)if get_scale_zp:return w, scales.view(w.shape[0], -1), zeros.view(w.shape[0], -1)else:return w@torch.no_grad()
def real_quantize_model_weight(model, w_bit, q_config,init_only=False
):from .qmodule import WQLinearfrom .pre_quant import get_blocks, get_named_linears, set_op_by_nameassert q_config["zero_point"], "We only support zero_point quantization now."layers = get_blocks(model)for i in tqdm(range(len(layers)), desc="real weight quantization..." + ("(init only)" if init_only else "")):layer = layers[i]named_linears = get_named_linears(layer)# scale_activations(layer)for name, module in named_linears.items():if init_only:q_linear = WQLinear.from_linear(module, w_bit, q_config['q_group_size'], True)q_linear.to(next(layer.parameters()).device)set_op_by_name(layer, name, q_linear)else:module.cuda()module.weight.data, scales, zeros = pseudo_quantize_tensor(module.weight.data, n_bit=w_bit, get_scale_zp=True, **q_config)# scales = scales.t().contiguous()# zeros = zeros.t().contiguous()q_linear = WQLinear.from_linear(module, w_bit, q_config['q_group_size'], False, scales, zeros)module.cpu()q_linear.to(next(layer.parameters()).device)set_op_by_name(layer, name, q_linear)torch.cuda.empty_cache()gc.collect()torch.cuda.empty_cache()gc.collect()def pseudo_quantize_n2f3_tensor(w, q_group_size=-1):quantizer = SteN2F3Quantizer(q_group_size=q_group_size)w = quantizer(w)return wclass SteInt3AsymQuantizer(nn.Module):def __init__(self, q_group_size=128):super().__init__()self.q_group_size = q_group_sizeself.bit = 3def forward(self, x):org_w_shape = x.shapeif self.q_group_size > 0:assert org_w_shape[-1] % self.q_group_size == 0x = x.reshape(-1, self.q_group_size)elif self.q_group_size == -1:assert org_w_shape[-1] % self.q_group_size == 0x = x.reshape(-1, x.shape[-1])assert x.dim() == 2max_val = x.amax(dim=1, keepdim=True)min_val = x.amin(dim=1, keepdim=True)max_int = 2 ** self.bit - 1min_int = 0scales = (max_val - min_val).clamp(min=1e-5) / max_intzeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)assert torch.isnan(scales).sum() == 0assert torch.isnan(x).sum() == 0x = (torch.clamp(Round.apply(x / scales) +zeros, min_int, max_int) - zeros) * scalesassert torch.isnan(x).sum() == 0x = x.reshape(org_w_shape)return xclass SteInt2AsymQuantizer(nn.Module):def __init__(self, q_group_size=64):super().__init__()self.q_group_size = q_group_sizeself.bit = 2def forward(self, x):org_w_shape = x.shapeif self.q_group_size > 0:assert org_w_shape[-1] % self.q_group_size == 0x = x.reshape(-1, self.q_group_size)assert x.dim() == 2max_val = x.amax(dim=1, keepdim=True)min_val = x.amin(dim=1, keepdim=True)max_int = 2 ** self.bit - 1min_int = 0scales = (max_val - min_val).clamp(min=1e-5) / max_intzeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)assert torch.isnan(scales).sum() == 0assert torch.isnan(x).sum() == 0x = (torch.clamp(Round.apply(x / scales) +zeros, min_int, max_int) - zeros) * scalesassert torch.isnan(x).sum() == 0x = x.reshape(org_w_shape)return xclass SteN2F3Quantizer(nn.Module):def __init__(self, q_group_size=128):super().__init__()self.q_group_size = q_group_sizedef forward(self, x):org_w_shape = x.shape# reshape to groupsizeif self.q_group_size > 0:assert org_w_shape[-1] % self.q_group_size == 0qx = x.reshape(-1, self.q_group_size)elif self.q_group_size == -1:qx = x.reshape(-1, x.shape[-1])assert qx.dim() == 2# Get the Min Maxmax_val = qx.amax(dim=1, keepdim=True)min_val = qx.amin(dim=1, keepdim=True)scale_pos = torch.abs(max_val)scale_neg = torch.abs(min_val)dev = qx.devicex_pos = torch.zeros_like(qx)x_neg = torch.zeros_like(qx)x_pos = torch.where(qx >= 0, qx, x_pos)x_neg = torch.where(qx < 0, qx, x_neg)q_pos = x_pos / scale_posq_neg = x_neg / scale_negq_pos, q_neg = self.round_pass(q_pos, q_neg, dev)qx = q_pos * scale_pos + q_neg * scale_negqx = qx.reshape(org_w_shape)return qxdef round_n2f3(self, q_pos, q_neg, dev):q_pos = torch.where(q_pos >= 0.8114928305149078,                                        torch.tensor(1.0).to(dev), q_pos)q_pos = torch.where((q_pos < 0.8114928305149078)    & (q_pos >= 0.5024898052215576),    torch.tensor(0.6229856610298157).to(dev), q_pos)q_pos = torch.where((q_pos < 0.5024898052215576)    & (q_pos >= 0.2826657369732857),    torch.tensor(0.3819939494132996).to(dev), q_pos)q_pos = torch.where((q_pos < 0.2826657369732857)    & (q_pos >= 0.0916687622666359),    torch.tensor(0.1833375245332718).to(dev), q_pos)q_pos = torch.where(q_pos < 0.0916687622666359,                                        torch.tensor(0).to(dev), q_pos)q_neg = torch.where(q_neg >= -0.1234657019376755,                                     torch.tensor(0).to(dev), q_neg)q_neg = torch.where((q_neg < -0.1234657019376755)   & (q_neg >= -0.39097706973552704),   torch.tensor(-0.2469314038753510).to(dev), q_neg)q_neg = torch.where((q_neg < -0.39097706973552704)   & (q_neg >= -0.7675113677978516),   torch.tensor(-0.5350227355957031).to(dev), q_neg)q_neg = torch.where(q_neg < -0.7675113677978516,                                        torch.tensor(-1.0).to(dev), q_neg)return q_pos, q_negdef round_pass(self, q_pos, q_neg, dev):y_grad_pos, y_grad_neg = q_pos, q_negy_pos, y_neg = self.round_n2f3(q_pos, q_neg, dev)return (y_pos - y_grad_pos).detach() + y_grad_pos, (y_neg - y_grad_neg).detach() + y_grad_neg

推荐阅读指数:✭✭✭✭✩

推荐理由

  • 创新性:BitDistiller通过结合QAT和KD,在亚4比特量化领域提供了一种新的解决方案,具有显著的性能提升。
  • 实用性:BitDistiller不仅在理论上具有创新性,而且在实际应用中也显示出了成本效益,这对于资源受限的设备尤为重要。
  • 广泛适用性:BitDistiller在多种语言和推理任务中都展现出了优越的性能,表明其方法的广泛适用性。

后记

如果您对我的博客内容感兴趣,欢迎三连击(点赞、收藏、关注和评论),我将持续为您带来计算机人工智能前沿技术(尤其是AI相关的大语言模型,深度学习和计算机视觉相关方向)最新学术论文及工程实践方面的内容分享,助力您更快更准更系统地了解 AI前沿技术

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/463476.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

P9220 「TAOI-1」椎名真昼

P9220 「TAOI-1」椎名真昼 考点&#xff1a;博弈论、拓扑、强连通分量。 难度&#xff1a; 提高/省选- 。 题意&#xff1a; ​ Alice 和 Bob 玩游戏&#xff0c;给定一个有向图&#xff0c;每个点有初始颜色&#xff08;黑/白&#xff09;。 ​ 双方轮番操作一次&#xf…

计算机网络:网络层 —— 多播路由选择协议

文章目录 多播路由选择协议多播转发树构建多播转发树基于源树的多播路由选择建立广播转发树建立多播转发树 组共享树的多播路由选择基于核心的生成树的建立过程 因特网的多播路由选择协议 多播路由选择协议 仅使用 IGMP 并不能在因特网上进行IP多播。连接在局域网上的多播路由…

例行性工作

1、单一执行------at-----仅处理执行一次就结束了 1.1工作过程 /etc/at.allow&#xff0c;写在该文件的人可以使用at命令/etc/at.deny&#xff0c;黑名单两个文件如果都不存在&#xff0c;只有root能使用 1.2命令详解------命令格式&#xff1a;at [参数] [时间] 2、循环执行…

使用Kafka构建大规模消息传递系统

&#x1f493; 博客主页&#xff1a;瑕疵的CSDN主页 &#x1f4dd; Gitee主页&#xff1a;瑕疵的gitee主页 ⏩ 文章专栏&#xff1a;《热点资讯》 使用Kafka构建大规模消息传递系统 引言 Kafka 简介 安装 Kafka 创建主题 生产者 消费者 高级特性 分区 持久化 消费者组 消息确认…

【sqlmap使用】

sqlmap简介 sqlmap 目录结构 sqlmap常用参数 sqlmap实现注入 测试注入点&#xff0c;检测到注入点后&#xff0c;直接爆数据库名 python sqlmap.py –u http://172.16.12.2/7/9/strsql.php --data "usernameadmin" --dbs注意sqlmap在使用过程中可能会出现几个需要…

【java】java的基本程序设计结构07-字符串

字符串 1. 创建字符串 最简单的&#xff1a; String str "hello"; 用构造函数创建字符串&#xff1a; String str2new String("hello"); String 创建的字符串存储在公共池中&#xff0c;而 new 创建的字符串对象在堆上&#xff1a; 注意: String 类…

数组排序简介-基数排序(Radix Sort)

基本思想 将整数按位数切割成不同的数字&#xff0c;然后从低位开始&#xff0c;依次到高位&#xff0c;逐位进行排序&#xff0c;从而达到排序的目的。 算法步骤 基数排序算法可以采用「最低位优先法&#xff08;Least Significant Digit First&#xff09;」或者「最高位优先…

w~Transformer~合集8

我自己的原文哦~ https://blog.51cto.com/whaosoft/12419881 #Batch Normalization 本文聚焦于Batch Normalization&#xff0c;Layer Normalization两个标准化方法&#xff0c;对其原理和优势等进行了详细的阐述。 这一篇写Transformer里标准化的方法。在Transformer中&am…

Hadoop——HDFS

什么是HDFS HDFS&#xff08;Hadoop Distributed File System&#xff09;是Apache Hadoop的核心组件之一&#xff0c;是一个分布式文件系统&#xff0c;专门设计用于在大规模集群上存储和管理海量数据。它的设计目标是提供高吞吐量的数据访问和容错能力&#xff0c;以支持大数…

废弃物分类分割系统:入门训练营

废弃物分类分割系统源码&#xff06;数据集分享 [yolov8-seg-C2f-DCNV2-Dynamic&#xff06;yolov8-seg-C2f-DWR等50全套改进创新点发刊_一键训练教程_Web前端展示] 1.研究背景与意义 项目参考ILSVRC ImageNet Large Scale Visual Recognition Challenge 项目来源AAAI Glob…

java项目之微服务在线教育系统设计与实现(springcloud)

风定落花生&#xff0c;歌声逐流水&#xff0c;大家好我是风歌&#xff0c;混迹在java圈的辛苦码农。今天要和大家聊的是一款基于springboot的闲一品交易平台。项目源码以及部署相关请联系风歌&#xff0c;文末附上联系信息 。 项目简介&#xff1a; 微服务在线教育系统设计与…

拆换LED灯珠后测量是短路的,为何

今天更换灯珠遇到一个怪事情&#xff0c;拆换一颗好的灯珠上去&#xff0c;万用表测试是短路的。 后面测试电路板上面&#xff0c;中间的散热部分是跟二极管的正极想通的。而且恰恰此时&#xff0c;LED灯珠的散热部分是跟负极想通的。 遂将线路板上面的散热部分跟二极管正极割…

串口屏控制的自动滑轨(未完工)

序言 疫情期间自己制作了一个自动滑轨&#xff0c;基于无线遥控的&#xff0c;但是整体太大了&#xff0c;非常不方便携带&#xff0c;所以重新设计了一个新的&#xff0c;以2020铝型材做导轨的滑轨&#xff0c;目前2020做滑轨已经很成熟了&#xff0c;配件也都非常便宜&#x…

【NOIP提高组】Hankson的趣味题

【NOIP提高组】Hankson的趣味题 &#x1f490;The Begin&#x1f490;点点关注&#xff0c;收藏不迷路&#x1f490; Hanks 博士是BT (Bio-Tech&#xff0c;生物技术) 领域的知名专家&#xff0c;他的儿子名叫Hankson。现在&#xff0c;刚刚放学回家的Hankson 正在思考一个有趣…

Matlab车牌识别课程设计报告(附源代码)

Matlab车牌识别系统 分院&#xff08;系&#xff09; 信息科学与工程 专业 学生姓名 学号 设计题目 车牌识别系统设计 内容及要求&#xff1a; 车牌定位系统的目的在于正确获取整个图像中车牌的区域&#xff0c; 并识别出车牌号。通过设计实现车牌识别系…

【Unity基础】初识UI Toolkit - 运行时UI

Unity中的UI工具包&#xff08;UI Toolkit&#xff09;不但可以用于创建编辑器UI&#xff0c;同样可以来创建运行时UI。 关于Unity中的UI系统以及使用UI工具包创建编辑器UI可以参见&#xff1a; 1. Unity中的UI系统 2. 初识UI Toolkit - 编辑器UI 本文将通过一个简单示例来…

【重生之我要苦学C语言】深入理解指针4

深入理解指针4 字符指针变量 指针指向字符变量 char ch w; char* p &ch;指针指向字符数组 char arr[10] "abcdef"; char* p arr;printf("%s\n", arr); printf("%s\n", p);结果是一样的 也可以写成&#xff1a; char* p "abc…

Freertos学习日志(1)-基础知识

目录 1.什么是Freertos&#xff1f; 2.为什么要学习RTOS&#xff1f; 3.Freertos多任务处理的原理 1.什么是Freertos&#xff1f; RTOS&#xff0c;即&#xff08;Real Time Operating System 实时操作系统&#xff09;&#xff0c;是一种体积小巧、确定性强的计算机操作系统…

勒索软件通过易受攻击的 Cyber​​Panel 实例攻击网络托管服务器

一个威胁行为者&#xff08;或可能多个&#xff09;使用 PSAUX 和其他勒索软件攻击了大约 22,000 个易受攻击的 Cyber​​Panel 实例以及运行该实例的服务器上的加密文件。 PSAUX 赎金记录&#xff08;来源&#xff1a;LeakIX&#xff09; Cyber​​Panel 漏洞 Cyber​​Pane…

基于vue3和elementPlus的el-tree组件,实现树结构穿梭框,支持数据回显和懒加载

一、功能 功能描述 数据双向穿梭&#xff1a;支持从左侧向右侧转移数据&#xff0c;以及从右侧向左侧转移数据。懒加载支持&#xff1a;支持懒加载数据&#xff0c;适用于大数据量的情况。多种展示形式&#xff1a;右侧列表支持以树形结构或列表形式展示。全选与反选&#xf…