Open Source Computer Vision Library
https://opencv.org/
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
940 lines
32 KiB
940 lines
32 KiB
// This file is part of OpenCV project. |
|
// It is subject to the license terms in the LICENSE file found in the top-level directory |
|
// of this distribution and at http://opencv.org/license.html. |
|
|
|
// Copyright (C) 2020, Intel Corporation, all rights reserved. |
|
// Third party copyrights are property of their respective owners. |
|
|
|
#include "../precomp.hpp" |
|
|
|
#include "../graph_simplifier.hpp" |
|
#include "onnx_graph_simplifier.hpp" |
|
|
|
#include <opencv2/core/utils/logger.hpp> |
|
#include <queue> |
|
|
|
namespace cv { namespace dnn { |
|
CV__DNN_INLINE_NS_BEGIN |
|
|
|
extern bool DNN_DIAGNOSTICS_RUN; |
|
|
|
// This wrapper can behave differently for fake input nodes and real graph nodes. |
|
class ONNXNodeWrapper : public ImportNodeWrapper |
|
{ |
|
public: |
|
ONNXNodeWrapper(opencv_onnx::NodeProto* _node = 0) : node(_node) {} |
|
|
|
virtual int getNumInputs() const CV_OVERRIDE |
|
{ |
|
return node ? node->input_size() : 0; |
|
} |
|
|
|
virtual std::string getInputName(int idx) const CV_OVERRIDE |
|
{ |
|
CV_Assert_N(node, idx < node->input_size()); |
|
return node->input(idx); |
|
} |
|
|
|
virtual std::string getType() const CV_OVERRIDE |
|
{ |
|
return node ? node->op_type() : ""; |
|
} |
|
|
|
virtual void setType(const std::string& type) CV_OVERRIDE |
|
{ |
|
CV_Assert(node); |
|
node->set_op_type(type); |
|
} |
|
|
|
virtual void setInputNames(const std::vector<std::string>& inputs) CV_OVERRIDE |
|
{ |
|
CV_Assert(node); |
|
node->clear_input(); |
|
for (int i = 0; i < inputs.size(); ++i) |
|
node->add_input(inputs[i]); |
|
} |
|
|
|
opencv_onnx::NodeProto* node; |
|
}; |
|
|
|
// ONNX graph's inputs are separate from nodes so we index them before the rest of nodes. |
|
class ONNXGraphWrapper : public ImportGraphWrapper |
|
{ |
|
public: |
|
ONNXGraphWrapper(opencv_onnx::GraphProto& _net) : net(_net) |
|
{ |
|
numInputs = net.input_size(); |
|
numInitializers = net.initializer_size(); |
|
} |
|
|
|
virtual Ptr<ImportNodeWrapper> getNode(int idx) const CV_OVERRIDE |
|
{ |
|
opencv_onnx::NodeProto* node = 0; |
|
if (idx >= numInputs + numInitializers) |
|
node = net.mutable_node(idx - numInputs - numInitializers); |
|
return makePtr<ONNXNodeWrapper>(node); |
|
} |
|
|
|
virtual int getNumNodes() const CV_OVERRIDE |
|
{ |
|
return numInputs + numInitializers + net.node_size(); |
|
} |
|
|
|
virtual int getNumOutputs(int nodeId) const CV_OVERRIDE |
|
{ |
|
if (nodeId < numInputs + numInitializers) |
|
return 1; |
|
else |
|
return net.node(nodeId - numInputs - numInitializers).output_size(); |
|
} |
|
|
|
virtual std::string getOutputName(int nodeId, int outId) const CV_OVERRIDE |
|
{ |
|
CV_Assert(outId < getNumOutputs(nodeId)); |
|
if (nodeId < numInputs) |
|
return net.input(nodeId).name(); |
|
else if (nodeId < numInputs + numInitializers) |
|
return net.initializer(nodeId - numInputs).name(); |
|
else |
|
return net.node(nodeId - numInputs - numInitializers).output(outId); |
|
} |
|
|
|
virtual void removeNode(int idx) CV_OVERRIDE |
|
{ |
|
CV_Assert(idx >= numInputs + numInitializers); |
|
net.mutable_node()->DeleteSubrange(idx - numInputs - numInitializers, 1); |
|
} |
|
|
|
private: |
|
int numInputs, numInitializers; |
|
opencv_onnx::GraphProto& net; |
|
}; |
|
|
|
class SoftMaxSubgraphBase : public Subgraph |
|
{ |
|
public: |
|
SoftMaxSubgraphBase() : axis(1), id(-1) {} |
|
|
|
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId, |
|
std::vector<int>& matchedNodesIds, |
|
std::vector<int>& targetNodesIds) CV_OVERRIDE |
|
{ |
|
if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds)) |
|
{ |
|
CV_Assert(id >= 0 && id < matchedNodesIds.size()); |
|
Ptr<ImportNodeWrapper> sum = net->getNode(matchedNodesIds[id]); |
|
opencv_onnx::NodeProto* node = sum.dynamicCast<ONNXNodeWrapper>()->node; |
|
|
|
for (int i = 0; i < node->attribute_size(); i++) |
|
{ |
|
opencv_onnx::AttributeProto attr = node->attribute(i); |
|
if (attr.name() != "axes") |
|
continue; |
|
if (attr.ints_size() != 1) |
|
CV_Error(Error::StsNotImplemented, format("Unexpected number of axes: %d", attr.ints_size())); |
|
axis = attr.ints(0); |
|
return true; |
|
} |
|
CV_Error(Error::StsNotImplemented, "Missed axes attribute"); |
|
} |
|
return false; |
|
} |
|
|
|
virtual void finalize(const Ptr<ImportGraphWrapper>&, |
|
const Ptr<ImportNodeWrapper>& fusedNode, |
|
std::vector<Ptr<ImportNodeWrapper> >&) CV_OVERRIDE |
|
{ |
|
opencv_onnx::NodeProto* node = fusedNode.dynamicCast<ONNXNodeWrapper>()->node; |
|
opencv_onnx::AttributeProto* attr = node->add_attribute(); |
|
attr->set_name("axis"); |
|
attr->set_i(axis); |
|
} |
|
|
|
protected: |
|
int axis; |
|
int id; |
|
}; |
|
|
|
class SoftMaxSubgraph : public SoftMaxSubgraphBase |
|
{ |
|
public: |
|
SoftMaxSubgraph() |
|
{ |
|
int input = addNodeToMatch(""); |
|
int inpExp = addNodeToMatch("Exp", input); |
|
|
|
int sum = addNodeToMatch("ReduceSum", inpExp); |
|
id = 1; |
|
|
|
addNodeToMatch("Div", inpExp, sum); |
|
setFusedNode("Softmax", input); |
|
} |
|
}; |
|
|
|
class SoftMaxSubgraph2 : public SoftMaxSubgraphBase { |
|
public: |
|
SoftMaxSubgraph2() { |
|
int input = addNodeToMatch(""); |
|
|
|
int reducemax = addNodeToMatch("ReduceMax", input); |
|
id = 0; |
|
|
|
int sub = addNodeToMatch("Sub", input, reducemax); |
|
int exp = addNodeToMatch("Exp", sub); |
|
int reducesum = addNodeToMatch("ReduceSum", exp, addNodeToMatch("")); |
|
addNodeToMatch("Div", exp, reducesum); |
|
setFusedNode("Softmax", input); |
|
} |
|
}; |
|
|
|
class LogSoftMaxSubgraph : public SoftMaxSubgraphBase |
|
{ |
|
public: |
|
LogSoftMaxSubgraph() |
|
{ |
|
int input = addNodeToMatch(""); |
|
|
|
int reducemax = addNodeToMatch("ReduceMax", input); |
|
id = 0; |
|
|
|
int sub_1 = addNodeToMatch("Sub", input, reducemax); |
|
int exp = addNodeToMatch("Exp", sub_1); |
|
int reducesum = addNodeToMatch("ReduceSum", exp, addNodeToMatch("")); |
|
int log = addNodeToMatch("Log", reducesum); |
|
addNodeToMatch("Sub", sub_1, log); |
|
setFusedNode("LogSoftmax", input); |
|
} |
|
}; |
|
|
|
class HardSwishSubgraph : public Subgraph |
|
{ |
|
public: |
|
HardSwishSubgraph() |
|
{ |
|
int input = addNodeToMatch(""); |
|
int hardSigmoid = addNodeToMatch("HardSigmoid", input); |
|
addNodeToMatch("Mul", input, hardSigmoid); |
|
setFusedNode("HardSwish", input); |
|
} |
|
|
|
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId, |
|
std::vector<int>& matchedNodesIds, |
|
std::vector<int>& targetNodesIds) CV_OVERRIDE |
|
{ |
|
if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds)) |
|
{ |
|
Ptr<ImportNodeWrapper> hardSigmoid = net->getNode(matchedNodesIds[0]); |
|
opencv_onnx::NodeProto* node = hardSigmoid.dynamicCast<ONNXNodeWrapper>()->node; |
|
|
|
uint8_t matched = 0; |
|
for (int i = 0; i < node->attribute_size(); i++) |
|
{ |
|
opencv_onnx::AttributeProto attr = node->attribute(i); |
|
if ((attr.name() == "alpha" && attr.f() == 1.f / 6.f) || |
|
(attr.name() == "beta" && attr.f() == 0.5f)) |
|
{ |
|
++matched; |
|
} |
|
} |
|
return matched == 2; |
|
} |
|
return false; |
|
} |
|
}; |
|
|
|
class CeluSubgraph : public Subgraph |
|
{ |
|
public: |
|
CeluSubgraph() : alpha(1.f) |
|
{ |
|
int input = addNodeToMatch(""); |
|
int div = addNodeToMatch("Div", input, addNodeToMatch("")); |
|
int elu = addNodeToMatch("Elu", div); |
|
addNodeToMatch("Mul", addNodeToMatch(""), elu); |
|
setFusedNode("Celu", input); |
|
} |
|
|
|
static float extractAlpha(const Ptr<ImportGraphWrapper>& net, int node_id, int input_id) |
|
{ |
|
const Ptr<ImportNodeWrapper> node = net->getNode(node_id); |
|
int const_id = getInputNodeId(net, node, input_id); |
|
Ptr<ImportNodeWrapper> alpha_ptr = net->getNode(const_id); |
|
opencv_onnx::NodeProto* alpha_node = alpha_ptr.dynamicCast<ONNXNodeWrapper>()->node; |
|
opencv_onnx::TensorProto alpha_proto = alpha_node->attribute(0).t(); |
|
Mat alpha_mat = getMatFromTensor(alpha_proto); |
|
return *alpha_mat.ptr<float>(); |
|
} |
|
|
|
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId, |
|
std::vector<int>& matchedNodesIds, |
|
std::vector<int>& targetNodesIds) CV_OVERRIDE |
|
{ |
|
if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds)) |
|
{ |
|
float alpha_div = extractAlpha(net, matchedNodesIds[0], 1); |
|
float alpha_mul = extractAlpha(net, matchedNodesIds[2], 0); |
|
float alpha_elu = 1.f; |
|
|
|
Ptr<ImportNodeWrapper> elu_ptr = net->getNode(matchedNodesIds[1]); |
|
opencv_onnx::NodeProto* elu_node = elu_ptr.dynamicCast<ONNXNodeWrapper>()->node; |
|
|
|
for (int i = 0; i < elu_node->attribute_size(); i++) |
|
{ |
|
opencv_onnx::AttributeProto attr = elu_node->attribute(i); |
|
if (attr.name() != "alpha") |
|
continue; |
|
alpha_elu = attr.f(); |
|
} |
|
|
|
alpha = alpha_div; |
|
return alpha_elu == 1.f && alpha_div == alpha_mul; |
|
} |
|
return false; |
|
} |
|
|
|
virtual void finalize(const Ptr<ImportGraphWrapper>&, |
|
const Ptr<ImportNodeWrapper>& fusedNode, |
|
std::vector<Ptr<ImportNodeWrapper> >&) CV_OVERRIDE |
|
{ |
|
opencv_onnx::NodeProto* node = fusedNode.dynamicCast<ONNXNodeWrapper>()->node; |
|
opencv_onnx::AttributeProto* alpha_attr = node->add_attribute(); |
|
alpha_attr->set_name("alpha"); |
|
alpha_attr->set_f(alpha); |
|
} |
|
|
|
protected: |
|
float alpha; |
|
}; |
|
|
|
class NormalizeSubgraphBase : public Subgraph |
|
{ |
|
public: |
|
NormalizeSubgraphBase(int _normNodeOrder = 0) : axis(1), normNodeOrder(_normNodeOrder) {} |
|
|
|
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId, |
|
std::vector<int>& matchedNodesIds, |
|
std::vector<int>& targetNodesIds) CV_OVERRIDE |
|
{ |
|
if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds)) |
|
{ |
|
Ptr<ImportNodeWrapper> norm = net->getNode(matchedNodesIds[normNodeOrder]); |
|
opencv_onnx::NodeProto* node = norm.dynamicCast<ONNXNodeWrapper>()->node; |
|
|
|
for (int i = 0; i < node->attribute_size(); i++) |
|
{ |
|
opencv_onnx::AttributeProto attr = node->attribute(i); |
|
if (attr.name() != "axes") |
|
continue; |
|
if (attr.ints_size() != 1) |
|
CV_Error(Error::StsNotImplemented, format("Unexpected number of axes: %d", attr.ints_size())); |
|
axis = attr.ints(0); |
|
return true; |
|
} |
|
CV_Error(Error::StsNotImplemented, "Missed axes attribute"); |
|
} |
|
return false; |
|
} |
|
|
|
virtual void finalize(const Ptr<ImportGraphWrapper>&, |
|
const Ptr<ImportNodeWrapper>& fusedNode, |
|
std::vector<Ptr<ImportNodeWrapper> >&) CV_OVERRIDE |
|
{ |
|
opencv_onnx::NodeProto* node = fusedNode.dynamicCast<ONNXNodeWrapper>()->node; |
|
opencv_onnx::AttributeProto* axis_attr = node->add_attribute(); |
|
axis_attr->set_name("axis"); |
|
axis_attr->set_i(axis); |
|
|
|
opencv_onnx::AttributeProto* end_axis_attr = node->add_attribute(); |
|
end_axis_attr->set_name("end_axis"); |
|
end_axis_attr->set_i(axis); |
|
} |
|
|
|
protected: |
|
int axis, normNodeOrder; |
|
}; |
|
|
|
class NormalizeSubgraph1 : public NormalizeSubgraphBase |
|
{ |
|
public: |
|
NormalizeSubgraph1() |
|
{ |
|
int input = addNodeToMatch(""); |
|
int norm = addNodeToMatch("ReduceL2", input); |
|
addNodeToMatch("Div", input, norm); |
|
setFusedNode("Normalize", input); |
|
} |
|
}; |
|
|
|
class NormalizeSubgraph2 : public NormalizeSubgraphBase |
|
{ |
|
public: |
|
NormalizeSubgraph2() |
|
{ |
|
int input = addNodeToMatch(""); |
|
int norm = addNodeToMatch("ReduceL2", input); |
|
int clip = addNodeToMatch("Clip", norm); |
|
int shape = addNodeToMatch("Shape", input); |
|
int expand = addNodeToMatch("Expand", clip, shape); |
|
addNodeToMatch("Div", input, expand); |
|
setFusedNode("Normalize", input); |
|
} |
|
}; |
|
|
|
class NormalizeSubgraph2_2 : public NormalizeSubgraphBase |
|
{ |
|
public: |
|
NormalizeSubgraph2_2() |
|
{ |
|
int input = addNodeToMatch(""); |
|
int norm = addNodeToMatch("ReduceL2", input); |
|
|
|
int min = addNodeToMatch(""); |
|
int max = addNodeToMatch(""); |
|
int clip = addNodeToMatch("Clip", norm, min, max); |
|
|
|
int shape = addNodeToMatch(""); |
|
int expand = addNodeToMatch("Expand", clip, shape); |
|
|
|
addNodeToMatch("Div", input, expand); |
|
|
|
setFusedNode("Normalize", input); |
|
} |
|
}; |
|
|
|
class NormalizeSubgraph3 : public NormalizeSubgraphBase |
|
{ |
|
public: |
|
NormalizeSubgraph3() : NormalizeSubgraphBase(1) |
|
{ |
|
int input = addNodeToMatch(""); |
|
int power = addNodeToMatch("Constant"); |
|
int squared = addNodeToMatch("Pow", input, power); |
|
int sum = addNodeToMatch("ReduceSum", squared); |
|
int sqrtNode = addNodeToMatch("Sqrt", sum); |
|
int eps = addNodeToMatch("Constant"); |
|
int add = addNodeToMatch("Add", sqrtNode, eps); |
|
|
|
addNodeToMatch("Div", input, add); |
|
setFusedNode("Normalize", input); |
|
} |
|
}; |
|
|
|
class NormalizeSubgraph4 : public NormalizeSubgraphBase |
|
{ |
|
public: |
|
NormalizeSubgraph4() : NormalizeSubgraphBase(1) |
|
{ |
|
int input = addNodeToMatch(""); |
|
int mul = addNodeToMatch("Mul", input, input); |
|
int sum = addNodeToMatch("ReduceSum", mul); |
|
int eps = addNodeToMatch(""); |
|
int max = addNodeToMatch("Max", sum, eps); |
|
int sqrt = addNodeToMatch("Sqrt", max); |
|
int reciprocal = addNodeToMatch("Reciprocal", sqrt); |
|
addNodeToMatch("Mul", input, reciprocal); |
|
setFusedNode("Normalize", input); |
|
} |
|
}; |
|
|
|
class NormalizeSubgraph5 : public NormalizeSubgraphBase |
|
{ |
|
public: |
|
NormalizeSubgraph5() : NormalizeSubgraphBase(1) |
|
{ |
|
int input = addNodeToMatch(""); |
|
int mul = addNodeToMatch("Mul", input, input); |
|
int sum = addNodeToMatch("ReduceSum", mul); |
|
int clip = addNodeToMatch("Clip", sum); |
|
int sqrt = addNodeToMatch("Sqrt", clip); |
|
int one = addNodeToMatch("Constant"); |
|
int div = addNodeToMatch("Div", one, sqrt); |
|
addNodeToMatch("Mul", input, div); |
|
setFusedNode("Normalize", input); |
|
} |
|
}; |
|
|
|
class GatherCastSubgraph : public Subgraph |
|
{ |
|
public: |
|
GatherCastSubgraph() |
|
{ |
|
int input = addNodeToMatch(""); |
|
int index = addNodeToMatch("Constant"); |
|
int gather = addNodeToMatch("Gather", input, index); |
|
addNodeToMatch("Cast", gather); |
|
setFusedNode("Gather", input, index); |
|
} |
|
|
|
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId, |
|
std::vector<int>& matchedNodesIds, |
|
std::vector<int>& targetNodesIds) CV_OVERRIDE |
|
{ |
|
bool retVal = Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds); |
|
size_t matchedNodesNum = matchedNodesIds.size(); |
|
// Now we check if merging can be made for these Gather and Cast nodes |
|
if (!retVal || matchedNodesNum < 2) |
|
return retVal; |
|
else { |
|
int nodeToMatch = matchedNodesIds[matchedNodesNum - 1]; |
|
const Ptr<ImportNodeWrapper> node = net->getNode(nodeToMatch); |
|
if (node->getType() == "Cast") { |
|
int inpNodeId = matchedNodesIds[matchedNodesNum - 2]; |
|
const Ptr<ImportNodeWrapper> inpNode = net->getNode(inpNodeId); |
|
if (inpNode->getType() == "Gather") { |
|
int numNodes = net->getNumNodes(); |
|
std::string inpNodeName = node->getInputName(0); |
|
for (int i = 0; i < numNodes; ++i) { |
|
const Ptr<ImportNodeWrapper> node_to_check = net->getNode(i); |
|
int numInp = node_to_check->getNumInputs(); |
|
for (int inp = 0; inp < numInp; ++inp) { |
|
if (i != nodeToMatch && inpNodeName == node_to_check->getInputName(0)) { |
|
// Another node has the same input node, so it cannot be merged. |
|
return false; |
|
} |
|
} |
|
} |
|
} |
|
} |
|
} |
|
return retVal; |
|
} |
|
}; |
|
|
|
class ExpandSubgraph : public Subgraph |
|
{ |
|
public: |
|
ExpandSubgraph() |
|
{ |
|
int input = addNodeToMatch(""); |
|
int values = addNodeToMatch(""); |
|
int init = addNodeToMatch("ConstantOfShape", values); |
|
int coeff = addNodeToMatch("Constant"); |
|
int mul = addNodeToMatch("Mul", init, coeff); |
|
int shape = addNodeToMatch("Constant"); |
|
int condition = addNodeToMatch("Equal", shape, mul); |
|
int where = addNodeToMatch("Where", condition, init, addNodeToMatch("Constant")); |
|
addNodeToMatch("Expand", input, where); |
|
setFusedNode("Expand", input, shape); |
|
} |
|
}; |
|
|
|
class MishSubgraph : public Subgraph |
|
{ |
|
public: |
|
MishSubgraph() |
|
{ |
|
int input = addNodeToMatch(""); |
|
int softplus = addNodeToMatch("Softplus", input); |
|
int tanh = addNodeToMatch("Tanh", softplus); |
|
addNodeToMatch("Mul", input, tanh); |
|
setFusedNode("Mish", input); |
|
} |
|
}; |
|
|
|
// softplus(x) = log(exp(x) + 1) |
|
class SoftplusSubgraph: public Subgraph |
|
{ |
|
public: |
|
SoftplusSubgraph() |
|
{ |
|
int input = addNodeToMatch(""); |
|
int exp = addNodeToMatch("Exp", input); |
|
int addVal = addNodeToMatch(""); |
|
int add = addNodeToMatch("Add", addVal, exp); |
|
addNodeToMatch("Log", add); |
|
setFusedNode("Softplus", input); |
|
} |
|
}; |
|
|
|
class SoftplusSubgraph2: public Subgraph |
|
{ |
|
public: |
|
SoftplusSubgraph2() |
|
{ |
|
int input = addNodeToMatch(""); |
|
int exp = addNodeToMatch("Exp", input); |
|
int addVal = addNodeToMatch(""); |
|
int add = addNodeToMatch("Add", exp, addVal); |
|
addNodeToMatch("Log", add); |
|
setFusedNode("Softplus", input); |
|
} |
|
}; |
|
|
|
class MulCastSubgraph : public Subgraph |
|
{ |
|
public: |
|
MulCastSubgraph() |
|
{ |
|
int input = addNodeToMatch(""); |
|
int scaleNode = addNodeToMatch("Constant"); |
|
int mul = addNodeToMatch("Mul", input, scaleNode); |
|
addNodeToMatch("Cast", mul); |
|
setFusedNode("Mul", input, scaleNode); |
|
} |
|
}; |
|
|
|
class ExtractScalesSubgraph : public Subgraph |
|
{ |
|
public: |
|
ExtractScalesSubgraph() |
|
{ |
|
input = addNodeToMatch(""); |
|
|
|
int indexH = addNodeToMatch("Constant"); |
|
int shape1 = addNodeToMatch("Shape", input); |
|
int gather1 = addNodeToMatch("Gather", shape1, indexH); |
|
scaleHNode = addNodeToMatch("Constant"); |
|
int mul1 = addNodeToMatch("Mul", gather1, scaleHNode); |
|
int floor1 = addNodeToMatch("Floor", mul1); |
|
|
|
int indexW = addNodeToMatch("Constant"); |
|
int shape2 = addNodeToMatch("Shape", input); |
|
int gather2 = addNodeToMatch("Gather", shape2, indexW); |
|
scaleWNode = addNodeToMatch("Constant"); |
|
int mul2 = addNodeToMatch("Mul", gather2, scaleWNode); |
|
int floor2 = addNodeToMatch("Floor", mul2); |
|
|
|
int unsqueeze1 = addNodeToMatch("Unsqueeze", floor1); |
|
int unsqueeze2 = addNodeToMatch("Unsqueeze", floor2); |
|
concatId = addNodeToMatch("Concat", unsqueeze1, unsqueeze2); |
|
} |
|
|
|
void finalize(const Ptr<ImportGraphWrapper>& net, |
|
const Ptr<ImportNodeWrapper>& fusedNode, |
|
std::vector<Ptr<ImportNodeWrapper> >& inputs) CV_OVERRIDE |
|
{ |
|
opencv_onnx::NodeProto* constant_node = inputs[1].dynamicCast<ONNXNodeWrapper>()->node; |
|
opencv_onnx::TensorProto tensor_proto = constant_node->attribute(0).t(); |
|
Mat scaleW = getMatFromTensor(tensor_proto); |
|
CV_Assert(scaleW.total() == 1); |
|
scaleW.convertTo(scaleW, CV_32F); |
|
|
|
constant_node = inputs[2].dynamicCast<ONNXNodeWrapper>()->node; |
|
tensor_proto = constant_node->attribute(0).t(); |
|
Mat scaleH = getMatFromTensor(tensor_proto); |
|
CV_Assert(scaleH.total() == 1); |
|
scaleH.convertTo(scaleH, CV_32F); |
|
|
|
opencv_onnx::NodeProto* node = fusedNode.dynamicCast<ONNXNodeWrapper>()->node; |
|
opencv_onnx::AttributeProto* attrH = node->add_attribute(); |
|
attrH->set_name("height_scale"); |
|
attrH->set_i(scaleH.at<float>(0)); |
|
opencv_onnx::AttributeProto* attrW = node->add_attribute(); |
|
attrW->set_name("width_scale"); |
|
attrW->set_i(scaleW.at<float>(0)); |
|
|
|
node->mutable_input()->DeleteSubrange(1, 2); // Remove two last inputs |
|
} |
|
|
|
protected: |
|
int input, concatId; |
|
int scaleHNode, scaleWNode; |
|
}; |
|
|
|
class UpsampleSubgraph : public ExtractScalesSubgraph |
|
{ |
|
public: |
|
UpsampleSubgraph() : ExtractScalesSubgraph() |
|
{ |
|
int shape = addNodeToMatch("Shape", input); |
|
int slice = addNodeToMatch("Slice", shape); |
|
|
|
int castConcat = addNodeToMatch("Cast", concatId); |
|
int castSlice = addNodeToMatch("Cast", slice); |
|
int divide = addNodeToMatch("Div", castConcat, castSlice); |
|
|
|
int constant = addNodeToMatch("Constant"); |
|
int concat = addNodeToMatch("Concat", constant, divide); |
|
|
|
addNodeToMatch("Upsample", input, concat); |
|
setFusedNode("Upsample", input, scaleWNode, scaleHNode); |
|
} |
|
}; |
|
|
|
class ResizeSubgraph1 : public ExtractScalesSubgraph |
|
{ |
|
public: |
|
ResizeSubgraph1() : ExtractScalesSubgraph() |
|
{ |
|
int shape = addNodeToMatch("Shape", input); |
|
int slice = addNodeToMatch("Slice", shape, addNodeToMatch("Constant"), addNodeToMatch("Constant"), addNodeToMatch("Constant")); |
|
|
|
int castConcat = addNodeToMatch("Cast", concatId); |
|
int concat = addNodeToMatch("Concat", slice, castConcat); |
|
int constant = addNodeToMatch("Constant"); |
|
|
|
addNodeToMatch("Resize", input, constant, constant, concat); |
|
setFusedNode("Upsample", input, scaleWNode, scaleHNode); |
|
} |
|
}; |
|
|
|
class ResizeSubgraph2 : public ExtractScalesSubgraph |
|
{ |
|
public: |
|
ResizeSubgraph2() : ExtractScalesSubgraph() |
|
{ |
|
int constantConcat = addNodeToMatch("Constant"); |
|
int castConcat = addNodeToMatch("Cast", concatId); |
|
int concat = addNodeToMatch("Concat", constantConcat, castConcat); |
|
int constant = addNodeToMatch("Constant"); |
|
|
|
addNodeToMatch("Resize", input, constant, constant, concat); |
|
setFusedNode("Upsample", input, scaleWNode, scaleHNode); |
|
} |
|
}; |
|
|
|
class BatchNormalizationSubgraphBase : public Subgraph |
|
{ |
|
public: |
|
BatchNormalizationSubgraphBase() |
|
{ |
|
input = addNodeToMatch(""); |
|
var = addNodeToMatch(""); |
|
mean = addNodeToMatch(""); |
|
weight = addNodeToMatch(""); |
|
bias = addNodeToMatch(""); |
|
A = addNodeToMatch(""); |
|
shape1 = addNodeToMatch(""); |
|
shape2 = addNodeToMatch(""); |
|
} |
|
protected: |
|
int input, var, mean, weight, bias, A, shape1, shape2; |
|
}; |
|
|
|
class BatchNormalizationSubgraph1 : public BatchNormalizationSubgraphBase |
|
{ |
|
public: |
|
BatchNormalizationSubgraph1() |
|
{ |
|
int reshape1 = addNodeToMatch("Reshape", weight, shape1); |
|
int reshape2 = addNodeToMatch("Reshape", bias, shape2); |
|
int shape3 = addNodeToMatch("Constant"); |
|
int reshape3 = addNodeToMatch("Reshape", var, shape3); |
|
int shape4 = addNodeToMatch("Constant"); |
|
int reshape4 = addNodeToMatch("Reshape", mean, shape4); |
|
int sqrtNode = addNodeToMatch("Sqrt", reshape3); |
|
int divNode = addNodeToMatch("Div", A, sqrtNode); |
|
int mul1 = addNodeToMatch("Mul", reshape1, divNode); |
|
int mul2 = addNodeToMatch("Mul", reshape4, mul1); |
|
int sub = addNodeToMatch("Sub", reshape2, mul2); |
|
int mul3 = addNodeToMatch("Mul", input, mul1); |
|
addNodeToMatch("Add", mul3, sub); |
|
setFusedNode("BatchNormalization", input, weight, bias, mean, var); |
|
} |
|
}; |
|
|
|
class BatchNormalizationSubgraph2 : public BatchNormalizationSubgraphBase |
|
{ |
|
public: |
|
BatchNormalizationSubgraph2() |
|
{ |
|
int sqrtNode = addNodeToMatch("Sqrt", var); |
|
int divNode = addNodeToMatch("Div", A, sqrtNode); |
|
int mul1 = addNodeToMatch("Mul", weight, divNode); |
|
int reshape2 = addNodeToMatch("Reshape", mul1, shape2); |
|
|
|
int mulMean = addNodeToMatch("Mul", mean, mul1); |
|
int sub = addNodeToMatch("Sub", bias, mulMean); |
|
int reshape1 = addNodeToMatch("Reshape", sub, shape1); |
|
|
|
int mulInput = addNodeToMatch("Mul", input, reshape2); |
|
addNodeToMatch("Add", mulInput, reshape1); |
|
setFusedNode("BatchNormalization", input, weight, bias, mean, var); |
|
} |
|
}; |
|
|
|
void simplifySubgraphs(opencv_onnx::GraphProto& net) |
|
{ |
|
std::vector<Ptr<Subgraph> > subgraphs; |
|
subgraphs.push_back(makePtr<GatherCastSubgraph>()); |
|
subgraphs.push_back(makePtr<MulCastSubgraph>()); |
|
subgraphs.push_back(makePtr<UpsampleSubgraph>()); |
|
subgraphs.push_back(makePtr<ResizeSubgraph1>()); |
|
subgraphs.push_back(makePtr<ResizeSubgraph2>()); |
|
subgraphs.push_back(makePtr<SoftMaxSubgraph>()); |
|
subgraphs.push_back(makePtr<SoftMaxSubgraph2>()); |
|
subgraphs.push_back(makePtr<LogSoftMaxSubgraph>()); |
|
subgraphs.push_back(makePtr<HardSwishSubgraph>()); |
|
subgraphs.push_back(makePtr<CeluSubgraph>()); |
|
subgraphs.push_back(makePtr<NormalizeSubgraph1>()); |
|
subgraphs.push_back(makePtr<NormalizeSubgraph2>()); |
|
subgraphs.push_back(makePtr<NormalizeSubgraph2_2>()); |
|
subgraphs.push_back(makePtr<NormalizeSubgraph3>()); |
|
subgraphs.push_back(makePtr<BatchNormalizationSubgraph1>()); |
|
subgraphs.push_back(makePtr<BatchNormalizationSubgraph2>()); |
|
subgraphs.push_back(makePtr<ExpandSubgraph>()); |
|
subgraphs.push_back(makePtr<SoftplusSubgraph>()); |
|
subgraphs.push_back(makePtr<SoftplusSubgraph2>()); |
|
subgraphs.push_back(makePtr<MishSubgraph>()); |
|
subgraphs.push_back(makePtr<NormalizeSubgraph4>()); |
|
subgraphs.push_back(makePtr<NormalizeSubgraph5>()); |
|
|
|
simplifySubgraphs(Ptr<ImportGraphWrapper>(new ONNXGraphWrapper(net)), subgraphs); |
|
} |
|
|
|
Mat getMatFromTensor(const opencv_onnx::TensorProto& tensor_proto) |
|
{ |
|
if (tensor_proto.raw_data().empty() && tensor_proto.float_data().empty() && |
|
tensor_proto.double_data().empty() && tensor_proto.int64_data().empty() && |
|
tensor_proto.int32_data().empty()) |
|
return Mat(); |
|
|
|
opencv_onnx::TensorProto_DataType datatype = tensor_proto.data_type(); |
|
Mat blob; |
|
std::vector<int> sizes; |
|
for (int i = 0; i < tensor_proto.dims_size(); i++) { |
|
sizes.push_back(tensor_proto.dims(i)); |
|
} |
|
if (sizes.empty()) |
|
sizes.assign(1, 1); |
|
if (datatype == opencv_onnx::TensorProto_DataType_FLOAT) { |
|
|
|
if (!tensor_proto.float_data().empty()) { |
|
const ::google::protobuf::RepeatedField<float> field = tensor_proto.float_data(); |
|
Mat(sizes, CV_32FC1, (void*)field.data()).copyTo(blob); |
|
} |
|
else { |
|
char* val = const_cast<char*>(tensor_proto.raw_data().c_str()); |
|
Mat(sizes, CV_32FC1, val).copyTo(blob); |
|
} |
|
} |
|
else if (datatype == opencv_onnx::TensorProto_DataType_FLOAT16) |
|
{ |
|
// FIXME, for now, we only load FP16 Tensor as FP32 Mat, full support for FP16 is required in the future. |
|
CV_LOG_ONCE_WARNING(NULL, "DNN: load FP16 model as FP32 model, and it takes twice the FP16 RAM requirement."); |
|
|
|
// ONNX saves float 16 data in two format: int32 and raw_data. |
|
// Link: https://github.com/onnx/onnx/issues/4460#issuecomment-1224373746 |
|
if (!tensor_proto.int32_data().empty()) |
|
{ |
|
int offset = 0; |
|
#ifdef WORDS_BIGENDIAN |
|
offset = 1; |
|
#endif |
|
const ::google::protobuf::RepeatedField<int32_t> field = tensor_proto.int32_data(); |
|
|
|
AutoBuffer<float16_t, 16> aligned_val; |
|
size_t sz = tensor_proto.int32_data().size(); |
|
aligned_val.allocate(sz); |
|
float16_t* bufPtr = aligned_val.data(); |
|
|
|
float16_t *fp16Ptr = (float16_t *)field.data(); |
|
for (int i = 0; i < sz; i++) |
|
{ |
|
bufPtr[i] = fp16Ptr[i*2 + offset]; |
|
} |
|
Mat(sizes, CV_16FC1, bufPtr).convertTo(blob, CV_32FC1); |
|
} |
|
else |
|
{ |
|
char* val = const_cast<char*>(tensor_proto.raw_data().c_str()); |
|
#if CV_STRONG_ALIGNMENT |
|
// Aligned pointer is required. |
|
AutoBuffer<float16_t, 16> aligned_val; |
|
if (!isAligned<sizeof(float16_t)>(val)) |
|
{ |
|
size_t sz = tensor_proto.raw_data().size(); |
|
aligned_val.allocate(divUp(sz, sizeof(float16_t))); |
|
memcpy(aligned_val.data(), val, sz); |
|
val = (char*)aligned_val.data(); |
|
} |
|
#endif |
|
Mat(sizes, CV_16FC1, val).convertTo(blob, CV_32FC1); |
|
} |
|
} |
|
else if (datatype == opencv_onnx::TensorProto_DataType_DOUBLE) |
|
{ |
|
const ::google::protobuf::RepeatedField<double> field = tensor_proto.double_data(); |
|
CV_Assert(!field.empty()); |
|
char* val = (char *)field.data(); |
|
#if CV_STRONG_ALIGNMENT |
|
// Aligned pointer is required. |
|
AutoBuffer<double, 16> aligned_val; |
|
if (!isAligned<sizeof(double)>(val)) |
|
{ |
|
size_t sz = tensor_proto.raw_data().size(); |
|
aligned_val.allocate(divUp(sz, sizeof(double))); |
|
memcpy(aligned_val.data(), val, sz); |
|
val = (char*)aligned_val.data(); |
|
} |
|
#endif |
|
Mat(sizes, CV_64FC1, val).convertTo(blob, CV_32FC1); |
|
} |
|
else if (datatype == opencv_onnx::TensorProto_DataType_INT32) |
|
{ |
|
if (!tensor_proto.int32_data().empty()) |
|
{ |
|
const ::google::protobuf::RepeatedField<int32_t> field = tensor_proto.int32_data(); |
|
Mat(sizes, CV_32SC1, (void*)field.data()).copyTo(blob); |
|
} |
|
else |
|
{ |
|
char* val = const_cast<char*>(tensor_proto.raw_data().c_str()); |
|
Mat(sizes, CV_32SC1, val).copyTo(blob); |
|
} |
|
} |
|
else if (datatype == opencv_onnx::TensorProto_DataType_INT64) |
|
{ |
|
blob.create(sizes, CV_32SC1); |
|
int32_t* dst = reinterpret_cast<int32_t*>(blob.data); |
|
|
|
if (!tensor_proto.int64_data().empty()) { |
|
::google::protobuf::RepeatedField< ::google::protobuf::int64> src = tensor_proto.int64_data(); |
|
convertInt64ToInt32(src, dst, blob.total()); |
|
} |
|
else |
|
{ |
|
const char* val = tensor_proto.raw_data().c_str(); |
|
#if CV_STRONG_ALIGNMENT |
|
// Aligned pointer is required: https://github.com/opencv/opencv/issues/16373 |
|
// this doesn't work: typedef int64_t CV_DECL_ALIGNED(1) unaligned_int64_t; |
|
AutoBuffer<int64_t, 16> aligned_val; |
|
if (!isAligned<sizeof(int64_t)>(val)) |
|
{ |
|
size_t sz = tensor_proto.raw_data().size(); |
|
aligned_val.allocate(divUp(sz, sizeof(int64_t))); |
|
memcpy(aligned_val.data(), val, sz); |
|
val = (const char*)aligned_val.data(); |
|
} |
|
#endif |
|
const int64_t* src = reinterpret_cast<const int64_t*>(val); |
|
convertInt64ToInt32(src, dst, blob.total()); |
|
} |
|
} |
|
else if (datatype == opencv_onnx::TensorProto_DataType_INT8 || |
|
datatype == opencv_onnx::TensorProto_DataType_UINT8) |
|
{ |
|
// TODO : Add support for uint8 weights and acitvations. For now, converting uint8 tensors to int8. |
|
int offset = datatype == opencv_onnx::TensorProto_DataType_INT8 ? 0 : -128; |
|
int depth = datatype == opencv_onnx::TensorProto_DataType_INT8 ? CV_8S : CV_8U; |
|
|
|
if (!tensor_proto.int32_data().empty()) |
|
{ |
|
const ::google::protobuf::RepeatedField<int32_t> field = tensor_proto.int32_data(); |
|
Mat(sizes, CV_32SC1, (void*)field.data()).convertTo(blob, CV_8S, 1.0, offset); |
|
} |
|
else |
|
{ |
|
char* val = const_cast<char*>(tensor_proto.raw_data().c_str()); |
|
Mat(sizes, depth, val).convertTo(blob, CV_8S, 1.0, offset); |
|
} |
|
} |
|
else |
|
{ |
|
std::string errorMsg = "Unsupported data type: " + |
|
opencv_onnx::TensorProto_DataType_Name(datatype); |
|
|
|
if (!DNN_DIAGNOSTICS_RUN) |
|
{ |
|
CV_Error(Error::StsUnsupportedFormat, errorMsg); |
|
} |
|
CV_LOG_ERROR(NULL, errorMsg); |
|
return blob; |
|
} |
|
if (tensor_proto.dims_size() == 0) |
|
blob.dims = 1; // To force 1-dimensional cv::Mat for scalars. |
|
return blob; |
|
} |
|
|
|
CV__DNN_INLINE_NS_END |
|
}} // namespace cv::dnn
|
|
|