TenorFlow多层感知机识别手写体

文章目录

  • 数据准备
  • 建立模型
          • 建立输入层 x
          • 建立隐藏层h1
          • 建立隐藏层h2
          • 建立输出层
  • 定义训练方式
          • 建立训练数据label真实值 placeholder
          • 定义loss function
          • 选择optimizer
  • 定义评估模型的准确率
          • 计算每一项数据是否正确预测
          • 将计算预测正确结果,加总平均
  • 开始训练
          • 画出误差执行结果
          • 画出准确率执行结果
  • 评估模型的准确率
  • 进行预测
  • 找出预测错误

GITHUB地址https://github.com/fz861062923/TensorFlow
注意下载数据连接的是外网,有一股神秘力量让你403

数据准备

import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_datamnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
C:\Users\admin\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\h5py\__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.from ._conv import register_converters as _register_convertersWARNING:tensorflow:From <ipython-input-1-2ee827ab903d>:4: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
WARNING:tensorflow:From C:\Users\admin\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please write your own downloading logic.
WARNING:tensorflow:From C:\Users\admin\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data/train-images-idx3-ubyte.gz
WARNING:tensorflow:From C:\Users\admin\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data/train-labels-idx1-ubyte.gz
WARNING:tensorflow:From C:\Users\admin\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.one_hot on tensors.
WARNING:tensorflow:From C:\Users\admin\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\base.py:252: _internal_retry.<locals>.wrap.<locals>.wrapped_fn (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please use urllib or similar directly.
Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
WARNING:tensorflow:From C:\Users\admin\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
print('train images     :', mnist.train.images.shape,'labels:'           , mnist.train.labels.shape)
print('validation images:', mnist.validation.images.shape,' labels:'          , mnist.validation.labels.shape)
print('test images      :', mnist.test.images.shape,'labels:'           , mnist.test.labels.shape)
train images     : (55000, 784) labels: (55000, 10)
validation images: (5000, 784)  labels: (5000, 10)
test images      : (10000, 784) labels: (10000, 10)

建立模型

def layer(output_dim,input_dim,inputs, activation=None):#激活函数默认为NoneW = tf.Variable(tf.random_normal([input_dim, output_dim]))#以正态分布的随机数建立并且初始化权重Wb = tf.Variable(tf.random_normal([1, output_dim]))XWb = tf.matmul(inputs, W) + bif activation is None:outputs = XWbelse:outputs = activation(XWb)return outputs
建立输入层 x
x = tf.placeholder("float", [None, 784])
建立隐藏层h1
h1=layer(output_dim=1000,input_dim=784,inputs=x ,activation=tf.nn.relu)  
建立隐藏层h2
h2=layer(output_dim=1000,input_dim=1000,inputs=h1 ,activation=tf.nn.relu)  
建立输出层
y_predict=layer(output_dim=10,input_dim=1000,inputs=h2,activation=None)

定义训练方式

建立训练数据label真实值 placeholder
y_label = tf.placeholder("float", [None, 10])#训练数据的个数很多所以设置为None
定义loss function
# 深度学习模型的训练中使用交叉熵训练的效果比较好
loss_function = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=y_predict , labels=y_label))
选择optimizer
optimizer = tf.train.AdamOptimizer(learning_rate=0.001) \.minimize(loss_function)
#使用Loss_function来计算误差,并且按照误差更新模型权重与偏差,使误差最小化

定义评估模型的准确率

计算每一项数据是否正确预测
correct_prediction = tf.equal(tf.argmax(y_label  , 1),tf.argmax(y_predict, 1))#将one-hot encoding转化为1所在的位数,方便比较
将计算预测正确结果,加总平均
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))

开始训练

trainEpochs = 15#执行15个训练周期
batchSize = 100#每一批的数量为100
totalBatchs = int(mnist.train.num_examples/batchSize)#计算每一个训练周期应该执行的次数
epoch_list=[];accuracy_list=[];loss_list=[];
from time import time
startTime=time()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
for epoch in range(trainEpochs):#执行15个训练周期#每个训练周期执行550批次训练for i in range(totalBatchs):batch_x, batch_y = mnist.train.next_batch(batchSize)#用该函数批次读取数据sess.run(optimizer,feed_dict={x: batch_x,y_label: batch_y})#使用验证数据计算准确率loss,acc = sess.run([loss_function,accuracy],feed_dict={x: mnist.validation.images, #验证数据的featuresy_label: mnist.validation.labels})#验证数据的labelepoch_list.append(epoch)loss_list.append(loss);accuracy_list.append(acc)    print("Train Epoch:", '%02d' % (epoch+1), \"Loss=","{:.9f}".format(loss)," Accuracy=",acc)duration =time()-startTime
print("Train Finished takes:",duration)        
Train Epoch: 01 Loss= 133.117172241  Accuracy= 0.9194
Train Epoch: 02 Loss= 88.949943542  Accuracy= 0.9392
Train Epoch: 03 Loss= 80.701606750  Accuracy= 0.9446
Train Epoch: 04 Loss= 72.045913696  Accuracy= 0.9506
Train Epoch: 05 Loss= 71.911483765  Accuracy= 0.9502
Train Epoch: 06 Loss= 63.642936707  Accuracy= 0.9558
Train Epoch: 07 Loss= 67.192626953  Accuracy= 0.9494
Train Epoch: 08 Loss= 55.959281921  Accuracy= 0.9618
Train Epoch: 09 Loss= 58.867351532  Accuracy= 0.9592
Train Epoch: 10 Loss= 61.904548645  Accuracy= 0.9612
Train Epoch: 11 Loss= 58.283069611  Accuracy= 0.9608
Train Epoch: 12 Loss= 54.332244873  Accuracy= 0.9646
Train Epoch: 13 Loss= 58.152175903  Accuracy= 0.9624
Train Epoch: 14 Loss= 51.552104950  Accuracy= 0.9688
Train Epoch: 15 Loss= 52.803482056  Accuracy= 0.9678
Train Finished takes: 545.0556836128235
画出误差执行结果
%matplotlib inline
import matplotlib.pyplot as plt
fig = plt.gcf()#获取当前的figure图
fig.set_size_inches(4,2)#设置图的大小
plt.plot(epoch_list, loss_list, label = 'loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['loss'], loc='upper left')
<matplotlib.legend.Legend at 0x1edb8d4c240>

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

画出准确率执行结果
plt.plot(epoch_list, accuracy_list,label="accuracy" )
fig = plt.gcf()
fig.set_size_inches(4,2)
plt.ylim(0.8,1)
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend()
plt.show()

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

评估模型的准确率

print("Accuracy:", sess.run(accuracy,feed_dict={x: mnist.test.images, y_label: mnist.test.labels}))
Accuracy: 0.9643

进行预测

prediction_result=sess.run(tf.argmax(y_predict,1),feed_dict={x: mnist.test.images })
prediction_result[:10]
array([7, 2, 1, 0, 4, 1, 4, 9, 6, 9], dtype=int64)
import matplotlib.pyplot as plt
import numpy as np
def plot_images_labels_prediction(images,labels,prediction,idx,num=10):fig = plt.gcf()fig.set_size_inches(12, 14)if num>25: num=25 for i in range(0, num):ax=plt.subplot(5,5, 1+i)ax.imshow(np.reshape(images[idx],(28, 28)), cmap='binary')title= "label=" +str(np.argmax(labels[idx]))if len(prediction)>0:title+=",predict="+str(prediction[idx]) ax.set_title(title,fontsize=10) ax.set_xticks([]);ax.set_yticks([])        idx+=1 plt.show()
plot_images_labels_prediction(mnist.test.images,mnist.test.labels,prediction_result,0)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

y_predict_Onehot=sess.run(y_predict,feed_dict={x: mnist.test.images })
y_predict_Onehot[8]
array([-6185.544  , -5329.589  ,  1897.1707 , -3942.7764 ,   347.9809 ,5513.258  ,  6735.7153 , -5088.5273 ,   649.2062 ,    69.50408],dtype=float32)

找出预测错误

for i in range(400):if prediction_result[i]!=np.argmax(mnist.test.labels[i]):print("i="+str(i)+"   label=",np.argmax(mnist.test.labels[i]),"predict=",prediction_result[i])
i=8   label= 5 predict= 6
i=18   label= 3 predict= 8
i=149   label= 2 predict= 4
i=151   label= 9 predict= 8
i=233   label= 8 predict= 7
i=241   label= 9 predict= 8
i=245   label= 3 predict= 5
i=247   label= 4 predict= 2
i=259   label= 6 predict= 0
i=320   label= 9 predict= 1
i=340   label= 5 predict= 3
i=381   label= 3 predict= 7
i=386   label= 6 predict= 5
sess.close()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/258763.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Spring Boot 笔记 012 创建接口_添加文章分类

1.1.1 实体类添加校验 package com.geji.pojo;import jakarta.validation.constraints.NotEmpty; import lombok.Data;import java.time.LocalDateTime;Data public class Category {private Integer id;//主键IDNotEmptyprivate String categoryName;//分类名称NotEmptypriva…

【MySQL】多表关系的基本学习

&#x1f308;个人主页: Aileen_0v0 &#x1f525;热门专栏: 华为鸿蒙系统学习|计算机网络|数据结构与算法 ​&#x1f4ab;个人格言:“没有罗马,那就自己创造罗马~” #mermaid-svg-3oES1ZdkKIklfKzq {font-family:"trebuchet ms",verdana,arial,sans-serif;font-siz…

【Linux】Framebuffer 应用

# 前置知识 LCD 操作原理 在 Linux 系统中通过 Framebuffer 驱动程序来控制 LCD。 Frame 是帧的意思&#xff0c; buffer 是缓冲的意思&#xff0c;这意味着 Framebuffer 就是一块内存&#xff0c;里面保存着一帧图像。 Framebuffer 中保存着一帧图像的每一个像素颜色值&…

17.3.1.3 灰度

版权声明&#xff1a;本文为博主原创文章&#xff0c;转载请在显著位置标明本文出处以及作者网名&#xff0c;未经作者允许不得用于商业目的。 灰度的算法主要有以下三种&#xff1a; 1、最大值法: 原图像&#xff1a;颜色值color&#xff08;R&#xff0c;G&#xff0c;B&a…

【python】网络爬虫与信息提取--requests库

导学 当一个软件想获得数据&#xff0c;那么我们只有把网站当成api就可以 requests库:自动爬取HTML页面&#xff0c;自动网络请求提交 robots协议&#xff1a;网络爬虫排除标准&#xff08;网络爬虫的规则&#xff09; beautiful soup库&#xff1a;解析HTML页面 工具&…

nginx2

mkdir /usr/local/develop cd /usr/local/develop 下载 wget http://nginx.org/download/nginx-1.17.4.tar.gz yum install git git clone https://github.com/arut/nginx-rtmp-module.git 解压文件 tar zxmf nginx-1.17.4.tar.gz 进入解压目录 cd nginx-1.17.4/ 安装编译…

《Go 简易速速上手小册》第9章:数据库交互(2024 最新版)

文章目录 9.1 连接数据库 - Go 语言的海底宝藏之门9.1.1 基础知识讲解安装数据库驱动数据库连接 9.1.2 重点案例&#xff1a;用户信息管理系统准备数据库Go 代码实现连接数据库添加新用户查询用户信息用户登录验证主函数 9.1.3 拓展案例 1&#xff1a;批量添加用户准备数据库Go…

【刷题】牛客— NC21 链表内指定区间反转

链表内指定区间反转 题目描述思路一&#xff08;暴力破解版&#xff09;思路二&#xff08;技巧反转版&#xff09;思路三&#xff08;递归魔法版&#xff09;Thanks♪(&#xff65;ω&#xff65;)&#xff89;谢谢阅读&#xff01;&#xff01;&#xff01;下一篇文章见&…

IIC--集成电路总线

目录 一、IIC基础知识 1、设计IIC电路的原因&#xff1a; 2、上拉电阻阻值怎么确定 3、IIC分类 4、IIC协议 二、单片机使用IIC读写数据 1、 IIC发送一个字节数据&#xff1a; 2、IIC读取一个字节数据&#xff1a; 一、IIC基础知识 1、设计IIC电路的原因&#xff1a; (…

【机器学习案例5】语言建模 - 最常见的预训练任务一览表

自监督学习 (SSL) 是基于 Transformer 的预训练语言模型的支柱,该范例涉及解决有助于建模自然语言的预训练任务 (PT)。本文将所有流行的预训练任务放在一起,以便我们一目了然地评估它们。 SSL 中的损失函数 这里的损失函数只是模型训练的各个预训练任务损失的加权和。 以BE…

selenium定位元素报错:‘WebDriver‘ object has no attribute ‘find_element_by_id‘

Selenium更新到 4.x版本后&#xff0c;以前的一些常用的代码的语法发生了改变 from selenium import webdriver browser webdriver.Chrome() browser.get(https://www.baidu.com) input browser.find_element_by_id(By.ID,kw) input.send_keys(Python)目标&#xff1a;希望通…

IP地址+子网掩码+CIDR学习笔记

目录 一、IP地址 1、表示方法&#xff1a; 2、特殊IP地址 二、子网掩码 1、判断网络位和主机位 2、子网划分 三、无分类编址CIDR 1、CIDR路由汇聚 汇聚规则&#xff1a; 汇聚ID&#xff1a; 2、最佳路由匹配原则 一、IP地址 1、表示方法&#xff1a; 机器中存放的…

STM32F1 - 中断系统

Interrupt 1> 硬件框图2> NVIC 中断管理3> EXTI 中断管理3.1> EXTI与NVIC3.2> EXTI内部框图 4> 外部中断实验4.1> 实验概述4.2> 程序设计 5> 中断向量表6> 总结 1> 硬件框图 NVIC&#xff1a;Nested Vectored Interrupt Controller【嵌套向量…

家庭动态网络怎么在公网访问主机数据?--DDNS配置(动态域名解析配置)

前言 Dynamic DNS是一个DNS服务。当您的设备IP地址被互联网服务提供商动态变更时,它提供选项来自动变更一个或多个DNS记录的IP地址。 此服务在技术术语上也被称作DDNS或是Dyn DNS 如果您没有一个静态IP,那么每次您重新连接到互联网是IP都会改变。为了避免每次IP变化时手动更…

WebStorm | 如何修改webstorm中新建html文件默认生成模板中title的初始值

在近期的JS的学习中&#xff0c;使用webstorm&#xff0c;总是要先新建一个html文件&#xff0c;然后再到里面书写<script>标签&#xff0c;真是麻烦&#xff0c;而且标题也是默认的title&#xff0c;想改成文件名还总是需要手动去改 经过小小的研究&#xff0c;找到了修…

【复现】Supabase后端服务 SQL注入漏洞_48

目录 一.概述 二 .漏洞影响 三.漏洞复现 1. 漏洞一&#xff1a; 四.修复建议&#xff1a; 五. 搜索语法&#xff1a; 六.免责声明 一.概述 Supabase是什么 Supabase将自己定位为Firebase的开源替代品&#xff0c;提供了一套工具来帮助开发者构建web或移动应用程序。 Sup…

Spring Boot3自定义异常及全局异常捕获

⛰️个人主页: 蒾酒 &#x1f525;系列专栏&#xff1a;《spring boot实战》 &#x1f30a;山高路远&#xff0c;行路漫漫&#xff0c;终有归途。 目录 前置条件 目的 主要步骤 定义自定义异常类 创建全局异常处理器 手动抛出自定义异常 前置条件 已经初始化好一个…

Linux第52步_移植ST公司的linux内核第4步_关闭内核模块验证和log信息时间戳_编译_并通过tftp下载测试

1、采用程序配置关闭“内核模块验证” 默认配置文件“stm32mp1_atk_defconfig”路径为“arch/arm/configs”; 使用VSCode打开默认配置文件“stm32mp1_atk_defconfg”&#xff0c;然后将下面的4条语句屏蔽掉&#xff0c;如下&#xff1a; CONFIG_MODULE_SIGy CONFIG_MODULE_…

二叉树前序中序后序遍历(非递归)

大家好&#xff0c;又和大家见面啦&#xff01;今天我们一起去看一下二叉树的前序中序后序的遍历&#xff0c;相信这个对大家来说是信手拈来&#xff0c;但是&#xff0c;今天我们并不是使用常见的递归方式来解题&#xff0c;我们采用迭代方式解答。我们先看第一道前序遍历 1…

前端秘法进阶篇----这还是我们熟悉的浏览器吗?(浏览器的渲染原理)

目录 一.浏览器渲染原理 二.渲染时间点 三.渲染流水线 1.解析html(Parse HTML) 1.1解析成DOM树(document object model) 1.2解析成CSSOM树(css object model) 2.样式计算(Recalculate Style) 3.布局(Layout) 4.分层(Layer) 5. 绘制(Paint) 6.分块(Tiling) 7. 光栅化…