本科阶段最后一次竞赛Vlog——2024年智能车大赛智慧医疗组准备全过程——6Resnet实现黑线识别

本科阶段最后一次竞赛Vlog——2024年智能车大赛智慧医疗组准备全过程——6Resnet实现黑线识别

​ 比赛还有重要部分就是黑线的识别,这块地平线社区的帖子很多

​ 在本次我就使用了社区吴超大佬写出的文章,当然我们的步骤有所不同,也是比较省事的一种

​ 现在和大家一起聊聊我们的应用

1.代码介绍

​ 社区上的代码主要这几个py文件,这里我们数据制作采用了自己的,只是简单的opencv

image-20240810195030960

2.准备工作

​ 准备工作就是创建好数据集

​ 这里建议直接运行超的代码

DATASET_NMAE = "DataSet_3_1119"from os import makedirs
path = "./" + DATASET_NMAE + "/"
try:makedirs(path + "train/image/")makedirs(path + "train/label/")makedirs(path + "test/image/")makedirs(path + "test/label/")print("Dirs Success")
except:print("Dirs Failed")

3. 录制视频

​ 首先第一步是,获取512*512大小的图片,这里给大家我们是使用Opencv 进行录制视频,进行使用代码裁剪

​ 对于小车打开摄像头选项是8,这里按下Ctrl C就可以自动保存

import cv2
import signal
import sysdef signal_handler(sig, frame):global stop_recordingstop_recording = Truedef record_video(output_file, width=640, height=480, fps=30):global stop_recordingstop_recording = False# 创建视频捕获对象fourcc = cv2.VideoWriter_fourcc(*'XVID')  # 使用XVID编码器out = cv2.VideoWriter(output_file, fourcc, fps, (width, height))# 打开默认摄像头cap = cv2.VideoCapture(8)# 设置视频帧率cap.set(cv2.CAP_PROP_FPS, fps)# 获取摄像头的实际帧率(以确保设置成功)actual_fps = cap.get(cv2.CAP_PROP_FPS)print(f"实际帧率: {actual_fps}")# 注册信号处理函数以便捕获Ctrl+Csignal.signal(signal.SIGINT, signal_handler)# 开始录制while True:ret, frame = cap.read()if not ret:break# 写入当前帧到视频文件out.write(frame)# 检查是否收到停止信号if stop_recording:print("录制已结束")break# 按 'q' 键退出if cv2.waitKey(1) & 0xFF == ord('q'):break# 释放资源cap.release()out.release()cv2.destroyAllWindows()# 调用函数
record_video('output.avi', width=640, height=480, fps=30)

4. 01_get512img

​ 针对小车录制的视频,保存到本地,使用下面代码进行抽帧

​ 代码里已经给大家写好注释了

from datetime import datetimeimport cv2
import osdef extract_frames(video_path, output_folder, i,frame_size=(512, 512) ,skip_frames=3):# 确保输出文件夹存在if not os.path.exists(output_folder):os.makedirs(output_folder)# 打开视频文件cap = cv2.VideoCapture(video_path)if not cap.isOpened():print("Error: Could not open video.")returnframe_count = 0saved_frame_count = 0current_time = datetime.now().strftime("%Y%m%d_%H%M%S")  # 获取当前时间并格式化为字符串while True:# 读取视频中的一帧ret, frame = cap.read()if not ret:break  # 视频结束或读取错误# 每隔skip_frames帧保存一帧if frame_count % skip_frames == 0:# 调整帧的大小resized_frame = cv2.resize(frame, frame_size, interpolation=cv2.INTER_AREA)# 保存帧到文件,文件名前加上当前时间frame_filename = f"{output_folder}/{i}_frame_{saved_frame_count:04d}.png"cv2.imwrite(frame_filename, resized_frame)print(f"Saved {frame_filename}")saved_frame_count += 1frame_count += 1# 释放视频捕获对象cap.release()print("Done extracting frames.")# 使用示例
i = 1
while i<=5:video_path = rf'D:\cardata\{i}.avi'output_folder = f'./DataSet_3_1119/train/image'extract_frames(video_path, output_folder, i,frame_size=(512, 512), skip_frames=5)  # 每隔5帧保存一次# extract_frames(video_path, output_folder, i,frame_size=(672, 672), skip_frames=5)  # 每隔5帧保存一次i+=1

5. 进行标记

如果上面你按照我粘贴的代码,现在这里就可以无脑运行超哥的代码

import cv2
import osDATASET_NMAE = "DataSet_3_1119"  # 数据集名称
ZOOM = 1.0  # 显示缩放倍数,与标注数据无关,仅仅适应一些高分屏的电脑def get_xy(file_path):# 读取txt文件,获取x,y坐标(浮点数表示)x,y = -1.0,-1.0with open(file_path) as f:content = f.read().split(" ")x, y = float(content[0]), float(content[1])f.close()return x,ydef mouse_callback(event, x, y, flags, param):# 鼠标点击事件global img_x, img_y, label_path, txt_name,img,img_width, img_heightif event == cv2.EVENT_LBUTTONUP:print(img_width,img_height)img_x, img_y = float(x)/img_width/ZOOM, float(y)/img_height/ZOOMcv2.imshow("img", cv2.circle(img.copy(), (x, y), 10,(0,0,255), -1))print("Mouse Click(%d, %d), Save as(%.8f, %.8f)"%(x,y,img_x,img_y))with open(label_path + txt_name,"w") as f:f.write("%.8f %.8f"%(img_x, img_y))# 新建cv2的工作窗口,并绑定鼠标点击的回调函数
img_x, img_y = -1,-1
cv2.namedWindow('img')
cv2.setMouseCallback('img', mouse_callback)img_path = DATASET_NMAE + "/train/image/"
label_path = DATASET_NMAE + "/train/label/"
print("img path = %s"%img_path)
print("label path = %s"%label_path)img_names = os.listdir(img_path)# img size
img_width, img_height = 0, 0      
# img control
img_control = 0
img_control_min = 0
img_control_max = len(img_names) - 1
while True:name = img_names[img_control]print(name, end="  ")img = cv2.imread(img_path + name)img_height, img_width = img.shape[:2]img = cv2.resize(img, (0,0), fx=ZOOM, fy=ZOOM)cv2.imshow("img", img)print("height = %d, width = %d"%(img_height, img_width), end="  ")## 若存在标签则绘制点,若不存在则不绘制txt_name = name.split(".")[0] + ".txt"label_names = os.listdir(label_path)if txt_name in label_names:img_x, img_y = get_xy(label_path + txt_name)cv2.imshow("img", cv2.circle(img.copy(), (int(ZOOM*img_width*img_x), int(ZOOM*img_height*img_y)), 10,(0,0,255), -1))# print(int(ZOOM*img_width*img_x), int(ZOOM*img_height*img_y))print("\033[32;40m" + "Label Exist" + "\033[0m" + ": x = %.8f, y = %.8f"%(img_x, img_y))else:print("\033[31m" + "NO Label" + "\033[0m")## while 循环的控制command = cv2.waitKey(0) & 0xFF# 慢速退if command == ord('a'):if img_control > img_control_min:img_control -= 1else:img_control = 0print("First img already")# 慢速进elif command == ord('d'):if img_control < img_control_max:img_control += 1else:img_control = img_control_maxprint("Last img already")# 快速退elif command == ord('z'):if img_control - 4 > img_control_min:img_control -= 5else:img_control = 0print("First img already")# 快速进elif command == ord('c'):if img_control + 4 < img_control_max:img_control += 5else:img_control = img_control_maxprint("Last img already")# 退出elif command == ord('q'):breakelse:print("Unknown Command")

6. 删除多余标签

​ 上面打完标签,会生成很多txt,但是对于有些时候有些图片并没有黑线

​ 这里就会导致图片与txt不匹配,这里给大家写了个删除多余图片和标记的代码

import osdef delete_unlabeled_images(image_dir, label_dir):# 遍历图片目录中的所有文件for img_file in os.listdir(image_dir):if img_file.endswith(".png"):  # 确保处理的是PNG图片# 构建对应的标签文件路径label_file = os.path.splitext(img_file)[0] + ".txt"label_path = os.path.join(label_dir, label_file)# 检查标签文件是否存在if not os.path.exists(label_path):# 如果标签文件不存在,则删除图片img_path = os.path.join(image_dir, img_file)os.remove(img_path)print(f"Deleted image without label: {img_path}")# 数据集名称
DATASET_NAME = "./taSet_3_1119"
# 图片和标签的目录路径
img_path = os.path.join(DATASET_NAME, "train/image")
label_path = os.path.join(DATASET_NAME, "train/label")# 调用函数,删除没有对应标签的图片
delete_unlabeled_images(img_path, label_path)

7. 划分数据

​ 上面打完标签,会生成很多txt,当然现在是仅仅在DataSet_3_1119/train/image这个目录

​ 我们现在需要划分一部分图片和对应标签到test里面

​ 按照上面的这里仍然可以无脑进行下一步

DATASET_NMAE = "DataSet_3_1119"  # 数据集名称
test_percent = 0.25  # 0.25表示25%的图片作为测试集from random import sample
from shutil import move
from os import listdirpath = "./" + DATASET_NMAE + "/"train_image = "train/image/"
train_label = "train/label/"
test_image = "test/image/"
test_label = "test/label/"images_names = listdir(path + train_image)
# 抽样并移动
test_number = int(len(images_names)*test_percent)
test_names = sample(images_names, test_number)
for name in test_names:# 移动图片image_old = path + train_image + nameimage_path = path + test_image + nameprint(image_old, end=" ")try:move(image_old,image_path)print("\033[32;40m" + "Success." + "\033[0m")except:print("\033[31m" + "Failed! " + "\033[0m")# 移动标签label_old = path + train_label + name.split(".")[0] + ".txt"label_path = path + test_label + name.split(".")[0] + ".txt"print(label_old, end=" ")try:move(label_old, label_path)print("\033[32;40m" + "Success." + "\033[0m")except:print("\033[31m" + "Failed! " + "\033[0m")

8.训练

这个时候,就可以真正训练了

## 此Python脚本在开发机上运行 ### Step 5# 训练ResNet18
# 如果是第一次训练会自动下载预训练权重,约40MB
# 训练结束后会在当前目录下生成一个名为BEST_MODEL_PATH的模型文件
# CPU就能训练,我的R7-4800H约12秒一个Epoch,不会太慢import torch
import torch.optim as optim
import torch.nn.functional as F
import torchvision.models as models
import torchvision.transforms as transforms
import glob
import PIL.Image
import os
import numpy as np
from threading import Thread
from time import time, sleepDATASET_NMAE = "DataSet_3_1119"   # 数据集名称
BEST_MODEL_PATH = './model_best1000.pth'  # 最好的训练结果
BATCH_SIZE = 256
NUM_EPOCHS = 1000            # 迭代次数def main(args=None):best_loss = 1e9train_image = "./" + DATASET_NMAE + "/train/"test_image = "./" + DATASET_NMAE + "/test/"train_dataset = XYDataset(train_image)test_dataset = XYDataset(test_image)train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=BATCH_SIZE,shuffle=True,num_workers=0)test_loader = torch.utils.data.DataLoader(test_dataset,batch_size=BATCH_SIZE,shuffle=True,num_workers=0)# 创建ResNet18模型,这里选用已经预训练的模型,# 更改fc输出为2,即x、y坐标值model = models.resnet18(pretrained=True)model.fc = torch.nn.Linear(512, 2)device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')print(f"Using device: {device}")model = model.to(device)optimizer = optim.Adam(model.parameters())print("开始训练")for epoch in range(NUM_EPOCHS):print(epoch)epoch_time_begin = time()model.train()train_loss = 0.0for images, labels in iter(train_loader):images = images.to(device)labels = labels.to(device)optimizer.zero_grad()outputs = model(images)loss = F.mse_loss(outputs, labels)train_loss += float(loss)loss.backward()optimizer.step()train_loss /= len(train_loader)model.eval()test_loss = 0.0for images, labels in iter(test_loader):images = images.to(device)labels = labels.to(device)outputs = model(images)loss = F.mse_loss(outputs, labels)test_loss += float(loss)test_loss /= len(test_loader)msgStr = "Epoch" + "\033[32;40m" + " %d " % epoch + "\033[0m"msgStr += "-> time: \033[32;40m%.3f\033[0m s,  train_loss: \033[32;40m%f\033[0m,  test_loss: \033[32;40m%f\033[0m" % (time() - epoch_time_begin, train_loss, test_loss)if test_loss < best_loss:msgStr += (" \033[31m" + " Saved" + "\033[0m")torch.save(model.state_dict(), BEST_MODEL_PATH)best_loss = test_losselse:msgStr += " Done"print(msgStr)class XYDataset(torch.utils.data.Dataset):def __init__(self, directory, random_hflips=False):self.directory = directoryself.random_hflips = random_hflipsself.image_paths = glob.glob(os.path.join(self.directory + "/image", '*.png'))self.color_jitter = transforms.ColorJitter(0.3, 0.3, 0.3, 0.3)def __len__(self):return len(self.image_paths)def __getitem__(self, idx):image_path = self.image_paths[idx]image = PIL.Image.open(image_path)with open(os.path.join(self.directory + "/label", os.path.splitext(os.path.basename(image_path))[0]+".txt"), 'r') as label_file:content = label_file.read()values = content.split()if len(values) == 2:value1 = float(values[0])value2 = float(values[1])else:print("文件格式不正确")x, y = value1, value2if self.random_hflips:if float(np.random.rand(1)) > 0.5:image = transforms.functional.hflip(image)x = -ximage = self.color_jitter(image)image = transforms.functional.resize(image, (224, 224))image = transforms.functional.to_tensor(image)image = image.numpy().copy()image = torch.from_numpy(image)image = transforms.functional.normalize(image,[0.485, 0.456, 0.406], [0.229, 0.224, 0.225])return image, torch.tensor([x, y]).float()if __name__ == '__main__':main()

9.转ONNX

直接无脑运行哈哈,超超佬的代码太好用了

!!!当然无脑也得改路径

import torchvision
import torchBEST_MODEL_PATH = r'C:\Users\jszjg\Desktop\ResNet18\model_000056_all.pth'  # 最好的训练结果def main(args=None):model = torchvision.models.resnet18(pretrained=False)model.fc = torch.nn.Linear(512,2)model.load_state_dict(torch.load(BEST_MODEL_PATH, map_location='cpu'))device = torch.device('cpu')model = model.to(device)model.eval()x = torch.randn(1, 3, 224, 224, requires_grad=True)# torch_out = model(x)torch.onnx.export(model,x,BEST_MODEL_PATH[:-4] + ".onnx",export_params=True,opset_version=11,do_constant_folding=True,input_names=['input'],output_names=['output'])if __name__ == '__main__':main()

10.总结与下期预告

​ 现在按照地平线大佬的教程,训练一段时间已经可以获得一个onnx模型了

del.to(device)
model.eval()
x = torch.randn(1, 3, 224, 224, requires_grad=True)

torch_out = model(x)

torch.onnx.export(model,
x,
BEST_MODEL_PATH[:-4] + “.onnx”,
export_params=True,
opset_version=11,
do_constant_folding=True,
input_names=[‘input’],
output_names=[‘output’])

if name == ‘main’:
main()

# 10.总结与下期预告​	现在按照地平线大佬的教程,训练一段时间已经可以获得一个onnx模型了​	后面将把resnet转模型步骤给大家进行演示

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/396770.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

黄牛杀手 抢票脚本 V3.0

黄牛杀手 抢票脚本 V3.0 介绍 现在黄牛太tm多了&#xff0c;根本抢不到票 为了解决这个问题&#xff0c;开发了这个脚本&#xff0c;支持大麦网&#xff0c;淘票票、缤玩岛等多个平台 依赖 selenium (4.10.0以下版本) pip install selenium 现在黄牛太tm多了&#xff0c;根…

2.类和对象(上)

1. 类的定义 1.1 类定义格式 • class为定义类的关键字&#xff0c;Stack为类的名字&#xff0c;{ }中为类的主体&#xff0c;注意类定义结束时后面分号不能省略。类体中内容称为类的成员&#xff1a;类中的变量称为类的属性或成员变量; &#xff08;类和结构体非常像&#…

LVS原理——详细介绍

目录 lvs简介 LVS作用 LVS 的优势与不足 LVS概念与相关术语 LVS的3种工作模式 LVS调度算法 LVS-dr模式 LVS-tun模式 ipvsadm工具使用 lvs简介 LVS 是Linux Virtual Server的简称&#xff0c;也就是 Linux 虚拟服务器,是一个极好的负载均衡解决方案&#xff0c;它将一个…

计数排序,桶排序,基数排序

计数排序&#xff1a; 找出数据中的最大值和最小值&#xff0c;并创建哈希表&#xff0c;把 数据-最小值 作为数组的下标访问哈希表并标记数量&#xff0c;标记完后&#xff0c;遍历哈希表&#xff0c;当表中的值大于0&#xff0c;把 **下标最小值 (下标元素-最小值)**还原数据…

LLVM 寄存器分配

概述 基本寄存器分配器是四种寄存器分配器中最简单的寄存器分配pass实现(<llvm_root/livm/lib/CodeGen/RegAllocBasic.cpp>) 但麻雀虽小&#xff0c;五脏俱全&#xff0c;基本寄存器分配器中实现了根据溢出权重确实虚拟寄存器优先级、按优先级分配物理寄存器&#xff0…

韦东山瑞士军刀项目自学之UART

放自己一星期假回家&#xff0c;回来继续准备秋招。 本章记录关于UART协议的相关知识笔记。平时主要还是基于HAL库开发&#xff0c;但笔记里也讲了韦老师介绍的如何控制寄存器来设置UART的参数。 以及一些UART防止采集的抖动设置的一些策略与波特率与比特率的区别等。

文件共享服务NFS(服务名nfs,端口tcp/2049)

目录 前言 配置文件 工作原理 NFS服务器的配置 查看服务器是否安装 查看服务器状态 开启服务 编写配置文件 客户端挂载 前言 NFS&#xff08;Network File System&#xff09;是一种分布式文件系统协议&#xff0c;它允许网络中的不同计算机共享文件和目录&#xff0…

使用tailwindcss轻松实现移动端rem适配

本示例节选自小卷全栈开发实战系列的《Vue3实战》。演示如何用tailwindcss所支持的rem体系轻松实现一个仿b站移动端头部导航栏rem适配。 友情声明 学习分享不易&#xff0c;如果小伙伴觉得有帮助&#xff0c;点赞支持下。满30赞&#xff0c;将随文附赠录屏讲解&#xff0c;感谢…

树莓派4/5:运行Yolov5n模型(文末附镜像文件)

〇、前言 因国内网络问题&#xff0c;可直接烧录文末镜像文件&#xff0c;或者按照本教程进行手动操作。 一、实验目的 在树莓派4B运行Yolov5n模型。 二、实验条件 1、Windows 11计算机&#xff1a;安装了Mobaxterm 2、树莓派4B&#xff1a;64Bit Lite OS&#xff0c;安装了…

案例:Nginx + Tomcat集群(负载均衡 动静分离)

目录 案例 案例环境 案例步骤 部署Tomcat服务器 部署Nginx服务器 实现负载均衡和读写分离 日志控制 案例 案例环境 操作系统 IP 地址 角色 CentOS 192.168.10.101 Nginx服务器&#xff08;调度器&#xff09; CentOS 192.168.10.102 Tomcat服务器① CentOS 1…

uniapp 对于scroll-view滑动和页面滑动的联动处理

需求 遇到一个需求 解决方案 这个时候可以做一个内页面滑动判断 <!-- scroll-y 做true或者false的判断是否滑动 --> <view class"u-menu-wrap" style"background-color: #fff;"><scroll-view :scroll-y"data.isGo" scroll-wit…

贷奇乐漏洞学习 --- 两个变态WAF绕过

代码分析 第一个WAF 代码 function dowith_sql($str) {$check preg_match(/select|insert|update|delete|\|\/\*|\*|\.\.\/|\.\/|union|into|load_file|outfile/is, $str);if ($check) {echo "非法字符!";exit();}return $str;} 实现原理 这段PHP代码定义了一个…

uniapp切换同一个子组件时,钩子函数只进了一次

给子组件添加不同的 “key” 值&#xff0c;当 key 值改变时&#xff0c;Vue 会认为这是一个不同的组件&#xff0c;并重新创建它 props: ["L1Id"],// 方式1: 使用keycomputed: {// 切换子组件时,发现created、mounted等钩子函数只会进一次,或者用 keykey(){this.ref…

CSS技巧专栏:一日一例 19 -纯CSS实现超酷的水晶按钮特效

CSS技巧专栏:一日一例 19 -纯CSS实现超酷的水晶按钮特效 今天给大家分享一个纯CSS按钮水晶按钮,效果很赞,希望对大家有所帮助。 本例图片 案例分析 这个按钮看起来效果很赞,我们分析一下它由几个层组成: 1. 按钮本体:渐变层+按钮文字 2.用before伪元素实现高光层+内…

线程与多线程(二)

线程与多线程&#xff08;二&#xff09; 一、线程互斥1、相关概念 二、互斥锁1、介绍2、使用场景3、初始化&#xff08;1&#xff09;函数&#xff08;2&#xff09;概念 4、销毁&#xff08;1&#xff09;函数&#xff08;2&#xff09;概念 5、加锁&#xff08;1&#xff09…

SAM-Med2D 大模型学习笔记(续):训练自己数据集

1、前言、数据集介绍 SAM-Med2D大模型介绍参考上文&#xff1a;第三章&#xff1a;SAM-Med2D大模型复现-CSDN博客 本文将使用SAM-Med2D大模型训练自己的数据集 关于SAM-Med2D大模型官方demo数据集的介绍上文已经介绍过&#xff0c;这里简单回顾下 其中data_demo为数据集的目…

leetcode171. Excel 表列序号,进制转换

leetcode171. Excel 表列序号 给你一个字符串 columnTitle &#xff0c;表示 Excel 表格中的列名称。返回 该列名称对应的列序号 。 例如&#xff1a; A -> 1 B -> 2 C -> 3 … Z -> 26 AA -> 27 AB -> 28 … 示例 1: 输入: columnTitle “A” 输出: 1 示…

电商平台产品ID|CDN与预渲染|前端边缘计算

技术实现 都是通过ID拿到属性&#xff0c;进行预渲染html&#xff0c;通过 oss 分发出去 详情页这种基本都是通过 ssr 渲染出来&#xff0c;然后上缓存 CDN 分发到边缘节点来处理&#xff0c;具体逻辑可以参考 淘宝——EdgeRoutine边缘计算&#xff08;CDNServerless 边缘计算…

国内真正意义上的OpenAI,最强多模态大模型 MiniCPM-V 2.6 发布

最近这一两周看到不少互联网公司都已经开始秋招提前批了。不同以往的是&#xff0c;当前职场环境已不再是那个双向奔赴时代了。求职者在变多&#xff0c;HC 在变少&#xff0c;岗位要求还更高了。 最近&#xff0c;我们又陆续整理了很多大厂的面试题&#xff0c;帮助一些球友解…

二叉树的最大深度

二叉树的最大深度 思路&#xff1a; 法一&#xff1a;深搜 也就是递归 要想清楚边界条件 好久没写深搜了 回忆下怎么写。 突然就悟了&#xff1a; /*** Definition for a binary tree node.* struct TreeNode {* int val;* TreeNode *left;* TreeNode *rig…