mirror of https://github.com/grpc/grpc.git
Merge branch 'implement-server-interceptor-for-unary-unary-call' of github.com:ZHmao/grpc into implement-server-interceptor-for-unary-unary-call
commit
99e26eb647
126 changed files with 5406 additions and 884 deletions
@ -0,0 +1,3 @@ |
||||
dl.field-list > dt { |
||||
word-break: keep-all !important; |
||||
} |
@ -0,0 +1,132 @@ |
||||
gRPC AsyncIO API |
||||
================ |
||||
|
||||
.. module:: grpc.experimental.aio |
||||
|
||||
Overview |
||||
-------- |
||||
|
||||
gRPC AsyncIO API is the **new version** of gRPC Python whose architecture is |
||||
tailored to AsyncIO. Underlying, it utilizes the same C-extension, gRPC C-Core, |
||||
as existing stack, and it replaces all gRPC IO operations with methods provided |
||||
by the AsyncIO library. |
||||
|
||||
This stack currently is under active development. Feel free to offer |
||||
suggestions by opening issues on our GitHub repo `grpc/grpc <https://github.com/grpc/grpc>`_. |
||||
|
||||
The design doc can be found here as `gRFC <https://github.com/grpc/proposal/pull/155>`_. |
||||
|
||||
|
||||
Caveats |
||||
------- |
||||
|
||||
gRPC Async API objects may only be used on the thread on which they were |
||||
created. AsyncIO doesn't provide thread safety for most of its APIs. |
||||
|
||||
|
||||
Module Contents |
||||
--------------- |
||||
|
||||
Enable AsyncIO in gRPC |
||||
^^^^^^^^^^^^^^^^^^^^^^ |
||||
|
||||
.. function:: init_grpc_aio |
||||
|
||||
Enable AsyncIO for gRPC Python. |
||||
|
||||
This function is idempotent and it should be invoked before creation of |
||||
AsyncIO stack objects. Otherwise, the application might deadlock. |
||||
|
||||
This function configurates the gRPC C-Core to invoke AsyncIO methods for IO |
||||
operations (e.g., socket read, write). The configuration applies to the |
||||
entire process. |
||||
|
||||
After invoking this function, making blocking function calls in coroutines |
||||
or in the thread running event loop will block the event loop, potentially |
||||
starving all RPCs in the process. Refer to the Python language |
||||
documentation on AsyncIO for more details (`running-blocking-code <https://docs.python.org/3/library/asyncio-dev.html#running-blocking-code>`_). |
||||
|
||||
|
||||
Create Channel |
||||
^^^^^^^^^^^^^^ |
||||
|
||||
Channels are the abstraction of clients, where most of networking logic |
||||
happens, for example, managing one or more underlying connections, name |
||||
resolution, load balancing, flow control, etc.. If you are using ProtoBuf, |
||||
Channel objects works best when further encapsulate into stub objects, then the |
||||
application can invoke remote functions as if they are local functions. |
||||
|
||||
.. autofunction:: insecure_channel |
||||
.. autofunction:: secure_channel |
||||
|
||||
|
||||
Channel Object |
||||
^^^^^^^^^^^^^^ |
||||
|
||||
.. autoclass:: Channel |
||||
|
||||
|
||||
Create Server |
||||
^^^^^^^^^^^^^ |
||||
|
||||
.. autofunction:: server |
||||
|
||||
|
||||
Server Object |
||||
^^^^^^^^^^^^^ |
||||
|
||||
.. autoclass:: Server |
||||
|
||||
|
||||
gRPC Exceptions |
||||
^^^^^^^^^^^^^^^ |
||||
|
||||
.. autoexception:: BaseError |
||||
.. autoexception:: UsageError |
||||
.. autoexception:: AbortError |
||||
.. autoexception:: InternalError |
||||
.. autoexception:: AioRpcError |
||||
|
||||
|
||||
Shared Context |
||||
^^^^^^^^^^^^^^^^^^^^ |
||||
|
||||
.. autoclass:: RpcContext |
||||
|
||||
|
||||
Client-Side Context |
||||
^^^^^^^^^^^^^^^^^^^^^^^ |
||||
|
||||
.. autoclass:: Call |
||||
.. autoclass:: UnaryUnaryCall |
||||
.. autoclass:: UnaryStreamCall |
||||
.. autoclass:: StreamUnaryCall |
||||
.. autoclass:: StreamStreamCall |
||||
|
||||
|
||||
Server-Side Context |
||||
^^^^^^^^^^^^^^^^^^^^^^^ |
||||
|
||||
.. autoclass:: ServicerContext |
||||
|
||||
|
||||
Client-Side Interceptor |
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
||||
|
||||
.. autoclass:: ClientCallDetails |
||||
.. autoclass:: InterceptedUnaryUnaryCall |
||||
.. autoclass:: UnaryUnaryClientInterceptor |
||||
|
||||
.. Service-Side Context |
||||
.. ^^^^^^^^^^^^^^^^^^^^ |
||||
|
||||
.. .. autoclass:: ServicerContext |
||||
|
||||
|
||||
Multi-Callable Interfaces |
||||
^^^^^^^^^^^^^^^^^^^^^^^^^ |
||||
|
||||
.. autoclass:: UnaryUnaryMultiCallable |
||||
.. autoclass:: UnaryStreamMultiCallable() |
||||
.. autoclass:: StreamUnaryMultiCallable() |
||||
.. autoclass:: StreamStreamMultiCallable() |
@ -0,0 +1,450 @@ |
||||
# Copyright 2020 The gRPC authors. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
"""Functions that obviate explicit stubs and explicit channels.""" |
||||
|
||||
import collections |
||||
import datetime |
||||
import os |
||||
import logging |
||||
import threading |
||||
from typing import (Any, AnyStr, Callable, Dict, Iterator, Optional, Sequence, |
||||
Tuple, TypeVar, Union) |
||||
|
||||
import grpc |
||||
from grpc.experimental import experimental_api |
||||
|
||||
RequestType = TypeVar('RequestType') |
||||
ResponseType = TypeVar('ResponseType') |
||||
|
||||
OptionsType = Sequence[Tuple[str, str]] |
||||
CacheKey = Tuple[str, OptionsType, Optional[grpc.ChannelCredentials], Optional[ |
||||
grpc.Compression]] |
||||
|
||||
_LOGGER = logging.getLogger(__name__) |
||||
|
||||
_EVICTION_PERIOD_KEY = "GRPC_PYTHON_MANAGED_CHANNEL_EVICTION_SECONDS" |
||||
if _EVICTION_PERIOD_KEY in os.environ: |
||||
_EVICTION_PERIOD = datetime.timedelta( |
||||
seconds=float(os.environ[_EVICTION_PERIOD_KEY])) |
||||
_LOGGER.debug("Setting managed channel eviction period to %s", |
||||
_EVICTION_PERIOD) |
||||
else: |
||||
_EVICTION_PERIOD = datetime.timedelta(minutes=10) |
||||
|
||||
_MAXIMUM_CHANNELS_KEY = "GRPC_PYTHON_MANAGED_CHANNEL_MAXIMUM" |
||||
if _MAXIMUM_CHANNELS_KEY in os.environ: |
||||
_MAXIMUM_CHANNELS = int(os.environ[_MAXIMUM_CHANNELS_KEY]) |
||||
_LOGGER.debug("Setting maximum managed channels to %d", _MAXIMUM_CHANNELS) |
||||
else: |
||||
_MAXIMUM_CHANNELS = 2**8 |
||||
|
||||
|
||||
def _create_channel(target: str, options: Sequence[Tuple[str, str]], |
||||
channel_credentials: Optional[grpc.ChannelCredentials], |
||||
compression: Optional[grpc.Compression]) -> grpc.Channel: |
||||
channel_credentials = channel_credentials or grpc.local_channel_credentials( |
||||
) |
||||
if channel_credentials._credentials is grpc.experimental._insecure_channel_credentials: |
||||
_LOGGER.debug(f"Creating insecure channel with options '{options}' " + |
||||
f"and compression '{compression}'") |
||||
return grpc.insecure_channel(target, |
||||
options=options, |
||||
compression=compression) |
||||
else: |
||||
_LOGGER.debug( |
||||
f"Creating secure channel with credentials '{channel_credentials}', " |
||||
+ f"options '{options}' and compression '{compression}'") |
||||
return grpc.secure_channel(target, |
||||
credentials=channel_credentials, |
||||
options=options, |
||||
compression=compression) |
||||
|
||||
|
||||
class ChannelCache: |
||||
# NOTE(rbellevi): Untyped due to reference cycle. |
||||
_singleton = None |
||||
_lock: threading.RLock = threading.RLock() |
||||
_condition: threading.Condition = threading.Condition(lock=_lock) |
||||
_eviction_ready: threading.Event = threading.Event() |
||||
|
||||
_mapping: Dict[CacheKey, Tuple[grpc.Channel, datetime.datetime]] |
||||
_eviction_thread: threading.Thread |
||||
|
||||
def __init__(self): |
||||
self._mapping = collections.OrderedDict() |
||||
self._eviction_thread = threading.Thread( |
||||
target=ChannelCache._perform_evictions, daemon=True) |
||||
self._eviction_thread.start() |
||||
|
||||
@staticmethod |
||||
def get(): |
||||
with ChannelCache._lock: |
||||
if ChannelCache._singleton is None: |
||||
ChannelCache._singleton = ChannelCache() |
||||
ChannelCache._eviction_ready.wait() |
||||
return ChannelCache._singleton |
||||
|
||||
def _evict_locked(self, key: CacheKey): |
||||
channel, _ = self._mapping.pop(key) |
||||
_LOGGER.debug("Evicting channel %s with configuration %s.", channel, |
||||
key) |
||||
channel.close() |
||||
del channel |
||||
|
||||
@staticmethod |
||||
def _perform_evictions(): |
||||
while True: |
||||
with ChannelCache._lock: |
||||
ChannelCache._eviction_ready.set() |
||||
if not ChannelCache._singleton._mapping: |
||||
ChannelCache._condition.wait() |
||||
elif len(ChannelCache._singleton._mapping) > _MAXIMUM_CHANNELS: |
||||
key = next(iter(ChannelCache._singleton._mapping.keys())) |
||||
ChannelCache._singleton._evict_locked(key) |
||||
# And immediately reevaluate. |
||||
else: |
||||
key, (_, eviction_time) = next( |
||||
iter(ChannelCache._singleton._mapping.items())) |
||||
now = datetime.datetime.now() |
||||
if eviction_time <= now: |
||||
ChannelCache._singleton._evict_locked(key) |
||||
continue |
||||
else: |
||||
time_to_eviction = (eviction_time - now).total_seconds() |
||||
# NOTE: We aim to *eventually* coalesce to a state in |
||||
# which no overdue channels are in the cache and the |
||||
# length of the cache is longer than _MAXIMUM_CHANNELS. |
||||
# We tolerate momentary states in which these two |
||||
# criteria are not met. |
||||
ChannelCache._condition.wait(timeout=time_to_eviction) |
||||
|
||||
def get_channel(self, target: str, options: Sequence[Tuple[str, str]], |
||||
channel_credentials: Optional[grpc.ChannelCredentials], |
||||
compression: Optional[grpc.Compression]) -> grpc.Channel: |
||||
key = (target, options, channel_credentials, compression) |
||||
with self._lock: |
||||
channel_data = self._mapping.get(key, None) |
||||
if channel_data is not None: |
||||
channel = channel_data[0] |
||||
self._mapping.pop(key) |
||||
self._mapping[key] = (channel, datetime.datetime.now() + |
||||
_EVICTION_PERIOD) |
||||
return channel |
||||
else: |
||||
channel = _create_channel(target, options, channel_credentials, |
||||
compression) |
||||
self._mapping[key] = (channel, datetime.datetime.now() + |
||||
_EVICTION_PERIOD) |
||||
if len(self._mapping) == 1 or len( |
||||
self._mapping) >= _MAXIMUM_CHANNELS: |
||||
self._condition.notify() |
||||
return channel |
||||
|
||||
def _test_only_channel_count(self) -> int: |
||||
with self._lock: |
||||
return len(self._mapping) |
||||
|
||||
|
||||
# TODO(rbellevi): Consider a credential type that has the |
||||
# following functionality matrix: |
||||
# |
||||
# +----------+-------+--------+ |
||||
# | | local | remote | |
||||
# |----------+-------+--------+ |
||||
# | secure | o | o | |
||||
# | insecure | o | x | |
||||
# +----------+-------+--------+ |
||||
# |
||||
# Make this the default option. |
||||
|
||||
|
||||
@experimental_api |
||||
def unary_unary( |
||||
request: RequestType, |
||||
target: str, |
||||
method: str, |
||||
request_serializer: Optional[Callable[[Any], bytes]] = None, |
||||
request_deserializer: Optional[Callable[[bytes], Any]] = None, |
||||
options: Sequence[Tuple[AnyStr, AnyStr]] = (), |
||||
channel_credentials: Optional[grpc.ChannelCredentials] = None, |
||||
call_credentials: Optional[grpc.CallCredentials] = None, |
||||
compression: Optional[grpc.Compression] = None, |
||||
wait_for_ready: Optional[bool] = None, |
||||
timeout: Optional[float] = None, |
||||
metadata: Optional[Sequence[Tuple[str, Union[str, bytes]]]] = None |
||||
) -> ResponseType: |
||||
"""Invokes a unary-unary RPC without an explicitly specified channel. |
||||
|
||||
THIS IS AN EXPERIMENTAL API. |
||||
|
||||
This is backed by a per-process cache of channels. Channels are evicted |
||||
from the cache after a fixed period by a background. Channels will also be |
||||
evicted if more than a configured maximum accumulate. |
||||
|
||||
The default eviction period is 10 minutes. One may set the environment |
||||
variable "GRPC_PYTHON_MANAGED_CHANNEL_EVICTION_SECONDS" to configure this. |
||||
|
||||
The default maximum number of channels is 256. One may set the |
||||
environment variable "GRPC_PYTHON_MANAGED_CHANNEL_MAXIMUM" to configure |
||||
this. |
||||
|
||||
Args: |
||||
request: An iterator that yields request values for the RPC. |
||||
target: The server address. |
||||
method: The name of the RPC method. |
||||
request_serializer: Optional behaviour for serializing the request |
||||
message. Request goes unserialized in case None is passed. |
||||
response_deserializer: Optional behaviour for deserializing the response |
||||
message. Response goes undeserialized in case None is passed. |
||||
options: An optional list of key-value pairs (channel args in gRPC Core |
||||
runtime) to configure the channel. |
||||
channel_credentials: A credential applied to the whole channel, e.g. the |
||||
return value of grpc.ssl_channel_credentials() or |
||||
grpc.insecure_channel_credentials(). |
||||
call_credentials: A call credential applied to each call individually, |
||||
e.g. the output of grpc.metadata_call_credentials() or |
||||
grpc.access_token_call_credentials(). |
||||
compression: An optional value indicating the compression method to be |
||||
used over the lifetime of the channel, e.g. grpc.Compression.Gzip. |
||||
wait_for_ready: An optional flag indicating whether the RPC should fail |
||||
immediately if the connection is not ready at the time the RPC is |
||||
invoked, or if it should wait until the connection to the server |
||||
becomes ready. When using this option, the user will likely also want |
||||
to set a timeout. Defaults to False. |
||||
timeout: An optional duration of time in seconds to allow for the RPC, |
||||
after which an exception will be raised. |
||||
metadata: Optional metadata to send to the server. |
||||
|
||||
Returns: |
||||
The response to the RPC. |
||||
""" |
||||
channel = ChannelCache.get().get_channel(target, options, |
||||
channel_credentials, compression) |
||||
multicallable = channel.unary_unary(method, request_serializer, |
||||
request_deserializer) |
||||
return multicallable(request, |
||||
metadata=metadata, |
||||
wait_for_ready=wait_for_ready, |
||||
credentials=call_credentials, |
||||
timeout=timeout) |
||||
|
||||
|
||||
@experimental_api |
||||
def unary_stream( |
||||
request: RequestType, |
||||
target: str, |
||||
method: str, |
||||
request_serializer: Optional[Callable[[Any], bytes]] = None, |
||||
request_deserializer: Optional[Callable[[bytes], Any]] = None, |
||||
options: Sequence[Tuple[AnyStr, AnyStr]] = (), |
||||
channel_credentials: Optional[grpc.ChannelCredentials] = None, |
||||
call_credentials: Optional[grpc.CallCredentials] = None, |
||||
compression: Optional[grpc.Compression] = None, |
||||
wait_for_ready: Optional[bool] = None, |
||||
timeout: Optional[float] = None, |
||||
metadata: Optional[Sequence[Tuple[str, Union[str, bytes]]]] = None |
||||
) -> Iterator[ResponseType]: |
||||
"""Invokes a unary-stream RPC without an explicitly specified channel. |
||||
|
||||
THIS IS AN EXPERIMENTAL API. |
||||
|
||||
This is backed by a per-process cache of channels. Channels are evicted |
||||
from the cache after a fixed period by a background. Channels will also be |
||||
evicted if more than a configured maximum accumulate. |
||||
|
||||
The default eviction period is 10 minutes. One may set the environment |
||||
variable "GRPC_PYTHON_MANAGED_CHANNEL_EVICTION_SECONDS" to configure this. |
||||
|
||||
The default maximum number of channels is 256. One may set the |
||||
environment variable "GRPC_PYTHON_MANAGED_CHANNEL_MAXIMUM" to configure |
||||
this. |
||||
|
||||
Args: |
||||
request: An iterator that yields request values for the RPC. |
||||
target: The server address. |
||||
method: The name of the RPC method. |
||||
request_serializer: Optional behaviour for serializing the request |
||||
message. Request goes unserialized in case None is passed. |
||||
response_deserializer: Optional behaviour for deserializing the response |
||||
message. Response goes undeserialized in case None is passed. |
||||
options: An optional list of key-value pairs (channel args in gRPC Core |
||||
runtime) to configure the channel. |
||||
channel_credentials: A credential applied to the whole channel, e.g. the |
||||
return value of grpc.ssl_channel_credentials(). |
||||
call_credentials: A call credential applied to each call individually, |
||||
e.g. the output of grpc.metadata_call_credentials() or |
||||
grpc.access_token_call_credentials(). |
||||
compression: An optional value indicating the compression method to be |
||||
used over the lifetime of the channel, e.g. grpc.Compression.Gzip. |
||||
wait_for_ready: An optional flag indicating whether the RPC should fail |
||||
immediately if the connection is not ready at the time the RPC is |
||||
invoked, or if it should wait until the connection to the server |
||||
becomes ready. When using this option, the user will likely also want |
||||
to set a timeout. Defaults to False. |
||||
timeout: An optional duration of time in seconds to allow for the RPC, |
||||
after which an exception will be raised. |
||||
metadata: Optional metadata to send to the server. |
||||
|
||||
Returns: |
||||
An iterator of responses. |
||||
""" |
||||
channel = ChannelCache.get().get_channel(target, options, |
||||
channel_credentials, compression) |
||||
multicallable = channel.unary_stream(method, request_serializer, |
||||
request_deserializer) |
||||
return multicallable(request, |
||||
metadata=metadata, |
||||
wait_for_ready=wait_for_ready, |
||||
credentials=call_credentials, |
||||
timeout=timeout) |
||||
|
||||
|
||||
@experimental_api |
||||
def stream_unary( |
||||
request_iterator: Iterator[RequestType], |
||||
target: str, |
||||
method: str, |
||||
request_serializer: Optional[Callable[[Any], bytes]] = None, |
||||
request_deserializer: Optional[Callable[[bytes], Any]] = None, |
||||
options: Sequence[Tuple[AnyStr, AnyStr]] = (), |
||||
channel_credentials: Optional[grpc.ChannelCredentials] = None, |
||||
call_credentials: Optional[grpc.CallCredentials] = None, |
||||
compression: Optional[grpc.Compression] = None, |
||||
wait_for_ready: Optional[bool] = None, |
||||
timeout: Optional[float] = None, |
||||
metadata: Optional[Sequence[Tuple[str, Union[str, bytes]]]] = None |
||||
) -> ResponseType: |
||||
"""Invokes a stream-unary RPC without an explicitly specified channel. |
||||
|
||||
THIS IS AN EXPERIMENTAL API. |
||||
|
||||
This is backed by a per-process cache of channels. Channels are evicted |
||||
from the cache after a fixed period by a background. Channels will also be |
||||
evicted if more than a configured maximum accumulate. |
||||
|
||||
The default eviction period is 10 minutes. One may set the environment |
||||
variable "GRPC_PYTHON_MANAGED_CHANNEL_EVICTION_SECONDS" to configure this. |
||||
|
||||
The default maximum number of channels is 256. One may set the |
||||
environment variable "GRPC_PYTHON_MANAGED_CHANNEL_MAXIMUM" to configure |
||||
this. |
||||
|
||||
Args: |
||||
request_iterator: An iterator that yields request values for the RPC. |
||||
target: The server address. |
||||
method: The name of the RPC method. |
||||
request_serializer: Optional behaviour for serializing the request |
||||
message. Request goes unserialized in case None is passed. |
||||
response_deserializer: Optional behaviour for deserializing the response |
||||
message. Response goes undeserialized in case None is passed. |
||||
options: An optional list of key-value pairs (channel args in gRPC Core |
||||
runtime) to configure the channel. |
||||
channel_credentials: A credential applied to the whole channel, e.g. the |
||||
return value of grpc.ssl_channel_credentials(). |
||||
call_credentials: A call credential applied to each call individually, |
||||
e.g. the output of grpc.metadata_call_credentials() or |
||||
grpc.access_token_call_credentials(). |
||||
compression: An optional value indicating the compression method to be |
||||
used over the lifetime of the channel, e.g. grpc.Compression.Gzip. |
||||
wait_for_ready: An optional flag indicating whether the RPC should fail |
||||
immediately if the connection is not ready at the time the RPC is |
||||
invoked, or if it should wait until the connection to the server |
||||
becomes ready. When using this option, the user will likely also want |
||||
to set a timeout. Defaults to False. |
||||
timeout: An optional duration of time in seconds to allow for the RPC, |
||||
after which an exception will be raised. |
||||
metadata: Optional metadata to send to the server. |
||||
|
||||
Returns: |
||||
The response to the RPC. |
||||
""" |
||||
channel = ChannelCache.get().get_channel(target, options, |
||||
channel_credentials, compression) |
||||
multicallable = channel.stream_unary(method, request_serializer, |
||||
request_deserializer) |
||||
return multicallable(request_iterator, |
||||
metadata=metadata, |
||||
wait_for_ready=wait_for_ready, |
||||
credentials=call_credentials, |
||||
timeout=timeout) |
||||
|
||||
|
||||
@experimental_api |
||||
def stream_stream( |
||||
request_iterator: Iterator[RequestType], |
||||
target: str, |
||||
method: str, |
||||
request_serializer: Optional[Callable[[Any], bytes]] = None, |
||||
request_deserializer: Optional[Callable[[bytes], Any]] = None, |
||||
options: Sequence[Tuple[AnyStr, AnyStr]] = (), |
||||
channel_credentials: Optional[grpc.ChannelCredentials] = None, |
||||
call_credentials: Optional[grpc.CallCredentials] = None, |
||||
compression: Optional[grpc.Compression] = None, |
||||
wait_for_ready: Optional[bool] = None, |
||||
timeout: Optional[float] = None, |
||||
metadata: Optional[Sequence[Tuple[str, Union[str, bytes]]]] = None |
||||
) -> Iterator[ResponseType]: |
||||
"""Invokes a stream-stream RPC without an explicitly specified channel. |
||||
|
||||
THIS IS AN EXPERIMENTAL API. |
||||
|
||||
This is backed by a per-process cache of channels. Channels are evicted |
||||
from the cache after a fixed period by a background. Channels will also be |
||||
evicted if more than a configured maximum accumulate. |
||||
|
||||
The default eviction period is 10 minutes. One may set the environment |
||||
variable "GRPC_PYTHON_MANAGED_CHANNEL_EVICTION_SECONDS" to configure this. |
||||
|
||||
The default maximum number of channels is 256. One may set the |
||||
environment variable "GRPC_PYTHON_MANAGED_CHANNEL_MAXIMUM" to configure |
||||
this. |
||||
|
||||
Args: |
||||
request_iterator: An iterator that yields request values for the RPC. |
||||
target: The server address. |
||||
method: The name of the RPC method. |
||||
request_serializer: Optional behaviour for serializing the request |
||||
message. Request goes unserialized in case None is passed. |
||||
response_deserializer: Optional behaviour for deserializing the response |
||||
message. Response goes undeserialized in case None is passed. |
||||
options: An optional list of key-value pairs (channel args in gRPC Core |
||||
runtime) to configure the channel. |
||||
channel_credentials: A credential applied to the whole channel, e.g. the |
||||
return value of grpc.ssl_channel_credentials(). |
||||
call_credentials: A call credential applied to each call individually, |
||||
e.g. the output of grpc.metadata_call_credentials() or |
||||
grpc.access_token_call_credentials(). |
||||
compression: An optional value indicating the compression method to be |
||||
used over the lifetime of the channel, e.g. grpc.Compression.Gzip. |
||||
wait_for_ready: An optional flag indicating whether the RPC should fail |
||||
immediately if the connection is not ready at the time the RPC is |
||||
invoked, or if it should wait until the connection to the server |
||||
becomes ready. When using this option, the user will likely also want |
||||
to set a timeout. Defaults to False. |
||||
timeout: An optional duration of time in seconds to allow for the RPC, |
||||
after which an exception will be raised. |
||||
metadata: Optional metadata to send to the server. |
||||
|
||||
Returns: |
||||
An iterator of responses. |
||||
""" |
||||
channel = ChannelCache.get().get_channel(target, options, |
||||
channel_credentials, compression) |
||||
multicallable = channel.stream_stream(method, request_serializer, |
||||
request_deserializer) |
||||
return multicallable(request_iterator, |
||||
metadata=metadata, |
||||
wait_for_ready=wait_for_ready, |
||||
credentials=call_credentials, |
||||
timeout=timeout) |
@ -0,0 +1,345 @@ |
||||
# Copyright 2020 The gRPC Authors |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
"""Abstract base classes for Channel objects and Multicallable objects.""" |
||||
|
||||
import abc |
||||
from typing import Any, AsyncIterable, Optional |
||||
|
||||
import grpc |
||||
|
||||
from . import _base_call |
||||
from ._typing import DeserializingFunction, MetadataType, SerializingFunction |
||||
|
||||
_IMMUTABLE_EMPTY_TUPLE = tuple() |
||||
|
||||
|
||||
class UnaryUnaryMultiCallable(abc.ABC): |
||||
"""Enables asynchronous invocation of a unary-call RPC.""" |
||||
|
||||
@abc.abstractmethod |
||||
def __call__(self, |
||||
request: Any, |
||||
*, |
||||
timeout: Optional[float] = None, |
||||
metadata: Optional[MetadataType] = _IMMUTABLE_EMPTY_TUPLE, |
||||
credentials: Optional[grpc.CallCredentials] = None, |
||||
wait_for_ready: Optional[bool] = None, |
||||
compression: Optional[grpc.Compression] = None |
||||
) -> _base_call.UnaryUnaryCall: |
||||
"""Asynchronously invokes the underlying RPC. |
||||
|
||||
Args: |
||||
request: The request value for the RPC. |
||||
timeout: An optional duration of time in seconds to allow |
||||
for the RPC. |
||||
metadata: Optional :term:`metadata` to be transmitted to the |
||||
service-side of the RPC. |
||||
credentials: An optional CallCredentials for the RPC. Only valid for |
||||
secure Channel. |
||||
wait_for_ready: This is an EXPERIMENTAL argument. An optional |
||||
flag to enable wait for ready mechanism |
||||
compression: An element of grpc.compression, e.g. |
||||
grpc.compression.Gzip. This is an EXPERIMENTAL option. |
||||
|
||||
Returns: |
||||
A UnaryUnaryCall object. |
||||
|
||||
Raises: |
||||
RpcError: Indicates that the RPC terminated with non-OK status. The |
||||
raised RpcError will also be a Call for the RPC affording the RPC's |
||||
metadata, status code, and details. |
||||
""" |
||||
|
||||
|
||||
class UnaryStreamMultiCallable(abc.ABC): |
||||
"""Enables asynchronous invocation of a server-streaming RPC.""" |
||||
|
||||
@abc.abstractmethod |
||||
def __call__(self, |
||||
request: Any, |
||||
*, |
||||
timeout: Optional[float] = None, |
||||
metadata: Optional[MetadataType] = _IMMUTABLE_EMPTY_TUPLE, |
||||
credentials: Optional[grpc.CallCredentials] = None, |
||||
wait_for_ready: Optional[bool] = None, |
||||
compression: Optional[grpc.Compression] = None |
||||
) -> _base_call.UnaryStreamCall: |
||||
"""Asynchronously invokes the underlying RPC. |
||||
|
||||
Args: |
||||
request: The request value for the RPC. |
||||
timeout: An optional duration of time in seconds to allow |
||||
for the RPC. |
||||
metadata: Optional :term:`metadata` to be transmitted to the |
||||
service-side of the RPC. |
||||
credentials: An optional CallCredentials for the RPC. Only valid for |
||||
secure Channel. |
||||
wait_for_ready: This is an EXPERIMENTAL argument. An optional |
||||
flag to enable wait for ready mechanism |
||||
compression: An element of grpc.compression, e.g. |
||||
grpc.compression.Gzip. This is an EXPERIMENTAL option. |
||||
|
||||
Returns: |
||||
A UnaryStreamCall object. |
||||
|
||||
Raises: |
||||
RpcError: Indicates that the RPC terminated with non-OK status. The |
||||
raised RpcError will also be a Call for the RPC affording the RPC's |
||||
metadata, status code, and details. |
||||
""" |
||||
|
||||
|
||||
class StreamUnaryMultiCallable(abc.ABC): |
||||
"""Enables asynchronous invocation of a client-streaming RPC.""" |
||||
|
||||
@abc.abstractmethod |
||||
def __call__(self, |
||||
request_async_iterator: Optional[AsyncIterable[Any]] = None, |
||||
timeout: Optional[float] = None, |
||||
metadata: Optional[MetadataType] = _IMMUTABLE_EMPTY_TUPLE, |
||||
credentials: Optional[grpc.CallCredentials] = None, |
||||
wait_for_ready: Optional[bool] = None, |
||||
compression: Optional[grpc.Compression] = None |
||||
) -> _base_call.StreamUnaryCall: |
||||
"""Asynchronously invokes the underlying RPC. |
||||
|
||||
Args: |
||||
request: The request value for the RPC. |
||||
timeout: An optional duration of time in seconds to allow |
||||
for the RPC. |
||||
metadata: Optional :term:`metadata` to be transmitted to the |
||||
service-side of the RPC. |
||||
credentials: An optional CallCredentials for the RPC. Only valid for |
||||
secure Channel. |
||||
wait_for_ready: This is an EXPERIMENTAL argument. An optional |
||||
flag to enable wait for ready mechanism |
||||
compression: An element of grpc.compression, e.g. |
||||
grpc.compression.Gzip. This is an EXPERIMENTAL option. |
||||
|
||||
Returns: |
||||
A StreamUnaryCall object. |
||||
|
||||
Raises: |
||||
RpcError: Indicates that the RPC terminated with non-OK status. The |
||||
raised RpcError will also be a Call for the RPC affording the RPC's |
||||
metadata, status code, and details. |
||||
""" |
||||
|
||||
|
||||
class StreamStreamMultiCallable(abc.ABC): |
||||
"""Enables asynchronous invocation of a bidirectional-streaming RPC.""" |
||||
|
||||
@abc.abstractmethod |
||||
def __call__(self, |
||||
request_async_iterator: Optional[AsyncIterable[Any]] = None, |
||||
timeout: Optional[float] = None, |
||||
metadata: Optional[MetadataType] = _IMMUTABLE_EMPTY_TUPLE, |
||||
credentials: Optional[grpc.CallCredentials] = None, |
||||
wait_for_ready: Optional[bool] = None, |
||||
compression: Optional[grpc.Compression] = None |
||||
) -> _base_call.StreamStreamCall: |
||||
"""Asynchronously invokes the underlying RPC. |
||||
|
||||
Args: |
||||
request: The request value for the RPC. |
||||
timeout: An optional duration of time in seconds to allow |
||||
for the RPC. |
||||
metadata: Optional :term:`metadata` to be transmitted to the |
||||
service-side of the RPC. |
||||
credentials: An optional CallCredentials for the RPC. Only valid for |
||||
secure Channel. |
||||
wait_for_ready: This is an EXPERIMENTAL argument. An optional |
||||
flag to enable wait for ready mechanism |
||||
compression: An element of grpc.compression, e.g. |
||||
grpc.compression.Gzip. This is an EXPERIMENTAL option. |
||||
|
||||
Returns: |
||||
A StreamStreamCall object. |
||||
|
||||
Raises: |
||||
RpcError: Indicates that the RPC terminated with non-OK status. The |
||||
raised RpcError will also be a Call for the RPC affording the RPC's |
||||
metadata, status code, and details. |
||||
""" |
||||
|
||||
|
||||
class Channel(abc.ABC): |
||||
"""Enables asynchronous RPC invocation as a client. |
||||
|
||||
Channel objects implement the Asynchronous Context Manager (aka. async |
||||
with) type, although they are not supportted to be entered and exited |
||||
multiple times. |
||||
""" |
||||
|
||||
@abc.abstractmethod |
||||
async def __aenter__(self): |
||||
"""Starts an asynchronous context manager. |
||||
|
||||
Returns: |
||||
Channel the channel that was instantiated. |
||||
""" |
||||
|
||||
@abc.abstractmethod |
||||
async def __aexit__(self, exc_type, exc_val, exc_tb): |
||||
"""Finishes the asynchronous context manager by closing the channel. |
||||
|
||||
Still active RPCs will be cancelled. |
||||
""" |
||||
|
||||
@abc.abstractmethod |
||||
async def close(self, grace: Optional[float] = None): |
||||
"""Closes this Channel and releases all resources held by it. |
||||
|
||||
This method immediately stops the channel from executing new RPCs in |
||||
all cases. |
||||
|
||||
If a grace period is specified, this method wait until all active |
||||
RPCs are finshed, once the grace period is reached the ones that haven't |
||||
been terminated are cancelled. If a grace period is not specified |
||||
(by passing None for grace), all existing RPCs are cancelled immediately. |
||||
|
||||
This method is idempotent. |
||||
""" |
||||
|
||||
@abc.abstractmethod |
||||
def get_state(self, |
||||
try_to_connect: bool = False) -> grpc.ChannelConnectivity: |
||||
"""Checks the connectivity state of a channel. |
||||
|
||||
This is an EXPERIMENTAL API. |
||||
|
||||
If the channel reaches a stable connectivity state, it is guaranteed |
||||
that the return value of this function will eventually converge to that |
||||
state. |
||||
|
||||
Args: |
||||
try_to_connect: a bool indicate whether the Channel should try to |
||||
connect to peer or not. |
||||
|
||||
Returns: A ChannelConnectivity object. |
||||
""" |
||||
|
||||
@abc.abstractmethod |
||||
async def wait_for_state_change( |
||||
self, |
||||
last_observed_state: grpc.ChannelConnectivity, |
||||
) -> None: |
||||
"""Waits for a change in connectivity state. |
||||
|
||||
This is an EXPERIMENTAL API. |
||||
|
||||
The function blocks until there is a change in the channel connectivity |
||||
state from the "last_observed_state". If the state is already |
||||
different, this function will return immediately. |
||||
|
||||
There is an inherent race between the invocation of |
||||
"Channel.wait_for_state_change" and "Channel.get_state". The state can |
||||
change arbitrary many times during the race, so there is no way to |
||||
observe every state transition. |
||||
|
||||
If there is a need to put a timeout for this function, please refer to |
||||
"asyncio.wait_for". |
||||
|
||||
Args: |
||||
last_observed_state: A grpc.ChannelConnectivity object representing |
||||
the last known state. |
||||
""" |
||||
|
||||
@abc.abstractmethod |
||||
async def channel_ready(self) -> None: |
||||
"""Creates a coroutine that blocks until the Channel is READY.""" |
||||
|
||||
@abc.abstractmethod |
||||
def unary_unary( |
||||
self, |
||||
method: str, |
||||
request_serializer: Optional[SerializingFunction] = None, |
||||
response_deserializer: Optional[DeserializingFunction] = None |
||||
) -> UnaryUnaryMultiCallable: |
||||
"""Creates a UnaryUnaryMultiCallable for a unary-unary method. |
||||
|
||||
Args: |
||||
method: The name of the RPC method. |
||||
request_serializer: Optional behaviour for serializing the request |
||||
message. Request goes unserialized in case None is passed. |
||||
response_deserializer: Optional behaviour for deserializing the |
||||
response message. Response goes undeserialized in case None |
||||
is passed. |
||||
|
||||
Returns: |
||||
A UnaryUnaryMultiCallable value for the named unary-unary method. |
||||
""" |
||||
|
||||
@abc.abstractmethod |
||||
def unary_stream( |
||||
self, |
||||
method: str, |
||||
request_serializer: Optional[SerializingFunction] = None, |
||||
response_deserializer: Optional[DeserializingFunction] = None |
||||
) -> UnaryStreamMultiCallable: |
||||
"""Creates a UnaryStreamMultiCallable for a unary-stream method. |
||||
|
||||
Args: |
||||
method: The name of the RPC method. |
||||
request_serializer: Optional behaviour for serializing the request |
||||
message. Request goes unserialized in case None is passed. |
||||
response_deserializer: Optional behaviour for deserializing the |
||||
response message. Response goes undeserialized in case None |
||||
is passed. |
||||
|
||||
Returns: |
||||
A UnarySteramMultiCallable value for the named unary-stream method. |
||||
""" |
||||
|
||||
@abc.abstractmethod |
||||
def stream_unary( |
||||
self, |
||||
method: str, |
||||
request_serializer: Optional[SerializingFunction] = None, |
||||
response_deserializer: Optional[DeserializingFunction] = None |
||||
) -> StreamUnaryMultiCallable: |
||||
"""Creates a StreamUnaryMultiCallable for a stream-unary method. |
||||
|
||||
Args: |
||||
method: The name of the RPC method. |
||||
request_serializer: Optional behaviour for serializing the request |
||||
message. Request goes unserialized in case None is passed. |
||||
response_deserializer: Optional behaviour for deserializing the |
||||
response message. Response goes undeserialized in case None |
||||
is passed. |
||||
|
||||
Returns: |
||||
A StreamUnaryMultiCallable value for the named stream-unary method. |
||||
""" |
||||
|
||||
@abc.abstractmethod |
||||
def stream_stream( |
||||
self, |
||||
method: str, |
||||
request_serializer: Optional[SerializingFunction] = None, |
||||
response_deserializer: Optional[DeserializingFunction] = None |
||||
) -> StreamStreamMultiCallable: |
||||
"""Creates a StreamStreamMultiCallable for a stream-stream method. |
||||
|
||||
Args: |
||||
method: The name of the RPC method. |
||||
request_serializer: Optional behaviour for serializing the request |
||||
message. Request goes unserialized in case None is passed. |
||||
response_deserializer: Optional behaviour for deserializing the |
||||
response message. Response goes undeserialized in case None |
||||
is passed. |
||||
|
||||
Returns: |
||||
A StreamStreamMultiCallable value for the named stream-stream method. |
||||
""" |
@ -0,0 +1,254 @@ |
||||
# Copyright 2020 The gRPC Authors |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
"""Abstract base classes for server-side classes.""" |
||||
|
||||
import abc |
||||
from typing import Generic, Optional, Sequence |
||||
|
||||
import grpc |
||||
|
||||
from ._typing import MetadataType, RequestType, ResponseType |
||||
|
||||
|
||||
class Server(abc.ABC): |
||||
"""Serves RPCs.""" |
||||
|
||||
@abc.abstractmethod |
||||
def add_generic_rpc_handlers( |
||||
self, |
||||
generic_rpc_handlers: Sequence[grpc.GenericRpcHandler]) -> None: |
||||
"""Registers GenericRpcHandlers with this Server. |
||||
|
||||
This method is only safe to call before the server is started. |
||||
|
||||
Args: |
||||
generic_rpc_handlers: A sequence of GenericRpcHandlers that will be |
||||
used to service RPCs. |
||||
""" |
||||
|
||||
@abc.abstractmethod |
||||
def add_insecure_port(self, address: str) -> int: |
||||
"""Opens an insecure port for accepting RPCs. |
||||
|
||||
A port is a communication endpoint that used by networking protocols, |
||||
like TCP and UDP. To date, we only support TCP. |
||||
|
||||
This method may only be called before starting the server. |
||||
|
||||
Args: |
||||
address: The address for which to open a port. If the port is 0, |
||||
or not specified in the address, then the gRPC runtime will choose a port. |
||||
|
||||
Returns: |
||||
An integer port on which the server will accept RPC requests. |
||||
""" |
||||
|
||||
@abc.abstractmethod |
||||
def add_secure_port(self, address: str, |
||||
server_credentials: grpc.ServerCredentials) -> int: |
||||
"""Opens a secure port for accepting RPCs. |
||||
|
||||
A port is a communication endpoint that used by networking protocols, |
||||
like TCP and UDP. To date, we only support TCP. |
||||
|
||||
This method may only be called before starting the server. |
||||
|
||||
Args: |
||||
address: The address for which to open a port. |
||||
if the port is 0, or not specified in the address, then the gRPC |
||||
runtime will choose a port. |
||||
server_credentials: A ServerCredentials object. |
||||
|
||||
Returns: |
||||
An integer port on which the server will accept RPC requests. |
||||
""" |
||||
|
||||
@abc.abstractmethod |
||||
async def start(self) -> None: |
||||
"""Starts this Server. |
||||
|
||||
This method may only be called once. (i.e. it is not idempotent). |
||||
""" |
||||
|
||||
@abc.abstractmethod |
||||
async def stop(self, grace: Optional[float]) -> None: |
||||
"""Stops this Server. |
||||
|
||||
This method immediately stops the server from servicing new RPCs in |
||||
all cases. |
||||
|
||||
If a grace period is specified, this method returns immediately and all |
||||
RPCs active at the end of the grace period are aborted. If a grace |
||||
period is not specified (by passing None for grace), all existing RPCs |
||||
are aborted immediately and this method blocks until the last RPC |
||||
handler terminates. |
||||
|
||||
This method is idempotent and may be called at any time. Passing a |
||||
smaller grace value in a subsequent call will have the effect of |
||||
stopping the Server sooner (passing None will have the effect of |
||||
stopping the server immediately). Passing a larger grace value in a |
||||
subsequent call will not have the effect of stopping the server later |
||||
(i.e. the most restrictive grace value is used). |
||||
|
||||
Args: |
||||
grace: A duration of time in seconds or None. |
||||
""" |
||||
|
||||
@abc.abstractmethod |
||||
async def wait_for_termination(self, |
||||
timeout: Optional[float] = None) -> bool: |
||||
"""Continues current coroutine once the server stops. |
||||
|
||||
This is an EXPERIMENTAL API. |
||||
|
||||
The wait will not consume computational resources during blocking, and |
||||
it will block until one of the two following conditions are met: |
||||
|
||||
1) The server is stopped or terminated; |
||||
2) A timeout occurs if timeout is not `None`. |
||||
|
||||
The timeout argument works in the same way as `threading.Event.wait()`. |
||||
https://docs.python.org/3/library/threading.html#threading.Event.wait |
||||
|
||||
Args: |
||||
timeout: A floating point number specifying a timeout for the |
||||
operation in seconds. |
||||
|
||||
Returns: |
||||
A bool indicates if the operation times out. |
||||
""" |
||||
|
||||
|
||||
class ServicerContext(Generic[RequestType, ResponseType], abc.ABC): |
||||
"""A context object passed to method implementations.""" |
||||
|
||||
@abc.abstractmethod |
||||
async def read(self) -> RequestType: |
||||
"""Reads one message from the RPC. |
||||
|
||||
Only one read operation is allowed simultaneously. |
||||
|
||||
Returns: |
||||
A response message of the RPC. |
||||
|
||||
Raises: |
||||
An RpcError exception if the read failed. |
||||
""" |
||||
|
||||
@abc.abstractmethod |
||||
async def write(self, message: ResponseType) -> None: |
||||
"""Writes one message to the RPC. |
||||
|
||||
Only one write operation is allowed simultaneously. |
||||
|
||||
Raises: |
||||
An RpcError exception if the write failed. |
||||
""" |
||||
|
||||
@abc.abstractmethod |
||||
async def send_initial_metadata(self, |
||||
initial_metadata: MetadataType) -> None: |
||||
"""Sends the initial metadata value to the client. |
||||
|
||||
This method need not be called by implementations if they have no |
||||
metadata to add to what the gRPC runtime will transmit. |
||||
|
||||
Args: |
||||
initial_metadata: The initial :term:`metadata`. |
||||
""" |
||||
|
||||
@abc.abstractmethod |
||||
async def abort(self, code: grpc.StatusCode, details: str, |
||||
trailing_metadata: MetadataType) -> None: |
||||
"""Raises an exception to terminate the RPC with a non-OK status. |
||||
|
||||
The code and details passed as arguments will supercede any existing |
||||
ones. |
||||
|
||||
Args: |
||||
code: A StatusCode object to be sent to the client. |
||||
It must not be StatusCode.OK. |
||||
details: A UTF-8-encodable string to be sent to the client upon |
||||
termination of the RPC. |
||||
trailing_metadata: A sequence of tuple represents the trailing |
||||
:term:`metadata`. |
||||
|
||||
Raises: |
||||
Exception: An exception is always raised to signal the abortion the |
||||
RPC to the gRPC runtime. |
||||
""" |
||||
|
||||
@abc.abstractmethod |
||||
async def set_trailing_metadata(self, |
||||
trailing_metadata: MetadataType) -> None: |
||||
"""Sends the trailing metadata for the RPC. |
||||
|
||||
This method need not be called by implementations if they have no |
||||
metadata to add to what the gRPC runtime will transmit. |
||||
|
||||
Args: |
||||
trailing_metadata: The trailing :term:`metadata`. |
||||
""" |
||||
|
||||
@abc.abstractmethod |
||||
def invocation_metadata(self) -> Optional[MetadataType]: |
||||
"""Accesses the metadata from the sent by the client. |
||||
|
||||
Returns: |
||||
The invocation :term:`metadata`. |
||||
""" |
||||
|
||||
@abc.abstractmethod |
||||
def set_code(self, code: grpc.StatusCode) -> None: |
||||
"""Sets the value to be used as status code upon RPC completion. |
||||
|
||||
This method need not be called by method implementations if they wish |
||||
the gRPC runtime to determine the status code of the RPC. |
||||
|
||||
Args: |
||||
code: A StatusCode object to be sent to the client. |
||||
""" |
||||
|
||||
@abc.abstractmethod |
||||
def set_details(self, details: str) -> None: |
||||
"""Sets the value to be used the as detail string upon RPC completion. |
||||
|
||||
This method need not be called by method implementations if they have |
||||
no details to transmit. |
||||
|
||||
Args: |
||||
details: A UTF-8-encodable string to be sent to the client upon |
||||
termination of the RPC. |
||||
""" |
||||
|
||||
@abc.abstractmethod |
||||
def set_compression(self, compression: grpc.Compression) -> None: |
||||
"""Set the compression algorithm to be used for the entire call. |
||||
|
||||
This is an EXPERIMENTAL method. |
||||
|
||||
Args: |
||||
compression: An element of grpc.compression, e.g. |
||||
grpc.compression.Gzip. |
||||
""" |
||||
|
||||
@abc.abstractmethod |
||||
def disable_next_message_compression(self) -> None: |
||||
"""Disables compression for the next response message. |
||||
|
||||
This is an EXPERIMENTAL method. |
||||
|
||||
This method will override any compression configuration set during |
||||
server creation or set on the call. |
||||
""" |
@ -0,0 +1,113 @@ |
||||
# Copyright 2020 The gRPC Authors |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
"""Reference implementation for health checking in gRPC Python.""" |
||||
|
||||
import asyncio |
||||
import collections |
||||
from typing import MutableMapping |
||||
import grpc |
||||
|
||||
from grpc_health.v1 import health_pb2 as _health_pb2 |
||||
from grpc_health.v1 import health_pb2_grpc as _health_pb2_grpc |
||||
|
||||
|
||||
class HealthServicer(_health_pb2_grpc.HealthServicer): |
||||
"""An AsyncIO implementation of health checking servicer.""" |
||||
_server_status: MutableMapping[ |
||||
str, '_health_pb2.HealthCheckResponse.ServingStatus'] |
||||
_server_watchers: MutableMapping[str, asyncio.Condition] |
||||
_gracefully_shutting_down: bool |
||||
|
||||
def __init__(self) -> None: |
||||
self._server_status = dict() |
||||
self._server_watchers = collections.defaultdict(asyncio.Condition) |
||||
self._gracefully_shutting_down = False |
||||
|
||||
async def Check(self, request: _health_pb2.HealthCheckRequest, |
||||
context) -> None: |
||||
status = self._server_status.get(request.service) |
||||
|
||||
if status is None: |
||||
await context.abort(grpc.StatusCode.NOT_FOUND) |
||||
else: |
||||
return _health_pb2.HealthCheckResponse(status=status) |
||||
|
||||
async def Watch(self, request: _health_pb2.HealthCheckRequest, |
||||
context) -> None: |
||||
condition = self._server_watchers[request.service] |
||||
last_status = None |
||||
try: |
||||
async with condition: |
||||
while True: |
||||
status = self._server_status.get( |
||||
request.service, |
||||
_health_pb2.HealthCheckResponse.SERVICE_UNKNOWN) |
||||
|
||||
# NOTE(lidiz) If the observed status is the same, it means |
||||
# there are missing intermediate statuses. It's considered |
||||
# acceptable since peer only interested in eventual status. |
||||
if status != last_status: |
||||
# Responds with current health state |
||||
await context.write( |
||||
_health_pb2.HealthCheckResponse(status=status)) |
||||
|
||||
# Records the last sent status |
||||
last_status = status |
||||
|
||||
# Polling on health state changes |
||||
await condition.wait() |
||||
finally: |
||||
if request.service in self._server_watchers: |
||||
del self._server_watchers[request.service] |
||||
|
||||
async def _set(self, service: str, |
||||
status: _health_pb2.HealthCheckResponse.ServingStatus |
||||
) -> None: |
||||
if service in self._server_watchers: |
||||
condition = self._server_watchers.get(service) |
||||
async with condition: |
||||
self._server_status[service] = status |
||||
condition.notify_all() |
||||
else: |
||||
self._server_status[service] = status |
||||
|
||||
async def set(self, service: str, |
||||
status: _health_pb2.HealthCheckResponse.ServingStatus |
||||
) -> None: |
||||
"""Sets the status of a service. |
||||
|
||||
Args: |
||||
service: string, the name of the service. |
||||
status: HealthCheckResponse.status enum value indicating the status of |
||||
the service |
||||
""" |
||||
if self._gracefully_shutting_down: |
||||
return |
||||
else: |
||||
await self._set(service, status) |
||||
|
||||
async def enter_graceful_shutdown(self) -> None: |
||||
"""Permanently sets the status of all services to NOT_SERVING. |
||||
|
||||
This should be invoked when the server is entering a graceful shutdown |
||||
period. After this method is invoked, future attempts to set the status |
||||
of a service will be ignored. |
||||
""" |
||||
if self._gracefully_shutting_down: |
||||
return |
||||
else: |
||||
self._gracefully_shutting_down = True |
||||
for service in self._server_status: |
||||
await self._set(service, |
||||
_health_pb2.HealthCheckResponse.NOT_SERVING) |
@ -0,0 +1,27 @@ |
||||
# Copyright 2020 The gRPC Authors |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
package( |
||||
default_testonly = 1, |
||||
default_visibility = ["//visibility:public"], |
||||
) |
||||
|
||||
py_library( |
||||
name = "histogram", |
||||
srcs = ["histogram.py"], |
||||
srcs_version = "PY2AND3", |
||||
deps = [ |
||||
"//src/proto/grpc/testing:stats_py_pb2", |
||||
], |
||||
) |
@ -0,0 +1,155 @@ |
||||
# Copyright 2020 The gRPC Authors |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
"""The Python AsyncIO Benchmark Clients.""" |
||||
|
||||
import abc |
||||
import asyncio |
||||
import time |
||||
import logging |
||||
import random |
||||
|
||||
import grpc |
||||
from grpc.experimental import aio |
||||
|
||||
from src.proto.grpc.testing import (benchmark_service_pb2_grpc, control_pb2, |
||||
messages_pb2) |
||||
from tests.qps import histogram |
||||
from tests.unit import resources |
||||
|
||||
|
||||
class GenericStub(object): |
||||
|
||||
def __init__(self, channel: aio.Channel): |
||||
self.UnaryCall = channel.unary_unary( |
||||
'/grpc.testing.BenchmarkService/UnaryCall') |
||||
self.StreamingCall = channel.stream_stream( |
||||
'/grpc.testing.BenchmarkService/StreamingCall') |
||||
|
||||
|
||||
class BenchmarkClient(abc.ABC): |
||||
"""Benchmark client interface that exposes a non-blocking send_request().""" |
||||
|
||||
def __init__(self, address: str, config: control_pb2.ClientConfig, |
||||
hist: histogram.Histogram): |
||||
# Disables underlying reuse of subchannels |
||||
unique_option = (('iv', random.random()),) |
||||
|
||||
# Parses the channel argument from config |
||||
channel_args = tuple( |
||||
(arg.name, arg.str_value) if arg.HasField('str_value') else ( |
||||
arg.name, int(arg.int_value)) for arg in config.channel_args) |
||||
|
||||
# Creates the channel |
||||
if config.HasField('security_params'): |
||||
channel_credentials = grpc.ssl_channel_credentials( |
||||
resources.test_root_certificates(),) |
||||
server_host_override_option = (( |
||||
'grpc.ssl_target_name_override', |
||||
config.security_params.server_host_override, |
||||
),) |
||||
self._channel = aio.secure_channel( |
||||
address, channel_credentials, |
||||
unique_option + channel_args + server_host_override_option) |
||||
else: |
||||
self._channel = aio.insecure_channel(address, |
||||
options=unique_option + |
||||
channel_args) |
||||
|
||||
# Creates the stub |
||||
if config.payload_config.WhichOneof('payload') == 'simple_params': |
||||
self._generic = False |
||||
self._stub = benchmark_service_pb2_grpc.BenchmarkServiceStub( |
||||
self._channel) |
||||
payload = messages_pb2.Payload( |
||||
body=b'\0' * config.payload_config.simple_params.req_size) |
||||
self._request = messages_pb2.SimpleRequest( |
||||
payload=payload, |
||||
response_size=config.payload_config.simple_params.resp_size) |
||||
else: |
||||
self._generic = True |
||||
self._stub = GenericStub(self._channel) |
||||
self._request = b'\0' * config.payload_config.bytebuf_params.req_size |
||||
|
||||
self._hist = hist |
||||
self._response_callbacks = [] |
||||
self._concurrency = config.outstanding_rpcs_per_channel |
||||
|
||||
async def run(self) -> None: |
||||
await self._channel.channel_ready() |
||||
|
||||
async def stop(self) -> None: |
||||
await self._channel.close() |
||||
|
||||
def _record_query_time(self, query_time: float) -> None: |
||||
self._hist.add(query_time * 1e9) |
||||
|
||||
|
||||
class UnaryAsyncBenchmarkClient(BenchmarkClient): |
||||
|
||||
def __init__(self, address: str, config: control_pb2.ClientConfig, |
||||
hist: histogram.Histogram): |
||||
super().__init__(address, config, hist) |
||||
self._running = None |
||||
self._stopped = asyncio.Event() |
||||
|
||||
async def _send_request(self): |
||||
start_time = time.monotonic() |
||||
await self._stub.UnaryCall(self._request) |
||||
self._record_query_time(time.monotonic() - start_time) |
||||
|
||||
async def _send_indefinitely(self) -> None: |
||||
while self._running: |
||||
await self._send_request() |
||||
|
||||
async def run(self) -> None: |
||||
await super().run() |
||||
self._running = True |
||||
senders = (self._send_indefinitely() for _ in range(self._concurrency)) |
||||
await asyncio.gather(*senders) |
||||
self._stopped.set() |
||||
|
||||
async def stop(self) -> None: |
||||
self._running = False |
||||
await self._stopped.wait() |
||||
await super().stop() |
||||
|
||||
|
||||
class StreamingAsyncBenchmarkClient(BenchmarkClient): |
||||
|
||||
def __init__(self, address: str, config: control_pb2.ClientConfig, |
||||
hist: histogram.Histogram): |
||||
super().__init__(address, config, hist) |
||||
self._running = None |
||||
self._stopped = asyncio.Event() |
||||
|
||||
async def _one_streaming_call(self): |
||||
call = self._stub.StreamingCall() |
||||
while self._running: |
||||
start_time = time.time() |
||||
await call.write(self._request) |
||||
await call.read() |
||||
self._record_query_time(time.time() - start_time) |
||||
await call.done_writing() |
||||
|
||||
async def run(self): |
||||
await super().run() |
||||
self._running = True |
||||
senders = (self._one_streaming_call() for _ in range(self._concurrency)) |
||||
await asyncio.gather(*senders) |
||||
self._stopped.set() |
||||
|
||||
async def stop(self): |
||||
self._running = False |
||||
await self._stopped.wait() |
||||
await super().stop() |
@ -0,0 +1,55 @@ |
||||
# Copyright 2020 The gRPC Authors |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
"""The Python AsyncIO Benchmark Servicers.""" |
||||
|
||||
import asyncio |
||||
import logging |
||||
import unittest |
||||
|
||||
from grpc.experimental import aio |
||||
|
||||
from src.proto.grpc.testing import benchmark_service_pb2_grpc, messages_pb2 |
||||
|
||||
|
||||
class BenchmarkServicer(benchmark_service_pb2_grpc.BenchmarkServiceServicer): |
||||
|
||||
async def UnaryCall(self, request, unused_context): |
||||
payload = messages_pb2.Payload(body=b'\0' * request.response_size) |
||||
return messages_pb2.SimpleResponse(payload=payload) |
||||
|
||||
async def StreamingFromServer(self, request, unused_context): |
||||
payload = messages_pb2.Payload(body=b'\0' * request.response_size) |
||||
# Sends response at full capacity! |
||||
while True: |
||||
yield messages_pb2.SimpleResponse(payload=payload) |
||||
|
||||
async def StreamingCall(self, request_iterator, unused_context): |
||||
async for request in request_iterator: |
||||
payload = messages_pb2.Payload(body=b'\0' * request.response_size) |
||||
yield messages_pb2.SimpleResponse(payload=payload) |
||||
|
||||
|
||||
class GenericBenchmarkServicer( |
||||
benchmark_service_pb2_grpc.BenchmarkServiceServicer): |
||||
"""Generic (no-codec) Server implementation for the Benchmark service.""" |
||||
|
||||
def __init__(self, resp_size): |
||||
self._response = '\0' * resp_size |
||||
|
||||
async def UnaryCall(self, unused_request, unused_context): |
||||
return self._response |
||||
|
||||
async def StreamingCall(self, request_iterator, unused_context): |
||||
async for _ in request_iterator: |
||||
yield self._response |
@ -0,0 +1,58 @@ |
||||
# Copyright 2020 The gRPC Authors |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
import argparse |
||||
import asyncio |
||||
import logging |
||||
|
||||
from grpc.experimental import aio |
||||
|
||||
from src.proto.grpc.testing import worker_service_pb2_grpc |
||||
from tests_aio.benchmark import worker_servicer |
||||
|
||||
|
||||
async def run_worker_server(port: int) -> None: |
||||
aio.init_grpc_aio() |
||||
server = aio.server() |
||||
|
||||
servicer = worker_servicer.WorkerServicer() |
||||
worker_service_pb2_grpc.add_WorkerServiceServicer_to_server( |
||||
servicer, server) |
||||
|
||||
server.add_insecure_port('[::]:{}'.format(port)) |
||||
|
||||
await server.start() |
||||
|
||||
await servicer.wait_for_quit() |
||||
await server.stop(None) |
||||
|
||||
|
||||
if __name__ == '__main__': |
||||
logging.basicConfig(level=logging.DEBUG) |
||||
parser = argparse.ArgumentParser( |
||||
description='gRPC Python performance testing worker') |
||||
parser.add_argument('--driver_port', |
||||
type=int, |
||||
dest='port', |
||||
help='The port the worker should listen on') |
||||
parser.add_argument('--uvloop', |
||||
action='store_true', |
||||
help='Use uvloop or not') |
||||
args = parser.parse_args() |
||||
|
||||
if args.uvloop: |
||||
import uvloop |
||||
uvloop.install() |
||||
|
||||
asyncio.get_event_loop().run_until_complete(run_worker_server(args.port)) |
@ -0,0 +1,367 @@ |
||||
# Copyright 2020 The gRPC Authors |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
import asyncio |
||||
import collections |
||||
import logging |
||||
import multiprocessing |
||||
import os |
||||
import sys |
||||
import time |
||||
from typing import Tuple |
||||
|
||||
import grpc |
||||
from grpc.experimental import aio |
||||
|
||||
from src.proto.grpc.testing import (benchmark_service_pb2_grpc, control_pb2, |
||||
stats_pb2, worker_service_pb2_grpc) |
||||
from tests.qps import histogram |
||||
from tests.unit import resources |
||||
from tests.unit.framework.common import get_socket |
||||
from tests_aio.benchmark import benchmark_client, benchmark_servicer |
||||
|
||||
_NUM_CORES = multiprocessing.cpu_count() |
||||
_WORKER_ENTRY_FILE = os.path.join( |
||||
os.path.split(os.path.abspath(__file__))[0], 'worker.py') |
||||
|
||||
_LOGGER = logging.getLogger(__name__) |
||||
|
||||
|
||||
class _SubWorker( |
||||
collections.namedtuple('_SubWorker', |
||||
['process', 'port', 'channel', 'stub'])): |
||||
"""A data class that holds information about a child qps worker.""" |
||||
|
||||
def _repr(self): |
||||
return f'<_SubWorker pid={self.process.pid} port={self.port}>' |
||||
|
||||
def __repr__(self): |
||||
return self._repr() |
||||
|
||||
def __str__(self): |
||||
return self._repr() |
||||
|
||||
|
||||
def _get_server_status(start_time: float, end_time: float, |
||||
port: int) -> control_pb2.ServerStatus: |
||||
"""Creates ServerStatus proto message.""" |
||||
end_time = time.monotonic() |
||||
elapsed_time = end_time - start_time |
||||
# TODO(lidiz) Collect accurate time system to compute QPS/core-second. |
||||
stats = stats_pb2.ServerStats(time_elapsed=elapsed_time, |
||||
time_user=elapsed_time, |
||||
time_system=elapsed_time) |
||||
return control_pb2.ServerStatus(stats=stats, port=port, cores=_NUM_CORES) |
||||
|
||||
|
||||
def _create_server(config: control_pb2.ServerConfig) -> Tuple[aio.Server, int]: |
||||
"""Creates a server object according to the ServerConfig.""" |
||||
channel_args = tuple( |
||||
(arg.name, |
||||
arg.str_value) if arg.HasField('str_value') else (arg.name, |
||||
int(arg.int_value)) |
||||
for arg in config.channel_args) |
||||
|
||||
server = aio.server(options=channel_args + (('grpc.so_reuseport', 1),)) |
||||
if config.server_type == control_pb2.ASYNC_SERVER: |
||||
servicer = benchmark_servicer.BenchmarkServicer() |
||||
benchmark_service_pb2_grpc.add_BenchmarkServiceServicer_to_server( |
||||
servicer, server) |
||||
elif config.server_type == control_pb2.ASYNC_GENERIC_SERVER: |
||||
resp_size = config.payload_config.bytebuf_params.resp_size |
||||
servicer = benchmark_servicer.GenericBenchmarkServicer(resp_size) |
||||
method_implementations = { |
||||
'StreamingCall': |
||||
grpc.stream_stream_rpc_method_handler(servicer.StreamingCall), |
||||
'UnaryCall': |
||||
grpc.unary_unary_rpc_method_handler(servicer.UnaryCall), |
||||
} |
||||
handler = grpc.method_handlers_generic_handler( |
||||
'grpc.testing.BenchmarkService', method_implementations) |
||||
server.add_generic_rpc_handlers((handler,)) |
||||
else: |
||||
raise NotImplementedError('Unsupported server type {}'.format( |
||||
config.server_type)) |
||||
|
||||
if config.HasField('security_params'): # Use SSL |
||||
server_creds = grpc.ssl_server_credentials( |
||||
((resources.private_key(), resources.certificate_chain()),)) |
||||
port = server.add_secure_port('[::]:{}'.format(config.port), |
||||
server_creds) |
||||
else: |
||||
port = server.add_insecure_port('[::]:{}'.format(config.port)) |
||||
|
||||
return server, port |
||||
|
||||
|
||||
def _get_client_status(start_time: float, end_time: float, |
||||
qps_data: histogram.Histogram |
||||
) -> control_pb2.ClientStatus: |
||||
"""Creates ClientStatus proto message.""" |
||||
latencies = qps_data.get_data() |
||||
end_time = time.monotonic() |
||||
elapsed_time = end_time - start_time |
||||
# TODO(lidiz) Collect accurate time system to compute QPS/core-second. |
||||
stats = stats_pb2.ClientStats(latencies=latencies, |
||||
time_elapsed=elapsed_time, |
||||
time_user=elapsed_time, |
||||
time_system=elapsed_time) |
||||
return control_pb2.ClientStatus(stats=stats) |
||||
|
||||
|
||||
def _create_client(server: str, config: control_pb2.ClientConfig, |
||||
qps_data: histogram.Histogram |
||||
) -> benchmark_client.BenchmarkClient: |
||||
"""Creates a client object according to the ClientConfig.""" |
||||
if config.load_params.WhichOneof('load') != 'closed_loop': |
||||
raise NotImplementedError( |
||||
f'Unsupported load parameter {config.load_params}') |
||||
|
||||
if config.client_type == control_pb2.ASYNC_CLIENT: |
||||
if config.rpc_type == control_pb2.UNARY: |
||||
client_type = benchmark_client.UnaryAsyncBenchmarkClient |
||||
elif config.rpc_type == control_pb2.STREAMING: |
||||
client_type = benchmark_client.StreamingAsyncBenchmarkClient |
||||
else: |
||||
raise NotImplementedError( |
||||
f'Unsupported rpc_type [{config.rpc_type}]') |
||||
else: |
||||
raise NotImplementedError( |
||||
f'Unsupported client type {config.client_type}') |
||||
|
||||
return client_type(server, config, qps_data) |
||||
|
||||
|
||||
def _pick_an_unused_port() -> int: |
||||
"""Picks an unused TCP port.""" |
||||
_, port, sock = get_socket() |
||||
sock.close() |
||||
return port |
||||
|
||||
|
||||
async def _create_sub_worker() -> _SubWorker: |
||||
"""Creates a child qps worker as a subprocess.""" |
||||
port = _pick_an_unused_port() |
||||
|
||||
_LOGGER.info('Creating sub worker at port [%d]...', port) |
||||
process = await asyncio.create_subprocess_exec(sys.executable, |
||||
_WORKER_ENTRY_FILE, |
||||
'--driver_port', str(port)) |
||||
_LOGGER.info('Created sub worker process for port [%d] at pid [%d]', port, |
||||
process.pid) |
||||
channel = aio.insecure_channel(f'localhost:{port}') |
||||
_LOGGER.info('Waiting for sub worker at port [%d]', port) |
||||
await channel.channel_ready() |
||||
stub = worker_service_pb2_grpc.WorkerServiceStub(channel) |
||||
return _SubWorker( |
||||
process=process, |
||||
port=port, |
||||
channel=channel, |
||||
stub=stub, |
||||
) |
||||
|
||||
|
||||
class WorkerServicer(worker_service_pb2_grpc.WorkerServiceServicer): |
||||
"""Python Worker Server implementation.""" |
||||
|
||||
def __init__(self): |
||||
self._loop = asyncio.get_event_loop() |
||||
self._quit_event = asyncio.Event() |
||||
|
||||
async def _run_single_server(self, config, request_iterator, context): |
||||
server, port = _create_server(config) |
||||
await server.start() |
||||
_LOGGER.info('Server started at port [%d]', port) |
||||
|
||||
start_time = time.monotonic() |
||||
await context.write(_get_server_status(start_time, start_time, port)) |
||||
|
||||
async for request in request_iterator: |
||||
end_time = time.monotonic() |
||||
status = _get_server_status(start_time, end_time, port) |
||||
if request.mark.reset: |
||||
start_time = end_time |
||||
await context.write(status) |
||||
await server.stop(None) |
||||
|
||||
async def RunServer(self, request_iterator, context): |
||||
config_request = await context.read() |
||||
config = config_request.setup |
||||
_LOGGER.info('Received ServerConfig: %s', config) |
||||
|
||||
if config.server_processes <= 0: |
||||
_LOGGER.info('Using server_processes == [%d]', _NUM_CORES) |
||||
config.server_processes = _NUM_CORES |
||||
|
||||
if config.port == 0: |
||||
config.port = _pick_an_unused_port() |
||||
_LOGGER.info('Port picked [%d]', config.port) |
||||
|
||||
if config.server_processes == 1: |
||||
# If server_processes == 1, start the server in this process. |
||||
await self._run_single_server(config, request_iterator, context) |
||||
else: |
||||
# If server_processes > 1, offload to other processes. |
||||
sub_workers = await asyncio.gather(*( |
||||
_create_sub_worker() for _ in range(config.server_processes))) |
||||
|
||||
calls = [worker.stub.RunServer() for worker in sub_workers] |
||||
|
||||
config_request.setup.server_processes = 1 |
||||
|
||||
for call in calls: |
||||
await call.write(config_request) |
||||
# An empty status indicates the peer is ready |
||||
await call.read() |
||||
|
||||
start_time = time.monotonic() |
||||
await context.write( |
||||
_get_server_status( |
||||
start_time, |
||||
start_time, |
||||
config.port, |
||||
)) |
||||
|
||||
_LOGGER.info('Servers are ready to serve.') |
||||
|
||||
async for request in request_iterator: |
||||
end_time = time.monotonic() |
||||
|
||||
for call in calls: |
||||
await call.write(request) |
||||
# Reports from sub workers doesn't matter |
||||
await call.read() |
||||
|
||||
status = _get_server_status( |
||||
start_time, |
||||
end_time, |
||||
config.port, |
||||
) |
||||
if request.mark.reset: |
||||
start_time = end_time |
||||
await context.write(status) |
||||
|
||||
for call in calls: |
||||
await call.done_writing() |
||||
|
||||
for worker in sub_workers: |
||||
await worker.stub.QuitWorker(control_pb2.Void()) |
||||
await worker.channel.close() |
||||
_LOGGER.info('Waiting for [%s] to quit...', worker) |
||||
await worker.process.wait() |
||||
|
||||
async def _run_single_client(self, config, request_iterator, context): |
||||
running_tasks = [] |
||||
qps_data = histogram.Histogram(config.histogram_params.resolution, |
||||
config.histogram_params.max_possible) |
||||
start_time = time.monotonic() |
||||
|
||||
# Create a client for each channel as asyncio.Task |
||||
for i in range(config.client_channels): |
||||
server = config.server_targets[i % len(config.server_targets)] |
||||
client = _create_client(server, config, qps_data) |
||||
_LOGGER.info('Client created against server [%s]', server) |
||||
running_tasks.append(self._loop.create_task(client.run())) |
||||
|
||||
end_time = time.monotonic() |
||||
await context.write(_get_client_status(start_time, end_time, qps_data)) |
||||
|
||||
# Respond to stat requests |
||||
async for request in request_iterator: |
||||
end_time = time.monotonic() |
||||
status = _get_client_status(start_time, end_time, qps_data) |
||||
if request.mark.reset: |
||||
qps_data.reset() |
||||
start_time = time.monotonic() |
||||
await context.write(status) |
||||
|
||||
# Cleanup the clients |
||||
for task in running_tasks: |
||||
task.cancel() |
||||
|
||||
async def RunClient(self, request_iterator, context): |
||||
config_request = await context.read() |
||||
config = config_request.setup |
||||
_LOGGER.info('Received ClientConfig: %s', config) |
||||
|
||||
if config.client_processes <= 0: |
||||
_LOGGER.info('client_processes can\'t be [%d]', |
||||
config.client_processes) |
||||
_LOGGER.info('Using client_processes == [%d]', _NUM_CORES) |
||||
config.client_processes = _NUM_CORES |
||||
|
||||
if config.client_processes == 1: |
||||
# If client_processes == 1, run the benchmark in this process. |
||||
await self._run_single_client(config, request_iterator, context) |
||||
else: |
||||
# If client_processes > 1, offload the work to other processes. |
||||
sub_workers = await asyncio.gather(*( |
||||
_create_sub_worker() for _ in range(config.client_processes))) |
||||
|
||||
calls = [worker.stub.RunClient() for worker in sub_workers] |
||||
|
||||
config_request.setup.client_processes = 1 |
||||
|
||||
for call in calls: |
||||
await call.write(config_request) |
||||
# An empty status indicates the peer is ready |
||||
await call.read() |
||||
|
||||
start_time = time.monotonic() |
||||
result = histogram.Histogram(config.histogram_params.resolution, |
||||
config.histogram_params.max_possible) |
||||
end_time = time.monotonic() |
||||
await context.write(_get_client_status(start_time, end_time, |
||||
result)) |
||||
|
||||
async for request in request_iterator: |
||||
end_time = time.monotonic() |
||||
|
||||
for call in calls: |
||||
_LOGGER.debug('Fetching status...') |
||||
await call.write(request) |
||||
sub_status = await call.read() |
||||
result.merge(sub_status.stats.latencies) |
||||
_LOGGER.debug('Update from sub worker count=[%d]', |
||||
sub_status.stats.latencies.count) |
||||
|
||||
status = _get_client_status(start_time, end_time, result) |
||||
if request.mark.reset: |
||||
result.reset() |
||||
start_time = time.monotonic() |
||||
_LOGGER.debug('Reporting count=[%d]', |
||||
status.stats.latencies.count) |
||||
await context.write(status) |
||||
|
||||
for call in calls: |
||||
await call.done_writing() |
||||
|
||||
for worker in sub_workers: |
||||
await worker.stub.QuitWorker(control_pb2.Void()) |
||||
await worker.channel.close() |
||||
_LOGGER.info('Waiting for sub worker [%s] to quit...', worker) |
||||
await worker.process.wait() |
||||
_LOGGER.info('Sub worker [%s] quit', worker) |
||||
|
||||
@staticmethod |
||||
async def CoreCount(unused_request, unused_context): |
||||
return control_pb2.CoreResponse(cores=_NUM_CORES) |
||||
|
||||
async def QuitWorker(self, unused_request, unused_context): |
||||
_LOGGER.info('QuitWorker command received.') |
||||
self._quit_event.set() |
||||
return control_pb2.Void() |
||||
|
||||
async def wait_for_quit(self): |
||||
await self._quit_event.wait() |
@ -0,0 +1,29 @@ |
||||
# Copyright 2020 The gRPC Authors |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
package(default_testonly = 1) |
||||
|
||||
py_test( |
||||
name = "health_servicer_test", |
||||
size = "small", |
||||
srcs = ["health_servicer_test.py"], |
||||
imports = ["../../"], |
||||
python_version = "PY3", |
||||
deps = [ |
||||
"//src/python/grpcio/grpc:grpcio", |
||||
"//src/python/grpcio_health_checking/grpc_health/v1:grpc_health", |
||||
"//src/python/grpcio_tests/tests/unit/framework/common", |
||||
"//src/python/grpcio_tests/tests_aio/unit:_test_base", |
||||
], |
||||
) |
@ -0,0 +1,13 @@ |
||||
# Copyright 2020 The gRPC Authors |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
@ -0,0 +1,262 @@ |
||||
# Copyright 2020 The gRPC Authors |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
"""Tests AsyncIO version of grpcio-health-checking.""" |
||||
|
||||
import asyncio |
||||
import logging |
||||
import time |
||||
import random |
||||
import unittest |
||||
|
||||
import grpc |
||||
|
||||
from grpc_health.v1 import health |
||||
from grpc_health.v1 import health_pb2 |
||||
from grpc_health.v1 import health_pb2_grpc |
||||
from grpc.experimental import aio |
||||
|
||||
from tests.unit.framework.common import test_constants |
||||
|
||||
from tests_aio.unit._test_base import AioTestBase |
||||
|
||||
_SERVING_SERVICE = 'grpc.test.TestServiceServing' |
||||
_UNKNOWN_SERVICE = 'grpc.test.TestServiceUnknown' |
||||
_NOT_SERVING_SERVICE = 'grpc.test.TestServiceNotServing' |
||||
_WATCH_SERVICE = 'grpc.test.WatchService' |
||||
|
||||
_LARGE_NUMBER_OF_STATUS_CHANGES = 1000 |
||||
|
||||
|
||||
async def _pipe_to_queue(call, queue): |
||||
async for response in call: |
||||
await queue.put(response) |
||||
|
||||
|
||||
class HealthServicerTest(AioTestBase): |
||||
|
||||
async def setUp(self): |
||||
self._servicer = health.aio.HealthServicer() |
||||
await self._servicer.set(health.OVERALL_HEALTH, |
||||
health_pb2.HealthCheckResponse.SERVING) |
||||
await self._servicer.set(_SERVING_SERVICE, |
||||
health_pb2.HealthCheckResponse.SERVING) |
||||
await self._servicer.set(_UNKNOWN_SERVICE, |
||||
health_pb2.HealthCheckResponse.UNKNOWN) |
||||
await self._servicer.set(_NOT_SERVING_SERVICE, |
||||
health_pb2.HealthCheckResponse.NOT_SERVING) |
||||
self._server = aio.server() |
||||
port = self._server.add_insecure_port('[::]:0') |
||||
health_pb2_grpc.add_HealthServicer_to_server(self._servicer, |
||||
self._server) |
||||
await self._server.start() |
||||
|
||||
self._channel = aio.insecure_channel('localhost:%d' % port) |
||||
self._stub = health_pb2_grpc.HealthStub(self._channel) |
||||
|
||||
async def tearDown(self): |
||||
await self._channel.close() |
||||
await self._server.stop(None) |
||||
|
||||
async def test_check_empty_service(self): |
||||
request = health_pb2.HealthCheckRequest() |
||||
resp = await self._stub.Check(request) |
||||
self.assertEqual(health_pb2.HealthCheckResponse.SERVING, resp.status) |
||||
|
||||
async def test_check_serving_service(self): |
||||
request = health_pb2.HealthCheckRequest(service=_SERVING_SERVICE) |
||||
resp = await self._stub.Check(request) |
||||
self.assertEqual(health_pb2.HealthCheckResponse.SERVING, resp.status) |
||||
|
||||
async def test_check_unknown_service(self): |
||||
request = health_pb2.HealthCheckRequest(service=_UNKNOWN_SERVICE) |
||||
resp = await self._stub.Check(request) |
||||
self.assertEqual(health_pb2.HealthCheckResponse.UNKNOWN, resp.status) |
||||
|
||||
async def test_check_not_serving_service(self): |
||||
request = health_pb2.HealthCheckRequest(service=_NOT_SERVING_SERVICE) |
||||
resp = await self._stub.Check(request) |
||||
self.assertEqual(health_pb2.HealthCheckResponse.NOT_SERVING, |
||||
resp.status) |
||||
|
||||
async def test_check_not_found_service(self): |
||||
request = health_pb2.HealthCheckRequest(service='not-found') |
||||
with self.assertRaises(aio.AioRpcError) as context: |
||||
await self._stub.Check(request) |
||||
|
||||
self.assertEqual(grpc.StatusCode.NOT_FOUND, context.exception.code()) |
||||
|
||||
async def test_health_service_name(self): |
||||
self.assertEqual(health.SERVICE_NAME, 'grpc.health.v1.Health') |
||||
|
||||
async def test_watch_empty_service(self): |
||||
request = health_pb2.HealthCheckRequest(service=health.OVERALL_HEALTH) |
||||
|
||||
call = self._stub.Watch(request) |
||||
queue = asyncio.Queue() |
||||
task = self.loop.create_task(_pipe_to_queue(call, queue)) |
||||
|
||||
self.assertEqual(health_pb2.HealthCheckResponse.SERVING, |
||||
(await queue.get()).status) |
||||
|
||||
call.cancel() |
||||
await task |
||||
self.assertTrue(queue.empty()) |
||||
|
||||
async def test_watch_new_service(self): |
||||
request = health_pb2.HealthCheckRequest(service=_WATCH_SERVICE) |
||||
call = self._stub.Watch(request) |
||||
queue = asyncio.Queue() |
||||
task = self.loop.create_task(_pipe_to_queue(call, queue)) |
||||
|
||||
self.assertEqual(health_pb2.HealthCheckResponse.SERVICE_UNKNOWN, |
||||
(await queue.get()).status) |
||||
|
||||
await self._servicer.set(_WATCH_SERVICE, |
||||
health_pb2.HealthCheckResponse.SERVING) |
||||
self.assertEqual(health_pb2.HealthCheckResponse.SERVING, |
||||
(await queue.get()).status) |
||||
|
||||
await self._servicer.set(_WATCH_SERVICE, |
||||
health_pb2.HealthCheckResponse.NOT_SERVING) |
||||
self.assertEqual(health_pb2.HealthCheckResponse.NOT_SERVING, |
||||
(await queue.get()).status) |
||||
|
||||
call.cancel() |
||||
await task |
||||
self.assertTrue(queue.empty()) |
||||
|
||||
async def test_watch_service_isolation(self): |
||||
request = health_pb2.HealthCheckRequest(service=_WATCH_SERVICE) |
||||
call = self._stub.Watch(request) |
||||
queue = asyncio.Queue() |
||||
task = self.loop.create_task(_pipe_to_queue(call, queue)) |
||||
|
||||
self.assertEqual(health_pb2.HealthCheckResponse.SERVICE_UNKNOWN, |
||||
(await queue.get()).status) |
||||
|
||||
await self._servicer.set('some-other-service', |
||||
health_pb2.HealthCheckResponse.SERVING) |
||||
# The change of health status in other service should be isolated. |
||||
# Hence, no additional notification should be observed. |
||||
with self.assertRaises(asyncio.TimeoutError): |
||||
await asyncio.wait_for(queue.get(), test_constants.SHORT_TIMEOUT) |
||||
|
||||
call.cancel() |
||||
await task |
||||
self.assertTrue(queue.empty()) |
||||
|
||||
async def test_two_watchers(self): |
||||
request = health_pb2.HealthCheckRequest(service=_WATCH_SERVICE) |
||||
queue1 = asyncio.Queue() |
||||
queue2 = asyncio.Queue() |
||||
call1 = self._stub.Watch(request) |
||||
call2 = self._stub.Watch(request) |
||||
task1 = self.loop.create_task(_pipe_to_queue(call1, queue1)) |
||||
task2 = self.loop.create_task(_pipe_to_queue(call2, queue2)) |
||||
|
||||
self.assertEqual(health_pb2.HealthCheckResponse.SERVICE_UNKNOWN, |
||||
(await queue1.get()).status) |
||||
self.assertEqual(health_pb2.HealthCheckResponse.SERVICE_UNKNOWN, |
||||
(await queue2.get()).status) |
||||
|
||||
await self._servicer.set(_WATCH_SERVICE, |
||||
health_pb2.HealthCheckResponse.SERVING) |
||||
self.assertEqual(health_pb2.HealthCheckResponse.SERVING, |
||||
(await queue1.get()).status) |
||||
self.assertEqual(health_pb2.HealthCheckResponse.SERVING, |
||||
(await queue2.get()).status) |
||||
|
||||
call1.cancel() |
||||
call2.cancel() |
||||
await task1 |
||||
await task2 |
||||
self.assertTrue(queue1.empty()) |
||||
self.assertTrue(queue2.empty()) |
||||
|
||||
async def test_cancelled_watch_removed_from_watch_list(self): |
||||
request = health_pb2.HealthCheckRequest(service=_WATCH_SERVICE) |
||||
call = self._stub.Watch(request) |
||||
queue = asyncio.Queue() |
||||
task = self.loop.create_task(_pipe_to_queue(call, queue)) |
||||
|
||||
self.assertEqual(health_pb2.HealthCheckResponse.SERVICE_UNKNOWN, |
||||
(await queue.get()).status) |
||||
|
||||
call.cancel() |
||||
await self._servicer.set(_WATCH_SERVICE, |
||||
health_pb2.HealthCheckResponse.SERVING) |
||||
await task |
||||
|
||||
# Wait for the serving coroutine to process client cancellation. |
||||
timeout = time.monotonic() + test_constants.TIME_ALLOWANCE |
||||
while (time.monotonic() < timeout and self._servicer._server_watchers): |
||||
await asyncio.sleep(1) |
||||
self.assertFalse(self._servicer._server_watchers, |
||||
'There should not be any watcher left') |
||||
self.assertTrue(queue.empty()) |
||||
|
||||
async def test_graceful_shutdown(self): |
||||
request = health_pb2.HealthCheckRequest(service=health.OVERALL_HEALTH) |
||||
call = self._stub.Watch(request) |
||||
queue = asyncio.Queue() |
||||
task = self.loop.create_task(_pipe_to_queue(call, queue)) |
||||
|
||||
self.assertEqual(health_pb2.HealthCheckResponse.SERVING, |
||||
(await queue.get()).status) |
||||
|
||||
await self._servicer.enter_graceful_shutdown() |
||||
self.assertEqual(health_pb2.HealthCheckResponse.NOT_SERVING, |
||||
(await queue.get()).status) |
||||
|
||||
# This should be a no-op. |
||||
await self._servicer.set(health.OVERALL_HEALTH, |
||||
health_pb2.HealthCheckResponse.SERVING) |
||||
|
||||
resp = await self._stub.Check(request) |
||||
self.assertEqual(health_pb2.HealthCheckResponse.NOT_SERVING, |
||||
resp.status) |
||||
|
||||
call.cancel() |
||||
await task |
||||
self.assertTrue(queue.empty()) |
||||
|
||||
async def test_no_duplicate_status(self): |
||||
request = health_pb2.HealthCheckRequest(service=_WATCH_SERVICE) |
||||
call = self._stub.Watch(request) |
||||
queue = asyncio.Queue() |
||||
task = self.loop.create_task(_pipe_to_queue(call, queue)) |
||||
|
||||
self.assertEqual(health_pb2.HealthCheckResponse.SERVICE_UNKNOWN, |
||||
(await queue.get()).status) |
||||
last_status = health_pb2.HealthCheckResponse.SERVICE_UNKNOWN |
||||
|
||||
for _ in range(_LARGE_NUMBER_OF_STATUS_CHANGES): |
||||
if random.randint(0, 1) == 0: |
||||
status = health_pb2.HealthCheckResponse.SERVING |
||||
else: |
||||
status = health_pb2.HealthCheckResponse.NOT_SERVING |
||||
|
||||
await self._servicer.set(_WATCH_SERVICE, status) |
||||
if status != last_status: |
||||
self.assertEqual(status, (await queue.get()).status) |
||||
last_status = status |
||||
|
||||
call.cancel() |
||||
await task |
||||
self.assertTrue(queue.empty()) |
||||
|
||||
|
||||
if __name__ == '__main__': |
||||
logging.basicConfig(level=logging.DEBUG) |
||||
unittest.main(verbosity=2) |
@ -0,0 +1,21 @@ |
||||
# Copyright 2020 The gRPC authors. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
from __future__ import absolute_import |
||||
|
||||
from tests import _loader |
||||
from tests import _runner |
||||
|
||||
Loader = _loader.Loader |
||||
Runner = _runner.Runner |
@ -0,0 +1,41 @@ |
||||
# Copyright 2020 The gRPC authors. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
package( |
||||
default_testonly = True, |
||||
) |
||||
|
||||
GRPCIO_PY3_ONLY_TESTS_UNIT = glob([ |
||||
"*_test.py", |
||||
]) |
||||
|
||||
[ |
||||
py_test( |
||||
name = test_file_name[:-len(".py")], |
||||
size = "small", |
||||
srcs = [test_file_name], |
||||
main = test_file_name, |
||||
python_version = "PY3", |
||||
srcs_version = "PY3", |
||||
deps = [ |
||||
"//src/python/grpcio/grpc:grpcio", |
||||
"//src/python/grpcio_tests/tests/testing", |
||||
"//src/python/grpcio_tests/tests/unit:resources", |
||||
"//src/python/grpcio_tests/tests/unit:test_common", |
||||
"//src/python/grpcio_tests/tests/unit/framework/common", |
||||
"@six", |
||||
], |
||||
) |
||||
for test_file_name in GRPCIO_PY3_ONLY_TESTS_UNIT |
||||
] |
@ -0,0 +1,13 @@ |
||||
# Copyright 2019 The gRPC Authors. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
@ -0,0 +1,276 @@ |
||||
# Copyright 2020 The gRPC authors. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
"""Tests for Simple Stubs.""" |
||||
|
||||
# TODO(https://github.com/grpc/grpc/issues/21965): Run under setuptools. |
||||
|
||||
import os |
||||
|
||||
_MAXIMUM_CHANNELS = 10 |
||||
|
||||
os.environ["GRPC_PYTHON_MANAGED_CHANNEL_EVICTION_SECONDS"] = "1" |
||||
os.environ["GRPC_PYTHON_MANAGED_CHANNEL_MAXIMUM"] = str(_MAXIMUM_CHANNELS) |
||||
|
||||
import contextlib |
||||
import datetime |
||||
import inspect |
||||
import logging |
||||
import unittest |
||||
import sys |
||||
import time |
||||
from typing import Callable, Optional |
||||
|
||||
from tests.unit import test_common |
||||
import grpc |
||||
import grpc.experimental |
||||
|
||||
_REQUEST = b"0000" |
||||
|
||||
_CACHE_EPOCHS = 8 |
||||
_CACHE_TRIALS = 6 |
||||
|
||||
_SERVER_RESPONSE_COUNT = 10 |
||||
_CLIENT_REQUEST_COUNT = _SERVER_RESPONSE_COUNT |
||||
|
||||
_STRESS_EPOCHS = _MAXIMUM_CHANNELS * 10 |
||||
|
||||
_UNARY_UNARY = "/test/UnaryUnary" |
||||
_UNARY_STREAM = "/test/UnaryStream" |
||||
_STREAM_UNARY = "/test/StreamUnary" |
||||
_STREAM_STREAM = "/test/StreamStream" |
||||
|
||||
|
||||
def _unary_unary_handler(request, context): |
||||
return request |
||||
|
||||
|
||||
def _unary_stream_handler(request, context): |
||||
for _ in range(_SERVER_RESPONSE_COUNT): |
||||
yield request |
||||
|
||||
|
||||
def _stream_unary_handler(request_iterator, context): |
||||
request = None |
||||
for single_request in request_iterator: |
||||
request = single_request |
||||
return request |
||||
|
||||
|
||||
def _stream_stream_handler(request_iterator, context): |
||||
for request in request_iterator: |
||||
yield request |
||||
|
||||
|
||||
class _GenericHandler(grpc.GenericRpcHandler): |
||||
|
||||
def service(self, handler_call_details): |
||||
if handler_call_details.method == _UNARY_UNARY: |
||||
return grpc.unary_unary_rpc_method_handler(_unary_unary_handler) |
||||
elif handler_call_details.method == _UNARY_STREAM: |
||||
return grpc.unary_stream_rpc_method_handler(_unary_stream_handler) |
||||
elif handler_call_details.method == _STREAM_UNARY: |
||||
return grpc.stream_unary_rpc_method_handler(_stream_unary_handler) |
||||
elif handler_call_details.method == _STREAM_STREAM: |
||||
return grpc.stream_stream_rpc_method_handler(_stream_stream_handler) |
||||
else: |
||||
raise NotImplementedError() |
||||
|
||||
|
||||
def _time_invocation(to_time: Callable[[], None]) -> datetime.timedelta: |
||||
start = datetime.datetime.now() |
||||
to_time() |
||||
return datetime.datetime.now() - start |
||||
|
||||
|
||||
@contextlib.contextmanager |
||||
def _server(credentials: Optional[grpc.ServerCredentials]): |
||||
try: |
||||
server = test_common.test_server() |
||||
target = '[::]:0' |
||||
if credentials is None: |
||||
port = server.add_insecure_port(target) |
||||
else: |
||||
port = server.add_secure_port(target, credentials) |
||||
server.add_generic_rpc_handlers((_GenericHandler(),)) |
||||
server.start() |
||||
yield port |
||||
finally: |
||||
server.stop(None) |
||||
|
||||
|
||||
class SimpleStubsTest(unittest.TestCase): |
||||
|
||||
def assert_cached(self, to_check: Callable[[str], None]) -> None: |
||||
"""Asserts that a function caches intermediate data/state. |
||||
|
||||
To be specific, given a function whose caching behavior is |
||||
deterministic in the value of a supplied string, this function asserts |
||||
that, on average, subsequent invocations of the function for a specific |
||||
string are faster than first invocations with that same string. |
||||
|
||||
Args: |
||||
to_check: A function returning nothing, that caches values based on |
||||
an arbitrary supplied string. |
||||
""" |
||||
initial_runs = [] |
||||
cached_runs = [] |
||||
for epoch in range(_CACHE_EPOCHS): |
||||
runs = [] |
||||
text = str(epoch) |
||||
for trial in range(_CACHE_TRIALS): |
||||
runs.append(_time_invocation(lambda: to_check(text))) |
||||
initial_runs.append(runs[0]) |
||||
cached_runs.extend(runs[1:]) |
||||
average_cold = sum((run for run in initial_runs), |
||||
datetime.timedelta()) / len(initial_runs) |
||||
average_warm = sum((run for run in cached_runs), |
||||
datetime.timedelta()) / len(cached_runs) |
||||
self.assertLess(average_warm, average_cold) |
||||
|
||||
def assert_eventually(self, |
||||
predicate: Callable[[], bool], |
||||
*, |
||||
timeout: Optional[datetime.timedelta] = None, |
||||
message: Optional[Callable[[], str]] = None) -> None: |
||||
message = message or (lambda: "Proposition did not evaluate to true") |
||||
timeout = timeout or datetime.timedelta(seconds=10) |
||||
end = datetime.datetime.now() + timeout |
||||
while datetime.datetime.now() < end: |
||||
if predicate(): |
||||
break |
||||
time.sleep(0.5) |
||||
else: |
||||
self.fail(message() + " after " + str(timeout)) |
||||
|
||||
def test_unary_unary_insecure(self): |
||||
with _server(None) as port: |
||||
target = f'localhost:{port}' |
||||
response = grpc.experimental.unary_unary( |
||||
_REQUEST, |
||||
target, |
||||
_UNARY_UNARY, |
||||
channel_credentials=grpc.experimental. |
||||
insecure_channel_credentials()) |
||||
self.assertEqual(_REQUEST, response) |
||||
|
||||
def test_unary_unary_secure(self): |
||||
with _server(grpc.local_server_credentials()) as port: |
||||
target = f'localhost:{port}' |
||||
response = grpc.experimental.unary_unary( |
||||
_REQUEST, |
||||
target, |
||||
_UNARY_UNARY, |
||||
channel_credentials=grpc.local_channel_credentials()) |
||||
self.assertEqual(_REQUEST, response) |
||||
|
||||
def test_channel_credentials_default(self): |
||||
with _server(grpc.local_server_credentials()) as port: |
||||
target = f'localhost:{port}' |
||||
response = grpc.experimental.unary_unary(_REQUEST, target, |
||||
_UNARY_UNARY) |
||||
self.assertEqual(_REQUEST, response) |
||||
|
||||
def test_channels_cached(self): |
||||
with _server(grpc.local_server_credentials()) as port: |
||||
target = f'localhost:{port}' |
||||
test_name = inspect.stack()[0][3] |
||||
args = (_REQUEST, target, _UNARY_UNARY) |
||||
kwargs = {"channel_credentials": grpc.local_channel_credentials()} |
||||
|
||||
def _invoke(seed: str): |
||||
run_kwargs = dict(kwargs) |
||||
run_kwargs["options"] = ((test_name + seed, ""),) |
||||
grpc.experimental.unary_unary(*args, **run_kwargs) |
||||
|
||||
self.assert_cached(_invoke) |
||||
|
||||
def test_channels_evicted(self): |
||||
with _server(grpc.local_server_credentials()) as port: |
||||
target = f'localhost:{port}' |
||||
response = grpc.experimental.unary_unary( |
||||
_REQUEST, |
||||
target, |
||||
_UNARY_UNARY, |
||||
channel_credentials=grpc.local_channel_credentials()) |
||||
self.assert_eventually( |
||||
lambda: grpc._simple_stubs.ChannelCache.get( |
||||
)._test_only_channel_count() == 0, |
||||
message=lambda: |
||||
f"{grpc._simple_stubs.ChannelCache.get()._test_only_channel_count()} remain" |
||||
) |
||||
|
||||
def test_total_channels_enforced(self): |
||||
with _server(grpc.local_server_credentials()) as port: |
||||
target = f'localhost:{port}' |
||||
for i in range(_STRESS_EPOCHS): |
||||
# Ensure we get a new channel each time. |
||||
options = (("foo", str(i)),) |
||||
# Send messages at full blast. |
||||
grpc.experimental.unary_unary( |
||||
_REQUEST, |
||||
target, |
||||
_UNARY_UNARY, |
||||
options=options, |
||||
channel_credentials=grpc.local_channel_credentials()) |
||||
self.assert_eventually( |
||||
lambda: grpc._simple_stubs.ChannelCache.get( |
||||
)._test_only_channel_count() <= _MAXIMUM_CHANNELS + 1, |
||||
message=lambda: |
||||
f"{grpc._simple_stubs.ChannelCache.get()._test_only_channel_count()} channels remain" |
||||
) |
||||
|
||||
def test_unary_stream(self): |
||||
with _server(grpc.local_server_credentials()) as port: |
||||
target = f'localhost:{port}' |
||||
for response in grpc.experimental.unary_stream( |
||||
_REQUEST, |
||||
target, |
||||
_UNARY_STREAM, |
||||
channel_credentials=grpc.local_channel_credentials()): |
||||
self.assertEqual(_REQUEST, response) |
||||
|
||||
def test_stream_unary(self): |
||||
|
||||
def request_iter(): |
||||
for _ in range(_CLIENT_REQUEST_COUNT): |
||||
yield _REQUEST |
||||
|
||||
with _server(grpc.local_server_credentials()) as port: |
||||
target = f'localhost:{port}' |
||||
response = grpc.experimental.stream_unary( |
||||
request_iter(), |
||||
target, |
||||
_STREAM_UNARY, |
||||
channel_credentials=grpc.local_channel_credentials()) |
||||
self.assertEqual(_REQUEST, response) |
||||
|
||||
def test_stream_stream(self): |
||||
|
||||
def request_iter(): |
||||
for _ in range(_CLIENT_REQUEST_COUNT): |
||||
yield _REQUEST |
||||
|
||||
with _server(grpc.local_server_credentials()) as port: |
||||
target = f'localhost:{port}' |
||||
for response in grpc.experimental.stream_stream( |
||||
request_iter(), |
||||
target, |
||||
_STREAM_STREAM, |
||||
channel_credentials=grpc.local_channel_credentials()): |
||||
self.assertEqual(_REQUEST, response) |
||||
|
||||
|
||||
if __name__ == "__main__": |
||||
logging.basicConfig(level=logging.INFO) |
||||
unittest.main(verbosity=2) |
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue