【计算机视觉】ViT:代码逐行解读

这篇具有很好参考价值的文章主要介绍了【计算机视觉】ViT:代码逐行解读。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

一、代码

import torch
import torch.nn as nn
from einops import rearrange

from self_attention_cv import TransformerEncoder


class ViT(nn.Module):
    def __init__(self, *,
                 img_dim,
                 in_channels=3,
                 patch_dim=16,
                 num_classes=10,
                 dim=512,
                 blocks=6,
                 heads=4,
                 dim_linear_block=1024,
                 dim_head=None,
                 dropout=0, transformer=None, classification=True):
        """
        Args:
            img_dim: the spatial image size
            in_channels: number of img channels
            patch_dim: desired patch dim
            num_classes: classification task classes
            dim: the linear layer's dim to project the patches for MHSA
            blocks: number of transformer blocks
            heads: number of heads
            dim_linear_block: inner dim of the transformer linear block
            dim_head: dim head in case you want to define it. defaults to dim/heads
            dropout: for pos emb and transformer
            transformer: in case you want to provide another transformer implementation
            classification: creates an extra CLS token
        """
        super().__init__()
        assert img_dim % patch_dim == 0, f'patch size {patch_dim} not divisible'
        self.p = patch_dim
        self.classification = classification
        tokens = (img_dim // patch_dim) ** 2
        self.token_dim = in_channels * (patch_dim ** 2)
        self.dim = dim
        self.dim_head = (int(dim / heads)) if dim_head is None else dim_head
        self.project_patches = nn.Linear(self.token_dim, dim)

        self.emb_dropout = nn.Dropout(dropout)
        if self.classification:
            self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
            self.pos_emb1D = nn.Parameter(torch.randn(tokens + 1, dim))
            self.mlp_head = nn.Linear(dim, num_classes)
        else:
            self.pos_emb1D = nn.Parameter(torch.randn(tokens, dim))

        if transformer is None:
            self.transformer = TransformerEncoder(dim, blocks=blocks, heads=heads,
                                                  dim_head=self.dim_head,
                                                  dim_linear_block=dim_linear_block,
                                                  dropout=dropout)
        else:
            self.transformer = transformer

    def expand_cls_to_batch(self, batch):
        """
        Args:
            batch: batch size
        Returns: cls token expanded to the batch size
        """
        return self.cls_token.expand([batch, -1, -1])

    def forward(self, img, mask=None):
        batch_size = img.shape[0]
        img_patches = rearrange(
            img, 'b c (patch_x x) (patch_y y) -> b (x y) (patch_x patch_y c)',
                                patch_x=self.p, patch_y=self.p)
        # project patches with linear layer + add pos emb
        img_patches = self.project_patches(img_patches)

        if self.classification:
            img_patches = torch.cat(
                (self.expand_cls_to_batch(batch_size), img_patches), dim=1)

        patch_embeddings = self.emb_dropout(img_patches + self.pos_emb1D)

        # feed patch_embeddings and output of transformer. shape: [batch, tokens, dim]
        y = self.transformer(patch_embeddings, mask)

        if self.classification:
            # we index only the cls token for classification. nlp tricks :P
            return self.mlp_head(y[:, 0, :])
        else:
            return y

二、代码解读

2.1 大体理解

这段代码是一个实现了 Vision Transformer(ViT)模型的 PyTorch 实现。

ViT 是一个基于 Transformer 架构的图像分类模型,其主要思想是将图像分成一个个固定大小的 patch ,并将这些 patch 看做是一个个 token 输入到 Transformer 中进行特征提取和分类。

以下是对代码的解读:

  1. ViT类继承自nn.Module类,其构造函数有一系列参数,包括输入图像的尺寸、patch的大小、输出类别数、注意力机制中的头数等等。
  2. project_patches函数通过一个全连接层将每个patch映射到一个d维的特征空间中。
  3. 如果classification = True,则将一个额外的CLS token添加到输入的token序列的开头,即对于每张图像添加一个形状为[1, 1, d]的CLS token。同时,在ViT中采用的是绝对位置编码,因此还添加了一个1D的位置编码向量,其形状为[num_patches + 1, d],其中num_patches表示图像被划分成的patch数目。如果classification = False,则不添加CLS token。
  4. forward函数首先将输入的图像进行patch划分,并通过project_patches函数将每个patch映射到d维特征空间中。接着,将位置编码向量加到映射后的patch特征向量上,并进行dropout处理。如果classification=True,则在特征序列开头添加CLS token。接着将这些特征输入到Transformer中,进行特征提取。最后输出分类结果,如果classification=True,则只返回CLS token的分类结果。

2.2 详细理解

from self_attention_cv import TransformerEncoder

self_attention_cv是一个基于PyTorch实现的库,提供了在计算机视觉任务中使用自注意力机制的模块和网络,例如Transformer EncoderAttention Modules

它主要针对图像分类、对象检测、语义分割等任务,支持多种自注意力模块的实现,包括Simplified Self-AttentionFull Self-AttentionLocal Self-Attention等。此外,该库还提供了一些常见的计算机视觉任务模型的实现,例如Vision Transformer(ViT)Swin Transformer等。

TransformerEncoder是一个自注意力机制的编码器,用于将输入序列转换为编码后的序列。自注意力机制允许模型能够根据输入序列中的其他位置来加权计算每个位置的表示。这种机制在自然语言处理中的应用非常广泛,比如BERT、GPT等模型都采用了自注意力机制。

TransformerEncoder是基于PyTorch实现的,可以在计算机视觉任务中使用,例如图像分类、对象检测、语义分割等。它支持多头注意力、残差连接和LayerNorm等特性。在这个代码中,ViT模型中的Transformer部分采用了TransformerEncoder作为默认的实现。

def __init__(self, *,
                img_dim,
                in_channels=3,
                patch_dim=16,
                num_classes=10,
                dim=512,
                blocks=6,
                heads=4,
                dim_linear_block=1024,
                dim_head=None,
                dropout=0, transformer=None, classification=True):
    super().__init__()
    assert img_dim % patch_dim == 0, f'patch size {patch_dim} not divisible'
    self.p = patch_dim
    self.classification = classification
    tokens = (img_dim // patch_dim) ** 2
    self.token_dim = in_channels * (patch_dim ** 2)
    self.dim = dim
    self.dim_head = (int(dim / heads)) if dim_head is None else dim_head
    self.project_patches = nn.Linear(self.token_dim, dim)

    self.emb_dropout = nn.Dropout(dropout)
    if self.classification:
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.pos_emb1D = nn.Parameter(torch.randn(tokens + 1, dim))
        self.mlp_head = nn.Linear(dim, num_classes)
    else:
        self.pos_emb1D = nn.Parameter(torch.randn(tokens, dim))

    if transformer is None:
        self.transformer = TransformerEncoder(dim, blocks=blocks, heads=heads,
                                                dim_head=self.dim_head,
                                                dim_linear_block=dim_linear_block,
                                                dropout=dropout)
    else:
        self.transformer = transformer

这段代码定义了一个名为 ViT 的 PyTorch 模型类,它是一个使用自注意力机制(Self-Attention)实现的视觉 Transformer 模型。其中主要参数包括:

  • img_dim:输入图片的空间大小
  • in_channels:输入图片的通道数
  • patch_dim:将图片划分成固定大小的 patch 的大小
  • num_classes:分类任务的类别数
  • dim:线性层的维度,用于将每个 patch 投影到 MHSA 空间
  • blocks:Transformer 模型中的块数
  • heads:注意力头的数量
  • dim_linear_block:线性块内部的维度
  • dim_head:每个头的维度,如果没有指定则默认为 dim/heads
  • dropout:用于位置编码和 Transformer 的 dropout 概率
  • transformer:可选的 TransformerEncoder 类实例
  • classification:是否包含额外的 CLS 标记以用于分类任务
def __init__(self, *,
                 img_dim,
                 in_channels=3,
                 patch_dim=16,
                 num_classes=10,
                 dim=512,
                 blocks=6,
                 heads=4,
                 dim_linear_block=1024,
                 dim_head=None,
                 dropout=0, transformer=None, classification=True):
    super().__init__()

这里定义了 ViT 类的构造函数,其包含多个参数,包括输入图像大小 img_dim,输入通道数 in_channels,分块大小 patch_dim,分类数目 num_classes,嵌入维度 dim,Transformer编码器的块数 blocks,头数 heads,线性块的维度 dim_linear_block,注意力头维度 dim_head,Dropout概率 dropout,可选的Transformer编码器 transformer,以及是否进行分类的标志 classification

    assert img_dim % patch_dim == 0, f'patch size {patch_dim} not divisible'
    self.p = patch_dim
    self.classification = classification

这里检查 img_dim 是否能够被 patch_dim 整除,如果不能整除,则会引发断言错误。同时,将 patch_dim 存储到 self.p 中,并将是否进行分类的标志存储到 self.classification 中。

    tokens = (img_dim // patch_dim) ** 2
    self.token_dim = in_channels * (patch_dim ** 2)
    self.dim = dim
    self.dim_head = (int(dim / heads)) if dim_head is None else dim_head
    self.project_patches = nn.Linear(self.token_dim, dim)

这里计算了输入图像中可分块的数量 tokens,并将每个块的维度 self.token_dim 设置为 in_channels * (patch_dim ** 2)

将嵌入维度 dim 存储到 self.dim 中,并根据 dim_head 是否为 None,设置注意力头维度 self.dim_headself.project_patches 是一个线性层,用于将每个块投影到嵌入空间中。

    self.emb_dropout = nn.Dropout(dropout)
    if self.classification:
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.pos_emb1D = nn.Parameter(torch.randn(tokens + 1, dim))
        self.mlp_head = nn.Linear(dim, num_classes)
    else:
        self.pos_emb1D = nn.Parameter(torch.randn(tokens, dim))

这里定义了嵌入层的Dropout层,并根据是否进行分类的标志,设置类别标记 self.cls_token、位置嵌入 self.pos_emb1DMLPself.mlp_head。如果不进行分类,则不需要 self.cls_tokenself.mlp_head

if transformer is None:
        self.transformer = TransformerEncoder(dim, blocks=blocks, heads=heads,
                                                dim_head=self.dim_head,
                                                dim_linear_block=dim_linear_block,
                                                dropout=dropout)
    else:
        self.transformer = transformer

self.emb_dropout = nn.Dropout(dropout): 定义了一个dropout层,用于在embedding后对其进行dropout操作。

if self.classification:: 如果是分类任务,就执行下面的操作,否则跳过。

self.cls_token = nn.Parameter(torch.randn(1, 1, dim)): 定义了一个可训练参数cls_token,表示分类token,它是一个1x1xdim的tensor,其中dim表示embedding维度。

self.pos_emb1D = nn.Parameter(torch.randn(tokens + 1, dim)): 定义了一个可训练参数pos_emb1D,表示位置嵌入,它是一个(tokens+1)xdim的tensor,其中tokens表示图像被分成的patch数,dim表示embedding维度。

self.mlp_head = nn.Linear(dim, num_classes): 定义了一个全连接层,将embedding映射到输出类别的数量。

最后,根据传入的参数来选择使用默认的TransformerEncoder,还是使用传入的transformer。如果没有传入,则使用默认的TransformerEncoder,否则使用传入的transformer。

def expand_cls_to_batch(self, batch):
    """
    Args:
        batch: batch size
    Returns: cls token expanded to the batch size
    """
    return self.cls_token.expand([batch, -1, -1])

该方法的作用是将 Transformer 中的分类 token 扩展到整个批次的样本数。它接受一个 batch 参数作为批次大小,返回一个形状为 [batch, 1, dim] 的张量,其中 dim 是 Transformer 模型的维度大小。在这个方法中,使用了 PyTorch 的 expand() 方法来实现扩展操作。

def forward(self, img, mask=None):
    batch_size = img.shape[0]
    img_patches = rearrange(
        img, 'b c (patch_x x) (patch_y y) -> b (x y) (patch_x patch_y c)',
                            patch_x=self.p, patch_y=self.p)
    # project patches with linear layer + add pos emb
    img_patches = self.project_patches(img_patches)

    if self.classification:
        img_patches = torch.cat(
            (self.expand_cls_to_batch(batch_size), img_patches), dim=1)

    patch_embeddings = self.emb_dropout(img_patches + self.pos_emb1D)

    # feed patch_embeddings and output of transformer. shape: [batch, tokens, dim]
    y = self.transformer(patch_embeddings, mask)

    if self.classification:
        # we index only the cls token for classification. nlp tricks :P
        return self.mlp_head(y[:, 0, :])
    else:
        return y

forward 函数中,接收输入的 imgmask

通过 img_dimpatch_dim 计算出 tokens 数量,其中 tokens 为图像分割成的块的数量。

将输入的 img 分成 patch,并通过 rearrange 函数重组成形状为 [batch_size, tokens, patch_dim * patch_dim * in_channels] 的张量。

通过 Linear 层将每个 patch 映射到 dim 维度,并加上位置编码 pos_emb1D

如果是用于分类任务,则在序列的开头插入一个 CLS token,然后与处理后的 patch 张量按列拼接。

对 patch_embeddings 应用 dropout,并输入到 TransformerEncoder 中,返回输出张量 y,形状为 [batch_size, tokens, dim]

如果是用于分类任务,则从 y 中取出 CLS token,输入到一个 Linear 层中进行分类,输出分类结果。

如果不是分类任务,则直接返回 y。文章来源地址https://www.toymoban.com/news/detail-436562.html

到了这里,关于【计算机视觉】ViT:代码逐行解读的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 【计算机视觉 | Kaggle】飞机凝结轨迹识别 Baseline 分享和解读(含源代码)

    比赛名称:Google Research - Identify Contrails to Reduce Global Warming 训练 ML 模型以识别卫星图像中的尾迹 比赛类型:计算机视觉、语义分割 Contrails 是“凝结轨迹”的缩写,是在飞机发动机排气中形成的线状冰晶云,由飞机飞过大气中的超潮湿区域时产生。持续的尾迹对全球变暖的贡

    2024年02月14日
    浏览(40)
  • 【计算机视觉】Visual Transformer (ViT)模型结构以及原理解析

    Visual Transformer (ViT) 出自于论文《AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE》,是基于Transformer的模型在视觉领域的开篇之作。 本文将尽可能简洁地介绍一下ViT模型的整体架构以及基本原理。 ViT模型是基于Transformer Encoder模型的,在这里假设读者已经了解Transfo

    2024年02月02日
    浏览(47)
  • 【计算机视觉】Gaussian Splatting源码解读补充(一)

    本文旨在补充@gwpscut创作的博文学习笔记之——3D Gaussian Splatting源码解读。 Gaussian Splatting Github地址:https://github.com/graphdeco-inria/gaussian-splatting 论文地址:https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/3d_gaussian_splatting_high.pdf 这部分可以参考PlenOctrees论文的附录B。 有时候从不同的

    2024年04月09日
    浏览(176)
  • 【计算机视觉】Gaussian Splatting源码解读补充(二)

    第一部分 本文是对学习笔记之——3D Gaussian Splatting源码解读的补充,并订正了一些错误。 其中出现的辅助函数: 这部分的参考资料: [1] CUDA Tutorial [2] An Even Easier Introduction to CUDA [3] Introduction to CUDA Programming CUDA是一个为支持CUDA的GPU提供的平台和编程模型。该平台使GPU能够进

    2024年04月10日
    浏览(48)
  • 【计算机视觉 | 目标检测】Grounding DINO:开集目标检测论文解读

    介绍一篇较新的目标检测工作: 论文地址为: github 地址为: 作者展示一种开集目标检测方案: Grounding DINO ,将将基于 Transformer 的检测器 DINO 与真值预训练相结合。 开集检测关键是引入 language 至闭集检测器,用于开集概念泛化。作者将闭集检测器分为三个阶段,提出一种

    2024年02月10日
    浏览(61)
  • 【计算机视觉 | 目标检测】Open-Vocabulary DETR with Conditional Matching论文解读

    论文题目:具有条件匹配的开放词汇表DETR 开放词汇对象检测是指在自然语言的引导下对新对象进行检测的问题,越来越受到社会的关注。理想情况下,我们希望扩展一个开放词汇表检测器,这样它就可以基于自然语言或范例图像形式的用户输入生成边界框预测。这为人机交

    2024年01月21日
    浏览(42)
  • 13 计算机视觉-代码详解

    为了防止在训练集上过拟合,有两种办法,第一种是扩大训练集数量,但是需要大量的成本;第二种就是应用迁移学习,将源数据学习到的知识迁移到目标数据集,即在把在源数据训练好的参数和模型(除去输出层)直接复制到目标数据集训练。 13.2.1 获取数据集  13.2.2 初始

    2024年02月12日
    浏览(41)
  • 计算机视觉之姿态识别(原理+代码实操)

    •人体分割使用的方法可以大体分为人体骨骼关键点检测、语义分割等方式实现。这里主要分析与姿态相关的人体骨骼关键点检测。人体骨骼关键点检测输出是人体的骨架信息,一般主要作为人体姿态识别的基础部分,主要用于分割、对齐等。一般实现流程为: •主要检测人

    2023年04月16日
    浏览(38)
  • 【计算机视觉】DINOv2(视觉大模型)代码使用和测试(完整的源代码)

    输出为: 命令是一个Git命令,用于克隆(Clone)名为\\\"dinov2\\\"的存储库。它使用了一个名为\\\"ghproxy.com\\\"的代理,用于加速GitHub的克隆操作。 我们需要切换为output的路径: 以下是代码的逐行中文解读: 这段代码的功能是对给定的图像进行一系列处理和特征提取,并使用PCA对特征进

    2024年02月16日
    浏览(58)
  • 【计算机视觉】YOLOv8如何使用?(含源代码)

    comments description keywords true Boost your Python projects with object detection, segmentation and classification using YOLOv8. Explore how to load, train, validate, predict, export, track and benchmark models with ease. YOLOv8, Ultralytics, Python, object detection, segmentation, classification, model training, validation, prediction, model export, bench

    2024年02月04日
    浏览(52)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包