TensorFlow2实战-系列教程2:神经网络分类任务

🧡💛💚TensorFlow2实战-系列教程 总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Jupyter Notebook中进行
本篇文章配套的代码资源已经上传

1、Mnist数据集

下载mnist数据集:

%matplotlib inline
from pathlib import Path
import requestsDATA_PATH = Path("data")
PATH = DATA_PATH / "mnist"
PATH.mkdir(parents=True, exist_ok=True)URL = "http://deeplearning.net/data/mnist/"
FILENAME = "mnist.pkl.gz"if not (PATH / FILENAME).exists():content = requests.get(URL + FILENAME).content(PATH / FILENAME).open("wb").write(content)

制作数据:

import pickle
import gzipwith gzip.open((PATH / FILENAME).as_posix(), "rb") as f:((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding="latin-1")

简单展示数据:

from matplotlib import pyplot
import numpy as nppyplot.imshow(x_train[0].reshape((28, 28)), cmap="gray")
print(x_train.shape)
print(y_train[0])

打印结果:

(50000, 784)
5

在这里插入图片描述

2、模型构建

在这里插入图片描述
在这里插入图片描述
输入为784神经元,经过隐层提取特征后为10个神经元,10个神经元的输出值经过softmax得到10个概率值,取出10个概率值中最高的一个就是神经网络的最后预测值

构建模型代码:

import tensorflow as tf
from tensorflow.keras import layers
model = tf.keras.Sequential()
model.add(layers.Dense(32, activation='relu'))
model.add(layers.Dense(32, activation='relu'))
model.add(layers.Dense(10, activation='softmax'))

选择损失函数,损失函数是机器学习一个非常重要的部分,基本直接决定了这个算法的效果,这里是多分类任务,一般我们就直接选用多元交叉熵函数就好了:
TensorFlow损失函数API

编译模型:

model.compile(optimizer=tf.keras.optimizers.Adam(0.001),loss=tf.keras.losses.SparseCategoricalCrossentropy(),metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
  1. adam优化器,学习率为0.001
  2. 多元交叉熵损失函数
  3. 评价指标

模型训练:

model.fit(x_train, y_train, epochs=5, batch_size=64, validation_data=(x_valid, y_valid))

训练数据,训练标签,训练轮次,batch_size,验证集

打印结果:

Train on 50000 samples, validate on 10000 samples
Epoch 1/5 50000/50000  1s 29us
sample-loss: 115566 - sparse_categorical_accuracy: 0.1122 - val_loss: 364928.5786 - val_sparse_categorical_accuracy: 0.1064
Epoch 2/5 50000/50000 1s 21us
sample - loss: 837104 - sparse_categorical_accuracy: 0.1136 - val_loss: 1323287.7028 - val_sparse_categorical_accuracy: 0.1064
Epoch 3/5 50000/50000 1s 20us
sample - loss: 1892431 - sparse_categorical_accuracy: 0.1136 - val_loss: 2448062.2680 - val_sparse_categorical_accuracy: 0.1064
Epoch 4/5 50000/50000 1s 20us
sample - loss: 3131130 - sparse_categorical_accuracy: 0.1136 - val_loss: 3773744.5348 - val_sparse_categorical_accuracy: 0.1064
Epoch 5/5 50000/50000 1s 20us
sample - loss: 4527781 - sparse_categorical_accuracy: 0.1136 - val_loss: 5207194.3728 - val_sparse_categorical_accuracy: 0.1064
<tensorflow.python.keras.callbacks.History at 0x1d3eb9015f8>

模型保存:

model.save('Mnist_model.h5')

3、TensorFlow常用模块

3.1 Tensor格式转换

创建一组数据

import numpy as np
input_data = np.arange(16)
input_data

打印结果:
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])

转换成TensorFlow格式的数据:

dataset = tf.data.Dataset.from_tensor_slices(input_data)
for data in dataset:print (data)

将一个ndarray转换成
打印结果:
tf.Tensor(0, shape=(), dtype=int32)
tf.Tensor(1, shape=(), dtype=int32)

tf.Tensor(14, shape=(), dtype=int32)
tf.Tensor(15, shape=(), dtype=int32)

3.2repeat操作

dataset = tf.data.Dataset.from_tensor_slices(input_data)
dataset = dataset.repeat(2)
for data in dataset:print (data)

tf.Tensor(0, shape=(), dtype=int32)
tf.Tensor(1, shape=(), dtype=int32)

tf.Tensor(14, shape=(), dtype=int32)
tf.Tensor(15, shape=(), dtype=int32)
tf.Tensor(0, shape=(), dtype=int32)
tf.Tensor(1, shape=(), dtype=int32)

tf.Tensor(14, shape=(), dtype=int32)
tf.Tensor(15, shape=(), dtype=int32)

会将当前的输出重复一遍

3.3 batch操作

dataset = tf.data.Dataset.from_tensor_slices(input_data)
dataset = dataset.repeat(2).batch(4)
for data in dataset:print (data)

tf.Tensor([0 1 2 3], shape=(4,), dtype=int32)
tf.Tensor([4 5 6 7], shape=(4,), dtype=int32)
tf.Tensor([ 8 9 10 11], shape=(4,), dtype=int32)
tf.Tensor([12 13 14 15], shape=(4,), dtype=int32)
tf.Tensor([0 1 2 3], shape=(4,), dtype=int32)
tf.Tensor([4 5 6 7], shape=(4,), dtype=int32)
tf.Tensor([ 8 9 10 11], shape=(4,), dtype=int32)
tf.Tensor([12 13 14 15], shape=(4,), dtype=int32)

将原来的数据按照4个为一个批次

3.4 shuffle操作

dataset = tf.data.Dataset.from_tensor_slices(input_data).shuffle(buffer_size=10).batch(4)
for data in dataset:print (data)

tf.Tensor([ 9 8 11 3], shape=(4,), dtype=int32)
tf.Tensor([ 5 6 1 13], shape=(4,), dtype=int32)
tf.Tensor([14 15 4 2], shape=(4,), dtype=int32)
tf.Tensor([12 7 0 10], shape=(4,), dtype=int32)

shuffle操作,直接翻译过来就是洗牌,把当前的数据进行打乱操作
buffer_size=10,就是缓存10来进行打乱取数据

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/246075.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

项目实战——Qt实现FFmpeg音视频转码器

文章目录 前言一、移植 FFmpeg 相关文件二、绘制 ui 界面三、实现简单的转码四、功能优化1、控件布局及美化2、缩放界面3、实现拖拽4、解析文件5、开启独立线程6、开启定时器7、最终运行效果 五、附录六、资源自取 前言 本文记录使用 Qt 实现 FFmepg 音视频转码器项目的开发过…

【干货】【常用电子元器件介绍】【电阻】(二)--敏感电阻器

声明&#xff1a;本人水平有限&#xff0c;博客可能存在部分错误的地方&#xff0c;请广大读者谅解并向本人反馈错误。   电子电路中除了采用普通电阻器外&#xff0c;还有一些敏感电阻器&#xff08;如热敏电阻器、压敏电阻器、光敏电阻器等&#xff09;也被广泛地应用。然而…

【华为 ICT HCIA eNSP 习题汇总】——题目集10

1、以下哪个动态路由协议不能应用在 IPv6 网络中&#xff1f; A、IS-IS B、RIPng C、BGP4 D、OSPFv3 考点&#xff1a;路由技术原理 解析&#xff1a;&#xff08;A&#xff09; IS-ISv6 是在 IPv6 环境下&#xff0c;IS-IS 协议进行了相应的扩展和改进&#xff0c;以适应 IPv6…

css中隐藏页面元素的方式(详细讲解)

这里写目录标题 一、前言二、实现方式display:nonevisibility:hiddenopacity:0设置height、width属性为0position:absoluteclip-path小结 三、区别参考文献 一、前言 在平常的样式排版中&#xff0c;我们经常遇到将某个模块隐藏的场景 通过css隐藏元素的方法有很多种&#xf…

vue项目中使用Element多个Form表单同时验证

一、项目需求 在项目中一个页面中需要实现多个Form表单&#xff0c;并在页面提交时需要对多个Form表单进行校验&#xff0c;多个表单都校验成功时才能提交。 二、实现效果 三、多个表单验证 注意项&#xff1a;多个form表单&#xff0c;每个表单上都设置单独的model和ref&am…

matlab对负数开立方根得到虚数的解决方案

问题描述&#xff1a;在matlab中&#xff0c;对负数开立方根&#xff0c;不出意外你将得到虚数。 例如 − 27 3 \sqrt[3]{-27} 3−27 ​&#xff0c;我们知道其实数解是-3&#xff0c;但在matlab中的计算结果如下&#xff1a; 问题原因&#xff1a;matlab中的立方根运算是在…

详细分析Java的树形工具类(含注释)

目录 前言1. 基本框架2. 实战应用 前言 对应的每个子孙属于该父亲&#xff0c;这其实是数据结构的基础知识&#xff0c;那怎么划分怎么归属呢 对应的基本知识推荐如下&#xff1a; 【数据结构】树和二叉树详细分析&#xff08;全&#xff09;【数据结构】B树和B树的笔记详细…

vue实现在线Excel表格功能

目录 1.安装x-data-spreadsheet xlsx 2.引入 3.使用 1.安装x-data-spreadsheet xlsx npm i x-data-spreadsheet xlsx2.引入 import zhCN from "x-data-spreadsheet/src/locale/zh-cn"; import Spreadsheet from "x-data-spreadsheet"; import * as X…

Scala基础知识

scala 1、scala简介 ​ scala是运行在JVM上的多范式编程语言&#xff0c;同时支持面向对象和面向函数式编程。 2、scala解释器 要启动scala解释器&#xff0c;只需要以下几步&#xff1a; 按住windows键 r输入scala即可 在scala命令提示窗口中执行:quit&#xff0c;即可退…

mybatis基础知识

title: mybatis的基础知识创建mybatis数据库&#xff0c;创建user表&#xff0c;在里面添加数据 创建空项目&#xff0c;创建maven模块 导入相关的依赖 在pom.xml文件中导入mysql&#xff0c;junit测试&#xff0c;mybatis依赖。 <!--导入mysql依赖--><dependency&g…

四步搞定国赛!快速入门大小模型融合的AI产品开发

前不久&#xff0c;2024中国大学生服务外包创新创业大赛正式启动&#xff01;作为中国高等教育学会“全国普通高校学科竞赛排行榜”竞赛&#xff0c;飞桨赛道已经吸引了超过200位选手报名参赛。 本文旨在助力“A01-基于文心大模型智能阅卷平台设计”赛道选手&#xff0c;更快地…

7.【SpringBoot3】项目部署、属性配置、多环境开发

1. SpringBoot 项目部署 项目完成后&#xff0c;需要部署到服务器上。 SpringBoot 项目需要经过编译打包生成一个 jar 包&#xff08;借助打包插件 spring-boot-maven-plugin&#xff09;&#xff0c;再将该 jar 包发送或拷贝到服务器上&#xff0c;然后就可以通过执行 java …

手机视频压缩怎么压缩?一键瘦身~

现在手机已经成为我们日常生活中必不可少的工具&#xff0c;而在手机的应用领域中&#xff0c;文件的传输和存储是一个非常重要的问题。很多用户都会遇到这样一个问题&#xff0c;那就是在手机上存储的文件太多太大&#xff0c;导致手机存储空间不足&#xff0c;那么怎么在手机…

Typora 无法导出 pdf 问题的解决

目录 问题描述 解决困难 解决方法 问题描述 Windows 下&#xff0c;以前&#xff08;Windows 11&#xff09; Typora 可以顺利较快地由 .md 导出 .pdf 文件&#xff0c;此功能当然非常实用与重要。 然而&#xff0c;有一次电脑因故重装了系统&#xff08;刷机&#xff09;…

SpringMVC 环境搭建入门

SpringMVC 是一种基于 Java 的实现 MVC 设计模型的请求驱动类型的轻量级 Web 框架&#xff0c;属于SpringFrameWork 的后续产品&#xff0c;已经融合在 Spring Web Flow 中。 SpringMVC 已经成为目前最主流的MVC框架之一&#xff0c;并且随着Spring3.0 的发布&#xff0c;全面…

论述Python中列表、元组、字典和集合的概念

Python列表是用于存储任意数目、任意类型的数据集合&#xff0c;包含多个元素的有序连续的内存空间&#xff0c;是内置可变序列&#xff0c;或者说可以任意修改。在Python中&#xff0c;列表以方括号&#xff08;[ ]&#xff09;形式编写。 Python元组与Python列表类似&#x…

Flink 集成 Debezium Confluent Avro ( format=debezium-avro-confluent )

博主历时三年精心创作的《大数据平台架构与原型实现:数据中台建设实战》一书现已由知名IT图书品牌电子工业出版社博文视点出版发行,点击《重磅推荐:建大数据平台太难了!给我发个工程原型吧!》了解图书详情,京东购书链接:https://item.jd.com/12677623.html,扫描左侧二维…

docker生命周期管理命令

文章目录 前言1、docker create2、docker run2.1、常用选项2.2、系统2.3、网络2.4、健康检查 3、docker start/stop/restart4、docker kill5、docker rm6、docker pause/unpause总结 前言 在云原生时代&#xff0c;Docker已成为必不可少的容器管理工具。通过掌握Docker常用的容…

[UE]无法接收OnInputTouchBegin事件

遇到问题 想做一个鼠标按住左键选中Actor拖动而旋转的功能&#xff0c;想法是通过OnInputTouchBeginOnInputTouchEndTick实现。但是却无法接收OnInputTouchBegin与OnInputTouchEnd事件。 解决方案 想要触发OnInputTouchBegin事件 1.需要设置勾选ProjectSettings->Input-&…

.net访问oracle数据库性能问题

问题&#xff1a; 生产环境相同的inser语句在别的非.NET程序相应明显快于.NET程序&#xff0c;执行时间相差比较大&#xff0c;影响正常业务运行&#xff0c;测试环境反而正常。 问题详细诊断过程 问题初步判断诊断过程&#xff1a; 查询插入慢的sql_id 检查对应的执行计划…