【Segment Anything Model】四:预处理自己的数据集接入SAM

文章目录

  • 1️⃣预备知识
  • 2️⃣实现思路
  • 🔸脚本预处理得到包含embedd和GT的npz
  • 🔸编写Dataset类
  • 3️⃣代码
  • 🔸实现脚本预处理得到包含embedd和GT的npz代码
  • 🔸实现Dataset的代码

1️⃣预备知识

欢迎订阅本专栏(为爱发电,限时免费),联系前三篇一起食用哈!上一篇讲了如何使用SAM接口完成一个训练流程,本篇只专注于如何处理包装自己的数据集。

流程如下:
在这里插入图片描述
直接将图像编码器编码得到的embedding存入npz代表原始图像,是因为,我们有很多种训练策略,但每一次的编码过程是一摸一摸的,并且也是最耗时的一部分,所以,将其静态化,每次用的时候拿来解压。

由于比较粗糙并且没有做交叉验证,所以这里在原始图像存放路径的时候就划分好了训练测试,但一般自己的数据集还是做个交叉验证,在得到npz之后划分训练测试。

2️⃣实现思路

🔸脚本预处理得到包含embedd和GT的npz

embedding步骤:
1.归一化
2.ResizeLongestSide到1024*1024
3.sam_model.preprocess预处理
4.sam_model.image_encoder编码

GD步骤:1.校验GT是否是2D 2.校验是否和img尺寸大小相同 3.uint8到255

🔸编写Dataset类

init:解压npz,读取数据放入self变量

getitem:根据GT获得边界框当作框提示,在GT内随机选择点当作点提示, 将embedd,box,point,GT,组装torch.tensor

len:返回图片个数就好啦

在这里插入图片描述

3️⃣代码

🔸实现脚本预处理得到包含embedd和GT的npz代码

注释都在代码里吗,按行注释,我真贴心💓

import numpy as np
import os
join = os.path.join
from skimage import transform, io
from tqdm import tqdm
import torch
from segment_anything import sam_model_registry
from segment_anything.utils.transforms import ResizeLongestSide# GT存放路径,到文件夹
gt_path = "./"
# 组装好npz的保存路径
save_path = "./"
# 获取所有GT图像名称
names = sorted(os.listdir(gt_path))
os.makedirs(save_path, exist_ok=True)
model_type = 'vit_b'
checkpoint = 'xx/sam_vit_b_01ec64.pth'
device = 'cuda:0'
sam_model = sam_model_registry[model_type](checkpoint=checkpoint).to(device)
imgs = []
gts = []
img_embeddings = []
# image路径 到最后一层文件夹
img_path=""
for gt_name in tqdm(names):# 如果你是jpg改一下后缀image_name = gt_name.split('.')[0] + "png"# 读取GTgt_data = io.imread(join(gt_path, gt_name))# GT必须是2D,如果是3D就取前两通道if len(gt_data.shape) == 3:gt_data = gt_data[:, :, 0]assert len(gt_data.shape) == 2, 'GT must be 2D'# 尺寸转256数值转255gt_data = transform.resize(gt_data == 255, (256, 256), order=0,preserve_range=True, mode='constant')gt_data = np.uint8(gt_data)# 排除GT特别小的情况,这条可以不加if np.sum(gt_data) > 100:assert np.max(gt_data) == 1 and np.unique(gt_data).shape[0] == 2, 'GT must be 2D'image_data = io.imread(join(img_path, image_name))# 计算最大值最小值lower_bound, upper_bound = np.percentile(image_data, 0.5), np.percentile(image_data, 99.5)# 排除特别特殊的像素image_data_pre = np.clip(image_data, lower_bound, upper_bound)# 归一化image_data_pre = (image_data_pre - np.min(image_data_pre)) / (np.max(image_data_pre) - np.min(image_data_pre)) * 255.0image_data_pre[image_data == 0] = 0# 归一化image_data_pre = transform.resize(image_data_pre, (256, 256), order=3,preserve_range=True, mode='constant', anti_aliasing=True)image_data_pre = np.uint8(image_data_pre)imgs.append(image_data_pre)gts.append(gt_data)# SAM提供的resize到1024sam_transform = ResizeLongestSide(sam_model.image_encoder.img_size)resize_img = sam_transform.apply_image(image_data_pre)# resize_img是通道在后,sam要求通道在前,transposehi是对resize_img数组进行维度重排(dimension reordering)的操作。resize_img_tensor = torch.as_tensor(resize_img.transpose(2, 0, 1)).to(device)# 增加一个channel假装当作有一个batchsize输入到sam_model.image_encoderinput_image = sam_model.preprocess(resize_img_tensor[None, :, :, :])  # (1, 3, 1024, 1024)# 提前计算图像embeddingwith torch.no_grad():embedding = sam_model.image_encoder(input_image)img_embeddings.append(embedding.cpu().numpy()[0])# 上面数据已经处理好并存在数组了,需要数据字典存在npz中
# 沿着纵轴堆砌,每一个都是(256, 256, 3),堆起来是(n, 256, 256, 3)
imgs = np.stack(imgs, axis=0)  # (n, 256, 256, 3)
gts = np.stack(gts, axis=0)  # (n, 256, 256)
img_embeddings = np.stack(img_embeddings, axis=0)  # (n, 1, 256, 64, 64)
# np的保存npz操作
np.savez_compressed(join(save_path, '.npz'), imgs=imgs, gts=gts, img_embeddings=img_embeddings)

🔸实现Dataset的代码

import numpy as np
import matplotlib.pyplot as plt
import osjoin = os.path.join
from tqdm import tqdm
import torch
from torch.utils.data import Dataset, DataLoader
import monai
from segment_anything import sam_model_registry
from segment_anything.utils.transforms import ResizeLongestSide
import randomtorch.manual_seed(2023)# 构造自己的Dataset继承Dataset类
class MyselfDataset(Dataset):def __init__(self, data_root):print("into init")self.data_root = data_root# 访问npz文件self.npz_files = sorted(os.listdir(self.data_root))# 去除npz里的数据self.npz_data = [np.load(join(data_root, f)) for f in self.npz_files]# 将取出来的数据放在变量保存self.ori_gts = np.vstack([d['gts'] for d in self.npz_data])self.img_embeddings = np.vstack([d['img_embeddings'] for d in self.npz_data])def __len__(self):return self.ori_gts.shape[0]def __getitem__(self, index):img_embed = self.img_embeddings[index]gt2D = self.ori_gts[index]# 获取非零点坐标y_indices, x_indices = np.where(gt2D > 0)# 获取GT坐标框x_min, x_max = np.min(x_indices), np.max(x_indices)y_min, y_max = np.min(y_indices), np.max(y_indices)# 在GT框加扰动H, W = gt2D.shapex_min = max(0, x_min - np.random.randint(0, 10))x_max = min(W, x_max + np.random.randint(0, 10))y_min = max(0, y_min - np.random.randint(0, 10))y_max = min(H, y_max + np.random.randint(0, 10))bboxes = np.array([x_min, y_min, x_max, y_max])# 在GT在5像素以内的地方随机选择两个背景点y_zero, x_zero = np.where(gt2D == 0)y_zero = np.unique(y_zero)x_zero = np.unique(x_zero)y_list = y_zero[(y_min - 5 < y_zero) & (y_zero < y_max + 5)]x_list = x_zero[(x_min - 5 < x_zero) & (x_zero < x_max + 5)]y1, y2 = random.choices(y_list, k=2)x1, x2 = random.choices(x_list, k=2)background_index1 = [x1, y1]background_index2 = [x2, y2]# 在GT内随机选择前景点foreground_index1, foreground_index2, foreground_index3 = random.choices(np.argwhere(gt2D == 1), k=3)# 将所有选择好的点添加到list,如果是单点,不需要直接返回点的index就好。pt_list_s = []pt_list_s.append(background_index1)pt_list_s.append(background_index2)pt_list_s.append(foreground_index1)pt_list_s.append(foreground_index2)pt_list_s.append(foreground_index3)points = pt_list_s# 0是背景1是前景points_labels = [0, 0, 1, 1, 1]return torch.tensor(img_embed).float(), torch.tensor(gt2D[None, :, :]).long(), torch.tensor(bboxes).float(), torch.tensor(points).float(), torch.tensor(points_labels).float()

之后连系上篇 【Segment Anything Model】SAM模型微调自定义数据集,更改混合提示方式:点,框,点框混合
在这里取值训练就好啦
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/79028.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Idea添加mybatis的mapper文件模版

针对Java开发人员&#xff0c;各种框架的配置模版的确是需要随时保留一份&#xff0c;在使用的时候&#xff0c;方便复制粘贴&#xff0c;但是也依然不方便&#xff0c;我们可以给开发工具&#xff08;IDE&#xff09;中添加配置模版&#xff0c;这里我介绍下使用idea开发工具&…

ad+硬件每日学习十个知识点(18)23.7.29 (LDO原理、LDO的补偿引脚)

文章目录 1.LDO名字介绍2.LDO的应用范围3.LDO的原理4.LDO输出端和输入端的差值至少满足多少V&#xff1f;怎么计算的&#xff1f;5.输出的误差和输出电流&#x1f446;&#xff08;右下角图像&#xff09;6.LDO一般会有个引脚是做补偿之用&#xff0c;datasheet会说明一个器件的…

Packet Tracer - 检验 IPv4 和 IPv6 编址

Packet Tracer - 检验 IPv4 和 IPv6 编址 地址分配表 设备 接口 IPv4 地址 子网掩码 默认网关 IPv6 地址/前缀 R1 G0/0 10.10.1.97 255.255.255.224 N/A 2001:DB8:1:1::1/64 N/A S0/0/1 10.10.1.6 255.255.255.252 N/A 2001:DB8:1:2::2/64 N/A 本地链路 F…

Linux 信号signal处理机制

Signal机制在Linux中是一个非常常用的进程间通信机制&#xff0c;很多人在使用的时候不会考虑该机制是具体如何实现的。signal机制可以被理解成进程的软中断&#xff0c;因此&#xff0c;在实时性方面还是相对比较高的。Linux中signal机制的模型可以采用下图进行描述。 每个进程…

电力巡检无人机助力迎峰度夏,保障夏季电力供应

夏季是电力需求量较高的时期&#xff0c;随着高温天气的来临&#xff0c;风扇、空调和冰箱等电器的使用量也大大增加&#xff0c;从而迎来夏季用电高峰期&#xff0c;电网用电负荷不断攀升。为了保障夏季电网供电稳定&#xff0c;供电公司会加强对电力设施设备的巡检&#xff0…

spring — Spring Security 5.7与6.0差异性对比

1. spring security Spring Security 是一个提供身份验证、授权和针对常见攻击保护的框架。 凭借对保护命令式和反应式应用程序的一流支持&#xff0c;它成为基于Spring的标准安全框架。 Spring Security 在最近几个版本中配置的写法都有一些变化&#xff0c;很多常见的方法都…

【力扣每日一题】2023.8.7 反转字符串

目录 题目&#xff1a; 示例&#xff1a; 分析&#xff1a; 代码&#xff1a; 题目&#xff1a; 示例&#xff1a; 分析&#xff1a; 题目给我们一个字符数组形式的字符串&#xff0c;让我们直接原地修改反转字符串&#xff0c;不必返回。 给出的条件是使用O(1)的额外空间…

24届近5年重庆邮电大学自动化考研院校分析

今天给大家带来的是重庆邮电大学控制考研分析 满满干货&#xff5e;还不快快点赞收藏 一、重庆邮电大学 学校简介 重庆邮电大学简称"重邮"&#xff0c;坐落于直辖市-重庆市&#xff0c;入选国家"中西部高校基础能力建设工程”、国家“卓越工程师教育培养计划…

【ES】笔记-let 声明及其特性

let 声明及其特性 声明变量 变量赋值、也可以批量赋值 let a;let b,c,d;let e100;let f521,giloveyou,h[];变量不能重复声明 let star罗志祥;let star小猪;块级作用域&#xff0c;let声明的变量只在块级作用域内有效 {let girl周杨青;}console.log(girl)注意&#xff1a;在 i…

SpringIOC注入的两种方式讲解以及代码示例

Ioc是Spring全家桶各个功能模块的基础&#xff0c;创建对象的容器。 AOP也是以IoC为基础&#xff0c;AOP是面向切面编程&#xff0c;抽象化的面向对象 AOP功能&#xff1a;打印日志&#xff0c;事务&#xff0c;权限处理 AOP的使用会在下一篇文章进行介绍 IoC 翻译为控制反…

配置Hive远程服务详细步骤

HiveServer2支持多客户端的并发和认证&#xff0c;为开放API客户端如JDBC、ODBC提供了更好的支持。 &#xff08;1&#xff09;修改hive-site.xml&#xff0c;在文件中添加以下内容&#xff1a; <property><name>hive.metastore.event.db.notification.api.auth&l…

嵌入式硬件系统的基本组成

嵌入式硬件系统的基本组成 嵌入式系统的硬件是以包含嵌入式微处理器的SOC为核心&#xff0c;主要由SOC、总线、存储器、输入/输出接口和设备组成。 嵌入式微处理器 每个嵌入式系统至少包含一个嵌入式微处理器 嵌入式微处理器体系结构可采用冯.诺依曼&#xff08;Von Neumann&…

【ShaderToy中图形效果转译到UnityShaderlab案例分享,实现科技感电流场_PlasmaGlobe】

Mac电脑系统下的显示: Windows系统下的显示: Shader"ShaderToy/PlasmaGlobe" {Properties{_MainTex("MainTex", 2D) = "white"{}_iMouse

AI编程工具Copilot与Codeium的实测对比

csdn原创谢绝转载 简介 现在没有AI编程工具&#xff0c;效率会打一个折扣&#xff0c;如果还没有&#xff0c;赶紧装起来&#xff0e; GitHub Copilot是OpenAi与github等共同开发的的AI辅助编程工具&#xff0c;基于ChatGPT驱动&#xff0c;功能强大&#xff0c;这个没人怀疑…

解决Win11右键菜单问题

✅作者简介&#xff1a;大家好&#xff0c;我是Cisyam&#xff0c;热爱Java后端开发者&#xff0c;一个想要与大家共同进步的男人&#x1f609;&#x1f609; &#x1f34e;个人主页&#xff1a;Cisyam-Shark的博客 &#x1f49e;当前专栏&#xff1a; 程序日常 ✨特色专栏&…

OpenAI 已为 GPT-5 申请商标,GPT-4 发布不到半年,GPT-5 就要来了吗?

据美国专利商标局&#xff08;USPTO&#xff09;信息显示&#xff0c;OpenAI已经在7月18日申请注册了“GPT-5”商标。 在这份新商标申请中&#xff0c;OpenAI将“GPT-5”描述为一种“用于使用语言模型的可下载计算机软件”。 继GPT-4发布之后&#xff0c;它预计将成为OpenAI下一…

Python自动化测试之用Robot Framework进行自动化测试详解

概要 你还在手动测试&#xff1f;不妨了解一下更高效、准确且简单的测试方法——使用Python的Robot Framework进行自动化测试。 什么是Robot Framework&#xff1f; Robot Framework是一款开源的Python自动化测试框架&#xff0c;它基于关键字驱动的思想&#xff0c;具有易读、…

【2.1】Java微服务:详解Hystrix

✅作者简介&#xff1a;大家好&#xff0c;我是 Meteors., 向往着更加简洁高效的代码写法与编程方式&#xff0c;持续分享Java技术内容。 &#x1f34e;个人主页&#xff1a;Meteors.的博客 &#x1f49e;当前专栏&#xff1a; 深度学习 ✨特色专栏&#xff1a; 知识分享 &…

【测试】软件测试工具JMeter简单用法

简明扼要&#xff0c;点到为止。 1. JMeter介绍 JMeter的全称是Apache JMeter&#xff0c;是一款用于软件测试的工具软件&#xff0c;其是开源免费的&#xff0c;由Apache基金会运营。 官网&#xff1a;Apache JMeter - Apache JMeter™ 2. 下载安装及运行 2.1 安装 Java8…

AlmaLinux 9 安装 Go 1.20

AlmaLinux 9 安装 Golang 1.20 1. 下载 go 安装包2. 安装 go3. 配置环境变量4. 确认 go 版本 1. 下载 go 安装包 访问 https://go.dev/dl/&#xff0c;下载你想安装的版本&#xff0c;比如 go1.20.7.linux-amd64.tar.gz&#xff0c; 2. 安装 go (可选)删除旧版本&#xff0c;…