chiebot
captainfffsama 2 years ago
parent e72b217c1e
commit 6a49742e69
  1. 1
      .gitignore
  2. 9
      src/grpc/__init__.py
  3. 18
      src/grpc/base_cfg.py
  4. 174
      src/grpc/debug_tools.py
  5. 70
      src/grpc/loftr_worker.py
  6. 19
      src/grpc/server.py

1
.gitignore vendored

@ -6,6 +6,7 @@ __pycache__/
*.pth
tmp.*
*/.ipynb_checkpoints/*
test/
logs/
weights/

@ -0,0 +1,9 @@
# -*- coding: utf-8 -*-
'''
@Author: captainfffsama
@Date: 2023-02-03 10:03:55
@LastEditors: captainfffsama tuanzhangsama@outlook.com
@LastEditTime: 2023-02-03 10:03:56
@FilePath: /LoFTR/src/grpc/__init__.py
@Description:
'''

@ -3,7 +3,7 @@
@Author: captainfffsama
@Date: 2023-02-02 16:40:37
@LastEditors: captainfffsama tuanzhangsama@outlook.com
@LastEditTime: 2023-02-02 16:42:41
@LastEditTime: 2023-02-03 10:34:31
@FilePath: /LoFTR/src/grpc/base_cfg.py
@Description:
'''
@ -15,8 +15,19 @@ param = dict(grpc=dict(host='127.0.0.1',
max_workers=10,
max_send_message_length=100 * 1024 * 1024,
max_receive_message_length=100 * 1024 * 1024),
loftr=dict(ckpt_path="",device="cuda:0",thr=0.5,ransc_method="USAC_MAGSAC",ransc_thr=3,ransc_max_iter=2000,
))
loftr=dict(ckpt_path="",
img_size=(640, 480),
device="cuda:0",
thr=0.5,
ransc_method="USAC_MAGSAC",
ransc_thr=3,
ransc_max_iter=2000,
debug=False,
debug_show_type=(
"vis",
"false",
"true",
)))
def _update(dic1: dict, dic2: dict):
@ -60,4 +71,3 @@ def merge_param(file_path: str):
raise ValueError('{} is not support'.format(cfg_ext))
else:
globals()[func_name](file_path)

@ -0,0 +1,174 @@
# -*- coding: utf-8 -*-
'''
@Author: captainfffsama
@Date: 2023-02-03 09:56:39
@LastEditors: captainfffsama tuanzhangsama@outlook.com
@LastEditTime: 2023-02-03 10:12:32
@FilePath: /LoFTR/src/grpc/debug_tools.py
@Description:
'''
import os
from typing import Optional
import math
from copy import deepcopy
import random
import cv2
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.patches as patches
import numpy as np
import torch
from torchvision.utils import make_grid
COLOR_DICT=mcolors.CSS4_COLORS
def get_color():
return random.sample(COLOR_DICT.keys(),1)[0]
def normalization(data):
_range = np.max(data) - np.min(data)
return (data - np.min(data)) / _range *255
def show_img(img_ori,text:Optional[str]=None,cvreader:bool=True,delay:float=0):
img=deepcopy(img_ori)
if isinstance(img, list) or isinstance(img, tuple):
img_num = len(img)
row_n = math.ceil(math.sqrt(img_num))
col_n = max(math.ceil(img_num / row_n), 1)
fig, axs = plt.subplots(row_n, col_n, figsize=(15 * row_n, 15 * col_n))
for idx, img_ in enumerate(img):
if isinstance(img_,torch.Tensor) or isinstance(img_,np.ndarray):
img_=show_tensor(img_,cvreader)
if 2 == len(axs.shape):
axs[idx % row_n][idx // row_n].imshow(img_)
axs[idx % row_n][idx // row_n].set_title(str(idx))
else:
axs[idx % row_n].imshow(img_)
axs[idx % row_n].set_title(str(idx))
if text:
plt.text(0,0,text,fontsize=15)
if delay <=0:
plt.show()
else:
plt.draw()
plt.pause(delay)
plt.close()
elif isinstance(img,torch.Tensor) or isinstance(img,np.ndarray):
img=show_tensor(img,cvreader)
plt.imshow(img)
if text:
plt.text(0,0,text,fontsize=15)
if delay <=0:
plt.show()
else:
plt.draw()
plt.pause(delay)
plt.close()
else:
if hasattr(img,'show'):
img.show()
def show_tensor(img,cvreader):
if len(img.shape) == 4 and img.shape[0] !=1:
if isinstance(img,np.ndarray):
img=torch.Tensor(img)
img=make_grid(img)
else:
img=make_grid(img)
img: np.ndarray = img.detach().cpu().numpy().squeeze()
if isinstance(img,torch.Tensor):
img=img.detach().cpu().numpy().squeeze()
if 1==img.max():
img=img.astype(np.float)
if img.shape[0] == 1 or img.shape[0] == 3:
img = np.transpose(img, (1, 2, 0))
if img.min() <0 or img.max() >255:
img=normalization(img)
print("img have norm")
if img.shape[-1] == 3:
img=img.astype(np.uint8)
if cvreader and len(img.shape)==3:
img=img[:,:,::-1]
return img
def draw_point(axs,kp:np.ndarray,color:str):
kp=kp.astype(int)
axs.plot(kp[:,0],kp[:,1],'o',color=color)
def plot_kp(debug_info,show_flag=('fake','true','no_match',"vis"),debug_save=""):
fig, axs = plt.subplots(2, 1)
axs[0].set_xticks([])
axs[0].set_yticks([])
axs[0].imshow(debug_info.imgA[:,:,::-1])
axs[1].set_yticks([])
axs[1].set_yticks([])
axs[1].imshow(debug_info.imgB[:,:,::-1])
if "fake" in show_flag:
draw_point(axs[0],debug_info.pts_info.kp0_fake_match,'y')
draw_point(axs[1],debug_info.pts_info.kp1_fake_match,'y')
for kp_a,kp_b in zip(debug_info.pts_info.kp0_fake_match,debug_info.pts_info.kp1_fake_match):
color=get_color()
con = patches.ConnectionPatch(
xyA=kp_a, coordsA=axs[0].transData,
xyB=kp_b, coordsB=axs[1].transData,
arrowstyle="-", shrinkB=5,color=color)
axs[1].add_artist(con)
if "true" in show_flag:
draw_point(axs[0],debug_info.pts_info.kp0_true_match,'g')
draw_point(axs[1],debug_info.pts_info.kp1_true_match,'g')
for kp_a,kp_b in zip(debug_info.pts_info.kp0_true_match,debug_info.pts_info.kp1_true_match):
color=get_color()
con = patches.ConnectionPatch(
xyA=kp_a, coordsA=axs[0].transData,
xyB=kp_b, coordsB=axs[1].transData,
arrowstyle="-", shrinkB=5,color=color)
axs[1].add_artist(con)
if "vis" in show_flag:
fig2=plt.figure()
axe2=fig2.gca()
fix_img=generate_transform_pic((debug_info.imgB.shape[1],debug_info.imgB.shape[0]),debug_info.imgA,debug_info.H)
final_img=cv2.addWeighted(debug_info.imgB,0.6,fix_img,0.4,0)
axe2.imshow(final_img[:,:,::-1])
if isinstance(debug_save,str):
save_path,ext=os.path.splitext(debug_save)
fig.savefig(save_path+"nft.jpg",dpi=400,bbox_inches='tight')
fig2.savefig(save_path+"vis.jpg",dpi=400,bbox_inches='tight')
plt.clf()
plt.close("all")
else:
plt.clf()
plt.show()
def convertpt2matrix(pts):
if isinstance(pts,list):
pts=np.array([x+[1] for x in pts])
if pts.shape[-1]==2:
expend_m=np.array([[1],[1],[1],[1]])
pts=np.concatenate((pts,expend_m),axis=1)
return pts
def generate_transform_pic(imgb_shape,imgA,H):
imgA_warp = cv2.warpPerspective(imgA,
H, imgb_shape,
flags=cv2.INTER_LINEAR)
k_pts=[[0,0],[0,imgA.shape[0]-1],[imgA.shape[1]-1,imgA.shape[0]-1],[imgA.shape[1]-1,0]]
kptm=convertpt2matrix(k_pts)
warp_kpm=np.matmul(H,kptm.T)
warp_kpm=(warp_kpm/warp_kpm[2,:])[:2,:].astype(int).T.reshape((-1,1,2))
# FIXME:这里可能有问题,没做越界
cv2.polylines(imgA_warp,[warp_kpm],True,(0,0,255),thickness=5)
return imgA_warp

@ -5,23 +5,53 @@
@Description:
'''
from collections import namedtuple
from typing import Optional, Union
import torch
import cv2
import numpy as np
from src.loftr import LoFTR, default_cfg
from .debug_tools import plot_kp
DebugInfo = namedtuple(
"DebugInfo",
KeyPointsDebugInfo = namedtuple(
"KeyPointsDebugInfo",
["kp0_fake_match", "kp1_fake_match", "kp0_true_match", "kp1_true_match"])
class DebugInfoCollector(object):
def __init__(self,
imgA: Optional[np.ndarray] = None,
imgB: Optional[np.ndarray] = None,
pts_info: Optional[KeyPointsDebugInfo] = None,
H: Optional[np.ndarray] = None):
self.imgA = imgA
self.imgB = imgB
self.pts_info = pts_info
self.H = H
def clean(self):
self.imgA = None
self.imgB = None
self.pts_info = None
self.H = None
def __str__(self):
return "{}".format({
"imgA": self.imgA,
"imgB": self.imgB,
"pts_info": self.pts_info,
"H": self.H
})
class LoFTRWorker(object):
def __init__(self,
config,
ckpt_path,
img_size=(640, 480),
device="cuda:0",
thr=0.5,
ransc_method="USAC_MAGSAC",
@ -37,13 +67,16 @@ class LoFTRWorker(object):
self.ransc_method = getattr(cv2, ransc_method)
self.ransc_thr = ransc_thr
self.ransc_max_iter = ransc_max_iter
self.img_size = img_size
def _imgdeal(self, img):
def _img2gray(self, img):
if len(img.shape) == 3 and img.shape[-1] == 3:
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
return img
def _imgdeal(self, img):
oh, ow = img.shape[:2]
img = cv2.resize(img, (640, 480))
img = cv2.resize(img, self.img_size)
h, w = img.shape[:2]
fix_matrix = np.array([[w / ow, 0, 0], [0, h / oh, 0], [0, 0, 1]])
return img, fix_matrix
@ -51,9 +84,19 @@ class LoFTRWorker(object):
def _fix_H(self, fm0, fm1, H):
return np.linalg.inv(fm0) @ H @ fm1
def __call__(self, img0, img1, debug=""):
img0, fm0 = self._imgdeal(img0)
img1, fm1 = self._imgdeal(img1)
def __call__(self,
img0,
img1,
debug: Union[bool, str] = False,
debug_show_type: tuple = (
"vis",
"false",
"true",
)):
img0_o, fm0 = self._imgdeal(img0)
img1_o, fm1 = self._imgdeal(img1)
img0 = self._img2gray(img0_o)
img1 = self._img2gray(img1_o)
img0 = torch.from_numpy(img0)[None][None].cuda() / 255.
img1 = torch.from_numpy(img1)[None][None].cuda() / 255.
@ -69,10 +112,9 @@ class LoFTRWorker(object):
mkpts0 = mkpts0[idx]
mkpts1 = mkpts1[idx]
debug_info = None
H = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.float)
if mkpts0.shape[0] < 4 or mkpts1.shape[0] < 4:
return self._fix_H(fm0, fm1, H), False, debug_info
return self._fix_H(fm0, fm1, H), False
H, Mask = cv2.findHomography(mkpts0[:, :2],
mkpts1[:, :2],
@ -87,10 +129,12 @@ class LoFTRWorker(object):
kp0_fake_matched = mkpts0[~Mask.astype(bool), :2]
kp1_fake_matched = mkpts1[~Mask.astype(bool), :2]
debug_info = DebugInfo(kp0_fake_matched, kp1_fake_matched,
kp0_true_matched, kp1_true_matched)
kpdi = KeyPointsDebugInfo(kp0_fake_matched, kp1_fake_matched,
kp0_true_matched, kp1_true_matched)
debug_info = DebugInfoCollector(img0_o, img1_o, kpdi, H)
plot_kp(debug_info, show_flag=debug_show_type, debug_save=debug)
if H is None:
return self._fix_H(fm0, fm1, H), False, debug_info
return self._fix_H(fm0, fm1, H), False
else:
return self._fix_H(fm0, fm1, H), True, debug_info
return self._fix_H(fm0, fm1, H), True

@ -2,7 +2,7 @@
@Author: captainfffsama
@Date: 2023-02-02 15:59:46
@LastEditors: captainfffsama tuanzhangsama@outlook.com
@LastEditTime: 2023-02-02 16:43:55
@LastEditTime: 2023-02-03 10:35:44
@FilePath: /LoFTR/src/grpc/server.py
@Description:
'''
@ -18,25 +18,36 @@ from .utils import decode_img_from_proto, np2tensor_proto, img2pb_img
class LoFTRServer(LoftrServicer):
def __init__(self,
ckpt_path,
img_size=(640, 480),
device="cuda:0",
thr=0.5,
ransc_method="USAC_MAGSAC",
ransc_thr=3,
ransc_max_iter=2000,
debug=False,
debug_show_type=(
"vis",
"false",
"true",
),
*args,
**kwargs):
super().__init__(*args, **kwargs)
self.worker = LoFTRWorker(
default_cfg,
ckpt_path,
img_size,
device,
thr,
ransc_method,
ransc_thr,
ransc_max_iter,
)
self.debug = debug
self.debug_show_type = debug_show_type
def getEssentialMatrix(self, request, context):
imgA = decode_img_from_proto(request.imageA)
@ -44,8 +55,8 @@ class LoFTRServer(LoftrServicer):
if imgA is None or imgB is None:
return loftr_pb2.GetEssentialMatrixReply(matrix=np2tensor_proto(
np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.float)),
status=-14)
H, flag = self.worker(imgA, imgB)
status=-14)
H, flag = self.worker(imgA, imgB, self.debug, self.debug_show_type)
status = 0 if flag else -14
return loftr_pb2.GetEssentialMatrixReply(matrix=np2tensor_proto(H),
status=status)
status=status)

Loading…
Cancel
Save