parent
94e98b695b
commit
c6f6a89d8a
6 changed files with 544 additions and 0 deletions
@ -0,0 +1,21 @@ |
||||
syntax = "proto3"; |
||||
package LOFRT; |
||||
service Loftr{ |
||||
rpc getEssentialMatrix(ImagePair) returns (GetEssentialMatrixReply) {};// 返回投影矩阵 |
||||
} |
||||
message Image{ |
||||
optional bytes image = 1; |
||||
optional string path = 2; |
||||
} |
||||
message ImagePair{ |
||||
Image imageA=1; |
||||
Image imageB=2; |
||||
} |
||||
message Tensor { |
||||
repeated float data =1; |
||||
repeated int32 shape =2; |
||||
} |
||||
message GetEssentialMatrixReply { |
||||
Tensor matrix =1; |
||||
int32 status=2; |
||||
} |
@ -0,0 +1,263 @@ |
||||
# -*- coding: utf-8 -*- |
||||
# Generated by the protocol buffer compiler. DO NOT EDIT! |
||||
# source: loftr.proto |
||||
"""Generated protocol buffer code.""" |
||||
from google.protobuf import descriptor as _descriptor |
||||
from google.protobuf import message as _message |
||||
from google.protobuf import reflection as _reflection |
||||
from google.protobuf import symbol_database as _symbol_database |
||||
# @@protoc_insertion_point(imports) |
||||
|
||||
_sym_db = _symbol_database.Default() |
||||
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor.FileDescriptor( |
||||
name='loftr.proto', |
||||
package='LOFRT', |
||||
syntax='proto3', |
||||
serialized_options=None, |
||||
create_key=_descriptor._internal_create_key, |
||||
serialized_pb=b'\n\x0bloftr.proto\x12\x05LOFRT\"A\n\x05Image\x12\x12\n\x05image\x18\x01 \x01(\x0cH\x00\x88\x01\x01\x12\x11\n\x04path\x18\x02 \x01(\tH\x01\x88\x01\x01\x42\x08\n\x06_imageB\x07\n\x05_path\"G\n\tImagePair\x12\x1c\n\x06imageA\x18\x01 \x01(\x0b\x32\x0c.LOFRT.Image\x12\x1c\n\x06imageB\x18\x02 \x01(\x0b\x32\x0c.LOFRT.Image\"%\n\x06Tensor\x12\x0c\n\x04\x64\x61ta\x18\x01 \x03(\x02\x12\r\n\x05shape\x18\x02 \x03(\x05\"H\n\x17GetEssentialMatrixReply\x12\x1d\n\x06matrix\x18\x01 \x01(\x0b\x32\r.LOFRT.Tensor\x12\x0e\n\x06status\x18\x02 \x01(\x05\x32Q\n\x05Loftr\x12H\n\x12getEssentialMatrix\x12\x10.LOFRT.ImagePair\x1a\x1e.LOFRT.GetEssentialMatrixReply\"\x00\x62\x06proto3' |
||||
) |
||||
|
||||
|
||||
|
||||
|
||||
_IMAGE = _descriptor.Descriptor( |
||||
name='Image', |
||||
full_name='LOFRT.Image', |
||||
filename=None, |
||||
file=DESCRIPTOR, |
||||
containing_type=None, |
||||
create_key=_descriptor._internal_create_key, |
||||
fields=[ |
||||
_descriptor.FieldDescriptor( |
||||
name='image', full_name='LOFRT.Image.image', index=0, |
||||
number=1, type=12, cpp_type=9, label=1, |
||||
has_default_value=False, default_value=b"", |
||||
message_type=None, enum_type=None, containing_type=None, |
||||
is_extension=False, extension_scope=None, |
||||
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), |
||||
_descriptor.FieldDescriptor( |
||||
name='path', full_name='LOFRT.Image.path', index=1, |
||||
number=2, type=9, cpp_type=9, label=1, |
||||
has_default_value=False, default_value=b"".decode('utf-8'), |
||||
message_type=None, enum_type=None, containing_type=None, |
||||
is_extension=False, extension_scope=None, |
||||
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), |
||||
], |
||||
extensions=[ |
||||
], |
||||
nested_types=[], |
||||
enum_types=[ |
||||
], |
||||
serialized_options=None, |
||||
is_extendable=False, |
||||
syntax='proto3', |
||||
extension_ranges=[], |
||||
oneofs=[ |
||||
_descriptor.OneofDescriptor( |
||||
name='_image', full_name='LOFRT.Image._image', |
||||
index=0, containing_type=None, |
||||
create_key=_descriptor._internal_create_key, |
||||
fields=[]), |
||||
_descriptor.OneofDescriptor( |
||||
name='_path', full_name='LOFRT.Image._path', |
||||
index=1, containing_type=None, |
||||
create_key=_descriptor._internal_create_key, |
||||
fields=[]), |
||||
], |
||||
serialized_start=22, |
||||
serialized_end=87, |
||||
) |
||||
|
||||
|
||||
_IMAGEPAIR = _descriptor.Descriptor( |
||||
name='ImagePair', |
||||
full_name='LOFRT.ImagePair', |
||||
filename=None, |
||||
file=DESCRIPTOR, |
||||
containing_type=None, |
||||
create_key=_descriptor._internal_create_key, |
||||
fields=[ |
||||
_descriptor.FieldDescriptor( |
||||
name='imageA', full_name='LOFRT.ImagePair.imageA', index=0, |
||||
number=1, type=11, cpp_type=10, label=1, |
||||
has_default_value=False, default_value=None, |
||||
message_type=None, enum_type=None, containing_type=None, |
||||
is_extension=False, extension_scope=None, |
||||
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), |
||||
_descriptor.FieldDescriptor( |
||||
name='imageB', full_name='LOFRT.ImagePair.imageB', index=1, |
||||
number=2, type=11, cpp_type=10, label=1, |
||||
has_default_value=False, default_value=None, |
||||
message_type=None, enum_type=None, containing_type=None, |
||||
is_extension=False, extension_scope=None, |
||||
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), |
||||
], |
||||
extensions=[ |
||||
], |
||||
nested_types=[], |
||||
enum_types=[ |
||||
], |
||||
serialized_options=None, |
||||
is_extendable=False, |
||||
syntax='proto3', |
||||
extension_ranges=[], |
||||
oneofs=[ |
||||
], |
||||
serialized_start=89, |
||||
serialized_end=160, |
||||
) |
||||
|
||||
|
||||
_TENSOR = _descriptor.Descriptor( |
||||
name='Tensor', |
||||
full_name='LOFRT.Tensor', |
||||
filename=None, |
||||
file=DESCRIPTOR, |
||||
containing_type=None, |
||||
create_key=_descriptor._internal_create_key, |
||||
fields=[ |
||||
_descriptor.FieldDescriptor( |
||||
name='data', full_name='LOFRT.Tensor.data', index=0, |
||||
number=1, type=2, cpp_type=6, label=3, |
||||
has_default_value=False, default_value=[], |
||||
message_type=None, enum_type=None, containing_type=None, |
||||
is_extension=False, extension_scope=None, |
||||
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), |
||||
_descriptor.FieldDescriptor( |
||||
name='shape', full_name='LOFRT.Tensor.shape', index=1, |
||||
number=2, type=5, cpp_type=1, label=3, |
||||
has_default_value=False, default_value=[], |
||||
message_type=None, enum_type=None, containing_type=None, |
||||
is_extension=False, extension_scope=None, |
||||
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), |
||||
], |
||||
extensions=[ |
||||
], |
||||
nested_types=[], |
||||
enum_types=[ |
||||
], |
||||
serialized_options=None, |
||||
is_extendable=False, |
||||
syntax='proto3', |
||||
extension_ranges=[], |
||||
oneofs=[ |
||||
], |
||||
serialized_start=162, |
||||
serialized_end=199, |
||||
) |
||||
|
||||
|
||||
_GETESSENTIALMATRIXREPLY = _descriptor.Descriptor( |
||||
name='GetEssentialMatrixReply', |
||||
full_name='LOFRT.GetEssentialMatrixReply', |
||||
filename=None, |
||||
file=DESCRIPTOR, |
||||
containing_type=None, |
||||
create_key=_descriptor._internal_create_key, |
||||
fields=[ |
||||
_descriptor.FieldDescriptor( |
||||
name='matrix', full_name='LOFRT.GetEssentialMatrixReply.matrix', index=0, |
||||
number=1, type=11, cpp_type=10, label=1, |
||||
has_default_value=False, default_value=None, |
||||
message_type=None, enum_type=None, containing_type=None, |
||||
is_extension=False, extension_scope=None, |
||||
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), |
||||
_descriptor.FieldDescriptor( |
||||
name='status', full_name='LOFRT.GetEssentialMatrixReply.status', index=1, |
||||
number=2, type=5, cpp_type=1, label=1, |
||||
has_default_value=False, default_value=0, |
||||
message_type=None, enum_type=None, containing_type=None, |
||||
is_extension=False, extension_scope=None, |
||||
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), |
||||
], |
||||
extensions=[ |
||||
], |
||||
nested_types=[], |
||||
enum_types=[ |
||||
], |
||||
serialized_options=None, |
||||
is_extendable=False, |
||||
syntax='proto3', |
||||
extension_ranges=[], |
||||
oneofs=[ |
||||
], |
||||
serialized_start=201, |
||||
serialized_end=273, |
||||
) |
||||
|
||||
_IMAGE.oneofs_by_name['_image'].fields.append( |
||||
_IMAGE.fields_by_name['image']) |
||||
_IMAGE.fields_by_name['image'].containing_oneof = _IMAGE.oneofs_by_name['_image'] |
||||
_IMAGE.oneofs_by_name['_path'].fields.append( |
||||
_IMAGE.fields_by_name['path']) |
||||
_IMAGE.fields_by_name['path'].containing_oneof = _IMAGE.oneofs_by_name['_path'] |
||||
_IMAGEPAIR.fields_by_name['imageA'].message_type = _IMAGE |
||||
_IMAGEPAIR.fields_by_name['imageB'].message_type = _IMAGE |
||||
_GETESSENTIALMATRIXREPLY.fields_by_name['matrix'].message_type = _TENSOR |
||||
DESCRIPTOR.message_types_by_name['Image'] = _IMAGE |
||||
DESCRIPTOR.message_types_by_name['ImagePair'] = _IMAGEPAIR |
||||
DESCRIPTOR.message_types_by_name['Tensor'] = _TENSOR |
||||
DESCRIPTOR.message_types_by_name['GetEssentialMatrixReply'] = _GETESSENTIALMATRIXREPLY |
||||
_sym_db.RegisterFileDescriptor(DESCRIPTOR) |
||||
|
||||
Image = _reflection.GeneratedProtocolMessageType('Image', (_message.Message,), { |
||||
'DESCRIPTOR' : _IMAGE, |
||||
'__module__' : 'loftr_pb2' |
||||
# @@protoc_insertion_point(class_scope:LOFRT.Image) |
||||
}) |
||||
_sym_db.RegisterMessage(Image) |
||||
|
||||
ImagePair = _reflection.GeneratedProtocolMessageType('ImagePair', (_message.Message,), { |
||||
'DESCRIPTOR' : _IMAGEPAIR, |
||||
'__module__' : 'loftr_pb2' |
||||
# @@protoc_insertion_point(class_scope:LOFRT.ImagePair) |
||||
}) |
||||
_sym_db.RegisterMessage(ImagePair) |
||||
|
||||
Tensor = _reflection.GeneratedProtocolMessageType('Tensor', (_message.Message,), { |
||||
'DESCRIPTOR' : _TENSOR, |
||||
'__module__' : 'loftr_pb2' |
||||
# @@protoc_insertion_point(class_scope:LOFRT.Tensor) |
||||
}) |
||||
_sym_db.RegisterMessage(Tensor) |
||||
|
||||
GetEssentialMatrixReply = _reflection.GeneratedProtocolMessageType('GetEssentialMatrixReply', (_message.Message,), { |
||||
'DESCRIPTOR' : _GETESSENTIALMATRIXREPLY, |
||||
'__module__' : 'loftr_pb2' |
||||
# @@protoc_insertion_point(class_scope:LOFRT.GetEssentialMatrixReply) |
||||
}) |
||||
_sym_db.RegisterMessage(GetEssentialMatrixReply) |
||||
|
||||
|
||||
|
||||
_LOFTR = _descriptor.ServiceDescriptor( |
||||
name='Loftr', |
||||
full_name='LOFRT.Loftr', |
||||
file=DESCRIPTOR, |
||||
index=0, |
||||
serialized_options=None, |
||||
create_key=_descriptor._internal_create_key, |
||||
serialized_start=275, |
||||
serialized_end=356, |
||||
methods=[ |
||||
_descriptor.MethodDescriptor( |
||||
name='getEssentialMatrix', |
||||
full_name='LOFRT.Loftr.getEssentialMatrix', |
||||
index=0, |
||||
containing_service=None, |
||||
input_type=_IMAGEPAIR, |
||||
output_type=_GETESSENTIALMATRIXREPLY, |
||||
serialized_options=None, |
||||
create_key=_descriptor._internal_create_key, |
||||
), |
||||
]) |
||||
_sym_db.RegisterServiceDescriptor(_LOFTR) |
||||
|
||||
DESCRIPTOR.services_by_name['Loftr'] = _LOFTR |
||||
|
||||
# @@protoc_insertion_point(module_scope) |
@ -0,0 +1,66 @@ |
||||
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! |
||||
"""Client and server classes corresponding to protobuf-defined services.""" |
||||
import grpc |
||||
|
||||
from . import loftr_pb2 as loftr__pb2 |
||||
|
||||
|
||||
class LoftrStub(object): |
||||
"""Missing associated documentation comment in .proto file.""" |
||||
|
||||
def __init__(self, channel): |
||||
"""Constructor. |
||||
|
||||
Args: |
||||
channel: A grpc.Channel. |
||||
""" |
||||
self.getEssentialMatrix = channel.unary_unary( |
||||
'/LOFRT.Loftr/getEssentialMatrix', |
||||
request_serializer=loftr__pb2.ImagePair.SerializeToString, |
||||
response_deserializer=loftr__pb2.GetEssentialMatrixReply.FromString, |
||||
) |
||||
|
||||
|
||||
class LoftrServicer(object): |
||||
"""Missing associated documentation comment in .proto file.""" |
||||
|
||||
def getEssentialMatrix(self, request, context): |
||||
"""Missing associated documentation comment in .proto file.""" |
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED) |
||||
context.set_details('Method not implemented!') |
||||
raise NotImplementedError('Method not implemented!') |
||||
|
||||
|
||||
def add_LoftrServicer_to_server(servicer, server): |
||||
rpc_method_handlers = { |
||||
'getEssentialMatrix': grpc.unary_unary_rpc_method_handler( |
||||
servicer.getEssentialMatrix, |
||||
request_deserializer=loftr__pb2.ImagePair.FromString, |
||||
response_serializer=loftr__pb2.GetEssentialMatrixReply.SerializeToString, |
||||
), |
||||
} |
||||
generic_handler = grpc.method_handlers_generic_handler( |
||||
'LOFRT.Loftr', rpc_method_handlers) |
||||
server.add_generic_rpc_handlers((generic_handler,)) |
||||
|
||||
|
||||
# This class is part of an EXPERIMENTAL API. |
||||
class Loftr(object): |
||||
"""Missing associated documentation comment in .proto file.""" |
||||
|
||||
@staticmethod |
||||
def getEssentialMatrix(request, |
||||
target, |
||||
options=(), |
||||
channel_credentials=None, |
||||
call_credentials=None, |
||||
insecure=False, |
||||
compression=None, |
||||
wait_for_ready=None, |
||||
timeout=None, |
||||
metadata=None): |
||||
return grpc.experimental.unary_unary(request, target, '/LOFRT.Loftr/getEssentialMatrix', |
||||
loftr__pb2.ImagePair.SerializeToString, |
||||
loftr__pb2.GetEssentialMatrixReply.FromString, |
||||
options, channel_credentials, |
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata) |
@ -0,0 +1,88 @@ |
||||
# -*- coding: utf-8 -*- |
||||
''' |
||||
@Author: CaptainHu |
||||
@Date: 2023年 02月 02日 星期四 15:19:14 CST |
||||
@Description: |
||||
''' |
||||
from collections import namedtuple |
||||
|
||||
import torch |
||||
import cv2 |
||||
import numpy as np |
||||
|
||||
from src.loftr import LoFTR, default_cfg |
||||
|
||||
DebugInfo=namedtuple("DebugInfo", |
||||
["kp0_fake_match","kp1_fake_match", |
||||
"kp0_true_match","kp1_true_match"]) |
||||
|
||||
class LoFTRWorker(object): |
||||
|
||||
def __init__(self, |
||||
config, |
||||
ckpt_path, |
||||
device="cuda:0", |
||||
thr=0.5, |
||||
ransc_method="USAC_MAGSAC", |
||||
ransc_thr=3, |
||||
ransc_max_iter=2000): |
||||
self.model = LoFTR(config=config) |
||||
self.model.load_state_dict(torch.load(ckpt_path)['state_dict']) |
||||
if device != 'cpu' and not torch.cuda.is_available(): |
||||
device = 'cpu' |
||||
print("ERROR: cuda can not use, will use cpu") |
||||
self.model = self.model.eval().to(device) |
||||
self.thr=thr |
||||
self.ransc_method = getattr(cv2,ransc_method) |
||||
self.ransc_thr=ransc_thr |
||||
self.ransc_max_iter=ransc_max_iter |
||||
|
||||
def _img2gray(self, img): |
||||
if len(img.shape) == 3 and img.shape[-1] == 3: |
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) |
||||
|
||||
return img |
||||
|
||||
def __call__(self, img0, img1,debug=""): |
||||
img0 = self._img1gray(img0) |
||||
img1 = self._img1gray(img1) |
||||
img0 = torch.from_numpy(img0)[None][None].cuda() / 255. |
||||
img1 = torch.from_numpy(img1)[None][None].cuda() / 255. |
||||
|
||||
batch = {'image0': img0, 'image1': img1} |
||||
with torch.no_grad(): |
||||
self.model(batch) |
||||
mkpts0 = batch['mkpts0_f'].cpu().numpy() |
||||
mkpts1 = batch['mkpts1_f'].cpu().numpy() |
||||
mconf = batch['mconf'].cpu().numpy() |
||||
|
||||
idx=np.where(mconf>self.thr) |
||||
mconf=mconf[idx] |
||||
mkpts0=mkpts0[idx] |
||||
mkpts1=mkpts1[idx] |
||||
|
||||
debug_info=None |
||||
if mkpts0.shape[0] < 4 or mkpts1.shape[0] < 4: |
||||
return np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], |
||||
dtype=np.float), False,debug_info |
||||
|
||||
H, Mask = cv2.findHomography(mkpts0[:, :2], |
||||
mkpts1[:, :2], |
||||
self.ransc_method, |
||||
self.ransc_thr, |
||||
maxIters=self.ransc_max_iter) |
||||
Mask=np.squeeze(Mask) |
||||
if debug: |
||||
|
||||
kp0_true_matched=mkpts0[Mask.astype(bool),:2] |
||||
kp1_true_matched=mkpts1[Mask.astype(bool),:2] |
||||
kp0_fake_matched=mkpts0[~Mask.astype(bool),:2] |
||||
kp1_fake_matched=mkpts1[~Mask.astype(bool),:2] |
||||
|
||||
debug_info=DebugInfo(kp0_fake_matched,kp1_fake_matched,kp0_true_matched,kp1_true_matched) |
||||
|
||||
if H is None: |
||||
return np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], |
||||
dtype=np.float), False,debug_info |
||||
else: |
||||
return H, True,debug_info |
@ -0,0 +1,50 @@ |
||||
''' |
||||
@Author: captainfffsama |
||||
@Date: 2023-02-02 15:59:46 |
||||
@LastEditors: captainfffsama tuanzhangsama@outlook.com |
||||
@LastEditTime: 2023-02-02 16:08:41 |
||||
@FilePath: /LoFTR/src/grpc/server.py |
||||
@Description: |
||||
''' |
||||
import numpy as np |
||||
|
||||
from . import loftr_pb2 |
||||
from .loftr_pb2_grpc import LoftrServicer |
||||
|
||||
from .loftr_worker import LoFTRWorker |
||||
from .utils import decode_img_from_proto, np2tensor_proto, img2pb_img |
||||
|
||||
|
||||
class LoFTRServer(LoftrServicer): |
||||
def __init__(self, |
||||
config, |
||||
ckpt_path, |
||||
device="cuda:0", |
||||
thr=0.5, |
||||
ransc_method="USAC_MAGSAC", |
||||
ransc_thr=3, |
||||
ransc_max_iter=2000, |
||||
*args, |
||||
**kwargs): |
||||
super().__init__(*args, **kwargs) |
||||
self.worker = LoFTRWorker( |
||||
config, |
||||
ckpt_path, |
||||
device, |
||||
thr, |
||||
ransc_method, |
||||
ransc_thr, |
||||
ransc_max_iter, |
||||
) |
||||
|
||||
def getEssentialMatrix(self, request, context): |
||||
imgA = decode_img_from_proto(request.imageA) |
||||
imgB = decode_img_from_proto(request.imageB) |
||||
if imgA is None or imgB is None: |
||||
return loftr_pb2.GetEssentialMatrixReply(matrix=np2tensor_proto( |
||||
np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.float)), |
||||
status=-14) |
||||
H, flag = self.worker(imgA, imgB) |
||||
status = 0 if flag else -14 |
||||
return loftr_pb2.GetEssentialMatrixReply(matrix=np2tensor_proto(H), |
||||
status=status) |
@ -0,0 +1,56 @@ |
||||
# -*- coding: utf-8 -*- |
||||
''' |
||||
@Author: captainfffsama |
||||
@Date: 2023-02-02 16:00:45 |
||||
@LastEditors: captainfffsama tuanzhangsama@outlook.com |
||||
@LastEditTime: 2023-02-02 16:03:15 |
||||
@FilePath: /LoFTR/src/grpc/utils.py |
||||
@Description: |
||||
''' |
||||
import os |
||||
import base64 |
||||
|
||||
import numpy as np |
||||
import cv2 |
||||
|
||||
from . import loftr_pb2 |
||||
|
||||
|
||||
def get_img(img_info): |
||||
if os.path.isfile(img_info): |
||||
if not os.path.exists(img_info): |
||||
return None |
||||
else: |
||||
return cv2.imread(img_info) #ignore |
||||
else: |
||||
img_str = base64.b64decode(img_info) |
||||
img_np = np.fromstring(img_str, np.uint8) |
||||
return cv2.imdecode(img_np, cv2.IMREAD_COLOR) |
||||
|
||||
|
||||
def decode_img_from_proto(proto_image): |
||||
if proto_image.image: |
||||
return get_img(proto_image.image) |
||||
else: |
||||
return get_img(proto_image.path) |
||||
|
||||
|
||||
def np2tensor_proto(np_ndarray: np.ndarray): |
||||
shape = list(np_ndarray.shape) |
||||
data = np_ndarray.flatten().tolist() |
||||
tensor_pb = loftr_pb2.Tensor() |
||||
tensor_pb.shape.extend(shape) |
||||
tensor_pb.data.extend(data) |
||||
return tensor_pb |
||||
|
||||
|
||||
def img2pb_img(img): |
||||
base64_str = cv2.imencode('.jpg', img)[1].tostring() |
||||
base64_str = base64.b64encode(base64_str) |
||||
return loftr_pb2.Image(image=base64_str) |
||||
|
||||
|
||||
def tensor_proto2np(tensor_pb): |
||||
np_matrix = np.array(tensor_pb.data, |
||||
dtype=np.float).reshape(tensor_pb.shape) |
||||
return np_matrix |
Loading…
Reference in new issue