From 36288eebe7a75a047dd983cb60ff1c1ab705f602 Mon Sep 17 00:00:00 2001 From: Dmitry Kurtaev Date: Wed, 4 Jul 2018 11:53:24 +0300 Subject: [PATCH] Nearest neighbor resize from Keras --- .../src/tensorflow/tf_graph_simplifier.cpp | 45 +++++++++++++++++++ modules/dnn/test/test_tf_importer.cpp | 1 + 2 files changed, 46 insertions(+) diff --git a/modules/dnn/src/tensorflow/tf_graph_simplifier.cpp b/modules/dnn/src/tensorflow/tf_graph_simplifier.cpp index a537358a1f..3d8a97f240 100644 --- a/modules/dnn/src/tensorflow/tf_graph_simplifier.cpp +++ b/modules/dnn/src/tensorflow/tf_graph_simplifier.cpp @@ -571,6 +571,50 @@ public: } }; +// In case of resizing by factor. +class UpsamplingKerasSubgraph : public Subgraph +{ +public: + UpsamplingKerasSubgraph() + { + int input = addNodeToMatch(""); + int shape = addNodeToMatch("Shape", input); + int stack = addNodeToMatch("Const"); + int stack_1 = addNodeToMatch("Const"); + int stack_2 = addNodeToMatch("Const"); + int strided_slice = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2); + int factors = addNodeToMatch("Const"); + int mul = addNodeToMatch("Mul", strided_slice, factors); + addNodeToMatch("ResizeNearestNeighbor", input, mul); + setFusedNode("ResizeNearestNeighbor", input, factors); + } + + virtual void finalize(tensorflow::GraphDef& net, tensorflow::NodeDef* fusedNode, + std::vector& inputNodes) CV_OVERRIDE + { + Mat factorsMat = getTensorContent(inputNodes[1]->attr().at("value").tensor()); + CV_Assert(factorsMat.total() == 2, factorsMat.type() == CV_32SC1); + + // Height scale factor + tensorflow::TensorProto* factorY = inputNodes[1]->mutable_attr()->at("value").mutable_tensor(); + factorY->clear_int_val(); + factorY->clear_tensor_content(); + factorY->add_int_val(factorsMat.at(0, 0)); + + // Width scale factor. + tensorflow::NodeDef* factorXNode = net.add_node(); + factorXNode->set_op("Const"); + factorXNode->set_name(fusedNode->name() + "/factor_y"); + + tensorflow::AttrValue factorX; + factorX.mutable_tensor()->set_dtype(tensorflow::DT_INT32); + factorX.mutable_tensor()->add_int_val(factorsMat.at(0, 1)); + factorXNode->mutable_attr()->insert(MapPair("value", factorX)); + + fusedNode->add_input(factorXNode->name()); + } +}; + void simplifySubgraphs(tensorflow::GraphDef& net) { std::vector > subgraphs; @@ -585,6 +629,7 @@ void simplifySubgraphs(tensorflow::GraphDef& net) subgraphs.push_back(Ptr(new DeconvolutionValidKerasSubgraph())); subgraphs.push_back(Ptr(new DeconvolutionSameKerasSubgraph())); subgraphs.push_back(Ptr(new ResizeBilinearSubgraph())); + subgraphs.push_back(Ptr(new UpsamplingKerasSubgraph())); int numNodes = net.node_size(); std::vector matchedNodesIds; diff --git a/modules/dnn/test/test_tf_importer.cpp b/modules/dnn/test/test_tf_importer.cpp index d4ffc94399..bb60d46d6f 100644 --- a/modules/dnn/test/test_tf_importer.cpp +++ b/modules/dnn/test/test_tf_importer.cpp @@ -402,6 +402,7 @@ TEST(Test_TensorFlow, split) TEST(Test_TensorFlow, resize_nearest_neighbor) { runTensorFlowNet("resize_nearest_neighbor"); + runTensorFlowNet("keras_upsampling2d"); } TEST(Test_TensorFlow, slice)