|
|
|
@ -182,12 +182,12 @@ static Mat extractConstant(const Ptr<ImportGraphWrapper>& net, int node_id, int |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
/* Slice operator has two optional inputs "axes" and "steps". Some models may be set to have
|
|
|
|
|
Slice with optional inputs of default values, some of them don't. This Subgraph removes |
|
|
|
|
all optional inputs of Slice if values are default. |
|
|
|
|
Slice with optional inputs of default values, some of them don't. This Subgraph adjusts |
|
|
|
|
all optional inputs of Slice up to 5. |
|
|
|
|
*/ |
|
|
|
|
class RemoveSliceAllOptionalInputsSubgraph : public Subgraph { |
|
|
|
|
class AdjustSliceAllOptionalInputsSubgraph : public Subgraph { |
|
|
|
|
public: |
|
|
|
|
RemoveSliceAllOptionalInputsSubgraph(size_t num_inputs = 4) { |
|
|
|
|
AdjustSliceAllOptionalInputsSubgraph(size_t num_inputs = 4) { |
|
|
|
|
num_inputs_ = num_inputs; |
|
|
|
|
|
|
|
|
|
int input = addNodeToMatch(""); |
|
|
|
@ -200,35 +200,17 @@ class RemoveSliceAllOptionalInputsSubgraph : public Subgraph { |
|
|
|
|
|
|
|
|
|
slice_id = addNodeToMatch("Slice", inputs); |
|
|
|
|
|
|
|
|
|
setFusedNode("Slice", std::vector<int>{input, starts, ends}); |
|
|
|
|
setFusedNode("Slice", inputs); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId, |
|
|
|
|
std::vector<int>& matchedNodesIds) CV_OVERRIDE { |
|
|
|
|
if (Subgraph::match(net, nodeId, matchedNodesIds)) { |
|
|
|
|
if (num_inputs_ >= 4) { // with axes
|
|
|
|
|
// Check if axes are -1 or last axis
|
|
|
|
|
auto onnx_net = net.dynamicCast<ONNXGraphWrapper>(); |
|
|
|
|
int shape_size = onnx_net->getTensorShapeSize(matchedNodesIds[slice_id], 0); |
|
|
|
|
|
|
|
|
|
auto axes = extractConstant(net, matchedNodesIds[slice_id], 3); |
|
|
|
|
for (size_t i = 0; i < axes.total(); i++) { |
|
|
|
|
const int axis = *(axes.ptr<const int>() + i); |
|
|
|
|
if (axis != -1 && axis != shape_size - 1) { |
|
|
|
|
return false; |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
if (num_inputs_ == 5) { |
|
|
|
|
// Check if steps are 1
|
|
|
|
|
auto steps = extractConstant(net, matchedNodesIds[slice_id], 4); |
|
|
|
|
if (countNonZero(steps != 1)) { |
|
|
|
|
return false; |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
return true; |
|
|
|
|
virtual void finalize(const Ptr<ImportGraphWrapper>&, |
|
|
|
|
const Ptr<ImportNodeWrapper>& fusedNode, |
|
|
|
|
std::vector<Ptr<ImportNodeWrapper> >&) CV_OVERRIDE |
|
|
|
|
{ |
|
|
|
|
opencv_onnx::NodeProto* node = fusedNode.dynamicCast<ONNXNodeWrapper>()->node; |
|
|
|
|
for (int i = num_inputs_; i < 5; ++i) { |
|
|
|
|
node->add_input(""); |
|
|
|
|
} |
|
|
|
|
return false; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
private: |
|
|
|
@ -1119,7 +1101,11 @@ public: |
|
|
|
|
ResizeSubgraph1() : ExtractScalesSubgraph() |
|
|
|
|
{ |
|
|
|
|
int shape = addNodeToMatch("Shape", input); |
|
|
|
|
int slice = addNodeToMatch("Slice", shape, addNodeToMatch("Constant"), addNodeToMatch("Constant"), addNodeToMatch("Constant")); |
|
|
|
|
int slice = addNodeToMatch("Slice", {shape, |
|
|
|
|
addNodeToMatch(""), |
|
|
|
|
addNodeToMatch(""), |
|
|
|
|
addNodeToMatch(""), |
|
|
|
|
addNodeToMatch("")}); |
|
|
|
|
|
|
|
|
|
int castConcat = addNodeToMatch("Cast", concatId); |
|
|
|
|
int concat = addNodeToMatch("Concat", slice, castConcat); |
|
|
|
@ -1163,7 +1149,11 @@ public: |
|
|
|
|
int cast = addNodeToMatch("Cast", concat1); |
|
|
|
|
|
|
|
|
|
int shape2 = addNodeToMatch("Shape", input); |
|
|
|
|
int slice = addNodeToMatch("Slice", shape2, addNodeToMatch("Constant"), addNodeToMatch("Constant")); |
|
|
|
|
int slice = addNodeToMatch("Slice", {shape2, |
|
|
|
|
addNodeToMatch(""), |
|
|
|
|
addNodeToMatch(""), |
|
|
|
|
addNodeToMatch(""), |
|
|
|
|
addNodeToMatch("")}); |
|
|
|
|
int concat2 = addNodeToMatch("Concat", slice, cast); |
|
|
|
|
addNodeToMatch("Resize", input, addNodeToMatch("Constant"), addNodeToMatch("Constant"), concat2); |
|
|
|
|
|
|
|
|
@ -1235,8 +1225,8 @@ public: |
|
|
|
|
void simplifySubgraphs(opencv_onnx::GraphProto& net) |
|
|
|
|
{ |
|
|
|
|
std::vector<Ptr<Subgraph> > subgraphs; |
|
|
|
|
subgraphs.push_back(makePtr<RemoveSliceAllOptionalInputsSubgraph>(4)); |
|
|
|
|
subgraphs.push_back(makePtr<RemoveSliceAllOptionalInputsSubgraph>(5)); |
|
|
|
|
subgraphs.push_back(makePtr<AdjustSliceAllOptionalInputsSubgraph>(3)); |
|
|
|
|
subgraphs.push_back(makePtr<AdjustSliceAllOptionalInputsSubgraph>(4)); |
|
|
|
|
subgraphs.push_back(makePtr<GeluSubGraph>()); |
|
|
|
|
subgraphs.push_back(makePtr<GeluApproximationSubGraph>()); |
|
|
|
|
subgraphs.push_back(makePtr<LayerNormSubGraph>()); |
|
|
|
|