(1)通道选择
这里要先解释一下:
(1)通道剪枝
那我们实际做法不是上面直接对所有层都添加L1正则项,而是仅仅对BN层权重添加L1正则项。通道剪枝具体步骤如下:
1.BN层权重添加L1正则项,进行稀疏训练
2.对BN层权重的scale factor进行排序,对scale factor低于阈值的通道进行裁剪,得到剪枝模型
3.对剪枝模型进行finetune
注:进行finetune的目的是因为剪枝完整个网络结构发生了变化,之前的训练的模型无法再加载进入,必须要finetune(或者这里用重新训练更合适),否则会发现推理结果都是0.
在深度学习中,Batch Normalization(BN)层通常用于加速训练过程并提高模型的泛化能力。BN层的权重参数包括scale factor(缩放因子)和shift factor(偏移因子)。通过对BN层的scale factor添加L1正则化,我们可以实现通道剪枝。
下面是一个示例代码,展示了如何对BN层的scale factor添加L1正则化,并进行通道剪枝和微调(finetune)。
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms# 定义一个简单的卷积神经网络
class SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)self.bn1 = nn.BatchNorm2d(32)self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)self.bn2 = nn.BatchNorm2d(64)self.fc1 = nn.Linear(64 * 7 * 7, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = F.relu(self.bn1(self.conv1(x)))x = F.max_pool2d(x, 2)x = F.relu(self.bn2(self.conv2(x)))x = F.max_pool2d(x, 2)x = x.view(x.size(0), -1)x = F.relu(self.fc1(x))x = self.fc2(x)return x# 加载MNIST数据集
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)# 初始化模型、损失函数和优化器
model = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# L1正则化参数
lambda_l1 = 0.001# 稀疏训练
for epoch in range(10):for data, target in train_loader:optimizer.zero_grad()output = model(data)loss = criterion(output, target)# 计算BN层scale factor的L1正则化项l1_regularization = 0for module in model.modules():if isinstance(module, nn.BatchNorm2d):l1_regularization += torch.norm(module.weight, p=1)loss += lambda_l1 * l1_regularizationloss.backward()optimizer.step()print(f'Epoch {epoch + 1}, Loss: {loss.item()}')print("稀疏训练完成")# 通道剪枝
def prune_channels(model, sparsity_threshold):for module in model.modules():if isinstance(module, nn.BatchNorm2d):weights = module.weight.datamask = torch.abs(weights) > sparsity_thresholdmodule.weight.data = weights[mask]module.bias.data = module.bias.data[mask]module.num_features = int(torch.sum(mask))# 更新卷积层的输入通道数if hasattr(module, 'conv'):conv_module = getattr(module, 'conv')conv_module.out_channels = int(torch.sum(mask))conv_module.weight.data = conv_module.weight.data[mask]if conv_module.bias is not None:conv_module.bias.data = conv_module.bias.data[mask]# 设置稀疏性阈值
sparsity_threshold = 0.01# 剪枝
prune_channels(model, sparsity_threshold)print("通道剪枝完成")# 微调
for epoch in range(10):for data, target in train_loader:optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()print(f'Finetune Epoch {epoch + 1}, Loss: {loss.item()}')print("微调完成")
(2)卷积核剪枝
1.conv层权重添加L1正则项,进行稀疏训练
2.对conv层权重进行排序,对权重低于阈值的卷积核进行裁剪,得到剪枝模型
3.对剪枝模型进行finetune
下面我写了一个简单的示例代码,展示了如何在训练过程中计算权重的稀疏性,并根据稀疏性剪掉通道。
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms# 定义一个简单的卷积神经网络
class SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)self.fc1 = nn.Linear(64 * 7 * 7, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = F.relu(self.conv1(x))x = F.max_pool2d(x, 2)x = F.relu(self.conv2(x))x = F.max_pool2d(x, 2)x = x.view(x.size(0), -1)x = F.relu(self.fc1(x))x = self.fc2(x)return x# 加载MNIST数据集
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)# 初始化模型、损失函数和优化器
model = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# L1正则化参数
lambda_l1 = 0.001# 训练模型
for epoch in range(10):for data, target in train_loader:optimizer.zero_grad()output = model(data)loss = criterion(output, target)# 计算L1正则化项l1_regularization = 0for param in model.parameters():l1_regularization += torch.norm(param, p=1)loss += lambda_l1 * l1_regularizationloss.backward()optimizer.step()print(f'Epoch {epoch + 1}, Loss: {loss.item()}')print("训练完成")# 根据稀疏性剪掉通道
def prune_channels(model, sparsity_threshold):for name, module in model.named_modules():if isinstance(module, nn.Conv2d):weights = module.weight.dataabs_weights = torch.abs(weights)channel_sums = torch.sum(abs_weights, dim=(1, 2, 3))mask = channel_sums > sparsity_thresholdmodule.weight.data = weights[mask]module.out_channels = int(torch.sum(mask))if module.bias is not None:module.bias.data = module.bias.data[mask]# 设置稀疏性阈值
sparsity_threshold = 0.01# 剪枝
prune_channels(model, sparsity_threshold)print("剪枝完成")
在上面例子中,我们在训练完成后,通过 prune_channels
函数根据稀疏性剪掉通道。
具体步骤如下:
-
计算权重的稀疏性:对于每个卷积层的权重,我们计算每个通道的权重绝对值之和。
-
剪枝:根据设定的稀疏性阈值,我们创建一个掩码(mask),只保留那些权重绝对值之和大于阈值的通道,并更新卷积层的权重和偏置。
通过这种方式,我们可以根据权重的稀疏性剪掉不重要的通道,从而减少模型的复杂度和计算量。
通道剪枝和卷积核剪枝小结:
卷积核剪枝(Kernel Pruning)和通道剪枝(Channel Pruning)是两种不同的模型剪枝技术,它们在剪枝的对象和目标上有所区别。
卷积核剪枝(Kernel Pruning)
卷积核剪枝 是指从卷积层中移除整个卷积核(kernel)。一个卷积核通常由一组权重组成,这些权重在卷积操作中与输入特征图的局部区域进行卷积运算。卷积核剪枝的目标是移除那些对模型性能贡献较小的卷积核,从而减少模型的计算量和参数数量。
-
剪枝对象:卷积核(kernel)。
-
剪枝目标:移除整个卷积核。
-
影响:减少卷积层的输出通道数。
通道剪枝(Channel Pruning)
通道剪枝 是指从卷积层或全连接层中移除整个通道(channel)。一个通道通常由一组权重组成,这些权重在卷积操作中与输入特征图的所有位置进行卷积运算。通道剪枝的目标是移除那些对模型性能贡献较小的通道,从而减少模型的计算量和参数数量。
-
剪枝对象:通道(channel)。
-
剪枝目标:移除整个通道。
-
影响:减少卷积层的输入或输出通道数。
主要区别
-
剪枝对象:
-
卷积核剪枝针对的是卷积核,即卷积层中的单个权重组。
-
通道剪枝针对的是通道,即卷积层或全连接层中的整个权重集合。
-
-
剪枝目标:
-
卷积核剪枝的目标是移除整个卷积核。
-
通道剪枝的目标是移除整个通道。
-
-
影响:
-
卷积核剪枝主要影响卷积层的输出通道数。
-
通道剪枝既可以影响卷积层的输入通道数,也可以影响输出通道数。
-
卷积核剪枝代码:
def prune_kernels(model, sparsity_threshold):for name, module in model.named_modules():if isinstance(module, nn.Conv2d):weights = module.weight.dataabs_weights = torch.abs(weights)kernel_sums = torch.sum(abs_weights, dim=(1, 2, 3))mask = kernel_sums > sparsity_thresholdmodule.weight.data = weights[mask]module.out_channels = int(torch.sum(mask))if module.bias is not None:module.bias.data = module.bias.data[mask]
通道剪枝代码:
def prune_channels(model, sparsity_threshold):#遍历模型中的所有模块for module in model.modules():#检查模块是否为BN层if isinstance(module, nn.BatchNorm2d):#获取BN层的权重weights = module.weight.data#根据稀疏性阈值创建掩码mask = torch.abs(weights) > sparsity_threshold#应用掩码到BN层的权重和偏置module.weight.data = weights[mask]module.bias.data = module.bias.data[mask]module.num_features = int(torch.sum(mask))#检查BN层是否有与之关联的卷积层if hasattr(module, 'conv'):conv_module = getattr(module, 'conv')#应用掩码到卷积层的权重和偏置conv_module.out_channels = int(torch.sum(mask))conv_module.weight.data = conv_module.weight.data[mask]if conv_module.bias is not None:conv_module.bias.data = conv_module.bias.data[mask]
在通道剪枝中,我们不仅需要剪枝Batch Normalization(BN)层的权重,还需要相应地剪枝与之关联的卷积层的权重。具体来说,BN层的权重(scale factor)决定了哪些通道是重要的,因此我们需要根据BN层的权重来剪枝卷积层的通道。
通过这种方式,我们确保了BN层的剪枝与卷积层的剪枝是一致的,即剪枝后的BN层和卷积层具有相同的通道数。这样可以保证模型在剪枝后的结构是有效的,并且能够正常工作。
总结来说,通道剪枝不仅涉及BN层的权重剪枝,还涉及与之关联的卷积层的权重剪枝,以确保剪枝后的模型结构的一致性和有效性。
(3)特征图重构
特征图重构是一种在通道剪枝中常用的方法,旨在最小化剪枝后特征图与原始特征图之间的差异。通过这种方式,我们可以更直接地控制剪枝的力度,并确保剪枝后的模型在性能上与原始模型尽可能接近。
下面是一个示例代码,展示了如何使用最小二乘法(linear least squares)来实现特征图重构,从而控制通道剪枝的力度。
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR# 定义一个简单的卷积神经网络
class SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)self.bn1 = nn.BatchNorm2d(32)self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)self.bn2 = nn.BatchNorm2d(64)self.fc1 = nn.Linear(64 * 7 * 7, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = F.relu(self.bn1(self.conv1(x)))x = F.max_pool2d(x, 2)x = F.relu(self.bn2(self.conv2(x)))x = F.max_pool2d(x, 2)x = x.view(x.size(0), -1)x = F.relu(self.fc1(x))x = self.fc2(x)return x# 加载MNIST数据集
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)# 初始化模型、损失函数和优化器
model = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# 训练模型
for epoch in range(10):for data, target in train_loader:optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()print(f'Epoch {epoch + 1}, Loss: {loss.item()}')print("训练完成")# 特征图重构
def feature_map_reconstruction(model, train_loader, alpha=0.01):model.eval()original_features = []pruned_features = []# 收集原始特征图with torch.no_grad():for data, _ in train_loader:output = model(data)original_features.append(output)# 剪枝def prune_channels(model, sparsity_threshold):for module in model.modules():if isinstance(module, nn.BatchNorm2d):weights = module.weight.datamask = torch.abs(weights) > sparsity_thresholdmodule.weight.data = weights[mask]module.bias.data = module.bias.data[mask]module.num_features = int(torch.sum(mask))# 更新卷积层的输入通道数if hasattr(module, 'conv'):conv_module = getattr(module, 'conv')conv_module.out_channels = int(torch.sum(mask))conv_module.weight.data = conv_module.weight.data[mask]if conv_module.bias is not None:conv_module.bias.data = conv_module.bias.data[mask]# 设置稀疏性阈值sparsity_threshold = 0.01prune_channels(model, sparsity_threshold)# 收集剪枝后的特征图with torch.no_grad():for data, _ in train_loader:output = model(data)pruned_features.append(output)# 计算特征图差异original_features = torch.cat(original_features, dim=0)pruned_features = torch.cat(pruned_features, dim=0)diff = original_features - pruned_featuresloss = alpha * torch.norm(diff, p=2)# 反向传播和优化loss.backward()optimizer.step()print(f'Feature Map Reconstruction Loss: {loss.item()}')# 特征图重构
feature_map_reconstruction(model, train_loader)print("特征图重构完成")