From 601977b1697a147758ceebe0004d9627f0c4b2dd Mon Sep 17 00:00:00 2001 From: liqun fu Date: Wed, 15 Apr 2020 20:53:07 -0700 Subject: [PATCH] handle BN with sequence axis --- .../proto/onnx/CNTKToONNX.cpp | 92 ++++++++++++------- 1 file changed, 61 insertions(+), 31 deletions(-) diff --git a/Source/CNTKv2LibraryDll/proto/onnx/CNTKToONNX.cpp b/Source/CNTKv2LibraryDll/proto/onnx/CNTKToONNX.cpp index a5a94901f72c..6f6876947fcb 100644 --- a/Source/CNTKv2LibraryDll/proto/onnx/CNTKToONNX.cpp +++ b/Source/CNTKv2LibraryDll/proto/onnx/CNTKToONNX.cpp @@ -24,6 +24,8 @@ #include "Internals/ComputationGraphAlgorithms.h" #include "ControlFlowHelper.h" +#include + using namespace Microsoft::MSR::CNTK; using namespace CNTK::ONNX; using namespace CNTK; @@ -845,6 +847,10 @@ class CNTKToONNXHelper std::unordered_map& variableNodes, std::vector& scanLoops, int createLoopIndex); + static onnxruntime::Node* WrapSequenceOpWithReshape(const FunctionPtr& src, + onnxruntime::Graph* graph, onnxruntime::NodeArg *input, onnxruntime::NodeArg *output, + std::function& node_creator); + static onnxruntime::Node* CreatePoolingNode(const FunctionPtr& src, onnxruntime::Graph* graph, std::unordered_map& functionNodes, @@ -3817,7 +3823,7 @@ onnxruntime::Node *CNTKToONNXHelper::AddIdentityOp(onnxruntime::NodeArg &nodeArg onnxruntime::NodeArg &outputArg = graph->GetOrCreateNodeArg(out_arg_name, &outputTypeProto); onnxruntime::Node* identityNode = &graph->AddNode( - nodeArg.Name() + string("_identity"), "Identity", "", {&nodeArg}, {&outputArg}); + nodeArg.Name() + string("_identity_") + out_arg_name, "Identity", "", {&nodeArg}, {&outputArg}); return identityNode; } @@ -8422,8 +8428,20 @@ onnxruntime::Node* CNTKToONNXHelper::CreateBatchNormalization(const FunctionPtr onnxruntime::Node *node = nullptr; if (spatial) { - // input and output are in correct shape. - node = &graph->AddNode(nodeName, "BatchNormalization", "", inputs, outputs); + if (src->Inputs()[0].DynamicAxes().size() == 2) + { + std::function node_creator = [&](NodeArg *inputArg, NodeArg *outputArg) + { + inputs[0] = inputArg; + outputs[0] = outputArg; + return &graph->AddNode(nodeName, "BatchNormalization", "", inputs, outputs); + }; + node = WrapSequenceOpWithReshape(src, graph, inputs[0], outputs[0], node_creator); + } else + { + // input and output are in correct shape. + node = &graph->AddNode(nodeName, "BatchNormalization", "", inputs, outputs); + } } else { @@ -8625,6 +8643,40 @@ onnxruntime::Node* ApplyActivationToSequenceConvolution(Node* convNode, const Fu return activationNode; } +onnxruntime::Node* CNTKToONNXHelper::WrapSequenceOpWithReshape(const FunctionPtr& src, + onnxruntime::Graph* graph, onnxruntime::NodeArg *input, onnxruntime::NodeArg *output, + std::function& node_creator) +{ + // Max/AveragePool/BN takes input of shape [N, C, H, W] or [N, C, D1, D2, ..., Dn]. CNTK input needs to be reshaped to match it. + // reshape [#, *][C, H, W] to [-1, C, H, W] + // onnx Max/AveragePool/BN + // reshape [-1, C_out, H_out, W_out] to [#, *][C_out, H_out, W_out] + vector newDimInputToOpNode; + // collapse extra dims into one axis as N for ONNX Conv + newDimInputToOpNode.push_back(-1); + for (int i = 2; i < input->Shape()->dim_size(); i++) + { + // copy C, H, W + if (!input->Shape()->dim(i).has_dim_value()) + LogicError("wrapped_op: feature dimensions need to have dim value."); + newDimInputToOpNode.push_back(input->Shape()->dim(i).dim_value()); + } + + onnxruntime::Node* preReshape = AddReshapeNode(*input, newDimInputToOpNode, input->Name() + "_reshaped_for_wrapped_op", graph); + const std::vector pooling_inputs({ const_cast(preReshape->OutputDefs()[0]) }); + TypeProto nodeOutputTypeProto; + UpdateONNXType(src->Outputs()[0].GetDataType(), nodeOutputTypeProto); + + NodeArg *opOutputArg = &graph->GetOrCreateNodeArg(output->Name() + "_wrapped_op_of_reshaped", &nodeOutputTypeProto); + + onnxruntime::Node* node = node_creator(const_cast(preReshape->OutputDefs()[0]), opOutputArg); + + vector newDimOutputFromPooling = ToINTS(*output->TypeAsProto()); + onnxruntime::Node* postReshape = AddReshapeNode(*opOutputArg, newDimOutputFromPooling, output->Name(), graph); + + return node; +} + // insert reshape before and after a Pooling op when the CNTK op has both sequence and batch axes. onnxruntime::Node* CNTKToONNXHelper::CreatePoolingNode(const FunctionPtr& src, onnxruntime::Graph* graph, @@ -8642,35 +8694,13 @@ onnxruntime::Node* CNTKToONNXHelper::CreatePoolingNode(const FunctionPtr& src, std::vector outputs; ProcessOutputs(src, inputs, outputs, graph); - // Max/AveragePool takes input of shape [N, C, H, W] or [N, C, D1, D2, ..., Dn]. CNTK input needs to be reshaped to match it. - // reshape [#, *][C, H, W] to [-1, C, H, W] - // onnx Max/AveragePool - // reshape [-1, C_out, H_out, W_out] to [#, *][C_out, H_out, W_out] - vector newDimInputToPooling; - // collapse extra dims into one axis as N for ONNX Conv - newDimInputToPooling.push_back(-1); - for (int i = 2; i < inputs[0]->Shape()->dim_size(); i++) - { - // copy C, H, W - if (!inputs[0]->Shape()->dim(i).has_dim_value()) - LogicError("Max/AveragePool: feature dimensions need to have dim value."); - newDimInputToPooling.push_back(inputs[0]->Shape()->dim(i).dim_value()); - } - - onnxruntime::Node* preReshape = AddReshapeNode(*inputs[0], newDimInputToPooling, inputs[0]->Name() + "_reshaped_for_max_pool", graph); - const std::vector pooling_inputs({const_cast(preReshape->OutputDefs()[0])}); - TypeProto poolingOutputTypeProto; - UpdateONNXType(src->Outputs()[0].GetDataType(), poolingOutputTypeProto); - - NodeArg *poolingOutputArg = &graph->GetOrCreateNodeArg(outputs[0]->Name() + "_pooling_of_reshaped", &poolingOutputTypeProto); - - onnxruntime::Node* poolingNode = AddNode(src, graph, pooling_inputs, { poolingOutputArg }); - - vector newDimOutputFromPooling = ToINTS(*outputs[0]->TypeAsProto()); - onnxruntime::Node* postReshape = AddReshapeNode(*poolingOutputArg, newDimOutputFromPooling, outputs[0]->Name(), graph); + std::function node_creator = [&](NodeArg *inputArg, NodeArg *outputArg) { + return AddNode(src, graph, { inputArg }, { outputArg }); + }; + onnxruntime::Node* poolingNode = WrapSequenceOpWithReshape(src, graph, inputs[0], outputs[0], node_creator); functionNodes.emplace(src, poolingNode); - return postReshape; + return poolingNode; } onnxruntime::Node* CNTKToONNXHelper::CreateConvolutionNode(const FunctionPtr& src, @@ -8706,7 +8736,7 @@ onnxruntime::Node* CNTKToONNXHelper::CreateConvolutionNode(const FunctionPtr& sr newDimInputToConv.push_back(inputs[1]->Shape()->dim(i).dim_value()); } - onnxruntime::Node* preReshape = AddReshapeNode(*inputs[1], newDimInputToConv, inputs[1]->Name() + "_reshaped_for_conv", graph); + onnxruntime::Node* preReshape = AddReshapeNode(*inputs[1], newDimInputToConv, inputs[1]->Name() + "_reshaped_for_conv_" + ToLegacyString(ToUTF8(src->Name())), graph); std::vector conv_inputs = inputs; conv_inputs[1] = const_cast(preReshape->OutputDefs()[0]); TypeProto convOutputTypeProto;