FastViT实战:使用FastViT实现图像分类任务(一)

文章目录

  • 摘要
  • 安装包
    • 安装timm
    • 安装 grad-cam
    • 安装mmcv
  • 数据增强Cutout和Mixup
  • EMA
  • 项目结构
  • 计算mean和std
  • 生成数据集
  • 补充一个知识点:torch.jit
    • 两种保存方式

摘要

论文翻译:https://wanghao.blog.csdn.net/article/details/132407722?spm=1001.2014.3001.5502
或者
https://blog.csdn.net/m0_47867638/article/details/132441806?spm=1001.2014.3001.5502

官方源码:https://github.com/apple/ml-fastvit

FastViT是一种混合ViT架构,它通过引入一种新型的token混合运算符RepMixer来达到最先进的延迟-准确性权衡。RepMixer通过消除网络中的跳过连接来降低内存访问成本。FastViT进一步应用训练时间过度参数化和大核卷积来提高准确性,并根据经验表明这些选择对延迟的影响最小。实验结果表明,FastViT在移动设备上的速度比最近的混合Transformer架构CMT快3.5倍,比EfficientNet快4.9倍,比ConvNeXt快1.9倍。在相似的延迟下,FastViT在ImageNet上的Top-1精度比MobileOne高出4.2%。此外,FastViT模型能够较好的适应域外和破损数据,相较于其它SOTA架构具备很强的鲁棒性和泛化性能。

在这里插入图片描述

这篇文章使用FastViT完成植物分类任务,模型采用fastvit_t8向大家展示如何使用FastViT。fastvit_t8在这个数据集上实现了95+%的ACC,如下图:

在这里插入图片描述
在这里插入图片描述

通过这篇文章能让你学到:

  1. 如何使用数据增强,包括transforms的增强、CutOut、MixUp、CutMix等增强手段?
  2. 如何实现FastViT模型实现训练?
  3. 如何使用pytorch自带混合精度?
  4. 如何使用梯度裁剪防止梯度爆炸?
  5. 如何使用DP多显卡训练?
  6. 如何绘制loss和acc曲线?
  7. 如何生成val的测评报告?
  8. 如何编写测试脚本测试测试集?
  9. 如何使用余弦退火策略调整学习率?
  10. 如何使用AverageMeter类统计ACC和loss等自定义变量?
  11. 如何理解和统计ACC1和ACC5?
  12. 如何使用EMA?

如果基础薄弱,对上面的这些功能难以理解可以看我的专栏:经典主干网络精讲与实战
这个专栏,从零开始时,一步一步的讲解这些,让大家更容易接受。

安装包

安装timm

使用pip就行,命令:

pip install timm

mixup增强和EMA用到了timm

安装 grad-cam

pip install grad-cam

安装mmcv

pip install -U openmim
mim install mmcv

数据增强Cutout和Mixup

为了提高成绩我在代码中加入Cutout和Mixup这两种增强方式。实现这两种增强需要安装torchtoolbox。安装命令:

pip install torchtoolbox

Cutout实现,在transforms中。

from torchtoolbox.transform import Cutout
# 数据预处理
transform = transforms.Compose([transforms.Resize((224, 224)),Cutout(),transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])

需要导入包:from timm.data.mixup import Mixup,

定义Mixup,和SoftTargetCrossEntropy

  mixup_fn = Mixup(mixup_alpha=0.8, cutmix_alpha=1.0, cutmix_minmax=None,prob=0.1, switch_prob=0.5, mode='batch',label_smoothing=0.1, num_classes=12)criterion_train = SoftTargetCrossEntropy()

参数详解:

mixup_alpha (float): mixup alpha 值,如果 > 0,则 mixup 处于活动状态。

cutmix_alpha (float):cutmix alpha 值,如果 > 0,cutmix 处于活动状态。

cutmix_minmax (List[float]):cutmix 最小/最大图像比率,cutmix 处于活动状态,如果不是 None,则使用这个 vs alpha。

如果设置了 cutmix_minmax 则cutmix_alpha 默认为1.0

prob (float): 每批次或元素应用 mixup 或 cutmix 的概率。

switch_prob (float): 当两者都处于活动状态时切换cutmix 和mixup 的概率 。

mode (str): 如何应用 mixup/cutmix 参数(每个’batch’,‘pair’(元素对),‘elem’(元素)。

correct_lam (bool): 当 cutmix bbox 被图像边框剪裁时应用。 lambda 校正

label_smoothing (float):将标签平滑应用于混合目标张量。

num_classes (int): 目标的类数。

EMA

EMA(Exponential Moving Average)是指数移动平均值。在深度学习中的做法是保存历史的一份参数,在一定训练阶段后,拿历史的参数给目前学习的参数做一次平滑。具体实现如下:


import logging
from collections import OrderedDict
from copy import deepcopy
import torch
import torch.nn as nn_logger = logging.getLogger(__name__)class ModelEma:def __init__(self, model, decay=0.9999, device='', resume=''):# make a copy of the model for accumulating moving average of weightsself.ema = deepcopy(model)self.ema.eval()self.decay = decayself.device = device  # perform ema on different device from model if setif device:self.ema.to(device=device)self.ema_has_module = hasattr(self.ema, 'module')if resume:self._load_checkpoint(resume)for p in self.ema.parameters():p.requires_grad_(False)def _load_checkpoint(self, checkpoint_path):checkpoint = torch.load(checkpoint_path, map_location='cpu')assert isinstance(checkpoint, dict)if 'state_dict_ema' in checkpoint:new_state_dict = OrderedDict()for k, v in checkpoint['state_dict_ema'].items():# ema model may have been wrapped by DataParallel, and need module prefixif self.ema_has_module:name = 'module.' + k if not k.startswith('module') else kelse:name = knew_state_dict[name] = vself.ema.load_state_dict(new_state_dict)_logger.info("Loaded state_dict_ema")else:_logger.warning("Failed to find state_dict_ema, starting from loaded model weights")def update(self, model):# correct a mismatch in state dict keysneeds_module = hasattr(model, 'module') and not self.ema_has_modulewith torch.no_grad():msd = model.state_dict()for k, ema_v in self.ema.state_dict().items():if needs_module:k = 'module.' + kmodel_v = msd[k].detach()if self.device:model_v = model_v.to(device=self.device)ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v)

加入到模型中。

#初始化
if use_ema:model_ema = ModelEma(model_ft,decay=model_ema_decay,device='cpu',resume=resume)# 训练过程中,更新完参数后,同步update shadow weights
def train():optimizer.step()if model_ema is not None:model_ema.update(model)# 将model_ema传入验证函数中
val(model_ema.ema, DEVICE, test_loader)

针对没有预训练的模型,容易出现EMA不上分的情况,这点大家要注意啊!

项目结构

FastViT_Demo
├─data1
│  ├─Black-grass
│  ├─Charlock
│  ├─Cleavers
│  ├─Common Chickweed
│  ├─Common wheat
│  ├─Fat Hen
│  ├─Loose Silky-bent
│  ├─Maize
│  ├─Scentless Mayweed
│  ├─Shepherds Purse
│  ├─Small-flowered Cranesbill
│  └─Sugar beet
├─models
│  ├─__init__.py
│  ├─modules
│  │  ├─mobileone.py
│  │  └─replknet.py
│  └─fastvit.py
├─mean_std.py
├─export_model.py
├─makedata.py
├─train.py
├─cam_image.py
└─test.py

models:来源官方代码,对面的代码做了一些适应性修改。
export_model.py:导出重参数模型
mean_std.py:计算mean和std的值。
makedata.py:生成数据集。
ema.py:EMA脚本
train.py:训练InceptionNext模型
cam_image.py:热力图可视化

计算mean和std

为了使模型更加快速的收敛,我们需要计算出mean和std的值,新建mean_std.py,插入代码:

from torchvision.datasets import ImageFolder
import torch
from torchvision import transformsdef get_mean_and_std(train_data):train_loader = torch.utils.data.DataLoader(train_data, batch_size=1, shuffle=False, num_workers=0,pin_memory=True)mean = torch.zeros(3)std = torch.zeros(3)for X, _ in train_loader:for d in range(3):mean[d] += X[:, d, :, :].mean()std[d] += X[:, d, :, :].std()mean.div_(len(train_data))std.div_(len(train_data))return list(mean.numpy()), list(std.numpy())if __name__ == '__main__':train_dataset = ImageFolder(root=r'data1', transform=transforms.ToTensor())print(get_mean_and_std(train_dataset))

数据集结构:

image-20220221153058619

运行结果:

([0.3281186, 0.28937867, 0.20702125], [0.09407319, 0.09732835, 0.106712654])

把这个结果记录下来,后面要用!

生成数据集

我们整理还的图像分类的数据集结构是这样的

data
├─Black-grass
├─Charlock
├─Cleavers
├─Common Chickweed
├─Common wheat
├─Fat Hen
├─Loose Silky-bent
├─Maize
├─Scentless Mayweed
├─Shepherds Purse
├─Small-flowered Cranesbill
└─Sugar beet

pytorch和keras默认加载方式是ImageNet数据集格式,格式是

├─data
│  ├─val
│  │   ├─Black-grass
│  │   ├─Charlock
│  │   ├─Cleavers
│  │   ├─Common Chickweed
│  │   ├─Common wheat
│  │   ├─Fat Hen
│  │   ├─Loose Silky-bent
│  │   ├─Maize
│  │   ├─Scentless Mayweed
│  │   ├─Shepherds Purse
│  │   ├─Small-flowered Cranesbill
│  │   └─Sugar beet
│  └─train
│      ├─Black-grass
│      ├─Charlock
│      ├─Cleavers
│      ├─Common Chickweed
│      ├─Common wheat
│      ├─Fat Hen
│      ├─Loose Silky-bent
│      ├─Maize
│      ├─Scentless Mayweed
│      ├─Shepherds Purse
│      ├─Small-flowered Cranesbill
│      └─Sugar beet

新增格式转化脚本makedata.py,插入代码:

import glob
import os
import shutilimage_list=glob.glob('data1/*/*.png')
print(image_list)
file_dir='data'
if os.path.exists(file_dir):print('true')#os.rmdir(file_dir)shutil.rmtree(file_dir)#删除再建立os.makedirs(file_dir)
else:os.makedirs(file_dir)from sklearn.model_selection import train_test_split
trainval_files, val_files = train_test_split(image_list, test_size=0.3, random_state=42)
train_dir='train'
val_dir='val'
train_root=os.path.join(file_dir,train_dir)
val_root=os.path.join(file_dir,val_dir)
for file in trainval_files:file_class=file.replace("\\","/").split('/')[-2]file_name=file.replace("\\","/").split('/')[-1]file_class=os.path.join(train_root,file_class)if not os.path.isdir(file_class):os.makedirs(file_class)shutil.copy(file, file_class + '/' + file_name)for file in val_files:file_class=file.replace("\\","/").split('/')[-2]file_name=file.replace("\\","/").split('/')[-1]file_class=os.path.join(val_root,file_class)if not os.path.isdir(file_class):os.makedirs(file_class)shutil.copy(file, file_class + '/' + file_name)

完成上面的内容就可以开启训练和测试了。

补充一个知识点:torch.jit

FastViT用到了Torch.jit保存模型。所以,我把这个知识点做个说明,方便大家理解。模型训练好后自然想要将里面所有层涉及的权重保存下来,这样子我们的模型就能部署在任意有pytorch环境下了。但是,用Torch.save/load还会依赖模型文件。

torch.jit是PyTorch的模型压缩和序列化工具,它可以将训练好的神经网络模型转换成TorchScript格式的脚本,以便在不需要Python解释器的情况下进行部署和运行。不再依赖模型文件。

torch.jit可以将训练好的神经网络模型转换成TorchScript格式的脚本,这样可以大大减少模型的内存占用,提高模型的运行速度,同时也可以避免Python环境的不稳定性对模型运行的影响。

两种保存方式

torch.jit.trace:这种方式为追踪一个函数的执行流,使用时需要提供一个测试输入。详见:https://pytorch.org/docs/1.6.0/generated/torch.jit.trace.html?highlight=jit%20trace#torch.jit.trace

需要注意的是这个接口只追踪测试输入走过的函数执行流(如果模型中有多条分支的话只会保存测试输入走过的分支!!!!!),所以对于一些多分支的模型不要采用这种方式,采用下面的Torch.jit.script。比如model.eval()和model.train()可以控制模型内BN层和dropout的权重是否固定,如果采用这种方式只能保留其中之一状态(固定或不固定)。

torch.jit.script:使用这种方式可以将一个模型完整的保存下来,和上面的trace正好相对。如果模型中的分支很多,并且在运行时会改变的话一定要用这种形式保存。详见:https://pytorch.org/docs/1.6.0/generated/torch.jit.script.html?highlight=torch%20jit%20script#torch.jit.script

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/121434.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

前端实习第七周周记

前言 第六周没写,是因为第六周的前两天在处理第五周的样本库部分。问题解决一个是嵌套问题(因为我用到了递归),还有一个问题在于本机没有问题,打包上线接口404。这个问题我会在这周的总结中说。 第六周第三天才谈好新…

【核心复现】基于改进灰狼算法的并网交流微电网经济优化调度(Matlab代码实现)

💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…

Re44:数据集 GSM8K 和 论文 Training Verifiers to Solve Math Word Problems

诸神缄默不语-个人CSDN博文目录 诸神缄默不语的论文阅读笔记和分类 论文全名:Training Verifiers to Solve Math Word Problems GSM8K数据集原始论文 OpenAI 2021年的工作,关注解决MWP问题(具体场景是小学(grade school&#xf…

如何在Mac电脑上安装WeasyPrint:简单易懂的步骤

1. 安装homebrew 首先需要确保安装了homebrew,通过homebrew安装weasyprint可以将需要的库都安装好,比pip安装更简单快捷。 安装方法如下: /bin/zsh -c "$(curl -fsSL https://gitee.com/cunkai/HomebrewCN/raw/master/Homebrew.sh)&qu…

SpringBoot v2.7.x+ 整合Swagger3入坑记?

目录 一、依赖 二、集成Swagger Java Config 三、配置完毕 四、解决方案 彩蛋 想尝鲜&#xff0c;坑也多&#xff0c;一起入个坑~ 一、依赖 SpringBoot版本&#xff1a;2.7.14 Swagger版本&#xff1a;3.0.0 <dependency><groupId>com.github.xiaoymin<…

方案展示 | RK3588开发板Linux双摄同显方案

iTOP-RK3588开发板使用手册更新&#xff0c;后续资料会不断更新&#xff0c;不断完善&#xff0c;帮助用户快速入门&#xff0c;大大提升研发速度。 RK3588开发板载4路MIPI CAMERA摄像头接口、MIPI CSI DPHY的4.5Gbps、2.5Gops的MIPI CSI CPHY&#xff0c;四路同时输入&#xf…

react快速开始(三)-create-react-app脚手架项目启动;使用VScode调试react

文章目录 react快速开始(三)-create-react-app脚手架项目启动&#xff1b;使用VScode调试react一、create-react-app脚手架项目启动1. react-scripts2. 关于better-npm-runbetter-npm-run安装 二、使用VScode调试react1. 浏览器插件React Developer Tools2. 【重点】用 VSCode …

MEMS传感器的原理与构造——单片式硅陀螺仪

一、前言 机械转子式陀螺仪在很长的一段时间内都是唯一的选项&#xff0c;也正是因为它的结构和原理&#xff0c;使其不再适用于现代小型、单体、集成式传感器的设计。常规的机械转子式陀螺仪包括平衡环、支撑轴承、电机和转子等部件&#xff0c;这些部件需要精密加工和…

mysql group by 字段 与 select 字段

表数据如下&#xff1a; 执行SQL语句1&#xff1a; SELECT * FROM z_course GROUP BY NAME,SEX 结果&#xff1a; 执行SQL语句2&#xff1a; SELECT * FROM z_course GROUP BY NAME sql 1 根据 name&#xff0c;sex 两个字段分组&#xff0c;查询 所有字段&#xff0c;返回结…

骨传导耳机用久了伤耳朵吗?骨传导耳机有什么优势

骨传导耳机用久了不伤耳朵&#xff0c;相对于传统的入耳式耳机来说&#xff0c;对耳朵的压力和损伤较小。由于骨传导技术不直接通过耳道传递声音&#xff0c;而是通过振动将声音传送到内耳&#xff0c;因此相比其他类型的耳机&#xff0c;它在减少听力损伤的风险方面具有优势。…

Java 加了@PreAuthorize注解的接口在Postman中访问

1. 首先&#xff0c;你需要获取一个有效的用户token&#xff0c;该token应包含了相应的接口权限。你可以通过登录或其他身份验证方式来获取token。2. 打开Postman&#xff0c;并确保已选择正确的HTTP方法&#xff08;GET、POST等&#xff09;。3. 在请求的Headers部分&#xff…

Flink中RPC实现原理简介

前提知识 Akka是一套可扩展、弹性和快速的系统&#xff0c;为此Flink基于Akka实现了一套内部的RPC通信框架&#xff1b;为此先对Akka进行了解 Akka Akka是使用Scala语言编写的库&#xff0c;基于Actor模型提供一个用于构建可扩展、弹性、快速响应的系统&#xff1b;并被应用…

内网隧道代理技术(十九)之 CS工具自带上线不出网机器

CS工具自带上线不出网机器 如图A区域存在一台中转机器,这台机器可以出网,这种是最常见的情况。我们在渗透测试的过程中经常是拿下一台边缘机器,其有多块网卡,边缘机器可以访问内网机器,内网机器都不出网。这种情况下拿这个边缘机器做中转,就可以使用CS工具自带上线不出网…

【每日一题】54. 螺旋矩阵

54. 螺旋矩阵 - 力扣&#xff08;LeetCode&#xff09; 给你一个 m 行 n 列的矩阵 matrix &#xff0c;请按照 顺时针螺旋顺序 &#xff0c;返回矩阵中的所有元素。 示例 1&#xff1a; 输入&#xff1a;matrix [[1,2,3],[4,5,6],[7,8,9]] 输出&#xff1a;[1,2,3,6,9,8,7,4,5…

Myvatis关联关系映射与表对象之间的关系

目录 一、关联关系映射 1.1 一对一 1.2 一对多 1.3 多对多 二、处理关联关系的方式 2.1 嵌套查询 2.2 嵌套结果 三、一对一关联映射 3.1 建表 ​编辑 3.2 配置文件 3.3 代码生成 3.4 编写测试 四、一对多关联映射 五、多对多关联映射 六、小结 一、关联关系映射 …

BFS练习1

BFS练习1 - 题目 - Daimayuan Online Judge 问题描述&#xff1a; 刚开始吓一跳&#xff0c;以为有什么更简单的呢&#xff0c;因为每一次都要走一次bfs&#xff0c;看了数据范围后&#xff0c;感觉跑一次bfs进行记录即可。 代码&#xff1a; void solve() {int a,k; cin>…

部署项目至服务器

安装conda https://zhuanlan.zhihu.com/p/489499097 个人租借的服务器如何进行端口的开放呢&#xff1f; 防火墙设置&#xff1a; 添加规则设置&#xff1a; 即可&#xff1b; 通常下租借的服务器没有防火墙设置 相关链接&#xff1a; https://blog.csdn.net/weixin_4520…

线上展厅可以用在哪些行业,线上展厅如何获取访客

引言&#xff1a; 随着数字化时代的到来&#xff0c;线上展厅成为了一种重要的营销工具&#xff0c;适用于多个行业&#xff0c;帮助他们吸引来自不同领域的潜在用户。 一&#xff0e;线上展厅在哪些行业有应用 1.零售行业 线上展厅为零售商提供了一个虚拟展示产品的平台&am…

Unity——工程与资源

本文将详细介绍Unity工程的文件夹结构&#xff0c;以及动态加载资源的技术要点 一、Unity项目的文件夹结构 1.工程文件夹 在新建工程时&#xff0c;Unity会创建所有必要的文件夹。第一级文件夹有Assets,Library,Logs,Packages,ProjectSettings。 Assets&#xff1a;最主要的文…

NVIDIA CUDA Win10安装步骤

前言 windows10 版本安装 CUDA &#xff0c;首先需要下载两个安装包 CUDA toolkit&#xff08;toolkit就是指工具包&#xff09;cuDNN 1. 安装前准备 在安装CUDA之前&#xff0c;需要完成以下准备工作&#xff1a; 确认你的显卡已经正确安装&#xff0c;在设备管理器中可以看…