@ -86,6 +86,7 @@ public:
int getTensorShapeSize ( int node_id , int node_input_id ) {
int getTensorShapeSize ( int node_id , int node_input_id ) {
const auto node = getNode ( node_id ) ;
const auto node = getNode ( node_id ) ;
const auto & input_name = node - > getInputName ( node_input_id ) ;
const auto & input_name = node - > getInputName ( node_input_id ) ;
// try to get from value_info
for ( int i = 0 ; i < net . value_info_size ( ) ; i + + ) {
for ( int i = 0 ; i < net . value_info_size ( ) ; i + + ) {
const auto value_info = net . value_info ( i ) ;
const auto value_info = net . value_info ( i ) ;
if ( value_info . name ( ) = = input_name ) {
if ( value_info . name ( ) = = input_name ) {
@ -97,6 +98,18 @@ public:
}
}
}
}
}
}
// try to get from input
for ( int i = 0 ; i < net . input_size ( ) ; i + + ) {
const auto input = net . input ( i ) ;
if ( input . name ( ) = = input_name ) {
if ( input . has_type ( ) & & input . type ( ) . has_tensor_type ( ) & &
input . type ( ) . tensor_type ( ) . has_shape ( ) ) {
return input . type ( ) . tensor_type ( ) . shape ( ) . dim_size ( ) ;
} else {
return - 1 ;
}
}
}
return - 1 ;
return - 1 ;
}
}
@ -660,6 +673,10 @@ private:
[ Input ] - > LayerNorm - > [ Output ]
[ Input ] - > LayerNorm - > [ Output ]
\
\
[ weight ] , [ bias ]
[ weight ] , [ bias ]
Note : axes of ReduceMean must be :
- last element is the axis of last dimension ( - 1 or ( input_ndims - 1 ) )
- a list of adjacent axes , e . g . [ 1 , 2 , 3 , . . . , input_ndims - 1 ]
*/
*/
class LayerNormSubGraph : public Subgraph
class LayerNormSubGraph : public Subgraph
{
{
@ -683,19 +700,22 @@ public:
setFusedNode ( " LayerNormalization " , input ) ;
setFusedNode ( " LayerNormalization " , input ) ;
}
}
static float extractAxis ( const Ptr < ImportGraphWrapper > & net , int node_id )
static std : : vector < int64_t > extractAxis ( const Ptr < ImportGraphWrapper > & net , int node_id )
{
{
// TODO: consider ReduceMean-18 which has axes as one of the inputs instead of attributes
Ptr < ImportNodeWrapper > mean_ptr = net - > getNode ( node_id ) ;
Ptr < ImportNodeWrapper > mean_ptr = net - > getNode ( node_id ) ;
opencv_onnx : : NodeProto * mean_node = mean_ptr . dynamicCast < ONNXNodeWrapper > ( ) - > node ;
opencv_onnx : : NodeProto * mean_node = mean_ptr . dynamicCast < ONNXNodeWrapper > ( ) - > node ;
int axis_ = - 1 ;
std : : vector < int64_t > axes ;
for ( int i = 0 ; i < mean_node - > attribute_size ( ) ; i + + )
for ( int i = 0 ; i < mean_node - > attribute_size ( ) ; i + + )
{
{
opencv_onnx : : AttributeProto attr = mean_node - > attribute ( i ) ;
opencv_onnx : : AttributeProto attr = mean_node - > attribute ( i ) ;
if ( attr . name ( ) ! = " axes " )
if ( attr . name ( ) ! = " axes " )
continue ;
continue ;
axis_ = static_cast < int > ( attr . ints ( 0 ) ) ;
for ( int j = 0 ; j < attr . ints_size ( ) ; j + + ) {
axes . push_back ( attr . ints ( j ) ) ;
}
}
}
return axis_ ;
return axes ;
}
}
virtual bool match ( const Ptr < ImportGraphWrapper > & net , int nodeId ,
virtual bool match ( const Ptr < ImportGraphWrapper > & net , int nodeId ,
@ -707,11 +727,31 @@ public:
if ( pow_exp - 2 > 1e-5 ) // not pow(2)
if ( pow_exp - 2 > 1e-5 ) // not pow(2)
return false ;
return false ;
int axis_mean1 = extractAxis ( net , matchedNodesIds [ mean ] ) ;
std : : vector < int64_t > axes = extractAxis ( net , matchedNodesIds [ mean ] ) ;
int axis_mean2 = extractAxis ( net , matchedNodesIds [ mean1 ] ) ;
// check whether it is -1 or last_axis or [axis, ..., last_axis]
if ( axis_mean1 ! = axis_mean2 )
int64_t input_ndims = static_cast < int64_t > ( net . dynamicCast < ONNXGraphWrapper > ( ) - > getTensorShapeSize ( matchedNodesIds [ mean ] , 0 ) ) ;
if ( input_ndims = = - 1 ) {
return false ; // input shape unknown
}
// assume that axes are sorted in ascending order, e.g. [0, 1, 2, 3] or [-3, -2, -1]
if ( axes . back ( ) ! = - 1 & & axes . back ( ) ! = ( input_ndims - 1 ) ) {
return false ;
return false ;
axis = axis_mean1 ;
}
for ( size_t i = 0 ; i < axes . size ( ) - 1 ; i + + ) {
if ( axes [ i ] - axes [ i + 1 ] ! = - 1 ) {
return false ;
}
}
std : : vector < int64_t > axes1 = extractAxis ( net , matchedNodesIds [ mean1 ] ) ;
if ( axes . size ( ) ! = axes1 . size ( ) )
return false ;
for ( size_t i = 0 ; i < axes . size ( ) ; i + + ) {
if ( ( ( axes [ i ] + input_ndims ) % input_ndims ) ! = ( ( axes1 [ i ] + input_ndims ) % input_ndims ) ) {
return false ;
}
}
axis = axes [ 0 ] ;
epsilon = extractConstant ( net , matchedNodesIds [ add ] , 1 ) . at < float > ( 0 ) ;
epsilon = extractConstant ( net , matchedNodesIds [ add ] , 1 ) . at < float > ( 0 ) ;