【模块化编程】制作、可视化不平衡数据(长尾数据)(以Cifar-10为例)

🌈 个人主页:十二月的猫-CSDN博客
🔥 系列专栏: 🏀《PyTorch科研加速指南:即插即用式模块开发》-CSDN博客

💪🏻 十二月的寒冬阻挡不了春天的脚步,十二点的黑夜遮蔽不住黎明的曙光 

目录

 1. 前言

2. Cifar数据集介绍

3. 分模块介绍Cifar10-LT数据集的生成方式

3.1 导入包

3.2 数据集描述与引用

3.3 Cifar10LTConfig 类

3.4 Cifar10 类

3.5 _info 方法

3.6 _split_generators 方法

3.7 _generate_examples 方法

3.8 _imbalance_indices 方法

3.9 _generate_indices_targets 方法

3.10 _get_img_num_per_cls 方法

3.11 _gen_imbalanced_data 方法

4. 完整代码

5. 统计&可视化长尾数据集 

6. 总结


 1. 前言

  • 👑《PyTorch科研加速指南:即插即用式模块开发》专栏持续更新中,未来最少文章数量为100篇。由于专栏刚刚建立,目前免费,后续将慢慢恢复原价至99.9🍉。
  • 👑《PyTorch科研加速指南:即插即用式模块开发》专栏主要针对零基础入门的小伙伴。不需要Python基础,不需要深度学习基础,只要你愿意学,这一个专栏将真正让你做到零基础入门。
  • 🔥每例项目都包括理论讲解、数据集、源代码

正在更新中💹💹

🚨项目运行环境:

  • 平台:Window11
  • 语言环境:Python3.8
  • 运行环境1:PyCharm 2021.3
  • 运行环境2:Jupyter Notebook 7.3.2
  • 框架:PyTorch 2.5.1(CUDA11.8)

2. Cifar数据集介绍

3. 分模块介绍Cifar10-LT数据集的生成方式

3.1 导入包

import pickle
from typing import Dict, Iterator, List, Tuple, BinaryIO
import numpy as np
import datasets
from datasets.tasks import ImageClassification
  • pickle: 用于序列化和反序列化Python对象,这里用于加载CIFAR-10数据集。
  • typing: 提供类型注解,增强代码可读性和类型检查。
  • numpy: 用于高效的数值计算,特别是数组操作。
  • datasets: Hugging Face的datasets库,用于加载和处理数据集。
  • ImageClassification: 定义图像分类任务的模板。

注意:datasets.tasks 的用法在最新的datasets中不适用,大家可以下载旧版本😇

3.2 数据集描述与引用

数据集的引用:

_CITATION = """\
@TECHREPORT{Krizhevsky09learningmultiple,author = {Alex Krizhevsky},title = {Learning multiple layers of features from tiny images},institution = {},year = {2009}
}
"""

 _CITATION: 数据集的引用信息,引用CIFAR-10的原始论文。

数据集的描述:

_DESCRIPTION = """\
The CIFAR-10-LT imbalanced dataset is comprised of under 60,000 color images, each measuring 32x32 pixels, 
distributed across 10 distinct classes.  
The dataset includes 10,000 test images, with 1000 images per class, 
and fewer than 50,000 training images.
The number of samples within each class of the train set decreases exponentially with factors of 10, 20, 50, 100, or 200.
"""

_DESCRIPTION: 数据集的详细描述,说明CIFAR-10-LT是一个长尾数据集,训练集类别样本数呈指数下降。

原数据集下载链接:

_DATA_URL = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"

DATA_URL: 数据集下载链接。

数据集的类别名称:

_NAMES = ["airplane","automobile","bird","cat","deer","dog","frog","horse","ship","truck",
]

_NAMES: CIFAR-10数据集的10个类别名称。

3.3 Cifar10LTConfig 类

Cifar10LTConfig 类:继承自datasets.BuilderConfig,用于配置CIFAR-10-LT数据集的不同版本。

class Cifar10LTConfig(datasets.BuilderConfig):"""BuilderConfig for CIFAR-10-LT."""def __init__(self, imb_type: str, imb_factor: float, rand_number: int = 0, cls_num: int = 10, **kwargs):"""BuilderConfig for CIFAR-10-LT.Args:imb_type (str): imbalance type, including 'exp', 'step'.imb_factor (float): imbalance factor.rand_number (int): random seed, default: 0.cls_num (int): number of classes, default: 10.**kwargs: keyword arguments forwarded to super."""# Version history:super().__init__(version=datasets.Version("1.0.0"), **kwargs)self.imb_type = imb_typeself.imb_factor = imb_factorself.rand_number = rand_numberself.cls_num = cls_numnp.random.seed(self.rand_number)

__init__: 初始化函数,接收以下参数:

  1. imb_type: 不平衡类型,如'exp'(指数)或'step'(阶梯)。
  2. imb_factor: 不平衡因子,控制类别样本数的比例。
  3. rand_number: 随机种子,默认值为0。
  4. cls_num: 类别数量,默认值为10。
  5. **kwargs: 其他参数,传递给父类。

简单来说:Cifar10LTConfig就是用来设置生成Cifar10LT数据集的参数(不同设置能够生成不同程度的长尾数据集)

3.4 Cifar10 类

Cifar10 类:利用前面的设置去生成不同参数下(不平衡程度不同)的长尾数据。

class Cifar10(datasets.GeneratorBasedBuilder):"""CIFAR-10 Dataset"""BUILDER_CONFIGS = [Cifar10LTConfig(name="r-10",description="CIFAR-10-LT-r-10 Dataset",imb_type='exp',imb_factor=1/10,rand_number=0,cls_num=10,),Cifar10LTConfig(name="r-20",description="CIFAR-10-LT-r-20 Dataset",imb_type='exp',imb_factor=1/20,rand_number=0,cls_num=10,),Cifar10LTConfig(name="r-50",description="CIFAR-10-LT-r-50 Dataset",imb_type='exp',imb_factor=1/50,rand_number=0,cls_num=10,),Cifar10LTConfig(name="r-100",description="CIFAR-10-LT-r-100 Dataset",imb_type='exp',imb_factor=1/100,rand_number=0,cls_num=10,),Cifar10LTConfig(name="r-200",description="CIFAR-10-LT-r-200 Dataset",imb_type='exp',imb_factor=1/200,rand_number=0,cls_num=10,),]

BUILDER_CONFIGS: 定义不同版本的CIFAR-10-LT数据集配置,包括r-10、r-20、r-50、r-100和r-200。

3.5 _info 方法

_info方法: 返回数据集的基本信息,包括描述、特征、引用和任务模板。

def _info(self) -> datasets.DatasetInfo:return datasets.DatasetInfo(description=_DESCRIPTION,features=datasets.Features({"img": datasets.Image(),"label": datasets.features.ClassLabel(names=_NAMES),}),supervised_keys=None,homepage="https://www.cs.toronto.edu/~kriz/cifar.html",citation=_CITATION,task_templates=[ImageClassification(image_column="img", label_column="label")],)

-> :指示函数返回值的类型。

datasets.DatasetInfo 主要有以下几个字段:

  1. description

    • 数据集的描述,可以是一个简短的文本,解释数据集的背景、来源、用途等。
  2. features

    • 用于定义数据集的特征(即数据的字段和类型)。例如,数据集可能有图像(Image 类型)和标签(ClassLabel 类型)字段。
    • 使用 datasets.Features() 来定义。
  3. citation

    • 数据集的引用信息,通常是一个字符串,包含数据集的学术引用格式(如 BibTeX)。
  4. homepage

    • 数据集的官方网站或主页,通常指向数据集的来源页面,提供更多的背景和下载信息。
  5. license

    • 数据集的许可证信息,通常描述使用数据集的条款和条件。
  6. task_templates

    • 定义数据集适用于哪些机器学习任务。例如,图像分类任务、文本分类任务等。
  7. supervised_keys

    • 在监督学习任务中,supervised_keys 用来指示哪些字段是特征(输入)和标签(输出)。如果数据集不需要这些信息,通常设置为 None

3.6 _split_generators 方法

_split_generators方法: 下载数据集并生成训练集和测试集的生成器。

def _split_generators(self, dl_manager: datasets.DownloadManager) -> List[datasets.SplitGenerator]:archive = dl_manager.download(_DATA_URL) # 这里的dl_manager在前面的方法中自定义的return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"files": dl_manager.iter_archive(archive), "split": "train"}),datasets.SplitGenerator(name=datasets.Split.TEST, gen_kwargs={"files": dl_manager.iter_archive(archive), "split": "test"}),]

dl_manager: datasets.DownloadManager:当:后面的不是一个循环时,而是一个类型时表示对前面实例数据类型的提示。

3.7 _generate_examples 方法

_generate_examples方法: 生成数据集的样本。

    def _generate_examples(self, files: Iterator[Tuple[str, BinaryIO]], split: str) -> Iterator[Dict]:"""This function returns the examples in the array form."""if split == "train":batches = ["data_batch_1", "data_batch_2", "data_batch_3", "data_batch_4", "data_batch_5"]if split == "test":batches = ["test_batch"]batches = [f"cifar-10-batches-py/{filename}" for filename in batches]for path, fo in files:if path in batches:dict = pickle.load(fo, encoding="bytes")labels = dict[b"labels"]images = dict[b"data"]if split == "train":indices = self._imbalance_indices()else:indices = range(len(labels))for idx in indices:img_reshaped = np.transpose(np.reshape(images[idx], (3, 32, 32)), (1, 2, 0))yield f"{path}_{idx}", {"img": img_reshaped,"label": labels[idx],}break

具体功能如下:

if split == "train":batches = ["data_batch_1", "data_batch_2", "data_batch_3", "data_batch_4", "data_batch_5"]
if split == "test":batches = ["test_batch"]
batches = [f"cifar-10-batches-py/{filename}" for filename in batches]

根据split参数选择训练集或测试集的文件名。


for path, fo in files:if path in batches:dict = pickle.load(fo, encoding="bytes")labels = dict[b"labels"]images = dict[b"data"]

加载数据文件并解析标签和图像。


if split == "train":indices = self._imbalance_indices()
else:indices = range(len(labels))

 如果是训练集,调用_imbalance_indices生成不平衡的索引(下面介绍);否则测试机则使用全部样本。


for idx in indices:img_reshaped = np.transpose(np.reshape(images[idx], (3, 32, 32)), (1, 2, 0))yield f"{path}_{idx}", {"img": img_reshaped,"label": labels[idx],}

1.for idx in indices:

  • 这行代码表示对 indices 中的每一个元素进行迭代,indice是一个列表或数组,包含了一些索引值。每次迭代时,idx 就是 indices 中的当前索引值。

2. img_reshaped = np.transpose(np.reshape(images[idx], (3, 32, 32)), (1, 2, 0))

  • 这行代码先对 images[idx] 中的图像数据进行操作。假设 images 是一个包含图像数据的列表或数组。

  • images[idx]:根据当前的 idximages 中取出对应的图像。

  • np.reshape(images[idx], (3, 32, 32)):首先使用 numpyreshape 函数将图像数据重新调整为形状 (3, 32, 32),通常这意味着将图像从原来的形状(例如 (32, 32, 3))调整为一个不同的维度。在这里,假设图像是 RGB 色彩图像,所以会有 3 个颜色通道(红、绿、蓝),每个通道大小为 32x32 像素。

  • np.transpose(..., (1, 2, 0)):然后通过 np.transpose 对图像的维度进行转换。原始维度是 (3, 32, 32),意味着 3 个颜色通道在前,32x32 像素在后。(1, 2, 0) 表示我们将维度顺序改变为 (32, 32, 3),即图像的高度和宽度维度排在前面,颜色通道排在后面。通常这是为了适配不同的图像处理库或模型的输入格式。

  • img_reshaped:经过这些操作,最终得到的 img_reshaped 就是调整好维度后的图像数据。

3. yield f"{path}_{idx}", { "img": img_reshaped, "label": labels[idx] }

  • 这一行使用 yield 语句返回一个生成器对象。
  • f"{path}_{idx}":这是一个 f-string,用来生成一个字符串,形式为 path_idx,其中 path 是某个路径(可能是数据集的根目录或者其他信息),idx 是当前的索引。这个字符串通常用于作为该数据项的标识符。
  • {"img": img_reshaped, "label": labels[idx]}:这是一个字典,包含两个键值对:
    • "img":键对应的值是之前得到的 img_reshaped,即处理后的图像数据。
    • "label":键对应的值是 labels[idx],即图像对应的标签。labels 可能是一个包含每张图像类别标签的列表或数组。

yield 会将 f"{path}_{idx}" 和字典 {"img": img_reshaped, "label": labels[idx]} 作为一对数据返回,这个数据对通常会在后续代码中被进一步处理。

3.8 _imbalance_indices 方法

_imbalance_indices 方法:这个函数生成不平衡的数据集的索引。

def _imbalance_indices(self) -> List[int]:"""This function returns the indices of imbalanced CIFAR-10-LT dataset."""dl_manager = datasets.DownloadManager()archive = dl_manager.download(_DATA_URL)data_iterator = self._generate_indices_targets(dl_manager.iter_archive(archive), "train")indices = []targets = []for i, targets_dict in data_iterator:indices.append(i)targets.append(targets_dict["label"])data_length = len(indices)img_num_per_cls = self._get_img_num_per_cls(data_length)new_indices, _ = self._gen_imbalanced_data(img_num_per_cls, targets)return new_indices
  • 使用 datasets.DownloadManager() 下载 CIFAR-10 数据集。
  • 调用 _generate_indices_targets 函数,获取训练集数据的索引和标签。
  • 遍历数据迭代器,收集数据的索引和标签。
  • 计算数据集的总长度 data_length
  • 使用 _get_img_num_per_cls 函数计算每个类别的目标图像数量。
  • 使用 _gen_imbalanced_data 函数生成不平衡数据集的索引。

3.9 _generate_indices_targets 方法

_generate_indices_targets 方法:这个函数用于从给定的文件中加载 CIFAR-10 数据集的标签,并生成数据的索引和标签字典。

def _generate_indices_targets(self, files: Iterator[Tuple[str, BinaryIO]], split: str) -> Iterator[Dict]:"""This function returns the examples in the array form."""if split == "train":batches = ["data_batch_1", "data_batch_2", "data_batch_3", "data_batch_4", "data_batch_5"]if split == "test":batches = ["test_batch"]batches = [f"cifar-10-batches-py/{filename}" for filename in batches]for path, fo in files:if path in batches:dict = pickle.load(fo, encoding="bytes")labels = dict[b"labels"]for idx, _ in enumerate(labels):yield f"{path}_{idx}", {"label": labels[idx],}break
  • 输入
    • split: 一个字符串,表示数据集的划分,可以是 "train" 或 "test"。
    • files: 是一个文件迭代器,每个元素是一个元组 (path, fo),其中 path 是文件路径,fo 是文件对象。
  • 处理
    • 根据 split 的值,选择训练集 (train) 或测试集 (test) 的数据批次文件。
    • 对于训练集,选择了 5 个数据批次;对于测试集,只选择了 1 个批次。
    • 使用 pickle 加载批次文件,读取其中的标签(即图片的类别标签)。
    • 遍历标签,为每个标签生成一个字典,其中键为 "label",值为对应的标签。返回一个 (index, label) 的生成器,index 由文件路径和索引 idx 组合而成。                         
  • 返回值:
    • 返回的是一个生成器,生成元组 (index, {"label": label}),其中 index 是由文件路径和图像索引构成的唯一标识,label 是图像对应的标签。

3.10 _get_img_num_per_cls 方法

_get_img_num_per_cls 方法:根据采用的不平衡策略+每一类图片的数量按照特定计算方法确定每一个类图片的数量

def _get_img_num_per_cls(self, data_length: int) -> List[int]:"""Get the number of images per class given the imbalance ratio and total number of images."""img_max = data_length / self.config.cls_numimg_num_per_cls = []if self.config.imb_type == 'exp':for cls_idx in range(self.config.cls_num):num = img_max * (self.config.imb_factor**(cls_idx / (self.config.cls_num - 1.0)))img_num_per_cls.append(int(num))elif self.config.imb_type == 'step':for cls_idx in range(self.config.cls_num // 2):img_num_per_cls.append(int(img_max))for cls_idx in range(self.config.cls_num // 2):img_num_per_cls.append(int(img_max * self.config.imb_factor))else:img_num_per_cls.extend([int(img_max)] * self.config.cls_num)return img_num_per_cls
  • 输入
    • data_length: 数据集中的样本数量。
  • 处理
    • img_max 计算每个类别的最大图像数量,即将总样本数 data_length 平均分配到每个类别中。
    • 如果不平衡类型(self.config.imb_type)是 'exp'(指数型不平衡),则每个类别的样本数量根据指数因子(self.config.imb_factor)来计算。类别索引越大,样本数量越少。
    • 如果不平衡类型是 'step'(阶梯型不平衡),则前一半类别样本数为 img_max,后一半类别的样本数为 img_max * imb_factor
    • 如果不平衡类型是 'uniform'(均衡),则所有类别的样本数都是相同的,等于 img_max
  • 返回值
    • 返回一个列表 img_num_per_cls,表示每个类别的图像数量。

3.11 _gen_imbalanced_data 方法

 _gen_imbalanced_data 方法:这个函数根据不平衡的图像数量,生成不平衡的数据集索引。

def _gen_imbalanced_data(self, img_num_per_cls: List[int], targets: List[int]) -> Tuple[List[int], Dict[int, int]]:"""This function returns the indices of imbalanced CIFAR-10-LT dataset and the number of images per class."""new_indices = []targets_np = np.array(targets, dtype=np.int64)classes = np.unique(targets_np)num_per_cls_dict = dict()for the_class, the_img_num in zip(classes, img_num_per_cls):num_per_cls_dict[the_class] = the_img_numidx = np.where(targets_np == the_class)[0]np.random.shuffle(idx)selec_idx = idx[:the_img_num]new_indices.extend(selec_idx.tolist())return new_indices, num_per_cls_dict
  • 输入
    • img_num_per_cls: 每个类别的目标图像数量。
    • targets: 数据集的标签列表。
  • 处理
    • targets_np 将 targets 转换为 NumPy 数组。
    • classes 获取标签中的所有类别。
    • 遍历每个类别和目标图像数量,找出该类别的所有样本的索引 idx,并随机打乱。
    • 根据目标图像数量 the_img_num 从打乱后的索引中选择前 the_img_num 个样本。
    • 将这些样本的索引添加到 new_indices 列表中。
    • num_per_cls_dict 字典存储每个类别对应的目标样本数。
  • 返回值
    • 返回一个元组:
      • new_indices: 包含不平衡数据集索引的列表。
      • num_per_cls_dict: 每个类别的目标样本数字典。

4. 完整代码

核心逻辑:

  1. 用户输入不平衡策略+不平衡程度等参数。
  2. 根据类的图片数量按照前面用户的输入生成该类在不平衡数据集中的图片数量。
  3. 根据每个类在不平衡数据集中的图片数量生成索引。
  4. 根据索引从原数据集中生成新的长尾数据集。
import os
import pickle
from typing import Dict, Iterator, List, Tuple, BinaryIO
import numpy as np
import datasets
from datasets.tasks import ImageClassification_CITATION = """\
@TECHREPORT{Krizhevsky09learningmultiple,author = {Alex Krizhevsky},title = {Learning multiple layers of features from tiny images},institution = {},year = {2009}
}
"""_DESCRIPTION = """\
The CIFAR-10-LT imbalanced dataset is comprised of under 60,000 color images, each measuring 32x32 pixels, 
distributed across 10 distinct classes.  
The dataset includes 10,000 test images, with 1000 images per class, 
and fewer than 50,000 training images.
The number of samples within each class of the train set decreases exponentially with factors of 10, 20, 50, 100, or 200.
"""_DATA_URL = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"_NAMES = ["airplane","automobile","bird","cat","deer","dog","frog","horse","ship","truck",
]class Cifar10LTConfig(datasets.BuilderConfig):"""BuilderConfig for CIFAR-10-LT."""def __init__(self, imb_type: str, imb_factor: float, rand_number: int = 0, cls_num: int = 10, **kwargs):"""BuilderConfig for CIFAR-10-LT.Args:imb_type (str): imbalance type, including 'exp', 'step'.imb_factor (float): imbalance factor.rand_number (int): random seed, default: 0.cls_num (int): number of classes, default: 10.**kwargs: keyword arguments forwarded to super."""# Version history:super().__init__(version=datasets.Version("1.0.0"), **kwargs)self.imb_type = imb_typeself.imb_factor = imb_factorself.rand_number = rand_numberself.cls_num = cls_numnp.random.seed(self.rand_number)class Cifar10(datasets.GeneratorBasedBuilder):"""CIFAR-10 Dataset"""BUILDER_CONFIGS = [Cifar10LTConfig(name="r-10",description="CIFAR-10-LT-r-10 Dataset",imb_type='exp',imb_factor=1/10,rand_number=0,cls_num=10,),Cifar10LTConfig(name="r-20",description="CIFAR-10-LT-r-20 Dataset",imb_type='exp',imb_factor=1/20,rand_number=0,cls_num=10,),Cifar10LTConfig(name="r-50",description="CIFAR-10-LT-r-50 Dataset",imb_type='exp',imb_factor=1/50,rand_number=0,cls_num=10,),Cifar10LTConfig(name="r-100",description="CIFAR-10-LT-r-100 Dataset",imb_type='exp',imb_factor=1/100,rand_number=0,cls_num=10,),Cifar10LTConfig(name="r-200",description="CIFAR-10-LT-r-200 Dataset",imb_type='exp',imb_factor=1/200,rand_number=0,cls_num=10,),]def _info(self) -> datasets.DatasetInfo:return datasets.DatasetInfo(description=_DESCRIPTION,features=datasets.Features({"img": datasets.Image(),"label": datasets.features.ClassLabel(names=_NAMES),}),supervised_keys=None,  # Probably needs to be fixed.homepage="https://www.cs.toronto.edu/~kriz/cifar.html",citation=_CITATION,task_templates=[ImageClassification(image_column="img", label_column="label")],)def _split_generators(self, dl_manager: datasets.DownloadManager) -> List[datasets.SplitGenerator]:archive = dl_manager.download(_DATA_URL)return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"files": dl_manager.iter_archive(archive), "split": "train"}),datasets.SplitGenerator(name=datasets.Split.TEST, gen_kwargs={"files": dl_manager.iter_archive(archive), "split": "test"}),]def _generate_examples(self, files: Iterator[Tuple[str, BinaryIO]], split: str) -> Iterator[Dict]:"""This function returns the examples in the array form."""if split == "train":batches = ["data_batch_1", "data_batch_2", "data_batch_3", "data_batch_4", "data_batch_5"]if split == "test":batches = ["test_batch"]batches = [f"cifar-10-batches-py/{filename}" for filename in batches]for path, fo in files:if path in batches:dict = pickle.load(fo, encoding="bytes")labels = dict[b"labels"]images = dict[b"data"]if split == "train":indices = self._imbalance_indices()else:indices = range(len(labels))for idx in indices:img_reshaped = np.transpose(np.reshape(images[idx], (3, 32, 32)), (1, 2, 0))yield f"{path}_{idx}", {"img": img_reshaped,"label": labels[idx],}breakdef _generate_indices_targets(self, files: Iterator[Tuple[str, BinaryIO]], split: str) -> Iterator[Dict]:"""This function returns the examples in the array form."""if split == "train":batches = ["data_batch_1", "data_batch_2", "data_batch_3", "data_batch_4", "data_batch_5"]if split == "test":batches = ["test_batch"]batches = [f"cifar-10-batches-py/{filename}" for filename in batches]for path, fo in files:if path in batches:dict = pickle.load(fo, encoding="bytes")labels = dict[b"labels"]for idx, _ in enumerate(labels):yield f"{path}_{idx}", {"label": labels[idx],}breakdef _get_img_num_per_cls(self, data_length: int) -> List[int]:"""Get the number of images per class given the imbalance ratio and total number of images."""img_max = data_length / self.config.cls_numimg_num_per_cls = []if self.config.imb_type == 'exp':for cls_idx in range(self.config.cls_num):num = img_max * (self.config.imb_factor**(cls_idx / (self.config.cls_num - 1.0)))img_num_per_cls.append(int(num))elif self.config.imb_type == 'step':for cls_idx in range(self.config.cls_num // 2):img_num_per_cls.append(int(img_max))for cls_idx in range(self.config.cls_num // 2):img_num_per_cls.append(int(img_max * self.config.imb_factor))else:img_num_per_cls.extend([int(img_max)] * self.config.cls_num)return img_num_per_clsdef _gen_imbalanced_data(self, img_num_per_cls: List[int], targets: List[int]) -> Tuple[List[int], Dict[int, int]]:"""This function returns the indices of imbalanced CIFAR-10-LT dataset and the number of images per class."""new_indices = []targets_np = np.array(targets, dtype=np.int64)classes = np.unique(targets_np)num_per_cls_dict = dict()for the_class, the_img_num in zip(classes, img_num_per_cls):num_per_cls_dict[the_class] = the_img_numidx = np.where(targets_np == the_class)[0]np.random.shuffle(idx)selec_idx = idx[:the_img_num]new_indices.extend(selec_idx.tolist())return new_indices, num_per_cls_dictdef _imbalance_indices(self) -> List[int]:"""This function returns the indices of imbalanced CIFAR-10-LT dataset."""dl_manager = datasets.DownloadManager()archive = dl_manager.download(_DATA_URL)data_iterator = self._generate_indices_targets(dl_manager.iter_archive(archive), "train")indices = []targets = []for i, targets_dict in data_iterator:indices.append(i)targets.append(targets_dict["label"])data_length = len(indices)img_num_per_cls = self._get_img_num_per_cls(data_length)new_indices, _ = self._gen_imbalanced_data(img_num_per_cls, targets)return new_indicesdef main():"""Generate and save CIFAR-10-LT dataset as .npy files."""# 直接通过 name 参数指定配置dataset = Cifar10(name="r-10")  # 使用预定义的 "r-10" 配置# 下载并加载数据集dataset.download_and_prepare()train_data = dataset.as_dataset(split="train")test_data = dataset.as_dataset(split="test")# 提取图像和标签train_images = np.array([example["img"] for example in train_data])train_labels = np.array([example["label"] for example in train_data])test_images = np.array([example["img"] for example in test_data])test_labels = np.array([example["label"] for example in test_data])# 保存为 .npy 文件os.makedirs("cifar10_lt", exist_ok=True)np.save("cifar10_lt/train_images.npy", train_images)np.save("cifar10_lt/train_labels.npy", train_labels)np.save("cifar10_lt/test_images.npy", test_images)np.save("cifar10_lt/test_labels.npy", test_labels)print("CIFAR-10-LT dataset saved as .npy files in 'cifar10_lt' directory.")if __name__ == "__main__":main()

运行结果:

  

5. 统计&可视化长尾数据集 

import matplotlib
matplotlib.use('Agg')  # 必须在其他matplotlib导入之前设置
import matplotlib.pyplot as plt
import numpy as np
import os# 加载数据集
def load_data():base_dir = "cifar10_lt"return (np.load(os.path.join(base_dir, "train_images.npy")),np.load(os.path.join(base_dir, "train_labels.npy")),np.load(os.path.join(base_dir, "test_images.npy")),np.load(os.path.join(base_dir, "test_labels.npy")),)# 定义CIFAR-10类别名称
CLASS_NAMES = ['airplane', 'automobile', 'bird', 'cat', 'deer','dog', 'frog', 'horse', 'ship', 'truck'
]# 统计类别分布并保存图片
def plot_class_distribution(labels, title, filename):unique, counts = np.unique(labels, return_counts=True)plt.figure(figsize=(12, 6))bars = plt.bar(CLASS_NAMES, counts)# 显示数值标签for bar in bars:height = bar.get_height()plt.text(bar.get_x() + bar.get_width() / 2,height + 5,  # 稍微抬高数值避免重叠f'{int(height)}',ha='center', va='bottom')plt.title(f"Class Distribution ({title})")plt.xlabel("Class")plt.ylabel("Number of Samples")plt.xticks(rotation=45)plt.tight_layout()plt.savefig(filename)  # 保存为文件plt.close()  # 关闭图形避免内存泄漏# 可视化样本并保存图片
def visualize_samples(images, labels, filename, num_samples=20):plt.figure(figsize=(15, 10))indices = np.random.choice(len(images), num_samples, replace=False)for i, idx in enumerate(indices, 1):plt.subplot(4, 5, i)image = images[idx].astype(np.float32)# 确保像素值在 [0,1] 范围内if image.max() > 1.0:image /= 255.0# 转换通道顺序 (HWC格式直接可用)plt.imshow(image)plt.title(f"{CLASS_NAMES[labels[idx]]}\n({labels[idx]})")plt.axis('off')plt.tight_layout()plt.savefig(filename)plt.close()if __name__ == "__main__":# 强制设置后端(防止环境变量干扰)matplotlib.use('Agg')# 加载数据和后续代码保持不变train_images, train_labels, test_images, test_labels = load_data()os.makedirs("visualization", exist_ok=True)plot_class_distribution(train_labels, "Training Set", "visualization/train_dist.png")plot_class_distribution(test_labels, "Test Set", "visualization/test_dist.png")# 可视化样本visualize_samples(train_images, train_labels, "visualization/samples.png")print("所有图表已保存至 'visualization' 目录!")

运行结果:

6. 总结

 【如果想学习更多深度学习文章,可以订阅一下热门专栏】

  • 《PyTorch科研加速指南:即插即用式模块开发》_十二月的猫的博客-CSDN博客
  • 《深度学习理论直觉三十讲》_十二月的猫的博客-CSDN博客
  • 《AI认知筑基三十讲》_十二月的猫的博客-CSDN博客

如果想要学习更多pyTorch/python编程的知识,大家可以点个关注并订阅,持续学习、天天进步你的点赞就是我更新的动力,如果觉得对你有帮助,辛苦友友点个赞,收个藏呀~~~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/36202.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux操作系统7- 线程同步与互斥1(POSIX互斥锁的使用详解)

上篇文章:Linux操作系统6- 线程4(POSIX线程的简单封装)-CSDN博客 本篇代码仓库:myLerningCode/l30 橘子真甜/Linux操作系统与网络编程学习 - 码云 - 开源中国 (gitee.com) 目录 一. 线程不互斥造成的结果 二. pthread_mutex_t 互斥…

深入 Linux 声卡驱动开发:核心问题与实战解析

1. 字符设备驱动如何为声卡提供操作接口? 问题背景 在 Linux 系统中,声卡被抽象为字符设备。如何通过代码让应用程序能够访问声卡的录音和播放功能? 核心答案 1.1 字符设备驱动的核心结构 Linux 字符设备驱动通过 file_operations 结构体定…

洛谷 [语言月赛 202503] 题解(C++)

本文为洛谷3月的语言月赛题目全部题解,难度为入门到普及-, 觉的有帮助或者写的不错的可以点个赞 题目链接为 题目列表 - 洛谷 | 计算机科学教育新生态 目录 题目A:长方形 解题思路: 代码(C): 题目B:水流 题目大意: 解题思路: 代码(C): 题目C:格…

算法每日一练 (15)

💢欢迎来到张胤尘的技术站 💥技术如江河,汇聚众志成。代码似星辰,照亮行征程。开源精神长,传承永不忘。携手共前行,未来更辉煌💥 文章目录 算法每日一练 (15)第 N 个泰波那契数题目描述解题思路…

实验11 机器学习-贝叶斯分类器

实验11 机器学习-贝叶斯分类器 一、实验目的 (1)理解并熟悉贝叶斯分类器的思想和原理; (2)熟悉贝叶斯分类器的数学推导过程; (3)能运用贝叶斯分类器解决实际问题并体会算法的效果&a…

Matrix-breakout-2-morpheus靶机实战攻略

1.安装并开启靶机 2.获取靶机IP 3.浏览器访问靶机 4.扫描敏感目录文件和端口 gobuster dir -u http://192.168.52.135 -w /usr/share/wordlists/dirbuster/directory-list-2.3-medium.txt -x php,txt,html 5.访问文件和端口 发现在graffiti.php输入框输入内容后页面会返回内容…

【知识】Graph Sparsification、Graph Coarsening、Graph Condensation的详细介绍和对比

转载请注明出处:小锋学长生活大爆炸[xfxuezhagn.cn] 如果本文帮助到了你,欢迎[点赞、收藏、关注]哦~ 目录 1. 理论基础(Definitions & Theoretical Background) 2. 算法方法(Techniques & Algorithms&#x…

Java单元测试、Junit、断言、单元测试常见注解、单元测试Maven依赖范围、Maven常见问题解决方法

一. 测试 1. 测试:是一种用来促进鉴定软件的正确性、完整性、安全性和质量的过程 2. 阶段划分:单元测试、集成测试、系统测试、验收测试。 ① 单元测试:对软件的基本组成单位进行测试,最小测试单位;目的检验软件基本组…

【Notepad】Notepad优化笔记AutoHotkey语法高亮\设置替换默认的notepad程序\设置主题\增加返回上一个编辑地方插件

Npp使用优化笔记 AHK或自定义语法高亮设置替换系统默认的notepad设置主题返回上一次编辑的地方插件使用 AHK或自定义语法高亮 具体参考该论坛 https://www.autohotkey.com/boards/viewtopic.php?t50 设置替换默认的notepad程序 参考文章: https://www.winhelpo…

Mac:Maven 下载+安装+环境配置(详细讲解)

📌 下载 Maven 下载地址:https://maven.apache.org/download.cgi 📌 无需安装 Apache官网下载 Maven 压缩包,无需安装,下载解压后放到自己指定目录下即可。 按我自己的习惯,我会在用户 jane 目录下新建…

[K!nd4SUS 2025] Crypto

最后一个把周末的补完。这个今天问了小鸡块神终于把一个补上,完成5/6,最后一个网站也上不去不弄了。 Matrices Matrices Matrices 这个是不是叫LWE呀,名词忘了,但意思还是知道。 b a*s e 这里的e是高斯分成,用1000…

学习threejs,构建THREE.ParametricGeometry参数化函数生成几何体

👨‍⚕️ 主页: gis分享者 👨‍⚕️ 感谢各位大佬 点赞👍 收藏⭐ 留言📝 加关注✅! 👨‍⚕️ 收录于专栏:threejs gis工程师 文章目录 一、🍀前言1.1 ☘️THREE.ParametricGeometry1…

Canal 解析与 Spring Boot 整合实战

一、Canal 简介 1.1 Canal 是什么? Canal 是阿里巴巴开源的一款基于 MySQL 数据库增量日志解析(Binlog)中间件,它模拟 MySQL 的从机(Slave)行为,监听 MySQL 主机的二进制日志(Binl…

【海螺AI视频】蓝耘智算 | AI视频新浪潮:蓝耘MaaS与海螺AI视频创作体验

【作者主页】Francek Chen 【专栏介绍】 ⌈ ⌈ ⌈人工智能与大模型应用 ⌋ ⌋ ⌋ 人工智能(AI)通过算法模拟人类智能,利用机器学习、深度学习等技术驱动医疗、金融等领域的智能化。大模型是千亿参数的深度神经网络(如ChatGPT&…

Prometheus使用

介绍:Prometheus 是一个开源的 监控与告警系统,主要用于采集和存储时间序列数据(Time Series Data) Prometheus的自定义查询语言PromQL Metric类型 为了能够帮助用户理解和区分这些不同监控指标之间的差异,Prometheu…

Linux 文件操作-标准IO函数3- fread读取、fwrite写入、 fprintf向文件写入格式化数据、fscanf逐行读取格式化数据的验证

目录 1. fread 从文件中读取数据 1.1 读取次数 每次读取字节数 < 原内容字节数 1.2 读取次数 每次读取字节数 > 原内容字节数 2.fwrite 向文件中写入数据 2.1写入字符串验证 2.2写入结构体验证 3. fprintf 将数据写入到指定文件 4. fscanf 从文件中逐行读取内容…

再学:abi编码 地址类型与底层调用

目录 1.内置全局变量及函数 2.abi 3.地址类型 4.transfer 1.内置全局变量及函数 2.abi data就是abi编码 abi描述&#xff1a;以json格式表明有什么方法 3.地址类型 4.transfer x.transfer:合约转给x call 和 delegatecall 是 Solidity 中用于底层合约调用的函数&#xff0…

解决前端文字超高度有滚动条的情况下padding失效(el-scrollbar)使用

<div class"detailsBlocksContent"><div>测试测试测试测试测试测试测试测试测试测试测试测试测试测试测试测试测试测试测试测试测试测试测试测试测试测试测试测试测试测试测试测试测试测试测试测试测试测试测试测试测试测试测试测试测试测试测试测试测试…

SpringCloud 学习笔记3(OpenFeign)

OpenFeign 微服务之间的通信方式&#xff0c;通常有两种&#xff1a;RPC 和 HTTP。 简言之&#xff0c;RPC 就是像调用本地方法一样调用远程方法。 在 SpringCloud 中&#xff0c;默认是使用 HTTP 来进行微服务的通信&#xff0c;最常用的实现形式有两种&#xff1a; RestTem…

c中<string.h>

常见错误与最佳实践 缓冲区溢出&#xff1a; strcpy 和 strcat 不检查目标缓冲区大小&#xff0c;需手动确保空间足够。替代方案&#xff1a;使用 strncpy 和 strncat&#xff0c;或动态分配内存&#xff08;如 malloc&#xff09;。 未终止的字符串&#xff1a; 确保字符串以…