|
|
|
@ -531,6 +531,38 @@ public: |
|
|
|
|
} |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
class MishSubgraph2 : public Subgraph |
|
|
|
|
{ |
|
|
|
|
public: |
|
|
|
|
MishSubgraph2() |
|
|
|
|
{ |
|
|
|
|
int input = addNodeToMatch(""); |
|
|
|
|
int exp = addNodeToMatch("Exp", input); |
|
|
|
|
int addVal = addNodeToMatch(""); |
|
|
|
|
int add = addNodeToMatch("Add", addVal, exp); |
|
|
|
|
int log = addNodeToMatch("Log", add); |
|
|
|
|
int tanh = addNodeToMatch("Tanh", log); |
|
|
|
|
addNodeToMatch("Mul", input, tanh); |
|
|
|
|
setFusedNode("Mish", input); |
|
|
|
|
} |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
class MishSubgraph3 : public Subgraph |
|
|
|
|
{ |
|
|
|
|
public: |
|
|
|
|
MishSubgraph3() |
|
|
|
|
{ |
|
|
|
|
int input = addNodeToMatch(""); |
|
|
|
|
int exp = addNodeToMatch("Exp", input); |
|
|
|
|
int addVal = addNodeToMatch(""); |
|
|
|
|
int add = addNodeToMatch("Add", exp, addVal); |
|
|
|
|
int log = addNodeToMatch("Log", add); |
|
|
|
|
int tanh = addNodeToMatch("Tanh", log); |
|
|
|
|
addNodeToMatch("Mul", input, tanh); |
|
|
|
|
setFusedNode("Mish", input); |
|
|
|
|
} |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
class MulCastSubgraph : public Subgraph |
|
|
|
|
{ |
|
|
|
|
public: |
|
|
|
@ -735,6 +767,8 @@ void simplifySubgraphs(opencv_onnx::GraphProto& net) |
|
|
|
|
subgraphs.push_back(makePtr<BatchNormalizationSubgraph2>()); |
|
|
|
|
subgraphs.push_back(makePtr<ExpandSubgraph>()); |
|
|
|
|
subgraphs.push_back(makePtr<MishSubgraph>()); |
|
|
|
|
subgraphs.push_back(makePtr<MishSubgraph2>()); |
|
|
|
|
subgraphs.push_back(makePtr<MishSubgraph3>()); |
|
|
|
|
subgraphs.push_back(makePtr<NormalizeSubgraph4>()); |
|
|
|
|
subgraphs.push_back(makePtr<NormalizeSubgraph5>()); |
|
|
|
|
|
|
|
|
|