【模型+代码/保姆级教程】使用Pytorch实现手写汉字识别

这篇具有很好参考价值的文章主要介绍了【模型+代码/保姆级教程】使用Pytorch实现手写汉字识别。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

前言


参考文章:

最初参考的两篇:
【Pytorch】基于CNN手写汉字的识别
「Pytorch」CNN实现手写汉字识别(数据集制作,网络搭建,训练验证测试全部代码)
模型:
EfficientNetV2网络详解
数据集(不必从这里下载,可以看一下它的介绍):
CASIA Online and Offline Chinese Handwriting Databases

鉴于已经3202年了,大模型都出来了,网上还是缺乏汉字识别这种基础神经网络的能让新手直接上手跑通的手把手教程,我就斗胆自己写一篇好了。

本文的主要特点:

  1. 使用EfficientNetV2模型真正实现3755类汉字识别

  2. Demo开源

  3. 预训练模型可以下载

  4. 预制数据集,无需处理直接使用

数据集


使用中科院制作的手写汉字数据集,链接直达官网,所以我这里不多介绍,只有满腔敬意。

上面参考的博客可能要你自己下载之后按照它的办法再预处理一下,但是在这个环节出现问题的朋友挺多,我把预处理的数据已经传到【夸克云盘】,有人反映说这个有损坏了,但是我自己用bandizip智能解压一切正常,没损坏。如果还是报错,则使用这个别人传的【百度网盘】。

预训练模型已经上传了(后面有链接),但是如果想自己训一下,就需要下载这个数据集,解压到项目结构里的data文件夹如下所示

data文件夹和log文件夹需要自己建。

项目结构


完整源代码:【项目源码】

【模型+代码/保姆级教程】使用Pytorch实现手写汉字识别

目录结构

重点注意data文件夹的结构,不要把数据集放错位置了或者多嵌套了文件夹

├─Chinese_Character_Rec(项目)
│ ├─asserts
│ │ ├─*.png
│ ├─char_dict
│ ├─Data.py
│ ├─EfficientNetV2
│ │ ├─demo.py
│ │ ├─EffNetV2.py
│ │ ├─Evaluate.py
│ │ ├─model.py
│ │ └─Train.py
│ ├─Utils.py
│ ├─VGG19
│ │ ├─demo.py
│ │ ├─Evaluate.py
│ │ ├─model.py
│ │ ├─Train.py
│ │ └─VGG19.py
│ └─README.md
├─data(数据集)
│ ├─test(测试数据集)
│ │ ├─00000
│ │ ├─00001
│ │ └─...
│ ├─test.txt(程序生成)
│ ├─train(训练数据集)
│ │ ├─00000
│ │ ├─00001
│ │ └─ ...
│ └─train.txt(程序生成)
├─log(模型参数存放位置)
    ├─log1.pth
    └─…

神经网络模型


预训练模型【参数链接】(包含vgg19和efficientnetv2)

请将.pth文件重命名为log+数字.pth的格式,例如log1.pth,放入log文件夹。方便识别和retrain。

VGG19

这里先后用了两种神经网络,我先用VGG19试了一下,分类前1000种汉字。训得有点慢,主要还是这模型有点老了,参数量也不小。而且要改到3755类的话还用原参数的话就很难收敛,也不知道该怎么调参数了,估计调好了也会规模很大,所以这里VGG19模型的版本只能分类1000种,就是数据集的前1000种(准确率>92%)。

EfficientNetV2

这个模型很不错,主要是卷积层的部分非常有效,参数量也很少。直接用small版本去分类3755个汉字,半小时就收敛得差不多了。所以本文用来实现3755类汉字的模型就是EfficientNetV2(准确率>89%),后面的教程都是基于这个,VGG19就不管了,在源码里感兴趣的自己看吧。

以下代码不用自己写,前面已经给出完整源代码了,下面的教程是结合源码的讲解而已。

运行环境


显存>=4G(与batchSize有关,batchSize=512时显存占用4.8G;如果是256或者128,应该会低于4G,虽然会导致训得慢一点)

内存>=16G(训练时不太占内存,但是刚开始加载的时候会突然占一下,如果小于16G还是怕爆)

如果你没有安装过Pytorch,啊,我也不知道怎么办,你要不就看看安装Pytorch的教程吧。(总体步骤是,有一个不太老的N卡,先去驱动里看看cuda版本,安装合适的CUDA,然后根据CUDA版本去pytorch.org找到合适的安装指令,然后在本地pip install)

以下是项目运行环境,我是3060 6G,CUDA版本11.6

这个约等号不用在意,可以都安装最新版本,反正我这里应该没用什么特殊的API


torch~=1.12.1+cu116
torchvision~=0.13.1+cu116
Pillow~=9.3.0

数据集准备


首先定义classes_txt方法在Utils.py中(不是我写的,是CSDN那两篇博客的,MyDataset同):

生成每张图片的路径,存储到train.txt或test.txt。方便训练或评估时读取数据


def classes_txt(root, out_path, num_class=None):
    dirs = os.listdir(root)
    if not num_class:
        num_class = len(dirs)

    with open(out_path, 'w') as f:
        end = 0
        if end < num_class - 1:
            dirs.sort()
            dirs = dirs[end:num_class]
            for dir1 in dirs:
                files = os.listdir(os.path.join(root, dir1))
                for file in files:
                    f.write(os.path.join(root, dir1, file) + '\n')

定义Dataset类,用于制作数据集,为每个图片加上对应的标签,即图片所在文件夹的代号


class MyDataset(Dataset):
    def __init__(self, txt_path, num_class, transforms=None):
        super(MyDataset, self).__init__()
        images = []
        labels = []
        with open(txt_path, 'r') as f:
            for line in f:
                if int(line.split('\\')[1]) >= num_class: # 超出规定的类,就不添加,例如VGG19只添加了1000类
                    break
                line = line.strip('\n')
                images.append(line)
                labels.append(int(line.split('\\')[1]))
        self.images = images
        self.labels = labels
        self.transforms = transforms

    def __getitem__(self, index):
        image = Image.open(self.images[index]).convert('RGB')
        label = self.labels[index]
        if self.transforms is not None:
            image = self.transforms(image)
        return image, label

    def __len__(self):
        return len(self.labels)

入口


我把各种超参都放在了args里方便改,请根据实际情况自行调整。这套defaults就是我训练这个模型时使用的超参,图片size默认32是因为我显存太小辣!!但是数据集给的图片大小普遍不超过64,如果想训得更精确,可以试试64*64的大小。

如果你训练时爆mem,请调小batch_size,试试256,128,64,32


parser = argparse.ArgumentParser(description='EfficientNetV2 arguments')
parser.add_argument('--mode', dest='mode', type=str, default='demo', help='Mode of net')
parser.add_argument('--epoch', dest='epoch', type=int, default=50, help='Epoch number of training')
parser.add_argument('--batch_size', dest='batch_size', type=int, default=512, help='Value of batch size')
parser.add_argument('--lr', dest='lr', type=float, default=0.0001, help='Value of lr')
parser.add_argument('--img_size', dest='img_size', type=int, default=32, help='reSize of input image')
parser.add_argument('--data_root', dest='data_root', type=str, default='../../data/', help='Path to data')
parser.add_argument('--log_root', dest='log_root', type=str, default='../../log/', help='Path to model.pth')
parser.add_argument('--num_classes', dest='num_classes', type=int, default=3755, help='Classes of character')
parser.add_argument('--demo_img', dest='demo_img', type=str, default='../asserts/fo2.png', help='Path to demo image')
args = parser.parse_args()


if __name__ == '__main__':
    if not os.path.exists(args.data_root + 'train.txt'): # 只生成一次
        classes_txt(args.data_root + 'train', args.data_root + 'train.txt', args.num_classes)
    if not os.path.exists(args.data_root + 'test.txt'): # 只生成一次
        classes_txt(args.data_root + 'test', args.data_root + 'test.txt', args.num_classes)

    if args.mode == 'train':
        train(args)
    elif args.mode == 'evaluate':
        evaluate(args)
    elif args.mode == 'demo':
        demo(args)
    else:
        print('Unknown mode')

训练


在前面CSDN博客的基础上,增加了lr_scheduler自行调整学习率(如果连续2个epoch无改进,就调小lr到一半),增加了连续训练的功能:

先在log文件夹下寻找是否存在参数文件,如果没有,就认为是初次训练;如果有,就找到后缀数字最大的log.pth,在这个基础上继续训练,并且每训练完一个epoch,就保存最新的log.pth,代号是上一次的+1。这样可以多次训练,防止训练过程中出错,参数文件损坏前功尽弃。

其中has_log_file和find_max_log在Utils.py中有定义。


def train(args):
    print("===Train EffNetV2===")
    # 归一化处理,不一定要这样做,看自己的需求,只是预训练模型的训练是这样设置的
    transform = transforms.Compose(
        [transforms.Resize((args.img_size, args.img_size)), transforms.ToTensor(),
         transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
         transforms.ColorJitter()])  

    train_set = MyDataset(args.data_root + 'train.txt', num_class=args.num_classes, transforms=transform)
    train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True)
    device = torch.device('cuda:0')
    # 加载模型
    model = efficientnetv2_s(num_classes=args.num_classes)
    model.to(device)
    model.train()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    # 学习率调整函数,不一定要这样做,可以自定义
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=2, factor=0.5)
    print("load model...")
    
    # 加载最近保存了的参数
    if has_log_file(args.log_root):
        max_log = find_max_log(args.log_root)
        print("continue training with " + max_log + "...")
        checkpoint = torch.load(max_log)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        loss = checkpoint['loss']
        epoch = checkpoint['epoch'] + 1
    else:
        print("train for the first time...")
        loss = 0.0
        epoch = 0

    while epoch < args.epoch:
        running_loss = 0.0
        for i, data in enumerate(train_loader):
            inputs, labels = data[0].to(device), data[1].to(device)
            optimizer.zero_grad()
            outs = model(inputs)
            loss = criterion(outs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            if i % 200 == 199:
                print('epoch %5d: batch: %5d, loss: %8f, lr: %f' % (
                    epoch + 1, i + 1, running_loss / 200, optimizer.state_dict()['param_groups'][0]['lr']))
                running_loss = 0.0

        scheduler.step(loss)
        # 每个epoch结束后就保存最新的参数
        print('Save checkpoint...')
        torch.save({'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': loss},
                   args.log_root + 'log' + str(epoch) + '.pth')
        print('Saved')
        epoch += 1

    print('Finish training')

评估


跑测试集,算总体准确率。有一点不完善,就是看不到每一个类具体的准确率。我的预训练模型其实感觉有几类是过拟合的,但是我懒得调整了。


def evaluate(args):
    print("===Evaluate EffNetV2===")
    # 这个地方要和train一致,不过colorJitter可有可无
    transform = transforms.Compose(
        [transforms.Resize((args.img_size, args.img_size)), transforms.ToTensor(),
         transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
         transforms.ColorJitter()])

    model = efficientnetv2_s(num_classes=args.num_classes)
    model.eval()
    if has_log_file(args.log_root):
        file = find_max_log(args.log_root)
        print("Using log file: ", file)
        checkpoint = torch.load(file)
        model.load_state_dict(checkpoint['model_state_dict'])
    else:
        print("Warning: No log file")

    model.to(torch.device('cuda:0'))
    test_loader = DataLoader(MyDataset(args.data_root + 'test.txt', num_class=args.num_classes, transforms=transform),batch_size=args.batch_size, shuffle=False)
    total = 0.0
    correct = 0.0
    print("Evaluating...")
    with torch.no_grad():
        for i, data in enumerate(test_loader):
            inputs, labels = data[0].cuda(), data[1].cuda()
            outputs = model(inputs)
            _, predict = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predict == labels).sum().item()
    acc = correct / total * 100
    print('Accuracy'': ', acc, '%')

推理


输入文字图片,输出识别结果:

其中char_dict就是每个汉字在数据集里的代号对应的gb2312编码,这个模型的输出结果是它在数据集里的代号,所以要查这个char_dict来获取它对应的汉字。


def demo(args):
    print('==Demo EfficientNetV2===')
    print('Input Image: ', args.demo_img)
    # 这个地方要和train一致,不过colorJitter可有可无
    transform = transforms.Compose(
        [transforms.Resize((args.img_size, args.img_size)), transforms.ToTensor(),
         transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
    img = Image.open(args.demo_img)
    img = transform(img)
    img = img.unsqueeze(0) # 增维
    model = efficientnetv2_s(num_classes=args.num_classes)
    model.eval()
    if has_log_file(args.log_root):
        file = find_max_log(args.log_root)
        print("Using log file: ", file)
        checkpoint = torch.load(file)
        model.load_state_dict(checkpoint['model_state_dict'])
    else:
        print("Warning: No log file")

    with torch.no_grad():
        output = model(img)
    _, pred = torch.max(output.data, 1)
    f = open('../char_dict', 'rb')
    dic = pickle.load(f)
    for cha in dic:
        if dic[cha] == int(pred):
            print('predict: ', cha)
    f.close()

例如输入图片为:

【模型+代码/保姆级教程】使用Pytorch实现手写汉字识别

程序运行结果:

【模型+代码/保姆级教程】使用Pytorch实现手写汉字识别

其他说明


如遇到Module not found之类的错,重新写一下import,从实际的位置导入。

这个模型我通过ChaquoPy尝试移植到了Android平台,不过效果一般,我也没好好做:手写汉字识别APP,借用开源手写板

另外,这个模型对于太细太黑的字体,准确度貌似不是很好,可能还是有点过拟合了。建议输入的图片与数据集的风格靠拢,黑色尽量浅一点,线不要太细。

B站同步文章:(【模型+代码/保姆级教程】使用Pytorch实现手写汉字识别 - 哔哩哔哩)

2023年9月更新:本项目已不再做,只是本人的学习实践,教程只是帮你跑通一个简单有效果的深度学习,事实上移动端和PC端应该都有比EffNetV2更加合适的模型,且需要仔细设定学习策略。有链接失效可以B站私信扣我,我过来补链接。文章来源地址https://www.toymoban.com/news/detail-470255.html

到了这里,关于【模型+代码/保姆级教程】使用Pytorch实现手写汉字识别的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 【数学建模】常用微分方程模型 + 详细手写公式推导 + Matlab代码实现

    微分方程基本概念 微分方程在数学建模中的应用 微分方程常用模型(人口增长模型、传染病模型) 2022.06.19 微分方程,是指含有未知函数及其导数的关系式。解微分方程就是找出未知函数。 微分方程是伴随着微积分学一起发展起来的。微积分学的奠基人Newton和Leibniz的著作中

    2024年02月09日
    浏览(67)
  • 实践教程|基于 pytorch 实现模型剪枝

    PyTorch剪枝方法详解,附详细代码。 一,剪枝分类 1.1,非结构化剪枝 1.2,结构化剪枝 1.3,本地与全局修剪 二,PyTorch 的剪枝 2.1,pytorch 剪枝工作原理 2.2,局部剪枝 2.3,全局非结构化剪枝 三,总结 参考资料 所谓模型剪枝,其实是一种从神经网络中移除\\\"不必要\\\"权重或偏差(

    2024年02月12日
    浏览(39)
  • 在树莓派上实现numpy的conv2d卷积神经网络做图像分类,加载pytorch的模型参数,推理mnist手写数字识别,并使用多进程加速

    这几天又在玩树莓派,先是搞了个物联网,又在尝试在树莓派上搞一些简单的神经网络,这次搞得是卷积识别mnist手写数字识别 训练代码在电脑上,cpu就能训练,很快的: 然后需要自己在dataset里导出一些图片:我保存在了mnist_pi文件夹下,“_”后面的是标签,主要是在pc端导

    2024年02月07日
    浏览(35)
  • Unity教程2:保姆级教程.几行代码实现输入控制2D人物的移动

    目录 人物的创建以及刚体的设置 图层渲染层级设置 角色碰撞箱设置 使用代码控制人物移动 创建脚本文件  初始函数解释 控制移动代码 初始化变量  获得键盘输入  调用函数 手册链接在这:Unity User Manual (2019.3) - Unity 手册 没有控制人物移动的2D游戏就太说不过去了!那么接

    2024年02月06日
    浏览(43)
  • Unity中使用Mixamo为3D模型添加动画(保姆级教程)

    最近在做为Unity的3D人物添加动画,浅浅记录一下操作方法。 打开Unity Hub,点击New Project,然后按照下图步骤操作: 打开项目——GameObject——3D Object——Plane,这一步非必要,如果已有3D场景,可忽略这一步。 点此打开Mixamo 打开Mixamo后进入如下界面,这里有一些3D角色和动画可

    2024年02月07日
    浏览(63)
  • 文本识别CRNN模型介绍以及pytorch代码实现

    提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文本识别是图像领域的一个常见任务,场景文字识别OCR任务中,需要先检测出图像中文字位置,再对检测出的文字进行识别,文本介绍的CRNN模型可用于后者, 对检测出的文字进行识别。 An End-to-End Tra

    2024年02月07日
    浏览(40)
  • 【PyTorch 实战2:UNet 分割模型】10min揭秘 UNet 分割网络如何工作以及pytorch代码实现(详细代码实现)

      U-Net,自2015年诞生以来,便以其卓越的性能在生物医学图像分割领域崭露头角。作为FCN的一种变体,U-Net凭借其Encoder-Decoder的精巧结构,不仅在医学图像分析中大放异彩,更在卫星图像分割、工业瑕疵检测等多个领域展现出强大的应用能力。UNet是一种常用于图像分割的卷

    2024年04月28日
    浏览(40)
  • 人工智能概论报告-基于PyTorch的深度学习手写数字识别模型研究与实践

    本文是我人工智能概论的课程大作业实践应用报告,可供各位同学参考,内容写的及其水,部分也借助了gpt自动生成,排版等也基本做好,大家可以参照。如果有需要word版的可以私信我,或者在评论区留下邮箱,我会逐个发给。word版是我最后提交的,已经调整统一了全文格

    2024年02月05日
    浏览(74)
  • 高级圣诞树代码实现合集-保姆级教程【前端三件套实现—0基础直接运行】

    0基础直接运行教程: 1.新建txt文本: 2.将代码粘贴到txt文本里: 3.将后缀改为html 4.双击打开html文件,观察效果~ 这段代码是一个用HTML和JavaScript实现的圣诞树动画效果。我将代码分成几个部分进行讲解。 HTML结构: 在 head 标签中定义了页面的标题、字符集和样式。 样式部分

    2024年02月04日
    浏览(65)
  • CNN实现手写数字识别(Pytorch)

    CNN(卷积神经网络)主要包括卷积层、池化层和全连接层。输入数据经过多个卷积层和池化层提取图片信息后,最后经过若干个全连接层获得最终的输出。 CNN的实现主要包括以下步骤: 数据加载与预处理 模型搭建 定义损失函数、优化器 模型训练 模型测试 以下基于Pytorch框

    2024年02月03日
    浏览(96)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包