1. 矫正矩阵

2. 添加一个启动脚本
chiebot
captainfffsama 2 years ago
parent c6f6a89d8a
commit e72b217c1e
  1. 63
      src/grpc/base_cfg.py
  2. 72
      src/grpc/loftr_worker.py
  3. 7
      src/grpc/server.py
  4. 62
      start_grpc.py

@ -0,0 +1,63 @@
# -*- coding: utf-8 -*-
'''
@Author: captainfffsama
@Date: 2023-02-02 16:40:37
@LastEditors: captainfffsama tuanzhangsama@outlook.com
@LastEditTime: 2023-02-02 16:42:41
@FilePath: /LoFTR/src/grpc/base_cfg.py
@Description:
'''
import json
import yaml
param = dict(grpc=dict(host='127.0.0.1',
port='8001',
max_workers=10,
max_send_message_length=100 * 1024 * 1024,
max_receive_message_length=100 * 1024 * 1024),
loftr=dict(ckpt_path="",device="cuda:0",thr=0.5,ransc_method="USAC_MAGSAC",ransc_thr=3,ransc_max_iter=2000,
))
def _update(dic1: dict, dic2: dict):
"""使用dic2 来递归更新 dic1
# NOTE:
1. dic1 本体是会被更改的!!!
2. python 本身没有做尾递归优化的,dict深度超大时候可能爆栈
"""
for k, v in dic2.items():
if k.endswith('args') and v is None:
dic2[k] = {}
if k in dic1:
if isinstance(v, dict) and isinstance(dic1[k], dict):
_update(dic1[k], dic2[k])
else:
dic1[k] = dic2[k]
else:
dic1[k] = dic2[k]
def _merge_yaml(yaml_path: str):
global param
with open(yaml_path, 'r') as fr:
content_dict = yaml.load(fr, yaml.FullLoader)
_update(param, content_dict)
def _merge_json(json_path: str):
global param
with open(json_path, 'r') as fr:
content_dict = json.load(fr)
_update(param, content_dict)
def merge_param(file_path: str):
"""按照用户传入的配置文件更新基本设置
"""
cfg_ext = file_path.split('.')[-1]
func_name = '_merge_' + cfg_ext
if func_name not in globals():
raise ValueError('{} is not support'.format(cfg_ext))
else:
globals()[func_name](file_path)

@ -12,9 +12,10 @@ import numpy as np
from src.loftr import LoFTR, default_cfg from src.loftr import LoFTR, default_cfg
DebugInfo=namedtuple("DebugInfo", DebugInfo = namedtuple(
["kp0_fake_match","kp1_fake_match", "DebugInfo",
"kp0_true_match","kp1_true_match"]) ["kp0_fake_match", "kp1_fake_match", "kp0_true_match", "kp1_true_match"])
class LoFTRWorker(object): class LoFTRWorker(object):
@ -32,20 +33,27 @@ class LoFTRWorker(object):
device = 'cpu' device = 'cpu'
print("ERROR: cuda can not use, will use cpu") print("ERROR: cuda can not use, will use cpu")
self.model = self.model.eval().to(device) self.model = self.model.eval().to(device)
self.thr=thr self.thr = thr
self.ransc_method = getattr(cv2,ransc_method) self.ransc_method = getattr(cv2, ransc_method)
self.ransc_thr=ransc_thr self.ransc_thr = ransc_thr
self.ransc_max_iter=ransc_max_iter self.ransc_max_iter = ransc_max_iter
def _img2gray(self, img): def _imgdeal(self, img):
if len(img.shape) == 3 and img.shape[-1] == 3: if len(img.shape) == 3 and img.shape[-1] == 3:
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
return img oh, ow = img.shape[:2]
img = cv2.resize(img, (640, 480))
h, w = img.shape[:2]
fix_matrix = np.array([[w / ow, 0, 0], [0, h / oh, 0], [0, 0, 1]])
return img, fix_matrix
def _fix_H(self, fm0, fm1, H):
return np.linalg.inv(fm0) @ H @ fm1
def __call__(self, img0, img1,debug=""): def __call__(self, img0, img1, debug=""):
img0 = self._img1gray(img0) img0, fm0 = self._imgdeal(img0)
img1 = self._img1gray(img1) img1, fm1 = self._imgdeal(img1)
img0 = torch.from_numpy(img0)[None][None].cuda() / 255. img0 = torch.from_numpy(img0)[None][None].cuda() / 255.
img1 = torch.from_numpy(img1)[None][None].cuda() / 255. img1 = torch.from_numpy(img1)[None][None].cuda() / 255.
@ -56,33 +64,33 @@ class LoFTRWorker(object):
mkpts1 = batch['mkpts1_f'].cpu().numpy() mkpts1 = batch['mkpts1_f'].cpu().numpy()
mconf = batch['mconf'].cpu().numpy() mconf = batch['mconf'].cpu().numpy()
idx=np.where(mconf>self.thr) idx = np.where(mconf > self.thr)
mconf=mconf[idx] mconf = mconf[idx]
mkpts0=mkpts0[idx] mkpts0 = mkpts0[idx]
mkpts1=mkpts1[idx] mkpts1 = mkpts1[idx]
debug_info=None debug_info = None
H = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.float)
if mkpts0.shape[0] < 4 or mkpts1.shape[0] < 4: if mkpts0.shape[0] < 4 or mkpts1.shape[0] < 4:
return np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], return self._fix_H(fm0, fm1, H), False, debug_info
dtype=np.float), False,debug_info
H, Mask = cv2.findHomography(mkpts0[:, :2], H, Mask = cv2.findHomography(mkpts0[:, :2],
mkpts1[:, :2], mkpts1[:, :2],
self.ransc_method, self.ransc_method,
self.ransc_thr, self.ransc_thr,
maxIters=self.ransc_max_iter) maxIters=self.ransc_max_iter)
Mask=np.squeeze(Mask) Mask = np.squeeze(Mask)
if debug: if debug:
kp0_true_matched=mkpts0[Mask.astype(bool),:2] kp0_true_matched = mkpts0[Mask.astype(bool), :2]
kp1_true_matched=mkpts1[Mask.astype(bool),:2] kp1_true_matched = mkpts1[Mask.astype(bool), :2]
kp0_fake_matched=mkpts0[~Mask.astype(bool),:2] kp0_fake_matched = mkpts0[~Mask.astype(bool), :2]
kp1_fake_matched=mkpts1[~Mask.astype(bool),:2] kp1_fake_matched = mkpts1[~Mask.astype(bool), :2]
debug_info=DebugInfo(kp0_fake_matched,kp1_fake_matched,kp0_true_matched,kp1_true_matched) debug_info = DebugInfo(kp0_fake_matched, kp1_fake_matched,
kp0_true_matched, kp1_true_matched)
if H is None: if H is None:
return np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], return self._fix_H(fm0, fm1, H), False, debug_info
dtype=np.float), False,debug_info
else: else:
return H, True,debug_info return self._fix_H(fm0, fm1, H), True, debug_info

@ -2,12 +2,14 @@
@Author: captainfffsama @Author: captainfffsama
@Date: 2023-02-02 15:59:46 @Date: 2023-02-02 15:59:46
@LastEditors: captainfffsama tuanzhangsama@outlook.com @LastEditors: captainfffsama tuanzhangsama@outlook.com
@LastEditTime: 2023-02-02 16:08:41 @LastEditTime: 2023-02-02 16:43:55
@FilePath: /LoFTR/src/grpc/server.py @FilePath: /LoFTR/src/grpc/server.py
@Description: @Description:
''' '''
import numpy as np import numpy as np
from src.loftr import default_cfg
from . import loftr_pb2 from . import loftr_pb2
from .loftr_pb2_grpc import LoftrServicer from .loftr_pb2_grpc import LoftrServicer
@ -17,7 +19,6 @@ from .utils import decode_img_from_proto, np2tensor_proto, img2pb_img
class LoFTRServer(LoftrServicer): class LoFTRServer(LoftrServicer):
def __init__(self, def __init__(self,
config,
ckpt_path, ckpt_path,
device="cuda:0", device="cuda:0",
thr=0.5, thr=0.5,
@ -28,7 +29,7 @@ class LoFTRServer(LoftrServicer):
**kwargs): **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.worker = LoFTRWorker( self.worker = LoFTRWorker(
config, default_cfg,
ckpt_path, ckpt_path,
device, device,
thr, thr,

@ -0,0 +1,62 @@
# -*- coding: utf-8 -*-
'''
@Author: captainfffsama
@Date: 2023-02-02 16:38:45
@LastEditors: captainfffsama tuanzhangsama@outlook.com
@LastEditTime: 2023-02-02 16:48:12
@FilePath: /LoFTR/start_grpc.py
@Description:
'''
from concurrent import futures
import sys
from pprint import pprint
import os
import yaml
import grpc
from src.grpc.loftr_pb2_grpc import add_LoftrServicer_to_server
import src.grpc.base_cfg as cfg
from src.grpc.server import LoFTRServer
def start_server(config):
if not os.path.exists(config):
raise FileExistsError('{} 不存在'.format(config))
cfg.merge_param(config)
args_dict: dict = cfg.param
pprint(args_dict)
grpc_args = args_dict['grpc']
model_args = args_dict['loftr']
# 最大限制为100M
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=grpc_args['max_workers']),
options=[('grpc.max_send_message_length',
grpc_args['max_send_message_length']),
('grpc.max_receive_message_length',
grpc_args['max_receive_message_length'])])
loftr_server =LoFTRServer(**model_args)
add_LoftrServicer_to_server(loftr_server,server)
server.add_insecure_port("{}:{}".format(grpc_args['host'],
grpc_args['port']))
server.start()
server.wait_for_termination()
def main(args=None):
import argparse
parser = argparse.ArgumentParser(description="grpc调用loftr,需要配置文件")
parser.add_argument("-c", "--config", type=str, default="", help="配置文件地址")
options = parser.parse_args(args)
if options.config:
start_server(options.config)
if __name__ == "__main__":
rc = 1
try:
main()
except Exception as e:
print('Error: %s' % e, file=sys.stderr)
sys.exit(rc)
Loading…
Cancel
Save