U2Net、U2NetP分割模型训练---自定义dataset、训练代码训练自己的数据集

这篇具有很好参考价值的文章主要介绍了U2Net、U2NetP分割模型训练---自定义dataset、训练代码训练自己的数据集。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

前言

  • 博客很久没有更新了,今天就来更新一篇博客吧,哈哈;
  • 最近在做图像分割相关的任务,因此,写这么一篇博客来简单实现一下分割是怎么做的,内容简单,枯燥,需要耐心看,哈哈;
  • 博客的内容相对简单,比较适合刚接触分割的同学参考学习(这篇博客在算法训练上没有涉及到训练策略、数据增强方法,特意留下余地处给大家自行发挥)

内容简介

  • U2Net算法介绍
  • 本博客训练效果截图展示
  • 本博客代码框架介绍
  • 数据集数据集准备
  • 自定义dataset
  • u2net、u2netp网络结构定义
  • 训练代码
  • 模型推理代码
  • 总结以及博客代码的Github地址

U2Net算法介绍

U2Net、U2NetP分割模型训练---自定义dataset、训练代码训练自己的数据集

  • 关于算法介绍,CSDN上很多大神有详细的解读,大家可自行去搜索阅读学习,本博客目的是实操,所以此处省略上千字,哈哈
  • 官方论文地址:https://arxiv.org/pdf/2005.09007.pdf
  • 官方Github repo 地址:https://github.com/xuebinqin/U-2-Net

本博客代码训练效果截图展示

  • 任务图片分割结果可视化展示
    U2Net、U2NetP分割模型训练---自定义dataset、训练代码训练自己的数据集
  • 如上图所示,模型在测试集上的推理效果(左上为原始标注mask,左下为预测的mask,右边图像为原始图片)可以看出,模型的效果还是比较理想的;

代码框架介绍

  • 项目的整体框架如下图所示
    U2Net、U2NetP分割模型训练---自定义dataset、训练代码训练自己的数据集
  • 第一个Folder :backup
backup为训练过程模型的保存的folder,在训练过程中,代码会自动在该目录下生成文件夹,并保存训练过程的权重pth文件
  • 第二个Folder: dataset

dataset目录为训练数据集存放的目录包括了参与训练的原始图片、以及对应的标注mask,训练数据集的组成方式由图片由如下的方式组成:
-images
   -train
     -0.jpg
     -1.jpg
     -....
   -test
  	 -0.jpg
     -1.jpg
     -....
   -val
  	 -0.jpg
     -1.jpg
     -....
 -masks
   -train
     -0.jpg
     -1.jpg
     -....
   -test
  	 -0.jpg
     -1.jpg
     -....
   -val
  	 -0.jpg
     -1.jpg
     -....
  • 第三个Folder:src
src文件夹下有两个文件,一个是网络模型的定义文件u2net.py,另一个为自定义的dataset.py
  • train_u2net.py文件: 模型训练代码
  • inference_u2net.py文件: 模型的推理代码

训练数据集准备

  • 请参考上一章节中dataset Folder的描述方式来准备您的训练数据集;
  • 注意:请保持原始图片和mask图片的命名一致,若不一致的话,需自行修改调整dataset代码部分

自定义dataset

  • 一般来说dataset的组成部分有核心的两个
__getitiem__ 方法 (根据索引返回样本数据)
__len__ 方法 (返回数据集中样本的个数)
(注意:本博客中dataset类中,未写数据增强部分,特意给大家留下空间自行学习和发挥)
  • 根据上述描述,接下来我们开始自定义dataset
    src/seg_dataset.py
# coding: utf-8
# author: hxy
# 2022-04-20
"""
数据读取dataset
"""
import os
import cv2
import numpy as np
from tqdm import tqdm
from torch.utils.data import Dataset


# dataset for u2net
class U2netSegDataset(Dataset):
    def __init__(self, img_dir, mask_dir, input_size=(320, 320)):
        """
        :param img_dir: 数据集图片文件夹路径
        :param mask_dir: 数据集mask文件夹路径
        :param input_size: 图片输入的尺寸
        """
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.input_size = input_size
        self.samples = list()
        self.gt_mask = list()
        self.std = np.array([0.229, 0.224, 0.225]).reshape((1, 1, 3))
        self.mean = np.array([0.485, 0.456, 0.406]).reshape((1, 1, 3))
        self.load_data()

    def __len__(self):
        return len(self.samples)

    def load_data(self):
        img_dir_full_path = self.img_dir
        mask_dir_full_path = self.mask_dir
        img_files = os.listdir(img_dir_full_path)

        for img_name in tqdm(img_files):
            img_full_path = os.path.join(img_dir_full_path, img_name)
            mask_full_path = os.path.join(mask_dir_full_path, img_name)

            img = cv2.imread(img_full_path)
            img = cv2.resize(img, self.input_size)
            img2rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            img2norm = (img2rgb - self.mean) / self.std
            # 图像格式改为nchw
            img2nchw = np.transpose(img2norm, [2, 0, 1]).astype(np.float32)

            gt_mask = cv2.imread(mask_full_path)
            gt_mask = cv2.resize(gt_mask, self.input_size)
            gt_mask = cv2.cvtColor(gt_mask, cv2.COLOR_BGR2GRAY)
            gt_mask = gt_mask / 255.
            gt_mask = np.expand_dims(gt_mask, axis=0)

            self.samples.append(img2nchw)
            self.gt_mask.append(gt_mask)

        return self.samples, self.gt_mask

    def __getitem__(self, index):
        img = self.samples[index]
        mask = self.gt_mask[index]

        return img, mask

上面的代码块简单描述一下: 用os模块遍历文件夹,获取所有文件的名字,并将他们的全部路径拼接起来,opencv读取,然后对读取的照片array做预处理(resize、归一化、通道转换),最后将预处理好的图片append到对应的list中去即可;

u2net、u2netp网络结构定义

  • 网络结构的定义, 该部分代码是直接从源repo中copy过来的,所以直接贴在下来供大家参考使用;
    src/u2net.py
import torch
import torch.nn as nn
import torch.nn.functional as F


class REBNCONV(nn.Module):
    def __init__(self, in_ch=3, out_ch=3, dirate=1):
        super(REBNCONV, self).__init__()

        self.conv_s1 = nn.Conv2d(in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate)
        self.bn_s1 = nn.BatchNorm2d(out_ch)
        self.relu_s1 = nn.ReLU(inplace=True)

    def forward(self, x):
        hx = x
        xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))

        return xout


## upsample tensor 'src' to have the same spatial size with tensor 'tar'
def _upsample_like(src, tar):
    # src = F.upsample(src, size=tar.shape[2:], mode='bilinear') # old version torch
    src = F.upsample(src, size=tar.shape[2:], mode='bilinear', align_corners=True)

    return src


### RSU-7 ###
class RSU7(nn.Module):  # UNet07DRES(nn.Module):

    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        super(RSU7, self).__init__()

        self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)

        self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
        self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
        self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
        self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
        self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
        self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)

        self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)

        self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
        self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
        self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
        self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
        self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
        self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)

    def forward(self, x):
        hx = x
        hxin = self.rebnconvin(hx)

        hx1 = self.rebnconv1(hxin)
        hx = self.pool1(hx1)

        hx2 = self.rebnconv2(hx)
        hx = self.pool2(hx2)

        hx3 = self.rebnconv3(hx)
        hx = self.pool3(hx3)

        hx4 = self.rebnconv4(hx)
        hx = self.pool4(hx4)

        hx5 = self.rebnconv5(hx)
        hx = self.pool5(hx5)

        hx6 = self.rebnconv6(hx)

        hx7 = self.rebnconv7(hx6)

        hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
        hx6dup = _upsample_like(hx6d, hx5)

        hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
        hx5dup = _upsample_like(hx5d, hx4)

        hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
        hx4dup = _upsample_like(hx4d, hx3)

        hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
        hx3dup = _upsample_like(hx3d, hx2)

        hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
        hx2dup = _upsample_like(hx2d, hx1)

        hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))

        return hx1d + hxin


### RSU-6 ###
class RSU6(nn.Module):  # UNet06DRES(nn.Module):

    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        super(RSU6, self).__init__()

        self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)

        self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
        self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
        self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
        self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
        self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)

        self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)

        self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
        self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
        self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
        self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
        self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)

    def forward(self, x):
        hx = x

        hxin = self.rebnconvin(hx)

        hx1 = self.rebnconv1(hxin)
        hx = self.pool1(hx1)

        hx2 = self.rebnconv2(hx)
        hx = self.pool2(hx2)

        hx3 = self.rebnconv3(hx)
        hx = self.pool3(hx3)

        hx4 = self.rebnconv4(hx)
        hx = self.pool4(hx4)

        hx5 = self.rebnconv5(hx)

        hx6 = self.rebnconv6(hx5)

        hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
        hx5dup = _upsample_like(hx5d, hx4)

        hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
        hx4dup = _upsample_like(hx4d, hx3)

        hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
        hx3dup = _upsample_like(hx3d, hx2)

        hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
        hx2dup = _upsample_like(hx2d, hx1)

        hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))

        return hx1d + hxin


### RSU-5 ###
class RSU5(nn.Module):  # UNet05DRES(nn.Module):

    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        super(RSU5, self).__init__()

        self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)

        self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
        self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
        self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
        self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)

        self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)

        self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
        self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
        self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
        self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)

    def forward(self, x):
        hx = x

        hxin = self.rebnconvin(hx)

        hx1 = self.rebnconv1(hxin)
        hx = self.pool1(hx1)

        hx2 = self.rebnconv2(hx)
        hx = self.pool2(hx2)

        hx3 = self.rebnconv3(hx)
        hx = self.pool3(hx3)

        hx4 = self.rebnconv4(hx)

        hx5 = self.rebnconv5(hx4)

        hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
        hx4dup = _upsample_like(hx4d, hx3)

        hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
        hx3dup = _upsample_like(hx3d, hx2)

        hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
        hx2dup = _upsample_like(hx2d, hx1)

        hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))

        return hx1d + hxin


### RSU-4 ###
class RSU4(nn.Module):  # UNet04DRES(nn.Module):

    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        super(RSU4, self).__init__()

        self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)

        self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
        self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
        self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)

        self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)

        self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
        self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
        self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)

    def forward(self, x):
        hx = x

        hxin = self.rebnconvin(hx)

        hx1 = self.rebnconv1(hxin)
        hx = self.pool1(hx1)

        hx2 = self.rebnconv2(hx)
        hx = self.pool2(hx2)

        hx3 = self.rebnconv3(hx)

        hx4 = self.rebnconv4(hx3)

        hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
        hx3dup = _upsample_like(hx3d, hx2)

        hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
        hx2dup = _upsample_like(hx2d, hx1)

        hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))

        return hx1d + hxin


### RSU-4F ###
class RSU4F(nn.Module):  # UNet04FRES(nn.Module):

    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        super(RSU4F, self).__init__()

        self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)

        self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
        self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
        self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)

        self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)

        self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
        self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
        self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)

    def forward(self, x):
        hx = x

        hxin = self.rebnconvin(hx)

        hx1 = self.rebnconv1(hxin)
        hx2 = self.rebnconv2(hx1)
        hx3 = self.rebnconv3(hx2)

        hx4 = self.rebnconv4(hx3)

        hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
        hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
        hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))

        return hx1d + hxin


##### U^2-Net ####
class U2NET(nn.Module):

    def __init__(self, in_ch=3, out_ch=1):
        super(U2NET, self).__init__()

        self.stage1 = RSU7(in_ch, 32, 64)
        self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.stage2 = RSU6(64, 32, 128)
        self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.stage3 = RSU5(128, 64, 256)
        self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.stage4 = RSU4(256, 128, 512)
        self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.stage5 = RSU4F(512, 256, 512)
        self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.stage6 = RSU4F(512, 256, 512)

        # decoder
        self.stage5d = RSU4F(1024, 256, 512)
        self.stage4d = RSU4(1024, 128, 256)
        self.stage3d = RSU5(512, 64, 128)
        self.stage2d = RSU6(256, 32, 64)
        self.stage1d = RSU7(128, 16, 64)

        self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
        self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
        self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
        self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
        self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
        self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)

        self.outconv = nn.Conv2d(6 * out_ch, out_ch, 1)

    def forward(self, x):
        hx = x

        # stage 1
        hx1 = self.stage1(hx)
        hx = self.pool12(hx1)

        # stage 2
        hx2 = self.stage2(hx)
        hx = self.pool23(hx2)

        # stage 3
        hx3 = self.stage3(hx)
        hx = self.pool34(hx3)

        # stage 4
        hx4 = self.stage4(hx)
        hx = self.pool45(hx4)

        # stage 5
        hx5 = self.stage5(hx)
        hx = self.pool56(hx5)

        # stage 6
        hx6 = self.stage6(hx)
        hx6up = _upsample_like(hx6, hx5)

        # -------------------- decoder --------------------
        hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
        hx5dup = _upsample_like(hx5d, hx4)

        hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
        hx4dup = _upsample_like(hx4d, hx3)

        hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
        hx3dup = _upsample_like(hx3d, hx2)

        hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
        hx2dup = _upsample_like(hx2d, hx1)

        hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))

        # side output
        d1 = self.side1(hx1d)

        d2 = self.side2(hx2d)
        d2 = _upsample_like(d2, d1)

        d3 = self.side3(hx3d)
        d3 = _upsample_like(d3, d1)

        d4 = self.side4(hx4d)
        d4 = _upsample_like(d4, d1)

        d5 = self.side5(hx5d)
        d5 = _upsample_like(d5, d1)

        d6 = self.side6(hx6)
        d6 = _upsample_like(d6, d1)

        d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))

        return torch.sigmoid(d0), torch.sigmoid(d1), torch.sigmoid(d2), torch.sigmoid(d3), torch.sigmoid(d4), torch.sigmoid(d5), torch.sigmoid(d6)


### U^2-Net small ###
class U2NETP(nn.Module):

    def __init__(self, in_ch=3, out_ch=1):
        super(U2NETP, self).__init__()

        self.stage1 = RSU7(in_ch, 16, 64)
        self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.stage2 = RSU6(64, 16, 64)
        self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.stage3 = RSU5(64, 16, 64)
        self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.stage4 = RSU4(64, 16, 64)
        self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.stage5 = RSU4F(64, 16, 64)
        self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.stage6 = RSU4F(64, 16, 64)

        # decoder
        self.stage5d = RSU4F(128, 16, 64)
        self.stage4d = RSU4(128, 16, 64)
        self.stage3d = RSU5(128, 16, 64)
        self.stage2d = RSU6(128, 16, 64)
        self.stage1d = RSU7(128, 16, 64)

        self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
        self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
        self.side3 = nn.Conv2d(64, out_ch, 3, padding=1)
        self.side4 = nn.Conv2d(64, out_ch, 3, padding=1)
        self.side5 = nn.Conv2d(64, out_ch, 3, padding=1)
        self.side6 = nn.Conv2d(64, out_ch, 3, padding=1)

        self.outconv = nn.Conv2d(6 * out_ch, out_ch, 1)

    def forward(self, x):
        hx = x

        # stage 1
        hx1 = self.stage1(hx)
        hx = self.pool12(hx1)

        # stage 2
        hx2 = self.stage2(hx)
        hx = self.pool23(hx2)

        # stage 3
        hx3 = self.stage3(hx)
        hx = self.pool34(hx3)

        # stage 4
        hx4 = self.stage4(hx)
        hx = self.pool45(hx4)

        # stage 5
        hx5 = self.stage5(hx)
        hx = self.pool56(hx5)

        # stage 6
        hx6 = self.stage6(hx)
        hx6up = _upsample_like(hx6, hx5)

        # decoder
        hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
        hx5dup = _upsample_like(hx5d, hx4)

        hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
        hx4dup = _upsample_like(hx4d, hx3)

        hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
        hx3dup = _upsample_like(hx3d, hx2)

        hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
        hx2dup = _upsample_like(hx2d, hx1)

        hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))

        # side output
        d1 = self.side1(hx1d)

        d2 = self.side2(hx2d)
        d2 = _upsample_like(d2, d1)

        d3 = self.side3(hx3d)
        d3 = _upsample_like(d3, d1)

        d4 = self.side4(hx4d)
        d4 = _upsample_like(d4, d1)

        d5 = self.side5(hx5d)
        d5 = _upsample_like(d5, d1)

        d6 = self.side6(hx6)
        d6 = _upsample_like(d6, d1)

        d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))

        return torch.sigmoid(d0), torch.sigmoid(d1), torch.sigmoid(d2), torch.sigmoid(d3), torch.sigmoid(d4), torch.sigmoid(d5), torch.sigmoid(d6)

训练代码

  • 训练代码
深度学习训练代码的一般流程是: 模型定义 -> 数据加载 -> 模型训练 ->模型验证
本博客中训练代码的实现逻辑如下:
	1 定义网络
	2 加载数据
	3 定义损失函数和优化器
	 4 开始训练
	   - 训练网络
	   - 将梯度置为0
	   - 求loss
	   - 反向传播
	   - 更新参数
(在本博客的训练代码中未写验证部分代码,留给各位同学自行实现)

** train_u2net.py**

# coding: utf-8
# author: hxy
# 20220420
"""
训练代码:u2net、u2netp
train it from scratch.
"""
import os
import datetime
import torch
import numpy as np
from tqdm import tqdm
from src.u2net import U2NET, U2NETP
from src.seg_dataset import U2netSegDataset
from torch.utils.data import DataLoader

# 参考u2net源码loss的设定
bce_loss = torch.nn.BCELoss(reduction='mean')


def muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v):
    loss0 = bce_loss(d0, labels_v)
    loss1 = bce_loss(d1, labels_v)
    loss2 = bce_loss(d2, labels_v)
    loss3 = bce_loss(d3, labels_v)
    loss4 = bce_loss(d4, labels_v)
    loss5 = bce_loss(d5, labels_v)
    loss6 = bce_loss(d6, labels_v)

    loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6
    # print("l0: %3f, l1: %3f, l2: %3f, l3: %3f, l4: %3f, l5: %3f, l6: %3f\n" % (
    # loss0.data.item(), loss1.data.item(), loss2.data.item(), loss3.data.item(), loss4.data.item(), loss5.data.item(),
    # loss6.data.item()))

    return loss0, loss


def load_data(img_folder, mask_folder, batch_size, num_workers, input_size):
    """
    :param img_folder: 图片保存的fodler
    :param mask_folder: mask保存的fodler
    :param batch_size: batch_size的设定
    :param num_workers: 数据加载cpu核心数
    :param input_size: 模型输入尺寸
    :return:
    """
    train_dataset = U2netSegDataset(img_dir=os.path.join(img_folder, 'train'),
                                    mask_dir=os.path.join(mask_folder, 'train'),
                                    input_size=input_size)

    val_dataset = U2netSegDataset(img_dir=os.path.join(img_folder, 'val'),
                                  mask_dir=os.path.join(mask_folder, 'val'),
                                  input_size=input_size)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    return train_loader, val_loader


def train_model(epoch_nums, cuda_device, model_save_dir):
    """
    :param epoch_nums: 训练总的epoch
    :param cuda_device: 指定gpu训练
    :param model_save_dir: 模型保存folder
    :return:
    """
    current_time = datetime.datetime.now()
    current_time = datetime.datetime.strftime(current_time, '%Y-%m-%d-%H:%M')
    model_save_dir = os.path.join(model_save_dir, current_time)
    if not os.path.exists(model_save_dir):
        os.makedirs(model_save_dir)
    else:
        pass

    device = torch.device(cuda_device)
    train_loader, val_loader = load_data(img_folder='dataset',
                                         mask_folder='dataset',
                                         batch_size=32,
                                         num_workers=10,
                                         input_size=(160, 160))

    # input 3-channels, output 1-channels
    net = U2NET(3, 1)
    #net = U2NETP(3, 1)
    
    # if torch.cuda.device_count() > 1:
    #     net = torch.nn.DataParallel(net, device_ids=[6, 7])
    net.to(device)

    optimizer = torch.optim.Adam(net.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)

    for epoch in range(0, epoch_nums):
        run_loss = list()
        run_tar_loss = list()

        net.train()
        for i, (inputs, gt_masks) in enumerate(tqdm(train_loader)):
            optimizer.zero_grad()
            inputs = inputs.type(torch.FloatTensor)
            gt_masks = gt_masks.type(torch.FloatTensor)
            inputs, gt_masks = inputs.to(device), gt_masks.to(device)

            d0, d1, d2, d3, d4, d5, d6 = net(inputs)
            loss2, loss = muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, gt_masks)

            loss.backward()
            optimizer.step()

            run_loss.append(loss.item())
            run_tar_loss.append(loss2.item())
            del d0, d1, d2, d3, d4, d5, d6, loss2, loss

        print("--Train Epoch:{}--".format(epoch))
        print("--Train run_loss:{:.4f}--".format(np.mean(run_loss)))
        print("--Train run_tar_loss:{:.4f}--\n".format(np.mean(run_tar_loss)))

        if epoch % 20 == 0:
            checkpoint_name = 'checkpoint_' + str(epoch) + '_' + str(np.mean(run_loss)) + '.pth'
            torch.save(net.state_dict(), os.path.join(model_save_dir, checkpoint_name))
            print("--model saved:{}--".format(checkpoint_name))


if __name__ == '__main__':
    train_model(epoch_nums=500, cuda_device='cuda:7',
                model_save_dir='backup')

在这部分训练代码中, 并没有出现很多训练策略,如各种学习率调整策略、多阶段学习等等…该代码实现的为最基础的训练代码,因此,您有足够的空间去自行发挥;

模型推理程序

  • 算法模型推理
推理程序的编写逻辑一般是: 加载模型-> 读取图片 —>图片预处理(需要保持和训练过程中的图片预处理一致) ->模型推理 ->获取结果,进行后处理 ->保存图片,可视化查看结果

inference_u2net.py

# coding: utf-8
# author: hxy
# 20220420
"""
u2net/u2netP模型推理程序
"""

import os
import cv2
import torch
import numpy as np
from time import time
from tqdm import tqdm
from src.u2net import U2NET, U2NETP

"""
初始化模型加载
"""
try:
    print('===loading model===')
    current_project_path = os.getcwd()
    net = U2NET(3, 1)
    # net = U2NETP(3, 1)
    checkpoint_path = os.path.join(current_project_path,
                                   'backup/*****.pth')
    if torch.cuda.is_available():
        net.load_state_dict(torch.load(checkpoint_path, map_location='cuda:1'))
    else:
        net.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
    net.eval()
    print('===model lode sucessed===')

except Exception as e:
    print('===model load error:{}==='.format(e))


# 计算dice
def dice_coef(output, target):  # output为预测结果 target为真实结果
    smooth = 1e-5  # 防止0除
    intersection = (output * target).sum()
    return (2. * intersection + smooth) / \
           (output.sum() + target.sum() + smooth)


# 图像归一化操作
def img2norm(img_array, input_size):
    std = [0.229, 0.224, 0.225]
    mean = [0.485, 0.456, 0.406]
    _std = np.array(std).reshape((1, 1, 3))
    _mean = np.array(mean).reshape((1, 1, 3))

    img_array = cv2.resize(img_array, input_size)
    norm_img = (img_array - _mean) / _std

    return norm_img


# 归一化预测结果
def normPRED(d):
    ma = torch.max(d)
    mi = torch.min(d)
    dn = (d - mi) / (ma - mi)

    return dn


# 推理
def inference1folder(img_folder, mask_folder, input_size):
    total_times = list()
    total_dices = list()
    img_files = os.listdir(img_folder)
    for img_file in tqdm(img_files):
        img_full_path = os.path.join(img_folder, img_file)
        mask_full_path = os.path.join(mask_folder, img_file)
        img = cv2.imread(img_full_path)
        gt_mask = cv2.imread(mask_full_path)
        gt_mask = cv2.resize(gt_mask, input_size)
        gt_mask = cv2.cvtColor(gt_mask, cv2.COLOR_BGR2GRAY)
        gt_mask = gt_mask / 255.

        ori_h, ori_w = img.shape[:2]
        img2rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        norm_img = img2norm(img2rgb, input_size)

        x_tensor = torch.from_numpy(norm_img).permute(2, 0, 1).float()
        x_tensor = torch.unsqueeze(x_tensor, 0)

        start_t = time()
        d1, d2, d3, d4, d5, d6, d7 = net(x_tensor)
        end_t = time()

        total_times.append(end_t - start_t)
        pred = d1[:, 0, :, :]
        pred = normPRED(pred)
        pred = pred.squeeze().cpu().data.numpy()

        dice_value = dice_coef(pred, gt_mask)
        total_dices.append(dice_value)

        # pred[pred>=0.3]=255
        # pred[pred<0.3]=0
        # pred_res = pred
        pred_res = pred * 255
        pred_res = cv2.resize(pred_res, (ori_w, ori_h))

        cv2.imwrite(os.path.join(current_project_path, 'infer_output/', img_file), pred_res)

    print('==inference 1 pic avg cost:{:.4f}ms=='.format(np.mean(total_times) * 1000))
    print('==inference avg dice:{:.4f}=='.format(np.mean(total_dices)))

    return None


if __name__ == '__main__':
    test_img_folder = os.path.join(os.getcwd(), 'dataset/images/test')
    test_gt_mask_folder = os.path.join(os.getcwd(), 'dataset/masks/test')
    inference1folder(img_folder=test_img_folder, mask_folder=test_gt_mask_folder, input_size=(160, 160))

着一部分代码没什么好说的,仔细看就完事,当然我只写了针对于一个folder的推理代码,您可以尝试推理视频file;或者你也可以加一些更加炫酷的后处理让你的推理结果看起来更加具有美观;文章来源地址https://www.toymoban.com/news/detail-450665.html

总结以及博客代码的Github地址

  • 一篇博客写完总归还是要来点总结才完美的!
  1. 本篇博客实现的是最基础的训练过程和训练代码,所以你有很多的发挥空间;
  2. 例如:尝试使用不同的loss函数(dice loss、bce dice loss、iou loss等等)
  3. 添加数据增强操作(建议使用albumentation库,torchversion也行)
  4. 使用不同的调参策略训练模型(不同的学习率衰减策略、多阶段训练等等)
  5. 尝试使用不同的优化器训练模型等等。。。。。
  6. 等你上述尝试都做过了,你可尝试使用不同的网络,src文件夹内不断丰富不同网络结构
  7. 优化一下代码的编写,封装一下之类的,哈哈。。
  8. 总是很多实验可以做,可学习的东西也很多。。
  9. 最后,希望本篇博客能够给你带来帮助~互相学习~文章代码有不知之处多多包涵!
  • 本博客代码Github地址: https://github.com/YingXiuHe/u2net-pytorch.git/

到了这里,关于U2Net、U2NetP分割模型训练---自定义dataset、训练代码训练自己的数据集的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 模型实战(3)之YOLOv7实例分割、模型训练自己数据集

    下载yolov7实例分割模型: 安装环境

    2023年04月08日
    浏览(30)
  • 特征提取网络之res2net

    整理下res2net特征提取网络 论文地址:https://arxiv.org/abs/1904.01169 文中所提供的代码来自:[https://github.com/open-mmlab/mmclassification] 主要是在每个残差块内部构建特征金字塔结构,在特征层内部进行多尺度的卷积,形成不同感受野,获得不同细粒度的特征。 结构图 结构图与resnet类

    2024年02月06日
    浏览(26)
  • 学习Segformer语义分割模型并训练测试cityscapes数据集

    官方的segformer源码是基于MMCV框架,整体包装较多,自己不便于阅读和学习,我这里使用的是Bubbliiiing大佬github复现的segformer版本。 Bubbliiiing大佬代码下载链接: https://github.com/bubbliiiing/segformer-pytorch 大佬的代码很优秀简练,注释也很详细,代码里采用的是VOC数据集的格式,因

    2024年02月15日
    浏览(29)
  • Res2Net: 一种新的多尺度主干体系结构(Res2Net: A New Multi-scale Backbone Architecture )

    如图1所示,视觉模式在自然场景中以多尺度出现。首先, 对象可以在单个图像中以不同的尺寸 出现,例如,沙发和杯子具有不同的尺寸。其次, 对象的基本上下文信息可能比对象本身占据更大的区域 。例如,我们需要依靠大桌子作为上下文,以更好地判断放置在桌子上的

    2024年02月13日
    浏览(34)
  • 【3-D深度学习:肺肿瘤分割】创建和训练 V-Net 神经网络,并从 3D 医学图像中对肺肿瘤进行语义分割研究(Matlab代码实现)

     💥💥💞💞 欢迎来到本博客 ❤️❤️💥💥 🏆博主优势: 🌞🌞🌞 博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️ 座右铭: 行百里者,半于九十。 📋📋📋 本文目录如下: 🎁🎁🎁 目录 💥1 概述 📚2 运行结果 🎉3 参考文献 🌈4 Matlab代码实现 使用

    2024年02月15日
    浏览(30)
  • YOLOv5:使用7.0版本训练自己的实例分割模型(车辆、行人、路标、车道线等实例分割)

    由于本人水平有限,难免出现错漏,敬请批评改正。 更多精彩内容,可点击进入YOLO系列专栏或我的个人主页查看 YOLOv5:添加SE、CBAM、CoordAtt、ECA注意力机制 YOLOv5:yolov5s.yaml配置文件解读、增加小目标检测层 YOLOv5:IoU、GIoU、DIoU、CIoU、EIoU YOLOv7训练自己的数据集(口罩检测)

    2024年02月15日
    浏览(33)
  • U2-net网络详解

    学习视频:U2Net网络结构讲解_哔哩哔哩_bilibili 论文名称:U2-Net: Goging Deeper with Nested U-Structure forSalient Object Detetion 论文下载地址:https://arxiv.org/abs/2005.09007 官方源码(Pytorch实现):https://github.com/xuebinqin/U-2-Net U2-net是阿尔伯塔大学(University of Alberta)在2020年发表在CVPR上的一篇

    2024年01月25日
    浏览(23)
  • 基于.Net6使用YoloV8的分割模型

    在目标检测一文中,我们学习了如何处理Onnx模型,并的到目标检测结果,在此基础上,本文实现基于.Net平台的实例分割任务。 执行YoloV8的分割任务后可以得到分割.pt模型。由于Python基本不用于工业软件的部署,最终还是希望能在.Net平台使用训练好的模型进行预测。我们可以

    2024年02月09日
    浏览(23)
  • Yolov5改进算法之添加Res2Net模块

    目录 1. Res2Net介绍 1.1 Res2Net的背景和动机 1.2 Res2Net的基本概念 2. YOLOV5添加Res2Net模块   Res2Net (Residual Resolution Network)是一种用于图像处理和计算机视觉任务的深度卷积神经网络架构。它旨在解决传统的ResNet(Residual Network)存在的问题,如对不同尺度和分辨率特征的建模不足

    2024年02月10日
    浏览(24)
  • CVPR2023最新论文 (含语义分割、扩散模型、多模态、预训练、MAE等方向)

    2023 年 2 月 28 日凌晨,CVPR 2023 顶会论文接收结果出炉! CVPR 2023 收录的工作中 \\\" 扩散模型、多模态、预训练、MAE \\\" 相关工作的数量会显著增长。 Delivering Arbitrary-Modal Semantic Segmentation 论文/Paper: http://arxiv.org/pdf/2303.01480 代码/Code: None Conflict-Based Cross-View Consistency for Semi-Supervised

    2023年04月08日
    浏览(31)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包