Multiple layers support added into Torch importer.

Also DictValue was refactored.
pull/265/head
Vitaliy Lyudvichenko 10 years ago
parent 7d795af6a0
commit 2905c03581
  1. 37
      modules/dnn/include/opencv2/dnn/dict.hpp
  2. 221
      modules/dnn/src/torch/torch_importer.cpp
  3. 6
      modules/dnn/test/test_torch_importer.cpp

@ -152,9 +152,28 @@ inline DictValue DictValue::get<DictValue>(int idx) const
template<>
inline int64 DictValue::get<int64>(int idx) const
{
CV_Assert(isInt());
CV_Assert(idx == -1 && pi->size() == 1 || idx >= 0 && idx < (int)pi->size());
return (*pi)[(idx == -1) ? 0 : idx];
CV_Assert(idx == -1 && size() == 1 || idx >= 0 && idx < size());
idx = (idx == -1) ? 0 : idx;
if (type == Param::INT)
{
return (*pi)[idx];
}
else if (type == Param::REAL)
{
double doubleValue = (*pd)[idx];
double fracpart, intpart;
fracpart = std::modf(doubleValue, &intpart);
CV_Assert(fracpart == 0.0);
return doubleValue;
}
else
{
CV_Assert(isInt() || isReal());
return 0;
}
}
template<>
@ -178,19 +197,20 @@ inline bool DictValue::get<bool>(int idx) const
template<>
inline double DictValue::get<double>(int idx) const
{
CV_Assert(idx == -1 && size() == 1 || idx >= 0 && idx < size());
idx = (idx == -1) ? 0 : idx;
if (type == Param::REAL)
{
CV_Assert(idx == -1 && pd->size() == 1 || idx >= 0 && idx < (int)pd->size());
return (*pd)[0];
return (*pd)[idx];
}
else if (type == Param::INT)
{
CV_Assert(idx == -1 && pi->size() == 1 || idx >= 0 && idx < (int)pi->size());
return (double)(*pi)[0];;
return (double)(*pi)[idx];
}
else
{
CV_Assert(isReal());
CV_Assert(isReal() || isInt());
return 0;
}
}
@ -300,6 +320,7 @@ inline int DictValue::size() const
return (int)pd->size();
break;
default:
CV_Error(Error::StsInternal, "");
return -1;
}
}

@ -44,11 +44,27 @@ static inline bool endsWith(const String &str, const char *substr)
struct TorchImporter : public ::cv::dnn::Importer
{
Net net;
THFile *file;
std::set<int> readedIndexes;
std::map<int, Mat> storages;
std::map<int, Blob> tensors;
struct Module
{
String thName, type;
dnn::LayerParams params;
std::vector<Module*> modules;
Module(const String &_thName, const String &_type = String())
: thName(_thName), type(_type) {}
};
Module *rootModule;
Module *curModule;
int moduleCounter;
TorchImporter(String filename, bool isBinary)
{
file = THDiskFile_new(filename.c_str(), "r", 0);
@ -139,6 +155,8 @@ struct TorchImporter : public ::cv::dnn::Importer
return CV_16S;
else if (typeStr == "Int")
return CV_32S;
else if (typeStr == "Long") //Carefully! CV_64S type coded as CV_USRTYPE1
return CV_USRTYPE1;
else
CV_Error(Error::StsNotImplemented, "Unknown type \"" + typeStr + "\" of torch class \"" + str + "\"");
}
@ -162,7 +180,7 @@ struct TorchImporter : public ::cv::dnn::Importer
return;
long size = readLong();
Mat storageMat(1, size, type);
Mat storageMat(1, size, (type != CV_USRTYPE1) ? type : CV_64F); //handle LongStorage as CV_64F Mat
switch (type)
{
@ -182,6 +200,16 @@ struct TorchImporter : public ::cv::dnn::Importer
break;
case CV_32S:
THFile_readIntRaw(file, (int*)storageMat.data, size);
break;
case CV_USRTYPE1:
{
double *buf = storageMat.ptr<double>();
THFile_readLongRaw(file, (long*)buf, size);
for (size_t i = 0; i < (size_t)size; i++)
buf[i] = ((long*)buf)[i];
}
break;
default:
CV_Error(Error::StsInternal, "");
break;
@ -210,12 +238,13 @@ struct TorchImporter : public ::cv::dnn::Importer
if (ktype != TYPE_STRING) //skip non-string fileds
{
THFile_seek(file, fpos);
readObject();
readObject();
readObject(); //key
readObject(); //value
continue;
}
String key = readString();
std::cout << "key: " << key << "\n";
fpos = THFile_position(file);
int vtype = readInt();
@ -228,8 +257,19 @@ struct TorchImporter : public ::cv::dnn::Importer
readTorchObject(index);
}
if (tensors.count(index))
if (tensors.count(index)) //tensor was readed
{
tensorParams.insert(std::make_pair(key, tensors[index]));
}
else if (storages.count(index)) //storage was readed
{
Mat &matStorage = storages[index];
Mat matCasted;
matStorage.convertTo(matCasted, CV_64F);
DictValue scalar = DictValue::arrayReal(matCasted.ptr<double>(), matCasted.total());
scalarParams.set(key, scalar);
}
}
else if (vtype == TYPE_NUMBER)
{
@ -250,6 +290,15 @@ struct TorchImporter : public ::cv::dnn::Importer
continue;
}
}
//Debug output
std::cout << "scalarParams:\n";
std::cout << scalarParams;
std::cout << "#" << tensorParams.size() << "tensorParams:\n";
std::map<String,Blob>::const_iterator it;
for (it = tensorParams.begin(); it != tensorParams.end(); it++)
std::cout << it->first << ": Tensor " << it->second.shape() << "\n";
}
void readTorchTensor(int indexTensor, int typeTensor)
@ -324,15 +373,22 @@ struct TorchImporter : public ::cv::dnn::Importer
return false;
}
void readTorchObject(int index, bool skip = false)
void convertTorchKernelsParams(const Dict &torchParams, cv::dnn::LayerParams &layerParams)
{
layerParams.set("kernel_h", torchParams.get<int>("kH"));
layerParams.set("kernel_w", torchParams.get<int>("kW"));
layerParams.set("stride_h", torchParams.get<int>("dH"));
layerParams.set("stride_w", torchParams.get<int>("dW"));
layerParams.set("pad_h", torchParams.get<int>("padH", 0));
layerParams.set("pad_w", torchParams.get<int>("padW", 0));
}
void readTorchObject(int index)
{
String className = readTorchClassName();
String nnName;
std::cout << "Class: " << className << std::endl;
Dict scalarParams;
std::map<String, Blob> tensorParams;
int type;
if ( (type = parseTensorType(className)) >= 0 ) //is Tensor
{
@ -345,26 +401,97 @@ struct TorchImporter : public ::cv::dnn::Importer
else if (isNNClass(className, nnName))
{
CV_Assert(!readedIndexes.count(index));
readTorchTable(scalarParams, tensorParams);
std::cout << scalarParams;
for (std::map<String,Blob>::const_iterator it = tensorParams.begin(); it != tensorParams.end(); it++)
std::cout << it->first << ": Tensor" << "\n";
Dict scalarParams;
std::map<String, Blob> tensorParams;
if (nnName == "Sequential")
Module *newModule = new Module(nnName);
cv::dnn::LayerParams &layerParams = newModule->params;
if (nnName == "Sequential" || nnName == "Parallel" || nnName == "Concat")
{
Module *parentModule = curModule;
curModule->modules.push_back(newModule);
curModule = newModule;
readTorchTable(scalarParams, tensorParams);
curModule = parentModule;
}
else if (nnName == "SpatialConvolution")
{
newModule->type = "Convolution";
readTorchTable(scalarParams, tensorParams);
CV_Assert(tensorParams.count("weight"));
layerParams.learnedBlobs.push_back(tensorParams["weight"]);
bool bias = tensorParams.count("bias");
layerParams.set("bias_term", bias);
if (bias)
layerParams.learnedBlobs.push_back(tensorParams["bias"]);
layerParams.set("num_output", scalarParams.get<int>("nOutputPlane"));
convertTorchKernelsParams(scalarParams, layerParams);
curModule->modules.push_back(newModule);
}
else if (nnName == "Concat")
else if (nnName == "SpatialMaxPooling" || nnName == "SpatialAveragePooling")
{
newModule->type = "Pooling";
readTorchTable(scalarParams, tensorParams);
if (nnName == "SpatialMaxPooling")
layerParams.set("pool", "MAX");
if (nnName == "SpatialAveragePooling")
layerParams.set("pool", "AVE");
convertTorchKernelsParams(scalarParams, layerParams);
curModule->modules.push_back(newModule);
}
else if (nnName == "SpatialConvolution")
else if (nnName == "Linear")
{
newModule->type = "InnerProduct";
readTorchTable(scalarParams, tensorParams);
CV_Assert(tensorParams.count("weight"));
Blob weightBlob = tensorParams["weight"];
layerParams.learnedBlobs.push_back(weightBlob);
bool bias = tensorParams.count("bias");
if (bias)
layerParams.learnedBlobs.push_back(tensorParams["bias"]);
layerParams.set("bias_term", bias);
//TODO: axis detect
layerParams.set("num_output", weightBlob.size(1));
curModule->modules.push_back(newModule);
}
else if (nnName == "Reshape")
{
newModule->type = "Reshape";
readTorchTable(scalarParams, tensorParams);
CV_Assert(scalarParams.has("size"));
DictValue dimParam = scalarParams.get("size");
layerParams.set("dim", dimParam);
curModule->modules.push_back(newModule);
}
else if (nnName == "ReLU")
{
curModule->modules.push_back(new Module(nnName, "ReLU"));
}
else if (nnName == "Tanh")
{
curModule->modules.push_back(new Module(nnName, "TanH"));
}
else if (nnName == "Sigmoid")
{
curModule->modules.push_back(new Module(nnName, "Sigmoid"));
}
else
{
readTorchTable(scalarParams, tensorParams);
CV_Error(Error::StsNotImplemented, "Unknown nn class \"" + className + "\"");
}
}
@ -392,26 +519,82 @@ struct TorchImporter : public ::cv::dnn::Importer
else if (typeidx == TYPE_NIL)
return;
else if (typeidx == TYPE_NUMBER)
//readDouble();
std::cout << readDouble() << std::endl;
readDouble();
else if (typeidx == TYPE_BOOLEAN)
readBool();
else if (typeidx == TYPE_STRING)
//readString();
std::cout << readString() << std::endl;
readString();
else if (typeidx == TYPE_TABLE)
readTable();
else
CV_Error(Error::StsNotImplemented, "Unsupported Lua type");
}
inline String generateLayerName(const String &label = String())
{
this->moduleCounter++;
return "l" + toString(this->moduleCounter) + "_" + label;
}
int fill(Module *module, int prevLayerId = 0, int prevOutNum = 0)
{
if (module == NULL)
return prevLayerId;
if (module->type.length())
{
int newLayerId = this->net.addLayer(generateLayerName(module->type), module->type, module->params);
net.connect(prevLayerId, prevOutNum, newLayerId, 0);
std::cout << "added " << module->thName << " i.e. " << module->type << "\n";
return newLayerId;
}
else
{
if (module->thName == "Sequential")
{
for (size_t i = 0; i < module->modules.size(); i++)
{
prevLayerId = fill(module->modules[i], prevLayerId, prevOutNum);
prevOutNum = 0;
}
return prevLayerId;
}
else if (module->thName == "Parallel" || module->thName == "Concat")
{
int splitId, mergeId, newId;
String splitType = (module->thName == "Parallel") ? "Slice" : "Split";
splitId = net.addLayer(generateLayerName("torchSplit"), splitType, module->params);
net.connect(prevLayerId, prevOutNum, splitId, 0);
mergeId = net.addLayer(generateLayerName("torchMerge"), "Concat", module->params);
for (size_t i = 0; i < module->modules.size(); i++)
{
newId = fill(module->modules[i], splitId, (int)i);
net.connect(newId, 0, mergeId, (int)i);
}
return mergeId;
}
}
CV_Error(Error::StsInternal, "Unexpected torch container: " + module->thName);
return -1;
}
void populateNet(Net net)
{
this->net = net;
THFile_seek(file, 0);
readedIndexes.clear();
storages.clear();
rootModule = new Module("Sequential");
curModule = rootModule;
readObject();
moduleCounter = 0;
fill(rootModule);
}
};

@ -25,9 +25,11 @@ TEST(Torch_Importer, simple_read)
Net net;
Ptr<Importer> importer;
//ASSERT_NO_THROW( importer = createTorchImporter("/home/vitaliy/th/conv1.txt", false) );
ASSERT_NO_THROW( importer = createTorchImporter("L:\\home\\vitaliy\\th\\conv1.txt", false) );
ASSERT_NO_THROW( importer = createTorchImporter("/home/vitaliy/th/conv1.txt", false) );
//ASSERT_NO_THROW( importer = createTorchImporter("L:\\home\\vitaliy\\th\\conv1.txt", false) );
ASSERT_TRUE( importer != NULL );
importer->populateNet(net);
//ASSERT_NO_THROW( importer->populateNet(net) );
}

Loading…
Cancel
Save