|
|
|
@ -16,6 +16,7 @@ |
|
|
|
|
import asyncio |
|
|
|
|
from functools import partial |
|
|
|
|
import logging |
|
|
|
|
import enum |
|
|
|
|
from typing import AsyncIterable, Awaitable, Dict, Optional |
|
|
|
|
|
|
|
|
|
import grpc |
|
|
|
@ -238,6 +239,12 @@ class Call: |
|
|
|
|
return self._repr() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _APIStyle(enum.IntEnum): |
|
|
|
|
UNKNOWN = 0 |
|
|
|
|
ASYNC_GENERATOR = 1 |
|
|
|
|
READER_WRITER = 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _UnaryResponseMixin(Call): |
|
|
|
|
_call_response: asyncio.Task |
|
|
|
|
|
|
|
|
@ -283,10 +290,19 @@ class _UnaryResponseMixin(Call): |
|
|
|
|
class _StreamResponseMixin(Call): |
|
|
|
|
_message_aiter: AsyncIterable[ResponseType] |
|
|
|
|
_preparation: asyncio.Task |
|
|
|
|
_response_style: _APIStyle |
|
|
|
|
|
|
|
|
|
def _init_stream_response_mixin(self, preparation: asyncio.Task): |
|
|
|
|
self._message_aiter = None |
|
|
|
|
self._preparation = preparation |
|
|
|
|
self._response_style = _APIStyle.UNKNOWN |
|
|
|
|
|
|
|
|
|
def _update_response_style(self, style: _APIStyle): |
|
|
|
|
if self._response_style is _APIStyle.UNKNOWN: |
|
|
|
|
self._response_style = style |
|
|
|
|
elif self._response_style is not style: |
|
|
|
|
raise cygrpc.UsageError( |
|
|
|
|
'Please don\'t mix two styles of API for streaming responses') |
|
|
|
|
|
|
|
|
|
def cancel(self) -> bool: |
|
|
|
|
if super().cancel(): |
|
|
|
@ -302,6 +318,7 @@ class _StreamResponseMixin(Call): |
|
|
|
|
message = await self._read() |
|
|
|
|
|
|
|
|
|
def __aiter__(self) -> AsyncIterable[ResponseType]: |
|
|
|
|
self._update_response_style(_APIStyle.ASYNC_GENERATOR) |
|
|
|
|
if self._message_aiter is None: |
|
|
|
|
self._message_aiter = self._fetch_stream_responses() |
|
|
|
|
return self._message_aiter |
|
|
|
@ -328,6 +345,7 @@ class _StreamResponseMixin(Call): |
|
|
|
|
if self.done(): |
|
|
|
|
await self._raise_for_status() |
|
|
|
|
return cygrpc.EOF |
|
|
|
|
self._update_response_style(_APIStyle.READER_WRITER) |
|
|
|
|
|
|
|
|
|
response_message = await self._read() |
|
|
|
|
|
|
|
|
@ -339,20 +357,28 @@ class _StreamResponseMixin(Call): |
|
|
|
|
|
|
|
|
|
class _StreamRequestMixin(Call): |
|
|
|
|
_metadata_sent: asyncio.Event |
|
|
|
|
_done_writing: bool |
|
|
|
|
_done_writing_flag: bool |
|
|
|
|
_async_request_poller: Optional[asyncio.Task] |
|
|
|
|
_request_style: _APIStyle |
|
|
|
|
|
|
|
|
|
def _init_stream_request_mixin( |
|
|
|
|
self, request_async_iterator: Optional[AsyncIterable[RequestType]]): |
|
|
|
|
self._metadata_sent = asyncio.Event(loop=self._loop) |
|
|
|
|
self._done_writing = False |
|
|
|
|
self._done_writing_flag = False |
|
|
|
|
|
|
|
|
|
# If user passes in an async iterator, create a consumer Task. |
|
|
|
|
if request_async_iterator is not None: |
|
|
|
|
self._async_request_poller = self._loop.create_task( |
|
|
|
|
self._consume_request_iterator(request_async_iterator)) |
|
|
|
|
self._request_style = _APIStyle.ASYNC_GENERATOR |
|
|
|
|
else: |
|
|
|
|
self._async_request_poller = None |
|
|
|
|
self._request_style = _APIStyle.READER_WRITER |
|
|
|
|
|
|
|
|
|
def _raise_for_different_style(self, style: _APIStyle): |
|
|
|
|
if self._request_style is not style: |
|
|
|
|
raise cygrpc.UsageError( |
|
|
|
|
'Please don\'t mix two styles of API for streaming requests') |
|
|
|
|
|
|
|
|
|
def cancel(self) -> bool: |
|
|
|
|
if super().cancel(): |
|
|
|
@ -369,8 +395,8 @@ class _StreamRequestMixin(Call): |
|
|
|
|
self, request_async_iterator: AsyncIterable[RequestType]) -> None: |
|
|
|
|
try: |
|
|
|
|
async for request in request_async_iterator: |
|
|
|
|
await self.write(request) |
|
|
|
|
await self.done_writing() |
|
|
|
|
await self._write(request) |
|
|
|
|
await self._done_writing() |
|
|
|
|
except AioRpcError as rpc_error: |
|
|
|
|
# Rpc status should be exposed through other API. Exceptions raised |
|
|
|
|
# within this Task won't be retrieved by another coroutine. It's |
|
|
|
@ -378,10 +404,10 @@ class _StreamRequestMixin(Call): |
|
|
|
|
_LOGGER.debug('Exception while consuming the request_iterator: %s', |
|
|
|
|
rpc_error) |
|
|
|
|
|
|
|
|
|
async def write(self, request: RequestType) -> None: |
|
|
|
|
async def _write(self, request: RequestType) -> None: |
|
|
|
|
if self.done(): |
|
|
|
|
raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS) |
|
|
|
|
if self._done_writing: |
|
|
|
|
if self._done_writing_flag: |
|
|
|
|
raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS) |
|
|
|
|
if not self._metadata_sent.is_set(): |
|
|
|
|
await self._metadata_sent.wait() |
|
|
|
@ -398,14 +424,13 @@ class _StreamRequestMixin(Call): |
|
|
|
|
self.cancel() |
|
|
|
|
await self._raise_for_status() |
|
|
|
|
|
|
|
|
|
async def done_writing(self) -> None: |
|
|
|
|
"""Implementation of done_writing is idempotent.""" |
|
|
|
|
async def _done_writing(self) -> None: |
|
|
|
|
if self.done(): |
|
|
|
|
# If the RPC is finished, do nothing. |
|
|
|
|
return |
|
|
|
|
if not self._done_writing: |
|
|
|
|
if not self._done_writing_flag: |
|
|
|
|
# If the done writing is not sent before, try to send it. |
|
|
|
|
self._done_writing = True |
|
|
|
|
self._done_writing_flag = True |
|
|
|
|
try: |
|
|
|
|
await self._cython_call.send_receive_close() |
|
|
|
|
except asyncio.CancelledError: |
|
|
|
@ -413,6 +438,18 @@ class _StreamRequestMixin(Call): |
|
|
|
|
self.cancel() |
|
|
|
|
await self._raise_for_status() |
|
|
|
|
|
|
|
|
|
async def write(self, request: RequestType) -> None: |
|
|
|
|
self._raise_for_different_style(_APIStyle.READER_WRITER) |
|
|
|
|
await self._write(request) |
|
|
|
|
|
|
|
|
|
async def done_writing(self) -> None: |
|
|
|
|
"""Signal peer that client is done writing. |
|
|
|
|
|
|
|
|
|
This method is idempotent. |
|
|
|
|
""" |
|
|
|
|
self._raise_for_different_style(_APIStyle.READER_WRITER) |
|
|
|
|
await self._done_writing() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class UnaryUnaryCall(_UnaryResponseMixin, Call, _base_call.UnaryUnaryCall): |
|
|
|
|
"""Object for managing unary-unary RPC calls. |
|
|
|
|