旷视科技ShuffleNetV1代码分析[pytorch版]

一、前述 

旷视科技针对于ShuffleNet系列网络在GitHub网站上已开源,其链接:https://github.com/megvii-model/ShuffleNet-Series

在这个系列中,包括了ShuffleNetV1/V2网络,如下图所示。 

我们点开ShuffleNetV1文件夹,如下图所示。 

  • ShuffleNetV1文件夹中有五个文件,分别为:README.md、blocks.py、network.py、train.py、utils.py文件。
  • 其中,blocks.py中的代码是ShuffleNetV1的基本模块;
  • network.py 中的代码是 blocks.py 中基本模块堆叠出来的 ShuffleNetV1 网络;
  • train.py 中是训练 ImageNet 数据集图像分类的训练代码;
  • utils.py 是一些常用的函数。

旷视科技GitHub网站给出的ShufflNetV1网络的结果,如下表所示: 

二、代码分析

2.1 blocks.py(ShuffleNetV1 Unit) 

我们先来回顾以下ShuffleNetV1 Unit,如下图(b)、图(c)所示。 
图(b)表示的是stride=1的ShuffleNetV1 Unit,在该基本单元中,右侧被称为主分支,在该主分支中:
①先1×1GConv(group pointwise convolution)降维,第一个红色模块;
②然后channel shuffle,蓝色模块;
③再然后3×3DWConv(depthwise convolution),绿色模块;
④然后再1×1GConv升维,第二个红色模块。

图(c)表示的是stride=2 的ShuffleNetV1 Unit,在该基本单元中,右侧被称为主分支,在该主分支中:
①先1×1GConv(group pointwise convolution)降维,第一个红色模块;
②然后channel shuffle,蓝色模块;
③再然后3×3DWConv(depthwise convolution),绿色模块;
④然后再1×1GConv升维,第二个红色模块。

ShuffleNetV1 Unit
(b)stride = 1; (c)stride = 2

ShuffleNetV1网络基本模块的总体代码如下所示,该代码包括了:stride=1的基本单元构建、stride=2的基本单元构建、channel shuffle(通道重排)操作。 

# blocks.py
import torch
import torch.nn as nn
import torch.nn.functional as Fclass ShuffleV1Block(nn.Module):def __init__(self, inp, oup, *, group, first_group, mid_channels, ksize, stride):super(ShuffleV1Block, self).__init__()self.stride = strideassert stride in [1, 2]self.mid_channels = mid_channelsself.ksize = ksizepad = ksize // 2self.pad = padself.inp = inpself.group = groupif stride == 2:outputs = oup - inpelse:outputs = oupbranch_main_1 = [# pwnn.Conv2d(inp, mid_channels, 1, 1, 0, groups=1 if first_group else group, bias=False),nn.BatchNorm2d(mid_channels),nn.ReLU(inplace=True),# dwnn.Conv2d(mid_channels, mid_channels, ksize, stride, pad, groups=mid_channels, bias=False),nn.BatchNorm2d(mid_channels),]branch_main_2 = [# pw-linearnn.Conv2d(mid_channels, outputs, 1, 1, 0, groups=group, bias=False),nn.BatchNorm2d(outputs),]self.branch_main_1 = nn.Sequential(*branch_main_1)self.branch_main_2 = nn.Sequential(*branch_main_2)if stride == 2:self.branch_proj = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)def forward(self, old_x):x = old_xx_proj = old_xx = self.branch_main_1(x)if self.group > 1:x = self.channel_shuffle(x)x = self.branch_main_2(x)if self.stride == 1:return F.relu(x + x_proj)elif self.stride == 2:return torch.cat((self.branch_proj(x_proj), F.relu(x)), 1)def channel_shuffle(self, x):batchsize, num_channels, height, width = x.data.size()assert num_channels % self.group == 0group_channels = num_channels // self.groupx = x.reshape(batchsize, group_channels, self.group, height, width)x = x.permute(0, 2, 1, 3, 4)x = x.reshape(batchsize, num_channels, height, width)return x

我们一步一步做好乐高积木然后将这些乐高积木拼装起来,如下: 

图(b)主分支代码如下: 

图(c)主分支代码如下:
图(c)侧分支代码如下:

channel shuffle代码: 

做好乐高积木之后,我们在forward函数中开始搭建这些乐高积木,如下所示: 

2.2 networks.py (ShuffleNetV1网络架构)

ShuffleNetV1网络架构: 

 ShuffleNetV1 网络架构代码:

import torch
import torch.nn as nn
from blocks import ShuffleV1Blockclass ShuffleNetV1(nn.Module):def __init__(self, input_size=224, n_class=1000, model_size='2.0x', group=None):super(ShuffleNetV1, self).__init__()print('model size is ', model_size)assert group is not Noneself.stage_repeats = [4, 8, 4]self.model_size = model_sizeif group == 3:if model_size == '0.5x':self.stage_out_channels = [-1, 12, 120, 240, 480]elif model_size == '1.0x':self.stage_out_channels = [-1, 24, 240, 480, 960]elif model_size == '1.5x':self.stage_out_channels = [-1, 24, 360, 720, 1440]elif model_size == '2.0x':self.stage_out_channels = [-1, 48, 480, 960, 1920]else:raise NotImplementedErrorelif group == 8:if model_size == '0.5x':self.stage_out_channels = [-1, 16, 192, 384, 768]elif model_size == '1.0x':self.stage_out_channels = [-1, 24, 384, 768, 1536]elif model_size == '1.5x':self.stage_out_channels = [-1, 24, 576, 1152, 2304]elif model_size == '2.0x':self.stage_out_channels = [-1, 48, 768, 1536, 3072]else:raise NotImplementedError# building first layerinput_channel = self.stage_out_channels[1]self.first_conv = nn.Sequential(nn.Conv2d(3, input_channel, 3, 2, 1, bias=False),nn.BatchNorm2d(input_channel),nn.ReLU(inplace=True),)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)self.features = []for idxstage in range(len(self.stage_repeats)):numrepeat = self.stage_repeats[idxstage]output_channel = self.stage_out_channels[idxstage+2]for i in range(numrepeat):stride = 2 if i == 0 else 1first_group = idxstage == 0 and i == 0self.features.append(ShuffleV1Block(input_channel, output_channel,group=group, first_group=first_group,mid_channels=output_channel // 4, ksize=3, stride=stride))input_channel = output_channelself.features = nn.Sequential(*self.features)self.globalpool = nn.AvgPool2d(7)self.classifier = nn.Sequential(nn.Linear(self.stage_out_channels[-1], n_class, bias=False))self._initialize_weights()def forward(self, x):x = self.first_conv(x)x = self.maxpool(x)x = self.features(x)x = self.globalpool(x)x = x.contiguous().view(-1, self.stage_out_channels[-1])x = self.classifier(x)return xdef _initialize_weights(self):for name, m in self.named_modules():if isinstance(m, nn.Conv2d):if 'first' in name:nn.init.normal_(m.weight, 0, 0.01)else:nn.init.normal_(m.weight, 0, 1.0 / m.weight.shape[1])if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.BatchNorm2d):nn.init.constant_(m.weight, 1)if m.bias is not None:nn.init.constant_(m.bias, 0.0001)nn.init.constant_(m.running_mean, 0)elif isinstance(m, nn.BatchNorm1d):nn.init.constant_(m.weight, 1)if m.bias is not None:nn.init.constant_(m.bias, 0.0001)nn.init.constant_(m.running_mean, 0)elif isinstance(m, nn.Linear):nn.init.normal_(m.weight, 0, 0.01)if m.bias is not None:nn.init.constant_(m.bias, 0)if __name__ == "__main__":model = ShuffleNetV1(group=3)# print(model)test_data = torch.rand(5, 3, 224, 224)test_outputs = model(test_data)print(test_outputs.size())

分析: 

 

asset函数:
张量的连续性:https://blog.csdn.net/m0_48241022/article/details/132804698 
如何理解张量、张量索引等:https://blog.csdn.net/m0_48241022/article/details/132729561
torch.nn.Conv2d函数:
torch.nn.BatchNorm2d函数:
torch.nn.ReLU函数:
torch.nn.AvgPool2d函数:
torch.nn.Linear函数:
torch.nn.Sequential函数:
torch.cat函数:
permute函数:

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/430706.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

python爬虫:从12306网站获取火车站信息

代码逻辑 初始化 (init 方法): 设置请求头信息。设置车站版本号。 同步车站信息 (synchronization 方法): 发送GET请求获取车站信息。返回服务器响应的文本。 提取信息 (extract 方法): 从服务器响应中提取车站信息字符串。去掉字符串末尾的…

UML——统一建模语言

序言: 是统一建模语言的简称,它是一种由一整套图表组成的标准化建模语言。UML用于帮助系统开发人员阐明,展示,构建和记录软件系统的产出。UML代表了一系列在大型而复杂系统建模中被证明是成功的做法,是开发面向对象软件…

【计算机基础】用bat命令将Unity导出PC包转成单个exe可执行文件

Unity打包成exe可执行文件 上边连接是很久以前用过的方法,发现操作有些不一样了,并且如果按上述操作比较麻烦,所以写了个bat命令。 图1、导出的pc程序 如图1是导出的pc程序,点击exe文件可运行该程序。 添加pack_project.bat文件 …

el-form中三级动态添加数据

el-form中三级动态添加数据 data数据view按钮触发事件 data数据 submitForm: {id: undefined, //修改IDapp_id: undefined, //IP类型name: , //规则名称sort: undefined, //排序detail: [{keycode: 0,title_one: undefined, //一级标题desc_detail: [{keycode: 0,title_two: u…

LPDDR4芯片学习(一)——基础知识与引脚定义

一、基础知识 01 dram基本存储单元 当需要将一位数据存储到DRAM中时,晶体管会充电或放电电容。充电的电容表示逻辑高(1),放电的电容表示逻辑低(0)。由于电容会随着时间泄漏电荷,因此需要定期刷…

Axure大屏可视化模板:跨领域数据分析平台原型案例

随着信息技术的飞速发展,数据可视化已成为各行各业提升管理效率、优化决策过程的重要手段。Axure作为一款强大的原型设计工具,其大屏可视化模板在农业、园区、城市、企业数据可视化、医疗等多个领域得到了广泛应用。本文将通过几个具体案例,展…

PyTorch使用------自动微分模块

目录 🍔 梯度基本计算 1.1 单标量梯度的计算 1.2 单向量梯度的计算 1.3 多标量梯度计算 1.4 多向量梯度计算 1.5 运行结果💯 🍔 控制梯度计算 2.1 控制不计算梯度 2.2 注意: 累计梯度 2.3 梯度下降优化最优解 2.4 运行结果&#x1…

mybatis 配置文件完成增删改查(五) :单条件 动态sql查询,相当于switch

文章目录 单条件 动态sql查询写测试方法 疑问总结 单条件 动态sql查询 <select id"selectByConditionBySingle" resultMap"brandResultMap">.select *from tb_brandwhere<choose>/*相当于switch*/<when test"status ! null">…

江协科技STM32学习- P17 TIM输入捕获

&#x1f680;write in front&#x1f680; &#x1f50e;大家好&#xff0c;我是黄桃罐头&#xff0c;希望你看完之后&#xff0c;能对你有所帮助&#xff0c;不足请指正&#xff01;共同学习交流 &#x1f381;欢迎各位→点赞&#x1f44d; 收藏⭐️ 留言&#x1f4dd;​…

if __name__ == ‘__main__‘: 在 Python 中的作用

Python Python 是一种广泛使用的高级编程语言&#xff0c;它以其易读性和简洁的语法而闻名。Python 支持多种编程范式&#xff0c;包括面向对象、命令式、函数式和过程式编程。它由 Guido van Rossum 创建&#xff0c;并在 1991 年首次发布。 Python 的一些关键特性包括&#…

Python中requests模块(爬虫)基本使用

Python的requests模块是一个非常流行的HTTP库&#xff0c;用于发送HTTP/1.1请求。 一、模块导入 1、requests模块的下载&#xff1a; 使用包管理器下载&#xff0c;在cmd窗口&#xff0c;或者在项目的虚拟环境目录下&#xff1a; pip3 install -i https://pypi.tuna.tsingh…

Chrome开发者工具如何才能看到Vue项目的源码

大家好&#xff0c;我是 程序员码递夫。 今天给大家分享的是 Chrome开发者工具如何才能看到Vue项目的源码。 问题 我们在编写一下Vue项目时&#xff0c;常常要通过 chrome 进行本地调试后&#xff0c;才打包 生产版本。 但有时打开 chrome 的开发者工具后&#xff0c;看到的…

什么是反射,反射用途,spring哪些地方用到了反射,我们项目中哪些地方用到了反射

3分钟搞懂Java反射 一、反射是什么 在Java中&#xff0c;反射&#xff08;Reflection&#xff09;是一种强大的工具&#xff0c;它允许程序在运行时获取和操作类、接口、构造器、方法和字段等。反射是Java语言的一个重要特性&#xff0c;它为开发人员提供了许多灵活性&#xf…

50页PPT麦肯锡精益运营转型五步法

读者朋友大家好&#xff0c;最近有会员朋友咨询晓雯&#xff0c;需要《 50页PPT麦肯锡精益运营转型五步法》资料&#xff0c;欢迎大家下载学习。 知识星球已上传的资料链接&#xff1a; 企业架构 企业架构 (EA) 设计咨询项目-企业架构治理(EAM)现状诊断 105页PPTHW企业架构设…

收据信息提取系统源码分享

收据信息提取检测系统源码分享 [一条龙教学YOLOV8标注好的数据集一键训练_70全套改进创新点发刊_Web前端展示] 1.研究背景与意义 项目参考AAAI Association for the Advancement of Artificial Intelligence 项目来源AACV Association for the Advancement of Computer Vis…

vue-baidu-map的基本使用

前言 公司项目需求引入百度地图&#xff0c;由于给的时间比较短&#xff0c;所以就用了已经封装好了的vue-baidu-map 一、vue-baidu-map是什么&#xff1f; vue-baidu-map是基于vue.js封装的百度地图组件(官方文档) 二、使用步骤 1.下载插件 //我下载的版本 npm install …

在虚幻引擎中实现Camera Shake 相机抖动/震屏效果

在虚幻引擎游戏中创建相机抖动有时能让画面更加高级 , 比如 遇到大型的Boss , 出现一些炫酷的特效 加一些短而快的 Camera Shake 能达到很好的效果 , 为玩家提供沉浸感 创建Camera Shake 调整Shake参数 到第三人称或第一人称蓝图 调用Camera Shake Radius值越大 晃动越强

vscode缩进 和自动格式化

如下图&#xff0c;缩进太大了。 检查2个地方 prettierrc.cjs文件。此处决定缩进几个tab vscode 的设置。 保存的时候 格式化。

数据结构——顺序表、链表

目录 前言 一&#xff0c;数据结构 1&#xff0c;什么是数据结构&#xff1f; 2&#xff0c;有什么类型&#xff1f; 二&#xff0c;顺序表 1&#xff0c;线性表 2&#xff0c;顺序表基本结构 3&#xff0c;动态顺序表的功能实现 三&#xff0c;链表 1&#xff0c;链…

AI大模型微调训练营,全面解析微调技术理论,掌握大模型微调核心技能

一、引言 随着人工智能技术的飞速发展&#xff0c;大型预训练模型&#xff08;如GPT、BERT、Transformer等&#xff09;已成为自然语言处理、图像识别等领域的核心工具。然而&#xff0c;这些大模型在直接应用于特定任务时&#xff0c;往往无法直接达到理想的性能。因此&#…