mirror of https://github.com/grpc/grpc.git
Merge pull request #19465 from gnossen/cancellation_example
Add Python Cancellation Examplepull/19837/head
commit
b88a227135
7 changed files with 731 additions and 0 deletions
@ -0,0 +1,75 @@ |
|||||||
|
# gRPC Bazel BUILD file. |
||||||
|
# |
||||||
|
# Copyright 2019 The gRPC authors. |
||||||
|
# |
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||||
|
# you may not use this file except in compliance with the License. |
||||||
|
# You may obtain a copy of the License at |
||||||
|
# |
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||||
|
# |
||||||
|
# Unless required by applicable law or agreed to in writing, software |
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||||
|
# See the License for the specific language governing permissions and |
||||||
|
# limitations under the License. |
||||||
|
|
||||||
|
load("@grpc_python_dependencies//:requirements.bzl", "requirement") |
||||||
|
load("//bazel:python_rules.bzl", "py_proto_library") |
||||||
|
|
||||||
|
package(default_testonly = 1) |
||||||
|
|
||||||
|
proto_library( |
||||||
|
name = "hash_name_proto", |
||||||
|
srcs = ["hash_name.proto"], |
||||||
|
) |
||||||
|
|
||||||
|
py_proto_library( |
||||||
|
name = "hash_name_proto_pb2", |
||||||
|
deps = [":hash_name_proto"], |
||||||
|
well_known_protos = False, |
||||||
|
) |
||||||
|
|
||||||
|
py_binary( |
||||||
|
name = "client", |
||||||
|
srcs = ["client.py"], |
||||||
|
deps = [ |
||||||
|
"//src/python/grpcio/grpc:grpcio", |
||||||
|
":hash_name_proto_pb2", |
||||||
|
requirement("six"), |
||||||
|
], |
||||||
|
srcs_version = "PY2AND3", |
||||||
|
) |
||||||
|
|
||||||
|
py_library( |
||||||
|
name = "search", |
||||||
|
srcs = ["search.py"], |
||||||
|
srcs_version = "PY2AND3", |
||||||
|
deps = [ |
||||||
|
":hash_name_proto_pb2", |
||||||
|
], |
||||||
|
) |
||||||
|
|
||||||
|
py_binary( |
||||||
|
name = "server", |
||||||
|
srcs = ["server.py"], |
||||||
|
deps = [ |
||||||
|
"//src/python/grpcio/grpc:grpcio", |
||||||
|
":hash_name_proto_pb2", |
||||||
|
":search", |
||||||
|
] + select({ |
||||||
|
"//conditions:default": [requirement("futures")], |
||||||
|
"//:python3": [], |
||||||
|
}), |
||||||
|
srcs_version = "PY2AND3", |
||||||
|
) |
||||||
|
|
||||||
|
py_test( |
||||||
|
name = "test/_cancellation_example_test", |
||||||
|
srcs = ["test/_cancellation_example_test.py"], |
||||||
|
data = [ |
||||||
|
":client", |
||||||
|
":server" |
||||||
|
], |
||||||
|
size = "small", |
||||||
|
) |
@ -0,0 +1,127 @@ |
|||||||
|
### Cancellation |
||||||
|
|
||||||
|
In the example, we implement a silly algorithm. We search for bytestrings whose |
||||||
|
hashes are similar to a given search string. For example, say we're looking for |
||||||
|
the string "doctor". Our algorithm may return `JrqhZVkTDoctYrUlXDbL6pfYQHU=` or |
||||||
|
`RC9/7mlM3ldy4TdoctOc6WzYbO4=`. This is a brute force algorithm, so the server |
||||||
|
performing the search must be conscious of the resources it allows to each client |
||||||
|
and each client must be conscientious of the resources it demands of the server. |
||||||
|
|
||||||
|
In particular, we ensure that client processes cancel the stream explicitly |
||||||
|
before terminating and we ensure that server processes cancel RPCs that have gone on longer |
||||||
|
than a certain number of iterations. |
||||||
|
|
||||||
|
#### Cancellation on the Client Side |
||||||
|
|
||||||
|
A client may cancel an RPC for several reasons. Perhaps the data it requested |
||||||
|
has been made irrelevant. Perhaps you, as the client, want to be a good citizen |
||||||
|
of the server and are conserving compute resources. |
||||||
|
|
||||||
|
##### Cancelling a Server-Side Unary RPC from the Client |
||||||
|
|
||||||
|
The default RPC methods on a stub will simply return the result of an RPC. |
||||||
|
|
||||||
|
```python |
||||||
|
>>> stub = hash_name_pb2_grpc.HashFinderStub(channel) |
||||||
|
>>> stub.Find(hash_name_pb2.HashNameRequest(desired_name=name)) |
||||||
|
<hash_name_pb2.HashNameResponse object at 0x7fe2eb8ce2d0> |
||||||
|
``` |
||||||
|
|
||||||
|
But you may use the `future()` method to receive an instance of `grpc.Future`. |
||||||
|
This interface allows you to wait on a response with a timeout, add a callback |
||||||
|
to be executed when the RPC completes, or to cancel the RPC before it has |
||||||
|
completed. |
||||||
|
|
||||||
|
In the example, we use this interface to cancel our in-progress RPC when the |
||||||
|
user interrupts the process with ctrl-c. |
||||||
|
|
||||||
|
```python |
||||||
|
stub = hash_name_pb2_grpc.HashFinderStub(channel) |
||||||
|
future = stub.Find.future(hash_name_pb2.HashNameRequest(desired_name=name)) |
||||||
|
def cancel_request(unused_signum, unused_frame): |
||||||
|
future.cancel() |
||||||
|
sys.exit(0) |
||||||
|
signal.signal(signal.SIGINT, cancel_request) |
||||||
|
|
||||||
|
result = future.result() |
||||||
|
print(result) |
||||||
|
``` |
||||||
|
|
||||||
|
We also call `sys.exit(0)` to terminate the process. If we do not do this, then |
||||||
|
`future.result()` with throw an `RpcError`. Alternatively, you may catch this |
||||||
|
exception. |
||||||
|
|
||||||
|
|
||||||
|
##### Cancelling a Server-Side Streaming RPC from the Client |
||||||
|
|
||||||
|
Cancelling a Server-side streaming RPC is even simpler from the perspective of |
||||||
|
the gRPC API. The default stub method is already an instance of `grpc.Future`, |
||||||
|
so the methods outlined above still apply. It is also a generator, so we may |
||||||
|
iterate over it to yield the results of our RPC. |
||||||
|
|
||||||
|
```python |
||||||
|
stub = hash_name_pb2_grpc.HashFinderStub(channel) |
||||||
|
result_generator = stub.FindRange(hash_name_pb2.HashNameRequest(desired_name=name)) |
||||||
|
def cancel_request(unused_signum, unused_frame): |
||||||
|
result_generator.cancel() |
||||||
|
sys.exit(0) |
||||||
|
signal.signal(signal.SIGINT, cancel_request) |
||||||
|
for result in result_generator: |
||||||
|
print(result) |
||||||
|
``` |
||||||
|
|
||||||
|
We also call `sys.exit(0)` here to terminate the process. Alternatively, you may |
||||||
|
catch the `RpcError` raised by the for loop upon cancellation. |
||||||
|
|
||||||
|
|
||||||
|
#### Cancellation on the Server Side |
||||||
|
|
||||||
|
A server is reponsible for cancellation in two ways. It must respond in some way |
||||||
|
when a client initiates a cancellation, otherwise long-running computations |
||||||
|
could continue indefinitely. |
||||||
|
|
||||||
|
It may also decide to cancel the RPC for its own reasons. In our example, the |
||||||
|
server can be configured to cancel an RPC after a certain number of hashes has |
||||||
|
been computed in order to conserve compute resources. |
||||||
|
|
||||||
|
##### Responding to Cancellations from a Servicer Thread |
||||||
|
|
||||||
|
It's important to remember that a gRPC Python server is backed by a thread pool |
||||||
|
with a fixed size. When an RPC is cancelled, the library does *not* terminate |
||||||
|
your servicer thread. It is your responsibility as the application author to |
||||||
|
ensure that your servicer thread terminates soon after the RPC has been |
||||||
|
cancelled. |
||||||
|
|
||||||
|
In this example, we use the `ServicerContext.add_callback` method to set a |
||||||
|
`threading.Event` object when the RPC is terminated. We pass this `Event` object |
||||||
|
down through our hashing algorithm and ensure to check that the RPC is still |
||||||
|
ongoing before each iteration. |
||||||
|
|
||||||
|
```python |
||||||
|
stop_event = threading.Event() |
||||||
|
def on_rpc_done(): |
||||||
|
# Regain servicer thread. |
||||||
|
stop_event.set() |
||||||
|
context.add_callback(on_rpc_done) |
||||||
|
secret = _find_secret(stop_event) |
||||||
|
``` |
||||||
|
|
||||||
|
##### Initiating a Cancellation on the Server Side |
||||||
|
|
||||||
|
Initiating a cancellation from the server side is simpler. Just call |
||||||
|
`ServicerContext.cancel()`. |
||||||
|
|
||||||
|
In our example, we ensure that no single client is monopolizing the server by |
||||||
|
cancelling after a configurable number of hashes have been checked. |
||||||
|
|
||||||
|
```python |
||||||
|
try: |
||||||
|
for candidate in secret_generator: |
||||||
|
yield candidate |
||||||
|
except ResourceLimitExceededError: |
||||||
|
print("Cancelling RPC due to exhausted resources.") |
||||||
|
context.cancel() |
||||||
|
``` |
||||||
|
|
||||||
|
In this type of situation, you may also consider returning a more specific error |
||||||
|
using the [`grpcio-status`](https://pypi.org/project/grpcio-status/) package. |
@ -0,0 +1,104 @@ |
|||||||
|
# Copyright 2019 the gRPC authors. |
||||||
|
# |
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||||
|
# you may not use this file except in compliance with the License. |
||||||
|
# You may obtain a copy of the License at |
||||||
|
# |
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||||
|
# |
||||||
|
# Unless required by applicable law or agreed to in writing, software |
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||||
|
# See the License for the specific language governing permissions and |
||||||
|
# limitations under the License. |
||||||
|
"""An example of cancelling requests in gRPC.""" |
||||||
|
|
||||||
|
from __future__ import absolute_import |
||||||
|
from __future__ import division |
||||||
|
from __future__ import print_function |
||||||
|
|
||||||
|
import argparse |
||||||
|
import logging |
||||||
|
import signal |
||||||
|
import sys |
||||||
|
|
||||||
|
import grpc |
||||||
|
|
||||||
|
from examples.python.cancellation import hash_name_pb2 |
||||||
|
from examples.python.cancellation import hash_name_pb2_grpc |
||||||
|
|
||||||
|
_DESCRIPTION = "A client for finding hashes similar to names." |
||||||
|
_LOGGER = logging.getLogger(__name__) |
||||||
|
|
||||||
|
|
||||||
|
def run_unary_client(server_target, name, ideal_distance): |
||||||
|
with grpc.insecure_channel(server_target) as channel: |
||||||
|
stub = hash_name_pb2_grpc.HashFinderStub(channel) |
||||||
|
future = stub.Find.future( |
||||||
|
hash_name_pb2.HashNameRequest( |
||||||
|
desired_name=name, ideal_hamming_distance=ideal_distance), |
||||||
|
wait_for_ready=True) |
||||||
|
|
||||||
|
def cancel_request(unused_signum, unused_frame): |
||||||
|
future.cancel() |
||||||
|
sys.exit(0) |
||||||
|
|
||||||
|
signal.signal(signal.SIGINT, cancel_request) |
||||||
|
result = future.result() |
||||||
|
print(result) |
||||||
|
|
||||||
|
|
||||||
|
def run_streaming_client(server_target, name, ideal_distance, |
||||||
|
interesting_distance): |
||||||
|
with grpc.insecure_channel(server_target) as channel: |
||||||
|
stub = hash_name_pb2_grpc.HashFinderStub(channel) |
||||||
|
result_generator = stub.FindRange( |
||||||
|
hash_name_pb2.HashNameRequest( |
||||||
|
desired_name=name, |
||||||
|
ideal_hamming_distance=ideal_distance, |
||||||
|
interesting_hamming_distance=interesting_distance), |
||||||
|
wait_for_ready=True) |
||||||
|
|
||||||
|
def cancel_request(unused_signum, unused_frame): |
||||||
|
result_generator.cancel() |
||||||
|
sys.exit(0) |
||||||
|
|
||||||
|
signal.signal(signal.SIGINT, cancel_request) |
||||||
|
for result in result_generator: |
||||||
|
print(result) |
||||||
|
|
||||||
|
|
||||||
|
def main(): |
||||||
|
parser = argparse.ArgumentParser(description=_DESCRIPTION) |
||||||
|
parser.add_argument("name", type=str, help='The desired name.') |
||||||
|
parser.add_argument( |
||||||
|
"--ideal-distance", |
||||||
|
default=0, |
||||||
|
nargs='?', |
||||||
|
type=int, |
||||||
|
help="The desired Hamming distance.") |
||||||
|
parser.add_argument( |
||||||
|
'--server', |
||||||
|
default='localhost:50051', |
||||||
|
type=str, |
||||||
|
nargs='?', |
||||||
|
help='The host-port pair at which to reach the server.') |
||||||
|
parser.add_argument( |
||||||
|
'--show-inferior', |
||||||
|
default=None, |
||||||
|
type=int, |
||||||
|
nargs='?', |
||||||
|
help='Also show candidates with a Hamming distance less than this value.' |
||||||
|
) |
||||||
|
|
||||||
|
args = parser.parse_args() |
||||||
|
if args.show_inferior is not None: |
||||||
|
run_streaming_client(args.server, args.name, args.ideal_distance, |
||||||
|
args.show_inferior) |
||||||
|
else: |
||||||
|
run_unary_client(args.server, args.name, args.ideal_distance) |
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__": |
||||||
|
logging.basicConfig() |
||||||
|
main() |
@ -0,0 +1,56 @@ |
|||||||
|
// Copyright 2019 the gRPC authors. |
||||||
|
// |
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); |
||||||
|
// you may not use this file except in compliance with the License. |
||||||
|
// You may obtain a copy of the License at |
||||||
|
// |
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0 |
||||||
|
// |
||||||
|
// Unless required by applicable law or agreed to in writing, software |
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS, |
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||||
|
// See the License for the specific language governing permissions and |
||||||
|
// limitations under the License. |
||||||
|
|
||||||
|
syntax = "proto3"; |
||||||
|
|
||||||
|
package hash_name; |
||||||
|
|
||||||
|
// A request for a single secret whose hash is similar to a desired name. |
||||||
|
message HashNameRequest { |
||||||
|
// The string that is desired in the secret's hash. |
||||||
|
string desired_name = 1; |
||||||
|
|
||||||
|
// The ideal Hamming distance betwen desired_name and the secret that will |
||||||
|
// be searched for. |
||||||
|
int32 ideal_hamming_distance = 2; |
||||||
|
|
||||||
|
// A Hamming distance greater than the ideal Hamming distance. Search results |
||||||
|
// with a Hamming distance less than this value but greater than the ideal |
||||||
|
// distance will be returned back to the client but will not terminate the |
||||||
|
// search. |
||||||
|
int32 interesting_hamming_distance = 3; |
||||||
|
} |
||||||
|
|
||||||
|
message HashNameResponse { |
||||||
|
// The search result. |
||||||
|
string secret = 1; |
||||||
|
|
||||||
|
// The hash of the search result. A substring of this is of |
||||||
|
// ideal_hamming_distance Hamming distance or less from desired_name. |
||||||
|
string hashed_name = 2; |
||||||
|
|
||||||
|
// The Hamming distance between hashed_name and desired_name. |
||||||
|
int32 hamming_distance = 3; |
||||||
|
} |
||||||
|
|
||||||
|
service HashFinder { |
||||||
|
|
||||||
|
// Search for a single string whose hash is similar to the specified |
||||||
|
// desired_name. interesting_hamming_distance is ignored. |
||||||
|
rpc Find (HashNameRequest) returns (HashNameResponse) {} |
||||||
|
|
||||||
|
// Search for a string whose hash is similar to the specified desired_name, |
||||||
|
// but also stream back less-than-ideal candidates. |
||||||
|
rpc FindRange (HashNameRequest) returns (stream HashNameResponse) {} |
||||||
|
} |
@ -0,0 +1,148 @@ |
|||||||
|
# Copyright 2019 the gRPC authors. |
||||||
|
# |
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||||
|
# you may not use this file except in compliance with the License. |
||||||
|
# You may obtain a copy of the License at |
||||||
|
# |
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||||
|
# |
||||||
|
# Unless required by applicable law or agreed to in writing, software |
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||||
|
# See the License for the specific language governing permissions and |
||||||
|
# limitations under the License. |
||||||
|
"""A search algorithm over the space of all bytestrings.""" |
||||||
|
|
||||||
|
from __future__ import absolute_import |
||||||
|
from __future__ import division |
||||||
|
from __future__ import print_function |
||||||
|
|
||||||
|
import base64 |
||||||
|
import hashlib |
||||||
|
import itertools |
||||||
|
import logging |
||||||
|
import struct |
||||||
|
|
||||||
|
from examples.python.cancellation import hash_name_pb2 |
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__) |
||||||
|
_BYTE_MAX = 255 |
||||||
|
|
||||||
|
|
||||||
|
def _get_hamming_distance(a, b): |
||||||
|
"""Calculates hamming distance between strings of equal length.""" |
||||||
|
distance = 0 |
||||||
|
for char_a, char_b in zip(a, b): |
||||||
|
if char_a != char_b: |
||||||
|
distance += 1 |
||||||
|
return distance |
||||||
|
|
||||||
|
|
||||||
|
def _get_substring_hamming_distance(candidate, target): |
||||||
|
"""Calculates the minimum hamming distance between between the target |
||||||
|
and any substring of the candidate. |
||||||
|
|
||||||
|
Args: |
||||||
|
candidate: The string whose substrings will be tested. |
||||||
|
target: The target string. |
||||||
|
|
||||||
|
Returns: |
||||||
|
The minimum Hamming distance between candidate and target. |
||||||
|
""" |
||||||
|
min_distance = None |
||||||
|
if len(target) > len(candidate): |
||||||
|
raise ValueError("Candidate must be at least as long as target.") |
||||||
|
for i in range(len(candidate) - len(target) + 1): |
||||||
|
distance = _get_hamming_distance(candidate[i:i + len(target)].lower(), |
||||||
|
target.lower()) |
||||||
|
if min_distance is None or distance < min_distance: |
||||||
|
min_distance = distance |
||||||
|
return min_distance |
||||||
|
|
||||||
|
|
||||||
|
def _get_hash(secret): |
||||||
|
hasher = hashlib.sha1() |
||||||
|
hasher.update(secret) |
||||||
|
return base64.b64encode(hasher.digest()).decode('ascii') |
||||||
|
|
||||||
|
|
||||||
|
class ResourceLimitExceededError(Exception): |
||||||
|
"""Signifies the request has exceeded configured limits.""" |
||||||
|
|
||||||
|
|
||||||
|
def _bytestrings_of_length(length): |
||||||
|
"""Generates a stream containing all bytestrings of a given length. |
||||||
|
|
||||||
|
Args: |
||||||
|
length: A positive integer length. |
||||||
|
|
||||||
|
Yields: |
||||||
|
All bytestrings of length `length`. |
||||||
|
""" |
||||||
|
for digits in itertools.product(range(_BYTE_MAX), repeat=length): |
||||||
|
yield b''.join(struct.pack('B', i) for i in digits) |
||||||
|
|
||||||
|
|
||||||
|
def _all_bytestrings(): |
||||||
|
"""Generates a stream containing all possible bytestrings. |
||||||
|
|
||||||
|
This generator does not terminate. |
||||||
|
|
||||||
|
Yields: |
||||||
|
All bytestrings in ascending order of length. |
||||||
|
""" |
||||||
|
for bytestring in itertools.chain.from_iterable( |
||||||
|
_bytestrings_of_length(length) for length in itertools.count()): |
||||||
|
yield bytestring |
||||||
|
|
||||||
|
|
||||||
|
def search(target, |
||||||
|
ideal_distance, |
||||||
|
stop_event, |
||||||
|
maximum_hashes, |
||||||
|
interesting_hamming_distance=None): |
||||||
|
"""Find candidate strings. |
||||||
|
|
||||||
|
Search through the space of all bytestrings, in order of increasing length, |
||||||
|
indefinitely, until a hash with a Hamming distance of `maximum_distance` or |
||||||
|
less has been found. |
||||||
|
|
||||||
|
Args: |
||||||
|
target: The search string. |
||||||
|
ideal_distance: The desired Hamming distance. |
||||||
|
stop_event: An event indicating whether the RPC should terminate. |
||||||
|
maximum_hashes: The maximum number of hashes to check before stopping. |
||||||
|
interesting_hamming_distance: If specified, strings with a Hamming |
||||||
|
distance from the target below this value will be yielded. |
||||||
|
|
||||||
|
Yields: |
||||||
|
Instances of HashNameResponse. The final entry in the stream will be of |
||||||
|
`maximum_distance` Hamming distance or less from the target string, |
||||||
|
while all others will be of less than `interesting_hamming_distance`. |
||||||
|
|
||||||
|
Raises: |
||||||
|
ResourceLimitExceededError: If the computation exceeds `maximum_hashes` |
||||||
|
iterations. |
||||||
|
""" |
||||||
|
hashes_computed = 0 |
||||||
|
for secret in _all_bytestrings(): |
||||||
|
if stop_event.is_set(): |
||||||
|
raise StopIteration() # pylint: disable=stop-iteration-return |
||||||
|
candidate_hash = _get_hash(secret) |
||||||
|
distance = _get_substring_hamming_distance(candidate_hash, target) |
||||||
|
if interesting_hamming_distance is not None and distance <= interesting_hamming_distance: |
||||||
|
# Surface interesting candidates, but don't stop. |
||||||
|
yield hash_name_pb2.HashNameResponse( |
||||||
|
secret=base64.b64encode(secret), |
||||||
|
hashed_name=candidate_hash, |
||||||
|
hamming_distance=distance) |
||||||
|
elif distance <= ideal_distance: |
||||||
|
# Yield ideal candidate and end the stream. |
||||||
|
yield hash_name_pb2.HashNameResponse( |
||||||
|
secret=base64.b64encode(secret), |
||||||
|
hashed_name=candidate_hash, |
||||||
|
hamming_distance=distance) |
||||||
|
raise StopIteration() # pylint: disable=stop-iteration-return |
||||||
|
hashes_computed += 1 |
||||||
|
if hashes_computed == maximum_hashes: |
||||||
|
raise ResourceLimitExceededError() |
@ -0,0 +1,134 @@ |
|||||||
|
# Copyright 2019 the gRPC authors. |
||||||
|
# |
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||||
|
# you may not use this file except in compliance with the License. |
||||||
|
# You may obtain a copy of the License at |
||||||
|
# |
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||||
|
# |
||||||
|
# Unless required by applicable law or agreed to in writing, software |
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||||
|
# See the License for the specific language governing permissions and |
||||||
|
# limitations under the License. |
||||||
|
"""An example of cancelling requests in gRPC.""" |
||||||
|
|
||||||
|
from __future__ import absolute_import |
||||||
|
from __future__ import division |
||||||
|
from __future__ import print_function |
||||||
|
|
||||||
|
from concurrent import futures |
||||||
|
import argparse |
||||||
|
import contextlib |
||||||
|
import logging |
||||||
|
import time |
||||||
|
import threading |
||||||
|
|
||||||
|
import grpc |
||||||
|
import search |
||||||
|
|
||||||
|
from examples.python.cancellation import hash_name_pb2 |
||||||
|
from examples.python.cancellation import hash_name_pb2_grpc |
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__) |
||||||
|
_SERVER_HOST = 'localhost' |
||||||
|
_ONE_DAY_IN_SECONDS = 60 * 60 * 24 |
||||||
|
|
||||||
|
_DESCRIPTION = "A server for finding hashes similar to names." |
||||||
|
|
||||||
|
|
||||||
|
class HashFinder(hash_name_pb2_grpc.HashFinderServicer): |
||||||
|
|
||||||
|
def __init__(self, maximum_hashes): |
||||||
|
super(HashFinder, self).__init__() |
||||||
|
self._maximum_hashes = maximum_hashes |
||||||
|
|
||||||
|
def Find(self, request, context): |
||||||
|
stop_event = threading.Event() |
||||||
|
|
||||||
|
def on_rpc_done(): |
||||||
|
_LOGGER.debug("Attempting to regain servicer thread.") |
||||||
|
stop_event.set() |
||||||
|
|
||||||
|
context.add_callback(on_rpc_done) |
||||||
|
candidates = [] |
||||||
|
try: |
||||||
|
candidates = list( |
||||||
|
search.search(request.desired_name, |
||||||
|
request.ideal_hamming_distance, stop_event, |
||||||
|
self._maximum_hashes)) |
||||||
|
except search.ResourceLimitExceededError: |
||||||
|
_LOGGER.info("Cancelling RPC due to exhausted resources.") |
||||||
|
context.cancel() |
||||||
|
_LOGGER.debug("Servicer thread returning.") |
||||||
|
if not candidates: |
||||||
|
return hash_name_pb2.HashNameResponse() |
||||||
|
return candidates[-1] |
||||||
|
|
||||||
|
def FindRange(self, request, context): |
||||||
|
stop_event = threading.Event() |
||||||
|
|
||||||
|
def on_rpc_done(): |
||||||
|
_LOGGER.debug("Attempting to regain servicer thread.") |
||||||
|
stop_event.set() |
||||||
|
|
||||||
|
context.add_callback(on_rpc_done) |
||||||
|
secret_generator = search.search( |
||||||
|
request.desired_name, |
||||||
|
request.ideal_hamming_distance, |
||||||
|
stop_event, |
||||||
|
self._maximum_hashes, |
||||||
|
interesting_hamming_distance=request.interesting_hamming_distance) |
||||||
|
try: |
||||||
|
for candidate in secret_generator: |
||||||
|
yield candidate |
||||||
|
except search.ResourceLimitExceededError: |
||||||
|
_LOGGER.info("Cancelling RPC due to exhausted resources.") |
||||||
|
context.cancel() |
||||||
|
_LOGGER.debug("Regained servicer thread.") |
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager |
||||||
|
def _running_server(port, maximum_hashes): |
||||||
|
# We use only a single servicer thread here to demonstrate that, if managed |
||||||
|
# carefully, cancelled RPCs can need not continue occupying servicers |
||||||
|
# threads. |
||||||
|
server = grpc.server( |
||||||
|
futures.ThreadPoolExecutor(max_workers=1), maximum_concurrent_rpcs=1) |
||||||
|
hash_name_pb2_grpc.add_HashFinderServicer_to_server( |
||||||
|
HashFinder(maximum_hashes), server) |
||||||
|
address = '{}:{}'.format(_SERVER_HOST, port) |
||||||
|
actual_port = server.add_insecure_port(address) |
||||||
|
server.start() |
||||||
|
print("Server listening at '{}'".format(address)) |
||||||
|
try: |
||||||
|
yield actual_port |
||||||
|
except KeyboardInterrupt: |
||||||
|
pass |
||||||
|
finally: |
||||||
|
server.stop(None) |
||||||
|
|
||||||
|
|
||||||
|
def main(): |
||||||
|
parser = argparse.ArgumentParser(description=_DESCRIPTION) |
||||||
|
parser.add_argument( |
||||||
|
'--port', |
||||||
|
type=int, |
||||||
|
default=50051, |
||||||
|
nargs='?', |
||||||
|
help='The port on which the server will listen.') |
||||||
|
parser.add_argument( |
||||||
|
'--maximum-hashes', |
||||||
|
type=int, |
||||||
|
default=1000000, |
||||||
|
nargs='?', |
||||||
|
help='The maximum number of hashes to search before cancelling.') |
||||||
|
args = parser.parse_args() |
||||||
|
with _running_server(args.port, args.maximum_hashes): |
||||||
|
while True: |
||||||
|
time.sleep(_ONE_DAY_IN_SECONDS) |
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__": |
||||||
|
logging.basicConfig() |
||||||
|
main() |
@ -0,0 +1,87 @@ |
|||||||
|
# Copyright 2019 the gRPC authors. |
||||||
|
# |
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||||
|
# you may not use this file except in compliance with the License. |
||||||
|
# You may obtain a copy of the License at |
||||||
|
# |
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||||
|
# |
||||||
|
# Unless required by applicable law or agreed to in writing, software |
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||||
|
# See the License for the specific language governing permissions and |
||||||
|
# limitations under the License. |
||||||
|
"""Test for cancellation example.""" |
||||||
|
|
||||||
|
import contextlib |
||||||
|
import os |
||||||
|
import signal |
||||||
|
import socket |
||||||
|
import subprocess |
||||||
|
import unittest |
||||||
|
|
||||||
|
_BINARY_DIR = os.path.realpath( |
||||||
|
os.path.join(os.path.dirname(os.path.abspath(__file__)), '..')) |
||||||
|
_SERVER_PATH = os.path.join(_BINARY_DIR, 'server') |
||||||
|
_CLIENT_PATH = os.path.join(_BINARY_DIR, 'client') |
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager |
||||||
|
def _get_port(): |
||||||
|
sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) |
||||||
|
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) |
||||||
|
if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT) == 0: |
||||||
|
raise RuntimeError("Failed to set SO_REUSEPORT.") |
||||||
|
sock.bind(('', 0)) |
||||||
|
try: |
||||||
|
yield sock.getsockname()[1] |
||||||
|
finally: |
||||||
|
sock.close() |
||||||
|
|
||||||
|
|
||||||
|
def _start_client(server_port, |
||||||
|
desired_string, |
||||||
|
ideal_distance, |
||||||
|
interesting_distance=None): |
||||||
|
interesting_distance_args = () if interesting_distance is None else ( |
||||||
|
'--show-inferior', interesting_distance) |
||||||
|
return subprocess.Popen((_CLIENT_PATH, desired_string, '--server', |
||||||
|
'localhost:{}'.format(server_port), |
||||||
|
'--ideal-distance', |
||||||
|
str(ideal_distance)) + interesting_distance_args) |
||||||
|
|
||||||
|
|
||||||
|
class CancellationExampleTest(unittest.TestCase): |
||||||
|
|
||||||
|
def test_successful_run(self): |
||||||
|
with _get_port() as test_port: |
||||||
|
server_process = subprocess.Popen((_SERVER_PATH, '--port', |
||||||
|
str(test_port))) |
||||||
|
try: |
||||||
|
client_process = _start_client(test_port, 'aa', 0) |
||||||
|
client_return_code = client_process.wait() |
||||||
|
self.assertEqual(0, client_return_code) |
||||||
|
self.assertIsNone(server_process.poll()) |
||||||
|
finally: |
||||||
|
server_process.kill() |
||||||
|
server_process.wait() |
||||||
|
|
||||||
|
def test_graceful_sigint(self): |
||||||
|
with _get_port() as test_port: |
||||||
|
server_process = subprocess.Popen((_SERVER_PATH, '--port', |
||||||
|
str(test_port))) |
||||||
|
try: |
||||||
|
client_process1 = _start_client(test_port, 'aaaaaaaaaa', 0) |
||||||
|
client_process1.send_signal(signal.SIGINT) |
||||||
|
client_process1.wait() |
||||||
|
client_process2 = _start_client(test_port, 'aa', 0) |
||||||
|
client_return_code = client_process2.wait() |
||||||
|
self.assertEqual(0, client_return_code) |
||||||
|
self.assertIsNone(server_process.poll()) |
||||||
|
finally: |
||||||
|
server_process.kill() |
||||||
|
server_process.wait() |
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__': |
||||||
|
unittest.main(verbosity=2) |
Loading…
Reference in new issue