TensorFlow中slim包的具体用法

TensorFlow中slim包的具体用法

  • 1、训练脚本文件(该文件包含数据下载打包、模型训练,模型评估流程)
  • 3、模型训练
    • 1、数据集相关模块:
    • 2、设置网络模型模块
    • 3、数据预处理模块
    • 4、定义损失loss
    • 5、定义优化器模块

本次使用的TensorFlow版本是1.13.0
地址:https://github.com/tensorflow/models/tree/r1.13.0
到tensorflow-models的GitHub下载research下面的slim这个包到本地
在这里插入图片描述

TensorFlow中slim包的目录结构:

-- slim|-- BUILD|-- README.md|-- WORKSPACE|-- __init__.py|-- datasets|   |-- __init__.py|   |-- __pycache__|   |   |-- __init__.cpython-37.pyc|   |   |-- dataset_utils.cpython-37.pyc|   |   |-- download_and_convert_cifar10.cpython-37.pyc|   |   |-- download_and_convert_flowers.cpython-37.pyc|   |   `-- download_and_convert_mnist.cpython-37.pyc|   |-- build_imagenet_data.py|   |-- cifar10.py|   |-- dataset_factory.py|   |-- dataset_utils.py|   |-- download_and_convert_cifar10.py|   |-- download_and_convert_flowers.py|   |-- download_and_convert_imagenet.sh|   |-- download_and_convert_mnist.py|   |-- download_imagenet.sh|   |-- flowers.py|   |-- imagenet.py|   |-- imagenet_2012_validation_synset_labels.txt|   |-- imagenet_lsvrc_2015_synsets.txt|   |-- imagenet_metadata.txt|   |-- mnist.py|   |-- preprocess_imagenet_validation_data.py|   `-- process_bounding_boxes.py|-- deployment|   |-- __init__.py|   |-- model_deploy.py|   `-- model_deploy_test.py|-- download_and_convert_data.py    # 下载相应的数据集,并将数据打包成TF-record的格式|-- eval_image_classifier.py        # 测试模型分类效果|-- export_inference_graph.py|-- export_inference_graph_test.py|-- nets|   |-- __init__.py|   |-- alexnet.py|   |-- alexnet_test.py|   |-- cifarnet.py|   |-- cyclegan.py|   |-- cyclegan_test.py|   |-- dcgan.py|   |-- dcgan_test.py|   |-- i3d.py|   |-- i3d_test.py|   |-- i3d_utils.py|   |-- inception.py|   |-- inception_resnet_v2.py|   |-- inception_resnet_v2_test.py|   |-- inception_utils.py|   |-- inception_v1.py|   |-- inception_v1_test.py|   |-- inception_v2.py|   |-- inception_v2_test.py|   |-- inception_v3.py|   |-- inception_v3_test.py|   |-- inception_v4.py|   |-- inception_v4_test.py|   |-- lenet.py|   |-- mobilenet|   |   |-- README.md|   |   |-- __init__.py|   |   |-- conv_blocks.py|   |   |-- madds_top1_accuracy.png|   |   |-- mnet_v1_vs_v2_pixel1_latency.png|   |   |-- mobilenet.py|   |   |-- mobilenet_example.ipynb|   |   |-- mobilenet_v2.py|   |   `-- mobilenet_v2_test.py|   |-- mobilenet_v1.md|   |-- mobilenet_v1.png|   |-- mobilenet_v1.py|   |-- mobilenet_v1_eval.py|   |-- mobilenet_v1_test.py|   |-- mobilenet_v1_train.py|   |-- nasnet|   |   |-- README.md|   |   |-- __init__.py|   |   |-- nasnet.py|   |   |-- nasnet_test.py|   |   |-- nasnet_utils.py|   |   |-- nasnet_utils_test.py|   |   |-- pnasnet.py|   |   `-- pnasnet_test.py|   |-- nets_factory.py|   |-- nets_factory_test.py|   |-- overfeat.py|   |-- overfeat_test.py|   |-- pix2pix.py|   |-- pix2pix_test.py|   |-- resnet_utils.py|   |-- resnet_v1.py|   |-- resnet_v1_test.py|   |-- resnet_v2.py|   |-- resnet_v2_test.py|   |-- s3dg.py|   |-- s3dg_test.py|   |-- vgg.py|   `-- vgg_test.py|-- preprocessing|   |-- __init__.py|   |-- cifarnet_preprocessing.py|   |-- inception_preprocessing.py|   |-- lenet_preprocessing.py|   |-- preprocessing_factory.py|   `-- vgg_preprocessing.py|-- scripts                     # gqr:存储的是相关的模型训练脚本                |   |-- export_mobilenet.sh|   |-- finetune_inception_resnet_v2_on_flowers.sh|   |-- finetune_inception_v1_on_flowers.sh|   |-- finetune_inception_v3_on_flowers.sh|   |-- finetune_resnet_v1_50_on_flowers.sh|   |-- train_cifarnet_on_cifar10.sh|   `-- train_lenet_on_mnist.sh|-- setup.py|-- slim_walkthrough.ipynb`-- train_image_classifier.py    # 训练模型的脚本

1、训练脚本文件(该文件包含数据下载打包、模型训练,模型评估流程)

scripts/finetune_resnet_v1_50_on_flowers.sh

#!/bin/bash
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
#
# This script performs the following operations:
# 1. Downloads the Flowers dataset
# 2. Fine-tunes a ResNetV1-50 model on the Flowers training set.
# 3. Evaluates the model on the Flowers validation set.
#
# Usage:
# cd slim
# ./slim/scripts/finetune_resnet_v1_50_on_flowers.sh
set -e# Where the pre-trained ResNetV1-50 checkpoint is saved to.
PRETRAINED_CHECKPOINT_DIR=/tmp/checkpoints   # gqr:预训练模型存放路径# Where the training (fine-tuned) checkpoint and logs will be saved to.
TRAIN_DIR=/tmp/flowers-models/resnet_v1_50# Where the dataset is saved to.
DATASET_DIR=/tmp/flowers    # gqr:数据集存放路径# Download the pre-trained checkpoint.
if [ ! -d "$PRETRAINED_CHECKPOINT_DIR" ]; thenmkdir ${PRETRAINED_CHECKPOINT_DIR}
fi
if [ ! -f ${PRETRAINED_CHECKPOINT_DIR}/resnet_v1_50.ckpt ]; thenwget http://download.tensorflow.org/models/resnet_v1_50_2016_08_28.tar.gztar -xvf resnet_v1_50_2016_08_28.tar.gzmv resnet_v1_50.ckpt ${PRETRAINED_CHECKPOINT_DIR}/resnet_v1_50.ckptrm resnet_v1_50_2016_08_28.tar.gz
fi# Download the dataset
python download_and_convert_data.py \--dataset_name=flowers \--dataset_dir=${DATASET_DIR}# Fine-tune only the new layers for 3000 steps.
python train_image_classifier.py \--train_dir=${TRAIN_DIR} \--dataset_name=flowers \--dataset_split_name=train \--dataset_dir=${DATASET_DIR} \--model_name=resnet_v1_50 \--checkpoint_path=${PRETRAINED_CHECKPOINT_DIR}/resnet_v1_50.ckpt \--checkpoint_exclude_scopes=resnet_v1_50/logits \--trainable_scopes=resnet_v1_50/logits \--max_number_of_steps=3000 \--batch_size=32 \--learning_rate=0.01 \--save_interval_secs=60 \--save_summaries_secs=60 \--log_every_n_steps=100 \--optimizer=rmsprop \--weight_decay=0.00004# Run evaluation.
python eval_image_classifier.py \--checkpoint_path=${TRAIN_DIR} \--eval_dir=${TRAIN_DIR} \--dataset_name=flowers \--dataset_split_name=validation \--dataset_dir=${DATASET_DIR} \--model_name=resnet_v1_50# Fine-tune all the new layers for 1000 steps.
python train_image_classifier.py \--train_dir=${TRAIN_DIR}/all \--dataset_name=flowers \--dataset_split_name=train \--dataset_dir=${DATASET_DIR} \--checkpoint_path=${TRAIN_DIR} \--model_name=resnet_v1_50 \--max_number_of_steps=1000 \--batch_size=32 \--learning_rate=0.001 \--save_interval_secs=60 \--save_summaries_secs=60 \--log_every_n_steps=100 \--optimizer=rmsprop \--weight_decay=0.00004# Run evaluation.
python eval_image_classifier.py \--checkpoint_path=${TRAIN_DIR}/all \--eval_dir=${TRAIN_DIR}/all \--dataset_name=flowers \--dataset_split_name=validation \--dataset_dir=${DATASET_DIR} \--model_name=resnet_v1_50

以上文件以下载并打包flowers数据集为例会调用slim/datasets下的****download_and_convert_flowers.py
在这里插入图片描述
代码43行:_NUM_VALIDATION = 350值的意思的测试数据集的数量,我们一般2,8分数据集,这里只用填写测试集的数据代码会自动吧总数据集分成2部分
代码48行:_NUM_SHARDS = 1这个的意思是生成几个tfrecord文件,这个数量是根据你数据量来划分
在这里插入图片描述
代码190行:dataset_utils.download_and_uncompress_tarball(_DATA_URL, dataset_dir) 函数为下载数据集函数,如果本地已经存在数据集,可将将其注释掉
在这里插入图片描述
代码210行:_clean_up_temporary_files(dataset_dir) 函数为打包完毕后删除下载的数据集文件,如果需要下载的数据集可以将其注释掉

上述文件执行完毕后,会得到以下文件
在这里插入图片描述

3、模型训练

模型训练文件为
在这里插入图片描述
以下是该文件中各个模块相关内容

1、数据集相关模块:

在这里插入图片描述

2、设置网络模型模块

在这里插入图片描述

3、数据预处理模块

在这里插入图片描述

4、定义损失loss

在这里插入图片描述

5、定义优化器模块

在这里插入图片描述

运行训练指令:

python train_image_classifier.py \--train_dir=./data/flowers-models/resnet_v1_50\--dataset_name=flowers \--dataset_split_name=train \--dataset_dir=./data/flowers \--model_name=resnet_v1_50 \--checkpoint_path=./data/checkpoints/resnet_v1_50.ckpt \--checkpoint_exclude_scopes=resnet_v1_50/logits \--trainable_scopes=resnet_v1_50/logits \--max_number_of_steps=3000 \ --batch_size=32 \--learning_rate=0.01 \--save_interval_secs=60 \--save_summaries_secs=60 \--log_every_n_steps=100 \--optimizer=rmsprop \--weight_decay=0.00004

–dataset_name=指定模板
–model_name=指定预训练模板
–dataset_dir=指定训练集目录
–checkpoint_exclude_scopes=指定忘记那几层的参数,不带进训练里面,记住提取特征的部分
–train_dir=训练参数存放地址
–trainable_scopes=设定只对那几层变量进行调整,其他层都不进行调整,不设定就会对所有层训练(所以是必须要给定的)
–learning_rate=学习率
–optimizer=优化器
–max_number_of_steps=训练步数
–batch_size=一次训练所选取的样本数。 (Batch Size的大小影响模型的优化程度和速度。同时其直接影响到GPU内存的使用情况,假如你GPU内存不大,该数值最好设置小一点。)
–weight_decay=即模型中所有参数的二次正则化超参数(这个的加入就是为了防止过拟合加入正则项,weight_decay 是乘在正则项的前面,控制正则化项在损失函数中所占权重的)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/106911.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【C++】—— C++11之可变参数模板

前言: 在C语言中,我们谈论了有关可变参数的相关知识。在C11中引入了一个新特性---即可变参数模板。本期,我们将要介绍的就是有关可变参数模板的相关知识!!! 目录 序言 (一)可变参…

0基础学习VR全景平台篇 第90篇:智慧眼-数据统计

【数据统计】是按不同条件去统计整个智慧眼项目中的热点,共包含四大块,分别是数据统计、分类热点、待审核、回收站,下面我们来逐一进行介绍。 1、数据统计 ① 可以按所属分类、场景分组、所属场景、热点类型以及输入热点名去筛选对应的热点&…

文生图模型之Stable Diffusion

原始文章地址 autoencoder CLIP text encoder tokenizer最大长度为77(CLIP训练时所采用的设置),当输入text的tokens数量超过77后,将进行截断,如果不足则进行paddings,这样将保证无论输入任何长度的文本&…

Kaniko在containerd中无特权快速构建并推送容器镜像

目录 一、kaniko是什么 二、kaniko工作原理 三、kanijo工作在Containerd上 基于serverless的考虑,我们选择了kaniko作为镜像打包工具,它是google提供了一种不需要特权就可以构建的docker镜像构建工具。 一、kaniko是什么 kaniko 是一种在容器或 Kube…

【Linux】进程状态|僵尸进程|孤儿进程

前言 本文继续深入讲解进程内容——进程状态。 一个进程包含有多种状态,有运行状态,阻塞状态,挂起状态,僵尸状态,死亡状态等等,其中,阻塞状态还包含深度睡眠和浅度睡眠状态。 个人主页&#xff…

Diffusion Models for Image Restoration and Enhancement – A Comprehensive Survey

图像恢复与增强的扩散模型综述 论文链接:https://arxiv.org/abs/2308.09388 项目地址:https://github.com/lixinustc/Awesome-diffusion-model-for-image-processing/ Abstract 图像恢复(IR)一直是低水平视觉领域不可或缺的一项具有挑战性的任务&…

算法竞赛入门【码蹄集新手村600题】(MT1220-1240)C语言

算法竞赛入门【码蹄集新手村600题】(MT1220-1240)C语言 目录MT1221 分数的总和MT1222 等差数列MT1223 N是什么MT1224 棋盘MT1225 复杂分数MT1226 解不等式MT1227 宝宝爬楼梯MT1228 宝宝抢糖果MT1229 搬家公司MT1230 圆周率MT1231圆周率IIMT1232 数字和MT1233 数字之…

适配器模式实现stack和queue

适配器模式实现stack和queue 什么是适配器模式?STL标准库中stack和queue的底层结构stack的模拟实现queue的模拟实现 什么是适配器模式? 适配器是一种设计模式(设计模式是一套被反复使用的、多数人知晓的、经过分类编目的、代码设计经验的总结)&#xff…

时间和日期--Python

1. 时间:time模块 总结:2. datetime模块 相比与time模块,datetime模块的接口更直观、更容易调用 2.1 datetime模块定义的类 (1)datetime.date:表示日期的类。常用的属性有:year、month、day; &#xff…

k8s节点pod驱逐、污点标记

一、设置污点,禁止pod被调度到节点上 kubectl cordon k8s-node-145 设置完成后,可以看到该节点附带了 SchedulingDisabled 的标记 二、驱逐节点上运行的pod到其他节点 kubectl drain --ignore-daemonsets --delete-emptydir-data k8s-node-145 显示被驱逐…

抓包相关,抓包学习

检查网络流量 - 提琴手经典 (telerik.com) Headers Reference - Fiddler Classic (telerik.com) 以上是fiddler官方文档 F12要勾选保留日志 不勾选的话跳转到新页面之前页面的日志不会在下方显示 会保留所有抓到的包 如果重定向到别的页面 F12抓包可能看不到响应信息,但是…

【PHP】PHP开发教程-PHP开发环境安装

1、PHP简单介绍 PHP(全称:Hypertext Preprocessor)是一种广泛使用的开放源代码脚本语言,特别适用于Web开发。它嵌入在HTML中,通过在HTML文档中添加PHP标记和脚本,可以生成动态的、个性化的Web页面。 PHP最…

Vant 4.6.4发布,增加了一些新功能,并修复了一些bug

导读Vant 4.6.4发布,增加了一些新功能,并修复了一些bug等。 新功能 feat(area-data): 更新芜湖的县区数据,由 nivin-studio 在 #12122 中贡献feat(Locale): 添加塞尔维亚语到国际化,由 RogerZXY 在 #12145 中贡献feat(ImagePreview): 添加 c…

百望云华为云共建零售数字化新生态 聚焦数智新消费升级

零售业是一个充满活力和创新的行业,但也是当前面临很大新挑战和新机遇的行业。数智新消费时代,数字化转型已经成为零售企业必须面对的重要课题。 8 月 20 日-21日,以“云上创新 韧性增长”为主题的华为云数智新消费创新峰会2023在成都隆重召…

Redis从基础到进阶篇(二)----内存模型与内存优化

目录 一、缓存通识 1.1 ⽆处不在的缓存 1.2 多级缓存 (重点) 二、Redis简介 2.1 什么是Redis 2.2 Redis的应用场景 三、Redis数据存储的细节 3.1 Redis数据类型 3.2 内存结构 3.3 内存分配器 3.4 redisObject 3.4.1 type 3.4.2 encoding 3…

微积分基本概念

微分 函数的微分是指对函数的局部变化的一种线性描述。微分可以近似地描述当函数自变量的取值作足够小的改变时,函数的值是怎样改变的。。对于函数 y f ( x ) y f(x) yf(x) 的微分记作: d y f ′ ( x ) d x d_y f^{}(x)d_x dy​f′(x)dx​ 微分和…

什么是响应式设计(Responsive Design)?如何实现一个响应式网页?

聚沙成塔每天进步一点点 ⭐ 专栏简介⭐ 响应式设计(Responsive Design)⭐ 如何实现一个响应式网页?1. 弹性网格布局2. 媒体查询3. 弹性图像和媒体4. 流式布局5. 优化导航6. 测试和调整7. 图片优化8. 字体优化9. 渐进增强10. 面向移动优先11. …

芯讯通SIMCOM A7680C (4G Cat.1)AT指令测试 TCP通信过程

A7680C TCP通信 1、文档准备 去SIMCOM官网找到A7680C的AT指令集 AT指令官网 进入官网有这么多AT指令文件,只需要找到你需要用到的,这里我们用到了HTTP和TCP的,所以下载这两个即可。 2、串口助手 任意准备一个串口助手即可 这里我使用的是XC…

C++笔记之设计模式:setter函数、依赖注入

C笔记之设计模式:setter函数、依赖注入 参考笔记: 1.C笔记之静态成员函数可以在类外部访问私有构造函数吗? 2.C笔记之设计模式:setter函数、依赖注入 3.C笔记之两个类的实例之间传递参数——通过构造函数传递类对象的方法详细探究…