人工智能学习与实训笔记(七):神经网络之推荐系统处理

九、模型压缩与知识蒸馏

出于对响应速度,存储大小和能耗的考虑,往往需要对大模型进行压缩。

模型压缩方法主要可以分为以下四类:

  • 参数修剪和量化(Parameter pruning and quantization):用于消除对模型表现影响不大的冗余参数。早期工作表明,网络修剪和量化在降低网络复杂性和解决过拟合问题上是有效的。它可以为神经网络带来正则化效果从而提高泛化能力。参数修剪和量化可以进一步分为三类:量化和二值化,网络剪枝和结构化矩阵。量化可以看作是“量子级别的减肥”,神经网络模型的参数一般都用float32的数据表示,但如果我们将float32的数据计算精度变成int8的计算精度,则可以牺牲一点模型精度来换取更快的计算速度。而剪枝则类似“化学结构式的减肥”,将模型结构中对预测结果不重要的网络结构剪裁掉,使网络结构变得更加 ”瘦身“。比如,在每层网络,有些神经元节点的权重非常小,对模型加载信息的影响微乎其微。如果将这些权重较小的神经元删除,则既能保证模型精度不受大影响,又能减小模型大小。结构化矩阵则是用少于 m×n 个参数来描述一个 m×n 阶矩阵,以此来减少内存消耗。
  • 低秩分解(Low-rank factorization):卷积神经网络中的主要计算量在于卷积计算,而卷积计算本质上是矩阵分析问题,因此可以通过对多维矩阵进行分解的方式,用多个低秩矩阵来逼近该矩阵,比如将一个3D卷积转换为3个1D卷积,从而降低参数复杂度和运算复杂度。
  • 迁移/压缩卷积滤波器(Transferred/compact convolutional filters):通过构造特殊结构的卷积滤波器来降低存储空间、减小计算复杂度。
  • 知识蒸馏(Knowledge distillation):类似“老师教学生”,使用一个效果好的大模型指导一个小模型训练,因为大模型可以提供更多的软分类信息量,所以会训练出一个效果接近大模型的小模型。

9.1 知识蒸馏

知识蒸馏(knowledge distillation)是模型压缩的一种常用的方法,不同于模型压缩中的剪枝和量化,知识蒸馏是通过构建一个轻量化的小模型,利用性能更好的大模型的监督信息,来训练这个小模型,以期达到更好的性能和精度。最早是由 Hinton 在 2015 年首次提出并应用在分类任务上面,这个大模型我们称之为 teacher(教师模型),小模型我们称之为 Student(学生模型)。来自 Teacher 模型输出的监督信息称之为 knowledge(知识),而 student 学习迁移来自 teacher 的监督信息的过程称之为 Distillation(蒸馏)。

9.1.1 知识蒸馏的原理

一般使用蒸馏的时候,往往会找一个参数量更小的 student 网络,那么相比于 teacher 来说,这个轻量级的网络不能很好的学习到数据集之前隐藏的潜在关系,如上图所示,相比于 one hot 的输出,teacher 网络是将输出的 logits 进行了 softmax,更加平滑的处理了标签,即将数字 1 输出成了 0.6(对 1 的预测)和 0.4(对 0 的预测)然后输入到 student 网络中,相比于 1 来说,这种 softmax 含有更多的信息。好模型的目标不是拟合训练数据,而是学习如何泛化到新的数据。所以蒸馏的目标是让 student 学习到 teacher 的泛化能力,理论上得到的结果会比单纯拟合训练数据的 student 要好。另外,对于分类任务,如果 soft targets 的熵比 hard targets 高,那显然 student 会学习到更多的信息。最终 student 模型学习的是 teacher 模型的泛化能力,而不是“过拟合训练数据”

  • 1. 如上图所示,左边的教师网络是一个复杂的大模型,以它带有温度参数T的softmax输出作为软目标作为学生网络学习的软目标。
  • 2. 学生网络在学习时,也通过带有温度参数T的softmax进行概率分布预测,与软目标计算soft loss。
  • 3. 同时,也通过正常的训练流程获得预测的样本类别与真实的样本类别计算hard loss。
  • 4 最终根据 γ∗softloss+(1−γ)∗hardloss作为损失函数来训练学生网络。

这个公式就是知识蒸馏的核心理论。其实就是要让学生模型学习到老师模型的泛化能力。

9.1.2 知识蒸馏的种类

1、 离线蒸馏

离线蒸馏方式即为传统的知识蒸馏,如上图(a)。用户需要在已知数据集上面提前训练好一个 teacher 模型,然后在对 student 模型进行训练的时候,利用所获取的 teacher 模型进行监督训练来达到蒸馏的目的,而且这个 teacher 的训练精度要比 student 模型精度要高,差值越大,蒸馏效果也就越明显。一般来讲,teacher 的模型参数在蒸馏训练的过程中保持不变,达到训练 student 模型的目的。蒸馏的损失函数 distillation loss 计算 teacher 和 student 之前输出预测值的差别,和 student 的 loss 加在一起作为整个训练 loss,来进行梯度更新,最终得到一个更高性能和精度的 student 模型。

2、 半监督蒸馏

半监督方式的蒸馏利用了 teacher 模型的预测信息作为标签,来对 student 网络进行监督学习,如上图(b)。那么不同于传统离线蒸馏的方式,在对 student 模型训练之前,先输入部分的未标记的数据,利用 teacher 网络输出标签作为监督信息再输入到 student 网络中,来完成蒸馏过程,这样就可以使用更少标注量的数据集,达到提升模型精度的目的。

3、 自监督蒸馏

自监督蒸馏相比于传统的离线蒸馏的方式是不需要提前训练一个 teacher 网络模型,而是 student 网络本身的训练完成一个蒸馏过程,如上图(c)。具体实现方式 有多种,例如先开始训练 student 模型,在整个训练过程的最后几个 epoch 的时候,利用前面训练的 student 作为监督模型,在剩下的 epoch 中,对模型进行蒸馏。这样做的好处是不需要提前训练好 teacher 模型,就可以变训练边蒸馏,节省整个蒸馏过程的训练时间。

9.1.3 知识蒸馏的作用

1、提升模型精度

用户如果对目前的网络模型 A 的精度不是很满意,那么可以先训练一个更高精度的 teacher 模型 B(通常参数量更多,时延更大),然后用这个训练好的 teacher 模型 B 对 student 模型 A 进行知识蒸馏,得到一个更高精度的模型。

2、降低模型时延,压缩网络参数

用户如果对目前的网络模型 A 的时延不满意,可以先找到一个时延更低,参数量更小的模型 B,通常来讲,这种模型精度也会比较低,然后通过训练一个更高精度的 teacher 模型 C 来对这个参数量小的模型 B 进行知识蒸馏,使得该模型 B 的精度接近最原始的模型 A,从而达到降低时延的目的。

3、图片标签之间的域迁移

用户使用狗和猫的数据集训练了一个 teacher 模型 A,使用香蕉和苹果训练了一个 teacher 模型 B,那么就可以用这两个模型同时蒸馏出一个可以识别狗,猫,香蕉以及苹果的模型,将两个不同与的数据集进行集成和迁移。

4、降低标注量

该功能可以通过半监督的蒸馏方式来实现,用户利用训练好的 teacher 网络模型来对未标注的数据集进行蒸馏,达到降低标注量的

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/258242.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

蓝桥省赛真题|简单:分数

题目链接:https://www.lanqiao.cn/problems/610/learning/?page1&first_category_id1&second_category_id3&tags2018&name%E5%88%86%E6%95%B0 题不难,但是可以帮助编程时好的习惯的养成,更加注意一些细节。 注意几个地方︰…

机器人专题:我国机器人产业园区发展现状、问题、经验及建议

今天分享的是机器人系列深度研究报告:《机器人专题:我国机器人产业园区发展现状、问题、经验及建议》。 (报告出品方:赛迪研究院) 报告共计:26页 机器人作为推动工业化发展和数字中国建设的重要工具&…

【RT-DETR有效改进】利用EMAttention加深网络深度提高模型特征提取能力(特征选择模块)

一、本文介绍 本文给大家带来的改进机制是EMAttention注意力机制,它的核心思想是,重塑部分通道到批次维度,并将通道维度分组为多个子特征,以保留每个通道的信息并减少计算开销。EMA模块通过编码全局信息来重新校准每个并行分支中…

【Qt】环境安装与初识

目录 一、Qt背景介绍 二、搭建Qt开发环境 三、新建工程 四、Qt中的命名规范 五、Qt Creator中的快捷键 六、QWidget基础项目文件详解 6.1 .pro文件解析 6.2 widget.h文件解析 6.3 widget.cpp文件解析 6.4 widget.ui文件解析 6.5 main.cpp文件解析 七、对象树 八、…

MySQL 基础知识(一)之数据库和 SQL 概述

目录 1 数据库相关概念 2 数据库的结构 ​3 SQL 概要 4 SQL 的基本书写规则 1 数据库相关概念 数据库是将大量的数据保存起来,通过计算机加工而成的可以进行高效访问的数据集合数据库管理系统(DBMS)是用来管理数据库的计算机系统&#xf…

指针的经典笔试题

经典的指针试题,让你彻底理解指针 前言 之前对于指针做了一个详解,现在来看一些关于指针的经典面试题。 再次说一下数组名 数组名通常表示的都是首元素的地址,但是有两个意外,1.sizeof(数组名)这里数组名…

stm32:pwm output模块,记录一下我是用smt32,输出pwm波的记录--(实现--重要)

我是实现了输出pwm波,频率固定,占空比可以不断调整的方法,将PA0接到示波器上,可以看到是一个标准的PWM波,如图下面示波器图。 1,首先是ioc的配置 我刚开始设置的分频的倍数是7199,使得分频的太…

Mac M2芯片配置PHP环境

Mac M2芯片配置PHP环境 1. XAMPP2. PHPBrew(PHP版本管理)安装php7.4.33版本 参考 1. XAMPP 官网地址 https://www.apachefriends.org/ 安装 安装完成 web server打开后,在打开localhost 成功! 2. PHPBrew(PHP版本管理) 官方文档 https://github.co…

【教程】C++语言基础学习笔记(五)——Vector向量

写在前面: 如果文章对你有帮助,记得点赞关注加收藏一波,利于以后需要的时候复习,多谢支持! 【C语言基础学习】系列文章 第一章 《项目与程序结构》 第二章 《数据类型》 第三章 《运算符》 第四章 《流程控制》 第五章…

【OpenAI Sora】开启未来:视频生成模型作为终极世界模拟器的突破之旅

这份技术报告主要关注两个方面:(1)我们的方法将各种类型的视觉数据转化为统一的表示形式,从而实现了大规模生成模型的训练;(2)对Sora的能力和局限性进行了定性评估。报告中不包含模型和实现细节…

jenkins 发布远程服务器并部署项目

安装参考另一个文章 配置maven 和 jdk 和 git 注意jdk的安装目录,是jenkins 安装所在服务器的jdk目录 注意maven的目录 是jenkins 安装所在服务器的maven目录 注意git的目录 是jenkins 安装所在服务器的 git 目录 安装 Publish Over SSH 插件 配置远程服务器 创…

【AIGC】Stable Diffusion的采样器入门

在 Stable Diffusion 中,采样器(Sampler)是指用于生成图像的一种技术或方法,它决定了模型如何从潜在空间中抽样并生成图像。采样器在生成图像的过程中起着重要作用,影响着生成图像的多样性、质量和创造性。以下是对 St…

批量梯度下降、随机梯度下降、小批量梯度下降

一、批量梯度下降(Batch Gradient Descent,BGD) 在批量梯度下降中,每次迭代都使用整个训练集的数据进行梯度计算和参数更新。也就是说,每次迭代都对所有的样本求取梯度,然后更新参数。由于要处理整个训练集&#xff0c…

用HTML5实现动画

用HTML5实现动画 要在HTML5中实现动画&#xff0c;可以使用以下几种方法&#xff1a;CSS动画、使用<canvas>元素和JavaScript来实现动画、使用JavaScript动画库。重点介绍前两种。 一、CSS动画 CSS3 动画&#xff1a;使用CSS3的动画属性和关键帧&#xff08;keyframes&…

第三节 zookeeper基础应用与实战2

目录 1. Watch事件监听 1.1 一次性监听方式&#xff1a;Watcher 1.2 Curator事件监听机制 2. 事务&异步操作演示 2.1 事务演示 2.2 异步操作 3. Zookeeper权限控制 3.1 zk权限控制介绍 3.2 Scheme 权限模式 3.3 ID 授权对象 3.4 Permission权限类型 3.5 在控制台…

JDBC教程+数据库连接池

JDBC 1.JDBC概述 ​ JDBC&#xff0c;全称Java数据库连接&#xff08;Java DataBase Connectivity&#xff09;&#xff0c;它是使用Java语言操作关系型数据库的一套API。 ​ JDBC本质是官方&#xff08;原SUN公司&#xff0c;现ORACLE&#xff09;定义的一套操作所有关系型数…

讲解用Python处理Excel表格

我们今天来一起探索一下用Python怎么操作Excel文件。与word文件的操作库python-docx类似&#xff0c;Python也有专门的库为Excel文件的操作提供支持&#xff0c;这些库包括xlrd、xlwt、xlutils、openpyxl、xlsxwriter几种&#xff0c;其中我最喜欢用的是openpyxl&#xff0c;这…

GitLab配置SSHKey

段落一&#xff1a;什么是SSH密钥 SSH&#xff08;Secure Shell&#xff09;是一种网络协议&#xff0c;用于安全地远程登录和执行命令。SSH密钥是一种用于身份验证的加密文件&#xff0c;它允许您在与远程服务器通信时&#xff0c;无需输入密码即可进行认证。在GitLab中配置S…

Vue2学习第一天

Vue2 学习第一天 1. 什么是 vue? Vue 是一套用于构建用户界面的渐进式框架。 2. vue 历史 vue 是在 2013 年创建的&#xff0c;vue3 是 2020 出现的&#xff0c;现在主要是用 vue2&#xff0c;创新公司用的是 vue3 vue 的作者是尤雨溪&#xff0c;vue 的搜索热度比 react…

【算法随想录03】相交链表

题目&#xff1a;160. 相交链表 难度&#xff1a;EASY 思路 主要难点在于如何进行节点之间的对应。两条链表长度不定长&#xff0c;如何找到需要对比的节点至关重要。 我们从后往前看&#xff0c;我们需要对比的节点有什么特点。一个最大的特点就是后面的节点数相同。这就…