diff --git a/setup.cfg b/setup.cfg index 00333b9fa13..36d00651684 100644 --- a/setup.cfg +++ b/setup.cfg @@ -21,14 +21,23 @@ license_files = LICENSE # NOTE(lidiz) Adding examples one by one due to pytype aggressive errer: # ninja: error: build.ninja:178: multiple rules generate helloworld_pb2.pyi [-w dupbuild=err] +# TODO(xuanwn): include all files in src/python/grpcio/grpc [pytype] inputs = src/python/grpcio/grpc/experimental + src/python/grpcio/grpc src/python/grpcio_tests/tests_aio examples/python/auth examples/python/helloworld exclude = **/*_pb2.py + src/python/grpcio/grpc/framework + src/python/grpcio/grpc/aio + src/python/grpcio/grpc/beta + src/python/grpcio/grpc/__init__.py + src/python/grpcio/grpc/_channel.py + src/python/grpcio/grpc/_server.py + src/python/grpcio/grpc/_simple_stubs.py # NOTE(lidiz) # import-error: C extension triggers import-error. diff --git a/src/python/grpcio/grpc/BUILD.bazel b/src/python/grpcio/grpc/BUILD.bazel index d40e1459b7a..91d55d762ce 100644 --- a/src/python/grpcio/grpc/BUILD.bazel +++ b/src/python/grpcio/grpc/BUILD.bazel @@ -89,6 +89,11 @@ py_library( srcs = ["_runtime_protos.py"], ) +py_library( + name = "_typing", + srcs = ["_typing.py"], +) + py_library( name = "grpcio", srcs = ["__init__.py"], @@ -99,6 +104,7 @@ py_library( deps = [ ":_runtime_protos", ":_simple_stubs", + ":_typing", ":aio", ":auth", ":channel", diff --git a/src/python/grpcio/grpc/_auth.py b/src/python/grpcio/grpc/_auth.py index 67113ceb586..2095957072f 100644 --- a/src/python/grpcio/grpc/_auth.py +++ b/src/python/grpcio/grpc/_auth.py @@ -14,31 +14,39 @@ """GRPCAuthMetadataPlugins for standard authentication.""" import inspect +from typing import Any, Optional import grpc -def _sign_request(callback, token, error): +def _sign_request(callback: grpc.AuthMetadataPluginCallback, + token: Optional[str], error: Optional[Exception]): metadata = (('authorization', 'Bearer {}'.format(token)),) callback(metadata, error) class GoogleCallCredentials(grpc.AuthMetadataPlugin): """Metadata wrapper for GoogleCredentials from the oauth2client library.""" + _is_jwt: bool + _credentials: Any - def __init__(self, credentials): + # TODO(xuanwn): Give credentials an actual type. + def __init__(self, credentials: Any): self._credentials = credentials # Hack to determine if these are JWT creds and we need to pass # additional_claims when getting a token self._is_jwt = 'additional_claims' in inspect.getfullargspec( credentials.get_access_token).args - def __call__(self, context, callback): + def __call__(self, context: grpc.AuthMetadataContext, + callback: grpc.AuthMetadataPluginCallback): try: if self._is_jwt: access_token = self._credentials.get_access_token( additional_claims={ - 'aud': context.service_url + 'aud': + context. + service_url # pytype: disable=attribute-error }).access_token else: access_token = self._credentials.get_access_token().access_token @@ -50,9 +58,11 @@ class GoogleCallCredentials(grpc.AuthMetadataPlugin): class AccessTokenAuthMetadataPlugin(grpc.AuthMetadataPlugin): """Metadata wrapper for raw access token credentials.""" + _access_token: str - def __init__(self, access_token): + def __init__(self, access_token: str): self._access_token = access_token - def __call__(self, context, callback): + def __call__(self, context: grpc.AuthMetadataContext, + callback: grpc.AuthMetadataPluginCallback): _sign_request(callback, self._access_token, None) diff --git a/src/python/grpcio/grpc/_common.py b/src/python/grpcio/grpc/_common.py index 7ab6b0fd306..b976800e6bb 100644 --- a/src/python/grpcio/grpc/_common.py +++ b/src/python/grpcio/grpc/_common.py @@ -15,9 +15,12 @@ import logging import time +from typing import Any, AnyStr, Callable, Optional, Union import grpc from grpc._cython import cygrpc +from grpc._typing import DeserializingFunction +from grpc._typing import SerializingFunction _LOGGER = logging.getLogger(__name__) @@ -64,20 +67,22 @@ _ERROR_MESSAGE_PORT_BINDING_FAILED = 'Failed to bind to address %s; set ' \ 'GRPC_VERBOSITY=debug environment variable to see detailed error message.' -def encode(s): +def encode(s: AnyStr) -> bytes: if isinstance(s, bytes): return s else: return s.encode('utf8') -def decode(b): +def decode(b: AnyStr) -> str: if isinstance(b, bytes): return b.decode('utf-8', 'replace') return b -def _transform(message, transformer, exception_message): +def _transform(message: Any, transformer: Union[SerializingFunction, + DeserializingFunction, None], + exception_message: str) -> Any: if transformer is None: return message else: @@ -88,26 +93,31 @@ def _transform(message, transformer, exception_message): return None -def serialize(message, serializer): +def serialize(message: Any, serializer: Optional[SerializingFunction]) -> bytes: return _transform(message, serializer, 'Exception serializing message!') -def deserialize(serialized_message, deserializer): +def deserialize(serialized_message: bytes, + deserializer: Optional[DeserializingFunction]) -> Any: return _transform(serialized_message, deserializer, 'Exception deserializing message!') -def fully_qualified_method(group, method): +def fully_qualified_method(group: str, method: str) -> str: return '/{}/{}'.format(group, method) -def _wait_once(wait_fn, timeout, spin_cb): +def _wait_once(wait_fn: Callable[..., None], timeout: float, + spin_cb: Optional[Callable[[], None]]): wait_fn(timeout=timeout) if spin_cb is not None: spin_cb() -def wait(wait_fn, wait_complete_fn, timeout=None, spin_cb=None): +def wait(wait_fn: Callable[..., None], + wait_complete_fn: Callable[[], bool], + timeout: Optional[float] = None, + spin_cb: Optional[Callable[[], None]] = None) -> bool: """Blocks waiting for an event without blocking the thread indefinitely. See https://github.com/grpc/grpc/issues/19464 for full context. CPython's @@ -148,7 +158,7 @@ def wait(wait_fn, wait_complete_fn, timeout=None, spin_cb=None): return False -def validate_port_binding_result(address, port): +def validate_port_binding_result(address: str, port: int) -> int: """Validates if the port binding succeed. If the port returned by Core is 0, the binding is failed. However, in that diff --git a/src/python/grpcio/grpc/_compression.py b/src/python/grpcio/grpc/_compression.py index 45339c3afe2..5eb6f2ac6d8 100644 --- a/src/python/grpcio/grpc/_compression.py +++ b/src/python/grpcio/grpc/_compression.py @@ -12,7 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + +from typing import Optional + +import grpc from grpc._cython import cygrpc +from grpc._typing import MetadataType NoCompression = cygrpc.CompressionAlgorithm.none Deflate = cygrpc.CompressionAlgorithm.deflate @@ -25,21 +31,23 @@ _METADATA_STRING_MAPPING = { } -def _compression_algorithm_to_metadata_value(compression): +def _compression_algorithm_to_metadata_value( + compression: grpc.Compression) -> str: return _METADATA_STRING_MAPPING[compression] -def compression_algorithm_to_metadata(compression): +def compression_algorithm_to_metadata(compression: grpc.Compression): return (cygrpc.GRPC_COMPRESSION_REQUEST_ALGORITHM_MD_KEY, _compression_algorithm_to_metadata_value(compression)) -def create_channel_option(compression): +def create_channel_option(compression: Optional[grpc.Compression]): return ((cygrpc.GRPC_COMPRESSION_CHANNEL_DEFAULT_ALGORITHM, int(compression)),) if compression else () -def augment_metadata(metadata, compression): +def augment_metadata(metadata: Optional[MetadataType], + compression: Optional[grpc.Compression]): if not metadata and not compression: return None base_metadata = tuple(metadata) if metadata else () diff --git a/src/python/grpcio/grpc/_interceptor.py b/src/python/grpcio/grpc/_interceptor.py index ee63cb31452..e327ae17261 100644 --- a/src/python/grpcio/grpc/_interceptor.py +++ b/src/python/grpcio/grpc/_interceptor.py @@ -15,19 +15,30 @@ import collections import sys +import types +from typing import Any, Callable, Optional, Sequence, Tuple, Union import grpc +from ._typing import DeserializingFunction +from ._typing import DoneCallbackType +from ._typing import MetadataType +from ._typing import RequestIterableType +from ._typing import SerializingFunction + class _ServicePipeline(object): + interceptors: Tuple[grpc.ServerInterceptor] - def __init__(self, interceptors): + def __init__(self, interceptors: Sequence[grpc.ServerInterceptor]): self.interceptors = tuple(interceptors) - def _continuation(self, thunk, index): + def _continuation(self, thunk: Callable, index: int) -> Callable: return lambda context: self._intercept_at(thunk, index, context) - def _intercept_at(self, thunk, index, context): + def _intercept_at( + self, thunk: Callable, index: int, + context: grpc.HandlerCallDetails) -> grpc.RpcMethodHandler: if index < len(self.interceptors): interceptor = self.interceptors[index] thunk = self._continuation(thunk, index + 1) @@ -35,11 +46,14 @@ class _ServicePipeline(object): else: return thunk(context) - def execute(self, thunk, context): + def execute(self, thunk: Callable, + context: grpc.HandlerCallDetails) -> grpc.RpcMethodHandler: return self._intercept_at(thunk, 0, context) -def service_pipeline(interceptors): +def service_pipeline( + interceptors: Optional[Sequence[grpc.ServerInterceptor]] +) -> Optional[_ServicePipeline]: return _ServicePipeline(interceptors) if interceptors else None @@ -51,90 +65,101 @@ class _ClientCallDetails( pass -def _unwrap_client_call_details(call_details, default_details): +def _unwrap_client_call_details( + call_details: grpc.ClientCallDetails, + default_details: grpc.ClientCallDetails +) -> Tuple[str, float, MetadataType, grpc.CallCredentials, bool, + grpc.Compression]: try: - method = call_details.method + method = call_details.method # pytype: disable=attribute-error except AttributeError: - method = default_details.method + method = default_details.method # pytype: disable=attribute-error try: - timeout = call_details.timeout + timeout = call_details.timeout # pytype: disable=attribute-error except AttributeError: - timeout = default_details.timeout + timeout = default_details.timeout # pytype: disable=attribute-error try: - metadata = call_details.metadata + metadata = call_details.metadata # pytype: disable=attribute-error except AttributeError: - metadata = default_details.metadata + metadata = default_details.metadata # pytype: disable=attribute-error try: - credentials = call_details.credentials + credentials = call_details.credentials # pytype: disable=attribute-error except AttributeError: - credentials = default_details.credentials + credentials = default_details.credentials # pytype: disable=attribute-error try: - wait_for_ready = call_details.wait_for_ready + wait_for_ready = call_details.wait_for_ready # pytype: disable=attribute-error except AttributeError: - wait_for_ready = default_details.wait_for_ready + wait_for_ready = default_details.wait_for_ready # pytype: disable=attribute-error try: - compression = call_details.compression + compression = call_details.compression # pytype: disable=attribute-error except AttributeError: - compression = default_details.compression + compression = default_details.compression # pytype: disable=attribute-error return method, timeout, metadata, credentials, wait_for_ready, compression class _FailureOutcome(grpc.RpcError, grpc.Future, grpc.Call): # pylint: disable=too-many-ancestors + _exception: Exception + _traceback: types.TracebackType - def __init__(self, exception, traceback): + def __init__(self, exception: Exception, traceback: types.TracebackType): super(_FailureOutcome, self).__init__() self._exception = exception self._traceback = traceback - def initial_metadata(self): + def initial_metadata(self) -> Optional[MetadataType]: return None - def trailing_metadata(self): + def trailing_metadata(self) -> Optional[MetadataType]: return None - def code(self): + def code(self) -> Optional[grpc.StatusCode]: return grpc.StatusCode.INTERNAL - def details(self): + def details(self) -> Optional[str]: return 'Exception raised while intercepting the RPC' - def cancel(self): + def cancel(self) -> bool: return False - def cancelled(self): + def cancelled(self) -> bool: return False - def is_active(self): + def is_active(self) -> bool: return False - def time_remaining(self): + def time_remaining(self) -> Optional[float]: return None - def running(self): + def running(self) -> bool: return False - def done(self): + def done(self) -> bool: return True - def result(self, ignored_timeout=None): + def result(self, ignored_timeout: Optional[float] = None): raise self._exception - def exception(self, ignored_timeout=None): + def exception( + self, + ignored_timeout: Optional[float] = None) -> Optional[Exception]: return self._exception - def traceback(self, ignored_timeout=None): + def traceback( + self, + ignored_timeout: Optional[float] = None + ) -> Optional[types.TracebackType]: return self._traceback - def add_callback(self, unused_callback): + def add_callback(self, unused_callback) -> bool: return False - def add_done_callback(self, fn): + def add_done_callback(self, fn: DoneCallbackType) -> None: fn(self) def __iter__(self): @@ -148,71 +173,77 @@ class _FailureOutcome(grpc.RpcError, grpc.Future, grpc.Call): # pylint: disable class _UnaryOutcome(grpc.Call, grpc.Future): + _response: Any + _call: grpc.Call - def __init__(self, response, call): + def __init__(self, response: Any, call: grpc.Call): self._response = response self._call = call - def initial_metadata(self): + def initial_metadata(self) -> Optional[MetadataType]: return self._call.initial_metadata() - def trailing_metadata(self): + def trailing_metadata(self) -> Optional[MetadataType]: return self._call.trailing_metadata() - def code(self): + def code(self) -> Optional[grpc.StatusCode]: return self._call.code() - def details(self): + def details(self) -> Optional[str]: return self._call.details() - def is_active(self): + def is_active(self) -> bool: return self._call.is_active() - def time_remaining(self): + def time_remaining(self) -> Optional[float]: return self._call.time_remaining() - def cancel(self): + def cancel(self) -> bool: return self._call.cancel() - def add_callback(self, callback): + def add_callback(self, callback) -> None: return self._call.add_callback(callback) - def cancelled(self): + def cancelled(self) -> bool: return False - def running(self): + def running(self) -> bool: return False - def done(self): + def done(self) -> bool: return True - def result(self, ignored_timeout=None): + def result(self, ignored_timeout: Optional[float] = None): return self._response - def exception(self, ignored_timeout=None): + def exception(self, ignored_timeout: Optional[float] = None): return None - def traceback(self, ignored_timeout=None): + def traceback(self, ignored_timeout: Optional[float] = None): return None - def add_done_callback(self, fn): + def add_done_callback(self, fn: DoneCallbackType) -> None: fn(self) class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): + _thunk: Callable + _method: str + _interceptor: grpc.UnaryUnaryClientInterceptor - def __init__(self, thunk, method, interceptor): + def __init__(self, thunk: Callable, method: str, + interceptor: grpc.UnaryUnaryClientInterceptor): self._thunk = thunk self._method = method self._interceptor = interceptor def __call__(self, - request, - timeout=None, - metadata=None, - credentials=None, - wait_for_ready=None, - compression=None): + request: Any, + timeout: Optional[float] = None, + metadata: Optional[MetadataType] = None, + credentials: Optional[grpc.CallCredentials] = None, + wait_for_ready: Optional[bool] = None, + compression: Optional[grpc.Compression] = None) -> Any: response, ignored_call = self._with_call(request, timeout=timeout, metadata=metadata, @@ -221,13 +252,15 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): compression=compression) return response - def _with_call(self, - request, - timeout=None, - metadata=None, - credentials=None, - wait_for_ready=None, - compression=None): + def _with_call( + self, + request: Any, + timeout: Optional[float] = None, + metadata: Optional[MetadataType] = None, + credentials: Optional[grpc.CallCredentials] = None, + wait_for_ready: Optional[bool] = None, + compression: Optional[grpc.Compression] = None + ) -> Tuple[Any, grpc.Call]: client_call_details = _ClientCallDetails(self._method, timeout, metadata, credentials, wait_for_ready, compression) @@ -256,13 +289,15 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): request) return call.result(), call - def with_call(self, - request, - timeout=None, - metadata=None, - credentials=None, - wait_for_ready=None, - compression=None): + def with_call( + self, + request: Any, + timeout: Optional[float] = None, + metadata: Optional[MetadataType] = None, + credentials: Optional[grpc.CallCredentials] = None, + wait_for_ready: Optional[bool] = None, + compression: Optional[grpc.Compression] = None + ) -> Tuple[Any, grpc.Call]: return self._with_call(request, timeout=timeout, metadata=metadata, @@ -271,12 +306,12 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): compression=compression) def future(self, - request, - timeout=None, - metadata=None, - credentials=None, - wait_for_ready=None, - compression=None): + request: Any, + timeout: Optional[float] = None, + metadata: Optional[MetadataType] = None, + credentials: Optional[grpc.CallCredentials] = None, + wait_for_ready: Optional[bool] = None, + compression: Optional[grpc.Compression] = None) -> Any: client_call_details = _ClientCallDetails(self._method, timeout, metadata, credentials, wait_for_ready, compression) @@ -302,19 +337,23 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable): + _thunk: Callable + _method: str + _interceptor: grpc.UnaryStreamClientInterceptor - def __init__(self, thunk, method, interceptor): + def __init__(self, thunk: Callable, method: str, + interceptor: grpc.UnaryStreamClientInterceptor): self._thunk = thunk self._method = method self._interceptor = interceptor def __call__(self, - request, - timeout=None, - metadata=None, - credentials=None, - wait_for_ready=None, - compression=None): + request: Any, + timeout: Optional[float] = None, + metadata: Optional[MetadataType] = None, + credentials: Optional[grpc.CallCredentials] = None, + wait_for_ready: Optional[bool] = None, + compression: Optional[grpc.Compression] = None): client_call_details = _ClientCallDetails(self._method, timeout, metadata, credentials, wait_for_ready, compression) @@ -339,19 +378,23 @@ class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable): class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): + _thunk: Callable + _method: str + _interceptor: grpc.StreamUnaryClientInterceptor - def __init__(self, thunk, method, interceptor): + def __init__(self, thunk: Callable, method: str, + interceptor: grpc.StreamUnaryClientInterceptor): self._thunk = thunk self._method = method self._interceptor = interceptor def __call__(self, - request_iterator, - timeout=None, - metadata=None, - credentials=None, - wait_for_ready=None, - compression=None): + request_iterator: RequestIterableType, + timeout: Optional[float] = None, + metadata: Optional[MetadataType] = None, + credentials: Optional[grpc.CallCredentials] = None, + wait_for_ready: Optional[bool] = None, + compression: Optional[grpc.Compression] = None) -> Any: response, ignored_call = self._with_call(request_iterator, timeout=timeout, metadata=metadata, @@ -360,13 +403,15 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): compression=compression) return response - def _with_call(self, - request_iterator, - timeout=None, - metadata=None, - credentials=None, - wait_for_ready=None, - compression=None): + def _with_call( + self, + request_iterator: RequestIterableType, + timeout: Optional[float] = None, + metadata: Optional[MetadataType] = None, + credentials: Optional[grpc.CallCredentials] = None, + wait_for_ready: Optional[bool] = None, + compression: Optional[grpc.Compression] = None + ) -> Tuple[Any, grpc.Call]: client_call_details = _ClientCallDetails(self._method, timeout, metadata, credentials, wait_for_ready, compression) @@ -395,13 +440,15 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): request_iterator) return call.result(), call - def with_call(self, - request_iterator, - timeout=None, - metadata=None, - credentials=None, - wait_for_ready=None, - compression=None): + def with_call( + self, + request_iterator: RequestIterableType, + timeout: Optional[float] = None, + metadata: Optional[MetadataType] = None, + credentials: Optional[grpc.CallCredentials] = None, + wait_for_ready: Optional[bool] = None, + compression: Optional[grpc.Compression] = None + ) -> Tuple[Any, grpc.Call]: return self._with_call(request_iterator, timeout=timeout, metadata=metadata, @@ -410,12 +457,12 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): compression=compression) def future(self, - request_iterator, - timeout=None, - metadata=None, - credentials=None, - wait_for_ready=None, - compression=None): + request_iterator: RequestIterableType, + timeout: Optional[float] = None, + metadata: Optional[MetadataType] = None, + credentials: Optional[grpc.CallCredentials] = None, + wait_for_ready: Optional[bool] = None, + compression: Optional[grpc.Compression] = None) -> Any: client_call_details = _ClientCallDetails(self._method, timeout, metadata, credentials, wait_for_ready, compression) @@ -441,19 +488,23 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable): + _thunk: Callable + _method: str + _interceptor: grpc.StreamStreamClientInterceptor - def __init__(self, thunk, method, interceptor): + def __init__(self, thunk: Callable, method: str, + interceptor: grpc.StreamStreamClientInterceptor): self._thunk = thunk self._method = method self._interceptor = interceptor def __call__(self, - request_iterator, - timeout=None, - metadata=None, - credentials=None, - wait_for_ready=None, - compression=None): + request_iterator: RequestIterableType, + timeout: Optional[float] = None, + metadata: Optional[MetadataType] = None, + credentials: Optional[grpc.CallCredentials] = None, + wait_for_ready: Optional[bool] = None, + compression: Optional[grpc.Compression] = None): client_call_details = _ClientCallDetails(self._method, timeout, metadata, credentials, wait_for_ready, compression) @@ -478,21 +529,34 @@ class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable): class _Channel(grpc.Channel): - - def __init__(self, channel, interceptor): + _channel: grpc.Channel + _interceptor: Union[grpc.UnaryUnaryClientInterceptor, + grpc.UnaryStreamClientInterceptor, + grpc.StreamStreamClientInterceptor, + grpc.StreamUnaryClientInterceptor] + + def __init__(self, channel: grpc.Channel, + interceptor: Union[grpc.UnaryUnaryClientInterceptor, + grpc.UnaryStreamClientInterceptor, + grpc.StreamStreamClientInterceptor, + grpc.StreamUnaryClientInterceptor]): self._channel = channel self._interceptor = interceptor - def subscribe(self, callback, try_to_connect=False): + def subscribe(self, + callback: Callable, + try_to_connect: Optional[bool] = False): self._channel.subscribe(callback, try_to_connect=try_to_connect) - def unsubscribe(self, callback): + def unsubscribe(self, callback: Callable): self._channel.unsubscribe(callback) - def unary_unary(self, - method, - request_serializer=None, - response_deserializer=None): + def unary_unary( + self, + method: str, + request_serializer: Optional[SerializingFunction] = None, + response_deserializer: Optional[DeserializingFunction] = None + ) -> grpc.UnaryUnaryMultiCallable: thunk = lambda m: self._channel.unary_unary(m, request_serializer, response_deserializer) if isinstance(self._interceptor, grpc.UnaryUnaryClientInterceptor): @@ -500,10 +564,12 @@ class _Channel(grpc.Channel): else: return thunk(method) - def unary_stream(self, - method, - request_serializer=None, - response_deserializer=None): + def unary_stream( + self, + method: str, + request_serializer: Optional[SerializingFunction] = None, + response_deserializer: Optional[DeserializingFunction] = None + ) -> grpc.UnaryStreamMultiCallable: thunk = lambda m: self._channel.unary_stream(m, request_serializer, response_deserializer) if isinstance(self._interceptor, grpc.UnaryStreamClientInterceptor): @@ -511,10 +577,12 @@ class _Channel(grpc.Channel): else: return thunk(method) - def stream_unary(self, - method, - request_serializer=None, - response_deserializer=None): + def stream_unary( + self, + method: str, + request_serializer: Optional[SerializingFunction] = None, + response_deserializer: Optional[DeserializingFunction] = None + ) -> grpc.StreamUnaryMultiCallable: thunk = lambda m: self._channel.stream_unary(m, request_serializer, response_deserializer) if isinstance(self._interceptor, grpc.StreamUnaryClientInterceptor): @@ -522,10 +590,12 @@ class _Channel(grpc.Channel): else: return thunk(method) - def stream_stream(self, - method, - request_serializer=None, - response_deserializer=None): + def stream_stream( + self, + method: str, + request_serializer: Optional[SerializingFunction] = None, + response_deserializer: Optional[DeserializingFunction] = None + ) -> grpc.StreamStreamMultiCallable: thunk = lambda m: self._channel.stream_stream(m, request_serializer, response_deserializer) if isinstance(self._interceptor, grpc.StreamStreamClientInterceptor): @@ -547,7 +617,13 @@ class _Channel(grpc.Channel): self._channel.close() -def intercept_channel(channel, *interceptors): +def intercept_channel( + channel: grpc.Channel, + *interceptors: Optional[Sequence[Union[grpc.UnaryUnaryClientInterceptor, + grpc.UnaryStreamClientInterceptor, + grpc.StreamStreamClientInterceptor, + grpc.StreamUnaryClientInterceptor]]] +) -> grpc.Channel: for interceptor in reversed(list(interceptors)): if not isinstance(interceptor, grpc.UnaryUnaryClientInterceptor) and \ not isinstance(interceptor, grpc.UnaryStreamClientInterceptor) and \ diff --git a/src/python/grpcio/grpc/_plugin_wrapping.py b/src/python/grpcio/grpc/_plugin_wrapping.py index ad74b256a37..942264cdaea 100644 --- a/src/python/grpcio/grpc/_plugin_wrapping.py +++ b/src/python/grpcio/grpc/_plugin_wrapping.py @@ -15,10 +15,12 @@ import collections import logging import threading +from typing import Callable, Optional, Type import grpc from grpc import _common from grpc._cython import cygrpc +from grpc._typing import MetadataType _LOGGER = logging.getLogger(__name__) @@ -40,12 +42,15 @@ class _CallbackState(object): class _AuthMetadataPluginCallback(grpc.AuthMetadataPluginCallback): + _state: _CallbackState + _callback: Callable - def __init__(self, state, callback): + def __init__(self, state: _CallbackState, callback: Callable): self._state = state self._callback = callback - def __call__(self, metadata, error): + def __call__(self, metadata: MetadataType, + error: Optional[Type[BaseException]]): with self._state.lock: if self._state.exception is None: if self._state.called: @@ -65,8 +70,9 @@ class _AuthMetadataPluginCallback(grpc.AuthMetadataPluginCallback): class _Plugin(object): + _metadata_plugin: grpc.AuthMetadataPlugin - def __init__(self, metadata_plugin): + def __init__(self, metadata_plugin: grpc.AuthMetadataPlugin): self._metadata_plugin = metadata_plugin self._stored_ctx = None @@ -81,7 +87,7 @@ class _Plugin(object): # Support versions predating contextvars. pass - def __call__(self, service_url, method_name, callback): + def __call__(self, service_url: str, method_name: str, callback: Callable): context = _AuthMetadataContext(_common.decode(service_url), _common.decode(method_name)) callback_state = _CallbackState() @@ -100,7 +106,9 @@ class _Plugin(object): _common.encode(str(exception))) -def metadata_plugin_call_credentials(metadata_plugin, name): +def metadata_plugin_call_credentials( + metadata_plugin: grpc.AuthMetadataPlugin, + name: Optional[str]) -> grpc.CallCredentials: if name is None: try: effective_name = metadata_plugin.__name__ diff --git a/src/python/grpcio/grpc/_runtime_protos.py b/src/python/grpcio/grpc/_runtime_protos.py index 2a3e1d459a2..fcc37038dac 100644 --- a/src/python/grpcio/grpc/_runtime_protos.py +++ b/src/python/grpcio/grpc/_runtime_protos.py @@ -13,6 +13,8 @@ # limitations under the License. import sys +import types +from typing import Tuple, Union _REQUIRED_SYMBOLS = ("_protos", "_services", "_protos_and_services") _MINIMUM_VERSION = (3, 5, 0) @@ -21,13 +23,13 @@ _UNINSTALLED_TEMPLATE = "Install the grpcio-tools package (1.32.0+) to use the { _VERSION_ERROR_TEMPLATE = "The {} function is only on available on Python 3.X interpreters." -def _has_runtime_proto_symbols(mod): +def _has_runtime_proto_symbols(mod: types.ModuleType) -> bool: return all(hasattr(mod, sym) for sym in _REQUIRED_SYMBOLS) -def _is_grpc_tools_importable(): +def _is_grpc_tools_importable() -> bool: try: - import grpc_tools # pylint: disable=unused-import + import grpc_tools # pylint: disable=unused-import # pytype: disable=import-error return True except ImportError as e: # NOTE: It's possible that we're encountering a transitive ImportError, so @@ -37,7 +39,9 @@ def _is_grpc_tools_importable(): return False -def _call_with_lazy_import(fn_name, protobuf_path): +def _call_with_lazy_import( + fn_name: str, protobuf_path: str +) -> Union[types.ModuleType, Tuple[types.ModuleType, types.ModuleType]]: """Calls one of the three functions, lazily importing grpc_tools. Args: @@ -52,7 +56,7 @@ def _call_with_lazy_import(fn_name, protobuf_path): else: if not _is_grpc_tools_importable(): raise NotImplementedError(_UNINSTALLED_TEMPLATE.format(fn_name)) - import grpc_tools.protoc + import grpc_tools.protoc # pytype: disable=import-error if _has_runtime_proto_symbols(grpc_tools.protoc): fn = getattr(grpc_tools.protoc, '_' + fn_name) return fn(protobuf_path) diff --git a/src/python/grpcio/grpc/_typing.py b/src/python/grpcio/grpc/_typing.py new file mode 100644 index 00000000000..37e3762afbf --- /dev/null +++ b/src/python/grpcio/grpc/_typing.py @@ -0,0 +1,29 @@ +# Copyright 2022 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Common types for gRPC Sync API""" + +from typing import Any, Callable, Iterable, Sequence, Tuple, TypeVar, Union + +from grpc._cython.cygrpc import EOF + +RequestType = TypeVar('RequestType') +ResponseType = TypeVar('ResponseType') +SerializingFunction = Callable[[Any], bytes] +DeserializingFunction = Callable[[bytes], Any] +MetadataType = Sequence[Tuple[str, Union[str, bytes]]] +ChannelArgumentType = Sequence[Tuple[str, Any]] +EOFType = type(EOF) +DoneCallbackType = Callable[[Any], None] +RequestIterableType = Iterable[Any] +ResponseIterableType = Iterable[Any] diff --git a/src/python/grpcio/grpc/_utilities.py b/src/python/grpcio/grpc/_utilities.py index a8d15dc3f4b..3dafa7a03d3 100644 --- a/src/python/grpcio/grpc/_utilities.py +++ b/src/python/grpcio/grpc/_utilities.py @@ -17,9 +17,11 @@ import collections import logging import threading import time +from typing import Callable, Dict, Optional, Sequence -import grpc -from grpc import _common +import grpc # pytype: disable=pyi-error +from grpc import _common # pytype: disable=pyi-error +from grpc._typing import DoneCallbackType _LOGGER = logging.getLogger(__name__) @@ -42,24 +44,35 @@ class RpcMethodHandler( class DictionaryGenericHandler(grpc.ServiceRpcHandler): + _name: str + _method_handlers: Dict[str, grpc.RpcMethodHandler] - def __init__(self, service, method_handlers): + def __init__(self, service: str, + method_handlers: Dict[str, grpc.RpcMethodHandler]): self._name = service self._method_handlers = { _common.fully_qualified_method(service, method): method_handler for method, method_handler in method_handlers.items() } - def service_name(self): + def service_name(self) -> str: return self._name - def service(self, handler_call_details): - return self._method_handlers.get(handler_call_details.method) + def service( + self, handler_call_details: grpc.HandlerCallDetails + ) -> Optional[grpc.RpcMethodHandler]: + details_method = handler_call_details.method + return self._method_handlers.get(details_method) # pytype: disable=attribute-error class _ChannelReadyFuture(grpc.Future): + _condition: threading.Condition + _channel: grpc.Channel + _matured: bool + _cancelled: bool + _done_callbacks: Sequence[Callable] - def __init__(self, channel): + def __init__(self, channel: grpc.Channel): self._condition = threading.Condition() self._channel = channel @@ -67,7 +80,7 @@ class _ChannelReadyFuture(grpc.Future): self._cancelled = False self._done_callbacks = [] - def _block(self, timeout): + def _block(self, timeout: Optional[float]) -> None: until = None if timeout is None else time.time() + timeout with self._condition: while True: @@ -85,7 +98,7 @@ class _ChannelReadyFuture(grpc.Future): else: self._condition.wait(timeout=remaining) - def _update(self, connectivity): + def _update(self, connectivity: Optional[grpc.ChannelConnectivity]) -> None: with self._condition: if (not self._cancelled and connectivity is grpc.ChannelConnectivity.READY): @@ -103,7 +116,7 @@ class _ChannelReadyFuture(grpc.Future): except Exception: # pylint: disable=broad-except _LOGGER.exception(_DONE_CALLBACK_EXCEPTION_LOG_MESSAGE) - def cancel(self): + def cancel(self) -> bool: with self._condition: if not self._matured: self._cancelled = True @@ -122,28 +135,28 @@ class _ChannelReadyFuture(grpc.Future): return True - def cancelled(self): + def cancelled(self) -> bool: with self._condition: return self._cancelled - def running(self): + def running(self) -> bool: with self._condition: return not self._cancelled and not self._matured - def done(self): + def done(self) -> bool: with self._condition: return self._cancelled or self._matured - def result(self, timeout=None): + def result(self, timeout: Optional[float] = None) -> None: self._block(timeout) - def exception(self, timeout=None): + def exception(self, timeout: Optional[float] = None) -> None: self._block(timeout) - def traceback(self, timeout=None): + def traceback(self, timeout: Optional[float] = None) -> None: self._block(timeout) - def add_done_callback(self, fn): + def add_done_callback(self, fn: DoneCallbackType): with self._condition: if not self._cancelled and not self._matured: self._done_callbacks.append(fn) @@ -161,7 +174,7 @@ class _ChannelReadyFuture(grpc.Future): self._channel.unsubscribe(self._update) -def channel_ready_future(channel): +def channel_ready_future(channel: grpc.Channel) -> _ChannelReadyFuture: ready_future = _ChannelReadyFuture(channel) ready_future.start() return ready_future