【学习】使用PyTorch训练与评估自己的ResNet网络教程

参考:保姆级使用PyTorch训练与评估自己的ResNet网络教程_训练自己的图像分类网络resnet101 pytorch-CSDN博客

项目地址:GitHub - Fafa-DL/Awesome-Backbones: Integrate deep learning models for image classification | Backbone learning/comparison/magic modification project

视频手把手教程:我将维护一个集成各主干网络的图像分类项目_哔哩哔哩_bilibili

主要是复现和训练测试自己的数据集

复现部分

0.环境问题

pytorch官网里面找个合适的CUDA11.0安装一下,然后把requirements.txt安装一下

pip install -r requirements.txt

 参考版本:

pip list
Package                Version
---------------------- ---------------
certifi                2021.5.30
cycler                 0.11.0
dataclasses            0.8
importlib-resources    5.4.0
joblib                 1.1.1
kiwisolver             1.3.1
matplotlib             3.3.4
mkl-fft                1.3.0
mkl-random             1.1.1
mkl-service            2.3.0
numpy                  1.19.2
olefile                0.46
opencv-contrib-python  4.0.1.24
opencv-python          4.0.1.24
opencv-python-headless 4.0.1.24
packaging              21.3
Pillow                 8.4.0
pip                    21.3.1
pyparsing              3.0.7
python-dateutil        2.9.0.post0
scikit-learn           0.24.2
scipy                  1.5.4
setuptools             36.4.0
six                    1.16.0
terminaltables         3.1.10
threadpoolctl          3.1.0
torch                  1.7.1
torchaudio             0.7.0a0+a853dff
torchvision            0.8.2
tqdm                   4.64.1
typing_extensions      4.1.1
wheel                  0.37.1
zipp                   3.6.0
  • 下载MobileNetV3-Small权重至datas
  • 利用项目里的猫狗图片检验一下安装情况
    python tools/single_test.py datas/cat-dog.png models/mobilenet/mobilenet_v3_small.py --classes-map datas/imageNet1kAnnotation.txt
    

    成功的话大概这样:

 1.数据集问题

 先下载花卉数据集(0zat):flower_photos.zip_免费高速下载|百度网盘-分享无限制 (baidu.com)

 原始地址在项目的资料部分:GitHub - Fafa-DL/Awesome-Backbones: Integrate deep learning models for image classification | Backbone learning/comparison/magic modification project

 目录结构,按照花卉类型存放

├─flower_photos
│  ├─daisy
│  │      100080576_f52e8ee070_n.jpg
│  │      10140303196_b88d3d6cec.jpg
│  │      ...
│  ├─dandelion
│  │      10043234166_e6dd915111_n.jpg
│  │      10200780773_c6051a7d71_n.jpg
│  │      ...
│  ├─roses
│  │      10090824183_d02c613f10_m.jpg
│  │      102501987_3cdb8e5394_n.jpg
│  │      ...
│  ├─sunflowers
│  │      1008566138_6927679c8a.jpg
│  │      1022552002_2b93faf9e7_n.jpg
│  │      ...
│  └─tulips
│  │      100930342_92e8746431_n.jpg
│  │      10094729603_eeca3f2cb6.jpg
│  │      ...
  • datas/中创建标签文件annotations.txt,按行将类别名的索引写入文件(应该已经写好了);即
    daisy 0
    dandelion 1
    roses 2
    sunflowers 3
    tulips 4
    

    之后进行数据集划分,随机分为训练和测试集。

  • 在tools/split_data.py中修改原始数据集地址和划分后的数据集地址。(new_datasets最好别更改)

    init_dataset = './flower_photos'
    new_dataset = './Awesome-Backbones/datasets'
    

    终端使用命令:

    python tools/split_data.py
    

    划分后的数据集格式大概为:

    ├─...
    ├─datasets
    │  ├─test
    │  │  ├─daisy
    │  │  ├─dandelion
    │  │  ├─roses
    │  │  ├─sunflowers
    │  │  └─tulips
    │  └─train
    │      ├─daisy
    │      ├─dandelion
    │      ├─roses
    │      ├─sunflowers
    │      └─tulips
    ├─...
    

    查看tools/get_annotation.py,看看路径要不要更改:

  • datasets_path   = '你的数据集路径'
    

 终端使用命令:

python tools/get_annotation.py

 该命令应该会在datas/下形成train.txt和test.txt,里面是具体照片的位置

2.修改配置文件

/models下有许多的模型配置文件

 以resnet为例

 挑一个顺眼的改改

以resnet101为例

# model settingsmodel_cfg = dict(backbone=dict(type='ResNet',depth=101,num_stages=4,out_indices=(3, ),style='pytorch'),neck=dict(type='GlobalAveragePooling'),head=dict(type='LinearClsHead',num_classes=5,in_channels=2048,loss=dict(type='CrossEntropyLoss', loss_weight=1.0),topk=(1, 5),))# dataloader pipeline
img_lighting_cfg = dict(eigval=[55.4625, 4.7940, 1.1475],eigvec=[[-0.5675, 0.7192, 0.4009], [-0.5808, -0.0045, -0.8140],[-0.5836, -0.6948, 0.4203]],alphastd=0.1,to_rgb=True)
policies = [dict(type='AutoContrast', prob=0.5),dict(type='Equalize', prob=0.5),dict(type='Invert', prob=0.5),dict(type='Rotate',magnitude_key='angle',magnitude_range=(0, 30),pad_val=0,prob=0.5,random_negative_prob=0.5),dict(type='Posterize',magnitude_key='bits',magnitude_range=(0, 4),prob=0.5),dict(type='Solarize',magnitude_key='thr',magnitude_range=(0, 256),prob=0.5),dict(type='SolarizeAdd',magnitude_key='magnitude',magnitude_range=(0, 110),thr=128,prob=0.5),dict(type='ColorTransform',magnitude_key='magnitude',magnitude_range=(-0.9, 0.9),prob=0.5,random_negative_prob=0.),dict(type='Contrast',magnitude_key='magnitude',magnitude_range=(-0.9, 0.9),prob=0.5,random_negative_prob=0.),dict(type='Brightness',magnitude_key='magnitude',magnitude_range=(-0.9, 0.9),prob=0.5,random_negative_prob=0.),dict(type='Sharpness',magnitude_key='magnitude',magnitude_range=(-0.9, 0.9),prob=0.5,random_negative_prob=0.),dict(type='Shear',magnitude_key='magnitude',magnitude_range=(0, 0.3),pad_val=0,prob=0.5,direction='horizontal',random_negative_prob=0.5),dict(type='Shear',magnitude_key='magnitude',magnitude_range=(0, 0.3),pad_val=0,prob=0.5,direction='vertical',random_negative_prob=0.5),dict(type='Cutout',magnitude_key='shape',magnitude_range=(1, 41),pad_val=0,prob=0.5),dict(type='Translate',magnitude_key='magnitude',magnitude_range=(0, 0.3),pad_val=0,prob=0.5,direction='horizontal',random_negative_prob=0.5,interpolation='bicubic'),dict(type='Translate',magnitude_key='magnitude',magnitude_range=(0, 0.3),pad_val=0,prob=0.5,direction='vertical',random_negative_prob=0.5,interpolation='bicubic')
]
train_pipeline = [dict(type='LoadImageFromFile'),dict(type='RandAugment',policies=policies,num_policies=2,magnitude_level=12),dict(type='RandomResizedCrop',size=224,efficientnet_style=True,interpolation='bicubic',backend='pillow'),dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),dict(type='ColorJitter', brightness=0.4, contrast=0.4, saturation=0.4),dict(type='Lighting', **img_lighting_cfg),dict(type='Normalize',mean=[123.675, 116.28, 103.53],std=[58.395, 57.12, 57.375],to_rgb=False),dict(type='ImageToTensor', keys=['img']),dict(type='ToTensor', keys=['gt_label']),dict(type='Collect', keys=['img', 'gt_label'])
]
val_pipeline = [dict(type='LoadImageFromFile'),dict(type='CenterCrop',crop_size=224,efficientnet_style=True,interpolation='bicubic',backend='pillow'),dict(type='Normalize',mean=[123.675, 116.28, 103.53],std=[58.395, 57.12, 57.375],to_rgb=True),dict(type='ImageToTensor', keys=['img']),dict(type='Collect', keys=['img'])
]# train
data_cfg = dict(batch_size = 32,num_workers = 0,train = dict(pretrained_flag = False,pretrained_weights = '',freeze_flag = False,freeze_layers = ('backbone',),epoches = 150,),test=dict(ckpt = './logs/ResNet/2024-06-26-10-37-00/Last_Epoch150.pth',metrics = ['accuracy', 'precision', 'recall', 'f1_score', 'confusion'],metric_options = dict(topk = (1,5),thrs = None,average_mode='none'))
)# optimizer
optimizer_cfg = dict(type='SGD',lr=0.001,momentum=0.9,weight_decay=1e-4)# learning 
lr_config = dict(type='StepLrUpdater', step=[30, 60, 90])

主要改model_cfg里面的num_classes,data_cfg里的batch_size与num_workers

若有预训练权重则可以将pretrained_weights设置为True并将预训练的路径赋值给pretrained_weights

optimizer_cfg中修改初始学习率,根据batch_size调试

3.训练

终端运行

python tools/train.py models/resnet/resnet101.py

 运行结果

4.评估

在实际使用的配置文件中将ckpt修改

ckpt = '你的训练权重路径'

终端运行

python tools/evaluation.py models/resnet/resnet101.py

 运行结果

 我跑出来的准确率不高哈

5.测试

单张测试

python tools/single_test.py datasets/test/dandelion/14283011_3e7452c5b2_n.jpg models/resnet/resnet101.py

多张测试

使用batch_test.py,路径使用文件夹路径。

----------------------------------------------------------------------------------------------

使用自己的数据集

1.数据集准备

2.配置文件

3.训练

4.评估

5.测试

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/361301.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

高效修复机床导轨磨损,保障加工精度!

机床导轨是支承和引导运动构件沿着一定轨迹运动的传动装置,在机器设备中是个十分重要的部件,在机床中是常见的部件。机床的加工精度与导轨精度有直接的联系,且导轨一旦损坏,维修较复杂且困难。我们简单总结了以下几点对于机床导轨…

编程设计思想

健康检查脚本 nmap:扫描端口 while true do healthycurl B:httpPORT/healthy -i | grep HTTP/1.1 | tail -n 1 | awk {print $2} done 批量操作类型脚本(记录每一步日志) 将100个nginx:vn推送到harbor仓库192.168.0.100 根据镜像对比sha值…

【开源项目】自然语言处理领域的明星项目推荐:Hugging Face Transformers

在当今人工智能与大数据飞速发展的时代,自然语言处理(NLP)已成为推动科技进步的重要力量。而在NLP领域,Hugging Face Transformers无疑是一个备受瞩目的开源项目。本文将从项目介绍、代码解释以及技术特点等角度,为您深…

面向对象修炼手册(四)(多态与空间分配)(Java宝典)

🌈 个人主页:十二月的猫-CSDN博客 🔥 系列专栏: 🏀面向对象修炼手册 💪🏻 十二月的寒冬阻挡不了春天的脚步,十二点的黑夜遮蔽不住黎明的曙光 目录 前言 1 多态 1.1 多态的形式&…

需求之 实现获取调试信息在h5页面,在手机端可以查看调试(二)

事实证明 chatgpt很好用,有不懂的问题可以问它 https://zhuanlan.zhihu.com/p/690118775 国内外9个免费的ChatGPT网站 我筛选出来的比较好用免费的网站 fchat.dykyzdh.cn/ 这个也可以 阿里云的 通义灵码 在vscode中安装使用 而且阿里云有一个产品,可以…

面试-Java线程池

1.利用Excutors创建不同的线程池满足不同场景的需求 分析: 如果并发的请求的数量非常多,但每个线程执行的时间非常短,这样就会频繁的创建和销毁线程。如此一来,会大大降低系统的效率。 可能出现,服务器在为每个线程创建…

jdk1.8升级到jdk11遇到的各种问题

一、第三方依赖使用了BASE64Decoder 如果项目中使用了这个类 sun.misc.BASE64Decoder,就会导致错误,因为再jdk11中,该类已经被删除。 Caused by: java.lang.NoClassDefFoundError: sun/misc/BASE64Encoder 当然这个类也有替换方式&#xf…

MySQL实训--原神数据库

原神数据库 er图DDL/DML语句查询语句存储过程/触发器 er图 DDL/DML语句 SET NAMES utf8mb4; SET FOREIGN_KEY_CHECKS 0;DROP TABLE IF EXISTS artifacts; CREATE TABLE artifacts (id int NOT NULL AUTO_INCREMENT,artifacts_name varchar(255) CHARACTER SET utf8 COLLATE …

一文搞懂Linux多线程【下】

目录 🚩多线程代码的健壮性 🚩多线程控制 🚩线程返回值问题 🚩关于Linux线程库 🚩对Linux线程简单的封装 在观看本博客之前,建议大家先看一文搞懂Linux多线程【上】由于上一篇博客篇幅太长,为…

文件操作<C语言>

导言 平时我们在写程序时,在运行时申请内存空间,运行完时内存空间被收回,如果想要持久化的保存,我们就可以使用文件,所以下文将要介绍一些在程序中完成一些文件操作。 目录 导言 文件流 文件指针 文件的打开与关闭 …

《黑神话悟空》电脑配置要求

《黑神话:悟空》这款国内优秀的3A游戏大作,拥有顶级的特效与故事剧情,自公布以来便备受玩家期待,其精美的画面与流畅的战斗体验,对玩家的电脑配置提出一定要求。那么这款优秀的游戏需要什么样的电脑配置,才…

记录:[android] SSLHandshakeException: Handshake failed 问题;已解决!

1、问题描述:在使用Retrofit2 时在安卓老设备上(安卓6.0)网络无法请求、安卓 10 、 11 未出现此问题?what? 原因:服务端 TLS 版本过高 2、废话不多说、解决方案A 、添加依赖:implementation org.conscrypt…

黑马苍穹外卖6 清理redis缓存+Spring Cache+购物车的增删改查

缓存菜品 后端服务都去查询数据库,对数据库访问压力增大。 解决方式:使用redis来缓存菜品,用内存比磁盘性能更高。 key :dish_分类id String key “dish_” categoryId; RestController("userDishController") RequestMapping…

游戏工厂:AI(AIGC/ChatGPT)与流程式游戏开发

游戏工厂:AI(AIGC/ChatGPT)与流程式游戏开发 码客 卢益贵 ygluu 关键词:AI(AIGC、ChatGPT、文心一言)、流程式管理、好莱坞电影流程、电影工厂、游戏工厂、游戏开发流程、游戏架构、模块化开发 一、前言…

【每日刷题】Day75

【每日刷题】Day75 🥕个人主页:开敲🍉 🔥所属专栏:每日刷题🍍 🌼文章目录🌼 1. 1833. 雪糕的最大数量 - 力扣(LeetCode) 2. 面试题 17.14. 最小K个数 - 力扣…

LabVIEW电梯钢丝绳实时监测系统

电梯作为现代高层建筑中不可或缺的交通工具,其安全性直接影响到乘客的生命财产安全。电梯钢丝绳作为承载乘客与货物的关键部件,其健康状况尤为重要。传统的钢丝绳检测方法大多依赖于定期检查,无法实现实时监控,存在一定的安全隐患…

DPDK使用make编译并运行示例程序

环境: VMware Workstation 16 Pro 16.2.4 虚拟机系统:Centos 8 DPDK版本:stable-20.11.10 下载源码后,使用meson和ninja编译完成、配置并挂载大页、内核和VFIO设置完成,在dpdk源码目录下的build/…

安全技术和防火墙

安全技术和防火墙 安全技术 入侵检测系统:特点是不阻断网络访问,主要提供报警和事后监督,不主动介入,默默看着你(监控) 入侵防御系统:透明模式工作,数据包,网络监控&am…

Python22 Pandas库

Pandas 是一个Python数据分析库,它提供了高性能、易于使用的数据结构和数据分析工具。这个库适用于处理和分析输入数据,常见于统计分析、金融分析、社会科学研究等领域。 1.Pandas的核心功能 Pandas 库的核心功能包括: 1.数据结构&#xff…

YIA主题侧边栏如何添加3D旋转标签云?

WordPress站点侧边栏默认的标签云排版很一般,而3D旋转标签云就比较酷炫了。下面boke112百科就以YIA主题为例,跟大家说一说如何将默认的标签云修改成3D旋转标签云,具体步骤如下: 1、点此下载3d标签云文件(密码&#xf…