知识蒸馏Matching logits与RocketQAv2

知识蒸馏Matching logits

公式推导

刚开始的\frac{\partial L}{\partial z_i}=q_i-p_i怎么来,可以转看下面证明梯度等于输出值-标签y

C是一个交叉熵,我们要求解的是这个交叉熵对z_i的这个梯度。z_i就是你可以理解成第i个类别的得分。z_i就是student model,被蒸馏的模型,它所输出的logits。

p_i是什么?是target probability对吧。q_i是什么?q_i认为就是这个distilled model的输出的那个probability。所以就是说这两个概率相减,再乘以这个T分之一T是什么?T是一个温度。

我们现在假定是说我们是用teacher model输出的这个label,然后去训练student model,或者说去训练distilled model。我们对这个第i个类别的梯度,就等于\frac{1}{T}{(q_i-p_i)},然后呢,q_ip_i可以做一个化简。

q_ip_i进行展开,概率都是用softmax算出来的,就可以得到这个式子。

通过e^x\approx 1+x来进行化简,这个式子在x比较小的时候是成立的。

在这里,当T足够大的时(相比z的logits,即z),\frac{z_i}{T}就足够的小,接近于0,此时e^{\frac{z_i}{T}}\approx 1+\frac{z_i}{T}

\sum_j e^{\frac{z_i}{T}}\approx \sum_j{1+\frac{z_i}{T}}=N+\sum_j{\frac{z_i}{T}} 

z_j的这个累加,它就等于零。这个v_j的这个累加也等于零,即\sum_j z_j=\sum_j v_j=0,所以这两个分母直接就变成了N。

\frac{1}{T}({\frac{1+z_i/T}{N}}-{\frac{1+v_i/T}{N}})=\frac{1}{TN}{\frac{z_i-v_i}{T}}

则所求梯度

想说明的事情

它其实就想说明这样一个事情。我们试图用一个teacher model,或者说我们想用VI对应的那个概率叫p_iz_i对应的概率叫q_i。如果我们想用这个p_i作为label去用交叉商去训练q_i去用这个soft label的交叉商去训练q_i,那么其实我们可能不需要套用交叉商这个东西了,我们也不需要什么softmax的label的交叉商,然后去做这个事了。因为这个东西在我们的这样一通推导下就会发现,其实就等于均方误差,右边这一项其实就是什么均方误差的求导,它就是均方误差求导之后的结果,你可以这样认为。

我们就会发现说,原来对于交叉商对于这个知识蒸馏的这个交叉商,然后我们对他求导求出来的梯度其实是近似等同于我们直接用MSE去训练,然后得到的梯度的。那么既然这样,我们为什么不直接用MSE?

它的推导就告诉我们说我们对于两个模型,两个多分类模型来说,我们要用a模型去交B模型做蒸馏。我们没有必要让这两个模型生成分别生成什么label,然后再生成预测的概率,然后再加上去优化了。

我们直接让这两个多分类模型的这个logic,然后直接做MSE就可以了,就可以做到一种就是一种这种MSE就是一种什么蒸馏的特殊形式。就是蒸馏的一个最早期的雏形,其实在这个时候都还没有考虑用这个什么KL散度来做,就只是提出最简单的一个思想是什么,就是用MSE来做就够了。

我们一直即便到今天,我们做很多知识正溜的实验,我们依然会发现MIC可能有的时候都会比K要好。虽然大家都说自己用什么KL散度用什么JS散度,但是就是否现在就最优,还真不一定有的时候就是MSE效果好。

注:MSE = \frac{1}{N} \sum_{i=1}^{N} (y_i - \hat{y}_i)^2

需要注意的事

公式的推导基于两个假设:

1.T得足够大的(相比z的logits,即z_i)

2.模型输出的logic是零均值的(即均值为0),因为模型输出的logic是零均值的,这个z_j的这个累加,它就等于零。这个v_j的这个累加也等于零,即\sum_j z_j=\sum_j v_j=0

 证明梯度等于输出值-标签y

softmax函数

归一化,使其输出的概率和为1

S_i=\frac{e^{z_i}}{\sum_ke^{z_k}}

S_i代表的是第i个神经元的输出。

神经元的输出,一个神经元如下图:

z_i=\sum_jw_{ij}x_{ij}+b

其中w_{ij}是第i个神经元的第j个权重,b是偏移值。z_i表示该网络的第i个输出。

给这个输出加上一个softmax函数,得a_i=\frac{e^{z_i}}{\sum_ke^{z_k}}

a_i代表softmax的第i个输出值

交叉熵损失函数 loss function

L=-\sum_i{y_i}{lna_i}

其中y_i表示真实的分类结果。

证明梯度等于输出值-标签y

loss对于神经元输出z_i的梯度为\frac{\partial L}{\partial z_i}=\frac{\partial L}{\partial a_j}\frac{\partial a_j}{\partial z_i}

由于softmax公式的特性,它的分母包含了所有神经元的输出,对于不等于i的其他输出里面,也包含着z_i,所有的a都要纳入到计算范围中,并且后面的计算可以看到需要分为i=ji \ne j两种情况求导。

由于\frac{\partial (-\sum_{k\ne j}y_{k}ln a_k)}{\partial a_j}=0

\frac{\partial C}{\partial a_j}=\frac{\partial (-\sum_jy_jln a_j)}{\partial a_j}=-\sum_jy_j\frac{1}{a_J}

如果i=j

\frac{\partial a_i}{\partial z_i}=\frac{\partial (\frac{e^{z_i}}{\sum_ke^{z_k}})}{\partial z_i}=\frac{\partial (\frac{e^{z_i}}{\sum_{k\ne i}e^{z_k}+e^{z_i}})}{\partial z_i}=\frac{\sum_ke^{z_k}e^{z_i}-(e^{z_i})^2}{\sum_k(e^{z_k})^2}
=(\frac{e^{z_i}}{\sum_ke^{z_k}})(1-\frac{e^{z_i}}{\sum_ke^{z_k}})=a_i(1-a_i)

这里\sum_ke^{z_k}=\sum_{k\ne i}e^{z_k}+e^{z_i}

如果i \ne j

这里\sum_ke^{z_k}=\sum_{k\ne j}e^{z_k}+e^{z_j}

\frac{\partial a_i}{\partial z_i}=\frac{\partial (\frac{e^{z_i}}{\sum_ke^{z_k}})}{\partial z_i}=-e^{z_j}(\frac{1}{\sum_ke^{z_k}})e^{z_i}=-a_ia_j

综上

\frac{\partial L}{\partial z_i}=\frac{\partial L}{\partial a_j}\frac{\partial a_j}{\partial z_i}=(-\sum_jy_j\frac{1}{a_j})\frac{\partial a_j}{\partial z_i}=-\frac{y_i}{a_i}a_i(1-a_i)+\sum_{j\neq i}\frac{y_i}{a_j}a_ia_j
=-y_i+y_ia_i+\sum_{j\neq i}{y_ia_i}=-y_i+a_i\sum_{j}y_j

最后,针对分类问题,我们给定的结果y_i最终只会有一个类别是1,其他非标签类别都是0,因此,对于分类问题,这个梯度等于

\frac{\partial L}{\partial z_i}=a_i-y_i

知识蒸馏RocketQAv2

https://arxiv.org/pdf/2110.07367.pdf

这个模型有两部分组成一个retriever和一个ranker。这个做的事情就是说用label去监督re-ranker,然后用ranker去监督retriever。用KL散度去约束它约束,用这个K散路去让这个re-ranker的分布和retriever的分布对齐。

要注意就是说。这里就是他们就没有用MSE,就是说如果用MSE怎么做,就是说对应的这个直接相减,就对应位置直接相减,然后分MSE就行。这里用的是KL散度。

KL散度的定义,你可以认为是这样的,让这两个概率分别相除,除完了之后都要再取对数,然后再乘以这个概率。

DE,这个teacher model的概率乘以teacher model的概率乘以log,teacher model的概率除以student model的概率。然后把这么多概率给它都累加起来。

在这里,假定这里的是retriever给出来的一个概率分布假如说是十个候选,ranker也给了这样一个概率分布,那么就是十个概率分布对应的一项一项的去算这个KL度,即概率除概率,然后再取对数,然后再乘上ranker这个概率。

然后再把这十项给它累加起来,然后就是一个KL散度,这样的话,这个K散度其实是现在就是接受最多的一种损失函数。

因为KL散度就是天生的,可以捕获这个分布和分布之间的距离。像MSE缺点是什么?MSE的缺点是它没有整体的那种距离衡量的能力。MSE其实是对于细节的这种距离的衡量很强。如果MSE来的话,每一个每一项,这十项每一项的重要性对于MIC来说都是一样的。但是这个KL散度可能就会更在乎一个整体的一个分布上的一个区别了,就而不是说就在乎一些细节上的一些差别,因为有可能就是说。你某一些细节差距虽然大一些,但是你整体差距不大,所以KL散度也可以比较小。

实际上一切可以衡量两个分布之间距离的指标都可以用来做知识蒸馏,所以其实wasserstein距离也可以用来作为蒸馏的损失函数:

https://openaccess.thecvf.com/content/CVPR2021/papers/Chen_Wasserstein_Contrastive_Representation_Distillation_CVPR_2021_paper.pdf

为什么知识蒸馏会有效?

1. teacher model可以生成soft label,相比于原始数据的hard label,包含了更多信息量。

所以很多时候你与其说直接用一个数据集去训练一个模型,你还不如用这个数据集先训练一个大a模型比a模型要大的模型。再让大a模型去教会a模型去做,有可能效果就更好。就是因为大a模型这个teacher model可以生成soft label相比于原始数据的hard label,可以包含更多的信息量,从而就天然的有一种去燥的一种功能。

2. teacher model可以为大量的无标签数据打上label,然后为student提供一个大规模的训练集。然后从而可以给student提供一个更大尺度的训练集,然后防止student的一个过拟合,然后提高student model的一个泛化能力。也就是说,teacher model可以把自己的泛化能力交给student model

在这个知识蒸馏的过程当中,这也是为什么说很多大公司里边现在线上的模型都是蒸馏出来的小模型就是因为我们与其说直接训练小模型。还不如说就用这个蒸馏去蒸馏一个小模型反而泛化能力会更强一些

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/276989.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Java两周半速成之路(第十六天)

一、网络编程 1.概述: 就是用来实现网络互连的不同计算机上运行的程序间可以进行数据交换 2.网络模型 3.网络参考模型图 4.网络通信三要素 4.1IP地址 InetAddress类的使用: 注意:通过API查看,此类没有构造方法,如…

Docker 哲学 - 容器操作

容器: 创建 停止 删除 强制删除(正在运行) run stop rm rm -f 列出本地容器: docker ps / docker container ls 镜像: search pull run : …

第十四届蓝桥杯省赛真题 Java A 组【原卷】

文章目录 发现宝藏【考生须知】试题 A \mathrm{A} A : 特殊日期试题 B: 与或异或试题 C : \mathrm{C}: C: 平均试题 D: 棋盘试题 E : \mathrm{E}: E: 互质数的个数试题 F: 阶乘的和试题 G: 小蓝的旅行计划试题 H: 太阳试题 I: 高塔试题 J \mathrm{J} J : 反异或 01 串 发现…

代码随想录 Day45 动态规划(背包问题)

对背包问题有了更深刻的理解,物品的遍历是背包可能要装的物品的遍历,跟多少数量有关系,而背包的遍历则跟物品的重量,背包的容量,以及价值有关。

C语言数据结构易错知识点(3)(堆)

1.堆的本质:完全二叉树 堆在物理结构上是顺序结构,实现方式类似于顺序表(数组);但在逻辑结构上是树形结构,准确来说堆是一棵完全二叉树。因为堆的实现在代码上和顺序表有很高的相似度,所以在写…

关于用max,min函数超时的情况—算法小Tips

今天在做这道题的时候,有了一点对一些题max函数min函数就会超时的思考,不是每道题都这样,但也可以是个做题小tips; 题目连接:登录—专业IT笔试面试备考平台_牛客网 题目很简单,用个前缀和暴力一下就行&…

【Preprocessing数据预处理】之Scaler

在机器学习中,特征缩放是训练模型前数据预处理阶段的一个关键步骤。不同的缩放器被用来规范化或标准化特征。这里简要概述了您提到的几种缩放器: StandardScaler StandardScaler 通过去除均值并缩放至单位方差来标准化特征。这种缩放器假设特征分布是正…

C语言从入门到熟悉------第四阶段

指针 地址和指针的概念 要明白什么是指针,必须先要弄清楚数据在内存中是如何存储的,又是如何被读取的。如果在程序中定义了一个变量,在对程序进行编译时,系统就会为这个变量分配内存单元。编译系统根据程序中定义的变量类型分配…

深度学习 精选笔记(11)深度学习计算相关:GPU、参数、读写、块

学习参考: 动手学深度学习2.0Deep-Learning-with-TensorFlow-bookpytorchlightning ①如有冒犯、请联系侵删。 ②已写完的笔记文章会不定时一直修订修改(删、改、增),以达到集多方教程的精华于一文的目的。 ③非常推荐上面(学习参考&#x…

C++ 入门篇

目录 1、了解C 2、C关键字 2、命名空间 2.1 命名空间的定义 2.2 命名空间的使用 3. C输入与输出 4.缺省参数 4.1 缺省参数的概念 4.2 缺省参数的分类 5. 函数重载 5.1 函数重载的概念 5.2 C中支持函数重载的原理--名字修饰 6. 引用 6.1 引用概念 6.2 引用…

HTML案例-2.标签综合练习

目录 效果 知识点 1.图像标签 2.链接标签 3.锚点定位 4.base标签 源码 页面1 页面2 效果 知识点 1.图像标签 <img src="图像URL" /> 单标签 属性 属性值 描述 src URL 图像的路径 alt 文本

Linux使用Docker部署Registry结合内网穿透实现公网远程拉取推送镜像

文章目录 1. 部署Docker Registry2. 本地测试推送镜像3. Linux 安装cpolar4. 配置Docker Registry公网访问地址5. 公网远程推送Docker Registry6. 固定Docker Registry公网地址 Docker Registry 本地镜像仓库,简单几步结合cpolar内网穿透工具实现远程pull or push (拉取和推送)…

0G联合创始人MICHAEL HEINRICH确认出席Hack.Summit() 2024区块链开发者大会

随着区块链技术的不断发展和应用&#xff0c;全球开发者瞩目的Hack.Summit() 2024区块链开发者大会即将于2024年4月9日至10日在香港数码港盛大举行。此次大会由Hack VC主办&#xff0c;并得到AltLayer和Berachain的协办&#xff0c;同时汇聚了Solana、The Graph、Blockchain Ac…

路由和流量控制

项目拓扑与项目需求 项目需求:某政务网络拥有两个园区&#xff0c;园区A和园区B之间通过物理专线相连。IP地址如图所示。现在需要实现以下需求&#xff1a; 要求A园区无法访问B园区的vlan 30 网络&#xff0c;要求使用路由过滤的方式实现。 配置步骤 设备IP地址的规划 设备名…

从0开始回顾MySQL---事务四大特性

事务概述 事务是一个最小的工作单元。在数据库当中&#xff0c;事务表示一件完整的事儿。一个业务的完成可能需要多条DML语句共同配合才能完成&#xff0c;例如转账业务&#xff0c;需要执行两条DML语句&#xff0c;先更新张三账户的余额&#xff0c;再更新李四账户的余额&…

螺旋阵思维与代码

1.思维 56789419202110318252211217242312116151413观察上面的螺旋阵,你就会发现数字从小到大,按照贝壳的螺旋形依次排列. 走到头就要换一个方向. 看到螺旋数组可以让人想象到贪吃蛇,拿出一个字符串设置为方向,碰到头方向改变,这样循环模拟,直到格子里的数>行和列数(n) .…

c++ 常用函数 集锦 整理中

c 常用函数集锦 目录 c 常用函数集锦 1、string和wstring之间转换 2、经纬度转 xyz 值 互转 3 、获取 根目录下的文件地址 1、string和wstring之间转换 std::string convertWStringToString(std::wstring wstr) {std::string str;if (!wstr.empty()){std::wstring_convert<…

51-31 VastGaussian,3D高斯大型场景重建

2024 年 2 月&#xff0c;清华大学、华为和中科院联合发布的 VastGaussian 模型&#xff0c;实现了基于 3D Gaussian Splatting 进行大型场景高保真重建和实时渲染。 Abstract 现有基于NeRF大型场景重建方法&#xff0c;往往在视觉质量和渲染速度方面存在局限性。虽然最近 3D…

pycharm @NotNull parameter ‘module‘ of ...

下载了最新pycharm &#xff0c;无法启动运行 pycharm或者idea中Run/Debug Python项目报错 Argument for NotNull parameter ‘module‘ of … 解决方案 删除项目根目录的 idea 文件夹 随后重启&#xff0c;重新配置即可

图论(蓝桥杯 C++ 题目 代码 注解)

目录 迪杰斯特拉模板&#xff08;用来求一个点出发到其它点的最短距离&#xff09;&#xff1a; 克鲁斯卡尔模板&#xff08;用来求最小生成树&#xff09;&#xff1a; 题目一&#xff08;蓝桥王国&#xff09;&#xff1a; 题目二&#xff08;随机数据下的最短路径&#…