@ -1746,43 +1746,45 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
for ( int i = 1 ; i < node_proto . input_size ( ) ; i + + )
CV_Assert ( layer_id . find ( node_proto . input ( i ) ) = = layer_id . end ( ) ) ;
String interp_mode ;
if ( layerParams . has ( " coordinate_transformation_mode " ) )
interp_mode = layerParams . get < String > ( " coordinate_transformation_mode " ) ;
else
interp_mode = layerParams . get < String > ( " mode " ) ;
CV_Assert_N ( interp_mode ! = " tf_crop_and_resize " , interp_mode ! = " tf_half_pixel_for_nn " ) ;
layerParams . set ( " align_corners " , interp_mode = = " align_corners " ) ;
Mat shapes = getBlob ( node_proto , node_proto . input_size ( ) - 1 ) ;
CV_CheckEQ ( shapes . size [ 0 ] , 4 , " " ) ;
CV_CheckEQ ( shapes . size [ 1 ] , 1 , " " ) ;
CV_CheckDepth ( shapes . depth ( ) , shapes . depth ( ) = = CV_32S | | shapes . depth ( ) = = CV_32F , " " ) ;
if ( shapes . depth ( ) = = CV_32F )
shapes . convertTo ( shapes , CV_32S ) ;
int height = shapes . at < int > ( 2 ) ;
int width = shapes . at < int > ( 3 ) ;
if ( hasDynamicShapes )
{
layerParams . set ( " zoom_factor_x " , width ) ;
layerParams . set ( " zoom_factor_y " , height ) ;
String interp_mode = layerParams . get < String > ( " coordinate_transformation_mode " ) ;
CV_Assert_N ( interp_mode ! = " tf_crop_and_resize " , interp_mode ! = " tf_half_pixel_for_nn " ) ;
layerParams . set ( " align_corners " , interp_mode = = " align_corners " ) ;
if ( layerParams . get < String > ( " mode " ) = = " linear " )
{
layerParams . set ( " mode " , interp_mode = = " pytorch_half_pixel " ?
" opencv_linear " : " bilinear " ) ;
}
}
if ( layerParams . get < String > ( " mode " ) = = " linear " & & framework_name = = " pytorch " )
layerParams . set ( " mode " , " opencv_linear " ) ;
// input = [X, scales], [X, roi, scales] or [x, roi, scales, sizes]
int foundScaleId = hasDynamicShapes ? node_proto . input_size ( ) - 1
: node_proto . input_size ( ) > 2 ? 2 : 1 ;
Mat scales = getBlob ( node_proto , foundScaleId ) ;
if ( scales . total ( ) = = 4 )
{
layerParams . set ( " zoom_factor_y " , scales . at < float > ( 2 ) ) ;
layerParams . set ( " zoom_factor_x " , scales . at < float > ( 3 ) ) ;
}
else
{
if ( node_proto . input_size ( ) = = 3 ) {
IterShape_t shapeIt = outShapes . find ( node_proto . input ( 0 ) ) ;
CV_Assert ( shapeIt ! = outShapes . end ( ) ) ;
MatShape scales = shapeIt - > second ;
height * = scales [ 2 ] ;
width * = scales [ 3 ] ;
const std : : string & inputLast = node_proto . input ( node_proto . input_size ( ) - 1 ) ;
if ( constBlobs . find ( inputLast ) ! = constBlobs . end ( ) )
{
Mat shapes = getBlob ( inputLast ) ;
CV_CheckEQ ( shapes . size [ 0 ] , 4 , " " ) ;
CV_CheckEQ ( shapes . size [ 1 ] , 1 , " " ) ;
CV_CheckDepth ( shapes . depth ( ) , shapes . depth ( ) = = CV_32S | | shapes . depth ( ) = = CV_32F , " " ) ;
if ( shapes . depth ( ) = = CV_32F )
shapes . convertTo ( shapes , CV_32S ) ;
layerParams . set ( " width " , shapes . at < int > ( 3 ) ) ;
layerParams . set ( " height " , shapes . at < int > ( 2 ) ) ;
}
layerParams . set ( " width " , width ) ;
layerParams . set ( " height " , height ) ;
}
if ( layerParams . get < String > ( " mode " ) = = " linear " ) {
layerParams . set ( " mode " , interp_mode = = " pytorch_half_pixel " ?
" opencv_linear " : " bilinear " ) ;
}
replaceLayerParam ( layerParams , " mode " , " interpolation " ) ;
}
@ -1822,10 +1824,14 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
else
{
// scales as input
Mat scales = getBlob ( node_proto , 1 ) ;
CV_Assert ( scales . total ( ) = = 4 ) ;
layerParams . set ( " zoom_factor_y " , scales . at < float > ( 2 ) ) ;
layerParams . set ( " zoom_factor_x " , scales . at < float > ( 3 ) ) ;
const std : : string & input1 = node_proto . input ( 1 ) ;
if ( constBlobs . find ( input1 ) ! = constBlobs . end ( ) )
{
Mat scales = getBlob ( input1 ) ;
CV_Assert ( scales . total ( ) = = 4 ) ;
layerParams . set ( " zoom_factor_y " , scales . at < float > ( 2 ) ) ;
layerParams . set ( " zoom_factor_x " , scales . at < float > ( 3 ) ) ;
}
}
replaceLayerParam ( layerParams , " mode " , " interpolation " ) ;
}