From 3aa37d297143857cb8c8da5029b0a8db451d275c Mon Sep 17 00:00:00 2001 From: Vitaliy Lyudvichenko Date: Sat, 18 Jul 2015 01:12:08 +0300 Subject: [PATCH] Array of parmaters support into caffe_importer.cpp --- modules/dnn/include/opencv2/dnn/dict.hpp | 61 ++++++++++++++++++++---- modules/dnn/src/caffe_importer.cpp | 49 ++++++++++++++----- 2 files changed, 88 insertions(+), 22 deletions(-) diff --git a/modules/dnn/include/opencv2/dnn/dict.hpp b/modules/dnn/include/opencv2/dnn/dict.hpp index e48dcb201..d192654bb 100644 --- a/modules/dnn/include/opencv2/dnn/dict.hpp +++ b/modules/dnn/include/opencv2/dnn/dict.hpp @@ -16,6 +16,13 @@ struct DictValue DictValue(unsigned p) : type(Param::INT), pi(new AutoBuffer) { (*pi)[0] = p; } DictValue(double p) : type(Param::REAL), pd(new AutoBuffer) { (*pd)[0] = p; } DictValue(const String &p) : type(Param::STRING), ps(new AutoBuffer) { (*ps)[0] = p; } + + template + static DictValue arrayInt(TypeIter begin, int size); + template + static DictValue arrayReal(TypeIter begin, int size); + template + static DictValue arrayString(TypeIter begin, int size); template T get(int idx = -1) const; @@ -24,7 +31,7 @@ struct DictValue bool isInt() const; bool isString() const; - bool isReal() const; + bool isReal() const; DictValue &operator=(const DictValue &r); @@ -39,11 +46,40 @@ protected: AutoBuffer *pi; AutoBuffer *pd; AutoBuffer *ps; + void *p; }; + DictValue(int _type, void *_p) : type(_type), p(_p) {} void release(); }; +template +DictValue DictValue::arrayInt(TypeIter begin, int size) +{ + DictValue res(Param::INT, new AutoBuffer(size)); + for (int j = 0; j < size; begin++, j++) + (*res.pi)[j] = *begin; + return res; +} + +template +DictValue DictValue::arrayReal(TypeIter begin, int size) +{ + DictValue res(Param::REAL, new AutoBuffer(size)); + for (int j = 0; j < size; begin++, j++) + (*res.pd)[j] = *begin; + return res; +} + +template +DictValue DictValue::arrayString(TypeIter begin, int size) +{ + DictValue res(Param::STRING, new AutoBuffer(size)); + for (int j = 0; j < size; begin++, j++) + (*res.ps)[j] = *begin; + return res; +} + class CV_EXPORTS Dict { typedef std::map _Dict; @@ -62,13 +98,18 @@ public: return (i == dict.end()) ? NULL : &i->second; } - template - T get(const String &name) const + const DictValue &get(const String &name) const { _Dict::const_iterator i = dict.find(name); if (i == dict.end()) - CV_Error(cv::Error::StsBadArg, "Required argument \"" + name + "\" not found into dictionary"); - return i->second.get(); + CV_Error(Error::StsBadArg, "Required argument \"" + name + "\" not found into dictionary"); + return i->second; + } + + template + T get(const String &name) const + { + return this->get(name).get(); } template @@ -106,7 +147,7 @@ inline DictValue DictValue::get(int idx) const template<> inline int64 DictValue::get(int idx) const { - CV_Assert(type == Param::INT); + CV_Assert(isInt()); CV_Assert(idx == -1 && pi->size() == 1 || idx >= 0 && idx < (int)pi->size()); return (*pi)[(idx == -1) ? 0 : idx]; } @@ -144,7 +185,7 @@ inline double DictValue::get(int idx) const } else { - CV_Assert(type == Param::REAL || type == Param::INT); + CV_Assert(isReal()); return 0; } } @@ -158,7 +199,7 @@ inline float DictValue::get(int idx) const template<> inline String DictValue::get(int idx) const { - CV_Assert(type == Param::STRING); + CV_Assert(isString()); CV_Assert(idx == -1 && ps->size() == 1 || idx >= 0 && idx < (int)ps->size()); return (*ps)[(idx == -1) ? 0 : idx]; } @@ -228,9 +269,9 @@ inline bool DictValue::isInt() const return (type == Param::INT); } -bool DictValue::isReal() const +inline bool DictValue::isReal() const { - return (type == Param::REAL); + return (type == Param::REAL || type == Param::INT); } int DictValue::size() const diff --git a/modules/dnn/src/caffe_importer.cpp b/modules/dnn/src/caffe_importer.cpp index 16d269e56..d673cbc63 100644 --- a/modules/dnn/src/caffe_importer.cpp +++ b/modules/dnn/src/caffe_importer.cpp @@ -63,32 +63,57 @@ namespace void addParam(const Message &msg, const FieldDescriptor *field, cv::dnn::LayerParams ¶ms) { - const Reflection *msgRefl = msg.GetReflection(); + const Reflection *refl = msg.GetReflection(); int type = field->cpp_type(); bool isRepeated = field->is_repeated(); const std::string &name = field->name(); - #define GET_FIRST(Type) (isRepeated ? msgRefl->GetRepeated##Type(msg, field, 0) : msgRefl->Get##Type(msg, field)) + #define SET_UP_FILED(getter, arrayConstr, gtype) \ + if (isRepeated) { \ + const RepeatedField &v = refl->GetRepeatedField(msg, field); \ + params.set(name, ##arrayConstr(v.begin(), (int)v.size())); \ + } \ + else { \ + params.set(name, refl->##getter(msg, field)); \ + } switch (type) { case FieldDescriptor::CPPTYPE_INT32: - params.set(name, GET_FIRST(Int32)); + SET_UP_FILED(GetInt32, DictValue::arrayInt, ::google::protobuf::int32); break; case FieldDescriptor::CPPTYPE_UINT32: - params.set(name, GET_FIRST(UInt32)); + SET_UP_FILED(GetUInt32, DictValue::arrayInt, ::google::protobuf::uint32); + break; + case FieldDescriptor::CPPTYPE_BOOL: + SET_UP_FILED(GetBool, DictValue::arrayInt, bool); break; case FieldDescriptor::CPPTYPE_DOUBLE: - params.set(name, GET_FIRST(Double)); + SET_UP_FILED(GetDouble, DictValue::arrayReal, double); break; case FieldDescriptor::CPPTYPE_FLOAT: - params.set(name, GET_FIRST(Float)); + SET_UP_FILED(GetFloat, DictValue::arrayReal, float); break; - case FieldDescriptor::CPPTYPE_ENUM: - params.set(name, GET_FIRST(Enum)->name()); + case FieldDescriptor::CPPTYPE_STRING: + if (isRepeated) { + const RepeatedPtrField &v = refl->GetRepeatedPtrField(msg, field); + params.set(name, DictValue::arrayString(v.begin(), (int)v.size())); + } + else { + params.set(name, refl->GetString(msg, field)); + } break; - case FieldDescriptor::CPPTYPE_BOOL: - params.set(name, GET_FIRST(Bool)); + case FieldDescriptor::CPPTYPE_ENUM: + if (isRepeated) { + int size = refl->FieldSize(msg, field); + std::vector buf(size); + for (int i = 0; i < size; i++) + buf[i] = refl->GetRepeatedEnum(msg, field, i)->name(); + params.set(name, DictValue::arrayString(buf.begin(), size)); + } + else { + params.set(name, refl->GetEnum(msg, field)->name()); + } break; default: CV_Error(Error::StsError, "Unknown type \"" + String(field->type_name()) + "\" in prototxt"); @@ -230,11 +255,11 @@ namespace addInput(layer.bottom(inNum), id, inNum, dstNet, addedBlobs); for (int outNum = 0; outNum < layer.top_size(); outNum++) - addOutput(layer, id, outNum, dstNet, addedBlobs); + addOutput(layer, id, outNum, addedBlobs); } } - void addOutput(const caffe::LayerParameter &layer, int layerId, int outNum, Net &dstNet, std::vector &addedBlobs) + void addOutput(const caffe::LayerParameter &layer, int layerId, int outNum, std::vector &addedBlobs) { const std::string &name = layer.top(outNum);