SAM-Med2D 大模型学习笔记(续):训练自己数据集

1、前言、数据集介绍

SAM-Med2D大模型介绍参考上文:第三章:SAM-Med2D大模型复现-CSDN博客

本文将使用SAM-Med2D大模型训练自己的数据集

关于SAM-Med2D大模型官方demo数据集的介绍上文已经介绍过,这里简单回顾下

  • 其中data_demo为数据集的目录,下面有images和masks两个目录,分别存放数据和标签
  • 其中images,就是正常的数据图像,格式是png格式
  • masks格式值得注意,正常的mask是灰度等级的阈值图像【0 1 2 3】,这里把每个类别单独提取出来,变成【0 255】的二值图像,有几个类别就有几张对应的mask模板

例如mask是【0 1 2 2 1】,mask模板有两个,分别是1对应的模板【0 255 0 0 255】,就是只分割前景1。以及只是分割2的模板【0 0 255 255 0】。

mask的命名可以是image名字加上灰度,例如image_1.png和image_2.png

两个json文件如下:

训练数据就是单张image对应的一组mask标签字典

测试集是mask对应的image

2、生成数据的脚本

有了上面介绍,就很简单了,也就是说我们只需要把自己的数据集换成上面格式就可以正常训练了!

其他补充,因为官方的image和mask都是png格式的。

格式需要是png,因为之前本人做过实验只有png保存的二值图像,灰度值才不会乱掉(比如你保存【0 255 0】的jpg读取,np.unique读取可能变成【0 224 223】之类的)

更改文件后缀可以参考:PYTHON 自动化办公:更改图片后缀_改变文件夹里面图片后缀名的pytorch代码-CSDN博客

这里需要把自己的数据集摆放如下:

划分数据集的脚本参考:关于图像分割任务中按照比例将数据集随机划分成训练集和测试集_图像数据划分训练集-CSDN博客

然后运行下面代码就行了:

这个代码会生成image对应mask不同类别的掩膜数据,并且生成两个json文件。这里的目录命名一定要和上面对应

import json
import numpy as np
from tqdm import tqdm
import os
import shutil
from PIL import Image
import cv2def mkdir():root = 'data_demo'if os.path.exists(root):shutil.rmtree(root)os.mkdir(root)os.mkdir(os.path.join(root,'images'))os.mkdir(os.path.join(root,'masks'))# 生成训练集
def gen_trainSet(img_suff,msk_suff):p = 'RawData/train/images'image_list = [os.path.join(p,i) for i in os.listdir(p)]with open('data_demo/image2label_train.json', 'a') as jf:json_all ={}        # json文件for i in tqdm(image_list,desc='generate train set'):j = i.replace('images','masks').replace(img_suff,msk_suff)assert os.path.exists(j)        # 判断label是否存在shutil.copy(i,'data_demo/images')mask = np.array(Image.open(j).convert('L'))     # 标签图像gray_list = np.unique(mask)img_list = []for gray in gray_list[1:]:          # 遍历mask所有的分割前景ret_mask = np.zeros(mask.shape,dtype=np.uint8)ret_mask[mask==gray] =255      # 指定前景为255,其余为背景ret_mask[ret_mask<255] = 0# 去除小的分割区域h,w = ret_mask.shapetotal_pixel = h*wif (np.sum(ret_mask!=0)/total_pixel) < 0.005:continueret_name =i.replace(img_suff,'_'+str(gray)+img_suff).replace('RawData/train/images','data_demo/masks')cv2.imwrite(ret_name,ret_mask)  # 保存生成的数据img_list.append(ret_name)if len(img_list) == 0:continuejson_all[i.replace('RawData/train/images','data_demo/images')] = img_listjson_str = json.dumps(json_all,indent=4)jf.write(json_str)# 生成测试集
def gen_testSet(img_suff,msk_suff):p = 'RawData/test/images'image_list = [os.path.join(p,i) for i in os.listdir(p)]with open('data_demo/label2image_test.json', 'a') as jf:json_all ={}        # json文件for i in tqdm(image_list,desc='generate test set'):j = i.replace('images','masks').replace(img_suff,msk_suff)assert os.path.exists(j)        # 判断label是否存在shutil.copy(i,'data_demo/images')mask = np.array(Image.open(j).convert('L'))     # 标签图像gray_list = np.unique(mask)for gray in gray_list[1:]:          # 遍历mask所有的分割前景ret_mask = np.zeros(mask.shape,dtype=np.uint8)ret_mask[mask==gray] =255      # 指定前景为255,其余为背景ret_mask[ret_mask<255] = 0# 去除小的分割区域h,w = ret_mask.shapetotal_pixel = h*wif (np.sum(ret_mask!=0)/total_pixel) < 0.005:continueret_name =i.replace(img_suff,'_'+str(gray)+img_suff).replace('RawData/test/images','data_demo/masks')cv2.imwrite(ret_name,ret_mask)  # 保存生成的数据json_all[ret_name] = i.replace('RawData/test/images','data_demo/images')json_str = json.dumps(json_all,indent=4)jf.write(json_str)if __name__ == '__main__':imgFormat = '.png'          # image 的后缀maskFormat = '.png'         # mask 的后缀mkdir()         # 生成目录gen_trainSet(img_suff=imgFormat,msk_suff=maskFormat)        # 生成训练数据gen_testSet(img_suff=imgFormat,msk_suff=maskFormat)         # 生成测试数据

Tips

运行过程如下

如下:

可以看到image生成了三个对应的mask数据,命名是image的名字加上类别。

下图的8 9 17后缀是原来mask中8 9 17的像素值

测试代码的时候,训练会报错误,大概是len(box)什么分母为零,不能被除的bug。本人猜测可能是生成的组mask里面,前景区域太小之类的,所有脚本里增加点处理

代码会将不足千分之五的分割前景区域删除

3、训练脚本

因为生成的数据就是data_demo目录,所有train脚本不需要任何更改,直接运行即可

这里的parser.add_argument("--mask_num", type=int, default=5, help="get mask number")参数还是没懂

生成的结果如下:每个权重大约2G左右吧

4、测试脚本

代码如下:

python test.py --sam_checkpoint workdir/models/sam-med2d/epoch10_sam.pth

测试结果如下:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/396743.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

leetcode171. Excel 表列序号,进制转换

leetcode171. Excel 表列序号 给你一个字符串 columnTitle &#xff0c;表示 Excel 表格中的列名称。返回 该列名称对应的列序号 。 例如&#xff1a; A -> 1 B -> 2 C -> 3 … Z -> 26 AA -> 27 AB -> 28 … 示例 1: 输入: columnTitle “A” 输出: 1 示…

电商平台产品ID|CDN与预渲染|前端边缘计算

技术实现 都是通过ID拿到属性&#xff0c;进行预渲染html&#xff0c;通过 oss 分发出去 详情页这种基本都是通过 ssr 渲染出来&#xff0c;然后上缓存 CDN 分发到边缘节点来处理&#xff0c;具体逻辑可以参考 淘宝——EdgeRoutine边缘计算&#xff08;CDNServerless 边缘计算…

国内真正意义上的OpenAI,最强多模态大模型 MiniCPM-V 2.6 发布

最近这一两周看到不少互联网公司都已经开始秋招提前批了。不同以往的是&#xff0c;当前职场环境已不再是那个双向奔赴时代了。求职者在变多&#xff0c;HC 在变少&#xff0c;岗位要求还更高了。 最近&#xff0c;我们又陆续整理了很多大厂的面试题&#xff0c;帮助一些球友解…

二叉树的最大深度

二叉树的最大深度 思路&#xff1a; 法一&#xff1a;深搜 也就是递归 要想清楚边界条件 好久没写深搜了 回忆下怎么写。 突然就悟了&#xff1a; /*** Definition for a binary tree node.* struct TreeNode {* int val;* TreeNode *left;* TreeNode *rig…

2024年6月 青少年机器人技术等级考试理论综合试卷(二级)

202406 青少年等级考试机器人理论真题二级 第 1 题 如图&#xff0c;这是飞机起飞时的机翼示意图&#xff0c;下列说法正确的是&#xff1f;&#xff08; &#xff09; A&#xff1a;机翼上侧所受的气压为0 B&#xff1a;机翼受到向下的力的作用 C&#xff1a;机翼下侧所受…

基于sklearn的机器学习 — 支持向量机(SVM)

支持向量机&#xff08;SVM&#xff1a;support vector machine&#xff09;另一种功能强大、应用广泛的学习算法&#xff0c;可应用于分类、回归、密度估计、聚类等问题。SVM可以看作是感知器&#xff08;可被视为一种最简单形式的前馈神经网络&#xff0c;是一种二元线性分类…

C++ 特殊类设计

目录 0.前言 1.设计一个不能被拷贝的类 1.1C98实现 1.2C11实现 2.设计一个只能在堆上创建对象的类 3.设计一个只能在栈上创建对象的类 4.设计一个不能被继承的类 4.1C98实现 4.2C11实现 5.设计只能创建一个对象的类&#xff08;单例模式&#xff09; 5.1设计模式简介 5.2单例模…

Jupyter nbextensions安装与使用

这里写自定义目录标题 Jupyter nbextensions安装与使用安装7以下版本&#xff0c;安装插件包推荐使用的插件 Jupyter nbextensions安装与使用 目前&#xff0c;jupyter版本升级到了7以上版本&#xff0c;导致其界面非常难看&#xff0c;因此&#xff0c;为了重回之前的使用界面…

buuctf-crypto

前言 查找资料的时候,意外翻出之前刷的一些ctf题目,算是简单记录一下,当然因为常用typeo去写md文件,所以其中有很多当时记录的图片都失效了,可惜了 题目1:一眼就解密 ZmxhZ3tUSEVfRkxBR19PRl9USElTX1NUUklOR30 base64解密 flag:flag{THE_FLAG_OF_THIS_STRING} 题目2:MD5 …

全球化浪潮下的数据库革新:嘉里物流 TiDB 实践价值的设想

导读 本文来自 TiDB 社区武汉站——嘉里物流架构团队负责人肖飞老师的演讲《嘉里物流 & TiDB 在全球化业务场景中应用设想》。本次分享探讨了嘉里物流在全球化扩展中&#xff0c;将如何通过 TiDB 的强大功能应对海量数据挑战&#xff0c;优化技术架构&#xff0c;并提升决…

【Linux】详解自定义Shell管道 | 构建简易进程池

目录 续&#xff1a;通信 4 种情况 应用场景 1. 自定义 shell 管道 1. 包含头文件 2. 解析命令函数 详细步骤 3. 执行命令函数 4. 主函数 总结 2. 使用管道实现一个简易版本的进程池 代码结构 代码实现 channel.hpp tasks.hpp main.cc 子进程读取任务&#xff…

十九、虚拟机VMware Workstation(CentOSDebian)的安装

目录 &#x1f33b;&#x1f33b; 一、安装 VMware Workstation1.1 安装 VMware Workstation1.2 虚拟机上安装 CentOS1.3 虚拟机安装 Debian 二、配置Debian方便第三方工具远程连接2.1 配置debian2.2 安装远程SSH工具并连接 一、安装 VMware Workstation 官网下载 本地资源库…

你好! Git——企业级开发模型

企业级开发模型&#xff08;6&#xff09; 一、删除远程分支&#xff0c;git branch -a &#xff08;查看所有本地分支与远程分支&#xff09;还能看到已经删除的分支&#xff0c;怎么解决&#xff1f;二、企业级开发流程2.1 企业级开发流程2.2 系统开发环境 三、Git分支设计模…

RabbitMQ面试题汇总

RabbitMQ面试题 一、RabbitMQ基础1. 什么是RabbitMQ&#xff0c;它的基本架构是怎样的&#xff1f;2. RabbitMQ支持哪些协议&#xff1f;3. 说一下AMQP协议&#xff1f;4. 为什么要使用RabbitMQ&#xff1f;5. MQ的应用场景有哪些&#xff1f;6. 解耦、异步、削峰是什么&#x…

购物系统小程序的设计

管理员账户功能包括&#xff1a;系统首页&#xff0c;个人中心&#xff0c;商品分类管理&#xff0c;商品信息管理&#xff0c;特价商品管理&#xff0c;用户管理&#xff0c;留言板管理&#xff0c;订单管理&#xff0c;系统管理 微信端账号功能包括&#xff1a;系统首页&…

uni-app总结

1. <u-form-item label"报废人" ><u--input v-model"model.remark" border"bottom" placeholder"请输入"></u--input> </u-form-item> border"bottom" 报废日期 为了

后端Web开发之Maven

1.java项目构建工具maven介绍 Maven是apache旗下的一个开源项目。Apache软件基金会&#xff0c;成立于1999年7月&#xff0c;是目前世界上最大的最受欢迎的开源&#xff08;源代码开放&#xff09;软件基金会也是一一个专门为支持开源项目而生的非盈利性组织。 apache开源项目…

PDO在CANopen协议同步传输和异步传输

PDO&#xff08;过程数据对象&#xff09;在CANopen协议中有两种主要的传输方式&#xff1a;同步传输和异步传输。这两种方式决定了PDO数据的传输时机和条件。下面分别举例说明这两种传输方式&#xff1a; 1. 同步传输 (Synchronous Transmission) 概念&#xff1a; 在同步传输…

3GPP 4G 5G 主要协议

4G LTE的协议主要是36 series 5G NR的协议主要是38 series

RustScan:开源端口扫描器

RustScan 是一款开源端口扫描器&#xff0c;专为速度和多功能性而设计。 它结合了时尚的界面和随时间推移而适应和改进的能力。 借助 RustScan 的自适应学习功能&#xff0c;该工具不断优化其性能&#xff0c;使其成为最高效的端口扫描器。 在几秒钟内发现开放端口&#xff…