TensorFlow-slim包进行图像数据集分类---具体流程

TensorFlow中slim包的具体用法

  • 1、训练脚本文件(该文件包含数据下载打包、模型训练,模型评估流程)
  • 3、模型训练
    • 1、数据集相关模块:
    • 2、设置网络模型模块
    • 3、数据预处理模块
    • 4、定义损失loss
    • 5、定义优化器模块

本次使用的TensorFlow版本是1.13.0
地址:https://github.com/tensorflow/models/tree/r1.13.0
到tensorflow-models的GitHub下载research下面的slim这个包到本地
在这里插入图片描述

TensorFlow中slim包的目录结构:

-- slim|-- BUILD|-- README.md|-- WORKSPACE|-- __init__.py|-- datasets|   |-- __init__.py|   |-- __pycache__|   |   |-- __init__.cpython-37.pyc|   |   |-- dataset_utils.cpython-37.pyc|   |   |-- download_and_convert_cifar10.cpython-37.pyc|   |   |-- download_and_convert_flowers.cpython-37.pyc|   |   `-- download_and_convert_mnist.cpython-37.pyc|   |-- build_imagenet_data.py|   |-- cifar10.py|   |-- dataset_factory.py|   |-- dataset_utils.py|   |-- download_and_convert_cifar10.py|   |-- download_and_convert_flowers.py|   |-- download_and_convert_imagenet.sh|   |-- download_and_convert_mnist.py|   |-- download_imagenet.sh|   |-- flowers.py|   |-- imagenet.py|   |-- imagenet_2012_validation_synset_labels.txt|   |-- imagenet_lsvrc_2015_synsets.txt|   |-- imagenet_metadata.txt|   |-- mnist.py|   |-- preprocess_imagenet_validation_data.py|   `-- process_bounding_boxes.py|-- deployment|   |-- __init__.py|   |-- model_deploy.py|   `-- model_deploy_test.py|-- download_and_convert_data.py    # 下载相应的数据集,并将数据打包成TF-record的格式|-- eval_image_classifier.py        # 测试模型分类效果|-- export_inference_graph.py|-- export_inference_graph_test.py|-- nets|   |-- __init__.py|   |-- alexnet.py|   |-- alexnet_test.py|   |-- cifarnet.py|   |-- cyclegan.py|   |-- cyclegan_test.py|   |-- dcgan.py|   |-- dcgan_test.py|   |-- i3d.py|   |-- i3d_test.py|   |-- i3d_utils.py|   |-- inception.py|   |-- inception_resnet_v2.py|   |-- inception_resnet_v2_test.py|   |-- inception_utils.py|   |-- inception_v1.py|   |-- inception_v1_test.py|   |-- inception_v2.py|   |-- inception_v2_test.py|   |-- inception_v3.py|   |-- inception_v3_test.py|   |-- inception_v4.py|   |-- inception_v4_test.py|   |-- lenet.py|   |-- mobilenet|   |   |-- README.md|   |   |-- __init__.py|   |   |-- conv_blocks.py|   |   |-- madds_top1_accuracy.png|   |   |-- mnet_v1_vs_v2_pixel1_latency.png|   |   |-- mobilenet.py|   |   |-- mobilenet_example.ipynb|   |   |-- mobilenet_v2.py|   |   `-- mobilenet_v2_test.py|   |-- mobilenet_v1.md|   |-- mobilenet_v1.png|   |-- mobilenet_v1.py|   |-- mobilenet_v1_eval.py|   |-- mobilenet_v1_test.py|   |-- mobilenet_v1_train.py|   |-- nasnet|   |   |-- README.md|   |   |-- __init__.py|   |   |-- nasnet.py|   |   |-- nasnet_test.py|   |   |-- nasnet_utils.py|   |   |-- nasnet_utils_test.py|   |   |-- pnasnet.py|   |   `-- pnasnet_test.py|   |-- nets_factory.py|   |-- nets_factory_test.py|   |-- overfeat.py|   |-- overfeat_test.py|   |-- pix2pix.py|   |-- pix2pix_test.py|   |-- resnet_utils.py|   |-- resnet_v1.py|   |-- resnet_v1_test.py|   |-- resnet_v2.py|   |-- resnet_v2_test.py|   |-- s3dg.py|   |-- s3dg_test.py|   |-- vgg.py|   `-- vgg_test.py|-- preprocessing|   |-- __init__.py|   |-- cifarnet_preprocessing.py|   |-- inception_preprocessing.py|   |-- lenet_preprocessing.py|   |-- preprocessing_factory.py|   `-- vgg_preprocessing.py|-- scripts                     # gqr:存储的是相关的模型训练脚本                |   |-- export_mobilenet.sh|   |-- finetune_inception_resnet_v2_on_flowers.sh|   |-- finetune_inception_v1_on_flowers.sh|   |-- finetune_inception_v3_on_flowers.sh|   |-- finetune_resnet_v1_50_on_flowers.sh|   |-- train_cifarnet_on_cifar10.sh|   `-- train_lenet_on_mnist.sh|-- setup.py|-- slim_walkthrough.ipynb`-- train_image_classifier.py    # 训练模型的脚本

1、训练脚本文件(该文件包含数据下载打包、模型训练,模型评估流程)

scripts/finetune_resnet_v1_50_on_flowers.sh

#!/bin/bash
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
#
# This script performs the following operations:
# 1. Downloads the Flowers dataset
# 2. Fine-tunes a ResNetV1-50 model on the Flowers training set.
# 3. Evaluates the model on the Flowers validation set.
#
# Usage:
# cd slim
# ./slim/scripts/finetune_resnet_v1_50_on_flowers.sh
set -e# Where the pre-trained ResNetV1-50 checkpoint is saved to.
PRETRAINED_CHECKPOINT_DIR=/tmp/checkpoints   # gqr:预训练模型存放路径# Where the training (fine-tuned) checkpoint and logs will be saved to.
TRAIN_DIR=/tmp/flowers-models/resnet_v1_50# Where the dataset is saved to.
DATASET_DIR=/tmp/flowers    # gqr:数据集存放路径# Download the pre-trained checkpoint.
if [ ! -d "$PRETRAINED_CHECKPOINT_DIR" ]; thenmkdir ${PRETRAINED_CHECKPOINT_DIR}
fi
if [ ! -f ${PRETRAINED_CHECKPOINT_DIR}/resnet_v1_50.ckpt ]; thenwget http://download.tensorflow.org/models/resnet_v1_50_2016_08_28.tar.gztar -xvf resnet_v1_50_2016_08_28.tar.gzmv resnet_v1_50.ckpt ${PRETRAINED_CHECKPOINT_DIR}/resnet_v1_50.ckptrm resnet_v1_50_2016_08_28.tar.gz
fi# Download the dataset
python download_and_convert_data.py \--dataset_name=flowers \--dataset_dir=${DATASET_DIR}# Fine-tune only the new layers for 3000 steps.
python train_image_classifier.py \--train_dir=${TRAIN_DIR} \--dataset_name=flowers \--dataset_split_name=train \--dataset_dir=${DATASET_DIR} \--model_name=resnet_v1_50 \--checkpoint_path=${PRETRAINED_CHECKPOINT_DIR}/resnet_v1_50.ckpt \--checkpoint_exclude_scopes=resnet_v1_50/logits \--trainable_scopes=resnet_v1_50/logits \--max_number_of_steps=3000 \--batch_size=32 \--learning_rate=0.01 \--save_interval_secs=60 \--save_summaries_secs=60 \--log_every_n_steps=100 \--optimizer=rmsprop \--weight_decay=0.00004# Run evaluation.
python eval_image_classifier.py \--checkpoint_path=${TRAIN_DIR} \--eval_dir=${TRAIN_DIR} \--dataset_name=flowers \--dataset_split_name=validation \--dataset_dir=${DATASET_DIR} \--model_name=resnet_v1_50# Fine-tune all the new layers for 1000 steps.
python train_image_classifier.py \--train_dir=${TRAIN_DIR}/all \--dataset_name=flowers \--dataset_split_name=train \--dataset_dir=${DATASET_DIR} \--checkpoint_path=${TRAIN_DIR} \--model_name=resnet_v1_50 \--max_number_of_steps=1000 \--batch_size=32 \--learning_rate=0.001 \--save_interval_secs=60 \--save_summaries_secs=60 \--log_every_n_steps=100 \--optimizer=rmsprop \--weight_decay=0.00004# Run evaluation.
python eval_image_classifier.py \--checkpoint_path=${TRAIN_DIR}/all \--eval_dir=${TRAIN_DIR}/all \--dataset_name=flowers \--dataset_split_name=validation \--dataset_dir=${DATASET_DIR} \--model_name=resnet_v1_50

以上文件以下载并打包flowers数据集为例会调用slim/datasets下的****download_and_convert_flowers.py
在这里插入图片描述
代码43行:_NUM_VALIDATION = 350值的意思的测试数据集的数量,我们一般2,8分数据集,这里只用填写测试集的数据代码会自动吧总数据集分成2部分
代码48行:_NUM_SHARDS = 1这个的意思是生成几个tfrecord文件,这个数量是根据你数据量来划分
在这里插入图片描述
代码190行:dataset_utils.download_and_uncompress_tarball(_DATA_URL, dataset_dir) 函数为下载数据集函数,如果本地已经存在数据集,可将将其注释掉
在这里插入图片描述
代码210行:_clean_up_temporary_files(dataset_dir) 函数为打包完毕后删除下载的数据集文件,如果需要下载的数据集可以将其注释掉

上述文件执行完毕后,会得到以下文件
在这里插入图片描述

3、模型训练

模型训练文件为
在这里插入图片描述
以下是该文件中各个模块相关内容

1、数据集相关模块:

在这里插入图片描述

2、设置网络模型模块

在这里插入图片描述

3、数据预处理模块

在这里插入图片描述

4、定义损失loss

在这里插入图片描述

5、定义优化器模块

在这里插入图片描述

运行训练指令:

python train_image_classifier.py \--train_dir=./data/flowers-models/resnet_v1_50\--dataset_name=flowers \--dataset_split_name=train \--dataset_dir=./data/flowers \--model_name=resnet_v1_50 \--checkpoint_path=./data/checkpoints/resnet_v1_50.ckpt \--checkpoint_exclude_scopes=resnet_v1_50/logits \--trainable_scopes=resnet_v1_50/logits \--max_number_of_steps=3000 \ --batch_size=32 \--learning_rate=0.01 \--save_interval_secs=60 \--save_summaries_secs=60 \--log_every_n_steps=100 \--optimizer=rmsprop \--weight_decay=0.00004

–dataset_name=指定模板
–model_name=指定预训练模板
–dataset_dir=指定训练集目录
–checkpoint_exclude_scopes=指定忘记那几层的参数,不带进训练里面,记住提取特征的部分
–train_dir=训练参数存放地址
–trainable_scopes=设定只对那几层变量进行调整,其他层都不进行调整,不设定就会对所有层训练(所以是必须要给定的)
–learning_rate=学习率
–optimizer=优化器
–checkpoint_path:预训练模型存放地址
–max_number_of_steps=训练步数
–batch_size=一次训练所选取的样本数。 (Batch Size的大小影响模型的优化程度和速度。同时其直接影响到GPU内存的使用情况,假如你GPU内存不大,该数值最好设置小一点。)
–weight_decay=即模型中所有参数的二次正则化超参数(这个的加入就是为了防止过拟合加入正则项,weight_decay 是乘在正则项的前面,控制正则化项在损失函数中所占权重的)

注意:在模型训练前,需要下载预训练模型,
wget http://download.tensorflow.org/models/resnet_v1_50_2016_08_28.tar.gz

解压后存放在相应目录

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/111433.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

延迟队列的理解与使用

目录 一、场景引入 二、延迟队列的三种场景 1、TTL对队列进行延迟 2、创建通用延时消息对消息延迟 3、使用rabbitmq的延时队列插件 x-delayed-message使用 父pom文件 pom文件 配置文件 config 生产者 消费者 结果 一、场景引入 我们知道可以通过TTL来对队列进行设…

Matlab(结构化程式和自定义函数)

目录 1.脚本编辑器 2.脚本流 2.1 控制流 2.2 关系(逻辑)操作符 3.脚本与函数 1.脚本编辑器 Matlab的命名规则: 常用功能: 智能缩进: 在写代码的时候,有的时候代码看起来并不是那么美观(可读性…

栈和队列(详解)

一、栈 1.1、栈的基本概念 1.1.1、栈的定义 栈(Stack):是只允许在一端进行插入或删除的线性表。首先栈是一种线性表,但限定这种线性表只能在某一端进行插入和删除操作。 栈顶(Top):线性表允许…

iPhone 15 Pro与谷歌Pixel 7 Pro:哪款相机手机更好?

考虑到苹果最近将更多高级功能转移到iPhone Pro设备上的趋势,今年秋天iPhone 15 Pro与谷歌Pixel 7 Pro的对决将是一场特别有趣的对决。去年发布的iPhone 14 Pro确实发生了这种情况,有传言称iPhone 15 Pro再次受到了苹果的大部分关注。 预计iPhone 15系列会有一些变化,例如切…

企业网络安全:威胁情报解决方案

什么是威胁情报 威胁情报是网络安全的关键组成部分,可为潜在的恶意来源提供有价值的见解,这些知识可帮助组织主动识别和防止网络攻击,通过利用 STIX/TAXII 等威胁源,组织可以检测其网络中的潜在攻击,从而促进快速检测…

Flutter Web 项目网络请求报 XMLHttpRequest error 解决方案

使用http库进行简单的网络请求时,运行在Chrome浏览器上,网络请求一直报错 XMLHttpRequest error,而在iOS 模拟器上运行则正常,后面在postman上发送请求,也是正常的。这就是很尴尬了!!&#xff0…

公有云与私有云,IaaS、PaaS 和 SaaS云服务模型概述

云计算主要分为 4 种类型:私有云、公共云、混合云和多云。同时,云计算服务主要有 3 种:基础架构即服务(IaaS)、平台即服务(PaaS)和软件即服务(SaaS) Saas(Sof…

nginx-concat

为了减少tcp请求数量,nginx从上有服务器获取多个静态资源(css,js)的时候,将多个静态资源合并成一个返回给客户端。 这种前面有两个问号的请求都是用了cancat合并功能。 先到官网下载安装包,拷贝到服务器编译…

UDP 多播(组播)

前言(了解分类的IP地址) 1.组播(多播) 单播地址标识单个IP接口,广播地址标识某个子网的所有IP接口,多播地址标识一组IP接口。单播和广播是寻址方案的两个极端(要么单个要么全部)&am…

微信小程序 实时日志

目录 实时日志 背景 如何使用 如何查看日志 注意事项 实时日志 背景 为帮助小程序开发者快捷地排查小程序漏洞、定位问题,我们推出了实时日志功能。从基础库2.7.1开始,开发者可通过提供的接口打印日志,日志汇聚并实时上报到小程序后台…

【base64】JavaScriptuniapp 将图片转为base64并展示

Base64是一种用于编码二进制数据的方法&#xff0c;它将二进制数据转换为文本字符串。它的主要目的是在网络传输或存储过程中&#xff0c;通过将二进制数据转换为可打印字符的形式进行传输 JavaScript 压缩图片 <html><body><script src"https://code.j…

数学建模:主成分分析法

&#x1f506; 文章首发于我的个人博客&#xff1a;欢迎大佬们来逛逛 主成分分析法 算法流程 构建原始数据矩阵 X X X &#xff0c;其中矩阵的形状为 x ∗ n x * n x∗n &#xff0c;有 m m m 个对象&#xff0c; n n n 个评价指标。然后进行矩阵的归一化处理。首先计算矩…

Android Looper Handler 机制浅析

最近想写个播放器demo&#xff0c;里面要用到 Looper Handler&#xff0c;看了很多资料都没能理解透彻&#xff0c;于是决定自己看看相关的源码&#xff0c;并在此记录心得体会&#xff0c;希望能够帮助到有需要的人。 本文会以 猜想 log验证 的方式来学习 Android Looper Ha…

第62步 深度学习图像识别:多分类建模(Pytorch)

基于WIN10的64位系统演示 一、写在前面 上期我们基于TensorFlow环境做了图像识别的多分类任务建模。 本期以健康组、肺结核组、COVID-19组、细菌性&#xff08;病毒性&#xff09;肺炎组为数据集&#xff0c;基于Pytorch环境&#xff0c;构建SqueezeNet多分类模型&#xff0…

Android Activity启动过程一:从Intent到Activity创建

关于作者&#xff1a;CSDN内容合伙人、技术专家&#xff0c; 从零开始做日活千万级APP。 专注于分享各领域原创系列文章 &#xff0c;擅长java后端、移动开发、人工智能等&#xff0c;希望大家多多支持。 目录 一、概览二、应用内启动源码流程 (startActivity)2.1 startActivit…

ADRV9009子卡 设计原理图:FMCJ450-基于ADRV9009的双收双发射频FMC子卡 便携测试设备

FMCJ450-基于ADRV9009的双收双发射频FMC子卡 一、板卡概述 ADRV9009是一款高集成度射频(RF)、捷变收发器&#xff0c;提供双通道发射器和接收器、集成式频率合成器以及数字信号处理功能。北京太速科技&#xff0c;这款IC具备多样化的高性能和低功耗组合&#xff0c;FMC子…

基于亚马逊云科技无服务器服务快速搭建电商平台——部署篇

受疫情影响消费者习惯发生改变&#xff0c;刺激了全球电商行业的快速发展。除了依托第三方电商平台将产品销售给消费者之外&#xff0c;企业通过品牌官网或者自有电商平台销售商品也是近几年电商领域快速发展的商业模式。独立站电商模式可以进行多方面、全渠道的互联网市场拓展…

Git分布式版本控制系统与github

第四阶段提升 时 间&#xff1a;2023年8月29日 参加人&#xff1a;全班人员 内 容&#xff1a; Git分布式版本控制系统与github 目录 一、案例概述 二、版本控制系统 &#xff08;一&#xff09; 本地版本控制 &#xff08;二&#xff09;集中化的版本控制系统 &…

DP读书:鲲鹏处理器 架构与编程(十三)操作系统内核与云基础软件

操作系统内核与云基础软件 鲲鹏软件构成硬件特定软件 鲲鹏软件构成硬件特定软件1. Boot Loader2. SBSA 与 SBBR3. UEFI4. ACPI 操作系统内核Linux系统调用Linux进程调度Linux内存管理Linux虚拟文件系统Linux网络子系统Linux进程间通信Linux可加载内核模块Linux设备驱动程序Linu…

Vue 项目性能优化 — 实践指南

前言 Vue 框架通过数据双向绑定和虚拟 DOM 技术&#xff0c;帮我们处理了前端开发中最脏最累的 DOM 操作部分&#xff0c; 我们不再需要去考虑如何操作 DOM 以及如何最高效地操作 DOM&#xff1b;但 Vue 项目中仍然存在项目首屏优化、Webpack 编译配置优化等问题&#xff0c;所…