@ -260,6 +260,40 @@ public:
addNodeToMatch ( " Cast " , gather ) ;
setFusedNode ( " Gather " , input , index ) ;
}
virtual bool match ( const Ptr < ImportGraphWrapper > & net , int nodeId ,
std : : vector < int > & matchedNodesIds ,
std : : vector < int > & targetNodesIds ) CV_OVERRIDE
{
bool retVal = Subgraph : : match ( net , nodeId , matchedNodesIds , targetNodesIds ) ;
size_t matchedNodesNum = matchedNodesIds . size ( ) ;
// Now we check if merging can be made for these Gather and Cast nodes
if ( ! retVal | | matchedNodesNum < 2 )
return retVal ;
else {
int nodeToMatch = matchedNodesIds [ matchedNodesNum - 1 ] ;
const Ptr < ImportNodeWrapper > node = net - > getNode ( nodeToMatch ) ;
if ( node - > getType ( ) = = " Cast " ) {
int inpNodeId = matchedNodesIds [ matchedNodesNum - 2 ] ;
const Ptr < ImportNodeWrapper > inpNode = net - > getNode ( inpNodeId ) ;
if ( inpNode - > getType ( ) = = " Gather " ) {
int numNodes = net - > getNumNodes ( ) ;
std : : string inpNodeName = node - > getInputName ( 0 ) ;
for ( int i = 0 ; i < numNodes ; + + i ) {
const Ptr < ImportNodeWrapper > node_to_check = net - > getNode ( i ) ;
int numInp = node_to_check - > getNumInputs ( ) ;
for ( int inp = 0 ; inp < numInp ; + + inp ) {
if ( i ! = nodeToMatch & & inpNodeName = = node_to_check - > getInputName ( 0 ) ) {
// Another node has the same input node, so it cannot be merged.
return false ;
}
}
}
}
}
}
return retVal ;
}
} ;
class ExpandSubgraph : public Subgraph