往期,笔者基于LLava的数据对齐训练,搞了一个Reyes多模态大模型,并且看了些多模态大模型,相关开源的多模态大模型如:KimiVL、Internvl、QwenVL等,其视觉编码器的尺寸都比较大,如:MoonViT-SO-400M、InternViT-6B-448px-V2_5 等都非常大,对于特定的垂直场景(或者是端侧落地都不大友好),也许并不需要这么大视觉编码器。如:表格场景(【多模态 & 文档智能】一次多模态大模型表格识别解析探索小实践记录),当时笔者用了一个8B参数的模型及百万表格数据进行训练达到了不错的效果。近期,因此思考一些模型轻量化的方案,寻找一个轻量点的视觉编码器(比如参数量小于100M) ,下面来看看SAM,供参考。
Segment Anything Model(SAM)是Meta AI发布的一个突破性图像分割模型为计算机视觉领域提供一个通用的、灵活的基座视觉大模型。它受到自然语言处理(NLP)中基础模型(如GPT、BERT)的启发,强调零样本迁移和提示式交互能力。在SA-1B数据集上的训练,该数据集包含超过11百万张图像和11亿个高质量分割掩码,覆盖了从日常场景到专业领域的多样化内容。
SAM借鉴了NLP领域的Prompt策略,通过给图像分割任务提供Prompt提示来完成任意目标的快速分割。Prompt类型可以是「前景/背景点集、粗略的框或遮罩、任意形式的文本或者任何指示图像中需要进行分割」的信息。如图(a)所示,模型的输入是原始的图像和一些prompt,目标是输出"valid"的分割,所谓valid,就是当prompt的指向是模糊时,模型能够输出至少其中一个mask。
模型结构
SAM的模型结构由三个核心组件组成,Image Encoder 、Prompt Encoder和Mask Decoder。分别负责图像特征提取、提示编码和掩码生成。图像经过Image Encoder编码,Prompt提示经过Prompt Encoder编码,两部分Embedding再经过一个轻量化的Mask Decoder得到融合后的特征。其中,Encoder部分使用的是已有模型,Decoder部分使用Transformer。 下表为三个组件的总结:
组件名称 | 功能 | 关键特点 |
---|---|---|
Image Encoder | ||
将输入图像转换为密集特征表示 | ||
使用MAE预训练的Vision Transformer(ViT-H/16),输入1024x1024x3,输出64x64x256嵌入。 | ||
Prompt Encoder | ||
将用户提示(点、框、文本、掩码)编码为嵌入 | ||
支持稀疏提示(点、框、文本)和密集提示(掩码),使用CLIP处理文本,灵活适应多种输入。 | ||
Mask Decoder | ||
结合图像和提示嵌入,生成最终分割掩码 | ||
轻量级Transformer解码器,通过自注意力与交叉注意力机制预测掩码,实时高效。 | ||
Image Encoder
本文的目的是为了寻找一个轻量化的视觉编码器 ,因此下面来详细看下视觉编码器部分。Image Encoder的作用是把图像映射到特征空间,整体过程如下图所示。
正如论文中所讲,本质上这个Encoder可以是任何网络结构,在这里使用的是微调的Detectron的ViT,当然它也可以被改成传统的卷积结构,非常合理。
可以看到,Image Encoder就是一个ViT的结构,由PatchEmbed、Transformer Encoder、Neck Convolution组成。
输入图像经过ViT结构的过程如下:
- Patch Embedding
输入图像通过一个卷积base,将图像划分为16x16的patches,步长也为16,这样feature map的尺寸就缩小了16倍,同时channel从3映射到768。Patch Embedding示意图如下所示。
将输入的图像转换为序列化的特征向量
Patch Embedding过程在Vision Transformer结构图中对应下图所示。
- Transformer Encode
feature map通过16个Transformer Block,其中12个Block 使用了基于Window Partition(就是把特征图分成14*14的windows做局部的Attention)的注意力机制,以处理局部信息。另外4个Block是全局注意力模块(多头注意力),它们穿插在Window Partition模块之间,以捕捉图像的全局上下文。
循环叠加Transformer Encode 2. Neck Convolution
最后,通过两层卷积(Neck)将通道数降低至256,生成最终的Image Embedding。其结构图如下所示。
SAM构建与轻量化编码器提取
通过下面代码提取一个参数量大小仅为80几M的视觉编码器。
import torch
from functools import partial
from modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer
def build\_sam\_vit\_b(checkpoint=None):
return \_build\_sam(
encoder\_embed\_dim=768,
encoder\_depth=12,
encoder\_num\_heads=12,
encoder\_global\_attn\_indexes=[2, 5, 8, 11],
checkpoint=checkpoint,
)
sam\_model\_registry = {
"vit\_b": build\_sam\_vit\_b,
}
def \_build\_sam(
encoder\_embed\_dim,
encoder\_depth,
encoder\_num\_heads,
encoder\_global\_attn\_indexes,
checkpoint=None,
):
prompt\_embed\_dim = 256
image\_size = 1024
vit\_patch\_size = 16
image\_embedding\_size = image\_size // vit\_patch\_size
sam = Sam(
image\_encoder=ImageEncoderViT(
depth=encoder\_depth,
embed\_dim=encoder\_embed\_dim,
img\_size=image\_size,
mlp\_ratio=4,
norm\_layer=partial(torch.nn.LayerNorm, eps=1e-6),
num\_heads=encoder\_num\_heads,
patch\_size=vit\_patch\_size,
qkv\_bias=True,
use\_rel\_pos=True,
global\_attn\_indexes=encoder\_global\_attn\_indexes,
window\_size=14,
out\_chans=prompt\_embed\_dim,
),
prompt\_encoder=PromptEncoder(
embed\_dim=prompt\_embed\_dim,
image\_embedding\_size=(image\_embedding\_size, image\_embedding\_size),
input\_image\_size=(image\_size, image\_size),
mask\_in\_chans=16,
),
mask\_decoder=MaskDecoder(
num\_multimask\_outputs=3,
transformer=TwoWayTransformer(
depth=2,
embedding\_dim=prompt\_embed\_dim,
mlp\_dim=2048,
num\_heads=8,
),
transformer\_dim=prompt\_embed\_dim,
iou\_head\_depth=3,
iou\_head\_hidden\_dim=256,
),
pixel\_mean=[123.675, 116.28, 103.53],
pixel\_std=[58.395, 57.12, 57.375],
)
sam.eval()
if checkpoint is not None:
with open(checkpoint, "rb") as f:
state\_dict = torch.load(f)
sam.load\_state\_dict(state\_dict)
return sam
if \_\_name\_\_ == '\_\_main\_\_':
x = torch.zeros(2, 3, 1024, 1024)
net = build\_sam\_vit\_b(checkpoint='sam\_vit\_b\_01ec64.pth')
image\_encoder = net.image\_encoder
print(image\_encoder)
print(image\_encoder(x).shape) # 输出:torch.Size([2, 256, 64, 64])
total\_params = sum(p.numel() for p in image\_encoder.parameters())
print(f"模型的参数量为: {(total\_params/ 1e6):.2f}M") # 模型的参数量为: 89.67M
参考文献:
Segment Anything,https://arxiv.org/pdf/2304.02643
code:https://github.com/facebookresearch/segment-anything
往期相关:
Qwen-VL系列多模态大模型技术演进-模型架构、训练方法、数据细节
Phi-4-multimodal:图、文、音频统一的多模态大模型架构、训练方法、数据细节
deepseek多模态大模型Janus、Janus-Pro模型架构及优化方法浅谈
Encoder-free无编码器多模态大模型EVEv2模型架构、训练方法浅尝
关于我:余俊晖,主要研究方向为自然语言处理、大语言模型、文档智能。曾获CCF、Kaggle、ICPR、ICDAR、CCL、CAIL等国内外近二十项AI算法竞赛/评测冠亚季军。发表SCI、顶会等文章多篇,专利数项。