Add typing for some internal python files. (#31514)

* Add typing for some internal python files.
pull/31586/head
Xuan Wang 2 years ago committed by GitHub
parent 90d8754b0e
commit e1978a4fdd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 9
      setup.cfg
  2. 6
      src/python/grpcio/grpc/BUILD.bazel
  3. 22
      src/python/grpcio/grpc/_auth.py
  4. 28
      src/python/grpcio/grpc/_common.py
  5. 16
      src/python/grpcio/grpc/_compression.py
  6. 354
      src/python/grpcio/grpc/_interceptor.py
  7. 18
      src/python/grpcio/grpc/_plugin_wrapping.py
  8. 14
      src/python/grpcio/grpc/_runtime_protos.py
  9. 29
      src/python/grpcio/grpc/_typing.py
  10. 49
      src/python/grpcio/grpc/_utilities.py

@ -21,14 +21,23 @@ license_files = LICENSE
# NOTE(lidiz) Adding examples one by one due to pytype aggressive errer:
# ninja: error: build.ninja:178: multiple rules generate helloworld_pb2.pyi [-w dupbuild=err]
# TODO(xuanwn): include all files in src/python/grpcio/grpc
[pytype]
inputs =
src/python/grpcio/grpc/experimental
src/python/grpcio/grpc
src/python/grpcio_tests/tests_aio
examples/python/auth
examples/python/helloworld
exclude =
**/*_pb2.py
src/python/grpcio/grpc/framework
src/python/grpcio/grpc/aio
src/python/grpcio/grpc/beta
src/python/grpcio/grpc/__init__.py
src/python/grpcio/grpc/_channel.py
src/python/grpcio/grpc/_server.py
src/python/grpcio/grpc/_simple_stubs.py
# NOTE(lidiz)
# import-error: C extension triggers import-error.

@ -89,6 +89,11 @@ py_library(
srcs = ["_runtime_protos.py"],
)
py_library(
name = "_typing",
srcs = ["_typing.py"],
)
py_library(
name = "grpcio",
srcs = ["__init__.py"],
@ -99,6 +104,7 @@ py_library(
deps = [
":_runtime_protos",
":_simple_stubs",
":_typing",
":aio",
":auth",
":channel",

@ -14,31 +14,39 @@
"""GRPCAuthMetadataPlugins for standard authentication."""
import inspect
from typing import Any, Optional
import grpc
def _sign_request(callback, token, error):
def _sign_request(callback: grpc.AuthMetadataPluginCallback,
token: Optional[str], error: Optional[Exception]):
metadata = (('authorization', 'Bearer {}'.format(token)),)
callback(metadata, error)
class GoogleCallCredentials(grpc.AuthMetadataPlugin):
"""Metadata wrapper for GoogleCredentials from the oauth2client library."""
_is_jwt: bool
_credentials: Any
def __init__(self, credentials):
# TODO(xuanwn): Give credentials an actual type.
def __init__(self, credentials: Any):
self._credentials = credentials
# Hack to determine if these are JWT creds and we need to pass
# additional_claims when getting a token
self._is_jwt = 'additional_claims' in inspect.getfullargspec(
credentials.get_access_token).args
def __call__(self, context, callback):
def __call__(self, context: grpc.AuthMetadataContext,
callback: grpc.AuthMetadataPluginCallback):
try:
if self._is_jwt:
access_token = self._credentials.get_access_token(
additional_claims={
'aud': context.service_url
'aud':
context.
service_url # pytype: disable=attribute-error
}).access_token
else:
access_token = self._credentials.get_access_token().access_token
@ -50,9 +58,11 @@ class GoogleCallCredentials(grpc.AuthMetadataPlugin):
class AccessTokenAuthMetadataPlugin(grpc.AuthMetadataPlugin):
"""Metadata wrapper for raw access token credentials."""
_access_token: str
def __init__(self, access_token):
def __init__(self, access_token: str):
self._access_token = access_token
def __call__(self, context, callback):
def __call__(self, context: grpc.AuthMetadataContext,
callback: grpc.AuthMetadataPluginCallback):
_sign_request(callback, self._access_token, None)

@ -15,9 +15,12 @@
import logging
import time
from typing import Any, AnyStr, Callable, Optional, Union
import grpc
from grpc._cython import cygrpc
from grpc._typing import DeserializingFunction
from grpc._typing import SerializingFunction
_LOGGER = logging.getLogger(__name__)
@ -64,20 +67,22 @@ _ERROR_MESSAGE_PORT_BINDING_FAILED = 'Failed to bind to address %s; set ' \
'GRPC_VERBOSITY=debug environment variable to see detailed error message.'
def encode(s):
def encode(s: AnyStr) -> bytes:
if isinstance(s, bytes):
return s
else:
return s.encode('utf8')
def decode(b):
def decode(b: AnyStr) -> str:
if isinstance(b, bytes):
return b.decode('utf-8', 'replace')
return b
def _transform(message, transformer, exception_message):
def _transform(message: Any, transformer: Union[SerializingFunction,
DeserializingFunction, None],
exception_message: str) -> Any:
if transformer is None:
return message
else:
@ -88,26 +93,31 @@ def _transform(message, transformer, exception_message):
return None
def serialize(message, serializer):
def serialize(message: Any, serializer: Optional[SerializingFunction]) -> bytes:
return _transform(message, serializer, 'Exception serializing message!')
def deserialize(serialized_message, deserializer):
def deserialize(serialized_message: bytes,
deserializer: Optional[DeserializingFunction]) -> Any:
return _transform(serialized_message, deserializer,
'Exception deserializing message!')
def fully_qualified_method(group, method):
def fully_qualified_method(group: str, method: str) -> str:
return '/{}/{}'.format(group, method)
def _wait_once(wait_fn, timeout, spin_cb):
def _wait_once(wait_fn: Callable[..., None], timeout: float,
spin_cb: Optional[Callable[[], None]]):
wait_fn(timeout=timeout)
if spin_cb is not None:
spin_cb()
def wait(wait_fn, wait_complete_fn, timeout=None, spin_cb=None):
def wait(wait_fn: Callable[..., None],
wait_complete_fn: Callable[[], bool],
timeout: Optional[float] = None,
spin_cb: Optional[Callable[[], None]] = None) -> bool:
"""Blocks waiting for an event without blocking the thread indefinitely.
See https://github.com/grpc/grpc/issues/19464 for full context. CPython's
@ -148,7 +158,7 @@ def wait(wait_fn, wait_complete_fn, timeout=None, spin_cb=None):
return False
def validate_port_binding_result(address, port):
def validate_port_binding_result(address: str, port: int) -> int:
"""Validates if the port binding succeed.
If the port returned by Core is 0, the binding is failed. However, in that

@ -12,7 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Optional
import grpc
from grpc._cython import cygrpc
from grpc._typing import MetadataType
NoCompression = cygrpc.CompressionAlgorithm.none
Deflate = cygrpc.CompressionAlgorithm.deflate
@ -25,21 +31,23 @@ _METADATA_STRING_MAPPING = {
}
def _compression_algorithm_to_metadata_value(compression):
def _compression_algorithm_to_metadata_value(
compression: grpc.Compression) -> str:
return _METADATA_STRING_MAPPING[compression]
def compression_algorithm_to_metadata(compression):
def compression_algorithm_to_metadata(compression: grpc.Compression):
return (cygrpc.GRPC_COMPRESSION_REQUEST_ALGORITHM_MD_KEY,
_compression_algorithm_to_metadata_value(compression))
def create_channel_option(compression):
def create_channel_option(compression: Optional[grpc.Compression]):
return ((cygrpc.GRPC_COMPRESSION_CHANNEL_DEFAULT_ALGORITHM,
int(compression)),) if compression else ()
def augment_metadata(metadata, compression):
def augment_metadata(metadata: Optional[MetadataType],
compression: Optional[grpc.Compression]):
if not metadata and not compression:
return None
base_metadata = tuple(metadata) if metadata else ()

@ -15,19 +15,30 @@
import collections
import sys
import types
from typing import Any, Callable, Optional, Sequence, Tuple, Union
import grpc
from ._typing import DeserializingFunction
from ._typing import DoneCallbackType
from ._typing import MetadataType
from ._typing import RequestIterableType
from ._typing import SerializingFunction
class _ServicePipeline(object):
interceptors: Tuple[grpc.ServerInterceptor]
def __init__(self, interceptors):
def __init__(self, interceptors: Sequence[grpc.ServerInterceptor]):
self.interceptors = tuple(interceptors)
def _continuation(self, thunk, index):
def _continuation(self, thunk: Callable, index: int) -> Callable:
return lambda context: self._intercept_at(thunk, index, context)
def _intercept_at(self, thunk, index, context):
def _intercept_at(
self, thunk: Callable, index: int,
context: grpc.HandlerCallDetails) -> grpc.RpcMethodHandler:
if index < len(self.interceptors):
interceptor = self.interceptors[index]
thunk = self._continuation(thunk, index + 1)
@ -35,11 +46,14 @@ class _ServicePipeline(object):
else:
return thunk(context)
def execute(self, thunk, context):
def execute(self, thunk: Callable,
context: grpc.HandlerCallDetails) -> grpc.RpcMethodHandler:
return self._intercept_at(thunk, 0, context)
def service_pipeline(interceptors):
def service_pipeline(
interceptors: Optional[Sequence[grpc.ServerInterceptor]]
) -> Optional[_ServicePipeline]:
return _ServicePipeline(interceptors) if interceptors else None
@ -51,90 +65,101 @@ class _ClientCallDetails(
pass
def _unwrap_client_call_details(call_details, default_details):
def _unwrap_client_call_details(
call_details: grpc.ClientCallDetails,
default_details: grpc.ClientCallDetails
) -> Tuple[str, float, MetadataType, grpc.CallCredentials, bool,
grpc.Compression]:
try:
method = call_details.method
method = call_details.method # pytype: disable=attribute-error
except AttributeError:
method = default_details.method
method = default_details.method # pytype: disable=attribute-error
try:
timeout = call_details.timeout
timeout = call_details.timeout # pytype: disable=attribute-error
except AttributeError:
timeout = default_details.timeout
timeout = default_details.timeout # pytype: disable=attribute-error
try:
metadata = call_details.metadata
metadata = call_details.metadata # pytype: disable=attribute-error
except AttributeError:
metadata = default_details.metadata
metadata = default_details.metadata # pytype: disable=attribute-error
try:
credentials = call_details.credentials
credentials = call_details.credentials # pytype: disable=attribute-error
except AttributeError:
credentials = default_details.credentials
credentials = default_details.credentials # pytype: disable=attribute-error
try:
wait_for_ready = call_details.wait_for_ready
wait_for_ready = call_details.wait_for_ready # pytype: disable=attribute-error
except AttributeError:
wait_for_ready = default_details.wait_for_ready
wait_for_ready = default_details.wait_for_ready # pytype: disable=attribute-error
try:
compression = call_details.compression
compression = call_details.compression # pytype: disable=attribute-error
except AttributeError:
compression = default_details.compression
compression = default_details.compression # pytype: disable=attribute-error
return method, timeout, metadata, credentials, wait_for_ready, compression
class _FailureOutcome(grpc.RpcError, grpc.Future, grpc.Call): # pylint: disable=too-many-ancestors
_exception: Exception
_traceback: types.TracebackType
def __init__(self, exception, traceback):
def __init__(self, exception: Exception, traceback: types.TracebackType):
super(_FailureOutcome, self).__init__()
self._exception = exception
self._traceback = traceback
def initial_metadata(self):
def initial_metadata(self) -> Optional[MetadataType]:
return None
def trailing_metadata(self):
def trailing_metadata(self) -> Optional[MetadataType]:
return None
def code(self):
def code(self) -> Optional[grpc.StatusCode]:
return grpc.StatusCode.INTERNAL
def details(self):
def details(self) -> Optional[str]:
return 'Exception raised while intercepting the RPC'
def cancel(self):
def cancel(self) -> bool:
return False
def cancelled(self):
def cancelled(self) -> bool:
return False
def is_active(self):
def is_active(self) -> bool:
return False
def time_remaining(self):
def time_remaining(self) -> Optional[float]:
return None
def running(self):
def running(self) -> bool:
return False
def done(self):
def done(self) -> bool:
return True
def result(self, ignored_timeout=None):
def result(self, ignored_timeout: Optional[float] = None):
raise self._exception
def exception(self, ignored_timeout=None):
def exception(
self,
ignored_timeout: Optional[float] = None) -> Optional[Exception]:
return self._exception
def traceback(self, ignored_timeout=None):
def traceback(
self,
ignored_timeout: Optional[float] = None
) -> Optional[types.TracebackType]:
return self._traceback
def add_callback(self, unused_callback):
def add_callback(self, unused_callback) -> bool:
return False
def add_done_callback(self, fn):
def add_done_callback(self, fn: DoneCallbackType) -> None:
fn(self)
def __iter__(self):
@ -148,71 +173,77 @@ class _FailureOutcome(grpc.RpcError, grpc.Future, grpc.Call): # pylint: disable
class _UnaryOutcome(grpc.Call, grpc.Future):
_response: Any
_call: grpc.Call
def __init__(self, response, call):
def __init__(self, response: Any, call: grpc.Call):
self._response = response
self._call = call
def initial_metadata(self):
def initial_metadata(self) -> Optional[MetadataType]:
return self._call.initial_metadata()
def trailing_metadata(self):
def trailing_metadata(self) -> Optional[MetadataType]:
return self._call.trailing_metadata()
def code(self):
def code(self) -> Optional[grpc.StatusCode]:
return self._call.code()
def details(self):
def details(self) -> Optional[str]:
return self._call.details()
def is_active(self):
def is_active(self) -> bool:
return self._call.is_active()
def time_remaining(self):
def time_remaining(self) -> Optional[float]:
return self._call.time_remaining()
def cancel(self):
def cancel(self) -> bool:
return self._call.cancel()
def add_callback(self, callback):
def add_callback(self, callback) -> None:
return self._call.add_callback(callback)
def cancelled(self):
def cancelled(self) -> bool:
return False
def running(self):
def running(self) -> bool:
return False
def done(self):
def done(self) -> bool:
return True
def result(self, ignored_timeout=None):
def result(self, ignored_timeout: Optional[float] = None):
return self._response
def exception(self, ignored_timeout=None):
def exception(self, ignored_timeout: Optional[float] = None):
return None
def traceback(self, ignored_timeout=None):
def traceback(self, ignored_timeout: Optional[float] = None):
return None
def add_done_callback(self, fn):
def add_done_callback(self, fn: DoneCallbackType) -> None:
fn(self)
class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
_thunk: Callable
_method: str
_interceptor: grpc.UnaryUnaryClientInterceptor
def __init__(self, thunk, method, interceptor):
def __init__(self, thunk: Callable, method: str,
interceptor: grpc.UnaryUnaryClientInterceptor):
self._thunk = thunk
self._method = method
self._interceptor = interceptor
def __call__(self,
request,
timeout=None,
metadata=None,
credentials=None,
wait_for_ready=None,
compression=None):
request: Any,
timeout: Optional[float] = None,
metadata: Optional[MetadataType] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None) -> Any:
response, ignored_call = self._with_call(request,
timeout=timeout,
metadata=metadata,
@ -221,13 +252,15 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
compression=compression)
return response
def _with_call(self,
request,
timeout=None,
metadata=None,
credentials=None,
wait_for_ready=None,
compression=None):
def _with_call(
self,
request: Any,
timeout: Optional[float] = None,
metadata: Optional[MetadataType] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None
) -> Tuple[Any, grpc.Call]:
client_call_details = _ClientCallDetails(self._method, timeout,
metadata, credentials,
wait_for_ready, compression)
@ -256,13 +289,15 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
request)
return call.result(), call
def with_call(self,
request,
timeout=None,
metadata=None,
credentials=None,
wait_for_ready=None,
compression=None):
def with_call(
self,
request: Any,
timeout: Optional[float] = None,
metadata: Optional[MetadataType] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None
) -> Tuple[Any, grpc.Call]:
return self._with_call(request,
timeout=timeout,
metadata=metadata,
@ -271,12 +306,12 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
compression=compression)
def future(self,
request,
timeout=None,
metadata=None,
credentials=None,
wait_for_ready=None,
compression=None):
request: Any,
timeout: Optional[float] = None,
metadata: Optional[MetadataType] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None) -> Any:
client_call_details = _ClientCallDetails(self._method, timeout,
metadata, credentials,
wait_for_ready, compression)
@ -302,19 +337,23 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
_thunk: Callable
_method: str
_interceptor: grpc.UnaryStreamClientInterceptor
def __init__(self, thunk, method, interceptor):
def __init__(self, thunk: Callable, method: str,
interceptor: grpc.UnaryStreamClientInterceptor):
self._thunk = thunk
self._method = method
self._interceptor = interceptor
def __call__(self,
request,
timeout=None,
metadata=None,
credentials=None,
wait_for_ready=None,
compression=None):
request: Any,
timeout: Optional[float] = None,
metadata: Optional[MetadataType] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None):
client_call_details = _ClientCallDetails(self._method, timeout,
metadata, credentials,
wait_for_ready, compression)
@ -339,19 +378,23 @@ class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
_thunk: Callable
_method: str
_interceptor: grpc.StreamUnaryClientInterceptor
def __init__(self, thunk, method, interceptor):
def __init__(self, thunk: Callable, method: str,
interceptor: grpc.StreamUnaryClientInterceptor):
self._thunk = thunk
self._method = method
self._interceptor = interceptor
def __call__(self,
request_iterator,
timeout=None,
metadata=None,
credentials=None,
wait_for_ready=None,
compression=None):
request_iterator: RequestIterableType,
timeout: Optional[float] = None,
metadata: Optional[MetadataType] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None) -> Any:
response, ignored_call = self._with_call(request_iterator,
timeout=timeout,
metadata=metadata,
@ -360,13 +403,15 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
compression=compression)
return response
def _with_call(self,
request_iterator,
timeout=None,
metadata=None,
credentials=None,
wait_for_ready=None,
compression=None):
def _with_call(
self,
request_iterator: RequestIterableType,
timeout: Optional[float] = None,
metadata: Optional[MetadataType] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None
) -> Tuple[Any, grpc.Call]:
client_call_details = _ClientCallDetails(self._method, timeout,
metadata, credentials,
wait_for_ready, compression)
@ -395,13 +440,15 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
request_iterator)
return call.result(), call
def with_call(self,
request_iterator,
timeout=None,
metadata=None,
credentials=None,
wait_for_ready=None,
compression=None):
def with_call(
self,
request_iterator: RequestIterableType,
timeout: Optional[float] = None,
metadata: Optional[MetadataType] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None
) -> Tuple[Any, grpc.Call]:
return self._with_call(request_iterator,
timeout=timeout,
metadata=metadata,
@ -410,12 +457,12 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
compression=compression)
def future(self,
request_iterator,
timeout=None,
metadata=None,
credentials=None,
wait_for_ready=None,
compression=None):
request_iterator: RequestIterableType,
timeout: Optional[float] = None,
metadata: Optional[MetadataType] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None) -> Any:
client_call_details = _ClientCallDetails(self._method, timeout,
metadata, credentials,
wait_for_ready, compression)
@ -441,19 +488,23 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable):
_thunk: Callable
_method: str
_interceptor: grpc.StreamStreamClientInterceptor
def __init__(self, thunk, method, interceptor):
def __init__(self, thunk: Callable, method: str,
interceptor: grpc.StreamStreamClientInterceptor):
self._thunk = thunk
self._method = method
self._interceptor = interceptor
def __call__(self,
request_iterator,
timeout=None,
metadata=None,
credentials=None,
wait_for_ready=None,
compression=None):
request_iterator: RequestIterableType,
timeout: Optional[float] = None,
metadata: Optional[MetadataType] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None):
client_call_details = _ClientCallDetails(self._method, timeout,
metadata, credentials,
wait_for_ready, compression)
@ -478,21 +529,34 @@ class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable):
class _Channel(grpc.Channel):
def __init__(self, channel, interceptor):
_channel: grpc.Channel
_interceptor: Union[grpc.UnaryUnaryClientInterceptor,
grpc.UnaryStreamClientInterceptor,
grpc.StreamStreamClientInterceptor,
grpc.StreamUnaryClientInterceptor]
def __init__(self, channel: grpc.Channel,
interceptor: Union[grpc.UnaryUnaryClientInterceptor,
grpc.UnaryStreamClientInterceptor,
grpc.StreamStreamClientInterceptor,
grpc.StreamUnaryClientInterceptor]):
self._channel = channel
self._interceptor = interceptor
def subscribe(self, callback, try_to_connect=False):
def subscribe(self,
callback: Callable,
try_to_connect: Optional[bool] = False):
self._channel.subscribe(callback, try_to_connect=try_to_connect)
def unsubscribe(self, callback):
def unsubscribe(self, callback: Callable):
self._channel.unsubscribe(callback)
def unary_unary(self,
method,
request_serializer=None,
response_deserializer=None):
def unary_unary(
self,
method: str,
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None
) -> grpc.UnaryUnaryMultiCallable:
thunk = lambda m: self._channel.unary_unary(m, request_serializer,
response_deserializer)
if isinstance(self._interceptor, grpc.UnaryUnaryClientInterceptor):
@ -500,10 +564,12 @@ class _Channel(grpc.Channel):
else:
return thunk(method)
def unary_stream(self,
method,
request_serializer=None,
response_deserializer=None):
def unary_stream(
self,
method: str,
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None
) -> grpc.UnaryStreamMultiCallable:
thunk = lambda m: self._channel.unary_stream(m, request_serializer,
response_deserializer)
if isinstance(self._interceptor, grpc.UnaryStreamClientInterceptor):
@ -511,10 +577,12 @@ class _Channel(grpc.Channel):
else:
return thunk(method)
def stream_unary(self,
method,
request_serializer=None,
response_deserializer=None):
def stream_unary(
self,
method: str,
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None
) -> grpc.StreamUnaryMultiCallable:
thunk = lambda m: self._channel.stream_unary(m, request_serializer,
response_deserializer)
if isinstance(self._interceptor, grpc.StreamUnaryClientInterceptor):
@ -522,10 +590,12 @@ class _Channel(grpc.Channel):
else:
return thunk(method)
def stream_stream(self,
method,
request_serializer=None,
response_deserializer=None):
def stream_stream(
self,
method: str,
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None
) -> grpc.StreamStreamMultiCallable:
thunk = lambda m: self._channel.stream_stream(m, request_serializer,
response_deserializer)
if isinstance(self._interceptor, grpc.StreamStreamClientInterceptor):
@ -547,7 +617,13 @@ class _Channel(grpc.Channel):
self._channel.close()
def intercept_channel(channel, *interceptors):
def intercept_channel(
channel: grpc.Channel,
*interceptors: Optional[Sequence[Union[grpc.UnaryUnaryClientInterceptor,
grpc.UnaryStreamClientInterceptor,
grpc.StreamStreamClientInterceptor,
grpc.StreamUnaryClientInterceptor]]]
) -> grpc.Channel:
for interceptor in reversed(list(interceptors)):
if not isinstance(interceptor, grpc.UnaryUnaryClientInterceptor) and \
not isinstance(interceptor, grpc.UnaryStreamClientInterceptor) and \

@ -15,10 +15,12 @@
import collections
import logging
import threading
from typing import Callable, Optional, Type
import grpc
from grpc import _common
from grpc._cython import cygrpc
from grpc._typing import MetadataType
_LOGGER = logging.getLogger(__name__)
@ -40,12 +42,15 @@ class _CallbackState(object):
class _AuthMetadataPluginCallback(grpc.AuthMetadataPluginCallback):
_state: _CallbackState
_callback: Callable
def __init__(self, state, callback):
def __init__(self, state: _CallbackState, callback: Callable):
self._state = state
self._callback = callback
def __call__(self, metadata, error):
def __call__(self, metadata: MetadataType,
error: Optional[Type[BaseException]]):
with self._state.lock:
if self._state.exception is None:
if self._state.called:
@ -65,8 +70,9 @@ class _AuthMetadataPluginCallback(grpc.AuthMetadataPluginCallback):
class _Plugin(object):
_metadata_plugin: grpc.AuthMetadataPlugin
def __init__(self, metadata_plugin):
def __init__(self, metadata_plugin: grpc.AuthMetadataPlugin):
self._metadata_plugin = metadata_plugin
self._stored_ctx = None
@ -81,7 +87,7 @@ class _Plugin(object):
# Support versions predating contextvars.
pass
def __call__(self, service_url, method_name, callback):
def __call__(self, service_url: str, method_name: str, callback: Callable):
context = _AuthMetadataContext(_common.decode(service_url),
_common.decode(method_name))
callback_state = _CallbackState()
@ -100,7 +106,9 @@ class _Plugin(object):
_common.encode(str(exception)))
def metadata_plugin_call_credentials(metadata_plugin, name):
def metadata_plugin_call_credentials(
metadata_plugin: grpc.AuthMetadataPlugin,
name: Optional[str]) -> grpc.CallCredentials:
if name is None:
try:
effective_name = metadata_plugin.__name__

@ -13,6 +13,8 @@
# limitations under the License.
import sys
import types
from typing import Tuple, Union
_REQUIRED_SYMBOLS = ("_protos", "_services", "_protos_and_services")
_MINIMUM_VERSION = (3, 5, 0)
@ -21,13 +23,13 @@ _UNINSTALLED_TEMPLATE = "Install the grpcio-tools package (1.32.0+) to use the {
_VERSION_ERROR_TEMPLATE = "The {} function is only on available on Python 3.X interpreters."
def _has_runtime_proto_symbols(mod):
def _has_runtime_proto_symbols(mod: types.ModuleType) -> bool:
return all(hasattr(mod, sym) for sym in _REQUIRED_SYMBOLS)
def _is_grpc_tools_importable():
def _is_grpc_tools_importable() -> bool:
try:
import grpc_tools # pylint: disable=unused-import
import grpc_tools # pylint: disable=unused-import # pytype: disable=import-error
return True
except ImportError as e:
# NOTE: It's possible that we're encountering a transitive ImportError, so
@ -37,7 +39,9 @@ def _is_grpc_tools_importable():
return False
def _call_with_lazy_import(fn_name, protobuf_path):
def _call_with_lazy_import(
fn_name: str, protobuf_path: str
) -> Union[types.ModuleType, Tuple[types.ModuleType, types.ModuleType]]:
"""Calls one of the three functions, lazily importing grpc_tools.
Args:
@ -52,7 +56,7 @@ def _call_with_lazy_import(fn_name, protobuf_path):
else:
if not _is_grpc_tools_importable():
raise NotImplementedError(_UNINSTALLED_TEMPLATE.format(fn_name))
import grpc_tools.protoc
import grpc_tools.protoc # pytype: disable=import-error
if _has_runtime_proto_symbols(grpc_tools.protoc):
fn = getattr(grpc_tools.protoc, '_' + fn_name)
return fn(protobuf_path)

@ -0,0 +1,29 @@
# Copyright 2022 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common types for gRPC Sync API"""
from typing import Any, Callable, Iterable, Sequence, Tuple, TypeVar, Union
from grpc._cython.cygrpc import EOF
RequestType = TypeVar('RequestType')
ResponseType = TypeVar('ResponseType')
SerializingFunction = Callable[[Any], bytes]
DeserializingFunction = Callable[[bytes], Any]
MetadataType = Sequence[Tuple[str, Union[str, bytes]]]
ChannelArgumentType = Sequence[Tuple[str, Any]]
EOFType = type(EOF)
DoneCallbackType = Callable[[Any], None]
RequestIterableType = Iterable[Any]
ResponseIterableType = Iterable[Any]

@ -17,9 +17,11 @@ import collections
import logging
import threading
import time
from typing import Callable, Dict, Optional, Sequence
import grpc
from grpc import _common
import grpc # pytype: disable=pyi-error
from grpc import _common # pytype: disable=pyi-error
from grpc._typing import DoneCallbackType
_LOGGER = logging.getLogger(__name__)
@ -42,24 +44,35 @@ class RpcMethodHandler(
class DictionaryGenericHandler(grpc.ServiceRpcHandler):
_name: str
_method_handlers: Dict[str, grpc.RpcMethodHandler]
def __init__(self, service, method_handlers):
def __init__(self, service: str,
method_handlers: Dict[str, grpc.RpcMethodHandler]):
self._name = service
self._method_handlers = {
_common.fully_qualified_method(service, method): method_handler
for method, method_handler in method_handlers.items()
}
def service_name(self):
def service_name(self) -> str:
return self._name
def service(self, handler_call_details):
return self._method_handlers.get(handler_call_details.method)
def service(
self, handler_call_details: grpc.HandlerCallDetails
) -> Optional[grpc.RpcMethodHandler]:
details_method = handler_call_details.method
return self._method_handlers.get(details_method) # pytype: disable=attribute-error
class _ChannelReadyFuture(grpc.Future):
_condition: threading.Condition
_channel: grpc.Channel
_matured: bool
_cancelled: bool
_done_callbacks: Sequence[Callable]
def __init__(self, channel):
def __init__(self, channel: grpc.Channel):
self._condition = threading.Condition()
self._channel = channel
@ -67,7 +80,7 @@ class _ChannelReadyFuture(grpc.Future):
self._cancelled = False
self._done_callbacks = []
def _block(self, timeout):
def _block(self, timeout: Optional[float]) -> None:
until = None if timeout is None else time.time() + timeout
with self._condition:
while True:
@ -85,7 +98,7 @@ class _ChannelReadyFuture(grpc.Future):
else:
self._condition.wait(timeout=remaining)
def _update(self, connectivity):
def _update(self, connectivity: Optional[grpc.ChannelConnectivity]) -> None:
with self._condition:
if (not self._cancelled and
connectivity is grpc.ChannelConnectivity.READY):
@ -103,7 +116,7 @@ class _ChannelReadyFuture(grpc.Future):
except Exception: # pylint: disable=broad-except
_LOGGER.exception(_DONE_CALLBACK_EXCEPTION_LOG_MESSAGE)
def cancel(self):
def cancel(self) -> bool:
with self._condition:
if not self._matured:
self._cancelled = True
@ -122,28 +135,28 @@ class _ChannelReadyFuture(grpc.Future):
return True
def cancelled(self):
def cancelled(self) -> bool:
with self._condition:
return self._cancelled
def running(self):
def running(self) -> bool:
with self._condition:
return not self._cancelled and not self._matured
def done(self):
def done(self) -> bool:
with self._condition:
return self._cancelled or self._matured
def result(self, timeout=None):
def result(self, timeout: Optional[float] = None) -> None:
self._block(timeout)
def exception(self, timeout=None):
def exception(self, timeout: Optional[float] = None) -> None:
self._block(timeout)
def traceback(self, timeout=None):
def traceback(self, timeout: Optional[float] = None) -> None:
self._block(timeout)
def add_done_callback(self, fn):
def add_done_callback(self, fn: DoneCallbackType):
with self._condition:
if not self._cancelled and not self._matured:
self._done_callbacks.append(fn)
@ -161,7 +174,7 @@ class _ChannelReadyFuture(grpc.Future):
self._channel.unsubscribe(self._update)
def channel_ready_future(channel):
def channel_ready_future(channel: grpc.Channel) -> _ChannelReadyFuture:
ready_future = _ChannelReadyFuture(channel)
ready_future.start()
return ready_future

Loading…
Cancel
Save