|
|
|
@ -249,6 +249,40 @@ public: |
|
|
|
|
} |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
class NormalizeSubgraph4 : public NormalizeSubgraphBase |
|
|
|
|
{ |
|
|
|
|
public: |
|
|
|
|
NormalizeSubgraph4() : NormalizeSubgraphBase(1) |
|
|
|
|
{ |
|
|
|
|
int input = addNodeToMatch(""); |
|
|
|
|
int mul = addNodeToMatch("Mul", input, input); |
|
|
|
|
int sum = addNodeToMatch("ReduceSum", mul); |
|
|
|
|
int eps = addNodeToMatch(""); |
|
|
|
|
int max = addNodeToMatch("Max", sum, eps); |
|
|
|
|
int sqrt = addNodeToMatch("Sqrt", max); |
|
|
|
|
int reciprocal = addNodeToMatch("Reciprocal", sqrt); |
|
|
|
|
addNodeToMatch("Mul", input, reciprocal); |
|
|
|
|
setFusedNode("Normalize", input); |
|
|
|
|
} |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
class NormalizeSubgraph5 : public NormalizeSubgraphBase |
|
|
|
|
{ |
|
|
|
|
public: |
|
|
|
|
NormalizeSubgraph5() : NormalizeSubgraphBase(1) |
|
|
|
|
{ |
|
|
|
|
int input = addNodeToMatch(""); |
|
|
|
|
int mul = addNodeToMatch("Mul", input, input); |
|
|
|
|
int sum = addNodeToMatch("ReduceSum", mul); |
|
|
|
|
int clip = addNodeToMatch("Clip", sum); |
|
|
|
|
int sqrt = addNodeToMatch("Sqrt", clip); |
|
|
|
|
int one = addNodeToMatch("Constant"); |
|
|
|
|
int div = addNodeToMatch("Div", one, sqrt); |
|
|
|
|
addNodeToMatch("Mul", input, div); |
|
|
|
|
setFusedNode("Normalize", input); |
|
|
|
|
} |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
class GatherCastSubgraph : public Subgraph |
|
|
|
|
{ |
|
|
|
|
public: |
|
|
|
@ -526,6 +560,8 @@ void simplifySubgraphs(opencv_onnx::GraphProto& net) |
|
|
|
|
subgraphs.push_back(makePtr<BatchNormalizationSubgraph2>()); |
|
|
|
|
subgraphs.push_back(makePtr<ExpandSubgraph>()); |
|
|
|
|
subgraphs.push_back(makePtr<MishSubgraph>()); |
|
|
|
|
subgraphs.push_back(makePtr<NormalizeSubgraph4>()); |
|
|
|
|
subgraphs.push_back(makePtr<NormalizeSubgraph5>()); |
|
|
|
|
|
|
|
|
|
simplifySubgraphs(Ptr<ImportGraphWrapper>(new ONNXGraphWrapper(net)), subgraphs); |
|
|
|
|
} |
|
|
|
|