Merge pull request #22554 from WanliZhong:slice_axes_no_seq

DNN: Let Slice layer support non-sequential and negative axes
pull/22542/head
Alexander Smorkalov 2 years ago committed by GitHub
commit 96844b0ca5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 118
      modules/dnn/src/onnx/onnx_importer.cpp
  2. 14
      modules/dnn/test/test_onnx_importer.cpp

@ -1299,72 +1299,59 @@ void ONNXImporter::parseReduce(LayerParams& layerParams, const opencv_onnx::Node
void ONNXImporter::parseSlice(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto) void ONNXImporter::parseSlice(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{ {
int axis = 0; MatShape inpShape = outShapes[node_proto.input(0)];
std::vector<int> begin; int dims = inpShape.size();
std::vector<int> end; std::vector<int> begin(dims, 0);
std::vector<int> end(dims, INT_MAX);
std::vector<int> steps; std::vector<int> steps;
int inp_size = node_proto.input_size(); int inp_size = node_proto.input_size();
int axis = 0;
bool has_axes = false;
DictValue starts_, ends_, axes_, steps_;
// opset = 1
if (inp_size == 1) if (inp_size == 1)
{ {
if (layerParams.has("axes")) { starts_ = layerParams.get("starts");
DictValue axes = layerParams.get("axes"); ends_ = layerParams.get("ends");
for (int i = 1; i < axes.size(); ++i) { CV_Assert(starts_.size() == ends_.size());
CV_Assert(axes.get<int>(i - 1) == axes.get<int>(i) - 1); if (layerParams.has("axes"))
}
axis = axes.get<int>(0);
}
DictValue starts = layerParams.get("starts");
DictValue ends = layerParams.get("ends");
CV_Assert(starts.size() == ends.size());
if (axis > 0) {
CV_CheckLE(axis, 1024, "Slice layer can't have more than 1024 axes"); // arbitrary limit
begin.resize(axis, 0);
end.resize(axis, INT_MAX);
}
for (int i = 0; i < starts.size(); ++i)
{ {
begin.push_back(starts.get<int>(i)); axes_ = layerParams.get("axes");
end.push_back(ends.get<int>(i)); CV_Assert(axes_.size() == starts_.size());
axis = axes_.getIntValue(0) < 0 ? axes_.getIntValue(0) + dims : axes_.getIntValue(0);
has_axes = true;
} }
} else { // inp_size > 1 }
// opset > 1
else
{
CV_Assert(inp_size >= 3); CV_Assert(inp_size >= 3);
for (int i = 1; i < inp_size; i++) { for (int i = 1; i < inp_size; ++i)
{
CV_Assert(constBlobs.find(node_proto.input(i)) != constBlobs.end()); CV_Assert(constBlobs.find(node_proto.input(i)) != constBlobs.end());
} }
Mat start_blob = getBlob(node_proto, 1); Mat start_blob = getBlob(node_proto, 1);
Mat end_blob = getBlob(node_proto, 2); Mat end_blob = getBlob(node_proto, 2);
CV_Assert(start_blob.total() == end_blob.total()); CV_Assert(start_blob.total() == end_blob.total());
starts_ = DictValue::arrayInt(start_blob.begin<int>(), start_blob.total());
ends_ = DictValue::arrayInt(end_blob.begin<int>(), end_blob.total());
if (inp_size > 3) { if (inp_size > 3)
{
Mat axes_blob = getBlob(node_proto, 3); Mat axes_blob = getBlob(node_proto, 3);
const int* axes = (int*)axes_blob.data; CV_Assert(axes_blob.total() == start_blob.total());
for (int i = 1; i < axes_blob.total(); ++i) { axes_ = DictValue::arrayInt(axes_blob.begin<int>(), axes_blob.total());
CV_Assert(axes[i - 1] == axes[i] - 1); axis = axes_.getIntValue(0) < 0 ? axes_.getIntValue(0) + dims : axes_.getIntValue(0);
} has_axes = true;
axis = axes[0];
}
const int* starts = start_blob.ptr<int>();
const int* ends = end_blob.ptr<int>();
if (axis > 0) {
begin.resize(axis, 0);
end.resize(axis, INT_MAX);
} }
std::copy(starts, starts + start_blob.total(), std::back_inserter(begin));
std::copy(ends, ends + end_blob.total(), std::back_inserter(end));
if (inp_size == 5) { if (inp_size == 5)
CV_Assert(constBlobs.find(node_proto.input(4)) != constBlobs.end()); {
Mat step_blob = getBlob(node_proto, 4); Mat step_blob = getBlob(node_proto, 4);
const int* steps_ptr = step_blob.ptr<int>(); CV_Assert(step_blob.total() == start_blob.total());
steps_ = DictValue::arrayInt(step_blob.begin<int>(), step_blob.total());
if (axis > 0) steps.resize(dims, 1);
steps.resize(axis, 1);
std::copy(steps_ptr, steps_ptr + step_blob.total(), std::back_inserter(steps));
// Very strange application for Slice op with tensor reversing. // Very strange application for Slice op with tensor reversing.
// We just workaround it for 2d constants. // We just workaround it for 2d constants.
@ -1384,12 +1371,45 @@ void ONNXImporter::parseSlice(LayerParams& layerParams, const opencv_onnx::NodeP
} }
} }
} }
if (!has_axes)
{
// make a default axes [0, 1, 2...]
Mat axes_tmp(1, starts_.size(), CV_32S);
std::iota(axes_tmp.begin<int>(), axes_tmp.end<int>(), 0);
axes_ = DictValue::arrayInt(axes_tmp.begin<int>(), axes_tmp.total());
}
int cur_axe;
std::vector<bool> flag(dims, false);
Mat axes(1, starts_.size(), CV_32S);
auto axes_ptr = axes.ptr<int>();
// resize begin and end
for (int i = 0; i < axes_.size(); ++i)
{
// dims should be added to the negative axes
cur_axe = axes_.getIntValue(i) < 0 ? axes_.getIntValue(i) + dims : axes_.getIntValue(i);
CV_CheckGE(cur_axe, 0, "Axes should be grater or equal to '-dims'.");
CV_CheckLT(cur_axe, dims, "Axes should be less than 'dim'.");
CV_CheckEQ(flag[cur_axe], false, "Axes shouldn't have duplicated values.");
flag[cur_axe] = true;
// change axis to the minimum axe
if (cur_axe < axis) axis = cur_axe;
axes_ptr[i] = cur_axe;
begin[cur_axe] = starts_.getIntValue(i);
end[cur_axe] = ends_.getIntValue(i);
}
layerParams.set("begin", DictValue::arrayInt(&begin[0], begin.size())); layerParams.set("begin", DictValue::arrayInt(&begin[0], begin.size()));
layerParams.set("end", DictValue::arrayInt(&end[0], end.size())); layerParams.set("end", DictValue::arrayInt(&end[0], end.size()));
layerParams.set("axis", axis); layerParams.set("axis", axis);
if (!steps.empty()) if (!steps.empty())
{
for (int i = 0; i < axes.total(); ++i)
steps[axes_ptr[i]] = steps_.getIntValue(i);
layerParams.set("steps", DictValue::arrayInt(&steps[0], steps.size())); layerParams.set("steps", DictValue::arrayInt(&steps[0], steps.size()));
}
if (constBlobs.find(node_proto.input(0)) != constBlobs.end()) if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
{ {

@ -1172,6 +1172,20 @@ TEST_P(Test_ONNX_layers, Slice_Steps_5DInput)
testONNXModels("slice_opset_11_steps_5d"); testONNXModels("slice_opset_11_steps_5d");
} }
TEST_P(Test_ONNX_layers, Slice_Nonseq_Axes)
{
testONNXModels("slice_nonseq_axes");
testONNXModels("slice_nonseq_axes_steps");
testONNXModels("slice_nonseq_miss_axes_steps");
}
TEST_P(Test_ONNX_layers, Slice_Neg_Axes)
{
testONNXModels("slice_neg_axes");
testONNXModels("slice_neg_axes_steps");
testONNXModels("slice_neg_miss_axes_steps");
}
TEST_P(Test_ONNX_layers, Softmax) TEST_P(Test_ONNX_layers, Softmax)
{ {
testONNXModels("softmax"); testONNXModels("softmax");

Loading…
Cancel
Save