前沿重器[47] | RAG开源项目Qanything源码阅读3-在线推理

技术

前沿重器

栏目主要给大家分享各种大厂、顶会的论文和分享,从中抽取关键精华的部分和大家分享,和大家一起把握前沿技术。具体介绍:仓颉专项:飞机大炮我都会,利器心法我还有。(算起来,专项启动已经是20年的事了!)

2023年文章合集发布了!在这里:又添十万字-CS的陋室2023年文章合集来袭

往期回顾

书接上文,最近选了一个开源的RAG项目进行进一步学习:https://github.com/netease-youdao/QAnything,后续一连几篇,会分几篇,从我的角度,给大家介绍这个项目,预计的目录如下:

本期是在线推理,即作为服务,一个query进来后到最终输出答案的全流程。

  • 检索&粗排
  • 精排
  • 检索文档后处理
  • prompt和请求大模型

外部服务

回顾一下在“前沿重器[45] RAG开源项目Qanything源码阅读1-概述+服务”中提到的服务核心文件,所有的接口都是在qanything_kernel\qanything_server\sanic_api.py里面启动的:


          
            
app.add_route(document, "/api/docs", methods=['GET'])  
app.add_route(new_knowledge_base, "/api/local\_doc\_qa/new\_knowledge\_base", methods=['POST'])  # tags=["新建知识库"]  
app.add_route(upload_weblink, "/api/local\_doc\_qa/upload\_weblink", methods=['POST'])  # tags=["上传网页链接"]  
app.add_route(upload_files, "/api/local\_doc\_qa/upload\_files", methods=['POST'])  # tags=["上传文件"]   
app.add_route(local_doc_chat, "/api/local\_doc\_qa/local\_doc\_chat", methods=['POST'])  # tags=["问答接口"]   
app.add_route(list_kbs, "/api/local\_doc\_qa/list\_knowledge\_base", methods=['POST'])  # tags=["知识库列表"]   
app.add_route(list_docs, "/api/local\_doc\_qa/list\_files", methods=['POST'])  # tags=["文件列表"]  
app.add_route(get_total_status, "/api/local\_doc\_qa/get\_total\_status", methods=['POST'])  # tags=["获取所有知识库状态"]  
app.add_route(clean_files_by_status, "/api/local\_doc\_qa/clean\_files\_by\_status", methods=['POST'])  # tags=["清理数据库"]  
app.add_route(delete_docs, "/api/local\_doc\_qa/delete\_files", methods=['POST'])  # tags=["删除文件"]   
app.add_route(delete_knowledge_base, "/api/local\_doc\_qa/delete\_knowledge\_base", methods=['POST'])  # tags=["删除知识库"]   
app.add_route(rename_knowledge_base, "/api/local\_doc\_qa/rename\_knowledge\_base", methods=['POST'])  # tags=["重命名知识库"]   

        

而推理,就是这里的local_doc_chat,直接看这个函数,就在qanything_kernel\qanything_server\handler.py里面。


          
            
async def local\_doc\_chat(req: request):  
    local_doc_qa: LocalDocQA = req.app.ctx.local_doc_qa  
    user_id = safe_get(req, 'user\_id')  
    if user_id is None:  
        return sanic_json({"code": 2002, "msg": f'输入非法!request.json:{req.json},请检查!'})  
    is_valid = validate_user_id(user_id)  
    if not is_valid:  
        return sanic_json({"code": 2005, "msg": get_invalid_user_id_msg(user_id=user_id)})  
    debug_logger.info('local\_doc\_chat %s', user_id)  
    kb_ids = safe_get(req, 'kb\_ids')  
    question = safe_get(req, 'question')  
    rerank = safe_get(req, 'rerank', default=True)  
    debug_logger.info('rerank %s', rerank)  
    streaming = safe_get(req, 'streaming', False)  
    history = safe_get(req, 'history', [])  
    debug_logger.info("history: %s ", history)  
    debug_logger.info("question: %s", question)  
    debug_logger.info("kb\_ids: %s", kb_ids)  
    debug_logger.info("user\_id: %s", user_id)  
  
    not_exist_kb_ids = local_doc_qa.milvus_summary.check_kb_exist(user_id, kb_ids)  
    if not_exist_kb_ids:  
        return sanic_json({"code": 2003, "msg": "fail, knowledge Base {} not found".format(not_exist_kb_ids)})  
  
    file_infos = []  
    milvus_kb = local_doc_qa.match_milvus_kb(user_id, kb_ids)  
    for kb_id in kb_ids:  
        file_infos.extend(local_doc_qa.milvus_summary.get_files(user_id, kb_id))  
    valid_files = [fi for fi in file_infos if fi[2] == 'green']  
    if len(valid_files) == 0:  
        return sanic_json({"code": 200, "msg": "当前知识库为空,请上传文件或等待文件解析完毕", "question": question,  
                           "response": "All knowledge bases {} are empty or haven't green file, please upload files".format(  
                               kb_ids), "history": history, "source\_documents": [{}]})  
    else:  
        debug_logger.info("streaming: %s", streaming)  
        if streaming:  
            debug_logger.info("start generate answer")  
  
            async def generate\_answer(response):  
                debug_logger.info("start generate...")  
                for resp, next_history in local_doc_qa.get_knowledge_based_answer(  
                        query=question, milvus_kb=milvus_kb, chat_history=history, streaming=True, rerank=rerank  
                ):  
                    chunk_data = resp["result"]  
                    if not chunk_data:  
                        continue  
                    chunk_str = chunk_data[6:]  
                    if chunk_str.startswith("[DONE]"):  
                        source_documents = []  
                        for inum, doc in enumerate(resp["source\_documents"]):  
                            source_info = {'file\_id': doc.metadata['file\_id'],  
                                           'file\_name': doc.metadata['file\_name'],  
                                           'content': doc.page_content,  
                                           'retrieval\_query': doc.metadata['retrieval\_query'],  
                                           'score': str(doc.metadata['score'])}  
                            source_documents.append(source_info)  
  
                        retrieval_documents = format_source_documents(resp["retrieval\_documents"])  
                        source_documents = format_source_documents(resp["source\_documents"])  
                        chat_data = {'user\_info': user_id, 'kb\_ids': kb_ids, 'query': question, 'history': history,  
                                     'prompt': resp['prompt'], 'result': next_history[-1][1],  
                                     'retrieval\_documents': retrieval_documents, 'source\_documents': source_documents}  
                        qa_logger.info("chat\_data: %s", chat_data)  
                        debug_logger.info("response: %s", chat_data['result'])  
                        stream_res = {  
                            "code": 200,  
                            "msg": "success",  
                            "question": question,  
                            # "response":next\_history[-1][1],  
                            "response": "",  
                            "history": next_history,  
                            "source\_documents": source_documents,  
                        }  
                    else:  
                        chunk_js = json.loads(chunk_str)  
                        delta_answer = chunk_js["answer"]  
                        stream_res = {  
                            "code": 200,  
                            "msg": "success",  
                            "question": "",  
                            "response": delta_answer,  
                            "history": [],  
                            "source\_documents": [],  
                        }  
                    await response.write(f"data: {json.dumps(stream\_res, ensure\_ascii=False)}\n\n")  
                    if chunk_str.startswith("[DONE]"):  
                        await response.eof()  
                    await asyncio.sleep(0.001)  
  
            response_stream = ResponseStream(generate_answer, content_type='text/event-stream')  
            return response_stream  
  
        else:  
            for resp, history in local_doc_qa.get_knowledge_based_answer(  
                    query=question, milvus_kb=milvus_kb, chat_history=history, streaming=False, rerank=rerank  
            ):  
                pass  
            retrieval_documents = format_source_documents(resp["retrieval\_documents"])  
            source_documents = format_source_documents(resp["source\_documents"])  
            chat_data = {'user\_id': user_id, 'kb\_ids': kb_ids, 'query': question, 'history': history,  
                         'retrieval\_documents': retrieval_documents, 'prompt': resp['prompt'], 'result': resp['result'],  
                         '`': source_documents}  
            qa_logger.info("chat\_data: %s", chat_data)  
            debug_logger.info("response: %s", chat_data['result'])  
            return sanic_json({"code": 200, "msg": "success chat", "question": question, "response": resp["result"],  
                               "history": history, "source\_documents": source_documents})  

        

来挑几个重点讲讲:

  • 首先因为是正式项目,在鉴权、数据库检测上都做了很多健壮性的处理,例如,对user_id的判别、对数据库及其对应用户的权限判别check_kb_exist,再者还有知识库的判空等。
  • 此处有区分是否使用了流式streaming
  • 最终结果的输出有进行结构化,结构化这事的业务代码专门弄了个函数format_source_documents
  • 这里区分了retrieval_documentssource_documents,两者有所区别,我在后面展开聊关键算法流程的时候会展开讲。
  • get_knowledge_based_answer是内部获取知识点并进行生成的关键函数,就是上一条所说的关键算法流程。

get_knowledge_based_answer的函数很简单,不过单独拿出来,对可读性是有挺大帮助的。


          
            
def format\_source\_documents(ori\_source\_documents):  
    source_documents = []  
    for inum, doc in enumerate(ori_source_documents):  
        # for inum, doc in enumerate(answer\_source\_documents):  
        # doc\_source = doc.metadata['source']  
        file_id = doc.metadata['file\_id']  
        file_name = doc.metadata['file\_name']  
        # source\_str = doc\_source if isURL(doc\_source) else os.path.split(doc\_source)[-1]  
        source_info = {'file\_id': doc.metadata['file\_id'],  
                       'file\_name': doc.metadata['file\_name'],  
                       'content': doc.page_content,  
                       'retrieval\_query': doc.metadata['retrieval\_query'],  
                       'kernel': doc.metadata['kernel'],  
                       'score': str(doc.metadata['score']),  
                       'embed\_version': doc.metadata['embed\_version']}  
        source_documents.append(source_info)  
    return source_documents  

        

RAG推理流程

RAG说白了就是先搜后交给大模型生成,终于讲到这段代码了,流程在这里qanything_kernel\core\local_doc_qa.py


          
            
@get\_time  
def get\_knowledge\_based\_answer(self, query, milvus\_kb, chat\_history=None, streaming: bool = STREAMING,  
                                rerank: bool = False):  
    if chat_history is None:  
        chat_history = []  
    retrieval_queries = [query]  
  
    source_documents = self.get_source_documents(retrieval_queries, milvus_kb)  
  
    deduplicated_docs = self.deduplicate_documents(source_documents)  
    retrieval_documents = sorted(deduplicated_docs, key=lambda x: x.metadata['score'], reverse=True)  
    if rerank and len(retrieval_documents) > 1:  
        debug_logger.info(f"use rerank, rerank docs num: {len(retrieval\_documents)}")  
        retrieval_documents = self.rerank_documents(query, retrieval_documents)  
  
    source_documents = self.reprocess_source_documents(query=query,  
                                                        source_docs=retrieval_documents,  
                                                        history=chat_history,  
                                                        prompt_template=PROMPT_TEMPLATE)  
    prompt = self.generate_prompt(query=query,  
                                    source_docs=source_documents,  
                                    prompt_template=PROMPT_TEMPLATE)  
    t1 = time.time()  
    for answer_result in self.llm.generatorAnswer(prompt=prompt,  
                                                    history=chat_history,  
                                                    streaming=streaming):  
        resp = answer_result.llm_output["answer"]  
        prompt = answer_result.prompt  
        history = answer_result.history  
  
        # logging.info(f"[debug] get\_knowledge\_based\_answer history = {history}")  
        history[-1][0] = query  
        response = {"query": query,  
                    "prompt": prompt,  
                    "result": resp,  
                    "retrieval\_documents": retrieval_documents,  
                    "source\_documents": source_documents}  
        yield response, history  
    t2 = time.time()  
    debug_logger.info(f"LLM time: {t2 - t1}")  
  

        

首先注意到这里有个装饰器@get_time。可以用来记录执行的时间。


          
            
def get\_time(func):  
    def inner(*arg, **kwargs):  
        s_time = time.time()  
        res = func(*arg, **kwargs)  
        e_time = time.time()  
        print('函数 {} 执行耗时: {} 秒'.format(func.__name__, e_time - s_time))  
        return res  
  
    return inner  

        

检索&粗排

get_source_documents是检索的过程,即给定了retrieval_queriesmilvus_kb,即query和所需要查的数据库,开始进行查询。这个过程应该大家都比较熟悉了,这个的返回结果,会放在retrieval_documents里面,即“检索到的文档”,我把源码放出来。


          
            
def get\_source\_documents(self, queries, milvus\_kb, cosine\_thresh=None, top\_k=None):  
    milvus_kb: MilvusClient  
    if not top_k:  
        top_k = self.top_k  
    source_documents = []  
    embs = self.embeddings._get_len_safe_embeddings(queries)  
    t1 = time.time()  
    batch_result = milvus_kb.search_emb_async(embs=embs, top_k=top_k, queries=queries)  
    t2 = time.time()  
    debug_logger.info(f"milvus search time: {t2 - t1}")  
    for query, query_docs in zip(queries, batch_result):  
        for doc in query_docs:  
            doc.metadata['retrieval\_query'] = query  # 添加查询到文档的元数据中  
            doc.metadata['embed\_version'] = self.embeddings.embed_version  
            source_documents.append(doc)  
    if cosine_thresh:  
        source_documents = [item for item in source_documents if float(item.metadata['score']) > cosine_thresh]  
  
    return source_documents  

        
  • _get_len_safe_embeddings给定query获取向量。在上一期“前沿重器[46] RAG开源项目Qanything源码阅读2-离线文件处理”有讲过,这个内部是请求一个向量模型的服务,背后的模型是需要和离线文件处理那个模型一致,所以部署同一个就会比较稳当,当然的,接口也是triton,一个grpc接口,有关GRPC,上次忘了放链接,这次放这里心法利器[6] | python grpc实践,非常建议大家详细了解并且学会。
  • search_emb_async是用于做向量检索的。这个就是pymilvus的核心功能了。
  • 此处,查询出来还要过一个阈值卡控,对相似度达不到阈值的文档,需要过滤,阈值设置在cosine_thresh
  • 此处注意,这里的检索还涉及一个过程“粗排”,这个粗排是指查询数据库的时候,需要根据相似度进行排序,只取TOPN,毕竟如果不进行这个TOP的卡控,那数据库里所有的数据都会被查出来,这没什么意义了。这里之所以叫粗排,是因为这种相似度的对比是比较粗略的,只能过滤掉“肯定不是”的那些无关结果。具体“哪个好”,用额外的、更精准的模型来做会更好,达到“优中取优”的目的。这块的设计还可以展开,后面我再写文章聊吧。

留意到这一串代码:


          
            
retrieval_documents = sorted(deduplicated_docs, key=lambda x: x.metadata['score'], reverse=True)  
if rerank and len(retrieval_documents) > 1:  
    debug_logger.info(f"use rerank, rerank docs num: {len(retrieval\_documents)}")  
    retrieval_documents = self.rerank_documents(query, retrieval_documents)  

        

此处注意,这里的检索还涉及一个过程“粗排”,这个粗排是指查询数据库的时候,需要根据相似度进行排序,只取TOPN,毕竟如果不进行这个TOP的卡控,那数据库里所有的数据都会被查出来,这没什么意义了。这里之所以叫粗排,是因为这种相似度的对比是比较粗略的,只能过滤掉“肯定不是”的那些无关结果。具体“哪个好”,用额外的、更精准的模型来做会更好,达到“优中取优”的目的。这块的设计还可以展开,后面我再写文章聊吧。

精排

我们继续关注这里的rerank_documents,这个就是精排,或者像这里说的重排。


          
            
def rerank\_documents(self, query, source\_documents):  
    return self.rerank_documents_for_local(query, source_documents)  
  
def rerank\_documents\_for\_local(self, query, source\_documents):  
    if len(query) > 300:  # tokens数量超过300时不使用local rerank  
        return source_documents  
  
    source_documents_reranked = []  
    try:  
        response = requests.post(f"{self.local\_rerank\_service\_url}/rerank",  
                                    json={"passages": [doc.page_content for doc in source_documents], "query": query})  
        scores = response.json()  
        for idx, score in enumerate(scores):  
            source_documents[idx].metadata['score'] = score  
            if score < 0.35 and len(source_documents_reranked) > 0:  
                continue  
            source_documents_reranked.append(source_documents[idx])  
  
        source_documents_reranked = sorted(source_documents_reranked, key=lambda x: x.metadata['score'], reverse=True)  
    except Exception as e:  
        debug_logger.error("rerank error: %s", traceback.format_exc())  
        debug_logger.warning("rerank error, use origin retrieval docs")  
        source_documents_reranked = sorted(source_documents, key=lambda x: x.metadata['score'], reverse=True)  
  
    return source_documents_reranked  

        

简单地,这里就是把所有召回回来的文章请求到重排服务来算分,根据算分来进行过滤和排序,筛选出最优的文章。和向量模型类似,一样是用triton部署的,看模型名像是QAEnsemble_embed_rerank。

检索文档后处理

更进一步,需要对文档进行后处理,即reprocess_source_documents函数。


          
            
def reprocess\_source\_documents(self, query: str,  
                                source\_docs: List[Document],  
                                history: List[str],  
                                prompt\_template: str) -> List[Document]:  
    # 组装prompt,根据max\_token  
    query_token_num = self.llm.num_tokens_from_messages([query])  
    history_token_num = self.llm.num_tokens_from_messages([x for sublist in history for x in sublist])  
    template_token_num = self.llm.num_tokens_from_messages([prompt_template])  
  
    # logging.info(f"<self.llm.token\_window, self.llm.max\_token, self.llm.offcut\_token, query\_token\_num, history\_token\_num, template\_token\_num>, types = {type(self.llm.token\_window), type(self.llm.max\_token), type(self.llm.offcut\_token), type(query\_token\_num), type(history\_token\_num), type(template\_token\_num)}, values = {query\_token\_num, history\_token\_num, template\_token\_num}")  
    limited_token_nums = self.llm.token_window - self.llm.max_token - self.llm.offcut_token - query_token_num - history_token_num - template_token_num  
    new_source_docs = []  
    total_token_num = 0  
    for doc in source_docs:  
        doc_token_num = self.llm.num_tokens_from_docs([doc])  
        if total_token_num + doc_token_num <= limited_token_nums:  
            new_source_docs.append(doc)  
            total_token_num += doc_token_num  
        else:  
            remaining_token_num = limited_token_nums - total_token_num  
            doc_content = doc.page_content  
            doc_content_token_num = self.llm.num_tokens_from_messages([doc_content])  
            while doc_content_token_num > remaining_token_num:  
                # Truncate the doc content to fit the remaining tokens  
                if len(doc_content) > 2 * self.llm.truncate_len:  
                    doc_content = doc_content[self.llm.truncate_len: -self.llm.truncate_len]  
                else:  # 如果最后不够truncate\_len长度的2倍,说明不够切了,直接赋值为空  
                    doc_content = ""  
                    break  
                doc_content_token_num = self.llm.num_tokens_from_messages([doc_content])  
            doc.page_content = doc_content  
            new_source_docs.append(doc)  
            break  
  
    debug_logger.info(f"limited token nums: {limited\_token\_nums}")  
    debug_logger.info(f"template token nums: {template\_token\_num}")  
    debug_logger.info(f"query token nums: {query\_token\_num}")  
    debug_logger.info(f"history token nums: {history\_token\_num}")  
    debug_logger.info(f"new\_source\_docs token nums: {self.llm.num\_tokens\_from\_docs(new\_source\_docs)}")  
    return new_source_docs  

        
  • 这里的llm,是一个自己封装好的大模型工具,具体是在qanything_kernel\connector\llm\llm_for_fastchat.py这个位置。里面支持计算token、请求大模型等通用功能。这个工具可以结合自己场景的需求搬过去直接使用。
  • 计算limited_token_nums主要是方便组装prompt,避免某些文字被吃掉。
  • 这里是需要对文档进行新的拼接和调整,如果查询的文档太多太长,则需要截断,且截断的时候需要注意,要保证截断的位置必须是完整地句子,如果不够长直接不切了。

prompt和请求大模型

然后就是开始生成prompt,generate_prompt。说白了就是一个简单的拼接。另外,这里的prompt拼接,更多使用replace来完成,之前有看过别的模式,例如用字符串的format应该也可以,不过replace的适用范围会更广一些。


          
            
def generate\_prompt(self, query, source\_docs, prompt\_template):  
    context = "\n".join([doc.page_content for doc in source_docs])  
    prompt = prompt_template.replace("{question}", query).replace("{context}", context)  
    return prompt  

        

顺带就看看他们的prompt吧,实际上并不复杂。


          
            
PROMPT_TEMPLATE = """参考信息:  
{context}  
---  
我的问题或指令:  
{question}  
---  
请根据上述参考信息回答我的问题或回复我的指令。前面的参考信息可能有用,也可能没用,你需要从我给出的参考信息中选出与我的问题最相关的那些,来为你的回答提供依据。回答一定要忠于原文,简洁但不丢信息,不要胡乱编造。我的问题或指令是什么语种,你就用什么语种回复,  
你的回复:"""  

        

最后一步就是开始请求大模型了。即generatorAnswer函数。


          
            
def generatorAnswer(self, prompt: str,  
                    history: List[List[str]] = [],  
                    streaming: bool = False) -> AnswerResult:  
  
    if history is None or len(history) == 0:  
        history = [[]]  
    logging.info(f"history\_len: {self.history\_len}")  
    logging.info(f"prompt: {prompt}")  
    logging.info(f"prompt tokens: {self.num\_tokens\_from\_messages([{'content': prompt}])}")  
    logging.info(f"streaming: {streaming}")  
              
    response = self._call(prompt, history[:-1], streaming)  
    complete_answer = ""  
    for response_text in response:  
  
        if response_text:  
            chunk_str = response_text[6:]  
            if not chunk_str.startswith("[DONE]"):  
                chunk_js = json.loads(chunk_str)  
                complete_answer += chunk_js["answer"]  
                  
        history[-1] = [prompt, complete_answer]  
        answer_result = AnswerResult()  
        answer_result.history = history  
        answer_result.llm_output = {"answer": response_text}  
        answer_result.prompt = prompt  
        yield answer_result  

        

这里就是请求大模型的基本话术了,相对还是比较简单的,一方面是请求大模型,另一方面是解析大模型内的结果。有留意到,这里有对内容做一些校验:


          
            
if response_text:  
    chunk_str = response_text[6:]  
    if not chunk_str.startswith("[DONE]"):  
        chunk_js = json.loads(chunk_str)  
        complete_answer += chunk_js["answer"]  

        

可以看出应该是有一些泛用性处理,能解决更多复杂的问题吧。

小结

本文把QAnything项目内的重要的推理部分穿讲了一遍,可以看出这个项目已经非常完成,基本具备上线所需的关键部分,同时也有很严格的校验逻辑,严格程度很高也比较稳定,经过这个学习,我自己对工程代码和具体实施的理解有了很大的提升,希望大家也有收获。

当然了,和我之前那个项目类似:心法利器[105] 基础RAG-大模型和中控模块代码(含代码),这只是最通用、常用的方案罢了,相比我之前的basic_rag项目,QAnything在服务的完整性、健壮性,以及文档处理上都有了很多的更新,但都不要指望用上就能达到很高的水准,需要进一步提升还需要更多内里的修炼,例如query理解辅助更好地提升检索的准确性,联合训练提升大模型和检索结果的协同,更深入定制的文档处理提升内容的可读性等。

picture.image

0
0
0
0
关于作者
关于作者

文章

0

获赞

0

收藏

0

相关资源
字节跳动基于 DataLeap 的 DataOps 实践
随着数字化转型的推进以及业务数仓建设不断完善,大数据开发体量及复杂性逐步上升,如何保证数据稳定、正确、持续产出成为数据开发者核心诉求,也成为平台建设面临的挑战之一。本次分享主要介绍字节对于DataOps的理解 以及 DataOps在内部业务如何落地实践。
相关产品
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论