mirror of https://github.com/grpc/grpc.git
Merge pull request #16264 from ericgribkoff/fork_support_v2
Support gRPC Python client-side fork with epoll1pull/16318/head
commit
2cec9c5344
25 changed files with 1167 additions and 74 deletions
@ -0,0 +1,29 @@ |
|||||||
|
# Copyright 2018 gRPC authors. |
||||||
|
# |
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||||
|
# you may not use this file except in compliance with the License. |
||||||
|
# You may obtain a copy of the License at |
||||||
|
# |
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||||
|
# |
||||||
|
# Unless required by applicable law or agreed to in writing, software |
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||||
|
# See the License for the specific language governing permissions and |
||||||
|
# limitations under the License. |
||||||
|
|
||||||
|
|
||||||
|
cdef extern from "pthread.h" nogil: |
||||||
|
int pthread_atfork( |
||||||
|
void (*prepare)() nogil, |
||||||
|
void (*parent)() nogil, |
||||||
|
void (*child)() nogil) |
||||||
|
|
||||||
|
|
||||||
|
cdef void __prefork() nogil |
||||||
|
|
||||||
|
|
||||||
|
cdef void __postfork_parent() nogil |
||||||
|
|
||||||
|
|
||||||
|
cdef void __postfork_child() nogil |
@ -0,0 +1,203 @@ |
|||||||
|
# Copyright 2018 gRPC authors. |
||||||
|
# |
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||||
|
# you may not use this file except in compliance with the License. |
||||||
|
# You may obtain a copy of the License at |
||||||
|
# |
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||||
|
# |
||||||
|
# Unless required by applicable law or agreed to in writing, software |
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||||
|
# See the License for the specific language governing permissions and |
||||||
|
# limitations under the License. |
||||||
|
|
||||||
|
|
||||||
|
import logging |
||||||
|
import os |
||||||
|
import threading |
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__) |
||||||
|
|
||||||
|
_AWAIT_THREADS_TIMEOUT_SECONDS = 5 |
||||||
|
|
||||||
|
_TRUE_VALUES = ['yes', 'Yes', 'YES', 'true', 'True', 'TRUE', '1'] |
||||||
|
|
||||||
|
# This flag enables experimental support within gRPC Python for applications |
||||||
|
# that will fork() without exec(). When enabled, gRPC Python will attempt to |
||||||
|
# pause all of its internally created threads before the fork syscall proceeds. |
||||||
|
# |
||||||
|
# For this to be successful, the application must not have multiple threads of |
||||||
|
# its own calling into gRPC when fork is invoked. Any callbacks from gRPC |
||||||
|
# Python-spawned threads into user code (e.g., callbacks for asynchronous RPCs) |
||||||
|
# must not block and should execute quickly. |
||||||
|
# |
||||||
|
# This flag is not supported on Windows. |
||||||
|
_GRPC_ENABLE_FORK_SUPPORT = ( |
||||||
|
os.environ.get('GRPC_ENABLE_FORK_SUPPORT', '0') |
||||||
|
.lower() in _TRUE_VALUES) |
||||||
|
|
||||||
|
_GRPC_POLL_STRATEGY = os.environ.get('GRPC_POLL_STRATEGY') |
||||||
|
|
||||||
|
cdef void __prefork() nogil: |
||||||
|
with gil: |
||||||
|
with _fork_state.fork_in_progress_condition: |
||||||
|
_fork_state.fork_in_progress = True |
||||||
|
if not _fork_state.active_thread_count.await_zero_threads( |
||||||
|
_AWAIT_THREADS_TIMEOUT_SECONDS): |
||||||
|
_LOGGER.error( |
||||||
|
'Failed to shutdown gRPC Python threads prior to fork. ' |
||||||
|
'Behavior after fork will be undefined.') |
||||||
|
|
||||||
|
|
||||||
|
cdef void __postfork_parent() nogil: |
||||||
|
with gil: |
||||||
|
with _fork_state.fork_in_progress_condition: |
||||||
|
_fork_state.fork_in_progress = False |
||||||
|
_fork_state.fork_in_progress_condition.notify_all() |
||||||
|
|
||||||
|
|
||||||
|
cdef void __postfork_child() nogil: |
||||||
|
with gil: |
||||||
|
# Thread could be holding the fork_in_progress_condition inside of |
||||||
|
# block_if_fork_in_progress() when fork occurs. Reset the lock here. |
||||||
|
_fork_state.fork_in_progress_condition = threading.Condition() |
||||||
|
# A thread in return_from_user_request_generator() may hold this lock |
||||||
|
# when fork occurs. |
||||||
|
_fork_state.active_thread_count = _ActiveThreadCount() |
||||||
|
for state_to_reset in _fork_state.postfork_states_to_reset: |
||||||
|
state_to_reset.reset_postfork_child() |
||||||
|
_fork_state.fork_epoch += 1 |
||||||
|
for channel in _fork_state.channels: |
||||||
|
channel._close_on_fork() |
||||||
|
# TODO(ericgribkoff) Check and abort if core is not shutdown |
||||||
|
with _fork_state.fork_in_progress_condition: |
||||||
|
_fork_state.fork_in_progress = False |
||||||
|
|
||||||
|
|
||||||
|
def fork_handlers_and_grpc_init(): |
||||||
|
grpc_init() |
||||||
|
if _GRPC_ENABLE_FORK_SUPPORT: |
||||||
|
# TODO(ericgribkoff) epoll1 is default for grpcio distribution. Decide whether to expose |
||||||
|
# grpc_get_poll_strategy_name() from ev_posix.cc to get actual polling choice. |
||||||
|
if _GRPC_POLL_STRATEGY is not None and _GRPC_POLL_STRATEGY != "epoll1": |
||||||
|
_LOGGER.error( |
||||||
|
'gRPC Python fork support is only compatible with the epoll1 ' |
||||||
|
'polling engine') |
||||||
|
return |
||||||
|
with _fork_state.fork_handler_registered_lock: |
||||||
|
if not _fork_state.fork_handler_registered: |
||||||
|
pthread_atfork(&__prefork, &__postfork_parent, &__postfork_child) |
||||||
|
_fork_state.fork_handler_registered = True |
||||||
|
|
||||||
|
|
||||||
|
class ForkManagedThread(object): |
||||||
|
def __init__(self, target, args=()): |
||||||
|
if _GRPC_ENABLE_FORK_SUPPORT: |
||||||
|
def managed_target(*args): |
||||||
|
try: |
||||||
|
target(*args) |
||||||
|
finally: |
||||||
|
_fork_state.active_thread_count.decrement() |
||||||
|
self._thread = threading.Thread(target=managed_target, args=args) |
||||||
|
else: |
||||||
|
self._thread = threading.Thread(target=target, args=args) |
||||||
|
|
||||||
|
def setDaemon(self, daemonic): |
||||||
|
self._thread.daemon = daemonic |
||||||
|
|
||||||
|
def start(self): |
||||||
|
if _GRPC_ENABLE_FORK_SUPPORT: |
||||||
|
_fork_state.active_thread_count.increment() |
||||||
|
self._thread.start() |
||||||
|
|
||||||
|
def join(self): |
||||||
|
self._thread.join() |
||||||
|
|
||||||
|
|
||||||
|
def block_if_fork_in_progress(postfork_state_to_reset=None): |
||||||
|
if _GRPC_ENABLE_FORK_SUPPORT: |
||||||
|
with _fork_state.fork_in_progress_condition: |
||||||
|
if not _fork_state.fork_in_progress: |
||||||
|
return |
||||||
|
if postfork_state_to_reset is not None: |
||||||
|
_fork_state.postfork_states_to_reset.append(postfork_state_to_reset) |
||||||
|
_fork_state.active_thread_count.decrement() |
||||||
|
_fork_state.fork_in_progress_condition.wait() |
||||||
|
_fork_state.active_thread_count.increment() |
||||||
|
|
||||||
|
|
||||||
|
def enter_user_request_generator(): |
||||||
|
if _GRPC_ENABLE_FORK_SUPPORT: |
||||||
|
_fork_state.active_thread_count.decrement() |
||||||
|
|
||||||
|
|
||||||
|
def return_from_user_request_generator(): |
||||||
|
if _GRPC_ENABLE_FORK_SUPPORT: |
||||||
|
_fork_state.active_thread_count.increment() |
||||||
|
block_if_fork_in_progress() |
||||||
|
|
||||||
|
|
||||||
|
def get_fork_epoch(): |
||||||
|
return _fork_state.fork_epoch |
||||||
|
|
||||||
|
|
||||||
|
def is_fork_support_enabled(): |
||||||
|
return _GRPC_ENABLE_FORK_SUPPORT |
||||||
|
|
||||||
|
|
||||||
|
def fork_register_channel(channel): |
||||||
|
if _GRPC_ENABLE_FORK_SUPPORT: |
||||||
|
_fork_state.channels.add(channel) |
||||||
|
|
||||||
|
|
||||||
|
def fork_unregister_channel(channel): |
||||||
|
if _GRPC_ENABLE_FORK_SUPPORT: |
||||||
|
_fork_state.channels.remove(channel) |
||||||
|
|
||||||
|
|
||||||
|
class _ActiveThreadCount(object): |
||||||
|
def __init__(self): |
||||||
|
self._num_active_threads = 0 |
||||||
|
self._condition = threading.Condition() |
||||||
|
|
||||||
|
def increment(self): |
||||||
|
with self._condition: |
||||||
|
self._num_active_threads += 1 |
||||||
|
|
||||||
|
def decrement(self): |
||||||
|
with self._condition: |
||||||
|
self._num_active_threads -= 1 |
||||||
|
if self._num_active_threads == 0: |
||||||
|
self._condition.notify_all() |
||||||
|
|
||||||
|
def await_zero_threads(self, timeout_secs): |
||||||
|
end_time = time.time() + timeout_secs |
||||||
|
wait_time = timeout_secs |
||||||
|
with self._condition: |
||||||
|
while True: |
||||||
|
if self._num_active_threads > 0: |
||||||
|
self._condition.wait(wait_time) |
||||||
|
if self._num_active_threads == 0: |
||||||
|
return True |
||||||
|
# Thread count may have increased before this re-obtains the |
||||||
|
# lock after a notify(). Wait again until timeout_secs has |
||||||
|
# elapsed. |
||||||
|
wait_time = end_time - time.time() |
||||||
|
if wait_time <= 0: |
||||||
|
return False |
||||||
|
|
||||||
|
|
||||||
|
class _ForkState(object): |
||||||
|
def __init__(self): |
||||||
|
self.fork_in_progress_condition = threading.Condition() |
||||||
|
self.fork_in_progress = False |
||||||
|
self.postfork_states_to_reset = [] |
||||||
|
self.fork_handler_registered_lock = threading.Lock() |
||||||
|
self.fork_handler_registered = False |
||||||
|
self.active_thread_count = _ActiveThreadCount() |
||||||
|
self.fork_epoch = 0 |
||||||
|
self.channels = set() |
||||||
|
|
||||||
|
|
||||||
|
_fork_state = _ForkState() |
@ -0,0 +1,63 @@ |
|||||||
|
# Copyright 2018 gRPC authors. |
||||||
|
# |
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||||
|
# you may not use this file except in compliance with the License. |
||||||
|
# You may obtain a copy of the License at |
||||||
|
# |
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||||
|
# |
||||||
|
# Unless required by applicable law or agreed to in writing, software |
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||||
|
# See the License for the specific language governing permissions and |
||||||
|
# limitations under the License. |
||||||
|
|
||||||
|
|
||||||
|
import threading |
||||||
|
|
||||||
|
# No-op implementations for Windows. |
||||||
|
|
||||||
|
def fork_handlers_and_grpc_init(): |
||||||
|
grpc_init() |
||||||
|
|
||||||
|
|
||||||
|
class ForkManagedThread(object): |
||||||
|
def __init__(self, target, args=()): |
||||||
|
self._thread = threading.Thread(target=target, args=args) |
||||||
|
|
||||||
|
def setDaemon(self, daemonic): |
||||||
|
self._thread.daemon = daemonic |
||||||
|
|
||||||
|
def start(self): |
||||||
|
self._thread.start() |
||||||
|
|
||||||
|
def join(self): |
||||||
|
self._thread.join() |
||||||
|
|
||||||
|
|
||||||
|
def block_if_fork_in_progress(postfork_state_to_reset=None): |
||||||
|
pass |
||||||
|
|
||||||
|
|
||||||
|
def enter_user_request_generator(): |
||||||
|
pass |
||||||
|
|
||||||
|
|
||||||
|
def return_from_user_request_generator(): |
||||||
|
pass |
||||||
|
|
||||||
|
|
||||||
|
def get_fork_epoch(): |
||||||
|
return 0 |
||||||
|
|
||||||
|
|
||||||
|
def is_fork_support_enabled(): |
||||||
|
return False |
||||||
|
|
||||||
|
|
||||||
|
def fork_register_channel(channel): |
||||||
|
pass |
||||||
|
|
||||||
|
|
||||||
|
def fork_unregister_channel(channel): |
||||||
|
pass |
@ -0,0 +1,13 @@ |
|||||||
|
# Copyright 2018 gRPC authors. |
||||||
|
# |
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||||
|
# you may not use this file except in compliance with the License. |
||||||
|
# You may obtain a copy of the License at |
||||||
|
# |
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||||
|
# |
||||||
|
# Unless required by applicable law or agreed to in writing, software |
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||||
|
# See the License for the specific language governing permissions and |
||||||
|
# limitations under the License. |
@ -0,0 +1,76 @@ |
|||||||
|
# Copyright 2018 gRPC authors. |
||||||
|
# |
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||||
|
# you may not use this file except in compliance with the License. |
||||||
|
# You may obtain a copy of the License at |
||||||
|
# |
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||||
|
# |
||||||
|
# Unless required by applicable law or agreed to in writing, software |
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||||
|
# See the License for the specific language governing permissions and |
||||||
|
# limitations under the License. |
||||||
|
"""The Python implementation of the GRPC interoperability test client.""" |
||||||
|
|
||||||
|
import argparse |
||||||
|
import logging |
||||||
|
import sys |
||||||
|
|
||||||
|
from tests.fork import methods |
||||||
|
|
||||||
|
|
||||||
|
def _args(): |
||||||
|
|
||||||
|
def parse_bool(value): |
||||||
|
if value == 'true': |
||||||
|
return True |
||||||
|
if value == 'false': |
||||||
|
return False |
||||||
|
raise argparse.ArgumentTypeError('Only true/false allowed') |
||||||
|
|
||||||
|
parser = argparse.ArgumentParser() |
||||||
|
parser.add_argument( |
||||||
|
'--server_host', |
||||||
|
default="localhost", |
||||||
|
type=str, |
||||||
|
help='the host to which to connect') |
||||||
|
parser.add_argument( |
||||||
|
'--server_port', |
||||||
|
type=int, |
||||||
|
required=True, |
||||||
|
help='the port to which to connect') |
||||||
|
parser.add_argument( |
||||||
|
'--test_case', |
||||||
|
default='large_unary', |
||||||
|
type=str, |
||||||
|
help='the test case to execute') |
||||||
|
parser.add_argument( |
||||||
|
'--use_tls', |
||||||
|
default=False, |
||||||
|
type=parse_bool, |
||||||
|
help='require a secure connection') |
||||||
|
return parser.parse_args() |
||||||
|
|
||||||
|
|
||||||
|
def _test_case_from_arg(test_case_arg): |
||||||
|
for test_case in methods.TestCase: |
||||||
|
if test_case_arg == test_case.value: |
||||||
|
return test_case |
||||||
|
else: |
||||||
|
raise ValueError('No test case "%s"!' % test_case_arg) |
||||||
|
|
||||||
|
|
||||||
|
def test_fork(): |
||||||
|
logging.basicConfig(level=logging.INFO) |
||||||
|
args = _args() |
||||||
|
if args.test_case == "all": |
||||||
|
for test_case in methods.TestCase: |
||||||
|
test_case.run_test(args) |
||||||
|
else: |
||||||
|
test_case = _test_case_from_arg(args.test_case) |
||||||
|
test_case.run_test(args) |
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__': |
||||||
|
test_fork() |
@ -0,0 +1,445 @@ |
|||||||
|
# Copyright 2018 gRPC authors. |
||||||
|
# |
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||||
|
# you may not use this file except in compliance with the License. |
||||||
|
# You may obtain a copy of the License at |
||||||
|
# |
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||||
|
# |
||||||
|
# Unless required by applicable law or agreed to in writing, software |
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||||
|
# See the License for the specific language governing permissions and |
||||||
|
# limitations under the License. |
||||||
|
"""Implementations of fork support test methods.""" |
||||||
|
|
||||||
|
import enum |
||||||
|
import json |
||||||
|
import logging |
||||||
|
import multiprocessing |
||||||
|
import os |
||||||
|
import threading |
||||||
|
import time |
||||||
|
|
||||||
|
import grpc |
||||||
|
|
||||||
|
from six.moves import queue |
||||||
|
|
||||||
|
from src.proto.grpc.testing import empty_pb2 |
||||||
|
from src.proto.grpc.testing import messages_pb2 |
||||||
|
from src.proto.grpc.testing import test_pb2_grpc |
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__) |
||||||
|
|
||||||
|
|
||||||
|
def _channel(args): |
||||||
|
target = '{}:{}'.format(args.server_host, args.server_port) |
||||||
|
if args.use_tls: |
||||||
|
channel_credentials = grpc.ssl_channel_credentials() |
||||||
|
channel = grpc.secure_channel(target, channel_credentials) |
||||||
|
else: |
||||||
|
channel = grpc.insecure_channel(target) |
||||||
|
return channel |
||||||
|
|
||||||
|
|
||||||
|
def _validate_payload_type_and_length(response, expected_type, expected_length): |
||||||
|
if response.payload.type is not expected_type: |
||||||
|
raise ValueError('expected payload type %s, got %s' % |
||||||
|
(expected_type, type(response.payload.type))) |
||||||
|
elif len(response.payload.body) != expected_length: |
||||||
|
raise ValueError('expected payload body size %d, got %d' % |
||||||
|
(expected_length, len(response.payload.body))) |
||||||
|
|
||||||
|
|
||||||
|
def _async_unary(stub): |
||||||
|
size = 314159 |
||||||
|
request = messages_pb2.SimpleRequest( |
||||||
|
response_type=messages_pb2.COMPRESSABLE, |
||||||
|
response_size=size, |
||||||
|
payload=messages_pb2.Payload(body=b'\x00' * 271828)) |
||||||
|
response_future = stub.UnaryCall.future(request) |
||||||
|
response = response_future.result() |
||||||
|
_validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE, size) |
||||||
|
|
||||||
|
|
||||||
|
def _blocking_unary(stub): |
||||||
|
size = 314159 |
||||||
|
request = messages_pb2.SimpleRequest( |
||||||
|
response_type=messages_pb2.COMPRESSABLE, |
||||||
|
response_size=size, |
||||||
|
payload=messages_pb2.Payload(body=b'\x00' * 271828)) |
||||||
|
response = stub.UnaryCall(request) |
||||||
|
_validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE, size) |
||||||
|
|
||||||
|
|
||||||
|
class _Pipe(object): |
||||||
|
|
||||||
|
def __init__(self): |
||||||
|
self._condition = threading.Condition() |
||||||
|
self._values = [] |
||||||
|
self._open = True |
||||||
|
|
||||||
|
def __iter__(self): |
||||||
|
return self |
||||||
|
|
||||||
|
def __next__(self): |
||||||
|
return self.next() |
||||||
|
|
||||||
|
def next(self): |
||||||
|
with self._condition: |
||||||
|
while not self._values and self._open: |
||||||
|
self._condition.wait() |
||||||
|
if self._values: |
||||||
|
return self._values.pop(0) |
||||||
|
else: |
||||||
|
raise StopIteration() |
||||||
|
|
||||||
|
def add(self, value): |
||||||
|
with self._condition: |
||||||
|
self._values.append(value) |
||||||
|
self._condition.notify() |
||||||
|
|
||||||
|
def close(self): |
||||||
|
with self._condition: |
||||||
|
self._open = False |
||||||
|
self._condition.notify() |
||||||
|
|
||||||
|
def __enter__(self): |
||||||
|
return self |
||||||
|
|
||||||
|
def __exit__(self, type, value, traceback): |
||||||
|
self.close() |
||||||
|
|
||||||
|
|
||||||
|
class _ChildProcess(object): |
||||||
|
|
||||||
|
def __init__(self, task, args=None): |
||||||
|
if args is None: |
||||||
|
args = () |
||||||
|
self._exceptions = multiprocessing.Queue() |
||||||
|
|
||||||
|
def record_exceptions(): |
||||||
|
try: |
||||||
|
task(*args) |
||||||
|
except Exception as e: # pylint: disable=broad-except |
||||||
|
self._exceptions.put(e) |
||||||
|
|
||||||
|
self._process = multiprocessing.Process(target=record_exceptions) |
||||||
|
|
||||||
|
def start(self): |
||||||
|
self._process.start() |
||||||
|
|
||||||
|
def finish(self): |
||||||
|
self._process.join() |
||||||
|
if self._process.exitcode != 0: |
||||||
|
raise ValueError('Child process failed with exitcode %d' % |
||||||
|
self._process.exitcode) |
||||||
|
try: |
||||||
|
exception = self._exceptions.get(block=False) |
||||||
|
raise ValueError('Child process failed: %s' % exception) |
||||||
|
except queue.Empty: |
||||||
|
pass |
||||||
|
|
||||||
|
|
||||||
|
def _async_unary_same_channel(channel): |
||||||
|
|
||||||
|
def child_target(): |
||||||
|
try: |
||||||
|
_async_unary(stub) |
||||||
|
raise Exception( |
||||||
|
'Child should not be able to re-use channel after fork') |
||||||
|
except ValueError as expected_value_error: |
||||||
|
pass |
||||||
|
|
||||||
|
stub = test_pb2_grpc.TestServiceStub(channel) |
||||||
|
_async_unary(stub) |
||||||
|
child_process = _ChildProcess(child_target) |
||||||
|
child_process.start() |
||||||
|
_async_unary(stub) |
||||||
|
child_process.finish() |
||||||
|
|
||||||
|
|
||||||
|
def _async_unary_new_channel(channel, args): |
||||||
|
|
||||||
|
def child_target(): |
||||||
|
child_channel = _channel(args) |
||||||
|
child_stub = test_pb2_grpc.TestServiceStub(child_channel) |
||||||
|
_async_unary(child_stub) |
||||||
|
child_channel.close() |
||||||
|
|
||||||
|
stub = test_pb2_grpc.TestServiceStub(channel) |
||||||
|
_async_unary(stub) |
||||||
|
child_process = _ChildProcess(child_target) |
||||||
|
child_process.start() |
||||||
|
_async_unary(stub) |
||||||
|
child_process.finish() |
||||||
|
|
||||||
|
|
||||||
|
def _blocking_unary_same_channel(channel): |
||||||
|
|
||||||
|
def child_target(): |
||||||
|
try: |
||||||
|
_blocking_unary(stub) |
||||||
|
raise Exception( |
||||||
|
'Child should not be able to re-use channel after fork') |
||||||
|
except ValueError as expected_value_error: |
||||||
|
pass |
||||||
|
|
||||||
|
stub = test_pb2_grpc.TestServiceStub(channel) |
||||||
|
_blocking_unary(stub) |
||||||
|
child_process = _ChildProcess(child_target) |
||||||
|
child_process.start() |
||||||
|
child_process.finish() |
||||||
|
|
||||||
|
|
||||||
|
def _blocking_unary_new_channel(channel, args): |
||||||
|
|
||||||
|
def child_target(): |
||||||
|
child_channel = _channel(args) |
||||||
|
child_stub = test_pb2_grpc.TestServiceStub(child_channel) |
||||||
|
_blocking_unary(child_stub) |
||||||
|
child_channel.close() |
||||||
|
|
||||||
|
stub = test_pb2_grpc.TestServiceStub(channel) |
||||||
|
_blocking_unary(stub) |
||||||
|
child_process = _ChildProcess(child_target) |
||||||
|
child_process.start() |
||||||
|
_blocking_unary(stub) |
||||||
|
child_process.finish() |
||||||
|
|
||||||
|
|
||||||
|
# Verify that the fork channel registry can handle already closed channels |
||||||
|
def _close_channel_before_fork(channel, args): |
||||||
|
|
||||||
|
def child_target(): |
||||||
|
new_channel.close() |
||||||
|
child_channel = _channel(args) |
||||||
|
child_stub = test_pb2_grpc.TestServiceStub(child_channel) |
||||||
|
_blocking_unary(child_stub) |
||||||
|
child_channel.close() |
||||||
|
|
||||||
|
stub = test_pb2_grpc.TestServiceStub(channel) |
||||||
|
_blocking_unary(stub) |
||||||
|
channel.close() |
||||||
|
|
||||||
|
new_channel = _channel(args) |
||||||
|
new_stub = test_pb2_grpc.TestServiceStub(new_channel) |
||||||
|
child_process = _ChildProcess(child_target) |
||||||
|
child_process.start() |
||||||
|
_blocking_unary(new_stub) |
||||||
|
child_process.finish() |
||||||
|
|
||||||
|
|
||||||
|
def _connectivity_watch(channel, args): |
||||||
|
|
||||||
|
def child_target(): |
||||||
|
|
||||||
|
def child_connectivity_callback(state): |
||||||
|
child_states.append(state) |
||||||
|
|
||||||
|
child_states = [] |
||||||
|
child_channel = _channel(args) |
||||||
|
child_stub = test_pb2_grpc.TestServiceStub(child_channel) |
||||||
|
child_channel.subscribe(child_connectivity_callback) |
||||||
|
_async_unary(child_stub) |
||||||
|
if len(child_states |
||||||
|
) < 2 or child_states[-1] != grpc.ChannelConnectivity.READY: |
||||||
|
raise ValueError('Channel did not move to READY') |
||||||
|
if len(parent_states) > 1: |
||||||
|
raise ValueError('Received connectivity updates on parent callback') |
||||||
|
child_channel.unsubscribe(child_connectivity_callback) |
||||||
|
child_channel.close() |
||||||
|
|
||||||
|
def parent_connectivity_callback(state): |
||||||
|
parent_states.append(state) |
||||||
|
|
||||||
|
parent_states = [] |
||||||
|
channel.subscribe(parent_connectivity_callback) |
||||||
|
stub = test_pb2_grpc.TestServiceStub(channel) |
||||||
|
child_process = _ChildProcess(child_target) |
||||||
|
child_process.start() |
||||||
|
_async_unary(stub) |
||||||
|
if len(parent_states |
||||||
|
) < 2 or parent_states[-1] != grpc.ChannelConnectivity.READY: |
||||||
|
raise ValueError('Channel did not move to READY') |
||||||
|
channel.unsubscribe(parent_connectivity_callback) |
||||||
|
child_process.finish() |
||||||
|
|
||||||
|
# Need to unsubscribe or _channel.py in _poll_connectivity triggers a |
||||||
|
# "Cannot invoke RPC on closed channel!" error. |
||||||
|
# TODO(ericgribkoff) Fix issue with channel.close() and connectivity polling |
||||||
|
channel.unsubscribe(parent_connectivity_callback) |
||||||
|
|
||||||
|
|
||||||
|
def _ping_pong_with_child_processes_after_first_response( |
||||||
|
channel, args, child_target, run_after_close=True): |
||||||
|
request_response_sizes = ( |
||||||
|
31415, |
||||||
|
9, |
||||||
|
2653, |
||||||
|
58979, |
||||||
|
) |
||||||
|
request_payload_sizes = ( |
||||||
|
27182, |
||||||
|
8, |
||||||
|
1828, |
||||||
|
45904, |
||||||
|
) |
||||||
|
stub = test_pb2_grpc.TestServiceStub(channel) |
||||||
|
pipe = _Pipe() |
||||||
|
parent_bidi_call = stub.FullDuplexCall(pipe) |
||||||
|
child_processes = [] |
||||||
|
first_message_received = False |
||||||
|
for response_size, payload_size in zip(request_response_sizes, |
||||||
|
request_payload_sizes): |
||||||
|
request = messages_pb2.StreamingOutputCallRequest( |
||||||
|
response_type=messages_pb2.COMPRESSABLE, |
||||||
|
response_parameters=( |
||||||
|
messages_pb2.ResponseParameters(size=response_size),), |
||||||
|
payload=messages_pb2.Payload(body=b'\x00' * payload_size)) |
||||||
|
pipe.add(request) |
||||||
|
if first_message_received: |
||||||
|
child_process = _ChildProcess(child_target, |
||||||
|
(parent_bidi_call, channel, args)) |
||||||
|
child_process.start() |
||||||
|
child_processes.append(child_process) |
||||||
|
response = next(parent_bidi_call) |
||||||
|
first_message_received = True |
||||||
|
child_process = _ChildProcess(child_target, |
||||||
|
(parent_bidi_call, channel, args)) |
||||||
|
child_process.start() |
||||||
|
child_processes.append(child_process) |
||||||
|
_validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE, |
||||||
|
response_size) |
||||||
|
pipe.close() |
||||||
|
if run_after_close: |
||||||
|
child_process = _ChildProcess(child_target, |
||||||
|
(parent_bidi_call, channel, args)) |
||||||
|
child_process.start() |
||||||
|
child_processes.append(child_process) |
||||||
|
for child_process in child_processes: |
||||||
|
child_process.finish() |
||||||
|
|
||||||
|
|
||||||
|
def _in_progress_bidi_continue_call(channel): |
||||||
|
|
||||||
|
def child_target(parent_bidi_call, parent_channel, args): |
||||||
|
stub = test_pb2_grpc.TestServiceStub(parent_channel) |
||||||
|
try: |
||||||
|
_async_unary(stub) |
||||||
|
raise Exception( |
||||||
|
'Child should not be able to re-use channel after fork') |
||||||
|
except ValueError as expected_value_error: |
||||||
|
pass |
||||||
|
inherited_code = parent_bidi_call.code() |
||||||
|
inherited_details = parent_bidi_call.details() |
||||||
|
if inherited_code != grpc.StatusCode.CANCELLED: |
||||||
|
raise ValueError( |
||||||
|
'Expected inherited code CANCELLED, got %s' % inherited_code) |
||||||
|
if inherited_details != 'Channel closed due to fork': |
||||||
|
raise ValueError( |
||||||
|
'Expected inherited details Channel closed due to fork, got %s' |
||||||
|
% inherited_details) |
||||||
|
|
||||||
|
# Don't run child_target after closing the parent call, as the call may have |
||||||
|
# received a status from the server before fork occurs. |
||||||
|
_ping_pong_with_child_processes_after_first_response( |
||||||
|
channel, None, child_target, run_after_close=False) |
||||||
|
|
||||||
|
|
||||||
|
def _in_progress_bidi_same_channel_async_call(channel): |
||||||
|
|
||||||
|
def child_target(parent_bidi_call, parent_channel, args): |
||||||
|
stub = test_pb2_grpc.TestServiceStub(parent_channel) |
||||||
|
try: |
||||||
|
_async_unary(stub) |
||||||
|
raise Exception( |
||||||
|
'Child should not be able to re-use channel after fork') |
||||||
|
except ValueError as expected_value_error: |
||||||
|
pass |
||||||
|
|
||||||
|
_ping_pong_with_child_processes_after_first_response( |
||||||
|
channel, None, child_target) |
||||||
|
|
||||||
|
|
||||||
|
def _in_progress_bidi_same_channel_blocking_call(channel): |
||||||
|
|
||||||
|
def child_target(parent_bidi_call, parent_channel, args): |
||||||
|
stub = test_pb2_grpc.TestServiceStub(parent_channel) |
||||||
|
try: |
||||||
|
_blocking_unary(stub) |
||||||
|
raise Exception( |
||||||
|
'Child should not be able to re-use channel after fork') |
||||||
|
except ValueError as expected_value_error: |
||||||
|
pass |
||||||
|
|
||||||
|
_ping_pong_with_child_processes_after_first_response( |
||||||
|
channel, None, child_target) |
||||||
|
|
||||||
|
|
||||||
|
def _in_progress_bidi_new_channel_async_call(channel, args): |
||||||
|
|
||||||
|
def child_target(parent_bidi_call, parent_channel, args): |
||||||
|
channel = _channel(args) |
||||||
|
stub = test_pb2_grpc.TestServiceStub(channel) |
||||||
|
_async_unary(stub) |
||||||
|
|
||||||
|
_ping_pong_with_child_processes_after_first_response( |
||||||
|
channel, args, child_target) |
||||||
|
|
||||||
|
|
||||||
|
def _in_progress_bidi_new_channel_blocking_call(channel, args): |
||||||
|
|
||||||
|
def child_target(parent_bidi_call, parent_channel, args): |
||||||
|
channel = _channel(args) |
||||||
|
stub = test_pb2_grpc.TestServiceStub(channel) |
||||||
|
_blocking_unary(stub) |
||||||
|
|
||||||
|
_ping_pong_with_child_processes_after_first_response( |
||||||
|
channel, args, child_target) |
||||||
|
|
||||||
|
|
||||||
|
@enum.unique |
||||||
|
class TestCase(enum.Enum): |
||||||
|
|
||||||
|
CONNECTIVITY_WATCH = 'connectivity_watch' |
||||||
|
CLOSE_CHANNEL_BEFORE_FORK = 'close_channel_before_fork' |
||||||
|
ASYNC_UNARY_SAME_CHANNEL = 'async_unary_same_channel' |
||||||
|
ASYNC_UNARY_NEW_CHANNEL = 'async_unary_new_channel' |
||||||
|
BLOCKING_UNARY_SAME_CHANNEL = 'blocking_unary_same_channel' |
||||||
|
BLOCKING_UNARY_NEW_CHANNEL = 'blocking_unary_new_channel' |
||||||
|
IN_PROGRESS_BIDI_CONTINUE_CALL = 'in_progress_bidi_continue_call' |
||||||
|
IN_PROGRESS_BIDI_SAME_CHANNEL_ASYNC_CALL = 'in_progress_bidi_same_channel_async_call' |
||||||
|
IN_PROGRESS_BIDI_SAME_CHANNEL_BLOCKING_CALL = 'in_progress_bidi_same_channel_blocking_call' |
||||||
|
IN_PROGRESS_BIDI_NEW_CHANNEL_ASYNC_CALL = 'in_progress_bidi_new_channel_async_call' |
||||||
|
IN_PROGRESS_BIDI_NEW_CHANNEL_BLOCKING_CALL = 'in_progress_bidi_new_channel_blocking_call' |
||||||
|
|
||||||
|
def run_test(self, args): |
||||||
|
_LOGGER.info("Running %s", self) |
||||||
|
channel = _channel(args) |
||||||
|
if self is TestCase.ASYNC_UNARY_SAME_CHANNEL: |
||||||
|
_async_unary_same_channel(channel) |
||||||
|
elif self is TestCase.ASYNC_UNARY_NEW_CHANNEL: |
||||||
|
_async_unary_new_channel(channel, args) |
||||||
|
elif self is TestCase.BLOCKING_UNARY_SAME_CHANNEL: |
||||||
|
_blocking_unary_same_channel(channel) |
||||||
|
elif self is TestCase.BLOCKING_UNARY_NEW_CHANNEL: |
||||||
|
_blocking_unary_new_channel(channel, args) |
||||||
|
elif self is TestCase.CLOSE_CHANNEL_BEFORE_FORK: |
||||||
|
_close_channel_before_fork(channel, args) |
||||||
|
elif self is TestCase.CONNECTIVITY_WATCH: |
||||||
|
_connectivity_watch(channel, args) |
||||||
|
elif self is TestCase.IN_PROGRESS_BIDI_CONTINUE_CALL: |
||||||
|
_in_progress_bidi_continue_call(channel) |
||||||
|
elif self is TestCase.IN_PROGRESS_BIDI_SAME_CHANNEL_ASYNC_CALL: |
||||||
|
_in_progress_bidi_same_channel_async_call(channel) |
||||||
|
elif self is TestCase.IN_PROGRESS_BIDI_SAME_CHANNEL_BLOCKING_CALL: |
||||||
|
_in_progress_bidi_same_channel_blocking_call(channel) |
||||||
|
elif self is TestCase.IN_PROGRESS_BIDI_NEW_CHANNEL_ASYNC_CALL: |
||||||
|
_in_progress_bidi_new_channel_async_call(channel, args) |
||||||
|
elif self is TestCase.IN_PROGRESS_BIDI_NEW_CHANNEL_BLOCKING_CALL: |
||||||
|
_in_progress_bidi_new_channel_blocking_call(channel, args) |
||||||
|
else: |
||||||
|
raise NotImplementedError( |
||||||
|
'Test case "%s" not implemented!' % self.name) |
||||||
|
channel.close() |
@ -0,0 +1,68 @@ |
|||||||
|
# Copyright 2018 gRPC authors. |
||||||
|
# |
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||||
|
# you may not use this file except in compliance with the License. |
||||||
|
# You may obtain a copy of the License at |
||||||
|
# |
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||||
|
# |
||||||
|
# Unless required by applicable law or agreed to in writing, software |
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||||
|
# See the License for the specific language governing permissions and |
||||||
|
# limitations under the License. |
||||||
|
|
||||||
|
import os |
||||||
|
import threading |
||||||
|
import unittest |
||||||
|
|
||||||
|
from grpc._cython import cygrpc |
||||||
|
|
||||||
|
|
||||||
|
def _get_number_active_threads(): |
||||||
|
return cygrpc._fork_state.active_thread_count._num_active_threads |
||||||
|
|
||||||
|
|
||||||
|
@unittest.skipIf(os.name == 'nt', 'Posix-specific tests') |
||||||
|
class ForkPosixTester(unittest.TestCase): |
||||||
|
|
||||||
|
def setUp(self): |
||||||
|
cygrpc._GRPC_ENABLE_FORK_SUPPORT = True |
||||||
|
|
||||||
|
def testForkManagedThread(self): |
||||||
|
|
||||||
|
def cb(): |
||||||
|
self.assertEqual(1, _get_number_active_threads()) |
||||||
|
|
||||||
|
thread = cygrpc.ForkManagedThread(cb) |
||||||
|
thread.start() |
||||||
|
thread.join() |
||||||
|
self.assertEqual(0, _get_number_active_threads()) |
||||||
|
|
||||||
|
def testForkManagedThreadThrowsException(self): |
||||||
|
|
||||||
|
def cb(): |
||||||
|
self.assertEqual(1, _get_number_active_threads()) |
||||||
|
raise Exception("expected exception") |
||||||
|
|
||||||
|
thread = cygrpc.ForkManagedThread(cb) |
||||||
|
thread.start() |
||||||
|
thread.join() |
||||||
|
self.assertEqual(0, _get_number_active_threads()) |
||||||
|
|
||||||
|
|
||||||
|
@unittest.skipUnless(os.name == 'nt', 'Windows-specific tests') |
||||||
|
class ForkWindowsTester(unittest.TestCase): |
||||||
|
|
||||||
|
def testForkManagedThreadIsNoOp(self): |
||||||
|
|
||||||
|
def cb(): |
||||||
|
pass |
||||||
|
|
||||||
|
thread = cygrpc.ForkManagedThread(cb) |
||||||
|
thread.start() |
||||||
|
thread.join() |
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__': |
||||||
|
unittest.main(verbosity=2) |
Loading…
Reference in new issue