【目标检测】yolov8结构及代码分析

yolov8代码:https://github.com/ultralytics/ultralytics

yolov8的整体结构如下图(来自mmyolo):

yolov8的配置文件:

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'# [depth, width, max_channels]n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPss: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPsm: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPsl: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPsx: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs# YOLOv8.0n backbone
backbone:# [from, repeats, module, args]- [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2- [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4- [-1, 3, C2f, [128, True]]- [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8- [-1, 6, C2f, [256, True]]- [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16- [-1, 6, C2f, [512, True]]- [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32- [-1, 3, C2f, [1024, True]]- [-1, 1, SPPF, [1024, 5]]  # 9# YOLOv8.0n head
head:- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 6], 1, Concat, [1]]  # cat backbone P4- [-1, 3, C2f, [512]]  # 12- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 4], 1, Concat, [1]]  # cat backbone P3- [-1, 3, C2f, [256]]  # 15 (P3/8-small)- [-1, 1, Conv, [256, 3, 2]]- [[-1, 12], 1, Concat, [1]]  # cat head P4- [-1, 3, C2f, [512]]  # 18 (P4/16-medium)- [-1, 1, Conv, [512, 3, 2]]- [[-1, 9], 1, Concat, [1]]  # cat head P5- [-1, 3, C2f, [1024]]  # 21 (P5/32-large)- [[15, 18, 21], 1, Detect, [nc]]  # Detect(P3, P4, P5)

可以看出,主要包含Conv,C2f,SPPF,Concat,Detect几个模块。

一、Conv

Conv模块包含卷积层、BN层和激活函数层。

代码如下:

class Conv(nn.Module):"""Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)."""default_act = nn.SiLU()  # default activationdef __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):"""Initialize Conv layer with given arguments including activation."""super().__init__()self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)self.bn = nn.BatchNorm2d(c2)self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()def forward(self, x):"""Apply convolution, batch normalization and activation to input tensor."""return self.act(self.bn(self.conv(x)))def forward_fuse(self, x):"""Perform transposed convolution of 2D data."""return self.act(self.conv(x))

二、C2f

C2f就是模型结构图中的CSPLayer_2Conv,包含多个DarkNetBottleNeck。代码如下:

class C2f(nn.Module):"""Faster Implementation of CSP Bottleneck with 2 convolutions."""def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):"""Initialize CSP bottleneck layer with two convolutions with arguments ch_in, ch_out, number, shortcut, groups,expansion."""super().__init__()self.c = int(c2 * e)  # hidden channelsself.cv1 = Conv(c1, 2 * self.c, 1, 1)self.cv2 = Conv((2 + n) * self.c, c2, 1)  # optional act=FReLU(c2)self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n))def forward(self, x):"""Forward pass through C2f layer."""y = list(self.cv1(x).chunk(2, 1))y.extend(m(y[-1]) for m in self.m)return self.cv2(torch.cat(y, 1))def forward_split(self, x):"""Forward pass using split() instead of chunk()."""y = list(self.cv1(x).split((self.c, self.c), 1))y.extend(m(y[-1]) for m in self.m)return self.cv2(torch.cat(y, 1))class Bottleneck(nn.Module):"""Standard bottleneck."""def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5):"""Initializes a bottleneck module with given input/output channels, shortcut option, group, kernels, andexpansion."""super().__init__()c_ = int(c2 * e)  # hidden channelsself.cv1 = Conv(c1, c_, k[0], 1)self.cv2 = Conv(c_, c2, k[1], 1, g=g)self.add = shortcut and c1 == c2def forward(self, x):"""'forward()' applies the YOLO FPN to input data."""return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))

三、SPPF

yolov8的SPPF实现是通过多次池化实现不同大小的池化窗口运算,代码如下:

class SPPF(nn.Module):"""Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher."""def __init__(self, c1, c2, k=5):"""Initializes the SPPF layer with given input/output channels and kernel size.This module is equivalent to SPP(k=(5, 9, 13))."""super().__init__()c_ = c1 // 2  # hidden channelsself.cv1 = Conv(c1, c_, 1, 1)self.cv2 = Conv(c_ * 4, c2, 1, 1)self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)def forward(self, x):"""Forward pass through Ghost Convolution block."""x = self.cv1(x)y1 = self.m(x)y2 = self.m(y1)return self.cv2(torch.cat((x, y1, y2, self.m(y2)), 1))

四、Concat

和torch.cat功能几乎一致:

class Concat(nn.Module):"""Concatenate a list of tensors along dimension."""def __init__(self, dimension=1):"""Concatenates a list of tensors along a specified dimension."""super().__init__()self.d = dimensiondef forward(self, x):"""Forward pass for the YOLOv8 mask Proto module."""return torch.cat(x, self.d)

五、Detect

yolov8的检测头:

class Detect(nn.Module):"""YOLOv8 Detect head for detection models."""dynamic = False  # force grid reconstructionexport = False  # export modeshape = Noneanchors = torch.empty(0)  # initstrides = torch.empty(0)  # initdef __init__(self, nc=80, ch=()):"""Initializes the YOLOv8 detection layer with specified number of classes and channels."""super().__init__()self.nc = nc  # number of classesself.nl = len(ch)  # number of detection layersself.reg_max = 16  # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x)self.no = nc + self.reg_max * 4  # number of outputs per anchorself.stride = torch.zeros(self.nl)  # strides computed during buildc2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100))  # channelsself.cv2 = nn.ModuleList(nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch)self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch)self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()def forward(self, x):"""Concatenates and returns predicted bounding boxes and class probabilities."""shape = x[0].shape  # BCHWfor i in range(self.nl):x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)if self.training:return xelif self.dynamic or self.shape != shape:self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))self.shape = shapex_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)if self.export and self.format in ('saved_model', 'pb', 'tflite', 'edgetpu', 'tfjs'):  # avoid TF FlexSplitV opsbox = x_cat[:, :self.reg_max * 4]cls = x_cat[:, self.reg_max * 4:]else:box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.stridesif self.export and self.format in ('tflite', 'edgetpu'):# Normalize xywh with image size to mitigate quantization error of TFLite integer models as done in YOLOv5:# https://github.com/ultralytics/yolov5/blob/0c8de3fca4a702f8ff5c435e67f378d1fce70243/models/tf.py#L307-L309# See this PR for details: https://github.com/ultralytics/ultralytics/pull/1695img_h = shape[2] * self.stride[0]img_w = shape[3] * self.stride[0]img_size = torch.tensor([img_w, img_h, img_w, img_h], device=dbox.device).reshape(1, 4, 1)dbox /= img_sizey = torch.cat((dbox, cls.sigmoid()), 1)return y if self.export else (y, x)def bias_init(self):"""Initialize Detect() biases, WARNING: requires stride availability."""m = self  # self.model[-1]  # Detect() module# cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1# ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum())  # nominal class frequencyfor a, b, s in zip(m.cv2, m.cv3, m.stride):  # froma[-1].bias.data[:] = 1.0  # boxb[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2)  # cls (.01 objects, 80 classes, 640 img)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/227437.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

跨境电商迎来综合竞争力比拼时代 五大趋势解读跨境2024

过去几年,跨境电商成为外贸出口增长的一大亮点,随着年底国务院办公厅《关于加快内外贸一体化发展的若干措施》的发布,跨境电商在促进经济发展、助力内外贸一体化发展方面的价值更加凸显。 这是跨境电商变化最快的时代,也是跨境电…

GrayLog日志平台的基本使用-ssh之Email报警

1、首先编辑并添加邮件配置到server.conf(注意:是添加) vim /etc/graylog/server/server.conf # Email transport transport_email_enabled true transport_email_hostname smtp.qq.com transport_email_port 465 transport_email_use_a…

Net6 Core webApi发布到IIS

Net6 Core Api发布到IIS不同于webapi,依赖框架不同,配置也移至项目内Program.cs 一、发布到指定文件夹和IIS,不过注意IIS应用程序池选择的是 “无托管代码“ 在IIS管理器中点击浏览,访问接口路径报500.19,原因是所依赖…

杰发科技AC7840——EEPROM初探

0.序 7840和7801的模拟EEPROM使用不太一样 1.现象 按照官方Demo,在这样的配置下,我们看到存储是这样的(连续三个数字1 2 3)。 使用串口工具的多帧发送功能 看不出多少规律 修改代码后 发现如下规律: 前四个字节是…

GBASE南大通用携手宇信科技打造“一表通”全链路解决方案

什么是“一表通”? “一表通”是国家金融监督管理总局为发挥统计监督效能、完善银行保险监管统计制度、推进监管数据标准化建设、打破数据壁垒,而制定的新型监管数据统计规范。相较于以往的报送接口,“一表通”提高了对报送时效性、校验准确…

stm32学习笔记:TIM-定时中断和外部时钟

定时器四部分讲解内容,本文是第一部分 ​​​​​TIM简介 基本定时器 时基单元:预分频器、计数器、自动重装载寄存器 预分频器之前,连接的就是基准计数时钟的输入,由于基本定时器只能选择内部时钟,所以可以认为这根…

【MySQL】数据库之存储过程(“SQL语句的脚本“)

目录 一、什么是存储过程? 二、存储过程的作用 三、如何创建、调用、查看、删除、修改存储过程 四、存储过程的参数(输入参数,输出参数,输入输出参数) 第一种:输入参数 第二种:输出参数 …

【数据结构】什么是二叉树?

🦄个人主页:修修修也 🎏所属专栏:数据结构 ⚙️操作环境:Visual Studio 2022 目录 📌二叉树的定义 📌二叉树的特点 📌特殊二叉树 📌二叉树的性质 📌二叉树的存储结构 📌二叉树…

链表精选题集

目录 1 链表翻转 题目链接: 解题: 试错版: 2 找中间节点 题目链接: 题解: 3 找倒数第k个节点 题目链接: 题解: 4 将两个升序链表合并为一个升序链表 题目链接: 题解: …

Spark SQL

目录 一、Spark SQL简介 (一)从Shark说起 (二)Spark SQL架构 (三)为什么推出Spark SQL 二、DataFrame概述 三、DataFrame的创建 四、DataFrame的保存 五、DataFrame的常用操作 六、从RDD转换得到D…

nodejs+vue+ElementUi农产品团购销售系统zto2c

目标是为了完成小区团购平台的设计和实现,在疫情当下的环境,方便小区业主购入生活所需,减小居民的生活压力 采用B/S模式架构系统,开发简单,只需要连接网络即可登录本系统,不需要安装任何客户端。开发工具采…

mxxWechatBot微信机器人V2版本文档说明

大家伙,我是雄雄,欢迎关注微信公众号:雄雄的小课堂。 先看这里 一、前言二、mxxWechatBot流程图三、怎么使用? 一、前言 经过不断地探索与研究,mxxWechatBot正式上线,届时全面开放使用。 mxxWechatBot&am…

LLM之RAG实战(十一)| 使用Mistral-7B和Langchain搭建基于PDF文件的聊天机器人

在本文中,使用LangChain、HuggingFaceEmbeddings和HuggingFace的Mistral-7B LLM创建一个简单的Python程序,可以从任何pdf文件中回答问题。 一、LangChain简介 LangChain是一个在语言模型之上开发上下文感知应用程序的框架。LangChain使用带prompt和few-…

小米电脑管家 - 手机平板电脑家居互联

系列文章目录 前言 联想电脑安装小米电脑管家实现设备互联 如图,将 小米平板 5 Pro 作为联想笔记本 GeekPro 5000 (这垃圾电脑)的副屏。 可以在小米平板控制笔记本,如图所示 一、官方使用手册 参考:小米电脑管家帮助 …

听GPT 讲Rust源代码--src/tools(34)

File: rust/src/tools/clippy/clippy_lints/src/collection_is_never_read.rs 文件"collection_is_never_read.rs"位于Rust源代码中的clippy_lints工具中,其作用是检查在集合类型(如Vec、HashMap等)的实例上执行的操作是否被忽略了…

学习笔记13——Spring整合Mybatis、junit、AOP、事务

学习笔记系列开头惯例发布一些寻亲消息 链接:https://baobeihuijia.com/bbhj/ Mybatis - Spring(使用第三方包new一个对象bean) 原始的Mybatis与数据库交互【通过sqlmapconfig来配置和连接】 初始化SqlSessionFactory获得连接获取数据层接口…

Jmeter吞吐量控制器总结

吞吐量控制器(Throughput Controller) 场景: 在同一个线程组里, 有10个并发, 7个做A业务, 3个做B业务,要模拟这种场景,可以通过吞吐量模拟器来实现。 添加吞吐量控制器 用法1: Percent Executions 在一个线程组内分别建立两个吞吐量控制器, 分别放业务A和业务B …

【三维目标检测/自动驾驶】IA-BEV:基于结构先验和自增强学习的实例感知三维目标检测(AAAI 2024)

系列文章目录 论文:Instance-aware Multi-Camera 3D Object Detection with Structural Priors Mining and Self-Boosting Learning 地址:https://arxiv.org/pdf/2312.08004.pdf 来源:复旦大学 英特尔Shanghai Key Lab /美团 文章目录 系列文…

数据预处理时,怎样处理类别型特征?

1. 序号编码 序号编码通常用于处理类别间具有大小关系的数据。例如成绩,可以分为低、中、高三档,并且存在“高>中>低”的排序关系。序号编码会按照大小关系对类别型特征赋予一个数值ID,例如高表示为3、中表示为2、低表示为1&#xff0…

Spring系列学习四、Spring数据访问

Spring数据访问 一、Spring中的JDBC模板介绍1、新建SpringBoot应用2、引入依赖:3、配置数据库连接,注入dbcTemplate对象,执行查询:4,测试验证: 二、整合MyBatis Plus1,在你的项目中添加MyBatis …