MATLAB深度学习(七)——ResNet残差网络

一、ResNet网络

        ResNet是深度残差网络的简称。其核心思想就是在,每两个网络层之间加入一个残差连接,缓解深层网络中的梯度消失问题

二、残差结构

        在多层神经网络模型里,设想一个包含诺干层自网络,子网络的函数用H(x)来表示,其中x是子网络的输入。残差学习是通过重新设定这个参数,让一个参数层表达一个残差函数                 ​​​​​​​        ​​​​​​​        ​​​​​​​        ​​​​​​​        ​​​​​​​        ​​​​​​​        ​​​​​​​        F(x)=H(x)-x

        因此这个子网络的输出y 就是 

        ​​​​​​​        ​​​​​​​        ​​​​​​​        ​​​​​​​        ​​​​​​​        ​​​​​​​        ​​​​​​​   y=F(x)+x

        其中 +x 的操作,是通过一个相当于恒等映射的跳跃连接来完成的,它将残差块的输入直接与输出连接,这就是残差结构。按照上面的结构递推,根据前向传播,第i个残差块的输出就是第i+1个残差块的输入

        ​​​​​​​        ​​​​​​​        ​​​​​​​        ​​​​​​​        ​​​​​​​        ​​​​​​​        x_{\ell+1}=F(x_\ell)+x_\ell

        根据递归公式,可以推导出

        ​​​​​​​        ​​​​​​​        ​​​​​​​        ​​​​​​​        ​​​​​​​        ​​​​​​​        x_L=x_\ell+\sum_{i=l}^{L-1}F(x_i)

        这里的L表示任意后续残差块,i是靠前块,那么公式就说明了总会有信号能从浅层到深层。

        从反向传播来看,根据上面的公式对 Xl 进行求导,可以得到

        ​​​​​​​        ​​​​​​​        ​​​​​​​        ​​​​​​​        ​​​​​​​        \begin{aligned} \frac{\partial\mathcal{E}}{\partial x_\ell} & =\frac{\partial\mathcal{E}}{\partial x_L}\frac{\partial x_L}{\partial x_\ell} \\ & =\frac{\partial\mathcal{E}}{\partial x_L}\left(1+\frac{\partial}{\partial x_\ell}\sum_{i=l}^{L-1}F(x_i)\right) \\ & =\frac{\partial\mathcal{E}}{\partial x_L}+\frac{\partial\mathcal{E}}{\partial x_L}\frac{\partial}{\partial x_\ell}\sum_{i=l}^{L-1}F(x_i) \end{aligned}

        这里的 \mathcal{E} 是最小损失化函数。以上说明,浅层的梯度计算 \frac{\partial\mathcal{E}}{\partial x_{\ell}},总会直接加上上一个项 \frac{\partial\mathcal{E}}{\partial x_L}因为存在额外的一项,所以就想 F(xi)很小,总的梯度都不会消失。

三、基于ResNet识别实现步骤

        其主要步骤为1.加载图像数据,并将数据分为训练集合与验证集;2.加载MATLAB训练好的ResNet50;3.和Alexnet一样替换最后几层;4.按照网络配置调整图像数据;5.对网络进行训练

3.1 调整ResNet实现迁移学习

        针对ImageNet的数据任务,原本最后三层 FC SOFTMAX 输出层是针对1000个类别的物体进行识别,针对图像问题,继续调整这三层,首先冻结上面的层。将下面的三层进行替换。由于ResNet50需要输入的图像大小为224 * 224 *3.

unzip('MerchData.zip');
img_ds = imageDatastore('MerchData', ...'IncludeSubfolders',true, ...'LabelSource','foldernames');total_split = countEachLabel(img_ds); %返回一个包含每个标签和相应图像数量的表格num_images = length(img_ds.Labels); %返回的图像的个数,5个类型都有15张照片
perm =  randperm(num_images,10);  %随机取出10个
figure
for i = 1:9subplot(3,3,i)imshow(imread(img_ds.Files{perm(i)})); %从图像数据集 (img_ds) 中读取一张图像并显示它%可以用alexnet的方法
endtest_idx = randperm(num_images,9); %随机取5张做样本
img_ds_Test = subset(img_ds,test_idx);
train_idx = setdiff(1:length(img_ds.Files),test_idx);
img_ds_Train = subset(img_ds,train_idx);%% 步骤2:加载预训练好的网络% 加载ResNet50网络(注:该网络需要提前下载,当输入下面命令时按要求下载即可)
net = resnet50;%% 步骤3:对网络结构进行调整,替换最后几层% 获取网络图结构
LayerGraph = layerGraph(net);
clear net;% 确定训练数据中新冠图片标签类别数量:5类
numClasses = numel(categories(img_ds_Train.Labels));
disp(numClasses);% 保留ResNet50倒数第三层之前的网络,并替换后3层
% 倒数第三层的全连接层,这里修改为5类
newLearnableLayer = fullyConnectedLayer(numClasses,...'Name','new_fc',...'WeightLearnRateFactor',10,...
'BiasLearnRateFactor',10);
%numClasses 分类任务。通过设置 WeightLearnRateFactor 和 BiasLearnRateFactor
%来控制学习率的调整,使得这些层在训练过程中能更快速地学习
% 分别替换最后3层:fc1000、softmax和分类输出层
LayerGraph = replaceLayer(LayerGraph,'fc1000',newLearnableLayer);
%替换了原网络中的 fc1000 层(ResNet50 中的全连接层)为 new_fc 层
% (即刚才定义的新的全连接层)。replaceLayer 函数通过层的名字来替换图层newSoftmaxLayer = softmaxLayer('Name','new_softmax');
LayerGraph = replaceLayer(LayerGraph,'fc1000_softmax',newSoftmaxLayer);newClassLayer = classificationLayer('Name','new_classoutput');
LayerGraph = replaceLayer(LayerGraph,'ClassificationLayer_fc1000',newClassLayer);%% 步骤4:按照网络配置调整图像数据% 输入图像格式转换,这里调用了自定义函数preprocess
img_ds_Train.ReadFcn = @(filename)preprocess(filename);
img_ds_Test.ReadFcn  = @(filename)preprocess(filename);% 数据增强的参数
augmenter = imageDataAugmenter(...'RandRotation',[-5 5],...'RandXReflection',1,...'RandYReflection',1,...'RandXShear',[-0.05 0.05],...'RandYShear',[-0.05 0.05]);
% 将批量训练图像的大小调整为与输入层的大小相同
aug_img_ds_train = augmentedImageDatastore([224 224],img_ds_Train,'DataAugmentation',augmenter);
% 将批量测试图像的大小调整为与输入层的大小相同
aug_img_ds_test = augmentedImageDatastore([224 224],img_ds_Test);%% 步骤5:对网络进行训练% 对训练参数进行设置
options = trainingOptions('adam',...'MaxEpochs',10,...'MiniBatchSize',8,...'Shuffle','every-epoch',...'InitialLearnRate',1e-4,...'Verbose',false,...'Plots','training-progress',...'ExecutionEnvironment','cpu');% 用训练图像对网络进行训练
netTransfer = trainNetwork(aug_img_ds_train,LayerGraph,options);%% 步骤6:进行测试并查看结果% 对训练好的网络采用验证数据集进行验证
[YPred,scores] = classify(netTransfer,aug_img_ds_test);% 随机显示验证效果
idx = randperm(numel(img_ds_Test.Files),4);
figure
for i = 1:4subplot(2,2,i)I = readimage(img_ds_Test,idx(i));imshow(I)label = YPred(idx(i));title(string(label));
end%% 计算分类准确率
YValidation = img_ds_Test.Labels;
accuracy = mean(YPred == YValidation);%% 创建并显示混淆矩阵
figure
confusionchart(YValidation,YPred) 

实现效果如下:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/488324.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【PHP】部署和发布PHP网站到IIS服务器

欢迎来到《小5讲堂》 这是《PHP》系列文章,每篇文章将以博主理解的角度展开讲解。 温馨提示:博主能力有限,理解水平有限,若有不对之处望指正! 目录 前言安装PHP稳定版本线程安全版解压使用 PHP配置配置文件扩展文件路径…

SSM 校园一卡通密钥管理系统 PF 于校园图书借阅管理的安全保障

摘 要 传统办法管理信息首先需要花费的时间比较多,其次数据出错率比较高,而且对错误的数据进行更改也比较困难,最后,检索数据费事费力。因此,在计算机上安装校园一卡通密钥管理系统软件来发挥其高效地信息处理的作用&a…

TCP 2

文章目录 Tcp状态三次握手四次挥手理解TIME WAIT状态 如上就是TCP连接管理部分 流量控制滑动窗口快重传 延迟应答原理 捎带应答总结TCP拥塞控制拥塞控制的策略 -- 每台识别主机拥塞的机器都要做 面向字节流和粘包问题tcp连接异常进程终止机器重启机器掉电/网线断开 Tcp状态 建…

【操作系统】实验二:观察Linux,使用proc文件系统

实验二 观察Linux,使用proc文件系统 实验目的:学习Linux内核、进程、存储和其他资源的一些重要特征。读/proc/stat文件,计算并显示系统CPU占用率和用户态CPU占用率。(编写一个程序使用/proc机制获得以及修改机器的各种资源参数。…

【密码学】AES算法

一、AES算法介绍: AES(Advanced Encryption Standard)算法是一种广泛使用的对称密钥加密,由美国国家标准与技术研究院(NIST)于2001年发布。 AES是一种分组密码,支持128位、192位和256位三种不同…

【学习笔记】目前市面中手持激光雷达设备及参数汇总

手持激光雷达设备介绍 手持激光雷达设备是一种利用激光时间飞行原理来测量物体距离并构建三维模型的便携式高科技产品。它通过发射激光束并分析反射回来的激光信号,能够精确地获取物体的三维结构信息。这种设备以其高精度、适应各种光照环境的能力和便携性&#xf…

探索 LeNet-5:卷积神经网络的先驱与手写数字识别传奇

一、引言 在当今深度学习技术蓬勃发展的时代,各种复杂而强大的神经网络架构不断涌现,如 ResNet、VGG、Transformer 等,它们在图像识别、自然语言处理、语音识别等众多领域都取得了令人瞩目的成果。然而,当我们回顾深度学习的发展历…

【数据结构——栈与队列】链栈的基本运算(头歌实践教学平台习题)【合集】

目录😋 任务描述 相关知识 测试说明 我的通关代码: 测试结果: 任务描述 本关任务:编写一个程序实现链栈的基本运算。 相关知识 为了完成本关任务,你需要掌握: 初始化栈、销毁栈、判断栈是否为空、进栈、出栈、取栈…

【笔记】架构上篇Day6 法则四:为什么要顺应技术的生命周期?

法则四:为什么要顺应技术的生命周期? 简介:包含模块一 架构师的六大生存法则-法则四:为什么要顺应技术的生命周期?&法则四:架构设计中怎么判断和利用技术趋势? 2024-08-29 17:30:07 你好&am…

Security自定义逻辑认证(极简案例)

项目结构 config SecurityConfig package com.wunaiieq.tmp2024121105.config;import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.security.crypto.password.NoOpPasswordEnco…

docker安装ddns-go(外网连接局域网)

docker先下载镜像,目前最新版是v6.7.6 也可以csdn资源下载 再导入dockers https://download.csdn.net/download/u014756339/90096748 docker load -i ddns-go.tar 启动 docker run -d --name ddns-go --restartalways --nethost -v /opt/ddns-go:/root jeessy/…

技术速递|dotnet scaffold – .NET 的下一代内容创建

作者:Sayed Ibrahim Hashimi - 首席项目经理 排版:Alan Wang Visual Studio 中为 ASP.NET Core 项目搭建脚手架是一项长期特性,是在 ASP.NET Core 发布后不久添加的。多年来,我们一直支持从命令行搭建脚手架。根据从命令行操作中获…

基于ZYNQ 7z010开发板 oled点亮的实现

dc拉高的时候就是发送128字节数据的时候 发送指令dc拉低 模式是00 sck先置低再置高 复位是与开发板上的按键一样都是低有效 25位字节指令 加 3字节的 页地址加起始结束 b0,00,10, timescale 1ns / 1ps module top0(input wire clk ,input wire rst_n,// out…

使用torch模拟 BMM int8量化计算。

使用torch模型BMM int8计算。 模拟:BMM->softmax->BMM 计算流程 import torch import numpy as np torch.manual_seed(777) def int8_quantize_per_token(x: torch.Tensor, axis: int -1, attnsFalse):if x.dtype ! torch.float32:x x.type(torch.float32)…

【CSS in Depth 2 精译_070】11.3 利用 OKLCH 颜色值来处理 CSS 中的颜色问题(下):从页面其他颜色衍生出新颜色

当前内容所在位置(可进入专栏查看其他译好的章节内容) 第四部分 视觉增强技术 ✔️【第 11 章 颜色与对比】 ✔️ 11.1 通过对比进行交流 11.1.1 模式的建立11.1.2 还原设计稿 11.2 颜色的定义 11.2.1 色域与色彩空间11.2.2 CSS 颜色表示法 11.2.2.1 RGB…

HTML:表格重点

用表格就用table caption为该表上部信息,用来说明表的作用 thead为表头主要信息,效果加粗 tbody为表格中的主体内容 tr是 table row 表格的行 td是table data th是table heading表格标题 ,一般表格第一行的数据都是table heading

15.Java 网络编程(网络相关概念、InetAddress、NetworkInterface、TCP 网络通信、UDP 网络通信、超时中断)

一、网络相关概念 1、网络通信 网络通信指两台设备之间通过网络实现数据传输,将数据通过网络从一台设备传输到另一台设备 java.net 包下提供了一系列的类和接口用于完成网络通信 2、网络 两台以上设备通过一定物理设备连接构成网络,根据网络的覆盖范…

项目中使用AntV L7地图(五)添加飞线

项目中使用AntV L7地图,添加 飞线 文档地址:https://l7.antv.antgroup.com/zh/examples/line/animate/#trip_animate 一、初始化地图 使用的地图文件为四川地图JSON,下载地址:https://datav.aliyun.com/portal/school/atlas/area_selector#&…

MySQL-DQL之数据表操作

文章目录 零. 准备工作一. 简单查询1.查询所有的商品.2.查询商品名和商品价格.3.查询结果是表达式(运算查询):将所有商品的价格10元进行显示. 二. 条件查询1. 比较查询2. 范围查询3. 逻辑查询4. 模糊查询5. 非空查询 三. 排序查询四. 聚合查询…

nacos bootstrap.yml 和 spring.config.import 加载配置的流程区别

相关依赖 springboot:2.7.15 nacos:2.2.3 bootstrap.yml加载方式 加载流程如下图所示 从图中可以看出,: 1.bootstrap.yml 的加载是在 BootstrapApplicationListener.onApplicationEvent 接收到 ApplicationEnvironmentPreparedEventEvent 事件后另起一个 Sprin…