|
|
|
@ -154,6 +154,32 @@ private: |
|
|
|
|
int axis; |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
class GatherCastSubgraph : public Subgraph |
|
|
|
|
{ |
|
|
|
|
public: |
|
|
|
|
GatherCastSubgraph() |
|
|
|
|
{ |
|
|
|
|
int input = addNodeToMatch(""); |
|
|
|
|
int index = addNodeToMatch("Constant"); |
|
|
|
|
int gather = addNodeToMatch("Gather", input, index); |
|
|
|
|
addNodeToMatch("Cast", gather); |
|
|
|
|
setFusedNode("Gather", input, index); |
|
|
|
|
} |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
class MulCastSubgraph : public Subgraph |
|
|
|
|
{ |
|
|
|
|
public: |
|
|
|
|
MulCastSubgraph() |
|
|
|
|
{ |
|
|
|
|
int input = addNodeToMatch(""); |
|
|
|
|
int scaleNode = addNodeToMatch("Constant"); |
|
|
|
|
int mul = addNodeToMatch("Mul", input, scaleNode); |
|
|
|
|
addNodeToMatch("Cast", mul); |
|
|
|
|
setFusedNode("Mul", input, scaleNode); |
|
|
|
|
} |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
class ExtractScalesSubgraph : public Subgraph |
|
|
|
|
{ |
|
|
|
|
public: |
|
|
|
@ -164,20 +190,16 @@ public: |
|
|
|
|
int indexH = addNodeToMatch("Constant"); |
|
|
|
|
int shape1 = addNodeToMatch("Shape", input); |
|
|
|
|
int gather1 = addNodeToMatch("Gather", shape1, indexH); |
|
|
|
|
int castG1 = addNodeToMatch("Cast", gather1); |
|
|
|
|
scaleHNode = addNodeToMatch("Constant"); |
|
|
|
|
int mul1 = addNodeToMatch("Mul", castG1, scaleHNode); |
|
|
|
|
int castM1 = addNodeToMatch("Cast", mul1); |
|
|
|
|
int floor1 = addNodeToMatch("Floor", castM1); |
|
|
|
|
int mul1 = addNodeToMatch("Mul", gather1, scaleHNode); |
|
|
|
|
int floor1 = addNodeToMatch("Floor", mul1); |
|
|
|
|
|
|
|
|
|
int indexW = addNodeToMatch("Constant"); |
|
|
|
|
int shape2 = addNodeToMatch("Shape", input); |
|
|
|
|
int gather2 = addNodeToMatch("Gather", shape2, indexW); |
|
|
|
|
int castG2 = addNodeToMatch("Cast", gather2); |
|
|
|
|
scaleWNode = addNodeToMatch("Constant"); |
|
|
|
|
int mul2 = addNodeToMatch("Mul", castG2, scaleWNode); |
|
|
|
|
int castM2 = addNodeToMatch("Cast", mul2); |
|
|
|
|
int floor2 = addNodeToMatch("Floor", castM2); |
|
|
|
|
int mul2 = addNodeToMatch("Mul", gather2, scaleWNode); |
|
|
|
|
int floor2 = addNodeToMatch("Floor", mul2); |
|
|
|
|
|
|
|
|
|
int unsqueeze1 = addNodeToMatch("Unsqueeze", floor1); |
|
|
|
|
int unsqueeze2 = addNodeToMatch("Unsqueeze", floor2); |
|
|
|
@ -190,19 +212,23 @@ public: |
|
|
|
|
{ |
|
|
|
|
opencv_onnx::NodeProto* constant_node = inputs[1].dynamicCast<ONNXNodeWrapper>()->node; |
|
|
|
|
opencv_onnx::TensorProto tensor_proto = constant_node->attribute(0).t(); |
|
|
|
|
float scaleW = getMatFromTensor(tensor_proto).at<float>(0); |
|
|
|
|
Mat scaleW = getMatFromTensor(tensor_proto); |
|
|
|
|
CV_Assert(scaleW.total() == 1); |
|
|
|
|
scaleW.convertTo(scaleW, CV_32F); |
|
|
|
|
|
|
|
|
|
constant_node = inputs[2].dynamicCast<ONNXNodeWrapper>()->node; |
|
|
|
|
tensor_proto = constant_node->attribute(0).t(); |
|
|
|
|
float scaleH = getMatFromTensor(tensor_proto).at<float>(0); |
|
|
|
|
Mat scaleH = getMatFromTensor(tensor_proto); |
|
|
|
|
CV_Assert(scaleH.total() == 1); |
|
|
|
|
scaleH.convertTo(scaleH, CV_32F); |
|
|
|
|
|
|
|
|
|
opencv_onnx::NodeProto* node = fusedNode.dynamicCast<ONNXNodeWrapper>()->node; |
|
|
|
|
opencv_onnx::AttributeProto* attrH = node->add_attribute(); |
|
|
|
|
attrH->set_name("height_scale"); |
|
|
|
|
attrH->set_i(scaleH); |
|
|
|
|
attrH->set_i(scaleH.at<float>(0)); |
|
|
|
|
opencv_onnx::AttributeProto* attrW = node->add_attribute(); |
|
|
|
|
attrW->set_name("width_scale"); |
|
|
|
|
attrW->set_i(scaleW); |
|
|
|
|
attrW->set_i(scaleW.at<float>(0)); |
|
|
|
|
|
|
|
|
|
node->mutable_input()->DeleteSubrange(1, 2); // Remove two last inputs
|
|
|
|
|
} |
|
|
|
@ -267,6 +293,8 @@ public: |
|
|
|
|
void simplifySubgraphs(opencv_onnx::GraphProto& net) |
|
|
|
|
{ |
|
|
|
|
std::vector<Ptr<Subgraph> > subgraphs; |
|
|
|
|
subgraphs.push_back(makePtr<GatherCastSubgraph>()); |
|
|
|
|
subgraphs.push_back(makePtr<MulCastSubgraph>()); |
|
|
|
|
subgraphs.push_back(makePtr<UpsampleSubgraph>()); |
|
|
|
|
subgraphs.push_back(makePtr<ResizeSubgraph1>()); |
|
|
|
|
subgraphs.push_back(makePtr<ResizeSubgraph2>()); |
|
|
|
|