diff --git a/samples/python/tracker.py b/samples/python/tracker.py index 753e166ad8..3b04c57e8a 100644 --- a/samples/python/tracker.py +++ b/samples/python/tracker.py @@ -1,5 +1,4 @@ #!/usr/bin/env python - ''' Tracker demo @@ -36,43 +35,49 @@ class App(object): def __init__(self, args): self.args = args - - def initializeTracker(self, image, trackerAlgorithm): + self.trackerAlgorithm = args.tracker_algo + self.tracker = self.createTracker() + + def createTracker(self): + if self.trackerAlgorithm == 'mil': + tracker = cv.TrackerMIL_create() + elif self.trackerAlgorithm == 'goturn': + params = cv.TrackerGOTURN_Params() + params.modelTxt = self.args.goturn + params.modelBin = self.args.goturn_model + tracker = cv.TrackerGOTURN_create(params) + elif self.trackerAlgorithm == 'dasiamrpn': + params = cv.TrackerDaSiamRPN_Params() + params.model = self.args.dasiamrpn_net + params.kernel_cls1 = self.args.dasiamrpn_kernel_cls1 + params.kernel_r1 = self.args.dasiamrpn_kernel_r1 + tracker = cv.TrackerDaSiamRPN_create(params) + else: + sys.exit("Tracker {} is not recognized. Please use one of three available: mil, goturn, dasiamrpn.".format(self.trackerAlgorithm)) + return tracker + + def initializeTracker(self, image): while True: - if trackerAlgorithm == 'mil': - tracker = cv.TrackerMIL_create() - elif trackerAlgorithm == 'goturn': - params = cv.TrackerGOTURN_Params() - params.modelTxt = self.args.goturn - params.modelBin = self.args.goturn_model - tracker = cv.TrackerGOTURN_create(params) - elif trackerAlgorithm == 'dasiamrpn': - params = cv.TrackerDaSiamRPN_Params() - params.model = self.args.dasiamrpn_net - params.kernel_cls1 = self.args.dasiamrpn_kernel_cls1 - params.kernel_r1 = self.args.dasiamrpn_kernel_r1 - tracker = cv.TrackerDaSiamRPN_create(params) - else: - sys.exit("Tracker {} is not recognized. Please use one of three available: mil, goturn, dasiamrpn.".format(trackerAlgorithm)) - print('==> Select object ROI for tracker ...') bbox = cv.selectROI('tracking', image) print('ROI: {}'.format(bbox)) + if bbox[2] <= 0 or bbox[3] <= 0: + sys.exit("ROI selection cancelled. Exiting...") try: - tracker.init(image, bbox) + self.tracker.init(image, bbox) except Exception as e: print('Unable to initialize tracker with requested bounding box. Is there any object?') print(e) print('Try again ...') continue - return tracker + return def run(self): videoPath = self.args.input - trackerAlgorithm = self.args.tracker_algo - camera = create_capture(videoPath, presets['cube']) + print('Using video: {}'.format(videoPath)) + camera = create_capture(cv.samples.findFileOrKeep(videoPath), presets['cube']) if not camera.isOpened(): sys.exit("Can't open video stream: {}".format(videoPath)) @@ -82,7 +87,7 @@ class App(object): assert image is not None cv.namedWindow('tracking') - tracker = self.initializeTracker(image, trackerAlgorithm) + self.initializeTracker(image) print("==> Tracking is started. Press 'SPACE' to re-initialize tracker or 'ESC' for exit...") @@ -92,7 +97,7 @@ class App(object): print("Can't read frame") break - ok, newbox = tracker.update(image) + ok, newbox = self.tracker.update(image) #print(ok, newbox) if ok: @@ -101,7 +106,7 @@ class App(object): cv.imshow("tracking", image) k = cv.waitKey(1) if k == 32: # SPACE - tracker = self.initializeTracker(image) + self.initializeTracker(image) if k == 27: # ESC break @@ -112,22 +117,13 @@ if __name__ == '__main__': print(__doc__) parser = argparse.ArgumentParser(description="Run tracker") parser.add_argument("--input", type=str, default="vtest.avi", help="Path to video source") - parser.add_argument("--tracker_algo", type=str, default="mil", help="One of three available tracking algorithms: mil, goturn, dasiamrpn") + parser.add_argument("--tracker_algo", type=str, default="mil", help="One of available tracking algorithms: mil, goturn, dasiamrpn") parser.add_argument("--goturn", type=str, default="goturn.prototxt", help="Path to GOTURN architecture") parser.add_argument("--goturn_model", type=str, default="goturn.caffemodel", help="Path to GOTERN model") parser.add_argument("--dasiamrpn_net", type=str, default="dasiamrpn_model.onnx", help="Path to onnx model of DaSiamRPN net") parser.add_argument("--dasiamrpn_kernel_r1", type=str, default="dasiamrpn_kernel_r1.onnx", help="Path to onnx model of DaSiamRPN kernel_r1") parser.add_argument("--dasiamrpn_kernel_cls1", type=str, default="dasiamrpn_kernel_cls1.onnx", help="Path to onnx model of DaSiamRPN kernel_cls1") - parser.add_argument("--dasiamrpn_backend", type=int, default=0, help="Choose one of computation backends:\ - 0: automatically (by default),\ - 1: Halide language (http://halide-lang.org/),\ - 2: Intel's Deep Learning Inference Engine (https://software.intel.com/openvino-toolkit),\ - 3: OpenCV implementation") - parser.add_argument("--dasiamrpn_target", type=int, default=0, help="Choose one of target computation devices:\ - 0: CPU target (by default),\ - 1: OpenCL,\ - 2: OpenCL fp16 (half-float precision),\ - 3: VPU") + args = parser.parse_args() App(args).run() cv.destroyAllWindows()