Merge pull request #21154 from pccvlab:MatMul_with_two_inputs

Add BatchMatMul layer support for tf_importer

* two inputs

* support batch_matmul

* refactor: remove useless code

* refactor: decrease nesting
pull/21238/head
Gruhuang 3 years ago committed by GitHub
parent c08954c18b
commit 17bc8565f6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 20
      modules/dnn/src/tensorflow/tf_importer.cpp
  2. 8
      modules/dnn/test/test_tf_importer.cpp

@ -646,7 +646,7 @@ const TFImporter::DispatchMap TFImporter::buildDispatchMap()
dispatch["Conv2D"] = dispatch["SpaceToBatchND"] = dispatch["DepthwiseConv2dNative"] =
dispatch["Pad"] = dispatch["MirrorPad"] = dispatch["Conv3D"] = &TFImporter::parseConvolution;
dispatch["BiasAdd"] = dispatch["Add"] = dispatch["AddV2"] = dispatch["Sub"] = dispatch["AddN"] = &TFImporter::parseBias;
dispatch["MatMul"] = &TFImporter::parseMatMul;
dispatch["MatMul"] = dispatch["BatchMatMul"] = &TFImporter::parseMatMul;
dispatch["Reshape"] = &TFImporter::parseReshape;
dispatch["Flatten"] = dispatch["Squeeze"] = &TFImporter::parseFlatten;
dispatch["Transpose"] = &TFImporter::parseTranspose;
@ -983,6 +983,24 @@ void TFImporter::parseMatMul(tensorflow::GraphDef& net, const tensorflow::NodeDe
layerParams.set("bias_term", false);
layerParams.blobs.resize(1);
bool hasConstBlob = false;
for(int i = 0; i < layer.input_size(); i++) {
if (value_id.find(layer.input(i)) != value_id.end())
hasConstBlob = true;
}
if (!hasConstBlob)
{
layerParams.blobs.clear();
int id = dstNet.addLayer(name, "InnerProduct", layerParams);
layer_id[name] = id;
// two inputs
for(int ii=0; ii<layer.input_size(); ii++){
connect(layer_id, dstNet, parsePin(layer.input(ii)), id, ii);
}
return;
}
StrIntVector next_layers = getNextLayers(net, name, "BiasAdd"); // FIXIT Use layers fusion instead
if (next_layers.empty())
{

@ -660,6 +660,14 @@ TEST_P(Test_TensorFlow_layers, matmul)
double l1 = target == DNN_TARGET_MYRIAD ? 6.1e-3 : default_l1;
runTensorFlowNet("nhwc_reshape_matmul", false, l1);
runTensorFlowNet("matmul_layout");
runTensorFlowNet("two_inputs_matmul");
}
TEST_P(Test_TensorFlow_layers, batch_matmul)
{
if (backend == DNN_BACKEND_OPENCV && target == DNN_TARGET_OPENCL_FP16)
applyTestTag(CV_TEST_TAG_DNN_SKIP_OPENCL_FP16);
runTensorFlowNet("batch_matmul");
}
TEST_P(Test_TensorFlow_layers, reshape)

Loading…
Cancel
Save