Use a poller thread to replace custom IO manager

pull/22258/head
Lidi Zheng 5 years ago
parent d9c55675c4
commit e00f8b3492
  1. 3
      src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pxd.pxi
  2. 3
      src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi
  3. 25
      src/python/grpcio/grpc/_cython/_cygrpc/aio/grpc_aio.pyx.pxi
  4. 3
      src/python/grpcio/grpc/_cython/_cygrpc/aio/iomgr/socket.pyx.pxi
  5. 11
      src/python/grpcio/grpc/_cython/_cygrpc/aio/poller.pxd.pxi
  6. 89
      src/python/grpcio/grpc/_cython/_cygrpc/aio/poller.pyx.pxi
  7. 3
      src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pxd.pxi
  8. 3
      src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi
  9. 1
      src/python/grpcio/grpc/_cython/cygrpc.pxd
  10. 1
      src/python/grpcio/grpc/_cython/cygrpc.pyx
  11. 564
      src/python/grpcio_tests/tests_aio/unit/call_test.py

@ -21,7 +21,8 @@ cdef enum AioChannelStatus:
cdef class AioChannel: cdef class AioChannel:
cdef: cdef:
grpc_channel * channel grpc_channel * channel
CallbackCompletionQueue cq # CallbackCompletionQueue cq
BackgroundCompletionQueue cq
object loop object loop
bytes _target bytes _target
AioChannelStatus _status AioChannelStatus _status

@ -31,7 +31,8 @@ cdef class AioChannel:
options = () options = ()
cdef _ChannelArgs channel_args = _ChannelArgs(options) cdef _ChannelArgs channel_args = _ChannelArgs(options)
self._target = target self._target = target
self.cq = CallbackCompletionQueue() # self.cq = CallbackCompletionQueue()
self.cq = BackgroundCompletionQueue()
self.loop = loop self.loop = loop
self._status = AIO_CHANNEL_STATUS_READY self._status = AIO_CHANNEL_STATUS_READY

@ -19,22 +19,25 @@ cdef bint _grpc_aio_initialized = False
# we should support this use case. So, the gRPC Python Async Stack should use # we should support this use case. So, the gRPC Python Async Stack should use
# a single event loop picked by "init_grpc_aio". # a single event loop picked by "init_grpc_aio".
cdef object _grpc_aio_loop cdef object _grpc_aio_loop
cdef object _event_loop_thread_ident
def init_grpc_aio(): def init_grpc_aio():
global _grpc_aio_initialized global _grpc_aio_initialized
global _grpc_aio_loop global _grpc_aio_loop
global _event_loop_thread_ident
if _grpc_aio_initialized: if _grpc_aio_initialized:
return return
else: else:
_grpc_aio_initialized = True _grpc_aio_initialized = True
_event_loop_thread_ident = threading.current_thread().ident
# Anchors the event loop that the gRPC library going to use. # Anchors the event loop that the gRPC library going to use.
_grpc_aio_loop = asyncio.get_event_loop() _grpc_aio_loop = asyncio.get_event_loop()
# Activates asyncio IO manager # Activates asyncio IO manager
install_asyncio_iomgr() # install_asyncio_iomgr()
# TODO(https://github.com/grpc/grpc/issues/22244) we need a the # TODO(https://github.com/grpc/grpc/issues/22244) we need a the
# grpc_shutdown_blocking() counterpart for this call. Otherwise, the gRPC # grpc_shutdown_blocking() counterpart for this call. Otherwise, the gRPC
@ -44,11 +47,11 @@ def init_grpc_aio():
# Timers are triggered by the Asyncio loop. We disable # Timers are triggered by the Asyncio loop. We disable
# the background thread that is being used by the native # the background thread that is being used by the native
# gRPC iomgr. # gRPC iomgr.
grpc_timer_manager_set_threading(False) # grpc_timer_manager_set_threading(False)
# gRPC callbaks are executed within the same thread used by the Asyncio # gRPC callbaks are executed within the same thread used by the Asyncio
# event loop, as it is being done by the other Asyncio callbacks. # event loop, as it is being done by the other Asyncio callbacks.
Executor.SetThreadingAll(False) # Executor.SetThreadingAll(False)
_grpc_aio_initialized = False _grpc_aio_initialized = False
@ -56,3 +59,19 @@ def init_grpc_aio():
def grpc_aio_loop(): def grpc_aio_loop():
"""Returns the one-and-only gRPC Aio event loop.""" """Returns the one-and-only gRPC Aio event loop."""
return _grpc_aio_loop return _grpc_aio_loop
cdef grpc_schedule_coroutine(object coro):
"""Thread-safely schedules coroutine to gRPC Aio event loop.
If invoked within the same thread as the event loop, return an
Asyncio.Task. Otherwise, return a concurrent.futures.Future (the sync
Future). For non-asyncio threads, sync Future objects are probably easier
to handle (without worrying other thread-safety stuff).
"""
assert _event_loop_thread_ident != threading.current_thread().ident
return asyncio.run_coroutine_threadsafe(coro, _grpc_aio_loop)
def grpc_call_soon_threadsafe(object func, *args):
return _grpc_aio_loop.call_soon_threadsafe(func, *args)

@ -159,6 +159,7 @@ cdef class _AsyncioSocket:
return self._reader and not self._reader._transport.is_closing() return self._reader and not self._reader._transport.is_closing()
cdef void close(self): cdef void close(self):
_LOGGER.debug('closed!')
if self.is_connected(): if self.is_connected():
self._writer.close() self._writer.close()
if self._server: if self._server:
@ -196,7 +197,9 @@ cdef class _AsyncioSocket:
self._new_connection_callback, self._new_connection_callback,
sock=self._py_socket, sock=self._py_socket,
) )
_LOGGER.debug('start listen')
_LOGGER.debug('want to listen')
grpc_aio_loop().create_task(create_asyncio_server()) grpc_aio_loop().create_task(create_asyncio_server())
cdef accept(self, cdef accept(self,

@ -0,0 +1,11 @@
cdef gpr_timespec _GPR_INF_FUTURE = gpr_inf_future(GPR_CLOCK_REALTIME)
cdef class BackgroundCompletionQueue:
cdef grpc_completion_queue *_cq
cdef bint _shutdown
cdef object _shutdown_completed
cdef object _poller
cdef object _poller_running
cdef _polling(self)
cdef grpc_completion_queue* c_ptr(self)

@ -0,0 +1,89 @@
# Copyright 2020 The gRPC Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
cdef gpr_timespec _GPR_INF_FUTURE = gpr_inf_future(GPR_CLOCK_REALTIME)
def _handle_callback_wrapper(CallbackWrapper callback_wrapper, int success):
try:
CallbackWrapper.functor_run(callback_wrapper.c_functor(), success)
_LOGGER.debug('_handle_callback_wrapper Done')
except Exception as e:
_LOGGER.debug('_handle_callback_wrapper EXP')
_LOGGER.exception(e)
raise
cdef class BackgroundCompletionQueue:
def __cinit__(self):
self._cq = grpc_completion_queue_create_for_next(NULL)
self._shutdown = False
self._shutdown_completed = asyncio.get_event_loop().create_future()
self._poller = None
self._poller_running = asyncio.get_event_loop().create_future()
# asyncio.get_event_loop().create_task(self._start_poller())
self._poller = threading.Thread(target=self._polling_wrapper)
self._poller.daemon = True
self._poller.start()
# async def _start_poller(self):
# if self._poller:
# raise UsageError('Poller can only be started once.')
# self._poller = threading.Thread(target=self._polling_wrapper)
# self._poller.daemon = True
# self._poller.start()
# await self._poller_running
cdef _polling(self):
cdef grpc_event event
cdef CallbackContext *context
cdef object waiter
grpc_call_soon_threadsafe(self._poller_running.set_result, None)
while not self._shutdown:
_LOGGER.debug('BackgroundCompletionQueue polling')
with nogil:
event = grpc_completion_queue_next(self._cq,
_GPR_INF_FUTURE,
NULL)
_LOGGER.debug('BackgroundCompletionQueue polling 1')
if event.type == GRPC_QUEUE_TIMEOUT:
_LOGGER.debug('BackgroundCompletionQueue timeout???')
raise NotImplementedError()
elif event.type == GRPC_QUEUE_SHUTDOWN:
_LOGGER.debug('BackgroundCompletionQueue shutdown!')
self._shutdown = True
grpc_call_soon_threadsafe(self._shutdown_completed.set_result, None)
else:
_LOGGER.debug('BackgroundCompletionQueue event! %d', event.success)
context = <CallbackContext *>event.tag
grpc_call_soon_threadsafe(
_handle_callback_wrapper,
<CallbackWrapper>context.callback_wrapper,
event.success)
_LOGGER.debug('BackgroundCompletionQueue polling 2')
def _polling_wrapper(self):
self._polling()
async def shutdown(self):
grpc_completion_queue_shutdown(self._cq)
await self._shutdown_completed
grpc_completion_queue_destroy(self._cq)
cdef grpc_completion_queue* c_ptr(self):
return self._cq

@ -51,7 +51,8 @@ cdef enum AioServerStatus:
cdef class AioServer: cdef class AioServer:
cdef Server _server cdef Server _server
cdef CallbackCompletionQueue _cq # cdef CallbackCompletionQueue _cq
cdef BackgroundCompletionQueue _cq
cdef list _generic_handlers cdef list _generic_handlers
cdef AioServerStatus _status cdef AioServerStatus _status
cdef object _loop # asyncio.EventLoop cdef object _loop # asyncio.EventLoop

@ -613,7 +613,8 @@ cdef class AioServer:
# NOTE(lidiz) Core objects won't be deallocated automatically. # NOTE(lidiz) Core objects won't be deallocated automatically.
# If AioServer.shutdown is not called, those objects will leak. # If AioServer.shutdown is not called, those objects will leak.
self._server = Server(options) self._server = Server(options)
self._cq = CallbackCompletionQueue() # self._cq = CallbackCompletionQueue()
self._cq = BackgroundCompletionQueue()
grpc_server_register_completion_queue( grpc_server_register_completion_queue(
self._server.c_server, self._server.c_server,
self._cq.c_ptr(), self._cq.c_ptr(),

@ -45,6 +45,7 @@ IF UNAME_SYSNAME != "Windows":
include "_cygrpc/aio/iomgr/socket.pxd.pxi" include "_cygrpc/aio/iomgr/socket.pxd.pxi"
include "_cygrpc/aio/iomgr/timer.pxd.pxi" include "_cygrpc/aio/iomgr/timer.pxd.pxi"
include "_cygrpc/aio/iomgr/resolver.pxd.pxi" include "_cygrpc/aio/iomgr/resolver.pxd.pxi"
include "_cygrpc/aio/poller.pxd.pxi"
include "_cygrpc/aio/rpc_status.pxd.pxi" include "_cygrpc/aio/rpc_status.pxd.pxi"
include "_cygrpc/aio/grpc_aio.pxd.pxi" include "_cygrpc/aio/grpc_aio.pxd.pxi"
include "_cygrpc/aio/callback_common.pxd.pxi" include "_cygrpc/aio/callback_common.pxd.pxi"

@ -72,6 +72,7 @@ include "_cygrpc/aio/iomgr/resolver.pyx.pxi"
include "_cygrpc/aio/common.pyx.pxi" include "_cygrpc/aio/common.pyx.pxi"
include "_cygrpc/aio/rpc_status.pyx.pxi" include "_cygrpc/aio/rpc_status.pyx.pxi"
include "_cygrpc/aio/callback_common.pyx.pxi" include "_cygrpc/aio/callback_common.pyx.pxi"
include "_cygrpc/aio/poller.pyx.pxi"
include "_cygrpc/aio/grpc_aio.pyx.pxi" include "_cygrpc/aio/grpc_aio.pyx.pxi"
include "_cygrpc/aio/call.pyx.pxi" include "_cygrpc/aio/call.pyx.pxi"
include "_cygrpc/aio/channel.pyx.pxi" include "_cygrpc/aio/channel.pyx.pxi"

@ -47,338 +47,338 @@ class _MulticallableTestMixin():
await self._server.stop(None) await self._server.stop(None)
class TestUnaryUnaryCall(_MulticallableTestMixin, AioTestBase): # class TestUnaryUnaryCall(_MulticallableTestMixin, AioTestBase):
async def test_call_to_string(self): # async def test_call_to_string(self):
call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) # call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
self.assertTrue(str(call) is not None) # self.assertTrue(str(call) is not None)
self.assertTrue(repr(call) is not None) # self.assertTrue(repr(call) is not None)
response = await call # response = await call
self.assertTrue(str(call) is not None) # self.assertTrue(str(call) is not None)
self.assertTrue(repr(call) is not None) # self.assertTrue(repr(call) is not None)
async def test_call_ok(self): # async def test_call_ok(self):
call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) # call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
self.assertFalse(call.done()) # self.assertFalse(call.done())
response = await call # response = await call
self.assertTrue(call.done()) # self.assertTrue(call.done())
self.assertIsInstance(response, messages_pb2.SimpleResponse) # self.assertIsInstance(response, messages_pb2.SimpleResponse)
self.assertEqual(await call.code(), grpc.StatusCode.OK) # self.assertEqual(await call.code(), grpc.StatusCode.OK)
# Response is cached at call object level, reentrance # # Response is cached at call object level, reentrance
# returns again the same response # # returns again the same response
response_retry = await call # response_retry = await call
self.assertIs(response, response_retry) # self.assertIs(response, response_retry)
async def test_call_rpc_error(self): # async def test_call_rpc_error(self):
async with aio.insecure_channel(_UNREACHABLE_TARGET) as channel: # async with aio.insecure_channel(_UNREACHABLE_TARGET) as channel:
stub = test_pb2_grpc.TestServiceStub(channel) # stub = test_pb2_grpc.TestServiceStub(channel)
call = stub.UnaryCall(messages_pb2.SimpleRequest()) # call = stub.UnaryCall(messages_pb2.SimpleRequest())
with self.assertRaises(aio.AioRpcError) as exception_context: # with self.assertRaises(aio.AioRpcError) as exception_context:
await call # await call
self.assertEqual(grpc.StatusCode.UNAVAILABLE, # self.assertEqual(grpc.StatusCode.UNAVAILABLE,
exception_context.exception.code()) # exception_context.exception.code())
self.assertTrue(call.done()) # self.assertTrue(call.done())
self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code()) # self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code())
async def test_call_code_awaitable(self): # async def test_call_code_awaitable(self):
call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) # call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
self.assertEqual(await call.code(), grpc.StatusCode.OK) # self.assertEqual(await call.code(), grpc.StatusCode.OK)
async def test_call_details_awaitable(self): # async def test_call_details_awaitable(self):
call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) # call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
self.assertEqual('', await call.details()) # self.assertEqual('', await call.details())
async def test_call_initial_metadata_awaitable(self): # async def test_call_initial_metadata_awaitable(self):
call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) # call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
self.assertEqual((), await call.initial_metadata()) # self.assertEqual((), await call.initial_metadata())
async def test_call_trailing_metadata_awaitable(self): # async def test_call_trailing_metadata_awaitable(self):
call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) # call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
self.assertEqual((), await call.trailing_metadata()) # self.assertEqual((), await call.trailing_metadata())
async def test_call_initial_metadata_cancelable(self): # async def test_call_initial_metadata_cancelable(self):
coro_started = asyncio.Event() # coro_started = asyncio.Event()
call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) # call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
async def coro(): # async def coro():
coro_started.set() # coro_started.set()
await call.initial_metadata() # await call.initial_metadata()
task = self.loop.create_task(coro()) # task = self.loop.create_task(coro())
await coro_started.wait() # await coro_started.wait()
task.cancel() # task.cancel()
# Test that initial metadata can still be asked thought # # Test that initial metadata can still be asked thought
# a cancellation happened with the previous task # # a cancellation happened with the previous task
self.assertEqual((), await call.initial_metadata()) # self.assertEqual((), await call.initial_metadata())
async def test_call_initial_metadata_multiple_waiters(self): # async def test_call_initial_metadata_multiple_waiters(self):
call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) # call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
async def coro(): # async def coro():
return await call.initial_metadata() # return await call.initial_metadata()
task1 = self.loop.create_task(coro()) # task1 = self.loop.create_task(coro())
task2 = self.loop.create_task(coro()) # task2 = self.loop.create_task(coro())
await call # await call
self.assertEqual([(), ()], await asyncio.gather(*[task1, task2])) # self.assertEqual([(), ()], await asyncio.gather(*[task1, task2]))
async def test_call_code_cancelable(self): # async def test_call_code_cancelable(self):
coro_started = asyncio.Event() # coro_started = asyncio.Event()
call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) # call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
async def coro(): # async def coro():
coro_started.set() # coro_started.set()
await call.code() # await call.code()
task = self.loop.create_task(coro()) # task = self.loop.create_task(coro())
await coro_started.wait() # await coro_started.wait()
task.cancel() # task.cancel()
# Test that code can still be asked thought # # Test that code can still be asked thought
# a cancellation happened with the previous task # # a cancellation happened with the previous task
self.assertEqual(grpc.StatusCode.OK, await call.code()) # self.assertEqual(grpc.StatusCode.OK, await call.code())
async def test_call_code_multiple_waiters(self): # async def test_call_code_multiple_waiters(self):
call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) # call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
async def coro(): # async def coro():
return await call.code() # return await call.code()
task1 = self.loop.create_task(coro()) # task1 = self.loop.create_task(coro())
task2 = self.loop.create_task(coro()) # task2 = self.loop.create_task(coro())
await call # await call
self.assertEqual([grpc.StatusCode.OK, grpc.StatusCode.OK], await # self.assertEqual([grpc.StatusCode.OK, grpc.StatusCode.OK], await
asyncio.gather(task1, task2)) # asyncio.gather(task1, task2))
async def test_cancel_unary_unary(self): # async def test_cancel_unary_unary(self):
call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) # call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
self.assertFalse(call.cancelled()) # self.assertFalse(call.cancelled())
self.assertTrue(call.cancel()) # self.assertTrue(call.cancel())
self.assertFalse(call.cancel()) # self.assertFalse(call.cancel())
with self.assertRaises(asyncio.CancelledError): # with self.assertRaises(asyncio.CancelledError):
await call # await call
# The info in the RpcError should match the info in Call object. # # The info in the RpcError should match the info in Call object.
self.assertTrue(call.cancelled()) # self.assertTrue(call.cancelled())
self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) # self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)
self.assertEqual(await call.details(), # self.assertEqual(await call.details(),
'Locally cancelled by application!') # 'Locally cancelled by application!')
async def test_cancel_unary_unary_in_task(self): # async def test_cancel_unary_unary_in_task(self):
coro_started = asyncio.Event() # coro_started = asyncio.Event()
call = self._stub.EmptyCall(messages_pb2.SimpleRequest()) # call = self._stub.EmptyCall(messages_pb2.SimpleRequest())
async def another_coro(): # async def another_coro():
coro_started.set() # coro_started.set()
await call # await call
task = self.loop.create_task(another_coro()) # task = self.loop.create_task(another_coro())
await coro_started.wait() # await coro_started.wait()
self.assertFalse(task.done()) # self.assertFalse(task.done())
task.cancel() # task.cancel()
self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) # self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
with self.assertRaises(asyncio.CancelledError): # with self.assertRaises(asyncio.CancelledError):
await task # await task
class TestUnaryStreamCall(_MulticallableTestMixin, AioTestBase): class TestUnaryStreamCall(_MulticallableTestMixin, AioTestBase):
async def test_cancel_unary_stream(self): # async def test_cancel_unary_stream(self):
# Prepares the request # # Prepares the request
request = messages_pb2.StreamingOutputCallRequest() # request = messages_pb2.StreamingOutputCallRequest()
for _ in range(_NUM_STREAM_RESPONSES): # for _ in range(_NUM_STREAM_RESPONSES):
request.response_parameters.append( # request.response_parameters.append(
messages_pb2.ResponseParameters( # messages_pb2.ResponseParameters(
size=_RESPONSE_PAYLOAD_SIZE, # size=_RESPONSE_PAYLOAD_SIZE,
interval_us=_RESPONSE_INTERVAL_US, # interval_us=_RESPONSE_INTERVAL_US,
)) # ))
# Invokes the actual RPC # # Invokes the actual RPC
call = self._stub.StreamingOutputCall(request) # call = self._stub.StreamingOutputCall(request)
self.assertFalse(call.cancelled()) # self.assertFalse(call.cancelled())
response = await call.read() # response = await call.read()
self.assertIs(type(response), messages_pb2.StreamingOutputCallResponse) # self.assertIs(type(response), messages_pb2.StreamingOutputCallResponse)
self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) # self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
self.assertTrue(call.cancel()) # self.assertTrue(call.cancel())
self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) # self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
self.assertEqual(_LOCAL_CANCEL_DETAILS_EXPECTATION, await # self.assertEqual(_LOCAL_CANCEL_DETAILS_EXPECTATION, await
call.details()) # call.details())
self.assertFalse(call.cancel()) # self.assertFalse(call.cancel())
with self.assertRaises(asyncio.CancelledError): # with self.assertRaises(asyncio.CancelledError):
await call.read() # await call.read()
self.assertTrue(call.cancelled()) # self.assertTrue(call.cancelled())
async def test_multiple_cancel_unary_stream(self): # async def test_multiple_cancel_unary_stream(self):
# Prepares the request # # Prepares the request
request = messages_pb2.StreamingOutputCallRequest() # request = messages_pb2.StreamingOutputCallRequest()
for _ in range(_NUM_STREAM_RESPONSES): # for _ in range(_NUM_STREAM_RESPONSES):
request.response_parameters.append( # request.response_parameters.append(
messages_pb2.ResponseParameters( # messages_pb2.ResponseParameters(
size=_RESPONSE_PAYLOAD_SIZE, # size=_RESPONSE_PAYLOAD_SIZE,
interval_us=_RESPONSE_INTERVAL_US, # interval_us=_RESPONSE_INTERVAL_US,
)) # ))
# Invokes the actual RPC # # Invokes the actual RPC
call = self._stub.StreamingOutputCall(request) # call = self._stub.StreamingOutputCall(request)
self.assertFalse(call.cancelled()) # self.assertFalse(call.cancelled())
response = await call.read() # response = await call.read()
self.assertIs(type(response), messages_pb2.StreamingOutputCallResponse) # self.assertIs(type(response), messages_pb2.StreamingOutputCallResponse)
self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) # self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
self.assertTrue(call.cancel()) # self.assertTrue(call.cancel())
self.assertFalse(call.cancel()) # self.assertFalse(call.cancel())
self.assertFalse(call.cancel()) # self.assertFalse(call.cancel())
self.assertFalse(call.cancel()) # self.assertFalse(call.cancel())
with self.assertRaises(asyncio.CancelledError): # with self.assertRaises(asyncio.CancelledError):
await call.read() # await call.read()
async def test_early_cancel_unary_stream(self): # async def test_early_cancel_unary_stream(self):
"""Test cancellation before receiving messages.""" # """Test cancellation before receiving messages."""
# Prepares the request # # Prepares the request
request = messages_pb2.StreamingOutputCallRequest() # request = messages_pb2.StreamingOutputCallRequest()
for _ in range(_NUM_STREAM_RESPONSES): # for _ in range(_NUM_STREAM_RESPONSES):
request.response_parameters.append( # request.response_parameters.append(
messages_pb2.ResponseParameters( # messages_pb2.ResponseParameters(
size=_RESPONSE_PAYLOAD_SIZE, # size=_RESPONSE_PAYLOAD_SIZE,
interval_us=_RESPONSE_INTERVAL_US, # interval_us=_RESPONSE_INTERVAL_US,
)) # ))
# Invokes the actual RPC # # Invokes the actual RPC
call = self._stub.StreamingOutputCall(request) # call = self._stub.StreamingOutputCall(request)
self.assertFalse(call.cancelled()) # self.assertFalse(call.cancelled())
self.assertTrue(call.cancel()) # self.assertTrue(call.cancel())
self.assertFalse(call.cancel()) # self.assertFalse(call.cancel())
with self.assertRaises(asyncio.CancelledError): # with self.assertRaises(asyncio.CancelledError):
await call.read() # await call.read()
self.assertTrue(call.cancelled()) # self.assertTrue(call.cancelled())
self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) # self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
self.assertEqual(_LOCAL_CANCEL_DETAILS_EXPECTATION, await # self.assertEqual(_LOCAL_CANCEL_DETAILS_EXPECTATION, await
call.details()) # call.details())
async def test_late_cancel_unary_stream(self): # async def test_late_cancel_unary_stream(self):
"""Test cancellation after received all messages.""" # """Test cancellation after received all messages."""
# Prepares the request # # Prepares the request
request = messages_pb2.StreamingOutputCallRequest() # request = messages_pb2.StreamingOutputCallRequest()
for _ in range(_NUM_STREAM_RESPONSES): # for _ in range(_NUM_STREAM_RESPONSES):
request.response_parameters.append( # request.response_parameters.append(
messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE,)) # messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE,))
# Invokes the actual RPC # # Invokes the actual RPC
call = self._stub.StreamingOutputCall(request) # call = self._stub.StreamingOutputCall(request)
for _ in range(_NUM_STREAM_RESPONSES): # for _ in range(_NUM_STREAM_RESPONSES):
response = await call.read() # response = await call.read()
self.assertIs(type(response), # self.assertIs(type(response),
messages_pb2.StreamingOutputCallResponse) # messages_pb2.StreamingOutputCallResponse)
self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) # self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
# After all messages received, it is possible that the final state # # After all messages received, it is possible that the final state
# is received or on its way. It's basically a data race, so our # # is received or on its way. It's basically a data race, so our
# expectation here is do not crash :) # # expectation here is do not crash :)
call.cancel() # call.cancel()
self.assertIn(await call.code(), # self.assertIn(await call.code(),
[grpc.StatusCode.OK, grpc.StatusCode.CANCELLED]) # [grpc.StatusCode.OK, grpc.StatusCode.CANCELLED])
async def test_too_many_reads_unary_stream(self): # async def test_too_many_reads_unary_stream(self):
"""Test calling read after received all messages fails.""" # """Test calling read after received all messages fails."""
# Prepares the request # # Prepares the request
request = messages_pb2.StreamingOutputCallRequest() # request = messages_pb2.StreamingOutputCallRequest()
for _ in range(_NUM_STREAM_RESPONSES): # for _ in range(_NUM_STREAM_RESPONSES):
request.response_parameters.append( # request.response_parameters.append(
messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE,)) # messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE,))
# Invokes the actual RPC # # Invokes the actual RPC
call = self._stub.StreamingOutputCall(request) # call = self._stub.StreamingOutputCall(request)
for _ in range(_NUM_STREAM_RESPONSES): # for _ in range(_NUM_STREAM_RESPONSES):
response = await call.read() # response = await call.read()
self.assertIs(type(response), # self.assertIs(type(response),
messages_pb2.StreamingOutputCallResponse) # messages_pb2.StreamingOutputCallResponse)
self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) # self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
self.assertIs(await call.read(), aio.EOF) # self.assertIs(await call.read(), aio.EOF)
# After the RPC is finished, further reads will lead to exception. # # After the RPC is finished, further reads will lead to exception.
self.assertEqual(await call.code(), grpc.StatusCode.OK) # self.assertEqual(await call.code(), grpc.StatusCode.OK)
self.assertIs(await call.read(), aio.EOF) # self.assertIs(await call.read(), aio.EOF)
async def test_unary_stream_async_generator(self): # async def test_unary_stream_async_generator(self):
"""Sunny day test case for unary_stream.""" # """Sunny day test case for unary_stream."""
# Prepares the request # # Prepares the request
request = messages_pb2.StreamingOutputCallRequest() # request = messages_pb2.StreamingOutputCallRequest()
for _ in range(_NUM_STREAM_RESPONSES): # for _ in range(_NUM_STREAM_RESPONSES):
request.response_parameters.append( # request.response_parameters.append(
messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE,)) # messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE,))
# Invokes the actual RPC # # Invokes the actual RPC
call = self._stub.StreamingOutputCall(request) # call = self._stub.StreamingOutputCall(request)
self.assertFalse(call.cancelled()) # self.assertFalse(call.cancelled())
async for response in call: # async for response in call:
self.assertIs(type(response), # self.assertIs(type(response),
messages_pb2.StreamingOutputCallResponse) # messages_pb2.StreamingOutputCallResponse)
self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) # self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
self.assertEqual(await call.code(), grpc.StatusCode.OK) # self.assertEqual(await call.code(), grpc.StatusCode.OK)
async def test_cancel_unary_stream_in_task_using_read(self): # async def test_cancel_unary_stream_in_task_using_read(self):
coro_started = asyncio.Event() # coro_started = asyncio.Event()
# Configs the server method to block forever # # Configs the server method to block forever
request = messages_pb2.StreamingOutputCallRequest() # request = messages_pb2.StreamingOutputCallRequest()
request.response_parameters.append( # request.response_parameters.append(
messages_pb2.ResponseParameters( # messages_pb2.ResponseParameters(
size=_RESPONSE_PAYLOAD_SIZE, # size=_RESPONSE_PAYLOAD_SIZE,
interval_us=_INFINITE_INTERVAL_US, # interval_us=_INFINITE_INTERVAL_US,
)) # ))
# Invokes the actual RPC # # Invokes the actual RPC
call = self._stub.StreamingOutputCall(request) # call = self._stub.StreamingOutputCall(request)
async def another_coro(): # async def another_coro():
coro_started.set() # coro_started.set()
await call.read() # await call.read()
task = self.loop.create_task(another_coro()) # task = self.loop.create_task(another_coro())
await coro_started.wait() # await coro_started.wait()
self.assertFalse(task.done()) # self.assertFalse(task.done())
task.cancel() # task.cancel()
self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) # self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
with self.assertRaises(asyncio.CancelledError): # with self.assertRaises(asyncio.CancelledError):
await task # await task
async def test_cancel_unary_stream_in_task_using_async_for(self): async def test_cancel_unary_stream_in_task_using_async_for(self):
coro_started = asyncio.Event() coro_started = asyncio.Event()
@ -755,5 +755,5 @@ class TestStreamStreamCall(_MulticallableTestMixin, AioTestBase):
if __name__ == '__main__': if __name__ == '__main__':
logging.basicConfig() logging.basicConfig(level=logging.DEBUG)
unittest.main(verbosity=2) unittest.main(verbosity=2)

Loading…
Cancel
Save