|
|
|
@ -382,6 +382,36 @@ public: |
|
|
|
|
} |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
class BatchNormalizationSubgraph : public Subgraph |
|
|
|
|
{ |
|
|
|
|
public: |
|
|
|
|
BatchNormalizationSubgraph() |
|
|
|
|
{ |
|
|
|
|
int input = addNodeToMatch(""); |
|
|
|
|
int data1 = addNodeToMatch("Constant"); |
|
|
|
|
int data2 = addNodeToMatch("Constant"); |
|
|
|
|
int data3 = addNodeToMatch("Constant"); |
|
|
|
|
int data4 = addNodeToMatch("Constant"); |
|
|
|
|
int shape1 = addNodeToMatch("Constant"); |
|
|
|
|
int reshape1 = addNodeToMatch("Reshape", data1, shape1); |
|
|
|
|
int shape2 = addNodeToMatch("Constant"); |
|
|
|
|
int reshape2 = addNodeToMatch("Reshape", data2, shape2); |
|
|
|
|
int shape3 = addNodeToMatch("Constant"); |
|
|
|
|
int reshape3 = addNodeToMatch("Reshape", data3, shape3); |
|
|
|
|
int shape4 = addNodeToMatch("Constant"); |
|
|
|
|
int reshape4 = addNodeToMatch("Reshape", data4, shape4); |
|
|
|
|
int sqrtNode = addNodeToMatch("Sqrt", reshape3); |
|
|
|
|
int A = addNodeToMatch("Constant"); |
|
|
|
|
int divNode = addNodeToMatch("Div", A, sqrtNode); |
|
|
|
|
int mul1 = addNodeToMatch("Mul", reshape1, divNode); |
|
|
|
|
int mul2 = addNodeToMatch("Mul", reshape4, mul1); |
|
|
|
|
int sub = addNodeToMatch("Sub", reshape2, mul2); |
|
|
|
|
int mul3 = addNodeToMatch("Mul", input, mul1); |
|
|
|
|
addNodeToMatch("Add", mul3, sub); |
|
|
|
|
setFusedNode("BatchNormalization", input, data1, data2, data4 ,data3); |
|
|
|
|
} |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
void simplifySubgraphs(opencv_onnx::GraphProto& net) |
|
|
|
|
{ |
|
|
|
|
std::vector<Ptr<Subgraph> > subgraphs; |
|
|
|
@ -394,6 +424,7 @@ void simplifySubgraphs(opencv_onnx::GraphProto& net) |
|
|
|
|
subgraphs.push_back(makePtr<NormalizeSubgraph1>()); |
|
|
|
|
subgraphs.push_back(makePtr<NormalizeSubgraph2>()); |
|
|
|
|
subgraphs.push_back(makePtr<NormalizeSubgraph3>()); |
|
|
|
|
subgraphs.push_back(makePtr<BatchNormalizationSubgraph>()); |
|
|
|
|
|
|
|
|
|
simplifySubgraphs(Ptr<ImportGraphWrapper>(new ONNXGraphWrapper(net)), subgraphs); |
|
|
|
|
} |
|
|
|
|