SAM:Segment Anything 代码复现和测试 基本使用

相关地址

代码:
https://github.com/facebookresearch/segment-anything
在线网站:
https://segment-anything.com/demo

环境配置

建议可以clone下来学习相关代码,安装可以不依赖与这个库

git clone https://github.com/facebookresearch/segment-anything.git

1.创建environment.yaml

name: sam
channels:- pytorch- conda-forge
dependencies:- python=3.8- pytorch=1.9.0- torchvision=0.10.0- cudatoolkit=11.1- pip
conda env create -f environment.yaml
conda activate raptor

2.安装

pip install git+https://github.com/facebookresearch/segment-anything.git

3.其他库

pip install opencv-python pycocotools matplotlib onnxruntime onnx

目前安装的版本

Successfully installed coloredlogs-15.0.1 contourpy-1.1.1
cycler-0.12.1 flatbuffers-23.5.26 fonttools-4.43.1 humanfriendly-10.0
importlib-resources-6.1.0 kiwisolver-1.4.5 matplotlib-3.7.3
mpmath-1.3.0 numpy-1.24.4 onnx-1.15.0 onnxruntime-1.16.1
opencv-python-4.8.1.78 packaging-23.2 protobuf-4.24.4
pycocotools-2.0.7 pyparsing-3.1.1 python-dateutil-2.8.2 six-1.16.0
sympy-1.12 zipp-3.17.0

初阶测试

1.下载模型
https://github.com/facebookresearch/segment-anything#model-checkpoints

2.测试代码

import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictordef show_anns(anns):if len(anns) == 0:returnsorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)ax = plt.gca()ax.set_autoscale_on(False)img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))img[:,:,3] = 0for ann in sorted_anns:m = ann['segmentation']color_mask = np.concatenate([np.random.random(3), [0.35]])img[m] = color_maskax.imshow(img)sam_checkpoint = "./checkpoints/sam_vit_h_4b8939.pth"
model_type = "vit_h"device = "cuda"sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)mask_generator = SamAutomaticMaskGenerator(sam)img_path = '/data/qinl/code/segment-anything/notebooks/images/dog.jpg'
image = cv2.imread(img_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)masks = mask_generator.generate(image)'''
Mask generation returns a list over masks, where each mask is a dictionary containing various data about the mask. These keys are:
* `segmentation` : the mask
* `area` : the area of the mask in pixels
* `bbox` : the boundary box of the mask in XYWH format
* `predicted_iou` : the model's own prediction for the quality of the mask
* `point_coords` : the sampled input point that generated this mask
* `stability_score` : an additional measure of mask quality
* `crop_box` : the crop of the image used to generate this mask in XYWH format
'''print(len(masks))
print(masks[0].keys())plt.figure(figsize=(20,20))
plt.imshow(image)
show_anns(masks)
plt.axis('off')
plt.show() 

3.输出

65
dict_keys(['segmentation', 'area', 'bbox', 'predicted_iou', 'point_coords', 'stability_score', 'crop_box'])

在这里插入图片描述

进阶测试

图片预处理部分

其他instruction,都是在这个基础上进行处理

import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamPredictordef show_mask(mask, ax, random_color=False):if random_color:color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)else:color = np.array([30/255, 144/255, 255/255, 0.6])h, w = mask.shape[-2:]mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)ax.imshow(mask_image)def show_points(coords, labels, ax, marker_size=375):pos_points = coords[labels==1]neg_points = coords[labels==0]ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   def show_box(box, ax):x0, y0 = box[0], box[1]w, h = box[2] - box[0], box[3] - box[1]ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))   sam_checkpoint = "./checkpoints/sam_vit_h_4b8939.pth"
model_type = "vit_h"device = "cuda"sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)predictor = SamPredictor(sam)img_path = '/data/qinl/code/segment-anything/notebooks/images/truck.jpg'
image = cv2.imread(img_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)# 预处理输入图片
predictor.set_image(image)

输入的instruction为point的情况

# 输入为point的情况input_point = np.array([[500, 375]])input_label = np.array([1])# 可以用来显示一下点的位置# plt.figure(figsize=(10,10))# plt.imshow(image)# show_points(input_point, input_label, plt.gca())# plt.axis('on')# plt.show()  masks, scores, logits = predictor.predict(point_coords=input_point,point_labels=input_label,multimask_output=True,)print('masks.shape',masks.shape)  # (number_of_masks) x H x W# 输出3个mask,分别有不同的scorefor i, (mask, score) in enumerate(zip(masks, scores)):plt.figure(figsize=(10,10))plt.imshow(image)show_mask(mask, plt.gca())show_points(input_point, input_label, plt.gca())plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)plt.axis('off')plt.show()  

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

多点输入(都视为前景点)

# 输入为多个point的情况(前景点)input_point = np.array([[500, 375]])input_label = np.array([1])masks, scores, logits = predictor.predict(point_coords=input_point,point_labels=input_label,multimask_output=True,)# additional pointsinput_point = np.array([[500, 375], [1125, 625]])input_label = np.array([1, 1])mask_input = logits[np.argmax(scores), :, :]  # Choose the model's best maskmasks, _, _ = predictor.predict(point_coords=input_point,point_labels=input_label,mask_input=mask_input[None, :, :],multimask_output=False,)print('masks.shape',masks.shape) # only 1 x H x Wplt.figure(figsize=(10,10))plt.imshow(image)show_mask(masks, plt.gca())show_points(input_point, input_label, plt.gca())plt.axis('off')plt.show() 

在这里插入图片描述

多点输入(前景点加后景点)

决定这个点是前景点还是后景点的就是label,0就是背景的意思

修改标签,得到不一样的结果

    # input_point = np.array([[500, 375], [1125, 625]])# input_label = np.array([1, 1])input_point = np.array([[500, 375], [1125, 625]])input_label = np.array([1, 0])

在这里插入图片描述

使用box框具体物体

# 输入为additional pointsinput_box = np.array([425, 600, 700, 875])masks, _, _ = predictor.predict(point_coords=None,point_labels=None,box=input_box[None, :],multimask_output=False,)plt.figure(figsize=(10, 10))plt.imshow(image)show_mask(masks[0], plt.gca())show_box(input_box, plt.gca())plt.axis('off')plt.show()

在这里插入图片描述

结合points和box

    # 输入为point和boxinput_box = np.array([425, 600, 700, 875])input_point = np.array([[575, 750]])input_label = np.array([0])masks, _, _ = predictor.predict(point_coords=input_point,point_labels=input_label,box=input_box,multimask_output=False,)plt.figure(figsize=(10, 10))plt.imshow(image)show_mask(masks[0], plt.gca())show_box(input_box, plt.gca())show_points(input_point, input_label, plt.gca())plt.axis('off')plt.show()

在这里插入图片描述

batch prompt inputs

    # batch prompt inputsinput_boxes = torch.tensor([[75, 275, 1725, 850],[425, 600, 700, 875],[1375, 550, 1650, 800],[1240, 675, 1400, 750],], device=predictor.device)transformed_boxes = predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2])masks, _, _ = predictor.predict_torch(point_coords=None,point_labels=None,boxes=transformed_boxes,multimask_output=False,)print(masks.shape)  # (batch_size) x (num_predicted_masks_per_input) x H x Wplt.figure(figsize=(10, 10))plt.imshow(image)for mask in masks:show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)for box in input_boxes:show_box(box.cpu().numpy(), plt.gca())plt.axis('off')plt.show()

在这里插入图片描述

End-to-end batched inference

    ## End-to-end batched inferenceimage1 = image  # truck.jpg from aboveimage1_boxes = torch.tensor([[75, 275, 1725, 850],[425, 600, 700, 875],[1375, 550, 1650, 800],[1240, 675, 1400, 750],], device=sam.device)image2 = cv2.imread('./notebooks/images/groceries.jpg')image2 = cv2.cvtColor(image2, cv2.COLOR_BGR2RGB)image2_boxes = torch.tensor([[450, 170, 520, 350],[350, 190, 450, 350],[500, 170, 580, 350],[580, 170, 640, 350],], device=sam.device)# Both images and prompts are input as PyTorch tensors that are already transformed to the correct frame. # Inputs are packaged as a list over images, which each element is a dict that takes the following keys:# * `image`: The input image as a PyTorch tensor in CHW format.# * `original_size`: The size of the image before transforming for input to SAM, in (H, W) format.# * `point_coords`: Batched coordinates of point prompts.# * `point_labels`: Batched labels of point prompts.# * `boxes`: Batched input boxes.# * `mask_inputs`: Batched input masks.from segment_anything.utils.transforms import ResizeLongestSideresize_transform = ResizeLongestSide(sam.image_encoder.img_size)def prepare_image(image, transform, device):image = transform.apply_image(image)image = torch.as_tensor(image, device=device.device) return image.permute(2, 0, 1).contiguous()batched_input = [{'image': prepare_image(image1, resize_transform, sam),'boxes': resize_transform.apply_boxes_torch(image1_boxes, image1.shape[:2]),'original_size': image1.shape[:2]},{'image': prepare_image(image2, resize_transform, sam),'boxes': resize_transform.apply_boxes_torch(image2_boxes, image2.shape[:2]),'original_size': image2.shape[:2]}]batched_output = sam(batched_input, multimask_output=False)# The output is a list over results for each input image, where list elements are dictionaries with the following keys:# * `masks`: A batched torch tensor of predicted binary masks, the size of the original image.# * `iou_predictions`: The model's prediction of the quality for each mask.# * `low_res_logits`: Low res logits for each mask, which can be passed back to the model as mask input on a later iteration.print('batched_output[0].keys()',batched_output[0].keys())fig, ax = plt.subplots(1, 2, figsize=(20, 20))ax[0].imshow(image1)for mask in batched_output[0]['masks']:show_mask(mask.cpu().numpy(), ax[0], random_color=True)for box in image1_boxes:show_box(box.cpu().numpy(), ax[0])ax[0].axis('off')ax[1].imshow(image2)for mask in batched_output[1]['masks']:show_mask(mask.cpu().numpy(), ax[1], random_color=True)for box in image2_boxes:show_box(box.cpu().numpy(), ax[1])ax[1].axis('off')plt.tight_layout()plt.show()

在这里插入图片描述

高阶测试

模型训练(waiting)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/175873.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

前端HTML

文章目录 一、什么是前端前端后端 前端三剑客1.什么是HTML2.编写前端的步骤1.编写服务端2.浏览器充当客户端访问服务端​ 3.浏览器无法正常展示服务端内容(因为服务端的数据没有遵循标准)4.HTTP协议>>>:最主要的内容就是规定了浏览器与服务端之间数据交互的格式 3. 前…

Angular-03:组件模板

各种学习后的知识点整理归纳,非原创! 组件模板 ① 数据绑定② 属性绑定③ 类名绑定④ 样式绑定⑤ 事件绑定⑥ 获取原生DOM对象6.1 在组件模板中获取6.2 在组件类中获取 ⑦ 双向数据绑定⑧ 内容投影8.1 select选择器8.2 单槽投影8.3 多槽投影 ⑨ 安全操作…

【Overload游戏引擎细节分析】PBR材质Shader---完结篇

PBR基于物理的渲染可以实现更加真实的效果,其Shader值得分析一下。但PBR需要较多的基础知识,不适合不会OpenGL的朋友。 一、PBR理论 PBR指基于物理的渲染,其理论较多,需要的基础知识也较多,我在这就不再写一遍了&…

leetcode:374. 猜数字大小(二分查找)

一、题目 函数原型:int guessNumber(int n) 二、思路 本题其实就是从 1 - n 中找出所要的答案。利用guess函数来判断数字是否符合答案。 答案小于当前数字,guess函数返回-1 答案等于当前数字,guess函数返回0 答案大于当前数字,gue…

nginx 转发数据流文件

1.问题描述 后端服务,从数据库中查询日志,并生成表格文件返回静态文件。当数据量几兆时,返回正常,但是超过几十兆,几百兆,就会超过网关的连接超时时间30秒。 时序图 这里面主要花费时间的地方在&#xff…

启动Vue项目报错Error: error:0308010C:digital envelope routines::unsupported

问题描述 启动Vue项目报错Error: error:0308010C:digital envelope routines::unsupported 出现这个一般就是node版本的问题,通过命令查看node -v查看node版本; 百度查了好多,都让我降低node版本,属实太麻烦了 在不改node版本的…

【C# Programming】委托和lambda表达式、事件

目录 一、委托和lambda表达式 1.1 委托概述 1.2 委托类型的声明 1.3 委托的实例化 1.4 委托的内部机制 1.5 Lambda 表达式 1.6 语句lambda 1.7 表达式lambda 1.8 Lambda表达式 1.9 通用的委托 1.10 委托没有结构相等性 1.11 Lambda表达式和匿名方法的内部机制 1.1…

博弈论学习笔记(2)——完全信息静态博弈

前言 这部分我们学习的是完全信息静态博弈,主要内容包括博弈论的基本概念、战略式博弈、Nash均衡、Nash均衡解的特性、以及Nash均衡的应用。 零、绪论 1、什么是博弈论 1)博弈的定义 博弈论:研究决策主体的行为发生直接相互作用时候的决策…

Java架构师软件架构的演化和维护

目录 1 导学2 软件架构演化和定义3 面向对象软件架构演化4 软件架构演化方式的分类5 软件架构演化原则6 软件架构演化评估方法7 大型网站架构演化8 软件架构维护想学习架构师构建流程请跳转:Java架构师系统架构设计 1 导学 2 软件架构演化和定义 软件架构的演化和维护就是对…

Kafka - 异步/同步发送API

文章目录 异步发送普通异步发送异步发送流程Code 带回调函数的异步发送带回调函数的异步发送流程Code 同步发送API 异步发送 普通异步发送 需求&#xff1a;创建Kafka生产者&#xff0c;采用异步的方式发送到Kafka broker 异步发送流程 Code <!-- https://mvnrepository…

飞鼠异地组网工具全网互通实战指南

飞鼠异地组网工具全网互通实战指南 一、飞鼠异地组网工具介绍1.1 飞鼠工具简介1.2 飞鼠工具官网 二、本次实践介绍2.1 本次实践前提2.2 本次实践简介2.3 本次实践环境规划 三、异地组网配置3.1 进入中心控制器节点管理后台3.2 网卡设置3.3 进入子网节点管理后台3.4 网卡设置 四…

项目综合实训,vrrp+bfd,以及策略路由的应用

目录 一&#xff0e; 项目需求 二&#xff0e; Visio设备画图 三&#xff0e; 设备选型 三&#xff0e;vlan规划 四&#xff0e;Ip地址规划 五&#xff0e;实验拓扑图 六&#xff0e;配置过程及结果 项目需求 1.S1作为VLAN10的主网关和根桥&#xff0c;S2作为v…

Pytorch L1,L2正则化

L1正则化和L2正则化是常用的正则化技术&#xff0c;用于在机器学习模型中控制过拟合。它们的主要区别在于正则化项的形式和对模型参数的影响。 L1正则化&#xff08;Lasso正则化&#xff09;&#xff1a; 正则化项形式&#xff1a;L1正则化使用模型参数的绝对值之和作为正则化…

Emscripten + CMakeLists.txt 将 C++ 项目编译成 WebAssembly(.wasm)/js,并编译 Html 测试

背景&#xff1a;Web 端需要使用已有的 C 库&#xff08;使用 CMake 编译&#xff09;&#xff0c;需要将 C 项目编译成 WebAssembly(.wasm) 供 js 调用。 上篇文章《Mac 上安装 Emscripten》 已讲解如何安装配置 Emscripten 环境。 本篇文章主要讲解如何将基于 CMakeLists 配…

Gitee 发行版

Gitee 发行版 1、Gitee 发行版管理2、项目仓库中创建发行版本3、项目中导入3.1 gradle配置3.2 dependencies执行正常&#xff0c;包没有下载 1、Gitee 发行版管理 Gitee 发行版&#xff08;Release&#xff09;管理 2、项目仓库中创建发行版本 按照Gitee官网操作就行 3、项目…

PCIe 访问 EP 配置空间,空间映射详解,BDF 计算偏移

访问 EP 的配置空间方法 内存映射IO 访问 内存访问配置空间 前置知识 PCIe 设备的寻址是按照 BDF 即 Bus-Device-Function 来组织的。访问某个设备则需要根据BDF计算偏移地址。 两种不同的内存访问配置空间方法 类 xilinx&#xff0c;基地址 偏移地址访问 // linux-5.10\…

http1,https,http2,http3总结

1.HTTP 当我们浏览网页时&#xff0c;地址栏中使用最多的多是https://开头的url&#xff0c;它与我们所学的http协议有什么区别&#xff1f; http协议又叫超文本传输协议&#xff0c;它是应用层中使用最多的协议&#xff0c; http与我们常说的socket有什么区别吗&#xff1f; …

【ARM 嵌入式 C 入门及渐进 10 -- 冒泡排序 选择排序 插入排序 快速排序 归并排序 堆排序 比较介绍】

文章目录 排序算法小结排序算法C实现排序方法的稳定性 排序算法小结 C语言中常用的排序算法包括冒泡排序、选择排序、插入排序、快速排序、归并排序、堆排序。下面我们来一一介绍&#xff1a; 冒泡排序&#xff08;Bubble Sort&#xff09;&#xff1a;冒泡排序是通过比较相邻…

android 8.1 disable unsupported sensor

如果device不支持某种sensor,可以在android/frameworks/base/core/java/android/hardware/SystemSensorManager.java里将其disabled掉。以disable proximity sensor为例。 public SystemSensorManager(Context context, Looper mainLooper) {synchronized(sLock) {if (!sNativ…

MWeb Pro for Mac:博客生成编辑器,助力你的创作之旅

在当今数字化时代&#xff0c;博客已经成为了许多人记录生活、分享知识和表达观点的重要渠道。而要打造一个专业、美观且易于管理的博客&#xff0c;选择一款强大的博客生成编辑器至关重要。今天&#xff0c;我向大家推荐一款备受好评的Mac软件——MWeb Pro。 MWeb Pro是一款专…