[reflection]: python: reflection returns `original_request` (#36944)

Closes #36943

This PR adds `original_request` to any `_reflection_pb2.ServerReflectionResponse` generated by `grpc_reflection.v1alpha`

Closes #36944

COPYBARA_INTEGRATE_REVIEW=https://github.com/grpc/grpc/pull/36944 from Drarig29:corentin.girard/add-original-request-to-ServerReflectionResponse e6a94789e1
PiperOrigin-RevId: 660746169
pull/37421/head
Corentin Girard 7 months ago committed by Copybara-Service
parent d3560d9176
commit 8271b14c41
  1. 12
      src/python/grpcio_reflection/grpc_reflection/v1alpha/_async.py
  2. 40
      src/python/grpcio_reflection/grpc_reflection/v1alpha/_base.py
  3. 12
      src/python/grpcio_reflection/grpc_reflection/v1alpha/reflection.py
  4. 9
      src/python/grpcio_tests/tests/reflection/_reflection_servicer_test.py
  5. 9
      src/python/grpcio_tests/tests_aio/reflection/reflection_servicer_test.py

@ -32,22 +32,23 @@ class ReflectionServicer(BaseReflectionServicer):
) -> AsyncIterable[_reflection_pb2.ServerReflectionResponse]:
async for request in request_iterator:
if request.HasField("file_by_filename"):
yield self._file_by_filename(request.file_by_filename)
yield self._file_by_filename(request, request.file_by_filename)
elif request.HasField("file_containing_symbol"):
yield self._file_containing_symbol(
request.file_containing_symbol
request, request.file_containing_symbol
)
elif request.HasField("file_containing_extension"):
yield self._file_containing_extension(
request,
request.file_containing_extension.containing_type,
request.file_containing_extension.extension_number,
)
elif request.HasField("all_extension_numbers_of_type"):
yield self._all_extension_numbers_of_type(
request.all_extension_numbers_of_type
request, request.all_extension_numbers_of_type
)
elif request.HasField("list_services"):
yield self._list_services()
yield self._list_services(request)
else:
yield _reflection_pb2.ServerReflectionResponse(
error_response=_reflection_pb2.ErrorResponse(
@ -55,7 +56,8 @@ class ReflectionServicer(BaseReflectionServicer):
error_message=grpc.StatusCode.INVALID_ARGUMENT.value[
1
].encode(),
)
),
original_request=request,
)

@ -22,12 +22,13 @@ from grpc_reflection.v1alpha import reflection_pb2_grpc as _reflection_pb2_grpc
_POOL = descriptor_pool.Default()
def _not_found_error():
def _not_found_error(original_request):
return _reflection_pb2.ServerReflectionResponse(
error_response=_reflection_pb2.ErrorResponse(
error_code=grpc.StatusCode.NOT_FOUND.value[0],
error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(),
)
),
original_request=original_request,
)
@ -39,7 +40,7 @@ def _collect_transitive_dependencies(descriptor, seen_files):
_collect_transitive_dependencies(dependency, seen_files)
def _file_descriptor_response(descriptor):
def _file_descriptor_response(descriptor, original_request):
# collect all dependencies
descriptors = {}
_collect_transitive_dependencies(descriptor, descriptors)
@ -55,6 +56,7 @@ def _file_descriptor_response(descriptor):
file_descriptor_response=_reflection_pb2.FileDescriptorResponse(
file_descriptor_proto=(serialized_proto_list)
),
original_request=original_request,
)
@ -71,25 +73,27 @@ class BaseReflectionServicer(_reflection_pb2_grpc.ServerReflectionServicer):
self._service_names = tuple(sorted(service_names))
self._pool = _POOL if pool is None else pool
def _file_by_filename(self, filename):
def _file_by_filename(self, request, filename):
try:
descriptor = self._pool.FindFileByName(filename)
except KeyError:
return _not_found_error()
return _not_found_error(request)
else:
return _file_descriptor_response(descriptor)
return _file_descriptor_response(descriptor, request)
def _file_containing_symbol(self, fully_qualified_name):
def _file_containing_symbol(self, request, fully_qualified_name):
try:
descriptor = self._pool.FindFileContainingSymbol(
fully_qualified_name
)
except KeyError:
return _not_found_error()
return _not_found_error(request)
else:
return _file_descriptor_response(descriptor)
return _file_descriptor_response(descriptor, request)
def _file_containing_extension(self, containing_type, extension_number):
def _file_containing_extension(
self, request, containing_type, extension_number
):
try:
message_descriptor = self._pool.FindMessageTypeByName(
containing_type
@ -101,11 +105,11 @@ class BaseReflectionServicer(_reflection_pb2_grpc.ServerReflectionServicer):
extension_descriptor.full_name
)
except KeyError:
return _not_found_error()
return _not_found_error(request)
else:
return _file_descriptor_response(descriptor)
return _file_descriptor_response(descriptor, request)
def _all_extension_numbers_of_type(self, containing_type):
def _all_extension_numbers_of_type(self, request, containing_type):
try:
message_descriptor = self._pool.FindMessageTypeByName(
containing_type
@ -119,23 +123,25 @@ class BaseReflectionServicer(_reflection_pb2_grpc.ServerReflectionServicer):
)
)
except KeyError:
return _not_found_error()
return _not_found_error(request)
else:
return _reflection_pb2.ServerReflectionResponse(
all_extension_numbers_response=_reflection_pb2.ExtensionNumberResponse(
base_type_name=message_descriptor.full_name,
extension_number=extension_numbers,
)
),
original_request=request,
)
def _list_services(self):
def _list_services(self, request):
return _reflection_pb2.ServerReflectionResponse(
list_services_response=_reflection_pb2.ListServiceResponse(
service=[
_reflection_pb2.ServiceResponse(name=service_name)
for service_name in self._service_names
]
)
),
original_request=request,
)

@ -32,22 +32,23 @@ class ReflectionServicer(BaseReflectionServicer):
# pylint: disable=unused-argument
for request in request_iterator:
if request.HasField("file_by_filename"):
yield self._file_by_filename(request.file_by_filename)
yield self._file_by_filename(request, request.file_by_filename)
elif request.HasField("file_containing_symbol"):
yield self._file_containing_symbol(
request.file_containing_symbol
request, request.file_containing_symbol
)
elif request.HasField("file_containing_extension"):
yield self._file_containing_extension(
request,
request.file_containing_extension.containing_type,
request.file_containing_extension.extension_number,
)
elif request.HasField("all_extension_numbers_of_type"):
yield self._all_extension_numbers_of_type(
request.all_extension_numbers_of_type
request, request.all_extension_numbers_of_type
)
elif request.HasField("list_services"):
yield self._list_services()
yield self._list_services(request)
else:
yield _reflection_pb2.ServerReflectionResponse(
error_response=_reflection_pb2.ErrorResponse(
@ -55,7 +56,8 @@ class ReflectionServicer(BaseReflectionServicer):
error_message=grpc.StatusCode.INVALID_ARGUMENT.value[
1
].encode(),
)
),
original_request=request,
)

@ -90,6 +90,7 @@ class ReflectionServicerTest(unittest.TestCase):
_file_descriptor_to_proto(empty_pb2.DESCRIPTOR),
)
),
original_request=requests[0],
),
reflection_pb2.ServerReflectionResponse(
valid_host="",
@ -97,6 +98,7 @@ class ReflectionServicerTest(unittest.TestCase):
error_code=grpc.StatusCode.NOT_FOUND.value[0],
error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(),
),
original_request=requests[1],
),
)
self.assertEqual(expected_responses, responses)
@ -119,6 +121,7 @@ class ReflectionServicerTest(unittest.TestCase):
_file_descriptor_to_proto(empty_pb2.DESCRIPTOR),
)
),
original_request=requests[0],
),
reflection_pb2.ServerReflectionResponse(
valid_host="",
@ -126,6 +129,7 @@ class ReflectionServicerTest(unittest.TestCase):
error_code=grpc.StatusCode.NOT_FOUND.value[0],
error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(),
),
original_request=requests[1],
),
)
self.assertEqual(expected_responses, responses)
@ -157,6 +161,7 @@ class ReflectionServicerTest(unittest.TestCase):
_file_descriptor_to_proto(empty2_pb2.DESCRIPTOR),
)
),
original_request=requests[0],
),
reflection_pb2.ServerReflectionResponse(
valid_host="",
@ -164,6 +169,7 @@ class ReflectionServicerTest(unittest.TestCase):
error_code=grpc.StatusCode.NOT_FOUND.value[0],
error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(),
),
original_request=requests[1],
),
)
self.assertEqual(expected_responses, responses)
@ -185,6 +191,7 @@ class ReflectionServicerTest(unittest.TestCase):
base_type_name=_EMPTY_EXTENSIONS_SYMBOL_NAME,
extension_number=_EMPTY_EXTENSIONS_NUMBERS,
),
original_request=requests[0],
),
reflection_pb2.ServerReflectionResponse(
valid_host="",
@ -192,6 +199,7 @@ class ReflectionServicerTest(unittest.TestCase):
error_code=grpc.StatusCode.NOT_FOUND.value[0],
error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(),
),
original_request=requests[1],
),
)
self.assertEqual(expected_responses, responses)
@ -212,6 +220,7 @@ class ReflectionServicerTest(unittest.TestCase):
for name in _SERVICE_NAMES
)
),
original_request=requests[0],
),
)
self.assertEqual(expected_responses, responses)

@ -89,6 +89,7 @@ class ReflectionServicerTest(AioTestBase):
_file_descriptor_to_proto(empty_pb2.DESCRIPTOR),
)
),
original_request=requests[0],
),
reflection_pb2.ServerReflectionResponse(
valid_host="",
@ -96,6 +97,7 @@ class ReflectionServicerTest(AioTestBase):
error_code=grpc.StatusCode.NOT_FOUND.value[0],
error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(),
),
original_request=requests[1],
),
)
self.assertSequenceEqual(expected_responses, responses)
@ -120,6 +122,7 @@ class ReflectionServicerTest(AioTestBase):
_file_descriptor_to_proto(empty_pb2.DESCRIPTOR),
)
),
original_request=requests[0],
),
reflection_pb2.ServerReflectionResponse(
valid_host="",
@ -127,6 +130,7 @@ class ReflectionServicerTest(AioTestBase):
error_code=grpc.StatusCode.NOT_FOUND.value[0],
error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(),
),
original_request=requests[1],
),
)
self.assertSequenceEqual(expected_responses, responses)
@ -160,6 +164,7 @@ class ReflectionServicerTest(AioTestBase):
_file_descriptor_to_proto(empty2_pb2.DESCRIPTOR),
)
),
original_request=requests[0],
),
reflection_pb2.ServerReflectionResponse(
valid_host="",
@ -167,6 +172,7 @@ class ReflectionServicerTest(AioTestBase):
error_code=grpc.StatusCode.NOT_FOUND.value[0],
error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(),
),
original_request=requests[1],
),
)
self.assertSequenceEqual(expected_responses, responses)
@ -190,6 +196,7 @@ class ReflectionServicerTest(AioTestBase):
base_type_name=_EMPTY_EXTENSIONS_SYMBOL_NAME,
extension_number=_EMPTY_EXTENSIONS_NUMBERS,
),
original_request=requests[0],
),
reflection_pb2.ServerReflectionResponse(
valid_host="",
@ -197,6 +204,7 @@ class ReflectionServicerTest(AioTestBase):
error_code=grpc.StatusCode.NOT_FOUND.value[0],
error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(),
),
original_request=requests[1],
),
)
self.assertSequenceEqual(expected_responses, responses)
@ -219,6 +227,7 @@ class ReflectionServicerTest(AioTestBase):
for name in _SERVICE_NAMES
)
),
original_request=requests[0],
),
)
self.assertSequenceEqual(expected_responses, responses)

Loading…
Cancel
Save