PyTorch从零开始实现ResNet

文章目录

    • 代码实现
    • 参考

代码实现

本文实现 ResNet原论文 Deep Residual Learning for Image Recognition 中的50层,101层和152层残差连接。
在这里插入图片描述
代码中使用基础残差块这个概念,这里的基础残差块指的是上图中红色矩形圈出的内容:从上到下分别使用3, 4, 6, 3个基础残差块,每个基础残差块由三个卷积层组成,核大小分别为1x1, 3x3, 1x1 。

残差连接的结构

在这里插入图片描述

复现代码如下:

import torch
import torch.nn as nn# 基础残差块,后面ResNet要多次重复使用该块
class block(nn.Module):def __init__(self, in_channels, out_channels, identity_downsample=None, stride=1):super(block, self).__init__()self.expansion = 4  self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)self.bn1 = nn.BatchNorm2d(out_channels)self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1)self.bn2 = nn.BatchNorm2d(out_channels)self.conv3 = nn.Conv2d(out_channels, out_channels*self.expansion, kernel_size=1, stride=1, padding=0)self.bn3 = nn.BatchNorm2d(out_channels*self.expansion)self.relu = nn.ReLU()self.identity_downsample = identity_downsampledef forward(self, x):identity = xx = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.conv2(x)x = self.bn2(x)x = self.relu(x)x = self.conv3(x)x = self.bn3(x)if self.identity_downsample is not None:identity = self.identity_downsample(identity)x += identityx = self.relu(x)return xclass ResNet(nn.Module):def __init__(self, block, layers, image_channels, num_classes):super(ResNet, self).__init__()# 初始化的层self.in_channels = 64self.conv1 = nn.Conv2d(image_channels, 64, kernel_size=7, stride=2, padding=3)self.bn1 = nn.BatchNorm2d(64)self.relu = nn.ReLU()self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)# ResNet layersself.layer1 = self._make_layer(block, layers[0], out_channels=64, stride=1)self.layer2 = self._make_layer(block, layers[1], out_channels=128, stride=2)self.layer3 = self._make_layer(block, layers[2], out_channels=256, stride=2)self.layer4 = self._make_layer(block, layers[3], out_channels=512, stride=2)self.avgpool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(512*4, num_classes)def forward(self, x):x = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.maxpool(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)x = self.avgpool(x)x = x.reshape(x.shape[0], -1)x = self.fc(x)return x# 核心函数:调用block基础残差块,构造ResNet的每一层def _make_layer(self, block, num_residual_blocks, out_channels, stride):identity_downsample = Nonelayers = []if stride != 1 or self.in_channels != out_channels * 4:identity_downsample = nn.Sequential(nn.Conv2d(self.in_channels, out_channels*4, kernel_size=1,stride=stride),                                               nn.BatchNorm2d(out_channels*4))layers.append(block(self.in_channels, out_channels, identity_downsample, stride))self.in_channels = out_channels * 4for i in range(num_residual_blocks - 1):layers.append(block(self.in_channels, out_channels)) # 256 -> 64, 64*4(256) againreturn nn.Sequential(*layers)# 构造ResNet50层:默认图像通道3,分类类别为1000
def resnet50(img_channels=3, num_classes=1000):return ResNet(block, [3, 4, 6, 3], img_channels, num_classes)# 构造ResNet101层  
def resnet101(img_channels=3, num_classes=1000):return ResNet(block, [3, 4, 23, 3], img_channels, num_classes)# 构造ResNet152层  
def resnet152(img_channels=3, num_classes=1000):return ResNet(block, [3, 8, 36, 3], img_channels, num_classes)# 测试输出y的形状是否满足1000类
def test():net = resnet152()x = torch.randn(2, 3, 224, 224)y = net(x)print(y.shape) # [2, 1000]test()

参考

[1] Deep Residual Learning for Image Recognition
[2] https://www.youtube.com/watch?v=DkNIBBBvcPs&list=PLhhyoLH6IjfxeoooqP9rhU3HJIAVAJ3Vz&index=19

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/91620.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【博客692】grafana如何解决step动态变化时可能出现range duration小于step

grafana如何解决step动态变化时可能出现range duration小于step 1、grafana中的step和resolution grafana中的 “step” grafana本身是没有提供step参数的,因为仪表盘根据查询数据区间以及仪表盘线条宽度等,对于不同查询,相同的step并不能…

win10下如何安装ffmpeg

安装ffmpeg之前先安装win10 绿色软件管理软件:scoop. Scoop的基本介绍 Scoop是一款适用于Windows平台的命令行软件(包)管理工具,这里是Github介绍页。简单来说,就是可以通过命令行工具(PowerShell、CMD等…

构建 LVS-DR 群集、配置nginx负载均衡。

目录 一、基于 CentOS 7 构建 LVS-DR 群集 1、准备四台虚拟机 2、配置负载调度器(192.168.2.130) 3、部署共享存储(192.168.2.133) 4、配置两个Web服务器(192.168.2.131、192.168.2.132) 测试集群 二…

广联达OA前台sql注入+后台文件上传漏洞复现分析

文章目录 前言资产特征前台sql注入后台文件上传解决办法 前言 最近看到广联达OA的前端sql注入和后端文件上传漏洞联动的poc 广联达科技股份有限公司以建设工程领域专业应用为核心基础支撑,提供一百余款基于“端云大数据”产品/服务,提供产业大数据、产业…

Qt读写Excel--QXlsx编译为静态库2

1、概述🥔 在使用QXlsx时由于源码文件比较多,如果直接加载进项目里面,会增加每次编译的时间; 直接将源码加载进项目工程中,会导致项目文件非常多,结构变得更加臃肿; 所以在本文中将会将QXlsx编译…

leetcode810. 黑板异或游戏(博弈论 - java)

黑板异或游戏 lc 810 - 黑板异或游戏题目描述博弈论 动态规划 lc 810 - 黑板异或游戏 难度 - 困难 原题链接 - 黑板异或游戏 题目描述 黑板上写着一个非负整数数组 nums[i] 。 Alice 和 Bob 轮流从黑板上擦掉一个数字,Alice 先手。如果擦除一个数字后,剩…

人工智能任务1-【NLP系列】句子嵌入的应用与多模型实现方式

大家好,我是微学AI,今天给大家介绍一下人工智能任务1-【NLP系列】句子嵌入的应用与多模型实现方式。句子嵌入是将句子映射到一个固定维度的向量表示形式,它在自然语言处理(NLP)中有着广泛的应用。通过将句子转化为向量…

Springboot + Vue ElementUI 实现MySQLPostgresql可视化

一、功能展示: 效果如图: DB连接配置维护: Schema功能:集成Screw生成文档,导出库的表结构,导出表结构和数据 表对象操作:翻页查询,查看创建SQL,生成代码 可以单个代码文…

微服务与Nacos概述-6

RBAC 模型 RBAC 基于角色的访问控制是实施面向企业安全策略的一种有效的访问控制方式。 基本思想是,对系统操作的各种权限不是直接授予具体的用户,而是在用户集合与权限集合之间建立一个角色集合。每一种角色对应一组相应的权限。一旦用户被分配了适当…

时序预测 | MATLAB实现EEMD-GRU、GRU集合经验模态分解结合门控循环单元时间序列预测对比

时序预测 | MATLAB实现EEMD-GRU、GRU集合经验模态分解结合门控循环单元时间序列预测对比 目录 时序预测 | MATLAB实现EEMD-GRU、GRU集合经验模态分解结合门控循环单元时间序列预测对比效果一览基本介绍模型搭建程序设计参考资料 效果一览 基本介绍 1.MATLAB实现EEMD-GRU、GRU时…

SpringBoot 的自动装配特性

1. Spring Boot 的自动装配特性 Spring Boot 的自动装配(Auto-Configuration)是一种特性,它允许您在应用程序中使用默认配置来自动配置 Spring Framework 的各种功能和组件,从而减少了繁琐的配置工作。通过自动装配,您…

操作系统—网络系统

什么是零拷贝 磁盘是计算机系统最慢的的硬件之一,所以有不少优化磁盘的方法,比如零拷贝、直接IO、异步IO等等,这些优化的目的是为了提高系统的吞吐量,另外操作系统内核中的磁盘高度缓存区,可以有效的减少磁盘的访问次…

matlab使用教程(15)—图论基础

1.有向图和无向图 1.1什么是图? 图是表示各种关系的节点和边的集合: • 节点 是与对象对应的顶点。 • 边 是对象之间的连接。 • 图的边有时会有权重 ,表示节点之间的每个连接的强度(或一些其他属性)。 这些定…

嵌入式Linux开发实操(八):UART串口开发

串口可以说是非常好用的一个接口,它同USB、CAN、I2C、SPI等接口一样,为SOC/MCU构建了丰富的接口功能。那么在嵌入式linux中又是如何搭建和使用UART接口的呢? 一、Console接口即ttyS0 ttyS0通常做为u-boot(bootloader的一种,像是Windows的BIOS),它需要一个交互界面,一般…

pytest 编写规范

一、pytest 编写规范 1、介绍 pytest是一个非常成熟的全功能的Python测试框架,主要特点有以下几点: 1、简单灵活,容易上手,文档丰富;2、支持参数化,可以细粒度地控制要测试的测试用例;3、能够…

问AI一个严肃的问题

chatgpt的问世再一次掀起了AI的浪潮,其实我一直在想,AI和人类的关系未来会怎样发展,我们未来会怎样和AI相处,AI真的会完全取代人类吗,带着这个问题,我问了下chatgpt,看一看它是怎么看待这个问题…

DNS WEB HTTP

DNS与域名 网络是基于 TCP/IP 协议进行通信和连接的。 每一台主机都有唯一的标识,用于区别在网络上成千上万个用户和计算机。即固定的IP地址(32位二进制数转换成为十进制数——点分十进制)。每一个与网络相连接的计算机和服务器都被指派一个…

应用程序运行报错:First section must be [net] or [network]:No such file or directory

应用程序报错环境: 在linux下,调用darknet训练的模型,报错:First section must be [net] or [network]:No such file or directory,并提示:"./src/utils.c:256: error: Assertion 0 failed." 如…

pytest自动化测试框架tep环境变量、fixtures、用例三者之间的关系

tep是一款测试工具,在pytest测试框架基础上集成了第三方包,提供项目脚手架,帮助以写Python代码方式,快速实现自动化项目落地。 在tep项目中,自动化测试用例都是放到tests目录下的,每个.py文件相互独立&…

TPAMI, 2023 | 用压缩隐逆向神经网络进行高精度稀疏雷达成像

CoIR: Compressive Implicit Radar | IEEE TPAMI, 2023 | 用压缩隐逆向神经网络进行高精度稀疏雷达成像 注1:本文系“无线感知论文速递”系列之一,致力于简洁清晰完整地介绍、解读无线感知领域最新的顶会/顶刊论文(包括但不限于Nature/Science及其子刊;MobiCom, Sigcom, MobiSy…