Fixed Torch parser

pull/265/head
Vitaliy Lyudvichenko 10 years ago
parent 4da1046d68
commit 7d795af6a0
  1. 40
      modules/dnn/include/opencv2/dnn/dict.hpp
  2. 212
      modules/dnn/src/torch/torch_importer.cpp
  3. 3
      modules/dnn/test/test_torch_importer.cpp

@ -3,6 +3,7 @@
#include <opencv2/core.hpp>
#include <map>
#include <ostream>
namespace cv
{
@ -35,6 +36,8 @@ struct DictValue
DictValue &operator=(const DictValue &r);
friend std::ostream &operator<<(std::ostream &stream, const DictValue &dictv);
~DictValue();
protected:
@ -135,6 +138,8 @@ public:
return value;
}
friend std::ostream &operator<<(std::ostream &stream, const Dict &dict);
};
template<>
@ -299,6 +304,41 @@ inline int DictValue::size() const
}
}
inline std::ostream &operator<<(std::ostream &stream, const DictValue &dictv)
{
int i;
if (dictv.isInt())
{
for (i = 0; i < dictv.size() - 1; i++)
stream << dictv.get<int64>(i) << ", ";
stream << dictv.get<int64>(i);
}
else if (dictv.isReal())
{
for (i = 0; i < dictv.size() - 1; i++)
stream << dictv.get<double>(i) << ", ";
stream << dictv.get<double>(i);
}
else if (dictv.isString())
{
for (i = 0; i < dictv.size() - 1; i++)
stream << "\"" << dictv.get<String>(i) << "\", ";
stream << dictv.get<String>(i);
}
return stream;
}
inline std::ostream &operator<<(std::ostream &stream, const Dict &dict)
{
Dict::_Dict::const_iterator it;
for (it = dict.dict.begin(); it != dict.dict.end(); it++)
stream << it->first << " : " << it->second << "\n";
return stream;
}
}
}

@ -42,12 +42,12 @@ static inline bool endsWith(const String &str, const char *substr)
return str.rfind(substr) == str.length() - strlen(substr);
}
struct TorchImporter : public ::cv::dnn::Importer
{
THFile *file;
std::set<int> readedIndexes;
std::map<int, Mat> storages;
std::map<int, Blob> tensors;
TorchImporter(String filename, bool isBinary)
{
@ -99,22 +99,23 @@ struct TorchImporter : public ::cv::dnn::Importer
inline void readFunction()
{
readString();
readObject(true);
readObject();
}
void readTable()
void readTable(int index = -1)
{
std::cout << "Skipping table\n";
index = (index < 0) ? readInt() : index;
if (readedIndexes.count(index))
return;
int index = readInt();
CV_Assert(readedIndexes.count(index) == 0);
readedIndexes.insert(index);
int size = readInt();
for (int i = 0; i < size; i++)
{
readObject(true); //key
readObject(true); //value
readObject(); //key
readObject(); //value
}
}
@ -157,106 +158,223 @@ struct TorchImporter : public ::cv::dnn::Importer
void readTorchStorage(int index, int type = -1)
{
if (storages.count(index))
return;
long size = readLong();
Mat storageMat(1, size, type);
THFile_readByteRaw(file, storageMat.data, size * CV_ELEM_SIZE(type));
switch (type)
{
case CV_32F:
THFile_readFloatRaw(file, (float*)storageMat.data, size);
break;
case CV_64F:
THFile_readDoubleRaw(file, (double*)storageMat.data, size);
break;
case CV_8S:
case CV_8U:
THFile_readByteRaw(file, (uchar*)storageMat.data, size);
break;
case CV_16S:
case CV_16U:
THFile_readShortRaw(file, (short*)storageMat.data, size);
break;
case CV_32S:
THFile_readIntRaw(file, (int*)storageMat.data, size);
default:
CV_Error(Error::StsInternal, "");
break;
}
storages.insert(std::make_pair(index, storageMat));
readedIndexes.insert(index);
}
Blob readTorchTensor(int typeTensor, bool skip = false)
void readTorchTable(Dict &scalarParams, std::map<String, Blob> &tensorParams)
{
int ndims = readInt();
int luaType = readInt();
int index = readInt();
CV_Assert(luaType == TYPE_TABLE && readedIndexes.count(index) == 0);
readedIndexes.insert(index);
long fpos;
int numPairs = readInt();
for (int i = 0; i < numPairs; i++)
{
fpos = THFile_position(file);
int ktype = readInt();
if (ktype != TYPE_STRING) //skip non-string fileds
{
THFile_seek(file, fpos);
readObject();
readObject();
continue;
}
String key = readString();
fpos = THFile_position(file);
int vtype = readInt();
if (vtype == TYPE_TORCH)
{
int index = readInt();
if (tensors.count(index) == 0)
{
readTorchObject(index);
}
if (tensors.count(index))
tensorParams.insert(std::make_pair(key, tensors[index]));
}
else if (vtype == TYPE_NUMBER)
{
scalarParams.set(key, readDouble());
}
else if (vtype == TYPE_STRING)
{
scalarParams.set(key, readString());
}
else if (vtype == TYPE_BOOLEAN)
{
scalarParams.set(key, readBool());
}
else
{
THFile_seek(file, fpos);
readObject();
continue;
}
}
}
void readTorchTensor(int indexTensor, int typeTensor)
{
if (tensors.count(indexTensor))
return;
int ndims = readInt();
AutoBuffer<long, 4> sizes(ndims);
AutoBuffer<long, 4> steps(ndims);
THFile_readLongRaw(file, sizes, ndims);
THFile_readLongRaw(file, sizes, ndims);
THFile_readLongRaw(file, steps, ndims);
long offset = readLong() - 1;
//read Storage
int typeidx = readInt();
std::cout << "stograge typeidx of tensor: " << typeidx << "\n";
CV_Assert(typeidx == TYPE_TORCH || (typeidx == TYPE_NIL && ndims == 0));
if (typeidx == TYPE_NIL)
return Blob();
{
tensors.insert(std::make_pair(indexTensor, Blob()));
return;
}
int index = readInt();
if (readedIndexes.count(index) == 0)
int indexStorage = readInt();
if (readedIndexes.count(indexStorage) == 0)
{
int typeStorage = parseStorageType(readTorchClassName());
CV_Assert(typeStorage >= 0 && typeTensor == typeStorage);
readTorchStorage(typeStorage, index);
readTorchStorage(indexStorage, typeStorage);
}
//allocate Blob
//small check
size_t requireElems = (size_t)offset + (size_t)steps[0] * (size_t)sizes[0];
size_t storageElems = storages[indexStorage].total();
if (requireElems > storageElems)
CV_Error(Error::StsBadSize, "Storage has insufficent number of elemements for requested Tensor");
//convert sizes
AutoBuffer<int, 4> isizes(ndims);
AutoBuffer<size_t, 4> ssteps(ndims);
size_t stepExpected = 1;
for (int i = ndims - 1; i >= 0; i--)
{
isizes[i] = (int)sizes[i];
ssteps[i] = (size_t)steps[i] * CV_ELEM_SIZE(typeTensor);
stepExpected *= sizes[i];
}
if (skip)
return Blob();
Mat srcMat(ndims, (int*)isizes, typeTensor , storages[index].ptr(), (size_t*)ssteps);
//allocate Blob
Mat srcMat(ndims, (int*)isizes, typeTensor , storages[indexStorage].ptr() + offset, (size_t*)ssteps);
int dstType = (typeTensor == CV_64F) ? CV_64F : CV_32F;
Blob blob;
blob.create(BlobShape(ndims, isizes), dstType);
srcMat.convertTo(blob.getMatRef(), dstType);
return blob;
tensors.insert(std::make_pair(indexTensor, blob));
}
bool isNNClass(const String &className, String &nnName)
{
const char *prefixes[] = {"nn.", "cunn.", "cudnn.", "fbcunn.", NULL};
for (int i = 0; prefixes[i]; i++)
{
if (startsWith(className, prefixes[i]))
{
nnName = className.substr(strlen(prefixes[i]));
return true;
}
}
return false;
}
void readTorchObject(int index, bool skip = false)
{
String className = readTorchClassName();
String nnName;
std::cout << "Class: " << className << std::endl;
Dict scalarParams;
std::map<String, Blob> tensorParams;
int type;
if ( (type = parseTensorType(className)) >= 0 ) //is Tensor
{
readTorchTensor(type);
return;
readTorchTensor(index, type);
}
else if ( (type = parseStorageType(className)) >= 0 ) //is Storage
{
readTorchStorage(index, type);
}
else if (className == "nn.Sequential")
else if (isNNClass(className, nnName))
{
CV_Assert(!readedIndexes.count(index));
readTorchTable(scalarParams, tensorParams);
std::cout << scalarParams;
for (std::map<String,Blob>::const_iterator it = tensorParams.begin(); it != tensorParams.end(); it++)
std::cout << it->first << ": Tensor" << "\n";
if (nnName == "Sequential")
{
readObject();
}
else if (className == "nn.Concat")
else if (nnName == "Concat")
{
readObject();
}
else if (className == "nn.SpatialConvolution")
else if (nnName == "SpatialConvolution")
{
readObject();
}
else if (className == "nn.ReLU")
else if (nnName == "ReLU")
{
readObject();
}
else
{
CV_Error(Error::StsNotImplemented, "Unsupported Torch class \"" + className +"\"");
CV_Error(Error::StsNotImplemented, "Unknown nn class \"" + className + "\"");
}
}
else
{
CV_Error(Error::StsNotImplemented, "Unsupported Torch class \"" + className + "\"");
}
}
void readObject(bool skip = false)
void readObject()
{
int typeidx = readInt();
std::cout << "typeidx: " << typeidx << "\n";
@ -267,23 +385,20 @@ struct TorchImporter : public ::cv::dnn::Importer
if (readedIndexes.count(index) == 0)
{
readTorchObject(index, skip);
readTorchObject(index);
readedIndexes.insert(index);
}
else
{
//CV_Error(Error::StsNotImplemented, "");
//TBD
}
}
else if (typeidx == TYPE_NIL)
return;
else if (typeidx == TYPE_NUMBER)
readDouble();
//readDouble();
std::cout << readDouble() << std::endl;
else if (typeidx == TYPE_BOOLEAN)
readBool();
else if (typeidx == TYPE_STRING)
readString();
//readString();
std::cout << readString() << std::endl;
else if (typeidx == TYPE_TABLE)
readTable();
else
@ -294,6 +409,7 @@ struct TorchImporter : public ::cv::dnn::Importer
{
THFile_seek(file, 0);
readedIndexes.clear();
storages.clear();
readObject();
}

@ -25,7 +25,8 @@ TEST(Torch_Importer, simple_read)
Net net;
Ptr<Importer> importer;
ASSERT_NO_THROW( importer = createTorchImporter("/home/vitaliy/th/conv1.txt", false) );
//ASSERT_NO_THROW( importer = createTorchImporter("/home/vitaliy/th/conv1.txt", false) );
ASSERT_NO_THROW( importer = createTorchImporter("L:\\home\\vitaliy\\th\\conv1.txt", false) );
ASSERT_TRUE( importer != NULL );
importer->populateNet(net);
//ASSERT_NO_THROW( importer->populateNet(net) );

Loading…
Cancel
Save