parent
c6f6a89d8a
commit
e72b217c1e
4 changed files with 169 additions and 35 deletions
@ -0,0 +1,63 @@ |
|||||||
|
# -*- coding: utf-8 -*- |
||||||
|
''' |
||||||
|
@Author: captainfffsama |
||||||
|
@Date: 2023-02-02 16:40:37 |
||||||
|
@LastEditors: captainfffsama tuanzhangsama@outlook.com |
||||||
|
@LastEditTime: 2023-02-02 16:42:41 |
||||||
|
@FilePath: /LoFTR/src/grpc/base_cfg.py |
||||||
|
@Description: |
||||||
|
''' |
||||||
|
import json |
||||||
|
import yaml |
||||||
|
|
||||||
|
param = dict(grpc=dict(host='127.0.0.1', |
||||||
|
port='8001', |
||||||
|
max_workers=10, |
||||||
|
max_send_message_length=100 * 1024 * 1024, |
||||||
|
max_receive_message_length=100 * 1024 * 1024), |
||||||
|
loftr=dict(ckpt_path="",device="cuda:0",thr=0.5,ransc_method="USAC_MAGSAC",ransc_thr=3,ransc_max_iter=2000, |
||||||
|
)) |
||||||
|
|
||||||
|
|
||||||
|
def _update(dic1: dict, dic2: dict): |
||||||
|
"""使用dic2 来递归更新 dic1 |
||||||
|
# NOTE: |
||||||
|
1. dic1 本体是会被更改的!!! |
||||||
|
2. python 本身没有做尾递归优化的,dict深度超大时候可能爆栈 |
||||||
|
""" |
||||||
|
for k, v in dic2.items(): |
||||||
|
if k.endswith('args') and v is None: |
||||||
|
dic2[k] = {} |
||||||
|
if k in dic1: |
||||||
|
if isinstance(v, dict) and isinstance(dic1[k], dict): |
||||||
|
_update(dic1[k], dic2[k]) |
||||||
|
else: |
||||||
|
dic1[k] = dic2[k] |
||||||
|
else: |
||||||
|
dic1[k] = dic2[k] |
||||||
|
|
||||||
|
|
||||||
|
def _merge_yaml(yaml_path: str): |
||||||
|
global param |
||||||
|
with open(yaml_path, 'r') as fr: |
||||||
|
content_dict = yaml.load(fr, yaml.FullLoader) |
||||||
|
_update(param, content_dict) |
||||||
|
|
||||||
|
|
||||||
|
def _merge_json(json_path: str): |
||||||
|
global param |
||||||
|
with open(json_path, 'r') as fr: |
||||||
|
content_dict = json.load(fr) |
||||||
|
_update(param, content_dict) |
||||||
|
|
||||||
|
|
||||||
|
def merge_param(file_path: str): |
||||||
|
"""按照用户传入的配置文件更新基本设置 |
||||||
|
""" |
||||||
|
cfg_ext = file_path.split('.')[-1] |
||||||
|
func_name = '_merge_' + cfg_ext |
||||||
|
if func_name not in globals(): |
||||||
|
raise ValueError('{} is not support'.format(cfg_ext)) |
||||||
|
else: |
||||||
|
globals()[func_name](file_path) |
||||||
|
|
@ -0,0 +1,62 @@ |
|||||||
|
# -*- coding: utf-8 -*- |
||||||
|
''' |
||||||
|
@Author: captainfffsama |
||||||
|
@Date: 2023-02-02 16:38:45 |
||||||
|
@LastEditors: captainfffsama tuanzhangsama@outlook.com |
||||||
|
@LastEditTime: 2023-02-02 16:48:12 |
||||||
|
@FilePath: /LoFTR/start_grpc.py |
||||||
|
@Description: |
||||||
|
''' |
||||||
|
from concurrent import futures |
||||||
|
import sys |
||||||
|
from pprint import pprint |
||||||
|
import os |
||||||
|
import yaml |
||||||
|
|
||||||
|
import grpc |
||||||
|
|
||||||
|
from src.grpc.loftr_pb2_grpc import add_LoftrServicer_to_server |
||||||
|
import src.grpc.base_cfg as cfg |
||||||
|
from src.grpc.server import LoFTRServer |
||||||
|
|
||||||
|
|
||||||
|
def start_server(config): |
||||||
|
if not os.path.exists(config): |
||||||
|
raise FileExistsError('{} 不存在'.format(config)) |
||||||
|
cfg.merge_param(config) |
||||||
|
args_dict: dict = cfg.param |
||||||
|
pprint(args_dict) |
||||||
|
|
||||||
|
grpc_args = args_dict['grpc'] |
||||||
|
model_args = args_dict['loftr'] |
||||||
|
# 最大限制为100M |
||||||
|
server = grpc.server( |
||||||
|
futures.ThreadPoolExecutor(max_workers=grpc_args['max_workers']), |
||||||
|
options=[('grpc.max_send_message_length', |
||||||
|
grpc_args['max_send_message_length']), |
||||||
|
('grpc.max_receive_message_length', |
||||||
|
grpc_args['max_receive_message_length'])]) |
||||||
|
|
||||||
|
loftr_server =LoFTRServer(**model_args) |
||||||
|
add_LoftrServicer_to_server(loftr_server,server) |
||||||
|
server.add_insecure_port("{}:{}".format(grpc_args['host'], |
||||||
|
grpc_args['port'])) |
||||||
|
server.start() |
||||||
|
server.wait_for_termination() |
||||||
|
|
||||||
|
|
||||||
|
def main(args=None): |
||||||
|
import argparse |
||||||
|
parser = argparse.ArgumentParser(description="grpc调用loftr,需要配置文件") |
||||||
|
parser.add_argument("-c", "--config", type=str, default="", help="配置文件地址") |
||||||
|
options = parser.parse_args(args) |
||||||
|
if options.config: |
||||||
|
start_server(options.config) |
||||||
|
|
||||||
|
if __name__ == "__main__": |
||||||
|
rc = 1 |
||||||
|
try: |
||||||
|
main() |
||||||
|
except Exception as e: |
||||||
|
print('Error: %s' % e, file=sys.stderr) |
||||||
|
sys.exit(rc) |
Loading…
Reference in new issue