【代码复现】TransUNet代码实现流程

这篇具有很好参考价值的文章主要介绍了【代码复现】TransUNet代码实现流程。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。


摘要:从图像输入模型开始,梳理TransUNet模型流程。

注:这里n_patches原代码设置196,但自己在分割硬渗出物的时候设置成了1024,不知道会不会使得效果变差

一. class VisionTransformer()

图片开始进入模型所在代码行:

outputs = model(image_batch)  # (B,n_classes,H,W)

然后进入类VisionTransformer(nn.Moudle)

class VisionTransformer(nn.Module):
    def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False):
        super(VisionTransformer, self).__init__()
        self.num_classes = num_classes
        self.zero_head = zero_head
        self.classifier = config.classifier
        self.transformer = Transformer(config, img_size, vis)
        self.decoder = DecoderCup(config)
        self.segmentation_head = SegmentationHead(
            in_channels=config['decoder_channels'][-1],
            out_channels=config['n_classes'],
            kernel_size=3,
        )
        self.config = config

    def forward(self, x):
        if x.size()[1] == 1:   # 如果图片是灰度图就在其通道方向进行复制从1维转成3维(比如CT图像就是灰度图)
            x = x.repeat(1,3,1,1)   # (B,3,H,W) 

        # 然后将x送入self.trasnformer中,实现在1节
        x, attn_weights, features = self.transformer(x) 
         # (B, n_patch, hidden):(B,196,768) # features:(B,512,H/8,W/8);(B,256,H/4,W/4);(B,64,H/2,W/2)
        x = self.decoder(x, features)
        logits = self.segmentation_head(x)
        return logits

1. self.transformer()

class Transformer(nn.Module):
    def __init__(self, config, img_size, vis):
        super(Transformer, self).__init__()
        self.embeddings = Embeddings(config, img_size=img_size)
        self.encoder = Encoder(config, vis)

    def forward(self, input_ids):   # (B,3,H,W)

        # 将x送入self.embeddings中,实现在1.1节
        embedding_output, features = self.embeddings(input_ids) # (B, 1024, 768) # (B,512,H/8,W/8);(B,256,H/4,W/4);(B,64,H/2,W/2)
        encoded, attn_weights = self.encoder(embedding_output)  # (B, n_patchs, hidden)
        return encoded, attn_weights, features

这部分实现:
features:表示CNN支路中的3个特征图;embedding_output:表示Transformer支路的输入;self.encoder:即Transformer支路的实现;
encoded:表示

1.1. self.embeddings()

class Embeddings(nn.Module):
    """Construct the embeddings from patch, position embeddings.
    """
    def __init__(self, config, img_size, in_channels=3):
        super(Embeddings, self).__init__()
        self.hybrid = None
        self.config = config
        img_size = _pair(img_size)

        if config.patches.get("grid") is not None:   # ResNet
            grid_size = config.patches["grid"]  # (14,14)
            patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1])  # (1,1)
            patch_size_real = (patch_size[0] * 16, patch_size[1] * 16)  # (16,16)
            n_patches = (img_size[0] // patch_size_real[0]) * (img_size[1] // patch_size_real[1])  # 14*14
            self.hybrid = True
        else:
            patch_size = _pair(config.patches["size"])
            n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])
            self.hybrid = False

        if self.hybrid:
            self.hybrid_model = ResNetV2(block_units=config.resnet.num_layers, width_factor=config.resnet.width_factor)
            in_channels = self.hybrid_model.width * 16
        self.patch_embeddings = Conv2d(in_channels=in_channels,
                                       out_channels=config.hidden_size,
                                       kernel_size=patch_size,
                                       stride=patch_size)
        self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches, config.hidden_size)) # (1,n_patches, hidden): (1,196,768)

        self.dropout = Dropout(config.transformer["dropout_rate"])


    def forward(self, x):
        if self.hybrid:      # True

            # 将x送入self.hybrid_model中,实现在1.1.1节
            x, features = self.hybrid_model(x) # (B,3,H,W) -> (B,1024,H/16,W/16)
        else:
            features = None
        x = self.patch_embeddings(x)  # (B, hidden. n_patches^(1/2), n_patches^(1/2)) = (B,768,H/16,W/16)
        x = x.flatten(2)   # (B, hidden, n_patches)
        x = x.transpose(-1, -2)  # (B, n_patches, hidden) = (B, 1024, 768)

        embeddings = x + self.position_embeddings # (B, 1024, 768)
        embeddings = self.dropout(embeddings)
        return embeddings, features

1.1.1. self.hybrid_model()

class ResNetV2(nn.Module):
    """Implementation of Pre-activation (v2) ResNet mode."""

    def __init__(self, block_units, width_factor):  # (3,4,9); 1
        super().__init__()
        width = int(64 * width_factor)
        self.width = width

        self.root = nn.Sequential(OrderedDict([
            ('conv', StdConv2d(3, width, kernel_size=7, stride=2, bias=False, padding=3)),
            ('gn', nn.GroupNorm(32, width, eps=1e-6)),
            ('relu', nn.ReLU(inplace=True)),
            # ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=0))
        ]))

        self.body = nn.Sequential(OrderedDict([
            ('block1', nn.Sequential(OrderedDict(
                [('unit1', PreActBottleneck(cin=width, cout=width*4, cmid=width))] +
                [(f'unit{i:d}', PreActBottleneck(cin=width*4, cout=width*4, cmid=width)) for i in range(2, block_units[0] + 1)],
                ))),
            ('block2', nn.Sequential(OrderedDict(
                [('unit1', PreActBottleneck(cin=width*4, cout=width*8, cmid=width*2, stride=2))] +
                [(f'unit{i:d}', PreActBottleneck(cin=width*8, cout=width*8, cmid=width*2)) for i in range(2, block_units[1] + 1)],
                ))),
            ('block3', nn.Sequential(OrderedDict(
                [('unit1', PreActBottleneck(cin=width*8, cout=width*16, cmid=width*4, stride=2))] +
                [(f'unit{i:d}', PreActBottleneck(cin=width*16, cout=width*16, cmid=width*4)) for i in range(2, block_units[2] + 1)],
                ))),
        ]))

    def forward(self, x):  # (B,3,H,W)
        features = []
        b, c, in_size, _ = x.size()    # B,3,H
        x = self.root(x)   # (B,3,H,W) -> (B,64,H/2,W/2)
        features.append(x) # (B,64,H/2,W/2)
        x = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)(x)   # (B,64,H/4-1,W/4-1)
        for i in range(len(self.body)-1): # i:       0         ;      1
            x = self.body[i](x)           # (B,256,H/4-1,W/4-1);(B,512,H/8,W/8)
            right_size = int(in_size / 4 / (i+1))  # 56,28

            # 进行补0操作,将特征图size补成right_size
            if x.size()[2] != right_size:
                pad = right_size - x.size()[2]
                assert pad < 3 and pad > 0, "x {} should {}".format(x.size(), right_size)
                feat = torch.zeros((b, x.size()[1], right_size, right_size), device=x.device)  # (B,256,H/4,W/4)
                feat[:, :, 0:x.size()[2], 0:x.size()[3]] = x[:]
            else:
                feat = x

            # 这三个特征图就是CNNs分支里的
            features.append(feat) # (B,64,H/2,W/2);(B,256,H/4,W/4);(B,512,H/8,W/8)
        x = self.body[-1](x)      # (B,1024,H/16,W/16)
        return x, features[::-1]  #对features逆序排序

到这里,这部分代码就走完了,接下来将x,features两个参数传入类Embeddings,继续1.1小节

1.1.(2)

为了简洁,这里就只显示前向传播forward过程

    def forward(self, x):
        if self.hybrid:      # True

            # 将x送入self.hybrid_model中,实现在1.1.1节
            x, features = self.hybrid_model(x) # (B,3,H,W) -> (B,1024,H/16,W/16)
            # x:(B,1024,H/16,W/16)
            # features[0]:(B,512,H/8,W/8)
            # features[1]:(B,256,H/4,W/4)
            # features[2]:(B,64,H/2,W/2)
        else:
            features = None

        # patch_embedding是通过卷积层实现的, channel数变化:1024->768
        x = self.patch_embeddings(x)  # (B, hidden. n_patches^(1/2), n_patches^(1/2)) = (B,768,H/16,W/16)
        
        x = x.flatten(2)   # (B, hidden, n_patches) = (B, 768, H*W/(16*16) )
        x = x.transpose(-1, -2)  # (B, n_patches, hidden) = (B, H*W/(16*16), 768)

        # 加入位置编码(1, H*W/(16*16), 768)
        embeddings = x + self.position_embeddings # (B, H*W/(16*16), 768)
        embeddings = self.dropout(embeddings)
        return embeddings, features      # (B, n_patches, 768); features

    这段代码中的最初的x 1 / 16 1/16 1/16倍的特征图,然后分别经过:①patch_embeddings(x):改变通道数为768;②flatten(2):合并后两个维度;③transpose(-1, -2):交换后两个维度,目的是将张量改成transformer需要的shape形状;④self.position_embeddings:加入位置编码;⑤self.dropout:防止过拟合。
最后,返回的embeddings会送入Transformer Layer

到这里类Embeddings执行结束。下面将embeddings传入Encoder,继续1小节。

1.(2)

前向传播forward过程:

    def forward(self, input_ids):   # (B,3,H,W)

        # 将x送入self.embeddings中,实现在1.1节
        embedding_output, features = self.embeddings(input_ids) # (B, 1024, 768) # (B,512,H/8,W/8);(B,256,H/4,W/4);(B,64,H/2,W/2)

        # 这里将embedding_output传入self.encoder中,实现在1.2小节
        encoded, attn_weights = self.encoder(embedding_output)  # (B, n_patchs, hidden):(B, 1024, 768)
        return encoded, attn_weights, features

1.2. self.encoder()

class Encoder(nn.Module):
    def __init__(self, config, vis):
        super(Encoder, self).__init__()
        self.vis = vis
        self.layer = nn.ModuleList()
        self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6)
        for _ in range(config.transformer["num_layers"]):  # 12
            layer = Block(config, vis)
            self.layer.append(copy.deepcopy(layer))

    # 这里的hidden_states是上面代码中的embedding_output
    def forward(self, hidden_states): # (B, n_patches, 768):(B, 1024, 768)
        attn_weights = []
        
        '''
        这里的layer_block是Transformer Layer层,共有12层
        '''
        for layer_block in self.layer:
            # 跳转到类Block,实现见1.2.1节
            hidden_states, weights = layer_block(hidden_states)  
            if self.vis:
                attn_weights.append(weights)
        encoded = self.encoder_norm(hidden_states)
        return encoded, attn_weights

1.2.1. class Block()

类class BlockTransformer Layer的实现过程,有12个Transformer Layer,所以要执行12次

class Block(nn.Module):
    def __init__(self, config, vis):
        super(Block, self).__init__()
        self.hidden_size = config.hidden_size
        self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6)
        self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6)
        self.ffn = Mlp(config)
        self.attn = Attention(config, vis)

    def forward(self, x):
        h = x                       # (B, n_patchs, hidden):(B, 1024, 768)
        x = self.attention_norm(x)  # (B, n_patchs, hidden):(B, 1024, 768) 

        # 下面使用注意力模块,实现见1.2.1.1节
        x, weights = self.attn(x) # (B, 1024, 768) 
        x = x + h 

        h = x # (B,196,768)
        x = self.ffn_norm(x) 
        x = self.ffn(x) 
        x = x + h
        return x, weights
1.2.1.1. self.attn(x)

类Attention()

class Attention(nn.Module):
    def __init__(self, config, vis):
        super(Attention, self).__init__()
        self.vis = vis
        self.num_attention_heads = config.transformer["num_heads"]
        self.attention_head_size = int(config.hidden_size / self.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = Linear(config.hidden_size, self.all_head_size)
        self.key = Linear(config.hidden_size, self.all_head_size)
        self.value = Linear(config.hidden_size, self.all_head_size)

        self.out = Linear(config.hidden_size, config.hidden_size)
        self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"])
        self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"])

        self.softmax = Softmax(dim=-1)

    def transpose_for_scores(self, x):      # (B, 1024, 768)

        # 构造新的张量,其形状为x前两维加上(self.num_attention_heads, self.attention_head_size)作为后两维,
        # 其形状为(B, n_patchs, self.num_attention_heads, self.attention_head_size)
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) # (B,1024,12,64)
        x = x.view(*new_x_shape)   # 将 x 变成new_x_shape的形状: (B,1024,12,64)
        return x.permute(0, 2, 1, 3)   # (B,12,1024,64)

    def forward(self, hidden_states):  # (B, n_patchs, hidden):(B, 1024, 768)
 
        # 使用Linear生成q、k、v
        mixed_query_layer = self.query(hidden_states)     # (B, 1024, 768)
        mixed_key_layer = self.key(hidden_states)         # (B, 1024, 768)
        mixed_value_layer = self.value(hidden_states)     # (B, 1024, 768)
       
        # 将q、k、v 变换成需要的形状
        query_layer = self.transpose_for_scores(mixed_query_layer)  # (B,12,1024,64)
        key_layer = self.transpose_for_scores(mixed_key_layer)      # (B,12,1024,64)
        value_layer = self.transpose_for_scores(mixed_value_layer)  # (B,12,1024,64)

        # q 点乘 k的转置,记为:q·k'
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))# (B,12,1024,64)matmul(B,12,64,1024)=(B,12,1024,1024)

        # q·k'/√d  :attention_head_size是注意力机制的维度d
        attention_scores = attention_scores / math.sqrt(self.attention_head_size) # (B,12,1024,1024)

        # softmax(q·k'/√d)
        attention_probs = self.softmax(attention_scores)  # (B,12,1024,1024)

        # self.vis = False
        weights = attention_probs if self.vis else None
        
        # softmax(q·k'/√d)  : 这一步猜测是防止过拟合
        attention_probs = self.attn_dropout(attention_probs)  # (B,12,1024,1024)

        # softmax(q·k'/√d)·v , 其中v:(B,12,1024,64)
        context_layer = torch.matmul(attention_probs, value_layer) # (B,12,1024,64)  #(B,12,1024,1024)matmul(B,12,1024,64)=(B,12,1024,64)
        # 至此, 完成Attention(q,k,v) = softmax(q·k'/√d)·v
        
        # 调整shape:(B,12,1024,64)->(B,1024,12,64)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous() # (B,1024,12,64)

        # 构造一个新的张量形状(B,1024,768), 将 context_layer 变成其形状
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) # (B,1024,768)
        context_layer = context_layer.view(*new_context_layer_shape) # (B,1024,768)
        
        # 通过一个Linear层和dropout层
        attention_output = self.out(context_layer) # (B,1024,768)
        attention_output = self.proj_dropout(attention_output) # (B,1024,768)

        # 到这, attention_output 为注意力机制的输出, weight为None
        return attention_output, weights          # (B,1024,768)
        # 下面跳转到 1.2.1.小节的 class Block()

1.2.1.(2)

class Block()
只看前向传播过程,该从x = x+h 开始


    def forward(self, x):
        h = x                       # (B, n_patchs, hidden):(B, 1024, 768)
        x = self.attention_norm(x)  # (B, n_patchs, hidden):(B, 1024, 768) 

        # 下面使用注意力模块,实现见1.2.1.1节
        x, weights = self.attn(x) # (B, 1024, 768) 
        x = x + h                 # (B, 1024, 768)

        h = x                     # (B, 1024, 768)

        # self.ffn_norm是一个LayerNorm层
        x = self.ffn_norm(x)      # (B, 1024, 768)

        # self.ffn = Mlp(), 实现见1.2.1.2小节
        x = self.ffn(x) 
        x = x + h
        return x, weights
1.2.1.2. self.ffn(x)

类Mlp

class Mlp(nn.Module):
    def __init__(self, config):
        super(Mlp, self).__init__()
        self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"])  # mlp_dim:3072
        self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size)
        self.act_fn = ACT2FN["gelu"]
        self.dropout = Dropout(config.transformer["dropout_rate"])

        self._init_weights()

    def _init_weights(self):
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc2.weight)
        nn.init.normal_(self.fc1.bias, std=1e-6)
        nn.init.normal_(self.fc2.bias, std=1e-6)

    def forward(self, x):   # (B, 1024, 768)
    
        x = self.fc1(x)         # (B, 1024, 3072)
        x = self.act_fn(x)      # (B, 1024, 3072)
        x = self.dropout(x)     # (B, 1024, 3072)
        
        x = self.fc2(x)         # (B, 1024, 768)
        x = self.dropout(x)     # (B, 1024, 768)
        return x
        # 到这里, Mlp 的过程完成,下面进入1.2.1小节class Block()

1.2.1.(3)

class Block()
只看前向传播过程,该从第二个x = x+h 开始


    def forward(self, x):
        h = x                       # (B, n_patchs, hidden):(B, 1024, 768)
        x = self.attention_norm(x)  # (B, n_patchs, hidden):(B, 1024, 768) 

        # 下面使用注意力模块,实现见1.2.1.1节
        x, weights = self.attn(x) # (B, 1024, 768) 
        x = x + h                 # (B, 1024, 768)

        h = x                     # (B, 1024, 768)

        # self.ffn_norm是一个LayerNorm层
        x = self.ffn_norm(x)      # (B, 1024, 768)

        # self.ffn = Mlp(), 实现见1.2.1.2小节
        x = self.ffn(x)           # (B, 1024, 768)
        x = x + h                 # (B, 1024, 768)
        return x, weights    # x: (B, 1024, 768) ; weights: None
        # 到这 class Block() 的过程结束,下面跳转到 1.2. self.encoder()

1.2.(2) self.encoder()

类Encoder
代码该从if self.vis:开始

class Encoder(nn.Module):
    def __init__(self, config, vis):
        super(Encoder, self).__init__()
        self.vis = vis
        self.layer = nn.ModuleList()
        self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6)
        for _ in range(config.transformer["num_layers"]):  # 12
            layer = Block(config, vis)
            self.layer.append(copy.deepcopy(layer))

    def forward(self, hidden_states): # (B, n_patch, hidden):(B, 1024, 768)
        attn_weights = []
        
        '''
        这里的layer_block是Transformer Layer层,共有12层
        '''
        # 这里self.layer有12层,即循环要走12次class Block()
        for layer_block in self.layer:
            hidden_states, weights = layer_block(hidden_states)  # (B,1024,768)
            if self.vis:    # False
                attn_weights.append(weights)

        # LayerNorm层, 
        encoded = self.encoder_norm(hidden_states)  # (B,1024,768)
        return encoded, attn_weights  # encoded:(B,1024,768); attn_weights:None
        # 到这 class Encoder() 的过程结束,下面跳转到 1. self.transformer()

1.(3)

类Transformer()
代码该从return encoded, attn_weights, features开始

class Transformer(nn.Module):
    def __init__(self, config, img_size, vis):
        super(Transformer, self).__init__()
        self.embeddings = Embeddings(config, img_size=img_size)
        self.encoder = Encoder(config, vis)

    def forward(self, input_ids):   # (B,3,H,W)

        # 将x送入self.embeddings中,实现在1.1节
        embedding_output, features = self.embeddings(input_ids) # (B, 196, 768) # (B,512,H/8,W/8);(B,256,H/4,W/4);(B,64,H/2,W/2)
        encoded, attn_weights = self.encoder(embedding_output)  # (B, n_patchs, hidden)

        # 返回 encoded:(B,1024,768); features: (B,512,H/8,W/8);(B,256,H/4,W/4);(B,64,H/2,W/2)
        return encoded, attn_weights, features
        # 到这 class Transformer() 的过程结束,下面跳转到 一. class VisionTransformer

一.(2)

类VisionTransformer()
代码该从x = self.decoder(x, features)开始

class VisionTransformer(nn.Module):
    def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False):
        super(VisionTransformer, self).__init__()
        self.num_classes = num_classes
        self.zero_head = zero_head
        self.classifier = config.classifier
        self.transformer = Transformer(config, img_size, vis)
        self.decoder = DecoderCup(config)
        self.segmentation_head = SegmentationHead(
            in_channels=config['decoder_channels'][-1],
            out_channels=config['n_classes'],
            kernel_size=3,
        )
        self.config = config

    def forward(self, x):
        if x.size()[1] == 1:   # 如果图片是灰度图就在其通道方向进行复制从1维转成3维(比如CT图像就是灰度图)
            x = x.repeat(1,3,1,1)   # (B,3,H,W) 

        # 然后将x送入self.trasnformer中,实现在1节
        x, attn_weights, features = self.transformer(x) 
         # (B, n_patchs, hidden):(B,1024,768) # features:(B,512,H/8,W/8);(B,256,H/4,W/4);(B,64,H/2,W/2)

        # x: (B,1024,768); features: (B,512,H/8,W/8);(B,256,H/4,W/4);(B,64,H/2,W/2)
        # 将x,features送入DecoderCup(), 实现在2小节
        x = self.decoder(x, features)
        logits = self.segmentation_head(x)
        return logits

2. self.decoder()

类Decoder()

class DecoderCup(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        head_channels = 512
        self.conv_more = Conv2dReLU(
            config.hidden_size,
            head_channels,
            kernel_size=3,
            padding=1,
            use_batchnorm=True,
        )
        decoder_channels = config.decoder_channels
        in_channels = [head_channels] + list(decoder_channels[:-1])
        out_channels = decoder_channels

        if self.config.n_skip != 0:
            skip_channels = self.config.skip_channels
            for i in range(4-self.config.n_skip):  # re-select the skip channels according to n_skip
                skip_channels[3-i]=0    # skip_channels=[512,256,64,0]

        else:
            skip_channels=[0,0,0,0]

        blocks = [
            DecoderBlock(in_ch, out_ch, sk_ch) for in_ch, out_ch, sk_ch in zip(in_channels, out_channels, skip_channels)
        ]
        self.blocks = nn.ModuleList(blocks)

    def forward(self, hidden_states, features=None):  # (B, n_patchs, hidden):(B,1024,768) # features:(B,512,H/8,W/8);(B,256,H/4,W/4);(B,64,H/2,W/2)

        # reshape from (B, n_patchs, hidden) to (B, h, w, hidden)
        B, n_patch, hidden = hidden_states.size()  
        h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch)) # 32,32
        x = hidden_states.permute(0, 2, 1) # (B,768,1024)
        x = x.contiguous().view(B, hidden, h, w) # (B,768,32,32)

        # 将channel变成512, 对应前面三个特征图的通道数512,256,64
        # 通道数256和64在后面会体现到
        x = self.conv_more(x) # (B,512,32,32)    # features:(B,512,H/8,W/8);(B,256,H/4,W/4);(B,64,H/2,W/2)

        '''
        (1)self.block共有4层, 对应上面对DecoderBlock()的定义, 
           其中zip(in_channels, out_channels, skip_channels)这里面的参数有
           4 组:(512, 256, 512); (256, 128, 256); (128, 64, 64); (64, 16, 0)。
        (2)skip为features的三个特征图: (B,512,H/8,W/8);(B,256,H/4,W/4);(B,64,H/2,W/2)。
        (3)Decoder_Block()作用: 先对 x 进行上采样,然后将 x 与 skip 进行cat, 再对cat后
           的x进行卷积使其channel变成256, 128, 64。
           但是当i=3时,不在进行skip与x拼接(cat),Decoder_Block()的作用:对x的channel降维到16,
           最后输出的x:(B,16,H,W)=(B,16,512,512)
        '''
        for i, decoder_block in enumerate(self.blocks):  # (512, 256, 512); (256, 128, 256); (128, 64, 64); (64, 16, 0)
            if features is not None:
                skip = features[i] if (i < self.config.n_skip) else None
            else:
                skip = None

            # 将 x 和 skip 送如decoder_block,实现见2.1小节
            x = decoder_block(x, skip=skip)
        return x   # x:(B,16,H,W)=(B,16,512,512)

2.1. decoder_block(x, skip)

类DecoderBlock()

class DecoderBlock(nn.Module):
    def __init__(
            self,
            in_channels,
            out_channels,
            skip_channels=0,
            use_batchnorm=True,
    ):
        super().__init__()
        self.conv1 = Conv2dReLU(
            in_channels + skip_channels,
            out_channels,
            kernel_size=3,
            padding=1,
            use_batchnorm=use_batchnorm,
        )
        self.conv2 = Conv2dReLU(
            out_channels,
            out_channels,
            kernel_size=3,
            padding=1,
            use_batchnorm=use_batchnorm,
        )
        self.up = nn.UpsamplingBilinear2d(scale_factor=2)

    def forward(self, x, skip=None): # (B,512,32,32); (B,512,64,64) 
        # 先对x进行上采样,然后将x与skip进行cat
        x = self.up(x)      # (B,512,64,64) 
        if skip is not None:
            x = torch.cat([x, skip], dim=1)   # (B,1024,64,64) 
        x = self.conv1(x)       # (B,256,64,64) 
        x = self.conv2(x)       # (B,256,64,64)
        return x
        '''
        上面只是i=0的时候实现流程
        这个类根据i的取值会执行4次,0、1、2的时候 x 会和 skip 及进行cat操作,4的时候skip为None,只
        对 x 进行卷积操作,最后x:(B,16,512,512),该特征图是整个模型最后的特征图。
        '''

执行完后会跳到2.小节最后一行return x ,然后从一.节类VisionTransformer中的logits = self.segmentation_head(x)开始。

一.(3)

类VisionTransformer()
代码该从logits = self.segmentation_head(x)开始
为了简洁,只看前向传播:

    def forward(self, x):
        if x.size()[1] == 1:
            x = x.repeat(1,3,1,1)   # (B,3,H,W) = (B,3,512,512)
        x, attn_weights, features = self.transformer(x)  # (B, n_patchs, hidden):(B,1024,768) # features:(B,512,H/8,W/8);(B,256,H/4,W/4);(B,64,H/2,W/2)
        x = self.decoder(x, features)     # (B,16,H,W):(B,16,512,512)

        # 从这跳转到3.小节的class SegmentationHead()
        logits = self.segmentation_head(x)  # (B,2,H,W):(B,2,512,512)
        return logits

3. self.segmentation_head(x)

类SegmentationHead()

class SegmentationHead(nn.Sequential):

    def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1):
        
        # 分割头:将 out_channels 设置成n_classes, 这里 out_channels=2
        conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2)
        
        # upsampling=1 不进行上采样
        upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity()
        super().__init__(conv2d, upsampling)

该段代码执行完后,跳转到一.小节中的return logits,即返回最后的分类特征图,至此模型完毕。

附录. 模型图

最后附上模型图:

transunet复现,模型代码解读,深度学习,pytorch文章来源地址https://www.toymoban.com/news/detail-761647.html

到了这里,关于【代码复现】TransUNet代码实现流程的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 【模型解读】【代码复现】Segment Anything Model(SAM)

    Segment Anything Model论文 Segment Anything Model官网 Segment Anything Model官网demo网页端 Segment Anything Model模型源码 SAM应用:生成包含语义信息的遥感图像分割结果 We introduce the Segment Anything (SA) project: a new task, model, and dataset for image segmentation. Using our efficient model in a data collection loop, we bu

    2024年02月08日
    浏览(41)
  • Segment Anything Model(SAM)模型解读及代码复现

    Segment Anything Model论文 Segment Anything Model官网 Segment Anything Model官网demo网页端 Segment Anything Model模型源码 SAM应用:生成包含语义信息的遥感图像分割结果 We introduce the Segment Anything (SA) project: a new task, model, and dataset for image segmentation. Using our efficient model in a data collection loop, we bu

    2024年02月06日
    浏览(51)
  • DEMATEL-ISM模型的Python实现——方法介绍以及代码复现

    本文源于笔者的《系统工程》课程的小组作业,笔者尝试运用DEMATEL-ISM方法来进行分析,建模求解,但在网络上并没有找到相应的,特别是集合DEMATEL-ISM方法的代码。因此自己码了DEMATEL-ISM模型的Python代码,并作为第一个博客发布~ 参考文献中,笔者主要参考了李广利等 1 的研

    2023年04月20日
    浏览(42)
  • 目标检测论文解读复现之十:基于YOLOv5的遥感图像目标检测(代码已复现)

    前言        此前出了目标改进算法专栏,但是对于应用于什么场景,需要什么改进方法对应与自己的应用场景有效果,并且多少改进点能发什么水平的文章,为解决大家的困惑,此系列文章旨在给大家解读最新目标检测算法论文,帮助大家解答疑惑。解读的系列文章,本人

    2024年02月06日
    浏览(43)
  • 【EI复现】梯级水光互补系统最大化可消纳电量期望短期优化调度模型(Matlab代码实现)

    💥💥💞💞 欢迎来到本博客 ❤️❤️💥💥 🏆博主优势: 🌞🌞🌞 博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️ 座右铭: 行百里者,半于九十。 📋📋📋 本文目录如下: 🎁🎁🎁 目录 💥1 概述 📚2 运行结果 🎉3 参考文献 🌈4 Matlab代码实现 文献来源

    2024年02月13日
    浏览(56)
  • 【EI复现】考虑区域多能源系统集群协同优化的联合需求侧响应模型(Matlab代码实现)

    💥💥💞💞 欢迎来到本博客 ❤️❤️💥💥 🏆博主优势: 🌞🌞🌞 博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️ 座右铭: 行百里者,半于九十。 📋📋📋 本文目录如下: 🎁🎁🎁 目录 💥1 概述 📚2 运行结果 🎉3 参考文献 🌈4 Matlab代码实现 文献来源

    2024年02月13日
    浏览(49)
  • 经典神经网络论文超详细解读(六)——DenseNet学习笔记(翻译+精读+代码复现)

    上一篇我们介绍了ResNet:经典神经网络论文超详细解读(五)——ResNet(残差网络)学习笔记(翻译+精读+代码复现) ResNet通过短路连接,可以训练出更深的CNN模型,从而实现更高的准确度。今天我们要介绍的是 DenseNet(《Densely connected convolutional networks》) 模型,它的基本

    2024年02月03日
    浏览(62)
  • 经典神经网络论文超详细解读(八)——ResNeXt学习笔记(翻译+精读+代码复现)

    今天我们一起来学习何恺明大神的又一经典之作:  ResNeXt(《Aggregated Residual Transformations for Deep Neural Networks》) 。这个网络可以被解释为 VGG、ResNet 和 Inception 的结合体,它通过重复多个block(如在 VGG 中)块组成,每个block块聚合了多种转换(如 Inception),同时考虑到跨层

    2024年02月03日
    浏览(55)
  • 经典神经网络论文超详细解读(五)——ResNet(残差网络)学习笔记(翻译+精读+代码复现)

    《Deep Residual Learning for Image Recognition》这篇论文是何恺明等大佬写的,在深度学习领域相当经典,在2016CVPR获得best paper。今天就让我们一起来学习一下吧! 论文原文:https://arxiv.org/abs/1512.03385 前情回顾: 经典神经网络论文超详细解读(一)——AlexNet学习笔记(翻译+精读)

    2024年02月08日
    浏览(47)
  • 3DGS代码复现流程(windows本地、colab)

    环境配置 前提:已下载并配置好Anaconda、CUDA、pytorch、COLMAP等。 已在Anaconda中配置好虚拟环境,pytorch最好安装在虚拟环境中,不要在base中。 可参考:过程挺详细的 在python终端运行如下代码,将官方代码进行下载。 复现流程参考一个外国的大佬写的,但是有些文件的存放在我

    2024年04月17日
    浏览(48)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包