Merge pull request #13784 from nathanielmanistaatgoogle/13752

Reallow out-of-spec metadata.
pull/13787/head^2
Nathaniel Manista 7 years ago committed by GitHub
commit bbb6270dc7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 19
      src/python/grpcio/grpc/_cython/_cygrpc/grpc_string.pyx.pxi
  2. 62
      src/python/grpcio_tests/tests/unit/_metadata_test.py

@ -26,15 +26,20 @@ cdef bytes str_to_bytes(object s):
raise TypeError('Expected bytes, str, or unicode, not {}'.format(type(s))) raise TypeError('Expected bytes, str, or unicode, not {}'.format(type(s)))
cdef bytes _encode(str native_string_or_none): # TODO(https://github.com/grpc/grpc/issues/13782): It would be nice for us if
if native_string_or_none is None: # the type of metadata that we accept were exactly the same as the type of
# metadata that we deliver to our users (so "str" for this function's
# parameter rather than "object"), but would it be nice for our users? Right
# now we haven't yet heard from enough users to know one way or another.
cdef bytes _encode(object string_or_none):
if string_or_none is None:
return b'' return b''
elif isinstance(native_string_or_none, (bytes,)): elif isinstance(string_or_none, (bytes,)):
return <bytes>native_string_or_none return <bytes>string_or_none
elif isinstance(native_string_or_none, (unicode,)): elif isinstance(string_or_none, (unicode,)):
return native_string_or_none.encode('ascii') return string_or_none.encode('ascii')
else: else:
raise TypeError('Expected str, not {}'.format(type(native_string_or_none))) raise TypeError('Expected str, not {}'.format(type(string_or_none)))
cdef str _decode(bytes bytestring): cdef str _decode(bytes bytestring):

@ -34,16 +34,19 @@ _UNARY_STREAM = '/test/UnaryStream'
_STREAM_UNARY = '/test/StreamUnary' _STREAM_UNARY = '/test/StreamUnary'
_STREAM_STREAM = '/test/StreamStream' _STREAM_STREAM = '/test/StreamStream'
_CLIENT_METADATA = (('client-md-key', 'client-md-key'), _INVOCATION_METADATA = ((b'invocation-md-key', u'invocation-md-value',),
('client-md-key-bin', b'\x00\x01')) (u'invocation-md-key-bin', b'\x00\x01',),)
_EXPECTED_INVOCATION_METADATA = (('invocation-md-key', 'invocation-md-value',),
('invocation-md-key-bin', b'\x00\x01',),)
_SERVER_INITIAL_METADATA = ( _INITIAL_METADATA = ((b'initial-md-key', u'initial-md-value'),
('server-initial-md-key', 'server-initial-md-value'), (u'initial-md-key-bin', b'\x00\x02'))
('server-initial-md-key-bin', b'\x00\x02')) _EXPECTED_INITIAL_METADATA = (('initial-md-key', 'initial-md-value',),
('initial-md-key-bin', b'\x00\x02',),)
_SERVER_TRAILING_METADATA = ( _TRAILING_METADATA = (('server-trailing-md-key', 'server-trailing-md-value',),
('server-trailing-md-key', 'server-trailing-md-value'), ('server-trailing-md-key-bin', b'\x00\x03',),)
('server-trailing-md-key-bin', b'\x00\x03')) _EXPECTED_TRAILING_METADATA = _TRAILING_METADATA
def user_agent(metadata): def user_agent(metadata):
@ -56,7 +59,8 @@ def user_agent(metadata):
def validate_client_metadata(test, servicer_context): def validate_client_metadata(test, servicer_context):
test.assertTrue( test.assertTrue(
test_common.metadata_transmitted( test_common.metadata_transmitted(
_CLIENT_METADATA, servicer_context.invocation_metadata())) _EXPECTED_INVOCATION_METADATA,
servicer_context.invocation_metadata()))
test.assertTrue( test.assertTrue(
user_agent(servicer_context.invocation_metadata()) user_agent(servicer_context.invocation_metadata())
.startswith('primary-agent ' + _channel._USER_AGENT)) .startswith('primary-agent ' + _channel._USER_AGENT))
@ -67,23 +71,23 @@ def validate_client_metadata(test, servicer_context):
def handle_unary_unary(test, request, servicer_context): def handle_unary_unary(test, request, servicer_context):
validate_client_metadata(test, servicer_context) validate_client_metadata(test, servicer_context)
servicer_context.send_initial_metadata(_SERVER_INITIAL_METADATA) servicer_context.send_initial_metadata(_INITIAL_METADATA)
servicer_context.set_trailing_metadata(_SERVER_TRAILING_METADATA) servicer_context.set_trailing_metadata(_TRAILING_METADATA)
return _RESPONSE return _RESPONSE
def handle_unary_stream(test, request, servicer_context): def handle_unary_stream(test, request, servicer_context):
validate_client_metadata(test, servicer_context) validate_client_metadata(test, servicer_context)
servicer_context.send_initial_metadata(_SERVER_INITIAL_METADATA) servicer_context.send_initial_metadata(_INITIAL_METADATA)
servicer_context.set_trailing_metadata(_SERVER_TRAILING_METADATA) servicer_context.set_trailing_metadata(_TRAILING_METADATA)
for _ in range(test_constants.STREAM_LENGTH): for _ in range(test_constants.STREAM_LENGTH):
yield _RESPONSE yield _RESPONSE
def handle_stream_unary(test, request_iterator, servicer_context): def handle_stream_unary(test, request_iterator, servicer_context):
validate_client_metadata(test, servicer_context) validate_client_metadata(test, servicer_context)
servicer_context.send_initial_metadata(_SERVER_INITIAL_METADATA) servicer_context.send_initial_metadata(_INITIAL_METADATA)
servicer_context.set_trailing_metadata(_SERVER_TRAILING_METADATA) servicer_context.set_trailing_metadata(_TRAILING_METADATA)
# TODO(issue:#6891) We should be able to remove this loop # TODO(issue:#6891) We should be able to remove this loop
for request in request_iterator: for request in request_iterator:
pass pass
@ -92,8 +96,8 @@ def handle_stream_unary(test, request_iterator, servicer_context):
def handle_stream_stream(test, request_iterator, servicer_context): def handle_stream_stream(test, request_iterator, servicer_context):
validate_client_metadata(test, servicer_context) validate_client_metadata(test, servicer_context)
servicer_context.send_initial_metadata(_SERVER_INITIAL_METADATA) servicer_context.send_initial_metadata(_INITIAL_METADATA)
servicer_context.set_trailing_metadata(_SERVER_TRAILING_METADATA) servicer_context.set_trailing_metadata(_TRAILING_METADATA)
# TODO(issue:#6891) We should be able to remove this loop, # TODO(issue:#6891) We should be able to remove this loop,
# and replace with return; yield # and replace with return; yield
for request in request_iterator: for request in request_iterator:
@ -156,50 +160,50 @@ class MetadataTest(unittest.TestCase):
def testUnaryUnary(self): def testUnaryUnary(self):
multi_callable = self._channel.unary_unary(_UNARY_UNARY) multi_callable = self._channel.unary_unary(_UNARY_UNARY)
unused_response, call = multi_callable.with_call( unused_response, call = multi_callable.with_call(
_REQUEST, metadata=_CLIENT_METADATA) _REQUEST, metadata=_INVOCATION_METADATA)
self.assertTrue( self.assertTrue(
test_common.metadata_transmitted(_SERVER_INITIAL_METADATA, test_common.metadata_transmitted(_EXPECTED_INITIAL_METADATA,
call.initial_metadata())) call.initial_metadata()))
self.assertTrue( self.assertTrue(
test_common.metadata_transmitted(_SERVER_TRAILING_METADATA, test_common.metadata_transmitted(_EXPECTED_TRAILING_METADATA,
call.trailing_metadata())) call.trailing_metadata()))
def testUnaryStream(self): def testUnaryStream(self):
multi_callable = self._channel.unary_stream(_UNARY_STREAM) multi_callable = self._channel.unary_stream(_UNARY_STREAM)
call = multi_callable(_REQUEST, metadata=_CLIENT_METADATA) call = multi_callable(_REQUEST, metadata=_INVOCATION_METADATA)
self.assertTrue( self.assertTrue(
test_common.metadata_transmitted(_SERVER_INITIAL_METADATA, test_common.metadata_transmitted(_EXPECTED_INITIAL_METADATA,
call.initial_metadata())) call.initial_metadata()))
for _ in call: for _ in call:
pass pass
self.assertTrue( self.assertTrue(
test_common.metadata_transmitted(_SERVER_TRAILING_METADATA, test_common.metadata_transmitted(_EXPECTED_TRAILING_METADATA,
call.trailing_metadata())) call.trailing_metadata()))
def testStreamUnary(self): def testStreamUnary(self):
multi_callable = self._channel.stream_unary(_STREAM_UNARY) multi_callable = self._channel.stream_unary(_STREAM_UNARY)
unused_response, call = multi_callable.with_call( unused_response, call = multi_callable.with_call(
iter([_REQUEST] * test_constants.STREAM_LENGTH), iter([_REQUEST] * test_constants.STREAM_LENGTH),
metadata=_CLIENT_METADATA) metadata=_INVOCATION_METADATA)
self.assertTrue( self.assertTrue(
test_common.metadata_transmitted(_SERVER_INITIAL_METADATA, test_common.metadata_transmitted(_EXPECTED_INITIAL_METADATA,
call.initial_metadata())) call.initial_metadata()))
self.assertTrue( self.assertTrue(
test_common.metadata_transmitted(_SERVER_TRAILING_METADATA, test_common.metadata_transmitted(_EXPECTED_TRAILING_METADATA,
call.trailing_metadata())) call.trailing_metadata()))
def testStreamStream(self): def testStreamStream(self):
multi_callable = self._channel.stream_stream(_STREAM_STREAM) multi_callable = self._channel.stream_stream(_STREAM_STREAM)
call = multi_callable( call = multi_callable(
iter([_REQUEST] * test_constants.STREAM_LENGTH), iter([_REQUEST] * test_constants.STREAM_LENGTH),
metadata=_CLIENT_METADATA) metadata=_INVOCATION_METADATA)
self.assertTrue( self.assertTrue(
test_common.metadata_transmitted(_SERVER_INITIAL_METADATA, test_common.metadata_transmitted(_EXPECTED_INITIAL_METADATA,
call.initial_metadata())) call.initial_metadata()))
for _ in call: for _ in call:
pass pass
self.assertTrue( self.assertTrue(
test_common.metadata_transmitted(_SERVER_TRAILING_METADATA, test_common.metadata_transmitted(_EXPECTED_TRAILING_METADATA,
call.trailing_metadata())) call.trailing_metadata()))

Loading…
Cancel
Save