YOLO系列解读DAY2—YOLOV1预测代码转换

一、说在前面

小伙伴们好,博主很久没有写博客了,略感生疏,不到之处敬请谅解,欢迎指出文中错误,大家一起探讨。欲看视频讲解,可转至博主DouYin、B站,欢迎关注,链接如下:

Github: GitHub - samylee/YOLOV1_PyTorch

DY: samylee_csdn

B站:samylee

二、博客系列链接

YOLO系列解读DAY1—YOLOV1预训练模型

YOLO系列解读DAY2—YOLOV1预测代码转换

三、预测代码转换效果演示

四、YOLOV1网络架构

从上图可以看出yolov1-tiny的网络架构较为简单,只是采用卷积层和全连接层累加的形式,但并不能否认该网络在当时是开山之作,后面的架构都是基于该架构的思想进行创作的!其中要注意的是yolov1并不是通过anchor回归边界框,而是暴力地直接回归边框信息,所以导致边界框并不是很准。这个在后面的yolov1训练部分会进一步讲解,敬请期待。

网络输出的全连接层默认S=7,表示输出特征图的尺寸为7x7,B=2,表示输出两个矩形框(boundingbox),C=20,表示voc2007+2012的20类目标。之所以作者尝试用全连接层作为网络输出,是因为当时没想到可以用全卷积网络呗,当时的主流网络框架就是卷积+全连接,后面作者开源的YOLOV2则是用了全卷积作为网络的主干架构。输出特征如下图(非原创图,若有侵权,联系删除)所示:

上图可以一目了然反应输出特征图表达的含义,即2个boundingbox,2个对应的有无目标概率,20个类别概率,然而confidence与class-score两者的乘积才作为类别的输出概率。需要注意yolov1全连接层的数据排序,其输出的channels=(S*S*(5*B+C))=1470,前7x7x20存储的是voc20类的概率,中间7x7x2存储的是2个矩形框的目标概率,最后7x7x2x4存储的是2个矩形框的边框信息。

五、YOLOV1网络预测转PyTorch

网络转换预测需要注意三个地方:

1、要求将yolov1-tiny网络按照其特定的网络序列复写出PyTorch的网络架构,切不可心浮气躁;

2、Darknet存储参数的序列和PyTorch稍有不同,若遇到BatchNorm架构的卷积层,Darknet会先存储BatchNorm层的参数,进而存储卷积层的参数;

3、图像数据预处理部分,yolov1-tiny网络的顺序是:BGR2RGB->Norm(1/255)->Resize,所以我们在复现的时候一定要按照该顺序进行数据预处理操作。

代码复现如下,版本PyTorch1.4,Python3.6。

detect.py
import torch
import torch.nn as nn
import cv2
import numpy as npfrom YOLOV1 import YOLOV1
from utils import load_darknet_weights, nmsdef postprocess(output, thresh, S, B, C, img_w, img_h):# to cpu numpypredictions = output.squeeze(0).data.cpu().numpy()# detection results# [xmin, ymin, xmax, ymax, score, class_id]results = np.empty((0, 4 + 1 + 1), dtype=np.float32)probs_tmp = np.empty((C), dtype=np.float32)boxes_tmp = np.empty((4), dtype=np.float32)for i in range(S * S):row = i // Scol = i % S# get objprob_index = S * S * C + i * Bobj1_prob = predictions[prob_index]obj2_prob = predictions[prob_index + 1]obj_prob_max = obj1_prob if obj1_prob > obj2_prob else obj2_probobj_prob_max_index = 0 if obj1_prob > obj2_prob else 1# get classclass_index = i * Cfor j in range(C):class_prob = obj_prob_max * predictions[class_index + j]probs_tmp[j] = class_prob if class_prob > thresh else 0if probs_tmp.max() > thresh:# get network boxesbox_index = S * S * (C + B) + (i * B + obj_prob_max_index) * 4boxes_tmp[0] = (predictions[box_index + 0] + col) / Sboxes_tmp[1] = (predictions[box_index + 1] + row) / Sboxes_tmp[2] = pow(predictions[box_index + 2], 2)boxes_tmp[3] = pow(predictions[box_index + 3], 2)# get real boxesxmin = (boxes_tmp[0] - boxes_tmp[2] / 2.) * img_wymin = (boxes_tmp[1] - boxes_tmp[3] / 2.) * img_hxmax = (boxes_tmp[0] + boxes_tmp[2] / 2.) * img_wymax = (boxes_tmp[1] + boxes_tmp[3] / 2.) * img_h# limit rectxmin = xmin if xmin > 0 else 0ymin = ymin if ymin > 0 else 0xmax = xmax if xmax < img_w else img_w - 1ymax = ymax if ymax < img_h else img_h - 1values = [xmin, ymin, xmax, ymax, probs_tmp.max(), probs_tmp.argmax()]row_values = np.expand_dims(np.array(values), axis=0)results = np.append(results, row_values, axis=0)return resultsdef preprocess(img, net_w, net_h):# img bgr2rgbimg_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)# resize imgimg_resize = cv2.resize(img_rgb, (net_w, net_h))# norm imgimg_resize = torch.from_numpy(img_resize.transpose((2, 0, 1)))img_norm = img_resize.float().div(255).unsqueeze(0)return img_normif __name__ == '__main__':# load moelcheckpoint_path = 'weights/yolov1-tiny.weights'S, B, C = 7, 2, 20model = YOLOV1(S=S, B=B, C=C)load_darknet_weights(model, checkpoint_path)model.eval()# params initnet_w, net_h = 448, 448thresh = 0.2iou_thresh = 0.4classes = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle','bus', 'car', 'cat', 'chair', 'cow','diningtable', 'dog', 'horse', 'motorbike', 'person','pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']# load imgimg = cv2.imread('data/person.jpg')img_h, img_w, _ = img.shape# preprocessimg_norm = preprocess(img, net_w, net_h)# forwardoutput = model(img_norm)# postprocessresults = postprocess(output, thresh, S, B, C, img_w, img_h)# nmsresults = nms(results, iou_thresh)# showfor i in range(results.shape[0]):if results[i][4] > thresh:cv2.rectangle(img, (int(results[i][0]), int(results[i][1])), (int(results[i][2]), int(results[i][3])), (0,255,0), 2)cv2.putText(img, classes[int(results[i][5])] + '-' + str(round(results[i][4], 4)), (int(results[i][0]), int(results[i][1])), 0, 0.6, (0,255,255), 2)cv2.imwrite('demo.jpg', img)cv2.imshow('demo', img)cv2.waitKey(0)
YOLOV1.py
import torch
import torch.nn as nnclass YOLOV1(nn.Module):def __init__(self, S=7, B=2, C=20):super(YOLOV1, self).__init__()in_channels = 3out_channels = 256self.features = self.make_layers(in_channels=in_channels, out_channels=out_channels)self.fc = nn.Linear(S * S * out_channels, S * S * (5 * B + C))def forward(self, x):for feature in self.features:x = feature(x)# flattenx = x.view(x.size(0), -1)x = self.fc(x)return xdef make_layers(self, in_channels=3, out_channels=256):# conv: out_channels, kernel_size, stride, batchnorm, activate# maxpool: kernel_size strideparams = [[16, 3, 1, True, 'leaky'],['M', 2, 2],[32, 3, 1, True, 'leaky'],['M', 2, 2],[64, 3, 1, True, 'leaky'],['M', 2, 2],[128, 3, 1, True, 'leaky'],['M', 2, 2],[256, 3, 1, True, 'leaky'],['M', 2, 2],[512, 3, 1, True, 'leaky'],['M', 2, 2],[1024, 3, 1, True, 'leaky'],[out_channels, 3, 1, True, 'leaky']]module_list = nn.ModuleList()for i, v in enumerate(params):modules = nn.Sequential()if v[0] == 'M':modules.add_module(f'maxpool_{i}', nn.MaxPool2d(kernel_size=v[1], stride=v[2], padding=int((v[1] - 1) // 2)))else:modules.add_module(f'conv_{i}',nn.Conv2d(in_channels,v[0],kernel_size=v[1],stride=v[2],padding=(v[1] - 1) // 2,bias=not v[3]))if v[3]:modules.add_module(f'bn_{i}', nn.BatchNorm2d(v[0]))modules.add_module(f'act_{i}', nn.LeakyReLU(0.1) if v[4] == 'leaky' else nn.ReLU())in_channels = v[0]module_list.append(modules)return module_list
utils.py
import numpy as np
import torch
import torch.nn as nndef nms(boxes, iou_thres):x1 = boxes[:, 0]y1 = boxes[:, 1]x2 = boxes[:, 2]y2 = boxes[:, 3]scores = boxes[:, 4]areas = (x2 - x1 + 1) * (y2 - y1 + 1)keep = []index = np.argsort(scores)[::-1]while(index.size):i = index[0]keep.append(index[0])inter_x1 = np.maximum(x1[i], x1[index[1:]])inter_y1 = np.maximum(y1[i], y1[index[1:]])inter_x2 = np.minimum(x2[i], x2[index[1:]])inter_y2 = np.minimum(y2[i], y2[index[1:]])inter_area = np.maximum(inter_x2 - inter_x1 + 1, 0) * np.maximum(inter_y2 - inter_y1 + 1, 0)iou = inter_area / (areas[index[1:]] + areas[i] - inter_area)ids = np.where(iou <= iou_thres)[0]index = index[ids + 1]return boxes[keep]def load_darknet_weights(model, weights_path):# Open the weights filewith open(weights_path, 'rb') as f:# First five are header valuesheader = np.fromfile(f, dtype=np.int32, count=4)header_info = header  # Needed to write header when saving weightsseen = header[3]  # number of images seen during trainingweights = np.fromfile(f, dtype=np.float32)  # The rest are weightsweights_len = len(weights)ptr = 0# convfor module in model.features:if isinstance(module[0], nn.Conv2d):conv_layer = module[0]if isinstance(module[1], nn.BatchNorm2d):# Load BN bias, weights, running mean and running variancebn_layer = module[1]num_b = bn_layer.bias.numel()  # Number of biases# Biasbn_b = torch.from_numpy(weights[ptr: ptr + num_b]).view_as(bn_layer.bias)bn_layer.bias.data.copy_(bn_b)ptr += num_b# Weightbn_w = torch.from_numpy(weights[ptr: ptr + num_b]).view_as(bn_layer.weight)bn_layer.weight.data.copy_(bn_w)ptr += num_b# Running Meanbn_rm = torch.from_numpy(weights[ptr: ptr + num_b]).view_as(bn_layer.running_mean)bn_layer.running_mean.data.copy_(bn_rm)ptr += num_b# Running Varbn_rv = torch.from_numpy(weights[ptr: ptr + num_b]).view_as(bn_layer.running_var)bn_layer.running_var.data.copy_(bn_rv)ptr += num_belse:# Load conv. biasnum_b = conv_layer.bias.numel()conv_b = torch.from_numpy(weights[ptr: ptr + num_b]).view_as(conv_layer.bias)conv_layer.bias.data.copy_(conv_b)ptr += num_b# Load conv. weightsnum_w = conv_layer.weight.numel()conv_w = torch.from_numpy(weights[ptr: ptr + num_w]).view_as(conv_layer.weight)conv_layer.weight.data.copy_(conv_w)ptr += num_w# fcfc_layer = model.fc# biasnum_b = fc_layer.bias.numel()fc_b = torch.from_numpy(weights[ptr: ptr + num_b]).view_as(fc_layer.bias)fc_layer.bias.data.copy_(fc_b)ptr += num_b# weightsnum_w = fc_layer.weight.numel()fc_w = torch.from_numpy(weights[ptr: ptr + num_w]).view_as(fc_layer.weight)fc_layer.weight.data.copy_(fc_w)ptr += num_wassert weights_len == ptr, 'darknet_weights\'s length dont match pytorch_weights\'s length'

六、写在后面

小伙伴们若能坚持完成YOLO系列的代码和原理学习,相信能对图像检测任务有个全新的认识,跟随博主的脚步,培养自己的动手能力吧!希望博主也能坚持将该系列做下去,一起加油!!!

七、参考

YOLO: Real-Time Object Detection

GitHub - samylee/YOLOV1_PyTorch

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/96831.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Dockerfile自定义镜像

文章目录 Dockerfile自定义镜像镜像结构Dockerfile语法构建java项目 小结 Dockerfile自定义镜像 常见的镜像在DockerHub就能找到&#xff0c;但是我们自己写的项目就必须自己构建镜像了。 而要自定义镜像&#xff0c;就必须先了解镜像的结构才行。 镜像结构 镜像是将应用程序及…

java知识-JVM线程四大引用

一、JVM (1) 基本概念&#xff1a; JVM 是可运行 Java 代码的假想计算机 &#xff0c;包括一套字节码指令集、一组寄存器、一个栈、 一个垃圾回收&#xff0c;堆 和 一个存储方法域。JVM 是运行在操作系统之上的&#xff0c;它与硬件没有直接 的交互。 (2) 运行过程&#x…

File 类的用法, InputStream和Reader, OutputStream和Writer 的用法

前言 普通的文件长这样&#xff1a; 其实目录也是一种特殊文件&#xff1a; 一、文件前缀知识 &#xff08;一&#xff09;绝对路径和相对路径 以盘符开头的的路径&#xff0c;叫做绝对路径&#xff0c;如&#xff1a;D:\360Downloads\cat.jpg 以.或..开头的路径&#xff0c…

华为认证为什么现在这么受欢迎?

华为认证目前受欢迎的原因有很多&#xff0c;以下是其中一些主要原因&#xff1a; 高质量的认证培训&#xff1a;华为认证提供了一系列高质量的培训课程&#xff0c;涵盖了IT技术、网络安全、云计算等领域。这些培训课程由华为的技术专家和工程师团队设计和提供&#xff0c;内容…

编程练习(3)

一.选择题 第一题&#xff1a; 函数传参的两个变量都是传的地址&#xff0c;而数组名c本身就是地址&#xff0c;int型变量b需要使用&符号&#xff0c;因此答案为A 第二题&#xff1a; 本题考察const修饰指针变量&#xff0c;答案为A,B,C,D 第三题&#xff1a; 注意int 型变…

回到未来:使用马尔可夫转移矩阵分析时间序列数据

一、说明 在本文中&#xff0c;我们将研究使用马尔可夫转移矩阵重构时间序列数据如何产生有趣的描述性见解以及用于预测、回溯和收敛分析的优雅方法。在时间上来回走动——就像科幻经典《回到未来》中 Doc 改装的 DeLorean 时间机器一样。 注意&#xff1a;以下各节中的所有方程…

CVE-2015-5254漏洞复现

1.漏洞介绍。 Apache ActiveMQ 是美国阿帕奇&#xff08;Apache&#xff09;软件基金会所研发的一套开源的消息中间件&#xff0c;它支持 Java 消息服务&#xff0c;集群&#xff0c;Spring Framework 等。Apache ActiveMQ 5.13.0之前 5.x 版本中存在安全漏洞&#xff0c;该漏…

曲面(弧面、柱面)展平(拉直)瓶子标签识别ocr

瓶子或者柱面在做字符识别的时候由于变形&#xff0c;识别效果是很不好的 或者是检测瓶子表面缺陷的时候效果也没有展平的好 下面介绍两个项目&#xff0c;关于曲面&#xff08;弧面、柱面&#xff09;展平&#xff08;拉直&#xff09; 项目一&#xff1a;通过识别曲面的6个点…

去除UI切图边缘上多余的线条

最近接到UI切图&#xff0c;放进项目&#xff0c;显示边缘有多余线条&#xff0c;影响UI美观。开始以为切图没切好&#xff0c;实则不是。如图&#xff1a; ->解决&#xff1a; 将该图片资源WrapMode改为Clamp

回归预测 | MATLAB实现SSA-BP麻雀搜索算法优化BP神经网络多输入单输出回归预测(多指标,多图)

回归预测 | MATLAB实现SSA-BP麻雀搜索算法优化BP神经网络多输入单输出回归预测&#xff08;多指标&#xff0c;多图&#xff09; 目录 回归预测 | MATLAB实现SSA-BP麻雀搜索算法优化BP神经网络多输入单输出回归预测&#xff08;多指标&#xff0c;多图&#xff09;效果一览基本…

Matplotlib数据可视化(三)

目录 1.绘图的填充 1.1 曲线下方区域的填充 1.2 填充部分区域 1.3 两条曲线之间的区域填充 1.4 直接使用fill进行填充 1.绘图的填充 绘图的填充可以调用fill_between()或fill()进行填充。 1.1 曲线下方区域的填充 x np.linspace(0,1,500) y np.sin(3*np.pi*x)*np.exp…

深入理解【二叉树】

&#x1f4d9;作者简介&#xff1a; 清水加冰&#xff0c;目前大二在读&#xff0c;正在学习C/C、Python、操作系统、数据库等。 &#x1f4d8;相关专栏&#xff1a;C语言初阶、C语言进阶、C语言刷题训练营、数据结构刷题训练营、有感兴趣的可以看一看。 欢迎点赞 &#x1f44d…

3D- vista:预训练的3D视觉和文本对齐Transformer

论文&#xff1a;https://arxiv.org/abs/2308.04352 代码: GitHub - 3d-vista/3D-VisTA: Official implementation of ICCV 2023 paper "3D-VisTA: Pre-trained Transformer for 3D Vision and Text Alignment" 摘要 三维视觉语言基础(3D- vl)是一个新兴领域&…

Jmeter参数化类型

1.参数在多个请求报文中出现&#xff0c;执行一次需要使用同一个参数--随机生成(随机变更) 2.参数在请求报文中出现&#xff0c;执行过程需要使用同一个参数(--固定参数) 3.参数从指定几个固定中随机获取一个 4.参数从本地文件中获取 5.参数在多个请求报文中出现&#xff0c;每…

C++笔记之左值与右值、右值引用

C笔记之左值与右值、右值引用 code review! 文章目录 C笔记之左值与右值、右值引用1.左值与右值2.右值引用——关于int&& r 10;3.右值引用——对比int&& r 10;和int& r 10;4.右值引用&#xff08;rvalue reference&#xff09;的概念 1.左值与右值 2.…

CNN卷积详解(三)

一、卷积层的计算 4 ∗ * ∗ 4的输入矩阵 I I I 和 3 ∗ * ∗ 3 的卷积核 K K K: 在步长&#xff08;stride&#xff09;为 1 时&#xff0c;输出的大小为 ( 4 − 3 1 ) ( 4 − 3 1) 计算公式&#xff1a; ● 输入图片矩阵 I I I 大小&#xff1a; w w w w ww ●…

i.MX6ULL开发板无法进入NFS挂载文件系统的解决办法

问题 使用NFS网络挂载文件系统后卡住无法进入系统。 解决办法 此处不详细讲述NFS安装流程 查看板卡挂载在/home/etc/rc.init下的自启动程序 进入到../../home/etc目录下&#xff0c;查看rc.init文件&#xff0c;首先从第一行排查&#xff0c;查看/home/etc/netcfg文件代码内容&…

【C++进阶】继承、多态的详解(多态篇)

【C进阶】继承、多态的详解&#xff08;多态篇&#xff09; 目录 【C进阶】继承、多态的详解&#xff08;多态篇&#xff09;多态的概念多态的定义及实现多态的构成条件&#xff08;重点&#xff09;虚函数虚函数的重写&#xff08;覆盖、一种接口继承&#xff09;C11 override…

qsort函数详解

大家好&#xff0c;我是苏貝&#xff0c;本篇博客带大家了解qsort函数&#xff0c;如果你觉得我写的不错的话&#xff0c;可以给我一个赞&#x1f44d;吗&#xff0c;感谢❤️ 文章目录 一. qsort函数参数详解1.数组首元素地址base2.数组的元素个数num和元素所占内存空间大小w…

java语言B/S架构云HIS医院信息系统源码【springboot】

医院云HIS全称为基于云计算的医疗卫生信息系统( Cloud- Based Healthcare Information System)&#xff0c;是运用云计算、大数据、物联网等新兴信息技术&#xff0c;按照现代医疗卫生管理要求&#xff0c;在一定区域范围内以数字化形式提供医疗卫生行业数据收集、存储、传递、…