YOLOv6-4.0部分代码阅读笔记-effidehead_lite.py

effidehead_lite.py

yolov6\models\heads\effidehead_lite.py

目录

effidehead_lite.py

1.所需的库和模块

2.class Detect(nn.Module): 

3.def build_effidehead_layer(channels_list, num_anchors, num_classes, num_layers): 


1.所需的库和模块

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from yolov6.layers.common import DPBlock
from yolov6.assigners.anchor_generator import generate_anchors
from yolov6.utils.general import dist2bbox

2.class Detect(nn.Module): 

class Detect(nn.Module):# 高效分离头。# 利用硬件感知设计,使用混合通道方法对解耦头进行优化。'''Efficient Decoupled HeadWith hardware-aware degisn, the decoupled head is optimized withhybridchannels methods.'''def __init__(self, num_classes=80, num_layers=3, inplace=True, head_layers=None):  # detection layer    检测层super().__init__()assert head_layers is not Noneself.nc = num_classes  # number of classes    类别数量self.no = num_classes + 5  # number of outputs per anchor    每个锚点的输出数量self.nl = num_layers  # number of detection layers    检测层数self.grid = [torch.zeros(1)] * num_layersself.prior_prob = 1e-2self.inplace = inplacestride = [8, 16, 32] if num_layers == 3 else [8, 16, 32, 64] # strides computed during build    构建期间计算的步长self.stride = torch.tensor(stride)self.grid_cell_offset = 0.5self.grid_cell_size = 5.0# Init decouple head    初始化解耦头self.stems = nn.ModuleList()self.cls_convs = nn.ModuleList()self.reg_convs = nn.ModuleList()self.cls_preds = nn.ModuleList()self.reg_preds = nn.ModuleList()# Efficient decoupled head layers    高效解耦的头部层for i in range(num_layers):idx = i*5self.stems.append(head_layers[idx])self.cls_convs.append(head_layers[idx+1])self.reg_convs.append(head_layers[idx+2])self.cls_preds.append(head_layers[idx+3])self.reg_preds.append(head_layers[idx+4])# 它用于初始化神经网络中特定层的偏置(biases)。这个方法特别针对于类别预测层( self.cls_preds )和边界框回归预测层( self.reg_preds )的偏置和权重进行初始化。# 接受 self 作为参数,表示类的实例。def initialize_biases(self):# 遍历所有类别预测层, self.cls_preds 是一个包含卷积层的列表。for conv in self.cls_preds:# 获取当前卷积层的偏置,并将其展平为一维张量。b = conv.bias.view(-1, )# 使用逻辑斯谛分布的公式来初始化偏置值。这里的 self.prior_prob 是一个先验概率,通常用于目标检测中表示目标存在的概率。这个公式确保了在开始训练时,模型对目标的存在与否持中立态度。b.data.fill_(-math.log((1 - self.prior_prob) / self.prior_prob))# 将初始化后的偏置值重新设置为卷积层的偏置,并确保它们是可训练的( requires_grad=True )。conv.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)# 获取当前卷积层的权重。w = conv.weight# 将权重初始化为0。w.data.fill_(0.)# 将权重重新设置为卷积层的权重,并确保它们是可训练的。conv.weight = torch.nn.Parameter(w, requires_grad=True)# 遍历所有边界框回归预测层, self.reg_preds 是一个包含卷积层的列表。for conv in self.reg_preds:# 获取当前卷积层的偏置,并将其展平为一维张量。b = conv.bias.view(-1, )# 将偏置值初始化为1.0,这是因为在边界框回归中,我们通常希望预测的边界框与真实边界框的中心点对齐。b.data.fill_(1.0)# 将初始化后的偏置值重新设置为卷积层的偏置,并确保它们是可训练的。conv.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)# 获取当前卷积层的权重。w = conv.weight# 将权重初始化为0。w.data.fill_(0.)# 将权重重新设置为卷积层的权重,并确保它们是可训练的。conv.weight = torch.nn.Parameter(w, requires_grad=True)def forward(self, x):if self.training:cls_score_list = []reg_distri_list = []for i in range(self.nl):x[i] = self.stems[i](x[i])cls_x = x[i]reg_x = x[i]cls_feat = self.cls_convs[i](cls_x)cls_output = self.cls_preds[i](cls_feat)reg_feat = self.reg_convs[i](reg_x)reg_output = self.reg_preds[i](reg_feat)cls_output = torch.sigmoid(cls_output)cls_score_list.append(cls_output.flatten(2).permute((0, 2, 1)))reg_distri_list.append(reg_output.flatten(2).permute((0, 2, 1)))cls_score_list = torch.cat(cls_score_list, axis=1)reg_distri_list = torch.cat(reg_distri_list, axis=1)return x, cls_score_list, reg_distri_listelse:cls_score_list = []reg_dist_list = []# def generate_anchors(feats, fpn_strides, grid_cell_size=5.0, grid_cell_offset=0.5,  device='cpu', is_eval=False, mode='af'):# -> 根据特征生成锚点。# -> return anchor_points, stride_tensor  /  return anchors, anchor_points, num_anchors_list, stride_tensoranchor_points, stride_tensor = generate_anchors(x, self.stride, self.grid_cell_size, self.grid_cell_offset, device=x[0].device, is_eval=True, mode='af')for i in range(self.nl):b, _, h, w = x[i].shapel = h * wx[i] = self.stems[i](x[i])cls_x = x[i]reg_x = x[i]cls_feat = self.cls_convs[i](cls_x)cls_output = self.cls_preds[i](cls_feat)reg_feat = self.reg_convs[i](reg_x)reg_output = self.reg_preds[i](reg_feat)cls_output = torch.sigmoid(cls_output)cls_score_list.append(cls_output.reshape([b, self.nc, l]))reg_dist_list.append(reg_output.reshape([b, 4, l]))cls_score_list = torch.cat(cls_score_list, axis=-1).permute(0, 2, 1)reg_dist_list = torch.cat(reg_dist_list, axis=-1).permute(0, 2, 1)# def dist2bbox(distance, anchor_points, box_format='xyxy'): -> 将距离(ltrb)转换为盒子(xywh或xyxy)。 -> return bboxpred_bboxes = dist2bbox(reg_dist_list, anchor_points, box_format='xywh')pred_bboxes *= stride_tensorreturn torch.cat([pred_bboxes,torch.ones((b, pred_bboxes.shape[1], 1), device=pred_bboxes.device, dtype=pred_bboxes.dtype),cls_score_list],axis=-1)

3.def build_effidehead_layer(channels_list, num_anchors, num_classes, num_layers): 

def build_effidehead_layer(channels_list, num_anchors, num_classes, num_layers):head_layers = nn.Sequential(# stem0DPBlock(in_channel=channels_list[0],out_channel=channels_list[0],kernel_size=5,stride=1),# cls_conv0DPBlock(in_channel=channels_list[0],out_channel=channels_list[0],kernel_size=5,stride=1),# reg_conv0DPBlock(in_channel=channels_list[0],out_channel=channels_list[0],kernel_size=5,stride=1),# cls_pred0nn.Conv2d(in_channels=channels_list[0],out_channels=num_classes * num_anchors,kernel_size=1),# reg_pred0nn.Conv2d(in_channels=channels_list[0],out_channels=4 * num_anchors,kernel_size=1),# stem1DPBlock(in_channel=channels_list[1],out_channel=channels_list[1],kernel_size=5,stride=1),# cls_conv1DPBlock(in_channel=channels_list[1],out_channel=channels_list[1],kernel_size=5,stride=1),# reg_conv1DPBlock(in_channel=channels_list[1],out_channel=channels_list[1],kernel_size=5,stride=1),# cls_pred1nn.Conv2d(in_channels=channels_list[1],out_channels=num_classes * num_anchors,kernel_size=1),# reg_pred1nn.Conv2d(in_channels=channels_list[1],out_channels=4 * num_anchors,kernel_size=1),# stem2DPBlock(in_channel=channels_list[2],out_channel=channels_list[2],kernel_size=5,stride=1),# cls_conv2DPBlock(in_channel=channels_list[2],out_channel=channels_list[2],kernel_size=5,stride=1),# reg_conv2DPBlock(in_channel=channels_list[2],out_channel=channels_list[2],kernel_size=5,stride=1),# cls_pred2nn.Conv2d(in_channels=channels_list[2],out_channels=num_classes * num_anchors,kernel_size=1),# reg_pred2nn.Conv2d(in_channels=channels_list[2],out_channels=4 * num_anchors,kernel_size=1))if num_layers == 4:head_layers.add_module('stem3',# stem3DPBlock(in_channel=channels_list[3],out_channel=channels_list[3],kernel_size=5,stride=1))head_layers.add_module('cls_conv3',# cls_conv3DPBlock(in_channel=channels_list[3],out_channel=channels_list[3],kernel_size=5,stride=1))head_layers.add_module('reg_conv3',# reg_conv3DPBlock(in_channel=channels_list[3],out_channel=channels_list[3],kernel_size=5,stride=1))head_layers.add_module('cls_pred3',# cls_pred3nn.Conv2d(in_channels=channels_list[3],out_channels=num_classes * num_anchors,kernel_size=1))head_layers.add_module('reg_pred3',# reg_pred3nn.Conv2d(in_channels=channels_list[3],out_channels=4 * num_anchors,kernel_size=1))return head_layers

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/466903.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

得物多模态大模型在重复商品识别上的应用和架构演进

重复商品治理介绍 根据得物的平台特性,同一个商品在平台上不能出现多个链接,原因是平台需要保证一品一链的特点,以保障商品的集中竞价,所以说一个商品在整个得物平台上只能有一个商详链接,因此我们需要对一品多链的情…

盘点2024年惊艳的10款录屏工具!!

你是否经常需要捕捉电脑屏幕上的精彩瞬间?或者想要记录自己操作某个应用程序的流程?这时候你就需要一款录屏工具啦!在学习、工作和娱乐中,录屏工具都能成为你的得力助手。无论你是做教学视频、游戏解说还是分享精彩瞬间&#xff0…

vue+websocket实现即时聊天平台

目录 1 什么是websocket 2 实现步骤 2.1 导入依赖 2.2 编写代码 1 什么是websocket WebSocket 是一种在单个 TCP 连接上进行全双工通信的协议。它主要用于在客户端和服务器之间建立持久的连接,允许实时数据交换。WebSocket 的设计目的是为了提高 Web 应用程序的…

软件设计师-上午题-15 计算机网络(5分)

计算机网络题号一般为66-70题,分值一般为5分。 目录 1 网络设备 1.1 真题 2 协议簇 2.1 真题 3 TCP和UDP 3.1 真题 4 SMTP和POP3 4.1 真题 5 ARP 5.1 真题 6 DHCP 6.1 真题 7 URL 7.1 真题 8 浏览器 8.1 真题 9 IP地址和子网掩码 9.1 真题 10 I…

C++:map 和 set 的使用

前言 平衡二叉搜索树 ( AVL树 ) 由于二叉搜索树在特殊情况下,其增删查的效率会降低到 O ( N ),因此对二叉搜索树进行改良,通过旋转等方式将其转换为一个左右均衡的二叉树,这样的树就称为平衡二叉搜索树,又称 AVL树。…

Vue 自定义icon组件封装SVG图标

通过自定义子组件CustomIcon.vue使用SVG图标&#xff0c;相比iconfont下载文件、重新替换更节省时间。 子组件包括&#xff1a; 1. Icons.vue 存放所有SVG图标的path 2. CustomIcon.vue 通过icon的id索引对应的图标 使用的时候需要将 <Icons></Icons> 引到使用的…

面相小白的php反序列化漏洞原理剖析

前言 欢迎来到我的博客 个人主页:北岭敲键盘的荒漠猫-CSDN博客 本文整理反序列化漏洞的一些成因原理 建议学习反序列化之前 先对php基础语法与面向对象有个大体的了解 (我觉得我整理的比较细致&#xff0c;了解这俩是个啥就行) 漏洞实战情况 这个漏洞黑盒几乎不会被发现&am…

ReactPress:深入解析技术方案设计与源码

ReactPress Github项目地址&#xff1a;https://github.com/fecommunity/reactpress 欢迎提出宝贵的建议&#xff0c;欢迎一起共建&#xff0c;感谢Star。 ReactPress是一个基于React框架开发的开源发布平台&#xff0c;它不仅仅是一个简单的博客系统&#xff0c;更是一个功能全…

canal1.1.7使用canal-adapter进行mysql同步数据

重要的事情说前面&#xff0c;canal1.1.8需要jdk11以上&#xff0c;大家自行选择&#xff0c;我这由于项目原因只能使用1.1.7兼容版的 文章参考地址&#xff1a; canal 使用详解_canal使用-CSDN博客 使用canal.deployer-1.1.7和canal.adapter-1.1.7实现mysql数据同步_mysql更…

SpringBoot之定时任务

1. 前言 本篇博客是个人的经验之谈&#xff0c;不是普适的解决方案。阅读本篇博客的朋友&#xff0c;可以参考这里的写法&#xff0c;如有不同的见解和想法&#xff0c;欢迎评论区交流。如果此篇博客对你有帮助&#xff0c;感谢点个赞~ 2. 场景 我们讨论在单体项目&#xff0c…

绿色能源发展关键:优化风电运维体系

根据QYResearch调研团队最新发布的《全球风电运维市场报告2023-2029》显示&#xff0c;预计到2029年&#xff0c;全球风电运维市场的规模将攀升至307.8亿美元&#xff0c;并且在接下来的几年里&#xff0c;其年复合增长率&#xff08;CAGR&#xff09;将达到12.5%。 上述图表及…

前端 Canvas 绘画 总结

目录 一、使用案例 1、基础使用案例 2、基本案例改为直接JS实现 二、相关资料 1、API教程文档 2、炫酷案例 一、使用案例 1、基础使用案例 使用Canvas的基本步骤&#xff1a; 1、需要一个canvas标签 2、需要获取 画笔 对象 3、使用canvas提供的api进行绘图 <!--…

力扣排序455题(分发饼干)

假设你是一位很棒的家长&#xff0c;想要给你的孩子们一些小饼干。 但是&#xff0c;每个孩子最多只能给一块饼干。 对每个孩子 i&#xff0c;都有一个胃口值 g[i],这是能 让孩子们满足胃口的饼干的最小尺寸;并且每块饼 干j&#xff0c;都有一个尺寸 s[j]。如果 s[j]> g[i]&…

C语言 | Leetcode C语言题解之第537题复数乘法

题目&#xff1a; 题解&#xff1a; bool parseComplexNumber(const char * num, int * real, int * image) {char *token strtok(num, "");*real atoi(token);token strtok(NULL, "i");*image atoi(token);return true; };char * complexNumberMulti…

Android使用scheme方式唤醒处于后台时的App场景

场景&#xff1a;甲App唤醒处于后台时的乙App的目标界面Activity&#xff0c;且乙App的目标界面Activity处于最上层&#xff0c;即已经打开状态&#xff0c;要求甲App使用scheme唤醒乙App时&#xff0c;达到跟从桌面icon拉起App效果一致&#xff0c;不能出现只拉起了乙App的目标…

如何对接低价折扣相对稳定电影票渠道?

对接低价折扣电影票渠道需要经过一系列步骤&#xff0c;以确保能够为用户提供优惠且可靠的购票体验。以下是一个基本的对接流程&#xff1a; 1.市场调研&#xff1a; 调研市场上的电影票销售渠道&#xff0c;了解主要的电影票批发商和分销商。分析竞争对手的折扣电影票服务&a…

【上云拼团Go】如何在腾讯云双十一活动中省钱

1. 前言 双十一已经成为了全球最大的购物狂欢节&#xff0c;除了电商平台的优惠&#xff0c;云计算服务商也纷纷在这个期间推出了诱人的促销活动。腾讯云作为中国云计算的领军企业之一&#xff0c;每年双十一的活动都吸引了大量开发者、企业和个人用户参与。那么&#xff0c;在…

新能源企业在精益变革过程中可能会遇到哪些困难?

在绿色转型的浪潮中&#xff0c;新能源企业作为推动社会可持续发展的先锋力量&#xff0c;正加速迈向精益化管理的新阶段。然而&#xff0c;这条变革之路并非坦途&#xff0c;而是布满了未知与挑战。本文&#xff0c;天行健王春城老师将深入探讨新能源企业在精益变革过程中可能…

Maven的安装配置

文章目录 一、MVN 的下载二、配置maven2.1、更改maven/conf/settings.xml配置2.2、配置环境变量一、MVN 的下载 还是那句话,要去就去官网或者github,别的地方不要去下载。我们下载binaries/ 目录下的 cd /opt/server wget https://downloads.apache.org/maven/maven-3/3.9.6/…

OpenCV视觉分析之目标跟踪(10)估计两个点集之间的刚性变换函数estimateRigidTransform的使用

操作系统&#xff1a;ubuntu22.04 OpenCV版本&#xff1a;OpenCV4.9 IDE:Visual Studio Code 编程语言&#xff1a;C11 算法描述 计算两个2D点集之间的最优仿射变换 estimateRigidTransform 是 OpenCV 中的一个函数&#xff0c;用于估计两个点集之间的刚性变换&#xff08;即…