Support gRPC Python client-side fork with epoll1

A process may fork after invoking grpc_init() and use gRPC in the child
if and only if the child process first destroys all gRPC resources
inherited from the parent process and invokes grpc_shutdown().
Subsequent to this, the child will be able to re-initialize and use
gRPC. After fork, the parent process will be able to continue to use
existing gRPC resources such as channels and calls without interference
from the child process.

To facilitate gRPC Python applications meeting the above constraints,
gRPC Python will automatically destroy and shutdown all gRPC Core
resources in the child's post-fork handler, including cancelling
in-flight calls (see detailed design below). From the client's
perspective, the child process is now free to create new channels and
use gRPC.
pull/16264/head
Eric Gribkoff 7 years ago
parent f10596f0f3
commit f8cf7ee56d
  1. 2
      .pylintrc-tests
  2. 73
      src/core/lib/gprpp/fork.cc
  3. 17
      src/core/lib/gprpp/fork.h
  4. 72
      src/core/lib/iomgr/ev_epoll1_linux.cc
  5. 5
      src/core/lib/iomgr/fork_posix.cc
  6. 59
      src/python/grpcio/grpc/_channel.py
  7. 2
      src/python/grpcio/grpc/_cython/_cygrpc/call.pyx.pxi
  8. 1
      src/python/grpcio/grpc/_cython/_cygrpc/channel.pxd.pxi
  9. 49
      src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi
  10. 2
      src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pyx.pxi
  11. 8
      src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi
  12. 29
      src/python/grpcio/grpc/_cython/_cygrpc/fork_posix.pxd.pxi
  13. 203
      src/python/grpcio/grpc/_cython/_cygrpc/fork_posix.pyx.pxi
  14. 63
      src/python/grpcio/grpc/_cython/_cygrpc/fork_windows.pyx.pxi
  15. 2
      src/python/grpcio/grpc/_cython/_cygrpc/records.pyx.pxi
  16. 2
      src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi
  17. 3
      src/python/grpcio/grpc/_cython/cygrpc.pxd
  18. 5
      src/python/grpcio/grpc/_cython/cygrpc.pyx
  19. 25
      src/python/grpcio_tests/commands.py
  20. 1
      src/python/grpcio_tests/setup.py
  21. 13
      src/python/grpcio_tests/tests/fork/__init__.py
  22. 76
      src/python/grpcio_tests/tests/fork/client.py
  23. 445
      src/python/grpcio_tests/tests/fork/methods.py
  24. 2
      src/python/grpcio_tests/tests/tests.json
  25. 68
      src/python/grpcio_tests/tests/unit/_cython/_fork_test.py

@ -20,6 +20,8 @@ notes=FIXME,XXX
[MESSAGES CONTROL]
extension-pkg-whitelist=grpc._cython.cygrpc
disable=
# These suppressions are specific to tests:
#

@ -157,11 +157,11 @@ class ThreadState {
} // namespace
void Fork::GlobalInit() {
if (!overrideEnabled_) {
if (!override_enabled_) {
#ifdef GRPC_ENABLE_FORK_SUPPORT
supportEnabled_ = true;
support_enabled_ = true;
#else
supportEnabled_ = false;
support_enabled_ = false;
#endif
bool env_var_set = false;
char* env = gpr_getenv("GRPC_ENABLE_FORK_SUPPORT");
@ -172,7 +172,7 @@ void Fork::GlobalInit() {
"False", "FALSE", "0"};
for (size_t i = 0; i < GPR_ARRAY_SIZE(truthy); i++) {
if (0 == strcmp(env, truthy[i])) {
supportEnabled_ = true;
support_enabled_ = true;
env_var_set = true;
break;
}
@ -180,7 +180,7 @@ void Fork::GlobalInit() {
if (!env_var_set) {
for (size_t i = 0; i < GPR_ARRAY_SIZE(falsey); i++) {
if (0 == strcmp(env, falsey[i])) {
supportEnabled_ = false;
support_enabled_ = false;
env_var_set = true;
break;
}
@ -189,72 +189,79 @@ void Fork::GlobalInit() {
gpr_free(env);
}
}
if (supportEnabled_) {
execCtxState_ = grpc_core::New<internal::ExecCtxState>();
threadState_ = grpc_core::New<internal::ThreadState>();
if (support_enabled_) {
exec_ctx_state_ = grpc_core::New<internal::ExecCtxState>();
thread_state_ = grpc_core::New<internal::ThreadState>();
}
}
void Fork::GlobalShutdown() {
if (supportEnabled_) {
grpc_core::Delete(execCtxState_);
grpc_core::Delete(threadState_);
if (support_enabled_) {
grpc_core::Delete(exec_ctx_state_);
grpc_core::Delete(thread_state_);
}
}
bool Fork::Enabled() { return supportEnabled_; }
bool Fork::Enabled() { return support_enabled_; }
// Testing Only
void Fork::Enable(bool enable) {
overrideEnabled_ = true;
supportEnabled_ = enable;
override_enabled_ = true;
support_enabled_ = enable;
}
void Fork::IncExecCtxCount() {
if (supportEnabled_) {
execCtxState_->IncExecCtxCount();
if (support_enabled_) {
exec_ctx_state_->IncExecCtxCount();
}
}
void Fork::DecExecCtxCount() {
if (supportEnabled_) {
execCtxState_->DecExecCtxCount();
if (support_enabled_) {
exec_ctx_state_->DecExecCtxCount();
}
}
void Fork::SetResetChildPollingEngineFunc(Fork::child_postfork_func func) {
reset_child_polling_engine_ = func;
}
Fork::child_postfork_func Fork::GetResetChildPollingEngineFunc() {
return reset_child_polling_engine_;
}
bool Fork::BlockExecCtx() {
if (supportEnabled_) {
return execCtxState_->BlockExecCtx();
if (support_enabled_) {
return exec_ctx_state_->BlockExecCtx();
}
return false;
}
void Fork::AllowExecCtx() {
if (supportEnabled_) {
execCtxState_->AllowExecCtx();
if (support_enabled_) {
exec_ctx_state_->AllowExecCtx();
}
}
void Fork::IncThreadCount() {
if (supportEnabled_) {
threadState_->IncThreadCount();
if (support_enabled_) {
thread_state_->IncThreadCount();
}
}
void Fork::DecThreadCount() {
if (supportEnabled_) {
threadState_->DecThreadCount();
if (support_enabled_) {
thread_state_->DecThreadCount();
}
}
void Fork::AwaitThreads() {
if (supportEnabled_) {
threadState_->AwaitThreads();
if (support_enabled_) {
thread_state_->AwaitThreads();
}
}
internal::ExecCtxState* Fork::execCtxState_ = nullptr;
internal::ThreadState* Fork::threadState_ = nullptr;
bool Fork::supportEnabled_ = false;
bool Fork::overrideEnabled_ = false;
internal::ExecCtxState* Fork::exec_ctx_state_ = nullptr;
internal::ThreadState* Fork::thread_state_ = nullptr;
bool Fork::support_enabled_ = false;
bool Fork::override_enabled_ = false;
Fork::child_postfork_func Fork::reset_child_polling_engine_ = nullptr;
} // namespace grpc_core

@ -33,6 +33,8 @@ class ThreadState;
class Fork {
public:
typedef void (*child_postfork_func)(void);
static void GlobalInit();
static void GlobalShutdown();
@ -46,6 +48,12 @@ class Fork {
// Decrement the count of active ExecCtxs
static void DecExecCtxCount();
// Provide a function that will be invoked in the child's postfork handler to
// reset the polling engine's internal state.
static void SetResetChildPollingEngineFunc(
child_postfork_func reset_child_polling_engine);
static child_postfork_func GetResetChildPollingEngineFunc();
// Check if there is a single active ExecCtx
// (the one used to invoke this function). If there are more,
// return false. Otherwise, return true and block creation of
@ -68,10 +76,11 @@ class Fork {
static void Enable(bool enable);
private:
static internal::ExecCtxState* execCtxState_;
static internal::ThreadState* threadState_;
static bool supportEnabled_;
static bool overrideEnabled_;
static internal::ExecCtxState* exec_ctx_state_;
static internal::ThreadState* thread_state_;
static bool support_enabled_;
static bool override_enabled_;
static child_postfork_func reset_child_polling_engine_;
};
} // namespace grpc_core

@ -131,6 +131,13 @@ static void epoll_set_shutdown() {
* Fd Declarations
*/
/* Only used when GRPC_ENABLE_FORK_SUPPORT=1 */
struct grpc_fork_fd_list {
grpc_fd* fd;
grpc_fd* next;
grpc_fd* prev;
};
struct grpc_fd {
int fd;
@ -141,6 +148,9 @@ struct grpc_fd {
struct grpc_fd* freelist_next;
grpc_iomgr_object iomgr_object;
/* Only used when GRPC_ENABLE_FORK_SUPPORT=1 */
grpc_fork_fd_list* fork_fd_list;
};
static void fd_global_init(void);
@ -256,6 +266,10 @@ static bool append_error(grpc_error** composite, grpc_error* error,
static grpc_fd* fd_freelist = nullptr;
static gpr_mu fd_freelist_mu;
/* Only used when GRPC_ENABLE_FORK_SUPPORT=1 */
static grpc_fd* fork_fd_list_head = nullptr;
static gpr_mu fork_fd_list_mu;
static void fd_global_init(void) { gpr_mu_init(&fd_freelist_mu); }
static void fd_global_shutdown(void) {
@ -269,6 +283,38 @@ static void fd_global_shutdown(void) {
gpr_mu_destroy(&fd_freelist_mu);
}
static void fork_fd_list_add_grpc_fd(grpc_fd* fd) {
if (grpc_core::Fork::Enabled()) {
gpr_mu_lock(&fork_fd_list_mu);
fd->fork_fd_list =
static_cast<grpc_fork_fd_list*>(gpr_malloc(sizeof(grpc_fork_fd_list)));
fd->fork_fd_list->next = fork_fd_list_head;
fd->fork_fd_list->prev = nullptr;
if (fork_fd_list_head != nullptr) {
fork_fd_list_head->fork_fd_list->prev = fd;
}
fork_fd_list_head = fd;
gpr_mu_unlock(&fork_fd_list_mu);
}
}
static void fork_fd_list_remove_grpc_fd(grpc_fd* fd) {
if (grpc_core::Fork::Enabled()) {
gpr_mu_lock(&fork_fd_list_mu);
if (fork_fd_list_head == fd) {
fork_fd_list_head = fd->fork_fd_list->next;
}
if (fd->fork_fd_list->prev != nullptr) {
fd->fork_fd_list->prev->fork_fd_list->next = fd->fork_fd_list->next;
}
if (fd->fork_fd_list->next != nullptr) {
fd->fork_fd_list->next->fork_fd_list->prev = fd->fork_fd_list->prev;
}
gpr_free(fd->fork_fd_list);
gpr_mu_unlock(&fork_fd_list_mu);
}
}
static grpc_fd* fd_create(int fd, const char* name, bool track_err) {
grpc_fd* new_fd = nullptr;
@ -295,6 +341,7 @@ static grpc_fd* fd_create(int fd, const char* name, bool track_err) {
char* fd_name;
gpr_asprintf(&fd_name, "%s fd=%d", name, fd);
grpc_iomgr_register_object(&new_fd->iomgr_object, fd_name);
fork_fd_list_add_grpc_fd(new_fd);
#ifndef NDEBUG
if (grpc_trace_fd_refcount.enabled()) {
gpr_log(GPR_DEBUG, "FD %d %p create %s", fd, new_fd, fd_name);
@ -361,6 +408,7 @@ static void fd_orphan(grpc_fd* fd, grpc_closure* on_done, int* release_fd,
GRPC_CLOSURE_SCHED(on_done, GRPC_ERROR_REF(error));
grpc_iomgr_unregister_object(&fd->iomgr_object);
fork_fd_list_remove_grpc_fd(fd);
fd->read_closure->DestroyEvent();
fd->write_closure->DestroyEvent();
fd->error_closure->DestroyEvent();
@ -1190,6 +1238,10 @@ static void shutdown_engine(void) {
fd_global_shutdown();
pollset_global_shutdown();
epoll_set_shutdown();
if (grpc_core::Fork::Enabled()) {
gpr_mu_destroy(&fork_fd_list_mu);
grpc_core::Fork::SetResetChildPollingEngineFunc(nullptr);
}
}
static const grpc_event_engine_vtable vtable = {
@ -1227,6 +1279,21 @@ static const grpc_event_engine_vtable vtable = {
shutdown_engine,
};
/* Called by the child process's post-fork handler to close open fds, including
* the global epoll fd. This allows gRPC to shutdown in the child process
* without interfering with connections or RPCs ongoing in the parent. */
static void reset_event_manager_on_fork() {
gpr_mu_lock(&fork_fd_list_mu);
while (fork_fd_list_head != nullptr) {
close(fork_fd_list_head->fd);
fork_fd_list_head->fd = -1;
fork_fd_list_head = fork_fd_list_head->fork_fd_list->next;
}
gpr_mu_unlock(&fork_fd_list_mu);
shutdown_engine();
grpc_init_epoll1_linux(true);
}
/* It is possible that GLIBC has epoll but the underlying kernel doesn't.
* Create epoll_fd (epoll_set_init() takes care of that) to make sure epoll
* support is available */
@ -1248,6 +1315,11 @@ const grpc_event_engine_vtable* grpc_init_epoll1_linux(bool explicit_request) {
return nullptr;
}
if (grpc_core::Fork::Enabled()) {
gpr_mu_init(&fork_fd_list_mu);
grpc_core::Fork::SetResetChildPollingEngineFunc(
reset_event_manager_on_fork);
}
return &vtable;
}

@ -84,6 +84,11 @@ void grpc_postfork_child() {
if (!skipped_handler) {
grpc_core::Fork::AllowExecCtx();
grpc_core::ExecCtx exec_ctx;
grpc_core::Fork::child_postfork_func reset_polling_engine =
grpc_core::Fork::GetResetChildPollingEngineFunc();
if (reset_polling_engine != nullptr) {
reset_polling_engine();
}
grpc_timer_manager_set_threading(true);
grpc_executor_set_threading(true);
}

@ -111,6 +111,10 @@ class _RPCState(object):
# prior to termination of the RPC.
self.cancelled = False
self.callbacks = []
self.fork_epoch = cygrpc.get_fork_epoch()
def reset_postfork_child(self):
self.condition = threading.Condition()
def _abort(state, code, details):
@ -166,21 +170,30 @@ def _event_handler(state, response_deserializer):
done = not state.due
for callback in callbacks:
callback()
return done
return done and state.fork_epoch >= cygrpc.get_fork_epoch()
return handle_event
def _consume_request_iterator(request_iterator, state, call, request_serializer,
event_handler):
if cygrpc.is_fork_support_enabled():
condition_wait_timeout = 1.0
else:
condition_wait_timeout = None
def consume_request_iterator(): # pylint: disable=too-many-branches
while True:
return_from_user_request_generator_invoked = False
try:
# The thread may die in user-code. Do not block fork for this.
cygrpc.enter_user_request_generator()
request = next(request_iterator)
except StopIteration:
break
except Exception: # pylint: disable=broad-except
cygrpc.return_from_user_request_generator()
return_from_user_request_generator_invoked = True
code = grpc.StatusCode.UNKNOWN
details = 'Exception iterating requests!'
_LOGGER.exception(details)
@ -188,6 +201,9 @@ def _consume_request_iterator(request_iterator, state, call, request_serializer,
details)
_abort(state, code, details)
return
finally:
if not return_from_user_request_generator_invoked:
cygrpc.return_from_user_request_generator()
serialized_request = _common.serialize(request, request_serializer)
with state.condition:
if state.code is None and not state.cancelled:
@ -208,7 +224,8 @@ def _consume_request_iterator(request_iterator, state, call, request_serializer,
else:
return
while True:
state.condition.wait()
state.condition.wait(condition_wait_timeout)
cygrpc.block_if_fork_in_progress(state)
if state.code is None:
if cygrpc.OperationType.send_message not in state.due:
break
@ -224,8 +241,9 @@ def _consume_request_iterator(request_iterator, state, call, request_serializer,
if operating:
state.due.add(cygrpc.OperationType.send_close_from_client)
consumption_thread = threading.Thread(target=consume_request_iterator)
consumption_thread.daemon = True
consumption_thread = cygrpc.ForkManagedThread(
target=consume_request_iterator)
consumption_thread.setDaemon(True)
consumption_thread.start()
@ -671,13 +689,20 @@ class _ChannelCallState(object):
self.lock = threading.Lock()
self.channel = channel
self.managed_calls = 0
self.threading = False
def reset_postfork_child(self):
self.managed_calls = 0
def _run_channel_spin_thread(state):
def channel_spin():
while True:
cygrpc.block_if_fork_in_progress(state)
event = state.channel.next_call_event()
if event.completion_type == cygrpc.CompletionType.queue_timeout:
continue
call_completed = event.tag(event)
if call_completed:
with state.lock:
@ -685,8 +710,8 @@ def _run_channel_spin_thread(state):
if state.managed_calls == 0:
return
channel_spin_thread = threading.Thread(target=channel_spin)
channel_spin_thread.daemon = True
channel_spin_thread = cygrpc.ForkManagedThread(target=channel_spin)
channel_spin_thread.setDaemon(True)
channel_spin_thread.start()
@ -742,6 +767,13 @@ class _ChannelConnectivityState(object):
self.callbacks_and_connectivities = []
self.delivering = False
def reset_postfork_child(self):
self.polling = False
self.connectivity = None
self.try_to_connect = False
self.callbacks_and_connectivities = []
self.delivering = False
def _deliveries(state):
callbacks_needing_update = []
@ -758,6 +790,7 @@ def _deliver(state, initial_connectivity, initial_callbacks):
callbacks = initial_callbacks
while True:
for callback in callbacks:
cygrpc.block_if_fork_in_progress(state)
callable_util.call_logging_exceptions(
callback, _CHANNEL_SUBSCRIPTION_CALLBACK_ERROR_LOG_MESSAGE,
connectivity)
@ -771,7 +804,7 @@ def _deliver(state, initial_connectivity, initial_callbacks):
def _spawn_delivery(state, callbacks):
delivering_thread = threading.Thread(
delivering_thread = cygrpc.ForkManagedThread(
target=_deliver, args=(
state,
state.connectivity,
@ -799,6 +832,7 @@ def _poll_connectivity(state, channel, initial_try_to_connect):
while True:
event = channel.watch_connectivity_state(connectivity,
time.time() + 0.2)
cygrpc.block_if_fork_in_progress(state)
with state.lock:
if not state.callbacks_and_connectivities and not state.try_to_connect:
state.polling = False
@ -826,10 +860,10 @@ def _moot(state):
def _subscribe(state, callback, try_to_connect):
with state.lock:
if not state.callbacks_and_connectivities and not state.polling:
polling_thread = threading.Thread(
polling_thread = cygrpc.ForkManagedThread(
target=_poll_connectivity,
args=(state, state.channel, bool(try_to_connect)))
polling_thread.daemon = True
polling_thread.setDaemon(True)
polling_thread.start()
state.polling = True
state.callbacks_and_connectivities.append([callback, None])
@ -876,6 +910,7 @@ class Channel(grpc.Channel):
_common.encode(target), _options(options), credentials)
self._call_state = _ChannelCallState(self._channel)
self._connectivity_state = _ChannelConnectivityState(self._channel)
cygrpc.fork_register_channel(self)
def subscribe(self, callback, try_to_connect=None):
_subscribe(self._connectivity_state, callback, try_to_connect)
@ -919,6 +954,11 @@ class Channel(grpc.Channel):
self._channel.close(cygrpc.StatusCode.cancelled, 'Channel closed!')
_moot(self._connectivity_state)
def _close_on_fork(self):
self._channel.close_on_fork(cygrpc.StatusCode.cancelled,
'Channel closed due to fork')
_moot(self._connectivity_state)
def __enter__(self):
return self
@ -939,4 +979,5 @@ class Channel(grpc.Channel):
# for as long as they are in use and to close them after using them,
# then deletion of this grpc._channel.Channel instance can be made to
# effect closure of the underlying cygrpc.Channel instance.
cygrpc.fork_unregister_channel(self)
_moot(self._connectivity_state)

@ -19,7 +19,7 @@ cdef class Call:
def __cinit__(self):
# Create an *empty* call
grpc_init()
fork_handlers_and_grpc_init()
self.c_call = NULL
self.references = []

@ -40,6 +40,7 @@ cdef class _ChannelState:
# field and just use the NULLness of c_channel as an indication that the
# channel is closed.
cdef object open
cdef object closed_reason
# A dict from _BatchOperationTag to _CallState
cdef dict integrated_call_states

@ -15,6 +15,7 @@
cimport cpython
import threading
import time
_INTERNAL_CALL_ERROR_MESSAGE_FORMAT = (
'Internal gRPC call error %d. ' +
@ -83,6 +84,7 @@ cdef class _ChannelState:
self.integrated_call_states = {}
self.segregated_call_states = set()
self.connectivity_due = set()
self.closed_reason = None
cdef tuple _operate(grpc_call *c_call, object operations, object user_tag):
@ -142,10 +144,10 @@ cdef _cancel(
_check_and_raise_call_error_no_metadata(c_call_error)
cdef BatchOperationEvent _next_call_event(
cdef _next_call_event(
_ChannelState channel_state, grpc_completion_queue *c_completion_queue,
on_success):
tag, event = _latent_event(c_completion_queue, None)
on_success, deadline):
tag, event = _latent_event(c_completion_queue, deadline)
with channel_state.condition:
on_success(tag)
channel_state.condition.notify_all()
@ -229,8 +231,7 @@ cdef void _call(
call_state.due.update(started_tags)
on_success(started_tags)
else:
raise ValueError('Cannot invoke RPC on closed channel!')
raise ValueError('Cannot invoke RPC: %s' % channel_state.closed_reason)
cdef void _process_integrated_call_tag(
_ChannelState state, _BatchOperationTag tag) except *:
cdef _CallState call_state = state.integrated_call_states.pop(tag)
@ -302,7 +303,7 @@ cdef class SegregatedCall:
_process_segregated_call_tag(
self._channel_state, self._call_state, self._c_completion_queue, tag)
return _next_call_event(
self._channel_state, self._c_completion_queue, on_success)
self._channel_state, self._c_completion_queue, on_success, None)
cdef SegregatedCall _segregated_call(
@ -346,7 +347,7 @@ cdef object _watch_connectivity_state(
state.c_connectivity_completion_queue, <cpython.PyObject *>tag)
state.connectivity_due.add(tag)
else:
raise ValueError('Cannot invoke RPC on closed channel!')
raise ValueError('Cannot invoke RPC: %s' % state.closed_reason)
completed_tag, event = _latent_event(
state.c_connectivity_completion_queue, None)
with state.condition:
@ -355,12 +356,15 @@ cdef object _watch_connectivity_state(
return event
cdef _close(_ChannelState state, grpc_status_code code, object details):
cdef _close(Channel channel, grpc_status_code code, object details,
drain_calls):
cdef _ChannelState state = channel._state
cdef _CallState call_state
encoded_details = _encode(details)
with state.condition:
if state.open:
state.open = False
state.closed_reason = details
for call_state in set(state.integrated_call_states.values()):
grpc_call_cancel_with_status(
call_state.c_call, code, encoded_details, NULL)
@ -370,6 +374,13 @@ cdef _close(_ChannelState state, grpc_status_code code, object details):
# TODO(https://github.com/grpc/grpc/issues/3064): Cancel connectivity
# watching.
if drain_calls:
while not _calls_drained(state):
event = channel.next_call_event()
if event.completion_type == CompletionType.queue_timeout:
continue
event.tag(event)
else:
while state.integrated_call_states:
state.condition.wait()
while state.segregated_call_states:
@ -390,13 +401,17 @@ cdef _close(_ChannelState state, grpc_status_code code, object details):
state.condition.wait()
cdef _calls_drained(_ChannelState state):
return not (state.integrated_call_states or state.segregated_call_states or
state.connectivity_due)
cdef class Channel:
def __cinit__(
self, bytes target, object arguments,
ChannelCredentials channel_credentials):
arguments = () if arguments is None else tuple(arguments)
grpc_init()
fork_handlers_and_grpc_init()
self._state = _ChannelState()
self._vtable.copy = &_copy_pointer
self._vtable.destroy = &_destroy_pointer
@ -435,9 +450,14 @@ cdef class Channel:
def next_call_event(self):
def on_success(tag):
if tag is not None:
_process_integrated_call_tag(self._state, tag)
return _next_call_event(
self._state, self._state.c_call_completion_queue, on_success)
if is_fork_support_enabled():
queue_deadline = time.time() + 1.0
else:
queue_deadline = None
return _next_call_event(self._state, self._state.c_call_completion_queue,
on_success, queue_deadline)
def segregated_call(
self, int flags, method, host, object deadline, object metadata,
@ -452,11 +472,14 @@ cdef class Channel:
return grpc_channel_check_connectivity_state(
self._state.c_channel, try_to_connect)
else:
raise ValueError('Cannot invoke RPC on closed channel!')
raise ValueError('Cannot invoke RPC: %s' % self._state.closed_reason)
def watch_connectivity_state(
self, grpc_connectivity_state last_observed_state, object deadline):
return _watch_connectivity_state(self._state, last_observed_state, deadline)
def close(self, code, details):
_close(self._state, code, details)
_close(self, code, details, False)
def close_on_fork(self, code, details):
_close(self, code, details, True)

@ -71,7 +71,7 @@ cdef class CompletionQueue:
def __cinit__(self, shutdown_cq=False):
cdef grpc_completion_queue_attributes c_attrs
grpc_init()
fork_handlers_and_grpc_init()
if shutdown_cq:
c_attrs.version = 1
c_attrs.cq_completion_type = GRPC_CQ_NEXT

@ -21,7 +21,7 @@ from libc.stdint cimport uintptr_t
def _spawn_callback_in_thread(cb_func, args):
threading.Thread(target=cb_func, args=args).start()
ForkManagedThread(target=cb_func, args=args).start()
async_callback_func = _spawn_callback_in_thread
@ -114,7 +114,7 @@ cdef class ChannelCredentials:
cdef class SSLSessionCacheLRU:
def __cinit__(self, capacity):
grpc_init()
fork_handlers_and_grpc_init()
self._cache = grpc_ssl_session_cache_create_lru(capacity)
def __int__(self):
@ -172,7 +172,7 @@ cdef class CompositeChannelCredentials(ChannelCredentials):
cdef class ServerCertificateConfig:
def __cinit__(self):
grpc_init()
fork_handlers_and_grpc_init()
self.c_cert_config = NULL
self.c_pem_root_certs = NULL
self.c_ssl_pem_key_cert_pairs = NULL
@ -187,7 +187,7 @@ cdef class ServerCertificateConfig:
cdef class ServerCredentials:
def __cinit__(self):
grpc_init()
fork_handlers_and_grpc_init()
self.c_credentials = NULL
self.references = []
self.initial_cert_config = None

@ -0,0 +1,29 @@
# Copyright 2018 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
cdef extern from "pthread.h" nogil:
int pthread_atfork(
void (*prepare)() nogil,
void (*parent)() nogil,
void (*child)() nogil)
cdef void __prefork() nogil
cdef void __postfork_parent() nogil
cdef void __postfork_child() nogil

@ -0,0 +1,203 @@
# Copyright 2018 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import threading
_LOGGER = logging.getLogger(__name__)
_AWAIT_THREADS_TIMEOUT_SECONDS = 5
_TRUE_VALUES = ['yes', 'Yes', 'YES', 'true', 'True', 'TRUE', '1']
# This flag enables experimental support within gRPC Python for applications
# that will fork() without exec(). When enabled, gRPC Python will attempt to
# pause all of its internally created threads before the fork syscall proceeds.
#
# For this to be successful, the application must not have multiple threads of
# its own calling into gRPC when fork is invoked. Any callbacks from gRPC
# Python-spawned threads into user code (e.g., callbacks for asynchronous RPCs)
# must not block and should execute quickly.
#
# This flag is not supported on Windows.
_GRPC_ENABLE_FORK_SUPPORT = (
os.environ.get('GRPC_ENABLE_FORK_SUPPORT', '0')
.lower() in _TRUE_VALUES)
_GRPC_POLL_STRATEGY = os.environ.get('GRPC_POLL_STRATEGY')
cdef void __prefork() nogil:
with gil:
with _fork_state.fork_in_progress_condition:
_fork_state.fork_in_progress = True
if not _fork_state.active_thread_count.await_zero_threads(
_AWAIT_THREADS_TIMEOUT_SECONDS):
_LOGGER.error(
'Failed to shutdown gRPC Python threads prior to fork. '
'Behavior after fork will be undefined.')
cdef void __postfork_parent() nogil:
with gil:
with _fork_state.fork_in_progress_condition:
_fork_state.fork_in_progress = False
_fork_state.fork_in_progress_condition.notify_all()
cdef void __postfork_child() nogil:
with gil:
# Thread could be holding the fork_in_progress_condition inside of
# block_if_fork_in_progress() when fork occurs. Reset the lock here.
_fork_state.fork_in_progress_condition = threading.Condition()
# A thread in return_from_user_request_generator() may hold this lock
# when fork occurs.
_fork_state.active_thread_count = _ActiveThreadCount()
for state_to_reset in _fork_state.postfork_states_to_reset:
state_to_reset.reset_postfork_child()
_fork_state.fork_epoch += 1
for channel in _fork_state.channels:
channel._close_on_fork()
# TODO(ericgribkoff) Check and abort if core is not shutdown
with _fork_state.fork_in_progress_condition:
_fork_state.fork_in_progress = False
def fork_handlers_and_grpc_init():
grpc_init()
if _GRPC_ENABLE_FORK_SUPPORT:
# TODO(ericgribkoff) epoll1 is default for grpcio distribution. Decide whether to expose
# grpc_get_poll_strategy_name() from ev_posix.cc to get actual polling choice.
if _GRPC_POLL_STRATEGY is not None and _GRPC_POLL_STRATEGY != "epoll1":
_LOGGER.error(
'gRPC Python fork support is only compatible with the epoll1 '
'polling engine')
return
with _fork_state.fork_handler_registered_lock:
if not _fork_state.fork_handler_registered:
pthread_atfork(&__prefork, &__postfork_parent, &__postfork_child)
_fork_state.fork_handler_registered = True
class ForkManagedThread(object):
def __init__(self, target, args=()):
if _GRPC_ENABLE_FORK_SUPPORT:
def managed_target(*args):
try:
target(*args)
finally:
_fork_state.active_thread_count.decrement()
self._thread = threading.Thread(target=managed_target, args=args)
else:
self._thread = threading.Thread(target=target, args=args)
def setDaemon(self, daemonic):
self._thread.daemon = daemonic
def start(self):
if _GRPC_ENABLE_FORK_SUPPORT:
_fork_state.active_thread_count.increment()
self._thread.start()
def join(self):
self._thread.join()
def block_if_fork_in_progress(postfork_state_to_reset=None):
if _GRPC_ENABLE_FORK_SUPPORT:
with _fork_state.fork_in_progress_condition:
if not _fork_state.fork_in_progress:
return
if postfork_state_to_reset is not None:
_fork_state.postfork_states_to_reset.append(postfork_state_to_reset)
_fork_state.active_thread_count.decrement()
_fork_state.fork_in_progress_condition.wait()
_fork_state.active_thread_count.increment()
def enter_user_request_generator():
if _GRPC_ENABLE_FORK_SUPPORT:
_fork_state.active_thread_count.decrement()
def return_from_user_request_generator():
if _GRPC_ENABLE_FORK_SUPPORT:
_fork_state.active_thread_count.increment()
block_if_fork_in_progress()
def get_fork_epoch():
return _fork_state.fork_epoch
def is_fork_support_enabled():
return _GRPC_ENABLE_FORK_SUPPORT
def fork_register_channel(channel):
if _GRPC_ENABLE_FORK_SUPPORT:
_fork_state.channels.add(channel)
def fork_unregister_channel(channel):
if _GRPC_ENABLE_FORK_SUPPORT:
_fork_state.channels.remove(channel)
class _ActiveThreadCount(object):
def __init__(self):
self._num_active_threads = 0
self._condition = threading.Condition()
def increment(self):
with self._condition:
self._num_active_threads += 1
def decrement(self):
with self._condition:
self._num_active_threads -= 1
if self._num_active_threads == 0:
self._condition.notify_all()
def await_zero_threads(self, timeout_secs):
end_time = time.time() + timeout_secs
wait_time = timeout_secs
with self._condition:
while True:
if self._num_active_threads > 0:
self._condition.wait(wait_time)
if self._num_active_threads == 0:
return True
# Thread count may have increased before this re-obtains the
# lock after a notify(). Wait again until timeout_secs has
# elapsed.
wait_time = end_time - time.time()
if wait_time <= 0:
return False
class _ForkState(object):
def __init__(self):
self.fork_in_progress_condition = threading.Condition()
self.fork_in_progress = False
self.postfork_states_to_reset = []
self.fork_handler_registered_lock = threading.Lock()
self.fork_handler_registered = False
self.active_thread_count = _ActiveThreadCount()
self.fork_epoch = 0
self.channels = set()
_fork_state = _ForkState()

@ -0,0 +1,63 @@
# Copyright 2018 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import threading
# No-op implementations for Windows.
def fork_handlers_and_grpc_init():
grpc_init()
class ForkManagedThread(object):
def __init__(self, target, args=()):
self._thread = threading.Thread(target=target, args=args)
def setDaemon(self, daemonic):
self._thread.daemon = daemonic
def start(self):
self._thread.start()
def join(self):
self._thread.join()
def block_if_fork_in_progress(postfork_state_to_reset=None):
pass
def enter_user_request_generator():
pass
def return_from_user_request_generator():
pass
def get_fork_epoch():
return 0
def is_fork_support_enabled():
return False
def fork_register_channel(channel):
pass
def fork_unregister_channel(channel):
pass

@ -127,7 +127,7 @@ class CompressionLevel:
cdef class CallDetails:
def __cinit__(self):
grpc_init()
fork_handlers_and_grpc_init()
with nogil:
grpc_call_details_init(&self.c_details)

@ -60,7 +60,7 @@ cdef grpc_ssl_certificate_config_reload_status _server_cert_config_fetcher_wrapp
cdef class Server:
def __cinit__(self, object arguments):
grpc_init()
fork_handlers_and_grpc_init()
self.references = []
self.registered_completion_queues = []
self._vtable.copy = &_copy_pointer

@ -31,3 +31,6 @@ include "_cygrpc/time.pxd.pxi"
include "_cygrpc/_hooks.pxd.pxi"
include "_cygrpc/grpc_gevent.pxd.pxi"
IF UNAME_SYSNAME != "Windows":
include "_cygrpc/fork_posix.pxd.pxi"

@ -39,6 +39,11 @@ include "_cygrpc/_hooks.pyx.pxi"
include "_cygrpc/grpc_gevent.pyx.pxi"
IF UNAME_SYSNAME == "Windows":
include "_cygrpc/fork_windows.pyx.pxi"
ELSE:
include "_cygrpc/fork_posix.pyx.pxi"
#
# initialize gRPC
#

@ -202,3 +202,28 @@ class RunInterop(test.test):
from tests.interop import client
sys.argv[1:] = self.args.split()
client.test_interoperability()
class RunFork(test.test):
description = 'run fork test client'
user_options = [('args=', 'a', 'pass-thru arguments for the client')]
def initialize_options(self):
self.args = ''
def finalize_options(self):
# distutils requires this override.
pass
def run(self):
if self.distribution.install_requires:
self.distribution.fetch_build_eggs(
self.distribution.install_requires)
if self.distribution.tests_require:
self.distribution.fetch_build_eggs(self.distribution.tests_require)
# We import here to ensure that our setuptools parent has had a chance to
# edit the Python system path.
from tests.fork import client
sys.argv[1:] = self.args.split()
client.test_fork()

@ -52,6 +52,7 @@ COMMAND_CLASS = {
'preprocess': commands.GatherProto,
'build_package_protos': grpc_tools.command.BuildPackageProtos,
'build_py': commands.BuildPy,
'run_fork': commands.RunFork,
'run_interop': commands.RunInterop,
'test_lite': commands.TestLite,
'test_gevent': commands.TestGevent,

@ -0,0 +1,13 @@
# Copyright 2018 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,76 @@
# Copyright 2018 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The Python implementation of the GRPC interoperability test client."""
import argparse
import logging
import sys
from tests.fork import methods
def _args():
def parse_bool(value):
if value == 'true':
return True
if value == 'false':
return False
raise argparse.ArgumentTypeError('Only true/false allowed')
parser = argparse.ArgumentParser()
parser.add_argument(
'--server_host',
default="localhost",
type=str,
help='the host to which to connect')
parser.add_argument(
'--server_port',
type=int,
required=True,
help='the port to which to connect')
parser.add_argument(
'--test_case',
default='large_unary',
type=str,
help='the test case to execute')
parser.add_argument(
'--use_tls',
default=False,
type=parse_bool,
help='require a secure connection')
return parser.parse_args()
def _test_case_from_arg(test_case_arg):
for test_case in methods.TestCase:
if test_case_arg == test_case.value:
return test_case
else:
raise ValueError('No test case "%s"!' % test_case_arg)
def test_fork():
logging.basicConfig(level=logging.INFO)
args = _args()
if args.test_case == "all":
for test_case in methods.TestCase:
test_case.run_test(args)
else:
test_case = _test_case_from_arg(args.test_case)
test_case.run_test(args)
if __name__ == '__main__':
test_fork()

@ -0,0 +1,445 @@
# Copyright 2018 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implementations of fork support test methods."""
import enum
import json
import logging
import multiprocessing
import os
import threading
import time
import grpc
from six.moves import queue
from src.proto.grpc.testing import empty_pb2
from src.proto.grpc.testing import messages_pb2
from src.proto.grpc.testing import test_pb2_grpc
_LOGGER = logging.getLogger(__name__)
def _channel(args):
target = '{}:{}'.format(args.server_host, args.server_port)
if args.use_tls:
channel_credentials = grpc.ssl_channel_credentials()
channel = grpc.secure_channel(target, channel_credentials)
else:
channel = grpc.insecure_channel(target)
return channel
def _validate_payload_type_and_length(response, expected_type, expected_length):
if response.payload.type is not expected_type:
raise ValueError('expected payload type %s, got %s' %
(expected_type, type(response.payload.type)))
elif len(response.payload.body) != expected_length:
raise ValueError('expected payload body size %d, got %d' %
(expected_length, len(response.payload.body)))
def _async_unary(stub):
size = 314159
request = messages_pb2.SimpleRequest(
response_type=messages_pb2.COMPRESSABLE,
response_size=size,
payload=messages_pb2.Payload(body=b'\x00' * 271828))
response_future = stub.UnaryCall.future(request)
response = response_future.result()
_validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE, size)
def _blocking_unary(stub):
size = 314159
request = messages_pb2.SimpleRequest(
response_type=messages_pb2.COMPRESSABLE,
response_size=size,
payload=messages_pb2.Payload(body=b'\x00' * 271828))
response = stub.UnaryCall(request)
_validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE, size)
class _Pipe(object):
def __init__(self):
self._condition = threading.Condition()
self._values = []
self._open = True
def __iter__(self):
return self
def __next__(self):
return self.next()
def next(self):
with self._condition:
while not self._values and self._open:
self._condition.wait()
if self._values:
return self._values.pop(0)
else:
raise StopIteration()
def add(self, value):
with self._condition:
self._values.append(value)
self._condition.notify()
def close(self):
with self._condition:
self._open = False
self._condition.notify()
def __enter__(self):
return self
def __exit__(self, type, value, traceback):
self.close()
class _ChildProcess(object):
def __init__(self, task, args=None):
if args is None:
args = ()
self._exceptions = multiprocessing.Queue()
def record_exceptions():
try:
task(*args)
except Exception as e: # pylint: disable=broad-except
self._exceptions.put(e)
self._process = multiprocessing.Process(target=record_exceptions)
def start(self):
self._process.start()
def finish(self):
self._process.join()
if self._process.exitcode != 0:
raise ValueError('Child process failed with exitcode %d' %
self._process.exitcode)
try:
exception = self._exceptions.get(block=False)
raise ValueError('Child process failed: %s' % exception)
except queue.Empty:
pass
def _async_unary_same_channel(channel):
def child_target():
try:
_async_unary(stub)
raise Exception(
'Child should not be able to re-use channel after fork')
except ValueError as expected_value_error:
pass
stub = test_pb2_grpc.TestServiceStub(channel)
_async_unary(stub)
child_process = _ChildProcess(child_target)
child_process.start()
_async_unary(stub)
child_process.finish()
def _async_unary_new_channel(channel, args):
def child_target():
child_channel = _channel(args)
child_stub = test_pb2_grpc.TestServiceStub(child_channel)
_async_unary(child_stub)
child_channel.close()
stub = test_pb2_grpc.TestServiceStub(channel)
_async_unary(stub)
child_process = _ChildProcess(child_target)
child_process.start()
_async_unary(stub)
child_process.finish()
def _blocking_unary_same_channel(channel):
def child_target():
try:
_blocking_unary(stub)
raise Exception(
'Child should not be able to re-use channel after fork')
except ValueError as expected_value_error:
pass
stub = test_pb2_grpc.TestServiceStub(channel)
_blocking_unary(stub)
child_process = _ChildProcess(child_target)
child_process.start()
child_process.finish()
def _blocking_unary_new_channel(channel, args):
def child_target():
child_channel = _channel(args)
child_stub = test_pb2_grpc.TestServiceStub(child_channel)
_blocking_unary(child_stub)
child_channel.close()
stub = test_pb2_grpc.TestServiceStub(channel)
_blocking_unary(stub)
child_process = _ChildProcess(child_target)
child_process.start()
_blocking_unary(stub)
child_process.finish()
# Verify that the fork channel registry can handle already closed channels
def _close_channel_before_fork(channel, args):
def child_target():
new_channel.close()
child_channel = _channel(args)
child_stub = test_pb2_grpc.TestServiceStub(child_channel)
_blocking_unary(child_stub)
child_channel.close()
stub = test_pb2_grpc.TestServiceStub(channel)
_blocking_unary(stub)
channel.close()
new_channel = _channel(args)
new_stub = test_pb2_grpc.TestServiceStub(new_channel)
child_process = _ChildProcess(child_target)
child_process.start()
_blocking_unary(new_stub)
child_process.finish()
def _connectivity_watch(channel, args):
def child_target():
def child_connectivity_callback(state):
child_states.append(state)
child_states = []
child_channel = _channel(args)
child_stub = test_pb2_grpc.TestServiceStub(child_channel)
child_channel.subscribe(child_connectivity_callback)
_async_unary(child_stub)
if len(child_states
) < 2 or child_states[-1] != grpc.ChannelConnectivity.READY:
raise ValueError('Channel did not move to READY')
if len(parent_states) > 1:
raise ValueError('Received connectivity updates on parent callback')
child_channel.unsubscribe(child_connectivity_callback)
child_channel.close()
def parent_connectivity_callback(state):
parent_states.append(state)
parent_states = []
channel.subscribe(parent_connectivity_callback)
stub = test_pb2_grpc.TestServiceStub(channel)
child_process = _ChildProcess(child_target)
child_process.start()
_async_unary(stub)
if len(parent_states
) < 2 or parent_states[-1] != grpc.ChannelConnectivity.READY:
raise ValueError('Channel did not move to READY')
channel.unsubscribe(parent_connectivity_callback)
child_process.finish()
# Need to unsubscribe or _channel.py in _poll_connectivity triggers a
# "Cannot invoke RPC on closed channel!" error.
# TODO(ericgribkoff) Fix issue with channel.close() and connectivity polling
channel.unsubscribe(parent_connectivity_callback)
def _ping_pong_with_child_processes_after_first_response(
channel, args, child_target, run_after_close=True):
request_response_sizes = (
31415,
9,
2653,
58979,
)
request_payload_sizes = (
27182,
8,
1828,
45904,
)
stub = test_pb2_grpc.TestServiceStub(channel)
pipe = _Pipe()
parent_bidi_call = stub.FullDuplexCall(pipe)
child_processes = []
first_message_received = False
for response_size, payload_size in zip(request_response_sizes,
request_payload_sizes):
request = messages_pb2.StreamingOutputCallRequest(
response_type=messages_pb2.COMPRESSABLE,
response_parameters=(
messages_pb2.ResponseParameters(size=response_size),),
payload=messages_pb2.Payload(body=b'\x00' * payload_size))
pipe.add(request)
if first_message_received:
child_process = _ChildProcess(child_target,
(parent_bidi_call, channel, args))
child_process.start()
child_processes.append(child_process)
response = next(parent_bidi_call)
first_message_received = True
child_process = _ChildProcess(child_target,
(parent_bidi_call, channel, args))
child_process.start()
child_processes.append(child_process)
_validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE,
response_size)
pipe.close()
if run_after_close:
child_process = _ChildProcess(child_target,
(parent_bidi_call, channel, args))
child_process.start()
child_processes.append(child_process)
for child_process in child_processes:
child_process.finish()
def _in_progress_bidi_continue_call(channel):
def child_target(parent_bidi_call, parent_channel, args):
stub = test_pb2_grpc.TestServiceStub(parent_channel)
try:
_async_unary(stub)
raise Exception(
'Child should not be able to re-use channel after fork')
except ValueError as expected_value_error:
pass
inherited_code = parent_bidi_call.code()
inherited_details = parent_bidi_call.details()
if inherited_code != grpc.StatusCode.CANCELLED:
raise ValueError(
'Expected inherited code CANCELLED, got %s' % inherited_code)
if inherited_details != 'Channel closed due to fork':
raise ValueError(
'Expected inherited details Channel closed due to fork, got %s'
% inherited_details)
# Don't run child_target after closing the parent call, as the call may have
# received a status from the server before fork occurs.
_ping_pong_with_child_processes_after_first_response(
channel, None, child_target, run_after_close=False)
def _in_progress_bidi_same_channel_async_call(channel):
def child_target(parent_bidi_call, parent_channel, args):
stub = test_pb2_grpc.TestServiceStub(parent_channel)
try:
_async_unary(stub)
raise Exception(
'Child should not be able to re-use channel after fork')
except ValueError as expected_value_error:
pass
_ping_pong_with_child_processes_after_first_response(
channel, None, child_target)
def _in_progress_bidi_same_channel_blocking_call(channel):
def child_target(parent_bidi_call, parent_channel, args):
stub = test_pb2_grpc.TestServiceStub(parent_channel)
try:
_blocking_unary(stub)
raise Exception(
'Child should not be able to re-use channel after fork')
except ValueError as expected_value_error:
pass
_ping_pong_with_child_processes_after_first_response(
channel, None, child_target)
def _in_progress_bidi_new_channel_async_call(channel, args):
def child_target(parent_bidi_call, parent_channel, args):
channel = _channel(args)
stub = test_pb2_grpc.TestServiceStub(channel)
_async_unary(stub)
_ping_pong_with_child_processes_after_first_response(
channel, args, child_target)
def _in_progress_bidi_new_channel_blocking_call(channel, args):
def child_target(parent_bidi_call, parent_channel, args):
channel = _channel(args)
stub = test_pb2_grpc.TestServiceStub(channel)
_blocking_unary(stub)
_ping_pong_with_child_processes_after_first_response(
channel, args, child_target)
@enum.unique
class TestCase(enum.Enum):
CONNECTIVITY_WATCH = 'connectivity_watch'
CLOSE_CHANNEL_BEFORE_FORK = 'close_channel_before_fork'
ASYNC_UNARY_SAME_CHANNEL = 'async_unary_same_channel'
ASYNC_UNARY_NEW_CHANNEL = 'async_unary_new_channel'
BLOCKING_UNARY_SAME_CHANNEL = 'blocking_unary_same_channel'
BLOCKING_UNARY_NEW_CHANNEL = 'blocking_unary_new_channel'
IN_PROGRESS_BIDI_CONTINUE_CALL = 'in_progress_bidi_continue_call'
IN_PROGRESS_BIDI_SAME_CHANNEL_ASYNC_CALL = 'in_progress_bidi_same_channel_async_call'
IN_PROGRESS_BIDI_SAME_CHANNEL_BLOCKING_CALL = 'in_progress_bidi_same_channel_blocking_call'
IN_PROGRESS_BIDI_NEW_CHANNEL_ASYNC_CALL = 'in_progress_bidi_new_channel_async_call'
IN_PROGRESS_BIDI_NEW_CHANNEL_BLOCKING_CALL = 'in_progress_bidi_new_channel_blocking_call'
def run_test(self, args):
_LOGGER.info("Running %s", self)
channel = _channel(args)
if self is TestCase.ASYNC_UNARY_SAME_CHANNEL:
_async_unary_same_channel(channel)
elif self is TestCase.ASYNC_UNARY_NEW_CHANNEL:
_async_unary_new_channel(channel, args)
elif self is TestCase.BLOCKING_UNARY_SAME_CHANNEL:
_blocking_unary_same_channel(channel)
elif self is TestCase.BLOCKING_UNARY_NEW_CHANNEL:
_blocking_unary_new_channel(channel, args)
elif self is TestCase.CLOSE_CHANNEL_BEFORE_FORK:
_close_channel_before_fork(channel, args)
elif self is TestCase.CONNECTIVITY_WATCH:
_connectivity_watch(channel, args)
elif self is TestCase.IN_PROGRESS_BIDI_CONTINUE_CALL:
_in_progress_bidi_continue_call(channel)
elif self is TestCase.IN_PROGRESS_BIDI_SAME_CHANNEL_ASYNC_CALL:
_in_progress_bidi_same_channel_async_call(channel)
elif self is TestCase.IN_PROGRESS_BIDI_SAME_CHANNEL_BLOCKING_CALL:
_in_progress_bidi_same_channel_blocking_call(channel)
elif self is TestCase.IN_PROGRESS_BIDI_NEW_CHANNEL_ASYNC_CALL:
_in_progress_bidi_new_channel_async_call(channel, args)
elif self is TestCase.IN_PROGRESS_BIDI_NEW_CHANNEL_BLOCKING_CALL:
_in_progress_bidi_new_channel_blocking_call(channel, args)
else:
raise NotImplementedError(
'Test case "%s" not implemented!' % self.name)
channel.close()

@ -32,6 +32,8 @@
"unit._credentials_test.CredentialsTest",
"unit._cython._cancel_many_calls_test.CancelManyCallsTest",
"unit._cython._channel_test.ChannelTest",
"unit._cython._fork_test.ForkPosixTester",
"unit._cython._fork_test.ForkWindowsTester",
"unit._cython._no_messages_server_completion_queue_per_call_test.Test",
"unit._cython._no_messages_single_server_completion_queue_test.Test",
"unit._cython._read_some_but_not_all_responses_test.ReadSomeButNotAllResponsesTest",

@ -0,0 +1,68 @@
# Copyright 2018 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import threading
import unittest
from grpc._cython import cygrpc
def _get_number_active_threads():
return cygrpc._fork_state.active_thread_count._num_active_threads
@unittest.skipIf(os.name == 'nt', 'Posix-specific tests')
class ForkPosixTester(unittest.TestCase):
def setUp(self):
cygrpc._GRPC_ENABLE_FORK_SUPPORT = True
def testForkManagedThread(self):
def cb():
self.assertEqual(1, _get_number_active_threads())
thread = cygrpc.ForkManagedThread(cb)
thread.start()
thread.join()
self.assertEqual(0, _get_number_active_threads())
def testForkManagedThreadThrowsException(self):
def cb():
self.assertEqual(1, _get_number_active_threads())
raise Exception("expected exception")
thread = cygrpc.ForkManagedThread(cb)
thread.start()
thread.join()
self.assertEqual(0, _get_number_active_threads())
@unittest.skipUnless(os.name == 'nt', 'Windows-specific tests')
class ForkWindowsTester(unittest.TestCase):
def testForkManagedThreadIsNoOp(self):
def cb():
pass
thread = cygrpc.ForkManagedThread(cb)
thread.start()
thread.join()
if __name__ == '__main__':
unittest.main(verbosity=2)
Loading…
Cancel
Save