Merge pull request #20753 from gnossen/unary_stream

Add experimental option to run unary-stream RPCs on a single Python thread.
pull/20774/head
Richard Belleville 5 years ago committed by GitHub
commit 018580fb89
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 367
      src/python/grpcio/grpc/_channel.py
  2. 2
      src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi
  3. 11
      src/python/grpcio/grpc/experimental/__init__.py
  4. 30
      src/python/grpcio_tests/tests/stress/BUILD.bazel
  5. 27
      src/python/grpcio_tests/tests/stress/unary_stream_benchmark.proto
  6. 104
      src/python/grpcio_tests/tests/stress/unary_stream_benchmark.py
  7. 1
      src/python/grpcio_tests/tests/unit/BUILD.bazel
  8. 13
      src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py
  9. 6
      src/python/grpcio_tests/tests/unit/_metadata_flags_test.py
  10. 3
      src/python/grpcio_tests/tests/unit/_metadata_test.py

@ -20,6 +20,7 @@ import threading
import time
import grpc
import grpc.experimental
from grpc import _compression
from grpc import _common
from grpc import _grpcio_metadata
@ -248,16 +249,47 @@ def _consume_request_iterator(request_iterator, state, call, request_serializer,
consumption_thread.start()
class _Rendezvous(grpc.RpcError, grpc.Future, grpc.Call): # pylint: disable=too-many-ancestors
class _SingleThreadedRendezvous(grpc.RpcError, grpc.Call): # pylint: disable=too-many-ancestors
"""An RPC iterator operating entirely on a single thread.
The __next__ method of _SingleThreadedRendezvous does not depend on the
existence of any other thread, including the "channel spin thread".
However, this means that its interface is entirely synchronous. So this
class cannot fulfill the grpc.Future interface.
Attributes:
_state: An instance of _RPCState.
_call: An instance of SegregatedCall or (for subclasses) IntegratedCall.
In either case, the _call object is expected to have operate, cancel,
and next_event methods.
_response_deserializer: A callable taking bytes and return a Python
object.
_deadline: A float representing the deadline of the RPC in seconds. Or
possibly None, to represent an RPC with no deadline at all.
"""
def __init__(self, state, call, response_deserializer, deadline):
super(_Rendezvous, self).__init__()
super(_SingleThreadedRendezvous, self).__init__()
self._state = state
self._call = call
self._response_deserializer = response_deserializer
self._deadline = deadline
def is_active(self):
"""See grpc.RpcContext.is_active"""
with self._state.condition:
return self._state.code is None
def time_remaining(self):
"""See grpc.RpcContext.time_remaining"""
with self._state.condition:
if self._deadline is None:
return None
else:
return max(self._deadline - time.time(), 0)
def cancel(self):
"""See grpc.RpcContext.cancel"""
with self._state.condition:
if self._state.code is None:
code = grpc.StatusCode.CANCELLED
@ -267,7 +299,154 @@ class _Rendezvous(grpc.RpcError, grpc.Future, grpc.Call): # pylint: disable=too
self._state.cancelled = True
_abort(self._state, code, details)
self._state.condition.notify_all()
return False
return True
else:
return False
def add_callback(self, callback):
"""See grpc.RpcContext.add_callback"""
with self._state.condition:
if self._state.callbacks is None:
return False
else:
self._state.callbacks.append(callback)
return True
def initial_metadata(self):
"""See grpc.Call.initial_metadata"""
with self._state.condition:
def _done():
return self._state.initial_metadata is not None
_common.wait(self._state.condition.wait, _done)
return self._state.initial_metadata
def trailing_metadata(self):
"""See grpc.Call.trailing_metadata"""
with self._state.condition:
def _done():
return self._state.trailing_metadata is not None
_common.wait(self._state.condition.wait, _done)
return self._state.trailing_metadata
# TODO(https://github.com/grpc/grpc/issues/20763): Drive RPC progress using
# the calling thread.
def code(self):
"""See grpc.Call.code"""
with self._state.condition:
def _done():
return self._state.code is not None
_common.wait(self._state.condition.wait, _done)
return self._state.code
def details(self):
"""See grpc.Call.details"""
with self._state.condition:
def _done():
return self._state.details is not None
_common.wait(self._state.condition.wait, _done)
return _common.decode(self._state.details)
def _next(self):
with self._state.condition:
if self._state.code is None:
operating = self._call.operate(
(cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),), None)
if operating:
self._state.due.add(cygrpc.OperationType.receive_message)
elif self._state.code is grpc.StatusCode.OK:
raise StopIteration()
else:
raise self
while True:
event = self._call.next_event()
with self._state.condition:
callbacks = _handle_event(event, self._state,
self._response_deserializer)
for callback in callbacks:
try:
callback()
except Exception as e: # pylint: disable=broad-except
# NOTE(rbellevi): We suppress but log errors here so as not to
# kill the channel spin thread.
logging.error('Exception in callback %s: %s',
repr(callback.func), repr(e))
if self._state.response is not None:
response = self._state.response
self._state.response = None
return response
elif cygrpc.OperationType.receive_message not in self._state.due:
if self._state.code is grpc.StatusCode.OK:
raise StopIteration()
elif self._state.code is not None:
raise self
def __next__(self):
return self._next()
def next(self):
return self._next()
def __iter__(self):
return self
def debug_error_string(self):
with self._state.condition:
def _done():
return self._state.debug_error_string is not None
_common.wait(self._state.condition.wait, _done)
return _common.decode(self._state.debug_error_string)
def _repr(self):
with self._state.condition:
if self._state.code is None:
return '<{} object of in-flight RPC>'.format(
self.__class__.__name__)
elif self._state.code is grpc.StatusCode.OK:
return _OK_RENDEZVOUS_REPR_FORMAT.format(
self._state.code, self._state.details)
else:
return _NON_OK_RENDEZVOUS_REPR_FORMAT.format(
self._state.code, self._state.details,
self._state.debug_error_string)
def __repr__(self):
return self._repr()
def __str__(self):
return self._repr()
def __del__(self):
with self._state.condition:
if self._state.code is None:
self._state.code = grpc.StatusCode.CANCELLED
self._state.details = 'Cancelled upon garbage collection!'
self._state.cancelled = True
self._call.cancel(
_common.STATUS_CODE_TO_CYGRPC_STATUS_CODE[self._state.code],
self._state.details)
self._state.condition.notify_all()
class _Rendezvous(_SingleThreadedRendezvous, grpc.Future): # pylint: disable=too-many-ancestors
"""An RPC iterator that depends on a channel spin thread.
This iterator relies upon a per-channel thread running in the background,
dequeueing events from the completion queue, and notifying threads waiting
on the threading.Condition object in the _RPCState object.
This extra thread allows _Rendezvous to fulfill the grpc.Future interface
and to mediate a bidirection streaming RPC.
"""
def cancelled(self):
with self._state.condition:
@ -381,25 +560,6 @@ class _Rendezvous(grpc.RpcError, grpc.Future, grpc.Call): # pylint: disable=too
elif self._state.code is not None:
raise self
def __iter__(self):
return self
def __next__(self):
return self._next()
def next(self):
return self._next()
def is_active(self):
with self._state.condition:
return self._state.code is None
def time_remaining(self):
if self._deadline is None:
return None
else:
return max(self._deadline - time.time(), 0)
def add_callback(self, callback):
with self._state.condition:
if self._state.callbacks is None:
@ -408,80 +568,6 @@ class _Rendezvous(grpc.RpcError, grpc.Future, grpc.Call): # pylint: disable=too
self._state.callbacks.append(callback)
return True
def initial_metadata(self):
with self._state.condition:
def _done():
return self._state.initial_metadata is not None
_common.wait(self._state.condition.wait, _done)
return self._state.initial_metadata
def trailing_metadata(self):
with self._state.condition:
def _done():
return self._state.trailing_metadata is not None
_common.wait(self._state.condition.wait, _done)
return self._state.trailing_metadata
def code(self):
with self._state.condition:
def _done():
return self._state.code is not None
_common.wait(self._state.condition.wait, _done)
return self._state.code
def details(self):
with self._state.condition:
def _done():
return self._state.details is not None
_common.wait(self._state.condition.wait, _done)
return _common.decode(self._state.details)
def debug_error_string(self):
with self._state.condition:
def _done():
return self._state.debug_error_string is not None
_common.wait(self._state.condition.wait, _done)
return _common.decode(self._state.debug_error_string)
def _repr(self):
with self._state.condition:
if self._state.code is None:
return '<_Rendezvous object of in-flight RPC>'
elif self._state.code is grpc.StatusCode.OK:
return _OK_RENDEZVOUS_REPR_FORMAT.format(
self._state.code, self._state.details)
else:
return _NON_OK_RENDEZVOUS_REPR_FORMAT.format(
self._state.code, self._state.details,
self._state.debug_error_string)
def __repr__(self):
return self._repr()
def __str__(self):
return self._repr()
def __del__(self):
with self._state.condition:
if self._state.code is None:
self._state.code = grpc.StatusCode.CANCELLED
self._state.details = 'Cancelled upon garbage collection!'
self._state.cancelled = True
self._call.cancel(
_common.STATUS_CODE_TO_CYGRPC_STATUS_CODE[self._state.code],
self._state.details)
self._state.condition.notify_all()
def _start_unary_request(request, timeout, request_serializer):
deadline = _deadline(timeout)
@ -636,6 +722,54 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
deadline)
class _SingleThreadedUnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
# pylint: disable=too-many-arguments
def __init__(self, channel, method, request_serializer,
response_deserializer):
self._channel = channel
self._method = method
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
self._context = cygrpc.build_census_context()
def __call__( # pylint: disable=too-many-locals
self,
request,
timeout=None,
metadata=None,
credentials=None,
wait_for_ready=None,
compression=None):
deadline = _deadline(timeout)
serialized_request = _common.serialize(request,
self._request_serializer)
if serialized_request is None:
state = _RPCState((), (), (), grpc.StatusCode.INTERNAL,
'Exception serializing request!')
raise _Rendezvous(state, None, None, deadline)
state = _RPCState(_UNARY_STREAM_INITIAL_DUE, None, None, None, None)
call_credentials = None if credentials is None else credentials._credentials
initial_metadata_flags = _InitialMetadataFlags().with_wait_for_ready(
wait_for_ready)
augmented_metadata = _compression.augment_metadata(
metadata, compression)
operations_and_tags = ((
(cygrpc.SendInitialMetadataOperation(augmented_metadata,
initial_metadata_flags),
cygrpc.SendMessageOperation(serialized_request, _EMPTY_FLAGS),
cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS),
cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS)), None),) + (((
cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),), None),)
call = self._channel.segregated_call(
cygrpc.PropagationConstants.GRPC_PROPAGATE_DEFAULTS, self._method,
None, _determine_deadline(deadline), metadata, call_credentials,
operations_and_tags, self._context)
return _SingleThreadedRendezvous(state, call,
self._response_deserializer, deadline)
class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
# pylint: disable=too-many-arguments
@ -1042,6 +1176,18 @@ def _augment_options(base_options, compression):
),)
def _separate_channel_options(options):
"""Separates core channel options from Python channel options."""
core_options = []
python_options = []
for pair in options:
if pair[0] == grpc.experimental.ChannelOptions.SingleThreadedUnaryStream:
python_options.append(pair)
else:
core_options.append(pair)
return python_options, core_options
class Channel(grpc.Channel):
"""A cygrpc.Channel-backed implementation of grpc.Channel."""
@ -1055,13 +1201,22 @@ class Channel(grpc.Channel):
compression: An optional value indicating the compression method to be
used over the lifetime of the channel.
"""
python_options, core_options = _separate_channel_options(options)
self._single_threaded_unary_stream = False
self._process_python_options(python_options)
self._channel = cygrpc.Channel(
_common.encode(target), _augment_options(options, compression),
_common.encode(target), _augment_options(core_options, compression),
credentials)
self._call_state = _ChannelCallState(self._channel)
self._connectivity_state = _ChannelConnectivityState(self._channel)
cygrpc.fork_register_channel(self)
def _process_python_options(self, python_options):
"""Sets channel attributes according to python-only channel options."""
for pair in python_options:
if pair[0] == grpc.experimental.ChannelOptions.SingleThreadedUnaryStream:
self._single_threaded_unary_stream = True
def subscribe(self, callback, try_to_connect=None):
_subscribe(self._connectivity_state, callback, try_to_connect)
@ -1080,9 +1235,21 @@ class Channel(grpc.Channel):
method,
request_serializer=None,
response_deserializer=None):
return _UnaryStreamMultiCallable(
self._channel, _channel_managed_call_management(self._call_state),
_common.encode(method), request_serializer, response_deserializer)
# NOTE(rbellevi): Benchmarks have shown that running a unary-stream RPC
# on a single Python thread results in an appreciable speed-up. However,
# due to slight differences in capability, the multi-threaded variant'
# remains the default.
if self._single_threaded_unary_stream:
return _SingleThreadedUnaryStreamMultiCallable(
self._channel, _common.encode(method), request_serializer,
response_deserializer)
else:
return _UnaryStreamMultiCallable(self._channel,
_channel_managed_call_management(
self._call_state),
_common.encode(method),
request_serializer,
response_deserializer)
def stream_unary(self,
method,

@ -420,8 +420,6 @@ cdef _close(Channel channel, grpc_status_code code, object details,
else:
while state.integrated_call_states:
state.condition.wait()
while state.segregated_call_states:
state.condition.wait()
while state.connectivity_due:
state.condition.wait()

@ -15,3 +15,14 @@
These APIs are subject to be removed during any minor version release.
"""
class ChannelOptions(object):
"""Indicates a channel option unique to gRPC Python.
This enumeration is part of an EXPERIMENTAL API.
Attributes:
SingleThreadedUnaryStream: Perform unary-stream RPCs on a single thread.
"""
SingleThreadedUnaryStream = "SingleThreadedUnaryStream"

@ -0,0 +1,30 @@
load("@com_github_grpc_grpc//bazel:python_rules.bzl", "py_proto_library", "py_grpc_library")
proto_library(
name = "unary_stream_benchmark_proto",
srcs = ["unary_stream_benchmark.proto"],
deps = [],
)
py_proto_library(
name = "unary_stream_benchmark_py_pb2",
deps = [":unary_stream_benchmark_proto"],
)
py_grpc_library(
name = "unary_stream_benchmark_py_pb2_grpc",
srcs = [":unary_stream_benchmark_proto"],
deps = [":unary_stream_benchmark_py_pb2"],
)
py_binary(
name = "unary_stream_benchmark",
srcs_version = "PY3",
python_version = "PY3",
srcs = ["unary_stream_benchmark.py"],
deps = [
"//src/python/grpcio/grpc:grpcio",
":unary_stream_benchmark_py_pb2",
":unary_stream_benchmark_py_pb2_grpc",
]
)

@ -0,0 +1,27 @@
// Copyright 2019 The gRPC Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto3";
message BenchmarkRequest {
int32 message_size = 1;
int32 response_count = 2;
}
message BenchmarkResponse {
bytes response = 1;
}
service UnaryStreamBenchmarkService {
rpc Benchmark(BenchmarkRequest) returns (stream BenchmarkResponse);
}

@ -0,0 +1,104 @@
# Copyright 2019 The gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import threading
import grpc
import grpc.experimental
import subprocess
import sys
import time
import contextlib
_PORT = 5741
_MESSAGE_SIZE = 4
_RESPONSE_COUNT = 32 * 1024
_SERVER_CODE = """
import datetime
import threading
import grpc
from concurrent import futures
from src.python.grpcio_tests.tests.stress import unary_stream_benchmark_pb2
from src.python.grpcio_tests.tests.stress import unary_stream_benchmark_pb2_grpc
class Handler(unary_stream_benchmark_pb2_grpc.UnaryStreamBenchmarkServiceServicer):
def Benchmark(self, request, context):
payload = b'\\x00\\x01' * int(request.message_size / 2)
for _ in range(request.response_count):
yield unary_stream_benchmark_pb2.BenchmarkResponse(response=payload)
server = grpc.server(futures.ThreadPoolExecutor(max_workers=1))
server.add_insecure_port('[::]:%d')
unary_stream_benchmark_pb2_grpc.add_UnaryStreamBenchmarkServiceServicer_to_server(Handler(), server)
server.start()
server.wait_for_termination()
""" % _PORT
try:
from src.python.grpcio_tests.tests.stress import unary_stream_benchmark_pb2
from src.python.grpcio_tests.tests.stress import unary_stream_benchmark_pb2_grpc
_GRPC_CHANNEL_OPTIONS = [
('grpc.max_metadata_size', 16 * 1024 * 1024),
('grpc.max_receive_message_length', 64 * 1024 * 1024),
(grpc.experimental.ChannelOptions.SingleThreadedUnaryStream, 1),
]
@contextlib.contextmanager
def _running_server():
server_process = subprocess.Popen(
[sys.executable, '-c', _SERVER_CODE],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
try:
yield
finally:
server_process.terminate()
server_process.wait()
sys.stdout.write("stdout: {}".format(server_process.stdout.read()))
sys.stdout.flush()
sys.stdout.write("stderr: {}".format(server_process.stderr.read()))
sys.stdout.flush()
def profile(message_size, response_count):
request = unary_stream_benchmark_pb2.BenchmarkRequest(
message_size=message_size, response_count=response_count)
with grpc.insecure_channel(
'[::]:{}'.format(_PORT),
options=_GRPC_CHANNEL_OPTIONS) as channel:
stub = unary_stream_benchmark_pb2_grpc.UnaryStreamBenchmarkServiceStub(
channel)
start = datetime.datetime.now()
call = stub.Benchmark(request, wait_for_ready=True)
for message in call:
pass
end = datetime.datetime.now()
return end - start
def main():
with _running_server():
for i in range(1000):
latency = profile(_MESSAGE_SIZE, 1024)
sys.stdout.write("{}\n".format(latency.total_seconds()))
sys.stdout.flush()
if __name__ == '__main__':
main()
except ImportError:
# NOTE(rbellevi): The test runner should not load this module.
pass

@ -23,6 +23,7 @@ GRPCIO_TESTS_UNIT = [
"_invocation_defects_test.py",
"_local_credentials_test.py",
"_logging_test.py",
"_metadata_flags_test.py",
"_metadata_code_details_test.py",
"_metadata_test.py",
# TODO: Issue 16336

@ -255,8 +255,8 @@ class MetadataCodeDetailsTest(unittest.TestCase):
response_iterator_call = self._unary_stream(
_SERIALIZED_REQUEST, metadata=_CLIENT_METADATA)
received_initial_metadata = response_iterator_call.initial_metadata()
list(response_iterator_call)
received_initial_metadata = response_iterator_call.initial_metadata()
self.assertTrue(
test_common.metadata_transmitted(
@ -349,11 +349,14 @@ class MetadataCodeDetailsTest(unittest.TestCase):
response_iterator_call = self._unary_stream(
_SERIALIZED_REQUEST, metadata=_CLIENT_METADATA)
received_initial_metadata = \
response_iterator_call.initial_metadata()
# NOTE: In the single-threaded case, we cannot grab the initial_metadata
# without running the RPC first (or concurrently, in another
# thread).
with self.assertRaises(grpc.RpcError):
self.assertEqual(len(list(response_iterator_call)), 0)
received_initial_metadata = \
response_iterator_call.initial_metadata()
self.assertTrue(
test_common.metadata_transmitted(
_CLIENT_METADATA,
@ -454,9 +457,9 @@ class MetadataCodeDetailsTest(unittest.TestCase):
response_iterator_call = self._unary_stream(
_SERIALIZED_REQUEST, metadata=_CLIENT_METADATA)
received_initial_metadata = response_iterator_call.initial_metadata()
with self.assertRaises(grpc.RpcError):
list(response_iterator_call)
received_initial_metadata = response_iterator_call.initial_metadata()
self.assertTrue(
test_common.metadata_transmitted(
@ -547,9 +550,9 @@ class MetadataCodeDetailsTest(unittest.TestCase):
response_iterator_call = self._unary_stream(
_SERIALIZED_REQUEST, metadata=_CLIENT_METADATA)
received_initial_metadata = response_iterator_call.initial_metadata()
with self.assertRaises(grpc.RpcError):
list(response_iterator_call)
received_initial_metadata = response_iterator_call.initial_metadata()
self.assertTrue(
test_common.metadata_transmitted(

@ -94,10 +94,10 @@ class _GenericHandler(grpc.GenericRpcHandler):
def get_free_loopback_tcp_port():
tcp = socket.socket(socket.AF_INET6)
tcp = socket.socket(socket.AF_INET)
tcp.bind(('', 0))
address_tuple = tcp.getsockname()
return tcp, "[::1]:%s" % (address_tuple[1])
return tcp, "localhost:%s" % (address_tuple[1])
def create_dummy_channel():
@ -183,7 +183,7 @@ class MetadataFlagsTest(unittest.TestCase):
fn(channel, wait_for_ready)
self.fail("The Call should fail")
except BaseException as e: # pylint: disable=broad-except
self.assertIn('StatusCode.UNAVAILABLE', str(e))
self.assertIs(grpc.StatusCode.UNAVAILABLE, e.code())
def test_call_wait_for_ready_default(self):
for perform_call in _ALL_CALL_CASES:

@ -202,6 +202,9 @@ class MetadataTest(unittest.TestCase):
def testUnaryStream(self):
multi_callable = self._channel.unary_stream(_UNARY_STREAM)
call = multi_callable(_REQUEST, metadata=_INVOCATION_METADATA)
# TODO(https://github.com/grpc/grpc/issues/20762): Make the call to
# `next()` unnecessary.
next(call)
self.assertTrue(
test_common.metadata_transmitted(_EXPECTED_INITIAL_METADATA,
call.initial_metadata()))

Loading…
Cancel
Save