向AI转型的程序员都关注公众号 机器学习AI算法工程
在人工智能的浪潮中,一个新兴的名词——“模型蒸馏”正逐渐走进大众视野。尤其随着DeepSeek的火爆,模型蒸馏更是成为热议的话题。那么,什么是模型蒸馏呢?
一、模型蒸馏的定义
模型蒸馏,简而言之,就是把大模型学到的知识,用“浓缩”的方式传授给小模型的过程。这样做的目的,是在保证一定精度的基础上,大幅降低运算成本和硬件要求。
以DeepSeek为例,满血版671B参数量的DeepSeek R1就是“
教授模型 ”
。学生模型包括↓
二、模型蒸馏的过程
模型蒸馏的核心,是让大模型以“老师”的身份,向小模型这个“学生”传授解题思路。具体过程如下:
老师做示范:大模型针对每个输入问题,不会直接给出答案,而是提供解题思路(即软标签)。例如,输入一张猫的照片,大模型不会直接说“这是猫”,而是给出一张概率分布图,表明这张图片可能是什么。
老师这么干,就是为了让学生具备
举一反三、触类旁通
的能力,用概率分布来对应各种类别的相似程度。
如果只告诉学生这是猫,学生就不知道它和老虎有多少差别。通过这种有概率分布的软标签,学生就知道了老师是如何判断、如何区分。
学生模仿学习:小模型在学习时,会结合自己原有数据集中的硬标签(如“猫就是猫”),再参考大模型的软标签,综合判断。
建立学习标准:通过“蒸馏损失”来衡量小模型与大模型输出结果的差异,同时用“真实监督损失”来衡量小模型对基本问题的判断能力。通过设定平衡系数来调节这两种损失,以达到最优效果。
传统神经网络的交叉熵损失
在传统的神经网络训练中,我们通常用交叉熵损失(Cross-Entropy Loss)来训练分类模型:
其中:
- 是真实类别的独热编码。
- 是模型的预测概率,通常由 Softmax 变换得到。
其中
是模型最后一层的 logit 值。
传统的交叉熵损失函数仅利用了数据的硬标签(hard labels),即
仅在真实类别处为 1,其他类别为 0,导致模型无法学习类别之间的相似性信息。
知识蒸馏的损失函数
在知识蒸馏中,教师模型提供了一种软标签(soft targets),即对所有类别的预测分布,而不仅仅是单个类别。
这些软标签由温度化 Softmax 得到。
其中:
- 其中,
是第
类的未归一化分数(logits),
是温度系数,
是经过温度调整后的概率。
- 较高的 T 值会使得概率分布更加平滑,保留更多类别之间的关系信息,从而提供更丰富的知识给学生模型。
在训练学生模型时,通常使用两部分损失函数:
- 硬标签损失(传统的交叉熵损失)
用于确保学生模型能够正确分类。 2. 软标签损失(基于 Kullback-Leibler 散度的损失)
用于让学生模型学习教师模型的类别间关系。
其中,
是教师模型生成的软标签(概率分布),
是学生模型输出的概率分布。
注意,软标签损失乘上了
,用于平衡温度因子对梯度的影响。
最终的总损失
是硬标签损失和软标签损失的加权和:
其中,
是一个超参数,用于控制硬标签损失和软标签损失的相对重要性。
通过加权组合这两部分损失,可以平衡学生模型对硬标签和软标签的学习。
它会结合自己原有数据集中的硬标签(
猫就是猫、狗就是狗
),再参考老师的答案,最终给出自己的判断。
实操中,用“蒸馏损失”来衡量学生模型与教授模型输出结果的差异。用“真实监督损失”来衡量学生模型对基本是非问题的判断。
然后,再设定一个平衡系数(α)来调节这两种损失,达到一个最优效果。
说白了,学生模型要尽量模仿教授模型的行为,蒸馏损失越小越好,但是又不能学傻了,基本的是非问题都答不对。
标准确定后,就可以进入正式的蒸馏训练了。
❶
把同一批训练样本分别输入到学生模型和教授模型;
❷
根据硬标签和软标签,
对比结果,
结合权重,得到学生模型最终的损失值;
❸
对学生模型进行参数更新,以得到更小的损失值。
不断重复这个过程
❶
→
❷
→
❸
,就相当于反复刷题,每刷一轮,就找找学生答案和老师答案的差距,及时纠正。
经过多轮以后,学生的知识就会越来越扎实。
最终,蒸馏得到的小模型,尽量复制大模型的智慧,同时保持自己身轻如燕的优势。
三、模型蒸馏的类型
模型蒸馏主要分为两种类型:
知识蒸馏(输出层蒸馏):这是最常见、最通用的方式。小模型直接模仿大模型的最终输出。这种方式操作简单,即便大模型不开源,只要能调用其API,就能看到其知识输出并进行模仿。
相当于老师直接告诉你最后的答案,学生只需要抄作业,模仿老师的答案就行。
所以,有些模型比如GPT4,是明确声明不允许知识蒸馏的,但只要你能被调用,就没法避免别人偷师。
中间层蒸馏(特征层蒸馏):这种方式不仅学习大模型的最终判断结论,还学习其对图像/文本的内部理解。这需要大模型的配合,操作难度较高,但学习效果更好。
相当于学生不光看老师的最终答案,还要看老师的解题过程或中间步骤,从而更全面地学到思考方法。
但这种蒸馏方案,操作难度较高,通常需要教师模型允许,甚至主动配合,适用定制化的项目合作。
四、现实案例
李飞飞团队的s1模型:虽然被误传为仅用50美元训练而成,但实际上是基于通义Qwen2.5-32B的微调,且微调所用数据集部分蒸馏自Google Gemini 2.0 Flash Thinking。
所以,这个模型的诞生,是先通过知识蒸馏,从Gemini API获取推理轨迹和答案,辅助筛选出1000个高质量的数据样本。
然后,再用这个数据集,对通义Qwen2.5-32B进行微调,最终得到性能表现不错的s1模型。
这个微调过程,消耗了50美元的算力费用,但这背后,却是Gemini和Qwen两大模型无法估量的隐形成本。
这就好比,你“偷了”一位
名师
解题思路,给了一个
学霸
看,学霸本来就很NB,现在看完“思路”,变得更NB了。
严格来讲,Gemini 2.0作为闭源商业模型,虽然支持获得推理轨迹,但原则上是不允许用作蒸馏的,即便蒸馏出来也不能商用。不过如果仅是发发论文、做做学术研究、博博眼球,倒也无可厚非。
当然,不得不说,李的团队为我们打开了一种思路:我们可以站在巨人的肩膀上,用四两拨千斤的方法,去做一些创新。
比如,DeepSeek是MIT开源授权,代码和权重全开放,而且允许蒸馏(且支持获取推理轨迹)。
那么对于很多中小企业来讲,无异于巨大福利,大家可以轻松通过蒸馏和微调,获得自己的专属模型,还能商用。
代码案例分享
下面是一个完整的知识蒸馏的示例代码,使用 PyTorch 训练一个教师模型并将其知识蒸馏到学生模型。
这里,我们采用 MNIST 数据集,教师模型使用一个较大的神经网络,而学生模型是一个较小的神经网络。
首先,定义教师模型和学生模型。
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
# 教师模型(较大的神经网络)
class TeacherModel(nn.Module):
def \_\_init\_\_(self):
super(TeacherModel, self).\_\_init\_\_()
self.fc1 = nn.Linear(28 * 28, 512)
self.fc2 = nn.Linear(512, 256)
self.fc3 = nn.Linear(256, 10)
def forward(self, x):
x = x.view(-1, 28 * 28)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
# 注意这里没有 Softmax
return
x
# 学生模型(较小的神经网络)
class StudentModel(nn.Module):
def \_\_init\_\_(self):
super(StudentModel, self).\_\_init\_\_()
self.fc1 = nn.Linear(28 * 28, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = x.view(-1, 28 * 28)
x = F.relu(self.fc1(x))
x = self.fc2(x)
# 注意这里没有 Softmax
return
x
然后加载数据集。
# 数据预处理
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
# 加载 MNIST 数据集
train\_dataset = datasets.MNIST(root=
"./data"
, train=True, download=True, transform=transform)
test\_dataset = datasets.MNIST(root=
"./data"
, train=False, download=True, transform=transform)
train\_loader = DataLoader(train\_dataset, batch\_size=64, shuffle=True)
test\_loader = DataLoader(test\_dataset, batch\_size=1000, shuffle=False)
训练教师模型
def train\_teacher(model, train\_loader, epochs=5, lr=0.001):
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
for
epoch
in
range(epochs):
model.train()
total\_loss = 0
for
images, labels
in
train\_loader:
optimizer.zero\_grad()
output = model(images)
loss = criterion(output, labels)
loss.backward()
optimizer.step()
total\_loss += loss.item()
print
(f
"Epoch [{epoch+1}/{epochs}], Loss: {total\_loss / len(train\_loader):.4f}"
)
# 初始化并训练教师模型
teacher\_model = TeacherModel()
train\_teacher(teacher\_model, train\_loader)
知识蒸馏训练学生模型
def distillation\_loss(student\_logits, teacher\_logits, labels, T=3.0, alpha=0.5):
"""
计算蒸馏损失,结合知识蒸馏损失和交叉熵损失
"""
soft\_targets = F.softmax(teacher\_logits / T, dim=1)
# 教师模型的软标签
soft\_predictions = F.log\_softmax(student\_logits / T, dim=1)
# 学生模型的预测
distillation\_loss = F.kl\_div(soft\_predictions, soft\_targets, reduction=
"batchmean"
) * (T ** 2)
ce\_loss = F.cross\_entropy(student\_logits, labels)
return
alpha * ce\_loss + (1 - alpha) * distillation\_loss
def train\_student\_with\_distillation(student\_model, teacher\_model, train\_loader, epochs=5, lr=0.001, T=3.0, alpha=0.5):
optimizer = optim.Adam(student\_model.parameters(), lr=lr)
teacher\_model.eval()
# 设定教师模型为评估模式
for
epoch
in
range(epochs):
student\_model.train()
total\_loss = 0
for
images, labels
in
train\_loader:
optimizer.zero\_grad()
student\_logits = student\_model(images)
with torch.no\_grad():
teacher\_logits = teacher\_model(images)
# 获取教师模型输出
loss = distillation\_loss(student\_logits, teacher\_logits, labels, T=T, alpha=alpha)
loss.backward()
optimizer.step()
total\_loss += loss.item()
print
(f
"Epoch [{epoch+1}/{epochs}], Loss: {total\_loss / len(train\_loader):.4f}"
)
# 初始化学生模型
student\_model = StudentModel()
train\_student\_with\_distillation(student\_model, teacher\_model, train\_loader)
评估模型
def evaluate(model, test_loader):
model.eval()
correct = 0
total = 0
with torch.no\_grad():
for
images, labels
in
test_loader:
outputs = model(images)
\_, predicted = torch.max(outputs, 1)
correct += (predicted == labels).sum().item()
total += labels.size(0)
accuracy = 100 * correct / total
return
accuracy
teacher_acc = evaluate(teacher_model, test_loader)
(f
"教师模型准确率: {teacher_acc:.2f}%"
)
student_acc_distilled = evaluate(student_model, test_loader)
(f
"知识蒸馏训练的学生模型准确率: {student_acc_distilled:.2f}%"
)
结语
模型蒸馏为人工智能领域带来了一种新的优化手段,使得小模型能够在保持低成本的同时,获得接近大模型的能力。随着技术的不断发展,我们有理由相信,模型蒸馏将在更多领域发挥重要作用,推动人工智能技术的普及和应用。
机器学习算法AI大数据技术
搜索公众号添加: datanlp
长按图片,识别二维码
阅读过本文的人还看了以下文章:
整理开源的中文大语言模型,以规模较小、可私有化部署、训练成本较低的模型为主
基于40万表格数据集TableBank,用MaskRCNN做表格检测
《深度学习入门:基于Python的理论与实现》高清中文PDF+源码
2019最新《PyTorch自然语言处理》英、中文版PDF+源码
《21个项目玩转深度学习:基于TensorFlow的实践详解》完整版PDF+附书代码
PyTorch深度学习快速实战入门《pytorch-handbook》
【下载】豆瓣评分8.1,《机器学习实战:基于Scikit-Learn和TensorFlow》
李沐大神开源《动手学深度学习》,加州伯克利深度学习(2019春)教材
【Keras】完整实现‘交通标志’分类、‘票据’分类两个项目,让你掌握深度学习图像分类
如何利用全新的决策树集成级联结构gcForest做特征工程并打分?
Machine Learning Yearning 中文翻译稿
斯坦福CS230官方指南:CNN、RNN及使用技巧速查(打印收藏)
中科院Kaggle全球文本匹配竞赛华人第1名团队-深度学习与特征工程
不断更新资源
深度学习、机器学习、数据分析、python
搜索公众号添加: datayx