深度学习技术栈 —— Pytorch之TensorDataset、DataLoader

这篇具有很好参考价值的文章主要介绍了深度学习技术栈 —— Pytorch之TensorDataset、DataLoader。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。


前言

简单来说,TensorDatasetDataLoader这两个类的作用, 就是将数据读入并做整合,以便交给模型处理。就像石油加工厂一样,你不关心石油是如何采集与加工的,你关心的是自己去哪加油,油价是多少,对于一个模型而言,DataLoader就是这样的一个予取予求的数据服务商。

参考文章或视频链接
[1] How to use TensorDataset, Dataloader (pytorch)

一、TensorDataset、DataLoader的用法?

# coding:utf-8
# @Time: 2024/1/23 上午9:57
# @Author: 键盘国治理专家
# @File: __init__.py.py
# @Description: 

import numpy as np
import torch
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader


def test_TensorDataset():
    input = np.random.rand(4, 2)  # Input data
    correct = np.random.rand(4, 1)  # Correct answer data

    input = torch.FloatTensor(input)  # Change to an array that can be handled by pytorch
    correct = torch.FloatTensor(correct)  # Same as above

    print(input)
    print(correct)

    dataset = TensorDataset(input, correct)  # set the data,注意,是TensorDataset而不是Dataset,Dataset是个abstract class不能实例化

    print(dataset)  # 打印地址
    print(vars(dataset))  # vars prints the contents of the object
    return dataset


def test_DataLoader(dataset):
    train_load = DataLoader(dataset, batch_size=3, shuffle=False)  # Data shuffle with shuffle=True
    for x, t in train_load:
        print('x-->', x)
        print('t-->', t)


if __name__ == '__main__':
    dataset = test_TensorDataset()
    print("========================================================================================")
    test_DataLoader(dataset)

二、从.csv文件–>tensor张量

一般说来,大部分Kaggle比赛的数据都是以.csv为格式的,而Pytorch处理的是tensor张量,所以我们要了解如何将.csv文件的数据变成tensor张量数据。

"""
步骤如下
(1) xx.csv --> 经由pandas 变成 numpy 数组
(2) numpy 变成 tensor 张量
(3) tensor张量经过TensorDataset的组合
(4) dataset再经过DataLoader的处理,进而保证数据可用,以上为清洗过程
.csv --> numpy --> tensor --> dataset --> dataloader 四个过程,五个数据中转形式。
"""
# coding:utf-8
# @Time: 2024/1/23 下午1:01
# @Author: 键盘国治理专家
# @File: csv2tensor.py
# @Description:

import numpy
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader

def csv2numpy(csv_path):
    data = pd.read_csv(csv_path, dtype=np.float64)
    # numpy_data = data.iloc[:, data.columns != "xx"]  # 另一种用法,data.columns != "xx" 可以过滤掉你不想读入的字段
    numpy_data = data.iloc[:].values
    return numpy_data


def numpy2tensor(numpy_data):
    tensor_data = torch.from_numpy(numpy_data)
    return tensor_data


def tensor2DataLoader(tensor_data):  # 一步到位,直接变成DataLoader。最简单的实现方式,这个func还有改进空间,DataSet可以接收多个tensor数据
    dataset = torch.utils.data.TensorDataset(tensor_data)
    data_loader = torch.utils.data.DataLoader(dataset, shuffle=False)
    return data_loader


# 你甚至可以直接将.csv处理成DataLoader了,把这几个过程简单组合下形成一个新函数
def csv2DataLoader(csv_path):
    numpy_data = csv2numpy(csv_path)
    tensor_data = numpy2tensor(numpy_data)
    data_loader = tensor2DataLoader(tensor_data)
    return data_loader


if __name__ == '__main__':

    numpy_data = csv2numpy("./test.csv")
    # print(type(numpy_data))
    # print(numpy_data.shape)
    # print(numpy_data)

    tensor_data = numpy2tensor(numpy_data)
    # print(type(tensor_data))
    # print(tensor_data.shape)
    # print(tensor_data)

    data_loader = tensor2DataLoader(tensor_data)
    # print(type(data_loader))
    # print(data_loader)
    # print(data_loader.dataset)
    # # 用遍历的方式才能输出data_loader里的数据
    # for data_item in data_loader:
    #     print('data_item-->', data_item)
    # # 把数据的索引也一起输出
    # for i, data_item in enumerate(data_loader):
    #     print('i', i)
    #     print('data_item-->', data_item)

总结

本篇工作虽然简单,但确是进阶的一个不大不小的绊脚石,功夫虽小,也不能不练。文章来源地址https://www.toymoban.com/news/detail-820159.html

到了这里,关于深度学习技术栈 —— Pytorch之TensorDataset、DataLoader的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 【使用机器学习和深度学习对城市声音进行分类】基于两种技术(ML和DL)对音频数据(城市声音)进行分类(Matlab代码实现)

     💥💥💞💞 欢迎来到本博客 ❤️❤️💥💥 🏆博主优势: 🌞🌞🌞 博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️ 座右铭: 行百里者,半于九十。 📋📋📋 本文目录如下: 🎁🎁🎁 目录 💥1 概述 📚2 运行结果 2.1 算例1 2.2 算例2 2.3 算例3 2.4 算例4

    2024年02月16日
    浏览(33)
  • 【代码笔记】Pytorch学习 DataLoader模块详解

    dataloader主要有6个class构成(可见下图) _DatasetKind: _InfiniteConstantSampler: DataLoader: _BaseDataLoaderIter: _SingleProcessDataLoaderIter: _MultiProcessingDataLoaderIter: 我们首先看一下DataLoader的整体结构: init : _get_iterator: multiprocessing_context: multiprocessing_context: setattr : iter : _auto_collation: _ind

    2023年04月11日
    浏览(27)
  • PyTorch机器学习与深度学习技术方法

    近年来,随着AlphaGo、无人驾驶汽车、医学影像智慧辅助诊疗、ImageNet竞赛等热点事件的发生,人工智能迎来了新一轮的发展浪潮。尤其是深度学习技术,在许多行业都取得了颠覆性的成果。另外,近年来,Pytorch深度学习框架受到越来越多科研人员的关注和喜爱。 Python基础知

    2024年02月02日
    浏览(35)
  • pytorch进阶学习(二):使用DataLoader读取自己的数据集

    上一节使用的是官方数据集fashionminist进行训练,这节课使用自己搜集的数据集来进行数据的获取和训练。 教学视频:https://www.bilibili.com/video/BV1by4y1b7hX/?spm_id_from=333.1007.top_right_bar_window_history.content.clickvd_source=e482aea0f5ebf492c0b0220fb64f98d3 pytorch进阶学习(一):https://blog.csdn.net/w

    2024年02月09日
    浏览(28)
  • 深度学习技术栈 —— Pytorch中保存与加载权重文件

    权重文件是指训练好的模型参数文件,不同的深度学习框架和模型可能使用不同的权重文件格式。以下是一些常见的权重文件格式: PyTorch 的模型格式: .pt 文件。 Darknet 的模型格式: .weight 文件。 TensorFlow 的模型格式: .ckpt 文件。 一、参考文章或视频链接 [1] Navigating Mode

    2024年01月19日
    浏览(41)
  • 手把手写深度学习(23):视频扩散模型之Video DataLoader

    手把手写深度学习(0):专栏文章导航 前言: 训练自己的视频扩散模型的第一步就是准备数据集,而且这个数据集是text-video或者image-video的多模态数据集,这篇博客手把手教读者如何写一个这样扩散模型的的Video DataLoader。 目录 准备工作 下载数据集 视频数据打标签

    2024年03月21日
    浏览(39)
  • PyTorch深度学习遥感影像地物分类与目标检测、分割及遥感影像问题深度学习优化实践技术应用

    我国高分辨率对地观测系统重大专项已全面启动,高空间、高光谱、高时间分辨率和宽地面覆盖于一体的全球天空地一体化立体对地观测网逐步形成,将成为保障国家安全的基础性和战略性资源。未来10年全球每天获取的观测数据将超过10PB,遥感大数据时代已然来临。随着小

    2024年02月10日
    浏览(51)
  • 用pytorch给深度学习加速:正交与谱归一化技术

    目录 torch.nn参数优化 parametrizations.orthogonal 用途 用法 使用技巧 参数 注意事项 示例代码 parametrizations.spectral_norm 用途 用法 使用技巧 参数 注意事项 示例代码 总结 这个 torch.nn.utils.parametrizations.orthogonal 模块是PyTorch库中的一个功能,用于对神经网络中的矩阵或一批矩阵应用正交

    2024年01月17日
    浏览(30)
  • TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11

    原文:Mobile Deep Learning with TensorFlow Lite, ML Kit and Flutter 协议:CC BY-NC-SA 4.0 译者:飞龙 本文来自【ApacheCN 深度学习 译文集】,采用译后编辑(MTPE)流程来尽可能提升效率。 不要担心自己的形象,只关心如何实现目标。——《原则》,生活原则 2.3.c 认证是任何应用中最突出的

    2023年04月24日
    浏览(68)
  • 深度学习推荐系统(三)NeuralCF及其在ml-1m电影数据集上的应用

    在2016年, 随着微软的Deep Crossing, 谷歌的WideDeep以及FNN、PNN等一大批优秀的深度学习模型被提出, 推荐系统全面进入了深度学习时代, 时至今日, 依然是主流。 推荐模型主要有下面两个进展: 与传统的机器学习模型相比, 深度学习模型的表达能力更强, 能够挖掘更多数据

    2024年02月10日
    浏览(44)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包