字节提出DDT | 解耦扩散Transformer打破DiT瓶颈,训练提速4倍+推理加速新高度

大模型向量数据库机器学习

点击下方卡片,关注「集智书童」公众号

点击加入👉「集智书童」交流群

picture.image

picture.image

picture.image

picture.image

picture.image

picture.image

精简阅读版本

本文主要解决了什么问题

优化困境 :传统扩散Transformer在生成高质量图像时面临低频语义编码和高频细节解码之间的矛盾,导致训练效率低下。

训练效率低 :扩散Transformer需要更长的训练迭代次数和更多的推理步骤才能达到理想性能。

模型扩展瓶颈 :随着模型规模的增加,如何有效分配计算资源以提升性能是一个挑战。

本文的核心创新是什么

解耦扩散Transformer(DDT)架构 :将低频语义编码和高频细节解码分离为专用组件,通过条件编码器提取语义特征,速度解码器专注于高频分量的恢复。

统计动态规划方法 :提出了一种新的算法来寻找最优的自条件共享策略,从而显著提升推理速度并最小化性能下降。

改进的训练和采样策略 :通过增强表示对齐技术和局部一致性约束,进一步优化了模型的训练过程和推理效率。

结果相较于以前的方法有哪些提升

**ImageNet

数据集** :DDT-XL/2 在仅 256 个 epoch 内达到了 1.31 FID 的新最优性能,相比之前的扩散Transformer训练收敛速度提高了近 4 倍。

**ImageNet

数据集** :DDT-XL/2 达到了 1.28 FID 的新最优性能,显著优于所有先前方法。

推理速度提升 :通过相邻去噪步骤之间的自条件共享机制,显著提升了推理速度,同时保持了高质量的生成效果。

局限性总结

复杂性增加 :引入解耦架构和动态规划方法增加了模型设计和实现的复杂性。

依赖外部技术 :部分性能提升依赖于其他领域的先进技术(如 SwiGLU、RoPE 和 RMSNorm),可能限制其独立应用的价值。

潜在的微调不足 :在某些高分辨率数据集上(如 ImageNet

),可能存在微小的性能退化,可能是由于微调不足或训练时间有限所致。

深入阅读版本

扩散Transformer在生成质量上表现出色,但需要更长的训练迭代次数和更多的推理步骤。在每个去噪步骤中,扩散Transformer将含噪输入编码以提取低频语义分量,然后使用相同的模块解码高频分量。这种方案产生了一个固有的优化困境:编码低频语义需要减少高频分量,从而在语义编码和高频解码之间形成矛盾。为解决这一挑战,作者提出了一种新的解耦扩散Transformer(DDT),其设计包含专用的条件编码器用于语义提取以及专门的速率解码器:

实验表明,随着模型规模的增加,更强大的编码器能带来性能提升。在ImageNet

上,DDT-XL/2达到了1.31 FID的新最优性能(相比之前的扩散Transformer训练收敛速度提高了近4倍)。在ImageNet

上,DDTXL/2达到了1.28的新最优FID。此外,作为有益的副产品,作者的解耦架构通过实现相邻去噪步骤之间的自条件共享,提升了推理速度。为最小化性能退化,作者提出了一种新的统计动态规划方法来识别最优共享策略。

  1. 引言

图像生成是计算机视觉研究中的一个基本任务,旨在捕捉原始图像数据集的内在数据分布,并通过分布采样生成高质量的合成图像。扩散模型[19, 21, 29, 30, 41]最近已成为学习图像生成中潜在数据分布的一种极具前景的解决方案,其性能优于基于GAN的模型[3, 40]和自回归模型[5, 43, 51]。

扩散前向过程按照SDE前向调度逐步向原始数据添加高斯噪声[19, 21, 41]。去噪过程从这种污染过程中学习分数估计。一旦分数函数被准确学习,可以通过数值求解反向SDE合成数据样本[21, 29, 30, 41]。

扩散Transformer[32, 36]将Transformer架构引入扩散模型,以取代传统上占主导地位的基于UNet的模型[2, 10]。实验证据表明,在足够多的训练迭代下,扩散Transformer即使不依赖长残差连接也能超越传统方法[36]。然而,由于成本高昂,其缓慢的收敛速度仍然为开发新模型带来了巨大挑战。

在本论文中,作者希望从模型设计角度解决上述主要缺点。经典计算机视觉算法[4, 17, 23]策略性地采用编码器-解码器架构,优先使用大型编码器进行丰富特征提取,并使用轻量级解码器进行高效推理,而当代扩散模型主要依赖传统的仅解码器结构。作者系统地研究了扩散 Transformer 中解耦编码器-解码器设计的未充分挖掘的潜力,通过回答解耦编码器-解码器 Transformer 能否解锁加速收敛和提升样本质量的能力这一问题。

通过实验研究,作者得出结论:普通扩散 Transformer 在抽象结构信息提取和详细外观信息恢复之间存在优化困境。此外,由于原始像素监督[28, 52, 53],扩散 Transformer 在提取语义表示方面存在局限。为解决这一问题,作者提出了一种新架构,通过定制化的编码器-解码器设计明确解耦低频语义编码和高频细节解码。作者将这种编码器-解码器扩散 Transformer 模型称为DDT(解耦扩散 Transformer )。DDT包含一个条件编码器用于提取语义自条件特征。提取的自条件与噪声潜空间一起输入到速度解码器中,以回归速度场。为保持相邻步骤自条件特征的局部一致性,作者采用表示对齐的直接监督和来自解码器速度回归损失的间接监督。

在ImageNet 256×256数据集上,使用传统的现成VAE[38],作者的解耦扩散Transformer(DDT-XL/2)模型在仅256个epoch内就实现了1.31 FID的间隔引导下的最先进性能,相比REPA[52]的训练速度提高了约4倍。在ImageNet 512×512数据集上,DDTXL/2模型在100K步内实现了1.90 FID。

此外,DDT在自条件特征上实现了强大的局部一致性,这可以显著通过在相邻步骤间共享自条件来提升推理速度。作者将最优编码器共享策略表述为一个经典的最小和路径问题,通过最小化在相邻步骤间共享自条件时的性能下降来求解。作者提出了一种统计动态规划方法,以忽略不计的二级时间成本找到最优编码器共享策略。与简单的均匀共享相比,作者的动态规划实现了最小的FID下降。

作者的贡献总结如下。

  • • 作者提出了一种新的解耦扩散Transformer模型,该模型由一个条件编码器和一个速度解码器组成。
  • • 作者提出统计动态规划方法,以寻找最优的自条件共享策略,从而在保持性能最小下降的同时提升推理速度。
  • • 在ImageNet 256×256数据集中,使用传统SDf8d4 VAE,作者的解耦扩散Transformer(DDT-XL/2)模型在仅256个epoch下即实现了SoTA 1.31 FID,并采用区间引导,相较于REPA [52]的训练速度提升了约4倍。
  • • 在ImageNet

数据集上,DDT-XL/2模型达到了SOTA 1.28 FID,显著优于所有先前方法。

  1. 相关工作

扩散Transformer。DiT [36]的开创性工作将Transformer引入扩散模型,用以取代传统上占主导地位的UNet架构 [2, 10]。实验证据表明,在足够的训练迭代下,扩散Transformer即使不依赖长残差连接也能超越传统方法。SiT [32]进一步通过线性流扩散验证了Transformer架构。基于扩散Transformer的简洁性和可扩展性 [32, 36],SD3 [12]、Lumina [54]和PixArt [6, 7]将扩散Transformer应用于更High-Level的文本到图像领域。此外,最近扩散Transformer凭借确凿的视觉和运动质量在文本到视频领域占据主导地位 [1, 20, 24]。作者的解耦扩散Transformer(DDT)是扩散Transformer家族中的一个新变体。它通过解耦低频编码和高频解码实现更快的收敛。

快速扩散训练。为了提升扩散 Transformer 的训练效率,近期的研究进展致力于多方面的优化。以算子为中心的方法[13, 45, 48, 49]利用高效的注意力机制:线性注意力变体[13, 45, 49]将二次复杂度降低以加速训练,而Sparse注意力架构[48]优先考虑Sparse相关的token交互。重采样方法[12, 16]提出了对数正态采样[12]或损失重加权[16]技术来稳定训练动态。表征学习增强方法整合了外部归纳偏差:REPA[52]、RCG[27]和DoD[53]将特定视觉先验引入扩散训练,而 Mask 建模技术[14, 15]通过在去噪过程中强制执行结构化特征补全来强化空间推理。总体而言,这些策略解决了计算、采样和表征的 Bottleneck 问题。

  1. 初步分析

基于线性流的匹配方法[29, 30, 32]代表了一种特殊的扩散模型家族,作者将其作为主要分析目标,因其具有简洁性和高效性。为便于讨论,在某些情况下,扩散和流匹配将交替使用。在该框架中,

对应于纯噪声时间步。

如图3所示,扩散模型对频谱成分进行自回归细化[11, 37]。扩散Transformer在解码高频细节前,先编码噪声潜在信息以捕获低频语义。然而,这种语义编码过程不可避免地会削弱高频信息,从而产生优化困境。这一观察结果促使作者提出将传统的仅解码扩散Transformer解耦为显式的编码器-解码器架构。

picture.image

引理1 对于在时间步

的线性流匹配噪声调度器,作者用

表示干净数据

的最大频率。在含噪潜在空间中保留的最大频率满足:

引理1直接借鉴自[11, 37],作者将引理1的证明放在附录中。根据引理1,随着

增加到更少噪声的时间步长,语义编码变得更容易(由于噪声减少),而解码复杂度增加(因为残差频率增长)。考虑在去噪步骤

的最坏情况场景,扩散Transformer编码频率高达

,要进入步骤

,它必须解码至少为

的残差频率。在步骤

未能解码这些残差频率会为进入后续步骤造成关键 Bottleneck 。

从这个角度来看,如果将更多计算分配给更噪声的时间步长能带来改进,这意味着扩散Transformer在编码低频以提供语义方面存在困难。否则,如果将更多计算分配给更少噪声的时间步长能带来改进,这意味着流匹配Transformer在解码高频以提供精细细节方面存在困难。

为了找出当前扩散模型的 Bottleneck ,作者使用SiT-XL/2和二阶Adams类线性多步求解器进行了针对性实验。如图4所示,通过改变时间偏移值,作者证明相比于均匀调度,将更多计算分配给早期时间步能够提升最终性能。这揭示扩散模型在低噪声步骤中面临挑战。由此得出一个关键结论:当前的扩散Transformer本质上受限于其低频语义编码能力。这一见解推动了探索具有策略性编码器参数分配的编码器-解码器架构。

picture.image

已有研究进一步支持这一观点。虽然轻量级扩散MLP Head 表现出有限的解码能力,但MAR [28]通过其 Mask Backbone 网络产生的语义隐变量克服了这一限制,实现了高质量图像生成。类似地,REPA [52]通过与前训练视觉基础 [35]的校准来增强低频编码。

  1. 方法

作者的解耦扩散Transformer架构由一个条件编码器和一个速度解码器组成。条件编码器从含噪声输入、类别标签和时间步中提取低频分量,作为速度解码器的自条件;速度解码器利用自条件处理含噪声潜空间,以回归高频速度。作者使用已建立的线性流扩散框架训练该模型。为简洁起见,作者将模型命名为DDT(解耦扩散Transformer)。

4.1. 条件编码器

条件编码器镜像了DiT/SiT的架构设计和输入结构,并进行了微设计改进。它由交替的注意力模块和 FFN 模块构成,不包含长距离残差连接。编码器处理三个输入,即含噪声的潜在变量

、时间步

以及类别标签

,通过一系列堆叠的注意力模块和 FFN 模块提取自条件特征

具体而言,噪声潜在变量

被分块处理为连续的token,然后输入到上述编码器块中以提取自条件

。时间步

和类别标签

作为外部条件信息被映射到嵌入空间中。这些外部条件嵌入通过AdaLN-Zero[36]在每个编码器块中逐步注入到

的编码特征中。

为了保持相邻时间步长中

的局部一致性,作者采用 REPA [52] 中的表示对齐技术。如公式 (3) 所示,该方法将自映射编码器中第

层的中间特征

与 DINOv2 表示

对齐。与 REPA [52] 一致,

是可学习的投影 MLP:

这种简单的正则化方法加速了训练收敛,如REPA [52]所示,并促进了相邻步骤之间

的局部一致性。它允许在相邻步骤之间共享由编码器产生的自条件

。作者的实验表明,这种编码器共享策略显著提高了推理效率,且性能退化可以忽略不计。

此外,编码器还从解码器接收间接监督,作者将在后面详细阐述。

4.2. 速度解码器

速度解码器采用与条件编码器相同的架构设计,由多个堆叠的交错注意力(Attention)和 FFN (FFN)模块组成,类似于DiT/SiT。它以含噪声潜在变量

、时间步长

和自条件

作为输入,以估计速度

。与编码器不同,作者假设类别标签信息已经嵌入在

中。因此,仅将外部条件时间步长

和自条件特征

作为解码器模块的条件输入:

如前所述,为进一步提高相邻步骤中自条件

的一致性,作者采用AdaLN-Zero [36] 将

注入解码器特征中。解码器使用公式(5)所示的流匹配损失进行训练:

4.3. 采样加速

通过将显式表示对齐集成到编码器中,并将隐式自条件注入到解码器中,作者在训练过程中实现了相邻步骤间

的局部一致性(如图5所示)。这使作者能够在合适的局部范围内共享

,从而减轻了自映射编码器上的计算负担。

picture.image

形式上,给定总推理步数

和编码器计算预算

,共享比率为

,作者定义

为一个包含

个时间步的集合,这些时间步用于重新计算自条件,如方程6所示。如果当前时间步

不在

中,作者将先前计算的

重新用作

。否则,作者使用编码器和当前带噪声的潜在变量

重新计算

均匀编码器共享。这种朴素方法每

步重新计算自条件

。先前工作,如 DeepCache [33],使用这种手工制作的均匀

集合来加速 UNet 模型。然而,仅使用去噪损失训练且缺乏鲁棒表示对齐的 UNet 模型,在相邻步骤的深层特征中表现出比作者的 DDT 模型更弱的局部一致性。此外,作者将提出一种简单而优雅的统计动态规划算法来构建

。与朴素方法 [33] 相比,作者的统计动态规划能够更优地利用最优

集合。

统计动态规划。作者使用余弦距离在步骤

中的不同步长

之间构建统计相似度矩阵

。最优的

集合将保证总相似度成本

达到全局最小值。这个问题是一个经典的最小和路径问题,可以通过动态规划解决。如公式 (8) 所示,当

时,作者用

表示成本,用

表示追踪路径。从

的状态转移函数如下:

在获取成本矩阵

和跟踪路径

后,可通过从

回溯

来求解最优

  1. 实验

作者在

的ImageNet数据集上进行了实验。总训练批次大小设置为256。与SiT [32]、DiT [36]和REPA [52]等方法学方法一致,作者在整个训练过程中采用了Adam优化器,并使用恒定学习率0.0001。为确保公平的比较分析,作者没有使用梯度裁剪和学习率预热技术。作者的默认训练基础设施由

A100 GPU组成。对于采样,作者默认选择具有250步的Euler求解器。对于VAE,作者采用Huggingfacel提供的现成VAE-ft-EMA,其下采样因子为8。作者报告了FID [18]、sFID [34]、IS [39]、Precision和Recall[25]。

5.1. 改进的 Baseline 模型

近年来,SwiGLU [46, 47]、RoPE [42] 和 RMSNorm [46, 47] 等架构改进在研究界得到了广泛验证 [8, 31, 50]。此外,lognorm采样 [12] 对训练收敛性表现出显著优势。因此,作者通过结合这些先进技术,并借鉴该领域的最新研究成果,开发了改进的 Baseline 模型。这些改进 Baseline 的性能在表2中全面展示。为验证作者实现的可靠性,作者还复现了REPA-B/2的结果,所获得的指标略微超过了REPA [52] 中最初报告的指标。这些复现结果进一步增强了作者对方法鲁棒性的信心。

picture.image

作者表2中的改进 Baseline 在未使用REPA的情况下始终优于其前身,然而在实施REPA后,性能迅速接近饱和点。这一点在XL模型尺寸上尤为明显,增量技术改进带来的收益微乎其微。

5.2. 与 Baseline 的指标比较

作者在表2中展示了不同尺寸模型在400K训练步骤时的性能表现。作者的扩散编码器解码器Transformer(DDT)系列在各个模型尺寸上均表现出一致且显著的提升。DDT-B/2(8En4De)模型比Improved-REPAB/2提高了2.8 FID。DDT-XL/2(22En6De)模型比REPA-XL/2提高了1.3 FID。虽然仅解码器扩散Transformer在REPA[52]中已接近性能饱和,但DDT模型仍持续取得优异结果。增量技术改进带来的收益逐渐减少,尤其是在较大模型尺寸上。然而,DDT模型保持了显著的性能优势,突显了DDT的有效性。

5.3. 系统级比较

ImageNet 256×256。作者在表1中报告了DDTXL/2(22En6De)和DDT-L/2(20En4De)的最终指标。DDT模型展现出卓越的效率,相较于REPA [52]和其他扩散Transformer模型,在总epoch数的约1/4时即可实现收敛。为了与REPA在方法论上保持一致性,作者在区间[0.3, 1]中采用了2.0的classifier-free guidance。DDT取得了令人印象深刻的结果:DDT-L/2在仅80个epoch内达到了1.64 FID,而DDT-XL/2则达到了1.52 FID。通过将训练扩展至256个epoch(这一方法仍然显著优于传统的800个epoch方法),DDT-XL/2在ImageNet 256×256上建立了新的SOTA基准,FID达到了1.31,明显优于之前的扩散Transformer方法。为了将训练扩展至400个epoch,DDT-XL/2(22En6De)达到了1.26 FID,几乎接近SD-VAE-ft-EMA-f8d4的上限,后者在ImageNet256上的1.20 rFID。

picture.image

ImageNet

作者在表3中提供了DDTXL/2的最终指标。为了验证作者DDT模型的优越性,作者将作者在ImageNet

256上经过256个epoch训练的DDT-XL/2作为初始化,然后在ImageNet

上对DDT-XL/2进行100K步的微调。作者采用了上述的间隔引导[26],并取得了1.90 FID的显著最先进性能,比REPA显著提高了0.28的性能差距。在表3中,一些指标表现出微小的退化,作者将其归因于潜在的微调不足。当为DDT-XL/2分配更多的训练迭代时,在时间间隔[0.3, 1.0]内,它以CFG3.0在500K步达到了1.28 FID。

picture.image

5.4. 通过编码器共享实现加速

如图5所示,作者的条件编码器中的自条件具有强烈的局部一致性。即使

的相似度也超过0.8。这种一致性为通过在相邻步骤之间共享编码器来加速推理提供了机会。

picture.image

作者采用了简单的统一编码共享策略和新的创新统计动态规划策略。具体而言,对于统一策略,作者每

步仅重新计算自条件

。对于统计动态规划,作者通过动态规划在相似度矩阵上求解上述最小和路径,并根据求解策略重新计算

picture.image

如图6 所示,当

小于 6 时,推理速度显著提升,几乎不损失视觉质量。如表4 所示,指标损失仍然很小,而推理速度提升显著。新的统计动态规划策略在 FID 下降方面略微优于简单的统一策略。

5.5. 消融实验

作者对ImageNet

数据集进行了消融实验,使用DDT-B/2和DDT-L/2。在采样过程中,作者默认选择使用250步的欧拉求解器,且不采用无分类器引导。在训练过程中,作者使用80个epoch(

步)训练每个模型,并将批处理大小设置为256。

picture.image

作者系统地探索了不同模型尺寸下从

的编码器-解码器比例,如图7和图8所示。作者的符号mEnnDe表示具有

个编码器层和

个解码器层的模型。图7和图8中的研究实验揭示了架构优化的关键洞察。作者观察到,随着模型尺寸的增加,更大的编码器有利于进一步提高性能。对于图7中的基础模型,最佳配置为8个编码器层和4个解码器层,实现了卓越的性能和收敛速度。

picture.image

值得注意的是,图8中的大型模型表现出明显偏好,在20个编码器层和4个解码器层时达到峰值性能,这是一个出乎意料的激进编码器-解码器比例。这一意外发现促使作者将DDT-XL/2的层比例扩展到22个编码器层和6个解码器层,以探索扩散 Transformer 的性能上限。

解码器模块类型。在作者的解码器模块类型及其对高频解码性能影响的研究中,作者系统地评估了多种架构配置。作者的综合评估包括简单

卷积模块和朴素MLP模块等替代方法。如表5所示,默认设置(带有MLP的注意力机制)取得了更好的结果。得益于编码器-解码器设计,即使是朴素的卷积模块也能取得相当的结果。

picture.image

  1. 结论

本文介绍了一种新型解耦扩散Transformer,重新思考了传统扩散Transformer的优化困境。通过将低频编码和高频解码解耦为专用组件,作者有效解决了制约扩散Transformer的优化困境。此外,作者发现随着整体模型规模的增大,相对于解码器提高编码器容量能带来越来越有益的结果。这一见解为未来的模型扩展工作提供了宝贵指导。

实验表明,DDT-XL/2(22En6De)凭借其非同寻常的编码器-解码器层比例,在仅需256个训练周期的情况下实现了优异性能。这一效率的显著提升解决了扩散模型的主要限制之一:其漫长的训练需求。解耦架构还通过作者提出的编码器结果共享机制为推理优化提供了机会。作者用于确定最佳共享策略的统计动态规划方法能够在保证质量不下降的前提下实现更快的推理,证明架构创新可以带来超越其设计目标的好处。

参考

[1]. DDT: Decoupled Diffusion Transformer

picture.image

扫码加入👉「集智书童」交流群

(备注:方向+学校/公司+昵称)

picture.image

0
0
0
0
关于作者
关于作者

文章

0

获赞

0

收藏

0

相关资源
字节跳动 XR 技术的探索与实践
火山引擎开发者社区技术大讲堂第二期邀请到了火山引擎 XR 技术负责人和火山引擎创作 CV 技术负责人,为大家分享字节跳动积累的前沿视觉技术及内外部的应用实践,揭秘现代炫酷的视觉效果背后的技术实现。
相关产品
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论