Resolve reference counting problem in Torch importer.

pull/265/head
Vitaliy Lyudvichenko 10 years ago
parent 1c220cf03b
commit 436d929578
  1. 26
      modules/dnn/src/torch/torch_importer.cpp
  2. 2
      modules/dnn/testdata/dnn/torch/torch_gen_test_data.lua

@ -185,9 +185,6 @@ struct TorchImporter : public ::cv::dnn::Importer
void readTorchStorage(int index, int type = -1)
{
if (storages.count(index))
return;
long size = readLong();
Mat storageMat(1, size, (type != CV_USRTYPE1) ? type : CV_64F); //handle LongStorage as CV_64F Mat
@ -225,7 +222,6 @@ struct TorchImporter : public ::cv::dnn::Importer
}
storages.insert(std::make_pair(index, storageMat));
readedIndexes.insert(index);
}
void readTorchTable(Dict &scalarParams, std::map<String, Blob> &tensorParams)
@ -261,10 +257,7 @@ struct TorchImporter : public ::cv::dnn::Importer
if (vtype == TYPE_TORCH)
{
int index = readInt();
if (tensors.count(index) == 0)
{
readTorchObject(index);
}
if (tensors.count(index)) //tensor was readed
{
@ -311,9 +304,6 @@ struct TorchImporter : public ::cv::dnn::Importer
void readTorchTensor(int indexTensor, int typeTensor)
{
if (tensors.count(indexTensor))
return;
int ndims = readInt();
AutoBuffer<long, 4> sizes(ndims);
AutoBuffer<long, 4> steps(ndims);
@ -394,6 +384,14 @@ struct TorchImporter : public ::cv::dnn::Importer
void readTorchObject(int index)
{
if(readedIndexes.count(index))
{
if(!storages.count(index) && !tensors.count(index))
CV_Error(Error::StsNotImplemented, "Objects which have multiple references are not supported");
else
return;
}
String className = readTorchClassName();
String nnName;
std::cout << "Class: " << className << std::endl;
@ -409,8 +407,6 @@ struct TorchImporter : public ::cv::dnn::Importer
}
else if (isNNClass(className, nnName))
{
CV_Assert(!readedIndexes.count(index));
Dict scalarParams;
std::map<String, Blob> tensorParams;
@ -524,6 +520,8 @@ struct TorchImporter : public ::cv::dnn::Importer
{
CV_Error(Error::StsNotImplemented, "Unsupported Torch class \"" + className + "\"");
}
readedIndexes.insert(index);
}
void readObject()
@ -533,13 +531,9 @@ struct TorchImporter : public ::cv::dnn::Importer
if (typeidx == TYPE_TORCH)
{
int index = readInt();
if (readedIndexes.count(index) == 0)
{
readTorchObject(index);
readedIndexes.insert(index);
}
}
else if (typeidx == TYPE_NIL)
return;
else if (typeidx == TYPE_NUMBER)

@ -27,7 +27,7 @@ function save(net, input, label)
end
local net_simple = nn.Sequential()
--net_simple:add(nn.ReLU())
net_simple:add(nn.ReLU())
net_simple:add(nn.SpatialConvolution(3,64, 11,7, 3,4, 3,2))
net_simple:add(nn.SpatialMaxPooling(4,5, 3,2, 1,2))
net_simple:add(nn.Sigmoid())

Loading…
Cancel
Save