1. 矫正矩阵

2. 添加一个启动脚本
chiebot
captainfffsama 2 years ago
parent c6f6a89d8a
commit e72b217c1e
  1. 63
      src/grpc/base_cfg.py
  2. 34
      src/grpc/loftr_worker.py
  3. 7
      src/grpc/server.py
  4. 62
      start_grpc.py

@ -0,0 +1,63 @@
# -*- coding: utf-8 -*-
'''
@Author: captainfffsama
@Date: 2023-02-02 16:40:37
@LastEditors: captainfffsama tuanzhangsama@outlook.com
@LastEditTime: 2023-02-02 16:42:41
@FilePath: /LoFTR/src/grpc/base_cfg.py
@Description:
'''
import json
import yaml
param = dict(grpc=dict(host='127.0.0.1',
port='8001',
max_workers=10,
max_send_message_length=100 * 1024 * 1024,
max_receive_message_length=100 * 1024 * 1024),
loftr=dict(ckpt_path="",device="cuda:0",thr=0.5,ransc_method="USAC_MAGSAC",ransc_thr=3,ransc_max_iter=2000,
))
def _update(dic1: dict, dic2: dict):
"""使用dic2 来递归更新 dic1
# NOTE:
1. dic1 本体是会被更改的!!!
2. python 本身没有做尾递归优化的,dict深度超大时候可能爆栈
"""
for k, v in dic2.items():
if k.endswith('args') and v is None:
dic2[k] = {}
if k in dic1:
if isinstance(v, dict) and isinstance(dic1[k], dict):
_update(dic1[k], dic2[k])
else:
dic1[k] = dic2[k]
else:
dic1[k] = dic2[k]
def _merge_yaml(yaml_path: str):
global param
with open(yaml_path, 'r') as fr:
content_dict = yaml.load(fr, yaml.FullLoader)
_update(param, content_dict)
def _merge_json(json_path: str):
global param
with open(json_path, 'r') as fr:
content_dict = json.load(fr)
_update(param, content_dict)
def merge_param(file_path: str):
"""按照用户传入的配置文件更新基本设置
"""
cfg_ext = file_path.split('.')[-1]
func_name = '_merge_' + cfg_ext
if func_name not in globals():
raise ValueError('{} is not support'.format(cfg_ext))
else:
globals()[func_name](file_path)

@ -12,9 +12,10 @@ import numpy as np
from src.loftr import LoFTR, default_cfg
DebugInfo=namedtuple("DebugInfo",
["kp0_fake_match","kp1_fake_match",
"kp0_true_match","kp1_true_match"])
DebugInfo = namedtuple(
"DebugInfo",
["kp0_fake_match", "kp1_fake_match", "kp0_true_match", "kp1_true_match"])
class LoFTRWorker(object):
@ -37,15 +38,22 @@ class LoFTRWorker(object):
self.ransc_thr = ransc_thr
self.ransc_max_iter = ransc_max_iter
def _img2gray(self, img):
def _imgdeal(self, img):
if len(img.shape) == 3 and img.shape[-1] == 3:
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
return img
oh, ow = img.shape[:2]
img = cv2.resize(img, (640, 480))
h, w = img.shape[:2]
fix_matrix = np.array([[w / ow, 0, 0], [0, h / oh, 0], [0, 0, 1]])
return img, fix_matrix
def _fix_H(self, fm0, fm1, H):
return np.linalg.inv(fm0) @ H @ fm1
def __call__(self, img0, img1, debug=""):
img0 = self._img1gray(img0)
img1 = self._img1gray(img1)
img0, fm0 = self._imgdeal(img0)
img1, fm1 = self._imgdeal(img1)
img0 = torch.from_numpy(img0)[None][None].cuda() / 255.
img1 = torch.from_numpy(img1)[None][None].cuda() / 255.
@ -62,9 +70,9 @@ class LoFTRWorker(object):
mkpts1 = mkpts1[idx]
debug_info = None
H = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.float)
if mkpts0.shape[0] < 4 or mkpts1.shape[0] < 4:
return np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]],
dtype=np.float), False,debug_info
return self._fix_H(fm0, fm1, H), False, debug_info
H, Mask = cv2.findHomography(mkpts0[:, :2],
mkpts1[:, :2],
@ -79,10 +87,10 @@ class LoFTRWorker(object):
kp0_fake_matched = mkpts0[~Mask.astype(bool), :2]
kp1_fake_matched = mkpts1[~Mask.astype(bool), :2]
debug_info=DebugInfo(kp0_fake_matched,kp1_fake_matched,kp0_true_matched,kp1_true_matched)
debug_info = DebugInfo(kp0_fake_matched, kp1_fake_matched,
kp0_true_matched, kp1_true_matched)
if H is None:
return np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]],
dtype=np.float), False,debug_info
return self._fix_H(fm0, fm1, H), False, debug_info
else:
return H, True,debug_info
return self._fix_H(fm0, fm1, H), True, debug_info

@ -2,12 +2,14 @@
@Author: captainfffsama
@Date: 2023-02-02 15:59:46
@LastEditors: captainfffsama tuanzhangsama@outlook.com
@LastEditTime: 2023-02-02 16:08:41
@LastEditTime: 2023-02-02 16:43:55
@FilePath: /LoFTR/src/grpc/server.py
@Description:
'''
import numpy as np
from src.loftr import default_cfg
from . import loftr_pb2
from .loftr_pb2_grpc import LoftrServicer
@ -17,7 +19,6 @@ from .utils import decode_img_from_proto, np2tensor_proto, img2pb_img
class LoFTRServer(LoftrServicer):
def __init__(self,
config,
ckpt_path,
device="cuda:0",
thr=0.5,
@ -28,7 +29,7 @@ class LoFTRServer(LoftrServicer):
**kwargs):
super().__init__(*args, **kwargs)
self.worker = LoFTRWorker(
config,
default_cfg,
ckpt_path,
device,
thr,

@ -0,0 +1,62 @@
# -*- coding: utf-8 -*-
'''
@Author: captainfffsama
@Date: 2023-02-02 16:38:45
@LastEditors: captainfffsama tuanzhangsama@outlook.com
@LastEditTime: 2023-02-02 16:48:12
@FilePath: /LoFTR/start_grpc.py
@Description:
'''
from concurrent import futures
import sys
from pprint import pprint
import os
import yaml
import grpc
from src.grpc.loftr_pb2_grpc import add_LoftrServicer_to_server
import src.grpc.base_cfg as cfg
from src.grpc.server import LoFTRServer
def start_server(config):
if not os.path.exists(config):
raise FileExistsError('{} 不存在'.format(config))
cfg.merge_param(config)
args_dict: dict = cfg.param
pprint(args_dict)
grpc_args = args_dict['grpc']
model_args = args_dict['loftr']
# 最大限制为100M
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=grpc_args['max_workers']),
options=[('grpc.max_send_message_length',
grpc_args['max_send_message_length']),
('grpc.max_receive_message_length',
grpc_args['max_receive_message_length'])])
loftr_server =LoFTRServer(**model_args)
add_LoftrServicer_to_server(loftr_server,server)
server.add_insecure_port("{}:{}".format(grpc_args['host'],
grpc_args['port']))
server.start()
server.wait_for_termination()
def main(args=None):
import argparse
parser = argparse.ArgumentParser(description="grpc调用loftr,需要配置文件")
parser.add_argument("-c", "--config", type=str, default="", help="配置文件地址")
options = parser.parse_args(args)
if options.config:
start_server(options.config)
if __name__ == "__main__":
rc = 1
try:
main()
except Exception as e:
print('Error: %s' % e, file=sys.stderr)
sys.exit(rc)
Loading…
Cancel
Save