计算机视觉-02-基于U-Net的肝脏肿瘤分割(包含数据和代码)

这篇具有很好参考价值的文章主要介绍了计算机视觉-02-基于U-Net的肝脏肿瘤分割(包含数据和代码)。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。


关注公众号:『 AI学习星球
回复: 肝脏肿瘤分割 即可获取数据下载。
需要 论文辅导4对1辅导算法学习,可以通过 CSDN公众号滴滴我
3dircadb数据集介绍,# 计算机视觉,计算机视觉,人工智能,python,图像处理

1. 介绍

1.1 简介

该任务分为三个阶段,这是第一个阶段,三个阶段分别是:

  1. 第一阶段分割出腹部图像中的肝脏,作为第二阶段的ROI(region of interest)
  2. 第二阶段利用ROI对腹部图像进行裁剪,裁剪后的非ROI区域变成黑色,作为该阶段输入,分割出肝脏中的肿瘤。
  3. 第三阶段用随机场的后处理方法进行优化。

1.2 任务介绍

分割出CT腹部图像的肝脏区域。

1.3 数据集介绍

1.3.1 介绍

实验用到的数据为3Dircadb,即腹部的CT图。一个病人为一个文件夹。
例如3Dircadb1.1(一共20位),该文件夹下会使用到的子文件夹为PATIENT_DICOM(病人原始腹部3D CT图),MASKS_DICOM(该文件夹下具有不同部位的分割结果,比如liver和liver tumor等等)。如下图所示:
3dircadb数据集介绍,# 计算机视觉,计算机视觉,人工智能,python,图像处理

PATIENT_DICOM利用软件展示效果如下:一个dcm文件包含129张切片:
3dircadb数据集介绍,# 计算机视觉,计算机视觉,人工智能,python,图像处理
MASKS_DICOM下的liver分割图效果如下:
3dircadb数据集介绍,# 计算机视觉,计算机视觉,人工智能,python,图像处理

1.3.2 数据预处理建议
  1. 将ct值转化为标准的hu值
  2. 窗口化操作
  3. 直方图均衡化
  4. 归一化
  5. 仅提取腹部所有切片中包含了肝脏的那些切片,其余的不要

1.5 整体流程梳理

1.5.1 数据读取:从原始dcm格式读入成我们需要的数组格式
#part1
import numpy as np
import pydicom
import os
import matplotlib.pyplot as plt
import cv2
from keras.preprocessing.image import ImageDataGenerator
from HDF5DatasetWriter import HDF5DatasetWriter
from HDF5DatasetGenerator import HDF5DatasetGenerator

for i in range(1,18): # 前17个人作为测试集
   full_images = [] # 后面用来存储目标切片的列表
   full_livers = [] #功能同上
   # 注意不同的系统,文件分割符的区别
   label_path = '~/3Dircadb/3Dircadb1.%d/MASKS_DICOM/liver'%i
   data_path = '~/3Dircadb/3Dircadb1.%d/PATIENT_DICOM'%i
   liver_slices = [pydicom.dcmread(label_path + '/' + s) for s in os.listdir(label_path)]
   # 注意需要排序,即使文件夹中显示的是有序的,读进来后就是随机的了
   liver_slices.sort(key = lambda x: int(x.InstanceNumber))
   # s.pixel_array 获取dicom格式中的像素值
   livers = np.stack([s.pixel_array for s in liver_slices])
   image_slices = [pydicom.dcmread(data_path + '/' + s) for s in os.listdir(data_path)]
   image_slices.sort(key = lambda x: int(x.InstanceNumber))
   
   """ 省略进行的预处理操作,具体见part2"""
   
   full_images.append(images)
   full_livers.append(livers)
      
   full_images = np.vstack(full_images)
   full_images = np.expand_dims(full_images,axis=-1)
   full_livers = np.vstack(full_livers)
   full_livers = np.expand_dims(full_livers,axis=-1)
1.5.2 数据预处理:上面给出了提示
a. 将ct值转化为标准的hu值

若是dicom格式的图片,得先转化为hu值(即ct图像值,若图片本来就是ct,则不需要转换) ,因为hu值是与设备无关的,不同范围之内的值可以代表不同器官。常见的对应如下(w代表ct值width,L代表ct值level,即 [level - width/2 , level + width/2]为其窗口化目标范围),其实无论对于dcm还是nii格式的图片,只要是ct图,就都可以选择将储存的原始数据转化为Hu值,因为Hu值即代表了物体真正的密度。对于nii格式的图片,经过测试,nibabel, simpleitk常用的api接口,都会自动的进行上述转化过程,即取出来的值已经是Hu了。(除非专门用nib.load(‘xx’).dataobj.get_unscaled()或者itk.ReadImage(‘xx’).GetPixel(x,y,z)才能取得原始数据)对于dcm格式的图片,经过测试,simpleitk, pydicom常用的api接口都不会将原始数据自动转化为Hu!!(itk snap软件读入dcm或nii都不会对数据进行scale操作)
dicom转化为hu的示例代码:

def get_pixels_hu(scans):
    #type(scans[0].pixel_array)
    #Out[15]: numpy.ndarray
    #scans[0].pixel_array.shape
    #Out[16]: (512, 512)
    # image.shape: (129,512,512)
    image = np.stack([s.pixel_array for s in scans])
    # Convert to int16 (from sometimes int16), 
    # should be possible as values should always be low enough (<32k)
    image = image.astype(np.int16)

    # Set outside-of-scan pixels to 1
    # The intercept is usually -1024, so air is approximately 0
    image[image == -2000] = 0
    
    # Convert to Hounsfield units (HU)
    intercept = scans[0].RescaleIntercept
    slope = scans[0].RescaleSlope
    
    if slope != 1:
        image = slope * image.astype(np.float64)
        image = image.astype(np.int16)
        
    image += np.int16(intercept)
    
    return np.array(image, dtype=np.int16)
b. 窗口化操作

然而,hu的范围一般来说很大,这就导致了对比度很差,如果需要针对具体的器官进行处理,效果会不好,于是就有了windowing的方法:

Windowing, also known as grey-level mapping, contrast stretching, histogram modification or contrast enhancement is the process in which the CT image greyscale component of an image is manipulated via the CT numbers; doing this will change the appearance of the picture to highlight particular structures. The brightness of the image is, adjusted via the window level. The contrast is adjusted via the window width.

windowing代码

def transform_ctdata(self, windowWidth, windowCenter, normal=False):
        """
        注意,这个函数的self.image一定得是float类型的,否则就无效!
        return: trucated image according to window center and window width
        """
        minWindow = float(windowCenter) - 0.5*float(windowWidth)
        newimg = (self.image - minWindow) / float(windowWidth)
        newimg[newimg < 0] = 0
        newimg[newimg > 1] = 1
        if not normal:
            newimg = (newimg * 255).astype('uint8')
        return newimg

windowCenter和windowWidth的选择:
一是可以根据所需部位的hu值(对于CT数据而言)分布范围选取(注意若是增强ct的话hu值会有一些差别,可以画一下几个随机数据的hu分布图看一看)
二是可以根据图像的统计信息,例如图像均值作为窗口中心,正负2.5(这个值并非固定)的方差作为窗口宽度,下面是github上,vnet一个预处理过程,使用了基于统计的窗口化操作。

class StatisticalNormalization(object):
    """
    Normalize an image by mapping intensity with intensity distribution
    """

    def __init__(self, sigma):
        self.name = 'StatisticalNormalization'
        assert isinstance(sigma, float)
        self.sigma = sigma

    def __call__(self, sample):
        image, label = sample['image'], sample['label']
        statisticsFilter = sitk.StatisticsImageFilter()
        statisticsFilter.Execute(image)

        intensityWindowingFilter = sitk.IntensityWindowingImageFilter()
        intensityWindowingFilter.SetOutputMaximum(255)
        intensityWindowingFilter.SetOutputMinimum(0)
        intensityWindowingFilter.SetWindowMaximum(
            statisticsFilter.GetMean() + self.sigma * statisticsFilter.GetSigma())
        intensityWindowingFilter.SetWindowMinimum(
            statisticsFilter.GetMean() - self.sigma * statisticsFilter.GetSigma())

        image = intensityWindowingFilter.Execute(image)

        return {'image': image, 'label': label}
c. 直方图均衡化
def clahe_equalized(imgs,start,end):
   assert (len(imgs.shape)==3)  #3D arrays
   #create a CLAHE object (Arguments are optional).
   clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
   imgs_equalized = np.empty(imgs.shape)
   for i in range(start, end+1):
       imgs_equalized[i,:,:] = clahe.apply(np.array(imgs[i,:,:], dtype = np.uint8))
   return imgs_equalized
d. 归一化
e. 仅提取腹部所有切片中包含了肝脏的那些切片,其余的不要

在训练网络时,一般目标区域越有针对性效果越好,因此经常会在训练前对数据进行预处理,提取出包含有目标的那些切片。下面是示例代码(原始数据是3D ct images)
简单的方法:

def getRangImageDepth(image):
	"""
	args:
	image ndarray of shape (depth, height, weight)
	"""
	# 得到轴向上出现过目标(label>=1)的切片
    z = np.any(image, axis=(1,2)) # z.shape:(depth,)
    startposition,endposition = np.where(z)[0][[0,-1]]
    return startposition, endposition

复杂的方法:

def getRangImageDepth(image):
	"""
	args:
	image ndarray of shape (depth, height, weight)
	"""
	firstflag = True
	startposition = 0
	endposition = 0
	for z in range(image.shape[0]):
		notzeroflag = np.max(image[z])
		if notzeroflag and firstflag:
			startposition = z
			firstflag = False
		if notzeroflag:
			endposition = z
    return startposition, endposition

本实例的代码:

#part2
    # 接part1
   images = get_pixels_hu(image_slices)
   
   images = transform_ctdata(images,500,150)
   
   start,end = getRangImageDepth(livers)
   images = clahe_equalized(images,start,end)
   
   images /= 255.
   # 仅提取腹部所有切片中包含了肝脏的那些切片,其余的不要
  
   total = (end - 4) - (start+4) +1
   print("%d person, total slices %d"%(i,total))
   # 首和尾目标区域都太小,舍弃
   images = images[start+5:end-5]
   print("%d person, images.shape:(%d,)"%(i,images.shape[0]))
   
   livers[livers>0] = 1
   
   livers = livers[start+5:end-5]
1.5.3 数据增强

利用keras的数据增强接口,可以实现分割问题的数据增强。一般的增强是分类问题,这种情况,只需要对image变形,label保持不变。但分割问题,就需要image和mask进行同样的变形处理。具体怎么实现,参考下面代码,注意种子设定成一样的。

# 可以在part1之前设定好(即循环外)
seed=1
data_gen_args = dict(rotation_range=3,
                    width_shift_range=0.01,
                    height_shift_range=0.01,
                    shear_range=0.01,
                    zoom_range=0.01,
                    fill_mode='nearest')

image_datagen = ImageDataGenerator(**data_gen_args)
mask_datagen = ImageDataGenerator(**data_gen_args)
#part3 接part2
    image_datagen.fit(full_images, augment=True, seed=seed)
    mask_datagen.fit(full_livers, augment=True, seed=seed)
    image_generator = image_datagen.flow(full_images,seed=seed)
    mask_generator = mask_datagen.flow(full_livers,seed=seed)

    train_generator = zip(image_generator, mask_generator)
    x=[]
    y=[]
    i = 0
    for x_batch, y_batch in train_generator:
        i += 1
        x.append(x_batch)
        y.append(y_batch)
        if i>=2: # 因为我不需要太多的数据
            break
    x = np.vstack(x)
    y = np.vstack(y)
1.5.4 数据存储

一般而言,数据量较大的话,都会先将原始数据库的东西转化为np或者h5格式的文件,我感觉这样有两个好处,一是真正输入网络训练的时候io量会大大减少(特别是h5很适用于大的数据库),二是数据分享或者上传至服务器时也方便一点。

实验中会出现两个类,分别是写h5和读h5文件的辅助类:
这读文件的类写成了generator,这样可以结合训练网络时,keras的fit_generator来使用,降低内存开销。
HDF5DatasetGenerator.py

# -*- coding: utf-8 -*-
import h5py
import os
import numpy as np

class HDF5DatasetGenerator:
    
    def __init__(self, dbPath, batchSize, preprocessors=None,
                 aug=None, binarize=True, classes=2):
        self.batchSize = batchSize
        self.preprocessors = preprocessors
        self.aug = aug
        self.binarize = binarize
        self.classes = classes
        
        self.db = h5py.File(dbPath)
        self.numImages = self.db["images"].shape[0]
#        self.numImages = total
        print("total images:",self.numImages)
        self.num_batches_per_epoch = int((self.numImages-1)/batchSize) + 1
        
    
    def generator(self, shuffle=True, passes=np.inf):
        epochs = 0
        
        while epochs < passes:
            shuffle_indices = np.arange(self.numImages) 
            shuffle_indices = np.random.permutation(shuffle_indices)
            for batch_num in range(self.num_batches_per_epoch):
                
                start_index = batch_num * self.batchSize
                end_index = min((batch_num + 1) * self.batchSize, self.numImages)
                
                # h5py get item by index,参数为list,而且必须是增序
                batch_indices = sorted(list(shuffle_indices[start_index:end_index]))
                
                images = self.db["images"][batch_indices,:,:,:]
                labels = self.db["masks"][batch_indices,:,:,:]
                
#                if self.binarize:
#                    labels = np_utils.to_categorical(labels, self.classes)
                
                if self.preprocessors is not None:
                    procImages = []
                    for image in images:
                        for p in self.preprocessors:
                            image = p.preprocess(image)
                        procImages.append(image)
                    
                    images = np.array(procImages)
                
                if self.aug is not None:
                    # 不知道意义何在?本身images就有batchsize个了
                    (images, labels) = next(self.aug.flow(images, labels,
                                                        batch_size=self.batchSize))
                yield (images, labels)
            
            epochs += 1
            
    def close(self):
        self.db.close()

HDF5DatasetWriter.py

# -*- coding: utf-8 -*-
import h5py
import os

class HDF5DatasetWriter:
    def __init__(self, image_dims, mask_dims, outputPath, bufSize=200):
        """
        Args:
        - bufSize: 当内存储存了bufSize个数据时,就需要flush到外存
        """
        if os.path.exists(outputPath):
            raise ValueError("The supplied 'outputPath' already"
                             "exists and cannot be overwritten. Manually delete"
                             "the file before continuing", outputPath)
        
        self.db = h5py.File(outputPath, "w")
        self.data = self.db.create_dataset("images", image_dims, maxshape=(None,)+image_dims[1:], dtype="float")
        self.masks = self.db.create_dataset("masks", mask_dims, maxshape=(None,)+mask_dims[1:], dtype="int")
        self.dims = image_dims
        self.bufSize = bufSize
        self.buffer = {"data": [], "masks": []}
        self.idx = 0
    

    def add(self, rows, masks):
        # extend() 函数用于在列表末尾一次性追加另一个序列中的多个值(用新列表扩展原来的列表)
        # 注意,用extend还有好处,添加的数据不会是之前list的引用!!
        self.buffer["data"].extend(rows)
        self.buffer["masks"].extend(masks)
        print("len ",len(self.buffer["data"]))
        
        if len(self.buffer["data"]) >= self.bufSize:
            self.flush()
    
    def flush(self):
        i = self.idx + len(self.buffer["data"])
        if i>self.data.shape[0]:
        	# 扩展大小的策略可以自定义
            new_shape = (self.data.shape[0]*2,)+self.dims[1:]
            print("resize to new_shape:",new_shape)
            self.data.resize(new_shape)
            self.masks.resize(new_shape)
        self.data[self.idx:i,:,:,:] = self.buffer["data"]
        self.masks[self.idx:i,:,:,:] = self.buffer["masks"]
        print("h5py have writen %d data"%i)
        self.idx = i
        self.buffer = {"data": [], "masks": []}
        
    
    def close(self):
        if len(self.buffer["data"]) > 0:
            self.flush()

h5文件操作需要的包:import h5py

# 可以在part1之前设定好(即循环外)
# 这儿的数量需要提前写好,感觉很不方便,但不知道怎么改,我是先跑了之前的程序,计算了一共有多少
# 张图片后再写的,但这样明显不是好的解决方案
dataset = HDF5DatasetWriter(image_dims=(2782, 512, 512, 1),
                            mask_dims=(2782, 512, 512, 1),
                            outputPath="data_train/train_liver.h5")
#part4 接part3
    dataset.add(full_images, full_livers)
    dataset.add(x, y)
    # end of lop
dataset.close()

测试数据存储的全部过程
测试数据与上面训练数据的处理过程几乎一样,但测试数据不要进行数据增强

full_images2 = []
full_livers2 = []
for i in range(18,21):#后3个人作为测试样本
    label_path = '~/3Dircadb/3Dircadb1.%d/MASKS_DICOM/liver'%i
    data_path = '~/3Dircadb/3Dircadb1.%d/PATIENT_DICOM'%i
    liver_slices = [pydicom.dcmread(label_path + '/' + s) for s in os.listdir(label_path)]
    liver_slices.sort(key = lambda x: int(x.InstanceNumber))
    livers = np.stack([s.pixel_array for s in liver_slices])
    start,end = getRangImageDepth(livers)
    total = (end - 4) - (start+4) +1
    print("%d person, total slices %d"%(i,total))
    
    image_slices = [pydicom.dcmread(data_path + '/' + s) for s in os.listdir(data_path)]
    image_slices.sort(key = lambda x: int(x.InstanceNumber))
    
    images = get_pixels_hu(image_slices)
    images = transform_ctdata(images,500,150)
    images = clahe_equalized(images,start,end)
    images /= 255.
    images = images[start+5:end-5]
    print("%d person, images.shape:(%d,)"%(i,images.shape[0]))
    livers[livers>0] = 1
    livers = livers[start+5:end-5]

    full_images2.append(images)
    full_livers2.append(livers)
    
full_images2 = np.vstack(full_images2)
full_images2 = np.expand_dims(full_images2,axis=-1)
full_livers2 = np.vstack(full_livers2)
full_livers2 = np.expand_dims(full_livers2,axis=-1)

dataset = HDF5DatasetWriter(image_dims=(full_images2.shape[0], full_images2.shape[1], full_images2.shape[2], 1),
                            mask_dims=(full_images2.shape[0], full_images2.shape[1], full_images2.shape[2], 1),
                            outputPath="data_train/val_liver.h5")


dataset.add(full_images2, full_livers2)

print("total images in val ",dataset.close())
1.3.5 构建网络

这一部分不多说,是直接用的别人写好的Unet。
唯一进行的更改,就是Crop的那部分。因为,如果输入图片的高或者宽不能等于2^n(n>=m),m为网络收缩路径的层数。那么出现的问题就是,下采样的时候进行了取整操作,但是,后续在扩张路径又会有上采样,且上采样的结果会和收缩路径的特征图进行拼接,即long skip connection。若没有达到上面提到的要求,会导致拼接的两个特征图size不同,报错。

# partA
import os
import sys
import numpy as np
import random
import math
import tensorflow as tf
from HDF5DatasetGenerator import HDF5DatasetGenerator
from keras.models import Model
from keras.layers import Input, concatenate, Conv2D, MaxPooling2D, Conv2DTranspose,Cropping2D
from keras.optimizers import Adam
from keras.callbacks import ModelCheckpoint
from keras import backend as K
from skimage import io


K.set_image_data_format('channels_last')
    
def dice_coef(y_true, y_pred):
    smooth = 1.
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)


def dice_coef_loss(y_true, y_pred):
    return -dice_coef(y_true, y_pred)

def get_crop_shape(target, refer):
        # width, the 3rd dimension
        print(target.shape)
        print(refer._keras_shape)
        cw = (target._keras_shape[2] - refer._keras_shape[2])
        assert (cw >= 0)
        if cw % 2 != 0:
            cw1, cw2 = int(cw/2), int(cw/2) + 1
        else:
            cw1, cw2 = int(cw/2), int(cw/2)
        # height, the 2nd dimension
        ch = (target._keras_shape[1] - refer._keras_shape[1])
        assert (ch >= 0)
        if ch % 2 != 0:
            ch1, ch2 = int(ch/2), int(ch/2) + 1
        else:
            ch1, ch2 = int(ch/2), int(ch/2)

        return (ch1, ch2), (cw1, cw2)

def get_unet():
    inputs = Input((IMG_HEIGHT, IMG_WIDTH , 1))
    conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(inputs)
    conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(pool1)
    conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(pool2)
    conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

    conv4 = Conv2D(256, (3, 3), activation='relu', padding='same')(pool3)
    conv4 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)

    conv5 = Conv2D(512, (3, 3), activation='relu', padding='same')(pool4)
    conv5 = Conv2D(512, (3, 3), activation='relu', padding='same')(conv5)

    up_conv5 = Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(conv5)
    
    ch, cw = get_crop_shape(conv4, up_conv5)
    
    crop_conv4 = Cropping2D(cropping=(ch,cw), data_format="channels_last")(conv4)
    up6 = concatenate([up_conv5, crop_conv4], axis=3)
    conv6 = Conv2D(256, (3, 3), activation='relu', padding='same')(up6)
    conv6 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv6)
    
    up_conv6 = Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(conv6)

    ch, cw = get_crop_shape(conv3, up_conv6)
    crop_conv3 = Cropping2D(cropping=(ch,cw), data_format="channels_last")(conv3)
    
    up7 = concatenate([up_conv6, crop_conv3], axis=3)
    conv7 = Conv2D(128, (3, 3), activation='relu', padding='same')(up7)
    conv7 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv7)
    
    up_conv7 = Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(conv7)
    ch, cw = get_crop_shape(conv2, up_conv7)
    crop_conv2 = Cropping2D(cropping=(ch,cw), data_format="channels_last")(conv2)

    up8 = concatenate([up_conv7, crop_conv2], axis=3)
    conv8 = Conv2D(64, (3, 3), activation='relu', padding='same')(up8)
    conv8 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv8)
    
    up_conv8 = Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(conv8)
    ch, cw = get_crop_shape(conv1, up_conv8)
    crop_conv1 = Cropping2D(cropping=(ch,cw), data_format="channels_last")(conv1)


    up9 = concatenate([up_conv8, crop_conv1], axis=3)
    conv9 = Conv2D(32, (3, 3), activation='relu', padding='same')(up9)
    conv9 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv9)

    conv10 = Conv2D(1, (1, 1), activation='sigmoid')(conv9)

    model = Model(inputs=[inputs], outputs=[conv10])

    model.compile(optimizer=Adam(lr=1e-5), loss=dice_coef_loss, metrics=[dice_coef])

    return model
1.3.6 进行训练并测试

这其中主要包括训练文件与测试文件的读取,Checkpoint回掉函数的设定,fit_generator的使用,模型预测,并对预测结果进行保存。

# partB 接partA
IMG_WIDTH = 512
IMG_HEIGHT = 512
IMG_CHANNELS = 1
TOTAL = 2782 # 总共的训练数据
TOTAL_VAL = 152 # 总共的validation数据
# part1部分储存的数据文件
outputPath = './data_train/train_liver.h5' # 训练文件
val_outputPath = './data_train/val_liver.h5'
#checkpoint_path = 'model.ckpt'
BATCH_SIZE = 8 # 根据服务器的GPU显存进行调整

class UnetModel:
    def train_and_predict(self):
        
        reader = HDF5DatasetGenerator(dbPath=outputPath,batchSize=BATCH_SIZE)
        train_iter = reader.generator()
        
        test_reader = HDF5DatasetGenerator(dbPath=val_outputPath,batchSize=BATCH_SIZE)
        test_iter = test_reader.generator()
        fixed_test_images, fixed_test_masks = test_iter.__next__()
#   
        
        model = get_unet()
        model_checkpoint = ModelCheckpoint('weights.h5', monitor='val_loss', save_best_only=True)
        # 注:感觉validation的方式写的不对,应该不是这样弄的
        model.fit_generator(train_iter,steps_per_epoch=int(TOTAL/BATCH_SIZE),verbose=1,epochs=500,shuffle=True,
                            validation_data=(fixed_test_images, fixed_test_masks),callbacks=[model_checkpoint])
#        
        reader.close()
        test_reader.close()
        
        
        print('-'*30)
        print('Loading and preprocessing test data...')
        print('-'*30)
        
        print('-'*30)
        print('Loading saved weights...')
        print('-'*30)
        model.load_weights('weights.h5')
    
        print('-'*30)
        print('Predicting masks on test data...')
        print('-'*30)
        
        
        imgs_mask_test = model.predict(fixed_test_images, verbose=1)
        np.save('imgs_mask_test.npy', imgs_mask_test)
    
        print('-' * 30)
        print('Saving predicted masks to files...')
        print('-' * 30)
        pred_dir = 'preds'
        if not os.path.exists(pred_dir):
            os.mkdir(pred_dir)
        i = 0
        
        
        for image in imgs_mask_test:
            image = (image[:, :, 0] * 255.).astype(np.uint8)
            gt = (fixed_test_masks[i,:,:,0] * 255.).astype(np.uint8)
            ini = (fixed_test_images[i,:,:,0] *255.).astype(np.uint8)
            io.imsave(os.path.join(pred_dir, str(i) + '_ini.png'), ini)
            io.imsave(os.path.join(pred_dir, str(i) + '_pred.png'), image)
            io.imsave(os.path.join(pred_dir, str(i) + '_gt.png'), gt)
            i += 1

unet = UnetModel()
unet.train_and_predict()

模型跑的过程如图。
3dircadb数据集介绍,# 计算机视觉,计算机视觉,人工智能,python,图像处理
预测结果可视化展示
3dircadb数据集介绍,# 计算机视觉,计算机视觉,人工智能,python,图像处理

关注公众号:『AI学习星球
回复:肝脏肿瘤分割 即可获取数据下载。
需要论文辅导4对1辅导算法学习,可以通过公众号滴滴我
3dircadb数据集介绍,# 计算机视觉,计算机视觉,人工智能,python,图像处理文章来源地址https://www.toymoban.com/news/detail-780424.html

到了这里,关于计算机视觉-02-基于U-Net的肝脏肿瘤分割(包含数据和代码)的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 基于计算机视觉的葡萄检测分级系统

    【摘要】设计了一套基于计算机视觉的葡萄检测分级系统,包括驱动装置、输送机构、夹持机构、图像釆集与 处理系统和分级控制系统,葡萄以悬挂方式连续输送,两个 CCD 摄像机在外触发模式下实时采集葡萄的两面图像。 基于 RGB 色彩空间计算果面着色率,采用投影面积法和

    2024年02月19日
    浏览(36)
  • 基于计算机视觉的学生上课姿态识别

    数据集 1.1  A VA 数据集介绍 AVA数据集为目前行为数据集中背景最复杂、人体目标最多的数据集,是由Google在2018年所发表的一个用于训练动作检测的数据集,该数据集注释430个15分钟电影切片中的80个原子视觉动作,在空间和时间上定位了动作,从而产生了1.62万个动作标签。这

    2024年02月02日
    浏览(53)
  • 计算机竞赛 - 基于机器视觉的图像拼接算法

    图像拼接在实际的应用场景很广,比如无人机航拍,遥感图像等等,图像拼接是进一步做图像理解基础步骤,拼接效果的好坏直接影响接下来的工作,所以一个好的图像拼接算法非常重要。 再举一个身边的例子吧,你用你的手机对某一场景拍照,但是你没有办法一次将所有你

    2024年02月13日
    浏览(67)
  • 基于计算机视觉的坑洼道路检测和识别

    本研究论文提出了一种使用深度学习和图像处理技术进行坑洼检测的新方法。所提出的系统利用VGG16模型进行特征提取,并利用具有三重损失的自定义Siamese网络,称为RoadScan。该系统旨在解决道路上的坑洼这一关键问题,这对道路使用者构成重大风险。由于道路上的坑洼造成

    2024年02月05日
    浏览(52)
  • 基于机器视觉的车道线检测 计算机竞赛

    🔥 优质竞赛项目系列,今天要分享的是 基于深度学习的视频多目标跟踪实现 该项目较为新颖,适合作为竞赛课题方向,学长非常推荐! 🧿 更多资料, 项目分享: https://gitee.com/dancheng-senior/postgraduate 理解车道检测的概念 那么什么是车道检测?以下是百度百科对车道的定义:

    2024年02月08日
    浏览(52)
  • 基于cv2的手势识别-计算机视觉

      闲的无聊做的一个小玩意,可以调用你的计算机相机,识别框内的手势(剪刀、石头和布),提供一个判决平台,感兴趣的可以继续完善。 用到的参考小文献: 具体实现结果如下 并且我另写了一个框架平台,可以进行下一步的功能拓展,发在我的资源界面了;   我们

    2024年02月01日
    浏览(51)
  • 计算机视觉 - 基于黄金模板比较技术的缺陷检测

            基于黄金模板比对的检测是一种常见的视觉应用。当进行缺陷检查而其他缺陷检测方法是不可行的时候,使用金模板比较。另外当物体的表面或物体的形状非常复杂时,此技术特别有用。          虽然说黄金模板比较的技术的思路很简单,但是真正落地实施确

    2024年02月09日
    浏览(39)
  • 基于计算机视觉的物流和供应链管理

    作者:禅与计算机程序设计艺术 物流、供应链是一个非常重要的现代经济活动,许多企业都面临着如何提高效率,降低成本,改善供应链服务质量的问题。目前,人们已经在探索如何通过人工智能、物联网等新兴技术,实现自动化运输过程和管理。基于计算机视觉技术的物流

    2024年02月10日
    浏览(37)
  • 计算机竞赛 基于机器视觉的行人口罩佩戴检测

    简介 2020新冠爆发以来,疫情牵动着全国人民的心,一线医护工作者在最前线抗击疫情的同时,我们也可以看到很多科技行业和人工智能领域的从业者,也在贡献着他们的力量。近些天来,旷视、商汤、海康、百度都多家科技公司研发出了带有AI人脸检测算法的红外测温、口罩

    2024年02月10日
    浏览(48)
  • 基于深度学习的计算机视觉垃圾分类项目解析

    项目地址:https://gitcode.com/YaoHaozhe/Computer-vision-based-on-deep-learning-garbage-classification 在这个数字化的时代,数据已经成为我们生活和工作的重要组成部分,而其中,图像数据的处理能力更是关键。YaoHaozhe 创建的这个基于深度学习的计算机视觉垃圾分类项目,提供了一个实用的解决

    2024年04月12日
    浏览(30)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包