Merge pull request #15082 from dvd42:segmentation-module

Segmentation module (#15082)
pull/15298/head
Diego 5 years ago committed by Alexander Alekhin
parent 2ad0487cec
commit f7f2438478
  1. 30
      modules/dnn/include/opencv2/dnn/dnn.hpp
  2. 41
      modules/dnn/src/model.cpp
  3. 35
      modules/dnn/test/test_model.cpp

@ -1109,6 +1109,36 @@ CV__DNN_INLINE_NS_BEGIN
CV_WRAP void classify(InputArray frame, CV_OUT int& classId, CV_OUT float& conf);
};
/** @brief This class represents high-level API for segmentation models
*
* SegmentationModel allows to set params for preprocessing input image.
* SegmentationModel creates net from file with trained weights and config,
* sets preprocessing input, runs forward pass and returns the class prediction for each pixel.
*/
class CV_EXPORTS_W SegmentationModel: public Model
{
public:
/**
* @brief Create segmentation model from network represented in one of the supported formats.
* An order of @p model and @p config arguments does not matter.
* @param[in] model Binary file contains trained weights.
* @param[in] config Text file contains network configuration.
*/
CV_WRAP SegmentationModel(const String& model, const String& config = "");
/**
* @brief Create model from deep learning network.
* @param[in] network Net object.
*/
CV_WRAP SegmentationModel(const Net& network);
/** @brief Given the @p input frame, create input blob, run net
* @param[in] frame The input image.
* @param[out] mask Allocated class prediction for each pixel
*/
CV_WRAP void segment(InputArray frame, OutputArray mask);
};
/** @brief This class represents high-level API for object detection networks.
*
* DetectionModel allows to set params for preprocessing input image.

@ -137,6 +137,47 @@ void ClassificationModel::classify(InputArray frame, int& classId, float& conf)
std::tie(classId, conf) = classify(frame);
}
SegmentationModel::SegmentationModel(const String& model, const String& config)
: Model(model, config) {};
SegmentationModel::SegmentationModel(const Net& network) : Model(network) {};
void SegmentationModel::segment(InputArray frame, OutputArray mask)
{
std::vector<Mat> outs;
impl->predict(*this, frame.getMat(), outs);
CV_Assert(outs.size() == 1);
Mat score = outs[0];
const int chns = score.size[1];
const int rows = score.size[2];
const int cols = score.size[3];
mask.create(rows, cols, CV_8U);
Mat classIds = mask.getMat();
classIds.setTo(0);
Mat maxVal(rows, cols, CV_32F, score.data);
for (int ch = 1; ch < chns; ch++)
{
for (int row = 0; row < rows; row++)
{
const float *ptrScore = score.ptr<float>(0, ch, row);
uint8_t *ptrMaxCl = classIds.ptr<uint8_t>(row);
float *ptrMaxVal = maxVal.ptr<float>(row);
for (int col = 0; col < cols; col++)
{
if (ptrScore[col] > ptrMaxVal[col])
{
ptrMaxVal[col] = ptrScore[col];
ptrMaxCl[col] = ch;
}
}
}
}
}
DetectionModel::DetectionModel(const String& model, const String& config)
: Model(model, config) {};

@ -69,6 +69,25 @@ public:
EXPECT_EQ(prediction.first, ref.first);
ASSERT_NEAR(prediction.second, ref.second, norm);
}
void testSegmentationModel(const std::string& weights_file, const std::string& config_file,
const std::string& inImgPath, const std::string& outImgPath,
float norm, const Size& size = {-1, -1}, Scalar mean = Scalar(),
double scale = 1.0, bool swapRB = false, bool crop = false)
{
checkBackend();
Mat frame = imread(inImgPath);
Mat mask;
Mat exp = imread(outImgPath, 0);
SegmentationModel model(weights_file, config_file);
model.setInputSize(size).setInputMean(mean).setInputScale(scale)
.setInputSwapRB(swapRB).setInputCrop(crop);
model.segment(frame, mask);
normAssert(mask, exp, "", norm, norm);
}
};
TEST_P(Test_Model, Classify)
@ -202,6 +221,22 @@ TEST_P(Test_Model, DetectionMobilenetSSD)
scoreDiff, iouDiff, confThreshold, nmsThreshold, size, mean, scale);
}
TEST_P(Test_Model, Segmentation)
{
std::string inp = _tf("dog416.png");
std::string weights_file = _tf("fcn8s-heavy-pascal.prototxt");
std::string config_file = _tf("fcn8s-heavy-pascal.caffemodel");
std::string exp = _tf("segmentation_exp.png");
Size size{128, 128};
float norm = 0;
double scale = 1.0;
Scalar mean = Scalar();
bool swapRB = false;
testSegmentationModel(weights_file, config_file, inp, exp, norm, size, mean, scale, swapRB);
}
INSTANTIATE_TEST_CASE_P(/**/, Test_Model, dnnBackendsAndTargets());
}} // namespace

Loading…
Cancel
Save