本专栏主要是提供一种国产化图像识别的解决方案,专栏中实现了YOLOv5/v8在国产化芯片上的使用部署,并可以实现网页端实时查看。根据自己的具体需求可以直接产品化部署使用。
B站配套视频:https://www.bilibili.com/video/BV1or421T74f
图像增强的必要性
在我们日常进行训练的时候经常会遇到数据集不足的情况,比如对特定物品进行识别。我们很难收到满足训练数量的有效数据集,这时候我们就可以考虑采用图像增强的方式增加我们的数据量。
不废话直接上代码
from albumentations import *
import os
import cv2
from tqdm import tqdmclass enhancement:def __init__(self, picture_path, label_path, save_img_path, save_lable_path):image_files = []for file_name in os.listdir(picture_path):if file_name.endswith(('.jpg','.jepg','.png','.gif')):image_files.append(file_name)self.picture_name = sorted(image_files)label_files = []for file_name in os.listdir(label_path):if file_name.endswith(('.txt')):label_files.append(file_name)self.label_name = sorted(label_files)self.picture_path = [picture_path + i for i in self.picture_name]self.label_path = [label_path + i for i in self.label_name]self.save_img_path = save_img_pathself.save_lable_path = save_lable_pathdef iter(self):batch_size = 10for index_bin in tqdm(range(0, len(self.picture_path), batch_size), desc='批次进度'):# print(index_bin)picture_batch = self.picture_path[index_bin:index_bin + batch_size]label_batch = self.label_path[index_bin:index_bin + batch_size]yield picture_batch, label_batch, [index_bin, index_bin + batch_size]def get_transform(self):transform = Compose([# 图像均值平滑滤波。# Blur(blur_limit=7, always_apply=False, p=0.5),# VerticalFlip 水平翻转# VerticalFlip(always_apply=False, p=0.5),# HorizontalFlip 垂直翻转# HorizontalFlip(always_apply=False, p=1),# 中心裁剪# CenterCrop(200, 200, always_apply=False, p=1.0),# RandomFog(fog_coef_lower=0.3, fog_coef_upper=0.7, alpha_coef=0.08, always_apply=False, p=1),# RandomCrop(width=200, height=200)# Downscale(always_apply=False,p=1)# 添加其他增强技术# 参数:随机色调、饱和度、值变化。# HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, always_apply=True, p=0.5),# 使用具有随机核大小的高斯滤波器对输入图像进行模糊处理# GaussianBlur(blur_limit=7, always_apply=False, p=0.5),# 随机填充遮挡区域# CoarseDropout(max_holes=30, max_height=110, max_width=110, min_holes=10, min_height=50, min_width=50,# fill_value=0, mask_fill_value=None, always_apply=True, p=0.5),GridDropout(always_apply=True),], bbox_params=BboxParams(format='yolo', label_fields=['class_labels']))return transformdef augmentations(self, image, bboxes, class_labels):transform = self.get_transform()transformed = transform(image=image, bboxes=bboxes, class_labels=class_labels)augmented_image = transformed['image']augmented_bboxes = transformed['bboxes']augmented_labels = transformed['class_labels']return augmented_image, augmented_bboxes, augmented_labelsdef augmented_image_bboxes(self, img_path, l_path):# 打印l_pathprint(l_path)with open(l_path, 'r') as f:values = f.read()f.close()class_labels, original_bboxes = [], []values = [i.split(' ') for i in values.split('\n')[:-1]]for i in values:class_labels.append(int(i[0]))original_bboxes.append([float(i) for i in i[1:]])original_image = cv2.imread(img_path)augmented_image, augmented_bboxes, augmented_labels = self.augmentations(original_image, original_bboxes,class_labels)return augmented_image, augmented_bboxes, augmented_labels, original_imagedef parsing_data(self, p_l_i):img_path, l_path, index = p_l_i[0], p_l_i[1], p_l_i[2]self.augmented_image, self.augmented_bboxes, augmented_labels, original_image = self.augmented_image_bboxes(img_path, l_path)data = []for l, d in zip(augmented_labels, self.augmented_bboxes):s = ' '.join(map(str, [l] + list(d)))data.append(s)data = '\n'.join(data)if augmented_labels:self.show_img()self.save_img_lable(data, self.augmented_image, self.save_img_path, self.save_lable_path, index)else:print(f'{self.picture_name[index]}该图片没有标签,不做保存')def save_img_lable(self, data, img, save_img_path, save_lable_path, index):cv2.imwrite(save_img_path + 'aug2_camber_' + self.picture_name[index], img)with open(save_lable_path + 'aug2_camber_' + self.label_name[index], 'w') as f:f.write(data)f.close()def __call__(self):for picture_batch, label_batch, index_bin in self.iter():list(map(self.parsing_data,[(p, l, i) for p, l, i in zip(picture_batch, label_batch, range(index_bin[0], index_bin[1]))]))def show_img(self, boxe=False):if boxe:for j in self.augmented_bboxes:x, y, w, h = jx1 = int((x - w / 2) * self.augmented_image.shape[1])y1 = int((y - h / 2) * self.augmented_image.shape[0])x2 = int((x + w / 2) * self.augmented_image.shape[1])y2 = int((y + h / 2) * self.augmented_image.shape[0])cv2.rectangle(self.augmented_image, (x1, y1), (x2, y2), (255, 0, 0), 2)cv2.rectangle(self.augmented_image, (x1, y1), (x2, y2), (255, 0, 0), 2)else:pass# cv2.imshow('Augmented Image', self.augmented_image)# cv2.waitKey(0)# cv2.destroyAllWindows()if __name__ == '__main__':# 原图片,标签的路径picture_path = 'images/train/'label_path = 'labels/train/'# 增强后的图片跟标签save_img_path = 'images/train-aug-point/'save_lable_path = 'labels/train-aug-point/'c = enhancement(picture_path=picture_path,label_path=label_path,save_img_path=save_img_path,save_lable_path=save_lable_path)c()
代码说明
除了常规的opencv之外我们需要安装albumentations
pip install albumentations
代码讲解查看视频 https://www.bilibili.com/video/BV1or421T74f](https://www.bilibili.com/video/BV1or421T74f