[ACL 2024] Revisiting Knowledge Distillation for Autoregressive Language Models

Contents

  • Introduction
  • Method
    • Rethinking Knowledge Distillation for Autoregressive LMs
    • Improving Knowledge Distillation with Adaptive Teaching Modes
  • Experiments
  • References

Introduction

  • 作者提出 Autoregressive KD with Adaptive Teaching Modes (ATKD),通过对难易样本采用不同的学习策略来解决 larger teachers might dramatically result in a poorer student, especially when the model capability gap is large 的问题,可以作为一种通用的学习策略提升不同的已有 KD 算法的精度
    在这里插入图片描述

Method

Rethinking Knowledge Distillation for Autoregressive LMs

  • Reformulation of L K L \mathcal L_{\mathbf {KL}} LKL. KL 散度可以被分解为 ground truth 类别上的 binary KL loss K L ( p b t ∣ ∣ q b t ) \mathrm{KL}(\mathrm{p}_\mathrm{b}^t||\mathrm{q}_\mathrm{b}^t) KL(pbt∣∣qbt) 和非 ground truth 类别上的 KL loss K L ( p ^ t ∣ ∣ q ^ t ) \mathrm{KL}(\hat{\mathrm{p}}^\mathrm{t}||\hat{\mathrm{q}}^\mathrm{t}) KL(p^t∣∣q^t),前者可以帮助 student 学习 target 相关的信息,被称为 target-oriented knowledge distillation (TKD),后者可以帮助 student 学习 non-target 中蕴含的知识,被称为 diversity-oriented knowledge distillation (DKD);此外,这两部分的蒸馏损失被加上了一个权值 p \ g t t p_{\backslash g_t}^t p\gtt,该项反映了 teacher 的 uncertainty,被称为 uncertainty coefficient (UNC)
    L K L = ∑ t = 1 T ( p g t t log ⁡ ( p g t t q g t t ) + ∑ j = 1 , j ≠ g t C p j t log ⁡ ( p j t q j t ) ) = ∑ t = 1 T ( p g t t log ⁡ ( p g t t q g t t ) + p \ g t t ∑ j = 1 , j ≠ g t C p ^ j t ( log ⁡ ( p ^ j t q ^ j t ) + log ⁡ ( p \ g t t q \ g t t ) ) = ∑ t = 1 T ( p g t t log ⁡ ( p g t t q g t t ) + p ∖ g t t log ⁡ ( p ∖ g t t q ∖ g t t ) + p ∖ g t t ∑ j = 1 , j ≠ g t C p ^ j t log ⁡ ( p ^ j t q ^ j t ) = ∑ t = 1 T ( K L ( p b t ∣ ∣ q b t ) + p \ g t t K L ( p ^ t ∣ ∣ q ^ t ) ) \begin{aligned} \mathcal{L}_{\mathrm{KL}}& =\sum_{t=1}^{T}(p_{g_{t}}^{t}\log(\frac{p_{g_{t}}^{t}}{q_{g_{t}}^{t}})+\sum_{j=1,j\neq g_{t}}^{C}p_{j}^{t}\log(\frac{p_{j}^{t}}{q_{j}^{t}})) \\&=\sum_{t=1}^T\left(p_{g_t}^t\log(\frac{p_{g_t}^t}{q_{g_t}^t})\right. \\ &\ \ \ \ \ +p_{\backslash g_{t}}^{t}\sum_{j=1,j\neq g_{t}}^{C}\hat{p}_{j}^{t}\left(\log(\frac{\hat{p}_{j}^{t}}{\hat{q}_{j}^{t}})+\log(\frac{p_{\backslash g_{t}}^{t}}{q_{\backslash g_{t}}^{t}})\right) \\ &=\sum_{t=1}^{T}\left(p_{g_{t}}^{t}\log(\frac{p_{g_{t}}^{t}}{q_{g_{t}}^{t}})+p_{\setminus g_{t}}^{t}\log(\frac{p_{\setminus g_{t}}^{t}}{q_{\setminus g_{t}}^{t}})\right. \\ &\ \ \ \ \ +p_{\setminus g_t}^t\sum_{j=1,j\neq g_t}^C\hat{p}_j^t\log(\frac{\hat{p}_j^t}{\hat{q}_j^t}) \\ &=\sum_{t=1}^T\left(\mathrm{KL}(\mathrm{p}_\mathrm{b}^t||\mathrm{q}_\mathrm{b}^t)+p_{\backslash g_t}^t\mathrm{KL}(\hat{\mathrm{p}}^\mathrm{t}||\hat{\mathrm{q}}^\mathrm{t})\right) \end{aligned} LKL=t=1T(pgttlog(qgttpgtt)+j=1,j=gtCpjtlog(qjtpjt))=t=1T(pgttlog(qgttpgtt)     +p\gttj=1,j=gtCp^jt(log(q^jtp^jt)+log(q\gttp\gtt))=t=1T(pgttlog(qgttpgtt)+pgttlog(qgttpgtt)     +pgttj=1,j=gtCp^jtlog(q^jtp^jt)=t=1T(KL(pbt∣∣qbt)+p\gttKL(p^t∣∣q^t))其中, T T T 为序列长度, p , q p,q p,q 分别为 teacher 和 student 的概率分布, g t gt gt 为 teacher 预测的 ground-truth 类别, p g t t = exp ⁡ ( z g t t ) ∑ j = 1 C exp ⁡ ( z j t ) , p ∖ g t t = ∑ k = 1 , k ≠ g t C exp ⁡ ( z k t ) ∑ j = 1 C exp ⁡ ( z j t ) , p ^ i t = exp ⁡ ( z i t ) ∑ j = 1 , j ≠ g t C exp ⁡ ( z j t ) p_{g_t}^t=\frac{\exp(z_{g_t}^t)}{\sum_{j=1}^C\exp(z_j^t)},p_{\setminus g_t}^t=\frac{\sum_{k=1,k\neq g_t}^C\exp(z_k^t)}{\sum_{j=1}^C\exp(z_j^t)},\hat{p}_i^t=\frac{\exp(z_i^t)}{\sum_{j=1,j\neq g_t}^C\exp(z_j^t)} pgtt=j=1Cexp(zjt)exp(zgtt),pgtt=j=1Cexp(zjt)k=1,k=gtCexp(zkt),p^it=j=1,j=gtCexp(zjt)exp(zit) p i t = p ∖ g t t ⋅ p ^ i t p_i^t=p_{\setminus g_t}^t\cdot \hat{p}_i^t pit=pgttp^it p b t = [ p g t t , p ∖ g t t ] \mathrm{p}_{\mathrm{b}}^t=[p_{g_t}^t,p_{\setminus g_t}^t] pbt=[pgtt,pgtt]
  • Empirical Analyses. (1) UNC measures the learning difficulties of tokens, where the hard-to-learn ones are more important for KD. 根据 p \ g t t p_{\backslash g_t}^t p\gtt 的大小可以把 tokens 分为难样本 (top-50% uncertainty) 和简单样本,实验发现难样本对 student 的学习更重要,尤其是 student 和 teacher 差距比较大的时候,这可能是因为难样本能让 student 学到丰富的类间信息,同时避免过拟合
    在这里插入图片描述(2) DKD contributes more (than TKD) but is greatly suppressed, especially for the larger teachers. 作者对 TKD 和 DKD 做了解耦,去除了权重 p \ g t t p_{\backslash g_t}^t p\gtt 来考察它们各自的作用,作者发现 DKD 显著优于 TKD,但在 KL loss 中,由于 p \ g t t p_{\backslash g_t}^t p\gtt 的存在,DKD 的权值被降低了,并且这一现象在更大规模的模型中尤为显著,这也是作者认为的导致 larger teachers might dramatically result in a poorer student 的原因在这里插入图片描述在这里插入图片描述(3) TKD plays different roles in tokens with different learning difficulties. TKD 在简单样本上可能会导致 student 过拟合,从而影响泛化性;在难样本上能降低难样本的学习难度,从而提升 student 精度
    在这里插入图片描述

Improving Knowledge Distillation with Adaptive Teaching Modes

  • Autoregressive KD with Adaptive Teaching Modes (ATKD). 基于上述观察很容易想到,不同的 tokens 根据其难易程度,应该有不同的学习策略;简单样本仅使用 DKD,难样本 (top-50% uncertainty) 使用 DKD + TKD
    L K L e = − ∑ t ∈ D e K L ( p ^ t ∣ ∣ q ^ t ) , L K L h = − ∑ t ∈ D h K L ( p b t ∣ ∣ q b t ) + K L ( p ^ t ∣ ∣ q ^ t ) \begin{aligned} &\mathcal{L}_\mathrm{KL}^{e} =-\sum_{t\in\mathcal{D}_e}\mathrm{KL}(\mathbf{\hat{p}^t}||\mathbf{\hat{q}^t}), \\ &\mathcal{L}_{\mathrm{KL}}^h =-\sum_{t\in\mathcal{D}_h}\mathrm{KL}(\mathbf{p_b^t}||\mathbf{q_b^t})+\mathrm{KL}(\mathbf{\hat{p}^t}||\mathbf{\hat{q}^t}) \end{aligned} LKLe=tDeKL(p^t∣∣q^t),LKLh=tDhKL(pbt∣∣qbt)+KL(p^t∣∣q^t)最终的损失函数为简单样本和难样本上损失的加权和 L K L a l l = λ ∗ L K L e + ( 1 − λ ) ∗ L K L h \mathcal{L}_{\mathrm{KL}}^{all}=\lambda*\mathcal{L}_{\mathrm{KL}}^e+(1-\lambda)*\mathcal{L}_{\mathrm{KL}}^h LKLall=λLKLe+(1λ)LKLh其中, λ = 0.2 \lambda=0.2 λ=0.2

Experiments

  • Compared Results. S NLG \mathcal S_{\textrm{NLG}} SNLG 为语言生成任务,由 GPT-4 打分; S NLU \mathcal S_{\textrm{NLU}} SNLU 为语言理解任务,为 benchmark 得分
    在这里插入图片描述在这里插入图片描述
  • Ablation Study. (1) Impact of ratio k k k. k k k 用于确定 top- k k k uncertainty 的 tokens 为难样本;(2) Impact of coefficient λ λ λ. 用于确定难易样本损失的权重
    在这里插入图片描述

References

  • Zhong, Qihuang, et al. “Revisiting knowledge distillation for autoregressive language models.” arXiv preprint arXiv:2402.11890 (2024).

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/408977.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

5000套精美PPT免费分享

目录 部分展示目录 几乎包含各种应用场景的PPT模板 这里只展示部分目录 部分展示目录 ##PPT下载 链接:https://pan.baidu.com/s/1ckvN9xeMR82hL30lHXfJ0g 提取码:ZYNB 点击下载,记得点个赞哦

3 pytest Fixture

目录 3.1 通过 conftest.py 共享 fixture3.2 使用 fixture 执行配置及销毁逻辑3.3 使用 --setup-show 回溯 fixture 的执行过程3.4 使用 fixture 传递测试数据3.5 使用多个 fixture3.6 指定 fixture 作用范围3.7 使用 usefixtures 指定 fixture3.8 为常用 fixture 添加 autouse…

MySQL从入门到精通(第9-10章)

文章目录 9 子查询9.1 需求分析与问题解决9.1.1 实际问题9.1.2 子查询的使用9.1.3 子查询的分类 9.2 单行子查询9.2.1 单行比较操作符9.2.2 代码示例9.2.3 HAVING中的子查询9.2.4 CASE中的子查询9.2.5 子查询中的空值问题9.2.6 非法使用子查询 9.3 多行子查询9.3.1 多行比较操作…

linux系统使用 docker 来部署web环境 nginx+php7.4 并配置称 docker-compose-mysql.yml 文件

Docker是一个开源的容器化平台,旨在简化应用程序的创建、部署和管理。它基于OS-level虚拟化技术,通过将应用程序和其依赖项打包到一个称为容器的标准化单元中,使得应用程序可以在任何环境中快速、可靠地运行。 Docker的优势有以下几个方面&a…

如何使用ssm实现基于java斗车交易系统设计与实现+vue

TOC ssm082基于java斗车交易系统设计与实现vue 系统概述 1.1 概述 随着社会的快速发展,计算机的影响是全面且深入的。人们的生活水平不断提高,日常生活中人们对斗车交易方面的要求也在不断提高,需要咨询的人数更是不断增加,使得…

【第69课】Java安全JWT攻防Swagger自动化算法签名密匙Druid未授权

免责声明 本文发布的工具和脚本,仅用作测试和学习研究,禁止用于商业用途,不能保证其合法性,准确性,完整性和有效性,请根据情况自行判断。 如果任何单位或个人认为该项目的脚本可能涉嫌侵犯其权利&#xff0…

RM遥控键鼠控制总结

硬件&通信介绍 RM比赛中各个参赛队伍使用的都是大疆官方提供的遥控器套装,包括遥控器和接收机,接收机上共三个引脚:VCC,GND,DBUS(数据通道),首次使用需要进行遥控器和接收机配对…

C++类和对象(下):初始化列表、explicit关键字、友元函数、友元类

文章目录 C类和对象9、初始化列表9.1构造函数体赋值9.2初始化列表9.3 explicit(显示)关键字 10、友元10.1友元函数10.2友元类 C类和对象 9、初始化列表 一个类的构造函数要初始化成员变量有两种方式,一种是构造函数体赋值,另一种…

8.23-docker基础命令学习

docker 1.docker容器 [rootdocker ~]# systemctl start docker[rootdocker ~]# docker imagesREPOSITORY TAG IMAGE ID CREATED SIZEcentos latest 5d0da3dc9764 2 years ago 231MB​# 容器执行完就退出了​[rootdocker ~]# docker run -it …

spring框架简介

文章目录 1.Spring的简介2.Spring的起源与发展3.Spring的核心体系介绍4.Spring框架的特点总结5.xml定义bean的相关属性1、class属性、id属性、name属性2、作用域属性3、初始化方法和销毁方法 1.Spring的简介 Spring的英文翻译为春天,可以说是给Java程序员带来了春天…

python爬虫——入门

一、概念 万维网之所以叫做网,是因为通过点击超链接或者进入URL,我们可以访问任何网络资源,从一个网页跳转到另一个网页,所有的相关资源连接在一起,就形成了一个网。 而爬虫呢,听名字就让人想起来一个黏糊…

设计模式篇(DesignPattern - 创建型模式)

目录 模式一:单例模式 一、简介 二、种类 1. 饿汉式(静态常量) 1.1. 代码 1.2. 优缺点 2. 饿汉式(静态代码块) 2.1. 代码 2.2. 优缺点 3. 懒汉式(线程不安全) 3.1. 代码 3.2. 优缺点 4. 懒汉式(线程安全,…

Vulkan入门系列16 - 生成多级纹理贴图( Mipmaps)

一:概述 我们的程序现在可以加载和渲染 3D 模型了。在本章中,我们将再添加一项功能-- Mipmaps 生成。Mipmaps 广泛应用于游戏和渲染软件中,Vulkan 让我们可以完全控制 Mpmaps 的生成方式。 Mipmaps 是预先计算的、缩放的图像。每个新图像的宽度和高度都是前一个图像的一半。…

ssrf漏洞之——漏洞复现

漏洞介绍 SSRF漏洞:SSRF(Server-Side Request Forgery:服务器端请求伪造) 是一种由恶意访问者构造url,由服务端对此url发起请求的一个安全漏洞。 漏洞原理 SSRF 形成的原因大都是由于服务端提供了从其他服务器应用获取数据的功能,并且没有对目…

(QT-UI)十四、在时间轴上绘制一段段时间片

本系列预计实现 ①刻度上方文字显示, ②时间轴拖动效果, ③时间轴刻度缩放, ④时间轴和其他控件联动显示, ⑤鼠标放置到时间轴,显示具体时间。 ⑥通过定时器,实时更新时间轴 ⑦时间轴上绘制时间片 完…

用excel内容批量建立文件夹

建文件夹是电脑操作过程中比较常见的,但是用EXCEL内容批量建文件夹,这似乎不相关的两个操作,那么怎么实现这样的一个功能,我们需要用到专门的软件进行关联,推荐:可易文件夹批量生成器,这个软件有…

数据结构基础详解(C语言): 栈与队列的详解附完整代码

数据结构 栈 栈的核心重点: 栈是只能从表尾插入和删除的数据结构。 栈的顺序存储结构由两部分组成,top指针和数组。 链栈其实本质就是单链表头插法 文章目录 数据结构 栈1.栈的基本概念1.1 栈的常用操作 2.栈的存储结构2.1 栈的顺序存储结构2.1.1 栈的定…

AVL树的旋转

目录 一、AVL树的概念 二、AVL树节点的定义 三、AVL树的插入 四、AVL树的旋转 4.1右单旋 4.2左单旋 4.3左右双旋 4.4右左双旋 五、AVL树的验证 六、AVL树的性能 一、AVL树的概念 二叉搜索树虽可以缩短查找的效率,但如果数据有序或接近有序二叉搜索树将退化…

【AI绘画】Midjourney提示词详解:精细化技巧与高效实践指南

文章目录 💯Midjourney提示词基础结构1 图片链接1.1 上传流程 2 文字描述3 后置参数 💯Midjourney提示词的文字描述结构全面剖析1 主体主体细节描述2 环境背景2.1 环境2.2 光线2.3 色彩2.4 氛围 3 视角4 景别构图5 艺术风格6 图片制作方法7 作品质量万能…

鸿蒙(API 12 Beta3版)【使用Image完成图片接收器】图片开发指导依赖JS对象

图像接收类,用于获取组件surface id,接收最新的图片和读取下一张图片,以及释放ImageReceiver实例。 开发步骤 添加依赖 在进行应用开发之前,开发者需要打开native工程的src/main/cpp/CMakeLists.txt,在target_link_…