Detectron2和LSTM进行人体动作识别

1. 项目简介

本项目旨在开发一个人体动作识别系统,利用深度学习模型Detectron2和LSTM(长短时记忆网络)实现对视频中人体动作的精确识别与分类。项目背景是由于在现代智能监控、健康管理、体育分析等领域中,对人体动作的自动识别和分析需求日益增加。传统的计算机视觉方法在复杂环境下表现有限,而结合先进的深度学习模型能够在动态场景中捕捉细微的动作变化,从而提升识别准确率。本项目的主要目标是通过Detectron2进行人体关键点检测,将检测到的时序特征输入LSTM模型中,从而捕捉动作序列的时间依赖性并进行分类。Detectron2作为一个基于Mask R-CNN的深度学习模型,擅长对象检测和实例分割,能够准确定位人体的各个关键点,而LSTM则用于处理时间序列数据,学习动作随时间变化的模式。整体模型在训练数据集上进行训练后,能够对视频流中的每帧进行实时分析,从而实现对复杂人体动作的自动识别。项目应用场景广泛,包括智能监控、行为分析、体育动作指导及康复训练等,具备极高的实际应用价值。
在这里插入图片描述

在这里插入图片描述

Detectron2

Detectron2是 Facebook AI Research 的开源平台,用于物体检测、密集姿势、分割和其他视觉识别任务。该平台现在在 PyTorch 中实现,而其之前的版本Detectron则在 Caffe2 中实现。

在这里,我们使用来自 Detectron2 模型库的预训练“R50-FPN”模型进行姿势估计。该模型已经在 COCO 数据集上进行了训练,该数据集包含超过 200,000 张图像和 250,000 个人物实例,并标有关键点。该模型为输入图像帧中存在的每个人输出 17 个关键点,如下图所示。

在这里插入图片描述

人体上有 17 个关键点。图片左侧显示一个人,中间部分显示关键点列表,右侧显示人体上关键点的位置

2.技术创新点摘要

本项目结合了Detectron2和LSTM的优点,在人体动作识别领域实现了创新性的模型设计和数据处理方式。其技术创新点体现在以下几个方面:

  1. 多模型融合与关键点映射: 本项目采用Detectron2作为基础的关键点检测模型,并利用其在目标检测和实例分割方面的优势,精准识别人体的17个关键点。同时,为了解决不同深度学习框架(如OpenPose和Detectron2)之间的关键点顺序不一致的问题,项目引入了关键点映射机制(openpose_to_detectron_mapping),实现了不同检测格式间的数据统一,从而提升了整体模型的兼容性和准确性。
  2. 时序信息捕捉与动作识别: 通过LSTM(长短时记忆网络)对人体关键点数据进行时间序列建模,模型能够捕捉连续帧之间的动态变化。这种设计使得项目能够有效处理复杂的动态动作,如挥手、踢腿等,需要考虑时间维度的动作识别场景。具体实现中,模型采用了32帧的滑动窗口(WINDOW_SIZE),将连续帧划分为时序数据块输入LSTM,从而提升了模型对长时间依赖性的处理能力。
  3. 基于Pytorch Lightning的模块化设计: 项目采用Pytorch Lightning框架,实现了模型的模块化设计,并通过PoseDataModule类将数据加载、预处理、训练和评估过程进行有效封装。这种设计方式降低了代码复杂度,增强了模型的可扩展性和复用性,使得后续模型优化和调参更加方便。
  4. 关键点格式转换与数据增强: 在数据预处理中,通过convert_to_detectron_format函数将OpenPose格式的关键点数据转换为Detectron2标准格式,并结合了自定义的过滤和数据增强策略,从而解决了不同检测模型间数据格式不一致的问题,提升了训练数据的质量。

3. 数据集与预处理

本项目使用的主要数据集是人体动作识别公开数据集,包含大量人体关键点位置信息以及对应的动作标签。

它由使用OpenPose深度学习模型在伯克利多模态人类动作数据库 (MHAD)数据集的子集上进行的关键点检测组成。

OpenPose 是首个实时多人系统,可在单张图片上联合检测人体、手、脸和脚的关键点(共 135 个关键点)。关键点检测基于 12 个主体的视频(从 4 个角度拍摄),执行以下 6 个动作并重复 5 .

在这里插入图片描述

数据集的特点在于:每个样本是基于视频序列提取的连续人体关键点数据,每个动作由多帧(如32帧)组成,并且每帧都标注了人体不同关节点(如头部、肩膀、肘部、膝盖等)的位置坐标。相较于单帧图像数据,这种时序数据具有明显的时间依赖性,可以反映动作在时间维度上的连续性与动态变化。

该数据集由 12 名受试者组成,他们重复执行以下 6 个动作 5 次,从 4 个角度拍摄,每个角度重复 5 次。

  • 跳跃,
  • 跳跃运动,
  • 拳击,
  • 挥手,
  • 挥手,
  • 鼓掌。

总共有 1438 个视频(缺少 2 个),由 211200 个独立帧组成。

数据预处理流程
  1. 关键点格式转换: 数据集中部分样本是通过OpenPose生成的18个关键点序列,而项目中使用的Detectron2只输出17个关键点。为此,数据预处理中实现了关键点格式转换,将OpenPose生成的关键点数据重新排列,并通过去除不必要的点(如脖子点),将其标准化为Detectron2格式,以确保不同数据源的格式一致性。
  2. 归一化处理: 在数据输入模型之前,所有关键点坐标都会经过归一化处理,将其坐标值缩放到[0, 1]区间。这样做的目的是消除人体大小和不同图像分辨率对关键点坐标的影响,从而提高模型的鲁棒性。
  3. 数据分块与滑动窗口: 针对时序特征的处理,采用滑动窗口策略(如32帧),将连续的关键点序列按照固定的窗口大小进行分块,从而形成时序数据。每个时序数据块对应一个完整的动作标签。
  4. 数据增强与转换: 项目还通过随机删除、镜像变换等数据增强手段来增加训练样本的多样性,防止模型过拟合。同时,数据转换过程中,模型会自动识别并填充缺失值,确保输入特征的完整性。

4. 模型架构

1) 模型结构的逻辑

本项目的深度学习模型架构主要由Detectron2和LSTM模型两部分构成。首先,Detectron2用于人体关键点的检测,然后将其输出作为时间序列输入LSTM进行动作分类。具体模型结构如下:

  1. 关键点检测模型(Detectron2) :Detectron2模型基于Mask R-CNN架构,负责从视频帧中提取人体的17个关键点坐标数据。其骨干网络(Backbone)使用ResNet结构,提取每帧图像的特征图,并通过RPN(区域提议网络)生成目标区域,从而精确识别各个关节点的位置坐标。

  2. 动作分类模型(LSTM)

    1. LSTM层: 输入特征为人体关键点的二维坐标(input_features),每帧代表一个时间步(Time Step)。模型采用标准的LSTM结构:
    2. h t , c t = LSTM ( X t , h t − 1 , c t − 1 ) h_t, c_t = \text{LSTM}(X_t, h_{t-1}, c_{t-1}) ht,ct=LSTM(Xt,ht1,ct1)
    3. 其中,ht 和 ct 分别表示LSTM的隐状态和记忆状态, Xt 为当前时间步的输入特征。LSTM通过隐藏层状态捕捉不同帧之间的时间依赖性。
    4. 线性层: LSTM的最终隐状态输出通过线性层(self.linear)映射到动作类别:
      1. y = W ⋅ h t + b y = W ⋅ h t + b y = W ⋅ h t + b y=W⋅ht+by = W \cdot h_t + by=W⋅ht+b y=Wht+by=Wht+by=Wht+b
    5. 其中,WWW 为线性变换矩阵, hth_tht 为LSTM的输出, yyy 为预测的动作类别得分。此处的输出大小为动作类别的总数(TOT_ACTION_CLASSES)。
  3. 激活函数与损失函数:

    1. LSTM输出通过softmax激活函数转换为每个动作类别的概率分布: P ( y ) = softmax ( y ) P(y) = \text{softmax}(y) P(y)=softmax(y)
    2. 损失函数采用交叉熵损失(CrossEntropyLoss)来衡量预测结果与真实标签之间的差距,从而指导模型优化。
2) 模型的整体训练流程
  1. 数据预处理与加载: 数据首先通过PoseDataModule进行加载和预处理(包括关键点格式转换、归一化等),并通过PyTorch Lightning的DataLoader模块进行批次(Batch)的分发。

  2. 训练流程: 使用PyTorch Lightning框架构建标准化训练流程:

    1. 前向传播: 输入人体关键点序列,经过LSTM层处理后通过线性层输出动作类别预测结果。
    2. 反向传播与优化: 使用Adam优化器调整模型权重,最小化损失函数值。模型还引入了学习率调度器来动态调整学习率,防止过拟合。
  3. 模型评估指标: 模型评估采用以下几个指标:

    1. 验证损失( val_loss ): 用于评估模型在验证集上的性能。
    2. 分类准确率(Accuracy): 衡量模型预测的类别与真实标签的匹配程度。
    3. F1-Score: 评估模型在多类别动作识别中的精准度和召回率。
  4. 模型保存与回调机制: 训练过程中通过ModelCheckpoint保存验证集损失最低的模型,同时引入EarlyStopping策略,防止模型过拟合。

我们对模型进行了400 次训练,验证准确率达到0.913。验证准确率和损失曲线如下所示。训练后的模型被录入代码库,并在推理过程中使用。

在“Tesla T4”GPU 上进行的测试表明,Detectron2 推理大约需要0.14 秒,而 LSTM推理大约需要 0.002 秒。因此,如果我们处理视频中的每一帧,我们的推理管道执行的总FPS(每秒帧数)约为每秒 6 帧。

在这里插入图片描述

5. 核心代码详细讲解

1. 模型架构构建 (LSTM 构建部分 - lstm.py)

class ActionClassificationLSTM(pl.LightningModule):# 初始化方法,定义模型的基本结构和超参数def init(self, input_features, hidden_dim, learning_rate=0.001):super().__init__()# 保存初始化时传入的超参数,用于模型重建时复用self.save_hyperparameters()# 构建LSTM层:接收的输入为关键点数据,输出为隐藏状态的维度为 hidden_dimself.lstm = nn.LSTM(input_features, hidden_dim, batch_first=True)# 构建线性全连接层:将 LSTM 的输出映射到动作类别空间(分类任务的最终输出层)self.linear = nn.Linear(hidden_dim, TOT_ACTION_CLASSES)
  • 解释:

    • class ActionClassificationLSTM: 定义一个基于 LSTM 的动作分类模型。继承自 PyTorch Lightning 的 LightningModule,便于实现训练和评估模块化。
    • def __init__(self, input_features, hidden_dim, learning_rate=0.001): 初始化模型时定义输入特征维度 (input_features)、LSTM 隐藏层维度 (hidden_dim) 以及学习率 (learning_rate)。
    • self.save_hyperparameters(): 保存模型超参数,用于后续调用或模型的复现。
    • self.lstm = nn.LSTM(input_features, hidden_dim, batch_first=True): 创建一个 LSTM 网络层。input_features为输入特征维度,hidden_dim为 LSTM 隐藏层的维度。batch_first=True 表示输入的第一个维度为 batch size。
    • self.linear = nn.Linear(hidden_dim, TOT_ACTION_CLASSES): 定义一个线性层,将 LSTM 的输出(隐藏状态)映射到动作类别空间。TOT_ACTION_CLASSES表示动作类别的总数。

2. 数据预处理与特征工程 (utils.py)

def draw_line(image, p1, p2, color):# 使用 OpenCV 画线函数,在图像上绘制人体关键点之间的连线cv2.line(image, p1, p2, color, thickness=2, lineType=cv2.LINE_AA)
  • 解释:

    • def draw_line(image, p1, p2, color): 定义了一个用于在图像中绘制关键点连线的函数。
    • cv2.line(image, p1, p2, color, thickness=2, lineType=cv2.LINE_AA): 使用 OpenCV 函数在 image 图像上从 p1p2 点绘制一条颜色为 color 的直线,并设置线条厚度为 2,lineType=cv2.LINE_AA 表示使用抗锯齿算法绘制。

def find_person_indicies(scores):# 根据模型输出的置信分数找到可能是人的索引位置return [i for i, s in enumerate(scores) if s > 0.9]

  • 解释:

    • def find_person_indicies(scores): 通过对检测结果的置信度进行过滤,筛选出可能为人体的检测框。
    • return [i for i, s in enumerate(scores) if s > 0.9]: 遍历模型检测结果中的置信度列表 scores,返回大于 0.9 的索引位置,表示这些是高置信度的人体检测框。
def filter_persons(outputs):# 过滤模型输出中符合条件的人体关键点,并返回包含所有人体关键点的字典persons = {}p_indicies = find_person_indicies(outputs["instances"].scores)for x in p_indicies:# 提取检测到的关键点信息,并转换到CPU上,便于后续处理desired_kp = outputs["instances"].pred_keypoints[x][:].to("cpu")persons[x] = desired_kpreturn (persons, p_indicies)
  • 解释:

    • def filter_persons(outputs): 对检测模型输出的结果进行过滤,获取符合条件的人的关键点位置。
    • persons = {}: 初始化一个空字典用于存储人体关键点。
    • p_indicies = find_person_indicies(outputs["instances"].scores): 调用 find_person_indicies,获取检测到的高置信度人体索引。
    • for x in p_indicies: 遍历所有检测到的高置信度人体。
    • desired_kp = outputs["instances"].pred_keypoints[x][:].to("cpu"): 将每个检测到的人的关键点数据提取并转换到 CPU 中进行存储。
    • persons[x] = desired_kp: 将每个人体的关键点存入 persons 字典中,以检测框索引作为键值。

3. 模型训练与评估 (train.py)

checkpoint_callback = ModelCheckpoint(save_top_k=1, monitor='val_loss')
  • 解释:

    • checkpoint_callback = ModelCheckpoint(save_top_k=1, monitor='val_loss'): 定义模型检查点回调,用于在训练过程中保存验证集损失(val_loss)最低的模型权重。save_top_k=1 表示只保存验证集上最优的模型。
callbacks=[EarlyStopping(monitor='train_loss', patience=15), checkpoint_callback, lr_monitor]
  • 解释:

    • callbacks=[EarlyStopping(monitor='train_loss', patience=15), checkpoint_callback, lr_monitor]: 定义训练回调函数列表。
    • EarlyStopping(monitor='train_loss', patience=15): 使用早停策略(EarlyStopping)监控训练损失(train_loss),当验证集损失在 15 个 epoch 内不再下降时,停止训练,防止模型过拟合。
    • checkpoint_callback: 保存最优模型权重。
    • lr_monitor: 监控学习率的变化(未在此片段中完全展开)。

6. 模型优缺点评价

模型优点:
  1. 多模型融合: 本项目结合了Detectron2和LSTM的优点,前者用于人体关键点的精准定位,后者用于捕捉动作的时序依赖性。这种多模型融合策略能够在复杂场景中实现更高的动作识别精度。
  2. 时序信息捕捉能力强: 通过LSTM对人体关键点序列进行建模,可以有效处理动作的时间维度依赖关系,适合动态场景中的动作分类任务。
  3. 模块化设计与易于扩展: 使用PyTorch Lightning框架,模型的训练、验证和测试模块化程度高,便于后续对模型结构进行修改和扩展,具有较好的复用性和可维护性。
  4. 数据格式转换与兼容性: 项目实现了OpenPose到Detectron2格式的关键点映射,提升了不同数据源间的兼容性,使得模型可以处理不同格式的数据。
模型缺点:
  1. 对数据质量依赖较高: 由于动作识别依赖于人体关键点的准确性,一旦输入数据存在噪声或关键点检测精度不足,会导致识别效果下降。
  2. 无法处理复杂的多人物场景: 目前模型仅处理单一人物动作,缺乏对多人物交互动作的识别能力,这在群体场景中容易造成误分类。
  3. 计算开销较大: Detectron2和LSTM的组合在处理大规模数据时需要较大的计算资源,可能在实时场景中存在延迟问题。
模型改进方向:
  1. 模型结构优化: 考虑引入双向LSTM或Transformer结构以更好地捕捉动作序列的全局信息。
  2. 超参数调整: 针对LSTM的隐藏层维度、学习率及优化策略进行超参数搜索,以找到更优的训练配置。
  3. 数据增强策略: 增加对人体关键点的随机偏移、旋转、遮挡等数据增强方法,提高模型的鲁棒性。
  4. 支持多人物检测: 扩展模型结构和数据标注方式,以适应多人物场景中的复杂交互动作识别。

↓↓↓更多热门推荐:
Bi-LSTM-CRF实现中文命名实体识别工具(TensorFlow)

全部项目数据集、代码、教程进入官网zzgcz.com

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/456370.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

旧电脑安装Win11提示“这台电脑当前不满足windows11系统要求”,安装中断。怎么办?

前言 最近有很多小伙伴也获取了LTSC版本的Win11镜像,很大一部分小伙伴安装这个系统也是比较顺利的。 有顺利安装完成的,肯定也有安装不顺利的。这都是很正常的事情,毕竟这个镜像对电脑硬件要求还是挺高的。 有一部分小伙伴在安装Windows11 …

C++对象模型:关于对象

C语言和C对比 ⭐ 关联知识点:C和C语言区别 (1)C 语言的特点 简洁与高效:C 语言被设计为一种系统级的编程语言,它提供了对硬件的直接访问能力,并且编译后的代码通常非常紧凑,运行效率高。 全…

Java SnakeYaml 反序列化漏洞原理

目录 SnakeYaml 使用 SnakeYAML 序列化与反序列化 SnakeYAML 序列化实现 SnakeYAML 反序列化实现 SnakeYaml 反序列化漏洞 基于 ScriptEngineManager 利用链 漏洞原因分析 SPI 服务提供者发现机制 命令执行 漏洞修复 SnakeYaml SnakeYAML 是一个用于 Java 语言的 YA…

面试题:JVM(一)

1. JVM概述 1.1 JVM的生命周期 说说Java虚拟机的生命周期(阿里) 虚拟机的启动 Java虚拟机的启动是通过引导类加载器(bootstrap class loader)创建一个初始类(initial class)来完成的,这个类是由虚拟机的具体实现指定的。 虚拟机的退出有如下…

接口测试(九)jmeter——关联(JSON提取器)

一、JSON提取器介绍 要检查的响应字段:样本数据源引用名称:可自定义设置引用方法:${引用变量名}匹配数字 匹配数字含义-1表示全部0随机1第一个2第二个…以此类推 缺省值:匹配失败时的默认值ERROR,可以不写 二、js…

2024年双十一有什么好物推荐?盘点2024双十一爆款好物分享

第一款:希亦ACE内衣洗衣机 一句话点评:常出口欧美等多个国家,被超百家专业媒体评为“洗护一体技术之王”,妇科细菌除菌率达99.99%,清洁度高达99.8%! CEYEE希亦是清洁领域的实力大牌子了,也是母…

老照片如何修复变清晰?手把手教你4种模糊照片变清晰方法!

在洋溢着温情的生日聚会上,家人们围坐一堂,总会情不自禁地翻阅那些尘封已久的老照片,一同沉醉于往昔的温情岁月。然而,时光荏苒,许多承载着深情厚意的照片已变得泛黄、模糊,难以再现昔日的清晰与鲜活。但请…

vue2 a-input输入框使用正则限制为数字、英文及中文,出现吞字符和英文字符打断问题

需求是输入框限制数字、英文和中文,原始使用的正则是: replace(/[^a-zA-Z0-9\u4E00-\u9FA5]/g,)1、使用这个正则表达式使用搜狗输入法没问题,使用微软自带输入法后会存在输入英文会吞并当前光标前的字符,也有英文打断问题。 输入…

2024年【制冷与空调设备安装修理】考试及制冷与空调设备安装修理最新解析

题库来源:安全生产模拟考试一点通公众号小程序 制冷与空调设备安装修理考试参考答案及制冷与空调设备安装修理考试试题解析是安全生产模拟考试一点通题库老师及制冷与空调设备安装修理操作证已考过的学员汇总,相对有效帮助制冷与空调设备安装修理最新解…

线上遇到的问题记录(说多了都是泪)

写在前面 我觉得,工作中最有价值的就是及遇到的问题了,特别时线上这种容易让人血压升高的环境中遇到的问题,本文就是记录这些血压升高时刻。 如果你遇到什么真实环境的问题,也欢迎评论或者私信分享给我!!&…

Angular 保姆级别教程高阶应用 - RxJs

RxJS 13.1.1 什么是 RxJS ? RxJS 是一个用于处理异步编程的 JavaScript 库,目标是使编写异步和基于回调的代码更容易。 13.1.2 为什么要学习 RxJS ? 就像 Angular 深度集成 TypeScript 一样,Angular 也深度集成了 RxJS。 服务、表单、事件、全局状…

经典功率谱估计的原理及MATLAB仿真(自相关函数BT法、周期图法、bartlett法、welch法)

经典功率谱估计的原理及MATLAB仿真(自相关函数BT法、周期图法、bartlett法、welch法) 文章目录 前言一、BT法二、周期图法三、Bartlett法四、welch法五、MATLAB仿真六、MATLAB详细代码总结 前言 经典功率谱估计方法包括BT法(对自相关函数求傅…

基于Java的就业信息管理系统源码带本地搭建教程

技术框架:jQuery MySQL5.7 mybatis shiro Layui HTML CSs JS 运行环境:jdk8 IntelliJ IDEA maven3 宝塔面板 实现了就业信息管理、就业统计、用户管理等功能。有普通用户和管理员两种角色。

开源限流组件分析(三):golang-time/rate

文章目录 本系列前言提供获取令牌的API数据结构基础方法tokensFromDurationdurationFromTokensadvance 获取令牌方法reverseN其他系列API 令人费解的CancelAt是bug吗 取消后无法唤醒其他请求 本系列 开源限流组件分析(一):juju/ratelimit开源…

智能AI监测系统燃气安全改造方案的背景及应用价值

随着燃气行业的迅速发展和城市化进程的加快,燃气安全管理成为企业运营和城市管理中不可忽视的关键领域。燃气泄漏、管道破损等事故的发生不仅会造成严重的经济损失,还威胁到人民生命财产安全。传统的安全管理方法往往依赖人工巡检和手动监测,…

如何写一个视频编码器演示篇

先前写过《视频编码原理简介》,有朋友问光代码和文字不太真切,能否补充几张图片,今天我们演示一下: 这是第一帧画面:P1(我们的参考帧) 这是第二帧画面:P2(需要编码的帧&…

C2W4.LAB.Word_Embedding.Part2

理论课:C2W4.Word Embeddings with Neural Networks 文章目录 Training the CBOW modelForward propagationInitialization of the weights and biasesTraining exampleValues of the hidden layerValues of the output layerCross-entropy loss BackpropagationGr…

大家都在用的HR招聘管理工具:国内Top5排名

招聘管理工具是专为HR及招聘团队设计的数字化助手,旨在简化招聘流程,提高效率。众所周知,招聘管理工具通常集成简历收集、筛选、面试安排、候选人跟踪等功能于一体,让招聘过程更加流畅。使用招聘管理工具,不仅能节省时…

高边坡稳定安全监测预警系统解决方案

一、项目背景 高边坡的滑坡和崩塌是一种常见的自然地质灾害,一但发生而没有提前预告将给人民的生命财产和社会危害产生严重影响。对高边坡可能产生的灾害提前预警、必将有利于决策者采取应对措施、减少和降低灾害造成的损失。现有的高边坡监测技术有人工巡查和利用测…

100个候选人,没一个能讲明白什么是自动化框架?

什么是自动化测试框架 01 什么是框架 框架是整个或部分系统的可重用设计,表现为一组抽象构件及构件实例间交互的方法。它规定了应用的体系结构,阐明了整个设计、协作构件之间的依赖关系、责任分配和控制流程,表现为一组抽象类以及其实例之间…