基于CNN+RNNs(LSTM, GRU)的红点位置检测(pytorch)

1 项目背景

需要在图片精确识别三跟红线所在的位置,并输出这三个像素的位置。

在这里插入图片描述
其中,每跟红线占据不止一个像素,并且像素颜色也并不是饱和度和亮度极高的红黑配色,每个红线放大后可能是这样的。

在这里插入图片描述

而我们的目标是精确输出每个红点的位置,需要精确到像素。也就是说,对于每根红线,模型需要输出橙色箭头所指的像素而不是蓝色箭头所指的像素的位置。

在之前尝试过纯 RNNs 检测红点,但是准确率感人,在噪声极低的情况下并不能精准识别位置。但是有次尝试transformer位置编码之后发现效果不错:

实验loss完全准确的点
GRU129.66411762.0/9000 (20%)
LSTM249.20531267.0/9000 (14%)
Position embedding + GRU16.34035025.0/9000 (56%)
Position embedding + LSTM204.15511603.0/9000 (18%)

这说明模型的难点在于学习位置信息而不是寻找颜色有问题的点。联想到CNN也能提供位置信息,我决定尝试卷积一下的效果。

2 数据集

还是之前那个代码合成的数据集数据集,每个数据集规模在15000张图片左右,在没有加入噪音的情况下,每个样本预览如图所示:
在这里插入图片描述
加入噪音后,每个样本的预览如下图所示:

在这里插入图片描述

图中黑色部分包含比较弱的噪声,并非完全为黑色。

数据集包含两个文件,一个是文件夹,里面包含了jpg压缩的图像数据:
在这里插入图片描述
另一个是csv文件,里面包含了每个图像的名字以及3根红线所在的像素的位置。

在这里插入图片描述

3 思路

其实思路特别朴素。就是在RNNs要读序列化数据之前先用CNN把数据跑一遍,让原始的输入序列变成具有局部特征表示的嵌入表示,卷积后提取的特征输入到 RNN层,RNN 保持了序列中的长时依赖信息。接下来先用 fc1 把 RNN 的输出映射成分数,然后用 fc2 预测三个具体位置,经过 Sigmoid 输出 [0, 1] 的相对位置,再与宽度相乘得到真实位置。具体的流程如下图所示:

在这里插入图片描述

4 结果

在图片长度为1080、低噪声环境时,对比实验的结果如下:

实验loss完全准确的点
GRU129.66411762.0/9000 (20%)
LSTM249.20531267.0/9000 (14%)
CNN+GRU1419.5781601.0/9000 (7%)
CNN+LSTM1166.4599762.0/9000 (8%)

1080长度下图片抽样预测的效果如下:

在这里插入图片描述

在简单图片中的效果跟其他方法差距不大——基本都能准确定位红线,但是还是没办法做到像素级别的精确

在这里插入图片描述

可能是我的打开方式不对,但是CNN+RNN的效果并不如意。

从训练过程来看存在过拟合:

在这里插入图片描述

5 代码

CNN+GRU结构:


class CNN_GRU(nn.Module):def __init__(self, config):super(CNN_GRU, self).__init__()self.input_size = config.input_sizeself.hidden_size = config.hidden_sizeself.num_layers = config.num_layersself.device = config.device# CNNself.conv1 = nn.Conv1d(in_channels=self.input_size, out_channels=64, kernel_size=3, padding=1)self.conv2 = nn.Conv1d(in_channels=64, out_channels=128, kernel_size=3, padding=1)self.conv3 = nn.Conv1d(in_channels=128, out_channels=self.input_size, kernel_size=3, padding=1)self.gru = nn.GRU(input_size=self.input_size, hidden_size=self.hidden_size, num_layers=self.num_layers,batch_first=True, bidirectional=True, dropout=0.6)self.fc1 = nn.Sequential(nn.Linear(self.hidden_size * 2, 1))self.fc2 = nn.Sequential(nn.Linear(config.width, 3),  # predict 3 pointsnn.Sigmoid(),)self.scale = config.widthself.device = config.devicedef forward(self, x):x = x.squeeze(2)x = F.relu(self.conv1(x))  # (batch_size, 64, width)x = F.relu(self.conv2(x))  # (batch_size, 128, width)x = F.relu(self.conv3(x))  # (batch_size, input_size, width)x = x.permute(0, 2, 1)h0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(x.device)output, _ = self.gru(x0, h0)scores = self.fc1(output).squeeze(-1)  # shape: (batch_size, 1080)predicted_positions = self.fc2(scores)scaled_predicted_positions = predicted_positions * self.scalefinal_predicted_positions = torch.clamp(scaled_predicted_positions, min=0, max=self.scale - 1)return final_predicted_positions

CNN+LSTM结构:

class CNN_GRU(nn.Module):def __init__(self, config):super(CNN_GRU, self).__init__()self.input_size = config.input_sizeself.hidden_size = config.hidden_sizeself.num_layers = config.num_layersself.device = config.device# CNNself.conv1 = nn.Conv1d(in_channels=self.input_size, out_channels=64, kernel_size=3, padding=1)self.conv2 = nn.Conv1d(in_channels=64, out_channels=128, kernel_size=3, padding=1)self.conv3 = nn.Conv1d(in_channels=128, out_channels=self.input_size, kernel_size=3, padding=1)self.lstm = nn.LSTM(input_size=self.input_size, hidden_size=self.hidden_size, num_layers=self.num_layers,batch_first=True, bidirectional=True, dropout=0.6)self.fc1 = nn.Sequential(nn.Linear(self.hidden_size * 2, 1))self.fc2 = nn.Sequential(nn.Linear(config.width, 3),  # predict 3 pointsnn.Sigmoid(),)self.scale = config.widthself.device = config.devicedef forward(self, x):x = x.squeeze(2)x = F.relu(self.conv1(x))  # (batch_size, 64, width)x = F.relu(self.conv2(x))  # (batch_size, 128, width)x = F.relu(self.conv3(x))  # (batch_size, input_size, width)x = x.permute(0, 2, 1)h0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(x.device)c0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(x.device)output, _ = self.lstm(x, (h0, c0))scores = self.fc1(output).squeeze(-1)  # shape: (batch_size, 1080)predicted_positions = self.fc2(scores)scaled_predicted_positions = predicted_positions * self.scalefinal_predicted_positions = torch.clamp(scaled_predicted_positions, min=0, max=self.scale - 1)return final_predicted_positions

路过的大佬有什么建议 ball ball 在评论区打出来,我会去尝试~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/477403.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

使用 Elastic 收集 Windows 遥测数据:ETW Filebeat 输入简介

作者:来自 Elastic Chema Martinez 在安全领域,能够使用 Windows 主机的系统遥测数据为监控、故障排除和保护 IT 环境开辟了新的可能性。意识到这一点,Elastic 推出了专注于 Windows 事件跟踪 (ETW) 的新功能 - 这是一种强大的 Windows 原生机…

leetcode刷题记录(四十二)——101. 对称二叉树

(一)问题描述 . - 力扣(LeetCode). - 备战技术面试?力扣提供海量技术面试资源,帮助你高效提升编程技能,轻松拿下世界 IT 名企 Dream Offer。https://leetcode.cn/problems/symmetric-tree/description/给你…

LeetCode 力扣 热题 100道(九)反转链表(C++)

给你单链表的头节点 head ,请你反转链表,并返回反转后的链表。 方法一:迭代法 /*** Definition for singly-linked list.* struct ListNode {* int val;* ListNode *next;* ListNode() : val(0), next(nullptr) {}* ListNod…

取电快充协议芯片,支持全协议、内部集成LDO支持从UART串口读取电压电流消息

H004D 是一款支持全协议的受电端诱骗取电协议芯片,支持宽电压输入 3.3V~30V,芯片内部集成LDO,可输出 3.3V电压, 支持 通过UART 串口读取电压电流,支持定制功能,芯片采用QFN_20封装,线路简单,芯片…

FreeRTOS——事件标志组

一、概念与应用 1.1概念 事件是实现任务与任务或任务与中断间 通信的机制,用于同步,无数据传输。(注意与二值信号量区分) 与信号量不同的是,事件可以实现一对多、多对多的同步,即一个任务可以等待多个事…

window11编译pycdc.exe

一、代码库和参考链接 在对python打包的exe文件进行反编译时,会使用到uncompyle6工具,但是这个工具只支持python3.8及以下,针对更高的版本的python则不能反编译。 关于反编译参考几个文章: Python3.9及以上Pyinstaller 反编译教…

【100ask】IMX6ULL开发板用SPI驱动RC522模块

目录 一、问题汇总: 1.无法寻卡 2.寻卡不稳定 二、修改设备树 三、驱动程序 四、测试程序 1.rc522_ap.c 2.rc522_app.h 3.rc522_test.c 4.Makefile 前言: CSDN上大部分对于RC522的文章都是正点的,虽然文章写的挺详细,两…

springboot购物推荐网站的设计与实现(代码+数据库+LW)

摘要 随着信息互联网购物的飞速发展,一般企业都去创建属于自己的电商平台以及购物管理系统。本文介绍了东大每日推购物推荐网站的开发全过程。通过分析企业对于东大每日推购物推荐网站的需求,创建了一个计算机管理东大每日推购物推荐网站的方案。文章介…

小R的二叉树探险 | 模拟

问题描述 在一个神奇的二叉树中,结构非常独特: 每层的节点值赋值方向是交替的,第一层从左到右,第二层从右到左,以此类推,且该二叉树有无穷多层。 小R对这个二叉树充满了好奇,她想知道&#xf…

高精度计算题目合集

高精度计算题目合集 1168:大整数加法 1168:大整数加法 1168:大整数加法 高精度加法原理: a,b,c 都可以用数组表示。这些都是基于c语言的算术运算符形成的运算。 c 3 ( c 1 c 2 ) % 10 c_3(c_1c_2)\%1…

【2024APMCM亚太赛A题】完整参考论文与代码分享

A题 一、问题重述二、问题分析问题一:水下图像分类问题二:退化原因建模问题三:针对单一退化的图像增强方法问题四:复杂场景的综合增强模型问题五:针对性增强与综合增强的比较 三、问题假设退化特征独立性假设物理模型普…

VMware虚拟机(Ubuntu或centOS)共享宿主机网络资源

VMware虚拟机(Ubuntu或centOS)共享宿主机网络资源 由于需要在 Linux 环境下进行一些测试工作,于是决定使用 VMware 虚拟化软件来安装 Ubuntu 24.04 .1操作系统。考虑到测试过程中需要访问 Github ,要使用Docker拉去镜像等外部网络资源,因此产…

C0030.Clion中运行提示Process finished with exit code -1073741515 (0xC0000135)解决办法

1.错误提示 2.解决办法 添加环境变量完成之后,重启Clion软件,然后就可以正常调用由mingw编译的opencv库了。

每日计划-1123

1. 完成 15. 三数之和 class Solution { public:vector<vector<int>> threeSum(vector<int>& nums) {sort(nums.begin(), nums.end());// 待返回的三元组vector<vector<int>> triples;for(int i 0; i < nums.size(); i){// 检测重复的 n…

汇编语言基础

目录 基本套路 头部&#xff1a; 段&#xff1a; 函数&#xff1a; 导入masm32库 输入输出 加法指令 常见数据类型 定义数据类型 数据传达指令&#xff08;mov&#xff09; 加减法 常用伪指令 间接寻址 JMP和LOOP 堆栈操作 定义函数(ret,call) 位运算 jcc(跳…

React (三)

文章目录 项目地址十二、性能优化12.1 使用useMemo避免不必要的计算12.2 使用memo缓存组件,防止过度渲染12.3 useCallBack缓存函数12.4 useCallBack里访问之前的状态(没懂)十三、Styled-Components13.1 安装13.2给普通html元素添加样式13.3 继承和覆盖样式13.4 给react组件添…

MD5算法的学习

MD5_百度百科 MD5信息摘要算法&#xff08;Message-Digest Algorithm&#xff09;,一种被广泛使用的密码散列函数&#xff0c;可以产生出一个128位的&#xff08;16字节&#xff09;的散列值&#xff08;hash value&#xff09;&#xff0c;用于确保信息传输完整一致。MD5由美…

【虚拟机】VMWare的CentOS虚拟机断电或强制关机出现问题

VMware 虚拟机因为笔记本突然断电故障了&#xff0c;开机提示“Entering emergency mode. Exit the shell to continue.”&#xff0c;如下图所示&#xff1a; 解决方法&#xff1a;输入命令&#xff1a; xfs_repair -v -L /dev/dm-0 注&#xff1a;报 no such file or direct…

【论文阅读】WGSR

0. 摘要 0.1. 问题提出 1.超分辨率(SR)是一个不适定逆问题&#xff0c;可行解众多。 2.超分辨率(SR)算法在可行解中寻找一个在保真度和感知质量之间取得平衡的“良好”解。 3.现有的方法重建高频细节时会产生伪影和幻觉&#xff0c;模型区分图像细节与伪影仍是难题。 0.2. …

解决 Android 单元测试 No tests found for given includes:

问题 报错&#xff1a; Execution failed for task :testDebugUnitTest. > No tests found for given includes: 解决方案 1、一开始以为是没有给测试类加public修饰 2、然后替换 Test 注解的包可以解决&#xff0c;将 org.junit.jupiter.api.Test 修改为 org.junit.Tes…