Fix a bug that prevents metadata modification in interceptors

pull/21647/head
Lidi Zheng 5 years ago
parent a3e950adbb
commit 435cf89108
  1. 5
      src/python/grpcio/grpc/_cython/_cygrpc/metadata.pyx.pxi
  2. 3
      src/python/grpcio/grpc/experimental/aio/_interceptor.py
  3. 7
      src/python/grpcio_tests/tests_aio/unit/BUILD.bazel
  4. 24
      src/python/grpcio_tests/tests_aio/unit/_common.py
  5. 61
      src/python/grpcio_tests/tests_aio/unit/interceptor_test.py
  6. 17
      src/python/grpcio_tests/tests_aio/unit/metadata_test.py

@ -41,6 +41,11 @@ cdef void _store_c_metadata(
for index, (key, value) in enumerate(metadata):
encoded_key = _encode(key)
encoded_value = value if encoded_key[-4:] == b'-bin' else _encode(value)
if type(encoded_value) != bytes:
raise TypeError('Binary metadata key="%s" expected bytes, got %s' % (
key,
type(encoded_value)
))
c_metadata[0][index].key = _slice_from_bytes(encoded_key)
c_metadata[0][index].value = _slice_from_bytes(encoded_value)

@ -150,7 +150,8 @@ class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall):
else:
return UnaryUnaryCall(
request, _timeout_to_deadline(client_call_details.timeout),
metadata, client_call_details.credentials, self._channel,
client_call_details.metadata,
client_call_details.credentials, self._channel,
client_call_details.method, request_serializer,
response_deserializer)

@ -43,6 +43,12 @@ py_library(
srcs_version = "PY3",
)
py_library(
name = "_common",
srcs = ["_common.py"],
srcs_version = "PY3",
)
[
py_test(
name = test_file_name[:-3],
@ -55,6 +61,7 @@ py_library(
main = test_file_name,
python_version = "PY3",
deps = [
":_common",
":_constants",
":_test_base",
":_test_server",

@ -0,0 +1,24 @@
# Copyright 2020 The gRPC Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
def seen_metadata(expected, actual):
metadata_dict = dict(actual)
if type(expected[0]) != tuple:
return metadata_dict.get(expected[0]) == expected[1]
else:
for metadatum in expected:
if metadata_dict.get(metadatum[0]) != metadatum[1]:
return False
return True

@ -18,11 +18,18 @@ import unittest
import grpc
from grpc.experimental import aio
from tests_aio.unit._test_server import start_test_server, UNARY_CALL_WITH_SLEEP_VALUE
from tests_aio.unit._test_server import start_test_server, _INITIAL_METADATA_KEY, _TRAILING_METADATA_KEY
from tests_aio.unit import _constants
from tests_aio.unit import _common
from tests_aio.unit._test_base import AioTestBase
from src.proto.grpc.testing import messages_pb2
from src.proto.grpc.testing import messages_pb2, test_pb2_grpc
_LOCAL_CANCEL_DETAILS_EXPECTATION = 'Locally cancelled by application!'
_INITIAL_METADATA_TO_INJECT = (
(_INITIAL_METADATA_KEY, 'extra info'),
(_TRAILING_METADATA_KEY, b'\x13\x37'),
)
class TestUnaryUnaryClientInterceptor(AioTestBase):
@ -124,7 +131,7 @@ class TestUnaryUnaryClientInterceptor(AioTestBase):
client_call_details, request):
new_client_call_details = aio.ClientCallDetails(
method=client_call_details.method,
timeout=UNARY_CALL_WITH_SLEEP_VALUE / 2,
timeout=_constants.UNARY_CALL_WITH_SLEEP_VALUE / 2,
metadata=client_call_details.metadata,
credentials=client_call_details.credentials)
return await continuation(new_client_call_details, request)
@ -165,7 +172,7 @@ class TestUnaryUnaryClientInterceptor(AioTestBase):
new_client_call_details = aio.ClientCallDetails(
method=client_call_details.method,
timeout=UNARY_CALL_WITH_SLEEP_VALUE / 2,
timeout=_constants.UNARY_CALL_WITH_SLEEP_VALUE / 2,
metadata=client_call_details.metadata,
credentials=client_call_details.credentials)
@ -342,8 +349,9 @@ class TestInterceptedUnaryUnaryCall(AioTestBase):
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = multicallable(messages_pb2.SimpleRequest(),
timeout=UNARY_CALL_WITH_SLEEP_VALUE / 2)
call = multicallable(
messages_pb2.SimpleRequest(),
timeout=_constants.UNARY_CALL_WITH_SLEEP_VALUE / 2)
with self.assertRaises(aio.AioRpcError) as exception_context:
await call
@ -375,8 +383,9 @@ class TestInterceptedUnaryUnaryCall(AioTestBase):
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = multicallable(messages_pb2.SimpleRequest(),
timeout=UNARY_CALL_WITH_SLEEP_VALUE / 2)
call = multicallable(
messages_pb2.SimpleRequest(),
timeout=_constants.UNARY_CALL_WITH_SLEEP_VALUE / 2)
with self.assertRaises(aio.AioRpcError) as exception_context:
await call
@ -532,6 +541,42 @@ class TestInterceptedUnaryUnaryCall(AioTestBase):
self.assertEqual(await call.initial_metadata(), tuple())
self.assertEqual(await call.trailing_metadata(), None)
async def test_initial_metadata_modification(self):
class Interceptor(aio.UnaryUnaryClientInterceptor):
async def intercept_unary_unary(self, continuation,
client_call_details, request):
if client_call_details.metadata is not None:
new_metadata = client_call_details.metadata + _INITIAL_METADATA_TO_INJECT
else:
new_metadata = _INITIAL_METADATA_TO_INJECT
new_details = aio.ClientCallDetails(
method=client_call_details.method,
timeout=client_call_details.timeout,
metadata=new_metadata,
credentials=client_call_details.credentials,
)
return await continuation(new_details, request)
async with aio.insecure_channel(self._server_target,
interceptors=[Interceptor()
]) as channel:
stub = test_pb2_grpc.TestServiceStub(channel)
call = stub.UnaryCall(messages_pb2.SimpleRequest())
# Expected to see the echoed initial metadata
self.assertTrue(
_common.seen_metadata(_INITIAL_METADATA_TO_INJECT[0], await
call.initial_metadata()))
# Expected to see the echoed trailing metadata
self.assertTrue(
_common.seen_metadata(_INITIAL_METADATA_TO_INJECT[1], await
call.trailing_metadata()))
self.assertEqual(await call.code(), grpc.StatusCode.OK)
if __name__ == '__main__':
logging.basicConfig()

@ -23,6 +23,7 @@ import grpc
from grpc.experimental import aio
from tests_aio.unit._test_base import AioTestBase
from tests_aio.unit import _common
_TEST_CLIENT_TO_SERVER = '/test/TestClientToServer'
_TEST_SERVER_TO_CLIENT = '/test/TestServerToClient'
@ -69,21 +70,13 @@ _INVALID_METADATA_TEST_CASES = (
)
def _seen_metadata(expected, actual):
metadata_dict = dict(actual)
for metadatum in expected:
if metadata_dict.get(metadatum[0]) != metadatum[1]:
return False
return True
class _TestGenericHandlerForMethods(grpc.GenericRpcHandler):
@staticmethod
async def _test_client_to_server(request, context):
assert _REQUEST == request
assert _seen_metadata(_INITIAL_METADATA_FROM_CLIENT_TO_SERVER,
context.invocation_metadata())
assert _common.seen_metadata(_INITIAL_METADATA_FROM_CLIENT_TO_SERVER,
context.invocation_metadata())
return _RESPONSE
@staticmethod
@ -120,8 +113,8 @@ class _TestGenericHandlerItself(grpc.GenericRpcHandler):
return _RESPONSE
def service(self, handler_details):
assert _seen_metadata(_INITIAL_METADATA_FOR_GENERIC_HANDLER,
handler_details.invocation_metadata)
assert _common.seen_metadata(_INITIAL_METADATA_FOR_GENERIC_HANDLER,
handler_details.invocation_metadata)
return grpc.unary_unary_rpc_method_handler(self._method)

Loading…
Cancel
Save