计算机视觉的应用11-基于pytorch框架的卷积神经网络与注意力机制对街道房屋号码的识别应用

大家好,我是微学AI,今天给大家介绍一下计算机视觉的应用11-基于pytorch框架的卷积神经网络与注意力机制对街道房屋号码的识别应用,本文我们借助PyTorch,快速构建和训练卷积神经网络(CNN)等模型,以实现街道房屋号码的准确识别。引入并注意力机制,它是一种模仿人类视觉注意机制的方法,在图像处理任务中具有广泛应用。通过引入注意力机制,模型可以自动关注图像中与房屋号码相关的区域,提高识别的准确性和鲁棒性。

一、项目介绍

街道房屋号码识别是计算机视觉中的一个重要任务,通过对街道房屋号码的自动识别,可以对街道图像进行更好的理解和分析。本文将介绍如何使用PyTorch框架和注意力机制,结合SVHN数据集,来实现街道房屋号码的分类识别。

二、SVHN数据集

SVHN(Street View House Numbers)是一个公开的大规模街道数字图像数据集。该数据集包含了从Google Street View中获取的房屋门牌号码图像,可以用于训练和测试机器学习模型,以实现自动识别街道房屋号码的任务。

2.1 数据集下载和加载

首先,我们需要下载并加载SVHN数据集。在PyTorch中,我们可以使用torchvision库中的datasets模块来实现这一步。

数据集的下载与查看:

train_dataset = datasets.SVHN(root='./data', split='train', download=True)images = train_dataset.data[:10]  # shape: (10, 3, 32, 32)
labels = train_dataset.labels[:10]images = np.transpose(images, (0, 2, 3, 1))# Plot the images
fig, axs = plt.subplots(2, 5, figsize=(12, 6))
axs = axs.ravel()for i in range(10):axs[i].imshow(images[i])axs[i].set_title(f"Label: {labels[i]}")axs[i].axis('off')plt.tight_layout()
plt.show()

在这里插入图片描述

数据集的加载,预处理,便于输入模型训练:

import torch
from torchvision import datasets, transforms# 数据预处理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))])# 下载并加载SVHN数据集
trainset = datasets.SVHN(root='./data', split='train', download=True, transform=transform)
testset = datasets.SVHN(root='./data', split='test', download=True, transform=transform)trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True)

三、卷积网络搭建

使用PyTorch搭建卷积神经网络。卷积神经网络(Convolutional Neural Network, CNN)是一种主要用于处理具有类似网格结构的数据的神经网络,如图像(2D网格的像素点)或者文本(1D网格的单词)。

3.1 网络结构定义

下面是一个基础的卷积神经网络模型,包含两个卷积层、两个最大池化层和两个全连接层。

from torch import nnclass ConvNet(nn.Module):def __init__(self):super(ConvNet, self).__init__()self.layer1 = nn.Sequential(nn.Conv2d(3, 32, kernel_size=5, stride=1, padding=2),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2))self.layer2 = nn.Sequential(nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2))self.drop_out = nn.Dropout()self.fc1 = nn.Linear(7 * 7 * 64, 1000)self.fc2 = nn.Linear(1000, 10)def forward(self, x):out = self.layer1(x)out = self.layer2(out)out = out.reshape(out.size(0), -1)out = self.drop_out(out)out = self.fc1(out)return self.fc2(out)

四、加入注意力机制

注意力机制是一种能够改进模型性能的技术。在我们的模型中,我们将添加一个注意力层来帮助模型更好地专注于输入图像中的重要部分。

4.1 注意力层定义

我将实现基本的注意力层,这个层将会生成一个和输入同样大小的注意力图,然后将输入和这个注意力图对应元素相乘,以此来实现对输入的加权。

注意力机制层的数学原理:
注意力机制的数学原理可以用以下公式表示:

给定输入张量 x ∈ R b × c × h × w x \in \mathbb{R}^{b \times c \times h \times w} xRb×c×h×w,其中 b b b 是批量大小, c c c 是通道数, h h h 是高度, w w w 是宽度。注意力机制分为两个阶段:特征提取和特征加权。
1.特征提取阶段:
首先,通过自适应平均池化层(AdaptiveAvgPool2d)将输入张量 x x x 在高度和宽度上进行平均池化,得到形状为 b × c × 1 × 1 b \times c \times 1 \times 1 b×c×1×1 的张量 y y y。这里使用自适应平均池化是为了使得张量 y y y 在不同尺寸的输入上也能产生相同的输出。
2.特征加权阶段:
接下来,通过全连接层(Linear)和非线性激活函数ReLU对张量 y y y 进行特征变换,减少通道数,并保留重要特征。然后再通过另一个全连接层和Sigmoid激活函数得到权重张量 y ′ ∈ R b × c × 1 × 1 y' \in \mathbb{R}^{b \times c \times 1 \times 1} yRb×c×1×1,表示每个通道的权重值。这里的权重值在0到1之间,用于控制每个通道在后续的计算中所占的比重。将权重张量 y ′ y' y 扩展成与输入张量 x x x 相同的形状,并将其与输入张量相乘,得到经过注意力加权的特征张量。这样就实现了对输入张量的自适应特征加权。

数学表示为:
y = AdaptiveAvgPool2d ( x ) y ′ = Sigmoid ( Linear ( ReLU ( Linear ( y ) ) ) ) output = x ⊙ y ′ y = \text{AdaptiveAvgPool2d}(x) \\ y' = \text{Sigmoid}(\text{Linear}(\text{ReLU}(\text{Linear}(y)))) \\ \text{output} = x \odot y' y=AdaptiveAvgPool2d(x)y=Sigmoid(Linear(ReLU(Linear(y))))output=xy

其中 ⊙ \odot 表示按元素相乘操作。

注意力机制层的搭建代码:

class AttentionLayer(nn.Module):def __init__(self, channel, reduction=16):super(AttentionLayer, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.fc = nn.Sequential(nn.Linear(channel, channel// reduction, bias=False),nn.ReLU(inplace=True),nn.Linear(channel // reduction, channel, bias=False),nn.Sigmoid())def forward(self, x):b, c, _, _ = x.size()y = self.avg_pool(x).view(b, c)y = self.fc(y).view(b, c, 1, 1)return x * y.expand_as(x)

4.2 在网络中加入注意力层

我们将注意力层加入到ConvNet模型中:

class ConvNet(nn.Module):def __init__(self):super(ConvNet, self).__init__()self.layer1 = nn.Sequential(nn.Conv2d(3, 32, kernel_size=5, stride=1, padding=2),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),AttentionLayer(32))self.layer2 = nn.Sequential(nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),AttentionLayer(64))self.drop_out = nn.Dropout()self.fc1 = nn.Linear(8 * 8 * 64, 1000)self.fc2 = nn.Linear(1000, 10)def forward(self, x):out = self.layer1(x)out = self.layer2(out)out = out.reshape(out.size(0), -1)out = self.drop_out(out)out = self.fc1(out)return self.fc2(out)

五、模型训练与测试

接下来,我们将进行模型的训练和测试。

5.1 模型训练

import torch.optim as optimmodel = ConvNet()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)for epoch in range(10):  # loop over the dataset multiple timesrunning_loss = 0.0for i, data in enumerate(trainloader, 0):# get the inputs; data is a list of [inputs, labels]inputs, labels = data# zero the parameter gradientsoptimizer.zero_grad()# forward + backward + optimizeoutputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()# print statisticsrunning_loss += loss.item()if i % 20 == 0:    # print every 2000 mini-batchesprint('[%d, %5d] loss: %.3f' %(epoch + 1, i + 1, running_loss / 2000))running_loss = 0.0print('Finished Training')

5.2 模型测试

correct = 0
total = 0
with torch.no_grad():for data in testloader:images, labels = dataoutputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))

六、结论

这篇文章就像是一张奇妙的地图,引领你进入计算机视觉任务的神奇世界。在这个世界里,你将与PyTorch和注意力机制这两位强大的伙伴结伴前行,共同探索街道房屋号码识别的奥秘。

想象一下,你置身于繁忙的街道上,满目琳琅的房屋号码挑战着你的视力。而你却拥有了一种神奇的眼力,能轻松识别出每一个号码。这种超凡能力正是计算机视觉任务的魔法所在。

我们要携手PyTorch这位强大的工具,它如同一把巧妙的魔法棒,能帮助我们构建强大的神经网络模型。通过PyTorch,我们可以灵活地定义模型的结构,设置各种参数,并进行高效的训练和推理。

我们遇到了注意力机制,就像是一盏明亮的灯塔,照亮了我们前进的方向。注意力机制能够使神经网络集中注意力于图像中的重要区域,从而提高识别的准确性。利用这种机制,我们可以让模型更加聪明地注重街道房屋号码所在的位置和细节,从而更好地进行识别。而SVHN数据集则是我们探险的指南,其中包含了大量真实世界中的街道房屋号码图像。通过导入这些数据,我们可以让模型从中学习并提高自己的识别能力。这些图像将带领我们穿越城市的角落,感受不同场景下的挑战和变化。通过这篇文章,我们不仅可以更深入地理解计算机视觉任务的本质,还能获得启发。就像是一次奇妙的冒险,我们将学会如何使用PyTorch和注意力机制来实现街道房屋号码的识别任务。让我们一起跟随这个引人入胜的旅程,开拓视野,追寻新的可能性吧!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/97328.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Google开源了可视化编程框架Visual Blocks for ML

Visual Blocks for ML是一个由Google开发的开源可视化编程框架。它使你能够在易于使用的无代码图形编辑器中创建ML管道。 为了运行Visual Blocks for ML。需要确保你的GPU是可以工作的。剩下的就是clone代码,然后运行,下面我们做一个简单的介绍&#xf…

FifthOne:计算机视觉提示和技巧

一、说明 欢迎来到我们每周的FiftyOne提示和技巧博客,我们回顾了最近在Slack,GitHub,Stack Overflow和Reddit上弹出的问题和答案。FiftyOne是一个开源机器学习工具集,使数据科学团队能够通过帮助他们策划高质量数据集、评估模型、…

图像处理常见的两种拉流方式

传统算法或者深度学习在进行图像处理之前,总是会首先进行图像的采集,也就是所谓的拉流。解决拉流的方式有两种,一个是直接使用opencv进行取流,另一个是使用ffmpeg进行取流,如下分别介绍这两种方式进行拉流处理。 1、o…

webSocket 聊天室 node.js 版

全局安装vue脚手架 npm install vue/cli -g 创建 vue3 ts 脚手架 vue create vue3-chatroom 后端代码 src 同级目录下建 server: const express require(express); const app express(); const http require(http); const server http.createServer(app);const io req…

云原生反模式

通过了解这些反模式并遵循云原生最佳实践,您可以设计、构建和运营更加强大、可扩展和成本效益高的云原生应用程序。 1.单体架构:在云上运行一个大而紧密耦合的应用程序,妨碍了可扩展性和敏捷性。2.忽略成本优化:云服务可能昂贵&am…

大数据课程K2——Spark的RDD弹性分布式数据集

文章作者邮箱:yugongshiye@sina.cn 地址:广东惠州 ▲ 本章节目的 ⚪ 了解Spark的RDD结构; ⚪ 掌握Spark的RDD操作方法; ⚪ 掌握Spark的RDD常用变换方法、常用执行方法; 一、Spark最核心的数据结构——RDD弹性分布式数据集 1. 概述 初学Spark时,把RDD看…

easyx图形库基础:3实现弹球小游戏

实现弹球小游戏 一.实现弹球小游戏:1.初始化布:2.初始化一个球的信息:3.球的移动和碰撞反弹4.底边挡板的绘制和移动碰撞重置数据。 二.整体代码: 一.实现弹球小游戏: 1.初始化布: int main() {initgraph(800, 600);setorigin(40…

麻辣烫数据可视化,麻辣烫市场将持续蓬勃发展

麻辣烫,这道源自中国的美食,早已成为人们生活中不可或缺的一部分。它独特的香辣口味,让人忍不住每每流连忘返。与人们的关系,简直如同挚友一般。每当寒冷的冬日或疲惫的时刻,麻辣烫总是悄然走进人们的心房,…

FreeCAD的傻瓜式初级使用教程

起因:自己想DIY一套线性手刹和序列档,以便和我之前的freejoy控制器相连接应用,需要自己制图和在某宝找代加工的商家,但我又不想安装体积巨大的AutoCAD,所以找了以下开源、免费的解决方案,所以就有了这篇文章…

使用PostgreSQL构建强大的Web应用程序:最佳实践和建议

PostgreSQL是一个功能强大的开源关系型数据库,它拥有广泛的用户群和活跃的开发社区。越来越多的Web应用选择PostgreSQL作为数据库 backend。如何充分利用PostgreSQL的特性来构建健壮、高性能的Web应用?本文将给出一些最佳实践和建议。 一、选择合适的PostgreSQL数据类型 Pos…

C# WPF 中 外部图标引入iconfont,无法正常显示问题 【小白记录】

wpf iconfont 外部图标引入&#xff0c;无法正常显示问题。 1. 检查资源路径和引入格式是否正确2. 检查资源是否包含在程序集中 1. 检查资源路径和引入格式是否正确 正确的格式&#xff0c;注意字体文件 “xxxx.ttf” 应写为 “#xxxx” <TextBlock Text"&#xe7ae;…

类之间的比较

作者简介&#xff1a; zoro-1&#xff0c;目前大一&#xff0c;正在学习Java&#xff0c;数据结构等 作者主页&#xff1a; zoro-1的主页 欢迎大家点赞 &#x1f44d; 收藏 ⭐ 加关注哦&#xff01;&#x1f496;&#x1f496; 类之间的比较 固定需求式比较器 固定需求式 通过…

【C语言】字符分类函数、字符转换函数、内存函数

前言 之前我们用两篇文章介绍了strlen、strcpy、stract、strcmp、strncpy、strncat、strncmp、strstr、strtok、streeror这些函数 第一篇文章strlen、strcpy、stract 第二篇文章strcmp、strncpy、strncat、strncmp 第三篇文章strstr、strtok、streeror 今天我们就来学习字…

ES 概念

es 概念 Elasticsearch是分布式实时搜索、实时分析、实时存储引擎&#xff0c;简称&#xff08;ES&#xff09;成立于2012年&#xff0c;是一家来自荷兰的、开源的大数据搜索、分析服务提供商&#xff0c;为企业提供实时搜索、数据分析服务&#xff0c;支持PB级的大数据。 -- …

HTML详解连载(8)

HTML详解连载&#xff08;8&#xff09; 专栏链接 [link](http://t.csdn.cn/xF0H3)下面进行专栏介绍 开始喽浮动-产品区域布局场景 解决方法清除浮动方法一&#xff1a;额外标签发方法二&#xff1a;单伪元素法方法三&#xff1a;双伪元素法方法四&#xff1a;overflow浮动-总结…

GO学习之 数据库(mysql)

GO系列 1、GO学习之Hello World 2、GO学习之入门语法 3、GO学习之切片操作 4、GO学习之 Map 操作 5、GO学习之 结构体 操作 6、GO学习之 通道(Channel) 7、GO学习之 多线程(goroutine) 8、GO学习之 函数(Function) 9、GO学习之 接口(Interface) 10、GO学习之 网络通信(Net/Htt…

【C++】stack/queue/优先级队列的模拟实现

目录 1. stack/queue1.1 模拟实现 2. 优先级队列2.1 模拟实现2.2 仿函数 1. stack/queue stack文档说明 queue文档说明 stack和queue被称为容器适配器。 容器适配器是什么&#xff1f; 它是一种特殊的容器类型&#xff0c;通过封装已有的容器类型来提供特定功能的接口函数&a…

使用Nginx调用网关,然后网关调用其他微服务

问题前提&#xff1a;目前我的项目是已经搭建了网关根据访问路径路由到微服务&#xff0c;然后现在我使用了Nginx将静态资源都放在了Nginx中&#xff0c;然后我后端定义了一个接口访问一个html页面&#xff0c;但是html页面要用到静态资源&#xff0c;这个静态资源在我的后端是…

关于es中索引,倒排索引的理解

下面是我查询进行理解的东西 也就是说我们ES中的索引就相当于我们mysql中的数据库表&#xff0c;索引库就相当于我们的数据库&#xff0c;我们按照mapping规则会根据相应的字段&#xff08;index为true默认&#xff09;来创建倒排索引&#xff0c;这个倒排索引就相当于我们索引…

QT-Mysql数据库图形化接口

QT sql mysqloper.h qsqlrelationaltablemodelview.h /************************************************************************* 接口描述&#xff1a;Mysql数据库图形化接口 拟制&#xff1a; 接口版本&#xff1a;V1.0 时间&#xff1a;20230727 说明&#xff1a;支…