Python深度学习实战-基于tensorflow原生代码搭建BP神经网络实现分类任务(附源码和实现效果)

实现功能

前面两篇文章分别介绍了两种搭建神经网络模型的方法,一种是基于tensorflow的keras框架,另一种是继承父类自定义class类,本篇文章将编写原生代码搭建BP神经网络。

实现代码

import tensorflow as tf
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler# 加载鸢尾花数据集
iris = load_iris()
X = iris.data
y = iris.target# 数据预处理
scaler = StandardScaler()
X = scaler.fit_transform(X)# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 设置超参数
learning_rate = 0.001
num_epochs = 100
batch_size = 32# 定义输入和输出的维度
input_dim = X.shape[1]
output_dim = len(set(y))# 定义权重和偏置项
W1 = tf.Variable(tf.random.normal(shape=(input_dim, 64), dtype=tf.float64))
b1 = tf.Variable(tf.zeros(shape=(64,), dtype=tf.float64))
W2 = tf.Variable(tf.random.normal(shape=(64, 64), dtype=tf.float64))
b2 = tf.Variable(tf.zeros(shape=(64,), dtype=tf.float64))
W3 = tf.Variable(tf.random.normal(shape=(64, output_dim), dtype=tf.float64))
b3 = tf.Variable(tf.zeros(shape=(output_dim,), dtype=tf.float64))# 定义前向传播函数
def forward_pass(X):X = tf.cast(X, tf.float64)h1 = tf.nn.relu(tf.matmul(X, W1) + b1)h2 = tf.nn.relu(tf.matmul(h1, W2) + b2)logits = tf.matmul(h2, W3) + b3return logits# 定义损失函数
def loss_fn(logits, labels):return tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits))# 定义优化器
optimizer = tf.optimizers.Adam(learning_rate)# 定义准确率指标
accuracy_metric = tf.metrics.SparseCategoricalAccuracy()# 定义训练步骤
def train_step(inputs, labels):with tf.GradientTape() as tape:logits = forward_pass(inputs)loss_value = loss_fn(logits, labels)gradients = tape.gradient(loss_value, [W1, b1, W2, b2, W3, b3])optimizer.apply_gradients(zip(gradients, [W1, b1, W2, b2, W3, b3]))accuracy_metric(labels, logits)return loss_value# 进行训练
for epoch in range(num_epochs):epoch_loss = 0.0accuracy_metric.reset_states()for batch_start in range(0, len(X_train), batch_size):batch_end = batch_start + batch_sizebatch_X = X_train[batch_start:batch_end]batch_y = y_train[batch_start:batch_end]loss = train_step(batch_X, batch_y)epoch_loss += losstrain_loss = epoch_loss / (len(X_train) // batch_size)train_accuracy = accuracy_metric.result()print(f"Epoch {epoch+1}/{num_epochs}, Loss: {train_loss:.4f}, Accuracy: {train_accuracy:.4f}")# 进行评估
logits = forward_pass(X_test)
test_loss = loss_fn(logits, y_test)
test_accuracy = accuracy_metric(y_test, logits)print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}")

实现效果

本人读研期间发表5篇SCI数据挖掘相关论文,现在某研究院从事数据挖掘相关科研工作,对数据挖掘有一定认知和理解,会结合自身科研实践经历不定期分享关于python、机器学习、深度学习基础知识与案例。

致力于只做原创,以最简单的方式理解和学习,关注我一起交流成长。

邀请三个朋友关注V订阅号:数据杂坛,即可在后台联系我获取相关数据集和源码,送有关数据分析、数据挖掘、机器学习、深度学习相关的电子书籍。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/173925.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【数据库】

文章目录 1. 聚合函数练习: 2. 子查询 1. 聚合函数 where中过滤条件中不能写聚合函数,有聚合函数需要写到Having中 方式一效率高: Select执行流程 练习: 2. 第七题:count(*)有问题,原因是左外连接后…

【继承练习题--多态-- 动态绑定-- 重写】

文章目录 继承的练习题:多态多态实现条件 动态绑定什么是重写Override 注解重写的条件(缺一不可)有一种特殊的重写:叫协变类型重写的设计原则快捷键生成重写重写和重载的区别object类是所有类的父类 总结 继承的练习题&#xff1a…

支付宝证书到期更新完整过程

如果用户收到 支付宝公钥证书 到期通知后,可以根据如下指引更新证书 确认上传成功后就会生成新的证书,把新的证书替换到生产环境就可以了

标准ACL,扩展ACL,基本ACL,高级ACL

其实标准ACL,扩展ACL,基本ACL,高级ACL是同一个概念的不同名称,区别在于: 思科路由器支持标准ACL和扩展ACL两种类型的访问控制列表,没有”基本ACL“和”高级ACL“的概念,而华为路由器都支持 编号范围&…

基于鸟群算法的无人机航迹规划-附代码

基于鸟群算法的无人机航迹规划 文章目录 基于鸟群算法的无人机航迹规划1.鸟群搜索算法2.无人机飞行环境建模3.无人机航迹规划建模4.实验结果4.1地图创建4.2 航迹规划 5.参考文献6.Matlab代码 摘要:本文主要介绍利用鸟群算法来优化无人机航迹规划。 1.鸟群搜索算法 …

[java/力扣110]平衡二叉树——优化前后的两种方法

分析 根据平衡二叉树的定义,只需要满足:1、根节点两个子树的高度差不超过1;2、左右子树都为平衡二叉树 代码 public class BalancedBinaryTree {public class TreeNode{int val;TreeNode left;TreeNode right;TreeNode(){}TreeNode(int va…

springboot第44集:Kafka集群和Lua脚本

servers:Kafka服务器的地址。这是Kafka集群的地址,生产者将使用它来发送消息。retries:在消息发送失败时,生产者将尝试重新发送消息的次数。这个属性指定了重试次数。batchSize:指定了生产者在发送消息之前累积的消息大…

2.flink编码第一步(maven工程创建)

概述 万里第一步,要进行flink代码开发,第一步先整个 flink 代码工程 flink相关文章链接 flink官方文档 两种方式 一种命令行 mvn 创建,另一种直接在 idea 中创建一个工程,使用 mvn 的一些配置 mvn命令行创建 mvn 创建flink工程&…

基于SpringBoot的工厂车间管理系统设计与实现

目录 前言 一、技术栈 二、系统功能介绍 管理员功能实现 人员管理 看板信息管理 设备信息管理 生产开立管理 人员功能实现 生产开立管理 生产工序管理 生产流程管理 三、核心代码 1、登录模块 2、文件上传模块 3、代码封装 前言 社会发展日新月异,用计…

6.scala辅助构造器与为构造函数提供默认值(一)

概述 本文主要说明: 辅助构造器 与 为构造函数提供默认值 的使用 辅助构造器为构造函数提供默认值 相关链接 阅读之前,可以浏览一下 scala相关文章 辅助构造器 可以通过定义名为this的方法来定义辅助Scala类构造函数。只有几个规则需要了解: 每个辅助…

冯诺依曼体系结构、进程、环境变量

冯诺依曼体系结构、进程、环境变量 一、冯诺依曼体系结构1、结构图2、示例3、CPU与数据 二、进程1、概念2、查看进程(1)通过/proc系统文件夹(2)通过top和ps用户级工具(3)通过系统调用 3、通过系统调用创建进…

会声会影2024这款视频剪辑软件怎么样?

众所周知,每每有新兴行业逐渐崛起壮大的时候,随机而来的就是这个行业创造出的衍生行业,比如说现在的短视频平台或者是视频剪辑行业,都是很明显的例子,今天我们就针对剪辑软件来和大家聊一聊,会声会影2024这…

论坛搭建.

目录 一.配置软件仓库 二.安装http php miriadb 三.配置数据库 一.配置软件仓库 1.进入仓库目录 cd /etc/yum.repos.d 2.创建仓库文件 vim local.repo 3.在 local.repo中写入:(粘贴的时候注意位置) [biaoshi] 仓库标识符 namemiaoshu …

禁用Google Chrome自动升级、查看Chrome版本号

问题 查看Chrome版本时,会自动升级,这个设计很垃圾,对开发者不友好;查看Chrome版本方法:chrome浏览器右上角—>自定义及控制Google Chrome(三个竖着的点号)------>帮助---->关于Google Chrome。 解决办法 禁用自动升级…

[Unity][VR]透视开发系列3-Passthrough应用的真机测试方法

【视频讲解】 视频讲解地址请关注我的B站。 专栏后期会有一些不公开的高阶实战内容或是更细节的指导内容。 B站地址: https://www.bilibili.com/video/BV1Zg4y1w7fZ/ 我还有一些免费和收费课程在网易云课堂(大徐VR课堂): https://study.163.com/provider/480000002282025/…

Open3D(C++) 最小二乘拟合平面(直接求解法)

目录 一、算法原理二、代码实现三、结果展示本文由CSDN点云侠原创,原文链接。 一、算法原理 平面方程的一般表达式为: A x + B y + C

14个最实用的WordPress SEO插件推荐

在这篇文章中,将为你推荐最有利于网站SEO的WordPress插件,这里介绍这些插件的主要功能及使用技巧,合理使用它们将有助于网站SEO排名。无论你是一个刚刚开始的博客作者,还是一个经验丰富的企业网站管理员,我们都将帮助你…

洛谷P1765 手机 / 秋季赛 九宫格

手机 题目描述 一般的手机的键盘是这样的: 要按出英文字母就必须要按数字键多下。例如要按出 x \tt x x 就得按 9 9 9 两下,第一下会出 w \tt w w,而第二下会把 w \tt w w 变成 x \tt x x。 0 0 0 键按一下会出一个空格。 你的任务是…

React 你还在用 Redux 吗?更简化的状态管理工具(Recoil)

以往传统的 Redux 状态管理工具使用起来代码太过于复杂。 你需要通过纯函数触发 action 再去修改 data 中定义的数据,而且要通过接口请求数据还需要借助 redux - think 这个中间件才能完成。。。 更加方便使用的工具:Recoil ~ 由 facebook 推出契合 R…

使用示例和应用程序全面了解高效数据管理的Golang MySQL数据库

Golang,也被称为Go,已经成为构建强大高性能应用程序的首选语言。在处理MySQL数据库时,Golang提供了一系列强大的库,简化了数据库交互并提高了效率。在本文中,我们将深入探讨一些最流行的Golang MySQL数据库库&#xff…