文本识别CRNN模型介绍以及pytorch代码实现

这篇具有很好参考价值的文章主要介绍了文本识别CRNN模型介绍以及pytorch代码实现。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档


前言

文本识别是图像领域的一个常见任务,场景文字识别OCR任务中,需要先检测出图像中文字位置,再对检测出的文字进行识别,文本介绍的CRNN模型可用于后者, 对检测出的文字进行识别。
crnn模型,nlp,pytorch,深度学习,人工智能

An End-to-End Trainable Neural Network for Image-Based Sequence Recognition and Its Application to Scene Text Recognition
原论文地址:论文地址


一、CRNN模型介绍

1.模型结构

CRNN模型结合了CNN模型与RNN模型,CNN用于提取图像特征,RNN将CNN提取的特征进行处理得到输出,对应最终的标签。
CRNN包含三层,卷积层,循环层和转录层,由于每张图像中英文单词的长度不一致,但是经过CNN之后提取的特征长度是一定的,所以就需要一个转录层处理,得到最终结果。

crnn模型,nlp,pytorch,深度学习,人工智能
该图为模型的大体结构。

输入模型的是一张图像,其shape是(1,32,100) (channel,width,height),
经过一个卷积神经网络之后,其shape变成(512,1,24)(new_channel,new_height,new_width),把channel和height这两个维度合并,合并后shape(512,24),再将这两个维度交换位置,(24,512)(new_width,new_height*new_channel),由于后续需要将提取的特征输入循环神经网络,这个24就相当于是时间步了,24个时间步。输出特征图shape是(24,512)可以理解为,把原图分成24列,每一列用512维的特征向量表示。如下图所示
crnn模型,nlp,pytorch,深度学习,人工智能
将24个特征向量输入进循环神经网络,论文中循环神经网络层是两个LSTM堆叠而成的,经过后就得到24个时间步的输出,再经过全连接层以及softmax层得到一个概率矩阵,形状为(T,num_class),T是时间步,num_class是要分类的类别数,是0-9数字以及a-z字母组合,还有一个blank标识符,总共37类。时间步输出是24个,但是图片中字符数不一定都是24,长短不一,经过转录层将其处理。

2.CTCLoss

如果使用传统的loss function,需要对齐训练样本,有24个时间步,就需要有24个对应的标签,在该任务中显然不合适,除非可以把图片中的每一个字符都单独检测出来,一个字符对应一个标签,则需要很强大的文字检测算法,CTCLoss不需要对齐样本。

还是24个时间步得到24个标签,再进行一个β变换,才得到最终标签。24个时间步可以看作原图中分成24列,每一列输出一个标签,有时一个字母占据好几列,例如字母S占据三列,则这三列输出类别都应该是S,有的列没有字母,则输出空白类别,可以这么理解。得到最终类别时将连续重复的字符去重(空白符两侧的相同字符不去重,因为真实标签中可能存在连续重复字符,例如green,中的两个连续的e不应该去重,则生成标签的时候就该是类似e-e这种,则不会去重),最终去除空白符即可得到最终标签。
β变换定义如下
β : L ′ T → L < = T \beta :L^{'T} →L^{<=T} βLTL<=T
T代表时间步,长度,由于对连续重复字符去重,则处理后的长度一定小于T
举几个β变换的例子,空白用-表示
β ( − − s s t a a a t − e e ) = s t a t e \beta(--sstaaat-ee)=state β(sstaaatee)=state
β ( − − s − t t − a − t − e ) = s t a t e \beta(--s-tt-a-t-e)=state β(sttate)=state
β ( − s − s t − a a t − e ) = s s t a t e \beta(-s-st-aat-e)=sstate β(sstaate)=sstate
β ( − s − t t a − t t − e e ) = s t a t e \beta(-s-tta-tt-ee)=state β(sttattee)=state

可以看出若想要输出state,不止一条路径可以实现输出state.
经过LSTM后的结果需要送入转录层处理,设LSTM的输出标签序列为x,输出标签为l的概率为:
p ( l ∣ x ) = ∑ π ∈ β − ( l ) p ( π ∣ x ) p(l|x)=\sum_{\pi \in \beta ^{-}(l) }p(\pi |x) p(lx)=πβ(l)p(πx)
π ∈ β − ( l ) \pi \in \beta ^{-}(l) πβ(l)表示经过β变换后为l的路径集合 π \pi π

对于每一条路径 π \pi π
p ( π ∣ x ) = ∏ t = 1 T y π t t p(\pi |x)=\prod_{t=1}^{T}y_{\pi ^{t}}^{t } p(πx)=t=1Tyπtt

y π t t y_{\pi ^{t}}^{t } yπtt表示该路径第t个时间步取得该标签的一个概率,连乘起来就是取得该路径的概率。
CTCLoss的优化目标是使得 p ( l ∣ x ) = ∑ π ∈ β − ( l ) p ( π ∣ x ) p(l|x)=\sum_{\pi \in \beta ^{-}(l) }p(\pi |x) p(lx)=πβ(l)p(πx)最大,所以 l o s s = − p ( l ∣ x ) = ∑ π ∈ β − ( l ) p ( π ∣ x ) loss=-p(l|x)=\sum_{\pi \in \beta ^{-}(l) }p(\pi |x) loss=p(lx)=πβ(l)p(πx),使得该loss最小化,来更新前面lstm以及cnn的参数,由于CTCLoss计算有些复杂,暂不讨论。Pytorch中提供了CTCLoss的计算接口,我们直接使用即可。

from torch.nn import CTCLoss

beam search

训练阶段使用CTCLoss更新参数,测试阶段如果使用暴力解法,算出每条路径的一个概率,最终取最大概率的一个路径,时间复杂度非常大,如果有37个类别,序列长度是24,那么路径总和是 3 7 24 37^{24} 3724,这只是一个样本的路径数 。所以就需要用到beam search来优化计算过程。

crnn模型,nlp,pytorch,深度学习,人工智能
计算过程如图所示,现在第一个时间步中找到概率最大的三(可以自由设置)个标签,以这三个最大概率的标签为基础再往后搜索,在第二步会在第一步的概率基础上(需要以第一步的三个标签的概率乘以后面的标签概率)搜索出九个标签,在这九个标签中取三个最大的 ,继续往后搜索,以此类推,在经过最后一个时间步后会得到三条路径,取概率最大的那条,在经过CTC decode即可得到最终label。

二、使用pytorch实现crnn

数据集

将好几个数据集合并并做了相关处理,得到八千多张图片crnn模型,nlp,pytorch,深度学习,人工智能
只在这里展示关键部分代码
代码以及数据集在链接:https://pan.baidu.com/s/1j1sUFIgdB1qga1Cfrh-jlw
提取码:lf2m
dataset.py

import os
import torch
from torch.utils.data import Dataset
from PIL import Image
import numpy as np


class Synth90kDataset(Dataset):
    CHARS = '0123456789abcdefghijklmnopqrstuvwxyz'
    CHAR2LABEL = {char: i + 1 for i, char in enumerate(CHARS)}
    LABEL2CHAR = {label: char for char, label in CHAR2LABEL.items()}

    def __init__(self, root_dir=None,image_dir = None, mode=None, file_names=None, img_height=32, img_width=100):
        if mode == "train":
            file_names, texts = self._load_from_raw_files(root_dir, mode)
        else:

            texts = None
        self.root_dir = root_dir
        self.image_dir = image_dir
        self.file_names = file_names
        self.texts = texts
        self.img_height = img_height
        self.img_width = img_width

    def _load_from_raw_files(self, root_dir, mode):

        paths_file = None
        if mode == 'train':
            paths_file = 'train.txt'
        elif mode == 'test':
            paths_file = 'test.txt'

        file_names = []
        texts = []
        with open(os.path.join(root_dir, paths_file), 'r') as fr:
            for line in fr.readlines():

                file_name, ext = line.strip().split('.')
                text = file_name.split('_')[-1].lower()
                file_names.append(file_name + "." + ext)
                texts.append(text)
        return file_names, texts

    def __len__(self):
        return len(self.file_names)

    def __getitem__(self, index):
        file_name = self.file_names[index]
        file_path = os.path.join(self.image_dir,file_name)
        image = Image.open(file_path).convert('L')  # grey-scale

        image = image.resize((self.img_width, self.img_height), resample=Image.BILINEAR)
        image = np.array(image)
        image = image.reshape((1, self.img_height, self.img_width))
        image = (image / 127.5) - 1.0

        image = torch.FloatTensor(image)
        if self.texts:
            text = self.texts[index]
            target = [self.CHAR2LABEL[c] for c in text]
            target_length = [len(target)]

            target = torch.LongTensor(target)
            target_length = torch.LongTensor(target_length)
            # 如果DataLoader不设置collate_fn,则此处返回值为迭代DataLoader时取到的值
            return image, target, target_length
        else:
            return image


def synth90k_collate_fn(batch):
    # zip(*batch)拆包
    images, targets, target_lengths = zip(*batch)
    # stack就是向量堆叠的意思。一定是扩张一个维度,然后在扩张的维度上,把多个张量纳入仅一个张量。想象向上摞面包片,摞的操作即是stack,0轴即按块stack
    images = torch.stack(images, 0)
    # cat是指向量拼接的意思。一定不扩张维度,想象把两个长条向量cat成一个更长的向量。
    targets = torch.cat(targets, 0)
    target_lengths = torch.cat(target_lengths, 0)
    # 此处返回的数据即使train_loader每次取到的数据,迭代train_loader,每次都会取到三个值,即此处返回值。
    return images, targets, target_lengths

if __name__ == '__main__':
    from torch.utils.data import DataLoader
    from config import train_config as config

    img_width = config['img_width']
    img_height = config['img_height']
    data_dir = config['data_dir']
    train_batch_size = config['train_batch_size']
    cpu_workers = config['cpu_workers']

    train_dataset = Synth90kDataset(root_dir=data_dir, mode='train',
                                    img_height=img_height, img_width=img_width)
    train_loader = DataLoader(
        dataset=train_dataset,
        batch_size=train_batch_size,
        shuffle=True,
        num_workers=cpu_workers,
        collate_fn=synth90k_collate_fn)

    

model.py

import torch.nn as nn


class CRNN(nn.Module):

    def __init__(self, img_channel, img_height, img_width, num_class,
                 map_to_seq_hidden=64, rnn_hidden=256, leaky_relu=False):
        super(CRNN, self).__init__()

        self.cnn, (output_channel, output_height, output_width) = \
            self._cnn_backbone(img_channel, img_height, img_width, leaky_relu)

        self.map_to_seq = nn.Linear(output_channel * output_height, map_to_seq_hidden)

        self.rnn1 = nn.LSTM(map_to_seq_hidden, rnn_hidden, bidirectional=True)

        # 如果接双向lstm输出,则要 *2,固定用法
        self.rnn2 = nn.LSTM(2 * rnn_hidden, rnn_hidden, bidirectional=True)

        self.dense = nn.Linear(2 * rnn_hidden, num_class)

    # CNN主干网络
    def _cnn_backbone(self, img_channel, img_height, img_width, leaky_relu):
        assert img_height % 16 == 0
        assert img_width % 4 == 0

        # 超参设置
        channels = [img_channel, 64, 128, 256, 256, 512, 512, 512]
        kernel_sizes = [3, 3, 3, 3, 3, 3, 2]
        strides = [1, 1, 1, 1, 1, 1, 1]
        paddings = [1, 1, 1, 1, 1, 1, 0]

        cnn = nn.Sequential()

        def conv_relu(i, batch_norm=False):
            # shape of input: (batch, input_channel, height, width)
            input_channel = channels[i]
            output_channel = channels[i+1]

            cnn.add_module(
                f'conv{i}',
                nn.Conv2d(input_channel, output_channel, kernel_sizes[i], strides[i], paddings[i])
            )

            if batch_norm:
                cnn.add_module(f'batchnorm{i}', nn.BatchNorm2d(output_channel))

            relu = nn.LeakyReLU(0.2, inplace=True) if leaky_relu else nn.ReLU(inplace=True)
            cnn.add_module(f'relu{i}', relu)

        # size of image: (channel, height, width) = (img_channel, img_height, img_width)
        conv_relu(0)
        cnn.add_module('pooling0', nn.MaxPool2d(kernel_size=2, stride=2))
        # (64, img_height // 2, img_width // 2)

        conv_relu(1)
        cnn.add_module('pooling1', nn.MaxPool2d(kernel_size=2, stride=2))
        # (128, img_height // 4, img_width // 4)

        conv_relu(2)
        conv_relu(3)
        cnn.add_module(
            'pooling2',
            nn.MaxPool2d(kernel_size=(2, 1))
        )  # (256, img_height // 8, img_width // 4)

        conv_relu(4, batch_norm=True)
        conv_relu(5, batch_norm=True)
        cnn.add_module(
            'pooling3',
            nn.MaxPool2d(kernel_size=(2, 1))
        )  # (512, img_height // 16, img_width // 4)

        conv_relu(6)  # (512, img_height // 16 - 1, img_width // 4 - 1)

        output_channel, output_height, output_width = \
            channels[-1], img_height // 16 - 1, img_width // 4 - 1
        return cnn, (output_channel, output_height, output_width)

    # CNN+LSTM前向计算
    def forward(self, images):
        # shape of images: (batch, channel, height, width)

        conv = self.cnn(images)
        batch, channel, height, width = conv.size()

        conv = conv.view(batch, channel * height, width)
        conv = conv.permute(2, 0, 1)  # (width, batch, feature)

        # 卷积接全连接。全连接输入形状为(width, batch, channel*height),
        # 输出形状为(width, batch, hidden_layer),分别对应时序长度,batch,特征数,符合LSTM输入要求
        seq = self.map_to_seq(conv)

        recurrent, _ = self.rnn1(seq)
        recurrent, _ = self.rnn2(recurrent)

        output = self.dense(recurrent)
        return output  # shape: (seq_len, batch, num_class)

train.py

import os

import cv2
import torch
from torch.utils.data import DataLoader
import torch.optim as optim
from torch.nn import CTCLoss

from dataset import Synth90kDataset, synth90k_collate_fn
from model import CRNN
from evaluate import evaluate
from config import train_config as config


def train_batch(crnn, data, optimizer, criterion, device):
    crnn.train()
    images, targets, target_lengths = [d.to(device) for d in data]

    logits = crnn(images)
    log_probs = torch.nn.functional.log_softmax(logits, dim=2)

    batch_size = images.size(0)
    input_lengths = torch.LongTensor([logits.size(0)] * batch_size)
    target_lengths = torch.flatten(target_lengths)

    loss = criterion(log_probs, targets, input_lengths, target_lengths)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss.item()


def main():
    epochs = config['epochs']
    train_batch_size = config['train_batch_size']

    lr = config['lr']
    show_interval = config['show_interval']
    valid_interval = config['valid_interval']
    save_interval = config['save_interval']
    cpu_workers = config['cpu_workers']
    reload_checkpoint = config['reload_checkpoint']

    img_width = config['img_width']
    img_height = config['img_height']
    data_dir = config['data_dir']

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'device: {device}')

    train_dataset = Synth90kDataset(root_dir=data_dir,image_dir='../data/images', mode='train',
                                    img_height=img_height, img_width=img_width)


    train_loader = DataLoader(
        dataset=train_dataset,
        batch_size=train_batch_size,
        shuffle=True,
        num_workers=cpu_workers,
        collate_fn=synth90k_collate_fn)


    num_class = len(Synth90kDataset.LABEL2CHAR) + 1
    crnn = CRNN(1, img_height, img_width, num_class,
                map_to_seq_hidden=config['map_to_seq_hidden'],
                rnn_hidden=config['rnn_hidden'],
                leaky_relu=config['leaky_relu'])
    if reload_checkpoint:
        crnn.load_state_dict(torch.load(reload_checkpoint, map_location=device))
    crnn.to(device)

    optimizer = optim.RMSprop(crnn.parameters(), lr=lr)
    criterion = CTCLoss(reduction='sum')
    criterion.to(device)

    assert save_interval % valid_interval == 0 or valid_interval % save_interval ==0
    i = 1
    for epoch in range(1, epochs + 1):
        print(f'epoch: {epoch}')
        tot_train_loss = 0.
        tot_train_count = 0
        for train_data in train_loader:
            loss = train_batch(crnn, train_data, optimizer, criterion, device)
            train_size = train_data[0].size(0)

            tot_train_loss += loss
            tot_train_count += train_size
            if i % show_interval == 0:
                print('train_batch_loss[', i, ']: ', loss / train_size)



            if i % save_interval == 0:
                save_model_path = os.path.join(config["checkpoints_dir"],"crnn.pt")
                torch.save(crnn.state_dict(), save_model_path)
                print('save model at ', save_model_path)

            i += 1

        print('train_loss: ', tot_train_loss / tot_train_count)


if __name__ == '__main__':
    main()

crnn模型,nlp,pytorch,深度学习,人工智能
识别效果还算可以

crnn模型,nlp,pytorch,深度学习,人工智能
测试效果文章来源地址https://www.toymoban.com/news/detail-734187.html

到了这里,关于文本识别CRNN模型介绍以及pytorch代码实现的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • DEMATEL-ISM模型的Python实现——方法介绍以及代码复现

    本文源于笔者的《系统工程》课程的小组作业,笔者尝试运用DEMATEL-ISM方法来进行分析,建模求解,但在网络上并没有找到相应的,特别是集合DEMATEL-ISM方法的代码。因此自己码了DEMATEL-ISM模型的Python代码,并作为第一个博客发布~ 参考文献中,笔者主要参考了李广利等 1 的研

    2023年04月20日
    浏览(27)
  • 如何用pytorch做文本摘要生成任务(加载数据集、T5 模型参数、微调、保存和测试模型,以及ROUGE分数计算)

    摘要 :如何使用 Pytorch(或Pytorchlightning) 和 huggingface Transformers 做文本摘要生成任务,包括数据集的加载、模型的加载、模型的微调、模型的验证、模型的保存、ROUGE指标分数的计算、loss的可视化。 ✅ NLP 研 0 选手的学习笔记 ● python 需要 3.8+ ● 文件相对地址 : mian.py 和 tra

    2024年02月05日
    浏览(43)
  • 桃子叶片病害识别(Python代码,pyTorch框架,深度卷积网络模型,很容易替换为其它模型,带有GUI识别界面)

     1.分为三类 健康的桃子叶片 ,251张 桃疮痂病一般,857张     桃疮痂病严重,770 张  2.  GUI界面识别效果和predict.py识别效果如视频所示桃子叶片病害识别(Python代码,pyTorch框架,深度卷积网络模型,很容易替换为其它模型,带有GUI识别界面)_哔哩哔哩_bilibili     已经将

    2024年02月11日
    浏览(24)
  • 人脸情绪识别开源代码、模型以及说明文档

    队名:W03KFgNOc 排名:3 正确率: 0.75564 队员:yyMoming,xkwang,RichardoMu。 比赛链接:人脸情绪识别挑战赛 项目链接:link 该项目分别训练八个模型并生成csv文件,并进行融合 打开 train.sh ,可以看到训练的命令行,依次注释和解注释随后运行 train.sh 。 因为是训练八个模型,分别是

    2023年04月09日
    浏览(29)
  • 人工智能(pytorch)搭建模型8-利用pytorch搭建一个BiLSTM+CRF模型,实现简单的命名实体识别

    大家好,我是微学AI,今天给大家介绍一下人工智能(pytorch)搭建模型8-利用pytorch搭建一个BiLSTM+CRF模型,实现简单的命名实体识别,BiLSTM+CRF 模型是一种常用的序列标注算法,可用于词性标注、分词、命名实体识别等任务。本文利用pytorch搭建一个BiLSTM+CRF模型,并给出数据样例,

    2024年02月09日
    浏览(38)
  • Pytorch实现动物识别(含动物数据集和训练代码)

    目录 动物数据集+动物分类识别训练代码(Pytorch) 1. 前言 2. Animals-Dataset动物数据集说明 (1)Animals90动物数据集 (2)Animals10动物数据集 (3)自定义数据集 3. 动物分类识别模型训练 (1)项目安装 (2)准备Train和Test数据 (3)配置文件: config.yaml (4)开始训练 (5)可视化训

    2024年02月02日
    浏览(91)
  • Pytorch实现鸟类品种分类识别(含训练代码和鸟类数据集)

    目录 Pytorch实现鸟类识别(含训练代码和鸟类数据集) 1. 前言 2. 鸟类数据集 (1)Bird-Dataset26 (2)自定义数据集 3. 鸟类分类识别模型训练 (1)项目安装 (2)准备Train和Test数据 (3)配置文件:​config.yaml​ (4)开始训练 (5)可视化训练过程 (6)一些优化建议 (7) 一些运

    2024年02月09日
    浏览(44)
  • 年龄性别预测2:Pytorch实现年龄性别预测和识别(含训练代码和数据)

    目录 年龄性别预测2:Pytorch实现年龄性别预测和识别(含训练代码和数据) 1.年龄性别预测和识别方法 2.年龄性别预测和识别数据集 3.人脸检测模型 4.年龄性别预测和识别模型训练 (1)项目安装 (2)准备数据 (3)年龄性别模型训练(Pytorch) (4) 可视化训练过程 (5) 年龄性

    2024年01月19日
    浏览(57)
  • Pytorch实现中药材(中草药)分类识别(含训练代码和数据集)

    目录 Pytorch实现中药材(中草药)分类识别(含训练代码和数据集) 1. 前言 2. 中药材(中草药)数据集说明 (1)中药材(中草药)数据集:Chinese-Medicine-163 (2)自定义数据集 3. 中草药分类识别模型训练 (1)项目安装 (2)准备Train和Test数据 (3)配置文件: config.yaml (4)开始训练 (

    2023年04月13日
    浏览(24)
  • 从零实现诗词GPT大模型:pytorch框架介绍

    专栏规划: https://qibin.blog.csdn.net/article/details/137728228 因为咱们本系列文章主要基于深度学习框架pytorch进行,所以在正式开始之前,现对pytorch框架进行一个简单的介绍,主要面对深度学习或者pytorch还不熟悉的朋友。 这一步很简单,主要通过pip进行安装即可

    2024年04月16日
    浏览(22)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包