Add new method to Net

pull/265/head
Vitaliy Lyudvichenko 10 years ago
parent 7acfda2c11
commit 9d932af746
  1. 7
      modules/dnn/include/opencv2/dnn/dnn.hpp
  2. 33
      modules/dnn/src/dnn.cpp
  3. 4
      modules/dnn/src/layers/im2col.cpp
  4. 6
      modules/dnn/src/layers/im2col.hpp

@ -49,13 +49,14 @@ namespace dnn
~Net(); ~Net();
int addLayer(const String &name, const String &type, LayerParams &params); int addLayer(const String &name, const String &type, LayerParams &params);
int getLayerId(LayerId layer); int addLayerToPrev(const String &name, const String &type, LayerParams &params);
void deleteLayer(LayerId layer);
void setNetInputs(const std::vector<String> &inputBlobNames); int getLayerId(const String &layer);
void deleteLayer(LayerId layer);
void connect(String outPin, String inpPin); void connect(String outPin, String inpPin);
void connect(int outLayerId, int outNum, int inLayerId, int inNum); void connect(int outLayerId, int outNum, int inLayerId, int inNum);
void setNetInputs(const std::vector<String> &inputBlobNames);
void forward(); void forward();
void forward(LayerId toLayer); void forward(LayerId toLayer);

@ -383,7 +383,7 @@ int Net::addLayer(const String &name, const String &type, LayerParams &params)
{ {
if (name.find('.') != String::npos) if (name.find('.') != String::npos)
{ {
CV_Error(Error::StsBadArg, "Added layer name \"" + name + "\" should not contain dot symbol"); CV_Error(Error::StsBadArg, "Added layer name \"" + name + "\" must not contain dot symbol");
return -1; return -1;
} }
@ -400,6 +400,14 @@ int Net::addLayer(const String &name, const String &type, LayerParams &params)
return id; return id;
} }
int Net::addLayerToPrev(const String &name, const String &type, LayerParams &params)
{
int prvLid = impl->lastLayerId;
int newLid = this->addLayer(name, type, params);
this->connect(prvLid, 0, newLid, 0);
return newLid;
}
void Net::connect(int outLayerId, int outNum, int inLayerId, int inNum) void Net::connect(int outLayerId, int outNum, int inLayerId, int inNum)
{ {
impl->connect(outLayerId, outNum, inLayerId, inNum); impl->connect(outLayerId, outNum, inLayerId, inNum);
@ -467,6 +475,18 @@ Blob Net::getParam(LayerId layer, int numParam)
return layerBlobs[numParam]; return layerBlobs[numParam];
} }
int Net::getLayerId(const String &layer)
{
return impl->getLayerId(layer);
}
void Net::deleteLayer(LayerId)
{
CV_Error(Error::StsNotImplemented, "");
}
//////////////////////////////////////////////////////////////////////////
Importer::~Importer() Importer::~Importer()
{ {
@ -530,16 +550,5 @@ Ptr<Layer> LayerRegister::createLayerInstance(const String &_type, LayerParams&
} }
} }
int Net::getLayerId(LayerId)
{
CV_Error(Error::StsNotImplemented, "");
return -1;
}
void Net::deleteLayer(LayerId)
{
CV_Error(Error::StsNotImplemented, "");
}
} }
} }

@ -9,10 +9,10 @@ namespace dnn
{ {
void im2col_ocl(UMat &img, void im2col_ocl(UMat &img,
int channels, int height, int width, int channels, int height, int width,
int kernel_h, int kernel_w, int kernel_h, int kernel_w,
int pad_h, int pad_w, int pad_h, int pad_w,
int stride_h, int stride_w, int stride_h, int stride_w,
UMat &col) UMat &col)
{ {
int h_out = (height + 2 * pad_h - kernel_h) / stride_h + 1; int h_out = (height + 2 * pad_h - kernel_h) / stride_h + 1;

@ -7,8 +7,8 @@ namespace dnn
{ {
template <typename Dtype> template <typename Dtype>
void im2col_cpu(const Dtype* data_im, void im2col_cpu(const Dtype* data_im,
int channels, int height, int width, int channels, int height, int width,
int kernel_h, int kernel_w, int kernel_h, int kernel_w,
int pad_h, int pad_w, int pad_h, int pad_w,
int stride_h, int stride_w, int stride_h, int stride_w,
@ -36,7 +36,7 @@ void im2col_cpu(const Dtype* data_im,
} }
template <typename Dtype> template <typename Dtype>
void col2im_cpu(const Dtype* data_col, void col2im_cpu(const Dtype* data_col,
int channels, int height, int width, int channels, int height, int width,
int patch_h, int patch_w, int patch_h, int patch_w,
int pad_h, int pad_w, int pad_h, int pad_w,

Loading…
Cancel
Save