diff --git a/modules/dnn/src/torch/torch_importer.cpp b/modules/dnn/src/torch/torch_importer.cpp index f22d9eff1..29f7fd02d 100644 --- a/modules/dnn/src/torch/torch_importer.cpp +++ b/modules/dnn/src/torch/torch_importer.cpp @@ -185,9 +185,6 @@ struct TorchImporter : public ::cv::dnn::Importer void readTorchStorage(int index, int type = -1) { - if (storages.count(index)) - return; - long size = readLong(); Mat storageMat(1, size, (type != CV_USRTYPE1) ? type : CV_64F); //handle LongStorage as CV_64F Mat @@ -225,7 +222,6 @@ struct TorchImporter : public ::cv::dnn::Importer } storages.insert(std::make_pair(index, storageMat)); - readedIndexes.insert(index); } void readTorchTable(Dict &scalarParams, std::map &tensorParams) @@ -261,10 +257,7 @@ struct TorchImporter : public ::cv::dnn::Importer if (vtype == TYPE_TORCH) { int index = readInt(); - if (tensors.count(index) == 0) - { - readTorchObject(index); - } + readTorchObject(index); if (tensors.count(index)) //tensor was readed { @@ -311,9 +304,6 @@ struct TorchImporter : public ::cv::dnn::Importer void readTorchTensor(int indexTensor, int typeTensor) { - if (tensors.count(indexTensor)) - return; - int ndims = readInt(); AutoBuffer sizes(ndims); AutoBuffer steps(ndims); @@ -394,6 +384,14 @@ struct TorchImporter : public ::cv::dnn::Importer void readTorchObject(int index) { + if(readedIndexes.count(index)) + { + if(!storages.count(index) && !tensors.count(index)) + CV_Error(Error::StsNotImplemented, "Objects which have multiple references are not supported"); + else + return; + } + String className = readTorchClassName(); String nnName; std::cout << "Class: " << className << std::endl; @@ -409,8 +407,6 @@ struct TorchImporter : public ::cv::dnn::Importer } else if (isNNClass(className, nnName)) { - CV_Assert(!readedIndexes.count(index)); - Dict scalarParams; std::map tensorParams; @@ -524,6 +520,8 @@ struct TorchImporter : public ::cv::dnn::Importer { CV_Error(Error::StsNotImplemented, "Unsupported Torch class \"" + className + "\""); } + + readedIndexes.insert(index); } void readObject() @@ -533,12 +531,8 @@ struct TorchImporter : public ::cv::dnn::Importer if (typeidx == TYPE_TORCH) { int index = readInt(); - - if (readedIndexes.count(index) == 0) - { - readTorchObject(index); - readedIndexes.insert(index); - } + readTorchObject(index); + readedIndexes.insert(index); } else if (typeidx == TYPE_NIL) return; diff --git a/modules/dnn/testdata/dnn/torch/torch_gen_test_data.lua b/modules/dnn/testdata/dnn/torch/torch_gen_test_data.lua index a05f65b09..ae25caafd 100644 --- a/modules/dnn/testdata/dnn/torch/torch_gen_test_data.lua +++ b/modules/dnn/testdata/dnn/torch/torch_gen_test_data.lua @@ -27,7 +27,7 @@ function save(net, input, label) end local net_simple = nn.Sequential() ---net_simple:add(nn.ReLU()) +net_simple:add(nn.ReLU()) net_simple:add(nn.SpatialConvolution(3,64, 11,7, 3,4, 3,2)) net_simple:add(nn.SpatialMaxPooling(4,5, 3,2, 1,2)) net_simple:add(nn.Sigmoid())