点击下方卡片,关注 「AI视界引擎」 公众号
扩散模型在大规模预训练方面已成功在视觉内容生成领域取得了显著成就,尤其是以扩散 Transformer (DiT)为代表。然而,DiT模型在可扩展性和二次复杂度效率方面遇到了挑战。在本文中,作者旨在利用门控线性注意力(GLA) Transformer 在长序列建模方面的能力,将其适用性扩展到扩散模型。
作者引入了扩散门控线性注意力 Transformer (DiG),这是一个简单、易于采用的解决方案,具有最小的参数开销,遵循DiT的设计,但提供了更优的效率和有效性。除了性能优于DiT之外,DiG-S/2在训练速度上比DiT-S/2快了,在的分辨率下节省了的GPU内存。
此外,作者分析了DiG在各种计算复杂度下的可扩展性。增加深度/宽度或增强输入 Token 的DiG模型,其FID值持续降低。作者进一步将DiG与其他次二次时间扩散模型进行了比较。
在相同模型大小下,DiG-XL/2在分辨率下比最近的基于Mamba的扩散模型快,在分辨率下比使用CUDA优化的FlashAttention-2的DiT快。所有这些结果都表明,DiG在最新的扩散模型中具有卓越的效率。
1 Introduction
近年来,扩散模型已成为强大的深度生成模型,因其能够生成高质量图像而闻名。它们的迅速发展催生了包括图像到图像生成、文本到图像生成、语音合成、视频生成和3D生成在内的广泛应用。与此同时,采样算法的快速发展,主要技术根据其架构 Backbone 分为两类:基于U-Net的方法[20; 53]和基于ViT的方法[14]。基于U-Net的方法继续利用卷积神经网络(CNN)架构[31; 48],其分层特征建模能力有利于视觉生成任务。
另一方面,基于ViT的方法通过融入自注意力机制[56]而非传统的采样块进行创新,使得性能简化而有效。
由于在性能上的优异可扩展性,基于ViT的方法[39]已被采纳为最先进扩散作品的主干,包括PixArt[6, 5]、Sora[3]、Stable Diffusion 3[15]等。然而,基于ViT架构中的自注意力机制与输入序列长度的平方成正比,在处理长序列生成任务时,如高分辨率图像生成、视频生成等,会变得资源密集。
最近在次平方时间方法方面的进展,即Mamba、RWKV和门控线性注意力 Transformer (GLA)[59],试图通过整合类似递归神经网络(RNN)的架构和硬件感知算法来提高长序列处理的效率。其中,GLA将数据相关门控操作和硬件高效实现融入到线性注意力 Transformer 中,显示出具有竞争力的性能但更高的吞吐量。
受到自然语言处理领域中GLA成功的启发,作者希望将这种成功从语言生成转移到视觉内容生成上,即设计一个可扩展且高效的扩散 Backbone 网络,采用先进的线性注意力方法[26, 10, 25]。然而,使用GLA进行视觉生成面临两个挑战,即单向扫描建模和缺乏局部感知。为了解决这些挑战,作者提出了扩散GLA(DiG)模型,它融合了一个轻量级的空间重新定位与增强模块(SREM),用于逐层控制扫描方向和局部感知。在每个模块的末端,SREM会通过有效的矩阵操作改变序列索引,为下一个模块的不同扫描提供支持。
扫描方向包含四种基本模式,并使序列中的每个块能够感知到沿交叉方向的其他块。此外,作者还将在SREM中融合深度卷积(DWConv)[9],以极小的参数量提供局部感知。至关重要的是,本文呈现了一个系统的消融研究,包括SREM的融合以及模型架构的全面评估。重要的是要强调,DiG遵循扩散生成中线性注意力 Transformer 的首要实践,因其在大规模图像生成任务中卓越的可扩展性和效率而著称。
与基于ViT的方法,即DiT [39]相比,在相同的超参数下,DiG在ImageNet [12]生成上表现出更优越的性能。此外,在训练速度和GPU内存方面,DiG在生成高分辨率图像上更为高效。内存和速度上的高效使DiG能够缓解长序列视觉生成任务中的资源限制问题。值得注意的是,像DiS [16]这类基于Mamba的次二次时间扩散方法,由于复杂的模块设计和无法有效利用GPU张量核心,随着模型尺寸的扩大,通常表现出较低的效率,如图2所示。
得益于DiG模块的简洁而有效的设计,DiG可以在较大的模型尺寸上保持高效率,甚至在的分辨率下,超过了最精心设计且高度优化的线性注意力方法,即FlashAttention-2 [11]。
作者的主要贡献可以总结如下:
作者提出了扩散型全局局部注意力(DiG)结构,它通过层扫描高效地结合了对全局视觉上下文的建模和局部视觉感知。
据作者所知,DiG是首次探索具有线性注意力 Transformer 的扩散型主干网络。在没有二次注意力负担的情况下,所提出的DiG在保持与DiT相似建模能力的同时,在训练速度和GPU内存成本方面表现出更高的效率。
具体来说,如图1所示,在1792×1792分辨率下,DiG比DiT快2.5倍,并节省了75.7%的GPU内存。作者在ImageNet数据集上进行了大量实验。结果表明,与DiT相比,DiG表现出可扩展的能力,并取得了更优的性能。DiG有望成为下一代在大规模长序列生成背景下扩散模型的 Backbone 网络。
2 Related Work
Linear Attention Transformer
与标准的自回归Transformer [57](该模型处理全局注意力矩阵)不同,原始的线性注意力 [26] 本质上是一个具有矩阵值隐藏状态的线性RNN。线性注意力引入了一个相似性核 及其相关的特征映射 ,即 。输出 (这里 是序列长度, 是维度)的计算可以表示如下:
其中 Query ,键 ,值 的形状为 , 是当前 Token 的索引。通过表示隐藏状态 和归一化器 ,其中 ,方程(1)可以重写为:
最近的工作将 设置为恒等变换 [35; 55] 并移除了 [43],得到了以下格式的线性注意力Transformer:
直接将线性注意力Transformer用于视觉生成,由于单向建模的原因,会导致性能不佳,因此作者提出了一个轻量级的空间重新定位与增强模块,以处理十字交叉方向的全局上下文建模以及局部信息。
Backbones in Diffusion Models
现有的扩散模型通常使用U-Net作为 Backbone 网络进行图像生成。近期,基于Vision Transformer (ViT)的 Backbone 网络因 Transformer 架构的可扩展性和其对多模态学习的自然适应性而受到广泛关注。然而,基于ViT的架构存在二次复杂度问题,限制了它们在高分辨率图像合成、视频生成等长序列生成任务中的实用性。
为了缓解这一问题,近期的研究探索了次二次时间复杂度的方法,以高效处理长序列。例如,DiS [16],DiffuSSM [58] 和 ZigMa [23] 使用状态空间模型作为扩散 Backbone ,以提高计算效率。Diffusion-RWKV [58] 在扩散模型中采用了RWKV架构进行图像生成。
作者的DiG也遵循这一研究方向,通过采用门控线性注意力 Transformer (GLA)作为扩散 Backbone ,旨在提高长序列处理的效率。作者提出的适配方法保持了GLA的基本结构和优势,同时引入了一些必要的修改,以便生成高保真的视觉数据。
3 Method
Preliminaries
门控线性注意力 Transformer (Gated Linear Attention Transformer, GLA)[59]结合了数据相关门控机制和线性注意力,实现了卓越的循环建模性能。给定一个输入 (这里 是序列长度, 是维度),GLA以下面的方式计算 Query (query)、键(key)和值(value)向量:
其中 、 和 是线性投影权重。 和 是维度数量。接下来,GLA计算门控矩阵 如下:
其中 是标记的索引, 是sigmoid函数, 是偏置项, 是温度项。如图3所示,最终的输出 按以下方式获得:
(7) (8) (9)
其中Swish是Swish[44]激活函数,是逐元素乘法操作。在后续章节中,作者用来指代输入序列的门控线性注意力计算。
扩散模型。在介绍所提出的方法之前,作者简要回顾一下关于扩散模型(DDPM)[20]的一些基本概念。DDPM将噪声作为输入,并通过迭代去噪输入来采样图像。DDPM的前向过程始于一个随机过程,初始图像逐渐被噪声破坏,最终转化为一个更简单、以噪声为主导的状态。前向加噪过程可以表示为:
其中是从时间到的噪声图像序列。然后,DDPM学习逆向过程,用学习到的和恢复原始图像:
其中是去噪器的参数,通过观察数据的对数似然的变分下界[51]进行训练。
其中是完整损失。为了进一步简化DDPM的训练过程,研究者将重新参数化为一个噪声预测网络,并最小化与真实高斯噪声之间的均方误差损失:
然而,要训练一个能够学习可变反向过程协方差 的扩散模型,作者需要优化完整的 项。在本文中,作者遵循 DiT [39] 的方法训练网络,作者使用简单的损失 来训练噪声预测网络 ,并使用完整的损失 来训练协方差预测网络 。在训练过程之后,作者遵循随机采样过程,从学习到的 和 生成图像。
Diffusion GLA
作者提出了扩散GLA(DiG),这是一种新的扩散生成架构。作者的目标是尽可能忠实于标准的GLA架构,以保留其扩展能力和高效率特性。图3展示了所提出GLA的概览。标准的GLA是为处理1-D序列的因果语言建模而设计的。为了处理图像的DDPM训练,作者遵循了先前视觉 Transformer 架构的一些最佳实践[14; 39]。DiG首先将VAE编码器[27; 47]输出的空间表示作为输入。对于的图像输入到VAE编码器,空间表示的形状是。DiG随后通过 Patch 层将空间输入转换成标记序列,其中是标记序列的长度,是空间表示通道的数量,是图像块的大小,减半将会使翻四倍。接下来,作者将线性投影到维度为的向量,并为所有投影标记添加基于频率的位置嵌入,如下所示:
其中是的第个块,是可学习的投影矩阵。至于如噪声时间步和类别标签这样的条件信息,作者分别采用多层感知机(MLP)和嵌入层作为时间步嵌入器和标签嵌入器。
其中是时间嵌入,是标签嵌入。然后作者将标记序列()发送到DiG编码器的第层,并获得输出。最后,作者对输出标记序列进行归一化,并将其输入到线性投影头以得到最终的预测噪声和预测协方差,如下所示:
其中是第个扩散GLA块,是层数,是归一化层。和与输入空间表示具有相同的形状,即。
DiG Block
原始的GLA块以循环格式处理输入序列,这仅能实现对一维序列的因果建模。在本节中,作者介绍了DiG块,它包含了一个空间重新定位与增强模块(SREM),该模块实现了轻量级空间识别并控制逐层扫描方向。DiG块如图4所示。
具体来说,作者在算法1中展示了DiG块的向前过程。遵循自适应归一化层[41]在生成对抗网络(GANs)[2; 24]和扩散模型[13; 39]中的广泛使用,作者添加并归一化输入时间步嵌入和标签嵌入以回归尺度参数、和偏移参数。接下来,作者启动带门控的线性注意力(GLA)和前馈网络(FFN),并使用回归的自适应层归一化(adaLN)参数进行调整。然后,作者将序列 Reshape 为2D,并启动轻量级的深度卷积(DWConv2d)层来感知局部空间信息。特别是,使用传统的初始化方法对于DWConv2d会导致收敛缓慢,因为卷积权重是分散的。为了解决这个问题,作者提出了身份初始化方法,只将卷积核中心设置为1,周围设置为0。最后,每两个块作者将2D Token 矩阵转置,并翻转扁平的序列以控制下一块的扫描方向。如图4的右侧部分所示,每个层只处理一个方向的扫描。
Architecture Details
作者使用了总共个DiG块,每个块的隐藏维度大小为。遵循先前的工作[39, 14, 59],作者使用了标准的 Transformer 配置,这些配置会调整、以及注意力头的数量。具体来说,作者提供了四种配置:DiG-S、DiG-B、DiG-L和DiG-XL,如表1所示。它们涵盖了从31.5M到644.6M的参数和浮点运算分配,以及从1.09 Gflops到22.53 Gflops,为衡量扩展性能和效率提供了一种方法。值得注意的是,与同等大小的 Baseline 模型(即DiTs)相比,DiG仅消耗了77.0%到78.9%的Gflops。
Efficiency Analysis
GPU包含两个重要组件,即高带宽内存(HBM)和SRAM。HBM的内存容量更大,但SRAM的带宽更宽。为了充分利用SRAM并以并行形式建模序列,作者遵循GLA将整个序列分割成许多块,这些块可以在SRAM上完成计算。作者将块大小表示为,训练复杂度因此为,当时,这小于传统注意力的复杂度。此外,DiG块中的轻量级DWConv2d和高效的矩阵操作也保证了效率,如图1和图2所示。
4 Experiment
Experimental Settings
数据集和评价指标。遵循之前的工作[39],作者在的分辨率下使用ImageNet[12]进行类条件图像生成学习。ImageNet数据集包含了1,281,167张训练图像,涵盖了1,000个不同的类别。作者将水平翻转作为数据增强手段。作者使用Frechet初始距离(FID)[37],初始分数(Inception Score)[50],sFID[37],以及精确度/召回率[30]来衡量生成性能。
实现细节。作者使用AdamW优化器,并设置恒定的学习率为。遵循之前的工作[39],在训练过程中,作者使用DiG权重的指数移动平均(EMA),衰减率为0.9999。作者使用EMA模型生成所有图像。对于ImageNet的训练,作者使用现成的预训练变分自编码器(VAE)[46; 28]。
Model Analysis
空间重新定位与增强模块的影响。如表2所示,作者分析了提出的空间重新定位与增强模块(SREM)的有效性。作者以DiT-S/2作为作者的 Baseline 方法。仅具有因果建模的简单DiG的FLOPs和参数显著较少,但由于缺乏全局上下文,FID性能也较差。作者首先在DiG中添加双向扫描,并观察到显著改进,即69.28 FID,这证明了全局上下文的重要性。在没有为DWConv2d进行身份初始化的实验中,即半对半错符号,会导致更差的FID,而带有身份初始化的DWConv2d可以大幅提高性能。
带有DWConv2d的实验证明了身份初始化和局部意识的重要性。最后一行的实验表明,完整的SREM可以带来最佳性能,同时关注局部信息和全局上下文。模型尺寸的缩放。作者在ImageNet数据集上针对四种不同模型尺寸的DiG调查其缩放能力。如图5(a)所示,随着模型尺寸从S/2缩放到XL/2,性能得到提升。这些结果表明了DiG的缩放能力,表明其作为大型基础扩散模型的潜力。
Patch 尺寸的影响。作者在ImageNet数据集上用2、4和8的 Patch 尺寸训练DiG-S。如图5(b)所示,通过增加DiG的 Patch 尺寸,可以在训练过程中观察到明显的FID增强。因此,最佳性能需要较小的 Patch 尺寸和较长的序列长度。与DiT [39] Baseline 相比,DiG在处理长序列生成任务时更为高效。
Main Results
作者主要将提出的DiG与作者的 Baseline 方法DiT [39]进行了比较,两者使用了相同的超参数。在400K次训练迭代中,提出的DiG在四种模型规模上都优于DiT。此外,当使用无分类器指导时,DiG-XL/2-1200K也展现出与先前最先进方法相比具有竞争力的结果。
Case Study
5 结论
在这项工作中,作者提出了DiG,它是扩散模型在图像生成任务中传统Transformer的有效且经济的替代品。特别是,DiG探索了门控线性注意力Transformer(GLA),在长序列图像生成任务中获得了卓越的效率和有效性。
实验上,DiG在与先前扩散模型在类条件ImageNet基准上的表现相当的同时,显著降低了计算负担。作者希望这项工作能够为其他长序列生成任务开辟可能性,例如视频和音频建模。
参考
[1].DiG: Scalable and Efficient Diffusion Models with Gated Linear Attention
点击上方卡片,关注 「AI视界引擎」 公众号