点击下方卡片,关注 「AI视界引擎」 公众号
( 添加时备注:方向+学校/公司+昵称/姓名 )
最近的研究越来越集中在知识蒸馏领域,因为 logit 蒸馏具有简单性、有效性和模型压缩的多样性。在本文中,作者提出了改进型 logit 蒸馏(RLD),以解决现有 logit 蒸馏方法的局限性。
作者的方法是由观察到即使高性能的教师模型也会做出错误的预测而引起的,这种冲突使得标准蒸馏损失与交叉熵损失之间产生了矛盾,进而可能破坏学生模型的学习目标的一致性。
之前使用标签来实证修正教师预测可能削弱了类相关的稳定性。
相比之下,作者的 RLD 方法使用标记信息动态地改进教师 logit。这样,作者的方法可以有效地消除教师中的误导信息,同时保留关键的类相关性,从而提高蒸馏知识的价值和效率。
在 CIFAR-100 和 ImageNet 上的实验结果表明,它优于现有方法。
1 Introduction
知识蒸馏 [12]利用预训练的高性能老师模型来促进学生模型的训练。相比其他模型压缩方法(如修剪和量化和 [1]),知识蒸馏受到模型架构的约束更少。这种灵活性大大拓宽了其应用范围,使其在最近的学术研究中变得日益突出。
Hinton等人 hinton2015effective 是首位引入对数概率蒸馏概念的。这种蒸馏方法旨在在对数操作之后,通过对数对齐教师和学生模型的概率。后续的大部分研究都维持了对数概率蒸馏的原始概念,并深入研究了更深入的特征蒸馏过程,特别是研究了在教师和学生模型之间选择和对齐中间层特征的过程。然而,教师模型和学生模型之间的潜在架构差异会对特征对齐造成重大挑战,这是因为不同的架构提取不同的特征 [2]。此外,特征选择的广泛多样性进一步加剧了特征蒸馏的复杂性,导致蒸馏训练时间增加 [1]。最近,通过将经典对数概率蒸馏损失解耦,赵等人赵2022的研究表明,对数概率蒸馏可以生成与特征蒸馏相当或甚至优于的结果。因此,对数概率蒸馏在研究社区中引起了相当大的关注,主要得益于其简洁性、有效性和普适性。
然而,尽管最近的对数概率蒸馏方法 取得了令人印象深刻的成就,但大多数人都忽视了老师预测精度对训练过程的影响。具体来说,错误的老师预测会导致老师损失和标签损失之间的冲突,这可能会严重阻碍学生模型的进一步改进潜力。现有的更正型蒸馏方法一致性地使用标签信息修改老师概率(目标)。他们要么在预测的最大类别和真实类别之间交换值(_swap_ 操作),要么通过放大预测概率中的真实类别比例(_augment_ 操作)来增强这一比例。作者认为,这些方法可能会改变类别之间的相关性,如图1所示。这种破坏可能会阻碍“暗知识” [10]的传递并阻碍性能改进。
作者提出了一种名为“Refined Logit Distillation” (RLD)的方法来解决这些问题。 RLD 的目标是利用学生模型吸收教师的宝贵知识来缓解学生在训练过程中产生的损失冲突。具体地说,RLD包括两种类型的知识,_sample confidence_ (SC)和_masked correlation_ (MC)。样本信心是指对数概率生成的二进制概率。在老师模型中,SC来自于预测类别及其余类别的概率。这个指标概括了教师对当前样本的预测信心,并可以用于引导学生模型。考虑到老师预测的不准确性,作者将学生真实类别的预测概率与教师的预测概率对齐。这种对齐不仅减弱了老师的错误,而且引导学生模型实现与当前样本相似的信心水平。此外,它也可以有效防止过拟合。 Mask 相关性表示作者的动态方法用于选择用于教师和学生对齐的部分类。它旨在减轻老师模型可能错误对学生在模型中的影响,同时传递重要的类相关性。更具体地说,MC涉及对教师对数中具有等于或高于真实类排名的所有类的屏蔽。本质上,在对数概率模型中,当老师出现更多错误(即真实类排名较低),用于蒸馏的类别会减少,而当老师出现较少错误时,用于蒸馏的类别会增加。使用这两种互补类型的精炼知识,学生可以实现更好的性能。
作者的贡献可以总结如下:
- 作者揭示了目前普遍的蒸馏方法未能考虑错误的老师预测的影响,而现有的更正型策略往往破坏了有价值的类相关性。
- 作者引入了一种新颖的对数概率蒸馏方法,即 Refined Logit Distillation (RLD),以防止过拟合并减轻错误的老师知识影响,同时保留重要的类相关性。
- 在 CIFAR-100 和 ImageNet 数据集上,作者进行了全面的实验,验证了作者提出的 RLD 方法的优越性能。
2 Related Work
知识蒸馏的历史应用主要集中于图像分类任务,逐步扩展到更广泛的任务,包括语义分割、图模型(西利曼斯和霍,2022年;孙等人,2023年)等计算机视觉领域的任务。传统知识蒸馏通常涉及单个教师和单个学生模型。随着领域的发展,提出了一系列其他范式,如在线蒸馏、多教师蒸馏、多学生模型、自蒸馏。由于传统知识蒸馏仍是该领域研究的核心基础,作者将在下文讨论中仅关注此类方法。
在图像分类任务中,现有算法可以大致分为三类:对数离散、学生模型之间的蒸馏。对数离散已经成为当前研究的主要方向,因为其直观、有效、可适应性强。最初的logit蒸馏Hinton等人(2015年)利用KL散度将教师和学生模型的软化输出对齐,从而显著提高学生模型的性能。DKD,赵等人(2022年)通过解耦这一经典的损失,使logit蒸馏与特征蒸馏可比。MLKD,金等人(2023年)利用多级logit知识进一步提高模型性能。CTKD,李等人(2023年)引入了课程温度,应用对抗训练和课程学习来动态确定每个样本的蒸馏温度。LSKD,孙等人(2024年)将logits适当地分配给教师和学生在样本之间的温度,从而在该领域实现了最先进的性能。然而,教师错误预测对蒸馏的影响很少被考虑。
由于logits与预测准确性本质相关,有很多方法利用标签在蒸馏过程之前调整logits。LA,温等人(2021年)交换真实类和预测类的值以纠正教师模型的预测。RC,曹等人(2023年)放大学生输出分布中真实类的预计值,从而帮助学生模型做出准确而自信的预测。LR,兰等人(2024年)将One-Hot标签与教师的软标签结合,产生新的精确目标以进行蒸馏。然而,和之前证明的,这些方法可能会破坏类之间的相关性,从而阻碍性能提升。
3 Preliminaries
为帮助读者更好地理解知识蒸馏相关概念,作者提供一个概述。
考虑一个包含C类别的图像分类任务。作者有一个预训练的教师模型和学生模型。对于任意输入图像,教师和学生的输出logits分别为和。利用softmax函数,作者计算预测分布和如下:
其中代表第i类预测的值。
为了训练学生模型,作者首先计算学生预测和one-hot地面真标签之间的交叉熵损失:
其次,作者将教师模型和学生的软化预测对齐,使用KL散度:
其中表示softmax操作的温度。
通过结合等式2和3,作者得到随机梯度下降的经典logit蒸馏损失。这种方法在实验中被证明比仅使用标签训练更好。
4 Methodology
在本节中,作者将深入介绍作者提出的 RLD(Random Length Dimension)。如图2所示,有一个简要的 RLD 概述。
Sample Confidence Distillation
样本置信度(SC)是来自对数输出的二进制分布。它包含了每个样本的模型置信度,从而帮助学生的模型生成对真实类的置信度较高的预测,同时不过分限制其他类的分布。
在教师知识背景下,样本置信度的组成部分之一是最大预测概率值 ,而另一个是剩余类别的预测概率之和。然而,学生的SC包括真类预测概率 ,其余类别的预测概率包括在内的其他组成部分。可以总结在以下公式中:
为了实现知识的转移,作者使用KL散度将和对齐:
虽然熵可能潜在地可用于衡量样本置信度,但对作者来说并不适合。SC的目标是使学生模型在保持对真实类的置信度相似水平的同时,不阻碍其他类的预测概率。然而,直接传输熵并未对学生的模型中的真实类施加这个约束,如图3(a)和(b)所示。
Masked Correlation Distillation
Mask 相关性(MC)指的是在动态地 Mask 某些类别后获得概率分布。如图3(c)所示,这种 Mask 操作使学生模型免去了不正确的类别排序的匹配,从而让学生模型可以在不造成巨大损失的情况下,从教师模型中生成非常不同的输出。此外,保留部分类别概率使学生模型学习出有价值的类别关联,从而提高模型性能。
具体地, Mask 是从教师逻辑和标签动态地推导而来。作者将所有类别的逻辑值大于或等于(表示为 "ge")真实类别的逻辑值的目标作为 Mask 操作的指定目标,可以表示如下:
虽然 Mask 掉所有大于(表示为 "g")真实类别的类别可以消除错误信息并保留类别关联,但是这个称为 的 Mask 策略(无意间)地将真实类别相关的知识融合到了 Mask 相关性和样本置信中。如表1所示,这可能会引入损失之间的冲突并影响模型性能。因此,作者选择 作为 Mask 策略。
在得到 Mask 后,作者使用以下公式计算对齐的概率分布:
其中 且 成立。
作者总结了遮mask相关性知识的消融损失如下:
当教师模型做出更准确的预测(将真实类别排名更高)时,只有少数类别受到 Mask 操作。这使得大部分类别关联得以保留并传递给学生模型。相反,如果教师预测不准确,大部分类别被 Mask 。因此,学生模型学习到的知识较少,从而降低了错误信息误导训练过程的可能性。同时,这给了学生模型在 Mask 类别的预测上更大自由度,如与教师模型的差异很大的类别。
Refined Logit Distillation
通过将式(2)、式(6)和式(9)相结合,作者得到了最终随机标签分布(RLD)的损失表示式:
其中,和是超参数,分别用于调整样本置信度和 Mask 相关性的重要性。具体算法见算法1,此处省略了温度。
接下来,作者阐明每个损失函数在协同提升模型性能中的作用:
- 鼓励学生模型生成对真实类别的最高概率,但在独立部署时可能导致过拟合。
- 期望学生模型达到对真实类别的合理置信水平,从而避免过拟合。然而,在单独使用时,它可能导致分配给剩余类别的概率超过真实类别的概率。当与结合时,真实类能够保持最高且最适中的概率。然而,这种组合在传递剩余类知识方面有所欠缺。
- 由于在分布对齐过程中 consistently Mask 了真实类,因此在信息传递方面存在局限,但在消除误信息和向学生模型传递有价值的类关联方面具有能力。通过将、与相结合,可以确保传递大量有价值知识。对于DKD(Distribution Knowledge Distillation)的相关性:有趣的是,尽管作者的方法与DKD[17]从不同的角度考虑分位数压缩,但当教师模型一致做出准确预测时,它们实际上是等效的。此外,DKD并没有明确阐述为什么在强调非目标类知识方面它是有效的,而作者的方法暗示,通过在分布对齐过程中 Mask 真实类,学生模型得到了更大的自主性来调整真实类的排序,从而有助于更准确的预测。
5 Experiments
Settings
数据集。作者在两个标准的图像分类数据集上进行实验:CIFAR-100 [10]和ImageNet [21]。CIFAR-100包含100个不同的类别,训练集有50,000张图像,验证集有10,000张图像。这个数据集中的每张图像大小为32 x 32像素。ImageNet是一个更大的、更复杂的数据集,涵盖1000个类别。它包括128万张训练图像和50,000张验证图像,预处理后每个图像大小为224 x 224像素。
模型。教师和学生使用的模型包括ResNet [13],WideResNet(WRN) [12],VGG [23],ShuffleNet(SHN) [14, 17],和MobileNet(MN) [15, 16]。实验结果包含了异质和同质教师-学生模型的蒸馏结果。
比较方法。实验中涉及到的比较方法包括特征蒸馏和logit蒸馏方法。特征蒸馏方法包括FitNet [15],AT [1],RKD [18],CRD [17],OFD [13],ReviewKD [19],SimKD [15],和CAT-KD [13]。logit蒸馏方法包括KD [17],DTKD [18],DKD [19],LR [17]。值得一提的是,后三种方法基于校正技术,并采用作者统一的训练框架重新实现。
Main Results
CIFAR-100. RLD与其他蒸馏方法的top-1验证精确率(%)比较结果如图2(异构蒸馏对)和图3(同构蒸馏对)所示。作者可以看出,在所有情况下,RLD要么是最佳蒸馏算法,要么是次佳蒸馏算法,并且在大部分情况下都是最优的。这强调了RLD的优势,并突出了对教师预测进行纠正的重要性。与此同时,尽管特征蒸馏有时会超过对数分布的蒸馏,但特征蒸馏的最优算法表现出更高的不稳定性和无法长期稳定地与某种算法一起表现良好的倾向。此外,特征蒸馏会增加训练时间,并要求更复杂的算法设计,可能阻碍其实际应用性。
ImageNet. RLD与其他蒸馏方法的top-1和top-5验证精确率(%)比较结果如图4所示。在这更具挑战性的数据集上,RLD成功超越了所有现有特征和对数分布蒸馏算法,始终实现最佳性能,证明了作者的方法的优势。
Visualizations
在本节中,作者通过可视化更直观地表示不同方法的差异。教师模型为ResNet324,学生模型为 ResNet84。
特征可视化。 作者使用t-SNE [16]来可视化KD和RLD生成的学生模型的输出特征。如图4所示,RDL获取的特征具有更强的判别性。
逻辑单元差异可视化。 作者定量表示教师模型和通过DKD和RLD得到的学生模型在每类之间的对数单元差异,可视化这些结果如图5的热力图。尽管RLD优于DKD,但作者观察到RDL产生的对数单元差异大于DKD。这一观察结果与作者的预期相符,考虑到 RLD纠正了某些教师知识的误差,为学生提供了更大的自主性来制定自己的预测。这一发现表明,不考虑教师知识对齐可能不是最优策略,作者认为校正方法值得更多关注和研究。
Extensions
Ablation Study.
作者在RLD的组成部分上进行了一次消融实验,结果如表5所示。结果表明,RLD的每个组成部分对性能提升都具有有效的贡献。值得关注的是,当或单独启用(第2和第3行)时,在各温度设置下都取得了最佳结果。然而,当将和结合在一起(第4行)时,两个组成部分采用相同的温度设置优于设置不同的温度。这表明,损失之间的相互作用对于实现更好的性能也至关重要。
Logit标准化(Logit Standardization)。作者在RLD的基础上研究了添加logit标准化技术Sun等人(2024)的有效性。结果如表6所示。RLD实现的最佳结果证明了其卓越的性能和与其他方法的广泛整合潜力。
6 Conclusion
本文旨在解决现有知识蒸馏方法中存在的问题,例如未知教师预测的影响和 Teacher模型预测的随意修正,以及破坏了类相关性。
作者介绍了改进的逻辑斯蒂区分蒸馏方法(Refined Logit Distillation, RLD),该方法使教师模型可以向学生模型传授两种独特形式的知识:样本置信度和掩盖的相关性。
这些方法可以有效减轻过拟合并从教师模型中消除潜在的错误信息,同时保持类相关性,进而使学生模型获得更有价值的信息。实验结果表明了所提出方法的优越性。
参考
[1].Knowledge Distillation with Refined Logits.
点击上方卡片,关注 「AI视界引擎」 公众号
