model.py
ultralytics\models\yolo\model.py
目录
model.py
1.所需的库和模块
2.class YOLO(Model):
3.class YOLOWorld(Model):
1.所需的库和模块
# Ultralytics YOLO 🚀, AGPL-3.0 licensefrom pathlib import Pathfrom ultralytics.engine.model import Model
from ultralytics.models import yolo
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, OBBModel, PoseModel, SegmentationModel, WorldModel
from ultralytics.utils import yaml_load, ROOT
2.class YOLO(Model):
# 这段代码定义了一个名为 YOLO 的类,它是一个用于目标检测的模型,支持多种任务类型(如分类、检测、分割等),并能够根据不同的任务加载相应的模型和工具。
# 定义了一个名为 YOLO 的类,继承自 Model 。 Model 是一个基类,提供了通用的模型初始化和操作功能。
# class Model(nn.Module):
# -> 用于加载、训练、预测和部署YOLO模型。
# -> def __init__(self, model: Union[str, Path] = "yolov8n.pt", task: str = None, verbose: bool = False,) -> None:
class YOLO(Model):# YOLO(You Only Look Once)物体检测模型。"""YOLO (You Only Look Once) object detection model."""# 定义了 YOLO 类的初始化方法 __init__ ,它接受以下参数 :# 1.model :默认值为 "yolov8n.pt" ,表示模型文件的路径或名称。# 2.task :表示任务类型(如分类、检测等),默认为 None 。# 3.verbose :布尔值,表示是否输出详细信息,默认为 False 。def __init__(self, model="yolov8n.pt", task=None, verbose=False):# 初始化 YOLO 模型,如果模型文件名包含‘-world’则切换到 YOLOWorld。"""Initialize YOLO model, switching to YOLOWorld if model filename contains '-world'."""# 将传入的 model 参数转换为 Path 对象,方便后续对路径进行操作。 Path 是 pathlib 模块中的一个类,用于处理文件路径。path = Path(model)# 检查模型路径是否包含 -world ,并且文件扩展名是否为 .pt 、 .yaml 或 .yml 。如果满足条件,说明这是一个 YOLOWorld 模型。if "-world" in path.stem and path.suffix in {".pt", ".yaml", ".yml"}: # if YOLOWorld PyTorch model# 如果条件满足,创建一个 YOLOWorld 类的实例,并将模型路径传递给它。new_instance = YOLOWorld(path)# 将 当前实例的类动态 更改为 YOLOWorld 类的类型。这是一种动态改变实例类的方法,使得当前实例的行为与 YOLOWorld 类一致。self.__class__ = type(new_instance)# 将 当前实例的属性字典 ( __dict__ )替换为 YOLOWorld 实例的属性字典,从而继承 YOLOWorld 的所有属性和状态。self.__dict__ = new_instance.__dict__# 如果模型路径不满足 YOLOWorld 的条件,则执行默认的 YOLO 初始化逻辑。else:# Continue with default YOLO initialization# 调用父类 Model 的初始化方法,将 model 、 task 和 verbose 参数传递给父类,完成默认的 YOLO 初始化。super().__init__(model=model, task=task, verbose=verbose)# 定义了一个属性装饰器,用于将下面的方法定义为只读属性。@property# 定义了一个名为 task_map 的方法,它返回一个字典,映射了不同任务类型到对应的模型、训练器、验证器和预测器。def task_map(self):# 将头部映射到模型、训练器、验证器和预测器类别。"""Map head to model, trainer, validator, and predictor classes."""# 返回一个字典,包含不同任务类型及其对应的组件。return {# 定义了分类任务( classify )的映射。"classify": {# 指定分类任务的 模型 为 ClassificationModel 。"model": ClassificationModel,# 指定分类任务的 训练器 为 yolo.classify.ClassificationTrainer 。"trainer": yolo.classify.ClassificationTrainer,# 指定分类任务的 验证器 为 yolo.classify.ClassificationValidator 。"validator": yolo.classify.ClassificationValidator,# 指定分类任务的 预测器 为 yolo.classify.ClassificationPredictor 。"predictor": yolo.classify.ClassificationPredictor,# 结束分类任务的映射。},# 类似的结构定义了其他任务类型(如 detect 、 segment 、 pose 和 obb ),分别对应 检测 、 分割 、 姿态估计 和 定向边界框 任务。每个任务都指定了对应的模型、训练器、验证器和预测器。"detect": {"model": DetectionModel,"trainer": yolo.detect.DetectionTrainer,"validator": yolo.detect.DetectionValidator,"predictor": yolo.detect.DetectionPredictor,},"segment": {"model": SegmentationModel,"trainer": yolo.segment.SegmentationTrainer,"validator": yolo.segment.SegmentationValidator,"predictor": yolo.segment.SegmentationPredictor,},"pose": {"model": PoseModel,"trainer": yolo.pose.PoseTrainer,"validator": yolo.pose.PoseValidator,"predictor": yolo.pose.PosePredictor,},"obb": {"model": OBBModel,"trainer": yolo.obb.OBBTrainer,"validator": yolo.obb.OBBValidator,"predictor": yolo.obb.OBBPredictor,},}
# 这段代码定义了一个灵活的 YOLO 类,支持多种任务类型(分类、检测、分割等),并能够根据任务类型动态加载相应的组件。它还支持 YOLOWorld 模型的特殊处理,通过动态改变实例的类和属性,实现与 YOLOWorld 的无缝集成。这种设计使得 YOLO 类具有高度的可扩展性和灵活性,能够适应不同的任务需求和模型类型。
3.class YOLOWorld(Model):
# 这段代码定义了一个名为 YOLOWorld 的类,它是 YOLO 类的一个扩展,专门用于处理 YOLO-World 目标检测模型。
# 定义了一个名为 YOLOWorld 的类,继承自 Model 。这表明 YOLOWorld 是一个基于 Model 基类的模型类,继承了基类的通用功能。
class YOLOWorld(Model):# YOLO-World 物体检测模型。"""YOLO-World object detection model."""# 定义了 YOLOWorld 类的初始化方法 __init__ ,它接受一个参数 :# 1.model :默认值为 "yolov8s-world.pt" ,表示模型文件的路径或名称。def __init__(self, model="yolov8s-world.pt") -> None:# 使用给定的预训练模型文件初始化 YOLOv8-World 模型。支持 *.pt 和 *.yaml 格式。"""Initializes the YOLOv8-World model with the given pre-trained model file. Supports *.pt and *.yaml formats.Args:model (str | Path): Path to the pre-trained model. Defaults to 'yolov8s-world.pt'."""# 调用父类 Model 的初始化方法,将 model 参数传递给父类,并固定任务类型为 "detect" 。这表明 YOLOWorld 专门用于目标检测任务。super().__init__(model=model, task="detect")# Assign default COCO class names when there are no custom names# 检查模型对象是否没有 names 属性。 names 属性通常用于存储类别名称。if not hasattr(self.model, "names"):# 如果模型没有 names 属性,则从默认的 COCO 数据集配置文件( coco8.yaml )中加载类别名称,并将其赋值给 self.model.names 。 yaml_load 是一个函数,用于加载 YAML 文件的内容。self.model.names = yaml_load(ROOT / "cfg/datasets/coco8.yaml").get("names")# 定义了一个属性装饰器,用于将下面的方法定义为只读属性。@property# 定义了一个名为 task_map 的方法,它返回一个字典,映射了目标检测任务( detect )到对应的模型、验证器和预测器。def task_map(self):# 将头部映射到模型、验证器和预测器类别。"""Map head to model, validator, and predictor classes."""# 返回一个字典,包含目标检测任务的组件映射。return {# 在 YOLOWorld 类的 task_map 属性中,目标检测任务( detect )的映射中没有指定训练器( trainer )。这种设计可能是基于以下几种原因 :# YOLO-World 模型的特殊性 : YOLO-World 模型可能是一个预训练模型,主要用于推理(预测)和验证,而不是训练。在这种情况下,训练器可能并不需要,或者训练过程是由其他工具或模块完成的。因此, task_map 中只提供了验证器( validator )和预测器( predictor ),用于模型的评估和推理。# 训练器的通用性 : YOLO-World 模型可能使用与普通 YOLO 模型相同的训练器。在这种情况下,训练器不需要在 YOLOWorld 类中单独指定,而是可以复用父类 YOLO 或其他通用训练器。例如,在 YOLO 类的 task_map 中,目标检测任务的训练器被定义为 yolo.detect.DetectionTrainer 。如果 YOLO-World 模型可以使用相同的训练器,那么在 YOLOWorld 类中重复定义是多余的。# 代码设计的简洁性 :为了避免代码冗余, YOLOWorld 类可能故意省略了训练器的定义。如果训练器的实现与父类或其他模块相同,那么在 task_map 中重复定义可能会导致代码冗余和维护复杂性。通过省略训练器,代码更加简洁,同时也能减少潜在的错误。# 动态加载训练器 :在某些情况下,训练器可能需要根据具体任务动态加载。例如,训练器可能依赖于其他配置或参数,这些参数在类初始化时并不确定。因此,训练器的加载可能被推迟到运行时,而不是在 task_map 中静态定义。# 模型的使用场景 : YOLO-World 模型可能主要被设计为一个推理模型,用于快速部署和使用。在这种情况下,训练器可能不是该类的主要关注点,而是由其他模块或工具负责模型的训练过程。# 总结 :YOLOWorld 类的 task_map 属性中没有指定目标检测任务的训练器,可能是由于以下原因之一。 YOLO-World 模型主要用于推理,不需要单独的训练器。训练器可能与父类或其他模块共享,无需重复定义。训练器可能需要动态加载,不适合在 task_map 中静态定义。模型的设计重点是推理和验证,训练过程由其他工具完成。这种设计反映了代码的灵活性和简洁性,同时也符合 YOLO-World 模型的实际使用场景。# 定义了目标检测任务( detect )的映射。"detect": {# 指定目标检测任务的模型为 WorldModel 。"model": WorldModel,# 指定目标检测任务的验证器为 yolo.detect.DetectionValidator 。"validator": yolo.detect.DetectionValidator,# 指定目标检测任务的预测器为 yolo.detect.DetectionPredictor 。"predictor": yolo.detect.DetectionPredictor,}# 结束目标检测任务的映射。}# 定义了一个名为 set_classes 的方法,用于设置模型的类别名称。# 1.classes :类别名称列表。def set_classes(self, classes):# 设置类别。"""Set classes.Args:classes (List(str)): A list of categories i.e ["person"]."""# 调用模型的 set_classes 方法,将传入的 类别名称列表 传递给模型。self.model.set_classes(classes)# Remove background if it's given# 定义了一个变量 background ,值为空格,用于表示背景类别。background = " "# 检查传入的类别名称列表中是否包含背景类别。if background in classes:# 如果包含背景类别,则从类别列表中移除它。classes.remove(background)# 将更新后的类别名称列表赋值给模型的 names 属性。self.model.names = classes# Reset method class names# self.predictor = None # reset predictor otherwise old names remain# 检查是否存在预测器实例。if self.predictor:# 如果存在预测器,则更新预测器模型的类别名称。self.predictor.model.names = classes
# YOLOWorld 类是一个专门用于目标检测的模型类,继承自 Model 。它具有以下特点。固定任务类型: YOLOWorld 仅支持目标检测任务( detect ),并通过 task_map 提供了任务所需的模型、验证器和预测器。默认类别名称:如果模型没有定义类别名称,则从 COCO 数据集配置文件中加载默认类别名称。动态类别设置:通过 set_classes 方法,用户可以动态设置模型的类别名称,并在需要时移除背景类别。同时,该方法还会同步更新预测器的类别名称。扩展性: YOLOWorld 类通过继承 Model ,继承了基类的通用功能,同时针对 YOLO-World 模型进行了定制化扩展。这种设计使得 YOLOWorld 类能够灵活地处理目标检测任务,同时支持动态类别设置和默认类别加载,适用于多种应用场景。