详解U-Net分割网络,提供详细代码技术细节及完整项目代码

这篇具有很好参考价值的文章主要介绍了详解U-Net分割网络,提供详细代码技术细节及完整项目代码。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

一. 原始模型整体概述

U-Net网络是Ronneberger等人在2015年发表于计算机医学影像顶刊 MICCAI上的一篇论文,该论文首次提出了一种U型结构来进行图像的语义分割,论文的下载链接如下:U-Net: Convolutional Networks for Biomedical Image Segmentation

该网络赢得了2015年ISBI细胞跟踪挑战赛,一经提出便获得了巨大的关注,在计算机视觉领域的影响堪比凯明大神的ResNet网络,也正巧是同一年提出的,2015年真是一个AI的爆发之年,再后来便就是2017年的transformer网络提出了。

闲话少说,我们先来看一下U-Net网络的整体架构图,如下图所示:

u-net代码,深度学习2D/3D分割模型解读,网络,cnn,人工智能

该模型主要由两个部分组成,左边是编码器(文章里称作收缩路径,contracting path),右边是解码器( 文章里称作膨胀路径,expansive path)。这种U型结构和跳跃连接可以更好地获取上下文信息和位置信息。

其中编码器的每一行遵守常规卷积神经网络的典型架构,由两个3x3的卷积(无padding操作)构成。每个卷积层的后面跟随一个整流线性单元ReLU和一个2x2的最大池化操作,步幅为2,用于下采样,在每个下采样的过程中将特征通道数翻倍

解码器的每一步都包含特征图的上采样,然后是一个将特征通道数量减半的2x2卷积(“上卷积”),与收缩路径中相应裁剪的特征map进行连接,以及两个3x3卷积,每个卷积后面都有一个ReLU。由于在每次卷积中边界像素的损失,所以需要裁剪(crop)。在最后一层,使用1x1卷积将每个64个分量的特征向量map到所需的类数。这个网络总共有23个卷积层。

U-Net网络的优点及为什么在医学影像中表现优异:

1.通过跳跃连接,补充了语义信息的丢失。深层卷积得到的特征图有着更大的感受野,关注更多抽象的粗粒度的信息;而浅层卷积则关注纹理,颜色,边缘特征等中高级语义信息,所以深层浅层特征都是有意义的;另外一点是通过反卷积得到的更大的尺寸的特征图的边缘,是缺少信息的,毕竟每一次下采样提炼特征的同时,也必然会损失一些特征,而失去的特征并不能从上采样中找回,因此通过特征的拼接,来实现特征信息的找回。由于医学影像中的所有特征都很重要,所以跳跃连接起到了关键作用。

2.U-Net模型的结构比较简单,在模型参数较少的情况下获得了更多的语义信息,所以在数据规模较小的医学影像数据集上表现出更好的效果。

二. U-Net模型的改进

通过上述模型架构设计,我们很容易发现这个架构图有一些不便之处,我们可以将模型改成以下结构来方便模型构建并改善模型性能。这也是当前最流行的U-Net构造方法。

1.将卷积层的padding设置为1。由于编码器和解码器中使用的卷积层都没有进行padding操作,卷积层核大小为3x3,stride为1,padding=0,这就是一个很普通的卷积层,卷积之后,特征图的高和宽会各减少2。具体计算步骤:N=(W-F+2P)/S+1,其中N为输出特征图的大小,W为输入特征图大小,F为卷积核大小,P为padding的像素个数,S为stride的大小。则N = (572-3+0)/1 +1= 570。

我们如果将padding设置为1,则卷积操作后特征图的高和宽都不会发生变化,这样的话就免去了拼接时由于尺寸不一致而产生额外的裁剪操作,改进之后可以直接concat,而不用crop操作。此外,这种做法可以保证输入的图像大小核输出的图像大小完全一致,避免了边缘缺失导致的误差。

2.在每个卷积层和ReLu之间加上Batch normalization,这个做可以加速网络训练并提升网络精度。关于BN层的原理和作用,请参见另外一篇优质博文:BN层的原理和作用

三.改进U-Net模型的完整代码 

from typing import Dict
import torch
import torch.nn as nn
import torch.nn.functional as F

# 定义双卷积类
class DoubleConv(nn.Sequential):
    def __init__(self, in_channels, out_channels, mid_channels=None):
        if mid_channels is None:
            mid_channels = out_channels
        super(DoubleConv, self).__init__(
            # 定义padding为1,保持卷积后特征图的宽度和高度不变,具体计算N= (W-F+2P)/S+1
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            # 加入BN层,提升训练速度,并提高模型效果
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            # 第二个卷积层,同第一个
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            # BN层
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

# 调用上面定义的双卷积类,定义下采样类
class Down(nn.Sequential):
    def __init__(self, in_channels, out_channels):
        super(Down, self).__init__(
            # 最大池化,步长为2,池化核大小为2,计算公式同卷积,则 N = (W-F+2P)/S+1,  N= (W-2+0)/4 + 1
            nn.MaxPool2d(2, stride=2),
            DoubleConv(in_channels, out_channels)
        )

class Up(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Up, self).__init__()

        # 调用转置卷积的方法进行上采样,使特征图的高和宽翻倍,out  =(W−1)×S−2×P+F,通道数减半
        self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
        # 调用双层卷积类,通道数是否减半要看out_channels接收的值
        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
        x1 = self.up(x1)

        # X的shape为[N, C, H, W],下面三行代码主要是为了保证x1和x2在维度为2和3的地方保持一致,方便cat操作不出错。
        diff_y = x2.size()[2] - x1.size()[2]
        diff_x = x2.size()[3] - x1.size()[3]
        # 增加padding操作,padding_left, padding_right, padding_top, padding_bottom
        x1 = F.pad(x1, [diff_x // 2, diff_x - diff_x // 2,
                        diff_y // 2, diff_y - diff_y // 2])

        x = torch.cat([x2, x1], dim=1)
        x = self.conv(x)
        return x


# 定义输出卷积类
class OutConv(nn.Sequential):
    def __init__(self, in_channels, num_classes):
        super(OutConv, self).__init__(
            nn.Conv2d(in_channels, num_classes, kernel_size=1)
        )


class UNet(nn.Module):
    def __init__(self,
                 in_channels: int = 1,  # 默认输入图像的通道数为1,这里一般黑白图像为1,而彩色图像为3
                 num_classes: int = 2,  # 默认输出的分类类别数为2
                 # 默认基础通道为64,这里也可以改成大于2的任意2的次幂,不过越大模型的复杂度越高,参数越大,模型的拟合能力也就越强
                 base_c: int = 64):

        super(UNet, self).__init__()
        self.in_channels = in_channels
        self.num_classes = num_classes
        # 编码器的第1个双卷积层,不包含下采样过程,输入通道为1,输出通道数为base_c,这个值可以为64或者32        
        self.in_conv = DoubleConv(in_channels, base_c)
        # 编码器的第2个双卷积层,先进行下采样最大池化使得特征图高和宽减半,然后通道数翻倍        
        self.down1 = Down(base_c, base_c * 2)
        # 编码器的第3个双卷积层,先进行下采样最大池化使得特征图高和宽减半,然后通道数翻倍       
        self.down2 = Down(base_c * 2, base_c * 4)
        # 编码器的第4个双卷积层,先进行下采样最大池化使得特征图高和宽减半,然后通道数翻倍        
        self.down3 = Down(base_c * 4, base_c * 8)
        # 编码器的第5个双卷积层,先进行下采样最大池化使得特征图高和宽减半,然后通道数翻倍        
        self.down4 = Down(base_c * 8, base_c * 16)

        # 解码器的第1个上采样模块,首先进行一个转置卷积,使特征图的高和宽翻倍,通道数减半;
        # 对x1(x1可以到总的forward函数中可以知道它代指什么)进行padding,使其与x2的尺寸一致,然后在第1维通道维度进行concat,通道数翻倍。
        # 最后再进行一个双卷积层,通道数减半,高和宽不变。       
        self.up1 = Up(base_c * 16, base_c * 8)
        # 解码器的第2个上采样模块,操作同上        
        self.up2 = Up(base_c * 8, base_c * 4)
        # 解码器的第3个上采样模块,操作同上       
        self.up3 = Up(base_c * 4, base_c * 2)
        # 解码器的第4个上采样模块,操作同上        
        self.up4 = Up(base_c * 2, base_c)
        # 解码器的输出卷积模块,改变输出的通道数为分类的类别数        
        self.out_conv = OutConv(base_c, num_classes)

    def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
        # 假设输入的特征图尺寸为[N, C, H, W],[4, 3, 480, 480],依次代表BatchSize, 通道数量,高,宽;   则输出为[4, 64, 480,480]
        x1 = self.in_conv(x)
        # 输入的特征图尺寸为[4, 64, 480, 480];  输出为[4, 128, 240,240]
        x2 = self.down1(x1)
        # 输入的特征图尺寸为[4, 128, 240,240];  输出为[4, 256, 120,120]
        x3 = self.down2(x2)
        # 输入的特征图尺寸为[4, 256, 120,120];  输出为[4, 512, 60,60]
        x4 = self.down3(x3)
        # 输入的特征图尺寸为[4, 512, 60,60];  输出为[4, 1024, 30,30]
        x5 = self.down4(x4)
        
        # 输入的特征图尺寸为[4, 1024, 30,30];  输出为[4, 512, 60, 60]
        x = self.up1(x5, x4)
        # 输入的特征图尺寸为[4, 512, 60,60];  输出为[4, 256, 120, 120]
        x = self.up2(x, x3)
        # 输入的特征图尺寸为[4, 256, 120,120];  输出为[4, 128, 240, 240]
        x = self.up3(x, x2)
        # 输入的特征图尺寸为[4, 128, 240,240];  输出为[4, 64, 480, 480]
        x = self.up4(x, x1)
        # 输入的特征图尺寸为[4, 64, 480,480];  输出为[4, 2, 480, 480]
        logits = self.out_conv(x)

        return {"out": logits}

模型代码解读:

1.模型更改:上面是一个完整的U-Net网络的结构,该代码相对于原文的实现主要更改了卷积层的padding和添加了BN层。改善了边缘信息丢失并且由于BN层的添加使模型效率更高;

2.代码主要由两部分组成,一个编码器,一个解码器。编码器由1个初始双卷积层和4个下采样层构成;解码器由4个上采样层和1个输出卷积层构成,其中上采样层由一个转置卷积,对x1进行padding操作,cat操作和1个双卷积层构成。这里解释一下为什么还要padding,如果你在训练数据和测试数据输入时严格按照480*480进行输入,则不需要padding,x1和x2的也会保持相同的形状。而大部分情况下,测试数据不会做裁剪再输入,所以很难保证输入时由于分辨率为奇数时,由于池化和转置卷积的宽和高的减半和翻倍原因导致尺寸有出入而不匹配。如果嫌麻烦,可以删掉padding部分的代码,然后将测试集的图像大小严格resize;

3.模型的执行过程按照UNet类的forward中的顺序执行。

四.模型的关键代码展示

1. transform.py  一些关于数据增强的功能函数类,其实这些类在torchvision中也有实现。

import numpy as np
import random

import torch
from torchvision import transforms as T
from torchvision.transforms import functional as F


def pad_if_smaller(img, size, fill=0):
    # 如果图像最小边长小于给定size,则用数值fill进行padding
    min_size = min(img.size)
    if min_size < size:
        ow, oh = img.size
        padh = size - oh if oh < size else 0
        padw = size - ow if ow < size else 0
        img = F.pad(img, (0, 0, padw, padh), fill=fill)
    return img


class Compose(object):
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, target):
        for t in self.transforms:
            image, target = t(image, target)
        return image, target


class RandomResize(object):
    def __init__(self, min_size, max_size=None):
        self.min_size = min_size
        if max_size is None:
            max_size = min_size
        self.max_size = max_size

    def __call__(self, image, target):
        size = random.randint(self.min_size, self.max_size)
        # 这里size传入的是int类型,所以是将图像的最小边长缩放到size大小
        image = F.resize(image, size)
        # 这里的interpolation注意下,在torchvision(0.9.0)以后才有InterpolationMode.NEAREST
        # 如果是之前的版本需要使用PIL.Image.NEAREST
        target = F.resize(target, size, interpolation=T.InterpolationMode.NEAREST)
        return image, target


class RandomHorizontalFlip(object):
    def __init__(self, flip_prob):
        self.flip_prob = flip_prob

    def __call__(self, image, target):
        if random.random() < self.flip_prob:
            image = F.hflip(image)
            target = F.hflip(target)
        return image, target


class RandomVerticalFlip(object):
    def __init__(self, flip_prob):
        self.flip_prob = flip_prob

    def __call__(self, image, target):
        if random.random() < self.flip_prob:
            image = F.vflip(image)
            target = F.vflip(target)
        return image, target


class RandomCrop(object):
    def __init__(self, size):
        self.size = size

    def __call__(self, image, target):
        image = pad_if_smaller(image, self.size)
        target = pad_if_smaller(target, self.size, fill=255)
        crop_params = T.RandomCrop.get_params(image, (self.size, self.size))
        image = F.crop(image, *crop_params)
        target = F.crop(target, *crop_params)
        return image, target


class CenterCrop(object):
    def __init__(self, size):
        self.size = size

    def __call__(self, image, target):
        image = F.center_crop(image, self.size)
        target = F.center_crop(target, self.size)
        return image, target


class ToTensor(object):
    def __call__(self, image, target):
        image = F.to_tensor(image)
        target = torch.as_tensor(np.array(target), dtype=torch.int64)
        return image, target


class Normalize(object):
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, image, target):
        image = F.normalize(image, mean=self.mean, std=self.std)
        return image, target

2. my_dataset.py  ,用于数据集的提取和格式化

import os
from PIL import Image
import numpy as np
from torch.utils.data import Dataset


class DriveDataset(Dataset):  # 继承Dataset类
    def __init__(self, root: str, train: bool, transforms=None):
        super(DriveDataset, self).__init__()
        self.flag = "training" if train else "test"  # 根据train这个布尔类型确定需要处理的是训练集还是测试集
        data_root = os.path.join(root, "DRIVE", self.flag)  # 得到数据集根目录
        assert os.path.exists(data_root), f"path '{data_root}' does not exists."   # 判断路径是否存在
        self.transforms = transforms   # 初始化图像变换操作
        img_names = [i for i in os.listdir(os.path.join(data_root, "images")) if i.endswith(".tif")]   # 遍历图像文件夹获取每个图像的文件名
        self.img_list = [os.path.join(data_root, "images", i) for i in img_names]  # 获取图像路径
        self.manual = [os.path.join(data_root, "1st_manual", i.split("_")[0] + "_manual1.gif")  # 获取手动标签的路径
                       for i in img_names]
        # 检查手动标签文件是否存在
        for i in self.manual:
            if os.path.exists(i) is False:
                raise FileNotFoundError(f"file {i} does not exists.")
        # 获取分割的ROI区域掩码
        self.roi_mask = [os.path.join(data_root, "mask", i.split("_")[0] + f"_{self.flag}_mask.gif")
                         for i in img_names]
        # check files
        for i in self.roi_mask:
            if os.path.exists(i) is False:
                raise FileNotFoundError(f"file {i} does not exists.")

    def __getitem__(self, idx):
        img = Image.open(self.img_list[idx]).convert('RGB')  # 加载图像,并转换为RGB模式
        manual = Image.open(self.manual[idx]).convert('L')   # 加载手动标注图像,并转换为灰度模式
        manual = np.array(manual) / 255   # 进行归一化操作
        roi_mask = Image.open(self.roi_mask[idx]).convert('L')  # 加载ROI图像,并转换为灰度模式
        roi_mask = 255 - np.array(roi_mask)   # 对图像数组取反,使用这个方法将背景和前景颜色反转,白色是255,黑色是0,反转后ROI变成了内黑外白
        mask = np.clip(manual + roi_mask, a_min=0, a_max=255)   # 将手动标注图像和反转后的ROI图像相加,使用np.clip()将像素值控制在0-255范围,

        # 这里转回PIL的原因是,transforms中是对PIL数据进行处理
        mask = Image.fromarray(mask)

        if self.transforms is not None:
            img, mask = self.transforms(img, mask)

        return img, mask

    # 获取图像数据集长度
    def __len__(self):
        return len(self.img_list)

    # 用于将批量的图像和标签数据合并为一个批张量。
    @staticmethod
    def collate_fn(batch):
        images, targets = list(zip(*batch))  # 将批量数据拆分为图像和标签两个列表
        batched_imgs = cat_list(images, fill_value=0)  # 使用 cat_list() 函数将图像和标签列表合并成张量。用于将列表中的 PIL 图像数据堆叠成张量,
        batched_targets = cat_list(targets, fill_value=255)
        return batched_imgs, batched_targets


def cat_list(images, fill_value=0):
    max_size = tuple(max(s) for s in zip(*[img.shape for img in images]))  # 找到图像中最大的形状,以元组形式返回给max_size
    batch_shape = (len(images),) + max_size   # 计算出堆叠后的张量形状,包括批量大小和图像大小两个维度
    batched_imgs = images[0].new(*batch_shape).fill_(fill_value)  # 创建一个新的空白张量 batched_imgs,其形状与 batch_shape 相同,并将其填充为指定的填充值 fill_value
    for img, pad_img in zip(images, batched_imgs):  # 使用 zip() 函数将输入列表中的每个图像与其对应的空白张量进行拼接,以得到一个完整的张量。
        pad_img[..., :img.shape[-2], :img.shape[-1]].copy_(img)  # 将每个图像按照其实际大小插入到空白张量的左上角,以保持图像的相对位置不变。
    return batched_imgs

3. train.py, 模型训练的代码

import os
import time
import datetime

import torch

from src import UNet
from train_utils import train_one_epoch, evaluate, create_lr_scheduler
from my_dataset import DriveDataset
import transforms as T

# 定义训练集图像的预处理方式
class SegmentationPresetTrain:
    def __init__(self, base_size, crop_size, hflip_prob=0.5, vflip_prob=0.5,
                 mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
        min_size = int(0.5 * base_size)
        max_size = int(1.2 * base_size)

        trans = [T.RandomResize(min_size, max_size)]   # 对图像的短边(长和宽中最短的)进行随机缩放以适应不同图像输入尺寸,缩放范围为【min_size, max_size】
        if hflip_prob > 0:
            trans.append(T.RandomHorizontalFlip(hflip_prob))  # 加入水平翻转
        if vflip_prob > 0:
            trans.append(T.RandomVerticalFlip(vflip_prob))  # 加入垂直翻转
        trans.extend([
            T.RandomCrop(crop_size),   # 对图像进行随机裁剪
            T.ToTensor(),  # 将数组矩阵转换为tensor类型,规范化到【0,1】范围
            T.Normalize(mean=mean, std=std),  # 加入图像归一化,并定义均值和标准差,RGB三通道的
        ])
        # trans是一个列表类型,包含各种了变换,将这些变换组成一个compose变换,注意transforms.Compose()函数需要接收一个列表类型
        self.transforms = T.Compose(trans)

    # 使用__call__()函数来调用transforms变换
    def __call__(self, img, target):
        return self.transforms(img, target)  # target是指标签图像,img是指待分割图像


# 定义验证集的图像预处理组合类,比较简单,只有张量化和规范化两个操作,这里规范化使用的是ImageNet推荐的参数,注意这种做法是针对彩色图像
class SegmentationPresetEval:
    def __init__(self, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
        self.transforms = T.Compose([
            T.ToTensor(),
            T.Normalize(mean=mean, std=std),
        ])

    def __call__(self, img, target):
        return self.transforms(img, target)


# 定义一个函数根据数据集的类型来调用对应的数据集处理类
def get_transform(train, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
    base_size = 565
    crop_size = 480
    # 检查train是否为True
    if train:
        return SegmentationPresetTrain(base_size, crop_size, mean=mean, std=std)
    else:
        return SegmentationPresetEval(mean=mean, std=std)

# 定义模型创建函数,实例化UNet类创建模型,传入通道数,分割类别数,
def create_model(num_classes):
    model = UNet(in_channels=3, num_classes=num_classes, base_c=64)   # 输入通道数为3,分类类别数为2,
    # model = MobileV3Unet(num_classes=num_classes, pretrain_backbone=True)
    # model = VGG16UNet(num_classes=num_classes, pretrain_backbone=True)

    return model


def main(args):
    device = torch.device(args.device if torch.cuda.is_available() else "cpu")
    batch_size = args.batch_size

    # using compute_mean_std.py
    mean = (0.709, 0.381, 0.224)
    std = (0.127, 0.079, 0.043)

    # 用来保存训练以及验证过程中信息
    results_file = "results{}.txt".format(datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))

    train_dataset = DriveDataset(args.data_path,
                                 train=True,
                                 transforms=get_transform(train=True, mean=mean, std=std))

    val_dataset = DriveDataset(args.data_path,
                               train=False,
                               transforms=get_transform(train=False, mean=mean, std=std))

    num_workers = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])    # 如果batch_size>1, 线程数num_workers取min(cpu核数,batch_size)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               num_workers=num_workers,
                                               shuffle=True,
                                               pin_memory=True,
                                               collate_fn=train_dataset.collate_fn)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=1,
                                             num_workers=num_workers,
                                             pin_memory=True,
                                             collate_fn=val_dataset.collate_fn)

    model = create_model(num_classes=args.num_classes)
    model.to(device)

    params_to_optimize = [p for p in model.parameters() if p.requires_grad]

    optimizer = torch.optim.SGD(
        params_to_optimize,
        lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay
    )

    scaler = torch.cuda.amp.GradScaler() if args.amp else None

    # 创建学习率更新策略,这里是每个step更新一次(不是每个epoch)
    lr_scheduler = create_lr_scheduler(optimizer, len(train_loader), args.epochs, warmup=True)

    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        args.start_epoch = checkpoint['epoch'] + 1
        if args.amp:
            scaler.load_state_dict(checkpoint["scaler"])

    best_dice = 0.
    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        mean_loss, lr = train_one_epoch(model, optimizer, train_loader, device, epoch, args.num_classes,
                                        lr_scheduler=lr_scheduler, print_freq=args.print_freq, scaler=scaler)

        confmat, dice = evaluate(model, val_loader, device=device, num_classes=args.num_classes)
        val_info = str(confmat)
        print(val_info)
        print(f"dice coefficient: {dice:.3f}")
        # write into txt
        with open(results_file, "a") as f:
            # 记录每个epoch对应的train_loss、lr以及验证集各指标
            train_info = f"[epoch: {epoch}]\n" \
                         f"train_loss: {mean_loss:.4f}\n" \
                         f"lr: {lr:.6f}\n" \
                         f"dice coefficient: {dice:.3f}\n"
            f.write(train_info + val_info + "\n\n")

        if args.save_best is True:
            if best_dice < dice:
                best_dice = dice
            else:
                continue

        save_file = {"model": model.state_dict(),
                     "optimizer": optimizer.state_dict(),
                     "lr_scheduler": lr_scheduler.state_dict(),
                     "epoch": epoch,
                     "args": args}
        if args.amp:
            save_file["scaler"] = scaler.state_dict()

        if args.save_best is True:
            torch.save(save_file, "save_weights/best_model_mobilenet_unet.pth")
        else:
            torch.save(save_file, "./save_weights/model_amp_{}.pth".format(epoch))

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print("training time {}".format(total_time_str))


def parse_args():
    import argparse
    parser = argparse.ArgumentParser(description="pytorch unet training")

    parser.add_argument("--data-path", default="./", help="DRIVE root")
    # exclude background
    parser.add_argument("--num-classes", default=2, type=int)
    parser.add_argument("--device", default="cuda", help="training device")
    parser.add_argument("-b", "--batch-size", default=4, type=int)
    parser.add_argument("--epochs", default=100, type=int, metavar="N",
                        help="number of total epochs to train")

    parser.add_argument('--lr', default=0.01, type=float, help='initial learning rate')
    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                        help='momentum')
    parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
                        metavar='W', help='weight decay (default: 1e-4)',
                        dest='weight_decay')
    parser.add_argument('--print-freq', default=1, type=int, help='print frequency')

    # 从上次训练停止的地方重新开始训练
    parser.add_argument('--resume', default='', help='resume from checkpoint')
    parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                        help='start epoch')
    # 保存最佳模型
    parser.add_argument('--save-best', default=True, type=bool, help='only save best dice weights')
    # 混合精度训练参数,用于加快模型训练
    parser.add_argument("--amp", default=True, type=bool,  help="Use torch.cuda.amp for mixed precision training")

    # 解析命令行参数,并将解析结果保存在args对象中
    args = parser.parse_args()
    # 返回解析结果
    return args


if __name__ == '__main__':
    args = parse_args()

    if not os.path.exists("./save_weights"):
        os.mkdir("./save_weights")

    main(args)

五.模型的完整代码和数据集下载,请参考GitHub链接:待更新!!!

备注:此代码基于B站大佬:霹雳吧啦Wz,个人对这些代码进行了理解和注释说明得到。文章来源地址https://www.toymoban.com/news/detail-861139.html

到了这里,关于详解U-Net分割网络,提供详细代码技术细节及完整项目代码的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 图像分割算法U-net

    @[TOC] UNet是一种用于图像分割任务的深度学习模型,最初由Olaf Ronneberger等人在2015年提出。它的名字来源于其U形状的网络结构。 UNet的主要特点是它使用了编码器和解码器结构,其中编码器部分由一系列卷积层和池化层组成,可以对输入图像进行特征提取和压缩。解码器部分则

    2024年01月23日
    浏览(36)
  • 医学图像分割综述:U-Net系列

    论文地址 代码地址 医学图像自动分割是医学领域的一个重要课题,也是计算机辅助诊断范式的一个重要对应。U-Net是最广泛的图像分割架构,由于其灵活性,优化的模块化设计,并在所有医学图像模式的成功。多年来,U-Net模型得到了学术界和工业界研究人员的极大关注。该

    2024年02月05日
    浏览(48)
  • 医学图像分割之Attention U-Net

    目录 一、背景 二、问题 三、解决问题 四、Attention U-Net网络结构         简单总结Attention U-Net的操作:增强目标区域的特征值,抑制背景区域的目标值。抑制也就是设为了0。         为了捕获到足够大的、可接受的范围和语义上下文信息,在标准的CNN结构中,特征图被逐步

    2024年02月02日
    浏览(35)
  • DeU-Net: 用于三维心脏mri视频分割的可变形(Deformable)U-Net

    论文链接:https://arxiv.org/abs/2007.06341 代码链接:文章都看完了实在找不到代码!好崩溃!好崩溃!已经发邮件联系作者! 心脏磁共振成像(MRI)的自动分割促进了临床应用中高效、准确的体积测量。然而,由于分辨率各向异性和边界模糊(如右心室心内膜),现有方法在心脏MRI三

    2024年02月09日
    浏览(39)
  • 【图像处理】经营您的第一个U-Net以进行图像分割

            AI厨师们,今天您将学习如何准备计算机视觉中最重要的食谱之一:U-Net。本文将叙述:1 语义与实例分割,2 图像分割中还使用了其他损失,例如Jaccard损失,焦点损失;3 如果2D图像分割对您来说太容易了,您可以查看3D图像分割,因为模型要大得多,因此要困难得

    2024年02月15日
    浏览(53)
  • 论文阅读笔记——SMU-Net:面向缺失模态脑肿瘤分割的样式匹配U-Net

    论文地址:https://arxiv.org/abs/2204.02961v1 脑胶质瘤:https://baike.baidu.com/item/%E8%84%91%E8%83%B6%E8%B4%A8%E7%98%A4/7242862 互信息:https://zhuanlan.zhihu.com/p/240676850 Gram矩阵:https://zhuanlan.zhihu.com/p/187345192 背景: 绝大多数脑肿瘤都可以通过磁共振成像进行唯一的鉴别。 多模态MRI的好处: 每一种模态

    2024年01月25日
    浏览(54)
  • U-Net Transformer:用于医学图像分割的自我和交叉注意力模块

    对于复杂和低对比度的解剖结构,医学图像分割仍然特别具有挑战性。本文提出的一种U-Transformer网络,它将Transformer中的self-attention和Cross attention融合进了UNet,这样克服了UNet无法建模长程关系和空间依赖的缺点,从而提升对关键上下文的分割。本文集合了两种注意力机制:自

    2024年02月06日
    浏览(39)
  • U-Net网络

    U-Net 融合了 编码 - 解码结构和跳跃网络 的特点,在模型结构上更加巧妙,主要体现在以下两点: ● ( 1 ) U-Net 模型是一个 编码 - 解码的结构 ,压缩通道是一个编码器,用于逐层提取影像的特征,扩展通道是一个解码器,用于还原影像的位置信息,且 U-Net 模型的每一个隐

    2024年02月02日
    浏览(38)
  • U-net网络学习记录

    本质上是一个用于图像分割的神经网络 输入是一幅图,输出是目标的分割结果。继续简化就是,一幅图,编码,或者说降采样,然后解码,也就是升采样,然后输出一个分割结果。根据结果和真实分割的差异,反向传播来训练这个分割网络 既然输入和输出都是相同大小的图

    2024年02月09日
    浏览(36)
  • 图像语义分割 pytorch复现U2Net图像分割网络详解

    U2-Net: Going Deeper with Nested U-Structure for Salient Object Detection 网络的主体类似于U-Net的网络结构,在大的U-Net中,每一个小的block都是一个小型的类似于U-Net的结构,因此作者取名U2Net 仔细观察,可以将网络中的block分成两类: 第一类 :En_1 ~ En_4 与 De_1 ~ De_4这8个block采用的block其实是

    2024年01月22日
    浏览(48)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包