TextBrewer:融合并改进了NLP和CV中的多种知识蒸馏技术、提供便捷快速的知识蒸馏框架、提升模型的推理速度,减少内存占用

TextBrewer:融合并改进了NLP和CV中的多种知识蒸馏技术、提供便捷快速的知识蒸馏框架、提升模型的推理速度,减少内存占用

TextBrewer是一个基于PyTorch的、为实现NLP中的知识蒸馏任务而设计的工具包,
融合并改进了NLP和CV中的多种知识蒸馏技术,提供便捷快速的知识蒸馏框架,用于以较低的性能损失压缩神经网络模型的大小,提升模型的推理速度,减少内存占用。

1.简介

TextBrewer 为NLP中的知识蒸馏任务设计,融合了多种知识蒸馏技术,提供方便快捷的知识蒸馏框架。

主要特点:

  • 模型无关:适用于多种模型结构(主要面向Transfomer结构)
  • 方便灵活:可自由组合多种蒸馏方法;可方便增加自定义损失等模块
  • 非侵入式:无需对教师与学生模型本身结构进行修改
  • 支持典型的NLP任务:文本分类、阅读理解、序列标注等

TextBrewer目前支持的知识蒸馏技术有:

  • 软标签与硬标签混合训练
  • 动态损失权重调整与蒸馏温度调整
  • 多种蒸馏损失函数: hidden states MSE, attention-based loss, neuron selectivity transfer, …
  • 任意构建中间层特征匹配方案
  • 多教师知识蒸馏

TextBrewer的主要功能与模块分为3块:

  1. Distillers:进行蒸馏的核心部件,不同的distiller提供不同的蒸馏模式。目前包含GeneralDistiller, MultiTeacherDistiller, MultiTaskDistiller等
  2. Configurations and Presets:训练与蒸馏方法的配置,并提供预定义的蒸馏策略以及多种知识蒸馏损失函数
  3. Utilities:模型参数分析显示等辅助工具

用户需要准备:

  1. 已训练好的教师模型, 待蒸馏的学生模型
  2. 训练数据与必要的实验配置, 即可开始蒸馏

在多个典型NLP任务上,TextBrewer都能取得较好的压缩效果。相关实验见蒸馏效果。

2.TextBrewer结构

2.1 安装要求

  • Python >= 3.6

  • PyTorch >= 1.1.0

  • TensorboardX or Tensorboard

  • NumPy

  • tqdm

  • Transformers >= 2.0 (可选, Transformer相关示例需要用到)

  • Apex == 0.1.0 (可选,用于混合精度训练)

  • 从PyPI自动下载安装包安装:

pip install textbrewer
  • 从源码文件夹安装:
git clone https://github.com/airaria/TextBrewer.git
pip install ./textbrewer

2.2工作流程

  • Stage 1 : 蒸馏之前的准备工作:

    1. 训练教师模型
    2. 定义与初始化学生模型(随机初始化,或载入预训练权重)
    3. 构造蒸馏用数据集的dataloader,训练学生模型用的optimizer和learning rate scheduler
  • Stage 2 : 使用TextBrewer蒸馏:

    1. 构造训练配置(TrainingConfig)和蒸馏配置(DistillationConfig),初始化distiller
    2. 定义adaptorcallback ,分别用于适配模型输入输出和训练过程中的回调
    3. 调用distillertrain方法开始蒸馏

2.3 以蒸馏BERT-base到3层BERT为例展示TextBrewer用法

在开始蒸馏之前准备:

  • 训练好的教师模型teacher_model (BERT-base),待训练学生模型student_model (3-layer BERT)
  • 数据集dataloader,优化器optimizer,学习率调节器类或者构造函数scheduler_class 和构造用的参数字典 scheduler_args

使用TextBrewer蒸馏:

import textbrewer
from textbrewer import GeneralDistiller
from textbrewer import TrainingConfig, DistillationConfig#展示模型参数量的统计
print("\nteacher_model's parametrers:")
result, _ = textbrewer.utils.display_parameters(teacher_model,max_level=3)
print (result)print("student_model's parametrers:")
result, _ = textbrewer.utils.display_parameters(student_model,max_level=3)
print (result)#定义adaptor用于解释模型的输出
def simple_adaptor(batch, model_outputs):# model输出的第二、三个元素分别是logits和hidden statesreturn {'logits': model_outputs[1], 'hidden': model_outputs[2]}#蒸馏与训练配置
# 匹配教师和学生的embedding层;同时匹配教师的第8层和学生的第2层
distill_config = DistillationConfig(intermediate_matches=[    {'layer_T':0, 'layer_S':0, 'feature':'hidden', 'loss': 'hidden_mse','weight' : 1},{'layer_T':8, 'layer_S':2, 'feature':'hidden', 'loss': 'hidden_mse','weight' : 1}])
train_config = TrainingConfig()#初始化distiller
distiller = GeneralDistiller(train_config=train_config, distill_config = distill_config,model_T = teacher_model, model_S = student_model, adaptor_T = simple_adaptor, adaptor_S = simple_adaptor)#开始蒸馏
with distiller:distiller.train(optimizer, dataloader, num_epochs=1, scheduler_class=scheduler_class, scheduler_args = scheduler_args, callback=None)

2.4蒸馏任务示例

  • Transformers 4示例

    • examples/notebook_examples/sst2.ipynb (英文): SST-2文本分类任务上的BERT模型训练与蒸馏。
    • examples/notebook_examples/msra_ner.ipynb (中文): MSRA NER中文命名实体识别任务上的BERT模型训练与蒸馏。
    • examples/notebook_examples/sqaudv1.1.ipynb (英文): SQuAD 1.1英文阅读理解任务上的BERT模型训练与蒸馏。
  • examples/random_token_example: 一个可运行的简单示例,在文本分类任务上以随机文本为输入,演示TextBrewer用法。

  • examples/cmrc2018_example (中文): CMRC 2018上的中文阅读理解任务蒸馏,并使用DRCD数据集做数据增强。

  • examples/mnli_example (英文): MNLI任务上的英文句对分类任务蒸馏,并展示如何使用多教师蒸馏。

  • examples/conll2003_example (英文): CoNLL-2003英文实体识别任务上的序列标注任务蒸馏。

  • examples/msra_ner_example (中文): MSRA NER(中文命名实体识别)任务上,使用分布式数据并行训练的Chinese-ELECTRA-base模型蒸馏。

2.4.1蒸馏效果

我们在多个中英文文本分类、阅读理解、序列标注数据集上进行了蒸馏实验。实验的配置和效果如下。

  • 模型

  • 对于英文任务,教师模型为BERT-base-cased

  • 对于中文任务,教师模型为HFL发布的RoBERTa-wwm-extElectra-base

我们测试了不同的学生模型,为了与已有公开结果相比较,除了BiGRU都是和BERT一样的多层Transformer结构。模型的参数如下表所示。需要注意的是,参数量的统计包括了embedding层,但不包括最终适配各个任务的输出层。

  • 英文模型
Model#LayersHidden sizeFeed-forward size#ParamsRelative size
BERT-base-cased (教师)127683072108M100%
T6 (学生)6768307265M60%
T3 (学生)3768307244M41%
T3-small (学生)3384153617M16%
T4-Tiny (学生)4312120014M13%
T12-nano (学生)12256102417M16%
BiGRU (学生)-768-31M29%
  • 中文模型
Model#LayersHidden sizeFeed-forward size#ParamsRelative size
RoBERTa-wwm-ext (教师)127683072102M100%
Electra-base (教师)127683072102M100%
T3 (学生)3768307238M37%
T3-small (学生)3384153614M14%
T4-Tiny (学生)4312120011M11%
Electra-small (学生)12256102412M12%
  • T6的结构与DistilBERT[1], BERT6-PKD[2], BERT-of-Theseus[3] 相同。
  • T4-tiny的结构与 TinyBERT[4] 相同。
  • T3的结构与BERT3-PKD[2] 相同。

2.4.2 蒸馏配置

distill_config = DistillationConfig(temperature = 8, intermediate_matches = matches)
#其他参数为默认值

不同的模型用的matches我们采用了以下配置:

Modelmatches
BiGRUNone
T6L6_hidden_mse + L6_hidden_smmd
T3L3_hidden_mse + L3_hidden_smmd
T3-smallL3n_hidden_mse + L3_hidden_smmd
T4-TinyL4t_hidden_mse + L4_hidden_smmd
T12-nanosmall_hidden_mse + small_hidden_smmd
Electra-smallsmall_hidden_mse + small_hidden_smmd

各种matches的定义在examples/matches/matches.py中。均使用GeneralDistiller进行蒸馏。

2.4.3训练配置

蒸馏用的学习率 lr=1e-4(除非特殊说明)。训练30~60轮。

2.4.4英文实验结果

在英文实验中,我们使用了如下三个典型数据集。

DatasetTask typeMetrics#Train#DevNote
MNLI文本分类m/mm Acc393K20K句对三分类任务
SQuAD 1.1阅读理解EM/F188K11K篇章片段抽取型阅读理解
CoNLL-2003序列标注F123K6K命名实体识别任务

我们在下面两表中列出了DistilBERT, BERT-PKD, BERT-of-Theseus, TinyBERT 等公开的蒸馏结果,并与我们的结果做对比。

Public results:

Model (public)MNLISQuADCoNLL-2003
DistilBERT (T6)81.6 / 81.178.1 / 86.2-
BERT6-PKD (T6)81.5 / 81.077.1 / 85.3-
BERT-of-Theseus (T6)82.4/ 82.1--
BERT3-PKD (T3)76.7 / 76.3--
TinyBERT (T4-tiny)82.8 / 82.972.7 / 82.1-

Our results:

Model (ours)MNLISQuADCoNLL-2003
BERT-base-cased (教师)83.7 / 84.081.5 / 88.691.1
BiGRU--85.3
T683.5 / 84.080.8 / 88.190.7
T381.8 / 82.776.4 / 84.987.5
T3-small81.3 / 81.772.3 / 81.478.6
T4-tiny82.0 / 82.675.2 / 84.089.1
T12-nano83.2 / 83.979.0 / 86.689.6

说明:

  1. 公开模型的名称后括号内是其等价的模型结构
  2. 蒸馏到T4-tiny的实验中,SQuAD任务上使用了NewsQA作为增强数据;CoNLL-2003上使用了HotpotQA的篇章作为增强数据
  3. 蒸馏到T12-nano的实验中,CoNLL-2003上使用了HotpotQA的篇章作为增强数据

2.4.5中文实验结果

在中文实验中,我们使用了如下典型数据集。

DatasetTask typeMetrics#Train#DevNote
XNLI文本分类Acc393K2.5KMNLI的中文翻译版本,3分类任务
LCQMC文本分类Acc239K8.8K句对二分类任务,判断两个句子的语义是否相同
CMRC 2018阅读理解EM/F110K3.4K篇章片段抽取型阅读理解
DRCD阅读理解EM/F127K3.5K繁体中文篇章片段抽取型阅读理解
MSRA NER序列标注F145K3.4K (测试集)中文命名实体识别

实验结果如下表所示。

ModelXNLILCQMCCMRC 2018DRCD
RoBERTa-wwm-ext (教师)79.989.468.8 / 86.486.5 / 92.5
T378.489.066.4 / 84.278.2 / 86.4
T3-small76.088.158.0 / 79.375.8 / 84.8
T4-tiny76.288.461.8 / 81.877.3 / 86.1
ModelXNLILCQMCCMRC 2018DRCDMSRA NER
Electra-base (教师)77.889.865.6 / 84.786.9 / 92.395.14
Electra-small77.789.366.5 / 84.985.5 / 91.393.48

说明:

  1. 以RoBERTa-wwm-ext为教师模型蒸馏CMRC 2018和DRCD时,不采用学习率衰减
  2. CMRC 2018和DRCD两个任务上蒸馏时他们互作为增强数据
  3. Electra-base的教师模型训练设置参考自Chinese-ELECTRA
  4. Electra-small学生模型采用预训练权重初始化

3.核心概念

3.1Configurations

  • TrainingConfigDistillationConfig:训练和蒸馏相关的配置。

3.2Distillers

Distiller负责执行实际的蒸馏过程。目前实现了以下的distillers:

  • BasicDistiller: 提供单模型单任务蒸馏方式。可用作测试或简单实验。
  • GeneralDistiller (常用): 提供单模型单任务蒸馏方式,并且支持中间层特征匹配,一般情况下推荐使用
  • MultiTeacherDistiller: 多教师蒸馏。将多个(同任务)教师模型蒸馏到一个学生模型上。暂不支持中间层特征匹配
  • MultiTaskDistiller:多任务蒸馏。将多个(不同任务)单任务教师模型蒸馏到一个多任务学生模型。
  • BasicTrainer:用于单个模型的有监督训练,而非蒸馏。可用于训练教师模型

3.3用户定义函数

蒸馏实验中,有两个组件需要由用户提供,分别是callbackadaptor :

3.3.1Callback

回调函数。在每个checkpoint,保存模型后会被distiller调用,并传入当前模型。可以借由回调函数在每个checkpoint评测模型效果。

3.3.2Adaptor

将模型的输入和输出转换为指定的格式,向distiller解释模型的输入和输出,以便distiller根据不同的策略进行不同的计算。在每个训练步,batch和模型的输出model_outputs会作为参数传递给adaptoradaptor负责重新组织这些数据,返回一个字典。

更多细节可参见完整文档中的说明。

4.FAQ

Q: 学生模型该如何初始化?

A: 知识蒸馏本质上是“老师教学生”的过程。在初始化学生模型时,可以采用随机初始化的形式(即完全不包含任何先验知识),也可以载入已训练好的模型权重。例如,从BERT-base模型蒸馏到3层BERT时,可以预先载入RBT3模型权重(中文任务)或BERT的前三层权重(英文任务),然后进一步进行蒸馏,避免了蒸馏过程的“冷启动”问题。我们建议用户在使用时尽量采用已预训练过的学生模型,以充分利用大规模数据预训练所带来的优势。

Q: 如何设置蒸馏的训练参数以达到一个较好的效果?

A: 知识蒸馏的比有标签数据上的训练需要更多的训练轮数与更大的学习率。比如,BERT-base上训练SQuAD一般以lr=3e-5训练3轮左右即可达到较好的效果;而蒸馏时需要以lr=1e-4训练30~50轮。当然具体到各个任务上肯定还有区别,我们的建议仅是基于我们的经验得出的,仅供参考

Q: 我的教师模型和学生模型的输入不同(比如词表不同导致input_ids不兼容),该如何进行蒸馏?

A: 需要分别为教师模型和学生模型提供不同的batch,参见完整文档中的 Feed Different batches to Student and Teacher, Feed Cached Values 章节。

Q: 我缓存了教师模型的输出,它们可以用于加速蒸馏吗?

A: 可以, 参见完整文档中的 Feed Different batches to Student and Teacher, Feed Cached Values 章节。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/80334.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C高级-day4

#!/bin/bash function fun1(){arr[0]id -u $1arr[1]id -g $1echo ${arr[*]} }arr(fun1 ubuntu) echo ${arr[*]}冒泡排序 void Maopao(int arr[],int len){for(int i1;i<len;i){int count0;for(int j0;j<len-i;j){if(arr[j]>arr[j1]){int tarr[j];arr[j]arr[j1];arr[j…

智能优化算法——哈里鹰算法(Matlab实现)

目录 1 算法简介 2 算法数学模型 2.1.全局探索阶段 2.2 过渡阶段 2.3.局部开采阶段 3 求解步骤与程序框图 3.1 步骤 3.2 程序框图 4 matlab代码及结果 4.1 代码 4.2 结果 1 算法简介 哈里斯鹰算法(Harris Hawks Optimization&#xff0c;HHO)&#xff0c;是由Ali…

如何解决 Elasticsearch 查询缓慢的问题以获得更好的用户体验

作者&#xff1a;Philipp Kahr Elasticsearch Service 用户的重要注意事项&#xff1a;目前&#xff0c;本文中描述的 Kibana 设置更改仅限于 Cloud 控制台&#xff0c;如果没有我们支持团队的手动干预&#xff0c;则无法进行配置。 我们的工程团队正在努力消除对这些设置的限制…

arcgis宗地或者地块四至权利人信息提取教程

ARCGIS怎样将图斑四邻的名称及方位加入其属性表 以前曾发表过一篇《 如何把相邻图斑的属性添加在某个字段中》的个人心得,有些会员提出了进一步的要求,不但要相邻图斑的名称,还要求有方位,下面讲一下自己的做法。 基本思路是:连接相邻图斑质心,根据连线的角度确定相邻图斑…

c++游戏制作指南(三):c++剧情类文字游戏的制作

&#x1f37f;*★,*:.☆(&#xffe3;▽&#xffe3;)/$:*.★* &#x1f37f; &#x1f35f;欢迎来到静渊隐者的csdn博文&#xff0c;本文是c游戏制作指南的一部&#x1f35f; &#x1f355;更多文章请点击下方链接&#x1f355; &#x1f368; c游戏制作指南&#x1f3…

21、p6spy输出执行SQL日志

文章目录 1、背景2、简介3、接入3.1、 引入依赖3.2、修改database参数&#xff1a;3.3、 创建P6SpyLogger类&#xff0c;自定义日志格式3.4、添加spy.properties3.5、 输出样例 4、补充4.1、参数说明 1、背景 在开发的过程中&#xff0c;总希望方法执行完了可以看到完整是sql语…

供水管网漏损监测,24小时保障城市供水安全

供水管网作为城市生命线重要组成部分&#xff0c;其安全运行是城市建设和人民生活的基本保障。随着我国社会经济的快速发展和城市化进程的加快&#xff0c;城市供水管网的建设规模日益增长。然而&#xff0c;由于管网老化、外力破坏和不当维护等因素导致的供水管网漏损&#xf…

js修改img的src属性显示变换图片到前端页面,img的src属性显示java后台读取返回的本地图片

文章目录 前言一、HTML 图像- 图像标签&#xff08; <img>&#xff09;1.1图像标签的源属性&#xff08;Src&#xff09;1.2图像标签源属性&#xff08;Src&#xff09;显示项目中图片1.3图像标签源属性&#xff08;Src&#xff09;显示网络图片 二、图像标签&#xff08…

Docker实战-关于Docker镜像的相关操作(二)

导语   之前的分享中&#xff0c;我们介绍了关于Docker镜像的查询操作相关的内容&#xff0c;下面我们继续来介绍删除清理、导入导出、创建镜像等操作。 如何删除和清理镜像&#xff1f; 使用标签删除镜像 可以使用docker rmi 或者是 docker image rm 命令来删除镜像&#x…

segment-anything使用说明

文章目录 一. segment-anything介绍二. 官网Demo使用说明三. 安装教程四. python调用生成掩码教程五. python调用SAM分割后转labelme数据集 一. segment-anything介绍 Segment Anything Model&#xff08;SAM&#xff09;根据点或框等输入提示生成高质量的对象遮罩&#xff0c…

spring eurake中使用IP注册

在开发spring cloud的时候遇到一个很奇葩的问题&#xff0c;就是服务向spring eureka中注册实例的时候使用的是机器名&#xff0c;然后出现localhost、xxx.xx等这样的内容&#xff0c;如下图&#xff1a; eureka.instance.perferIpAddresstrue 我不知道这朋友用的什么spring c…

爬虫009_字符串高级_替换_去空格_分割_取长度_统计字符_间隔插入---python工作笔记028

然后再来看字符串的高级操作 取长度 查找字符串下标位置 判断是否以某个字符,开头结尾 计算字符出现次数 替换

sigmoid ReLU 等激活函数总结

sigmoid ReLU sigoid和ReLU对比 1.sigmoid有梯度消失问题&#xff1a;当sigmoid的输出非常接近0或者1时&#xff0c;区域的梯度几乎为0&#xff0c;而ReLU在正区间的梯度总为1。如果Sigmoid没有正确初始化&#xff0c;它可能在正区间得到几乎为0的梯度。使模型无法有效训练。 …

什么是OCR?OCR技术详解

光学字符识别(Optical Character Recognition)简称为“OCR”。ORC是指对包含文本资料的图像文件进行分析识别处理&#xff0c;获取文字及版面信息的技术。 一般包括以下几个过程&#xff1a; 1.图像输入 针对不同格式的图像&#xff0c;有着不同的存储格式和压缩方式。目前&…

TypeScript学习笔记

1.ts和js的区别 2. ts的优势 3. ts下载后报错解决方法 报错: PS C:\Users\\Desktop> tsc -v tsc : 无法加载文件 C:\Users\32173\AppData\Roaming\npm\tsc.ps1&#xff0c;因为在此系统上禁止运行脚本。有关详细信息&#xff0c;请参阅 https:/ go.microsoft.com/fwlink/?…

接口测试——电商网站接口测试实战(四)

1. 接口测试需求分析 常见接口文档提供的两种方式 ①word文档 ②在线文档 电商网站网址模拟练习&#xff1a;Swagger UI 2. 登陆的分析 慕慕生鲜网址&#xff1a;慕慕生鲜账号密码点击execute后 输入账号密码后点击开发者工具&#xff0c;再登录&#xff0c;点击网络&…

Linux下C/C++的gdb工具与Python的pdb工具常见用法之对比

1、gdb和pdb分别是什么&#xff1f; 1.1、gdb GDB&#xff08;GNU Debugger&#xff09;是一个功能强大的命令行调试工具&#xff0c;由GNU项目开发&#xff0c;用于调试C、C等编程语言的程序。它在多个操作系统中都可以使用&#xff0c;包括Linux、MacOS和Windows&#xff0…

oracle的管道函数

Oracle管道函数(Pipelined Table Function)oracle管道函数 1、管道函数即是可以返回行集合&#xff08;可以使嵌套表nested table 或数组 varray&#xff09;的函数&#xff0c;我们可以像查询物理表一样查询它或者将其赋值给集合变量。 2、管道函数为并行执行&#xff0c;在…

如何实现基于场景的接口自动化测试用例?

自动化本身是为了提高工作效率&#xff0c;不论选择何种框架&#xff0c;何种开发语言&#xff0c;我们最终想实现的效果&#xff0c;就是让大家用最少的代码&#xff0c;最小的投入&#xff0c;完成自动化测试的工作。 基于这个想法&#xff0c;我们的接口自动化测试思路如下…

软件设计师(七)面向对象技术

面向对象&#xff1a; Object-Oriented&#xff0c; 是一种以客观世界中的对象为中心的开发方法。 面向对象方法有Booch方法、Coad方法和OMT方法等。推出了同一建模语言UML。 面向对象方法包括面向对象分析、面向对象设计和面向对象实现。 一、面向对象基础 1、面向对象的基本…