不同的数据集在形式上千差万别,为了能够统一用于模型的训练,Pytorch框架下定义了一个dataset类和一个dataloader类。
dataset用于获取数据集中的样本,dataloader 用于抽取部分样本用于训练。比如说一个用于分割任务的图像数据集的结构如图1所示,一个样本由原图像和对应的mask组成。
图1 典型数据集的结构
为了获取数据集,典型的代码如下
from torch.utils.data import Dataset, DataLoader
import numpy as np
from PIL import Image
import os
from torchvision import transforms
# 定义数据集
train_data_dir = 'dataset/train'
train_GT_dir = 'dataset/train_GT'
class MyData(Dataset):
def __init__(self, imgdir, maskdir,transform):
self.imgdir = imgdir
self.maskdir = maskdir
self.transform = transform
self.img_list = os.listdir(self.imgdir)
self.mask_list= os.listdir(self.maskdir)
self.img_list.sort()
self.mask_list.sort()
def __getitem__(self, idx):
img_name = self.img_list[idx]
mask_name =self.mask_list[idx]
img_item_path = os.path.join(self.imgdir, img_name)
mask_item_path =os.path.join(self.maskdir,mask_name)
img =Image.open(img_item_path)
mask =Image.open(mask_item_path)
img = self.transform(img)
mask = self.transform(mask)
return img, mask
def __len__(self):
assert len(self.img_list) == len(self.mask_list)
return len(self.img_list)
if __name__ == '__main__':
transform = transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor()])
train_data_dir = 'dataset/train'
train_GT_dir = 'dataset/train_GT'
dataset = MyData(train_data_dir, train_GT_dir ,transform)
dataloader = DataLoader(dataset, batch_size=4, num_workers=0)
for step, (img,mask) in enumerate(dataloader):
print(step)
print(img.shape)
print(mask.shape)
if step>0:
break
程序运行的结果如下:
返回了一个batch的img 和mask 的尺寸,说明数据集抽取成功了.文章来源:https://www.toymoban.com/news/detail-823266.html
在建立数据集的过程中需用重写__getitem()__和__len()__方法即可。文章来源地址https://www.toymoban.com/news/detail-823266.html
到了这里,关于Pytorch中Dataset和dadaloader的理解的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!