MogaNet实战:使用MogaNet实现图像分类任务(一)

文章目录

  • 摘要
  • 安装包
    • 安装timm
  • 数据增强Cutout和Mixup
  • EMA
  • 项目结构
  • 计算mean和std
  • 生成数据集

摘要

论文:https://arxiv.org/pdf/2211.03295.pdf

作者多阶博弈论交互这一全新视角探索了现代卷积神经网络的表示能力。这种交互反映了不同尺度上下文中变量间的相互作用效果。提出了一种新的纯卷积神经网络架构族,称为MogaNet。MogaNet具有出色的可扩展性,在ImageNet和其他多种典型视觉基准测试中,与最先进的模型相比,其参数使用更高效,且具有竞争力的性能。具体来说,MogaNet在ImageNet上实现了80.0%和87.8%的Top-1准确率,分别使用了5.2M和181M参数,优于ParC-Net-S和ConvNeXt-L,同时节省了59%的浮点运算和17M的参数。源代码可在GitHub上(https://github.com/Westlake-AI/MogaNet)获取。

在这里插入图片描述

本文使用MogaNet模型实现图像分类任务,模型选择最小的moganet_xtiny,在植物幼苗分类任务ACC达到了96%+。

请添加图片描述

请添加图片描述

通过这篇文章能让你学到:

  1. 如何使用数据增强,包括transforms的增强、CutOut、MixUp、CutMix等增强手段?
  2. 如何实现MogaNet模型实现训练?
  3. 如何使用pytorch自带混合精度?
  4. 如何使用梯度裁剪防止梯度爆炸?
  5. 如何使用DP多显卡训练?
  6. 如何绘制loss和acc曲线?
  7. 如何生成val的测评报告?
  8. 如何编写测试脚本测试测试集?
  9. 如何使用余弦退火策略调整学习率?
  10. 如何使用AverageMeter类统计ACC和loss等自定义变量?
  11. 如何理解和统计ACC1和ACC5?
  12. 如何使用EMA?

如果基础薄弱,对上面的这些功能难以理解可以看我的专栏:经典主干网络精讲与实战
这个专栏,从零开始时,一步一步的讲解这些,让大家更容易接受。

安装包

安装timm

使用pip就行,命令:

pip install timm

mixup增强和EMA用到了timm

数据增强Cutout和Mixup

为了提高成绩我在代码中加入Cutout和Mixup这两种增强方式。实现这两种增强需要安装torchtoolbox。安装命令:

pip install torchtoolbox

Cutout实现,在transforms中。

from torchtoolbox.transform import Cutout
# 数据预处理
transform = transforms.Compose([transforms.Resize((224, 224)),Cutout(),transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])

需要导入包:from timm.data.mixup import Mixup,

定义Mixup,和SoftTargetCrossEntropy

  mixup_fn = Mixup(mixup_alpha=0.8, cutmix_alpha=1.0, cutmix_minmax=None,prob=0.1, switch_prob=0.5, mode='batch',label_smoothing=0.1, num_classes=12)criterion_train = SoftTargetCrossEntropy()

Mixup 是一种在图像分类任务中常用的数据增强技术,它通过将两张图像以及其对应的标签进行线性组合来生成新的数据和标签。
参数详解:

mixup_alpha (float): mixup alpha 值,如果 > 0,则 mixup 处于活动状态。

cutmix_alpha (float):cutmix alpha 值,如果 > 0,cutmix 处于活动状态。

cutmix_minmax (List[float]):cutmix 最小/最大图像比率,cutmix 处于活动状态,如果不是 None,则使用这个 vs alpha。

如果设置了 cutmix_minmax 则cutmix_alpha 默认为1.0

prob (float): 每批次或元素应用 mixup 或 cutmix 的概率。

switch_prob (float): 当两者都处于活动状态时切换cutmix 和mixup 的概率 。

mode (str): 如何应用 mixup/cutmix 参数(每个’batch’,‘pair’(元素对),‘elem’(元素)。

correct_lam (bool): 当 cutmix bbox 被图像边框剪裁时应用。 lambda 校正

label_smoothing (float):将标签平滑应用于混合目标张量。

num_classes (int): 目标的类数。

EMA

EMA(Exponential Moving Average)是指数移动平均值。在深度学习中的做法是保存历史的一份参数,在一定训练阶段后,拿历史的参数给目前学习的参数做一次平滑。具体实现如下:


import logging
from collections import OrderedDict
from copy import deepcopy
import torch
import torch.nn as nn_logger = logging.getLogger(__name__)class ModelEma:def __init__(self, model, decay=0.9999, device='', resume=''):# make a copy of the model for accumulating moving average of weightsself.ema = deepcopy(model)self.ema.eval()self.decay = decayself.device = device  # perform ema on different device from model if setif device:self.ema.to(device=device)self.ema_has_module = hasattr(self.ema, 'module')if resume:self._load_checkpoint(resume)for p in self.ema.parameters():p.requires_grad_(False)def _load_checkpoint(self, checkpoint_path):checkpoint = torch.load(checkpoint_path, map_location='cpu')assert isinstance(checkpoint, dict)if 'state_dict_ema' in checkpoint:new_state_dict = OrderedDict()for k, v in checkpoint['state_dict_ema'].items():# ema model may have been wrapped by DataParallel, and need module prefixif self.ema_has_module:name = 'module.' + k if not k.startswith('module') else kelse:name = knew_state_dict[name] = vself.ema.load_state_dict(new_state_dict)_logger.info("Loaded state_dict_ema")else:_logger.warning("Failed to find state_dict_ema, starting from loaded model weights")def update(self, model):# correct a mismatch in state dict keysneeds_module = hasattr(model, 'module') and not self.ema_has_modulewith torch.no_grad():msd = model.state_dict()for k, ema_v in self.ema.state_dict().items():if needs_module:k = 'module.' + kmodel_v = msd[k].detach()if self.device:model_v = model_v.to(device=self.device)ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v)

加入到模型中。

#初始化
if use_ema:model_ema = ModelEma(model_ft,decay=model_ema_decay,device='cpu',resume=resume)# 训练过程中,更新完参数后,同步update shadow weights
def train():optimizer.step()if model_ema is not None:model_ema.update(model)# 将model_ema传入验证函数中
val(model_ema.ema, DEVICE, test_loader)

针对没有预训练的模型,容易出现EMA不上分的情况,这点大家要注意啊!

项目结构

MogaNet_Demo
├─data1
│  ├─Black-grass
│  ├─Charlock
│  ├─Cleavers
│  ├─Common Chickweed
│  ├─Common wheat
│  ├─Fat Hen
│  ├─Loose Silky-bent
│  ├─Maize
│  ├─Scentless Mayweed
│  ├─Shepherds Purse
│  ├─Small-flowered Cranesbill
│  └─Sugar beet
├─models
│  ├─__init__.py
│  └─moganet.py
├─mean_std.py
├─makedata.py
├─train.py
└─test.py

mean_std.py:计算mean和std的值。
makedata.py:生成数据集。
train.py:训练MogaNet模型
models:来源官方代码。

计算mean和std

为了使模型更加快速的收敛,我们需要计算出mean和std的值,新建mean_std.py,插入代码:

from torchvision.datasets import ImageFolder
import torch
from torchvision import transformsdef get_mean_and_std(train_data):train_loader = torch.utils.data.DataLoader(train_data, batch_size=1, shuffle=False, num_workers=0,pin_memory=True)mean = torch.zeros(3)std = torch.zeros(3)for X, _ in train_loader:for d in range(3):mean[d] += X[:, d, :, :].mean()std[d] += X[:, d, :, :].std()mean.div_(len(train_data))std.div_(len(train_data))return list(mean.numpy()), list(std.numpy())if __name__ == '__main__':train_dataset = ImageFolder(root=r'data1', transform=transforms.ToTensor())print(get_mean_and_std(train_dataset))

数据集结构:

image-20220221153058619

运行结果:

([0.3281186, 0.28937867, 0.20702125], [0.09407319, 0.09732835, 0.106712654])

把这个结果记录下来,后面要用!

生成数据集

我们整理还的图像分类的数据集结构是这样的

data
├─Black-grass
├─Charlock
├─Cleavers
├─Common Chickweed
├─Common wheat
├─Fat Hen
├─Loose Silky-bent
├─Maize
├─Scentless Mayweed
├─Shepherds Purse
├─Small-flowered Cranesbill
└─Sugar beet

pytorch和keras默认加载方式是ImageNet数据集格式,格式是

├─data
│  ├─val
│  │   ├─Black-grass
│  │   ├─Charlock
│  │   ├─Cleavers
│  │   ├─Common Chickweed
│  │   ├─Common wheat
│  │   ├─Fat Hen
│  │   ├─Loose Silky-bent
│  │   ├─Maize
│  │   ├─Scentless Mayweed
│  │   ├─Shepherds Purse
│  │   ├─Small-flowered Cranesbill
│  │   └─Sugar beet
│  └─train
│      ├─Black-grass
│      ├─Charlock
│      ├─Cleavers
│      ├─Common Chickweed
│      ├─Common wheat
│      ├─Fat Hen
│      ├─Loose Silky-bent
│      ├─Maize
│      ├─Scentless Mayweed
│      ├─Shepherds Purse
│      ├─Small-flowered Cranesbill
│      └─Sugar beet

新增格式转化脚本makedata.py,插入代码:

import glob
import os
import shutilimage_list=glob.glob('data1/*/*.png')
print(image_list)
file_dir='data'
if os.path.exists(file_dir):print('true')#os.rmdir(file_dir)shutil.rmtree(file_dir)#删除再建立os.makedirs(file_dir)
else:os.makedirs(file_dir)from sklearn.model_selection import train_test_split
trainval_files, val_files = train_test_split(image_list, test_size=0.3, random_state=42)
train_dir='train'
val_dir='val'
train_root=os.path.join(file_dir,train_dir)
val_root=os.path.join(file_dir,val_dir)
for file in trainval_files:file_class=file.replace("\\","/").split('/')[-2]file_name=file.replace("\\","/").split('/')[-1]file_class=os.path.join(train_root,file_class)if not os.path.isdir(file_class):os.makedirs(file_class)shutil.copy(file, file_class + '/' + file_name)for file in val_files:file_class=file.replace("\\","/").split('/')[-2]file_name=file.replace("\\","/").split('/')[-1]file_class=os.path.join(val_root,file_class)if not os.path.isdir(file_class):os.makedirs(file_class)shutil.copy(file, file_class + '/' + file_name)

完成上面的内容就可以开启训练和测试了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/267759.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

#WEB前端(DIV、SPAN)

1.实验&#xff1a;DIV、SPAN 2.IDE&#xff1a;VSCODE 3.记录&#xff1a; 类? 4.代码&#xff1a; <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdev…

Python3零基础教程之数学运算专题进阶

大家好,我是千与编程,今天已经进入我们Python3的零基础教程的第十节之数学运算专题进阶。上一次的数学运算中我们介绍了简单的基础四则运算,加减乘除运算。当涉及到数学运算的 Python 3 刷题使用时,进阶课程包含了许多重要的概念和技巧。下面是一个简单的教程,涵盖了一些常…

BUUCTF---数据包中的线索1

1.题目描述 2.下载附件&#xff0c;是一个.pcap文件 3.放在wireshark中&#xff0c;仔细观察数据流&#xff0c;会发现有个叫fenxi.php的数据流 4.这条数据流是http,且使用GET方式&#xff0c;接下来我们使用http.request,methodGET 命令来过滤数据流 5.在分析栏中我们追踪htt…

VirtualBox 桥接网卡 未指定 “未能启动虚拟电脑Ubuntu,由于下述物理网卡未找到:”

解决办法&#xff0c;安装虚拟网卡&#xff0c;win11查找方式&#xff1a;控制面板→网络和共享中心→更改适配器设置 此时出现下面情况就算安装成功 但是如果报错&#xff1a;找不到指定的模块 则按下面步骤删除干净垃圾重新上面操作 先安装CCleaner, 链接:CCleaner Makes Y…

Day10:基础入门-HTTP数据包Postman构造请求方法请求头修改状态码判断

目录 数据-方法&头部&状态码 案例-文件探针 案例-登录爆破 工具-Postman自构造使用 思维导图 章节知识点&#xff1a; 应用架构&#xff1a;Web/APP/云应用/三方服务/负载均衡等 安全产品&#xff1a;CDN/WAF/IDS/IPS/蜜罐/防火墙/杀毒等 渗透命令&#xff1a;文件…

51单片机-(定时/计数器)

51单片机-&#xff08;定时/计数器&#xff09; 了解CPU时序、特殊功能寄存器和定时/计数器工作原理&#xff0c;以定时器0实现每次间隔一秒亮灯一秒的实验为例理解定时/计数器的编程实现。 1.CPU时序 1.1.四个周期 振荡周期&#xff1a;为单片机提供定时信号的振荡源的周期…

k8s.gcr.io/pause:3.2镜像丢失解决

文章目录 前言错误信息临时解决推荐解决onetwo 前言 使用Kubernetes&#xff08;k8s&#xff09;时遇到了镜像拉取的问题&#xff0c;导致Pod沙盒创建失败。错误显示在尝试从k8s.gcr.io拉取pause:3.2镜像时遇到了超时问题&#xff0c;这通常是因为网络问题或者镜像仓库服务器的…

springcloud:3.3测试重试机制

服务提供者【test-provider8001】 Openfeign远程调用服务提供者搭建 文章地址http://t.csdnimg.cn/06iz8 相关接口 测试远程调用&#xff1a;http://localhost:8001/payment/index 服务消费者【test-consumer-resilience4j8004】 Openfeign远程调用消费者搭建 文章地址http:/…

守护无价数据:文件备份的重要性与实用方案

一、数据安全之盾&#xff1a;文件备份的必要性 在数字化时代&#xff0c;我们生活的各个方面都与电子文件紧密相连。从工作文档、个人照片到珍贵视频&#xff0c;这些文件记录着我们的成长、工作与生活。然而&#xff0c;随着电子设备的使用频率增加&#xff0c;数据丢失的风…

mTSL: netty单向/双向TLS连接

创建证书 不管是单向tls还是双向tls(mTLS)&#xff0c;都需要创建证书。 创建证书可以使用openssl或者keytool&#xff0c;openssl 参考 mTLS: openssl创建CA证书 单向/双向tls需要使用到的相关文件: 文件单向tls双向tlsServer端Client端备注ca.key----需要保管好&#xff0…

Visual Studio C++项目远程断点调试客户现场程序方法

前言 程序开发一个很常见的场景&#xff0c;就是程序在自己本地部署调试明明一点问题都没有&#xff0c;但是部署到客户现场就问题百出&#xff0c;要调试起来还很困难&#xff0c;在自己本地也没有条件复现&#xff0c;很多时候只能靠日志一点点排查和猜测&#xff0c;耗费大…

[设计模式Java实现附plantuml源码~行为型] 对象状态及其转换——状态模式

前言&#xff1a; 为什么之前写过Golang 版的设计模式&#xff0c;还在重新写Java 版&#xff1f; 答&#xff1a;因为对于我而言&#xff0c;当然也希望对正在学习的大伙有帮助。Java作为一门纯面向对象的语言&#xff0c;更适合用于学习设计模式。 为什么类图要附上uml 因为很…

SparkStreaming在实时处理的两个场景示例

简介 Spark Streaming是Apache Spark生态系统中的一个组件&#xff0c;用于实时流式数据处理。它提供了类似于Spark的API&#xff0c;使开发者可以使用相似的编程模型来处理实时数据流。 Spark Streaming的工作原理是将连续的数据流划分成小的批次&#xff0c;并将每个批次作…

(C语言)函数详解上

&#xff08;C语言&#xff09;函数详解上 目录&#xff1a; 1. 函数的概念 2. 库函数 2.1 标准库和头文件 2.2 库函数的使用方法 2.2.1 sqrt 功能 2.2.2 头文件包含 2.2.3 实践 2.2.4 库函数文档的一般格式 3. 自定义函数 3.1 函数的语法形式 3.2 函数的举例 4. 形参和实参 4.…

Redis 命令全解析之 List类型

文章目录 命令RedisTemplate API使用场景 Redis 的 List 是一种有序、可重复、可变动的数据结构&#xff0c;它基于双向链表实现。在Redis中&#xff0c;List可以存储多个相同或不同类型的元素&#xff0c;每个元素在List中都有一个对应的索引位置。这使得List可以用来实现队列…

【计算机网络_应用层】协议定制序列化反序列化

文章目录 1. TCP协议的通信流程2. 应用层协议定制3. 通过“网络计算器”的实现来实现应用层协议定制和序列化3.1 protocol3.2 序列化和反序列化3.2.1 手写序列化和反序列化3.2.2 使用Json库 3.3 数据包读取3.4 服务端设计3.5 最后的源代码和运行结果 1. TCP协议的通信流程 在之…

oppo手机备忘录记录怎么转移到华为手机?

oppo手机备忘录记录怎么转移到华为手机?使用oppo手机已经有三四年了&#xff0c;因为平时习惯&#xff0c;在手机系统的备忘录中记录了很多重要的笔记&#xff0c;比如工作会议的要点、读书笔记、购物清单、朋友的生日提醒等。这些记录对我来说非常重要&#xff0c;我可以通过…

ElasticSearch搜索引擎使用指南

一、ES数据基础类型 1、数据类型 字符串 主要包括: text和keyword两种类型&#xff0c;keyword代表精确值不会参与分词&#xff0c;text类型的字符串会参与分词处理 数值 包括: long, integer, short, byte, double, float 布尔值 boolean 时间 date 数组 数组类型不…

独立游戏《星尘异变》UE5 C++程序开发日志1——项目与代码管理

写在前面&#xff1a;本日志系列将会向大家介绍在《星尘异变》这款模拟经营游戏&#xff0c;在开发时用到的与C相关的泛用代码与算法&#xff0c;主要记录UE5C与原生C的用法区别&#xff0c;以及遇到的问题和解决办法&#xff0c;因为这是我本人从ACM退役以后第一个从头开始的项…

【数据结构】知识点一:线性表之顺序表

内容导航 一、什么是线性表&#xff1f;二、什么是顺序表&#xff1f;1、顺序表的概念2、顺序表的结构a. 静态顺序表&#xff1a;使用定长数组存储元素。b. 动态顺序表&#xff1a;使用动态开辟的数组存储。 三、顺序表的接口实现精讲1.接口一&#xff1a;打印数据2.接口二&…