diff --git a/setup.cfg b/setup.cfg index 36d00651684..09aa3860cd2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -35,8 +35,6 @@ exclude = src/python/grpcio/grpc/aio src/python/grpcio/grpc/beta src/python/grpcio/grpc/__init__.py - src/python/grpcio/grpc/_channel.py - src/python/grpcio/grpc/_server.py src/python/grpcio/grpc/_simple_stubs.py # NOTE(lidiz) diff --git a/src/python/grpcio/grpc/_channel.py b/src/python/grpcio/grpc/_channel.py index b36e70f4a99..d31344fd0e6 100644 --- a/src/python/grpcio/grpc/_channel.py +++ b/src/python/grpcio/grpc/_channel.py @@ -20,13 +20,24 @@ import os import sys import threading import time - -import grpc -from grpc import _common -from grpc import _compression -from grpc import _grpcio_metadata +import types +from typing import (Any, Callable, Iterator, List, Optional, Sequence, Set, + Tuple, Union) + +import grpc # pytype: disable=pyi-error +from grpc import _common # pytype: disable=pyi-error +from grpc import _compression # pytype: disable=pyi-error +from grpc import _grpcio_metadata # pytype: disable=pyi-error from grpc._cython import cygrpc -import grpc.experimental +from grpc._typing import ChannelArgumentType +from grpc._typing import DeserializingFunction +from grpc._typing import IntegratedCallFactory +from grpc._typing import MetadataType +from grpc._typing import NullaryCallbackType +from grpc._typing import ResponseType +from grpc._typing import SerializingFunction +from grpc._typing import UserTag +import grpc.experimental # pytype: disable=pyi-error _LOGGER = logging.getLogger(__name__) @@ -81,18 +92,33 @@ _NON_OK_RENDEZVOUS_REPR_FORMAT = ('<{} of RPC that terminated with:\n' '>') -def _deadline(timeout): +def _deadline(timeout: Optional[float]) -> Optional[float]: return None if timeout is None else time.time() + timeout -def _unknown_code_details(unknown_cygrpc_code, details): +def _unknown_code_details(unknown_cygrpc_code: Optional[grpc.StatusCode], + details: Optional[str]) -> str: return 'Server sent unknown code {} and details "{}"'.format( unknown_cygrpc_code, details) class _RPCState(object): - - def __init__(self, due, initial_metadata, trailing_metadata, code, details): + condition: threading.Condition + due: Set[cygrpc.OperationType] + initial_metadata: Optional[MetadataType] + response: Any + trailing_metadata: Optional[MetadataType] + code: Optional[grpc.StatusCode] + details: Optional[str] + debug_error_string: Optional[str] + cancelled: bool + callbacks: List[NullaryCallbackType] + fork_epoch: Optional[int] + + def __init__(self, due: Sequence[cygrpc.OperationType], + initial_metadata: Optional[MetadataType], + trailing_metadata: Optional[MetadataType], + code: Optional[grpc.StatusCode], details: Optional[str]): # `condition` guards all members of _RPCState. `notify_all` is called on # `condition` when the state of the RPC has changed. self.condition = threading.Condition() @@ -123,7 +149,7 @@ class _RPCState(object): self.condition = threading.Condition() -def _abort(state, code, details): +def _abort(state: _RPCState, code: grpc.StatusCode, details: str) -> None: if state.code is None: state.code = code state.details = details @@ -132,7 +158,10 @@ def _abort(state, code, details): state.trailing_metadata = () -def _handle_event(event, state, response_deserializer): +def _handle_event( + event: cygrpc.BaseEvent, state: _RPCState, + response_deserializer: Optional[DeserializingFunction] +) -> List[NullaryCallbackType]: callbacks = [] for batch_operation in event.batch_operations: operation_type = batch_operation.type() @@ -167,7 +196,9 @@ def _handle_event(event, state, response_deserializer): return callbacks -def _event_handler(state, response_deserializer): +def _event_handler( + state: _RPCState, + response_deserializer: Optional[DeserializingFunction]) -> UserTag: def handle_event(event): with state.condition: @@ -187,10 +218,14 @@ def _event_handler(state, response_deserializer): return handle_event +# TODO(xuanwn): Create a base class for IntegratedCall and SegregatedCall. #pylint: disable=too-many-statements -def _consume_request_iterator(request_iterator, state, call, request_serializer, - event_handler): - """Consume a request iterator supplied by the user.""" +def _consume_request_iterator(request_iterator: Iterator, state: _RPCState, + call: Union[cygrpc.IntegratedCall, + cygrpc.SegregatedCall], + request_serializer: SerializingFunction, + event_handler: Optional[UserTag]) -> None: + """Consume a request supplied by the user.""" def consume_request_iterator(): # pylint: disable=too-many-branches # Iterate over the request iterator until it is exhausted or an error @@ -266,7 +301,7 @@ def _consume_request_iterator(request_iterator, state, call, request_serializer, consumption_thread.start() -def _rpc_state_string(class_name, rpc_state): +def _rpc_state_string(class_name: str, rpc_state: _RPCState) -> str: """Calculates error string for RPC.""" with rpc_state.condition: if rpc_state.code is None: @@ -289,8 +324,9 @@ class _InactiveRpcError(grpc.RpcError, grpc.Call, grpc.Future): Attributes: _state: An instance of _RPCState. """ + _state: _RPCState - def __init__(self, state): + def __init__(self, state: _RPCState): with state.condition: self._state = _RPCState((), copy.deepcopy(state.initial_metadata), copy.deepcopy(state.trailing_metadata), @@ -298,62 +334,68 @@ class _InactiveRpcError(grpc.RpcError, grpc.Call, grpc.Future): self._state.response = copy.copy(state.response) self._state.debug_error_string = copy.copy(state.debug_error_string) - def initial_metadata(self): + def initial_metadata(self) -> Optional[MetadataType]: return self._state.initial_metadata - def trailing_metadata(self): + def trailing_metadata(self) -> Optional[MetadataType]: return self._state.trailing_metadata - def code(self): + def code(self) -> Optional[grpc.StatusCode]: return self._state.code - def details(self): + def details(self) -> Optional[str]: return _common.decode(self._state.details) - def debug_error_string(self): + def debug_error_string(self) -> Optional[str]: return _common.decode(self._state.debug_error_string) - def _repr(self): + def _repr(self) -> str: return _rpc_state_string(self.__class__.__name__, self._state) - def __repr__(self): + def __repr__(self) -> str: return self._repr() - def __str__(self): + def __str__(self) -> str: return self._repr() - def cancel(self): + def cancel(self) -> bool: """See grpc.Future.cancel.""" return False - def cancelled(self): + def cancelled(self) -> bool: """See grpc.Future.cancelled.""" return False - def running(self): + def running(self) -> bool: """See grpc.Future.running.""" return False - def done(self): + def done(self) -> bool: """See grpc.Future.done.""" return True - def result(self, timeout=None): # pylint: disable=unused-argument + def result(self, timeout: Optional[float] = None) -> Any: # pylint: disable=unused-argument """See grpc.Future.result.""" raise self - def exception(self, timeout=None): # pylint: disable=unused-argument + def exception(self, timeout: Optional[float] = None) -> Optional[Exception]: # pylint: disable=unused-argument """See grpc.Future.exception.""" return self - def traceback(self, timeout=None): # pylint: disable=unused-argument + def traceback( + self, + timeout: Optional[float] = None # pylint: disable=unused-argument + ) -> Optional[types.TracebackType]: """See grpc.Future.traceback.""" try: raise self except grpc.RpcError: return sys.exc_info()[2] - def add_done_callback(self, fn, timeout=None): # pylint: disable=unused-argument + def add_done_callback( + self, + fn: Callable[[grpc.Future], None], + timeout: Optional[float] = None) -> None: # pylint: disable=unused-argument """See grpc.Future.add_done_callback.""" fn(self) @@ -371,20 +413,27 @@ class _Rendezvous(grpc.RpcError, grpc.RpcContext): _deadline: A float representing the deadline of the RPC in seconds. Or possibly None, to represent an RPC with no deadline at all. """ - - def __init__(self, state, call, response_deserializer, deadline): + _state: _RPCState + _call: Union[cygrpc.SegregatedCall, cygrpc.IntegratedCall] + _response_deserializer: Optional[DeserializingFunction] + _deadline: Optional[float] + + def __init__(self, state: _RPCState, call: Union[cygrpc.SegregatedCall, + cygrpc.IntegratedCall], + response_deserializer: Optional[DeserializingFunction], + deadline: Optional[float]): super(_Rendezvous, self).__init__() self._state = state self._call = call self._response_deserializer = response_deserializer self._deadline = deadline - def is_active(self): + def is_active(self) -> bool: """See grpc.RpcContext.is_active""" with self._state.condition: return self._state.code is None - def time_remaining(self): + def time_remaining(self) -> Optional[float]: """See grpc.RpcContext.time_remaining""" with self._state.condition: if self._deadline is None: @@ -392,7 +441,7 @@ class _Rendezvous(grpc.RpcError, grpc.RpcContext): else: return max(self._deadline - time.time(), 0) - def cancel(self): + def cancel(self) -> bool: """See grpc.RpcContext.cancel""" with self._state.condition: if self._state.code is None: @@ -407,7 +456,7 @@ class _Rendezvous(grpc.RpcError, grpc.RpcContext): else: return False - def add_callback(self, callback): + def add_callback(self, callback: NullaryCallbackType) -> bool: """See grpc.RpcContext.add_callback""" with self._state.condition: if self._state.callbacks is None: @@ -428,19 +477,19 @@ class _Rendezvous(grpc.RpcError, grpc.RpcContext): def _next(self): raise NotImplementedError() - def debug_error_string(self): + def debug_error_string(self) -> Optional[str]: raise NotImplementedError() - def _repr(self): + def _repr(self) -> str: return _rpc_state_string(self.__class__.__name__, self._state) - def __repr__(self): + def __repr__(self) -> str: return self._repr() - def __str__(self): + def __str__(self) -> str: return self._repr() - def __del__(self): + def __del__(self) -> None: with self._state.condition: if self._state.code is None: self._state.code = grpc.StatusCode.CANCELLED @@ -465,23 +514,24 @@ class _SingleThreadedRendezvous(_Rendezvous, grpc.Call, grpc.Future): # pylint: This means that these methods are safe to call from add_done_callback handlers. """ + _state: _RPCState - def _is_complete(self): + def _is_complete(self) -> bool: return self._state.code is not None - def cancelled(self): + def cancelled(self) -> bool: with self._state.condition: return self._state.cancelled - def running(self): + def running(self) -> bool: with self._state.condition: return self._state.code is None - def done(self): + def done(self) -> bool: with self._state.condition: return self._state.code is not None - def result(self, timeout=None): + def result(self, timeout: Optional[float] = None) -> Any: """Returns the result of the computation or raises its exception. This method will never block. Instead, it will raise an exception @@ -503,7 +553,7 @@ class _SingleThreadedRendezvous(_Rendezvous, grpc.Call, grpc.Future): # pylint: else: raise self - def exception(self, timeout=None): + def exception(self, timeout: Optional[float] = None) -> Optional[Exception]: """Return the exception raised by the computation. This method will never block. Instead, it will raise an exception @@ -525,7 +575,9 @@ class _SingleThreadedRendezvous(_Rendezvous, grpc.Call, grpc.Future): # pylint: else: return self - def traceback(self, timeout=None): + def traceback( + self, + timeout: Optional[float] = None) -> Optional[types.TracebackType]: """Access the traceback of the exception raised by the computation. This method will never block. Instead, it will raise an exception @@ -550,7 +602,7 @@ class _SingleThreadedRendezvous(_Rendezvous, grpc.Call, grpc.Future): # pylint: except grpc.RpcError: return sys.exc_info()[2] - def add_done_callback(self, fn): + def add_done_callback(self, fn: Callable[[grpc.Future], None]) -> None: with self._state.condition: if self._state.code is None: self._state.callbacks.append(functools.partial(fn, self)) @@ -558,7 +610,7 @@ class _SingleThreadedRendezvous(_Rendezvous, grpc.Call, grpc.Future): # pylint: fn(self) - def initial_metadata(self): + def initial_metadata(self) -> Optional[MetadataType]: """See grpc.Call.initial_metadata""" with self._state.condition: # NOTE(gnossen): Based on our initial call batch, we are guaranteed @@ -567,7 +619,7 @@ class _SingleThreadedRendezvous(_Rendezvous, grpc.Call, grpc.Future): # pylint: self._consume_next_event() return self._state.initial_metadata - def trailing_metadata(self): + def trailing_metadata(self) -> Optional[MetadataType]: """See grpc.Call.trailing_metadata""" with self._state.condition: if self._state.trailing_metadata is None: @@ -575,7 +627,7 @@ class _SingleThreadedRendezvous(_Rendezvous, grpc.Call, grpc.Future): # pylint: "Cannot get trailing metadata until RPC is completed.") return self._state.trailing_metadata - def code(self): + def code(self) -> Optional[grpc.StatusCode]: """See grpc.Call.code""" with self._state.condition: if self._state.code is None: @@ -583,7 +635,7 @@ class _SingleThreadedRendezvous(_Rendezvous, grpc.Call, grpc.Future): # pylint: "Cannot get code until RPC is completed.") return self._state.code - def details(self): + def details(self) -> Optional[str]: """See grpc.Call.details""" with self._state.condition: if self._state.details is None: @@ -591,7 +643,7 @@ class _SingleThreadedRendezvous(_Rendezvous, grpc.Call, grpc.Future): # pylint: "Cannot get details until RPC is completed.") return _common.decode(self._state.details) - def _consume_next_event(self): + def _consume_next_event(self) -> Optional[cygrpc.BaseEvent]: event = self._call.next_event() with self._state.condition: callbacks = _handle_event(event, self._state, @@ -602,7 +654,7 @@ class _SingleThreadedRendezvous(_Rendezvous, grpc.Call, grpc.Future): # pylint: callback() return event - def _next_response(self): + def _next_response(self) -> Any: while True: self._consume_next_event() with self._state.condition: @@ -616,7 +668,7 @@ class _SingleThreadedRendezvous(_Rendezvous, grpc.Call, grpc.Future): # pylint: elif self._state.code is not None: raise self - def _next(self): + def _next(self) -> Any: with self._state.condition: if self._state.code is None: # We tentatively add the operation as expected and remove @@ -641,7 +693,7 @@ class _SingleThreadedRendezvous(_Rendezvous, grpc.Call, grpc.Future): # pylint: raise self return self._next_response() - def debug_error_string(self): + def debug_error_string(self) -> Optional[str]: with self._state.condition: if self._state.debug_error_string is None: raise grpc.experimental.UsageError( @@ -659,8 +711,9 @@ class _MultiThreadedRendezvous(_Rendezvous, grpc.Call, grpc.Future): # pylint: This extra thread allows _MultiThreadedRendezvous to fulfill the grpc.Future interface and to mediate a bidirection streaming RPC. """ + _state: _RPCState - def initial_metadata(self): + def initial_metadata(self) -> Optional[MetadataType]: """See grpc.Call.initial_metadata""" with self._state.condition: @@ -670,7 +723,7 @@ class _MultiThreadedRendezvous(_Rendezvous, grpc.Call, grpc.Future): # pylint: _common.wait(self._state.condition.wait, _done) return self._state.initial_metadata - def trailing_metadata(self): + def trailing_metadata(self) -> Optional[MetadataType]: """See grpc.Call.trailing_metadata""" with self._state.condition: @@ -680,7 +733,7 @@ class _MultiThreadedRendezvous(_Rendezvous, grpc.Call, grpc.Future): # pylint: _common.wait(self._state.condition.wait, _done) return self._state.trailing_metadata - def code(self): + def code(self) -> Optional[grpc.StatusCode]: """See grpc.Call.code""" with self._state.condition: @@ -690,7 +743,7 @@ class _MultiThreadedRendezvous(_Rendezvous, grpc.Call, grpc.Future): # pylint: _common.wait(self._state.condition.wait, _done) return self._state.code - def details(self): + def details(self) -> Optional[str]: """See grpc.Call.details""" with self._state.condition: @@ -700,7 +753,7 @@ class _MultiThreadedRendezvous(_Rendezvous, grpc.Call, grpc.Future): # pylint: _common.wait(self._state.condition.wait, _done) return _common.decode(self._state.details) - def debug_error_string(self): + def debug_error_string(self) -> Optional[str]: with self._state.condition: def _done(): @@ -709,22 +762,22 @@ class _MultiThreadedRendezvous(_Rendezvous, grpc.Call, grpc.Future): # pylint: _common.wait(self._state.condition.wait, _done) return _common.decode(self._state.debug_error_string) - def cancelled(self): + def cancelled(self) -> bool: with self._state.condition: return self._state.cancelled - def running(self): + def running(self) -> bool: with self._state.condition: return self._state.code is None - def done(self): + def done(self) -> bool: with self._state.condition: return self._state.code is not None - def _is_complete(self): + def _is_complete(self) -> bool: return self._state.code is not None - def result(self, timeout=None): + def result(self, timeout: Optional[float] = None) -> Any: """Returns the result of the computation or raises its exception. See grpc.Future.result for the full API contract. @@ -743,7 +796,7 @@ class _MultiThreadedRendezvous(_Rendezvous, grpc.Call, grpc.Future): # pylint: else: raise self - def exception(self, timeout=None): + def exception(self, timeout: Optional[float] = None) -> Optional[Exception]: """Return the exception raised by the computation. See grpc.Future.exception for the full API contract. @@ -762,7 +815,9 @@ class _MultiThreadedRendezvous(_Rendezvous, grpc.Call, grpc.Future): # pylint: else: return self - def traceback(self, timeout=None): + def traceback( + self, + timeout: Optional[float] = None) -> Optional[types.TracebackType]: """Access the traceback of the exception raised by the computation. See grpc.future.traceback for the full API contract. @@ -784,7 +839,7 @@ class _MultiThreadedRendezvous(_Rendezvous, grpc.Call, grpc.Future): # pylint: except grpc.RpcError: return sys.exc_info()[2] - def add_done_callback(self, fn): + def add_done_callback(self, fn: Callable[[grpc.Future], None]) -> None: with self._state.condition: if self._state.code is None: self._state.callbacks.append(functools.partial(fn, self)) @@ -792,7 +847,7 @@ class _MultiThreadedRendezvous(_Rendezvous, grpc.Call, grpc.Future): # pylint: fn(self) - def _next(self): + def _next(self) -> Any: with self._state.condition: if self._state.code is None: event_handler = _event_handler(self._state, @@ -826,7 +881,10 @@ class _MultiThreadedRendezvous(_Rendezvous, grpc.Call, grpc.Future): # pylint: raise self -def _start_unary_request(request, timeout, request_serializer): +def _start_unary_request( + request: Any, timeout: Optional[float], + request_serializer: SerializingFunction +) -> Tuple[Optional[float], Optional[bytes], Optional[grpc.RpcError]]: deadline = _deadline(timeout) serialized_request = _common.serialize(request, request_serializer) if serialized_request is None: @@ -838,7 +896,10 @@ def _start_unary_request(request, timeout, request_serializer): return deadline, serialized_request, None -def _end_unary_response_blocking(state, call, with_call, deadline): +def _end_unary_response_blocking( + state: _RPCState, call: cygrpc.SegregatedCall, with_call: bool, + deadline: Optional[float] +) -> Union[ResponseType, Tuple[ResponseType, grpc.Call]]: if state.code is grpc.StatusCode.OK: if with_call: rendezvous = _MultiThreadedRendezvous(state, call, None, deadline) @@ -846,10 +907,12 @@ def _end_unary_response_blocking(state, call, with_call, deadline): else: return state.response else: - raise _InactiveRpcError(state) + raise _InactiveRpcError(state) # pytype: disable=not-instantiable -def _stream_unary_invocation_operationses(metadata, initial_metadata_flags): +def _stream_unary_invocation_operations( + metadata: Optional[MetadataType], + initial_metadata_flags: int) -> Sequence[Sequence[cygrpc.Operation]]: return ( ( cygrpc.SendInitialMetadataOperation(metadata, @@ -861,16 +924,17 @@ def _stream_unary_invocation_operationses(metadata, initial_metadata_flags): ) -def _stream_unary_invocation_operationses_and_tags(metadata, - initial_metadata_flags): +def _stream_unary_invocation_operations_and_tags( + metadata: Optional[MetadataType], initial_metadata_flags: int +) -> Sequence[Tuple[Sequence[cygrpc.Operation], Optional[UserTag]]]: return tuple(( operations, None, - ) for operations in _stream_unary_invocation_operationses( + ) for operations in _stream_unary_invocation_operations( metadata, initial_metadata_flags)) -def _determine_deadline(user_deadline): +def _determine_deadline(user_deadline: Optional[float]) -> Optional[float]: parent_deadline = cygrpc.get_deadline_from_context() if parent_deadline is None and user_deadline is None: return None @@ -883,10 +947,18 @@ def _determine_deadline(user_deadline): class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): + _channel: cygrpc.Channel + _managed_call: IntegratedCallFactory + _method: bytes + _request_serializer: Optional[SerializingFunction] + _response_deserializer: Optional[DeserializingFunction] + _context: Any # pylint: disable=too-many-arguments - def __init__(self, channel, managed_call, method, request_serializer, - response_deserializer): + def __init__(self, channel: cygrpc.Channel, + managed_call: IntegratedCallFactory, method: bytes, + request_serializer: Optional[SerializingFunction], + response_deserializer: Optional[DeserializingFunction]): self._channel = channel self._managed_call = managed_call self._method = method @@ -894,7 +966,12 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): self._response_deserializer = response_deserializer self._context = cygrpc.build_census_context() - def _prepare(self, request, timeout, metadata, wait_for_ready, compression): + def _prepare( + self, request: Any, timeout: Optional[float], + metadata: Optional[MetadataType], wait_for_ready: Optional[bool], + compression: Optional[grpc.Compression] + ) -> Tuple[Optional[_RPCState], Optional[Sequence[cygrpc.Operation]], + Optional[float], Optional[grpc.RpcError]]: deadline, serialized_request, rendezvous = _start_unary_request( request, timeout, self._request_serializer) initial_metadata_flags = _InitialMetadataFlags().with_wait_for_ready( @@ -916,8 +993,15 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): ) return state, operations, deadline, None - def _blocking(self, request, timeout, metadata, credentials, wait_for_ready, - compression): + def _blocking( + self, + request: Any, + timeout: Optional[float] = None, + metadata: Optional[MetadataType] = None, + credentials: Optional[grpc.CallCredentials] = None, + wait_for_ready: Optional[bool] = None, + compression: Optional[grpc.Compression] = None + ) -> Tuple[_RPCState, cygrpc.SegregatedCall]: state, operations, deadline, rendezvous = self._prepare( request, timeout, metadata, wait_for_ready, compression) if state is None: @@ -935,34 +1019,38 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): return state, call def __call__(self, - request, - timeout=None, - metadata=None, - credentials=None, - wait_for_ready=None, - compression=None): + request: Any, + timeout: Optional[float] = None, + metadata: Optional[MetadataType] = None, + credentials: Optional[grpc.CallCredentials] = None, + wait_for_ready: Optional[bool] = None, + compression: Optional[grpc.Compression] = None) -> Any: state, call, = self._blocking(request, timeout, metadata, credentials, wait_for_ready, compression) return _end_unary_response_blocking(state, call, False, None) - def with_call(self, - request, - timeout=None, - metadata=None, - credentials=None, - wait_for_ready=None, - compression=None): + def with_call( + self, + request: Any, + timeout: Optional[float] = None, + metadata: Optional[MetadataType] = None, + credentials: Optional[grpc.CallCredentials] = None, + wait_for_ready: Optional[bool] = None, + compression: Optional[grpc.Compression] = None + ) -> Tuple[Any, grpc.Call]: state, call, = self._blocking(request, timeout, metadata, credentials, wait_for_ready, compression) return _end_unary_response_blocking(state, call, True, None) - def future(self, - request, - timeout=None, - metadata=None, - credentials=None, - wait_for_ready=None, - compression=None): + def future( + self, + request: Any, + timeout: Optional[float] = None, + metadata: Optional[MetadataType] = None, + credentials: Optional[grpc.CallCredentials] = None, + wait_for_ready: Optional[bool] = None, + compression: Optional[grpc.Compression] = None + ) -> _MultiThreadedRendezvous: state, operations, deadline, rendezvous = self._prepare( request, timeout, metadata, wait_for_ready, compression) if state is None: @@ -980,10 +1068,16 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): class _SingleThreadedUnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable): + _channel: cygrpc.Channel + _method: bytes + _request_serializer: Optional[SerializingFunction] + _response_deserializer: Optional[DeserializingFunction] + _context: Any # pylint: disable=too-many-arguments - def __init__(self, channel, method, request_serializer, - response_deserializer): + def __init__(self, channel: cygrpc.Channel, method: bytes, + request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction): self._channel = channel self._method = method self._request_serializer = request_serializer @@ -991,13 +1085,14 @@ class _SingleThreadedUnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable): self._context = cygrpc.build_census_context() def __call__( # pylint: disable=too-many-locals - self, - request, - timeout=None, - metadata=None, - credentials=None, - wait_for_ready=None, - compression=None): + self, + request: Any, + timeout: Optional[float] = None, + metadata: Optional[MetadataType] = None, + credentials: Optional[grpc.CallCredentials] = None, + wait_for_ready: Optional[bool] = None, + compression: Optional[grpc.Compression] = None + ) -> _SingleThreadedRendezvous: deadline = _deadline(timeout) serialized_request = _common.serialize(request, self._request_serializer) @@ -1030,10 +1125,18 @@ class _SingleThreadedUnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable): class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable): + _channel: cygrpc.Channel + _managed_call: IntegratedCallFactory + _method: bytes + _request_serializer: Optional[SerializingFunction] + _response_deserializer: Optional[DeserializingFunction] + _context: Any # pylint: disable=too-many-arguments - def __init__(self, channel, managed_call, method, request_serializer, - response_deserializer): + def __init__(self, channel: cygrpc.Channel, + managed_call: IntegratedCallFactory, method: bytes, + request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction): self._channel = channel self._managed_call = managed_call self._method = method @@ -1042,13 +1145,14 @@ class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable): self._context = cygrpc.build_census_context() def __call__( # pylint: disable=too-many-locals - self, - request, - timeout=None, - metadata=None, - credentials=None, - wait_for_ready=None, - compression=None): + self, + request: Any, + timeout: Optional[float] = None, + metadata: Optional[MetadataType] = None, + credentials: Optional[grpc.CallCredentials] = None, + wait_for_ready: Optional[bool] = None, + compression: Optional[ + grpc.Compression] = None) -> _MultiThreadedRendezvous: deadline, serialized_request, rendezvous = _start_unary_request( request, timeout, self._request_serializer) initial_metadata_flags = _InitialMetadataFlags().with_wait_for_ready( @@ -1059,7 +1163,7 @@ class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable): augmented_metadata = _compression.augment_metadata( metadata, compression) state = _RPCState(_UNARY_STREAM_INITIAL_DUE, None, None, None, None) - operationses = ( + operations = ( ( cygrpc.SendInitialMetadataOperation(augmented_metadata, initial_metadata_flags), @@ -1074,8 +1178,7 @@ class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable): cygrpc.PropagationConstants.GRPC_PROPAGATE_DEFAULTS, self._method, None, _determine_deadline(deadline), metadata, None if credentials is None else credentials._credentials, - operationses, _event_handler(state, - self._response_deserializer), + operations, _event_handler(state, self._response_deserializer), self._context) return _MultiThreadedRendezvous(state, call, self._response_deserializer, @@ -1083,10 +1186,18 @@ class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable): class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): + _channel: cygrpc.Channel + _managed_call: IntegratedCallFactory + _method: bytes + _request_serializer: Optional[SerializingFunction] + _response_deserializer: Optional[DeserializingFunction] + _context: Any # pylint: disable=too-many-arguments - def __init__(self, channel, managed_call, method, request_serializer, - response_deserializer): + def __init__(self, channel: cygrpc.Channel, + managed_call: IntegratedCallFactory, method: bytes, + request_serializer: Optional[SerializingFunction], + response_deserializer: Optional[DeserializingFunction]): self._channel = channel self._managed_call = managed_call self._method = method @@ -1094,8 +1205,12 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): self._response_deserializer = response_deserializer self._context = cygrpc.build_census_context() - def _blocking(self, request_iterator, timeout, metadata, credentials, - wait_for_ready, compression): + def _blocking( + self, request_iterator: Iterator, timeout: Optional[float], + metadata: Optional[MetadataType], + credentials: Optional[grpc.CallCredentials], + wait_for_ready: Optional[bool], compression: Optional[grpc.Compression] + ) -> Tuple[_RPCState, cygrpc.SegregatedCall]: deadline = _deadline(timeout) state = _RPCState(_STREAM_UNARY_INITIAL_DUE, None, None, None, None) initial_metadata_flags = _InitialMetadataFlags().with_wait_for_ready( @@ -1106,7 +1221,7 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): cygrpc.PropagationConstants.GRPC_PROPAGATE_DEFAULTS, self._method, None, _determine_deadline(deadline), augmented_metadata, None if credentials is None else credentials._credentials, - _stream_unary_invocation_operationses_and_tags( + _stream_unary_invocation_operations_and_tags( augmented_metadata, initial_metadata_flags), self._context) _consume_request_iterator(request_iterator, state, call, self._request_serializer, None) @@ -1120,34 +1235,38 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): return state, call def __call__(self, - request_iterator, - timeout=None, - metadata=None, - credentials=None, - wait_for_ready=None, - compression=None): + request_iterator: Iterator, + timeout: Optional[float] = None, + metadata: Optional[MetadataType] = None, + credentials: Optional[grpc.CallCredentials] = None, + wait_for_ready: Optional[bool] = None, + compression: Optional[grpc.Compression] = None) -> Any: state, call, = self._blocking(request_iterator, timeout, metadata, credentials, wait_for_ready, compression) return _end_unary_response_blocking(state, call, False, None) - def with_call(self, - request_iterator, - timeout=None, - metadata=None, - credentials=None, - wait_for_ready=None, - compression=None): + def with_call( + self, + request_iterator: Iterator, + timeout: Optional[float] = None, + metadata: Optional[MetadataType] = None, + credentials: Optional[grpc.CallCredentials] = None, + wait_for_ready: Optional[bool] = None, + compression: Optional[grpc.Compression] = None + ) -> Tuple[Any, grpc.Call]: state, call, = self._blocking(request_iterator, timeout, metadata, credentials, wait_for_ready, compression) return _end_unary_response_blocking(state, call, True, None) - def future(self, - request_iterator, - timeout=None, - metadata=None, - credentials=None, - wait_for_ready=None, - compression=None): + def future( + self, + request_iterator: Iterator, + timeout: Optional[float] = None, + metadata: Optional[MetadataType] = None, + credentials: Optional[grpc.CallCredentials] = None, + wait_for_ready: Optional[bool] = None, + compression: Optional[grpc.Compression] = None + ) -> _MultiThreadedRendezvous: deadline = _deadline(timeout) state = _RPCState(_STREAM_UNARY_INITIAL_DUE, None, None, None, None) event_handler = _event_handler(state, self._response_deserializer) @@ -1159,8 +1278,8 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): cygrpc.PropagationConstants.GRPC_PROPAGATE_DEFAULTS, self._method, None, deadline, augmented_metadata, None if credentials is None else credentials._credentials, - _stream_unary_invocation_operationses(metadata, - initial_metadata_flags), + _stream_unary_invocation_operations(metadata, + initial_metadata_flags), event_handler, self._context) _consume_request_iterator(request_iterator, state, call, self._request_serializer, event_handler) @@ -1169,10 +1288,20 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable): + _channel: cygrpc.Channel + _managed_call: IntegratedCallFactory + _method: bytes + _request_serializer: Optional[SerializingFunction] + _response_deserializer: Optional[DeserializingFunction] + _context: Any # pylint: disable=too-many-arguments - def __init__(self, channel, managed_call, method, request_serializer, - response_deserializer): + def __init__(self, + channel: cygrpc.Channel, + managed_call: IntegratedCallFactory, + method: bytes, + request_serializer: Optional[SerializingFunction] = None, + response_deserializer: Optional[DeserializingFunction] = None): self._channel = channel self._managed_call = managed_call self._method = method @@ -1180,20 +1309,22 @@ class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable): self._response_deserializer = response_deserializer self._context = cygrpc.build_census_context() - def __call__(self, - request_iterator, - timeout=None, - metadata=None, - credentials=None, - wait_for_ready=None, - compression=None): + def __call__( + self, + request_iterator: Iterator, + timeout: Optional[float] = None, + metadata: Optional[MetadataType] = None, + credentials: Optional[grpc.CallCredentials] = None, + wait_for_ready: Optional[bool] = None, + compression: Optional[grpc.Compression] = None + ) -> _MultiThreadedRendezvous: deadline = _deadline(timeout) state = _RPCState(_STREAM_STREAM_INITIAL_DUE, None, None, None, None) initial_metadata_flags = _InitialMetadataFlags().with_wait_for_ready( wait_for_ready) augmented_metadata = _compression.augment_metadata( metadata, compression) - operationses = ( + operations = ( ( cygrpc.SendInitialMetadataOperation(augmented_metadata, initial_metadata_flags), @@ -1206,7 +1337,7 @@ class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable): cygrpc.PropagationConstants.GRPC_PROPAGATE_DEFAULTS, self._method, None, _determine_deadline(deadline), augmented_metadata, None if credentials is None else credentials._credentials, - operationses, event_handler, self._context) + operations, event_handler, self._context) _consume_request_iterator(request_iterator, state, call, self._request_serializer, event_handler) return _MultiThreadedRendezvous(state, call, @@ -1216,11 +1347,11 @@ class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable): class _InitialMetadataFlags(int): """Stores immutable initial metadata flags""" - def __new__(cls, value=_EMPTY_FLAGS): + def __new__(cls, value: int = _EMPTY_FLAGS): value &= cygrpc.InitialMetadataFlags.used_mask return super(_InitialMetadataFlags, cls).__new__(cls, value) - def with_wait_for_ready(self, wait_for_ready): + def with_wait_for_ready(self, wait_for_ready: Optional[bool]) -> int: if wait_for_ready is not None: if wait_for_ready: return self.__class__(self | cygrpc.InitialMetadataFlags.wait_for_ready | \ @@ -1232,14 +1363,17 @@ class _InitialMetadataFlags(int): class _ChannelCallState(object): + channel: cygrpc.Channel + managed_calls: int + threading: bool - def __init__(self, channel): + def __init__(self, channel: cygrpc.Channel): self.lock = threading.Lock() self.channel = channel self.managed_calls = 0 self.threading = False - def reset_postfork_child(self): + def reset_postfork_child(self) -> None: self.managed_calls = 0 def __del__(self): @@ -1250,7 +1384,7 @@ class _ChannelCallState(object): pass -def _run_channel_spin_thread(state): +def _run_channel_spin_thread(state: _ChannelCallState) -> None: def channel_spin(): while True: @@ -1270,11 +1404,14 @@ def _run_channel_spin_thread(state): channel_spin_thread.start() -def _channel_managed_call_management(state): +def _channel_managed_call_management(state: _ChannelCallState): # pylint: disable=too-many-arguments - def create(flags, method, host, deadline, metadata, credentials, - operationses, event_handler, context): + def create(flags: int, method: bytes, host: Optional[str], + deadline: Optional[float], metadata: Optional[MetadataType], + credentials: Optional[cygrpc.CallCredentials], + operations: Sequence[Sequence[cygrpc.Operation]], + event_handler: UserTag, context) -> cygrpc.IntegratedCall: """Creates a cygrpc.IntegratedCall. Args: @@ -1285,7 +1422,7 @@ def _channel_managed_call_management(state): the call is to have an infinite deadline. metadata: The metadata for the call or None. credentials: A cygrpc.CallCredentials or None. - operationses: An iterable of iterables of cygrpc.Operations to be + operations: A sequence of sequences of cygrpc.Operations to be started on the call. event_handler: A behavior to call to handle the events resultant from the operations on the call. @@ -1293,14 +1430,14 @@ def _channel_managed_call_management(state): Returns: A cygrpc.IntegratedCall with which to conduct an RPC. """ - operationses_and_tags = tuple(( - operations, + operations_and_tags = tuple(( + operation, event_handler, - ) for operations in operationses) + ) for operation in operations) with state.lock: call = state.channel.integrated_call(flags, method, host, deadline, metadata, credentials, - operationses_and_tags, context) + operations_and_tags, context) if state.managed_calls == 0: state.managed_calls = 1 _run_channel_spin_thread(state) @@ -1312,8 +1449,17 @@ def _channel_managed_call_management(state): class _ChannelConnectivityState(object): - - def __init__(self, channel): + lock: threading.RLock + channel: grpc.Channel + polling: bool + connectivity: grpc.ChannelConnectivity + try_to_connect: bool + # TODO(xuanwn): Refactor this: https://github.com/grpc/grpc/issues/31704 + callbacks_and_connectivities: List[Sequence[Union[Callable[ + [grpc.ChannelConnectivity], None], Optional[grpc.ChannelConnectivity]]]] + delivering: bool + + def __init__(self, channel: grpc.Channel): self.lock = threading.RLock() self.channel = channel self.polling = False @@ -1322,7 +1468,7 @@ class _ChannelConnectivityState(object): self.callbacks_and_connectivities = [] self.delivering = False - def reset_postfork_child(self): + def reset_postfork_child(self) -> None: self.polling = False self.connectivity = None self.try_to_connect = False @@ -1330,7 +1476,9 @@ class _ChannelConnectivityState(object): self.delivering = False -def _deliveries(state): +def _deliveries( + state: _ChannelConnectivityState +) -> List[Callable[[grpc.ChannelConnectivity], None]]: callbacks_needing_update = [] for callback_and_connectivity in state.callbacks_and_connectivities: callback, callback_connectivity, = callback_and_connectivity @@ -1340,7 +1488,11 @@ def _deliveries(state): return callbacks_needing_update -def _deliver(state, initial_connectivity, initial_callbacks): +def _deliver( + state: _ChannelConnectivityState, + initial_connectivity: grpc.ChannelConnectivity, + initial_callbacks: Sequence[Callable[[grpc.ChannelConnectivity], None]] +) -> None: connectivity = initial_connectivity callbacks = initial_callbacks while True: @@ -1360,7 +1512,10 @@ def _deliver(state, initial_connectivity, initial_callbacks): return -def _spawn_delivery(state, callbacks): +def _spawn_delivery( + state: _ChannelConnectivityState, + callbacks: Sequence[Callable[[grpc.ChannelConnectivity], + None]]) -> None: delivering_thread = cygrpc.ForkManagedThread(target=_deliver, args=( state, @@ -1373,7 +1528,8 @@ def _spawn_delivery(state, callbacks): # NOTE(https://github.com/grpc/grpc/issues/3064): We'd rather not poll. -def _poll_connectivity(state, channel, initial_try_to_connect): +def _poll_connectivity(state: _ChannelConnectivityState, channel: grpc.Channel, + initial_try_to_connect: bool) -> None: try_to_connect = initial_try_to_connect connectivity = channel.check_connectivity_state(try_to_connect) with state.lock: @@ -1410,7 +1566,9 @@ def _poll_connectivity(state, channel, initial_try_to_connect): _spawn_delivery(state, callbacks) -def _subscribe(state, callback, try_to_connect): +def _subscribe(state: _ChannelConnectivityState, + callback: Callable[[grpc.ChannelConnectivity], + None], try_to_connect: bool) -> None: with state.lock: if not state.callbacks_and_connectivities and not state.polling: polling_thread = cygrpc.ForkManagedThread( @@ -1430,7 +1588,8 @@ def _subscribe(state, callback, try_to_connect): state.callbacks_and_connectivities.append([callback, None]) -def _unsubscribe(state, callback): +def _unsubscribe(state: _ChannelConnectivityState, + callback: Callable[[grpc.ChannelConnectivity], None]) -> None: with state.lock: for index, (subscribed_callback, unused_connectivity) in enumerate( state.callbacks_and_connectivities): @@ -1439,7 +1598,10 @@ def _unsubscribe(state, callback): break -def _augment_options(base_options, compression): +def _augment_options( + base_options: Sequence[ChannelArgumentType], + compression: Optional[grpc.Compression] +) -> Sequence[ChannelArgumentType]: compression_option = _compression.create_channel_option(compression) return tuple(base_options) + compression_option + (( cygrpc.ChannelArgKey.primary_user_agent_string, @@ -1447,7 +1609,9 @@ def _augment_options(base_options, compression): ),) -def _separate_channel_options(options): +def _separate_channel_options( + options: Sequence[ChannelArgumentType] +) -> Tuple[Sequence[ChannelArgumentType], Sequence[ChannelArgumentType]]: """Separates core channel options from Python channel options.""" core_options = [] python_options = [] @@ -1461,8 +1625,14 @@ def _separate_channel_options(options): class Channel(grpc.Channel): """A cygrpc.Channel-backed implementation of grpc.Channel.""" - - def __init__(self, target, options, credentials, compression): + _single_threaded_unary_stream: bool + _channel: cygrpc.Channel + _call_state: _ChannelCallState + _connectivity_state: _ChannelConnectivityState + + def __init__(self, target: str, options: Sequence[ChannelArgumentType], + credentials: Optional[grpc.ChannelCredentials], + compression: Optional[grpc.Compression]): """Constructor. Args: @@ -1484,30 +1654,38 @@ class Channel(grpc.Channel): if cygrpc.g_gevent_activated: cygrpc.gevent_increment_channel_count() - def _process_python_options(self, python_options): + def _process_python_options( + self, python_options: Sequence[ChannelArgumentType]) -> None: """Sets channel attributes according to python-only channel options.""" for pair in python_options: if pair[0] == grpc.experimental.ChannelOptions.SingleThreadedUnaryStream: self._single_threaded_unary_stream = True - def subscribe(self, callback, try_to_connect=None): + def subscribe(self, + callback: Callable[[grpc.ChannelConnectivity], None], + try_to_connect: Optional[bool] = None) -> None: _subscribe(self._connectivity_state, callback, try_to_connect) - def unsubscribe(self, callback): + def unsubscribe( + self, callback: Callable[[grpc.ChannelConnectivity], None]) -> None: _unsubscribe(self._connectivity_state, callback) - def unary_unary(self, - method, - request_serializer=None, - response_deserializer=None): + def unary_unary( + self, + method: str, + request_serializer: Optional[SerializingFunction] = None, + response_deserializer: Optional[DeserializingFunction] = None + ) -> grpc.UnaryUnaryMultiCallable: return _UnaryUnaryMultiCallable( self._channel, _channel_managed_call_management(self._call_state), _common.encode(method), request_serializer, response_deserializer) - def unary_stream(self, - method, - request_serializer=None, - response_deserializer=None): + def unary_stream( + self, + method: str, + request_serializer: Optional[SerializingFunction] = None, + response_deserializer: Optional[DeserializingFunction] = None + ) -> grpc.UnaryStreamMultiCallable: # NOTE(rbellevi): Benchmarks have shown that running a unary-stream RPC # on a single Python thread results in an appreciable speed-up. However, # due to slight differences in capability, the multi-threaded variant @@ -1523,36 +1701,40 @@ class Channel(grpc.Channel): _common.encode(method), request_serializer, response_deserializer) - def stream_unary(self, - method, - request_serializer=None, - response_deserializer=None): + def stream_unary( + self, + method: str, + request_serializer: Optional[SerializingFunction] = None, + response_deserializer: Optional[DeserializingFunction] = None + ) -> grpc.StreamUnaryMultiCallable: return _StreamUnaryMultiCallable( self._channel, _channel_managed_call_management(self._call_state), _common.encode(method), request_serializer, response_deserializer) - def stream_stream(self, - method, - request_serializer=None, - response_deserializer=None): + def stream_stream( + self, + method: str, + request_serializer: Optional[SerializingFunction] = None, + response_deserializer: Optional[DeserializingFunction] = None + ) -> grpc.StreamStreamMultiCallable: return _StreamStreamMultiCallable( self._channel, _channel_managed_call_management(self._call_state), _common.encode(method), request_serializer, response_deserializer) - def _unsubscribe_all(self): + def _unsubscribe_all(self) -> None: state = self._connectivity_state if state: with state.lock: del state.callbacks_and_connectivities[:] - def _close(self): + def _close(self) -> None: self._unsubscribe_all() self._channel.close(cygrpc.StatusCode.cancelled, 'Channel closed!') cygrpc.fork_unregister_channel(self) if cygrpc.g_gevent_activated: cygrpc.gevent_decrement_channel_count() - def _close_on_fork(self): + def _close_on_fork(self) -> None: self._unsubscribe_all() self._channel.close_on_fork(cygrpc.StatusCode.cancelled, 'Channel closed due to fork') @@ -1564,7 +1746,7 @@ class Channel(grpc.Channel): self._close() return False - def close(self): + def close(self) -> None: self._close() def __del__(self): diff --git a/src/python/grpcio/grpc/_common.py b/src/python/grpcio/grpc/_common.py index b976800e6bb..3b8fd0ff97d 100644 --- a/src/python/grpcio/grpc/_common.py +++ b/src/python/grpcio/grpc/_common.py @@ -107,14 +107,14 @@ def fully_qualified_method(group: str, method: str) -> str: return '/{}/{}'.format(group, method) -def _wait_once(wait_fn: Callable[..., None], timeout: float, +def _wait_once(wait_fn: Callable[..., bool], timeout: float, spin_cb: Optional[Callable[[], None]]): wait_fn(timeout=timeout) if spin_cb is not None: spin_cb() -def wait(wait_fn: Callable[..., None], +def wait(wait_fn: Callable[..., bool], wait_complete_fn: Callable[[], bool], timeout: Optional[float] = None, spin_cb: Optional[Callable[[], None]] = None) -> bool: diff --git a/src/python/grpcio/grpc/_interceptor.py b/src/python/grpcio/grpc/_interceptor.py index e327ae17261..865ff17d35b 100644 --- a/src/python/grpcio/grpc/_interceptor.py +++ b/src/python/grpcio/grpc/_interceptor.py @@ -201,7 +201,7 @@ class _UnaryOutcome(grpc.Call, grpc.Future): def cancel(self) -> bool: return self._call.cancel() - def add_callback(self, callback) -> None: + def add_callback(self, callback) -> bool: return self._call.add_callback(callback) def cancelled(self) -> bool: diff --git a/src/python/grpcio/grpc/_server.py b/src/python/grpcio/grpc/_server.py index 2ad9c27cced..d9f06663ae2 100644 --- a/src/python/grpcio/grpc/_server.py +++ b/src/python/grpcio/grpc/_server.py @@ -13,18 +13,31 @@ # limitations under the License. """Service-side implementation of gRPC Python.""" +from __future__ import annotations + import collections from concurrent import futures import enum import logging import threading import time +from typing import (Any, Callable, Iterable, Iterator, List, Mapping, Optional, + Sequence, Set, Tuple, Union) -import grpc -from grpc import _common -from grpc import _compression -from grpc import _interceptor +import grpc # pytype: disable=pyi-error +from grpc import _common # pytype: disable=pyi-error +from grpc import _compression # pytype: disable=pyi-error +from grpc import _interceptor # pytype: disable=pyi-error from grpc._cython import cygrpc +from grpc._typing import ArityAgnosticMethodHandler +from grpc._typing import ChannelArgumentType +from grpc._typing import DeserializingFunction +from grpc._typing import MetadataType +from grpc._typing import NullaryCallbackType +from grpc._typing import ResponseType +from grpc._typing import SerializingFunction +from grpc._typing import ServerCallbackTag +from grpc._typing import ServerTagCallbackType _LOGGER = logging.getLogger(__name__) @@ -51,30 +64,31 @@ _DEALLOCATED_SERVER_CHECK_PERIOD_S = 1.0 _INF_TIMEOUT = 1e9 -def _serialized_request(request_event): +def _serialized_request(request_event: cygrpc.BaseEvent) -> bytes: return request_event.batch_operations[0].message() -def _application_code(code): +def _application_code(code: grpc.StatusCode) -> cygrpc.StatusCode: cygrpc_code = _common.STATUS_CODE_TO_CYGRPC_STATUS_CODE.get(code) return cygrpc.StatusCode.unknown if cygrpc_code is None else cygrpc_code -def _completion_code(state): +def _completion_code(state: _RPCState) -> cygrpc.StatusCode: if state.code is None: return cygrpc.StatusCode.ok else: return _application_code(state.code) -def _abortion_code(state, code): +def _abortion_code(state: _RPCState, + code: cygrpc.StatusCode) -> cygrpc.StatusCode: if state.code is None: return code else: return _application_code(state.code) -def _details(state): +def _details(state: _RPCState) -> bytes: return b'' if state.details is None else state.details @@ -87,6 +101,20 @@ class _HandlerCallDetails( class _RPCState(object): + condition: threading.Condition + due = Set[str] + request: Any + client: str + initial_metadata_allowed: bool + compression_algorithm: Optional[grpc.Compression] + disable_next_compression: bool + trailing_metadata: Optional[MetadataType] + code: Optional[grpc.StatusCode] + details: Optional[bytes] + statused: bool + rpc_errors: List[Exception] + callbacks: Optional[List[NullaryCallbackType]] + aborted: bool def __init__(self): self.condition = threading.Condition() @@ -105,13 +133,14 @@ class _RPCState(object): self.aborted = False -def _raise_rpc_error(state): +def _raise_rpc_error(state: _RPCState) -> None: rpc_error = grpc.RpcError() state.rpc_errors.append(rpc_error) raise rpc_error -def _possibly_finish_call(state, token): +def _possibly_finish_call(state: _RPCState, + token: str) -> ServerTagCallbackType: state.due.remove(token) if not _is_rpc_state_active(state) and not state.due: callbacks = state.callbacks @@ -121,7 +150,7 @@ def _possibly_finish_call(state, token): return None, () -def _send_status_from_server(state, token): +def _send_status_from_server(state: _RPCState, token: str) -> ServerCallbackTag: def send_status_from_server(unused_send_status_from_server_event): with state.condition: @@ -130,7 +159,9 @@ def _send_status_from_server(state, token): return send_status_from_server -def _get_initial_metadata(state, metadata): +def _get_initial_metadata( + state: _RPCState, + metadata: Optional[MetadataType]) -> Optional[MetadataType]: with state.condition: if state.compression_algorithm: compression_metadata = ( @@ -144,13 +175,15 @@ def _get_initial_metadata(state, metadata): return metadata -def _get_initial_metadata_operation(state, metadata): +def _get_initial_metadata_operation( + state: _RPCState, metadata: Optional[MetadataType]) -> cygrpc.Operation: operation = cygrpc.SendInitialMetadataOperation( _get_initial_metadata(state, metadata), _EMPTY_FLAGS) return operation -def _abort(state, call, code, details): +def _abort(state: _RPCState, call: cygrpc.Call, code: cygrpc.StatusCode, + details: bytes) -> None: if state.client is not _CANCELLED: effective_code = _abortion_code(state, code) effective_details = details if state.details is None else state.details @@ -174,7 +207,7 @@ def _abort(state, call, code, details): state.due.add(token) -def _receive_close_on_server(state): +def _receive_close_on_server(state: _RPCState) -> ServerCallbackTag: def receive_close_on_server(receive_close_on_server_event): with state.condition: @@ -188,7 +221,10 @@ def _receive_close_on_server(state): return receive_close_on_server -def _receive_message(state, call, request_deserializer): +def _receive_message( + state: _RPCState, call: cygrpc.Call, + request_deserializer: Optional[DeserializingFunction] +) -> ServerCallbackTag: def receive_message(receive_message_event): serialized_request = _serialized_request(receive_message_event) @@ -213,7 +249,7 @@ def _receive_message(state, call, request_deserializer): return receive_message -def _send_initial_metadata(state): +def _send_initial_metadata(state: _RPCState) -> ServerCallbackTag: def send_initial_metadata(unused_send_initial_metadata_event): with state.condition: @@ -222,7 +258,7 @@ def _send_initial_metadata(state): return send_initial_metadata -def _send_message(state, token): +def _send_message(state: _RPCState, token: str) -> ServerCallbackTag: def send_message(unused_send_message_event): with state.condition: @@ -233,23 +269,27 @@ def _send_message(state, token): class _Context(grpc.ServicerContext): + _rpc_event: cygrpc.BaseEvent + _state: _RPCState + request_deserializer: Optional[DeserializingFunction] - def __init__(self, rpc_event, state, request_deserializer): + def __init__(self, rpc_event: cygrpc.BaseEvent, state: _RPCState, + request_deserializer: Optional[DeserializingFunction]): self._rpc_event = rpc_event self._state = state self._request_deserializer = request_deserializer - def is_active(self): + def is_active(self) -> bool: with self._state.condition: return _is_rpc_state_active(self._state) - def time_remaining(self): + def time_remaining(self) -> float: return max(self._rpc_event.call_details.deadline - time.time(), 0) - def cancel(self): + def cancel(self) -> None: self._rpc_event.call.cancel() - def add_callback(self, callback): + def add_callback(self, callback: NullaryCallbackType) -> bool: with self._state.condition: if self._state.callbacks is None: return False @@ -257,24 +297,24 @@ class _Context(grpc.ServicerContext): self._state.callbacks.append(callback) return True - def disable_next_message_compression(self): + def disable_next_message_compression(self) -> None: with self._state.condition: self._state.disable_next_compression = True - def invocation_metadata(self): + def invocation_metadata(self) -> Optional[MetadataType]: return self._rpc_event.invocation_metadata - def peer(self): + def peer(self) -> str: return _common.decode(self._rpc_event.call.peer()) - def peer_identities(self): + def peer_identities(self) -> Optional[Sequence[bytes]]: return cygrpc.peer_identities(self._rpc_event.call) - def peer_identity_key(self): + def peer_identity_key(self) -> Optional[str]: id_key = cygrpc.peer_identity_key(self._rpc_event.call) return id_key if id_key is None else _common.decode(id_key) - def auth_context(self): + def auth_context(self) -> Mapping[str, Sequence[bytes]]: auth_context = cygrpc.auth_context(self._rpc_event.call) auth_context_dict = {} if auth_context is None else auth_context return { @@ -282,11 +322,11 @@ class _Context(grpc.ServicerContext): for key, value in auth_context_dict.items() } - def set_compression(self, compression): + def set_compression(self, compression: grpc.Compression) -> None: with self._state.condition: self._state.compression_algorithm = compression - def send_initial_metadata(self, initial_metadata): + def send_initial_metadata(self, initial_metadata: MetadataType) -> None: with self._state.condition: if self._state.client is _CANCELLED: _raise_rpc_error(self._state) @@ -301,14 +341,14 @@ class _Context(grpc.ServicerContext): else: raise ValueError('Initial metadata no longer allowed!') - def set_trailing_metadata(self, trailing_metadata): + def set_trailing_metadata(self, trailing_metadata: MetadataType) -> None: with self._state.condition: self._state.trailing_metadata = trailing_metadata - def trailing_metadata(self): + def trailing_metadata(self) -> Optional[MetadataType]: return self._state.trailing_metadata - def abort(self, code, details): + def abort(self, code: grpc.StatusCode, details: str) -> None: # treat OK like other invalid arguments: fail the RPC if code == grpc.StatusCode.OK: _LOGGER.error( @@ -321,36 +361,40 @@ class _Context(grpc.ServicerContext): self._state.aborted = True raise Exception() - def abort_with_status(self, status): + def abort_with_status(self, status: grpc.Status) -> None: self._state.trailing_metadata = status.trailing_metadata self.abort(status.code, status.details) - def set_code(self, code): + def set_code(self, code: grpc.StatusCode) -> None: with self._state.condition: self._state.code = code - def code(self): + def code(self) -> grpc.StatusCode: return self._state.code - def set_details(self, details): + def set_details(self, details: str) -> None: with self._state.condition: self._state.details = _common.encode(details) - def details(self): + def details(self) -> bytes: return self._state.details - def _finalize_state(self): + def _finalize_state(self) -> None: pass class _RequestIterator(object): + _state: _RPCState + _call: cygrpc.Call + _request_deserializer: Optional[DeserializingFunction] - def __init__(self, state, call, request_deserializer): + def __init__(self, state: _RPCState, call: cygrpc.Call, + request_deserializer: Optional[DeserializingFunction]): self._state = state self._call = call self._request_deserializer = request_deserializer - def _raise_or_start_receive_message(self): + def _raise_or_start_receive_message(self) -> None: if self._state.client is _CANCELLED: _raise_rpc_error(self._state) elif not _is_rpc_state_active(self._state): @@ -362,7 +406,7 @@ class _RequestIterator(object): self._request_deserializer)) self._state.due.add(_RECEIVE_MESSAGE_TOKEN) - def _look_for_request(self): + def _look_for_request(self) -> Any: if self._state.client is _CANCELLED: _raise_rpc_error(self._state) elif (self._state.request is None and @@ -375,7 +419,7 @@ class _RequestIterator(object): raise AssertionError() # should never run - def _next(self): + def _next(self) -> Any: with self._state.condition: self._raise_or_start_receive_message() while True: @@ -384,17 +428,20 @@ class _RequestIterator(object): if request is not None: return request - def __iter__(self): + def __iter__(self) -> _RequestIterator: return self - def __next__(self): + def __next__(self) -> Any: return self._next() - def next(self): + def next(self) -> Any: return self._next() -def _unary_request(rpc_event, state, request_deserializer): +def _unary_request( + rpc_event: cygrpc.BaseEvent, state: _RPCState, + request_deserializer: Optional[DeserializingFunction] +) -> Callable[[], Any]: def unary_request(): with state.condition: @@ -426,13 +473,15 @@ def _unary_request(rpc_event, state, request_deserializer): return unary_request -def _call_behavior(rpc_event, - state, - behavior, - argument, - request_deserializer, - send_response_callback=None): - from grpc import _create_servicer_context +def _call_behavior( + rpc_event: cygrpc.BaseEvent, + state: _RPCState, + behavior: ArityAgnosticMethodHandler, + argument: Any, + request_deserializer: Optional[DeserializingFunction], + send_response_callback: Optional[Callable[[ResponseType], None]] = None +) -> Tuple[Union[ResponseType, Iterator[ResponseType]], bool]: + from grpc import _create_servicer_context # pytype: disable=pyi-error with _create_servicer_context(rpc_event, state, request_deserializer) as context: try: @@ -457,7 +506,9 @@ def _call_behavior(rpc_event, return None, False -def _take_response_from_response_iterator(rpc_event, state, response_iterator): +def _take_response_from_response_iterator( + rpc_event: cygrpc.BaseEvent, state: _RPCState, + response_iterator: Iterator[ResponseType]) -> Tuple[ResponseType, bool]: try: return next(response_iterator), True except StopIteration: @@ -475,7 +526,9 @@ def _take_response_from_response_iterator(rpc_event, state, response_iterator): return None, False -def _serialize_response(rpc_event, state, response, response_serializer): +def _serialize_response( + rpc_event: cygrpc.BaseEvent, state: _RPCState, response: Any, + response_serializer: Optional[SerializingFunction]) -> Optional[bytes]: serialized_response = _common.serialize(response, response_serializer) if serialized_response is None: with state.condition: @@ -486,19 +539,21 @@ def _serialize_response(rpc_event, state, response, response_serializer): return serialized_response -def _get_send_message_op_flags_from_state(state): +def _get_send_message_op_flags_from_state( + state: _RPCState) -> Union[int, cygrpc.WriteFlag]: if state.disable_next_compression: return cygrpc.WriteFlag.no_compress else: return _EMPTY_FLAGS -def _reset_per_message_state(state): +def _reset_per_message_state(state: _RPCState) -> None: with state.condition: state.disable_next_compression = False -def _send_response(rpc_event, state, serialized_response): +def _send_response(rpc_event: cygrpc.BaseEvent, state: _RPCState, + serialized_response: bytes) -> bool: with state.condition: if not _is_rpc_state_active(state): return False @@ -527,7 +582,8 @@ def _send_response(rpc_event, state, serialized_response): return _is_rpc_state_active(state) -def _status(rpc_event, state, serialized_response): +def _status(rpc_event: cygrpc.BaseEvent, state: _RPCState, + serialized_response: Optional[bytes]) -> None: with state.condition: if state.client is not _CANCELLED: code = _completion_code(state) @@ -552,8 +608,11 @@ def _status(rpc_event, state, serialized_response): state.due.add(_SEND_STATUS_FROM_SERVER_TOKEN) -def _unary_response_in_pool(rpc_event, state, behavior, argument_thunk, - request_deserializer, response_serializer): +def _unary_response_in_pool( + rpc_event: cygrpc.BaseEvent, state: _RPCState, + behavior: ArityAgnosticMethodHandler, argument_thunk: Callable[[], Any], + request_deserializer: Optional[SerializingFunction], + response_serializer: Optional[SerializingFunction]) -> None: cygrpc.install_context_from_request_call_event(rpc_event) try: argument = argument_thunk() @@ -569,11 +628,14 @@ def _unary_response_in_pool(rpc_event, state, behavior, argument_thunk, cygrpc.uninstall_context() -def _stream_response_in_pool(rpc_event, state, behavior, argument_thunk, - request_deserializer, response_serializer): +def _stream_response_in_pool( + rpc_event: cygrpc.BaseEvent, state: _RPCState, + behavior: ArityAgnosticMethodHandler, argument_thunk: Callable[[], Any], + request_deserializer: Optional[DeserializingFunction], + response_serializer: Optional[SerializingFunction]) -> None: cygrpc.install_context_from_request_call_event(rpc_event) - def send_response(response): + def send_response(response: Any) -> None: if response is None: _status(rpc_event, state, None) else: @@ -604,13 +666,14 @@ def _stream_response_in_pool(rpc_event, state, behavior, argument_thunk, cygrpc.uninstall_context() -def _is_rpc_state_active(state): +def _is_rpc_state_active(state: _RPCState) -> bool: return state.client is not _CANCELLED and not state.statused -def _send_message_callback_to_blocking_iterator_adapter(rpc_event, state, - send_response_callback, - response_iterator): +def _send_message_callback_to_blocking_iterator_adapter( + rpc_event: cygrpc.BaseEvent, state: _RPCState, + send_response_callback: Callable[[ResponseType], None], + response_iterator: Iterator[ResponseType]) -> None: while True: response, proceed = _take_response_from_response_iterator( rpc_event, state, response_iterator) @@ -622,7 +685,10 @@ def _send_message_callback_to_blocking_iterator_adapter(rpc_event, state, break -def _select_thread_pool_for_behavior(behavior, default_thread_pool): +def _select_thread_pool_for_behavior( + behavior: ArityAgnosticMethodHandler, + default_thread_pool: futures.ThreadPoolExecutor +) -> futures.ThreadPoolExecutor: if hasattr(behavior, 'experimental_thread_pool') and isinstance( behavior.experimental_thread_pool, futures.ThreadPoolExecutor): return behavior.experimental_thread_pool @@ -630,7 +696,10 @@ def _select_thread_pool_for_behavior(behavior, default_thread_pool): return default_thread_pool -def _handle_unary_unary(rpc_event, state, method_handler, default_thread_pool): +def _handle_unary_unary( + rpc_event: cygrpc.BaseEvent, state: _RPCState, + method_handler: grpc.RpcMethodHandler, + default_thread_pool: futures.ThreadPoolExecutor) -> futures.Future: unary_request = _unary_request(rpc_event, state, method_handler.request_deserializer) thread_pool = _select_thread_pool_for_behavior(method_handler.unary_unary, @@ -641,7 +710,10 @@ def _handle_unary_unary(rpc_event, state, method_handler, default_thread_pool): method_handler.response_serializer) -def _handle_unary_stream(rpc_event, state, method_handler, default_thread_pool): +def _handle_unary_stream( + rpc_event: cygrpc.BaseEvent, state: _RPCState, + method_handler: grpc.RpcMethodHandler, + default_thread_pool: futures.ThreadPoolExecutor) -> futures.Future: unary_request = _unary_request(rpc_event, state, method_handler.request_deserializer) thread_pool = _select_thread_pool_for_behavior(method_handler.unary_stream, @@ -652,7 +724,10 @@ def _handle_unary_stream(rpc_event, state, method_handler, default_thread_pool): method_handler.response_serializer) -def _handle_stream_unary(rpc_event, state, method_handler, default_thread_pool): +def _handle_stream_unary( + rpc_event: cygrpc.BaseEvent, state: _RPCState, + method_handler: grpc.RpcMethodHandler, + default_thread_pool: futures.ThreadPoolExecutor) -> futures.Future: request_iterator = _RequestIterator(state, rpc_event.call, method_handler.request_deserializer) thread_pool = _select_thread_pool_for_behavior(method_handler.stream_unary, @@ -664,8 +739,10 @@ def _handle_stream_unary(rpc_event, state, method_handler, default_thread_pool): method_handler.response_serializer) -def _handle_stream_stream(rpc_event, state, method_handler, - default_thread_pool): +def _handle_stream_stream( + rpc_event: cygrpc.BaseEvent, state: _RPCState, + method_handler: grpc.RpcMethodHandler, + default_thread_pool: futures.ThreadPoolExecutor) -> futures.Future: request_iterator = _RequestIterator(state, rpc_event.call, method_handler.request_deserializer) thread_pool = _select_thread_pool_for_behavior(method_handler.stream_stream, @@ -677,9 +754,14 @@ def _handle_stream_stream(rpc_event, state, method_handler, method_handler.response_serializer) -def _find_method_handler(rpc_event, generic_handlers, interceptor_pipeline): +def _find_method_handler( + rpc_event: cygrpc.BaseEvent, generic_handlers: List[grpc.GenericRpcHandler], + interceptor_pipeline: Optional[_interceptor._ServicePipeline] +) -> Optional[grpc.RpcMethodHandler]: - def query_handlers(handler_call_details): + def query_handlers( + handler_call_details: _HandlerCallDetails + ) -> Optional[grpc.RpcMethodHandler]: for generic_handler in generic_handlers: method_handler = generic_handler.service(handler_call_details) if method_handler is not None: @@ -697,7 +779,8 @@ def _find_method_handler(rpc_event, generic_handlers, interceptor_pipeline): return query_handlers(handler_call_details) -def _reject_rpc(rpc_event, status, details): +def _reject_rpc(rpc_event: cygrpc.BaseEvent, status: cygrpc.StatusCode, + details: bytes) -> _RPCState: rpc_state = _RPCState() operations = ( _get_initial_metadata_operation(rpc_state, None), @@ -712,7 +795,10 @@ def _reject_rpc(rpc_event, status, details): return rpc_state -def _handle_with_method_handler(rpc_event, method_handler, thread_pool): +def _handle_with_method_handler( + rpc_event: cygrpc.BaseEvent, method_handler: grpc.RpcMethodHandler, + thread_pool: futures.ThreadPoolExecutor +) -> Tuple[_RPCState, futures.Future]: state = _RPCState() with state.condition: rpc_event.call.start_server_batch( @@ -735,8 +821,11 @@ def _handle_with_method_handler(rpc_event, method_handler, thread_pool): method_handler, thread_pool) -def _handle_call(rpc_event, generic_handlers, interceptor_pipeline, thread_pool, - concurrency_exceeded): +def _handle_call( + rpc_event: cygrpc.BaseEvent, generic_handlers: List[grpc.GenericRpcHandler], + interceptor_pipeline: Optional[_interceptor._ServicePipeline], + thread_pool: futures.ThreadPoolExecutor, concurrency_exceeded: bool +) -> Tuple[Optional[_RPCState], Optional[futures.Future]]: if not rpc_event.success: return None, None if rpc_event.call_details.method is not None: @@ -769,10 +858,28 @@ class _ServerStage(enum.Enum): class _ServerState(object): + lock: threading.RLock + completion_queue: cygrpc.CompletionQueue + server: cygrpc.Server + generic_handlers: List[grpc.GenericRpcHandler] + interceptor_pipeline: Optional[_interceptor._ServicePipeline] + thread_pool: futures.ThreadPoolExecutor + stage: _ServerStage + termination_event: threading.Event + shutdown_events: List[threading.Event] + maximum_concurrent_rpcs: Optional[int] + active_rpc_count: int + rpc_states: Set[_RPCState] + due: Set[str] + server_deallocated: bool # pylint: disable=too-many-arguments - def __init__(self, completion_queue, server, generic_handlers, - interceptor_pipeline, thread_pool, maximum_concurrent_rpcs): + def __init__(self, completion_queue: cygrpc.CompletionQueue, + server: cygrpc.Server, + generic_handlers: Sequence[grpc.GenericRpcHandler], + interceptor_pipeline: Optional[_interceptor._ServicePipeline], + thread_pool: futures.ThreadPoolExecutor, + maximum_concurrent_rpcs: Optional[int]): self.lock = threading.RLock() self.completion_queue = completion_queue self.server = server @@ -793,30 +900,33 @@ class _ServerState(object): self.server_deallocated = False -def _add_generic_handlers(state, generic_handlers): +def _add_generic_handlers( + state: _ServerState, + generic_handlers: Iterable[grpc.GenericRpcHandler]) -> None: with state.lock: state.generic_handlers.extend(generic_handlers) -def _add_insecure_port(state, address): +def _add_insecure_port(state: _ServerState, address: bytes) -> int: with state.lock: return state.server.add_http2_port(address) -def _add_secure_port(state, address, server_credentials): +def _add_secure_port(state: _ServerState, address: bytes, + server_credentials: grpc.ServerCredentials) -> int: with state.lock: return state.server.add_http2_port(address, server_credentials._credentials) -def _request_call(state): +def _request_call(state: _ServerState) -> None: state.server.request_call(state.completion_queue, state.completion_queue, _REQUEST_CALL_TAG) state.due.add(_REQUEST_CALL_TAG) # TODO(https://github.com/grpc/grpc/issues/6597): delete this function. -def _stop_serving(state): +def _stop_serving(state: _ServerState) -> bool: if not state.rpc_states and not state.due: state.server.destroy() for shutdown_event in state.shutdown_events: @@ -827,12 +937,13 @@ def _stop_serving(state): return False -def _on_call_completed(state): +def _on_call_completed(state: _ServerState) -> None: with state.lock: state.active_rpc_count -= 1 -def _process_event_and_continue(state, event): +def _process_event_and_continue(state: _ServerState, + event: cygrpc.BaseEvent) -> bool: should_continue = True if event.tag is _SHUTDOWN_TAG: with state.lock: @@ -874,7 +985,7 @@ def _process_event_and_continue(state, event): return should_continue -def _serve(state): +def _serve(state: _ServerState) -> None: while True: timeout = time.time() + _DEALLOCATED_SERVER_CHECK_PERIOD_S event = state.completion_queue.poll(timeout) @@ -889,7 +1000,7 @@ def _serve(state): event = None -def _begin_shutdown_once(state): +def _begin_shutdown_once(state: _ServerState) -> None: with state.lock: if state.stage is _ServerStage.STARTED: state.server.shutdown(state.completion_queue, _SHUTDOWN_TAG) @@ -897,7 +1008,7 @@ def _begin_shutdown_once(state): state.due.add(_SHUTDOWN_TAG) -def _stop(state, grace): +def _stop(state: _ServerState, grace: Optional[float]) -> threading.Event: with state.lock: if state.stage is _ServerStage.STOPPED: shutdown_event = threading.Event() @@ -923,7 +1034,7 @@ def _stop(state, grace): return shutdown_event -def _start(state): +def _start(state: _ServerState) -> None: with state.lock: if state.stage is not _ServerStage.STOPPED: raise ValueError('Cannot start already-started server!') @@ -936,7 +1047,8 @@ def _start(state): thread.start() -def _validate_generic_rpc_handlers(generic_rpc_handlers): +def _validate_generic_rpc_handlers( + generic_rpc_handlers: Iterable[grpc.GenericRpcHandler]) -> None: for generic_rpc_handler in generic_rpc_handlers: service_attribute = getattr(generic_rpc_handler, 'service', None) if service_attribute is None: @@ -945,16 +1057,24 @@ def _validate_generic_rpc_handlers(generic_rpc_handlers): 'not have "service" method!'.format(generic_rpc_handler)) -def _augment_options(base_options, compression): +def _augment_options( + base_options: Sequence[ChannelArgumentType], + compression: Optional[grpc.Compression] +) -> Sequence[ChannelArgumentType]: compression_option = _compression.create_channel_option(compression) return tuple(base_options) + compression_option class _Server(grpc.Server): + _state: _ServerState # pylint: disable=too-many-arguments - def __init__(self, thread_pool, generic_handlers, interceptors, options, - maximum_concurrent_rpcs, compression, xds): + def __init__(self, thread_pool: futures.ThreadPoolExecutor, + generic_handlers: Sequence[grpc.GenericRpcHandler], + interceptors: Sequence[grpc.ServerInterceptor], + options: Sequence[ChannelArgumentType], + maximum_concurrent_rpcs: Optional[int], + compression: Optional[grpc.Compression], xds: bool): completion_queue = cygrpc.CompletionQueue() server = cygrpc.Server(_augment_options(options, compression), xds) server.register_completion_queue(completion_queue) @@ -962,24 +1082,27 @@ class _Server(grpc.Server): _interceptor.service_pipeline(interceptors), thread_pool, maximum_concurrent_rpcs) - def add_generic_rpc_handlers(self, generic_rpc_handlers): + def add_generic_rpc_handlers( + self, + generic_rpc_handlers: Iterable[grpc.GenericRpcHandler]) -> None: _validate_generic_rpc_handlers(generic_rpc_handlers) _add_generic_handlers(self._state, generic_rpc_handlers) - def add_insecure_port(self, address): + def add_insecure_port(self, address: str) -> int: return _common.validate_port_binding_result( address, _add_insecure_port(self._state, _common.encode(address))) - def add_secure_port(self, address, server_credentials): + def add_secure_port(self, address: str, + server_credentials: grpc.ServerCredentials) -> int: return _common.validate_port_binding_result( address, _add_secure_port(self._state, _common.encode(address), server_credentials)) - def start(self): + def start(self) -> None: _start(self._state) - def wait_for_termination(self, timeout=None): + def wait_for_termination(self, timeout: Optional[float] = None) -> bool: # NOTE(https://bugs.python.org/issue35935) # Remove this workaround once threading.Event.wait() is working with # CTRL+C across platforms. @@ -987,7 +1110,7 @@ class _Server(grpc.Server): self._state.termination_event.is_set, timeout=timeout) - def stop(self, grace): + def stop(self, grace: Optional[float]) -> threading.Event: return _stop(self._state, grace) def __del__(self): @@ -997,8 +1120,13 @@ class _Server(grpc.Server): self._state.server_deallocated = True -def create_server(thread_pool, generic_rpc_handlers, interceptors, options, - maximum_concurrent_rpcs, compression, xds): +def create_server(thread_pool: futures.ThreadPoolExecutor, + generic_rpc_handlers: Sequence[grpc.GenericRpcHandler], + interceptors: Sequence[grpc.ServerInterceptor], + options: Sequence[ChannelArgumentType], + maximum_concurrent_rpcs: Optional[int], + compression: Optional[grpc.Compression], + xds: bool) -> _Server: _validate_generic_rpc_handlers(generic_rpc_handlers) return _Server(thread_pool, generic_rpc_handlers, interceptors, options, maximum_concurrent_rpcs, compression, xds) diff --git a/src/python/grpcio/grpc/_typing.py b/src/python/grpcio/grpc/_typing.py index 37e3762afbf..d2a0b472153 100644 --- a/src/python/grpcio/grpc/_typing.py +++ b/src/python/grpcio/grpc/_typing.py @@ -13,17 +13,46 @@ # limitations under the License. """Common types for gRPC Sync API""" -from typing import Any, Callable, Iterable, Sequence, Tuple, TypeVar, Union +from typing import (TYPE_CHECKING, Any, Callable, Iterable, Iterator, Optional, + Sequence, Tuple, TypeVar, Union) -from grpc._cython.cygrpc import EOF +from grpc._cython import cygrpc + +if TYPE_CHECKING: + from grpc import ServicerContext + from grpc._server import _RPCState RequestType = TypeVar('RequestType') ResponseType = TypeVar('ResponseType') SerializingFunction = Callable[[Any], bytes] DeserializingFunction = Callable[[bytes], Any] MetadataType = Sequence[Tuple[str, Union[str, bytes]]] -ChannelArgumentType = Sequence[Tuple[str, Any]] -EOFType = type(EOF) +ChannelArgumentType = Tuple[str, Any] DoneCallbackType = Callable[[Any], None] +NullaryCallbackType = Callable[[], None] RequestIterableType = Iterable[Any] ResponseIterableType = Iterable[Any] +UserTag = Callable[[cygrpc.BaseEvent], bool] +IntegratedCallFactory = Callable[[ + int, bytes, None, Optional[float], Optional[MetadataType], Optional[ + cygrpc.CallCredentials], Sequence[Sequence[cygrpc. + Operation]], UserTag, Any +], cygrpc.IntegratedCall] +ServerTagCallbackType = Tuple[Optional['_RPCState'], + Sequence[NullaryCallbackType]] +ServerCallbackTag = Callable[[cygrpc.BaseEvent], ServerTagCallbackType] +ArityAgnosticMethodHandler = Union[ + Callable[[RequestType, 'ServicerContext', Callable[[ResponseType], None]], + ResponseType], + Callable[[RequestType, 'ServicerContext', Callable[[ResponseType], None]], + Iterator[ResponseType]], + Callable[[ + Iterator[RequestType], 'ServicerContext', Callable[[ResponseType], None] + ], ResponseType], Callable[[ + Iterator[RequestType], 'ServicerContext', Callable[[ResponseType], None] + ], Iterator[ResponseType]], Callable[[RequestType, 'ServicerContext'], + ResponseType], + Callable[[RequestType, 'ServicerContext'], Iterator[ResponseType]], + Callable[[Iterator[RequestType], 'ServicerContext'], + ResponseType], Callable[[Iterator[RequestType], 'ServicerContext'], + Iterator[ResponseType]]] diff --git a/tools/distrib/check_pytype.sh b/tools/distrib/check_pytype.sh index c8a9a9dfd4e..666c7d31f21 100755 --- a/tools/distrib/check_pytype.sh +++ b/tools/distrib/check_pytype.sh @@ -14,6 +14,6 @@ # limitations under the License. JOBS=$(nproc) || JOBS=4 +# TODO(xuanwn): update pytype version python3 -m pip install pytype==2019.11.27 - python3 -m pytype --keep-going -j "$JOBS" --strict-import --config "setup.cfg"