YOLOv5 + SE注意力机制:提升目标检测性能的实践

一、引言

目标检测是计算机视觉领域的一个重要任务,广泛应用于自动驾驶、安防监控、工业检测等领域。YOLOv5作为YOLO系列的最新版本,以其高效性和准确性在实际应用中表现出色。然而,随着应用场景的复杂化,传统的卷积神经网络在处理复杂背景和多尺度目标时可能会遇到性能瓶颈。为此,引入注意力机制成为了一种有效的改进方法。本文将详细介绍如何在YOLOv5中引入SE(Squeeze-and-Excitation)注意力机制,通过修改模型配置文件和代码实现,提升模型性能,并对比训练效果。

YOLOv5是YOLO系列的最新版本,相较于之前的版本,YOLOv5在模型结构、训练策略和数据增强等方面进行了多项改进,显著提升了模型的性能和效率。其主要特点包括:

  • 模型结构优化:YOLOv5采用新的骨干网络(Backbone)和路径聚合网络(Neck),提高了特征提取和融合的能力。
  • 数据增强策略:引入了多种数据增强方法,如Mosaic、MixUp等,提升了模型的泛化能力。
  • 训练策略改进:采用动态标签分配策略(SimOTA),提高了训练效率和检测精度。

然而,随着任务复杂度的增加,传统的卷积神经网络在处理多尺度目标时的表现不够理想,SE注意力机制的引入为提升目标检测精度提供了新的思路。

二、YOLOv5与SE注意力机制

2.1 YOLOv5简介

YOLOv5以其高效性和准确性在目标检测中得到了广泛应用。其主要结构特点是:

  • Backbone:负责从输入图像中提取特征。
  • Neck:通过特征融合提高模型的多尺度感知能力。
  • Head:根据提取的特征进行预测。

2.2 SE注意力机制简介

SE(Squeeze-and-Excitation)注意力机制是一种轻量级的注意力模块,旨在通过显式地建模通道间的依赖关系,提升模型的表示能力。SE模块由两个关键部分组成:

  • Squeeze(压缩):通过全局平均池化操作,将特征图的空间维度压缩为1,生成通道描述符。
  • Excitation(激励):通过两个全连接层和一个Sigmoid激活函数生成通道权重,用于重新校准特征图的通道响应。

通过引入SE模块,YOLOv5能够更加关注重要的特征通道,抑制不重要的特征通道,从而提升模型性能。

三、YOLOv5 + SE注意力机制的实现

3.1 模型配置文件修改

首先,想要将SE注意力机制引入到Yolov5中去,需要修改以下几个文件:commom.py、yolo.py和yolov5s.yaml文件。需要修改YOLOv5的模型配置文件(yolov5_se.yaml),在Backbone和Neck中引入SE模块。注意将SE模块引入之后,需要更改层数的号码,SE注意力机制也可以加入到其他层中,比如head层的P3输出之前等等。以下是修改后的配置文件内容:

# YOLOv5 馃殌 by Ultralytics, GPL-3.0 license# Parameters
nc: 80  # number of classes
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.50  # layer channel multiple
anchors:- [10,13, 16,30, 33,23]  # P3/8- [30,61, 62,45, 59,119]  # P4/16- [116,90, 156,198, 373,326]  # P5/32# YOLOv5 v6.0 backbone
backbone:# [from, number, module, args][[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2[-1, 1, Conv, [128, 3, 2]],  # 1-P2/4[-1, 3, C3, [128]],[-1, 1, Conv, [256, 3, 2]],  # 3-P3/8[-1, 6, C3, [256]],[-1, 1, Conv, [512, 3, 2]],  # 5-P4/16[-1, 9, C3, [512]],[-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32[-1, 3, C3, [1024]],[-1, 1, SENet,[1024]], #SEAttention #9[-1, 1, SPPF, [1024, 5]],  # 10]# YOLOv5 v6.0 head
head:[[-1, 1, Conv, [512, 1, 1]],[-1, 1, nn.Upsample, [None, 2, 'nearest']],[[-1, 6], 1, Concat, [1]],  # cat backbone P4[-1, 3, C3, [512, False]],  # 13[-1, 1, Conv, [256, 1, 1]],[-1, 1, nn.Upsample, [None, 2, 'nearest']],[[-1, 4], 1, Concat, [1]],  # cat backbone P3#[-1, 1, SENet,[1024]], #SEAttention #9[-1, 3, C3, [256, False]],  # 18 (P3/8-small)[-1, 1, Conv, [256, 3, 2]],[[-1, 14], 1, Concat, [1]],  # cat head P4#[-1, 1, SENet,[1024]], #SEAttention #9[-1, 3, C3, [512, False]],  # 21 (P4/16-medium)[-1, 1, Conv, [512, 3, 2]],[[-1, 10], 1, Concat, [1]],  # cat head P5#[-1, 1, SENet,[1024]], #SEAttention #9[-1, 3, C3, [1024, False]],  # 24 (P5/32-large)[[18, 21, 24], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)]

3.2 SE注意力模块的代码实现

在YOLOv5的代码中,需要实现SE模块。以下是一个SEBlock的实现:

import torch
import torch.nn as nnclass SENet(nn.Module):#c1, c2, n=1, shortcut=True, g=1, e=0.5def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5 ):super(SENet, self).__init__()#c*1*1self.avgpool = nn.AdaptiveAvgPool2d(1)self.l1 = nn.Linear(c1, c1 // 16, bias=False)self.relu = nn.ReLU(inplace=True)self.l2 = nn.Linear(c1 // 16, c1, bias=False)self.sig = nn.Sigmoid()def forward(self, x):b, c, _, _ = x.size()y = self.avgpool(x).view(b, c)y = self.l1(y)y = self.relu(y)y = self.l2(y)y = self.sig(y)y = y.view(b, c, 1, 1)return x * y.expand_as(x)

3.3 使用SE注意力模块

为了在YOLOv5的Backbone和Neck中引入SE模块,可以对Yolo.py文件原有的parse_model进行修改,以下是修改后的Bottleneck模块:

def parse_model(d, ch):  # model_dict, input_channels(3)# Parse a YOLOv5 model.yaml dictionaryLOGGER.info(f"\n{'':>3}{'from':>18}{'n':>3}{'params':>10}  {'module':<40}{'arguments':<30}")anchors, nc, gd, gw, act = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple'], d.get('activation')if act:Conv.default_act = eval(act)  # redefine default activation, i.e. Conv.default_act = nn.SiLU()LOGGER.info(f"{colorstr('activation:')} {act}")  # printna = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors  # number of anchorsno = na * (nc + 5)  # number of outputs = anchors * (classes + 5)layers, save, c2 = [], [], ch[-1]  # layers, savelist, ch outfor i, (f, n, m, args) in enumerate(d['backbone'] + d['head']):  # from, number, module, argsm = eval(m) if isinstance(m, str) else m  # eval stringsfor j, a in enumerate(args):with contextlib.suppress(NameError):args[j] = eval(a) if isinstance(a, str) else a  # eval stringsn = n_ = max(round(n * gd), 1) if n > 1 else n  # depth gainif m in {Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv,BottleneckCSP, C3, C3TR, C3SPP, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x,SENet,}:c1, c2 = ch[f], args[0]if c2 != no:  # if not outputc2 = make_divisible(c2 * gw, 8)args = [c1, c2, *args[1:]]if m in {BottleneckCSP, C3, C3TR, C3Ghost, C3x, CBAMBottleneck, CABottleneck, CBAMC3, SENet, CANet, CAC3, CBAM, ECANet, GAMNet}:args.insert(2, n)  # number of repeatsn = 1elif m is nn.BatchNorm2d:args = [ch[f]]elif m is Concat:c2 = sum(ch[x] for x in f)# TODO: channel, gw, gdelif m in {Detect, Segment}:args.append([ch[x] for x in f])if isinstance(args[1], int):  # number of anchorsargs[1] = [list(range(args[1] * 2))] * len(f)if m is Segment:args[3] = make_divisible(args[3] * gw, 8)elif m is Contract:c2 = ch[f] * args[0] ** 2elif m is Expand:c2 = ch[f] // args[0] ** 2else:c2 = ch[f]m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args)  # modulet = str(m)[8:-2].replace('__main__.', '')  # module typenp = sum(x.numel() for x in m_.parameters())  # number paramsm_.i, m_.f, m_.type, m_.np = i, f, t, np  # attach index, 'from' index, type, number paramsLOGGER.info(f'{i:>3}{str(f):>18}{n_:>3}{np:10.0f}  {t:<40}{str(args):<30}')  # printsave.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1)  # append to savelistlayers.append(m_)if i == 0:ch = []ch.append(c2)return nn.Sequential(*layers), sorted(save)

3.4 模型训练与效果对比

完成模型配置文件和代码的修改后,可以开始训练模型。推荐使用

COCO数据集或自定义数据集进行训练和验证。或者其他的自定义数据集也可以,在这里我使用自定义数据集camel_elephant_training进行100个epoch训练,该数据集仅仅有骆驼和大象两个种类。

训练完成后,可以通过AP(平均精度)指标来评估引入SE注意力机制前后的模型性能。一般情况下,引入SE模块后,YOLOv5在复杂背景和多尺度目标的检测中表现更为出色。

训练之后的结果如下:

由于时间有限我仅仅训练了100个epoch,正常情况下应设置150~200epoch,从train/obj_loss来看,仍然有下降的空间。

3.5 训练步骤

  1. 配置训练环境,确保已安装YOLOv5和相关依赖。
  2. 下载COCO数据集或使用自定义数据集进行训练。
  3. 修改训练脚本,加载修改后的模型配置文件yolov5_se.yaml
  4. 开始训练并监控训练过程中的损失和精度。
  5. 完成训练后,使用验证集评估效果。

3.6 模型部署

将训练好的数据权重通过export.py文件转换成.onnx格式,可以部署到任意平台上。

import argparse
import contextlib
import json
import os
import platform
import re
import subprocess
import sys
import time
import warnings
from pathlib import Pathimport pandas as pd
import torch
from torch.utils.mobile_optimizer import optimize_for_mobileFILE = Path(__file__).resolve()
ROOT = FILE.parents[0]  # YOLOv5 root directory
if str(ROOT) not in sys.path:sys.path.append(str(ROOT))  # add ROOT to PATH
if platform.system() != 'Windows':ROOT = Path(os.path.relpath(ROOT, Path.cwd()))  # relativefrom models.experimental import attempt_load
from models.yolo import ClassificationModel, Detect, DetectionModel, SegmentationModel
from utils.dataloaders import LoadImages
from utils.general import (LOGGER, Profile, check_dataset, check_img_size, check_requirements, check_version,check_yaml, colorstr, file_size, get_default_args, print_args, url2file, yaml_save)
from utils.torch_utils import select_device, smart_inference_modeMACOS = platform.system() == 'Darwin'  # macOS environmentdef export_formats():# YOLOv5 export formatsx = [['PyTorch', '-', '.pt', True, True],['TorchScript', 'torchscript', '.torchscript', True, True],['ONNX', 'onnx', '.onnx', True, True],['OpenVINO', 'openvino', '_openvino_model', True, False],['TensorRT', 'engine', '.engine', False, True],['CoreML', 'coreml', '.mlmodel', True, False],['TensorFlow SavedModel', 'saved_model', '_saved_model', True, True],['TensorFlow GraphDef', 'pb', '.pb', True, True],['TensorFlow Lite', 'tflite', '.tflite', True, False],['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite', False, False],['TensorFlow.js', 'tfjs', '_web_model', False, False],['PaddlePaddle', 'paddle', '_paddle_model', True, True],]return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'CPU', 'GPU'])def try_export(inner_func):# YOLOv5 export decorator, i..e @try_exportinner_args = get_default_args(inner_func)def outer_func(*args, **kwargs):prefix = inner_args['prefix']try:with Profile() as dt:f, model = inner_func(*args, **kwargs)LOGGER.info(f'{prefix} export success 鉁?{dt.t:.1f}s, saved as {f} ({file_size(f):.1f} MB)')return f, modelexcept Exception as e:LOGGER.info(f'{prefix} export failure 鉂?{dt.t:.1f}s: {e}')return None, Nonereturn outer_func@try_export
def export_torchscript(model, im, file, optimize, prefix=colorstr('TorchScript:')):# YOLOv5 TorchScript model exportLOGGER.info(f'\n{prefix} starting export with torch {torch.__version__}...')f = file.with_suffix('.torchscript')ts = torch.jit.trace(model, im, strict=False)d = {"shape": im.shape, "stride": int(max(model.stride)), "names": model.names}extra_files = {'config.txt': json.dumps(d)}  # torch._C.ExtraFilesMap()if optimize:  # https://pytorch.org/tutorials/recipes/mobile_interpreter.htmloptimize_for_mobile(ts)._save_for_lite_interpreter(str(f), _extra_files=extra_files)else:ts.save(str(f), _extra_files=extra_files)return f, None@try_export
def export_onnx(model, im, file, opset, dynamic, simplify, prefix=colorstr('ONNX:')):# YOLOv5 ONNX exportcheck_requirements('onnx')import onnxLOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...')f = file.with_suffix('.onnx')output_names = ['output0', 'output1'] if isinstance(model, SegmentationModel) else ['output0']if dynamic:dynamic = {'images': {0: 'batch', 2: 'height', 3: 'width'}}  # shape(1,3,640,640)if isinstance(model, SegmentationModel):dynamic['output0'] = {0: 'batch', 1: 'anchors'}  # shape(1,25200,85)dynamic['output1'] = {0: 'batch', 2: 'mask_height', 3: 'mask_width'}  # shape(1,32,160,160)elif isinstance(model, DetectionModel):dynamic['output0'] = {0: 'batch', 1: 'anchors'}  # shape(1,25200,85)torch.onnx.export(model.cpu() if dynamic else model,  # --dynamic only compatible with cpuim.cpu() if dynamic else im,f,verbose=False,opset_version=opset,do_constant_folding=True,input_names=['images'],output_names=output_names,dynamic_axes=dynamic or None)# Checksmodel_onnx = onnx.load(f)  # load onnx modelonnx.checker.check_model(model_onnx)  # check onnx model# Metadatad = {'stride': int(max(model.stride)), 'names': model.names}for k, v in d.items():meta = model_onnx.metadata_props.add()meta.key, meta.value = k, str(v)onnx.save(model_onnx, f)# Simplifyif simplify:try:cuda = torch.cuda.is_available()check_requirements(('onnxruntime-gpu' if cuda else 'onnxruntime', 'onnx-simplifier>=0.4.1'))import onnxsimLOGGER.info(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...')model_onnx, check = onnxsim.simplify(model_onnx)assert check, 'assert check failed'onnx.save(model_onnx, f)except Exception as e:LOGGER.info(f'{prefix} simplifier failure: {e}')return f, model_onnx@try_export
def export_openvino(file, metadata, half, prefix=colorstr('OpenVINO:')):# YOLOv5 OpenVINO exportcheck_requirements('openvino-dev')  # requires openvino-dev: https://pypi.org/project/openvino-dev/import openvino.inference_engine as ieLOGGER.info(f'\n{prefix} starting export with openvino {ie.__version__}...')f = str(file).replace('.pt', f'_openvino_model{os.sep}')cmd = f"mo --input_model {file.with_suffix('.onnx')} --output_dir {f} --data_type {'FP16' if half else 'FP32'}"subprocess.run(cmd.split(), check=True, env=os.environ)  # exportyaml_save(Path(f) / file.with_suffix('.yaml').name, metadata)  # add metadata.yamlreturn f, None@try_export
def export_paddle(model, im, file, metadata, prefix=colorstr('PaddlePaddle:')):# YOLOv5 Paddle exportcheck_requirements(('paddlepaddle', 'x2paddle'))import x2paddlefrom x2paddle.convert import pytorch2paddleLOGGER.info(f'\n{prefix} starting export with X2Paddle {x2paddle.__version__}...')f = str(file).replace('.pt', f'_paddle_model{os.sep}')pytorch2paddle(module=model, save_dir=f, jit_type='trace', input_examples=[im])  # exportyaml_save(Path(f) / file.with_suffix('.yaml').name, metadata)  # add metadata.yamlreturn f, None@try_export
def export_coreml(model, im, file, int8, half, prefix=colorstr('CoreML:')):# YOLOv5 CoreML exportcheck_requirements('coremltools')import coremltools as ctLOGGER.info(f'\n{prefix} starting export with coremltools {ct.__version__}...')f = file.with_suffix('.mlmodel')ts = torch.jit.trace(model, im, strict=False)  # TorchScript modelct_model = ct.convert(ts, inputs=[ct.ImageType('image', shape=im.shape, scale=1 / 255, bias=[0, 0, 0])])bits, mode = (8, 'kmeans_lut') if int8 else (16, 'linear') if half else (32, None)if bits < 32:if MACOS:  # quantization only supported on macOSwith warnings.catch_warnings():warnings.filterwarnings("ignore", category=DeprecationWarning)  # suppress numpy==1.20 float warningct_model = ct.models.neural_network.quantization_utils.quantize_weights(ct_model, bits, mode)else:print(f'{prefix} quantization only supported on macOS, skipping...')ct_model.save(f)return f, ct_model@try_export
def export_engine(model, im, file, half, dynamic, simplify, workspace=4, verbose=False, prefix=colorstr('TensorRT:')):# YOLOv5 TensorRT export https://developer.nvidia.com/tensorrtassert im.device.type != 'cpu', 'export running on CPU but must be on GPU, i.e. `python export.py --device 0`'try:import tensorrt as trtexcept Exception:if platform.system() == 'Linux':check_requirements('nvidia-tensorrt', cmds='-U --index-url https://pypi.ngc.nvidia.com')import tensorrt as trtif trt.__version__[0] == '7':  # TensorRT 7 handling https://github.com/ultralytics/yolov5/issues/6012grid = model.model[-1].anchor_gridmodel.model[-1].anchor_grid = [a[..., :1, :1, :] for a in grid]export_onnx(model, im, file, 12, dynamic, simplify)  # opset 12model.model[-1].anchor_grid = gridelse:  # TensorRT >= 8check_version(trt.__version__, '8.0.0', hard=True)  # require tensorrt>=8.0.0export_onnx(model, im, file, 12, dynamic, simplify)  # opset 12onnx = file.with_suffix('.onnx')LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...')assert onnx.exists(), f'failed to export ONNX file: {onnx}'f = file.with_suffix('.engine')  # TensorRT engine filelogger = trt.Logger(trt.Logger.INFO)if verbose:logger.min_severity = trt.Logger.Severity.VERBOSEbuilder = trt.Builder(logger)config = builder.create_builder_config()config.max_workspace_size = workspace * 1 << 30# config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace << 30)  # fix TRT 8.4 deprecation noticeflag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))network = builder.create_network(flag)parser = trt.OnnxParser(network, logger)if not parser.parse_from_file(str(onnx)):raise RuntimeError(f'failed to load ONNX file: {onnx}')inputs = [network.get_input(i) for i in range(network.num_inputs)]outputs = [network.get_output(i) for i in range(network.num_outputs)]for inp in inputs:LOGGER.info(f'{prefix} input "{inp.name}" with shape{inp.shape} {inp.dtype}')for out in outputs:LOGGER.info(f'{prefix} output "{out.name}" with shape{out.shape} {out.dtype}')if dynamic:if im.shape[0] <= 1:LOGGER.warning(f"{prefix} WARNING 鈿狅笍 --dynamic model requires maximum --batch-size argument")profile = builder.create_optimization_profile()for inp in inputs:profile.set_shape(inp.name, (1, *im.shape[1:]), (max(1, im.shape[0] // 2), *im.shape[1:]), im.shape)config.add_optimization_profile(profile)LOGGER.info(f'{prefix} building FP{16 if builder.platform_has_fast_fp16 and half else 32} engine as {f}')if builder.platform_has_fast_fp16 and half:config.set_flag(trt.BuilderFlag.FP16)with builder.build_engine(network, config) as engine, open(f, 'wb') as t:t.write(engine.serialize())return f, None@try_export
def export_saved_model(model,im,file,dynamic,tf_nms=False,agnostic_nms=False,topk_per_class=100,topk_all=100,iou_thres=0.45,conf_thres=0.25,keras=False,prefix=colorstr('TensorFlow SavedModel:')):# YOLOv5 TensorFlow SavedModel exporttry:import tensorflow as tfexcept Exception:check_requirements(f"tensorflow{'' if torch.cuda.is_available() else '-macos' if MACOS else '-cpu'}")import tensorflow as tffrom tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2from models.tf import TFModelLOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')f = str(file).replace('.pt', '_saved_model')batch_size, ch, *imgsz = list(im.shape)  # BCHWtf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz)im = tf.zeros((batch_size, *imgsz, ch))  # BHWC order for TensorFlow_ = tf_model.predict(im, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)inputs = tf.keras.Input(shape=(*imgsz, ch), batch_size=None if dynamic else batch_size)outputs = tf_model.predict(inputs, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)keras_model = tf.keras.Model(inputs=inputs, outputs=outputs)keras_model.trainable = Falsekeras_model.summary()if keras:keras_model.save(f, save_format='tf')else:spec = tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype)m = tf.function(lambda x: keras_model(x))  # full modelm = m.get_concrete_function(spec)frozen_func = convert_variables_to_constants_v2(m)tfm = tf.Module()tfm.__call__ = tf.function(lambda x: frozen_func(x)[:4] if tf_nms else frozen_func(x), [spec])tfm.__call__(im)tf.saved_model.save(tfm,f,options=tf.saved_model.SaveOptions(experimental_custom_gradients=False) if check_version(tf.__version__, '2.6') else tf.saved_model.SaveOptions())return f, keras_model@try_export
def export_pb(keras_model, file, prefix=colorstr('TensorFlow GraphDef:')):# YOLOv5 TensorFlow GraphDef *.pb export https://github.com/leimao/Frozen_Graph_TensorFlowimport tensorflow as tffrom tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')f = file.with_suffix('.pb')m = tf.function(lambda x: keras_model(x))  # full modelm = m.get_concrete_function(tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype))frozen_func = convert_variables_to_constants_v2(m)frozen_func.graph.as_graph_def()tf.io.write_graph(graph_or_graph_def=frozen_func.graph, logdir=str(f.parent), name=f.name, as_text=False)return f, None@try_export
def export_tflite(keras_model, im, file, int8, data, nms, agnostic_nms, prefix=colorstr('TensorFlow Lite:')):# YOLOv5 TensorFlow Lite exportimport tensorflow as tfLOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')batch_size, ch, *imgsz = list(im.shape)  # BCHWf = str(file).replace('.pt', '-fp16.tflite')converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]converter.target_spec.supported_types = [tf.float16]converter.optimizations = [tf.lite.Optimize.DEFAULT]if int8:from models.tf import representative_dataset_gendataset = LoadImages(check_dataset(check_yaml(data))['train'], img_size=imgsz, auto=False)converter.representative_dataset = lambda: representative_dataset_gen(dataset, ncalib=100)converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]converter.target_spec.supported_types = []converter.inference_input_type = tf.uint8  # or tf.int8converter.inference_output_type = tf.uint8  # or tf.int8converter.experimental_new_quantizer = Truef = str(file).replace('.pt', '-int8.tflite')if nms or agnostic_nms:converter.target_spec.supported_ops.append(tf.lite.OpsSet.SELECT_TF_OPS)tflite_model = converter.convert()open(f, "wb").write(tflite_model)return f, None@try_export
def export_edgetpu(file, prefix=colorstr('Edge TPU:')):# YOLOv5 Edge TPU export https://coral.ai/docs/edgetpu/models-intro/cmd = 'edgetpu_compiler --version'help_url = 'https://coral.ai/docs/edgetpu/compiler/'assert platform.system() == 'Linux', f'export only supported on Linux. See {help_url}'if subprocess.run(f'{cmd} >/dev/null', shell=True).returncode != 0:LOGGER.info(f'\n{prefix} export requires Edge TPU compiler. Attempting install from {help_url}')sudo = subprocess.run('sudo --version >/dev/null', shell=True).returncode == 0  # sudo installed on systemfor c in ('curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -','echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | sudo tee /etc/apt/sources.list.d/coral-edgetpu.list','sudo apt-get update', 'sudo apt-get install edgetpu-compiler'):subprocess.run(c if sudo else c.replace('sudo ', ''), shell=True, check=True)ver = subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1]LOGGER.info(f'\n{prefix} starting export with Edge TPU compiler {ver}...')f = str(file).replace('.pt', '-int8_edgetpu.tflite')  # Edge TPU modelf_tfl = str(file).replace('.pt', '-int8.tflite')  # TFLite modelcmd = f"edgetpu_compiler -s -d -k 10 --out_dir {file.parent} {f_tfl}"subprocess.run(cmd.split(), check=True)return f, None@try_export
def export_tfjs(file, prefix=colorstr('TensorFlow.js:')):# YOLOv5 TensorFlow.js exportcheck_requirements('tensorflowjs')import tensorflowjs as tfjsLOGGER.info(f'\n{prefix} starting export with tensorflowjs {tfjs.__version__}...')f = str(file).replace('.pt', '_web_model')  # js dirf_pb = file.with_suffix('.pb')  # *.pb pathf_json = f'{f}/model.json'  # *.json pathcmd = f'tensorflowjs_converter --input_format=tf_frozen_model ' \f'--output_node_names=Identity,Identity_1,Identity_2,Identity_3 {f_pb} {f}'subprocess.run(cmd.split())json = Path(f_json).read_text()with open(f_json, 'w') as j:  # sort JSON Identity_* in ascending ordersubst = re.sub(r'{"outputs": {"Identity.?.?": {"name": "Identity.?.?"}, 'r'"Identity.?.?": {"name": "Identity.?.?"}, 'r'"Identity.?.?": {"name": "Identity.?.?"}, 'r'"Identity.?.?": {"name": "Identity.?.?"}}}', r'{"outputs": {"Identity": {"name": "Identity"}, 'r'"Identity_1": {"name": "Identity_1"}, 'r'"Identity_2": {"name": "Identity_2"}, 'r'"Identity_3": {"name": "Identity_3"}}}', json)j.write(subst)return f, Nonedef add_tflite_metadata(file, metadata, num_outputs):# Add metadata to *.tflite models per https://www.tensorflow.org/lite/models/convert/metadatawith contextlib.suppress(ImportError):# check_requirements('tflite_support')from tflite_support import flatbuffersfrom tflite_support import metadata as _metadatafrom tflite_support import metadata_schema_py_generated as _metadata_fbtmp_file = Path('/tmp/meta.txt')with open(tmp_file, 'w') as meta_f:meta_f.write(str(metadata))model_meta = _metadata_fb.ModelMetadataT()label_file = _metadata_fb.AssociatedFileT()label_file.name = tmp_file.namemodel_meta.associatedFiles = [label_file]subgraph = _metadata_fb.SubGraphMetadataT()subgraph.inputTensorMetadata = [_metadata_fb.TensorMetadataT()]subgraph.outputTensorMetadata = [_metadata_fb.TensorMetadataT()] * num_outputsmodel_meta.subgraphMetadata = [subgraph]b = flatbuffers.Builder(0)b.Finish(model_meta.Pack(b), _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)metadata_buf = b.Output()populator = _metadata.MetadataPopulator.with_model_file(file)populator.load_metadata_buffer(metadata_buf)populator.load_associated_files([str(tmp_file)])populator.populate()tmp_file.unlink()@smart_inference_mode()
def run(data=ROOT / 'data/coco128.yaml',  # 'dataset.yaml path'weights=ROOT / 'yolov5s.pt',  # weights pathimgsz=(640, 640),  # image (height, width)batch_size=1,  # batch sizedevice='cpu',  # cuda device, i.e. 0 or 0,1,2,3 or cpuinclude=('torchscript', 'onnx'),  # include formatshalf=False,  # FP16 half-precision exportinplace=False,  # set YOLOv5 Detect() inplace=Truekeras=False,  # use Kerasoptimize=False,  # TorchScript: optimize for mobileint8=False,  # CoreML/TF INT8 quantizationdynamic=False,  # ONNX/TF/TensorRT: dynamic axessimplify=False,  # ONNX: simplify modelopset=12,  # ONNX: opset versionverbose=False,  # TensorRT: verbose logworkspace=4,  # TensorRT: workspace size (GB)nms=False,  # TF: add NMS to modelagnostic_nms=False,  # TF: add agnostic NMS to modeltopk_per_class=100,  # TF.js NMS: topk per class to keeptopk_all=100,  # TF.js NMS: topk for all classes to keepiou_thres=0.45,  # TF.js NMS: IoU thresholdconf_thres=0.25,  # TF.js NMS: confidence threshold
):t = time.time()include = [x.lower() for x in include]  # to lowercasefmts = tuple(export_formats()['Argument'][1:])  # --include argumentsflags = [x in include for x in fmts]assert sum(flags) == len(include), f'ERROR: Invalid --include {include}, valid --include arguments are {fmts}'jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle = flags  # export booleansfile = Path(url2file(weights) if str(weights).startswith(('http:/', 'https:/')) else weights)  # PyTorch weights# Load PyTorch modeldevice = select_device(device)if half:assert device.type != 'cpu' or coreml, '--half only compatible with GPU export, i.e. use --device 0'assert not dynamic, '--half not compatible with --dynamic, i.e. use either --half or --dynamic but not both'model = attempt_load(weights, device=device, inplace=True, fuse=True)  # load FP32 model# Checksimgsz *= 2 if len(imgsz) == 1 else 1  # expandif optimize:assert device.type == 'cpu', '--optimize not compatible with cuda devices, i.e. use --device cpu'# Inputgs = int(max(model.stride))  # grid size (max stride)imgsz = [check_img_size(x, gs) for x in imgsz]  # verify img_size are gs-multiplesim = torch.zeros(batch_size, 3, *imgsz).to(device)  # image size(1,3,320,192) BCHW iDetection# Update modelmodel.eval()for k, m in model.named_modules():if isinstance(m, Detect):m.inplace = inplacem.dynamic = dynamicm.export = Truefor _ in range(2):y = model(im)  # dry runsif half and not coreml:im, model = im.half(), model.half()  # to FP16shape = tuple((y[0] if isinstance(y, tuple) else y).shape)  # model output shapemetadata = {'stride': int(max(model.stride)), 'names': model.names}  # model metadataLOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with output shape {shape} ({file_size(file):.1f} MB)")# Exportsf = [''] * len(fmts)  # exported filenameswarnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning)  # suppress TracerWarningif jit:  # TorchScriptf[0], _ = export_torchscript(model, im, file, optimize)if engine:  # TensorRT required before ONNXf[1], _ = export_engine(model, im, file, half, dynamic, simplify, workspace, verbose)if onnx or xml:  # OpenVINO requires ONNXf[2], _ = export_onnx(model, im, file, opset, dynamic, simplify)if xml:  # OpenVINOf[3], _ = export_openvino(file, metadata, half)if coreml:  # CoreMLf[4], _ = export_coreml(model, im, file, int8, half)if any((saved_model, pb, tflite, edgetpu, tfjs)):  # TensorFlow formatsassert not tflite or not tfjs, 'TFLite and TF.js models must be exported separately, please pass only one type.'assert not isinstance(model, ClassificationModel), 'ClassificationModel export to TF formats not yet supported.'f[5], s_model = export_saved_model(model.cpu(),im,file,dynamic,tf_nms=nms or agnostic_nms or tfjs,agnostic_nms=agnostic_nms or tfjs,topk_per_class=topk_per_class,topk_all=topk_all,iou_thres=iou_thres,conf_thres=conf_thres,keras=keras)if pb or tfjs:  # pb prerequisite to tfjsf[6], _ = export_pb(s_model, file)if tflite or edgetpu:f[7], _ = export_tflite(s_model, im, file, int8 or edgetpu, data=data, nms=nms, agnostic_nms=agnostic_nms)if edgetpu:f[8], _ = export_edgetpu(file)add_tflite_metadata(f[8] or f[7], metadata, num_outputs=len(s_model.outputs))if tfjs:f[9], _ = export_tfjs(file)if paddle:  # PaddlePaddlef[10], _ = export_paddle(model, im, file, metadata)# Finishf = [str(x) for x in f if x]  # filter out '' and Noneif any(f):cls, det, seg = (isinstance(model, x) for x in (ClassificationModel, DetectionModel, SegmentationModel))  # typedir = Path('segment' if seg else 'classify' if cls else '')h = '--half' if half else ''  # --half FP16 inference args = "# WARNING 鈿狅笍 ClassificationModel not yet supported for PyTorch Hub AutoShape inference" if cls else \"# WARNING 鈿狅笍 SegmentationModel not yet supported for PyTorch Hub AutoShape inference" if seg else ''LOGGER.info(f'\nExport complete ({time.time() - t:.1f}s)'f"\nResults saved to {colorstr('bold', file.parent.resolve())}"f"\nDetect:          python {dir / ('detect.py' if det else 'predict.py')} --weights {f[-1]} {h}"f"\nValidate:        python {dir / 'val.py'} --weights {f[-1]} {h}"f"\nPyTorch Hub:     model = torch.hub.load('ultralytics/yolov5', 'custom', '{f[-1]}')  {s}"f"\nVisualize:       https://netron.app")return f  # return list of exported files/dirsdef parse_opt():parser = argparse.ArgumentParser()parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='dataset.yaml path')parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolov5s.pt', help='model.pt path(s)')parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640, 640], help='image (h, w)')parser.add_argument('--batch-size', type=int, default=1, help='batch size')parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')parser.add_argument('--half', action='store_true', help='FP16 half-precision export')parser.add_argument('--inplace', action='store_true', help='set YOLOv5 Detect() inplace=True')parser.add_argument('--keras', action='store_true', help='TF: use Keras')parser.add_argument('--optimize', action='store_true', help='TorchScript: optimize for mobile')parser.add_argument('--int8', action='store_true', help='CoreML/TF INT8 quantization')parser.add_argument('--dynamic', action='store_true', help='ONNX/TF/TensorRT: dynamic axes')parser.add_argument('--simplify', action='store_true', help='ONNX: simplify model')parser.add_argument('--opset', type=int, default=12, help='ONNX: opset version')parser.add_argument('--verbose', action='store_true', help='TensorRT: verbose log')parser.add_argument('--workspace', type=int, default=4, help='TensorRT: workspace size (GB)')parser.add_argument('--nms', action='store_true', help='TF: add NMS to model')parser.add_argument('--agnostic-nms', action='store_true', help='TF: add agnostic NMS to model')parser.add_argument('--topk-per-class', type=int, default=100, help='TF.js NMS: topk per class to keep')parser.add_argument('--topk-all', type=int, default=100, help='TF.js NMS: topk for all classes to keep')parser.add_argument('--iou-thres', type=float, default=0.45, help='TF.js NMS: IoU threshold')parser.add_argument('--conf-thres', type=float, default=0.25, help='TF.js NMS: confidence threshold')parser.add_argument('--include',nargs='+',default=['torchscript'],help='torchscript, onnx, openvino, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle')opt = parser.parse_args()print_args(vars(opt))return optdef main(opt):for opt.weights in (opt.weights if isinstance(opt.weights, list) else [opt.weights]):run(**vars(opt))if __name__ == "__main__":opt = parse_opt()main(opt)

四、总结

本文介绍了如何在YOLOv5中引入SE注意力机制,包括模型配置文件的修改、代码实现、训练步骤以及效果对比。通过引入SE模块,YOLOv5在多尺度目标和复杂背景下的检测精度有所提升。未来,可以继续探索其他注意力机制(如CBAM、ECA等)的应用,以进一步提升YOLOv5的性能。感谢大家的支持。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/28834.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Webpack分包与合包深度解析

Webpack分包与合包深度解析 前些天发现了一个巨牛的人工智能学习网站&#xff0c;通俗易懂&#xff0c;风趣幽默&#xff0c;忍不住分享一下给大家&#xff01;点我试试&#xff01;&#xff01; 引言&#xff1a;现代前端工程的模块化困境 在单页面应用&#xff08;SPA&#…

永恒之塔鼠标卡顿移动鼠标卡屏的问题

原因是现在鼠标普遍轮询率偏高导致系统性能开销过大 解决办法90块钱到淘宝雷蛇官网店买个最便宜的鼠标 然后安装Razer控制台,在性能栏把轮询率设置到最低(125)

Selenium遇到Exception自动截图

# 随手小记 场景&#xff1a;测试百度&#xff1a; 点击新闻&#xff0c;跳转到新的窗口&#xff0c;找到输入框&#xff0c;输入“hello,world" 等到输入框的内容是hello,world, 这里有个错误&#xff0c;少了一个] 后来就实现了错误截图的功能&#xff0c;可以参考 …

飞机大战lua迷你世界脚本

-- 迷你世界飞机大战 v1.2 -- 星空露珠工作室制作 -- 最后更新&#xff1a;2024年1月 ----------------------------- -- 迷你世界API适配配置 ----------------------------- local UI { BASE_ID 7477478487091949474-22856, -- UI界面ID ELEMENTS { BG 1, -- 背景 BTN_LE…

我的ChatGPT怎么登不上?

近期&#xff0c;不少用户反馈在使用ChatGPT时遇到登录困难、连接超时等问题。本文将从技术角度分析常见原因&#xff0c;并提供合规、安全的解决方案&#xff0c;同时结合开发者实际需求推荐实用工具&#xff0c;助您高效应对登录障碍。 ChatGPT登录失败的常见原因 网络环境限…

【MySQL】用MySQL二进制包构建docker镜像

一、实验背景 【MySQL&docker】基于CentOS7.5 编译制作MySQL5.7.28镜像 https://www.jianshu.com/p/71fd79b69a6b 用MySQL源码编译的docker镜像&#xff0c;体积过大&#xff0c;直奔3G了&#xff0c;你也不清楚&#xff0c;这点编译参数打出的体积怎么就这么大&#xff01…

快速熟悉JavaScript

目录 1.js的基本认知 2.js的基本语法 2.1 变量的声明 三个关键字的区别 2.2数据类型 2.2.1 基本数据类型 2.2.2 复杂数据类型 2.3对象的属性和方法 2.3.1属性 2.3.2访问方式 2.4.3动态操作 2.4.4方法 2.4字符串的常用属性和方法 2.5运算符 2.6逻辑控制语句 2.7函…

在 Windows 上最快速安装 Qt 5

引言 Qt 是一个强大的跨平台 C 开发框架&#xff0c;广泛应用于 GUI 开发、嵌入式系统和工业软件等领域。然而&#xff0c;许多开发者习惯于在 Linux&#xff08;如 Ubuntu&#xff09;环境下使用 Qt&#xff0c;而在 Windows 上搭建 Qt 开发环境时可能会遇到许多问题&#xf…

二、QT和驱动模块实现智能家居-----5、通过QT控制LED

在QT界面&#xff0c;我们要实现点击“LED”按钮就可以控制板子上的LED。LED接线图如下&#xff1a; 在Linux 系统里&#xff0c;我们可以使用2种方法去操作上面的LED&#xff1a; ① 使用GPIO SYSFS系统&#xff1a;这需要一定的硬件知识&#xff0c;需要设置引脚的方向、数值…

threejs:用着色器给模型添加光带扫描效果

第一步&#xff1a;给模型添加光带 首先创建一个立方体&#xff0c;不进行任何缩放平移操作&#xff0c;也不要set position。 基础代码如下&#xff1a; 在顶点着色器代码里varying vec3 vPosition;vPosition position;获得threejs自动计算的顶点坐标插值&#xff08;也就…

高频 SQL 50 题(基础版)_1141. 查询近30天活跃用户数

1141. 查询近30天活跃用户数 select activity_date day,count(distinct user_id) active_users from Activity where (activity_date<2019-07-27 and activity_date>DATE_sub(2019-07-27,INTERVAL 30 DAY)) group by(activity_date)

【Zinx】Day1:初识 Zinx 框架

目录 学习目标初识 Zinx 框架Zinx v0.2 代码实现准备工作创建 Zinx 框架创建 ziface 与 znet 模块 基础的 Server 实现在 ziface 下创建服务模块抽象层 iserver.go在 znet 下实现服务模块 server.go 封装 Connection在 ziface 创建 iconnection.go在 znet 创建 connection.go 回…

音频3A测试--AEC(回声消除)测试

一、测试前期准备 一台录制电脑:用于作为近段音源和收集远端处理后的数据; 一台测试设备B:用于测试AEC的设备; 一个高保真音响:用于播放设备B的讲话; 一台播放电脑:用于模拟设备A讲话,和模拟设备B讲话; 一台音频处理器(调音台):用于录制和播放数据; 测试使用转接线若…

项目准备(flask+pyhon+MachineLearning)- 2

目录 1. 注册页面的渲染 2.邮箱的注册验证 3.登录页面的渲染 1. 注册页面的渲染 使用render_template来渲染&#xff0c;注意这里的前端网页使用jinja2模板 详情大家可以参考jinja2介绍 注意每个网页上方都有导航条&#xff0c;大家可以使用jinja2的继承功能&#xff0c;写一…

unity调用本地部署deepseek全流程

unity调用本地部署deepseek全流程 deepseek本地部署 安装Ollama 搜索并打开Ollama官网[Ollama](https://ollama.com/download) 点击Download下载对应版本 下载后点击直接安装 安装deepseek大语言模型 官网选择Models 选择deepseek-r1&#xff0c;选择对应的模型&#xff0…

跨域问题解释及前后端解决方案(SpringBoot)

一、问题引出 有时,控制台出现如下问题。 二、为什么会有跨域 2.1浏览器同源策略 浏览器的同源策略 &#xff08; Same-origin policy &#xff09;是一种重要的安全机制&#xff0c;用于限制一个源&#xff08; origin &#xff09;的文档或 脚本如何与另一个源的资源进行…

DeepSeek大模型深度解析:架构、技术与应用全景

前些天发现了一个巨牛的人工智能学习网站&#xff0c;通俗易懂&#xff0c;风趣幽默&#xff0c;忍不住分享一下给大家。点击跳转到网站。https://www.captainbed.cn/north 文章目录 一、大模型时代与DeepSeek的定位1.1 大模型发展历程回顾大模型发展历程时间轴&#xff08;20…

李宏毅机器学习课程学习笔记04 | 浅谈机器学习-宝可梦、数码宝贝分类器

文章目录 案例&#xff1a;宝可梦、数码宝贝分类器第一步&#xff1a;需要定义一个含有未知数的function第二步&#xff1a;loss of a function如何Sample Training Examples > 如何抽样可以得到一个较好的结果如何权衡模型的复杂程度 Tradeoff of Model Complexity todo 这…

获取Kernel32基地址

暴力搜索 32位在4G内存搜索有一定可行性&#xff0c;但是处理起来其实还是比较麻烦的&#xff0c;因为内存不可读会触发异常&#xff0c;需要对这些异常问题进行处理。 优化思路:缩小范围、增大搜索步长 (1)不优化&#xff0c;原始内存特征匹配&#xff0c;容易出错&#xf…

Spark核心之01:架构部署、sparkshell、程序模板

spark内存计算框架 一、主题 spark核心概念spark集群架构spark集群安装部署spark-shell的使用通过IDEA开发spark程序 二、要点 1. spark是什么 Apache Spark™ is a unified analytics engine for large-scale data processing. spark是针对于大规模数据处理的统一分析引擎…