|
|
|
@ -630,6 +630,21 @@ public: |
|
|
|
|
} |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
class SoftMaxSlimSubgraph : public Subgraph |
|
|
|
|
{ |
|
|
|
|
public: |
|
|
|
|
SoftMaxSlimSubgraph() |
|
|
|
|
{ |
|
|
|
|
int input = addNodeToMatch(""); |
|
|
|
|
int shape = addNodeToMatch("Const"); |
|
|
|
|
int shapeOp = addNodeToMatch("Shape", input); |
|
|
|
|
int reshape = addNodeToMatch("Reshape", input, shape); |
|
|
|
|
int softmax = addNodeToMatch("Softmax", reshape); |
|
|
|
|
addNodeToMatch("Reshape", softmax, shapeOp); |
|
|
|
|
setFusedNode("Softmax", input); |
|
|
|
|
} |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
void simplifySubgraphs(tensorflow::GraphDef& net) |
|
|
|
|
{ |
|
|
|
|
std::vector<Ptr<Subgraph> > subgraphs; |
|
|
|
@ -646,6 +661,7 @@ void simplifySubgraphs(tensorflow::GraphDef& net) |
|
|
|
|
subgraphs.push_back(Ptr<Subgraph>(new ResizeBilinearSubgraph())); |
|
|
|
|
subgraphs.push_back(Ptr<Subgraph>(new UpsamplingKerasSubgraph())); |
|
|
|
|
subgraphs.push_back(Ptr<Subgraph>(new ReshapeAsShapeSubgraph())); |
|
|
|
|
subgraphs.push_back(Ptr<Subgraph>(new SoftMaxSlimSubgraph())); |
|
|
|
|
|
|
|
|
|
int numNodes = net.node_size(); |
|
|
|
|
std::vector<int> matchedNodesIds; |
|
|
|
|