Merge pull request #11890 from dkurt:keras_resize_nearest

pull/11703/head
Alexander Alekhin 6 years ago
commit ccd2370bb7
  1. 45
      modules/dnn/src/tensorflow/tf_graph_simplifier.cpp
  2. 1
      modules/dnn/test/test_tf_importer.cpp

@ -571,6 +571,50 @@ public:
} }
}; };
// In case of resizing by factor.
class UpsamplingKerasSubgraph : public Subgraph
{
public:
UpsamplingKerasSubgraph()
{
int input = addNodeToMatch("");
int shape = addNodeToMatch("Shape", input);
int stack = addNodeToMatch("Const");
int stack_1 = addNodeToMatch("Const");
int stack_2 = addNodeToMatch("Const");
int strided_slice = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2);
int factors = addNodeToMatch("Const");
int mul = addNodeToMatch("Mul", strided_slice, factors);
addNodeToMatch("ResizeNearestNeighbor", input, mul);
setFusedNode("ResizeNearestNeighbor", input, factors);
}
virtual void finalize(tensorflow::GraphDef& net, tensorflow::NodeDef* fusedNode,
std::vector<tensorflow::NodeDef*>& inputNodes) CV_OVERRIDE
{
Mat factorsMat = getTensorContent(inputNodes[1]->attr().at("value").tensor());
CV_Assert(factorsMat.total() == 2, factorsMat.type() == CV_32SC1);
// Height scale factor
tensorflow::TensorProto* factorY = inputNodes[1]->mutable_attr()->at("value").mutable_tensor();
factorY->clear_int_val();
factorY->clear_tensor_content();
factorY->add_int_val(factorsMat.at<int>(0, 0));
// Width scale factor.
tensorflow::NodeDef* factorXNode = net.add_node();
factorXNode->set_op("Const");
factorXNode->set_name(fusedNode->name() + "/factor_y");
tensorflow::AttrValue factorX;
factorX.mutable_tensor()->set_dtype(tensorflow::DT_INT32);
factorX.mutable_tensor()->add_int_val(factorsMat.at<int>(0, 1));
factorXNode->mutable_attr()->insert(MapPair<std::string, tensorflow::AttrValue>("value", factorX));
fusedNode->add_input(factorXNode->name());
}
};
void simplifySubgraphs(tensorflow::GraphDef& net) void simplifySubgraphs(tensorflow::GraphDef& net)
{ {
std::vector<Ptr<Subgraph> > subgraphs; std::vector<Ptr<Subgraph> > subgraphs;
@ -585,6 +629,7 @@ void simplifySubgraphs(tensorflow::GraphDef& net)
subgraphs.push_back(Ptr<Subgraph>(new DeconvolutionValidKerasSubgraph())); subgraphs.push_back(Ptr<Subgraph>(new DeconvolutionValidKerasSubgraph()));
subgraphs.push_back(Ptr<Subgraph>(new DeconvolutionSameKerasSubgraph())); subgraphs.push_back(Ptr<Subgraph>(new DeconvolutionSameKerasSubgraph()));
subgraphs.push_back(Ptr<Subgraph>(new ResizeBilinearSubgraph())); subgraphs.push_back(Ptr<Subgraph>(new ResizeBilinearSubgraph()));
subgraphs.push_back(Ptr<Subgraph>(new UpsamplingKerasSubgraph()));
int numNodes = net.node_size(); int numNodes = net.node_size();
std::vector<int> matchedNodesIds; std::vector<int> matchedNodesIds;

@ -403,6 +403,7 @@ TEST(Test_TensorFlow, split)
TEST(Test_TensorFlow, resize_nearest_neighbor) TEST(Test_TensorFlow, resize_nearest_neighbor)
{ {
runTensorFlowNet("resize_nearest_neighbor"); runTensorFlowNet("resize_nearest_neighbor");
runTensorFlowNet("keras_upsampling2d");
} }
TEST(Test_TensorFlow, slice) TEST(Test_TensorFlow, slice)

Loading…
Cancel
Save