@ -81,26 +81,45 @@ bool Subgraph::match(const Ptr<ImportGraphWrapper>& net, int nodeId,
{
matchedNodesIds . clear ( ) ;
std : : queue < int > nodesToMatch ;
std : : queue < int > targetNodes ;
std : : vector < std : : pair < int , int > > matchings ;
matchings . reserve ( nodes . size ( ) ) ;
nodesToMatch . push ( nodeId ) ;
targetNodes . push ( nodes . size ( ) - 1 ) ;
while ( ! nodesToMatch . empty ( ) )
// Collection of all matchings states across branching.
// If there is no commutative ops in the subgraph - there would be just a single map.
std : : vector < std : : shared_ptr < std : : map < int , int > > > matchCandidates ;
matchCandidates . push_back ( makePtr < std : : map < int , int > > ( ) ) ;
struct State
{
int nodeToMatch ;
int targetNodeId ;
// Every state refers to current matchings pairs as well as
// matchings from parent branches produced by commutative ops.
std : : vector < std : : shared_ptr < std : : map < int , int > > > matchings ;
// When we register a matching pair we should register it in every parent branch.
// This is actual for branching in case of commutative ops only.
void addMatch ( std : : pair < int , int > match )
{
for ( auto & m : matchings )
m - > insert ( match ) ;
}
} ;
std : : queue < State > states ;
states . push ( { nodeId , ( int ) nodes . size ( ) - 1 , matchCandidates } ) ;
while ( ! states . empty ( ) )
{
int nodeToMatch = nodesToMatch . front ( ) ;
int targetNodeId = targetNodes . front ( ) ;
nodesToMatch . pop ( ) ;
targetNodes . pop ( ) ;
auto state = states . front ( ) ;
states . pop ( ) ;
int nodeToMatch = state . nodeToMatch ;
int targetNodeId = state . targetNodeId ;
auto matchings = state . matchings . back ( ) ;
if ( std : : find_if ( matchings . begin ( ) , matchings . end ( ) , [ & ] ( const std : : pair < int , int > & match ) { return match . first = = targetNodeId ; } ) ! =
matchings . end ( ) )
if ( matchings - > find ( targetNodeId ) ! = matchings - > end ( ) )
continue ;
// Empty placeholder matches with any input type
if ( nodes [ targetNodeId ] . empty ( ) ) {
matchings . push_back ( { targetNodeId , nodeToMatch } ) ;
state . addMatch ( { targetNodeId , nodeToMatch } ) ;
continue ;
}
@ -112,43 +131,51 @@ bool Subgraph::match(const Ptr<ImportGraphWrapper>& net, int nodeId,
if ( inputNodes . size ( ) ! = node - > getNumInputs ( ) )
continue ;
bool isCommutative = net - > isCommutativeOp ( node - > getType ( ) ) ;
state . addMatch ( { targetNodeId , nodeToMatch } ) ;
for ( int j = 0 ; j < inputNodes . size ( ) ; + + j )
{
// Sometimes, ONNX may have input but it's empty (see Clip layer from reduceL2_subgraph2_2 testcase)
if ( node - > getInputName ( j ) . empty ( ) )
continue ;
nodeId = getInputNodeId ( net , node , j ) ;
const Ptr < ImportNodeWrapper > inpNode = net - > getNode ( nodeId ) ;
bool isCommutative = net - > isCommutativeOp ( node - > getType ( ) ) ;
if ( isCommutative )
{
for ( int i = 0 ; i < inputNodes . size ( ) ; + + i )
{
nodesToMatch . push ( nodeId ) ;
targetNodes . push ( inputNodes [ i ] ) ;
}
if ( inputNodes . size ( ) ! = 2 )
CV_Error ( Error : : StsNotImplemented , " Commutative op fusion with more than 2 inputs " ) ;
auto newMatchings = makePtr < std : : map < int , int > > ( * matchings ) ;
matchCandidates . push_back ( newMatchings ) ;
state . matchings . push_back ( newMatchings ) ;
states . push ( { getInputNodeId ( net , node , 0 ) , inputNodes [ 0 ] , state . matchings } ) ;
states . push ( { getInputNodeId ( net , node , 1 ) , inputNodes [ 1 ] , state . matchings } ) ;
state . matchings . pop_back ( ) ;
newMatchings = makePtr < std : : map < int , int > > ( * matchings ) ;
matchCandidates . push_back ( newMatchings ) ;
state . matchings . push_back ( newMatchings ) ;
states . push ( { getInputNodeId ( net , node , 0 ) , inputNodes [ 1 ] , state . matchings } ) ;
states . push ( { getInputNodeId ( net , node , 1 ) , inputNodes [ 0 ] , state . matchings } ) ;
state . matchings . pop_back ( ) ;
}
else
{
nodesToMatch . push ( nodeId ) ;
targetNodes . push ( inputNodes [ j ] ) ;
for ( int j = 0 ; j < inputNodes . size ( ) ; + + j )
{
nodeId = getInputNodeId ( net , node , j ) ;
states . push ( { nodeId , inputNodes [ j ] , state . matchings } ) ;
}
}
matchings . push_back ( { targetNodeId , nodeToMatch } ) ;
}
if ( matchings . size ( ) ! = nodes . size ( ) )
return false ;
// Sort matched by pattern nodes order.
std : : sort ( matchings . begin ( ) , matchings . end ( ) ) ;
matchedNodesIds . resize ( matchings . size ( ) ) ;
for ( int i = 0 ; i < matchings . size ( ) ; + + i )
for ( auto & matchings : matchCandidates )
{
if ( matchings - > size ( ) ! = nodes . size ( ) )
continue ;
matchedNodesIds . resize ( matchings - > size ( ) ) ;
for ( int i = 0 ; i < matchings - > size ( ) ; + + i )
{
matchedNodesIds [ i ] = matchings [ i ] . second ;
CV_Assert ( matchings - > find ( i ) ! = matchings - > end ( ) ) ;
matchedNodesIds [ i ] = matchings - > at ( i ) ;
}
return true ;
}
return false ;
}
void Subgraph : : replace ( const Ptr < ImportGraphWrapper > & net , const std : : vector < int > & matchedNodesIds )
{