基于卷积神经网络的交通标志识别(pytorch,opencv,yolov5)

文章目录

  • 数据集介绍:
  • resnet18模型代码
  • 加载数据集(Dataset与Dataloader)
  • 模型训练
  • 训练准确率及损失函数:
  • resnet18交通标志分类源码
  • yolov5检测与识别(交通标志)

本文共包含两部分,
第一部分是用resnet18对交通标志分类,仅仅只是交通标志分类
文末附有yolov5和resnet18结合的源码,yolov5复制检测交通标志位置,然后使用resnet18对交通标志进行分类。

数据集介绍:

本文使用的数据集共有6000多张,共包含58个类别。部分数据集如下:

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

resnet18模型代码

使用pytorch自带的resnet18模型,代码如下:

from torchvision import models
import torch.nn as nn#加载resnet18模型
net=models.resnet18(weights=None)
#因为分类个数为58,所以需要修改模型最后一层全连接层
net.fc=nn.Linear(in_features=512, out_features=58, bias=True)
# print(net)

加载数据集(Dataset与Dataloader)

from torch.utils.data import Dataset,DataLoader
import numpy as np
import cv2
import imgaug.augmenters as iaa
from imgaug.augmentables.segmaps import SegmentationMapsOnImage
from PIL import Image
import os
from torchvision import transforms
import torch
import randoma=[]
class Mydata(Dataset):def __init__(self,lines,train=True):super(Mydata, self).__init__()self.lines=linesrandom.shuffle(self.lines)self.train=traindef __len__(self):return len(self.lines)def __getitem__(self, index):txts=self.lines[index].strip().split(';')src_path='pic/'+txts[0]w=int(txts[1])h=int(txts[2])x1=int(txts[3])y1=int(txts[4])x2=int(txts[5])y2=int(txts[6])new_x1=random.randint(0,x1)new_y1=random.randint(0,y1)new_x2=random.randint(x2,w-1)new_y2=random.randint(y2,h-1)lab=int(txts[7])# if lab in a:#     pass# else:a.append(lab)## a.sort()# print(len(a))# print(a)img = Image.open(src_path)img=np.array(img)[...,:3]img=img[new_y1:new_y2,new_x1:new_x2]#数据增强if self.train:img=self.get_random_data(img)else:img = cv2.resize(img, (128, 128))# cv2.imshow('img',img[...,::-1])# cv2.waitKey(0)#归一化img=(img/255.0).astype('float32')img=np.transpose(img,(2,0,1))img=torch.from_numpy(img)return img,labdef get_random_data(self,img):seq = iaa.Sequential([# iaa.Flipud(0.5),  # flip up and down (vertical)# iaa.Fliplr(0.5),  # flip left and right (horizontal)iaa.Multiply((0.8, 1.2)),  # change brightness, doesn't affect BBs(bounding boxes)iaa.GaussianBlur(sigma=(0, 1.0)),  # 标准差为0到3之间的值iaa.Crop(percent=(0, 0.2)),iaa.Affine(translate_px={"x": (0,15), "y": (0,15)},  # 平移scale=(0.8, 1.2),  # 尺度变换rotate=(-20, 20),mode='constant',cval=(125)),iaa.Resize(128)])img= seq(image=img)return img
if __name__ == '__main__':lines=open('data.txt','r').readlines()my=Mydata(lines=lines,train=True)myloader=DataLoader(dataset=my,batch_size=3,shuffle=False)for i,j in myloader:print(i.shape,j.shape)

模型训练

经过60个epoch训练后,模型准确率基本上达到百分百

from mymodel import net
from myDataset import Mydata
import random
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
import torch
from tqdm import tqdm
import matplotlib.pylab as pltbatch_size=32
Epoch=60
lr=0.001lines=open('data.txt','r').readlines()
random.shuffle(lines)
val_lines=random.sample(lines,int(len(lines)*0.1))
train_lines=list(set(lines)-set(val_lines))train_data=Mydata(lines=train_lines)
val_data=Mydata(lines=val_lines,train=False)
train_loader=DataLoader(dataset=train_data,batch_size=batch_size,shuffle=True)
val_loader=DataLoader(dataset=val_data,batch_size=batch_size,shuffle=False)num_train   = len(train_lines)
epoch_step  = num_train // batch_size
BCE_loss     = nn.CrossEntropyLoss()
optimizer  = optim.Adam(net.parameters(), lr=lr, betas=(0.5, 0.999))
lr_scheduler  = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.99)
#获取学习率函数
def get_lr(optimizer):for param_group in optimizer.param_groups:return param_group['lr']
#计算准确率函数
def metric_func(pred,lab):_,index=torch.max(pred,dim=-1)acc=torch.where(index==lab,1.,0.).mean()return acc
device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
net=net.to(device)
#设置损失函数
loss_fun     = nn.CrossEntropyLoss()if __name__ == '__main__':T_acc=[]V_acc=[]T_loss=[]V_loss=[]# 设置迭代次数200次epoch_step = num_train // batch_sizefor epoch in range(1, Epoch + 1):net.train()total_loss = 0loss_sum = 0.0train_acc_sum=0.0with tqdm(total=epoch_step, desc=f'Epoch {epoch}/{Epoch}', postfix=dict, mininterval=0.3) as pbar:for step, (features, labels) in enumerate(train_loader, 1):features = features.to(device)labels = labels.to(device)batch_size = labels.size()[0]optimizer.zero_grad()predictions = net(features)loss = loss_fun(predictions, labels)loss.backward()optimizer.step()total_loss += losstrain_acc = metric_func(predictions, labels)train_acc_sum+=train_accpbar.set_postfix(**{'loss': total_loss.item() / (step),"acc":train_acc_sum.item()/(step),'lr': get_lr(optimizer)})pbar.update(1)T_acc.append(train_acc_sum.item()/(step))T_loss.append(total_loss.item() / (step))# 验证net.eval()val_acc_sum = 0val_loss_sum=0for val_step, (features, labels) in enumerate(val_loader, 1):with torch.no_grad():features = features.to(device)labels = labels.to(device)predictions = net(features)val_metric = metric_func(predictions, labels)loss=loss_fun(predictions,labels)val_acc_sum += val_metric.item()val_loss_sum+=loss.item()print('val_acc=%.4f' % (val_acc_sum / val_step))V_acc.append(round(val_acc_sum / val_step,2))V_loss.append(val_loss_sum/val_step)# 保存模型if (epoch) % 2 == 0:torch.save(net.state_dict(), 'logs/Epoch%d-Loss%.4f_.pth' % (epoch, total_loss / (epoch_step + 1)))lr_scheduler.step()plt.figure()plt.plot(T_acc,'r')plt.plot(V_acc,'b')plt.title('Training and validation Acc')plt.xlabel("Epochs")plt.ylabel("Acc")plt.legend(["Train_acc", "Val_acc"])# plt.show()plt.savefig("ACC.png")plt.figure()plt.plot(T_loss, 'r')plt.plot(V_loss, 'b')plt.title('Training and validation loss')plt.xlabel("Epochs")plt.ylabel("loss")plt.legend(["Train_loss", "Val_loss"])plt.savefig("LOSS.png")plt.show()

训练准确率及损失函数:

准确率:

在这里插入图片描述
损失函数:
在这里插入图片描述

resnet18交通标志分类源码

(包含训练,预测代码,准确率,损失函结果图像,数据集等):
下载地址:

yolov5检测与识别(交通标志)

前面是使用resnet18网络对交通标志分类,只是单单的分类,无法从一张完整的全局图像中检测交通标志位置。对此,首先使用yolov5从全局图像中检测交通标志的位置,只是检测没有分类,然后再使用前面训练好的resnet18模型对交通标志分类。其效果如下:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/330941.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

LeetCode 279 —— 完全平方数

阅读目录 1. 题目2. 解题思路3. 代码实现 1. 题目 2. 解题思路 此图利用动态规划进行求解,首先,我们求出小于 n n n 的所有完全平方数,存放在数组 squareNums 中。 定义 dp[n] 为和为 n n n 的完全平方数的最小数量,那么有状态…

mysql中text,longtext,mediumtext区别

文章目录 一.概览二、字节限制不同三、I/O 不同四、行迁移不同 一.概览 在 MySQL 中,text、mediumtext 和 longtext 都是用来存储大量文本数据的数据类型。 TEXT:TEXT 数据类型可以用来存储最大长度为 65,535(2^16-1)个字符的文本数据。如果存储的数据…

【服务器】使用mobaXterm远程连接服务器

目录 1、安装mobaXterm2、使用mobaXterm3、程序后台保持运行状态 1、安装mobaXterm 下载地址:https://mobaxterm.mobatek.net/download.html 下载免费版 分为蓝色便携版(下载后可直接使用)和绿色安装版(需要正常安装后使用&…

【老王最佳实践-6】Spring 如何给静态变量注入值

有些时候,我们可能需要给静态变量注入 spring bean,尝试过使用 Autowired 给静态变量做注入的同学应该都能发现注入是失败的。 Autowired 给静态变量注入bean 失败的原因 spring 底层已经限制了,不能给静态属性注入值: 如果我…

Go语言(Golang)的开发框架

在Go语言(Golang)的开发中,有多种开发框架可供选择,它们各自具有不同的特点和优势。以下是一些流行的Go语言开发框架,选择Go语言的开发框架时,需要考虑项目需求、团队熟悉度、社区支持、框架性能和可维护性…

docker- 购建服务镜像并启动

文章目录 前言docker- 购建服务镜像并启动1. 前期准备2. 构建镜像3. 运行容器4. 验证 前言 如果您觉得有用的话,记得给博主点个赞,评论,收藏一键三连啊,写作不易啊^ _ ^。   而且听说点赞的人每天的运气都不会太差,实…

基于微信小程序的校园捐赠系统的设计与实现

校园捐赠系统是一种便捷的平台,为校园内的各种慈善活动提供支持和便利。通过该系统,学生、教职员工和校友可以方便地进行捐赠,并了解到相关的项目信息和捐助情况。本文将介绍一个基于Java后端和MySQL数据库的校园捐赠系统的设计与实现。 技术…

PGP软件安装文件加密解密签名实践记录

文章目录 环境说明PGP软件安装PGP软件汉化AB电脑新建密钥并互换密钥对称密钥并互换密钥 文件加密和解密A电脑加密B电脑解密 文件签名A电脑签名文件B电脑校验文件修改文件内容校验失败修改文件名称正常校验 环境说明 使用VM虚拟两个win11,进行操作演示 PGP软件安装 PGP软件下…

【Andoird开发】android获取蓝牙权限,搜索蓝牙设备MAC

<!-- Android 12以下才需要定位权限&#xff0c; Android 9以下官方建议申请ACCESS_COARSE_LOCATION --><uses-permission android:name"android.permission.ACCESS_COARSE_LOCATION" /><uses-permission android:name"android.permission.ACCES…

通过域名接口申请免费的ssl多域名证书

来此加密已顺利接入阿里云的域名接口&#xff0c;用户只需一键调用&#xff0c;便可轻松完成域名验证&#xff0c;从而更高效地申请证书。接下来&#xff0c;让我们详细解读一下整个操作过程。 来此加密官网 免费申请SSL证书 免费SSL多域名证书&#xff0c;泛域名证书。 首先&a…

【游戏引擎】Unity脚本基础 开启游戏开发之旅

持续更新。。。。。。。。。。。。。。。 【游戏引擎】Unity脚本基础 Unity脚本基础C#语言简介C#基础 Unity脚本基础创建和附加脚本MonoBehaviour生命周期生命周期方法 示例脚本 Unity特有的API常用Unity API 实践示例&#xff1a;制作一个简单的移动脚本步骤1&#xff1a;创建…

水泥超低排平台哪家好?

随着环保政策的加强和绿色发展理念的深入人心&#xff0c;水泥行业的超低排放改造已成为行业发展的新趋势。选择一个合适的水泥超低排平台对于确保改造效果和实现企业的可持续发展至关重要。朗观视觉小编将从多个角度出发&#xff0c;为您提供一份综合评估与选择攻略&#xff0…

Flask-SQLAlchemy的使用【二】

目录 一.查询 1.1查询语句的格式 1.2查询过滤器 1.3查询执行器 1.4具体例子 1.4.1查询有多少个用户 1.4.2查询第一个用户 1.4.3查询id为4的用户 1.4.4查询id为4title为4的记录 1.4.5查询id为4或者title为4的记录 1.4.6查询id为[1,3,5,7,9]的记录 1.4.7查询所有记录&a…

无人机助力光伏项目测绘建模

随着全球对可再生能源需求的不断增长&#xff0c;光伏项目作为其中的重要一环&#xff0c;其建设规模和速度都在不断提高。在这一背景下&#xff0c;如何高效、准确地完成光伏项目的测绘与建模工作&#xff0c;成为了行业发展的重要课题。近年来&#xff0c;无人机技术的快速发…

汇聚荣科技有限公司优点有哪些?

在当今快速发展的科技时代&#xff0c;企业之间的竞争愈发激烈。作为一家专注于科技创新与研发的公司&#xff0c;汇聚荣科技有限公司凭借其卓越的技术实力和创新能力&#xff0c;在业界树立了良好的口碑。那么&#xff0c;汇聚荣科技有限公司究竟有哪些优点呢?接下来&#xf…

WPF中MVVM架构学习笔记

MVVM架构是一种基于数据驱动的软件开发架构&#xff0c;它将数据模型&#xff08;Model&#xff09;、视图&#xff08;View&#xff09;和视图模型&#xff08;ViewModel&#xff09;三者进行分离&#xff0c;使得开发者可以更加专注于各自领域的开发。其中&#xff0c;Model负…

Add object from object library 从对象库中添加内置器件

Add object from object library 从对象库中添加内置器件 正文正文 对于 Lumerical,有些时候我们在使用中,可能需要从 Object library 中添加器件,通常我们的做法是手动添加。如下图所示,我们添加一个 Directional Coupler 到我们的工程文件中: 但是这种操作方式不够智能…

封装了一个iOS中间放大的collectionView layout

效果图如下所示 原理&#xff1a;就是首先确定一个放大和缩小系数和原大小对应的基准位置&#xff0c;然后根据距离每个布局属性到视图中心的距离和基准点到中心的距离的差距/基准点到中心的距离&#xff0c; 计算出每个布局属性的缩放系数 下面是代码 // // LBHorizontalCe…

基于AT89C52单片机的智能窗帘系统

点击链接获取Keil源码与Project Backups仿真图&#xff1a; https://download.csdn.net/download/qq_64505944/89276984?spm1001.2014.3001.5503 C 源码仿真图毕业设计实物制作步骤07 智能窗户控制系统学院&#xff08;部&#xff09;&#xff1a; 专 业&#xff1a; 班 级&…

springboot vue 开源 会员收银系统 (4) 门店模块开发

前言 完整版演示 前面我们对会员系统 springboot vue 开源 会员收银系统 (3) 会员管理的开发 实现了简单的会员添加 下面我们将从会员模块进行延伸 门店模块的开发 首先我们先分析一下常见门店的管理模式 常见的管理形式为总公司 - 区域管理&#xff08;若干个门店&#xff…