Dataset与DataLoader的关系
- Dataset: 构建一个数据集,其中含有所有的数据样本
- DataLoader:将构建好的Dataset,通过shuffle、划分batch、多线程num_workers运行的方式,加载到可训练的迭代容器。
import torch
from torch.utils.data import Dataset, DataLoader
class MyDataset(Dataset):
"""创建自己的数据集"""
def __init__(self):
"""初始化构建数据集所需要的参数"""
pass
def __getitem__(self, index):
"""来获取数据集中样本的索引"""
pass
def __len__(self):
"""获取数据集中的样本个数"""
pass
# 实例化自定义的数据集
dataset = MyDataset()
# 将自定义的数据集加载到可训练的迭代容器
train_loader = DataLoader(dataset=dataset, # 自定义的数据集
batch_size=32, # 数据集中小批量的大小
shuffle=True, # 是否要打乱数据集中样本的次序
num_workers=2) # 是否要并行
实战1:CSV数据集(结构化数据集)
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
class MyDataset(Dataset):
"""创建自己的数据集"""
def __init__(self, filepath):
"""初始化构建数据集所需要的参数"""
xy = np.loadtxt(filepath, delimiter=',', dtype=np.float32)
self.len = xy.shape[0] # 查看数据集中样本的个数
self.x_data = torch.from_numpy(xy[:, :-1])
self.y_data = torch.from_numpy(xy[:, [-1]])
print("数据已准备好......")
def __getitem__(self, index):
"""为了支持下标操作, 即索引dataset[index]:来获取数据集中样本的索引"""
return self.x_data[index], self.y_data[index]
def __len__(self):
"""为了使用len(dataset):获取数据集中的样本个数"""
return self.len
file = "D:\\BaiduNetdiskDownload\\Dataset_Dataload\\diabetes1.csv"
""" 1.使用 MyDataset类 构建自己的dataset """
mydataset = MyDataset(file)
""" 2.使用 DataLoader 构建train_loader """
train_loader = DataLoader(dataset=mydataset,
batch_size=32,
shuffle=True,
num_workers=0)
class MyModel(torch.nn.Module):
"""定义自己的模型"""
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(8, 6)
self.linear2 = torch.nn.Linear(6, 4)
self.linear3 = torch.nn.Linear(4, 1)
self.sigmooid = torch.nn.Sigmoid()
def forward(self, x):
x = self.sigmooid(self.linear1(x))
x = self.sigmooid(self.linear2(x))
x = self.sigmooid(self.linear3(x))
return x
# 实例化模型
model = MyModel()
# 定义损失函数
criterion = torch.nn.BCELoss(size_average=True)
# 定义优化器
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
if __name__ == "__main__":
for epoch in range(10):
for i, data in enumerate(train_loader, 0):
# 1. 准备数据
inputs, labels = data
# 2. 前向传播
y_pred= model(inputs)
loss = criterion(y_pred, labels)
print(epoch, i, loss.item())
# 3. 反向传播
optimizer.zero_grad()
loss.backward()
# 4. 梯度更新
optimizer.step()
文章来源:https://www.toymoban.com/news/detail-814276.html
实战2:图片数据集
├── flower_data
—├── flower_photos(解压的数据集文件夹,3670个样本)
—├── train(生成的训练集,3306个样本)
—└── val(生成的验证集,364个样本)文章来源地址https://www.toymoban.com/news/detail-814276.html
主函数文件main.py
import os
import torch
from torchvision import transforms
from my_dataset import MyDataSet
from utils import read_split_data, plot_data_loader_image
root = "../data/flower_data/flower_photos" # 数据集所在根目录
def main():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("using {} device.".format(device))
train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(root)
data_transform = {
"train": transforms.Compose([transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
"val": transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
train_data_set = MyDataSet(images_path=train_images_path,
images_class=train_images_label,
transform=data_transform["train"])
batch_size = 8
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
print('Using {} dataloader workers'.format(nw))
train_loader = torch.utils.data.DataLoader(train_data_set,
batch_size=batch_size,
shuffle=True,
num_workers=nw,
collate_fn=train_data_set.collate_fn)
# plot_data_loader_image(train_loader)
for epoch in range(100):
for step, data in enumerate(train_loader):
images, labels = data
# 然后在进行相应的训练操作即可
if __name__ == '__main__':
main()
自定义数据集文件my_dataset.py
from PIL import Image
import torch
from torch.utils.data import Dataset
class MyDataSet(Dataset):
"""自定义数据集"""
def __init__(self, images_path: list, images_class: list, transform=None):
self.images_path = images_path
self.images_class = images_class
self.transform = transform
def __len__(self):
return len(self.images_path)
def __getitem__(self, item):
img = Image.open(self.images_path[item])
# RGB为彩色图片,L为灰度图片
if img.mode != 'RGB':
raise ValueError("image: {} isn't RGB mode.".format(self.images_path[item]))
label = self.images_class[item]
if self.transform is not None:
img = self.transform(img)
return img, label
@staticmethod
def collate_fn(batch):
# 官方实现的default_collate可以参考
# https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.py
images, labels = tuple(zip(*batch))
images = torch.stack(images, dim=0)
labels = torch.as_tensor(labels)
return images, labels
功能文件utils.py(训练集、验证集的划分与可视化)
import os
import json
import pickle
import random
import matplotlib.pyplot as plt
def read_split_data(root: str, val_rate: float = 0.2):
random.seed(0) # 保证随机结果可复现
assert os.path.exists(root), "dataset root: {} does not exist.".format(root) # 判断路径是否存在
# 遍历文件夹,一个文件夹对应一个类别
flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]
# 排序,保证顺序一致
flower_class.sort()
# 生成类别名称以及对应的数字索引: 字典{’花名‘:0,’花名‘:1,···}
class_indices = dict((k, v) for v, k in enumerate(flower_class))
json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4) # 将花名与对应的序号分行保存
with open('class_indices.json', 'w') as json_file:
json_file.write(json_str)
train_images_path = [] # 存储训练集的所有图片路径
train_images_label = [] # 存储训练集图片对应索引信息
val_images_path = [] # 存储验证集的所有图片路径
val_images_label = [] # 存储验证集图片对应索引信息
every_class_num = [] # 存储每个类别的样本总数
supported = [".jpg", ".JPG", ".png", ".PNG"] # 支持的文件后缀类型
# 遍历每个文件夹下的文件
for cla in flower_class:
cla_path = os.path.join(root, cla)
# 遍历获取supported支持的所有文件路径
images = [os.path.join(root, cla, i) for i in os.listdir(cla_path)
if os.path.splitext(i)[-1] in supported]
# 获取该类别对应的索引
image_class = class_indices[cla]
# 记录该类别的样本数量
every_class_num.append(len(images))
# 按比例随机采样验证样本
val_path = random.sample(images, k=int(len(images) * val_rate))
for img_path in images:
if img_path in val_path: # 如果该路径在采样的验证集样本中则存入验证集
val_images_path.append(img_path)
val_images_label.append(image_class)
else: # 否则存入训练集
train_images_path.append(img_path)
train_images_label.append(image_class)
print("{} images were found in the dataset.".format(sum(every_class_num)))
print("{} images for training.".format(len(train_images_path)))
print("{} images for validation.".format(len(val_images_path)))
plot_image = True
if plot_image:
# 绘制每种类别个数柱状图
plt.bar(range(len(flower_class)), every_class_num, align='center')
# 将横坐标0,1,2,3,4替换为相应的类别名称
plt.xticks(range(len(flower_class)), flower_class)
# 在柱状图上添加数值标签
for i, v in enumerate(every_class_num):
plt.text(x=i, y=v + 5, s=str(v), ha='center')
# 设置x坐标
plt.xlabel('image class')
# 设置y坐标
plt.ylabel('number of images')
# 设置柱状图的标题
plt.title('flower class distribution')
plt.show()
return train_images_path, train_images_label, val_images_path, val_images_label
def plot_data_loader_image(data_loader):
batch_size = data_loader.batch_size
plot_num = min(batch_size, 4)
json_path = './class_indices.json'
assert os.path.exists(json_path), json_path + " does not exist."
json_file = open(json_path, 'r')
class_indices = json.load(json_file)
for data in data_loader:
images, labels = data
for i in range(plot_num):
# [C, H, W] -> [H, W, C]
img = images[i].numpy().transpose(1, 2, 0)
# 反Normalize操作
img = (img * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]) * 255
label = labels[i].item()
plt.subplot(1, plot_num, i+1)
plt.xlabel(class_indices[str(label)])
plt.xticks([]) # 去掉x轴的刻度
plt.yticks([]) # 去掉y轴的刻度
plt.imshow(img.astype('uint8'))
plt.show()
def write_pickle(list_info: list, file_name: str):
with open(file_name, 'wb') as f:
pickle.dump(list_info, f)
def read_pickle(file_name: str) -> list:
with open(file_name, 'rb') as f:
info_list = pickle.load(f)
return info_list
到了这里,关于如何使用pytorch的Dataset, 来定义自己的Dataset的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!