From e72b217c1eed31637d62ea3ae8f13915d46c27e4 Mon Sep 17 00:00:00 2001 From: captainfffsama Date: Thu, 2 Feb 2023 17:17:49 +0800 Subject: [PATCH] =?UTF-8?q?1.=20=E7=9F=AB=E6=AD=A3=E7=9F=A9=E9=98=B5=202.?= =?UTF-8?q?=20=E6=B7=BB=E5=8A=A0=E4=B8=80=E4=B8=AA=E5=90=AF=E5=8A=A8?= =?UTF-8?q?=E8=84=9A=E6=9C=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/grpc/base_cfg.py | 63 +++++++++++++++++++++++++++++++++++ src/grpc/loftr_worker.py | 72 ++++++++++++++++++++++------------------ src/grpc/server.py | 7 ++-- start_grpc.py | 62 ++++++++++++++++++++++++++++++++++ 4 files changed, 169 insertions(+), 35 deletions(-) create mode 100644 src/grpc/base_cfg.py create mode 100644 start_grpc.py diff --git a/src/grpc/base_cfg.py b/src/grpc/base_cfg.py new file mode 100644 index 0000000..29ab1a4 --- /dev/null +++ b/src/grpc/base_cfg.py @@ -0,0 +1,63 @@ +# -*- coding: utf-8 -*- +''' +@Author: captainfffsama +@Date: 2023-02-02 16:40:37 +@LastEditors: captainfffsama tuanzhangsama@outlook.com +@LastEditTime: 2023-02-02 16:42:41 +@FilePath: /LoFTR/src/grpc/base_cfg.py +@Description: +''' +import json +import yaml + +param = dict(grpc=dict(host='127.0.0.1', + port='8001', + max_workers=10, + max_send_message_length=100 * 1024 * 1024, + max_receive_message_length=100 * 1024 * 1024), + loftr=dict(ckpt_path="",device="cuda:0",thr=0.5,ransc_method="USAC_MAGSAC",ransc_thr=3,ransc_max_iter=2000, +)) + + +def _update(dic1: dict, dic2: dict): + """使用dic2 来递归更新 dic1 + # NOTE: + 1. dic1 本体是会被更改的!!! + 2. python 本身没有做尾递归优化的,dict深度超大时候可能爆栈 + """ + for k, v in dic2.items(): + if k.endswith('args') and v is None: + dic2[k] = {} + if k in dic1: + if isinstance(v, dict) and isinstance(dic1[k], dict): + _update(dic1[k], dic2[k]) + else: + dic1[k] = dic2[k] + else: + dic1[k] = dic2[k] + + +def _merge_yaml(yaml_path: str): + global param + with open(yaml_path, 'r') as fr: + content_dict = yaml.load(fr, yaml.FullLoader) + _update(param, content_dict) + + +def _merge_json(json_path: str): + global param + with open(json_path, 'r') as fr: + content_dict = json.load(fr) + _update(param, content_dict) + + +def merge_param(file_path: str): + """按照用户传入的配置文件更新基本设置 + """ + cfg_ext = file_path.split('.')[-1] + func_name = '_merge_' + cfg_ext + if func_name not in globals(): + raise ValueError('{} is not support'.format(cfg_ext)) + else: + globals()[func_name](file_path) + diff --git a/src/grpc/loftr_worker.py b/src/grpc/loftr_worker.py index d3a0325..7bdd33c 100644 --- a/src/grpc/loftr_worker.py +++ b/src/grpc/loftr_worker.py @@ -12,9 +12,10 @@ import numpy as np from src.loftr import LoFTR, default_cfg -DebugInfo=namedtuple("DebugInfo", - ["kp0_fake_match","kp1_fake_match", - "kp0_true_match","kp1_true_match"]) +DebugInfo = namedtuple( + "DebugInfo", + ["kp0_fake_match", "kp1_fake_match", "kp0_true_match", "kp1_true_match"]) + class LoFTRWorker(object): @@ -32,20 +33,27 @@ class LoFTRWorker(object): device = 'cpu' print("ERROR: cuda can not use, will use cpu") self.model = self.model.eval().to(device) - self.thr=thr - self.ransc_method = getattr(cv2,ransc_method) - self.ransc_thr=ransc_thr - self.ransc_max_iter=ransc_max_iter + self.thr = thr + self.ransc_method = getattr(cv2, ransc_method) + self.ransc_thr = ransc_thr + self.ransc_max_iter = ransc_max_iter - def _img2gray(self, img): + def _imgdeal(self, img): if len(img.shape) == 3 and img.shape[-1] == 3: img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) - return img + oh, ow = img.shape[:2] + img = cv2.resize(img, (640, 480)) + h, w = img.shape[:2] + fix_matrix = np.array([[w / ow, 0, 0], [0, h / oh, 0], [0, 0, 1]]) + return img, fix_matrix + + def _fix_H(self, fm0, fm1, H): + return np.linalg.inv(fm0) @ H @ fm1 - def __call__(self, img0, img1,debug=""): - img0 = self._img1gray(img0) - img1 = self._img1gray(img1) + def __call__(self, img0, img1, debug=""): + img0, fm0 = self._imgdeal(img0) + img1, fm1 = self._imgdeal(img1) img0 = torch.from_numpy(img0)[None][None].cuda() / 255. img1 = torch.from_numpy(img1)[None][None].cuda() / 255. @@ -56,33 +64,33 @@ class LoFTRWorker(object): mkpts1 = batch['mkpts1_f'].cpu().numpy() mconf = batch['mconf'].cpu().numpy() - idx=np.where(mconf>self.thr) - mconf=mconf[idx] - mkpts0=mkpts0[idx] - mkpts1=mkpts1[idx] + idx = np.where(mconf > self.thr) + mconf = mconf[idx] + mkpts0 = mkpts0[idx] + mkpts1 = mkpts1[idx] - debug_info=None + debug_info = None + H = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.float) if mkpts0.shape[0] < 4 or mkpts1.shape[0] < 4: - return np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], - dtype=np.float), False,debug_info + return self._fix_H(fm0, fm1, H), False, debug_info H, Mask = cv2.findHomography(mkpts0[:, :2], - mkpts1[:, :2], - self.ransc_method, - self.ransc_thr, - maxIters=self.ransc_max_iter) - Mask=np.squeeze(Mask) + mkpts1[:, :2], + self.ransc_method, + self.ransc_thr, + maxIters=self.ransc_max_iter) + Mask = np.squeeze(Mask) if debug: - kp0_true_matched=mkpts0[Mask.astype(bool),:2] - kp1_true_matched=mkpts1[Mask.astype(bool),:2] - kp0_fake_matched=mkpts0[~Mask.astype(bool),:2] - kp1_fake_matched=mkpts1[~Mask.astype(bool),:2] + kp0_true_matched = mkpts0[Mask.astype(bool), :2] + kp1_true_matched = mkpts1[Mask.astype(bool), :2] + kp0_fake_matched = mkpts0[~Mask.astype(bool), :2] + kp1_fake_matched = mkpts1[~Mask.astype(bool), :2] - debug_info=DebugInfo(kp0_fake_matched,kp1_fake_matched,kp0_true_matched,kp1_true_matched) + debug_info = DebugInfo(kp0_fake_matched, kp1_fake_matched, + kp0_true_matched, kp1_true_matched) if H is None: - return np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], - dtype=np.float), False,debug_info + return self._fix_H(fm0, fm1, H), False, debug_info else: - return H, True,debug_info \ No newline at end of file + return self._fix_H(fm0, fm1, H), True, debug_info diff --git a/src/grpc/server.py b/src/grpc/server.py index 7838c7d..3c333dc 100644 --- a/src/grpc/server.py +++ b/src/grpc/server.py @@ -2,12 +2,14 @@ @Author: captainfffsama @Date: 2023-02-02 15:59:46 @LastEditors: captainfffsama tuanzhangsama@outlook.com -@LastEditTime: 2023-02-02 16:08:41 +@LastEditTime: 2023-02-02 16:43:55 @FilePath: /LoFTR/src/grpc/server.py @Description: ''' import numpy as np +from src.loftr import default_cfg + from . import loftr_pb2 from .loftr_pb2_grpc import LoftrServicer @@ -17,7 +19,6 @@ from .utils import decode_img_from_proto, np2tensor_proto, img2pb_img class LoFTRServer(LoftrServicer): def __init__(self, - config, ckpt_path, device="cuda:0", thr=0.5, @@ -28,7 +29,7 @@ class LoFTRServer(LoftrServicer): **kwargs): super().__init__(*args, **kwargs) self.worker = LoFTRWorker( - config, + default_cfg, ckpt_path, device, thr, diff --git a/start_grpc.py b/start_grpc.py new file mode 100644 index 0000000..0c03cea --- /dev/null +++ b/start_grpc.py @@ -0,0 +1,62 @@ +# -*- coding: utf-8 -*- +''' +@Author: captainfffsama +@Date: 2023-02-02 16:38:45 +@LastEditors: captainfffsama tuanzhangsama@outlook.com +@LastEditTime: 2023-02-02 16:48:12 +@FilePath: /LoFTR/start_grpc.py +@Description: +''' +from concurrent import futures +import sys +from pprint import pprint +import os +import yaml + +import grpc + +from src.grpc.loftr_pb2_grpc import add_LoftrServicer_to_server +import src.grpc.base_cfg as cfg +from src.grpc.server import LoFTRServer + + +def start_server(config): + if not os.path.exists(config): + raise FileExistsError('{} 不存在'.format(config)) + cfg.merge_param(config) + args_dict: dict = cfg.param + pprint(args_dict) + + grpc_args = args_dict['grpc'] + model_args = args_dict['loftr'] + # 最大限制为100M + server = grpc.server( + futures.ThreadPoolExecutor(max_workers=grpc_args['max_workers']), + options=[('grpc.max_send_message_length', + grpc_args['max_send_message_length']), + ('grpc.max_receive_message_length', + grpc_args['max_receive_message_length'])]) + + loftr_server =LoFTRServer(**model_args) + add_LoftrServicer_to_server(loftr_server,server) + server.add_insecure_port("{}:{}".format(grpc_args['host'], + grpc_args['port'])) + server.start() + server.wait_for_termination() + + +def main(args=None): + import argparse + parser = argparse.ArgumentParser(description="grpc调用loftr,需要配置文件") + parser.add_argument("-c", "--config", type=str, default="", help="配置文件地址") + options = parser.parse_args(args) + if options.config: + start_server(options.config) + +if __name__ == "__main__": + rc = 1 + try: + main() + except Exception as e: + print('Error: %s' % e, file=sys.stderr) + sys.exit(rc) \ No newline at end of file