【多模态&RAG】多模态RAG ColPali实践

机器学习

关于【RAG&多模态】多模态RAG-ColPali:使用视觉语言模型实现高效的文档检索前面已经介绍了(供参考),这次来看看ColPali实践。

所需权重:

  1. 多模态问答模型:Qwen2-VL-72B-Instruct,https://modelscope.cn/models/Qwen/Qwen2-VL-72B-Instruct
  2. 基于 PaliGemma-3B 和 ColBERT 策略的视觉检索器:

多模态检索问答实践


        
          
  
from byaldi import RAGMultiModalModel  
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor  
from qwen_vl_utils import process_vision_info  
import torch  
from pdf2image import convert_from_path  
  
class DocumentQA:  
    def __init__(self, rag_model_name: str, vlm_model_name: str, device: str = 'cuda', system_prompt: str = None):  
        self.rag_engine = RAGMultiModalModel.from_pretrained(rag_model_name)  
        self.vlm = Qwen2VLForConditionalGeneration.from_pretrained(  
            vlm_model_name,  
            torch_dtype=torch.bfloat16,  
            attn_implementation="flash\_attention\_2",  
            device_map=device  
        )  
        self.processor = AutoProcessor.from_pretrained(vlm_model_name, trust_remote_code=True)  
        self.device = device  
        if system_prompt is None:  
            self.system_prompt = (  
                "你是一位专精于计算机科学和机器学习的AI研究助理。"  
                "你的任务是分析学术论文,尤其是关于文档检索和多模态模型的研究。"  
                "请仔细分析提供的图像和文本,提供深入的见解和解释。"  
            )  
        else:  
            self.system_prompt = system_prompt  
  
    def index_document(self, pdf_path: str, index_name: str = 'index', overwrite: bool = True):  
        self.pdf_path = pdf_path  
        self.rag_engine.index(  
            input_path=pdf_path,  
            index_name=index_name,  
            store_collection_with_index=False,  
            overwrite=overwrite  
        )  
        self.images = convert_from_path(pdf_path)  
  
    def query(self, text_query: str, k: int = 3) -> str:  
        results = self.rag_engine.search(text_query, k=k)  
        print("搜索结果:", results)  
  
        if not results:  
            print("未找到相关查询结果。")  
            return None  
  
        try:  
            page_num = results[0]["page\_num"]  
            image_index = page_num - 1  
            image = self.images[image_index]  
        except (KeyError, IndexError) as e:  
            print("获取页面图像时出错:", e)  
            return None  
  
        messages = [  
            {  
                "role": "system",  
                "content": self.system_prompt  
            },  
            {  
                "role": "user",  
                "content": [  
                    {"type": "image", "image": image},  
                    {"type": "text", "text": text_query},  
                ],  
            }  
        ]  
  
        text = self.processor.apply_chat_template(  
            messages, tokenize=False, add_generation_prompt=True  
        )  
  
        image_inputs, video_inputs = process_vision_info(messages)  
  
        # 准备模型输入  
        inputs = self.processor(  
            text=[text],  
            images=image_inputs,  
            videos=video_inputs,  
            padding=True,  
            return_tensors="pt",  
        )  
        inputs = inputs.to(self.device)  
  
        generated_ids = self.vlm.generate(**inputs, max_new_tokens=1024)  
  
        generated_ids_trimmed = [  
            out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)  
        ]  
        output_text = self.processor.batch_decode(  
            generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False  
        )  
  
        return output_text[0]  
  
if __name__ == "\_\_main\_\_":  
    # 初始化 DocumentQA 实例  
    document_qa = DocumentQA(  
        rag_model_name="./colpali",  
        vlm_model_name="./Qwen2-VL-7B-Instruct",  
        device='cuda'  
    )  
  
    # 索引 PDF 文档  
    document_qa.index_document("test.pdf")  
  
    # 定义查询  
    text_query = (  
        "文中模型在哪个数据集上相比其他模型有最大的优势?"  
        "该优势的改进幅度是多少?"  
    )  
  
    # 执行查询并打印答案  
    answer = document_qa.query(text_query)  
    print("答案:", answer)  

      
0
0
0
0
关于作者
相关资源
火山引擎大规模机器学习平台架构设计与应用实践
围绕数据加速、模型分布式训练框架建设、大规模异构集群调度、模型开发过程标准化等AI工程化实践,全面分享如何以开发者的极致体验为核心,进行机器学习平台的设计与实现。
相关产品
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论