【ViT 微调时关于position embedding如何插值(interpolate)的详解】

这篇具有很好参考价值的文章主要介绍了【ViT 微调时关于position embedding如何插值(interpolate)的详解】。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

【ViT 微调时关于position embedding如何插值(interpolate)的详解】
本文适合对Vision Transformer有一定了解(知道内部结构和一些实现细节,最好是精读过ViT这篇论文)的读者阅读,这篇博客不会详细说明ViT的结构和前向推断过程。

1. 问题描述

符号 含义
b b b batch size
N N N patch size
H H H W W W 低分辨率图像的高和宽
H ′ H' H W ′ W' W 高分辨率图像的高和宽
s o s_o so 低分辨率图像的sequence length的长度( o o o是original的意思)
s n s_n sn 高分辨率图像的sequence length的长度( n n n是new的意思)
h h h hidden dimension,即每个patch经过linear layer后得到的vector的长度,原文是16x16x3=768

最近在读ViT相关的论文(ViT、DeiT、Swin Transformer),感觉看得比较细致,但ViT中有个细节我一直不太理解:就是在用高分辨率(high resolution)图像做微调时,作者在论文里说:保持patch size不变,直接把position embedding向量进行插值处理(interpolate),原文如下:

【ViT 微调时关于position embedding如何插值(interpolate)的详解】

作者的意思是:当使用高分辨率(high resolution)图像对预训练好的ViT进行微调(fine-tuning)时,保持patch size( N ∗ N N*N NN)不变(即每个patch中的像素数量不变),但由于image size( H ′ ∗ W ′ H'*W' HW,且 H ′ = W ′ H'=W' H=W)变大了,则sequence length s n = H ′ / N s_n=H'/N sn=H/N 也相应变大了。而预训练好的position embedding是对原先低分辨率(low resolution)图像的位置编码(即原来的sequence length s o = H / N s_o=H/N so=H/N),自然无法适应现在的新的sequence length s n s_n sn。作者对此提出的解决方案是对原先的postion embedding进行2D的插值处理。

这我就很困惑了:position embedding是个1-D的向量,怎么做2D的插值呢?查了好久也没找到满意的解释,最后还是去看了torchvision中ViT的实现才明白怎么回事儿,其实很简单。

2. positional embedding如何interpolate

我们用图来表示想做的事情:
【ViT 微调时关于position embedding如何插值(interpolate)的详解】
【ViT 微调时关于position embedding如何插值(interpolate)的详解】

如何把 s o s_o so变成 s n s_n sn呢?具体做法如下:

假设position_embedding_img的shape为 ( b , h , s o ) (b, h, s_o) (b,h,so),其中 b b b为batch size,设置 b = 1 b=1 b=1 h h h s o s_o so的含义见上面的表格。

  • 首先将position_embedding_img的shape由 ( b , h , s o ) (b, h, s_o) (b,h,so)reshape成( b b b, h h h, s o \sqrt{s_o} so , s o \sqrt{s_o} so
  • 然后将后两维 ( s o (\sqrt{s_o} (so , s o ) \sqrt{s_o}) so ) 使用torch.nn.functinoal.interpolate,插值成: ( s n (\sqrt{s_n} (sn , s n ) \sqrt{s_n}) sn ),此时position_embedding_img_new的shape为:( b b b, h h h, s n \sqrt{s_n} sn , s n \sqrt{s_n} sn
  • 最后再把position_embedding_img_new reshape成( b b b, h h h, s n s_n sn

经过上述步骤,我们就将position_embedding_img的 ( b , h , s o ) (b, h, s_o) (b,h,so)变成了position_embedding_img_new的 ( b , h , s n ) (b, h, s_n) (b,h,sn)。示意图如下(这里设 b = 1 , h = 1 b=1,h=1 b=1,h=1):

【ViT 微调时关于position embedding如何插值(interpolate)的详解】

3. 输入的sequence length改变了ViT还能正常前向推断?

【ViT 微调时关于position embedding如何插值(interpolate)的详解】

其实到了第二步就已经结束了,但可能有些人(包括我之前)还会有个疑问:之前我们预训练时输入给Transformer Encoder(即上图中红色圈出的部分)的tensor的shape为: ( b , s o , h ) (b, s_o, h) (b,so,h),而如果使用高分辨率的img进行微调,那输入到Transformer Encoder的shape变成了: ( b , s n , h ) (b, s_n, h) (b,sn,h),还可以前向推断吗?Transformer Encoder不需要改内部结构吗?

答案是不需要。原因在于微调时hidden dimension h h h的值没有变,为什么这么说呢?我们考虑下Transformer Encoder的内部结构,主要是多头自注意力(multi-head self-attention)和MLP。multi-head self-attention其实就是把输入切分成n个头,分别进行self-attention,然后再把结果concat起来,所以我们以单头自注意力、batch size=1为例,self-attention的大致流程为:

【ViT 微调时关于position embedding如何插值(interpolate)的详解】
可以看出,Transformer Encoder中训练的参数: W q 、 W k 、 W v W_q、W_k、W_v WqWkWv的形状都为 ( h , h ) (h, h) (h,h),并不会随着sequence length由 s o s_o so变为 s n s_n sn而发生改变。

同理,Transformer Encoder中的MLP的input layer的神经元个数也是 h h h,和 s n s_n sn无关。

即Transformer Encoder中参数只和hidden embedding的长度 h h h有关,和sequence length s o 、 s n s_o、s_n sosn无关。

因此,即使我们输入Transformer Encoder的维度由 ( b , s o , h ) (b, s_o, h) (b,so,h)变为 ( b , s n , h ) (b, s_n, h) (b,sn,h),也不会影响ViT的前向推断过程。


如果想看Torchvision官方中关于interpolate代码的细节实现,我放在下面:

def interpolate_embeddings(
    image_size: int,
    patch_size: int,
    model_state: "OrderedDict[str, torch.Tensor]",
    interpolation_mode: str = "bicubic",
    reset_heads: bool = False,
) -> "OrderedDict[str, torch.Tensor]":
    """This function helps interpolating positional embeddings during checkpoint loading,
    especially when you want to apply a pre-trained model on images with different resolution.

    Args:
        image_size (int): Image size of the new model.
        patch_size (int): Patch size of the new model.
        model_state (OrderedDict[str, torch.Tensor]): State dict of the pre-trained model.
        interpolation_mode (str): The algorithm used for upsampling. Default: bicubic.
        reset_heads (bool): If true, not copying the state of heads. Default: False.

    Returns:
        OrderedDict[str, torch.Tensor]: A state dict which can be loaded into the new model.
    """
    # Shape of pos_embedding is (1, seq_length, hidden_dim)
    pos_embedding = model_state["encoder.pos_embedding"]
    n, seq_length, hidden_dim = pos_embedding.shape
    if n != 1:
        raise ValueError(f"Unexpected position embedding shape: {pos_embedding.shape}")

    new_seq_length = (image_size // patch_size) ** 2 + 1

    # Need to interpolate the weights for the position embedding.
    # We do this by reshaping the positions embeddings to a 2d grid, performing
    # an interpolation in the (h, w) space and then reshaping back to a 1d grid.
    if new_seq_length != seq_length:
        # The class token embedding shouldn't be interpolated so we split it up.
        seq_length -= 1
        new_seq_length -= 1
        pos_embedding_token = pos_embedding[:, :1, :]
        pos_embedding_img = pos_embedding[:, 1:, :]

        # (1, seq_length, hidden_dim) -> (1, hidden_dim, seq_length)
        pos_embedding_img = pos_embedding_img.permute(0, 2, 1)
        seq_length_1d = int(math.sqrt(seq_length))
        if seq_length_1d * seq_length_1d != seq_length:
            raise ValueError(
                f"seq_length is not a perfect square! Instead got seq_length_1d * seq_length_1d = {seq_length_1d * seq_length_1d } and seq_length = {seq_length}"
            )

        # (1, hidden_dim, seq_length) -> (1, hidden_dim, seq_l_1d, seq_l_1d)
        pos_embedding_img = pos_embedding_img.reshape(1, hidden_dim, seq_length_1d, seq_length_1d)
        new_seq_length_1d = image_size // patch_size

        # Perform interpolation.
        # (1, hidden_dim, seq_l_1d, seq_l_1d) -> (1, hidden_dim, new_seq_l_1d, new_seq_l_1d)
        new_pos_embedding_img = nn.functional.interpolate(
            pos_embedding_img,
            size=new_seq_length_1d,
            mode=interpolation_mode,
            align_corners=True,
        )

        # (1, hidden_dim, new_seq_l_1d, new_seq_l_1d) -> (1, hidden_dim, new_seq_length)
        new_pos_embedding_img = new_pos_embedding_img.reshape(1, hidden_dim, new_seq_length)

        # (1, hidden_dim, new_seq_length) -> (1, new_seq_length, hidden_dim)
        new_pos_embedding_img = new_pos_embedding_img.permute(0, 2, 1)
        new_pos_embedding = torch.cat([pos_embedding_token, new_pos_embedding_img], dim=1)

        model_state["encoder.pos_embedding"] = new_pos_embedding

        if reset_heads:
            model_state_copy: "OrderedDict[str, torch.Tensor]" = OrderedDict()
            for k, v in model_state.items():
                if not k.startswith("heads"):
                    model_state_copy[k] = v
            model_state = model_state_copy

    return model_state

参考:


1.原论文:AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE

2.Torchvision中ViT的实现源代码文章来源地址https://www.toymoban.com/news/detail-446499.html

到了这里,关于【ViT 微调时关于position embedding如何插值(interpolate)的详解】的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • Transformer的PE(position embedding),即位置编码理解

    最近要搞理论学习了,先前搞了大半年的工程,又要捡起一些理论原理,现在还是从transformer熟悉理解一下,争取吃透。 关于transformer的经典介绍和资料也一大堆,我就不展开来讲了,碰到了一些一时没太想明白的问题,就记一下,也当是重新理解一遍。 transformer的输入要么

    2024年02月16日
    浏览(66)
  • Ai 算法之Transformer 模型的实现: 一 、Input Embedding模块和Positional Embedding模块的实现

    比较常见的文章生成模型有以下几种: RNN:循环神经网络。可以处理长度变化的序列数据,比如自然语言文本。RNN通过隐藏层中的循环结构来传递时间序列中的信息,从而使当前的计算可以参照之前的信息。但这种模型有梯度爆炸和梯度消失的风险,所以只能做简单的生成任

    2024年01月23日
    浏览(39)
  • 深度学习笔记之Transformer(六)Position Embedding铺垫:Skipgram与CBOW模型

    上一节介绍了 Word2vec text{Word2vec} Word2vec 模型架构与对应策略。本节将继续介绍 Skipgram text{Skipgram} Skipgram 与 CBOW text{CBOW} CBOW 模型架构。 关于 Word2vec text{Word2vec} Word2vec 模型,它的 任务目标 是基于 语料库 ( Corpus ) (text{Corpus}) ( Corpus ) ,对该语料库对应 词汇表 ( Vocabulary ) (

    2024年02月15日
    浏览(39)
  • 【论文阅读随笔】RoPE/旋转编码:ROFORMER: ENHANCED TRANSFORMER WITH ROTARY POSITION EMBEDDING

    绝对位置编码比较简单,加或乘一个有次序的数 实现相对位置编码,也即意味着,要蕴含位置差的信息: 假设m是某个token的位置信息,n是另一个token的位置信息,要有类似 m − n m-n m − n 的信息,比较容易想到复数乘法会产生 m − n m-n m − n ,以及复数乘法和复数内积的性

    2024年03月11日
    浏览(45)
  • PETR: Position Embedding Transformation for Multi-View 3D Object Detection

    PETR: Position Embedding Transformation for Multi-View 3D Object Detection 旷视 DETR3D 中 2D-3D过程 存在的问题: 预测的参考点坐标可能不准确,在采样图片特征时可能拿不到对应的特征。 只有参考点 投影位置的图像特征被使用,无法学到全局的特征。 采样图像特征的过程过于复杂,难于应用

    2024年02月16日
    浏览(52)
  • 深度学习笔记之Transformer(五) Position Embedding铺垫:Word2vec

    在Transformer(三)自注意力机制一节中介绍了 位置编码 ( Position Embedding ) (text{Position Embedding}) ( Position Embedding ) ,本系列针对位置编码 再回首 ,从公式角度重新认识位置编码。本节作为铺垫,介绍一下 词向量 模型—— Word2vec text{Word2vec} Word2vec 。 在循环神经网络简单示例中

    2024年02月13日
    浏览(35)
  • CVPR 2022 Image Dehazing Transformer with Transmission-Aware 3D Position Embedding 个人学习笔记

    源码下载: CVPR2022ImageDehazingTransformerwithTransmission-Aware3D代码-深度学习文档类资源-CSDN下载 Abstract 尽管卷积神经网络(CNNs)的单图像去模糊已经取得了良好的进展,但卷积固有的 等方差 和 局部性 仍然是去雾性能的 瓶颈 。虽然 Transformer 占据了各种计算机视觉任务,但直接利

    2023年04月08日
    浏览(51)
  • LLaMa 原理+源码——拆解 (KV-Cache, Rotary Positional Embedding, RMS Norm, Grouped Query Attention, SwiGLU)

    Vanilla Transformer 与 LLaMa 的对比:LLaMa与普通的Transformer架构不同的地方,包括采用了 前置 了层归一化(Pre-normalization)并使用 RMSNorm 归一化函数(Normalizing Function)、使用了 旋转位置嵌入(RoPE) 、激活函数由ReLU更换为 SwiGLU ,并且将self-attention改进为 使用KV-Cache的Grouped Quer

    2024年01月25日
    浏览(45)
  • 用于多视图 3D 对象检测的位置嵌入变换(PETR: Position Embedding Transformation for Multi-View 3D Object Detection)

    本文PETR (PETR: Position Embedding Transformation for Multi-View 3D Object Detection)是对DETR3D (3D Object Detection from Multi-view Images via 3D-to-2D Queries)的改进,将2D转换至3D,还存在三个问题: (1) 空间与多视图之间的信息交互依赖于3D参考点估计的准确性,使得采样的特征超出了对象区域,无法投影

    2024年02月07日
    浏览(50)
  • 详解AI大模型行业黑话,迅速搞懂提示工程(prompt)、向量工程(embedding)、微调工程(fine-tune)

    大家都在讨论大模型,似乎什么都可以与大模型结合,可当初学者也想上手时,却面临一堆令人头大的词汇,什么Prompt、、Embedding、Fine-tuning,看到瞬间头都大了。一堆英文就算了,还不容易查到正确解释,怎么办呢?别担心,本文就用一种有趣的方式让大家认识它们。 首先

    2024年02月02日
    浏览(40)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包