用deepseek学大模型08-cnn残差网络

残差网络

参考:https://blog.csdn.net/2301_80750681/article/details/142882802
在这里插入图片描述

以下是使用PyTorch实现的三层残差网络示例,包含三个残差块和完整的网络结构:

import torch
import torch.nn as nnclass BasicBlock(nn.Module):expansion = 1def __init__(self, in_channels, out_channels, stride=1, downsample=None):super(BasicBlock, self).__init__()self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(out_channels)self.relu = nn.ReLU(inplace=True)self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,stride=1, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channels)self.downsample = downsampledef forward(self, x):identity = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)if self.downsample:identity = self.downsample(x)out += identityout = self.relu(out)return outclass ResNet3Layer(nn.Module):def __init__(self, num_classes=10):super(ResNet3Layer, self).__init__()self.in_channels = 64# 初始卷积层self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)self.bn1 = nn.BatchNorm2d(64)self.relu = nn.ReLU(inplace=True)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)# 三个残差块self.layer1 = self._make_layer(BasicBlock, 64, 1, stride=1)self.layer2 = self._make_layer(BasicBlock, 128, 1, stride=2)self.layer3 = self._make_layer(BasicBlock, 256, 1, stride=2)# 分类层self.avgpool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(256 * BasicBlock.expansion, num_classes)def _make_layer(self, block, out_channels, blocks, stride=1):downsample = Noneif stride != 1 or self.in_channels != out_channels * block.expansion:downsample = nn.Sequential(nn.Conv2d(self.in_channels, out_channels * block.expansion,kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(out_channels * block.expansion))layers = []layers.append(block(self.in_channels, out_channels, stride, downsample))self.in_channels = out_channels * block.expansionfor _ in range(1, blocks):layers.append(block(self.in_channels, out_channels))return nn.Sequential(*layers)def forward(self, x):x = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.maxpool(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.avgpool(x)x = torch.flatten(x, 1)x = self.fc(x)return x# 示例用法
model = ResNet3Layer(num_classes=10)
print(model)

关键实现细节说明

  1. 残差块设计:每个BasicBlock包含两个3×3卷积层,通过downsample参数处理维度匹配问题
  2. 网络结构
    • 初始卷积层(7×7卷积)用于提取基础特征
    • 最大池化层进行初步下采样
    • 三个残差块分别实现64→128→256通道的特征提取
    • 全局平均池化替代全连接层减少参数量
  3. 维度匹配:通过1×1卷积调整shortcut连接的维度,保证残差相加的有效性
  4. 参数配置
    • 每个残差块的步长(stride)分别为1、2、2,实现特征图尺寸的逐步缩小
    • 使用Batch Normalization加速训练收敛

该网络适用于CIFAR-10等小尺寸图像分类任务,可通过调整num_classes参数适配不同数据集。实际训练时建议配合数据增强和正则化技术。

残差网络的数学推导核心在于其残差映射设计和梯度传播特性,主要包含以下关键点:

1. 残差前向传播公式

残差块的基本结构可表示为:
x l + 1 = x l + F ( x l , W l ) \mathbf{x}_{l+1} = \mathbf{x}_l + \mathcal{F}(\mathbf{x}_l, \mathbf{W}_l) xl+1=xl+F(xl,Wl)
其中:

  • x l \mathbf{x}_l xl:第 l l l层的输入
  • F \mathcal{F} F:残差函数(通常包含卷积、BN、激活等操作)
  • W l \mathbf{W}_l Wl:可学习参数

对于 L L L层深度网络,累积表达式为:
x L = x 0 + ∑ i = 0 L − 1 F ( x i , W i ) \mathbf{x}_L = \mathbf{x}_0 + \sum_{i=0}^{L-1} \mathcal{F}(\mathbf{x}_i, \mathbf{W}_i) xL=x0+i=0L1F(xi,Wi)
这表明深层特征可分解为浅层特征与残差之和

2. 反向传播梯度推导

通过链式法则计算梯度:
∂ L ∂ x l = ∂ L ∂ x L ⋅ ∏ i = l L − 1 ( 1 + ∂ F ( x i , W i ) ∂ x i ) \frac{\partial \mathcal{L}}{\partial \mathbf{x}_l} = \frac{\partial \mathcal{L}}{\partial \mathbf{x}_L} \cdot \prod_{i=l}^{L-1} \left( 1 + \frac{\partial \mathcal{F}(\mathbf{x}_i, \mathbf{W}_i)}{\partial \mathbf{x}_i} \right) xlL=xLLi=lL1(1+xiF(xi,Wi))
其中:

  • 常数项1保证梯度直接传递(恒等映射路径)
  • 残差项 ∂ F ∂ x i \frac{\partial \mathcal{F}}{\partial \mathbf{x}_i} xiF通过权重层传播

3. 解决梯度问题的数学机制

当残差项趋近于0时:
∂ L ∂ x l ≈ ∂ L ∂ x L ⋅ 1 \frac{\partial \mathcal{L}}{\partial \mathbf{x}_l} \approx \frac{\partial \mathcal{L}}{\partial \mathbf{x}_L} \cdot 1 xlLxLL1
即使深层梯度 ∂ L ∂ x L \frac{\partial \mathcal{L}}{\partial \mathbf{x}_L} xLL较小,浅层仍能获得有效梯度更新,从根本上缓解梯度消失问题

4. 网络退化问题的解决

假设最优映射为 H ∗ ( x ) H^*(x) H(x),传统网络需直接拟合:
H ( x ) = H ∗ ( x ) H(x) = H^*(x) H(x)=H(x)
而残差网络改为拟合:
F ( x ) = H ∗ ( x ) − x \mathcal{F}(x) = H^*(x) - x F(x)=H(x)x
这使得当 F ( x ) = 0 \mathcal{F}(x)=0 F(x)=0时,网络退化为恒等映射,保证性能不劣化

5. 维度匹配的数学处理

当输入输出维度不匹配时,引入1×1卷积:
y = F ( x , W i ) + W s x \mathbf{y} = \mathcal{F}(\mathbf{x}, \mathbf{W}_i) + \mathbf{W}_s\mathbf{x} y=F(x,Wi)+Wsx
其中 W s \mathbf{W}_s Ws为线性变换矩阵,保证残差相加的维度一致性

通过上述数学设计,残差网络实现了:

  1. 梯度稳定传播(反向过程)
  2. 深层特征的有效累积(前向过程)
  3. 网络退化现象的根本性解决

残差网络(ResNet)相比普通直接卷积网络的核心优势体现在以下方面:

1. 解决梯度消失与网络退化问题

通过跳跃连接(Shortcut Connection)的残差结构,反向传播时梯度可绕过非线性层直接传递。数学上,第 l l l层的梯度为:
∂ L ∂ x l = ∂ L ∂ x L ⋅ ∏ i = l L − 1 ( 1 + ∂ F ( x i , W i ) ∂ x i ) \frac{\partial \mathcal{L}}{\partial x_l} = \frac{\partial \mathcal{L}}{\partial x_L} \cdot \prod_{i=l}^{L-1} \left( 1 + \frac{\partial F(x_i, W_i)}{\partial x_i} \right) xlL=xLLi=lL1(1+xiF(xi,Wi))
当残差项 ∂ F ∂ x i ≈ 0 \frac{\partial F}{\partial x_i} \approx 0 xiF0时,梯度 ∂ L ∂ x l ≈ ∂ L ∂ x L \frac{\partial \mathcal{L}}{\partial x_l} \approx \frac{\partial \mathcal{L}}{\partial x_L} xlLxLL,避免链式求导的指数衰减。

2. 优化目标简化

残差网络学习残差映射 F ( x ) = H ( x ) − x F(x) = H(x) - x F(x)=H(x)x,而非直接学习目标函数 H ( x ) H(x) H(x)。当最优映射接近恒等变换时,残差 F ( x ) → 0 F(x) \to 0 F(x)0比直接学习 H ( x ) → x H(x) \to x H(x)x更容易收敛。

3. 支持极深网络结构

普通CNN在超过20层时会出现性能退化(训练/测试误差同时上升),而ResNet通过残差块堆叠可构建超过1000层的网络,且准确率随深度增加持续提升(如ResNet-152在ImageNet上Top-5错误率仅3.57%)。

4. 参数效率与计算优化

  • 维度调整:使用1×1卷积调整通道数,参数量仅需 C i n × C o u t C_{in} \times C_{out} Cin×Cout,远少于3×3卷积的 9 C i n C o u t 9C_{in}C_{out} 9CinCout
  • 瓶颈结构:通过“1×1→3×3→1×1”的Bottleneck设计(如ResNet-50),在保持性能的同时减少计算量。

5. 实际性能优势

  • 分类任务:ResNet-50在ImageNet上的Top-1准确率达76.5%,比VGG-16提升约8%。
  • 训练效率:引入BN层后,ResNet训练速度比普通CNN快2-3倍,且收敛更稳定。

对比总结

特性普通CNNResNet
最大有效深度~20层>1000层
梯度传播稳定性易消失/爆炸通过跳跃连接稳定
训练误差随深度变化先降后升(退化)持续下降
参数量(同精度)较高更低(瓶颈结构)

这些设计使得ResNet成为计算机视觉任务的基础架构,广泛应用于图像分类、目标检测等领域。

以下是使用PyTorch实现的残差网络(ResNet)在MNIST手写数字识别中的示例:

import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoaderclass ResidualBlock(nn.Module):def __init__(self, in_channels, out_channels, stride=1):super().__init__()self.conv1 = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1),nn.BatchNorm2d(out_channels),nn.ReLU())self.conv2 = nn.Sequential(nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(out_channels))self.shortcut = nn.Sequential()if stride != 1 or in_channels != out_channels:self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride),nn.BatchNorm2d(out_channels))def forward(self, x):residual = self.shortcut(x)out = self.conv1(x)out = self.conv2(out)out += residualout = nn.ReLU()(out)return outclass ResNetMNIST(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(64),nn.ReLU())self.res_blocks = nn.Sequential(ResidualBlock(64, 64),ResidualBlock(64, 128, stride=2),ResidualBlock(128, 256, stride=2))self.fc = nn.Sequential(nn.AdaptiveAvgPool2d((1,1)),nn.Flatten(),nn.Linear(256, 10))def forward(self, x):x = self.conv1(x)x = self.res_blocks(x)x = self.fc(x)return x# 数据预处理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))
])# 加载数据集
train_set = MNIST(root='./data', train=True, download=True, transform=transform)
test_set = MNIST(root='./data', train=False, download=True, transform=transform)# 创建数据加载器
train_loader = DataLoader(train_set, batch_size=128, shuffle=True)
test_loader = DataLoader(test_set, batch_size=128, shuffle=False)# 初始化模型和优化器
model = ResNetMNIST()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()# 训练循环
for epoch in range(10):model.train()for images, labels in train_loader:outputs = model(images)loss = criterion(outputs, labels)optimizer.zero_grad()loss.backward()optimizer.step()# 测试准确率model.eval()correct = 0with torch.no_grad():for images, labels in test_loader:outputs = model(images)_, predicted = torch.max(outputs.data, 1)correct += (predicted == labels).sum().item()acc = 100 * correct / len(test_set)print(f'Epoch {epoch+1}, Test Accuracy: {acc:.2f}%')

关键实现细节说明

  1. 残差块设计:每个残差块包含两个3×3卷积层,通过shortcut连接处理维度变化
  2. 网络结构
    • 初始卷积层(3×3)提取基础特征
    • 三个残差块实现64→128→256通道的特征提取
    • 全局平均池化替代全连接层减少参数量
  3. 数据预处理
    • 标准化处理: μ = 0.1307 \mu=0.1307 μ=0.1307, σ = 0.3081 \sigma=0.3081 σ=0.3081
    • 输入维度:1×28×28(通道×高×宽)
  4. 训练配置
    • Adam优化器(学习率0.001)
    • 交叉熵损失函数
    • 批量大小128,训练10个epoch

该模型在MNIST测试集上通常能达到**99%+**的准确率。实际训练时可添加数据增强(随机旋转、平移)提升泛化能力,或使用学习率调度器优化收敛过程。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/20198.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

AIGC(生成式AI)试用 21 -- Python调用deepseek API

1. 安装openai pip3 install openai########################## Collecting openaiUsing cached openai-1.61.1-py3-none-any.whl.metadata (27 kB) Collecting anyio<5,>3.5.0 (from openai)Using cached anyio-4.8.0-py3-none-any.whl.metadata (4.6 kB) Collecting d…

分享一款AI绘画图片展示和分享的小程序

&#x1f3a8;奇绘图册 【开源】一款帮AI绘画爱好者维护绘图作品的小程序 查看Demo 反馈 github 文章目录 前言一、奇绘图册是什么&#xff1f;二、项目全景三、预览体验3.1 截图示例3.2 在线体验 四、功能介绍4.1 小程序4.2 服务端 五、安装部署5.1 快速开始~~5.2 手动部…

node.js + html调用ChatGPTApi实现Ai网站demo(带源码)

文章目录 前言一、demo演示二、node.js 使用步骤1.引入库2.引入包 前端HTML调用接口和UI所有文件总结 前言 关注博主&#xff0c;学习每天一个小demo 今天是Ai对话网站 又到了每天一个小demo的时候咯&#xff0c;前面我写了多人实时对话demo、和视频转换demo&#xff0c;今天…

Java基础(其一)

1.八个基础数据类型&#xff1a; 整数型&#xff1a;int long short byte 浮点型&#xff1a;float double 字符型&#xff1a;char 布尔型&#xff1a;bool 1.1. byte 范围&#xff1a;-128 到 127&#xff08;8位&#xff0c;有符号&#xff09; 用途&#xff1a; 小范围…

【Linux AnolisOS】关于Docker的一系列问题。尤其是拉取东西时的网络问题,镜像源问题。

AnolisOS 8中使用Docker部署&#xff08;全&#xff09;_anolis安装docker-CSDN博客 从在虚拟机安装龙蜥到安装docker上面这篇文章写的很清晰了&#xff0c;我重点讲述我解决文章里面问题一些的方法。 问题1&#xff1a; docker: Get https://registry-1.docker.io/v2/: net/h…

Java:单例模式(Singleton Pattern)及实现方式

一、单例模式的概念 单例模式是一种创建型设计模式&#xff0c;确保一个类只有一个实例&#xff0c;并提供一个全局访问点来访问该实例&#xff0c;是 Java 中最简单的设计模式之一。该模式常用于需要全局唯一实例的场景&#xff0c;例如日志记录器、配置管理、线程池、数据库…

【Python项目】文本相似度计算系统

【Python项目】文本相似度计算系统 技术简介&#xff1a;采用Python技术、Django技术、MYSQL数据库等实现。 系统简介&#xff1a;本系统基于Django进行开发&#xff0c;包含前端和后端两个部分。前端基于Bootstrap框架进行开发&#xff0c;主要包括系统首页&#xff0c;文本分…

通过VSCode直接连接使用 GPT的编程助手

GPT的编程助手在VSC上可以直接使用 选择相应的版本都可以正常使用。每个月可以使用40条&#xff0c;超过限制要付费。 如下图对应的4o和claude3.5等模型都可以使用。VSC直接连接即可。 配置步骤如下&#xff1a; 安装VSCODE 直接&#xff0c;官网下载就行 https://code.vis…

神经网络剪枝技术的重大突破:sGLP-IB与sTLP-IB

神经网络剪枝技术的重大突破:sGLP-IB与sTLP-IB 在人工智能飞速发展的今天,深度学习技术已经成为推动计算机视觉、自然语言处理等领域的核心力量。然而,随着模型规模的不断膨胀,如何在有限的计算资源和存储条件下高效部署这些复杂的神经网络模型,成为了研究者们亟待解决的…

深度集成DeepSeek大模型:WebSocket流式聊天实现

目录 5分钟快速接入DeepSeek大模型&#xff1a;WebSocket实时聊天指南创建应用开发后端代码 (Python/Node.js)结语 5分钟快速接入DeepSeek大模型&#xff1a;WebSocket实时聊天指南 创建应用 访问DeepSeek官网 前往 DeepSeek官网。如果还没有账号&#xff0c;需要先注册一个。…

Javascript网页设计案例:通过PDF.js实现一款PDF阅读器,包括预览、页面旋转、页面切换、放大缩小、黑夜模式等功能

前言 目前功能包括&#xff1a; 切换到首页。切换到尾页。上一页。下一页。添加标签。标签管理页面旋转页面随意拖动双击后还原位置 其实按照自己的预期来说&#xff0c;有很多功能还没有开发完&#xff0c;配色也没有全都搞完&#xff0c;先发出来吧&#xff0c;后期有需要…

使用html css js 来实现一个服装行业的企业站源码-静态网站模板

最近在练习 前端基础&#xff0c;html css 和js 为了加强 代码的 熟悉程序&#xff0c;就使用 前端 写了一个个服装行业的企业站。把使用的技术 和 页面效果分享给大家。 应用场景 该制衣服装工厂官网前端静态网站模板主要用于前端练习和编程练习&#xff0c;适合初学者进行 HT…

Ubuntu24安装MongoDB(解压版)

目录 0.需求说明1.环境检查2.下载软件2.1.下载MongoDB服务端2.2.下载MongoDB连接工具(可略过)2.3.检查上传或下载的安装包 3.安装MongoDB3.1.编辑系统服务3.2.启动服务3.3.客户端连接验证3.3.1.创建管理员用户 4.远程访问4.1.开启远程访问4.2.开放防火墙 0.需求说明 问&#x…

打造一个有点好看的 uniapp 网络测速软件

大家好&#xff0c;我是一名前端小白。今天想和分享一个有点好看的网络测速 uniapp 组件的实现过程。这个组件不仅外观精美&#xff0c;而且具有完整的功能性&#xff0c;是一个非常适合学习和实践的案例。 设计理念 在开始coding之前&#xff0c;先聊聊设计理念。一个好的测…

ESP32 ESP-IDF TFT-LCD(ST7735 128x160)自定义组件驱动显示

ESP32 ESP-IDF TFT-LCD(ST7735 128x160)自定义组件驱动显示 &#x1f33f;驱动参考来源&#xff1a;https://blog.csdn.net/weixin_59250390/article/details/142691848&#x1f4cd;个人相关驱动内容文章&#xff1a;《ESP32 ESP-IDF TFT-LCD(ST7735 128x160) LVGL基本配置和使…

Redis的简单使用

1.Redis的安装Ubuntu安装Redis-CSDN博客 2.Redis在Spring Boot 3 下的使用 2.1 pom.xml <!-- Redis --> <dependency> <groupId>org.springframework.boot</groupId> <artifactId>spring-boot-starter-data-redis</artifac…

elabradio入门第四讲——位同步(符号同步)

位同步是数字通信系统中特有的一种同步技术&#xff0c;又称为码元同步。在数字通信系统中&#xff0c;任何消息都是一串信号码元序列&#xff0c;接收端为了恢复码元序列&#xff0c;则需要知道每个码元的起止时刻&#xff0c;以便对于解调后的信号进行抽样判决&#xff0c;这…

网络安全推荐的视频教程 网络安全系列

第一章 网络安全概述 1.2.1 网络安全概念P4 网络安全是指网络系统的硬件、软件及其系统中的数据受到保护&#xff0c;不因偶然的或恶意的原因而遭到破坏、更改、泄露&#xff0c;系统连续可靠正常地运行&#xff0c;网络服务不中断。 1.2.3 网络安全的种类P5 &#xff08;1…

工控网络安全介绍 工控网络安全知识题目

31.PDR模型与访问控制的主要区别(A) A、PDR把对象看作一个整体 B、PDR作为系统保护的第一道防线 C、PDR采用定性评估与定量评估相结合 D、PDR的关键因素是人 32.信息安全中PDR模型的关键因素是(A) A、人 B、技术 C、模型 D、客体 33.计算机网络最早出现在哪个年代(B) A、20世…

Golang学习笔记_33——桥接模式

Golang学习笔记_30——建造者模式 Golang学习笔记_31——原型模式 Golang学习笔记_32——适配器模式 文章目录 桥接模式详解一、桥接模式核心概念1. 定义2. 解决的问题3. 核心角色4. 类图 二、桥接模式的特点三、适用场景1. 多维度变化2. 跨平台开发3. 动态切换实现 四、与其他…