Tensorflow2.0笔记 - 使用compile,fit,evaluate,predict简化流程

        本笔记主要用compile, fit, evalutate和predict来简化整体代码,使用这些高层API可以减少很多重复代码。具体内容请自行百度,本笔记基于FashionMnist的训练笔记,原始笔记如下:

Tensorflow2.0笔记 - FashionMnist数据集训练-CSDN博客文章浏览阅读347次。本笔记使用FashionMnist数据集,搭建一个5层的神经网络进行训练,并统计测试集的精度。本笔记中FashionMnist数据集是直接下载到本地加载的方式,不涉及用梯子。关于FashionMnist的介绍,请自行百度。https://blog.csdn.net/vivo01/article/details/136921592?spm=1001.2014.3001.5502        

#Fashion Mnist数据集本地下载和加载(不用梯子)
#https://blog.csdn.net/scar2016/article/details/115361245 (百度网盘)
#https://blog.csdn.net/weixin_43272781/article/details/110006990 (github)
import time
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, metricstf.__version__#加载fashion mnist数据集
def load_mnist(path, kind='train'):import osimport gzipimport numpy as np"""Load MNIST data from `path`"""labels_path = os.path.join(path,'%s-labels-idx1-ubyte.gz'% kind)images_path = os.path.join(path,'%s-images-idx3-ubyte.gz'% kind)with gzip.open(labels_path, 'rb') as lbpath:labels = np.frombuffer(lbpath.read(), dtype=np.uint8,offset=8)with gzip.open(images_path, 'rb') as imgpath:images = np.frombuffer(imgpath.read(), dtype=np.uint8,offset=16).reshape(len(labels), 784)return images, labels#预处理数据
def preprocess(x, y):x = tf.cast(x, dtype=tf.float32)x = tf.convert_to_tensor(x, dtype=tf.float32) / 255.y = tf.cast(y, dtype=tf.int32)y = tf.convert_to_tensor(y, dtype=tf.int32)return x, y
#训练数据
train_data, train_labels = load_mnist("./datasets")
print(train_data.shape, train_labels.shape)
#测试数据
test_data, test_labels = load_mnist("./datasets", "t10k")
print(test_data.shape, test_labels.shape)batch_size = 128train_db = tf.data.Dataset.from_tensor_slices((train_data, train_labels))
train_db = train_db.map(preprocess).shuffle(10000).batch(batch_size)test_db = tf.data.Dataset.from_tensor_slices((test_data, test_labels))
test_db = test_db.map(preprocess).batch(batch_size)train_db_iter = iter(train_db)
sample = next(train_db_iter)
print('Batch:', sample[0].shape, sample[1].shape)test_db_iter = iter(test_db)
sample = next(test_db_iter)
print(sample[0].shape)
print(sample[1].shape)#定义网络模型
model = Sequential([#Layer 1: [b, 784] => [b, 256]layers.Dense(256, activation=tf.nn.relu),#Layer 2: [b, 256] => [b, 128]layers.Dense(128, activation=tf.nn.relu),#Layer 3: [b, 128] => [b, 64]layers.Dense(64, activation=tf.nn.relu),#Layer 4: [b, 64] => [b, 32]layers.Dense(32, activation=tf.nn.relu),#Layer 5: [b, 32] => [b, 10], 输出类别结果layers.Dense(10)
])
print("\n=====Building Model=========\n")
model.build(input_shape=(None, 28*28))
model.summary()total_epoches = 10
learn_rate = 0.01
#编译网络,使用compile(),传入优化器,损失函数和metrics度量
#https://blog.csdn.net/weixin_48169169/article/details/120793534
print("\n=====Compiling Model=========\n")
model.compile(optimizer=optimizers.Adam(learning_rate=learn_rate),loss = tf.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['Accuracy'])
#使用fit()进行训练
print("\n=====Fitting Model=========\n")
model.fit(train_db, epochs=total_epoches, validation_data=test_db, validation_freq=2)
time.sleep(1)
#使用evaluate()进行模型验证,这里使用了test_db,实际中可以使用另外的数据集
print("\n=====Evaluating Model=========\n")
model.evaluate(test_db)
#使用predict()做数据预测
sample = next(iter(test_db))
#获得数据和标签
real_data = sample[0]
real_label = sample[1]
pred = model.predict(real_data)
pred = tf.argmax(pred, axis=1)
print("Predicted Labels:", pred.numpy())
print("Actual Labels:", real_label.numpy())

运行结果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/288709.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

OSPF-区域间路由计算

一、概述 前面学习了我们学习了Router-LSA和Network-LSA,它们都只能在区域内进行泛洪,而且我们之前一直主要是单区域学习。OSPF的核心是骨干区域Area 0,其它都为非骨干区域。但是在大型网络中,单区域OSPF会存在一定的问题&#xf…

[BT]BUUCTF刷题第9天(3.27)

第9天(共2题) [护网杯 2018]easy_tornado 打开网站就是三个txt文件 /flag.txt flag in /fllllllllllllag/welcome.txt render/hints.txt md5(cookie_secretmd5(filename))当点进flag.txt时,url变为 http://b9e52e06-e591-46ad-953e-7e8c5f…

gopher伪协议

基础知识 基本格式 基本格式&#xff1a;URL:gopher://<host>:<port>/<gopher-path>web也需要加端口号80gophert协议默认端口为70gopheri请求不转发第一个字符 get请求 问号&#xff08;&#xff1f;)需要转码为URL编码&#xff0c;也就是%3f回车换行要变…

【Linux】信号的处理{信号处理的时机/了解寄存器/内核态与用户态/信号操作函数}

文章目录 0.对于信号捕捉的理解1.信号处理的时机1.1 何时处理信号&#xff1f;1.2 内核态和用户态1.3 内核态和用户态的切换 2.了解寄存器3.信号捕捉的原理4.信号操作函数4.1sighandler_t signal(int signum, sighandler_t handler);4.2int sigaction(int signum, const struct…

IP如何异地共享文件?

【天联】 组网由于操作简单、跨平台应用、无网络要求、独创的安全加速方案等原因&#xff0c;被几十万用户广泛应用&#xff0c;解决了各行业客户的远程连接需求。采用穿透技术&#xff0c;简单易用&#xff0c;不需要在硬件设备中端口映射即可实现远程访问。 异地共享文件 在…

基于SpringBoot和Vue的校园管理系统的设计与实现

今天要和大家聊的是一款基于SpringBoot和Vue的校园管理系统的设计与实现 &#xff01;&#xff01;&#xff01; 有需要的小伙伴可以通过文章末尾名片咨询我哦&#xff01;&#xff01;&#xff01; &#x1f495;&#x1f495;作者&#xff1a;李同学 &#x1f495;&#x1f…

vscode上编辑vba

安装xvba插件更换vscode的工作目录启动扩展服务器在config.json中添加目标工作簿的名称加载excel文件&#xff08;必须带宏的xlsm&#xff09;这个扩展就会自动提取出Excel文件中的代码Export VBA&#xff08;编辑完成的VBA代码保存到 Excel文件 &#xff09;再打开excel文件可…

【LVGL-键盘部件,实体按键控制】

LVGL-二维码库 ■ LVGL-键盘部件■ 示例一&#xff1a;键盘弹窗提示■ 示例二&#xff1a;设置键盘模式■ 综合示例&#xff1a; ■ LVGL-实体按键控制■ 简介 ■ LVGL-键盘部件 ■ 示例一&#xff1a;键盘弹窗提示 lv_keyboard_set_popovers(kb,true);■ 示例二&#xff1a;设…

基于Givens旋转完成QR分解进而求解实矩阵的逆矩阵

基于Givens旋转完成QR分解进而求解实矩阵的逆矩阵 目录 前言 一、Givens旋转简介 二、Givens旋转解释 三、Givens旋转进行QR分解 四、Givens旋转进行QR分解数值计算例子 五、求逆矩阵 六、MATLAB仿真 七、参考资料 总结 前言 在进行QR分解时&#xff0c;HouseHolder变换…

开源AI引擎|企业合同管理:自然语言处理与OCR技术深度融合

一、企业应用&#xff1a;合同智能管理 结合NLP和OCR技术&#xff0c;企业可以构建智能化的合同管理系统&#xff0c;实现合同的自动化审查、风险评估和知识抽取。这样的系统不仅能够提高合同处理的效率&#xff0c;还能够降低人为错误&#xff0c;加强风险控制。 例如&#x…

【React】vite + react 项目,进行配置 eslint

安装与配置 eslint 1 安装 eslint babel/eslint-parser2 初始化配置 eslint3 安装 vite-plugin-eslint4 配置 vite.config.js 文件5 修改 eslint 默认配置 1 安装 eslint babel/eslint-parser npm i -D eslint babel/eslint-parser2 初始化配置 eslint npx eslint --init相关…

iOS - Runtime-消息机制-objc_msgSend()

iOS - Runtime-消息机制-objc_msgSend() 前言 本章主要介绍消息机制-objc_msgSend的执行流程&#xff0c;分为消息发送、动态方法解析、消息转发三个阶段&#xff0c;每个阶段可以做什么。还介绍了super的本质是什么&#xff0c;如何调用的 1. objc_msgSend执行流程 OC中的…

USART发送单字节数据原理及程序实现

硬件接线&#xff1a; 显示屏的SCA接在B11&#xff0c;SCL接在B10&#xff0c;串口的RX连接A9&#xff0c;TX连接A10。 新建Serial.c和Serial.h文件 在Serial.c文件中&#xff0c;实现初始化函数&#xff0c;等需要的函数&#xff0c;首先对串口进行初始化&#xff0c;只需要…

【深度学习|基础算法】2.AlexNet学习记录

AlexNet示例代码与解析 1、前言2、模型tips3、模型架构4、模型代码backbonetrainpredict 5、模型训练6、导出onnx模型 1、前言 AlexNet由Hinton和他的学生Alex Krizhevsky设计&#xff0c;模型名字来源于论文第一作者的姓名Alex。该模型以很大的优势获得了2012年ISLVRC竞赛的冠…

每日一题 --- 链表相交[力扣][Go]

链表相交 题目&#xff1a;面试题 02.07. 链表相交 给你两个单链表的头节点 headA 和 headB &#xff0c;请你找出并返回两个单链表相交的起始节点。如果两个链表没有交点&#xff0c;返回 null 。 图示两个链表在节点 c1 开始相交**&#xff1a;** 题目数据 保证 整个链式结…

神经网络:梯度下降法更新模型参数

作者&#xff1a;CSDN _养乐多_ 在神经网络领域&#xff0c;梯度下降是一种核心的优化算法&#xff0c;本文将介绍神经网络中梯度下降法更新参数的公式&#xff0c;并通过实例演示其在模型训练中的应用。通过本博客&#xff0c;读者将能够更好地理解深度学习中的优化算法和损…

H5小程序视频方案解决方案,实现轻量化视频制作

对于许多企业而言&#xff0c;制作高质量的视频仍然是一个技术门槛高、成本高昂的挑战。针对这一痛点&#xff0c;美摄科技凭借其深厚的技术积累和创新能力&#xff0c;推出了面向企业的H5/小程序视频方案解决方案&#xff0c;为企业提供了一种轻量化、高效、便捷的视频制作方式…

线程局部存储(TLS)

线程局部存储&#xff08;Thread Local Storage&#xff0c;TLS&#xff09;&#xff0c;是一种变量的存储方法&#xff0c;这个变量在它所在的线程内是全局可访问的&#xff0c;但是不能被其他线程访问到&#xff0c;这样就保持了数据的线程独立性。而熟知的全局变量&#xff…

mac-git上传至github(ssh版本,个人tokens总出错)

第一步 git clone https://github.com/用户名/项目名.git 第二步 cd 项目名 第三步 将本地的文件移动到项目下 第四步 git add . 第五步 git commit -m "添加****文件夹" 第六步 git push origin main 报错&#xff1a; 采用ssh验证 本地文件链接公钥 …

超级会员卡积分收银系统源码:积分+收银+商城三合一小程序 带完整的安装代码包以及搭建教程

信息技术的迅猛发展&#xff0c;移动支付和线上购物已经成为现代人生活的常态。在这样的背景下&#xff0c;商家对于能够整合收银、积分管理和在线商城的综合性系统的需求日益强烈。下面&#xff0c;罗峰给大家分享一款超级会员卡积分收银系统源码&#xff0c;它集积分、收银、…