【多模态&LLM】Reyes:一个从0到1开始训练的多模态大模型(技术报告)

大模型向量数据库机器学习

最近,笔者系统的看了下一些比较经典的多模态大模型实现思路,本着动手实践的态度,从零到一实现了一个多模态大模型,并命名为 Reyes(睿视) ,R:睿,eyes:眼。Reyes的参数量为8B,视觉编码器使用的是 InternViT-300M-448px-V2\_5 ,语言模型侧使用的是 Qwen2.5-7B-Instruct ,与NVLM-1.0等相关多模态大模型一样,Reyes也 通过一个两层MLP投影层连接视觉编码器与语言模型 。最终,Reyes-8B(0.447分)以更小的参数量在MMMU-benchmark得分超越llava1.5-13B(0.367分)。

picture.image

Reyes模型大体架构

Reyes模型架构

模型实现 :ReyesModel


        
        
            

          
   

 
          
 
 class
 
  
 
 ReyesModel
 
 
 (PreTrainedModel)
 
 :
 
          
   

 
              config\_class = ReyesConfig
          
   

 
              main\_input\_name = 
          
 'pixel\_values'
 
          
   

 
              \_supports\_flash\_attn\_2 = 
          
 True
 
          
   

 
              \_no\_split\_modules = [
          
 'InternVisionModel'
 
          , 
          
 'Qwen2DecoderLayer'
 
          ]
          
   

 
          
   

 
              
          
 
 def
 
  
 
 \_\_init\_\_
 
 
 (self, config: ReyesConfig, vision\_model=None, language\_model=None, use\_flash\_attn=True)
 
 :
 
          
   

 
                  super().\_\_init\_\_(config)
          
   

 
          
   

 
                  
          
 assert
 
           version\_cmp(transformers.\_\_version\_\_, 
          
 '4.44.2'
 
          , 
          
 'ge'
 
          )
          
   

 
                  image\_size = config.force\_image\_size 
          
 or
 
           config.vision\_config.image\_size
          
   

 
                  patch\_size = config.vision\_config.patch\_size
          
   

 
                  self.patch\_size = patch\_size
          
   

 
                  self.select\_layer = config.select\_layer
          
   

 
                  self.llm\_arch\_name = config.llm\_config.architectures[
          
 0
 
          ]
          
   

 
                  self.template = config.template
          
   

 
                  self.num\_image\_token = int((image\_size // patch\_size) ** 
          
 2
 
           * (config.downsample\_ratio ** 
          
 2
 
          ))
          
   

 
                  self.downsample\_ratio = config.downsample\_ratio
          
   

 
                  self.ps\_version = config.ps\_version
          
   

 
                  use\_flash\_attn = use\_flash\_attn 
          
 if
 
           has\_flash\_attn 
          
 else
 
          
 False
 
          
   

 
                  config.vision\_config.use\_flash\_attn = 
          
 True
 
          
 if
 
           use\_flash\_attn 
          
 else
 
          
 False
 
          
   

 
                  config.llm\_config.\_attn\_implementation = 
          
 'flash\_attention\_2'
 
          
 if
 
           use\_flash\_attn 
          
 else
 
          
 'eager'
 
          
   

 
          
   

 
                  logger.info(
          
 f'num\_image\_token: 
 
 {self.num\_image\_token}
 
 '
 
          )
          
   

 
                  logger.info(
          
 f'ps\_version: 
 
 {self.ps\_version}
 
 '
 
          )
          
   

 
                  
          
 if
 
           vision\_model 
          
 is
 
          
 not
 
          
 None
 
          :
          
   

 
                      self.vision\_model = vision\_model
          
   

 
                  
          
 else
 
          :
          
   

 
                      self.vision\_model = InternVisionModel(config.vision\_config)
          
   

 
                  
          
 if
 
           language\_model 
          
 is
 
          
 not
 
          
 None
 
          :
          
   

 
                      self.language\_model = language\_model
          
   

 
                  
          
 else
 
          :
          
   

 
                      
          
 if
 
           config.llm\_config.architectures[
          
 0
 
          ] == 
          
 'Qwen2ForCausalLM'
 
          :
          
   

 
                          self.language\_model = Qwen2ForCausalLM(config.llm\_config)
          
   

 
                          
          
 # self.language\_model = AutoLigerKernelForCausalLM(config.llm\_config)
 
          
   

 
                      
          
 else
 
          :
          
   

 
                          
          
 raise
 
           NotImplementedError(
          
 f'
 
 {config.llm\_config.architectures[
 
 0
 
 ]}
 
  is not implemented.'
 
          )
          
   

 
          
   

 
                  vit\_hidden\_size = config.vision\_config.hidden\_size
          
   

 
                  llm\_intermediate\_size = config.llm\_config.intermediate\_size
          
   

 
                  llm\_hidden\_size = config.llm\_config.hidden\_size
          
   

 
          
   

 
                  self.mlp1 = nn.Sequential(
          
   

 
                      nn.LayerNorm(vit\_hidden\_size * int(
          
 1
 
           / self.downsample\_ratio) ** 
          
 2
 
          ),
          
   

 
                      nn.Linear(vit\_hidden\_size * int(
          
 1
 
           / self.downsample\_ratio) ** 
          
 2
 
          , llm\_intermediate\_size, bias=
          
 False
 
          ),
          
   

 
                      nn.GELU(),
          
   

 
                      nn.Linear(llm\_intermediate\_size, llm\_hidden\_size, bias=
          
 False
 
          )
          
   

 
                  )
          
   

 
          
   

 
                  self.img\_context\_token\_id = 
          
 None
 
          
   

 
                  self.conv\_template = get\_conv\_template(self.template)
          
   

 
                  self.system\_message = self.conv\_template.system\_message
          
   

 
          
   

 
                  
          
 if
 
           config.use\_backbone\_lora:
          
   

 
                      self.wrap\_backbone\_lora(r=config.use\_backbone\_lora, lora\_alpha=
          
 2
 
           * config.use\_backbone\_lora)
          
   

 
          
   

 
                  
          
 if
 
           config.use\_llm\_lora:
          
   

 
                      self.wrap\_llm\_lora(r=config.use\_llm\_lora, lora\_alpha=
          
 2
 
           * config.use\_llm\_lora)
          
   

 
          
   

 
              
          
 
 def
 
  
 
 wrap\_backbone\_lora
 
 
 (self, r=
 
 128
 
 , lora\_alpha=
 
 256
 
 , lora\_dropout=
 
 0.05
 
 )
 
 :
 
          
   

 
                  lora\_config = LoraConfig(
          
   

 
                      r=r,
          
   

 
                      target\_modules=[
          
 'attn.qkv'
 
          , 
          
 'attn.proj'
 
          , 
          
 'mlp.fc1'
 
          , 
          
 'mlp.fc2'
 
          ],
          
   

 
                      lora\_alpha=lora\_alpha,
          
   

 
                      lora\_dropout=lora\_dropout,
          
   

 
                  )
          
   

 
                  self.vision\_model = get\_peft\_model(self.vision\_model, lora\_config)
          
   

 
                  self.vision\_model.print\_trainable\_parameters()
          
   

 
          
   

 
              
          
 
 def
 
  
 
 wrap\_llm\_lora
 
 
 (self, r=
 
 128
 
 , lora\_alpha=
 
 256
 
 , lora\_dropout=
 
 0.05
 
 )
 
 :
 
          
   

 
                  
          
 # Determine the target modules based on the architecture of the language model
 
          
   

 
                  
          
 if
 
           self.llm\_arch\_name 
          
 in
 
           [
          
 'Qwen2ForCausalLM'
 
          , 
          
 'LlamaForCausalLM'
 
          ]:
          
   

 
                      target\_modules = [
          
 'self\_attn.q\_proj'
 
          , 
          
 'self\_attn.k\_proj'
 
          , 
          
 'self\_attn.v\_proj'
 
          , 
          
 'self\_attn.o\_proj'
 
          ,
          
   

 
                                        
          
 'mlp.gate\_proj'
 
          , 
          
 'mlp.down\_proj'
 
          , 
          
 'mlp.up\_proj'
 
          ]
          
   

 
                  
          
 else
 
          :
          
   

 
                      
          
 raise
 
          
 NotImplemented
 
          
   

 
                  lora\_config = LoraConfig(
          
   

 
                      r=r,
          
   

 
                      target\_modules=target\_modules,
          
   

 
                      lora\_alpha=lora\_alpha,
          
   

 
                      lora\_dropout=lora\_dropout,
          
   

 
                      task\_type=
          
 'CAUSAL\_LM'
 
          
   

 
                  )
          
   

 
                  self.language\_model = get\_peft\_model(self.language\_model, lora\_config)
          
   

 
                  self.language\_model.enable\_input\_require\_grads()
          
   

 
                  self.language\_model.print\_trainable\_parameters()
          
   

 
          
   

 
              
          
 
 def
 
  
 
 forward
 
 
 (
 
   

 
             self,
 
   

 
             pixel\_values: torch.FloatTensor,
 
   

 
             input\_ids: torch.LongTensor = None,
 
   

 
             attention\_mask: Optional[torch.Tensor] = None,
 
   

 
             position\_ids: Optional[torch.LongTensor] = None,
 
   

 
             image\_flags: Optional[torch.LongTensor] = None,
 
   

 
             past\_key\_values: Optional[List[torch.FloatTensor]] = None,
 
   

 
             labels: Optional[torch.LongTensor] = None,
 
   

 
             use\_cache: Optional[bool] = None,
 
   

 
             output\_attentions: Optional[bool] = None,
 
   

 
             output\_hidden\_states: Optional[bool] = None,
 
   

 
             return\_dict: Optional[bool] = None,
 
   

 
     )
 
  -> Union[Tuple, CausalLMOutputWithPast]:
 
          
   

 
                  return\_dict = return\_dict 
          
 if
 
           return\_dict 
          
 is
 
          
 not
 
          
 None
 
          
 else
 
           self.config.use\_return\_dict
          
   

 
          
   

 
                  
          
 # image\_flags = image\_flags.squeeze(-1)
 
          
   

 
                  input\_embeds = self.language\_model.get\_input\_embeddings()(input\_ids)
          
   

 
          
   

 
                  vit\_embeds = self.extract\_feature(pixel\_values)
          
   

 
                  
          
 # vit\_embeds = vit\_embeds[image\_flags == 1]
 
          
   

 
                  vit\_batch\_size = pixel\_values.shape[
          
 0
 
          ]
          
   

 
          
   

 
                  B, N, C = input\_embeds.shape
          
   

 
                  input\_embeds = input\_embeds.reshape(B * N, C)
          
   

 
          
   

 
                  
          
 # if torch.distributed.get\_rank() == 0:
 
          
   

 
                  
          
 #     print(f'dynamic ViT batch size: {vit\_batch\_size}, images per sample: {vit\_batch\_size / B}, dynamic token length: {N}')
 
          
   

 
          
   

 
                  input\_ids = input\_ids.reshape(B * N)
          
   

 
                  selected = (input\_ids == self.img\_context\_token\_id)
          
   

 
                  
          
 try
 
          :
          
   

 
                      input\_embeds[selected] = input\_embeds[selected] * 
          
 0.0
 
           + vit\_embeds.reshape(
          
 -1
 
          , C)
          
   

 
                  
          
 except
 
           Exception 
          
 as
 
           e:
          
   

 
                      vit\_embeds = vit\_embeds.reshape(
          
 -1
 
          , C)
          
   

 
                      print(
          
 f'warning: 
 
 {e}
 
 , input\_embeds[selected].shape=
 
 {input\_embeds[selected].shape}
 
 , '
 
          
   

 
                            
          
 f'vit\_embeds.shape=
 
 {vit\_embeds.shape}
 
 '
 
          )
          
   

 
                      n\_token = selected.sum()
          
   

 
                      input\_embeds[selected] = input\_embeds[selected] * 
          
 0.0
 
           + vit\_embeds[:n\_token]
          
   

 
          
   

 
                  input\_embeds = input\_embeds.reshape(B, N, C)
          
   

 
          
   

 
                  outputs = self.language\_model(
          
   

 
                      inputs\_embeds=input\_embeds,
          
   

 
                      attention\_mask=attention\_mask,
          
   

 
                      position\_ids=position\_ids,
          
   

 
                      past\_key\_values=past\_key\_values,
          
   

 
                      use\_cache=use\_cache,
          
   

 
                      output\_attentions=output\_attentions,
          
   

 
                      output\_hidden\_states=output\_hidden\_states,
          
   

 
                      return\_dict=return\_dict,
          
   

 
                  )
          
   

 
                  logits = outputs.logits
          
   

 
          
   

 
                  loss = 
          
 None
 
          
   

 
                  
          
 if
 
           labels 
          
 is
 
          
 not
 
          
 None
 
          :
          
   

 
                      
          
 # Shift so that tokens < n predict n
 
          
   

 
                      shift\_logits = logits[..., :
          
 -1
 
          , :].contiguous()
          
   

 
                      shift\_labels = labels[..., 
          
 1
 
          :].contiguous()
          
   

 
                      
          
 # Flatten the tokens
 
          
   

 
                      loss\_fct = CrossEntropyLoss()
          
   

 
                      shift\_logits = shift\_logits.view(
          
 -1
 
          , self.language\_model.config.vocab\_size)
          
   

 
                      shift\_labels = shift\_labels.view(
          
 -1
 
          )
          
   

 
          
   

 
                      
          
 # Enable model parallelism
 
          
   

 
                      shift\_labels = shift\_labels.to(shift\_logits.device)
          
   

 
                      loss = loss\_fct(shift\_logits, shift\_labels)
          
   

 
          
   

 
                  
          
 if
 
          
 not
 
           return\_dict:
          
   

 
                      output = (logits,) + outputs[
          
 1
 
          :]
          
   

 
                      
          
 return
 
           (loss,) + output 
          
 if
 
           loss 
          
 is
 
          
 not
 
          
 None
 
          
 else
 
           output
          
   

 
          
   

 
                  
          
 return
 
           CausalLMOutputWithPast(
          
   

 
                      loss=loss,
          
   

 
                      logits=logits,
          
   

 
                      past\_key\_values=outputs.past\_key\_values,
          
   

 
                      hidden\_states=outputs.hidden\_states,
          
   

 
                      attentions=outputs.attentions,
          
   

 
                  )
          
   

 
          
   

 
        
      

数据引擎

图片放置方式:


              
              
                  

                ├── coco
                
   

 
                │   └── train2017
                
   

 
                ├── gqa
                
   

 
                │   └── images
                
   

 
                ├── ocr\_vqa
                
   

 
                │   └── images
                
   

 
                ├── textvqa
                
   

 
                │   └── train\_images
                
   

 
                └── vg
                
   

 
                    ├── VG\_100K
                
   

 
                    └── VG\_100K\_2
                
   

 
              
            

数据格式

  • 预训练数据格式:

              
              
                  

                <Image><image></Image> Give a brief description of the image. A group of airplanes flying through the sky.<|im\_end|>
                
   

 
              
            
  • SFT数据格式(chatml格式)

              
              
                  

                <|im\_start|>system
                
   

 
                Follow the user
                
 's instruction and answer questions.<|im\_end|>
 
   

 
 <|im\_start|>user
 
   

 
 <Image><image></lmage>
 
   

 
 What activity are the people engaged in on the green field?<|im\_end|>
 
   

 
 <|im\_start|>assistant
 
   

 
 The people are engaged in playing a game of frisbee on the lush green field.<|im\_end|>
 
   

 
 <|im\_start|>user
 
   

 
 Is there a person wearing a fairy outfit in the image? If so, what are they doing?<|im\_end|>
 
   

 
 <|im\_start|>assistant
 
   

 
 Yes, there is a person wearing a fairy outfit in the image. They are walking in the field, likely
 
   

 
 participating in the fun activities along with the other people.<|im\_end|>
 
   

 
 <|im\_start|>user
 
   

 
 What can you say about the spectator'
 
                s location 
                
 in
 
                 relation to the houses?<|im\_end|>
                
   

 
                <|im\_start|>assistant
                
   

 
                The spectators are watching a match being played 
                
 in
 
                 a location close to the houses, possibly indicating that the field is within a residential area or park.<|im\_end|>
                
   

 
              
            

动态高分辨率

在数据处理阶段,采用动态高分辨率(DHR)方法来处理不同分辨率的图像输入。具体的如下图,图像被分割成最多6个瓦片(tile):

picture.image

picture.image

上述两张图都是动态DHR的处理过程,围绕图像的预处理,包括归一化、缩放、裁剪、根据宽高比动态处理等操作,构建了一套完整的流程,代码逻辑如下:


        
        
            

          
 import
 
           torch
          
   

 
          
 from
 
           PIL 
          
 import
 
           Image
          
   

 
          
 import
 
           torchvision.transforms 
          
 as
 
           T
          
   

 
          
 from
 
           torchvision.transforms.functional 
          
 import
 
           InterpolationMode
          
   

 
          
   

 
          IMAGENET\_MEAN = (
          
 0.485
 
          , 
          
 0.456
 
          , 
          
 0.406
 
          )
          
   

 
          IMAGENET\_STD = (
          
 0.229
 
          , 
          
 0.224
 
          , 
          
 0.225
 
          )
          
   

 
          
   

 
          
   

 
          
 
 def
 
  
 
 build\_transform
 
 
 (input\_size)
 
 :
 
          
   

 
              MEAN, STD = IMAGENET\_MEAN, IMAGENET\_STD
          
   

 
              transform = T.Compose([
          
   

 
                  T.Lambda(
          
 lambda
 
           img: img.convert(
          
 'RGB'
 
          ) 
          
 if
 
           img.mode != 
          
 'RGB'
 
          
 else
 
           img),
          
   

 
                  T.Resize((input\_size, input\_size), interpolation=InterpolationMode.BICUBIC),
          
   

 
                  T.ToTensor(),
          
   

 
                  T.Normalize(mean=MEAN, std=STD)
          
   

 
              ])
          
   

 
              
          
 return
 
           transform
          
   

 
          
   

 
          
   

 
          
 
 def
 
  
 
 find\_closest\_aspect\_ratio
 
 
 (aspect\_ratio, target\_ratios, width, height, image\_size)
 
 :
 
          
   

 
              best\_ratio\_diff = float(
          
 'inf'
 
          )
          
   

 
              best\_ratio = (
          
 1
 
          , 
          
 1
 
          )
          
   

 
              area = width * height
          
   

 
              
          
 for
 
           ratio 
          
 in
 
           target\_ratios:
          
   

 
                  target\_aspect\_ratio = ratio[
          
 0
 
          ] / ratio[
          
 1
 
          ]
          
   

 
                  ratio\_diff = abs(aspect\_ratio - target\_aspect\_ratio)
          
   

 
                  
          
 if
 
           ratio\_diff < best\_ratio\_diff:
          
   

 
                      best\_ratio\_diff = ratio\_diff
          
   

 
                      best\_ratio = ratio
          
   

 
                  
          
 elif
 
           ratio\_diff == best\_ratio\_diff:
          
   

 
                      
          
 if
 
           area > 
          
 0.5
 
           * image\_size * image\_size * ratio[
          
 0
 
          ] * ratio[
          
 1
 
          ]:
          
   

 
                          best\_ratio = ratio
          
   

 
              
          
 return
 
           best\_ratio
          
   

 
          
   

 
          
   

 
          
 
 def
 
  
 
 dynamic\_preprocess
 
 
 (image, min\_num=
 
 1
 
 , max\_num=
 
 6
 
 , image\_size=
 
 448
 
 , use\_thumbnail=True)
 
 :
 
          
   

 
              orig\_width, orig\_height = image.size
          
   

 
              aspect\_ratio = orig\_width / orig\_height
          
   

 
          
   

 
              target\_ratios = set(
          
   

 
                  (i, j) 
          
 for
 
           n 
          
 in
 
           range(min\_num, max\_num + 
          
 1
 
          ) 
          
 for
 
           i 
          
 in
 
           range(
          
 1
 
          , n + 
          
 1
 
          ) 
          
 for
 
           j 
          
 in
 
           range(
          
 1
 
          , n + 
          
 1
 
          ) 
          
 if
 
          
   

 
                  i * j <= max\_num 
          
 and
 
           i * j >= min\_num)
          
   

 
              target\_ratios = sorted(target\_ratios, key=
          
 lambda
 
           x: x[
          
 0
 
          ] * x[
          
 1
 
          ])
          
   

 
          
   

 
              target\_aspect\_ratio = find\_closest\_aspect\_ratio(
          
   

 
                  aspect\_ratio, target\_ratios, orig\_width, orig\_height, image\_size)
          
   

 
          
   

 
              target\_width = image\_size * target\_aspect\_ratio[
          
 0
 
          ]
          
   

 
              target\_height = image\_size * target\_aspect\_ratio[
          
 1
 
          ]
          
   

 
              blocks = target\_aspect\_ratio[
          
 0
 
          ] * target\_aspect\_ratio[
          
 1
 
          ]
          
   

 
          
   

 
              resized\_img = image.resize((target\_width, target\_height))
          
   

 
              processed\_images = []
          
   

 
              
          
 for
 
           i 
          
 in
 
           range(blocks):
          
   

 
                  box = (
          
   

 
                      (i % (target\_width // image\_size)) * image\_size,
          
   

 
                      (i // (target\_width // image\_size)) * image\_size,
          
   

 
                      ((i % (target\_width // image\_size)) + 
          
 1
 
          ) * image\_size,
          
   

 
                      ((i // (target\_width // image\_size)) + 
          
 1
 
          ) * image\_size
          
   

 
                  )
          
   

 
                  split\_img = resized\_img.crop(box)
          
   

 
                  processed\_images.append(split\_img)
          
   

 
              
          
 assert
 
           len(processed\_images) == blocks
          
   

 
              
          
 if
 
           use\_thumbnail 
          
 and
 
           len(processed\_images) != 
          
 1
 
          :
          
   

 
                  thumbnail\_img = image.resize((image\_size, image\_size))
          
   

 
                  processed\_images.append(thumbnail\_img)
          
   

 
              
          
 return
 
           processed\_images
          
   

 
          
   

 
          
   

 
          
 
 def
 
  
 
 load\_image
 
 
 (image\_file, input\_size=
 
 448
 
 , max\_num=
 
 6
 
 )
 
 :
 
          
   

 
              image = Image.open(image\_file).convert(
          
 'RGB'
 
          )
          
   

 
              transform = build\_transform(input\_size=input\_size)
          
   

 
              images = dynamic\_preprocess(image, image\_size=input\_size, use\_thumbnail=
          
 True
 
          , max\_num=max\_num)
          
   

 
              pixel\_values = [transform(image) 
          
 for
 
           image 
          
 in
 
           images]
          
   

 
              pixel\_values = torch.stack(pixel\_values)
          
   

 
              
          
 return
 
           pixel\_values
          
   

 
          
   

 
        
      

loss效果

  • 预训练loss

picture.image

预训练loss,epoch=1

  • SFT loss

picture.image

SFT loss,epoch=1

训练配置

为了与llava1.5-13B公平对比,笔者在训练数据上和一些训练参数上进行了对齐。

  • pretrain阶段: 冻结视觉侧和LLM侧,只训练MLP对齐 ,max-len=2048,gradient_accumulation_steps=4,单卡batch-size=8,8xH100,所有batch-size=8x4x8=256。
  • SFT阶段: 继续保持视觉侧冻结,放开LLM,与MLP一起训练 ,max-len=2048,gradient_accumulation_steps=2,单卡batch-size=8,8xH100,所有batch-size=8x2x8=128。

推理


        
        
            

          
   

 
          
 import
 
           torch
          
   

 
          
 from
 
           modelscope 
          
 import
 
           AutoTokenizer, AutoModel
          
   

 
          
 from
 
           PIL 
          
 import
 
           Image
          
   

 
          
 import
 
           torchvision.transforms 
          
 as
 
           T
          
   

 
          
 from
 
           torchvision.transforms.functional 
          
 import
 
           InterpolationMode
          
   

 
          
   

 
          IMAGENET\_MEAN = (
          
 0.485
 
          , 
          
 0.456
 
          , 
          
 0.406
 
          )
          
   

 
          IMAGENET\_STD = (
          
 0.229
 
          , 
          
 0.224
 
          , 
          
 0.225
 
          )
          
   

 
          
   

 
          
   

 
          
 
 def
 
  
 
 build\_transform
 
 
 (input\_size)
 
 :
 
          
   

 
              MEAN, STD = IMAGENET\_MEAN, IMAGENET\_STD
          
   

 
              transform = T.Compose([
          
   

 
                  T.Lambda(
          
 lambda
 
           img: img.convert(
          
 'RGB'
 
          ) 
          
 if
 
           img.mode != 
          
 'RGB'
 
          
 else
 
           img),
          
   

 
                  T.Resize((input\_size, input\_size), interpolation=InterpolationMode.BICUBIC),
          
   

 
                  T.ToTensor(),
          
   

 
                  T.Normalize(mean=MEAN, std=STD)
          
   

 
              ])
          
   

 
              
          
 return
 
           transform
          
   

 
          
   

 
          
   

 
          
 
 def
 
  
 
 find\_closest\_aspect\_ratio
 
 
 (aspect\_ratio, target\_ratios, width, height, image\_size)
 
 :
 
          
   

 
              best\_ratio\_diff = float(
          
 'inf'
 
          )
          
   

 
              best\_ratio = (
          
 1
 
          , 
          
 1
 
          )
          
   

 
              area = width * height
          
   

 
              
          
 for
 
           ratio 
          
 in
 
           target\_ratios:
          
   

 
                  target\_aspect\_ratio = ratio[
          
 0
 
          ] / ratio[
          
 1
 
          ]
          
   

 
                  ratio\_diff = abs(aspect\_ratio - target\_aspect\_ratio)
          
   

 
                  
          
 if
 
           ratio\_diff < best\_ratio\_diff:
          
   

 
                      best\_ratio\_diff = ratio\_diff
          
   

 
                      best\_ratio = ratio
          
   

 
                  
          
 elif
 
           ratio\_diff == best\_ratio\_diff:
          
   

 
                      
          
 if
 
           area > 
          
 0.5
 
           * image\_size * image\_size * ratio[
          
 0
 
          ] * ratio[
          
 1
 
          ]:
          
   

 
                          best\_ratio = ratio
          
   

 
              
          
 return
 
           best\_ratio
          
   

 
          
   

 
          
   

 
          
 
 def
 
  
 
 dynamic\_preprocess
 
 
 (image, min\_num=
 
 1
 
 , max\_num=
 
 12
 
 , image\_size=
 
 448
 
 , use\_thumbnail=False)
 
 :
 
          
   

 
              orig\_width, orig\_height = image.size
          
   

 
              aspect\_ratio = orig\_width / orig\_height
          
   

 
          
   

 
              
          
 # calculate the existing image aspect ratio
 
          
   

 
              target\_ratios = set(
          
   

 
                  (i, j) 
          
 for
 
           n 
          
 in
 
           range(min\_num, max\_num + 
          
 1
 
          ) 
          
 for
 
           i 
          
 in
 
           range(
          
 1
 
          , n + 
          
 1
 
          ) 
          
 for
 
           j 
          
 in
 
           range(
          
 1
 
          , n + 
          
 1
 
          ) 
          
 if
 
          
   

 
                  i * j <= max\_num 
          
 and
 
           i * j >= min\_num)
          
   

 
              target\_ratios = sorted(target\_ratios, key=
          
 lambda
 
           x: x[
          
 0
 
          ] * x[
          
 1
 
          ])
          
   

 
          
   

 
              
          
 # find the closest aspect ratio to the target
 
          
   

 
              target\_aspect\_ratio = find\_closest\_aspect\_ratio(
          
   

 
                  aspect\_ratio, target\_ratios, orig\_width, orig\_height, image\_size)
          
   

 
          
   

 
              
          
 # calculate the target width and height
 
          
   

 
              target\_width = image\_size * target\_aspect\_ratio[
          
 0
 
          ]
          
   

 
              target\_height = image\_size * target\_aspect\_ratio[
          
 1
 
          ]
          
   

 
              blocks = target\_aspect\_ratio[
          
 0
 
          ] * target\_aspect\_ratio[
          
 1
 
          ]
          
   

 
          
   

 
              
          
 # resize the image
 
          
   

 
              resized\_img = image.resize((target\_width, target\_height))
          
   

 
              processed\_images = []
          
   

 
              
          
 for
 
           i 
          
 in
 
           range(blocks):
          
   

 
                  box = (
          
   

 
                      (i % (target\_width // image\_size)) * image\_size,
          
   

 
                      (i // (target\_width // image\_size)) * image\_size,
          
   

 
                      ((i % (target\_width // image\_size)) + 
          
 1
 
          ) * image\_size,
          
   

 
                      ((i // (target\_width // image\_size)) + 
          
 1
 
          ) * image\_size
          
   

 
                  )
          
   

 
                  
          
 # split the image
 
          
   

 
                  split\_img = resized\_img.crop(box)
          
   

 
                  processed\_images.append(split\_img)
          
   

 
              
          
 assert
 
           len(processed\_images) == blocks
          
   

 
              
          
 if
 
           use\_thumbnail 
          
 and
 
           len(processed\_images) != 
          
 1
 
          :
          
   

 
                  thumbnail\_img = image.resize((image\_size, image\_size))
          
   

 
                  processed\_images.append(thumbnail\_img)
          
   

 
              
          
 return
 
           processed\_images
          
   

 
          
   

 
          
   

 
          
 
 def
 
  
 
 load\_image
 
 
 (image\_file, input\_size=
 
 448
 
 , max\_num=
 
 12
 
 )
 
 :
 
          
   

 
              image = Image.open(image\_file).convert(
          
 'RGB'
 
          )
          
   

 
              transform = build\_transform(input\_size=input\_size)
          
   

 
              images = dynamic\_preprocess(image, image\_size=input\_size, use\_thumbnail=
          
 True
 
          , max\_num=max\_num)
          
   

 
              pixel\_values = [transform(image) 
          
 for
 
           image 
          
 in
 
           images]
          
   

 
              pixel\_values = torch.stack(pixel\_values)
          
   

 
              
          
 return
 
           pixel\_values
          
   

 
          
   

 
          
   

 
          
 
 def
 
  
 
 preprocess\_image
 
 
 (file\_path, dynamic=True, max\_num=
 
 6
 
 , image\_size=
 
 448
 
 )
 
 :
 
          
   

 
              
          
 try
 
          :
          
   

 
                  
          
 if
 
           dynamic:
          
   

 
                      
          
 return
 
           load\_image(file\_path, max\_num=max\_num).to(torch.bfloat16).cuda()
          
   

 
                  
          
 else
 
          :
          
   

 
                      img = Image.open(file\_path).convert(
          
 'RGB'
 
          )
          
   

 
                      transform = build\_transform(image\_size)
          
   

 
                      pixel\_values = transform(img)
          
   

 
                      
          
 return
 
           torch.stack([pixel\_values]).to(torch.bfloat16).cuda()
          
   

 
              
          
 except
 
           Exception 
          
 as
 
           e:
          
   

 
                  
          
 raise
 
           RuntimeError(
          
 f"Error processing image: 
 
 {e}
 
 "
 
          )
          
   

 
          
   

 
          
   

 
          path = 
          
 "Reyes-8B"
 
          
   

 
          
   

 
          model = AutoModel.from\_pretrained(
          
   

 
              path,
          
   

 
              torch\_dtype=torch.bfloat16,
          
   

 
              trust\_remote\_code=
          
 True
 
          ,
          
   

 
          ).eval().cuda()
          
   

 
          
   

 
          
 # print(model)
 
          
   

 
          
   

 
          tokenizer = AutoTokenizer.from\_pretrained(path, trust\_remote\_code=
          
 True
 
          , use\_fast=
          
 False
 
          )
          
   

 
          generation\_config = dict(max\_new\_tokens=
          
 2048
 
          , do\_sample=
          
 False
 
          )
          
   

 
          
   

 
          
 # single-image single-round conversation
 
          
   

 
          file\_path = 
          
 'tmp.png'
 
          
   

 
          pixel\_values = preprocess\_image(file\_path, dynamic=
          
 True
 
          )
          
   

 
          question = 
          
 '<image>\nPlease describe the image shortly.'
 
          
   

 
          response = model.chat(tokenizer, pixel\_values, question, generation\_config)
          
   

 
          print(
          
 f'User: 
 
 {question}
 
 \nAssistant: 
 
 {response}
 
 '
 
          )
          
   

 
          
   

 
          
 # pure-text conversation
 
          
   

 
          question = 
          
 'Hello, who are you?'
 
          
   

 
          response, history = model.chat(tokenizer, 
          
 None
 
          , question, generation\_config, history=
          
 None
 
          , return\_history=
          
 True
 
          )
          
   

 
          print(
          
 f'User: 
 
 {question}
 
 \nAssistant: 
 
 {response}
 
 '
 
          )
          
   

 
          
   

 
        
      

评测

  1. MMMU评测(MMMU: A Massive Multi-discipline Multimodal Understanding and Reasoning Benchmark for Expert AGI)

简介:MMMU 是一种新的基准,旨在评估多模态模型在需要大学水平的学科知识和深思熟虑的推理的大规模多学科任务中的表现。MMMU 包含11.5K 个精心收集的来自大学考试、测验和教科书的多模态问题,涵盖六个核心学科:艺术与设计、商业、科学、健康与医学、人文与社会科学以及技术与工程。这些问题涵盖30 个学科和183 个子领域,包含32 种高度异构的图像类型,如图表、图解、地图、表格、乐谱和化学结构。与现有基准不同,MMMU 专注于使用领域特定知识进行高级感知和推理,挑战模型执行类似于专家面临的任务。

picture.image

评测结果显示:Reyes-8b比llava1.5-13b取得了更先进的结果。 详细评分如下:

  • llava1.5-13b得分:0.367

picture.image

  • Reyes-8b得分:0.447

picture.image

  1. 一些测试case
  • case1picture.image

              
              
                  

                问题: Who painted <image 1>?
                
   

 
                选项: {
                
 'A'
 
                :
                
 'Claude Monet'
 
                , 
                
 'B'
 
                :
                
 'Henri Matisse'
 
                , 
                
 'C'
 
                :
                
 'Andy Warhol'
 
                ,
                
 'D'
 
                : 
                
 "Georgia O'Keefe"
 
                ]
                
   

 
                预测的答案: C
                
   

 
                正确的答案: C
                
   

 
              
            
  • case2

picture.image


              
              
                  

                问题: Each situation below relates to an independent company
                
 's Owners'
 
                 Equity. <image 1> Calculate the missing values of company 2.
                
   

 
                选项: {
                
 'A'
 
                : 
                
 '$1,620'
 
                , 
                
 'B'
 
                : 
                
 '$12,000'
 
                , 
                
 'C'
 
                : 
                
 '$51,180'
 
                , 
                
 'D'
 
                : 
                
 '$0'
 
                }
                
   

 
                预测的答案: D
                
   

 
                正确的答案: D
                
   

 
              
            
  • case3

picture.image


              
              
                  

                问题: A survey line ABC crossing a river at right angles cut its banks at B and C, as shown 
                
 in
 
                 Fig. 2.39. To determine the width BC of the river, the following operation was carried out.A 60 m long line BE was 
                
 set
 
                 out roughly parallel to the river. Line CE was extended to D and mid-point F of DB was established. Then EF was extended to G such that FG = EF. Line DG was extended to cut the survey line ABC at H. GH and HB were measured and found to be 40 m and 80 m, respectively.Find the width of the river.<image 1>
                
   

 
                选项: {
                
 'A'
 
                : 
                
 '120 m'
 
                , 
                
 'B'
 
                : 
                
 '122 m'
 
                , 
                
 'C'
 
                : 
                
 '123 m'
 
                , 
                
 'D'
 
                : 
                
 '121 m'
 
                }
                
   

 
                预测的答案: A
                
   

 
                正确的答案: A
                
   

 
              
            

总结

本文记录了从0到1实现一个多模态大模型的过程,包括模型结构、数据引擎、评测全流程。当前模型训练数据与llava1.5-13b对齐,并且在MMMU评测上以更小的模型参数量超越了llava1.5-13b,当前训练数据因为只采用了图文多模态数据,在SFT阶段,并未加入text-only数据,因此,语言模型端会出现一些退化。将来若有时间,会考虑加入更多的多模态数据及笔者私有数据进行训练(如:《 【多模态 & 文档智能】一次多模态大模型表格识别解析探索小实践记录 》),打造更强的Reyes模型。

往期相关:

【多模态&LLM】POINTS多模态大模型浅谈

【多模态&LLM】LLaVA系列算法架构演进:LLaVA(1.0->1.5->Next(1.6)->NeXT(Video))

【多模态&LLM】英伟达NVLM多模态大模型细节和数据集

【多模态&文档智能】OCR-free感知多模态大模型技术链路及训练数据细节

【多模态 & 文档智能】一次多模态大模型表格识别解析探索小实践记录

0
0
0
0
关于作者
关于作者

文章

0

获赞

0

收藏

0

相关资源
VikingDB:大规模云原生向量数据库的前沿实践与应用
本次演讲将重点介绍 VikingDB 解决各类应用中极限性能、规模、精度问题上的探索实践,并通过落地的案例向听众介绍如何在多模态信息检索、RAG 与知识库等领域进行合理的技术选型和规划。
相关产品
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论