parent
c6f6a89d8a
commit
e72b217c1e
4 changed files with 169 additions and 35 deletions
@ -0,0 +1,63 @@ |
||||
# -*- coding: utf-8 -*- |
||||
''' |
||||
@Author: captainfffsama |
||||
@Date: 2023-02-02 16:40:37 |
||||
@LastEditors: captainfffsama tuanzhangsama@outlook.com |
||||
@LastEditTime: 2023-02-02 16:42:41 |
||||
@FilePath: /LoFTR/src/grpc/base_cfg.py |
||||
@Description: |
||||
''' |
||||
import json |
||||
import yaml |
||||
|
||||
param = dict(grpc=dict(host='127.0.0.1', |
||||
port='8001', |
||||
max_workers=10, |
||||
max_send_message_length=100 * 1024 * 1024, |
||||
max_receive_message_length=100 * 1024 * 1024), |
||||
loftr=dict(ckpt_path="",device="cuda:0",thr=0.5,ransc_method="USAC_MAGSAC",ransc_thr=3,ransc_max_iter=2000, |
||||
)) |
||||
|
||||
|
||||
def _update(dic1: dict, dic2: dict): |
||||
"""使用dic2 来递归更新 dic1 |
||||
# NOTE: |
||||
1. dic1 本体是会被更改的!!! |
||||
2. python 本身没有做尾递归优化的,dict深度超大时候可能爆栈 |
||||
""" |
||||
for k, v in dic2.items(): |
||||
if k.endswith('args') and v is None: |
||||
dic2[k] = {} |
||||
if k in dic1: |
||||
if isinstance(v, dict) and isinstance(dic1[k], dict): |
||||
_update(dic1[k], dic2[k]) |
||||
else: |
||||
dic1[k] = dic2[k] |
||||
else: |
||||
dic1[k] = dic2[k] |
||||
|
||||
|
||||
def _merge_yaml(yaml_path: str): |
||||
global param |
||||
with open(yaml_path, 'r') as fr: |
||||
content_dict = yaml.load(fr, yaml.FullLoader) |
||||
_update(param, content_dict) |
||||
|
||||
|
||||
def _merge_json(json_path: str): |
||||
global param |
||||
with open(json_path, 'r') as fr: |
||||
content_dict = json.load(fr) |
||||
_update(param, content_dict) |
||||
|
||||
|
||||
def merge_param(file_path: str): |
||||
"""按照用户传入的配置文件更新基本设置 |
||||
""" |
||||
cfg_ext = file_path.split('.')[-1] |
||||
func_name = '_merge_' + cfg_ext |
||||
if func_name not in globals(): |
||||
raise ValueError('{} is not support'.format(cfg_ext)) |
||||
else: |
||||
globals()[func_name](file_path) |
||||
|
@ -0,0 +1,62 @@ |
||||
# -*- coding: utf-8 -*- |
||||
''' |
||||
@Author: captainfffsama |
||||
@Date: 2023-02-02 16:38:45 |
||||
@LastEditors: captainfffsama tuanzhangsama@outlook.com |
||||
@LastEditTime: 2023-02-02 16:48:12 |
||||
@FilePath: /LoFTR/start_grpc.py |
||||
@Description: |
||||
''' |
||||
from concurrent import futures |
||||
import sys |
||||
from pprint import pprint |
||||
import os |
||||
import yaml |
||||
|
||||
import grpc |
||||
|
||||
from src.grpc.loftr_pb2_grpc import add_LoftrServicer_to_server |
||||
import src.grpc.base_cfg as cfg |
||||
from src.grpc.server import LoFTRServer |
||||
|
||||
|
||||
def start_server(config): |
||||
if not os.path.exists(config): |
||||
raise FileExistsError('{} 不存在'.format(config)) |
||||
cfg.merge_param(config) |
||||
args_dict: dict = cfg.param |
||||
pprint(args_dict) |
||||
|
||||
grpc_args = args_dict['grpc'] |
||||
model_args = args_dict['loftr'] |
||||
# 最大限制为100M |
||||
server = grpc.server( |
||||
futures.ThreadPoolExecutor(max_workers=grpc_args['max_workers']), |
||||
options=[('grpc.max_send_message_length', |
||||
grpc_args['max_send_message_length']), |
||||
('grpc.max_receive_message_length', |
||||
grpc_args['max_receive_message_length'])]) |
||||
|
||||
loftr_server =LoFTRServer(**model_args) |
||||
add_LoftrServicer_to_server(loftr_server,server) |
||||
server.add_insecure_port("{}:{}".format(grpc_args['host'], |
||||
grpc_args['port'])) |
||||
server.start() |
||||
server.wait_for_termination() |
||||
|
||||
|
||||
def main(args=None): |
||||
import argparse |
||||
parser = argparse.ArgumentParser(description="grpc调用loftr,需要配置文件") |
||||
parser.add_argument("-c", "--config", type=str, default="", help="配置文件地址") |
||||
options = parser.parse_args(args) |
||||
if options.config: |
||||
start_server(options.config) |
||||
|
||||
if __name__ == "__main__": |
||||
rc = 1 |
||||
try: |
||||
main() |
||||
except Exception as e: |
||||
print('Error: %s' % e, file=sys.stderr) |
||||
sys.exit(rc) |
Loading…
Reference in new issue