cuda从零开始手搓PB神经网络

cuda实现PB神经网络


基于上一篇的矩阵点乘,实现了矩阵的加减乘除、函数调用等。并且复用之前元编程里面写的梯度下降、Adam、NAdam优化方法。实现PB神经网络如下:

#ifndef __BP_NETWORK_HPP__
#define __BP_NETWORK_HPP__
#include "matrix.hpp"
#include "mat.hpp"
#include "update_methods.hpp"template<typename activate_type, typename val_type_, template<typename> class update_type_tpl, typename init_type, int input_num_, int output_num_, int ... remain_layer>
struct bp_network
{constexpr static int input_num = input_num_;constexpr static int output_num = output_num_;using val_type = val_type_;using input_type = mat<input_num, 1, val_type>;using input_t_type = mat<1, input_num, val_type>;using output_type = mat<output_num, 1, val_type>;using weight_type = mat<output_num, input_num, val_type>;using forward_func = typename func_pair<activate_type>::forward_func;using backward_func = typename func_pair<activate_type>::backward_func;using next_node_type = typename bp_network<activate_type, val_type, update_type_tpl, init_type, output_num, remain_layer...>;using term_output_type = typename next_node_type::term_output_type;weight_type weight;update_type_tpl<weight_type> weight_update_method;output_type bias;update_type_tpl<output_type> bias_update_method;input_type pre_input;output_type pre_func_input;next_node_type next_node;bp_network():weight_update_method(), bias_update_method(){weight.template reset<init_type>();bias.template reset<init_type>();next_node = bp_network<activate_type, val_type, update_type_tpl, init_type, output_num, remain_layer...>();}auto forward(input_type& input){output_type curr_output;pre_input = input;auto temp = weight.dot(input);pre_func_input = temp + bias;curr_output = pre_func_input.template activate<forward_func>();return next_node.forward(curr_output);}auto backward(term_output_type& delta, val_type lr){output_type curr_delta = next_node.backward(delta, lr);curr_delta = pre_func_input.template activate<backward_func>() * curr_delta;auto ret = weight.t_dot(curr_delta);// 更新参数weight_type delta_weight = curr_delta.dot(pre_input.t());weight = weight_update_method.update(weight, delta_weight);bias = bias_update_method.update(bias, curr_delta);return ret;}   // 更新惯性量void update_inert(){weight_update_method.update_inert();bias_update_method.update_inert();next_node.update_inert();}void print(){weight.print();printf("-----------------\n");bias.print();printf("=================\n");next_node.print();}
};template<typename activate_type, typename val_type_, template<typename> class update_type_tpl, typename init_type, int input_num_, int output_num_>
struct bp_network<activate_type, val_type_, update_type_tpl, init_type, input_num_, output_num_>
{constexpr static int input_num = input_num_;constexpr static int output_num = output_num_;using val_type = val_type_;using input_type = mat<input_num, 1, val_type>;using input_t_type = mat<1, input_num, val_type>;using output_type = mat<output_num, 1, val_type>;using weight_type = mat<output_num, input_num, val_type>;using forward_func = typename func_pair<activate_type>::forward_func;using backward_func = typename func_pair<activate_type>::backward_func;using term_output_type = typename output_type;using weight_update_type = typename update_type_tpl<weight_type>;using bias_update_type = typename update_type_tpl<output_type>;weight_type weight;weight_update_type weight_update;output_type bias;bias_update_type bias_update;output_type pre_func_input;input_type pre_input;bp_network():weight_update(), bias_update(){weight.template reset<init_type>();bias.template reset<init_type>();}auto forward(input_type& input){pre_input = input;auto temp = weight.dot(input);pre_func_input = temp + bias;return pre_func_input.template activate<forward_func>();}auto backward(output_type& delta, val_type lr){output_type curr_delta = pre_func_input.template activate<backward_func>() * delta;auto ret = weight.t_dot(curr_delta);// 更新参数weight_type delta_weight = curr_delta.dot(pre_input.t());weight = weight_update.update(weight, delta_weight);bias = bias_update.update(bias, curr_delta);return ret;}void update_inert(){weight_update.update_inert();bias_update.update_inert();}void print(){weight.print();printf("-----------------\n");bias.print();printf("*****************\n");}
};#endif

下面实验一下我们的bp神经网络。

#include <chrono>
#include <thread>
#include "matrix.hpp"
#include "bp_network.hpp"
int main()
{constexpr int row_num = 32;constexpr int adj_num = 32;constexpr int col_num = 32;/*matrix_device_proxy<row_num, adj_num, double> A;eyes(A(), 2.0);matrix_device_proxy<adj_num, col_num, double> B;eyes(B(), 1.0);matrix_device_proxy<row_num, col_num, double> C;mat_dot<sigmoid>(A(), B(), C());print(type_cast(C()));auto A = mat<row_num, adj_num, double>::eyes(2.0);auto B = mat<adj_num, col_num, double>::eyes(1.0);auto C = A.dot(B);C = C + 1.0;C = sqrtl(C);C = C - 2.0;C = C * 3.0;C = C / 4.0;C.print();std::cout << "---------- D ----------" << std::endl;auto D = mat<row_num, col_num, double>::xavier_gaussian();D.print();std::cout << "---------- E ----------" << std::endl;auto E = mat<row_num, col_num, double>::xavier_mean();E.print();std::cout << "---------- F ----------" << std::endl;auto F = mat<row_num, col_num, double>::he_gaussian();F.print();std::cout << "---------- G ----------" << std::endl;auto G = mat<row_num, col_num, double>::he_mean();G.print();*/bp_network<sigmoid, double, nadam, xavier_gaussian_type, row_num, adj_num, col_num> node;auto input = mat<row_num, 1, double>::ones(0.2);auto expect = mat<col_num, 1, double>::ones(0.4);int times = 8000;int update_inert_times = 100;int step = times / update_inert_times;// 计时开始auto start = std::chrono::high_resolution_clock::now();for (int i = 0; i < times; ++i){auto output = node.forward(input);auto delta = (output - expect);node.backward(delta, 0.001);if (i == times - 1){output.t().print();}if (i % step == 0 && i != 0){node.update_inert();}}// 计时结束// 获取结束时间点auto end = std::chrono::high_resolution_clock::now();// 计算持续时间std::chrono::duration<double> duration = end - start;// 输出执行时间std::cout << "Execution time: " << duration.count() << " seconds" << std::endl;//node.print();cudaDeviceReset();return 0;
}

以上代码有个学习率lr没有地方设置哈,将来优化,见谅。执行结果如下:
在这里插入图片描述
可以看出,经过8000次的训练,这个使用sigmoid激活函数、NAdam优化、Xavier-Gaussian初始化的323232的PB能够将误差缩减到0.0001这个量级,而训练时间仅为8.54秒。还是相当给力的。
虽然这对于我的工作没有任何关系,但是我还是想搞一下。毕竟“越是没用的知识就越有用,越是有用的东西就越没用”。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/5266.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Java数据结构】排序

【Java数据结构】排序 一、排序1.1 排序的概念1.2 排序的稳定性1.3 内部排序和外部排序1.3.1 内部排序1.3.2 外部排序 二、插入排序2.1 直接插入排序2.2 希尔排序 三、选择排序3.1 选择排序3.2 堆排序 四、交换排序4.1 冒泡排序4.2 快速排序Hoare法&#xff1a;挖坑法&#xff…

内存 管理

1、如何在LCD上面实现SD卡文件浏览&#xff1f; 需要读取所有文件名到内存&#xff0c;方法是定义一个数组才存储所有文件名。&#xff08;最大文件名的长度和文件个数&#xff09; 2、内存管理是什么&#xff1f; 指软件运行时对MCU内存资源的分配和使用的技术。要实现两个函…

1月21日星期二今日早报简报微语报早读

1月21日星期二&#xff0c;农历腊月廿二&#xff0c;早报#微语早读。 1、多地官宣&#xff1a;2025年可有序、限时或在限定区域燃放烟花爆竹&#xff1b; 2、TikTok恢复在美服务&#xff1b;特朗普提出继续运营TikTok方案&#xff0c;外交部&#xff1a;若涉及收购中国企业应…

深度学习python基础(第三节) 函数、列表

本节主要介绍函数、列表的基本语法格式。 函数 与c语言的函数差不多&#xff0c;就是语法基本格式不同。 name "loveyou" length len(name) print("字符串的长度为&#xff1a;%d" % length) # 自定义函数 def countstr(data):count 0for i in da…

STM32 FreeROTS Tickless低功耗模式

低功耗模式简介 FreeRTOS 的 Tickless 模式是一种特殊的运行模式&#xff0c;用于最小化系统的时钟中断频率&#xff0c;以降低功耗。在 Tickless 模式下&#xff0c;系统只在有需要时才会启动时钟中断&#xff0c;而在无任务要运行时则完全进入休眠状态&#xff0c;从而降低功…

65,【5】buuctf web [SUCTF 2019]Upload Labs 2

进入靶场 1,源代码 点击题目时有个就有个admin.php <?php // 引入配置文件 include config.php;class Ad{public $cmd;public $clazz;public $func1;public $func2;public $func3;public $instance;public $arg1;public $arg2;public $arg3;// 构造函数&#xff0c;用于初…

Apache Tomcat文件包含漏洞复现(详细教程)

1.漏洞原理 Tomcat 服务器是一个免费的开放源代码的Web 应用服务器&#xff0c;其安装后会默认开启ajp连接器&#xff0c;方便与其他web服务器通过ajp协议进行交互。属于轻量级应用服务器&#xff0c;在中小型系统和并发访问用户不是很多的场合下被普遍使用&#xff0c;是开发…

springboot基于安卓的智启教育服务平台app

基于Spring Boot的智启教育服务平台App是一个结合了Spring Boot后端框架与安卓前端技术的综合性教育服务平台。 一、技术背景与架构 1.开发语言&#xff1a;后端采用Java语言开发&#xff0c;充分利用Java的跨平台性、面向对象特性和强大的后端处理能力。前端则使用安卓开发技…

我的创作纪念日,纪念我的第512天

目录 年末 年初 入围 博客 变动 生活 期待 年末 很快&#xff0c;2024年已经过去了&#xff0c;本想在跨年夜的时候营造一点小小的仪式感&#xff0c;结果也因为身体的原因放弃了&#xff0c;浑身感觉疼痛&#xff0c;躺在床上&#xff0c;闭上眼睛&#xff0c;什么也不…

2025/1/21 学习Vue的第四天

睡觉。 --------------------------------------------------------------------------------------------------------------------------------- 11.Object.defineProperty 1.在我们之前学习JS的时候&#xff0c;普通得定义一个对象与属性。 <!DOCTYPE html> <h…

卸载和安装Git小乌龟、git基本命令

卸载 Git 打开控制面板&#xff1a; 按 Win R 打开运行对话框&#xff0c;输入 control 并按回车键。或直接在功能搜索里搜索“控制面板”。在控制面板中&#xff0c;选择“程序”或“程序和功能”。 查找并卸载 Git&#xff1a; 在程序列表中找到“Git”或“Git for Windows…

OSI5GWIFI自组网协议层次对比

目录 5G网络5G与其他协议栈各层映射 5G网络 物理层 (PHY) 是 5G 基站协议架构的最底层&#xff0c;负责将数字数据转换为适合无线传输的信号&#xff0c;并将接收到的无线信号转换为数字数据。实现数据的编码、调制、多天线处理、资源映射等操作。涉及使用新的频段&#xff08…

ThinkPHP 8的多对多关联

【图书介绍】《ThinkPHP 8高效构建Web应用》-CSDN博客 《2025新书 ThinkPHP 8高效构建Web应用 编程与应用开发丛书 夏磊 清华大学出版社教材书籍 9787302678236 ThinkPHP 8高效构建Web应用》【摘要 书评 试读】- 京东图书 使用VS Code开发ThinkPHP项目-CSDN博客 编程与应用开…

可视化-numpy实现线性回归和梯度下降法

代码如下&#xff1a; import numpy as np import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D from matplotlib.patches import Patch# 生成二维输入数据 np.random.seed(0) X1 2 * np.random.rand(100, 1) # 第一个特征 X2 3 * np.random.rand(10…

python_在钉钉群@人员发送消息

python_在钉钉群人员发送消息 1、第一种 企业内部机器人群聊实现人接入指南&#xff0c;适用于群机器人接收消息&#xff0c;处理完一系列的动作之后&#xff0c;将消息返回给发消息的人员&#xff0c;同时该人员。 需要在企微后台新建一个自建应用&#xff0c;在自建应用里…

递归练习六(普通练习11-15)

一、例题 1、有效数独 36. 有效的数独 - 力扣&#xff08;LeetCode&#xff09; 2、填数独 37. 解数独 - 力扣&#xff08;LeetCode&#xff09; 3、单词搜索 79. 单词搜索 - 力扣&#xff08;LeetCode&#xff09; 4、黄金矿工 1219. 黄金矿工 - 力扣&#xff08;LeetCod…

【生产力工具】ChatGPT for Windows桌面版本安装教程

使用桌面版的ChatGPT目前可解决官方轻微降智的问题。 文章目录 准备安装步骤步骤 1: 更改系统区域设置步骤 2: 关闭系统代理&#xff08;如果你正在使用的话&#xff09;步骤 3: 启动EXE文件步骤 4: 完成安装 准备 下载并保存好 ChatGPT桌面版的EXE安装文件。 下载地址1&…

兼职全职招聘系统架构与功能分析

2015工作至今&#xff0c;10年资深全栈工程师&#xff0c;CTO&#xff0c;擅长带团队、攻克各种技术难题、研发各类软件产品&#xff0c;我的代码态度&#xff1a;代码虐我千百遍&#xff0c;我待代码如初恋&#xff0c;我的工作态度&#xff1a;极致&#xff0c;责任&#xff…

【ESP32】ESP32连接JY61P并通过WIFI发送给电脑

前言 手头上有个ESP32&#xff0c;发现有wifi功能&#xff0c;希望连接JY61P并通过WIFI把姿态数据发送给电脑 1.采用Arduino IDE编译器&#xff1b;需要安装ESP32的开发板管理器&#xff1b; 2.电脑接受数据是基于python的&#xff1b; 1. ESP32 连接手机WIFI #include <…

第23篇 基于ARM A9处理器用汇编语言实现中断<五>

Q&#xff1a;怎样修改HPS Timer 0定时器产生的中断周期&#xff1f; A&#xff1a;在上一期实验的基础上&#xff0c;可以修改按键中断服务程序&#xff0c;实现红色LED上的计数值递增的速率&#xff0c;主程序和其余代码文件不用修改。 实现以下功能&#xff1a;按下KEY0…