|
|
|
@ -132,6 +132,183 @@ private: |
|
|
|
|
opencv_onnx::GraphProto& net; |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
/* Fusion for Gelu.
|
|
|
|
|
|
|
|
|
|
Graph before fusion: |
|
|
|
|
+---------------------------------------------+ |
|
|
|
|
| | |
|
|
|
|
[Input] -> Div[B=sqrt(2)] -> Erf -> Add[B=1] -> Mul -> Mul[B=0.5] -> [Output] |
|
|
|
|
|
|
|
|
|
Graph after fusion: |
|
|
|
|
[Input] -> Gelu -> [Output] |
|
|
|
|
|
|
|
|
|
*/ |
|
|
|
|
class GeluSubGraph : public Subgraph |
|
|
|
|
{ |
|
|
|
|
public: |
|
|
|
|
GeluSubGraph() |
|
|
|
|
{ |
|
|
|
|
int input = addNodeToMatch(""); |
|
|
|
|
int div = addNodeToMatch("Div", input, addNodeToMatch("") /* B=sqrt(2) */ ); |
|
|
|
|
int erf = addNodeToMatch("Erf", div); |
|
|
|
|
int add = addNodeToMatch("Add", erf, addNodeToMatch("") /* B=1 */ ); |
|
|
|
|
int mul = addNodeToMatch("Mul", input, add); |
|
|
|
|
addNodeToMatch("Mul", mul, addNodeToMatch("") /* B=0.5 */) ; |
|
|
|
|
|
|
|
|
|
setFusedNode("Gelu", input); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
static bool isWithInitializer(const std::vector<int>& matchedNodesIds) |
|
|
|
|
{ |
|
|
|
|
// if node.getType() is Constant, Constant nodes are placed between other nodes
|
|
|
|
|
if (matchedNodesIds[2] - matchedNodesIds[1] != 1) |
|
|
|
|
return false; |
|
|
|
|
// if Initializer, there is no Constant node between other nodes
|
|
|
|
|
return true; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
static float extractConstant(const Ptr<ImportGraphWrapper>& net, int node_id, int input_id, bool withInitializer) |
|
|
|
|
{ |
|
|
|
|
if (withInitializer) |
|
|
|
|
{ |
|
|
|
|
auto onnx_net = net.dynamicCast<ONNXGraphWrapper>(); |
|
|
|
|
int initializer_id = onnx_net->getInputInitializerId(node_id, input_id); |
|
|
|
|
Mat const_mat = onnx_net->getMatFromInitializer(initializer_id); |
|
|
|
|
return *const_mat.ptr<float>(); |
|
|
|
|
} else { |
|
|
|
|
const Ptr<ImportNodeWrapper> node = net->getNode(node_id); |
|
|
|
|
int constant_id = getInputNodeId(net, node, input_id); |
|
|
|
|
Ptr<ImportNodeWrapper> constant_ptr = net->getNode(constant_id); |
|
|
|
|
opencv_onnx::NodeProto* constant_node = constant_ptr.dynamicCast<ONNXNodeWrapper>()->node; |
|
|
|
|
opencv_onnx::TensorProto constant_proto = constant_node->attribute(0).t(); |
|
|
|
|
Mat constant_mat = getMatFromTensor(constant_proto); |
|
|
|
|
return *constant_mat.ptr<float>(); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId, |
|
|
|
|
std::vector<int>& matchedNodesIds, |
|
|
|
|
std::vector<int>& targetNodesIds) CV_OVERRIDE |
|
|
|
|
{ |
|
|
|
|
if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds)) |
|
|
|
|
{ |
|
|
|
|
bool withInitializer = isWithInitializer(matchedNodesIds); |
|
|
|
|
|
|
|
|
|
// Check Div[B=sqrt(2)]
|
|
|
|
|
float divisor = extractConstant(net, matchedNodesIds[0], 1, withInitializer); |
|
|
|
|
if (divisor - M_SQRT2 >= 1e-6) |
|
|
|
|
return false; |
|
|
|
|
|
|
|
|
|
// Check Add[B=1]
|
|
|
|
|
float add_const = extractConstant(net, matchedNodesIds[2], 1, withInitializer); |
|
|
|
|
if (add_const - 1.f >= 1e-6) |
|
|
|
|
return false; |
|
|
|
|
|
|
|
|
|
// Check Mul[B=0.5]
|
|
|
|
|
float mul_const = extractConstant(net, matchedNodesIds[4], 1, withInitializer); |
|
|
|
|
if (mul_const - 0.5f >= 1e-6) |
|
|
|
|
return false; |
|
|
|
|
|
|
|
|
|
return true; |
|
|
|
|
} |
|
|
|
|
return false; |
|
|
|
|
} |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
/* Fusion for GeluApproximation.
|
|
|
|
|
|
|
|
|
|
Graph before fusion: |
|
|
|
|
+--------+------+----------------+------------------------------------+ |
|
|
|
|
| | | | | |
|
|
|
|
[Input] -> Mul -> Mul -> Mul[ ] -> Add -> Mul[ ] -> Tanh -> Add[A=1] -> Mul -> Mul(A=0.5) -> [Output] |
|
|
|
|
/ \
|
|
|
|
|
A=0.044714998453855515 A=sqrt(2/pie) |
|
|
|
|
|
|
|
|
|
Graph after fusion: |
|
|
|
|
[Input] -> GeluApproximation -> [Output] |
|
|
|
|
|
|
|
|
|
*/ |
|
|
|
|
class GeluApproximationSubGraph : public Subgraph |
|
|
|
|
{ |
|
|
|
|
public: |
|
|
|
|
GeluApproximationSubGraph() |
|
|
|
|
{ |
|
|
|
|
int input = addNodeToMatch(""); |
|
|
|
|
int mul0 = addNodeToMatch("Mul", input, input); |
|
|
|
|
int mul1 = addNodeToMatch("Mul", input, mul0); |
|
|
|
|
int mul2 = addNodeToMatch("Mul", addNodeToMatch("") /* A=0.044714998453855515 */, mul1); |
|
|
|
|
int add0 = addNodeToMatch("Add", input, mul2); |
|
|
|
|
int mul3 = addNodeToMatch("Mul", addNodeToMatch("") /* A=sqrt(2/pie) */, add0); |
|
|
|
|
int tanh = addNodeToMatch("Tanh", mul3); |
|
|
|
|
int add1 = addNodeToMatch("Add", addNodeToMatch("") /* A=1 */, tanh); |
|
|
|
|
int mul4 = addNodeToMatch("Mul", input, add1); |
|
|
|
|
addNodeToMatch("Mul", addNodeToMatch("") /* A=0.5 */, mul4); |
|
|
|
|
|
|
|
|
|
setFusedNode("GeluApproximation", input); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
static bool isWithInitializer(const std::vector<int>& matchedNodesIds) |
|
|
|
|
{ |
|
|
|
|
// if node.getType() is Constant, Constant nodes are placed between other nodes
|
|
|
|
|
if (matchedNodesIds[2] - matchedNodesIds[1] != 1) |
|
|
|
|
return false; |
|
|
|
|
// if Initializer, there is no Constant node between other nodes
|
|
|
|
|
return true; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
static float extractConstant(const Ptr<ImportGraphWrapper>& net, int node_id, int input_id, bool withInitializer) |
|
|
|
|
{ |
|
|
|
|
if (withInitializer) |
|
|
|
|
{ |
|
|
|
|
auto onnx_net = net.dynamicCast<ONNXGraphWrapper>(); |
|
|
|
|
int initializer_id = onnx_net->getInputInitializerId(node_id, input_id); |
|
|
|
|
Mat const_mat = onnx_net->getMatFromInitializer(initializer_id); |
|
|
|
|
return *const_mat.ptr<float>(); |
|
|
|
|
} else { |
|
|
|
|
const Ptr<ImportNodeWrapper> node = net->getNode(node_id); |
|
|
|
|
int constant_id = getInputNodeId(net, node, input_id); |
|
|
|
|
Ptr<ImportNodeWrapper> constant_ptr = net->getNode(constant_id); |
|
|
|
|
opencv_onnx::NodeProto* constant_node = constant_ptr.dynamicCast<ONNXNodeWrapper>()->node; |
|
|
|
|
opencv_onnx::TensorProto constant_proto = constant_node->attribute(0).t(); |
|
|
|
|
Mat constant_mat = getMatFromTensor(constant_proto); |
|
|
|
|
return *constant_mat.ptr<float>(); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId, |
|
|
|
|
std::vector<int>& matchedNodesIds, |
|
|
|
|
std::vector<int>& targetNodesIds) CV_OVERRIDE |
|
|
|
|
{ |
|
|
|
|
if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds)) |
|
|
|
|
{ |
|
|
|
|
bool withInitializer = isWithInitializer(matchedNodesIds); |
|
|
|
|
|
|
|
|
|
// Check Mul[A=0.044714998453855515]
|
|
|
|
|
float coef = extractConstant(net, matchedNodesIds[2], 0, withInitializer); |
|
|
|
|
if (coef - 0.044714998453855515 >= 1e-6) |
|
|
|
|
return false; |
|
|
|
|
|
|
|
|
|
// Check Mul[A=sqrt(2/pie)]
|
|
|
|
|
float sqrt_2_pie = extractConstant(net, matchedNodesIds[4], 0, withInitializer); |
|
|
|
|
if (sqrt_2_pie - 0.7978845834732056 >= 1e-6) |
|
|
|
|
return false; |
|
|
|
|
|
|
|
|
|
// Check Add[A=1]
|
|
|
|
|
float add_const = extractConstant(net, matchedNodesIds[6], 0, withInitializer); |
|
|
|
|
if (add_const - 1.f >= 1e-6) |
|
|
|
|
return false; |
|
|
|
|
|
|
|
|
|
// Check Mul[A=0.5]
|
|
|
|
|
float mul_const = extractConstant(net, matchedNodesIds[8], 0, withInitializer); |
|
|
|
|
if (mul_const - 0.5f >= 1e-6) |
|
|
|
|
return false; |
|
|
|
|
|
|
|
|
|
return true; |
|
|
|
|
} |
|
|
|
|
return false; |
|
|
|
|
} |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
class LayerNormSubGraph : public Subgraph |
|
|
|
|
{ |
|
|
|
|
public: |
|
|
|
@ -904,6 +1081,8 @@ public: |
|
|
|
|
void simplifySubgraphs(opencv_onnx::GraphProto& net) |
|
|
|
|
{ |
|
|
|
|
std::vector<Ptr<Subgraph> > subgraphs; |
|
|
|
|
subgraphs.push_back(makePtr<GeluSubGraph>()); |
|
|
|
|
subgraphs.push_back(makePtr<GeluApproximationSubGraph>()); |
|
|
|
|
subgraphs.push_back(makePtr<LayerNormSubGraph>()); |
|
|
|
|
subgraphs.push_back(makePtr<GatherCastSubgraph>()); |
|
|
|
|
subgraphs.push_back(makePtr<MulCastSubgraph>()); |
|
|
|
|