Multiple layers support added into Torch importer.

Also DictValue was refactored.
pull/265/head
Vitaliy Lyudvichenko 10 years ago
parent 7d795af6a0
commit 2905c03581
  1. 37
      modules/dnn/include/opencv2/dnn/dict.hpp
  2. 225
      modules/dnn/src/torch/torch_importer.cpp
  3. 6
      modules/dnn/test/test_torch_importer.cpp

@ -152,9 +152,28 @@ inline DictValue DictValue::get<DictValue>(int idx) const
template<> template<>
inline int64 DictValue::get<int64>(int idx) const inline int64 DictValue::get<int64>(int idx) const
{ {
CV_Assert(isInt()); CV_Assert(idx == -1 && size() == 1 || idx >= 0 && idx < size());
CV_Assert(idx == -1 && pi->size() == 1 || idx >= 0 && idx < (int)pi->size()); idx = (idx == -1) ? 0 : idx;
return (*pi)[(idx == -1) ? 0 : idx];
if (type == Param::INT)
{
return (*pi)[idx];
}
else if (type == Param::REAL)
{
double doubleValue = (*pd)[idx];
double fracpart, intpart;
fracpart = std::modf(doubleValue, &intpart);
CV_Assert(fracpart == 0.0);
return doubleValue;
}
else
{
CV_Assert(isInt() || isReal());
return 0;
}
} }
template<> template<>
@ -178,19 +197,20 @@ inline bool DictValue::get<bool>(int idx) const
template<> template<>
inline double DictValue::get<double>(int idx) const inline double DictValue::get<double>(int idx) const
{ {
CV_Assert(idx == -1 && size() == 1 || idx >= 0 && idx < size());
idx = (idx == -1) ? 0 : idx;
if (type == Param::REAL) if (type == Param::REAL)
{ {
CV_Assert(idx == -1 && pd->size() == 1 || idx >= 0 && idx < (int)pd->size()); return (*pd)[idx];
return (*pd)[0];
} }
else if (type == Param::INT) else if (type == Param::INT)
{ {
CV_Assert(idx == -1 && pi->size() == 1 || idx >= 0 && idx < (int)pi->size()); return (double)(*pi)[idx];
return (double)(*pi)[0];;
} }
else else
{ {
CV_Assert(isReal()); CV_Assert(isReal() || isInt());
return 0; return 0;
} }
} }
@ -300,6 +320,7 @@ inline int DictValue::size() const
return (int)pd->size(); return (int)pd->size();
break; break;
default: default:
CV_Error(Error::StsInternal, "");
return -1; return -1;
} }
} }

@ -44,11 +44,27 @@ static inline bool endsWith(const String &str, const char *substr)
struct TorchImporter : public ::cv::dnn::Importer struct TorchImporter : public ::cv::dnn::Importer
{ {
Net net;
THFile *file; THFile *file;
std::set<int> readedIndexes; std::set<int> readedIndexes;
std::map<int, Mat> storages; std::map<int, Mat> storages;
std::map<int, Blob> tensors; std::map<int, Blob> tensors;
struct Module
{
String thName, type;
dnn::LayerParams params;
std::vector<Module*> modules;
Module(const String &_thName, const String &_type = String())
: thName(_thName), type(_type) {}
};
Module *rootModule;
Module *curModule;
int moduleCounter;
TorchImporter(String filename, bool isBinary) TorchImporter(String filename, bool isBinary)
{ {
file = THDiskFile_new(filename.c_str(), "r", 0); file = THDiskFile_new(filename.c_str(), "r", 0);
@ -139,6 +155,8 @@ struct TorchImporter : public ::cv::dnn::Importer
return CV_16S; return CV_16S;
else if (typeStr == "Int") else if (typeStr == "Int")
return CV_32S; return CV_32S;
else if (typeStr == "Long") //Carefully! CV_64S type coded as CV_USRTYPE1
return CV_USRTYPE1;
else else
CV_Error(Error::StsNotImplemented, "Unknown type \"" + typeStr + "\" of torch class \"" + str + "\""); CV_Error(Error::StsNotImplemented, "Unknown type \"" + typeStr + "\" of torch class \"" + str + "\"");
} }
@ -162,7 +180,7 @@ struct TorchImporter : public ::cv::dnn::Importer
return; return;
long size = readLong(); long size = readLong();
Mat storageMat(1, size, type); Mat storageMat(1, size, (type != CV_USRTYPE1) ? type : CV_64F); //handle LongStorage as CV_64F Mat
switch (type) switch (type)
{ {
@ -182,6 +200,16 @@ struct TorchImporter : public ::cv::dnn::Importer
break; break;
case CV_32S: case CV_32S:
THFile_readIntRaw(file, (int*)storageMat.data, size); THFile_readIntRaw(file, (int*)storageMat.data, size);
break;
case CV_USRTYPE1:
{
double *buf = storageMat.ptr<double>();
THFile_readLongRaw(file, (long*)buf, size);
for (size_t i = 0; i < (size_t)size; i++)
buf[i] = ((long*)buf)[i];
}
break;
default: default:
CV_Error(Error::StsInternal, ""); CV_Error(Error::StsInternal, "");
break; break;
@ -210,13 +238,14 @@ struct TorchImporter : public ::cv::dnn::Importer
if (ktype != TYPE_STRING) //skip non-string fileds if (ktype != TYPE_STRING) //skip non-string fileds
{ {
THFile_seek(file, fpos); THFile_seek(file, fpos);
readObject(); readObject(); //key
readObject(); readObject(); //value
continue; continue;
} }
String key = readString(); String key = readString();
std::cout << "key: " << key << "\n";
fpos = THFile_position(file); fpos = THFile_position(file);
int vtype = readInt(); int vtype = readInt();
@ -227,9 +256,20 @@ struct TorchImporter : public ::cv::dnn::Importer
{ {
readTorchObject(index); readTorchObject(index);
} }
if (tensors.count(index)) if (tensors.count(index)) //tensor was readed
{
tensorParams.insert(std::make_pair(key, tensors[index])); tensorParams.insert(std::make_pair(key, tensors[index]));
}
else if (storages.count(index)) //storage was readed
{
Mat &matStorage = storages[index];
Mat matCasted;
matStorage.convertTo(matCasted, CV_64F);
DictValue scalar = DictValue::arrayReal(matCasted.ptr<double>(), matCasted.total());
scalarParams.set(key, scalar);
}
} }
else if (vtype == TYPE_NUMBER) else if (vtype == TYPE_NUMBER)
{ {
@ -250,6 +290,15 @@ struct TorchImporter : public ::cv::dnn::Importer
continue; continue;
} }
} }
//Debug output
std::cout << "scalarParams:\n";
std::cout << scalarParams;
std::cout << "#" << tensorParams.size() << "tensorParams:\n";
std::map<String,Blob>::const_iterator it;
for (it = tensorParams.begin(); it != tensorParams.end(); it++)
std::cout << it->first << ": Tensor " << it->second.shape() << "\n";
} }
void readTorchTensor(int indexTensor, int typeTensor) void readTorchTensor(int indexTensor, int typeTensor)
@ -324,15 +373,22 @@ struct TorchImporter : public ::cv::dnn::Importer
return false; return false;
} }
void readTorchObject(int index, bool skip = false) void convertTorchKernelsParams(const Dict &torchParams, cv::dnn::LayerParams &layerParams)
{
layerParams.set("kernel_h", torchParams.get<int>("kH"));
layerParams.set("kernel_w", torchParams.get<int>("kW"));
layerParams.set("stride_h", torchParams.get<int>("dH"));
layerParams.set("stride_w", torchParams.get<int>("dW"));
layerParams.set("pad_h", torchParams.get<int>("padH", 0));
layerParams.set("pad_w", torchParams.get<int>("padW", 0));
}
void readTorchObject(int index)
{ {
String className = readTorchClassName(); String className = readTorchClassName();
String nnName; String nnName;
std::cout << "Class: " << className << std::endl; std::cout << "Class: " << className << std::endl;
Dict scalarParams;
std::map<String, Blob> tensorParams;
int type; int type;
if ( (type = parseTensorType(className)) >= 0 ) //is Tensor if ( (type = parseTensorType(className)) >= 0 ) //is Tensor
{ {
@ -345,26 +401,97 @@ struct TorchImporter : public ::cv::dnn::Importer
else if (isNNClass(className, nnName)) else if (isNNClass(className, nnName))
{ {
CV_Assert(!readedIndexes.count(index)); CV_Assert(!readedIndexes.count(index));
readTorchTable(scalarParams, tensorParams);
std::cout << scalarParams; Dict scalarParams;
for (std::map<String,Blob>::const_iterator it = tensorParams.begin(); it != tensorParams.end(); it++) std::map<String, Blob> tensorParams;
std::cout << it->first << ": Tensor" << "\n";
Module *newModule = new Module(nnName);
cv::dnn::LayerParams &layerParams = newModule->params;
if (nnName == "Sequential") if (nnName == "Sequential" || nnName == "Parallel" || nnName == "Concat")
{
Module *parentModule = curModule;
curModule->modules.push_back(newModule);
curModule = newModule;
readTorchTable(scalarParams, tensorParams);
curModule = parentModule;
}
else if (nnName == "SpatialConvolution")
{ {
newModule->type = "Convolution";
readTorchTable(scalarParams, tensorParams);
CV_Assert(tensorParams.count("weight"));
layerParams.learnedBlobs.push_back(tensorParams["weight"]);
bool bias = tensorParams.count("bias");
layerParams.set("bias_term", bias);
if (bias)
layerParams.learnedBlobs.push_back(tensorParams["bias"]);
layerParams.set("num_output", scalarParams.get<int>("nOutputPlane"));
convertTorchKernelsParams(scalarParams, layerParams);
curModule->modules.push_back(newModule);
} }
else if (nnName == "Concat") else if (nnName == "SpatialMaxPooling" || nnName == "SpatialAveragePooling")
{ {
newModule->type = "Pooling";
readTorchTable(scalarParams, tensorParams);
if (nnName == "SpatialMaxPooling")
layerParams.set("pool", "MAX");
if (nnName == "SpatialAveragePooling")
layerParams.set("pool", "AVE");
convertTorchKernelsParams(scalarParams, layerParams);
curModule->modules.push_back(newModule);
} }
else if (nnName == "SpatialConvolution") else if (nnName == "Linear")
{
newModule->type = "InnerProduct";
readTorchTable(scalarParams, tensorParams);
CV_Assert(tensorParams.count("weight"));
Blob weightBlob = tensorParams["weight"];
layerParams.learnedBlobs.push_back(weightBlob);
bool bias = tensorParams.count("bias");
if (bias)
layerParams.learnedBlobs.push_back(tensorParams["bias"]);
layerParams.set("bias_term", bias);
//TODO: axis detect
layerParams.set("num_output", weightBlob.size(1));
curModule->modules.push_back(newModule);
}
else if (nnName == "Reshape")
{ {
newModule->type = "Reshape";
readTorchTable(scalarParams, tensorParams);
CV_Assert(scalarParams.has("size"));
DictValue dimParam = scalarParams.get("size");
layerParams.set("dim", dimParam);
curModule->modules.push_back(newModule);
} }
else if (nnName == "ReLU") else if (nnName == "ReLU")
{ {
curModule->modules.push_back(new Module(nnName, "ReLU"));
}
else if (nnName == "Tanh")
{
curModule->modules.push_back(new Module(nnName, "TanH"));
}
else if (nnName == "Sigmoid")
{
curModule->modules.push_back(new Module(nnName, "Sigmoid"));
} }
else else
{ {
readTorchTable(scalarParams, tensorParams);
CV_Error(Error::StsNotImplemented, "Unknown nn class \"" + className + "\""); CV_Error(Error::StsNotImplemented, "Unknown nn class \"" + className + "\"");
} }
} }
@ -392,26 +519,82 @@ struct TorchImporter : public ::cv::dnn::Importer
else if (typeidx == TYPE_NIL) else if (typeidx == TYPE_NIL)
return; return;
else if (typeidx == TYPE_NUMBER) else if (typeidx == TYPE_NUMBER)
//readDouble(); readDouble();
std::cout << readDouble() << std::endl;
else if (typeidx == TYPE_BOOLEAN) else if (typeidx == TYPE_BOOLEAN)
readBool(); readBool();
else if (typeidx == TYPE_STRING) else if (typeidx == TYPE_STRING)
//readString(); readString();
std::cout << readString() << std::endl;
else if (typeidx == TYPE_TABLE) else if (typeidx == TYPE_TABLE)
readTable(); readTable();
else else
CV_Error(Error::StsNotImplemented, "Unsupported Lua type"); CV_Error(Error::StsNotImplemented, "Unsupported Lua type");
} }
inline String generateLayerName(const String &label = String())
{
this->moduleCounter++;
return "l" + toString(this->moduleCounter) + "_" + label;
}
int fill(Module *module, int prevLayerId = 0, int prevOutNum = 0)
{
if (module == NULL)
return prevLayerId;
if (module->type.length())
{
int newLayerId = this->net.addLayer(generateLayerName(module->type), module->type, module->params);
net.connect(prevLayerId, prevOutNum, newLayerId, 0);
std::cout << "added " << module->thName << " i.e. " << module->type << "\n";
return newLayerId;
}
else
{
if (module->thName == "Sequential")
{
for (size_t i = 0; i < module->modules.size(); i++)
{
prevLayerId = fill(module->modules[i], prevLayerId, prevOutNum);
prevOutNum = 0;
}
return prevLayerId;
}
else if (module->thName == "Parallel" || module->thName == "Concat")
{
int splitId, mergeId, newId;
String splitType = (module->thName == "Parallel") ? "Slice" : "Split";
splitId = net.addLayer(generateLayerName("torchSplit"), splitType, module->params);
net.connect(prevLayerId, prevOutNum, splitId, 0);
mergeId = net.addLayer(generateLayerName("torchMerge"), "Concat", module->params);
for (size_t i = 0; i < module->modules.size(); i++)
{
newId = fill(module->modules[i], splitId, (int)i);
net.connect(newId, 0, mergeId, (int)i);
}
return mergeId;
}
}
CV_Error(Error::StsInternal, "Unexpected torch container: " + module->thName);
return -1;
}
void populateNet(Net net) void populateNet(Net net)
{ {
this->net = net;
THFile_seek(file, 0); THFile_seek(file, 0);
readedIndexes.clear(); readedIndexes.clear();
storages.clear(); storages.clear();
rootModule = new Module("Sequential");
curModule = rootModule;
readObject(); readObject();
moduleCounter = 0;
fill(rootModule);
} }
}; };

@ -25,9 +25,11 @@ TEST(Torch_Importer, simple_read)
Net net; Net net;
Ptr<Importer> importer; Ptr<Importer> importer;
//ASSERT_NO_THROW( importer = createTorchImporter("/home/vitaliy/th/conv1.txt", false) ); ASSERT_NO_THROW( importer = createTorchImporter("/home/vitaliy/th/conv1.txt", false) );
ASSERT_NO_THROW( importer = createTorchImporter("L:\\home\\vitaliy\\th\\conv1.txt", false) ); //ASSERT_NO_THROW( importer = createTorchImporter("L:\\home\\vitaliy\\th\\conv1.txt", false) );
ASSERT_TRUE( importer != NULL ); ASSERT_TRUE( importer != NULL );
importer->populateNet(net); importer->populateNet(net);
//ASSERT_NO_THROW( importer->populateNet(net) ); //ASSERT_NO_THROW( importer->populateNet(net) );
} }

Loading…
Cancel
Save