【RAG最新研究】优化RAG系统的最佳实践与深度解析

向量数据库大模型机器学习

今天给大家分享一篇最新的RAG论文:

论文题目:Enhancing Retrieval-Augmented Generation: A Study of Best Practices

论文链接: https://arxiv.org/abs/2501.07391

论文代码: https://github.com/ali-bahrainian/RAG\_best\_practices

picture.image

图片来源:https://x.com/shao\_\_meng/status/1879329913293209734

研究概述

这篇论文主要关注的是 检索增强型生成(RAG)系统 中的一个核心问题: 不同的组件和配置如何影响系统的性能

简单来说,RAG系统通过结合语言模型和外部知识库来生成更准确的回答,但之前的研究并没有深入探讨哪些因素(比如模型大小、提示设计、知识库大小等)对系统性能的影响最大。这篇论文的目标就是通过系统的实验和分析,找出这些关键因素,并提出一些新的配置方法,帮助提升RAG系统在各种复杂任务中的表现。

论文亮点

  • ✅查询扩展:使初始查询多样化以获得更多相关信息。
  • ✅对比上下文学习(Contrastive ICL):利用真假例子消除虚假信息,提高准确性!
  • ✅聚焦模式:仅提取必要的上下文并最大限度地减少噪音。

查询阶段

  • 查询扩展 查询扩展是扩展用户输入查询(q)以生成各种关键字和查询变体的过程。
  • 利用 Flan-T5 等生成模型创建增强查询(Raffel et al., 2020)。

            
              
例子)  
最初询问:“COVID-19 有哪些症状?”  
  
•生成的关键词:  
•“冠状病毒感染的迹象”  
•“SARS-CoV-2 的症状”  
•“常见的 COVID-19 症状”  

          

检索阶段

  • 第一步:在知识库中搜索每个关键字并检索广泛的相关文档。
  • 第二阶段:使用初始查询将范围缩小到仅相关文档。

RAG有哪些相关研究?

在RAG领域,已经有不少研究为这篇论文奠定了基础。以下是一些重要的相关研究:

  1. RAG系统的初步研究
  • Guu et al. (2020) 展示了语言模型可以通过实时检索文档来提高生成文本的准确性,而不需要增加模型的大小。
  • Shi et al. (2024b) 则证明了即使对于没有直接访问权限的黑盒模型,检索模块也能发挥作用。
  1. RAG系统的优化
  • Wang et al. (2024) 提出了优化检索组件的策略,比如改进文档索引和检索算法,以减少延迟并保持准确性。
  • Hsia et al. (2024) 研究了如何通过架构决策(如语料库选择、检索深度等)来提升RAG系统的效率。
  • Wu et al. (2024) 探讨了如何平衡模型内部知识和外部检索到的信息,避免两者之间的冲突。
  1. RAG系统的应用
  • Lewis et al. (2020) 提出了将外部知识源集成到推理过程中的RAG模型,确保生成的信息是最新且准确的。
  • Borgeaud et al. (2022) 和 Lee et al. (2024) 讨论了RAG模型如何通过整合可验证的信息来提高回答的事实准确性。
  1. RAG系统的评估
  • Semnani et al. (2023) 和 Chang et al. (2024) 研究了大型语言模型(LLMs)生成不准确信息的问题,并探讨了RAG系统如何解决这一问题。
  • Tran and Litman (2024) 则讨论了如何通过增强知识检索来实现基于知识的对话。

论文如何探索这个问题?

论文通过以下几个步骤来解决RAG系统中不同组件和配置对性能影响的问题:

  1. 提出研究问题 :论文首先提出了九个关键的研究问题,涵盖了语言模型大小、提示设计、文档块大小、知识库大小、检索步长、查询扩展、对比上下文学习、多语言知识库和焦点模式 等方面。
  2. 设计RAG系统变体 :基于这些研究问题,论文设计了多种RAG系统的变体,包括查询扩展模块、检索模块和文本生成模块。
  3. 实验设置 :论文详细描述了实验的设置,包括使用的数据集(TruthfulQA和MMLU)、知识库(Wikipedia Vital Articles)、评估指标(如ROUGE、余弦相似度、MAUVE、FActScore等)以及RAG 方法的具体实现细节。
  4. 实验和结果分析 :论文在两个数据集上进行了广泛的实验,评估了不同RAG变体的性能,并进行了相关性评估、事实性评估和定性分析
  5. 对比分析 :论文对比了不同RAG配置的效果,分析了语言模型大小、提示设计、文档大小、知识库大小、检索步长、查询扩展、对比上下文学习、多语言知识库和焦点模式 对生成响应质量的影响。
  6. 提出新方法 :论文提出了四种新的RAG配置方法,包括查询扩展、对比上下文学习示例、多语言知识库和焦点模式RAG ,这些都是本文的新贡献。

picture.image

通过这些步骤,论文系统地研究了RAG系统的架构,并提出了具体的改进措施,为开发和优化RAG系统提供了实证基础和理论支持。

论文做了哪些实验?

实验分类

论文进行了以下几类实验:

  1. 相关性评估
  • 对比了不同RAG变体生成的文本与参考文本的相关性。
  • 使用了ROUGE-1、ROUGE-2、ROUGE-L、嵌入余弦相似度和MAUVE等指标来评估性能差异。
  • 评估了九个研究问题对RAG系统性能的影响。
  1. 事实性评估
  • 使用FActScore指标评估了RAG变体在TruthfulQA和MMLU数据集上的事实性表现。
  • 对比了有无RAG模块的模型(w/o_RAG)与包含RAG模块的模型之间的事实性表现。
  1. 定性分析
  • 提供了在TruthfulQA和MMLU数据集上由模型变体生成的示例。
  • 展示了所提出的模块如何通过专门的检索技术显著提高RAG系统的性能。

具体实验设置:

  • 数据集 :使用了TruthfulQA和MMLU两个公开数据集。
  • 知识库 :使用了Wikipedia Vital Articles作为知识库,包括法语和德语文章。
  • 评估指标 :采用了ROUGE、嵌入余弦相似度、MAUVE和FActScore等指标。
  • RAG方法的具体实现 :包括使用T5模型进行查询扩展、FAISS用于向量索引和相似性搜索、Sentence Transformer作为文本编码器等。

picture.image

基于74次实验的结果,论文总结了关键发现,并提出了对比上下文学习RAG和焦点模式RAG在性能上的优越性。

论文的核心代码

下面代码实现了查询扩展和聚焦搜索,

  1. 查询扩展 :
  • 使用序列到序列模型对查询进行扩展,生成更多的关键词,以提高检索的准确性。
  • 扩展后的查询用于在FAISS索引中搜索相似的标题,从而找到更多相关的文档。
  1. 聚焦检索 :
  • 如果指定了focus参数,系统不仅会检索相关文档,还会进一步聚焦于文档中的最相关句子,提供更精确的结果。

更多细节请查看:https://github.com/ali-bahrainian/RAG\_best\_practices/tree/main


            
              
from faiss import IDSelectorArray, SearchParameters  
import numpy as np  
import pandas as pd  
from sentence_transformers import SentenceTransformer  
import torch  
  
import spacy  
import faiss  
  
# Load the English model  
nlp = spacy.load("en\_core\_web\_sm")  
  
class Retriever:  
    """  
    Handles the retrieval of relevant documents from a pre-built FAISS index.  
    Enables querying with sentence transformers embeddings.  
  
    Attributes:  
        index (faiss.Index): FAISS index for fast similarity search.  
        doc\_info (pd.DataFrame): DataFrame containing detailed information about documents.  
        documents (list of str): List of original documents.  
        embedding\_model (SentenceTransformer): Model used for embedding the documents and queries.  
    """  
  
    def \_\_init\_\_(self, index, doc\_info, embedding\_model\_name, model\_loader\_seq2seq, index\_titles):  
        """Initializes the Retriever class with necessary components.  
  
        Args:  
            index: FAISS index for fast retrieval.  
            doc\_info (DataFrame): DataFrame containing info about embedded document; aligned indices with index embeddings.  
            documents (list): List of original documents.  
            embedding\_model\_name (str): Name of the sentence transformer model.  
        """  
        self.index = index  
        self.doc_info = doc_info  
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  
        self.embedding_model = SentenceTransformer(embedding_model_name).to(self.device)  
        self.sent_info = None  
        self.index_sents = None  
  
        self.model_seq2seq = model_loader_seq2seq.model  
        self.tokenizer_seq2seq = model_loader_seq2seq.tokenizer  
        # Define text-query pairs for query expansion  
        self.text_query_pairs = [  
            {"text": "Mitochondria play a crucial role in cellular respiration and energy production within human cells.", "query": "Cell Biology, Mitochondria, Energy Metabolism"},  
            {"text": "The Treaty of Versailles had significant repercussions that contributed to the onset of World War II.", "query": "World History, Treaty of Versailles, World War II"},  
            {"text": "What are the implications of the Higgs boson discovery for particle physics and the Standard Model?", "query": "Particle Physics, Higgs Boson, Standard Model"},  
            {"text": "How did the Silk Road influence cultural and economic interactions during the Middle Ages?", "query": "Silk Road, Middle Ages, Cultural Exchange"}  
        ]  
        self.index_titles = index_titles  
  
    def build\_index(self, documents):  
        """  
        Builds a FAISS index from document embeddings for efficient similarity searches which  
        includes embedding document chunks and initializing a FAISS index with these embeddings.  
  
        Args:  
            chunk\_size (int): The size of each text chunk in tokens.  
            overlap (int): The number of tokens that overlap between consecutive chunks.  
  
        Returns:  
            faiss.IndexFlatIP: The FAISS index containing the embeddings of the document chunks.  
        """  
        embeddings = self.embed_sents(documents)  
        index = faiss.IndexFlatIP(embeddings.shape[1])  
        index.add(embeddings)  
  
        return index  
  
    def embed\_sents(self, documents):  
        """  
        Generates embeddings for document chunks.  
  
        The process involves:  
        1. Preparing chunks of documents:  
          - Splits each document into overlapping chunks based on `chunk\_size` and `overlap`.  
        2. Encoding these chunks/documents into embeddings using the Sentence Transformer.  
  
        Args:  
            chunk\_size (int): Size of each chunk in tokens.  
            overlap (int): Overlap between consecutive chunks in tokens.  
  
        Returns:  
            np.ndarray: An array of embeddings for all the documents (chunks).  
        """  
        self.sent_info = self.prepare_sents(documents)  
        self.sent_info = pd.DataFrame(self.sent_info)  
        embeddings = self.embedding_model.encode(self.sent_info["text"].tolist(), show_progress_bar=True)  
        self.sent_info['embedding'] = embeddings.tolist()  
  
        return np.array(embeddings)  
      
    def prepare\_sents(self, documents):  
        """  
        Splits each document into sentences and  
        creates dictionary for DataFrame associated with index.  
  
        Returns:  
            Tuple[List[str], List[dict]]: Tuple containing list of all sents and their info.  
        """  
        sent_info = []  
        sent_id = 0  
        for document in documents:  
              
            doc = nlp(document)  
            sents = [sent.text for sent in doc.sents]  
              
            # Prepend same document to its chunks and store document/chunk details  
            for sent in sents:  
                sent_dict = {"text": sent, "org\_sent\_id": sent_id}  
                sent_info.append(sent_dict)  
                sent_id += 1  
        return sent_info  
  
    def retrieve(self, query\_batch, k, expand\_query, k\_titles, icl\_kb\_idx\_batch=None, focus=None):  
        """  
        Retrieves the top-k most similar documents for each query in a batch of queries.  
  
        Args:  
            query\_batch (list of str): List of query strings.  
            k (int): Number of documents to retrieve.  
  
        Returns:  
            List[List[dict]]: List of lists containing formatted results of retrieved documents for each query.  
        """  
  
        if k == 0:  
            return [[] for _ in query_batch]  
  
        if expand_query:  
            # Expand the query using a seq2seq model  
            eq_prompt_batch_str = []  
            for query in query_batch:  
                examples = self.text_query_pairs.copy()  
                examples.append({"text": query, "query": ""})  
                eq_prompt = "\n".join([f"Question: {example['text']}\nQuery Keywords: {example['query']}" for example in examples])  
                eq_prompt_batch_str.append(eq_prompt)  
  
            eq_prompt_batch_enc = self.tokenizer_seq2seq(eq_prompt_batch_str, return_tensors='pt', padding=True).to(self.device)  
            eq_batch_enc = self.model_seq2seq.generate(**eq_prompt_batch_enc, max_length=25, num_return_sequences=1)  
            eq_batch = self.tokenizer_seq2seq.batch_decode(eq_batch_enc, skip_special_tokens=True)  
            eq_batch = [eq.split(", ") for eq in eq_batch] # Split the expanded queries  
  
            # Encode the expanded queries and search the index for similar titles  
            eq_batch_indexed = [(eq, i) for i, eqs in enumerate(eq_batch) for eq in eqs]  
            eq_batch_flat = [eq for eq, _ in eq_batch_indexed]  
            eq_embeddings = self.embedding_model.encode(eq_batch_flat, show_progress_bar=False)  
            _, indices_eq = self.index_titles.search(np.array(eq_embeddings), k_titles)  
  
            # Retrieve the indices of the documents associated with the similar titles  
            indices_eq_batch = [[] for _ in range(len(query_batch))]  
            for ids, (_, i) in zip(indices_eq, eq_batch_indexed):  
                indices_eq_batch[i].append(self.doc_info[self.doc_info['org\_doc\_id'].isin(ids)].index.tolist())  
        else:  
            # If not expanding the query, set the indices to an empty list  
            if icl_kb_idx_batch:  
                # Remove the correct answer from the retrieved documents  
                all_ids_batch = [list(range(self.index.ntotal)) for _ in range(len(query_batch))]  
                for all_ids, icl_kb_idx in zip(all_ids_batch, icl_kb_idx_batch):  
                    all_ids.remove(icl_kb_idx)  
                all_ids_batch = [[all_ids] for all_ids in all_ids_batch]  
                indices_eq_batch = all_ids_batch  
            else:  
                indices_eq_batch = [[] for _ in range(len(query_batch))]  
  
        # Batch encode the queries  
        query_embeddings = self.embedding_model.encode(query_batch, show_progress_bar=False)  
  
        # Process each query separately  
        results_batch = []  
        for query_embedding, ids_filter in zip(query_embeddings, indices_eq_batch):  
            ids_filter = ids_filter if ids_filter else [list(range(self.index.ntotal))]  
  
            id_filter_set = set()  
            for id_filter in ids_filter:  
                id_filter_set.update(id_filter)  
  
            id_filter = list(id_filter_set)  
            id_selector = IDSelectorArray(id_filter)  
            # Search the index for similar documents, retrieve a larger set of documents  
            similarities, indices = self.index.search(np.array([query_embedding]), k, params=SearchParameters(sel=id_selector))  
            indices, similarities = indices[0], similarities[0]  
              
            # Focus on the most relevant sentences from the retrieved documents  
            if focus:  
                docs = self.doc_info.loc[indices]["text"].tolist()  
                self.index_sents = self.build_index(docs)     
                similarities, indices = self.index_sents.search(np.array([query_embedding]), focus)  
                indices, similarities = indices[0], similarities[0]  
  
            icl_kb = icl_kb_idx_batch!=None  
            if focus:  
                # Retrieve the most relevant sentences from the retrieved documents  
                results_batch.append([self._create_result(idx, sim, icl_kb, focus) for idx, sim in zip(indices[:focus], similarities)])  
            else:  
                results_batch.append([self._create_result(idx, sim, icl_kb, focus) for idx, sim in zip(indices[:k], similarities)])  
  
        return results_batch  
  
  
    def \_create\_result(self, idx, score, icl\_kb, focus):  
        """  
        Creates/builds a result dictionary of the retrieved document.  
  
        Args:  
            idx (int): Index of the result/document in doc\_info.  
            score (float): Similarity (& Diversity) score of document.  
  
        Returns:  
            dict: Dictionary containing the document text and additional information.  
        """  
        if focus:   
            # Retrieve the most relevant sentences from the retrieved documents  
            sent = self.sent_info.iloc[idx]  
            result_dict = {  
            "text": sent["text"],  
            "sent\_id": sent["org\_sent\_id"],  
            "score": score  
        }  
        else:  
            doc = self.doc_info.iloc[idx]  
            # Create the result dictionary  
            result_dict = {  
                "text": doc["text"],  
                "doc\_id": doc["org\_doc\_id"],  
                "score": score  
            }  
  
            if icl_kb:  
                # Include the correct and incorrect answers for ICL KB  
                result_dict['correct\_answer'] = doc["correct\_answer"]  
                result_dict['incorrect\_answer'] = doc["incorrect\_answer"]  
  
        return result_dict
          

picture.image

添加微信,回复”RAG“进入交流群

picture.image

picture.image

0
0
0
0
关于作者
相关资源
高性能存储虚拟化方案 NVMe over Fabric 在火山引擎的演进
在云计算中,虚拟化存储扮演着重要角色,其中 iSCSI 协议在业界开放、流行多年。近年来,拥有更优性能的 NVMe over Fabrics 协议也得到了发展。本次分享介绍了 NVMe over Fabrics 在云原生和虚拟化方向的演进工作和成果。
相关产品
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论