|
|
|
@ -61,27 +61,28 @@ public: |
|
|
|
|
ONNXGraphWrapper(opencv_onnx::GraphProto& _net) : net(_net) |
|
|
|
|
{ |
|
|
|
|
numInputs = net.input_size(); |
|
|
|
|
numInitializers = net.initializer_size(); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
virtual Ptr<ImportNodeWrapper> getNode(int idx) const CV_OVERRIDE |
|
|
|
|
{ |
|
|
|
|
opencv_onnx::NodeProto* node = 0; |
|
|
|
|
if (idx >= numInputs) |
|
|
|
|
node = net.mutable_node(idx - numInputs); |
|
|
|
|
if (idx >= numInputs + numInitializers) |
|
|
|
|
node = net.mutable_node(idx - numInputs - numInitializers); |
|
|
|
|
return makePtr<ONNXNodeWrapper>(node); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
virtual int getNumNodes() const CV_OVERRIDE |
|
|
|
|
{ |
|
|
|
|
return numInputs + net.node_size(); |
|
|
|
|
return numInputs + numInitializers + net.node_size(); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
virtual int getNumOutputs(int nodeId) const CV_OVERRIDE |
|
|
|
|
{ |
|
|
|
|
if (nodeId < numInputs) |
|
|
|
|
if (nodeId < numInputs + numInitializers) |
|
|
|
|
return 1; |
|
|
|
|
else |
|
|
|
|
return net.node(nodeId - numInputs).output_size(); |
|
|
|
|
return net.node(nodeId - numInputs - numInitializers).output_size(); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
virtual std::string getOutputName(int nodeId, int outId) const CV_OVERRIDE |
|
|
|
@ -89,18 +90,20 @@ public: |
|
|
|
|
CV_Assert(outId < getNumOutputs(nodeId)); |
|
|
|
|
if (nodeId < numInputs) |
|
|
|
|
return net.input(nodeId).name(); |
|
|
|
|
else if (nodeId < numInputs + numInitializers) |
|
|
|
|
return net.initializer(nodeId - numInputs).name(); |
|
|
|
|
else |
|
|
|
|
return net.node(nodeId - numInputs).output(outId); |
|
|
|
|
return net.node(nodeId - numInputs - numInitializers).output(outId); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
virtual void removeNode(int idx) CV_OVERRIDE |
|
|
|
|
{ |
|
|
|
|
CV_Assert(idx >= numInputs); |
|
|
|
|
net.mutable_node()->DeleteSubrange(idx - numInputs, 1); |
|
|
|
|
CV_Assert(idx >= numInputs + numInitializers); |
|
|
|
|
net.mutable_node()->DeleteSubrange(idx - numInputs - numInitializers, 1); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
private: |
|
|
|
|
int numInputs; |
|
|
|
|
int numInputs, numInitializers; |
|
|
|
|
opencv_onnx::GraphProto& net; |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
@ -382,33 +385,63 @@ public: |
|
|
|
|
} |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
class BatchNormalizationSubgraph : public Subgraph |
|
|
|
|
class BatchNormalizationSubgraphBase : public Subgraph |
|
|
|
|
{ |
|
|
|
|
public: |
|
|
|
|
BatchNormalizationSubgraph() |
|
|
|
|
BatchNormalizationSubgraphBase() |
|
|
|
|
{ |
|
|
|
|
int input = addNodeToMatch(""); |
|
|
|
|
int data1 = addNodeToMatch("Constant"); |
|
|
|
|
int data2 = addNodeToMatch("Constant"); |
|
|
|
|
int data3 = addNodeToMatch("Constant"); |
|
|
|
|
int data4 = addNodeToMatch("Constant"); |
|
|
|
|
int shape1 = addNodeToMatch("Constant"); |
|
|
|
|
int reshape1 = addNodeToMatch("Reshape", data1, shape1); |
|
|
|
|
int shape2 = addNodeToMatch("Constant"); |
|
|
|
|
int reshape2 = addNodeToMatch("Reshape", data2, shape2); |
|
|
|
|
input = addNodeToMatch(""); |
|
|
|
|
var = addNodeToMatch(""); |
|
|
|
|
mean = addNodeToMatch(""); |
|
|
|
|
weight = addNodeToMatch(""); |
|
|
|
|
bias = addNodeToMatch(""); |
|
|
|
|
A = addNodeToMatch(""); |
|
|
|
|
shape1 = addNodeToMatch(""); |
|
|
|
|
shape2 = addNodeToMatch(""); |
|
|
|
|
} |
|
|
|
|
protected: |
|
|
|
|
int input, var, mean, weight, bias, A, shape1, shape2; |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
class BatchNormalizationSubgraph1 : public BatchNormalizationSubgraphBase |
|
|
|
|
{ |
|
|
|
|
public: |
|
|
|
|
BatchNormalizationSubgraph1() |
|
|
|
|
{ |
|
|
|
|
int reshape1 = addNodeToMatch("Reshape", weight, shape1); |
|
|
|
|
int reshape2 = addNodeToMatch("Reshape", bias, shape2); |
|
|
|
|
int shape3 = addNodeToMatch("Constant"); |
|
|
|
|
int reshape3 = addNodeToMatch("Reshape", data3, shape3); |
|
|
|
|
int reshape3 = addNodeToMatch("Reshape", var, shape3); |
|
|
|
|
int shape4 = addNodeToMatch("Constant"); |
|
|
|
|
int reshape4 = addNodeToMatch("Reshape", data4, shape4); |
|
|
|
|
int reshape4 = addNodeToMatch("Reshape", mean, shape4); |
|
|
|
|
int sqrtNode = addNodeToMatch("Sqrt", reshape3); |
|
|
|
|
int A = addNodeToMatch("Constant"); |
|
|
|
|
int divNode = addNodeToMatch("Div", A, sqrtNode); |
|
|
|
|
int mul1 = addNodeToMatch("Mul", reshape1, divNode); |
|
|
|
|
int mul2 = addNodeToMatch("Mul", reshape4, mul1); |
|
|
|
|
int sub = addNodeToMatch("Sub", reshape2, mul2); |
|
|
|
|
int mul3 = addNodeToMatch("Mul", input, mul1); |
|
|
|
|
addNodeToMatch("Add", mul3, sub); |
|
|
|
|
setFusedNode("BatchNormalization", input, data1, data2, data4 ,data3); |
|
|
|
|
setFusedNode("BatchNormalization", input, weight, bias, mean, var); |
|
|
|
|
} |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
class BatchNormalizationSubgraph2 : public BatchNormalizationSubgraphBase |
|
|
|
|
{ |
|
|
|
|
public: |
|
|
|
|
BatchNormalizationSubgraph2() |
|
|
|
|
{ |
|
|
|
|
int sqrtNode = addNodeToMatch("Sqrt", var); |
|
|
|
|
int divNode = addNodeToMatch("Div", A, sqrtNode); |
|
|
|
|
int mul1 = addNodeToMatch("Mul", weight, divNode); |
|
|
|
|
int reshape2 = addNodeToMatch("Reshape", mul1, shape2); |
|
|
|
|
|
|
|
|
|
int mulMean = addNodeToMatch("Mul", mean, mul1); |
|
|
|
|
int sub = addNodeToMatch("Sub", bias, mulMean); |
|
|
|
|
int reshape1 = addNodeToMatch("Reshape", sub, shape1); |
|
|
|
|
|
|
|
|
|
int mulInput = addNodeToMatch("Mul", input, reshape2); |
|
|
|
|
addNodeToMatch("Add", mulInput, reshape1); |
|
|
|
|
setFusedNode("BatchNormalization", input, weight, bias, mean, var); |
|
|
|
|
} |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
@ -424,7 +457,8 @@ void simplifySubgraphs(opencv_onnx::GraphProto& net) |
|
|
|
|
subgraphs.push_back(makePtr<NormalizeSubgraph1>()); |
|
|
|
|
subgraphs.push_back(makePtr<NormalizeSubgraph2>()); |
|
|
|
|
subgraphs.push_back(makePtr<NormalizeSubgraph3>()); |
|
|
|
|
subgraphs.push_back(makePtr<BatchNormalizationSubgraph>()); |
|
|
|
|
subgraphs.push_back(makePtr<BatchNormalizationSubgraph1>()); |
|
|
|
|
subgraphs.push_back(makePtr<BatchNormalizationSubgraph2>()); |
|
|
|
|
|
|
|
|
|
simplifySubgraphs(Ptr<ImportGraphWrapper>(new ONNXGraphWrapper(net)), subgraphs); |
|
|
|
|
} |
|
|
|
|