diff --git a/.dev_scripts/benchmark_test_image.py b/.dev_scripts/benchmark_test_image.py index b71ffe1ed..cf37382f7 100644 --- a/.dev_scripts/benchmark_test_image.py +++ b/.dev_scripts/benchmark_test_image.py @@ -14,7 +14,13 @@ def parse_args(): parser.add_argument('checkpoint_root', help='Checkpoint file root path') parser.add_argument('--img', default='demo/demo.jpg', help='Image file') parser.add_argument('--aug', action='store_true', help='aug test') + parser.add_argument('--model-name', help='model name to inference') parser.add_argument('--show', action='store_true', help='show results') + parser.add_argument( + '--wait-time', + type=float, + default=1, + help='the interval of show (s), 0 is block') parser.add_argument( '--device', default='cuda:0', help='Device used for inference') parser.add_argument( @@ -23,9 +29,54 @@ def parse_args(): return args +def inference_model(config_name, checkpoint, args, logger=None): + cfg = Config.fromfile(config_name) + if args.aug: + if 'flip' in cfg.data.test.pipeline[1]: + cfg.data.test.pipeline[1].flip = True + else: + if logger is not None: + logger.error(f'{config_name}: unable to start aug test') + else: + print(f'{config_name}: unable to start aug test', flush=True) + + model = init_detector(cfg, checkpoint, device=args.device) + # test a single image + result = inference_detector(model, args.img) + + # show the results + if args.show: + show_result_pyplot( + model, + args.img, + result, + score_thr=args.score_thr, + wait_time=args.wait_time) + return result + + # Sample test whether the inference code is correct def main(args): config = Config.fromfile(args.config) + + # test single model + if args.model_name: + if args.model_name in config: + model_infos = config[args.model_name] + if not isinstance(model_infos, list): + model_infos = [model_infos] + model_info = model_infos[0] + config_name = model_info['config'].strip() + print(f'processing: {config_name}', flush=True) + checkpoint = osp.join(args.checkpoint_root, + model_info['checkpoint'].strip()) + # build the model from a config file and a checkpoint file + inference_model(config_name, checkpoint, args) + return + else: + raise RuntimeError('model name input error.') + + # test all model logger = get_root_logger( log_file='benchmark_test_image.log', log_level=logging.ERROR) @@ -40,25 +91,7 @@ def main(args): model_info['checkpoint'].strip()) try: # build the model from a config file and a checkpoint file - cfg = Config.fromfile(config_name) - if args.aug: - if 'flip' in cfg.data.test.pipeline[1]: - cfg.data.test.pipeline[1].flip = True - else: - logger.error( - f'{config_name} " : Unable to start aug test') - - model = init_detector(cfg, checkpoint, device=args.device) - # test a single image - result = inference_detector(model, args.img) - # show the results - if args.show: - show_result_pyplot( - model, - args.img, - result, - score_thr=args.score_thr, - wait_time=1) + inference_model(config_name, checkpoint, args, logger) except Exception as e: logger.error(f'{config_name} " : {repr(e)}') diff --git a/.dev_scripts/test_benchamrk.sh b/.dev_scripts/test_benchmark.sh similarity index 100% rename from .dev_scripts/test_benchamrk.sh rename to .dev_scripts/test_benchmark.sh