【深度学习基础模型】深度残差网络(Deep Residual Networks, DRN)详细理解并附实现代码。

【深度学习基础模型】深度残差网络(Deep Residual Networks, DRN)详细理解并附实现代码。

【深度学习基础模型】深度残差网络(Deep Residual Networks, DRN)详细理解并附实现代码。


文章目录

  • 【深度学习基础模型】深度残差网络(Deep Residual Networks, DRN)详细理解并附实现代码。
  • 1. 算法提出
  • 2. 概述
  • 3. 发展
  • 4. 应用
  • 5. 优缺点
  • 6. Python代码实现


参考地址:https://www.asimovinstitute.org/neural-network-zoo/
论文地址:https://arxiv.org/pdf/1512.03385

欢迎宝子们点赞、关注、收藏!欢迎宝子们批评指正!
在这里插入图片描述

1. 算法提出

深度残差网络(DRN)最初由何凯明等人于2015年在论文“Deep Residual Learning for Image Recognition”中提出。该算法的核心思想是通过残差块(Residual Block)来解决深层神经网络训练中的退化问题

传统神经网络在层数增加时,随着网络变深,训练误差反而会上升,这种现象被称为梯度消失/爆炸问题DRN通过引入跳跃连接(Skip Connection),将前几层的输入直接传递到后几层,从而有效缓解了这个问题

2. 概述

DRN的核心结构是残差块。一个典型的残差块包含一个跳跃连接,将输入直接加到输出上,如下所示:

y = F ( x ) + x y=F(x)+x y=F(x)+x

其中, x x x是残差块的输入, F ( x ) F(x) F(x)是经过几层非线性变换后的输出。通过将输入 x x x直接添加到输出 F ( x ) F(x) F(x),残差网络实际上是在学习一个残差函数。这种结构使得网络能够更容易训练,并且即使网络层数增加,网络也不会出现退化现象。

残差网络的优点在于:

  • 更深的网络结构:传统前馈神经网络(Feedforward Neural Networks, FFNN)的层数通常在几层到几十层,而DRN可以扩展到上百层甚至更深(如ResNet-152)。
  • 稳定的训练过程:通过引入跳跃连接,梯度可以更好地传播,从而缓解了梯度消失问题。

3. 发展

自2015年提出以来,残差网络成为了许多深度学习模型的基础架构。随着研究的深入,残差网络的变种也被提出,例如:

  • ResNet:最早的残差网络版本,适用于图像分类等任务。
  • ResNeXt:将残差块中的卷积运算拆分为多个并行的路径,提高了模型的可扩展性。
  • DenseNet:一种变体,进一步增加了层之间的密集连接。

4. 应用

DRN被广泛应用于各种深度学习任务中,特别是在计算机视觉领域表现出色。典型的应用包括:

  • 图像分类:ResNet在ImageNet分类任务中取得了极好的效果,常用于图像分类任务。
  • 目标检测:许多目标检测模型(如Faster R-CNN)都基于残差网络作为主干结构。
  • 语义分割:在语义分割任务中,残差网络作为特征提取器也广泛使用。

5. 优缺点

优点:

  • 有效的深度学习:DRN能够有效训练非常深的网络(可达150层甚至更多),而不会出现明显的性能退化。
  • 跳跃连接:通过跳跃连接,DRN能够更好地传播梯度,解决梯度消失问题,从而加快训练速度。
  • 强大的表达能力:可以通过残差学习获得更高的模型表达能力,适用于复杂的学习任务。

缺点:

  • 计算复杂性高:随着网络深度的增加,计算资源需求显著增加,训练时间可能较长。
  • 模型可解释性差:深度模型的复杂性可能导致难以解释其内部机制和决策过程。
  • 需要大量数据:有效训练深度残差网络通常需要大量标注数据,以防止过拟合。

6. Python代码实现

以下是一个使用深度残差网络进行图像分类的示例,基于PyTorch框架:

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader# 定义残差块
class ResidualBlock(nn.Module):def __init__(self, in_channels, out_channels, stride=1):super(ResidualBlock, self).__init__()self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(out_channels)self.relu = nn.ReLU(inplace=True)self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channels)# 如果输入维度和输出维度不匹配,通过1x1卷积进行匹配self.shortcut = nn.Sequential()if stride != 1 or in_channels != out_channels:self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(out_channels))def forward(self, x):out = self.relu(self.bn1(self.conv1(x)))out = self.bn2(self.conv2(out))out += self.shortcut(x)  # 跳跃连接out = self.relu(out)return out# 定义ResNet模型
class ResNet(nn.Module):def __init__(self, block, num_blocks, num_classes=10):super(ResNet, self).__init__()self.in_channels = 64self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(64)self.relu = nn.ReLU(inplace=True)self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(512, num_classes)def _make_layer(self, block, out_channels, num_blocks, stride):layers = []layers.append(block(self.in_channels, out_channels, stride))self.in_channels = out_channelsfor _ in range(1, num_blocks):layers.append(block(self.in_channels, out_channels))return nn.Sequential(*layers)def forward(self, x):out = self.relu(self.bn1(self.conv1(x)))out = self.layer1(out)out = self.layer2(out)out = self.layer3(out)out = self.layer4(out)out = self.avg_pool(out)out = out.view(out.size(0), -1)out = self.fc(out)return out# 实例化ResNet18模型
def ResNet18():return ResNet(ResidualBlock, [2, 2, 2, 2])  # 定义ResNet18结构# 数据预处理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])# 加载CIFAR-10数据集
train_dataset = datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.CIFAR10(root='./data', train=False, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=100, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False)# 定义设备、损失函数和优化器
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ResNet18().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练模型
def train_model(num_epochs=5):for epoch in range(num_epochs):model.train()for images, labels in train_loader:images, labels = images.to(device), labels.to(device)outputs = model(images)loss = criterion(outputs, labels)optimizer.zero_grad()loss.backward()optimizer.step()print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')# 评估模型
def test_model():model.eval()correct = 0total = 0with torch.no_grad():for images, labels in test_loader:images, labels = images.to(device), labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f'测试集准确率: {100 * correct / total:.2f}%')# 运行训练和测试
train_model(num_epochs=5)
test_model()

代码解释:

  • ResidualBlock:实现了残差块,其中包括卷积层、批量归一化(Batch Normalization)、ReLU激活函数和跳跃连接。通过跳跃连接,将输入直接加到输出中,以实现残差学习。
  • ResNet:定义了ResNet模型结构,包括多个残差块的堆叠。_make_layer方法用于构建每一层的残差块。
  • 数据预处理:使用transforms.Compose对CIFAR-10数据集进行转换,进行标准化处理。
  • 模型训练:在train_model函数中,模型通过多轮训练,不断优化损失函数。
  • 模型评估:在test_model函数中,模型评估在测试集上的性能,并输出准确率。

该代码实现了基于深度残差网络的图像分类任务,展示了DRN在实际应用中的有效性。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/438895.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Python小示例——质地不均匀的硬币概率统计

在概率论和统计学中,随机事件的行为可以通过大量实验来研究。在日常生活中,我们经常用硬币进行抽样,比如抛硬币来决定某个结果。然而,当我们处理的是“质地不均匀”的硬币时,事情就变得复杂了。质地不均匀的硬币意味着…

Spring Boot中线程池使用

说明:在一些场景,如导入数据,批量插入数据库,使用常规方法,需要等待较长时间,而使用线程池可以提高效率。本文介绍如何在Spring Boot中使用线程池来批量插入数据。 搭建环境 首先,创建一个Spr…

每日学习一个数据结构-树

文章目录 树的相关概念一、树的定义二、树的基本术语三、树的分类四、特殊类型的树五、树的遍历六、树的应用场景 树的遍历一、前序遍历二、中序遍历三、后序遍历使用java代码实现遍历总结 树的相关概念 树是一种重要的非线性数据结构,在计算机科学中有着广泛的应用…

24-10-4-读书笔记(二十四)-《一个孤独漫步者的遐想》下([法] 让·雅克·卢梭 [译]陈阳)

文章目录 《一个孤独漫步者的遐想》下([法] 让雅克卢梭 [译]陈阳)目录阅读笔记记录总结 《一个孤独漫步者的遐想》下([法] 让雅克卢梭 [译]陈阳) 十月第四篇,这次应该能拿到流量券吧!《一个孤独漫步者的遐想…

A Learning-Based Approach to Static Program Slicing —— 论文笔记

A Learning-Based Approach to Static Program Slicing OOPLSA’2024 文章目录 A Learning-Based Approach to Static Program Slicing1. Abstract2. Motivation(1) 为什么需要能处理不完整代码(2) 现有方法局限性(3) 验证局限性: 初步实验研究实验设计何为不完整代码实验结果…

C#串口温度读取

背景:每天学点,坚持 要安装好虚拟串口和modbus poll,方便调试(相关资源在文末,也可以私信找我要) 传感器部分使用的是达林科技的DL11B-MC-D1,当时42软妹币买的(官网上面有这个传感…

网络编程(12)——完善粘包处理操作(id字段)

十二、day12 之前的粘包处理是基于消息头包含的消息体长度进行对应的切包操作,但并不完整。一般来说,消息头仅包含数据域的长度,但是如果要进行逻辑处理,就需要传递一个id字段表示要处理的消息id,当然可以不在包头传i…

Linux网络编程

文章目录 参考资料在前1. 前置知识2. 进程概述2.1 fork()函数2.2 守护进程 3. 浅谈printf()函数与write()函数3.1 printf()函数缓存问题3.2 write()函数思考 4. 网络编程剖析4.1 listen()监听套接字4.2 阻塞/非阻塞IO4.3 同步/异步IO4.4 TCP/IP设计4.4.1 三次握手4.4.2 四次挥手…

机器人的性能指标

1. 负荷能力 负荷能力负荷能力是指机器人在满足其他性能要求的情况下,能够承载的负荷重量。例如,一台机器人的最大负荷能力可能远大于它的额定负荷能力,但是达到最大负荷时,机器人的工作精度可能会降低,可能无法准确地沿着预定的轨迹运动,或者产生额外的偏差。机器人的负荷量与…

【重学 MySQL】四十一、子查询举例与分类

【重学 MySQL】四十一、子查询举例与分类 引入子查询在SELECT子句中引入子查询在FROM子句中引入子查询在WHERE子句中引入子查询注意事项 子查询分类标量子查询列子查询行子查询表子查询 子查询注意事项子查询的位置子查询的返回类型别名的使用性能考虑相关性错误处理逻辑清晰 总…

Flet介绍:平替PyQt的好用跨平台Python UI框架

随着Python在各个领域的广泛应用,特别是在数据科学和Web开发领域,对于一个简单易用且功能强大的用户界面(UI)开发工具的需求日益增长。传统的Python GUI库如Tkinter、PyQt虽然功能强大,但在易用性和现代感方面略显不足…

数据结构--二叉树的顺序实现(堆实现)

引言 在计算机科学中,二叉树是一种重要的数据结构,广泛应用于各种算法和程序设计中。本文将探讨二叉树的顺序实现,特别是堆的实现方式。 一、树 1.1树的概念与结构 树是⼀种⾮线性的数据结构,它是由 n(n>0) 个有限结点组成…

【HTML5】html5开篇基础(5)

1.❤️❤️前言~🥳🎉🎉🎉 Hello, Hello~ 亲爱的朋友们👋👋,这里是E绵绵呀✍️✍️。 如果你喜欢这篇文章,请别吝啬你的点赞❤️❤️和收藏📖📖。如果你对我的…

vue-live2d看板娘集成方案设计使用教程

文章目录 前言v1.1.x版本:vue集成看板娘(暂不使用,在v1.2.x已替换)集成看板娘实现看板娘拖拽效果方案资源备份存储 当前最新调研:2024.10.2开源方案1:OhMyLive2D(推荐)开源方案2&…

【设计模式】软件设计原则——接口隔离迪米特

接口隔离原则引出 接口隔离原则 定义:用多个专门的接口,不使用单一的总接口,客户端不应该依赖它不需要的接口; 一个类对另一个类的依赖,应该建立在最小接口上;如果有一个大接口,里面有很多方法,如果使用一个类实现该接口,所有的类都要实现,导致代码冗余;…

android 全面屏最底部栏沉浸式

Activity的onCreate方法中添加 this.getWindow().addFlags(WindowManager.LayoutParams.FLAG_TRANSLUCENT_NAVIGATION); Android 系统 Bar 沉浸式完美兼容方案自 Android 5.0 版本,Android 带来了沉浸式系统 ba - 掘金 (juejin.cn)https://juejin.cn/post/7075578…

【HTTP(3)】(状态码,https)

【认识状态码】 状态码最重要的目的,就是反馈给浏览器:这次请求是否成功,若失败,则出现失败原因 常见状态码: 200:OK,表示成功 404:Not Found,浏览器访问的资源在服务器上没有找到 403:Forbidden,访问被…

【每天学个新注解】Day 15 Lombok注解简解(十四)—@UtilityClass、@Helper

UtilityClass 生成工具类的注解 将一个类通过注解变成一个工具类,并没有什么用,本来代码中的工具类数量就极为有限,并不能达到减少重复代码的目的 1、如何使用 加在需要委托将其变为工具类的普通类上。 2、代码示例 例: Uti…

基于Java,SpringBoot,Vue智慧校园健康驿站体检论坛请假管理系统

摘要 互联网发展至今,无论是其理论还是技术都已经成熟,而且它广泛参与在社会中的方方面面。它让信息都可以通过网络传播,搭配信息管理工具可以很好地为人们提供服务。针对信息管理混乱,出错率高,信息安全性差&#xf…

景区+商业,如何实现1+1>2?

景区商业,如何实现11>2? 近两年,随着旅游业的蓬勃发展,旅游热潮持续升温,游客的消费观念也在逐步升级。为了适应这一趋势,各大景区纷纷着手打造具有鲜明特色的文旅项目,希望能够吸引…