什么是知识蒸馏?什么是Knowledge Distillation?知识蒸馏实例

知识蒸馏

  • 1. 知识蒸馏的核心概念
    • 什么是知识蒸馏?
  • 2. 知识蒸馏的关键组成部分
    • (1)温度调节(Temperature Scaling)
    • (2)蒸馏损失(Distillation Loss)
    • (3)蒸馏流程
  • 3. 知识蒸馏的主要方法
    • (1)经典蒸馏(Soft Target Distillation)
    • (2)中间层特征蒸馏(Feature-based Distillation)
    • (3)对抗式蒸馏(Adversarial Distillation)
    • (4)自蒸馏(Self-Distillation)
  • 4. 知识蒸馏的优点
  • 5. 知识蒸馏的实现步骤
  • 6. 知识蒸馏的应用场景
  • 从理论到实践全面掌握知识蒸馏

知识蒸馏(Knowledge Distillation)是机器学习中的一种技术,主要用于将一个复杂的、计算成本高的大模型(通常称为教师模型,Teacher Model)中的知识提炼并传递给一个较小的、计算高效的模型(通常称为 学生模型,Student Model)。通过这种方式,学生模型在保持接近教师模型性能的同时,具备更高的效率和更低的计算需求。

以下是系统性学习知识蒸馏的步骤:

1. 知识蒸馏的核心概念

什么是知识蒸馏?

知识蒸馏是一种训练方法,重点是通过利用教师模型的输出(例如概率分布或中间特征)作为“软目标”(Soft Target),指导学生模型的训练,而不是直接依赖训练数据的真实标签(硬目标,Hard Target)。

  • 硬目标(Hard Target): 常规分类问题中,每个样本的标签是确定的,例如“猫”的类别是1,其余类别为0。
  • 软目标(Soft Target): 教师模型输出的概率分布,通常包含更多的信息。例如,教师模型预测“猫”的概率为0.8,但也可能预测“狗”是0.1、“兔子”是0.05,这些反映了教师对类别间关系的理解。

2. 知识蒸馏的关键组成部分

(1)温度调节(Temperature Scaling)

在知识蒸馏中,教师模型的输出概率通常会通过温度参数 T T T 进行调节:

$$

q_i = \frac{\exp(z_i / T)}{\sum_{j} \exp(z_j / T)}

$$

  • z i z_i zi 是模型预测的原始得分(logits)。
  • T T T 是温度参数,较高的 T T T 会使输出分布更加平滑,包含更多类别间关系的信息。

学生模型的目标是模仿教师模型的这些经过温度调节的概率分布。

(2)蒸馏损失(Distillation Loss)

知识蒸馏的训练目标是最小化以下两个损失函数的加权和:

  1. 蒸馏损失(Distillation Loss): 让学生模型模仿教师模型的概率分布,常用交叉熵来衡量两者的差异。
    L distill = − ∑ i q i teacher log ⁡ q i student \mathcal{L}{\text{distill}} = -\sum{i} q_i^{\text{teacher}} \log q_i^{\text{student}} Ldistill=iqiteacherlogqistudent
  2. 监督损失(Supervised Loss): 学生模型使用真实标签进行传统监督训练。
    L supervised = − ∑ i y i true log ⁡ q i student \mathcal{L}{\text{supervised}} = -\sum{i} y_i^{\text{true}} \log q_i^{\text{student}} Lsupervised=iyitruelogqistudent

总损失函数:
L = α ⋅ L distill + ( 1 − α ) ⋅ L supervised \mathcal{L} = \alpha \cdot \mathcal{L}{\text{distill}} + (1 - \alpha) \cdot \mathcal{L}{\text{supervised}} L=αLdistill+(1α)Lsupervised

其中 α \alpha α 是平衡两个损失的超参数。

(3)蒸馏流程

  1. 先训练一个性能较好的教师模型。
  2. 利用教师模型生成概率分布(软目标)。
  3. 用上述蒸馏损失训练学生模型,使其学习教师的知识。

3. 知识蒸馏的主要方法

(1)经典蒸馏(Soft Target Distillation)

学生模型通过模仿教师模型输出的软目标概率分布进行训练,这是知识蒸馏最基础的形式。

(2)中间层特征蒸馏(Feature-based Distillation)

除了模仿最终输出概率,学生模型还可以学习教师模型中间层的特征表示,从而更好地捕捉深层信息。

(3)对抗式蒸馏(Adversarial Distillation)

将蒸馏过程视为生成对抗网络(GAN)的形式,学生模型作为生成器,教师模型的特征表示作为判别器的目标,使学生生成的输出更加接近教师。

(4)自蒸馏(Self-Distillation)

一种特殊形式,学生模型和教师模型使用相同的结构。学生模型从前几轮训练的“教师模型版本”中学习。这种方法不需要单独训练教师模型。

4. 知识蒸馏的优点

  1. 降低模型复杂度: 减少计算资源需求,使模型更适合部署在边缘设备或实时应用中。
  2. 保留教师模型知识: 学生模型不仅学习到了准确性,还能捕获类别间的潜在关系。
  3. 提升小模型性能: 即使学生模型参数少,通过知识蒸馏,性能通常优于直接训练的小模型。

5. 知识蒸馏的实现步骤

以下是用Python(PyTorch)实现知识蒸馏的简化代码示例:

import torch
import torch.nn as nn
import torch.optim as optim# 假设已经定义好教师模型 (teacher_model) 和学生模型 (student_model)# 超参数
temperature = 4.0
alpha = 0.7  # 蒸馏损失权重
learning_rate = 0.001# 定义蒸馏损失
def distillation_loss(student_logits, teacher_logits, temperature):soft_teacher = nn.functional.softmax(teacher_logits / temperature, dim=1)soft_student = nn.functional.log_softmax(student_logits / temperature, dim=1)return nn.functional.kl_div(soft_student, soft_teacher, reduction="batchmean") * (temperature ** 2)# 优化器和损失函数
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(student_model.parameters(), lr=learning_rate)# 训练循环
for epoch in range(num_epochs):for data, labels in train_loader:# 教师模型的输出teacher_logits = teacher_model(data).detach()# 学生模型的输出student_logits = student_model(data)# 计算蒸馏损失loss_distill = distillation_loss(student_logits, teacher_logits, temperature)loss_supervised = criterion(student_logits, labels)# 总损失loss = alpha * loss_distill + (1 - alpha) * loss_supervised# 反向传播optimizer.zero_grad()loss.backward()optimizer.step()

6. 知识蒸馏的应用场景

  1. 模型压缩: 将大型预训练模型(如GPT-3、BERT)压缩成轻量化版本,便于在移动设备或嵌入式系统中使用。
  2. 迁移学习: 将复杂模型的知识迁移到特定领域的小模型中。
  3. 多模型集成: 用多个教师模型的输出指导单一学生模型的训练,合并多个模型的知识。
  4. 实时推理: 提高模型推理速度,适应低延迟场景。

通过以上步骤和知识,你可以从理论到实践全面掌握知识蒸馏!

从理论到实践全面掌握知识蒸馏

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/476925.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

java学习-集合

为什么有集合? 自动扩容 数组:长度固定,可以存基本数据类型和引用数据类型 集合:长度可变,可以存引用数据类型,基本数据类型的话需要包装类 ArrayList public class studentTest {public static void m…

MATLAB GUI设计(基础)

一、目的和要求 1、熟悉和掌握MATLAB GUI的基本控件的使用及属性设置。 2、熟悉和掌握通过GUIDE创建MATLAB GUI的方法。 3、熟悉和掌握MATLAB GUI的菜单、对话框及文件管理框的设计。 4、熟悉和掌握MATLAB GUI的M文件编写。 5、了解通过程序创建MATLAB GUI的方法。 二、内…

【工具变量】中国省级及地级市保障性住房数据集(2010-2023年)

一、测算方式:参考顶刊《世界经济》蔡庆丰(2024)老师的研究,具体而言,本文将土地用途为经济适用住房用地、廉租住房用地、公共租赁住房用地、共有产权住房用 地等类型的土地定义为具有保障性住房用途的土地。根据具有保…

第T8周:Tensorflow实现猫狗识别(1)

🍨 本文为🔗365天深度学习训练营 中的学习记录博客🍖 原作者:K同学啊 具体实现 (一)环境 语言环境:Python 3.10 编 译 器: PyCharm 框 架: (二)具体步骤 from absl.l…

Day 18

修建二叉搜索树 link:669. 修剪二叉搜索树 - 力扣(LeetCode) 思路分析 注意修剪的时候要考虑到全部的节点,即搜到到限定区间小于左值或者大于右值时还需要检查当前不符合区间大小节点的右子树/左子树,不能直接返回n…

核间通信-Linux下RPMsg使用与源码框架分析

目录 1 文档目的 2 相关概念 2.1 术语 2.2 RPMsg相关概念 3 RPMsg核间通信软硬件模块框架 3.1 硬件原理 3.2 软件框架 4 使用RPMsg进行核间通信 4.1 RPMsg通信建立 4.1.1 使用名称服务建立通信 4.1.2 不用名称服务 4.2 RPMsg应用过程 4.3 应用层示例 5 RPMsg内核…

常用Adb 命令

# 连接设备 adb connect 192.168.10.125# 断开连接 adb disconnect 192.168.10.125# 查看已连接的设备 adb devices# 安装webview adb install -r "D:\webview\com.google.android.webview_103.0.5060.129-506012903_minAPI23(arm64-v8a,armeabi-v7a)(nodpi)_apkmirror.co…

高质量代理池go_Proxy_Pool

高质量代理池go_Proxy_Pool 声明! 学习视频来自B站up主 ​泷羽sec​​ 有兴趣的师傅可以关注一下,如涉及侵权马上删除文章 笔记只是方便各位师傅的学习和探讨,文章所提到的网站以及内容,只做学习交流,其他均与本人以…

有关博客博客系统的测试报告 --- 初次进行项目测试篇

文章目录 前言一、博客系统的项目背景二、博客系统的项目简介1.后端功能1.1 用户管理1.2 博客管理1.3 权限管理 2.前端功能2.1 用户界面 测试计划测试工具、环境设计的测试动作功能测试访问博客登录页面博客首页测试博客详情页博客编辑页 自动化测试自动化测试用例自动化测试脚…

物业管理系统的设计和实现

一、项目背景 物业管理系统在现代城市化进程中起着至关重要的作用。 随着居民生活水平的提高和信息技术的迅猛发展,传统的物业管理模式已不能满足业主和管理者的需求。 为了提高管理效率、降低运营成本、提升服务质量,设计并实现一个集成化、智能化的物业…

JDBC编程---Java

目录 一、数据库编程的前置 二、Java的数据库编程----JDBC 1.概念 2.JDBC编程的优点 三.导入MySQL驱动包 四、JDBC编程的实战 1.创造数据源,并设置数据库所在的位置,三条固定写法 2.建立和数据库服务器之间的连接,连接好了后&#xff…

快速图像识别:落叶植物叶片分类

1.背景意义 研究背景与意义 随着全球生态环境的变化,植物的多样性及其在生态系统中的重要性日益受到关注。植物叶片的分类不仅是植物学研究的基础,也是生态监测、农业管理和生物多样性保护的重要环节。传统的植物分类方法依赖于人工观察和专家知识&…

数字化那点事:一文读懂物联网

一、物联网是什么? 物联网(Internet of Things,简称IoT)是指通过网络将各种物理设备连接起来,使它们可以互相通信并进行数据交换的技术系统。通过在物理对象中嵌入传感器、处理器、通信模块等硬件,IoT将“…

IntelliJ+SpringBoot项目实战(十)--常量类、自定义错误页、全局异常处理

一、常量类 在项目开发中,经常需要约定一些常量,比如接口返回响应请求指定状态码、异常类型、默认页数等,为了增加代码的可阅读性以及开发团队中规范一些常量的使用,可开发一些常量类。下面有3个常量类示例,代码位于op…

ubuntu20.04的arduino+MU编辑器安装教程

arduino 按照这个博客,是2.3版本的: Ubuntu20.04/22.04 安装 Arduino IDE 2.x_ubuntu ide-CSDN博客https://blog.csdn.net/michaelchain/article/details/128744935以下这个博客是1.8版本的 在ubuntu系统安装Arduino IDE的方法_ubuntu arduino ide-CS…

Docker核心概念总结

本文只是对 Docker 的概念做了较为详细的介绍,并不涉及一些像 Docker 环境的安装以及 Docker 的一些常见操作和命令。 容器介绍 Docker 是世界领先的软件容器平台,所以想要搞懂 Docker 的概念我们必须先从容器开始说起。 什么是容器? 先来看看容器较为…

Redis ⽀持哪⼏种数据类型?适⽤场景,底层结构

目录 Redis 数据类型 一、String(字符串) 二、Hash(哈希) 三、List(列表) 四、Set(集合) 五、ZSet(sorted set:有序集合) 六、BitMap 七、HyperLogLog 八、GEO …

uniapp接入BMapGL百度地图

下面代码兼容安卓APP和H5 百度地图官网:控制台 | 百度地图开放平台 应用类别选择《浏览器端》 /utils/map.js 需要设置你自己的key export function myBMapGL1() {return new Promise(function(resolve, reject) {if (typeof window.initMyBMapGL1 function) {r…

Docker+Nginx | Docker(Nginx) + Docker(fastapi)反向代理

在DockerHub搜 nginx,第一个就是官方镜像库,这里使用1.27.2版本演示 1.下载镜像 docker pull nginx:1.27.2 2.测试运行 docker run --name nginx -p 9090:80 -d nginx:1.27.2 这里绑定了宿主机的9090端口,只要访问宿主机的9090端口&#…

AmazonS3集成minio实现https访问

最近系统全面升级到https,之前AmazonS3大文件分片上传直接使用http://ip:9000访问minio的方式已然行不通,https服务器访问http资源会报Mixed Content混合内容错误。 一般有两种解决方案,一是升级minio服务,配置ssl证书&#xff0c…