diff --git a/src/python/grpcio/grpc/experimental/aio/_interceptor.py b/src/python/grpcio/grpc/experimental/aio/_interceptor.py index 5de3367a86e..04f2f72e275 100644 --- a/src/python/grpcio/grpc/experimental/aio/_interceptor.py +++ b/src/python/grpcio/grpc/experimental/aio/_interceptor.py @@ -16,7 +16,7 @@ import asyncio import collections import functools from abc import ABCMeta, abstractmethod -from typing import Callable, Optional, Iterator, Sequence, Union +from typing import Callable, Optional, Iterator, Sequence, Union, Awaitable import grpc from grpc._cython import cygrpc @@ -38,8 +38,8 @@ class ServerInterceptor(metaclass=ABCMeta): @abstractmethod async def intercept_service( - self, continuation: Callable[[grpc.HandlerCallDetails], grpc. - RpcMethodHandler], + self, continuation: Callable[[grpc.HandlerCallDetails], Awaitable[ + grpc.RpcMethodHandler]], handler_call_details: grpc.HandlerCallDetails ) -> grpc.RpcMethodHandler: """Intercepts incoming RPCs before handing them over to a handler. diff --git a/src/python/grpcio_tests/tests_aio/unit/server_interceptor_test.py b/src/python/grpcio_tests/tests_aio/unit/server_interceptor_test.py index 1680fc98b4a..5aeedbab879 100644 --- a/src/python/grpcio_tests/tests_aio/unit/server_interceptor_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/server_interceptor_test.py @@ -13,7 +13,7 @@ # limitations under the License. import logging import unittest -from typing import Callable +from typing import Callable, Awaitable import grpc @@ -26,21 +26,33 @@ from src.proto.grpc.testing import messages_pb2 class _LoggingInterceptor(aio.ServerInterceptor): - def __init__(self, tag, record): + def __init__(self, tag: str, record: list) -> None: self.tag = tag self.record = record - async def intercept_service(self, continuation, handler_call_details): + async def intercept_service( + self, continuation: Callable[[grpc.HandlerCallDetails], Awaitable[ + grpc.RpcMethodHandler]], + handler_call_details: grpc.HandlerCallDetails + ) -> grpc.RpcMethodHandler: self.record.append(self.tag + ':intercept_service') return await continuation(handler_call_details) class _GenericInterceptor(aio.ServerInterceptor): - def __init__(self, fn): + def __init__(self, fn: Callable[[ + Callable[[grpc.HandlerCallDetails], Awaitable[grpc. + RpcMethodHandler]], + grpc.HandlerCallDetails + ], Awaitable[grpc.RpcMethodHandler]]) -> None: self._fn = fn - async def intercept_service(self, continuation, handler_call_details): + async def intercept_service( + self, continuation: Callable[[grpc.HandlerCallDetails], Awaitable[ + grpc.RpcMethodHandler]], + handler_call_details: grpc.HandlerCallDetails + ) -> grpc.RpcMethodHandler: return await self._fn(continuation, handler_call_details) @@ -48,7 +60,11 @@ def _filter_server_interceptor(condition: Callable, interceptor: aio.ServerInterceptor ) -> aio.ServerInterceptor: - async def intercept_service(continuation, handler_call_details): + async def intercept_service( + continuation: Callable[[grpc.HandlerCallDetails], Awaitable[ + grpc.RpcMethodHandler]], + handler_call_details: grpc.HandlerCallDetails + ) -> grpc.RpcMethodHandler: if condition(handler_call_details): return await interceptor.intercept_service(continuation, handler_call_details)