Fixed Torch parser

pull/265/head
Vitaliy Lyudvichenko 10 years ago
parent 4da1046d68
commit 7d795af6a0
  1. 40
      modules/dnn/include/opencv2/dnn/dict.hpp
  2. 224
      modules/dnn/src/torch/torch_importer.cpp
  3. 3
      modules/dnn/test/test_torch_importer.cpp

@ -3,6 +3,7 @@
#include <opencv2/core.hpp> #include <opencv2/core.hpp>
#include <map> #include <map>
#include <ostream>
namespace cv namespace cv
{ {
@ -35,6 +36,8 @@ struct DictValue
DictValue &operator=(const DictValue &r); DictValue &operator=(const DictValue &r);
friend std::ostream &operator<<(std::ostream &stream, const DictValue &dictv);
~DictValue(); ~DictValue();
protected: protected:
@ -135,6 +138,8 @@ public:
return value; return value;
} }
friend std::ostream &operator<<(std::ostream &stream, const Dict &dict);
}; };
template<> template<>
@ -299,6 +304,41 @@ inline int DictValue::size() const
} }
} }
inline std::ostream &operator<<(std::ostream &stream, const DictValue &dictv)
{
int i;
if (dictv.isInt())
{
for (i = 0; i < dictv.size() - 1; i++)
stream << dictv.get<int64>(i) << ", ";
stream << dictv.get<int64>(i);
}
else if (dictv.isReal())
{
for (i = 0; i < dictv.size() - 1; i++)
stream << dictv.get<double>(i) << ", ";
stream << dictv.get<double>(i);
}
else if (dictv.isString())
{
for (i = 0; i < dictv.size() - 1; i++)
stream << "\"" << dictv.get<String>(i) << "\", ";
stream << dictv.get<String>(i);
}
return stream;
}
inline std::ostream &operator<<(std::ostream &stream, const Dict &dict)
{
Dict::_Dict::const_iterator it;
for (it = dict.dict.begin(); it != dict.dict.end(); it++)
stream << it->first << " : " << it->second << "\n";
return stream;
}
} }
} }

@ -42,12 +42,12 @@ static inline bool endsWith(const String &str, const char *substr)
return str.rfind(substr) == str.length() - strlen(substr); return str.rfind(substr) == str.length() - strlen(substr);
} }
struct TorchImporter : public ::cv::dnn::Importer struct TorchImporter : public ::cv::dnn::Importer
{ {
THFile *file; THFile *file;
std::set<int> readedIndexes; std::set<int> readedIndexes;
std::map<int, Mat> storages; std::map<int, Mat> storages;
std::map<int, Blob> tensors;
TorchImporter(String filename, bool isBinary) TorchImporter(String filename, bool isBinary)
{ {
@ -99,22 +99,23 @@ struct TorchImporter : public ::cv::dnn::Importer
inline void readFunction() inline void readFunction()
{ {
readString(); readString();
readObject(true); readObject();
} }
void readTable() void readTable(int index = -1)
{ {
std::cout << "Skipping table\n"; index = (index < 0) ? readInt() : index;
if (readedIndexes.count(index))
return;
int index = readInt();
CV_Assert(readedIndexes.count(index) == 0);
readedIndexes.insert(index); readedIndexes.insert(index);
int size = readInt(); int size = readInt();
for (int i = 0; i < size; i++) for (int i = 0; i < size; i++)
{ {
readObject(true); //key readObject(); //key
readObject(true); //value readObject(); //value
} }
} }
@ -157,106 +158,223 @@ struct TorchImporter : public ::cv::dnn::Importer
void readTorchStorage(int index, int type = -1) void readTorchStorage(int index, int type = -1)
{ {
if (storages.count(index))
return;
long size = readLong(); long size = readLong();
Mat storageMat(1, size, type); Mat storageMat(1, size, type);
THFile_readByteRaw(file, storageMat.data, size * CV_ELEM_SIZE(type)); switch (type)
{
case CV_32F:
THFile_readFloatRaw(file, (float*)storageMat.data, size);
break;
case CV_64F:
THFile_readDoubleRaw(file, (double*)storageMat.data, size);
break;
case CV_8S:
case CV_8U:
THFile_readByteRaw(file, (uchar*)storageMat.data, size);
break;
case CV_16S:
case CV_16U:
THFile_readShortRaw(file, (short*)storageMat.data, size);
break;
case CV_32S:
THFile_readIntRaw(file, (int*)storageMat.data, size);
default:
CV_Error(Error::StsInternal, "");
break;
}
storages.insert(std::make_pair(index, storageMat)); storages.insert(std::make_pair(index, storageMat));
readedIndexes.insert(index); readedIndexes.insert(index);
} }
Blob readTorchTensor(int typeTensor, bool skip = false) void readTorchTable(Dict &scalarParams, std::map<String, Blob> &tensorParams)
{ {
int ndims = readInt(); int luaType = readInt();
int index = readInt();
CV_Assert(luaType == TYPE_TABLE && readedIndexes.count(index) == 0);
readedIndexes.insert(index);
long fpos;
int numPairs = readInt();
for (int i = 0; i < numPairs; i++)
{
fpos = THFile_position(file);
int ktype = readInt();
if (ktype != TYPE_STRING) //skip non-string fileds
{
THFile_seek(file, fpos);
readObject();
readObject();
continue;
}
String key = readString();
fpos = THFile_position(file);
int vtype = readInt();
if (vtype == TYPE_TORCH)
{
int index = readInt();
if (tensors.count(index) == 0)
{
readTorchObject(index);
}
if (tensors.count(index))
tensorParams.insert(std::make_pair(key, tensors[index]));
}
else if (vtype == TYPE_NUMBER)
{
scalarParams.set(key, readDouble());
}
else if (vtype == TYPE_STRING)
{
scalarParams.set(key, readString());
}
else if (vtype == TYPE_BOOLEAN)
{
scalarParams.set(key, readBool());
}
else
{
THFile_seek(file, fpos);
readObject();
continue;
}
}
}
void readTorchTensor(int indexTensor, int typeTensor)
{
if (tensors.count(indexTensor))
return;
int ndims = readInt();
AutoBuffer<long, 4> sizes(ndims); AutoBuffer<long, 4> sizes(ndims);
AutoBuffer<long, 4> steps(ndims); AutoBuffer<long, 4> steps(ndims);
THFile_readLongRaw(file, sizes, ndims); THFile_readLongRaw(file, sizes, ndims);
THFile_readLongRaw(file, sizes, ndims); THFile_readLongRaw(file, steps, ndims);
long offset = readLong() - 1; long offset = readLong() - 1;
//read Storage //read Storage
int typeidx = readInt(); int typeidx = readInt();
std::cout << "stograge typeidx of tensor: " << typeidx << "\n";
CV_Assert(typeidx == TYPE_TORCH || (typeidx == TYPE_NIL && ndims == 0)); CV_Assert(typeidx == TYPE_TORCH || (typeidx == TYPE_NIL && ndims == 0));
if (typeidx == TYPE_NIL) if (typeidx == TYPE_NIL)
return Blob(); {
tensors.insert(std::make_pair(indexTensor, Blob()));
return;
}
int index = readInt(); int indexStorage = readInt();
if (readedIndexes.count(index) == 0) if (readedIndexes.count(indexStorage) == 0)
{ {
int typeStorage = parseStorageType(readTorchClassName()); int typeStorage = parseStorageType(readTorchClassName());
CV_Assert(typeStorage >= 0 && typeTensor == typeStorage); CV_Assert(typeStorage >= 0 && typeTensor == typeStorage);
readTorchStorage(typeStorage, index); readTorchStorage(indexStorage, typeStorage);
} }
//allocate Blob //small check
size_t requireElems = (size_t)offset + (size_t)steps[0] * (size_t)sizes[0];
size_t storageElems = storages[indexStorage].total();
if (requireElems > storageElems)
CV_Error(Error::StsBadSize, "Storage has insufficent number of elemements for requested Tensor");
//convert sizes
AutoBuffer<int, 4> isizes(ndims); AutoBuffer<int, 4> isizes(ndims);
AutoBuffer<size_t, 4> ssteps(ndims); AutoBuffer<size_t, 4> ssteps(ndims);
size_t stepExpected = 1;
for (int i = ndims - 1; i >= 0; i--) for (int i = ndims - 1; i >= 0; i--)
{ {
isizes[i] = (int)sizes[i]; isizes[i] = (int)sizes[i];
ssteps[i] = (size_t)steps[i] * CV_ELEM_SIZE(typeTensor); ssteps[i] = (size_t)steps[i] * CV_ELEM_SIZE(typeTensor);
stepExpected *= sizes[i];
} }
if (skip) //allocate Blob
return Blob(); Mat srcMat(ndims, (int*)isizes, typeTensor , storages[indexStorage].ptr() + offset, (size_t*)ssteps);
Mat srcMat(ndims, (int*)isizes, typeTensor , storages[index].ptr(), (size_t*)ssteps);
int dstType = (typeTensor == CV_64F) ? CV_64F : CV_32F; int dstType = (typeTensor == CV_64F) ? CV_64F : CV_32F;
Blob blob; Blob blob;
blob.create(BlobShape(ndims, isizes), dstType); blob.create(BlobShape(ndims, isizes), dstType);
srcMat.convertTo(blob.getMatRef(), dstType); srcMat.convertTo(blob.getMatRef(), dstType);
return blob; tensors.insert(std::make_pair(indexTensor, blob));
}
bool isNNClass(const String &className, String &nnName)
{
const char *prefixes[] = {"nn.", "cunn.", "cudnn.", "fbcunn.", NULL};
for (int i = 0; prefixes[i]; i++)
{
if (startsWith(className, prefixes[i]))
{
nnName = className.substr(strlen(prefixes[i]));
return true;
}
}
return false;
} }
void readTorchObject(int index, bool skip = false) void readTorchObject(int index, bool skip = false)
{ {
String className = readTorchClassName(); String className = readTorchClassName();
String nnName;
std::cout << "Class: " << className << std::endl; std::cout << "Class: " << className << std::endl;
Dict scalarParams;
std::map<String, Blob> tensorParams;
int type; int type;
if ( (type = parseTensorType(className)) >= 0 ) //is Tensor if ( (type = parseTensorType(className)) >= 0 ) //is Tensor
{ {
readTorchTensor(type); readTorchTensor(index, type);
return;
} }
else if ( (type = parseStorageType(className)) >= 0 ) //is Storage else if ( (type = parseStorageType(className)) >= 0 ) //is Storage
{ {
readTorchStorage(index, type); readTorchStorage(index, type);
} }
else if (className == "nn.Sequential") else if (isNNClass(className, nnName))
{
readObject();
}
else if (className == "nn.Concat")
{ {
readObject(); CV_Assert(!readedIndexes.count(index));
} readTorchTable(scalarParams, tensorParams);
else if (className == "nn.SpatialConvolution")
{ std::cout << scalarParams;
readObject(); for (std::map<String,Blob>::const_iterator it = tensorParams.begin(); it != tensorParams.end(); it++)
} std::cout << it->first << ": Tensor" << "\n";
else if (className == "nn.ReLU")
{ if (nnName == "Sequential")
readObject(); {
}
else if (nnName == "Concat")
{
}
else if (nnName == "SpatialConvolution")
{
}
else if (nnName == "ReLU")
{
}
else
{
CV_Error(Error::StsNotImplemented, "Unknown nn class \"" + className + "\"");
}
} }
else else
{ {
CV_Error(Error::StsNotImplemented, "Unsupported Torch class \"" + className +"\""); CV_Error(Error::StsNotImplemented, "Unsupported Torch class \"" + className + "\"");
} }
} }
void readObject(bool skip = false) void readObject()
{ {
int typeidx = readInt(); int typeidx = readInt();
std::cout << "typeidx: " << typeidx << "\n"; std::cout << "typeidx: " << typeidx << "\n";
@ -267,23 +385,20 @@ struct TorchImporter : public ::cv::dnn::Importer
if (readedIndexes.count(index) == 0) if (readedIndexes.count(index) == 0)
{ {
readTorchObject(index, skip); readTorchObject(index);
readedIndexes.insert(index); readedIndexes.insert(index);
} }
else
{
//CV_Error(Error::StsNotImplemented, "");
//TBD
}
} }
else if (typeidx == TYPE_NIL) else if (typeidx == TYPE_NIL)
return; return;
else if (typeidx == TYPE_NUMBER) else if (typeidx == TYPE_NUMBER)
readDouble(); //readDouble();
std::cout << readDouble() << std::endl;
else if (typeidx == TYPE_BOOLEAN) else if (typeidx == TYPE_BOOLEAN)
readBool(); readBool();
else if (typeidx == TYPE_STRING) else if (typeidx == TYPE_STRING)
readString(); //readString();
std::cout << readString() << std::endl;
else if (typeidx == TYPE_TABLE) else if (typeidx == TYPE_TABLE)
readTable(); readTable();
else else
@ -294,6 +409,7 @@ struct TorchImporter : public ::cv::dnn::Importer
{ {
THFile_seek(file, 0); THFile_seek(file, 0);
readedIndexes.clear(); readedIndexes.clear();
storages.clear();
readObject(); readObject();
} }

@ -25,7 +25,8 @@ TEST(Torch_Importer, simple_read)
Net net; Net net;
Ptr<Importer> importer; Ptr<Importer> importer;
ASSERT_NO_THROW( importer = createTorchImporter("/home/vitaliy/th/conv1.txt", false) ); //ASSERT_NO_THROW( importer = createTorchImporter("/home/vitaliy/th/conv1.txt", false) );
ASSERT_NO_THROW( importer = createTorchImporter("L:\\home\\vitaliy\\th\\conv1.txt", false) );
ASSERT_TRUE( importer != NULL ); ASSERT_TRUE( importer != NULL );
importer->populateNet(net); importer->populateNet(net);
//ASSERT_NO_THROW( importer->populateNet(net) ); //ASSERT_NO_THROW( importer->populateNet(net) );

Loading…
Cancel
Save