diff --git a/src/python/grpcio/grpc/_channel.py b/src/python/grpcio/grpc/_channel.py index a89b5013030..29dbc3a668b 100644 --- a/src/python/grpcio/grpc/_channel.py +++ b/src/python/grpcio/grpc/_channel.py @@ -195,7 +195,8 @@ def _consume_request_iterator( cygrpc.operation_send_message( serialized_request, _EMPTY_FLAGS), ) - call.start_batch(cygrpc.Operations(operations), event_handler) + call.start_client_batch(cygrpc.Operations(operations), + event_handler) state.due.add(cygrpc.OperationType.send_message) while True: state.condition.wait() @@ -211,7 +212,7 @@ def _consume_request_iterator( operations = ( cygrpc.operation_send_close_from_client(_EMPTY_FLAGS), ) - call.start_batch(cygrpc.Operations(operations), event_handler) + call.start_client_batch(cygrpc.Operations(operations), event_handler) state.due.add(cygrpc.OperationType.send_close_from_client) def stop_consumption_thread(timeout): @@ -312,7 +313,7 @@ class _Rendezvous(grpc.RpcError, grpc.Future, grpc.Call): if self._state.code is None: event_handler = _event_handler( self._state, self._call, self._response_deserializer) - self._call.start_batch( + self._call.start_client_batch( cygrpc.Operations( (cygrpc.operation_receive_message(_EMPTY_FLAGS),)), event_handler) @@ -471,7 +472,7 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): None, 0, completion_queue, self._method, None, deadline_timespec) if credentials is not None: call.set_credentials(credentials._credentials) - call.start_batch(cygrpc.Operations(operations), None) + call.start_client_batch(cygrpc.Operations(operations), None) _handle_event(completion_queue.poll(), state, self._response_deserializer) return state, deadline @@ -495,7 +496,7 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): call.set_credentials(credentials._credentials) event_handler = _event_handler(state, call, self._response_deserializer) with state.condition: - call.start_batch(cygrpc.Operations(operations), event_handler) + call.start_client_batch(cygrpc.Operations(operations), event_handler) return _Rendezvous(state, call, self._response_deserializer, deadline) @@ -523,7 +524,7 @@ class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable): call.set_credentials(credentials._credentials) event_handler = _event_handler(state, call, self._response_deserializer) with state.condition: - call.start_batch( + call.start_client_batch( cygrpc.Operations( (cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),)), event_handler) @@ -534,7 +535,7 @@ class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable): cygrpc.operation_send_close_from_client(_EMPTY_FLAGS), cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS), ) - call.start_batch(cygrpc.Operations(operations), event_handler) + call.start_client_batch(cygrpc.Operations(operations), event_handler) return _Rendezvous(state, call, self._response_deserializer, deadline) @@ -558,7 +559,7 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): if credentials is not None: call.set_credentials(credentials._credentials) with state.condition: - call.start_batch( + call.start_client_batch( cygrpc.Operations( (cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),)), None) @@ -568,7 +569,7 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): cygrpc.operation_receive_message(_EMPTY_FLAGS), cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS), ) - call.start_batch(cygrpc.Operations(operations), None) + call.start_client_batch(cygrpc.Operations(operations), None) _consume_request_iterator( request_iterator, state, call, self._request_serializer) while True: @@ -602,7 +603,7 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): call.set_credentials(credentials._credentials) event_handler = _event_handler(state, call, self._response_deserializer) with state.condition: - call.start_batch( + call.start_client_batch( cygrpc.Operations( (cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),)), event_handler) @@ -612,7 +613,7 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): cygrpc.operation_receive_message(_EMPTY_FLAGS), cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS), ) - call.start_batch(cygrpc.Operations(operations), event_handler) + call.start_client_batch(cygrpc.Operations(operations), event_handler) _consume_request_iterator( request_iterator, state, call, self._request_serializer) return _Rendezvous(state, call, self._response_deserializer, deadline) @@ -639,7 +640,7 @@ class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable): call.set_credentials(credentials._credentials) event_handler = _event_handler(state, call, self._response_deserializer) with state.condition: - call.start_batch( + call.start_client_batch( cygrpc.Operations( (cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),)), event_handler) @@ -648,7 +649,7 @@ class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable): _common.cygrpc_metadata(metadata), _EMPTY_FLAGS), cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS), ) - call.start_batch(cygrpc.Operations(operations), event_handler) + call.start_client_batch(cygrpc.Operations(operations), event_handler) _consume_request_iterator( request_iterator, state, call, self._request_serializer) return _Rendezvous(state, call, self._response_deserializer, deadline) diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/call.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/call.pyx.pxi index a09bbc75fe6..6570dcdb852 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/call.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/call.pyx.pxi @@ -37,13 +37,16 @@ cdef class Call: self.c_call = NULL self.references = [] - def start_batch(self, operations, tag): + def _start_batch(self, operations, tag, retain_self): if not self.is_valid: raise ValueError("invalid call object cannot be used from Python") cdef grpc_call_error result cdef Operations cy_operations = Operations(operations) cdef OperationTag operation_tag = OperationTag(tag) - operation_tag.operation_call = self + if retain_self: + operation_tag.operation_call = self + else: + operation_tag.operation_call = None operation_tag.batch_operations = cy_operations cpython.Py_INCREF(operation_tag) with nogil: @@ -52,6 +55,14 @@ cdef class Call: operation_tag, NULL) return result + def start_client_batch(self, operations, tag): + # We don't reference this call in the operations tag because + # it should be cancelled when it goes out of scope + return self._start_batch(operations, tag, False) + + def start_server_batch(self, operations, tag): + return self._start_batch(operations, tag, True) + def cancel( self, grpc_status_code error_code=GRPC_STATUS__DO_NOT_USE, details=None): diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/records.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/records.pxd.pxi index 0474697af82..96c5b02bc2e 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/records.pxd.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/records.pxd.pxi @@ -58,14 +58,14 @@ cdef class Event: cdef readonly bint success cdef readonly object tag - # For operations with calls - cdef readonly Call operation_call - # For Server.request_call cdef readonly bint is_new_request cdef readonly CallDetails request_call_details cdef readonly Metadata request_metadata + # For server calls + cdef readonly Call operation_call + # For Call.start_batch cdef readonly Operations batch_operations diff --git a/src/python/grpcio/grpc/_server.py b/src/python/grpcio/grpc/_server.py index f4c114056fe..5d805b53e40 100644 --- a/src/python/grpcio/grpc/_server.py +++ b/src/python/grpcio/grpc/_server.py @@ -157,7 +157,7 @@ def _abort(state, call, code, details): effective_details, _EMPTY_FLAGS), ) token = _SEND_STATUS_FROM_SERVER_TOKEN - call.start_batch( + call.start_server_batch( cygrpc.Operations(operations), _send_status_from_server(state, token)) state.statused = True @@ -257,7 +257,7 @@ class _Context(grpc.ServicerContext): if self._state.initial_metadata_allowed: operation = cygrpc.operation_send_initial_metadata( _common.cygrpc_metadata(initial_metadata), _EMPTY_FLAGS) - self._rpc_event.operation_call.start_batch( + self._rpc_event.operation_call.start_server_batch( cygrpc.Operations((operation,)), _send_initial_metadata(self._state)) self._state.initial_metadata_allowed = False @@ -292,7 +292,7 @@ class _RequestIterator(object): elif self._state.client is _CLOSED or self._state.statused: raise StopIteration() else: - self._call.start_batch( + self._call.start_server_batch( cygrpc.Operations((cygrpc.operation_receive_message(_EMPTY_FLAGS),)), _receive_message(self._state, self._call, self._request_deserializer)) self._state.due.add(_RECEIVE_MESSAGE_TOKEN) @@ -333,7 +333,7 @@ def _unary_request(rpc_event, state, request_deserializer): if state.client is _CANCELLED or state.statused: return None else: - start_batch_result = rpc_event.operation_call.start_batch( + start_server_batch_result = rpc_event.operation_call.start_server_batch( cygrpc.Operations( (cygrpc.operation_receive_message(_EMPTY_FLAGS),)), _receive_message( @@ -417,7 +417,7 @@ def _send_response(rpc_event, state, serialized_response): cygrpc.operation_send_message(serialized_response, _EMPTY_FLAGS), ) token = _SEND_MESSAGE_TOKEN - rpc_event.operation_call.start_batch( + rpc_event.operation_call.start_server_batch( cygrpc.Operations(operations), _send_message(state, token)) state.due.add(token) while True: @@ -443,7 +443,7 @@ def _status(rpc_event, state, serialized_response): if serialized_response is not None: operations.append(cygrpc.operation_send_message( serialized_response, _EMPTY_FLAGS)) - rpc_event.operation_call.start_batch( + rpc_event.operation_call.start_server_batch( cygrpc.Operations(operations), _send_status_from_server(state, _SEND_STATUS_FROM_SERVER_TOKEN)) state.statused = True @@ -550,7 +550,7 @@ def _handle_unrecognized_method(rpc_event): b'Method not found!', _EMPTY_FLAGS), ) rpc_state = _RPCState() - rpc_event.operation_call.start_batch( + rpc_event.operation_call.start_server_batch( operations, lambda ignored_event: (rpc_state, (),)) return rpc_state @@ -558,7 +558,7 @@ def _handle_unrecognized_method(rpc_event): def _handle_with_method_handler(rpc_event, method_handler, thread_pool): state = _RPCState() with state.condition: - rpc_event.operation_call.start_batch( + rpc_event.operation_call.start_server_batch( cygrpc.Operations( (cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS),)), _receive_close_on_server(state)) diff --git a/src/python/grpcio_tests/tests/unit/_cython/_cancel_many_calls_test.py b/src/python/grpcio_tests/tests/unit/_cython/_cancel_many_calls_test.py index cac0c8b3b93..cf212c56539 100644 --- a/src/python/grpcio_tests/tests/unit/_cython/_cancel_many_calls_test.py +++ b/src/python/grpcio_tests/tests/unit/_cython/_cancel_many_calls_test.py @@ -81,11 +81,11 @@ class _Handler(object): self._state.condition.wait() with self._lock: - self._call.start_batch( + self._call.start_server_batch( cygrpc.Operations( (cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS),)), _RECEIVE_CLOSE_ON_SERVER_TAG) - self._call.start_batch( + self._call.start_server_batch( cygrpc.Operations((cygrpc.operation_receive_message(_EMPTY_FLAGS),)), _RECEIVE_MESSAGE_TAG) first_event = self._completion_queue.poll() @@ -101,7 +101,7 @@ class _Handler(object): _EMPTY_METADATA, cygrpc.StatusCode.ok, b'test details!', _EMPTY_FLAGS), ) - self._call.start_batch( + self._call.start_server_batch( cygrpc.Operations(operations), _SERVER_COMPLETE_CALL_TAG) self._completion_queue.poll() self._completion_queue.poll() @@ -193,7 +193,7 @@ class CancelManyCallsTest(unittest.TestCase): cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS), ) tag = 'client_complete_call_{0:04d}_tag'.format(index) - client_call.start_batch(cygrpc.Operations(operations), tag) + client_call.start_client_batch(cygrpc.Operations(operations), tag) client_due.add(tag) client_calls.append(client_call) diff --git a/src/python/grpcio_tests/tests/unit/_cython/_read_some_but_not_all_responses_test.py b/src/python/grpcio_tests/tests/unit/_cython/_read_some_but_not_all_responses_test.py index 27fcee0d6f1..152d8edde3b 100644 --- a/src/python/grpcio_tests/tests/unit/_cython/_read_some_but_not_all_responses_test.py +++ b/src/python/grpcio_tests/tests/unit/_cython/_read_some_but_not_all_responses_test.py @@ -168,12 +168,12 @@ class ReadSomeButNotAllResponsesTest(unittest.TestCase): client_complete_rpc_tag = 'client_complete_rpc_tag' with client_condition: client_receive_initial_metadata_start_batch_result = ( - client_call.start_batch(cygrpc.Operations([ + client_call.start_client_batch(cygrpc.Operations([ cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS), ]), client_receive_initial_metadata_tag)) client_due.add(client_receive_initial_metadata_tag) client_complete_rpc_start_batch_result = ( - client_call.start_batch(cygrpc.Operations([ + client_call.start_client_batch(cygrpc.Operations([ cygrpc.operation_send_initial_metadata( _EMPTY_METADATA, _EMPTY_FLAGS), cygrpc.operation_send_close_from_client(_EMPTY_FLAGS), @@ -185,30 +185,30 @@ class ReadSomeButNotAllResponsesTest(unittest.TestCase): with server_call_condition: server_send_initial_metadata_start_batch_result = ( - server_rpc_event.operation_call.start_batch(cygrpc.Operations([ + server_rpc_event.operation_call.start_server_batch([ cygrpc.operation_send_initial_metadata( _EMPTY_METADATA, _EMPTY_FLAGS), - ]), server_send_initial_metadata_tag)) + ], server_send_initial_metadata_tag)) server_send_first_message_start_batch_result = ( - server_rpc_event.operation_call.start_batch(cygrpc.Operations([ + server_rpc_event.operation_call.start_server_batch([ cygrpc.operation_send_message(b'\x07', _EMPTY_FLAGS), - ]), server_send_first_message_tag)) + ], server_send_first_message_tag)) server_send_initial_metadata_event = server_call_driver.event_with_tag( server_send_initial_metadata_tag) server_send_first_message_event = server_call_driver.event_with_tag( server_send_first_message_tag) with server_call_condition: server_send_second_message_start_batch_result = ( - server_rpc_event.operation_call.start_batch(cygrpc.Operations([ + server_rpc_event.operation_call.start_server_batch([ cygrpc.operation_send_message(b'\x07', _EMPTY_FLAGS), - ]), server_send_second_message_tag)) + ], server_send_second_message_tag)) server_complete_rpc_start_batch_result = ( - server_rpc_event.operation_call.start_batch(cygrpc.Operations([ + server_rpc_event.operation_call.start_server_batch([ cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS), cygrpc.operation_send_status_from_server( cygrpc.Metadata(()), cygrpc.StatusCode.ok, b'test details', _EMPTY_FLAGS), - ]), server_complete_rpc_tag)) + ], server_complete_rpc_tag)) server_send_second_message_event = server_call_driver.event_with_tag( server_send_second_message_tag) server_complete_rpc_event = server_call_driver.event_with_tag( @@ -218,7 +218,7 @@ class ReadSomeButNotAllResponsesTest(unittest.TestCase): with client_condition: client_receive_first_message_tag = 'client_receive_first_message_tag' client_receive_first_message_start_batch_result = ( - client_call.start_batch(cygrpc.Operations([ + client_call.start_client_batch(cygrpc.Operations([ cygrpc.operation_receive_message(_EMPTY_FLAGS), ]), client_receive_first_message_tag)) client_due.add(client_receive_first_message_tag) diff --git a/src/python/grpcio_tests/tests/unit/_cython/cygrpc_test.py b/src/python/grpcio_tests/tests/unit/_cython/cygrpc_test.py index b740695e35b..9d1dbc189b8 100644 --- a/src/python/grpcio_tests/tests/unit/_cython/cygrpc_test.py +++ b/src/python/grpcio_tests/tests/unit/_cython/cygrpc_test.py @@ -186,7 +186,8 @@ class ServerClientMixin(object): def performer(): tag = object() try: - call_result = call.start_batch(cygrpc.Operations(operations), tag) + call_result = call.start_client_batch( + cygrpc.Operations(operations), tag) self.assertEqual(cygrpc.CallError.ok, call_result) event = queue.poll(deadline) self.assertEqual(cygrpc.CompletionType.operation_complete, event.type) @@ -231,7 +232,7 @@ class ServerClientMixin(object): cygrpc.Metadatum(CLIENT_METADATA_ASCII_KEY, CLIENT_METADATA_ASCII_VALUE), cygrpc.Metadatum(CLIENT_METADATA_BIN_KEY, CLIENT_METADATA_BIN_VALUE)]) - client_start_batch_result = client_call.start_batch(cygrpc.Operations([ + client_start_batch_result = client_call.start_client_batch([ cygrpc.operation_send_initial_metadata(client_initial_metadata, _EMPTY_FLAGS), cygrpc.operation_send_message(REQUEST, _EMPTY_FLAGS), @@ -239,7 +240,7 @@ class ServerClientMixin(object): cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS), cygrpc.operation_receive_message(_EMPTY_FLAGS), cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS) - ]), client_call_tag) + ], client_call_tag) self.assertEqual(cygrpc.CallError.ok, client_start_batch_result) client_event_future = test_utilities.CompletionQueuePollFuture( self.client_completion_queue, cygrpc_deadline) @@ -268,7 +269,7 @@ class ServerClientMixin(object): server_trailing_metadata = cygrpc.Metadata([ cygrpc.Metadatum(SERVER_TRAILING_METADATA_KEY, SERVER_TRAILING_METADATA_VALUE)]) - server_start_batch_result = server_call.start_batch([ + server_start_batch_result = server_call.start_server_batch([ cygrpc.operation_send_initial_metadata(server_initial_metadata, _EMPTY_FLAGS), cygrpc.operation_receive_message(_EMPTY_FLAGS),