【自用】SAM模型论文笔记与复现代码(segment-anything-model)

这篇具有很好参考价值的文章主要介绍了【自用】SAM模型论文笔记与复现代码(segment-anything-model)。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

总模型结构

一个prompt encoder,对提示进行编码,image encoder对图像编码,生成embedding, 最后融合2个encoder,再接一个轻量的mask decoder,输出最后的mask。

模型结构示意图:【自用】SAM模型论文笔记与复现代码(segment-anything-model),深度学习,人工智能,python,ai,论文笔记,图像处理

流程图:

【自用】SAM模型论文笔记与复现代码(segment-anything-model),深度学习,人工智能,python,ai,论文笔记,图像处理

模型的结构如上图所示. prompt会经过prompt encoder, 图像会经过image encoder。然后将两部分embedding经过一个轻量化的mask decoder得到融合后的特征。encoder部分使用的都是已有模型,decoder使用transformer。

image encoder

利用MAE(Masked AutoEncoder)预训练的ViT模型,对每张图片只处理一次,且在prompt encoder之前进行。输入(c,h,w)的图像,对图像进行缩放,按照长边缩放成1024,短边不够就填充,得到(c,1024,1024)的图像,经过image encoder,得到对图像16倍下采样的feature,大小为(256,64,64)。

prompt encoder

prompt encoder结构图:
【自用】SAM模型论文笔记与复现代码(segment-anything-model),深度学习,人工智能,python,ai,论文笔记,图像处理

分为两类:稀疏与密集

稀疏:
  • point:使用position encodings
  • box:使用position encodings
  • text:使用CLIP作为encoder
密集:
  • mask:使用卷积作为encoder

mask decoder

  • prompt self-attention
  • cross-attention(从prompt到image和从image到prompt)

valid mask(模型输出)

  • 解决混淆的输入: 对于一个prompt,模型会输出3个mask,实际上也可以输出更多的分割结果,3个可以看作一个物体的整体、部分、子部分,基本能满足大多数情况。使用IOU的方式,排序mask。在反向传播时,参与计算的只有loss最小的mask相关的参数.
  • 高效: 这里主要指的是prompt encodermask decoder。在web浏览器上,CPU计算只用约50ms

SAM模型复现

环境:

python 3.8.10
pytorch 1.11.0
cuda 11.3

环境安装

git clone https://github.com/facebookresearch/segment-anything
pip install opencv-python matplotlib
pip install -e .
wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth #下载SAM_VIT-H模型

定义用于可视化的工具函数

import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2

def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)
    
def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   
    
def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))    
    
def show_anns(anns):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)

    img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
    img[:,:,3] = 0
    for ann in sorted_anns:
        m = ann['segmentation']
        color_mask = np.concatenate([np.random.random(3), [0.35]])
        img[m] = color_mask
    ax.imshow(img)

可视化原图片

image = cv2.imread('R.jpg')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

plt.figure(figsize=(14,14))
plt.imshow(image)
plt.axis('on')
plt.show()

【自用】SAM模型论文笔记与复现代码(segment-anything-model),深度学习,人工智能,python,ai,论文笔记,图像处理

原图片:
【自用】SAM模型论文笔记与复现代码(segment-anything-model),深度学习,人工智能,python,ai,论文笔记,图像处理

加载SAM模型

import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamPredictor

sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"

device = "cuda"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

predictor = SamPredictor(sam)

predictor.set_image(image)

点作为prompt

单点
input_point = np.array([[430, 605]])
input_label = np.array([1])
plt.figure(figsize=(14,14))
plt.imshow(image)
show_points(input_point, input_label, plt.gca())
plt.axis('on')
plt.show()  

【自用】SAM模型论文笔记与复现代码(segment-anything-model),深度学习,人工智能,python,ai,论文笔记,图像处理

使用SAM模型进行分割,并输出模型分割出的3个mask

masks, scores, logits = predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    multimask_output=True, #`multimask_output=True`表示是否输出三个mask结果
)

for i, (mask, score) in enumerate(zip(masks, scores)):
    plt.figure(figsize=(14,14))
    plt.imshow(image)
    show_mask(mask, plt.gca())
    show_points(input_point, input_label, plt.gca())
    plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
    plt.axis('off')
    plt.show()  
  

【自用】SAM模型论文笔记与复现代码(segment-anything-model),深度学习,人工智能,python,ai,论文笔记,图像处理

【自用】SAM模型论文笔记与复现代码(segment-anything-model),深度学习,人工智能,python,ai,论文笔记,图像处理

【自用】SAM模型论文笔记与复现代码(segment-anything-model),深度学习,人工智能,python,ai,论文笔记,图像处理

多点(使用先前单点输出的mask作为mask prompt)
仅前景点
input_point = np.array([[430, 605],[520, 650]])
input_label = np.array([1, 1]) #1代表前景点(绿色),0代表后景点(红色)

mask_input = logits[np.argmax(scores), :, :]  #选择先前分数最高的mask

masks, _, _ = predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    mask_input=mask_input[None, :, :],
    multimask_output=False,
)
plt.figure(figsize=(14,14))
plt.imshow(image)
show_mask(masks, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis('on')
plt.show() 

【自用】SAM模型论文笔记与复现代码(segment-anything-model),深度学习,人工智能,python,ai,论文笔记,图像处理

前景点+后景点
input_point = np.array([[430, 605],[520, 650], [520,500]])
input_label = np.array([1, 1, 0])  #1代表前景点(绿色),0代表后景点(红色)

mask_input = logits[np.argmax(scores), :, :]  
masks, _, _ = predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    mask_input=mask_input[None, :, :],
    multimask_output=False,
)
plt.figure(figsize=(14,14))
plt.imshow(image)
show_mask(masks, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis('on')
plt.show() 

【自用】SAM模型论文笔记与复现代码(segment-anything-model),深度学习,人工智能,python,ai,论文笔记,图像处理

矩形框作为prompt

单个矩形框
input_box = np.array([730, 105, 1030, 315])

masks, _, _ = predictor.predict(
    point_coords=None,
    point_labels=None,
    box=input_box[None, :],
    multimask_output=False,
)

plt.figure(figsize=(17, 17))
plt.imshow(image)
show_mask(masks[0], plt.gca())
show_box(input_box, plt.gca())
plt.axis('on')
plt.show()

【自用】SAM模型论文笔记与复现代码(segment-anything-model),深度学习,人工智能,python,ai,论文笔记,图像处理

多个矩形框(需要使用transform.apply_boxes_torch方法进行转换)
input_boxes = torch.tensor([
    [730, 105, 1030, 315],
    [970, 155, 1025, 250]
], device=predictor.device)

transformed_boxes = predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2])
masks, _, _ = predictor.predict_torch(
    point_coords=None,
    point_labels=None,
    boxes=transformed_boxes,
    multimask_output=False,
)

plt.figure(figsize=(17, 17))
plt.imshow(image)
for mask in masks:
    show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
for box in input_boxes:
    show_box(box.cpu().numpy(), plt.gca())
plt.axis('on')
plt.show()

【自用】SAM模型论文笔记与复现代码(segment-anything-model),深度学习,人工智能,python,ai,论文笔记,图像处理

自动分割

from segment_anything import SamAutomaticMaskGenerator
mask_generator = SamAutomaticMaskGenerator(sam)
masks = mask_generator.generate(image)
print(len(masks))
plt.figure(figsize=(20,20))
plt.imshow(image)
show_anns(masks)
plt.axis('on')
plt.show() 

输出mask数量
178

【自用】SAM模型论文笔记与复现代码(segment-anything-model),深度学习,人工智能,python,ai,论文笔记,图像处理

开始调参

mask_generator_2 = SamAutomaticMaskGenerator(
    model=sam,
    points_per_side=32,
    pred_iou_thresh=0.86,
    stability_score_thresh=0.92,
    crop_n_layers=1,
    crop_n_points_downscale_factor=2,
    min_mask_region_area=100, 
)
masks_2 = mask_generator_2.generate(image)
print(len(masks_2))
plt.figure(figsize=(20,20))
plt.imshow(image)
show_anns(masks_2)
plt.axis('on')
plt.show() 

输出mask数量
335

【自用】SAM模型论文笔记与复现代码(segment-anything-model),深度学习,人工智能,python,ai,论文笔记,图像处理文章来源地址https://www.toymoban.com/news/detail-821397.html

到了这里,关于【自用】SAM模型论文笔记与复现代码(segment-anything-model)的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • FedAvg与FedProx论文笔记以及代码复现(1)

    目录 一、FedAvg原始论文笔记 1、联邦优化问题:  2、联邦平均算法: FedSGD算法: FedAvg算法: 实验结果: 3、代码解释  3.1、main_fed.py主函数 3.2、Fed.py: 3.3、Nets.py:模型定义 3.4、option.py超参数设置 3.5、sampling.py: 3.6、update.py :局部更新 3.7、main_nn.py对照组 普通的nn 联邦平均算法

    2023年04月15日
    浏览(35)
  • GraphDTA论文阅读小白笔记(附代码注释和复现流程)

    具体代码复现以及代码注释可以查看https://github.com/zhaolongNCU/Demo-GraphDTA- 1.发展前景: 新药设计需要花费2.6billion,17years 药物再利用已被用于现实的疾病中 为了有效地重新调整药物的用途,了解哪些蛋白质是哪些药物的靶标是有用的 高通量筛选方法高消耗,并且彻底地完成搜

    2024年02月15日
    浏览(44)
  • LVI-SAM代码复现、调试与运行

            LVI-SAM是Tixiao Shan的最新力作,Tixiao Shan是Lego-loam和Lio-sam的作者,LVI-SAM是Tixiao Shan最新开源的基于视觉-激光-惯导里程计SLAM框架,结合了Lio-sam和Vins-Mono的视觉-激光-惯导融合的SLAM框架。 LVI-SAM系统框架         文章主要工作: 实现了一个激光-视觉-惯性的紧耦合

    2024年02月05日
    浏览(39)
  • 经典神经网络论文超详细解读(六)——DenseNet学习笔记(翻译+精读+代码复现)

    上一篇我们介绍了ResNet:经典神经网络论文超详细解读(五)——ResNet(残差网络)学习笔记(翻译+精读+代码复现) ResNet通过短路连接,可以训练出更深的CNN模型,从而实现更高的准确度。今天我们要介绍的是 DenseNet(《Densely connected convolutional networks》) 模型,它的基本

    2024年02月03日
    浏览(62)
  • 经典神经网络论文超详细解读(八)——ResNeXt学习笔记(翻译+精读+代码复现)

    今天我们一起来学习何恺明大神的又一经典之作:  ResNeXt(《Aggregated Residual Transformations for Deep Neural Networks》) 。这个网络可以被解释为 VGG、ResNet 和 Inception 的结合体,它通过重复多个block(如在 VGG 中)块组成,每个block块聚合了多种转换(如 Inception),同时考虑到跨层

    2024年02月03日
    浏览(55)
  • 经典神经网络论文超详细解读(五)——ResNet(残差网络)学习笔记(翻译+精读+代码复现)

    《Deep Residual Learning for Image Recognition》这篇论文是何恺明等大佬写的,在深度学习领域相当经典,在2016CVPR获得best paper。今天就让我们一起来学习一下吧! 论文原文:https://arxiv.org/abs/1512.03385 前情回顾: 经典神经网络论文超详细解读(一)——AlexNet学习笔记(翻译+精读)

    2024年02月08日
    浏览(47)
  • Segment Anything论文翻译,SAM模型,SAM论文,SAM论文翻译;一个用于图像分割的新任务、模型和数据集;SA-1B数据集

    论文链接: https://arxiv.org/pdf/2304.02643.pdf https://ai.facebook.com/research/publications/segment-anything/ 代码连接:https://github.com/facebookresearch/segment-anything 论文翻译: http://t.csdn.cn/nnqs8 https://blog.csdn.net/leiduifan6944/article/details/130080159 本文提出Segment Anything (SA)项目:一个用于图像分割的新任务

    2023年04月19日
    浏览(51)
  • 语义分割大模型SAM论文阅读(二)

    Segment Anything SAM 我们介绍了分割一切(SA)项目:一个新的图像分割任务,模型和数据集。在数据收集循环中使用我们的高效模型,我们建立了迄今为止(到目前为止)最大的分割数据集,在1100万张许可和尊重隐私的图像上拥有超过10亿个掩模。 该模型被设计和训练为提示 ,因此它

    2024年02月13日
    浏览(45)
  • 【论文阅读笔记】Mamba模型代码理解

    官方实现:state-spaces/mamba (github.com) 最简化实现:johnma2006/mamba-minimal: Simple, minimal implementation of the Mamba SSM in one file of PyTorch. (github.com) 直接实现:alxndrTL/mamba.py: A simple and efficient Mamba implementation in PyTorch and MLX. (github.com) 官方代码做了大量优化,目录层级较多,对于理解模型含

    2024年04月13日
    浏览(70)
  • 【论文阅读】Segment Anything(SAM)——可分割一切的CV大模型

    【前言】随着ChatGPT席卷自然语言处理,Facebook凭借着Segment Anything在CV圈也算扳回一城。迄今为止,github的star已经超过3万,火的可谓一塌糊涂。作为AI菜鸟,可不得自己爬到巨人肩膀上瞅一瞅~ 论文地址:https://arxiv.org/abs/2304.02643 代码地址:GitHub - facebookresearch/segment-anything: T

    2024年02月15日
    浏览(44)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包