【深度学习实验】网络优化与正则化(三):随机梯度下降的改进——Adam算法详解(Adam≈梯度方向优化Momentum+自适应学习率RMSprop)

文章目录

  • 一、实验介绍
  • 二、实验环境
    • 1. 配置虚拟环境
    • 2. 库版本介绍
  • 三、实验内容
    • 0. 导入必要的库
    • 1. 随机梯度下降SGD算法
      • a. PyTorch中的SGD优化器
      • b. 使用SGD优化器的前馈神经网络
    • 2.随机梯度下降的改进方法
      • a. 学习率调整
      • b. 梯度估计修正
    • 3. 梯度估计修正:动量法Momentum
    • 4. 自适应学习率
      • RMSprop算法
    • 5. Adam算法
      • 更新公式
      • 算法实现
      • 算法测试
    • 6. 代码整合

  任何数学技巧都不能弥补信息的缺失。
——科尼利厄斯·兰佐斯(Cornelius Lanczos)匈牙利数学家、物理学家

一、实验介绍

  深度神经网络在机器学习中应用时面临两类主要问题:优化问题和泛化问题。

  • 优化问题:深度神经网络的优化具有挑战性。

    • 神经网络的损失函数通常是非凸函数,因此找到全局最优解往往困难。
    • 深度神经网络的参数通常非常多,而训练数据也很大,因此使用计算代价较高的二阶优化方法不太可行,而一阶优化方法的训练效率通常较低。
    • 深度神经网络存在梯度消失梯度爆炸问题,导致基于梯度的优化方法经常失效。
  • 泛化问题:由于深度神经网络的复杂度较高且具有强大的拟合能力,很容易在训练集上产生过拟合现象。因此,在训练深度神经网络时需要采用一定的正则化方法来提高网络的泛化能力。

  目前,研究人员通过大量实践总结了一些经验方法,以在神经网络的表示能力、复杂度、学习效率和泛化能力之间取得良好的平衡,从而得到良好的网络模型。本系列文章将从网络优化和网络正则化两个方面来介绍如下方法:

  • 在网络优化方面,常用的方法包括优化算法的选择参数初始化方法数据预处理方法逐层归一化方法超参数优化方法
  • 在网络正则化方面,一些提高网络泛化能力的方法包括ℓ1和ℓ2正则化权重衰减提前停止丢弃法数据增强标签平滑等。

  本文将介绍基于自适应学习率的优化算法:Adam算法详解(Adam≈梯度方向优化Momentum+自适应学习率RMSprop)

二、实验环境

  本系列实验使用了PyTorch深度学习框架,相关操作如下:

1. 配置虚拟环境

conda create -n DL python=3.7 
conda activate DL
pip install torch==1.8.1+cu102 torchvision==0.9.1+cu102 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
conda install matplotlib
 conda install scikit-learn

2. 库版本介绍

软件包本实验版本目前最新版
matplotlib3.5.33.8.0
numpy1.21.61.26.0
python3.7.16
scikit-learn0.22.11.3.0
torch1.8.1+cu1022.0.1
torchaudio0.8.12.0.2
torchvision0.9.1+cu1020.15.2

三、实验内容

0. 导入必要的库

import torch
import torch.nn.functional as F
from d2l import torch as d2l
from sklearn.datasets import load_iris
from torch.utils.data import Dataset, DataLoader

1. 随机梯度下降SGD算法

  随机梯度下降(Stochastic Gradient Descent,SGD)是一种常用的优化算法,用于训练深度神经网络。在每次迭代中,SGD通过随机均匀采样一个数据样本的索引,并计算该样本的梯度来更新网络参数。具体而言,SGD的更新步骤如下:

  1. 从训练数据中随机选择一个样本的索引。
  2. 使用选择的样本计算损失函数对于网络参数的梯度。
  3. 根据计算得到的梯度更新网络参数。
  4. 重复以上步骤,直到达到停止条件(如达到固定的迭代次数或损失函数收敛)。

a. PyTorch中的SGD优化器

   Pytorch官方教程

optimizer = torch.optim.SGD(model.parameters(), lr=0.2)

b. 使用SGD优化器的前馈神经网络

   【深度学习实验】前馈神经网络(final):自定义鸢尾花分类前馈神经网络模型并进行训练及评价

2.随机梯度下降的改进方法

  传统的SGD在某些情况下可能存在一些问题,例如学习率选择困难和梯度的不稳定性。为了改进这些问题,提出了一些随机梯度下降的改进方法,其中包括学习率的调整和梯度的优化。

a. 学习率调整

在这里插入图片描述

  • 学习率衰减(Learning Rate Decay):随着训练的进行,逐渐降低学习率。常见的学习率衰减方法有固定衰减、按照指数衰减、按照时间表衰减等。
  • Adagrad:自适应地调整学习率。Adagrad根据参数在训练过程中的历史梯度进行调整,对于稀疏梯度较大的参数,降低学习率;对于稀疏梯度较小的参数,增加学习率。这样可以在不同参数上采用不同的学习率,提高收敛速度。
  • Adadelta:与Adagrad类似,但进一步解决了Adagrad学习率递减过快的问题。Adadelta不仅考虑了历史梯度,还引入了一个累积的平方梯度的衰减平均,以动态调整学习率。
  • RMSprop:也是一种自适应学习率的方法,通过使用梯度的指数加权移动平均来调整学习率。RMSprop结合了Adagrad的思想,但使用了衰减平均来减缓学习率的累积效果,从而更加稳定。

b. 梯度估计修正

  • Momentum:使用梯度的“加权移动平均”作为参数的更新方向。Momentum方法引入了一个动量项,用于加速梯度下降的过程。通过积累之前的梯度信息,可以在更新参数时保持一定的惯性,有助于跳出局部最优解、加快收敛速度。
  • Nesterov accelerated gradient:Nesterov加速梯度(NAG)是Momentum的一种变体。与Momentum不同的是,NAG会先根据当前的梯度估计出一个未来位置,然后在该位置计算梯度。这样可以更准确地估计当前位置的梯度,并且在参数更新时更加稳定。
  • 梯度截断(Gradient Clipping):为了应对梯度爆炸或梯度消失的问题,梯度截断的方法被提出。梯度截断通过限制梯度的范围,将梯度控制在一个合理的范围内。常见的梯度截断方法有阈值截断和梯度缩放。

3. 梯度估计修正:动量法Momentum

【深度学习实验】网络优化与正则化(一):优化算法:使用动量优化的随机梯度下降算法(Stochastic Gradient Descent with Momentum)

def init_momentum_states(feature_dim):v_w = torch.zeros((feature_dim, 3))v_b = torch.zeros(3)return (v_w, v_b)def sgd_momentum(params, states, hyperparams):for p, v in zip(params, states):with torch.no_grad():v[:] = hyperparams['momentum'] * v + p.gradp[:] -= hyperparams['lr'] * vp.grad.data.zero_()

4. 自适应学习率

【深度学习实验】网络优化与正则化(二):基于自适应学习率的优化算法详解:Adagrad、Adadelta、RMSprop

RMSprop算法

   RMSprop(Root Mean Square Propagation)算法是一种针对Adagrad算法的改进方法,通过引入衰减系数来平衡历史梯度和当前梯度的贡献。它能够更好地适应不同参数的变化情况,对于非稀疏数据集表现较好。

在这里插入图片描述

def init_rmsprop_states(feature_dim):s_w = torch.zeros((feature_dim, 3))s_b = torch.zeros(3)return (s_w, s_b)def rmsprop(params, states, hyperparams):gamma, eps = hyperparams['gamma'], 1e-6for p, s in zip(params, states):with torch.no_grad():s[:] = gamma * s + (1 - gamma) * torch.square(p.grad)p[:] -= hyperparams['lr'] * p.grad / torch.sqrt(s + eps)p.grad.data.zero_()

5. Adam算法

  Adam算法(Adaptive Moment Estimation Algorithm)[Kingma et al., 2015]可以看作动量法和 RMSprop 算法的结合,不但使用动量作为参数更新方向,而且可以自适应调整学习率

更新公式

在这里插入图片描述

算法实现

def init_adam_states(feature_dim):v_w, v_b = torch.zeros((feature_dim, 3)), torch.zeros(3)s_w, s_b = torch.zeros((feature_dim, 3)), torch.zeros(3)return ((v_w, s_w), (v_b, s_b))def adam(params, states, hyperparams):beta1, beta2, eps = 0.9, 0.999, 1e-6for p, (v, s) in zip(params, states):with torch.no_grad():v[:] = beta1 * v + (1 - beta1) * p.grads[:] = beta2 * s + (1 - beta2) * torch.square(p.grad)v_bias_corr = v / (1 - beta1 ** hyperparams['t'])s_bias_corr = s / (1 - beta2 ** hyperparams['t'])p[:] -= hyperparams['lr'] * v_bias_corr / (torch.sqrt(s_bias_corr) + eps)p.grad.data.zero_()hyperparams['t'] += 1

  init_adam_states函数用于初始化Adam优化算法的状态。它接受一个特征维度feature_dim作为输入,并返回包含权重和偏置项的状态变量((v_w, s_w), (v_b, s_b))。这些状态变量用于存储权重和偏置项的一阶矩估计(动量)和二阶矩估计(RMSProp)。

  adam函数是Adam优化算法的主要实现部分,它接受三个参数:params(待优化的参数),states(状态变量),和hyperparams(超参数)。

  • 在函数内部,使用一个循环来遍历待优化的参数params和对应的状态变量states,然后根据Adam算法的更新规则,对每个参数进行更新:
    • 在更新过程中,使用torch.no_grad()上下文管理器,表示在更新过程中不会计算梯度。
      • 根据Adam算法的公式,计算动量和二阶矩估计的更新值,并将其累加到对应的状态变量中。
      • 根据偏差修正公式,计算修正后的动量和二阶矩估计。
      • 根据修正后的动量和二阶矩估计,计算参数的更新量,并将其应用到参数上。
    • 使用p.grad.data.zero_()将参数的梯度清零,以便下一次迭代时重新计算梯度。
  • 在代码的最后,hyperparams['t'] += 1用于更新迭代次数t的计数器。

算法测试

# batch_size = 1
batch_size = 24
# batch_size = 120# 分别构建训练集、验证集和测试集
train_dataset = IrisDataset(mode='train')train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)lr = 0.02
train(adam, init_adam_states(4), {'lr': lr, 't': 1}, train_loader, 4)
  • IrisDataset类:

    • 参照前文:【深度学习实验】前馈神经网络(七):批量加载数据(直接加载数据→定义类封装数据)
  • train函数:

    • 参照前文:【深度学习实验】网络优化与正则化(一):优化算法:使用动量优化的随机梯度下降算法(Stochastic Gradient Descent with Momentum)

6. 代码整合

import torch
from torch import nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from d2l import torch as d2l
from sklearn.datasets import load_iris
from torchvision.io import read_image
from torch.utils.data import Dataset, DataLoaderclass FeedForward(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(FeedForward, self).__init__()self.fc1 = nn.Linear(input_size, hidden_size)self.fc2 = nn.Linear(hidden_size, output_size)self.act = nn.Sigmoid()def forward(self, inputs):outputs = self.fc1(inputs)outputs = self.act(outputs)outputs = self.fc2(outputs)return outputsdef evaluate_loss(net, data_iter, loss):"""评估给定数据集上模型的损失Defined in :numref:`sec_model_selection`"""metric = d2l.Accumulator(2)  # 损失的总和,样本数量for X, y in data_iter:X = X.to(torch.float32)out = net(X)#         y = d2l.reshape(y, out.shape)l = loss(out, y.long())metric.add(d2l.reduce_sum(l), d2l.size(l))return metric[0] / metric[1]def train(trainer_fn, states, hyperparams, data_iter, feature_dim, num_epochs=2):"""Defined in :numref:`sec_minibatches`"""# 初始化模型w = torch.normal(mean=0.0, std=0.01, size=(feature_dim, 3),requires_grad=True)b = torch.zeros((3), requires_grad=True)# 训练模型animator = d2l.Animator(xlabel='epoch', ylabel='loss',xlim=[0, num_epochs], ylim=[0.9, 1.1])n, timer = 0, d2l.Timer()# 这是一个单层线性层net = lambda X: d2l.linreg(X, w, b)loss = F.cross_entropyfor _ in range(num_epochs):for X, y in data_iter:X = X.to(torch.float32)l = loss(net(X), y.long()).mean()l.backward()trainer_fn([w, b], states, hyperparams)n += X.shape[0]if n % 48 == 0:timer.stop()animator.add(n / X.shape[0] / len(data_iter),(evaluate_loss(net, data_iter, loss),))timer.start()print(f'loss: {animator.Y[0][-1]:.3f}, {timer.avg():.3f} sec/epoch')return timer.cumsum(), animator.Y[0]def load_data(shuffle=True):x = torch.tensor(load_iris().data)y = torch.tensor(load_iris().target)# 数据归一化x_min = torch.min(x, dim=0).valuesx_max = torch.max(x, dim=0).valuesx = (x - x_min) / (x_max - x_min)if shuffle:idx = torch.randperm(x.shape[0])x = x[idx]y = y[idx]return x, yclass IrisDataset(Dataset):def __init__(self, mode='train', num_train=120, num_dev=15):super(IrisDataset, self).__init__()x, y = load_data(shuffle=True)if mode == 'train':self.x, self.y = x[:num_train], y[:num_train]elif mode == 'dev':self.x, self.y = x[num_train:num_train + num_dev], y[num_train:num_train + num_dev]else:self.x, self.y = x[num_train + num_dev:], y[num_train + num_dev:]def __getitem__(self, idx):return self.x[idx], self.y[idx]def __len__(self):return len(self.x)def init_adam_states(feature_dim):v_w, v_b = torch.zeros((feature_dim, 3)), torch.zeros(3)s_w, s_b = torch.zeros((feature_dim, 3)), torch.zeros(3)return ((v_w, s_w), (v_b, s_b))def adam(params, states, hyperparams):beta1, beta2, eps = 0.9, 0.999, 1e-6for p, (v, s) in zip(params, states):with torch.no_grad():v[:] = beta1 * v + (1 - beta1) * p.grads[:] = beta2 * s + (1 - beta2) * torch.square(p.grad)# 偏差修正,请在下方实现偏差修正的计算公式v_bias_corr = v / (1 - beta1 ** hyperparams['t'])s_bias_corr = s / (1 - beta2 ** hyperparams['t'])p[:] -= hyperparams['lr'] * v_bias_corr / (torch.sqrt(s_bias_corr) + eps)p.grad.data.zero_()hyperparams['t'] += 1# batch_size = 1
batch_size = 24
# batch_size = 120# 分别构建训练集、验证集和测试集
train_dataset = IrisDataset(mode='train')train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)lr = 0.02
train(adam, init_adam_states(4), {'lr':lr, 't':1}, train_loader, 4, num_epochs=200)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/190426.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

STM32--系统滴答SysTick

一、SysTick是什么? Systick定时器是一个24bit的倒计时(向下计数)定时器,功能就是实现简单的延时。 SysTick 是一种系统定时器,通常在嵌入式系统中使用。它是 ARM Cortex-M 处理器的一个特殊定时器,用于提…

基于Qt 多线程(继承自QThread篇)

# 简介 我们写的一个应用程序,应用程序跑起来后一般情况下只有一个线程,但是可能也有特殊情况。比如我们前面章节写的例程都跑起来后只有一个线程,就是程序的主线程。线程内的操作都是顺序执行的。恩,顺序执行?试着想一下,我们的程序顺序执行,假设我们的用户界面点击有某…

JavaScript_动态表格_删除功能

1、动态表格_删除功能 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><title>动态表格_添加和删除功能</title><style>table{border: 1px solid;margin: auto;width: 100%;}td,th{text-align: …

网络渗透测试(被动扫描)

被动扫描 主要是指的是在目标无法察觉的情况下进行信息搜集。在Google上进行人名的搜素就是一次被动扫描。最经典的被动扫描技术就是"Google Hacking"技术。由于Google退出中国&#xff0c;暂时无法使用。在此介绍三个优秀的信息搜集工具 被动扫描范围 1.企业网络…

Kafka入门

kafka无疑是当今互联网公司使用最广泛的分布式实时消息流系统&#xff0c;它的高吞吐量&#xff0c;高可靠等特点为并发下的大批量实时请求处理提供了可靠保障。很多同学在项目中都用到过kafka&#xff0c;但是对kafka的设计原理以及处理机制并不是十分清楚。为了知其然知其所以…

k8s-docker二进制(1.28)的搭建

二进制文件-docker方式 1、准备的服务器 角色ip组件k8s-master1192.168.11.111kube-apiserver,kube-controller-manager,kube-scheduler,etcdk8s-master2192.168.11.112kube-apiserver,kube-controller-manager,kube-scheduler,etcdk8s-node1192.168.11.113kubelet,kube-prox…

Presentation Prompter 5.4.2(mac屏幕提词器)

Presentation Prompter是一款演讲辅助屏幕提词器软件&#xff0c;旨在帮助演讲者在公共演讲、主持活动或录制视频时更加流畅地进行演讲。以下是Presentation Prompter的一些特色功能&#xff1a; 提供滚动或分页显示&#xff1a;可以将演讲稿以滚动或分页的形式显示在屏幕上&a…

client-go controller-runtime kubebuilder

背景 这半年一直做k8s相关的工作&#xff0c;一直接触client-go controller-runtime kubebuilder&#xff0c;但是很少有文章将这三个的区别说明白&#xff0c;直接用框架是简单&#xff0c;但是出了问题就是黑盒&#xff0c;这不符合我的理念&#xff0c;所以这篇文章从头说起…

【Android】画面卡顿优化列表流畅度三之RecyclerView刷新机制notifyItemRangeInserted

经过长达一个多星期的反复渲染耗时记录&#xff0c;大致上有以下几个方面的地方可以优化&#xff1a; 列表组件RecyclerView刷新机制由notifyDataSetChanged()优化为notifyItemRangeInserted&#xff08;&#xff09;&#xff0c;后期有必要也会使用notifyItemRangeRemoved、n…

uniapp发行web页面在老版本浏览器打开一片空白

uniapp发行的web页面&#xff08;菜单->发行->网站-PC Web或手机H5&#xff09;&#xff0c;对于一些老的浏览器&#xff08;或内核&#xff09;&#xff0c;打开一片空白&#xff1b; 而在新版本的浏览器中打开却正常。这是因为那些版本较低的浏览器不支持ES6的语法和新…

mapboxGL中的底图切换

概述 底图切换&#xff0c;这么简单的功能还要写一篇文章&#xff1f;值得的&#xff0c;为什么这么说呢&#xff1f;因为mapboxGL的矢量底图有上百个&#xff0c;不同的底图用的样式、图层的名称、图层的内容、字体库、图标库都不一样&#xff0c;尤其是当地图上已经叠加了很…

vue 使用js new Map()优化多个if else 执行方法

前言 在实际开发中根据业务需求我们经常要判断情况&#xff0c;一个if 我们科技直接使用ES6就可以解决 经常会出现根据不同的条件执行不同的方法&#xff0c;这是就会有多个if else 看起不太美观也费劲 js new map &#xff08;&#xff09;就可以解决这个问题&#xff0c;它…

linux下搭建gperftools工具分析程序瓶颈

1. 先安装 unwind //使用root wget https://github.com/libunwind/libunwind/archive/v0.99.tar.gz tar -xvf v0.99.tar.gz cd libunwind-0.99 autoreconf --force -v --install ./configure make sudo make install2. 安装gperftools wget https://github.com/gp…

【无标题(PC+WAP)花卉租赁盆栽绿植类pbootcms站模板

(PCWAP)花卉租赁盆栽绿植类pbootcms网站模板 PbootCMS内核开发的网站模板&#xff0c;该模板适用于盆栽绿植网站等企业&#xff0c;当然其他行业也可以做&#xff0c;只需要把文字图片换成其他行业的即可&#xff1b; PCWAP&#xff0c;同一个后台&#xff0c;数据即时同步&…

纯c语言模拟栈和队列(初学必看)

一、栈(Stack) 1.栈的概念及其结构 栈是一种特殊的线性表&#xff0c;在栈这个结构里&#xff0c;越先存进去的数据越难取出来。 这个结构就像是一个只有一端有打开的容器&#xff0c;越先放进去的球越在底部&#xff0c;想要把底部的球拿出来&#xff0c;就必须先把前面的求…

【Pytest】跳过执行之@pytest.mark.skip()详解

一、skip介绍及运用 在我们自动化测试过程中&#xff0c;经常会遇到功能阻塞、功能未实现、环境等一系列外部因素问题导致的一些用例执行不了&#xff0c;这时我们就可以用到跳过skip用例&#xff0c;如果我们注释掉或删除掉&#xff0c;后面还要进行恢复操作。 1、skip跳过成…

Spring IOC - Bean的生命周期之实例化

在Spring启动流程文章中讲到&#xff0c;容器的初始化是从refresh方法开始的&#xff0c;其在初始化的过程中会调用finishBeanFactoryInitialization方法。 而在该方法中则会调用DefaultListableBeanFactory#preInstantiateSingletons方法&#xff0c;该方法的核心作用是初始化…

元核云亮相金博会,智能质检助力金融合规

11月初&#xff0c;第五届中新&#xff08;苏州&#xff09;数字金融应用博览会&#xff5c;2023金融科技大会在苏州国际博览中心举办&#xff0c;围绕金融科技发展热点领域及金融行业信息科技领域重点工作&#xff0c;分享优秀实践经验&#xff0c;探讨数字化转型路径与未来发…

蓝桥杯每日一题2023.11.8

题目描述 题目分析 对于输入的abc我们可以以a为年也可以以c为年&#xff0c;将abc,cab,cba这三种情况进行判断合法性即可&#xff0c;注意需要排序去重&#xff0c;所以考虑使用set 此处为纯模拟的写法&#xff0c;但使用循环代码会更加简洁。 方法一&#xff1a; #include&…

物联网AI MicroPython学习之语法 network网络配置模块

学物联网&#xff0c;来万物简单IoT物联网&#xff01;&#xff01; network介绍 模块功能&#xff1a; 用于管理Wi-Fi和以太网的网络模块参考用法&#xff1a; import network import time nic network.WLAN(network.STA_IF) nic.active(True) if not nic.isconnected():…