Resolve reference counting problem in Torch importer.

pull/265/head
Vitaliy Lyudvichenko 10 years ago
parent 1c220cf03b
commit 436d929578
  1. 32
      modules/dnn/src/torch/torch_importer.cpp
  2. 2
      modules/dnn/testdata/dnn/torch/torch_gen_test_data.lua

@ -185,9 +185,6 @@ struct TorchImporter : public ::cv::dnn::Importer
void readTorchStorage(int index, int type = -1) void readTorchStorage(int index, int type = -1)
{ {
if (storages.count(index))
return;
long size = readLong(); long size = readLong();
Mat storageMat(1, size, (type != CV_USRTYPE1) ? type : CV_64F); //handle LongStorage as CV_64F Mat Mat storageMat(1, size, (type != CV_USRTYPE1) ? type : CV_64F); //handle LongStorage as CV_64F Mat
@ -225,7 +222,6 @@ struct TorchImporter : public ::cv::dnn::Importer
} }
storages.insert(std::make_pair(index, storageMat)); storages.insert(std::make_pair(index, storageMat));
readedIndexes.insert(index);
} }
void readTorchTable(Dict &scalarParams, std::map<String, Blob> &tensorParams) void readTorchTable(Dict &scalarParams, std::map<String, Blob> &tensorParams)
@ -261,10 +257,7 @@ struct TorchImporter : public ::cv::dnn::Importer
if (vtype == TYPE_TORCH) if (vtype == TYPE_TORCH)
{ {
int index = readInt(); int index = readInt();
if (tensors.count(index) == 0) readTorchObject(index);
{
readTorchObject(index);
}
if (tensors.count(index)) //tensor was readed if (tensors.count(index)) //tensor was readed
{ {
@ -311,9 +304,6 @@ struct TorchImporter : public ::cv::dnn::Importer
void readTorchTensor(int indexTensor, int typeTensor) void readTorchTensor(int indexTensor, int typeTensor)
{ {
if (tensors.count(indexTensor))
return;
int ndims = readInt(); int ndims = readInt();
AutoBuffer<long, 4> sizes(ndims); AutoBuffer<long, 4> sizes(ndims);
AutoBuffer<long, 4> steps(ndims); AutoBuffer<long, 4> steps(ndims);
@ -394,6 +384,14 @@ struct TorchImporter : public ::cv::dnn::Importer
void readTorchObject(int index) void readTorchObject(int index)
{ {
if(readedIndexes.count(index))
{
if(!storages.count(index) && !tensors.count(index))
CV_Error(Error::StsNotImplemented, "Objects which have multiple references are not supported");
else
return;
}
String className = readTorchClassName(); String className = readTorchClassName();
String nnName; String nnName;
std::cout << "Class: " << className << std::endl; std::cout << "Class: " << className << std::endl;
@ -409,8 +407,6 @@ struct TorchImporter : public ::cv::dnn::Importer
} }
else if (isNNClass(className, nnName)) else if (isNNClass(className, nnName))
{ {
CV_Assert(!readedIndexes.count(index));
Dict scalarParams; Dict scalarParams;
std::map<String, Blob> tensorParams; std::map<String, Blob> tensorParams;
@ -524,6 +520,8 @@ struct TorchImporter : public ::cv::dnn::Importer
{ {
CV_Error(Error::StsNotImplemented, "Unsupported Torch class \"" + className + "\""); CV_Error(Error::StsNotImplemented, "Unsupported Torch class \"" + className + "\"");
} }
readedIndexes.insert(index);
} }
void readObject() void readObject()
@ -533,12 +531,8 @@ struct TorchImporter : public ::cv::dnn::Importer
if (typeidx == TYPE_TORCH) if (typeidx == TYPE_TORCH)
{ {
int index = readInt(); int index = readInt();
readTorchObject(index);
if (readedIndexes.count(index) == 0) readedIndexes.insert(index);
{
readTorchObject(index);
readedIndexes.insert(index);
}
} }
else if (typeidx == TYPE_NIL) else if (typeidx == TYPE_NIL)
return; return;

@ -27,7 +27,7 @@ function save(net, input, label)
end end
local net_simple = nn.Sequential() local net_simple = nn.Sequential()
--net_simple:add(nn.ReLU()) net_simple:add(nn.ReLU())
net_simple:add(nn.SpatialConvolution(3,64, 11,7, 3,4, 3,2)) net_simple:add(nn.SpatialConvolution(3,64, 11,7, 3,4, 3,2))
net_simple:add(nn.SpatialMaxPooling(4,5, 3,2, 1,2)) net_simple:add(nn.SpatialMaxPooling(4,5, 3,2, 1,2))
net_simple:add(nn.Sigmoid()) net_simple:add(nn.Sigmoid())

Loading…
Cancel
Save