|
|
@ -70,13 +70,6 @@ public: |
|
|
|
{ |
|
|
|
{ |
|
|
|
fusedNodeInputs = inputs_; |
|
|
|
fusedNodeInputs = inputs_; |
|
|
|
fusedNodeOp = op; |
|
|
|
fusedNodeOp = op; |
|
|
|
nodesToFuse.clear(); |
|
|
|
|
|
|
|
for (int i = 0; i < nodes.size(); ++i) |
|
|
|
|
|
|
|
{ |
|
|
|
|
|
|
|
if (std::find(fusedNodeInputs.begin(), fusedNodeInputs.end(), i) == fusedNodeInputs.end() && |
|
|
|
|
|
|
|
nodes[i] != "Const") |
|
|
|
|
|
|
|
nodesToFuse.push_back(i); |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
static int getInputNodeId(const tensorflow::GraphDef& net, |
|
|
|
static int getInputNodeId(const tensorflow::GraphDef& net, |
|
|
@ -99,15 +92,17 @@ public: |
|
|
|
|
|
|
|
|
|
|
|
// Match TensorFlow subgraph starting from <nodeId> with a set of nodes to be fused.
|
|
|
|
// Match TensorFlow subgraph starting from <nodeId> with a set of nodes to be fused.
|
|
|
|
// Const nodes are skipped during matching. Returns true if nodes are matched and can be fused.
|
|
|
|
// Const nodes are skipped during matching. Returns true if nodes are matched and can be fused.
|
|
|
|
virtual bool match(const tensorflow::GraphDef& net, int nodeId, std::vector<int>& matchedNodesIds) |
|
|
|
virtual bool match(const tensorflow::GraphDef& net, int nodeId, |
|
|
|
|
|
|
|
std::vector<int>& matchedNodesIds, |
|
|
|
|
|
|
|
std::vector<int>& targetNodesIds) |
|
|
|
{ |
|
|
|
{ |
|
|
|
matchedNodesIds.clear(); |
|
|
|
matchedNodesIds.clear(); |
|
|
|
matchedNodesIds.reserve(nodesToFuse.size()); |
|
|
|
targetNodesIds.clear(); |
|
|
|
|
|
|
|
|
|
|
|
std::queue<int> nodesToMatch; |
|
|
|
std::queue<int> nodesToMatch; |
|
|
|
std::queue<int> targetNodes; |
|
|
|
std::queue<int> targetNodes; |
|
|
|
nodesToMatch.push(nodeId); |
|
|
|
nodesToMatch.push(nodeId); |
|
|
|
targetNodes.push(nodesToFuse.back()); |
|
|
|
targetNodes.push(nodes.size() - 1); |
|
|
|
while (!nodesToMatch.empty()) |
|
|
|
while (!nodesToMatch.empty()) |
|
|
|
{ |
|
|
|
{ |
|
|
|
int nodeToMatch = nodesToMatch.front(); |
|
|
|
int nodeToMatch = nodesToMatch.front(); |
|
|
@ -142,13 +137,25 @@ public: |
|
|
|
return false; |
|
|
|
return false; |
|
|
|
} |
|
|
|
} |
|
|
|
matchedNodesIds.push_back(nodeToMatch); |
|
|
|
matchedNodesIds.push_back(nodeToMatch); |
|
|
|
|
|
|
|
targetNodesIds.push_back(targetNodeId); |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const int n = matchedNodesIds.size(); |
|
|
|
|
|
|
|
std::vector<std::pair<int, int> > elements(n); |
|
|
|
|
|
|
|
for (int i = 0; i < n; ++i) |
|
|
|
|
|
|
|
elements[i] = std::make_pair(matchedNodesIds[i], targetNodesIds[i]); |
|
|
|
|
|
|
|
std::sort(elements.begin(), elements.end()); |
|
|
|
|
|
|
|
for (int i = 0; i < n; ++i) |
|
|
|
|
|
|
|
{ |
|
|
|
|
|
|
|
matchedNodesIds[i] = elements[i].first; |
|
|
|
|
|
|
|
targetNodesIds[i] = elements[i].second; |
|
|
|
} |
|
|
|
} |
|
|
|
std::sort(matchedNodesIds.begin(), matchedNodesIds.end()); |
|
|
|
|
|
|
|
return true; |
|
|
|
return true; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// Fuse matched subgraph.
|
|
|
|
// Fuse matched subgraph.
|
|
|
|
void replace(tensorflow::GraphDef& net, const std::vector<int>& matchedNodesIds) |
|
|
|
void replace(tensorflow::GraphDef& net, const std::vector<int>& matchedNodesIds, |
|
|
|
|
|
|
|
const std::vector<int>& targetNodesIds) |
|
|
|
{ |
|
|
|
{ |
|
|
|
// Extract names of input nodes.
|
|
|
|
// Extract names of input nodes.
|
|
|
|
std::vector<std::string> inputsNames(fusedNodeInputs.size()); |
|
|
|
std::vector<std::string> inputsNames(fusedNodeInputs.size()); |
|
|
@ -159,7 +166,7 @@ public: |
|
|
|
for (int j = 0; j < matchedNodesIds.size() && inpName.empty(); ++j) |
|
|
|
for (int j = 0; j < matchedNodesIds.size() && inpName.empty(); ++j) |
|
|
|
{ |
|
|
|
{ |
|
|
|
const tensorflow::NodeDef &node = net.node(matchedNodesIds[j]); |
|
|
|
const tensorflow::NodeDef &node = net.node(matchedNodesIds[j]); |
|
|
|
std::vector<int>& inpIndices = inputs[nodesToFuse[j]]; |
|
|
|
std::vector<int>& inpIndices = inputs[targetNodesIds[j]]; |
|
|
|
|
|
|
|
|
|
|
|
CV_Assert(node.input_size() == inpIndices.size()); |
|
|
|
CV_Assert(node.input_size() == inpIndices.size()); |
|
|
|
for (int k = 0; k < inpIndices.size(); ++k) |
|
|
|
for (int k = 0; k < inpIndices.size(); ++k) |
|
|
@ -204,7 +211,6 @@ private: |
|
|
|
std::vector<std::vector<int> > inputs; // Connections of an every node to it's inputs.
|
|
|
|
std::vector<std::vector<int> > inputs; // Connections of an every node to it's inputs.
|
|
|
|
|
|
|
|
|
|
|
|
std::string fusedNodeOp; // Operation name of resulting fused node.
|
|
|
|
std::string fusedNodeOp; // Operation name of resulting fused node.
|
|
|
|
std::vector<int> nodesToFuse; // Set of nodes to be fused.
|
|
|
|
|
|
|
|
std::vector<int> fusedNodeInputs; // Inputs of fused node.
|
|
|
|
std::vector<int> fusedNodeInputs; // Inputs of fused node.
|
|
|
|
}; |
|
|
|
}; |
|
|
|
|
|
|
|
|
|
|
@ -360,9 +366,11 @@ public: |
|
|
|
setFusedNode("Relu6", input); |
|
|
|
setFusedNode("Relu6", input); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
virtual bool match(const tensorflow::GraphDef& net, int nodeId, std::vector<int>& matchedNodesIds) CV_OVERRIDE |
|
|
|
virtual bool match(const tensorflow::GraphDef& net, int nodeId, |
|
|
|
|
|
|
|
std::vector<int>& matchedNodesIds, |
|
|
|
|
|
|
|
std::vector<int>& targetNodesIds) CV_OVERRIDE |
|
|
|
{ |
|
|
|
{ |
|
|
|
if (!Subgraph::match(net, nodeId, matchedNodesIds)) |
|
|
|
if (!Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds)) |
|
|
|
return false; |
|
|
|
return false; |
|
|
|
Mat maxValue = getTensorContent(net.node(matchedNodesIds.front() + 1).attr().at("value").tensor()); |
|
|
|
Mat maxValue = getTensorContent(net.node(matchedNodesIds.front() + 1).attr().at("value").tensor()); |
|
|
|
return maxValue.type() == CV_32FC1 && maxValue.total() == 1 && maxValue.at<float>(0) == 6; |
|
|
|
return maxValue.type() == CV_32FC1 && maxValue.total() == 1 && maxValue.at<float>(0) == 6; |
|
|
@ -394,14 +402,16 @@ public: |
|
|
|
setFusedNode("Reshape", ids); |
|
|
|
setFusedNode("Reshape", ids); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
virtual bool match(const tensorflow::GraphDef& net, int nodeId, std::vector<int>& matchedNodesIds) CV_OVERRIDE |
|
|
|
virtual bool match(const tensorflow::GraphDef& net, int nodeId, |
|
|
|
|
|
|
|
std::vector<int>& matchedNodesIds, |
|
|
|
|
|
|
|
std::vector<int>& targetNodesIds) CV_OVERRIDE |
|
|
|
{ |
|
|
|
{ |
|
|
|
const tensorflow::NodeDef& node = net.node(nodeId); |
|
|
|
const tensorflow::NodeDef& node = net.node(nodeId); |
|
|
|
if (node.input_size() == 0) |
|
|
|
if (node.input_size() == 0) |
|
|
|
return false; |
|
|
|
return false; |
|
|
|
|
|
|
|
|
|
|
|
inpName = node.input(0); |
|
|
|
inpName = node.input(0); |
|
|
|
return Subgraph::match(net, nodeId, matchedNodesIds); |
|
|
|
return Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -693,6 +703,40 @@ public: |
|
|
|
} |
|
|
|
} |
|
|
|
}; |
|
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class KerasMVNSubgraph : public Subgraph |
|
|
|
|
|
|
|
{ |
|
|
|
|
|
|
|
public: |
|
|
|
|
|
|
|
KerasMVNSubgraph() |
|
|
|
|
|
|
|
{ |
|
|
|
|
|
|
|
int input = addNodeToMatch(""); |
|
|
|
|
|
|
|
int mean = addNodeToMatch("Mean", input, addNodeToMatch("Const")); |
|
|
|
|
|
|
|
int grad = addNodeToMatch("StopGradient", mean); |
|
|
|
|
|
|
|
int diff = addNodeToMatch("SquaredDifference", input, grad); |
|
|
|
|
|
|
|
int var = addNodeToMatch("Mean", diff, addNodeToMatch("Const")); |
|
|
|
|
|
|
|
int sub = addNodeToMatch("Sub", input, mean); |
|
|
|
|
|
|
|
int add_y = addNodeToMatch("Const"); |
|
|
|
|
|
|
|
int add = addNodeToMatch("Add", var, add_y); |
|
|
|
|
|
|
|
int pow_y = addNodeToMatch("Const"); |
|
|
|
|
|
|
|
int powNode = addNodeToMatch("Pow", add, pow_y); |
|
|
|
|
|
|
|
addNodeToMatch("RealDiv", sub, powNode); |
|
|
|
|
|
|
|
setFusedNode("MVN", input, add_y); |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
virtual void finalize(tensorflow::GraphDef&, tensorflow::NodeDef* fusedNode, |
|
|
|
|
|
|
|
std::vector<tensorflow::NodeDef*>& inputNodes) CV_OVERRIDE |
|
|
|
|
|
|
|
{ |
|
|
|
|
|
|
|
tensorflow::AttrValue eps; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Mat epsMat = getTensorContent(inputNodes[1]->attr().at("value").tensor()); |
|
|
|
|
|
|
|
CV_CheckEQ(epsMat.total(), (size_t)1, ""); |
|
|
|
|
|
|
|
CV_CheckTypeEQ(epsMat.type(), CV_32FC1, ""); |
|
|
|
|
|
|
|
eps.set_f(epsMat.at<float>(0)); |
|
|
|
|
|
|
|
fusedNode->mutable_attr()->insert(MapPair<std::string, tensorflow::AttrValue>("eps", eps)); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fusedNode->mutable_input()->RemoveLast(); |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
void simplifySubgraphs(tensorflow::GraphDef& net) |
|
|
|
void simplifySubgraphs(tensorflow::GraphDef& net) |
|
|
|
{ |
|
|
|
{ |
|
|
|
std::vector<Ptr<Subgraph> > subgraphs; |
|
|
|
std::vector<Ptr<Subgraph> > subgraphs; |
|
|
@ -712,16 +756,17 @@ void simplifySubgraphs(tensorflow::GraphDef& net) |
|
|
|
subgraphs.push_back(Ptr<Subgraph>(new SoftMaxSlimSubgraph())); |
|
|
|
subgraphs.push_back(Ptr<Subgraph>(new SoftMaxSlimSubgraph())); |
|
|
|
subgraphs.push_back(Ptr<Subgraph>(new SoftMaxSlimV2Subgraph())); |
|
|
|
subgraphs.push_back(Ptr<Subgraph>(new SoftMaxSlimV2Subgraph())); |
|
|
|
subgraphs.push_back(Ptr<Subgraph>(new ReshapeAsShapeSubgraph())); |
|
|
|
subgraphs.push_back(Ptr<Subgraph>(new ReshapeAsShapeSubgraph())); |
|
|
|
|
|
|
|
subgraphs.push_back(Ptr<Subgraph>(new KerasMVNSubgraph())); |
|
|
|
|
|
|
|
|
|
|
|
int numNodes = net.node_size(); |
|
|
|
int numNodes = net.node_size(); |
|
|
|
std::vector<int> matchedNodesIds; |
|
|
|
std::vector<int> matchedNodesIds, targetNodesIds; |
|
|
|
for (int i = 0; i < numNodes; ++i) |
|
|
|
for (int i = 0; i < numNodes; ++i) |
|
|
|
{ |
|
|
|
{ |
|
|
|
for (int j = 0; j < subgraphs.size(); ++j) |
|
|
|
for (int j = 0; j < subgraphs.size(); ++j) |
|
|
|
{ |
|
|
|
{ |
|
|
|
if (subgraphs[j]->match(net, i, matchedNodesIds)) |
|
|
|
if (subgraphs[j]->match(net, i, matchedNodesIds, targetNodesIds)) |
|
|
|
{ |
|
|
|
{ |
|
|
|
subgraphs[j]->replace(net, matchedNodesIds); |
|
|
|
subgraphs[j]->replace(net, matchedNodesIds, targetNodesIds); |
|
|
|
numNodes -= matchedNodesIds.size() - 1; // #matchedNodes removed and one added.
|
|
|
|
numNodes -= matchedNodesIds.size() - 1; // #matchedNodes removed and one added.
|
|
|
|
break; |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|