|
|
@ -1423,6 +1423,43 @@ void TFImporter::populateNet(Net dstNet) |
|
|
|
|
|
|
|
|
|
|
|
connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0); |
|
|
|
connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
else if (type == "StridedSlice") |
|
|
|
|
|
|
|
{ |
|
|
|
|
|
|
|
CV_Assert(layer.input_size() == 4); |
|
|
|
|
|
|
|
Mat begins = getTensorContent(getConstBlob(layer, value_id, 1)); |
|
|
|
|
|
|
|
Mat ends = getTensorContent(getConstBlob(layer, value_id, 2)); |
|
|
|
|
|
|
|
Mat strides = getTensorContent(getConstBlob(layer, value_id, 3)); |
|
|
|
|
|
|
|
CV_CheckTypeEQ(begins.type(), CV_32SC1, ""); |
|
|
|
|
|
|
|
CV_CheckTypeEQ(ends.type(), CV_32SC1, ""); |
|
|
|
|
|
|
|
CV_CheckTypeEQ(strides.type(), CV_32SC1, ""); |
|
|
|
|
|
|
|
const int num = begins.total(); |
|
|
|
|
|
|
|
CV_Assert_N(num == ends.total(), num == strides.total()); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
int end_mask = getLayerAttr(layer, "end_mask").i(); |
|
|
|
|
|
|
|
for (int i = 0; i < num; ++i) |
|
|
|
|
|
|
|
{ |
|
|
|
|
|
|
|
if (end_mask & (1 << i)) |
|
|
|
|
|
|
|
ends.at<int>(i) = -1; |
|
|
|
|
|
|
|
if (strides.at<int>(i) != 1) |
|
|
|
|
|
|
|
CV_Error(Error::StsNotImplemented, |
|
|
|
|
|
|
|
format("StridedSlice with stride %d", strides.at<int>(i))); |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
if (begins.total() == 4 && getDataLayout(name, data_layouts) == DATA_LAYOUT_NHWC) |
|
|
|
|
|
|
|
{ |
|
|
|
|
|
|
|
// Swap NHWC parameters' order to NCHW.
|
|
|
|
|
|
|
|
std::swap(begins.at<int>(2), begins.at<int>(3)); |
|
|
|
|
|
|
|
std::swap(begins.at<int>(1), begins.at<int>(2)); |
|
|
|
|
|
|
|
std::swap(ends.at<int>(2), ends.at<int>(3)); |
|
|
|
|
|
|
|
std::swap(ends.at<int>(1), ends.at<int>(2)); |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
layerParams.set("begin", DictValue::arrayInt((int*)begins.data, begins.total())); |
|
|
|
|
|
|
|
layerParams.set("end", DictValue::arrayInt((int*)ends.data, ends.total())); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
int id = dstNet.addLayer(name, "Slice", layerParams); |
|
|
|
|
|
|
|
layer_id[name] = id; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0); |
|
|
|
|
|
|
|
} |
|
|
|
else if (type == "Mul") |
|
|
|
else if (type == "Mul") |
|
|
|
{ |
|
|
|
{ |
|
|
|
bool haveConst = false; |
|
|
|
bool haveConst = false; |
|
|
|