模型架构
模型主要由三个部份组成,一个笨重的 image encoder ,用于将图像 embedding;一个轻量级的 prompt encoder 和一个轻量级的 mask decoder,
Image Encoder
Vision Transformer(ViT)
这个部份采用了 NLP 领域常用的 Transformer 思想。
输入图片首先经过 PatchEmbed 模块。在这里,将图片切割成 16 * 16 个 patch, 每个 patch 的维度是 768.
然后,如果启用 absolute positional embeddings, 也就是位置信息,那么直接将它 sum 到 patch embedding 上。
接下来通过一组 depth 个的 transformer blocks. 这里使用了 Multihead Attention,然后再经过一个激活函数是 GeLU 的 MLP 。 这里它使用了 Window Attention ,也就是每次自注意力只关注一个局部。论文称它使用了 14 * 14 的 Window。同时还加上了 relative positional embeddings。
Reduce Channel Dimension
在这里它参照了Exploring Plain Vision Transformer Backbones for Object Detection,把输出先后通过 1 * 1, 256 channel 和 3 * 3, 256 channel 的卷积压缩。每次卷积后 LayerNorm 一次。
Prompt Encoder
将输入分成以下四种:
- 一个点:将 positional encoding 和一个表示它是在前景还是背景的 embedding 相加。
- 一个框:分别用两个 embedding 表示左上角和右下角。
- mask:可能是用于训练,直接塞入
dense_prompt_embedding
,sparse 项置零。 - 无输入:单独的 embedding,表示 no prompt.
分别传出两个部份 sparse_prompt_embedding
和 dense_prompt_embedding
Lightweight mask Decoder
传入的信息包括两个部份:
- token:这里包括两个 Embedding :
iou_token
和mask_token
,还有 prompt encoder 传出的sparse_prompt_embedding
,将他们 concatenate 在一起。 - src:这个部分将
image_embedding
,dense_prompt_embedding
sum 在一起。
接下来将 src 和 pos_src (分别代表 prompt 的信息和 image 的信息)放入 TwoWayTransformer
。分别用两个 cross attention 来处理 token 和 image 之间的相互关系。
一个 upscaling,然后生成 4 个 mask token ;同时预测一个 IoU (用来给结果质量排序)
Ambiguity-aware
这个部分主要问题是可能会把多个有效输出的 mask 给平均掉。「observe that」的处理方案是同时预测并输出三个 mask(代表整体,部份和子部份),同时预测一个 IoU 给结果排序,并且只考虑质量最好的 loss 来反向传播。
同时如果给出多个 prompt 的话只返回一个(多个 prompt 足以确认一个有效输出),为了不和前面混淆总共需要生成 4 个 mask。
数据
原文用了很大篇幅来解释数据的采样在种族,国家和生存环境上的多样性。
数据集生成
原始数据是从某个摄影公司处获取的,同时附有
大致分成三个步骤
- Assisted-manual stage: 这个阶段主要由打标人在一个 web 端上用一些工具来打标
- Semi-auto stage: 这个阶段标记出 confident masks,然后要求打标人给其他的对象打标
- Fully-auto stage: 这个部份用一个 32 * 32 的 point 型 prompt 来生成一组 masks,然后根据预测的 IoU 来筛选
一些 trick
- 只保留 confident mask ,也就是 IoU > 88.0
- 去除覆盖超过 95% 的 mask,提升 mask 质量,同时处理掉过小(100像素)的 spurious holes 和 components
训练
Losses
用 focal loss 和 dice loss 的 20:1 的线性组合来监督 mask 用 mean-square-error loss 监督 IoU 预测,factor 是 1.0
Training Algorithm
等概率选择 foreground point 或者 bounding box,然后加一些扰动。
之后从误差区域里加入新的采样点作为 prompt
把前一代的 mask 作为 prompt 塞给后一代(这可能是前面 prompt encoder 中 mask 的作用)
由于 prompt encoder 和 mask decoder 的开销很小(不足 1% 相对 image encoder),所以可以支持多步迭代(这里选择了 1 次初始,8 次更新采样点,然后 2 次没有额外信息的迭代)
Zero-shot Text-to-Mask
The key observation here is that because CLIP’s image embeddings are trained to align with its text embeddings, we can train with image embeddings, but use textembeddings for inference. That is, at inference time we run text through CLIP’s text encoder and then give the resulting text embedding as a prompt to SAM.
用 CLIP 的 image embedding 做训练,用 text embeddings 作推理。很深刻。