【模型压缩】原理及实例

在移动智能终端品类越发多样的时代,为了让模型可以顺利部署在算力和存储空间都受限的移动终端,对模型进行压缩尤为重要。模型压缩(model compression)可以降低神经网络参数量,减少延迟时间,从而实现提高神经网络推理速度、节省存储空间等目的。

一.量化

量化是指将模型权重参数用更少的比特数存储,以此来减少模型的存储空间和算力消耗。

1.基本原理

(1) 量化感知训练

Quantization-aware Training,QAT在训练过程中模拟量化过程,数据虽然表示为float32,但实际的值的间隔却会受到量化参数的设置。

QAT的具体流程如下:

1)初始化:设置权重和激活值范的范围q_{min}q_{max}的初始值;

2)构建模拟量化网络:在需要量化的权重和激活值后插入伪量化算子;

3)量化训练:重复执行以下步骤直至网络收敛(计算量化网络层的权重和激活值的范围q_{min}q_{max},并根据该范围将量化损失带入到前向推理和后向参数更新的过程中);

4)导出量化网络:获取q_{min}q_{max},并计算量化参数,将量化参数s和z代入到量化公式中,转换网络中的权重为量化整数值;删除伪量化算子,在量化网络层前后分别插入量化和反量化算子。

(2) 后训练动态量化

Post training dynamic quantization是在浮点模型训练收敛之后进行量化操作,weight被提前量化,activation在前向推理过程中被动态量化(即每次都要根据实际运算的浮点数据范围每一层计算1次scale和zero_point,然后进行量化)。

在量化激活值时会以校准数据集为输入,执行推理流程然后统计每层激活值的数据分布并得到相应的量化参数,具体操作流程如下:

1)使用直方图统计的方式得到原始float32数据的统计分布P_{f}

2)在给定的搜索空间中选取若干个q_{min}q_{max}分别对激活值进行量化,得到量化后的数据Q_{q}

3)使用直方图统计得到Q_{q}的统计分布;

4)计算每个Q_{q}P_{f}的统计分布差异,并找到差异性最低的1个对应的q_{min}q_{max}来计算相应的量化参数;常用的用于度量分布差异的指标包括KL散度、对称KL散度和JS散度。

(3) 后训练静态量化

activation会基于之前校准过程中记录下的固定的scale和zero_point进行量化,整个过程不存在量化参数(scale,zero_point)的再计算。

2.代码实例

(1) 加载数据

import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
from torch.quantization import QuantStub, DeQuantStub
import torch.optim as optim
from torch.quantization import get_default_qconfig, prepare_qat, convert# 定义数据预处理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 标准化
])# 加载训练集和测试集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=True, transform=transform)
testloader = DataLoader(testset, batch_size=64, shuffle=False, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

(2) 构建量化网络

class QuantizedCNN(nn.Module):def __init__(self):super(QuantizedCNN, self).__init__()self.quant = QuantStub()self.conv1 = nn.Conv2d(3, 16, 5)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(16, 32, 5)self.fc1 = nn.Linear(32 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)self.dequant = DeQuantStub()def forward(self, x):# x = self.quant(x)x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = torch.flatten(x, 1)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)x = self.dequant(x)return xdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = QuantizedCNN().to(device)
model.qconfig = get_default_qconfig('qnnpack')

(3) 量化训练并保存模型

# 训练循环
num_epochs = 10
for epoch in range(num_epochs):model.train()running_loss = 0.0for i, data in enumerate(trainloader, 0):inputs, labels = data[0].to(device), data[1].to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()if i % 2000 == 1999:print(f'[{epoch + 1}, {i + 1}] loss: {running_loss / 2000:.3f}')running_loss = 0.0# 切换到评估模式进行测试model.eval()correct = 0total = 0with torch.no_grad():for data in testloader:images, labels = data[0].to(device), data[1].to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))# 在最后1个epoch后完成量化if epoch == num_epochs - 1:model_quantized = convert(model.eval(), inplace=True)print("Model quantization completed.")# 保存量化模型torch.save(model_quantized.state_dict(), 'quantized_model.pth')

(4) 模型测试

def test_quantized_model(model, dataloader, device='cpu'):model = convert(model.eval(), inplace=True)model.to(device) correct = 0total = 0with torch.no_grad(): for data, targets in dataloader:data, targets = data.to(device), targets.to(device)  outputs = model(data)  _, predicted = torch.max(outputs.data, 1)  total += targets.size(0)correct += (predicted == targets).sum().item()accuracy = 100 * correct / totalprint(f'Accuracy of the quantized model on the test data: {accuracy:.2f}%')# 测试模型
quantized_model=QuantizedCNN()
quantized_model.load_state_dict(torch.load('quantized_model.pth'))
test_quantized_model(quantized_model, test_loader, device='cuda' if torch.cuda.is_available() else 'cpu'

二.剪枝

剪枝是指去除模型参数中冗余或不重要的部分,可以高效地生成规模更小、内存利用率更高、能耗更低、推断速度更快的模型。

1.基本原理

根据剪枝流程的位置,可以将剪枝操作分为2种:训练时剪枝和后剪枝。

(1) 训练时剪枝

和训练时使用dropout操作较为类似,训练时剪枝会根据当前模型的结果,删除不重要的结构,固化模型再进行训练,以后续的训练来弥补部分结构剪枝带来的不利影响。

(2) 后剪枝

在模型训练完成后,根据模型权重参数和剪枝测试选取需要剪枝的部分。

2.代码实例

(1) 加载预训练模型

import torch
import torchvision.models as models# 加载预训练的ResNet18模型
model = models.resnet18(pretrained=True)

(2) 定义剪枝算法

from torch.nn.utils.prune import global_unstructured# 定义剪枝比例
pruning_rate = 0.5# 对全连接层进行剪枝
def prune_model(model, pruning_rate):for name, module in model.named_modules():if isinstance(module, torch.nn.Linear):global_unstructured(module, pruning_dim=0, amount=pruning_rate)

(3)执行剪枝操作

prune_model(model, pruning_rate)# 查看剪枝后的模型结构
print(model)

(4) 重新训练和微调

剪枝后的模型需要重新进行训练和微调,以保证模型的准确性和性能。

# 定义损失函数和优化器
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

(5) 性能测试

三.蒸馏

蒸馏是指将知识从大模型(教师模型)向小模型(学生模型)传输的过程,可以用于模型压缩和训练加速。核心组件包括:知识(knowledge)、蒸馏算法(distillation algorithm)、教师学生架构(teacher-student architecture)。

1.基本原理

蒸馏的知识的形式可以是:激活、神经元、中间层特征、教师网络参数等。可将其归类为以下3种类型。

(1) Feature-Based Knowledge

基于特征的知识蒸馏引入中间层表征,教师网络的中间层作为学生网络对应层的提示(Hints层),从而提升学生网络模型的性能。核心是期望学生能够直接模仿教师网络的特征激活值。

(2) Relation-Based Knowledege

基于关系的知识蒸馏可以分为不同层之间的关系建模和不同样本之间的关系建模2种。

•不同层之间的关系建模

通常可以建模为:

其中,f _{t}f_{s}表示学生网络内成对的特征图,\Psi _{t}\Psi _{s}是相似度函数,L_{​{R^{1}}}代表教师网络与学生网络的关联函数。

•不同样本之间的关系建模

建模如下:

其中,F _{t}F _{s}分别是teacher和student模型的特征表示;\left ( t_{i}, t_{j}\right )\in F _{t}\left ( s_{i}, s_{j}\right )\in F _{s}

基于关系的知识蒸馏的具体算法如下表所示。

(3) Response-Based Knowleddge

基于响应的知识蒸馏里响应一般指的是神经元的响应,即教师模型的最后1层逻辑输出。核心想法是让学生模型模仿教师网络的输出。

响应知识的loss:

Hinton提出的KD是将teacher的logits层作为soft label:

T是用于控制soft target重要程度的超参数。

整体蒸馏loss可以写作:

2.代码实例

(1) 加载数据

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()# Normalize data
x_train = x_train.astype("float32") / 255.0
x_train = np.reshape(x_train, (-1, 28, 28, 1))x_test = x_test.astype("float32") / 255.0
x_test = np.reshape(x_test, (-1, 28, 28, 1))

(2) 构建teacher 、student模型结构

# Create the teacher
teacher = keras.Sequential([keras.Input(shape=(28, 28, 1)),layers.Conv2D(256, (3, 3), strides=(2, 2), padding="same"),layers.LeakyReLU(alpha=0.2),layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),layers.Conv2D(512, (3, 3), strides=(2, 2), padding="same"),layers.Flatten(),layers.Dense(10),],name="teacher",
)# Create the student
student = keras.Sequential([keras.Input(shape=(28, 28, 1)),layers.Conv2D(16, (3, 3), strides=(2, 2), padding="same"),layers.LeakyReLU(alpha=0.2),layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),layers.Conv2D(32, (3, 3), strides=(2, 2), padding="same"),layers.Flatten(),layers.Dense(10),],name="student",
)# Clone student for later comparison
student_scratch = keras.models.clone_model(student)

(3) 训练模型

# 1.Train teacher as usual
teacher.compile(optimizer=keras.optimizers.Adam(),loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
# Train and evaluate teacher on data.
teacher.fit(x_train, y_train, epochs=3)
teacher.evaluate(x_test, y_test)# 2.Train student as  usual
student_scratch.compile(optimizer=keras.optimizers.Adam(),loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
# Train and evaluate student on data
student_scratch.fit(x_train, y_train, epochs=3)
student_scratch.evaluate(x_test, y_test)

(4) 构建蒸馏模型

class Distiller(keras.Model):def __init__(self, student, teacher):super(Distiller, self).__init__()self.teacher = teacherself.student = studentdef compile(self,optimizer,metrics,student_loss_fn,distillation_loss_fn,alpha=0.1,temperature=3,):super(Distiller, self).compile(optimizer=optimizer, metrics=metrics)self.student_loss_fn = student_loss_fnself.distillation_loss_fn = distillation_loss_fnself.alpha = alphaself.temperature = temperaturedef train_step(self, data):# Unpack datax, y = data# Forward pass of teacherteacher_predictions = self.teacher(x, training=False)with tf.GradientTape() as tape:# Forward pass of studentstudent_predictions = self.student(x, training=True)# Compute lossesstudent_loss = self.student_loss_fn(y, student_predictions)# Compute scaled distillation lossdistillation_loss = (self.distillation_loss_fn(tf.nn.softmax(teacher_predictions / self.temperature, axis=1),tf.nn.softmax(student_predictions / self.temperature, axis=1),)* self.temperature**2)loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss# Compute gradientstrainable_vars = self.student.trainable_variablesgradients = tape.gradient(loss, trainable_vars)# Update weightsself.optimizer.apply_gradients(zip(gradients, trainable_vars))# Update the metrics configured in `compile()`.self.compiled_metrics.update_state(y, student_predictions)# Return a dict of performanceresults = {m.name: m.result() for m in self.metrics}results.update({"student_loss": student_loss, "distillation_loss": distillation_loss})return resultsdef test_step(self, data):# Unpack the datax, y = data# Compute predictionsy_prediction = self.student(x, training=False)# Calculate the lossstudent_loss = self.student_loss_fn(y, y_prediction)# Update the metrics.self.compiled_metrics.update_state(y, y_prediction)# Return a dict of performanceresults = {m.name: m.result() for m in self.metrics}results.update({"student_loss": student_loss})return results

(5)蒸馏

# Train student as doen usually
student_scratch.compile(optimizer=keras.optimizers.Adam(),loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=[keras.metrics.SparseCategoricalAccuracy()],
)# Train and evaluate student trained from scratch.
student_scratch.fit(x_train, y_train, epochs=1)
student_scratch.evaluate(x_test, y_test)

四.参考

(1) Knowledge Distillation: A Survey

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/494084.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

leetcode-128.最长连续序列-day14

为什么我感觉上述代码时间复杂度接近O(2n), 虽然有while循环,但是前面有个if判断,能进入while循环的也不多,while循环就相当于两个for循环,但不是嵌套类型的: 变量作用域问题:

Burp与其他安全工具联动及代理设置教程

Burp Suite 是一款功能强大的 Web 安全测试工具,其流量拦截和调试功能可以与其他安全工具(如 Xray、Yakit、Goby 等)实现联动,从而提升渗透测试的效率。本文将详细讲解 Burp 与其他工具联动的原理以及代理设置的操作方法&#xff…

文件操作(File类)

目录 一、初识文件 二、File类 常用方法 一、初识文件 我们目前是如何存储数据的?弊端是什么? int a 1; int[] arr new int[5];我们这些数据是在内存中存储的,是不能够长久保存的。 那么,我们的计算机当中有没有一块硬件可以长久存储数据的? 磁…

Ubuntu硬盘分区及挂载(命令行)

文章目录 一、简介二、硬盘分区三、格式化分区四、自动挂载分区五、调整分区大小小结 一、简介 创建磁盘分区首先需要找出Linux系统中的物理磁盘,在Linux中采用了一种标准格式来为硬盘分配设备名称。 SATA驱动器和SCSI驱动器:设备命名格式为/dev/sdx&a…

用java造1万条数据

上个月项目有造数需求记录一下。 package com.company;public class CreateSqlZhou {public static void main(String[] args) {//insert into Student (id,name,sex,age,adress) values(68881624120312320,zhangsan,男,18,北京);String startSql "insert into Student…

vue iframe进行父子页面通信并切换URL

需求是2个项目需要使用同一个面包屑进行跳转&#xff0c;其中一个是iframe所在的项目&#xff0c;另一个需要通过地址访问。通过 window.parent.postMessage &#xff0c;帮助 <iframe> 内嵌入的子页面和其父页面之间进行跨域通信。 使用通义千问提问后得到一个很好的示…

【Qt】显示类控件:QLabel、QLCDNumber、QProgressBar、QCalendarWidget

目录 QLabel QFrame 例子&#xff1a; textFormat pixmap、scaledContents alignment wordWrap、indent、margin buddy QLCDNumber 例子&#xff1a; QTimer QProgressBar 例子&#xff1a; QCalendarWidget 例子&#xff1a; QLabel 标签控件&#xff0c;用来显示…

UVM 验证方法学之interface学习系列文章(十二)virtual interface 终结篇

一 双向和三态问题 任何具有多个驱动器的信号,都需要使用网(net)来建模。网是唯一能够同时解决不同状态和强度驱动同一信号效果的构造。net的行为由内置解析函数定义,该函数使用net上所有驱动器的值和强度。每当其中一个驱动器发生变化时,就会调用该函数来生成解析值。该…

【游戏设计原理】22 - 石头剪刀布

一、游戏基础&#xff1a;拳头、掌心、分指 首先&#xff0c;石头剪刀布&#xff08;又名“Roshambo”&#xff09;看似简单&#xff0c;实际上可是个“深藏玄机”的零和博弈&#xff08;听起来很高深&#xff0c;其实就是输赢相抵消的意思&#xff09;。游戏中有三种手势&…

iterm2 focus时灰色蒙层出现的解决办法

问题描述&#xff1a; 当前我的iterm2版本是3.5.10&#xff0c;是我最近才更新的&#xff0c;然后就出现以下页面显示问题&#xff0c;如图所示&#xff1a; 我个人对终端、编辑器等使用存在洁癖&#xff0c;尤其是页面显示效果不满意更是不能忍受&#xff0c;之前找了很久没有…

如何在window 使用 conda 环境下载大模型

最近开始学习 变形金刚&#xff0c;最大的问题就是 huggingface 无法访问&#xff0c;无论是翻墙还是通过本地镜像网站HF-Mirror&#xff0c;然后再通过git下载都很慢&#xff0c;影响学习进度&#xff0c;后面看了如下文章&#xff0c;Huggingface配置镜像_huggingface镜像-CS…

Linux 网络维护相关命令简介

目录 零. 概要一. ping二. ip命令2.1 ip address2.2 ip route2.3 ip neighbour 三. traceroute四. DNS查询4.1 nslookup4.2 dig 五. ss 查看网络连接状态 零. 概要 ⏹在Linux系统中有2套用于网络管理的工具集 net-tools 早期网络管理的主要工具集&#xff0c;缺乏对 IPv6、网…

Liveweb视频融合共享平台在果园农场等项目中的视频监控系统搭建方案

一、背景介绍 在我国的大江南北遍布着各种各样的果园&#xff0c;针对这些地处偏僻的果园及农场等环境&#xff0c;较为传统的安全防范方式是建立围墙&#xff0c;但是仅靠围墙仍然无法阻挡不法分子的有意入侵和破坏&#xff0c;因此为了及时发现和处理一些难以察觉的问题&…

Ubuntu vi(vim)编辑器配置一键补全main函数

1.打开对应的配置文件 vi ~/.vim/snippets/c.snippets 2.按G将光标定位到文件末尾 3.按i进入插入模式 以tab键开头插入下的内容&#xff0c;空行也要加 tab键 4.:wq保存退出 5.再打开任意一个新的 .c文件后&#xff0c;插入模式输入 main 然后按tal键就能补全了

javaEE-线程的常用方法-4

目录 一.start():启动一个线程 调用start()方法 start()方法只能调用一次&#xff1a; java中的API: start()和run()的区别: 二.中断一个线程 中断线程方法1:引入标志位 中断线程方法2:调⽤interrupt()⽅法 抛出的异常: 三.等待一个线程 join() 四、获取线程引用 五…

服务器数据恢复—V7000存储中多块磁盘出现故障导致业务中断的数据恢复案例

服务器存储数据恢复环境&#xff1a; 一台V7000存储上共12块SAS机械硬盘&#xff08;其中1块是热备盘&#xff09;&#xff0c;组建了2组Mdisk&#xff0c;创建了一个pool。挂载在小型机上作为逻辑盘使用&#xff0c;小型机上安装的AIXSybase。 服务器存储故障&#xff1a; V7…

2024年图像处理、多媒体技术与机器学习

重要信息 官网&#xff1a;www.ipmml.org 时间&#xff1a;2024年12月27-29日 地点&#xff1a;中国-大理 简介 2024年图像处理、多媒体技术与机器学习&#xff08;CIPMT 2024&#xff09;将于2024年12月27-29日于中国大理召开。将围绕图像处理与多媒体技术、机器学习等在…

linux----文件访问(c语言)

linux文件访问相关函数 打开文件函数 - open 函数原型&#xff1a;int open(const char *pathname, int flags, mode_t mode);参数说明&#xff1a; pathname&#xff1a;这是要打开的文件的路径名&#xff0c;可以是绝对路径或者相对路径。例如&#xff0c;"/home/user/…

Redis 集群实操:强大的数据“分身术”

目录 Redis Cluster集群模式 1、介绍 2、架构设计 3、集群模式实操 4、故障转移 5、常用命令 Redis Cluster集群模式 1、介绍 redis3.0版本推出的Redis Cluster 集群模式&#xff0c;每个节点都可以保存数据和整个集群状态&#xff0c;每个节点都和其他所有节点连接。Cl…

探索 Seaborn Palette 的奥秘:为数据可视化增色添彩

一、引言 在数据科学的世界里&#xff0c;视觉传达是不可或缺的一环。一个好的数据可视化不仅能传递信息&#xff0c;还能引发共鸣。Seaborn 是 Python 中一款广受欢迎的可视化库&#xff0c;而它的调色板&#xff08;palette&#xff09;功能&#xff0c;则为我们提供了调配绚…