模型压缩之剪枝

(1)通道选择

这里要先解释一下:

(1)通道剪枝

那我们实际做法不是上面直接对所有层都添加L1正则项,而是仅仅对BN层权重添加L1正则项。通道剪枝具体步骤如下:
 

1.BN层权重添加L1正则项,进行稀疏训练

2.BN层权重的scale factor进行排序,对scale factor低于阈值的通道进行裁剪,得到剪枝模型

3.对剪枝模型进行finetune
 

注:进行finetune的目的是因为剪枝完整个网络结构发生了变化,之前的训练的模型无法再加载进入,必须要finetune(或者这里用重新训练更合适),否则会发现推理结果都是0.

在深度学习中,Batch Normalization(BN)层通常用于加速训练过程并提高模型的泛化能力。BN层的权重参数包括scale factor(缩放因子)和shift factor(偏移因子)。通过对BN层的scale factor添加L1正则化,我们可以实现通道剪枝。

下面是一个示例代码,展示了如何对BN层的scale factor添加L1正则化,并进行通道剪枝和微调(finetune)。

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms# 定义一个简单的卷积神经网络
class SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)self.bn1 = nn.BatchNorm2d(32)self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)self.bn2 = nn.BatchNorm2d(64)self.fc1 = nn.Linear(64 * 7 * 7, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = F.relu(self.bn1(self.conv1(x)))x = F.max_pool2d(x, 2)x = F.relu(self.bn2(self.conv2(x)))x = F.max_pool2d(x, 2)x = x.view(x.size(0), -1)x = F.relu(self.fc1(x))x = self.fc2(x)return x# 加载MNIST数据集
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)# 初始化模型、损失函数和优化器
model = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# L1正则化参数
lambda_l1 = 0.001# 稀疏训练
for epoch in range(10):for data, target in train_loader:optimizer.zero_grad()output = model(data)loss = criterion(output, target)# 计算BN层scale factor的L1正则化项l1_regularization = 0for module in model.modules():if isinstance(module, nn.BatchNorm2d):l1_regularization += torch.norm(module.weight, p=1)loss += lambda_l1 * l1_regularizationloss.backward()optimizer.step()print(f'Epoch {epoch + 1}, Loss: {loss.item()}')print("稀疏训练完成")# 通道剪枝
def prune_channels(model, sparsity_threshold):for module in model.modules():if isinstance(module, nn.BatchNorm2d):weights = module.weight.datamask = torch.abs(weights) > sparsity_thresholdmodule.weight.data = weights[mask]module.bias.data = module.bias.data[mask]module.num_features = int(torch.sum(mask))# 更新卷积层的输入通道数if hasattr(module, 'conv'):conv_module = getattr(module, 'conv')conv_module.out_channels = int(torch.sum(mask))conv_module.weight.data = conv_module.weight.data[mask]if conv_module.bias is not None:conv_module.bias.data = conv_module.bias.data[mask]# 设置稀疏性阈值
sparsity_threshold = 0.01# 剪枝
prune_channels(model, sparsity_threshold)print("通道剪枝完成")# 微调
for epoch in range(10):for data, target in train_loader:optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()print(f'Finetune Epoch {epoch + 1}, Loss: {loss.item()}')print("微调完成")

(2)卷积核剪枝

1.conv层权重添加L1正则项,进行稀疏训练

2.conv层权重进行排序,对权重低于阈值的卷积核进行裁剪,得到剪枝模型

3.对剪枝模型进行finetune

下面我写了一个简单的示例代码,展示了如何在训练过程中计算权重的稀疏性,并根据稀疏性剪掉通道。

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms# 定义一个简单的卷积神经网络
class SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)self.fc1 = nn.Linear(64 * 7 * 7, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = F.relu(self.conv1(x))x = F.max_pool2d(x, 2)x = F.relu(self.conv2(x))x = F.max_pool2d(x, 2)x = x.view(x.size(0), -1)x = F.relu(self.fc1(x))x = self.fc2(x)return x# 加载MNIST数据集
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)# 初始化模型、损失函数和优化器
model = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# L1正则化参数
lambda_l1 = 0.001# 训练模型
for epoch in range(10):for data, target in train_loader:optimizer.zero_grad()output = model(data)loss = criterion(output, target)# 计算L1正则化项l1_regularization = 0for param in model.parameters():l1_regularization += torch.norm(param, p=1)loss += lambda_l1 * l1_regularizationloss.backward()optimizer.step()print(f'Epoch {epoch + 1}, Loss: {loss.item()}')print("训练完成")# 根据稀疏性剪掉通道
def prune_channels(model, sparsity_threshold):for name, module in model.named_modules():if isinstance(module, nn.Conv2d):weights = module.weight.dataabs_weights = torch.abs(weights)channel_sums = torch.sum(abs_weights, dim=(1, 2, 3))mask = channel_sums > sparsity_thresholdmodule.weight.data = weights[mask]module.out_channels = int(torch.sum(mask))if module.bias is not None:module.bias.data = module.bias.data[mask]# 设置稀疏性阈值
sparsity_threshold = 0.01# 剪枝
prune_channels(model, sparsity_threshold)print("剪枝完成")

在上面例子中,我们在训练完成后,通过 prune_channels 函数根据稀疏性剪掉通道。

具体步骤如下:

  1. 计算权重的稀疏性:对于每个卷积层的权重,我们计算每个通道的权重绝对值之和。

  2. 剪枝:根据设定的稀疏性阈值,我们创建一个掩码(mask),只保留那些权重绝对值之和大于阈值的通道,并更新卷积层的权重和偏置。

通过这种方式,我们可以根据权重的稀疏性剪掉不重要的通道,从而减少模型的复杂度和计算量。

通道剪枝和卷积核剪枝小结:

卷积核剪枝(Kernel Pruning)和通道剪枝(Channel Pruning)是两种不同的模型剪枝技术,它们在剪枝的对象和目标上有所区别。

卷积核剪枝(Kernel Pruning)

卷积核剪枝 是指从卷积层中移除整个卷积核(kernel)。一个卷积核通常由一组权重组成,这些权重在卷积操作中与输入特征图的局部区域进行卷积运算。卷积核剪枝的目标是移除那些对模型性能贡献较小的卷积核,从而减少模型的计算量和参数数量。

  • 剪枝对象:卷积核(kernel)。

  • 剪枝目标:移除整个卷积核。

  • 影响:减少卷积层的输出通道数。

通道剪枝(Channel Pruning)

通道剪枝 是指从卷积层或全连接层中移除整个通道(channel)。一个通道通常由一组权重组成,这些权重在卷积操作中与输入特征图的所有位置进行卷积运算。通道剪枝的目标是移除那些对模型性能贡献较小的通道,从而减少模型的计算量和参数数量。

  • 剪枝对象:通道(channel)。

  • 剪枝目标:移除整个通道。

  • 影响:减少卷积层的输入或输出通道数。

主要区别

  1. 剪枝对象

    • 卷积核剪枝针对的是卷积核,即卷积层中的单个权重组。

    • 通道剪枝针对的是通道,即卷积层或全连接层中的整个权重集合。

  2. 剪枝目标

    • 卷积核剪枝的目标是移除整个卷积核。

    • 通道剪枝的目标是移除整个通道。

  3. 影响

    • 卷积核剪枝主要影响卷积层的输出通道数。

    • 通道剪枝既可以影响卷积层的输入通道数,也可以影响输出通道数。

卷积核剪枝代码:
 

def prune_kernels(model, sparsity_threshold):for name, module in model.named_modules():if isinstance(module, nn.Conv2d):weights = module.weight.dataabs_weights = torch.abs(weights)kernel_sums = torch.sum(abs_weights, dim=(1, 2, 3))mask = kernel_sums > sparsity_thresholdmodule.weight.data = weights[mask]module.out_channels = int(torch.sum(mask))if module.bias is not None:module.bias.data = module.bias.data[mask]

通道剪枝代码:

def prune_channels(model, sparsity_threshold):#遍历模型中的所有模块for module in model.modules():#检查模块是否为BN层if isinstance(module, nn.BatchNorm2d):#获取BN层的权重weights = module.weight.data#根据稀疏性阈值创建掩码mask = torch.abs(weights) > sparsity_threshold#应用掩码到BN层的权重和偏置module.weight.data = weights[mask]module.bias.data = module.bias.data[mask]module.num_features = int(torch.sum(mask))#检查BN层是否有与之关联的卷积层if hasattr(module, 'conv'):conv_module = getattr(module, 'conv')#应用掩码到卷积层的权重和偏置conv_module.out_channels = int(torch.sum(mask))conv_module.weight.data = conv_module.weight.data[mask]if conv_module.bias is not None:conv_module.bias.data = conv_module.bias.data[mask]

在通道剪枝中,我们不仅需要剪枝Batch Normalization(BN)层的权重,还需要相应地剪枝与之关联的卷积层的权重。具体来说,BN层的权重(scale factor)决定了哪些通道是重要的,因此我们需要根据BN层的权重来剪枝卷积层的通道。

通过这种方式,我们确保了BN层的剪枝与卷积层的剪枝是一致的,即剪枝后的BN层和卷积层具有相同的通道数。这样可以保证模型在剪枝后的结构是有效的,并且能够正常工作。

总结来说,通道剪枝不仅涉及BN层的权重剪枝,还涉及与之关联的卷积层的权重剪枝,以确保剪枝后的模型结构的一致性和有效性。

(3)特征图重构
 

特征图重构是一种在通道剪枝中常用的方法,旨在最小化剪枝后特征图与原始特征图之间的差异。通过这种方式,我们可以更直接地控制剪枝的力度,并确保剪枝后的模型在性能上与原始模型尽可能接近。

下面是一个示例代码,展示了如何使用最小二乘法(linear least squares)来实现特征图重构,从而控制通道剪枝的力度。

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR# 定义一个简单的卷积神经网络
class SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)self.bn1 = nn.BatchNorm2d(32)self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)self.bn2 = nn.BatchNorm2d(64)self.fc1 = nn.Linear(64 * 7 * 7, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = F.relu(self.bn1(self.conv1(x)))x = F.max_pool2d(x, 2)x = F.relu(self.bn2(self.conv2(x)))x = F.max_pool2d(x, 2)x = x.view(x.size(0), -1)x = F.relu(self.fc1(x))x = self.fc2(x)return x# 加载MNIST数据集
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)# 初始化模型、损失函数和优化器
model = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# 训练模型
for epoch in range(10):for data, target in train_loader:optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()print(f'Epoch {epoch + 1}, Loss: {loss.item()}')print("训练完成")# 特征图重构
def feature_map_reconstruction(model, train_loader, alpha=0.01):model.eval()original_features = []pruned_features = []# 收集原始特征图with torch.no_grad():for data, _ in train_loader:output = model(data)original_features.append(output)# 剪枝def prune_channels(model, sparsity_threshold):for module in model.modules():if isinstance(module, nn.BatchNorm2d):weights = module.weight.datamask = torch.abs(weights) > sparsity_thresholdmodule.weight.data = weights[mask]module.bias.data = module.bias.data[mask]module.num_features = int(torch.sum(mask))# 更新卷积层的输入通道数if hasattr(module, 'conv'):conv_module = getattr(module, 'conv')conv_module.out_channels = int(torch.sum(mask))conv_module.weight.data = conv_module.weight.data[mask]if conv_module.bias is not None:conv_module.bias.data = conv_module.bias.data[mask]# 设置稀疏性阈值sparsity_threshold = 0.01prune_channels(model, sparsity_threshold)# 收集剪枝后的特征图with torch.no_grad():for data, _ in train_loader:output = model(data)pruned_features.append(output)# 计算特征图差异original_features = torch.cat(original_features, dim=0)pruned_features = torch.cat(pruned_features, dim=0)diff = original_features - pruned_featuresloss = alpha * torch.norm(diff, p=2)# 反向传播和优化loss.backward()optimizer.step()print(f'Feature Map Reconstruction Loss: {loss.item()}')# 特征图重构
feature_map_reconstruction(model, train_loader)print("特征图重构完成")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/419452.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

还不懂BIO,NIO,AIO吗

BIO(Blocking I/O)、NIO(Non-blocking I/O)和 AIO(Asynchronous I/O)是 Java 中三种不同的 I/O 模型,主要用于处理输入 / 输出操作。 一、BIO(Blocking I/O) 定义与工作原…

ANSA联合ABAQS基于梁单元的螺栓预紧力分析实例

1、在螺栓孔之间创建一个模拟螺栓 ABAQUS界面→AUXILIARIES→bolt→分鳖选择上下两圈节点,这样在螺栓孔中间就会生成一个梁单元。 中键确定,因为螺杆使用的是变形体,所以接下来需要为其创建一个属性: 单击ok,完成虚拟螺栓的创建,该螺栓两端是刚性MPC,中间是弹性的梁单元…

美畅物联丨科技赋能校车安全:智慧监控管理系统的创新应用

1、背景 1.1应用需求 孩子,作为国家未来的希望之星和民族发展的潜力所在,其安全与健康向来都是社会瞩目的核心要点。校车,作为孩子们日常出行的关键交通载体,其安全性更是时刻牵动着每一个家庭的敏感神经。然而,不可…

利用TCP编程实现FTP功能

模拟FTP核心原理:客户端连接服务器后,向服务器发送一个文件。文件名可以通过参数指定,服务器端接收客户端传来的文件(文件名随意),如果文件不存在自动创建文件,如果文件存在,那么清空…

828华为云征文|使用sysbench对Mysql应用加速测评

文章目录 ❀前言❀测试环境准备❀测试工具选择❀测试工具安装❀mysql配置❀未开启Mysql加速测试❀开启Mysql加速测试❀总结 ❀前言 大家好,我是早九晚十二。 昨天有梳理一篇关于华为云最新推出的云服务器产品Flexus云服务器X。当时有说过,这次的华为云F…

【科研小白系列】使用screen创建虚拟终端,实现本地关机后服务器仍然跑模型

博主简介:努力学习的22级计算机科学与技术本科生一枚🌸博主主页: 是瑶瑶子啦往期回顾: 【科研小白系列】模型训练已经停止(强行中断)了,可GPU不释放显存,如何解决? 每日一言🌼: “生…

k8s网络

pod 网络 在K8S集群里,多个节点上的Pod相互通信,要通过网络插件来完成,比如Calico网络插件。 使用kubeadm初始化K8S集群时,有指定一个参数–pod-network-cidr10.18.0.0/16 它用来定义Pod的网段。 而我们在配置Calico的时候&#…

Trm理论 2(Word2Vec)

神经网络模型(NNLM)和Word2Vec NNLM模型是上次说过的模型,其目的是为了预测下一个词。 softmax(w2tanh(w1x b1)b2) 会得到一个副产品词向量 而Word2Vue就是专门求词向量的模型 softmax(w2*(w1*x b1)b2) Word2Vec softmax(w2*(w1*x b1)b…

jmeter性能测试HTML测试报告生成详解

作用:jmeter支持生成HTML测试报告,方便查看测试计划中获得图表和统计信息 命令: jmeter -n -t [jmx file] -l [result file] -e -o [html report folder] 示例:jmeter -n -t login.jmx -l result.jtl -e -o ./report jmx文件&a…

Gmsh:一个开源的三维有限元网格生成工具

Gmsh 是一个开源的三维有限元网格生成工具,主要用于在计算流体力学(CFD)和有限元分析(FEA)中生成复杂几何体的网格。它具有强大的几何建模、网格生成、求解器接口和后处理功能。Gmsh 适用于多种物理领域的模拟,包括流体力学、结构分析、电磁学等。 下载地址:https://gm…

【HarmonyOS】- 内存优化

文章目录 知识回顾前言源码分析1. onMemoryLevel2. 使用LRUCache优化ArkTS内存原理介绍3. 使用生命周期管理优化ArkTS内存4. 使用purgeable优化C++内存拓展知识1. Purgeable Memory总结知识回顾 前言 当应用程序占用过多内存时,系统可能会频繁进行内存回收和重新分配,导致应…

Java中Date类型上的注解

在日常开发中,涉及到日期时间类型Date和常用的注解DateTimeFormat和JsonFormat java.util.Date; org.springframework.format.annotation.DateTimeFormat; com.fasterxml.jackson.annotation.JsonFormat; 一 Date类型字段不使用注解 Data AllArgsConstructor N…

开源还是封闭?人工智能的两难选择

这篇文章于 2024 年 7 月 29 日首次出现在 The New Stack 上。人工智能正处于软件行业的完美风暴中,现在马克扎克伯格 (Mark Zuckerberg) 正在呼吁开源 AI。 关于如何控制 AI 的三个强大观点正在发生碰撞: 1 . 所有 AI 都应该是开…

MiniGPT-3D, 首个高效的3D点云大语言模型,仅需一张RTX3090显卡,训练一天时间,已开源

项目主页:https://tangyuan96.github.io/minigpt_3d_project_page/ 代码:https://github.com/TangYuan96/MiniGPT-3D 论文:https://arxiv.org/pdf/2405.01413 MiniGPT-3D在多个任务上取得了SoTA,被ACM MM2024接收,只拥…

【软件设计师真题】下午题第一大题---数据流图设计

解答数据流图的题目关键在于细心。 考试时一定要仔细阅读题目说明和给出的流程图。另外,解题时要懂得将说明和流程图进行对照,将父图和子图进行对照,切忌按照常识来猜测。同时应按照一定顺序考虑问题,以防遗漏,比如可以…

Einsum(Einstein summation convention)

Einsum(Einstein summation convention) 笔记来源: Permute和Reshape嫌麻烦?einsum来帮忙! The Einstein summation convention is a notational shorthand used in tensor calculus, particularly in the fields of …

[数据集][目标检测]西红柿缺陷检测数据集VOC+YOLO格式17318张3类别

数据集格式:Pascal VOC格式YOLO格式(不包含分割路径的txt文件,仅仅包含jpg图片以及对应的VOC格式xml文件和yolo格式txt文件) 图片数量(jpg文件个数):17318 标注数量(xml文件个数):17318 标注数量(txt文件个数):17318 标…

张飞硬件11~19-电容篇笔记

电容作用 作为源,对后级电路提供能量,对源进行充电。简单讲就是放电和充电。在电路设计中,源往往与负载相隔很远,增加电容就可以起到稳定作用。电容两端的电压不能激变,增加电容可以稳定电压。 电容可以类比为水坝&a…

(javaweb)mysql---DDL

一.数据模型,数据库操作 1.二维表:有行有列 2. 3.客户端连接数据库,发送sql语句给DBMS(数据库管理系统),DBMS创建--以文件夹显示 二.表结构操作--创建 database和schema含义一样。 这样就显示出了之前的内容…

系统编程--线程

这里写目录标题 线程概念什么是线程简介图解 内核原理图解 线程共享资源与非共享资源共享资源非共享资源 线程优缺点 线程控制原语pthread_self、pthread_create简介代码总结 循环创建多个子线程错误代码 线程间全局变量共享pthread_exit简介代码 pthread_join(回收…