深度学习(三)

5.Functional API 搭建神经网络模型
5.1利用Functional API编写宽深神经网络模型进行手写数字识别
import numpy as npimport pandas as pdimport matplotlib.pyplot as pltfrom sklearn.datasets import load_irisfrom sklearn.model_selection import train_test_splitfrom tensorflow.keras.layers import Input, Dense, concatenatefrom tensorflow.keras.models import Modeliris = load_iris()x_train, x_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.2, random_state=23)X_train, X_valid, y_train, y_valid = train_test_split(x_train, y_train, test_size=0.2, random_state=12)print(X_valid.shape)print(X_train.shape)inputs = Input(shape=X_train.shape[1:])hidden1 = Dense(300, activation="relu")(inputs)hidden2 = Dense(100, activation="relu")(hidden1)concat = concatenate([inputs, hidden2])output = Dense(10, activation="softmax")(concat)model_wide_deep = Model(inputs=inputs, outputs=output)

iris = load_iris():加载iris数据集,这是一个常用的多类分类数据集,包含了150个样本,每个样本有4个特征,属于3个不同的类别。

x_train, x_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.2, random_state=23):将iris数据集分割为训练集和测试集。test_size=0.2表示测试集的大小为原始数据的20%,random_state=23是一个随机种子,确保分割的可重复性。

X_train, X_valid, y_train, y_valid = train_test_split(x_train, y_train, test_size=0.2, random_state=12):进一步将训练集分割为训练集和验证集。同样,test_size=0.2表示验证集的大小为分割后训练数据的20%,random_state=12确保分割的可重复性。

print(X_valid.shape):打印验证集的特征数据的形状。

print(X_train.shape):打印新的训练集的特征数据的形状。

inputs = Input(shape=X_train.shape[1:]):定义模型的输入层,shape=X_train.shape[1:]指定输入的形状,由于X_train是一个二维数组,shape[1:]表示除了第一维(样本数量)之外的所有维度。

hidden1 = Dense(300, activation="relu")(inputs):定义第一个隐藏层,它有300个神经元,并使用ReLU激活函数。

hidden2 = Dense(100, activation="relu")(hidden1):定义第二个隐藏层,它有100个神经元,并使用ReLU激活函数。

concat = concatenate([inputs, hidden2]):将输入层和第二个隐藏层的输出拼接起来,形成更宽的网络。

output = Dense(10, activation="softmax")(concat):定义输出层,它有10个神经元(对应于3个类别和一个额外的神经元,这是常见的做法),并使用softmax激活函数输出概率分布。

model_wide_deep = Model(inputs=inputs, outputs=output):创建一个Keras模型,将输入层和输出层连接起来。

使用scikit-learn库中的load_iris函数来加载iris数据集,然后使用train_test_split函数将数据集分割为训练集和测试集,以及进一步的训练集和验证集。接着,它定义了一个宽深网络(wide and deep network)模型,其中包含了输入层、两个隐藏层和一个输出层。

model_wide_deep.summary()

model_wide_deep.compile(loss="sparse_categorical_crossentropy",optimizer="sgd",metrics=["accuracy"])h = model_wide_deep.fit(X_train, y_train, batch_size=32, epochs=30,validation_data=(X_valid, y_valid))

# 绘图pd.DataFrame(h.history).plot(figsize=(8,5))plt.grid(True)plt.gca().set_ylim(0, 1)plt.show()

# 使用 model_wide_deep 评估测试集test_loss, test_accuracy = model_wide_deep.evaluate(x_test, y_test, batch_size=32)print(f"Test Loss: {test_loss}")print(f"Test Accuracy: {test_accuracy}")

6.SubClassing API 搭建神经网络模型

以前馈全连接神经网络手写数字识别为例

import numpy as npimport pandas as pdimport matplotlib.pyplot as pltfrom sklearn.datasets import load_irisfrom sklearn.model_selection import train_test_splitfrom tensorflow.keras.layers import Input, Dense, concatenatefrom tensorflow.keras.models import Modelfrom tensorflow.keras import backend as K# 加载数据集iris = load_iris()X = iris.datay = iris.target# 分割数据集X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=23)X_train, X_valid, y_train, y_valid = train_test_split(X_train, y_train, test_size=0.2, random_state=12)# 打印验证集和训练集的形状print(X_valid.shape)print(X_train.shape)# 定义 Model_sub_fn 类class Model_sub_fn(Model):def __init__(self, units_1, units_2, units_out, activation="relu"):super(Model_sub_fn, self).__init__()self.hidden1 = Dense(units_1, activation=activation)self.hidden2 = Dense(units_2, activation=activation)self.main_output = Dense(units_out, activation="softmax")def call(self, inputs):x = self.hidden1(inputs)x = self.hidden2(x)return self.main_output(x)

定义了一个名为Model_sub_fn的类,该类继承自tensorflow.keras.Model。这个类用于创建一个简单的神经网络模型,它包含两个隐藏层和一个输出层。

class Model_sub_fn(Model)定义一个名为Model_sub_fn的类,它继承自tensorflow.keras.Model。这意味着Model_sub_fn类可以访问和继承Model类的所有属性和方法。

def __init__(self, units_1, units_2, units_out, activation="relu"):定义类的构造函数__init__,它接受四个参数:units_1(第一个隐藏层的神经元数量)、units_2(第二个隐藏层的神经元数量)、units_out(输出层的神经元数量)和activation(激活函数类型,默认为ReLU)。

super(Model_sub_fn, self).__init__():调用父类的构造函数,这是继承自Model类的标准做法。

self.hidden1 = Dense(units_1, activation=activation):定义第一个隐藏层,它有units_1个神经元,并使用activation作为激活函数。

self.hidden2 = Dense(units_2, activation=activation):定义第二个隐藏层,它有units_2个神经元,并使用activation作为激活函数。

self.main_output = Dense(units_out, activation="softmax"):定义输出层,它有units_out个神经元,并使用softmax作为激活函数。

def call(self, inputs):定义call方法,这是所有Keras模型必须定义的方法,它用于前向传播。在这个方法中,输入数据通过两个隐藏层,最后通过输出层。

x = self.hidden1(inputs):将输入数据通过第一个隐藏层。

x = self.hidden2(x):将第一个隐藏层的输出通过第二个隐藏层。

return self.main_output(x):将第二个隐藏层的输出通过输出层,并返回结果。

model_sub_fn = Model_sub_fn(units_1=64, units_2=32, units_out=3)# 创建 Model_sub_fn 实例model_sub_fn = Model_sub_fn(300, 100, 3, activation="relu")  # 假设输出层有3个单元,因为Iris数据集有3个类别# 编译模型model_sub_fn.compile(loss="sparse_categorical_crossentropy",optimizer="sgd",metrics=["accuracy"])# 训练模型history = model_sub_fn.fit(X_train, y_train, batch_size=32, epochs=30, validation_data=(X_valid, y_valid))

model_sub_fn.summary()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/345998.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Nginx配置详细解释:(3)http模块及server模块,location模块

目录 环境概述: http模块中的全局模块 1. root配置主要是对主web页面的路径访问。 2.server虚拟主机 2.1基于IP: 2.2基于域名: 3.alias别名 4.location匹配 5.access模块: 6.验证模块 7.自定义错误页面 8.日志存放位置…

Clearedge3d EdgeWise 5.8 强大的自动化建模软件

EdgeWise是功能强大的建模软件,提供领先的建模功能和先进的技术,让您的整个过程更快更准确!您可以获得使用自动特征提取和对象识别的 3D 建模,ClearEdge3D 自动建模和对象识别软件通过创建竣工文档和施工验证完成该过程。拓普康和…

Python第二语言(七、Python模块)

目录 1. 什么是模块 2. 基本语法 2.1 模块的导入方式 2.2 基本语法 import 模块名 2.3 基本语法 from 模块名 import 功能名 2.4 基本语法as 别名 3. 自定义模块 4. 调用自定义模块时,如何让其模块中的函数不被调用(__name__) 5. 调…

【数据结构与算法】使用数组实现栈:原理、步骤与应用

💓 博客主页:倔强的石头的CSDN主页 📝Gitee主页:倔强的石头的gitee主页 ⏩ 文章专栏:《数据结构与算法》 期待您的关注 ​ 目录 一、引言 🎄栈(Stack)是什么? &#x1…

安卓逆向经典案例——XX牛

安卓逆向经典案例——XX牛 按钮绑定方式 1.抓包 2.查看界面元素,找到控件id 通过抓包,发现点击登录后,才会出现Encrpt加密信息,所以我们通过控件找到对应id:btn_login 按钮绑定方法——第四种 public class LoginA…

开源规则引擎LiteFlow项目应用实践

本文介绍基于开源规则引擎LiteFlow,如何开发规则设计器,在低代码平台中集成规则引擎,并在项目中实现应用的效果。由于低代码平台使用规则引擎实现了逻辑编排的需求,所以本文中的叫法为“逻辑设计”、“逻辑编排”、“逻辑流引擎”…

【Android面试八股文】一图展示 Android生命周期:从Activity到Fragment,以及完整的Android Fragment生命周期

图片来源于:https://github.com/xxv/android-lifecycle Android生命周期:从Activity到Fragment 图:android-lifecycle-activity-to-fragments.png 完整的Android Fragment生命周期 图:complete_android_fragment_lifecycle.png…

设置路径别名

一、描述 如果想要给路径设置为别名,就是常见的有些项目前面的引入文件通过开头的,也就是替换了一些固定的文件路径,怎么配置。 二、配置 import { defineConfig } from vite import react from vitejs/plugin-react import path from path…

基础数据结构 -- 堆

1. 简介 堆可以看做是一种特殊的完全二叉树,它满足任意节点的值都大于或小于其子节点的值。 2. 功能 插入元素:插入新元素时,先将元素放至数组末尾,然后通过上浮算法自底向上调整,使堆保持性质。删除堆顶元素&#xff…

App UI 风格,尽显魅力

精妙无比的App UI 风格

动态内存管理(malloc,calloc,realloc,free)+经典笔试题

动态内存管理 一. malloc 和 free1. malloc2. free 二. calloc三. realloc四.动态内存的错误1.对NULL指针的解引用操作2.对动态开辟空间的越界访问3.对非动态开辟内存使用free释放4.使用free释放一块动态开辟内存的一部分5.对同一块动态内存多次释放6.动态开辟内存忘记释放&…

ROS 获取激光雷达数据(C++实现)

ROS 获取激光雷达数据(C实现) 实现思路 在机器人ROS系统中,激光雷达通常会有一个对应的节点,这个节点一般是由雷达的厂商提供,我们只需要简单的配置以下端口参数,就能和激光雷达的电路系统建立连接&#…

“安全生产月”专题报道:AI智能监控技术如何助力安全生产

今年6月是第23个全国“安全生产月”,6月16日为全国“安全宣传咨询日”。今年全国“安全生产月”活动主题为“人人讲安全、个个会应急——畅通生命通道”。近日,国务院安委会办公室、应急管理部对开展好2024年全国“安全生产月”活动作出安排部署。 随着科…

单臂路由的配置(思科、华为)

#交换设备 不同vlan属于不同广播域,不能互相通信,他们配置的是不同网段的IP地址,针对不同网段的IP地址进行通信,就需要用到路由技术 实现不同vlan之间的通信技术有两种 单臂路由三层交换 单臂路由 一、思科设备的单臂路由配…

AutoCAD Mechanical机械版专业的计算机辅助设计软件安装包下载安装!

AutoCAD机械版作为一款专业的计算机辅助设计软件,不仅具备卓越的二维绘图功能,更是拥有令人瞩目的3D建模工具,为机械设计师们提供了前所未有的创作空间。 在AutoCAD机械版的3D建模环境中,用户可以借助一系列简洁明了的命令&#…

【EDA】SSTA中最慢路径与最快路径统计计算

假设(X1,X2)为二元高斯随机向量,均值(μ1,μ2),标准差(σ1,σ2),相关系数ρ 定义:X=max(X1,X2),Y=min(X1,X2) SSTA中计算setup/hold的worst delay时即求X、Y,路径N对应维度为N维。 X的概率密度函数PDF为f(x)=f1(-x)+f2(-x),f1和f2为: 其中小Φ和大Φ…

在Linux上的Java项目导出PDF乱码问题

在Linux上的Java项目导出PDF乱码问题 场景:一个Java项目导出PDF,在我本地导出是没有问题,但是部署上Linux上后,导出就出现了乱码了。 处理方案 我这里使用的处理方案是在Linux服务器上安装一些PDF需要使用的字体 1.把字体上传到…

Dell服务器根据GPU温度调整风扇转速

前言 dell服务器自动风扇是根据CPU温度来调速的,我跑AI的时候cpu温度不高但是GPU温度很高导致显卡卡死PVE虚拟机直接挂起无法运行,我看了下也没有基于显卡温度调速的脚本,于是我就自己写了一个 基于ipmi工具 乌班图等linux先安装ipmi apt …

Springboot结合redis实现关注推送

关注推送 Feed流的模式 Timeline:不做内容筛选,简单的按照内容发布时间排序。常用于好友与关注。例如朋友圈的时间发布排序。 优点:信息全面,不会有缺失。并且实现也相对简单 缺点:信息噪音较多,用户不一定感兴趣,内容获取效率…

Transparent 且 Post-quantum zkSNARKs

1. 引言 前序博客有: SNARK原理示例SNARK性能及安全——Prover篇SNARK性能及安全——Verifier篇 上图摘自STARKs and STARK VM: Proofs of Computational Integrity。 上图选自:Dan Boneh 斯坦福大学 CS251 Fall 2023 Building a SNARK 课件。 SNARK…