diff --git a/.gitignore b/.gitignore index 3662ff6..988536e 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ __pycache__/ *.pth tmp.* */.ipynb_checkpoints/* +test/ logs/ weights/ diff --git a/src/grpc/__init__.py b/src/grpc/__init__.py new file mode 100644 index 0000000..386bd46 --- /dev/null +++ b/src/grpc/__init__.py @@ -0,0 +1,9 @@ +# -*- coding: utf-8 -*- +''' +@Author: captainfffsama +@Date: 2023-02-03 10:03:55 +@LastEditors: captainfffsama tuanzhangsama@outlook.com +@LastEditTime: 2023-02-03 10:03:56 +@FilePath: /LoFTR/src/grpc/__init__.py +@Description: +''' diff --git a/src/grpc/base_cfg.py b/src/grpc/base_cfg.py index 29ab1a4..fe1621e 100644 --- a/src/grpc/base_cfg.py +++ b/src/grpc/base_cfg.py @@ -3,7 +3,7 @@ @Author: captainfffsama @Date: 2023-02-02 16:40:37 @LastEditors: captainfffsama tuanzhangsama@outlook.com -@LastEditTime: 2023-02-02 16:42:41 +@LastEditTime: 2023-02-03 10:34:31 @FilePath: /LoFTR/src/grpc/base_cfg.py @Description: ''' @@ -15,8 +15,19 @@ param = dict(grpc=dict(host='127.0.0.1', max_workers=10, max_send_message_length=100 * 1024 * 1024, max_receive_message_length=100 * 1024 * 1024), - loftr=dict(ckpt_path="",device="cuda:0",thr=0.5,ransc_method="USAC_MAGSAC",ransc_thr=3,ransc_max_iter=2000, -)) + loftr=dict(ckpt_path="", + img_size=(640, 480), + device="cuda:0", + thr=0.5, + ransc_method="USAC_MAGSAC", + ransc_thr=3, + ransc_max_iter=2000, + debug=False, + debug_show_type=( + "vis", + "false", + "true", + ))) def _update(dic1: dict, dic2: dict): @@ -60,4 +71,3 @@ def merge_param(file_path: str): raise ValueError('{} is not support'.format(cfg_ext)) else: globals()[func_name](file_path) - diff --git a/src/grpc/debug_tools.py b/src/grpc/debug_tools.py new file mode 100644 index 0000000..0bfd2fe --- /dev/null +++ b/src/grpc/debug_tools.py @@ -0,0 +1,174 @@ +# -*- coding: utf-8 -*- +''' +@Author: captainfffsama +@Date: 2023-02-03 09:56:39 +@LastEditors: captainfffsama tuanzhangsama@outlook.com +@LastEditTime: 2023-02-03 10:12:32 +@FilePath: /LoFTR/src/grpc/debug_tools.py +@Description: +''' +import os +from typing import Optional +import math +from copy import deepcopy +import random + +import cv2 +import matplotlib.pyplot as plt +import matplotlib.colors as mcolors +import matplotlib.patches as patches +import numpy as np +import torch +from torchvision.utils import make_grid + + +COLOR_DICT=mcolors.CSS4_COLORS + +def get_color(): + return random.sample(COLOR_DICT.keys(),1)[0] + +def normalization(data): + _range = np.max(data) - np.min(data) + return (data - np.min(data)) / _range *255 + + +def show_img(img_ori,text:Optional[str]=None,cvreader:bool=True,delay:float=0): + img=deepcopy(img_ori) + if isinstance(img, list) or isinstance(img, tuple): + img_num = len(img) + row_n = math.ceil(math.sqrt(img_num)) + col_n = max(math.ceil(img_num / row_n), 1) + fig, axs = plt.subplots(row_n, col_n, figsize=(15 * row_n, 15 * col_n)) + for idx, img_ in enumerate(img): + if isinstance(img_,torch.Tensor) or isinstance(img_,np.ndarray): + img_=show_tensor(img_,cvreader) + if 2 == len(axs.shape): + axs[idx % row_n][idx // row_n].imshow(img_) + axs[idx % row_n][idx // row_n].set_title(str(idx)) + else: + axs[idx % row_n].imshow(img_) + axs[idx % row_n].set_title(str(idx)) + if text: + plt.text(0,0,text,fontsize=15) + if delay <=0: + plt.show() + else: + plt.draw() + plt.pause(delay) + plt.close() + elif isinstance(img,torch.Tensor) or isinstance(img,np.ndarray): + img=show_tensor(img,cvreader) + plt.imshow(img) + if text: + plt.text(0,0,text,fontsize=15) + if delay <=0: + plt.show() + else: + plt.draw() + plt.pause(delay) + plt.close() + else: + if hasattr(img,'show'): + img.show() + + +def show_tensor(img,cvreader): + if len(img.shape) == 4 and img.shape[0] !=1: + if isinstance(img,np.ndarray): + img=torch.Tensor(img) + img=make_grid(img) + else: + img=make_grid(img) + img: np.ndarray = img.detach().cpu().numpy().squeeze() + if isinstance(img,torch.Tensor): + img=img.detach().cpu().numpy().squeeze() + if 1==img.max(): + img=img.astype(np.float) + if img.shape[0] == 1 or img.shape[0] == 3: + img = np.transpose(img, (1, 2, 0)) + if img.min() <0 or img.max() >255: + img=normalization(img) + print("img have norm") + if img.shape[-1] == 3: + img=img.astype(np.uint8) + if cvreader and len(img.shape)==3: + img=img[:,:,::-1] + return img + +def draw_point(axs,kp:np.ndarray,color:str): + kp=kp.astype(int) + axs.plot(kp[:,0],kp[:,1],'o',color=color) + +def plot_kp(debug_info,show_flag=('fake','true','no_match',"vis"),debug_save=""): + fig, axs = plt.subplots(2, 1) + axs[0].set_xticks([]) + axs[0].set_yticks([]) + axs[0].imshow(debug_info.imgA[:,:,::-1]) + + axs[1].set_yticks([]) + axs[1].set_yticks([]) + axs[1].imshow(debug_info.imgB[:,:,::-1]) + + if "fake" in show_flag: + draw_point(axs[0],debug_info.pts_info.kp0_fake_match,'y') + draw_point(axs[1],debug_info.pts_info.kp1_fake_match,'y') + for kp_a,kp_b in zip(debug_info.pts_info.kp0_fake_match,debug_info.pts_info.kp1_fake_match): + color=get_color() + con = patches.ConnectionPatch( + xyA=kp_a, coordsA=axs[0].transData, + xyB=kp_b, coordsB=axs[1].transData, + arrowstyle="-", shrinkB=5,color=color) + axs[1].add_artist(con) + + if "true" in show_flag: + draw_point(axs[0],debug_info.pts_info.kp0_true_match,'g') + draw_point(axs[1],debug_info.pts_info.kp1_true_match,'g') + for kp_a,kp_b in zip(debug_info.pts_info.kp0_true_match,debug_info.pts_info.kp1_true_match): + color=get_color() + con = patches.ConnectionPatch( + xyA=kp_a, coordsA=axs[0].transData, + xyB=kp_b, coordsB=axs[1].transData, + arrowstyle="-", shrinkB=5,color=color) + axs[1].add_artist(con) + + if "vis" in show_flag: + fig2=plt.figure() + axe2=fig2.gca() + fix_img=generate_transform_pic((debug_info.imgB.shape[1],debug_info.imgB.shape[0]),debug_info.imgA,debug_info.H) + final_img=cv2.addWeighted(debug_info.imgB,0.6,fix_img,0.4,0) + axe2.imshow(final_img[:,:,::-1]) + + + if isinstance(debug_save,str): + save_path,ext=os.path.splitext(debug_save) + fig.savefig(save_path+"nft.jpg",dpi=400,bbox_inches='tight') + fig2.savefig(save_path+"vis.jpg",dpi=400,bbox_inches='tight') + plt.clf() + plt.close("all") + else: + plt.clf() + plt.show() + +def convertpt2matrix(pts): + if isinstance(pts,list): + pts=np.array([x+[1] for x in pts]) + if pts.shape[-1]==2: + expend_m=np.array([[1],[1],[1],[1]]) + pts=np.concatenate((pts,expend_m),axis=1) + return pts + +def generate_transform_pic(imgb_shape,imgA,H): + imgA_warp = cv2.warpPerspective(imgA, + H, imgb_shape, + flags=cv2.INTER_LINEAR) + + k_pts=[[0,0],[0,imgA.shape[0]-1],[imgA.shape[1]-1,imgA.shape[0]-1],[imgA.shape[1]-1,0]] + + kptm=convertpt2matrix(k_pts) + warp_kpm=np.matmul(H,kptm.T) + warp_kpm=(warp_kpm/warp_kpm[2,:])[:2,:].astype(int).T.reshape((-1,1,2)) + # FIXME:这里可能有问题,没做越界 + cv2.polylines(imgA_warp,[warp_kpm],True,(0,0,255),thickness=5) + return imgA_warp + + diff --git a/src/grpc/loftr_worker.py b/src/grpc/loftr_worker.py index 7bdd33c..89de97c 100644 --- a/src/grpc/loftr_worker.py +++ b/src/grpc/loftr_worker.py @@ -5,23 +5,53 @@ @Description: ''' from collections import namedtuple +from typing import Optional, Union import torch import cv2 import numpy as np from src.loftr import LoFTR, default_cfg +from .debug_tools import plot_kp -DebugInfo = namedtuple( - "DebugInfo", +KeyPointsDebugInfo = namedtuple( + "KeyPointsDebugInfo", ["kp0_fake_match", "kp1_fake_match", "kp0_true_match", "kp1_true_match"]) +class DebugInfoCollector(object): + + def __init__(self, + imgA: Optional[np.ndarray] = None, + imgB: Optional[np.ndarray] = None, + pts_info: Optional[KeyPointsDebugInfo] = None, + H: Optional[np.ndarray] = None): + self.imgA = imgA + self.imgB = imgB + self.pts_info = pts_info + self.H = H + + def clean(self): + self.imgA = None + self.imgB = None + self.pts_info = None + self.H = None + + def __str__(self): + return "{}".format({ + "imgA": self.imgA, + "imgB": self.imgB, + "pts_info": self.pts_info, + "H": self.H + }) + + class LoFTRWorker(object): def __init__(self, config, ckpt_path, + img_size=(640, 480), device="cuda:0", thr=0.5, ransc_method="USAC_MAGSAC", @@ -37,13 +67,16 @@ class LoFTRWorker(object): self.ransc_method = getattr(cv2, ransc_method) self.ransc_thr = ransc_thr self.ransc_max_iter = ransc_max_iter + self.img_size = img_size - def _imgdeal(self, img): + def _img2gray(self, img): if len(img.shape) == 3 and img.shape[-1] == 3: img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + return img + def _imgdeal(self, img): oh, ow = img.shape[:2] - img = cv2.resize(img, (640, 480)) + img = cv2.resize(img, self.img_size) h, w = img.shape[:2] fix_matrix = np.array([[w / ow, 0, 0], [0, h / oh, 0], [0, 0, 1]]) return img, fix_matrix @@ -51,9 +84,19 @@ class LoFTRWorker(object): def _fix_H(self, fm0, fm1, H): return np.linalg.inv(fm0) @ H @ fm1 - def __call__(self, img0, img1, debug=""): - img0, fm0 = self._imgdeal(img0) - img1, fm1 = self._imgdeal(img1) + def __call__(self, + img0, + img1, + debug: Union[bool, str] = False, + debug_show_type: tuple = ( + "vis", + "false", + "true", + )): + img0_o, fm0 = self._imgdeal(img0) + img1_o, fm1 = self._imgdeal(img1) + img0 = self._img2gray(img0_o) + img1 = self._img2gray(img1_o) img0 = torch.from_numpy(img0)[None][None].cuda() / 255. img1 = torch.from_numpy(img1)[None][None].cuda() / 255. @@ -69,10 +112,9 @@ class LoFTRWorker(object): mkpts0 = mkpts0[idx] mkpts1 = mkpts1[idx] - debug_info = None H = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.float) if mkpts0.shape[0] < 4 or mkpts1.shape[0] < 4: - return self._fix_H(fm0, fm1, H), False, debug_info + return self._fix_H(fm0, fm1, H), False H, Mask = cv2.findHomography(mkpts0[:, :2], mkpts1[:, :2], @@ -87,10 +129,12 @@ class LoFTRWorker(object): kp0_fake_matched = mkpts0[~Mask.astype(bool), :2] kp1_fake_matched = mkpts1[~Mask.astype(bool), :2] - debug_info = DebugInfo(kp0_fake_matched, kp1_fake_matched, - kp0_true_matched, kp1_true_matched) + kpdi = KeyPointsDebugInfo(kp0_fake_matched, kp1_fake_matched, + kp0_true_matched, kp1_true_matched) + debug_info = DebugInfoCollector(img0_o, img1_o, kpdi, H) + plot_kp(debug_info, show_flag=debug_show_type, debug_save=debug) if H is None: - return self._fix_H(fm0, fm1, H), False, debug_info + return self._fix_H(fm0, fm1, H), False else: - return self._fix_H(fm0, fm1, H), True, debug_info + return self._fix_H(fm0, fm1, H), True diff --git a/src/grpc/server.py b/src/grpc/server.py index 3c333dc..75c50fe 100644 --- a/src/grpc/server.py +++ b/src/grpc/server.py @@ -2,7 +2,7 @@ @Author: captainfffsama @Date: 2023-02-02 15:59:46 @LastEditors: captainfffsama tuanzhangsama@outlook.com -@LastEditTime: 2023-02-02 16:43:55 +@LastEditTime: 2023-02-03 10:35:44 @FilePath: /LoFTR/src/grpc/server.py @Description: ''' @@ -18,25 +18,36 @@ from .utils import decode_img_from_proto, np2tensor_proto, img2pb_img class LoFTRServer(LoftrServicer): + def __init__(self, ckpt_path, + img_size=(640, 480), device="cuda:0", thr=0.5, ransc_method="USAC_MAGSAC", ransc_thr=3, ransc_max_iter=2000, + debug=False, + debug_show_type=( + "vis", + "false", + "true", + ), *args, **kwargs): super().__init__(*args, **kwargs) self.worker = LoFTRWorker( default_cfg, ckpt_path, + img_size, device, thr, ransc_method, ransc_thr, ransc_max_iter, ) + self.debug = debug + self.debug_show_type = debug_show_type def getEssentialMatrix(self, request, context): imgA = decode_img_from_proto(request.imageA) @@ -44,8 +55,8 @@ class LoFTRServer(LoftrServicer): if imgA is None or imgB is None: return loftr_pb2.GetEssentialMatrixReply(matrix=np2tensor_proto( np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.float)), - status=-14) - H, flag = self.worker(imgA, imgB) + status=-14) + H, flag = self.worker(imgA, imgB, self.debug, self.debug_show_type) status = 0 if flag else -14 return loftr_pb2.GetEssentialMatrixReply(matrix=np2tensor_proto(H), - status=status) + status=status)