1、前言、数据集介绍
SAM-Med2D大模型介绍参考上文:第三章:SAM-Med2D大模型复现-CSDN博客
本文将使用SAM-Med2D大模型训练自己的数据集
关于SAM-Med2D大模型官方demo数据集的介绍上文已经介绍过,这里简单回顾下
- 其中data_demo为数据集的目录,下面有images和masks两个目录,分别存放数据和标签
- 其中images,就是正常的数据图像,格式是png格式
- masks格式值得注意,正常的mask是灰度等级的阈值图像【0 1 2 3】,这里把每个类别单独提取出来,变成【0 255】的二值图像,有几个类别就有几张对应的mask模板
例如mask是【0 1 2 2 1】,mask模板有两个,分别是1对应的模板【0 255 0 0 255】,就是只分割前景1。以及只是分割2的模板【0 0 255 255 0】。
mask的命名可以是image名字加上灰度,例如image_1.png和image_2.png
两个json文件如下:
训练数据就是单张image对应的一组mask标签字典
测试集是mask对应的image
2、生成数据的脚本
有了上面介绍,就很简单了,也就是说我们只需要把自己的数据集换成上面格式就可以正常训练了!
其他补充,因为官方的image和mask都是png格式的。
格式需要是png,因为之前本人做过实验只有png保存的二值图像,灰度值才不会乱掉(比如你保存【0 255 0】的jpg读取,np.unique读取可能变成【0 224 223】之类的)
更改文件后缀可以参考:PYTHON 自动化办公:更改图片后缀_改变文件夹里面图片后缀名的pytorch代码-CSDN博客
这里需要把自己的数据集摆放如下:
划分数据集的脚本参考:关于图像分割任务中按照比例将数据集随机划分成训练集和测试集_图像数据划分训练集-CSDN博客
然后运行下面代码就行了:
这个代码会生成image对应mask不同类别的掩膜数据,并且生成两个json文件。这里的目录命名一定要和上面对应
import json
import numpy as np
from tqdm import tqdm
import os
import shutil
from PIL import Image
import cv2def mkdir():root = 'data_demo'if os.path.exists(root):shutil.rmtree(root)os.mkdir(root)os.mkdir(os.path.join(root,'images'))os.mkdir(os.path.join(root,'masks'))# 生成训练集
def gen_trainSet(img_suff,msk_suff):p = 'RawData/train/images'image_list = [os.path.join(p,i) for i in os.listdir(p)]with open('data_demo/image2label_train.json', 'a') as jf:json_all ={} # json文件for i in tqdm(image_list,desc='generate train set'):j = i.replace('images','masks').replace(img_suff,msk_suff)assert os.path.exists(j) # 判断label是否存在shutil.copy(i,'data_demo/images')mask = np.array(Image.open(j).convert('L')) # 标签图像gray_list = np.unique(mask)img_list = []for gray in gray_list[1:]: # 遍历mask所有的分割前景ret_mask = np.zeros(mask.shape,dtype=np.uint8)ret_mask[mask==gray] =255 # 指定前景为255,其余为背景ret_mask[ret_mask<255] = 0# 去除小的分割区域h,w = ret_mask.shapetotal_pixel = h*wif (np.sum(ret_mask!=0)/total_pixel) < 0.005:continueret_name =i.replace(img_suff,'_'+str(gray)+img_suff).replace('RawData/train/images','data_demo/masks')cv2.imwrite(ret_name,ret_mask) # 保存生成的数据img_list.append(ret_name)if len(img_list) == 0:continuejson_all[i.replace('RawData/train/images','data_demo/images')] = img_listjson_str = json.dumps(json_all,indent=4)jf.write(json_str)# 生成测试集
def gen_testSet(img_suff,msk_suff):p = 'RawData/test/images'image_list = [os.path.join(p,i) for i in os.listdir(p)]with open('data_demo/label2image_test.json', 'a') as jf:json_all ={} # json文件for i in tqdm(image_list,desc='generate test set'):j = i.replace('images','masks').replace(img_suff,msk_suff)assert os.path.exists(j) # 判断label是否存在shutil.copy(i,'data_demo/images')mask = np.array(Image.open(j).convert('L')) # 标签图像gray_list = np.unique(mask)for gray in gray_list[1:]: # 遍历mask所有的分割前景ret_mask = np.zeros(mask.shape,dtype=np.uint8)ret_mask[mask==gray] =255 # 指定前景为255,其余为背景ret_mask[ret_mask<255] = 0# 去除小的分割区域h,w = ret_mask.shapetotal_pixel = h*wif (np.sum(ret_mask!=0)/total_pixel) < 0.005:continueret_name =i.replace(img_suff,'_'+str(gray)+img_suff).replace('RawData/test/images','data_demo/masks')cv2.imwrite(ret_name,ret_mask) # 保存生成的数据json_all[ret_name] = i.replace('RawData/test/images','data_demo/images')json_str = json.dumps(json_all,indent=4)jf.write(json_str)if __name__ == '__main__':imgFormat = '.png' # image 的后缀maskFormat = '.png' # mask 的后缀mkdir() # 生成目录gen_trainSet(img_suff=imgFormat,msk_suff=maskFormat) # 生成训练数据gen_testSet(img_suff=imgFormat,msk_suff=maskFormat) # 生成测试数据
Tips
运行过程如下
如下:
可以看到image生成了三个对应的mask数据,命名是image的名字加上类别。
下图的8 9 17后缀是原来mask中8 9 17的像素值
测试代码的时候,训练会报错误,大概是len(box)什么分母为零,不能被除的bug。本人猜测可能是生成的组mask里面,前景区域太小之类的,所有脚本里增加点处理
代码会将不足千分之五的分割前景区域删除
3、训练脚本
因为生成的数据就是data_demo目录,所有train脚本不需要任何更改,直接运行即可
这里的parser.add_argument("--mask_num", type=int, default=5, help="get mask number")参数还是没懂
生成的结果如下:每个权重大约2G左右吧
4、测试脚本
代码如下:
python test.py --sam_checkpoint workdir/models/sam-med2d/epoch10_sam.pth
测试结果如下: