mirror of https://github.com/grpc/grpc.git
commit
8acfeef374
92 changed files with 2240 additions and 754 deletions
@ -0,0 +1,49 @@ |
|||||||
|
# Copyright 2020 gRPC authors. |
||||||
|
# |
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||||
|
# you may not use this file except in compliance with the License. |
||||||
|
# You may obtain a copy of the License at |
||||||
|
# |
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||||
|
# |
||||||
|
# Unless required by applicable law or agreed to in writing, software |
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||||
|
# See the License for the specific language governing permissions and |
||||||
|
# limitations under the License. |
||||||
|
|
||||||
|
licenses(["notice"]) # 3-clause BSD |
||||||
|
|
||||||
|
load("@grpc_python_dependencies//:requirements.bzl", "requirement") |
||||||
|
|
||||||
|
py_binary( |
||||||
|
name = "alts_server", |
||||||
|
srcs = [ |
||||||
|
"alts_server.py", |
||||||
|
"demo_pb2.py", |
||||||
|
"demo_pb2_grpc.py", |
||||||
|
"server.py", |
||||||
|
], |
||||||
|
main = "alts_server.py", |
||||||
|
python_version = "PY3", |
||||||
|
srcs_version = "PY2AND3", |
||||||
|
deps = [ |
||||||
|
"//src/python/grpcio/grpc:grpcio", |
||||||
|
], |
||||||
|
) |
||||||
|
|
||||||
|
py_binary( |
||||||
|
name = "alts_client", |
||||||
|
srcs = [ |
||||||
|
"alts_client.py", |
||||||
|
"client.py", |
||||||
|
"demo_pb2.py", |
||||||
|
"demo_pb2_grpc.py", |
||||||
|
], |
||||||
|
main = "alts_client.py", |
||||||
|
python_version = "PY3", |
||||||
|
srcs_version = "PY2AND3", |
||||||
|
deps = [ |
||||||
|
"//src/python/grpcio/grpc:grpcio", |
||||||
|
], |
||||||
|
) |
@ -0,0 +1,39 @@ |
|||||||
|
# Copyright 2020 gRPC authors. |
||||||
|
# |
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||||
|
# you may not use this file except in compliance with the License. |
||||||
|
# You may obtain a copy of the License at |
||||||
|
# |
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||||
|
# |
||||||
|
# Unless required by applicable law or agreed to in writing, software |
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||||
|
# See the License for the specific language governing permissions and |
||||||
|
# limitations under the License. |
||||||
|
"""The example of using ALTS credentials to setup gRPC client. |
||||||
|
|
||||||
|
The example would only successfully run in GCP environment.""" |
||||||
|
|
||||||
|
import grpc |
||||||
|
|
||||||
|
import demo_pb2_grpc |
||||||
|
from client import (bidirectional_streaming_method, client_streaming_method, |
||||||
|
server_streaming_method, simple_method) |
||||||
|
|
||||||
|
SERVER_ADDRESS = "localhost:23333" |
||||||
|
|
||||||
|
|
||||||
|
def main(): |
||||||
|
with grpc.secure_channel( |
||||||
|
SERVER_ADDRESS, |
||||||
|
credentials=grpc.alts_channel_credentials()) as channel: |
||||||
|
stub = demo_pb2_grpc.GRPCDemoStub(channel) |
||||||
|
simple_method(stub) |
||||||
|
client_streaming_method(stub) |
||||||
|
server_streaming_method(stub) |
||||||
|
bidirectional_streaming_method(stub) |
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__': |
||||||
|
main() |
@ -0,0 +1,39 @@ |
|||||||
|
# Copyright 2020 gRPC authors. |
||||||
|
# |
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||||
|
# you may not use this file except in compliance with the License. |
||||||
|
# You may obtain a copy of the License at |
||||||
|
# |
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||||
|
# |
||||||
|
# Unless required by applicable law or agreed to in writing, software |
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||||
|
# See the License for the specific language governing permissions and |
||||||
|
# limitations under the License. |
||||||
|
"""The example of using ALTS credentials to setup gRPC server in python. |
||||||
|
|
||||||
|
The example would only successfully run in GCP environment.""" |
||||||
|
|
||||||
|
from concurrent import futures |
||||||
|
|
||||||
|
import grpc |
||||||
|
|
||||||
|
import demo_pb2_grpc |
||||||
|
from server import DemoServer |
||||||
|
|
||||||
|
SERVER_ADDRESS = 'localhost:23333' |
||||||
|
|
||||||
|
|
||||||
|
def main(): |
||||||
|
svr = grpc.server(futures.ThreadPoolExecutor()) |
||||||
|
demo_pb2_grpc.add_GRPCDemoServicer_to_server(DemoServer(), svr) |
||||||
|
svr.add_secure_port(SERVER_ADDRESS, |
||||||
|
server_credentials=grpc.alts_server_credentials()) |
||||||
|
print("------------------start Python GRPC server with ALTS encryption") |
||||||
|
svr.start() |
||||||
|
svr.wait_for_termination() |
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__': |
||||||
|
main() |
@ -1,7 +1,7 @@ |
|||||||
<!-- This file is generated --> |
<!-- This file is generated --> |
||||||
<Project> |
<Project> |
||||||
<PropertyGroup> |
<PropertyGroup> |
||||||
<GrpcCsharpVersion>2.29.0-dev</GrpcCsharpVersion> |
<GrpcCsharpVersion>2.30.0-dev</GrpcCsharpVersion> |
||||||
<GoogleProtobufVersion>3.11.4</GoogleProtobufVersion> |
<GoogleProtobufVersion>3.11.4</GoogleProtobufVersion> |
||||||
</PropertyGroup> |
</PropertyGroup> |
||||||
</Project> |
</Project> |
||||||
|
@ -0,0 +1,531 @@ |
|||||||
|
# Copyright 2020 The gRPC Authors. |
||||||
|
# |
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||||
|
# you may not use this file except in compliance with the License. |
||||||
|
# You may obtain a copy of the License at |
||||||
|
# |
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||||
|
# |
||||||
|
# Unless required by applicable law or agreed to in writing, software |
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||||
|
# See the License for the specific language governing permissions and |
||||||
|
# limitations under the License. |
||||||
|
import asyncio |
||||||
|
import logging |
||||||
|
import unittest |
||||||
|
import datetime |
||||||
|
|
||||||
|
import grpc |
||||||
|
|
||||||
|
from grpc.experimental import aio |
||||||
|
from tests_aio.unit._constants import UNREACHABLE_TARGET |
||||||
|
from tests_aio.unit._common import inject_callbacks |
||||||
|
from tests_aio.unit._test_server import start_test_server |
||||||
|
from tests_aio.unit._test_base import AioTestBase |
||||||
|
from tests.unit.framework.common import test_constants |
||||||
|
from src.proto.grpc.testing import messages_pb2, test_pb2_grpc |
||||||
|
|
||||||
|
_SHORT_TIMEOUT_S = 1.0 |
||||||
|
|
||||||
|
_NUM_STREAM_REQUESTS = 5 |
||||||
|
_REQUEST_PAYLOAD_SIZE = 7 |
||||||
|
_RESPONSE_INTERVAL_US = int(_SHORT_TIMEOUT_S * 1000 * 1000) |
||||||
|
|
||||||
|
|
||||||
|
class _CountingRequestIterator: |
||||||
|
|
||||||
|
def __init__(self, request_iterator): |
||||||
|
self.request_cnt = 0 |
||||||
|
self._request_iterator = request_iterator |
||||||
|
|
||||||
|
async def _forward_requests(self): |
||||||
|
async for request in self._request_iterator: |
||||||
|
self.request_cnt += 1 |
||||||
|
yield request |
||||||
|
|
||||||
|
def __aiter__(self): |
||||||
|
return self._forward_requests() |
||||||
|
|
||||||
|
|
||||||
|
class _StreamUnaryInterceptorEmpty(aio.StreamUnaryClientInterceptor): |
||||||
|
|
||||||
|
async def intercept_stream_unary(self, continuation, client_call_details, |
||||||
|
request_iterator): |
||||||
|
return await continuation(client_call_details, request_iterator) |
||||||
|
|
||||||
|
def assert_in_final_state(self, test: unittest.TestCase): |
||||||
|
pass |
||||||
|
|
||||||
|
|
||||||
|
class _StreamUnaryInterceptorWithRequestIterator( |
||||||
|
aio.StreamUnaryClientInterceptor): |
||||||
|
|
||||||
|
async def intercept_stream_unary(self, continuation, client_call_details, |
||||||
|
request_iterator): |
||||||
|
self.request_iterator = _CountingRequestIterator(request_iterator) |
||||||
|
call = await continuation(client_call_details, self.request_iterator) |
||||||
|
return call |
||||||
|
|
||||||
|
def assert_in_final_state(self, test: unittest.TestCase): |
||||||
|
test.assertEqual(_NUM_STREAM_REQUESTS, |
||||||
|
self.request_iterator.request_cnt) |
||||||
|
|
||||||
|
|
||||||
|
class TestStreamUnaryClientInterceptor(AioTestBase): |
||||||
|
|
||||||
|
async def setUp(self): |
||||||
|
self._server_target, self._server = await start_test_server() |
||||||
|
|
||||||
|
async def tearDown(self): |
||||||
|
await self._server.stop(None) |
||||||
|
|
||||||
|
async def test_intercepts(self): |
||||||
|
for interceptor_class in (_StreamUnaryInterceptorEmpty, |
||||||
|
_StreamUnaryInterceptorWithRequestIterator): |
||||||
|
|
||||||
|
with self.subTest(name=interceptor_class): |
||||||
|
interceptor = interceptor_class() |
||||||
|
channel = aio.insecure_channel(self._server_target, |
||||||
|
interceptors=[interceptor]) |
||||||
|
stub = test_pb2_grpc.TestServiceStub(channel) |
||||||
|
|
||||||
|
payload = messages_pb2.Payload(body=b'\0' * |
||||||
|
_REQUEST_PAYLOAD_SIZE) |
||||||
|
request = messages_pb2.StreamingInputCallRequest( |
||||||
|
payload=payload) |
||||||
|
|
||||||
|
async def request_iterator(): |
||||||
|
for _ in range(_NUM_STREAM_REQUESTS): |
||||||
|
yield request |
||||||
|
|
||||||
|
call = stub.StreamingInputCall(request_iterator()) |
||||||
|
|
||||||
|
response = await call |
||||||
|
|
||||||
|
self.assertEqual(_NUM_STREAM_REQUESTS * _REQUEST_PAYLOAD_SIZE, |
||||||
|
response.aggregated_payload_size) |
||||||
|
self.assertEqual(await call.code(), grpc.StatusCode.OK) |
||||||
|
self.assertEqual(await call.initial_metadata(), ()) |
||||||
|
self.assertEqual(await call.trailing_metadata(), ()) |
||||||
|
self.assertEqual(await call.details(), '') |
||||||
|
self.assertEqual(await call.debug_error_string(), '') |
||||||
|
self.assertEqual(call.cancel(), False) |
||||||
|
self.assertEqual(call.cancelled(), False) |
||||||
|
self.assertEqual(call.done(), True) |
||||||
|
|
||||||
|
interceptor.assert_in_final_state(self) |
||||||
|
|
||||||
|
await channel.close() |
||||||
|
|
||||||
|
async def test_intercepts_using_write(self): |
||||||
|
for interceptor_class in (_StreamUnaryInterceptorEmpty, |
||||||
|
_StreamUnaryInterceptorWithRequestIterator): |
||||||
|
|
||||||
|
with self.subTest(name=interceptor_class): |
||||||
|
interceptor = interceptor_class() |
||||||
|
channel = aio.insecure_channel(self._server_target, |
||||||
|
interceptors=[interceptor]) |
||||||
|
stub = test_pb2_grpc.TestServiceStub(channel) |
||||||
|
|
||||||
|
payload = messages_pb2.Payload(body=b'\0' * |
||||||
|
_REQUEST_PAYLOAD_SIZE) |
||||||
|
request = messages_pb2.StreamingInputCallRequest( |
||||||
|
payload=payload) |
||||||
|
|
||||||
|
call = stub.StreamingInputCall() |
||||||
|
|
||||||
|
for _ in range(_NUM_STREAM_REQUESTS): |
||||||
|
await call.write(request) |
||||||
|
|
||||||
|
await call.done_writing() |
||||||
|
|
||||||
|
response = await call |
||||||
|
|
||||||
|
self.assertEqual(_NUM_STREAM_REQUESTS * _REQUEST_PAYLOAD_SIZE, |
||||||
|
response.aggregated_payload_size) |
||||||
|
self.assertEqual(await call.code(), grpc.StatusCode.OK) |
||||||
|
self.assertEqual(await call.initial_metadata(), ()) |
||||||
|
self.assertEqual(await call.trailing_metadata(), ()) |
||||||
|
self.assertEqual(await call.details(), '') |
||||||
|
self.assertEqual(await call.debug_error_string(), '') |
||||||
|
self.assertEqual(call.cancel(), False) |
||||||
|
self.assertEqual(call.cancelled(), False) |
||||||
|
self.assertEqual(call.done(), True) |
||||||
|
|
||||||
|
interceptor.assert_in_final_state(self) |
||||||
|
|
||||||
|
await channel.close() |
||||||
|
|
||||||
|
async def test_add_done_callback_interceptor_task_not_finished(self): |
||||||
|
for interceptor_class in (_StreamUnaryInterceptorEmpty, |
||||||
|
_StreamUnaryInterceptorWithRequestIterator): |
||||||
|
|
||||||
|
with self.subTest(name=interceptor_class): |
||||||
|
interceptor = interceptor_class() |
||||||
|
|
||||||
|
channel = aio.insecure_channel(self._server_target, |
||||||
|
interceptors=[interceptor]) |
||||||
|
stub = test_pb2_grpc.TestServiceStub(channel) |
||||||
|
|
||||||
|
payload = messages_pb2.Payload(body=b'\0' * |
||||||
|
_REQUEST_PAYLOAD_SIZE) |
||||||
|
request = messages_pb2.StreamingInputCallRequest( |
||||||
|
payload=payload) |
||||||
|
|
||||||
|
async def request_iterator(): |
||||||
|
for _ in range(_NUM_STREAM_REQUESTS): |
||||||
|
yield request |
||||||
|
|
||||||
|
call = stub.StreamingInputCall(request_iterator()) |
||||||
|
|
||||||
|
validation = inject_callbacks(call) |
||||||
|
|
||||||
|
response = await call |
||||||
|
|
||||||
|
await validation |
||||||
|
|
||||||
|
await channel.close() |
||||||
|
|
||||||
|
async def test_add_done_callback_interceptor_task_finished(self): |
||||||
|
for interceptor_class in (_StreamUnaryInterceptorEmpty, |
||||||
|
_StreamUnaryInterceptorWithRequestIterator): |
||||||
|
|
||||||
|
with self.subTest(name=interceptor_class): |
||||||
|
interceptor = interceptor_class() |
||||||
|
|
||||||
|
channel = aio.insecure_channel(self._server_target, |
||||||
|
interceptors=[interceptor]) |
||||||
|
stub = test_pb2_grpc.TestServiceStub(channel) |
||||||
|
|
||||||
|
payload = messages_pb2.Payload(body=b'\0' * |
||||||
|
_REQUEST_PAYLOAD_SIZE) |
||||||
|
request = messages_pb2.StreamingInputCallRequest( |
||||||
|
payload=payload) |
||||||
|
|
||||||
|
async def request_iterator(): |
||||||
|
for _ in range(_NUM_STREAM_REQUESTS): |
||||||
|
yield request |
||||||
|
|
||||||
|
call = stub.StreamingInputCall(request_iterator()) |
||||||
|
|
||||||
|
response = await call |
||||||
|
|
||||||
|
validation = inject_callbacks(call) |
||||||
|
|
||||||
|
await validation |
||||||
|
|
||||||
|
await channel.close() |
||||||
|
|
||||||
|
async def test_multiple_interceptors_request_iterator(self): |
||||||
|
for interceptor_class in (_StreamUnaryInterceptorEmpty, |
||||||
|
_StreamUnaryInterceptorWithRequestIterator): |
||||||
|
|
||||||
|
with self.subTest(name=interceptor_class): |
||||||
|
|
||||||
|
interceptors = [interceptor_class(), interceptor_class()] |
||||||
|
channel = aio.insecure_channel(self._server_target, |
||||||
|
interceptors=interceptors) |
||||||
|
stub = test_pb2_grpc.TestServiceStub(channel) |
||||||
|
|
||||||
|
payload = messages_pb2.Payload(body=b'\0' * |
||||||
|
_REQUEST_PAYLOAD_SIZE) |
||||||
|
request = messages_pb2.StreamingInputCallRequest( |
||||||
|
payload=payload) |
||||||
|
|
||||||
|
async def request_iterator(): |
||||||
|
for _ in range(_NUM_STREAM_REQUESTS): |
||||||
|
yield request |
||||||
|
|
||||||
|
call = stub.StreamingInputCall(request_iterator()) |
||||||
|
|
||||||
|
response = await call |
||||||
|
|
||||||
|
self.assertEqual(_NUM_STREAM_REQUESTS * _REQUEST_PAYLOAD_SIZE, |
||||||
|
response.aggregated_payload_size) |
||||||
|
self.assertEqual(await call.code(), grpc.StatusCode.OK) |
||||||
|
self.assertEqual(await call.initial_metadata(), ()) |
||||||
|
self.assertEqual(await call.trailing_metadata(), ()) |
||||||
|
self.assertEqual(await call.details(), '') |
||||||
|
self.assertEqual(await call.debug_error_string(), '') |
||||||
|
self.assertEqual(call.cancel(), False) |
||||||
|
self.assertEqual(call.cancelled(), False) |
||||||
|
self.assertEqual(call.done(), True) |
||||||
|
|
||||||
|
for interceptor in interceptors: |
||||||
|
interceptor.assert_in_final_state(self) |
||||||
|
|
||||||
|
await channel.close() |
||||||
|
|
||||||
|
async def test_intercepts_request_iterator_rpc_error(self): |
||||||
|
for interceptor_class in (_StreamUnaryInterceptorEmpty, |
||||||
|
_StreamUnaryInterceptorWithRequestIterator): |
||||||
|
|
||||||
|
with self.subTest(name=interceptor_class): |
||||||
|
channel = aio.insecure_channel( |
||||||
|
UNREACHABLE_TARGET, interceptors=[interceptor_class()]) |
||||||
|
stub = test_pb2_grpc.TestServiceStub(channel) |
||||||
|
|
||||||
|
payload = messages_pb2.Payload(body=b'\0' * |
||||||
|
_REQUEST_PAYLOAD_SIZE) |
||||||
|
request = messages_pb2.StreamingInputCallRequest( |
||||||
|
payload=payload) |
||||||
|
|
||||||
|
# When there is an error the request iterator is no longer |
||||||
|
# consumed. |
||||||
|
async def request_iterator(): |
||||||
|
for _ in range(_NUM_STREAM_REQUESTS): |
||||||
|
yield request |
||||||
|
|
||||||
|
call = stub.StreamingInputCall(request_iterator()) |
||||||
|
|
||||||
|
with self.assertRaises(aio.AioRpcError) as exception_context: |
||||||
|
await call |
||||||
|
|
||||||
|
self.assertEqual(grpc.StatusCode.UNAVAILABLE, |
||||||
|
exception_context.exception.code()) |
||||||
|
self.assertTrue(call.done()) |
||||||
|
self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code()) |
||||||
|
|
||||||
|
await channel.close() |
||||||
|
|
||||||
|
async def test_intercepts_request_iterator_rpc_error_using_write(self): |
||||||
|
for interceptor_class in (_StreamUnaryInterceptorEmpty, |
||||||
|
_StreamUnaryInterceptorWithRequestIterator): |
||||||
|
|
||||||
|
with self.subTest(name=interceptor_class): |
||||||
|
channel = aio.insecure_channel( |
||||||
|
UNREACHABLE_TARGET, interceptors=[interceptor_class()]) |
||||||
|
stub = test_pb2_grpc.TestServiceStub(channel) |
||||||
|
|
||||||
|
payload = messages_pb2.Payload(body=b'\0' * |
||||||
|
_REQUEST_PAYLOAD_SIZE) |
||||||
|
request = messages_pb2.StreamingInputCallRequest( |
||||||
|
payload=payload) |
||||||
|
|
||||||
|
call = stub.StreamingInputCall() |
||||||
|
|
||||||
|
# When there is an error during the write, exception is raised. |
||||||
|
with self.assertRaises(asyncio.InvalidStateError): |
||||||
|
for _ in range(_NUM_STREAM_REQUESTS): |
||||||
|
await call.write(request) |
||||||
|
|
||||||
|
with self.assertRaises(aio.AioRpcError) as exception_context: |
||||||
|
await call |
||||||
|
|
||||||
|
self.assertEqual(grpc.StatusCode.UNAVAILABLE, |
||||||
|
exception_context.exception.code()) |
||||||
|
self.assertTrue(call.done()) |
||||||
|
self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code()) |
||||||
|
|
||||||
|
await channel.close() |
||||||
|
|
||||||
|
async def test_cancel_before_rpc(self): |
||||||
|
|
||||||
|
interceptor_reached = asyncio.Event() |
||||||
|
wait_for_ever = self.loop.create_future() |
||||||
|
|
||||||
|
class Interceptor(aio.StreamUnaryClientInterceptor): |
||||||
|
|
||||||
|
async def intercept_stream_unary(self, continuation, |
||||||
|
client_call_details, |
||||||
|
request_iterator): |
||||||
|
interceptor_reached.set() |
||||||
|
await wait_for_ever |
||||||
|
|
||||||
|
channel = aio.insecure_channel(self._server_target, |
||||||
|
interceptors=[Interceptor()]) |
||||||
|
stub = test_pb2_grpc.TestServiceStub(channel) |
||||||
|
|
||||||
|
payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) |
||||||
|
request = messages_pb2.StreamingInputCallRequest(payload=payload) |
||||||
|
|
||||||
|
call = stub.StreamingInputCall() |
||||||
|
|
||||||
|
self.assertFalse(call.cancelled()) |
||||||
|
self.assertFalse(call.done()) |
||||||
|
|
||||||
|
await interceptor_reached.wait() |
||||||
|
self.assertTrue(call.cancel()) |
||||||
|
|
||||||
|
# When there is an error during the write, exception is raised. |
||||||
|
with self.assertRaises(asyncio.InvalidStateError): |
||||||
|
for _ in range(_NUM_STREAM_REQUESTS): |
||||||
|
await call.write(request) |
||||||
|
|
||||||
|
with self.assertRaises(asyncio.CancelledError): |
||||||
|
await call |
||||||
|
|
||||||
|
self.assertTrue(call.cancelled()) |
||||||
|
self.assertTrue(call.done()) |
||||||
|
self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) |
||||||
|
self.assertEqual(await call.initial_metadata(), None) |
||||||
|
self.assertEqual(await call.trailing_metadata(), None) |
||||||
|
await channel.close() |
||||||
|
|
||||||
|
async def test_cancel_after_rpc(self): |
||||||
|
|
||||||
|
interceptor_reached = asyncio.Event() |
||||||
|
wait_for_ever = self.loop.create_future() |
||||||
|
|
||||||
|
class Interceptor(aio.StreamUnaryClientInterceptor): |
||||||
|
|
||||||
|
async def intercept_stream_unary(self, continuation, |
||||||
|
client_call_details, |
||||||
|
request_iterator): |
||||||
|
call = await continuation(client_call_details, request_iterator) |
||||||
|
interceptor_reached.set() |
||||||
|
await wait_for_ever |
||||||
|
|
||||||
|
channel = aio.insecure_channel(self._server_target, |
||||||
|
interceptors=[Interceptor()]) |
||||||
|
stub = test_pb2_grpc.TestServiceStub(channel) |
||||||
|
|
||||||
|
payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) |
||||||
|
request = messages_pb2.StreamingInputCallRequest(payload=payload) |
||||||
|
|
||||||
|
call = stub.StreamingInputCall() |
||||||
|
|
||||||
|
self.assertFalse(call.cancelled()) |
||||||
|
self.assertFalse(call.done()) |
||||||
|
|
||||||
|
await interceptor_reached.wait() |
||||||
|
self.assertTrue(call.cancel()) |
||||||
|
|
||||||
|
# When there is an error during the write, exception is raised. |
||||||
|
with self.assertRaises(asyncio.InvalidStateError): |
||||||
|
for _ in range(_NUM_STREAM_REQUESTS): |
||||||
|
await call.write(request) |
||||||
|
|
||||||
|
with self.assertRaises(asyncio.CancelledError): |
||||||
|
await call |
||||||
|
|
||||||
|
self.assertTrue(call.cancelled()) |
||||||
|
self.assertTrue(call.done()) |
||||||
|
self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) |
||||||
|
self.assertEqual(await call.initial_metadata(), None) |
||||||
|
self.assertEqual(await call.trailing_metadata(), None) |
||||||
|
await channel.close() |
||||||
|
|
||||||
|
async def test_cancel_while_writing(self): |
||||||
|
# Test cancelation before making any write or after doing at least 1 |
||||||
|
for num_writes_before_cancel in (0, 1): |
||||||
|
with self.subTest(name="Num writes before cancel: {}".format( |
||||||
|
num_writes_before_cancel)): |
||||||
|
|
||||||
|
channel = aio.insecure_channel( |
||||||
|
UNREACHABLE_TARGET, |
||||||
|
interceptors=[_StreamUnaryInterceptorWithRequestIterator()]) |
||||||
|
stub = test_pb2_grpc.TestServiceStub(channel) |
||||||
|
|
||||||
|
payload = messages_pb2.Payload(body=b'\0' * |
||||||
|
_REQUEST_PAYLOAD_SIZE) |
||||||
|
request = messages_pb2.StreamingInputCallRequest( |
||||||
|
payload=payload) |
||||||
|
|
||||||
|
call = stub.StreamingInputCall() |
||||||
|
|
||||||
|
with self.assertRaises(asyncio.InvalidStateError): |
||||||
|
for i in range(_NUM_STREAM_REQUESTS): |
||||||
|
if i == num_writes_before_cancel: |
||||||
|
self.assertTrue(call.cancel()) |
||||||
|
await call.write(request) |
||||||
|
|
||||||
|
with self.assertRaises(asyncio.CancelledError): |
||||||
|
await call |
||||||
|
|
||||||
|
self.assertTrue(call.cancelled()) |
||||||
|
self.assertTrue(call.done()) |
||||||
|
self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) |
||||||
|
|
||||||
|
await channel.close() |
||||||
|
|
||||||
|
async def test_cancel_by_the_interceptor(self): |
||||||
|
|
||||||
|
class Interceptor(aio.StreamUnaryClientInterceptor): |
||||||
|
|
||||||
|
async def intercept_stream_unary(self, continuation, |
||||||
|
client_call_details, |
||||||
|
request_iterator): |
||||||
|
call = await continuation(client_call_details, request_iterator) |
||||||
|
call.cancel() |
||||||
|
return call |
||||||
|
|
||||||
|
channel = aio.insecure_channel(UNREACHABLE_TARGET, |
||||||
|
interceptors=[Interceptor()]) |
||||||
|
stub = test_pb2_grpc.TestServiceStub(channel) |
||||||
|
|
||||||
|
payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) |
||||||
|
request = messages_pb2.StreamingInputCallRequest(payload=payload) |
||||||
|
|
||||||
|
call = stub.StreamingInputCall() |
||||||
|
|
||||||
|
with self.assertRaises(asyncio.InvalidStateError): |
||||||
|
for i in range(_NUM_STREAM_REQUESTS): |
||||||
|
await call.write(request) |
||||||
|
|
||||||
|
with self.assertRaises(asyncio.CancelledError): |
||||||
|
await call |
||||||
|
|
||||||
|
self.assertTrue(call.cancelled()) |
||||||
|
self.assertTrue(call.done()) |
||||||
|
self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) |
||||||
|
|
||||||
|
await channel.close() |
||||||
|
|
||||||
|
async def test_exception_raised_by_interceptor(self): |
||||||
|
|
||||||
|
class InterceptorException(Exception): |
||||||
|
pass |
||||||
|
|
||||||
|
class Interceptor(aio.StreamUnaryClientInterceptor): |
||||||
|
|
||||||
|
async def intercept_stream_unary(self, continuation, |
||||||
|
client_call_details, |
||||||
|
request_iterator): |
||||||
|
raise InterceptorException |
||||||
|
|
||||||
|
channel = aio.insecure_channel(UNREACHABLE_TARGET, |
||||||
|
interceptors=[Interceptor()]) |
||||||
|
stub = test_pb2_grpc.TestServiceStub(channel) |
||||||
|
|
||||||
|
payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) |
||||||
|
request = messages_pb2.StreamingInputCallRequest(payload=payload) |
||||||
|
|
||||||
|
call = stub.StreamingInputCall() |
||||||
|
|
||||||
|
with self.assertRaises(InterceptorException): |
||||||
|
for i in range(_NUM_STREAM_REQUESTS): |
||||||
|
await call.write(request) |
||||||
|
|
||||||
|
with self.assertRaises(InterceptorException): |
||||||
|
await call |
||||||
|
|
||||||
|
await channel.close() |
||||||
|
|
||||||
|
async def test_intercepts_prohibit_mixing_style(self): |
||||||
|
channel = aio.insecure_channel( |
||||||
|
self._server_target, interceptors=[_StreamUnaryInterceptorEmpty()]) |
||||||
|
stub = test_pb2_grpc.TestServiceStub(channel) |
||||||
|
|
||||||
|
payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) |
||||||
|
request = messages_pb2.StreamingInputCallRequest(payload=payload) |
||||||
|
|
||||||
|
async def request_iterator(): |
||||||
|
for _ in range(_NUM_STREAM_REQUESTS): |
||||||
|
yield request |
||||||
|
|
||||||
|
call = stub.StreamingInputCall(request_iterator()) |
||||||
|
|
||||||
|
with self.assertRaises(grpc._cython.cygrpc.UsageError): |
||||||
|
await call.write(request) |
||||||
|
|
||||||
|
with self.assertRaises(grpc._cython.cygrpc.UsageError): |
||||||
|
await call.done_writing() |
||||||
|
|
||||||
|
await channel.close() |
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__': |
||||||
|
logging.basicConfig(level=logging.DEBUG) |
||||||
|
unittest.main(verbosity=2) |
File diff suppressed because it is too large
Load Diff
Loading…
Reference in new issue