sentence-transformers(SBert)中文文本相似度预测(附代码)

这篇具有很好参考价值的文章主要介绍了sentence-transformers(SBert)中文文本相似度预测(附代码)。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

sentence-transformers(SBert)中文文本相似度预测(附代码)

前言

  • 训练文本相似度数据集并进行评估:sentence-transformers(SBert)
  • 预训练模型:chinese-roberta-wwm-ext
  • 数据集:蚂蚁金融文本相似度数据集
  • 前端:Vue2+elementui+axios
  • 后端:flask

训练模型

  1. 创建网络:使用Sbert官方给出的预训练模型sentence_hfl_chinese-roberta-wwm-ext,先载入embedding层进行分词,再载入池化层并传入嵌入后的维度,对模型进行降维压缩,最后载入密集层,选择Than激活函数,输出维度大小为256维。
  2. 获取训练数据:构建出新模型后使用InputExample类存储训练数据,它接受文本对字符串列表和用于指示语义相似性的标签,用标准的Pytorch Dataloader包装train_examples,作用是打乱数据并生成特定大小的批次。
  3. 计算损失函数:对于每个句子对,通过网络传递句子A和句子B,从而产生嵌入u和v,使用余弦相似度计算相似性,并将结果与标准相似度得分进行比较。这样网络就能够进行微调,更好地识别句子的相似性。
  4. 模型调优:通过调用model.fit()来调优模型。向model.fit()中传递train_objective列表(由元组(dataloader, loss_function))组成。也可以传递多个元组,以便在具有不同损失函数的多个数据集上执行多任务学习。在训练过程需要使用sentence_transformers.evaluation评估表现是否有所改善,它包含各种可以传递给fit方法的evaluators。Evaluators会在训练期间定期运行,并且会返回分数,只有得分最高的模型才会存储在磁盘上。

首先运行preprocess.py获取数据,并划分训练集和测试集,之后运行train_sentence_bert.py,使用预训练模型, sbert将数据集用sbert训练相似度任务,得到训练好的模型,最后运行evaluate.py评估训练好的模型,将结果保存在predict.txt中,并输出预测结果。

这部分在详细代码里注释得很全。

后端部分

使用flask编写post接口,接收的数据格式为application/json,将前端传来的两个句子使用训练好的模型对其进行相似度预测,将得到的相似度类型从无法序列化存入json的tensor转成list,并将状态码,信息,数据返回给前端。

from sentence_transformers import SentenceTransformer, util
# 后端接口
from flask import Flask, jsonify, request
import re
# 用当前脚本名称实例化Flask对象,方便flask从该脚本文件中获取需要的内容
app = Flask(__name__)
# 使通过jsonify返回的中文显示正常,否则显示为ASCII码
app.config["JSON_AS_ASCII"] = False
model_path = 'D:/xxx模型路径/'
model = SentenceTransformer(model_path)
@app.route("/evaluate",methods=['POST'])
def evalute_sentence():
    s1 = request.json.get("s1")
    s2 = request.json.get("s2")
    if s1 and s2:
        embedding1 = model.encode(s1, convert_to_tensor=True)
        embedding2 = model.encode(s2, convert_to_tensor=True)
        similarity = util.cos_sim(embedding1, embedding2).tolist()
        return jsonify({"code": 200, "msg": "预测成功", "data": similarity})
    else:
        return jsonify({"code": 400, "msg": "缺少字段"})
if __name__ == '__main__':
    app.run(debug=True)

前端部分

框架使用Vue2,UI框架使用elementui。组件校验用户输入的表单(内容为中文,字数限制32个字,两个句子不为空),只有符合规则的字段才能提交表单。将数据通过Axios调用接口传递给后端,再根据后端接口响应状态进行相应的处理,如果返回状态码200,说明接口调用成功,展示返回的预测值,否则调用失败,页面弹出失败消息提示。

<template>
  <div class="recommend">
    <el-card class="box">
      <h2 class="title">中文文本相似度预测</h2>
      <el-form :model="evaluateForm" :rules="evaluateRules" ref="evaluateForm" class="form">
        <el-form-item prop="s1">
          <el-input
            placeholder="请输入句子一"
            maxlength="32"
            show-word-limit
            v-model="evaluateForm.s1"
            autocomplete="false"
            prefix-icon="el-icon-edit-outline"
          ></el-input>
        </el-form-item>
        <el-form-item prop="s2">
          <el-input
            maxlength="32"
            placeholder="请输入句子二"
            v-model="evaluateForm.s2"
            show-word-limit
            autocomplete="false"
            prefix-icon="el-icon-edit-outline"
          ></el-input>
        </el-form-item>
        <el-form-item class="btn-container">
          <el-button
            type="primary"
            @click="submitForm('evaluateForm')"
            class="btn"
            id="queryButton"
          >开始预测</el-button>
        </el-form-item>
      </el-form>
      <div v-show="result" style="margin-top: 20px">
        <el-progress
          :text-inside="true"
          :stroke-width="26"
          :percentage="result*100 ? result*100 : 0"
          class="el-bg-inner-running"
        ></el-progress>
        <p>预测结果:{{result}}</p>
      </div>
    </el-card>
  </div>
</template>

<script>
import api from "@/api/index"
export default {
  data () {
    return {
      evaluateForm: {
        s1: "",
        s2: ""
      },
      evaluateRules: { // 评估表单校验规则
        s1: [
          { required: true, message: '请输入中文句子', trigger: 'blur', pattern: /^[\u4E00-\u9FA5]+$/ },
        ],
        s2: [
          { required: true, message: '请输入中文句子', trigger: 'blur', pattern: /^[\u4E00-\u9FA5]+$/ },
        ],
      },
      result: undefined,
    }
  },
  methods: {
    postEvaluate () { // 调用接口
      api.postEvaluate(this.evaluateForm)
        .then((res) => {
          if (!res) {
            return
          }
          console.log("res", res)
          if (res.data.code !== 200) {
            this.$message({
              message: "请求失败",
              type: "error"
            })
            return
          }
          let data = res.data.data[0]
          this.result = data[0]
          console.log("this.result", this.result)
          this.$message({
            message: "预测成功!",
            type: "success"
          })

        })
        .catch((error) => {
          this.$message.error('资源获取错误!')
        })
    },
    submitForm (formName) { // 提交表单
      this.$refs[formName].validate((valid) => {
        if (valid) {
          this.postEvaluate()
        } else {
          this.$message({
            message: "请按要求填写",
            type: "warning"
          })
          console.log('error in submit form')
          return false
        }
      })
      document.getElementById("queryButton").blur()
    },
  }

}
</script>

<style lang="scss" scoped>
.recommend {
  width: 100%;
  height: 100%;
  text-align: center;
  display: flex;
  text-align: center;
  flex-direction: column;
  align-items: center;
  justify-content: center;
  overflow: hidden;
  background: #00416a 0 / cover fixed; /* fallback for old browsers */
  background: -webkit-linear-gradient(
    to right,
    #00416a,
    #e4e5e6
  ); /* Chrome 10-25, Safari 5.1-6 */
  background: linear-gradient(
    to right,
    #00416a,
    #e4e5e6
  ); /* W3C, IE 10+/ Edge, Firefox 16+, Chrome 26+, Opera 12+, Safari 7+ */
  .box {
    width: 48%;
    height: 60%;
    position: relative;
    background: hsla(0, 0%, 100%, 0.3);
    z-index: 5;
    padding: 10px 20px;
    // display: flex;
    // flex-direction: column;
    // justify-content: center;
    box-sizing: border-box;
    &::before {
      content: '';
      position: absolute;
      top: 0;
      right: 0;
      bottom: 0;
      left: 0;
      filter: blur(20px);
    }
    .title {
      color: #143b54;
    }
    .btn-container {
      margin: 10px auto;
      .btn {
        width: 100%;
        border-radius: 20px;
      }
    }
  }
}
::v-deep .el-card {
  border: 0;
  box-shadow: 0 5px 16px 0 rgb(0 0 0 / 30%);
}
::v-deep .el-progress-bar__outer {
  border: 0;
  background-color: transparent;
  // background-color: #abcbe0;
}
::v-deep .el-bg-inner-running .el-progress-bar__inner {
  background: #9cecfb; /* fallback for old browsers */
  background: -webkit-linear-gradient(
    to left,
    #0052d4,
    #65c7f7,
    #9cecfb
  ); /* Chrome 10-25, Safari 5.1-6 */
  background: linear-gradient(
    to left,
    #0052d4,
    #65c7f7,
    #9cecfb
  ); /* W3C, IE 10+/ Edge, Firefox 16+, Chrome 26+, Opera 12+, Safari 7+ */
}
</style>

预训练模型比较

paraphrase-multilingual-MiniLM-L12-v2
参数设置:epochs=1,batch_size=16
特点:作为sbert官方多语言预训练模型,已带有BERT层和池化层,可直接用数据评估,但未经纯中文文本训练,准确率较低

sentence-transformers(SBert)中文文本相似度预测(附代码)

chinese-electra-180g-small-discriminator
参数设置:epochs=1, batch_size=16
特点:运行时间快,准确率尚可

sentence-transformers(SBert)中文文本相似度预测(附代码)

chinese-electra-180g-small-discriminator
参数设置:epochs=20, batch_size=16
特点:20次迭代比1次迭代有效果,但差别不大

sentence-transformers(SBert)中文文本相似度预测(附代码)

chinese-electra-180g-small-discriminator
参数设置:epochs=1,batch_size=8
特点:比batch_size=16时效果更好

sentence-transformers(SBert)中文文本相似度预测(附代码)

chinese-roberta-wwm-ext
参数设置:epochs=1,batch_size=8
特点:迭代1次和20次准确率无差别,稳定且效果在所有模型中最好,缺点是体积大运行速度慢

sentence-transformers(SBert)中文文本相似度预测(附代码)

最后

代码已上传至sbert中文文本相似度预测,欢迎star!文章来源地址https://www.toymoban.com/news/detail-435251.html

到了这里,关于sentence-transformers(SBert)中文文本相似度预测(附代码)的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 自然语言处理14-基于文本向量和欧氏距离相似度的文本匹配,用于找到与查询语句最相似的文本

    大家好,我是微学AI,今天给大家介绍一下自然语言处理实战项目14-基于文本向量和欧氏距离相似度的文本匹配,用于找到与查询语句最相似的文本。NLP中的文本匹配是指通过计算文本之间的相似度来找到与查询语句最相似的文本。其中一种常用的方法是基于文本向量和欧氏

    2024年02月15日
    浏览(49)
  • 【大数据】文本特征提取与文本相似度分析

    写在博客前的话: 本文主要阐述如何对一段简短的文本做 特征提取 的处理以及如何对文本进行 分析 。 本文主要脉络以一个故事 s t o r y story s t ory 为主线,以该主线逐步延申,涉及到: 文本特征提取 、 词汇频率统计 (TF) , 反文档频率 (IDF) 以及 余弦相似度 计算的概念,

    2023年04月27日
    浏览(41)
  • Java 计算文本相似度

    2024年02月11日
    浏览(41)
  • java文本相似度

    在 Java 中,可以使用一些现成的库来比较文本的相似度。这里,我将为您提供一个使用 Jaccard 相似度算法(集合相似度)比较文本相似度的方法。首先,请确保将 commons-collections4-4.4.jar 添加到项目的类路径中。您可以从 Maven Central 仓库下载这个 JAR 文件。 添加 Maven 依赖: 使

    2024年02月11日
    浏览(36)
  • 用java计算文本相似度

    遇到这样一个需求,需要计算两个文本内容的相似度,以前也接触过,下面列举几种方式,也是我在网上查了很多内容整理的,直接上代码,供大家参考,如果你也有这样的需求,希望能帮到你: 引入pom 代码 Jaccard计算文本相似性,效果并不咋地,但在一些应用环境上,使用

    2024年02月10日
    浏览(40)
  • scala 短文本相似度计算

    simHash类的算法更适合长文本的相似度判断,而短文本可考虑一下几种方法: 一、编辑距离+jacard距离 对于dataframe,getLevenshtein可利用原生的levenshtein函数 二、md5 三、语义向量模型 其他思路 python的difflib使用

    2024年02月15日
    浏览(37)
  • python实现文本相似度排名计算

       项目中,客户突然提出需要根据一份企业名单查找对应的内部系统用户信息,然后根据直接的企业社会统一信用号和企业名称进行匹配,发现匹配率只有2.86%,低得可怜。所以根据客户的要求,需要将匹配率提高到70-80%左右,于是开始了折腾之路。     上网一查,各种相

    2024年02月12日
    浏览(38)
  • Python中的文本相似度计算方法

    在自然语言处理(NLP)领域,文本相似度计算是一个常见的任务。本文将介绍如何使用Python计算文本之间的相似度,涵盖了余弦相似度、Jaccard相似度和编辑距离等方法。 1. 余弦相似度 余弦相似度是一种衡量两个向量夹角的方法,用于衡量文本的相似度。首先,将文本转换为

    2024年02月13日
    浏览(41)
  • 利用Redis实现向量相似度搜索:解决文本、图像和音频之间的相似度匹配问题

    在自然语言处理领域,有一个常见且重要的任务就是文本相似度搜索。文本相似度搜索是指根据用户输入的一段文本,从数据库中找出与之最相似或最相关的一段或多段文本。它可以应用在很多场景中,例如问答系统、推荐系统、搜索引擎等。 比如,当用户在知乎上提出一个

    2024年02月15日
    浏览(45)
  • Python案例分析|文本相似度比较分析

     本案例通过设计和实现有关文本相似度比较的类Vector和Sketch,帮助大家进一步掌握设计Python类来解决实际问题的能力。 通过计算并比较文档的摘要可实现文本的相似度比较。 文档摘要的最简单形式可以使用文档中的k-grams(k个连续字符)的相对频率的向量来表示。假设字符

    2024年02月16日
    浏览(52)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包