笔记小结:现代卷积神经网络之批量归一化

本文为李沐老师《动手学深度学习》笔记小结,用于个人复习并记录学习历程,适用于初学者

训练深层神经网络是十分困难的,特别是在较短的时间内使他们收敛更加棘手。 本节将介绍批量规范化(batch normalization),这是一种流行且有效的技术,可持续加速深层网络的收敛速度。

从零开始实现

张量的批量规范化函数
import torch
from torch import nndef batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):# 通过is_grad_enabled来判断当前模式是训练模式还是预测模式if not torch.is_grad_enabled():# 如果是在预测模式下,直接使用传入的移动平均所得的均值和方差X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)else:assert len(X.shape) in (2, 4)if len(X.shape) == 2:# 使用全连接层的情况,计算特征维上的均值和方差mean = X.mean(dim=0)var = ((X - mean) ** 2).mean(dim=0)else:# 使用二维卷积层的情况,计算通道维上(axis=1)的均值和方差。# 这里我们需要保持X的形状以便后面可以做广播运算mean = X.mean(dim=(0, 2, 3), keepdim=True)var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)# 训练模式下,用当前的均值和方差做标准化X_hat = (X - mean) / torch.sqrt(var + eps)# 更新移动平均的均值和方差moving_mean = momentum * moving_mean + (1.0 - momentum) * meanmoving_var = momentum * moving_var + (1.0 - momentum) * varY = gamma * X_hat + beta  # 缩放和移位return Y, moving_mean.data, moving_var.data
批量规范化层
class BatchNorm(nn.Module):# num_features:完全连接层的输出数量或卷积层的输出通道数。# num_dims:2表示完全连接层,4表示卷积层def __init__(self, num_features, num_dims):super().__init__()if num_dims == 2:shape = (1, num_features)else:shape = (1, num_features, 1, 1)# 参与求梯度和迭代的拉伸和偏移参数,分别初始化成1和0self.gamma = nn.Parameter(torch.ones(shape))self.beta = nn.Parameter(torch.zeros(shape))# 非模型参数的变量初始化为0和1self.moving_mean = torch.zeros(shape)self.moving_var = torch.ones(shape)def forward(self, X):# 如果X不在内存上,将moving_mean和moving_var# 复制到X所在显存上if self.moving_mean.device != X.device:self.moving_mean = self.moving_mean.to(X.device)self.moving_var = self.moving_var.to(X.device)# 保存更新过的moving_mean和moving_varY, self.moving_mean, self.moving_var = batch_norm(X, self.gamma, self.beta, self.moving_mean,self.moving_var, eps=1e-5, momentum=0.9)return Y
使用批量规范化层作用于LeNet

批量规范化是在卷积层或全连接层之后、相应的激活函数之前应用的。

net = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5), BatchNorm(6, num_dims=4), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), BatchNorm(16, num_dims=4), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),nn.Linear(16*4*4, 120), BatchNorm(120, num_dims=2), nn.Sigmoid(),nn.Linear(120, 84), BatchNorm(84, num_dims=2), nn.Sigmoid(),nn.Linear(84, 10))
训练
准备工作

和之前多篇文章中提到的一样,不再赘述,只给出代码

from IPython import display
import torchvision
from torch.utils import data
from torchvision import transforms
import matplotlib.pyplot as pltdef load_data_fashion_mnist(batch_size, resize=None): """下载Fashion-MNIST数据集,然后将其加载到内存中"""trans = [transforms.ToTensor()]if resize:trans.insert(0, transforms.Resize(resize))trans = transforms.Compose(trans)mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=0)mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=0)return (data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers()),data.DataLoader(mnist_test, batch_size, shuffle=False,num_workers=get_dataloader_workers()))def get_dataloader_workers():  """使用4个进程来读取数据"""return 4batch_size = 128
train_iter, test_iter = load_data_fashion_mnist(batch_size, resize=224)lr, num_epochs, batch_size = 1.0, 10, 256
train_iter, test_iter = load_data_fashion_mnist(batch_size)def accuracy(y_hat, y):  #@save"""计算预测正确的数量"""if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:y_hat = y_hat.argmax(axis=1) #找出输入张量(tensor)中最大值的索引cmp = y_hat.type(y.dtype) == yreturn float(cmp.type(y.dtype).sum())
class Accumulator:  #@save"""在n个变量上累加"""def __init__(self, n):self.data = [0.0] * ndef add(self, *args):self.data = [a + float(b) for a, b in zip(self.data, args)]def reset(self):self.data = [0.0] * len(self.data)def __getitem__(self, idx):return self.data[idx]import matplotlib.pyplot as plt
from matplotlib_inline import backend_inlinedef use_svg_display(): """使⽤svg格式在Jupyter中显⽰绘图"""backend_inline.set_matplotlib_formats('svg')def set_axes(axes, xlabel, ylabel, xlim, ylim, xscale, yscale, legend):"""设置matplotlib的轴"""axes.set_xlabel(xlabel)axes.set_ylabel(ylabel)axes.set_xscale(xscale)axes.set_yscale(yscale)axes.set_xlim(xlim)axes.set_ylim(ylim)if legend:axes.legend(legend)axes.grid()class Animator:  #@save"""在动画中绘制数据"""def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None,ylim=None, xscale='linear', yscale='linear',fmts=('-', 'm--', 'g-.', 'r:'), nrows=1, ncols=1,figsize=(3.5, 2.5)):# 增量地绘制多条线if legend is None:legend = []use_svg_display()self.fig, self.axes = plt.subplots(nrows, ncols, figsize=figsize)if nrows * ncols == 1:self.axes = [self.axes, ]# 使用lambda函数捕获参数self.config_axes = lambda: set_axes(self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend)self.X, self.Y, self.fmts = None, None, fmtsdef add(self, x, y):# 向图表中添加多个数据点if not hasattr(y, "__len__"):y = [y]n = len(y)if not hasattr(x, "__len__"):x = [x] * nif not self.X:self.X = [[] for _ in range(n)]if not self.Y:self.Y = [[] for _ in range(n)]for i, (a, b) in enumerate(zip(x, y)):if a is not None and b is not None:self.X[i].append(a)self.Y[i].append(b)self.axes[0].cla()for x, y, fmt in zip(self.X, self.Y, self.fmts):self.axes[0].plot(x, y, fmt)self.config_axes()display.display(self.fig)display.clear_output(wait=True)def evaluate_accuracy_gpu(net, data_iter, device=None): #@save"""使用GPU计算模型在数据集上的精度"""if isinstance(net, nn.Module):net.eval()  # 设置为评估模式if not device:device = next(iter(net.parameters())).device# 正确预测的数量,总预测的数量metric = Accumulator(2)with torch.no_grad():for X, y in data_iter:if isinstance(X, list):# BERT微调所需的(之后将介绍)X = [x.to(device) for x in X]else:X = X.to(device)y = y.to(device)metric.add(accuracy(net(X), y), y.numel())return metric[0] / metric[1]import time
class Timer:  #@save"""记录多次运行时间"""def __init__(self):self.times = []self.start()def start(self):"""启动计时器"""self.tik = time.time()def stop(self):"""停止计时器并将时间记录在列表中"""self.times.append(time.time() - self.tik)return self.times[-1]def avg(self):"""返回平均时间"""return sum(self.times) / len(self.times)def sum(self):"""返回时间总和"""return sum(self.times)def cumsum(self):"""返回累计时间"""return np.array(self.times).cumsum().tolist()def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):"""用GPU训练模型(在第六章定义)"""def init_weights(m):if type(m) == nn.Linear or type(m) == nn.Conv2d:nn.init.xavier_uniform_(m.weight)net.apply(init_weights)print('training on', device)net.to(device)optimizer = torch.optim.SGD(net.parameters(), lr=lr)loss = nn.CrossEntropyLoss()animator = Animator(xlabel='epoch', xlim=[1, num_epochs],legend=['train loss', 'train acc', 'test acc'])timer, num_batches = Timer(), len(train_iter)for epoch in range(num_epochs):# 训练损失之和,训练准确率之和,样本数metric = Accumulator(3)net.train()for i, (X, y) in enumerate(train_iter):timer.start()optimizer.zero_grad()X, y = X.to(device), y.to(device)y_hat = net(X)l = loss(y_hat, y)l.backward()optimizer.step()with torch.no_grad():metric.add(l * X.shape[0], accuracy(y_hat, y), X.shape[0])timer.stop()train_l = metric[0] / metric[2]train_acc = metric[1] / metric[2]if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:animator.add(epoch + (i + 1) / num_batches,(train_l, train_acc, None))test_acc = evaluate_accuracy_gpu(net, test_iter)animator.add(epoch + 1, (None, None, test_acc))print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, 'f'test acc {test_acc:.3f}')print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec 'f'on {str(device)}')def try_gpu(i=0):  #@save"""如果存在,则返回gpu(i),否则返回cpu()"""if torch.cuda.device_count() >= i + 1:return torch.device(f'cuda:{i}')return torch.device('cpu')
训练

和以前一样,我们将在Fashion-MNIST数据集上训练网络。 这个代码与我们第一次训练LeNet时几乎完全相同,主要区别在于学习率大得多。

begin = time.time()
train_ch6(net, train_iter, test_iter, num_epochs, lr, try_gpu())
end = time.time()
print(end - begin)

这个结果,对比当时不用批量归一化层的LeNet,训练的收敛速度快了许多,loss变小了,train acc提高了许多,但是test acc没有提高太多,出现了过拟合。 

简洁实现

除了使用我们刚刚定义的BatchNorm,我们也可以直接使用深度学习框架中定义的BatchNorm。 该代码看起来几乎与我们上面的代码相同。

net = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5), nn.BatchNorm2d(6), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), nn.BatchNorm2d(16), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),nn.Linear(256, 120), nn.BatchNorm1d(120), nn.Sigmoid(),nn.Linear(120, 84), nn.BatchNorm1d(84), nn.Sigmoid(),nn.Linear(84, 10))

下面,我们使用相同超参数来训练模型。 请注意,通常高级API变体运行速度快得多,因为它的代码已编译为C++或CUDA,而我们的自定义代码由Python实现。

begin = time.time()
train_ch6(net, train_iter, test_iter, num_epochs, lr, try_gpu())
end = time.time()

从结果可以看到,运行速度快了,并且过拟合也小了许多。 

小结

  • 在模型训练过程中,批量规范化利用小批量的均值和标准差,不断调整神经网络的中间输出,使整个神经网络各层的中间输出值更加稳定。
  • 批量规范化在全连接层和卷积层的使用略有不同。
  • 批量规范化层和暂退层一样,在训练模式和预测模式下计算不同。
  • 批量规范化有许多有益的副作用,主要是正则化。另一方面,”减少内部协变量偏移“的原始动机似乎不是一个有效的解释。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/382496.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Dav_笔记10:Using SQL Plan Management之3

将SQL计划基准与SQL Tuning Advisor一起使用 使用SQL Tuning Advisor调整SQL语句时,如果顾问程序找到调优计划并验证其性能优于从相应SQL计划基准中选择的计划,则建议接受SQL配置文件。 接受SQL配置文件后,数据库会将调整后的计划添加到相应…

群管机器人官网源码

一款非常好看的群管机器人html官网源码 搭建教程: 域名解析绑定 源码文件上传解压 访问域名即可 演示图片: 群管机器人官网源码下载:客户端下载 - 红客网络编程与渗透技术 原文链接: 群管机器人官网源码

云仓如何改变传统仓储模式?

云仓,即云仓储,是一种基于互联网技术的现代仓储模式,与传统的仓储模式相比,它在多个方面进行了创新和优化,包括: ———————————————————— 1、数据管理与实时监控: 云仓储利…

Element-ui :el-table 中表尾合计行

Table 表格 | Element Plus <template><el-table :data"tableData" border show-summary :summary-method"getSummariesss" style"width: 100%"><el-table-column prop"id" label"ID" width"180"…

C++与lua联合编程

C与lua联合编程 一、环境配置二、lua基本语法1.第一个lua和C程序2.基本数据类型和变量2.1 Nil2.2 Booleans2.3 Numbers2.4 String(最常用) 3. 字符串处理3.1 错误处理3.2 字符串长度:string.len3.3 字符串子串 :string.sub3.4 字符串查找: string.find3.5字符串替换: string.gs…

【笔记】ubuntu 误退了搜狗输入法:终端上输入fcitx即可重启

有时候&#xff0c;我们可能嫌弃ubuntu上的搜狗输入法&#xff0c;点击了退出&#xff1a; 但是当我们想开启搜狗输入法时&#xff0c;发现它消失了&#xff0c;此时我们可以打开终端&#xff0c;键入&#xff1a; fcitx 即可成功开启。

一些和颜色相关网站

1.中国传统色 2.网页颜色选择器 3.渐变色网站 4.多风格色卡生成 5.波浪生成 6.半透明磨砂框 7.色卡组合

全国区块链职业技能大赛国赛考题前端功能开发

任务3-1:区块链应用前端功能开发 1.请基于前端系统的开发模板,在登录组件login.js、组件管理文件components.js中添加对应的逻辑代码,实现对前端的角色选择功能,并测试功能完整性,示例页面如下: 具体要求如下: (1)有明确的提示,提示用户选择角色; (2)用户可看…

微服务安全——OAuth2详解、授权码模式、SpringAuthorizationServer实战、SSO单点登录、Gateway整合OAuth2

文章目录 Spring Authorization Server介绍OAuth2.0协议介绍角色OAuth2.0协议的运行流程应用场景授权模式详解客户端模式密码模式授权码模式简化模式token刷新模式 OAuth 2.1 协议介绍授权码模式PKCE扩展设备授权码模式拓展授权模式 OpenID Connect 1.0协议Spring Authorizatio…

用51单片机或者stm32能否开发机器人呢?

在开始前刚好我有一些资料&#xff0c;是我根据网友给的问题精心整理了一份「单片机的资料从专业入门到高级教程」&#xff0c; 点个关注在评论区回复“888”之后私信回复“888”&#xff0c;全部无偿共享给大家&#xff01;&#xff01;&#xff01;能的。但是由于单片机和st…

百度,有道,谷歌翻译API

API翻译 百度&#xff0c;有道&#xff0c;谷歌API翻译&#xff08;只针对中英相互翻译&#xff09;,其他语言翻译需要对应from&#xff0c;to的code 百度翻译 package fills.tools.translate; import java.util.ArrayList; import java.util.HashMap; import java.util.Lis…

OpenWrt 为软件包和docker空间扩容

参考资料 【openwrt折腾日记】解决openwrt固件刷入后磁盘空间默认小的问题&#xff0c;关联openwrt磁盘扩容空间扩容【openwrt分区扩容】轻松解决空间可用不足的尴尬丨老李一瓶奶油的YouTube 划分空间 参考一瓶奶油的YouTube 系统 -> 磁盘管理 -> 磁盘 -> 修改 格…

axios请求大全

本文讲解axios封装方式以及针对各种后台接口的请求方式 axios的介绍和基础配置可以看这个文档: 起步 | Axios中文文档 | Axios中文网 axios的封装 axios封装的重点有三个&#xff0c;一是设置全局config,比如请求的基础路径&#xff0c;超时时间等&#xff0c;第二点是在每次…

SQL labs-SQL注入(二)

环境搭建参考 SQL注入&#xff08;一&#xff09; 一&#xff0c;SQL labs-less2。 http://192.168.61.206:8001/Less-2/?id-1 union select 1,2,group_concat(username , password) from users-- 与第一关没什么太大的不同&#xff0c;唯一区别就是闭合方式为数字型。 二…

【ffmpeg命令入门】ffplay常用命令

文章目录 前言ffplay的简介FFplay 的基本用法常用参数及其作用示例 效果演示图播放普通视频播放网络媒体流RTSP 总结 前言 FFplay 是 FFmpeg 套件中的一个强大的媒体播放器&#xff0c;它基于命令行接口&#xff0c;允许用户以灵活且高效的方式播放音频和视频文件。作为一个简…

系统架构设计师教程 第4章 信息安全技术基础知识-4.3 信息安全系统的组成框架4.4 信息加解密技术-解读

系统架构设计师教程 第4章 信息安全技术基础知识-4.3 信息安全系统的组成框架 4.3 信息安全系统的组成框架4.3.1 技术体系4.3.1.1 基础安全设备4.3.1.2 计算机网络安全4.3.1.3 操作系统安全4.3.1.4 数据库安全4.3.1.5 终端安全设备4.3.2 组织机构体系4.3.3 管理体系4.4 信息加…

使用 Socket和动态代理以及反射 实现一个简易的 RPC 调用

使用 Socket、动态代理、反射 实现一个简易的 RPC 调用 我们前面有一篇 socket 的文章&#xff0c;再之前&#xff0c;还有一篇 java动态代理的文章&#xff0c;本文用到了那两篇文章中的知识点&#xff0c;需要的话可以回顾一下。 下面正文开始&#xff1a; 我们的背景是一个…

CSS实现的扫光效果组件

theme: lilsnake 图片和内容如有侵权&#xff0c;及时与我联系~ 详细内容与注释&#xff1a; CSS实现的扫光效果组件 代码 技术栈与框架 Vue3 CSS 扫光效果的原理 扫光效果的原理就是从左到右无限循环的一个位移动画 实现方式 适配文字扫光效果的css .shark-box { …

【devops】gitlab 实现cicd 实践

本站以分享各种运维经验和运维所需要的技能为主 《python零基础入门》&#xff1a;python零基础入门学习 《python运维脚本》&#xff1a; python运维脚本实践 《shell》&#xff1a;shell学习 《terraform》持续更新中&#xff1a;terraform_Aws学习零基础入门到最佳实战 《k8…

npm 安装报错(已解决)+ 运行 “wue-cli-service”不是内部或外部命令,也不是可运行的程序(已解决)

首先先说一下我这个项目是3年前的一个项目了&#xff0c;中间也是经过了多个人的修改惨咋了布置多少个人的思想&#xff0c;这这道我手里直接npm都安装不上&#xff0c;在网上也查询了多种方法&#xff0c;终于是找到问题所在了 问题1&#xff1a; 先是npm i 报错在下面图片&…