1.30、基于卷积神经网络的手写数字旋转角度预测(matlab)

1、卷积神经网络的手写数字旋转角度预测原理及流程

基于卷积神经网络的手写数字旋转角度预测是一个常见的计算机视觉问题。在这种情况下,我们可以通过构建一个卷积神经网络(Convolutional Neural Network,CNN)来实现该任务。以下是基于MATLAB的手写数字旋转角度预测的原理和流程:

原理:

  1. 数据准备:首先,准备一个包含手写数字图像和其对应标签(即旋转角度)的数据集。这些图像可以是MNIST数据集的手写数字。

  2. 模型建立:构建一个CNN模型,包括卷积层、池化层、全连接层等,来学习手写数字图像的特征并预测它们的旋转角度。

  3. 训练模型:利用准备好的训练数据集对CNN模型进行训练,通过反向传播算法来调整模型参数以最小化预测与真实标签之间的误差。

  4. 模型评估:使用测试数据集对训练好的模型进行评估,计算模型的准确率或其他性能指标,以评估其在预测手写数字旋转角度方面的性能。

流程:

  1. 加载数据集:在MATLAB中加载手写数字图像数据集,并对图像进行预处理和标签处理,以便输入到CNN模型中。

  2. 构建CNN模型:使用MATLAB深度学习工具箱中的函数(如convolution2dLayermaxPooling2dLayerfullyConnectedLayerclassificationLayer)构建一个适合手写数字旋转角度预测的CNN模型。

  3. 定义训练选项:设置训练选项,包括优化器类型、学习率、最大训练轮数等。

  4. 训练模型:使用训练数据集对CNN模型进行训练,通过调用trainNetwork函数并传入训练数据和训练选项来完成训练过程。

  5. 评估模型:使用测试数据集对训练好的模型进行评估,计算准确率等性能指标。

  6. 预测手写数字的旋转角度:最后,使用训练好的模型对新的手写数字图像进行预测,得到其旋转角度的预测结果。

这是基于卷积神经网络的手写数字旋转角度预测的基本原理和流程。

2、卷积神经网络的手写数字旋转角度预测案例说明

1)解决问题

卷积神经网络来预测手写数字的旋转角度

2)技术方案

回归任务涉及预测连续数值而不是离散类标签,回归构造卷积神经网络架构,训练网络,并使用经过训练的网络来预测旋转手写数字的角度。

3、加载数据

1)数据说明

数据集包含手写数字的合成图像以及每个图像的旋转角度(以度为单位)。

2)加载数据代码

说明:变量 anglesTrain 和 anglesTest 是以度为单位的旋转角度。训练数据集和测试数据集各包含 5000 个图像。

load DigitsDataTrain
load DigitsDataTest

3)显示训练集代码

numObservations = size(XTrain,4);
idx = randperm(numObservations,49);
I = imtile(XTrain(:,:,:,idx));
figure
imshow(I);

 视图效果

06f6163700a5438b802835e39b5c0504.png

4)数据集划分代码

说明:使用 trainingPartitions 函数将 XTrain 和 anglesTrain 分区为训练分区和验证分区,留出 15% 的训练数据用于验证。

[idxTrain,idxValidation] = trainingPartitions(numObservations,[0.85 0.15]);XValidation = XTrain(:,:,:,idxValidation);
anglesValidaiton = anglesTrain(idxValidation);XTrain = XTrain(:,:,:,idxTrain);
anglesTrain = anglesTrain(idxTrain);

4、检查数据归一化

1)归一化说明

训练神经网络时,确保数据在网络的所有阶段均归一化

对于使用梯度下降的网络训练,归一化有助于训练的稳定和加速.

数据比例不佳,则损失可能会变为 NaN,并且网络参数在训练过程中可能发生偏离

归一化数据的常用方法包括重新缩放数据,使其范围变为 [0,1],或使其均值为 0 且标准差为 1

2)绘制响应的分布代码

说明:响应(以度为单位的旋转角度)大致均匀地分布在 -45 和 45 之间,效果很好,无需归一化。

figure
histogram(anglesTrain)
axis tight
ylabel("Counts")
xlabel("Rotation Angle")

视图效果 

b293b09f497e445ea96b6b4fd31045df.png

5、定义神经网络架构

1)神经网络架构说明

对于图像输入,指定一个图像输入层。

指定四个 convolution-batchnorm-ReLU 模块,并增加滤波器数量。

在每个模块之间指定一个具有池化区域的平均池化层,步幅大小为 2。

在网络末尾,包含一个全连接层,其输出大小与响应数量匹配。

2)神经网络架构代码

numResponses = 1;layers = [imageInputLayer([28 28 1])convolution2dLayer(3,8,Padding="same")batchNormalizationLayerreluLayeraveragePooling2dLayer(2,Stride=2)convolution2dLayer(3,16,Padding="same")batchNormalizationLayerreluLayeraveragePooling2dLayer(2,Stride=2)convolution2dLayer(3,32,Padding="same")batchNormalizationLayerreluLayerconvolution2dLayer(3,32,Padding="same")batchNormalizationLayerreluLayerfullyConnectedLayer(numResponses)];

6、指定训练选项

1)指定训练选项说明

使用Experiment Manager。

将初始学习率设置为 0.001,并在 20 轮训练后降低学习率。

通过指定验证数据和验证频率,监控训练过程中的网络准确度。软件基于训练数据训练网络,并在训练过程中按固定时间间隔计算基于验证数据的准确度。验证数据不用于更新网络权重。

在图中显示训练进度并监控均方根误差。

2)指定训练选项代码

miniBatchSize  = 128;
validationFrequency = floor(numel(anglesTrain)/miniBatchSize);options = trainingOptions("sgdm", ...MiniBatchSize=miniBatchSize, ...InitialLearnRate=1e-3, ...LearnRateSchedule="piecewise", ...LearnRateDropFactor=0.1, ...LearnRateDropPeriod=20, ...Shuffle="every-epoch", ...ValidationData={XTest,anglesTest}, ...ValidationFrequency=validationFrequency, ...Plots="training-progress", ...Metrics="rmse", ...Verbose=false);

7、训练神经网络

1)训练神经网络说明

使用 trainnet 函数训练神经网络。

对于回归,请使用均方误差损失。默认情况下,trainnet 函数使用 GPU(如果有)。使用 GPU 需要 Parallel Computing Toolbox™ 许可证和受支持的 GPU 设备。要指定执行环境,请使用 ExecutionEnvironment 训练选项。

2)训练神经网络代码

net = trainnet(XTrain,anglesTrain,layers,"mse",options);

视图效果

 eaa6e5e1078e4fd8bd4d9c5fe0cf06f0.png

8、测试网络

1)测试网络说明

基于测试数据评估准确度来测试网络性能。

使用 minibatchpredict 函数进行预测。默认情况下,minibatchpredict 函数使用 GPU(如果有)。

2)测试网络代码

YTest = minibatchpredict(net,XTest);

3)计算均方根误差 (RMSE) 以衡量预测旋转角度和实际旋转角度之间的差异 

predictionError = anglesTest - YTest;
squares = predictionError.^2;
rmse = sqrt(mean(squares))

 4)散点图中可视化预测。绘制预测值对真实值的图。

figure
scatter(YTest,anglesTest,"+")
xlabel("Predicted Value")
ylabel("True Value")hold on
plot([-60 60], [-60 60],"y--")

视图效果 

a35e3dd73632423a92451bfbd0ae7b66.png

9、使用新数据进行预测

1)测试说明

使用 predict 函数并使用神经网络对第一个测试图像进行预测

2)测试代码

X = XTest(:,:,:,1);
if canUseGPUX = gpuArray(X);
end
Y = predict(net,X)

10、总结

基于卷积神经网络的手写数字旋转角度预测是一个常见的计算机视觉问题,通过使用MATLAB深度学习工具箱可以比较方便地实现。下面是对这一任务的总结:

总结要点:

  1. 数据准备:准备包含手写数字图像和对应旋转角度标签的数据集,如MNIST数据集。

  2. 模型建立:构建卷积神经网络(CNN)模型,通过卷积层、池化层、全连接层等结构来学习手写数字图像的特征和预测旋转角度。

  3. 训练模型:使用训练数据集对CNN模型进行训练,通过反向传播算法来调整模型参数,最小化预测与真实标签的误差。

  4. 模型评估:使用测试数据集对训练好的模型进行评估,计算准确率或其他性能指标,评定模型在预测旋转角度上的性能。

实现流程:

  1. 数据加载和预处理:加载手写数字图像数据集,对图像进行预处理(如缩放、归一化)并提取对应的旋转角度标签。

  2. CNN模型构建:使用MATLAB深度学习工具箱中的函数构建CNN模型,包括卷积层、池化层、全连接层,并适当选择激活函数。

  3. 训练模型:定义训练选项,选择优化器和学习率等参数,使用训练数据集对CNN模型进行训练。

  4. 模型评估:使用测试数据集对训练好的模型进行评估,检验其在预测手写数字旋转角度的准确性。

  5. 预测和应用:最后,使用训练好的模型对新的手写数字图像进行预测,实现手写数字旋转角度的自动识别和预测。

通过以上流程和总结,您可以利用MATLAB深度学习工具箱来实现基于卷积神经网络的手写数字旋转角度预测任务。

11、源代码

代码

%% 基于卷积神经网络的手写数字旋转角度预测
%卷积神经网络来预测手写数字的旋转角度
%回归任务涉及预测连续数值而不是离散类标签
%回归构造卷积神经网络架构,训练网络,并使用经过训练的网络来预测旋转手写数字的角度。%% 加载数据
%数据集包含手写数字的合成图像以及每个图像的旋转角度(以度为单位)。
%变量 anglesTrain 和 anglesTest 是以度为单位的旋转角度。训练数据集和测试数据集各包含 5000 个图像。load DigitsDataTrain
load DigitsDataTest%显示训练集
numObservations = size(XTrain,4);
idx = randperm(numObservations,49);
I = imtile(XTrain(:,:,:,idx));
figure
imshow(I);%数据集划分
%使用 trainingPartitions 函数将 XTrain 和 anglesTrain 分区为训练分区和验证分区,留出 15% 的训练数据用于验证。
[idxTrain,idxValidation] = trainingPartitions(numObservations,[0.85 0.15]);XValidation = XTrain(:,:,:,idxValidation);
anglesValidaiton = anglesTrain(idxValidation);XTrain = XTrain(:,:,:,idxTrain);
anglesTrain = anglesTrain(idxTrain);%% 检查数据归一化
%训练神经网络时,确保数据在网络的所有阶段均归一化。
%对于使用梯度下降的网络训练,归一化有助于训练的稳定和加速.
%数据比例不佳,则损失可能会变为 NaN,并且网络参数在训练过程中可能发生偏离
%归一化数据的常用方法包括重新缩放数据,使其范围变为 [0,1],或使其均值为 0 且标准差为 1%绘制响应的分布。
% 响应(以度为单位的旋转角度)大致均匀地分布在 -45 和 45 之间,效果很好,无需归一化。
figure
histogram(anglesTrain)
axis tight
ylabel("Counts")
xlabel("Rotation Angle")%%  定义神经网络架构
%对于图像输入,指定一个图像输入层。
%指定四个 convolution-batchnorm-ReLU 模块,并增加滤波器数量。
%在每个模块之间指定一个具有池化区域的平均池化层,步幅大小为 2。
%在网络末尾,包含一个全连接层,其输出大小与响应数量匹配。
numResponses = 1;layers = [imageInputLayer([28 28 1])convolution2dLayer(3,8,Padding="same")batchNormalizationLayerreluLayeraveragePooling2dLayer(2,Stride=2)convolution2dLayer(3,16,Padding="same")batchNormalizationLayerreluLayeraveragePooling2dLayer(2,Stride=2)convolution2dLayer(3,32,Padding="same")batchNormalizationLayerreluLayerconvolution2dLayer(3,32,Padding="same")batchNormalizationLayerreluLayerfullyConnectedLayer(numResponses)];
%% 指定训练选项
%使用Experiment Manager。
%将初始学习率设置为 0.001,并在 20 轮训练后降低学习率。
%通过指定验证数据和验证频率,监控训练过程中的网络准确度。软件基于训练数据训练网络,并在训练过程中按固定时间间隔计算基于验证数据的准确度。验证数据不用于更新网络权重。
%在图中显示训练进度并监控均方根误差。miniBatchSize  = 128;
validationFrequency = floor(numel(anglesTrain)/miniBatchSize);options = trainingOptions("sgdm", ...MiniBatchSize=miniBatchSize, ...InitialLearnRate=1e-3, ...LearnRateSchedule="piecewise", ...LearnRateDropFactor=0.1, ...LearnRateDropPeriod=20, ...Shuffle="every-epoch", ...ValidationData={XTest,anglesTest}, ...ValidationFrequency=validationFrequency, ...Plots="training-progress", ...Metrics="rmse", ...Verbose=false);
%% 训练神经网络
%使用 trainnet 函数训练神经网络。
%对于回归,请使用均方误差损失。默认情况下,trainnet 函数使用 GPU(如果有)。使用 GPU 需要 Parallel Computing Toolbox™ 许可证和受支持的 GPU 设备。要指定执行环境,请使用 ExecutionEnvironment 训练选项。
net = trainnet(XTrain,anglesTrain,layers,"mse",options);
%% 测试网络
%基于测试数据评估准确度来测试网络性能。
%使用 minibatchpredict 函数进行预测。默认情况下,minibatchpredict 函数使用 GPU(如果有)。
YTest = minibatchpredict(net,XTest);
%计算均方根误差 (RMSE) 以衡量预测旋转角度和实际旋转角度之间的差异。
predictionError = anglesTest - YTest;
squares = predictionError.^2;
rmse = sqrt(mean(squares))
%散点图中可视化预测。绘制预测值对真实值的图。
figure
scatter(YTest,anglesTest,"+")
xlabel("Predicted Value")
ylabel("True Value")hold on
plot([-60 60], [-60 60],"y--")%% 使用新数据进行预测
%使用 predict 函数并使用神经网络对第一个测试图像进行预测
X = XTest(:,:,:,1);
if canUseGPUX = gpuArray(X);
end
Y = predict(net,X)

工程文件

https://download.csdn.net/download/XU157303764/89494539

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/377868.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

数学建模·Topsis优劣解距离法

Topsis优劣解 一种新的评价方法,特点就是利用原有数据,客观性强。 相较于模糊评价和层次评价 更加客观,充分利用原有数据,精确反映方案差距 基本原理 离最优解最近,离最劣解越远 具体步骤 正向化 代码与原理与熵权…

Spring Boot中@Async注解的使用及原理 + 常见问题及解决方案

😄 19年之后由于某些原因断更了三年,23年重新扬帆起航,推出更多优质博文,希望大家多多支持~ 🌷 古之立大事者,不惟有超世之才,亦必有坚忍不拔之志 🎐 个人CSND主页——Mi…

阿里云GPU服务器安装ComfyUI

连接到GPU服务器: 使用SSH客户端(如PuTTY或终端)连接到你的服务器。命令通常是: ssh usernameserver_ip安装依赖: 确保Python和Git已安装。在大多数Linux系统上,可以这样安装: sudo apt update sudo apt install python3 python3-pip git克隆ComfyUI仓库: 这步骤会下载ComfyUI的…

Jetson-AGX-Orin 非docker环境源码编译安装CyberRT

Jetson-AGX-Orin 非docker环境源码编译安装CyberRT 1、安装依赖 sudo apt update sudo apt-get install g gdb gcc cmake sudo apt install libpoco-dev uuid-dev libncurses5-dev python3-dev python3-pip python3 -m pip install protobuf3.14.02、下载CyberRT源码 git cl…

C语言 ——— 大/小端存储模式的介绍及判断

目录 何为大端小端 如何测试当前机器是大端还是小端 编写代码,判断当前机器的字节序 何为大端小端 大端字节序存储模式:数据的低位字节的内容 存放在 内存的高地址 中,数据的高位字节的内容 保存在 内存的低地址 中 小端字节序存储模式&am…

【系统架构设计师】九、软件工程(面向对象方法|逆向工程)

目录 六、面向对象方法 6.1 基本概念 6.2 面向对象的分析 6.2.1 用例关系 6.2.2 类之间的关系 6.3 面向对象的设计 6.4 面向对象设计原则与设计模式 6.5 面向对象软件的测试 七、逆向工程 历年真题练习 六、面向对象方法 面向对象的分析方法 (Object-Oriented Analys…

Vue从零到实战

💝💝💝欢迎来到我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 非常期待和您一起在这个小…

《后端程序员 · Nacos 常见配置 · 第一弹》

📢 大家好,我是 【战神刘玉栋】,有10多年的研发经验,致力于前后端技术栈的知识沉淀和传播。 💗 🌻 CSDN入驻不久,希望大家多多支持,后续会继续提升文章质量,绝不滥竽充数…

showdoc sqli to rce漏洞利用思考

漏洞版本 sqli <3.2.5 phar 反序列化 <3.2.4 漏洞分析 前台sqli 补丁 https://github.com/star7th/showdoc/commit/84fc28d07c5dfc894f5fbc6e8c42efd13c976fda 补丁对比发现&#xff0c;在server/Application/Api/Controller/ItemController.class.php中将$item_id变量…

海外ASO:iOS与谷歌优化的相同点和区别

海外ASO是针对iOS的App Store和谷歌的Google Play这两个主要海外应用商店进行的优化过程&#xff0c;两个不同的平台需要采取不同的优化策略&#xff0c;以下是对iOS优化和谷歌优化的详细解析&#xff1a; 一、iOS优化&#xff08;App Store&#xff09; 1、关键词覆盖 选择关…

服务器数据恢复—raid5阵列热备盘同步失败导致lun不可用的数据恢复案例

服务器存储数据恢复环境&#xff1a; 华为S5300存储中有一组由16块FC硬盘组建的RAID5磁盘阵列&#xff08;包含一块热备盘&#xff09;。 服务器存储故障&#xff1a; 该存储中的RAID5阵列1块硬盘由于未知原因离线&#xff0c;热备盘上线并开始同步数据&#xff0c;数据同步到…

starRocks搭建

公司要使用新的大数据架构&#xff0c;打算用国产代替国外的大数据平台。所以这里我就纠结用doris还是starrocks&#xff0c;如果用doris&#xff0c;因为是开源的&#xff0c;以后就可以直接用云厂商的。如果用starrocks就得自己搭建&#xff0c;但是以后肯定会商业化&#xf…

【linux】服务器ubuntu安装cuda11.0、cuDNN教程,简单易懂,包教包会

【linux】服务器ubuntu安装cuda11.0、cuDNN教程&#xff0c;简单易懂&#xff0c;包教包会 【创作不易&#xff0c;求点赞关注收藏】 文章目录 【linux】服务器ubuntu安装cuda11.0、cuDNN教程&#xff0c;简单易懂&#xff0c;包教包会一、版本情况介绍二、安装cuda1、到官网…

最新PHP自助商城源码,彩虹商城源码

演示效果图 后台效果图 运行环境&#xff1a; Nginx 1.22.1 Mysql5.7 PHP7.4 直接访问域名即可安装 彩虹自助下单系统二次开发 拥有供货商系统 多余模板删除 保留一套商城,两套发卡 源码无后门隐患 已知存在的BUG修复 彩虹商城源码&#xff1a;下载 密码:chsc 免责声明&…

「AI得贤招聘官」通过首批“AI产业创新场景应用案例”评估

近日&#xff0c;上海近屿智能科技有限公司的「AI得贤招聘官」&#xff0c;经过工业和信息化部工业文化发展中心数字科技中心的严格评估&#xff0c;荣获首批“AI产业创新场景应用案例”。 据官方介绍&#xff0c;为积极推进通用人工智能产业高质量发展&#xff0c;围绕人工智能…

浅析Kafka Streams消息流式处理流程及原理

以下结合案例&#xff1a;统计消息中单词出现次数&#xff0c;来测试并说明kafka消息流式处理的执行流程 Maven依赖 <dependencies><dependency><groupId>org.apache.kafka</groupId><artifactId>kafka-streams</artifactId><exclusio…

kafka发送消息流程

配置props.put(ProducerConfig.PARTITIONER_CLASS_CONFIG, RoundRobinPartitioner.class); public Map<String,Object> producerConfigs(){Map<String,Object> props new HashMap<>();props.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG,bootstrapServers…

文件的顺序读写

文件读写函数介绍 文件顺序读写函数 函数名功能适用于fputc 字符输出函数 所有输出流 fgetc 字符输⼊函数 所有输⼊流 fputs ⽂本⾏输出函数 所有输出流fgets ⽂本⾏输⼊函数 所有输⼊流fprintf 格式化输出函数 所有输出流fscanf 格式化输⼊函数 所有输⼊流fwrite ⼆进制输出…

Spring源码中的模板方法模式

1. 什么是模板方法模式 模板方法模式&#xff08;Template Method Pattern&#xff09;是一种行为设计模式&#xff0c;它在操作中定义算法的框架&#xff0c;将一些步骤推迟到子类中。模板方法让子类在不改变算法结构的情况下重新定义算法的某些步骤。 模板方法模式的定义&…

被动的机器人非线性MPC控制

MPC是一种基于数学模型的控制策略&#xff0c;它通过预测系统在未来一段时间内的行为&#xff0c;并求解优化问题来确定当前的控制输入&#xff0c;以实现期望的控制目标。对于非线性系统&#xff0c;MPC可以采用非线性模型进行预测和优化&#xff0c;这种方法被称为非线性模型…