百度飞将 paddle ,实现贝叶斯神经网络 bayesue neure network bnn,aistudio公开项目 复现效果不好

论文复现赛:贝叶斯神经网络 - 飞桨AI Studio星河社区

https://github.com/hrdwsong/BayesianCNN-Paddle

论文复现:Weight Uncertainty in Neural Networks

本项目复现时遇到一个比较大的问题,用pytorch顺利跑通源代码后,修改至paddle框架下再次训练,发现模型不收敛,训练准确率一直维持在0.1附近(随机挑选概率), 模型完全没有学到东西。

针对此问题,我依次对dataset、dataloader、模型参数初始化、优化器、loss函数,甚至沿着整个计算图跟踪了梯度是否正确传递。 最终定位为paddle.max函数,使用该函数后,问题出现;屏蔽该函数后,问题消失。 经分析,应该是该函数不连续,不支持梯度传递。而pytorch版本的max函数则没有此问题。

 

一、简介

本文将不确定性引入神经网络,将确定性参数的神经网络改造为具有随机特性的概率神经网络(也成贝叶斯神经网络)。本文是贝叶斯神经网络的奠基作之一,具有很高的引用量。

具体地,在传统神经网络中,各网络节点的参数为确定值;通过本文方法引入不确定性后,各网络节点的参数转变为满足概率分布的随机变量。 每次正向推理时,网络会根据概率分布对参数值进行采样,并以采样到的值作为本次正向推理的参数值。传统神经网络与贝叶斯神经网络的异同点如下图所示:

训练贝叶斯神经网络时,通过本文方法,可将loss函数反向传递到网络节点的概率分布参数上,从而动态调优该网络。

论文链接:Weight Uncertainty in Neural Networks

二、复现精度

基于paddlepaddle深度学习框架,对文献算法进行复现后,本项目达到的测试精度,如下表所示。 参考文献的最高精度为98.68%

模型和方法本项目精度
lenet-bbb98.75%
alexnet-bbb98.73%
3conv3fc-bbb99.07%
lenet-lrt98.76%
alexnet-lrt98.82%
3conv3fc-lrt99.29%

超参数配置如下:

超参数名设置值
lr0.01
batch_size256
epochs200

三、数据集

本项目使用的是MNIST数据集。该数据集为美国国家标准与技术研究所(National Institute of Standards and Technology (NIST))发起整理,一共统计了来自250个不同的人手写数字图片,其中50%是高中生,50%来自人口普查局的工作人员。该数据集的收集目的是希望通过算法,实现对手写数字的识别。

  • 数据集大小:
    • MNIST数据集是机器学习领域中非常经典的一个数据集,由60000个训练样本和10000个测试样本组成,每个样本都是一张28 * 28像素的灰度手写数字图片。
  • 数据格式:它包含了四个部分
    • (1)Training set images: train-images-idx3-ubyte.gz (9.9 MB, 解压后 47 MB, 包含 60,000 个样本)
    • (2)Training set labels: train-labels-idx1-ubyte.gz (29 KB, 解压后 60 KB, 包含 60,000 个标签)
    • (3)Test set images: t10k-images-idx3-ubyte.gz (1.6 MB, 解压后 7.8 MB, 包含 10,000 个样本)
    • (4)Test set labels: t10k-labels-idx1-ubyte.gz (5KB, 解压后 10 KB, 包含 10,000 个标签)

数据集链接:MNIST

四、环境依赖

  • 硬件:

    • x86 cpu
    • NVIDIA GPU
  • 框架:

    • PaddlePaddle = 2.1.2
  • 其他依赖项:

    • numpy==1.19.3
    • matplotlib==3.3.4
    • pandas==1.2.4
    • pytest==6.2.4
    • paddle==1.0.2
    • Pillow==8.3.1

五、快速开始

1、执行以下命令启动训练:

python train.py --net_type 3conv3fc --dataset MNIST

训练贝叶斯神经网络,运行完毕后,模型参数文件保存在./checkpoints/MNIST/bayesian目录下。

2、执行以下命令进行评估

python test.py --net_type 3conv3fc --dataset MNIST 用于测试贝叶斯神经网络,测试前,将已训练好的最优参数模型从./results/3CONV3FC拷贝至./checkpoints/MNIST/bayesian

In [ ]

# 解压项目文件夹
!unzip -o Paddle-BayesianCNN-V1.zip
%cd Paddle-BayesianCNN

In [7]

# config_bayesian.py文件中修改训练方法,选择'bbb'或'lrt'
# 训练模型
!python train.py --net_type 3conv3fc --dataset MNIST

In [ ]

# 测试模型精度
!python test.py --net_type 3conv3fc --dataset MNIST

六、代码结构与详细说明

6.1 代码结构

├── onfig_bayesian.py               # 配置
├── metrics.py                      # 度量相关
├── README.md                       # readme
├── requirements.txt                # 依赖
├── test                            # 测试
├── train                           # 启动训练入口
├── utils.py                        # 公共调用
├── checkpoints                     # 保存
│   ├── MNIST                        # 数据集名称
│      ├── bayesian
│      ├── best
├── data
│   ├── data.py
├── layers
│   ├── misc.py
│   ├── BBB
│       ├── BBBConv.py
│       ├── BBBLinear.py
├── models
│   ├── BayesianModels
│       ├── BayesianOriginNet.py
│       ├── BayesianLeNet.py

6.2 参数说明

可以在 train.py 中设置训练与评估相关参数,具体如下:

参数默认值说明其他
--net_type3conv3fc, 可选选择模型可选择lenet/alexnet/3conv3fc/originet
--datasetMNIST, 可选选择数据集本项目目前仅支持MNIST

6.3 训练流程

可参考快速开始章节中的描述

训练输出

执行训练开始后,将得到类似如下的输出。每一轮epoch训练将会打印当前training loss、training acc、val loss、val acc以及训练kl散度。

Epoch: 0 	Training Loss: 957661.3024 	Training Accuracy: 0.5314 	Validation Loss: 6048323.2596 	Validation Accuracy: 0.8872 	train_kl_div: 108218176.5714
Validation loss decreased (inf --> 6048323.259558).  Saving model ...
Epoch: 1 	Training Loss: 620338.8870 	Training Accuracy: 0.7838 	Validation Loss: 4819156.8720 	Validation Accuracy: 0.8885 	train_kl_div: 90394454.2449
Validation loss decreased (6048323.259558 --> 4819156.872046).  Saving model ...
Epoch: 2 	Training Loss: 483882.8229 	Training Accuracy: 0.8268 	Validation Loss: 3822200.3844 	Validation Accuracy: 0.8913 	train_kl_div: 71920784.3061
Validation loss decreased (4819156.872046 --> 3822200.384351).  Saving model ...
Epoch: 3 	Training Loss: 390434.6679 	Training Accuracy: 0.8332 	Validation Loss: 2554367.9053 	Validation Accuracy: 0.9018 	train_kl_div: 48361270.8571
Validation loss decreased (3822200.384351 --> 2554367.905322).  Saving model ...
Epoch: 4 	Training Loss: 275255.5825 	Training Accuracy: 0.8434 	Validation Loss: 1809289.2232 	Validation Accuracy: 0.9172 	train_kl_div: 34525619.3469

6.4 测试流程

可参考快速开始章节中的描述

此时的输出为:

Testing Accuracy: 0.9907

七、实验数据比较及复现心得

7.1 实验数据比较

在不同的超参数配置下,模型的收敛效果、达到的精度指标有较大的差异,以下列举不同超参数配置下,实验结果的差异性,便于比较分析:

(1)学习率:

原文献采用的优化器与本项目一致,为Adam优化器,原文献学习率设置为0.001,本项目经调参发现, 学习率设置为0.01或0.0001时,网络有时会不收敛,该模型的稳定性存在可改进空间。

(2)epoch轮次

本项目训练时,采用的epoch轮次为200。LOSS和准确率在110个epoch附近已趋于稳定,模型处于收敛状态,下图为3CONV3FC-BBB的训练曲线。

7.2 复现心得

本项目复现时遇到一个比较大的问题,用pytorch顺利跑通源代码后,修改至paddle框架下再次训练,发现模型不收敛,训练准确率一直维持在0.1附近(随机挑选概率), 模型完全没有学到东西。

针对此问题,我依次对dataset、dataloader、模型参数初始化、优化器、loss函数,甚至沿着整个计算图跟踪了梯度是否正确传递。 最终定位为paddle.max函数,使用该函数后,问题出现;屏蔽该函数后,问题消失。 经分析,应该是该函数不连续,不支持梯度传递。而pytorch版本的max函数则没有此问题。

八、模型信息

训练完成后,模型保存在checkpoints目录下。

训练和测试日志保存在results目录下。

信息说明
发布者hrdwsong
时间2021.08
框架版本Paddle 2.1.2
应用场景贝叶斯神经网络
支持硬件GPU、CPU
repo地址https://github.com/hrdwsong/BayesianCNN-Paddle

请点击此处查看本环境基本用法.
Please click here for more detailed instructions.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/418866.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Python报错已解决】 AttributeError: ‘move_to‘ requires a WebElement

🎬 鸽芷咕:个人主页 🔥 个人专栏: 《C干货基地》《粉丝福利》 ⛺️生活的理想,就是为了理想的生活! 文章目录 前言一、问题描述1.1 报错示例1.2 报错分析1.3 解决思路 二、解决方法2.1 方法一:检查元素选择器2.2 方法…

828华为云征文|华为云Flexus X实例docker部署rancher并构建k8s集群

828华为云征文|华为云Flexus X实例docker部署rancher并构建k8s集群 华为云最近正在举办828 B2B企业节,Flexus X实例的促销力度非常大,特别适合那些对算力性能有高要求的小伙伴。如果你有自建MySQL、Redis、Nginx等服务的需求,一定…

一款支持同一个屏幕界面同时播放多个视频的视频播放软件

GridPlayer 是一款基于 VLC 的免费开源跨平台多视频同步播放工具,支持在一块屏幕上同时播放多个视频。其主要功能包括: 多视频播放:用户可以在一个窗口中同时播放任意数量的视频,数量仅受硬件性能限制。支持多种格式和流媒体&…

java实现,PDF转换为TIF

目录 ■JDK版本 ■java代码・实现效果 ■POM引用 ■之前TIF相关的问题(两张TIF合并) ■对于成果物TIF,需要考虑的点 ■问题 ■问题1:无法生成TIF,已解决 ■问题2:生成的TIF过大,已解决 …

vue3 自定义指令 directive

1、官方说明:https://cn.vuejs.org/guide/reusability/custom-directives 除了 Vue 内置的一系列指令 (比如 v-model 或 v-show) 之外,Vue 还允许你注册自定义的指令 (Custom Directives)。 我们已经介绍了两种在 Vue 中重用代码的方式:组件和…

QT 编译报错:C3861: ‘tr‘ identifier not found

问题: QT 编译报错:C3861: ‘tr’ identifier not found 原因 使用tr的地方所在的类没有继承自 QObject 类 或者在不在某一类中, 解决方案 就直接用类名引用 :QObject::tr( )

ApacheKafka中的设计

文章目录 1、介绍1_Kafka&MQ场景2_Kafka 架构剖析3_分区&日志4_生产者&消费者组5_核心概念总结6_顺写&mmap7_Kafka的数据存储形式 2、Kafka的数据同步机制1_高水位(High Watermark)2_LEO3_高水位更新机制4_副本同步机制解析5_消息丢失问…

matplotlib中文乱码问题

在使用Matplotlib进行数据可视化的过程中,经常会遇到中文乱码的问题。显示乱码是由于编码问题导致的,而matplotlib 默认使用ASCII 编码,但是当使用pyplot时,是支持unicode编码的,只是默认字体是英文字体,导…

GraphPad Prism 10 for Mac/Win:高效统计分析与精美绘图的科学利器

GraphPad Prism 10 是一款专为科研工作者设计的强大统计分析与绘图软件,无论是Mac还是Windows用户,都能享受到其带来的便捷与高效。该软件广泛应用于生物医学研究、实验设计和数据分析领域,以其直观的操作界面、丰富的统计方法和多样化的图表…

【HuggingFace Transformers】OpenAIGPTModel源码解析

OpenAIGPTModel源码解析 1. GPT 介绍2. OpenAIGPTModel类 源码解析 说到ChatGPT,大家可能都使用过吧。2022年,ChatGPT的推出引发了广泛的关注和讨论。这款对话生成模型不仅具备了强大的语言理解和生成能力,还能进行非常自然的对话&#xff0c…

MapSet之二叉搜索树

系列文章: 1. 先导片--Map&Set之二叉搜索树 2. Map&Set之相关概念 目录 前言 1.二叉搜索树 1.1 定义 1.2 操作-查找 1.3 操作-新增 1.4 操作-删除(难点) 1.5 总体实现代码 1.6 性能分析 前言 TreeMap 和 TreeSet 是 Java 中基于搜索树实现的 M…

图形语言传输格式glTF和三维瓦片数据3Dtiles(b3dm、pnts)学习

文章目录 一、3DTiles二、b3dm三、glTF1.glTF 3D模型格式有两种2.glTF 场景描述结构和坐标系3.glTF的索引访问与ID4.glTF asset5.glTF的JSON结构scenesscene.nodes nodesnodes.children transformations对外部数据的引用buffers 原始二进制数据块,没有固有的结构或含…

表单项标签简单学习

目录 1. 单选框 radio​编辑​编辑​编辑​编辑 2. 复选框 checkbox ​编辑​编辑​编辑 3. 隐藏域 hidden 4. 多行文本框 textarea​编辑​编辑 5. 下拉框 select​编辑​编辑 6. 选择头像​编辑​编辑 <!DOCTYPE html> <html lang"en"> <head&…

自用NAS系列1-设备

拾光坞 拾光坞多账号绑定青龙面板SMBWebdav小雅alist下载到NASDocker安装迅雷功能利用qBittorrentEEJackett打造一站式下载工具安装jackett插件 外网访问内网拾光客户端拾光穿透公网ipv6路由器配置ipv6拾光坞公网验证拾光坞域名验证 拾光坞 多账号绑定 手机注册拾光坞账号&am…

GEE数据集:加拿大卫星森林资源调查 (SBFI)-2020 年加拿大森林覆盖、干扰恢复、结构、物种、林分年龄以及 1985-2020 年林分替代干扰的信息

目录 简介 数据集后处理 数据下载链接 矢量属性 代码 代码链接 引用 许可 网址推荐 0代码在线构建地图应用 机器学习 加拿大卫星森林资源调查 (SBFI) 简介 卫星森林资源清查&#xff08;SBFI&#xff09;提供了 2020 年加拿大森林覆盖、干扰恢复、结构、物种、林分…

海外云手机是否适合运营TikTok?

随着科技的迅猛发展&#xff0c;海外云手机逐渐成为改变工作模式的重要工具。这种基于云端技术的虚拟手机&#xff0c;不仅提供了更加便捷、安全的使用体验&#xff0c;还在电商引流和海外社媒管理等领域展示了其巨大潜力。那么&#xff0c;海外云手机究竟能否有效用于运营TikT…

828华为云征文 | Flexus X 实例服务器网络性能深度评测

引言 随着互联网应用的快速发展&#xff0c;网络带宽和性能对云服务器的表现至关重要。在不同的云服务平台上&#xff0c;即便配置相同的带宽&#xff0c;实际的网络表现也可能有所差异。因此&#xff0c;了解并测试服务器的网络性能变得尤为重要。本文将以华为云X实例服务器为…

Open-Sora代码详细解读(1):解读DiT结构

Diffusion Models专栏文章汇总&#xff1a;入门与实战 前言&#xff1a;目前开源的DiT视频生成模型不是很多&#xff0c;Open-Sora是开发者生态最好的一个&#xff0c;涵盖了DiT、时空DiT、3D VAE、Rectified Flow、因果卷积等Diffusion视频生成的经典知识点。本篇博客从Open-S…

【MySQL】MySQL基础

目录 什么是数据库主流数据库基本使用MySQL的安装连接服务器服务器、数据库、表关系使用案例数据逻辑存储 MySQL的架构SQL分类什么是存储引擎 什么是数据库 mysql它是数据库服务的客户端mysqld它是数据库服务的服务器端mysql本质&#xff1a;基于C&#xff08;mysql&#xff09…

linux系统中,计算两个文件的相对路径

realpath --relative-to/home/itheima/smartnic/smartinc/blocks/ruby/seanet_diamond/tb/parser/test_parser_top /home/itheima/smartnic/smartinc/corundum/fpga/lib/eth/lib/axis/rtl/axis_fifo.v 检验方式就是直接在当前路径下&#xff0c;把输出的路径复制一份&#xff0…