数据增广是提升模型泛化能力重要的手段之一,CopyPaste 是一种新颖的数据增强技巧,已经在目标检测和实例分割任务中验证了有效性。利用 CopyPaste,可以合成文本实例来平衡训练图像中的正负样本之间的比例。相比而言,传统图像旋转、随机翻转和随机裁剪是无法做到的。
CopyPaste 主要步骤包括:
- 随机选择两幅训练图像;
- 随机尺度抖动缩放;
- 随机水平翻转;
- 随机选择一幅图像中的目标子集;
- 粘贴在另一幅图像中随机的位置。
这样就比较好地提升了样本丰富度,同时也增加了模型对环境的鲁棒性。如下图所示,通过在左下角的图中裁剪出来的文本,随机旋转缩放之后粘贴到左上角的图像中,进一步丰富了该文本在不同背景下的多样性。
参考代码:
# !/usr/bin/env python
# -*- coding:utf-8 -*-
# @Time : 2024.07
# @Author : 绿色羽毛
# @Email : lvseyumao@foxmail.com
# @Blog : https://blog.csdn.net/ViatorSun
# @Note :import os
import cv2
import json
import logging
import random
import numpy as npimport matplotlib.pyplot as pltdef create_operators(op_param_list, global_config=None):"""create operators based on the configArgs:params(list): a dict list, used to create some operators"""assert isinstance(op_param_list, list), ('operator config should be a list')ops = []for operator in op_param_list:assert isinstance(operator,dict) and len(operator) == 1, "yaml format error"op_name = list(operator)[0]param = {} if operator[op_name] is None else operator[op_name]if global_config is not None:param.update(global_config)op = eval(op_name)(**param)ops.append(op)return opsdef transform(data, ops=None):""" transform """if ops is None:ops = []for op in ops:data = op(data)if data is None:return Nonereturn data# CopyPaste示例的类
class CopyPasteDemo(object):def __init__(self, ):self.data_dir = "/media/sun/Data/Dataset/OCR_Data/det/train/"self.label_file_list = "/media/sun/Data/Dataset/OCR_Data/det/train.txt"self.data_lines = self.get_image_info_list(self.label_file_list)self.data_idx_order_list = list(range(len(self.data_lines)))transforms = [{"DecodeImage": {"img_mode": "BGR", "channel_first": False}},{"DetLabelEncode": {}},{"CopyPaste": {"objects_paste_ratio": 1.0}},]self.ops = create_operators(transforms)# 选择一张图像,将其中的内容拷贝到当前图像中def get_ext_data(self, idx):ext_data_num = 1ext_data = []next_idx = idxload_data_ops = self.ops[:2]while len(ext_data) < ext_data_num:next_idx = (next_idx + 1) % len(self)file_idx = self.data_idx_order_list[next_idx]data_line = self.data_lines[file_idx]data_line = data_line.decode('utf-8')substr = data_line.strip("\n").split("\t")file_name = substr[0]label = substr[1]img_path = os.path.join(self.data_dir, file_name)data = {'img_path': img_path, 'label': label}if not os.path.exists(img_path):continuewith open(data['img_path'], 'rb') as f:img = f.read()data['image'] = imgdata = transform(data, load_data_ops)if data is None:continueext_data.append(data)return ext_data# 获取图像信息def get_image_info_list(self, file_list):if isinstance(file_list, str):file_list = [file_list]data_lines = []for idx, file in enumerate(file_list):with open(file, "rb") as f:lines = f.readlines()data_lines.extend(lines)return data_lines# 获取DataSet中的一条数据def __getitem__(self, idx):file_idx = self.data_idx_order_list[idx]data_line = self.data_lines[file_idx]try:data_line = data_line.decode('utf-8')substr = data_line.strip("\n").split("\t")file_name = substr[0]label = substr[1]img_path = os.path.join(self.data_dir, file_name)data = {'img_path': img_path, 'label': label}if not os.path.exists(img_path):raise Exception("{} does not exist!".format(img_path))with open(data['img_path'], 'rb') as f:img = f.read()data['image'] = imgdata['ext_data'] = self.get_ext_data(idx)outs = transform(data, self.ops)except Exception as e:print("When parsing line {}, error happened with msg: {}".format(data_line, e))outs = Noneif outs is None:returnreturn outsdef __len__(self):return len(self.data_idx_order_list)if __name__ == '__main__':copy_paste_demo = CopyPasteDemo()idx = 1data1 = copy_paste_demo[idx]print(data1.keys())print(data1["img_path"])print(data1["ext_data"][0]["img_path"])infos = copy_paste_demo.data_lines[idx]infos = json.loads(infos.decode('utf-8').split("\t")[1])img3 = data1["image"].copy()plt.figure(figsize=(15, 10))plt.imshow(img3[:, :, ::-1])# 原始标注信息for info in infos:xs, ys = zip(*info["points"])xs = list(xs)ys = list(ys)xs.append(xs[0])ys.append(ys[0])plt.plot(xs, ys, "r")# 新增的标注信息for poly_idx in range(len(infos), len(data1["polys"])):poly = data1["polys"][poly_idx]xs, ys = zip(*poly)xs = list(xs)ys = list(ys)xs.append(xs[0])ys.append(ys[0])plt.plot(xs, ys, "b")plt.show()
生成后的图像