|
|
|
@ -601,7 +601,7 @@ public: |
|
|
|
|
class UpsamplingKerasSubgraph : public Subgraph |
|
|
|
|
{ |
|
|
|
|
public: |
|
|
|
|
UpsamplingKerasSubgraph() |
|
|
|
|
UpsamplingKerasSubgraph(const std::string& type) |
|
|
|
|
{ |
|
|
|
|
int input = addNodeToMatch(""); |
|
|
|
|
int shape = addNodeToMatch("Shape", input); |
|
|
|
@ -611,8 +611,8 @@ public: |
|
|
|
|
int strided_slice = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2); |
|
|
|
|
int factors = addNodeToMatch("Const"); |
|
|
|
|
int mul = addNodeToMatch("Mul", strided_slice, factors); |
|
|
|
|
addNodeToMatch("ResizeNearestNeighbor", input, mul); |
|
|
|
|
setFusedNode("ResizeNearestNeighbor", input, factors); |
|
|
|
|
addNodeToMatch(type, input, mul); |
|
|
|
|
setFusedNode(type, input, factors); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
virtual void finalize(tensorflow::GraphDef& net, tensorflow::NodeDef* fusedNode, |
|
|
|
@ -707,7 +707,8 @@ void simplifySubgraphs(tensorflow::GraphDef& net) |
|
|
|
|
subgraphs.push_back(Ptr<Subgraph>(new DeconvolutionValidKerasSubgraph())); |
|
|
|
|
subgraphs.push_back(Ptr<Subgraph>(new DeconvolutionSameKerasSubgraph())); |
|
|
|
|
subgraphs.push_back(Ptr<Subgraph>(new ResizeBilinearSubgraph())); |
|
|
|
|
subgraphs.push_back(Ptr<Subgraph>(new UpsamplingKerasSubgraph())); |
|
|
|
|
subgraphs.push_back(Ptr<Subgraph>(new UpsamplingKerasSubgraph("ResizeNearestNeighbor"))); |
|
|
|
|
subgraphs.push_back(Ptr<Subgraph>(new UpsamplingKerasSubgraph("ResizeBilinear"))); |
|
|
|
|
subgraphs.push_back(Ptr<Subgraph>(new SoftMaxSlimSubgraph())); |
|
|
|
|
subgraphs.push_back(Ptr<Subgraph>(new SoftMaxSlimV2Subgraph())); |
|
|
|
|
subgraphs.push_back(Ptr<Subgraph>(new ReshapeAsShapeSubgraph())); |
|
|
|
@ -752,6 +753,8 @@ void RemoveIdentityOps(tensorflow::GraphDef& net) |
|
|
|
|
tensorflow::NodeDef* layer = net.mutable_node(li); |
|
|
|
|
for (int input_id = 0; input_id < layer->input_size(); input_id++) { |
|
|
|
|
String input_op_name = layer->input(input_id); |
|
|
|
|
input_op_name = input_op_name.substr(input_op_name.find('^') + 1, |
|
|
|
|
input_op_name.rfind(':')); |
|
|
|
|
IdentityOpsMap::iterator it = identity_ops.find(input_op_name); |
|
|
|
|
|
|
|
|
|
if (it != identity_ops.end()) { |
|
|
|
|