Inf-DiT: 低显存占用,超高分辨率图像生成

智能应用RTC机器学习

picture.image

扩散模型在图像生成领域展现了卓越的能力。但是,由于在生成超高分辨率图像(例如40964096)时内存需求呈二次方增长,通常生成的图像分辨率会被限制在10241024以内。

为了解决这个问题,我们提出了一种 单向块注意力机制 ,该机制能够在推理过程中自适应地调整显存使用,并处理全局依赖关系。

通过采用单向块注意力机制,我们显著降低了DiT模型在推理时的显存占用。这使得对任意大小的图像进行上采样成为可能。例如,我们可以在30G显存的条件下支持8192分辨率的图像生成。

此外,该机制还能够根据显存的限制自适应地进行并行生成,进一步优化性能。

目前,基于该机制在超高分辨率生成效果中可以 达到 SOTA 水平

picture.image

项目地址:https://github.com/THUDM/Inf-DiT

论文地址:https://arxiv.org/abs/2405.04312

单向块注意力机制

在我们使用Diffusion模型进行超高分辨率图像生成时,我们注意到模型内部的隐藏状态(hidden state)会消耗大量的显存资源。例如,一个204820481280的隐藏状态将占用大约20GB的显存。尽管对注意力(attention)和卷积神经网络(CNN)算子进行显存优化可以降低一部分显存使用,但这并不能减少隐藏状态所占用的显存。

另一种方法是将图像分块独立生成,并通过一些统计量(如均值和方差)来聚合不同块的信息。然而,这种简单的聚合方法无法统一不同块之间的高阶语义信息,例如纹理的形状等。

针对这一问题,我们提出了一种单向块注意力机制(UniBA)。在这种机制下,每个块只与其自身以及其左上角的三个块进行局部的注意力操作。这种从左上角到右下角的单向依赖关系允许模型不必一次性生成整张图像,从而在推理过程中将隐藏状态的显存占用从O(N^2)降低到O(N)。同时,由于在各层隐藏状态上都进行了交互,这种方法有效地聚合了不同块之间的语义信息。

与自回归模型不同,Inf-DiT能够同时生成多个块。因此,它可以根据显存的限制自动调整每次生成的块数,从而实现加速生成。

picture.image

如上图所示,展示的是单向块注意力(UniBA)机制。在每一层中,每个块都直接依赖于三个一阶相邻块:它上方的块、左侧的块以及左上角的块。在扩散变换器(DiT)架构(这是Inf-DiT的基础架构),块间的依赖关系是通过注意力操作实现的,具体来说,每个块的查询向量会与其左上角以及自身的三个块的关键、值向量进行交互。

模型结构

结合单向块注意力机制和之前提出的DiT(Diffusion Transformer)架构,作者设计出了Inf-DiT上采样模型:

picture.image

图:(左)Inf-DiT的整体架构。(右)Inf-DiT区块内部结构。

该模型采用了与DiT相似的骨干网络,即将视觉变换器(ViT)应用于扩散模型,并验证了其有效性和可扩展性。

除了其出色的性能外,DiT与基于卷积的架构(如UNet)相比,仅通过注意力机制实现斑块间的交互,这为实施单向块注意力提供了便利。

为了保持与原图的局部和全局一致性,低分辨率图片会以多种方式输入模型:

局部一致性: 低分辨率图片在简单的resize后会与带噪图片concat作为DiT的输入,位置一一映射能提供良好的inductive bias。但单向块注意力会导致每个块无法看到低分辨率图片的右下角部分,对此文章提出了nearby LR cross attention来对低分辨率图片的局部做attention。

全局一致性: 为了保证和低分辨率图片的全局语义一致性(艺术风格、物体材质等),作者用CLIP的image encoder获取了低分辨率图片的embedding并与DiT的time embedding相加。同时因为CLIP可以将图文对齐到同一空间中,文章发现还可以用文本来对生成结果进行控制,即使模型没有在任何文本上进行训练:

picture.image

其中C_pos和C_neg分别代表正向和反向的提示,I_LR是输入模型的图像嵌入。例如,可以将C_pos设为"Clear",C_neg设为"Blur"。

Inf-DiT使用了两种位置编码:RoPE和块级别的相对可学习位置编码。为了解决训练和生成时分辨率不匹配的问题,作者预先创建了一个大型的位置编码表,在训练过程中随机选择图像左上角在表中的坐标,以确保每个位置编码都得到训练。

同时,由于注意力机制的输入序列可能非常长,模型在训练中还采用了BF16精度和QK-Layernorm等方法来稳定训练过程。

模型评测

本文从多个角度验证了模型的生成能力,包括超高分辨率生成评测、超分辨率生成评测以及人工评测。

超高分辨率生成评测: 在超高分辨率图片生成方面,本文选取了HPSv2数据集中的1,000个提示(prompt),生成了对应的2048和4096分辨率的图片,并进行比较。

picture.image

表:超高分辨率生成的评测结果。其中FID_crop是指在每张图中随机截取299299的区域计算FID,更能体现高分辨率图片细节的真实度。*

超分辨率生成评测: 对于超分辨率任务,本文使用了DIV2K valid数据集,这是一个包含多种真实场景的摄影数据集。

picture.image

表:超分辨率任务的评测结果。

人工评测: 人工评测环节中,本文请志愿者根据HPSv2提示生成的图片,在细节保真度、全局一致性和原图一致性(针对超分辨率)三个方面对模型进行排序。Inf-DiT在这三个方面均取得了最佳成绩。

picture.image

图:人工评测结果。

图像生成效果。由于Inf-DiT能够接受各种分辨率的图像作为输入,因此可以用于对低分辨率图像进行迭代式上采样。本文测试了从3232分辨率上采样到2048 2048分辨率的过程,可以看出模型能够在不同分辨率下生成不同频率的细节,如脸型、眼球、眉毛等。

picture.image

图:4096分辨率上的生成效果

因为Inf-DiT可以接受各种分辨率的图像作为输入,所以可以拿来对低分辨率图像做迭代式上采样,文中测试了从3232分辨率上采样到20482048分辨率的过程,可以看到模型可以在不同的分辨率下生成不同频率的细节:脸型、眼球、眉毛….

picture.image

图:迭代式上采样。

:Inf-DiT能够对自身生成的图像进行多次上采样,并在相应的分辨率下生成不同频率的细节。 下:如果在1282分辨率下未能生成瞳孔,后续的上采样阶段将难以纠正这一错误。


picture.image

  • GLM-4交流群 -

阅读原文,用智能体读论文!

0
0
0
0
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论