Tensorflow2.0笔记 - ResNet实践

        本笔记记录使用ResNet18网络结构,进行CIFAR100数据集的训练和验证。由于参数较多,训练时间会比较长,因此只跑了10个epoch,准确率还没有提升上去。

import os
import time
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics, Inputos.environ['TF_CPP_MIN_LOG_LEVEL']='2'
#tf.random.set_seed(12345)
tf.__version__#关于ResNet的描述,可以参考如下链接:
#https://blog.csdn.net/qq_39770163/article/details/126169080
#代码基于ResNet18结构,有少许不一样
class BasicBlock(layers.Layer):def __init__(self, filter_num, strides = 1):super(BasicBlock, self).__init__()#卷积层1self.conv1 = layers.Conv2D(filter_num, (3,3), strides = strides, padding='same')#BN层self.bn1 = layers.BatchNormalization()#Relu层self.relu = layers.Activation('relu')#卷积层2,BN层2,self.conv2 = layers.Conv2D(filter_num, (3,3), strides = 1, padding='same')self.bn2 = layers.BatchNormalization()#Shortcutif strides != 1:#如果strides不为1,需要下采样self.downsample = Sequential()self.downsample.add(layers.Conv2D(filter_num, (1,1), strides=strides))else:#strides为1, 直接返回原始值即可self.downsample = lambda x:xdef call(self, inputs, training = None):#经过第一个卷积层,BN和Reluout = self.conv1(inputs)out = self.bn1(out)out = self.relu(out)#经过第二个卷积层out = self.conv2(out)out = self.bn2(out)#Shortt处理,out和输入相加identity = self.downsample(inputs)output = layers.add([out, identity])#再经过一个reluoutput = tf.nn.relu(output)return outputclass ResNet(keras.Model):#layer_dims表示对应位置的ResBlock包含了几个BasicBlock#比如[2,2,2,2] => 总共4个ResBlock,每个ResBlock包含两个BasicBlock#num_classes表示输出的类别的个数def __init__(self, layer_dims, num_classes=100):super(ResNet, self).__init__()#预处理单元self.stem = Sequential([layers.Conv2D(64, (3,3), strides=(1,1)),layers.BatchNormalization(),layers.Activation('relu'),layers.MaxPool2D(pool_size=(2,2), strides=(1,1), padding='same')])#创建中间ResBlock层self.layer1 = self.buildResBlock(64, layer_dims[0])self.layer2 = self.buildResBlock(128, layer_dims[1], strides=2)self.layer3 = self.buildResBlock(256, layer_dims[2], strides=2)self.layer4 = self.buildResBlock(512, layer_dims[3], strides=2)#自适应输出层self.avgpool = layers.GlobalAveragePooling2D()#全连接层self.fc = layers.Dense(num_classes)def call(self, inputs, training = None):x = self.stem(inputs)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)#经过avgpool => [b, 512]x = self.avgpool(x)#经过Dense => [b, 100]x = self.fc(x)return xdef buildResBlock(self, filter_num, blocks, strides = 1):resBlocks = Sequential()resBlocks.add(BasicBlock(filter_num, strides))#后续的resBlock的strides都设置为1for _ in range(1, blocks):resBlocks.add(BasicBlock(filter_num))return resBlocks;def ResNet18():return ResNet([2, 2, 2 ,2]);def ResNet34():return ResNet([3, 4, 6, 3])#加载CIFAR100数据集
#如果下载很慢,可以使用迅雷下载到本地,迅雷的链接也可以直接用官网URL:
#      https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz
#下载好后,将cifar-100.python.tar.gz放到 .keras\datasets 目录下(我的环境是C:\Users\Administrator\.keras\datasets)
# 参考:https://blog.csdn.net/zy_like_study/article/details/104219259
(x_train,y_train), (x_test, y_test) = datasets.cifar100.load_data()
print("Train data shape:", x_train.shape)
print("Train label shape:", y_train.shape)
print("Test data shape:", x_test.shape)
print("Test label shape:", y_test.shape)def preprocess(x, y):x = tf.cast(x, dtype=tf.float32) / 255.y = tf.cast(y, dtype=tf.int32)return x,yy_train = tf.squeeze(y_train, axis=1)
y_test = tf.squeeze(y_test, axis=1)batch_size = 128
train_db = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_db = train_db.shuffle(1000).map(preprocess).batch(batch_size)test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_db = test_db.map(preprocess).batch(batch_size)sample = next(iter(train_db))
print("Train data sample:", sample[0].shape, sample[1].shape, tf.reduce_min(sample[0]), tf.reduce_max(sample[0]))def main():#创建ResNetresNet = ResNet18()resNet.build(input_shape=[None, 32, 32, 3])resNet.summary()#设置优化器optimizer = optimizers.Adam(learning_rate=1e-3)#进行训练num_epoches = 10for epoch in range(num_epoches):for step, (x,y) in enumerate(train_db):with tf.GradientTape() as tape:#[b, 32, 32, 3] => [b, 100]logits = resNet(x)#标签做one_hot encodingy_onehot = tf.one_hot(y, depth=100)#计算损失loss = tf.losses.categorical_crossentropy(y_onehot, logits, from_logits=True)loss = tf.reduce_mean(loss)#计算梯度grads = tape.gradient(loss, resNet.trainable_variables)#更新参数optimizer.apply_gradients(zip(grads, resNet.trainable_variables))if (step % 100 == 0):print("Epoch[", epoch + 1, "/", num_epoches, "]: step - ", step, " loss:", float(loss))#进行验证total_samples = 0total_correct = 0for x,y in test_db:logits = resNet(x)prob = tf.nn.softmax(logits, axis=1)pred = tf.argmax(prob, axis=1)pred = tf.cast(pred, dtype=tf.int32)correct = tf.cast(tf.equal(pred, y), dtype=tf.int32)correct = tf.reduce_sum(correct)total_samples += x.shape[0]total_correct += int(correct)#统计准确率acc = total_correct / total_samplesprint("Epoch[", epoch + 1, "/", num_epoches, "]: accuracy:", acc)if __name__ == '__main__':main()

运行结果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/317900.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

自适应医疗决策框架 MDAgents:问题复杂度评估 + 医疗决策 + 多智能体协作

自适应医疗决策框架 MDAgents:问题复杂度评估 医疗决策 多智能体协作 提出背景MDAgents 拆解解法:MDAgents框架处理医疗问题3.1 查询复杂性评估例子:糖尿病患者的医疗查询 3.2 专家招募3.3 医疗协作与改良3.4 决策制定 分阶段决策1. 问题复…

【实时数仓架构】方法论

笔者不是专业的实时数仓架构,这是笔者从其他人经验和网上资料整理而来,仅供参考。写此文章意义,加深对实时数仓理解。 一、实时数仓架构技术演进 1.1 四种架构演进 1)离线大数据架构 一种批处理离线数据分析架构,…

目标检测算法YOLOv3简介

YOLOv3由Joseph Redmon等人于2018年提出,论文名为:《YOLOv3: An Incremental Improvement》,论文见:https://arxiv.org/pdf/1804.02767.pdf ,项目网页:https://pjreddie.com/darknet/yolo/ 。YOLOv3是对YOL…

leetcode870.优势洗牌

题目描述: 给定两个长度相等的数组 nums1 和 nums2,nums1 相对于 nums2 的优势可以用满足 nums1[i] > nums2[i] 的索引 i 的数目来描述。 返回 nums1 的任意排列,使其相对于 nums2 的优势最大化。 示例一: 输入&#xff…

BIO、NIO与AIO

文章目录 一 BIO同步阻塞案例BIO模式消息多发多收实现 二 NIONIO核心组件Buffer(缓冲区)Buffer常见方法缓冲区的数据操作直接内存与非直接内存 Channel(通道)channel常用操作 Selector(选择器)selector选择器处理流程NIO非阻塞式网络通信原理分析 NIO网络编程实现群聊系统服务端…

Acrobat Pro DC 2023:专业PDF编辑软件,引领高效办公新时代

Acrobat Pro DC 2023是一款专为Mac和Windows用户设计的专业PDF编辑软件,凭借其强大的功能和卓越的性能,成为现代职场人士不可或缺的得力助手。 这款软件拥有出色的PDF编辑能力。用户不仅可以轻松地对PDF文档中的文字、图片和布局进行编辑和调整&#xf…

【C++】哈希的应用---位图

目录 1、引入 2、位图的概念 3、位图的实现 ①框架的搭建 ②设置存在 ③设置不存在 ④检查存在 ​4、位图计算出现的次数 5、完整代码 1、引入 我们可以看一道面试题 给40亿个不重复的无符号整数,没排过序。给一个无符号整数,如何快速判断一个数…

菜鸡学习netty源码(一)——ServerBootStrap启动

1.概述 对于初学者而然,写一个netty本地进行测试的Server端和Client端,我们最先接触到的类就是ServerBootstrap和Bootstrap。这两个类都有一个公共的父类就是AbstractBootstrap. 那既然 ServerBootstrap和Bootstrap都有一个公共的分类,那就证明它们两个肯定有很多公共的职…

树莓派4B安装安卓系统LineageOS 21(Android14)

1:系统下载 2:下载好镜像后,准备写入SD卡,我这边使用的是 balenaetcher 3:插入树莓派,按照指示一步一步进行配置,可以配置时区,语言。 注意点 1》:想返回的时候按F2 2》:进入系统…

基于springboot实现中药实验管理系统设计项目【项目源码+论文说明】计算机毕业设计

基于springboot实现中药实验管理系统设计演示 摘要 随着信息技术在管理上越来越深入而广泛的应用,管理信息系统的实施在技术上已逐步成熟。本文介绍了中药实验管理系统的开发全过程。通过分析中药实验管理系统管理的不足,创建了一个计算机管理中药实验管…

AI视频教程下载:用ChatGPT提示词开发AI应用和GPTs

在这个课程中,你将深入ChatGPT的迷人世界,学习如何利用其能力构建创新和有影响力的工具。你将发现如何创建不仅吸引而且保持用户参与度的应用程序,将流量驱动到你的网站,并开辟新的货币化途径。 **课程的主要特点:** …

Python异步Redis客户端与通用缓存装饰器

前言 这里我将通过 redis-py 简易封装一个异步的Redis客户端,然后主要讲解设计一个支持各种缓存代理(本地内存、Redis等)的缓存装饰器,用于在减少一些不必要的计算、存储层的查询、网络IO等。 具体代码都封装在 HuiDBK/py-tools: …

蓝桥杯练习系统(算法训练)ALGO-953 混合积

资源限制 内存限制:256.0MB C/C时间限制:1.0s Java时间限制:3.0s Python时间限制:5.0s 问题描述 众所周知,人人都在学习线性代数,既然都学过,那么解决本题应该很方便。   宇宙大战中&…

Python量化炒股的财务因子选股—质量因子选股

Python量化炒股的财务因子选股—质量因子选股 在Python财务因子量化选股中,质量类因子有2个,分别是净资产收益率和总资产净利率。需要注意的是,质量类因子在财务指标数据表indicator中。 净资产收益率(roe)选股 净资…

Linux基础——Linux开发工具(上)_vim

前言:在了解完Linux基本指令和Linux权限后,我们有了足够了能力来学习后面的内容,但是在真正进入Linux之前,我们还得要学会使用Linux中的几个开发工具。而我们主要介绍的是以下几个: yum, vim, gcc / g, gdb, make / ma…

Spark核心名词解释与编程

Spark核心概念 名词解释 1)ClusterManager:在Standalone(上述安装的模式,也就是依托于spark集群本身)模式中即为Master(主节点),控制整个集群,监控Worker。在YARN模式中为资源管理器ResourceManager(国内…

YOLOv9/YOLOv8算法改进【NO.128】 使用ICCV2023超轻量级且高效的动态上采样器( DySample)改进yolov8中的上采样

前 言 YOLO算法改进系列出到这,很多朋友问改进如何选择是最佳的,下面我就根据个人多年的写作发文章以及指导发文章的经验来看,按照优先顺序进行排序讲解YOLO算法改进方法的顺序选择。具体有需求的同学可以私信我沟通: 首推…

如何远程访问连接管理器?

远程访问连接管理器是一种方便的工具,可以实现远程访问计算机和网络设备的功能。它使用户能够从任何地点连接到远程计算机,并进行文件传输、桌面共享和远程控制等操作。远程访问连接管理器不仅提供了便利性,还能提高工作效率,并为…

机器人正反向运动学(FK和IK)

绕第一个顶点可以沿Z轴转动,角度用alpha表示 绕第二个点沿X轴转动,角度为Beta 第三个点沿X轴转动,记作gama 这三个点构成姿态(pose) 我们记第一个点为P0,画出它的本地坐标系,和世界坐标系一样红…

AI智能名片商城小程序:引领企业迈向第三增长极

随着数字化浪潮的席卷,私域流量的重要性逐渐凸显,为企业增长提供了全新的动力。在这一背景下,AI智能名片商城系统崭露头角,以其独特的优势,引领企业迈向第三增长极。 私域流量的兴起,为企业打开了一扇新的销…