新SCMHSA架构缓解 Transformer 下一帧预测语义稀释,适配损失函数性能更优 !

向量数据库大模型云通信

点击下方卡片,关注

「AI视界引擎」

公众号

( 添加时备注:方向+学校/公司+昵称/姓名 )

picture.image

picture.image

视频中的下一帧预测对于自动驾驶、目标跟踪和运动预测等应用至关重要。下一帧预测的主要挑战在于有效地从先前的视频序列中捕获和处理空间和时间信息。

以擅长处理序列数据著称的Transformer架构,在这一领域取得了显著的进展。然而,基于Transformer的下一帧预测模型存在一些显著问题:

(a)多头自注意力(MHSA)机制需要将输入嵌入分成

个片段,其中

是头的数量。每个片段仅捕获原始嵌入信息的一小部分,这扭曲了嵌入在潜在空间中的表示,导致语义稀释问题;

(b)这些模型预测下一帧的嵌入而不是帧本身,但损失函数基于重建帧的错误,而不是预测嵌入——这造成了训练目标和模型输出之间的差异。作者提出了语义浓度多头自注意力(SCMHSA)架构,有效地缓解了基于Transformer的下一帧预测中的语义稀释问题。

此外,作者引入了一个损失函数,该函数在潜在空间中优化SCMHSA,使训练目标与模型输出更加吻合。作者的方法在性能上优于原始的基于Transformer的预测器。

  1. 引言

在人工智能(AI)迅猛发展的今天,本文旨在探讨该领域的最新研究进展及其对各行各业的影响。随着计算能力的增强和算法的优化,AI在图像识别、自然语言处理、智能决策等领域取得了显著的成果。本文将从以下几个方面对AI研究进行综述:理论基础、关键技术、应用领域以及未来发展趋势。

人类能够利用视觉信息预测短期未来事件,这一能力对于自动驾驶[27]、动作识别[7]、运动预测[7]和异常检测[10]等任务至关重要。为了使机器能够发展出类似的能力,视频帧预测(VFP)任务得到了广泛的研究,该任务涉及从最近过去的帧中预测未来的视频帧。

视频帧预测(VFP)因其涉及现实世界视频动态的复杂性和未来事件固有的不确定性而具有挑战性[32]。有效地捕捉空间和时间信息至关重要[8]。在Transformer出现之前[21],通常的方法是将序列模型(例如,LSTM、RNN、GRU)与卷积神经网络(CNN)结合,从过去的帧中提取时空特征[8, 22, 24, 26]。然而,这些模型在处理长期依赖关系时存在困难,受到梯度消失的影响,而且在处理较长的序列时计算成本高且易出错[2]、[8]、[28]。相比之下,基于Transformer并使用多头自注意力(MHSA)的方法在处理长期依赖关系方面更为有效,并允许进行高效的并行处理。

然而,MHSA要求输入嵌入被分成多个块(对应于注意力头的数量),这可能会扭曲学习到的潜在空间并稀释语义信息,降低预测精度。图1展示了在MHSA实现中输入嵌入是如何划分的。尽管基于Transformer的VFP系统引入了额外的模块来增强预测性能,如局部时空块[16]、自注意力记忆(SAM)[8]或时间MHSA[28]等,但核心的MHSA机制并未改变,使得语义稀释的问题依然存在。此外,这些VFP系统并不是直接预测下一帧;相反,它们预测该帧的嵌入。这需要解码步骤来从其嵌入中重建预测的帧。然而,用于训练这些VFP系统的损失函数是基于重建帧的错误,而不是预测嵌入,这可能导致阻碍有效模型学习的差异。例如,在VFP系统中常用的一个损失函数结构如下[6, 8, 16]:

picture.image

代表真实帧,

代表预测帧,

表示L1损失(平均绝对误差,MAE),

表示L2损失(平均平方误差,MSE),而

代表感知损失或结构损失(如图像梯度损失[13]、Charbonnier损失[4]、Census损失[14]、LPIPS损失[29]等)。显而易见,这种损失函数取决于预测帧

,而非预测嵌入。尽管这种方法仍然有效,但它并不完全符合VFP系统的输出,后者预测的是嵌入。这种不一致性给模型优化带来了挑战,如收敛速度慢、学习效果次优和梯度不匹配。如何解决Transformer基础VFP系统中语义稀释和消除学习目标中的不匹配问题呢?

为了应对这些挑战,作者为基于Transformer的VFP系统引入了一个名为语义集中多头自注意(SCMHSA)的语义保持模块。该模块允许每个注意力头的 Query 、 Key和Value 矩阵使用整个输入嵌入来计算,从而减轻了标准MHSA块中的语义稀释问题。然而,这种方法增加了每个头的维度,使得训练这些头相互支持变得更加困难。为了解决这个问题,作者提出了一种损失函数,确保每个头在使用完整输入嵌入的同时,专注于不同的语义方面,从而促进更高效的收敛。

与在像素空间(基于重建帧)操作现有方法不同,作者的损失函数设计在嵌入空间(基于预测嵌入)工作,使训练目标更接近VFP输出。这种对齐还提高了SCMHSA块中自回归机制的效率,因为自回归是在嵌入空间操作的。通过预测下一帧的嵌入而不是进行完整的像素级重建,作者的方法有效地捕捉了关键特征,特别适合于像目标跟踪中的异常检测等任务,这些任务的焦点是识别显著的偏差或不规则性,而不是需要看到下一帧的完整可见性。作者在KTH[17]、UCSD[11]、UCF Sports[20]和Penn Action[30]四个不同的数据集上评估了作者的方法。

作者的贡献可以概括如下:

  1. 作者提出了一种语义集中多头自注意力(SCMHSA)模块,该模块能够保留VFP系统中输入帧序列嵌入中的语义信息。

  2. 作者提出了一种基于嵌入空间而非像素空间的新的损失函数,这有效地与VFP输出相匹配,并确保每个头可以专注于不同的语义方面。

  3. 通过实证评估,作者证明了作者提出的方法在预测准确性方面优于现有的基于Transformer的VFP技术。

  4. 相关研究


在Transformer模型[21]引入之前,视频帧预测(VFP)系统主要依赖于序列模型,如LSTM、RNN和GRU来处理时间信息,以及图像到向量模型,如CNN和自编码器来处理空间信息。这些时间和空间模型通常结合使用,以处理历史帧序列并预测下一帧。例如,ConvLSTM[18]通过使用卷积操作代替线性操作来增强传统的LSTM,在VFP领域取得了显著的成功。TrajGRU[19]通过引入动态、可学习的空间连接来扩展GRU,从而实现了更有效的时空建模。其他重要的变体包括PredRNN[22]、PredRNN

[23]、内存中的内存(MIM)[25]和运动感知单元(MAU)[3]。然而,序列模型通常面临一些挑战,如处理长期依赖、梯度消失问题、高计算成本、慢速训练和推理,以及易于累积误差[28]。

VFP系统已经发展到更高级的技术,利用了Transformer和注意力机制,这些机制能够比之前的序列模型更好地捕捉长距离依赖关系和并行化计算。[8] 将自注意力机制整合到ConvLSTM中,以捕捉空间和时间域中的长距离依赖关系。通过整合自注意力记忆(SAM)单元,模型能够更好地权衡不同空间和时间区域的重要性,从而在做出预测时关注输入数据中更相关的部分。这导致在视频预测、天气预报或需要理解复杂时空动态的任何应用中的性能得到提升。

另一个重要的贡献是“VPTR:高效的视频预测Transformer”[28],它引入了一种新型的Transformer块,结合了局部空间注意力和时间注意力,以减少计算复杂度同时保持高预测精度。作者提出了两种模型:VPTR-FAR(完全自回归)和VPTR-NAR(非自回归),两者都展示了改进的性能和推理过程中的误差累积减少。

结合Transformer与其他神经网络架构的混合模型也被提出。一个值得注意的例子是用于视频预测的混合Transformer-LSTM模型,它利用了Transformer和LSTM的优点[12]。该模型使用Transformer的注意力机制来捕捉长距离依赖关系,以及LSTM处理时间序列的能力,在视频预测任务中实现了令人印象深刻的结果。基于Transformer的下一帧预测的一个关键挑战是,当输入被分割成多个块用于MHSA时发生的语义稀释。

此外,常用的L1和L2损失函数,它们关注最小化预测帧与实际帧之间的误差,可能会在训练目标与模型输出之间造成不匹配。这种不匹配可能会给模型的学习过程带来不必要的困难。据作者所知,这些问题尚未被探索或解决。

  1. 方法

在本研究中,作者采用了一种基于深度学习的图像识别方法。该方法首先通过卷积神经网络(CNN)对图像进行特征提取,然后利用长短期记忆网络(LSTM)对提取的特征进行序列建模,以实现对图像内容的理解和分类。具体步骤如下:

  1. 数据预处理:对原始图像进行缩放、裁剪等操作,以确保输入网络的数据具有一致性和规范性。
  2. 特征提取:利用CNN提取图像的局部特征,并通过池化操作降低特征维度。
  3. 序列建模:将提取的特征输入LSTM网络,通过LSTM的时序处理能力,对图像内容进行建模。
  4. 分类与评估:利用训练好的模型对测试集进行分类,并计算分类准确率等指标,以评估模型性能。

在模型训练过程中,作者采用了交叉熵损失函数和Adam优化器,通过多次迭代优化模型参数。此外,为了提高模型的泛化能力,作者对数据集进行了数据增强处理,包括随机翻转、旋转等操作。

3.1. 问题表述

单帧视频帧预测(VFP)的目标是利用视频序列

中的前

个帧来预测下一个帧

,其中

表示第

个帧,具有高度

、宽度

个颜色通道。该任务被建模为:

在函数

中,输入帧序列被映射到预测的下一帧

由参数

决定。预测误差通过最小化损失函数

来优化,该函数通常定义为预测帧与实际帧之间的差异。

预测帧

和真实帧

常见的

的选择包括均方误差(MSE)或平均绝对误差(MAE)。

VFP任务可以被视为解决以下优化问题:

表示使预测误差最小化的最优参数。

3.2 基于Transformer的视觉飞行路径规划系统

每一帧

都通过基于CNN的模型

(例如,ResNet)进行处理,以提取高维特征表示。

帧嵌入

属于

,维度为

是卷积神经网络(CNN)的参数。

最后

个帧的嵌入

被输入到一个Transformer模型中,该模型使用多头自注意力(MHSA)来捕捉时间依赖性。对于每个头

, Query 、 Key和Value 矩阵的计算如下:

每个头的输入为

(其中

),

代表头的数量。这种划分在多头自注意力(MHSA)中导致了语义稀释。

注意力分数是计算嵌入

之间的,对于第

个头的:

个头的输出是一个加权求和的价值矩阵:

所有 Head 的输出随后被连接起来,以生成每个嵌入

的最终输出。

输出序列

被传递至解码网络,该网络生成预测帧

通常,解码器由一系列转置卷积或上采样层组成,这些层将最终的隐藏状态映射回原始帧空间。

该模型经过训练以最小化损失函数,例如预测帧

与真实帧

之间的均方误差(MSE)。

3.3. 提出方法

3.3.1 模型架构

作者提出了一种语义集中VFP(SC-VFP)模型(如图2所示),用以解决基于Transformer的VFP系统的局限性。该架构包括:

picture.image

嵌入层:通过视觉Transformer(ViT)[5]将输入帧映射到低维空间。一个可学习的分类 Token [CLS]从整个输入图像中聚合空间信息,代表图像帧嵌入。

语义集中VFP(SC-VFP):仅使用编码器Transformer处理嵌入,用作者提出的语义集中MHSA(SCMHSA)块替换原有的MHSA块,以减轻语义稀释。多个编码器块处理序列中的时间信息。

  1. 语义集中多头自注意力(SCMHSA):通过处理每个头的完整嵌入来增强传统的MHSA,从而得到更丰富的表示。为了管理增加的头维度,一个可学习的投影矩阵

保留了最相关的语义信息。

预测层:通过多层感知器(MLP)综合时空信息,预测下一帧的嵌入。

3.3.2. 语义集中多头自注意力(SCMHSA)

为了解决语义稀释问题,SCMHSA通过保留输入嵌入的全部语义内容来增强标准的MHSA。与传统MHSA不同,传统MHSA将嵌入

分割成

个更小的片段

(对于

个头),SCMHSA将完整的嵌入信息输入到每个注意力头中,实现整体处理并减轻语义损失。

在先前描述的标准Transformer架构的基础上,SCMHSA引入了以下关键改进:

每个 Head 的完整嵌入:每个注意力 Head

处理整个嵌入

,避免除法操作。

是学习得到的投影矩阵,其中

表示特定 Head (head)的投影维度。

可学习投影:所有 Head 输出的结果被连接起来:

为了处理增加的维度并保留关键语义信息,一个可学习的投影矩阵

将维度降低回

,同时保留最相关的语义信息。

SCMHSA提供:

· 语义完整性:每个 Head 都保留了完整的语义上下文,避免了稀释。 · 全局表征:所有 Head 都处理完整的嵌入,以实现更深入的理解。 · 维度管理:可学习的矩阵

过滤掉无关特征,保留关键信息。

3.3.3. 预测层

SC-VFP模块输出一系列向量序列

,其中

代表先前帧的数量,

表示第

帧经过处理的嵌入,其中包含了完整的语义信息。预测层(实现为多层感知器MLP)接收SC-VFP模块的输出,并生成一个向量

,该向量预测序列中下一个帧的嵌入。

预测层旨在将SC-VFP捕获的时间和空间信息转换为对下一个视频帧嵌入的准确预测。

3.3.4. 损失函数

在现有的基于Transformer的VFP系统中,模型预测下一帧的嵌入,但损失是在重建的帧上计算的[8]、[16]、[6]、[28],这可能会引入学习偏差。为了解决这个问题,作者提出了一种新的损失函数,该函数直接优化预测嵌入,并增强了SCMHSA模块的有效性。

所提出的损失函数

包含两个主要组成部分:

嵌入均方误差损失:该组件直接测量真实嵌入

与预测嵌入

之间的误差。这有助于确保模型被训练以生成下一帧的准确嵌入。

语义相似度损失:鉴于SCMHSA允许每个头接收完整的输入嵌入,确保 Head 捕获独特、不重叠的语义信息至关重要。语义相似度损失通过惩罚产生相似输出的 Head 来实现这一目标。

对于每一对头

(其中

),该组件计算行向量之间的余弦相似度,即

之间的平均角度距离。这些行向量余弦相似度的总和在所有头的对和头向量中的所有行上平均。语义相似度损失可以表示为:

在本文中,

代表 Head 的数量,

代表每个 Head 向量的行数(这同样也是输入序列的长度),而

表示 Head

的第

行。

总损失函数

是两个组成部分的加权之和:

在本文中,

是一个超参数,它控制着均方误差损失与语义相似度损失之间的权衡。

  1. 实验

4.1 数据集

作者使用了四个数据集进行了实验:KTH、UCSD行人、UCF运动和Penn动作。每个训练实例包含六个帧:五个输入帧和一个用于标签的第六帧。作者不是选择连续的五帧作为输入,而是每五个帧中选择一帧,以避免连续帧之间的相似性问题。这种方法也减少了需要处理的训练数据量。所有输入帧都被调整到

的大小,以与ViT模型[5]兼容。

4.1.1. KTH [17] - 韩德堡技术大学[17]

KTH数据集在人体动作识别领域被广泛应用,包含600个分辨率为160×120的视频序列,以25帧每秒的速度录制。在作者的实验中,作者使用了行走和跑步这两类动作。

4.1.2. UCSD行人数据集[11]

UCSD行人数据集包含从固定、高空摄像机拍摄的行人通道视频,用于识别异常事件。该数据集分为两个子集:Peds1,包含34个训练视频和36个测试视频,视频内容为行人向摄像机走来或走远;Peds2,包含16个训练视频和12个测试视频,视频内容为行人与摄像机平行移动。每个视频片段包含200帧黑白图像,分辨率为

4.1.3. UCF体育[20]

UCF运动数据集包含150个视频序列,每个序列的分辨率为720×480,涵盖了10个动作类别。该数据集对于研究运动场景中的人类行为非常有用,提供了多样化的场景和动作。

4.1.4. 宾州动作[30]

该宾夕法尼亚动作数据集包含2,326个分辨率为

的视频序列,涵盖了15个动作类别,例如棒球投球、引体向上和吉他弹奏。由于其涵盖了广泛的身体活动,并对动态动作进行了详细的标注,因此该数据集在动作识别任务中具有很高的价值。

4.2. 评估指标

由于作者的方法在嵌入空间而非像素空间中运行,因此常见的VFP指标,如LPIPS [29] 和SSIM不适用。相反,作者采用PSNR和MSE作为作者的评估指标。虽然PSNR传统上是使用像素值计算的,但作者改进了它,使用嵌入值。

4.3 实施细节

所提出的模型采用PyTorch实现。训练和评估在NVIDIA A100 40GB GPU上进行。该模型的架构由6个编码器块组成,每个块包含6个注意力头和768维的帧嵌入维度。序列长度设置为5。为了优化,作者使用了AdamW优化器[9],学习率为1e-4。批处理大小设置为32,模型训练了25个周期。数据集按照0.7、0.15和0.15的比例分别划分为训练集、验证集和测试集。为确保可重复性,作者将随机种子固定在2023。

4.4. 结果

作者展示了作者提出的方法的结果,将其与作者所知的最新下一帧预测器进行了比较,包括:PredRNN [22]、SA-ConvLSTM [8]、MIMO-VP [16]、LFDM [15]、VFP-ImageEvent [32]和ExtDM [31],这些比较是在之前提到的四个数据集上进行的。所有方法都在这四个数据集的测试集上进行了评估。定量结果和比较总结在表1中。除了定量比较外,作者在图3中展示了定性比较。与直接预测下一帧的传统下一帧预测模型不同,作者提出的方法预测的是下一帧的嵌入。下一帧的 GT 嵌入是通过ViT [5]生成的。在图3中,作者可视化了每种方法的预测嵌入和 GT 嵌入的错误图。定性结果展示了作者的方法预测嵌入与 GT 之间的吻合度。错误图越暗,预测与 GT 之间的偏差越大。作者还在图4中比较了作者的方法(SC-VFP)预测的嵌入与 GT 嵌入之间的余弦相似度,以及其他方法的比较。

picture.image

picture.image

picture.image

4.4.1 KTH实验结果

在KTH数据集上,作者的方法的表现相较于其他方法较差。具体来说,SCVFP的峰值信噪比(PSNR)比表现最好的方法低7.01%,均方误差(MSE)低59.94%(见表1)。这可以归因于与实验中使用的其他三个数据集相比,KTH数据集的相对较小规模。因此,在该数据集上进行训练和测试时,语义稀释的问题不太明显。相比之下,规模更大的数据集包含更多样化的语义,增加了语义稀释的可能影响,这可能导致性能差异更加明显。

4.4.2. UCSD行人数据集上的结果

在UCSD数据集上,作者的SC-VFP方法优于所有其他方法,实现了最低均方误差(MSE)为86.71和最高峰值信噪比(PSNR)为28.75。这些结果相比次优方法(VFP-ImageEvent)在MSE上提高了16.14%,在PSNR上提高了2.59%(见表1)。

4.4.3. UCF体育项目结果

在UCF体育数据集上,该数据集比KTH和UCSD更大、更复杂,作者发现SC-VFP与其他算法之间的性能差距更为显著。具体来说,均方误差(MsE)降低了38.3%,与最接近的竞争对手(ExtDM)相比,峰值信噪比(PSNR)提高了4.84%。这些结果进一步验证了作者的假设,即解决语义稀释问题对于处理大型数据集中更为复杂的运动模式至关重要。

不同数据集下方法的定量比较。加粗部分表示性能最佳。

4.4.4. 宾州动作测试结果

在最大的数据集(宾夕法尼亚动作数据集)上进行评估时,作者观察到在更大、更复杂的数据集上,性能持续提升的趋势。SC-VFP再次展现出显著的性能优势,在均方误差(MSE)方面比第二好的方法高出68.71%,在峰值信噪比(PSNR)方面高出6.63%。这证明了作者方法的可扩展性和鲁棒性,进一步确认了SC-VFP适用于需要精细粒度语义信息的高级视频预测任务。

4.5. 消融研究

4.5.1. 参数分析

作者提出的SCMHSA模型(参数量42.7M)与基于Transformer的原模型(参数量31.4M)相比,参数量有所增加,这是由于其每个注意力头中实现了完整的嵌入处理,从而减轻了语义稀释。这种设计使得模型参数比基准模型多出约

,能够实现更丰富的语义表示,并提高预测准确性,特别是在复杂数据集上。虽然作者的模型相对于原始Transformer具有更高的参数量,但实验表明,SCMHSA模型中额外的参数显著增强了模型捕捉复杂时空动态的能力,导致MSE和PSNR指标在多个数据集上均得到显著提升,如下一节所示。作者的方法和原始基于Transformer的方法的参数量见表2。

picture.image

4.5.2.性能分析

为了验证SCMHSA和语义相似度损失模块的贡献,作者在四个数据集(KTH、UCSD、UCF Sports、Penn Action)的测试集上进行了消融研究。表3显示,排除SCMHSA对不同数据集的影响各异。在KTH数据集上,没有SCMHSA的SC-VFP在均方误差(MSE)上降低了0.45%,在峰值信噪比(PSNR)上提高了0.07%。相比之下,在UCSD数据集上,引入SCMHSA取得了更好的性能,MSE降低了28.87%,PSNR提高了3.97%。在UCF Sports数据集上,排除SCMHSA的影响更为明显,MSE显著恶化了45.29%,PSNR恶化了9.89%。最后,在Penn Action数据集上,SCMHSA的引入分别将MSE和PSNR的值提高了35.71%和6.92%。这些结果表明SCMHSA机制在缓解语义稀释问题中的关键作用,从而对准确嵌入预测做出了重大贡献,特别是在更大、更复杂的数据集中。图5展示了带有和没有SCMHSA的SC-VFP在Penn Action数据集训练集上的收敛趋势。结果表明,带有SCMHSA的SCVFP不仅实现了更低的损失,而且比没有SCMHSA的变体收敛更快。

picture.image

picture.image

消融实验结果显示,语义相似度损失(SSL)在提升性能方面发挥着至关重要的作用(如表4所示)。即使在包含SCMHSA的情况下,省略SSL会导致MsE在四个数据集上的性能显著下降,分别降低了26.4%、48.2%、128.5%和96.8%。同样,在这四个数据集上,PSNR性能分别下降了3.27%、6.32%、14.06%和11.86%。这些发现强调了SSL对提高模型准确性的重大影响,即使在存在其他增强措施(如SCMHSA)的情况下也是如此。

picture.image

参考

[1]. Overcoming Semantic Dilution in Transformer-Based Next Frame Prediction .

点击上方卡片,关注

「AI视界引擎」

公众号

0
0
0
0
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论