Array of parmaters support into caffe_importer.cpp

pull/265/head
Vitaliy Lyudvichenko 9 years ago
parent 0362da927c
commit 3aa37d2971
  1. 59
      modules/dnn/include/opencv2/dnn/dict.hpp
  2. 49
      modules/dnn/src/caffe_importer.cpp

@ -17,6 +17,13 @@ struct DictValue
DictValue(double p) : type(Param::REAL), pd(new AutoBuffer<double,1>) { (*pd)[0] = p; }
DictValue(const String &p) : type(Param::STRING), ps(new AutoBuffer<String,1>) { (*ps)[0] = p; }
template<typename TypeIter>
static DictValue arrayInt(TypeIter begin, int size);
template<typename TypeIter>
static DictValue arrayReal(TypeIter begin, int size);
template<typename TypeIter>
static DictValue arrayString(TypeIter begin, int size);
template<typename T>
T get(int idx = -1) const;
@ -39,11 +46,40 @@ protected:
AutoBuffer<int64, 1> *pi;
AutoBuffer<double, 1> *pd;
AutoBuffer<String, 1> *ps;
void *p;
};
DictValue(int _type, void *_p) : type(_type), p(_p) {}
void release();
};
template<typename TypeIter>
DictValue DictValue::arrayInt(TypeIter begin, int size)
{
DictValue res(Param::INT, new AutoBuffer<int64, 1>(size));
for (int j = 0; j < size; begin++, j++)
(*res.pi)[j] = *begin;
return res;
}
template<typename TypeIter>
DictValue DictValue::arrayReal(TypeIter begin, int size)
{
DictValue res(Param::REAL, new AutoBuffer<double, 1>(size));
for (int j = 0; j < size; begin++, j++)
(*res.pd)[j] = *begin;
return res;
}
template<typename TypeIter>
DictValue DictValue::arrayString(TypeIter begin, int size)
{
DictValue res(Param::STRING, new AutoBuffer<String, 1>(size));
for (int j = 0; j < size; begin++, j++)
(*res.ps)[j] = *begin;
return res;
}
class CV_EXPORTS Dict
{
typedef std::map<String, DictValue> _Dict;
@ -62,13 +98,18 @@ public:
return (i == dict.end()) ? NULL : &i->second;
}
template <typename T>
T get(const String &name) const
const DictValue &get(const String &name) const
{
_Dict::const_iterator i = dict.find(name);
if (i == dict.end())
CV_Error(cv::Error::StsBadArg, "Required argument \"" + name + "\" not found into dictionary");
return i->second.get<T>();
CV_Error(Error::StsBadArg, "Required argument \"" + name + "\" not found into dictionary");
return i->second;
}
template <typename T>
T get(const String &name) const
{
return this->get(name).get<T>();
}
template <typename T>
@ -106,7 +147,7 @@ inline DictValue DictValue::get<DictValue>(int idx) const
template<>
inline int64 DictValue::get<int64>(int idx) const
{
CV_Assert(type == Param::INT);
CV_Assert(isInt());
CV_Assert(idx == -1 && pi->size() == 1 || idx >= 0 && idx < (int)pi->size());
return (*pi)[(idx == -1) ? 0 : idx];
}
@ -144,7 +185,7 @@ inline double DictValue::get<double>(int idx) const
}
else
{
CV_Assert(type == Param::REAL || type == Param::INT);
CV_Assert(isReal());
return 0;
}
}
@ -158,7 +199,7 @@ inline float DictValue::get<float>(int idx) const
template<>
inline String DictValue::get<String>(int idx) const
{
CV_Assert(type == Param::STRING);
CV_Assert(isString());
CV_Assert(idx == -1 && ps->size() == 1 || idx >= 0 && idx < (int)ps->size());
return (*ps)[(idx == -1) ? 0 : idx];
}
@ -228,9 +269,9 @@ inline bool DictValue::isInt() const
return (type == Param::INT);
}
bool DictValue::isReal() const
inline bool DictValue::isReal() const
{
return (type == Param::REAL);
return (type == Param::REAL || type == Param::INT);
}
int DictValue::size() const

@ -63,32 +63,57 @@ namespace
void addParam(const Message &msg, const FieldDescriptor *field, cv::dnn::LayerParams &params)
{
const Reflection *msgRefl = msg.GetReflection();
const Reflection *refl = msg.GetReflection();
int type = field->cpp_type();
bool isRepeated = field->is_repeated();
const std::string &name = field->name();
#define GET_FIRST(Type) (isRepeated ? msgRefl->GetRepeated##Type(msg, field, 0) : msgRefl->Get##Type(msg, field))
#define SET_UP_FILED(getter, arrayConstr, gtype) \
if (isRepeated) { \
const RepeatedField<gtype> &v = refl->GetRepeatedField<gtype>(msg, field); \
params.set(name, ##arrayConstr(v.begin(), (int)v.size())); \
} \
else { \
params.set(name, refl->##getter(msg, field)); \
}
switch (type)
{
case FieldDescriptor::CPPTYPE_INT32:
params.set(name, GET_FIRST(Int32));
SET_UP_FILED(GetInt32, DictValue::arrayInt, ::google::protobuf::int32);
break;
case FieldDescriptor::CPPTYPE_UINT32:
params.set(name, GET_FIRST(UInt32));
SET_UP_FILED(GetUInt32, DictValue::arrayInt, ::google::protobuf::uint32);
break;
case FieldDescriptor::CPPTYPE_BOOL:
SET_UP_FILED(GetBool, DictValue::arrayInt, bool);
break;
case FieldDescriptor::CPPTYPE_DOUBLE:
params.set(name, GET_FIRST(Double));
SET_UP_FILED(GetDouble, DictValue::arrayReal, double);
break;
case FieldDescriptor::CPPTYPE_FLOAT:
params.set(name, GET_FIRST(Float));
SET_UP_FILED(GetFloat, DictValue::arrayReal, float);
break;
case FieldDescriptor::CPPTYPE_ENUM:
params.set(name, GET_FIRST(Enum)->name());
case FieldDescriptor::CPPTYPE_STRING:
if (isRepeated) {
const RepeatedPtrField<std::string> &v = refl->GetRepeatedPtrField<std::string>(msg, field);
params.set(name, DictValue::arrayString(v.begin(), (int)v.size()));
}
else {
params.set(name, refl->GetString(msg, field));
}
break;
case FieldDescriptor::CPPTYPE_BOOL:
params.set(name, GET_FIRST(Bool));
case FieldDescriptor::CPPTYPE_ENUM:
if (isRepeated) {
int size = refl->FieldSize(msg, field);
std::vector<cv::String> buf(size);
for (int i = 0; i < size; i++)
buf[i] = refl->GetRepeatedEnum(msg, field, i)->name();
params.set(name, DictValue::arrayString(buf.begin(), size));
}
else {
params.set(name, refl->GetEnum(msg, field)->name());
}
break;
default:
CV_Error(Error::StsError, "Unknown type \"" + String(field->type_name()) + "\" in prototxt");
@ -230,11 +255,11 @@ namespace
addInput(layer.bottom(inNum), id, inNum, dstNet, addedBlobs);
for (int outNum = 0; outNum < layer.top_size(); outNum++)
addOutput(layer, id, outNum, dstNet, addedBlobs);
addOutput(layer, id, outNum, addedBlobs);
}
}
void addOutput(const caffe::LayerParameter &layer, int layerId, int outNum, Net &dstNet, std::vector<BlobNote> &addedBlobs)
void addOutput(const caffe::LayerParameter &layer, int layerId, int outNum, std::vector<BlobNote> &addedBlobs)
{
const std::string &name = layer.top(outNum);

Loading…
Cancel
Save