[医学分割大模型系列] (1) SAM 分割大模型解析

[医学大模型系列] [1] SAM 分割大模型解析

  • 1. 特点
  • 2. 网络结构
    • 2.1 Image encoder
    • 2.2 Prompt encoder
    • 2.3 Mask decoder
  • 3. 数据引擎
  • 4. 讨论

论文地址:Segment Anything

开源地址:https://github.com/facebookresearch/segment-anything

demo地址:Segment Anything | Meta AI

参考:

  1. SAM模型详解
  2. https://www.bilibili.com/video/BV1K94y177Ka/?spm_id_from=333.337.search-card.all.click&vd_source=6bf14836c2866f1a29042ffd4369a079
  3. https://www.bilibili.com/video/BV1Bu411s79u/?spm_id_from=333.337.search-card.all.click&vd_source=6bf14836c2866f1a29042ffd4369a079

1. 特点

在这里插入图片描述

  • 可提示的(Prompt)交互图像分割大模型(Foundation Models)
  • 四种Prompt形式:点,框,Mask,文本
  • 构建数据引擎:使用高效模型来协助数据收集和使用新收集的数据来帮助模型迭代。

2. 网络结构

模型整体上包含三个大模块,image encoder,prompt encoder和mask decoder。

2.1 Image encoder

image encoder旨在映射待分割的图像到图像特征空间。
在这里插入图片描述
这里的ViT结构也并不是十分复杂,这里简单列出输入图像经过ViT的流程,其实整体只有4个步骤:

  • 输入图像进入网络,先经过一个卷积base的patch_embedding:取16*16为一个patch,步长也是16,这样feature map的尺寸就缩小了16倍,同时channel从3映射到768。
  • patch_embed过后加positional_embedding:positional_embedding是个可学习的参数矩阵,初始化是0。
  • 加了positional_embedding后的feature map过16个transformer block,其中12个transformer是基于window partition(就是把特征图分成14*14的windows做局部的attention)的attn模块,和4个全局attn,这4个全局attn是穿插在windowed attention中的。
  • 最后过两层卷积(neck)把channel数降到256(token长度),这就是最终的image embedding的结果。

整体来看,这个部分的计算量是相对来说比较大的,demo体验过程中,只有这个过程的计算是在fb的服务器上做的,prompt encoder和mask decoder体积比较小,都是在浏览器内部或者说用本地的内存跑的,整体速度还比较快。

其使用的预训练模型: Masked Autoencoders Are Scalable Vision Learners - 代码

# build_sam.py
sam_model_registry = {"default": build_sam_vit_h,"vit_h": build_sam_vit_h,"vit_l": build_sam_vit_l,"vit_b": build_sam_vit_b,
}

2.2 Prompt encoder

prompt encoder则是负责映射输入的prompt到prompt的特征空间,这里有一点要提就是作者定义了sparse(包括 点,框,文本)和dense(mask)两种prompt,其中sparse prompt比较好理解,就是指demo中我们可以输入的点,目标框或者是描述目标的text,而dense prompt在目前的线上demo中体验不到,paper中也只说它对应的是mask类型的prompt,从代码里看应该是训练时候用的比较多,一般是上一次迭代预测出的一个粗分割的mask,粗略指出待分割的目标区域。
在这里插入图片描述
映射出的特征的channel和image embedding的channel一致(默认均为256),因为这两个后边要用attention进行融合。

  • sparse prompt:
    • 如果prompt是point,那么它的映射由两个部分相加组成,一个是位置编码,这里的位置编码使用的是Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional Domains的编码方式,用空间坐标乘以高斯分布的向量来描述位置比直接的线性向量描述效果更好,另一个部分是一个描述当前点是前景还是背景(因为demo里可以选择pos点也可以选择neg点)特征的可学习的一维向量。换句话说,如果当前选择的点是positive,那么就在位置编码的2维向量上加一个表示postitive的一维向量,如果是neg,就加一个表示neg的一维向量,对于所有的positive的点,加上去的pos向量都是一样的。
    • 如果prompt是box,那box的映射也是由两个部分相加组成,第一部分是左上和右下两个点的位置编码,第二部分是一组一维向量用来描述这个点是“左上”还是“右下”。也就是说,对于左上的点,他的映射就是位置编码+“左上”这个特征的描述向量,右下的点,就是位置编码+“右下”这个特征的描述向量。
  • dense prompt:
    • 对于mask这类的dense prompt,他的映射就比较简单粗暴。在输入prompt encoder之前,先要把mask降采样到4x,再过两个2x2,stride=2的卷积,这样尺寸又降了4x,就和降了16x的图像特征图尺寸一致了,再过一个1*1的卷积,把channel也升到256。如果没有提供mask,也就是我们实际inference时候的场景,这个结构会直接返回一个描述“没有mask”特征的特征图。

预训练: CLIP

# prompt_encoder.pydef forward(self,points: Optional[Tuple[torch.Tensor, torch.Tensor]],boxes: Optional[torch.Tensor],masks: Optional[torch.Tensor],) -> Tuple[torch.Tensor, torch.Tensor]:"""Embeds different types of prompts, returning both sparse and denseembeddings.Arguments:points (tuple(torch.Tensor, torch.Tensor) or none): point coordinatesand labels to embed.boxes (torch.Tensor or none): boxes to embedmasks (torch.Tensor or none): masks to embedReturns:torch.Tensor: sparse embeddings for the points and boxes, with shapeBxNx(embed_dim), where N is determined by the number of input pointsand boxes.torch.Tensor: dense embeddings for the masks, in the shapeBx(embed_dim)x(embed_H)x(embed_W)"""bs = self._get_batch_size(points, boxes, masks)sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())if points is not None:coords, labels = pointspoint_embeddings = self._embed_points(coords, labels, pad=(boxes is None))sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)if boxes is not None:box_embeddings = self._embed_boxes(boxes)sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)if masks is not None:dense_embeddings = self._embed_masks(masks)else:dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(bs, -1, self.image_embedding_size[0], self.image_embedding_size[1])return sparse_embeddings, dense_embeddings

2.3 Mask decoder

在这里插入图片描述

  • detr结构回顾

    • cross att:
      在这里插入图片描述
      100个queries token对图片特征进行遍历,确定每个token要对应寻找的物体特征

    • self att:
      在这里插入图片描述
      避免token之间寻找物体特征重复,queries token进行自注意力机制。

  • 结构描述

    • (1) self-attention on the tokens,
    • (2) cross-attention from tokens (as queries) to the image embedding,
    • (3) a point-wise MLP updates each token,
    • (4) cross-attention from the image embedding (as queries) to tokens.
    • (5) Additionally, the entire original prompt tokens (including their positional encodings) are re-added to the updated tokens whenever they participate in an attention layer. 提示加了两遍
      在这里插入图片描述
  • 结构分析

    • 在prompt embedding进入decoder之前,先在它上面concat了一组可学习的output tokens (相当于dert的queries token,100个tokens),output tokens由两个部分构成:
      • 一个是iou token,它会在后面被分离出来用于预测iou的可靠性(对应结构图右侧的IoU output token),它受到模型计算出的iou与模型计算出的mask与GT实际的iou之间的MSE loss监督;
      • 另一个是mask token,它也会在后面被分离出来参与预测最终的mask(对应结构图右侧的output token per mask),mask受到focal loss和dice loss 20:1的加权组合监督。
      • 这两个token的意义我感觉比较抽象,因为理论来说进入decoder的变量应该是由模型的输入,也就是prompt和image的映射构成,但这两个token的定义与prompt和image完全没有关系,而是凭空出现的。从结果反推原因,只能把它们理解成对模型的额外约束,因为它们两个参与构成了模型的两个输出并且有loss对他们进行监督。
      • 最终prompt embedding(这一步改名叫prompt token)和刚才提到这两个token concat到一起统称为tokens进入decoder。
    • image embedding在进入decoder之前也要进行一步操作:dense prompt由于包含密集的空间信息,与image embedding所在的特征空间一致性更高,所以直接与image embedding相加融合。因为后面要与prompt做cross attention融合,这里还要先算一下image embedding的位置编码。
    • 接下来{image embedding,image embedding的位置编码,tokens}进入一个两层transformer结构的decoder做融合。值得注意的是,在transformer结构中,为了保持位置信息始终不丢失,每做一次attention运算,不管是self-attention还是cross-attention,tokens都叠加一次初始的tokens,image embedding都叠加一次它自己的位置编码,并且每个attention后边都接一个layer_norm。
      • tokens先过一个self-attention。
      • tokens作为q,对image embedding做cross attention,更新tokens。
      • tokens再过两层的mlp做特征变换。
      • image embedding作为q,对tokens做cross attention,更新image embedding。
    • 更新后的tokens作为q,再对更新后的image embedding做cross attention,产生最终的tokens。
    • 更新后的image embedding过两层kernel_size=2, stride=2的转置卷积,升采样到4x大小(依然是4x降采样原图的大小),产生最终的image embedding。
    • 接下来兵分两路:
      • mask token被从tokens中分离出来(因为他一开始就是concat上去的,可以直接按维度摘出来),过一个三层的mlp调整channel数与最终的image embedding一致,并且他们两个做矩阵乘法生成mask的预测。最终的image embedding大小为[token长度(channel),图像长,图像宽],mask token的大小为[token长度(channel),token个数(最终生成的mask个数)]。两者点乘后的大小为[token个数(最终生成的mask个数),图像长,图像宽]。也就是说有几个token个数生成几个mask。
      • iou token被从tokens中分离出来,也过一个三层的mlp生成最终的iou预测。在反向传播时,排序mask,参与计算的只有loss最小的mask相关的参数。
    • 最后,如前文所述,分别对mask的预测和iou预测进行监督,反向传播,更新参数。每一个mask,会随机产生11种prompt与之配对。

对于一个输出,如果给出一个模糊的提示,该模型将平均多个有效的掩码。为了解决这个问题,我们修改了模型,以预测单个提示的多个输出掩码(比如说提示在衣服上,会分出来衣服的mask和人的mask)。我们发现3个掩模(output/mask token的个数为3)输出足以解决大多数常见的情况(嵌套掩模通常最多有三个深度:整体、部分和子部分)。在训练期间,我们只支持mask上的最小损失[匈牙利损失]。为了对掩码进行排名,该模型预测了每个掩码的置信度分数(即估计的IoU)

在这里插入图片描述

# transformer.pydef forward(self,image_embedding: Tensor,image_pe: Tensor,point_embedding: Tensor,) -> Tuple[Tensor, Tensor]:"""Args:image_embedding (torch.Tensor): image to attend to. Should be shapeB x embedding_dim x h x w for any h and w.image_pe (torch.Tensor): the positional encoding to add to the image. Musthave the same shape as image_embedding.point_embedding (torch.Tensor): the embedding to add to the query points.Must have shape B x N_points x embedding_dim for any N_points.Returns:torch.Tensor: the processed point_embeddingtorch.Tensor: the processed image_embedding"""# BxCxHxW -> BxHWxC == B x N_image_tokens x Cbs, c, h, w = image_embedding.shapeimage_embedding = image_embedding.flatten(2).permute(0, 2, 1)image_pe = image_pe.flatten(2).permute(0, 2, 1)# Prepare queriesqueries = point_embeddingkeys = image_embedding# Apply transformer blocks and final layernormfor layer in self.layers:queries, keys = layer(queries=queries,keys=keys,query_pe=point_embedding,key_pe=image_pe,)# Apply the final attention layer from the points to the imageq = queries + point_embeddingk = keys + image_peattn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)queries = queries + attn_outqueries = self.norm_final_attn(queries)return queries, keys

3. 数据引擎

在这里插入图片描述
阶段一,手动阶段,模型越来越好,数据越来越多

  1. 拿现有已经表好的数据集训练一个粗糙版的SAM
  2. 粗糙版的SAM分割新数据,人工修正后,重新训练模型
  3. 上两部反复迭代

阶段二,半自动阶段,默认准确率够高,召回率不够好(漏标)

  1. 用训练好的SAM模型分割图像
  2. 把分割不出来的给人来标
  3. 重新训练模型

阶段三,全自动阶段

  1. SAM模型在图像上分割,对得分较高,准确率较高,符合设定规则的结果进行保留

4. 讨论

在这里插入图片描述
SAM的一个我个人认为比较新颖的点子是它从interactive segmentation引申出了一个新的任务类型,叫做promptable segmentation。从他的模型中也能看出,输入的prompt是模型在输出最终mask的关键指导信息,这也是为什么我发现目前的SAM模型在处理一些专业领域图像(比如我自己从事的医学图像分割)时,直接使用他的segment everything功能,也就是无prompt进行分割时效果不好的原因。

另一个要搞清楚的问题是在进行有prompt的分割时,实际上实现的是一个二分类的分割任务,模型要解决的问题是根据我们选择的点的特征,从图像(背景)中分割出这个点所在的目标物体(前景),它本质上并不关心这个目标物体是个什么东西。滑稽一点来说,整个过程实际上有点类似photoshop里魔棒的功能,adobe倒是可以考虑把这个模型整合进ps里提升一些性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/283201.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

1、初识JVM

一、JVM是什么? JVM的英文全称是 Java Virtual Machine,其中文译名为Java虚拟机。它在本质上就是是一个运行在计算机上的程序,他的职责是运行Java字节码文件。 JVM执行流程如下 二、JVM有哪些功能? 2.1 解释和运行 对字节码文…

平衡隐私与效率,Partisia Blockchain 解锁数字安全新时代

原文:https://cointelegraph.com/news/exploring-multiparty-computations-role-in-the-future-of-blockchain-privacy; https://medium.com/partisia-blockchain/unlocking-tomorrow-outlook-for-mpc-in-2024-and-beyond-cb170e3ec567 编译&#xff1…

数据仓库系列总结

一、数据仓库架构 1、数据仓库的概念 数据仓库(Data Warehouse)是一个面向主题的、集成的、相对稳定的、反映历史变化的数据集合,用于支持管理决策。 数据仓库通常包含多个来源的数据,这些数据按照主题进行组织和存储&#x…

蓝桥杯练习——神秘咒语——axios

目标 完善 index.js 中的 TODO 部分,通过新增或者修改代码,完成以下目标: 点击钥匙 1 和钥匙 2 按钮时会通过 axios 发送请求,在发送请求时需要在请求头中添加 Authorization 字段携带 token,token 的值为 2b58f9a8-…

【DataWhale学习】用免费GPU线上跑chatGLM、SD项目实践

用免费GPU线上跑chatGLM、SD项目实践 ​ DataWhale组织了一个线上白嫖GPU跑chatGLM与SD的项目活动,我很感兴趣就参加啦。之前就对chatGLM有所耳闻,是去年清华联合发布的开源大语言模型,可以用来打造个人知识库什么的,一直没有尝试…

基于python+vue 的一加剧场管理系统的设计与实现flask-django-nodejs-php

二十一世纪我们的社会进入了信息时代,信息管理系统的建立,大大提高了人们信息化水平。传统的管理方式对时间、地点的限制太多,而在线管理系统刚好能满足这些需求,在线管理系统突破了传统管理方式的局限性。于是本文针对这一需求设…

从0到1实现RPC | 02 RpcConsumer的远程调用

一、RPC的简化版原理如下图(核心是代理机制)。 1.本地代理存根: Stub 2.本地序列化反序列化 3.网络通信 4.远程序列化反序列化 5.远程服务存根: Skeleton 6.调用实际业务服务 7.原路返回服务结果 8.返回给本地调用方 二、新建一个模块rpc-demo-c…

Docker容器初始

华子目录 docker简介虚拟化技术硬件级虚拟化硬件级虚拟化历史操作系统虚拟化历史基于服务的云计算模式 什么是dockerDocker和传统虚拟化方式的不同之处为什么要使用docker?Docker 在如下几个方面具有较大的优势 对比传统虚拟机总结docker应用场景docker改变了什么 基…

WebClient上载文件——实现将本地文件同步到远端服务器上

问题描述 用户上传产品示例图片到服务器端上,客户端在请求图片资源时,当服务端架设了多个节点的情况下,由于没有负载均衡请求到保存图片资源的服务器,出现图片访问404的问题。 这里保存上传文件时,同时需要将该文件保…

30V转5V 1A 30降压12V 1A DCDC低电压恒压IC 车充芯片-H4110

30V转5V和30V转12V的DCDC低电压恒压IC(也称为降压恒压芯片或车充芯片)工作原理如下: 输入电压识别:芯片首先识别输入的30V电压,并准备进行转换。 PWM控制:芯片内部的控制逻辑生成PWM信号。这个信号用于控制…

基于python+vue文学名著分享系统的设计与实现flask-django-nodejs-php

随着世界经济信息化、全球化的到来和互联网的飞速发展,推动了各行业的改革。若想达到安全,快捷的目的,就需要拥有信息化的组织和管理模式,建立一套合理、动态的、交互友好的、高效的文学名著分享系统。当前的信息管理存在工作效率…

easyExcel大数据量导出oom

easyExcel大数据量导出 异常信息 com.alibaba.excel.exception.ExcelGenerateException: java.lang.OutOfMemoryError: GC overhead limit exceededat com.alibaba.excel.write.ExcelBuilderImpl.fill(ExcelBuilderImpl.java:84)at com.alibaba.excel.ExcelWriter.fill(Excel…

huggingface的transformers训练bert

目录 理论 实践 理论 https://arxiv.org/abs/1810.04805 BERT(Bidirectional Encoder Representations from Transformers)是一种自然语言处理(NLP)模型,由Google在2018年提出。它是基于Transformer模型的预训练方法…

STM32 AD单通道函数设计

单片机学习! 目录 文章目录 前言 一、ADC配置步骤 二、详细步骤 2.1 开启RCC时钟 2.2 配置GPIO 2.3 配置多路开关 2.4 配置ADC转换器 2.5 开启ADC电源 2.6 ADC进行校准 2.6.1 复位校准 2.6.2 等待复位校准完成 2.6.3 开始校准 2.6.4 等待校准完成 三、启动AD转换函数…

selenium自动化登录模块HTMLTestRunner测试报告

1.下载HTMLTestRunner.py放到python的Lib目录下,python3之后的,文件要修改以下内容: 第94行,将import StringIO修改成import io 第539行,将self.outputBuffer StringIO.StringIO()修改成self.outputBuffer io.Strin…

WordPress站点如何实现发布文章即主动推送到神马搜索引擎?

平时boke112百科很少关注到神马搜索引擎,近日有站长留言想要实现WordPress站点发布文章就主动推送到神马搜索引擎,而且推送成功就自动添加一个自定义字段,以防重复推送。 登录进入神马站长平台后才知道神马也有一个API推送功能,不…

GitHub Copilot+ESP开发实战-串口

上篇文章讲了GitHub Copilot在应用中可能遇到的问题,接下来小启就简单介绍下GitHub Copilot在ESP32开发中C语言实现串口功能,感兴趣的可以看看。 一、向Copilot提问: 1. ESP32用C语言实现串口初始化; 2.配置uart为1&#xff0c…

思腾合力受邀出席文化和旅游虚拟现实应用推广交流活动并作主题演讲

3月21日,由文化和旅游部产业发展司主办,中国信息通信研究院、北京市石景山区文化和旅游局、中国动漫集团有限公司承办的“数字赋能文旅场景建设行动——文化和旅游虚拟现实应用推广交流活动”在北京首钢一高炉SoReal科幻乐园成功举办。 思腾合力CMO徐莉受…

python文学名著分享系统的设计与实现flask-django-nodejs-php

在此基础上,结合现有文学名著分享体系的特点,运用新技术,构建了以python为基础的文学名著分享信息化管理体系。首先,以需求为依据,根据需求分析结果进行了系统的设计,并将其划分为管理员和用户二种角色和多…

【Node.js从基础到高级运用】十五、单元测试与集成测试

引言 在Node.js开发过程中,测试是确保代码质量和功能正确性的关键步骤。单元测试和集成测试是最常见的测试类型。下面我们将使用Jest框架来进行测试。 单元测试 单元测试是指对软件中的最小可测试单元进行检查和验证。在Node.js中,这通常指的是函数或者…