|
|
|
@ -42,12 +42,12 @@ static inline bool endsWith(const String &str, const char *substr) |
|
|
|
|
return str.rfind(substr) == str.length() - strlen(substr); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
struct TorchImporter : public ::cv::dnn::Importer |
|
|
|
|
{ |
|
|
|
|
THFile *file; |
|
|
|
|
std::set<int> readedIndexes; |
|
|
|
|
std::map<int, Mat> storages; |
|
|
|
|
std::map<int, Blob> tensors; |
|
|
|
|
|
|
|
|
|
TorchImporter(String filename, bool isBinary) |
|
|
|
|
{ |
|
|
|
@ -99,22 +99,23 @@ struct TorchImporter : public ::cv::dnn::Importer |
|
|
|
|
inline void readFunction() |
|
|
|
|
{ |
|
|
|
|
readString(); |
|
|
|
|
readObject(true); |
|
|
|
|
readObject(); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
void readTable() |
|
|
|
|
void readTable(int index = -1) |
|
|
|
|
{ |
|
|
|
|
std::cout << "Skipping table\n"; |
|
|
|
|
index = (index < 0) ? readInt() : index; |
|
|
|
|
|
|
|
|
|
if (readedIndexes.count(index)) |
|
|
|
|
return; |
|
|
|
|
|
|
|
|
|
int index = readInt(); |
|
|
|
|
CV_Assert(readedIndexes.count(index) == 0); |
|
|
|
|
readedIndexes.insert(index); |
|
|
|
|
|
|
|
|
|
int size = readInt(); |
|
|
|
|
for (int i = 0; i < size; i++) |
|
|
|
|
{ |
|
|
|
|
readObject(true); //key
|
|
|
|
|
readObject(true); //value
|
|
|
|
|
readObject(); //key
|
|
|
|
|
readObject(); //value
|
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
@ -157,106 +158,223 @@ struct TorchImporter : public ::cv::dnn::Importer |
|
|
|
|
|
|
|
|
|
void readTorchStorage(int index, int type = -1) |
|
|
|
|
{ |
|
|
|
|
if (storages.count(index)) |
|
|
|
|
return; |
|
|
|
|
|
|
|
|
|
long size = readLong(); |
|
|
|
|
Mat storageMat(1, size, type); |
|
|
|
|
|
|
|
|
|
THFile_readByteRaw(file, storageMat.data, size * CV_ELEM_SIZE(type)); |
|
|
|
|
switch (type) |
|
|
|
|
{ |
|
|
|
|
case CV_32F: |
|
|
|
|
THFile_readFloatRaw(file, (float*)storageMat.data, size); |
|
|
|
|
break; |
|
|
|
|
case CV_64F: |
|
|
|
|
THFile_readDoubleRaw(file, (double*)storageMat.data, size); |
|
|
|
|
break; |
|
|
|
|
case CV_8S: |
|
|
|
|
case CV_8U: |
|
|
|
|
THFile_readByteRaw(file, (uchar*)storageMat.data, size); |
|
|
|
|
break; |
|
|
|
|
case CV_16S: |
|
|
|
|
case CV_16U: |
|
|
|
|
THFile_readShortRaw(file, (short*)storageMat.data, size); |
|
|
|
|
break; |
|
|
|
|
case CV_32S: |
|
|
|
|
THFile_readIntRaw(file, (int*)storageMat.data, size); |
|
|
|
|
default: |
|
|
|
|
CV_Error(Error::StsInternal, ""); |
|
|
|
|
break; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
storages.insert(std::make_pair(index, storageMat)); |
|
|
|
|
readedIndexes.insert(index); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
Blob readTorchTensor(int typeTensor, bool skip = false) |
|
|
|
|
void readTorchTable(Dict &scalarParams, std::map<String, Blob> &tensorParams) |
|
|
|
|
{ |
|
|
|
|
int ndims = readInt(); |
|
|
|
|
int luaType = readInt(); |
|
|
|
|
int index = readInt(); |
|
|
|
|
|
|
|
|
|
CV_Assert(luaType == TYPE_TABLE && readedIndexes.count(index) == 0); |
|
|
|
|
readedIndexes.insert(index); |
|
|
|
|
|
|
|
|
|
long fpos; |
|
|
|
|
int numPairs = readInt(); |
|
|
|
|
|
|
|
|
|
for (int i = 0; i < numPairs; i++) |
|
|
|
|
{ |
|
|
|
|
fpos = THFile_position(file); |
|
|
|
|
int ktype = readInt(); |
|
|
|
|
|
|
|
|
|
if (ktype != TYPE_STRING) //skip non-string fileds
|
|
|
|
|
{ |
|
|
|
|
THFile_seek(file, fpos); |
|
|
|
|
readObject(); |
|
|
|
|
readObject(); |
|
|
|
|
continue; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
String key = readString(); |
|
|
|
|
|
|
|
|
|
fpos = THFile_position(file); |
|
|
|
|
int vtype = readInt(); |
|
|
|
|
|
|
|
|
|
if (vtype == TYPE_TORCH) |
|
|
|
|
{ |
|
|
|
|
int index = readInt(); |
|
|
|
|
if (tensors.count(index) == 0) |
|
|
|
|
{ |
|
|
|
|
readTorchObject(index); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
if (tensors.count(index)) |
|
|
|
|
tensorParams.insert(std::make_pair(key, tensors[index])); |
|
|
|
|
} |
|
|
|
|
else if (vtype == TYPE_NUMBER) |
|
|
|
|
{ |
|
|
|
|
scalarParams.set(key, readDouble()); |
|
|
|
|
} |
|
|
|
|
else if (vtype == TYPE_STRING) |
|
|
|
|
{ |
|
|
|
|
scalarParams.set(key, readString()); |
|
|
|
|
} |
|
|
|
|
else if (vtype == TYPE_BOOLEAN) |
|
|
|
|
{ |
|
|
|
|
scalarParams.set(key, readBool()); |
|
|
|
|
} |
|
|
|
|
else |
|
|
|
|
{ |
|
|
|
|
THFile_seek(file, fpos); |
|
|
|
|
readObject(); |
|
|
|
|
continue; |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
void readTorchTensor(int indexTensor, int typeTensor) |
|
|
|
|
{ |
|
|
|
|
if (tensors.count(indexTensor)) |
|
|
|
|
return; |
|
|
|
|
|
|
|
|
|
int ndims = readInt(); |
|
|
|
|
AutoBuffer<long, 4> sizes(ndims); |
|
|
|
|
AutoBuffer<long, 4> steps(ndims); |
|
|
|
|
THFile_readLongRaw(file, sizes, ndims); |
|
|
|
|
THFile_readLongRaw(file, sizes, ndims); |
|
|
|
|
|
|
|
|
|
THFile_readLongRaw(file, steps, ndims); |
|
|
|
|
long offset = readLong() - 1; |
|
|
|
|
|
|
|
|
|
//read Storage
|
|
|
|
|
int typeidx = readInt(); |
|
|
|
|
std::cout << "stograge typeidx of tensor: " << typeidx << "\n"; |
|
|
|
|
CV_Assert(typeidx == TYPE_TORCH || (typeidx == TYPE_NIL && ndims == 0)); |
|
|
|
|
|
|
|
|
|
if (typeidx == TYPE_NIL) |
|
|
|
|
return Blob(); |
|
|
|
|
{ |
|
|
|
|
tensors.insert(std::make_pair(indexTensor, Blob())); |
|
|
|
|
return; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
int index = readInt(); |
|
|
|
|
if (readedIndexes.count(index) == 0) |
|
|
|
|
int indexStorage = readInt(); |
|
|
|
|
if (readedIndexes.count(indexStorage) == 0) |
|
|
|
|
{ |
|
|
|
|
int typeStorage = parseStorageType(readTorchClassName()); |
|
|
|
|
CV_Assert(typeStorage >= 0 && typeTensor == typeStorage); |
|
|
|
|
readTorchStorage(typeStorage, index); |
|
|
|
|
readTorchStorage(indexStorage, typeStorage); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
//allocate Blob
|
|
|
|
|
//small check
|
|
|
|
|
size_t requireElems = (size_t)offset + (size_t)steps[0] * (size_t)sizes[0]; |
|
|
|
|
size_t storageElems = storages[indexStorage].total(); |
|
|
|
|
if (requireElems > storageElems) |
|
|
|
|
CV_Error(Error::StsBadSize, "Storage has insufficent number of elemements for requested Tensor"); |
|
|
|
|
|
|
|
|
|
//convert sizes
|
|
|
|
|
AutoBuffer<int, 4> isizes(ndims); |
|
|
|
|
AutoBuffer<size_t, 4> ssteps(ndims); |
|
|
|
|
|
|
|
|
|
size_t stepExpected = 1; |
|
|
|
|
for (int i = ndims - 1; i >= 0; i--) |
|
|
|
|
{ |
|
|
|
|
isizes[i] = (int)sizes[i]; |
|
|
|
|
ssteps[i] = (size_t)steps[i] * CV_ELEM_SIZE(typeTensor); |
|
|
|
|
|
|
|
|
|
stepExpected *= sizes[i]; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
if (skip) |
|
|
|
|
return Blob(); |
|
|
|
|
|
|
|
|
|
Mat srcMat(ndims, (int*)isizes, typeTensor , storages[index].ptr(), (size_t*)ssteps); |
|
|
|
|
//allocate Blob
|
|
|
|
|
Mat srcMat(ndims, (int*)isizes, typeTensor , storages[indexStorage].ptr() + offset, (size_t*)ssteps); |
|
|
|
|
int dstType = (typeTensor == CV_64F) ? CV_64F : CV_32F; |
|
|
|
|
|
|
|
|
|
Blob blob; |
|
|
|
|
blob.create(BlobShape(ndims, isizes), dstType); |
|
|
|
|
srcMat.convertTo(blob.getMatRef(), dstType); |
|
|
|
|
|
|
|
|
|
return blob; |
|
|
|
|
tensors.insert(std::make_pair(indexTensor, blob)); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
bool isNNClass(const String &className, String &nnName) |
|
|
|
|
{ |
|
|
|
|
const char *prefixes[] = {"nn.", "cunn.", "cudnn.", "fbcunn.", NULL}; |
|
|
|
|
|
|
|
|
|
for (int i = 0; prefixes[i]; i++) |
|
|
|
|
{ |
|
|
|
|
if (startsWith(className, prefixes[i])) |
|
|
|
|
{ |
|
|
|
|
nnName = className.substr(strlen(prefixes[i])); |
|
|
|
|
return true; |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
return false; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
void readTorchObject(int index, bool skip = false) |
|
|
|
|
{ |
|
|
|
|
String className = readTorchClassName(); |
|
|
|
|
String nnName; |
|
|
|
|
std::cout << "Class: " << className << std::endl; |
|
|
|
|
|
|
|
|
|
Dict scalarParams; |
|
|
|
|
std::map<String, Blob> tensorParams; |
|
|
|
|
|
|
|
|
|
int type; |
|
|
|
|
if ( (type = parseTensorType(className)) >= 0 ) //is Tensor
|
|
|
|
|
{ |
|
|
|
|
readTorchTensor(type); |
|
|
|
|
return; |
|
|
|
|
readTorchTensor(index, type); |
|
|
|
|
} |
|
|
|
|
else if ( (type = parseStorageType(className)) >= 0 ) //is Storage
|
|
|
|
|
{ |
|
|
|
|
readTorchStorage(index, type); |
|
|
|
|
} |
|
|
|
|
else if (className == "nn.Sequential") |
|
|
|
|
{ |
|
|
|
|
readObject(); |
|
|
|
|
} |
|
|
|
|
else if (className == "nn.Concat") |
|
|
|
|
else if (isNNClass(className, nnName)) |
|
|
|
|
{ |
|
|
|
|
readObject(); |
|
|
|
|
} |
|
|
|
|
else if (className == "nn.SpatialConvolution") |
|
|
|
|
{ |
|
|
|
|
readObject(); |
|
|
|
|
} |
|
|
|
|
else if (className == "nn.ReLU") |
|
|
|
|
{ |
|
|
|
|
readObject(); |
|
|
|
|
CV_Assert(!readedIndexes.count(index)); |
|
|
|
|
readTorchTable(scalarParams, tensorParams); |
|
|
|
|
|
|
|
|
|
std::cout << scalarParams; |
|
|
|
|
for (std::map<String,Blob>::const_iterator it = tensorParams.begin(); it != tensorParams.end(); it++) |
|
|
|
|
std::cout << it->first << ": Tensor" << "\n"; |
|
|
|
|
|
|
|
|
|
if (nnName == "Sequential") |
|
|
|
|
{ |
|
|
|
|
} |
|
|
|
|
else if (nnName == "Concat") |
|
|
|
|
{ |
|
|
|
|
} |
|
|
|
|
else if (nnName == "SpatialConvolution") |
|
|
|
|
{ |
|
|
|
|
} |
|
|
|
|
else if (nnName == "ReLU") |
|
|
|
|
{ |
|
|
|
|
} |
|
|
|
|
else |
|
|
|
|
{ |
|
|
|
|
CV_Error(Error::StsNotImplemented, "Unknown nn class \"" + className + "\""); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
else |
|
|
|
|
{ |
|
|
|
|
CV_Error(Error::StsNotImplemented, "Unsupported Torch class \"" + className +"\""); |
|
|
|
|
CV_Error(Error::StsNotImplemented, "Unsupported Torch class \"" + className + "\""); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
void readObject(bool skip = false) |
|
|
|
|
void readObject() |
|
|
|
|
{ |
|
|
|
|
int typeidx = readInt(); |
|
|
|
|
std::cout << "typeidx: " << typeidx << "\n"; |
|
|
|
@ -267,23 +385,20 @@ struct TorchImporter : public ::cv::dnn::Importer |
|
|
|
|
|
|
|
|
|
if (readedIndexes.count(index) == 0) |
|
|
|
|
{ |
|
|
|
|
readTorchObject(index, skip); |
|
|
|
|
readTorchObject(index); |
|
|
|
|
readedIndexes.insert(index); |
|
|
|
|
} |
|
|
|
|
else |
|
|
|
|
{ |
|
|
|
|
//CV_Error(Error::StsNotImplemented, "");
|
|
|
|
|
//TBD
|
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
else if (typeidx == TYPE_NIL) |
|
|
|
|
return; |
|
|
|
|
else if (typeidx == TYPE_NUMBER) |
|
|
|
|
readDouble(); |
|
|
|
|
//readDouble();
|
|
|
|
|
std::cout << readDouble() << std::endl; |
|
|
|
|
else if (typeidx == TYPE_BOOLEAN) |
|
|
|
|
readBool(); |
|
|
|
|
else if (typeidx == TYPE_STRING) |
|
|
|
|
readString(); |
|
|
|
|
//readString();
|
|
|
|
|
std::cout << readString() << std::endl; |
|
|
|
|
else if (typeidx == TYPE_TABLE) |
|
|
|
|
readTable(); |
|
|
|
|
else |
|
|
|
@ -294,6 +409,7 @@ struct TorchImporter : public ::cv::dnn::Importer |
|
|
|
|
{ |
|
|
|
|
THFile_seek(file, 0); |
|
|
|
|
readedIndexes.clear(); |
|
|
|
|
storages.clear(); |
|
|
|
|
|
|
|
|
|
readObject(); |
|
|
|
|
} |
|
|
|
|