备注好友:方向-学校/公司-姓名/昵称
【AIGC 先锋科技】交流群
随着生成式AI图像技术的普及和进步,对强大的归因模型的需求日益增长。这些模型对于验证图像的真实性以及识别其来源生成模型的架构至关重要,这是维护媒体完整性的关键。
然而,归因模型难以泛化到未见过的模型,而用于更新这些模型的传统微调方法在实际应用中已被证明是不切实际的。为了应对这些挑战,作者提出了可扩展的低秩自适应网络(LoRAX),这是一种参数高效的类别增量算法,能够在无需完整重新训练的情况下适应新的生成式图像模型。
llm-LoRAX_2504通过低秩自适应为每个持续学习任务训练一个参数极其高效的特征提取器。每个特定任务的特征提取器学习不同的特征,而仅需底层特征提取器 Backbone 模型参数的一小部分。
作者的大量实验表明,LoRAX在持续深度伪造检测基准测试的所有训练场景和内存设置下,均优于或与最先进的类别增量学习算法保持竞争力,同时每个特征提取器的可训练参数数量不到全秩实现的3%。
1 引言
图像生成技术的快速民主化和进步强调了归属模型验证视觉内容真实性的必要性。这些模型是2024年。本文件的版权归其作者所有。
它可以未经更改地以印刷或电子形式自由分发。
https://github.com/mit-11/lorax\_ci1vital 用于确定图像的来源、分类其真实性或识别其生成模型的架构。此外,通过预测合成图像背后的生成模型,媒体取证分析师可以有效地调查事件并揭露有组织的虚假信息活动。这项能力对于维护媒体完整性、国家安全、机构信任以及公众对媒体来源的信心至关重要。
在合成图像(深度伪造)生成/分类的猫鼠游戏中开发归因模型的主要挑战之一是新型生成模型的快速迭代发展[18]。这些频繁的进展扩展了可能的源架构集合;因此需要持续对归因模型进行重新训练或微调[17]。一种方法是完整重新训练,但由于它需要使用每个遇到的类别的整个数据集,因此计算成本高,可能因存储限制和隐私问题而不可行[6, 26]。第二种方法,迭代微调,会遭受灾难性遗忘(CF)[15],在训练新类别后,对先前学习类别的性能显著下降,除非微调过程包含缓解技术。鉴于生成图像模型的输出包含独特的模式,或称为“指纹”,可用于推理源模型的详细信息[2, 33],作者可以将深度伪造检测视为一个持续学习问题。这种新型类别随时间出现的场景式微调分类模型的问题,是被称为类别增量学习(CIL)[36]的成熟研究问题。
CIL算法旨在实现稳定性与可塑性之间的最优平衡[3]稳定性使模型能够保留先前训练阶段的知识,而可塑性则允许其在后续阶段学习额外信息。近年来,以模型为中心的CIL算法——即每轮训练扩展网络主干的算法——在持续学习基准测试中取得了最先进的性能[36]。然而,在将现有模型应用于深度伪造溯源问题时,作者发现这些模型要么存在灾难性遗忘问题,要么导致模型参数在实际应用中呈指数级增长。为解决这两个问题,作者利用低秩自适应(LoRA)[12],并将其应用于持续合成图像溯源问题。llm-LoRAX_2504——LoRA可扩展网络(LoRAX),为每个CIL任务训练参数高效的特征提取器。每个任务的特征提取器是通过将特定任务的LoRA权重更新应用于单个冻结的主干网络形成的。特定任务的特征提取器捕获每个生成模型留下的独特模式,每个任务提取器的特征最终被输入到统一的分类头进行溯源。LoRAX通过在整个训练过程中冻结底层主干模型,并在各自训练阶段结束后冻结特定任务的特征提取器,有效避免了灾难性遗忘。
在本文中,作者提出了以下贡献:
- 作者针对ConViT Backbone 网络,对动态网络CIL算法DER和MEMO进行了适配。实验结果表明,对于每个算法,ConViT Backbone 网络始终在参数数量相等或更少的情况下优于ResNet Backbone 网络,突显了 Backbone 网络选择在持续学习性能中的重要性。
- 作者介绍了LoRAX,一种新颖的参数高效增量分类学习算法。
- 作者在持续深度伪造检测(CDDB)基准数据集上完成了大量实验,以验证LoRAX方法在不同内存设置和CIL任务数据流中的有效性。LoRAX在所有测试的学习场景和内存预算下,均与其他CIL算法具有竞争力或表现更优。
2 相关工作
数据驱动的深度伪造检测方法擅长识别其训练集中包含的模型生成的图像[18, 34],但在识别由未见过的技术生成的图像时表现不佳[4]。这种静态、非鲁棒性的模型训练设置在快速发展的生成式AI领域并不实用,并为维持现实世界的分类精度带来了重大挑战。为应对这些挑战,研究行人已成功将持续学习算法应用于深度伪造检测问题领域[16, 17],突显了持续学习在增强深度伪造归因模型鲁棒性和适应性的潜力。尽管取得了这些进展,分类精度仍有提升空间。特别是,在适应新型生成技术的同时减轻遗忘问题,需要进一步的研究和改进。
2.1 类别增量学习
增量学习(Incremental Learning, CIL)算法处理持续演化的数据流,其中新类别随时间引入[36]。CIL数据流由特定任务的数据集组成,记作
。每个任务数据集
代表特定时间点的数据子集,定义为
,其中
是第
个阶段(episode)的训练实例集合,
是第
个阶段中独有类别的集合。重要的是,每个阶段的类别集合不重叠,即
,而训练结束时的完整类别集合为
。
在增量学习中,模型通过每个新任务分阶段更新以纳入新类别。初始时,模型
被训练以分类第一阶段出现的原始类别集(
)。随后的每个阶段,模型增量更新以包含新类别,演变为
。此更新过程通常在无法或仅有限访问先前阶段数据的情况下进行。为帮助模型"记住"先前学习的任务,某些增量学习方法在微调时使用先前训练数据的子集,称为示例(exemplars)。这些示例被纳入未来阶段训练数据中。增量学习的目标是持续适应单个模型以分类新遇到的类别,同时保持对先前学习类别的准确性,并最小化对历史数据的访问。
2.1.1 基准动态网络CIL算法
近期研究见证了CIL算法的快速发展[8, 15, 24, 28, 30, 32, 37]。在这些算法中,以模型为中心[36]的 Backbone 扩展方法[8, 32, 37]最近取得了最先进的结果。这些方法通过扩展 Backbone 模型来容纳学习更多类别,同时尽量减少对先前学习类别的干扰。 Backbone 扩展方法特别适用于作者的深度伪造分类专业应用,因为它们不严重依赖预训练网络,而预训练网络可能无法很好地泛化到特定应用任务[36]。通过动态扩展网络的架构, Backbone 扩展方法确保在变化的数据环境中获得更鲁棒的性能。
微调 微调方法是作者的 Baseline CIL算法,它对所有CIL任务训练单个 Backbone 模型。对于每个任务,它通过扩展最终分类头以包含最新任务的类来修改模型。在每个CIL回合中,模型权重在当前任务上重新训练,未采取任何措施来减轻遗忘。
动态扩展表示(DER)是一种早期的基于 Backbone 网络扩展的持续学习(CIL)算法[32]。DER为每个CIL任务添加一个 Backbone 网络特征提取器。所有提取的特征被连接形成一个"超级特征",然后输入到一个统一的分类器中。除了传统的交叉熵损失,DER还引入了一个辅助损失。辅助损失使用最新任务的 Backbone 网络特征提取器来训练一个独立的分类器,以区分当前任务
中的所有类别,以及一个额外的类别代表所有先前见过的类别
。辅助损失的目标是鼓励模型从现有的特征提取器集合中学习多样化的特征集。2MEMO 基于浅层神经网络层倾向于提取一般特征的观察结果,记忆高效扩展模型(MEMO)[37]算法将专用模块集成到每个CIL任务的共享基础中。专用模块在利用浅层层提取的一般特征的同时高效地集成新任务。MEMO还使用了DER的辅助损失。
DyTox动态 Token 扩展(DyTox)[8]采用了一种基于transformer的架构,专门用于CIL任务。该算法具有共享的编码器和解码器层,以及一组动态扩展的任务特定 Token 。每个任务 Token 被添加到共享编码器的输出特征之前,并输入到共享解码器中以生成任务特定的嵌入。然后使用这组任务特定嵌入进行分类。
示例的作用作者评估的每种CIL算法都将过去训练过程中的示例纳入后续任务的训练数据集中,以帮助保留已学习的信息。一种常用的示例选择策略是群集[24]。群集通过选择那些特征最接近其类别特征均值的数据来选择"最具代表性的样本"。
3LoRA可扩展网络
在扩展现有动态网络CIL算法的工作基础上,作者定义了一种使用任务特定 Backbone 网络来学习每个任务的鲁棒表示,同时最小化任务间干扰的CIL算法。通过利用参数高效的微调技术低秩适配(LoRA)[12],作者最小化了与每个任务的特征提取器相关联的附加参数数量。
3.1 参数高效微调
微调用于针对特定下游应用更新预训练网络;需要存储网络中每个参数的权重更新。受模型权重更新具有低“固有秩”[1]这一假设的启发,LoRA算法[12]被定义为:对于预训练权重矩阵
及其权重更新AW,AW可以表示为低秩分解
,其中
,
,且秩
。这种低秩表示在
时,将存储权重更新所需的参数数量大大减少至原始网络
所需的参数数量。
LoRA微调初始化网络每个指定部分的低秩矩阵
和
。在模型训练过程中,仅更新
和
矩阵,而底层 Backbone 模型保持不变。微调后的模型通过将
和
矩阵的乘积加到原始模型权重上计算得到(1)。CoLoR CIL算法[30]也利用LoRA为每个CIL任务训练一个独立的分类器。然而,它依赖于高度预训练的 Backbone 模型来为每个输入选择任务分类器。如第2节所述,对高度预训练网络的依赖通常会在特定应用中降低性能,因此作者将其排除在分析之外。
3.2 LoRAX算法
LoRAX ConViT模型结果使用了建议的配置,其中LoRA秩为
。关于选择LoRAX模型的详细信息,请参见补充材料。
动态特征扩展LoRAX CIL算法(如图1所示)为每个CIL任务训练一个特征提取器
,以捕获每个生成器留下的独特指纹。每个特征提取器网络通过将LoRA权重更新应用于预训练模型
进行训练,即
。作者通过仅存储LoRA权重更新矩阵
来限制与每个特征提取器相关的模型参数数量。为缓解灾难性遗忘,作者在整个训练过程中冻结预训练网络
,并在每个CIL周期的结束时冻结每个
。遵循DER的特征扩展模式,作者将从每个任务的特征提取器中提取的特征连接起来,生成"超级特征"
。
₁
₂
被输入到统一的分类头 CLF 中,用于模型归因。
3.3 LoRAX损失
遵循DER和MEMO,LoRAX使用一个简单的两项损失函数:交叉熵损失
和多样性损失
交叉熵损失函数(公式4b)有助于模型学习当前训练周期中遇到的全新任务,并通过包含当前任务训练数据集中的样本来减轻对先前已学习任务的遗忘。在每个CIL周期开始时,作者扩展模型的统一分类器头,以整合任务的全新类别和特征提取器。从第二个任务开始,权重将继承自上一周期的CLF分类器。
多样性损失多样性损失分类器被引入以最小化每个任务 Adapter 提取的特征之间的冗余性。该分类器仅在训练期间需要,并在每个任务训练结束时被移除。超参数
决定了多样性损失的权重(公式4c),影响 Adapter 特征多样性和分类精度之间的平衡。作者对
进行了超参数扫描,并选择
进行作者的实验。
范例 LoRaX算法将其训练过程融入了先前情节中的范例样本,以防止遗忘先前学习的类别。范例通过iCarl的[24]羊群过程进行选择。
持续深度伪造检测基准(CDDB)[16] 是一个用于在持续学习环境下评估合成图像检测/分类模型的深度伪造检测基准。该基准通过整合来自十二个知名合成图像分类数据集的图像创建而成。CDDB基准为持续学习模型评估定义了三种训练场景:简单(7个任务)、困难(5个任务)和长期(12个任务)。每个场景定义了训练持续学习模型的任务顺序。每个任务的数据库包括一组真实图像和一组由已知及未知生成模型生成的合成图像。对于已知生成模型的任务,真实图像对应于合成源的训练数据。
4.1 多真实设置
CDDB数据集中的每个任务都包含一组真实图像和合成图像;因此,在作者的CIL过程中,每个任务都会产生一个额外的真实图像类别。作者采用一种多真实分类方案,该方案不会对真实类别之间的混淆进行惩罚(例如,来自任务i的真实图像被分类为来自任务j的真实图像,其中j≠i)。作者在多真实环境下计算作者的性能指标,即一个真实图像被分类为任何真实图像类型都被视为正确。
4.2 实现细节
作者在PyTorch [19]中实现了每个CIL模型。基准模型代码基于以下开源实现:DER [35]、MEMO [35]、DyToX [8, 16]。作者的工作扩展了DER和MEMO实现,以包含使用PyTorch Image Model [29] ConViT实现的ConViT主干网络。所有测试的主干模型均使用ImageNet [7]预训练权重初始化。作者的LoRAX模型的LoRA微调组件使用HuggingFace [31]库。所有模型均在单个GPU上进行训练(微调、DER、LoRAX:NVIDIA A5000、A6000、L40或A1000 GPU;MEMO、DyTox:NVIDIA Volta V100)。作者在包含500个样本的CDDB硬场景中,使用训练数据15%的验证集,对每个CIL和主干模型组合的超参数进行调优,该验证集用于所有其他场景和内存设置。
4.3 评估指标
为了评估CIL算法在一系列任务中的分类准确度及其在先前学习任务上保持性能的能力,作者追踪三个持续学习指标:平均准确度(AA)、最终任务平均准确度(AAF)和逆向迁移性(BWT)。AA表示每个阶段平均分类准确度的均值。AAF表示最终阶段所有任务的平均分类准确度。BWT衡量学习新任务对先前学习任务性能的影响,负BWT值表示性能退化。BWT值越不负面,遗忘越少。这些指标在
矩阵
上计算,其中
是任务数量,
是任务
在训练任务
后的准确度
4.4 主干网络对CIL性能的影响
作者对ResNet [11] 和ConViT [9] 主干架构在示例微调、DER w/o P 和 MEMO CIL 算法中进行了比较分析。3 作者的实验结果,如图2所示,表明主干模型的选择会影响CIL算法的性能。在每个测试的动态网络CIL算法中,基于ConViT的实现始终在平均精度、最终平均精度和反向迁移方面优于参数数量相等或更少的基于ResNet的实现。此外,在DER w/o P 和 MEMO 算法中,ConViT Small模型(2730万参数)在所有三个指标上均优于更大的ResNet152模型(5810万参数),这突显了其在尺寸较小的情况下仍具有的高效性。
4.5 CDDB CIL算法评估
LoRAX在保持高分类精度的同时,每个CIL回合所需的训练参数数量相对较少。如图3所示,在每种CDDB场景中,LoRAX在所有测试的CIL算法中所需的训练参数数量最少。对于内存
的设置,LoRAX ConViT Base模型表现优异,在BWT、AA和AAF方面始终排名顶尖算法,同时与其他方法相比,所需的训练参数数量仅占一小部分。4
与联合训练设置对比,作者为每个 Backbone 网络训练了一个Oracle模型。Oracle模型作为分类模型性能的上限,其中所有情节的训练数据同时可用。在内存
的情况下,每个CDDB场景中表现最佳的CIL算法几乎与Oracle模型相当,在AA和AAF指标上分别仅落后1.5%。这一稳健的性能表明,即使重演缓冲区有限,CIL也是深度伪造分类的有效方法。
LoRAX通过示例改进微调性能 如表1所示,将LoRA Adapter 应用于微调的ConViT Backbone 网络,在所有场景中均将准确率提高了1-7% AA和1-10% AAF,并且这一改进在所有测试的内存设置下均成立(参见补充材料)。通过连接每个任务特定特征提取器的特征维度,LoRAX相对于使用示例微调的 Baseline CIL框架,减少了任务间的遗忘现象。
5 结论
在本文中,作者提出了LoRAX,一种利用LoRA进行训练的新型类别增量学习算法,该算法为每个类别增量学习任务训练一个参数极其高效的特征提取器。作者的任务特定特征提取器使得LoRAX训练的模型能够识别每个任务特有的伪影,同时最小化类间学习干扰。
此外,通过在每个相应的CIL(类别增量学习)阶段后冻结每个特征提取器,LoRAX减少了灾难性遗忘。与底层的ConViT Base Backbone 模型相比,作者提出的LoRAX模型显著降低了每个任务特定特征提取器的内存。作者在持续深度伪造检测数据集上评估了LoRAX方法,并展示了它在一系列当代动态网络CIL算法中取得了具有竞争力的性能。
参考
[1]. joRAX: LoRA eXpandable Networks for Continual Synthetic Image Attribution