Merge pull request #24672 from dkurt:adjust_slice_optional_inputs

Replace Slice optional inputs removal to adjustment
pull/24681/head
Alexander Smorkalov 1 year ago committed by GitHub
commit 098efb6d3d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 58
      modules/dnn/src/onnx/onnx_graph_simplifier.cpp
  2. 4
      modules/dnn/src/onnx/onnx_importer.cpp

@ -182,12 +182,12 @@ static Mat extractConstant(const Ptr<ImportGraphWrapper>& net, int node_id, int
}
/* Slice operator has two optional inputs "axes" and "steps". Some models may be set to have
Slice with optional inputs of default values, some of them don't. This Subgraph removes
all optional inputs of Slice if values are default.
Slice with optional inputs of default values, some of them don't. This Subgraph adjusts
all optional inputs of Slice up to 5.
*/
class RemoveSliceAllOptionalInputsSubgraph : public Subgraph {
class AdjustSliceAllOptionalInputsSubgraph : public Subgraph {
public:
RemoveSliceAllOptionalInputsSubgraph(size_t num_inputs = 4) {
AdjustSliceAllOptionalInputsSubgraph(size_t num_inputs = 4) {
num_inputs_ = num_inputs;
int input = addNodeToMatch("");
@ -200,35 +200,17 @@ class RemoveSliceAllOptionalInputsSubgraph : public Subgraph {
slice_id = addNodeToMatch("Slice", inputs);
setFusedNode("Slice", std::vector<int>{input, starts, ends});
setFusedNode("Slice", inputs);
}
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
std::vector<int>& matchedNodesIds) CV_OVERRIDE {
if (Subgraph::match(net, nodeId, matchedNodesIds)) {
if (num_inputs_ >= 4) { // with axes
// Check if axes are -1 or last axis
auto onnx_net = net.dynamicCast<ONNXGraphWrapper>();
int shape_size = onnx_net->getTensorShapeSize(matchedNodesIds[slice_id], 0);
auto axes = extractConstant(net, matchedNodesIds[slice_id], 3);
for (size_t i = 0; i < axes.total(); i++) {
const int axis = *(axes.ptr<const int>() + i);
if (axis != -1 && axis != shape_size - 1) {
return false;
}
}
}
if (num_inputs_ == 5) {
// Check if steps are 1
auto steps = extractConstant(net, matchedNodesIds[slice_id], 4);
if (countNonZero(steps != 1)) {
return false;
}
}
return true;
virtual void finalize(const Ptr<ImportGraphWrapper>&,
const Ptr<ImportNodeWrapper>& fusedNode,
std::vector<Ptr<ImportNodeWrapper> >&) CV_OVERRIDE
{
opencv_onnx::NodeProto* node = fusedNode.dynamicCast<ONNXNodeWrapper>()->node;
for (int i = num_inputs_; i < 5; ++i) {
node->add_input("");
}
return false;
}
private:
@ -1119,7 +1101,11 @@ public:
ResizeSubgraph1() : ExtractScalesSubgraph()
{
int shape = addNodeToMatch("Shape", input);
int slice = addNodeToMatch("Slice", shape, addNodeToMatch("Constant"), addNodeToMatch("Constant"), addNodeToMatch("Constant"));
int slice = addNodeToMatch("Slice", {shape,
addNodeToMatch(""),
addNodeToMatch(""),
addNodeToMatch(""),
addNodeToMatch("")});
int castConcat = addNodeToMatch("Cast", concatId);
int concat = addNodeToMatch("Concat", slice, castConcat);
@ -1163,7 +1149,11 @@ public:
int cast = addNodeToMatch("Cast", concat1);
int shape2 = addNodeToMatch("Shape", input);
int slice = addNodeToMatch("Slice", shape2, addNodeToMatch("Constant"), addNodeToMatch("Constant"));
int slice = addNodeToMatch("Slice", {shape2,
addNodeToMatch(""),
addNodeToMatch(""),
addNodeToMatch(""),
addNodeToMatch("")});
int concat2 = addNodeToMatch("Concat", slice, cast);
addNodeToMatch("Resize", input, addNodeToMatch("Constant"), addNodeToMatch("Constant"), concat2);
@ -1235,8 +1225,8 @@ public:
void simplifySubgraphs(opencv_onnx::GraphProto& net)
{
std::vector<Ptr<Subgraph> > subgraphs;
subgraphs.push_back(makePtr<RemoveSliceAllOptionalInputsSubgraph>(4));
subgraphs.push_back(makePtr<RemoveSliceAllOptionalInputsSubgraph>(5));
subgraphs.push_back(makePtr<AdjustSliceAllOptionalInputsSubgraph>(3));
subgraphs.push_back(makePtr<AdjustSliceAllOptionalInputsSubgraph>(4));
subgraphs.push_back(makePtr<GeluSubGraph>());
subgraphs.push_back(makePtr<GeluApproximationSubGraph>());
subgraphs.push_back(makePtr<LayerNormSubGraph>());

@ -1235,7 +1235,7 @@ void ONNXImporter::parseSlice(LayerParams& layerParams, const opencv_onnx::NodeP
starts_ = DictValue::arrayInt(start_blob.begin<int>(), start_blob.total());
ends_ = DictValue::arrayInt(end_blob.begin<int>(), end_blob.total());
if (inp_size > 3)
if (inp_size > 3 && !getBlob(node_proto, 3).empty())
{
Mat axes_blob = getBlob(node_proto, 3);
CV_Assert(axes_blob.total() == start_blob.total());
@ -1244,7 +1244,7 @@ void ONNXImporter::parseSlice(LayerParams& layerParams, const opencv_onnx::NodeP
has_axes = true;
}
if (inp_size == 5)
if (inp_size == 5 && !getBlob(node_proto, 4).empty())
{
Mat step_blob = getBlob(node_proto, 4);
CV_Assert(step_blob.total() == start_blob.total());

Loading…
Cancel
Save