[Python Aio] Convert metadata to aio.Metadata when necessary (#34347)

Fix: https://github.com/grpc/grpc/issues/34298


<!--

If you know who should review your pull request, please assign it to
that
person, otherwise the pull request would get assigned randomly.

If your pull request is for a specific language, please add the
appropriate
lang label.

-->
pull/34761/head
Xuan Wang 2 years ago committed by GitHub
parent b08b07f611
commit 65d4df25ed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 13
      src/python/grpcio/grpc/aio/_channel.py
  2. 1
      src/python/grpcio_tests/tests_aio/tests.json
  3. 94
      src/python/grpcio_tests/tests_aio/unit/_metadata_test.py

@ -41,6 +41,7 @@ from ._interceptor import UnaryUnaryClientInterceptor
from ._metadata import Metadata
from ._typing import ChannelArgumentType
from ._typing import DeserializingFunction
from ._typing import MetadataType
from ._typing import RequestIterableType
from ._typing import RequestType
from ._typing import ResponseType
@ -115,13 +116,15 @@ class _BaseMultiCallable:
@staticmethod
def _init_metadata(
metadata: Optional[Metadata] = None,
metadata: Optional[MetadataType] = None,
compression: Optional[grpc.Compression] = None,
) -> Metadata:
"""Based on the provided values for <metadata> or <compression> initialise the final
metadata, as it should be used for the current call.
"""
metadata = metadata or Metadata()
if not isinstance(metadata, Metadata) and isinstance(metadata, tuple):
metadata = Metadata.from_tuple(metadata)
if compression:
metadata = Metadata(
*_compression.augment_metadata(metadata, compression)
@ -137,7 +140,7 @@ class UnaryUnaryMultiCallable(
request: RequestType,
*,
timeout: Optional[float] = None,
metadata: Optional[Metadata] = None,
metadata: Optional[MetadataType] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None,
@ -182,7 +185,7 @@ class UnaryStreamMultiCallable(
request: RequestType,
*,
timeout: Optional[float] = None,
metadata: Optional[Metadata] = None,
metadata: Optional[MetadataType] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None,
@ -227,7 +230,7 @@ class StreamUnaryMultiCallable(
self,
request_iterator: Optional[RequestIterableType] = None,
timeout: Optional[float] = None,
metadata: Optional[Metadata] = None,
metadata: Optional[MetadataType] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None,
@ -272,7 +275,7 @@ class StreamStreamMultiCallable(
self,
request_iterator: Optional[RequestIterableType] = None,
timeout: Optional[float] = None,
metadata: Optional[Metadata] = None,
metadata: Optional[MetadataType] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None,

@ -6,6 +6,7 @@
"tests_aio.interop.local_interop_test.SecureLocalInteropTest",
"tests_aio.reflection.reflection_servicer_test.ReflectionServicerTest",
"tests_aio.status.grpc_status_test.StatusTest",
"tests_aio.unit._metadata_test.TestMetadataWithServer",
"tests_aio.unit._metadata_test.TestTypeMetadata",
"tests_aio.unit.abort_test.TestAbort",
"tests_aio.unit.aio_rpc_error_test.TestAioRpcError",

@ -15,8 +15,78 @@
import logging
import unittest
import grpc
from grpc.experimental import aio
from grpc.experimental.aio import Metadata
from tests_aio.unit import _common
from tests_aio.unit._test_base import AioTestBase
_TEST_UNARY_UNARY = "/test/TestUnaryUnary"
_INITIAL_METADATA_FROM_CLIENT_TO_SERVER = aio.Metadata(
("client-to-server", "question"),
("client-to-server-bin", b"\x07\x07\x07"),
)
_INITIAL_METADATA_FROM_CLIENT_TO_SERVER_TUPLE = (
("client-to-server", "question"),
("client-to-server-bin", b"\x07\x07\x07"),
)
_INTERCEPTOR_METADATA_KEY = "interceptor-metadata-key"
_INTERCEPTOR_METADATA_VALUE = "interceptor-metadata-value"
_INITIAL_METADATA_FROM_CLIENT_TO_SERVER_ALL = aio.Metadata(
(_INTERCEPTOR_METADATA_KEY, _INTERCEPTOR_METADATA_VALUE),
("client-to-server", "question"),
("client-to-server-bin", b"\x07\x07\x07"),
)
_REQUEST = b"\x01" * 100
_RESPONSE = b"\x02" * 100
def validate_client_metadata(servicer_context):
invocation_metadata = servicer_context.invocation_metadata()
assert _common.seen_metadata(
_INITIAL_METADATA_FROM_CLIENT_TO_SERVER_ALL,
invocation_metadata,
)
async def _test_unary_unary(unused_request, servicer_context):
validate_client_metadata(servicer_context)
return _RESPONSE
_ROUTING_TABLE = {
_TEST_UNARY_UNARY: grpc.unary_unary_rpc_method_handler(_test_unary_unary),
}
class _GenericHandler(grpc.GenericRpcHandler):
def service(self, handler_call_details):
return _ROUTING_TABLE.get(handler_call_details.method)
class UnaryUnaryAddMetadataInterceptor(aio.UnaryUnaryClientInterceptor):
async def intercept_unary_unary(
self,
continuation,
client_call_details,
request,
):
client_call_details.metadata.add(
_INTERCEPTOR_METADATA_KEY, _INTERCEPTOR_METADATA_VALUE
)
response = await continuation(client_call_details, request)
return response
async def _start_test_server(options=None):
server = aio.server(options=options)
port = server.add_insecure_port("[::]:0")
server.add_generic_rpc_handlers((_GenericHandler(),))
await server.start()
return f"localhost:{port}", server
class TestTypeMetadata(unittest.TestCase):
"""Tests for the metadata type"""
@ -137,6 +207,30 @@ class TestTypeMetadata(unittest.TestCase):
self.assertEqual(expected, Metadata.from_tuple(source))
class TestMetadataWithServer(AioTestBase):
async def setUp(self):
self._address, self._server = await _start_test_server()
self._channel = aio.insecure_channel(self._address)
async def tearDown(self):
await self._channel.close()
await self._server.stop(None)
async def test_init_metadata_with_client_interceptor(self):
async with aio.insecure_channel(
self._address,
interceptors=[UnaryUnaryAddMetadataInterceptor()],
) as channel:
multicallable = channel.unary_unary(_TEST_UNARY_UNARY)
for metadata in [
_INITIAL_METADATA_FROM_CLIENT_TO_SERVER,
_INITIAL_METADATA_FROM_CLIENT_TO_SERVER_TUPLE,
]:
call = multicallable(_REQUEST, metadata=metadata)
await call
self.assertEqual(grpc.StatusCode.OK, await call.code())
if __name__ == "__main__":
logging.basicConfig()
unittest.main(verbosity=2)

Loading…
Cancel
Save