Added ngraph::op::v4::Interpolation

pull/19716/head
Liubov Batanina 4 years ago
parent b995de4ff3
commit 95ab9468c1
  1. 31
      modules/dnn/src/layers/resize_layer.cpp

@ -257,6 +257,7 @@ public:
{
auto& ieInpNode = nodes[0].dynamicCast<InfEngineNgraphNode>()->node;
#if INF_ENGINE_VER_MAJOR_LE(INF_ENGINE_RELEASE_2021_2)
ngraph::op::InterpolateAttrs attrs;
attrs.pads_begin.push_back(0);
attrs.pads_end.push_back(0);
@ -275,6 +276,36 @@ public:
std::vector<int64_t> shape = {outHeight, outWidth};
auto out_shape = std::make_shared<ngraph::op::Constant>(ngraph::element::i64, ngraph::Shape{2}, shape.data());
auto interp = std::make_shared<ngraph::op::Interpolate>(ieInpNode, out_shape, attrs);
#else
ngraph::op::v4::Interpolate::InterpolateAttrs attrs;
if (interpolation == "nearest") {
attrs.mode = ngraph::op::v4::Interpolate::InterpolateMode::nearest;
attrs.coordinate_transformation_mode = ngraph::op::v4::Interpolate::CoordinateTransformMode::half_pixel;
} else if (interpolation == "bilinear") {
attrs.mode = ngraph::op::v4::Interpolate::InterpolateMode::linear_onnx;
attrs.coordinate_transformation_mode = ngraph::op::v4::Interpolate::CoordinateTransformMode::asymmetric;
} else {
CV_Error(Error::StsNotImplemented, "Unsupported interpolation: " + interpolation);
}
attrs.shape_calculation_mode = ngraph::op::v4::Interpolate::ShapeCalcMode::sizes;
if (alignCorners) {
attrs.coordinate_transformation_mode = ngraph::op::v4::Interpolate::CoordinateTransformMode::align_corners;
}
attrs.nearest_mode = ngraph::op::v4::Interpolate::NearestMode::round_prefer_floor;
std::vector<int64_t> shape = {outHeight, outWidth};
auto out_shape = std::make_shared<ngraph::op::Constant>(ngraph::element::i64, ngraph::Shape{2}, shape.data());
auto& input_shape = ieInpNode->get_shape();
std::vector<float> scales = {static_cast<float>(outHeight) / input_shape[2], static_cast<float>(outHeight) / input_shape[2]};
auto scales_shape = std::make_shared<ngraph::op::Constant>(ngraph::element::f32, ngraph::Shape{2}, scales.data());
auto axes = std::make_shared<ngraph::op::Constant>(ngraph::element::i64, ngraph::Shape{2}, std::vector<int64_t>{2, 3});
auto interp = std::make_shared<ngraph::op::v4::Interpolate>(ieInpNode, out_shape, scales_shape, axes, attrs);
#endif
return Ptr<BackendNode>(new InfEngineNgraphNode(interp));
}
#endif // HAVE_DNN_NGRAPH

Loading…
Cancel
Save