重参系列 | 轻量化模型+重参技术是不是可以起飞?

小程序计算MySQL

1、开篇小记

1.1、知识点1

picture.image

相同的架构,不同激活函数带来的延迟差异极大。这里Mobileone选择的是使用ReLU。

1.2、知识点2

picture.image

当采用单分支结构时,模型具有更快的速度。这个在RepVGG时就已经知道了,这里读者可以参考RepVGG笔记

2、MobileOne 简述

MobileOne 的核心模块基于 MobileNetV1 而设计,同时吸收了重参数思想,得到上图所示的结构。注:这里的重参数机制还存在一个超参k用于控制重参数分支的数量(实验表明:对于小模型来说,该变种收益更大)。

picture.image

通过上图,如果你愿意,其实就是DBB+RepVGG的结合,而分支数你可以随意的扩宽,重参的化直接进行weight与bias的合并即可。

3、MobileOne 的实现

以下是 MobileOne 的Pytorch实现:


        
          
from typing import Optional, List, Tuple  
  
import copy  
import torch  
import torch.nn as nn  
import torch.nn.functional as F  
  
  
class MobileOneBlock(nn.Module):  
    def \_\_init\_\_(self,  
                 in\_channels: int,  
                 out\_channels: int,  
                 kernel\_size: int,  
                 stride: int = 1,  
                 padding: int = 0,  
                 dilation: int = 1,  
                 groups: int = 1,  
                 inference\_mode: bool = False,  
                 use\_se: bool = False,  
                 num\_conv\_branches: int = 3) -> None:  
        super(MobileOneBlock, self).__init__()  
        self.inference_mode = inference_mode  
        self.groups = groups  
        self.stride = stride  
        self.kernel_size = kernel_size  
        self.in_channels = in_channels  
        self.out_channels = out_channels  
        self.num_conv_branches = num_conv_branches  
  
        # Check if SE-ReLU is requested  
        self.se = nn.Identity()  
        self.activation = nn.ReLU()  
  
        if inference_mode:  
            self.reparam_conv = nn.Conv2d(in_channels=in_channels,  
                                          out_channels=out_channels,  
                                          kernel_size=kernel_size,  
                                          stride=stride,  
                                          padding=padding,  
                                          dilation=dilation,  
                                          groups=groups,  
                                          bias=True)  
        else:  
            # skip connection  
            self.rbr_skip = nn.BatchNorm2d(num_features=in_channels) if out_channels == in_channels and stride == 1 else None  
  
            # 3x3 conv branches  
            rbr_conv = list()  
            for _ in range(self.num_conv_branches):  
                rbr_conv.append(self._conv_bn(kernel_size=kernel_size, padding=padding))  
            self.rbr_conv = nn.ModuleList(rbr_conv)  
  
            # 1x1 conv branch(scale branch)  
            self.rbr_scale = None  
            if kernel_size > 1:  
                self.rbr_scale = self._conv_bn(kernel_size=1, padding=0)  
  
    def forward(self, x: torch.Tensor) -> torch.Tensor:  
        """ Apply forward pass. """  
        if self.inference_mode:  
            return self.activation(self.se(self.reparam_conv(x)))  
  
        identity_out = 0  
        if self.rbr_skip is not None:  
            identity_out = self.rbr_skip(x)  
  
        # Scale branch output  
        scale_out = 0  
        if self.rbr_scale is not None:  
            scale_out = self.rbr_scale(x)  
  
        # Other branches  
        out = scale_out + identity_out  
        for ix in range(self.num_conv_branches):  
            out += self.rbr_conv[ix](x)  
  
        return self.activation(self.se(out))  
  
    def reparameterize(self):  
        if self.inference_mode:  
            return  
        kernel, bias = self._get_kernel_bias()  
        self.reparam_conv = nn.Conv2d(in_channels=self.rbr_conv[0].conv.in_channels,  
                                      out_channels=self.rbr_conv[0].conv.out_channels,  
                                      kernel_size=self.rbr_conv[0].conv.kernel_size,  
                                      stride=self.rbr_conv[0].conv.stride,  
                                      padding=self.rbr_conv[0].conv.padding,  
                                      dilation=self.rbr_conv[0].conv.dilation,  
                                      groups=self.rbr_conv[0].conv.groups,  
                                      bias=True)  
        self.reparam_conv.weight.data = kernel  
        self.reparam_conv.bias.data = bias  
  
        for para in self.parameters():  
            para.detach_()  
        self.__delattr__('rbr\_conv')  
        self.__delattr__('rbr\_scale')  
        if hasattr(self, 'rbr\_skip'):  
            self.__delattr__('rbr\_skip')  
  
        self.inference_mode = True  
  
    def \_get\_kernel\_bias(self) -> Tuple[torch.Tensor, torch.Tensor]:  
        # 获取scale分支的卷积核bias  
        kernel_scale = 0  
        bias_scale = 0  
        if self.rbr_scale is not None:  
            kernel_scale, bias_scale = self._fuse_bn_tensor(self.rbr_scale)  
            # 将scale分支Pad为卷积分支  
            pad = self.kernel_size // 2  
            kernel_scale = torch.nn.functional.pad(kernel_scale, [pad, pad, pad, pad])  
  
        # 获取 skip 分支的权重  
        kernel_identity = 0  
        bias_identity = 0  
        if self.rbr_skip is not None:  
            kernel_identity, bias_identity = self._fuse_bn_tensor(self.rbr_skip)  
  
        # 获取卷积分支的权重  
        kernel_conv = 0  
        bias_conv = 0  
        for ix in range(self.num_conv_branches):  
            _kernel, _bias = self._fuse_bn_tensor(self.rbr_conv[ix])  
            kernel_conv += _kernel  
            bias_conv += _bias  
  
        kernel_final = kernel_conv + kernel_scale + kernel_identity  
        bias_final = bias_conv + bias_scale + bias_identity  
        return kernel_final, bias_final  
  
    def \_fuse\_bn\_tensor(self, branch) -> Tuple[torch.Tensor, torch.Tensor]:  
        if isinstance(branch, nn.Sequential):  
            kernel = branch.conv.weight  
            running_mean = branch.bn.running_mean  
            running_var = branch.bn.running_var  
            gamma = branch.bn.weight  
            beta = branch.bn.bias  
            eps = branch.bn.eps  
        else:  
            assert isinstance(branch, nn.BatchNorm2d)  
            if not hasattr(self, 'id\_tensor'):  
                input_dim = self.in_channels // self.groups  
                kernel_value = torch.zeros((self.in_channels,   
                                            input_dim,   
                                            self.kernel_size,   
                                            self.kernel_size),  
                                            dtype=branch.weight.dtype,   
                                            device=branch.weight.device)  
                for i in range(self.in_channels):  
                    kernel_value[i, i % input_dim, self.kernel_size // 2, self.kernel_size // 2] = 1  
                self.id_tensor = kernel_value  
            kernel = self.id_tensor  
            running_mean = branch.running_mean  
            running_var = branch.running_var  
            gamma = branch.weight  
            beta = branch.bias  
            eps = branch.eps  
        std = (running_var + eps).sqrt()  
        t = (gamma / std).reshape(-1, 1, 1, 1)  
        return kernel * t, beta - running_mean * gamma / std  
  
    def \_conv\_bn(self, kernel\_size: int, padding: int) -> nn.Sequential:  
        mod_list = nn.Sequential()  
        mod_list.add_module('conv', nn.Conv2d(in_channels=self.in_channels,  
                                              out_channels=self.out_channels,  
                                              kernel_size=kernel_size,  
                                              stride=self.stride,  
                                              padding=padding,  
                                              groups=self.groups,  
                                              bias=False))  
        mod_list.add_module('bn', nn.BatchNorm2d(num_features=self.out_channels))  
        return mod_list  
  
  
if __name__ == '\_\_main\_\_':  
    model = MobileOneBlock(16, 16, 3, padding=1, num_conv_branches=1)  
    x = torch.ones(1, 16, 9, 9)  
    y = model(x)  
    torch.onnx.export(model,  
                      (x,),  
                      'mobileone\_raw.onnx',  
                      opset_version=12,  
                      input_names=['input'],  
                      output_names=['output'])  
    model.reparameterize()  
    torch.onnx.export(model,  
                      (x,),  
                      'mobileone\_rep.onnx',  
                      opset_version=12,  
                      input_names=['input'],  
                      output_names=['output'])  

      

话不多说,直接对比ONNX的输出,就问你香不香!!!

picture.image

4、参考

[1].https://github.com/apple/ml-mobileone
[2].An Improved One millisecond Mobile Backbone

0
0
0
0
关于作者
关于作者

文章

0

获赞

0

收藏

0

相关资源
高性能存储虚拟化方案 NVMe over Fabric 在火山引擎的演进
在云计算中,虚拟化存储扮演着重要角色,其中 iSCSI 协议在业界开放、流行多年。近年来,拥有更优性能的 NVMe over Fabrics 协议也得到了发展。本次分享介绍了 NVMe over Fabrics 在云原生和虚拟化方向的演进工作和成果。
相关产品
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论