漫画趣解:彻底搞懂模型蒸馏!

大模型机器学习算法

picture.image

向AI转型的程序员都关注公众号 机器学习AI算法工程

在人工智能的浪潮中,一个新兴的名词——“模型蒸馏”正逐渐走进大众视野。尤其随着DeepSeek的火爆,模型蒸馏更是成为热议的话题。那么,什么是模型蒸馏呢?

一、模型蒸馏的定义

模型蒸馏,简而言之,就是把大模型学到的知识,用“浓缩”的方式传授给小模型的过程。这样做的目的,是在保证一定精度的基础上,大幅降低运算成本和硬件要求。

picture.image

以DeepSeek为例,满血版671B参数量的DeepSeek R1就是“

教授模型

。学生模型包括↓

picture.image

二、模型蒸馏的过程

模型蒸馏的核心,是让大模型以“老师”的身份,向小模型这个“学生”传授解题思路。具体过程如下:

老师做示范:大模型针对每个输入问题,不会直接给出答案,而是提供解题思路(即软标签)。例如,输入一张猫的照片,大模型不会直接说“这是猫”,而是给出一张概率分布图,表明这张图片可能是什么。

picture.image

老师这么干,就是为了让学生具备

举一反三、触类旁通

的能力,用概率分布来对应各种类别的相似程度。

如果只告诉学生这是猫,学生就不知道它和老虎有多少差别。通过这种有概率分布的软标签,学生就知道了老师是如何判断、如何区分。

学生模仿学习:小模型在学习时,会结合自己原有数据集中的硬标签(如“猫就是猫”),再参考大模型的软标签,综合判断。

picture.image

建立学习标准:通过“蒸馏损失”来衡量小模型与大模型输出结果的差异,同时用“真实监督损失”来衡量小模型对基本问题的判断能力。通过设定平衡系数来调节这两种损失,以达到最优效果。

传统神经网络的交叉熵损失

在传统的神经网络训练中,我们通常用交叉熵损失(Cross-Entropy Loss)来训练分类模型:

其中:

  • 是真实类别的独热编码。
  • 是模型的预测概率,通常由 Softmax 变换得到。

其中

是模型最后一层的 logit 值。

picture.image

传统的交叉熵损失函数仅利用了数据的硬标签(hard labels),即

仅在真实类别处为 1,其他类别为 0,导致模型无法学习类别之间的相似性信息。

知识蒸馏的损失函数

在知识蒸馏中,教师模型提供了一种软标签(soft targets),即对所有类别的预测分布,而不仅仅是单个类别。

这些软标签由温度化 Softmax 得到。

其中:

  • 其中,

是第

类的未归一化分数(logits),

是温度系数,

是经过温度调整后的概率。

  • 较高的 T 值会使得概率分布更加平滑,保留更多类别之间的关系信息,从而提供更丰富的知识给学生模型。

在训练学生模型时,通常使用两部分损失函数:

  1. 硬标签损失(传统的交叉熵损失)

用于确保学生模型能够正确分类。 2. 软标签损失(基于 Kullback-Leibler 散度的损失)

用于让学生模型学习教师模型的类别间关系。

其中,

是教师模型生成的软标签(概率分布),

是学生模型输出的概率分布。

注意,软标签损失乘上了

,用于平衡温度因子对梯度的影响。

最终的总损失

是硬标签损失和软标签损失的加权和:

其中,

是一个超参数,用于控制硬标签损失和软标签损失的相对重要性。

通过加权组合这两部分损失,可以平衡学生模型对硬标签和软标签的学习。

picture.image

它会结合自己原有数据集中的硬标签(

猫就是猫、狗就是狗

),再参考老师的答案,最终给出自己的判断。

picture.image

实操中,用“蒸馏损失”来衡量学生模型与教授模型输出结果的差异。用“真实监督损失”来衡量学生模型对基本是非问题的判断。

然后,再设定一个平衡系数(α)来调节这两种损失,达到一个最优效果。

picture.image

说白了,学生模型要尽量模仿教授模型的行为,蒸馏损失越小越好,但是又不能学傻了,基本的是非问题都答不对。

标准确定后,就可以进入正式的蒸馏训练了。

把同一批训练样本分别输入到学生模型和教授模型;

根据硬标签和软标签,

对比结果,

结合权重,得到学生模型最终的损失值;

对学生模型进行参数更新,以得到更小的损失值。

不断重复这个过程

,就相当于反复刷题,每刷一轮,就找找学生答案和老师答案的差距,及时纠正。

经过多轮以后,学生的知识就会越来越扎实。

最终,蒸馏得到的小模型,尽量复制大模型的智慧,同时保持自己身轻如燕的优势。

三、模型蒸馏的类型

模型蒸馏主要分为两种类型:

知识蒸馏(输出层蒸馏):这是最常见、最通用的方式。小模型直接模仿大模型的最终输出。这种方式操作简单,即便大模型不开源,只要能调用其API,就能看到其知识输出并进行模仿。

相当于老师直接告诉你最后的答案,学生只需要抄作业,模仿老师的答案就行。

所以,有些模型比如GPT4,是明确声明不允许知识蒸馏的,但只要你能被调用,就没法避免别人偷师。

中间层蒸馏(特征层蒸馏):这种方式不仅学习大模型的最终判断结论,还学习其对图像/文本的内部理解。这需要大模型的配合,操作难度较高,但学习效果更好。

相当于学生不光看老师的最终答案,还要看老师的解题过程或中间步骤,从而更全面地学到思考方法。

但这种蒸馏方案,操作难度较高,通常需要教师模型允许,甚至主动配合,适用定制化的项目合作。

四、现实案例

李飞飞团队的s1模型:虽然被误传为仅用50美元训练而成,但实际上是基于通义Qwen2.5-32B的微调,且微调所用数据集部分蒸馏自Google Gemini 2.0 Flash Thinking。

所以,这个模型的诞生,是先通过知识蒸馏,从Gemini API获取推理轨迹和答案,辅助筛选出1000个高质量的数据样本。

然后,再用这个数据集,对通义Qwen2.5-32B进行微调,最终得到性能表现不错的s1模型。

这个微调过程,消耗了50美元的算力费用,但这背后,却是Gemini和Qwen两大模型无法估量的隐形成本。

这就好比,你“偷了”一位

名师

解题思路,给了一个

学霸

看,学霸本来就很NB,现在看完“思路”,变得更NB了。

严格来讲,Gemini 2.0作为闭源商业模型,虽然支持获得推理轨迹,但原则上是不允许用作蒸馏的,即便蒸馏出来也不能商用。不过如果仅是发发论文、做做学术研究、博博眼球,倒也无可厚非。

当然,不得不说,李的团队为我们打开了一种思路:我们可以站在巨人的肩膀上,用四两拨千斤的方法,去做一些创新。

比如,DeepSeek是MIT开源授权,代码和权重全开放,而且允许蒸馏(且支持获取推理轨迹)。

那么对于很多中小企业来讲,无异于巨大福利,大家可以轻松通过蒸馏和微调,获得自己的专属模型,还能商用。

代码案例分享

下面是一个完整的知识蒸馏的示例代码,使用 PyTorch 训练一个教师模型并将其知识蒸馏到学生模型。

这里,我们采用 MNIST 数据集,教师模型使用一个较大的神经网络,而学生模型是一个较小的神经网络。

首先,定义教师模型和学生模型。


        
            

          
 import torch
 
   

 
 import torch.nn as nn
 
   

 
 import torch.optim as optim
 
   

 
 import torch.nn.functional as F
 
   

 
 from torchvision import datasets, transforms
 
   

 
 from torch.utils.data import DataLoader
 
   

 
 import matplotlib.pyplot as plt
 
   

 
 
 # 教师模型(较大的神经网络)
 
 
   

 
 class TeacherModel(nn.Module):
 
   

 
     def \_\_init\_\_(self):
 
   

 
         super(TeacherModel, self).\_\_init\_\_()
 
   

 
         self.fc1 = nn.Linear(28 * 28, 512)
 
   

 
         self.fc2 = nn.Linear(512, 256)
 
   

 
         self.fc3 = nn.Linear(256, 10)
 
   

 
 
   

 
     def forward(self, x):
 
   

 
         x = x.view(-1, 28 * 28)
 
   

 
         x = F.relu(self.fc1(x))
 
   

 
         x = F.relu(self.fc2(x))
 
   

 
         x = self.fc3(x)  
 
 # 注意这里没有 Softmax
 
 
   

 
         
 
 return
 
  x
 
   

 
 
   

 
 
 # 学生模型(较小的神经网络)
 
 
   

 
 class StudentModel(nn.Module):
 
   

 
     def \_\_init\_\_(self):
 
   

 
         super(StudentModel, self).\_\_init\_\_()
 
   

 
         self.fc1 = nn.Linear(28 * 28, 128)
 
   

 
         self.fc2 = nn.Linear(128, 10)
 
   

 
 
   

 
     def forward(self, x):
 
   

 
         x = x.view(-1, 28 * 28)
 
   

 
         x = F.relu(self.fc1(x))
 
   

 
         x = self.fc2(x)  
 
 # 注意这里没有 Softmax
 
 
   

 
         
 
 return
 
  x
 
          
   

 
        
      

然后加载数据集。


        
            

          
 
 # 数据预处理
 
 
   

 
 transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
 
   

 
 
   

 
 
 # 加载 MNIST 数据集
 
 
   

 
 train\_dataset = datasets.MNIST(root=
 
 "./data"
 
 , train=True, download=True, transform=transform)
 
   

 
 test\_dataset = datasets.MNIST(root=
 
 "./data"
 
 , train=False, download=True, transform=transform)
 
   

 
 
   

 
 train\_loader = DataLoader(train\_dataset, batch\_size=64, shuffle=True)
 
   

 
 test\_loader = DataLoader(test\_dataset, batch\_size=1000, shuffle=False)
 
          
   

 
        
      

训练教师模型


        
            

          
 def train\_teacher(model, train\_loader, epochs=5, lr=0.001):
 
   

 
     optimizer = optim.Adam(model.parameters(), lr=lr)
 
   

 
     criterion = nn.CrossEntropyLoss()
 
   

 
     
 
   

 
     
 
 for
 
  epoch 
 
 in
 
  range(epochs):
 
   

 
         model.train()
 
   

 
         total\_loss = 0
 
   

 
         
 
   

 
         
 
 for
 
  images, labels 
 
 in
 
  train\_loader:
 
   

 
             optimizer.zero\_grad()
 
   

 
             output = model(images)
 
   

 
             loss = criterion(output, labels)
 
   

 
             loss.backward()
 
   

 
             optimizer.step()
 
   

 
             total\_loss += loss.item()
 
   

 
         
 
   

 
         
 
 print
 
 (f
 
 "Epoch [{epoch+1}/{epochs}], Loss: {total\_loss / len(train\_loader):.4f}"
 
 )
 
   

 
 
   

 
 
 # 初始化并训练教师模型
 
 
   

 
 teacher\_model = TeacherModel()
 
   

 
 train\_teacher(teacher\_model, train\_loader)
 
          
   

 
        
      

知识蒸馏训练学生模型


        
            

          
 def distillation\_loss(student\_logits, teacher\_logits, labels, T=3.0, alpha=0.5):
 
   

 
     
 
 """
 
   

 
     计算蒸馏损失,结合知识蒸馏损失和交叉熵损失
 
   

 
     """
 
 
   

 
     soft\_targets = F.softmax(teacher\_logits / T, dim=1)  
 
 # 教师模型的软标签
 
 
   

 
     soft\_predictions = F.log\_softmax(student\_logits / T, dim=1)  
 
 # 学生模型的预测
 
 
   

 
     
 
   

 
     distillation\_loss = F.kl\_div(soft\_predictions, soft\_targets, reduction=
 
 "batchmean"
 
 ) * (T ** 2)
 
   

 
     ce\_loss = F.cross\_entropy(student\_logits, labels)
 
   

 
     
 
   

 
     
 
 return
 
  alpha * ce\_loss + (1 - alpha) * distillation\_loss
 
   

 
 
   

 
 def train\_student\_with\_distillation(student\_model, teacher\_model, train\_loader, epochs=5, lr=0.001, T=3.0, alpha=0.5):
 
   

 
     optimizer = optim.Adam(student\_model.parameters(), lr=lr)
 
   

 
     
 
   

 
     teacher\_model.eval()  
 
 # 设定教师模型为评估模式
 
 
   

 
     
 
 for
 
  epoch 
 
 in
 
  range(epochs):
 
   

 
         student\_model.train()
 
   

 
         total\_loss = 0
 
   

 
         
 
   

 
         
 
 for
 
  images, labels 
 
 in
 
  train\_loader:
 
   

 
             optimizer.zero\_grad()
 
   

 
             student\_logits = student\_model(images)
 
   

 
             with torch.no\_grad():
 
   

 
                 teacher\_logits = teacher\_model(images)  
 
 # 获取教师模型输出
 
 
   

 
             
 
   

 
             loss = distillation\_loss(student\_logits, teacher\_logits, labels, T=T, alpha=alpha)
 
   

 
             loss.backward()
 
   

 
             optimizer.step()
 
   

 
             total\_loss += loss.item()
 
   

 
         
 
   

 
         
 
 print
 
 (f
 
 "Epoch [{epoch+1}/{epochs}], Loss: {total\_loss / len(train\_loader):.4f}"
 
 )
 
   

 
 
   

 
 
 # 初始化学生模型
 
 
   

 
 student\_model = StudentModel()
 
   

 
 train\_student\_with\_distillation(student\_model, teacher\_model, train\_loader)
 
          
   

 
        
      

评估模型

def evaluate(model, test_loader):

model.eval()




correct = 0




total = 0









with torch.no\_grad():




    

for

images, labels

in

test_loader:

        outputs = model(images)




        \_, predicted = torch.max(outputs, 1)




        correct += (predicted == labels).sum().item()




        total += labels.size(0)









accuracy = 100 * correct / total




return

accuracy

评估教师模型

teacher_acc = evaluate(teacher_model, test_loader)

print

(f

"教师模型准确率: {teacher_acc:.2f}%"

)

评估知识蒸馏训练的学生模型

student_acc_distilled = evaluate(student_model, test_loader)

print

(f

"知识蒸馏训练的学生模型准确率: {student_acc_distilled:.2f}%"

)

picture.image

结语

模型蒸馏为人工智能领域带来了一种新的优化手段,使得小模型能够在保持低成本的同时,获得接近大模型的能力。随着技术的不断发展,我们有理由相信,模型蒸馏将在更多领域发挥重要作用,推动人工智能技术的普及和应用。

机器学习算法AI大数据技术

搜索公众号添加: datanlp

picture.image

长按图片,识别二维码

阅读过本文的人还看了以下文章:

实时语义分割ENet算法,提取书本/票据边缘

整理开源的中文大语言模型,以规模较小、可私有化部署、训练成本较低的模型为主

《大语言模型》PDF下载

动手学深度学习-(李沐)PyTorch版本

YOLOv9电动车头盔佩戴检测,详细讲解模型训练

TensorFlow 2.0深度学习案例实战

基于40万表格数据集TableBank,用MaskRCNN做表格检测

《基于深度学习的自然语言处理》中/英PDF

Deep Learning 中文版初版-周志华团队

【全套视频课】最全的目标检测算法系列讲解,通俗易懂!

《美团机器学习实践》_美团算法团队.pdf

《深度学习入门:基于Python的理论与实现》高清中文PDF+源码

《深度学习:基于Keras的Python实践》PDF和代码

特征提取与图像处理(第二版).pdf

python就业班学习视频,从入门到实战项目

2019最新《PyTorch自然语言处理》英、中文版PDF+源码

《21个项目玩转深度学习:基于TensorFlow的实践详解》完整版PDF+附书代码

《深度学习之pytorch》pdf+附书源码

PyTorch深度学习快速实战入门《pytorch-handbook》

【下载】豆瓣评分8.1,《机器学习实战:基于Scikit-Learn和TensorFlow》

《Python数据分析与挖掘实战》PDF+完整源码

汽车行业完整知识图谱项目实战视频(全23课)

李沐大神开源《动手学深度学习》,加州伯克利深度学习(2019春)教材

笔记、代码清晰易懂!李航《统计学习方法》最新资源全套!

《神经网络与深度学习》最新2018版中英PDF+源码

将机器学习模型部署为REST API

FashionAI服装属性标签图像识别Top1-5方案分享

重要开源!CNN-RNN-CTC 实现手写汉字识别

yolo3 检测出图像中的不规则汉字

同样是机器学习算法工程师,你的面试为什么过不了?

前海征信大数据算法:风险概率预测

【Keras】完整实现‘交通标志’分类、‘票据’分类两个项目,让你掌握深度学习图像分类

VGG16迁移学习,实现医学图像识别分类工程项目

特征工程(一)

特征工程(二) :文本数据的展开、过滤和分块

特征工程(三):特征缩放,从词袋到 TF-IDF

特征工程(四): 类别特征

特征工程(五): PCA 降维

特征工程(六): 非线性特征提取和模型堆叠

特征工程(七):图像特征提取和深度学习

如何利用全新的决策树集成级联结构gcForest做特征工程并打分?

Machine Learning Yearning 中文翻译稿

蚂蚁金服2018秋招-算法工程师(共四面)通过

全球AI挑战-场景分类的比赛源码(多模型融合)

斯坦福CS230官方指南:CNN、RNN及使用技巧速查(打印收藏)

python+flask搭建CNN在线识别手写中文网站

中科院Kaggle全球文本匹配竞赛华人第1名团队-深度学习与特征工程

不断更新资源

深度学习、机器学习、数据分析、python

搜索公众号添加: datayx

picture.image

0
0
0
0
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论