Fixed SSD example, added test

pull/1214/head
Aleksandr Rybnikov 8 years ago
parent 38dd47cf40
commit 692ba7bafd
  1. 3078
      modules/dnn/misc/caffe/caffe.pb.cc
  2. 1201
      modules/dnn/misc/caffe/caffe.pb.h
  3. 5
      modules/dnn/samples/ssd_object_detection.cpp
  4. 59
      modules/dnn/src/caffe/caffe.proto
  5. 1
      modules/dnn/src/init.cpp
  6. 314
      modules/dnn/src/layers/detection_output_layer.cpp
  7. 2
      modules/dnn/src/layers/flatten_layer.cpp
  8. 13
      modules/dnn/src/layers/permute_layer.cpp
  9. 32
      modules/dnn/src/layers/prior_box_layer.cpp
  10. 28
      modules/dnn/test/test_caffe_importer.cpp

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -1,4 +1,5 @@
#include <opencv2/dnn.hpp>
#include <opencv2/dnn/shape_utils.hpp>
#include <opencv2/imgproc.hpp>
#include <opencv2/highgui.hpp>
using namespace cv;
@ -30,7 +31,7 @@ Mat getMean(const size_t& imageHeight, const size_t& imageWidth)
Mat preprocess(const Mat& frame)
{
Mat preprocessed;
frame.convertTo(preprocessed, CV_32FC3);
frame.convertTo(preprocessed, CV_32F);
resize(preprocessed, preprocessed, Size(width, height)); //SSD accepts 300x300 RGB-images
Mat mean = getMean(width, height);
@ -98,6 +99,8 @@ int main(int argc, char** argv)
cv::Mat frame = cv::imread(parser.get<string>("image"), -1);
if (frame.channels() == 4)
cvtColor(frame, frame, COLOR_BGRA2BGR);
//! [Prepare blob]
Mat preprocessedFrame = preprocess(frame);

@ -115,6 +115,21 @@ message PriorBoxParameter {
optional bool clip = 5 [default = true];
// Variance for adjusting the prior bboxes.
repeated float variance = 6;
// By default, we calculate img_height, img_width, step_x, step_y based on
// bottom[0] (feat) and bottom[1] (img). Unless these values are explicitely
// provided.
// Explicitly provide the img_size.
optional uint32 img_size = 7;
// Either img_size or img_h/img_w should be specified; not both.
optional uint32 img_h = 8;
optional uint32 img_w = 9;
// Explicitly provide the step size.
optional float step = 10;
// Either step or step_h/step_w should be specified; not both.
optional float step_h = 11;
optional float step_w = 12;
// Offset to the top left corner of each cell.
optional float offset = 13 [default = 0.5];
}
// Message that store parameters used by DetectionOutputLayer
@ -126,6 +141,10 @@ message DetectionOutputParameter {
// Background label id. If there is no background class,
// set it as -1.
optional int32 background_label_id = 3 [default = 0];
// Parameters used for non maximum suppression.
optional NonMaximumSuppressionParameter nms_param = 4;
// Parameters used for saving detection results.
optional SaveOutputParameter save_output_param = 5;
// Type of coding method for bbox.
optional PriorBoxParameter.CodeType code_type = 6 [default = CORNER];
// If true, variance is encoded in target; otherwise we need to adjust the
@ -137,11 +156,6 @@ message DetectionOutputParameter {
// Only consider detections whose confidences are larger than a threshold.
// If not provided, consider all boxes.
optional float confidence_threshold = 9;
// Parameters used for non maximum suppression.
// Threshold to be used in nms.
optional float nms_threshold = 10 [default = 0.3];
// Maximum number of results to be kept.
optional int32 top_k = 11;
}
message Datum {
@ -503,7 +517,7 @@ message LayerParameter {
optional LRNParameter lrn_param = 118;
optional MemoryDataParameter memory_data_param = 119;
optional MVNParameter mvn_param = 120;
optional NormalizeBBoxParameter normalize_bbox_param = 149;
optional NormalizeBBoxParameter norm_param = 149;
optional PermuteParameter permute_param = 148;
optional ParameterParameter parameter_param = 145;
optional PoolingParameter pooling_param = 121;
@ -781,6 +795,39 @@ message DataParameter {
optional uint32 prefetch = 10 [default = 4];
}
message NonMaximumSuppressionParameter {
// Threshold to be used in nms.
optional float nms_threshold = 1 [default = 0.3];
// Maximum number of results to be kept.
optional int32 top_k = 2;
// Parameter for adaptive nms.
optional float eta = 3 [default = 1.0];
}
message SaveOutputParameter {
// Output directory. If not empty, we will save the results.
optional string output_directory = 1;
// Output name prefix.
optional string output_name_prefix = 2;
// Output format.
// VOC - PASCAL VOC output format.
// COCO - MS COCO output format.
optional string output_format = 3;
// If you want to output results, must also provide the following two files.
// Otherwise, we will ignore saving results.
// label map file.
optional string label_map_file = 4;
// A file which contains a list of names and sizes with same order
// of the input DB. The file is in the following format:
// name height width
// ...
optional string name_size_file = 5;
// Number of test images. It can be less than the lines specified in
// name_size_file. For example, when we only want to evaluate on part
// of the test images.
optional uint32 num_test_image = 6;
}
message DropoutParameter {
optional float dropout_ratio = 1 [default = 0.5]; // dropout ratio
}

@ -95,6 +95,7 @@ void initModule()
REG_RUNTIME_LAYER_CLASS(PriorBox, PriorBoxLayer);
REG_RUNTIME_LAYER_CLASS(DetectionOutput, DetectionOutputLayer);
REG_RUNTIME_LAYER_CLASS(NormalizeBBox, NormalizeBBoxLayer);
REG_RUNTIME_LAYER_CLASS(Normalize, NormalizeBBoxLayer);
REG_RUNTIME_LAYER_CLASS(Shift, ShiftLayer);
REG_RUNTIME_LAYER_CLASS(Padding, PaddingLayer);
REG_RUNTIME_LAYER_CLASS(Scale, ScaleLayer);

@ -228,11 +228,12 @@ public:
std::vector<std::vector<float> > priorVariances;
GetPriorBBoxes(priorData, numPriors, &priorBBoxes, &priorVariances);
const bool clip_bbox = false;
// Decode all loc predictions to bboxes.
std::vector<LabelBBox> allDecodedBBoxes;
DecodeBBoxesAll(allLocationPredictions, priorBBoxes, priorVariances, num,
_shareLocation, _numLocClasses, _backgroundLabelId,
_codeType, _varianceEncodedInTarget, &allDecodedBBoxes);
_codeType, _varianceEncodedInTarget, clip_bbox, &allDecodedBBoxes);
int numKept = 0;
std::vector<std::map<int, std::vector<int> > > allIndices;
@ -266,7 +267,7 @@ public:
}
const std::vector<caffe::NormalizedBBox>& bboxes =
decodeBBoxes.find(label)->second;
ApplyNMSFast(bboxes, scores, _confidenceThreshold, _nmsThreshold,
ApplyNMSFast(bboxes, scores, _confidenceThreshold, _nmsThreshold, 1.0,
_topK, &(indices[c]));
numDetections += indices[c].size();
}
@ -358,8 +359,7 @@ public:
outputsData[count * 7] = i;
outputsData[count * 7 + 1] = label;
outputsData[count * 7 + 2] = scores[idx];
caffe::NormalizedBBox clipBBox;
ClipBBox(bboxes[idx], &clipBBox);
caffe::NormalizedBBox clipBBox = bboxes[idx];
outputsData[count * 7 + 3] = clipBBox.xmin();
outputsData[count * 7 + 4] = clipBBox.ymin();
outputsData[count * 7 + 5] = clipBBox.xmax();
@ -417,142 +417,126 @@ public:
}
// Decode a bbox according to a prior bbox.
void DecodeBBox(const caffe::NormalizedBBox& priorBBox, const std::vector<float>& priorVariance,
const CodeType codeType, const bool varianceEncodedInTarget,
const caffe::NormalizedBBox& bbox, caffe::NormalizedBBox* decodeBBox)
{
if (codeType == caffe::PriorBoxParameter_CodeType_CORNER)
{
if (varianceEncodedInTarget)
{
// variance is encoded in target, we simply need to add the offset
// predictions.
decodeBBox->set_xmin(priorBBox.xmin() + bbox.xmin());
decodeBBox->set_ymin(priorBBox.ymin() + bbox.ymin());
decodeBBox->set_xmax(priorBBox.xmax() + bbox.xmax());
decodeBBox->set_ymax(priorBBox.ymax() + bbox.ymax());
}
else
{
// variance is encoded in bbox, we need to scale the offset accordingly.
decodeBBox->set_xmin(
priorBBox.xmin() + priorVariance[0] * bbox.xmin());
decodeBBox->set_ymin(
priorBBox.ymin() + priorVariance[1] * bbox.ymin());
decodeBBox->set_xmax(
priorBBox.xmax() + priorVariance[2] * bbox.xmax());
decodeBBox->set_ymax(
priorBBox.ymax() + priorVariance[3] * bbox.ymax());
}
void DecodeBBox(
const caffe::NormalizedBBox& prior_bbox, const std::vector<float>& prior_variance,
const CodeType code_type, const bool variance_encoded_in_target,
const bool clip_bbox, const caffe::NormalizedBBox& bbox,
caffe::NormalizedBBox* decode_bbox) {
if (code_type == caffe::PriorBoxParameter_CodeType_CORNER) {
if (variance_encoded_in_target) {
// variance is encoded in target, we simply need to add the offset
// predictions.
decode_bbox->set_xmin(prior_bbox.xmin() + bbox.xmin());
decode_bbox->set_ymin(prior_bbox.ymin() + bbox.ymin());
decode_bbox->set_xmax(prior_bbox.xmax() + bbox.xmax());
decode_bbox->set_ymax(prior_bbox.ymax() + bbox.ymax());
} else {
// variance is encoded in bbox, we need to scale the offset accordingly.
decode_bbox->set_xmin(
prior_bbox.xmin() + prior_variance[0] * bbox.xmin());
decode_bbox->set_ymin(
prior_bbox.ymin() + prior_variance[1] * bbox.ymin());
decode_bbox->set_xmax(
prior_bbox.xmax() + prior_variance[2] * bbox.xmax());
decode_bbox->set_ymax(
prior_bbox.ymax() + prior_variance[3] * bbox.ymax());
}
else if (codeType == caffe::PriorBoxParameter_CodeType_CENTER_SIZE)
{
float priorWidth = priorBBox.xmax() - priorBBox.xmin();
CV_Assert(priorWidth > 0);
float priorHeight = priorBBox.ymax() - priorBBox.ymin();
CV_Assert(priorHeight > 0);
float priorCenterX = (priorBBox.xmin() + priorBBox.xmax()) / 2.;
float priorCenterY = (priorBBox.ymin() + priorBBox.ymax()) / 2.;
float decodeBBoxCenterX, decodeBBoxCenterY;
float decodeBBoxWidth, decodeBBoxHeight;
if (varianceEncodedInTarget)
{
// variance is encoded in target, we simply need to retore the offset
// predictions.
decodeBBoxCenterX = bbox.xmin() * priorWidth + priorCenterX;
decodeBBoxCenterY = bbox.ymin() * priorHeight + priorCenterY;
decodeBBoxWidth = exp(bbox.xmax()) * priorWidth;
decodeBBoxHeight = exp(bbox.ymax()) * priorHeight;
}
else
{
// variance is encoded in bbox, we need to scale the offset accordingly.
decodeBBoxCenterX =
priorVariance[0] * bbox.xmin() * priorWidth + priorCenterX;
decodeBBoxCenterY =
priorVariance[1] * bbox.ymin() * priorHeight + priorCenterY;
decodeBBoxWidth =
exp(priorVariance[2] * bbox.xmax()) * priorWidth;
decodeBBoxHeight =
exp(priorVariance[3] * bbox.ymax()) * priorHeight;
}
decodeBBox->set_xmin(decodeBBoxCenterX - decodeBBoxWidth / 2.);
decodeBBox->set_ymin(decodeBBoxCenterY - decodeBBoxHeight / 2.);
decodeBBox->set_xmax(decodeBBoxCenterX + decodeBBoxWidth / 2.);
decodeBBox->set_ymax(decodeBBoxCenterY + decodeBBoxHeight / 2.);
} else if (code_type == caffe::PriorBoxParameter_CodeType_CENTER_SIZE) {
float prior_width = prior_bbox.xmax() - prior_bbox.xmin();
CV_Assert(prior_width > 0);
float prior_height = prior_bbox.ymax() - prior_bbox.ymin();
CV_Assert(prior_height > 0);
float prior_center_x = (prior_bbox.xmin() + prior_bbox.xmax()) / 2.;
float prior_center_y = (prior_bbox.ymin() + prior_bbox.ymax()) / 2.;
float decode_bbox_center_x, decode_bbox_center_y;
float decode_bbox_width, decode_bbox_height;
if (variance_encoded_in_target) {
// variance is encoded in target, we simply need to retore the offset
// predictions.
decode_bbox_center_x = bbox.xmin() * prior_width + prior_center_x;
decode_bbox_center_y = bbox.ymin() * prior_height + prior_center_y;
decode_bbox_width = exp(bbox.xmax()) * prior_width;
decode_bbox_height = exp(bbox.ymax()) * prior_height;
} else {
// variance is encoded in bbox, we need to scale the offset accordingly.
decode_bbox_center_x =
prior_variance[0] * bbox.xmin() * prior_width + prior_center_x;
decode_bbox_center_y =
prior_variance[1] * bbox.ymin() * prior_height + prior_center_y;
decode_bbox_width =
exp(prior_variance[2] * bbox.xmax()) * prior_width;
decode_bbox_height =
exp(prior_variance[3] * bbox.ymax()) * prior_height;
}
else
{
CV_Error(Error::StsBadArg, "Unknown LocLossType.");
}
float bboxSize = BBoxSize(*decodeBBox);
decodeBBox->set_size(bboxSize);
decode_bbox->set_xmin(decode_bbox_center_x - decode_bbox_width / 2.);
decode_bbox->set_ymin(decode_bbox_center_y - decode_bbox_height / 2.);
decode_bbox->set_xmax(decode_bbox_center_x + decode_bbox_width / 2.);
decode_bbox->set_ymax(decode_bbox_center_y + decode_bbox_height / 2.);
} else {
CV_Error(Error::StsBadArg, "Unknown LocLossType.");
}
float bbox_size = BBoxSize(*decode_bbox);
decode_bbox->set_size(bbox_size);
if (clip_bbox) {
ClipBBox(*decode_bbox, decode_bbox);
}
}
// Decode a set of bboxes according to a set of prior bboxes.
void DecodeBBoxes(const std::vector<caffe::NormalizedBBox>& priorBBoxes,
const std::vector<std::vector<float> >& priorVariances,
const CodeType codeType, const bool varianceEncodedInTarget,
const std::vector<caffe::NormalizedBBox>& bboxes,
std::vector<caffe::NormalizedBBox>* decodeBBoxes)
{
CV_Assert(priorBBoxes.size() == priorVariances.size());
CV_Assert(priorBBoxes.size() == bboxes.size());
int numBBoxes = priorBBoxes.size();
if (numBBoxes >= 1)
{
CV_Assert(priorVariances[0].size() == 4);
}
decodeBBoxes->clear();
for (int i = 0; i < numBBoxes; ++i)
{
caffe::NormalizedBBox decodeBBox;
DecodeBBox(priorBBoxes[i], priorVariances[i], codeType,
varianceEncodedInTarget, bboxes[i], &decodeBBox);
decodeBBoxes->push_back(decodeBBox);
}
void DecodeBBoxes(
const std::vector<caffe::NormalizedBBox>& prior_bboxes,
const std::vector<std::vector<float> >& prior_variances,
const CodeType code_type, const bool variance_encoded_in_target,
const bool clip_bbox, const std::vector<caffe::NormalizedBBox>& bboxes,
std::vector<caffe::NormalizedBBox>* decode_bboxes) {
CV_Assert(prior_bboxes.size() == prior_variances.size());
CV_Assert(prior_bboxes.size() == bboxes.size());
int num_bboxes = prior_bboxes.size();
if (num_bboxes >= 1) {
CV_Assert(prior_variances[0].size() == 4);
}
decode_bboxes->clear();
for (int i = 0; i < num_bboxes; ++i) {
caffe::NormalizedBBox decode_bbox;
DecodeBBox(prior_bboxes[i], prior_variances[i], code_type,
variance_encoded_in_target, clip_bbox, bboxes[i], &decode_bbox);
decode_bboxes->push_back(decode_bbox);
}
}
// Decode all bboxes in a batch.
void DecodeBBoxesAll(const std::vector<LabelBBox>& allLocPreds,
const std::vector<caffe::NormalizedBBox>& priorBBoxes,
const std::vector<std::vector<float> >& priorVariances,
const size_t num, const bool shareLocation,
const int numLocClasses, const int backgroundLabelId,
const CodeType codeType, const bool varianceEncodedInTarget,
std::vector<LabelBBox>* allDecodeBBoxes)
{
CV_Assert(allLocPreds.size() == num);
allDecodeBBoxes->clear();
allDecodeBBoxes->resize(num);
for (size_t i = 0; i < num; ++i)
{
// Decode predictions into bboxes.
LabelBBox& decodeBBoxes = (*allDecodeBBoxes)[i];
for (int c = 0; c < numLocClasses; ++c)
{
int label = shareLocation ? -1 : c;
if (label == backgroundLabelId)
{
// Ignore background class.
continue;
}
if (allLocPreds[i].find(label) == allLocPreds[i].end())
{
// Something bad happened if there are no predictions for current label.
util::make_error<int>("Could not find location predictions for label ", label);
}
const std::vector<caffe::NormalizedBBox>& labelLocPreds =
allLocPreds[i].find(label)->second;
DecodeBBoxes(priorBBoxes, priorVariances,
codeType, varianceEncodedInTarget,
labelLocPreds, &(decodeBBoxes[label]));
}
void DecodeBBoxesAll(const std::vector<LabelBBox>& all_loc_preds,
const std::vector<caffe::NormalizedBBox>& prior_bboxes,
const std::vector<std::vector<float> >& prior_variances,
const int num, const bool share_location,
const int num_loc_classes, const int background_label_id,
const CodeType code_type, const bool variance_encoded_in_target,
const bool clip, std::vector<LabelBBox>* all_decode_bboxes) {
CV_Assert(all_loc_preds.size() == num);
all_decode_bboxes->clear();
all_decode_bboxes->resize(num);
for (int i = 0; i < num; ++i) {
// Decode predictions into bboxes.
LabelBBox& decode_bboxes = (*all_decode_bboxes)[i];
for (int c = 0; c < num_loc_classes; ++c) {
int label = share_location ? -1 : c;
if (label == background_label_id) {
// Ignore background class.
continue;
}
if (all_loc_preds[i].find(label) == all_loc_preds[i].end()) {
// Something bad happened if there are no predictions for current label.
util::make_error<int>("Could not find location predictions for label ", label);
}
const std::vector<caffe::NormalizedBBox>& label_loc_preds =
all_loc_preds[i].find(label)->second;
DecodeBBoxes(prior_bboxes, prior_variances,
code_type, variance_encoded_in_target, clip,
label_loc_preds, &(decode_bboxes[label]));
}
}
}
// Get prior bounding boxes from prior_data.
@ -686,43 +670,39 @@ public:
// top_k: if not -1, keep at most top_k picked indices.
// indices: the kept indices of bboxes after nms.
void ApplyNMSFast(const std::vector<caffe::NormalizedBBox>& bboxes,
const std::vector<float>& scores,
const float score_threshold,
const float nms_threshold, const int top_k,
std::vector<int>* indices)
{
// Sanity check.
CV_Assert(bboxes.size() == scores.size());
// Get top_k scores (with corresponding indices).
std::vector<std::pair<float, int> > score_index_vec;
GetMaxScoreIndex(scores, score_threshold, top_k, &score_index_vec);
// Do nms.
indices->clear();
while (score_index_vec.size() != 0)
{
const int idx = score_index_vec.front().second;
bool keep = true;
for (size_t k = 0; k < indices->size(); ++k)
{
if (keep)
{
const int kept_idx = (*indices)[k];
float overlap = JaccardOverlap(bboxes[idx], bboxes[kept_idx]);
keep = overlap <= nms_threshold;
}
else
{
break;
}
}
if (keep)
{
indices->push_back(idx);
}
score_index_vec.erase(score_index_vec.begin());
const std::vector<float>& scores, const float score_threshold,
const float nms_threshold, const float eta, const int top_k,
std::vector<int>* indices) {
// Sanity check.
CV_Assert(bboxes.size() == scores.size());
// Get top_k scores (with corresponding indices).
std::vector<std::pair<float, int> > score_index_vec;
GetMaxScoreIndex(scores, score_threshold, top_k, &score_index_vec);
// Do nms.
float adaptive_threshold = nms_threshold;
indices->clear();
while (score_index_vec.size() != 0) {
const int idx = score_index_vec.front().second;
bool keep = true;
for (int k = 0; k < indices->size(); ++k) {
if (keep) {
const int kept_idx = (*indices)[k];
float overlap = JaccardOverlap(bboxes[idx], bboxes[kept_idx]);
keep = overlap <= adaptive_threshold;
} else {
break;
}
}
if (keep) {
indices->push_back(idx);
}
score_index_vec.erase(score_index_vec.begin());
if (keep && eta < 1 && adaptive_threshold > 0.5) {
adaptive_threshold *= eta;
}
}
}
// Get max scores with corresponding indices.

@ -84,7 +84,7 @@ public:
CV_Assert(startAxis >= 0);
CV_Assert(endAxis >= startAxis && endAxis < (int)numAxes);
size_t flattenedDimensionSize = total(inputs[0], startAxis, endAxis);
size_t flattenedDimensionSize = total(inputs[0], startAxis, endAxis + 1);
MatShape outputShapeVec;
for (int i = 0; i < startAxis; i++)

@ -124,7 +124,7 @@ public:
MatShape shapeBefore = inputs[0], shapeAfter;
for (size_t i = 0; i < _numAxes; i++)
{
shapeAfter[i] = shapeBefore[_order[i]];
shapeAfter.push_back(shapeBefore[_order[i]]);
}
outputs.clear();
@ -132,6 +132,7 @@ public:
for (size_t i = 0; i < inputs.size(); i++)
{
CV_Assert(inputs[i][2] == shapeBefore[2] && inputs[i][3] == shapeBefore[3]);
CV_Assert(total(inputs[i]) == total(shapeAfter));
outputs.push_back(shapeAfter);
}
@ -192,11 +193,11 @@ public:
CV_Assert(inp.dims == numAxes && inp.size == inputs[0]->size);
CV_Assert(out.dims == numAxes && out.size == outputs[0].size);
for( i = 0; i < numAxes; i++ )
{
CV_Assert(inp.size[i] == _oldDimensionSize[i]);
CV_Assert(out.size[i] == _newDimensionSize[i]);
}
// for( i = 0; i < numAxes; i++ )
// {
// CV_Assert(inp.size[i] == _oldDimensionSize[i]);
// CV_Assert(out.size[i] == _newDimensionSize[i]);
// }
CV_Assert(inp.isContinuous() && out.isContinuous());
CV_Assert(inp.type() == CV_32F && out.type() == CV_32F);

@ -183,6 +183,22 @@ public:
_numPriors += 1;
}
if (params.has("step_h") || params.has("step_w")) {
CV_Assert(!params.has("step"));
_stepY = getParameter<float>(params, "step_h");
CV_Assert(_stepY > 0.);
_stepX = getParameter<float>(params, "step_w");
CV_Assert(_stepX > 0.);
} else if (params.has("step")) {
const float step = getParameter<float>(params, "step");
CV_Assert(step > 0);
_stepY = step;
_stepX = step;
} else {
_stepY = 0;
_stepX = 0;
}
}
bool getMemoryShapes(const std::vector<MatShape> &inputs,
@ -216,8 +232,14 @@ public:
int _imageWidth = inputs[1]->size[3];
int _imageHeight = inputs[1]->size[2];
float _stepX = static_cast<float>(_imageWidth) / _layerWidth;
float _stepY = static_cast<float>(_imageHeight) / _layerHeight;
float stepX, stepY;
if (_stepX == 0 || _stepY == 0) {
stepX = static_cast<float>(_imageWidth) / _layerWidth;
stepY = static_cast<float>(_imageHeight) / _layerHeight;
} else {
stepX = _stepX;
stepY = _stepY;
}
int _outChannelSize = _layerHeight * _layerWidth * _numPriors * 4;
@ -231,8 +253,8 @@ public:
{
_boxWidth = _boxHeight = _minSize;
float center_x = (w + 0.5) * _stepX;
float center_y = (h + 0.5) * _stepY;
float center_x = (w + 0.5) * stepX;
float center_y = (h + 0.5) * stepY;
// xmin
outputPtr[idx++] = (center_x - _boxWidth / 2.) / _imageWidth;
// ymin
@ -332,6 +354,8 @@ public:
float _boxWidth;
float _boxHeight;
float _stepX, _stepY;
std::vector<float> _aspectRatios;
std::vector<float> _variance;

@ -130,4 +130,32 @@ TEST(Reproducibility_FCN, Accuracy)
}
#endif
TEST(Reproducibility_SSD, Accuracy)
{
Net net;
{
const string proto = findDataFile("dnn/ssd_vgg16.prototxt", false);
const string model = findDataFile("dnn/VGG_ILSVRC2016_SSD_300x300_iter_440000.caffemodel", false);
Ptr<Importer> importer = createCaffeImporter(proto, model);
ASSERT_TRUE(importer != NULL);
importer->populateNet(net);
}
Mat sample = imread(_tf("street.png"));
ASSERT_TRUE(!sample.empty());
if (sample.channels() == 4)
cvtColor(sample, sample, COLOR_BGRA2BGR);
sample.convertTo(sample, CV_32F);
resize(sample, sample, Size(300, 300));
Mat in_blob = blobFromImage(sample);
net.setBlob(".data", in_blob);
net.forward();
Mat out = net.getBlob("detection_out");
Mat ref = blobFromNPY(_tf("ssd_out.npy"));
normAssert(ref, out);
}
}

Loading…
Cancel
Save