|
|
|
@ -15,6 +15,7 @@ |
|
|
|
|
|
|
|
|
|
import asyncio |
|
|
|
|
import enum |
|
|
|
|
import inspect |
|
|
|
|
import logging |
|
|
|
|
from functools import partial |
|
|
|
|
from typing import AsyncIterable, Awaitable, Optional, Tuple |
|
|
|
@ -25,8 +26,8 @@ from grpc._cython import cygrpc |
|
|
|
|
|
|
|
|
|
from . import _base_call |
|
|
|
|
from ._typing import (DeserializingFunction, DoneCallbackType, MetadataType, |
|
|
|
|
MetadatumType, RequestType, ResponseType, |
|
|
|
|
SerializingFunction) |
|
|
|
|
MetadatumType, RequestIterableType, RequestType, |
|
|
|
|
ResponseType, SerializingFunction) |
|
|
|
|
|
|
|
|
|
__all__ = 'AioRpcError', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall' |
|
|
|
|
|
|
|
|
@ -363,14 +364,14 @@ class _StreamRequestMixin(Call): |
|
|
|
|
_request_style: _APIStyle |
|
|
|
|
|
|
|
|
|
def _init_stream_request_mixin( |
|
|
|
|
self, request_async_iterator: Optional[AsyncIterable[RequestType]]): |
|
|
|
|
self, request_iterator: Optional[RequestIterableType]): |
|
|
|
|
self._metadata_sent = asyncio.Event(loop=self._loop) |
|
|
|
|
self._done_writing_flag = False |
|
|
|
|
|
|
|
|
|
# If user passes in an async iterator, create a consumer Task. |
|
|
|
|
if request_async_iterator is not None: |
|
|
|
|
if request_iterator is not None: |
|
|
|
|
self._async_request_poller = self._loop.create_task( |
|
|
|
|
self._consume_request_iterator(request_async_iterator)) |
|
|
|
|
self._consume_request_iterator(request_iterator)) |
|
|
|
|
self._request_style = _APIStyle.ASYNC_GENERATOR |
|
|
|
|
else: |
|
|
|
|
self._async_request_poller = None |
|
|
|
@ -392,11 +393,17 @@ class _StreamRequestMixin(Call): |
|
|
|
|
def _metadata_sent_observer(self): |
|
|
|
|
self._metadata_sent.set() |
|
|
|
|
|
|
|
|
|
async def _consume_request_iterator( |
|
|
|
|
self, request_async_iterator: AsyncIterable[RequestType]) -> None: |
|
|
|
|
async def _consume_request_iterator(self, |
|
|
|
|
request_iterator: RequestIterableType |
|
|
|
|
) -> None: |
|
|
|
|
try: |
|
|
|
|
async for request in request_async_iterator: |
|
|
|
|
if inspect.isasyncgen(request_iterator): |
|
|
|
|
async for request in request_iterator: |
|
|
|
|
await self._write(request) |
|
|
|
|
else: |
|
|
|
|
for request in request_iterator: |
|
|
|
|
await self._write(request) |
|
|
|
|
|
|
|
|
|
await self._done_writing() |
|
|
|
|
except AioRpcError as rpc_error: |
|
|
|
|
# Rpc status should be exposed through other API. Exceptions raised |
|
|
|
@ -538,8 +545,7 @@ class StreamUnaryCall(_StreamRequestMixin, _UnaryResponseMixin, Call, |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
# pylint: disable=too-many-arguments |
|
|
|
|
def __init__(self, |
|
|
|
|
request_async_iterator: Optional[AsyncIterable[RequestType]], |
|
|
|
|
def __init__(self, request_iterator: Optional[RequestIterableType], |
|
|
|
|
deadline: Optional[float], metadata: MetadataType, |
|
|
|
|
credentials: Optional[grpc.CallCredentials], |
|
|
|
|
wait_for_ready: Optional[bool], channel: cygrpc.AioChannel, |
|
|
|
@ -550,7 +556,7 @@ class StreamUnaryCall(_StreamRequestMixin, _UnaryResponseMixin, Call, |
|
|
|
|
channel.call(method, deadline, credentials, wait_for_ready), |
|
|
|
|
metadata, request_serializer, response_deserializer, loop) |
|
|
|
|
|
|
|
|
|
self._init_stream_request_mixin(request_async_iterator) |
|
|
|
|
self._init_stream_request_mixin(request_iterator) |
|
|
|
|
self._init_unary_response_mixin(self._conduct_rpc()) |
|
|
|
|
|
|
|
|
|
async def _conduct_rpc(self) -> ResponseType: |
|
|
|
@ -577,8 +583,7 @@ class StreamStreamCall(_StreamRequestMixin, _StreamResponseMixin, Call, |
|
|
|
|
_initializer: asyncio.Task |
|
|
|
|
|
|
|
|
|
# pylint: disable=too-many-arguments |
|
|
|
|
def __init__(self, |
|
|
|
|
request_async_iterator: Optional[AsyncIterable[RequestType]], |
|
|
|
|
def __init__(self, request_iterator: Optional[RequestIterableType], |
|
|
|
|
deadline: Optional[float], metadata: MetadataType, |
|
|
|
|
credentials: Optional[grpc.CallCredentials], |
|
|
|
|
wait_for_ready: Optional[bool], channel: cygrpc.AioChannel, |
|
|
|
@ -589,7 +594,7 @@ class StreamStreamCall(_StreamRequestMixin, _StreamResponseMixin, Call, |
|
|
|
|
channel.call(method, deadline, credentials, wait_for_ready), |
|
|
|
|
metadata, request_serializer, response_deserializer, loop) |
|
|
|
|
self._initializer = self._loop.create_task(self._prepare_rpc()) |
|
|
|
|
self._init_stream_request_mixin(request_async_iterator) |
|
|
|
|
self._init_stream_request_mixin(request_iterator) |
|
|
|
|
self._init_stream_response_mixin(self._initializer) |
|
|
|
|
|
|
|
|
|
async def _prepare_rpc(self): |
|
|
|
|