贝叶斯神经网络用于学习曲线的概率预测【ICLR 2017】

论文下载地址:Excellent-Paper-For-Daily-Reading/hyper-parameters at main

类别:超参数

时间:2023/10/30

摘要

面对不同的神经网络结构、超参数和训练协议,通常需要检查生成学习曲线,以快速终止超参数设置不佳的运行,从而大大加快手动超参数优化。通过跨超参数设置的学习曲线的概率模型,可以在自动超参数优化中利用相同的信息。论文研究了贝叶斯神经网络的使用,并通过一个专门的学习曲线层来提高它们的性能。

论文完成的成果

  • 研究贝叶斯神经网络如何很好地适应各种架构和超参数设置的学习曲线,以及它们的不确定性估计有多可靠。
  • 开发了一个带有学习曲线层的专用神经网络架构,以改进学习曲线预测。
  • 比较了生成贝叶斯神经网络的不同方法:概率反向传播和两种不同的基于随机梯度的马尔可夫链蒙特卡罗(MCMC)方法。
  • 评估了全新学习曲线和外推部分观察曲线的预测质量,在学习曲线尚未收敛的阶段。
  • 扩展了多臂强盗策略(multi-armed bandit strategy),使用我们的模型进行采样,而不是均匀随机采样,从而使其能够比传统的贝叶斯优化更快地接近最优配置。

实验

学习曲线预测的实验

在实验部分,采用了不同的神经网络架构和学习曲线预测方法,并在不同数据集上进行了评估。实验结果表明,新模型的性能表现出良好的均方误差和平均对数似然,特别是使用随机梯度汉密尔顿MCMC方法(SGHMC)时表现更佳。此外,文章还比较了其他用于学习曲线预测的方法,包括随机森林、高斯过程、概率反向传播和简单的“最后一个观察到的值”方法。

左图显示了不同方法在CNN基准上的平均预测。所有模型都观察到真实学习曲线(黑色)的前12个epoch的验证误差。右图,绘制了40个epoch值的后验分布。

结论

论文研究了一种基于贝叶斯神经网络的学习曲线建模方法,为解决超参数优化和性能改进问题提供了新的思路和工具。贝叶斯神经网络的引入以及新型学习曲线层的设计为未来的研究和实践提供了有趣的方向。

这篇论文为深度学习领域的研究者提供了一个全新的视角,强调了贝叶斯神经网络在学习曲线预测和超参数优化中的重要性。通过结合不同领域的知识,我们有望进一步提高机器学习算法的性能。

学习率范围测试

学习率范围测试,又被称为LR Finder,是机器学习领域的一个重要实践工具。在深度学习模型训练中,学习率的选择通常是一个挑战,因为一个合适的学习率可以加速收敛并提高性能,但不合适的学习率可能导致训练不稳定或收敛缓慢。传统上,学习率的设定是基于经验和试错的,这篇论文介绍了一种更科学、更系统的方法,即学习率范围测试。

学习率范围测试的主要思想是在训练过程中逐渐增加学习率,然后观察模型的损失如何随学习率的增加而变化。通过分析损失与学习率之间的关系,可以找到一个合适的学习率范围,其中学习率既不会过高导致模型发散,也不会过低导致训练速度过慢。这种方法有助于为模型选择一个更有科学依据的初始学习率。

下面我根据模型、dataloader、损失函数和学习率进行调整:


import torch.nn as nn
import torchimport matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from utils.util import get_network
from utils.datasets import get_train_loader
from utils.lr_scheduler import _LRScheduler
from pyzjr.dlearn.learnrate import get_optimizerclass FindLR(_LRScheduler):"""exponentially increasing learning rateArgs:optimizer: optimzier(e.g. SGD)num_iter: totoal_itersmax_lr: maximum  learning rate"""def __init__(self, optimizer, max_lr=10, num_iter=100, last_epoch=-1):self.total_iters = num_iterself.max_lr = max_lrsuper().__init__(optimizer, last_epoch)def get_lr(self):return [base_lr * (self.max_lr / base_lr) ** (self.last_epoch / (self.total_iters + 1e-32)) for base_lr in self.base_lrs]# class lr_finder():
#     def __init__(self,net, training_loader, loss_function,optimizer_type="sgd",num_iter=100,batch_size=4):
#         self.net = net
#         self.training_loader = training_loader
#         self.loss_function = loss_function
#         self.optimizer_type = optimizer_type
#         self.num_iter = num_iter
#         self.batch_size = batch_size
# 
#     def update(self, init_lr=1e-7, max_lr=10):
#         n = 0
#         learning_rate = []
#         losses = []
#         optimizer = get_optimizer(self.net, self.optimizer_type, init_lr)
#         lr_scheduler = FindLR(optimizer, max_lr=max_lr, num_iter=self.num_iter)
#         epoches = int(args.num_iter / len(self.training_loader)) + 1
# 
#         for epoch in range(epoches):
#             net.train()
#             for batch_index, (images, labels) in enumerate(self.training_loader):
#                 if n > self.num_iter:
#                     break
#                 if torch.cuda.is_available():
#                     images = images.cuda()
#                     labels = labels.cuda()
# 
#                 optimizer.zero_grad()
#                 predicts = net(images)
#                 loss = loss_function(predicts, labels)
#                 if torch.isnan(loss).any():
#                     n += 1e8
#                     break
#                 loss.backward()
#                 optimizer.step()
#                 lr_scheduler.step()
# 
#                 print('Iterations: {iter_num} [{trained_samples}/{total_samples}]\tLoss: {:0.4f}\tLR: {:0.8f}'.format(
#                     loss.item(),
#                     optimizer.param_groups[0]['lr'],
#                     iter_num=n,
#                     trained_samples=batch_index * self.batch_size + len(images),
#                     total_samples=len(self.training_loader),
#                 ))
# 
#                 learning_rate.append(optimizer.param_groups[0]['lr'])
#                 losses.append(loss.item())
#                 n += 1
# 
#         self.learning_rate = learning_rate[10:-5]
#         self.losses = losses[10:-5]
# 
#     def plotshow(self, show=True):
#         import matplotlib
#         matplotlib.use("TkAgg")
#         fig, ax = plt.subplots(1, 1)
#         ax.plot(self.learning_rate, self.losses)
#         ax.set_xlabel('learning rate')
#         ax.set_ylabel('losses')
#         ax.set_xscale('log')
#         ax.xaxis.set_major_formatter(plt.FormatStrFormatter('%.0e'))
#         if show:
#             plt.show()
# 
#     def save(self, path='result.jpg'):
#         self.plotshow(show=False)
#         plt.savefig(path)if __name__ == '__main__':class parser_args():def __init__(self):self.net = "vgg16"self.batch_size = 64self.base_lr = 1e-7self.max_lr = 10self.num_iter = 100self.Cuda = Trueself.num_class = 4from pyzjr.dlearn.learnrate import lr_finderargs = parser_args()txt_path = r"D:\PythonProject\Torchproject\classification\dataset\train.txt"train_loader = get_train_loader(txt_path, batch_size=4, train=True)net = get_network(args)loss_function = nn.CrossEntropyLoss()lrfinder = lr_finder(net, train_loader, loss_function)lrfinder.update()lrfinder.plotshow()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/175610.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

c++ pcl 选取点云某一点反馈XYZ坐标的代码

看了看以前的代码,有一小段代码很有意思,是关于pcl点云处理的。 如有帮助,点赞收藏关注!!! 读取点云数据,想可视化点云数据,并根据选择,实时显示点云的空间坐标数值。 接…

分享大数据分析师前景怎么样? 从事行业有哪些?

数据分析师发展前景和待遇怎么样?有前途吗?好找工作吗?根据某招聘网数据显示,当前市场表现为: 2023年较2022年同期对比增长160%,2022年较2021年下降了46%。 工资待遇:2023年较2022年下降了2…

使用 kube-downscaler 降低Kubernetes集群成本

新钛云服已累计为您分享772篇技术干货 介绍 Kube-downscaler 是一款开源工具,允许用户定义 Kubernetes 中 pod 资源自动缩减的时间。这有助于通过减少非高峰时段的资源使用量来降低基础设施成本。 在本文中,我们将详细介绍 kube-downscaler 的功能、安装…

《算法通关村—队列基本特征和实现问题解析》

《算法通关村—队列基本特征和实现问题解析》 队列的基本特征 队列(Queue)是一种常见的数据结构,具有以下基本特征: 先进先出(FIFO):队列中的元素按照它们被添加到队列的顺序排列,…

UIAlertController 修改 title 或 message 样式相关

UIAlertController 文字换行后默认对齐方式为居中,若想调整其相关样式属性可以借鉴如下方式进行修改,具体实现方式 code 如下: NSString *msg "1、注销≠退出登录;\n注销:对不再使用的账号进行清空移除;注销后,App中数据将全部丢失,不可再找回;\n2、注销后,与账号相关的…

第三方软件测评选择远程测试好还是现场测试好?

如今许多软件企业在软件开发过程完成之后,会将软件测试工作交由第三方软件测评机构来进行,那么做第三方软件测试时,远程测试和现场测试哪个更好呢?我想这是许多软件企业都十分关注的问题,今天卓码软件测评小编将对以上问题作出简…

Linux rm命令:删除文件或目录

当 Linux 系统使用很长时间之后,可能会有一些已经没用的文件(即垃圾),这些文件不但会消耗宝贵的硬盘资源,还是降低系统的运行效率,因此需要及时地清理。 rm 是强大的删除命令,它可以永久性地删除…

天软特色因子看板(2023.10 第13期)

该因子看板跟踪天软特色因子A05005(近一月单笔流涌金额占比(%),该因子为近一个月单笔流通金额占比因,用以刻画股票在收盘时,主力资金在总交易金额中所占的比重。 今日为该因子跟踪第11期,跟踪其在SW801150 (申万医药生物) 中的表现…

如何通过会员营销数字化推动精准营销与用户忠诚度培养?

营销策略的制定和实施对于企业的成功至关重要,而会员数字化营销系统将通过用户画像、会员标签等重要功能,推动企业提高用户忠诚度培养。目前市面上有哪些热门的会员营销功能? 一、用户画像:让营销更精准 用户画像是一种通过收集和…

Hadoop学习总结(搭建Hadoop集群(伪分布式模式))

如果前面有搭建过Hadoop集群完全分布式模式,现在搭建Hadoop伪分布式模式可以选择直接克隆完全分布式模式中的主节点(hadoop001)。以下是在搭建过完全分布式模式下的Hadoop集群的情况进行 伪分布式模式下的Hadoop功能与完全分布式模式下的Hadoop功能相同。 一、克隆…

电脑不显示桌面?盘点4个正确操作!

“我的电脑一打开后完全加载不出来桌面,现在我也不知道怎么办,有没有比较了解电脑的大佬可以分享一下经验呀?” 有时候我们使用电脑时可能会遇到桌面上所有的应用程序都消失了甚至桌面不显示的情况。如果电脑不显示桌面我们可能就很难进行下一…

ToLua使用原生C#List和Dictionary

ToLua是使用原生C#List 介绍Lua中使用原生ListC#调用luaLua中操作打印测试如下 Lua中使用原生DictionaryC#调用luaLua中操作打印测试如下 介绍 当你用ToLua时C#和Lua之间肯定是会互相调用的,那么lua里面使用List和Dictionary肯定是必然的,在C#中可以调用…

关于FreeTypeFont‘ object has no attribute ‘getsize‘问题的解决方案

引言 这个问题是在训练yolov5_obb项目遇到的,大概率又是环境问题。如下图: 解决方法 出现这个问题是Pillow版本太高了,下载低版本的: pip install Pillow9.5 OK!

【UE】属性同步,源码详解一个勾选了Actor复制的Actor第一次被创建时经历了什么

本文参考https://zhuanlan.zhihu.com/p/640723352 准备工作 先准备一个勾选了复制的Actor,然后在游戏开始时Spawn这个Actor 源码过程详解 发送属性同步 在NetDriver的TickFlush中发送属性同步的数据 1、ServerReplicateActors_BuildConsiderList 去找到所有需…

算法通关村第四关-黄金挑战栈的经典问题

括号匹配问题 描述 : 给定一个只包括 (,),{,},[,] 的字符串 s ,判断字符串是否有效。 有效字符串需满足: 左括号必须用相同类型的右括号闭合。左括号必须以正确的顺序闭合。每个右括号都有…

超2000个大模型应用,支持文心4.0!AI Studio星河大模型社区升级上新

想给自己做个私人定制的旅行攻略,满足个性化的出游需求,还要细致关注到天气、穿衣、老人孩子的作息等等,但太耗时费力怎么办?让AI帮忙搞定。一位开发者在AI Studio星河大模型社区用短短数小时就做好了“旅行规划家”智能应用。像这…

色彩校正及OpenCV mcc模块介绍

一、术语 1.光:是电磁波,可见光是可被人眼感知的电磁波。可见光大约在400-700nm波段。光子携带的能量与波长成反比,400nm--700nm之间的单色光的颜色从紫色渐变成红色。 2.光谱:除了太阳光源外,LED灯、白炽灯等各种照明…

基于STC系列单片机实现定时器0扫描数码管显示定时器/计数器1作为计数器1产生频率的功能

#define uchar unsigned char//自定义无符号字符型为uchar #define uint unsigned int//自定义无符号整数型为uint #define NixieTubeSegmentCode P0//自定义数码管段码为单片机P0组引脚 #define NixieTubeBitCode P2//自定义数码管位码为单片机P2组引脚 sbit LED P1^0;//位定义…

探索Vue 3和Vue 2的区别

目录 响应式系统 性能优化 Composition API TypeScript支持 总结 Vue.js是一款流行的JavaScript框架,用于构建用户界面。Vue 3是Vue.js的最新版本,相较于Vue 2引入了许多重大变化和改进。在本文中,我们将探索Vue 3和Vue 2之间的区别。 …

深入探究Vue.js生命周期及其应用场景

当谈到Vue.js的生命周期时,我们指的是组件在创建、更新和销毁过程中发生的一系列事件。了解Vue的生命周期对于开发人员来说是至关重要的,因为它们提供了一个机会来执行特定任务,并在不同的阶段处理组件。 Vue的生命周期可以分为八个不同的阶…