transformer优化(一)-UNeXt 学习笔记

这篇具有很好参考价值的文章主要介绍了transformer优化(一)-UNeXt 学习笔记。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

论文地址:https://arxiv.org/abs/2203.04967

代码地址:hhttps://link.zhihu.com/?target=https%3A//github.com/jeya-maria-jose/UNeXt-pytorch

1.是什么?

UNeXt是约翰霍普金斯大学在2022年发布的论文。它在早期阶段使用卷积,在潜在空间阶段使用 MLP。通过一个标记化的 MLP 块来标记和投影卷积特征,并使用 MLP 对表示进行建模。对输入通道进行移位,可以专注于学习局部依赖性。

2.为什么?

近年来,UNet及其最新扩展如TransUNet已成为领先的医学图像分割方法。然而,由于这些网络参数多、计算复杂且使用速度慢,因此不能有效地用于即时护理应用中的快速图像分割。为此,我们提出了一种基于卷积多层感知器(MLP)的图像分割网络UNeXt。我们以一种有效的方式设计了UNeXt,其中包s和潜伏阶段的MLP阶段。我们提出了一个标记化(切片)的MLP块,我们有效地标记和投影(关联)卷积特征,并使用MLP来建模表示。为了进一步提高性能,我们建议在向mlp输入的同时改变输入的通道,以便专注于学习局部依赖关系。在潜在空间中使用标记化mlp减少了参数的数量和计算复杂性,同时能够产生更好的表示来帮助分割。该网络还包括各级编码器和解码器之间的跳过连接。我们在多个医学图像分割数据集上对UNeXt进行了测试,结果表明我们将参数数量减少了72倍,计算复杂度降低了68倍,推理速度提高了10倍,同时也获得了比目前最先进的医学图像分割架构更好的分割性能
 

3.怎么样?

3.1网络结构

UNeXt是一个编码器-解码器架构,有两个阶段:1)卷积阶段,和2)标记化MLP阶段。输入图像通过编码器,其中前3个块是卷积的,接下来的2个是Tokenized MLP块。解码器有2个Tokenized MLP块,后面跟着3个卷积块。每个编码器块将特征分辨率降低2,每个解码器块将特征分辨率提高2。

跳过连接是也包括在编码器和解码器之间。每个块上的通道数是一个超参数,表示为C1到C5。对于使用UNeXt架构的实验,除非另有说明,否则我们遵循C1 = 32, C2 = 64, C3 = 128, C4 = 160和C5 = 256。请注意,这些数字实际上小于UNet及其变体的过滤器数量,这有助于减少参数和计算的第一个更改。

transformer优化(一)-UNeXt 学习笔记,学习,笔记,transformer,人工智能,python

3.2原理分析

3.2.1卷积阶段

有三个卷积块,每个卷积块配备一个卷积层,一个批处理归一化层和ReLU激活。我们使用内核大小为3 × 3, stride为1,padding为1。编码器中的转换块使用池窗口为2 × 2的最大池化层,而解码器中的转换块由双线性插值层组成,用于对特征映射进行上采样。我们使用双线性插值代替转置卷积,因为转置卷积基本上是可学习的上采样,并有助于获得更多可学习的参数。

3.2.2 Shifted MLP

在移位MLP中,我们首先在标记之前移动卷积特征的通道轴。这有助于MLP只关注conv特征的某些位置,从而诱导块的局域性。这里的直觉类似于Swin transformer[5],其中引入了基于窗口的注意力,为完全全局的模型添加了更多的局部性。由于Tokenized MLP块有2个MLP,我们像轴向注意一样,在一个块中跨宽度移动特征,在另一个块中跨高度移动特征[24]。我们把这些特征分成h个不同的,根据指定的轴对它们进行分区并移动j个位置。这有助于我们创建沿轴引入局部性的随机窗口。我们把这些特征分成h个不同的部分,并根据指定的轴移动它们j个位置。这有助于我们创建沿轴引入局部性的随机窗口

transformer优化(一)-UNeXt 学习笔记,学习,笔记,transformer,人工智能,python

3.2.3  Tokenized MLP

在标记化的MLP块中,我们首先转移特征并将其投影到标记中。为了标记化,我们首先使用内核大小为3,并将通道数更改为E,其中E是嵌入维度(标记数),这是一个超参数。然后,我们将这些令牌传递给移位的MLP(跨宽度),其中MLP的隐藏维度是超参数h。
接下来,这些特征通过深度卷积层(DWConv)传递。我们在这个块中使用DWConv有两个原因:
1)它有助于编码MLP特征的位置信息。文献[26]表明,MLP块中的Conv层足以对位置信息进行编码,并且实际上比标准的位置编码技术性能更好。当测试和训练分辨率不相同时,需要插入像ViT中的位置编码技术,这通常会导致性能下降。
2) DWConv使用较少的参数,因此提高了效率。然后我们使用GELU[12]激活层。我们使用GELU而不是RELU,因为它是一个更平滑的选择,并且发现它的性能更好。此外,像ViT[10]和BERT[9]这样的最新架构已经成功地使用了GELU来获得改进的结果。然后,我们通过另一个移位的MLP(跨高度)传递特征,该MLP将维度从H转换为o。
我们在这里使用残差连接并添加原始标记作为残差。然后我们应用层归一化(LN)并将输出特征传递给下一个块。LN比BN更受欢迎,因为沿着令牌进行规范化比在token化的MLP块中跨批进行规范化更有意义。
token化MLP块中的计算可以总结为:

X shift = Shift W (X);T W = Tokenize(X shift ), (1)
Y = f(DWConv((MLP(T W )))), (2)
Y shift = Shift H (Y );T H = Tokenize(Y shift ), (3)
Y = f(LN(T W + MLP(GELU(T H )))), (4) 

其中T为令牌,H为高度,W为宽度,DW Conv为深度卷积,LN为层归一化。请注意,所有这些计算都是在嵌入维数H上进行的,这明显小于特征映射的维数H /N × H/ N,其中N是取决于块的2倍。在我们的实验中,除非另有说明,否则我们将H设为768。这种设计Tokenized MLP块的方法有助于编码有意义的特征信息,并且在计算或参数方面贡献不大。

3.3代码实现

import torch
from torch import nn
import torch
import torchvision
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.utils import save_image
import torch.nn.functional as F
import os
import matplotlib.pyplot as plt
from utils import *
__all__ = ['UNext']

import timm
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
import types
import math
from abc import ABCMeta, abstractmethod
from mmcv.cnn import ConvModule
import pdb



def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, bias=False)


def shift(dim):
            x_shift = [ torch.roll(x_c, shift, dim) for x_c, shift in zip(xs, range(-self.pad, self.pad+1))]
            x_cat = torch.cat(x_shift, 1)
            x_cat = torch.narrow(x_cat, 2, self.pad, H)
            x_cat = torch.narrow(x_cat, 3, self.pad, W)
            return x_cat

class shiftmlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., shift_size=5):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.dim = in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.dwconv = DWConv(hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

        self.shift_size = shift_size
        self.pad = shift_size // 2

        
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()
    
#     def shift(x, dim):
#         x = F.pad(x, "constant", 0)
#         x = torch.chunk(x, shift_size, 1)
#         x = [ torch.roll(x_c, shift, dim) for x_s, shift in zip(x, range(-pad, pad+1))]
#         x = torch.cat(x, 1)
#         return x[:, :, pad:-pad, pad:-pad]

    def forward(self, x, H, W):
        # pdb.set_trace()
        B, N, C = x.shape

        xn = x.transpose(1, 2).view(B, C, H, W).contiguous()
        xn = F.pad(xn, (self.pad, self.pad, self.pad, self.pad) , "constant", 0)
        xs = torch.chunk(xn, self.shift_size, 1)
        x_shift = [torch.roll(x_c, shift, 2) for x_c, shift in zip(xs, range(-self.pad, self.pad+1))]
        x_cat = torch.cat(x_shift, 1)
        x_cat = torch.narrow(x_cat, 2, self.pad, H)
        x_s = torch.narrow(x_cat, 3, self.pad, W)


        x_s = x_s.reshape(B,C,H*W).contiguous()
        x_shift_r = x_s.transpose(1,2)


        x = self.fc1(x_shift_r)

        x = self.dwconv(x, H, W)
        x = self.act(x) 
        x = self.drop(x)

        xn = x.transpose(1, 2).view(B, C, H, W).contiguous()
        xn = F.pad(xn, (self.pad, self.pad, self.pad, self.pad) , "constant", 0)
        xs = torch.chunk(xn, self.shift_size, 1)
        x_shift = [torch.roll(x_c, shift, 3) for x_c, shift in zip(xs, range(-self.pad, self.pad+1))]
        x_cat = torch.cat(x_shift, 1)
        x_cat = torch.narrow(x_cat, 2, self.pad, H)
        x_s = torch.narrow(x_cat, 3, self.pad, W)
        x_s = x_s.reshape(B,C,H*W).contiguous()
        x_shift_c = x_s.transpose(1,2)

        x = self.fc2(x_shift_c)
        x = self.drop(x)
        return x



class shiftedBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1):
        super().__init__()


        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = shiftmlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self, x, H, W):

        x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
        return x


class DWConv(nn.Module):
    def __init__(self, dim=768):
        super(DWConv, self).__init__()
        self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)

    def forward(self, x, H, W):
        B, N, C = x.shape
        x = x.transpose(1, 2).view(B, C, H, W)
        x = self.dwconv(x)
        x = x.flatten(2).transpose(1, 2)

        return x

class OverlapPatchEmbed(nn.Module):
    """ Image to Patch Embedding
    """

    def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)

        self.img_size = img_size
        self.patch_size = patch_size
        self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
        self.num_patches = self.H * self.W
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
                              padding=(patch_size[0] // 2, patch_size[1] // 2))
        self.norm = nn.LayerNorm(embed_dim)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self, x):
        x = self.proj(x)
        _, _, H, W = x.shape
        x = x.flatten(2).transpose(1, 2)
        x = self.norm(x)

        return x, H, W


class UNext(nn.Module):

    ## Conv 3 + MLP 2 + shifted MLP
    
    def __init__(self,  num_classes, input_channels=3, deep_supervision=False,img_size=224, patch_size=16, in_chans=3,  embed_dims=[ 128, 160, 256],
                 num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0.,
                 attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
                 depths=[1, 1, 1], sr_ratios=[8, 4, 2, 1], **kwargs):
        super().__init__()
        
        self.encoder1 = nn.Conv2d(3, 16, 3, stride=1, padding=1)  
        self.encoder2 = nn.Conv2d(16, 32, 3, stride=1, padding=1)  
        self.encoder3 = nn.Conv2d(32, 128, 3, stride=1, padding=1)

        self.ebn1 = nn.BatchNorm2d(16)
        self.ebn2 = nn.BatchNorm2d(32)
        self.ebn3 = nn.BatchNorm2d(128)
        
        self.norm3 = norm_layer(embed_dims[1])
        self.norm4 = norm_layer(embed_dims[2])

        self.dnorm3 = norm_layer(160)
        self.dnorm4 = norm_layer(128)

        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]

        self.block1 = nn.ModuleList([shiftedBlock(
            dim=embed_dims[1], num_heads=num_heads[0], mlp_ratio=1, qkv_bias=qkv_bias, qk_scale=qk_scale,
            drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[0], norm_layer=norm_layer,
            sr_ratio=sr_ratios[0])])

        self.block2 = nn.ModuleList([shiftedBlock(
            dim=embed_dims[2], num_heads=num_heads[0], mlp_ratio=1, qkv_bias=qkv_bias, qk_scale=qk_scale,
            drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[1], norm_layer=norm_layer,
            sr_ratio=sr_ratios[0])])

        self.dblock1 = nn.ModuleList([shiftedBlock(
            dim=embed_dims[1], num_heads=num_heads[0], mlp_ratio=1, qkv_bias=qkv_bias, qk_scale=qk_scale,
            drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[0], norm_layer=norm_layer,
            sr_ratio=sr_ratios[0])])

        self.dblock2 = nn.ModuleList([shiftedBlock(
            dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=1, qkv_bias=qkv_bias, qk_scale=qk_scale,
            drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[1], norm_layer=norm_layer,
            sr_ratio=sr_ratios[0])])

        self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0],
                                              embed_dim=embed_dims[1])
        self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1],
                                              embed_dim=embed_dims[2])

        self.decoder1 = nn.Conv2d(256, 160, 3, stride=1,padding=1)  
        self.decoder2 =   nn.Conv2d(160, 128, 3, stride=1, padding=1)  
        self.decoder3 =   nn.Conv2d(128, 32, 3, stride=1, padding=1) 
        self.decoder4 =   nn.Conv2d(32, 16, 3, stride=1, padding=1)
        self.decoder5 =   nn.Conv2d(16, 16, 3, stride=1, padding=1)

        self.dbn1 = nn.BatchNorm2d(160)
        self.dbn2 = nn.BatchNorm2d(128)
        self.dbn3 = nn.BatchNorm2d(32)
        self.dbn4 = nn.BatchNorm2d(16)
        
        self.final = nn.Conv2d(16, num_classes, kernel_size=1)

        self.soft = nn.Softmax(dim =1)

    def forward(self, x):
        
        B = x.shape[0]
        ### Encoder
        ### Conv Stage

        ### Stage 1
        out = F.relu(F.max_pool2d(self.ebn1(self.encoder1(x)),2,2))
        t1 = out
        ### Stage 2
        out = F.relu(F.max_pool2d(self.ebn2(self.encoder2(out)),2,2))
        t2 = out
        ### Stage 3
        out = F.relu(F.max_pool2d(self.ebn3(self.encoder3(out)),2,2))
        t3 = out

        ### Tokenized MLP Stage
        ### Stage 4

        out,H,W = self.patch_embed3(out)
        for i, blk in enumerate(self.block1):
            out = blk(out, H, W)
        out = self.norm3(out)
        out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
        t4 = out

        ### Bottleneck

        out ,H,W= self.patch_embed4(out)
        for i, blk in enumerate(self.block2):
            out = blk(out, H, W)
        out = self.norm4(out)
        out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()

        ### Stage 4

        out = F.relu(F.interpolate(self.dbn1(self.decoder1(out)),scale_factor=(2,2),mode ='bilinear'))
        
        out = torch.add(out,t4)
        _,_,H,W = out.shape
        out = out.flatten(2).transpose(1,2)
        for i, blk in enumerate(self.dblock1):
            out = blk(out, H, W)

        ### Stage 3
        
        out = self.dnorm3(out)
        out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
        out = F.relu(F.interpolate(self.dbn2(self.decoder2(out)),scale_factor=(2,2),mode ='bilinear'))
        out = torch.add(out,t3)
        _,_,H,W = out.shape
        out = out.flatten(2).transpose(1,2)
        
        for i, blk in enumerate(self.dblock2):
            out = blk(out, H, W)

        out = self.dnorm4(out)
        out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()

        out = F.relu(F.interpolate(self.dbn3(self.decoder3(out)),scale_factor=(2,2),mode ='bilinear'))
        out = torch.add(out,t2)
        out = F.relu(F.interpolate(self.dbn4(self.decoder4(out)),scale_factor=(2,2),mode ='bilinear'))
        out = torch.add(out,t1)
        out = F.relu(F.interpolate(self.decoder5(out),scale_factor=(2,2),mode ='bilinear'))

        return self.final(out)


class UNext_S(nn.Module):

    ## Conv 3 + MLP 2 + shifted MLP w less parameters
    
    def __init__(self,  num_classes, input_channels=3, deep_supervision=False,img_size=224, patch_size=16, in_chans=3,  embed_dims=[32, 64, 128, 512],
                 num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0.,
                 attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
                 depths=[1, 1, 1], sr_ratios=[8, 4, 2, 1], **kwargs):
        super().__init__()
        
        self.encoder1 = nn.Conv2d(3, 8, 3, stride=1, padding=1)  
        self.encoder2 = nn.Conv2d(8, 16, 3, stride=1, padding=1)  
        self.encoder3 = nn.Conv2d(16, 32, 3, stride=1, padding=1)

        self.ebn1 = nn.BatchNorm2d(8)
        self.ebn2 = nn.BatchNorm2d(16)
        self.ebn3 = nn.BatchNorm2d(32)
        
        self.norm3 = norm_layer(embed_dims[1])
        self.norm4 = norm_layer(embed_dims[2])

        self.dnorm3 = norm_layer(64)
        self.dnorm4 = norm_layer(32)

        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]

        self.block1 = nn.ModuleList([shiftedBlock(
            dim=embed_dims[1], num_heads=num_heads[0], mlp_ratio=1, qkv_bias=qkv_bias, qk_scale=qk_scale,
            drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[0], norm_layer=norm_layer,
            sr_ratio=sr_ratios[0])])

        self.block2 = nn.ModuleList([shiftedBlock(
            dim=embed_dims[2], num_heads=num_heads[0], mlp_ratio=1, qkv_bias=qkv_bias, qk_scale=qk_scale,
            drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[1], norm_layer=norm_layer,
            sr_ratio=sr_ratios[0])])

        self.dblock1 = nn.ModuleList([shiftedBlock(
            dim=embed_dims[1], num_heads=num_heads[0], mlp_ratio=1, qkv_bias=qkv_bias, qk_scale=qk_scale,
            drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[0], norm_layer=norm_layer,
            sr_ratio=sr_ratios[0])])

        self.dblock2 = nn.ModuleList([shiftedBlock(
            dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=1, qkv_bias=qkv_bias, qk_scale=qk_scale,
            drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[1], norm_layer=norm_layer,
            sr_ratio=sr_ratios[0])])

        self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0],
                                              embed_dim=embed_dims[1])
        self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1],
                                              embed_dim=embed_dims[2])

        self.decoder1 = nn.Conv2d(128, 64, 3, stride=1,padding=1)  
        self.decoder2 =   nn.Conv2d(64, 32, 3, stride=1, padding=1)  
        self.decoder3 =   nn.Conv2d(32, 16, 3, stride=1, padding=1) 
        self.decoder4 =   nn.Conv2d(16, 8, 3, stride=1, padding=1)
        self.decoder5 =   nn.Conv2d(8, 8, 3, stride=1, padding=1)

        self.dbn1 = nn.BatchNorm2d(64)
        self.dbn2 = nn.BatchNorm2d(32)
        self.dbn3 = nn.BatchNorm2d(16)
        self.dbn4 = nn.BatchNorm2d(8)
        
        self.final = nn.Conv2d(8, num_classes, kernel_size=1)

        self.soft = nn.Softmax(dim =1)

    def forward(self, x):
        
        B = x.shape[0]
        ### Encoder
        ### Conv Stage

        ### Stage 1
        out = F.relu(F.max_pool2d(self.ebn1(self.encoder1(x)),2,2))
        t1 = out
        ### Stage 2
        out = F.relu(F.max_pool2d(self.ebn2(self.encoder2(out)),2,2))
        t2 = out
        ### Stage 3
        out = F.relu(F.max_pool2d(self.ebn3(self.encoder3(out)),2,2))
        t3 = out

        ### Tokenized MLP Stage
        ### Stage 4

        out,H,W = self.patch_embed3(out)
        for i, blk in enumerate(self.block1):
            out = blk(out, H, W)
        out = self.norm3(out)
        out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
        t4 = out

        ### Bottleneck

        out ,H,W= self.patch_embed4(out)
        for i, blk in enumerate(self.block2):
            out = blk(out, H, W)
        out = self.norm4(out)
        out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()

        ### Stage 4

        out = F.relu(F.interpolate(self.dbn1(self.decoder1(out)),scale_factor=(2,2),mode ='bilinear'))
        
        out = torch.add(out,t4)
        _,_,H,W = out.shape
        out = out.flatten(2).transpose(1,2)
        for i, blk in enumerate(self.dblock1):
            out = blk(out, H, W)

        ### Stage 3
        
        out = self.dnorm3(out)
        out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
        out = F.relu(F.interpolate(self.dbn2(self.decoder2(out)),scale_factor=(2,2),mode ='bilinear'))
        out = torch.add(out,t3)
        _,_,H,W = out.shape
        out = out.flatten(2).transpose(1,2)
        
        for i, blk in enumerate(self.dblock2):
            out = blk(out, H, W)

        out = self.dnorm4(out)
        out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()

        out = F.relu(F.interpolate(self.dbn3(self.decoder3(out)),scale_factor=(2,2),mode ='bilinear'))
        out = torch.add(out,t2)
        out = F.relu(F.interpolate(self.dbn4(self.decoder4(out)),scale_factor=(2,2),mode ='bilinear'))
        out = torch.add(out,t1)
        out = F.relu(F.interpolate(self.decoder5(out),scale_factor=(2,2),mode ='bilinear'))

        return self.final(out)

参考:

UNext翻译(UNext基于MLP的医学图像快速分割网络)(有那么一丢丢没看懂)文章来源地址https://www.toymoban.com/news/detail-822629.html

到了这里,关于transformer优化(一)-UNeXt 学习笔记的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 人工智能入门学习笔记(一)

    家人们,好久不见哈!最近在尝试着学习人工智能的相关知识和具体技能呀。说实话,当像我这样的 小白初探人工智能体系 时,总是被很多未知的名词以及茫茫内容所淹没,便去想通过网络学习帮助自己建立正确的人工智能基本概念认知。在此,我便进一步对人工智能体系

    2024年02月02日
    浏览(63)
  • 模型训练:优化人工智能和机器学习,完善DevOps工具的使用

    作者:JFrog大中华区总经理董任远 据说法餐的秘诀在于黄油、黄油、更多的黄油。同样,对于DevOps而言,成功的三大秘诀是自动化、自动化、更高程度的自动化,而这一切归根结底都在于构建能够更快速地不断发布新版软件的流程。 尽管人们认为在人工智能(AI)和机器学习

    2024年02月10日
    浏览(39)
  • 人工智能学习笔记六——CBOW模型

    连续词袋模型(CBOW)模型是word2vec下的一个模型,是一群用来产生词向量的相关模型。这些模型为浅而双层的神经网络,用来训练以重新建构语言学之词文本。 网络 以词表现,并且需猜测相邻位置的输入词,在word2vec中词袋模型假设下,词的顺序是不重要的。训练完成之后,

    2024年02月14日
    浏览(39)
  • 人工智能( 第 3 版)第一章学习笔记

    第 1 章 人工智能概述 1.0 引言 本文对人工智能的观点:人工智能是由人(people)、想法(idea)、方法(method)、机器(machine)和结果(outcome)等对象组成的。人通过机器(计算机)将自己的想法以某种方法进行实现,最终实现的东西称为结果。 研究人工智能或实现人工智能系

    2024年01月25日
    浏览(46)
  • 探索人工智能 | 模型训练 使用算法和数据对机器学习模型进行参数调整和优化

    模型训练是指 使用算法和数据对机器学习模型进行参数调整和优化 的过程。模型训练一般包含以下步骤:数据收集、数据预处理、模型选择、模型训练、模型评估、超参数调优、模型部署、持续优化。 数据收集是指为机器学习或数据分析任务收集和获取用于训练或分析的数

    2024年02月12日
    浏览(56)
  • 飞浆AI studio人工智能课程学习(3)-在具体场景下优化Prompt

    01 常见应用场景与优化示例 02 优质Prompt模板化 03 大作业指引:Prompt作品积分赛 01 常见应用场景与优化示例 内容产业规模庞大、领域众多,大模型强大的生成能力给工作和生活带来了极大的想象力。 ?弹幕说一说,哪些AIGC场景是你最感兴趣的?先来看几类常见的: ·产品海报背景

    2024年02月06日
    浏览(49)
  • 人工智能基础_机器学习006_有监督机器学习_正规方程的公式推导_最小二乘法_凸函数的判定---人工智能工作笔记0046

    我们来看一下公式的推导这部分比较难一些, 首先要记住公式,这个公式,不用自己理解,知道怎么用就行, 比如这个(mA)T 这个转置的关系要知道 然后我们看这个符号就是求X的导数,X导数的转置除以X的导数,就得到单位矩阵, 可以看到下面也是,各种X的导数,然后计算,得到对应的矩阵

    2024年02月08日
    浏览(52)
  • 人工智能课程笔记(7)强化学习(基本概念 Q学习 深度强化学习 附有大量例题)

    强化学习和深度学习都是机器学习的分支,但是两者在方法和应用场景上有所不同。 强化学习 : 强化学习概述 :强化学习是一种通过智能体与环境进行交互来学习最优行动策略的算法。在强化学习中,智能体与环境不断交互,观察环境的状态并采取不同的行动,从而获得奖

    2024年01月17日
    浏览(50)
  • 人工智能基础_机器学习007_高斯分布_概率计算_最小二乘法推导_得出损失函数---人工智能工作笔记0047

    这个不分也是挺难的,但是之前有详细的,解释了,之前的文章中有, 那么这里会简单提一下,然后,继续向下学习 首先我们要知道高斯分布,也就是,正太分布, 这个可以预测x在多少的时候,概率最大 要知道在概率分布这个,高斯分布公式中,u代表平均值,然后西格玛代表标准差,知道了

    2024年02月07日
    浏览(68)
  • 人工智能学习与实训笔记(一):零基础理解神经网络

    人工智能专栏文章汇总:人工智能学习专栏文章汇总-CSDN博客 本篇目录 一、什么是神经网络模型 二、机器学习的类型 2.1 监督学习 2.2 无监督学习 2.3 半监督学习 2.4 强化学习 三、网络模型结构基础 3.1 单层网络 ​编辑 3.2 多层网络 3.3 非线性多层网络  四、 神经网络解决回

    2024年02月20日
    浏览(41)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包