多模态大模型轻量化探索-视觉大模型SAM的视觉编码器

大模型向量数据库机器学习

往期,笔者基于LLava的数据对齐训练,搞了一个Reyes多模态大模型,并且看了些多模态大模型,相关开源的多模态大模型如:KimiVL、Internvl、QwenVL等,其视觉编码器的尺寸都比较大,如:MoonViT-SO-400M、InternViT-6B-448px-V2_5 等都非常大,对于特定的垂直场景(或者是端侧落地都不大友好),也许并不需要这么大视觉编码器。如:表格场景(【多模态 & 文档智能】一次多模态大模型表格识别解析探索小实践记录),当时笔者用了一个8B参数的模型及百万表格数据进行训练达到了不错的效果。近期,因此思考一些模型轻量化的方案,寻找一个轻量点的视觉编码器(比如参数量小于100M) ,下面来看看SAM,供参考。

Segment Anything Model(SAM)是Meta AI发布的一个突破性图像分割模型为计算机视觉领域提供一个通用的、灵活的基座视觉大模型。它受到自然语言处理(NLP)中基础模型(如GPT、BERT)的启发,强调零样本迁移和提示式交互能力。在SA-1B数据集上的训练,该数据集包含超过11百万张图像和11亿个高质量分割掩码,覆盖了从日常场景到专业领域的多样化内容。

picture.image

SAM借鉴了NLP领域的Prompt策略,通过给图像分割任务提供Prompt提示来完成任意目标的快速分割。Prompt类型可以是「前景/背景点集、粗略的框或遮罩、任意形式的文本或者任何指示图像中需要进行分割」的信息。如图(a)所示,模型的输入是原始的图像和一些prompt,目标是输出"valid"的分割,所谓valid,就是当prompt的指向是模糊时,模型能够输出至少其中一个mask。

模型结构

picture.image

SAM的模型结构由三个核心组件组成,Image Encoder 、Prompt Encoder和Mask Decoder。分别负责图像特征提取、提示编码和掩码生成。图像经过Image Encoder编码,Prompt提示经过Prompt Encoder编码,两部分Embedding再经过一个轻量化的Mask Decoder得到融合后的特征。其中,Encoder部分使用的是已有模型,Decoder部分使用Transformer。 下表为三个组件的总结:

组件名称功能关键特点
Image Encoder
将输入图像转换为密集特征表示
使用MAE预训练的Vision Transformer(ViT-H/16),输入1024x1024x3,输出64x64x256嵌入。
Prompt Encoder
将用户提示(点、框、文本、掩码)编码为嵌入
支持稀疏提示(点、框、文本)和密集提示(掩码),使用CLIP处理文本,灵活适应多种输入。
Mask Decoder
结合图像和提示嵌入,生成最终分割掩码
轻量级Transformer解码器,通过自注意力与交叉注意力机制预测掩码,实时高效。

Image Encoder

本文的目的是为了寻找一个轻量化的视觉编码器 ,因此下面来详细看下视觉编码器部分。Image Encoder的作用是把图像映射到特征空间,整体过程如下图所示。

picture.image

正如论文中所讲,本质上这个Encoder可以是任何网络结构,在这里使用的是微调的Detectron的ViT,当然它也可以被改成传统的卷积结构,非常合理。

picture.image可以看到,Image Encoder就是一个ViT的结构,由PatchEmbed、Transformer Encoder、Neck Convolution组成。

输入图像经过ViT结构的过程如下:

  1. Patch Embedding

输入图像通过一个卷积base,将图像划分为16x16的patches,步长也为16,这样feature map的尺寸就缩小了16倍,同时channel从3映射到768。Patch Embedding示意图如下所示。

picture.image

将输入的图像转换为序列化的特征向量

picture.image

Patch Embedding过程在Vision Transformer结构图中对应下图所示。

picture.image

  1. Transformer Encode

feature map通过16个Transformer Block,其中12个Block 使用了基于Window Partition(就是把特征图分成14*14的windows做局部的Attention)的注意力机制,以处理局部信息。另外4个Block是全局注意力模块(多头注意力),它们穿插在Window Partition模块之间,以捕捉图像的全局上下文。

picture.image

picture.image

循环叠加Transformer Encode 2. Neck Convolution

最后,通过两层卷积(Neck)将通道数降低至256,生成最终的Image Embedding。其结构图如下所示。

picture.image

picture.image

SAM构建与轻量化编码器提取

通过下面代码提取一个参数量大小仅为80几M的视觉编码器。

  
  
import torch  
from functools import partial  
from modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer  
  
  
def build\_sam\_vit\_b(checkpoint=None):  
    return \_build\_sam(  
        encoder\_embed\_dim=768,  
        encoder\_depth=12,  
        encoder\_num\_heads=12,  
        encoder\_global\_attn\_indexes=[2, 5, 8, 11],  
        checkpoint=checkpoint,  
    )  
  
  
sam\_model\_registry = {  
    "vit\_b": build\_sam\_vit\_b,  
}  
  
  
def \_build\_sam(  
        encoder\_embed\_dim,  
        encoder\_depth,  
        encoder\_num\_heads,  
        encoder\_global\_attn\_indexes,  
        checkpoint=None,  
):  
    prompt\_embed\_dim = 256  
    image\_size = 1024  
    vit\_patch\_size = 16  
    image\_embedding\_size = image\_size // vit\_patch\_size  
    sam = Sam(  
        image\_encoder=ImageEncoderViT(  
            depth=encoder\_depth,  
            embed\_dim=encoder\_embed\_dim,  
            img\_size=image\_size,  
            mlp\_ratio=4,  
            norm\_layer=partial(torch.nn.LayerNorm, eps=1e-6),  
            num\_heads=encoder\_num\_heads,  
            patch\_size=vit\_patch\_size,  
            qkv\_bias=True,  
            use\_rel\_pos=True,  
            global\_attn\_indexes=encoder\_global\_attn\_indexes,  
            window\_size=14,  
            out\_chans=prompt\_embed\_dim,  
        ),  
        prompt\_encoder=PromptEncoder(  
            embed\_dim=prompt\_embed\_dim,  
            image\_embedding\_size=(image\_embedding\_size, image\_embedding\_size),  
            input\_image\_size=(image\_size, image\_size),  
            mask\_in\_chans=16,  
        ),  
        mask\_decoder=MaskDecoder(  
            num\_multimask\_outputs=3,  
            transformer=TwoWayTransformer(  
                depth=2,  
                embedding\_dim=prompt\_embed\_dim,  
                mlp\_dim=2048,  
                num\_heads=8,  
            ),  
            transformer\_dim=prompt\_embed\_dim,  
            iou\_head\_depth=3,  
            iou\_head\_hidden\_dim=256,  
        ),  
        pixel\_mean=[123.675, 116.28, 103.53],  
        pixel\_std=[58.395, 57.12, 57.375],  
    )  
    sam.eval()  
    if checkpoint is not None:  
        with open(checkpoint, "rb") as f:  
            state\_dict = torch.load(f)  
        sam.load\_state\_dict(state\_dict)  
    return sam  
  
  
if \_\_name\_\_ == '\_\_main\_\_':  
    x = torch.zeros(2, 3, 1024, 1024)  
    net = build\_sam\_vit\_b(checkpoint='sam\_vit\_b\_01ec64.pth')  
    image\_encoder = net.image\_encoder  
  
    print(image\_encoder)  
    print(image\_encoder(x).shape)  # 输出:torch.Size([2, 256, 64, 64])  
      
    total\_params = sum(p.numel() for p in image\_encoder.parameters())  
    print(f"模型的参数量为: {(total\_params/ 1e6):.2f}M")      # 模型的参数量为: 89.67M  
     

参考文献:

Segment Anything,https://arxiv.org/pdf/2304.02643

code:https://github.com/facebookresearch/segment-anything

往期相关:

Reyes:一个从0到1开始训练的多模态大模型(技术报告)

多模态大模型Ovis核心技术点、训练方法、数据细节

Qwen-VL系列多模态大模型技术演进-模型架构、训练方法、数据细节

Phi-4-multimodal:图、文、音频统一的多模态大模型架构、训练方法、数据细节

deepseek多模态大模型Janus、Janus-Pro模型架构及优化方法浅谈

POINTS多模态大模型浅谈

英伟达NVLM多模态大模型细节和数据集

OCR-free感知多模态大模型技术链路及训练数据细节

Kimi-VL开源多模态大模型结构、训练方法、训练数据浅析

Encoder-free无编码器多模态大模型EVEv2模型架构、训练方法浅尝

关于我:余俊晖,主要研究方向为自然语言处理、大语言模型、文档智能。曾获CCF、Kaggle、ICPR、ICDAR、CCL、CAIL等国内外近二十项AI算法竞赛/评测冠亚季军。发表SCI、顶会等文章多篇,专利数项。

0
0
0
0
关于作者
关于作者

文章

0

获赞

0

收藏

0

相关资源
在火山引擎云搜索服务上构建混合搜索的设计与实现
本次演讲将重点介绍字节跳动在混合搜索领域的探索,并探讨如何在多模态数据场景下进行海量数据搜索。
相关产品
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论