笔者在前面预训练了一个多模态大模型Reyes,详情见《 【多模态&LLM】Reyes:一个从0到1开始训练的多模态大模型(技术报告) 》。本文将为Reyes增加一个batch推理方式,提高Reyes的推理速度。
Reyes-8B开源地址:
- modelscope权重下载地址:https://modelscope.cn/models/yujunhuinlp/Reyes-8B
- github:https://github.com/yujunhuics/Reyes
使用方式
将本仓库中的
modeling\_reyes.py
文件替换modelscrope下载的
modeling\_reyes.py
运行即可。 batch推理详细见github:
batch\_inference.ipynb
.
modeling\_reyes.py
增项:
def
chat\_batch
(
self,
tokenizer,
pixel\_values\_list,
questions,
generation\_config,
histories=None,
return\_histories=False,
num\_patches\_lists=None,
IMG\_START\_TOKEN=
'<|vision\_start|>'
,
IMG\_END\_TOKEN=
'<|vision\_end|>'
,
IMG\_CONTEXT\_TOKEN=
'<|vision\_pad|>'
,
verbose=False,
visual\_features\_list=None
)
:
if
histories
is
None
:
histories = [[]
for
\_
in
questions]
img\_context\_token\_id = tokenizer.convert\_tokens\_to\_ids(IMG\_CONTEXT\_TOKEN)
self.img\_context\_token\_id = img\_context\_token\_id
# Get eos\_token\_id from the template
template = get\_conv\_template(self.template)
template.system\_message = self.system\_message
eos\_token\_id = tokenizer.convert\_tokens\_to\_ids(template.sep)
generation\_config[
'eos\_token\_id'
] = eos\_token\_id
queries = []
input\_ids\_list = []
attention\_mask\_list = []
for
idx
in
range(len(questions)):
question = questions[idx]
history = histories[idx]
pixel\_values = pixel\_values\_list[idx]
if
pixel\_values\_list[idx]
is
not
None
else
None
num\_patches\_list = [pixel\_values.shape[
0
]]
if
pixel\_values
is
not
None
else
[]
if
not
history
and
pixel\_values
is
not
None
and
'<image>'
not
in
question:
question =
'<image>\n'
+ question
template\_i = get\_conv\_template(self.template)
template\_i.system\_message = self.system\_message
for
(old\_question, old\_answer)
in
history:
template\_i.append\_message(template\_i.roles[
0
], old\_question)
template\_i.append\_message(template\_i.roles[
1
], old\_answer)
template\_i.append\_message(template\_i.roles[
0
], question)
template\_i.append\_message(template\_i.roles[
1
],
None
)
query = template\_i.get\_prompt()
# Handle image tokens
if
pixel\_values
is
not
None
:
for
num\_patches
in
num\_patches\_list:
tile\_pos\_identifiers = [
f"<tile\_
{i}
>"
for
i
in
range(
1
, num\_patches)] + [
"<tile\_global\_thumbnail>"
]
image\_tokens =
''
for
tile\_pos\_identifier
in
tile\_pos\_identifiers:
image\_tokens += tile\_pos\_identifier + IMG\_CONTEXT\_TOKEN * self.num\_image\_token
image\_tokens = IMG\_START\_TOKEN + image\_tokens + IMG\_END\_TOKEN
query = query.replace(
'<image>'
, image\_tokens,
1
)
model\_inputs = tokenizer(
query,
return\_tensors=
'pt'
,
padding=
True
,
truncation=
True
)
input\_ids = model\_inputs[
'input\_ids'
].cuda()
attention\_mask = model\_inputs[
'attention\_mask'
].cuda()
input\_ids\_list.append(input\_ids)
attention\_mask\_list.append(attention\_mask)
# Call the generate function
generation\_output = self.generate\_batch(
pixel\_values\_list=pixel\_values\_list,
input\_ids\_list=input\_ids\_list,
attention\_mask\_list=attention\_mask\_list,
**generation\_config
)
responses = tokenizer.batch\_decode(generation\_output, skip\_special\_tokens=
True
)
outputs = []
for
idx, response
in
enumerate(responses):
response = response.split(template.sep)[
0
].strip()
histories[idx].append((questions[idx], response))
outputs.append(response)
if
return\_histories:
return
outputs, histories
else
:
if
verbose:
for
idx, query
in
enumerate(queries):
query\_to\_print = query.replace(IMG\_CONTEXT\_TOKEN,
''
)
query\_to\_print = query\_to\_print.replace(
f'
{IMG\_START\_TOKEN}
{IMG\_END\_TOKEN}
'
,
'<image>'
)
print(query\_to\_print, outputs[idx])
return
outputs
@torch.no\_grad()
def
generate\_batch
(
self,
pixel\_values\_list: Optional[List[torch.FloatTensor]] = None,
input\_ids\_list: Optional[List[torch.FloatTensor]] = None,
attention\_mask\_list: Optional[List[torch.LongTensor]] = None,
visual\_features: Optional[torch.FloatTensor] = None,
generation\_config: Optional[GenerationConfig] = None,
output\_hidden\_states: Optional[bool] = None,
return\_dict: Optional[bool] = None,
**generate\_kwargs,
)
-> torch.LongTensor:
input\_embeds\_list = []
attention\_mask\_padded\_list = []
max\_seq\_length = max(input\_ids.shape[
1
]
for
input\_ids
in
input\_ids\_list)
for
pixel\_values, input\_ids, attention\_mask
in
zip(pixel\_values\_list, input\_ids\_list, attention\_mask\_list):
if
pixel\_values
is
not
None
:
if
visual\_features
is
not
None
:
vit\_embeds = visual\_features.cuda()
vit\_embeds = self.mlp1(vit\_embeds)
else
:
vit\_embeds = self.extract\_feature(pixel\_values)
input\_embeds = self.language\_model.get\_input\_embeddings()(input\_ids)
B, N, C = input\_embeds.shape
input\_embeds = input\_embeds.reshape(B * N, C)
input\_ids = input\_ids.reshape(B * N)
selected = (input\_ids == self.img\_context\_token\_id)
assert
selected.sum() !=
0
,
"No valid image context token IDs found."
input\_embeds[selected] = vit\_embeds.reshape(
-1
, C).to(input\_embeds.device)
input\_embeds = input\_embeds.reshape(B, N, C)
else
:
input\_embeds = self.language\_model.get\_input\_embeddings()(input\_ids)
seq\_length = input\_embeds.shape[
1
]
if
seq\_length < max\_seq\_length:
pad\_size = max\_seq\_length - seq\_length
input\_embeds = F.pad(input\_embeds, (
0
,
0
,
0
, pad\_size))
attention\_mask = F.pad(attention\_mask, (
0
, pad\_size))
input\_embeds\_list.append(input\_embeds)
attention\_mask\_padded\_list.append(attention\_mask)
input\_embeds = torch.cat(input\_embeds\_list, dim=
0
)
attention\_mask = torch.cat(attention\_mask\_padded\_list, dim=
0
)
outputs = self.language\_model.generate(
inputs\_embeds=input\_embeds,
attention\_mask=attention\_mask,
generation\_config=generation\_config,
output\_hidden\_states=output\_hidden\_states,
return\_dict=return\_dict,
use\_cache=
True
,
**generate\_kwargs,
)
return
outputs
batch推理:
import
torch
from
modelscope
import
AutoTokenizer, AutoModel
from
PIL
import
Image
import
torchvision.transforms
as
T
from
torchvision.transforms.functional
import
InterpolationMode
IMAGENET\_MEAN = (
0.485
,
0.456
,
0.406
)
IMAGENET\_STD = (
0.229
,
0.224
,
0.225
)
def
build\_transform
(input\_size)
:
MEAN, STD = IMAGENET\_MEAN, IMAGENET\_STD
transform = T.Compose([
T.Lambda(
lambda
img: img.convert(
'RGB'
)
if
img.mode !=
'RGB'
else
img),
T.Resize((input\_size, input\_size), interpolation=InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=MEAN, std=STD)
])
return
transform
def
find\_closest\_aspect\_ratio
(aspect\_ratio, target\_ratios, width, height, image\_size)
:
best\_ratio\_diff = float(
'inf'
)
best\_ratio = (
1
,
1
)
area = width * height
for
ratio
in
target\_ratios:
target\_aspect\_ratio = ratio[
0
] / ratio[
1
]
ratio\_diff = abs(aspect\_ratio - target\_aspect\_ratio)
if
ratio\_diff < best\_ratio\_diff:
best\_ratio\_diff = ratio\_diff
best\_ratio = ratio
elif
ratio\_diff == best\_ratio\_diff:
if
area >
0.5
* image\_size * image\_size * ratio[
0
] * ratio[
1
]:
best\_ratio = ratio
return
best\_ratio
def
dynamic\_preprocess
(image, min\_num=
1
, max\_num=
12
, image\_size=
448
, use\_thumbnail=False)
:
orig\_width, orig\_height = image.size
aspect\_ratio = orig\_width / orig\_height
# calculate the existing image aspect ratio
target\_ratios = set(
(i, j)
for
n
in
range(min\_num, max\_num +
1
)
for
i
in
range(
1
, n +
1
)
for
j
in
range(
1
, n +
1
)
if
i * j <= max\_num
and
i * j >= min\_num)
target\_ratios = sorted(target\_ratios, key=
lambda
x: x[
0
] * x[
1
])
# find the closest aspect ratio to the target
target\_aspect\_ratio = find\_closest\_aspect\_ratio(
aspect\_ratio, target\_ratios, orig\_width, orig\_height, image\_size)
# calculate the target width and height
target\_width = image\_size * target\_aspect\_ratio[
0
]
target\_height = image\_size * target\_aspect\_ratio[
1
]
blocks = target\_aspect\_ratio[
0
] * target\_aspect\_ratio[
1
]
# resize the image
resized\_img = image.resize((target\_width, target\_height))
processed\_images = []
for
i
in
range(blocks):
box = (
(i % (target\_width // image\_size)) * image\_size,
(i // (target\_width // image\_size)) * image\_size,
((i % (target\_width // image\_size)) +
1
) * image\_size,
((i // (target\_width // image\_size)) +
1
) * image\_size
)
# split the image
split\_img = resized\_img.crop(box)
processed\_images.append(split\_img)
assert
len(processed\_images) == blocks
if
use\_thumbnail
and
len(processed\_images) !=
1
:
thumbnail\_img = image.resize((image\_size, image\_size))
processed\_images.append(thumbnail\_img)
return
processed\_images
def
load\_image
(image\_file, input\_size=
448
, max\_num=
12
)
:
image = Image.open(image\_file).convert(
'RGB'
)
transform = build\_transform(input\_size=input\_size)
images = dynamic\_preprocess(image, image\_size=input\_size, use\_thumbnail=
True
, max\_num=max\_num)
pixel\_values = [transform(image)
for
image
in
images]
pixel\_values = torch.stack(pixel\_values)
return
pixel\_values
def
preprocess\_image
(file\_path, dynamic=True, max\_num=
6
, image\_size=
448
)
:
try
:
if
dynamic:
return
load\_image(file\_path, max\_num=max\_num).to(torch.bfloat16).cuda()
else
:
img = Image.open(file\_path).convert(
'RGB'
)
transform = build\_transform(image\_size)
pixel\_values = transform(img)
return
torch.stack([pixel\_values]).to(torch.bfloat16).cuda()
except
Exception
as
e:
raise
RuntimeError(
f"Error processing image:
{e}
"
)
path =
"Reyes-8B"
model = AutoModel.from\_pretrained(
path,
torch\_dtype=torch.bfloat16,
trust\_remote\_code=
True
,
).eval().cuda()
# print(model)
tokenizer = AutoTokenizer.from\_pretrained(path, trust\_remote\_code=
True
, use\_fast=
False
)
generation\_config = dict(max\_new\_tokens=
2048
, do\_sample=
False
)
questions = [
"<image>\nDescribe this image."
,
"<image>\nDescribe this image."
,
"<image>\nDescribe this image."
,
]
images\_path = [
"t6.png"
,
"t6.png"
,
"t6.png"
]
def
conversation
(model, tokenizer, questions, images\_path,generation\_config,histories)
:
pixel\_values\_list=[]
for
i
in
range(len(questions)):
if
images\_path[i]
is
not
None
:
pixel\_values = preprocess\_image(file\_path, dynamic=
True
)
pixel\_values\_list.append(pixel\_values)
return
model.chat\_batch(tokenizer, pixel\_values\_list, questions, generation\_config, histories, return\_histories=
False
)
responses= conversation(model, tokenizer, questions, images\_path,generation\_config,histories=
None
)
for
question, response
in
zip(questions, responses):
print(
f"User:
{question}
\nAssistant:
{response}
\n"
)