|
|
|
@ -495,8 +495,9 @@ public: |
|
|
|
|
ResizeBilinearSubgraph() |
|
|
|
|
{ |
|
|
|
|
int input = addNodeToMatch(""); |
|
|
|
|
int shapeSource = addNodeToMatch(""); |
|
|
|
|
|
|
|
|
|
int shape = addNodeToMatch("Shape", input); |
|
|
|
|
int shape = addNodeToMatch("Shape", shapeSource); |
|
|
|
|
int stack = addNodeToMatch("Const"); |
|
|
|
|
int stack_1 = addNodeToMatch("Const"); |
|
|
|
|
int stack_2 = addNodeToMatch("Const"); |
|
|
|
@ -504,7 +505,7 @@ public: |
|
|
|
|
int factorY = addNodeToMatch("Const"); |
|
|
|
|
int mul = addNodeToMatch("Mul", strided_slice, factorY); |
|
|
|
|
|
|
|
|
|
shape = addNodeToMatch("Shape", input); |
|
|
|
|
shape = addNodeToMatch("Shape", shapeSource); |
|
|
|
|
stack = addNodeToMatch("Const"); |
|
|
|
|
stack_1 = addNodeToMatch("Const"); |
|
|
|
|
stack_2 = addNodeToMatch("Const"); |
|
|
|
@ -519,6 +520,51 @@ public: |
|
|
|
|
} |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
// In case of resizing by factor.
|
|
|
|
|
class ResizeBilinearSubgraphDown : public TFSubgraph |
|
|
|
|
{ |
|
|
|
|
public: |
|
|
|
|
ResizeBilinearSubgraphDown() |
|
|
|
|
{ |
|
|
|
|
int input = addNodeToMatch(""); |
|
|
|
|
int shapeSource = addNodeToMatch(""); |
|
|
|
|
|
|
|
|
|
int shape = addNodeToMatch("Shape", shapeSource); |
|
|
|
|
int stack = addNodeToMatch("Const"); |
|
|
|
|
int stack_1 = addNodeToMatch("Const"); |
|
|
|
|
int stack_2 = addNodeToMatch("Const"); |
|
|
|
|
int strided_slice = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2); |
|
|
|
|
int factorY = addNodeToMatch("Const"); |
|
|
|
|
int div = addNodeToMatch("RealDiv", addNodeToMatch("Cast", strided_slice), factorY); |
|
|
|
|
int cast = addNodeToMatch("Cast", div); |
|
|
|
|
|
|
|
|
|
shape = addNodeToMatch("Shape", shapeSource); |
|
|
|
|
stack = addNodeToMatch("Const"); |
|
|
|
|
stack_1 = addNodeToMatch("Const"); |
|
|
|
|
stack_2 = addNodeToMatch("Const"); |
|
|
|
|
strided_slice = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2); |
|
|
|
|
int factorX = addNodeToMatch("Const"); |
|
|
|
|
int div_1 = addNodeToMatch("RealDiv", addNodeToMatch("Cast", strided_slice), factorX); |
|
|
|
|
int cast_1 = addNodeToMatch("Cast", div_1); |
|
|
|
|
|
|
|
|
|
int pack = addNodeToMatch("Pack", cast, cast_1); |
|
|
|
|
|
|
|
|
|
addNodeToMatch("ResizeBilinear", input, pack); |
|
|
|
|
setFusedNode("ResizeBilinear", input, factorY, factorX); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
virtual void finalize(tensorflow::GraphDef&, tensorflow::NodeDef* fusedNode, |
|
|
|
|
std::vector<tensorflow::NodeDef*>& inputNodes) CV_OVERRIDE |
|
|
|
|
{ |
|
|
|
|
|
|
|
|
|
for (int i = 1; i < 3; ++i) |
|
|
|
|
{ |
|
|
|
|
tensorflow::TensorProto* factor = inputNodes[i]->mutable_attr()->at("value").mutable_tensor(); |
|
|
|
|
factor->set_double_val(0, 1.0 / factor->double_val(0)); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
// In case of resizing by factor.
|
|
|
|
|
class UpsamplingKerasSubgraph : public TFSubgraph |
|
|
|
|
{ |
|
|
|
@ -702,6 +748,7 @@ void simplifySubgraphs(tensorflow::GraphDef& net) |
|
|
|
|
subgraphs.push_back(Ptr<Subgraph>(new PReLUSubgraph(true))); |
|
|
|
|
subgraphs.push_back(Ptr<Subgraph>(new PReLUSubgraph(false))); |
|
|
|
|
subgraphs.push_back(Ptr<Subgraph>(new FlattenProdSubgraph())); |
|
|
|
|
subgraphs.push_back(Ptr<Subgraph>(new ResizeBilinearSubgraphDown())); |
|
|
|
|
|
|
|
|
|
for (int i = 0; i < net.node_size(); ++i) |
|
|
|
|
{ |
|
|
|
|