基于PaddleClas的人物年龄分类项目

目录

一、任务概述

二、算法研发

2.1 下载数据集

2.2 数据集预处理

2.3 安装PaddleClas套件

2.4 算法训练

2.5 静态图导出

2.6 静态图推理

三、小结


一、任务概述

    最近遇到个需求,需要将图像中的人物区分为成人和小孩,这是一个典型的二分类问题,打算采用飞桨的图像分类套件PaddleClas来完成算法研发。本文记录相关流程。

二、算法研发

2.1 下载数据集

    本文采用MaGaAge_Asian数据集,该数据集主要由亚洲人图片组成,训练集包含40000张图像,验证集包含3495张图像,每张图像都有对应的年龄真值,所有图像均处理成了统一的大小,宽178像素,高218像素。

数据集地址下载链接。数据集部分示例如下图所示:

    该数据集本意是用来做年龄预测的,属于一个数值回归任务,本文将其变成二分类任务,以13岁年龄为界限,小于该年龄的属于小孩,大于该年龄的属于成人。这里之所以选择13岁,因为这个任务是需要筛选出长得很“像”小孩的小孩,13岁以上的青少年很多本身已经长的像成人了,因此,选择13岁作为分界线。

    下面首先对该数据集进行处理。

2.2 数据集预处理

    MaGaAge_Asian数据集每张图片对应的人物年龄存放在list文件夹的两个文件中,其中train_age.txt存放训练集对应的年龄真值,test_age.txt存放验证集对应的年龄真值。下面要写一个脚本,将所有小于13岁的图片移动到一个文件夹内,所有大于等于13岁的图片移动到另一个文件夹内。

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
'''
@文件        :split_asian.py
@说明        :拆分megaage_asian数据集,将小于13岁的移动到一个文件夹,大于等于13岁的移动到另一个文件夹
@时间        :2024/07/16 09:11:16
@作者        :Bin Qian
@版本        :1.0
'''import os
import cv2thr = 13 # 年龄阈值# 读取年龄列表
agefile = 'megaage_asian/list/test_age.txt'
f=open(agefile) 
ageLst = f.read().splitlines()
f.close() # 读取图像
imgFolder = 'megaage_asian/val'
imgnames = os.listdir(imgFolder)
index = 50000
for imgname in imgnames:imgPath = os.path.join(imgFolder,imgname)img = cv2.imread(imgPath)if img is None:continueprint(imgPath)imgindex = int(imgname.split('.')[0])age = int(ageLst[imgindex-1])if age < thr:dstFolder = 'ageclas/child'else:dstFolder = 'ageclas/adult'savePath = os.path.join(dstFolder,str(index)+'_asian.jpg')cv2.imwrite(savePath,img)index += 1
print('完成')

值得注意的是MaGaAge_Asian数据集中有很多质量较差的图像,这些“脏”图像会影响学习效果,最好手工检查这些数据并将其剔除。

另外,为了能够取得更好的效果,本文从互联网和FFHQ数据集里面再挑选出一些小孩和成人的照片进行补充。部分代码如下:

import os
import cv2# 读取图像
imgFolder = 'adult'
imgnames = os.listdir(imgFolder)
index = 1
for imgname in imgnames:imgPath = os.path.join(imgFolder,imgname)img = cv2.imread(imgPath)if img is None:continueprint(imgPath)dstFolder = 'ageclas/adult'savePath = os.path.join(dstFolder,str(index)+'_data.jpg')cv2.imwrite(savePath,img)index += 1
print('完成')

补充完整后,最后对整理好的数据集进行拆分,并且获得对应的文件列表:

# 导入系统库
import os
import random
import cv2# 定义参数
img_folder = 'ageclas'
trainlst = 'train_list.txt'
vallst = 'val_list.txt'
ratio = 0.95 # 训练集占比
labellst='label.txt'def writeLst(lstpath,namelst):'''保存文件列表'''print('正在写入 '+lstpath)random.shuffle (namelst)# 写入训练样本文件f=open(lstpath, 'a', encoding='utf-8')for i in range(len(namelst)):text = namelst[i]+'\n'f.write(text)f.close()print(lstpath+ '已完成写入')def main():'''主函数'''# 查找文件夹folderlst = os.listdir(img_folder)print('共找到 %d 个文件夹' % len(folderlst))# 循环处理trainnamelst = list()valnamelst = list()labelnamelst = list()for i in range(len(folderlst)):class_name = folderlst[i]class_label = iprint('开始处理 '+class_name+' 文件夹')# 获取子文件夹文件列表filenamelst = os.listdir(os.path.join(img_folder,class_name))totalNum = len(filenamelst)print('当前文件夹图片数量为: ' + str(totalNum)) trainNum = int(ratio*totalNum)text =  str(class_label)+ ' ' + class_namelabelnamelst.append(text)# 检查并校验图像for j in range(totalNum):imgpath = os.path.join(img_folder,class_name,filenamelst[j])img = cv2.imread(imgpath, cv2.IMREAD_COLOR)if img is None:continuetext = imgpath + ' ' + str(class_label)if j <= trainNum: trainnamelst.append(text)else:valnamelst.append(text)writeLst(trainlst,trainnamelst)writeLst(vallst,valnamelst)   writeLst(labellst,labelnamelst)     print('全部完成')if __name__ == '__main__':'''程序入口'''main()

运行后会生成train_lst.txt、val_lst.txt以及label.txt三个文件,有了这三个文件就可以使用PaddleClas套件进行算法研发了。

2.3 安装PaddleClas套件

git clone https://gitee.com/paddlepaddle/PaddleClas.git
cd PaddleClas
sudo python setup.py install

2.4 算法训练

在PaddleClas目录下新建一个配置文件config_lcnet.yaml,采用PPLCNet_x0_5模型来训练,配置文件代码如下:

# global configs
Global:checkpoints: nullpretrained_model: nulloutput_dir: ./output/device: gpusave_interval: 5eval_during_train: Trueeval_interval: 5epochs: 200print_batch_step: 10use_visualdl: True# used for static mode and model exportimage_shape: [3, 224, 224]save_inference_dir: ./output/inference
# model architecture
Arch:name: PPLCNet_x0_5class_num: 2# loss function config for traing/eval process
Loss:Train:- CELoss:weight: 1.0epsilon: 0.1Eval:- CELoss:weight: 1.0Optimizer:name: Momentummomentum: 0.9lr:name: Cosinelearning_rate: 0.8warmup_epoch: 5regularizer:name: 'L2'coeff: 0.00003# data loader for train and eval
DataLoader:Train:dataset:name: ImageNetDatasetimage_root: ../process_data/cls_label_path: ../process_data/train_list.txttransform_ops:- DecodeImage:to_rgb: Truechannel_first: False- ResizeImage:size: [224,224]- RandFlipImage:flip_code: 1- NormalizeImage:scale: 1.0/255.0mean: [0.485, 0.456, 0.406]std: [0.229, 0.224, 0.225]order: ''sampler:name: DistributedBatchSamplerbatch_size: 64drop_last: Falseshuffle: Trueloader:num_workers: 4use_shared_memory: TrueEval:dataset: name: ImageNetDatasetimage_root: ../process_data/cls_label_path: ../process_data/val_list.txttransform_ops:- DecodeImage:to_rgb: Truechannel_first: False- ResizeImage:size: [224,224]- NormalizeImage:scale: 1.0/255.0mean: [0.485, 0.456, 0.406]std: [0.229, 0.224, 0.225]order: ''sampler:name: DistributedBatchSamplerbatch_size: 64drop_last: Falseshuffle: Falseloader:num_workers: 4use_shared_memory: TrueInfer:infer_imgs: "../testimgs/10.jpg"batch_size: 1transforms:- DecodeImage:to_rgb: Truechannel_first: False- ResizeImage:size: [224,224]- NormalizeImage:scale: 1.0/255.0mean: [0.485, 0.456, 0.406]std: [0.229, 0.224, 0.225]order: ''- ToCHWImage:PostProcess:name: Topktopk: 1class_id_map_file: "../process_data/label.txt"Metric:Train:- TopkAcc:topk: [1]Eval:- TopkAcc:topk: [1]

然后使用下面的命令进行训练:

export CUDA_VISIBLE_DEVICES=0,1
python3 -m paddle.distributed.launch \--gpus="0,1" \tools/train.py \-c config_lcnet.yaml 

训练完成后可以使用下面的命令可视化查看训练结果:

visualdl --logdir results/vdl

运行效果如下:

可以看到,基本在epoch=100以后就收敛了,最高top1准确率达到97.5%,准确率还是比较高的。

下面可以使用动态图对单张图像进行测试,命令如下:

python3 tools/infer.py -c config_lcnet.yaml -o Global.pretrained_model=output/PPLCNet_x0_5/best_model

输出如下:

[{'class_ids': [1], 'scores': [0.93522], 'file_name': '../testimgs/10.jpg', 'label_names': ['adult']}]

2.5 静态图导出

为了方便后面进行模型部署,将训练好的最佳模型进行静态图导出。具体命令如下:

python3 tools/export_model.py \-c config_lcnet.yaml \-o Global.pretrained_model=output/PPLCNet_x0_5/best_model \-o Global.save_inference_dir=output/inference

导出的静态图模型存放在output/inference文件夹下面,整个模型参数加起来不超过3M,因此可以看出这个训练好的PPLCNet_x0_5模型是一个非常轻量级的模型。

2.6 静态图推理

下面使用静态图来进行推理。在推理前先使用visualdl工具查看下静态图模型的输入和输出,这将为编写推理脚本奠定基础。

可以看到,输入是[batch,3,224,224]的float型图像数据,输出是[batch,2]的float型数据。尤其是输出的两个值,代表的是两个类别的概率。

有了上面的分析,下面可以用PaddleInference写一个推理脚本infer.py:

import cv2
import numpy as np
from paddle.inference import create_predictor
from paddle.inference import Config as PredictConfig# 加载静态图模型
model_path = "./output/inference/inference.pdmodel"
params_path = "./output/inference/inference.pdiparams"
pred_cfg = PredictConfig(model_path, params_path)
pred_cfg.enable_memory_optim()  # 启用内存优化
pred_cfg.switch_ir_optim(True)
pred_cfg.enable_use_gpu(500, 0)  # 启用GPU推理
predictor = create_predictor(pred_cfg)  # 创建PaddleInference推理器# 解析模型输入输出
input_names = predictor.get_input_names()
input_handle = {}
for i in range(len(input_names)):input_handle[input_names[i]] = predictor.get_input_handle(input_names[i])
output_names = predictor.get_output_names()
output_handle = predictor.get_output_handle(output_names[0])# 图像预处理
img = cv2.imread("../testimgs/10.jpg", flags=cv2.IMREAD_COLOR)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_AREA)
img = img.astype(np.float32)
PIXEL_MEANS =(0.485, 0.456, 0.406)    # RGB格式的均值和方差
PIXEL_STDS = (0.229, 0.224, 0.225)
img/=255.0
img-=np.array(PIXEL_MEANS)
img/=np.array(PIXEL_STDS)
img = np.transpose(img[np.newaxis, :, :, :], (0, 3, 1, 2))# 预测
input_handle["x"].copy_from_cpu(img)
predictor.run()
results = output_handle.copy_to_cpu()# 后处理
results = results.squeeze(0)
if results[0]>results[1]:print('小孩'+"  "+str(results[0]))
else:print('大人'+"  "+str(results[1]))

从网上随便找两张照片,运行效果如下:

输出结果:

小孩  0.7256172

输出结果:

大人  0.9533998

可以看到,推理效果还是比较满意的。

三、小结

本文以项目为主线,使用了PaddleClas算法套件解决了年龄分类问题。后续读者如果想要深入学习PaddlePaddle(飞桨)及相关算法套件,可以关注我的书籍(链接)。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/386068.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Python | Leetcode Python题解之第283题移动零

题目&#xff1a; 题解&#xff1a; class Solution:def moveZeroes(self, nums: List[int]) -> None:n len(nums)left right 0while right < n:if nums[right] ! 0:nums[left], nums[right] nums[right], nums[left]left 1right 1

ClickHouse 进阶【建表、查询优化】

1、ClickHouse 进阶 因为上一节部署了集群模式&#xff0c;所以需要启动 Zookeeper 和 ck 集群&#xff1b; 1.1、Explain 基本语法 EXPLAIN [AST | SYNTAX | PLAN | PIPELINE] [setting value, ...] SELECT ... [FORMAT ...] AST&#xff1a;用于查看语法树SYNTAX&#…

橙单后端项目下载编译遇到的问题与解决

今天下载orange-admin项目&#xff0c;不过下载下来运行出现一些问题。 1、涉及到XMLStreamException的几个类都出现下面的错误 The package javax.xml.stream is accessible from more than one module: <unnamed>, java.xml ctrl-shift-t 可以找到这个引入是哪些包里…

成为git砖家(5): 理解 HEAD

文章目录 1. git rev-parse 命令2. 什么是 HEAD2.1 创建分支当并未切换&#xff0c; HEAD 不变2.2 切换分支&#xff0c;HEAD 改变2.3 再次切换分支&#xff0c; HEAD 再次改变 3. detached HEAD4. HEAD 表示分支、表示 detached HEAD 有什么区别&#xff1f;区别相同点 5. HEA…

【SpringCloud】企业认证、分布式事务,分布式锁方案落地-2

目录 高并发缓存三问 - 穿透 缓存穿透 概念 现象举例 解决方案 缓存穿透 - 预热架构 缓存穿透 - 布隆过滤器 布隆过滤器 布隆过滤器基本思想​编辑 了解 高并发缓存三问 - 击穿 缓存击穿 高并发缓存三问 - 雪崩 缓存雪崩 解决方案 总结 为什么要使用数据字典&…

Python网络爬虫:基础与实战!附淘宝抢购源码

Python网络爬虫是一个强大的工具&#xff0c;用于从互联网上自动抓取和提取数据。下面我将为你概述Python网络爬虫的基础知识和一些实战技巧。 Python网络爬虫基础 1. HTTP请求与响应 网络爬虫的核心是发送HTTP请求到目标网站并接收响应。Python中的requests库是处理HTTP请求…

Java NIO (一)

因工作需要我接触到了netty框架&#xff0c;这让我想起之前为夺高薪而在CSDN购买的Netty课程。如今看来&#xff0c;这套课程买的很值。这套课程中关于NIO的讲解&#xff0c;让我对Tomcat产生了浓厚的兴趣&#xff0c;于是我阅读了Tomcat中关于服务端和客户端之间连接部分的源码…

乐尚代驾六订单执行一

加载当前订单 需求 无论是司机端&#xff0c;还是乘客端&#xff0c;遇到页面切换&#xff0c;重新登录小程序等&#xff0c;只要回到首页面&#xff0c;查看当前是否有正在执行订单&#xff0c;如果有跳转到当前订单执行页面 之前这个接口已经开发&#xff0c;为了测试&…

JAVAWeb实战(后端篇)

因为前后端代码内容过多&#xff0c;这篇只写后端的代码&#xff0c;前端的在另一篇写 项目实战一&#xff1a; 1.创建数据库,表等数据 创建数据库 create database schedule_system 创建表&#xff0c;并添加内容 SET NAMES utf8mb4; SET FOREIGN_KEY_CHECKS 0;-- ---------…

Node.js版本管理工具之NVM

目录 一、NVM介绍二、NVM的下载安装1、NVM下载2、卸载旧版Node.js3、安装 三、NVM配置及使用1、设置nvm镜像源2、安装Node.js3、卸载Node.js4、使用或切换Node.js版本5、设置全局安装路径和缓存路径 四、常用命令技术交流 博主介绍&#xff1a; 计算机科班人&#xff0c;全栈工…

Win11 操作(四)g502鼠标连接电脑不亮灯无反应

罗技鼠标连接电脑不亮灯无反应 前言 罗技技术&#x1f4a9;中&#x1f4a9;&#xff0c;贴吧技术神中神&#xff01; 最近买了一个g502&#xff0c;结果买回来直接插上电脑连灯都不亮&#xff0c;问了一下客服。客服简单的让我换接口&#xff0c;又是下载ghub之类的&#xf…

Linux 安装 GDB (无Root 权限)

引入 在Linux系统中&#xff0c;如果你需要在集群或者远程操作没有root权限的机子&#xff0c;安装GDB&#xff08;GNU调试器&#xff09;可能会有些限制&#xff0c;因为通常安装新软件或更新系统文件需要管理员权限。下面我们介绍可以在没有root权限的情况下安装GDB&#xf…

ElasticSearch核心之DSL查询语句实战

什么是DSL&#xff1f; Elasticsearch提供丰富且灵活的查询语言叫做DSL查询(Query DSL),它允许你构建更加复杂、强大的查询。 DSL(Domain Specific Language特定领域语言)以JSON请求体的形式出现。目前常用的框架查询方法什么的底层都是构建DSL语句实现的&#xff0c;所以你必…

openFeign配置okhttp

原来的项目出现了性能问题&#xff0c;老大不知道怎么的&#xff0c;让我改openFeign线程池为okhttp&#xff0c;说原生的不支持线程池性能比较差。 原openFeign配置文章地址 一、pom文件 <dependency><groupId>org.springframework.cloud</groupId><arti…

【短视频矩阵系统源码部署/技术应用开发】

短视频矩阵系统&#xff1a;选择专业服务商指南 该短视频矩阵系统由多个关键模块组成&#xff0c;包括混剪算法、账号管理与发布、消息处理以及数据管理等。为了优化带宽使用&#xff0c;文件导出功能已被独立处理。 此外&#xff0c;系统还集成了后台运营管理功能。 在技术架…

Python设计模式 - 工厂方法模式

定义 工厂方法模式是一种创建型设计模式&#xff0c;它定义一个创建对象的接口&#xff0c;让其子类来处理对象的创建&#xff0c;而不是直接实例化对象。 结构 抽象工厂&#xff08;Factory&#xff09;&#xff1a;声明工厂方法&#xff0c;返回一个产品对象。具体工厂类都…

git等常用工具以及cmake

一、将git中的代码克隆进电脑以及常用工具介绍 1.安装git 首先需要安装git sudo apt install git 注意一定要加--recursive&#xff0c;因为文件中有很多“引用文件“&#xff0c;即第三方文件&#xff08;库&#xff09;&#xff0c;加入该选项会将文件中包含的子模…

区块链技术如何重塑医疗健康行业未来?

区块链在医疗领域的应用日益广泛&#xff0c;主要体现在以下几个方面&#xff1a; 一、医疗数据管理 电子病历管理&#xff1a; 区块链技术可以用于构建去中心化的电子病历系统&#xff0c;确保病历数据的不可篡改性和安全性。患者可以通过区块链平台安全地管理自己的电子病历…

30岁决心转行,AI太香了

今天是一篇老学员的经历分享&#xff0c;此时的王同学在大洋彼岸即将毕业&#xff0c;手握多家北美大厂offer&#xff0c;一片明媚。谁能想到王同学的转码之路竟始于一场裁员&#xff0c;这场访谈拉开了他的回忆。 最近总刷到一些关于转行的话题&#xff0c;很多刚毕业的同学喜…

【OpenCV C++20 学习笔记】图片融合

图片融合 原理实现结果展示完整代码 原理 关于OpenCV的配置和基础用法&#xff0c;请参阅本专栏的其他文章&#xff1a;垚武田的OpenCV合集 这里采用的图片熔合的算法来自Richard Szeliski的书《Computer Vision: Algorithms and Applications》&#xff08;《计算机视觉&#…