Fix main thread starvation issues

pull/19465/head
Richard Belleville 6 years ago
parent 1db141accc
commit 915e97b115
  1. 74
      examples/python/cancellation/README.md
  2. 49
      examples/python/cancellation/client.py
  3. 2
      examples/python/cancellation/server.py
  4. 2
      examples/python/cancellation/test/_cancellation_example_test.py

@ -40,30 +40,16 @@ stub = hash_name_pb2_grpc.HashFinderStub(channel)
future = stub.Find.future(hash_name_pb2.HashNameRequest(desired_name=name))
def cancel_request(unused_signum, unused_frame):
future.cancel()
sys.exit(0)
signal.signal(signal.SIGINT, cancel_request)
```
It's also important that you not block indefinitely on the RPC. Otherwise, the
signal handler will never have a chance to run.
```python
while True:
try:
result = future.result(timeout=_TIMEOUT_SECONDS)
except grpc.FutureTimeoutError:
continue
except grpc.FutureCancelledError:
break
print("Got response: \n{}".format(result))
break
result = future.result()
print(result)
```
Here, we repeatedly block on a result for up to `_TIMEOUT_SECONDS`. Doing so
gives the signal handlers a chance to run. In the case that our timeout
was reached, we simply continue on in the loop. In the case that the RPC was
cancelled (by our user's ctrl+c), we break out of the loop cleanly. Finally, if
we received the result of the RPC, we print it out for the user and exit the
loop.
We also call `sys.exit(0)` to terminate the process. If we do not do this, then
`future.result()` with throw an `RpcError`. Alternatively, you may catch this
exception.
##### Cancelling a Server-Side Streaming RPC from the Client
@ -78,53 +64,15 @@ stub = hash_name_pb2_grpc.HashFinderStub(channel)
result_generator = stub.FindRange(hash_name_pb2.HashNameRequest(desired_name=name))
def cancel_request(unused_signum, unused_frame):
result_generator.cancel()
sys.exit(0)
signal.signal(signal.SIGINT, cancel_request)
for result in result_generator:
print(result)
```
However, the streaming case is complicated by the fact that there is no way to
propagate a timeout to Python generators. As a result, simply iterating over the
results of the RPC can block indefinitely and the signal handler may never run.
Instead, we iterate over the generator on another thread and retrieve the
results on the main thread with a synchronized `Queue`.
```python
result_queue = Queue()
def iterate_responses(result_generator, result_queue):
try:
for result in result_generator:
result_queue.put(result)
except grpc.RpcError as rpc_error:
if rpc_error.code() != grpc.StatusCode.CANCELLED:
result_queue.put(None)
raise rpc_error
result_queue.put(None)
print("RPC complete")
response_thread = threading.Thread(target=iterate_responses, args=(result_generator, result_queue))
response_thread.daemon = True
response_thread.start()
```
While this thread iterating over the results may block indefinitely, we can
structure the code running on our main thread in such a way that signal handlers
are guaranteed to be run at least every `_TIMEOUT_SECONDS` seconds.
```python
while result_generator.running():
try:
result = result_queue.get(timeout=_TIMEOUT_SECONDS)
except QueueEmpty:
continue
if result is None:
break
print("Got result: {}".format(result))
```
Similarly to the unary example above, we continue in a loop waiting for results,
taking care to block for intervals of `_TIMEOUT_SECONDS` at the longest.
Finally, we use `None` as a sentinel value to signal the end of the stream.
We also call `sys.exit(0)` here to terminate the process. Alternatively, you may
catch the `RpcError` raised by the for loop upon cancellation.
Using this scheme, our process responds nicely to `SIGINT`s while also
explicitly cancelling its RPCs.
#### Cancellation on the Server Side

@ -27,6 +27,8 @@ from six.moves.queue import Queue
from six.moves.queue import Empty as QueueEmpty
import grpc
import os
import sys
from examples.python.cancellation import hash_name_pb2
from examples.python.cancellation import hash_name_pb2_grpc
@ -34,8 +36,6 @@ from examples.python.cancellation import hash_name_pb2_grpc
_DESCRIPTION = "A client for finding hashes similar to names."
_LOGGER = logging.getLogger(__name__)
_TIMEOUT_SECONDS = 0.05
def run_unary_client(server_target, name, ideal_distance):
with grpc.insecure_channel(server_target) as channel:
@ -47,21 +47,11 @@ def run_unary_client(server_target, name, ideal_distance):
def cancel_request(unused_signum, unused_frame):
future.cancel()
sys.exit(0)
signal.signal(signal.SIGINT, cancel_request)
while True:
try:
result = future.result(timeout=_TIMEOUT_SECONDS)
except grpc.FutureTimeoutError:
continue
except grpc.FutureCancelledError:
break
except grpc.RpcError as rpc_error:
if rpc_error.code() == grpc.StatusCode.CANCELLED:
break
raise rpc_error
print(result)
break
result = future.result()
print(result)
def run_streaming_client(server_target, name, ideal_distance,
@ -77,35 +67,10 @@ def run_streaming_client(server_target, name, ideal_distance,
def cancel_request(unused_signum, unused_frame):
result_generator.cancel()
sys.exit(0)
signal.signal(signal.SIGINT, cancel_request)
result_queue = Queue()
def iterate_responses(result_generator, result_queue):
try:
for result in result_generator:
result_queue.put(result)
except grpc.RpcError as rpc_error:
if rpc_error.code() != grpc.StatusCode.CANCELLED:
result_queue.put(None)
raise rpc_error
# Enqueue a sentinel to signal the end of the stream.
result_queue.put(None)
# TODO(https://github.com/grpc/grpc/issues/19464): Do everything on the
# main thread.
response_thread = threading.Thread(
target=iterate_responses, args=(result_generator, result_queue))
response_thread.daemon = True
response_thread.start()
while result_generator.running():
try:
result = result_queue.get(timeout=_TIMEOUT_SECONDS)
except QueueEmpty:
continue
if result is None:
break
for result in result_generator:
print(result)

@ -117,7 +117,7 @@ def main():
parser.add_argument(
'--maximum-hashes',
type=int,
default=10000,
default=1000000,
nargs='?',
help='The maximum number of hashes to search before cancelling.')
args = parser.parse_args()

@ -74,7 +74,7 @@ class CancellationExampleTest(unittest.TestCase):
client_process1 = _start_client(test_port, 'aaaaaaaaaa', 0)
client_process1.send_signal(signal.SIGINT)
client_process1.wait()
client_process2 = _start_client(test_port, 'aaaaaaaaaa', 0)
client_process2 = _start_client(test_port, 'aa', 0)
client_return_code = client_process2.wait()
self.assertEqual(0, client_return_code)
self.assertIsNone(server_process.poll())

Loading…
Cancel
Save