diff --git a/src/python/grpcio_reflection/grpc_reflection/v1alpha/_async.py b/src/python/grpcio_reflection/grpc_reflection/v1alpha/_async.py index ef3a99abb3a..88762b7d519 100644 --- a/src/python/grpcio_reflection/grpc_reflection/v1alpha/_async.py +++ b/src/python/grpcio_reflection/grpc_reflection/v1alpha/_async.py @@ -32,22 +32,23 @@ class ReflectionServicer(BaseReflectionServicer): ) -> AsyncIterable[_reflection_pb2.ServerReflectionResponse]: async for request in request_iterator: if request.HasField("file_by_filename"): - yield self._file_by_filename(request.file_by_filename) + yield self._file_by_filename(request, request.file_by_filename) elif request.HasField("file_containing_symbol"): yield self._file_containing_symbol( - request.file_containing_symbol + request, request.file_containing_symbol ) elif request.HasField("file_containing_extension"): yield self._file_containing_extension( + request, request.file_containing_extension.containing_type, request.file_containing_extension.extension_number, ) elif request.HasField("all_extension_numbers_of_type"): yield self._all_extension_numbers_of_type( - request.all_extension_numbers_of_type + request, request.all_extension_numbers_of_type ) elif request.HasField("list_services"): - yield self._list_services() + yield self._list_services(request) else: yield _reflection_pb2.ServerReflectionResponse( error_response=_reflection_pb2.ErrorResponse( @@ -55,7 +56,8 @@ class ReflectionServicer(BaseReflectionServicer): error_message=grpc.StatusCode.INVALID_ARGUMENT.value[ 1 ].encode(), - ) + ), + original_request=request, ) diff --git a/src/python/grpcio_reflection/grpc_reflection/v1alpha/_base.py b/src/python/grpcio_reflection/grpc_reflection/v1alpha/_base.py index e73abc9346e..50d671a92d7 100644 --- a/src/python/grpcio_reflection/grpc_reflection/v1alpha/_base.py +++ b/src/python/grpcio_reflection/grpc_reflection/v1alpha/_base.py @@ -22,12 +22,13 @@ from grpc_reflection.v1alpha import reflection_pb2_grpc as _reflection_pb2_grpc _POOL = descriptor_pool.Default() -def _not_found_error(): +def _not_found_error(original_request): return _reflection_pb2.ServerReflectionResponse( error_response=_reflection_pb2.ErrorResponse( error_code=grpc.StatusCode.NOT_FOUND.value[0], error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(), - ) + ), + original_request=original_request, ) @@ -39,7 +40,7 @@ def _collect_transitive_dependencies(descriptor, seen_files): _collect_transitive_dependencies(dependency, seen_files) -def _file_descriptor_response(descriptor): +def _file_descriptor_response(descriptor, original_request): # collect all dependencies descriptors = {} _collect_transitive_dependencies(descriptor, descriptors) @@ -55,6 +56,7 @@ def _file_descriptor_response(descriptor): file_descriptor_response=_reflection_pb2.FileDescriptorResponse( file_descriptor_proto=(serialized_proto_list) ), + original_request=original_request, ) @@ -71,25 +73,27 @@ class BaseReflectionServicer(_reflection_pb2_grpc.ServerReflectionServicer): self._service_names = tuple(sorted(service_names)) self._pool = _POOL if pool is None else pool - def _file_by_filename(self, filename): + def _file_by_filename(self, request, filename): try: descriptor = self._pool.FindFileByName(filename) except KeyError: - return _not_found_error() + return _not_found_error(request) else: - return _file_descriptor_response(descriptor) + return _file_descriptor_response(descriptor, request) - def _file_containing_symbol(self, fully_qualified_name): + def _file_containing_symbol(self, request, fully_qualified_name): try: descriptor = self._pool.FindFileContainingSymbol( fully_qualified_name ) except KeyError: - return _not_found_error() + return _not_found_error(request) else: - return _file_descriptor_response(descriptor) + return _file_descriptor_response(descriptor, request) - def _file_containing_extension(self, containing_type, extension_number): + def _file_containing_extension( + self, request, containing_type, extension_number + ): try: message_descriptor = self._pool.FindMessageTypeByName( containing_type @@ -101,11 +105,11 @@ class BaseReflectionServicer(_reflection_pb2_grpc.ServerReflectionServicer): extension_descriptor.full_name ) except KeyError: - return _not_found_error() + return _not_found_error(request) else: - return _file_descriptor_response(descriptor) + return _file_descriptor_response(descriptor, request) - def _all_extension_numbers_of_type(self, containing_type): + def _all_extension_numbers_of_type(self, request, containing_type): try: message_descriptor = self._pool.FindMessageTypeByName( containing_type @@ -119,23 +123,25 @@ class BaseReflectionServicer(_reflection_pb2_grpc.ServerReflectionServicer): ) ) except KeyError: - return _not_found_error() + return _not_found_error(request) else: return _reflection_pb2.ServerReflectionResponse( all_extension_numbers_response=_reflection_pb2.ExtensionNumberResponse( base_type_name=message_descriptor.full_name, extension_number=extension_numbers, - ) + ), + original_request=request, ) - def _list_services(self): + def _list_services(self, request): return _reflection_pb2.ServerReflectionResponse( list_services_response=_reflection_pb2.ListServiceResponse( service=[ _reflection_pb2.ServiceResponse(name=service_name) for service_name in self._service_names ] - ) + ), + original_request=request, ) diff --git a/src/python/grpcio_reflection/grpc_reflection/v1alpha/reflection.py b/src/python/grpcio_reflection/grpc_reflection/v1alpha/reflection.py index 1c1807f49aa..4315ac9fcbf 100644 --- a/src/python/grpcio_reflection/grpc_reflection/v1alpha/reflection.py +++ b/src/python/grpcio_reflection/grpc_reflection/v1alpha/reflection.py @@ -32,22 +32,23 @@ class ReflectionServicer(BaseReflectionServicer): # pylint: disable=unused-argument for request in request_iterator: if request.HasField("file_by_filename"): - yield self._file_by_filename(request.file_by_filename) + yield self._file_by_filename(request, request.file_by_filename) elif request.HasField("file_containing_symbol"): yield self._file_containing_symbol( - request.file_containing_symbol + request, request.file_containing_symbol ) elif request.HasField("file_containing_extension"): yield self._file_containing_extension( + request, request.file_containing_extension.containing_type, request.file_containing_extension.extension_number, ) elif request.HasField("all_extension_numbers_of_type"): yield self._all_extension_numbers_of_type( - request.all_extension_numbers_of_type + request, request.all_extension_numbers_of_type ) elif request.HasField("list_services"): - yield self._list_services() + yield self._list_services(request) else: yield _reflection_pb2.ServerReflectionResponse( error_response=_reflection_pb2.ErrorResponse( @@ -55,7 +56,8 @@ class ReflectionServicer(BaseReflectionServicer): error_message=grpc.StatusCode.INVALID_ARGUMENT.value[ 1 ].encode(), - ) + ), + original_request=request, ) diff --git a/src/python/grpcio_tests/tests/reflection/_reflection_servicer_test.py b/src/python/grpcio_tests/tests/reflection/_reflection_servicer_test.py index 7568abcd178..810b56f20fb 100644 --- a/src/python/grpcio_tests/tests/reflection/_reflection_servicer_test.py +++ b/src/python/grpcio_tests/tests/reflection/_reflection_servicer_test.py @@ -90,6 +90,7 @@ class ReflectionServicerTest(unittest.TestCase): _file_descriptor_to_proto(empty_pb2.DESCRIPTOR), ) ), + original_request=requests[0], ), reflection_pb2.ServerReflectionResponse( valid_host="", @@ -97,6 +98,7 @@ class ReflectionServicerTest(unittest.TestCase): error_code=grpc.StatusCode.NOT_FOUND.value[0], error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(), ), + original_request=requests[1], ), ) self.assertEqual(expected_responses, responses) @@ -119,6 +121,7 @@ class ReflectionServicerTest(unittest.TestCase): _file_descriptor_to_proto(empty_pb2.DESCRIPTOR), ) ), + original_request=requests[0], ), reflection_pb2.ServerReflectionResponse( valid_host="", @@ -126,6 +129,7 @@ class ReflectionServicerTest(unittest.TestCase): error_code=grpc.StatusCode.NOT_FOUND.value[0], error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(), ), + original_request=requests[1], ), ) self.assertEqual(expected_responses, responses) @@ -157,6 +161,7 @@ class ReflectionServicerTest(unittest.TestCase): _file_descriptor_to_proto(empty2_pb2.DESCRIPTOR), ) ), + original_request=requests[0], ), reflection_pb2.ServerReflectionResponse( valid_host="", @@ -164,6 +169,7 @@ class ReflectionServicerTest(unittest.TestCase): error_code=grpc.StatusCode.NOT_FOUND.value[0], error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(), ), + original_request=requests[1], ), ) self.assertEqual(expected_responses, responses) @@ -185,6 +191,7 @@ class ReflectionServicerTest(unittest.TestCase): base_type_name=_EMPTY_EXTENSIONS_SYMBOL_NAME, extension_number=_EMPTY_EXTENSIONS_NUMBERS, ), + original_request=requests[0], ), reflection_pb2.ServerReflectionResponse( valid_host="", @@ -192,6 +199,7 @@ class ReflectionServicerTest(unittest.TestCase): error_code=grpc.StatusCode.NOT_FOUND.value[0], error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(), ), + original_request=requests[1], ), ) self.assertEqual(expected_responses, responses) @@ -212,6 +220,7 @@ class ReflectionServicerTest(unittest.TestCase): for name in _SERVICE_NAMES ) ), + original_request=requests[0], ), ) self.assertEqual(expected_responses, responses) diff --git a/src/python/grpcio_tests/tests_aio/reflection/reflection_servicer_test.py b/src/python/grpcio_tests/tests_aio/reflection/reflection_servicer_test.py index de89aca5882..d1cc2c74891 100644 --- a/src/python/grpcio_tests/tests_aio/reflection/reflection_servicer_test.py +++ b/src/python/grpcio_tests/tests_aio/reflection/reflection_servicer_test.py @@ -89,6 +89,7 @@ class ReflectionServicerTest(AioTestBase): _file_descriptor_to_proto(empty_pb2.DESCRIPTOR), ) ), + original_request=requests[0], ), reflection_pb2.ServerReflectionResponse( valid_host="", @@ -96,6 +97,7 @@ class ReflectionServicerTest(AioTestBase): error_code=grpc.StatusCode.NOT_FOUND.value[0], error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(), ), + original_request=requests[1], ), ) self.assertSequenceEqual(expected_responses, responses) @@ -120,6 +122,7 @@ class ReflectionServicerTest(AioTestBase): _file_descriptor_to_proto(empty_pb2.DESCRIPTOR), ) ), + original_request=requests[0], ), reflection_pb2.ServerReflectionResponse( valid_host="", @@ -127,6 +130,7 @@ class ReflectionServicerTest(AioTestBase): error_code=grpc.StatusCode.NOT_FOUND.value[0], error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(), ), + original_request=requests[1], ), ) self.assertSequenceEqual(expected_responses, responses) @@ -160,6 +164,7 @@ class ReflectionServicerTest(AioTestBase): _file_descriptor_to_proto(empty2_pb2.DESCRIPTOR), ) ), + original_request=requests[0], ), reflection_pb2.ServerReflectionResponse( valid_host="", @@ -167,6 +172,7 @@ class ReflectionServicerTest(AioTestBase): error_code=grpc.StatusCode.NOT_FOUND.value[0], error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(), ), + original_request=requests[1], ), ) self.assertSequenceEqual(expected_responses, responses) @@ -190,6 +196,7 @@ class ReflectionServicerTest(AioTestBase): base_type_name=_EMPTY_EXTENSIONS_SYMBOL_NAME, extension_number=_EMPTY_EXTENSIONS_NUMBERS, ), + original_request=requests[0], ), reflection_pb2.ServerReflectionResponse( valid_host="", @@ -197,6 +204,7 @@ class ReflectionServicerTest(AioTestBase): error_code=grpc.StatusCode.NOT_FOUND.value[0], error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(), ), + original_request=requests[1], ), ) self.assertSequenceEqual(expected_responses, responses) @@ -219,6 +227,7 @@ class ReflectionServicerTest(AioTestBase): for name in _SERVICE_NAMES ) ), + original_request=requests[0], ), ) self.assertSequenceEqual(expected_responses, responses)