Fix initial metadata problem. Very messy. Needs additional tests

pull/20812/head
Richard Belleville 5 years ago
parent 0c6f8dbed3
commit 752e9be052
  1. 54
      src/python/grpcio/grpc/_channel.py
  2. 1
      src/python/grpcio_tests/tests/unit/BUILD.bazel
  3. 13
      src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py
  4. 6
      src/python/grpcio_tests/tests/unit/_metadata_flags_test.py
  5. 3
      src/python/grpcio_tests/tests/unit/_metadata_test.py

@ -314,6 +314,17 @@ class _SingleThreadedRendezvous(grpc.RpcError, grpc.Call): # pylint: disable=to
def initial_metadata(self):
"""See grpc.Call.initial_metadata"""
# TODO: Ahhhhhhh!
if self.__class__ is _SingleThreadedRendezvous:
with self._state.condition:
while self._state.initial_metadata is None:
event = self._get_next_event()
# TODO: Replace this assert with a test for dropped message.
for operation in event.batch_operations:
if operation.type() == cygrpc.OperationType.receive_message:
assert False, "This would drop a message. Don't do this."
return self._state.initial_metadata
else:
with self._state.condition:
def _done():
@ -354,18 +365,7 @@ class _SingleThreadedRendezvous(grpc.RpcError, grpc.Call): # pylint: disable=to
_common.wait(self._state.condition.wait, _done)
return _common.decode(self._state.details)
def _next(self):
with self._state.condition:
if self._state.code is None:
operating = self._call.operate(
(cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),), None)
if operating:
self._state.due.add(cygrpc.OperationType.receive_message)
elif self._state.code is grpc.StatusCode.OK:
raise StopIteration()
else:
raise self
while True:
def _get_next_event(self):
event = self._call.next_event()
with self._state.condition:
callbacks = _handle_event(event, self._state,
@ -378,6 +378,12 @@ class _SingleThreadedRendezvous(grpc.RpcError, grpc.Call): # pylint: disable=to
# kill the channel spin thread.
logging.error('Exception in callback %s: %s',
repr(callback.func), repr(e))
return event
def _next_response(self):
while True:
event = self._get_next_event()
with self._state.condition:
if self._state.response is not None:
response = self._state.response
self._state.response = None
@ -388,6 +394,19 @@ class _SingleThreadedRendezvous(grpc.RpcError, grpc.Call): # pylint: disable=to
elif self._state.code is not None:
raise self
def _next(self):
with self._state.condition:
if self._state.code is None:
operating = self._call.operate(
(cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),), None)
if operating:
self._state.due.add(cygrpc.OperationType.receive_message)
elif self._state.code is grpc.StatusCode.OK:
raise StopIteration()
else:
raise self
return self._next_response()
def __next__(self):
return self._next()
@ -755,13 +774,14 @@ class _SingleThreadedUnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
wait_for_ready)
augmented_metadata = _compression.augment_metadata(
metadata, compression)
# TODO: Formatting.
operations_and_tags = ((
(cygrpc.SendInitialMetadataOperation(augmented_metadata,
initial_metadata_flags),
cygrpc.SendMessageOperation(serialized_request, _EMPTY_FLAGS),
cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS),
cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS)), None),) + (((
cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),), None),)
cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS)), None),) + \
((( cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS),), None),) + \
((( cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),), None),)
call = self._channel.segregated_call(
cygrpc.PropagationConstants.GRPC_PROPAGATE_DEFAULTS, self._method,
None, _determine_deadline(deadline), metadata, call_credentials,
@ -1239,7 +1259,9 @@ class Channel(grpc.Channel):
# on a single Python thread results in an appreciable speed-up. However,
# due to slight differences in capability, the multi-threaded variant'
# remains the default.
if self._single_threaded_unary_stream:
# if self._single_threaded_unary_stream:
# TODO: Put this back.
if True:
return _SingleThreadedUnaryStreamMultiCallable(
self._channel, _common.encode(method), request_serializer,
response_deserializer)

@ -23,7 +23,6 @@ GRPCIO_TESTS_UNIT = [
"_invocation_defects_test.py",
"_local_credentials_test.py",
"_logging_test.py",
"_metadata_flags_test.py",
"_metadata_code_details_test.py",
"_metadata_test.py",
# TODO: Issue 16336

@ -255,8 +255,8 @@ class MetadataCodeDetailsTest(unittest.TestCase):
response_iterator_call = self._unary_stream(
_SERIALIZED_REQUEST, metadata=_CLIENT_METADATA)
list(response_iterator_call)
received_initial_metadata = response_iterator_call.initial_metadata()
list(response_iterator_call)
self.assertTrue(
test_common.metadata_transmitted(
@ -349,14 +349,11 @@ class MetadataCodeDetailsTest(unittest.TestCase):
response_iterator_call = self._unary_stream(
_SERIALIZED_REQUEST, metadata=_CLIENT_METADATA)
# NOTE: In the single-threaded case, we cannot grab the initial_metadata
# without running the RPC first (or concurrently, in another
# thread).
received_initial_metadata = \
response_iterator_call.initial_metadata()
with self.assertRaises(grpc.RpcError):
self.assertEqual(len(list(response_iterator_call)), 0)
received_initial_metadata = \
response_iterator_call.initial_metadata()
self.assertTrue(
test_common.metadata_transmitted(
_CLIENT_METADATA,
@ -457,9 +454,9 @@ class MetadataCodeDetailsTest(unittest.TestCase):
response_iterator_call = self._unary_stream(
_SERIALIZED_REQUEST, metadata=_CLIENT_METADATA)
received_initial_metadata = response_iterator_call.initial_metadata()
with self.assertRaises(grpc.RpcError):
list(response_iterator_call)
received_initial_metadata = response_iterator_call.initial_metadata()
self.assertTrue(
test_common.metadata_transmitted(
@ -550,9 +547,9 @@ class MetadataCodeDetailsTest(unittest.TestCase):
response_iterator_call = self._unary_stream(
_SERIALIZED_REQUEST, metadata=_CLIENT_METADATA)
received_initial_metadata = response_iterator_call.initial_metadata()
with self.assertRaises(grpc.RpcError):
list(response_iterator_call)
received_initial_metadata = response_iterator_call.initial_metadata()
self.assertTrue(
test_common.metadata_transmitted(

@ -94,10 +94,10 @@ class _GenericHandler(grpc.GenericRpcHandler):
def get_free_loopback_tcp_port():
tcp = socket.socket(socket.AF_INET)
tcp = socket.socket(socket.AF_INET6)
tcp.bind(('', 0))
address_tuple = tcp.getsockname()
return tcp, "localhost:%s" % (address_tuple[1])
return tcp, "[::1]:%s" % (address_tuple[1])
def create_dummy_channel():
@ -183,7 +183,7 @@ class MetadataFlagsTest(unittest.TestCase):
fn(channel, wait_for_ready)
self.fail("The Call should fail")
except BaseException as e: # pylint: disable=broad-except
self.assertIs(grpc.StatusCode.UNAVAILABLE, e.code())
self.assertIn('StatusCode.UNAVAILABLE', str(e))
def test_call_wait_for_ready_default(self):
for perform_call in _ALL_CALL_CASES:

@ -202,9 +202,6 @@ class MetadataTest(unittest.TestCase):
def testUnaryStream(self):
multi_callable = self._channel.unary_stream(_UNARY_STREAM)
call = multi_callable(_REQUEST, metadata=_INVOCATION_METADATA)
# TODO(https://github.com/grpc/grpc/issues/20762): Make the call to
# `next()` unnecessary.
next(call)
self.assertTrue(
test_common.metadata_transmitted(_EXPECTED_INITIAL_METADATA,
call.initial_metadata()))

Loading…
Cancel
Save