动手学深度学习(Pytorch版)代码实践 -卷积神经网络-29残差网络ResNet

29残差网络ResNet

在这里插入图片描述

import torch  
from torch import nn  
from torch.nn import functional as F 
import liliPytorch as lp  
import matplotlib.pyplot as plt# 定义一个继承自nn.Module的残差块类
class Residual(nn.Module):def __init__(self, input_channels, num_channels, use_1x1conv=False, strides=1):super().__init__()# 第一个卷积层,使用3x3的卷积核,填充为1,步幅为指定值self.conv1 = nn.Conv2d(input_channels, num_channels, kernel_size=3, padding=1, stride=strides)# 第二个卷积层,使用3x3的卷积核,填充为1self.conv2 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1)# 可选的1x1卷积层,用于匹配输入输出通道数和步幅if use_1x1conv:self.conv3 = nn.Conv2d(input_channels, num_channels, kernel_size=1, stride=strides)else:self.conv3 = None# 批量归一化层self.bn1 = nn.BatchNorm2d(num_channels)self.bn2 = nn.BatchNorm2d(num_channels)# 为什么需要两个不同的批量归一化层?# 1.不同的位置,不同的输入特征# 2.独立的参数和统计数据def forward(self, X):# 先通过第一个卷积层、批量归一化层和ReLU激活函数Y = F.relu(self.bn1(self.conv1(X)))# 然后通过第二个卷积层和批量归一化层Y = self.bn2(self.conv2(Y))# 如果定义了conv3,则通过conv3调整Xif self.conv3:X = self.conv3(X)# 将输入X加到输出Y上实现残差连接Y += X# 通过ReLU激活函数return F.relu(Y)# 创建一个包含输入和输出形状一致的残差块实例,并测试其输出形状
# blk = Residual(3, 3)
# X = torch.rand(4, 3, 6, 6)
# Y = blk(X)
# print(Y.shape)  # 预期输出形状:torch.Size([4, 3, 6, 6])# 创建一个包含1x1卷积和步幅为2的残差块实例,并测试其输出形状
# blk = Residual(3, 6, use_1x1conv=True, strides=2)
# print(blk(X).shape)  # 预期输出形状:torch.Size([4, 6, 3, 3])# 定义一个包含初始卷积层、批量归一化层、ReLU激活函数和最大池化层的顺序容器
b1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),nn.BatchNorm2d(64),nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
)# 定义一个函数,用于创建由多个残差块组成的模块
def resnet_block(input_channels, num_channels, num_residuals, first_block=False):blk = []for i in range(num_residuals):# 如果是第一个残差块且不是第一个模块,则使用1x1卷积和步幅为2if i == 0 and not first_block:blk.append(Residual(input_channels, num_channels, use_1x1conv=True, strides=2))else:blk.append(Residual(num_channels, num_channels))return blk# 创建由残差块组成的各个模块
# *符号有多种用途,但在函数调用时,*符号主要用于将列表或元组解包。
# *resnet_block()的作用是将列表中的元素逐个传递给nn.Sequential
b2 = nn.Sequential(*resnet_block(64, 64, 2, first_block=True))
b3 = nn.Sequential(*resnet_block(64, 128, 2))
b4 = nn.Sequential(*resnet_block(128, 256, 2))
b5 = nn.Sequential(*resnet_block(256, 512, 2))# 创建整个ResNet模型
net = nn.Sequential(b1, b2, b3, b4, b5,nn.AdaptiveAvgPool2d((1, 1)),  # 自适应平均池化层nn.Flatten(),  # 展平层nn.Linear(512, 176)  # 全连接层,输出10类
)# 测试整个网络的输出形状
X = torch.rand(size=(1, 1, 96, 96))
for layer in net:X = layer(X)print(layer.__class__.__name__, 'output shape:\t', X.shape)
# Sequential output shape:         torch.Size([1, 64, 24, 24])
# Sequential output shape:         torch.Size([1, 64, 24, 24])
# Sequential output shape:         torch.Size([1, 128, 12, 12])
# Sequential output shape:         torch.Size([1, 256, 6, 6])
# Sequential output shape:         torch.Size([1, 512, 3, 3])
# AdaptiveAvgPool2d output shape:  torch.Size([1, 512, 1, 1])
# Flatten output shape:    torch.Size([1, 512])
# Linear output shape:     torch.Size([1, 10])# 设置训练参数
lr, num_epochs, batch_size = 0.05, 10, 256
# 加载训练和测试数据
train_iter, test_iter = lp.loda_data_fashion_mnist(batch_size, resize=96)
# 训练模型
lp.train_ch6(net, train_iter, test_iter, num_epochs, lr, lp.try_gpu())
# 显示训练结果
plt.show()# loss 0.009, train acc 0.998, test acc 0.920
# 2306.3 examples/sec on cuda:0

运行结果:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/364133.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ROS2创建自定义接口

ROS2提供了四种通信方式: 话题-Topics 服务-Services 动作-Action 参数-Parameters 查看系统自定义接口命令 使用ros2 interface package sensor_msgs命令可以查看某一个接口包下所有的接口 除了参数之外,话题、服务和动作(Action)都支持自定义接口&am…

微服务实战系列之云原生

前言 话说博主的微服务实战系列从去年走到今天,已过去了半年多了。本系列,博主主要围绕微服务实践过程中的主要组件或工具展开介绍。其中基本覆盖了我们项目或产品研发过程中,经常使用的中间件或第三方工具。至此,该系列也该朝着…

LangChain真的好用吗?谈一下LangChain封装FAISS的一些坑

最近在做一个知识库问答项目,就是现在大模型浪潮下比较火的 RAG 应用。LangChain 可以说是 RAG 最受欢迎的工具,因此我首选 LangChain 来快速构建我的应用。坦白来讲 LangChain 本身一套对于组件的定义已经让我感觉很复杂,为什么采用 f-strin…

SM2258XT量产工具,SM2258XT开卡三星SSV4颗粒成功分享,SM2259XT量产参考教程,威刚ADATA SP580开卡记录

前两天拆了笔记本上的威刚ADATA SP580 240GB,准备做移动硬盘用,装入移动硬盘盒之后接入电脑,发现系统可认盘,SMART显示正常,Windows的磁盘管理能显示正确容量,但处于未初始化状态,且始终无法初始…

gin数据解析,绑定和渲染

一. 数据解析和绑定 1.1 Json数据解析和绑定 html文件&#xff1a; <!DOCTYPE html> <html lang"en"> <head> <meta charset"UTF-8"> <meta name"viewport" content"widthdevice-width, initial-scale1.0&quo…

数据脱敏学习

数据脱敏是一种保护敏感信息的方法&#xff0c;它通过修改或删除数据中的敏感部分&#xff0c;使得数据在保持一定可用性的同时&#xff0c;不再直接关联到个人隐私或重要信息。 自然人指可以直接或间接标识 直接标识&#xff1a;如姓名、身份证号码、家庭住址、电话号码、电…

Power BI可视化表格矩阵如何保持样式导出数据?

故事背景&#xff1a; 有朋友留言询问&#xff1a;自己从Power BI可视化矩阵表格中导出数据时&#xff0c;导出的表格样式会发生改变&#xff0c;需要线下再手动调整&#xff0c;重新进行透视组合成自己想要的格式。 有没有什么办法让表格导出来跟可视化一样&#xff1f; Po…

【proteus 51单片机入门】8*8led点阵

文章目录 前言如何点亮led点阵仿真图代码点亮led核心代码解析 爱心代码 滚动总结 前言 在嵌入式系统的开发中&#xff0c;LED点阵显示器是一种常见的显示设备&#xff0c;它可以用来显示各种图形和文字&#xff0c;为用户提供直观的信息反馈。本文将介绍如何使用Proteus软件和…

Element 页面滚动表头置顶

在开发后台管理系统时&#xff0c;表格是最常用的一个组件&#xff0c;为了看数据方便&#xff0c;时常需要固定表头。 如果页面基本只有一个表格区域&#xff0c;我们可以根据屏幕的高度动态的计算出一个值&#xff0c;给表格设定一个固定高度&#xff0c;这样表头就可以固定…

在 PMP 考试中,项目管理经验不足怎么办?

在项目管理的专业成长之路上&#xff0c;PMP认证如同一块里程碑&#xff0c;标志着从业者的专业水平达到了国际公认的标准。然而&#xff0c;对于那些项目管理经验尚浅的考生来说&#xff0c;这座里程碑似乎显得有些遥不可及。那么&#xff0c;在PMP考试准备中&#xff0c;项目…

冯雷老师:618大退货事件分析

近日冯雷老师受邀为某头部电商36名高管进行培训&#xff0c;其中聊到了今年618退货潮的问题。以下内容整理自冯雷老师的部分授课内容。 一、引言 随着电子商务的蓬勃发展&#xff0c;每年的618大促已成为消费者和商家共同关注的焦点。然而&#xff0c;在销售额不断攀升的同时…

DigiRL:让 AI 自己学会控制手机

类似于苹果此前发布的Ferret-UI 的安卓开源平替。主要用于在 Android 设备上识别 UI 和执行指令&#xff0c;不同的是它利用了离线到在线强化学习&#xff08;Offline-to-Online RL&#xff09;&#xff0c;能够快速适应应用更新或 UI 变化。

如何解决java程序CPU负载过高问题

1、介绍 在生产环境中&#xff0c;有时会遇到cpu占用过高且一直下不去的场景。这种情况可能会导致服务器宕机&#xff0c;进而中断对外服务&#xff0c;也会影响硬件寿命。 2、原因 1、Java代码存在因递归不当等原因导致的死循环的问题&#xff0c;推荐有条件的循环&#xf…

OpenAI禁止中国使用API,国内大模型市场何去何从

GPT-5 一年半后发布&#xff1f;对此你有何期待&#xff1f; 前言 前言&#xff1a; 近日&#xff0c;OpenAI宣布禁止中国用户使用其API&#xff0c;这一决策引起了国内大模型市场的广泛关注。面对这一挑战&#xff0c;国内大模型市场的发展路径和前景成为业界热议的焦点。本…

pytorch-01

加载mnist数据集 one-hot编码实现 import numpy as np import torch x_train np.load("../dataset/mnist/x_train.npy") # 从网站提前下载数据集&#xff0c;并解压缩 y_train_label np.load("../dataset/mnist/y_train_label.npy") x torch.tensor(y…

【小程序静态页面】猜拳游戏大转盘积分游戏小程序前端模板源码

猜拳游戏大转盘积分游戏小程序前端模板源码&#xff0c; 一共五个静态页面&#xff0c;首页、任务列表、大转盘和猜拳等五个页面。 主要是通过做任务来获取积分&#xff0c;积分可以兑换商品&#xff0c;也可用来玩游戏&#xff1b;通过玩游戏既可能获取奖品或积分也可能会消…

一文速览Google的Gemma:从gemma1到gemma2(2代27B的能力接近llama3 70B)

前言 如此文《七月论文审稿GPT第3.2版和第3.5版&#xff1a;通过paper-review数据集分别微调Mistral、gemma》所讲 Google作为曾经的AI老大&#xff0c;我司自然紧密关注&#xff0c;所以当Google总算开源了一个gemma 7b&#xff0c;作为有技术追求、技术信仰的我司&#xff0…

maven安装jar和pom到本地仓库

举例子我们要将 elastic-job-spring-boot-starter安装到本地的maven仓库&#xff0c;如下&#xff1a; <dependency><groupId>com.github.yinjihuan</groupId><artifactId>elastic-job-spring-boot-starter</artifactId><version>1.0.5&l…

关于组织赴俄罗斯(莫斯科)第 28 届国际汽车零部件、汽车维修设备和商品展览会商务考察的通知

关于组织赴俄罗斯&#xff08;莫斯科&#xff09; 第 28 届国际汽车零部件、汽车维修设备和商品展览会商务考察的通知 展会名称&#xff1a;俄罗斯&#xff08;莫斯科&#xff09;第 28 届国际汽车零部件、汽车零部件、汽车维修设备和商品展览会 时间&#xff1a;2024 年 8 月…

day02-Spark集群及参数

一、Spark运行环境变量问题(了解) 1-pycharm远程开发运行时&#xff0c;执行的是服务器的代码 2-通过本地传递指令到远程服务器运行代码时&#xff0c;会加载对应环境变量数据&#xff0c;加载环境变量文件是用户目录下的.bashrc文件 在/etc/bashrc 1-1 在代码中添加 使用os模块…