【论文笔记】SDCL: Self-Distillation Contrastive Learning for Chinese Spell Checking

这篇具有很好参考价值的文章主要介绍了【论文笔记】SDCL: Self-Distillation Contrastive Learning for Chinese Spell Checking。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

论文信息

论文地址:https://arxiv.org/pdf/2210.17168.pdf

Abstract

论文提出了一种token-level的自蒸馏对比学习(self-distillation contrastive learning)方法。

1. Introduction

【论文笔记】SDCL: Self-Distillation Contrastive Learning for Chinese Spell Checking,机器学习,论文阅读,深度学习,CSC,中文拼写纠错,自然语言处理

传统方法使用BERT后,会对confusion chars进行聚类,但使用作者提出的方法,会让其变得分布更均匀。

confusion chars: 指的应该是易出错的字。

2. Methodology

2.1 The Main Model

作者提取特征的方式:① 先用MacBERT得到hidden states,然后用word embedding和hidden states进行点乘。写成公式为:

H = M a c B E R T ( X ) ⋅ W \bf{H} = MacBERT(X) \cdot W H=MacBERT(X)W

这里的 W W W 应该就是BERT最前面的embedding层对X编码后的向量。

后面就是正常接个输出层再计算CrossEntropyLoss

2.2 Contrastive Loss

【论文笔记】SDCL: Self-Distillation Contrastive Learning for Chinese Spell Checking,机器学习,论文阅读,深度学习,CSC,中文拼写纠错,自然语言处理

基本思路:让错字token的特征向量和其对应正确字的token特征向量距离越近越好。这样BERT就能拿着错字,然后编码出对应正确字的向量,最后的预测层就能预测对了。

作者的做法:

  1. 错误句子从左边进入BERT,正确句子从右边进入BERT
  2. 对于错字,进行对比学习,让其与对应的正确字的特征向量距离越近越好。即这个错字的正样本为
  3. 将错误句子的其他token作为错字的负样本,使错字token的特征向量与其他向量的距离越远越好。上图中,字有5个负样本,即我、有、吃、旱、饭

上图中双头实线(↔)表示这两个token要距离越近越好,双头虚线表示这两个token要距离越远越好

损失函数公式如下:

L c = − ∑ i = 1 n L ( x ~ i ) log ⁡ exp ⁡ ( sim ⁡ ( h ~ i , h i ) / τ ) ∑ j = 1 n exp ⁡ ( sim ⁡ ( h ~ i , h j ) / τ ) L_c = -\sum_{i=1}^n \Bbb{L}\left(\tilde{x}_i\right) \log \frac{\exp \left(\operatorname{sim}\left(\tilde{h}_i, h_i\right) / \tau\right)}{\sum_{j=1}^n \exp \left(\operatorname{sim}\left(\tilde{h}_i, h_j\right) / \tau\right)} Lc=i=1nL(x~i)logj=1nexp(sim(h~i,hj)/τ)exp(sim(h~i,hi)/τ)

其中:

  • n n n : 为n个token
  • L ( x ~ i ) \Bbb{L}\left(\tilde{x}_i\right) L(x~i): 当 x i x_i xi为错字时, L ( x ~ i ) = 1 \Bbb{L}\left(\tilde{x}_i\right)=1 L(x~i)=1,否则为 0 0 0。即只算错字的损失
  • sim ( ⋅ ) \text{sim}(\cdot) sim():余弦相似度函数
  • h ~ i \tilde{h}_i h~i: 正确句子(右边BERT)输出的token的特征向量
  • h i h_i hi:错误句子(左边BERT)输出的token的特征向量
  • τ \tau τ:温度超参

上面损失使用CrossEntropyLoss实现。

作者还为右边的BERT增加了一个Loss L y L_y Ly,目的是让右边可以输出它的输入,即copy-paste任务。

最终的损失如下:

L = L x + α L y + β L c L = L_x+\alpha L_y+\beta L_c L=Lx+αLy+βLc

2.3 Implementation Details(Hyperparameters)

  • BERT:MacBERT
  • optimizer: AdamW
  • 学习率: 7e-5
  • batch_size: 48
  • λ \lambda λ: 0.9 (TODO,作者说的这个lambda不知道是啥)
  • α \alpha α: 1
  • β \beta β: 0.5
  • τ \tau τ: 0.9
  • epoch: 20次

3. Experiments

【论文笔记】SDCL: Self-Distillation Contrastive Learning for Chinese Spell Checking,机器学习,论文阅读,深度学习,CSC,中文拼写纠错,自然语言处理文章来源地址https://www.toymoban.com/news/detail-783766.html

代码实现

import torch
import torch.nn as nn
from transformers import BertTokenizerFast, BertForMaskedLM
import torch.nn.functional as F


class SDCLModel(nn.Module):

    def __init__(self):
        super(SDCLModel, self).__init__()
        self.tokenizer = BertTokenizerFast.from_pretrained('hfl/chinese-macbert-base')
        self.model = BertForMaskedLM.from_pretrained('hfl/chinese-macbert-base')

        self.alpha = 1
        self.beta = 0.5
        self.temperature = 0.9

    def forward(self, inputs, targets=None):
    	"""
    	inputs: 为tokenizer对原文本编码后的输入,包括input_ids, attention_mask等
    	targets:与inputs相同,只不过是对目标文本编码后的结果。
    	"""
        if targets is not None:
        	# 提取labels的input_ids
            text_labels = targets['input_ids'].clone()
            text_labels[text_labels == 0] = -100  # -100计算损失时会忽略
        else:
            text_labels = None
		
        word_embeddings = self.model.bert.embeddings.word_embeddings(inputs['input_ids'])
        hidden_states = self.model.bert(**inputs).last_hidden_state
        logits = self.model.cls(hidden_states * word_embeddings)

        if targets:
            loss = F.cross_entropy(logits.view(logits.shape[0] * logits.shape[1], logits.shape[2]), text_labels.view(-1))
        else:
            loss = 0.

        return logits, hidden_states, loss

    def extract_outputs(self, outputs):
        logits, _, _ = outputs
        return logits.argmax(-1)

    def compute_loss(self, outputs, targets, inputs, detect_targets, *args, **kwargs):
        logits_x, hidden_states_x, loss_x = outputs
        logits_y, hidden_states_y, loss_y = self.forward(targets, targets)

        # FIXME
        anchor_samples = hidden_states_x[detect_targets.bool()]
        positive_samples = hidden_states_y[detect_targets.bool()]
        negative_samples = hidden_states_x[~detect_targets.bool() & inputs['attention_mask'].bool()]

        # 错字和对应正确的字计算余弦相似度
        positive_sim = F.cosine_similarity(anchor_samples, positive_samples)
        # 错字与所有batch内的所有其他字计算余弦相似度
        # (FIXME,这里与原论文不一致,原论文说的是与当前句子的其他字计算,但我除了for循环,不知道该怎么写)
        negative_sim = F.cosine_similarity(anchor_samples.unsqueeze(1), negative_samples.unsqueeze(0), dim=-1)

        sims = torch.concat([positive_sim.unsqueeze(1), negative_sim], dim=1) / self.temperature
        sim_labels = torch.zeros(sims.shape[0]).long().to(self.args.device)

        loss_c = F.cross_entropy(sims, sim_labels)

        self.loss_c = float(loss_c)  # 记录一下

        return loss_x + self.alpha * loss_y + self.beta * loss_c

    def get_optimizer(self):
        return torch.optim.AdamW(self.parameters(), lr=7e-5)

    def predict(self, src):
        src = ' '.join(src.replace(" ", ""))
        inputs = self.tokenizer(src, return_tensors='pt').to(self.args.device)
        outputs = self.forward(inputs)
        outputs = self.extract_outputs(outputs)[0][1:-1]
        return self.tokenizer.decode(outputs).replace(' ', '')

个人总结

值得借鉴的地方

  1. 作者并没有直接使用BERT的输出作为token embedding,而是使用点乘的方式融合了BERT的输出和word embeddings

到了这里,关于【论文笔记】SDCL: Self-Distillation Contrastive Learning for Chinese Spell Checking的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 【论文阅读笔记】Contrastive Multiview Coding

     这篇文章主要探讨人类通过多种感官通道来观察世界,比如左眼观察到的长波长光通道,或右耳听到的高频振动通道。每个观察角度都带有噪音且是不完整的,但一些重要的因素,如物理、几何和语义,往往在所有观点之间共享(例如,“狗”可以被看到、听到和感受到)

    2024年01月18日
    浏览(50)
  • 【论文阅读笔记】Contrastive Learning with Stronger Augmentations

    基于提供的摘要,该论文的核心焦点是在对比学习领域提出的一个新框架——利用强数据增强的对比学习(Contrastive Learning with Stronger Augmentations,简称CLSA)。以下是对摘要的解析: 问题陈述: 表征学习(representation learning)已在对比学习方法的推动下得到了显著发展。 当前

    2024年02月19日
    浏览(49)
  • 【论文阅读笔记】 Representation Learning with Contrastive Predictive Coding

    这段文字是论文的摘要,作者讨论了监督学习在许多应用中取得的巨大进展,然而无监督学习并没有得到如此广泛的应用,仍然是人工智能中一个重要且具有挑战性的任务。在这项工作中,作者提出了一种通用的无监督学习方法,用于从高维数据中提取有用的表示,被称为“

    2024年01月25日
    浏览(43)
  • 论文笔记|CVPR2023:Supervised Masked Knowledge Distillation for Few-Shot Transformers

    这篇论文的题目是 用于小样本Transformers的监督遮掩知识蒸馏 论文接收: CVPR 2023 论文地址: https://arxiv.org/pdf/2303.15466.pdf 代码链接: https://github.com/HL-hanlin/SMKD 1.ViT在小样本学习(只有少量标记数据的小型数据集)中往往会 过拟合,并且由于缺乏 归纳偏置 而导致性能较差;

    2024年02月06日
    浏览(52)
  • 【论文笔记_对比学习_2021】CONTRASTIVE LEARNING WITH HARD NEGATIVE SAMPLES

    用困难负样本进行对比性学习 如何才能为对比性学习提供好的负面例子?我们认为,就像度量学习一样,表征的对比性学习得益于硬性负面样本(即难以与锚点区分的点)。使用硬阴性样本的关键挑战是,对比性方法必须保持无监督状态,这使得采用现有的使用真实相似性信

    2023年04月08日
    浏览(39)
  • 相对位置编码之RPR式:《Self-Attention with Relative Position Representations》论文笔记

    😄 额,本想学学XLNet的,然后XLNet又是以transformer-XL为主要结构,然后transformer-XL做了两个改进:一个是结构上做了segment-level的循环机制,一个是在attention机制里引入了相对位置编码信息来避免不同segment的同一位置采用相同的绝对位置编码的不合理。但无奈看到相对位置编码

    2024年02月17日
    浏览(41)
  • 【论文笔记】Triplet attention and dual-pool contrastive learning for clinic-driven multi-label medical...

    多标签分类Multi-label classification (MLC)可在单张图像上附加多个标签,在医学图像上取得了可喜的成果。但现有的多标签分类方法在实际应用中仍面临着严峻的临床现实挑战,例如: 错误分类带来的医疗风险, 不同疾病之间的样本不平衡问题 无法对未预先定义的疾病(未见疾

    2024年02月03日
    浏览(47)
  • Low-Light Image Enhancement via Self-Reinforced Retinex Projection Model 论文阅读笔记

    这是马龙博士2022年在TMM期刊发表的基于改进的retinex方法去做暗图增强(非深度学习)的一篇论文 文章用一张图展示了其动机,第一行是估计的亮度层,第二列是通常的retinex方法会对估计的亮度层进行RTV约束优化,从而产生平滑的亮度层,然后原图除以亮度层产生照度层作为

    2024年02月16日
    浏览(47)
  • [读论文][backbone]Knowledge Diffusion for Distillation

    DiffKD 摘要 The representation gap between teacher and student is an emerging topic in knowledge distillation (KD). To reduce the gap and improve the performance, current methods often resort to complicated training schemes, loss functions, and feature alignments, which are task-specific and feature-specific. In this paper, we state that the essence of the

    2024年02月08日
    浏览(55)
  • 【自监督论文阅读笔记】Self-Supervised Learning from Images with a Joint-Embedding Predictive Architecture

    2023         本文展示了一种 学习高度语义图像表示 的方法,而 不依赖于手工制作的数据增强 。本文介绍了 基于图像的联合嵌入预测架构 (I-JEPA) ,这是一种用于从图像进行自监督学习的 非生成方法 。 I-JEPA 背后的想法很简单: 从单个上下文块,预测同一图像中各种目

    2024年02月09日
    浏览(47)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包