diff --git a/src/python/grpcio_testing/grpc_testing/_server/_service.py b/src/python/grpcio_testing/grpc_testing/_server/_service.py index 36b0a2f7fff..97b3ffa4229 100644 --- a/src/python/grpcio_testing/grpc_testing/_server/_service.py +++ b/src/python/grpcio_testing/grpc_testing/_server/_service.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import grpc @@ -59,7 +60,7 @@ def _stream_response(argument, implementation, rpc, servicer_context): else: while True: try: - response = next(response_iterator) + response = copy.deepcopy(next(response_iterator)) except StopIteration: rpc.stream_response_complete() break diff --git a/src/python/grpcio_tests/tests/testing/_application_common.py b/src/python/grpcio_tests/tests/testing/_application_common.py index b7d5b236051..3226d1fb020 100644 --- a/src/python/grpcio_tests/tests/testing/_application_common.py +++ b/src/python/grpcio_tests/tests/testing/_application_common.py @@ -37,5 +37,7 @@ ABORT_SUCCESS_QUERY = requests_pb2.Up(first_up_field=43) ABORT_NO_STATUS_RESPONSE = services_pb2.Down(first_down_field=50) ABORT_SUCCESS_RESPONSE = services_pb2.Down(first_down_field=51) ABORT_FAILURE_RESPONSE = services_pb2.Down(first_down_field=52) +STREAM_STREAM_MUTATING_REQUEST = requests_pb2.Top(first_top_field=24601) +STREAM_STREAM_MUTATING_COUNT = 2 INFINITE_REQUEST_STREAM_TIMEOUT = 0.2 diff --git a/src/python/grpcio_tests/tests/testing/_server_application.py b/src/python/grpcio_tests/tests/testing/_server_application.py index 1dc5e8f3917..51ed977b8fe 100644 --- a/src/python/grpcio_tests/tests/testing/_server_application.py +++ b/src/python/grpcio_tests/tests/testing/_server_application.py @@ -75,13 +75,21 @@ class FirstServiceServicer(services_pb2_grpc.FirstServiceServicer): return _application_common.STREAM_UNARY_RESPONSE def StreStre(self, request_iterator, context): + valid_requests = (_application_common.STREAM_STREAM_REQUEST, + _application_common.STREAM_STREAM_MUTATING_REQUEST) for request in request_iterator: - if request != _application_common.STREAM_STREAM_REQUEST: + if request not in valid_requests: context.set_code(grpc.StatusCode.INVALID_ARGUMENT) context.set_details('Something is wrong with your request!') return elif not context.is_active(): return - else: + elif request == _application_common.STREAM_STREAM_REQUEST: yield _application_common.STREAM_STREAM_RESPONSE yield _application_common.STREAM_STREAM_RESPONSE + elif request == _application_common.STREAM_STREAM_MUTATING_REQUEST: + response = services_pb2.Bottom() + for i in range( + _application_common.STREAM_STREAM_MUTATING_COUNT): + response.first_bottom_field = i + yield response diff --git a/src/python/grpcio_tests/tests/testing/_server_test.py b/src/python/grpcio_tests/tests/testing/_server_test.py index 45975f229bc..617a41b7e54 100644 --- a/src/python/grpcio_tests/tests/testing/_server_test.py +++ b/src/python/grpcio_tests/tests/testing/_server_test.py @@ -21,6 +21,7 @@ import grpc_testing from tests.testing import _application_common from tests.testing import _application_testing_common from tests.testing import _server_application +from tests.testing.proto import services_pb2 class FirstServiceServicerTest(unittest.TestCase): @@ -94,6 +95,30 @@ class FirstServiceServicerTest(unittest.TestCase): response) self.assertIs(code, grpc.StatusCode.OK) + def test_mutating_stream_stream(self): + rpc = self._real_time_server.invoke_stream_stream( + _application_testing_common.FIRST_SERVICE_STRESTRE, (), None) + rpc.send_request(_application_common.STREAM_STREAM_MUTATING_REQUEST) + initial_metadata = rpc.initial_metadata() + responses = [ + rpc.take_response() + for _ in range(_application_common.STREAM_STREAM_MUTATING_COUNT) + ] + rpc.send_request(_application_common.STREAM_STREAM_MUTATING_REQUEST) + responses.extend([ + rpc.take_response() + for _ in range(_application_common.STREAM_STREAM_MUTATING_COUNT) + ]) + rpc.requests_closed() + _, _, _ = rpc.termination() + expected_responses = ( + services_pb2.Bottom(first_bottom_field=0), + services_pb2.Bottom(first_bottom_field=1), + services_pb2.Bottom(first_bottom_field=0), + services_pb2.Bottom(first_bottom_field=1), + ) + self.assertSequenceEqual(expected_responses, responses) + def test_server_rpc_idempotence(self): rpc = self._real_time_server.invoke_unary_unary( _application_testing_common.FIRST_SERVICE_UNUN, (),