Merge pull request #24655 from fengyuentau:graph_simplifier_optional_input

dnn onnx graph simplifier: handle optional inputs of Slice #24655

Resolves https://github.com/opencv/opencv/issues/24609

### Pull Request Readiness Checklist

See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request

- [x] I agree to contribute to the project under Apache 2 License.
- [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV
- [x] The PR is proposed to the proper branch
- [x] There is a reference to the original bug report and related work
- [x] There is accuracy test, performance test and test data in opencv_extra repository, if applicable
      Patch to opencv_extra has the same branch name.
- [x] The feature is well documented and sample code can be built with the project CMake
pull/24664/head
Yuantao Feng 1 year ago committed by GitHub
parent 22edfd2628
commit f5ec92e4ca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 76
      modules/dnn/src/onnx/onnx_graph_simplifier.cpp

@ -82,6 +82,23 @@ public:
return makePtr<ONNXNodeWrapper>(node);
}
int getTensorShapeSize(int node_id, int node_input_id) {
const auto node = getNode(node_id);
const auto &input_name = node->getInputName(node_input_id);
for (int i = 0; i < net.value_info_size(); i++) {
const auto value_info = net.value_info(i);
if (value_info.name() == input_name) {
if (value_info.has_type() && value_info.type().has_tensor_type() &&
value_info.type().tensor_type().has_shape()) {
return value_info.type().tensor_type().shape().dim_size();
} else {
return -1;
}
}
}
return -1;
}
int getInputInitializerId(int node_id, int node_input_id)
{
auto node = getNode(node_id);
@ -164,6 +181,61 @@ static Mat extractConstant(const Ptr<ImportGraphWrapper>& net, int node_id, int
}
}
/* Slice operator has two optional inputs "axes" and "steps". Some models may be set to have
Slice with optional inputs of default values, some of them don't. This Subgraph removes
all optional inputs of Slice if values are default.
*/
class RemoveSliceAllOptionalInputsSubgraph : public Subgraph {
public:
RemoveSliceAllOptionalInputsSubgraph(size_t num_inputs = 4) {
num_inputs_ = num_inputs;
int input = addNodeToMatch("");
int starts = addNodeToMatch("");
int ends = addNodeToMatch("");
std::vector<int> inputs{input, starts, ends};
for (size_t i = 3; i < num_inputs_; i++) { // axes and steps
inputs.push_back(addNodeToMatch(""));
}
slice_id = addNodeToMatch("Slice", inputs);
setFusedNode("Slice", std::vector<int>{input, starts, ends});
}
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
std::vector<int>& matchedNodesIds) CV_OVERRIDE {
if (Subgraph::match(net, nodeId, matchedNodesIds)) {
if (num_inputs_ >= 4) { // with axes
// Check if axes are -1 or last axis
auto onnx_net = net.dynamicCast<ONNXGraphWrapper>();
int shape_size = onnx_net->getTensorShapeSize(matchedNodesIds[slice_id], 0);
auto axes = extractConstant(net, matchedNodesIds[slice_id], 3);
for (size_t i = 0; i < axes.total(); i++) {
const int axis = *(axes.ptr<const int>() + i);
if (axis != -1 && axis != shape_size - 1) {
return false;
}
}
}
if (num_inputs_ == 5) {
// Check if steps are 1
auto steps = extractConstant(net, matchedNodesIds[slice_id], 4);
if (countNonZero(steps != 1)) {
return false;
}
}
return true;
}
return false;
}
private:
int slice_id;
size_t num_inputs_;
};
/* Fusion for Gelu.
Graph before fusion:
@ -1091,7 +1163,7 @@ public:
int cast = addNodeToMatch("Cast", concat1);
int shape2 = addNodeToMatch("Shape", input);
int slice = addNodeToMatch("Slice", shape2, addNodeToMatch("Constant"), addNodeToMatch("Constant"), addNodeToMatch("Constant"));
int slice = addNodeToMatch("Slice", shape2, addNodeToMatch("Constant"), addNodeToMatch("Constant"));
int concat2 = addNodeToMatch("Concat", slice, cast);
addNodeToMatch("Resize", input, addNodeToMatch("Constant"), addNodeToMatch("Constant"), concat2);
@ -1163,6 +1235,8 @@ public:
void simplifySubgraphs(opencv_onnx::GraphProto& net)
{
std::vector<Ptr<Subgraph> > subgraphs;
subgraphs.push_back(makePtr<RemoveSliceAllOptionalInputsSubgraph>(4));
subgraphs.push_back(makePtr<RemoveSliceAllOptionalInputsSubgraph>(5));
subgraphs.push_back(makePtr<GeluSubGraph>());
subgraphs.push_back(makePtr<GeluApproximationSubGraph>());
subgraphs.push_back(makePtr<LayerNormSubGraph>());

Loading…
Cancel
Save