华为开源自研AI框架昇思MindSpore应用案例:人体关键点检测模型Lite-HRNet

如果你对MindSpore感兴趣,可以关注昇思MindSpore社区

在这里插入图片描述

在这里插入图片描述

一、环境准备

1.进入ModelArts官网

云平台帮助用户快速创建和部署模型,管理全周期AI工作流,选择下面的云平台以开始使用昇思MindSpore,获取安装命令,安装MindSpore2.0.0-alpha版本,可以在昇思教程中进入ModelArts官网

在这里插入图片描述

选择下方CodeLab立即体验

在这里插入图片描述

等待环境搭建完成

在这里插入图片描述

2.使用CodeLab体验Notebook实例

下载NoteBook样例代码,Lite-HRNet实现人体关键点检测 ,.ipynb为样例代码

在这里插入图片描述

选择ModelArts Upload Files上传.ipynb文件

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

选择Kernel环境

在这里插入图片描述

切换至GPU环境,切换成第一个限时免费

在这里插入图片描述

进入昇思MindSpore官网,点击上方的安装

在这里插入图片描述

获取安装命令

在这里插入图片描述

回到Notebook中,在第一块代码前加入命令
在这里插入图片描述

conda update -n base -c defaults conda

在这里插入图片描述

安装MindSpore 2.0 GPU版本

conda install mindspore=2.0.0a0 -c mindspore -c conda-forge

在这里插入图片描述

安装mindvision

pip install mindvision

在这里插入图片描述

安装下载download

pip install download

人体关键点检测模型Lite-HRNet

人体关键点检测是计算机视觉的基本任务之一,在许多应用场景诸如自动驾驶、安防等有着重要的地位。可以发现,在这些应用场景下,深度学习模型可能需要部署在IoT设备上,这些设备算力较低,存储空间有限,无法支撑太大的模型,因此轻量但不失高性能的人体关键点检测级模型将极大降低模型部署难度。Lite-HRNet便提供了一轻量级神经网络骨干,通过接上不同的后续模型可以完成不同的任务,其中便包括人体关键点检测,在配置合理的情况下,Lite-HRNet可以以大型神经网络数十分之一的参数量及计算量达到相近的性能。

模型简介

Lite-HRNet由HRNet(High-Resolution Network)改进而来,HRNet的主要思路是在前向传播过程中通过维持不同分辨率的特征,使得最后生成的高阶特征既可以保留低分辨率高阶特征中的图像语义信息,也可以保留高分辨率高阶特征中的物体位置信息,进而提高在分辨率敏感的任务如语义分割、姿态检测中的表现。Lite-HRNet是HRNet的轻量化改进,改进了HRNet中的卷积模块,将HRNet中的参数量从28.5M降低至1.1M,计算量从7.1GFLOPS降低至0.2GFLOPS,但AP75仅下降了7%。
综上,Lite-HRNet具有计算量、参数量低,精度可观的优点,有利于部署在物联网低算力设备上服务于各个应用场景。

数据准备

本案例使用COCO2017数据集作为训练、验证数据集,请首先安装Mindspore Vision套件,并确保安装的Mindspore是GPU版本,随后请在https://cocodataset.org/ 上下载好2017 Train Images、2017 Val Images以及对应的标记2017 Train/Val Annotations,并解压至当前文件夹,文件夹结构下表所示

Lite-HRNet/├── imgs├── src├── annotations├──person_keypoints_train2017.json└──person_keypoints_train2017.json├── train2017└── val2017

训练、测试原始图片如下所示,图片中可能包含多个人体,且包含的人体不一定包含COCO2017中定义的17个关键点,标注中有每个人体的边框、关键点信息,以便处理图像后供模型训练。

数据预处理

src/mindspore_coco.py中定义了供mindspore模型训练、测试的COCO数据集接口,在加载训练数据集时只需指定所用数据集文件夹位置、输入图像的尺寸、目标热力图的尺寸、以及手动设置对训练图像采用的变换即可


import mindspore as ms
import mindspore.dataset as dataset
import mindspore.dataset.vision.py_transforms as py_vision
import mindspore.nn as nn
from mindspore.dataset.transforms.py_transforms import Composefrom src.configs.dataset_config import COCOConfig
from src.dataset.mindspore_coco import COCODatasetcfg = COCOConfig(root="./", output_dir="outputs/", image_size=[192, 256], heatmap_size=[48, 64])
trans = Compose([py_vision.ToTensor(),py_vision.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
train_ds = COCODataset(cfg, "../", "train2017", True, transform=trans)
train_loader = dataset.GeneratorDataset(train_ds, ["data", "target", "weight"])

在这里插入图片描述

构建网络

Lite-HRNet网络骨干大体结构如下图所示:
在这里插入图片描述

网络中存在不同的分辨率分支,网络主干上维持着较高分辨率、较少通道数的输出特征,网络分支上延展出较低分辨率、较多通道数的输出特征,且这些不同分辨率的特征之间通过上采样、下采样的卷积层进行交互、融合。Stage内的Cross Channel Weighting(CCW)则是网络实现轻量化的精髓,它将原HRNet中复杂度较高的1*1卷积以更低复杂度的Spatial Weighting等方法替代,从而实现降低网络参数、计算量的效果。CCW的结构如下图所示

在这里插入图片描述

值得注意的是,除了骨干网络,作者在论文中同时也给出了所使用的检测头即SimpleBaseline,为了简洁起见,在本次的Lite-HRNet的Mindspore实现中,检测头(代码中包括IterativeHeads和LiteTopDownSimpleHeatMap)已集成至骨干网络之后,作为整体模型的一部分,直接调用模型即可得到热力图预测输出。

损失函数

此处使用损失函数为JointMSELoss,即关节点的均方差误差损失函数,其源码如下所示,总体流程即计算每个关节点预测热力图与实际热力图的均方差,其中target是根据关节点的人工标注坐标,通过二维高斯分布生成的热力图,target_weight用于指定参与计算的关节点,若某关节点对应target_weight取值为0,则表明该关节点在输入图像中未出现,不参与计算。

"""JointMSELoss"""
import mindspore.nn as nn
import mindspore.ops as opsclass JointsMSELoss(nn.Cell):"""Joint MSELoss"""def __init__(self, use_target_weight):"""JointMSELoss"""super(JointsMSELoss, self).__init__()self.criterion = nn.MSELoss(reduction='mean')self.use_target_weight = use_target_weightdef construct(self, output, target, weight):"""construct"""target = targettarget_weight = weightbatch_size = output.shape[0]num_joints = output.shape[1]spliter = ops.Split(axis=1, output_num=num_joints)mul = ops.Mul()heatmaps_pred = spliter(output.reshape((batch_size, num_joints, -1)))heatmaps_gt = spliter(target.reshape((batch_size, num_joints, -1)))loss = 0for idx in range(num_joints):heatmap_pred = heatmaps_pred[idx].squeeze()heatmap_gt = heatmaps_gt[idx].squeeze()if self.use_target_weight:heatmap_pred = mul(heatmap_pred, target_weight[:, idx])heatmap_gt = mul(heatmap_gt, target_weight[:, idx])loss += 0.5 * self.criterion(heatmap_pred,heatmap_gt)else:loss += 0.5 * self.criterion(heatmap_pred, heatmap_gt)return loss/num_joints

模型实现与训练

在实现模型时,需指定模型内部结构,在src/net_configs中已指定原论文中10种结构配置,在训练样例种取Lite_18_coco作为模型结构,此处作为案例,仅设置epoch数量为1,在实际训练中可以设置为200,并且可以加入warmup。由于mindspore的训练接口默认数据集中每条数据只有两列(即训练数据和标签),所以这里需自定义Loss Cell。值得注意的是loss在训练前后变化并不会十分大,训练好的模型的loss为0.0004左右

class CustomWithLossCell(nn.Cell):def __init__(self,net: nn.Cell,loss_fn: nn.Cell):super(CustomWithLossCell, self).__init__()self.net = netself._loss_fn = loss_fndef construct(self, img, target, weight):""" build network """heatmap_pred = self.net(img)return self._loss_fn(heatmap_pred,target,weight)
from src.configs.net_configs import get_netconfig
from mindspore.train.callback import  LossMonitor
from src.backbone import LiteHRNetext = get_netconfig("extra_lite_18_coco")
net = LiteHRNet(ext)
criterion = JointsMSELoss(use_target_weight=True)train_loader = train_loader.batch(64)
optim = nn.Adam(net.trainable_params(), learning_rate=2e-3)
loss = JointsMSELoss(use_target_weight=True)
net_with_loss = CustomWithLossCell(net, loss)model = ms.Model(network=net_with_loss, optimizer=optim)
epochs = 1
#Start Training
model.train(epochs, train_loader, callbacks=[LossMonitor(100)], dataset_sink_mode=False)

在这里插入图片描述

模型评估

模型评估过程中使用AP、AP50、AP75以及AR50、AR75作为评价指标,val2017作为评价数据集,pycocotool包中已实现根据评价函数,且src/mindspore_coco.py中的evaluate函数也实现了调用该评价函数的接口,只需提供预测关键点坐标等信息即可获得评价指标。此处载入Lite_18_coco的预训练模型进行评价。

from mindspore import load_checkpoint
from mindspore import load_param_into_netfrom src.utils.utils import get_final_preds
import numpy as npdef evaluate_model(model, dataset, output_path):"""Evaluate"""num_samples = len(dataset)all_preds = np.zeros((num_samples, 17, 3),dtype=np.float32)all_boxes = np.zeros((num_samples, 6))image_path = []for i, data in enumerate(dataset):input_data, target, meta = data[0], data[1], data[3]input_data = ms.Tensor(input_data[0], ms.float32).reshape(1, 3, 256, 192)shit = model(input_data).asnumpy()target = target.reshape(shit.shape)c = meta['center'].reshape(1, 2)s = meta['scale'].reshape(1, 2)score = meta['score']preds, maxvals = get_final_preds(shit, c, s)all_preds[i:i + 1, :, 0:2] = preds[:, :, 0:2]all_preds[i:i + 1, :, 2:3] = maxvals# double check this all_boxes partsall_boxes[i:i + 1, 0:2] = c[:, 0:2]all_boxes[i:i + 1, 2:4] = s[:, 0:2]all_boxes[i:i + 1, 4] = np.prod(s*200, 1)all_boxes[i:i + 1, 5] = scoreimage_path.append(meta['image'])dataset.evaluate(0, all_preds, output_path, all_boxes, image_path)net_dict = load_checkpoint("./ckpt/litehrnet_18_coco_256x192.ckpt")
load_param_into_net(net, net_dict)eval_ds = COCODataset(cfg, "./", "val2017", False, transform=trans)
evaluate_model(net, eval_ds, "./result")

在这里插入图片描述

模型推理

  1. Lite-HRNet是关键点检测模型,所以输入待推理图像应为包含单个人体的图像,作者在论文中提及在coco test 2017测试前已使用SimpleBaseline生成的目标检测Bounding Box处理图像,所以待推理图像应仅包含单个人体。
  2. 网络的输入为(1,3,256,192),所以在输入图像前应先将其变换成网络可处理的形式。
import cv2
from src.utils.utils import get_max_preds
origin_img = cv2.imread("./imgs/man.jpg")
origin_h, origin_w, _ = origin_img.shape
scale_factor = [origin_w/192, origin_h/256]# resize to (112 112 3) and convert to tensor
img = cv2.resize(origin_img, (192, 256))
print(img.shape)
img = trans(img)
# img = np.expand_dims(img, axis=0)
img = ms.Tensor(img)
print(img.shape)# Infer
heatmap_pred = net(img).asnumpy()
pred, _ = get_max_preds(heatmap_pred)# Postprocess
pred = pred.reshape(pred.shape[0], -1, 2)
print(pred[0])
pre_landmark = pred[0] * 4 * scale_factor
# Draw points
for (x, y) in pre_landmark.astype(np.int32):cv2.circle(origin_img, (x, y), 3, (255, 255, 255), -1)# Save image
cv2.imwrite("./imgs/man_infer.jpg", origin_img)

在这里插入图片描述

可以看到模型基本正确标注出了关键点的位置\

在这里插入图片描述

算法基本流程

  1. 获取原始数据
  2. 从数据集的标注json文件中得到各个图像bbox以及关键点坐标信息
  3. 根据bbox裁剪图像,并放缩至指定尺寸,如果是训练还可以作适当数据增强,生成指定尺寸的目标热力图
  4. 指定尺寸的输入经过网络前向传播后得到预测的关键点热力图
  5. 经过处理后取热力图中的最大值坐标作为关键点的预测坐标

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/473429.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

gitlab和jenkins连接

一:jenkins 配置 安装gitlab插件 生成密钥 id_rsa 要上传到jenkins,id_rsa.pub要上传到gitlab cat /root/.ssh/id_rsa 复制查看的内容 可以看到已经成功创建出来了对于gitlab的认证凭据 二:配置gitlab cat /root/.ssh/id_rsa.pub 复制查…

SpringBoot实现WebSocket

参考链接&#xff1a;https://www.kancloud.cn/king_om/mic_03/2783864 一、环境搭建 1.创建SpringBoot项目&#xff0c;引入相关依赖 <dependencies><!-- Spring Boot核心启动器&#xff0c;引入常用依赖基础 --><dependency><groupId>org.springf…

现代密码学|公钥密码体制 | RSA加密算法及其数学基础

文章目录 公钥密码RSA数学基础欧拉函数欧拉定理模指数运算 RSA加密算法对rsa的攻击 公钥密码 现代密码学&#xff5c;公钥密码体制概述 加密 A用B的公钥加密 B用B的私钥解密 认证 A使用A的私钥加密 B使用A的公钥解密 加密认证 A用A的私钥加密&#xff0c;再用B的公钥加密 B用…

VuePress v2 快速搭建属于自己的个人博客网站

目录 为什么用VuePress&#xff1f; 一、前期准备 Node.js 使用主题快速开发 二、VuePress安装 三、个性化定制 修改配置信息 删除不需要的信息 博客上传 四、部署 使用github快速部署 初始化仓库 本地配置 配置github的ssh密钥 部署 为什么用VuePress&#xff…

【阅读记录-章节1】Build a Large Language Model (From Scratch)

目录 1. Understanding large language models1.1 What is an LLM?补充介绍人工智能、机器学习和深度学习的关系机器学习 vs 深度学习传统机器学习 vs 深度学习&#xff08;以垃圾邮件分类为例&#xff09; 1.2 Applications of LLMs1.3 Stages of building and using LLMs1.4…

平台整合是网络安全成功的关键

如今&#xff0c;组织面临着日益复杂、动态的网络威胁环境&#xff0c;随着恶意行为者采用越来越阴险的技术来破坏环境&#xff0c;攻击的数量和有效性也在不断上升。我们最近的 Cyber​​Ark 身份威胁形势报告&#xff08;2024 年 5 月&#xff09;发现&#xff0c;去年 99% 的…

PlantUML——时序图

PlantUML时序图 背景 时序图&#xff08;Sequence Diagram&#xff09;&#xff0c;又名序列图、循序图&#xff0c;是一种UML交互图&#xff0c;用于描述对象之间发送消息的时间顺序&#xff0c;显示多个对象之间的动态协作。时序图的使用场景非常广泛&#xff0c;几乎各行各…

【MYSQL】分库分表

一、什么是分库分表 分库分表就是指在一个数据库在存储数据过大&#xff0c;或者一个表存储数据过多的情况下&#xff0c;为了提高数据存储的可持续性&#xff0c;查询数据的性能而进行的将单一库或者表分成多个库&#xff0c;表使用。 二、为什么要分库分表 分库分表其实是两…

Spring纯注解开发

在我的另一篇文章中&#xff08;初识Spring-CSDN博客&#xff09;&#xff0c;讲述了Bean&#xff0c;以及通过xml方式定义Bean。接下来将讲解通过注解的方法管理Bean。 我们在创建具体的类的时候&#xff0c;可以直接在类的上面标明“注解”&#xff0c;以此来声明类。 1. 常…

git push时报错! [rejected] master -> master (fetch first)error: ...

错误描述&#xff1a;在我向远程仓库push代码时&#xff0c;即执行 git push origin master命令时发生的错误。直接上错误截图。 错误截图 错误原因&#xff1a; 在网上查了许多资料&#xff0c;是因为Git仓库中已经有一部分代码&#xff0c;它不允许你直接把你的代码覆盖上去…

java常用工具包介绍

Java 作为一种广泛使用的编程语言&#xff0c;提供了丰富的标准库和工具包来帮助开发者高效地进行开发。这些工具包涵盖了从基础的数据类型操作到高级的网络编程、数据库连接等各个方面。下面是一些 Java 中常用的工具包&#xff08;Package&#xff09;及其简要介绍&#xff1…

latex中,两个相邻的表格,怎样留一定的空白

目录 问题描述 问题解决 问题描述 在使用latex写论文时&#xff0c;经常表格需要置顶写&#xff0c;则会出现两个表格连在一起的情况。下一个表名容易与上面的横线相连&#xff0c;如何通过明令&#xff0c;留出一定的空白。 问题解决 在第二个表格的 \centering命令之后…

react中如何在一张图片上加一个灰色蒙层,并添加事件?

最终效果&#xff1a; 实现原理&#xff1a; 移动到图片上的时候&#xff0c;给img加一个伪类 &#xff01;&#xff01;此时就要地方要注意了&#xff0c;因为img标签是闭合的标签&#xff0c;无法直接添加 伪类&#xff08;::after&#xff09;&#xff0c;所以 我是在img外…

C++builder中的人工智能(27):如何将 GPT-3 API 集成到 C++ 中

人工智能软件和硬件技术正在迅速发展。我们每天都能看到新的进步。其中一个巨大的飞跃是我们拥有更多基于自然语言处理&#xff08;NLP&#xff09;和深度学习&#xff08;DL&#xff09;机制的逻辑性更强的AI聊天应用。有许多AI工具可以用来开发由C、C、Delphi、Python等编程语…

【项目开发】URL中井号(#)的技术细节

未经许可,不得转载。 文章目录 前言一、# 的基本含义二、# 不参与 HTTP 请求三、# 后的字符处理机制四、# 的变化不会触发网页重新加载五、# 的变化会记录在浏览器历史中六、通过 window.location.hash 操作七、onhashchange 事件八、Google 对 # 的处理机制前言 2023 年 9 月…

AUTOSAR_EXP_ARAComAPI的7章笔记(5)

☞返回总目录 相关总结&#xff1a;典型的 SOME/IP 多绑定用例总结 7.3.3 典型的SOME/IP多绑定用例 在前面的章节中&#xff0c;我们简要提到&#xff0c;在一个典型的SOME/IP 网络协议的部署场景中&#xff0c;AP SWC不太可能自己打开套接字连接来与远程服务通信。为什么不…

Jenkins下载安装、构建部署到linux远程启动运行

Jenkins详细教程 Winodws下载安装Jenkins一、Jenkins配置Plugins插件管理1、汉化插件2、Maven插件3、重启Jenkins&#xff1a;Restart Safely插件4、文件传输&#xff1a;Publish Over SSH5、gitee插件6、清理插件&#xff1a;workspace cleanup system系统配置1、Gitee配置2、…

Flutter:Dio下载文件到本地

import dart:io; import package:dio/dio.dart;main(){// 创建dio对象final dio Dio();// 下载地址var url https://*******.org/files/1.0.0.apk;// 手机端路径String savePath Directory.systemTemp.path/ceshi.apk;print(savePath);downLoad(dio,url,savePath); }downLo…

【C++笔记】C++三大特性之多态

【C笔记】C三大特性之多态 &#x1f525;个人主页&#xff1a;大白的编程日记 &#x1f525;专栏&#xff1a;C笔记 文章目录 【C笔记】C三大特性之多态前言一.多态1.1 多态的概念1.2 虚函数1.3 虚函数的重写/覆盖1.4 多态的定义及实现 二.虚函数重写的⼀些其他问题2.1 协变(…

2.STM32之通信接口《精讲》之USART通信

有关通信详解进我主页观看其他文章&#xff01;【免费】SPIIICUARTRS232/485-详细版_UART、IIC、SPI资源-CSDN文库 通过以上可以看出。根据电频标准&#xff0c;可以分为TTL电平&#xff0c;RS232电平&#xff0c;RS485电平&#xff0c;这些本质上都属于串口通信。有区别的仅是…