孤村落日残霞,轻烟老树寒鸦,一点飞鸿影下
小伙们好,我是卖热干面的小女孩,继续 多模态大模型系列 。紧随前文:
多模态系列 | Google最新开源多模态:PaliGemma 2 简介&实战
多模态系列 | Google开源多模态:PaliGemma 2技术全面解读
今天这篇小作文以多模态模型PaliGemma 2微调为例,从实战角度介绍如何微调多模态大模型。完整的代码请到微信公众号《小窗幽记机器学习》上添加小编微信获取。后续还会持续输出各类大模型的解读、评测、微调实战及其推理加速,感兴趣的小伙伴可以加星标留意下!
以下将实战演示如何微调多模态大模型PaliGemma 2,本次微调将基于 google/paligemma2-3b-pt-448 和少量的VQAv2数据集。使用PEFT配合LoRA进行,具体将在下面的微调部分详细说明。
环境安装
pip3 install -U datasets bitsandbytes peft git+https://github.com/huggingface/transformers.git -i https://mirrors.cloud.tencent.com/pypi/simple
数据下载
这里使用huggingface上的数据集:
'merve/vqav2-small
,具体的下载方式可以参考之前的文章:
AI入门系列 | 如何优雅地下载最前沿的模型?
huggingface-cli download --resume-download --repo-type dataset --
local
-dir-use-symlinks False merve/vqav2-small --
local
-dir /you\_data\_dir/merve/vqav2-small/
为了直观展现模型微调的效果,以下对比相同Prompt和图片在模型微调前后的输出结果。 测试图片:
promt="这个建筑物叫什么名字?"
以下是直接加载预训练模型,然后进行inference的结果:
tower
可以看出,并不认识这个建筑是东方明珠,只知晓是一座塔。
加载数据
# 加载本地数据
from datasets import load\_dataset
ds = load\_dataset(
"parquet"
,
data\_dir=
"/data\_dir/merve/vqav2-small/data"
,
# 修改为包含 Parquet 文件的目录
split=
"validation"
,
# data\_files="validation-*.parquet" # 使用通配符匹配所有 validation 分片
)
ds = ds.train\_test\_split(test\_size=0.1)
# we'll use a very small split for demo
ds = ds[
"train"
]
此时的训练数据详情:
Dataset({
features: [
'multiple\_choice\_answer'
,
'question'
,
'image'
],
num\_rows: 19291
})
加载模型&设置LoRA
import os
from transformers import PaliGemmaProcessor
model\_dir =
"/model\_dir/"
model\_id =
"google/paligemma2-3b-pt-224"
# or your favorite PaliGemma
model\_id = os.path.join(model\_dir, model\_id)
from transformers import PaliGemmaForConditionalGeneration
import torch
device =
"cuda"
from transformers import BitsAndBytesConfig, PaliGemmaForConditionalGeneration
from peft import get\_peft\_model, LoraConfig
bnb\_config = BitsAndBytesConfig(load\_in\_4bit=True, bnb\_4bit\_compute\_dtype=torch.bfloat16)
lora\_config = LoraConfig(
r=8,
target\_modules=[
"q\_proj"
,
"o\_proj"
,
"k\_proj"
,
"v\_proj"
,
"gate\_proj"
,
"up\_proj"
,
"down\_proj"
],
task\_type=
"CAUSAL\_LM"
,
)
model = PaliGemmaForConditionalGeneration.from\_pretrained(model\_id, device\_map=device)
#, quantization\_config=bnb\_config)
model = get\_peft\_model(model, lora\_config)
# model.print\_trainable\_parameters()
注意,当使用如下代码加载模型:
model = PaliGemmaForConditionalGeneration.from\_pretrained(model\_id, device\_map=
"auto"
)
如果GPU显存不足,可能会出现如下提示信息:
Some parameters are on the meta device because they were offloaded to the cpu.
数据预处理
processor = PaliGemmaProcessor.from\_pretrained(model\_id)
import torch
# image\_token = processor.tokenizer.convert\_tokens\_to\_ids("<image>")
def collate\_fn(examples):
texts = [
"<image>answer en "
+ example[
"question"
]
for
example
in
examples]
labels= [example[
'multiple\_choice\_answer'
]
for
example
in
examples]
images = [example[
"image"
].convert(
"RGB"
)
for
example
in
examples]
tokens = processor(text=texts, images=images, suffix=labels,
return\_tensors=
"pt"
, padding=
"longest"
)
tokens = tokens.to(DTYPE).to(device)
return
tokens
开始训练
from transformers import TrainingArguments
args=TrainingArguments(
num\_train\_epochs=2,
remove\_unused\_columns=False,
per\_device\_train\_batch\_size=1,
gradient\_accumulation\_steps=4,
warmup\_steps=2,
learning\_rate=2e-5,
weight\_decay=1e-6,
adam\_beta2=0.999,
logging\_steps=100,
optim=
"adamw\_hf"
,
# you can use paged optimizers like paged\_adamw\_8bit for QLoRA
save\_strategy=
"steps"
,
save\_steps=1000,
save\_total\_limit=1,
output\_dir=
"paligemma\_vqav2"
,
bf16=True,
report\_to=[
"tensorboard"
],
dataloader\_pin\_memory=False
)
trainer = Trainer(
model=model,
train\_dataset=ds,
data\_collator=collate\_fn,
args=args
)
trainer.train()
在单块A100 80G上大概训练了9个半小时。完整的代码请到微信公众号《小窗幽记机器学习》上添加小编微信获取。
模型inference
# Inference
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
# 加载微调后的模型
sft\_model\_id =
"./paligemma\_vqav2/checkpoint-3615"
model = PaliGemmaForConditionalGeneration.from\_pretrained(sft\_model\_id)
processor = AutoProcessor.from\_pretrained(model\_id)
from PIL import Image
import requests
prompt =
"<image>这个建筑物叫什么名字?"
test\_image =
'../../dataset/test\_images/东方明珠.jpeg'
raw\_image = Image.open(test\_image).convert(
"RGB"
)
inputs = processor(prompt, raw\_image, return\_tensors=
"pt"
)
output = model.generate(**inputs, max\_new\_tokens=20)
print
(processor.decode(output[0], skip\_special\_tokens=True)[len(prompt):])
输出结果如下:
oriental pearl tower
可以看出,使用
vqav2-small
数据集对
paligemma2-3b-pt-224
微调之后,模型具备更精准的识物能力。
错误1:cannot import name 'shard_checkpoint'
ImportError: cannot import name
'shard\_checkpoint'
from
'transformers.modeling\_utils'
(/your\_path/opt/minicoda/lib/python3.11/site-packages/transformers/modeling\_utils.py)
解决方案:安装的transformers进行降版本,最终安装的transformers版本=4.47.0。
错误2:cannot import name 'log'
File
"/usr/local/lib/python3.10/dist-packages/deepspeed/elasticity/\_\_init\_\_.py"
, line 10,
in
<module>
from .elastic\_agent import DSElasticAgent
File
"/usr/local/lib/python3.10/dist-packages/deepspeed/elasticity/elastic\_agent.py"
, line 9,
in
<module>
from torch.distributed.elastic.agent.server.api import
log
, \_get\_socket\_with\_port
ImportError: cannot import name
'log'
from
'torch.distributed.elastic.agent.server.api'
(/usr/
local
/lib/python3.10/dist-packages/torch/distributed/elastic/agent/ server/api.py)
解决方案: 这是torch与deepspeed版本之间的兼容问题。因此升级 deepspeed, 升级之后的deepspeed版本为=0.16.1。此时的pytorch版本=2.5.1
报错3:module 'PIL.Image' has no attribute 'ExifTags'
AttributeError: module
'PIL.Image'
has no attribute
'ExifTags'
. Did you mean:
'TiffTags'
?
解决方案,升级Pillow,最终版本为Pillow-9.4.0。