神经网络构成、优化、常用函数+激活函数

Iris分类

数据集介绍,共有数据150组,每组包括长宽等4个输入特征,同时给出输入特征对应的Iris类别,分别用0,1,2表示。

从sklearn包datasets读入数据集。

from sklearn import darasets
from pandas import DataFrame
import pandas as pd
x_data = datasets.load_iris().data #  输入特征
y_data = datasets.load_iris().target # 标签
x_data = DataFrame(x_data, columns=["花萼长度",'花萼宽度','花瓣长度','花瓣宽度'])
pd.set_option('display.unicode.east_asian_width',True) # 设置列名对齐
x_dara['类别'] = y_data # 新增一列

神经网络实现分类步骤。

1.准备数据:

数据集读入,数据集乱序,生成训练集和测试集,配成输入特征/标签对,每次读入一部分

2.搭建网络

定义神经网络中的所有可训练参数

3.参数优化

嵌套循环迭代,with结构更新参数,显示当前loss

4.测试效果

计算当前向前传播后的准确率,显示当前acc

5可视化acc/loss

# -*- coding: UTF-8 -*-
# 利用鸢尾花数据集,实现前向传播、反向传播,可视化loss曲线# 导入所需模块
import tensorflow as tf
from sklearn import datasets
from matplotlib import pyplot as plt
import numpy as np# 导入数据,分别为输入特征和标签
x_data = datasets.load_iris().data
y_data = datasets.load_iris().target# 随机打乱数据(因为原始数据是顺序的,顺序不打乱会影响准确率)
# seed: 随机数种子,是一个整数,当设置之后,每次生成的随机数都一样(为方便教学,以保每位同学结果一致)
np.random.seed(116)  # 使用相同的seed,保证输入特征和标签一一对应
np.random.shuffle(x_data)
np.random.seed(116)
np.random.shuffle(y_data)
tf.random.set_seed(116)# 将打乱后的数据集分割为训练集和测试集,训练集为前120行,测试集为后30行
x_train = x_data[:-30]
y_train = y_data[:-30]
x_test = x_data[-30:]
y_test = y_data[-30:]# 转换x的数据类型,否则后面矩阵相乘时会因数据类型不一致报错
x_train = tf.cast(x_train, tf.float32)
x_test = tf.cast(x_test, tf.float32)# from_tensor_slices函数使输入特征和标签值一一对应。(把数据集分批次,每个批次batch组数据)
train_db = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(32)
test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)# 生成神经网络的参数,4个输入特征故,输入层为4个输入节点;因为3分类,故输出层为3个神经元
# 用tf.Variable()标记参数可训练
# 使用seed使每次生成的随机数相同(方便教学,使大家结果都一致,在现实使用时不写seed)
w1 = tf.Variable(tf.random.truncated_normal([4, 3], stddev=0.1, seed=1))
b1 = tf.Variable(tf.random.truncated_normal([3], stddev=0.1, seed=1))lr = 0.1  # 学习率为0.1
train_loss_results = []  # 将每轮的loss记录在此列表中,为后续画loss曲线提供数据
test_acc = []  # 将每轮的acc记录在此列表中,为后续画acc曲线提供数据
epoch = 500  # 循环500轮
loss_all = 0  # 每轮分4个step,loss_all记录四个step生成的4个loss的和# 训练部分
for epoch in range(epoch):  #数据集级别的循环,每个epoch循环一次数据集for step, (x_train, y_train) in enumerate(train_db):  #batch级别的循环 ,每个step循环一个batchwith tf.GradientTape() as tape:  # with结构记录梯度信息y = tf.matmul(x_train, w1) + b1  # 神经网络乘加运算y = tf.nn.softmax(y)  # 使输出y符合概率分布(此操作后与独热码同量级,可相减求loss)y_ = tf.one_hot(y_train, depth=3)  # 将标签值转换为独热码格式,方便计算loss和accuracyloss = tf.reduce_mean(tf.square(y_ - y))  # 采用均方误差损失函数mse = mean(sum(y-out)^2)loss_all += loss.numpy()  # 将每个step计算出的loss累加,为后续求loss平均值提供数据,这样计算的loss更准确# 计算loss对各个参数的梯度grads = tape.gradient(loss, [w1, b1])# 实现梯度更新 w1 = w1 - lr * w1_grad    b = b - lr * b_gradw1.assign_sub(lr * grads[0])  # 参数w1自更新b1.assign_sub(lr * grads[1])  # 参数b自更新# 每个epoch,打印loss信息print("Epoch {}, loss: {}".format(epoch, loss_all/4))train_loss_results.append(loss_all / 4)  # 将4个step的loss求平均记录在此变量中loss_all = 0  # loss_all归零,为记录下一个epoch的loss做准备# 测试部分# total_correct为预测对的样本个数, total_number为测试的总样本数,将这两个变量都初始化为0total_correct, total_number = 0, 0for x_test, y_test in test_db:# 使用更新后的参数进行预测y = tf.matmul(x_test, w1) + b1y = tf.nn.softmax(y)pred = tf.argmax(y, axis=1)  # 返回y中最大值的索引,即预测的分类# 将pred转换为y_test的数据类型pred = tf.cast(pred, dtype=y_test.dtype)# 若分类正确,则correct=1,否则为0,将bool型的结果转换为int型correct = tf.cast(tf.equal(pred, y_test), dtype=tf.int32)# 将每个batch的correct数加起来correct = tf.reduce_sum(correct)# 将所有batch中的correct数加起来total_correct += int(correct)# total_number为测试的总样本数,也就是x_test的行数,shape[0]返回变量的行数total_number += x_test.shape[0]# 总的准确率等于total_correct/total_numberacc = total_correct / total_numbertest_acc.append(acc)print("Test_acc:", acc)print("--------------------------")# 绘制 loss 曲线
plt.title('Loss Function Curve')  # 图片标题
plt.xlabel('Epoch')  # x轴变量名称
plt.ylabel('Loss')  # y轴变量名称
plt.plot(train_loss_results, label="$Loss$")  # 逐点画出trian_loss_results值并连线,连线图标是Loss
plt.legend()  # 画出曲线图标
plt.show()  # 画出图像# 绘制 Accuracy 曲线
plt.title('Acc Curve')  # 图片标题
plt.xlabel('Epoch')  # x轴变量名称
plt.ylabel('Acc')  # y轴变量名称
plt.plot(test_acc, label="$Accuracy$")  # 逐点画出test_acc值并连线,连线图标是Accuracy
plt.legend()
plt.show()

 根据MP模型可以看出,求出的y实际上就是计算出属于哪一种分类的概率

 

 其求出的loss就是概率和0/1相减,再将loss和w与b求偏导,通过公式运算得到w和b

预备函数

tf.where

a=tf.constant([1,2,3,1,1])

b=tf.constant([0,1,3,4,5])

c=tf.where(tf.greater(a,b),a,b)

2.np.random.RandomState.rand(维度)返回[0,1]的随机数

3.np.vstack() 将两个数组按垂直方向叠加

4np.mgridp[起始值:结束值:步长,....] [)

5.x.ravel() 将x变为一维数组

6.np.c_ [数组1,数组2]返回的间隔数值点配对

神经网络复杂度

指数衰减学习率

可以先用较大的学习率,快速得到较优解,然后逐步减小学习率,使模型在训练后期稳定

指数衰减学习率=初始学习率*学习率衰减率^(当前轮数/多少轮衰减一次)更新频率

epoch = 40
LR_BASE = 0.2
LR_DECAY = 0.99
LR_STEP = 1
for epoch in range(epoch):lr = LR_BASE*LR_DECAY**(epoch/LR_STEP)with tf.GradientTape() as tape:loss = tf.square(w + 1)grads = tape.gradient(loss, w)w.assign_sub(;lr*grads)

激活函数

1.Signmoid函数

特点:容易造成梯度消失,输出非0均值,收敛慢,幂运算复杂,训练时间长

2.Tanh函数

特点:输出是0均值,容易造成梯度消失,幂运算复杂,训练时间长

3.Relu函数

解决梯度消失问题正区间,容易造成神经元死亡,改变随机初始化,避免过多设置更小学习率,减少参数的巨大变化,避免训练中产生过多负数特征进入函数

Leaky Rely

1首选relu函数

2学习率设置较小值

3输入特征标准化,既让输入特征满足以0为均值,1为标准差的正态分布

4初始参数中心化,既让随机生成的参数满足以0为均值,sqrt(2/当前层输入特征个数)为标准差的正态分布

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/373336.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【密码学】分组密码概述

一、分组密码的定义 分组密码和流密码都是对称密码体制。 流密码:是将明文视为连续的比特流,对每个比特或字节进行实时加密,而不将其分割成固定的块。流密码适用于加密实时数据流,如网络通信。分组密码:是将明文数据…

GuLi商城-商品服务-API-品牌管理-OSS获取服务端签名

新建第三方服务: 引入common 把common中oss的依赖都拿到第三方服务中来 配置文件: 加上nacos注解:<

windows USB 设备驱动开发-USB带宽

本文讨论如何仔细管理 USB 带宽的指导。 每个 USB 客户端驱动程序都有责任最大程度地减少其使用的 USB 带宽&#xff0c;并尽快将未使用的带宽返回到可用带宽池。 在这里&#xff0c;我们认为USB 2.0 的速度是480Mbps、12Mbps、1.5Mbps&#xff0c;这分别对应高速、全速、低速…

【QML之·基础语法概述】

系列文章目录 文章目录 前言一、QML基础语法二、属性三、脚本四、核心元素类型4.1 元素可以分为视觉元素和非视觉元素。4.2 Item4.2.1 几何属性(Geometry&#xff09;:4.2.2 布局处理:4.2.3 键处理&#xff1a;4.2.4 变换4.2.5 视觉4.2.6 状态定义 4.3 Rectangle4.3.1 颜色 4.4…

《植物大战僵尸杂交版》2.2版本:全新内容与下载指南

《植物大战僵尸杂交版》2.2版本已经火热更新&#xff0c;带来了一系列令人兴奋的新玩法和调整&#xff0c;为这款经典的塔防游戏注入了新的活力。如果你是《植物大战僵尸》系列的忠实粉丝&#xff0c;那么这个版本绝对值得你一探究竟。 2.2版本更新亮点 新增看星星玩法 这个新…

宏碁F5-572G-59K3笔记本笔记本电脑拆机清灰教程(详解)

1. 前言 我的笔记本开机比较慢&#xff0c;没有固态&#xff0c;听说最近固态比较便宜&#xff0c;就想入手一个&#xff0c;于是拆笔记本看一下有没有可以安的装位置。&#xff08;友情提示&#xff0c;在拆机之前记得洗手并擦干&#xff0c;以防静电损坏电源器件&#xff09…

ChatTTS使用

ChatTTS是一款适用于日常对话的生成式语音模型。 克隆仓库 git clone https://github.com/2noise/ChatTTS cd ChatTTS 使用 conda 安装 conda create -n chattts conda activate chattts pip install -r requirements.txt 安装完成后运行 下载模型并运行 python exampl…

Python酷库之旅-第三方库Pandas(013)

目录 一、用法精讲 31、pandas.read_feather函数 31-1、语法 31-2、参数 31-3、功能 31-4、返回值 31-5、说明 31-6、用法 31-6-1、数据准备 31-6-2、代码示例 31-6-3、结果输出 32、pandas.DataFrame.to_feather函数 32-1、语法 32-2、参数 32-3、功能 32-4、…

【计算机毕业设计】基于Springboot的IT技术交流和分享平台【源码+lw+部署文档】

包含论文源码的压缩包较大&#xff0c;请私信或者加我的绿色小软件获取 免责声明&#xff1a;资料部分来源于合法的互联网渠道收集和整理&#xff0c;部分自己学习积累成果&#xff0c;供大家学习参考与交流。收取的费用仅用于收集和整理资料耗费时间的酬劳。 本人尊重原创作者…

14-56 剑和诗人30 - IaC、PaC 和 OaC 在云成功中的作用

介绍 随着各大企业在 2024 年加速采用云计算&#xff0c;基础设施即代码 (IaC)、策略即代码 (PaC) 和优化即代码 (OaC) 已成为成功实现云迁移、IT 现代化和业务转型的关键功能。 让我在云计划的背景下全面了解这些代码功能的当前状态。我们将研究现代云基础设施趋势、IaC、Pa…

MATLAB备赛资源库(1)建模指令

一、介绍 MATLAB&#xff08;Matrix Laboratory&#xff09;是一种强大的数值计算环境和编程语言&#xff0c;特别设计用于科学计算、数据分析和工程应用。 二、使用 数学建模使用MATLAB通常涉及以下几个方面&#xff1a; 1. **数据处理与预处理**&#xff1a; - 导入和处理…

MacOS如何切换shell类型

切换 shell 类型 如果你想在不同的 shell 之间切换&#xff0c;以探索它们的不同之处&#xff0c;或者因为你知道自己需要其中的一个或另一个&#xff0c;可以使用如下命令&#xff1a; 切换到 bash chsh -s $(which bash)切换到 zsh chsh -s $(which zsh)$()语法的作用是运…

VSCode无法连接网络安装插件-手动安装插件

手动安装插件&#xff1a; 你可以尝试从 Visual Studio Code Marketplace 下载 .vsix 文件&#xff0c;然后在VSCode中手动安装。 手动安装的步骤如下&#xff1a; 1.访问插件页面&#xff0c;下载 .vsix 文件。 Extensions for Visual Studio family of products | Visual S…

CSS【详解】层叠 z-index (含 z-index 的层叠规则,不同样式的层叠效果)

仅对已定位的元素&#xff08; position:relative&#xff0c;position:absolute&#xff0c;position:fixed &#xff09;有效&#xff0c;默认值为0&#xff0c;可以为负值。 z-index 的层叠规则 z-index 值从小到大层叠 兄弟元素 z-index 值相同时&#xff0c;后面的元素在…

MySQL架构你了解多少?

MySQL是一个服务器-客户端应用&#xff0c;MySQL8.0服务器是由连接池、服务管理工具和公共组件、NoSQL接口、SQL接口、解析器、优化器、缓存、存储引擎、文件系统组成。MySQL还为各种编程语言提供了一套用于外部程序访问服务器的连接器。整体架构图如下所示: MySQLConnectors:为…

文件操作和IO流(Java版)

前言 我们无时无刻不在操作文件。可以说&#xff0c;我们在电脑上能看到的图片、视频、音频、文档都是一个又一个的文件&#xff0c;我们需要从文件中读取我们需要的数据&#xff0c;将数据运算后也需要将结果写入文件中长期保存。可见文件的重要性&#xff0c;今天我们就来简…

windows实现Grafana+Loki+loki4j轻量级日志系统,告别沉重的ELK

文章目录 Loki下载Grafana下载安装Loki添加Loki数据源springboot日志推送 Loki下载 下载地址&#xff1a;https://github.com/grafana/loki/releases/ 找到loki-windows-amd64.exe.zip点击开始下载&#xff0c;我这里下载的2.9.9版本 Grafana下载 下载地址&#xff1a;http…

Hi3861 OpenHarmony嵌入式应用入门--MQTT

MQTT 是机器对机器(M2M)/物联网(IoT)连接协议。它被设计为一个极其轻量级的发布/订阅消息传输 协议。对于需要较小代码占用空间和/或网络带宽非常宝贵的远程连接非常有用&#xff0c;是专为受限设备和低带宽、 高延迟或不可靠的网络而设计。这些原则也使该协议成为新兴的“机器…

“Numpy数据分析与挖掘:高效学习重点技能“

目录 # 开篇 # 补充 zeros & ones eye 1. numpy数组的创建 1.1 array 1.2 range 1.3 arange 1.4 常见的数据类型 1.5 astype 1.6 random.random() & round 2. numpy数组计算和数组计算 2.1 reshape 2.2 shape 2.3 将一维数组变成多维数组 2.4 指定一维…

Java版Flink使用指南——合流

大纲 新建工程无界流奇数Long型无界流偶数Long型无界流奇数String型无界流 合流UnionConnect 测试工程代码 在《Java版Flink使用指南——分流导出》中&#xff0c;我们通过addSink进行了输出分流。本文我们将介绍几种通过多个无界流输入合并成一个流来进行处理的方案。 新建工…