补一个loftr的grpc

chiebot
captainfffsama 2 years ago
parent 94e98b695b
commit c6f6a89d8a
  1. 21
      src/grpc/loftr.proto
  2. 263
      src/grpc/loftr_pb2.py
  3. 66
      src/grpc/loftr_pb2_grpc.py
  4. 88
      src/grpc/loftr_worker.py
  5. 50
      src/grpc/server.py
  6. 56
      src/grpc/utils.py

@ -0,0 +1,21 @@
syntax = "proto3";
package LOFRT;
service Loftr{
rpc getEssentialMatrix(ImagePair) returns (GetEssentialMatrixReply) {};//
}
message Image{
optional bytes image = 1;
optional string path = 2;
}
message ImagePair{
Image imageA=1;
Image imageB=2;
}
message Tensor {
repeated float data =1;
repeated int32 shape =2;
}
message GetEssentialMatrixReply {
Tensor matrix =1;
int32 status=2;
}

@ -0,0 +1,263 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: loftr.proto
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor.FileDescriptor(
name='loftr.proto',
package='LOFRT',
syntax='proto3',
serialized_options=None,
create_key=_descriptor._internal_create_key,
serialized_pb=b'\n\x0bloftr.proto\x12\x05LOFRT\"A\n\x05Image\x12\x12\n\x05image\x18\x01 \x01(\x0cH\x00\x88\x01\x01\x12\x11\n\x04path\x18\x02 \x01(\tH\x01\x88\x01\x01\x42\x08\n\x06_imageB\x07\n\x05_path\"G\n\tImagePair\x12\x1c\n\x06imageA\x18\x01 \x01(\x0b\x32\x0c.LOFRT.Image\x12\x1c\n\x06imageB\x18\x02 \x01(\x0b\x32\x0c.LOFRT.Image\"%\n\x06Tensor\x12\x0c\n\x04\x64\x61ta\x18\x01 \x03(\x02\x12\r\n\x05shape\x18\x02 \x03(\x05\"H\n\x17GetEssentialMatrixReply\x12\x1d\n\x06matrix\x18\x01 \x01(\x0b\x32\r.LOFRT.Tensor\x12\x0e\n\x06status\x18\x02 \x01(\x05\x32Q\n\x05Loftr\x12H\n\x12getEssentialMatrix\x12\x10.LOFRT.ImagePair\x1a\x1e.LOFRT.GetEssentialMatrixReply\"\x00\x62\x06proto3'
)
_IMAGE = _descriptor.Descriptor(
name='Image',
full_name='LOFRT.Image',
filename=None,
file=DESCRIPTOR,
containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[
_descriptor.FieldDescriptor(
name='image', full_name='LOFRT.Image.image', index=0,
number=1, type=12, cpp_type=9, label=1,
has_default_value=False, default_value=b"",
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='path', full_name='LOFRT.Image.path', index=1,
number=2, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
_descriptor.OneofDescriptor(
name='_image', full_name='LOFRT.Image._image',
index=0, containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[]),
_descriptor.OneofDescriptor(
name='_path', full_name='LOFRT.Image._path',
index=1, containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[]),
],
serialized_start=22,
serialized_end=87,
)
_IMAGEPAIR = _descriptor.Descriptor(
name='ImagePair',
full_name='LOFRT.ImagePair',
filename=None,
file=DESCRIPTOR,
containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[
_descriptor.FieldDescriptor(
name='imageA', full_name='LOFRT.ImagePair.imageA', index=0,
number=1, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='imageB', full_name='LOFRT.ImagePair.imageB', index=1,
number=2, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=89,
serialized_end=160,
)
_TENSOR = _descriptor.Descriptor(
name='Tensor',
full_name='LOFRT.Tensor',
filename=None,
file=DESCRIPTOR,
containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[
_descriptor.FieldDescriptor(
name='data', full_name='LOFRT.Tensor.data', index=0,
number=1, type=2, cpp_type=6, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='shape', full_name='LOFRT.Tensor.shape', index=1,
number=2, type=5, cpp_type=1, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=162,
serialized_end=199,
)
_GETESSENTIALMATRIXREPLY = _descriptor.Descriptor(
name='GetEssentialMatrixReply',
full_name='LOFRT.GetEssentialMatrixReply',
filename=None,
file=DESCRIPTOR,
containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[
_descriptor.FieldDescriptor(
name='matrix', full_name='LOFRT.GetEssentialMatrixReply.matrix', index=0,
number=1, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='status', full_name='LOFRT.GetEssentialMatrixReply.status', index=1,
number=2, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=201,
serialized_end=273,
)
_IMAGE.oneofs_by_name['_image'].fields.append(
_IMAGE.fields_by_name['image'])
_IMAGE.fields_by_name['image'].containing_oneof = _IMAGE.oneofs_by_name['_image']
_IMAGE.oneofs_by_name['_path'].fields.append(
_IMAGE.fields_by_name['path'])
_IMAGE.fields_by_name['path'].containing_oneof = _IMAGE.oneofs_by_name['_path']
_IMAGEPAIR.fields_by_name['imageA'].message_type = _IMAGE
_IMAGEPAIR.fields_by_name['imageB'].message_type = _IMAGE
_GETESSENTIALMATRIXREPLY.fields_by_name['matrix'].message_type = _TENSOR
DESCRIPTOR.message_types_by_name['Image'] = _IMAGE
DESCRIPTOR.message_types_by_name['ImagePair'] = _IMAGEPAIR
DESCRIPTOR.message_types_by_name['Tensor'] = _TENSOR
DESCRIPTOR.message_types_by_name['GetEssentialMatrixReply'] = _GETESSENTIALMATRIXREPLY
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
Image = _reflection.GeneratedProtocolMessageType('Image', (_message.Message,), {
'DESCRIPTOR' : _IMAGE,
'__module__' : 'loftr_pb2'
# @@protoc_insertion_point(class_scope:LOFRT.Image)
})
_sym_db.RegisterMessage(Image)
ImagePair = _reflection.GeneratedProtocolMessageType('ImagePair', (_message.Message,), {
'DESCRIPTOR' : _IMAGEPAIR,
'__module__' : 'loftr_pb2'
# @@protoc_insertion_point(class_scope:LOFRT.ImagePair)
})
_sym_db.RegisterMessage(ImagePair)
Tensor = _reflection.GeneratedProtocolMessageType('Tensor', (_message.Message,), {
'DESCRIPTOR' : _TENSOR,
'__module__' : 'loftr_pb2'
# @@protoc_insertion_point(class_scope:LOFRT.Tensor)
})
_sym_db.RegisterMessage(Tensor)
GetEssentialMatrixReply = _reflection.GeneratedProtocolMessageType('GetEssentialMatrixReply', (_message.Message,), {
'DESCRIPTOR' : _GETESSENTIALMATRIXREPLY,
'__module__' : 'loftr_pb2'
# @@protoc_insertion_point(class_scope:LOFRT.GetEssentialMatrixReply)
})
_sym_db.RegisterMessage(GetEssentialMatrixReply)
_LOFTR = _descriptor.ServiceDescriptor(
name='Loftr',
full_name='LOFRT.Loftr',
file=DESCRIPTOR,
index=0,
serialized_options=None,
create_key=_descriptor._internal_create_key,
serialized_start=275,
serialized_end=356,
methods=[
_descriptor.MethodDescriptor(
name='getEssentialMatrix',
full_name='LOFRT.Loftr.getEssentialMatrix',
index=0,
containing_service=None,
input_type=_IMAGEPAIR,
output_type=_GETESSENTIALMATRIXREPLY,
serialized_options=None,
create_key=_descriptor._internal_create_key,
),
])
_sym_db.RegisterServiceDescriptor(_LOFTR)
DESCRIPTOR.services_by_name['Loftr'] = _LOFTR
# @@protoc_insertion_point(module_scope)

@ -0,0 +1,66 @@
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
from . import loftr_pb2 as loftr__pb2
class LoftrStub(object):
"""Missing associated documentation comment in .proto file."""
def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.getEssentialMatrix = channel.unary_unary(
'/LOFRT.Loftr/getEssentialMatrix',
request_serializer=loftr__pb2.ImagePair.SerializeToString,
response_deserializer=loftr__pb2.GetEssentialMatrixReply.FromString,
)
class LoftrServicer(object):
"""Missing associated documentation comment in .proto file."""
def getEssentialMatrix(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_LoftrServicer_to_server(servicer, server):
rpc_method_handlers = {
'getEssentialMatrix': grpc.unary_unary_rpc_method_handler(
servicer.getEssentialMatrix,
request_deserializer=loftr__pb2.ImagePair.FromString,
response_serializer=loftr__pb2.GetEssentialMatrixReply.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'LOFRT.Loftr', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
# This class is part of an EXPERIMENTAL API.
class Loftr(object):
"""Missing associated documentation comment in .proto file."""
@staticmethod
def getEssentialMatrix(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/LOFRT.Loftr/getEssentialMatrix',
loftr__pb2.ImagePair.SerializeToString,
loftr__pb2.GetEssentialMatrixReply.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

@ -0,0 +1,88 @@
# -*- coding: utf-8 -*-
'''
@Author: CaptainHu
@Date: 2023 02 02 星期四 15:19:14 CST
@Description:
'''
from collections import namedtuple
import torch
import cv2
import numpy as np
from src.loftr import LoFTR, default_cfg
DebugInfo=namedtuple("DebugInfo",
["kp0_fake_match","kp1_fake_match",
"kp0_true_match","kp1_true_match"])
class LoFTRWorker(object):
def __init__(self,
config,
ckpt_path,
device="cuda:0",
thr=0.5,
ransc_method="USAC_MAGSAC",
ransc_thr=3,
ransc_max_iter=2000):
self.model = LoFTR(config=config)
self.model.load_state_dict(torch.load(ckpt_path)['state_dict'])
if device != 'cpu' and not torch.cuda.is_available():
device = 'cpu'
print("ERROR: cuda can not use, will use cpu")
self.model = self.model.eval().to(device)
self.thr=thr
self.ransc_method = getattr(cv2,ransc_method)
self.ransc_thr=ransc_thr
self.ransc_max_iter=ransc_max_iter
def _img2gray(self, img):
if len(img.shape) == 3 and img.shape[-1] == 3:
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
return img
def __call__(self, img0, img1,debug=""):
img0 = self._img1gray(img0)
img1 = self._img1gray(img1)
img0 = torch.from_numpy(img0)[None][None].cuda() / 255.
img1 = torch.from_numpy(img1)[None][None].cuda() / 255.
batch = {'image0': img0, 'image1': img1}
with torch.no_grad():
self.model(batch)
mkpts0 = batch['mkpts0_f'].cpu().numpy()
mkpts1 = batch['mkpts1_f'].cpu().numpy()
mconf = batch['mconf'].cpu().numpy()
idx=np.where(mconf>self.thr)
mconf=mconf[idx]
mkpts0=mkpts0[idx]
mkpts1=mkpts1[idx]
debug_info=None
if mkpts0.shape[0] < 4 or mkpts1.shape[0] < 4:
return np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]],
dtype=np.float), False,debug_info
H, Mask = cv2.findHomography(mkpts0[:, :2],
mkpts1[:, :2],
self.ransc_method,
self.ransc_thr,
maxIters=self.ransc_max_iter)
Mask=np.squeeze(Mask)
if debug:
kp0_true_matched=mkpts0[Mask.astype(bool),:2]
kp1_true_matched=mkpts1[Mask.astype(bool),:2]
kp0_fake_matched=mkpts0[~Mask.astype(bool),:2]
kp1_fake_matched=mkpts1[~Mask.astype(bool),:2]
debug_info=DebugInfo(kp0_fake_matched,kp1_fake_matched,kp0_true_matched,kp1_true_matched)
if H is None:
return np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]],
dtype=np.float), False,debug_info
else:
return H, True,debug_info

@ -0,0 +1,50 @@
'''
@Author: captainfffsama
@Date: 2023-02-02 15:59:46
@LastEditors: captainfffsama tuanzhangsama@outlook.com
@LastEditTime: 2023-02-02 16:08:41
@FilePath: /LoFTR/src/grpc/server.py
@Description:
'''
import numpy as np
from . import loftr_pb2
from .loftr_pb2_grpc import LoftrServicer
from .loftr_worker import LoFTRWorker
from .utils import decode_img_from_proto, np2tensor_proto, img2pb_img
class LoFTRServer(LoftrServicer):
def __init__(self,
config,
ckpt_path,
device="cuda:0",
thr=0.5,
ransc_method="USAC_MAGSAC",
ransc_thr=3,
ransc_max_iter=2000,
*args,
**kwargs):
super().__init__(*args, **kwargs)
self.worker = LoFTRWorker(
config,
ckpt_path,
device,
thr,
ransc_method,
ransc_thr,
ransc_max_iter,
)
def getEssentialMatrix(self, request, context):
imgA = decode_img_from_proto(request.imageA)
imgB = decode_img_from_proto(request.imageB)
if imgA is None or imgB is None:
return loftr_pb2.GetEssentialMatrixReply(matrix=np2tensor_proto(
np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.float)),
status=-14)
H, flag = self.worker(imgA, imgB)
status = 0 if flag else -14
return loftr_pb2.GetEssentialMatrixReply(matrix=np2tensor_proto(H),
status=status)

@ -0,0 +1,56 @@
# -*- coding: utf-8 -*-
'''
@Author: captainfffsama
@Date: 2023-02-02 16:00:45
@LastEditors: captainfffsama tuanzhangsama@outlook.com
@LastEditTime: 2023-02-02 16:03:15
@FilePath: /LoFTR/src/grpc/utils.py
@Description:
'''
import os
import base64
import numpy as np
import cv2
from . import loftr_pb2
def get_img(img_info):
if os.path.isfile(img_info):
if not os.path.exists(img_info):
return None
else:
return cv2.imread(img_info) #ignore
else:
img_str = base64.b64decode(img_info)
img_np = np.fromstring(img_str, np.uint8)
return cv2.imdecode(img_np, cv2.IMREAD_COLOR)
def decode_img_from_proto(proto_image):
if proto_image.image:
return get_img(proto_image.image)
else:
return get_img(proto_image.path)
def np2tensor_proto(np_ndarray: np.ndarray):
shape = list(np_ndarray.shape)
data = np_ndarray.flatten().tolist()
tensor_pb = loftr_pb2.Tensor()
tensor_pb.shape.extend(shape)
tensor_pb.data.extend(data)
return tensor_pb
def img2pb_img(img):
base64_str = cv2.imencode('.jpg', img)[1].tostring()
base64_str = base64.b64encode(base64_str)
return loftr_pb2.Image(image=base64_str)
def tensor_proto2np(tensor_pb):
np_matrix = np.array(tensor_pb.data,
dtype=np.float).reshape(tensor_pb.shape)
return np_matrix
Loading…
Cancel
Save