diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi index 960f5317446..43557e68733 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi @@ -96,12 +96,14 @@ cdef class _AioCall: else: # By implementation, grpc_call_cancel always return OK grpc_call_cancel(self._grpc_call_wrapper.call, NULL) - return AioRpcStatus( + status = AioRpcStatus( StatusCode.cancelled, _UNKNOWN_CANCELLATION_DETAILS, None, None, ) + cancellation_future.set_result(status) + return status async def unary_unary(self, bytes method, diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/iomgr/socket.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/iomgr/socket.pxd.pxi index 363d9b6eea5..b205e3f70e4 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/iomgr/socket.pxd.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/iomgr/socket.pxd.pxi @@ -30,7 +30,6 @@ cdef class _AsyncioSocket: # Server-side attributes grpc_custom_accept_callback _grpc_accept_cb - grpc_custom_write_callback _grpc_write_cb grpc_custom_socket * _grpc_client_socket object _server object _py_socket diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/iomgr/socket.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/iomgr/socket.pyx.pxi index 4b14166be50..367148afe8a 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/iomgr/socket.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/iomgr/socket.pyx.pxi @@ -23,7 +23,6 @@ cdef class _AsyncioSocket: self._grpc_socket = NULL self._grpc_connect_cb = NULL self._grpc_read_cb = NULL - self._grpc_write_cb = NULL self._reader = None self._writer = None self._task_connect = None @@ -131,28 +130,22 @@ cdef class _AsyncioSocket: self._grpc_read_cb = grpc_read_cb self._task_read.add_done_callback(self._read_cb) self._read_buffer = buffer_ - - async def _async_write(self, bytearray buffer): - self._writer.write(buffer) - await self._writer.drain() - - self._grpc_write_cb( - self._grpc_socket, - 0 - ) cdef void write(self, grpc_slice_buffer * g_slice_buffer, grpc_custom_write_callback grpc_write_cb): - # For each socket, C-Core guarantees there'll be only one ongoing write - self._grpc_write_cb = grpc_write_cb - + """Performs write to network socket in AsyncIO. + + For each socket, C-Core guarantees there'll be only one ongoing write. + When the write is finished, we need to call grpc_write_cb to notify + C-Core that the work is done. + """ cdef char* start - buffer = bytearray() + cdef bytearray outbound_buffer = bytearray() for i in range(g_slice_buffer.count): start = grpc_slice_buffer_start(g_slice_buffer, i) length = grpc_slice_buffer_length(g_slice_buffer, i) - buffer.extend(start[:length]) + outbound_buffer.extend(start[:length]) - self._writer.write(buffer) + self._writer.write(outbound_buffer) grpc_write_cb( self._grpc_socket, 0 diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/rpc_status.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/rpc_status.pxd.pxi index 62add7f33d1..a4414b8cfbe 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/rpc_status.pxd.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/rpc_status.pxd.pxi @@ -16,13 +16,13 @@ cdef class AioRpcStatus(Exception): cdef readonly: - int _code + grpc_status_code _code str _details # On spec, only client-side status has trailing metadata. tuple _trailing_metadata str _debug_error_string - cpdef int code(self) + cpdef grpc_status_code code(self) cpdef str details(self) cpdef tuple trailing_metadata(self) cpdef str debug_error_string(self) diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/rpc_status.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/rpc_status.pyx.pxi index 9784db19a1f..07669fc1575 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/rpc_status.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/rpc_status.pyx.pxi @@ -19,7 +19,7 @@ cdef class AioRpcStatus(Exception): # The final status of gRPC is represented by three trailing metadata: # `grpc-status`, `grpc-status-message`, abd `grpc-status-details`. def __cinit__(self, - int code, + grpc_status_code code, str details, tuple trailing_metadata, str debug_error_string): @@ -28,7 +28,7 @@ cdef class AioRpcStatus(Exception): self._trailing_metadata = trailing_metadata self._debug_error_string = debug_error_string - cpdef int code(self): + cpdef grpc_status_code code(self): return self._code cpdef str details(self): diff --git a/src/python/grpcio/grpc/experimental/aio/_call.py b/src/python/grpcio/grpc/experimental/aio/_call.py index 0e7334501e7..a62b8d61c7a 100644 --- a/src/python/grpcio/grpc/experimental/aio/_call.py +++ b/src/python/grpcio/grpc/experimental/aio/_call.py @@ -246,7 +246,6 @@ class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall): Returned when an instance of `UnaryUnaryMultiCallable` object is called. """ - _loop: asyncio.AbstractEventLoop _request: RequestType _deadline: Optional[float] _channel: cygrpc.AioChannel @@ -260,7 +259,6 @@ class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall): request_serializer: SerializingFunction, response_deserializer: DeserializingFunction) -> None: super().__init__() - self._loop = asyncio.get_event_loop() self._request = request self._deadline = deadline self._channel = channel @@ -330,28 +328,26 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall): Returned when an instance of `UnaryStreamMultiCallable` object is called. """ - _loop: asyncio.AbstractEventLoop _request: RequestType _deadline: Optional[float] _channel: cygrpc.AioChannel _method: bytes _request_serializer: SerializingFunction _response_deserializer: DeserializingFunction - _call: AsyncIterable[ResponseType] + _aiter: AsyncIterable[ResponseType] def __init__(self, request: RequestType, deadline: Optional[float], channel: cygrpc.AioChannel, method: bytes, request_serializer: SerializingFunction, response_deserializer: DeserializingFunction) -> None: super().__init__() - self._loop = asyncio.get_event_loop() self._request = request self._deadline = deadline self._channel = channel self._method = method self._request_serializer = request_serializer self._response_deserializer = response_deserializer - self._call = self._invoke() + self._aiter = self._invoke() def __del__(self) -> None: if not self._status.done(): @@ -406,10 +402,10 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall): _LOCAL_CANCELLATION_DETAILS, None, None)) def __aiter__(self) -> AsyncIterable[ResponseType]: - return self._call + return self._aiter async def read(self) -> ResponseType: if self._status.done(): await self._raise_rpc_error_if_not_ok() raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS) - return await self._call.__anext__() + return await self._aiter.__anext__() diff --git a/src/python/grpcio_tests/tests_aio/tests.json b/src/python/grpcio_tests/tests_aio/tests.json index 39b4d93de43..1a46c9d5e69 100644 --- a/src/python/grpcio_tests/tests_aio/tests.json +++ b/src/python/grpcio_tests/tests_aio/tests.json @@ -1,7 +1,8 @@ [ "_sanity._sanity_test.AioSanityTest", "unit.aio_rpc_error_test.TestAioRpcError", - "unit.call_test.TestCall", + "unit.call_test.TestUnaryUnaryCall", + "unit.call_test.TestUnaryStreamCall", "unit.channel_test.TestChannel", "unit.init_test.TestInsecureChannel", "unit.server_test.TestServer" diff --git a/src/python/grpcio_tests/tests_aio/unit/call_test.py b/src/python/grpcio_tests/tests_aio/unit/call_test.py index 81480b62180..44477c36159 100644 --- a/src/python/grpcio_tests/tests_aio/unit/call_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/call_test.py @@ -30,11 +30,10 @@ from tests_aio.unit._test_base import AioTestBase _NUM_STREAM_RESPONSES = 5 _RESPONSE_PAYLOAD_SIZE = 42 _LOCAL_CANCEL_DETAILS_EXPECTATION = 'Locally cancelled by application!' -# _RESPONSE_INTERVAL_US = test_constants.SHORT_TIMEOUT * 1000 * 1000 -_RESPONSE_INTERVAL_US = 200 * 1000 +_RESPONSE_INTERVAL_US = test_constants.SHORT_TIMEOUT * 1000 * 1000 -class TestCall(AioTestBase): +class TestUnaryUnaryCall(AioTestBase): async def setUp(self): self._server_target, self._server = await start_test_server() @@ -141,6 +140,15 @@ class TestCall(AioTestBase): # so we might not want to use it to transmit data. # https://github.com/python/cpython/blob/master/Lib/asyncio/tasks.py#L785 + +class TestUnaryStreamCall(AioTestBase): + + async def setUp(self): + self._server_target, self._server = await start_test_server() + + async def tearDown(self): + await self._server.stop(None) + async def test_cancel_unary_stream(self): async with aio.insecure_channel(self._server_target) as channel: stub = test_pb2_grpc.TestServiceStub(channel)