知识蒸馏代码实现(以MNIST手写数字体为例,自定义MLP网络做为教师和学生网络)

dataloader_tools.py

import torchvision
from torchvision import transforms
from torch.utils.data import DataLoaderdef load_data():# 载入MNIST训练集train_dataset = torchvision.datasets.MNIST(root = "../datasets/",train=True,transform=transforms.ToTensor(),download=True)# 载入MNIST测试集test_dataset = torchvision.datasets.MNIST(root = "../datasets/",train=False,transform=transforms.ToTensor(),download=True)# 生成训练集和测试集的dataloadertrain_dataloader = DataLoader(dataset=train_dataset,batch_size=12,shuffle=True)test_dataloader = DataLoader(dataset=test_dataset,batch_size=12,shuffle=False)return train_dataloader,test_dataloader

models.py

import torch
from torch import nn
# 教师模型
class TeacherModel(nn.Module):def __init__(self,in_channels=1,num_classes=10):super(TeacherModel,self).__init__()self.relu = nn.ReLU()self.fc1 = nn.Linear(784,1200)self.fc2 = nn.Linear(1200,1200)self.fc3 = nn.Linear(1200,num_classes)self.dropout = nn.Dropout(p=0.5) #p=0.5是丢弃该层一半的神经元.def forward(self,x):x = x.view(-1,784)x = self.fc1(x)x = self.dropout(x)x = self.relu(x)x = self.fc2(x)x = self.dropout(x)x = self.relu(x)x = self.fc3(x)return xclass StudentModel(nn.Module):def __init__(self,in_channels=1,num_classes=10):super(StudentModel,self).__init__()self.relu = nn.ReLU()self.fc1 = nn.Linear(784,20)self.fc2 = nn.Linear(20,20)self.fc3 = nn.Linear(20,num_classes)def forward(self,x):x = x.view(-1,784)x = self.fc1(x)x = self.relu(x)x = self.fc2(x)x = self.relu(x)x = self.fc3(x)return x

train_tools.py

from torch import nn
import time
import torch
import tqdm
import torch.nn.functional as Fdef train(epochs, model, model_name, lr,train_dataloader,test_dataloader,device):# ----------------------开始计时-----------------------------------start_time = time.time()# 设置参数开始训练best_acc, best_epoch = 0, 0criterion = nn.CrossEntropyLoss()optimizer = torch.optim.Adam(model.parameters(), lr=lr)for epoch in range(epochs):model.train()# 训练集上训练模型权重for data, targets in tqdm.tqdm(train_dataloader):# 把数据加载到GPU上data = data.to(device)targets = targets.to(device)# 前向传播preds = model(data)loss = criterion(preds, targets)# 反向传播optimizer.zero_grad()loss.backward()optimizer.step()# 测试集上评估模型性能model.eval()num_correct = 0num_samples = 0with torch.no_grad():for x, y in test_dataloader:x = x.to(device)y = y.to(device)preds = model(x)predictions = preds.max(1).indices  # 返回每一行的最大值和该最大值在该行的列索引num_correct += (predictions == y).sum()num_samples += predictions.size(0)acc = (num_correct / num_samples).item()if acc > best_acc:best_acc = accbest_epoch = epoch# 保存模型最优准确率的参数torch.save(model.state_dict(), f"../weights/{model_name}_best_acc_params.pth")model.train()print('Epoch:{}\t Accuracy:{:.4f}'.format(epoch + 1, acc),f'loss={loss}')print(f'最优准确率的epoch为{best_epoch},值为:{best_acc},最优参数已经保存到:weights/{model_name}_best_acc_params.pth')# -------------------------结束计时------------------------------------end_time = time.time()run_time = end_time - start_time# 将输出的秒数保留两位小数if int(run_time) < 60:print(f'训练用时为:{round(run_time, 2)}s')else:print(f'训练用时为:{round(run_time / 60, 2)}minutes')def distill_train(epochs,teacher_model,student_model,model_name,train_dataloader,test_dataloader,alpha,lr,temp,device):# -------------------------------------开始计时--------------------------------start_time = time.time()# 定以损失函数hard_loss = nn.CrossEntropyLoss()soft_loss = nn.KLDivLoss(reduction="batchmean")# 定义优化器optimizer = torch.optim.Adam(student_model.parameters(), lr=lr)best_acc,best_epoch = 0,0for epoch in range(epochs):student_model.train()# 训练集上训练模型权重for data,targets in tqdm.tqdm(train_dataloader):# 把数据加载到GPU上data = data.to(device)targets = targets.to(device)# 教师模型预测with torch.no_grad():teacher_preds = teacher_model(data)# 学生模型预测student_preds = student_model(data)# 计算hard_lossstudent_hard_loss = hard_loss(student_preds,targets)# 计算蒸馏后的预测结果及soft_lossditillation_loss = soft_loss(F.softmax(student_preds/temp,dim=1),F.softmax(teacher_preds/temp,dim=1))# 将hard_loss和soft_loss加权求和loss = temp * temp * alpha * student_hard_loss + (1-alpha)*ditillation_loss# 反向传播,优化权重optimizer.zero_grad()loss.backward()optimizer.step()#测试集上评估模型性能student_model.eval()num_correct = 0num_samples = 0with torch.no_grad():for x,y in test_dataloader:x = x.to(device)y = y.to(device)preds = student_model(x)predictions = preds.max(1).indices #返回每一行的最大值和该最大值在该行的列索引num_correct += (predictions ==y).sum()num_samples += predictions.size(0)acc = (num_correct/num_samples).item()if acc>best_acc:best_acc = accbest_epoch = epoch# 保存模型最优准确率的参数torch.save(student_model.state_dict(),f"../weights/{model_name}_best_acc_params.pth")student_model.train()print('Epoch:{}\t Accuracy:{:.4f}'.format(epoch+1,acc))print(f'student_hard_loss={student_hard_loss},ditillation_loss={ditillation_loss},loss={loss}')print(f'最优准确率的epoch为{best_epoch},值为:{best_acc},')# --------------------------------结束计时----------------------------------end_time = time.time()run_time = end_time - start_time# 将输出的秒数保留两位小数if int(run_time) < 60:print(f'训练用时为:{round(run_time, 2)}s')else:print(f'训练用时为:{round(run_time / 60, 2)}minutes')

训练教师网络

import torch
from torchinfo import summary #用来可视化的
import models
import dataloader_tools
import train_tools# 设置随机数种子
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 使用cuDNN加速卷积运算
torch.backends.cudnn.benchmark = True# 载入MNIST训练集和测试集
train_dataloader,test_dataloader = dataloader_tools.load_data()# 定义教师模型
model = models.TeacherModel()
model = model.to(device)
# 打印模型的参数
summary(model)# 定义参数并开始训练
epochs = 10
lr = 1e-4
model_name = 'teacher'
train_tools.train(epochs,model,model_name,lr,train_dataloader,test_dataloader,device)
最优准确率的epoch为9,值为:0.9868999719619751

用非蒸馏的方法训练学生网络

import torch
from torchinfo import summary #用来可视化的
import dataloader_tools
import models
import train_tools# 设置随机数种子
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 使用cuDNN加速卷积运算
torch.backends.cudnn.benchmark = True# 生成训练集和测试集的dataloader
train_dataloader,test_dataloader = dataloader_tools.load_data()# 从头训练学生模型
model = models.StudentModel()
model = model.to(device)
# 查看模型参数
print(summary(model))# 定义参数并开始训练
epochs = 10
lr = 1e-4
model_name = 'student'
train_tools.train(epochs, model, model_name, lr,train_dataloader,test_dataloader,device)
最优准确率的epoch为9,准确率为:0.9382999539375305,最优参数已经保存到:weights/student_best_acc_params.pth
训练用时为:1.74minutes

用知识蒸馏的方法训练student model

import torch
import train_tools
import models
import dataloader_tools# 设置随机数种子
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 使用cuDNN加速卷积运算
torch.backends.cudnn.benchmark = True# 加载数据
train_dataloader,test_dataloader = dataloader_tools.load_data()# 加载训练好的teacher model
teacher_model = models.TeacherModel()
teacher_model = teacher_model.to(device)
teacher_model.load_state_dict(torch.load('../weights/teacher_best_acc_params.pth'))
teacher_model.eval()# 准备新的学生模型
student_model = models.StudentModel()
student_model = student_model.to(device)
student_model.train()# 开始训练
lr = 0.0001
epochs = 20
alpha = 0.3 # hard_loss权重
temp = 7 # 蒸馏温度
model_name = 'distill_student_loss'
# 调用train_tools中的
train_tools.distill_train(epochs,teacher_model,student_model,model_name,train_dataloader,test_dataloader,alpha,lr,temp,device)
最优准确率的epoch为9,值为:0.9204999804496765,
训练用时为:2.14minutes

在这里插入图片描述

loss改为:

# temp的平方乘在student_hard_loss
loss = temp * temp * alpha * student_hard_loss + (1 - alpha) * ditillation_loss
最优准确率的epoch为9,值为:0.9336999654769897,
训练用时为:2.12minutes

loss改为:

# temp的平方乘ditillation_loss
loss = alpha * student_hard_loss + temp * temp * (1 - alpha) * ditillation_loss
最优准确率的epoch为9,值为:0.9176999926567078,
训练用时为:2.09minutes

上面的几种loss,蒸馏损失都出现了负数的情况。不太对劲。
在这里插入图片描述

其它开源的知识蒸馏算法如下:

open-mmlab开源的工具箱包含知识蒸馏算法

mmrazor

github.com/open-mmlab/mmrazor

在这里插入图片描述

NAS:神经架构搜索
剪枝:Pruning
KD: 知识蒸馏
Quantization: 量化

自定义知识蒸馏算法:
在这里插入图片描述

mmdeploy

可以把算法部署到一些厂商支持的中间格式,如ONNX,tensorRT等。

在这里插入图片描述

HobbitLong的RepDistiller

github.com/HobbitLong/RepDistiller

在这里插入图片描述
在这里插入图片描述
里面有12种最新的知识蒸馏算法。

蒸馏网络可以应用于同一种模型,将大的学习的知识蒸馏到小的上面。
如下将resnet100做教师网络,resnet32做学生网络。

在这里插入图片描述

将一种模型迁移到另一种模型上。如vgg13做教师网络,mobilNetv2做学生网络:

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/206632.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

鸿蒙系统扫盲(三):鸿蒙开发用什么语言?

1.两种开发方向 我们常说鸿蒙开发&#xff0c;但是其实鸿蒙开发分为两个方向&#xff1a; 一个是系统级别的开发&#xff0c;比如驱动&#xff0c;内核和框架层的开发&#xff0c;这种开发以C/C为主 还有一个是应用级别的开发&#xff0c;在API7以及以下&#xff0c;还是支持…

咨询+低代码,强强联合为制造业客户赋能

内容来自演讲&#xff1a;沈毅 | 遨睿智库 | 董事长 & 王劭禹 | 橙木智能 | 联合创始人 摘要 文章主要讲述了智库董事长沈毅创办广告公司的经历&#xff0c;以及他在管理公司过程中遇到的问题和挑战&#xff0c;最后通过与明道云以及橙木智能联合创始人王邵禹老师的合作&…

分布式篇---第五篇

系列文章目录 文章目录 系列文章目录前言一、你知道哪些限流算法?二、说说什么是计数器(固定窗口)算法三、说说什么是滑动窗口算法前言 前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家。点击跳转到网站,这篇文章男女通用,看懂了就去…

【软件测试】盘一盘工作中遇到的 Redis 异常测试

在测试工作中&#xff0c;涉及到与 redis 交互的场景变的越来越多了。关于redis本身就不作赘述了&#xff0c;网上随便搜&#xff0c;本人也做过一些整理。 今天只来复盘一下&#xff0c;在测试过程中与 redis 的二三事儿。其中提到的案例是经过抽象化的&#xff0c;用作辅助说…

阿里云跨账号建立局域网

最近有活动&#xff0c;和好友一并薅了下阿里云的羊毛。琢磨着两台机器组一个局域网&#xff0c;于是有了这个需求&#xff0c;把步骤记录一下&#xff1a; 假设两台机器叫A和B&#xff0c;我们开始进行建立和组网 1. 建立ECS 把A机器公共环境装好&#xff0c;然后使用《实例与…

【驱动】SPI驱动分析(四)-关键API解析

关键API 设备树 设备树解析 我们以Firefly 的SPI demo 分析下dts中对spi的描述&#xff1a; /* Firefly SPI demo */ &spi1 {spi_demo: spi-demo00{status "okay";compatible "firefly,rk3399-spi";reg <0x00>;spi-max-frequency <48…

什么是木马

木马 1. 定义2. 木马的特征3. 木马攻击流程4. 常见木马类型5. 如何防御木马 1. 定义 木马一名来源于古希腊特洛伊战争中著名的“木马计”&#xff0c;指可以非法控制计算机&#xff0c;或在他人计算机中从事秘密活动的恶意软件。 木马通过伪装成正常软件被下载到用户主机&…

js无法请求后端接口,别的都可以?

在每个接口的控制器中加入以下代码即可&#xff1a; header(Access-Control-Allow-Methods:*); header("Access-Control-Allow-Origin:*"); 如果嫌麻烦可以添加在api初始函数里面

万宾科技可燃气体监测仪的功能有哪些?

随着城市人口的持续增长和智慧城市不断发展&#xff0c;燃气作为一种重要的能源供应方式&#xff0c;已经广泛地应用于居民生活和工业生产的各个领域。然而燃气泄漏和安全事故的风险也随之增加&#xff0c;对城市的安全和社会的稳定构成了潜在的威胁。我国燃气管道安全事故的频…

设计模式详解(三):工厂方法

目录导航 抽象工厂及其作用工厂方法的好处工厂方法的实现关系图实现步骤 工厂方法的适用场景工厂方法举例 抽象工厂及其作用 工厂方法是一种创建型设计模式。所谓创建型设计模式是说针对创建对象方面的设计模式。在面向对象的编程语言里&#xff0c;我们通过对象间的相互协作&…

智能优化算法应用:基于教与学算法无线传感器网络(WSN)覆盖优化 - 附代码

智能优化算法应用&#xff1a;基于教与学算法无线传感器网络(WSN)覆盖优化 - 附代码 文章目录 智能优化算法应用&#xff1a;基于教与学算法无线传感器网络(WSN)覆盖优化 - 附代码1.无线传感网络节点模型2.覆盖数学模型及分析3.教与学算法4.实验参数设定5.算法结果6.参考文献7.…

leetCode 78.子集 + 回溯算法 + 图解

给你一个整数数组 nums &#xff0c;数组中的元素 互不相同 。返回该数组所有可能的子集&#xff08;幂集&#xff09;。解集 不能 包含重复的子集。你可以按 任意顺序 返回解集 示例 1&#xff1a; 输入&#xff1a;nums [1,2,3] 输出&#xff1a;[[],[1],[2],[1,2],[3],[1…

STM32---时钟树

写在前面&#xff1a;一个 MCU 越复杂&#xff0c;时钟系统也会相应地变得复杂&#xff0c;如 STM32F1 的时钟系统比较复杂&#xff0c;不像简单的 51 单片机一个系统时钟就 可以解决一切。对于 STM32F1 系列的芯片&#xff0c;其有多个时钟源&#xff0c;构成了一个庞大的是时…

Python 异常处理(try except)

文章目录 1 概述1.1 异常示例 2 异常处理2.1 捕获异常 try except2.2 抛出异常 raise 3 异常类型3.1 内置异常3.2 自定义异常 1 概述 1.1 异常示例 异常&#xff1a;程序执行中出现错误&#xff0c;若不处理&#xff0c;则程序终止 示例代码&#xff1a; v 6 / 0 # 除数不…

Redis入门指南学习笔记(3):Redis高级特性

一.前言 上一篇博客对Redis常用的数据结构进行了详细介绍。Redis除了丰富的数据类型支持&#xff0c;还包含许多高级特性&#xff0c;例如事务、内存驻留策略、排序、消息队列等&#xff0c;本文将对这些进行逐一介绍。 二.事务 Redis同样包含事务&#xff08;transaction&a…

【刷题笔记】H指数||数组||二分查找的变体

H指数 最新编辑于2023.11.29 之前的代码写得有点抽象&#xff0c;实在抱歉&#xff0c;好像我自己都不理解当时自己怎么写的&#xff0c;现在重新更新了代码&#xff0c;保证好理解。 1 题目描述 https://leetcode.cn/problems/h-index/ 给你一个整数数组 citations &#xf…

Python基础学习之包与模块详解

文章目录 前言什么是 Python 的包与模块包的身份证如何创建包创建包的小练习 包的导入 - import模块的导入 - from…import导入子包及子包函数的调用导入主包及主包的函数调用导入的包与子包模块之间过长如何优化 强大的第三方包什么是第三方包如何安装第三方包 总结关于Python…

【人工智能Ⅰ】实验3:蚁群算法

实验3 蚁群算法的应用 一、实验内容 TSP 问题的蚁群算法实现。 二、实验目的 1. 熟悉和掌握蚁群算法的基本概念和思想&#xff1b; 2. 理解和掌握蚁群算法的参数选取&#xff0c;解决实际应用问题。 三、实验原理 1&#xff0e;算法来源 蚁群算法的基本原理来源于自然界…

网络基础_1

目录 网络基础 协议 协议分层 OSI七层模型 网络传输的基本流程 数据包的封装和分用 IP地址和MAC地址 网络基础 网络就是不同的计算机之间可以进行通信&#xff0c;前面我们学了同一台计算机之间通信&#xff0c;其中有进程间通信&#xff0c;前面学过的有管道&#xff…

在Linux上安装KVM虚拟机

一、搭建KVM环境 KVM&#xff08;Kernel-based Virtual Machine&#xff09;是一个基于内核的系统虚拟化模块&#xff0c;从Linux内核版本2.6.20开始&#xff0c;各大Linux发行版就已经将其集成于发行版中。KVM与Xen等虚拟化相比&#xff0c;需要硬件支持的完全虚拟化。KVM由内…