Merge pull request #19465 from gnossen/cancellation_example

Add Python Cancellation Example
pull/19837/head
Richard Belleville 5 years ago committed by GitHub
commit b88a227135
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 75
      examples/python/cancellation/BUILD.bazel
  2. 127
      examples/python/cancellation/README.md
  3. 104
      examples/python/cancellation/client.py
  4. 56
      examples/python/cancellation/hash_name.proto
  5. 148
      examples/python/cancellation/search.py
  6. 134
      examples/python/cancellation/server.py
  7. 87
      examples/python/cancellation/test/_cancellation_example_test.py

@ -0,0 +1,75 @@
# gRPC Bazel BUILD file.
#
# Copyright 2019 The gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load("@grpc_python_dependencies//:requirements.bzl", "requirement")
load("//bazel:python_rules.bzl", "py_proto_library")
package(default_testonly = 1)
proto_library(
name = "hash_name_proto",
srcs = ["hash_name.proto"],
)
py_proto_library(
name = "hash_name_proto_pb2",
deps = [":hash_name_proto"],
well_known_protos = False,
)
py_binary(
name = "client",
srcs = ["client.py"],
deps = [
"//src/python/grpcio/grpc:grpcio",
":hash_name_proto_pb2",
requirement("six"),
],
srcs_version = "PY2AND3",
)
py_library(
name = "search",
srcs = ["search.py"],
srcs_version = "PY2AND3",
deps = [
":hash_name_proto_pb2",
],
)
py_binary(
name = "server",
srcs = ["server.py"],
deps = [
"//src/python/grpcio/grpc:grpcio",
":hash_name_proto_pb2",
":search",
] + select({
"//conditions:default": [requirement("futures")],
"//:python3": [],
}),
srcs_version = "PY2AND3",
)
py_test(
name = "test/_cancellation_example_test",
srcs = ["test/_cancellation_example_test.py"],
data = [
":client",
":server"
],
size = "small",
)

@ -0,0 +1,127 @@
### Cancellation
In the example, we implement a silly algorithm. We search for bytestrings whose
hashes are similar to a given search string. For example, say we're looking for
the string "doctor". Our algorithm may return `JrqhZVkTDoctYrUlXDbL6pfYQHU=` or
`RC9/7mlM3ldy4TdoctOc6WzYbO4=`. This is a brute force algorithm, so the server
performing the search must be conscious of the resources it allows to each client
and each client must be conscientious of the resources it demands of the server.
In particular, we ensure that client processes cancel the stream explicitly
before terminating and we ensure that server processes cancel RPCs that have gone on longer
than a certain number of iterations.
#### Cancellation on the Client Side
A client may cancel an RPC for several reasons. Perhaps the data it requested
has been made irrelevant. Perhaps you, as the client, want to be a good citizen
of the server and are conserving compute resources.
##### Cancelling a Server-Side Unary RPC from the Client
The default RPC methods on a stub will simply return the result of an RPC.
```python
>>> stub = hash_name_pb2_grpc.HashFinderStub(channel)
>>> stub.Find(hash_name_pb2.HashNameRequest(desired_name=name))
<hash_name_pb2.HashNameResponse object at 0x7fe2eb8ce2d0>
```
But you may use the `future()` method to receive an instance of `grpc.Future`.
This interface allows you to wait on a response with a timeout, add a callback
to be executed when the RPC completes, or to cancel the RPC before it has
completed.
In the example, we use this interface to cancel our in-progress RPC when the
user interrupts the process with ctrl-c.
```python
stub = hash_name_pb2_grpc.HashFinderStub(channel)
future = stub.Find.future(hash_name_pb2.HashNameRequest(desired_name=name))
def cancel_request(unused_signum, unused_frame):
future.cancel()
sys.exit(0)
signal.signal(signal.SIGINT, cancel_request)
result = future.result()
print(result)
```
We also call `sys.exit(0)` to terminate the process. If we do not do this, then
`future.result()` with throw an `RpcError`. Alternatively, you may catch this
exception.
##### Cancelling a Server-Side Streaming RPC from the Client
Cancelling a Server-side streaming RPC is even simpler from the perspective of
the gRPC API. The default stub method is already an instance of `grpc.Future`,
so the methods outlined above still apply. It is also a generator, so we may
iterate over it to yield the results of our RPC.
```python
stub = hash_name_pb2_grpc.HashFinderStub(channel)
result_generator = stub.FindRange(hash_name_pb2.HashNameRequest(desired_name=name))
def cancel_request(unused_signum, unused_frame):
result_generator.cancel()
sys.exit(0)
signal.signal(signal.SIGINT, cancel_request)
for result in result_generator:
print(result)
```
We also call `sys.exit(0)` here to terminate the process. Alternatively, you may
catch the `RpcError` raised by the for loop upon cancellation.
#### Cancellation on the Server Side
A server is reponsible for cancellation in two ways. It must respond in some way
when a client initiates a cancellation, otherwise long-running computations
could continue indefinitely.
It may also decide to cancel the RPC for its own reasons. In our example, the
server can be configured to cancel an RPC after a certain number of hashes has
been computed in order to conserve compute resources.
##### Responding to Cancellations from a Servicer Thread
It's important to remember that a gRPC Python server is backed by a thread pool
with a fixed size. When an RPC is cancelled, the library does *not* terminate
your servicer thread. It is your responsibility as the application author to
ensure that your servicer thread terminates soon after the RPC has been
cancelled.
In this example, we use the `ServicerContext.add_callback` method to set a
`threading.Event` object when the RPC is terminated. We pass this `Event` object
down through our hashing algorithm and ensure to check that the RPC is still
ongoing before each iteration.
```python
stop_event = threading.Event()
def on_rpc_done():
# Regain servicer thread.
stop_event.set()
context.add_callback(on_rpc_done)
secret = _find_secret(stop_event)
```
##### Initiating a Cancellation on the Server Side
Initiating a cancellation from the server side is simpler. Just call
`ServicerContext.cancel()`.
In our example, we ensure that no single client is monopolizing the server by
cancelling after a configurable number of hashes have been checked.
```python
try:
for candidate in secret_generator:
yield candidate
except ResourceLimitExceededError:
print("Cancelling RPC due to exhausted resources.")
context.cancel()
```
In this type of situation, you may also consider returning a more specific error
using the [`grpcio-status`](https://pypi.org/project/grpcio-status/) package.

@ -0,0 +1,104 @@
# Copyright 2019 the gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""An example of cancelling requests in gRPC."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import logging
import signal
import sys
import grpc
from examples.python.cancellation import hash_name_pb2
from examples.python.cancellation import hash_name_pb2_grpc
_DESCRIPTION = "A client for finding hashes similar to names."
_LOGGER = logging.getLogger(__name__)
def run_unary_client(server_target, name, ideal_distance):
with grpc.insecure_channel(server_target) as channel:
stub = hash_name_pb2_grpc.HashFinderStub(channel)
future = stub.Find.future(
hash_name_pb2.HashNameRequest(
desired_name=name, ideal_hamming_distance=ideal_distance),
wait_for_ready=True)
def cancel_request(unused_signum, unused_frame):
future.cancel()
sys.exit(0)
signal.signal(signal.SIGINT, cancel_request)
result = future.result()
print(result)
def run_streaming_client(server_target, name, ideal_distance,
interesting_distance):
with grpc.insecure_channel(server_target) as channel:
stub = hash_name_pb2_grpc.HashFinderStub(channel)
result_generator = stub.FindRange(
hash_name_pb2.HashNameRequest(
desired_name=name,
ideal_hamming_distance=ideal_distance,
interesting_hamming_distance=interesting_distance),
wait_for_ready=True)
def cancel_request(unused_signum, unused_frame):
result_generator.cancel()
sys.exit(0)
signal.signal(signal.SIGINT, cancel_request)
for result in result_generator:
print(result)
def main():
parser = argparse.ArgumentParser(description=_DESCRIPTION)
parser.add_argument("name", type=str, help='The desired name.')
parser.add_argument(
"--ideal-distance",
default=0,
nargs='?',
type=int,
help="The desired Hamming distance.")
parser.add_argument(
'--server',
default='localhost:50051',
type=str,
nargs='?',
help='The host-port pair at which to reach the server.')
parser.add_argument(
'--show-inferior',
default=None,
type=int,
nargs='?',
help='Also show candidates with a Hamming distance less than this value.'
)
args = parser.parse_args()
if args.show_inferior is not None:
run_streaming_client(args.server, args.name, args.ideal_distance,
args.show_inferior)
else:
run_unary_client(args.server, args.name, args.ideal_distance)
if __name__ == "__main__":
logging.basicConfig()
main()

@ -0,0 +1,56 @@
// Copyright 2019 the gRPC authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto3";
package hash_name;
// A request for a single secret whose hash is similar to a desired name.
message HashNameRequest {
// The string that is desired in the secret's hash.
string desired_name = 1;
// The ideal Hamming distance betwen desired_name and the secret that will
// be searched for.
int32 ideal_hamming_distance = 2;
// A Hamming distance greater than the ideal Hamming distance. Search results
// with a Hamming distance less than this value but greater than the ideal
// distance will be returned back to the client but will not terminate the
// search.
int32 interesting_hamming_distance = 3;
}
message HashNameResponse {
// The search result.
string secret = 1;
// The hash of the search result. A substring of this is of
// ideal_hamming_distance Hamming distance or less from desired_name.
string hashed_name = 2;
// The Hamming distance between hashed_name and desired_name.
int32 hamming_distance = 3;
}
service HashFinder {
// Search for a single string whose hash is similar to the specified
// desired_name. interesting_hamming_distance is ignored.
rpc Find (HashNameRequest) returns (HashNameResponse) {}
// Search for a string whose hash is similar to the specified desired_name,
// but also stream back less-than-ideal candidates.
rpc FindRange (HashNameRequest) returns (stream HashNameResponse) {}
}

@ -0,0 +1,148 @@
# Copyright 2019 the gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A search algorithm over the space of all bytestrings."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import base64
import hashlib
import itertools
import logging
import struct
from examples.python.cancellation import hash_name_pb2
_LOGGER = logging.getLogger(__name__)
_BYTE_MAX = 255
def _get_hamming_distance(a, b):
"""Calculates hamming distance between strings of equal length."""
distance = 0
for char_a, char_b in zip(a, b):
if char_a != char_b:
distance += 1
return distance
def _get_substring_hamming_distance(candidate, target):
"""Calculates the minimum hamming distance between between the target
and any substring of the candidate.
Args:
candidate: The string whose substrings will be tested.
target: The target string.
Returns:
The minimum Hamming distance between candidate and target.
"""
min_distance = None
if len(target) > len(candidate):
raise ValueError("Candidate must be at least as long as target.")
for i in range(len(candidate) - len(target) + 1):
distance = _get_hamming_distance(candidate[i:i + len(target)].lower(),
target.lower())
if min_distance is None or distance < min_distance:
min_distance = distance
return min_distance
def _get_hash(secret):
hasher = hashlib.sha1()
hasher.update(secret)
return base64.b64encode(hasher.digest()).decode('ascii')
class ResourceLimitExceededError(Exception):
"""Signifies the request has exceeded configured limits."""
def _bytestrings_of_length(length):
"""Generates a stream containing all bytestrings of a given length.
Args:
length: A positive integer length.
Yields:
All bytestrings of length `length`.
"""
for digits in itertools.product(range(_BYTE_MAX), repeat=length):
yield b''.join(struct.pack('B', i) for i in digits)
def _all_bytestrings():
"""Generates a stream containing all possible bytestrings.
This generator does not terminate.
Yields:
All bytestrings in ascending order of length.
"""
for bytestring in itertools.chain.from_iterable(
_bytestrings_of_length(length) for length in itertools.count()):
yield bytestring
def search(target,
ideal_distance,
stop_event,
maximum_hashes,
interesting_hamming_distance=None):
"""Find candidate strings.
Search through the space of all bytestrings, in order of increasing length,
indefinitely, until a hash with a Hamming distance of `maximum_distance` or
less has been found.
Args:
target: The search string.
ideal_distance: The desired Hamming distance.
stop_event: An event indicating whether the RPC should terminate.
maximum_hashes: The maximum number of hashes to check before stopping.
interesting_hamming_distance: If specified, strings with a Hamming
distance from the target below this value will be yielded.
Yields:
Instances of HashNameResponse. The final entry in the stream will be of
`maximum_distance` Hamming distance or less from the target string,
while all others will be of less than `interesting_hamming_distance`.
Raises:
ResourceLimitExceededError: If the computation exceeds `maximum_hashes`
iterations.
"""
hashes_computed = 0
for secret in _all_bytestrings():
if stop_event.is_set():
raise StopIteration() # pylint: disable=stop-iteration-return
candidate_hash = _get_hash(secret)
distance = _get_substring_hamming_distance(candidate_hash, target)
if interesting_hamming_distance is not None and distance <= interesting_hamming_distance:
# Surface interesting candidates, but don't stop.
yield hash_name_pb2.HashNameResponse(
secret=base64.b64encode(secret),
hashed_name=candidate_hash,
hamming_distance=distance)
elif distance <= ideal_distance:
# Yield ideal candidate and end the stream.
yield hash_name_pb2.HashNameResponse(
secret=base64.b64encode(secret),
hashed_name=candidate_hash,
hamming_distance=distance)
raise StopIteration() # pylint: disable=stop-iteration-return
hashes_computed += 1
if hashes_computed == maximum_hashes:
raise ResourceLimitExceededError()

@ -0,0 +1,134 @@
# Copyright 2019 the gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""An example of cancelling requests in gRPC."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from concurrent import futures
import argparse
import contextlib
import logging
import time
import threading
import grpc
import search
from examples.python.cancellation import hash_name_pb2
from examples.python.cancellation import hash_name_pb2_grpc
_LOGGER = logging.getLogger(__name__)
_SERVER_HOST = 'localhost'
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
_DESCRIPTION = "A server for finding hashes similar to names."
class HashFinder(hash_name_pb2_grpc.HashFinderServicer):
def __init__(self, maximum_hashes):
super(HashFinder, self).__init__()
self._maximum_hashes = maximum_hashes
def Find(self, request, context):
stop_event = threading.Event()
def on_rpc_done():
_LOGGER.debug("Attempting to regain servicer thread.")
stop_event.set()
context.add_callback(on_rpc_done)
candidates = []
try:
candidates = list(
search.search(request.desired_name,
request.ideal_hamming_distance, stop_event,
self._maximum_hashes))
except search.ResourceLimitExceededError:
_LOGGER.info("Cancelling RPC due to exhausted resources.")
context.cancel()
_LOGGER.debug("Servicer thread returning.")
if not candidates:
return hash_name_pb2.HashNameResponse()
return candidates[-1]
def FindRange(self, request, context):
stop_event = threading.Event()
def on_rpc_done():
_LOGGER.debug("Attempting to regain servicer thread.")
stop_event.set()
context.add_callback(on_rpc_done)
secret_generator = search.search(
request.desired_name,
request.ideal_hamming_distance,
stop_event,
self._maximum_hashes,
interesting_hamming_distance=request.interesting_hamming_distance)
try:
for candidate in secret_generator:
yield candidate
except search.ResourceLimitExceededError:
_LOGGER.info("Cancelling RPC due to exhausted resources.")
context.cancel()
_LOGGER.debug("Regained servicer thread.")
@contextlib.contextmanager
def _running_server(port, maximum_hashes):
# We use only a single servicer thread here to demonstrate that, if managed
# carefully, cancelled RPCs can need not continue occupying servicers
# threads.
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=1), maximum_concurrent_rpcs=1)
hash_name_pb2_grpc.add_HashFinderServicer_to_server(
HashFinder(maximum_hashes), server)
address = '{}:{}'.format(_SERVER_HOST, port)
actual_port = server.add_insecure_port(address)
server.start()
print("Server listening at '{}'".format(address))
try:
yield actual_port
except KeyboardInterrupt:
pass
finally:
server.stop(None)
def main():
parser = argparse.ArgumentParser(description=_DESCRIPTION)
parser.add_argument(
'--port',
type=int,
default=50051,
nargs='?',
help='The port on which the server will listen.')
parser.add_argument(
'--maximum-hashes',
type=int,
default=1000000,
nargs='?',
help='The maximum number of hashes to search before cancelling.')
args = parser.parse_args()
with _running_server(args.port, args.maximum_hashes):
while True:
time.sleep(_ONE_DAY_IN_SECONDS)
if __name__ == "__main__":
logging.basicConfig()
main()

@ -0,0 +1,87 @@
# Copyright 2019 the gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test for cancellation example."""
import contextlib
import os
import signal
import socket
import subprocess
import unittest
_BINARY_DIR = os.path.realpath(
os.path.join(os.path.dirname(os.path.abspath(__file__)), '..'))
_SERVER_PATH = os.path.join(_BINARY_DIR, 'server')
_CLIENT_PATH = os.path.join(_BINARY_DIR, 'client')
@contextlib.contextmanager
def _get_port():
sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT) == 0:
raise RuntimeError("Failed to set SO_REUSEPORT.")
sock.bind(('', 0))
try:
yield sock.getsockname()[1]
finally:
sock.close()
def _start_client(server_port,
desired_string,
ideal_distance,
interesting_distance=None):
interesting_distance_args = () if interesting_distance is None else (
'--show-inferior', interesting_distance)
return subprocess.Popen((_CLIENT_PATH, desired_string, '--server',
'localhost:{}'.format(server_port),
'--ideal-distance',
str(ideal_distance)) + interesting_distance_args)
class CancellationExampleTest(unittest.TestCase):
def test_successful_run(self):
with _get_port() as test_port:
server_process = subprocess.Popen((_SERVER_PATH, '--port',
str(test_port)))
try:
client_process = _start_client(test_port, 'aa', 0)
client_return_code = client_process.wait()
self.assertEqual(0, client_return_code)
self.assertIsNone(server_process.poll())
finally:
server_process.kill()
server_process.wait()
def test_graceful_sigint(self):
with _get_port() as test_port:
server_process = subprocess.Popen((_SERVER_PATH, '--port',
str(test_port)))
try:
client_process1 = _start_client(test_port, 'aaaaaaaaaa', 0)
client_process1.send_signal(signal.SIGINT)
client_process1.wait()
client_process2 = _start_client(test_port, 'aa', 0)
client_return_code = client_process2.wait()
self.assertEqual(0, client_return_code)
self.assertIsNone(server_process.poll())
finally:
server_process.kill()
server_process.wait()
if __name__ == '__main__':
unittest.main(verbosity=2)
Loading…
Cancel
Save