TensorFlow 图变换(一):FoldBatchNorms

技术

picture.image


TensorFlow 的计算是按图(Graph)组织的,构建好的图有时需要根据需要做一些变换(例如将训练好的模型部署到生产环境时,去除无用的节点),在保证计算结果不变(或近似不变)的情况下优化计算速度或降低内存占用。Graph Transform Tool【1】是 TensorFlow 提供的一组可以修改 TensorFlow Graph 的工具,使用方便,易于扩展。

使用 Graph Transform Tool 时,它的操作对象为 GraphDef 对象,通常保存为二进制文件,后缀为 .pb。前面文章《TensorFlow 到底有几种模型格式?》介绍过这种文件的生成方式。

该工具调用方式如下:

bazel build tensorflow/tools/graph_transforms:transform_graph bazel-bin/tensorflow/tools/graph_transforms/transform_graph
--in_graph=tensorflow_inception_graph.pb
--out_graph=optimized_inception_graph.pb
--inputs='Mul:0'
--outputs='softmax:0'
--transforms='
strip_unused_nodes(type=float, shape="1,299,299,3")
remove_nodes(op=Identity, op=CheckNumerics)
fold_constants
fold_batch_norms'

注意需要在 TensorFlow 源码根目录下运行。

其中参数 --in_graph 指定输入 GraphDef 文件名,--out_graph 制定输出 GraphDef 文件名,--inputs 指定输入 Node,--outputs 指定输出 Node,--transforms 指定变换类型。变换类型使用一串命令构成,每条命令都对应一种变换。

Batch Normalization(后面简称 BN)【2】是一种加速深度模型训练的技术,通过训练时对每个 mini-batch 内的 activations 做归一化降低 internal covariate shift,进而加速模型收敛。目前主流深度学习模型(ResNet,Inception,DenseNet,……)几乎都使用了 BN 技术。训练完毕,Batch Normalization 的参数(均值 E[x] 和方差 Val[x])不再更新,在后续推理计算时,可将这些常数参数通过 constant folding 来简化模型。BN 计算公式如下:

picture.image

一般 BN 位置都在 Convolution 之后(DenseNet 例外),以 TensorFlow 实现的 Inception V3 模型【3】 为例,Conv-BN-Relu 表示为计算图如下:

picture.image

其中 Conv2D 节点实现 Convolution 计算,Rsqrt 实现先求平方根再取倒数运算,Mul, Add, Sub 分别实现乘法、加法、减法计算,Const 为常数,代表计算参数。由于推理计算时 BN 输入参数均为常数,那么经过 constant folding, BN 可在算数上简化为:

y = x * a + b

进一步,当 x 为卷积输出,对卷积权值直接乘上 a,就可以在前向计算时直接得到 x * a 的结果,这一步称为 BN folding。经过两步简化后的计算图为:

picture.image

此时节点数目也有大量缩减。在推理计算时,能降低运行时间和存储开销。

与上面 BN folding 优化对应的 Graph Transform 代码位于 tensorflow/tools/graph_transforms/fold_batch_norms.cc,其中使用了一个非常有用的函数:

Status ReplaceMatchingOpTypes(

const GraphDef& input\_graph\_def, 


const OpTypePattern& pattern,


const std::function<Status(const NodeMatch&, const std::set<string>&,


const std::set<string>&, std::vector<NodeDef>*)>&  node\_generator,


const ReplaceMatchingOpTypesOptions& options, 


GraphDef* output\_graph\_def);

该函数将 input_graph_def 中所有与 pattern 匹配的子图替换为 node_generator 产生的新 op,然后保存到 output_graph_def 中。

pattern 定义为:

  {"Mul",                // mul\_node


    {


      {"Conv2D|MatMul",  // conv\_node


        {


          {"*"},         // input\_node


          {"Const"},     // weights\_node


        }


      },


      {"Const"},         // mul\_values\_node


    }


  },  // clang-format on

上述 pattern 能匹配原计算图中 Conv2D -> Mul 子图。node_generator 代码中将匹配后的子图直接替换为新的 Conv2D(权值常量更新为原权值与乘数因子 a 的乘积)。代码如下:

    // 从匹配模式中得到 Mul、Conv、Input、weight、Mul value 节点


    const NodeDef& mul\_node = match.node;


    const NodeDef& conv\_node = match.inputs[0].node;


    const NodeDef& input\_node = match.inputs[0].inputs[0].node;


    const NodeDef& weights\_node = match.inputs[0].inputs[1].node;


    const NodeDef& mul\_values\_node = match.inputs[1].node;


    // 获取卷积权值、乘数因子数值


    Tensor weights = GetNodeTensorAttr(weights\_node, "value");


    Tensor mul\_values = GetNodeTensorAttr(mul\_values\_node, "value");


    // 原始卷积权值乘上乘数因子


    auto weights\_matrix = weights.flat\_inner\_dims<float>();


    Tensor scaled\_weights(DT\_FLOAT, weights.shape());


    auto scaled\_weights\_matrix = scaled\_weights.flat\_inner\_dims<float>();


    for (int64 row = 0; row < weights\_matrix.dimension(0); ++row) {


      for (int64 col = 0; col < weights\_cols; ++col) {


        scaled\_weights\_matrix(row, col) =


            weights\_matrix(row, col) * mul\_values.flat<float>()(col);


      }


    }


    // 构造新的卷积权值节点,填入更新后的权值


    NodeDef scaled\_weights\_node;


    scaled\_weights\_node.set\_op("Const");


    scaled\_weights\_node.set\_name(weights\_node.name());


    SetNodeAttr("dtype", DT\_FLOAT, &scaled\_weights\_node);


    SetNodeTensorAttr<float>("value", scaled\_weights, &scaled\_weights\_node);


    new\_nodes->push\_back(scaled\_weights\_node);








    new\_nodes->push\_back(input\_node);


    // 构造新的卷积节点,复制旧卷积节点参数,改个名  



    NodeDef new\_conv\_node;


    new\_conv\_node = conv\_node;


    new\_conv\_node.set\_name(mul\_node.name());


    new\_nodes->push\_back(new\_conv\_node);








    return Status::OK();

Graph Transform 工具为离线优化工具,优化后的 GraphDef 文件可以像原先模型一样部署,无需修改生产环境代码。

本文介绍的 BN folding 优化方法可适用于 CPU、GPU、移动端、嵌入式等各种需要推理加速的场景。

【1】 https://github.com/tensorflow/tensorflow/tree/master/tensorflow/tools/graph\_transforms

【2】 Batch Normalization : Accelerating Deep Network Training by Reducing Internal Covariate Shift, arXiv:1502.03167

【3】 http://download.tensorflow.org/models/inception\_v3\_2016\_08\_28.tar.gz


picture.image

0
0
0
0
关于作者
关于作者

文章

0

获赞

0

收藏

0

相关资源
DevOps 在字节移动研发中的探索和实践
在日益复杂的APP工程架构下,如何保证APP能高效开发,保障团队效能和工程质量?本次将结合字节内部应用的事件案例,介绍DevOps团队对移动研发效能建设的探索和思考。
相关产品
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论