Merge pull request #21681 from lidizheng/aio-callbacks

[Aio] Implement add_done_callback and time_remaining
pull/21689/head
Lidi Zheng 5 years ago committed by GitHub
commit b9083a9edb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pxd.pxi
  2. 7
      src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi
  3. 10
      src/python/grpcio/grpc/experimental/aio/_base_call.py
  4. 30
      src/python/grpcio/grpc/experimental/aio/_call.py
  5. 1
      src/python/grpcio/grpc/experimental/aio/_typing.py
  6. 1
      src/python/grpcio_tests/tests_aio/tests.json
  7. 529
      src/python/grpcio_tests/tests_aio/unit/call_test.py
  8. 160
      src/python/grpcio_tests/tests_aio/unit/done_callback_test.py

@ -28,4 +28,6 @@ cdef class _AioCall(GrpcCallWrapper):
# because Core is holding a pointer for the callback handler.
bint _is_locally_cancelled
object _deadline
cdef void _create_grpc_call(self, object timeout, bytes method, CallCredentials credentials) except *

@ -36,6 +36,7 @@ cdef class _AioCall(GrpcCallWrapper):
self._loop = asyncio.get_event_loop()
self._create_grpc_call(deadline, method, call_credentials)
self._is_locally_cancelled = False
self._deadline = deadline
def __dealloc__(self):
if self.call:
@ -84,6 +85,12 @@ cdef class _AioCall(GrpcCallWrapper):
grpc_slice_unref(method_slice)
def time_remaining(self):
if self._deadline is None:
return None
else:
return max(0, self._deadline - time.time())
def cancel(self, AioRpcStatus status):
"""Cancels the RPC in Core with given RPC status.

@ -19,12 +19,12 @@ RPC, e.g. cancellation.
"""
from abc import ABCMeta, abstractmethod
from typing import (Any, AsyncIterable, Awaitable, Callable, Generic, Optional,
Text, Union)
from typing import AsyncIterable, Awaitable, Generic, Optional, Text, Union
import grpc
from ._typing import EOFType, MetadataType, RequestType, ResponseType
from ._typing import (DoneCallbackType, EOFType, MetadataType, RequestType,
ResponseType)
__all__ = 'RpcContext', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall'
@ -73,11 +73,11 @@ class RpcContext(metaclass=ABCMeta):
"""
@abstractmethod
def add_done_callback(self, callback: Callable[[Any], None]) -> None:
def add_done_callback(self, callback: DoneCallbackType) -> None:
"""Registers a callback to be called on RPC termination.
Args:
callback: A callable object will be called with the context object as
callback: A callable object will be called with the call object as
its only argument.
"""

@ -14,7 +14,7 @@
"""Invocation-side implementation of gRPC Asyncio Python."""
import asyncio
from typing import AsyncIterable, Awaitable, Dict, Optional
from typing import AsyncIterable, Awaitable, List, Dict, Optional
import grpc
from grpc import _common
@ -22,7 +22,7 @@ from grpc._cython import cygrpc
from . import _base_call
from ._typing import (DeserializingFunction, MetadataType, RequestType,
ResponseType, SerializingFunction)
ResponseType, SerializingFunction, DoneCallbackType)
__all__ = 'AioRpcError', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall'
@ -157,6 +157,7 @@ class Call(_base_call.Call):
_initial_metadata: Awaitable[MetadataType]
_locally_cancelled: bool
_cython_call: cygrpc._AioCall
_done_callbacks: List[DoneCallbackType]
def __init__(self, cython_call: cygrpc._AioCall) -> None:
self._loop = asyncio.get_event_loop()
@ -165,6 +166,7 @@ class Call(_base_call.Call):
self._initial_metadata = self._loop.create_future()
self._locally_cancelled = False
self._cython_call = cython_call
self._done_callbacks = []
def __del__(self) -> None:
if not self._status.done():
@ -192,11 +194,14 @@ class Call(_base_call.Call):
def done(self) -> bool:
return self._status.done()
def add_done_callback(self, unused_callback) -> None:
raise NotImplementedError()
def add_done_callback(self, callback: DoneCallbackType) -> None:
if self.done():
callback(self)
else:
self._done_callbacks.append(callback)
def time_remaining(self) -> Optional[float]:
raise NotImplementedError()
return self._cython_call.time_remaining()
async def initial_metadata(self) -> MetadataType:
return await self._initial_metadata
@ -220,9 +225,7 @@ class Call(_base_call.Call):
def _set_status(self, status: cygrpc.AioRpcStatus) -> None:
"""Private method to set final status of the RPC.
This method may be called multiple time due to data race between local
cancellation (by application) and Core receiving status from peer. We
make no promise here which one will win.
This method should only be invoked once.
"""
# In case of local cancellation, flip the flag.
if status.details() is _LOCAL_CANCELLATION_DETAILS:
@ -236,6 +239,9 @@ class Call(_base_call.Call):
self._status.set_result(status)
self._code = _common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[status.code()]
for callback in self._done_callbacks:
callback(self)
async def _raise_for_status(self) -> None:
if self._locally_cancelled:
raise asyncio.CancelledError()
@ -265,8 +271,6 @@ class Call(_base_call.Call):
return self._repr()
# TODO(https://github.com/grpc/grpc/issues/21623) remove this suppression
# pylint: disable=abstract-method
class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall):
"""Object for managing unary-unary RPC calls.
@ -338,8 +342,6 @@ class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall):
return response
# TODO(https://github.com/grpc/grpc/issues/21623) remove this suppression
# pylint: disable=abstract-method
class UnaryStreamCall(Call, _base_call.UnaryStreamCall):
"""Object for managing unary-stream RPC calls.
@ -429,8 +431,6 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall):
return response_message
# TODO(https://github.com/grpc/grpc/issues/21623) remove this suppression
# pylint: disable=abstract-method
class StreamUnaryCall(Call, _base_call.StreamUnaryCall):
"""Object for managing stream-unary RPC calls.
@ -550,8 +550,6 @@ class StreamUnaryCall(Call, _base_call.StreamUnaryCall):
await self._raise_for_status()
# TODO(https://github.com/grpc/grpc/issues/21623) remove this suppression
# pylint: disable=abstract-method
class StreamStreamCall(Call, _base_call.StreamStreamCall):
"""Object for managing stream-stream RPC calls.

@ -24,3 +24,4 @@ MetadatumType = Tuple[Text, AnyStr]
MetadataType = Sequence[MetadatumType]
ChannelArgumentType = Sequence[Tuple[Text, Any]]
EOFType = type(EOF)
DoneCallbackType = Callable[[Any], None]

@ -9,6 +9,7 @@
"unit.channel_argument_test.TestChannelArgument",
"unit.channel_test.TestChannel",
"unit.connectivity_test.TestConnectivityState",
"unit.done_callback_test.TestDoneCallback",
"unit.init_test.TestInsecureChannel",
"unit.init_test.TestSecureChannel",
"unit.interceptor_test.TestInterceptedUnaryUnaryCall",

@ -14,19 +14,17 @@
"""Tests behavior of the grpc.aio.UnaryUnaryCall class."""
import asyncio
import datetime
import logging
import unittest
import datetime
import grpc
from grpc.experimental import aio
from src.proto.grpc.testing import messages_pb2
from src.proto.grpc.testing import test_pb2_grpc
from src.proto.grpc.testing import messages_pb2, test_pb2_grpc
from tests.unit.framework.common import test_constants
from tests_aio.unit._test_server import start_test_server
from tests_aio.unit._test_base import AioTestBase
from src.proto.grpc.testing import messages_pb2
from tests_aio.unit._test_server import start_test_server
_NUM_STREAM_RESPONSES = 5
_RESPONSE_PAYLOAD_SIZE = 42
@ -37,44 +35,41 @@ _UNREACHABLE_TARGET = '0.1:1111'
_INFINITE_INTERVAL_US = 2**31 - 1
class TestUnaryUnaryCall(AioTestBase):
class _MulticallableTestMixin():
async def setUp(self):
self._server_target, self._server = await start_test_server()
address, self._server = await start_test_server()
self._channel = aio.insecure_channel(address)
self._stub = test_pb2_grpc.TestServiceStub(self._channel)
async def tearDown(self):
await self._channel.close()
await self._server.stop(None)
class TestUnaryUnaryCall(_MulticallableTestMixin, AioTestBase):
async def test_call_ok(self):
async with aio.insecure_channel(self._server_target) as channel:
hi = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = hi(messages_pb2.SimpleRequest())
call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
self.assertFalse(call.done())
self.assertFalse(call.done())
response = await call
response = await call
self.assertTrue(call.done())
self.assertIsInstance(response, messages_pb2.SimpleResponse)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
self.assertTrue(call.done())
self.assertIsInstance(response, messages_pb2.SimpleResponse)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
# Response is cached at call object level, reentrance
# returns again the same response
response_retry = await call
self.assertIs(response, response_retry)
# Response is cached at call object level, reentrance
# returns again the same response
response_retry = await call
self.assertIs(response, response_retry)
async def test_call_rpc_error(self):
async with aio.insecure_channel(_UNREACHABLE_TARGET) as channel:
hi = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString,
)
stub = test_pb2_grpc.TestServiceStub(channel)
call = hi(messages_pb2.SimpleRequest(), timeout=0.1)
call = stub.UnaryCall(messages_pb2.SimpleRequest(), timeout=0.1)
with self.assertRaises(grpc.RpcError) as exception_context:
await call
@ -95,327 +90,264 @@ class TestUnaryUnaryCall(AioTestBase):
exception_context_retry.exception)
async def test_call_code_awaitable(self):
async with aio.insecure_channel(self._server_target) as channel:
hi = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = hi(messages_pb2.SimpleRequest())
self.assertEqual(await call.code(), grpc.StatusCode.OK)
call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
self.assertEqual(await call.code(), grpc.StatusCode.OK)
async def test_call_details_awaitable(self):
async with aio.insecure_channel(self._server_target) as channel:
hi = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = hi(messages_pb2.SimpleRequest())
self.assertEqual('', await call.details())
call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
self.assertEqual('', await call.details())
async def test_call_initial_metadata_awaitable(self):
async with aio.insecure_channel(self._server_target) as channel:
hi = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = hi(messages_pb2.SimpleRequest())
self.assertEqual((), await call.initial_metadata())
call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
self.assertEqual((), await call.initial_metadata())
async def test_call_trailing_metadata_awaitable(self):
async with aio.insecure_channel(self._server_target) as channel:
hi = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = hi(messages_pb2.SimpleRequest())
self.assertEqual((), await call.trailing_metadata())
call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
self.assertEqual((), await call.trailing_metadata())
async def test_cancel_unary_unary(self):
async with aio.insecure_channel(self._server_target) as channel:
hi = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = hi(messages_pb2.SimpleRequest())
call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
self.assertFalse(call.cancelled())
self.assertFalse(call.cancelled())
self.assertTrue(call.cancel())
self.assertFalse(call.cancel())
self.assertTrue(call.cancel())
self.assertFalse(call.cancel())
with self.assertRaises(asyncio.CancelledError):
await call
with self.assertRaises(asyncio.CancelledError):
await call
# The info in the RpcError should match the info in Call object.
self.assertTrue(call.cancelled())
self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)
self.assertEqual(await call.details(),
'Locally cancelled by application!')
# The info in the RpcError should match the info in Call object.
self.assertTrue(call.cancelled())
self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)
self.assertEqual(await call.details(),
'Locally cancelled by application!')
async def test_cancel_unary_unary_in_task(self):
async with aio.insecure_channel(self._server_target) as channel:
stub = test_pb2_grpc.TestServiceStub(channel)
coro_started = asyncio.Event()
call = stub.EmptyCall(messages_pb2.SimpleRequest())
async def another_coro():
coro_started.set()
await call
task = self.loop.create_task(another_coro())
await coro_started.wait()
coro_started = asyncio.Event()
call = self._stub.EmptyCall(messages_pb2.SimpleRequest())
self.assertFalse(task.done())
task.cancel()
async def another_coro():
coro_started.set()
await call
self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
task = self.loop.create_task(another_coro())
await coro_started.wait()
with self.assertRaises(asyncio.CancelledError):
await task
self.assertFalse(task.done())
task.cancel()
self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
class TestUnaryStreamCall(AioTestBase):
with self.assertRaises(asyncio.CancelledError):
await task
async def setUp(self):
self._server_target, self._server = await start_test_server()
async def tearDown(self):
await self._server.stop(None)
class TestUnaryStreamCall(_MulticallableTestMixin, AioTestBase):
async def test_cancel_unary_stream(self):
async with aio.insecure_channel(self._server_target) as channel:
stub = test_pb2_grpc.TestServiceStub(channel)
# Prepares the request
request = messages_pb2.StreamingOutputCallRequest()
for _ in range(_NUM_STREAM_RESPONSES):
request.response_parameters.append(
messages_pb2.ResponseParameters(
size=_RESPONSE_PAYLOAD_SIZE,
interval_us=_RESPONSE_INTERVAL_US,
))
# Prepares the request
request = messages_pb2.StreamingOutputCallRequest()
for _ in range(_NUM_STREAM_RESPONSES):
request.response_parameters.append(
messages_pb2.ResponseParameters(
size=_RESPONSE_PAYLOAD_SIZE,
interval_us=_RESPONSE_INTERVAL_US,
))
# Invokes the actual RPC
call = stub.StreamingOutputCall(request)
self.assertFalse(call.cancelled())
# Invokes the actual RPC
call = self._stub.StreamingOutputCall(request)
self.assertFalse(call.cancelled())
response = await call.read()
self.assertIs(type(response),
messages_pb2.StreamingOutputCallResponse)
self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
response = await call.read()
self.assertIs(type(response), messages_pb2.StreamingOutputCallResponse)
self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
self.assertTrue(call.cancel())
self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
self.assertEqual(_LOCAL_CANCEL_DETAILS_EXPECTATION, await
call.details())
self.assertFalse(call.cancel())
self.assertTrue(call.cancel())
self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
self.assertEqual(_LOCAL_CANCEL_DETAILS_EXPECTATION, await
call.details())
self.assertFalse(call.cancel())
with self.assertRaises(asyncio.CancelledError):
await call.read()
self.assertTrue(call.cancelled())
with self.assertRaises(asyncio.CancelledError):
await call.read()
self.assertTrue(call.cancelled())
async def test_multiple_cancel_unary_stream(self):
async with aio.insecure_channel(self._server_target) as channel:
stub = test_pb2_grpc.TestServiceStub(channel)
# Prepares the request
request = messages_pb2.StreamingOutputCallRequest()
for _ in range(_NUM_STREAM_RESPONSES):
request.response_parameters.append(
messages_pb2.ResponseParameters(
size=_RESPONSE_PAYLOAD_SIZE,
interval_us=_RESPONSE_INTERVAL_US,
))
# Prepares the request
request = messages_pb2.StreamingOutputCallRequest()
for _ in range(_NUM_STREAM_RESPONSES):
request.response_parameters.append(
messages_pb2.ResponseParameters(
size=_RESPONSE_PAYLOAD_SIZE,
interval_us=_RESPONSE_INTERVAL_US,
))
# Invokes the actual RPC
call = stub.StreamingOutputCall(request)
self.assertFalse(call.cancelled())
# Invokes the actual RPC
call = self._stub.StreamingOutputCall(request)
self.assertFalse(call.cancelled())
response = await call.read()
self.assertIs(type(response),
messages_pb2.StreamingOutputCallResponse)
self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
response = await call.read()
self.assertIs(type(response), messages_pb2.StreamingOutputCallResponse)
self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
self.assertTrue(call.cancel())
self.assertFalse(call.cancel())
self.assertFalse(call.cancel())
self.assertFalse(call.cancel())
self.assertTrue(call.cancel())
self.assertFalse(call.cancel())
self.assertFalse(call.cancel())
self.assertFalse(call.cancel())
with self.assertRaises(asyncio.CancelledError):
await call.read()
with self.assertRaises(asyncio.CancelledError):
await call.read()
async def test_early_cancel_unary_stream(self):
"""Test cancellation before receiving messages."""
async with aio.insecure_channel(self._server_target) as channel:
stub = test_pb2_grpc.TestServiceStub(channel)
# Prepares the request
request = messages_pb2.StreamingOutputCallRequest()
for _ in range(_NUM_STREAM_RESPONSES):
request.response_parameters.append(
messages_pb2.ResponseParameters(
size=_RESPONSE_PAYLOAD_SIZE,
interval_us=_RESPONSE_INTERVAL_US,
))
# Prepares the request
request = messages_pb2.StreamingOutputCallRequest()
for _ in range(_NUM_STREAM_RESPONSES):
request.response_parameters.append(
messages_pb2.ResponseParameters(
size=_RESPONSE_PAYLOAD_SIZE,
interval_us=_RESPONSE_INTERVAL_US,
))
# Invokes the actual RPC
call = stub.StreamingOutputCall(request)
# Invokes the actual RPC
call = self._stub.StreamingOutputCall(request)
self.assertFalse(call.cancelled())
self.assertTrue(call.cancel())
self.assertFalse(call.cancel())
self.assertFalse(call.cancelled())
self.assertTrue(call.cancel())
self.assertFalse(call.cancel())
with self.assertRaises(asyncio.CancelledError):
await call.read()
with self.assertRaises(asyncio.CancelledError):
await call.read()
self.assertTrue(call.cancelled())
self.assertTrue(call.cancelled())
self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
self.assertEqual(_LOCAL_CANCEL_DETAILS_EXPECTATION, await
call.details())
self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
self.assertEqual(_LOCAL_CANCEL_DETAILS_EXPECTATION, await
call.details())
async def test_late_cancel_unary_stream(self):
"""Test cancellation after received all messages."""
async with aio.insecure_channel(self._server_target) as channel:
stub = test_pb2_grpc.TestServiceStub(channel)
# Prepares the request
request = messages_pb2.StreamingOutputCallRequest()
for _ in range(_NUM_STREAM_RESPONSES):
request.response_parameters.append(
messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE,))
# Prepares the request
request = messages_pb2.StreamingOutputCallRequest()
for _ in range(_NUM_STREAM_RESPONSES):
request.response_parameters.append(
messages_pb2.ResponseParameters(
size=_RESPONSE_PAYLOAD_SIZE,))
# Invokes the actual RPC
call = self._stub.StreamingOutputCall(request)
# Invokes the actual RPC
call = stub.StreamingOutputCall(request)
for _ in range(_NUM_STREAM_RESPONSES):
response = await call.read()
self.assertIs(type(response),
messages_pb2.StreamingOutputCallResponse)
self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
for _ in range(_NUM_STREAM_RESPONSES):
response = await call.read()
self.assertIs(type(response),
messages_pb2.StreamingOutputCallResponse)
self.assertEqual(_RESPONSE_PAYLOAD_SIZE,
len(response.payload.body))
# After all messages received, it is possible that the final state
# is received or on its way. It's basically a data race, so our
# expectation here is do not crash :)
call.cancel()
self.assertIn(await call.code(),
[grpc.StatusCode.OK, grpc.StatusCode.CANCELLED])
# After all messages received, it is possible that the final state
# is received or on its way. It's basically a data race, so our
# expectation here is do not crash :)
call.cancel()
self.assertIn(await call.code(),
[grpc.StatusCode.OK, grpc.StatusCode.CANCELLED])
async def test_too_many_reads_unary_stream(self):
"""Test calling read after received all messages fails."""
async with aio.insecure_channel(self._server_target) as channel:
stub = test_pb2_grpc.TestServiceStub(channel)
# Prepares the request
request = messages_pb2.StreamingOutputCallRequest()
for _ in range(_NUM_STREAM_RESPONSES):
request.response_parameters.append(
messages_pb2.ResponseParameters(
size=_RESPONSE_PAYLOAD_SIZE,))
# Prepares the request
request = messages_pb2.StreamingOutputCallRequest()
for _ in range(_NUM_STREAM_RESPONSES):
request.response_parameters.append(
messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE,))
# Invokes the actual RPC
call = stub.StreamingOutputCall(request)
# Invokes the actual RPC
call = self._stub.StreamingOutputCall(request)
for _ in range(_NUM_STREAM_RESPONSES):
response = await call.read()
self.assertIs(type(response),
messages_pb2.StreamingOutputCallResponse)
self.assertEqual(_RESPONSE_PAYLOAD_SIZE,
len(response.payload.body))
self.assertIs(await call.read(), aio.EOF)
for _ in range(_NUM_STREAM_RESPONSES):
response = await call.read()
self.assertIs(type(response),
messages_pb2.StreamingOutputCallResponse)
self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
self.assertIs(await call.read(), aio.EOF)
# After the RPC is finished, further reads will lead to exception.
self.assertEqual(await call.code(), grpc.StatusCode.OK)
self.assertIs(await call.read(), aio.EOF)
# After the RPC is finished, further reads will lead to exception.
self.assertEqual(await call.code(), grpc.StatusCode.OK)
self.assertIs(await call.read(), aio.EOF)
async def test_unary_stream_async_generator(self):
"""Sunny day test case for unary_stream."""
async with aio.insecure_channel(self._server_target) as channel:
stub = test_pb2_grpc.TestServiceStub(channel)
# Prepares the request
request = messages_pb2.StreamingOutputCallRequest()
for _ in range(_NUM_STREAM_RESPONSES):
request.response_parameters.append(
messages_pb2.ResponseParameters(
size=_RESPONSE_PAYLOAD_SIZE,))
# Prepares the request
request = messages_pb2.StreamingOutputCallRequest()
for _ in range(_NUM_STREAM_RESPONSES):
request.response_parameters.append(
messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE,))
# Invokes the actual RPC
call = stub.StreamingOutputCall(request)
self.assertFalse(call.cancelled())
# Invokes the actual RPC
call = self._stub.StreamingOutputCall(request)
self.assertFalse(call.cancelled())
async for response in call:
self.assertIs(type(response),
messages_pb2.StreamingOutputCallResponse)
self.assertEqual(_RESPONSE_PAYLOAD_SIZE,
len(response.payload.body))
async for response in call:
self.assertIs(type(response),
messages_pb2.StreamingOutputCallResponse)
self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
self.assertEqual(await call.code(), grpc.StatusCode.OK)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
async def test_cancel_unary_stream_in_task_using_read(self):
async with aio.insecure_channel(self._server_target) as channel:
stub = test_pb2_grpc.TestServiceStub(channel)
coro_started = asyncio.Event()
coro_started = asyncio.Event()
# Configs the server method to block forever
request = messages_pb2.StreamingOutputCallRequest()
request.response_parameters.append(
messages_pb2.ResponseParameters(
size=_RESPONSE_PAYLOAD_SIZE,
interval_us=_INFINITE_INTERVAL_US,
))
# Configs the server method to block forever
request = messages_pb2.StreamingOutputCallRequest()
request.response_parameters.append(
messages_pb2.ResponseParameters(
size=_RESPONSE_PAYLOAD_SIZE,
interval_us=_INFINITE_INTERVAL_US,
))
# Invokes the actual RPC
call = stub.StreamingOutputCall(request)
# Invokes the actual RPC
call = self._stub.StreamingOutputCall(request)
async def another_coro():
coro_started.set()
await call.read()
async def another_coro():
coro_started.set()
await call.read()
task = self.loop.create_task(another_coro())
await coro_started.wait()
task = self.loop.create_task(another_coro())
await coro_started.wait()
self.assertFalse(task.done())
task.cancel()
self.assertFalse(task.done())
task.cancel()
self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
with self.assertRaises(asyncio.CancelledError):
await task
with self.assertRaises(asyncio.CancelledError):
await task
async def test_cancel_unary_stream_in_task_using_async_for(self):
async with aio.insecure_channel(self._server_target) as channel:
stub = test_pb2_grpc.TestServiceStub(channel)
coro_started = asyncio.Event()
coro_started = asyncio.Event()
# Configs the server method to block forever
request = messages_pb2.StreamingOutputCallRequest()
request.response_parameters.append(
messages_pb2.ResponseParameters(
size=_RESPONSE_PAYLOAD_SIZE,
interval_us=_INFINITE_INTERVAL_US,
))
# Configs the server method to block forever
request = messages_pb2.StreamingOutputCallRequest()
request.response_parameters.append(
messages_pb2.ResponseParameters(
size=_RESPONSE_PAYLOAD_SIZE,
interval_us=_INFINITE_INTERVAL_US,
))
# Invokes the actual RPC
call = stub.StreamingOutputCall(request)
# Invokes the actual RPC
call = self._stub.StreamingOutputCall(request)
async def another_coro():
coro_started.set()
async for _ in call:
pass
async def another_coro():
coro_started.set()
async for _ in call:
pass
task = self.loop.create_task(another_coro())
await coro_started.wait()
task = self.loop.create_task(another_coro())
await coro_started.wait()
self.assertFalse(task.done())
task.cancel()
self.assertFalse(task.done())
task.cancel()
self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
with self.assertRaises(asyncio.CancelledError):
await task
with self.assertRaises(asyncio.CancelledError):
await task
def test_call_credentials(self):
@ -444,17 +376,41 @@ class TestUnaryStreamCall(AioTestBase):
self.loop.run_until_complete(coro())
async def test_time_remaining(self):
request = messages_pb2.StreamingOutputCallRequest()
# First message comes back immediately
request.response_parameters.append(
messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE,))
# Second message comes back after a unit of wait time
request.response_parameters.append(
messages_pb2.ResponseParameters(
size=_RESPONSE_PAYLOAD_SIZE,
interval_us=_RESPONSE_INTERVAL_US,
))
call = self._stub.StreamingOutputCall(
request, timeout=test_constants.SHORT_TIMEOUT * 2)
class TestStreamUnaryCall(AioTestBase):
response = await call.read()
self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
async def setUp(self):
self._server_target, self._server = await start_test_server()
self._channel = aio.insecure_channel(self._server_target)
self._stub = test_pb2_grpc.TestServiceStub(self._channel)
# Should be around the same as the timeout
remained_time = call.time_remaining()
self.assertGreater(remained_time, test_constants.SHORT_TIMEOUT * 3 // 2)
self.assertLess(remained_time, test_constants.SHORT_TIMEOUT * 2)
async def tearDown(self):
await self._channel.close()
await self._server.stop(None)
response = await call.read()
self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
# Should be around the timeout minus a unit of wait time
remained_time = call.time_remaining()
self.assertGreater(remained_time, test_constants.SHORT_TIMEOUT // 2)
self.assertLess(remained_time, test_constants.SHORT_TIMEOUT * 3 // 2)
self.assertEqual(grpc.StatusCode.OK, await call.code())
class TestStreamUnaryCall(_MulticallableTestMixin, AioTestBase):
async def test_cancel_stream_unary(self):
call = self._stub.StreamingInputCall()
@ -564,16 +520,7 @@ _STREAM_OUTPUT_REQUEST_ONE_RESPONSE.response_parameters.append(
messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE))
class TestStreamStreamCall(AioTestBase):
async def setUp(self):
self._server_target, self._server = await start_test_server()
self._channel = aio.insecure_channel(self._server_target)
self._stub = test_pb2_grpc.TestServiceStub(self._channel)
async def tearDown(self):
await self._channel.close()
await self._server.stop(None)
class TestStreamStreamCall(_MulticallableTestMixin, AioTestBase):
async def test_cancel(self):
# Invokes the actual RPC

@ -0,0 +1,160 @@
# Copyright 2020 The gRPC Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing the done callbacks mechanism."""
# Copyright 2019 The gRPC Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
import unittest
import time
import gc
import grpc
from grpc.experimental import aio
from tests_aio.unit._test_base import AioTestBase
from tests.unit.framework.common import test_constants
from src.proto.grpc.testing import messages_pb2, test_pb2_grpc
from tests_aio.unit._test_server import start_test_server
_NUM_STREAM_RESPONSES = 5
_REQUEST_PAYLOAD_SIZE = 7
_RESPONSE_PAYLOAD_SIZE = 42
def _inject_callbacks(call):
first_callback_ran = asyncio.Event()
def first_callback(unused_call):
first_callback_ran.set()
second_callback_ran = asyncio.Event()
def second_callback(unused_call):
second_callback_ran.set()
call.add_done_callback(first_callback)
call.add_done_callback(second_callback)
async def validation():
await asyncio.wait_for(
asyncio.gather(first_callback_ran.wait(),
second_callback_ran.wait()),
test_constants.SHORT_TIMEOUT)
return validation()
class TestDoneCallback(AioTestBase):
async def setUp(self):
address, self._server = await start_test_server()
self._channel = aio.insecure_channel(address)
self._stub = test_pb2_grpc.TestServiceStub(self._channel)
async def tearDown(self):
await self._channel.close()
await self._server.stop(None)
async def test_add_after_done(self):
call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
self.assertEqual(grpc.StatusCode.OK, await call.code())
validation = _inject_callbacks(call)
await validation
async def test_unary_unary(self):
call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
validation = _inject_callbacks(call)
self.assertEqual(grpc.StatusCode.OK, await call.code())
await validation
async def test_unary_stream(self):
request = messages_pb2.StreamingOutputCallRequest()
for _ in range(_NUM_STREAM_RESPONSES):
request.response_parameters.append(
messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE))
call = self._stub.StreamingOutputCall(request)
validation = _inject_callbacks(call)
response_cnt = 0
async for response in call:
response_cnt += 1
self.assertIsInstance(response,
messages_pb2.StreamingOutputCallResponse)
self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
self.assertEqual(_NUM_STREAM_RESPONSES, response_cnt)
self.assertEqual(grpc.StatusCode.OK, await call.code())
await validation
async def test_stream_unary(self):
payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE)
request = messages_pb2.StreamingInputCallRequest(payload=payload)
async def gen():
for _ in range(_NUM_STREAM_RESPONSES):
yield request
call = self._stub.StreamingInputCall(gen())
validation = _inject_callbacks(call)
response = await call
self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse)
self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE,
response.aggregated_payload_size)
self.assertEqual(grpc.StatusCode.OK, await call.code())
await validation
async def test_stream_stream(self):
call = self._stub.FullDuplexCall()
validation = _inject_callbacks(call)
request = messages_pb2.StreamingOutputCallRequest()
request.response_parameters.append(
messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE))
for _ in range(_NUM_STREAM_RESPONSES):
await call.write(request)
response = await call.read()
self.assertIsInstance(response,
messages_pb2.StreamingOutputCallResponse)
self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
await call.done_writing()
self.assertEqual(grpc.StatusCode.OK, await call.code())
await validation
if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG)
unittest.main(verbosity=2)
Loading…
Cancel
Save