[pytorch] 3D Unet + Resnet替换Encoder

这篇具有很好参考价值的文章主要介绍了[pytorch] 3D Unet + Resnet替换Encoder。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

本文介绍如何实现Unet的3D版本,以及如何用Resnet替换Unet原始版本的Encoder.
原版Unet的实现: U-Net(Convolutional Networks for Biomedical Image Segmentation)
Resnet的实现: [pytorch] 2D + 3D ResNet代码实现, 改写

建议先对这两种网络结构有一定的了解,如果懒得去学习的话可以直接使用第三章节U-Net_resnet_encoder的完整代码。

1. Unet

from typing import Dict
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.autograd import Variable
import math
from functools import partial

1.1 Unet 2D 版本



class DoubleConv(nn.Sequential):
    def __init__(self, in_channels, out_channels, mid_channels=None):
        if mid_channels is None:
            mid_channels = out_channels
        super(DoubleConv, self).__init__(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )


class Down(nn.Sequential):
    def __init__(self, in_channels, out_channels):
        super(Down, self).__init__(
            nn.MaxPool2d(2, stride=2),
            DoubleConv(in_channels, out_channels)
        )


class Up(nn.Module):
    def __init__(self, in_channels, out_channels, bilinear=True):
        super(Up, self).__init__()
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
        x1 = self.up(x1)
        # [N, C, H, W]
        diff_y = x2.size()[2] - x1.size()[2]
        diff_x = x2.size()[3] - x1.size()[3]

        # padding_left, padding_right, padding_top, padding_bottom
        x1 = F.pad(x1, [diff_x // 2, diff_x - diff_x // 2,
                        diff_y // 2, diff_y - diff_y // 2])

        x = torch.cat([x2, x1], dim=1)
        x = self.conv(x)
        return x


class OutConv(nn.Sequential):
    def __init__(self, in_channels, num_classes):
        super(OutConv, self).__init__(
            nn.Conv2d(in_channels, num_classes, kernel_size=1)
        )


class UNet(nn.Module):
    def __init__(self,
                 in_channels: int = 1,
                 num_classes: int = 2,
                 bilinear: bool = True,
                 base_c: int = 64):
        super(UNet, self).__init__()
        self.in_channels = in_channels
        self.num_classes = num_classes
        self.bilinear = bilinear

        self.in_conv = DoubleConv(in_channels, base_c)
        self.down1 = Down(base_c, base_c * 2)
        self.down2 = Down(base_c * 2, base_c * 4)
        self.down3 = Down(base_c * 4, base_c * 8)
        factor = 2 if bilinear else 1
        self.down4 = Down(base_c * 8, base_c * 16 // factor)
        self.up1 = Up(base_c * 16, base_c * 8 // factor, bilinear)
        self.up2 = Up(base_c * 8, base_c * 4 // factor, bilinear)
        self.up3 = Up(base_c * 4, base_c * 2 // factor, bilinear)
        self.up4 = Up(base_c * 2, base_c, bilinear)
        self.out_conv = OutConv(base_c, num_classes)

    def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
        x1 = self.in_conv(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.out_conv(x)

        return logits

使用例子

model = UNet(in_channels= 1,
             num_classes= 2,
             bilinear= True,
             base_c= 64)
x=torch.randn(1,1,224,224)
X=model(x)
print('model output shape =')
print(X.shape)

model output shape =
torch.Size([1, 2, 224, 224])

1.2 Unet 3D 版本

将所有层换成3d,此外对bilinear部分进行了修改

class DoubleConv_3d(nn.Sequential):
    def __init__(self, in_channels, out_channels, mid_channels=None):
        if mid_channels is None:
            mid_channels = out_channels
        super(DoubleConv_3d, self).__init__(
            nn.Conv3d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm3d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv3d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True)
        )


class Down_3d(nn.Sequential):
    def __init__(self, in_channels, out_channels):
        super(Down_3d, self).__init__(
            nn.MaxPool3d(2, stride=2),
            DoubleConv_3d(in_channels, out_channels)
        )


class Up_3d(nn.Module):
    def __init__(self, in_channels, out_channels, bilinear=True):
        super(Up_3d, self).__init__()
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
            self.conv = DoubleConv_3d(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose3d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv_3d(in_channels, out_channels)

    def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
        x1 = self.up(x1)
        # [N, C, H, W, Z]
        diff_y = x2.size()[2] - x1.size()[2]
        diff_x = x2.size()[3] - x1.size()[3]
        diff_z = x2.size()[4] - x1.size()[4]

        # padding_left, padding_right, padding_top, padding_bottom
        x1 = F.pad(x1, [diff_x // 2, diff_x - diff_x // 2,
                        diff_y // 2, diff_y - diff_y // 2,
                        diff_z // 2, diff_z - diff_z // 2])

        x = torch.cat([x2, x1], dim=1)
        x = self.conv(x)
        return x


class OutConv_3d(nn.Sequential):
    def __init__(self, in_channels, num_classes):
        super(OutConv_3d, self).__init__(
            nn.Conv3d(in_channels, num_classes, kernel_size=1)
        )
        
class UNet_3d(nn.Module):
    def __init__(self,
                 in_channels: int = 1,
                 num_classes: int = 2,
                 bilinear: bool = True,
                 base_c: int = 64):
        super(UNet_3d, self).__init__()
        self.in_channels = in_channels
        self.num_classes = num_classes
        self.bilinear = bilinear

        self.in_conv = DoubleConv_3d(in_channels, base_c)
        self.down1 = Down_3d(base_c, base_c * 2)
        self.down2 = Down_3d(base_c * 2, base_c * 4)
        self.down3 = Down_3d(base_c * 4, base_c * 8)
        factor = 2 if bilinear else 1
        self.down4 = Down_3d(base_c * 8, base_c * 16 // factor)
        self.up1 = Up_3d(base_c * 16, base_c * 8 // factor, bilinear)
        self.up2 = Up_3d(base_c * 8, base_c * 4 // factor, bilinear)
        self.up3 = Up_3d(base_c * 4, base_c * 2 // factor, bilinear)
        self.up4 = Up_3d(base_c * 2, base_c, bilinear)
        self.out_conv = OutConv_3d(base_c, num_classes)

    def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
        x1 = self.in_conv(x)
        print('x1=',x1.shape)
        x2 = self.down1(x1)
        print('x2=',x2.shape)
        x3 = self.down2(x2)
        print('x3=',x3.shape)
        x4 = self.down3(x3)
        print('x4=',x4.shape)
        x5 = self.down4(x4)
        print('x5=',x5.shape)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.out_conv(x)

        return logits

例子

model = UNet_3d(in_channels= 1,
             num_classes= 2,
             bilinear= False,
             base_c= 64)
x=torch.randn(1,1,64,64,64)
X=model(x)
print('model output shape =')
print(X.shape)


x1= torch.Size([1, 64, 64, 64, 64])
x2= torch.Size([1, 128, 32, 32, 32])
x3= torch.Size([1, 256, 16, 16, 16])
x4= torch.Size([1, 512, 8, 8, 8])
x5= torch.Size([1, 1024, 4, 4, 4])
model output shape =
torch.Size([1, 2, 64, 64, 64])
model = UNet_3d(in_channels= 1,
             num_classes= 2,
             bilinear= True,
             base_c= 64)
x=torch.randn(1,1,64,64,64)
X=model(x)
print('model output shape =')
print(X.shape)

x1= torch.Size([1, 64, 64, 64, 64])
x2= torch.Size([1, 128, 32, 32, 32])
x3= torch.Size([1, 256, 16, 16, 16])
x4= torch.Size([1, 512, 8, 8, 8])
x5= torch.Size([1, 512, 4, 4, 4])
model output shape =
torch.Size([1, 2, 64, 64, 64])

这里建议使用bilinear= False的版本,上采样通过反卷积实现。

2. Resnet

直接使用原版Resnet直接替换Unet的Encoder的话是不行的,因为卷积设置不一样,每个block的输出大小不同。我们的目的是修改Resnet网络来将Resnet不同Block的输出调成和Unet的Encoder相同的大小。修改如下:

    1. 将Resnet第一层Block的卷积层 #stride=(2, 2, 2)-> (1, 1, 1)
    1. 将Bottleneck中的expansion从4改成2
def conv3x3x3(in_planes, out_planes, stride=1, dilation=1):
    # 3x3x3 convolution with padding
    return nn.Conv3d(
        in_planes,
        out_planes,
        kernel_size=3,
        dilation=dilation,
        stride=stride,
        padding=dilation,
        bias=False)


def downsample_basic_block(x, planes, stride, no_cuda=False):
    out = F.avg_pool3d(x, kernel_size=1, stride=stride)
    zero_pads = torch.Tensor(
        out.size(0), planes - out.size(1), out.size(2), out.size(3),
        out.size(4)).zero_()
    if not no_cuda:
        if isinstance(out.data, torch.cuda.FloatTensor):
            zero_pads = zero_pads.cuda()

    out = Variable(torch.cat([out.data, zero_pads], dim=1))

    return out


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3x3(inplanes, planes, stride=stride, dilation=dilation)
        self.bn1 = nn.BatchNorm3d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3x3(planes, planes, dilation=dilation)
        self.bn2 = nn.BatchNorm3d(planes)
        self.downsample = downsample
        self.stride = stride
        self.dilation = dilation

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 2  # 4 -> 2

    def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm3d(planes)
        self.conv2 = nn.Conv3d(
            planes, planes, kernel_size=3, stride=stride, dilation=dilation, padding=dilation, bias=False)
        self.bn2 = nn.BatchNorm3d(planes)
        self.conv3 = nn.Conv3d(planes, planes * 2, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm3d(planes * 2)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride
        self.dilation = dilation

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

class ResNet_3d(nn.Module):

    def __init__(self,
                 block,
                 layers,
                 num_classes=1000,
                 shortcut_type='B',
                 no_cuda = False,
                 include_top=True):
        super(ResNet_3d, self).__init__()
        self.inplanes = 64
        self.no_cuda = no_cuda
        self.include_top = include_top
        
        self.conv1 = nn.Conv3d(
            1,
            64,
            kernel_size=7,
            stride=(1, 1, 1), #stride=(2, 2, 2)-> (1, 1, 1)
            padding=(3, 3, 3),
            bias=False)
            
        self.bn1 = nn.BatchNorm3d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0], shortcut_type)
        self.layer2 = self._make_layer(
            block, 128, layers[1], shortcut_type, stride=2)
        self.layer3 = self._make_layer(
            block, 256, layers[2], shortcut_type, stride=2)
        self.layer4 = self._make_layer(
            block, 512, layers[3], shortcut_type, stride=2)
        
        if self.include_top:
            self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))  # output size = (1, 1)自适应
            self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                m.weight = nn.init.kaiming_normal(m.weight, mode='fan_out')
            elif isinstance(m, nn.BatchNorm3d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
                


    def _make_layer(self, block, planes, blocks, shortcut_type, stride=1, dilation=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            if shortcut_type == 'A':
                downsample = partial(
                    downsample_basic_block,
                    planes=planes * block.expansion,
                    stride=stride,
                    no_cuda=self.no_cuda)
            else:
                downsample = nn.Sequential(
                    nn.Conv3d(
                        self.inplanes,
                        planes * block.expansion,
                        kernel_size=1,
                        stride=stride,
                        bias=False), nn.BatchNorm3d(planes * block.expansion))

        layers = []
        layers.append(block(self.inplanes, planes, stride=stride, downsample=downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)
    

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        print('x1=',x.shape)
        x = self.maxpool(x)
        x = self.layer1(x)
        print('x2=',x.shape)
        x = self.layer2(x)
        print('x3=',x.shape)
        x = self.layer3(x)
        print('x4=',x.shape)
        x = self.layer4(x)
        print('x5=',x.shape)

        if self.include_top:
            x = self.avgpool(x)
            x = torch.flatten(x, 1)
            x = self.fc(x)

        return x


例子

resnet50_3d = ResNet_3d(Bottleneck, [3, 4, 6, 3],shortcut_type='B',no_cuda=False,num_classes=3,include_top=True)
x=torch.randn(1,1,64,64,64)
X=resnet50_3d(x)
print(X.shape)

x1= torch.Size([1, 64, 64, 64, 64])
x2= torch.Size([1, 128, 32, 32, 32])
x3= torch.Size([1, 256, 16, 16, 16])
x4= torch.Size([1, 512, 8, 8, 8])
x5= torch.Size([1, 1024, 4, 4, 4])
torch.Size([1, 3])

3. UNet_3d_resnet_encoder

完整代码如下,可以直接使用

from typing import Dict
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.autograd import Variable
import math
from functools import partial


def conv3x3x3(in_planes, out_planes, stride=1, dilation=1):
    # 3x3x3 convolution with padding
    return nn.Conv3d(
        in_planes,
        out_planes,
        kernel_size=3,
        dilation=dilation,
        stride=stride,
        padding=dilation,
        bias=False)


def downsample_basic_block(x, planes, stride, no_cuda=False):
    out = F.avg_pool3d(x, kernel_size=1, stride=stride)
    zero_pads = torch.Tensor(
        out.size(0), planes - out.size(1), out.size(2), out.size(3),
        out.size(4)).zero_()
    if not no_cuda:
        if isinstance(out.data, torch.cuda.FloatTensor):
            zero_pads = zero_pads.cuda()

    out = Variable(torch.cat([out.data, zero_pads], dim=1))

    return out


class Bottleneck(nn.Module):
    expansion = 2  # 4 -> 2

    def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm3d(planes)
        self.conv2 = nn.Conv3d(
            planes, planes, kernel_size=3, stride=stride, dilation=dilation, padding=dilation, bias=False)
        self.bn2 = nn.BatchNorm3d(planes)
        self.conv3 = nn.Conv3d(planes, planes * 2, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm3d(planes * 2)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride
        self.dilation = dilation

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

class DoubleConv_3d(nn.Sequential):
    def __init__(self, in_channels, out_channels, mid_channels=None):
        if mid_channels is None:
            mid_channels = out_channels
        super(DoubleConv_3d, self).__init__(
            nn.Conv3d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm3d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv3d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True)
        )


class Down_3d(nn.Sequential):
    def __init__(self, in_channels, out_channels):
        super(Down_3d, self).__init__(
            nn.MaxPool3d(2, stride=2),
            DoubleConv_3d(in_channels, out_channels)
        )


class Up_3d(nn.Module):
    def __init__(self, in_channels, out_channels, bilinear=True):
        super(Up_3d, self).__init__()
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
            self.conv = DoubleConv_3d(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose3d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv_3d(in_channels, out_channels)

    def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
        x1 = self.up(x1)
        # [N, C, H, W, Z]
        diff_y = x2.size()[2] - x1.size()[2]
        diff_x = x2.size()[3] - x1.size()[3]
        diff_z = x2.size()[4] - x1.size()[4]

        # padding_left, padding_right, padding_top, padding_bottom
        x1 = F.pad(x1, [diff_x // 2, diff_x - diff_x // 2,
                        diff_y // 2, diff_y - diff_y // 2,
                        diff_z // 2, diff_z - diff_z // 2])

        x = torch.cat([x2, x1], dim=1)
        x = self.conv(x)
        return x


class OutConv_3d(nn.Sequential):
    def __init__(self, in_channels, num_classes):
        super(OutConv_3d, self).__init__(
            nn.Conv3d(in_channels, num_classes, kernel_size=1)
        )
    
    

class UNet_3d_resnet_encoder(nn.Module):
    def __init__(self,
                 block,
                 layers,
                 in_channels: int = 1,
                 num_classes: int = 2,
                 bilinear: bool = True,
                 base_c: int = 64,
                 shortcut_type='B',
                 no_cuda = False):
        super(UNet_3d_resnet_encoder, self).__init__()
        self.in_channels = in_channels
        self.num_classes = num_classes
        self.bilinear = bilinear
        
        self.inplanes = 64
        self.no_cuda = no_cuda

        #self.in_conv = DoubleConv_3d(in_channels, base_c)
        #self.down1 = Down_3d(base_c, base_c * 2)
        #self.down2 = Down_3d(base_c * 2, base_c * 4)
        #self.down3 = Down_3d(base_c * 4, base_c * 8)
        # unet
        factor = 2 if bilinear else 1
        self.down4 = Down_3d(base_c * 8, base_c * 16 // factor)
        self.up1 = Up_3d(base_c * 16, base_c * 8 // factor, bilinear)
        self.up2 = Up_3d(base_c * 8, base_c * 4 // factor, bilinear)
        self.up3 = Up_3d(base_c * 4, base_c * 2 // factor, bilinear)
        self.up4 = Up_3d(base_c * 2, base_c, bilinear)
        self.out_conv = OutConv_3d(base_c, num_classes)
        
        # resnet
        
        self.conv1 = nn.Conv3d(
            in_channels,
            64,
            kernel_size=7,
            stride=(1, 1, 1), #stride=(2, 2, 2)-> (1, 1, 1)
            padding=(3, 3, 3),
            bias=False)
            
        self.bn1 = nn.BatchNorm3d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0], shortcut_type)
        self.layer2 = self._make_layer(
            block, 128, layers[1], shortcut_type, stride=2)
        self.layer3 = self._make_layer(
            block, 256, layers[2], shortcut_type, stride=2)
        self.layer4 = self._make_layer(
            block, 512, layers[3], shortcut_type, stride=2)


        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                m.weight = nn.init.kaiming_normal(m.weight, mode='fan_out')
            elif isinstance(m, nn.BatchNorm3d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
                


    def _make_layer(self, block, planes, blocks, shortcut_type, stride=1, dilation=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            if shortcut_type == 'A':
                downsample = partial(
                    downsample_basic_block,
                    planes=planes * block.expansion,
                    stride=stride,
                    no_cuda=self.no_cuda)
            else:
                downsample = nn.Sequential(
                    nn.Conv3d(
                        self.inplanes,
                        planes * block.expansion,
                        kernel_size=1,
                        stride=stride,
                        bias=False), nn.BatchNorm3d(planes * block.expansion))

        layers = []
        layers.append(block(self.inplanes, planes, stride=stride, downsample=downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
        x = self.conv1(x)
        x = self.bn1(x)
        x1 = self.relu(x)
        #print('x1=',x1.shape)
        x2 = self.maxpool(x1)
        x2 = self.layer1(x2)
        #print('x2=',x2.shape)
        x3 = self.layer2(x2)
        #print('x3=',x3.shape)
        x4 = self.layer3(x3)
        #print('x4=',x4.shape)
        x5 = self.layer4(x4)
        #print('x5=',x5.shape)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.out_conv(x)

        return logits

例子1,使用Resnet50替换Encoder, 输入大小为 (1,64,64,64)

model = UNet_3d_resnet_encoder(block = Bottleneck, # 目前仅支持Resnet50以上的版本,不支持Resnet18/34
             layers = [3, 4, 6, 3],
             shortcut_type='B',
             no_cuda=False,
             in_channels= 1,
             num_classes= 2,
             bilinear= False, # 目前不支持=True
             base_c= 64) 
x=torch.randn(1,1,64,64,64)
X=model(x)
print('model output shape =')
print(X.shape)

model output shape =
torch.Size([1, 2, 64, 64, 64])

输入大小为 (3,128,128,64)

model = UNet_3d_resnet_encoder(block = Bottleneck, # 目前仅支持Resnet50以上的版本,不支持Resnet18/34
             layers = [3, 4, 6, 3],
             shortcut_type='B',
             no_cuda=False,
             in_channels= 3,
             num_classes= 2,
             bilinear= False, # 目前不支持=True
             base_c= 64) 
x=torch.randn(1,3,128,128,64)
X=model(x)
print('model output shape =')
print(X.shape)
model output shape =
torch.Size([1, 2, 128, 128, 64])

例子2,使用Resnet101替换Encoder

model = UNet_3d_resnet_encoder(block = Bottleneck, # 目前仅支持Resnet50以上的版本,不支持Resnet18/34
             layers = [3, 4, 23, 3],
             shortcut_type='B',
             no_cuda=False,
             in_channels= 1,
             num_classes= 2,
             bilinear= False, # 目前不支持=True
             base_c= 64) 
x=torch.randn(1,1,64,64,64)
X=model(x)
print('model output shape =')
print(X.shape)

model output shape =
torch.Size([1, 2, 64, 64, 64])

例子3,使用Resnet152替换Encoder文章来源地址https://www.toymoban.com/news/detail-555567.html

model = UNet_3d_resnet_encoder(block = Bottleneck, # 目前仅支持Resnet50以上的版本,不支持Resnet18/34
             layers = [3, 8, 36, 3],
             shortcut_type='B',
             no_cuda=False,
             in_channels= 1,
             num_classes= 2,
             bilinear= False, # 目前不支持=True
             base_c= 64) 
x=torch.randn(1,1,64,64,64)
X=model(x)
print('model output shape =')
print(X.shape)
model output shape =
torch.Size([1, 2, 64, 64, 64])

到了这里,关于[pytorch] 3D Unet + Resnet替换Encoder的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • UNet语义分割模型的使用-Pytorch

    最近有时间,跑了一下UNet模型,因为自己的深度学习基础不扎实,导致用了一些时间。目前只停留在使用和理解别人模型的基础上,对于优化模型的相关方法还有待学习。 众所周知,UNent是进行语义分割的知名模型,它的U形结构很多人也都见过,但是如果自己没有亲自试过

    2024年02月03日
    浏览(50)
  • 【ResNet】Pytorch从零构建ResNet18

    第一章 从零构建ResNet18 第二章 从零构建ResNet50 ResNet 目前是应用很广的网络基础框架,所以有必要了解一下,并且resnet结构清晰,适合练手 pytorch就更不用多说了。( 坑自坑 ) 懂自懂 本文使用以下环境构筑 深度残差网络(Deep residual network, ResNet)的提出是CNN图像史上的一件里

    2024年02月03日
    浏览(39)
  • 【pytorch】ResNet18、ResNet20、ResNet34、ResNet50网络结构与实现

    选取经典的早期Pytorch官方实现代码进行分析 https://github.com/pytorch/vision/blob/9a481d0bec2700763a799ff148fe2e083b575441/torchvision/models/resnet.py 各种ResNet网络是由BasicBlock或者bottleneck构成的,它们是构成深度残差网络的基本模块 ResNet的大部分各种结构是1层conv+4个block+1层fc 需要注意的是最后

    2024年02月02日
    浏览(47)
  • 【CVPR2018 3D ResNet】3D ResNet网络结构详解

    3D ResNet系列网络由日本国家工业科学技术研究院的Kensho Hara等人提出。接下来,我将对3D ResNet系列网络做出详细的网络结构解释,欢迎大家补充与提问。 我的github链接主页为XuecWu (Conna) · GitHub

    2024年02月12日
    浏览(42)
  • 语义分割系列7-Attention Unet(pytorch实现)

    继前文Unet和Unet++之后,本文将介绍Attention Unet。 Attention Unet地址,《Attention U-Net: Learning Where to Look for the Pancreas》。 Attention Unet发布于2018年,主要应用于医学领域的图像分割,全文中主要以肝脏的分割论证。 Attention Unet主要的中心思想就是提出来Attention gate模块,使用soft-at

    2024年02月06日
    浏览(40)
  • 《图像分割Unet网络分析及其Pytorch版本代码实现》

      最近两个月在做学习图像分割方面的学习,踩了无数的坑,也学到了很多的东西,想了想还是趁着国庆节有时间来做个总结,以后有这方面需要可以来看看。   神经网络被大规模的应用到计算机视觉中的分类任务中,说到神经网络的分类任务这里不得不提到CNN(卷积神经网

    2024年02月05日
    浏览(45)
  • 用pytorch实现Resnet

         ResNet(Residual Network) 是一种深度卷积神经网络架构,由Kaiming He等人于2015年提出。它在计算机视觉领域引起了革命性的变革,使得训练更深的神经网络成为可能,超越了传统网络架构的限制。      ResNet的主要创新在于 残差学习 的概念。传统神经网络存在梯度消失的

    2024年02月12日
    浏览(37)
  • pytorch复现ResNet

    测试

    2024年02月12日
    浏览(34)
  • 利用Pytorch实现ResNet网络

    目  录 1 ResNet网络介绍 1.1 ResNet网络的亮点 1.2 梯度消失、梯度爆炸和退化问题 1.3 残差(residual)模块 1.3.1 残差模块介绍 1.3.2 特殊的残差模块 1.4 Batch Normalization 1.4.1 BN处理原理 1.4.2 BN处理使用时需要注意的问题 1.5 迁移学习 1.5.1 使用迁移学习的优势 1.5.2 迁移学习原理简介

    2024年02月03日
    浏览(27)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包