【机器学习】基于tensorflow实现你的第一个DNN网络

博客导读:

《AI—工程篇》

AI智能体研发之路-工程篇(一):Docker助力AI智能体开发提效

AI智能体研发之路-工程篇(二):Dify智能体开发平台一键部署

AI智能体研发之路-工程篇(三):大模型推理服务框架Ollama一键部署

AI智能体研发之路-工程篇(四):大模型推理服务框架Xinference一键部署

AI智能体研发之路-工程篇(五):大模型推理服务框架LocalAI一键部署

《AI—模型篇》

AI智能体研发之路-模型篇(一):大模型训练框架LLaMA-Factory在国内网络环境下的安装、部署及使用

AI智能体研发之路-模型篇(二):DeepSeek-V2-Chat 训练与推理实战

AI智能体研发之路-模型篇(三):中文大模型开、闭源之争

AI智能体研发之路-模型篇(四):一文入门pytorch开发

AI智能体研发之路-模型篇(五):pytorch vs tensorflow框架DNN网络结构源码级对比

目录

一、引言

二、tensorflow介绍

2.1 tensorflow历史

2.2 tensorflow特点

 2.3 tensorflow安装

三、tensorflow实战

3.1 引入依赖的tensorflow库

3.2 训练数据准备

3.3 创建三层DNN模型

3.4 编译模型、定义损失函数与优化器

3.5 启动训练,迭代收敛

3.6 模型评估

3.7 可以直接跑的代码 

四、总结


一、引言

上一篇AI智能体研发之路-模型篇(四):一文入门pytorch开发介绍如何使用pytorch实现一个简单的DNN网络,今天我们还是用同样的例子,看看使用tensorflow如何实现。

二、tensorflow介绍

2.1 tensorflow历史

TensorFlow由谷歌人工智能团队谷歌大脑(Google Brain)开发和维护,拥有包括TensorFlow Hub、TensorFlow Lite、TensorFlow Research Cloud在内的多个项目以及各类应用程序接口(Application Programming Interface, API)。自2015年11月9日起,TensorFlow依据阿帕奇授权协议(Apache 2.0 open source license)开放源代码。

2.2 tensorflow特点

深度学习时代,tensorflow在工业应用较为广泛,而pytorch更多应用于研究中。大模型时代,pytorch是很多项目的底层库,大有超过tensorflow的趋势。可谓并驾齐驱。

  • 生态系统更成熟:TensorFlow拥有一个庞大的社区和丰富的资源,包括大量的教程、预训练模型和工具,适合从初学者到专家的各个层次用户。
  • 生产部署友好:TensorFlow支持更多的平台和设备,包括移动设备和边缘设备,提供了TensorFlow Lite和TensorFlow.js等,便于模型的部署和优化。
  • 静态图与动态图的结合:虽然早期TensorFlow以静态图为主,但TensorFlow 2.x引入了Eager Execution,结合了动态图的易用性和静态图的高性能,同时保持了模型的可部署性。
  • Keras集成:TensorFlow内建了Keras,这是一个高级神经网络API,使得模型构建、训练和评估更加简洁直观。
  • TensorBoard:TensorFlow自带的可视化工具TensorBoard,便于可视化模型结构、训练过程中的损失和指标,帮助用户更好地理解和调试模型。
  • 广泛的工业应用支持:由于其成熟度和稳定性,TensorFlow在工业界得到了广泛的应用,特别是在大型企业中。

 2.3 tensorflow安装

与pytorch一样,还是采用conda创建环境,采用pip安装tensorflow包

1.建立名为pytrain,python版本为3.11的conda环境(这里与pytorch一样)

conda create -n pytrain python=3.11
conda activate pytrain

​  

 2.采用pip下载tensorflow以及机器学习常用的scikit-learn和numpy包

pip install tensorflow scikit-learn numpy  -i https://mirrors.cloud.tencent.com/pypi/simple

​ 

这里未指定版本,默认下载最新版本tensorflow-2.16.1以及其他tensorboard等生态包。 

三、tensorflow实战

 动手实现一个三层DNN网络:

3.1 引入依赖的tensorflow库

这里主要是tensorflow、keras、sklearn、numpy等

Keras是一个用于构建和训练深度学习模型的高级API,它设计得极其用户友好,支持快速实验。Keras可以运行在TensorFlow之上。

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Activation
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import BinaryCrossentropy
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import numpy as np

3.2 训练数据准备

这里采用numpy库进行数据随机生成

# 假设你已经有了特征数据 X 和标签数据 y
# X, y = ...  # 实际数据加载和预处理步骤
# 这里我们用随机数据作为示例
np.random.seed(0)
X = np.random.rand(1000, 1000)  # 1000个样本,每个样本1000个特征
y = np.random.randint(0, 2, size=(1000, 1))  # 二分类标签# 数据预处理,标准化特征
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=42)
  • 首先,采用numpy的random随机生成X矩阵(1000行样本*1000行特征)和y矩阵(1000行0或1的label)
  • 其次,采用sklearn库中的StandardScaler将X矩阵中的每个样本特征数值标准化(将每个特征都转换为正态分布,均值为0,标准差为1),这一步骤对于机器学习算法的性能至关重要,特别是那些对输入数据的尺度敏感的算法。
  • 最后,按照2:8的比例从数据中切分出测试机与训练集

3.3 创建三层DNN模型

采用keras.sequential类,顾名思义“按顺序的”由输入至输出编排神经网络

# 创建模型
model = Sequential([Dense(512, input_shape=(X_train.shape[1],)),  # 第一层Activation('relu'),Dense(512),  # 第二层Activation('relu'),Dense(1),  # 输出层Activation('sigmoid')  # 二分类使用sigmoid
])

 Sequential是Keras中用于构建深度学习模型的一个类,特别适合于构建线性的堆叠层模型。这种模型结构是层与层直接相连,没有复杂的拓扑结构,适合于解决如图像分类、文本分类等任务

特点

  • 线性堆叠:层按照添加的顺序堆叠,每一层只与前一层有连接。
  • 易于使用:适合初学者和快速原型设计,对于复杂的网络结构可能不够灵活。
  • 灵活性限制:对于需要多输入或多输出,或者层间有复杂连接的模型,应使用更高级的模型结构,如Functional API。

3.4 编译模型、定义损失函数与优化器

不同于pytorch的实例化模型对象,这里采用compile对模型进行编译。与pytorch相同点是都要定义损失函数和优化器,方法与技巧完全相同。

# 编译模型
model.compile(optimizer=Adam(learning_rate=0.001),loss=BinaryCrossentropy(),metrics=['accuracy'])
  • optimizer=Adam(learning_rate=0.001):这里选择了Adam作为优化器。Adam(Adaptive Moment Estimation)是一种常用的优化算法,它结合了RMSprop和Momentum的优点,能够自动调整学习率。通过设置learning_rate=0.001,可以控制模型学习的速度。学习率是训练过程中的一个重要超参数,影响模型收敛的速度和最终的性能。
  • loss=BinaryCrossentropy():损失函数设置为二元交叉熵(Binary Crossentropy)。这个损失函数适用于二分类问题,它衡量了模型预测的概率分布与实际标签之间的差异。在二分类任务中,正确选择损失函数对于模型的性能至关重要。
  • metrics=['accuracy']:指定评估模型性能的指标。这里使用的是准确率(accuracy),即分类正确的比例。在训练和验证过程中,除了损失值外,还会计算并显示这个指标,帮助我们了解模型的性能。

3.5 启动训练,迭代收敛

不同于pytorch需要写两个循环处理每一行样本,tensorflow直接采用fit方法对输入的特征样本矩阵以及label矩阵进行训练

tensorflow版:

# 训练模型
history = model.fit(X_train, y_train, epochs=100, validation_split=0.1,  # 使用10%的数据作为验证集verbose=1)

pytorch版:

# 训练循环
num_epochs = 10
for epoch in range(num_epochs):model.train()  # 设置为训练模式running_loss = 0.0for i, (inputs, labels) in enumerate(data_loader, 0):optimizer.zero_grad()  # 清零梯度outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()  # 反向传播optimizer.step()  # 更新权重running_loss += loss.item()print(f'Epoch {epoch + 1}, Loss: {running_loss / len(data_loader)}')

对比来看,pytorch版的更加透明,有助于理解,tensorflow更加便捷 

运行后可以看到loss逐步收敛:​

3.6 模型评估

通过model.evaluate对模型进行评估,evaluate与fit的区别是只计算指标不进行模型更新

tensorflow版:

# 评估模型
loss, accuracy = model.evaluate(X_test, y_test, verbose=0)
print(f'Test loss: {loss}, Test accuracy: {accuracy}')

 pytorch版:

import torchmetrics # 导入torchmetricstest_num_samples = 200  # 测试样本数
test_X_train = torch.randn(test_num_samples, input_size) 
test_y_train = torch.randint(0, output_size, (test_num_samples,))# 数据加载
test_dataset = TensorDataset(test_X_train,test_y_train)
test_data_loader = DataLoader(test_dataset, batch_size=32, shuffle=True)# 在模型训练完成后进行评估
# 首先,我们需要确保模型在评估模式下
model.eval()# 初始化准确率和召回率的计算器
accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=output_size)
recall = torchmetrics.Recall(task="multiclass", num_classes=output_size)with torch.no_grad():  # 确保在评估时不进行梯度计算for inputs, labels in test_data_loader:outputs = model(inputs)preds = torch.softmax(outputs, dim=1)# 更新指标计算器accuracy.update(preds, labels)recall.update(preds, labels)# 打印准确率和召回率
print(f'Accuracy: {accuracy.compute():.4f}')
print(f'Recall: {recall.compute():.4f}')print('Evaluation finished.')

对比pytorch需要写一个循环,tensorflow.keras的封装更为简洁

运行后,可以输出模型的准确率与召回率,由于采用随机生成的测试数据且迭代轮数较少,具体数值不错参考,可以根据自己需要丰富数据。

3.7 可以直接跑的代码 

与上一篇AI智能体研发之路-模型篇(四):一文入门pytorch开发一样,附可以直接运行的代码,先跑起来,再一行行研究!

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Activation
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import BinaryCrossentropy
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import numpy as np# 假设你已经有了特征数据 X 和标签数据 y
# X, y = ...  # 实际数据加载和预处理步骤
# 这里我们用随机数据作为示例
np.random.seed(0)
X = np.random.rand(1000, 1000)  # 1000个样本,每个样本1000个特征
y = np.random.randint(0, 2, size=(1000, 1))  # 二分类标签# 数据预处理,标准化特征
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=42)# 创建模型
model = Sequential([Dense(512, input_shape=(X_train.shape[1],)),  # 第一层Activation('relu'),Dense(512),  # 第二层Activation('relu'),Dense(1),  # 输出层Activation('sigmoid')  # 二分类使用sigmoid
])# 编译模型
model.compile(optimizer=Adam(learning_rate=0.001),loss=BinaryCrossentropy(),metrics=['accuracy'])# 训练模型
history = model.fit(X_train, y_train, epochs=10, validation_split=0.1,  # 使用10%的数据作为验证集verbose=1)# 评估模型
loss, accuracy = model.evaluate(X_test, y_test, verbose=0)
print(f'Test loss: {loss}, Test accuracy: {accuracy}')

四、总结

本文先对tensorflow深度学习框架历史、特点及安装方法进行介绍,接下来基于tensorflow带读者一步步开发一个简单的三层神经网络程序,最后附可执行的代码供读者进行测试学习。个人感觉tensorflow封装程度高于pytorch,网络结构也更加清晰,但pytorch更加透明。

喜欢的话期待您的关注、点赞、收藏,您的互动是对我最大的鼓励!

如果还有时间,可以看看我的其他文章:

《AI—工程篇》

AI智能体研发之路-工程篇(一):Docker助力AI智能体开发提效

AI智能体研发之路-工程篇(二):Dify智能体开发平台一键部署

AI智能体研发之路-工程篇(三):大模型推理服务框架Ollama一键部署

AI智能体研发之路-工程篇(四):大模型推理服务框架Xinference一键部署

AI智能体研发之路-工程篇(五):大模型推理服务框架LocalAI一键部署

《AI—模型篇》

AI智能体研发之路-模型篇(一):大模型训练框架LLaMA-Factory在国内网络环境下的安装、部署及使用

AI智能体研发之路-模型篇(二):DeepSeek-V2-Chat 训练与推理实战

AI智能体研发之路-模型篇(三):中文大模型开、闭源之争

AI智能体研发之路-模型篇(四):一文入门pytorch开发

AI智能体研发之路-模型篇(五):pytorch vs tensorflow框架DNN网络结构源码级对比

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/335288.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

香港优才计划申请时间要多久?各流程申请周期规划,再晚就来不及了!

香港优才计划申请时间要多久?各流程申请周期规划,再晚就来不及了! 2024年是香港优才计划不限配额的最后一年,明年政策如何变化还未可知,但如果明年又设置限额了,那么今年最后的机会一定要抓住了。 在这里…

LSTM长短时记忆网络:推导与实现(pytorch)

LSTM长短时记忆网络:推导与实现(pytorch) 背景推导遗忘门输入门输出门 LSTM的改进:GRU实现 背景 人类不会每秒钟都从头开始思考。当你阅读这篇文章时,你会根据你对以前单词的理解来理解每个单词。你不会把所有东西都扔…

STM32-12-OLED模块

STM32-01-认识单片机 STM32-02-基础知识 STM32-03-HAL库 STM32-04-时钟树 STM32-05-SYSTEM文件夹 STM32-06-GPIO STM32-07-外部中断 STM32-08-串口 STM32-09-IWDG和WWDG STM32-10-定时器 STM32-11-电容触摸按键 文章目录 1. OLED显示屏介绍2. OLED驱动原理3. OLED驱动芯片简介4…

v4l2抓取rv1126图像

0.准备工作 本文是基于正点原子的rv1126开发板使用mx415摄像头对不同节点的图像进行抓取 1.数据流向 图1 mx415采集到的数据为原始的拜尔格式(也就是raw格式),我们需要通过isp进行图像的调节才符合视觉,其中isp和ispp是两个处理的…

【大数据】Hadoop 2.X和1.X升级优化对比

目录 1.前言 2.hadoop 1.X的缺点和优化方向 3.解决NameNode的局限性 3.1.Hadoop HA 3.2.Haddop federation 4.yarn 5.周边组件 1.前言 本文是作者大数据系列中的一文,专栏地址: https://blog.csdn.net/joker_zjn/category_12631789.html?spm10…

网络侦察技术

网络侦察技术 收集的信息网络侦察步骤搜索引擎检索命令bing搜索引擎Baidu搜索引擎Shodan钟馗之眼(zoomeye) whois数据库:信息宝库查询注册资料 域名系统网络拓扑社交网络跨域拓展攻击 其它侦察手段社会工程学社会工程学常见形式Web网站查询 其它非技术侦察手段总结网…

GDPU 操作系统 天码行空13

文章目录 ❌ TODO:本文仅供参考,极有可能有误1.生产者消费者问题(信号量)💖 ProducerConsumerExample.java🏆 运行结果 💖 ProducerConsumerSelectiveExample.java🏆 运行结果 2.实现…

将四种算法的预测结果绘制在一张图中

​ 声明:文章是从本人公众号中复制而来,因此,想最新最快了解各类智能优化算法及其改进的朋友,可关注我的公众号:强盛机器学习,不定期会有很多免费代码分享~ 之前的一期推文中,我们推出了…

TREK高压发生器维修高压电源615-3-L-JX 615-3

美国TREK高压电源维修故障分析应注意两点: 故障分析检测和故障硬件更换,由高压电源故障和工作表现初步判断故障的类型和哪些硬件出了问题,初步判断缩小检测范围,通过排除法和更替新配件准确找到故障硬件。维修过程需要对trek电源维…

C语言学习笔记之指针(一)

目录 什么是指针? 指针和指针类型 指针的类型 指针类型的意义 指针-整数 指针的解引用 指针 - 指针 指针的关系运算 野指针 什么是野指针? 野指针的成因 如何规避野指针? 二级指针 什么是指针? 在介绍指针之前&#…

【ai】livekit:Agents 1 : Agents Framework 与 LiveKit 核心 API 原语

agents 官方文档LiveKit Agents LiveKit Agents is an end-to-end framework for building realtime, multimodal AI “agents” that interact with end-users through voice, video, and data channels. This framework allows you to build an agent using Python.是一个端到…

2024年6月1日(星期六)骑行禹都甸

2024年6月1日 (星期六)骑行禹都甸(韭葱花),早8:30到9:00,昆明氧气厂门口集合,9:30准时出发【因迟到者,骑行速度快者,可自行追赶偶遇。】 偶遇地点:昆明氧气厂门口集合 ,…

010-Linux磁盘介绍

文章目录 1、名词 2、类型 3、尺寸 4、接口/协议/总线 5、命名 6、分区方式 MBR分区 GPT分区 1、名词 磁盘是计算机主要的存储介质,可以存储大量的二进制数据,并且断电后也能保持数据不丢失。早期计算机使用的磁盘是软磁盘(Floppy D…

三方语言中调用, Go Energy GUI编译的dll动态链接库CEF

如何在其它编程语言中调用energy编译的dll动态链接库,以使用CEF 或 LCL库 Energy是Go语言基于LCL CEF开发的跨平台GUI框架, 具有很容易使用CEF 和 LCL控件库 interface 便利 示例链接 正文 为方便起见使用 python 调用 go energy 编译的dll 准备 系统&#x…

C++:vector的模拟实现

✨✨✨学习的道路很枯燥,希望我们能并肩走下来! 文章目录 目录 文章目录 前言 一、vector的模拟实现 1.1 迭代器的获取 1.2 构造函数和赋值重载 1.2.1 无参构造函数 1.2.2 有参构造函数(对n个对象的去调用他们的构造) 1.2.3 迭代器区…

【UnityShader入门精要学习笔记】第十五章 使用噪声

本系列为作者学习UnityShader入门精要而作的笔记,内容将包括: 书本中句子照抄 个人批注项目源码一堆新手会犯的错误潜在的太监断更,有始无终 我的GitHub仓库 总之适用于同样开始学习Shader的同学们进行有取舍的参考。 文章目录 使用噪声上…

亮相CCIG2024,合合信息文档解析技术破解大模型语料“饥荒”难题

近日,2024中国图象图形大会在古都西安盛大开幕。本届大会由中国图象图形学学会主办,空军军医大学、西安交通大学、西北工业大学承办,通过二十多场论坛、百余项成果,集中展示了生成式人工智能、大模型、机器学习、类脑计算等多个图…

Compose第一弹 可组合函数+Text

目标: 1.Compose是什么?有什么特征? 2.Compose的文本控件 一、Compose是什么? Jetpack Compose 是用于构建原生 Android 界面的新工具包。 Compose特征: 1)声明式UI:使用声明性的函数构建一…

opencascade 快速显示AIS_ConnectedInteractive源码学习

AIS_ConcentricRelation typedef PrsDim_ConcentricRelation AIS_ConcentricRelation AIS_ConnectedInteractive 简介 创建一个任意位置的另一个交互对象实例作为参考。这允许您使用连接的交互对象,而无需重新计算其表示、选择或图形结构。这些属性是从您的参考对…

蓝桥杯嵌入式国赛笔记(4):多路AD采集

1、前言 蓝桥杯的国赛会遇到多路AD采集的情况,这时候之前的单路采集的方式就不可用了,下面介绍两种多路采集的方式。 以第13届国赛为例 2、方法一(配置通道) 2.1 使用CubeMx配置 设置IN13与IN17为Single-ended 在Parameter S…