dalle2:hierarchical text-conditional image generation with clip

这篇具有很好参考价值的文章主要介绍了dalle2:hierarchical text-conditional image generation with clip。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

DALL·E 2【论文精读】_哔哩哔哩_bilibili更多论文:https://github.com/mli/paper-reading, 视频播放量 30350、弹幕量 256、点赞数 1767、投硬币枚数 1318、收藏人数 751、转发人数 344, 视频作者 跟李沐学AI, 作者简介 ,相关视频:博一研究生 求偶视频,如何做好文献阅读及笔记整理,在线求偶|26岁985副教授,开组会时,师兄SCI见刊了,生成对抗网络GAN开山之作论文精读,GAN论文逐段精读【论文精读】,对比学习论文综述【论文精读】,01 机器学习编译概述 【MLC-机器学习编译中文版】,导师对不起,您评院士的事可能得缓缓了,【精读AI论文】知识蒸馏https://www.bilibili.com/video/BV17r4y1u77B?spm_id_from=333.999.0.0&vd_source=4aed82e35f26bb600bc5b46e65e25c22看到市面上的一些关于dalle2的的解释其实都不太好,没说的很明白,生成模型的三大方向分别是vae,gan和扩散模型,其中ae->dae->vae->vqvae->diffusion,扩散模型的ddpm->improved ddpm->diffusion beets GAN->glide->dalle2.

1.introduction

        clip对图像分布变化具有鲁棒性,可以zero-shot,扩散模型能满足样本多样性且保真度也不错。dalle2结合了这两个模型的优良特性。

2.method

dalle2 paper,大模型、多模态和生成,人工智能,深度学习

上面这张图画的很好,结合这个图来看,首先虚线上面是一个clip,这个clip是提前训练好的,在dalle2的训练期间不会再去训练clip,是个权重锁死的,在dalle2的训练时,输入也是一对数据,一个文本对及其对应的图像,首先输入一个文本,经过clip的文本编码模块(bert,clip对图像使用vit,对text使用bert进行编码,clip是基本的对比学习,两个模态的编码很重要,模态编码之后直接余弦求相似度了),在输入一个图像,经过clip的图像编码模块,产生了图像的vector,这个图像vector其实是gt。产生的文本编码输入到第一个prior模型中,这是一个扩散模型,也可以用自回归的transformer,这个扩散模型输出一组图像vector,这时候通过经过clip产生的图像vector进行监督,此处其实是一个监督模型,后面是一个decoder模块,在以往的dalle中,encoder和decoder是放在dvae中一起训练的,但是此处的deocder是单训的,也是一个扩散模型,其实虚线之下的生成模型,是将一个完整的生成步骤,变成了二阶段显式的图像生成,作者实验这种显式的生成效果更好。这篇文章称自己为unclip,clip是将输入的文本和图像转成特征,而dalle2是将文本特征转成图像特征再转成图像的过程,其实图像特征到图像是通过一个扩散模型实现的。在deocder时既用了classifier-free guidence也用了clip的guidence,这个guidence指的是在decoder的过程中,输入是t时刻的一个带噪声的图像,最终输出是一个图像,这个带噪声的图像通过unet每一次得到的一个特征图可以用一个图像分类器去做判定,此处一般就用交叉熵函数做一个二分类,但是可以获取图像分类的梯度,利用这个梯度去引导扩散去更好的decoder。

3.代码

GitHub - lucidrains/DALLE2-pytorch: Implementation of DALL-E 2, OpenAI's updated text-to-image synthesis neural network, in Pytorch

核心是训练一个先验模型和一个decoder模型,这两个都是扩散模型,当然先验模型也可以是自回归AE,如果是自回归AE就是dalle的思路,clip自己选一个训练好的即可,clip的本质是提供良好的图像和文本的vector。

train_diffusion_prior输入是经过img2dataset和clip-retrieval的转换,通过img2dataset下载数据,训练先验模型,主要是通过clip-retrieval生成需要的img_emb/text_emb和meta_url。train_decoder的训练输入是img2dataset生成的tar模型,tar中包含图片即可,在dalle2-pytorch中会对输入的图片做判定,内置了一个clip对其进行特征提取,实际上先验和生成模型作为显式分离的部分,是单独训练的。

diffusion_prior

DiffusionPrior(
  (noise_scheduler): NoiseScheduler()
  (clip): OpenAIClipAdapter(
    (clip): CLIP(
      (visual): VisionTransformer(
        (conv1): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
        (ln_pre): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (transformer): Transformer(
          (resblocks): Sequential(
            (0): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (1): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (2): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (3): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (4): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (5): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (6): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (7): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (8): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (9): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (10): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (11): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (12): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (13): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (14): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (15): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (16): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (17): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (18): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (19): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (20): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (21): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (22): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (23): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
          )
        )
        (ln_post): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      )
      (transformer): Transformer(
        (resblocks): Sequential(
          (0): ResidualAttentionBlock(
            (attn): MultiheadAttention(
              (out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
            )
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): QuickGELU()
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          )
          (1): ResidualAttentionBlock(
            (attn): MultiheadAttention(
              (out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
            )
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): QuickGELU()
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          )
          (2): ResidualAttentionBlock(
            (attn): MultiheadAttention(
              (out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
            )
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): QuickGELU()
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          )
          (3): ResidualAttentionBlock(
            (attn): MultiheadAttention(
              (out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
            )
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): QuickGELU()
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          )
          (4): ResidualAttentionBlock(
            (attn): MultiheadAttention(
              (out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
            )
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): QuickGELU()
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          )
          (5): ResidualAttentionBlock(
            (attn): MultiheadAttention(
              (out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
            )
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): QuickGELU()
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          )
          (6): ResidualAttentionBlock(
            (attn): MultiheadAttention(
              (out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
            )
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): QuickGELU()
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          )
          (7): ResidualAttentionBlock(
            (attn): MultiheadAttention(
              (out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
            )
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): QuickGELU()
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          )
          (8): ResidualAttentionBlock(
            (attn): MultiheadAttention(
              (out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
            )
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): QuickGELU()
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          )
          (9): ResidualAttentionBlock(
            (attn): MultiheadAttention(
              (out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
            )
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): QuickGELU()
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          )
          (10): ResidualAttentionBlock(
            (attn): MultiheadAttention(
              (out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
            )
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): QuickGELU()
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          )
          (11): ResidualAttentionBlock(
            (attn): MultiheadAttention(
              (out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
            )
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): QuickGELU()
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          )
        )
      )
      (token_embedding): Embedding(49408, 768)
      (ln_final): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    )
    (clip_normalize): Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
  )
  (net): DiffusionPriorNetwork(
    (to_text_embeds): Sequential(
      (0): Identity()
      (1): Rearrange('b (n d) -> b n d', n=1)
    )
    (to_time_embeds): Sequential(
      (0): Embedding(1000, 768)
      (1): Rearrange('b (n d) -> b n d', n=1)
    )
    (to_image_embeds): Sequential(
      (0): Identity()
      (1): Rearrange('b (n d) -> b n d', n=1)
    )
    (causal_transformer): CausalTransformer(
      (init_norm): Identity()
      (rel_pos_bias): RelPosBias(
        (relative_attention_bias): Embedding(32, 12)
      )
      (layers): ModuleList(
        (0): ModuleList(
          (0): Attention(
            (norm): LayerNorm()
            (dropout): Dropout(p=0.05, inplace=False)
            (to_q): Linear(in_features=768, out_features=768, bias=False)
            (to_kv): Linear(in_features=768, out_features=128, bias=False)
            (rotary_emb): RotaryEmbedding()
            (to_out): Sequential(
              (0): Linear(in_features=768, out_features=768, bias=False)
              (1): LayerNorm()
            )
          )
          (1): Sequential(
            (0): LayerNorm()
            (1): Linear(in_features=768, out_features=6144, bias=False)
            (2): SwiGLU()
            (3): LayerNorm()
            (4): Dropout(p=0.05, inplace=False)
            (5): Linear(in_features=3072, out_features=768, bias=False)
          )
        )
        (1): ModuleList(
          (0): Attention(
            (norm): LayerNorm()
            (dropout): Dropout(p=0.05, inplace=False)
            (to_q): Linear(in_features=768, out_features=768, bias=False)
            (to_kv): Linear(in_features=768, out_features=128, bias=False)
            (rotary_emb): RotaryEmbedding()
            (to_out): Sequential(
              (0): Linear(in_features=768, out_features=768, bias=False)
              (1): LayerNorm()
            )
          )
          (1): Sequential(
            (0): LayerNorm()
            (1): Linear(in_features=768, out_features=6144, bias=False)
            (2): SwiGLU()
            (3): LayerNorm()
            (4): Dropout(p=0.05, inplace=False)
            (5): Linear(in_features=3072, out_features=768, bias=False)
          )
        )
        (2): ModuleList(
          (0): Attention(
            (norm): LayerNorm()
            (dropout): Dropout(p=0.05, inplace=False)
            (to_q): Linear(in_features=768, out_features=768, bias=False)
            (to_kv): Linear(in_features=768, out_features=128, bias=False)
            (rotary_emb): RotaryEmbedding()
            (to_out): Sequential(
              (0): Linear(in_features=768, out_features=768, bias=False)
              (1): LayerNorm()
            )
          )
          (1): Sequential(
            (0): LayerNorm()
            (1): Linear(in_features=768, out_features=6144, bias=False)
            (2): SwiGLU()
            (3): LayerNorm()
            (4): Dropout(p=0.05, inplace=False)
            (5): Linear(in_features=3072, out_features=768, bias=False)
          )
        )
        (3): ModuleList(
          (0): Attention(
            (norm): LayerNorm()
            (dropout): Dropout(p=0.05, inplace=False)
            (to_q): Linear(in_features=768, out_features=768, bias=False)
            (to_kv): Linear(in_features=768, out_features=128, bias=False)
            (rotary_emb): RotaryEmbedding()
            (to_out): Sequential(
              (0): Linear(in_features=768, out_features=768, bias=False)
              (1): LayerNorm()
            )
          )
          (1): Sequential(
            (0): LayerNorm()
            (1): Linear(in_features=768, out_features=6144, bias=False)
            (2): SwiGLU()
            (3): LayerNorm()
            (4): Dropout(p=0.05, inplace=False)
            (5): Linear(in_features=3072, out_features=768, bias=False)
          )
        )
        (4): ModuleList(
          (0): Attention(
            (norm): LayerNorm()
            (dropout): Dropout(p=0.05, inplace=False)
            (to_q): Linear(in_features=768, out_features=768, bias=False)
            (to_kv): Linear(in_features=768, out_features=128, bias=False)
            (rotary_emb): RotaryEmbedding()
            (to_out): Sequential(
              (0): Linear(in_features=768, out_features=768, bias=False)
              (1): LayerNorm()
            )
          )
          (1): Sequential(
            (0): LayerNorm()
            (1): Linear(in_features=768, out_features=6144, bias=False)
            (2): SwiGLU()
            (3): LayerNorm()
            (4): Dropout(p=0.05, inplace=False)
            (5): Linear(in_features=3072, out_features=768, bias=False)
          )
        )
        (5): ModuleList(
          (0): Attention(
            (norm): LayerNorm()
            (dropout): Dropout(p=0.05, inplace=False)
            (to_q): Linear(in_features=768, out_features=768, bias=False)
            (to_kv): Linear(in_features=768, out_features=128, bias=False)
            (rotary_emb): RotaryEmbedding()
            (to_out): Sequential(
              (0): Linear(in_features=768, out_features=768, bias=False)
              (1): LayerNorm()
            )
          )
          (1): Sequential(
            (0): LayerNorm()
            (1): Linear(in_features=768, out_features=6144, bias=False)
            (2): SwiGLU()
            (3): LayerNorm()
            (4): Dropout(p=0.05, inplace=False)
            (5): Linear(in_features=3072, out_features=768, bias=False)
          )
        )
        (6): ModuleList(
          (0): Attention(
            (norm): LayerNorm()
            (dropout): Dropout(p=0.05, inplace=False)
            (to_q): Linear(in_features=768, out_features=768, bias=False)
            (to_kv): Linear(in_features=768, out_features=128, bias=False)
            (rotary_emb): RotaryEmbedding()
            (to_out): Sequential(
              (0): Linear(in_features=768, out_features=768, bias=False)
              (1): LayerNorm()
            )
          )
          (1): Sequential(
            (0): LayerNorm()
            (1): Linear(in_features=768, out_features=6144, bias=False)
            (2): SwiGLU()
            (3): LayerNorm()
            (4): Dropout(p=0.05, inplace=False)
            (5): Linear(in_features=3072, out_features=768, bias=False)
          )
        )
        (7): ModuleList(
          (0): Attention(
            (norm): LayerNorm()
            (dropout): Dropout(p=0.05, inplace=False)
            (to_q): Linear(in_features=768, out_features=768, bias=False)
            (to_kv): Linear(in_features=768, out_features=128, bias=False)
            (rotary_emb): RotaryEmbedding()
            (to_out): Sequential(
              (0): Linear(in_features=768, out_features=768, bias=False)
              (1): LayerNorm()
            )
          )
          (1): Sequential(
            (0): LayerNorm()
            (1): Linear(in_features=768, out_features=6144, bias=False)
            (2): SwiGLU()
            (3): LayerNorm()
            (4): Dropout(p=0.05, inplace=False)
            (5): Linear(in_features=3072, out_features=768, bias=False)
          )
        )
        (8): ModuleList(
          (0): Attention(
            (norm): LayerNorm()
            (dropout): Dropout(p=0.05, inplace=False)
            (to_q): Linear(in_features=768, out_features=768, bias=False)
            (to_kv): Linear(in_features=768, out_features=128, bias=False)
            (rotary_emb): RotaryEmbedding()
            (to_out): Sequential(
              (0): Linear(in_features=768, out_features=768, bias=False)
              (1): LayerNorm()
            )
          )
          (1): Sequential(
            (0): LayerNorm()
            (1): Linear(in_features=768, out_features=6144, bias=False)
            (2): SwiGLU()
            (3): LayerNorm()
            (4): Dropout(p=0.05, inplace=False)
            (5): Linear(in_features=3072, out_features=768, bias=False)
          )
        )
        (9): ModuleList(
          (0): Attention(
            (norm): LayerNorm()
            (dropout): Dropout(p=0.05, inplace=False)
            (to_q): Linear(in_features=768, out_features=768, bias=False)
            (to_kv): Linear(in_features=768, out_features=128, bias=False)
            (rotary_emb): RotaryEmbedding()
            (to_out): Sequential(
              (0): Linear(in_features=768, out_features=768, bias=False)
              (1): LayerNorm()
            )
          )
          (1): Sequential(
            (0): LayerNorm()
            (1): Linear(in_features=768, out_features=6144, bias=False)
            (2): SwiGLU()
            (3): LayerNorm()
            (4): Dropout(p=0.05, inplace=False)
            (5): Linear(in_features=3072, out_features=768, bias=False)
          )
        )
        (10): ModuleList(
          (0): Attention(
            (norm): LayerNorm()
            (dropout): Dropout(p=0.05, inplace=False)
            (to_q): Linear(in_features=768, out_features=768, bias=False)
            (to_kv): Linear(in_features=768, out_features=128, bias=False)
            (rotary_emb): RotaryEmbedding()
            (to_out): Sequential(
              (0): Linear(in_features=768, out_features=768, bias=False)
              (1): LayerNorm()
            )
          )
          (1): Sequential(
            (0): LayerNorm()
            (1): Linear(in_features=768, out_features=6144, bias=False)
            (2): SwiGLU()
            (3): LayerNorm()
            (4): Dropout(p=0.05, inplace=False)
            (5): Linear(in_features=3072, out_features=768, bias=False)
          )
        )
        (11): ModuleList(
          (0): Attention(
            (norm): LayerNorm()
            (dropout): Dropout(p=0.05, inplace=False)
            (to_q): Linear(in_features=768, out_features=768, bias=False)
            (to_kv): Linear(in_features=768, out_features=128, bias=False)
            (rotary_emb): RotaryEmbedding()
            (to_out): Sequential(
              (0): Linear(in_features=768, out_features=768, bias=False)
              (1): LayerNorm()
            )
          )
          (1): Sequential(
            (0): LayerNorm()
            (1): Linear(in_features=768, out_features=6144, bias=False)
            (2): SwiGLU()
            (3): LayerNorm()
            (4): Dropout(p=0.05, inplace=False)
            (5): Linear(in_features=3072, out_features=768, bias=False)
          )
        )
      )
      (norm): LayerNorm()
      (project_out): Linear(in_features=768, out_features=768, bias=False)
    )
  )
)

先验模型是一个自回归的transformer,流程如下:

TrainDiffusionPriorConfig.from_json_path()->prior/data/train/tracker(保存)->train:DiffusionPriorTrainer=make_model()->(prior_config)DiffusionPriorConfig/DiffusionPriorTrainConfig->diffusion_prior=prior_config.create()->clip=clip.create()->AdapterConfig.create()->OpenAIClipAdapter()->diffusion_prior_network=net.create()->DiffusionPriorNetworkConfig.create()->DiffusionPriorNetwork->trainer=DiffusionPriorTrainer->tracker=create_tracker()->TrackerConfig.create()->img_reader=get_reader()此处输入是三组img_url/text_url/meta_url->image_reader=EmbeddingReader()/text_reader=EmbeddingReader()->train_loader/eval_loader/test_loader=make_splits->train:DiffusionPriorTrainer/Tracker/DiffusionPriorTrainConfig->img:16,768,txt:16,77->DiffusionPrior.forward:net:DiffusionPriorNetwork->image_embed,_=self.clip.embed_image()/text_embed,text_encodings=self.clip.embed_text()->times=self.noise_scheduler.sample_random_times->self.p_losses->image_embed_noisy=self.noise_scheduler.q_sample(NoiseScheduler)->pred=self.net:image_embed_noisy:16,768,text_cond:text_embed 16,768/text_encodings16,77,768->DiffusionPriorNetwork,forward()->image_embed:16,768,text_embed:16,768->tokens:16,81,768->pred_image_embed:16,768->target=noise:16,768->loss=self.noise_scheduler.loss_fn(l2:mse)->trainer.update

decoder:

Decoder(
  (clip): OpenAIClipAdapter(
    (clip): CLIP(
      (visual): VisionTransformer(
        (conv1): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
        (ln_pre): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (transformer): Transformer(
          (resblocks): Sequential(
            (0): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (1): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (2): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (3): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (4): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (5): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (6): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (7): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (8): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (9): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (10): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (11): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (12): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (13): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (14): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (15): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (16): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (17): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (18): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (19): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (20): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (21): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (22): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (23): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
          )
        )
        (ln_post): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      )
      (transformer): Transformer(
        (resblocks): Sequential(
          (0): ResidualAttentionBlock(
            (attn): MultiheadAttention(
              (out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
            )
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): QuickGELU()
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          )
          (1): ResidualAttentionBlock(
            (attn): MultiheadAttention(
              (out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
            )
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): QuickGELU()
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          )
          (2): ResidualAttentionBlock(
            (attn): MultiheadAttention(
              (out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
            )
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): QuickGELU()
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          )
          (3): ResidualAttentionBlock(
            (attn): MultiheadAttention(
              (out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
            )
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): QuickGELU()
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          )
          (4): ResidualAttentionBlock(
            (attn): MultiheadAttention(
              (out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
            )
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): QuickGELU()
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          )
          (5): ResidualAttentionBlock(
            (attn): MultiheadAttention(
              (out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
            )
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): QuickGELU()
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          )
          (6): ResidualAttentionBlock(
            (attn): MultiheadAttention(
              (out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
            )
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): QuickGELU()
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          )
          (7): ResidualAttentionBlock(
            (attn): MultiheadAttention(
              (out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
            )
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): QuickGELU()
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          )
          (8): ResidualAttentionBlock(
            (attn): MultiheadAttention(
              (out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
            )
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): QuickGELU()
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          )
          (9): ResidualAttentionBlock(
            (attn): MultiheadAttention(
              (out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
            )
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): QuickGELU()
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          )
          (10): ResidualAttentionBlock(
            (attn): MultiheadAttention(
              (out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
            )
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): QuickGELU()
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          )
          (11): ResidualAttentionBlock(
            (attn): MultiheadAttention(
              (out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
            )
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): QuickGELU()
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          )
        )
      )
      (token_embedding): Embedding(49408, 768)
      (ln_final): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    )
    (clip_normalize): Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
  )
  (unets): ModuleList(
    (0): Unet(
      (init_conv): CrossEmbedLayer(
        (convs): ModuleList(
          (0): Conv2d(3, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): Conv2d(3, 4, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3))
          (2): Conv2d(3, 4, kernel_size=(15, 15), stride=(1, 1), padding=(7, 7))
        )
      )
      (to_time_hiddens): Sequential(
        (0): SinusoidalPosEmb()
        (1): Linear(in_features=16, out_features=64, bias=True)
        (2): GELU()
      )
      (to_time_tokens): Sequential(
        (0): Linear(in_features=64, out_features=32, bias=True)
        (1): Rearrange('b (r d) -> b r d', r=2)
      )
      (to_time_cond): Sequential(
        (0): Linear(in_features=64, out_features=64, bias=True)
      )
      (image_to_tokens): Identity()
      (norm_cond): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
      (norm_mid_cond): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
      (downs): ModuleList(
        (0): ModuleList(
          (0): None
          (1): ResnetBlock(
            (time_mlp): Sequential(
              (0): SiLU()
              (1): Linear(in_features=64, out_features=32, bias=True)
            )
            (block1): Block(
              (project): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (norm): GroupNorm(8, 16, eps=1e-05, affine=True)
              (act): SiLU()
            )
            (block2): Block(
              (project): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (norm): GroupNorm(8, 16, eps=1e-05, affine=True)
              (act): SiLU()
            )
            (res_conv): Identity()
          )
          (2): ModuleList(
            (0): ResnetBlock(
              (time_mlp): Sequential(
                (0): SiLU()
                (1): Linear(in_features=64, out_features=32, bias=True)
              )
              (block1): Block(
                (project): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (norm): GroupNorm(8, 16, eps=1e-05, affine=True)
                (act): SiLU()
              )
              (block2): Block(
                (project): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (norm): GroupNorm(8, 16, eps=1e-05, affine=True)
                (act): SiLU()
              )
              (res_conv): Identity()
            )
            (1): ResnetBlock(
              (time_mlp): Sequential(
                (0): SiLU()
                (1): Linear(in_features=64, out_features=32, bias=True)
              )
              (block1): Block(
                (project): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (norm): GroupNorm(8, 16, eps=1e-05, affine=True)
                (act): SiLU()
              )
              (block2): Block(
                (project): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (norm): GroupNorm(8, 16, eps=1e-05, affine=True)
                (act): SiLU()
              )
              (res_conv): Identity()
            )
          )
          (3): Identity()
          (4): Conv2d(16, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        )
        (1): ModuleList(
          (0): None
          (1): ResnetBlock(
            (time_mlp): Sequential(
              (0): SiLU()
              (1): Linear(in_features=64, out_features=32, bias=True)
            )
            (block1): Block(
              (project): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (norm): GroupNorm(8, 16, eps=1e-05, affine=True)
              (act): SiLU()
            )
            (block2): Block(
              (project): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (norm): GroupNorm(8, 16, eps=1e-05, affine=True)
              (act): SiLU()
            )
            (res_conv): Identity()
          )
          (2): ModuleList(
            (0): ResnetBlock(
              (time_mlp): Sequential(
                (0): SiLU()
                (1): Linear(in_features=64, out_features=32, bias=True)
              )
              (cross_attn): EinopsToAndFrom(
                (fn): CrossAttention(
                  (norm): LayerNorm()
                  (norm_context): Identity()
                  (dropout): Dropout(p=0.0, inplace=False)
                  (to_q): Linear(in_features=16, out_features=512, bias=False)
                  (to_kv): Linear(in_features=16, out_features=1024, bias=False)
                  (to_out): Sequential(
                    (0): Linear(in_features=512, out_features=16, bias=False)
                    (1): LayerNorm()
                  )
                )
              )
              (block1): Block(
                (project): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (norm): GroupNorm(8, 16, eps=1e-05, affine=True)
                (act): SiLU()
              )
              (block2): Block(
                (project): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (norm): GroupNorm(8, 16, eps=1e-05, affine=True)
                (act): SiLU()
              )
              (res_conv): Identity()
            )
            (1): ResnetBlock(
              (time_mlp): Sequential(
                (0): SiLU()
                (1): Linear(in_features=64, out_features=32, bias=True)
              )
              (cross_attn): EinopsToAndFrom(
                (fn): CrossAttention(
                  (norm): LayerNorm()
                  (norm_context): Identity()
                  (dropout): Dropout(p=0.0, inplace=False)
                  (to_q): Linear(in_features=16, out_features=512, bias=False)
                  (to_kv): Linear(in_features=16, out_features=1024, bias=False)
                  (to_out): Sequential(
                    (0): Linear(in_features=512, out_features=16, bias=False)
                    (1): LayerNorm()
                  )
                )
              )
              (block1): Block(
                (project): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (norm): GroupNorm(8, 16, eps=1e-05, affine=True)
                (act): SiLU()
              )
              (block2): Block(
                (project): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (norm): GroupNorm(8, 16, eps=1e-05, affine=True)
                (act): SiLU()
              )
              (res_conv): Identity()
            )
          )
          (3): EinopsToAndFrom(
            (fn): Residual(
              (fn): Attention(
                (norm): LayerNorm()
                (dropout): Dropout(p=0.0, inplace=False)
                (to_q): Linear(in_features=16, out_features=64, bias=False)
                (to_kv): Linear(in_features=16, out_features=32, bias=False)
                (to_out): Sequential(
                  (0): Linear(in_features=64, out_features=16, bias=False)
                  (1): LayerNorm()
                )
              )
            )
          )
          (4): Conv2d(16, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        )
        (2): ModuleList(
          (0): None
          (1): ResnetBlock(
            (time_mlp): Sequential(
              (0): SiLU()
              (1): Linear(in_features=64, out_features=64, bias=True)
            )
            (block1): Block(
              (project): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (norm): GroupNorm(8, 32, eps=1e-05, affine=True)
              (act): SiLU()
            )
            (block2): Block(
              (project): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (norm): GroupNorm(8, 32, eps=1e-05, affine=True)
              (act): SiLU()
            )
            (res_conv): Identity()
          )
          (2): ModuleList(
            (0): ResnetBlock(
              (time_mlp): Sequential(
                (0): SiLU()
                (1): Linear(in_features=64, out_features=64, bias=True)
              )
              (cross_attn): EinopsToAndFrom(
                (fn): CrossAttention(
                  (norm): LayerNorm()
                  (norm_context): Identity()
                  (dropout): Dropout(p=0.0, inplace=False)
                  (to_q): Linear(in_features=32, out_features=512, bias=False)
                  (to_kv): Linear(in_features=16, out_features=1024, bias=False)
                  (to_out): Sequential(
                    (0): Linear(in_features=512, out_features=32, bias=False)
                    (1): LayerNorm()
                  )
                )
              )
              (block1): Block(
                (project): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (norm): GroupNorm(8, 32, eps=1e-05, affine=True)
                (act): SiLU()
              )
              (block2): Block(
                (project): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (norm): GroupNorm(8, 32, eps=1e-05, affine=True)
                (act): SiLU()
              )
              (res_conv): Identity()
            )
            (1): ResnetBlock(
              (time_mlp): Sequential(
                (0): SiLU()
                (1): Linear(in_features=64, out_features=64, bias=True)
              )
              (cross_attn): EinopsToAndFrom(
                (fn): CrossAttention(
                  (norm): LayerNorm()
                  (norm_context): Identity()
                  (dropout): Dropout(p=0.0, inplace=False)
                  (to_q): Linear(in_features=32, out_features=512, bias=False)
                  (to_kv): Linear(in_features=16, out_features=1024, bias=False)
                  (to_out): Sequential(
                    (0): Linear(in_features=512, out_features=32, bias=False)
                    (1): LayerNorm()
                  )
                )
              )
              (block1): Block(
                (project): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (norm): GroupNorm(8, 32, eps=1e-05, affine=True)
                (act): SiLU()
              )
              (block2): Block(
                (project): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (norm): GroupNorm(8, 32, eps=1e-05, affine=True)
                (act): SiLU()
              )
              (res_conv): Identity()
            )
          )
          (3): EinopsToAndFrom(
            (fn): Residual(
              (fn): Attention(
                (norm): LayerNorm()
                (dropout): Dropout(p=0.0, inplace=False)
                (to_q): Linear(in_features=32, out_features=64, bias=False)
                (to_kv): Linear(in_features=32, out_features=32, bias=False)
                (to_out): Sequential(
                  (0): Linear(in_features=64, out_features=32, bias=False)
                  (1): LayerNorm()
                )
              )
            )
          )
          (4): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        )
        (3): ModuleList(
          (0): None
          (1): ResnetBlock(
            (time_mlp): Sequential(
              (0): SiLU()
              (1): Linear(in_features=64, out_features=128, bias=True)
            )
            (block1): Block(
              (project): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (norm): GroupNorm(8, 64, eps=1e-05, affine=True)
              (act): SiLU()
            )
            (block2): Block(
              (project): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (norm): GroupNorm(8, 64, eps=1e-05, affine=True)
              (act): SiLU()
            )
            (res_conv): Identity()
          )
          (2): ModuleList(
            (0): ResnetBlock(
              (time_mlp): Sequential(
                (0): SiLU()
                (1): Linear(in_features=64, out_features=128, bias=True)
              )
              (cross_attn): EinopsToAndFrom(
                (fn): CrossAttention(
                  (norm): LayerNorm()
                  (norm_context): Identity()
                  (dropout): Dropout(p=0.0, inplace=False)
                  (to_q): Linear(in_features=64, out_features=512, bias=False)
                  (to_kv): Linear(in_features=16, out_features=1024, bias=False)
                  (to_out): Sequential(
                    (0): Linear(in_features=512, out_features=64, bias=False)
                    (1): LayerNorm()
                  )
                )
              )
              (block1): Block(
                (project): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (norm): GroupNorm(8, 64, eps=1e-05, affine=True)
                (act): SiLU()
              )
              (block2): Block(
                (project): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (norm): GroupNorm(8, 64, eps=1e-05, affine=True)
                (act): SiLU()
              )
              (res_conv): Identity()
            )
            (1): ResnetBlock(
              (time_mlp): Sequential(
                (0): SiLU()
                (1): Linear(in_features=64, out_features=128, bias=True)
              )
              (cross_attn): EinopsToAndFrom(
                (fn): CrossAttention(
                  (norm): LayerNorm()
                  (norm_context): Identity()
                  (dropout): Dropout(p=0.0, inplace=False)
                  (to_q): Linear(in_features=64, out_features=512, bias=False)
                  (to_kv): Linear(in_features=16, out_features=1024, bias=False)
                  (to_out): Sequential(
                    (0): Linear(in_features=512, out_features=64, bias=False)
                    (1): LayerNorm()
                  )
                )
              )
              (block1): Block(
                (project): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (norm): GroupNorm(8, 64, eps=1e-05, affine=True)
                (act): SiLU()
              )
              (block2): Block(
                (project): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (norm): GroupNorm(8, 64, eps=1e-05, affine=True)
                (act): SiLU()
              )
              (res_conv): Identity()
            )
          )
          (3): EinopsToAndFrom(
            (fn): Residual(
              (fn): Attention(
                (norm): LayerNorm()
                (dropout): Dropout(p=0.0, inplace=False)
                (to_q): Linear(in_features=64, out_features=64, bias=False)
                (to_kv): Linear(in_features=64, out_features=32, bias=False)
                (to_out): Sequential(
                  (0): Linear(in_features=64, out_features=64, bias=False)
                  (1): LayerNorm()
                )
              )
            )
          )
          (4): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1))
        )
      )
      (ups): ModuleList(
        (0): ModuleList(
          (0): ResnetBlock(
            (time_mlp): Sequential(
              (0): SiLU()
              (1): Linear(in_features=64, out_features=256, bias=True)
            )
            (cross_attn): EinopsToAndFrom(
              (fn): CrossAttention(
                (norm): LayerNorm()
                (norm_context): Identity()
                (dropout): Dropout(p=0.0, inplace=False)
                (to_q): Linear(in_features=128, out_features=512, bias=False)
                (to_kv): Linear(in_features=16, out_features=1024, bias=False)
                (to_out): Sequential(
                  (0): Linear(in_features=512, out_features=128, bias=False)
                  (1): LayerNorm()
                )
              )
            )
            (block1): Block(
              (project): Conv2d(192, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (norm): GroupNorm(8, 128, eps=1e-05, affine=True)
              (act): SiLU()
            )
            (block2): Block(
              (project): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (norm): GroupNorm(8, 128, eps=1e-05, affine=True)
              (act): SiLU()
            )
            (res_conv): Conv2d(192, 128, kernel_size=(1, 1), stride=(1, 1))
          )
          (1): ModuleList(
            (0): ResnetBlock(
              (time_mlp): Sequential(
                (0): SiLU()
                (1): Linear(in_features=64, out_features=256, bias=True)
              )
              (cross_attn): EinopsToAndFrom(
                (fn): CrossAttention(
                  (norm): LayerNorm()
                  (norm_context): Identity()
                  (dropout): Dropout(p=0.0, inplace=False)
                  (to_q): Linear(in_features=128, out_features=512, bias=False)
                  (to_kv): Linear(in_features=16, out_features=1024, bias=False)
                  (to_out): Sequential(
                    (0): Linear(in_features=512, out_features=128, bias=False)
                    (1): LayerNorm()
                  )
                )
              )
              (block1): Block(
                (project): Conv2d(192, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (norm): GroupNorm(8, 128, eps=1e-05, affine=True)
                (act): SiLU()
              )
              (block2): Block(
                (project): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (norm): GroupNorm(8, 128, eps=1e-05, affine=True)
                (act): SiLU()
              )
              (res_conv): Conv2d(192, 128, kernel_size=(1, 1), stride=(1, 1))
            )
            (1): ResnetBlock(
              (time_mlp): Sequential(
                (0): SiLU()
                (1): Linear(in_features=64, out_features=256, bias=True)
              )
              (cross_attn): EinopsToAndFrom(
                (fn): CrossAttention(
                  (norm): LayerNorm()
                  (norm_context): Identity()
                  (dropout): Dropout(p=0.0, inplace=False)
                  (to_q): Linear(in_features=128, out_features=512, bias=False)
                  (to_kv): Linear(in_features=16, out_features=1024, bias=False)
                  (to_out): Sequential(
                    (0): Linear(in_features=512, out_features=128, bias=False)
                    (1): LayerNorm()
                  )
                )
              )
              (block1): Block(
                (project): Conv2d(192, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (norm): GroupNorm(8, 128, eps=1e-05, affine=True)
                (act): SiLU()
              )
              (block2): Block(
                (project): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (norm): GroupNorm(8, 128, eps=1e-05, affine=True)
                (act): SiLU()
              )
              (res_conv): Conv2d(192, 128, kernel_size=(1, 1), stride=(1, 1))
            )
          )
          (2): EinopsToAndFrom(
            (fn): Residual(
              (fn): Attention(
                (norm): LayerNorm()
                (dropout): Dropout(p=0.0, inplace=False)
                (to_q): Linear(in_features=128, out_features=64, bias=False)
                (to_kv): Linear(in_features=128, out_features=32, bias=False)
                (to_out): Sequential(
                  (0): Linear(in_features=64, out_features=128, bias=False)
                  (1): LayerNorm()
                )
              )
            )
          )
          (3): PixelShuffleUpsample(
            (net): Sequential(
              (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1))
              (1): SiLU()
              (2): PixelShuffle(upscale_factor=2)
            )
          )
        )
        (1): ModuleList(
          (0): ResnetBlock(
            (time_mlp): Sequential(
              (0): SiLU()
              (1): Linear(in_features=64, out_features=128, bias=True)
            )
            (cross_attn): EinopsToAndFrom(
              (fn): CrossAttention(
                (norm): LayerNorm()
                (norm_context): Identity()
                (dropout): Dropout(p=0.0, inplace=False)
                (to_q): Linear(in_features=64, out_features=512, bias=False)
                (to_kv): Linear(in_features=16, out_features=1024, bias=False)
                (to_out): Sequential(
                  (0): Linear(in_features=512, out_features=64, bias=False)
                  (1): LayerNorm()
                )
              )
            )
            (block1): Block(
              (project): Conv2d(96, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (norm): GroupNorm(8, 64, eps=1e-05, affine=True)
              (act): SiLU()
            )
            (block2): Block(
              (project): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (norm): GroupNorm(8, 64, eps=1e-05, affine=True)
              (act): SiLU()
            )
            (res_conv): Conv2d(96, 64, kernel_size=(1, 1), stride=(1, 1))
          )
          (1): ModuleList(
            (0): ResnetBlock(
              (time_mlp): Sequential(
                (0): SiLU()
                (1): Linear(in_features=64, out_features=128, bias=True)
              )
              (cross_attn): EinopsToAndFrom(
                (fn): CrossAttention(
                  (norm): LayerNorm()
                  (norm_context): Identity()
                  (dropout): Dropout(p=0.0, inplace=False)
                  (to_q): Linear(in_features=64, out_features=512, bias=False)
                  (to_kv): Linear(in_features=16, out_features=1024, bias=False)
                  (to_out): Sequential(
                    (0): Linear(in_features=512, out_features=64, bias=False)
                    (1): LayerNorm()
                  )
                )
              )
              (block1): Block(
                (project): Conv2d(96, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (norm): GroupNorm(8, 64, eps=1e-05, affine=True)
                (act): SiLU()
              )
              (block2): Block(
                (project): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (norm): GroupNorm(8, 64, eps=1e-05, affine=True)
                (act): SiLU()
              )
              (res_conv): Conv2d(96, 64, kernel_size=(1, 1), stride=(1, 1))
            )
            (1): ResnetBlock(
              (time_mlp): Sequential(
                (0): SiLU()
                (1): Linear(in_features=64, out_features=128, bias=True)
              )
              (cross_attn): EinopsToAndFrom(
                (fn): CrossAttention(
                  (norm): LayerNorm()
                  (norm_context): Identity()
                  (dropout): Dropout(p=0.0, inplace=False)
                  (to_q): Linear(in_features=64, out_features=512, bias=False)
                  (to_kv): Linear(in_features=16, out_features=1024, bias=False)
                  (to_out): Sequential(
                    (0): Linear(in_features=512, out_features=64, bias=False)
                    (1): LayerNorm()
                  )
                )
              )
              (block1): Block(
                (project): Conv2d(96, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (norm): GroupNorm(8, 64, eps=1e-05, affine=True)
                (act): SiLU()
              )
              (block2): Block(
                (project): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (norm): GroupNorm(8, 64, eps=1e-05, affine=True)
                (act): SiLU()
              )
              (res_conv): Conv2d(96, 64, kernel_size=(1, 1), stride=(1, 1))
            )
          )
          (2): EinopsToAndFrom(
            (fn): Residual(
              (fn): Attention(
                (norm): LayerNorm()
                (dropout): Dropout(p=0.0, inplace=False)
                (to_q): Linear(in_features=64, out_features=64, bias=False)
                (to_kv): Linear(in_features=64, out_features=32, bias=False)
                (to_out): Sequential(
                  (0): Linear(in_features=64, out_features=64, bias=False)
                  (1): LayerNorm()
                )
              )
            )
          )
          (3): PixelShuffleUpsample(
            (net): Sequential(
              (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1))
              (1): SiLU()
              (2): PixelShuffle(upscale_factor=2)
            )
          )
        )
        (2): ModuleList(
          (0): ResnetBlock(
            (time_mlp): Sequential(
              (0): SiLU()
              (1): Linear(in_features=64, out_features=64, bias=True)
            )
            (cross_attn): EinopsToAndFrom(
              (fn): CrossAttention(
                (norm): LayerNorm()
                (norm_context): Identity()
                (dropout): Dropout(p=0.0, inplace=False)
                (to_q): Linear(in_features=32, out_features=512, bias=False)
                (to_kv): Linear(in_features=16, out_features=1024, bias=False)
                (to_out): Sequential(
                  (0): Linear(in_features=512, out_features=32, bias=False)
                  (1): LayerNorm()
                )
              )
            )
            (block1): Block(
              (project): Conv2d(48, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (norm): GroupNorm(8, 32, eps=1e-05, affine=True)
              (act): SiLU()
            )
            (block2): Block(
              (project): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (norm): GroupNorm(8, 32, eps=1e-05, affine=True)
              (act): SiLU()
            )
            (res_conv): Conv2d(48, 32, kernel_size=(1, 1), stride=(1, 1))
          )
          (1): ModuleList(
            (0): ResnetBlock(
              (time_mlp): Sequential(
                (0): SiLU()
                (1): Linear(in_features=64, out_features=64, bias=True)
              )
              (cross_attn): EinopsToAndFrom(
                (fn): CrossAttention(
                  (norm): LayerNorm()
                  (norm_context): Identity()
                  (dropout): Dropout(p=0.0, inplace=False)
                  (to_q): Linear(in_features=32, out_features=512, bias=False)
                  (to_kv): Linear(in_features=16, out_features=1024, bias=False)
                  (to_out): Sequential(
                    (0): Linear(in_features=512, out_features=32, bias=False)
                    (1): LayerNorm()
                  )
                )
              )
              (block1): Block(
                (project): Conv2d(48, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (norm): GroupNorm(8, 32, eps=1e-05, affine=True)
                (act): SiLU()
              )
              (block2): Block(
                (project): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (norm): GroupNorm(8, 32, eps=1e-05, affine=True)
                (act): SiLU()
              )
              (res_conv): Conv2d(48, 32, kernel_size=(1, 1), stride=(1, 1))
            )
            (1): ResnetBlock(
              (time_mlp): Sequential(
                (0): SiLU()
                (1): Linear(in_features=64, out_features=64, bias=True)
              )
              (cross_attn): EinopsToAndFrom(
                (fn): CrossAttention(
                  (norm): LayerNorm()
                  (norm_context): Identity()
                  (dropout): Dropout(p=0.0, inplace=False)
                  (to_q): Linear(in_features=32, out_features=512, bias=False)
                  (to_kv): Linear(in_features=16, out_features=1024, bias=False)
                  (to_out): Sequential(
                    (0): Linear(in_features=512, out_features=32, bias=False)
                    (1): LayerNorm()
                  )
                )
              )
              (block1): Block(
                (project): Conv2d(48, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (norm): GroupNorm(8, 32, eps=1e-05, affine=True)
                (act): SiLU()
              )
              (block2): Block(
                (project): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (norm): GroupNorm(8, 32, eps=1e-05, affine=True)
                (act): SiLU()
              )
              (res_conv): Conv2d(48, 32, kernel_size=(1, 1), stride=(1, 1))
            )
          )
          (2): EinopsToAndFrom(
            (fn): Residual(
              (fn): Attention(
                (norm): LayerNorm()
                (dropout): Dropout(p=0.0, inplace=False)
                (to_q): Linear(in_features=32, out_features=64, bias=False)
                (to_kv): Linear(in_features=32, out_features=32, bias=False)
                (to_out): Sequential(
                  (0): Linear(in_features=64, out_features=32, bias=False)
                  (1): LayerNorm()
                )
              )
            )
          )
          (3): PixelShuffleUpsample(
            (net): Sequential(
              (0): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1))
              (1): SiLU()
              (2): PixelShuffle(upscale_factor=2)
            )
          )
        )
        (3): ModuleList(
          (0): ResnetBlock(
            (time_mlp): Sequential(
              (0): SiLU()
              (1): Linear(in_features=64, out_features=32, bias=True)
            )
            (block1): Block(
              (project): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (norm): GroupNorm(8, 16, eps=1e-05, affine=True)
              (act): SiLU()
            )
            (block2): Block(
              (project): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (norm): GroupNorm(8, 16, eps=1e-05, affine=True)
              (act): SiLU()
            )
            (res_conv): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1))
          )
          (1): ModuleList(
            (0): ResnetBlock(
              (time_mlp): Sequential(
                (0): SiLU()
                (1): Linear(in_features=64, out_features=32, bias=True)
              )
              (block1): Block(
                (project): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (norm): GroupNorm(8, 16, eps=1e-05, affine=True)
                (act): SiLU()
              )
              (block2): Block(
                (project): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (norm): GroupNorm(8, 16, eps=1e-05, affine=True)
                (act): SiLU()
              )
              (res_conv): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1))
            )
            (1): ResnetBlock(
              (time_mlp): Sequential(
                (0): SiLU()
                (1): Linear(in_features=64, out_features=32, bias=True)
              )
              (block1): Block(
                (project): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (norm): GroupNorm(8, 16, eps=1e-05, affine=True)
                (act): SiLU()
              )
              (block2): Block(
                (project): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (norm): GroupNorm(8, 16, eps=1e-05, affine=True)
                (act): SiLU()
              )
              (res_conv): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1))
            )
          )
          (2): Identity()
          (3): Identity()
        )
      )
      (mid_block1): ResnetBlock(
        (time_mlp): Sequential(
          (0): SiLU()
          (1): Linear(in_features=64, out_features=256, bias=True)
        )
        (cross_attn): EinopsToAndFrom(
          (fn): CrossAttention(
            (norm): LayerNorm()
            (norm_context): Identity()
            (dropout): Dropout(p=0.0, inplace=False)
            (to_q): Linear(in_features=128, out_features=512, bias=False)
            (to_kv): Linear(in_features=16, out_features=1024, bias=False)
            (to_out): Sequential(
              (0): Linear(in_features=512, out_features=128, bias=False)
              (1): LayerNorm()
            )
          )
        )
        (block1): Block(
          (project): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm): GroupNorm(8, 128, eps=1e-05, affine=True)
          (act): SiLU()
        )
        (block2): Block(
          (project): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm): GroupNorm(8, 128, eps=1e-05, affine=True)
          (act): SiLU()
        )
        (res_conv): Identity()
      )
      (mid_attn): EinopsToAndFrom(
        (fn): Residual(
          (fn): Attention(
            (norm): LayerNorm()
            (dropout): Dropout(p=0.0, inplace=False)
            (to_q): Linear(in_features=128, out_features=64, bias=False)
            (to_kv): Linear(in_features=128, out_features=32, bias=False)
            (to_out): Sequential(
              (0): Linear(in_features=64, out_features=128, bias=False)
              (1): LayerNorm()
            )
          )
        )
      )
      (mid_block2): ResnetBlock(
        (time_mlp): Sequential(
          (0): SiLU()
          (1): Linear(in_features=64, out_features=256, bias=True)
        )
        (cross_attn): EinopsToAndFrom(
          (fn): CrossAttention(
            (norm): LayerNorm()
            (norm_context): Identity()
            (dropout): Dropout(p=0.0, inplace=False)
            (to_q): Linear(in_features=128, out_features=512, bias=False)
            (to_kv): Linear(in_features=16, out_features=1024, bias=False)
            (to_out): Sequential(
              (0): Linear(in_features=512, out_features=128, bias=False)
              (1): LayerNorm()
            )
          )
        )
        (block1): Block(
          (project): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm): GroupNorm(8, 128, eps=1e-05, affine=True)
          (act): SiLU()
        )
        (block2): Block(
          (project): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm): GroupNorm(8, 128, eps=1e-05, affine=True)
          (act): SiLU()
        )
        (res_conv): Identity()
      )
      (upsample_combiner): UpsampleCombiner()
      (final_resnet_block): ResnetBlock(
        (time_mlp): Sequential(
          (0): SiLU()
          (1): Linear(in_features=64, out_features=32, bias=True)
        )
        (block1): Block(
          (project): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm): GroupNorm(8, 16, eps=1e-05, affine=True)
          (act): SiLU()
        )
        (block2): Block(
          (project): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm): GroupNorm(8, 16, eps=1e-05, affine=True)
          (act): SiLU()
        )
        (res_conv): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1))
      )
      (to_out): Conv2d(16, 3, kernel_size=(1, 1), stride=(1, 1))
    )
  )
  (vaes): ModuleList(
    (0): NullVQGanVAE()
  )
  (noise_schedulers): ModuleList(
    (0): NoiseScheduler()
  )
  (lowres_conds): ModuleList(
    (0): None
  )
)

TrainDecoderConfig->DecoderConfig/DecoderDataConfig/DecoderTrainConfig/DecoderEvaluateConfig/TrackerConfig->dataloader=create_dataloaders->create_image_embedding_dataloader->decoder=config.decoder.create()->DecoderConfig.create()->Unet(unconfigs)->clip=clip.create()->OpenAIClipAdapter->Decoder->tracker=create_tracker->train:DecoderTrainer->img:4,3,224,224,txt:cat/sea/tree/motel->DecoderTrainer.forward()->self.decoder->Decoder.forward()->resize_image_to:image:4,3,64,64->image=vae.encode(image)->p_losses->x_noisy=noise_scheduler.q_sample()->model_output=unet():4,3,64,64->Unet->target:4,3,64,64,pred:4,2,64,64->loss=noise_scheduler.loss_fn:l2:mse->mean,var=noise_scheduler.q_posterior()->kl=normal_kl->loss+vb_loss文章来源地址https://www.toymoban.com/news/detail-568298.html

到了这里,关于dalle2:hierarchical text-conditional image generation with clip的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • Adding Conditional Control to Text-to-Image Diffusion Models

    安全验证 - 知乎 知乎,中文互联网高质量的问答社区和创作者聚集的原创内容平台,于 2011 年 1 月正式上线,以「让人们更好的分享知识、经验和见解,找到自己的解答」为品牌使命。知乎凭借认真、专业、友善的社区氛围、独特的产品机制以及结构化和易获得的优质内容,

    2024年02月06日
    浏览(68)
  • Adding Conditional Control to Text-to-Image Diffusion Models——【论文笔记】

    本文发表于ICCV2023  论文地址:ICCV 2023 Open Access Repository (thecvf.com) 官方实现代码:lllyasviel/ControlNet: Let us control diffusion models! (github.com)  论文提出了一种神经网络架构ControlNet,可以将空间条件控制添加到大型的预训练文本到图像扩散模型中。ControlNet将预训练好的大型扩散模型

    2024年02月01日
    浏览(41)
  • 条件控制生成——diffusion模型——Adding Conditional Control to Text-to-Image Diffusion Models

      在之前的扩散模型介绍中,入门-1,主要考虑的是无条件下的图片生成,涉及到的问题主要是如何保证图片的质量,这个过程需要考虑很多的参数项,参数设定的不同会对图片的质量和多样性产生很大的影响。    能够让diffusion模型在工业界中大放异彩的模型,比如条件

    2024年02月16日
    浏览(45)
  • AI绘画后面的论文——ControlNet:Adding Conditional Control to Text-to-Image Diffusion Models

    代码:lllyasviel/ControlNet: Let us control diffusion models! (github.com) 论文地址 最近AI绘画又双叒叕进化了,前一次还只能生成二次元,这次三次元都能生成了。这次AI绘画这么火爆的原因跟下面这篇文章脱不开关系,它将AI绘画带到了一个新的高度。 我们提出了一个神经网络结构cont

    2024年02月11日
    浏览(47)
  • Text-to-Image with Diffusion models的巅峰之作:深入解读​ DALL·E 2​

    Diffusion Models专栏文章汇总:入门与实战   前言: DALL·E 2、imagen、GLIDE是最著名的三个text-to-image的扩散模型,是diffusion models第一个火出圈的任务。这篇博客将会详细解读DALL·E 2《Hierarchical Text-Conditional Image Generation with CLIP Latents》的原理。 目录 背景知识:CLIP简介 方法概述

    2024年02月13日
    浏览(32)
  • DALLE2论文解读及实现(一)

    DALLE2: Hierarchical Text-Conditional Image Generation with CLIP Latents paper: https://cdn.openai.com/papers/dall-e-2.pdf github: https://github.com/lucidrains/DALLE2-pytorch DALLE2概览: - CLIP模型: 用于生成text embedding zt 和image embedding zi - prior模型: 1) 模型输入:为 the encoded text,the CLIP text embedding,time_embed,image

    2024年02月07日
    浏览(65)
  • Openai神作Dalle2理论和代码复现

    注:大家觉得博客好的话,别忘了点赞收藏呀,本人每周都会更新关于人工智能和大数据相关的内容,内容多为原创,Python Java Scala SQL 代码,CV NLP 推荐系统等,Spark Flink Kafka Hbase Hive Flume等等~写的都是纯干货,各种顶会的论文解读,一起进步。 今天和大家分享一下Openai神作

    2023年04月17日
    浏览(35)
  • openai visgpt,chatgpt,DALLE2 使用测试

    网络问题:openai-chatGPT的API调用异常处理 官方手册:https://platform.openai.com/docs/api-reference gitlab代码 https://github.com/microsoft/visual-chatgpt visual_chatgpt.py运行前添加密匙 更改参数为cpu 非常非常慢7min,而且根据控制台查看图片,发现已经生成,而UI上没有生成 结果 结果:氛围有了,但

    2024年02月12日
    浏览(54)
  • ChatGPT和 dalle2 配合生成故事绘本

    在之前章节中,我们已经尝试过让 ChatGPT 来生成一些故事情节,不管是影视剧还是小说还是游戏都可以。这时候,自然而然的可以联想到:那我们可不可以一步到位,把 ChatGPT 编出来的故事情节,再画成连环画、甚至生成动画视频呢? 事实上,ChatGPT 和 Dalle2 配合完成故事绘

    2024年02月13日
    浏览(48)
  • 大模型 Dalle2 学习三部曲(二)clip学习

    clip论文比较长48页,但是clip模型本身又比较简单,效果又奇好,正所谓大道至简,我们来学习一下clip论文中的一些技巧,可以让我们快速加深对clip模型的理解,以及大模型对推荐带来革命性的变化。 首选我们来看看clip的结构,如图clip结构比较直观,训练的时候把文本描述

    2024年02月09日
    浏览(39)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包