Elasticsearch 混合检索优化大模型 RAG 任务

这篇具有很好参考价值的文章主要介绍了Elasticsearch 混合检索优化大模型 RAG 任务。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

Elastic 社区在自然语言处理上面做的很不错官方博客更新速度也很快,现阶段大模型的应用场景主要在 Rag 和 Agent 上,国内 Rag(Retrieval-Augmented Generation 检索增强生成) 的尤其多,而搜索对于 Elasticsearch 来说是强项特别是 8.9 之后的版本提供了 ESRE 模块(集成了高级相关性排序如 BM25f、强大的矢量数据库、自然语言处理技术、与第三方模型如 GPT-3 和 GPT-4 的集成,并支持开发者自定义模型与应用),经过我的各种尝试在 Elasticsearch 上做 NLP 是一个很不错的选择,要做大规模的 RAG 任务甚至是针对图像、声音、多模态、关键词等大数据量的向量召回且搭配生成式模型这种复杂的业务场景 Elasticsearch 是天生支持的。此篇文章主要记录混合检索(BM25 +HNSW)倒数融合排序(RRF)完整测试。

官博有几篇不错的文章可以看看:

何时应用 RAG 与微调-CSDN博客
在 Elasticsearch 中扩展 ML 推理管道:如何避免问题并解决瓶颈-CSDN博客
Elastic:加速生成式人工智能体验-CSDN博客
Elastic AI Assistant for Observability 和 Microsoft Azure OpenAI 入门-CSDN博客
Elasticsearch 开放 inference API 增加了对 Cohere Embeddings 的支持-CSDN博客
Elasticsearch:使用在本地计算机上运行的 LLM 以及 Ollama 和 Langchain 构建 RAG 应用程序_ollama+rag-CSDN博客
Elasticsearch:倒数排序融合 - Reciprocal rank fusion (RRF)_rrf算法-CSDN博客

先说一下 RAG 任务的流程,以民法典为例 LLM 可以在现有资料上分析出确切的回答:
文档分割 -> 文本向量化 -> 问句向量化 -> 向量相似 top k个 -> 拼接 prompt 上下文  -> 提交给 LLM 生成回答。
倒数排序融合rrf算法论文,ElasticSearch 源码分析,深度学习/机器学习/强化学习,elasticsearch,大数据,搜索引擎,人工智能,语言模型,RAG

1.混合检索

全文检索 + ANN 检索。因为全文检索能查找更加准确的文档,直观都会感觉比单一的相似度检索更强。一个混合检索的查询语句例如:

{
  "query": {
    "bool": {
      "must": [
        { "match": {"content": {"query": "结婚领证登记需要双发到场吗?","boost": 1}}}
      ]
    }
  },
  "knn": {
    "field": "content_embed",
    "k": 5,
    "num_candidates": 100,
    "query_vector": []   // 向量、省略
  },
  "size": 5
}

2.倒数融合排序

倒数排序融合 - Reciprocal rank fusion:
由于全文搜索及向量搜索是使用不同的算法进行打分的,这就造成把两个不同搜索结果综合起来统一排名的困难。向量搜索的分数处于 0-1.0 之间,而全文搜索的结果排名分数可能是高于10或者更大的值。我们需要一种方法把两种搜索方法的结果进行综合处理,并得出一个唯一的排名。
倒数排序融合(RRF)是一种将具有不同相关性指标的多个结果集组合成单个结果集的方法。 
RRF 无需调优,不同的相关性指标也不必相互关联即可获得高质量的结果。该方法的优势在于不利用相关分数,而仅靠排名计算。相关分数存在的问题在于不同模型的分数范围差。
针对不同的 RAG 任务有不同的处理方式比如 法律、历史、人文类型的任务还可以加入命名实体识别 。或者使用其他语义转换模型将长文本总结为短文本。将拆分的长文本先调用 embed 转为向量后存储到 index 上。然后执行混合检索。

3.Embedding

第一步是文本向量化,这一步可以放在客户端做也可以放在 Elasticsearch 服务端做,不过模型推理是 Elasticsearch 新版中的重大功能,下面演示如何做。

在抱脸上直接搜索  sentence-similarity 模型,最靠前的就是 bge 由智源开源,基本上从去年开始一直是榜一,输入 zh 筛选中文:

倒数排序融合rrf算法论文,ElasticSearch 源码分析,深度学习/机器学习/强化学习,elasticsearch,大数据,搜索引擎,人工智能,语言模型,RAG

使用 langchain 测试推理,模型输出是  dim=1024:

倒数排序融合rrf算法论文,ElasticSearch 源码分析,深度学习/机器学习/强化学习,elasticsearch,大数据,搜索引擎,人工智能,语言模型,RAG

ElasticSearch支持最大 2048,目前 Es 还不支持非固定长度的向量,Elasticsearch 提供了 Eland 工具用于 pytorch 模型的推理和上传,源码安装该工具:

git clone https://github.com/elastic/eland
pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/
python setup.py install

然后执行上传脚本:
eland_import_hub_model --url http://192.168.197.128:9200 --hub-model-id .\Langchain-Chatchat-0.2.10\model\bge-large-zh-v1.5 --task-type text_embedding --start --clear-previous

上传过程不太顺利发现源码有一些问题需要修改,大致两处:
eland_import_hub_model.py  =>    上传前会把模型和一些文件放到临时目录,因为我的 windwos user name 是中文会找不到路径。直接将 tmp 写死即可。
            with tempfile.TemporaryDirectory() as tmp_dir:
            tmp_dir = 'C:\\tmp' 
transformers.py    =>        函数里面将 token 这个参数去掉
            # model = AutoModel.from_pretrained(model_id, token=token, torchscript=True)
            model = AutoModel.from_pretrained(model_id, torchscript=True)

等待执行完成:
倒数排序融合rrf算法论文,ElasticSearch 源码分析,深度学习/机器学习/强化学习,elasticsearch,大数据,搜索引擎,人工智能,语言模型,RAG

上传成功后在 kibana 模型管理位置点击 Synchronize your jobs and trained models.,同步一下刚刚上传的模型看到,调用推理接口,复制模型id,可以看到模型输出和前面 embed_demo.py 中测试的一样:
POST _ml/trained_models/m_workspace__langchain-chatchat-0.2.10__model__bge-large-zh-v1.5/_infer
{
  "docs": [
    {"text_field": "你好,请问你在干什么?"}
  ]
}

倒数排序融合rrf算法论文,ElasticSearch 源码分析,深度学习/机器学习/强化学习,elasticsearch,大数据,搜索引擎,人工智能,语言模型,RAG

4.文本分割

向量 dim=1024 是无法将一个超长文本完整的语义全部嵌入的,且大模型 token 的限制需要将文档进行分割,最简单的做法是指定 chunk_size(单个文档token数) 和 chunk_overlap(向量文档重叠token数)对文档进行分割,也有按句分割的做法,更加准确的是使用现成的语义分割模型,可以看看 github 上 Langchain-Chatchat 这个项目,提供了多种分割方式:

倒数排序融合rrf算法论文,ElasticSearch 源码分析,深度学习/机器学习/强化学习,elasticsearch,大数据,搜索引擎,人工智能,语言模型,RAG

5.部署 LLM 做增强生成

对于 RAG 任务,更大参数量的 LLM 对效果并没有显著提升, 即使是最小参数量的大模型也涵盖了基本的理解能力,这里部署清华 ChatGLM-6b  int4 量化模型 6G显存就够,这样可以将 token 开到很大。

git clone https://github.com/THUDM/ChatGLM-6B
pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/
git clone https://huggingface.co/THUDM/chatglm-6b-int4

模型 README.md 中有测试代码,替换一下模型路径就可以了:
倒数排序融合rrf算法论文,ElasticSearch 源码分析,深度学习/机器学习/强化学习,elasticsearch,大数据,搜索引擎,人工智能,语言模型,RAG

然后写一个 ELasticsearch Query 例子,根据搜索文档拼接 Prompt 做问答,Java 完整代码:

package tool.elk;

import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import org.apache.http.HttpEntity;
import org.apache.http.HttpHost;
import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.entity.ContentType;
import org.apache.http.entity.StringEntity;
import org.apache.http.impl.client.HttpClients;
import org.apache.http.nio.entity.NStringEntity;
import org.apache.http.util.EntityUtils;
import org.elasticsearch.client.*;
import org.elasticsearch.client.indices.CreateIndexRequest;
import org.elasticsearch.client.indices.GetIndexRequest;
import org.elasticsearch.common.xcontent.XContentType;
import java.io.BufferedReader;
import java.io.FileReader;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;

/**
 *   @desc : elatcisearch rag 测试
 *   @auth : tyf
 *   @date : 2024-04-16 10:06:24
*/
public class RAGDemo {

    public static String es_host = "192.168.197.128";
    public static Integer es_port = 9200;

    public static String llm_host = "http://0.0.0.0:8000";

    public static RestHighLevelClient highLevelClient;
    public static RestClient lowLevelClient;
    static {
        String[] ipArr = es_host.split(",");
        HttpHost[] httpHosts = new HttpHost[ipArr.length];
        for (int i = 0; i < ipArr.length; i++) {
            httpHosts[i] = new HttpHost(ipArr[i], es_port, "http");
        }
        RestClientBuilder builder = RestClient.builder(httpHosts);
        highLevelClient = new RestHighLevelClient(builder);
        lowLevelClient = highLevelClient.getLowLevelClient();
        System.out.println("初始化成功");
    }

    
    // 索引名称
    public static String indexName = "doc_split";
    // 索引 mapping
    public static String indexMapping =
            "{\n" +
                    "  \"settings\": {\n" +
                    "    \"number_of_shards\": 1,\n" +
                    "    \"number_of_replicas\": 0\n" +
                    "  },\n" +
                    "  \"mappings\": {\n" +
                    "    \"properties\": {\n" +
                    "      \"content\": {\n" +
                    "        \"type\": \"text\"\n" +
                    "      },\n" +
                    "      \"timestamp\": {\n" +
                    "        \"type\": \"long\"\n" +
                    "      },\n" +
                    "      \"content_embed\": {\n" +
                    "        \"type\": \"dense_vector\",\n" +
                    "        \"dims\": 1024,\n" +
                    "        \"index\": true,\n" +
                    "        \"similarity\": \"cosine\"\n" +
                    "      }\n" +
                    "    }\n" +
                    "  }\n" +
                    "}";

    // embed 模型编号
    public static String modelId = "m_workspace__langchain-chatchat-0.2.10__model__bge-large-zh-v1.5";

    // 文档召回 _score 阈值
    public static double scoreThreshold = 3d;

    // 本地文档路径
    public static String docPath = "C:\\Users\\唐于凡\\Desktop\\中华人民共和国民法典.txt";



    // 创建索引
    public static void createIndex() throws Exception{
//        System.out.println(indexMapping);
        // 索引不存在则创建
        GetIndexRequest request1 = new GetIndexRequest(indexName);
        boolean response1 = highLevelClient.indices().exists(request1, RequestOptions.DEFAULT);
        if(!response1){
            CreateIndexRequest request2 = new CreateIndexRequest(indexName);
            request2.source(indexMapping, XContentType.JSON);
            highLevelClient.indices().create(request2, RequestOptions.DEFAULT);
        }
    }

    // 读取并拆分文档
    public static List<String> parseDoc(int chunkSize,int chunkOverlap) throws Exception{
        List<String> splitTexts = new ArrayList<>();
        try (BufferedReader br = new BufferedReader(new FileReader(docPath))) {
            StringBuilder sb = new StringBuilder();
            String line;
            while ((line = br.readLine()) != null) {
                // 去掉没用的空格
                line = line.trim();
                if (!line.isEmpty()) {
                    sb.append(line).append(" "); // 可以根据需要调整分隔符
                }
            }
            String fullText = sb.toString().trim();
            // 拆分文本
            for (int i = 0; i < fullText.length(); i += chunkSize - chunkOverlap) {
                if (i + chunkSize < fullText.length()) {
                    splitTexts.add(fullText.substring(i, i + chunkSize));
                } else {
                    splitTexts.add(fullText.substring(i));
                }
            }
        }
        System.out.println("文档总数:"+splitTexts.size());
        return splitTexts;
    }


    // 调用 embed 模型转为向量
    public static Object embedDoc(String text){

        Object rt = null;
        // POST
        try {
            String entity = "{ \"docs\": [{\"text_field\": \""+text+"\"}]}";
            Request req = new Request("POST","_ml/trained_models/"+modelId+"/_infer");
            HttpEntity params = new NStringEntity(entity, ContentType.APPLICATION_JSON);
            req.setEntity(params);
            Response rsp = lowLevelClient.performRequest(req);
            HttpEntity en = rsp.getEntity();
            String body = EntityUtils.toString(en);
            JSONObject data = JSON.parseObject(body);
            rt = data.getJSONArray("inference_results").getJSONObject(0).getJSONArray("predicted_value");
        }
        catch (Exception e){
            e.printStackTrace();
        }
        return rt;
    }

    // 提交 Elasticsearch
    public static void uploadDoc(List<String> docSplits) throws Exception{

        // 遍历每个文档
        for (int i = 0; i < docSplits.size(); i++) {
            // 原始文本
            String content = docSplits.get(i);
            // 转为向量
            Object content_embed = embedDoc(content);
            // 时间
            Long timestamp = System.currentTimeMillis();

            // 上传
            JSONObject data = new JSONObject();
            data.put("content",content);
            data.put("content_embed",content_embed);
            data.put("timestamp",timestamp);

            Request req = new Request("POST","/"+indexName+"/_doc");
            HttpEntity params = new NStringEntity(data.toJSONString(), ContentType.APPLICATION_JSON);
            req.setEntity(params);
            Response res = lowLevelClient.performRequest(req);
            System.out.println("上传第"+i+"条:"+res);
        }

    }

    // 执行混合检索
    public static List<String> search(String q) throws Exception{

        // 转为向量
        Object vector = embedDoc(q);

        // 查询语句
        String query =
                "{\n" +
                        "  \"query\": {\n" +
                        "    \"bool\": {\n" +
                        "      \"must\": [\n" +
                        "        {\n" +
                        "          \"match\": {\n" +
                        "            \"content\": {\n" +
                        "              \"query\": \""+q+"\",\n" +
                        "              \"boost\": 1\n" +
                        "            }\n" +
                        "          }\n" +
                        "        }\n" +
                        "      ]\n" +
                        "    }\n" +
                        "  },\n" +
                        "  \"knn\": {\n" +
                        "    \"field\": \"content_embed\",\n" +
                        "    \"k\": 5,\n" +
                        "    \"num_candidates\": 100,\n" +
                        "    \"query_vector\": "+vector+"\n" +
                        "  },\n" +
                        "  \"size\": 5\n" +
                        "}\n";

//        System.out.println("查询语句:");
//        System.out.println(query);

        // 调用查询
        Request req = new Request("POST","/"+indexName+"/_search?pretty");
        HttpEntity params = new NStringEntity(query, ContentType.APPLICATION_JSON);
        req.setEntity(params);
        Response res = lowLevelClient.performRequest(req);

        // 解析
        String body = EntityUtils.toString(res.getEntity());
        JSONArray data = JSON.parseObject(body).getJSONObject("hits").getJSONArray("hits");

        // 遍历每个文档、将高的分的文档保存
        List<String> contents = new ArrayList<>();
        data.stream().map(n->JSONObject.parseObject(n.toString())).forEach(n->{
            // 得分高的才作为资料避免 llm 幻觉
            Double _score = n.getDouble("_score");
            if(_score >= scoreThreshold){
                // 文本
                String content = n.getJSONObject("_source").getString("content");
                contents.add(content);
                System.out.println("召回文档数据:"+n);
            }
        });

        System.out.println();
        return contents;
    }

    // 拼接 prompt
    public static String prompt(List<String> content,String q){

        StringBuilder question = new StringBuilder();

        question.append("你好,下面是我搜索得到的资料:\n");
        if(content.size()==0){
            question.append("无。\n");
        }
        for (int i = 0; i < content.size() ; i++) {
            question.append("("+(i+1)+")").append(content.get(i)).append("\n");
        }
        question.append("\n");
        question.append("请帮我根据上面的资料分析下面的问题,并帮我根据资料列出相关依据:\n");
        question.append(q).append("\n");
        question.append("\n");
        question.append("如果根据资料无法分析请回复不知道!");

        return question.toString();
    }

    // 调用 LLM 生成回答
    public static String llmAnswer(String question) throws Exception{

        JSONObject data = new JSONObject();
        data.put("prompt",question);
        data.put("history",null);

        HttpPost httpPost = new HttpPost(llm_host);
        httpPost.addHeader("Content-Type", "application/json;charset=utf-8");
        httpPost.setEntity(new StringEntity(data.toString(), StandardCharsets.UTF_8));
        CloseableHttpResponse response = HttpClients.createDefault().execute(httpPost);
        HttpEntity resEntity = response.getEntity();
        String resp = EntityUtils.toString(resEntity,"utf-8");
        return JSONObject.parseObject(resp).getString("response");
    }

    public static void main(String[] args) throws Exception{

        // 创建索引
//        createIndex();

        // 读取并拆分文档、提交 Elasticsearch
//        uploadDoc(parseDoc(500,100));

        // 执行混合检索
        String question = "结婚领证登记需要双发到场吗?";
        List<String> contents = search(question);

        // 执行混合检索并拼接 prompt
        String prompt = prompt(contents,question);

        // 调用 LLM 生成回答
        String answer = llmAnswer(prompt);

        System.out.println("-----------");
        System.out.println("Question:");
        System.out.println(question);
        System.out.println("-----------");
        System.out.println("Prompt:");
        System.out.println(prompt);
        System.out.println("-----------");
        System.out.println("Answer:");
        System.out.println(answer);
    }

}


 
 文章来源地址https://www.toymoban.com/news/detail-858087.html

到了这里,关于Elasticsearch 混合检索优化大模型 RAG 任务的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 探索检索增强生成(RAG)技术的无限可能:Vector+KG RAG、Self-RAG、多向量检索器多模态RAG集成

    由于 RAG 的整体思路是首先将文本切分成不同的组块,然后存储到向量数据库中。在实际使用时,将计算用户的问题和文本块的相似度,并召回 top k 的组块,然后将 top k 的组块和问题拼接生成提示词输入到大模型中,最终得到回答。 优化点: 优化文本切分的方式,组块大小

    2024年02月02日
    浏览(36)
  • ElasticSearch 7.X系列之: 检索性能优化实战指南

    检索响应慢! 并发检索用户多时,响应时间不达标 卡死了! 怎么还没有出结果? 怎么这么慢? 为啥竞品产品的很快就返回结果了? 宕机了 等等...... 这些都与可能检索有关,确切的说和检索性能有关。 检索性能的优化涉及知识点比较零散,我以官方文档的检索性能优化部

    2023年04月08日
    浏览(47)
  • Elasticsearch 8.X DSL 如何优化更有助于提升检索性能?

    根据我的实战和咨询经验,我发现如下几个问题。 当然,这是在和球友交流确认问题之后总结出来的。 2.1 问题1:bool 组合嵌套过深。 官方实际是有参数来约束的,indices.query.bool.max_nested_depth——bool 最大支持的嵌套层数是 20 ,并且过大的嵌套层数会导致“堆栈溢出”异常问

    2024年02月16日
    浏览(43)
  • 【高级RAG技巧】使用二阶段检索器平衡检索的效率和精度

    之前的文章已经介绍过向量数据库在RAG(Retrieval Augmented Generative)中的应用,本文将会讨论另一个重要的工具-Embedding模型。 一般来说,构建生产环境下的RAG系统是直接使用Embedding模型对用户输入的Query进行向量化表示,并且从已经构建好的向量数据库中检索出相关的段落用户

    2024年04月26日
    浏览(36)
  • AI数据技术02:RAG数据检索

            在人工智能的动态环境中,检索增强生成(RAG)已成为游戏规则的改变者,彻底改变了我们生成文本和与文本交互的方式。RAG 使用大型语言模型 (LLM) 等工具将信息检索的强大功能与自然语言生成无缝结合,为内容创建提供了一种变革性的方法。         在

    2024年02月03日
    浏览(40)
  • TS版LangChain实战:基于文档的增强检索(RAG)

    LangChain是一个以 LLM (大语言模型)模型为核心的开发框架,LangChain的主要特性: 可以连接多种数据源,比如网页链接、本地PDF文件、向量数据库等 允许语言模型与其环境交互 封装了Model I/O(输入/输出)、Retrieval(检索器)、Memory(记忆)、Agents(决策和调度)等核心组件

    2024年02月05日
    浏览(56)
  • RAG实战3-如何追踪哪些文档片段被用于检索增强生成

    本文是RAG实战2-如何使用LlamaIndex存储和读取embedding向量的续集,在阅读本文之前请先阅读前篇。 在前篇中,我们介绍了如何使用LlamaIndex存储和读取embedding向量。在本文中,我们将介绍在LlamaIndex中如何获得被用于检索增强生成的文档片段。 下面的代码展示了如何使用LlamaInd

    2024年03月09日
    浏览(47)
  • RAG应用开发实战02-相似性检索的关键 - Embedding

    将整个文本转化为实数向量的技术。 Embedding优点是可将离散的词语或句子转化为连续的向量,就可用数学方法来处理词语或句子,捕捉到文本的语义信息,文本和文本的关系信息。 ◉ 优质的Embedding通常会让语义相似的文本在空间中彼此接* ◉ 优质的Embedding相似的语义关系可

    2024年04月14日
    浏览(49)
  • Python实现HBA混合蝙蝠智能算法优化循环神经网络分类模型(LSTM分类算法)项目实战

    说明:这是一个机器学习实战项目(附带数据+代码+文档+视频讲解),如需数据+代码+文档+视频讲解可以直接到文章最后获取。 蝙蝠算法是2010年杨教授基于群体智能提出的启发式搜索算法,是一种搜索全局最优解的有效方法。该算法基于迭代优化,初始化为一组随机解,然

    2024年02月17日
    浏览(163)
  • RAG检索式增强技术是什么——OJAC近屿智能带你一探究竟

    Look!👀我们的大模型商业化落地产品 📖更多AI资讯请👉🏾关注 Free三天集训营助教在线为您火热答疑👩🏼‍🏫 RAG(Retrieval-Augmented Generation)模型是一个创新的自然语言处理(NLP)技术,它结合了传统的信息检索方法和现代的生成式语言模型,旨在通过引入外部知识源来

    2024年02月01日
    浏览(52)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包