|
|
|
@ -18,6 +18,7 @@ Implementation of Tensorflow models parser |
|
|
|
|
#include <fstream> |
|
|
|
|
#include <algorithm> |
|
|
|
|
#include <string> |
|
|
|
|
#include <queue> |
|
|
|
|
#include "tf_graph_simplifier.hpp" |
|
|
|
|
#endif |
|
|
|
|
|
|
|
|
@ -558,9 +559,7 @@ static void addConstNodes(tensorflow::GraphDef& net, std::map<String, int>& cons |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// If all inputs of specific layer have the same data layout we can say that
|
|
|
|
|
// this layer's output has this data layout too. Returns DATA_LAYOUT_UNKNOWN otherwise.
|
|
|
|
|
static int predictOutputDataLayout(const tensorflow::NodeDef& layer, const std::map<String, int>& data_layouts) |
|
|
|
|
static int getDataLayout(const tensorflow::NodeDef& layer) |
|
|
|
|
{ |
|
|
|
|
if (hasLayerAttr(layer, "data_format")) |
|
|
|
|
{ |
|
|
|
@ -572,27 +571,48 @@ static int predictOutputDataLayout(const tensorflow::NodeDef& layer, const std:: |
|
|
|
|
else |
|
|
|
|
CV_Error(Error::StsParseError, "Unknown data_format value: " + format); |
|
|
|
|
} |
|
|
|
|
return DATA_LAYOUT_UNKNOWN; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
static inline std::string getNodeName(const std::string& tensorName) |
|
|
|
|
{ |
|
|
|
|
return tensorName.substr(0, tensorName.rfind(':')); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// If all inputs of specific layer have the same data layout we can say that
|
|
|
|
|
// this layer's output has this data layout too. Returns DATA_LAYOUT_UNKNOWN otherwise.
|
|
|
|
|
static int predictOutputDataLayout(const tensorflow::GraphDef& net, |
|
|
|
|
const tensorflow::NodeDef& layer, |
|
|
|
|
const std::map<String, int>& data_layouts) |
|
|
|
|
{ |
|
|
|
|
int layout = getDataLayout(layer); |
|
|
|
|
if (layout != DATA_LAYOUT_UNKNOWN) |
|
|
|
|
return layout; |
|
|
|
|
|
|
|
|
|
// Determine layout by layer's inputs
|
|
|
|
|
int layout = DATA_LAYOUT_UNKNOWN; |
|
|
|
|
std::map<String, int>::const_iterator it; |
|
|
|
|
for (int i = 0, n = layer.input_size(); i < n; ++i) |
|
|
|
|
{ |
|
|
|
|
it = data_layouts.find(layer.input(i).substr(0, layer.input(i).rfind(':'))); |
|
|
|
|
it = data_layouts.find(getNodeName(layer.input(i))); |
|
|
|
|
if (it != data_layouts.end()) |
|
|
|
|
{ |
|
|
|
|
if (it->second == DATA_LAYOUT_UNKNOWN) |
|
|
|
|
return DATA_LAYOUT_UNKNOWN; |
|
|
|
|
else if (it->second != layout) |
|
|
|
|
if (layout != DATA_LAYOUT_UNKNOWN) |
|
|
|
|
{ |
|
|
|
|
if (layout == DATA_LAYOUT_UNKNOWN) |
|
|
|
|
layout = it->second; |
|
|
|
|
else |
|
|
|
|
if (it->second != layout && it->second != DATA_LAYOUT_UNKNOWN) |
|
|
|
|
return DATA_LAYOUT_UNKNOWN; |
|
|
|
|
} |
|
|
|
|
else |
|
|
|
|
layout = it->second; |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
return layout; |
|
|
|
|
|
|
|
|
|
if (layout != DATA_LAYOUT_UNKNOWN) |
|
|
|
|
return layout; |
|
|
|
|
|
|
|
|
|
// Determine layout by layer's consumers recursively.
|
|
|
|
|
it = data_layouts.find(layer.name()); |
|
|
|
|
CV_Assert(it != data_layouts.end()); |
|
|
|
|
return it->second; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
void TFImporter::populateNet(Net dstNet) |
|
|
|
@ -610,6 +630,52 @@ void TFImporter::populateNet(Net dstNet) |
|
|
|
|
int layersSize = net.node_size(); |
|
|
|
|
|
|
|
|
|
std::map<String, int> data_layouts; |
|
|
|
|
// Pre-fill data layouts where they are set explicitly.
|
|
|
|
|
// Assuming that nodes are in topological order
|
|
|
|
|
for (int i = net.node_size() - 1; i >= 0; --i) |
|
|
|
|
{ |
|
|
|
|
const tensorflow::NodeDef& layer = net.node(i); |
|
|
|
|
std::string name = layer.name(); |
|
|
|
|
|
|
|
|
|
int layout = getDataLayout(layer); |
|
|
|
|
std::map<String, int>::iterator it = data_layouts.find(name); |
|
|
|
|
if (it != data_layouts.end()) |
|
|
|
|
{ |
|
|
|
|
if (layout != DATA_LAYOUT_UNKNOWN) |
|
|
|
|
{ |
|
|
|
|
if (it->second == DATA_LAYOUT_UNKNOWN) |
|
|
|
|
it->second = layout; |
|
|
|
|
else if (it->second != layout) |
|
|
|
|
{ |
|
|
|
|
it->second = DATA_LAYOUT_UNKNOWN; |
|
|
|
|
layout = DATA_LAYOUT_UNKNOWN; |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
else |
|
|
|
|
layout = it->second; |
|
|
|
|
} |
|
|
|
|
else |
|
|
|
|
data_layouts[name] = layout; |
|
|
|
|
|
|
|
|
|
// Specify input layers to have the same data layout.
|
|
|
|
|
for (int j = 0; j < layer.input_size(); ++j) |
|
|
|
|
{ |
|
|
|
|
name = getNodeName(layer.input(j)); |
|
|
|
|
it = data_layouts.find(name); |
|
|
|
|
if (it != data_layouts.end()) |
|
|
|
|
{ |
|
|
|
|
if (layout != DATA_LAYOUT_UNKNOWN) |
|
|
|
|
{ |
|
|
|
|
if (it->second == DATA_LAYOUT_UNKNOWN) |
|
|
|
|
it->second = layout; |
|
|
|
|
else if (it->second != layout) |
|
|
|
|
it->second = DATA_LAYOUT_UNKNOWN; |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
else |
|
|
|
|
data_layouts[name] = layout; |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// find all Const layers for params
|
|
|
|
|
std::map<String, int> value_id; |
|
|
|
@ -628,7 +694,8 @@ void TFImporter::populateNet(Net dstNet) |
|
|
|
|
if(layers_to_ignore.find(name) != layers_to_ignore.end()) |
|
|
|
|
continue; |
|
|
|
|
|
|
|
|
|
data_layouts[name] = predictOutputDataLayout(layer, data_layouts); |
|
|
|
|
int predictedLayout = predictOutputDataLayout(net, layer, data_layouts); |
|
|
|
|
data_layouts[name] = predictedLayout; |
|
|
|
|
|
|
|
|
|
if (type == "Conv2D" || type == "SpaceToBatchND" || type == "DepthwiseConv2dNative") |
|
|
|
|
{ |
|
|
|
@ -885,6 +952,7 @@ void TFImporter::populateNet(Net dstNet) |
|
|
|
|
|
|
|
|
|
// one input only
|
|
|
|
|
connect(layer_id, dstNet, inpId, id, 0); |
|
|
|
|
data_layouts[name] = DATA_LAYOUT_UNKNOWN; |
|
|
|
|
} |
|
|
|
|
else if (type == "Flatten" || type == "Squeeze") |
|
|
|
|
{ |
|
|
|
@ -1013,7 +1081,10 @@ void TFImporter::populateNet(Net dstNet) |
|
|
|
|
{ |
|
|
|
|
int axisId = (type == "Concat" ? 0 : layer.input_size() - 1); |
|
|
|
|
int axis = getConstBlob(layer, value_id, axisId).int_val().Get(0); |
|
|
|
|
layerParams.set("axis", 0 <= axis && axis < 4 ? toNCHW(axis) : axis); |
|
|
|
|
|
|
|
|
|
if (data_layouts[name] == DATA_LAYOUT_NHWC) |
|
|
|
|
axis = toNCHW(axis); |
|
|
|
|
layerParams.set("axis", axis); |
|
|
|
|
|
|
|
|
|
int id = dstNet.addLayer(name, "Concat", layerParams); |
|
|
|
|
layer_id[name] = id; |
|
|
|
|