@ -125,8 +125,13 @@ public:
virtual void removeNode ( int idx ) CV_OVERRIDE
{
CV_Assert ( idx > = numInputs + numInitializers ) ;
net . mutable_node ( ) - > DeleteSubrange ( idx - numInputs - numInitializers , 1 ) ;
if ( idx > = numInputs + numInitializers )
net . mutable_node ( ) - > DeleteSubrange ( idx - numInputs - numInitializers , 1 ) ;
}
virtual inline bool isCommutativeOp ( const std : : string & type ) const CV_OVERRIDE
{
return type = = " Add " | | type = = " Mul " | | type = = " Equal " | | type = = " Max " ;
}
private :
@ -134,6 +139,25 @@ private:
opencv_onnx : : GraphProto & net ;
} ;
static Mat extractConstant ( const Ptr < ImportGraphWrapper > & net , int node_id , int input_id )
{
auto onnx_net = net . dynamicCast < ONNXGraphWrapper > ( ) ;
int initializer_id = onnx_net - > getInputInitializerId ( node_id , input_id ) ;
if ( initializer_id ! = - 1 )
{
return onnx_net - > getMatFromInitializer ( initializer_id ) ;
}
else
{
const Ptr < ImportNodeWrapper > node = net - > getNode ( node_id ) ;
int constant_id = Subgraph : : getInputNodeId ( net , node , input_id ) ;
Ptr < ImportNodeWrapper > constant_ptr = net - > getNode ( constant_id ) ;
opencv_onnx : : NodeProto * constant_node = constant_ptr . dynamicCast < ONNXNodeWrapper > ( ) - > node ;
opencv_onnx : : TensorProto constant_proto = constant_node - > attribute ( 0 ) . t ( ) ;
return getMatFromTensor ( constant_proto ) ;
}
}
/* Fusion for Gelu.
Graph before fusion :
@ -151,54 +175,32 @@ public:
GeluSubGraph ( )
{
int input = addNodeToMatch ( " " ) ;
int div = addNodeToMatch ( " Div " , input , addNodeToMatch ( " " ) /* B=sqrt(2) */ ) ;
div = addNodeToMatch ( " Div " , input , addNodeToMatch ( " " ) /* B=sqrt(2) */ ) ;
int erf = addNodeToMatch ( " Erf " , div ) ;
int add = addNodeToMatch ( " Add " , erf , addNodeToMatch ( " " ) /* B=1 */ ) ;
add = addNodeToMatch ( " Add " , erf , addNodeToMatch ( " " ) /* B=1 */ ) ;
int mul = addNodeToMatch ( " Mul " , input , add ) ;
addNodeToMatch ( " Mul " , mul , addNodeToMatch ( " " ) /* B=0.5 */ ) ;
mul2 = addNodeToMatch ( " Mul " , mul , addNodeToMatch ( " " ) /* B=0.5 */ ) ;
setFusedNode ( " Gelu " , input ) ;
}
static float extractConstant ( const Ptr < ImportGraphWrapper > & net , int node_id , int input_id )
{
auto onnx_net = net . dynamicCast < ONNXGraphWrapper > ( ) ;
int initializer_id = onnx_net - > getInputInitializerId ( node_id , input_id ) ;
if ( initializer_id ! = - 1 )
{
Mat const_mat = onnx_net - > getMatFromInitializer ( initializer_id ) ;
return * const_mat . ptr < float > ( ) ;
}
else
{
const Ptr < ImportNodeWrapper > node = net - > getNode ( node_id ) ;
int constant_id = getInputNodeId ( net , node , input_id ) ;
Ptr < ImportNodeWrapper > constant_ptr = net - > getNode ( constant_id ) ;
opencv_onnx : : NodeProto * constant_node = constant_ptr . dynamicCast < ONNXNodeWrapper > ( ) - > node ;
opencv_onnx : : TensorProto constant_proto = constant_node - > attribute ( 0 ) . t ( ) ;
Mat constant_mat = getMatFromTensor ( constant_proto ) ;
return * constant_mat . ptr < float > ( ) ;
}
}
virtual bool match ( const Ptr < ImportGraphWrapper > & net , int nodeId ,
std : : vector < int > & matchedNodesIds ,
std : : vector < int > & targetNodesIds ) CV_OVERRIDE
std : : vector < int > & matchedNodesIds ) CV_OVERRIDE
{
if ( Subgraph : : match ( net , nodeId , matchedNodesIds , targetNodesIds ) )
if ( Subgraph : : match ( net , nodeId , matchedNodesIds ) )
{
// Check Div[B=sqrt(2)]
float divisor = extractConstant ( net , matchedNodesIds [ 0 ] , 1 ) ;
float divisor = extractConstant ( net , matchedNodesIds [ div ] , 1 ) . at < float > ( 0 ) ;
if ( std : : fabs ( divisor - M_SQRT2 ) > = std : : numeric_limits < float > : : epsilon ( ) )
return false ;
// Check Add[B=1]
float add_const = extractConstant ( net , matchedNodesIds [ 2 ] , 1 ) ;
float add_const = extractConstant ( net , matchedNodesIds [ add ] , 1 ) . at < float > ( 0 ) ;
if ( std : : fabs ( add_const - 1.f ) > = std : : numeric_limits < float > : : epsilon ( ) )
return false ;
// Check Mul[B=0.5]
float mul_const = extractConstant ( net , matchedNodesIds [ 4 ] , 1 ) ;
float mul_const = extractConstant ( net , matchedNodesIds [ mul2 ] , 1 ) . at < float > ( 0 ) ;
if ( std : : fabs ( mul_const - 0.5f ) > = std : : numeric_limits < float > : : epsilon ( ) )
return false ;
@ -206,6 +208,9 @@ public:
}
return false ;
}
private :
int div , add , mul2 ;
} ;
/* Fusion for GeluApproximation.
@ -229,61 +234,39 @@ public:
int input = addNodeToMatch ( " " ) ;
int mul0 = addNodeToMatch ( " Mul " , input , input ) ;
int mul1 = addNodeToMatch ( " Mul " , input , mul0 ) ;
int mul2 = addNodeToMatch ( " Mul " , addNodeToMatch ( " " ) /* A=0.044714998453855515 */ , mul1 ) ;
mul2 = addNodeToMatch ( " Mul " , addNodeToMatch ( " " ) /* A=0.044714998453855515 */ , mul1 ) ;
int add0 = addNodeToMatch ( " Add " , input , mul2 ) ;
int mul3 = addNodeToMatch ( " Mul " , addNodeToMatch ( " " ) /* A=sqrt(2/pie) */ , add0 ) ;
mul3 = addNodeToMatch ( " Mul " , addNodeToMatch ( " " ) /* A=sqrt(2/pie) */ , add0 ) ;
int tanh = addNodeToMatch ( " Tanh " , mul3 ) ;
int add1 = addNodeToMatch ( " Add " , addNodeToMatch ( " " ) /* A=1 */ , tanh ) ;
add1 = addNodeToMatch ( " Add " , addNodeToMatch ( " " ) /* A=1 */ , tanh ) ;
int mul4 = addNodeToMatch ( " Mul " , input , add1 ) ;
addNodeToMatch ( " Mul " , addNodeToMatch ( " " ) /* A=0.5 */ , mul4 ) ;
mul5 = addNodeToMatch ( " Mul " , addNodeToMatch ( " " ) /* A=0.5 */ , mul4 ) ;
setFusedNode ( " GeluApproximation " , input ) ;
}
static float extractConstant ( const Ptr < ImportGraphWrapper > & net , int node_id , int input_id )
{
auto onnx_net = net . dynamicCast < ONNXGraphWrapper > ( ) ;
int initializer_id = onnx_net - > getInputInitializerId ( node_id , input_id ) ;
if ( initializer_id ! = - 1 )
{
Mat const_mat = onnx_net - > getMatFromInitializer ( initializer_id ) ;
return * const_mat . ptr < float > ( ) ;
}
else
{
const Ptr < ImportNodeWrapper > node = net - > getNode ( node_id ) ;
int constant_id = getInputNodeId ( net , node , input_id ) ;
Ptr < ImportNodeWrapper > constant_ptr = net - > getNode ( constant_id ) ;
opencv_onnx : : NodeProto * constant_node = constant_ptr . dynamicCast < ONNXNodeWrapper > ( ) - > node ;
opencv_onnx : : TensorProto constant_proto = constant_node - > attribute ( 0 ) . t ( ) ;
Mat constant_mat = getMatFromTensor ( constant_proto ) ;
return * constant_mat . ptr < float > ( ) ;
}
}
virtual bool match ( const Ptr < ImportGraphWrapper > & net , int nodeId ,
std : : vector < int > & matchedNodesIds ,
std : : vector < int > & targetNodesIds ) CV_OVERRIDE
std : : vector < int > & matchedNodesIds ) CV_OVERRIDE
{
if ( Subgraph : : match ( net , nodeId , matchedNodesIds , targetNodesIds ) )
if ( Subgraph : : match ( net , nodeId , matchedNodesIds ) )
{
// Check Mul[A=0.044714998453855515]
float coef = extractConstant ( net , matchedNodesIds [ 2 ] , 0 ) ;
float coef = extractConstant ( net , matchedNodesIds [ mul2 ] , 0 ) . at < float > ( 0 ) ;
if ( coef - 0.044714998453855515 > = 1e-6 )
return false ;
// Check Mul[A=sqrt(2/pie)]
float sqrt_2_pie = extractConstant ( net , matchedNodesIds [ 4 ] , 0 ) ;
float sqrt_2_pie = extractConstant ( net , matchedNodesIds [ mul3 ] , 0 ) . at < float > ( 0 ) ;
if ( sqrt_2_pie - 0.7978845834732056 > = 1e-6 )
return false ;
// Check Add[A=1]
float add_const = extractConstant ( net , matchedNodesIds [ 6 ] , 0 ) ;
float add_const = extractConstant ( net , matchedNodesIds [ add1 ] , 0 ) . at < float > ( 0 ) ;
if ( add_const - 1.f > = 1e-6 )
return false ;
// Check Mul[A=0.5]
float mul_const = extractConstant ( net , matchedNodesIds [ 8 ] , 0 ) ;
float mul_const = extractConstant ( net , matchedNodesIds [ mul5 ] , 0 ) . at < float > ( 0 ) ;
if ( mul_const - 0.5f > = 1e-6 )
return false ;
@ -291,6 +274,9 @@ public:
}
return false ;
}
private :
int mul2 , mul3 , add1 , mul5 ;
} ;
/* Fusion for LayerNormalization.
@ -313,43 +299,22 @@ public:
LayerNormSubGraph ( ) : axis ( - 1 ) , epsilon ( 1e-5 )
{
int input = addNodeToMatch ( " " ) ;
int mean = addNodeToMatch ( " ReduceMean " , input ) ;
mean = addNodeToMatch ( " ReduceMean " , input ) ;
int sub = addNodeToMatch ( " Sub " , input , mean ) ;
int pow = addNodeToMatch ( " Pow " , sub , addNodeToMatch ( " " ) ) ;
int mean1 = addNodeToMatch ( " ReduceMean " , pow ) ;
int add = addNodeToMatch ( " Add " , mean1 , addNodeToMatch ( " " ) ) ;
pow = addNodeToMatch ( " Pow " , sub , addNodeToMatch ( " " ) ) ;
mean1 = addNodeToMatch ( " ReduceMean " , pow ) ;
add = addNodeToMatch ( " Add " , mean1 , addNodeToMatch ( " " ) ) ;
int sqrt = addNodeToMatch ( " Sqrt " , add ) ;
int div = addNodeToMatch ( " Div " , sub , sqrt ) ;
int mul = addNodeToMatch ( " Mul " , div , addNodeToMatch ( " " ) ) ;
addNodeToMatch ( " Add " , mul , addNodeToMatch ( " " ) ) ;
mul = addNodeToMatch ( " Mul " , div , addNodeToMatch ( " " ) ) ;
bias = addNodeToMatch ( " Add " , mul , addNodeToMatch ( " " ) ) ;
setFusedNode ( " LayerNormalization " , input ) ;
}
static float extractConstant ( const Ptr < ImportGraphWrapper > & net , int node_id , int input_id )
{
auto onnx_net = net . dynamicCast < ONNXGraphWrapper > ( ) ;
int initializer_id = onnx_net - > getInputInitializerId ( node_id , input_id ) ;
if ( initializer_id ! = - 1 ) // initializer
{
Mat const_mat = onnx_net - > getMatFromInitializer ( initializer_id ) ;
return * const_mat . ptr < float > ( ) ;
}
else
{
const Ptr < ImportNodeWrapper > node = net - > getNode ( node_id ) ;
int constant_id = getInputNodeId ( net , node , input_id ) ;
Ptr < ImportNodeWrapper > constant_ptr = net - > getNode ( constant_id ) ;
opencv_onnx : : NodeProto * constant_node = constant_ptr . dynamicCast < ONNXNodeWrapper > ( ) - > node ;
opencv_onnx : : TensorProto constant_proto = constant_node - > attribute ( 0 ) . t ( ) ;
Mat constant_mat = getMatFromTensor ( constant_proto ) ;
return * constant_mat . ptr < float > ( ) ;
}
}
static float extractAxis ( const Ptr < ImportGraphWrapper > & net , int node_id )
{
Ptr < ImportNodeWrapper > mean_ptr = net - > getNode ( node_id ) ;
@ -381,25 +346,24 @@ public:
}
virtual bool match ( const Ptr < ImportGraphWrapper > & net , int nodeId ,
std : : vector < int > & matchedNodesIds ,
std : : vector < int > & targetNodesIds ) CV_OVERRIDE
std : : vector < int > & matchedNodesIds ) CV_OVERRIDE
{
if ( Subgraph : : match ( net , nodeId , matchedNodesIds , targetNodesIds ) )
if ( Subgraph : : match ( net , nodeId , matchedNodesIds ) )
{
float pow_exp = extractConstant ( net , matchedNodesIds [ 2 ] , 1 ) ;
float pow_exp = extractConstant ( net , matchedNodesIds [ pow ] , 1 ) . at < float > ( 0 ) ;
if ( pow_exp - 2 > 1e-5 ) // not pow(2)
return false ;
int axis_mean1 = extractAxis ( net , matchedNodesIds [ 0 ] ) ;
int axis_mean2 = extractAxis ( net , matchedNodesIds [ 3 ] ) ;
int axis_mean1 = extractAxis ( net , matchedNodesIds [ mean ] ) ;
int axis_mean2 = extractAxis ( net , matchedNodesIds [ mean1 ] ) ;
if ( axis_mean1 ! = axis_mean2 )
return false ;
axis = axis_mean1 ;
epsilon = extractConstant ( net , matchedNodesIds [ 4 ] , 1 ) ;
epsilon = extractConstant ( net , matchedNodesIds [ add ] , 1 ) . at < float > ( 0 ) ;
weight_name = getInputName ( net , matchedNodesIds [ 7 ] , 1 ) ;
bias_name = getInputName ( net , matchedNodesIds [ 8 ] , 1 ) ;
weight_name = getInputName ( net , matchedNodesIds [ mul ] , 1 ) ;
bias_name = getInputName ( net , matchedNodesIds [ bias ] , 1 ) ;
return true ;
}
@ -429,6 +393,7 @@ protected:
float epsilon ;
std : : string weight_name ;
std : : string bias_name ;
int pow , mean , mean1 , add , mul , bias ;
} ;
class SoftMaxSubgraphBase : public Subgraph
@ -437,10 +402,9 @@ public:
SoftMaxSubgraphBase ( ) : axis ( 1 ) , id ( - 1 ) { }
virtual bool match ( const Ptr < ImportGraphWrapper > & net , int nodeId ,
std : : vector < int > & matchedNodesIds ,
std : : vector < int > & targetNodesIds ) CV_OVERRIDE
std : : vector < int > & matchedNodesIds ) CV_OVERRIDE
{
if ( Subgraph : : match ( net , nodeId , matchedNodesIds , targetNodesIds ) )
if ( Subgraph : : match ( net , nodeId , matchedNodesIds ) )
{
CV_Assert ( id > = 0 & & id < matchedNodesIds . size ( ) ) ;
Ptr < ImportNodeWrapper > sum = net - > getNode ( matchedNodesIds [ id ] ) ;
@ -485,7 +449,7 @@ public:
int inpExp = addNodeToMatch ( " Exp " , input ) ;
int sum = addNodeToMatch ( " ReduceSum " , inpExp ) ;
id = 1 ;
id = sum ;
addNodeToMatch ( " Div " , inpExp , sum ) ;
setFusedNode ( " Softmax " , input ) ;
@ -498,7 +462,7 @@ public:
int input = addNodeToMatch ( " " ) ;
int reducemax = addNodeToMatch ( " ReduceMax " , input ) ;
id = 0 ;
id = reducemax ;
int sub = addNodeToMatch ( " Sub " , input , reducemax ) ;
int exp = addNodeToMatch ( " Exp " , sub ) ;
@ -516,7 +480,7 @@ public:
int input = addNodeToMatch ( " " ) ;
int reducemax = addNodeToMatch ( " ReduceMax " , input ) ;
id = 0 ;
id = reducemax ;
int sub_1 = addNodeToMatch ( " Sub " , input , reducemax ) ;
int exp = addNodeToMatch ( " Exp " , sub_1 ) ;
@ -533,18 +497,17 @@ public:
HardSwishSubgraph ( )
{
int input = addNodeToMatch ( " " ) ;
int hardSigmoid = addNodeToMatch ( " HardSigmoid " , input ) ;
addNodeToMatch ( " Mul " , input , hardSigmoid ) ;
hardSigmoidI d = addNodeToMatch ( " HardSigmoid " , input ) ;
addNodeToMatch ( " Mul " , input , hardSigmoidId ) ;
setFusedNode ( " HardSwish " , input ) ;
}
virtual bool match ( const Ptr < ImportGraphWrapper > & net , int nodeId ,
std : : vector < int > & matchedNodesIds ,
std : : vector < int > & targetNodesIds ) CV_OVERRIDE
std : : vector < int > & matchedNodesIds ) CV_OVERRIDE
{
if ( Subgraph : : match ( net , nodeId , matchedNodesIds , targetNodesIds ) )
if ( Subgraph : : match ( net , nodeId , matchedNodesIds ) )
{
Ptr < ImportNodeWrapper > hardSigmoid = net - > getNode ( matchedNodesIds [ 0 ] ) ;
Ptr < ImportNodeWrapper > hardSigmoid = net - > getNode ( matchedNodesIds [ hardSigmoidId ] ) ;
opencv_onnx : : NodeProto * node = hardSigmoid . dynamicCast < ONNXNodeWrapper > ( ) - > node ;
uint8_t matched = 0 ;
@ -561,6 +524,9 @@ public:
}
return false ;
}
private :
int hardSigmoidId ;
} ;
class CeluSubgraph : public Subgraph
@ -569,9 +535,9 @@ public:
CeluSubgraph ( ) : alpha ( 1.f )
{
int input = addNodeToMatch ( " " ) ;
int div = addNodeToMatch ( " Div " , input , addNodeToMatch ( " " ) ) ;
int elu = addNodeToMatch ( " Elu " , div ) ;
addNodeToMatch ( " Mul " , addNodeToMatch ( " " ) , elu ) ;
div = addNodeToMatch ( " Div " , input , addNodeToMatch ( " " ) ) ;
elu = addNodeToMatch ( " Elu " , div ) ;
mul = addNodeToMatch ( " Mul " , addNodeToMatch ( " " ) , elu ) ;
setFusedNode ( " Celu " , input ) ;
}
@ -587,16 +553,15 @@ public:
}
virtual bool match ( const Ptr < ImportGraphWrapper > & net , int nodeId ,
std : : vector < int > & matchedNodesIds ,
std : : vector < int > & targetNodesIds ) CV_OVERRIDE
std : : vector < int > & matchedNodesIds ) CV_OVERRIDE
{
if ( Subgraph : : match ( net , nodeId , matchedNodesIds , targetNodesIds ) )
if ( Subgraph : : match ( net , nodeId , matchedNodesIds ) )
{
float alpha_div = extractAlpha ( net , matchedNodesIds [ 0 ] , 1 ) ;
float alpha_mul = extractAlpha ( net , matchedNodesIds [ 2 ] , 0 ) ;
float alpha_div = extractAlpha ( net , matchedNodesIds [ div ] , 1 ) ;
float alpha_mul = extractAlpha ( net , matchedNodesIds [ mul ] , 0 ) ;
float alpha_elu = 1.f ;
Ptr < ImportNodeWrapper > elu_ptr = net - > getNode ( matchedNodesIds [ 1 ] ) ;
Ptr < ImportNodeWrapper > elu_ptr = net - > getNode ( matchedNodesIds [ elu ] ) ;
opencv_onnx : : NodeProto * elu_node = elu_ptr . dynamicCast < ONNXNodeWrapper > ( ) - > node ;
for ( int i = 0 ; i < elu_node - > attribute_size ( ) ; i + + )
@ -625,18 +590,18 @@ public:
protected :
float alpha ;
int div , mul , elu ;
} ;
class NormalizeSubgraphBase : public Subgraph
{
public :
NormalizeSubgraphBase ( int _normNodeOrder = 0 ) : axis ( 1 ) , normNodeOrder ( _normNodeOrder ) { }
NormalizeSubgraphBase ( int _normNodeOrder = 1 ) : axis ( 1 ) , normNodeOrder ( _normNodeOrder ) { }
virtual bool match ( const Ptr < ImportGraphWrapper > & net , int nodeId ,
std : : vector < int > & matchedNodesIds ,
std : : vector < int > & targetNodesIds ) CV_OVERRIDE
std : : vector < int > & matchedNodesIds ) CV_OVERRIDE
{
if ( Subgraph : : match ( net , nodeId , matchedNodesIds , targetNodesIds ) )
if ( Subgraph : : match ( net , nodeId , matchedNodesIds ) )
{
Ptr < ImportNodeWrapper > norm = net - > getNode ( matchedNodesIds [ normNodeOrder ] ) ;
opencv_onnx : : NodeProto * node = norm . dynamicCast < ONNXNodeWrapper > ( ) - > node ;
@ -725,7 +690,7 @@ public:
class NormalizeSubgraph3 : public NormalizeSubgraphBase
{
public :
NormalizeSubgraph3 ( ) : NormalizeSubgraphBase ( 1 )
NormalizeSubgraph3 ( ) : NormalizeSubgraphBase ( 3 )
{
int input = addNodeToMatch ( " " ) ;
int power = addNodeToMatch ( " Constant " ) ;
@ -743,7 +708,7 @@ public:
class NormalizeSubgraph4 : public NormalizeSubgraphBase
{
public :
NormalizeSubgraph4 ( ) : NormalizeSubgraphBase ( 1 )
NormalizeSubgraph4 ( ) : NormalizeSubgraphBase ( 2 )
{
int input = addNodeToMatch ( " " ) ;
int mul = addNodeToMatch ( " Mul " , input , input ) ;
@ -760,7 +725,7 @@ public:
class NormalizeSubgraph5 : public NormalizeSubgraphBase
{
public :
NormalizeSubgraph5 ( ) : NormalizeSubgraphBase ( 1 )
NormalizeSubgraph5 ( ) : NormalizeSubgraphBase ( 2 )
{
int input = addNodeToMatch ( " " ) ;
int mul = addNodeToMatch ( " Mul " , input , input ) ;
@ -781,25 +746,24 @@ public:
{
int input = addNodeToMatch ( " " ) ;
int index = addNodeToMatch ( " Constant " ) ;
int gather = addNodeToMatch ( " Gather " , input , index ) ;
addNodeToMatch ( " Cast " , gather ) ;
gather = addNodeToMatch ( " Gather " , input , index ) ;
cast = addNodeToMatch ( " Cast " , gather ) ;
setFusedNode ( " Gather " , input , index ) ;
}
virtual bool match ( const Ptr < ImportGraphWrapper > & net , int nodeId ,
std : : vector < int > & matchedNodesIds ,
std : : vector < int > & targetNodesIds ) CV_OVERRIDE
std : : vector < int > & matchedNodesIds ) CV_OVERRIDE
{
bool retVal = Subgraph : : match ( net , nodeId , matchedNodesIds , targetNodesIds ) ;
bool retVal = Subgraph : : match ( net , nodeId , matchedNodesIds ) ;
size_t matchedNodesNum = matchedNodesIds . size ( ) ;
// Now we check if merging can be made for these Gather and Cast nodes
if ( ! retVal | | matchedNodesNum < 2 )
return retVal ;
else {
int nodeToMatch = matchedNodesIds [ matchedNodesNum - 1 ] ;
int nodeToMatch = matchedNodesIds [ cast ] ;
const Ptr < ImportNodeWrapper > node = net - > getNode ( nodeToMatch ) ;
if ( node - > getType ( ) = = " Cast " ) {
int inpNodeId = matchedNodesIds [ matchedNodesNum - 2 ] ;
int inpNodeId = matchedNodesIds [ gather ] ;
const Ptr < ImportNodeWrapper > inpNode = net - > getNode ( inpNodeId ) ;
if ( inpNode - > getType ( ) = = " Gather " ) {
int numNodes = net - > getNumNodes ( ) ;
@ -819,6 +783,9 @@ public:
}
return retVal ;
}
private :
int cast , gather ;
} ;
/* Constant folding shape for Expand.
@ -838,12 +805,12 @@ public:
{
int input = addNodeToMatch ( " " ) ;
int values = addNodeToMatch ( " " ) ;
int init = addNodeToMatch ( " ConstantOfShape " , values ) ;
init = addNodeToMatch ( " ConstantOfShape " , values ) ;
int coeff = addNodeToMatch ( " Constant " ) ;
int mul = addNodeToMatch ( " Mul " , init , coeff ) ;
mul = addNodeToMatch ( " Mul " , init , coeff ) ;
int shape = addNodeToMatch ( " Constant " ) ;
int condition = addNodeToMatch ( " Equal " , shape , mul ) ;
int where = addNodeToMatch ( " Where " , condition , init , addNodeToMatch ( " Constant " ) ) ;
condition = addNodeToMatch ( " Equal " , shape , mul ) ;
where = addNodeToMatch ( " Where " , condition , init , addNodeToMatch ( " Constant " ) ) ;
addNodeToMatch ( " Expand " , input , where ) ;
setFusedNode ( " Expand " , input , shape ) ;
}
@ -872,53 +839,28 @@ public:
return 0 ;
}
static std : : vector < int64_t > extractConstant ( const Ptr < ImportGraphWrapper > & net , int node_id , int input_id )
{
auto onnx_net = net . dynamicCast < ONNXGraphWrapper > ( ) ;
int initializer_id = onnx_net - > getInputInitializerId ( node_id , input_id ) ;
Mat mat_constant ;
if ( initializer_id ! = - 1 ) // initializer
{
mat_constant = onnx_net - > getMatFromInitializer ( initializer_id ) ;
}
else
{
const Ptr < ImportNodeWrapper > node = net - > getNode ( node_id ) ;
int constant_id = getInputNodeId ( net , node , input_id ) ;
Ptr < ImportNodeWrapper > constant_ptr = net - > getNode ( constant_id ) ;
opencv_onnx : : NodeProto * constant_node = constant_ptr . dynamicCast < ONNXNodeWrapper > ( ) - > node ;
opencv_onnx : : TensorProto constant_proto = constant_node - > attribute ( 0 ) . t ( ) ;
mat_constant = getMatFromTensor ( constant_proto ) ;
}
std : : vector < int64_t > retvals { mat_constant . begin < int > ( ) , mat_constant . end < int > ( ) } ;
return retvals ;
}
virtual bool match ( const Ptr < ImportGraphWrapper > & net , int nodeId ,
std : : vector < int > & matchedNodesIds ,
std : : vector < int > & targetNodesIds ) CV_OVERRIDE {
if ( Subgraph : : match ( net , nodeId , matchedNodesIds , targetNodesIds ) ) {
std : : vector < int > & matchedNodesIds ) CV_OVERRIDE {
if ( Subgraph : : match ( net , nodeId , matchedNodesIds ) ) {
int64_t value_ConstantOfShape ;
if ( ! extractValue ( net , matchedNodesIds [ 0 ] , value_ConstantOfShape ) ) {
if ( ! extractValue ( net , matchedNodesIds [ init ] , value_ConstantOfShape ) ) {
return false ;
}
std : : vector < int64_t > input_ConstantOfShape = extractConstant ( net , matchedNodesIds [ 0 ] , 0 ) ;
std : : vector < int > input_ConstantOfShape = extractConstant ( net , matchedNodesIds [ init ] , 0 ) ;
if ( input_ConstantOfShape . size ( ) ! = static_cast < size_t > ( 1 ) ) {
return false ;
}
auto B_Mul = extractConstant ( net , matchedNodesIds [ 1 ] , 1 ) ;
std : : vector < int > B_Mul = extractConstant ( net , matchedNodesIds [ mul ] , 1 ) ;
if ( B_Mul . size ( ) ! = static_cast < size_t > ( 1 ) ) {
return false ;
}
auto A_Equal = extractConstant ( net , matchedNodesIds [ 2 ] , 0 ) ;
std : : vector < int > A_Equal = extractConstant ( net , matchedNodesIds [ condition ] , 0 ) ;
if ( A_Equal . size ( ) ! = static_cast < size_t > ( input_ConstantOfShape [ 0 ] ) ) {
return false ;
}
auto Y_Where = extractConstant ( net , matchedNodesIds [ 3 ] , 2 ) ;
std : : vector < int > Y_Where = extractConstant ( net , matchedNodesIds [ where ] , 2 ) ;
if ( Y_Where . size ( ) ! = A_Equal . size ( ) ) {
return false ;
}
@ -969,6 +911,9 @@ public:
protected :
std : : vector < int64_t > shape ;
private :
int init , mul , condition , where ;
} ;
class MishSubgraph : public Subgraph
@ -979,7 +924,7 @@ public:
int input = addNodeToMatch ( " " ) ;
int softplus = addNodeToMatch ( " Softplus " , input ) ;
int tanh = addNodeToMatch ( " Tanh " , softplus ) ;
addNodeToMatch ( " Mul " , input , tanh ) ;
addNodeToMatch ( " Mul " , tanh , input ) ;
setFusedNode ( " Mish " , input ) ;
}
} ;
@ -999,20 +944,6 @@ public:
}
} ;
class SoftplusSubgraph2 : public Subgraph
{
public :
SoftplusSubgraph2 ( )
{
int input = addNodeToMatch ( " " ) ;
int exp = addNodeToMatch ( " Exp " , input ) ;
int addVal = addNodeToMatch ( " " ) ;
int add = addNodeToMatch ( " Add " , exp , addVal ) ;
addNodeToMatch ( " Log " , add ) ;
setFusedNode ( " Softplus " , input ) ;
}
} ;
class MulCastSubgraph : public Subgraph
{
public :
@ -1248,7 +1179,6 @@ void simplifySubgraphs(opencv_onnx::GraphProto& net)
subgraphs . push_back ( makePtr < BatchNormalizationSubgraph2 > ( ) ) ;
subgraphs . push_back ( makePtr < ExpandSubgraph > ( ) ) ;
subgraphs . push_back ( makePtr < SoftplusSubgraph > ( ) ) ;
subgraphs . push_back ( makePtr < SoftplusSubgraph2 > ( ) ) ;
subgraphs . push_back ( makePtr < MishSubgraph > ( ) ) ;
subgraphs . push_back ( makePtr < NormalizeSubgraph4 > ( ) ) ;
subgraphs . push_back ( makePtr < NormalizeSubgraph5 > ( ) ) ;