Merge pull request #24672 from dkurt:adjust_slice_optional_inputs

Replace Slice optional inputs removal to adjustment
pull/24681/head
Alexander Smorkalov 1 year ago committed by GitHub
commit 098efb6d3d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 58
      modules/dnn/src/onnx/onnx_graph_simplifier.cpp
  2. 4
      modules/dnn/src/onnx/onnx_importer.cpp

@ -182,12 +182,12 @@ static Mat extractConstant(const Ptr<ImportGraphWrapper>& net, int node_id, int
} }
/* Slice operator has two optional inputs "axes" and "steps". Some models may be set to have /* Slice operator has two optional inputs "axes" and "steps". Some models may be set to have
Slice with optional inputs of default values, some of them don't. This Subgraph removes Slice with optional inputs of default values, some of them don't. This Subgraph adjusts
all optional inputs of Slice if values are default. all optional inputs of Slice up to 5.
*/ */
class RemoveSliceAllOptionalInputsSubgraph : public Subgraph { class AdjustSliceAllOptionalInputsSubgraph : public Subgraph {
public: public:
RemoveSliceAllOptionalInputsSubgraph(size_t num_inputs = 4) { AdjustSliceAllOptionalInputsSubgraph(size_t num_inputs = 4) {
num_inputs_ = num_inputs; num_inputs_ = num_inputs;
int input = addNodeToMatch(""); int input = addNodeToMatch("");
@ -200,35 +200,17 @@ class RemoveSliceAllOptionalInputsSubgraph : public Subgraph {
slice_id = addNodeToMatch("Slice", inputs); slice_id = addNodeToMatch("Slice", inputs);
setFusedNode("Slice", std::vector<int>{input, starts, ends}); setFusedNode("Slice", inputs);
} }
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId, virtual void finalize(const Ptr<ImportGraphWrapper>&,
std::vector<int>& matchedNodesIds) CV_OVERRIDE { const Ptr<ImportNodeWrapper>& fusedNode,
if (Subgraph::match(net, nodeId, matchedNodesIds)) { std::vector<Ptr<ImportNodeWrapper> >&) CV_OVERRIDE
if (num_inputs_ >= 4) { // with axes {
// Check if axes are -1 or last axis opencv_onnx::NodeProto* node = fusedNode.dynamicCast<ONNXNodeWrapper>()->node;
auto onnx_net = net.dynamicCast<ONNXGraphWrapper>(); for (int i = num_inputs_; i < 5; ++i) {
int shape_size = onnx_net->getTensorShapeSize(matchedNodesIds[slice_id], 0); node->add_input("");
auto axes = extractConstant(net, matchedNodesIds[slice_id], 3);
for (size_t i = 0; i < axes.total(); i++) {
const int axis = *(axes.ptr<const int>() + i);
if (axis != -1 && axis != shape_size - 1) {
return false;
}
}
}
if (num_inputs_ == 5) {
// Check if steps are 1
auto steps = extractConstant(net, matchedNodesIds[slice_id], 4);
if (countNonZero(steps != 1)) {
return false;
}
}
return true;
} }
return false;
} }
private: private:
@ -1119,7 +1101,11 @@ public:
ResizeSubgraph1() : ExtractScalesSubgraph() ResizeSubgraph1() : ExtractScalesSubgraph()
{ {
int shape = addNodeToMatch("Shape", input); int shape = addNodeToMatch("Shape", input);
int slice = addNodeToMatch("Slice", shape, addNodeToMatch("Constant"), addNodeToMatch("Constant"), addNodeToMatch("Constant")); int slice = addNodeToMatch("Slice", {shape,
addNodeToMatch(""),
addNodeToMatch(""),
addNodeToMatch(""),
addNodeToMatch("")});
int castConcat = addNodeToMatch("Cast", concatId); int castConcat = addNodeToMatch("Cast", concatId);
int concat = addNodeToMatch("Concat", slice, castConcat); int concat = addNodeToMatch("Concat", slice, castConcat);
@ -1163,7 +1149,11 @@ public:
int cast = addNodeToMatch("Cast", concat1); int cast = addNodeToMatch("Cast", concat1);
int shape2 = addNodeToMatch("Shape", input); int shape2 = addNodeToMatch("Shape", input);
int slice = addNodeToMatch("Slice", shape2, addNodeToMatch("Constant"), addNodeToMatch("Constant")); int slice = addNodeToMatch("Slice", {shape2,
addNodeToMatch(""),
addNodeToMatch(""),
addNodeToMatch(""),
addNodeToMatch("")});
int concat2 = addNodeToMatch("Concat", slice, cast); int concat2 = addNodeToMatch("Concat", slice, cast);
addNodeToMatch("Resize", input, addNodeToMatch("Constant"), addNodeToMatch("Constant"), concat2); addNodeToMatch("Resize", input, addNodeToMatch("Constant"), addNodeToMatch("Constant"), concat2);
@ -1235,8 +1225,8 @@ public:
void simplifySubgraphs(opencv_onnx::GraphProto& net) void simplifySubgraphs(opencv_onnx::GraphProto& net)
{ {
std::vector<Ptr<Subgraph> > subgraphs; std::vector<Ptr<Subgraph> > subgraphs;
subgraphs.push_back(makePtr<RemoveSliceAllOptionalInputsSubgraph>(4)); subgraphs.push_back(makePtr<AdjustSliceAllOptionalInputsSubgraph>(3));
subgraphs.push_back(makePtr<RemoveSliceAllOptionalInputsSubgraph>(5)); subgraphs.push_back(makePtr<AdjustSliceAllOptionalInputsSubgraph>(4));
subgraphs.push_back(makePtr<GeluSubGraph>()); subgraphs.push_back(makePtr<GeluSubGraph>());
subgraphs.push_back(makePtr<GeluApproximationSubGraph>()); subgraphs.push_back(makePtr<GeluApproximationSubGraph>());
subgraphs.push_back(makePtr<LayerNormSubGraph>()); subgraphs.push_back(makePtr<LayerNormSubGraph>());

@ -1235,7 +1235,7 @@ void ONNXImporter::parseSlice(LayerParams& layerParams, const opencv_onnx::NodeP
starts_ = DictValue::arrayInt(start_blob.begin<int>(), start_blob.total()); starts_ = DictValue::arrayInt(start_blob.begin<int>(), start_blob.total());
ends_ = DictValue::arrayInt(end_blob.begin<int>(), end_blob.total()); ends_ = DictValue::arrayInt(end_blob.begin<int>(), end_blob.total());
if (inp_size > 3) if (inp_size > 3 && !getBlob(node_proto, 3).empty())
{ {
Mat axes_blob = getBlob(node_proto, 3); Mat axes_blob = getBlob(node_proto, 3);
CV_Assert(axes_blob.total() == start_blob.total()); CV_Assert(axes_blob.total() == start_blob.total());
@ -1244,7 +1244,7 @@ void ONNXImporter::parseSlice(LayerParams& layerParams, const opencv_onnx::NodeP
has_axes = true; has_axes = true;
} }
if (inp_size == 5) if (inp_size == 5 && !getBlob(node_proto, 4).empty())
{ {
Mat step_blob = getBlob(node_proto, 4); Mat step_blob = getBlob(node_proto, 4);
CV_Assert(step_blob.total() == start_blob.total()); CV_Assert(step_blob.total() == start_blob.total());

Loading…
Cancel
Save