【图像分类实用脚本】数据可视化以及高数量类别截断

图像分类时,如果某个类别或者某些类别的数量远大于其他类别的话,模型在计算的时候,更倾向于拟合数量更多的类别;因此,观察类别数量以及对数据量多的类别进行截断是很有必要的。

1.准备数据

数据的格式为图像分类数据集格式,根目录下分为train和val文件夹,每个文件夹下以类别名命名的子文件夹:

.
├── ./datasets
│ ├── ./datasets/train/A
│ │ ├── ./datasets/train/A/1.jpg
│ │ ├── ./datasets/train/A/2.jpg
│ │ ├── ./datasets/train/A/3.jpg
│ │ ├── …
│ ├── ./datasets/train/B
│ │ ├── ./datasets/train/B/1.jpg
│ │ ├── ./datasets/train/B/1.jpg
│ │ ├── ./datasets/train/B/1.jpg
│ │ ├── …
│ ├── ./datasets/val/A
│ │ ├── ./datasets/val/A/1.jpg
│ │ ├── ./datasets/val/A/2.jpg
│ │ ├── ./datasets/val/A/3.jpg
│ │ ├── …
│ ├── ./datasets/val/B
│ │ ├── ./datasets/val/B/1.jpg
│ │ ├── ./datasets/val/B/1.jpg
│ │ ├── ./datasets/val/B/1.jpg
│ │ ├── …

2.查看数据分布

import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pddef count_images(directory, image_extensions):"""统计每个子文件夹中的图像数量。:param directory: 主目录路径(train或val):param image_extensions: 允许的图像文件扩展名元组:return: 一个字典,键为类别名,值为图像数量"""counts = {}if not os.path.exists(directory):print(f"目录不存在: {directory}")return countsfor class_name in os.listdir(directory):class_path = os.path.join(directory, class_name)if os.path.isdir(class_path):# 统计符合扩展名的文件数量image_count = sum(1 for file in os.listdir(class_path)if file.lower().endswith(image_extensions))counts[class_name] = image_countreturn countsdef count_images_in_single_directory(directory, image_extensions):"""统计单个目录下每个类别的图像数量。:param directory: 主目录路径:param image_extensions: 允许的图像文件扩展名元组:return: 一个字典,键为类别名,值为图像数量"""counts = {}if not os.path.exists(directory):print(f"目录不存在: {directory}")return countsfor class_name in os.listdir(directory):class_path = os.path.join(directory, class_name)if os.path.isdir(class_path):image_count = sum(1 for file in os.listdir(class_path)if file.lower().endswith(image_extensions))counts[class_name] = image_countreturn countsdef autolabel(ax, rects):"""在每个柱状图上方添加数值标签。:param ax: Matplotlib 的轴对象:param rects: 柱状图对象"""for rect in rects:height = rect.get_height()ax.annotate(f'{height}',xy=(rect.get_x() + rect.get_width() / 2, height),xytext=(0, 3),  # 3 points vertical offsettextcoords="offset points",ha='center', va='bottom')def plot_distribution(all_classes, train_values, val_values, output_path, has_val=False):"""绘制并保存训练集和验证集中每个类别的图像数量分布柱状图。如果没有验证集数据,则只绘制训练集数据。:param all_classes: 所有类别名称列表:param train_values: 训练集中每个类别的图像数量列表:param val_values: 验证集中每个类别的图像数量列表(如果有的话):param output_path: 保存图表的文件路径:param has_val: 是否包含验证集数据"""x = np.arange(len(all_classes))  # 类别位置width = 0.35  # 柱状图的宽度fig, ax = plt.subplots(figsize=(12, 8))if has_val:rects1 = ax.bar(x - width/2, train_values, width, label='Train')rects2 = ax.bar(x + width/2, val_values, width, label='Validation')else:rects1 = ax.bar(x, train_values, width, label='Count')# 添加一些文本标签ax.set_xlabel('Category')ax.set_ylabel('Number of Images')title = 'Number of Images in Each Category for Train and Validation' if has_val else 'Number of Images in Each Category'ax.set_title(title)ax.set_xticks(x)ax.set_xticklabels(all_classes, rotation=45, ha='right')ax.legend() if has_val else ax.legend(['Count'])# 自动标注柱状图上的数值autolabel(ax, rects1)if has_val:autolabel(ax, rects2)fig.tight_layout()# 保存图表为图片文件plt.savefig(output_path, dpi=300, bbox_inches='tight')print(f"图表已保存到 {output_path}")def compute_and_display_statistics(counts_dict, dataset_name, save_csv=False):"""计算并展示统计数据,包括总图像数量、类别数量、平均每个类别的图像数量和类别占比。:param counts_dict: 类别名称与图像数量的字典:param dataset_name: 数据集名称(例如 'Train', 'Validation', 'Dataset'):param save_csv: 是否保存统计结果为 CSV 文件"""total_images = sum(counts_dict.values())num_classes = len(counts_dict)avg_per_class = total_images / num_classes if num_classes > 0 else 0# 计算每个类别的占比category_proportions = {cls: (count / total_images * 100) if total_images > 0 else 0 for cls, count in counts_dict.items()}# 创建 DataFramedf = pd.DataFrame({'类别名称': list(counts_dict.keys()),'图像数量': list(counts_dict.values()),'占比 (%)': [f"{prop:.2f}" for prop in category_proportions.values()]})# 排序 DataFrame 按图像数量降序df = df.sort_values(by='图像数量', ascending=False)print(f"\n===== {dataset_name} 数据统计 =====")print(df.to_string(index=False))print(f"总图像数量: {total_images}")print(f"类别数量: {num_classes}")print(f"平均每个类别的图像数量: {avg_per_class:.2f}")# 根据 save_csv 参数决定是否保存为 CSV 文件if save_csv:# 将数据集名称转换为小写并去除空格,以作为文件名的一部分sanitized_name = dataset_name.lower().replace(" ", "_").replace("(", "").replace(")", "")csv_filename = f"{sanitized_name}_statistics.csv"df.to_csv(csv_filename, index=False, encoding='utf-8-sig')print(f"统计表已保存为 {csv_filename}\n")def main():# ================== 配置参数 ==================# 设置数据集的根目录路径dataset_root = 'datasets/device_cls_merge_manual_with_21w_1218'  # 替换为你的数据集路径# 定义train和val目录train_dir = os.path.join(dataset_root, 'train')val_dir = os.path.join(dataset_root, 'val')# 定义允许的图像文件扩展名image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.gif')# 输出图表的路径output_path = 'dataset_distribution.png'  # 你可以更改为你想要的文件名和路径# 是否保存统计结果为 CSV 文件(默认不保存)SAVE_CSV = False  # 设置为 True 以启用保存 CSV# ================== 统计图像数量 ==================has_train = os.path.exists(train_dir) and os.path.isdir(train_dir)has_val = os.path.exists(val_dir) and os.path.isdir(val_dir)if has_train and has_val:print("检测到 'train' 和 'val' 目录。统计训练集和验证集中的图像数量...")train_counts = count_images(train_dir, image_extensions)val_counts = count_images(val_dir, image_extensions)# 获取所有类别的名称(确保train和val中的类别一致)all_classes = sorted(list(set(train_counts.keys()) | set(val_counts.keys())))# 准备绘图数据train_values = [train_counts.get(cls, 0) for cls in all_classes]val_values = [val_counts.get(cls, 0) for cls in all_classes]# ================== 计算并展示统计数据 ==================compute_and_display_statistics(train_counts, '训练集 (Train)', save_csv=SAVE_CSV)compute_and_display_statistics(val_counts, '验证集 (Validation)', save_csv=SAVE_CSV)# ================== 绘制并保存图表 ==================print("绘制并保存训练集和验证集的图表...")plot_distribution(all_classes, train_values, val_values, output_path, has_val=True)else:print("未检测到 'train' 和 'val' 目录。将统计主目录下的图像数量...")# 如果没有train和val目录,则统计主目录下的图像分布main_counts = count_images_in_single_directory(dataset_root, image_extensions)# 获取所有类别的名称all_classes = sorted(main_counts.keys())# 准备绘图数据main_values = [main_counts.get(cls, 0) for cls in all_classes]# 定义输出图表路径(可以区分不同的输出文件名)output_path_single = 'dataset_distribution_single.png'  # 或者使用与train_val相同的output_path# ================== 计算并展示统计数据 ==================compute_and_display_statistics(main_counts, '数据集 (Dataset)', save_csv=SAVE_CSV)# ================== 绘制并保存图表 ==================print("绘制并保存主目录的图表...")plot_distribution(all_classes, main_values, [], output_path_single, has_val=False)if __name__ == "__main__":main()

下图为原始数据集运行结果,可以看到数据存在严重不均衡问题
在这里插入图片描述

3.数据截断

import os
import shutil
import randomdef count_images(directory, image_extensions):"""统计每个子文件夹中的图像文件路径列表。:param directory: 主目录路径(train或val):param image_extensions: 允许的图像文件扩展名列表:return: 一个字典,键为类别名,值为图像文件路径列表"""counts = {}if not os.path.exists(directory):print(f"目录不存在: {directory}")return countsfor class_name in os.listdir(directory):class_path = os.path.join(directory, class_name)if os.path.isdir(class_path):# 获取符合扩展名的文件列表images = [file for file in os.listdir(class_path)if file.lower().endswith(tuple(image_extensions))]image_paths = [os.path.join(class_path, img) for img in images]counts[class_name] = image_pathsreturn countsdef truncate_dataset(class_images, threshold, seed=42):"""对每个类别的图像进行截断,如果超过阈值则随机选择一定数量的图像。:param class_images: 一个字典,键为类别名,值为图像文件路径列表:param threshold: 每个类别的图像数量阈值:param seed: 随机种子:return: 截断后的类别图像字典"""truncated = {}random.seed(seed)for class_name, images in class_images.items():if len(images) > threshold:truncated_images = random.sample(images, threshold)truncated[class_name] = truncated_imagesprint(f"类别 '{class_name}' 超过阈值 {threshold},已随机选择 {threshold} 张图像。")else:truncated[class_name] = imagesprint(f"类别 '{class_name}' 不超过阈值 {threshold},保留所有 {len(images)} 张图像。")return truncateddef copy_images(truncated_data, subset, output_root):"""将截断后的图像复制到输出目录,保持原有的目录结构。:param truncated_data: 截断后的类别图像字典:param subset: 'train' 或 'val':param output_root: 输出根目录路径"""for class_name, images in truncated_data.items():dest_dir = os.path.join(output_root, subset, class_name)os.makedirs(dest_dir, exist_ok=True)for img_path in images:img_name = os.path.basename(img_path)dest_path = os.path.join(dest_dir, img_name)shutil.copy2(img_path, dest_path)print(f"'{subset}' 子集已复制到 {output_root}")def main():"""主函数,执行数据集截断和复制操作。"""# ================== 配置参数 ==================# 原始数据集根目录路径input_dir = 'datasets/device_cls_merge_manual_with_21w_1218_train_val_224'  # 替换为你的原始数据集路径# 截断后数据集的输出根目录路径output_dir = 'datasets/device_cls_merge_manual_with_21w_1218_train_val_224_truncate'  # 替换为你希望保存截断后数据集的路径# 训练集每个类别的图像数量阈值train_threshold = 2000  # 设置为你需要的训练集阈值# 验证集每个类别的图像数量阈值val_threshold = 400  # 设置为你需要的验证集阈值# 允许的图像文件扩展名image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff']# 随机种子以确保可重复性random_seed = 42# ================== 脚本实现 ==================# 设置随机种子random.seed(random_seed)# 定义train和val目录路径train_input_dir = os.path.join(input_dir, 'train')val_input_dir = os.path.join(input_dir, 'val')# 统计train和val中的图像print("统计训练集中的图像数量...")train_counts = count_images(train_input_dir, image_extensions)print("统计验证集中的图像数量...")val_counts = count_images(val_input_dir, image_extensions)# 截断train和val中的图像print("\n截断训练集中的图像...")truncated_train = truncate_dataset(train_counts, train_threshold, random_seed)print("\n截断验证集中的图像...")truncated_val = truncate_dataset(val_counts, val_threshold, random_seed)# 复制截断后的图像到输出目录print("\n复制截断后的训练集图像...")copy_images(truncated_train, 'train', output_dir)print("复制截断后的验证集图像...")copy_images(truncated_val, 'val', output_dir)print("\n数据集截断完成。")if __name__ == "__main__":main()

再次查看已经符合截断后的数据分布了
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/492691.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Javascript-web API-day02

文章目录 01-事件监听02-点击关闭广告03-随机点名案例04-鼠标经过或离开事件05-可点击的轮播图06-小米搜索框07-键盘类型事件08-键盘事件-发布评论案例09-focus选择器10-评论回车发布11-事件对象12-trim方法13-环境对象14-回调函数15-tab栏切换 01-事件监听 <!DOCTYPE html…

c语言-------循环结构

基本概念 循环结构是C语言中一种重要的程序控制结构&#xff0c;它允许程序在满足一定条件的情况下&#xff0c;反复执行一段代码。这可以避免重复编写相似的代码&#xff0c;提高代码的效率和可读性。 while循环 语法格式 while (条件表达式) { 循环体语句; } 执行流程 首先判…

Centos创建共享文件夹拉取文件

1.打开VMware程序&#xff0c;鼠标右检你的虚拟机&#xff0c;打开设置 2.点击选项——共享文件夹——总是启用 点击添加&#xff0c;设置你想要共享的文件夹在pc上的路径&#xff08;我这里已经添加过了就不加了&#xff09; 注意不要中文&#xff0c;建议用share&#xff0c…

SpringBoot项目Jar包使用systemctl运行

1. 前言 SpringBoot项目打成jar包后&#xff0c;可以直接使用 java -jar xxx.jar 启动。但是为了方便启动和停止服务&#xff0c;通常我们会写两个脚本&#xff0c;分别是启动脚本 start.sh 和 停止脚本 shutdown.sh&#xff08;这两个脚本内容我们下文会实现&#xff09;&…

计算机网络-L2TP VPN基础概念与原理

一、概述 前面学习了GRE和IPSec VPN&#xff0c;今天继续学习另外一个也很常见的VPN类型-L2TP VPN。 L2TP&#xff08;Layer 2 Tunneling Protocol&#xff09; 协议结合了L2F协议和PPTP协议的优点&#xff0c;是IETF有关二层隧道协议的工业标准。L2TP是虚拟私有拨号网VPDN&…

踩准智能汽车+机器人两大风口,速腾聚创AI+机器人应用双线爆发

日前&#xff0c;RoboSense速腾聚创交出了一份亮眼的Q3财报。受到多重利好消息影响&#xff0c;其股价也应势连续大涨。截止12月9日发稿前&#xff0c;速腾聚创股价近一个月内累计涨幅已超88%。 财务数据方面&#xff0c;速腾聚创在今年前三季度实现总收入约11.3亿元&#xff0…

使用Idea自带的git功能进行分支合并

文章目录 1.背景描述2.分支切换3.分支合并的具体操作4.将在local环境下&#xff0c;从dev合并到qas分支上的代码&#xff0c;推送到远端 1.背景描述 目前在开发的当前项目有四个分支&#xff0c;master(主分支)、pre(预生产分支)、qas(测试分支)、dev(开发分支)&#xff1b; …

EE308FZ_Sixth Assignment_Beta Sprint_Sprint Essay 3

Assignment 6Beta SprintCourseEE308FZ[A] — Software EngineeringClass Link2401_MU_SE_FZURequirementsTeamwork—Beta SprintTeam NameFZUGOObjectiveSprint Essay 3_Day5-Day6 (12.15-12.16)Other Reference1. WeChat Mini Program Design Guide 2. Javascript Style Guid…

凯酷全科技抖音电商服务的卓越践行者

在数字经济蓬勃发展的今天&#xff0c;电子商务已成为企业增长的新引擎。随着短视频平台的崛起&#xff0c;抖音作为全球领先的短视频社交平台&#xff0c;不仅改变了人们的娱乐方式&#xff0c;也为品牌和商家提供了全新的营销渠道。厦门凯酷全科技有限公司&#xff08;以下简…

架构信息收集(小迪网络安全笔记~

附&#xff1a;完整笔记目录~ ps&#xff1a;本人小白&#xff0c;笔记均在个人理解基础上整理&#xff0c;若有错误欢迎指正&#xff01; 2.2 架构信息收集 引子&#xff1a;一个Web应用的构成&#xff0c;由诸多组件&服务相结合&#xff0c;而域名仅是处于Web架构中最表…

一.photoshop导入到spine

这里使用的是 photoshoptospine脚本 下载地址:https://download.csdn.net/download/boyxgb/90156744 脚本的使用,可以通过文件的脚本的浏览,浏览该脚本使用该脚本,也可以将该脚本放在photoshop安装文件夹里的script文件夹下,具体路径:Photoshop\Presets\Scripts,重启photosho…

Mapbox-GL 的源码解读的一般步骤

Mapbox-GL 是一个非常优秀的二三维地理引擎&#xff0c;随着智能驾驶时代的到来&#xff0c;应用也会越来越广泛&#xff0c;关于mapbox-gl和其他地理引擎的详细对比&#xff08;比如CesiumJS&#xff09;&#xff0c;后续有时间会加更。地理首先理解 Mapbox-GL 的源码是一项复…

SparkSQL运行架构及原理

文章目录 SparkSQL运行架构及原理1.1. Catalyst优化器简介1.2. SparkSQL运行架构1.3. SparkSQL解析Core底层原理1.4. 执行计划查看 SparkSQL运行架构及原理 1.1. Catalyst优化器简介 SparkSQL使得我们开发人员可以使用DSL风格的数据来处理数据&#xff0c;甚至可以直接使用SQ…

大数据-254 离线数仓 - Airflow 任务调度 核心交易调度任务集成

点一下关注吧&#xff01;&#xff01;&#xff01;非常感谢&#xff01;&#xff01;持续更新&#xff01;&#xff01;&#xff01; Java篇开始了&#xff01; 目前开始更新 MyBatis&#xff0c;一起深入浅出&#xff01; 目前已经更新到了&#xff1a; Hadoop&#xff0…

昇思25天学习打卡营第33天|共赴算力时代

文章目录 一、平台简介二、深度学习模型2.1 处理数据集2.2 模型训练2.3 加载模型 三、共赴算力时代 一、平台简介 昇思大模型平台&#xff0c;就像是AI学习者和开发者的超级基地&#xff0c;这里不仅提供丰富的项目、模型和大模型体验&#xff0c;还有一大堆经典数据集任你挑。…

Docker 镜像加速和配置的分享 云服务器搭建beef-xss

前言 最近很多的docker镜像加速都鸡鸡了 找点资源是越来越不容易了 什么事docker 因为我是个业余的人 我简单的说 docker就是比如我们的软件商店的 下载 docker镜像&#xff08;之前就是我们在服务器上搭建网站 和环境的很费力费时 之后就有了这个 镜像 &#xff1a;这…

浅谈怎样系统的准备前端面试

前言 创业梦碎&#xff0c;回归现实&#xff0c;7 月底毅然裸辞&#xff0c;苦战两个月&#xff0c;拿到了美团和字节跳动的 offer&#xff0c;这算是从业以来第一次真正意义的面试&#xff0c;遇到蛮多问题&#xff0c;比如一开始具体的面试过程我都不懂&#xff0c;基本一直是…

告别机器人味:如何让ChatGPT写出有灵魂的内容

目录 ChatGPT的一些AI味道小问题 1.提供编辑指南 2.提供样本 3.思维链大纲 4.融入自己的想法 5.去除重复增加多样性 6.删除废话 ChatGPT的一些AI味道小问题 大多数宝子们再使用ChatGPT进行写作时&#xff0c;发现我们的老朋友ChatGPT在各类写作上还有点“机器人味”太重…

【长城杯】Web题 hello_web 解题思路

查看源代码发现路径提示 访问…/tips.php显示无用页面&#xff0c;怀疑…/被过滤&#xff0c;采用…/./形式&#xff0c;看到phpinfo()页面 注意到disable_functions&#xff0c;禁用了很多函数 访问hackme.php,看到页面源码 发现eval函数&#xff0c;包含base64 解密获得php代…

Windows部署Docker及PostgreSQL数据库相关操作

一、Windows安装Docker 1.wsl安装 以管理员身份启动命令行&#xff0c;运行&#xff1a;wsl --install&#xff1b; 安装结束后&#xff0c;重启电脑&#xff0c;以管理员身份启动命令行&#xff0c;运行&#xff1a;wsl --install -d Ubuntu&#xff1b; 中间需要输入用户名…