Vision Transformer架构Pytorch逐行实现

这篇具有很好参考价值的文章主要介绍了Vision Transformer架构Pytorch逐行实现。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

前言

  • 代码来自哔哩哔哩博主deep_thoughts,视频地址,该博主对深度学习框架方面讲的非常详细,推荐大家也去看看原视频,不管是否已经非常熟练,我相信都能有很大收获。
  • 论文An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale,下载地址。开源项目地址
  • 本文不对开源项目中代码进行解析,仅使用pytorch实现ViT框架,让大家对框架有更清楚的认知。

模型框架展示

Vision Transformer架构Pytorch逐行实现

  • Encoder部分和Transformer中的实现方法一致,可以直接调用pytorch中的API实现(博主在前面几个视频中使用pytorch逐行写了decoderencoder,再次推荐大家去看他的视频),下面主要针对左边的部分进行实现。
  • 架构思维导图,如下图
    -Vision Transformer架构Pytorch逐行实现
  • 导入必要包
import torch
import torch.nn as nn
import torch.nn.functional as F
  • 定义初始变量
# batch_size, 输入通道数,图像高,图像宽
bs, ic, image_h, image_w = 1, 3, 8, 8
# 分块边长
patch_size = 4
# 输出通道数
model_dim = 8
# 最大子图片块数
max_num_token = 16
# 分类数
num_classes = 10
# 生成真实标签
label = torch.randint(10,(bs,))
# 卷积核面积 * 输入通道数
patch_depth = patch_size * patch_size * ic
# image张量
image = torch.randn(bs, ic, image_h, image_w)
# model_dim:输出通道数,patcg_depth:卷积核面积 * 输入通道数
weight = torch.randn(patch_depth, model_dim)

perspective

  • 这一部分有两种实现方式,第1种是DNN方式,利用pytorch中的unfold函数滑动提取图像块。第2种是使用2维卷积的方法,最后将特征铺平。

DNN perspective

  • 首先使用unfold函数,滑动提取不重叠的块,所以kernel_size和stride相同。
  • 再与weight进行矩阵相乘,维度变化以及每个维度意义都在注释中。
def image2emb_naive(image, patch_size, weight):
    # patch:[batch_size, patch_size * patch_size * ic, (image_h * image_w) / (patch_size * patch_size)]
    patch = F.unfold(image, kernel_size=patch_size,stride=patch_size)
    # 转置操作[batch_size, (image_h * image_w) / (patch_size * patch_size), patch_size * patch_size * ic]]
    patch = patch.transpose(-1, -2)
    # 矩阵乘法weight:[patch_size * patch_size * ic, model_dim]
    patch_embedding = patch @ weight
    return patch_embedding
  • 调用函数,得到patch_embedding,检查维度
# 得到patch_embedding:[batch_size, (image_h * image_w) / (patch_size * patch_size), model_dim]
patch_embedding_naive = image2emb_naive(image, patch_size, weight)
print(patch_embedding_naive.shape)

输出:

torch.Size([1, 4, 8])

CNN perspective

def image2emb_conv(image, kernel, stride):
    conv_output = F.conv2d(image, kernel, stride = stride)
    bs, oc, oh, ow = conv_output.shape
    # patch_embedding:[batch_size, outchannel, o_h * o_w]
    patch_embedding = conv_output.reshape((bs, oc, oh*ow))
    print(patch_embedding.shape)
    # patch_embedding:[batch_size, o_h * o_w, outchannel]
    patch_embedding = patch_embedding.transpose(-1,-2)
    print(patch_embedding.shape)
    return patch_embedding

weight = weight.transpose(0,1)
print(weight.shape)
# kernel:[outchannel, inchannel, patch_size, patch_size]
kernel = weight.reshape((-1,ic, patch_size, patch_size))
print(kernel.shape)
patch_embedding_conv = image2emb_conv(image, kernel, patch_size)
print(patch_embedding_conv.shape)

输出:

torch.Size([8, 48])
torch.Size([8, 3, 4, 4])
torch.Size([1, 8, 4])
torch.Size([1, 4, 8])
torch.Size([1, 4, 8])

class token embedding

  • 随机生成cls_token_emnedding,并将其设为可训练参数。沿着图片块数维度进行拼接,检查cls_token_emneddingtoken_embedding维度。
# CLS token embedding
# cls_token_emnedding:[batch_size,1,mode_dim]
cls_token_emnedding = torch.randn(bs, 1, model_dim, requires_grad=True)
# 沿着图片块数维度进行拼接
token_embedding = torch.cat([cls_token_emnedding, patch_embedding_naive], dim=1)
print(cls_token_emnedding.shape)
print(token_embedding.shape)

输出:文章来源地址https://www.toymoban.com/news/detail-461356.html

torch.Size([1, 1, 8])
torch.Size([1, 5, 8])

position embedding

  • 创建pos embedding:[max_num_token,model_dim],然后使用tile函数进行增自我拼接,重复batch_size次。
# add position embedding
# 创建pos embedding:[max_num_token,model_dim]
position_embedding_table = torch.randn(max_num_token, model_dim, requires_grad = True)
# 取图片块数维度
seq_len = token_embedding.shape[1]
# tile增自我拼接,dims参数指定每个维度中的重复次数,dims = [batch_size,1,1]
position_embedding = torch.tile(position_embedding_table[:seq_len], [token_embedding.shape[0],1,1])
print(position_embedding.shape)

Transformer Encoder部分

  • 实例化TransformerEncoderLayer,再实例化TransformerEncoder,得到Encoder输出。
# pass embedding to Transformer Encoder
encoder_layer = nn.TransformerEncoderLayer(d_model=model_dim,nhead=8)
transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
enconder_output = transformer_encoder(token_embedding)
print(enconder_output.shape)

classification head

  • 取出pos embedding维,经过线性层,对输出计算交叉熵损失
# 取出第1个图片块数维度,就是pos embedding维
cls_token_output = enconder_output[:,0,:]
# 实例化线性层model_dim --> num_classes
linear_layer = nn.Linear(model_dim, num_classes)
# 得到线性层输出
logits = linear_layer(cls_token_output)
# 交叉熵损失
loss_fn = nn.CrossEntropyLoss()
# 计算交叉熵损失
loss = loss_fn(logits,label)
print(loss)

到了这里,关于Vision Transformer架构Pytorch逐行实现的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 一个基于PVT(Pyramid Vision Transformer)的视频插帧程序(pytorch)

    项目地址(欢迎大家来star一下!):GitHub - liaoyanqing666/PVT_v2_video_frame_interpolation: 使用PVT_v2作为编码器的视频插帧程序,A program using PVT_v2 as the encoder of video frame interpolation, VFI, pytorch         众所周知,视频是由一系列连续帧组成的,而将两个连续帧中间插入中间帧就完成了

    2024年04月09日
    浏览(68)
  • 36k字从Attention解读Transformer及其在Vision中的应用(pytorch版)

    深度学习中的卷积操作:https://blog.csdn.net/zyw2002/article/details/128306697 1.1.1 Encoder-Decoder Encoder-Decoder框架顾名思义也就是 编码-解码框架 ,在NLP中Encoder-Decoder框架主要被用来处理 序列-序列问题 。也就是输入一个序列,生成一个序列的问题。这两个序列可以分别是任意长度。 具体

    2024年02月11日
    浏览(38)
  • 类ChatGPT逐行代码解读(1/2):从零实现Transformer、ChatGLM-6B

    最近一直在做类ChatGPT项目的部署 微调,关注比较多的是两个:一个LLaMA,一个ChatGLM,会发现有不少模型是基于这两个模型去做微调的,说到微调,那具体怎么微调呢,因此又详细了解了一下微调代码,发现微调LLM时一般都会用到Hugging face实现的Transformers库的Trainer类 从而发现

    2024年02月03日
    浏览(52)
  • 类ChatGPT逐行代码解读(1/2):从零起步实现Transformer、ChatGLM-6B

    最近一直在做类ChatGPT项目的部署 微调,关注比较多的是两个:一个LLaMA,一个ChatGLM,会发现有不少模型是基于这两个模型去做微调的,说到微调,那具体怎么微调呢,因此又详细了解了一下微调代码,发现微调LLM时一般都会用到Hugging face实现的Transformers库的Trainer类 从而发现

    2023年04月26日
    浏览(43)
  • SRCNN超分辨率Pytorch实现,代码逐行讲解,附源码

    目录 0.图像超分辨率 1.SRCNN介绍 训练过程 损失函数  个人对SRCNN训练过程的理解 2.实验常见问题和部分解读 1. torch.utils.data.dataloader中DataLoader函数的用法 2.SRCNN图像颜色空间转换原因以及方法? 3. model.parameters()与model.state_dict()的区别 4. .item()函数的用法? 5.最后的测试过程

    2023年04月15日
    浏览(36)
  • 人工智能(Pytorch)搭建transformer模型,真正跑通transformer模型,深刻了解transformer的架构

    大家好,我是微学AI,今天给大家讲述一下人工智能(Pytorch)搭建transformer模型,手动搭建transformer模型,我们知道transformer模型是相对复杂的模型,它是一种利用自注意力机制进行序列建模的深度学习模型。相较于 RNN 和 CNN,transformer 模型更高效、更容易并行化,广泛应用于神

    2023年04月10日
    浏览(62)
  • 《Vision Transformer (ViT)》论文精度,并解析ViT模型结构以及代码实现

    《AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE》 论文共有22页,表格和图像很多,网络模型结构解释的很清楚,并且用四个公式展示了模型的计算过程;本文章对其进行精度,并对源码进行剖析,希望读者可以耐心读下去。 论文地址:https://arxiv.org/abs/2010.11929 源

    2024年02月05日
    浏览(39)
  • 挑战Transformer的新架构Mamba解析以及Pytorch复现

    今天我们来详细研究这篇论文“Mamba:具有选择性状态空间的线性时间序列建模” Mamba一直在人工智能界掀起波澜,被吹捧为Transformer的潜在竞争对手。到底是什么让Mamba在拥挤的序列建中脱颖而出? 在介绍之前先简要回顾一下现有的模型 Transformer:以其注意力机制而闻名,其中序

    2024年02月02日
    浏览(45)
  • 深度学习实战24-人工智能(Pytorch)搭建transformer模型,真正跑通transformer模型,深刻了解transformer的架构

    大家好,我是微学AI,今天给大家讲述一下人工智能(Pytorch)搭建transformer模型,手动搭建transformer模型,我们知道transformer模型是相对复杂的模型,它是一种利用自注意力机制进行序列建模的深度学习模型。相较于 RNN 和 CNN,transformer 模型更高效、更容易并行化,广泛应用于神

    2023年04月22日
    浏览(58)
  • 深度学习应用篇-计算机视觉-图像分类[3]:ResNeXt、Res2Net、Swin Transformer、Vision Transformer等模型结构、实现、模型特点详细介绍

    【深度学习入门到进阶】必看系列,含激活函数、优化策略、损失函数、模型调优、归一化算法、卷积模型、序列模型、预训练模型、对抗神经网络等 专栏详细介绍:【深度学习入门到进阶】必看系列,含激活函数、优化策略、损失函数、模型调优、归一化算法、卷积模型、

    2024年02月14日
    浏览(51)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包