Tensorflow2.0笔记 - FashionMnist数据集训练

        本笔记使用FashionMnist数据集,搭建一个5层的神经网络进行训练,并统计测试集的精度。

        本笔记中FashionMnist数据集是直接下载到本地加载的方式,不涉及用梯子。

        关于FashionMnist的介绍,请自行百度。

        

#Fashion Mnist数据集本地下载和加载(不用梯子)
#https://blog.csdn.net/scar2016/article/details/115361245 (百度网盘)
#https://blog.csdn.net/weixin_43272781/article/details/110006990 (github)
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, metricstf.__version__#加载fashion mnist数据集
def load_mnist(path, kind='train'):import osimport gzipimport numpy as np"""Load MNIST data from `path`"""labels_path = os.path.join(path,'%s-labels-idx1-ubyte.gz'% kind)images_path = os.path.join(path,'%s-images-idx3-ubyte.gz'% kind)with gzip.open(labels_path, 'rb') as lbpath:labels = np.frombuffer(lbpath.read(), dtype=np.uint8,offset=8)with gzip.open(images_path, 'rb') as imgpath:images = np.frombuffer(imgpath.read(), dtype=np.uint8,offset=16).reshape(len(labels), 784)return images, labels#预处理数据
def preprocess(x, y):x = tf.cast(x, dtype=tf.float32)x = tf.convert_to_tensor(x, dtype=tf.float32) / 255.y = tf.cast(y, dtype=tf.int32)y = tf.convert_to_tensor(y, dtype=tf.int32)return x, y
#训练数据
train_data, train_labels = load_mnist("./datasets")
print(train_data.shape, train_labels.shape)
#测试数据
test_data, test_labels = load_mnist("./datasets", "t10k")
print(test_data.shape, test_labels.shape)batch_size = 128train_db = tf.data.Dataset.from_tensor_slices((train_data, train_labels))
train_db = train_db.map(preprocess).shuffle(10000).batch(batch_size)test_db = tf.data.Dataset.from_tensor_slices((test_data, test_labels))
test_db = test_db.map(preprocess).batch(batch_size)train_db_iter = iter(train_db)
sample = next(train_db_iter)
print('Batch:', sample[0].shape, sample[1].shape)#定义网络模型
model = Sequential([#Layer 1: [b, 784] => [b, 256]layers.Dense(256, activation=tf.nn.relu),#Layer 2: [b, 256] => [b, 128]layers.Dense(128, activation=tf.nn.relu),#Layer 3: [b, 128] => [b, 64]layers.Dense(64, activation=tf.nn.relu),#Layer 4: [b, 64] => [b, 32]layers.Dense(32, activation=tf.nn.relu),#Layer 5: [b, 32] => [b, 10], 输出类别结果layers.Dense(10)
])#编译网络
model.build(input_shape=[None, 28*28])
model.summary()#进行训练
total_epoches = 30
learn_rate = 0.01optimizer = optimizers.Adam(learning_rate = learn_rate)
for epoch in range(total_epoches):for step, (x,y) in enumerate(train_db):with tf.GradientTape() as tape:logits = model(x)y_onehot = tf.one_hot(y, depth=10)#使用交叉熵作为lossloss_ce = tf.reduce_mean(tf.losses.categorical_crossentropy(y_onehot, logits, from_logits=True))#计算梯度grads = tape.gradient(loss_ce, model.trainable_variables)#更新梯度optimizer.apply_gradients(zip(grads, model.trainable_variables))if step % 100 == 0:print("Epoch[", epoch, "]: step-", step, "\tloss: CrossEntropy-", loss_ce.numpy())#使用测试集进行验证
total_correct = 0
total_num = 0
for x,y in test_db:logits = model(x)#使用softmax得到各个类别的概率prob = tf.nn.softmax(logits, axis=1)#求出概率最大的结果参数位置,作为预测的分类结果pred = tf.cast(tf.argmax(prob, axis=1), dtype=tf.int32)#比较结果correct = tf.equal(pred, y)correct = tf.reduce_sum(tf.cast(correct, dtype=tf.int32))#计算精度total_correct += int(correct)total_num += x.shape[0]acc = total_correct / total_num
print("Accuracy:", acc)

运行结果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/281643.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

二、Web3 学习(区块链)

区块链基础知识 一、基础知识1. 区块链可以做什么?2. 区块链的三个特点 二、区块链的类型概括1. PoW2. PoS3. 私有链和联盟链 三、智能合约1. 什么是智能合约2. 如何使用智能合约 四、困境1. 三难选择的基本要素2. 这真的是一个三难选择吗? 五、比特币1. 什么是比特…

java JVM详解(持续更新)

JVM定义 JVM结构 类装载子系统 双亲委派模型 运行时数据区 方法区(Method Area) 堆区(Heap) 虚拟机栈区 程序计数区 执行引擎子系统 垃圾回收机制 内存分代机制 JVM调优 JVM面试题 JVM定义 JVM它是jre的一部分,也…

sentinel使用控制台实现

1、添加依赖 <!--整合控制台--><dependency> <groupId>com.alibaba.csp</groupId> <artifactId>sentinel-transport-simple-http</artifactId> <version>1.8.0</version></dependency> 此项方法&#xff0…

PSO-ELM,粒子群优化算法优化ELM极限学习机数据回归预测(多输入单输出)-MATLAB实现

粒子群优化算法&#xff08;Particle Swarm Optimization, PSO&#xff09;结合极限学习机&#xff08;Extreme Learning Machine, ELM&#xff09;进行数据回归预测是一种常见的机器学习方法。ELM作为一种单隐层前馈神经网络&#xff0c;具有快速训练和良好的泛化能力。而PSO则…

报表生成器FastReport .Net用户指南:关于脚本(下)

FastReport的报表生成器&#xff08;无论VCL平台还是.NET平台&#xff09;&#xff0c;跨平台的多语言脚本引擎FastScript&#xff0c;桌面OLAP FastCube&#xff0c;如今都被世界各地的开发者所认可&#xff0c;这些名字被等价于“速度”、“可靠”和“品质”,在美国&#xff…

【CKA模拟题】查找集群中使用内存最高的node节点

题干 For this question, please set this context (In exam, diff cluster name) kubectl config use-context kubernetes-adminkubernetesFind the Node that consumes the most MEMORY in all cluster(currently we have single cluster). Then, store the result in the …

MySQL数据库的下载和安装以及命令行语法学习

MySQL数据库的下载和安装以及命令行语法学习 学习MYSQL&#xff0c;掌握住基础的SQL句型&#xff08;创建数据库、查看数据库列表、数据增、删、改、查等操作类型&#xff09; 首先要知道MySQL下载和安装方法&#xff1a; 提示&#xff1a;别嫌啰嗦&#xff0c;对于一个初识MY…

python智慧农业小程序flask-django-php-nodejs

当今社会已经步入了科学技术进步和经济社会快速发展的新时期&#xff0c;国际信息和学术交流也不断加强&#xff0c;计算机技术对经济社会发展和人民生活改善的影响也日益突出&#xff0c;人类的生存和思考方式也产生了变化。传统智慧农业采取了人工的管理方法&#xff0c;但这…

基于Python3的数据结构与算法 - 16 链表

目录 链表 1. 创建链表 2. 链表的插入和删除 3. 双链表 4. 链表总结 链表 链表是由一系列节点组成的元素集合。每个节点包含两部分&#xff0c;数据域item和指向下一个节点得指针next。通过节点之间的相互连接&#xff0c;最终串联成一个链表。 class Node:def __init…

vue key的bug

今天遇到一个bug&#xff0c;列表删除元素时&#xff0c;明明在外层设置了key&#xff0c;但是列表元素的状态居然复用了&#xff0c;找了好久原因&#xff0c;最后是key的取值问题&#xff0c;记录一下。 首先key可以取undefine&#xff0c;这个是不会报错的 然后项目的代码结…

通过 Socket 手动实现 HTTP 协议

你好&#xff0c;我是 shengjk1&#xff0c;多年大厂经验&#xff0c;努力构建 通俗易懂的、好玩的编程语言教程。 欢迎关注&#xff01;你会有如下收益&#xff1a; 了解大厂经验拥有和大厂相匹配的技术等 希望看什么&#xff0c;评论或者私信告诉我&#xff01; 文章目录 一…

社交媒体的未来:探讨Facebook的发展趋势

引言 在数字化时代&#xff0c;社交媒体已经成为人们日常生活中不可或缺的一部分。作为全球最大的社交媒体平台之一&#xff0c;Facebook一直在不断地追求创新&#xff0c;以满足用户日益增长的需求和适应科技发展的变革。本文将探讨Facebook在未来发展中可能面临的挑战和应对…

WM8978 —— 带扬声器驱动程序的立体声编解码器(2)

接前一篇文章&#xff1a;WM8978 —— 带扬声器驱动程序的立体声编解码器&#xff08;1&#xff09; 六、引脚详细说明 引脚&#xff08;PIN&#xff09;名称&#xff08;NAME&#xff09;类型&#xff08;TYPE&#xff09;描述&#xff08;DESCRIPTION&#xff09;1LIP模拟输入…

uniApp中使用小程序XR-Frame创建3D场景(1)环境搭建

1.XR-Frame简介 XR-Frame作为微信小程序官方推出的3D框架&#xff0c;是目前所有小程序平台中3D效果最好的一个&#xff0c;由于其本身针对微信小程序做了优化&#xff0c;在性能方面比其他第三方库都要高很多。 2.与Three.js的区别 做3D小程序的同学们对Three.js一定不陌生…

停止docker 容器并删除对应镜像

docker 容器相关命令 docker ps 查看当前系统正在运行的容器情况&#xff0c;返回信息分别为&#xff1a; 容器ID&#xff1a;CONTAINER ID 镜像名IMAGE NAMES 运行命令COMMAND 创建时间CREATED 状态STATUS 映射端口 PORTS docker ps |grep XXX 可以…

ssm项目(tomcat项目),定时任务(每天运行一次)相同时间多次重复运行job 的bug

目录标题 一、原因 一、原因 debug本地调试没有出现定时任务多次运行的bug&#xff0c;上传到服务器就出现多次运行的bug。&#xff08;war的方式部署到tomcat&#xff09; 一开始我以为是代码原因&#xff0c;或者是linux和win环境不同运行定时任务的方式不一样。 但是自己…

sentinel整合gateway实现服务限流

导入依赖: <dependencies><dependency><groupId>org.springframework.cloud</groupId><artifactId>spring-cloud-starter-gateway</artifactId></dependency><dependency><groupId>com.alibaba.csp</groupId><…

数据结构:堆的创建和使用

上一期我们学习了树和二叉树的定义&#xff0c;其中我们了解到了两种特殊的二叉树&#xff1a;满二叉树和完全二叉树。 今天我们还要学习一种新的结构&#xff1a;堆 那这种结构和二叉树有什么联系呢&#xff1f;&#xff1f;&#xff1f; 通过观察我们可以发现&#xff0c;…

UE5 C++增强输入

一.创建charactor&#xff0c;并且包含增强输入相关的头文件 1.项目名.build.cs。添加模块“EnhancedInput”&#xff0c;方便找到头文件和映射的一些文件。 PublicDependencyModuleNames.AddRange(new string[] { "Core", "CoreUObject", "Engine&q…

塔楼VR火灾逃生应急安全教育突破了传统模式

城镇化的高速发展&#xff0c;给消防安全带来了严峻的挑战&#xff0c;尤其是人员密集的办公场所&#xff0c;如何预防火灾发生&#xff0c;学习火灾成因&#xff0c;减少火灾发生避免不必要的损失&#xff0c;成为安全应急科普的重中之重。 通过模拟真实的办公场所火灾场景&am…