Add typing for some internal python files. (#31514)

* Add typing for some internal python files.
pull/31586/head
Xuan Wang 2 years ago committed by GitHub
parent 90d8754b0e
commit e1978a4fdd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 9
      setup.cfg
  2. 6
      src/python/grpcio/grpc/BUILD.bazel
  3. 22
      src/python/grpcio/grpc/_auth.py
  4. 28
      src/python/grpcio/grpc/_common.py
  5. 16
      src/python/grpcio/grpc/_compression.py
  6. 354
      src/python/grpcio/grpc/_interceptor.py
  7. 18
      src/python/grpcio/grpc/_plugin_wrapping.py
  8. 14
      src/python/grpcio/grpc/_runtime_protos.py
  9. 29
      src/python/grpcio/grpc/_typing.py
  10. 49
      src/python/grpcio/grpc/_utilities.py

@ -21,14 +21,23 @@ license_files = LICENSE
# NOTE(lidiz) Adding examples one by one due to pytype aggressive errer: # NOTE(lidiz) Adding examples one by one due to pytype aggressive errer:
# ninja: error: build.ninja:178: multiple rules generate helloworld_pb2.pyi [-w dupbuild=err] # ninja: error: build.ninja:178: multiple rules generate helloworld_pb2.pyi [-w dupbuild=err]
# TODO(xuanwn): include all files in src/python/grpcio/grpc
[pytype] [pytype]
inputs = inputs =
src/python/grpcio/grpc/experimental src/python/grpcio/grpc/experimental
src/python/grpcio/grpc
src/python/grpcio_tests/tests_aio src/python/grpcio_tests/tests_aio
examples/python/auth examples/python/auth
examples/python/helloworld examples/python/helloworld
exclude = exclude =
**/*_pb2.py **/*_pb2.py
src/python/grpcio/grpc/framework
src/python/grpcio/grpc/aio
src/python/grpcio/grpc/beta
src/python/grpcio/grpc/__init__.py
src/python/grpcio/grpc/_channel.py
src/python/grpcio/grpc/_server.py
src/python/grpcio/grpc/_simple_stubs.py
# NOTE(lidiz) # NOTE(lidiz)
# import-error: C extension triggers import-error. # import-error: C extension triggers import-error.

@ -89,6 +89,11 @@ py_library(
srcs = ["_runtime_protos.py"], srcs = ["_runtime_protos.py"],
) )
py_library(
name = "_typing",
srcs = ["_typing.py"],
)
py_library( py_library(
name = "grpcio", name = "grpcio",
srcs = ["__init__.py"], srcs = ["__init__.py"],
@ -99,6 +104,7 @@ py_library(
deps = [ deps = [
":_runtime_protos", ":_runtime_protos",
":_simple_stubs", ":_simple_stubs",
":_typing",
":aio", ":aio",
":auth", ":auth",
":channel", ":channel",

@ -14,31 +14,39 @@
"""GRPCAuthMetadataPlugins for standard authentication.""" """GRPCAuthMetadataPlugins for standard authentication."""
import inspect import inspect
from typing import Any, Optional
import grpc import grpc
def _sign_request(callback, token, error): def _sign_request(callback: grpc.AuthMetadataPluginCallback,
token: Optional[str], error: Optional[Exception]):
metadata = (('authorization', 'Bearer {}'.format(token)),) metadata = (('authorization', 'Bearer {}'.format(token)),)
callback(metadata, error) callback(metadata, error)
class GoogleCallCredentials(grpc.AuthMetadataPlugin): class GoogleCallCredentials(grpc.AuthMetadataPlugin):
"""Metadata wrapper for GoogleCredentials from the oauth2client library.""" """Metadata wrapper for GoogleCredentials from the oauth2client library."""
_is_jwt: bool
_credentials: Any
def __init__(self, credentials): # TODO(xuanwn): Give credentials an actual type.
def __init__(self, credentials: Any):
self._credentials = credentials self._credentials = credentials
# Hack to determine if these are JWT creds and we need to pass # Hack to determine if these are JWT creds and we need to pass
# additional_claims when getting a token # additional_claims when getting a token
self._is_jwt = 'additional_claims' in inspect.getfullargspec( self._is_jwt = 'additional_claims' in inspect.getfullargspec(
credentials.get_access_token).args credentials.get_access_token).args
def __call__(self, context, callback): def __call__(self, context: grpc.AuthMetadataContext,
callback: grpc.AuthMetadataPluginCallback):
try: try:
if self._is_jwt: if self._is_jwt:
access_token = self._credentials.get_access_token( access_token = self._credentials.get_access_token(
additional_claims={ additional_claims={
'aud': context.service_url 'aud':
context.
service_url # pytype: disable=attribute-error
}).access_token }).access_token
else: else:
access_token = self._credentials.get_access_token().access_token access_token = self._credentials.get_access_token().access_token
@ -50,9 +58,11 @@ class GoogleCallCredentials(grpc.AuthMetadataPlugin):
class AccessTokenAuthMetadataPlugin(grpc.AuthMetadataPlugin): class AccessTokenAuthMetadataPlugin(grpc.AuthMetadataPlugin):
"""Metadata wrapper for raw access token credentials.""" """Metadata wrapper for raw access token credentials."""
_access_token: str
def __init__(self, access_token): def __init__(self, access_token: str):
self._access_token = access_token self._access_token = access_token
def __call__(self, context, callback): def __call__(self, context: grpc.AuthMetadataContext,
callback: grpc.AuthMetadataPluginCallback):
_sign_request(callback, self._access_token, None) _sign_request(callback, self._access_token, None)

@ -15,9 +15,12 @@
import logging import logging
import time import time
from typing import Any, AnyStr, Callable, Optional, Union
import grpc import grpc
from grpc._cython import cygrpc from grpc._cython import cygrpc
from grpc._typing import DeserializingFunction
from grpc._typing import SerializingFunction
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -64,20 +67,22 @@ _ERROR_MESSAGE_PORT_BINDING_FAILED = 'Failed to bind to address %s; set ' \
'GRPC_VERBOSITY=debug environment variable to see detailed error message.' 'GRPC_VERBOSITY=debug environment variable to see detailed error message.'
def encode(s): def encode(s: AnyStr) -> bytes:
if isinstance(s, bytes): if isinstance(s, bytes):
return s return s
else: else:
return s.encode('utf8') return s.encode('utf8')
def decode(b): def decode(b: AnyStr) -> str:
if isinstance(b, bytes): if isinstance(b, bytes):
return b.decode('utf-8', 'replace') return b.decode('utf-8', 'replace')
return b return b
def _transform(message, transformer, exception_message): def _transform(message: Any, transformer: Union[SerializingFunction,
DeserializingFunction, None],
exception_message: str) -> Any:
if transformer is None: if transformer is None:
return message return message
else: else:
@ -88,26 +93,31 @@ def _transform(message, transformer, exception_message):
return None return None
def serialize(message, serializer): def serialize(message: Any, serializer: Optional[SerializingFunction]) -> bytes:
return _transform(message, serializer, 'Exception serializing message!') return _transform(message, serializer, 'Exception serializing message!')
def deserialize(serialized_message, deserializer): def deserialize(serialized_message: bytes,
deserializer: Optional[DeserializingFunction]) -> Any:
return _transform(serialized_message, deserializer, return _transform(serialized_message, deserializer,
'Exception deserializing message!') 'Exception deserializing message!')
def fully_qualified_method(group, method): def fully_qualified_method(group: str, method: str) -> str:
return '/{}/{}'.format(group, method) return '/{}/{}'.format(group, method)
def _wait_once(wait_fn, timeout, spin_cb): def _wait_once(wait_fn: Callable[..., None], timeout: float,
spin_cb: Optional[Callable[[], None]]):
wait_fn(timeout=timeout) wait_fn(timeout=timeout)
if spin_cb is not None: if spin_cb is not None:
spin_cb() spin_cb()
def wait(wait_fn, wait_complete_fn, timeout=None, spin_cb=None): def wait(wait_fn: Callable[..., None],
wait_complete_fn: Callable[[], bool],
timeout: Optional[float] = None,
spin_cb: Optional[Callable[[], None]] = None) -> bool:
"""Blocks waiting for an event without blocking the thread indefinitely. """Blocks waiting for an event without blocking the thread indefinitely.
See https://github.com/grpc/grpc/issues/19464 for full context. CPython's See https://github.com/grpc/grpc/issues/19464 for full context. CPython's
@ -148,7 +158,7 @@ def wait(wait_fn, wait_complete_fn, timeout=None, spin_cb=None):
return False return False
def validate_port_binding_result(address, port): def validate_port_binding_result(address: str, port: int) -> int:
"""Validates if the port binding succeed. """Validates if the port binding succeed.
If the port returned by Core is 0, the binding is failed. However, in that If the port returned by Core is 0, the binding is failed. However, in that

@ -12,7 +12,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import annotations
from typing import Optional
import grpc
from grpc._cython import cygrpc from grpc._cython import cygrpc
from grpc._typing import MetadataType
NoCompression = cygrpc.CompressionAlgorithm.none NoCompression = cygrpc.CompressionAlgorithm.none
Deflate = cygrpc.CompressionAlgorithm.deflate Deflate = cygrpc.CompressionAlgorithm.deflate
@ -25,21 +31,23 @@ _METADATA_STRING_MAPPING = {
} }
def _compression_algorithm_to_metadata_value(compression): def _compression_algorithm_to_metadata_value(
compression: grpc.Compression) -> str:
return _METADATA_STRING_MAPPING[compression] return _METADATA_STRING_MAPPING[compression]
def compression_algorithm_to_metadata(compression): def compression_algorithm_to_metadata(compression: grpc.Compression):
return (cygrpc.GRPC_COMPRESSION_REQUEST_ALGORITHM_MD_KEY, return (cygrpc.GRPC_COMPRESSION_REQUEST_ALGORITHM_MD_KEY,
_compression_algorithm_to_metadata_value(compression)) _compression_algorithm_to_metadata_value(compression))
def create_channel_option(compression): def create_channel_option(compression: Optional[grpc.Compression]):
return ((cygrpc.GRPC_COMPRESSION_CHANNEL_DEFAULT_ALGORITHM, return ((cygrpc.GRPC_COMPRESSION_CHANNEL_DEFAULT_ALGORITHM,
int(compression)),) if compression else () int(compression)),) if compression else ()
def augment_metadata(metadata, compression): def augment_metadata(metadata: Optional[MetadataType],
compression: Optional[grpc.Compression]):
if not metadata and not compression: if not metadata and not compression:
return None return None
base_metadata = tuple(metadata) if metadata else () base_metadata = tuple(metadata) if metadata else ()

@ -15,19 +15,30 @@
import collections import collections
import sys import sys
import types
from typing import Any, Callable, Optional, Sequence, Tuple, Union
import grpc import grpc
from ._typing import DeserializingFunction
from ._typing import DoneCallbackType
from ._typing import MetadataType
from ._typing import RequestIterableType
from ._typing import SerializingFunction
class _ServicePipeline(object): class _ServicePipeline(object):
interceptors: Tuple[grpc.ServerInterceptor]
def __init__(self, interceptors): def __init__(self, interceptors: Sequence[grpc.ServerInterceptor]):
self.interceptors = tuple(interceptors) self.interceptors = tuple(interceptors)
def _continuation(self, thunk, index): def _continuation(self, thunk: Callable, index: int) -> Callable:
return lambda context: self._intercept_at(thunk, index, context) return lambda context: self._intercept_at(thunk, index, context)
def _intercept_at(self, thunk, index, context): def _intercept_at(
self, thunk: Callable, index: int,
context: grpc.HandlerCallDetails) -> grpc.RpcMethodHandler:
if index < len(self.interceptors): if index < len(self.interceptors):
interceptor = self.interceptors[index] interceptor = self.interceptors[index]
thunk = self._continuation(thunk, index + 1) thunk = self._continuation(thunk, index + 1)
@ -35,11 +46,14 @@ class _ServicePipeline(object):
else: else:
return thunk(context) return thunk(context)
def execute(self, thunk, context): def execute(self, thunk: Callable,
context: grpc.HandlerCallDetails) -> grpc.RpcMethodHandler:
return self._intercept_at(thunk, 0, context) return self._intercept_at(thunk, 0, context)
def service_pipeline(interceptors): def service_pipeline(
interceptors: Optional[Sequence[grpc.ServerInterceptor]]
) -> Optional[_ServicePipeline]:
return _ServicePipeline(interceptors) if interceptors else None return _ServicePipeline(interceptors) if interceptors else None
@ -51,90 +65,101 @@ class _ClientCallDetails(
pass pass
def _unwrap_client_call_details(call_details, default_details): def _unwrap_client_call_details(
call_details: grpc.ClientCallDetails,
default_details: grpc.ClientCallDetails
) -> Tuple[str, float, MetadataType, grpc.CallCredentials, bool,
grpc.Compression]:
try: try:
method = call_details.method method = call_details.method # pytype: disable=attribute-error
except AttributeError: except AttributeError:
method = default_details.method method = default_details.method # pytype: disable=attribute-error
try: try:
timeout = call_details.timeout timeout = call_details.timeout # pytype: disable=attribute-error
except AttributeError: except AttributeError:
timeout = default_details.timeout timeout = default_details.timeout # pytype: disable=attribute-error
try: try:
metadata = call_details.metadata metadata = call_details.metadata # pytype: disable=attribute-error
except AttributeError: except AttributeError:
metadata = default_details.metadata metadata = default_details.metadata # pytype: disable=attribute-error
try: try:
credentials = call_details.credentials credentials = call_details.credentials # pytype: disable=attribute-error
except AttributeError: except AttributeError:
credentials = default_details.credentials credentials = default_details.credentials # pytype: disable=attribute-error
try: try:
wait_for_ready = call_details.wait_for_ready wait_for_ready = call_details.wait_for_ready # pytype: disable=attribute-error
except AttributeError: except AttributeError:
wait_for_ready = default_details.wait_for_ready wait_for_ready = default_details.wait_for_ready # pytype: disable=attribute-error
try: try:
compression = call_details.compression compression = call_details.compression # pytype: disable=attribute-error
except AttributeError: except AttributeError:
compression = default_details.compression compression = default_details.compression # pytype: disable=attribute-error
return method, timeout, metadata, credentials, wait_for_ready, compression return method, timeout, metadata, credentials, wait_for_ready, compression
class _FailureOutcome(grpc.RpcError, grpc.Future, grpc.Call): # pylint: disable=too-many-ancestors class _FailureOutcome(grpc.RpcError, grpc.Future, grpc.Call): # pylint: disable=too-many-ancestors
_exception: Exception
_traceback: types.TracebackType
def __init__(self, exception, traceback): def __init__(self, exception: Exception, traceback: types.TracebackType):
super(_FailureOutcome, self).__init__() super(_FailureOutcome, self).__init__()
self._exception = exception self._exception = exception
self._traceback = traceback self._traceback = traceback
def initial_metadata(self): def initial_metadata(self) -> Optional[MetadataType]:
return None return None
def trailing_metadata(self): def trailing_metadata(self) -> Optional[MetadataType]:
return None return None
def code(self): def code(self) -> Optional[grpc.StatusCode]:
return grpc.StatusCode.INTERNAL return grpc.StatusCode.INTERNAL
def details(self): def details(self) -> Optional[str]:
return 'Exception raised while intercepting the RPC' return 'Exception raised while intercepting the RPC'
def cancel(self): def cancel(self) -> bool:
return False return False
def cancelled(self): def cancelled(self) -> bool:
return False return False
def is_active(self): def is_active(self) -> bool:
return False return False
def time_remaining(self): def time_remaining(self) -> Optional[float]:
return None return None
def running(self): def running(self) -> bool:
return False return False
def done(self): def done(self) -> bool:
return True return True
def result(self, ignored_timeout=None): def result(self, ignored_timeout: Optional[float] = None):
raise self._exception raise self._exception
def exception(self, ignored_timeout=None): def exception(
self,
ignored_timeout: Optional[float] = None) -> Optional[Exception]:
return self._exception return self._exception
def traceback(self, ignored_timeout=None): def traceback(
self,
ignored_timeout: Optional[float] = None
) -> Optional[types.TracebackType]:
return self._traceback return self._traceback
def add_callback(self, unused_callback): def add_callback(self, unused_callback) -> bool:
return False return False
def add_done_callback(self, fn): def add_done_callback(self, fn: DoneCallbackType) -> None:
fn(self) fn(self)
def __iter__(self): def __iter__(self):
@ -148,71 +173,77 @@ class _FailureOutcome(grpc.RpcError, grpc.Future, grpc.Call): # pylint: disable
class _UnaryOutcome(grpc.Call, grpc.Future): class _UnaryOutcome(grpc.Call, grpc.Future):
_response: Any
_call: grpc.Call
def __init__(self, response, call): def __init__(self, response: Any, call: grpc.Call):
self._response = response self._response = response
self._call = call self._call = call
def initial_metadata(self): def initial_metadata(self) -> Optional[MetadataType]:
return self._call.initial_metadata() return self._call.initial_metadata()
def trailing_metadata(self): def trailing_metadata(self) -> Optional[MetadataType]:
return self._call.trailing_metadata() return self._call.trailing_metadata()
def code(self): def code(self) -> Optional[grpc.StatusCode]:
return self._call.code() return self._call.code()
def details(self): def details(self) -> Optional[str]:
return self._call.details() return self._call.details()
def is_active(self): def is_active(self) -> bool:
return self._call.is_active() return self._call.is_active()
def time_remaining(self): def time_remaining(self) -> Optional[float]:
return self._call.time_remaining() return self._call.time_remaining()
def cancel(self): def cancel(self) -> bool:
return self._call.cancel() return self._call.cancel()
def add_callback(self, callback): def add_callback(self, callback) -> None:
return self._call.add_callback(callback) return self._call.add_callback(callback)
def cancelled(self): def cancelled(self) -> bool:
return False return False
def running(self): def running(self) -> bool:
return False return False
def done(self): def done(self) -> bool:
return True return True
def result(self, ignored_timeout=None): def result(self, ignored_timeout: Optional[float] = None):
return self._response return self._response
def exception(self, ignored_timeout=None): def exception(self, ignored_timeout: Optional[float] = None):
return None return None
def traceback(self, ignored_timeout=None): def traceback(self, ignored_timeout: Optional[float] = None):
return None return None
def add_done_callback(self, fn): def add_done_callback(self, fn: DoneCallbackType) -> None:
fn(self) fn(self)
class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
_thunk: Callable
_method: str
_interceptor: grpc.UnaryUnaryClientInterceptor
def __init__(self, thunk, method, interceptor): def __init__(self, thunk: Callable, method: str,
interceptor: grpc.UnaryUnaryClientInterceptor):
self._thunk = thunk self._thunk = thunk
self._method = method self._method = method
self._interceptor = interceptor self._interceptor = interceptor
def __call__(self, def __call__(self,
request, request: Any,
timeout=None, timeout: Optional[float] = None,
metadata=None, metadata: Optional[MetadataType] = None,
credentials=None, credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready=None, wait_for_ready: Optional[bool] = None,
compression=None): compression: Optional[grpc.Compression] = None) -> Any:
response, ignored_call = self._with_call(request, response, ignored_call = self._with_call(request,
timeout=timeout, timeout=timeout,
metadata=metadata, metadata=metadata,
@ -221,13 +252,15 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
compression=compression) compression=compression)
return response return response
def _with_call(self, def _with_call(
request, self,
timeout=None, request: Any,
metadata=None, timeout: Optional[float] = None,
credentials=None, metadata: Optional[MetadataType] = None,
wait_for_ready=None, credentials: Optional[grpc.CallCredentials] = None,
compression=None): wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None
) -> Tuple[Any, grpc.Call]:
client_call_details = _ClientCallDetails(self._method, timeout, client_call_details = _ClientCallDetails(self._method, timeout,
metadata, credentials, metadata, credentials,
wait_for_ready, compression) wait_for_ready, compression)
@ -256,13 +289,15 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
request) request)
return call.result(), call return call.result(), call
def with_call(self, def with_call(
request, self,
timeout=None, request: Any,
metadata=None, timeout: Optional[float] = None,
credentials=None, metadata: Optional[MetadataType] = None,
wait_for_ready=None, credentials: Optional[grpc.CallCredentials] = None,
compression=None): wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None
) -> Tuple[Any, grpc.Call]:
return self._with_call(request, return self._with_call(request,
timeout=timeout, timeout=timeout,
metadata=metadata, metadata=metadata,
@ -271,12 +306,12 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
compression=compression) compression=compression)
def future(self, def future(self,
request, request: Any,
timeout=None, timeout: Optional[float] = None,
metadata=None, metadata: Optional[MetadataType] = None,
credentials=None, credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready=None, wait_for_ready: Optional[bool] = None,
compression=None): compression: Optional[grpc.Compression] = None) -> Any:
client_call_details = _ClientCallDetails(self._method, timeout, client_call_details = _ClientCallDetails(self._method, timeout,
metadata, credentials, metadata, credentials,
wait_for_ready, compression) wait_for_ready, compression)
@ -302,19 +337,23 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable): class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
_thunk: Callable
_method: str
_interceptor: grpc.UnaryStreamClientInterceptor
def __init__(self, thunk, method, interceptor): def __init__(self, thunk: Callable, method: str,
interceptor: grpc.UnaryStreamClientInterceptor):
self._thunk = thunk self._thunk = thunk
self._method = method self._method = method
self._interceptor = interceptor self._interceptor = interceptor
def __call__(self, def __call__(self,
request, request: Any,
timeout=None, timeout: Optional[float] = None,
metadata=None, metadata: Optional[MetadataType] = None,
credentials=None, credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready=None, wait_for_ready: Optional[bool] = None,
compression=None): compression: Optional[grpc.Compression] = None):
client_call_details = _ClientCallDetails(self._method, timeout, client_call_details = _ClientCallDetails(self._method, timeout,
metadata, credentials, metadata, credentials,
wait_for_ready, compression) wait_for_ready, compression)
@ -339,19 +378,23 @@ class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
_thunk: Callable
_method: str
_interceptor: grpc.StreamUnaryClientInterceptor
def __init__(self, thunk, method, interceptor): def __init__(self, thunk: Callable, method: str,
interceptor: grpc.StreamUnaryClientInterceptor):
self._thunk = thunk self._thunk = thunk
self._method = method self._method = method
self._interceptor = interceptor self._interceptor = interceptor
def __call__(self, def __call__(self,
request_iterator, request_iterator: RequestIterableType,
timeout=None, timeout: Optional[float] = None,
metadata=None, metadata: Optional[MetadataType] = None,
credentials=None, credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready=None, wait_for_ready: Optional[bool] = None,
compression=None): compression: Optional[grpc.Compression] = None) -> Any:
response, ignored_call = self._with_call(request_iterator, response, ignored_call = self._with_call(request_iterator,
timeout=timeout, timeout=timeout,
metadata=metadata, metadata=metadata,
@ -360,13 +403,15 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
compression=compression) compression=compression)
return response return response
def _with_call(self, def _with_call(
request_iterator, self,
timeout=None, request_iterator: RequestIterableType,
metadata=None, timeout: Optional[float] = None,
credentials=None, metadata: Optional[MetadataType] = None,
wait_for_ready=None, credentials: Optional[grpc.CallCredentials] = None,
compression=None): wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None
) -> Tuple[Any, grpc.Call]:
client_call_details = _ClientCallDetails(self._method, timeout, client_call_details = _ClientCallDetails(self._method, timeout,
metadata, credentials, metadata, credentials,
wait_for_ready, compression) wait_for_ready, compression)
@ -395,13 +440,15 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
request_iterator) request_iterator)
return call.result(), call return call.result(), call
def with_call(self, def with_call(
request_iterator, self,
timeout=None, request_iterator: RequestIterableType,
metadata=None, timeout: Optional[float] = None,
credentials=None, metadata: Optional[MetadataType] = None,
wait_for_ready=None, credentials: Optional[grpc.CallCredentials] = None,
compression=None): wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None
) -> Tuple[Any, grpc.Call]:
return self._with_call(request_iterator, return self._with_call(request_iterator,
timeout=timeout, timeout=timeout,
metadata=metadata, metadata=metadata,
@ -410,12 +457,12 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
compression=compression) compression=compression)
def future(self, def future(self,
request_iterator, request_iterator: RequestIterableType,
timeout=None, timeout: Optional[float] = None,
metadata=None, metadata: Optional[MetadataType] = None,
credentials=None, credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready=None, wait_for_ready: Optional[bool] = None,
compression=None): compression: Optional[grpc.Compression] = None) -> Any:
client_call_details = _ClientCallDetails(self._method, timeout, client_call_details = _ClientCallDetails(self._method, timeout,
metadata, credentials, metadata, credentials,
wait_for_ready, compression) wait_for_ready, compression)
@ -441,19 +488,23 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable): class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable):
_thunk: Callable
_method: str
_interceptor: grpc.StreamStreamClientInterceptor
def __init__(self, thunk, method, interceptor): def __init__(self, thunk: Callable, method: str,
interceptor: grpc.StreamStreamClientInterceptor):
self._thunk = thunk self._thunk = thunk
self._method = method self._method = method
self._interceptor = interceptor self._interceptor = interceptor
def __call__(self, def __call__(self,
request_iterator, request_iterator: RequestIterableType,
timeout=None, timeout: Optional[float] = None,
metadata=None, metadata: Optional[MetadataType] = None,
credentials=None, credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready=None, wait_for_ready: Optional[bool] = None,
compression=None): compression: Optional[grpc.Compression] = None):
client_call_details = _ClientCallDetails(self._method, timeout, client_call_details = _ClientCallDetails(self._method, timeout,
metadata, credentials, metadata, credentials,
wait_for_ready, compression) wait_for_ready, compression)
@ -478,21 +529,34 @@ class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable):
class _Channel(grpc.Channel): class _Channel(grpc.Channel):
_channel: grpc.Channel
def __init__(self, channel, interceptor): _interceptor: Union[grpc.UnaryUnaryClientInterceptor,
grpc.UnaryStreamClientInterceptor,
grpc.StreamStreamClientInterceptor,
grpc.StreamUnaryClientInterceptor]
def __init__(self, channel: grpc.Channel,
interceptor: Union[grpc.UnaryUnaryClientInterceptor,
grpc.UnaryStreamClientInterceptor,
grpc.StreamStreamClientInterceptor,
grpc.StreamUnaryClientInterceptor]):
self._channel = channel self._channel = channel
self._interceptor = interceptor self._interceptor = interceptor
def subscribe(self, callback, try_to_connect=False): def subscribe(self,
callback: Callable,
try_to_connect: Optional[bool] = False):
self._channel.subscribe(callback, try_to_connect=try_to_connect) self._channel.subscribe(callback, try_to_connect=try_to_connect)
def unsubscribe(self, callback): def unsubscribe(self, callback: Callable):
self._channel.unsubscribe(callback) self._channel.unsubscribe(callback)
def unary_unary(self, def unary_unary(
method, self,
request_serializer=None, method: str,
response_deserializer=None): request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None
) -> grpc.UnaryUnaryMultiCallable:
thunk = lambda m: self._channel.unary_unary(m, request_serializer, thunk = lambda m: self._channel.unary_unary(m, request_serializer,
response_deserializer) response_deserializer)
if isinstance(self._interceptor, grpc.UnaryUnaryClientInterceptor): if isinstance(self._interceptor, grpc.UnaryUnaryClientInterceptor):
@ -500,10 +564,12 @@ class _Channel(grpc.Channel):
else: else:
return thunk(method) return thunk(method)
def unary_stream(self, def unary_stream(
method, self,
request_serializer=None, method: str,
response_deserializer=None): request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None
) -> grpc.UnaryStreamMultiCallable:
thunk = lambda m: self._channel.unary_stream(m, request_serializer, thunk = lambda m: self._channel.unary_stream(m, request_serializer,
response_deserializer) response_deserializer)
if isinstance(self._interceptor, grpc.UnaryStreamClientInterceptor): if isinstance(self._interceptor, grpc.UnaryStreamClientInterceptor):
@ -511,10 +577,12 @@ class _Channel(grpc.Channel):
else: else:
return thunk(method) return thunk(method)
def stream_unary(self, def stream_unary(
method, self,
request_serializer=None, method: str,
response_deserializer=None): request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None
) -> grpc.StreamUnaryMultiCallable:
thunk = lambda m: self._channel.stream_unary(m, request_serializer, thunk = lambda m: self._channel.stream_unary(m, request_serializer,
response_deserializer) response_deserializer)
if isinstance(self._interceptor, grpc.StreamUnaryClientInterceptor): if isinstance(self._interceptor, grpc.StreamUnaryClientInterceptor):
@ -522,10 +590,12 @@ class _Channel(grpc.Channel):
else: else:
return thunk(method) return thunk(method)
def stream_stream(self, def stream_stream(
method, self,
request_serializer=None, method: str,
response_deserializer=None): request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None
) -> grpc.StreamStreamMultiCallable:
thunk = lambda m: self._channel.stream_stream(m, request_serializer, thunk = lambda m: self._channel.stream_stream(m, request_serializer,
response_deserializer) response_deserializer)
if isinstance(self._interceptor, grpc.StreamStreamClientInterceptor): if isinstance(self._interceptor, grpc.StreamStreamClientInterceptor):
@ -547,7 +617,13 @@ class _Channel(grpc.Channel):
self._channel.close() self._channel.close()
def intercept_channel(channel, *interceptors): def intercept_channel(
channel: grpc.Channel,
*interceptors: Optional[Sequence[Union[grpc.UnaryUnaryClientInterceptor,
grpc.UnaryStreamClientInterceptor,
grpc.StreamStreamClientInterceptor,
grpc.StreamUnaryClientInterceptor]]]
) -> grpc.Channel:
for interceptor in reversed(list(interceptors)): for interceptor in reversed(list(interceptors)):
if not isinstance(interceptor, grpc.UnaryUnaryClientInterceptor) and \ if not isinstance(interceptor, grpc.UnaryUnaryClientInterceptor) and \
not isinstance(interceptor, grpc.UnaryStreamClientInterceptor) and \ not isinstance(interceptor, grpc.UnaryStreamClientInterceptor) and \

@ -15,10 +15,12 @@
import collections import collections
import logging import logging
import threading import threading
from typing import Callable, Optional, Type
import grpc import grpc
from grpc import _common from grpc import _common
from grpc._cython import cygrpc from grpc._cython import cygrpc
from grpc._typing import MetadataType
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -40,12 +42,15 @@ class _CallbackState(object):
class _AuthMetadataPluginCallback(grpc.AuthMetadataPluginCallback): class _AuthMetadataPluginCallback(grpc.AuthMetadataPluginCallback):
_state: _CallbackState
_callback: Callable
def __init__(self, state, callback): def __init__(self, state: _CallbackState, callback: Callable):
self._state = state self._state = state
self._callback = callback self._callback = callback
def __call__(self, metadata, error): def __call__(self, metadata: MetadataType,
error: Optional[Type[BaseException]]):
with self._state.lock: with self._state.lock:
if self._state.exception is None: if self._state.exception is None:
if self._state.called: if self._state.called:
@ -65,8 +70,9 @@ class _AuthMetadataPluginCallback(grpc.AuthMetadataPluginCallback):
class _Plugin(object): class _Plugin(object):
_metadata_plugin: grpc.AuthMetadataPlugin
def __init__(self, metadata_plugin): def __init__(self, metadata_plugin: grpc.AuthMetadataPlugin):
self._metadata_plugin = metadata_plugin self._metadata_plugin = metadata_plugin
self._stored_ctx = None self._stored_ctx = None
@ -81,7 +87,7 @@ class _Plugin(object):
# Support versions predating contextvars. # Support versions predating contextvars.
pass pass
def __call__(self, service_url, method_name, callback): def __call__(self, service_url: str, method_name: str, callback: Callable):
context = _AuthMetadataContext(_common.decode(service_url), context = _AuthMetadataContext(_common.decode(service_url),
_common.decode(method_name)) _common.decode(method_name))
callback_state = _CallbackState() callback_state = _CallbackState()
@ -100,7 +106,9 @@ class _Plugin(object):
_common.encode(str(exception))) _common.encode(str(exception)))
def metadata_plugin_call_credentials(metadata_plugin, name): def metadata_plugin_call_credentials(
metadata_plugin: grpc.AuthMetadataPlugin,
name: Optional[str]) -> grpc.CallCredentials:
if name is None: if name is None:
try: try:
effective_name = metadata_plugin.__name__ effective_name = metadata_plugin.__name__

@ -13,6 +13,8 @@
# limitations under the License. # limitations under the License.
import sys import sys
import types
from typing import Tuple, Union
_REQUIRED_SYMBOLS = ("_protos", "_services", "_protos_and_services") _REQUIRED_SYMBOLS = ("_protos", "_services", "_protos_and_services")
_MINIMUM_VERSION = (3, 5, 0) _MINIMUM_VERSION = (3, 5, 0)
@ -21,13 +23,13 @@ _UNINSTALLED_TEMPLATE = "Install the grpcio-tools package (1.32.0+) to use the {
_VERSION_ERROR_TEMPLATE = "The {} function is only on available on Python 3.X interpreters." _VERSION_ERROR_TEMPLATE = "The {} function is only on available on Python 3.X interpreters."
def _has_runtime_proto_symbols(mod): def _has_runtime_proto_symbols(mod: types.ModuleType) -> bool:
return all(hasattr(mod, sym) for sym in _REQUIRED_SYMBOLS) return all(hasattr(mod, sym) for sym in _REQUIRED_SYMBOLS)
def _is_grpc_tools_importable(): def _is_grpc_tools_importable() -> bool:
try: try:
import grpc_tools # pylint: disable=unused-import import grpc_tools # pylint: disable=unused-import # pytype: disable=import-error
return True return True
except ImportError as e: except ImportError as e:
# NOTE: It's possible that we're encountering a transitive ImportError, so # NOTE: It's possible that we're encountering a transitive ImportError, so
@ -37,7 +39,9 @@ def _is_grpc_tools_importable():
return False return False
def _call_with_lazy_import(fn_name, protobuf_path): def _call_with_lazy_import(
fn_name: str, protobuf_path: str
) -> Union[types.ModuleType, Tuple[types.ModuleType, types.ModuleType]]:
"""Calls one of the three functions, lazily importing grpc_tools. """Calls one of the three functions, lazily importing grpc_tools.
Args: Args:
@ -52,7 +56,7 @@ def _call_with_lazy_import(fn_name, protobuf_path):
else: else:
if not _is_grpc_tools_importable(): if not _is_grpc_tools_importable():
raise NotImplementedError(_UNINSTALLED_TEMPLATE.format(fn_name)) raise NotImplementedError(_UNINSTALLED_TEMPLATE.format(fn_name))
import grpc_tools.protoc import grpc_tools.protoc # pytype: disable=import-error
if _has_runtime_proto_symbols(grpc_tools.protoc): if _has_runtime_proto_symbols(grpc_tools.protoc):
fn = getattr(grpc_tools.protoc, '_' + fn_name) fn = getattr(grpc_tools.protoc, '_' + fn_name)
return fn(protobuf_path) return fn(protobuf_path)

@ -0,0 +1,29 @@
# Copyright 2022 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common types for gRPC Sync API"""
from typing import Any, Callable, Iterable, Sequence, Tuple, TypeVar, Union
from grpc._cython.cygrpc import EOF
RequestType = TypeVar('RequestType')
ResponseType = TypeVar('ResponseType')
SerializingFunction = Callable[[Any], bytes]
DeserializingFunction = Callable[[bytes], Any]
MetadataType = Sequence[Tuple[str, Union[str, bytes]]]
ChannelArgumentType = Sequence[Tuple[str, Any]]
EOFType = type(EOF)
DoneCallbackType = Callable[[Any], None]
RequestIterableType = Iterable[Any]
ResponseIterableType = Iterable[Any]

@ -17,9 +17,11 @@ import collections
import logging import logging
import threading import threading
import time import time
from typing import Callable, Dict, Optional, Sequence
import grpc import grpc # pytype: disable=pyi-error
from grpc import _common from grpc import _common # pytype: disable=pyi-error
from grpc._typing import DoneCallbackType
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -42,24 +44,35 @@ class RpcMethodHandler(
class DictionaryGenericHandler(grpc.ServiceRpcHandler): class DictionaryGenericHandler(grpc.ServiceRpcHandler):
_name: str
_method_handlers: Dict[str, grpc.RpcMethodHandler]
def __init__(self, service, method_handlers): def __init__(self, service: str,
method_handlers: Dict[str, grpc.RpcMethodHandler]):
self._name = service self._name = service
self._method_handlers = { self._method_handlers = {
_common.fully_qualified_method(service, method): method_handler _common.fully_qualified_method(service, method): method_handler
for method, method_handler in method_handlers.items() for method, method_handler in method_handlers.items()
} }
def service_name(self): def service_name(self) -> str:
return self._name return self._name
def service(self, handler_call_details): def service(
return self._method_handlers.get(handler_call_details.method) self, handler_call_details: grpc.HandlerCallDetails
) -> Optional[grpc.RpcMethodHandler]:
details_method = handler_call_details.method
return self._method_handlers.get(details_method) # pytype: disable=attribute-error
class _ChannelReadyFuture(grpc.Future): class _ChannelReadyFuture(grpc.Future):
_condition: threading.Condition
_channel: grpc.Channel
_matured: bool
_cancelled: bool
_done_callbacks: Sequence[Callable]
def __init__(self, channel): def __init__(self, channel: grpc.Channel):
self._condition = threading.Condition() self._condition = threading.Condition()
self._channel = channel self._channel = channel
@ -67,7 +80,7 @@ class _ChannelReadyFuture(grpc.Future):
self._cancelled = False self._cancelled = False
self._done_callbacks = [] self._done_callbacks = []
def _block(self, timeout): def _block(self, timeout: Optional[float]) -> None:
until = None if timeout is None else time.time() + timeout until = None if timeout is None else time.time() + timeout
with self._condition: with self._condition:
while True: while True:
@ -85,7 +98,7 @@ class _ChannelReadyFuture(grpc.Future):
else: else:
self._condition.wait(timeout=remaining) self._condition.wait(timeout=remaining)
def _update(self, connectivity): def _update(self, connectivity: Optional[grpc.ChannelConnectivity]) -> None:
with self._condition: with self._condition:
if (not self._cancelled and if (not self._cancelled and
connectivity is grpc.ChannelConnectivity.READY): connectivity is grpc.ChannelConnectivity.READY):
@ -103,7 +116,7 @@ class _ChannelReadyFuture(grpc.Future):
except Exception: # pylint: disable=broad-except except Exception: # pylint: disable=broad-except
_LOGGER.exception(_DONE_CALLBACK_EXCEPTION_LOG_MESSAGE) _LOGGER.exception(_DONE_CALLBACK_EXCEPTION_LOG_MESSAGE)
def cancel(self): def cancel(self) -> bool:
with self._condition: with self._condition:
if not self._matured: if not self._matured:
self._cancelled = True self._cancelled = True
@ -122,28 +135,28 @@ class _ChannelReadyFuture(grpc.Future):
return True return True
def cancelled(self): def cancelled(self) -> bool:
with self._condition: with self._condition:
return self._cancelled return self._cancelled
def running(self): def running(self) -> bool:
with self._condition: with self._condition:
return not self._cancelled and not self._matured return not self._cancelled and not self._matured
def done(self): def done(self) -> bool:
with self._condition: with self._condition:
return self._cancelled or self._matured return self._cancelled or self._matured
def result(self, timeout=None): def result(self, timeout: Optional[float] = None) -> None:
self._block(timeout) self._block(timeout)
def exception(self, timeout=None): def exception(self, timeout: Optional[float] = None) -> None:
self._block(timeout) self._block(timeout)
def traceback(self, timeout=None): def traceback(self, timeout: Optional[float] = None) -> None:
self._block(timeout) self._block(timeout)
def add_done_callback(self, fn): def add_done_callback(self, fn: DoneCallbackType):
with self._condition: with self._condition:
if not self._cancelled and not self._matured: if not self._cancelled and not self._matured:
self._done_callbacks.append(fn) self._done_callbacks.append(fn)
@ -161,7 +174,7 @@ class _ChannelReadyFuture(grpc.Future):
self._channel.unsubscribe(self._update) self._channel.unsubscribe(self._update)
def channel_ready_future(channel): def channel_ready_future(channel: grpc.Channel) -> _ChannelReadyFuture:
ready_future = _ChannelReadyFuture(channel) ready_future = _ChannelReadyFuture(channel)
ready_future.start() ready_future.start()
return ready_future return ready_future

Loading…
Cancel
Save