StridedSlice from TensorFlow

pull/14459/head
Dmitry Kurtaev 6 years ago
parent 190467b6c1
commit 26e426adb1
  1. 1
      modules/dnn/include/opencv2/dnn/dnn.hpp
  2. 37
      modules/dnn/src/tensorflow/tf_importer.cpp
  3. 2
      modules/dnn/test/test_onnx_importer.cpp
  4. 1
      modules/dnn/test/test_tf_importer.cpp
  5. 8
      samples/dnn/tf_text_graph_faster_rcnn.py

@ -820,6 +820,7 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
* * `*.t7` | `*.net` (Torch, http://torch.ch/) * * `*.t7` | `*.net` (Torch, http://torch.ch/)
* * `*.weights` (Darknet, https://pjreddie.com/darknet/) * * `*.weights` (Darknet, https://pjreddie.com/darknet/)
* * `*.bin` (DLDT, https://software.intel.com/openvino-toolkit) * * `*.bin` (DLDT, https://software.intel.com/openvino-toolkit)
* * `*.onnx` (ONNX, https://onnx.ai/)
* @param[in] config Text file contains network configuration. It could be a * @param[in] config Text file contains network configuration. It could be a
* file with the following extensions: * file with the following extensions:
* * `*.prototxt` (Caffe, http://caffe.berkeleyvision.org/) * * `*.prototxt` (Caffe, http://caffe.berkeleyvision.org/)

@ -1423,6 +1423,43 @@ void TFImporter::populateNet(Net dstNet)
connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0); connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0);
} }
else if (type == "StridedSlice")
{
CV_Assert(layer.input_size() == 4);
Mat begins = getTensorContent(getConstBlob(layer, value_id, 1));
Mat ends = getTensorContent(getConstBlob(layer, value_id, 2));
Mat strides = getTensorContent(getConstBlob(layer, value_id, 3));
CV_CheckTypeEQ(begins.type(), CV_32SC1, "");
CV_CheckTypeEQ(ends.type(), CV_32SC1, "");
CV_CheckTypeEQ(strides.type(), CV_32SC1, "");
const int num = begins.total();
CV_Assert_N(num == ends.total(), num == strides.total());
int end_mask = getLayerAttr(layer, "end_mask").i();
for (int i = 0; i < num; ++i)
{
if (end_mask & (1 << i))
ends.at<int>(i) = -1;
if (strides.at<int>(i) != 1)
CV_Error(Error::StsNotImplemented,
format("StridedSlice with stride %d", strides.at<int>(i)));
}
if (begins.total() == 4 && getDataLayout(name, data_layouts) == DATA_LAYOUT_NHWC)
{
// Swap NHWC parameters' order to NCHW.
std::swap(begins.at<int>(2), begins.at<int>(3));
std::swap(begins.at<int>(1), begins.at<int>(2));
std::swap(ends.at<int>(2), ends.at<int>(3));
std::swap(ends.at<int>(1), ends.at<int>(2));
}
layerParams.set("begin", DictValue::arrayInt((int*)begins.data, begins.total()));
layerParams.set("end", DictValue::arrayInt((int*)ends.data, ends.total()));
int id = dstNet.addLayer(name, "Slice", layerParams);
layer_id[name] = id;
connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0);
}
else if (type == "Mul") else if (type == "Mul")
{ {
bool haveConst = false; bool haveConst = false;

@ -248,7 +248,7 @@ TEST_P(Test_ONNX_layers, Reshape)
TEST_P(Test_ONNX_layers, Softmax) TEST_P(Test_ONNX_layers, Softmax)
{ {
testONNXModels("softmax"); testONNXModels("softmax");
testONNXModels("log_softmax"); testONNXModels("log_softmax", npy, 0, 0, false, false);
} }
INSTANTIATE_TEST_CASE_P(/*nothing*/, Test_ONNX_layers, dnnBackendsAndTargets()); INSTANTIATE_TEST_CASE_P(/*nothing*/, Test_ONNX_layers, dnnBackendsAndTargets());

@ -663,6 +663,7 @@ TEST_P(Test_TensorFlow_layers, slice)
(target == DNN_TARGET_OPENCL || target == DNN_TARGET_OPENCL_FP16)) (target == DNN_TARGET_OPENCL || target == DNN_TARGET_OPENCL_FP16))
throw SkipTestException(""); throw SkipTestException("");
runTensorFlowNet("slice_4d"); runTensorFlowNet("slice_4d");
runTensorFlowNet("strided_slice");
} }
TEST_P(Test_TensorFlow_layers, softmax) TEST_P(Test_TensorFlow_layers, softmax)

@ -31,7 +31,13 @@ def createFasterRCNNGraph(modelPath, configPath, outputPath):
aspect_ratios = [float(ar) for ar in grid_anchor_generator['aspect_ratios']] aspect_ratios = [float(ar) for ar in grid_anchor_generator['aspect_ratios']]
width_stride = float(grid_anchor_generator['width_stride'][0]) width_stride = float(grid_anchor_generator['width_stride'][0])
height_stride = float(grid_anchor_generator['height_stride'][0]) height_stride = float(grid_anchor_generator['height_stride'][0])
features_stride = float(config['feature_extractor'][0]['first_stage_features_stride'][0])
feature_extractor = config['feature_extractor'][0]
if 'type' in feature_extractor and feature_extractor['type'][0] == 'faster_rcnn_nas':
features_stride = 16.0
else:
features_stride = float(feature_extractor['first_stage_features_stride'][0])
first_stage_nms_iou_threshold = float(config['first_stage_nms_iou_threshold'][0]) first_stage_nms_iou_threshold = float(config['first_stage_nms_iou_threshold'][0])
first_stage_max_proposals = int(config['first_stage_max_proposals'][0]) first_stage_max_proposals = int(config['first_stage_max_proposals'][0])

Loading…
Cancel
Save