diff --git a/src/python/grpcio/grpc/_interceptor.py b/src/python/grpcio/grpc/_interceptor.py index 6b7a912a941..1d2d374ad19 100644 --- a/src/python/grpcio/grpc/_interceptor.py +++ b/src/python/grpcio/grpc/_interceptor.py @@ -75,10 +75,10 @@ def _unwrap_client_call_details(call_details, default_details): return method, timeout, metadata, credentials -class _LocalFailure(grpc.RpcError, grpc.Future, grpc.Call): +class _FailureOutcome(grpc.RpcError, grpc.Future, grpc.Call): def __init__(self, exception, traceback): - super(_LocalFailure, self).__init__() + super(_FailureOutcome, self).__init__() self._exception = exception self._traceback = traceback @@ -134,6 +134,58 @@ class _LocalFailure(grpc.RpcError, grpc.Future, grpc.Call): raise self._exception +class _UnaryOutcome(grpc.Call, grpc.Future): + + def __init__(self, response, call): + self._response = response + self._call = call + + def initial_metadata(self): + return self._call.initial_metadata() + + def trailing_metadata(self): + return self._call.trailing_metadata() + + def code(self): + return self._call.code() + + def details(self): + return self._call.details() + + def is_active(self): + return self._call.is_active() + + def time_remaining(self): + return self._call.time_remaining() + + def cancel(self): + return self._call.cancel() + + def add_callback(self, callback): + return self._call.add_callback(callback) + + def cancelled(self): + return False + + def running(self): + return False + + def done(self): + return True + + def result(self, ignored_timeout=None): + return self._response + + def exception(self, ignored_timeout=None): + return None + + def traceback(self, ignored_timeout=None): + return None + + def add_done_callback(self, fn): + fn(self) + + class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): def __init__(self, thunk, method, interceptor): @@ -142,23 +194,45 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): self._interceptor = interceptor def __call__(self, request, timeout=None, metadata=None, credentials=None): - call_future = self.future( + response, ignored_call = self._with_call( request, timeout=timeout, metadata=metadata, credentials=credentials) - return call_future.result() + return response + + def _with_call(self, request, timeout=None, metadata=None, + credentials=None): + client_call_details = _ClientCallDetails(self._method, timeout, + metadata, credentials) + + def continuation(new_details, request): + new_method, new_timeout, new_metadata, new_credentials = ( + _unwrap_client_call_details(new_details, client_call_details)) + try: + response, call = self._thunk(new_method).with_call( + request, + timeout=new_timeout, + metadata=new_metadata, + credentials=new_credentials) + return _UnaryOutcome(response, call) + except grpc.RpcError: + raise + except Exception as exception: # pylint:disable=broad-except + return _FailureOutcome(exception, sys.exc_info()[2]) + + call = self._interceptor.intercept_unary_unary( + continuation, client_call_details, request) + return call.result(), call def with_call(self, request, timeout=None, metadata=None, credentials=None): - call_future = self.future( + return self._with_call( request, timeout=timeout, metadata=metadata, credentials=credentials) - return call_future.result(), call_future def future(self, request, timeout=None, metadata=None, credentials=None): - client_call_details = _ClientCallDetails(self._method, timeout, metadata, credentials) @@ -175,7 +249,7 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): return self._interceptor.intercept_unary_unary( continuation, client_call_details, request) except Exception as exception: # pylint:disable=broad-except - return _LocalFailure(exception, sys.exc_info()[2]) + return _FailureOutcome(exception, sys.exc_info()[2]) class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable): @@ -202,7 +276,7 @@ class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable): return self._interceptor.intercept_unary_stream( continuation, client_call_details, request) except Exception as exception: # pylint:disable=broad-except - return _LocalFailure(exception, sys.exc_info()[2]) + return _FailureOutcome(exception, sys.exc_info()[2]) class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): @@ -217,24 +291,50 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): timeout=None, metadata=None, credentials=None): - call_future = self.future( + response, ignored_call = self._with_call( request_iterator, timeout=timeout, metadata=metadata, credentials=credentials) - return call_future.result() + return response + + def _with_call(self, + request_iterator, + timeout=None, + metadata=None, + credentials=None): + client_call_details = _ClientCallDetails(self._method, timeout, + metadata, credentials) + + def continuation(new_details, request_iterator): + new_method, new_timeout, new_metadata, new_credentials = ( + _unwrap_client_call_details(new_details, client_call_details)) + try: + response, call = self._thunk(new_method).with_call( + request_iterator, + timeout=new_timeout, + metadata=new_metadata, + credentials=new_credentials) + return _UnaryOutcome(response, call) + except grpc.RpcError: + raise + except Exception as exception: # pylint:disable=broad-except + return _FailureOutcome(exception, sys.exc_info()[2]) + + call = self._interceptor.intercept_stream_unary( + continuation, client_call_details, request_iterator) + return call.result(), call def with_call(self, request_iterator, timeout=None, metadata=None, credentials=None): - call_future = self.future( + return self._with_call( request_iterator, timeout=timeout, metadata=metadata, credentials=credentials) - return call_future.result(), call_future def future(self, request_iterator, @@ -257,7 +357,7 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): return self._interceptor.intercept_stream_unary( continuation, client_call_details, request_iterator) except Exception as exception: # pylint:disable=broad-except - return _LocalFailure(exception, sys.exc_info()[2]) + return _FailureOutcome(exception, sys.exc_info()[2]) class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable): @@ -288,7 +388,7 @@ class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable): return self._interceptor.intercept_stream_stream( continuation, client_call_details, request_iterator) except Exception as exception: # pylint:disable=broad-except - return _LocalFailure(exception, sys.exc_info()[2]) + return _FailureOutcome(exception, sys.exc_info()[2]) class _Channel(grpc.Channel):