U-Net网络结构解析和代码解析

这篇具有很好参考价值的文章主要介绍了U-Net网络结构解析和代码解析。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

U-Net网络结构详解

在语义分割领域,基于深度学习的语义分割算法开山之作是FCN(Fully Convolutional Networks for Semantic Segmentation),而U-Net是遵循FCN的原理,并进行了相应的改进,使其适应小样本的简单分割问题。U-Net网络在医疗影像领域的应用十分广泛,成为了大多数医疗影像语义分割任务的baseline,同时基于U-Net网络改进网络也纷纷出现,本篇文章主要介绍U-NET网络。

由于医学图像往往包含噪声且边界模糊,仅靠低层次的图像特征难以进行目标检测。同时,由于缺乏图像的细节信息,仅靠图像语义特征无法得到准确的边界。而U-Net通过跳跃连接,将低分辨率和高分辨率的特征映射结合起来,有效地融合了低层次和高层次的图像特征,从而成为医学图像分割任务的一个理想解决方案。目前,U-Net已经成为了大多数医学图像分割任务的一个基准,并且激发了很多有意义的改进方法,其网络结构下图所示。

U-Net网络结构解析和代码解析

U-Net是一个全卷积神经网络,输入与输出都是图像,没有全连接层;并且由图可知,U-Net在宏观上是一个对称的网络结构,左侧为下采样,右侧为上采样,同时按照功能可以将左侧的一系列下采样操作称为encoder,将右侧的一系列上采样操作称为decoder,因此U-Net网络可以划分到Encoder-decoder基础模型类型中;该网络最主要的两个特点是:U型网络结构和Skip Connection跳层连接。

Skip Connection跳层连接中间四条灰色的箭头copy and crop,Skip Connection是在上采样的过程中,融合下采样过过程中的feature map。

Skip Connection跳层连接用到的融合的操作也很简单,就是将feature map的通道进行叠加,俗称Concat。例如,一个大小为256×256×64的feature map,即feature map的w(宽)为256,h(高)为256,c(通道数)为64;和一个大小为256×256×32的feature map进行Concat融合,就会得到一个大小为256×256×96的feature map。

在实际使用中,Concat融合的两个feature map的大小不一定相同,例如256×256×64的feature map和240×240×32的feature map进行Concat。解决这个问题有两种办法:

  • 第一种:将大256×256×64的feature map进行裁剪,裁剪为240×240×64的feature map,比如上下左右,各舍弃8 pixel,裁剪后再进行Concat,得到240×240×96的feature map。

  • 第二种:将小240×240×32的feature map进行padding操作,padding为256×256×32的feature map,比如上下左右,各补8 pixel,padding后再进行Concat,得到256×256×96的feature map。

U-Net网络核心思想:

  • 不含全连接层(fc)的全卷积(fully conv)网络。可适应任意尺寸输入。
  • 增大数据尺寸的反卷积(deconv)层。能够输出精细的结果。
  • 结合不同深度层结果的跳级(skip)结构。同时确保鲁棒性和精确性。

这里使用1×1的卷积替代全连接层还有一个好处:输入的图片形状不再固定了。由于全连接层的输入必须固定形状的,所以输入模型的图片一般都要先resize到固定的shape,而使用1×1卷积代替全连接层之后变不在存在这一问题。在推理的时候,不需要再对图片进行resize,从而最好可能会导致输出的图片的失真。

这么一个不断加深网络并不断增加通道数来提取浅层信息和深层特征的过程就是编码器 (Encoder)

U-Net未能解决的一些问题:

  • 组织器官的顶层截面和底层截面与中部截面差异过大而不易识别;
  • 不同扫描影像之间有较大的外观变异而不易识别;
  • 磁场不均匀引起的伪影和畸变,导致不易识别。

U-Net网络架构实现代码解析

将U-Net网络中的架构分解为四个模块:

  1. 输入层的DoubleConv模块;
  2. 左侧分支从第二层开始的max_pool+DoubleConv,称为Down模块;
  3. 右侧分支的up_conv+copy_crop+DoubleConv,称为Up模块;
  4. 输出层的1x1卷积,称为OutConv模块。

U-Net网络结构解析和代码解析

从上图可以看到,Unet网络的结构比较简单,左侧分支每一层包含两个重复的卷积,命名为DoubleConv;从第二层开始,都是max pool + DoubleConv;右侧分支每一层都是up conv + copy crop + DoubleConv;在最后输出层,有一个1x1 conv。

1. 模块实现

1.1 DoubleConv模块

DoubleConv模块由两个“Conv2d+NatchNorm2d+ReLU”组成:

# unet_parts.py
import torch
import torch.nn as nn
import torch.nn.functional as F
 
class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""
 
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
 
    def forward(self, x):
        return self.double_conv(x)

1.2 Down模块

Down模块由一个“MaxPool2d+DoubleConv”组成:

class Down(nn.Module):
    """Downscaling with maxpool then double conv"""
 
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )
 
    def forward(self, x):
        return self.maxpool_conv(x)

1.3 Up模块

右侧上行模块涉及到copy and crop,实现起来会略微复杂一些。首先经过一个上采样或转置卷积,然后从左侧路径的同一层feature map中截取相同的size(从图中很容易可以看出,左侧同一层中的feature map比右侧的size要大一些),与右侧feature map合并,最后再进行DoubleConv。代码如下:

class Up(nn.Module):
    """Upscaling then double conv"""
 
    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()
 
        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)
 
        self.conv = DoubleConv(in_channels, out_channels)
 
    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = torch.tensor([x2.size()[2] - x1.size()[2]])
        diffX = torch.tensor([x2.size()[3] - x1.size()[3]])
 
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
 
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)
  • Upsample通过插值方法完成上采样,不需要训练参数。
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

其中mode为选择的插值方法,双线性插值如下图所示:

U-Net网络结构解析和代码解析

已知Q11、Q12、Q21、Q22四个点坐标,通过Q11和Q21求R1,再通过Q12和Q22求R2,最后通过R1和R2求P,这个过程就是双线性插值。

对于一个feature map而言,其实就是在像素点中间补点,补的点的值是多少,是由相邻像素点的值决定的。

  • ConvTranspose2d可以理解为卷积的逆过程,可以训练参数。
nn.ConvTranspose2d(mid_channels, mid_channels, kernel_size=4, stride=2, padding=1)

其中输出尺寸与输入关系如下,所以,k=4, s=2, p=1即2倍上采样。

o u t p u t = ( i n p u t − 1 ) ∗ s t r i d e + o u t p u t p a d d i n g − 2 ∗ p a d d i n g + k e r n e l s i z e output = (input-1)*stride+outputpadding-2*padding+kernelsize output=(input1)stride+outputpadding2padding+kernelsize

具体执行过程为通过对原图插值0,扩大尺寸,然后改变卷积参数,对扩大尺寸后的进行卷积即nn.ConvTranspose2d:

  1. 原图插值,在两两元素之间插0;
  2. 改变参数。新的卷积核:Stride′=1, kernel的size′ = size padding’ 为Size−padding−1;
  3. 卷积。

U-Net网络结构解析和代码解析

1.4 OutConv模块

输出层中用1x1卷积实现:

class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
 
    def forward(self, x):
        return self.conv(x)

2. 整体架构

# unet_model.py
import torch.nn.functional as F
from .unet_parts import *

class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear
 
        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 512)
        self.up1 = Up(1024, 256, bilinear)
        self.up2 = Up(512, 128, bilinear)
        self.up3 = Up(256, 64, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)
 
    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits
 
if __name__ == '__main__':
    net = UNet(n_channels=3, n_classes=1)
    print(net)

U-Net网络案例:分割细胞边缘

1. 数据加载

基于模板进行数据加载:  
# ================================================================== #
#                Input pipeline for custom dataset                 #
# ================================================================== #
 
# You should build your custom dataset as below.
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self):
        # TODO
        # 1. Initialize file paths or a list of file names. 
        pass
    def __getitem__(self, index):
        # TODO
        # 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
        # 2. Preprocess the data (e.g. torchvision.Transform).
        # 3. Return a data pair (e.g. image and label).
        pass
    def __len__(self):
        # You should change 0 to the total size of your dataset.
        return 0 
 
# You can then use the prebuilt data loader. 
custom_dataset = CustomDataset()
train_loader = torch.utils.data.DataLoader(dataset=custom_dataset,
                                           batch_size=64, 
                                           shuffle=True)

使用上面的标准模板,进行加载数据、定义标签、数据增强等操作。

# dataset.py
import torch
import cv2
import os
import glob
from torch.utils.data import Dataset
import random
 
class ISBI_Loader(Dataset):
    def __init__(self, data_path):
        # 初始化函数,读取所有data_path下的图片
        self.data_path = data_path
        self.imgs_path = glob.glob(os.path.join(data_path, 'image/*.png'))
 
    def augment(self, image, flipCode):
        # 使用cv2.flip进行数据增强,filpCode为1水平翻转,0垂直翻转,-1水平+垂直翻转
        flip = cv2.flip(image, flipCode)
        return flip
        
    def __getitem__(self, index):
        # 根据index读取图片
        image_path = self.imgs_path[index]
        # 根据image_path生成label_path
        label_path = image_path.replace('image', 'label')
        # 读取训练图片和标签图片
        image = cv2.imread(image_path)
        label = cv2.imread(label_path)
        # 将数据转为单通道的图片
        image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        label = cv2.cvtColor(label, cv2.COLOR_BGR2GRAY)
        image = image.reshape(1, image.shape[0], image.shape[1])
        label = label.reshape(1, label.shape[0], label.shape[1])
        # 处理标签,将像素值为255的改为1
        if label.max() > 1:
            label = label / 255
        # 随机进行数据增强,为2时不做处理
        flipCode = random.choice([-1, 0, 1, 2])
        if flipCode != 2:
            image = self.augment(image, flipCode)
            label = self.augment(label, flipCode)
        return image, label
 
    def __len__(self):
        # 返回训练集大小
        return len(self.imgs_path)
 
    
if __name__ == "__main__":
    isbi_dataset = ISBI_Loader("data/train/")
    print("数据个数:", len(isbi_dataset))
    train_loader = torch.utils.data.DataLoader(dataset=isbi_dataset,
                                               batch_size=2, 
                                               shuffle=True)
    for image, label in train_loader:
        print(image.shape)

__init__函数是类的初始化函数,根据指定的图片路径data_path,读取所有图片数据,存放到self.imgs_path列表中。

__len__函数返回数据的数量,这个类实例化后,通过len()函数调用。

__getitem__函数是数据获取函数,函数里定义数据读取方式,处理方式,同时数据预处理、数据增强等也在该函数中进行定义。该案例中首先读取图片,并将其处理成单通道图片;将 label 的图片像素点的范围从[0, 255]归一化为[0, 1];最后随机进行了数据增强。

augment函数是定义数据增强函数,案例中进行的是旋转操作,除此之外还可以定义多种其他数据增强函数。

在这个类中,不用进行一些打乱数据集的操作,也不用管怎么按照 batchsize 读取数据。实例化这个类后,用 torch.utils.data.DataLoader 方法指定 batchsize 的大小,决定是否打乱数据。

2. 模型选择

这里使用的第二部分的U-Net网络结构,不再详细解释。

3. 算法选择

Loss函数的选择会对算法拟合数据效果产生较大的影响,分割细胞边缘是一个简单的二分类任务,所以选择使用BCEWithLogitsLoss。BCEWithLogitsLoss 是 Pytorch 提供的用来计算二分类交叉熵的函数:

l o s s ( o , t ) = − 1 / n ∑ i ( t [ i ] ∗ l o g ( o [ i ] ) + ( 1 − t [ i ] ) ∗ l o g ( 1 − o [ i ] ) ) loss(o, t) = -1/n\sum_i(t[i]*log(o[i]) + (1-t[i])*log(1-o[i])) loss(o,t)=1/ni(t[i]log(o[i])+(1t[i])log(1o[i]))

这是 Logistic 回归的损失函数,该函数利用 Sigmoid 函数阈值在[0,1]这个特性来进行分类。

优化目标的算法选择了一种基于AdaGrad改进的自适应的优化算法:RMSProp。优化算法本质上是梯度下降,例如:SGD(随机梯度下降算法)、Momentum(引入了动量的 SGD,以指数衰减的形式累计历史梯度)。而自适应参数的优化算法最大的特点是每个参数有不同的学习率,在整个学习过程中自动适应这些学习率,从而达到更好的收敛效果。

# train.py
from model.unet_model import UNet
from utils.dataset import ISBI_Loader
from torch import optim
import torch.nn as nn
import torch
 
def train_net(net, device, data_path, epochs=40, batch_size=1, lr=0.00001):
    # 加载训练集
    isbi_dataset = ISBI_Loader(data_path)
    train_loader = torch.utils.data.DataLoader(dataset=isbi_dataset,
                                               batch_size=batch_size, 
                                               shuffle=True)
    # 定义RMSprop算法
    optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9)
    # 定义Loss算法
    criterion = nn.BCEWithLogitsLoss()
    # best_loss统计,初始化为正无穷
    best_loss = float('inf')
    # 训练epochs次
    for epoch in range(epochs):
        # 训练模式
        net.train()
        # 按照batch_size开始训练
        for image, label in train_loader:
            optimizer.zero_grad()
            # 将数据拷贝到device中
            image = image.to(device=device, dtype=torch.float32)
            label = label.to(device=device, dtype=torch.float32)
            # 使用网络参数,输出预测结果
            pred = net(image)
            # 计算loss
            loss = criterion(pred, label)
            print('Loss/train', loss.item())
            # 保存loss值最小的网络参数
            if loss < best_loss:
                best_loss = loss
                torch.save(net.state_dict(), 'best_model.pth')
            # 更新参数
            loss.backward()
            optimizer.step()
 
if __name__ == "__main__":
    # 选择设备,有cuda用cuda,没有就用cpu
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # 加载网络,图片单通道1,分类为1。
    net = UNet(n_channels=1, n_classes=1)
    # 将网络拷贝到deivce中
    net.to(device=device)
    # 指定训练集地址,开始训练
    data_path = "data/train/"
    train_net(net, device, data_path)

4. 预测

import glob
import numpy as np
import torch
import os
import cv2
from model.unet_model import UNet
 
if __name__ == "__main__":
    # 选择设备,有cuda用cuda,没有就用cpu
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # 加载网络,图片单通道,分类为1。
    net = UNet(n_channels=1, n_classes=1)
    # 将网络拷贝到deivce中
    net.to(device=device)
    # 加载模型参数
    net.load_state_dict(torch.load('best_model.pth', map_location=device))
    # 测试模式
    net.eval()
    # 读取所有图片路径
    tests_path = glob.glob('data/test/*.png')
    # 遍历所有图片
    for test_path in tests_path:
        # 保存结果地址
        save_res_path = test_path.split('.')[0] + '_res.png'
        # 读取图片
        img = cv2.imread(test_path)
        # 转为灰度图
        img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
        # 转为batch为1,通道为1,大小为512*512的数组
        img = img.reshape(1, 1, img.shape[0], img.shape[1])
        # 转为tensor
        img_tensor = torch.from_numpy(img)
        # 将tensor拷贝到device中,只用cpu就是拷贝到cpu中,用cuda就是拷贝到cuda中。
        img_tensor = img_tensor.to(device=device, dtype=torch.float32)
        # 预测
        pred = net(img_tensor)
        # 提取结果
        pred = np.array(pred.data.cpu()[0])[0]
        # 处理结果
        pred[pred >= 0.5] = 255
        pred[pred < 0.5] = 0
        # 保存图片
        cv2.imwrite(save_res_path, pred)

输出结果:

U-Net网络结构解析和代码解析

整体目录:

├── data
│ ├── test
│ │ ├──……
│ │ └── *.png
│ └── train
│ ├──……
│ └── *.png
├── model
│ ├── unet_model.py
│ └── unet_parts.py
├── utils
│ └── dataset.py
├── train.py
└── predict.py文章来源地址https://www.toymoban.com/news/detail-450993.html

到了这里,关于U-Net网络结构解析和代码解析的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 网络安全等级保护通用三级系统整体拓扑结构分值区间解析

    1.等保2.0 三级信息系统 70-80分 拓扑图:   2.设备清单:下一代防火墙(含IPS、AV)+综合日志审计系统+堡垒机+数据库审计系统+杀毒软件。 其他参考方案: 【接入边界NGFW】【必配】:融合防火墙安全策略、访问控制功能。解决安全区域边界要求,并开启AV模块功能;配置网络

    2024年02月06日
    浏览(42)
  • 卷积神经网络模型之——AlexNet网络结构与代码实现

    AlexNet原文地址:https://proceedings.neurips.cc/paper/2012/file/c399862d3b9d6b76c8436e924a68c45b-Paper.pdf AlexNet诞生于2012年,由2012年ImageNet竞赛冠军获得者Hinton和他的学生Alex Krizhevsky设计的。 AlexNet的贡献点: 首次使用GPU加速网络训练 使用ReLU激活函数,代替不是传统的Sigmoid和Tanh,解决了Sigmo

    2024年02月08日
    浏览(43)
  • 计算机网络——第一章体系结构相关习题及详细解析

    在OSI参考模型中,自下而上第一个提供端到端服务的层次是: A.数据链路层        B.传输层        C.会话层        D.应用层 答案选择: B.传输层 即, 在OSI参考模型中,自下而上第一个提供端到端服务的层次是传输层。  解析 为了解决这道题,我们首先要了解OSI体系结构

    2024年02月08日
    浏览(48)
  • YOLO系列 --- YOLOV7算法(四):YOLO V7算法网络结构解析

    今天来讲讲YOLO V7算法网络结构吧~ 在 train.py 中大概95行的地方开始创建网络,如下图(YOLO V7下载的时间不同,可能代码有少许的改动,所以行数跟我不一定一样) 我们进去发现,其实就是在 yolo.py 里面。后期,我们就会发现相关的网络结构都是在该py文件里面。这篇blog就主

    2024年02月05日
    浏览(46)
  • YOLOv8详解 【网络结构+代码+实操】

    🚀🚀🚀 目标检测——Yolo系列(YOLOv1/2/v3/4/5/x/6/7) ✨✨✨ YOLOv8改进——引入可变形卷积DCNv3 YOLOv8是目前YOLO系列算法中最新推出的检测算法,YOLOv8可以完成检测、分类、分割任务,其检测和分割网络结构图如下。 YOLOv8 算法的核心特性和改动可以归结为如下: 提供了一个全新

    2024年02月01日
    浏览(54)
  • YOLO v5 代码精读(3)YOLO网络结构

    YOLO模型共有五种模型规格,规格越大的模型准确率越高,相应的预测时间也就越长。一般默认选择YOLOv5s,也可根据需求选择更大或更小的模型。 这里以YOLO v5s为例,分析YOLO的网络结构。 配置变量 nc:表示检测的类别数量,这里默认取自coco数据集的80个类别 depth_multiple:控制

    2023年04月15日
    浏览(39)
  • U-Net网络

    U-Net 融合了 编码 - 解码结构和跳跃网络 的特点,在模型结构上更加巧妙,主要体现在以下两点: ● ( 1 ) U-Net 模型是一个 编码 - 解码的结构 ,压缩通道是一个编码器,用于逐层提取影像的特征,扩展通道是一个解码器,用于还原影像的位置信息,且 U-Net 模型的每一个隐

    2024年02月02日
    浏览(43)
  • 简单有趣的轻量级网络 Shufflenet v1 、Shufflenet v2(网络结构详解+详细注释代码+核心思想讲解)——pytorch实现

         这期博客咱们来学习一下Shufflenet系列轻量级卷积神经网络,Shufflenet v1 、Shufflenet v2。 本博客代码可以直接生成训练集和测试集的损失和准确率的折线图,便于写论文使用。 论文下载链接: Shufflene系列轻量级卷积神经网络由旷世提出,也是非常有趣的轻量级卷积神经网

    2024年02月01日
    浏览(48)
  • U-net网络详解

    论文地址:https://arxiv.org/abs/1505.04597 学习视频:U-Net网络结构讲解(语义分割)_哔哩哔哩_bilibili 如下图,U-net结构为Encoder-Decoder,左边为Encoder部分,作用是下采样,右边为Decoder部分,作用为上采样 在图中每一个长条代表一个特征层,每一个箭头对应于一种操作 第一步 如下图,

    2024年02月02日
    浏览(36)
  • U-net网络学习记录

    本质上是一个用于图像分割的神经网络 输入是一幅图,输出是目标的分割结果。继续简化就是,一幅图,编码,或者说降采样,然后解码,也就是升采样,然后输出一个分割结果。根据结果和真实分割的差异,反向传播来训练这个分割网络 既然输入和输出都是相同大小的图

    2024年02月09日
    浏览(36)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包