深度学习技术之加宽前馈全连接神经网络

深度学习技术

  • 加宽前馈全连接神经网络
    • 1. Functional API 搭建神经网络模型
      • 1.1 利用Functional API编写宽深神经网络模型进行手写数字识别
        • 1.1.1 导入需要的库
        • 1.1.2 加载虹膜(Iris)数据集
        • 1.1.3 分割训练集和测试集
        • 1.1.4 定义模型输入层
        • 1.1.5 添加隐藏层
        • 1.1.6 拼接输入层和第二个隐藏层
        • 1.1.7 添加输出层
        • 1.1.8 创建模型
        • 1.1.9 打印模型的摘要
        • 1.1.10 模型编译并训练
      • 1.2 利用Functional API编写多输入神经网络模型进行手写数字识别
        • 1.2.1 分割子集
        • 1.2.2 定义输入层
        • 1.2.3 定义全连接层
        • 1.2.4 创建模型
        • 1.2.5 编译与训练模型
        • 1.2.6 训练历史数据的可视化
    • 2. SubClassing API 搭建神经网络模型
      • 2.1 前馈全连接神经网络手写数字识别
        • 2.1.1 定义一个Keras模型类
        • 2.1.2 定义方法
        • 2.1.3 初始化模型
        • 2.1.4 通过在初始化中传递参数改变模型元素默认值
        • 2.1.5 编译与训练模型
        • 2.1.6 打印模型摘要

加宽前馈全连接神经网络

1. Functional API 搭建神经网络模型

1.1 利用Functional API编写宽深神经网络模型进行手写数字识别

1.1.1 导入需要的库

利用Sequential API建立一个顺序传播的前馈全连接神经网络,导入numpy、pandas,tensorflow等库,以及导入matplotlib的pyplot模块。从sklearn库的datasets模块中导入load_iris函数,以及从sklearn库的model_selection模块中导入train_test_split函数。从TensorFlow库中导入Keras模块。

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
import tensorflow as tf
from tensorflow import keras
1.1.2 加载虹膜(Iris)数据集

虹膜(Iris)数据集是scikit-learn库中内置的一个样本数据集,它包含了150个样本,分为三个类,每个类有50个样本。这三个类分别是山鸢尾(Iris Setosa)、杂色鸢尾(Iris Versicolour)和维吉尼亚鸢尾(Iris Virginica)。

iris = load_iris()
1.1.3 分割训练集和测试集

将虹膜(Iris)数据集分割为训练集和测试集,得到训练集x_train和y_train,再将分割得到的训练集x_train和y_train分割为新的训练集和验证集。

x_train,x_test,y_train,y_test=train_test_split(iris.data,iris.target, test_size=0.2, random_state=23)
X_train, X_valid, y_train, y_valid=train_test_split(x_train, y_train,test_size=0.2, random_state=12)
1.1.4 定义模型输入层

使用X_train.shape[1:]作为输入层的形状,因为X_train.shape[0]是批量大小,通常在训练过程中改变,而X_train.shape[1:]包含了特征的数量,这些数量在训练过程中保持不变。

inputs = keras.layers.Input(shape=X_train.shape[1:])
1.1.5 添加隐藏层

隐藏层,包含神经元,并使用ReLU激活函数。

hidden1 = keras.layers.Dense(300, activation="relu")(inputs)
hidden2 = keras.layers.Dense(100, activation="relu")(hidden1)
1.1.6 拼接输入层和第二个隐藏层

将输入层和第二个隐藏层的输出进行拼接,得到一个融合了输入和中间层信息的特征向量。

concat = keras.layers.concatenate([inputs, hidden2])
1.1.7 添加输出层

添加了一个输出层,包含10个神经元,使用softmax激活函数,因为模型是用于多类分类任务。

output = keras.layers.Dense(10, activation="softmax")(concat)
1.1.8 创建模型

创建了一个完整的模型,将输入层和输出层连接起来,形成了一个有监督学习的模型结构。
这个模型结构结合了“宽”模型(wide model)和“深”模型(deep model)的特点,通过输入层和隐藏层的拼接来融合这两种模型。

model_fun_WideDeep = keras.models.Model(inputs=[inputs], outputs=[output])

运行结果:
在这里插入图片描述

1.1.9 打印模型的摘要
model_fun_WideDeep.summary()
1.1.10 模型编译并训练

model_fun_WideDeep.fit()方法将开始模型的训练过程,并在每个轮次结束后使用验证数据评估模型的性能。训练过程中,模型将逐渐学习如何将输入特征映射到正确的输出类别。

model_fun_WideDeep.compile(loss="sparse_categorical_crossentropy",optimizer="sgd",metrics=["accuracy"])
h=model_fun_WideDeep.fit(X_train, y_train, batch_size=32, epochs=30, validation_data=(X_valid, y_valid))

运行结果:
在这里插入图片描述

1.2 利用Functional API编写多输入神经网络模型进行手写数字识别

1.2.1 分割子集

将训练集X_train和验证集X_valid分割为两个子集。

X_train_A, X_train_B = X_train[:, :200], X_train[:, 100:]
X_valid_A, X_valid_B = X_valid[:, :200], X_valid[:, 100:]
1.2.2 定义输入层
input_A = keras.layers.Input(shape=X_train_A.shape[1])
input_B = keras.layers.Input(shape=X_train_B.shape[1])
1.2.3 定义全连接层
hidden1 = keras.layers.Dense(300, activation="relu")(input_B)
hidden2 = keras.layers.Dense(100, activation="relu")(hiddenl)
1.2.4 创建模型

将输入层和输出层连接起来。

model_fun_MulIn = keras.models.Model(inputs=[input_A, input_B], outputs=[output])
1.2.5 编译与训练模型

在训练过程中,模型将使用指定的损失函数和优化器来更新权重,并使用准确率作为评估指标来监控性能。

model_fun_MulIn.compile(loss="sparse_categorical_crossentropy",optimizer="sgd",metrics=["accuracy"])

运行结果:
在这里插入图片描述

1.2.6 训练历史数据的可视化

图中显示了训练和验证集上的损失和准确率随轮次的变化情况。

pd.DataFrame(h.history).plot(figsize=(8,5))
plt.grid(True)
plt.gca().set_ylim(0,1)
plt.show()

运行结果:
在这里插入图片描述

2. SubClassing API 搭建神经网络模型

2.1 前馈全连接神经网络手写数字识别

2.1.1 定义一个Keras模型类

定义一个自定义的Keras模型类Model_sub_fnn,继承自keras.models.Model。这个类定义了一个简单的全连接神经网络,它有两个隐藏层和一个输出层。

class Model_sub_fnn(keras.models.Model):def __init__(self, units_1=300, units_2=100, units_out=10, activation='relu'):super().__init__()self.hidden1 = keras.layers.Dense(units_1, activation=activation)self.hidden2 = keras.layers.Dense(units_2, activation=activation)self.main_output = keras.layers.Dense(units_out, activation='softmax')
2.1.2 定义方法

给Model_sub_fnn类定义一个call方法。这个方法是Keras模型中的一个特殊方法,它定义了模型的前向传播过程,它将输入数据通过模型的所有层,并返回最终的输出。

def call(self, data):hidden1 = self.hidden1(data)hidden2 = self.hidden2(hidden1)main_output = self.main_output(hidden2)return main_output
2.1.3 初始化模型
model_sub_fnn = Model_sub_fnn()
2.1.4 通过在初始化中传递参数改变模型元素默认值
model_sub_fnn2 = Model_sub_fnn(300, 100, 10, activation='relu')
2.1.5 编译与训练模型

编译模型,使用训练数据和验证数据进行训练。在训练过程中,模型将使用指定的损失函数和优化器来更新权重,并使用准确率作为评估指标来监控性能。训练完成后,将得到模型的摘要,其中包含了模型的详细信息。

model_sub_fnn.compile(loss='sparse_categorical_crossentropy',optimizer='sgd',metrics=["accuracy")
h= model_sub_fnn.fit(X_train,y_train,batch_size=32,epochs=30,validation_data = (X_valid,y_valid))

运行结果:
在这里插入图片描述

2.1.6 打印模型摘要

打印出模型的摘要,其中包括模型的层结构、每个层的输出形状、层的参数数量以及整个模型的总参数数量。

model_sub_fnn.summary()

运行结果:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/326294.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

栈结构(详解)

1.栈的概念 栈:一种特殊的线性表,其只允许在固定的一端进行插入和删除元素操作。进行数据插入和删除操作的一端称为栈顶,另一端称为栈底。栈中的数据元素遵守后进先出LIFO(Last In First Out)的原则。 压栈&am…

省级生活垃圾无害化处理率面板数据(2004-2022年)

01、数据简介 生活垃圾无害化处理率是指经过处理的生活垃圾中,达到无害化标准的垃圾所占的比例。这一指标是衡量城市垃圾处理水平的重要标准,反映了城市对垃圾进行有效管理和处理的能力。 生活垃圾无害化处理的主要方式包括生活垃圾焚烧、生活垃圾卫生…

react18【系列实用教程】moxb —— 集中状态管理 (2024最新版)

官方文档 https://www.mobxjs.com/ moxb 和 redux 都能用于 react 的状态管理,但 moxb 更简单,适合规模不大的应用 (规模大的应用若合理组织代码结构,也能用 moxb) 安装 moxb npm i mobx npm i mobx-react-lite此处安…

C语言洛谷题目分享(11)回文质数

目录 1.前言 2.题目:回文质数 1.题目描述 2.输入格式 3.输出格式 4.输入输出样例 5.题解 3.小结 1.前言 哈喽大家好,今儿继续为大家分享一道蛮有价值的一道题,希望大家多多支持喔~ 2.题目:回文质数 1.题目描述 因为 151 …

【Oracle直播课】5月19日Oracle 19c OCM认证大师课 (附课件预览)

Oracle 19c OCM认证大师培训 - 课程体系 - 云贝教育 (yunbee.net) 部分课件预览 OCM部分课件预览 Oracle Database 19c Certified Master Exam (OCM) 认证大师 25 天 / 150课时 什么是Oracle 19c OCM? Oracle Certified Master (OCM)是Oracle认证大师,…

【51】Camunda8-Zeebe核心引擎-Zeebe Gateway

概述 Zeebe网关是Zeebe集群的一个组件,它可以被视为Zeebe集群的联系点,它允许Zeebe客户端与Zeebe集群内的Zeebe代理进行通信。有关Zeebe broker的更多信息,请访问我们的附加文档。 总而言之,Zeebe broker是Zeebe集群的主要部分,它完成所有繁重的工作,如处理、复制、导出…

软件工程期末复习(2)

软件工程 软件危机与软件工程的提出: 面对软件危机,1968年德国召开的一次NATO会议上首次签署声明“软件工程”这一说法,认为软件工程应当使用业已建立的工程学科的基本原理和范型。 背后驱使的观念是:软件设计、实现和维护应当与…

网络编程套接字详解

目录 1. 预备介绍 2.网络字节序 3.udp网络程序 4.地址转换函数 5.udp网络编程 1.预备介绍 1.1源IP地址和目标IP地址 举个例子: 从北京出发到上海旅游, 那么源IP地址就是北京, 目标IP地址就是上海. 1.2 端口号 作用: 标识一个进程, 告诉OS这个数据交给那个进程来处理; (1)…

Redis详解(二)

事务 什么是事务? 事务是一个单独的隔离操作:事务中的所有命令都会序列化、按顺序地执行。事务在执行的过程中,不会被其他客户端发送来的命令请求所打断。 事务是一个原子操作:事务中的命令要么全部被执行,要么全部都…

JVM内存结构

文章目录 JVM的内存结构程序计数器虚拟机栈(栈)本地方法栈Native关键字 堆元空间(JDK1.8之前为永久代) 实例对象实例化的过程 JVM的内存结构 先看下图 内存结构分为以下的两种:线程私有以及线程共享 线程私有&#x…

1070: 邻接矩阵存储简单路径

解法&#xff1a; #include<iostream> #include<vector> using namespace std; int arr[100][100]; int n; int sta, des; vector<int> path; vector<vector<int>> res; void dfs(vector<int> &a,int i) {a[i] 1;path.push_back(i);…

【数据分析面试】42.用户流失预测模型搭建(资料数据分享)

题目 保持高的客户留存率可以稳定和提到企业的收入。因此&#xff0c;预测和防止客户流失是在业务中常见的一项数据分析任务。这次分享的数据集包括了电信行业、银行、人力资源和电商行业&#xff0c;涵盖了不同业务背景下的流失预测数据。 后台回复暗号&#xff08;在本文末…

【Ubuntu永久授权串口设备读取权限“/dev/ttyUSB0”】

Ubuntu永久授权串口设备读取权限 1 问题描述2 解决方案2.1 查看ttyUSB0权限&#xff0c;拥有者是root&#xff0c;所属用户组为dialout2.2 查看dialout用户组成员&#xff0c;如图所示&#xff0c;普通用户y不在dialout组中2.3 将普通用户y加入dialout组中2.4 再次查看dialout用…

Python 求高斯误差函数 erf 和 erfc

文章目录 Part.I IntroductionPart.II 概念定义Chap.I 误差函数 erfChap.II 误差函数补 erfc Part.II 求值与绘图Chap.I 求取方式Chap.II 绘图 Reference Part.I Introduction 本文将对误差函数&#xff08;ERror Function, ERF&#xff09;进行简单的介绍&#xff0c;并探索如…

Linux网站服务

1.概念:HTML:超级文本编辑语言 网页:使用HTML,PHP,JAVA语言格式书写的文件。 主页:网页中呈现用户的第一个界面。 网站:多个网页组合而成的一台网站服务器。 URL:统一资源定位符&#xff0c;访问网站的地址。 网站架构:LAMP: LinuxApacheMYSQLPHP(系统服务器程序数据管理…

Multitouch for Mac:手势自定义,提升工作效率

Multitouch for Mac作为一款触控板手势增强软件&#xff0c;其核心功能在于手势的自定义和与Mac系统的深度整合。通过Multitouch&#xff0c;用户可以轻松设置各种手势&#xff0c;如三指轻点、四指左右滑动等&#xff0c;来执行常见的任务&#xff0c;如打开应用、切换窗口、滚…

爬虫工作量由小到大的思维转变---<第七十三章 > Scrapy爬虫详解一下HTTPERROE的问题

前言&#xff1a; 在我们的日常工作中&#xff0c;有时会忽略一些工具或组件的重要性&#xff0c;直到它们引起一连串的问题&#xff0c;我们才意识到它们的价值。正如在Scrapy框架中的HttpErrorMiddleware&#xff08;HTTP错误中间件&#xff09;一样&#xff0c;在开始时&…

MATLAB R2017b安装的关键一步

作者&#xff1a;朱金灿 来源&#xff1a;clever101的专栏 MATLAB R2017b的安装文件比较庞大&#xff0c;由两个iso文件组成&#xff1a;R2017b_win64_dvd1.iso和R2017b_win64_dvd2.iso。安装时需要注意的是&#xff0c;首先使用DAEMON Tools Lite打开R2017b_win64_dvd1.iso&am…

【Git】Git在Gitee上的基本操作指南

文章目录 1. 查看 git 版本2. 从Gitee克隆仓库&#xff1a;3. 复制文件到工作目录&#xff1a;4. 将未跟踪的文件添加到暂存区&#xff1a;5. 在本地提交更改&#xff1a;6. 将更改推送到远程仓库&#xff08;Gitee&#xff09;&#xff1a;7. Windows特定提示&#xff1a; 1. …

品鉴中的品鉴笔记:如何记录和分享自己的品鉴心得

品鉴云仓酒庄雷盛红酒的过程&#xff0c;不仅是品尝美酒&#xff0c;更是一次与葡萄酒深度对话的旅程。为了更好地记录和分享自己的品鉴心得&#xff0c;养成写品鉴笔记的习惯是十分必要的。 首先&#xff0c;选择一个适合的记录工具。可以是传统的笔记本&#xff0c;也可以是…