一文全览 | 知识蒸馏算法汇总

技术

点击下方卡片,关注「集智书童」公众号

点击加入👉「集智书童-模型蒸馏」交流群

作者丨AI浩@知乎 来源丨https://zhuanlan.zhihu.com/p/583273832 编辑丨小书童

知识蒸馏有两大类:一类是**「logits蒸馏」** ,另一类是**「特征蒸馏」** 。

「logits蒸馏」 指的是在softmax时使用较高的温度系数,提升负标签的信息,然后使用Student和Teacher在高温softmax下logits的KL散度作为loss。

「中间特征蒸馏」 就是强迫Student去学习Teacher某些中间层的特征,直接匹配中间的特征或学习特征之间的转换关系。例如,在特征No.1和No.2中间,知识可以表示为如何模做两者中间的转化,可以用一个矩阵让学习者产生这个矩阵,学习者和转化之间的学习关系。

这篇文章汇总了常用的知识蒸馏的论文和代码,方便后续的学习和研究。

1、Logits

论文链接:https://proceedings.neurips.cc/paper/2014/file/ea8fcd92d59581717e06eb187f10666d-Paper.pdf

代码:


          
            
from __future__ import absolute_import  
from __future__ import print_function  
from __future__ import division  
import torch  
import torch.nn as nn  
import torch.nn.functional as F  
  
  
class Logits(nn.Module):  
 '''  
 Do Deep Nets Really Need to be Deep?  
 http://papers.nips.cc/paper/5484-do-deep-nets-really-need-to-be-deep.pdf  
 '''  
 def \_\_init\_\_(self):  
  super(Logits, self).__init__()  
  
 def forward(self, out\_s, out\_t):  
  loss = F.mse_loss(out_s, out_t)  
  
  return loss  

        

2、ST

论文链接:https://arxiv.org/pdf/1503.02531.pdf

代码:


          
            
from __future__ import absolute_import  
from __future__ import print_function  
from __future__ import division  
import torch  
import torch.nn as nn  
import torch.nn.functional as F  
  
  
class SoftTarget(nn.Module):  
 '''  
 Distilling the Knowledge in a Neural Network  
 https://arxiv.org/pdf/1503.02531.pdf  
 '''  
 def \_\_init\_\_(self, T):  
  super(SoftTarget, self).__init__()  
  self.T = T  
  
 def forward(self, out\_s, out\_t):  
  loss = F.kl_div(F.log_softmax(out_s/self.T, dim=1),  
      F.softmax(out_t/self.T, dim=1),  
      reduction='batchmean') * self.T * self.T  
  
  return loss  
  

        

3、AT

论文链接:https://arxiv.org/pdf/1612.03928.pdf

代码:


          
            
from __future__ import absolute_import  
from __future__ import print_function  
from __future__ import division  
import torch  
import torch.nn as nn  
import torch.nn.functional as F  
  
  
'''  
AT with sum of absolute values with power p  
'''  
class AT(nn.Module):  
 '''  
 Paying More Attention to Attention: Improving the Performance of Convolutional  
 Neural Netkworks wia Attention Transfer  
 https://arxiv.org/pdf/1612.03928.pdf  
 '''  
 def \_\_init\_\_(self, p):  
  super(AT, self).__init__()  
  self.p = p  
  
 def forward(self, fm\_s, fm\_t):  
  loss = F.mse_loss(self.attention_map(fm_s), self.attention_map(fm_t))  
  
  return loss  
  
 def attention\_map(self, fm, eps=1e-6):  
  am = torch.pow(torch.abs(fm), self.p)  
  am = torch.sum(am, dim=1, keepdim=True)  
  norm = torch.norm(am, dim=(2,3), keepdim=True)  
  am = torch.div(am, norm+eps)  
  
  return am  
  

        

4、Fitnet

论文链接:https://arxiv.org/pdf/1412.6550.pdf

代码:


          
            
from __future__ import absolute_import  
from __future__ import print_function  
from __future__ import division  
import torch  
import torch.nn as nn  
import torch.nn.functional as F  
  
  
class Hint(nn.Module):  
 '''  
 FitNets: Hints for Thin Deep Nets  
 https://arxiv.org/pdf/1412.6550.pdf  
 '''  
 def \_\_init\_\_(self):  
  super(Hint, self).__init__()  
  
 def forward(self, fm\_s, fm\_t):  
  loss = F.mse_loss(fm_s, fm_t)  
  
  return loss  
  

        

5、NST

论文链接:https://arxiv.org/pdf/1707.0121

代码:


          
            
from __future__ import absolute_import  
from __future__ import print_function  
from __future__ import division  
import torch  
import torch.nn as nn  
import torch.nn.functional as F  
  
  
'''  
NST with Polynomial Kernel, where d=2 and c=0  
'''  
class NST(nn.Module):  
 '''  
 Like What You Like: Knowledge Distill via Neuron Selectivity Transfer  
 https://arxiv.org/pdf/1707.01219.pdf  
 '''  
 def \_\_init\_\_(self):  
  super(NST, self).__init__()  
  
 def forward(self, fm\_s, fm\_t):  
  fm_s = fm_s.view(fm_s.size(0), fm_s.size(1), -1)  
  fm_s = F.normalize(fm_s, dim=2)  
  
  fm_t = fm_t.view(fm_t.size(0), fm_t.size(1), -1)  
  fm_t = F.normalize(fm_t, dim=2)  
  
  loss = self.poly_kernel(fm_t, fm_t).mean() \  
    + self.poly_kernel(fm_s, fm_s).mean() \  
    - 2 * self.poly_kernel(fm_s, fm_t).mean()  
  
  return loss  
  
 def poly\_kernel(self, fm1, fm2):  
  fm1 = fm1.unsqueeze(1)  
  fm2 = fm2.unsqueeze(2)  
  out = (fm1 * fm2).sum(-1).pow(2)  
  
  return out  
  

        

6、PKT

论文链接:http://openaccess.thecvf.com/content\_ECCV\_2018/papers/Nikolaos\_Passalis\_Learning\_Deep\_Representations\_ECCV\_2018\_paper.pdf

代码:


          
            
from __future__ import absolute_import  
from __future__ import print_function  
from __future__ import division  
import torch  
import torch.nn as nn  
import torch.nn.functional as F  
  
  
'''  
Adopted from https://github.com/passalis/probabilistic\_kt/blob/master/nn/pkt.py  
'''  
class PKTCosSim(nn.Module):  
 '''  
 Learning Deep Representations with Probabilistic Knowledge Transfer  
 http://openaccess.thecvf.com/content\_ECCV\_2018/papers/Nikolaos\_Passalis\_Learning\_Deep\_Representations\_ECCV\_2018\_paper.pdf  
 '''  
 def \_\_init\_\_(self):  
  super(PKTCosSim, self).__init__()  
  
 def forward(self, feat\_s, feat\_t, eps=1e-6):  
  # Normalize each vector by its norm  
  feat_s_norm = torch.sqrt(torch.sum(feat_s ** 2, dim=1, keepdim=True))  
  feat_s = feat_s / (feat_s_norm + eps)  
  feat_s[feat_s != feat_s] = 0  
  
  feat_t_norm = torch.sqrt(torch.sum(feat_t ** 2, dim=1, keepdim=True))  
  feat_t = feat_t / (feat_t_norm + eps)  
  feat_t[feat_t != feat_t] = 0  
  
  # Calculate the cosine similarity  
  feat_s_cos_sim = torch.mm(feat_s, feat_s.transpose(0, 1))  
  feat_t_cos_sim = torch.mm(feat_t, feat_t.transpose(0, 1))  
  
  # Scale cosine similarity to [0,1]  
  feat_s_cos_sim = (feat_s_cos_sim + 1.0) / 2.0  
  feat_t_cos_sim = (feat_t_cos_sim + 1.0) / 2.0  
  
  # Transform them into probabilities  
  feat_s_cond_prob = feat_s_cos_sim / torch.sum(feat_s_cos_sim, dim=1, keepdim=True)  
  feat_t_cond_prob = feat_t_cos_sim / torch.sum(feat_t_cos_sim, dim=1, keepdim=True)  
  
  # Calculate the KL-divergence  
  loss = torch.mean(feat_t_cond_prob * torch.log((feat_t_cond_prob + eps) / (feat_s_cond_prob + eps)))  
  
  return loss  
  

        

7、FSP

论文链接:http://openaccess.thecvf.com/content\_cvpr\_2017/papers/Yim\_A\_Gift\_From\_CVPR\_2017\_paper.pdf

代码:


          
            
from __future__ import absolute_import  
from __future__ import print_function  
from __future__ import division  
import torch  
import torch.nn as nn  
import torch.nn.functional as F  
  
  
class FSP(nn.Module):  
 '''  
 A Gift from Knowledge Distillation: Fast Optimization, Network Minimization and Transfer Learning  
 http://openaccess.thecvf.com/content\_cvpr\_2017/papers/Yim\_A\_Gift\_From\_CVPR\_2017\_paper.pdf  
 '''  
 def \_\_init\_\_(self):  
  super(FSP, self).__init__()  
  
 def forward(self, fm\_s1, fm\_s2, fm\_t1, fm\_t2):  
  loss = F.mse_loss(self.fsp_matrix(fm_s1,fm_s2), self.fsp_matrix(fm_t1,fm_t2))  
  
  return loss  
  
 def fsp\_matrix(self, fm1, fm2):  
  if fm1.size(2) > fm2.size(2):  
   fm1 = F.adaptive_avg_pool2d(fm1, (fm2.size(2), fm2.size(3)))  
  
  fm1 = fm1.view(fm1.size(0), fm1.size(1), -1)  
  fm2 = fm2.view(fm2.size(0), fm2.size(1), -1).transpose(1,2)  
  
  fsp = torch.bmm(fm1, fm2) / fm1.size(2)  
  
  return fsp  
  

        

8、FT

论文链接:http://papers.nips.cc/paper/7541-paraphrasing-complex-network-network-compression-via-factor-transfer.pdf

代码:


          
            
from __future__ import absolute_import  
from __future__ import print_function  
from __future__ import division  
import torch  
import torch.nn as nn  
import torch.nn.functional as F  
  
  
class FT(nn.Module):  
 '''  
 araphrasing Complex Network: Network Compression via Factor Transfer  
 http://papers.nips.cc/paper/7541-paraphrasing-complex-network-network-compression-via-factor-transfer.pdf  
 '''  
 def \_\_init\_\_(self):  
  super(FT, self).__init__()  
  
 def forward(self, factor\_s, factor\_t):  
  loss = F.l1_loss(self.normalize(factor_s), self.normalize(factor_t))  
  
  return loss  
  
 def normalize(self, factor):  
  norm_factor = F.normalize(factor.view(factor.size(0),-1))  
  
  return norm_factor  
  

        

9、RKD

论文链接:https://arxiv.org/pdf/1904.05068.pdf

代码:


          
            
from __future__ import absolute_import  
from __future__ import print_function  
from __future__ import division  
import torch  
import torch.nn as nn  
import torch.nn.functional as F  
  
  
'''  
From https://github.com/lenscloth/RKD/blob/master/metric/loss.py  
'''  
class RKD(nn.Module):  
 '''  
 Relational Knowledge Distillation  
 https://arxiv.org/pdf/1904.05068.pdf  
 '''  
 def \_\_init\_\_(self, w\_dist, w\_angle):  
  super(RKD, self).__init__()  
  
  self.w_dist  = w_dist  
  self.w_angle = w_angle  
  
 def forward(self, feat\_s, feat\_t):  
  loss = self.w_dist * self.rkd_dist(feat_s, feat_t) + \  
      self.w_angle * self.rkd_angle(feat_s, feat_t)  
  
  return loss  
  
 def rkd\_dist(self, feat\_s, feat\_t):  
  feat_t_dist = self.pdist(feat_t, squared=False)  
  mean_feat_t_dist = feat_t_dist[feat_t_dist>0].mean()  
  feat_t_dist = feat_t_dist / mean_feat_t_dist  
  
  feat_s_dist = self.pdist(feat_s, squared=False)  
  mean_feat_s_dist = feat_s_dist[feat_s_dist>0].mean()  
  feat_s_dist = feat_s_dist / mean_feat_s_dist  
  
  loss = F.smooth_l1_loss(feat_s_dist, feat_t_dist)  
  
  return loss  
  
 def rkd\_angle(self, feat\_s, feat\_t):  
  # N x C --> N x N x C  
  feat_t_vd = (feat_t.unsqueeze(0) - feat_t.unsqueeze(1))  
  norm_feat_t_vd = F.normalize(feat_t_vd, p=2, dim=2)  
  feat_t_angle = torch.bmm(norm_feat_t_vd, norm_feat_t_vd.transpose(1, 2)).view(-1)  
  
  feat_s_vd = (feat_s.unsqueeze(0) - feat_s.unsqueeze(1))  
  norm_feat_s_vd = F.normalize(feat_s_vd, p=2, dim=2)  
  feat_s_angle = torch.bmm(norm_feat_s_vd, norm_feat_s_vd.transpose(1, 2)).view(-1)  
  
  loss = F.smooth_l1_loss(feat_s_angle, feat_t_angle)  
  
  return loss  
  
 def pdist(self, feat, squared=False, eps=1e-12):  
  feat_square = feat.pow(2).sum(dim=1)  
  feat_prod   = torch.mm(feat, feat.t())  
  feat_dist   = (feat_square.unsqueeze(0) + feat_square.unsqueeze(1) - 2 * feat_prod).clamp(min=eps)  
  
  if not squared:  
   feat_dist = feat_dist.sqrt()  
  
  feat_dist = feat_dist.clone()  
  feat_dist[range(len(feat)), range(len(feat))] = 0  
  
  return feat_dist  
  

        

10、AB

论文链接:https://arxiv.org/pdf/1811.03233.pdf

代码:


          
            
from __future__ import absolute_import  
from __future__ import print_function  
from __future__ import division  
import torch  
import torch.nn as nn  
import torch.nn.functional as F  
  
  
class AB(nn.Module):  
 '''  
 Knowledge Transfer via Distillation of Activation Boundaries Formed by Hidden Neurons  
 https://arxiv.org/pdf/1811.03233.pdf  
 '''  
 def \_\_init\_\_(self, margin):  
  super(AB, self).__init__()  
  
  self.margin = margin  
  
 def forward(self, fm\_s, fm\_t):  
  # fm befor activation  
  loss = ((fm_s + self.margin).pow(2) * ((fm_s > -self.margin) & (fm_t <= 0)).float() +  
       (fm_s - self.margin).pow(2) * ((fm_s <= self.margin) & (fm_t > 0)).float())  
  loss = loss.mean()  
  
  return loss  
  

        

11、SP

论文链接:https://arxiv.org/pdf/1907.09682.pdf

代码:


          
            
from __future__ import absolute_import  
from __future__ import print_function  
from __future__ import division  
import torch  
import torch.nn as nn  
import torch.nn.functional as F  
  
  
class SP(nn.Module):  
 '''  
 Similarity-Preserving Knowledge Distillation  
 https://arxiv.org/pdf/1907.09682.pdf  
 '''  
 def \_\_init\_\_(self):  
  super(SP, self).__init__()  
  
 def forward(self, fm\_s, fm\_t):  
  fm_s = fm_s.view(fm_s.size(0), -1)  
  G_s  = torch.mm(fm_s, fm_s.t())  
  norm_G_s = F.normalize(G_s, p=2, dim=1)  
  
  fm_t = fm_t.view(fm_t.size(0), -1)  
  G_t  = torch.mm(fm_t, fm_t.t())  
  norm_G_t = F.normalize(G_t, p=2, dim=1)  
  
  loss = F.mse_loss(norm_G_s, norm_G_t)  
  
  return loss  
  

        

12、Sobolev

论文链接:https://arxiv.org/pdf/1706.04859.pdf

代码:


          
            
from __future__ import absolute_import  
from __future__ import print_function  
from __future__ import division  
import torch  
import torch.nn as nn  
import torch.nn.functional as F  
from torch.autograd import grad  
  
  
class Sobolev(nn.Module):  
 '''  
 Sobolev Training for Neural Networks  
 https://arxiv.org/pdf/1706.04859.pdf  
  
 Knowledge Transfer with Jacobian Matching  
 http://de.arxiv.org/pdf/1803.00443  
 '''  
 def \_\_init\_\_(self):  
  super(Sobolev, self).__init__()  
  
 def forward(self, out\_s, out\_t, img, target):  
  target_out_s = torch.gather(out_s, 1, target.view(-1, 1))  
  grad_s       = grad(outputs=target_out_s, inputs=img,  
       grad_outputs=torch.ones_like(target_out_s),  
       create_graph=True, retain_graph=True, only_inputs=True)[0]  
  norm_grad_s  = F.normalize(grad_s.view(grad_s.size(0), -1), p=2, dim=1)  
  
  target_out_t = torch.gather(out_t, 1, target.view(-1, 1))  
  grad_t       = grad(outputs=target_out_t, inputs=img,  
       grad_outputs=torch.ones_like(target_out_t),  
       create_graph=True, retain_graph=True, only_inputs=True)[0]  
  norm_grad_t  = F.normalize(grad_t.view(grad_t.size(0), -1), p=2, dim=1)  
  
  loss = F.mse_loss(norm_grad_s, norm_grad_t.detach())  
  
  return loss  
  

        

13、BSS

论文链接:https://arxiv.org/pdf/1805.05532.pdf

代码:


          
            
from __future__ import absolute_import  
from __future__ import print_function  
from __future__ import division  
import torch  
import torch.nn as nn  
import torch.nn.functional as F  
from torch.autograd.gradcheck import zero_gradients  
'''  
Modified by https://github.com/bhheo/BSS\_distillation  
'''  
  
def reduce\_sum(x, keepdim=True):  
 for d in reversed(range(1, x.dim())):  
  x = x.sum(d, keepdim=keepdim)  
 return x  
  
  
def l2\_norm(x, keepdim=True):  
 norm = reduce_sum(x*x, keepdim=keepdim)  
 return norm.sqrt()  
  
  
class BSS(nn.Module):  
 '''  
 Knowledge Distillation with Adversarial Samples Supporting Decision Boundary  
 https://arxiv.org/pdf/1805.05532.pdf  
 '''  
 def \_\_init\_\_(self, T):  
  super(BSS, self).__init__()  
  self.T = T  
  
 def forward(self, attacked\_out\_s, attacked\_out\_t):  
  loss = F.kl_div(F.log_softmax(attacked_out_s/self.T, dim=1),  
      F.softmax(attacked_out_t/self.T, dim=1),  
      reduction='batchmean') #* self.T * self.T  
  
  return loss  
  
  
class BSSAttacker():  
 def \_\_init\_\_(self, step\_alpha, num\_steps, eps=1e-4):  
  self.step_alpha = step_alpha  
  self.num_steps = num_steps  
  self.eps = eps  
  
 def attack(self, model, img, target, attack\_class):  
  img = img.detach().requires_grad_(True)  
  
  step = 0  
  while step < self.num_steps:  
   zero_gradients(img)  
   _, _, _, _, _, output = model(img)  
  
   score = F.softmax(output, dim=1)  
   score_target = score.gather(1, target.unsqueeze(1))  
   score_attack_class = score.gather(1, attack_class.unsqueeze(1))  
  
   loss = (score_attack_class - score_target).sum()  
   loss.backward()  
  
   step_alpha = self.step_alpha * (target == output.max(1)[1]).float()  
   step_alpha = step_alpha.unsqueeze(1).unsqueeze(1).unsqueeze(1)  
   if step_alpha.sum() == 0:  
    break  
  
   pert = (score_target - score_attack_class).unsqueeze(1).unsqueeze(1)  
   norm_pert = step_alpha * (pert + self.eps) * img.grad / l2_norm(img.grad)  
  
   step_adv = img + norm_pert  
   step_adv = torch.clamp(step_adv, -2.5, 2.5)  
   img.data = step_adv.data  
  
   step += 1  
  
  return img  
  

        

14、CC

论文链接:http://openaccess.thecvf.com/content\_ICCV\_2019/papers/Peng\_Correlation\_Congruence\_for\_Knowledge\_Distillation\_ICCV\_2019\_paper.pdf

代码:


          
            
from __future__ import absolute_import  
from __future__ import print_function  
from __future__ import division  
import torch  
import torch.nn as nn  
import torch.nn.functional as F  
import math  
  
  
'''  
CC with P-order Taylor Expansion of Gaussian RBF kernel  
'''  
class CC(nn.Module):  
 '''  
 Correlation Congruence for Knowledge Distillation  
 http://openaccess.thecvf.com/content\_ICCV\_2019/papers/  
 Peng\_Correlation\_Congruence\_for\_Knowledge\_Distillation\_ICCV\_2019\_paper.pdf  
 '''  
 def \_\_init\_\_(self, gamma, P\_order):  
  super(CC, self).__init__()  
  self.gamma = gamma  
  self.P_order = P_order  
  
 def forward(self, feat\_s, feat\_t):  
  corr_mat_s = self.get_correlation_matrix(feat_s)  
  corr_mat_t = self.get_correlation_matrix(feat_t)  
  
  loss = F.mse_loss(corr_mat_s, corr_mat_t)  
  
  return loss  
  
 def get\_correlation\_matrix(self, feat):  
  feat = F.normalize(feat, p=2, dim=-1)  
  sim_mat  = torch.matmul(feat, feat.t())  
  corr_mat = torch.zeros_like(sim_mat)  
  
  for p in range(self.P_order+1):  
   corr_mat += math.exp(-2*self.gamma) * (2*self.gamma)**p / \  
      math.factorial(p) * torch.pow(sim_mat, p)  
  
  return corr_mat  
  

        

15、LwM

论文链接:https://arxiv.org/pdf/1811.08051.pdf

代码:


          
            
from __future__ import absolute_import  
from __future__ import print_function  
from __future__ import division  
import torch  
import torch.nn as nn  
import torch.nn.functional as F  
from torch.autograd import grad  
  
'''  
LwM is originally an incremental learning method with   
classification/distillation/attention distillation losses.  
  
Here, LwM is only defined as the Grad-CAM based attention distillation.  
'''  
class LwM(nn.Module):  
 '''  
 Learning without Memorizing  
 https://arxiv.org/pdf/1811.08051.pdf  
 '''  
 def \_\_init\_\_(self):  
  super(LwM, self).__init__()  
  
 def forward(self, out\_s, fm\_s, out\_t, fm\_t, target):  
  target_out_t = torch.gather(out_t, 1, target.view(-1, 1))  
  grad_fm_t    = grad(outputs=target_out_t, inputs=fm_t,  
       grad_outputs=torch.ones_like(target_out_t),  
       create_graph=True, retain_graph=True, only_inputs=True)[0]  
  weights_t = F.adaptive_avg_pool2d(grad_fm_t, 1)  
  cam_t = torch.sum(torch.mul(weights_t, grad_fm_t), dim=1, keepdim=True)  
  cam_t = F.relu(cam_t)  
  cam_t = cam_t.view(cam_t.size(0), -1)  
  norm_cam_t = F.normalize(cam_t, p=2, dim=1)  
  
  target_out_s = torch.gather(out_s, 1, target.view(-1, 1))  
  grad_fm_s    = grad(outputs=target_out_s, inputs=fm_s,  
       grad_outputs=torch.ones_like(target_out_s),  
       create_graph=True, retain_graph=True, only_inputs=True)[0]  
  weights_s = F.adaptive_avg_pool2d(grad_fm_s, 1)  
  cam_s = torch.sum(torch.mul(weights_s, grad_fm_s), dim=1, keepdim=True)  
  cam_s = F.relu(cam_s)  
  cam_s = cam_s.view(cam_s.size(0), -1)  
  norm_cam_s = F.normalize(cam_s, p=2, dim=1)  
  
  loss = F.l1_loss(norm_cam_s, norm_cam_t.detach())  
  
  return loss  
  

        

16、IRG

论文链接:http://openaccess.thecvf.com/content\_CVPR\_2019/papers/Liu\_Knowledge\_Distillation\_via\_Instance\_Relationship\_Graph\_CVPR\_2019\_paper.pdf

代码:


          
            
from __future__ import absolute_import  
from __future__ import print_function  
from __future__ import division  
import torch  
import torch.nn as nn  
import torch.nn.functional as F  
  
  
class IRG(nn.Module):  
 '''  
 Knowledge Distillation via Instance Relationship Graph  
 http://openaccess.thecvf.com/content\_CVPR\_2019/papers/  
 Liu\_Knowledge\_Distillation\_via\_Instance\_Relationship\_Graph\_CVPR\_2019\_paper.pdf  
  
 The official code is written by Caffe  
 https://github.com/yufanLIU/IRG  
 '''  
 def \_\_init\_\_(self, w\_irg\_vert, w\_irg\_edge, w\_irg\_tran):  
  super(IRG, self).__init__()  
  
  self.w_irg_vert = w_irg_vert  
  self.w_irg_edge = w_irg_edge  
  self.w_irg_tran = w_irg_tran  
  
 def forward(self, irg\_s, irg\_t):  
  fm_s1, fm_s2, feat_s, out_s = irg_s  
  fm_t1, fm_t2, feat_t, out_t = irg_t  
  
  loss_irg_vert = F.mse_loss(out_s, out_t)  
  
  irg_edge_feat_s = self.euclidean_dist_feat(feat_s, squared=True)  
  irg_edge_feat_t = self.euclidean_dist_feat(feat_t, squared=True)  
  irg_edge_fm_s1  = self.euclidean_dist_fm(fm_s1, squared=True)  
  irg_edge_fm_t1  = self.euclidean_dist_fm(fm_t1, squared=True)  
  irg_edge_fm_s2  = self.euclidean_dist_fm(fm_s2, squared=True)  
  irg_edge_fm_t2  = self.euclidean_dist_fm(fm_t2, squared=True)  
  loss_irg_edge = (F.mse_loss(irg_edge_feat_s, irg_edge_feat_t) +  
       F.mse_loss(irg_edge_fm_s1,  irg_edge_fm_t1 ) +  
       F.mse_loss(irg_edge_fm_s2,  irg_edge_fm_t2 )) / 3.0  
  
  irg_tran_s = self.euclidean_dist_fms(fm_s1, fm_s2, squared=True)  
  irg_tran_t = self.euclidean_dist_fms(fm_t1, fm_t2, squared=True)  
  loss_irg_tran = F.mse_loss(irg_tran_s, irg_tran_t)  
  
  # print(self.w\_irg\_vert * loss\_irg\_vert)  
  # print(self.w\_irg\_edge * loss\_irg\_edge)  
  # print(self.w\_irg\_tran * loss\_irg\_tran)  
  # print()  
  
  loss = (self.w_irg_vert * loss_irg_vert +  
    self.w_irg_edge * loss_irg_edge +  
    self.w_irg_tran * loss_irg_tran)  
  
  return loss  
  
 def euclidean\_dist\_fms(self, fm1, fm2, squared=False, eps=1e-12):  
  '''  
  Calculating the IRG Transformation, where fm1 precedes fm2 in the network.  
  '''  
  if fm1.size(2) > fm2.size(2):  
   fm1 = F.adaptive_avg_pool2d(fm1, (fm2.size(2), fm2.size(3)))  
  if fm1.size(1) < fm2.size(1):  
   fm2 = (fm2[:,0::2,:,:] + fm2[:,1::2,:,:]) / 2.0  
  
  fm1 = fm1.view(fm1.size(0), -1)  
  fm2 = fm2.view(fm2.size(0), -1)  
  fms_dist = torch.sum(torch.pow(fm1-fm2, 2), dim=-1).clamp(min=eps)  
  
  if not squared:  
   fms_dist = fms_dist.sqrt()  
  
  fms_dist = fms_dist / fms_dist.max()  
  
  return fms_dist  
  
 def euclidean\_dist\_fm(self, fm, squared=False, eps=1e-12):   
  '''  
  Calculating the IRG edge of feature map.   
  '''  
  fm = fm.view(fm.size(0), -1)  
  fm_square = fm.pow(2).sum(dim=1)  
  fm_prod   = torch.mm(fm, fm.t())  
  fm_dist   = (fm_square.unsqueeze(0) + fm_square.unsqueeze(1) - 2 * fm_prod).clamp(min=eps)  
  
  if not squared:  
   fm_dist = fm_dist.sqrt()  
  
  fm_dist = fm_dist.clone()  
  fm_dist[range(len(fm)), range(len(fm))] = 0  
  fm_dist = fm_dist / fm_dist.max()  
  
  return fm_dist  
  
 def euclidean\_dist\_feat(self, feat, squared=False, eps=1e-12):  
  '''  
  Calculating the IRG edge of feat.  
  '''  
  feat_square = feat.pow(2).sum(dim=1)  
  feat_prod   = torch.mm(feat, feat.t())  
  feat_dist   = (feat_square.unsqueeze(0) + feat_square.unsqueeze(1) - 2 * feat_prod).clamp(min=eps)  
  
  if not squared:  
   feat_dist = feat_dist.sqrt()  
  
  feat_dist = feat_dist.clone()  
  feat_dist[range(len(feat)), range(len(feat))] = 0  
  feat_dist = feat_dist / feat_dist.max()  
  
  return feat_dist  

        

17、VID

论文链接:https://openaccess.thecvf.com/content\_CVPR\_2019/papers/Ahn\_Variational\_Information\_Distillation\_for\_Knowledge\_Transfer\_CVPR\_2019\_paper.pdf

代码:


          
            
from __future__ import absolute_import  
from __future__ import print_function  
from __future__ import division  
import torch  
import torch.nn as nn  
import torch.nn.functional as F  
import numpy as np  
  
  
def conv1x1(in\_channels, out\_channels):  
 return nn.Conv2d(in_channels, out_channels,  
      kernel_size=1, stride=1,  
      padding=0, bias=False)  
  
'''  
Modified from https://github.com/HobbitLong/RepDistiller/blob/master/distiller\_zoo/VID.py  
'''  
class VID(nn.Module):  
 '''  
 Variational Information Distillation for Knowledge Transfer  
 https://zpascal.net/cvpr2019/Ahn\_Variational\_Information\_Distillation\_for\_Knowledge\_Transfer\_CVPR\_2019\_paper.pdf  
 '''  
 def \_\_init\_\_(self, in\_channels, mid\_channels, out\_channels, init\_var, eps=1e-6):  
  super(VID, self).__init__()  
  self.eps = eps  
  self.regressor = nn.Sequential(*[  
    conv1x1(in_channels, mid_channels),  
    # nn.BatchNorm2d(mid\_channels),  
    nn.ReLU(),  
    conv1x1(mid_channels, mid_channels),  
    # nn.BatchNorm2d(mid\_channels),  
    nn.ReLU(),  
    conv1x1(mid_channels, out_channels),  
   ])  
  self.alpha = nn.Parameter(  
    np.log(np.exp(init_var-eps)-1.0) * torch.ones(out_channels)  
   )  
  
  for m in self.modules():  
   if isinstance(m, nn.Conv2d):  
    nn.init.kaiming_normal_(m.weight, mode='fan\_out', nonlinearity='relu')  
    if m.bias is not None:  
     nn.init.constant_(m.bias, 0)  
   # elif isinstance(m, nn.BatchNorm2d):  
   #  nn.init.constant\_(m.weight, 1)  
   #  nn.init.constant\_(m.bias, 0)  
  
 def forward(self, fm\_s, fm\_t):  
  pred_mean = self.regressor(fm_s)  
  pred_var  = torch.log(1.0+torch.exp(self.alpha)) + self.eps  
  pred_var  = pred_var.view(1, -1, 1, 1)  
  neg_log_prob = 0.5 * (torch.log(pred_var) + (pred_mean-fm_t)**2 / pred_var)  
  loss = torch.mean(neg_log_prob)  
  
  return loss  
  

        

18、OFD

论文链接:http://openaccess.thecvf.com/content\_ICCV\_2019/papers/Heo\_A\_Comprehensive\_Overhaul\_of\_Feature\_Distillation\_ICCV\_2019\_paper.pdf

代码:


          
            
from __future__ import absolute_import  
from __future__ import print_function  
from __future__ import division  
import torch  
import torch.nn as nn  
import torch.nn.functional as F  
import numpy as np  
  
  
'''  
Modified from https://github.com/clovaai/overhaul-distillation/blob/master/CIFAR-100/distiller.py  
'''  
class OFD(nn.Module):  
 '''  
 A Comprehensive Overhaul of Feature Distillation  
 http://openaccess.thecvf.com/content\_ICCV\_2019/papers/  
 Heo\_A\_Comprehensive\_Overhaul\_of\_Feature\_Distillation\_ICCV\_2019\_paper.pdf  
 '''  
 def \_\_init\_\_(self, in\_channels, out\_channels):  
  super(OFD, self).__init__()  
  self.connector = nn.Sequential(*[  
    nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),  
    nn.BatchNorm2d(out_channels)  
   ])  
  
  for m in self.modules():  
   if isinstance(m, nn.Conv2d):  
    nn.init.kaiming_normal_(m.weight, mode='fan\_out', nonlinearity='relu')  
    if m.bias is not None:  
     nn.init.constant_(m.bias, 0)  
   elif isinstance(m, nn.BatchNorm2d):  
    nn.init.constant_(m.weight, 1)  
    nn.init.constant_(m.bias, 0)  
  
 def forward(self, fm\_s, fm\_t):  
  margin = self.get_margin(fm_t)  
  fm_t = torch.max(fm_t, margin)  
  fm_s = self.connector(fm_s)  
  
  mask = 1.0 - ((fm_s <= fm_t) & (fm_t <= 0.0)).float()  
  loss = torch.mean((fm_s - fm_t)**2 * mask)  
  
  return loss  
  
 def get\_margin(self, fm, eps=1e-6):  
  mask = (fm < 0.0).float()  
  masked_fm = fm * mask  
  
  margin = masked_fm.sum(dim=(0,2,3), keepdim=True) / (mask.sum(dim=(0,2,3), keepdim=True)+eps)  
  
  return margin  
  

        

19、AFD

论文链接:https://openreview.net/pdf?id=ryxyCeHtPB

代码:


          
            
from __future__ import absolute_import  
from __future__ import print_function  
from __future__ import division  
import torch  
import torch.nn as nn  
import torch.nn.functional as F  
import math  
  
'''  
In the original paper, AFD is one of components of AFDS.  
AFDS: Attention Feature Distillation and Selection  
AFD:  Attention Feature Distillation  
AFS:  Attention Feature Selection  
  
We find the original implementation of attention is unstable, thus we replace it with a SE block.  
'''  
class AFD(nn.Module):  
 '''  
 Pay Attention to Features, Transfer Learn Faster CNNs  
 https://openreview.net/pdf?id=ryxyCeHtPB  
 '''  
 def \_\_init\_\_(self, in\_channels, att\_f):  
  super(AFD, self).__init__()  
  mid_channels = int(in_channels * att_f)  
  
  self.attention = nn.Sequential(*[  
    nn.Conv2d(in_channels, mid_channels, 1, 1, 0, bias=True),  
    nn.ReLU(inplace=True),  
    nn.Conv2d(mid_channels, in_channels, 1, 1, 0, bias=True)  
   ])  
  
  for m in self.modules():  
   if isinstance(m, nn.Conv2d):  
    nn.init.kaiming_normal_(m.weight, mode='fan\_out', nonlinearity='relu')  
    if m.bias is not None:  
     nn.init.constant_(m.bias, 0)  
    
 def forward(self, fm\_s, fm\_t, eps=1e-6):  
  fm_t_pooled = F.adaptive_avg_pool2d(fm_t, 1)  
  rho = self.attention(fm_t_pooled)  
  # rho = F.softmax(rho.squeeze(), dim=-1)  
  rho = torch.sigmoid(rho.squeeze())  
  rho = rho / torch.sum(rho, dim=1, keepdim=True)  
  
  fm_s_norm = torch.norm(fm_s, dim=(2,3), keepdim=True)  
  fm_s      = torch.div(fm_s, fm_s_norm+eps)  
  fm_t_norm = torch.norm(fm_t, dim=(2,3), keepdim=True)  
  fm_t      = torch.div(fm_t, fm_t_norm+eps)  
  
  loss = rho * torch.pow(fm_s-fm_t, 2).mean(dim=(2,3))  
  loss = loss.sum(1).mean(0)  
  
  return loss  
  

        

20、CRD

论文链接:https://openreview.net/pdf?id=SkgpBJrtvS

代码:


          
            
from __future__ import absolute_import  
from __future__ import print_function  
from __future__ import division  
import torch  
import torch.nn as nn  
import torch.nn.functional as F  
import math  
  
  
'''  
Modified from https://github.com/HobbitLong/RepDistiller/tree/master/crd  
'''  
class CRD(nn.Module):  
 '''  
 Contrastive Representation Distillation  
 https://openreview.net/pdf?id=SkgpBJrtvS  
  
 includes two symmetric parts:  
 (a) using teacher as anchor, choose positive and negatives over the student side  
 (b) using student as anchor, choose positive and negatives over the teacher side  
  
 Args:  
  s\_dim: the dimension of student's feature  
  t\_dim: the dimension of teacher's feature  
  feat\_dim: the dimension of the projection space  
  nce\_n: number of negatives paired with each positive  
  nce\_t: the temperature  
  nce\_mom: the momentum for updating the memory buffer  
  n\_data: the number of samples in the training set, which is the M in Eq.(19)  
 '''  
 def \_\_init\_\_(self, s\_dim, t\_dim, feat\_dim, nce\_n, nce\_t, nce\_mom, n\_data):  
  super(CRD, self).__init__()  
  self.embed_s = Embed(s_dim, feat_dim)  
  self.embed_t = Embed(t_dim, feat_dim)  
  self.contrast = ContrastMemory(feat_dim, n_data, nce_n, nce_t, nce_mom)  
  self.criterion_s = ContrastLoss(n_data)  
  self.criterion_t = ContrastLoss(n_data)  
  
 def forward(self, feat\_s, feat\_t, idx, sample\_idx):  
  feat_s = self.embed_s(feat_s)  
  feat_t = self.embed_t(feat_t)  
  out_s, out_t = self.contrast(feat_s, feat_t, idx, sample_idx)  
  loss_s = self.criterion_s(out_s)  
  loss_t = self.criterion_t(out_t)  
  loss = loss_s + loss_t  
  
  return loss  
  
  
class Embed(nn.Module):  
 def \_\_init\_\_(self, in\_dim, out\_dim):  
  super(Embed, self).__init__()  
  self.linear = nn.Linear(in_dim, out_dim)  
  
 def forward(self, x):  
  x = x.view(x.size(0), -1)  
  x = self.linear(x)  
  x = F.normalize(x, p=2, dim=1)  
  
  return x  
  
  
class ContrastLoss(nn.Module):  
 '''  
 contrastive loss, corresponding to Eq.(18)  
 '''  
 def \_\_init\_\_(self, n\_data, eps=1e-7):  
  super(ContrastLoss, self).__init__()  
  self.n_data = n_data  
  self.eps = eps  
  
 def forward(self, x):  
  bs = x.size(0)  
  N  = x.size(1) - 1  
  M  = float(self.n_data)  
  
  # loss for positive pair  
  pos_pair = x.select(1, 0)  
  log_pos  = torch.div(pos_pair, pos_pair.add(N / M + self.eps)).log_()  
  
  # loss for negative pair  
  neg_pair = x.narrow(1, 1, N)  
  log_neg  = torch.div(neg_pair.clone().fill_(N / M), neg_pair.add(N / M + self.eps)).log_()  
  
  loss = -(log_pos.sum() + log_neg.sum()) / bs  
  
  return loss  
  
  
class ContrastMemory(nn.Module):  
 def \_\_init\_\_(self, feat\_dim, n\_data, nce\_n, nce\_t, nce\_mom):  
  super(ContrastMemory, self).__init__()  
  self.N = nce_n  
  self.T = nce_t  
  self.momentum = nce_mom  
  self.Z_t = None  
  self.Z_s = None  
  
  stdv = 1. / math.sqrt(feat_dim / 3.)  
  self.register_buffer('memory\_t', torch.rand(n_data, feat_dim).mul_(2 * stdv).add_(-stdv))  
  self.register_buffer('memory\_s', torch.rand(n_data, feat_dim).mul_(2 * stdv).add_(-stdv))  
  
 def forward(self, feat\_s, feat\_t, idx, sample\_idx):  
  bs = feat_s.size(0)  
  feat_dim = self.memory_s.size(1)  
  n_data = self.memory_s.size(0)  
  
  # using teacher as anchor  
  weight_s = torch.index_select(self.memory_s, 0, sample_idx.view(-1)).detach()  
  weight_s = weight_s.view(bs, self.N + 1, feat_dim)  
  out_t = torch.bmm(weight_s, feat_t.view(bs, feat_dim, 1))  
  out_t = torch.exp(torch.div(out_t, self.T)).squeeze().contiguous()  
  
  # using student as anchor  
  weight_t = torch.index_select(self.memory_t, 0, sample_idx.view(-1)).detach()  
  weight_t = weight_t.view(bs, self.N + 1, feat_dim)  
  out_s = torch.bmm(weight_t, feat_s.view(bs, feat_dim, 1))  
  out_s = torch.exp(torch.div(out_s, self.T)).squeeze().contiguous()  
  
  # set Z if haven't been set yet  
  if self.Z_t is None:  
   self.Z_t = (out_t.mean() * n_data).detach().item()  
  if self.Z_s is None:  
   self.Z_s = (out_s.mean() * n_data).detach().item()  
  
  out_t = torch.div(out_t, self.Z_t)  
  out_s = torch.div(out_s, self.Z_s)  
  
  # update memory  
  with torch.no_grad():  
   pos_mem_t = torch.index_select(self.memory_t, 0, idx.view(-1))  
   pos_mem_t.mul_(self.momentum)  
   pos_mem_t.add_(torch.mul(feat_t, 1 - self.momentum))  
   pos_mem_t = F.normalize(pos_mem_t, p=2, dim=1)  
   self.memory_t.index_copy_(0, idx, pos_mem_t)  
  
   pos_mem_s = torch.index_select(self.memory_s, 0, idx.view(-1))  
   pos_mem_s.mul_(self.momentum)  
   pos_mem_s.add_(torch.mul(feat_s, 1 - self.momentum))  
   pos_mem_s = F.normalize(pos_mem_s, p=2, dim=1)  
   self.memory_s.index_copy_(0, idx, pos_mem_s)  
  
  return out_s, out_t  
  

        

21、DML

论文链接:https://openaccess.thecvf.com/content\_cvpr\_2018/papers/Zhang\_Deep\_Mutual\_Learning\_CVPR\_2018\_paper.pdf

代码:


          
            
from __future__ import absolute_import  
from __future__ import print_function  
from __future__ import division  
import torch  
import torch.nn as nn  
import torch.nn.functional as F  
  
  
'''  
DML with only two networks  
'''  
class DML(nn.Module):  
 '''  
 Deep Mutual Learning  
 https://zpascal.net/cvpr2018/Zhang\_Deep\_Mutual\_Learning\_CVPR\_2018\_paper.pdf  
 '''  
 def \_\_init\_\_(self):  
  super(DML, self).__init__()  
  
 def forward(self, out1, out2):  
  loss = F.kl_div(F.log_softmax(out1, dim=1),  
      F.softmax(out2, dim=1),  
      reduction='batchmean')  
  
  return loss  

        

推荐阅读

[picture.image

DeepSORT再升级 | Deep OC-SORT引入目标外观信息,大幅领先SOTA](https://mp.weixin.qq.com/s?__biz=MzU5OTA2Mjk5Mw==&mid=2247503628&idx=1&sn=a49dc9f38ff4b4cd71e06d15d9aa7a88&chksm=feb82fb2c9cfa6a456cf9f36da31796862af32b6e1f915e71060c5a77984e615dffc918563f8&scene=21#wechat_redirect)

[picture.image

YOLOv8官方支持多目标跟踪 | ByteTrack、BoT-SORT都已加入YOLOv8官方](https://mp.weixin.qq.com/s?__biz=MzU5OTA2Mjk5Mw==&mid=2247503531&idx=1&sn=d0b69f7c5bbb79c502b906372028136e&chksm=feb82e15c9cfa703c1c3eb0c512975078282f1b83a9f9e38d4c995037efd6b4a990975e6de55&scene=21#wechat_redirect)

[picture.image

YOLOv5抛弃Anchor-Base方法 | YOLOv5u正式加入Anchor-Free大家庭](https://mp.weixin.qq.com/s?__biz=MzU5OTA2Mjk5Mw==&mid=2247503230&idx=1&sn=0a986d9a50db9cf9e47469d88ae8596f&chksm=feb82dc0c9cfa4d6f38f165755332497d595e99d0fb95350ef209859c4cd495e171bc809a804&scene=21#wechat_redirect)

picture.image

扫码加入👉「集智书童-模型蒸馏」交流群

(备注: 方向+学校/公司+昵称 )

picture.image

picture.image

picture.image

picture.image

picture.image

picture.image

想要了解更多:

前沿AI视觉感知全栈知识👉「分类、检测、分割、关键点、车道线检测、3D视觉(分割、检测)、多模态、目标跟踪、NerF」

行业技术方案 👉「AI安防、AI医疗、AI自动驾驶」

AI模型部署落地实战 👉「CUDA、TensorRT、NCNN、OpenVINO、MNN、ONNXRuntime以及地平线框架」

欢迎扫描上方二维码,加入「 集智书童-知识星球 」,日常分享论文、学习笔记、问题解决方案、部署方案以及全栈式答疑,期待交流!

免责声明

凡本公众号注明“来源:XXX(非集智书童)”的作品,均转载自其它媒体,版权归原作者所有,如有侵权请联系我们删除,谢谢。

点击下方“ 阅读原文 ”,

了解更多AI学习路上的 「武功秘籍」

0
0
0
0
关于作者
关于作者

文章

0

获赞

0

收藏

0

相关资源
大规模高性能计算集群优化实践
随着机器学习的发展,数据量和训练模型都有越来越大的趋势,这对基础设施有了更高的要求,包括硬件、网络架构等。本次分享主要介绍火山引擎支撑大规模高性能计算集群的架构和优化实践。
相关产品
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论