diff --git a/samples/dnn/ssd_mobilenet_object_detection.cpp b/samples/dnn/ssd_mobilenet_object_detection.cpp index d7b2cbf8be..e04f1c3e5e 100644 --- a/samples/dnn/ssd_mobilenet_object_detection.cpp +++ b/samples/dnn/ssd_mobilenet_object_detection.cpp @@ -13,7 +13,6 @@ using namespace std; const size_t inWidth = 300; const size_t inHeight = 300; -const float WHRatio = inWidth / (float)inHeight; const float inScaleFactor = 0.007843f; const float meanVal = 127.5; const char* classNames[] = {"background", @@ -23,13 +22,6 @@ const char* classNames[] = {"background", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"}; -const char* about = "This sample uses MobileNet Single-Shot Detector " - "(https://arxiv.org/abs/1704.04861) " - "to detect objects on camera/video/image.\n" - ".caffemodel model's file is available here: " - "https://github.com/chuanqi305/MobileNet-SSD\n" - "Default network is 300x300 and 20-classes VOC.\n"; - const char* params = "{ help | false | print usage }" "{ proto | MobileNetSSD_deploy.prototxt | model configuration }" @@ -44,16 +36,22 @@ const char* params int main(int argc, char** argv) { CommandLineParser parser(argc, argv, params); - - if (parser.get("help")) + parser.about("This sample uses MobileNet Single-Shot Detector " + "(https://arxiv.org/abs/1704.04861) " + "to detect objects on camera/video/image.\n" + ".caffemodel model's file is available here: " + "https://github.com/chuanqi305/MobileNet-SSD\n" + "Default network is 300x300 and 20-classes VOC.\n"); + + if (parser.get("help") || argc == 1) { - cout << about << endl; parser.printMessage(); return 0; } String modelConfiguration = parser.get("proto"); String modelBinary = parser.get("model"); + CV_Assert(!modelConfiguration.empty() && !modelBinary.empty()); //! [Initialize network] dnn::Net net = readNetFromCaffe(modelConfiguration, modelBinary); @@ -75,7 +73,7 @@ int main(int argc, char** argv) } VideoCapture cap; - if (parser.get("video").empty()) + if (!parser.has("video")) { int cameraDevice = parser.get("camera_device"); cap = VideoCapture(cameraDevice); @@ -95,32 +93,16 @@ int main(int argc, char** argv) } } - Size inVideoSize; - inVideoSize = Size((int) cap.get(CV_CAP_PROP_FRAME_WIDTH), //Acquire input size - (int) cap.get(CV_CAP_PROP_FRAME_HEIGHT)); - - Size cropSize; - if (inVideoSize.width / (float)inVideoSize.height > WHRatio) - { - cropSize = Size(static_cast(inVideoSize.height * WHRatio), - inVideoSize.height); - } - else - { - cropSize = Size(inVideoSize.width, - static_cast(inVideoSize.width / WHRatio)); - } - - Rect crop(Point((inVideoSize.width - cropSize.width) / 2, - (inVideoSize.height - cropSize.height) / 2), - cropSize); + //Acquire input size + Size inVideoSize((int) cap.get(CV_CAP_PROP_FRAME_WIDTH), + (int) cap.get(CV_CAP_PROP_FRAME_HEIGHT)); double fps = cap.get(CV_CAP_PROP_FPS); int fourcc = static_cast(cap.get(CV_CAP_PROP_FOURCC)); VideoWriter outputVideo; outputVideo.open(parser.get("out") , (fourcc != 0 ? fourcc : VideoWriter::fourcc('M','J','P','G')), - (fps != 0 ? fps : 10.0), cropSize, true); + (fps != 0 ? fps : 10.0), inVideoSize, true); for(;;) { @@ -138,15 +120,17 @@ int main(int argc, char** argv) //! [Prepare blob] Mat inputBlob = blobFromImage(frame, inScaleFactor, - Size(inWidth, inHeight), meanVal, false); //Convert Mat to batch of images + Size(inWidth, inHeight), + Scalar(meanVal, meanVal, meanVal), + false, false); //Convert Mat to batch of images //! [Prepare blob] //! [Set input blob] - net.setInput(inputBlob, "data"); //set the network input + net.setInput(inputBlob); //set the network input //! [Set input blob] //! [Make forward pass] - Mat detection = net.forward("detection_out"); //compute output + Mat detection = net.forward(); //compute output //! [Make forward pass] vector layersTimings; @@ -155,13 +139,10 @@ int main(int argc, char** argv) Mat detectionMat(detection.size[2], detection.size[3], CV_32F, detection.ptr()); - frame = frame(crop); - - ostringstream ss; if (!outputVideo.isOpened()) { - ss << "FPS: " << 1000/time << " ; time: " << time << " ms"; - putText(frame, ss.str(), Point(20,20), 0, 0.5, Scalar(0,0,255)); + putText(frame, format("FPS: %.2f ; time: %.2f ms", 1000.f/time, time), + Point(20,20), 0, 0.5, Scalar(0,0,255)); } else cout << "Inference time, ms: " << time << endl; @@ -175,27 +156,20 @@ int main(int argc, char** argv) { size_t objectClass = (size_t)(detectionMat.at(i, 1)); - int xLeftBottom = static_cast(detectionMat.at(i, 3) * frame.cols); - int yLeftBottom = static_cast(detectionMat.at(i, 4) * frame.rows); - int xRightTop = static_cast(detectionMat.at(i, 5) * frame.cols); - int yRightTop = static_cast(detectionMat.at(i, 6) * frame.rows); - - ss.str(""); - ss << confidence; - String conf(ss.str()); - - Rect object((int)xLeftBottom, (int)yLeftBottom, - (int)(xRightTop - xLeftBottom), - (int)(yRightTop - yLeftBottom)); + int left = static_cast(detectionMat.at(i, 3) * frame.cols); + int top = static_cast(detectionMat.at(i, 4) * frame.rows); + int right = static_cast(detectionMat.at(i, 5) * frame.cols); + int bottom = static_cast(detectionMat.at(i, 6) * frame.rows); - rectangle(frame, object, Scalar(0, 255, 0)); - String label = String(classNames[objectClass]) + ": " + conf; + rectangle(frame, Point(left, top), Point(right, bottom), Scalar(0, 255, 0)); + String label = format("%s: %.2f", classNames[objectClass], confidence); int baseLine = 0; Size labelSize = getTextSize(label, FONT_HERSHEY_SIMPLEX, 0.5, 1, &baseLine); - rectangle(frame, Rect(Point(xLeftBottom, yLeftBottom - labelSize.height), - Size(labelSize.width, labelSize.height + baseLine)), + top = max(top, labelSize.height); + rectangle(frame, Point(left, top - labelSize.height), + Point(left + labelSize.width, top + baseLine), Scalar(255, 255, 255), CV_FILLED); - putText(frame, label, Point(xLeftBottom, yLeftBottom), + putText(frame, label, Point(left, top), FONT_HERSHEY_SIMPLEX, 0.5, Scalar(0,0,0)); } }