chiebot
captainfffsama 2 years ago
parent e72b217c1e
commit 6a49742e69
  1. 1
      .gitignore
  2. 9
      src/grpc/__init__.py
  3. 18
      src/grpc/base_cfg.py
  4. 174
      src/grpc/debug_tools.py
  5. 68
      src/grpc/loftr_worker.py
  6. 15
      src/grpc/server.py

1
.gitignore vendored

@ -6,6 +6,7 @@ __pycache__/
*.pth *.pth
tmp.* tmp.*
*/.ipynb_checkpoints/* */.ipynb_checkpoints/*
test/
logs/ logs/
weights/ weights/

@ -0,0 +1,9 @@
# -*- coding: utf-8 -*-
'''
@Author: captainfffsama
@Date: 2023-02-03 10:03:55
@LastEditors: captainfffsama tuanzhangsama@outlook.com
@LastEditTime: 2023-02-03 10:03:56
@FilePath: /LoFTR/src/grpc/__init__.py
@Description:
'''

@ -3,7 +3,7 @@
@Author: captainfffsama @Author: captainfffsama
@Date: 2023-02-02 16:40:37 @Date: 2023-02-02 16:40:37
@LastEditors: captainfffsama tuanzhangsama@outlook.com @LastEditors: captainfffsama tuanzhangsama@outlook.com
@LastEditTime: 2023-02-02 16:42:41 @LastEditTime: 2023-02-03 10:34:31
@FilePath: /LoFTR/src/grpc/base_cfg.py @FilePath: /LoFTR/src/grpc/base_cfg.py
@Description: @Description:
''' '''
@ -15,8 +15,19 @@ param = dict(grpc=dict(host='127.0.0.1',
max_workers=10, max_workers=10,
max_send_message_length=100 * 1024 * 1024, max_send_message_length=100 * 1024 * 1024,
max_receive_message_length=100 * 1024 * 1024), max_receive_message_length=100 * 1024 * 1024),
loftr=dict(ckpt_path="",device="cuda:0",thr=0.5,ransc_method="USAC_MAGSAC",ransc_thr=3,ransc_max_iter=2000, loftr=dict(ckpt_path="",
)) img_size=(640, 480),
device="cuda:0",
thr=0.5,
ransc_method="USAC_MAGSAC",
ransc_thr=3,
ransc_max_iter=2000,
debug=False,
debug_show_type=(
"vis",
"false",
"true",
)))
def _update(dic1: dict, dic2: dict): def _update(dic1: dict, dic2: dict):
@ -60,4 +71,3 @@ def merge_param(file_path: str):
raise ValueError('{} is not support'.format(cfg_ext)) raise ValueError('{} is not support'.format(cfg_ext))
else: else:
globals()[func_name](file_path) globals()[func_name](file_path)

@ -0,0 +1,174 @@
# -*- coding: utf-8 -*-
'''
@Author: captainfffsama
@Date: 2023-02-03 09:56:39
@LastEditors: captainfffsama tuanzhangsama@outlook.com
@LastEditTime: 2023-02-03 10:12:32
@FilePath: /LoFTR/src/grpc/debug_tools.py
@Description:
'''
import os
from typing import Optional
import math
from copy import deepcopy
import random
import cv2
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.patches as patches
import numpy as np
import torch
from torchvision.utils import make_grid
COLOR_DICT=mcolors.CSS4_COLORS
def get_color():
return random.sample(COLOR_DICT.keys(),1)[0]
def normalization(data):
_range = np.max(data) - np.min(data)
return (data - np.min(data)) / _range *255
def show_img(img_ori,text:Optional[str]=None,cvreader:bool=True,delay:float=0):
img=deepcopy(img_ori)
if isinstance(img, list) or isinstance(img, tuple):
img_num = len(img)
row_n = math.ceil(math.sqrt(img_num))
col_n = max(math.ceil(img_num / row_n), 1)
fig, axs = plt.subplots(row_n, col_n, figsize=(15 * row_n, 15 * col_n))
for idx, img_ in enumerate(img):
if isinstance(img_,torch.Tensor) or isinstance(img_,np.ndarray):
img_=show_tensor(img_,cvreader)
if 2 == len(axs.shape):
axs[idx % row_n][idx // row_n].imshow(img_)
axs[idx % row_n][idx // row_n].set_title(str(idx))
else:
axs[idx % row_n].imshow(img_)
axs[idx % row_n].set_title(str(idx))
if text:
plt.text(0,0,text,fontsize=15)
if delay <=0:
plt.show()
else:
plt.draw()
plt.pause(delay)
plt.close()
elif isinstance(img,torch.Tensor) or isinstance(img,np.ndarray):
img=show_tensor(img,cvreader)
plt.imshow(img)
if text:
plt.text(0,0,text,fontsize=15)
if delay <=0:
plt.show()
else:
plt.draw()
plt.pause(delay)
plt.close()
else:
if hasattr(img,'show'):
img.show()
def show_tensor(img,cvreader):
if len(img.shape) == 4 and img.shape[0] !=1:
if isinstance(img,np.ndarray):
img=torch.Tensor(img)
img=make_grid(img)
else:
img=make_grid(img)
img: np.ndarray = img.detach().cpu().numpy().squeeze()
if isinstance(img,torch.Tensor):
img=img.detach().cpu().numpy().squeeze()
if 1==img.max():
img=img.astype(np.float)
if img.shape[0] == 1 or img.shape[0] == 3:
img = np.transpose(img, (1, 2, 0))
if img.min() <0 or img.max() >255:
img=normalization(img)
print("img have norm")
if img.shape[-1] == 3:
img=img.astype(np.uint8)
if cvreader and len(img.shape)==3:
img=img[:,:,::-1]
return img
def draw_point(axs,kp:np.ndarray,color:str):
kp=kp.astype(int)
axs.plot(kp[:,0],kp[:,1],'o',color=color)
def plot_kp(debug_info,show_flag=('fake','true','no_match',"vis"),debug_save=""):
fig, axs = plt.subplots(2, 1)
axs[0].set_xticks([])
axs[0].set_yticks([])
axs[0].imshow(debug_info.imgA[:,:,::-1])
axs[1].set_yticks([])
axs[1].set_yticks([])
axs[1].imshow(debug_info.imgB[:,:,::-1])
if "fake" in show_flag:
draw_point(axs[0],debug_info.pts_info.kp0_fake_match,'y')
draw_point(axs[1],debug_info.pts_info.kp1_fake_match,'y')
for kp_a,kp_b in zip(debug_info.pts_info.kp0_fake_match,debug_info.pts_info.kp1_fake_match):
color=get_color()
con = patches.ConnectionPatch(
xyA=kp_a, coordsA=axs[0].transData,
xyB=kp_b, coordsB=axs[1].transData,
arrowstyle="-", shrinkB=5,color=color)
axs[1].add_artist(con)
if "true" in show_flag:
draw_point(axs[0],debug_info.pts_info.kp0_true_match,'g')
draw_point(axs[1],debug_info.pts_info.kp1_true_match,'g')
for kp_a,kp_b in zip(debug_info.pts_info.kp0_true_match,debug_info.pts_info.kp1_true_match):
color=get_color()
con = patches.ConnectionPatch(
xyA=kp_a, coordsA=axs[0].transData,
xyB=kp_b, coordsB=axs[1].transData,
arrowstyle="-", shrinkB=5,color=color)
axs[1].add_artist(con)
if "vis" in show_flag:
fig2=plt.figure()
axe2=fig2.gca()
fix_img=generate_transform_pic((debug_info.imgB.shape[1],debug_info.imgB.shape[0]),debug_info.imgA,debug_info.H)
final_img=cv2.addWeighted(debug_info.imgB,0.6,fix_img,0.4,0)
axe2.imshow(final_img[:,:,::-1])
if isinstance(debug_save,str):
save_path,ext=os.path.splitext(debug_save)
fig.savefig(save_path+"nft.jpg",dpi=400,bbox_inches='tight')
fig2.savefig(save_path+"vis.jpg",dpi=400,bbox_inches='tight')
plt.clf()
plt.close("all")
else:
plt.clf()
plt.show()
def convertpt2matrix(pts):
if isinstance(pts,list):
pts=np.array([x+[1] for x in pts])
if pts.shape[-1]==2:
expend_m=np.array([[1],[1],[1],[1]])
pts=np.concatenate((pts,expend_m),axis=1)
return pts
def generate_transform_pic(imgb_shape,imgA,H):
imgA_warp = cv2.warpPerspective(imgA,
H, imgb_shape,
flags=cv2.INTER_LINEAR)
k_pts=[[0,0],[0,imgA.shape[0]-1],[imgA.shape[1]-1,imgA.shape[0]-1],[imgA.shape[1]-1,0]]
kptm=convertpt2matrix(k_pts)
warp_kpm=np.matmul(H,kptm.T)
warp_kpm=(warp_kpm/warp_kpm[2,:])[:2,:].astype(int).T.reshape((-1,1,2))
# FIXME:这里可能有问题,没做越界
cv2.polylines(imgA_warp,[warp_kpm],True,(0,0,255),thickness=5)
return imgA_warp

@ -5,23 +5,53 @@
@Description: @Description:
''' '''
from collections import namedtuple from collections import namedtuple
from typing import Optional, Union
import torch import torch
import cv2 import cv2
import numpy as np import numpy as np
from src.loftr import LoFTR, default_cfg from src.loftr import LoFTR, default_cfg
from .debug_tools import plot_kp
DebugInfo = namedtuple( KeyPointsDebugInfo = namedtuple(
"DebugInfo", "KeyPointsDebugInfo",
["kp0_fake_match", "kp1_fake_match", "kp0_true_match", "kp1_true_match"]) ["kp0_fake_match", "kp1_fake_match", "kp0_true_match", "kp1_true_match"])
class DebugInfoCollector(object):
def __init__(self,
imgA: Optional[np.ndarray] = None,
imgB: Optional[np.ndarray] = None,
pts_info: Optional[KeyPointsDebugInfo] = None,
H: Optional[np.ndarray] = None):
self.imgA = imgA
self.imgB = imgB
self.pts_info = pts_info
self.H = H
def clean(self):
self.imgA = None
self.imgB = None
self.pts_info = None
self.H = None
def __str__(self):
return "{}".format({
"imgA": self.imgA,
"imgB": self.imgB,
"pts_info": self.pts_info,
"H": self.H
})
class LoFTRWorker(object): class LoFTRWorker(object):
def __init__(self, def __init__(self,
config, config,
ckpt_path, ckpt_path,
img_size=(640, 480),
device="cuda:0", device="cuda:0",
thr=0.5, thr=0.5,
ransc_method="USAC_MAGSAC", ransc_method="USAC_MAGSAC",
@ -37,13 +67,16 @@ class LoFTRWorker(object):
self.ransc_method = getattr(cv2, ransc_method) self.ransc_method = getattr(cv2, ransc_method)
self.ransc_thr = ransc_thr self.ransc_thr = ransc_thr
self.ransc_max_iter = ransc_max_iter self.ransc_max_iter = ransc_max_iter
self.img_size = img_size
def _imgdeal(self, img): def _img2gray(self, img):
if len(img.shape) == 3 and img.shape[-1] == 3: if len(img.shape) == 3 and img.shape[-1] == 3:
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
return img
def _imgdeal(self, img):
oh, ow = img.shape[:2] oh, ow = img.shape[:2]
img = cv2.resize(img, (640, 480)) img = cv2.resize(img, self.img_size)
h, w = img.shape[:2] h, w = img.shape[:2]
fix_matrix = np.array([[w / ow, 0, 0], [0, h / oh, 0], [0, 0, 1]]) fix_matrix = np.array([[w / ow, 0, 0], [0, h / oh, 0], [0, 0, 1]])
return img, fix_matrix return img, fix_matrix
@ -51,9 +84,19 @@ class LoFTRWorker(object):
def _fix_H(self, fm0, fm1, H): def _fix_H(self, fm0, fm1, H):
return np.linalg.inv(fm0) @ H @ fm1 return np.linalg.inv(fm0) @ H @ fm1
def __call__(self, img0, img1, debug=""): def __call__(self,
img0, fm0 = self._imgdeal(img0) img0,
img1, fm1 = self._imgdeal(img1) img1,
debug: Union[bool, str] = False,
debug_show_type: tuple = (
"vis",
"false",
"true",
)):
img0_o, fm0 = self._imgdeal(img0)
img1_o, fm1 = self._imgdeal(img1)
img0 = self._img2gray(img0_o)
img1 = self._img2gray(img1_o)
img0 = torch.from_numpy(img0)[None][None].cuda() / 255. img0 = torch.from_numpy(img0)[None][None].cuda() / 255.
img1 = torch.from_numpy(img1)[None][None].cuda() / 255. img1 = torch.from_numpy(img1)[None][None].cuda() / 255.
@ -69,10 +112,9 @@ class LoFTRWorker(object):
mkpts0 = mkpts0[idx] mkpts0 = mkpts0[idx]
mkpts1 = mkpts1[idx] mkpts1 = mkpts1[idx]
debug_info = None
H = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.float) H = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.float)
if mkpts0.shape[0] < 4 or mkpts1.shape[0] < 4: if mkpts0.shape[0] < 4 or mkpts1.shape[0] < 4:
return self._fix_H(fm0, fm1, H), False, debug_info return self._fix_H(fm0, fm1, H), False
H, Mask = cv2.findHomography(mkpts0[:, :2], H, Mask = cv2.findHomography(mkpts0[:, :2],
mkpts1[:, :2], mkpts1[:, :2],
@ -87,10 +129,12 @@ class LoFTRWorker(object):
kp0_fake_matched = mkpts0[~Mask.astype(bool), :2] kp0_fake_matched = mkpts0[~Mask.astype(bool), :2]
kp1_fake_matched = mkpts1[~Mask.astype(bool), :2] kp1_fake_matched = mkpts1[~Mask.astype(bool), :2]
debug_info = DebugInfo(kp0_fake_matched, kp1_fake_matched, kpdi = KeyPointsDebugInfo(kp0_fake_matched, kp1_fake_matched,
kp0_true_matched, kp1_true_matched) kp0_true_matched, kp1_true_matched)
debug_info = DebugInfoCollector(img0_o, img1_o, kpdi, H)
plot_kp(debug_info, show_flag=debug_show_type, debug_save=debug)
if H is None: if H is None:
return self._fix_H(fm0, fm1, H), False, debug_info return self._fix_H(fm0, fm1, H), False
else: else:
return self._fix_H(fm0, fm1, H), True, debug_info return self._fix_H(fm0, fm1, H), True

@ -2,7 +2,7 @@
@Author: captainfffsama @Author: captainfffsama
@Date: 2023-02-02 15:59:46 @Date: 2023-02-02 15:59:46
@LastEditors: captainfffsama tuanzhangsama@outlook.com @LastEditors: captainfffsama tuanzhangsama@outlook.com
@LastEditTime: 2023-02-02 16:43:55 @LastEditTime: 2023-02-03 10:35:44
@FilePath: /LoFTR/src/grpc/server.py @FilePath: /LoFTR/src/grpc/server.py
@Description: @Description:
''' '''
@ -18,25 +18,36 @@ from .utils import decode_img_from_proto, np2tensor_proto, img2pb_img
class LoFTRServer(LoftrServicer): class LoFTRServer(LoftrServicer):
def __init__(self, def __init__(self,
ckpt_path, ckpt_path,
img_size=(640, 480),
device="cuda:0", device="cuda:0",
thr=0.5, thr=0.5,
ransc_method="USAC_MAGSAC", ransc_method="USAC_MAGSAC",
ransc_thr=3, ransc_thr=3,
ransc_max_iter=2000, ransc_max_iter=2000,
debug=False,
debug_show_type=(
"vis",
"false",
"true",
),
*args, *args,
**kwargs): **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.worker = LoFTRWorker( self.worker = LoFTRWorker(
default_cfg, default_cfg,
ckpt_path, ckpt_path,
img_size,
device, device,
thr, thr,
ransc_method, ransc_method,
ransc_thr, ransc_thr,
ransc_max_iter, ransc_max_iter,
) )
self.debug = debug
self.debug_show_type = debug_show_type
def getEssentialMatrix(self, request, context): def getEssentialMatrix(self, request, context):
imgA = decode_img_from_proto(request.imageA) imgA = decode_img_from_proto(request.imageA)
@ -45,7 +56,7 @@ class LoFTRServer(LoftrServicer):
return loftr_pb2.GetEssentialMatrixReply(matrix=np2tensor_proto( return loftr_pb2.GetEssentialMatrixReply(matrix=np2tensor_proto(
np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.float)), np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.float)),
status=-14) status=-14)
H, flag = self.worker(imgA, imgB) H, flag = self.worker(imgA, imgB, self.debug, self.debug_show_type)
status = 0 if flag else -14 status = 0 if flag else -14
return loftr_pb2.GetEssentialMatrixReply(matrix=np2tensor_proto(H), return loftr_pb2.GetEssentialMatrixReply(matrix=np2tensor_proto(H),
status=status) status=status)

Loading…
Cancel
Save