Upgrade yapf to 0.20.0

Upgrade yapf version to 0.20.0 and reformat Python files.
pull/13899/head
Mehrdad Afshari 7 years ago
parent 63392f682e
commit 87cd994b04
  1. 12
      examples/python/interceptors/headers/header_manipulator_client_interceptor.py
  2. 10
      examples/python/multiplex/multiplex_client.py
  3. 4
      examples/python/multiplex/multiplex_server.py
  4. 18
      examples/python/multiplex/run_codegen.py
  5. 9
      examples/python/route_guide/route_guide_client.py
  6. 4
      examples/python/route_guide/route_guide_server.py
  7. 9
      examples/python/route_guide/run_codegen.py
  8. 24
      src/python/grpcio/commands.py
  9. 74
      src/python/grpcio/grpc/__init__.py
  10. 4
      src/python/grpcio/grpc/_auth.py
  11. 96
      src/python/grpcio/grpc/_channel.py
  12. 7
      src/python/grpcio/grpc/_interceptor.py
  13. 9
      src/python/grpcio/grpc/_plugin_wrapping.py
  14. 39
      src/python/grpcio/grpc/_server.py
  15. 12
      src/python/grpcio/grpc/_utilities.py
  16. 117
      src/python/grpcio/grpc/beta/_client_adaptations.py
  17. 5
      src/python/grpcio/grpc/beta/_metadata.py
  18. 75
      src/python/grpcio/grpc/beta/_server_adaptations.py
  19. 4
      src/python/grpcio/grpc/beta/implementations.py
  20. 4
      src/python/grpcio/grpc/framework/foundation/callable_util.py
  21. 15
      src/python/grpcio/grpc/framework/interfaces/base/utilities.py
  22. 15
      src/python/grpcio/grpc/framework/interfaces/face/face.py
  23. 6
      src/python/grpcio_health_checking/health_commands.py
  24. 6
      src/python/grpcio_health_checking/setup.py
  25. 15
      src/python/grpcio_reflection/grpc_reflection/v1alpha/reflection.py
  26. 6
      src/python/grpcio_reflection/setup.py
  27. 32
      src/python/grpcio_testing/grpc_testing/_channel/_multi_callable.py
  28. 4
      src/python/grpcio_testing/grpc_testing/_channel/_rpc_state.py
  29. 23
      src/python/grpcio_testing/grpc_testing/_common.py
  30. 8
      src/python/grpcio_testing/grpc_testing/_server/_handler.py
  31. 6
      src/python/grpcio_testing/grpc_testing/_server/_server.py
  32. 8
      src/python/grpcio_testing/grpc_testing/_time.py
  33. 6
      src/python/grpcio_testing/setup.py
  34. 3
      src/python/grpcio_tests/setup.py
  35. 4
      src/python/grpcio_tests/tests/_loader.py
  36. 31
      src/python/grpcio_tests/tests/_result.py
  37. 8
      src/python/grpcio_tests/tests/_runner.py
  38. 8
      src/python/grpcio_tests/tests/http2/negative_http2_client.py
  39. 4
      src/python/grpcio_tests/tests/interop/_intraop_test_case.py
  40. 11
      src/python/grpcio_tests/tests/interop/_secure_intraop_test.py
  41. 6
      src/python/grpcio_tests/tests/interop/client.py
  42. 100
      src/python/grpcio_tests/tests/interop/methods.py
  43. 4
      src/python/grpcio_tests/tests/interop/server.py
  44. 19
      src/python/grpcio_tests/tests/protoc_plugin/_python_plugin_test.py
  45. 48
      src/python/grpcio_tests/tests/protoc_plugin/_split_definitions_test.py
  46. 26
      src/python/grpcio_tests/tests/protoc_plugin/beta_python_plugin_test.py
  47. 3
      src/python/grpcio_tests/tests/qps/benchmark_client.py
  48. 16
      src/python/grpcio_tests/tests/qps/worker_server.py
  49. 75
      src/python/grpcio_tests/tests/reflection/_reflection_servicer_test.py
  50. 6
      src/python/grpcio_tests/tests/stress/client.py
  51. 8
      src/python/grpcio_tests/tests/testing/_client_application.py
  52. 6
      src/python/grpcio_tests/tests/testing/_client_test.py
  53. 6
      src/python/grpcio_tests/tests/testing/_server_application.py
  54. 11
      src/python/grpcio_tests/tests/testing/_server_test.py
  55. 4
      src/python/grpcio_tests/tests/testing/_time_test.py
  56. 84
      src/python/grpcio_tests/tests/unit/_api_test.py
  57. 13
      src/python/grpcio_tests/tests/unit/_auth_context_test.py
  58. 9
      src/python/grpcio_tests/tests/unit/_channel_args_test.py
  59. 8
      src/python/grpcio_tests/tests/unit/_compression_test.py
  60. 4
      src/python/grpcio_tests/tests/unit/_credentials_test.py
  61. 18
      src/python/grpcio_tests/tests/unit/_cython/_cancel_many_calls_test.py
  62. 5
      src/python/grpcio_tests/tests/unit/_cython/_channel_test.py
  63. 23
      src/python/grpcio_tests/tests/unit/_cython/_common.py
  64. 35
      src/python/grpcio_tests/tests/unit/_cython/_no_messages_server_completion_queue_per_call_test.py
  65. 35
      src/python/grpcio_tests/tests/unit/_cython/_no_messages_single_server_completion_queue_test.py
  66. 9
      src/python/grpcio_tests/tests/unit/_cython/_read_some_but_not_all_responses_test.py
  67. 59
      src/python/grpcio_tests/tests/unit/_cython/cygrpc_test.py
  68. 8
      src/python/grpcio_tests/tests/unit/_empty_message_test.py
  69. 158
      src/python/grpcio_tests/tests/unit/_interceptor_test.py
  70. 16
      src/python/grpcio_tests/tests/unit/_invalid_metadata_test.py
  71. 24
      src/python/grpcio_tests/tests/unit/_invocation_defects_test.py
  72. 48
      src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py
  73. 54
      src/python/grpcio_tests/tests/unit/_metadata_test.py
  74. 132
      src/python/grpcio_tests/tests/unit/_rpc_test.py
  75. 3
      src/python/grpcio_tests/tests/unit/_server_ssl_cert_config_test.py
  76. 4
      src/python/grpcio_tests/tests/unit/_thread_cleanup_test.py
  77. 10
      src/python/grpcio_tests/tests/unit/beta/_beta_features_test.py
  78. 19
      src/python/grpcio_tests/tests/unit/beta/_face_interface_test.py
  79. 4
      src/python/grpcio_tests/tests/unit/beta/_implementations_test.py
  80. 6
      src/python/grpcio_tests/tests/unit/beta/test_utilities.py
  81. 72
      src/python/grpcio_tests/tests/unit/framework/interfaces/face/_blocking_invocation_inline_service.py
  82. 22
      src/python/grpcio_tests/tests/unit/framework/interfaces/face/_digest.py
  83. 104
      src/python/grpcio_tests/tests/unit/framework/interfaces/face/_future_invocation_asynchronous_event_service.py
  84. 7
      src/python/grpcio_tests/tests/unit/framework/interfaces/face/_invocation.py
  85. 4
      src/python/grpcio_tests/tests/unit/framework/interfaces/face/_stock_service.py
  86. 8
      src/python/grpcio_tests/tests/unit/framework/interfaces/face/test_cases.py
  87. 6
      src/python/grpcio_tests/tests/unit/resources.py
  88. 24
      src/python/grpcio_tests/tests/unit/test_common.py
  89. 5
      tools/buildgen/bunch.py
  90. 8
      tools/buildgen/mako_renderer.py
  91. 7
      tools/buildgen/plugins/expand_filegroups.py
  92. 13
      tools/buildgen/plugins/generate_vsprojects.py
  93. 4
      tools/buildgen/plugins/transitive_dependencies.py
  94. 9
      tools/codegen/core/gen_settings_ids.py
  95. 16
      tools/codegen/core/gen_static_metadata.py
  96. 13
      tools/codegen/core/gen_stats_data.py
  97. 2
      tools/debug/core/error_ref_leak.py
  98. 12
      tools/distrib/check_copyright.py
  99. 16
      tools/distrib/check_include_guards.py
  100. 4
      tools/distrib/python/grpcio_tools/grpc_tools/command.py
  101. Some files were not shown because too many files have changed in this diff Show More

@ -20,9 +20,10 @@ import generic_client_interceptor
class _ClientCallDetails( class _ClientCallDetails(
collections.namedtuple('_ClientCallDetails', collections.namedtuple(
('method', 'timeout', 'metadata', '_ClientCallDetails',
'credentials')), grpc.ClientCallDetails): ('method', 'timeout', 'metadata', 'credentials')),
grpc.ClientCallDetails):
pass pass
@ -33,7 +34,10 @@ def header_adder_interceptor(header, value):
metadata = [] metadata = []
if client_call_details.metadata is not None: if client_call_details.metadata is not None:
metadata = list(client_call_details.metadata) metadata = list(client_call_details.metadata)
metadata.append((header, value,)) metadata.append((
header,
value,
))
client_call_details = _ClientCallDetails( client_call_details = _ClientCallDetails(
client_call_details.method, client_call_details.timeout, metadata, client_call_details.method, client_call_details.timeout, metadata,
client_call_details.credentials) client_call_details.credentials)

@ -46,9 +46,9 @@ def guide_get_one_feature(route_guide_stub, point):
def guide_get_feature(route_guide_stub): def guide_get_feature(route_guide_stub):
guide_get_one_feature( guide_get_one_feature(route_guide_stub,
route_guide_stub, route_guide_pb2.Point(
route_guide_pb2.Point(latitude=409146138, longitude=-746188906)) latitude=409146138, longitude=-746188906))
guide_get_one_feature(route_guide_stub, guide_get_one_feature(route_guide_stub,
route_guide_pb2.Point(latitude=0, longitude=0)) route_guide_pb2.Point(latitude=0, longitude=0))
@ -101,8 +101,8 @@ def generate_messages():
def guide_route_chat(route_guide_stub): def guide_route_chat(route_guide_stub):
responses = route_guide_stub.RouteChat(generate_messages()) responses = route_guide_stub.RouteChat(generate_messages())
for response in responses: for response in responses:
print("Received message %s at %s" % print("Received message %s at %s" % (response.message,
(response.message, response.location)) response.location))
def run(): def run():

@ -124,8 +124,8 @@ def serve():
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
helloworld_pb2_grpc.add_GreeterServicer_to_server(_GreeterServicer(), helloworld_pb2_grpc.add_GreeterServicer_to_server(_GreeterServicer(),
server) server)
route_guide_pb2_grpc.add_RouteGuideServicer_to_server(_RouteGuideServicer(), route_guide_pb2_grpc.add_RouteGuideServicer_to_server(
server) _RouteGuideServicer(), server)
server.add_insecure_port('[::]:50051') server.add_insecure_port('[::]:50051')
server.start() server.start()
try: try:

@ -15,7 +15,17 @@
from grpc_tools import protoc from grpc_tools import protoc
protoc.main(('', '-I../../protos', '--python_out=.', '--grpc_python_out=.', protoc.main((
'../../protos/helloworld.proto',)) '',
protoc.main(('', '-I../../protos', '--python_out=.', '--grpc_python_out=.', '-I../../protos',
'../../protos/route_guide.proto',)) '--python_out=.',
'--grpc_python_out=.',
'../../protos/helloworld.proto',
))
protoc.main((
'',
'-I../../protos',
'--python_out=.',
'--grpc_python_out=.',
'../../protos/route_guide.proto',
))

@ -43,8 +43,9 @@ def guide_get_one_feature(stub, point):
def guide_get_feature(stub): def guide_get_feature(stub):
guide_get_one_feature( guide_get_one_feature(stub,
stub, route_guide_pb2.Point(latitude=409146138, longitude=-746188906)) route_guide_pb2.Point(
latitude=409146138, longitude=-746188906))
guide_get_one_feature(stub, route_guide_pb2.Point(latitude=0, longitude=0)) guide_get_one_feature(stub, route_guide_pb2.Point(latitude=0, longitude=0))
@ -94,8 +95,8 @@ def generate_messages():
def guide_route_chat(stub): def guide_route_chat(stub):
responses = stub.RouteChat(generate_messages()) responses = stub.RouteChat(generate_messages())
for response in responses: for response in responses:
print("Received message %s at %s" % print("Received message %s at %s" % (response.message,
(response.message, response.location)) response.location))
def run(): def run():

@ -113,8 +113,8 @@ class RouteGuideServicer(route_guide_pb2_grpc.RouteGuideServicer):
def serve(): def serve():
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
route_guide_pb2_grpc.add_RouteGuideServicer_to_server(RouteGuideServicer(), route_guide_pb2_grpc.add_RouteGuideServicer_to_server(
server) RouteGuideServicer(), server)
server.add_insecure_port('[::]:50051') server.add_insecure_port('[::]:50051')
server.start() server.start()
try: try:

@ -15,5 +15,10 @@
from grpc_tools import protoc from grpc_tools import protoc
protoc.main(('', '-I../../protos', '--python_out=.', '--grpc_python_out=.', protoc.main((
'../../protos/route_guide.proto',)) '',
'-I../../protos',
'--python_out=.',
'--grpc_python_out=.',
'../../protos/route_guide.proto',
))

@ -104,8 +104,8 @@ def _get_grpc_custom_bdist(decorated_basename, target_bdist_basename):
with open(bdist_path, 'w') as bdist_file: with open(bdist_path, 'w') as bdist_file:
bdist_file.write(bdist_data) bdist_file.write(bdist_data)
except IOError as error: except IOError as error:
raise CommandError('{}\n\nCould not write grpcio bdist: {}' raise CommandError('{}\n\nCould not write grpcio bdist: {}'.format(
.format(traceback.format_exc(), error.message)) traceback.format_exc(), error.message))
return bdist_path return bdist_path
@ -141,7 +141,8 @@ class SphinxDocumentation(setuptools.Command):
with open(glossary_filepath, 'a') as glossary_filepath: with open(glossary_filepath, 'a') as glossary_filepath:
glossary_filepath.write(API_GLOSSARY) glossary_filepath.write(API_GLOSSARY)
sphinx.main( sphinx.main(
['', os.path.join('doc', 'src'), os.path.join('doc', 'build')]) ['', os.path.join('doc', 'src'),
os.path.join('doc', 'build')])
class BuildProjectMetadata(setuptools.Command): class BuildProjectMetadata(setuptools.Command):
@ -189,10 +190,11 @@ def check_and_update_cythonization(extensions):
for source in extension.sources: for source in extension.sources:
base, file_ext = os.path.splitext(source) base, file_ext = os.path.splitext(source)
if file_ext == '.pyx': if file_ext == '.pyx':
generated_pyx_source = next((base + gen_ext generated_pyx_source = next(
for gen_ext in ('.c', '.cpp',) (base + gen_ext for gen_ext in (
if os.path.isfile(base + gen_ext)), '.c',
None) '.cpp',
) if os.path.isfile(base + gen_ext)), None)
if generated_pyx_source: if generated_pyx_source:
generated_pyx_sources.append(generated_pyx_source) generated_pyx_sources.append(generated_pyx_source)
else: else:
@ -299,10 +301,10 @@ class Gather(setuptools.Command):
"""Command to gather project dependencies.""" """Command to gather project dependencies."""
description = 'gather dependencies for grpcio' description = 'gather dependencies for grpcio'
user_options = [ user_options = [('test', 't',
('test', 't', 'flag indicating to gather test dependencies'), 'flag indicating to gather test dependencies'),
('install', 'i', 'flag indicating to gather install dependencies') ('install', 'i',
] 'flag indicating to gather install dependencies')]
def initialize_options(self): def initialize_options(self):
self.test = False self.test = False

@ -1376,8 +1376,8 @@ def metadata_call_credentials(metadata_plugin, name=None):
A CallCredentials. A CallCredentials.
""" """
from grpc import _plugin_wrapping # pylint: disable=cyclic-import from grpc import _plugin_wrapping # pylint: disable=cyclic-import
return _plugin_wrapping.metadata_plugin_call_credentials(metadata_plugin, return _plugin_wrapping.metadata_plugin_call_credentials(
name) metadata_plugin, name)
def access_token_call_credentials(access_token): def access_token_call_credentials(access_token):
@ -1631,25 +1631,57 @@ def server(thread_pool,
################################### __all__ ################################# ################################### __all__ #################################
__all__ = ( __all__ = (
'FutureTimeoutError', 'FutureCancelledError', 'Future', 'FutureTimeoutError',
'ChannelConnectivity', 'StatusCode', 'RpcError', 'RpcContext', 'Call', 'FutureCancelledError',
'ChannelCredentials', 'CallCredentials', 'AuthMetadataContext', 'Future',
'AuthMetadataPluginCallback', 'AuthMetadataPlugin', 'ClientCallDetails', 'ChannelConnectivity',
'ServerCertificateConfiguration', 'ServerCredentials', 'StatusCode',
'UnaryUnaryMultiCallable', 'UnaryStreamMultiCallable', 'RpcError',
'StreamUnaryMultiCallable', 'StreamStreamMultiCallable', 'RpcContext',
'UnaryUnaryClientInterceptor', 'UnaryStreamClientInterceptor', 'Call',
'StreamUnaryClientInterceptor', 'StreamStreamClientInterceptor', 'Channel', 'ChannelCredentials',
'ServicerContext', 'RpcMethodHandler', 'HandlerCallDetails', 'CallCredentials',
'GenericRpcHandler', 'ServiceRpcHandler', 'Server', 'ServerInterceptor', 'AuthMetadataContext',
'unary_unary_rpc_method_handler', 'unary_stream_rpc_method_handler', 'AuthMetadataPluginCallback',
'stream_unary_rpc_method_handler', 'stream_stream_rpc_method_handler', 'AuthMetadataPlugin',
'method_handlers_generic_handler', 'ssl_channel_credentials', 'ClientCallDetails',
'metadata_call_credentials', 'access_token_call_credentials', 'ServerCertificateConfiguration',
'composite_call_credentials', 'composite_channel_credentials', 'ServerCredentials',
'ssl_server_credentials', 'ssl_server_certificate_configuration', 'UnaryUnaryMultiCallable',
'dynamic_ssl_server_credentials', 'channel_ready_future', 'UnaryStreamMultiCallable',
'insecure_channel', 'secure_channel', 'intercept_channel', 'server',) 'StreamUnaryMultiCallable',
'StreamStreamMultiCallable',
'UnaryUnaryClientInterceptor',
'UnaryStreamClientInterceptor',
'StreamUnaryClientInterceptor',
'StreamStreamClientInterceptor',
'Channel',
'ServicerContext',
'RpcMethodHandler',
'HandlerCallDetails',
'GenericRpcHandler',
'ServiceRpcHandler',
'Server',
'ServerInterceptor',
'unary_unary_rpc_method_handler',
'unary_stream_rpc_method_handler',
'stream_unary_rpc_method_handler',
'stream_stream_rpc_method_handler',
'method_handlers_generic_handler',
'ssl_channel_credentials',
'metadata_call_credentials',
'access_token_call_credentials',
'composite_call_credentials',
'composite_channel_credentials',
'ssl_server_credentials',
'ssl_server_certificate_configuration',
'dynamic_ssl_server_credentials',
'channel_ready_future',
'insecure_channel',
'secure_channel',
'intercept_channel',
'server',
)
############################### Extension Shims ################################ ############################### Extension Shims ################################

@ -54,7 +54,9 @@ class GoogleCallCredentials(grpc.AuthMetadataPlugin):
if self._is_jwt: if self._is_jwt:
future = self._pool.submit( future = self._pool.submit(
self._credentials.get_access_token, self._credentials.get_access_token,
additional_claims={'aud': context.service_url}) additional_claims={
'aud': context.service_url
})
else: else:
future = self._pool.submit(self._credentials.get_access_token) future = self._pool.submit(self._credentials.get_access_token)
future.add_done_callback(_create_get_token_callback(callback)) future.add_done_callback(_create_get_token_callback(callback))

@ -29,24 +29,32 @@ _USER_AGENT = 'grpc-python/{}'.format(_grpcio_metadata.__version__)
_EMPTY_FLAGS = 0 _EMPTY_FLAGS = 0
_INFINITE_FUTURE = cygrpc.Timespec(float('+inf')) _INFINITE_FUTURE = cygrpc.Timespec(float('+inf'))
_UNARY_UNARY_INITIAL_DUE = (cygrpc.OperationType.send_initial_metadata, _UNARY_UNARY_INITIAL_DUE = (
cygrpc.OperationType.send_message, cygrpc.OperationType.send_initial_metadata,
cygrpc.OperationType.send_close_from_client, cygrpc.OperationType.send_message,
cygrpc.OperationType.receive_initial_metadata, cygrpc.OperationType.send_close_from_client,
cygrpc.OperationType.receive_message, cygrpc.OperationType.receive_initial_metadata,
cygrpc.OperationType.receive_status_on_client,) cygrpc.OperationType.receive_message,
_UNARY_STREAM_INITIAL_DUE = (cygrpc.OperationType.send_initial_metadata, cygrpc.OperationType.receive_status_on_client,
cygrpc.OperationType.send_message, )
cygrpc.OperationType.send_close_from_client, _UNARY_STREAM_INITIAL_DUE = (
cygrpc.OperationType.receive_initial_metadata, cygrpc.OperationType.send_initial_metadata,
cygrpc.OperationType.receive_status_on_client,) cygrpc.OperationType.send_message,
_STREAM_UNARY_INITIAL_DUE = (cygrpc.OperationType.send_initial_metadata, cygrpc.OperationType.send_close_from_client,
cygrpc.OperationType.receive_initial_metadata, cygrpc.OperationType.receive_initial_metadata,
cygrpc.OperationType.receive_message, cygrpc.OperationType.receive_status_on_client,
cygrpc.OperationType.receive_status_on_client,) )
_STREAM_STREAM_INITIAL_DUE = (cygrpc.OperationType.send_initial_metadata, _STREAM_UNARY_INITIAL_DUE = (
cygrpc.OperationType.receive_initial_metadata, cygrpc.OperationType.send_initial_metadata,
cygrpc.OperationType.receive_status_on_client,) cygrpc.OperationType.receive_initial_metadata,
cygrpc.OperationType.receive_message,
cygrpc.OperationType.receive_status_on_client,
)
_STREAM_STREAM_INITIAL_DUE = (
cygrpc.OperationType.send_initial_metadata,
cygrpc.OperationType.receive_initial_metadata,
cygrpc.OperationType.receive_status_on_client,
)
_CHANNEL_SUBSCRIPTION_CALLBACK_ERROR_LOG_MESSAGE = ( _CHANNEL_SUBSCRIPTION_CALLBACK_ERROR_LOG_MESSAGE = (
'Exception calling channel subscription callback!') 'Exception calling channel subscription callback!')
@ -457,7 +465,8 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS), cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS),
cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS), cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),
cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),
cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS),) cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS),
)
return state, operations, deadline, deadline_timespec, None return state, operations, deadline, deadline_timespec, None
def _blocking(self, request, timeout, metadata, credentials): def _blocking(self, request, timeout, metadata, credentials):
@ -538,11 +547,12 @@ class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
(cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),), (cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),),
event_handler) event_handler)
operations = ( operations = (
cygrpc.SendInitialMetadataOperation( cygrpc.SendInitialMetadataOperation(metadata, _EMPTY_FLAGS),
metadata, _EMPTY_FLAGS), cygrpc.SendMessageOperation( cygrpc.SendMessageOperation(serialized_request,
serialized_request, _EMPTY_FLAGS), _EMPTY_FLAGS),
cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS), cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS),
cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS),) cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS),
)
call_error = call.start_client_batch(operations, event_handler) call_error = call.start_client_batch(operations, event_handler)
if call_error != cygrpc.CallError.ok: if call_error != cygrpc.CallError.ok:
_call_error_set_RPCstate(state, call_error, metadata) _call_error_set_RPCstate(state, call_error, metadata)
@ -576,7 +586,8 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
operations = ( operations = (
cygrpc.SendInitialMetadataOperation(metadata, _EMPTY_FLAGS), cygrpc.SendInitialMetadataOperation(metadata, _EMPTY_FLAGS),
cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),
cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS),) cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS),
)
call_error = call.start_client_batch(operations, None) call_error = call.start_client_batch(operations, None)
_check_call_error(call_error, metadata) _check_call_error(call_error, metadata)
_consume_request_iterator(request_iterator, state, call, _consume_request_iterator(request_iterator, state, call,
@ -627,7 +638,8 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
operations = ( operations = (
cygrpc.SendInitialMetadataOperation(metadata, _EMPTY_FLAGS), cygrpc.SendInitialMetadataOperation(metadata, _EMPTY_FLAGS),
cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),
cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS),) cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS),
)
call_error = call.start_client_batch(operations, event_handler) call_error = call.start_client_batch(operations, event_handler)
if call_error != cygrpc.CallError.ok: if call_error != cygrpc.CallError.ok:
_call_error_set_RPCstate(state, call_error, metadata) _call_error_set_RPCstate(state, call_error, metadata)
@ -666,7 +678,8 @@ class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable):
event_handler) event_handler)
operations = ( operations = (
cygrpc.SendInitialMetadataOperation(metadata, _EMPTY_FLAGS), cygrpc.SendInitialMetadataOperation(metadata, _EMPTY_FLAGS),
cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS),) cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS),
)
call_error = call.start_client_batch(operations, event_handler) call_error = call.start_client_batch(operations, event_handler)
if call_error != cygrpc.CallError.ok: if call_error != cygrpc.CallError.ok:
_call_error_set_RPCstate(state, call_error, metadata) _call_error_set_RPCstate(state, call_error, metadata)
@ -787,7 +800,11 @@ def _deliver(state, initial_connectivity, initial_callbacks):
def _spawn_delivery(state, callbacks): def _spawn_delivery(state, callbacks):
delivering_thread = threading.Thread( delivering_thread = threading.Thread(
target=_deliver, args=(state, state.connectivity, callbacks,)) target=_deliver, args=(
state,
state.connectivity,
callbacks,
))
delivering_thread.start() delivering_thread.start()
state.delivering = True state.delivering = True
@ -862,17 +879,16 @@ def _subscribe(state, callback, try_to_connect):
def _unsubscribe(state, callback): def _unsubscribe(state, callback):
with state.lock: with state.lock:
for index, (subscribed_callback, unused_connectivity for index, (subscribed_callback, unused_connectivity) in enumerate(
) in enumerate(state.callbacks_and_connectivities): state.callbacks_and_connectivities):
if callback == subscribed_callback: if callback == subscribed_callback:
state.callbacks_and_connectivities.pop(index) state.callbacks_and_connectivities.pop(index)
break break
def _options(options): def _options(options):
return list(options) + [ return list(options) + [(cygrpc.ChannelArgKey.primary_user_agent_string,
(cygrpc.ChannelArgKey.primary_user_agent_string, _USER_AGENT) _USER_AGENT)]
]
class Channel(grpc.Channel): class Channel(grpc.Channel):
@ -887,8 +903,8 @@ class Channel(grpc.Channel):
credentials: A cygrpc.ChannelCredentials or None. credentials: A cygrpc.ChannelCredentials or None.
""" """
self._channel = cygrpc.Channel( self._channel = cygrpc.Channel(
_common.encode(target), _common.encode(target), _common.channel_args(_options(options)),
_common.channel_args(_options(options)), credentials) credentials)
self._call_state = _ChannelCallState(self._channel) self._call_state = _ChannelCallState(self._channel)
self._connectivity_state = _ChannelConnectivityState(self._channel) self._connectivity_state = _ChannelConnectivityState(self._channel)
@ -908,8 +924,7 @@ class Channel(grpc.Channel):
request_serializer=None, request_serializer=None,
response_deserializer=None): response_deserializer=None):
return _UnaryUnaryMultiCallable( return _UnaryUnaryMultiCallable(
self._channel, self._channel, _channel_managed_call_management(self._call_state),
_channel_managed_call_management(self._call_state),
_common.encode(method), request_serializer, response_deserializer) _common.encode(method), request_serializer, response_deserializer)
def unary_stream(self, def unary_stream(self,
@ -917,8 +932,7 @@ class Channel(grpc.Channel):
request_serializer=None, request_serializer=None,
response_deserializer=None): response_deserializer=None):
return _UnaryStreamMultiCallable( return _UnaryStreamMultiCallable(
self._channel, self._channel, _channel_managed_call_management(self._call_state),
_channel_managed_call_management(self._call_state),
_common.encode(method), request_serializer, response_deserializer) _common.encode(method), request_serializer, response_deserializer)
def stream_unary(self, def stream_unary(self,
@ -926,8 +940,7 @@ class Channel(grpc.Channel):
request_serializer=None, request_serializer=None,
response_deserializer=None): response_deserializer=None):
return _StreamUnaryMultiCallable( return _StreamUnaryMultiCallable(
self._channel, self._channel, _channel_managed_call_management(self._call_state),
_channel_managed_call_management(self._call_state),
_common.encode(method), request_serializer, response_deserializer) _common.encode(method), request_serializer, response_deserializer)
def stream_stream(self, def stream_stream(self,
@ -935,8 +948,7 @@ class Channel(grpc.Channel):
request_serializer=None, request_serializer=None,
response_deserializer=None): response_deserializer=None):
return _StreamStreamMultiCallable( return _StreamStreamMultiCallable(
self._channel, self._channel, _channel_managed_call_management(self._call_state),
_channel_managed_call_management(self._call_state),
_common.encode(method), request_serializer, response_deserializer) _common.encode(method), request_serializer, response_deserializer)
def __del__(self): def __del__(self):

@ -44,9 +44,10 @@ def service_pipeline(interceptors):
class _ClientCallDetails( class _ClientCallDetails(
collections.namedtuple('_ClientCallDetails', collections.namedtuple(
('method', 'timeout', 'metadata', '_ClientCallDetails',
'credentials')), grpc.ClientCallDetails): ('method', 'timeout', 'metadata', 'credentials')),
grpc.ClientCallDetails):
pass pass

@ -23,7 +23,9 @@ from grpc._cython import cygrpc
class _AuthMetadataContext( class _AuthMetadataContext(
collections.namedtuple('AuthMetadataContext', ( collections.namedtuple('AuthMetadataContext', (
'service_url', 'method_name',)), grpc.AuthMetadataContext): 'service_url',
'method_name',
)), grpc.AuthMetadataContext):
pass pass
@ -70,8 +72,9 @@ class _Plugin(object):
_common.decode(service_url), _common.decode(method_name)) _common.decode(service_url), _common.decode(method_name))
callback_state = _CallbackState() callback_state = _CallbackState()
try: try:
self._metadata_plugin( self._metadata_plugin(context,
context, _AuthMetadataPluginCallback(callback_state, callback)) _AuthMetadataPluginCallback(
callback_state, callback))
except Exception as exception: # pylint: disable=broad-except except Exception as exception: # pylint: disable=broad-except
logging.exception( logging.exception(
'AuthMetadataPluginCallback "%s" raised exception!', 'AuthMetadataPluginCallback "%s" raised exception!',

@ -78,7 +78,9 @@ def _details(state):
class _HandlerCallDetails( class _HandlerCallDetails(
collections.namedtuple('_HandlerCallDetails', ( collections.namedtuple('_HandlerCallDetails', (
'method', 'invocation_metadata',)), grpc.HandlerCallDetails): 'method',
'invocation_metadata',
)), grpc.HandlerCallDetails):
pass pass
@ -130,10 +132,12 @@ def _abort(state, call, code, details):
effective_code = _abortion_code(state, code) effective_code = _abortion_code(state, code)
effective_details = details if state.details is None else state.details effective_details = details if state.details is None else state.details
if state.initial_metadata_allowed: if state.initial_metadata_allowed:
operations = (cygrpc.SendInitialMetadataOperation( operations = (
None, _EMPTY_FLAGS), cygrpc.SendStatusFromServerOperation( cygrpc.SendInitialMetadataOperation(None, _EMPTY_FLAGS),
cygrpc.SendStatusFromServerOperation(
state.trailing_metadata, effective_code, effective_details, state.trailing_metadata, effective_code, effective_details,
_EMPTY_FLAGS),) _EMPTY_FLAGS),
)
token = _SEND_INITIAL_METADATA_AND_SEND_STATUS_FROM_SERVER_TOKEN token = _SEND_INITIAL_METADATA_AND_SEND_STATUS_FROM_SERVER_TOKEN
else: else:
operations = (cygrpc.SendStatusFromServerOperation( operations = (cygrpc.SendStatusFromServerOperation(
@ -422,15 +426,16 @@ def _send_response(rpc_event, state, serialized_response):
return False return False
else: else:
if state.initial_metadata_allowed: if state.initial_metadata_allowed:
operations = (cygrpc.SendInitialMetadataOperation(None, operations = (
_EMPTY_FLAGS), cygrpc.SendInitialMetadataOperation(None, _EMPTY_FLAGS),
cygrpc.SendMessageOperation(serialized_response, cygrpc.SendMessageOperation(serialized_response,
_EMPTY_FLAGS),) _EMPTY_FLAGS),
)
state.initial_metadata_allowed = False state.initial_metadata_allowed = False
token = _SEND_INITIAL_METADATA_AND_SEND_MESSAGE_TOKEN token = _SEND_INITIAL_METADATA_AND_SEND_MESSAGE_TOKEN
else: else:
operations = (cygrpc.SendMessageOperation(serialized_response, operations = (cygrpc.SendMessageOperation(
_EMPTY_FLAGS),) serialized_response, _EMPTY_FLAGS),)
token = _SEND_MESSAGE_TOKEN token = _SEND_MESSAGE_TOKEN
rpc_event.call.start_server_batch(operations, rpc_event.call.start_server_batch(operations,
_send_message(state, token)) _send_message(state, token))
@ -562,10 +567,12 @@ def _find_method_handler(rpc_event, generic_handlers, interceptor_pipeline):
def _reject_rpc(rpc_event, status, details): def _reject_rpc(rpc_event, status, details):
operations = (cygrpc.SendInitialMetadataOperation(None, _EMPTY_FLAGS), operations = (
cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS), cygrpc.SendInitialMetadataOperation(None, _EMPTY_FLAGS),
cygrpc.SendStatusFromServerOperation(None, status, details, cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS),
_EMPTY_FLAGS),) cygrpc.SendStatusFromServerOperation(None, status, details,
_EMPTY_FLAGS),
)
rpc_state = _RPCState() rpc_state = _RPCState()
rpc_event.call.start_server_batch(operations, rpc_event.call.start_server_batch(operations,
lambda ignored_event: (rpc_state, (),)) lambda ignored_event: (rpc_state, (),))
@ -798,8 +805,8 @@ class Server(grpc.Server):
return _add_insecure_port(self._state, _common.encode(address)) return _add_insecure_port(self._state, _common.encode(address))
def add_secure_port(self, address, server_credentials): def add_secure_port(self, address, server_credentials):
return _add_secure_port(self._state, return _add_secure_port(self._state, _common.encode(address),
_common.encode(address), server_credentials) server_credentials)
def start(self): def start(self):
_start(self._state) _start(self._state)

@ -29,9 +29,15 @@ _DONE_CALLBACK_EXCEPTION_LOG_MESSAGE = (
class RpcMethodHandler( class RpcMethodHandler(
collections.namedtuple('_RpcMethodHandler', ( collections.namedtuple('_RpcMethodHandler', (
'request_streaming', 'response_streaming', 'request_deserializer', 'request_streaming',
'response_serializer', 'unary_unary', 'unary_stream', 'response_streaming',
'stream_unary', 'stream_stream',)), grpc.RpcMethodHandler): 'request_deserializer',
'response_serializer',
'unary_unary',
'unary_stream',
'stream_unary',
'stream_stream',
)), grpc.RpcMethodHandler):
pass pass

@ -51,8 +51,7 @@ def _abortion(rpc_error_call):
code = rpc_error_call.code() code = rpc_error_call.code()
pair = _STATUS_CODE_TO_ABORTION_KIND_AND_ABORTION_ERROR_CLASS.get(code) pair = _STATUS_CODE_TO_ABORTION_KIND_AND_ABORTION_ERROR_CLASS.get(code)
error_kind = face.Abortion.Kind.LOCAL_FAILURE if pair is None else pair[0] error_kind = face.Abortion.Kind.LOCAL_FAILURE if pair is None else pair[0]
return face.Abortion(error_kind, return face.Abortion(error_kind, rpc_error_call.initial_metadata(),
rpc_error_call.initial_metadata(),
rpc_error_call.trailing_metadata(), code, rpc_error_call.trailing_metadata(), code,
rpc_error_call.details()) rpc_error_call.details())
@ -441,9 +440,14 @@ class _GenericStub(face.GenericStub):
metadata=None, metadata=None,
with_call=None, with_call=None,
protocol_options=None): protocol_options=None):
request_serializer = self._request_serializers.get((group, method,)) request_serializer = self._request_serializers.get((
response_deserializer = self._response_deserializers.get((group, group,
method,)) method,
))
response_deserializer = self._response_deserializers.get((
group,
method,
))
return _blocking_unary_unary(self._channel, group, method, timeout, return _blocking_unary_unary(self._channel, group, method, timeout,
with_call, protocol_options, metadata, with_call, protocol_options, metadata,
self._metadata_transformer, request, self._metadata_transformer, request,
@ -456,9 +460,14 @@ class _GenericStub(face.GenericStub):
timeout, timeout,
metadata=None, metadata=None,
protocol_options=None): protocol_options=None):
request_serializer = self._request_serializers.get((group, method,)) request_serializer = self._request_serializers.get((
response_deserializer = self._response_deserializers.get((group, group,
method,)) method,
))
response_deserializer = self._response_deserializers.get((
group,
method,
))
return _future_unary_unary(self._channel, group, method, timeout, return _future_unary_unary(self._channel, group, method, timeout,
protocol_options, metadata, protocol_options, metadata,
self._metadata_transformer, request, self._metadata_transformer, request,
@ -471,9 +480,14 @@ class _GenericStub(face.GenericStub):
timeout, timeout,
metadata=None, metadata=None,
protocol_options=None): protocol_options=None):
request_serializer = self._request_serializers.get((group, method,)) request_serializer = self._request_serializers.get((
response_deserializer = self._response_deserializers.get((group, group,
method,)) method,
))
response_deserializer = self._response_deserializers.get((
group,
method,
))
return _unary_stream(self._channel, group, method, timeout, return _unary_stream(self._channel, group, method, timeout,
protocol_options, metadata, protocol_options, metadata,
self._metadata_transformer, request, self._metadata_transformer, request,
@ -487,9 +501,14 @@ class _GenericStub(face.GenericStub):
metadata=None, metadata=None,
with_call=None, with_call=None,
protocol_options=None): protocol_options=None):
request_serializer = self._request_serializers.get((group, method,)) request_serializer = self._request_serializers.get((
response_deserializer = self._response_deserializers.get((group, group,
method,)) method,
))
response_deserializer = self._response_deserializers.get((
group,
method,
))
return _blocking_stream_unary( return _blocking_stream_unary(
self._channel, group, method, timeout, with_call, protocol_options, self._channel, group, method, timeout, with_call, protocol_options,
metadata, self._metadata_transformer, request_iterator, metadata, self._metadata_transformer, request_iterator,
@ -502,9 +521,14 @@ class _GenericStub(face.GenericStub):
timeout, timeout,
metadata=None, metadata=None,
protocol_options=None): protocol_options=None):
request_serializer = self._request_serializers.get((group, method,)) request_serializer = self._request_serializers.get((
response_deserializer = self._response_deserializers.get((group, group,
method,)) method,
))
response_deserializer = self._response_deserializers.get((
group,
method,
))
return _future_stream_unary( return _future_stream_unary(
self._channel, group, method, timeout, protocol_options, metadata, self._channel, group, method, timeout, protocol_options, metadata,
self._metadata_transformer, request_iterator, request_serializer, self._metadata_transformer, request_iterator, request_serializer,
@ -517,9 +541,14 @@ class _GenericStub(face.GenericStub):
timeout, timeout,
metadata=None, metadata=None,
protocol_options=None): protocol_options=None):
request_serializer = self._request_serializers.get((group, method,)) request_serializer = self._request_serializers.get((
response_deserializer = self._response_deserializers.get((group, group,
method,)) method,
))
response_deserializer = self._response_deserializers.get((
group,
method,
))
return _stream_stream(self._channel, group, method, timeout, return _stream_stream(self._channel, group, method, timeout,
protocol_options, metadata, protocol_options, metadata,
self._metadata_transformer, request_iterator, self._metadata_transformer, request_iterator,
@ -568,33 +597,53 @@ class _GenericStub(face.GenericStub):
raise NotImplementedError() raise NotImplementedError()
def unary_unary(self, group, method): def unary_unary(self, group, method):
request_serializer = self._request_serializers.get((group, method,)) request_serializer = self._request_serializers.get((
response_deserializer = self._response_deserializers.get((group, group,
method,)) method,
))
response_deserializer = self._response_deserializers.get((
group,
method,
))
return _UnaryUnaryMultiCallable( return _UnaryUnaryMultiCallable(
self._channel, group, method, self._metadata_transformer, self._channel, group, method, self._metadata_transformer,
request_serializer, response_deserializer) request_serializer, response_deserializer)
def unary_stream(self, group, method): def unary_stream(self, group, method):
request_serializer = self._request_serializers.get((group, method,)) request_serializer = self._request_serializers.get((
response_deserializer = self._response_deserializers.get((group, group,
method,)) method,
))
response_deserializer = self._response_deserializers.get((
group,
method,
))
return _UnaryStreamMultiCallable( return _UnaryStreamMultiCallable(
self._channel, group, method, self._metadata_transformer, self._channel, group, method, self._metadata_transformer,
request_serializer, response_deserializer) request_serializer, response_deserializer)
def stream_unary(self, group, method): def stream_unary(self, group, method):
request_serializer = self._request_serializers.get((group, method,)) request_serializer = self._request_serializers.get((
response_deserializer = self._response_deserializers.get((group, group,
method,)) method,
))
response_deserializer = self._response_deserializers.get((
group,
method,
))
return _StreamUnaryMultiCallable( return _StreamUnaryMultiCallable(
self._channel, group, method, self._metadata_transformer, self._channel, group, method, self._metadata_transformer,
request_serializer, response_deserializer) request_serializer, response_deserializer)
def stream_stream(self, group, method): def stream_stream(self, group, method):
request_serializer = self._request_serializers.get((group, method,)) request_serializer = self._request_serializers.get((
response_deserializer = self._response_deserializers.get((group, group,
method,)) method,
))
response_deserializer = self._response_deserializers.get((
group,
method,
))
return _StreamStreamMultiCallable( return _StreamStreamMultiCallable(
self._channel, group, method, self._metadata_transformer, self._channel, group, method, self._metadata_transformer,
request_serializer, response_deserializer) request_serializer, response_deserializer)
@ -624,8 +673,8 @@ class _DynamicStub(face.DynamicStub):
elif method_cardinality is cardinality.Cardinality.STREAM_STREAM: elif method_cardinality is cardinality.Cardinality.STREAM_STREAM:
return self._generic_stub.stream_stream(self._group, attr) return self._generic_stub.stream_stream(self._group, attr)
else: else:
raise AttributeError('_DynamicStub object has no attribute "%s"!' % raise AttributeError(
attr) '_DynamicStub object has no attribute "%s"!' % attr)
def __enter__(self): def __enter__(self):
return self return self

@ -15,7 +15,10 @@
import collections import collections
_Metadatum = collections.namedtuple('_Metadatum', ('key', 'value',)) _Metadatum = collections.namedtuple('_Metadatum', (
'key',
'value',
))
def _beta_metadatum(key, value): def _beta_metadatum(key, value):

@ -245,9 +245,15 @@ def _adapt_stream_stream_event(stream_stream_event):
class _SimpleMethodHandler( class _SimpleMethodHandler(
collections.namedtuple('_MethodHandler', ( collections.namedtuple('_MethodHandler', (
'request_streaming', 'response_streaming', 'request_deserializer', 'request_streaming',
'response_serializer', 'unary_unary', 'unary_stream', 'response_streaming',
'stream_unary', 'stream_stream',)), grpc.RpcMethodHandler): 'request_deserializer',
'response_serializer',
'unary_unary',
'unary_stream',
'stream_unary',
'stream_stream',
)), grpc.RpcMethodHandler):
pass pass
@ -255,15 +261,17 @@ def _simple_method_handler(implementation, request_deserializer,
response_serializer): response_serializer):
if implementation.style is style.Service.INLINE: if implementation.style is style.Service.INLINE:
if implementation.cardinality is cardinality.Cardinality.UNARY_UNARY: if implementation.cardinality is cardinality.Cardinality.UNARY_UNARY:
return _SimpleMethodHandler( return _SimpleMethodHandler(False, False, request_deserializer,
False, False, request_deserializer, response_serializer, response_serializer,
_adapt_unary_request_inline(implementation.unary_unary_inline), _adapt_unary_request_inline(
None, None, None) implementation.unary_unary_inline),
None, None, None)
elif implementation.cardinality is cardinality.Cardinality.UNARY_STREAM: elif implementation.cardinality is cardinality.Cardinality.UNARY_STREAM:
return _SimpleMethodHandler( return _SimpleMethodHandler(False, True, request_deserializer,
False, True, request_deserializer, response_serializer, None, response_serializer, None,
_adapt_unary_request_inline(implementation.unary_stream_inline), _adapt_unary_request_inline(
None, None) implementation.unary_stream_inline),
None, None)
elif implementation.cardinality is cardinality.Cardinality.STREAM_UNARY: elif implementation.cardinality is cardinality.Cardinality.STREAM_UNARY:
return _SimpleMethodHandler(True, False, request_deserializer, return _SimpleMethodHandler(True, False, request_deserializer,
response_serializer, None, None, response_serializer, None, None,
@ -278,26 +286,28 @@ def _simple_method_handler(implementation, request_deserializer,
implementation.stream_stream_inline)) implementation.stream_stream_inline))
elif implementation.style is style.Service.EVENT: elif implementation.style is style.Service.EVENT:
if implementation.cardinality is cardinality.Cardinality.UNARY_UNARY: if implementation.cardinality is cardinality.Cardinality.UNARY_UNARY:
return _SimpleMethodHandler( return _SimpleMethodHandler(False, False, request_deserializer,
False, False, request_deserializer, response_serializer, response_serializer,
_adapt_unary_unary_event(implementation.unary_unary_event), _adapt_unary_unary_event(
None, None, None) implementation.unary_unary_event),
None, None, None)
elif implementation.cardinality is cardinality.Cardinality.UNARY_STREAM: elif implementation.cardinality is cardinality.Cardinality.UNARY_STREAM:
return _SimpleMethodHandler( return _SimpleMethodHandler(False, True, request_deserializer,
False, True, request_deserializer, response_serializer, None, response_serializer, None,
_adapt_unary_stream_event(implementation.unary_stream_event), _adapt_unary_stream_event(
None, None) implementation.unary_stream_event),
None, None)
elif implementation.cardinality is cardinality.Cardinality.STREAM_UNARY: elif implementation.cardinality is cardinality.Cardinality.STREAM_UNARY:
return _SimpleMethodHandler( return _SimpleMethodHandler(True, False, request_deserializer,
True, False, request_deserializer, response_serializer, None, response_serializer, None, None,
None, _adapt_stream_unary_event(
_adapt_stream_unary_event(implementation.stream_unary_event), implementation.stream_unary_event),
None) None)
elif implementation.cardinality is cardinality.Cardinality.STREAM_STREAM: elif implementation.cardinality is cardinality.Cardinality.STREAM_STREAM:
return _SimpleMethodHandler( return _SimpleMethodHandler(True, True, request_deserializer,
True, True, request_deserializer, response_serializer, None, response_serializer, None, None, None,
None, None, _adapt_stream_stream_event(
_adapt_stream_stream_event(implementation.stream_stream_event)) implementation.stream_stream_event))
def _flatten_method_pair_map(method_pair_map): def _flatten_method_pair_map(method_pair_map):
@ -325,10 +335,11 @@ class _GenericRpcHandler(grpc.GenericRpcHandler):
method_implementation = self._method_implementations.get( method_implementation = self._method_implementations.get(
handler_call_details.method) handler_call_details.method)
if method_implementation is not None: if method_implementation is not None:
return _simple_method_handler( return _simple_method_handler(method_implementation,
method_implementation, self._request_deserializers.get(
self._request_deserializers.get(handler_call_details.method), handler_call_details.method),
self._response_serializers.get(handler_call_details.method)) self._response_serializers.get(
handler_call_details.method))
elif self._multi_method_implementation is None: elif self._multi_method_implementation is None:
return None return None
else: else:

@ -110,8 +110,8 @@ def insecure_channel(host, port):
Returns: Returns:
A Channel to the remote host through which RPCs may be conducted. A Channel to the remote host through which RPCs may be conducted.
""" """
channel = grpc.insecure_channel(host channel = grpc.insecure_channel(host if port is None else '%s:%d' % (host,
if port is None else '%s:%d' % (host, port)) port))
return Channel(channel) return Channel(channel)

@ -50,8 +50,8 @@ class _EasyOutcome(
def _call_logging_exceptions(behavior, message, *args, **kwargs): def _call_logging_exceptions(behavior, message, *args, **kwargs):
try: try:
return _EasyOutcome(Outcome.Kind.RETURNED, return _EasyOutcome(Outcome.Kind.RETURNED, behavior(*args, **kwargs),
behavior(*args, **kwargs), None) None)
except Exception as e: # pylint: disable=broad-except except Exception as e: # pylint: disable=broad-except
logging.exception(message) logging.exception(message)
return _EasyOutcome(Outcome.Kind.RAISED, None, e) return _EasyOutcome(Outcome.Kind.RAISED, None, e)

@ -19,15 +19,22 @@ from grpc.framework.interfaces.base import base
class _Completion(base.Completion, class _Completion(base.Completion,
collections.namedtuple('_Completion', ('terminal_metadata', collections.namedtuple('_Completion', (
'code', 'message',))): 'terminal_metadata',
'code',
'message',
))):
"""A trivial implementation of base.Completion.""" """A trivial implementation of base.Completion."""
class _Subscription(base.Subscription, class _Subscription(base.Subscription,
collections.namedtuple('_Subscription', ( collections.namedtuple('_Subscription', (
'kind', 'termination_callback', 'allowance', 'operator', 'kind',
'protocol_receiver',))): 'termination_callback',
'allowance',
'operator',
'protocol_receiver',
))):
"""A trivial implementation of base.Subscription.""" """A trivial implementation of base.Subscription."""

@ -50,13 +50,20 @@ class NoSuchMethodError(Exception):
self.method = method self.method = method
def __repr__(self): def __repr__(self):
return 'face.NoSuchMethodError(%s, %s)' % (self.group, self.method,) return 'face.NoSuchMethodError(%s, %s)' % (
self.group,
self.method,
)
class Abortion( class Abortion(
collections.namedtuple('Abortion', collections.namedtuple('Abortion', (
('kind', 'initial_metadata', 'terminal_metadata', 'kind',
'code', 'details',))): 'initial_metadata',
'terminal_metadata',
'code',
'details',
))):
"""A value describing RPC abortion. """A value describing RPC abortion.
Attributes: Attributes:

@ -36,9 +36,9 @@ class CopyProtoModules(setuptools.Command):
def run(self): def run(self):
if os.path.isfile(HEALTH_PROTO): if os.path.isfile(HEALTH_PROTO):
shutil.copyfile( shutil.copyfile(HEALTH_PROTO,
HEALTH_PROTO, os.path.join(ROOT_DIR,
os.path.join(ROOT_DIR, 'grpc_health/v1/health.proto')) 'grpc_health/v1/health.proto'))
class BuildPackageProtos(setuptools.Command): class BuildPackageProtos(setuptools.Command):

@ -56,8 +56,10 @@ PACKAGE_DIRECTORIES = {
'': '.', '': '.',
} }
INSTALL_REQUIRES = ('protobuf>=3.5.0.post1', INSTALL_REQUIRES = (
'grpcio>={version}'.format(version=grpc_version.VERSION),) 'protobuf>=3.5.0.post1',
'grpcio>={version}'.format(version=grpc_version.VERSION),
)
try: try:
import health_commands as _health_commands import health_commands as _health_commands

@ -27,7 +27,8 @@ def _not_found_error():
return reflection_pb2.ServerReflectionResponse( return reflection_pb2.ServerReflectionResponse(
error_response=reflection_pb2.ErrorResponse( error_response=reflection_pb2.ErrorResponse(
error_code=grpc.StatusCode.NOT_FOUND.value[0], error_code=grpc.StatusCode.NOT_FOUND.value[0],
error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(),)) error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(),
))
def _file_descriptor_response(descriptor): def _file_descriptor_response(descriptor):
@ -101,10 +102,11 @@ class ReflectionServicer(reflection_pb2_grpc.ServerReflectionServicer):
def _list_services(self): def _list_services(self):
return reflection_pb2.ServerReflectionResponse( return reflection_pb2.ServerReflectionResponse(
list_services_response=reflection_pb2.ListServiceResponse(service=[ list_services_response=reflection_pb2.ListServiceResponse(
reflection_pb2.ServiceResponse(name=service_name) service=[
for service_name in self._service_names reflection_pb2.ServiceResponse(name=service_name)
])) for service_name in self._service_names
]))
def ServerReflectionInfo(self, request_iterator, context): def ServerReflectionInfo(self, request_iterator, context):
# pylint: disable=unused-argument # pylint: disable=unused-argument
@ -128,7 +130,8 @@ class ReflectionServicer(reflection_pb2_grpc.ServerReflectionServicer):
error_response=reflection_pb2.ErrorResponse( error_response=reflection_pb2.ErrorResponse(
error_code=grpc.StatusCode.INVALID_ARGUMENT.value[0], error_code=grpc.StatusCode.INVALID_ARGUMENT.value[0],
error_message=grpc.StatusCode.INVALID_ARGUMENT.value[1] error_message=grpc.StatusCode.INVALID_ARGUMENT.value[1]
.encode(),)) .encode(),
))
def enable_server_reflection(service_names, server, pool=None): def enable_server_reflection(service_names, server, pool=None):

@ -57,8 +57,10 @@ PACKAGE_DIRECTORIES = {
'': '.', '': '.',
} }
INSTALL_REQUIRES = ('protobuf>=3.5.0.post1', INSTALL_REQUIRES = (
'grpcio>={version}'.format(version=grpc_version.VERSION),) 'protobuf>=3.5.0.post1',
'grpcio>={version}'.format(version=grpc_version.VERSION),
)
try: try:
import reflection_commands as _reflection_commands import reflection_commands as _reflection_commands

@ -27,20 +27,20 @@ class UnaryUnary(grpc.UnaryUnaryMultiCallable):
def __call__(self, request, timeout=None, metadata=None, credentials=None): def __call__(self, request, timeout=None, metadata=None, credentials=None):
rpc_handler = self._channel_handler.invoke_rpc( rpc_handler = self._channel_handler.invoke_rpc(
self._method_full_rpc_name, self._method_full_rpc_name, _common.fuss_with_metadata(metadata),
_common.fuss_with_metadata(metadata), [request], True, timeout) [request], True, timeout)
return _invocation.blocking_unary_response(rpc_handler) return _invocation.blocking_unary_response(rpc_handler)
def with_call(self, request, timeout=None, metadata=None, credentials=None): def with_call(self, request, timeout=None, metadata=None, credentials=None):
rpc_handler = self._channel_handler.invoke_rpc( rpc_handler = self._channel_handler.invoke_rpc(
self._method_full_rpc_name, self._method_full_rpc_name, _common.fuss_with_metadata(metadata),
_common.fuss_with_metadata(metadata), [request], True, timeout) [request], True, timeout)
return _invocation.blocking_unary_response_with_call(rpc_handler) return _invocation.blocking_unary_response_with_call(rpc_handler)
def future(self, request, timeout=None, metadata=None, credentials=None): def future(self, request, timeout=None, metadata=None, credentials=None):
rpc_handler = self._channel_handler.invoke_rpc( rpc_handler = self._channel_handler.invoke_rpc(
self._method_full_rpc_name, self._method_full_rpc_name, _common.fuss_with_metadata(metadata),
_common.fuss_with_metadata(metadata), [request], True, timeout) [request], True, timeout)
return _invocation.future_call(rpc_handler) return _invocation.future_call(rpc_handler)
@ -52,8 +52,8 @@ class UnaryStream(grpc.StreamStreamMultiCallable):
def __call__(self, request, timeout=None, metadata=None, credentials=None): def __call__(self, request, timeout=None, metadata=None, credentials=None):
rpc_handler = self._channel_handler.invoke_rpc( rpc_handler = self._channel_handler.invoke_rpc(
self._method_full_rpc_name, self._method_full_rpc_name, _common.fuss_with_metadata(metadata),
_common.fuss_with_metadata(metadata), [request], True, timeout) [request], True, timeout)
return _invocation.ResponseIteratorCall(rpc_handler) return _invocation.ResponseIteratorCall(rpc_handler)
@ -69,8 +69,8 @@ class StreamUnary(grpc.StreamUnaryMultiCallable):
metadata=None, metadata=None,
credentials=None): credentials=None):
rpc_handler = self._channel_handler.invoke_rpc( rpc_handler = self._channel_handler.invoke_rpc(
self._method_full_rpc_name, self._method_full_rpc_name, _common.fuss_with_metadata(metadata),
_common.fuss_with_metadata(metadata), [], False, timeout) [], False, timeout)
_invocation.consume_requests(request_iterator, rpc_handler) _invocation.consume_requests(request_iterator, rpc_handler)
return _invocation.blocking_unary_response(rpc_handler) return _invocation.blocking_unary_response(rpc_handler)
@ -80,8 +80,8 @@ class StreamUnary(grpc.StreamUnaryMultiCallable):
metadata=None, metadata=None,
credentials=None): credentials=None):
rpc_handler = self._channel_handler.invoke_rpc( rpc_handler = self._channel_handler.invoke_rpc(
self._method_full_rpc_name, self._method_full_rpc_name, _common.fuss_with_metadata(metadata),
_common.fuss_with_metadata(metadata), [], False, timeout) [], False, timeout)
_invocation.consume_requests(request_iterator, rpc_handler) _invocation.consume_requests(request_iterator, rpc_handler)
return _invocation.blocking_unary_response_with_call(rpc_handler) return _invocation.blocking_unary_response_with_call(rpc_handler)
@ -91,8 +91,8 @@ class StreamUnary(grpc.StreamUnaryMultiCallable):
metadata=None, metadata=None,
credentials=None): credentials=None):
rpc_handler = self._channel_handler.invoke_rpc( rpc_handler = self._channel_handler.invoke_rpc(
self._method_full_rpc_name, self._method_full_rpc_name, _common.fuss_with_metadata(metadata),
_common.fuss_with_metadata(metadata), [], False, timeout) [], False, timeout)
_invocation.consume_requests(request_iterator, rpc_handler) _invocation.consume_requests(request_iterator, rpc_handler)
return _invocation.future_call(rpc_handler) return _invocation.future_call(rpc_handler)
@ -109,8 +109,8 @@ class StreamStream(grpc.StreamStreamMultiCallable):
metadata=None, metadata=None,
credentials=None): credentials=None):
rpc_handler = self._channel_handler.invoke_rpc( rpc_handler = self._channel_handler.invoke_rpc(
self._method_full_rpc_name, self._method_full_rpc_name, _common.fuss_with_metadata(metadata),
_common.fuss_with_metadata(metadata), [], False, timeout) [], False, timeout)
_invocation.consume_requests(request_iterator, rpc_handler) _invocation.consume_requests(request_iterator, rpc_handler)
return _invocation.ResponseIteratorCall(rpc_handler) return _invocation.ResponseIteratorCall(rpc_handler)

@ -179,8 +179,8 @@ class State(_common.ChannelRpcHandler):
elif self._code is None: elif self._code is None:
self._condition.wait() self._condition.wait()
else: else:
raise ValueError( raise ValueError('Status code unexpectedly {}!'.format(
'Status code unexpectedly {}!'.format(self._code)) self._code))
def is_active(self): def is_active(self):
raise NotImplementedError() raise NotImplementedError()

@ -20,9 +20,10 @@ import six
def _fuss(tuplified_metadata): def _fuss(tuplified_metadata):
return tuplified_metadata + ( return tuplified_metadata + ((
('grpc.metadata_added_by_runtime', 'grpc.metadata_added_by_runtime',
'gRPC is allowed to add metadata in transmission and does so.',),) 'gRPC is allowed to add metadata in transmission and does so.',
),)
FUSSED_EMPTY_METADATA = _fuss(()) FUSSED_EMPTY_METADATA = _fuss(())
@ -46,9 +47,12 @@ def rpc_names(service_descriptors):
class ChannelRpcRead( class ChannelRpcRead(
collections.namedtuple( collections.namedtuple('ChannelRpcRead', (
'ChannelRpcRead', 'response',
('response', 'trailing_metadata', 'code', 'details',))): 'trailing_metadata',
'code',
'details',
))):
pass pass
@ -100,8 +104,11 @@ class ChannelHandler(six.with_metaclass(abc.ABCMeta)):
class ServerRpcRead( class ServerRpcRead(
collections.namedtuple('ServerRpcRead', collections.namedtuple('ServerRpcRead', (
('request', 'requests_closed', 'terminated',))): 'request',
'requests_closed',
'terminated',
))):
pass pass

@ -170,8 +170,12 @@ class _Handler(Handler):
if self._unary_response is None: if self._unary_response is None:
if self._responses: if self._responses:
self._unary_response = self._responses.pop(0) self._unary_response = self._responses.pop(0)
return (self._unary_response, self._trailing_metadata, return (
self._code, self._details,) self._unary_response,
self._trailing_metadata,
self._code,
self._details,
)
def stream_response_termination(self): def stream_response_termination(self):
with self._condition: with self._condition:

@ -76,7 +76,11 @@ class _Serverish(_common.Serverish):
rpc, self._time, deadline) rpc, self._time, deadline)
service_thread = threading.Thread( service_thread = threading.Thread(
target=service_behavior, target=service_behavior,
args=(implementation, rpc, servicer_context,)) args=(
implementation,
rpc,
servicer_context,
))
service_thread.start() service_thread.start()
def invoke_unary_unary(self, method_descriptor, handler, def invoke_unary_unary(self, method_descriptor, handler,

@ -46,9 +46,11 @@ class _State(object):
class _Delta( class _Delta(
collections.namedtuple('_Delta', collections.namedtuple('_Delta', (
('mature_behaviors', 'earliest_mature_time', 'mature_behaviors',
'earliest_immature_time',))): 'earliest_mature_time',
'earliest_immature_time',
))):
pass pass

@ -28,8 +28,10 @@ PACKAGE_DIRECTORIES = {
'': '.', '': '.',
} }
INSTALL_REQUIRES = ('protobuf>=3.5.0.post1', INSTALL_REQUIRES = (
'grpcio>={version}'.format(version=grpc_version.VERSION),) 'protobuf>=3.5.0.post1',
'grpcio>={version}'.format(version=grpc_version.VERSION),
)
setuptools.setup( setuptools.setup(
name='grpcio-testing', name='grpcio-testing',

@ -99,4 +99,5 @@ setuptools.setup(
tests_require=TESTS_REQUIRE, tests_require=TESTS_REQUIRE,
test_suite=TEST_SUITE, test_suite=TEST_SUITE,
test_loader=TEST_LOADER, test_loader=TEST_LOADER,
test_runner=TEST_RUNNER,) test_runner=TEST_RUNNER,
)

@ -101,5 +101,5 @@ def iterate_suite_cases(suite):
elif isinstance(item, unittest.TestCase): elif isinstance(item, unittest.TestCase):
yield item yield item
else: else:
raise ValueError( raise ValueError('unexpected suite item of type {}'.format(
'unexpected suite item of type {}'.format(type(item))) type(item)))

@ -215,7 +215,8 @@ class AugmentedResult(unittest.TestResult):
Args: Args:
filter (callable): A unary predicate to filter over CaseResult objects. filter (callable): A unary predicate to filter over CaseResult objects.
""" """
return (self.cases[case_id] for case_id in self.cases return (self.cases[case_id]
for case_id in self.cases
if filter(self.cases[case_id])) if filter(self.cases[case_id]))
@ -285,8 +286,8 @@ class TerminalResult(CoverageResult):
def startTestRun(self): def startTestRun(self):
"""See unittest.TestResult.startTestRun.""" """See unittest.TestResult.startTestRun."""
super(TerminalResult, self).startTestRun() super(TerminalResult, self).startTestRun()
self.out.write(_Colors.HEADER + 'Testing gRPC Python...\n' + self.out.write(
_Colors.END) _Colors.HEADER + 'Testing gRPC Python...\n' + _Colors.END)
def stopTestRun(self): def stopTestRun(self):
"""See unittest.TestResult.stopTestRun.""" """See unittest.TestResult.stopTestRun."""
@ -297,43 +298,43 @@ class TerminalResult(CoverageResult):
def addError(self, test, error): def addError(self, test, error):
"""See unittest.TestResult.addError.""" """See unittest.TestResult.addError."""
super(TerminalResult, self).addError(test, error) super(TerminalResult, self).addError(test, error)
self.out.write(_Colors.FAIL + 'ERROR {}\n'.format(test.id()) + self.out.write(
_Colors.END) _Colors.FAIL + 'ERROR {}\n'.format(test.id()) + _Colors.END)
self.out.flush() self.out.flush()
def addFailure(self, test, error): def addFailure(self, test, error):
"""See unittest.TestResult.addFailure.""" """See unittest.TestResult.addFailure."""
super(TerminalResult, self).addFailure(test, error) super(TerminalResult, self).addFailure(test, error)
self.out.write(_Colors.FAIL + 'FAILURE {}\n'.format(test.id()) + self.out.write(
_Colors.END) _Colors.FAIL + 'FAILURE {}\n'.format(test.id()) + _Colors.END)
self.out.flush() self.out.flush()
def addSuccess(self, test): def addSuccess(self, test):
"""See unittest.TestResult.addSuccess.""" """See unittest.TestResult.addSuccess."""
super(TerminalResult, self).addSuccess(test) super(TerminalResult, self).addSuccess(test)
self.out.write(_Colors.OK + 'SUCCESS {}\n'.format(test.id()) + self.out.write(
_Colors.END) _Colors.OK + 'SUCCESS {}\n'.format(test.id()) + _Colors.END)
self.out.flush() self.out.flush()
def addSkip(self, test, reason): def addSkip(self, test, reason):
"""See unittest.TestResult.addSkip.""" """See unittest.TestResult.addSkip."""
super(TerminalResult, self).addSkip(test, reason) super(TerminalResult, self).addSkip(test, reason)
self.out.write(_Colors.INFO + 'SKIP {}\n'.format(test.id()) + self.out.write(
_Colors.END) _Colors.INFO + 'SKIP {}\n'.format(test.id()) + _Colors.END)
self.out.flush() self.out.flush()
def addExpectedFailure(self, test, error): def addExpectedFailure(self, test, error):
"""See unittest.TestResult.addExpectedFailure.""" """See unittest.TestResult.addExpectedFailure."""
super(TerminalResult, self).addExpectedFailure(test, error) super(TerminalResult, self).addExpectedFailure(test, error)
self.out.write(_Colors.INFO + 'FAILURE_OK {}\n'.format(test.id()) + self.out.write(
_Colors.END) _Colors.INFO + 'FAILURE_OK {}\n'.format(test.id()) + _Colors.END)
self.out.flush() self.out.flush()
def addUnexpectedSuccess(self, test): def addUnexpectedSuccess(self, test):
"""See unittest.TestResult.addUnexpectedSuccess.""" """See unittest.TestResult.addUnexpectedSuccess."""
super(TerminalResult, self).addUnexpectedSuccess(test) super(TerminalResult, self).addUnexpectedSuccess(test)
self.out.write(_Colors.INFO + 'UNEXPECTED_OK {}\n'.format(test.id()) + self.out.write(
_Colors.END) _Colors.INFO + 'UNEXPECTED_OK {}\n'.format(test.id()) + _Colors.END)
self.out.flush() self.out.flush()

@ -181,8 +181,8 @@ class Runner(object):
# Run the tests # Run the tests
result.startTestRun() result.startTestRun()
for augmented_case in augmented_cases: for augmented_case in augmented_cases:
sys.stdout.write( sys.stdout.write('Running {}\n'.format(
'Running {}\n'.format(augmented_case.case.id())) augmented_case.case.id()))
sys.stdout.flush() sys.stdout.flush()
case_thread = threading.Thread( case_thread = threading.Thread(
target=augmented_case.case.run, args=(result,)) target=augmented_case.case.run, args=(result,))
@ -196,8 +196,8 @@ class Runner(object):
except: except:
# re-raise the exception after forcing the with-block to end # re-raise the exception after forcing the with-block to end
raise raise
result.set_output(augmented_case.case, result.set_output(augmented_case.case, stdout_pipe.output(),
stdout_pipe.output(), stderr_pipe.output()) stderr_pipe.output())
sys.stdout.write(result_out.getvalue()) sys.stdout.write(result_out.getvalue())
sys.stdout.flush() sys.stdout.flush()
result_out.truncate(0) result_out.truncate(0)

@ -32,14 +32,14 @@ def _validate_payload_type_and_length(response, expected_type, expected_length):
def _expect_status_code(call, expected_code): def _expect_status_code(call, expected_code):
if call.code() != expected_code: if call.code() != expected_code:
raise ValueError('expected code %s, got %s' % raise ValueError('expected code %s, got %s' % (expected_code,
(expected_code, call.code())) call.code()))
def _expect_status_details(call, expected_details): def _expect_status_details(call, expected_details):
if call.details() != expected_details: if call.details() != expected_details:
raise ValueError('expected message %s, got %s' % raise ValueError('expected message %s, got %s' % (expected_details,
(expected_details, call.details())) call.details()))
def _validate_status_code_and_details(call, expected_code, expected_details): def _validate_status_code_and_details(call, expected_code, expected_details):

@ -39,8 +39,8 @@ class IntraopTestCase(object):
methods.TestCase.PING_PONG.test_interoperability(self.stub, None) methods.TestCase.PING_PONG.test_interoperability(self.stub, None)
def testCancelAfterBegin(self): def testCancelAfterBegin(self):
methods.TestCase.CANCEL_AFTER_BEGIN.test_interoperability(self.stub, methods.TestCase.CANCEL_AFTER_BEGIN.test_interoperability(
None) self.stub, None)
def testCancelAfterFirstResponse(self): def testCancelAfterFirstResponse(self):
methods.TestCase.CANCEL_AFTER_FIRST_RESPONSE.test_interoperability( methods.TestCase.CANCEL_AFTER_FIRST_RESPONSE.test_interoperability(

@ -34,15 +34,16 @@ class SecureIntraopTest(_intraop_test_case.IntraopTestCase, unittest.TestCase):
self.server) self.server)
port = self.server.add_secure_port( port = self.server.add_secure_port(
'[::]:0', '[::]:0',
grpc.ssl_server_credentials( grpc.ssl_server_credentials([(resources.private_key(),
[(resources.private_key(), resources.certificate_chain())])) resources.certificate_chain())]))
self.server.start() self.server.start()
self.stub = test_pb2_grpc.TestServiceStub( self.stub = test_pb2_grpc.TestServiceStub(
grpc.secure_channel('localhost:{}'.format(port), grpc.secure_channel('localhost:{}'.format(port),
grpc.ssl_channel_credentials( grpc.ssl_channel_credentials(
resources.test_root_certificates()), ( resources.test_root_certificates()), ((
('grpc.ssl_target_name_override', 'grpc.ssl_target_name_override',
_SERVER_HOST_OVERRIDE,),))) _SERVER_HOST_OVERRIDE,
),)))
if __name__ == '__main__': if __name__ == '__main__':

@ -104,8 +104,10 @@ def _stub(args):
channel_credentials = grpc.composite_channel_credentials( channel_credentials = grpc.composite_channel_credentials(
channel_credentials, call_credentials) channel_credentials, call_credentials)
channel = grpc.secure_channel(target, channel_credentials, ( channel = grpc.secure_channel(target, channel_credentials, ((
('grpc.ssl_target_name_override', args.server_host_override,),)) 'grpc.ssl_target_name_override',
args.server_host_override,
),))
else: else:
channel = grpc.insecure_channel(target) channel = grpc.insecure_channel(target)
if args.test_case == "unimplemented_service": if args.test_case == "unimplemented_service":

@ -62,9 +62,10 @@ class TestService(test_pb2_grpc.TestServiceServicer):
def UnaryCall(self, request, context): def UnaryCall(self, request, context):
_maybe_echo_metadata(context) _maybe_echo_metadata(context)
_maybe_echo_status_and_message(request, context) _maybe_echo_status_and_message(request, context)
return messages_pb2.SimpleResponse(payload=messages_pb2.Payload( return messages_pb2.SimpleResponse(
type=messages_pb2.COMPRESSABLE, payload=messages_pb2.Payload(
body=b'\x00' * request.response_size)) type=messages_pb2.COMPRESSABLE,
body=b'\x00' * request.response_size))
def StreamingOutputCall(self, request, context): def StreamingOutputCall(self, request, context):
_maybe_echo_status_and_message(request, context) _maybe_echo_status_and_message(request, context)
@ -100,14 +101,14 @@ class TestService(test_pb2_grpc.TestServiceServicer):
def _expect_status_code(call, expected_code): def _expect_status_code(call, expected_code):
if call.code() != expected_code: if call.code() != expected_code:
raise ValueError('expected code %s, got %s' % raise ValueError('expected code %s, got %s' % (expected_code,
(expected_code, call.code())) call.code()))
def _expect_status_details(call, expected_details): def _expect_status_details(call, expected_details):
if call.details() != expected_details: if call.details() != expected_details:
raise ValueError('expected message %s, got %s' % raise ValueError('expected message %s, got %s' % (expected_details,
(expected_details, call.details())) call.details()))
def _validate_status_code_and_details(call, expected_code, expected_details): def _validate_status_code_and_details(call, expected_code, expected_details):
@ -152,26 +153,38 @@ def _large_unary(stub):
def _client_streaming(stub): def _client_streaming(stub):
payload_body_sizes = (27182, 8, 1828, 45904,) payload_body_sizes = (
27182,
8,
1828,
45904,
)
payloads = (messages_pb2.Payload(body=b'\x00' * size) payloads = (messages_pb2.Payload(body=b'\x00' * size)
for size in payload_body_sizes) for size in payload_body_sizes)
requests = (messages_pb2.StreamingInputCallRequest(payload=payload) requests = (messages_pb2.StreamingInputCallRequest(payload=payload)
for payload in payloads) for payload in payloads)
response = stub.StreamingInputCall(requests) response = stub.StreamingInputCall(requests)
if response.aggregated_payload_size != 74922: if response.aggregated_payload_size != 74922:
raise ValueError('incorrect size %d!' % raise ValueError(
response.aggregated_payload_size) 'incorrect size %d!' % response.aggregated_payload_size)
def _server_streaming(stub): def _server_streaming(stub):
sizes = (31415, 9, 2653, 58979,) sizes = (
31415,
9,
2653,
58979,
)
request = messages_pb2.StreamingOutputCallRequest( request = messages_pb2.StreamingOutputCallRequest(
response_type=messages_pb2.COMPRESSABLE, response_type=messages_pb2.COMPRESSABLE,
response_parameters=(messages_pb2.ResponseParameters(size=sizes[0]), response_parameters=(
messages_pb2.ResponseParameters(size=sizes[1]), messages_pb2.ResponseParameters(size=sizes[0]),
messages_pb2.ResponseParameters(size=sizes[2]), messages_pb2.ResponseParameters(size=sizes[1]),
messages_pb2.ResponseParameters(size=sizes[3]),)) messages_pb2.ResponseParameters(size=sizes[2]),
messages_pb2.ResponseParameters(size=sizes[3]),
))
response_iterator = stub.StreamingOutputCall(request) response_iterator = stub.StreamingOutputCall(request)
for index, response in enumerate(response_iterator): for index, response in enumerate(response_iterator):
_validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE, _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE,
@ -218,8 +231,18 @@ class _Pipe(object):
def _ping_pong(stub): def _ping_pong(stub):
request_response_sizes = (31415, 9, 2653, 58979,) request_response_sizes = (
request_payload_sizes = (27182, 8, 1828, 45904,) 31415,
9,
2653,
58979,
)
request_payload_sizes = (
27182,
8,
1828,
45904,
)
with _Pipe() as pipe: with _Pipe() as pipe:
response_iterator = stub.FullDuplexCall(pipe) response_iterator = stub.FullDuplexCall(pipe)
@ -247,8 +270,18 @@ def _cancel_after_begin(stub):
def _cancel_after_first_response(stub): def _cancel_after_first_response(stub):
request_response_sizes = (31415, 9, 2653, 58979,) request_response_sizes = (
request_payload_sizes = (27182, 8, 1828, 45904,) 31415,
9,
2653,
58979,
)
request_payload_sizes = (
27182,
8,
1828,
45904,
)
with _Pipe() as pipe: with _Pipe() as pipe:
response_iterator = stub.FullDuplexCall(pipe) response_iterator = stub.FullDuplexCall(pipe)
@ -331,14 +364,14 @@ def _status_code_and_message(stub):
def _unimplemented_method(test_service_stub): def _unimplemented_method(test_service_stub):
response_future = ( response_future = (test_service_stub.UnimplementedCall.future(
test_service_stub.UnimplementedCall.future(empty_pb2.Empty())) empty_pb2.Empty()))
_expect_status_code(response_future, grpc.StatusCode.UNIMPLEMENTED) _expect_status_code(response_future, grpc.StatusCode.UNIMPLEMENTED)
def _unimplemented_service(unimplemented_service_stub): def _unimplemented_service(unimplemented_service_stub):
response_future = ( response_future = (unimplemented_service_stub.UnimplementedCall.future(
unimplemented_service_stub.UnimplementedCall.future(empty_pb2.Empty())) empty_pb2.Empty()))
_expect_status_code(response_future, grpc.StatusCode.UNIMPLEMENTED) _expect_status_code(response_future, grpc.StatusCode.UNIMPLEMENTED)
@ -392,11 +425,12 @@ def _oauth2_auth_token(stub, args):
wanted_email = json.load(open(json_key_filename, 'rb'))['client_email'] wanted_email = json.load(open(json_key_filename, 'rb'))['client_email']
response = _large_unary_common_behavior(stub, True, True, None) response = _large_unary_common_behavior(stub, True, True, None)
if wanted_email != response.username: if wanted_email != response.username:
raise ValueError('expected username %s, got %s' % raise ValueError('expected username %s, got %s' % (wanted_email,
(wanted_email, response.username)) response.username))
if args.oauth_scope.find(response.oauth_scope) == -1: if args.oauth_scope.find(response.oauth_scope) == -1:
raise ValueError('expected to find oauth scope "{}" in received "{}"'. raise ValueError(
format(response.oauth_scope, args.oauth_scope)) 'expected to find oauth scope "{}" in received "{}"'.format(
response.oauth_scope, args.oauth_scope))
def _jwt_token_creds(stub, args): def _jwt_token_creds(stub, args):
@ -404,8 +438,8 @@ def _jwt_token_creds(stub, args):
wanted_email = json.load(open(json_key_filename, 'rb'))['client_email'] wanted_email = json.load(open(json_key_filename, 'rb'))['client_email']
response = _large_unary_common_behavior(stub, True, False, None) response = _large_unary_common_behavior(stub, True, False, None)
if wanted_email != response.username: if wanted_email != response.username:
raise ValueError('expected username %s, got %s' % raise ValueError('expected username %s, got %s' % (wanted_email,
(wanted_email, response.username)) response.username))
def _per_rpc_creds(stub, args): def _per_rpc_creds(stub, args):
@ -419,8 +453,8 @@ def _per_rpc_creds(stub, args):
request=google_auth_transport_requests.Request())) request=google_auth_transport_requests.Request()))
response = _large_unary_common_behavior(stub, True, False, call_credentials) response = _large_unary_common_behavior(stub, True, False, call_credentials)
if wanted_email != response.username: if wanted_email != response.username:
raise ValueError('expected username %s, got %s' % raise ValueError('expected username %s, got %s' % (wanted_email,
(wanted_email, response.username)) response.username))
@enum.unique @enum.unique
@ -479,5 +513,5 @@ class TestCase(enum.Enum):
elif self is TestCase.PER_RPC_CREDS: elif self is TestCase.PER_RPC_CREDS:
_per_rpc_creds(stub, args) _per_rpc_creds(stub, args)
else: else:
raise NotImplementedError('Test case "%s" not implemented!' % raise NotImplementedError(
self.name) 'Test case "%s" not implemented!' % self.name)

@ -45,8 +45,8 @@ def serve():
if args.use_tls: if args.use_tls:
private_key = resources.private_key() private_key = resources.private_key()
certificate_chain = resources.certificate_chain() certificate_chain = resources.certificate_chain()
credentials = grpc.ssl_server_credentials(( credentials = grpc.ssl_server_credentials(((private_key,
(private_key, certificate_chain),)) certificate_chain),))
server.add_secure_port('[::]:{}'.format(args.port), credentials) server.add_secure_port('[::]:{}'.format(args.port), credentials)
else: else:
server.add_insecure_port('[::]:{}'.format(args.port)) server.add_insecure_port('[::]:{}'.format(args.port))

@ -119,8 +119,11 @@ class _ServicerMethods(object):
class _Service( class _Service(
collections.namedtuple('_Service', ('servicer_methods', 'server', collections.namedtuple('_Service', (
'stub',))): 'servicer_methods',
'server',
'stub',
))):
"""A live and running service. """A live and running service.
Attributes: Attributes:
@ -297,8 +300,8 @@ class PythonPluginTest(unittest.TestCase):
responses = service.stub.StreamingOutputCall(request) responses = service.stub.StreamingOutputCall(request)
expected_responses = service.servicer_methods.StreamingOutputCall( expected_responses = service.servicer_methods.StreamingOutputCall(
request, 'not a real RpcContext!') request, 'not a real RpcContext!')
for expected_response, response in moves.zip_longest(expected_responses, for expected_response, response in moves.zip_longest(
responses): expected_responses, responses):
self.assertEqual(expected_response, response) self.assertEqual(expected_response, response)
def testStreamingOutputCallExpired(self): def testStreamingOutputCallExpired(self):
@ -388,8 +391,8 @@ class PythonPluginTest(unittest.TestCase):
responses = service.stub.FullDuplexCall(_full_duplex_request_iterator()) responses = service.stub.FullDuplexCall(_full_duplex_request_iterator())
expected_responses = service.servicer_methods.FullDuplexCall( expected_responses = service.servicer_methods.FullDuplexCall(
_full_duplex_request_iterator(), 'not a real RpcContext!') _full_duplex_request_iterator(), 'not a real RpcContext!')
for expected_response, response in moves.zip_longest(expected_responses, for expected_response, response in moves.zip_longest(
responses): expected_responses, responses):
self.assertEqual(expected_response, response) self.assertEqual(expected_response, response)
def testFullDuplexCallExpired(self): def testFullDuplexCallExpired(self):
@ -439,8 +442,8 @@ class PythonPluginTest(unittest.TestCase):
responses = service.stub.HalfDuplexCall(half_duplex_request_iterator()) responses = service.stub.HalfDuplexCall(half_duplex_request_iterator())
expected_responses = service.servicer_methods.HalfDuplexCall( expected_responses = service.servicer_methods.HalfDuplexCall(
half_duplex_request_iterator(), 'not a real RpcContext!') half_duplex_request_iterator(), 'not a real RpcContext!')
for expected_response, response in moves.zip_longest(expected_responses, for expected_response, response in moves.zip_longest(
responses): expected_responses, responses):
self.assertEqual(expected_response, response) self.assertEqual(expected_response, response)
def testHalfDuplexCallWedged(self): def testHalfDuplexCallWedged(self):

@ -64,8 +64,8 @@ def _massage_proto_content(proto_content, test_name_bytes,
messages_proto_relative_file_name_bytes): messages_proto_relative_file_name_bytes):
package_substitution = (b'package grpc_protoc_plugin.invocation_testing.' + package_substitution = (b'package grpc_protoc_plugin.invocation_testing.' +
test_name_bytes + b';') test_name_bytes + b';')
common_namespace_substituted = proto_content.replace(_COMMON_NAMESPACE, common_namespace_substituted = proto_content.replace(
package_substitution) _COMMON_NAMESPACE, package_substitution)
split_namespace_substituted = common_namespace_substituted.replace( split_namespace_substituted = common_namespace_substituted.replace(
_SPLIT_NAMESPACE, package_substitution) _SPLIT_NAMESPACE, package_substitution)
message_import_replaced = split_namespace_substituted.replace( message_import_replaced = split_namespace_substituted.replace(
@ -163,8 +163,12 @@ class _GrpcBeforeProtoProtocStyle(object):
return pb2_grpc_protoc_exit_code, pb2_protoc_exit_code, return pb2_grpc_protoc_exit_code, pb2_protoc_exit_code,
_PROTOC_STYLES = (_Mid2016ProtocStyle(), _SingleProtocExecutionProtocStyle(), _PROTOC_STYLES = (
_ProtoBeforeGrpcProtocStyle(), _GrpcBeforeProtoProtocStyle(),) _Mid2016ProtocStyle(),
_SingleProtocExecutionProtocStyle(),
_ProtoBeforeGrpcProtocStyle(),
_GrpcBeforeProtoProtocStyle(),
)
@unittest.skipIf(platform.python_implementation() == 'PyPy', @unittest.skipIf(platform.python_implementation() == 'PyPy',
@ -180,18 +184,22 @@ class _Test(six.with_metaclass(abc.ABCMeta, unittest.TestCase)):
os.makedirs(self._python_out) os.makedirs(self._python_out)
proto_directories_and_names = { proto_directories_and_names = {
(self.MESSAGES_PROTO_RELATIVE_DIRECTORY_NAMES, (
self.MESSAGES_PROTO_FILE_NAME,), self.MESSAGES_PROTO_RELATIVE_DIRECTORY_NAMES,
(self.SERVICES_PROTO_RELATIVE_DIRECTORY_NAMES, self.MESSAGES_PROTO_FILE_NAME,
self.SERVICES_PROTO_FILE_NAME,), ),
(
self.SERVICES_PROTO_RELATIVE_DIRECTORY_NAMES,
self.SERVICES_PROTO_FILE_NAME,
),
} }
messages_proto_relative_file_name_forward_slashes = '/'.join( messages_proto_relative_file_name_forward_slashes = '/'.join(
self.MESSAGES_PROTO_RELATIVE_DIRECTORY_NAMES + ( self.MESSAGES_PROTO_RELATIVE_DIRECTORY_NAMES +
self.MESSAGES_PROTO_FILE_NAME,)) (self.MESSAGES_PROTO_FILE_NAME,))
_create_directory_tree(self._proto_path, ( _create_directory_tree(self._proto_path,
relative_proto_directory_names (relative_proto_directory_names
for relative_proto_directory_names, _ in proto_directories_and_names for relative_proto_directory_names, _ in
)) proto_directories_and_names))
self._absolute_proto_file_names = set() self._absolute_proto_file_names = set()
for relative_directory_names, file_name in proto_directories_and_names: for relative_directory_names, file_name in proto_directories_and_names:
absolute_proto_file_name = path.join( absolute_proto_file_name = path.join(
@ -200,8 +208,7 @@ class _Test(six.with_metaclass(abc.ABCMeta, unittest.TestCase)):
'tests.protoc_plugin.protos.invocation_testing', 'tests.protoc_plugin.protos.invocation_testing',
path.join(*relative_directory_names + (file_name,))) path.join(*relative_directory_names + (file_name,)))
massaged_proto_content = _massage_proto_content( massaged_proto_content = _massage_proto_content(
raw_proto_content, raw_proto_content, self.NAME.encode(),
self.NAME.encode(),
messages_proto_relative_file_name_forward_slashes.encode()) messages_proto_relative_file_name_forward_slashes.encode())
with open(absolute_proto_file_name, 'wb') as proto_file: with open(absolute_proto_file_name, 'wb') as proto_file:
proto_file.write(massaged_proto_content) proto_file.write(massaged_proto_content)
@ -275,7 +282,9 @@ def _create_test_case_class(split_proto, protoc_style):
if split_proto: if split_proto:
attributes['MESSAGES_PROTO_RELATIVE_DIRECTORY_NAMES'] = ( attributes['MESSAGES_PROTO_RELATIVE_DIRECTORY_NAMES'] = (
'split_messages', 'sub',) 'split_messages',
'sub',
)
attributes['MESSAGES_PROTO_FILE_NAME'] = 'messages.proto' attributes['MESSAGES_PROTO_FILE_NAME'] = 'messages.proto'
attributes['SERVICES_PROTO_RELATIVE_DIRECTORY_NAMES'] = ( attributes['SERVICES_PROTO_RELATIVE_DIRECTORY_NAMES'] = (
'split_services',) 'split_services',)
@ -301,7 +310,10 @@ def _create_test_case_class(split_proto, protoc_style):
def _create_test_case_classes(): def _create_test_case_classes():
for split_proto in (False, True,): for split_proto in (
False,
True,
):
for protoc_style in _PROTOC_STYLES: for protoc_style in _PROTOC_STYLES:
yield _create_test_case_class(split_proto, protoc_style) yield _create_test_case_class(split_proto, protoc_style)

@ -36,10 +36,28 @@ _RELATIVE_PROTO_PATH = 'relative_proto_path'
_RELATIVE_PYTHON_OUT = 'relative_python_out' _RELATIVE_PYTHON_OUT = 'relative_python_out'
_PROTO_FILES_PATH_COMPONENTS = ( _PROTO_FILES_PATH_COMPONENTS = (
('beta_grpc_plugin_test', 'payload', 'test_payload.proto',), (
('beta_grpc_plugin_test', 'requests', 'r', 'test_requests.proto',), 'beta_grpc_plugin_test',
('beta_grpc_plugin_test', 'responses', 'test_responses.proto',), 'payload',
('beta_grpc_plugin_test', 'service', 'test_service.proto',),) 'test_payload.proto',
),
(
'beta_grpc_plugin_test',
'requests',
'r',
'test_requests.proto',
),
(
'beta_grpc_plugin_test',
'responses',
'test_responses.proto',
),
(
'beta_grpc_plugin_test',
'service',
'test_service.proto',
),
)
_PAYLOAD_PB2 = 'beta_grpc_plugin_test.payload.test_payload_pb2' _PAYLOAD_PB2 = 'beta_grpc_plugin_test.payload.test_payload_pb2'
_REQUESTS_PB2 = 'beta_grpc_plugin_test.requests.r.test_requests_pb2' _REQUESTS_PB2 = 'beta_grpc_plugin_test.requests.r.test_requests_pb2'

@ -155,7 +155,8 @@ class _SyncStream(object):
_TIMEOUT) _TIMEOUT)
for _ in response_stream: for _ in response_stream:
self._handle_response( self._handle_response(
self, time.time() - self._send_time_queue.get_nowait()) self,
time.time() - self._send_time_queue.get_nowait())
def stop(self): def stop(self):
self._is_streaming = False self._is_streaming = False

@ -72,8 +72,8 @@ class WorkerServer(services_pb2_grpc.WorkerServiceServicer):
server = test_common.test_server(max_workers=server_threads) server = test_common.test_server(max_workers=server_threads)
if config.server_type == control_pb2.ASYNC_SERVER: if config.server_type == control_pb2.ASYNC_SERVER:
servicer = benchmark_server.BenchmarkServer() servicer = benchmark_server.BenchmarkServer()
services_pb2_grpc.add_BenchmarkServiceServicer_to_server(servicer, services_pb2_grpc.add_BenchmarkServiceServicer_to_server(
server) servicer, server)
elif config.server_type == control_pb2.ASYNC_GENERIC_SERVER: elif config.server_type == control_pb2.ASYNC_GENERIC_SERVER:
resp_size = config.payload_config.bytebuf_params.resp_size resp_size = config.payload_config.bytebuf_params.resp_size
servicer = benchmark_server.GenericBenchmarkServer(resp_size) servicer = benchmark_server.GenericBenchmarkServer(resp_size)
@ -87,12 +87,12 @@ class WorkerServer(services_pb2_grpc.WorkerServiceServicer):
'grpc.testing.BenchmarkService', method_implementations) 'grpc.testing.BenchmarkService', method_implementations)
server.add_generic_rpc_handlers((handler,)) server.add_generic_rpc_handlers((handler,))
else: else:
raise Exception( raise Exception('Unsupported server type {}'.format(
'Unsupported server type {}'.format(config.server_type)) config.server_type))
if config.HasField('security_params'): # Use SSL if config.HasField('security_params'): # Use SSL
server_creds = grpc.ssl_server_credentials(( server_creds = grpc.ssl_server_credentials(
(resources.private_key(), resources.certificate_chain()),)) ((resources.private_key(), resources.certificate_chain()),))
port = server.add_secure_port('[::]:{}'.format(config.port), port = server.add_secure_port('[::]:{}'.format(config.port),
server_creds) server_creds)
else: else:
@ -156,8 +156,8 @@ class WorkerServer(services_pb2_grpc.WorkerServiceServicer):
else: else:
raise Exception('Async streaming client not supported') raise Exception('Async streaming client not supported')
else: else:
raise Exception( raise Exception('Unsupported client type {}'.format(
'Unsupported client type {}'.format(config.client_type)) config.client_type))
# In multi-channel tests, we split the load across all channels # In multi-channel tests, we split the load across all channels
load_factor = float(config.client_channels) load_factor = float(config.client_channels)

@ -33,7 +33,13 @@ _EMPTY_PROTO_SYMBOL_NAME = 'grpc.testing.Empty'
_SERVICE_NAMES = ('Angstrom', 'Bohr', 'Curie', 'Dyson', 'Einstein', 'Feynman', _SERVICE_NAMES = ('Angstrom', 'Bohr', 'Curie', 'Dyson', 'Einstein', 'Feynman',
'Galilei') 'Galilei')
_EMPTY_EXTENSIONS_SYMBOL_NAME = 'grpc.testing.proto2.EmptyWithExtensions' _EMPTY_EXTENSIONS_SYMBOL_NAME = 'grpc.testing.proto2.EmptyWithExtensions'
_EMPTY_EXTENSIONS_NUMBERS = (124, 125, 126, 127, 128,) _EMPTY_EXTENSIONS_NUMBERS = (
124,
125,
126,
127,
128,
)
def _file_descriptor_to_proto(descriptor): def _file_descriptor_to_proto(descriptor):
@ -54,10 +60,12 @@ class ReflectionServicerTest(unittest.TestCase):
self._stub = reflection_pb2_grpc.ServerReflectionStub(channel) self._stub = reflection_pb2_grpc.ServerReflectionStub(channel)
def testFileByName(self): def testFileByName(self):
requests = (reflection_pb2.ServerReflectionRequest( requests = (
file_by_filename=_EMPTY_PROTO_FILE_NAME), reflection_pb2.ServerReflectionRequest(
reflection_pb2.ServerReflectionRequest( file_by_filename=_EMPTY_PROTO_FILE_NAME),
file_by_filename='i-donut-exist'),) reflection_pb2.ServerReflectionRequest(
file_by_filename='i-donut-exist'),
)
responses = tuple(self._stub.ServerReflectionInfo(iter(requests))) responses = tuple(self._stub.ServerReflectionInfo(iter(requests)))
expected_responses = ( expected_responses = (
reflection_pb2.ServerReflectionResponse( reflection_pb2.ServerReflectionResponse(
@ -70,14 +78,18 @@ class ReflectionServicerTest(unittest.TestCase):
error_response=reflection_pb2.ErrorResponse( error_response=reflection_pb2.ErrorResponse(
error_code=grpc.StatusCode.NOT_FOUND.value[0], error_code=grpc.StatusCode.NOT_FOUND.value[0],
error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(), error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(),
)),) )),
)
self.assertSequenceEqual(expected_responses, responses) self.assertSequenceEqual(expected_responses, responses)
def testFileBySymbol(self): def testFileBySymbol(self):
requests = (reflection_pb2.ServerReflectionRequest( requests = (
file_containing_symbol=_EMPTY_PROTO_SYMBOL_NAME reflection_pb2.ServerReflectionRequest(
), reflection_pb2.ServerReflectionRequest( file_containing_symbol=_EMPTY_PROTO_SYMBOL_NAME),
file_containing_symbol='i.donut.exist.co.uk.org.net.me.name.foo'),) reflection_pb2.ServerReflectionRequest(
file_containing_symbol='i.donut.exist.co.uk.org.net.me.name.foo'
),
)
responses = tuple(self._stub.ServerReflectionInfo(iter(requests))) responses = tuple(self._stub.ServerReflectionInfo(iter(requests)))
expected_responses = ( expected_responses = (
reflection_pb2.ServerReflectionResponse( reflection_pb2.ServerReflectionResponse(
@ -90,18 +102,23 @@ class ReflectionServicerTest(unittest.TestCase):
error_response=reflection_pb2.ErrorResponse( error_response=reflection_pb2.ErrorResponse(
error_code=grpc.StatusCode.NOT_FOUND.value[0], error_code=grpc.StatusCode.NOT_FOUND.value[0],
error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(), error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(),
)),) )),
)
self.assertSequenceEqual(expected_responses, responses) self.assertSequenceEqual(expected_responses, responses)
def testFileContainingExtension(self): def testFileContainingExtension(self):
requests = (reflection_pb2.ServerReflectionRequest( requests = (
file_containing_extension=reflection_pb2.ExtensionRequest( reflection_pb2.ServerReflectionRequest(
containing_type=_EMPTY_EXTENSIONS_SYMBOL_NAME, file_containing_extension=reflection_pb2.ExtensionRequest(
extension_number=125,), containing_type=_EMPTY_EXTENSIONS_SYMBOL_NAME,
), reflection_pb2.ServerReflectionRequest( extension_number=125,
file_containing_extension=reflection_pb2.ExtensionRequest( ),),
containing_type='i.donut.exist.co.uk.org.net.me.name.foo', reflection_pb2.ServerReflectionRequest(
extension_number=55,),),) file_containing_extension=reflection_pb2.ExtensionRequest(
containing_type='i.donut.exist.co.uk.org.net.me.name.foo',
extension_number=55,
),),
)
responses = tuple(self._stub.ServerReflectionInfo(iter(requests))) responses = tuple(self._stub.ServerReflectionInfo(iter(requests)))
expected_responses = ( expected_responses = (
reflection_pb2.ServerReflectionResponse( reflection_pb2.ServerReflectionResponse(
@ -114,14 +131,18 @@ class ReflectionServicerTest(unittest.TestCase):
error_response=reflection_pb2.ErrorResponse( error_response=reflection_pb2.ErrorResponse(
error_code=grpc.StatusCode.NOT_FOUND.value[0], error_code=grpc.StatusCode.NOT_FOUND.value[0],
error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(), error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(),
)),) )),
)
self.assertSequenceEqual(expected_responses, responses) self.assertSequenceEqual(expected_responses, responses)
def testExtensionNumbersOfType(self): def testExtensionNumbersOfType(self):
requests = (reflection_pb2.ServerReflectionRequest( requests = (
all_extension_numbers_of_type=_EMPTY_EXTENSIONS_SYMBOL_NAME reflection_pb2.ServerReflectionRequest(
), reflection_pb2.ServerReflectionRequest( all_extension_numbers_of_type=_EMPTY_EXTENSIONS_SYMBOL_NAME),
all_extension_numbers_of_type='i.donut.exist.co.uk.net.name.foo'),) reflection_pb2.ServerReflectionRequest(
all_extension_numbers_of_type='i.donut.exist.co.uk.net.name.foo'
),
)
responses = tuple(self._stub.ServerReflectionInfo(iter(requests))) responses = tuple(self._stub.ServerReflectionInfo(iter(requests)))
expected_responses = ( expected_responses = (
reflection_pb2.ServerReflectionResponse( reflection_pb2.ServerReflectionResponse(
@ -135,12 +156,12 @@ class ReflectionServicerTest(unittest.TestCase):
error_response=reflection_pb2.ErrorResponse( error_response=reflection_pb2.ErrorResponse(
error_code=grpc.StatusCode.NOT_FOUND.value[0], error_code=grpc.StatusCode.NOT_FOUND.value[0],
error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(), error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(),
)),) )),
)
self.assertSequenceEqual(expected_responses, responses) self.assertSequenceEqual(expected_responses, responses)
def testListServices(self): def testListServices(self):
requests = (reflection_pb2.ServerReflectionRequest( requests = (reflection_pb2.ServerReflectionRequest(list_services='',),)
list_services='',),)
responses = tuple(self._stub.ServerReflectionInfo(iter(requests))) responses = tuple(self._stub.ServerReflectionInfo(iter(requests)))
expected_responses = (reflection_pb2.ServerReflectionResponse( expected_responses = (reflection_pb2.ServerReflectionResponse(
valid_host='', valid_host='',

@ -102,8 +102,10 @@ def _get_channel(target, args):
root_certificates = None # will load default roots. root_certificates = None # will load default roots.
channel_credentials = grpc.ssl_channel_credentials( channel_credentials = grpc.ssl_channel_credentials(
root_certificates=root_certificates) root_certificates=root_certificates)
options = (('grpc.ssl_target_name_override', options = ((
args.server_host_override,),) 'grpc.ssl_target_name_override',
args.server_host_override,
),)
channel = grpc.secure_channel( channel = grpc.secure_channel(
target, channel_credentials, options=options) target, channel_credentials, options=options)
else: else:

@ -235,8 +235,8 @@ def run(scenario, channel):
elif scenario is Scenario.INFINITE_REQUEST_STREAM: elif scenario is Scenario.INFINITE_REQUEST_STREAM:
return _run_infinite_request_stream(stub) return _run_infinite_request_stream(stub)
except grpc.RpcError as rpc_error: except grpc.RpcError as rpc_error:
return Outcome(Outcome.Kind.RPC_ERROR, return Outcome(Outcome.Kind.RPC_ERROR, rpc_error.code(),
rpc_error.code(), rpc_error.details()) rpc_error.details())
_IMPLEMENTATIONS = { _IMPLEMENTATIONS = {
@ -256,5 +256,5 @@ def run(scenario, channel):
try: try:
return _IMPLEMENTATIONS[scenario](stub) return _IMPLEMENTATIONS[scenario](stub)
except grpc.RpcError as rpc_error: except grpc.RpcError as rpc_error:
return Outcome(Outcome.Kind.RPC_ERROR, return Outcome(Outcome.Kind.RPC_ERROR, rpc_error.code(),
rpc_error.code(), rpc_error.details()) rpc_error.details())

@ -193,8 +193,10 @@ class ClientTest(unittest.TestCase):
rpc.take_request() rpc.take_request()
rpc.take_request() rpc.take_request()
rpc.requests_closed() rpc.requests_closed()
rpc.send_initial_metadata(( rpc.send_initial_metadata(((
('my_metadata_key', 'My Metadata Value!',),)) 'my_metadata_key',
'My Metadata Value!',
),))
for rpc in rpcs[:-1]: for rpc in rpcs[:-1]:
rpc.terminate(_application_common.STREAM_UNARY_RESPONSE, (), rpc.terminate(_application_common.STREAM_UNARY_RESPONSE, (),
grpc.StatusCode.OK, '') grpc.StatusCode.OK, '')

@ -41,8 +41,10 @@ class FirstServiceServicer(services_pb2_grpc.FirstServiceServicer):
yield services_pb2.Strange() yield services_pb2.Strange()
def StreUn(self, request_iterator, context): def StreUn(self, request_iterator, context):
context.send_initial_metadata(( context.send_initial_metadata(((
('server_application_metadata_key', 'Hi there!',),)) 'server_application_metadata_key',
'Hi there!',
),))
for request in request_iterator: for request in request_iterator:
if request != _application_common.STREAM_UNARY_REQUEST: if request != _application_common.STREAM_UNARY_REQUEST:
context.set_code(grpc.StatusCode.INVALID_ARGUMENT) context.set_code(grpc.StatusCode.INVALID_ARGUMENT)

@ -110,14 +110,19 @@ class FirstServiceServicerTest(unittest.TestCase):
second_termination = rpc.termination() second_termination = rpc.termination()
third_termination = rpc.termination() third_termination = rpc.termination()
for later_initial_metadata in (second_initial_metadata, for later_initial_metadata in (
third_initial_metadata,): second_initial_metadata,
third_initial_metadata,
):
self.assertEqual(first_initial_metadata, later_initial_metadata) self.assertEqual(first_initial_metadata, later_initial_metadata)
response = first_termination[0] response = first_termination[0]
terminal_metadata = first_termination[1] terminal_metadata = first_termination[1]
code = first_termination[2] code = first_termination[2]
details = first_termination[3] details = first_termination[3]
for later_termination in (second_termination, third_termination,): for later_termination in (
second_termination,
third_termination,
):
self.assertEqual(response, later_termination[0]) self.assertEqual(response, later_termination[0])
self.assertEqual(terminal_metadata, later_termination[1]) self.assertEqual(terminal_metadata, later_termination[1])
self.assertIs(code, later_termination[2]) self.assertIs(code, later_termination[2])

@ -105,8 +105,8 @@ class TimeTest(object):
test_event.set, _QUANTUM * (2 + random.random())) test_event.set, _QUANTUM * (2 + random.random()))
for _ in range(_MANY): for _ in range(_MANY):
background_noise_futures.append( background_noise_futures.append(
self._time.call_in(threading.Event().set, _QUANTUM * 1000 * self._time.call_in(threading.Event().set,
random.random())) _QUANTUM * 1000 * random.random()))
self._time.sleep_for(_QUANTUM) self._time.sleep_for(_QUANTUM)
cancelled = set() cancelled = set()
for test_event, test_future in possibly_cancelled_futures.items(): for test_event, test_future in possibly_cancelled_futures.items():

@ -26,28 +26,57 @@ class AllTest(unittest.TestCase):
def testAll(self): def testAll(self):
expected_grpc_code_elements = ( expected_grpc_code_elements = (
'FutureTimeoutError', 'FutureCancelledError', 'Future', 'FutureTimeoutError',
'ChannelConnectivity', 'StatusCode', 'RpcError', 'RpcContext', 'FutureCancelledError',
'Call', 'ChannelCredentials', 'CallCredentials', 'Future',
'AuthMetadataContext', 'AuthMetadataPluginCallback', 'ChannelConnectivity',
'AuthMetadataPlugin', 'ServerCertificateConfiguration', 'StatusCode',
'ServerCredentials', 'UnaryUnaryMultiCallable', 'RpcError',
'UnaryStreamMultiCallable', 'StreamUnaryMultiCallable', 'RpcContext',
'StreamStreamMultiCallable', 'UnaryUnaryClientInterceptor', 'Call',
'UnaryStreamClientInterceptor', 'StreamUnaryClientInterceptor', 'ChannelCredentials',
'StreamStreamClientInterceptor', 'Channel', 'ServicerContext', 'CallCredentials',
'RpcMethodHandler', 'HandlerCallDetails', 'GenericRpcHandler', 'AuthMetadataContext',
'ServiceRpcHandler', 'Server', 'ServerInterceptor', 'AuthMetadataPluginCallback',
'unary_unary_rpc_method_handler', 'unary_stream_rpc_method_handler', 'AuthMetadataPlugin',
'stream_unary_rpc_method_handler', 'ClientCallDetails', 'ServerCertificateConfiguration',
'ServerCredentials',
'UnaryUnaryMultiCallable',
'UnaryStreamMultiCallable',
'StreamUnaryMultiCallable',
'StreamStreamMultiCallable',
'UnaryUnaryClientInterceptor',
'UnaryStreamClientInterceptor',
'StreamUnaryClientInterceptor',
'StreamStreamClientInterceptor',
'Channel',
'ServicerContext',
'RpcMethodHandler',
'HandlerCallDetails',
'GenericRpcHandler',
'ServiceRpcHandler',
'Server',
'ServerInterceptor',
'unary_unary_rpc_method_handler',
'unary_stream_rpc_method_handler',
'stream_unary_rpc_method_handler',
'ClientCallDetails',
'stream_stream_rpc_method_handler', 'stream_stream_rpc_method_handler',
'method_handlers_generic_handler', 'ssl_channel_credentials', 'method_handlers_generic_handler',
'metadata_call_credentials', 'access_token_call_credentials', 'ssl_channel_credentials',
'composite_call_credentials', 'composite_channel_credentials', 'metadata_call_credentials',
'ssl_server_credentials', 'ssl_server_certificate_configuration', 'access_token_call_credentials',
'dynamic_ssl_server_credentials', 'channel_ready_future', 'composite_call_credentials',
'insecure_channel', 'secure_channel', 'intercept_channel', 'composite_channel_credentials',
'server',) 'ssl_server_credentials',
'ssl_server_certificate_configuration',
'dynamic_ssl_server_credentials',
'channel_ready_future',
'insecure_channel',
'secure_channel',
'intercept_channel',
'server',
)
six.assertCountEqual(self, expected_grpc_code_elements, six.assertCountEqual(self, expected_grpc_code_elements,
_from_grpc_import_star.GRPC_ELEMENTS) _from_grpc_import_star.GRPC_ELEMENTS)
@ -56,12 +85,13 @@ class AllTest(unittest.TestCase):
class ChannelConnectivityTest(unittest.TestCase): class ChannelConnectivityTest(unittest.TestCase):
def testChannelConnectivity(self): def testChannelConnectivity(self):
self.assertSequenceEqual( self.assertSequenceEqual((
(grpc.ChannelConnectivity.IDLE, grpc.ChannelConnectivity.CONNECTING, grpc.ChannelConnectivity.IDLE,
grpc.ChannelConnectivity.READY, grpc.ChannelConnectivity.CONNECTING,
grpc.ChannelConnectivity.TRANSIENT_FAILURE, grpc.ChannelConnectivity.READY,
grpc.ChannelConnectivity.SHUTDOWN,), grpc.ChannelConnectivity.TRANSIENT_FAILURE,
tuple(grpc.ChannelConnectivity)) grpc.ChannelConnectivity.SHUTDOWN,
), tuple(grpc.ChannelConnectivity))
class ChannelTest(unittest.TestCase): class ChannelTest(unittest.TestCase):

@ -29,8 +29,12 @@ _RESPONSE = b'\x00\x00\x00'
_UNARY_UNARY = '/test/UnaryUnary' _UNARY_UNARY = '/test/UnaryUnary'
_SERVER_HOST_OVERRIDE = 'foo.test.google.fr' _SERVER_HOST_OVERRIDE = 'foo.test.google.fr'
_CLIENT_IDS = (b'*.test.google.fr', b'waterzooi.test.google.be', _CLIENT_IDS = (
b'*.test.youtube.com', b'192.168.1.3',) b'*.test.google.fr',
b'waterzooi.test.google.be',
b'*.test.youtube.com',
b'192.168.1.3',
)
_ID = 'id' _ID = 'id'
_ID_KEY = 'id_key' _ID_KEY = 'id_key'
_AUTH_CTX = 'auth_ctx' _AUTH_CTX = 'auth_ctx'
@ -39,7 +43,10 @@ _PRIVATE_KEY = resources.private_key()
_CERTIFICATE_CHAIN = resources.certificate_chain() _CERTIFICATE_CHAIN = resources.certificate_chain()
_TEST_ROOT_CERTIFICATES = resources.test_root_certificates() _TEST_ROOT_CERTIFICATES = resources.test_root_certificates()
_SERVER_CERTS = ((_PRIVATE_KEY, _CERTIFICATE_CHAIN),) _SERVER_CERTS = ((_PRIVATE_KEY, _CERTIFICATE_CHAIN),)
_PROPERTY_OPTIONS = (('grpc.ssl_target_name_override', _SERVER_HOST_OVERRIDE,),) _PROPERTY_OPTIONS = ((
'grpc.ssl_target_name_override',
_SERVER_HOST_OVERRIDE,
),)
def handle_unary_unary(request, servicer_context): def handle_unary_unary(request, servicer_context):

@ -24,8 +24,13 @@ class TestPointerWrapper(object):
return 123456 return 123456
TEST_CHANNEL_ARGS = (('arg1', b'bytes_val'), ('arg2', 'str_val'), ('arg3', 1), TEST_CHANNEL_ARGS = (
(b'arg4', 'str_val'), ('arg6', TestPointerWrapper()),) ('arg1', b'bytes_val'),
('arg2', 'str_val'),
('arg3', 1),
(b'arg4', 'str_val'),
('arg6', TestPointerWrapper()),
)
class ChannelArgsTest(unittest.TestCase): class ChannelArgsTest(unittest.TestCase):

@ -26,16 +26,16 @@ _STREAM_STREAM = '/test/StreamStream'
def handle_unary(request, servicer_context): def handle_unary(request, servicer_context):
servicer_context.send_initial_metadata( servicer_context.send_initial_metadata([('grpc-internal-encoding-request',
[('grpc-internal-encoding-request', 'gzip')]) 'gzip')])
return request return request
def handle_stream(request_iterator, servicer_context): def handle_stream(request_iterator, servicer_context):
# TODO(issue:#6891) We should be able to remove this loop, # TODO(issue:#6891) We should be able to remove this loop,
# and replace with return; yield # and replace with return; yield
servicer_context.send_initial_metadata( servicer_context.send_initial_metadata([('grpc-internal-encoding-request',
[('grpc-internal-encoding-request', 'gzip')]) 'gzip')])
for request in request_iterator: for request in request_iterator:
yield request yield request

@ -26,8 +26,8 @@ class CredentialsTest(unittest.TestCase):
third = grpc.access_token_call_credentials('ghi') third = grpc.access_token_call_credentials('ghi')
first_and_second = grpc.composite_call_credentials(first, second) first_and_second = grpc.composite_call_credentials(first, second)
first_second_and_third = grpc.composite_call_credentials(first, second, first_second_and_third = grpc.composite_call_credentials(
third) first, second, third)
self.assertIsInstance(first_and_second, grpc.CallCredentials) self.assertIsInstance(first_and_second, grpc.CallCredentials)
self.assertIsInstance(first_second_and_third, grpc.CallCredentials) self.assertIsInstance(first_second_and_third, grpc.CallCredentials)

@ -81,7 +81,8 @@ class _Handler(object):
cygrpc.SendMessageOperation(b'\x79\x57', _EMPTY_FLAGS), cygrpc.SendMessageOperation(b'\x79\x57', _EMPTY_FLAGS),
cygrpc.SendStatusFromServerOperation( cygrpc.SendStatusFromServerOperation(
_EMPTY_METADATA, cygrpc.StatusCode.ok, b'test details!', _EMPTY_METADATA, cygrpc.StatusCode.ok, b'test details!',
_EMPTY_FLAGS),) _EMPTY_FLAGS),
)
self._call.start_server_batch(operations, self._call.start_server_batch(operations,
_SERVER_COMPLETE_CALL_TAG) _SERVER_COMPLETE_CALL_TAG)
self._completion_queue.poll() self._completion_queue.poll()
@ -151,8 +152,12 @@ class CancelManyCallsTest(unittest.TestCase):
state = _State() state = _State()
server_thread_args = (state, server, server_completion_queue, server_thread_args = (
server_thread_pool,) state,
server,
server_completion_queue,
server_thread_pool,
)
server_thread = threading.Thread(target=_serve, args=server_thread_args) server_thread = threading.Thread(target=_serve, args=server_thread_args)
server_thread.start() server_thread.start()
@ -176,7 +181,8 @@ class CancelManyCallsTest(unittest.TestCase):
cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS), cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS),
cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS), cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),
cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),
cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS),) cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS),
)
tag = 'client_complete_call_{0:04d}_tag'.format(index) tag = 'client_complete_call_{0:04d}_tag'.format(index)
client_call.start_client_batch(operations, tag) client_call.start_client_batch(operations, tag)
client_due.add(tag) client_due.add(tag)
@ -193,8 +199,8 @@ class CancelManyCallsTest(unittest.TestCase):
state.condition.notify_all() state.condition.notify_all()
break break
client_driver.events(test_constants.RPC_CONCURRENCY * client_driver.events(
_SUCCESS_CALL_FRACTION) test_constants.RPC_CONCURRENCY * _SUCCESS_CALL_FRACTION)
with client_condition: with client_condition:
for client_call in client_calls: for client_call in client_calls:
client_call.cancel() client_call.cancel()

@ -56,7 +56,10 @@ class ChannelTest(unittest.TestCase):
def test_single_channel_lonely_connectivity(self): def test_single_channel_lonely_connectivity(self):
channel, completion_queue = _channel_and_completion_queue() channel, completion_queue = _channel_and_completion_queue()
_in_parallel(_connectivity_loop, (channel, completion_queue,)) _in_parallel(_connectivity_loop, (
channel,
completion_queue,
))
completion_queue.shutdown() completion_queue.shutdown()
def test_multiple_channels_lonely_connectivity(self): def test_multiple_channels_lonely_connectivity(self):

@ -23,14 +23,20 @@ RPC_COUNT = 4000
INFINITE_FUTURE = cygrpc.Timespec(float('+inf')) INFINITE_FUTURE = cygrpc.Timespec(float('+inf'))
EMPTY_FLAGS = 0 EMPTY_FLAGS = 0
INVOCATION_METADATA = (('client-md-key', 'client-md-key'), INVOCATION_METADATA = (
('client-md-key-bin', b'\x00\x01' * 3000),) ('client-md-key', 'client-md-key'),
('client-md-key-bin', b'\x00\x01' * 3000),
)
INITIAL_METADATA = (('server-initial-md-key', 'server-initial-md-value'), INITIAL_METADATA = (
('server-initial-md-key-bin', b'\x00\x02' * 3000),) ('server-initial-md-key', 'server-initial-md-value'),
('server-initial-md-key-bin', b'\x00\x02' * 3000),
)
TRAILING_METADATA = (('server-trailing-md-key', 'server-trailing-md-value'), TRAILING_METADATA = (
('server-trailing-md-key-bin', b'\x00\x03' * 3000),) ('server-trailing-md-key', 'server-trailing-md-value'),
('server-trailing-md-key-bin', b'\x00\x03' * 3000),
)
class QueueDriver(object): class QueueDriver(object):
@ -76,7 +82,10 @@ def execute_many_times(behavior):
class OperationResult( class OperationResult(
collections.namedtuple('OperationResult', ( collections.namedtuple('OperationResult', (
'start_batch_result', 'completion_type', 'success',))): 'start_batch_result',
'completion_type',
'success',
))):
pass pass

@ -101,28 +101,29 @@ class Test(_common.RpcTest, unittest.TestCase):
client_complete_rpc_event = self.client_driver.event_with_tag( client_complete_rpc_event = self.client_driver.event_with_tag(
client_complete_rpc_tag) client_complete_rpc_tag)
return (_common.OperationResult( return (
server_request_call_start_batch_result, _common.OperationResult(server_request_call_start_batch_result,
server_request_call_event.completion_type, server_request_call_event.completion_type,
server_request_call_event.success), _common.OperationResult( server_request_call_event.success),
_common.OperationResult(
client_receive_initial_metadata_start_batch_result, client_receive_initial_metadata_start_batch_result,
client_receive_initial_metadata_event.completion_type, client_receive_initial_metadata_event.completion_type,
client_receive_initial_metadata_event.success), client_receive_initial_metadata_event.success),
_common.OperationResult( _common.OperationResult(client_complete_rpc_start_batch_result,
client_complete_rpc_start_batch_result, client_complete_rpc_event.completion_type,
client_complete_rpc_event.completion_type, client_complete_rpc_event.success),
client_complete_rpc_event.success), _common.OperationResult( _common.OperationResult(
server_send_initial_metadata_start_batch_result, server_send_initial_metadata_start_batch_result,
server_send_initial_metadata_event.completion_type, server_send_initial_metadata_event.completion_type,
server_send_initial_metadata_event.success), server_send_initial_metadata_event.success),
_common.OperationResult( _common.OperationResult(server_complete_rpc_start_batch_result,
server_complete_rpc_start_batch_result, server_complete_rpc_event.completion_type,
server_complete_rpc_event.completion_type, server_complete_rpc_event.success),
server_complete_rpc_event.success),) )
def test_rpcs(self): def test_rpcs(self):
expecteds = [(_common.SUCCESSFUL_OPERATION_RESULT,) * expecteds = [(
5] * _common.RPC_COUNT _common.SUCCESSFUL_OPERATION_RESULT,) * 5] * _common.RPC_COUNT
actuallys = _common.execute_many_times(self._do_rpcs) actuallys = _common.execute_many_times(self._do_rpcs)
self.assertSequenceEqual(expecteds, actuallys) self.assertSequenceEqual(expecteds, actuallys)

@ -92,28 +92,29 @@ class Test(_common.RpcTest, unittest.TestCase):
client_complete_rpc_event = self.client_driver.event_with_tag( client_complete_rpc_event = self.client_driver.event_with_tag(
client_complete_rpc_tag) client_complete_rpc_tag)
return (_common.OperationResult( return (
server_request_call_start_batch_result, _common.OperationResult(server_request_call_start_batch_result,
server_request_call_event.completion_type, server_request_call_event.completion_type,
server_request_call_event.success), _common.OperationResult( server_request_call_event.success),
_common.OperationResult(
client_receive_initial_metadata_start_batch_result, client_receive_initial_metadata_start_batch_result,
client_receive_initial_metadata_event.completion_type, client_receive_initial_metadata_event.completion_type,
client_receive_initial_metadata_event.success), client_receive_initial_metadata_event.success),
_common.OperationResult( _common.OperationResult(client_complete_rpc_start_batch_result,
client_complete_rpc_start_batch_result, client_complete_rpc_event.completion_type,
client_complete_rpc_event.completion_type, client_complete_rpc_event.success),
client_complete_rpc_event.success), _common.OperationResult( _common.OperationResult(
server_send_initial_metadata_start_batch_result, server_send_initial_metadata_start_batch_result,
server_send_initial_metadata_event.completion_type, server_send_initial_metadata_event.completion_type,
server_send_initial_metadata_event.success), server_send_initial_metadata_event.success),
_common.OperationResult( _common.OperationResult(server_complete_rpc_start_batch_result,
server_complete_rpc_start_batch_result, server_complete_rpc_event.completion_type,
server_complete_rpc_event.completion_type, server_complete_rpc_event.success),
server_complete_rpc_event.success),) )
def test_rpcs(self): def test_rpcs(self):
expecteds = [(_common.SUCCESSFUL_OPERATION_RESULT,) * expecteds = [(
5] * _common.RPC_COUNT _common.SUCCESSFUL_OPERATION_RESULT,) * 5] * _common.RPC_COUNT
actuallys = _common.execute_many_times(self._do_rpcs) actuallys = _common.execute_many_times(self._do_rpcs)
self.assertSequenceEqual(expecteds, actuallys) self.assertSequenceEqual(expecteds, actuallys)

@ -137,9 +137,12 @@ class ReadSomeButNotAllResponsesTest(unittest.TestCase):
server_send_first_message_tag = 'server_send_first_message_tag' server_send_first_message_tag = 'server_send_first_message_tag'
server_send_second_message_tag = 'server_send_second_message_tag' server_send_second_message_tag = 'server_send_second_message_tag'
server_complete_rpc_tag = 'server_complete_rpc_tag' server_complete_rpc_tag = 'server_complete_rpc_tag'
server_call_due = set( server_call_due = set((
(server_send_initial_metadata_tag, server_send_first_message_tag, server_send_initial_metadata_tag,
server_send_second_message_tag, server_complete_rpc_tag,)) server_send_first_message_tag,
server_send_second_message_tag,
server_complete_rpc_tag,
))
server_call_completion_queue = cygrpc.CompletionQueue() server_call_completion_queue = cygrpc.CompletionQueue()
server_call_driver = _QueueDriver(server_call_condition, server_call_driver = _QueueDriver(server_call_condition,
server_call_completion_queue, server_call_completion_queue,

@ -29,8 +29,10 @@ _EMPTY_FLAGS = 0
def _metadata_plugin(context, callback): def _metadata_plugin(context, callback):
callback(((_CALL_CREDENTIALS_METADATA_KEY, callback(((
_CALL_CREDENTIALS_METADATA_VALUE,),), cygrpc.StatusCode.ok, b'') _CALL_CREDENTIALS_METADATA_KEY,
_CALL_CREDENTIALS_METADATA_VALUE,
),), cygrpc.StatusCode.ok, b'')
class TypeSmokeTest(unittest.TestCase): class TypeSmokeTest(unittest.TestCase):
@ -113,13 +115,12 @@ class ServerClientMixin(object):
cygrpc.ChannelArg(cygrpc.ChannelArgKey.ssl_target_name_override, cygrpc.ChannelArg(cygrpc.ChannelArgKey.ssl_target_name_override,
host_override) host_override)
]) ])
self.client_channel = cygrpc.Channel( self.client_channel = cygrpc.Channel('localhost:{}'.format(
'localhost:{}'.format(self.port).encode(), self.port).encode(), client_channel_arguments,
client_channel_arguments, client_credentials) client_credentials)
else: else:
self.client_channel = cygrpc.Channel( self.client_channel = cygrpc.Channel('localhost:{}'.format(
'localhost:{}'.format(self.port).encode(), self.port).encode(), cygrpc.ChannelArgs([]))
cygrpc.ChannelArgs([]))
if host_override: if host_override:
self.host_argument = None # default host self.host_argument = None # default host
self.expected_host = host_override self.expected_host = host_override
@ -152,8 +153,8 @@ class ServerClientMixin(object):
self.assertTrue(event.success) self.assertTrue(event.success)
self.assertIs(tag, event.tag) self.assertIs(tag, event.tag)
except Exception as error: except Exception as error:
raise Exception( raise Exception("Error in '{}': {}".format(
"Error in '{}': {}".format(description, error.message)) description, error.message))
return event return event
return test_utilities.SimpleFuture(performer) return test_utilities.SimpleFuture(performer)
@ -189,8 +190,15 @@ class ServerClientMixin(object):
None, 0, self.client_completion_queue, METHOD, self.host_argument, None, 0, self.client_completion_queue, METHOD, self.host_argument,
cygrpc_deadline) cygrpc_deadline)
client_initial_metadata = ( client_initial_metadata = (
(CLIENT_METADATA_ASCII_KEY, CLIENT_METADATA_ASCII_VALUE,), (
(CLIENT_METADATA_BIN_KEY, CLIENT_METADATA_BIN_VALUE,),) CLIENT_METADATA_ASCII_KEY,
CLIENT_METADATA_ASCII_VALUE,
),
(
CLIENT_METADATA_BIN_KEY,
CLIENT_METADATA_BIN_VALUE,
),
)
client_start_batch_result = client_call.start_client_batch([ client_start_batch_result = client_call.start_client_batch([
cygrpc.SendInitialMetadataOperation(client_initial_metadata, cygrpc.SendInitialMetadataOperation(client_initial_metadata,
_EMPTY_FLAGS), _EMPTY_FLAGS),
@ -220,14 +228,18 @@ class ServerClientMixin(object):
server_call_tag = object() server_call_tag = object()
server_call = request_event.call server_call = request_event.call
server_initial_metadata = ( server_initial_metadata = ((
(SERVER_INITIAL_METADATA_KEY, SERVER_INITIAL_METADATA_VALUE,),) SERVER_INITIAL_METADATA_KEY,
server_trailing_metadata = ( SERVER_INITIAL_METADATA_VALUE,
(SERVER_TRAILING_METADATA_KEY, SERVER_TRAILING_METADATA_VALUE,),) ),)
server_trailing_metadata = ((
SERVER_TRAILING_METADATA_KEY,
SERVER_TRAILING_METADATA_VALUE,
),)
server_start_batch_result = server_call.start_server_batch([ server_start_batch_result = server_call.start_server_batch([
cygrpc.SendInitialMetadataOperation( cygrpc.SendInitialMetadataOperation(server_initial_metadata,
server_initial_metadata, _EMPTY_FLAGS),
_EMPTY_FLAGS), cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),
cygrpc.SendMessageOperation(RESPONSE, _EMPTY_FLAGS), cygrpc.SendMessageOperation(RESPONSE, _EMPTY_FLAGS),
cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS), cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS),
cygrpc.SendStatusFromServerOperation( cygrpc.SendStatusFromServerOperation(
@ -377,10 +389,11 @@ class InsecureServerInsecureClient(unittest.TestCase, ServerClientMixin):
class SecureServerSecureClient(unittest.TestCase, ServerClientMixin): class SecureServerSecureClient(unittest.TestCase, ServerClientMixin):
def setUp(self): def setUp(self):
server_credentials = cygrpc.server_credentials_ssl(None, [ server_credentials = cygrpc.server_credentials_ssl(
cygrpc.SslPemKeyCertPair(resources.private_key(), None, [
resources.certificate_chain()) cygrpc.SslPemKeyCertPair(resources.private_key(),
], False) resources.certificate_chain())
], False)
client_credentials = cygrpc.SSLChannelCredentials( client_credentials = cygrpc.SSLChannelCredentials(
resources.test_root_certificates(), None, None) resources.test_root_certificates(), None, None)
self.setUpMixin(server_credentials, client_credentials, self.setUpMixin(server_credentials, client_credentials,

@ -106,13 +106,13 @@ class EmptyMessageTest(unittest.TestCase):
list(response_iterator)) list(response_iterator))
def testStreamUnary(self): def testStreamUnary(self):
response = self._channel.stream_unary(_STREAM_UNARY)( response = self._channel.stream_unary(_STREAM_UNARY)(iter(
iter([_REQUEST] * test_constants.STREAM_LENGTH)) [_REQUEST] * test_constants.STREAM_LENGTH))
self.assertEqual(_RESPONSE, response) self.assertEqual(_RESPONSE, response)
def testStreamStream(self): def testStreamStream(self):
response_iterator = self._channel.stream_stream(_STREAM_STREAM)( response_iterator = self._channel.stream_stream(_STREAM_STREAM)(iter(
iter([_REQUEST] * test_constants.STREAM_LENGTH)) [_REQUEST] * test_constants.STREAM_LENGTH))
self.assertSequenceEqual([_RESPONSE] * test_constants.STREAM_LENGTH, self.assertSequenceEqual([_RESPONSE] * test_constants.STREAM_LENGTH,
list(response_iterator)) list(response_iterator))

@ -65,7 +65,10 @@ class _Handler(object):
def handle_unary_unary(self, request, servicer_context): def handle_unary_unary(self, request, servicer_context):
self._control.control() self._control.control()
if servicer_context is not None: if servicer_context is not None:
servicer_context.set_trailing_metadata((('testkey', 'testvalue',),)) servicer_context.set_trailing_metadata(((
'testkey',
'testvalue',
),))
return request return request
def handle_unary_stream(self, request, servicer_context): def handle_unary_stream(self, request, servicer_context):
@ -74,7 +77,10 @@ class _Handler(object):
yield request yield request
self._control.control() self._control.control()
if servicer_context is not None: if servicer_context is not None:
servicer_context.set_trailing_metadata((('testkey', 'testvalue',),)) servicer_context.set_trailing_metadata(((
'testkey',
'testvalue',
),))
def handle_stream_unary(self, request_iterator, servicer_context): def handle_stream_unary(self, request_iterator, servicer_context):
if servicer_context is not None: if servicer_context is not None:
@ -86,13 +92,19 @@ class _Handler(object):
response_elements.append(request) response_elements.append(request)
self._control.control() self._control.control()
if servicer_context is not None: if servicer_context is not None:
servicer_context.set_trailing_metadata((('testkey', 'testvalue',),)) servicer_context.set_trailing_metadata(((
'testkey',
'testvalue',
),))
return b''.join(response_elements) return b''.join(response_elements)
def handle_stream_stream(self, request_iterator, servicer_context): def handle_stream_stream(self, request_iterator, servicer_context):
self._control.control() self._control.control()
if servicer_context is not None: if servicer_context is not None:
servicer_context.set_trailing_metadata((('testkey', 'testvalue',),)) servicer_context.set_trailing_metadata(((
'testkey',
'testvalue',
),))
for request in request_iterator: for request in request_iterator:
self._control.control() self._control.control()
yield request yield request
@ -162,9 +174,10 @@ def _stream_stream_multi_callable(channel):
class _ClientCallDetails( class _ClientCallDetails(
collections.namedtuple('_ClientCallDetails', collections.namedtuple(
('method', 'timeout', 'metadata', '_ClientCallDetails',
'credentials')), grpc.ClientCallDetails): ('method', 'timeout', 'metadata', 'credentials')),
grpc.ClientCallDetails):
pass pass
@ -262,7 +275,10 @@ def _append_request_header_interceptor(header, value):
metadata = [] metadata = []
if client_call_details.metadata: if client_call_details.metadata:
metadata = list(client_call_details.metadata) metadata = list(client_call_details.metadata)
metadata.append((header, value,)) metadata.append((
header,
value,
))
client_call_details = _ClientCallDetails( client_call_details = _ClientCallDetails(
client_call_details.method, client_call_details.timeout, metadata, client_call_details.method, client_call_details.timeout, metadata,
client_call_details.credentials) client_call_details.credentials)
@ -306,9 +322,11 @@ class InterceptorTest(unittest.TestCase):
self._server = grpc.server( self._server = grpc.server(
self._server_pool, self._server_pool,
options=(('grpc.so_reuseport', 0),), options=(('grpc.so_reuseport', 0),),
interceptors=(_LoggingInterceptor('s1', self._record), interceptors=(
conditional_interceptor, _LoggingInterceptor('s1', self._record),
_LoggingInterceptor('s2', self._record),)) conditional_interceptor,
_LoggingInterceptor('s2', self._record),
))
port = self._server.add_insecure_port('[::]:0') port = self._server.add_insecure_port('[::]:0')
self._server.add_generic_rpc_handlers((_GenericHandler(self._handler),)) self._server.add_generic_rpc_handlers((_GenericHandler(self._handler),))
self._server.start() self._server.start()
@ -333,8 +351,8 @@ class InterceptorTest(unittest.TestCase):
interceptor = _wrap_request_iterator_stream_interceptor(triple) interceptor = _wrap_request_iterator_stream_interceptor(triple)
channel = grpc.intercept_channel(self._channel, interceptor) channel = grpc.intercept_channel(self._channel, interceptor)
requests = tuple(b'\x07\x08' requests = tuple(
for _ in range(test_constants.STREAM_LENGTH)) b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH))
multi_callable = _stream_stream_multi_callable(channel) multi_callable = _stream_stream_multi_callable(channel)
response_iterator = multi_callable( response_iterator = multi_callable(
@ -365,8 +383,8 @@ class InterceptorTest(unittest.TestCase):
multi_callable = _unary_unary_multi_callable(defective_channel) multi_callable = _unary_unary_multi_callable(defective_channel)
call_future = multi_callable.future( call_future = multi_callable.future(
request, request,
metadata=( metadata=(('test',
('test', 'InterceptedUnaryRequestBlockingUnaryResponse'),)) 'InterceptedUnaryRequestBlockingUnaryResponse'),))
self.assertIsNotNone(call_future.exception()) self.assertIsNotNone(call_future.exception())
self.assertEqual(call_future.code(), grpc.StatusCode.INTERNAL) self.assertEqual(call_future.code(), grpc.StatusCode.INTERNAL)
@ -374,12 +392,14 @@ class InterceptorTest(unittest.TestCase):
def testInterceptedHeaderManipulationWithServerSideVerification(self): def testInterceptedHeaderManipulationWithServerSideVerification(self):
request = b'\x07\x08' request = b'\x07\x08'
channel = grpc.intercept_channel( channel = grpc.intercept_channel(self._channel,
self._channel, _append_request_header_interceptor('secret', '42')) _append_request_header_interceptor(
channel = grpc.intercept_channel( 'secret', '42'))
channel, channel = grpc.intercept_channel(channel,
_LoggingInterceptor('c1', self._record), _LoggingInterceptor(
_LoggingInterceptor('c2', self._record)) 'c1', self._record),
_LoggingInterceptor(
'c2', self._record))
self._record[:] = [] self._record[:] = []
@ -401,16 +421,17 @@ class InterceptorTest(unittest.TestCase):
self._record[:] = [] self._record[:] = []
channel = grpc.intercept_channel( channel = grpc.intercept_channel(self._channel,
self._channel, _LoggingInterceptor(
_LoggingInterceptor('c1', self._record), 'c1', self._record),
_LoggingInterceptor('c2', self._record)) _LoggingInterceptor(
'c2', self._record))
multi_callable = _unary_unary_multi_callable(channel) multi_callable = _unary_unary_multi_callable(channel)
multi_callable( multi_callable(
request, request,
metadata=( metadata=(('test',
('test', 'InterceptedUnaryRequestBlockingUnaryResponse'),)) 'InterceptedUnaryRequestBlockingUnaryResponse'),))
self.assertSequenceEqual(self._record, [ self.assertSequenceEqual(self._record, [
'c1:intercept_unary_unary', 'c2:intercept_unary_unary', 'c1:intercept_unary_unary', 'c2:intercept_unary_unary',
@ -420,10 +441,11 @@ class InterceptorTest(unittest.TestCase):
def testInterceptedUnaryRequestBlockingUnaryResponseWithCall(self): def testInterceptedUnaryRequestBlockingUnaryResponseWithCall(self):
request = b'\x07\x08' request = b'\x07\x08'
channel = grpc.intercept_channel( channel = grpc.intercept_channel(self._channel,
self._channel, _LoggingInterceptor(
_LoggingInterceptor('c1', self._record), 'c1', self._record),
_LoggingInterceptor('c2', self._record)) _LoggingInterceptor(
'c2', self._record))
self._record[:] = [] self._record[:] = []
@ -443,10 +465,11 @@ class InterceptorTest(unittest.TestCase):
request = b'\x07\x08' request = b'\x07\x08'
self._record[:] = [] self._record[:] = []
channel = grpc.intercept_channel( channel = grpc.intercept_channel(self._channel,
self._channel, _LoggingInterceptor(
_LoggingInterceptor('c1', self._record), 'c1', self._record),
_LoggingInterceptor('c2', self._record)) _LoggingInterceptor(
'c2', self._record))
multi_callable = _unary_unary_multi_callable(channel) multi_callable = _unary_unary_multi_callable(channel)
response_future = multi_callable.future( response_future = multi_callable.future(
@ -463,10 +486,11 @@ class InterceptorTest(unittest.TestCase):
request = b'\x37\x58' request = b'\x37\x58'
self._record[:] = [] self._record[:] = []
channel = grpc.intercept_channel( channel = grpc.intercept_channel(self._channel,
self._channel, _LoggingInterceptor(
_LoggingInterceptor('c1', self._record), 'c1', self._record),
_LoggingInterceptor('c2', self._record)) _LoggingInterceptor(
'c2', self._record))
multi_callable = _unary_stream_multi_callable(channel) multi_callable = _unary_stream_multi_callable(channel)
response_iterator = multi_callable( response_iterator = multi_callable(
@ -480,21 +504,22 @@ class InterceptorTest(unittest.TestCase):
]) ])
def testInterceptedStreamRequestBlockingUnaryResponse(self): def testInterceptedStreamRequestBlockingUnaryResponse(self):
requests = tuple(b'\x07\x08' requests = tuple(
for _ in range(test_constants.STREAM_LENGTH)) b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH))
request_iterator = iter(requests) request_iterator = iter(requests)
self._record[:] = [] self._record[:] = []
channel = grpc.intercept_channel( channel = grpc.intercept_channel(self._channel,
self._channel, _LoggingInterceptor(
_LoggingInterceptor('c1', self._record), 'c1', self._record),
_LoggingInterceptor('c2', self._record)) _LoggingInterceptor(
'c2', self._record))
multi_callable = _stream_unary_multi_callable(channel) multi_callable = _stream_unary_multi_callable(channel)
multi_callable( multi_callable(
request_iterator, request_iterator,
metadata=( metadata=(('test',
('test', 'InterceptedStreamRequestBlockingUnaryResponse'),)) 'InterceptedStreamRequestBlockingUnaryResponse'),))
self.assertSequenceEqual(self._record, [ self.assertSequenceEqual(self._record, [
'c1:intercept_stream_unary', 'c2:intercept_stream_unary', 'c1:intercept_stream_unary', 'c2:intercept_stream_unary',
@ -502,15 +527,16 @@ class InterceptorTest(unittest.TestCase):
]) ])
def testInterceptedStreamRequestBlockingUnaryResponseWithCall(self): def testInterceptedStreamRequestBlockingUnaryResponseWithCall(self):
requests = tuple(b'\x07\x08' requests = tuple(
for _ in range(test_constants.STREAM_LENGTH)) b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH))
request_iterator = iter(requests) request_iterator = iter(requests)
self._record[:] = [] self._record[:] = []
channel = grpc.intercept_channel( channel = grpc.intercept_channel(self._channel,
self._channel, _LoggingInterceptor(
_LoggingInterceptor('c1', self._record), 'c1', self._record),
_LoggingInterceptor('c2', self._record)) _LoggingInterceptor(
'c2', self._record))
multi_callable = _stream_unary_multi_callable(channel) multi_callable = _stream_unary_multi_callable(channel)
multi_callable.with_call( multi_callable.with_call(
@ -525,15 +551,16 @@ class InterceptorTest(unittest.TestCase):
]) ])
def testInterceptedStreamRequestFutureUnaryResponse(self): def testInterceptedStreamRequestFutureUnaryResponse(self):
requests = tuple(b'\x07\x08' requests = tuple(
for _ in range(test_constants.STREAM_LENGTH)) b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH))
request_iterator = iter(requests) request_iterator = iter(requests)
self._record[:] = [] self._record[:] = []
channel = grpc.intercept_channel( channel = grpc.intercept_channel(self._channel,
self._channel, _LoggingInterceptor(
_LoggingInterceptor('c1', self._record), 'c1', self._record),
_LoggingInterceptor('c2', self._record)) _LoggingInterceptor(
'c2', self._record))
multi_callable = _stream_unary_multi_callable(channel) multi_callable = _stream_unary_multi_callable(channel)
response_future = multi_callable.future( response_future = multi_callable.future(
@ -547,15 +574,16 @@ class InterceptorTest(unittest.TestCase):
]) ])
def testInterceptedStreamRequestStreamResponse(self): def testInterceptedStreamRequestStreamResponse(self):
requests = tuple(b'\x77\x58' requests = tuple(
for _ in range(test_constants.STREAM_LENGTH)) b'\x77\x58' for _ in range(test_constants.STREAM_LENGTH))
request_iterator = iter(requests) request_iterator = iter(requests)
self._record[:] = [] self._record[:] = []
channel = grpc.intercept_channel( channel = grpc.intercept_channel(self._channel,
self._channel, _LoggingInterceptor(
_LoggingInterceptor('c1', self._record), 'c1', self._record),
_LoggingInterceptor('c2', self._record)) _LoggingInterceptor(
'c2', self._record))
multi_callable = _stream_stream_multi_callable(channel) multi_callable = _stream_stream_multi_callable(channel)
response_iterator = multi_callable( response_iterator = multi_callable(

@ -106,8 +106,8 @@ class InvalidMetadataTest(unittest.TestCase):
self.assertEqual(response_iterator.code(), grpc.StatusCode.INTERNAL) self.assertEqual(response_iterator.code(), grpc.StatusCode.INTERNAL)
def testStreamRequestBlockingUnaryResponse(self): def testStreamRequestBlockingUnaryResponse(self):
request_iterator = (b'\x07\x08' request_iterator = (
for _ in range(test_constants.STREAM_LENGTH)) b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH))
metadata = (('InVaLiD', 'StreamRequestBlockingUnaryResponse'),) metadata = (('InVaLiD', 'StreamRequestBlockingUnaryResponse'),)
expected_error_details = "metadata was invalid: %s" % metadata expected_error_details = "metadata was invalid: %s" % metadata
with self.assertRaises(ValueError) as exception_context: with self.assertRaises(ValueError) as exception_context:
@ -115,8 +115,8 @@ class InvalidMetadataTest(unittest.TestCase):
self.assertIn(expected_error_details, str(exception_context.exception)) self.assertIn(expected_error_details, str(exception_context.exception))
def testStreamRequestBlockingUnaryResponseWithCall(self): def testStreamRequestBlockingUnaryResponseWithCall(self):
request_iterator = (b'\x07\x08' request_iterator = (
for _ in range(test_constants.STREAM_LENGTH)) b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH))
metadata = (('InVaLiD', 'StreamRequestBlockingUnaryResponseWithCall'),) metadata = (('InVaLiD', 'StreamRequestBlockingUnaryResponseWithCall'),)
expected_error_details = "metadata was invalid: %s" % metadata expected_error_details = "metadata was invalid: %s" % metadata
multi_callable = _stream_unary_multi_callable(self._channel) multi_callable = _stream_unary_multi_callable(self._channel)
@ -125,8 +125,8 @@ class InvalidMetadataTest(unittest.TestCase):
self.assertIn(expected_error_details, str(exception_context.exception)) self.assertIn(expected_error_details, str(exception_context.exception))
def testStreamRequestFutureUnaryResponse(self): def testStreamRequestFutureUnaryResponse(self):
request_iterator = (b'\x07\x08' request_iterator = (
for _ in range(test_constants.STREAM_LENGTH)) b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH))
metadata = (('InVaLiD', 'StreamRequestFutureUnaryResponse'),) metadata = (('InVaLiD', 'StreamRequestFutureUnaryResponse'),)
expected_error_details = "metadata was invalid: %s" % metadata expected_error_details = "metadata was invalid: %s" % metadata
response_future = self._stream_unary.future( response_future = self._stream_unary.future(
@ -141,8 +141,8 @@ class InvalidMetadataTest(unittest.TestCase):
self.assertEqual(response_future.code(), grpc.StatusCode.INTERNAL) self.assertEqual(response_future.code(), grpc.StatusCode.INTERNAL)
def testStreamRequestStreamResponse(self): def testStreamRequestStreamResponse(self):
request_iterator = (b'\x07\x08' request_iterator = (
for _ in range(test_constants.STREAM_LENGTH)) b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH))
metadata = (('InVaLiD', 'StreamRequestStreamResponse'),) metadata = (('InVaLiD', 'StreamRequestStreamResponse'),)
expected_error_details = "metadata was invalid: %s" % metadata expected_error_details = "metadata was invalid: %s" % metadata
response_iterator = self._stream_stream( response_iterator = self._stream_stream(

@ -62,7 +62,10 @@ class _Handler(object):
def handle_unary_unary(self, request, servicer_context): def handle_unary_unary(self, request, servicer_context):
self._control.control() self._control.control()
if servicer_context is not None: if servicer_context is not None:
servicer_context.set_trailing_metadata((('testkey', 'testvalue',),)) servicer_context.set_trailing_metadata(((
'testkey',
'testvalue',
),))
return request return request
def handle_unary_stream(self, request, servicer_context): def handle_unary_stream(self, request, servicer_context):
@ -71,7 +74,10 @@ class _Handler(object):
yield request yield request
self._control.control() self._control.control()
if servicer_context is not None: if servicer_context is not None:
servicer_context.set_trailing_metadata((('testkey', 'testvalue',),)) servicer_context.set_trailing_metadata(((
'testkey',
'testvalue',
),))
def handle_stream_unary(self, request_iterator, servicer_context): def handle_stream_unary(self, request_iterator, servicer_context):
if servicer_context is not None: if servicer_context is not None:
@ -83,13 +89,19 @@ class _Handler(object):
response_elements.append(request) response_elements.append(request)
self._control.control() self._control.control()
if servicer_context is not None: if servicer_context is not None:
servicer_context.set_trailing_metadata((('testkey', 'testvalue',),)) servicer_context.set_trailing_metadata(((
'testkey',
'testvalue',
),))
return b''.join(response_elements) return b''.join(response_elements)
def handle_stream_stream(self, request_iterator, servicer_context): def handle_stream_stream(self, request_iterator, servicer_context):
self._control.control() self._control.control()
if servicer_context is not None: if servicer_context is not None:
servicer_context.set_trailing_metadata((('testkey', 'testvalue',),)) servicer_context.set_trailing_metadata(((
'testkey',
'testvalue',
),))
for request in request_iterator: for request in request_iterator:
self._control.control() self._control.control()
yield request yield request
@ -208,8 +220,8 @@ class InvocationDefectsTest(unittest.TestCase):
with self.assertRaises(grpc.RpcError): with self.assertRaises(grpc.RpcError):
response = multi_callable( response = multi_callable(
requests, requests,
metadata=( metadata=(('test',
('test', 'IterableStreamRequestBlockingUnaryResponse'),)) 'IterableStreamRequestBlockingUnaryResponse'),))
def testIterableStreamRequestFutureUnaryResponse(self): def testIterableStreamRequestFutureUnaryResponse(self):
requests = [b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)] requests = [b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)]

@ -36,16 +36,16 @@ _UNARY_STREAM = 'UnaryStream'
_STREAM_UNARY = 'StreamUnary' _STREAM_UNARY = 'StreamUnary'
_STREAM_STREAM = 'StreamStream' _STREAM_STREAM = 'StreamStream'
_CLIENT_METADATA = (('client-md-key', 'client-md-key'), _CLIENT_METADATA = (('client-md-key', 'client-md-key'), ('client-md-key-bin',
('client-md-key-bin', b'\x00\x01')) b'\x00\x01'))
_SERVER_INITIAL_METADATA = ( _SERVER_INITIAL_METADATA = (('server-initial-md-key',
('server-initial-md-key', 'server-initial-md-value'), 'server-initial-md-value'),
('server-initial-md-key-bin', b'\x00\x02')) ('server-initial-md-key-bin', b'\x00\x02'))
_SERVER_TRAILING_METADATA = ( _SERVER_TRAILING_METADATA = (('server-trailing-md-key',
('server-trailing-md-key', 'server-trailing-md-value'), 'server-trailing-md-value'),
('server-trailing-md-key-bin', b'\x00\x03')) ('server-trailing-md-key-bin', b'\x00\x03'))
_NON_OK_CODE = grpc.StatusCode.NOT_FOUND _NON_OK_CODE = grpc.StatusCode.NOT_FOUND
_DETAILS = 'Test details!' _DETAILS = 'Test details!'
@ -193,17 +193,33 @@ class MetadataCodeDetailsTest(unittest.TestCase):
channel = grpc.insecure_channel('localhost:{}'.format(port)) channel = grpc.insecure_channel('localhost:{}'.format(port))
self._unary_unary = channel.unary_unary( self._unary_unary = channel.unary_unary(
'/'.join(('', _SERVICE, _UNARY_UNARY,)), '/'.join((
'',
_SERVICE,
_UNARY_UNARY,
)),
request_serializer=_REQUEST_SERIALIZER, request_serializer=_REQUEST_SERIALIZER,
response_deserializer=_RESPONSE_DESERIALIZER,) response_deserializer=_RESPONSE_DESERIALIZER,
self._unary_stream = channel.unary_stream( )
'/'.join(('', _SERVICE, _UNARY_STREAM,)),) self._unary_stream = channel.unary_stream('/'.join((
self._stream_unary = channel.stream_unary( '',
'/'.join(('', _SERVICE, _STREAM_UNARY,)),) _SERVICE,
_UNARY_STREAM,
)),)
self._stream_unary = channel.stream_unary('/'.join((
'',
_SERVICE,
_STREAM_UNARY,
)),)
self._stream_stream = channel.stream_stream( self._stream_stream = channel.stream_stream(
'/'.join(('', _SERVICE, _STREAM_STREAM,)), '/'.join((
'',
_SERVICE,
_STREAM_STREAM,
)),
request_serializer=_REQUEST_SERIALIZER, request_serializer=_REQUEST_SERIALIZER,
response_deserializer=_RESPONSE_DESERIALIZER,) response_deserializer=_RESPONSE_DESERIALIZER,
)
def testSuccessfulUnaryUnary(self): def testSuccessfulUnaryUnary(self):
self._servicer.set_details(_DETAILS) self._servicer.set_details(_DETAILS)

@ -33,18 +33,50 @@ _UNARY_STREAM = '/test/UnaryStream'
_STREAM_UNARY = '/test/StreamUnary' _STREAM_UNARY = '/test/StreamUnary'
_STREAM_STREAM = '/test/StreamStream' _STREAM_STREAM = '/test/StreamStream'
_INVOCATION_METADATA = ((b'invocation-md-key', u'invocation-md-value',), _INVOCATION_METADATA = (
(u'invocation-md-key-bin', b'\x00\x01',),) (
_EXPECTED_INVOCATION_METADATA = (('invocation-md-key', 'invocation-md-value',), b'invocation-md-key',
('invocation-md-key-bin', b'\x00\x01',),) u'invocation-md-value',
),
(
u'invocation-md-key-bin',
b'\x00\x01',
),
)
_EXPECTED_INVOCATION_METADATA = (
(
'invocation-md-key',
'invocation-md-value',
),
(
'invocation-md-key-bin',
b'\x00\x01',
),
)
_INITIAL_METADATA = ((b'initial-md-key', u'initial-md-value'), _INITIAL_METADATA = ((b'initial-md-key', u'initial-md-value'),
(u'initial-md-key-bin', b'\x00\x02')) (u'initial-md-key-bin', b'\x00\x02'))
_EXPECTED_INITIAL_METADATA = (('initial-md-key', 'initial-md-value',), _EXPECTED_INITIAL_METADATA = (
('initial-md-key-bin', b'\x00\x02',),) (
'initial-md-key',
_TRAILING_METADATA = (('server-trailing-md-key', 'server-trailing-md-value',), 'initial-md-value',
('server-trailing-md-key-bin', b'\x00\x03',),) ),
(
'initial-md-key-bin',
b'\x00\x02',
),
)
_TRAILING_METADATA = (
(
'server-trailing-md-key',
'server-trailing-md-value',
),
(
'server-trailing-md-key-bin',
b'\x00\x03',
),
)
_EXPECTED_TRAILING_METADATA = _TRAILING_METADATA _EXPECTED_TRAILING_METADATA = _TRAILING_METADATA
@ -146,8 +178,8 @@ class MetadataTest(unittest.TestCase):
def setUp(self): def setUp(self):
self._server = test_common.test_server() self._server = test_common.test_server()
self._server.add_generic_rpc_handlers( self._server.add_generic_rpc_handlers((_GenericHandler(
(_GenericHandler(weakref.proxy(self)),)) weakref.proxy(self)),))
port = self._server.add_insecure_port('[::]:0') port = self._server.add_insecure_port('[::]:0')
self._server.start() self._server.start()
self._channel = grpc.insecure_channel( self._channel = grpc.insecure_channel(

@ -64,7 +64,10 @@ class _Handler(object):
def handle_unary_unary(self, request, servicer_context): def handle_unary_unary(self, request, servicer_context):
self._control.control() self._control.control()
if servicer_context is not None: if servicer_context is not None:
servicer_context.set_trailing_metadata((('testkey', 'testvalue',),)) servicer_context.set_trailing_metadata(((
'testkey',
'testvalue',
),))
# TODO(https://github.com/grpc/grpc/issues/8483): test the values # TODO(https://github.com/grpc/grpc/issues/8483): test the values
# returned by these methods rather than only "smoke" testing that # returned by these methods rather than only "smoke" testing that
# the return after having been called. # the return after having been called.
@ -78,7 +81,10 @@ class _Handler(object):
yield request yield request
self._control.control() self._control.control()
if servicer_context is not None: if servicer_context is not None:
servicer_context.set_trailing_metadata((('testkey', 'testvalue',),)) servicer_context.set_trailing_metadata(((
'testkey',
'testvalue',
),))
def handle_stream_unary(self, request_iterator, servicer_context): def handle_stream_unary(self, request_iterator, servicer_context):
if servicer_context is not None: if servicer_context is not None:
@ -90,13 +96,19 @@ class _Handler(object):
response_elements.append(request) response_elements.append(request)
self._control.control() self._control.control()
if servicer_context is not None: if servicer_context is not None:
servicer_context.set_trailing_metadata((('testkey', 'testvalue',),)) servicer_context.set_trailing_metadata(((
'testkey',
'testvalue',
),))
return b''.join(response_elements) return b''.join(response_elements)
def handle_stream_stream(self, request_iterator, servicer_context): def handle_stream_stream(self, request_iterator, servicer_context):
self._control.control() self._control.control()
if servicer_context is not None: if servicer_context is not None:
servicer_context.set_trailing_metadata((('testkey', 'testvalue',),)) servicer_context.set_trailing_metadata(((
'testkey',
'testvalue',
),))
for request in request_iterator: for request in request_iterator:
self._control.control() self._control.control()
yield request yield request
@ -244,8 +256,8 @@ class RPCTest(unittest.TestCase):
self.assertSequenceEqual(expected_responses, responses) self.assertSequenceEqual(expected_responses, responses)
def testSuccessfulStreamRequestBlockingUnaryResponse(self): def testSuccessfulStreamRequestBlockingUnaryResponse(self):
requests = tuple(b'\x07\x08' requests = tuple(
for _ in range(test_constants.STREAM_LENGTH)) b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH))
expected_response = self._handler.handle_stream_unary( expected_response = self._handler.handle_stream_unary(
iter(requests), None) iter(requests), None)
request_iterator = iter(requests) request_iterator = iter(requests)
@ -253,14 +265,14 @@ class RPCTest(unittest.TestCase):
multi_callable = _stream_unary_multi_callable(self._channel) multi_callable = _stream_unary_multi_callable(self._channel)
response = multi_callable( response = multi_callable(
request_iterator, request_iterator,
metadata=( metadata=(('test',
('test', 'SuccessfulStreamRequestBlockingUnaryResponse'),)) 'SuccessfulStreamRequestBlockingUnaryResponse'),))
self.assertEqual(expected_response, response) self.assertEqual(expected_response, response)
def testSuccessfulStreamRequestBlockingUnaryResponseWithCall(self): def testSuccessfulStreamRequestBlockingUnaryResponseWithCall(self):
requests = tuple(b'\x07\x08' requests = tuple(
for _ in range(test_constants.STREAM_LENGTH)) b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH))
expected_response = self._handler.handle_stream_unary( expected_response = self._handler.handle_stream_unary(
iter(requests), None) iter(requests), None)
request_iterator = iter(requests) request_iterator = iter(requests)
@ -276,8 +288,8 @@ class RPCTest(unittest.TestCase):
self.assertIs(grpc.StatusCode.OK, call.code()) self.assertIs(grpc.StatusCode.OK, call.code())
def testSuccessfulStreamRequestFutureUnaryResponse(self): def testSuccessfulStreamRequestFutureUnaryResponse(self):
requests = tuple(b'\x07\x08' requests = tuple(
for _ in range(test_constants.STREAM_LENGTH)) b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH))
expected_response = self._handler.handle_stream_unary( expected_response = self._handler.handle_stream_unary(
iter(requests), None) iter(requests), None)
request_iterator = iter(requests) request_iterator = iter(requests)
@ -293,8 +305,8 @@ class RPCTest(unittest.TestCase):
self.assertIsNone(response_future.traceback()) self.assertIsNone(response_future.traceback())
def testSuccessfulStreamRequestStreamResponse(self): def testSuccessfulStreamRequestStreamResponse(self):
requests = tuple(b'\x77\x58' requests = tuple(
for _ in range(test_constants.STREAM_LENGTH)) b'\x77\x58' for _ in range(test_constants.STREAM_LENGTH))
expected_responses = tuple( expected_responses = tuple(
self._handler.handle_stream_stream(iter(requests), None)) self._handler.handle_stream_stream(iter(requests), None))
request_iterator = iter(requests) request_iterator = iter(requests)
@ -326,8 +338,8 @@ class RPCTest(unittest.TestCase):
def testConcurrentBlockingInvocations(self): def testConcurrentBlockingInvocations(self):
pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY) pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY)
requests = tuple(b'\x07\x08' requests = tuple(
for _ in range(test_constants.STREAM_LENGTH)) b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH))
expected_response = self._handler.handle_stream_unary( expected_response = self._handler.handle_stream_unary(
iter(requests), None) iter(requests), None)
expected_responses = [expected_response expected_responses = [expected_response
@ -342,15 +354,15 @@ class RPCTest(unittest.TestCase):
request_iterator, request_iterator,
metadata=(('test', 'ConcurrentBlockingInvocations'),)) metadata=(('test', 'ConcurrentBlockingInvocations'),))
response_futures[index] = response_future response_futures[index] = response_future
responses = tuple(response_future.result() responses = tuple(
for response_future in response_futures) response_future.result() for response_future in response_futures)
pool.shutdown(wait=True) pool.shutdown(wait=True)
self.assertSequenceEqual(expected_responses, responses) self.assertSequenceEqual(expected_responses, responses)
def testConcurrentFutureInvocations(self): def testConcurrentFutureInvocations(self):
requests = tuple(b'\x07\x08' requests = tuple(
for _ in range(test_constants.STREAM_LENGTH)) b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH))
expected_response = self._handler.handle_stream_unary( expected_response = self._handler.handle_stream_unary(
iter(requests), None) iter(requests), None)
expected_responses = [expected_response expected_responses = [expected_response
@ -364,8 +376,8 @@ class RPCTest(unittest.TestCase):
request_iterator, request_iterator,
metadata=(('test', 'ConcurrentFutureInvocations'),)) metadata=(('test', 'ConcurrentFutureInvocations'),))
response_futures[index] = response_future response_futures[index] = response_future
responses = tuple(response_future.result() responses = tuple(
for response_future in response_futures) response_future.result() for response_future in response_futures)
self.assertSequenceEqual(expected_responses, responses) self.assertSequenceEqual(expected_responses, responses)
@ -424,14 +436,14 @@ class RPCTest(unittest.TestCase):
multi_callable = _unary_stream_multi_callable(self._channel) multi_callable = _unary_stream_multi_callable(self._channel)
response_iterator = multi_callable( response_iterator = multi_callable(
request, request,
metadata=( metadata=(('test',
('test', 'ConsumingSomeButNotAllStreamResponsesUnaryRequest'),)) 'ConsumingSomeButNotAllStreamResponsesUnaryRequest'),))
for _ in range(test_constants.STREAM_LENGTH // 2): for _ in range(test_constants.STREAM_LENGTH // 2):
next(response_iterator) next(response_iterator)
def testConsumingSomeButNotAllStreamResponsesStreamRequest(self): def testConsumingSomeButNotAllStreamResponsesStreamRequest(self):
requests = tuple(b'\x67\x88' requests = tuple(
for _ in range(test_constants.STREAM_LENGTH)) b'\x67\x88' for _ in range(test_constants.STREAM_LENGTH))
request_iterator = iter(requests) request_iterator = iter(requests)
multi_callable = _stream_stream_multi_callable(self._channel) multi_callable = _stream_stream_multi_callable(self._channel)
@ -443,15 +455,15 @@ class RPCTest(unittest.TestCase):
next(response_iterator) next(response_iterator)
def testConsumingTooManyStreamResponsesStreamRequest(self): def testConsumingTooManyStreamResponsesStreamRequest(self):
requests = tuple(b'\x67\x88' requests = tuple(
for _ in range(test_constants.STREAM_LENGTH)) b'\x67\x88' for _ in range(test_constants.STREAM_LENGTH))
request_iterator = iter(requests) request_iterator = iter(requests)
multi_callable = _stream_stream_multi_callable(self._channel) multi_callable = _stream_stream_multi_callable(self._channel)
response_iterator = multi_callable( response_iterator = multi_callable(
request_iterator, request_iterator,
metadata=( metadata=(('test',
('test', 'ConsumingTooManyStreamResponsesStreamRequest'),)) 'ConsumingTooManyStreamResponsesStreamRequest'),))
for _ in range(test_constants.STREAM_LENGTH): for _ in range(test_constants.STREAM_LENGTH):
next(response_iterator) next(response_iterator)
for _ in range(test_constants.STREAM_LENGTH): for _ in range(test_constants.STREAM_LENGTH):
@ -503,8 +515,8 @@ class RPCTest(unittest.TestCase):
self.assertIsNotNone(response_iterator.trailing_metadata()) self.assertIsNotNone(response_iterator.trailing_metadata())
def testCancelledStreamRequestUnaryResponse(self): def testCancelledStreamRequestUnaryResponse(self):
requests = tuple(b'\x07\x08' requests = tuple(
for _ in range(test_constants.STREAM_LENGTH)) b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH))
request_iterator = iter(requests) request_iterator = iter(requests)
multi_callable = _stream_unary_multi_callable(self._channel) multi_callable = _stream_unary_multi_callable(self._channel)
@ -528,8 +540,8 @@ class RPCTest(unittest.TestCase):
self.assertIsNotNone(response_future.trailing_metadata()) self.assertIsNotNone(response_future.trailing_metadata())
def testCancelledStreamRequestStreamResponse(self): def testCancelledStreamRequestStreamResponse(self):
requests = tuple(b'\x07\x08' requests = tuple(
for _ in range(test_constants.STREAM_LENGTH)) b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH))
request_iterator = iter(requests) request_iterator = iter(requests)
multi_callable = _stream_stream_multi_callable(self._channel) multi_callable = _stream_stream_multi_callable(self._channel)
@ -555,8 +567,8 @@ class RPCTest(unittest.TestCase):
multi_callable.with_call( multi_callable.with_call(
request, request,
timeout=test_constants.SHORT_TIMEOUT, timeout=test_constants.SHORT_TIMEOUT,
metadata=( metadata=(('test',
('test', 'ExpiredUnaryRequestBlockingUnaryResponse'),)) 'ExpiredUnaryRequestBlockingUnaryResponse'),))
self.assertIsInstance(exception_context.exception, grpc.Call) self.assertIsInstance(exception_context.exception, grpc.Call)
self.assertIsNotNone(exception_context.exception.initial_metadata()) self.assertIsNotNone(exception_context.exception.initial_metadata())
@ -610,8 +622,8 @@ class RPCTest(unittest.TestCase):
response_iterator.code()) response_iterator.code())
def testExpiredStreamRequestBlockingUnaryResponse(self): def testExpiredStreamRequestBlockingUnaryResponse(self):
requests = tuple(b'\x07\x08' requests = tuple(
for _ in range(test_constants.STREAM_LENGTH)) b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH))
request_iterator = iter(requests) request_iterator = iter(requests)
multi_callable = _stream_unary_multi_callable(self._channel) multi_callable = _stream_unary_multi_callable(self._channel)
@ -620,8 +632,8 @@ class RPCTest(unittest.TestCase):
multi_callable( multi_callable(
request_iterator, request_iterator,
timeout=test_constants.SHORT_TIMEOUT, timeout=test_constants.SHORT_TIMEOUT,
metadata=( metadata=(('test',
('test', 'ExpiredStreamRequestBlockingUnaryResponse'),)) 'ExpiredStreamRequestBlockingUnaryResponse'),))
self.assertIsInstance(exception_context.exception, grpc.RpcError) self.assertIsInstance(exception_context.exception, grpc.RpcError)
self.assertIsInstance(exception_context.exception, grpc.Call) self.assertIsInstance(exception_context.exception, grpc.Call)
@ -632,8 +644,8 @@ class RPCTest(unittest.TestCase):
self.assertIsNotNone(exception_context.exception.trailing_metadata()) self.assertIsNotNone(exception_context.exception.trailing_metadata())
def testExpiredStreamRequestFutureUnaryResponse(self): def testExpiredStreamRequestFutureUnaryResponse(self):
requests = tuple(b'\x07\x18' requests = tuple(
for _ in range(test_constants.STREAM_LENGTH)) b'\x07\x18' for _ in range(test_constants.STREAM_LENGTH))
request_iterator = iter(requests) request_iterator = iter(requests)
callback = _Callback() callback = _Callback()
@ -644,8 +656,8 @@ class RPCTest(unittest.TestCase):
timeout=test_constants.SHORT_TIMEOUT, timeout=test_constants.SHORT_TIMEOUT,
metadata=(('test', 'ExpiredStreamRequestFutureUnaryResponse'),)) metadata=(('test', 'ExpiredStreamRequestFutureUnaryResponse'),))
with self.assertRaises(grpc.FutureTimeoutError): with self.assertRaises(grpc.FutureTimeoutError):
response_future.result(timeout=test_constants.SHORT_TIMEOUT / response_future.result(
2.0) timeout=test_constants.SHORT_TIMEOUT / 2.0)
response_future.add_done_callback(callback) response_future.add_done_callback(callback)
value_passed_to_callback = callback.value() value_passed_to_callback = callback.value()
@ -663,8 +675,8 @@ class RPCTest(unittest.TestCase):
self.assertIsNotNone(response_future.trailing_metadata()) self.assertIsNotNone(response_future.trailing_metadata())
def testExpiredStreamRequestStreamResponse(self): def testExpiredStreamRequestStreamResponse(self):
requests = tuple(b'\x67\x18' requests = tuple(
for _ in range(test_constants.STREAM_LENGTH)) b'\x67\x18' for _ in range(test_constants.STREAM_LENGTH))
request_iterator = iter(requests) request_iterator = iter(requests)
multi_callable = _stream_stream_multi_callable(self._channel) multi_callable = _stream_stream_multi_callable(self._channel)
@ -689,8 +701,8 @@ class RPCTest(unittest.TestCase):
with self.assertRaises(grpc.RpcError) as exception_context: with self.assertRaises(grpc.RpcError) as exception_context:
multi_callable.with_call( multi_callable.with_call(
request, request,
metadata=( metadata=(('test',
('test', 'FailedUnaryRequestBlockingUnaryResponse'),)) 'FailedUnaryRequestBlockingUnaryResponse'),))
self.assertIs(grpc.StatusCode.UNKNOWN, self.assertIs(grpc.StatusCode.UNKNOWN,
exception_context.exception.code()) exception_context.exception.code())
@ -734,8 +746,8 @@ class RPCTest(unittest.TestCase):
exception_context.exception.code()) exception_context.exception.code())
def testFailedStreamRequestBlockingUnaryResponse(self): def testFailedStreamRequestBlockingUnaryResponse(self):
requests = tuple(b'\x47\x58' requests = tuple(
for _ in range(test_constants.STREAM_LENGTH)) b'\x47\x58' for _ in range(test_constants.STREAM_LENGTH))
request_iterator = iter(requests) request_iterator = iter(requests)
multi_callable = _stream_unary_multi_callable(self._channel) multi_callable = _stream_unary_multi_callable(self._channel)
@ -743,15 +755,15 @@ class RPCTest(unittest.TestCase):
with self.assertRaises(grpc.RpcError) as exception_context: with self.assertRaises(grpc.RpcError) as exception_context:
multi_callable( multi_callable(
request_iterator, request_iterator,
metadata=( metadata=(('test',
('test', 'FailedStreamRequestBlockingUnaryResponse'),)) 'FailedStreamRequestBlockingUnaryResponse'),))
self.assertIs(grpc.StatusCode.UNKNOWN, self.assertIs(grpc.StatusCode.UNKNOWN,
exception_context.exception.code()) exception_context.exception.code())
def testFailedStreamRequestFutureUnaryResponse(self): def testFailedStreamRequestFutureUnaryResponse(self):
requests = tuple(b'\x07\x18' requests = tuple(
for _ in range(test_constants.STREAM_LENGTH)) b'\x07\x18' for _ in range(test_constants.STREAM_LENGTH))
request_iterator = iter(requests) request_iterator = iter(requests)
callback = _Callback() callback = _Callback()
@ -773,8 +785,8 @@ class RPCTest(unittest.TestCase):
self.assertIs(response_future, value_passed_to_callback) self.assertIs(response_future, value_passed_to_callback)
def testFailedStreamRequestStreamResponse(self): def testFailedStreamRequestStreamResponse(self):
requests = tuple(b'\x67\x88' requests = tuple(
for _ in range(test_constants.STREAM_LENGTH)) b'\x67\x88' for _ in range(test_constants.STREAM_LENGTH))
request_iterator = iter(requests) request_iterator = iter(requests)
multi_callable = _stream_stream_multi_callable(self._channel) multi_callable = _stream_stream_multi_callable(self._channel)
@ -805,8 +817,8 @@ class RPCTest(unittest.TestCase):
request, metadata=(('test', 'IgnoredUnaryRequestStreamResponse'),)) request, metadata=(('test', 'IgnoredUnaryRequestStreamResponse'),))
def testIgnoredStreamRequestFutureUnaryResponse(self): def testIgnoredStreamRequestFutureUnaryResponse(self):
requests = tuple(b'\x07\x18' requests = tuple(
for _ in range(test_constants.STREAM_LENGTH)) b'\x07\x18' for _ in range(test_constants.STREAM_LENGTH))
request_iterator = iter(requests) request_iterator = iter(requests)
multi_callable = _stream_unary_multi_callable(self._channel) multi_callable = _stream_unary_multi_callable(self._channel)
@ -815,8 +827,8 @@ class RPCTest(unittest.TestCase):
metadata=(('test', 'IgnoredStreamRequestFutureUnaryResponse'),)) metadata=(('test', 'IgnoredStreamRequestFutureUnaryResponse'),))
def testIgnoredStreamRequestStreamResponse(self): def testIgnoredStreamRequestStreamResponse(self):
requests = tuple(b'\x67\x88' requests = tuple(
for _ in range(test_constants.STREAM_LENGTH)) b'\x67\x88' for _ in range(test_constants.STREAM_LENGTH))
request_iterator = iter(requests) request_iterator = iter(requests)
multi_callable = _stream_stream_multi_callable(self._channel) multi_callable = _stream_stream_multi_callable(self._channel)

@ -74,7 +74,8 @@ def _create_client_stub(
expect_success, expect_success,
root_certificates=None, root_certificates=None,
private_key=None, private_key=None,
certificate_chain=None,): certificate_chain=None,
):
channel = grpc.secure_channel('localhost:{}'.format(port), channel = grpc.secure_channel('localhost:{}'.format(port),
grpc.ssl_channel_credentials( grpc.ssl_channel_credentials(
root_certificates=root_certificates, root_certificates=root_certificates,

@ -52,7 +52,9 @@ class CleanupThreadTest(unittest.TestCase):
target=target, target=target,
name='test-name', name='test-name',
args=('arg1', 'arg2'), args=('arg1', 'arg2'),
kwargs={'arg3': 'arg3'}) kwargs={
'arg3': 'arg3'
})
cleanup_thread.start() cleanup_thread.start()
cleanup_thread.join() cleanup_thread.join()
self.assertEqual(cleanup_thread.name, 'test-name') self.assertEqual(cleanup_thread.name, 'test-name')

@ -163,7 +163,10 @@ class BetaFeaturesTest(unittest.TestCase):
self._server = implementations.server( self._server = implementations.server(
method_implementations, options=server_options) method_implementations, options=server_options)
server_credentials = implementations.ssl_server_credentials([ server_credentials = implementations.ssl_server_credentials([
(resources.private_key(), resources.certificate_chain(),), (
resources.private_key(),
resources.certificate_chain(),
),
]) ])
port = self._server.add_secure_port('[::]:0', server_credentials) port = self._server.add_secure_port('[::]:0', server_credentials)
self._server.start() self._server.start()
@ -289,7 +292,10 @@ class ContextManagementAndLifecycleTest(unittest.TestCase):
self._server_options = implementations.server_options( self._server_options = implementations.server_options(
thread_pool_size=test_constants.POOL_SIZE) thread_pool_size=test_constants.POOL_SIZE)
self._server_credentials = implementations.ssl_server_credentials([ self._server_credentials = implementations.ssl_server_credentials([
(resources.private_key(), resources.certificate_chain(),), (
resources.private_key(),
resources.certificate_chain(),
),
]) ])
self._channel_credentials = implementations.ssl_channel_credentials( self._channel_credentials = implementations.ssl_channel_credentials(
resources.test_root_certificates()) resources.test_root_certificates())

@ -32,8 +32,11 @@ _SERVER_HOST_OVERRIDE = 'foo.test.google.fr'
class _SerializationBehaviors( class _SerializationBehaviors(
collections.namedtuple('_SerializationBehaviors', ( collections.namedtuple('_SerializationBehaviors', (
'request_serializers', 'request_deserializers', 'request_serializers',
'response_serializers', 'response_deserializers',))): 'request_deserializers',
'response_serializers',
'response_deserializers',
))):
pass pass
@ -73,7 +76,10 @@ class _Implementation(test_interfaces.Implementation):
server = implementations.server( server = implementations.server(
method_implementations, options=server_options) method_implementations, options=server_options)
server_credentials = implementations.ssl_server_credentials([ server_credentials = implementations.ssl_server_credentials([
(resources.private_key(), resources.certificate_chain(),), (
resources.private_key(),
resources.certificate_chain(),
),
]) ])
port = server.add_secure_port('[::]:0', server_credentials) port = server.add_secure_port('[::]:0', server_credentials)
server.start() server.start()
@ -116,9 +122,10 @@ class _Implementation(test_interfaces.Implementation):
def load_tests(loader, tests, pattern): def load_tests(loader, tests, pattern):
return unittest.TestSuite(tests=tuple( return unittest.TestSuite(
loader.loadTestsFromTestCase(test_case_class) tests=tuple(
for test_case_class in test_cases.test_cases(_Implementation()))) loader.loadTestsFromTestCase(test_case_class)
for test_case_class in test_cases.test_cases(_Implementation())))
if __name__ == '__main__': if __name__ == '__main__':

@ -41,8 +41,8 @@ class CallCredentialsTest(unittest.TestCase):
def test_google_call_credentials(self): def test_google_call_credentials(self):
creds = oauth2client_client.GoogleCredentials( creds = oauth2client_client.GoogleCredentials(
'token', 'client_id', 'secret', 'refresh_token', 'token', 'client_id', 'secret', 'refresh_token',
datetime.datetime(2008, 6, 24), 'https://refresh.uri.com/', datetime.datetime(2008, 6,
'user_agent') 24), 'https://refresh.uri.com/', 'user_agent')
call_creds = implementations.google_call_credentials(creds) call_creds = implementations.google_call_credentials(creds)
self.assertIsInstance(call_creds, implementations.CallCredentials) self.assertIsInstance(call_creds, implementations.CallCredentials)

@ -33,6 +33,8 @@ def not_really_secure_channel(host, port, channel_credentials,
conducted. conducted.
""" """
target = '%s:%d' % (host, port) target = '%s:%d' % (host, port)
channel = grpc.secure_channel(target, channel_credentials, ( channel = grpc.secure_channel(target, channel_credentials, ((
('grpc.ssl_target_name_override', server_host_override,),)) 'grpc.ssl_target_name_override',
server_host_override,
),))
return implementations.Channel(channel) return implementations.Channel(channel)

@ -70,8 +70,8 @@ class TestCase(
self.implementation.destantiate(self._memo) self.implementation.destantiate(self._memo)
def testSuccessfulUnaryRequestUnaryResponse(self): def testSuccessfulUnaryRequestUnaryResponse(self):
for (group, method), test_messages_sequence in ( for (group, method), test_messages_sequence in (six.iteritems(
six.iteritems(self._digest.unary_unary_messages_sequences)): self._digest.unary_unary_messages_sequences)):
for test_messages in test_messages_sequence: for test_messages in test_messages_sequence:
request = test_messages.request() request = test_messages.request()
@ -81,8 +81,8 @@ class TestCase(
test_messages.verify(request, response, self) test_messages.verify(request, response, self)
def testSuccessfulUnaryRequestStreamResponse(self): def testSuccessfulUnaryRequestStreamResponse(self):
for (group, method), test_messages_sequence in ( for (group, method), test_messages_sequence in (six.iteritems(
six.iteritems(self._digest.unary_stream_messages_sequences)): self._digest.unary_stream_messages_sequences)):
for test_messages in test_messages_sequence: for test_messages in test_messages_sequence:
request = test_messages.request() request = test_messages.request()
@ -93,8 +93,8 @@ class TestCase(
test_messages.verify(request, responses, self) test_messages.verify(request, responses, self)
def testSuccessfulStreamRequestUnaryResponse(self): def testSuccessfulStreamRequestUnaryResponse(self):
for (group, method), test_messages_sequence in ( for (group, method), test_messages_sequence in (six.iteritems(
six.iteritems(self._digest.stream_unary_messages_sequences)): self._digest.stream_unary_messages_sequences)):
for test_messages in test_messages_sequence: for test_messages in test_messages_sequence:
requests = test_messages.requests() requests = test_messages.requests()
@ -104,8 +104,8 @@ class TestCase(
test_messages.verify(requests, response, self) test_messages.verify(requests, response, self)
def testSuccessfulStreamRequestStreamResponse(self): def testSuccessfulStreamRequestStreamResponse(self):
for (group, method), test_messages_sequence in ( for (group, method), test_messages_sequence in (six.iteritems(
six.iteritems(self._digest.stream_stream_messages_sequences)): self._digest.stream_stream_messages_sequences)):
for test_messages in test_messages_sequence: for test_messages in test_messages_sequence:
requests = test_messages.requests() requests = test_messages.requests()
@ -116,8 +116,8 @@ class TestCase(
test_messages.verify(requests, responses, self) test_messages.verify(requests, responses, self)
def testSequentialInvocations(self): def testSequentialInvocations(self):
for (group, method), test_messages_sequence in ( for (group, method), test_messages_sequence in (six.iteritems(
six.iteritems(self._digest.unary_unary_messages_sequences)): self._digest.unary_unary_messages_sequences)):
for test_messages in test_messages_sequence: for test_messages in test_messages_sequence:
first_request = test_messages.request() first_request = test_messages.request()
second_request = test_messages.request() second_request = test_messages.request()
@ -134,8 +134,8 @@ class TestCase(
def testParallelInvocations(self): def testParallelInvocations(self):
pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY) pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY)
for (group, method), test_messages_sequence in ( for (group, method), test_messages_sequence in (six.iteritems(
six.iteritems(self._digest.unary_unary_messages_sequences)): self._digest.unary_unary_messages_sequences)):
for test_messages in test_messages_sequence: for test_messages in test_messages_sequence:
requests = [] requests = []
response_futures = [] response_futures = []
@ -158,8 +158,8 @@ class TestCase(
def testWaitingForSomeButNotAllParallelInvocations(self): def testWaitingForSomeButNotAllParallelInvocations(self):
pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY) pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY)
for (group, method), test_messages_sequence in ( for (group, method), test_messages_sequence in (six.iteritems(
six.iteritems(self._digest.unary_unary_messages_sequences)): self._digest.unary_unary_messages_sequences)):
for test_messages in test_messages_sequence: for test_messages in test_messages_sequence:
requests = [] requests = []
response_futures_to_indices = {} response_futures_to_indices = {}
@ -197,8 +197,8 @@ class TestCase(
raise NotImplementedError() raise NotImplementedError()
def testExpiredUnaryRequestUnaryResponse(self): def testExpiredUnaryRequestUnaryResponse(self):
for (group, method), test_messages_sequence in ( for (group, method), test_messages_sequence in (six.iteritems(
six.iteritems(self._digest.unary_unary_messages_sequences)): self._digest.unary_unary_messages_sequences)):
for test_messages in test_messages_sequence: for test_messages in test_messages_sequence:
request = test_messages.request() request = test_messages.request()
@ -208,8 +208,8 @@ class TestCase(
request, _3069_test_constant.REALLY_SHORT_TIMEOUT) request, _3069_test_constant.REALLY_SHORT_TIMEOUT)
def testExpiredUnaryRequestStreamResponse(self): def testExpiredUnaryRequestStreamResponse(self):
for (group, method), test_messages_sequence in ( for (group, method), test_messages_sequence in (six.iteritems(
six.iteritems(self._digest.unary_stream_messages_sequences)): self._digest.unary_stream_messages_sequences)):
for test_messages in test_messages_sequence: for test_messages in test_messages_sequence:
request = test_messages.request() request = test_messages.request()
@ -220,33 +220,33 @@ class TestCase(
list(response_iterator) list(response_iterator)
def testExpiredStreamRequestUnaryResponse(self): def testExpiredStreamRequestUnaryResponse(self):
for (group, method), test_messages_sequence in ( for (group, method), test_messages_sequence in (six.iteritems(
six.iteritems(self._digest.stream_unary_messages_sequences)): self._digest.stream_unary_messages_sequences)):
for test_messages in test_messages_sequence: for test_messages in test_messages_sequence:
requests = test_messages.requests() requests = test_messages.requests()
with self._control.pause(), self.assertRaises( with self._control.pause(), self.assertRaises(
face.ExpirationError): face.ExpirationError):
self._invoker.blocking(group, method)( self._invoker.blocking(
iter(requests), group, method)(iter(requests),
_3069_test_constant.REALLY_SHORT_TIMEOUT) _3069_test_constant.REALLY_SHORT_TIMEOUT)
def testExpiredStreamRequestStreamResponse(self): def testExpiredStreamRequestStreamResponse(self):
for (group, method), test_messages_sequence in ( for (group, method), test_messages_sequence in (six.iteritems(
six.iteritems(self._digest.stream_stream_messages_sequences)): self._digest.stream_stream_messages_sequences)):
for test_messages in test_messages_sequence: for test_messages in test_messages_sequence:
requests = test_messages.requests() requests = test_messages.requests()
with self._control.pause(), self.assertRaises( with self._control.pause(), self.assertRaises(
face.ExpirationError): face.ExpirationError):
response_iterator = self._invoker.blocking(group, method)( response_iterator = self._invoker.blocking(
iter(requests), group, method)(iter(requests),
_3069_test_constant.REALLY_SHORT_TIMEOUT) _3069_test_constant.REALLY_SHORT_TIMEOUT)
list(response_iterator) list(response_iterator)
def testFailedUnaryRequestUnaryResponse(self): def testFailedUnaryRequestUnaryResponse(self):
for (group, method), test_messages_sequence in ( for (group, method), test_messages_sequence in (six.iteritems(
six.iteritems(self._digest.unary_unary_messages_sequences)): self._digest.unary_unary_messages_sequences)):
for test_messages in test_messages_sequence: for test_messages in test_messages_sequence:
request = test_messages.request() request = test_messages.request()
@ -255,8 +255,8 @@ class TestCase(
request, test_constants.LONG_TIMEOUT) request, test_constants.LONG_TIMEOUT)
def testFailedUnaryRequestStreamResponse(self): def testFailedUnaryRequestStreamResponse(self):
for (group, method), test_messages_sequence in ( for (group, method), test_messages_sequence in (six.iteritems(
six.iteritems(self._digest.unary_stream_messages_sequences)): self._digest.unary_stream_messages_sequences)):
for test_messages in test_messages_sequence: for test_messages in test_messages_sequence:
request = test_messages.request() request = test_messages.request()
@ -266,8 +266,8 @@ class TestCase(
list(response_iterator) list(response_iterator)
def testFailedStreamRequestUnaryResponse(self): def testFailedStreamRequestUnaryResponse(self):
for (group, method), test_messages_sequence in ( for (group, method), test_messages_sequence in (six.iteritems(
six.iteritems(self._digest.stream_unary_messages_sequences)): self._digest.stream_unary_messages_sequences)):
for test_messages in test_messages_sequence: for test_messages in test_messages_sequence:
requests = test_messages.requests() requests = test_messages.requests()
@ -276,8 +276,8 @@ class TestCase(
iter(requests), test_constants.LONG_TIMEOUT) iter(requests), test_constants.LONG_TIMEOUT)
def testFailedStreamRequestStreamResponse(self): def testFailedStreamRequestStreamResponse(self):
for (group, method), test_messages_sequence in ( for (group, method), test_messages_sequence in (six.iteritems(
six.iteritems(self._digest.stream_stream_messages_sequences)): self._digest.stream_stream_messages_sequences)):
for test_messages in test_messages_sequence: for test_messages in test_messages_sequence:
requests = test_messages.requests() requests = test_messages.requests()

@ -34,11 +34,15 @@ _IDENTITY = lambda x: x
class TestServiceDigest( class TestServiceDigest(
collections.namedtuple('TestServiceDigest', ( collections.namedtuple('TestServiceDigest', (
'methods', 'inline_method_implementations', 'methods',
'event_method_implementations', 'multi_method_implementation', 'inline_method_implementations',
'unary_unary_messages_sequences', 'unary_stream_messages_sequences', 'event_method_implementations',
'multi_method_implementation',
'unary_unary_messages_sequences',
'unary_stream_messages_sequences',
'stream_unary_messages_sequences', 'stream_unary_messages_sequences',
'stream_stream_messages_sequences',))): 'stream_stream_messages_sequences',
))):
"""A transformation of a service.TestService. """A transformation of a service.TestService.
Attributes: Attributes:
@ -421,8 +425,8 @@ def digest(service, control, pool):
events.update(stream_unary.events) events.update(stream_unary.events)
events.update(stream_stream.events) events.update(stream_stream.events)
return TestServiceDigest( return TestServiceDigest(methods, inlines, events,
methods, inlines, events, _MultiMethodImplementation(adaptations, control,
_MultiMethodImplementation(adaptations, control, pool), pool),
unary_unary.messages, unary_stream.messages, stream_unary.messages, unary_unary.messages, unary_stream.messages,
stream_stream.messages) stream_unary.messages, stream_stream.messages)

@ -134,8 +134,8 @@ class TestCase(
self._digest_pool.shutdown(wait=True) self._digest_pool.shutdown(wait=True)
def testSuccessfulUnaryRequestUnaryResponse(self): def testSuccessfulUnaryRequestUnaryResponse(self):
for (group, method), test_messages_sequence in ( for (group, method), test_messages_sequence in (six.iteritems(
six.iteritems(self._digest.unary_unary_messages_sequences)): self._digest.unary_unary_messages_sequences)):
for test_messages in test_messages_sequence: for test_messages in test_messages_sequence:
request = test_messages.request() request = test_messages.request()
callback = _Callback() callback = _Callback()
@ -151,8 +151,8 @@ class TestCase(
self.assertIsNone(response_future.traceback()) self.assertIsNone(response_future.traceback())
def testSuccessfulUnaryRequestStreamResponse(self): def testSuccessfulUnaryRequestStreamResponse(self):
for (group, method), test_messages_sequence in ( for (group, method), test_messages_sequence in (six.iteritems(
six.iteritems(self._digest.unary_stream_messages_sequences)): self._digest.unary_stream_messages_sequences)):
for test_messages in test_messages_sequence: for test_messages in test_messages_sequence:
request = test_messages.request() request = test_messages.request()
@ -163,8 +163,8 @@ class TestCase(
test_messages.verify(request, responses, self) test_messages.verify(request, responses, self)
def testSuccessfulStreamRequestUnaryResponse(self): def testSuccessfulStreamRequestUnaryResponse(self):
for (group, method), test_messages_sequence in ( for (group, method), test_messages_sequence in (six.iteritems(
six.iteritems(self._digest.stream_unary_messages_sequences)): self._digest.stream_unary_messages_sequences)):
for test_messages in test_messages_sequence: for test_messages in test_messages_sequence:
requests = test_messages.requests() requests = test_messages.requests()
request_iterator = _PauseableIterator(iter(requests)) request_iterator = _PauseableIterator(iter(requests))
@ -185,8 +185,8 @@ class TestCase(
self.assertIsNone(response_future.traceback()) self.assertIsNone(response_future.traceback())
def testSuccessfulStreamRequestStreamResponse(self): def testSuccessfulStreamRequestStreamResponse(self):
for (group, method), test_messages_sequence in ( for (group, method), test_messages_sequence in (six.iteritems(
six.iteritems(self._digest.stream_stream_messages_sequences)): self._digest.stream_stream_messages_sequences)):
for test_messages in test_messages_sequence: for test_messages in test_messages_sequence:
requests = test_messages.requests() requests = test_messages.requests()
request_iterator = _PauseableIterator(iter(requests)) request_iterator = _PauseableIterator(iter(requests))
@ -201,8 +201,8 @@ class TestCase(
test_messages.verify(requests, responses, self) test_messages.verify(requests, responses, self)
def testSequentialInvocations(self): def testSequentialInvocations(self):
for (group, method), test_messages_sequence in ( for (group, method), test_messages_sequence in (six.iteritems(
six.iteritems(self._digest.unary_unary_messages_sequences)): self._digest.unary_unary_messages_sequences)):
for test_messages in test_messages_sequence: for test_messages in test_messages_sequence:
first_request = test_messages.request() first_request = test_messages.request()
second_request = test_messages.request() second_request = test_messages.request()
@ -220,8 +220,8 @@ class TestCase(
test_messages.verify(second_request, second_response, self) test_messages.verify(second_request, second_response, self)
def testParallelInvocations(self): def testParallelInvocations(self):
for (group, method), test_messages_sequence in ( for (group, method), test_messages_sequence in (six.iteritems(
six.iteritems(self._digest.unary_unary_messages_sequences)): self._digest.unary_unary_messages_sequences)):
for test_messages in test_messages_sequence: for test_messages in test_messages_sequence:
first_request = test_messages.request() first_request = test_messages.request()
second_request = test_messages.request() second_request = test_messages.request()
@ -236,8 +236,8 @@ class TestCase(
test_messages.verify(first_request, first_response, self) test_messages.verify(first_request, first_response, self)
test_messages.verify(second_request, second_response, self) test_messages.verify(second_request, second_response, self)
for (group, method), test_messages_sequence in ( for (group, method), test_messages_sequence in (six.iteritems(
six.iteritems(self._digest.unary_unary_messages_sequences)): self._digest.unary_unary_messages_sequences)):
for test_messages in test_messages_sequence: for test_messages in test_messages_sequence:
requests = [] requests = []
response_futures = [] response_futures = []
@ -258,8 +258,8 @@ class TestCase(
def testWaitingForSomeButNotAllParallelInvocations(self): def testWaitingForSomeButNotAllParallelInvocations(self):
pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY) pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY)
for (group, method), test_messages_sequence in ( for (group, method), test_messages_sequence in (six.iteritems(
six.iteritems(self._digest.unary_unary_messages_sequences)): self._digest.unary_unary_messages_sequences)):
for test_messages in test_messages_sequence: for test_messages in test_messages_sequence:
requests = [] requests = []
response_futures_to_indices = {} response_futures_to_indices = {}
@ -282,8 +282,8 @@ class TestCase(
pool.shutdown(wait=True) pool.shutdown(wait=True)
def testCancelledUnaryRequestUnaryResponse(self): def testCancelledUnaryRequestUnaryResponse(self):
for (group, method), test_messages_sequence in ( for (group, method), test_messages_sequence in (six.iteritems(
six.iteritems(self._digest.unary_unary_messages_sequences)): self._digest.unary_unary_messages_sequences)):
for test_messages in test_messages_sequence: for test_messages in test_messages_sequence:
request = test_messages.request() request = test_messages.request()
callback = _Callback() callback = _Callback()
@ -305,8 +305,8 @@ class TestCase(
response_future.traceback() response_future.traceback()
def testCancelledUnaryRequestStreamResponse(self): def testCancelledUnaryRequestStreamResponse(self):
for (group, method), test_messages_sequence in ( for (group, method), test_messages_sequence in (six.iteritems(
six.iteritems(self._digest.unary_stream_messages_sequences)): self._digest.unary_stream_messages_sequences)):
for test_messages in test_messages_sequence: for test_messages in test_messages_sequence:
request = test_messages.request() request = test_messages.request()
@ -319,8 +319,8 @@ class TestCase(
next(response_iterator) next(response_iterator)
def testCancelledStreamRequestUnaryResponse(self): def testCancelledStreamRequestUnaryResponse(self):
for (group, method), test_messages_sequence in ( for (group, method), test_messages_sequence in (six.iteritems(
six.iteritems(self._digest.stream_unary_messages_sequences)): self._digest.stream_unary_messages_sequences)):
for test_messages in test_messages_sequence: for test_messages in test_messages_sequence:
requests = test_messages.requests() requests = test_messages.requests()
callback = _Callback() callback = _Callback()
@ -342,8 +342,8 @@ class TestCase(
response_future.traceback() response_future.traceback()
def testCancelledStreamRequestStreamResponse(self): def testCancelledStreamRequestStreamResponse(self):
for (group, method), test_messages_sequence in ( for (group, method), test_messages_sequence in (six.iteritems(
six.iteritems(self._digest.stream_stream_messages_sequences)): self._digest.stream_stream_messages_sequences)):
for test_messages in test_messages_sequence: for test_messages in test_messages_sequence:
requests = test_messages.requests() requests = test_messages.requests()
@ -356,8 +356,8 @@ class TestCase(
next(response_iterator) next(response_iterator)
def testExpiredUnaryRequestUnaryResponse(self): def testExpiredUnaryRequestUnaryResponse(self):
for (group, method), test_messages_sequence in ( for (group, method), test_messages_sequence in (six.iteritems(
six.iteritems(self._digest.unary_unary_messages_sequences)): self._digest.unary_unary_messages_sequences)):
for test_messages in test_messages_sequence: for test_messages in test_messages_sequence:
request = test_messages.request() request = test_messages.request()
callback = _Callback() callback = _Callback()
@ -376,8 +376,8 @@ class TestCase(
self.assertIsNotNone(response_future.traceback()) self.assertIsNotNone(response_future.traceback())
def testExpiredUnaryRequestStreamResponse(self): def testExpiredUnaryRequestStreamResponse(self):
for (group, method), test_messages_sequence in ( for (group, method), test_messages_sequence in (six.iteritems(
six.iteritems(self._digest.unary_stream_messages_sequences)): self._digest.unary_stream_messages_sequences)):
for test_messages in test_messages_sequence: for test_messages in test_messages_sequence:
request = test_messages.request() request = test_messages.request()
@ -388,16 +388,16 @@ class TestCase(
list(response_iterator) list(response_iterator)
def testExpiredStreamRequestUnaryResponse(self): def testExpiredStreamRequestUnaryResponse(self):
for (group, method), test_messages_sequence in ( for (group, method), test_messages_sequence in (six.iteritems(
six.iteritems(self._digest.stream_unary_messages_sequences)): self._digest.stream_unary_messages_sequences)):
for test_messages in test_messages_sequence: for test_messages in test_messages_sequence:
requests = test_messages.requests() requests = test_messages.requests()
callback = _Callback() callback = _Callback()
with self._control.pause(): with self._control.pause():
response_future = self._invoker.future(group, method)( response_future = self._invoker.future(
iter(requests), group, method)(iter(requests),
_3069_test_constant.REALLY_SHORT_TIMEOUT) _3069_test_constant.REALLY_SHORT_TIMEOUT)
response_future.add_done_callback(callback) response_future.add_done_callback(callback)
self.assertIs(callback.future(), response_future) self.assertIs(callback.future(), response_future)
self.assertIsInstance(response_future.exception(), self.assertIsInstance(response_future.exception(),
@ -409,21 +409,21 @@ class TestCase(
self.assertIsNotNone(response_future.traceback()) self.assertIsNotNone(response_future.traceback())
def testExpiredStreamRequestStreamResponse(self): def testExpiredStreamRequestStreamResponse(self):
for (group, method), test_messages_sequence in ( for (group, method), test_messages_sequence in (six.iteritems(
six.iteritems(self._digest.stream_stream_messages_sequences)): self._digest.stream_stream_messages_sequences)):
for test_messages in test_messages_sequence: for test_messages in test_messages_sequence:
requests = test_messages.requests() requests = test_messages.requests()
with self._control.pause(): with self._control.pause():
response_iterator = self._invoker.future(group, method)( response_iterator = self._invoker.future(
iter(requests), group, method)(iter(requests),
_3069_test_constant.REALLY_SHORT_TIMEOUT) _3069_test_constant.REALLY_SHORT_TIMEOUT)
with self.assertRaises(face.ExpirationError): with self.assertRaises(face.ExpirationError):
list(response_iterator) list(response_iterator)
def testFailedUnaryRequestUnaryResponse(self): def testFailedUnaryRequestUnaryResponse(self):
for (group, method), test_messages_sequence in ( for (group, method), test_messages_sequence in (six.iteritems(
six.iteritems(self._digest.unary_unary_messages_sequences)): self._digest.unary_unary_messages_sequences)):
for test_messages in test_messages_sequence: for test_messages in test_messages_sequence:
request = test_messages.request() request = test_messages.request()
callback = _Callback() callback = _Callback()
@ -448,8 +448,8 @@ class TestCase(
self.assertIsNotNone(abortion_callback.future()) self.assertIsNotNone(abortion_callback.future())
def testFailedUnaryRequestStreamResponse(self): def testFailedUnaryRequestStreamResponse(self):
for (group, method), test_messages_sequence in ( for (group, method), test_messages_sequence in (six.iteritems(
six.iteritems(self._digest.unary_stream_messages_sequences)): self._digest.unary_stream_messages_sequences)):
for test_messages in test_messages_sequence: for test_messages in test_messages_sequence:
request = test_messages.request() request = test_messages.request()
@ -464,17 +464,17 @@ class TestCase(
list(response_iterator) list(response_iterator)
def testFailedStreamRequestUnaryResponse(self): def testFailedStreamRequestUnaryResponse(self):
for (group, method), test_messages_sequence in ( for (group, method), test_messages_sequence in (six.iteritems(
six.iteritems(self._digest.stream_unary_messages_sequences)): self._digest.stream_unary_messages_sequences)):
for test_messages in test_messages_sequence: for test_messages in test_messages_sequence:
requests = test_messages.requests() requests = test_messages.requests()
callback = _Callback() callback = _Callback()
abortion_callback = _Callback() abortion_callback = _Callback()
with self._control.fail(): with self._control.fail():
response_future = self._invoker.future(group, method)( response_future = self._invoker.future(
iter(requests), group, method)(iter(requests),
_3069_test_constant.REALLY_SHORT_TIMEOUT) _3069_test_constant.REALLY_SHORT_TIMEOUT)
response_future.add_done_callback(callback) response_future.add_done_callback(callback)
response_future.add_abortion_callback(abortion_callback) response_future.add_abortion_callback(abortion_callback)
@ -491,8 +491,8 @@ class TestCase(
self.assertIsNotNone(abortion_callback.future()) self.assertIsNotNone(abortion_callback.future())
def testFailedStreamRequestStreamResponse(self): def testFailedStreamRequestStreamResponse(self):
for (group, method), test_messages_sequence in ( for (group, method), test_messages_sequence in (six.iteritems(
six.iteritems(self._digest.stream_stream_messages_sequences)): self._digest.stream_stream_messages_sequences)):
for test_messages in test_messages_sequence: for test_messages in test_messages_sequence:
requests = test_messages.requests() requests = test_messages.requests()
@ -502,7 +502,7 @@ class TestCase(
# expiration of the RPC. # expiration of the RPC.
with self._control.fail(), self.assertRaises( with self._control.fail(), self.assertRaises(
face.ExpirationError): face.ExpirationError):
response_iterator = self._invoker.future(group, method)( response_iterator = self._invoker.future(
iter(requests), group, method)(iter(requests),
_3069_test_constant.REALLY_SHORT_TIMEOUT) _3069_test_constant.REALLY_SHORT_TIMEOUT)
list(response_iterator) list(response_iterator)

@ -191,5 +191,8 @@ def invoker_constructors():
Returns: Returns:
A sequence of InvokerConstructors. A sequence of InvokerConstructors.
""" """
return (_GenericInvokerConstructor(), _MultiCallableInvokerConstructor(), return (
_DynamicInvokerConstructor(),) _GenericInvokerConstructor(),
_MultiCallableInvokerConstructor(),
_DynamicInvokerConstructor(),
)

@ -33,8 +33,8 @@ def _get_last_trade_price(stock_request, stock_reply_callback, control, active):
if active(): if active():
stock_reply_callback( stock_reply_callback(
stock_pb2.StockReply( stock_pb2.StockReply(
symbol=stock_request.symbol, price=_price( symbol=stock_request.symbol,
stock_request.symbol))) price=_price(stock_request.symbol)))
else: else:
raise abandonment.Abandoned() raise abandonment.Abandoned()

@ -24,7 +24,8 @@ from tests.unit.framework.interfaces.face import test_interfaces # pylint: disa
_TEST_CASE_SUPERCLASSES = ( _TEST_CASE_SUPERCLASSES = (
_blocking_invocation_inline_service.TestCase, _blocking_invocation_inline_service.TestCase,
_future_invocation_asynchronous_event_service.TestCase,) _future_invocation_asynchronous_event_service.TestCase,
)
def test_cases(implementation): def test_cases(implementation):
@ -42,8 +43,9 @@ def test_cases(implementation):
for invoker_constructor in _invocation.invoker_constructors(): for invoker_constructor in _invocation.invoker_constructors():
for super_class in _TEST_CASE_SUPERCLASSES: for super_class in _TEST_CASE_SUPERCLASSES:
test_case_classes.append( test_case_classes.append(
type(invoker_constructor.name() + super_class.NAME, ( type(
super_class,), { invoker_constructor.name() + super_class.NAME,
(super_class,), {
'implementation': implementation, 'implementation': implementation,
'invoker_constructor': invoker_constructor, 'invoker_constructor': invoker_constructor,
'__module__': implementation.__module__, '__module__': implementation.__module__,

@ -58,7 +58,8 @@ def cert_hier_1_client_1_key():
def cert_hier_1_client_1_cert(): def cert_hier_1_client_1_cert():
return pkg_resources.resource_string( return pkg_resources.resource_string(
__name__, __name__,
'credentials/certificate_hierarchy_1/intermediate/certs/client.cert.pem') 'credentials/certificate_hierarchy_1/intermediate/certs/client.cert.pem'
)
def cert_hier_1_server_1_key(): def cert_hier_1_server_1_key():
@ -97,7 +98,8 @@ def cert_hier_2_client_1_key():
def cert_hier_2_client_1_cert(): def cert_hier_2_client_1_cert():
return pkg_resources.resource_string( return pkg_resources.resource_string(
__name__, __name__,
'credentials/certificate_hierarchy_2/intermediate/certs/client.cert.pem') 'credentials/certificate_hierarchy_2/intermediate/certs/client.cert.pem'
)
def cert_hier_2_server_1_key(): def cert_hier_2_server_1_key():

@ -19,9 +19,21 @@ from concurrent import futures
import grpc import grpc
import six import six
INVOCATION_INITIAL_METADATA = (('0', 'abc'), ('1', 'def'), ('2', 'ghi'),) INVOCATION_INITIAL_METADATA = (
SERVICE_INITIAL_METADATA = (('3', 'jkl'), ('4', 'mno'), ('5', 'pqr'),) ('0', 'abc'),
SERVICE_TERMINAL_METADATA = (('6', 'stu'), ('7', 'vwx'), ('8', 'yza'),) ('1', 'def'),
('2', 'ghi'),
)
SERVICE_INITIAL_METADATA = (
('3', 'jkl'),
('4', 'mno'),
('5', 'pqr'),
)
SERVICE_TERMINAL_METADATA = (
('6', 'stu'),
('7', 'vwx'),
('8', 'yza'),
)
DETAILS = 'test details' DETAILS = 'test details'
@ -80,8 +92,10 @@ def test_secure_channel(target, channel_credentials, server_host_override):
An implementations.Channel to the remote host through which RPCs may be An implementations.Channel to the remote host through which RPCs may be
conducted. conducted.
""" """
channel = grpc.secure_channel(target, channel_credentials, ( channel = grpc.secure_channel(target, channel_credentials, ((
('grpc.ssl_target_name_override', server_host_override,),)) 'grpc.ssl_target_name_override',
server_host_override,
),))
return channel return channel

@ -48,5 +48,6 @@ def merge_json(dst, add):
elif isinstance(dst, list) and isinstance(add, list): elif isinstance(dst, list) and isinstance(add, list):
dst.extend(add) dst.extend(add)
else: else:
raise Exception('Tried to merge incompatible objects %s %s\n\n%r\n\n%r' raise Exception(
% (type(dst).__name__, type(add).__name__, dst, add)) 'Tried to merge incompatible objects %s %s\n\n%r\n\n%r' %
(type(dst).__name__, type(add).__name__, dst, add))

@ -99,10 +99,10 @@ def main(argv):
elif opt == '-P': elif opt == '-P':
assert not got_preprocessed_input assert not got_preprocessed_input
assert json_dict == {} assert json_dict == {}
sys.path.insert( sys.path.insert(0,
0, os.path.abspath(
os.path.abspath( os.path.join(
os.path.join(os.path.dirname(sys.argv[0]), 'plugins'))) os.path.dirname(sys.argv[0]), 'plugins')))
with open(arg, 'r') as dict_file: with open(arg, 'r') as dict_file:
dictionary = pickle.load(dict_file) dictionary = pickle.load(dict_file)
got_preprocessed_input = True got_preprocessed_input = True

@ -104,8 +104,7 @@ def mako_plugin(dictionary):
# build reverse dependency map # build reverse dependency map
things = {} things = {}
for thing in dictionary['libs'] + dictionary['targets'] + dictionary[ for thing in dictionary['libs'] + dictionary['targets'] + dictionary['filegroups']:
'filegroups']:
things[thing['name']] = thing things[thing['name']] = thing
thing['used_by'] = [] thing['used_by'] = []
thing_deps = lambda t: t.get('uses', []) + t.get('filegroups', []) + t.get('deps', []) thing_deps = lambda t: t.get('uses', []) + t.get('filegroups', []) + t.get('deps', [])
@ -148,7 +147,7 @@ def mako_plugin(dictionary):
lib[lst] = vals lib[lst] = vals
lib['plugins'] = plugins lib['plugins'] = plugins
if lib.get('generate_plugin_registry', False): if lib.get('generate_plugin_registry', False):
lib['src'].append('src/core/plugin_registry/%s_plugin_registry.cc' % lib['src'].append(
lib['name']) 'src/core/plugin_registry/%s_plugin_registry.cc' % lib['name'])
for lst in FILEGROUP_LISTS: for lst in FILEGROUP_LISTS:
lib[lst] = uniquify(lib.get(lst, [])) lib[lst] = uniquify(lib.get(lst, []))

@ -56,11 +56,12 @@ def mako_plugin(dictionary):
target['vs_props'] = [] target['vs_props'] = []
target['vs_proj_dir'] = target.get('vs_proj_dir', default_test_dir) target['vs_proj_dir'] = target.get('vs_proj_dir', default_test_dir)
if target.get('vs_project_guid', if target.get('vs_project_guid',
None) is None and 'windows' in target.get('platforms', None) is None and 'windows' in target.get(
['windows']): 'platforms', ['windows']):
name = target['name'] name = target['name']
guid = re.sub('(........)(....)(....)(....)(.*)', guid = re.sub('(........)(....)(....)(....)(.*)',
r'{\1-\2-\3-\4-\5}', hashlib.md5(name).hexdigest()) r'{\1-\2-\3-\4-\5}',
hashlib.md5(name).hexdigest())
target['vs_project_guid'] = guid.upper() target['vs_project_guid'] = guid.upper()
# Exclude projects without a visual project guid, such as the tests. # Exclude projects without a visual project guid, such as the tests.
projects = [ projects = [
@ -69,9 +70,9 @@ def mako_plugin(dictionary):
projects = [ projects = [
project for project in projects project for project in projects
if project['language'] != 'c++' or project['build'] == 'all' or project[ if project['language'] != 'c++' or project['build'] == 'all' or
'build'] == 'protoc' or (project['language'] == 'c++' and (project[ project['build'] == 'protoc' or (project['language'] == 'c++' and (
'build'] == 'test' or project['build'] == 'private')) project['build'] == 'test' or project['build'] == 'private'))
] ]
project_dict = dict([(p['name'], p) for p in projects]) project_dict = dict([(p['name'], p) for p in projects])

@ -54,5 +54,5 @@ def mako_plugin(dictionary):
target['transitive_deps'] = transitive_deps(target, libs) target['transitive_deps'] = transitive_deps(target, libs)
python_dependencies = dictionary.get('python_dependencies') python_dependencies = dictionary.get('python_dependencies')
python_dependencies['transitive_deps'] = ( python_dependencies['transitive_deps'] = (transitive_deps(
transitive_deps(python_dependencies, libs)) python_dependencies, libs))

@ -174,10 +174,13 @@ for decorated_setting in sorted(decorated_settings):
print >> C, "{NULL, 0, 0, 0, GRPC_CHTTP2_DISCONNECT_ON_INVALID_VALUE, GRPC_HTTP2_PROTOCOL_ERROR}," print >> C, "{NULL, 0, 0, 0, GRPC_CHTTP2_DISCONNECT_ON_INVALID_VALUE, GRPC_HTTP2_PROTOCOL_ERROR},"
i += 1 i += 1
print >> C, "{\"%s\", %du, %du, %du, GRPC_CHTTP2_%s, GRPC_HTTP2_%s}," % ( print >> C, "{\"%s\", %du, %du, %du, GRPC_CHTTP2_%s, GRPC_HTTP2_%s}," % (
decorated_setting.name, decorated_setting.setting.default, decorated_setting.name,
decorated_setting.setting.min, decorated_setting.setting.max, decorated_setting.setting.default,
decorated_setting.setting.min,
decorated_setting.setting.max,
decorated_setting.setting.on_error.behavior, decorated_setting.setting.on_error.behavior,
decorated_setting.setting.on_error.code,) decorated_setting.setting.on_error.code,
)
i += 1 i += 1
print >> C, "};" print >> C, "};"

@ -387,8 +387,8 @@ for i, elem in enumerate(all_strs):
print >> H, '#define %s (grpc_static_slice_table[%d])' % ( print >> H, '#define %s (grpc_static_slice_table[%d])' % (
mangle(elem).upper(), i) mangle(elem).upper(), i)
print >> H print >> H
print >> C, 'static uint8_t g_bytes[] = {%s};' % ( print >> C, 'static uint8_t g_bytes[] = {%s};' % (','.join(
','.join('%d' % ord(c) for c in ''.join(all_strs))) '%d' % ord(c) for c in ''.join(all_strs)))
print >> C print >> C
print >> C, 'static void static_ref(void *unused) {}' print >> C, 'static void static_ref(void *unused) {}'
print >> C, 'static void static_unref(void *unused) {}' print >> C, 'static void static_unref(void *unused) {}'
@ -444,8 +444,8 @@ for i, elem in enumerate(all_elems):
print >> H print >> H
print >> C, ('uintptr_t grpc_static_mdelem_user_data[GRPC_STATIC_MDELEM_COUNT] ' print >> C, ('uintptr_t grpc_static_mdelem_user_data[GRPC_STATIC_MDELEM_COUNT] '
'= {') '= {')
print >> C, ' %s' % ','.join('%d' % static_userdata.get(elem, 0) print >> C, ' %s' % ','.join(
for elem in all_elems) '%d' % static_userdata.get(elem, 0) for elem in all_elems)
print >> C, '};' print >> C, '};'
print >> C print >> C
@ -520,8 +520,8 @@ for i, k in enumerate(elem_keys):
idxs[h] = i idxs[h] = i
print >> C, 'static const uint16_t elem_keys[] = {%s};' % ','.join( print >> C, 'static const uint16_t elem_keys[] = {%s};' % ','.join(
'%d' % k for k in keys) '%d' % k for k in keys)
print >> C, 'static const uint8_t elem_idxs[] = {%s};' % ','.join('%d' % i print >> C, 'static const uint8_t elem_idxs[] = {%s};' % ','.join(
for i in idxs) '%d' % i for i in idxs)
print >> C print >> C
print >> H, 'grpc_mdelem grpc_static_mdelem_for_static_strings(int a, int b);' print >> H, 'grpc_mdelem grpc_static_mdelem_for_static_strings(int a, int b);'
@ -579,8 +579,8 @@ print >> H, 'extern const uint8_t grpc_static_accept_stream_encoding_metadata[%d
1 << len(STREAM_COMPRESSION_ALGORITHMS)) 1 << len(STREAM_COMPRESSION_ALGORITHMS))
print >> C, 'const uint8_t grpc_static_accept_stream_encoding_metadata[%d] = {' % ( print >> C, 'const uint8_t grpc_static_accept_stream_encoding_metadata[%d] = {' % (
1 << len(STREAM_COMPRESSION_ALGORITHMS)) 1 << len(STREAM_COMPRESSION_ALGORITHMS))
print >> C, '0,%s' % ','.join('%d' % md_idx(elem) print >> C, '0,%s' % ','.join(
for elem in stream_compression_elems) '%d' % md_idx(elem) for elem in stream_compression_elems)
print >> C, '};' print >> C, '};'
print >> H, '#define GRPC_MDELEM_ACCEPT_STREAM_ENCODING_FOR_ALGORITHMS(algs) (GRPC_MAKE_MDELEM(&grpc_static_mdelem_table[grpc_static_accept_stream_encoding_metadata[(algs)]], GRPC_MDELEM_STORAGE_STATIC))' print >> H, '#define GRPC_MDELEM_ACCEPT_STREAM_ENCODING_FOR_ALGORITHMS(algs) (GRPC_MAKE_MDELEM(&grpc_static_mdelem_table[grpc_static_accept_stream_encoding_metadata[(algs)]], GRPC_MDELEM_STORAGE_STATIC))'

@ -28,8 +28,8 @@ REQUIRED_FIELDS = ['name', 'doc']
def make_type(name, fields): def make_type(name, fields):
return (collections.namedtuple( return (collections.namedtuple(name, ' '.join(
name, ' '.join(list(set(REQUIRED_FIELDS + fields)))), []) list(set(REQUIRED_FIELDS + fields)))), [])
def c_str(s, encoding='ascii'): def c_str(s, encoding='ascii'):
@ -44,7 +44,10 @@ def c_str(s, encoding='ascii'):
return '"' + result + '"' return '"' + result + '"'
types = (make_type('Counter', []), make_type('Histogram', ['max', 'buckets']),) types = (
make_type('Counter', []),
make_type('Histogram', ['max', 'buckets']),
)
inst_map = dict((t[0].__name__, t[1]) for t in types) inst_map = dict((t[0].__name__, t[1]) for t in types)
@ -349,8 +352,8 @@ with open('src/core/lib/debug/stats_data.cc', 'w') as C:
print >> C, "const int grpc_stats_histo_start[%d] = {%s};" % ( print >> C, "const int grpc_stats_histo_start[%d] = {%s};" % (
len(inst_map['Histogram']), ','.join('%s' % x for x in histo_start)) len(inst_map['Histogram']), ','.join('%s' % x for x in histo_start))
print >> C, "const int *const grpc_stats_histo_bucket_boundaries[%d] = {%s};" % ( print >> C, "const int *const grpc_stats_histo_bucket_boundaries[%d] = {%s};" % (
len(inst_map['Histogram']), ','.join('grpc_stats_table_%d' % x len(inst_map['Histogram']), ','.join(
for x in histo_bucket_boundaries)) 'grpc_stats_table_%d' % x for x in histo_bucket_boundaries))
print >> C, "void (*const grpc_stats_inc_histogram[%d])(int x) = {%s};" % ( print >> C, "void (*const grpc_stats_inc_histogram[%d])(int x) = {%s};" % (
len(inst_map['Histogram']), ','.join( len(inst_map['Histogram']), ','.join(
'grpc_stats_inc_%s' % histogram.name.lower() 'grpc_stats_inc_%s' % histogram.name.lower()

@ -39,7 +39,7 @@ for line in data:
elif line[0] == "realloc": elif line[0] == "realloc":
errs.remove(line[1]) errs.remove(line[1])
errs.append(line[3]) errs.append(line[3])
# explicitly look for the last dereference # explicitly look for the last dereference
elif line[1] == "1" and line[3] == "0": elif line[1] == "1" and line[3] == "0":
assert (err in errs) assert (err in errs)
errs.remove(err) errs.remove(err)

@ -84,13 +84,15 @@ _EXEMPT = frozenset((
# census.proto copied from github # census.proto copied from github
'tools/grpcz/census.proto', 'tools/grpcz/census.proto',
# status.proto copied from googleapis # status.proto copied from googleapis
'src/proto/grpc/status/status.proto',)) 'src/proto/grpc/status/status.proto',
))
RE_YEAR = r'Copyright (?P<first_year>[0-9]+\-)?(?P<last_year>[0-9]+) gRPC authors.' RE_YEAR = r'Copyright (?P<first_year>[0-9]+\-)?(?P<last_year>[0-9]+) gRPC authors.'
RE_LICENSE = dict((k, r'\n'.join( RE_LICENSE = dict(
LICENSE_PREFIX[k] + (RE_YEAR (k, r'\n'.join(LICENSE_PREFIX[k] +
if re.search(RE_YEAR, line) else re.escape(line)) (RE_YEAR if re.search(RE_YEAR, line) else re.escape(line))
for line in LICENSE_NOTICE)) for k, v in LICENSE_PREFIX.iteritems()) for line in LICENSE_NOTICE))
for k, v in LICENSE_PREFIX.iteritems())
if args.precommit: if args.precommit:
FILE_LIST_COMMAND = 'git status -z | grep -Poz \'(?<=^[MARC][MARCD ] )[^\s]+\'' FILE_LIST_COMMAND = 'git status -z | grep -Poz \'(?<=^[MARC][MARCD ] )[^\s]+\''

@ -95,14 +95,14 @@ class GuardValidator(object):
# Does the guard end with a '_H'? # Does the guard end with a '_H'?
running_guard = match.group(1) running_guard = match.group(1)
if not running_guard.endswith('_H'): if not running_guard.endswith('_H'):
fcontents = self.fail(fpath, match.re, match.string, fcontents = self.fail(fpath, match.re, match.string, match.group(1),
match.group(1), valid_guard, fix) valid_guard, fix)
if fix: save(fpath, fcontents) if fix: save(fpath, fcontents)
# Is it the expected one based on the file path? # Is it the expected one based on the file path?
if running_guard != valid_guard: if running_guard != valid_guard:
fcontents = self.fail(fpath, match.re, match.string, fcontents = self.fail(fpath, match.re, match.string, match.group(1),
match.group(1), valid_guard, fix) valid_guard, fix)
if fix: save(fpath, fcontents) if fix: save(fpath, fcontents)
# Is there a #define? Is it the same as the #ifndef one? # Is there a #define? Is it the same as the #ifndef one?
@ -114,8 +114,8 @@ class GuardValidator(object):
# Is the #define guard the same as the #ifndef guard? # Is the #define guard the same as the #ifndef guard?
if match.group(1) != running_guard: if match.group(1) != running_guard:
fcontents = self.fail(fpath, match.re, match.string, fcontents = self.fail(fpath, match.re, match.string, match.group(1),
match.group(1), valid_guard, fix) valid_guard, fix)
if fix: save(fpath, fcontents) if fix: save(fpath, fcontents)
# Is there a properly commented #endif? # Is there a properly commented #endif?
@ -138,8 +138,8 @@ class GuardValidator(object):
self.fail(fpath, endif_re, flines[-1], '', '', False) self.fail(fpath, endif_re, flines[-1], '', '', False)
elif match.group(1) != running_guard: elif match.group(1) != running_guard:
# Is the #endif guard the same as the #ifndef and #define guards? # Is the #endif guard the same as the #ifndef and #define guards?
fcontents = self.fail(fpath, endif_re, fcontents, fcontents = self.fail(fpath, endif_re, fcontents, match.group(1),
match.group(1), valid_guard, fix) valid_guard, fix)
if fix: save(fpath, fcontents) if fix: save(fpath, fcontents)
return not self.failed # Did the check succeed? (ie, not failed) return not self.failed # Did the check succeed? (ie, not failed)

@ -30,8 +30,8 @@ def build_package_protos(package_root):
proto_files.append( proto_files.append(
os.path.abspath(os.path.join(root, filename))) os.path.abspath(os.path.join(root, filename)))
well_known_protos_include = pkg_resources.resource_filename('grpc_tools', well_known_protos_include = pkg_resources.resource_filename(
'_proto') 'grpc_tools', '_proto')
for proto_file in proto_files: for proto_file in proto_files:
command = [ command = [

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save