基于迁移学习的手势分类模型训练

1、基本原理介绍

       这里介绍的单指模型迁移。一般我们训练模型时,往往会自定义一个模型类,这个类中定义了神经网络的结构,训练时将数据集输入,从0开始训练;而迁移学习中(单指模型迁移策略),是在一个已经有过训练基础的模型上,用自己的数据集,进一步训练,使得这个模型能够完成我们需要的任务。

这么做有有这样几个显而易见的好处:

※  因为模型之前被训练过,所以初始参数不会是0,这样能够加速模型训练

※  因为预训练模型(什么是预训练模型下文会讲到)在其他数据集上训练过,而其他数据集往往和我们用的数据集存在一定的区别,所以这可以提高模型的泛化能力

※  通过迁移学习,可以将来自大规模数据的优势转移到小规模或新任务上,提高模型的表现和效果

2、预训练模型

        在进行迁移学习时,我们要先找到一个预训练模型。在分类任务领域,比较流行的如resnet系列、mobilenet系列(更轻量化)、vgg(系列)、efficientnet(系列)等等网络,都是比较常用且容易获得的预训练模型,这些模型都能够通过python直接下载。

        而且由于上述模型基本都是在ImageNet这一大规模,多分类类别的数据集上进行过训练的,所以对于简单的二分类等少数类别分类,能有较好的效果。

3、训练流程

迁移学习完整的训练流程和一般搭建神经网络的训练模型的流程基本类似:数据预处理->数据集的切分->加载预训练模型(搭建神经网络)->设置超参数/损失函数/优化器等->训练模型

3.1 模型训练

下面的代码是一个利用mobilenet网络训练得到的手势分类模型,该模型能够较准确的分类不同类别手势。

相关解释已在代码中注释说明。

from torchvision.models import mobilenet_v2
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize, RandomHorizontalFlip, RandomRotation# 定义数据预处理和增强器
transform = Compose([RandomHorizontalFlip(),  # 随机水平翻转RandomRotation(10),      # 随机旋转10度Resize((224, 224)),CenterCrop(224),ToTensor(),Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 加载数据集并应用预处理和增强器
dataset = ImageFolder(root='data', transform=transform)
# 这里由于数据比较少,将所有数据集全部用来训练,得到的模型直接拿来用了,这其实不算是非常规范的操作,仅供参考# 定义网络结构
model = mobilenet_v2(pretrained=True)  # 加载预训练模型,也可以试试其他模型,效果差别挺大的
num_ftrs = model.classifier[1].in_features
model.classifier[1] = nn.Linear(num_ftrs, 5)  # 假设是5分类问题,具体几分类,改这里的参数就行了# 将模型移动到设备上
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)# 定义优化器和损失函数
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5)
criterion = nn.CrossEntropyLoss()# 定义训练循环
def train_model(model, criterion, optimizer, num_epochs, train_loader):for epoch in range(num_epochs):model.train()  # 设置模型为训练模式train_loss = 0.0correct = 0total = 0for inputs, labels in train_loader:inputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()train_loss += loss.item() * inputs.size(0)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()epoch_loss = train_loss / totalepoch_acc = 100. * correct / totalprint(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%')# 创建训练集的DataLoader
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)# 开始训练模型
train_model(model, criterion, optimizer, num_epochs=15, train_loader=train_loader)
torch.save(model, 'my_model(1).pth')

3.2 数据集文件结构

当然,你也可以自己定义读取数据集的data_loader类。

3.3 模型推理

这段代码是用训练得到的模型对一张图片进行推理测试的,如果需要对系列图片进行推理,评估模型效果,可自行修改,调用对应函数即可。

import torch
from PIL import Image
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
def predict_image(image_path, model_path='my_model(1).pth'):image = Image.open(image_path).convert("RGB")# 对测试的图片进行预处理,需要和训练时处理的方式一样transform = Compose([Resize((224, 224)),CenterCrop(224),ToTensor(),Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])image_tensor = transform(image).unsqueeze(0)device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')image_tensor = image_tensor.to(device)model = torch.load(model_path,map_location=device)model.eval()with torch.no_grad():output = model(image_tensor)_, predicted = torch.max(output.data, 1)  # 获得分类标记return predicted.item()
if __name__=="__main__":image_path = "test2/6.jpg"print(predict_image(image_path))

3.4 整体项目文件

4、补充说明

        这种利用迁移学习策略,进行少类别,不同类别特征差距小的任务需求来说,效果一般来说是比较好的。因为之前做过相关实验,准确率90%以上是很容易的,所以这里没有模型评估,生成混淆矩阵等过程。对于多类别分类,建议有完整的评估体系。

        上述使用的方法仅适用于分类任务,对于真正的目标检测如手势识别,直接使用该模型的问题是:由于无法定位手势的位置,所以导致识别不准确。

        本实验数据集是不同类别手势图片,为自制,不开源。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/384565.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

如何查看jvm资源占用情况

如何设置jar的内存 java -XX:MetaspaceSize256M -XX:MaxMetaspaceSize256M -XX:AlwaysPreTouch -XX:ReservedCodeCacheSize128m -XX:InitialCodeCacheSize128m -Xss512k -Xmx2g -Xms2g -XX:UseG1GC -XX:G1HeapRegionSize4M -jar your-application.jar以上配置为堆内存4G jar项…

二叉树详解-第四篇 二叉树链式结构的实现

目录 1.二叉树的遍历 1.1前序遍历: 1.2 中序遍历: 1.3 后序遍历: 2.二叉树链式结构的实现 2.1 Tree.h 2.2 Tree.cpp 2.2.1 前序遍历 void PreOrder(TNode* Root) 2.2.2 中序遍历 void InOrder(TNode* Root) 2.2.3 后序遍历 void Bac…

基于opencv[python]的人脸检测

1 图片爬虫 这里的代码转载自:http://t.csdnimg.cn/T4R4F # 获取图片数据 import os.path import fake_useragent import requests from lxml import etree# UA伪装 head {"User-Agent": fake_useragent.UserAgent().random}pic_name 0 def request_pic…

DVWA的安装和使用

背景介绍 DVWA是Damn Vulnerable Web Application的缩写,是一个用于安全脆弱性检测的开源Web应用。它旨在为安全专业人员提供一个合法的测试环境,帮助他们测试自己的专业技能和工具,同时也帮助web开发者更好地理解web应用安全防范的过程。DV…

微信小程序-本地部署(前端)

遇到问题:因为是游客模式所以不能修改appID. 参考链接:微信开发者工具如何从游客模式切换为开发者模式?_微信开发者工具如何修改游客模式-CSDN博客 其余参考:Ego微商项目部署(小程序项目)(全网…

Wonder3D 论文学习

论文链接:https://arxiv.org/abs/2310.15008 代码链接:https://github.com/xxlong0/Wonder3D 解决了什么问题? 随着扩散模型的提出,3D 生成领域取得了长足进步。从单张图片重建出 3D 几何是计算机图形学和 3D 视觉的基础任务&am…

k8s安装

说明 本事件适合刚刚装系统的新机子,前提是可以ping通www.baidu。yum可以下载软件 本实验模拟单机k8s,主机ip为172.26.50.222 关闭防火墙 systemctl status firewalld systemctl stop firewalld systemctl disable firewalld getenforce setenforce …

【React】详解样式控制:从基础到进阶应用的全面指南

文章目录 一、内联样式1. 什么是内联样式?2. 内联样式的定义3. 基本示例4. 动态内联样式 二、CSS模块1. 什么是CSS模块?2. CSS模块的定义3. 基本示例4. 动态应用样式 三、CSS-in-JS1. 什么是CSS-in-JS?2. styled-components的定义3. 基本示例…

llama模型,nano

目录 llama模型 Llama模型性能评测 nano模型是什么 Gemini Nano模型 参数量 MMLU、GPQA、HumanEval 1. MMLU(Massive Multi-task Language Understanding) 2. GPQA(Grade School Physics Question Answering) 3. HumanEval llama模型 Large Language Model AI Ll…

【React】详解 Redux 状态管理

文章目录 一、Redux 的基本概念1. 什么是 Redux?2. Redux 的三大原则 二、Redux 的核心组件1. Store2. Action3. Reducer 三、Redux 的使用流程1. 安装 Redux 及其 React 绑定2. 创建 Action3. 创建 Reducer4. 创建 Store5. 在 React 应用中使用 Store6. 连接 React…

【Redis】主从复制分析-基础

1 主从节点运行数据的存储 在主从复制中, 对于主节点, 从节点就是自身的一个客户端, 所以和普通的客户端一样, 会被组织为一个 client 的结构体。 typedef struct client {// 省略 } client;同时无论是从节点, 还是主节点, 在运行中的数据都存放在一个 redisServer 的结构体中…

使用C#手搓Word插件

WordTools主要功能介绍 编码语言:C#【VSTO】 1、选择 1.1、表格 作用:全选文档中的表格; 1.2、表头 作用:全选文档所有表格的表头【第一行】; 1.3、表正文 全选文档中所有表格的除表头部分【除第一行部分】 1.…

Vue常用指令及其生命周期

作者:CSDN-PleaSure乐事 欢迎大家阅读我的博客 希望大家喜欢 目录 1.常用指令 1.1 v-bind 1.2 v-model 注意事项 1.3 v-on 注意事项 1.4 v-if / v-else-if / v-else 1.5 v-show 1.6 v-for 无索引 有索引 生命周期 定义 流程 1.常用指令 Vue当中的指令…

福派斯牛肉高脂猫粮,为何成猫舍首选?揭秘其神奇功效!

🐾 说到猫咪的伙食,咱们当铲屎官的可是操碎了心!想让自家毛孩子吃得健康又开心,选对猫粮真的太重要了。今天就来聊聊为啥福派斯牛肉高脂猫粮能成为众多猫舍的首选,以及它到底能帮咱们的小猫咪哪些忙吧! 1️…

数据传输安全--SSL VPN

目录 IPSEC在Client to LAN场景下比较吃力的表现 SSL VPV SSL VPN优势 SSL协议 SSL所在层次 SSL工作原理 SSL握手协议、SSL密码变化协议、SSL警告协议三个协议作用 工作过程 1、进行TCP三次握手、建立网络连接会话 2、客户端先发送Client HELLO包,下图是包…

springboot项目从jdk8升级为jdk17过程记录

背景:公司有升级项目jdk的规划,计划从jdk8升级到jdk11 开始 首先配置本地的java_home 参考文档:Mac环境下切换JDK版本及不同的maven-CSDN博客 将pom.xml中jdk1.8相关的版本全部改为jdk17,主要是maven编译插件之类的&#xff0c…

ubuntu22.04 安装 NVIDIA 驱动以及CUDA

目录 1、事前问题解决 2、安装 nvidia 驱动 3、卸载 nvidia 驱动方法 4、安装 CUDA 5、安装 Anaconda 6、安装 PyTorch 1、事前问题解决 在安装完ubuntu之后,如果进入ubuntu出现黑屏情况,一般就是nvidia驱动与linux自带的不兼容,可以通…

找工作准备刷题Day10 回溯算法 (卡尔41期训练营 7.24)

回溯算法今天这几个题目做过,晚上有面试,今天水一水。 第一题:Leetcode77. 组合 题目描述 解题思路 从题目示例来看,k个数是不能重合的,但是题目没有明确说明这一点。 使用回溯算法解决此问题,利用树形…

数据结构 —— B+树和B*树及MySQL底层引擎

数据结构 —— B树和B*树及MySQL底层引擎 B树B*树B树的应用B树在MySQL中的应用MyISAMInnoDB 我们之前学习了B树的基本原理,今天我们来看看B树的一些改良版本——B树和B*树。如果还没有了解过的小伙伴可以点击这里: https://blog.csdn.net/qq_67693066/ar…

【MySQL进阶之路 | 高级篇】MVCC三剑客:隐藏字段,Undo Log,ReadView

1. 再谈隔离级别 我们知道事务有四个隔离级别,可能存在三种并发问题: 在MySQL中,默认的隔离级别是可重复读,可以解决脏读和不可重复读的问题,如果仅从定义的角度来看,它并不能解决幻读问题。如果我们想要解…