Added int64 values support to scatter, scatterND and maxunpool layers

pull/25212/head
Alexander Lyulkov 9 months ago
parent 85cc02f4de
commit d2d6869a26
  1. 36
      modules/dnn/src/layers/max_unpooling_layer.cpp
  2. 3
      modules/dnn/src/layers/scatterND_layer.cpp
  3. 3
      modules/dnn/src/layers/scatter_layer.cpp

@ -94,14 +94,34 @@ public:
Mat& input = inputs[0];
Mat& indices = inputs[1];
if (input.type() == CV_32F && indices.type() == CV_32S)
run<float, int32_t>(input, indices, outputs);
else if (input.type() == CV_32F && indices.type() == CV_64S)
run<float, int64_t>(input, indices, outputs);
else if (input.type() == CV_16F && indices.type() == CV_32S)
run<int16_t, int32_t>(input, indices, outputs);
else if (input.type() == CV_16F && indices.type() == CV_64S)
run<int16_t, int64_t>(input, indices, outputs);
if (indices.depth() == CV_32S)
typeDispatch<int32_t>(input.type(), input, indices, outputs);
else if (indices.depth() == CV_64S)
typeDispatch<int64_t>(input.type(), input, indices, outputs);
else
CV_Error(cv::Error::BadDepth, "Unsupported type.");
}
template<typename T_INDEX, typename... Args>
inline void typeDispatch(const int type, Args&&... args)
{
switch (type)
{
case CV_32S:
run<int32_t, T_INDEX>(std::forward<Args>(args)...);
break;
case CV_64S:
run<int64_t, T_INDEX>(std::forward<Args>(args)...);
break;
case CV_32F:
run<float, T_INDEX>(std::forward<Args>(args)...);
break;
case CV_16F:
run<int16_t, T_INDEX>(std::forward<Args>(args)...);
break;
default:
CV_Error(cv::Error::BadDepth, "Unsupported type.");
};
}
template<typename T, typename INDEX_TYPE>

@ -190,6 +190,9 @@ public:
case CV_32S:
reductionDispatch<int32_t, T_INDEX>(std::forward<Args>(args)...);
break;
case CV_64S:
reductionDispatch<int64_t, T_INDEX>(std::forward<Args>(args)...);
break;
case CV_32F:
reductionDispatch<float, T_INDEX>(std::forward<Args>(args)...);
break;

@ -185,6 +185,9 @@ public:
case CV_32S:
reductionDispatch<int32_t, T_INDEX>(std::forward<Args>(args)...);
break;
case CV_64S:
reductionDispatch<int64_t, T_INDEX>(std::forward<Args>(args)...);
break;
case CV_32F:
reductionDispatch<float, T_INDEX>(std::forward<Args>(args)...);
break;

Loading…
Cancel
Save