|
|
|
@ -185,9 +185,6 @@ struct TorchImporter : public ::cv::dnn::Importer |
|
|
|
|
|
|
|
|
|
void readTorchStorage(int index, int type = -1) |
|
|
|
|
{ |
|
|
|
|
if (storages.count(index)) |
|
|
|
|
return; |
|
|
|
|
|
|
|
|
|
long size = readLong(); |
|
|
|
|
Mat storageMat(1, size, (type != CV_USRTYPE1) ? type : CV_64F); //handle LongStorage as CV_64F Mat
|
|
|
|
|
|
|
|
|
@ -225,7 +222,6 @@ struct TorchImporter : public ::cv::dnn::Importer |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
storages.insert(std::make_pair(index, storageMat)); |
|
|
|
|
readedIndexes.insert(index); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
void readTorchTable(Dict &scalarParams, std::map<String, Blob> &tensorParams) |
|
|
|
@ -261,10 +257,7 @@ struct TorchImporter : public ::cv::dnn::Importer |
|
|
|
|
if (vtype == TYPE_TORCH) |
|
|
|
|
{ |
|
|
|
|
int index = readInt(); |
|
|
|
|
if (tensors.count(index) == 0) |
|
|
|
|
{ |
|
|
|
|
readTorchObject(index); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
if (tensors.count(index)) //tensor was readed
|
|
|
|
|
{ |
|
|
|
@ -311,9 +304,6 @@ struct TorchImporter : public ::cv::dnn::Importer |
|
|
|
|
|
|
|
|
|
void readTorchTensor(int indexTensor, int typeTensor) |
|
|
|
|
{ |
|
|
|
|
if (tensors.count(indexTensor)) |
|
|
|
|
return; |
|
|
|
|
|
|
|
|
|
int ndims = readInt(); |
|
|
|
|
AutoBuffer<long, 4> sizes(ndims); |
|
|
|
|
AutoBuffer<long, 4> steps(ndims); |
|
|
|
@ -394,6 +384,14 @@ struct TorchImporter : public ::cv::dnn::Importer |
|
|
|
|
|
|
|
|
|
void readTorchObject(int index) |
|
|
|
|
{ |
|
|
|
|
if(readedIndexes.count(index)) |
|
|
|
|
{ |
|
|
|
|
if(!storages.count(index) && !tensors.count(index)) |
|
|
|
|
CV_Error(Error::StsNotImplemented, "Objects which have multiple references are not supported"); |
|
|
|
|
else |
|
|
|
|
return; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
String className = readTorchClassName(); |
|
|
|
|
String nnName; |
|
|
|
|
std::cout << "Class: " << className << std::endl; |
|
|
|
@ -409,8 +407,6 @@ struct TorchImporter : public ::cv::dnn::Importer |
|
|
|
|
} |
|
|
|
|
else if (isNNClass(className, nnName)) |
|
|
|
|
{ |
|
|
|
|
CV_Assert(!readedIndexes.count(index)); |
|
|
|
|
|
|
|
|
|
Dict scalarParams; |
|
|
|
|
std::map<String, Blob> tensorParams; |
|
|
|
|
|
|
|
|
@ -524,6 +520,8 @@ struct TorchImporter : public ::cv::dnn::Importer |
|
|
|
|
{ |
|
|
|
|
CV_Error(Error::StsNotImplemented, "Unsupported Torch class \"" + className + "\""); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
readedIndexes.insert(index); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
void readObject() |
|
|
|
@ -533,13 +531,9 @@ struct TorchImporter : public ::cv::dnn::Importer |
|
|
|
|
if (typeidx == TYPE_TORCH) |
|
|
|
|
{ |
|
|
|
|
int index = readInt(); |
|
|
|
|
|
|
|
|
|
if (readedIndexes.count(index) == 0) |
|
|
|
|
{ |
|
|
|
|
readTorchObject(index); |
|
|
|
|
readedIndexes.insert(index); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
else if (typeidx == TYPE_NIL) |
|
|
|
|
return; |
|
|
|
|
else if (typeidx == TYPE_NUMBER) |
|
|
|
|