diff --git a/modules/dnn/include/opencv2/dnn/dict.hpp b/modules/dnn/include/opencv2/dnn/dict.hpp index bdfe5e9ba..c1b10c6cc 100644 --- a/modules/dnn/include/opencv2/dnn/dict.hpp +++ b/modules/dnn/include/opencv2/dnn/dict.hpp @@ -3,6 +3,7 @@ #include #include +#include namespace cv { @@ -35,6 +36,8 @@ struct DictValue DictValue &operator=(const DictValue &r); + friend std::ostream &operator<<(std::ostream &stream, const DictValue &dictv); + ~DictValue(); protected: @@ -135,6 +138,8 @@ public: return value; } + + friend std::ostream &operator<<(std::ostream &stream, const Dict &dict); }; template<> @@ -299,6 +304,41 @@ inline int DictValue::size() const } } +inline std::ostream &operator<<(std::ostream &stream, const DictValue &dictv) +{ + int i; + + if (dictv.isInt()) + { + for (i = 0; i < dictv.size() - 1; i++) + stream << dictv.get(i) << ", "; + stream << dictv.get(i); + } + else if (dictv.isReal()) + { + for (i = 0; i < dictv.size() - 1; i++) + stream << dictv.get(i) << ", "; + stream << dictv.get(i); + } + else if (dictv.isString()) + { + for (i = 0; i < dictv.size() - 1; i++) + stream << "\"" << dictv.get(i) << "\", "; + stream << dictv.get(i); + } + + return stream; +} + +inline std::ostream &operator<<(std::ostream &stream, const Dict &dict) +{ + Dict::_Dict::const_iterator it; + for (it = dict.dict.begin(); it != dict.dict.end(); it++) + stream << it->first << " : " << it->second << "\n"; + + return stream; +} + } } diff --git a/modules/dnn/src/torch/torch_importer.cpp b/modules/dnn/src/torch/torch_importer.cpp index 7ea55c825..716204777 100644 --- a/modules/dnn/src/torch/torch_importer.cpp +++ b/modules/dnn/src/torch/torch_importer.cpp @@ -42,12 +42,12 @@ static inline bool endsWith(const String &str, const char *substr) return str.rfind(substr) == str.length() - strlen(substr); } - struct TorchImporter : public ::cv::dnn::Importer { THFile *file; std::set readedIndexes; std::map storages; + std::map tensors; TorchImporter(String filename, bool isBinary) { @@ -99,22 +99,23 @@ struct TorchImporter : public ::cv::dnn::Importer inline void readFunction() { readString(); - readObject(true); + readObject(); } - void readTable() + void readTable(int index = -1) { - std::cout << "Skipping table\n"; + index = (index < 0) ? readInt() : index; + + if (readedIndexes.count(index)) + return; - int index = readInt(); - CV_Assert(readedIndexes.count(index) == 0); readedIndexes.insert(index); int size = readInt(); for (int i = 0; i < size; i++) { - readObject(true); //key - readObject(true); //value + readObject(); //key + readObject(); //value } } @@ -157,106 +158,223 @@ struct TorchImporter : public ::cv::dnn::Importer void readTorchStorage(int index, int type = -1) { + if (storages.count(index)) + return; + long size = readLong(); Mat storageMat(1, size, type); - THFile_readByteRaw(file, storageMat.data, size * CV_ELEM_SIZE(type)); + switch (type) + { + case CV_32F: + THFile_readFloatRaw(file, (float*)storageMat.data, size); + break; + case CV_64F: + THFile_readDoubleRaw(file, (double*)storageMat.data, size); + break; + case CV_8S: + case CV_8U: + THFile_readByteRaw(file, (uchar*)storageMat.data, size); + break; + case CV_16S: + case CV_16U: + THFile_readShortRaw(file, (short*)storageMat.data, size); + break; + case CV_32S: + THFile_readIntRaw(file, (int*)storageMat.data, size); + default: + CV_Error(Error::StsInternal, ""); + break; + } storages.insert(std::make_pair(index, storageMat)); readedIndexes.insert(index); } - Blob readTorchTensor(int typeTensor, bool skip = false) + void readTorchTable(Dict &scalarParams, std::map &tensorParams) { - int ndims = readInt(); + int luaType = readInt(); + int index = readInt(); + + CV_Assert(luaType == TYPE_TABLE && readedIndexes.count(index) == 0); + readedIndexes.insert(index); + + long fpos; + int numPairs = readInt(); + + for (int i = 0; i < numPairs; i++) + { + fpos = THFile_position(file); + int ktype = readInt(); + + if (ktype != TYPE_STRING) //skip non-string fileds + { + THFile_seek(file, fpos); + readObject(); + readObject(); + continue; + } + + String key = readString(); + + fpos = THFile_position(file); + int vtype = readInt(); + + if (vtype == TYPE_TORCH) + { + int index = readInt(); + if (tensors.count(index) == 0) + { + readTorchObject(index); + } + + if (tensors.count(index)) + tensorParams.insert(std::make_pair(key, tensors[index])); + } + else if (vtype == TYPE_NUMBER) + { + scalarParams.set(key, readDouble()); + } + else if (vtype == TYPE_STRING) + { + scalarParams.set(key, readString()); + } + else if (vtype == TYPE_BOOLEAN) + { + scalarParams.set(key, readBool()); + } + else + { + THFile_seek(file, fpos); + readObject(); + continue; + } + } + } + + void readTorchTensor(int indexTensor, int typeTensor) + { + if (tensors.count(indexTensor)) + return; + int ndims = readInt(); AutoBuffer sizes(ndims); AutoBuffer steps(ndims); THFile_readLongRaw(file, sizes, ndims); - THFile_readLongRaw(file, sizes, ndims); - + THFile_readLongRaw(file, steps, ndims); long offset = readLong() - 1; //read Storage int typeidx = readInt(); - std::cout << "stograge typeidx of tensor: " << typeidx << "\n"; CV_Assert(typeidx == TYPE_TORCH || (typeidx == TYPE_NIL && ndims == 0)); if (typeidx == TYPE_NIL) - return Blob(); + { + tensors.insert(std::make_pair(indexTensor, Blob())); + return; + } - int index = readInt(); - if (readedIndexes.count(index) == 0) + int indexStorage = readInt(); + if (readedIndexes.count(indexStorage) == 0) { int typeStorage = parseStorageType(readTorchClassName()); CV_Assert(typeStorage >= 0 && typeTensor == typeStorage); - readTorchStorage(typeStorage, index); + readTorchStorage(indexStorage, typeStorage); } - //allocate Blob + //small check + size_t requireElems = (size_t)offset + (size_t)steps[0] * (size_t)sizes[0]; + size_t storageElems = storages[indexStorage].total(); + if (requireElems > storageElems) + CV_Error(Error::StsBadSize, "Storage has insufficent number of elemements for requested Tensor"); + + //convert sizes AutoBuffer isizes(ndims); AutoBuffer ssteps(ndims); - - size_t stepExpected = 1; for (int i = ndims - 1; i >= 0; i--) { isizes[i] = (int)sizes[i]; ssteps[i] = (size_t)steps[i] * CV_ELEM_SIZE(typeTensor); - - stepExpected *= sizes[i]; } - if (skip) - return Blob(); - - Mat srcMat(ndims, (int*)isizes, typeTensor , storages[index].ptr(), (size_t*)ssteps); + //allocate Blob + Mat srcMat(ndims, (int*)isizes, typeTensor , storages[indexStorage].ptr() + offset, (size_t*)ssteps); int dstType = (typeTensor == CV_64F) ? CV_64F : CV_32F; Blob blob; blob.create(BlobShape(ndims, isizes), dstType); srcMat.convertTo(blob.getMatRef(), dstType); - return blob; + tensors.insert(std::make_pair(indexTensor, blob)); + } + + bool isNNClass(const String &className, String &nnName) + { + const char *prefixes[] = {"nn.", "cunn.", "cudnn.", "fbcunn.", NULL}; + + for (int i = 0; prefixes[i]; i++) + { + if (startsWith(className, prefixes[i])) + { + nnName = className.substr(strlen(prefixes[i])); + return true; + } + } + + return false; } void readTorchObject(int index, bool skip = false) { String className = readTorchClassName(); + String nnName; std::cout << "Class: " << className << std::endl; + Dict scalarParams; + std::map tensorParams; + int type; if ( (type = parseTensorType(className)) >= 0 ) //is Tensor { - readTorchTensor(type); - return; + readTorchTensor(index, type); } else if ( (type = parseStorageType(className)) >= 0 ) //is Storage { readTorchStorage(index, type); } - else if (className == "nn.Sequential") - { - readObject(); - } - else if (className == "nn.Concat") + else if (isNNClass(className, nnName)) { - readObject(); - } - else if (className == "nn.SpatialConvolution") - { - readObject(); - } - else if (className == "nn.ReLU") - { - readObject(); + CV_Assert(!readedIndexes.count(index)); + readTorchTable(scalarParams, tensorParams); + + std::cout << scalarParams; + for (std::map::const_iterator it = tensorParams.begin(); it != tensorParams.end(); it++) + std::cout << it->first << ": Tensor" << "\n"; + + if (nnName == "Sequential") + { + } + else if (nnName == "Concat") + { + } + else if (nnName == "SpatialConvolution") + { + } + else if (nnName == "ReLU") + { + } + else + { + CV_Error(Error::StsNotImplemented, "Unknown nn class \"" + className + "\""); + } } else { - CV_Error(Error::StsNotImplemented, "Unsupported Torch class \"" + className +"\""); + CV_Error(Error::StsNotImplemented, "Unsupported Torch class \"" + className + "\""); } } - void readObject(bool skip = false) + void readObject() { int typeidx = readInt(); std::cout << "typeidx: " << typeidx << "\n"; @@ -267,23 +385,20 @@ struct TorchImporter : public ::cv::dnn::Importer if (readedIndexes.count(index) == 0) { - readTorchObject(index, skip); + readTorchObject(index); readedIndexes.insert(index); } - else - { - //CV_Error(Error::StsNotImplemented, ""); - //TBD - } } else if (typeidx == TYPE_NIL) return; else if (typeidx == TYPE_NUMBER) - readDouble(); + //readDouble(); + std::cout << readDouble() << std::endl; else if (typeidx == TYPE_BOOLEAN) readBool(); else if (typeidx == TYPE_STRING) - readString(); + //readString(); + std::cout << readString() << std::endl; else if (typeidx == TYPE_TABLE) readTable(); else @@ -294,6 +409,7 @@ struct TorchImporter : public ::cv::dnn::Importer { THFile_seek(file, 0); readedIndexes.clear(); + storages.clear(); readObject(); } diff --git a/modules/dnn/test/test_torch_importer.cpp b/modules/dnn/test/test_torch_importer.cpp index 92b7d88d0..220d6fb58 100644 --- a/modules/dnn/test/test_torch_importer.cpp +++ b/modules/dnn/test/test_torch_importer.cpp @@ -25,7 +25,8 @@ TEST(Torch_Importer, simple_read) Net net; Ptr importer; - ASSERT_NO_THROW( importer = createTorchImporter("/home/vitaliy/th/conv1.txt", false) ); + //ASSERT_NO_THROW( importer = createTorchImporter("/home/vitaliy/th/conv1.txt", false) ); + ASSERT_NO_THROW( importer = createTorchImporter("L:\\home\\vitaliy\\th\\conv1.txt", false) ); ASSERT_TRUE( importer != NULL ); importer->populateNet(net); //ASSERT_NO_THROW( importer->populateNet(net) );