Pytorch文本分类入门

这篇具有很好参考价值的文章主要介绍了Pytorch文本分类入门。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

🍨 本文为[🔗365天深度学习训练营学习记录博客

🍦 参考文章:365天深度学习训练营

🍖 原作者:[K同学啊 | 接辅导、项目定制]\n🚀 文章来源:[K同学的学习圈子](https://www.yuque.com/mingtian-fkmxf/zxwb45)

一、加载数据

import os
import sys
import PIL
from PIL import Image
import time
import copy
import random
import pathlib
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchtext.datasets import AG_NEWS
import torchvision
from torchinfo import summary
import torchsummary
import matplotlib.pyplot as plt
import numpy as np
import warnings


''' 下载或读取AG News数据集中的训练集与测试集 '''
def getDataset(root, dataset):
    if not os.path.exists(root) or not os.path.isdir(root):
        os.makedirs(root)
    if not os.path.exists(dataset) or not os.path.isdir(dataset):
        print('Downloading dataset...\n')
        # 下载AG News数据集 直接运行会报网络错误 无法下载  
        train_ds, test_ds = AG_NEWS(root=root, split=("train", "test"))
    else:
        print('Dataset already downloaded, reading...\n')
        # 读取本地AG News数据集 手动下载了train.csv和test.csv后可从本地加载数据
        train_ds, test_ds = AG_NEWS(root=dataset, split=("train", "test"))
    #print("Train:", next(train_ds), len(list(train_ds))+1)
    #print("Test :", next(test_ds), len(list(test_ds))+1)
    return train_ds, test_ds


''' 设置GPU '''
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using {} device\n".format(device))
''' 加载数据 '''
root = './data/'
data_dir = os.path.join(root, 'AG_NEWS.data')
train_ds, test_ds = getDataset(root, data_dir)

 运行结果:

Using cuda device

Dataset already downloaded, reading...

Train: (3, "Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green again.") 120000
Test : (3, "Fears for T N pension after talks Unions representing workers at Turner   Newall say they are 'disappointed' after talks with stricken parent firm Federal Mogul.") 7600

二、构建词典

''' 构建词典 '''
def buildDict(train_ds):
    tokenizer  = get_tokenizer('basic_english') # 返回分词器函数
    def yield_tokens(data_iter):
        for _, text in data_iter:
            yield tokenizer(text)
    vocab = build_vocab_from_iterator(yield_tokens(train_ds))
    text_pipeline  = lambda x: vocab.lookup_indices(tokenizer(x))
    label_pipeline = lambda x: int(x)
    #print(vocab.UNK, vocab._default_unk_index())# 打印默认索引,如果找不到单词,则会选择默认索引
    #print(vocab.lookup_indices(['here', 'is', 'an', 'example']))
    #print(text_pipeline('here is the an example'))
    #print(label_pipeline('10'))
    return vocab, text_pipeline, label_pipeline


# 构建词典
text_pipeline, label_pipeline = buildDict(train_ds)

运行结果: 

120001lines [00:04, 27817.88lines/s]
<unk> 0
[471, 22, 31, 5177]
[471, 22, 3, 31, 5177]
10

三、生成数据批次和迭代器

''' 加载数据,并设置batch_size '''
def loadData(train_ds, test_ds, batch_size=8, device='cpu'):
    # 构建词典
    vocab, text_pipeline, label_pipeline = buildDict(train_ds)
    # 生成数据批次和迭代器
    def collate_batch(batch):
        label_list, text_list, offsets = [], [], [0]
        for (_label, _text) in batch:
            # 标签列表
            label_list.append(label_pipeline(_label))
            # 文本列表
            processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
            text_list.append(processed_text)
            # 偏移量,即语句的总词汇量
            offsets.append(processed_text.size(0))
        label_list = torch.tensor(label_list, dtype=torch.int64)
        text_list  = torch.cat(text_list)
        offsets    = torch.tensor(offsets[:-1]).cumsum(dim=0) #返回维度dim中输入元素的累计和
        return label_list.to(device), text_list.to(device), offsets.to(device)
    # 从 train_ds 加载训练集
    train_dl = torch.utils.data.DataLoader(train_ds,
                                           batch_size=batch_size,
                                           shuffle=False,
                                           collate_fn=collate_batch,
                                           num_workers=0)
    # 从 test_ds 加载测试集
    test_dl  = torch.utils.data.DataLoader(test_ds,
                                           batch_size=batch_size,
                                           shuffle=False,
                                           collate_fn=collate_batch,
                                           num_workers=0)
    
    # 取一个批次查看数据格式
    #data = train_dl.__iter__()
    #print(type(data), data, '\n')
    return vocab, train_dl, test_dl


# 生成数据批次和迭代器
batch_size = 64
train_dl, test_dl = loadData(train_ds, test_ds, batch_size=batch_size, device=device)

运行结果:文章来源地址https://www.toymoban.com/news/detail-817442.html

120001lines [00:04, 27749.13lines/s]
<class 'torch.utils.data.dataloader._SingleProcessDataLoaderIter'> <torch.utils.data.dataloader._SingleProcessDataLoaderIter object at 0x00000266556204C0>

四、构建模型

class TextClassificationModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_class):
        super(TextClassificationModel, self).__init__()
        self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True)
        self.fc = nn.Linear(embed_dim, num_class)
        self.init_weights()

    def init_weights(self):
        initrange = 0.5
        self.embedding.weight.data.uniform_(-initrange, initrange)      # 将tensor用从均匀分布中抽样得到的值填充
        self.fc.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()

    def forward(self, text, offsets):
        embedded = self.embedding(text, offsets)        # torch.Size([64, 64])
        output = self.fc(embedded)      # torch.Size([64, 4])
        return output
''' 定义实例 '''
train_iter = AG_NEWS(root='./data/AG_NEWS.data', split=("train"))
num_class  = len(set([label for (label, text) in train_iter]))
vocab_size = len(vocab)
em_size    = 64
model      = TextClassificationModel(vocab_size, em_size, num_class).to(device)
print('num_class', num_class)
print('vocab_size', vocab_size)
print(model)
def train(dataloader):
    model.train()       # 训练模式
    total_acc, total_count = 0, 0
    log_interval = 500
    start_time = time.time()

    for idx, (label, text, offsets) in enumerate(dataloader):
        optimizer.zero_grad()
        predited_label = model(text, offsets)
        loss = criterion(predited_label, label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)     # 规定了最大不能超过的max_norm
        optimizer.step()
        total_acc += (predited_label.argmax(1) == label).sum().item()
        total_count += label.size(0)
        if idx % log_interval == 0 and idx > 0:
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches, accuracy {:8.3f}'.format(epoch, idx, len(dataloader), total_acc / total_count))
            total_acc, total_count = 0, 0
            start_time = time.time()
def evaluate(dataloader):
    model.eval()
    total_acc, total_count = 0, 0

    with torch.no_grad():
        for idx, (label, text, offsets) in enumerate(dataloader):
            predited_label = model(text, offsets)
            # loss = criterion(predited_label, label)
            total_acc += (predited_label.argmax(1) == label).sum().item()
            total_count += label.size(0)
    return total_acc / total_count

五、拆分数据集和运行模型

if __name__ == '__main__':
    # 超参数(Hyperparameters)
    EPOCHS = 10  # epoch
    LR = 5  # learning rate
    BATCH_SIZE = 64  # batch size for training
   
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=LR)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
    total_accu = None
    train_iter, test_iter = AG_NEWS(root=path)
    train_dataset = list(train_iter)
    test_dataset = list(test_iter)
    num_train = int(len(train_dataset) * 0.95)
    split_train_, split_valid_ = random_split(train_dataset, [num_train, len(train_dataset) - num_train])
   
    train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)      # shuffle表示随机打乱
    valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
    test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
   
    for epoch in range(1, EPOCHS + 1):
        epoch_start_time = time.time()
        train(train_dataloader)
        accu_val = evaluate(valid_dataloader)
        if total_accu is not None and total_accu > accu_val:
            scheduler.step()
        else:
            total_accu = accu_val
        print('-' * 59)
        print('| end of epoch {:3d} | time: {:5.2f}s | '
              'valid accuracy {:8.3f} '.format(epoch, time.time() - epoch_start_time, accu_val))
        print('-' * 59)
    
    torch.save(model.state_dict(), 'output\\model_TextClassification.pth')
| epoch   1 |   500/ 1782 batches, accuracy    0.687
| epoch   1 |  1000/ 1782 batches, accuracy    0.856
| epoch   1 |  1500/ 1782 batches, accuracy    0.875
-----------------------------------------------------------
| end of epoch   1 | time: 23.15s | valid accuracy    0.881
-----------------------------------------------------------
| epoch   2 |   500/ 1782 batches, accuracy    0.898
| epoch   2 |  1000/ 1782 batches, accuracy    0.898
| epoch   2 |  1500/ 1782 batches, accuracy    0.903
-----------------------------------------------------------
| end of epoch   2 | time: 16.20s | valid accuracy    0.897
-----------------------------------------------------------
| epoch   3 |   500/ 1782 batches, accuracy    0.917
| epoch   3 |  1000/ 1782 batches, accuracy    0.915
| epoch   3 |  1500/ 1782 batches, accuracy    0.914
-----------------------------------------------------------
| end of epoch   3 | time: 15.98s | valid accuracy    0.902
-----------------------------------------------------------
| epoch   4 |   500/ 1782 batches, accuracy    0.924
| epoch   4 |  1000/ 1782 batches, accuracy    0.924
| epoch   4 |  1500/ 1782 batches, accuracy    0.922
-----------------------------------------------------------
| end of epoch   4 | time: 16.63s | valid accuracy    0.901
-----------------------------------------------------------
| epoch   5 |   500/ 1782 batches, accuracy    0.937
| epoch   5 |  1000/ 1782 batches, accuracy    0.937
| epoch   5 |  1500/ 1782 batches, accuracy    0.938
-----------------------------------------------------------
| end of epoch   5 | time: 16.37s | valid accuracy    0.912
-----------------------------------------------------------
| epoch   6 |   500/ 1782 batches, accuracy    0.938
| epoch   6 |  1000/ 1782 batches, accuracy    0.939
| epoch   6 |  1500/ 1782 batches, accuracy    0.940
-----------------------------------------------------------
| end of epoch   6 | time: 16.17s | valid accuracy    0.912
-----------------------------------------------------------
| epoch   7 |   500/ 1782 batches, accuracy    0.940
| epoch   7 |  1000/ 1782 batches, accuracy    0.938
| epoch   7 |  1500/ 1782 batches, accuracy    0.943
-----------------------------------------------------------
| end of epoch   7 | time: 16.20s | valid accuracy    0.911
-----------------------------------------------------------
| epoch   8 |   500/ 1782 batches, accuracy    0.941
| epoch   8 |  1000/ 1782 batches, accuracy    0.940
| epoch   8 |  1500/ 1782 batches, accuracy    0.942
-----------------------------------------------------------
| end of epoch   8 | time: 16.46s | valid accuracy    0.911
-----------------------------------------------------------
| epoch   9 |   500/ 1782 batches, accuracy    0.941
| epoch   9 |  1000/ 1782 batches, accuracy    0.941
| epoch   9 |  1500/ 1782 batches, accuracy    0.943
-----------------------------------------------------------
| end of epoch   9 | time: 17.50s | valid accuracy    0.912
-----------------------------------------------------------
| epoch  10 |   500/ 1782 batches, accuracy    0.940
| epoch  10 |  1000/ 1782 batches, accuracy    0.942
| epoch  10 |  1500/ 1782 batches, accuracy    0.942
-----------------------------------------------------------
| end of epoch  10 | time: 16.12s | valid accuracy    0.912
-----------------------------------------------------------

实验目的

  • 构建一个文本分类模型,用于对AG News数据集中的新闻文章进行分类。

数据集

  • 使用的是AG News数据集,包括新闻文章及其相应类别标签。
  • 数据集被分为训练集和测试集。

数据预处理

  • 构建了一个词典(vocab),用于将文本转换为数字表示。
  • 定义了文本和标签的处理流程(text_pipelinelabel_pipeline)。

模型构建

  • 使用了EmbeddingBagLinear层构建了一个简单的文本分类模型。
  • 模型包含词嵌入层,将文本转换为固定大小的向量,随后通过一个全连接层进行分类。

训练过程

  • 使用交叉熵损失函数(CrossEntropyLoss)和随机梯度下降优化器(SGD)。
  • 实现了训练(train)和评估(evaluate)函数。
  • 训练了10个epoch,每个epoch结束后在验证集上评估模型。

结果和调优

  • 在训练过程中,如果验证集上的准确率没有提升,则减小学习率。
  • 每个epoch结束后打印了时间和验证集上的准确率。
  • 最终模型被保存为model_TextClassification.pth

到了这里,关于Pytorch文本分类入门的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • Pytorch文本分类入门

    🍨 本文为[🔗365天深度学习训练营学习记录博客 🍦 参考文章:365天深度学习训练营 🍖 原作者:[K同学啊 | 接辅导、项目定制]n🚀 文章来源:[K同学的学习圈子](https://www.yuque.com/mingtian-fkmxf/zxwb45) 一、加载数据  运行结果: 二、构建词典 运行结果:  三、生成数据批次和

    2024年01月23日
    浏览(36)
  • 人工智能中的文本分类:技术突破与实战指导

    在本文中,我们全面探讨了文本分类技术的发展历程、基本原理、关键技术、深度学习的应用,以及从RNN到Transformer的技术演进。文章详细介绍了各种模型的原理和实战应用,旨在提供对文本分类技术深入理解的全面视角。 关注TechLead,分享AI全维度知识。作者拥有10+年互联网

    2024年02月05日
    浏览(37)
  • PyTorch 人工智能研讨会:6~7

    原文:The Deep Learning with PyTorch Workshop 协议:CC BY-NC-SA 4.0 译者:飞龙 本文来自【ApacheCN 深度学习 译文集】,采用译后编辑(MTPE)流程来尽可能提升效率。 不要担心自己的形象,只关心如何实现目标。——《原则》,生活原则 2.3.c 概述 本章扩展了循环神经网络的概念。 您将

    2023年04月20日
    浏览(67)
  • 人工智能学习07--pytorch15(前接pytorch10)--目标检测:FPN结构详解

    backbone:骨干网络,例如cnn的一系列。(特征提取) (a)特征图像金字塔 检测不同尺寸目标。 首先将图片缩放到不同尺度,针对每个尺度图片都一次通过算法进行预测。 但是这样一来,生成多少个尺度就要预测多少次,训练效率很低。 (b)单一特征图 faster rcnn所采用的一种方式

    2023年04月12日
    浏览(76)
  • 人工智能(pytorch)搭建模型9-pytorch搭建一个ELMo模型,实现训练过程

    大家好,我是微学AI,今天给大家介绍一下人工智能(pytorch)搭建模型9-pytorch搭建一个ELMo模型,实现训练过程,本文将介绍如何使用PyTorch搭建ELMo模型,包括ELMo模型的原理、数据样例、模型训练、损失值和准确率的打印以及预测。文章将提供完整的代码实现。 ELMo模型简介 数据

    2024年02月07日
    浏览(67)
  • 人工智能(Pytorch)搭建模型6-使用Pytorch搭建卷积神经网络ResNet模型

    大家好,我是微学AI,今天给大家介绍一下人工智能(Pytorch)搭建模型6-使用Pytorch搭建卷积神经网络ResNet模型,在本文中,我们将学习如何使用PyTorch搭建卷积神经网络ResNet模型,并在生成的假数据上进行训练和测试。本文将涵盖这些内容:ResNet模型简介、ResNet模型结构、生成假

    2024年02月06日
    浏览(78)
  • 人工智能学习07--pytorch14--ResNet网络/BN/迁移学习详解+pytorch搭建

    亮点:网络结构特别深 (突变点是因为学习率除0.1?) 梯度消失 :假设每一层的误差梯度是一个小于1的数,则在反向传播过程中,每向前传播一层,都要乘以一个小于1的误差梯度。当网络越来越深的时候,相乘的这些小于1的系数越多,就越趋近于0,这样梯度就会越来越小

    2023年04月11日
    浏览(159)
  • 人工智能(pytorch)搭建模型10-pytorch搭建脉冲神经网络(SNN)实现及应用

    大家好,我是微学AI,今天给大家介绍一下人工智能(pytorch)搭建模型10-pytorch搭建脉冲神经网络(SNN)实现及应用,脉冲神经网络(SNN)是一种基于生物神经系统的神经网络模型,它通过模拟神经元之间的电信号传递来实现信息处理。与传统的人工神经网络(ANN)不同,SNN 中的

    2024年02月08日
    浏览(50)
  • 人工智能(pytorch)搭建模型17-pytorch搭建ReitnNet模型,加载数据进行模型训练与预测

    大家好,我是微学AI,今天给大家介绍一下人工智能(pytorch)搭建模型17-pytorch搭建ReitnNet模型,加载数据进行模型训练与预测,RetinaNet 是一种用于目标检测任务的深度学习模型,旨在解决目标检测中存在的困难样本和不平衡类别问题。它是基于单阶段检测器的一种改进方法,通

    2024年02月15日
    浏览(96)
  • 人工智能:Pytorch,TensorFlow,MXNET,PaddlePaddle 啥区别?

    学习人工智能的时候碰到各种深度神经网络框架:pytorch,TensorFlow,MXNET,PaddlePaddle,他们有什么区别? PyTorch、TensorFlow、MXNet和PaddlePaddle都是深度学习领域的开源框架,它们各自具有不同的特点和优势。以下是它们之间的主要区别: PyTorch是一个开源的Python机器学习库,它基

    2024年04月16日
    浏览(69)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包