【多模态&LLM】多模态大模型Reyes增加batch推理方式,提升推理速度

大模型向量数据库机器学习

笔者在前面预训练了一个多模态大模型Reyes,详情见《 【多模态&LLM】Reyes:一个从0到1开始训练的多模态大模型(技术报告) 》。本文将为Reyes增加一个batch推理方式,提高Reyes的推理速度。

Reyes-8B开源地址:

使用方式

将本仓库中的 modeling\_reyes.py 文件替换modelscrope下载的 modeling\_reyes.py 运行即可。 batch推理详细见github: batch\_inference.ipynb .

modeling\_reyes.py 增项:


        
        
            

              
          
 
 def
 
  
 
 chat\_batch
 
 
 (
   

             self,
   

             tokenizer,
   

             pixel\_values\_list,
   

             questions,
   

             generation\_config,
   

             histories=None,
   

             return\_histories=False,
   

             num\_patches\_lists=None,
   

             IMG\_START\_TOKEN=
 
 '<|vision\_start|>'
 
 ,
   

             IMG\_END\_TOKEN=
 
 '<|vision\_end|>'
 
 ,
   

             IMG\_CONTEXT\_TOKEN=
 
 '<|vision\_pad|>'
 
 ,
   

             verbose=False,
   

             visual\_features\_list=None
   

     )
 
 :
 
            

            

                  
          
 if
 
           histories 
          
 is
 
          
 None
 
          :
            

                      histories = [[] 
          
 for
 
           \_ 
          
 in
 
           questions]
            

            

                  img\_context\_token\_id = tokenizer.convert\_tokens\_to\_ids(IMG\_CONTEXT\_TOKEN)
            

                  self.img\_context\_token\_id = img\_context\_token\_id
            

                  
          
 # Get eos\_token\_id from the template
 
            

                  template = get\_conv\_template(self.template)
            

                  template.system\_message = self.system\_message
            

                  eos\_token\_id = tokenizer.convert\_tokens\_to\_ids(template.sep)
            

                  generation\_config[
          
 'eos\_token\_id'
 
          ] = eos\_token\_id
            

            

                  queries = []
            

                  input\_ids\_list = []
            

                  attention\_mask\_list = []
            

            

                  
          
 for
 
           idx 
          
 in
 
           range(len(questions)):
            

                      question = questions[idx]
            

                      history = histories[idx]
            

                      pixel\_values = pixel\_values\_list[idx] 
          
 if
 
           pixel\_values\_list[idx] 
          
 is
 
          
 not
 
          
 None
 
          
 else
 
          
 None
 
            

                      num\_patches\_list = [pixel\_values.shape[
          
 0
 
          ]] 
          
 if
 
           pixel\_values 
          
 is
 
          
 not
 
          
 None
 
          
 else
 
           []
            

            

                      
          
 if
 
          
 not
 
           history 
          
 and
 
           pixel\_values 
          
 is
 
          
 not
 
          
 None
 
          
 and
 
          
 '<image>'
 
          
 not
 
          
 in
 
           question:
            

                          question = 
          
 '<image>\n'
 
           + question
            

            

                      template\_i = get\_conv\_template(self.template)
            

                      template\_i.system\_message = self.system\_message
            

                      
          
 for
 
           (old\_question, old\_answer) 
          
 in
 
           history:
            

                          template\_i.append\_message(template\_i.roles[
          
 0
 
          ], old\_question)
            

                          template\_i.append\_message(template\_i.roles[
          
 1
 
          ], old\_answer)
            

                      template\_i.append\_message(template\_i.roles[
          
 0
 
          ], question)
            

                      template\_i.append\_message(template\_i.roles[
          
 1
 
          ], 
          
 None
 
          )
            

                      query = template\_i.get\_prompt()
            

                      
          
 # Handle image tokens
 
            

                      
          
 if
 
           pixel\_values 
          
 is
 
          
 not
 
          
 None
 
          :
            

                          
          
 for
 
           num\_patches 
          
 in
 
           num\_patches\_list:
            

                              tile\_pos\_identifiers = [
          
 f"<tile\_
 
 {i}
 
 >"
 
          
 for
 
           i 
          
 in
 
           range(
          
 1
 
          , num\_patches)] + [
          
 "<tile\_global\_thumbnail>"
 
          ]
            

                              image\_tokens = 
          
 ''
 
            

                              
          
 for
 
           tile\_pos\_identifier 
          
 in
 
           tile\_pos\_identifiers:
            

                                  image\_tokens += tile\_pos\_identifier + IMG\_CONTEXT\_TOKEN * self.num\_image\_token
            

                              image\_tokens = IMG\_START\_TOKEN + image\_tokens + IMG\_END\_TOKEN
            

                              query = query.replace(
          
 '<image>'
 
          , image\_tokens, 
          
 1
 
          )
            

            

                      model\_inputs = tokenizer(
            

                          query,
            

                          return\_tensors=
          
 'pt'
 
          ,
            

                          padding=
          
 True
 
          ,
            

                          truncation=
          
 True
 
            

                      )
            

                      input\_ids = model\_inputs[
          
 'input\_ids'
 
          ].cuda()
            

                      attention\_mask = model\_inputs[
          
 'attention\_mask'
 
          ].cuda()
            

                      input\_ids\_list.append(input\_ids)
            

                      attention\_mask\_list.append(attention\_mask)
            

            

                  
          
 # Call the generate function
 
            

                  generation\_output = self.generate\_batch(
            

                      pixel\_values\_list=pixel\_values\_list,
            

                      input\_ids\_list=input\_ids\_list,
            

                      attention\_mask\_list=attention\_mask\_list,
            

                      **generation\_config
            

                  )
            

                  responses = tokenizer.batch\_decode(generation\_output, skip\_special\_tokens=
          
 True
 
          )
            

            

                  outputs = []
            

                  
          
 for
 
           idx, response 
          
 in
 
           enumerate(responses):
            

                      response = response.split(template.sep)[
          
 0
 
          ].strip()
            

                      histories[idx].append((questions[idx], response))
            

                      outputs.append(response)
            

            

                  
          
 if
 
           return\_histories:
            

                      
          
 return
 
           outputs, histories
            

                  
          
 else
 
          :
            

                      
          
 if
 
           verbose:
            

                          
          
 for
 
           idx, query 
          
 in
 
           enumerate(queries):
            

                              query\_to\_print = query.replace(IMG\_CONTEXT\_TOKEN, 
          
 ''
 
          )
            

                              query\_to\_print = query\_to\_print.replace(
          
 f'
 
 {IMG\_START\_TOKEN}
 
 
 {IMG\_END\_TOKEN}
 
 '
 
          , 
          
 '<image>'
 
          )
            

                              print(query\_to\_print, outputs[idx])
            

                      
          
 return
 
           outputs
            

            

          
     @torch.no\_grad()
 
            

              
          
 
 def
 
  
 
 generate\_batch
 
 
 (
   

             self,
   

             pixel\_values\_list: Optional[List[torch.FloatTensor]] = None,
   

             input\_ids\_list: Optional[List[torch.FloatTensor]] = None,
   

             attention\_mask\_list: Optional[List[torch.LongTensor]] = None,
   

             visual\_features: Optional[torch.FloatTensor] = None,
   

             generation\_config: Optional[GenerationConfig] = None,
   

             output\_hidden\_states: Optional[bool] = None,
   

             return\_dict: Optional[bool] = None,
   

             **generate\_kwargs,
   

     )
 
  -> torch.LongTensor:
 
            

                  input\_embeds\_list = []
            

                  attention\_mask\_padded\_list = []
            

            

                  max\_seq\_length = max(input\_ids.shape[
          
 1
 
          ] 
          
 for
 
           input\_ids 
          
 in
 
           input\_ids\_list)
            

            

                  
          
 for
 
           pixel\_values, input\_ids, attention\_mask 
          
 in
 
           zip(pixel\_values\_list, input\_ids\_list, attention\_mask\_list):
            

                      
          
 if
 
           pixel\_values 
          
 is
 
          
 not
 
          
 None
 
          :
            

                          
          
 if
 
           visual\_features 
          
 is
 
          
 not
 
          
 None
 
          :
            

                              vit\_embeds = visual\_features.cuda()
            

                              vit\_embeds = self.mlp1(vit\_embeds)
            

                          
          
 else
 
          :
            

                              vit\_embeds = self.extract\_feature(pixel\_values)
            

            

                          input\_embeds = self.language\_model.get\_input\_embeddings()(input\_ids)
            

                          B, N, C = input\_embeds.shape
            

                          input\_embeds = input\_embeds.reshape(B * N, C)
            

            

                          input\_ids = input\_ids.reshape(B * N)
            

                          selected = (input\_ids == self.img\_context\_token\_id)
            

                          
          
 assert
 
           selected.sum() != 
          
 0
 
          , 
          
 "No valid image context token IDs found."
 
            

                          input\_embeds[selected] = vit\_embeds.reshape(
          
 -1
 
          , C).to(input\_embeds.device)
            

            

                          input\_embeds = input\_embeds.reshape(B, N, C)
            

                      
          
 else
 
          :
            

                          input\_embeds = self.language\_model.get\_input\_embeddings()(input\_ids)
            

            

                      seq\_length = input\_embeds.shape[
          
 1
 
          ]
            

                      
          
 if
 
           seq\_length < max\_seq\_length:
            

                          pad\_size = max\_seq\_length - seq\_length
            

                          input\_embeds = F.pad(input\_embeds, (
          
 0
 
          , 
          
 0
 
          , 
          
 0
 
          , pad\_size))
            

                          attention\_mask = F.pad(attention\_mask, (
          
 0
 
          , pad\_size))
            

            

                      input\_embeds\_list.append(input\_embeds)
            

                      attention\_mask\_padded\_list.append(attention\_mask)
            

            

                  input\_embeds = torch.cat(input\_embeds\_list, dim=
          
 0
 
          )
            

                  attention\_mask = torch.cat(attention\_mask\_padded\_list, dim=
          
 0
 
          )
            

            

                  outputs = self.language\_model.generate(
            

                      inputs\_embeds=input\_embeds,
            

                      attention\_mask=attention\_mask,
            

                      generation\_config=generation\_config,
            

                      output\_hidden\_states=output\_hidden\_states,
            

                      return\_dict=return\_dict,
            

                      use\_cache=
          
 True
 
          ,
            

                      **generate\_kwargs,
            

                  )
            

            

                  
          
 return
 
           outputs
            

            

        
      

batch推理:


        
        
            

          
 import
 
           torch
            

          
 from
 
           modelscope 
          
 import
 
           AutoTokenizer, AutoModel
            

          
 from
 
           PIL 
          
 import
 
           Image
            

          
 import
 
           torchvision.transforms 
          
 as
 
           T
            

          
 from
 
           torchvision.transforms.functional 
          
 import
 
           InterpolationMode
            

            

          IMAGENET\_MEAN = (
          
 0.485
 
          , 
          
 0.456
 
          , 
          
 0.406
 
          )
            

          IMAGENET\_STD = (
          
 0.229
 
          , 
          
 0.224
 
          , 
          
 0.225
 
          )
            

            

            

          
 
 def
 
  
 
 build\_transform
 
 
 (input\_size)
 
 :
 
            

              MEAN, STD = IMAGENET\_MEAN, IMAGENET\_STD
            

              transform = T.Compose([
            

                  T.Lambda(
          
 lambda
 
           img: img.convert(
          
 'RGB'
 
          ) 
          
 if
 
           img.mode != 
          
 'RGB'
 
          
 else
 
           img),
            

                  T.Resize((input\_size, input\_size), interpolation=InterpolationMode.BICUBIC),
            

                  T.ToTensor(),
            

                  T.Normalize(mean=MEAN, std=STD)
            

              ])
            

              
          
 return
 
           transform
            

            

            

          
 
 def
 
  
 
 find\_closest\_aspect\_ratio
 
 
 (aspect\_ratio, target\_ratios, width, height, image\_size)
 
 :
 
            

              best\_ratio\_diff = float(
          
 'inf'
 
          )
            

              best\_ratio = (
          
 1
 
          , 
          
 1
 
          )
            

              area = width * height
            

              
          
 for
 
           ratio 
          
 in
 
           target\_ratios:
            

                  target\_aspect\_ratio = ratio[
          
 0
 
          ] / ratio[
          
 1
 
          ]
            

                  ratio\_diff = abs(aspect\_ratio - target\_aspect\_ratio)
            

                  
          
 if
 
           ratio\_diff < best\_ratio\_diff:
            

                      best\_ratio\_diff = ratio\_diff
            

                      best\_ratio = ratio
            

                  
          
 elif
 
           ratio\_diff == best\_ratio\_diff:
            

                      
          
 if
 
           area > 
          
 0.5
 
           * image\_size * image\_size * ratio[
          
 0
 
          ] * ratio[
          
 1
 
          ]:
            

                          best\_ratio = ratio
            

              
          
 return
 
           best\_ratio
            

            

            

          
 
 def
 
  
 
 dynamic\_preprocess
 
 
 (image, min\_num=
 
 1
 
 , max\_num=
 
 12
 
 , image\_size=
 
 448
 
 , use\_thumbnail=False)
 
 :
 
            

              orig\_width, orig\_height = image.size
            

              aspect\_ratio = orig\_width / orig\_height
            

            

              
          
 # calculate the existing image aspect ratio
 
            

              target\_ratios = set(
            

                  (i, j) 
          
 for
 
           n 
          
 in
 
           range(min\_num, max\_num + 
          
 1
 
          ) 
          
 for
 
           i 
          
 in
 
           range(
          
 1
 
          , n + 
          
 1
 
          ) 
          
 for
 
           j 
          
 in
 
           range(
          
 1
 
          , n + 
          
 1
 
          ) 
          
 if
 
            

                  i * j <= max\_num 
          
 and
 
           i * j >= min\_num)
            

              target\_ratios = sorted(target\_ratios, key=
          
 lambda
 
           x: x[
          
 0
 
          ] * x[
          
 1
 
          ])
            

            

              
          
 # find the closest aspect ratio to the target
 
            

              target\_aspect\_ratio = find\_closest\_aspect\_ratio(
            

                  aspect\_ratio, target\_ratios, orig\_width, orig\_height, image\_size)
            

            

              
          
 # calculate the target width and height
 
            

              target\_width = image\_size * target\_aspect\_ratio[
          
 0
 
          ]
            

              target\_height = image\_size * target\_aspect\_ratio[
          
 1
 
          ]
            

              blocks = target\_aspect\_ratio[
          
 0
 
          ] * target\_aspect\_ratio[
          
 1
 
          ]
            

            

              
          
 # resize the image
 
            

              resized\_img = image.resize((target\_width, target\_height))
            

              processed\_images = []
            

              
          
 for
 
           i 
          
 in
 
           range(blocks):
            

                  box = (
            

                      (i % (target\_width // image\_size)) * image\_size,
            

                      (i // (target\_width // image\_size)) * image\_size,
            

                      ((i % (target\_width // image\_size)) + 
          
 1
 
          ) * image\_size,
            

                      ((i // (target\_width // image\_size)) + 
          
 1
 
          ) * image\_size
            

                  )
            

                  
          
 # split the image
 
            

                  split\_img = resized\_img.crop(box)
            

                  processed\_images.append(split\_img)
            

              
          
 assert
 
           len(processed\_images) == blocks
            

              
          
 if
 
           use\_thumbnail 
          
 and
 
           len(processed\_images) != 
          
 1
 
          :
            

                  thumbnail\_img = image.resize((image\_size, image\_size))
            

                  processed\_images.append(thumbnail\_img)
            

              
          
 return
 
           processed\_images
            

            

            

          
 
 def
 
  
 
 load\_image
 
 
 (image\_file, input\_size=
 
 448
 
 , max\_num=
 
 12
 
 )
 
 :
 
            

              image = Image.open(image\_file).convert(
          
 'RGB'
 
          )
            

              transform = build\_transform(input\_size=input\_size)
            

              images = dynamic\_preprocess(image, image\_size=input\_size, use\_thumbnail=
          
 True
 
          , max\_num=max\_num)
            

              pixel\_values = [transform(image) 
          
 for
 
           image 
          
 in
 
           images]
            

              pixel\_values = torch.stack(pixel\_values)
            

              
          
 return
 
           pixel\_values
            

            

            

          
 
 def
 
  
 
 preprocess\_image
 
 
 (file\_path, dynamic=True, max\_num=
 
 6
 
 , image\_size=
 
 448
 
 )
 
 :
 
            

              
          
 try
 
          :
            

                  
          
 if
 
           dynamic:
            

                      
          
 return
 
           load\_image(file\_path, max\_num=max\_num).to(torch.bfloat16).cuda()
            

                  
          
 else
 
          :
            

                      img = Image.open(file\_path).convert(
          
 'RGB'
 
          )
            

                      transform = build\_transform(image\_size)
            

                      pixel\_values = transform(img)
            

                      
          
 return
 
           torch.stack([pixel\_values]).to(torch.bfloat16).cuda()
            

              
          
 except
 
           Exception 
          
 as
 
           e:
            

                  
          
 raise
 
           RuntimeError(
          
 f"Error processing image: 
 
 {e}
 
 "
 
          )
            

            

            

          path = 
          
 "Reyes-8B"
 
            

            

          model = AutoModel.from\_pretrained(
            

              path,
            

              torch\_dtype=torch.bfloat16,
            

              trust\_remote\_code=
          
 True
 
          ,
            

          ).eval().cuda()
            

            

          
 # print(model)
 
            

            

          tokenizer = AutoTokenizer.from\_pretrained(path, trust\_remote\_code=
          
 True
 
          , use\_fast=
          
 False
 
          )
            

          generation\_config = dict(max\_new\_tokens=
          
 2048
 
          , do\_sample=
          
 False
 
          )
            

          questions = [
            

              
          
 "<image>\nDescribe this image."
 
          ,
            

              
          
 "<image>\nDescribe this image."
 
          ,
            

              
          
 "<image>\nDescribe this image."
 
          ,
            

          ]
            

            

          images\_path = [
          
 "t6.png"
 
          ,
          
 "t6.png"
 
          ,
          
 "t6.png"
 
          ]
            

            

            

          
 
 def
 
  
 
 conversation
 
 
 (model, tokenizer, questions, images\_path,generation\_config,histories)
 
 :
 
            

              pixel\_values\_list=[]
            

            

              
          
 for
 
           i 
          
 in
 
           range(len(questions)):
            

                  
          
 if
 
           images\_path[i] 
          
 is
 
          
 not
 
          
 None
 
          :
            

                      pixel\_values = preprocess\_image(file\_path, dynamic=
          
 True
 
          )
            

                      pixel\_values\_list.append(pixel\_values)
            

            

            

              
          
 return
 
           model.chat\_batch(tokenizer, pixel\_values\_list, questions, generation\_config, histories, return\_histories=
          
 False
 
          )
            

            

          responses= conversation(model, tokenizer, questions, images\_path,generation\_config,histories=
          
 None
 
          )
            

          
 for
 
           question, response 
          
 in
 
           zip(questions, responses):
            

              print(
          
 f"User: 
 
 {question}
 
 \nAssistant: 
 
 {response}
 
 \n"
 
          )
            

            

        
      
0
0
0
0
关于作者
关于作者

文章

0

获赞

0

收藏

0

相关资源
大规模高性能计算集群优化实践
随着机器学习的发展,数据量和训练模型都有越来越大的趋势,这对基础设施有了更高的要求,包括硬件、网络架构等。本次分享主要介绍火山引擎支撑大规模高性能计算集群的架构和优化实践。
相关产品
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论