Merge pull request #9276 from soltanmm-google/remember-the-blue-flowers-they-are-important!

Enable yapf (Python formatting).
pull/9310/head
Nathaniel Manista 8 years ago committed by GitHub
commit fc4b07e10c
  1. 1
      .gitignore
  2. 4
      setup.cfg
  3. 6
      src/python/grpcio/_spawn_patch.py
  4. 54
      src/python/grpcio/commands.py
  5. 152
      src/python/grpcio/grpc/__init__.py
  6. 7
      src/python/grpcio/grpc/_auth.py
  7. 281
      src/python/grpcio/grpc/_channel.py
  8. 21
      src/python/grpcio/grpc/_common.py
  9. 4
      src/python/grpcio/grpc/_credential_composition.py
  10. 19
      src/python/grpcio/grpc/_plugin_wrapping.py
  11. 192
      src/python/grpcio/grpc/_server.py
  12. 19
      src/python/grpcio/grpc/_utilities.py
  13. 424
      src/python/grpcio/grpc/beta/_client_adaptations.py
  14. 27
      src/python/grpcio/grpc/beta/_connectivity_channel.py
  15. 116
      src/python/grpcio/grpc/beta/_server_adaptations.py
  16. 65
      src/python/grpcio/grpc/beta/implementations.py
  17. 2
      src/python/grpcio/grpc/beta/interfaces.py
  18. 2
      src/python/grpcio/grpc/beta/utilities.py
  19. 2
      src/python/grpcio/grpc/framework/__init__.py
  20. 2
      src/python/grpcio/grpc/framework/common/__init__.py
  21. 1
      src/python/grpcio/grpc/framework/common/cardinality.py
  22. 1
      src/python/grpcio/grpc/framework/common/style.py
  23. 2
      src/python/grpcio/grpc/framework/foundation/__init__.py
  24. 1
      src/python/grpcio/grpc/framework/foundation/abandonment.py
  25. 11
      src/python/grpcio/grpc/framework/foundation/callable_util.py
  26. 1
      src/python/grpcio/grpc/framework/foundation/future.py
  27. 11
      src/python/grpcio/grpc/framework/foundation/logging_pool.py
  28. 2
      src/python/grpcio/grpc/framework/foundation/stream.py
  29. 1
      src/python/grpcio/grpc/framework/foundation/stream_util.py
  30. 2
      src/python/grpcio/grpc/framework/interfaces/__init__.py
  31. 2
      src/python/grpcio/grpc/framework/interfaces/base/__init__.py
  32. 20
      src/python/grpcio/grpc/framework/interfaces/base/base.py
  33. 30
      src/python/grpcio/grpc/framework/interfaces/base/utilities.py
  34. 2
      src/python/grpcio/grpc/framework/interfaces/face/__init__.py
  35. 182
      src/python/grpcio/grpc/framework/interfaces/face/face.py
  36. 69
      src/python/grpcio/grpc/framework/interfaces/face/utilities.py
  37. 21
      src/python/grpcio/support.py
  38. 2
      src/python/grpcio_health_checking/grpc_health/__init__.py
  39. 2
      src/python/grpcio_health_checking/grpc_health/v1/__init__.py
  40. 1
      src/python/grpcio_health_checking/grpc_health/v1/health.py
  41. 1
      src/python/grpcio_health_checking/health_commands.py
  42. 14
      src/python/grpcio_health_checking/setup.py
  43. 1
      src/python/grpcio_reflection/grpc_reflection/__init__.py
  44. 1
      src/python/grpcio_reflection/grpc_reflection/v1alpha/__init__.py
  45. 38
      src/python/grpcio_reflection/grpc_reflection/v1alpha/reflection.py
  46. 7
      src/python/grpcio_reflection/reflection_commands.py
  47. 14
      src/python/grpcio_reflection/setup.py
  48. 16
      src/python/grpcio_tests/commands.py
  49. 20
      src/python/grpcio_tests/setup.py
  50. 3
      src/python/grpcio_tests/tests/_loader.py
  51. 134
      src/python/grpcio_tests/tests/_result.py
  52. 31
      src/python/grpcio_tests/tests/_runner.py
  53. 7
      src/python/grpcio_tests/tests/health_check/_health_servicer_test.py
  54. 65
      src/python/grpcio_tests/tests/http2/_negative_http2_client.py
  55. 2
      src/python/grpcio_tests/tests/interop/__init__.py
  56. 8
      src/python/grpcio_tests/tests/interop/_insecure_intraop_test.py
  57. 10
      src/python/grpcio_tests/tests/interop/_intraop_test_case.py
  58. 21
      src/python/grpcio_tests/tests/interop/_secure_intraop_test.py
  59. 43
      src/python/grpcio_tests/tests/interop/client.py
  60. 159
      src/python/grpcio_tests/tests/interop/methods.py
  61. 9
      src/python/grpcio_tests/tests/interop/resources.py
  62. 15
      src/python/grpcio_tests/tests/interop/server.py
  63. 2
      src/python/grpcio_tests/tests/protoc_plugin/__init__.py
  64. 82
      src/python/grpcio_tests/tests/protoc_plugin/_python_plugin_test.py
  65. 54
      src/python/grpcio_tests/tests/protoc_plugin/_split_definitions_test.py
  66. 58
      src/python/grpcio_tests/tests/protoc_plugin/beta_python_plugin_test.py
  67. 2
      src/python/grpcio_tests/tests/protoc_plugin/protos/__init__.py
  68. 2
      src/python/grpcio_tests/tests/protoc_plugin/protos/invocation_testing/__init__.py
  69. 2
      src/python/grpcio_tests/tests/protoc_plugin/protos/invocation_testing/split_messages/__init__.py
  70. 2
      src/python/grpcio_tests/tests/protoc_plugin/protos/invocation_testing/split_services/__init__.py
  71. 2
      src/python/grpcio_tests/tests/protoc_plugin/protos/payload/__init__.py
  72. 2
      src/python/grpcio_tests/tests/protoc_plugin/protos/requests/__init__.py
  73. 2
      src/python/grpcio_tests/tests/protoc_plugin/protos/requests/r/__init__.py
  74. 2
      src/python/grpcio_tests/tests/protoc_plugin/protos/responses/__init__.py
  75. 2
      src/python/grpcio_tests/tests/protoc_plugin/protos/service/__init__.py
  76. 16
      src/python/grpcio_tests/tests/qps/benchmark_client.py
  77. 1
      src/python/grpcio_tests/tests/qps/client_runner.py
  78. 4
      src/python/grpcio_tests/tests/qps/qps_worker.py
  79. 27
      src/python/grpcio_tests/tests/qps/worker_server.py
  80. 80
      src/python/grpcio_tests/tests/reflection/_reflection_servicer_test.py
  81. 41
      src/python/grpcio_tests/tests/stress/client.py
  82. 1
      src/python/grpcio_tests/tests/stress/metrics_server.py
  83. 1
      src/python/grpcio_tests/tests/stress/test_runner.py
  84. 2
      src/python/grpcio_tests/tests/unit/__init__.py
  85. 11
      src/python/grpcio_tests/tests/unit/_api_test.py
  86. 1
      src/python/grpcio_tests/tests/unit/_auth_test.py
  87. 6
      src/python/grpcio_tests/tests/unit/_channel_args_test.py
  88. 49
      src/python/grpcio_tests/tests/unit/_channel_connectivity_test.py
  89. 4
      src/python/grpcio_tests/tests/unit/_channel_ready_future_test.py
  90. 21
      src/python/grpcio_tests/tests/unit/_compression_test.py
  91. 16
      src/python/grpcio_tests/tests/unit/_credentials_test.py
  92. 50
      src/python/grpcio_tests/tests/unit/_cython/_cancel_many_calls_test.py
  93. 13
      src/python/grpcio_tests/tests/unit/_cython/_channel_test.py
  94. 68
      src/python/grpcio_tests/tests/unit/_cython/_read_some_but_not_all_responses_test.py
  95. 108
      src/python/grpcio_tests/tests/unit/_cython/cygrpc_test.py
  96. 7
      src/python/grpcio_tests/tests/unit/_cython/test_utilities.py
  97. 17
      src/python/grpcio_tests/tests/unit/_empty_message_test.py
  98. 1
      src/python/grpcio_tests/tests/unit/_exit_scenarios.py
  99. 61
      src/python/grpcio_tests/tests/unit/_exit_test.py
  100. 51
      src/python/grpcio_tests/tests/unit/_invalid_metadata_test.py
  101. Some files were not shown because too many files have changed in this diff Show More

1
.gitignore vendored

@ -8,6 +8,7 @@ objs
# Python items
cython_debug/
python_build/
python_format_venv/
.coverage*
.eggs
htmlcov/

@ -11,3 +11,7 @@ inplace=1
[build_package_protos]
exclude=.*protoc_plugin/protoc_plugin_test\.proto$
# Style settings
[yapf]
based_on_style = google

@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Patches the spawn() command for windows compilers.
Windows has an 8191 character command line limit, but some compilers
@ -45,6 +44,7 @@ MAX_COMMAND_LENGTH = 8191
_classic_spawn = ccompiler.CCompiler.spawn
def _commandfile_spawn(self, command):
command_length = sum([len(arg) for arg in command])
if os.name == 'nt' and command_length > MAX_COMMAND_LENGTH:
@ -56,7 +56,9 @@ def _commandfile_spawn(self, command):
command_filename = os.path.abspath(
os.path.join(temporary_directory, 'command'))
with open(command_filename, 'w') as command_file:
escaped_args = ['"' + arg.replace('\\', '\\\\') + '"' for arg in command[1:]]
escaped_args = [
'"' + arg.replace('\\', '\\\\') + '"' for arg in command[1:]
]
command_file.write(' '.join(escaped_args))
modified_command = command[:1] + ['@{}'.format(command_filename)]
try:

@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Provides distutils command classes for the GRPC Python setup process."""
import distutils
@ -112,17 +111,15 @@ def _get_grpc_custom_bdist(decorated_basename, target_bdist_basename):
url = BINARIES_REPOSITORY + '/{target}'.format(target=decorated_path)
bdist_data = request.urlopen(url).read()
except IOError as error:
raise CommandError(
'{}\n\nCould not find the bdist {}: {}'
.format(traceback.format_exc(), decorated_path, error.message))
raise CommandError('{}\n\nCould not find the bdist {}: {}'.format(
traceback.format_exc(), decorated_path, error.message))
# Our chosen local bdist path.
bdist_path = target_bdist_basename + GRPC_CUSTOM_BDIST_EXT
try:
with open(bdist_path, 'w') as bdist_file:
bdist_file.write(bdist_data)
except IOError as error:
raise CommandError(
'{}\n\nCould not write grpcio bdist: {}'
raise CommandError('{}\n\nCould not write grpcio bdist: {}'
.format(traceback.format_exc(), error.message))
return bdist_path
@ -149,15 +146,17 @@ class SphinxDocumentation(setuptools.Command):
sys.path.append(src_dir)
sphinx.apidoc.main([
'', '--force', '--full', '-H', metadata.name, '-A', metadata.author,
'-V', metadata.version, '-R', metadata.version,
'-o', os.path.join('doc', 'src'), src_dir])
'-V', metadata.version, '-R', metadata.version, '-o',
os.path.join('doc', 'src'), src_dir
])
conf_filepath = os.path.join('doc', 'src', 'conf.py')
with open(conf_filepath, 'a') as conf_file:
conf_file.write(CONF_PY_ADDENDUM)
glossary_filepath = os.path.join('doc', 'src', 'grpc.rst')
with open(glossary_filepath, 'a') as glossary_filepath:
glossary_filepath.write(API_GLOSSARY)
sphinx.main(['', os.path.join('doc', 'src'), os.path.join('doc', 'build')])
sphinx.main(
['', os.path.join('doc', 'src'), os.path.join('doc', 'build')])
class BuildProjectMetadata(setuptools.Command):
@ -173,7 +172,8 @@ class BuildProjectMetadata(setuptools.Command):
pass
def run(self):
with open(os.path.join(PYTHON_STEM, 'grpc/_grpcio_metadata.py'), 'w') as module_file:
with open(os.path.join(PYTHON_STEM, 'grpc/_grpcio_metadata.py'),
'w') as module_file:
module_file.write('__version__ = """{}"""'.format(
self.distribution.get_version()))
@ -194,6 +194,7 @@ def _poison_extensions(extensions, message):
for extension in extensions:
extension.sources = [poison_filename]
def check_and_update_cythonization(extensions):
"""Replace .pyx files with their generated counterparts and return whether or
not cythonization still needs to occur."""
@ -203,9 +204,12 @@ def check_and_update_cythonization(extensions):
for source in extension.sources:
base, file_ext = os.path.splitext(source)
if file_ext == '.pyx':
generated_pyx_source = next(
(base + gen_ext for gen_ext in ('.c', '.cpp',)
if os.path.isfile(base + gen_ext)), None)
generated_pyx_source = next((base + gen_ext
for gen_ext in (
'.c',
'.cpp',)
if os.path.isfile(base + gen_ext)),
None)
if generated_pyx_source:
generated_pyx_sources.append(generated_pyx_source)
else:
@ -217,6 +221,7 @@ def check_and_update_cythonization(extensions):
sys.stderr.write('Found cython-generated files...\n')
return True
def try_cythonize(extensions, linetracing=False, mandatory=True):
"""Attempt to cythonize the extensions.
@ -236,7 +241,8 @@ def try_cythonize(extensions, linetracing=False, mandatory=True):
"Poisoning extension sources to disallow extension commands...")
_poison_extensions(
extensions,
"Extensions have been poisoned due to missing Cython-generated code.")
"Extensions have been poisoned due to missing Cython-generated code."
)
return extensions
cython_compiler_directives = {}
if linetracing:
@ -245,10 +251,11 @@ def try_cythonize(extensions, linetracing=False, mandatory=True):
return Cython.Build.cythonize(
extensions,
include_path=[
include_dir for extension in extensions for include_dir in extension.include_dirs
include_dir
for extension in extensions
for include_dir in extension.include_dirs
] + [CYTHON_STEM],
compiler_directives=cython_compiler_directives
)
compiler_directives=cython_compiler_directives)
class BuildExt(build_ext.build_ext):
@ -264,10 +271,12 @@ class BuildExt(build_ext.build_ext):
compiler = self.compiler.compiler_type
if compiler in BuildExt.C_OPTIONS:
for extension in self.extensions:
extension.extra_compile_args += list(BuildExt.C_OPTIONS[compiler])
extension.extra_compile_args += list(BuildExt.C_OPTIONS[
compiler])
if compiler in BuildExt.LINK_OPTIONS:
for extension in self.extensions:
extension.extra_link_args += list(BuildExt.LINK_OPTIONS[compiler])
extension.extra_link_args += list(BuildExt.LINK_OPTIONS[
compiler])
if not check_and_update_cythonization(self.extensions):
self.extensions = try_cythonize(self.extensions)
try:
@ -275,8 +284,8 @@ class BuildExt(build_ext.build_ext):
except Exception as error:
formatted_exception = traceback.format_exc()
support.diagnose_build_ext_error(self, error, formatted_exception)
raise CommandError(
"Failed `build_ext` step:\n{}".format(formatted_exception))
raise CommandError("Failed `build_ext` step:\n{}".format(
formatted_exception))
class Gather(setuptools.Command):
@ -298,6 +307,7 @@ class Gather(setuptools.Command):
def run(self):
if self.install and self.distribution.install_requires:
self.distribution.fetch_build_eggs(self.distribution.install_requires)
self.distribution.fetch_build_eggs(
self.distribution.install_requires)
if self.test and self.distribution.tests_require:
self.distribution.fetch_build_eggs(self.distribution.tests_require)

@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""gRPC's Python API."""
import abc
@ -37,7 +36,6 @@ import six
from grpc._cython import cygrpc as _cygrpc
############################## Future Interface ###############################
@ -216,8 +214,8 @@ class ChannelConnectivity(enum.Enum):
IDLE = (_cygrpc.ConnectivityState.idle, 'idle')
CONNECTING = (_cygrpc.ConnectivityState.connecting, 'connecting')
READY = (_cygrpc.ConnectivityState.ready, 'ready')
TRANSIENT_FAILURE = (
_cygrpc.ConnectivityState.transient_failure, 'transient failure')
TRANSIENT_FAILURE = (_cygrpc.ConnectivityState.transient_failure,
'transient failure')
SHUTDOWN = (_cygrpc.ConnectivityState.shutdown, 'shutdown')
@ -227,18 +225,17 @@ class StatusCode(enum.Enum):
OK = (_cygrpc.StatusCode.ok, 'ok')
CANCELLED = (_cygrpc.StatusCode.cancelled, 'cancelled')
UNKNOWN = (_cygrpc.StatusCode.unknown, 'unknown')
INVALID_ARGUMENT = (
_cygrpc.StatusCode.invalid_argument, 'invalid argument')
DEADLINE_EXCEEDED = (
_cygrpc.StatusCode.deadline_exceeded, 'deadline exceeded')
INVALID_ARGUMENT = (_cygrpc.StatusCode.invalid_argument, 'invalid argument')
DEADLINE_EXCEEDED = (_cygrpc.StatusCode.deadline_exceeded,
'deadline exceeded')
NOT_FOUND = (_cygrpc.StatusCode.not_found, 'not found')
ALREADY_EXISTS = (_cygrpc.StatusCode.already_exists, 'already exists')
PERMISSION_DENIED = (
_cygrpc.StatusCode.permission_denied, 'permission denied')
RESOURCE_EXHAUSTED = (
_cygrpc.StatusCode.resource_exhausted, 'resource exhausted')
FAILED_PRECONDITION = (
_cygrpc.StatusCode.failed_precondition, 'failed precondition')
PERMISSION_DENIED = (_cygrpc.StatusCode.permission_denied,
'permission denied')
RESOURCE_EXHAUSTED = (_cygrpc.StatusCode.resource_exhausted,
'resource exhausted')
FAILED_PRECONDITION = (_cygrpc.StatusCode.failed_precondition,
'failed precondition')
ABORTED = (_cygrpc.StatusCode.aborted, 'aborted')
OUT_OF_RANGE = (_cygrpc.StatusCode.out_of_range, 'out of range')
UNIMPLEMENTED = (_cygrpc.StatusCode.unimplemented, 'unimplemented')
@ -523,8 +520,11 @@ class StreamUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)):
"""Affords invoking a stream-unary RPC in any call style."""
@abc.abstractmethod
def __call__(
self, request_iterator, timeout=None, metadata=None, credentials=None):
def __call__(self,
request_iterator,
timeout=None,
metadata=None,
credentials=None):
"""Synchronously invokes the underlying RPC.
Args:
@ -546,8 +546,11 @@ class StreamUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)):
raise NotImplementedError()
@abc.abstractmethod
def with_call(
self, request_iterator, timeout=None, metadata=None, credentials=None):
def with_call(self,
request_iterator,
timeout=None,
metadata=None,
credentials=None):
"""Synchronously invokes the underlying RPC.
Args:
@ -568,8 +571,11 @@ class StreamUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)):
raise NotImplementedError()
@abc.abstractmethod
def future(
self, request_iterator, timeout=None, metadata=None, credentials=None):
def future(self,
request_iterator,
timeout=None,
metadata=None,
credentials=None):
"""Asynchronously invokes the underlying RPC.
Args:
@ -592,8 +598,11 @@ class StreamStreamMultiCallable(six.with_metaclass(abc.ABCMeta)):
"""Affords invoking a stream-stream RPC in any call style."""
@abc.abstractmethod
def __call__(
self, request_iterator, timeout=None, metadata=None, credentials=None):
def __call__(self,
request_iterator,
timeout=None,
metadata=None,
credentials=None):
"""Invokes the underlying RPC.
Args:
@ -644,8 +653,10 @@ class Channel(six.with_metaclass(abc.ABCMeta)):
raise NotImplementedError()
@abc.abstractmethod
def unary_unary(
self, method, request_serializer=None, response_deserializer=None):
def unary_unary(self,
method,
request_serializer=None,
response_deserializer=None):
"""Creates a UnaryUnaryMultiCallable for a unary-unary method.
Args:
@ -661,8 +672,10 @@ class Channel(six.with_metaclass(abc.ABCMeta)):
raise NotImplementedError()
@abc.abstractmethod
def unary_stream(
self, method, request_serializer=None, response_deserializer=None):
def unary_stream(self,
method,
request_serializer=None,
response_deserializer=None):
"""Creates a UnaryStreamMultiCallable for a unary-stream method.
Args:
@ -678,8 +691,10 @@ class Channel(six.with_metaclass(abc.ABCMeta)):
raise NotImplementedError()
@abc.abstractmethod
def stream_unary(
self, method, request_serializer=None, response_deserializer=None):
def stream_unary(self,
method,
request_serializer=None,
response_deserializer=None):
"""Creates a StreamUnaryMultiCallable for a stream-unary method.
Args:
@ -695,8 +710,10 @@ class Channel(six.with_metaclass(abc.ABCMeta)):
raise NotImplementedError()
@abc.abstractmethod
def stream_stream(
self, method, request_serializer=None, response_deserializer=None):
def stream_stream(self,
method,
request_serializer=None,
response_deserializer=None):
"""Creates a StreamStreamMultiCallable for a stream-stream method.
Args:
@ -973,8 +990,9 @@ class Server(six.with_metaclass(abc.ABCMeta)):
################################# Functions ################################
def unary_unary_rpc_method_handler(
behavior, request_deserializer=None, response_serializer=None):
def unary_unary_rpc_method_handler(behavior,
request_deserializer=None,
response_serializer=None):
"""Creates an RpcMethodHandler for a unary-unary RPC method.
Args:
@ -988,13 +1006,14 @@ def unary_unary_rpc_method_handler(
parameters.
"""
from grpc import _utilities
return _utilities.RpcMethodHandler(
False, False, request_deserializer, response_serializer, behavior, None,
return _utilities.RpcMethodHandler(False, False, request_deserializer,
response_serializer, behavior, None,
None, None)
def unary_stream_rpc_method_handler(
behavior, request_deserializer=None, response_serializer=None):
def unary_stream_rpc_method_handler(behavior,
request_deserializer=None,
response_serializer=None):
"""Creates an RpcMethodHandler for a unary-stream RPC method.
Args:
@ -1008,13 +1027,14 @@ def unary_stream_rpc_method_handler(
given parameters.
"""
from grpc import _utilities
return _utilities.RpcMethodHandler(
False, True, request_deserializer, response_serializer, None, behavior,
return _utilities.RpcMethodHandler(False, True, request_deserializer,
response_serializer, None, behavior,
None, None)
def stream_unary_rpc_method_handler(
behavior, request_deserializer=None, response_serializer=None):
def stream_unary_rpc_method_handler(behavior,
request_deserializer=None,
response_serializer=None):
"""Creates an RpcMethodHandler for a stream-unary RPC method.
Args:
@ -1028,13 +1048,14 @@ def stream_unary_rpc_method_handler(
given parameters.
"""
from grpc import _utilities
return _utilities.RpcMethodHandler(
True, False, request_deserializer, response_serializer, None, None,
return _utilities.RpcMethodHandler(True, False, request_deserializer,
response_serializer, None, None,
behavior, None)
def stream_stream_rpc_method_handler(
behavior, request_deserializer=None, response_serializer=None):
def stream_stream_rpc_method_handler(behavior,
request_deserializer=None,
response_serializer=None):
"""Creates an RpcMethodHandler for a stream-stream RPC method.
Args:
@ -1049,8 +1070,8 @@ def stream_stream_rpc_method_handler(
given parameters.
"""
from grpc import _utilities
return _utilities.RpcMethodHandler(
True, True, request_deserializer, response_serializer, None, None, None,
return _utilities.RpcMethodHandler(True, True, request_deserializer,
response_serializer, None, None, None,
behavior)
@ -1069,8 +1090,9 @@ def method_handlers_generic_handler(service, method_handlers):
return _utilities.DictionaryGenericHandler(service, method_handlers)
def ssl_channel_credentials(
root_certificates=None, private_key=None, certificate_chain=None):
def ssl_channel_credentials(root_certificates=None,
private_key=None,
certificate_chain=None):
"""Creates a ChannelCredentials for use with an SSL-enabled Channel.
Args:
@ -1112,8 +1134,8 @@ def metadata_call_credentials(metadata_plugin, name=None):
else:
effective_name = name
return CallCredentials(
_plugin_wrapping.call_credentials_metadata_plugin(
metadata_plugin, effective_name))
_plugin_wrapping.call_credentials_metadata_plugin(metadata_plugin,
effective_name))
def access_token_call_credentials(access_token):
@ -1164,12 +1186,12 @@ def composite_channel_credentials(channel_credentials, *call_credentials):
single_call_credentials._credentials
for single_call_credentials in call_credentials)
return ChannelCredentials(
_credential_composition.channel(
channel_credentials._credentials, cygrpc_call_credentials))
_credential_composition.channel(channel_credentials._credentials,
cygrpc_call_credentials))
def ssl_server_credentials(
private_key_certificate_chain_pairs, root_certificates=None,
def ssl_server_credentials(private_key_certificate_chain_pairs,
root_certificates=None,
require_client_auth=False):
"""Creates a ServerCredentials for use with an SSL-enabled Server.
@ -1192,14 +1214,14 @@ def ssl_server_credentials(
'At least one private key-certificate chain pair is required!')
elif require_client_auth and root_certificates is None:
raise ValueError(
'Illegal to require client auth without providing root certificates!')
'Illegal to require client auth without providing root certificates!'
)
else:
return ServerCredentials(
_cygrpc.server_credentials_ssl(
root_certificates,
[_cygrpc.SslPemKeyCertPair(key, pem)
for key, pem in private_key_certificate_chain_pairs],
require_client_auth))
_cygrpc.server_credentials_ssl(root_certificates, [
_cygrpc.SslPemKeyCertPair(key, pem)
for key, pem in private_key_certificate_chain_pairs
], require_client_auth))
def channel_ready_future(channel):
@ -1270,13 +1292,12 @@ def server(thread_pool, handlers=None, options=None):
A Server with which RPCs can be serviced.
"""
from grpc import _server
return _server.Server(thread_pool, () if handlers is None else handlers,
() if options is None else options)
return _server.Server(thread_pool, () if handlers is None else handlers, ()
if options is None else options)
################################### __all__ #################################
__all__ = (
'FutureTimeoutError',
'FutureCancelledError',
@ -1317,13 +1338,10 @@ __all__ = (
'channel_ready_future',
'insecure_channel',
'secure_channel',
'server',
)
'server',)
############################### Extension Shims ################################
# Here to maintain backwards compatibility; avoid using these in new code!
try:
import grpc_tools

@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""GRPCAuthMetadataPlugins for standard authentication."""
import inspect
@ -58,11 +57,13 @@ class GoogleCallCredentials(grpc.AuthMetadataPlugin):
def __call__(self, context, callback):
# MetadataPlugins cannot block (see grpc.beta.interfaces.py)
if self._is_jwt:
future = self._pool.submit(self._credentials.get_access_token,
future = self._pool.submit(
self._credentials.get_access_token,
additional_claims={'aud': context.service_url})
else:
future = self._pool.submit(self._credentials.get_access_token)
future.add_done_callback(lambda x: self._get_token_callback(callback, x))
future.add_done_callback(
lambda x: self._get_token_callback(callback, x))
def _get_token_callback(self, callback, future):
try:

@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Invocation-side implementation of gRPC Python."""
import sys
@ -52,26 +51,22 @@ _UNARY_UNARY_INITIAL_DUE = (
cygrpc.OperationType.send_close_from_client,
cygrpc.OperationType.receive_initial_metadata,
cygrpc.OperationType.receive_message,
cygrpc.OperationType.receive_status_on_client,
)
cygrpc.OperationType.receive_status_on_client,)
_UNARY_STREAM_INITIAL_DUE = (
cygrpc.OperationType.send_initial_metadata,
cygrpc.OperationType.send_message,
cygrpc.OperationType.send_close_from_client,
cygrpc.OperationType.receive_initial_metadata,
cygrpc.OperationType.receive_status_on_client,
)
cygrpc.OperationType.receive_status_on_client,)
_STREAM_UNARY_INITIAL_DUE = (
cygrpc.OperationType.send_initial_metadata,
cygrpc.OperationType.receive_initial_metadata,
cygrpc.OperationType.receive_message,
cygrpc.OperationType.receive_status_on_client,
)
cygrpc.OperationType.receive_status_on_client,)
_STREAM_STREAM_INITIAL_DUE = (
cygrpc.OperationType.send_initial_metadata,
cygrpc.OperationType.receive_initial_metadata,
cygrpc.OperationType.receive_status_on_client,
)
cygrpc.OperationType.receive_status_on_client,)
_CHANNEL_SUBSCRIPTION_CALLBACK_ERROR_LOG_MESSAGE = (
'Exception calling channel subscription callback!')
@ -100,23 +95,28 @@ def _wait_once_until(condition, until):
else:
condition.wait(timeout=remaining)
_INTERNAL_CALL_ERROR_MESSAGE_FORMAT = (
'Internal gRPC call error %d. ' +
'Please report to https://github.com/grpc/grpc/issues')
def _check_call_error(call_error, metadata):
if call_error == cygrpc.CallError.invalid_metadata:
raise ValueError('metadata was invalid: %s' % metadata)
elif call_error != cygrpc.CallError.ok:
raise ValueError(_INTERNAL_CALL_ERROR_MESSAGE_FORMAT % call_error)
def _call_error_set_RPCstate(state, call_error, metadata):
if call_error == cygrpc.CallError.invalid_metadata:
_abort(state, grpc.StatusCode.INTERNAL, 'metadata was invalid: %s' % metadata)
_abort(state, grpc.StatusCode.INTERNAL,
'metadata was invalid: %s' % metadata)
else:
_abort(state, grpc.StatusCode.INTERNAL,
_INTERNAL_CALL_ERROR_MESSAGE_FORMAT % call_error)
class _RPCState(object):
def __init__(self, due, initial_metadata, trailing_metadata, code, details):
@ -156,8 +156,8 @@ def _handle_event(event, state, response_deserializer):
elif operation_type == cygrpc.OperationType.receive_message:
serialized_response = batch_operation.received_message.bytes()
if serialized_response is not None:
response = _common.deserialize(
serialized_response, response_deserializer)
response = _common.deserialize(serialized_response,
response_deserializer)
if response is None:
details = 'Exception deserializing response!'
_abort(state, grpc.StatusCode.INTERNAL, details)
@ -182,6 +182,7 @@ def _handle_event(event, state, response_deserializer):
def _event_handler(state, call, response_deserializer):
def handle_event(event):
with state.condition:
callbacks = _handle_event(event, state, response_deserializer)
@ -190,11 +191,12 @@ def _event_handler(state, call, response_deserializer):
for callback in callbacks:
callback()
return call if done else None
return handle_event
def _consume_request_iterator(
request_iterator, state, call, request_serializer):
def _consume_request_iterator(request_iterator, state, call,
request_serializer):
event_handler = _event_handler(state, call, None)
def consume_request_iterator():
@ -206,7 +208,8 @@ def _consume_request_iterator(
except Exception as e:
logging.exception("Exception iterating requests!")
call.cancel()
_abort(state, grpc.StatusCode.UNKNOWN, "Exception iterating requests!")
_abort(state, grpc.StatusCode.UNKNOWN,
"Exception iterating requests!")
return
serialized_request = _common.serialize(request, request_serializer)
with state.condition:
@ -217,12 +220,10 @@ def _consume_request_iterator(
_abort(state, grpc.StatusCode.INTERNAL, details)
return
else:
operations = (
cygrpc.operation_send_message(
serialized_request, _EMPTY_FLAGS),
)
call.start_client_batch(cygrpc.Operations(operations),
event_handler)
operations = (cygrpc.operation_send_message(
serialized_request, _EMPTY_FLAGS),)
call.start_client_batch(
cygrpc.Operations(operations), event_handler)
state.due.add(cygrpc.OperationType.send_message)
while True:
state.condition.wait()
@ -236,9 +237,9 @@ def _consume_request_iterator(
with state.condition:
if state.code is None:
operations = (
cygrpc.operation_send_close_from_client(_EMPTY_FLAGS),
)
call.start_client_batch(cygrpc.Operations(operations), event_handler)
cygrpc.operation_send_close_from_client(_EMPTY_FLAGS),)
call.start_client_batch(
cygrpc.Operations(operations), event_handler)
state.due.add(cygrpc.OperationType.send_close_from_client)
def stop_consumption_thread(timeout):
@ -337,8 +338,8 @@ class _Rendezvous(grpc.RpcError, grpc.Future, grpc.Call):
def _next(self):
with self._state.condition:
if self._state.code is None:
event_handler = _event_handler(
self._state, self._call, self._response_deserializer)
event_handler = _event_handler(self._state, self._call,
self._response_deserializer)
self._call.start_client_batch(
cygrpc.Operations(
(cygrpc.operation_receive_message(_EMPTY_FLAGS),)),
@ -438,8 +439,8 @@ def _start_unary_request(request, timeout, request_serializer):
deadline, deadline_timespec = _deadline(timeout)
serialized_request = _common.serialize(request, request_serializer)
if serialized_request is None:
state = _RPCState(
(), _EMPTY_METADATA, _EMPTY_METADATA, grpc.StatusCode.INTERNAL,
state = _RPCState((), _EMPTY_METADATA, _EMPTY_METADATA,
grpc.StatusCode.INTERNAL,
'Exception serializing request!')
rendezvous = _Rendezvous(state, None, None, deadline)
return deadline, deadline_timespec, None, rendezvous
@ -460,8 +461,7 @@ def _end_unary_response_blocking(state, with_call, deadline):
class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
def __init__(
self, channel, managed_call, method, request_serializer,
def __init__(self, channel, managed_call, method, request_serializer,
response_deserializer):
self._channel = channel
self._managed_call = managed_call
@ -483,8 +483,7 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
cygrpc.operation_send_close_from_client(_EMPTY_FLAGS),
cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),
cygrpc.operation_receive_message(_EMPTY_FLAGS),
cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),
)
cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),)
return state, operations, deadline, deadline_timespec, None
def _blocking(self, request, timeout, metadata, credentials):
@ -494,21 +493,26 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
raise rendezvous
else:
completion_queue = cygrpc.CompletionQueue()
call = self._channel.create_call(
None, 0, completion_queue, self._method, None, deadline_timespec)
call = self._channel.create_call(None, 0, completion_queue,
self._method, None,
deadline_timespec)
if credentials is not None:
call.set_credentials(credentials._credentials)
call_error = call.start_client_batch(cygrpc.Operations(operations), None)
call_error = call.start_client_batch(
cygrpc.Operations(operations), None)
_check_call_error(call_error, metadata)
_handle_event(completion_queue.poll(), state, self._response_deserializer)
_handle_event(completion_queue.poll(), state,
self._response_deserializer)
return state, deadline
def __call__(self, request, timeout=None, metadata=None, credentials=None):
state, deadline, = self._blocking(request, timeout, metadata, credentials)
state, deadline, = self._blocking(request, timeout, metadata,
credentials)
return _end_unary_response_blocking(state, False, deadline)
def with_call(self, request, timeout=None, metadata=None, credentials=None):
state, deadline, = self._blocking(request, timeout, metadata, credentials)
state, deadline, = self._blocking(request, timeout, metadata,
credentials)
return _end_unary_response_blocking(state, True, deadline)
def future(self, request, timeout=None, metadata=None, credentials=None):
@ -517,25 +521,26 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
if rendezvous:
return rendezvous
else:
call, drive_call = self._managed_call(
None, 0, self._method, None, deadline_timespec)
call, drive_call = self._managed_call(None, 0, self._method, None,
deadline_timespec)
if credentials is not None:
call.set_credentials(credentials._credentials)
event_handler = _event_handler(state, call, self._response_deserializer)
event_handler = _event_handler(state, call,
self._response_deserializer)
with state.condition:
call_error = call.start_client_batch(cygrpc.Operations(operations),
event_handler)
call_error = call.start_client_batch(
cygrpc.Operations(operations), event_handler)
if call_error != cygrpc.CallError.ok:
_call_error_set_RPCstate(state, call_error, metadata)
return _Rendezvous(state, None, None, deadline)
drive_call()
return _Rendezvous(state, call, self._response_deserializer, deadline)
return _Rendezvous(state, call, self._response_deserializer,
deadline)
class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
def __init__(
self, channel, managed_call, method, request_serializer,
def __init__(self, channel, managed_call, method, request_serializer,
response_deserializer):
self._channel = channel
self._managed_call = managed_call
@ -550,36 +555,37 @@ class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
raise rendezvous
else:
state = _RPCState(_UNARY_STREAM_INITIAL_DUE, None, None, None, None)
call, drive_call = self._managed_call(
None, 0, self._method, None, deadline_timespec)
call, drive_call = self._managed_call(None, 0, self._method, None,
deadline_timespec)
if credentials is not None:
call.set_credentials(credentials._credentials)
event_handler = _event_handler(state, call, self._response_deserializer)
event_handler = _event_handler(state, call,
self._response_deserializer)
with state.condition:
call.start_client_batch(
cygrpc.Operations(
(cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),)),
event_handler)
cygrpc.Operations((
cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),
)), event_handler)
operations = (
cygrpc.operation_send_initial_metadata(
_common.cygrpc_metadata(metadata), _EMPTY_FLAGS),
cygrpc.operation_send_message(serialized_request, _EMPTY_FLAGS),
cygrpc.operation_send_message(serialized_request,
_EMPTY_FLAGS),
cygrpc.operation_send_close_from_client(_EMPTY_FLAGS),
cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),
)
call_error = call.start_client_batch(cygrpc.Operations(operations),
event_handler)
cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),)
call_error = call.start_client_batch(
cygrpc.Operations(operations), event_handler)
if call_error != cygrpc.CallError.ok:
_call_error_set_RPCstate(state, call_error, metadata)
return _Rendezvous(state, None, None, deadline)
drive_call()
return _Rendezvous(state, call, self._response_deserializer, deadline)
return _Rendezvous(state, call, self._response_deserializer,
deadline)
class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
def __init__(
self, channel, managed_call, method, request_serializer,
def __init__(self, channel, managed_call, method, request_serializer,
response_deserializer):
self._channel = channel
self._managed_call = managed_call
@ -591,8 +597,8 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
deadline, deadline_timespec = _deadline(timeout)
state = _RPCState(_STREAM_UNARY_INITIAL_DUE, None, None, None, None)
completion_queue = cygrpc.CompletionQueue()
call = self._channel.create_call(
None, 0, completion_queue, self._method, None, deadline_timespec)
call = self._channel.create_call(None, 0, completion_queue,
self._method, None, deadline_timespec)
if credentials is not None:
call.set_credentials(credentials._credentials)
with state.condition:
@ -604,12 +610,12 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
cygrpc.operation_send_initial_metadata(
_common.cygrpc_metadata(metadata), _EMPTY_FLAGS),
cygrpc.operation_receive_message(_EMPTY_FLAGS),
cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),
)
call_error = call.start_client_batch(cygrpc.Operations(operations), None)
cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),)
call_error = call.start_client_batch(
cygrpc.Operations(operations), None)
_check_call_error(call_error, metadata)
_consume_request_iterator(
request_iterator, state, call, self._request_serializer)
_consume_request_iterator(request_iterator, state, call,
self._request_serializer)
while True:
event = completion_queue.poll()
with state.condition:
@ -619,24 +625,33 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
break
return state, deadline
def __call__(
self, request_iterator, timeout=None, metadata=None, credentials=None):
state, deadline, = self._blocking(
request_iterator, timeout, metadata, credentials)
def __call__(self,
request_iterator,
timeout=None,
metadata=None,
credentials=None):
state, deadline, = self._blocking(request_iterator, timeout, metadata,
credentials)
return _end_unary_response_blocking(state, False, deadline)
def with_call(
self, request_iterator, timeout=None, metadata=None, credentials=None):
state, deadline, = self._blocking(
request_iterator, timeout, metadata, credentials)
def with_call(self,
request_iterator,
timeout=None,
metadata=None,
credentials=None):
state, deadline, = self._blocking(request_iterator, timeout, metadata,
credentials)
return _end_unary_response_blocking(state, True, deadline)
def future(
self, request_iterator, timeout=None, metadata=None, credentials=None):
def future(self,
request_iterator,
timeout=None,
metadata=None,
credentials=None):
deadline, deadline_timespec = _deadline(timeout)
state = _RPCState(_STREAM_UNARY_INITIAL_DUE, None, None, None, None)
call, drive_call = self._managed_call(
None, 0, self._method, None, deadline_timespec)
call, drive_call = self._managed_call(None, 0, self._method, None,
deadline_timespec)
if credentials is not None:
call.set_credentials(credentials._credentials)
event_handler = _event_handler(state, call, self._response_deserializer)
@ -649,23 +664,21 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
cygrpc.operation_send_initial_metadata(
_common.cygrpc_metadata(metadata), _EMPTY_FLAGS),
cygrpc.operation_receive_message(_EMPTY_FLAGS),
cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),
)
call_error = call.start_client_batch(cygrpc.Operations(operations),
event_handler)
cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),)
call_error = call.start_client_batch(
cygrpc.Operations(operations), event_handler)
if call_error != cygrpc.CallError.ok:
_call_error_set_RPCstate(state, call_error, metadata)
return _Rendezvous(state, None, None, deadline)
drive_call()
_consume_request_iterator(
request_iterator, state, call, self._request_serializer)
_consume_request_iterator(request_iterator, state, call,
self._request_serializer)
return _Rendezvous(state, call, self._response_deserializer, deadline)
class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable):
def __init__(
self, channel, managed_call, method, request_serializer,
def __init__(self, channel, managed_call, method, request_serializer,
response_deserializer):
self._channel = channel
self._managed_call = managed_call
@ -673,12 +686,15 @@ class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable):
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
def __call__(
self, request_iterator, timeout=None, metadata=None, credentials=None):
def __call__(self,
request_iterator,
timeout=None,
metadata=None,
credentials=None):
deadline, deadline_timespec = _deadline(timeout)
state = _RPCState(_STREAM_STREAM_INITIAL_DUE, None, None, None, None)
call, drive_call = self._managed_call(
None, 0, self._method, None, deadline_timespec)
call, drive_call = self._managed_call(None, 0, self._method, None,
deadline_timespec)
if credentials is not None:
call.set_credentials(credentials._credentials)
event_handler = _event_handler(state, call, self._response_deserializer)
@ -690,16 +706,15 @@ class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable):
operations = (
cygrpc.operation_send_initial_metadata(
_common.cygrpc_metadata(metadata), _EMPTY_FLAGS),
cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),
)
call_error = call.start_client_batch(cygrpc.Operations(operations),
event_handler)
cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),)
call_error = call.start_client_batch(
cygrpc.Operations(operations), event_handler)
if call_error != cygrpc.CallError.ok:
_call_error_set_RPCstate(state, call_error, metadata)
return _Rendezvous(state, None, None, deadline)
drive_call()
_consume_request_iterator(
request_iterator, state, call, self._request_serializer)
_consume_request_iterator(request_iterator, state, call,
self._request_serializer)
return _Rendezvous(state, call, self._response_deserializer, deadline)
@ -713,6 +728,7 @@ class _ChannelCallState(object):
def _run_channel_spin_thread(state):
def channel_spin():
while True:
event = state.completion_queue.poll()
@ -736,6 +752,7 @@ def _run_channel_spin_thread(state):
def _channel_managed_call_management(state):
def create(parent, flags, method, host, deadline):
"""Creates a managed cygrpc.Call and a function to call to drive it.
@ -754,8 +771,8 @@ def _channel_managed_call_management(state):
A cygrpc.Call with which to conduct an RPC and a function to call if
operations are successfully started on the call.
"""
call = state.channel.create_call(
parent, flags, state.completion_queue, method, host, deadline)
call = state.channel.create_call(parent, flags, state.completion_queue,
method, host, deadline)
def drive():
with state.lock:
@ -766,6 +783,7 @@ def _channel_managed_call_management(state):
state.managed_calls.add(call)
return call, drive
return create
@ -810,7 +828,10 @@ def _deliver(state, initial_connectivity, initial_callbacks):
def _spawn_delivery(state, callbacks):
delivering_thread = threading.Thread(
target=_deliver, args=(state, state.connectivity, callbacks,))
target=_deliver, args=(
state,
state.connectivity,
callbacks,))
delivering_thread.start()
state.delivering = True
@ -823,8 +844,8 @@ def _poll_connectivity(state, channel, initial_try_to_connect):
state.connectivity = (
_common.CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY[
connectivity])
callbacks = tuple(
callback for callback, unused_but_known_to_be_none_connectivity
callbacks = tuple(callback
for callback, unused_but_known_to_be_none_connectivity
in state.callbacks_and_connectivities)
for callback_and_connectivity in state.callbacks_and_connectivities:
callback_and_connectivity[1] = state.connectivity
@ -832,8 +853,8 @@ def _poll_connectivity(state, channel, initial_try_to_connect):
_spawn_delivery(state, callbacks)
completion_queue = cygrpc.CompletionQueue()
while True:
channel.watch_connectivity_state(
connectivity, cygrpc.Timespec(time.time() + 0.2),
channel.watch_connectivity_state(connectivity,
cygrpc.Timespec(time.time() + 0.2),
completion_queue, None)
event = completion_queue.poll()
with state.lock:
@ -863,10 +884,13 @@ def _moot(state):
def _subscribe(state, callback, try_to_connect):
with state.lock:
if not state.callbacks_and_connectivities and not state.polling:
def cancel_all_subscriptions(timeout):
_moot(state)
polling_thread = _common.CleanupThread(
cancel_all_subscriptions, target=_poll_connectivity,
cancel_all_subscriptions,
target=_poll_connectivity,
args=(state, state.channel, bool(try_to_connect)))
polling_thread.start()
state.polling = True
@ -883,8 +907,8 @@ def _subscribe(state, callback, try_to_connect):
def _unsubscribe(state, callback):
with state.lock:
for index, (subscribed_callback, unused_connectivity) in enumerate(
state.callbacks_and_connectivities):
for index, (subscribed_callback, unused_connectivity
) in enumerate(state.callbacks_and_connectivities):
if callback == subscribed_callback:
state.callbacks_and_connectivities.pop(index)
break
@ -892,7 +916,8 @@ def _unsubscribe(state, callback):
def _options(options):
return list(options) + [
(cygrpc.ChannelArgKey.primary_user_agent_string, _USER_AGENT)]
(cygrpc.ChannelArgKey.primary_user_agent_string, _USER_AGENT)
]
class Channel(grpc.Channel):
@ -907,8 +932,8 @@ class Channel(grpc.Channel):
credentials: A cygrpc.ChannelCredentials or None.
"""
self._channel = cygrpc.Channel(
_common.encode(target), _common.channel_args(_options(options)),
credentials)
_common.encode(target),
_common.channel_args(_options(options)), credentials)
self._call_state = _ChannelCallState(self._channel)
self._connectivity_state = _ChannelConnectivityState(self._channel)
@ -918,28 +943,40 @@ class Channel(grpc.Channel):
def unsubscribe(self, callback):
_unsubscribe(self._connectivity_state, callback)
def unary_unary(
self, method, request_serializer=None, response_deserializer=None):
def unary_unary(self,
method,
request_serializer=None,
response_deserializer=None):
return _UnaryUnaryMultiCallable(
self._channel, _channel_managed_call_management(self._call_state),
self._channel,
_channel_managed_call_management(self._call_state),
_common.encode(method), request_serializer, response_deserializer)
def unary_stream(
self, method, request_serializer=None, response_deserializer=None):
def unary_stream(self,
method,
request_serializer=None,
response_deserializer=None):
return _UnaryStreamMultiCallable(
self._channel, _channel_managed_call_management(self._call_state),
self._channel,
_channel_managed_call_management(self._call_state),
_common.encode(method), request_serializer, response_deserializer)
def stream_unary(
self, method, request_serializer=None, response_deserializer=None):
def stream_unary(self,
method,
request_serializer=None,
response_deserializer=None):
return _StreamUnaryMultiCallable(
self._channel, _channel_managed_call_management(self._call_state),
self._channel,
_channel_managed_call_management(self._call_state),
_common.encode(method), request_serializer, response_deserializer)
def stream_stream(
self, method, request_serializer=None, response_deserializer=None):
def stream_stream(self,
method,
request_serializer=None,
response_deserializer=None):
return _StreamStreamMultiCallable(
self._channel, _channel_managed_call_management(self._call_state),
self._channel,
_channel_managed_call_management(self._call_state),
_common.encode(method), request_serializer, response_deserializer)
def __del__(self):

@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Shared implementation."""
import logging
@ -46,8 +45,7 @@ CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY = {
cygrpc.ConnectivityState.ready: grpc.ChannelConnectivity.READY,
cygrpc.ConnectivityState.transient_failure:
grpc.ChannelConnectivity.TRANSIENT_FAILURE,
cygrpc.ConnectivityState.shutdown:
grpc.ChannelConnectivity.SHUTDOWN,
cygrpc.ConnectivityState.shutdown: grpc.ChannelConnectivity.SHUTDOWN,
}
CYGRPC_STATUS_CODE_TO_STATUS_CODE = {
@ -114,8 +112,8 @@ def application_metadata(cygrpc_metadata):
if cygrpc_metadata is None:
return ()
else:
return tuple(
(decode(key), value if key[-4:] == b'-bin' else decode(value))
return tuple((decode(key), value
if key[-4:] == b'-bin' else decode(value))
for key, value in cygrpc_metadata)
@ -151,8 +149,13 @@ class CleanupThread(threading.Thread):
we accomplish this by overriding the join() method.
"""
def __init__(self, behavior, group=None, target=None, name=None,
args=(), kwargs={}):
def __init__(self,
behavior,
group=None,
target=None,
name=None,
args=(),
kwargs={}):
"""Constructor.
Args:
@ -169,8 +172,8 @@ class CleanupThread(threading.Thread):
kwargs (dict[str,object]): A dictionary of keyword arguments to
pass to `target`.
"""
super(CleanupThread, self).__init__(group=group, target=target,
name=name, args=args, kwargs=kwargs)
super(CleanupThread, self).__init__(
group=group, target=target, name=name, args=args, kwargs=kwargs)
self._behavior = behavior
def join(self, timeout=None):

@ -44,5 +44,5 @@ def call(call_credentialses):
def channel(channel_credentials, call_credentialses):
return cygrpc.channel_credentials_composite(
channel_credentials, _call(call_credentialses))
return cygrpc.channel_credentials_composite(channel_credentials,
_call(call_credentialses))

@ -36,9 +36,9 @@ from grpc._cython import cygrpc
class AuthMetadataContext(
collections.namedtuple(
'AuthMetadataContext', ('service_url', 'method_name',)),
grpc.AuthMetadataContext):
collections.namedtuple('AuthMetadataContext', (
'service_url',
'method_name',)), grpc.AuthMetadataContext):
pass
@ -62,8 +62,7 @@ class _WrappedCygrpcCallback(object):
def _invoke_failure(self, error):
# TODO(atash) translate different Exception superclasses into different
# status codes.
self.cygrpc_callback(
_common.EMPTY_METADATA, cygrpc.StatusCode.internal,
self.cygrpc_callback(_common.EMPTY_METADATA, cygrpc.StatusCode.internal,
_common.encode(str(error)))
def _invoke_success(self, metadata):
@ -101,10 +100,11 @@ class _WrappedPlugin(object):
def __call__(self, context, cygrpc_callback):
wrapped_cygrpc_callback = _WrappedCygrpcCallback(cygrpc_callback)
wrapped_context = AuthMetadataContext(
_common.decode(context.service_url), _common.decode(context.method_name))
_common.decode(context.service_url),
_common.decode(context.method_name))
try:
self.plugin(
wrapped_context, AuthMetadataPluginCallback(wrapped_cygrpc_callback))
self.plugin(wrapped_context,
AuthMetadataPluginCallback(wrapped_cygrpc_callback))
except Exception as error:
wrapped_cygrpc_callback.notify_failure(error)
raise
@ -120,4 +120,5 @@ def call_credentials_metadata_plugin(plugin, name):
plugin's invocation must be non-blocking.
"""
return cygrpc.call_credentials_metadata_plugin(
cygrpc.CredentialsMetadataPlugin(_WrappedPlugin(plugin), _common.encode(name)))
cygrpc.CredentialsMetadataPlugin(
_WrappedPlugin(plugin), _common.encode(name)))

@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Service-side implementation of gRPC Python."""
import collections
@ -91,9 +90,9 @@ def _details(state):
class _HandlerCallDetails(
collections.namedtuple(
'_HandlerCallDetails', ('method', 'invocation_metadata',)),
grpc.HandlerCallDetails):
collections.namedtuple('_HandlerCallDetails', (
'method',
'invocation_metadata',)), grpc.HandlerCallDetails):
pass
@ -131,9 +130,11 @@ def _possibly_finish_call(state, token):
def _send_status_from_server(state, token):
def send_status_from_server(unused_send_status_from_server_event):
with state.condition:
return _possibly_finish_call(state, token)
return send_status_from_server
@ -143,19 +144,16 @@ def _abort(state, call, code, details):
effective_details = details if state.details is None else state.details
if state.initial_metadata_allowed:
operations = (
cygrpc.operation_send_initial_metadata(
_EMPTY_METADATA, _EMPTY_FLAGS),
cygrpc.operation_send_initial_metadata(_EMPTY_METADATA,
_EMPTY_FLAGS),
cygrpc.operation_send_status_from_server(
_common.cygrpc_metadata(state.trailing_metadata), effective_code,
effective_details, _EMPTY_FLAGS),
)
_common.cygrpc_metadata(state.trailing_metadata),
effective_code, effective_details, _EMPTY_FLAGS),)
token = _SEND_INITIAL_METADATA_AND_SEND_STATUS_FROM_SERVER_TOKEN
else:
operations = (
cygrpc.operation_send_status_from_server(
_common.cygrpc_metadata(state.trailing_metadata), effective_code,
effective_details, _EMPTY_FLAGS),
)
operations = (cygrpc.operation_send_status_from_server(
_common.cygrpc_metadata(state.trailing_metadata),
effective_code, effective_details, _EMPTY_FLAGS),)
token = _SEND_STATUS_FROM_SERVER_TOKEN
call.start_server_batch(
cygrpc.Operations(operations),
@ -165,18 +163,22 @@ def _abort(state, call, code, details):
def _receive_close_on_server(state):
def receive_close_on_server(receive_close_on_server_event):
with state.condition:
if receive_close_on_server_event.batch_operations[0].received_cancelled:
if receive_close_on_server_event.batch_operations[
0].received_cancelled:
state.client = _CANCELLED
elif state.client is _OPEN:
state.client = _CLOSED
state.condition.notify_all()
return _possibly_finish_call(state, _RECEIVE_CLOSE_ON_SERVER_TOKEN)
return receive_close_on_server
def _receive_message(state, call, request_deserializer):
def receive_message(receive_message_event):
serialized_request = _serialized_request(receive_message_event)
if serialized_request is None:
@ -186,31 +188,36 @@ def _receive_message(state, call, request_deserializer):
state.condition.notify_all()
return _possibly_finish_call(state, _RECEIVE_MESSAGE_TOKEN)
else:
request = _common.deserialize(serialized_request, request_deserializer)
request = _common.deserialize(serialized_request,
request_deserializer)
with state.condition:
if request is None:
_abort(
state, call, cygrpc.StatusCode.internal,
_abort(state, call, cygrpc.StatusCode.internal,
b'Exception deserializing request!')
else:
state.request = request
state.condition.notify_all()
return _possibly_finish_call(state, _RECEIVE_MESSAGE_TOKEN)
return receive_message
def _send_initial_metadata(state):
def send_initial_metadata(unused_send_initial_metadata_event):
with state.condition:
return _possibly_finish_call(state, _SEND_INITIAL_METADATA_TOKEN)
return send_initial_metadata
def _send_message(state, token):
def send_message(unused_send_message_event):
with state.condition:
state.condition.notify_all()
return _possibly_finish_call(state, token)
return send_message
@ -226,7 +233,8 @@ class _Context(grpc.ServicerContext):
return self._state.client is not _CANCELLED and not self._state.statused
def time_remaining(self):
return max(self._rpc_event.request_call_details.deadline - time.time(), 0)
return max(self._rpc_event.request_call_details.deadline - time.time(),
0)
def cancel(self):
self._rpc_event.operation_call.cancel()
@ -293,8 +301,10 @@ class _RequestIterator(object):
raise StopIteration()
else:
self._call.start_server_batch(
cygrpc.Operations((cygrpc.operation_receive_message(_EMPTY_FLAGS),)),
_receive_message(self._state, self._call, self._request_deserializer))
cygrpc.Operations(
(cygrpc.operation_receive_message(_EMPTY_FLAGS),)),
_receive_message(self._state, self._call,
self._request_deserializer))
self._state.due.add(_RECEIVE_MESSAGE_TOKEN)
def _look_for_request(self):
@ -328,6 +338,7 @@ class _RequestIterator(object):
def _unary_request(rpc_event, state, request_deserializer):
def unary_request():
with state.condition:
if state.client is _CANCELLED or state.statused:
@ -336,8 +347,8 @@ def _unary_request(rpc_event, state, request_deserializer):
start_server_batch_result = rpc_event.operation_call.start_server_batch(
cygrpc.Operations(
(cygrpc.operation_receive_message(_EMPTY_FLAGS),)),
_receive_message(
state, rpc_event.operation_call, request_deserializer))
_receive_message(state, rpc_event.operation_call,
request_deserializer))
state.due.add(_RECEIVE_MESSAGE_TOKEN)
while True:
state.condition.wait()
@ -345,9 +356,9 @@ def _unary_request(rpc_event, state, request_deserializer):
if state.client is _CLOSED:
details = '"{}" requires exactly one request message.'.format(
rpc_event.request_call_details.method)
_abort(
state, rpc_event.operation_call,
cygrpc.StatusCode.unimplemented, _common.encode(details))
_abort(state, rpc_event.operation_call,
cygrpc.StatusCode.unimplemented,
_common.encode(details))
return None
elif state.client is _CANCELLED:
return None
@ -355,6 +366,7 @@ def _unary_request(rpc_event, state, request_deserializer):
request = state.request
state.request = None
return request
return unary_request
@ -391,8 +403,7 @@ def _serialize_response(rpc_event, state, response, response_serializer):
serialized_response = _common.serialize(response, response_serializer)
if serialized_response is None:
with state.condition:
_abort(
state, rpc_event.operation_call, cygrpc.StatusCode.internal,
_abort(state, rpc_event.operation_call, cygrpc.StatusCode.internal,
b'Failed to serialize response!')
return None
else:
@ -406,16 +417,15 @@ def _send_response(rpc_event, state, serialized_response):
else:
if state.initial_metadata_allowed:
operations = (
cygrpc.operation_send_initial_metadata(
_EMPTY_METADATA, _EMPTY_FLAGS),
cygrpc.operation_send_message(serialized_response, _EMPTY_FLAGS),
)
cygrpc.operation_send_initial_metadata(_EMPTY_METADATA,
_EMPTY_FLAGS),
cygrpc.operation_send_message(serialized_response,
_EMPTY_FLAGS),)
state.initial_metadata_allowed = False
token = _SEND_INITIAL_METADATA_AND_SEND_MESSAGE_TOKEN
else:
operations = (
cygrpc.operation_send_message(serialized_response, _EMPTY_FLAGS),
)
operations = (cygrpc.operation_send_message(serialized_response,
_EMPTY_FLAGS),)
token = _SEND_MESSAGE_TOKEN
rpc_event.operation_call.start_server_batch(
cygrpc.Operations(operations), _send_message(state, token))
@ -438,11 +448,12 @@ def _status(rpc_event, state, serialized_response):
]
if state.initial_metadata_allowed:
operations.append(
cygrpc.operation_send_initial_metadata(
_EMPTY_METADATA, _EMPTY_FLAGS))
cygrpc.operation_send_initial_metadata(_EMPTY_METADATA,
_EMPTY_FLAGS))
if serialized_response is not None:
operations.append(cygrpc.operation_send_message(
serialized_response, _EMPTY_FLAGS))
operations.append(
cygrpc.operation_send_message(serialized_response,
_EMPTY_FLAGS))
rpc_event.operation_call.start_server_batch(
cygrpc.Operations(operations),
_send_status_from_server(state, _SEND_STATUS_FROM_SERVER_TOKEN))
@ -450,13 +461,12 @@ def _status(rpc_event, state, serialized_response):
state.due.add(_SEND_STATUS_FROM_SERVER_TOKEN)
def _unary_response_in_pool(
rpc_event, state, behavior, argument_thunk, request_deserializer,
response_serializer):
def _unary_response_in_pool(rpc_event, state, behavior, argument_thunk,
request_deserializer, response_serializer):
argument = argument_thunk()
if argument is not None:
response, proceed = _call_behavior(
rpc_event, state, behavior, argument, request_deserializer)
response, proceed = _call_behavior(rpc_event, state, behavior, argument,
request_deserializer)
if proceed:
serialized_response = _serialize_response(
rpc_event, state, response, response_serializer)
@ -464,9 +474,8 @@ def _unary_response_in_pool(
_status(rpc_event, state, serialized_response)
def _stream_response_in_pool(
rpc_event, state, behavior, argument_thunk, request_deserializer,
response_serializer):
def _stream_response_in_pool(rpc_event, state, behavior, argument_thunk,
request_deserializer, response_serializer):
argument = argument_thunk()
if argument is not None:
response_iterator, proceed = _call_behavior(
@ -483,7 +492,8 @@ def _stream_response_in_pool(
serialized_response = _serialize_response(
rpc_event, state, response, response_serializer)
if serialized_response is not None:
proceed = _send_response(rpc_event, state, serialized_response)
proceed = _send_response(rpc_event, state,
serialized_response)
if not proceed:
break
else:
@ -493,38 +503,38 @@ def _stream_response_in_pool(
def _handle_unary_unary(rpc_event, state, method_handler, thread_pool):
unary_request = _unary_request(
rpc_event, state, method_handler.request_deserializer)
thread_pool.submit(
_unary_response_in_pool, rpc_event, state, method_handler.unary_unary,
unary_request, method_handler.request_deserializer,
unary_request = _unary_request(rpc_event, state,
method_handler.request_deserializer)
thread_pool.submit(_unary_response_in_pool, rpc_event, state,
method_handler.unary_unary, unary_request,
method_handler.request_deserializer,
method_handler.response_serializer)
def _handle_unary_stream(rpc_event, state, method_handler, thread_pool):
unary_request = _unary_request(
rpc_event, state, method_handler.request_deserializer)
thread_pool.submit(
_stream_response_in_pool, rpc_event, state, method_handler.unary_stream,
unary_request, method_handler.request_deserializer,
unary_request = _unary_request(rpc_event, state,
method_handler.request_deserializer)
thread_pool.submit(_stream_response_in_pool, rpc_event, state,
method_handler.unary_stream, unary_request,
method_handler.request_deserializer,
method_handler.response_serializer)
def _handle_stream_unary(rpc_event, state, method_handler, thread_pool):
request_iterator = _RequestIterator(
state, rpc_event.operation_call, method_handler.request_deserializer)
thread_pool.submit(
_unary_response_in_pool, rpc_event, state, method_handler.stream_unary,
lambda: request_iterator, method_handler.request_deserializer,
request_iterator = _RequestIterator(state, rpc_event.operation_call,
method_handler.request_deserializer)
thread_pool.submit(_unary_response_in_pool, rpc_event, state,
method_handler.stream_unary, lambda: request_iterator,
method_handler.request_deserializer,
method_handler.response_serializer)
def _handle_stream_stream(rpc_event, state, method_handler, thread_pool):
request_iterator = _RequestIterator(
state, rpc_event.operation_call, method_handler.request_deserializer)
thread_pool.submit(
_stream_response_in_pool, rpc_event, state, method_handler.stream_stream,
lambda: request_iterator, method_handler.request_deserializer,
request_iterator = _RequestIterator(state, rpc_event.operation_call,
method_handler.request_deserializer)
thread_pool.submit(_stream_response_in_pool, rpc_event, state,
method_handler.stream_stream, lambda: request_iterator,
method_handler.request_deserializer,
method_handler.response_serializer)
@ -546,11 +556,12 @@ def _handle_unrecognized_method(rpc_event):
cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS),
cygrpc.operation_send_status_from_server(
_EMPTY_METADATA, cygrpc.StatusCode.unimplemented,
b'Method not found!', _EMPTY_FLAGS),
)
b'Method not found!', _EMPTY_FLAGS),)
rpc_state = _RPCState()
rpc_event.operation_call.start_server_batch(
operations, lambda ignored_event: (rpc_state, (),))
rpc_event.operation_call.start_server_batch(operations,
lambda ignored_event: (
rpc_state,
(),))
return rpc_state
@ -564,14 +575,18 @@ def _handle_with_method_handler(rpc_event, method_handler, thread_pool):
state.due.add(_RECEIVE_CLOSE_ON_SERVER_TOKEN)
if method_handler.request_streaming:
if method_handler.response_streaming:
_handle_stream_stream(rpc_event, state, method_handler, thread_pool)
_handle_stream_stream(rpc_event, state, method_handler,
thread_pool)
else:
_handle_stream_unary(rpc_event, state, method_handler, thread_pool)
_handle_stream_unary(rpc_event, state, method_handler,
thread_pool)
else:
if method_handler.response_streaming:
_handle_unary_stream(rpc_event, state, method_handler, thread_pool)
_handle_unary_stream(rpc_event, state, method_handler,
thread_pool)
else:
_handle_unary_unary(rpc_event, state, method_handler, thread_pool)
_handle_unary_unary(rpc_event, state, method_handler,
thread_pool)
return state
@ -581,7 +596,8 @@ def _handle_call(rpc_event, generic_handlers, thread_pool):
if method_handler is None:
return _handle_unrecognized_method(rpc_event)
else:
return _handle_with_method_handler(rpc_event, method_handler, thread_pool)
return _handle_with_method_handler(rpc_event, method_handler,
thread_pool)
else:
return None
@ -621,12 +637,13 @@ def _add_insecure_port(state, address):
def _add_secure_port(state, address, server_credentials):
with state.lock:
return state.server.add_http2_port(address, server_credentials._credentials)
return state.server.add_http2_port(address,
server_credentials._credentials)
def _request_call(state):
state.server.request_call(
state.completion_queue, state.completion_queue, _REQUEST_CALL_TAG)
state.server.request_call(state.completion_queue, state.completion_queue,
_REQUEST_CALL_TAG)
state.due.add(_REQUEST_CALL_TAG)
@ -652,8 +669,8 @@ def _serve(state):
elif event.tag is _REQUEST_CALL_TAG:
with state.lock:
state.due.remove(_REQUEST_CALL_TAG)
rpc_state = _handle_call(
event, state.generic_handlers, state.thread_pool)
rpc_state = _handle_call(event, state.generic_handlers,
state.thread_pool)
if rpc_state is not None:
state.rpc_states.add(rpc_state)
if state.stage is _ServerStage.STARTED:
@ -694,6 +711,7 @@ def _stop(state, grace):
rpc_state.client = _CANCELLED
rpc_state.condition.notify_all()
else:
def cancel_all_calls_after_grace():
shutdown_event.wait(timeout=grace)
with state.lock:
@ -703,6 +721,7 @@ def _stop(state, grace):
with rpc_state.condition:
rpc_state.client = _CANCELLED
rpc_state.condition.notify_all()
thread = threading.Thread(target=cancel_all_calls_after_grace)
thread.start()
return shutdown_event
@ -717,6 +736,7 @@ def _start(state):
state.server.start()
state.stage = _ServerStage.STARTED
_request_call(state)
def cleanup_server(timeout):
if timeout is None:
_stop(state, _UNEXPECTED_EXIT_SERVER_GRACE).wait()
@ -727,14 +747,15 @@ def _start(state):
cleanup_server, target=_serve, args=(state,))
thread.start()
class Server(grpc.Server):
def __init__(self, thread_pool, generic_handlers, options):
completion_queue = cygrpc.CompletionQueue()
server = cygrpc.Server(_common.channel_args(options))
server.register_completion_queue(completion_queue)
self._state = _ServerState(
completion_queue, server, generic_handlers, thread_pool)
self._state = _ServerState(completion_queue, server, generic_handlers,
thread_pool)
def add_generic_rpc_handlers(self, generic_rpc_handlers):
_add_generic_handlers(self._state, generic_rpc_handlers)
@ -743,7 +764,8 @@ class Server(grpc.Server):
return _add_insecure_port(self._state, _common.encode(address))
def add_secure_port(self, address, server_credentials):
return _add_secure_port(self._state, _common.encode(address), server_credentials)
return _add_secure_port(self._state,
_common.encode(address), server_credentials)
def start(self):
_start(self._state)

@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Internal utilities for gRPC Python."""
import collections
@ -44,12 +43,15 @@ _DONE_CALLBACK_EXCEPTION_LOG_MESSAGE = (
class RpcMethodHandler(
collections.namedtuple(
'_RpcMethodHandler',
('request_streaming', 'response_streaming', 'request_deserializer',
'response_serializer', 'unary_unary', 'unary_stream', 'stream_unary',
'stream_stream',)),
grpc.RpcMethodHandler):
collections.namedtuple('_RpcMethodHandler', (
'request_streaming',
'response_streaming',
'request_deserializer',
'response_serializer',
'unary_unary',
'unary_stream',
'stream_unary',
'stream_stream',)), grpc.RpcMethodHandler):
pass
@ -59,7 +61,8 @@ class DictionaryGenericHandler(grpc.ServiceRpcHandler):
self._name = service
self._method_handlers = {
_common.fully_qualified_method(service, method): method_handler
for method, method_handler in six.iteritems(method_handlers)}
for method, method_handler in six.iteritems(method_handlers)
}
def service_name(self):
return self._name

@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Translates gRPC's client-side API into gRPC's client-side Beta API."""
import grpc
@ -38,14 +37,14 @@ from grpc.framework.foundation import future
from grpc.framework.interfaces.face import face
_STATUS_CODE_TO_ABORTION_KIND_AND_ABORTION_ERROR_CLASS = {
grpc.StatusCode.CANCELLED: (
face.Abortion.Kind.CANCELLED, face.CancellationError),
grpc.StatusCode.UNKNOWN: (
face.Abortion.Kind.REMOTE_FAILURE, face.RemoteError),
grpc.StatusCode.DEADLINE_EXCEEDED: (
face.Abortion.Kind.EXPIRED, face.ExpirationError),
grpc.StatusCode.UNIMPLEMENTED: (
face.Abortion.Kind.LOCAL_FAILURE, face.LocalError),
grpc.StatusCode.CANCELLED: (face.Abortion.Kind.CANCELLED,
face.CancellationError),
grpc.StatusCode.UNKNOWN: (face.Abortion.Kind.REMOTE_FAILURE,
face.RemoteError),
grpc.StatusCode.DEADLINE_EXCEEDED: (face.Abortion.Kind.EXPIRED,
face.ExpirationError),
grpc.StatusCode.UNIMPLEMENTED: (face.Abortion.Kind.LOCAL_FAILURE,
face.LocalError),
}
@ -65,18 +64,19 @@ def _abortion(rpc_error_call):
code = rpc_error_call.code()
pair = _STATUS_CODE_TO_ABORTION_KIND_AND_ABORTION_ERROR_CLASS.get(code)
error_kind = face.Abortion.Kind.LOCAL_FAILURE if pair is None else pair[0]
return face.Abortion(
error_kind, rpc_error_call.initial_metadata(),
rpc_error_call.trailing_metadata(), code, rpc_error_call.details())
return face.Abortion(error_kind,
rpc_error_call.initial_metadata(),
rpc_error_call.trailing_metadata(), code,
rpc_error_call.details())
def _abortion_error(rpc_error_call):
code = rpc_error_call.code()
pair = _STATUS_CODE_TO_ABORTION_KIND_AND_ABORTION_ERROR_CLASS.get(code)
exception_class = face.AbortionError if pair is None else pair[1]
return exception_class(
rpc_error_call.initial_metadata(), rpc_error_call.trailing_metadata(),
code, rpc_error_call.details())
return exception_class(rpc_error_call.initial_metadata(),
rpc_error_call.trailing_metadata(), code,
rpc_error_call.details())
class _InvocationProtocolContext(interfaces.GRPCInvocationContext):
@ -159,9 +159,11 @@ class _Rendezvous(future.Future, face.Call):
return self._call.time_remaining()
def add_abortion_callback(self, abortion_callback):
def done_callback():
if self.code() is not grpc.StatusCode.OK:
abortion_callback(_abortion(self._call))
registered = self._call.add_callback(done_callback)
return None if registered else done_callback()
@ -181,9 +183,9 @@ class _Rendezvous(future.Future, face.Call):
return self._call.details()
def _blocking_unary_unary(
channel, group, method, timeout, with_call, protocol_options, metadata,
metadata_transformer, request, request_serializer, response_deserializer):
def _blocking_unary_unary(channel, group, method, timeout, with_call,
protocol_options, metadata, metadata_transformer,
request, request_serializer, response_deserializer):
try:
multi_callable = channel.unary_unary(
_common.fully_qualified_method(group, method),
@ -192,48 +194,56 @@ def _blocking_unary_unary(
effective_metadata = _effective_metadata(metadata, metadata_transformer)
if with_call:
response, call = multi_callable.with_call(
request, timeout=timeout, metadata=effective_metadata,
request,
timeout=timeout,
metadata=effective_metadata,
credentials=_credentials(protocol_options))
return response, _Rendezvous(None, None, call)
else:
return multi_callable(
request, timeout=timeout, metadata=effective_metadata,
request,
timeout=timeout,
metadata=effective_metadata,
credentials=_credentials(protocol_options))
except grpc.RpcError as rpc_error_call:
raise _abortion_error(rpc_error_call)
def _future_unary_unary(
channel, group, method, timeout, protocol_options, metadata,
metadata_transformer, request, request_serializer, response_deserializer):
def _future_unary_unary(channel, group, method, timeout, protocol_options,
metadata, metadata_transformer, request,
request_serializer, response_deserializer):
multi_callable = channel.unary_unary(
_common.fully_qualified_method(group, method),
request_serializer=request_serializer,
response_deserializer=response_deserializer)
effective_metadata = _effective_metadata(metadata, metadata_transformer)
response_future = multi_callable.future(
request, timeout=timeout, metadata=effective_metadata,
request,
timeout=timeout,
metadata=effective_metadata,
credentials=_credentials(protocol_options))
return _Rendezvous(response_future, None, response_future)
def _unary_stream(
channel, group, method, timeout, protocol_options, metadata,
metadata_transformer, request, request_serializer, response_deserializer):
def _unary_stream(channel, group, method, timeout, protocol_options, metadata,
metadata_transformer, request, request_serializer,
response_deserializer):
multi_callable = channel.unary_stream(
_common.fully_qualified_method(group, method),
request_serializer=request_serializer,
response_deserializer=response_deserializer)
effective_metadata = _effective_metadata(metadata, metadata_transformer)
response_iterator = multi_callable(
request, timeout=timeout, metadata=effective_metadata,
request,
timeout=timeout,
metadata=effective_metadata,
credentials=_credentials(protocol_options))
return _Rendezvous(None, response_iterator, response_iterator)
def _blocking_stream_unary(
channel, group, method, timeout, with_call, protocol_options, metadata,
metadata_transformer, request_iterator, request_serializer,
def _blocking_stream_unary(channel, group, method, timeout, with_call,
protocol_options, metadata, metadata_transformer,
request_iterator, request_serializer,
response_deserializer):
try:
multi_callable = channel.stream_unary(
@ -243,34 +253,38 @@ def _blocking_stream_unary(
effective_metadata = _effective_metadata(metadata, metadata_transformer)
if with_call:
response, call = multi_callable.with_call(
request_iterator, timeout=timeout, metadata=effective_metadata,
request_iterator,
timeout=timeout,
metadata=effective_metadata,
credentials=_credentials(protocol_options))
return response, _Rendezvous(None, None, call)
else:
return multi_callable(
request_iterator, timeout=timeout, metadata=effective_metadata,
request_iterator,
timeout=timeout,
metadata=effective_metadata,
credentials=_credentials(protocol_options))
except grpc.RpcError as rpc_error_call:
raise _abortion_error(rpc_error_call)
def _future_stream_unary(
channel, group, method, timeout, protocol_options, metadata,
metadata_transformer, request_iterator, request_serializer,
response_deserializer):
def _future_stream_unary(channel, group, method, timeout, protocol_options,
metadata, metadata_transformer, request_iterator,
request_serializer, response_deserializer):
multi_callable = channel.stream_unary(
_common.fully_qualified_method(group, method),
request_serializer=request_serializer,
response_deserializer=response_deserializer)
effective_metadata = _effective_metadata(metadata, metadata_transformer)
response_future = multi_callable.future(
request_iterator, timeout=timeout, metadata=effective_metadata,
request_iterator,
timeout=timeout,
metadata=effective_metadata,
credentials=_credentials(protocol_options))
return _Rendezvous(response_future, None, response_future)
def _stream_stream(
channel, group, method, timeout, protocol_options, metadata,
def _stream_stream(channel, group, method, timeout, protocol_options, metadata,
metadata_transformer, request_iterator, request_serializer,
response_deserializer):
multi_callable = channel.stream_stream(
@ -279,16 +293,17 @@ def _stream_stream(
response_deserializer=response_deserializer)
effective_metadata = _effective_metadata(metadata, metadata_transformer)
response_iterator = multi_callable(
request_iterator, timeout=timeout, metadata=effective_metadata,
request_iterator,
timeout=timeout,
metadata=effective_metadata,
credentials=_credentials(protocol_options))
return _Rendezvous(None, response_iterator, response_iterator)
class _UnaryUnaryMultiCallable(face.UnaryUnaryMultiCallable):
def __init__(
self, channel, group, method, metadata_transformer, request_serializer,
response_deserializer):
def __init__(self, channel, group, method, metadata_transformer,
request_serializer, response_deserializer):
self._channel = channel
self._group = group
self._method = method
@ -296,8 +311,11 @@ class _UnaryUnaryMultiCallable(face.UnaryUnaryMultiCallable):
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
def __call__(
self, request, timeout, metadata=None, with_call=False,
def __call__(self,
request,
timeout,
metadata=None,
with_call=False,
protocol_options=None):
return _blocking_unary_unary(
self._channel, self._group, self._method, timeout, with_call,
@ -307,20 +325,23 @@ class _UnaryUnaryMultiCallable(face.UnaryUnaryMultiCallable):
def future(self, request, timeout, metadata=None, protocol_options=None):
return _future_unary_unary(
self._channel, self._group, self._method, timeout, protocol_options,
metadata, self._metadata_transformer, request, self._request_serializer,
self._response_deserializer)
metadata, self._metadata_transformer, request,
self._request_serializer, self._response_deserializer)
def event(
self, request, receiver, abortion_callback, timeout,
metadata=None, protocol_options=None):
def event(self,
request,
receiver,
abortion_callback,
timeout,
metadata=None,
protocol_options=None):
raise NotImplementedError()
class _UnaryStreamMultiCallable(face.UnaryStreamMultiCallable):
def __init__(
self, channel, group, method, metadata_transformer, request_serializer,
response_deserializer):
def __init__(self, channel, group, method, metadata_transformer,
request_serializer, response_deserializer):
self._channel = channel
self._group = group
self._method = method
@ -331,20 +352,23 @@ class _UnaryStreamMultiCallable(face.UnaryStreamMultiCallable):
def __call__(self, request, timeout, metadata=None, protocol_options=None):
return _unary_stream(
self._channel, self._group, self._method, timeout, protocol_options,
metadata, self._metadata_transformer, request, self._request_serializer,
self._response_deserializer)
metadata, self._metadata_transformer, request,
self._request_serializer, self._response_deserializer)
def event(
self, request, receiver, abortion_callback, timeout,
metadata=None, protocol_options=None):
def event(self,
request,
receiver,
abortion_callback,
timeout,
metadata=None,
protocol_options=None):
raise NotImplementedError()
class _StreamUnaryMultiCallable(face.StreamUnaryMultiCallable):
def __init__(
self, channel, group, method, metadata_transformer, request_serializer,
response_deserializer):
def __init__(self, channel, group, method, metadata_transformer,
request_serializer, response_deserializer):
self._channel = channel
self._group = group
self._method = method
@ -352,32 +376,41 @@ class _StreamUnaryMultiCallable(face.StreamUnaryMultiCallable):
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
def __call__(
self, request_iterator, timeout, metadata=None, with_call=False,
def __call__(self,
request_iterator,
timeout,
metadata=None,
with_call=False,
protocol_options=None):
return _blocking_stream_unary(
self._channel, self._group, self._method, timeout, with_call,
protocol_options, metadata, self._metadata_transformer,
request_iterator, self._request_serializer, self._response_deserializer)
request_iterator, self._request_serializer,
self._response_deserializer)
def future(
self, request_iterator, timeout, metadata=None, protocol_options=None):
def future(self,
request_iterator,
timeout,
metadata=None,
protocol_options=None):
return _future_stream_unary(
self._channel, self._group, self._method, timeout, protocol_options,
metadata, self._metadata_transformer, request_iterator,
self._request_serializer, self._response_deserializer)
def event(
self, receiver, abortion_callback, timeout, metadata=None,
def event(self,
receiver,
abortion_callback,
timeout,
metadata=None,
protocol_options=None):
raise NotImplementedError()
class _StreamStreamMultiCallable(face.StreamStreamMultiCallable):
def __init__(
self, channel, group, method, metadata_transformer, request_serializer,
response_deserializer):
def __init__(self, channel, group, method, metadata_transformer,
request_serializer, response_deserializer):
self._channel = channel
self._group = group
self._method = method
@ -385,133 +418,226 @@ class _StreamStreamMultiCallable(face.StreamStreamMultiCallable):
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
def __call__(
self, request_iterator, timeout, metadata=None, protocol_options=None):
def __call__(self,
request_iterator,
timeout,
metadata=None,
protocol_options=None):
return _stream_stream(
self._channel, self._group, self._method, timeout, protocol_options,
metadata, self._metadata_transformer, request_iterator,
self._request_serializer, self._response_deserializer)
def event(
self, receiver, abortion_callback, timeout, metadata=None,
def event(self,
receiver,
abortion_callback,
timeout,
metadata=None,
protocol_options=None):
raise NotImplementedError()
class _GenericStub(face.GenericStub):
def __init__(
self, channel, metadata_transformer, request_serializers,
def __init__(self, channel, metadata_transformer, request_serializers,
response_deserializers):
self._channel = channel
self._metadata_transformer = metadata_transformer
self._request_serializers = request_serializers or {}
self._response_deserializers = response_deserializers or {}
def blocking_unary_unary(
self, group, method, request, timeout, metadata=None,
with_call=None, protocol_options=None):
request_serializer = self._request_serializers.get((group, method,))
response_deserializer = self._response_deserializers.get((group, method,))
return _blocking_unary_unary(
self._channel, group, method, timeout, with_call, protocol_options,
metadata, self._metadata_transformer, request, request_serializer,
response_deserializer)
def blocking_unary_unary(self,
group,
method,
request,
timeout,
metadata=None,
with_call=None,
protocol_options=None):
request_serializer = self._request_serializers.get((
group,
method,))
response_deserializer = self._response_deserializers.get((
group,
method,))
return _blocking_unary_unary(self._channel, group, method, timeout,
with_call, protocol_options, metadata,
self._metadata_transformer, request,
request_serializer, response_deserializer)
def future_unary_unary(
self, group, method, request, timeout, metadata=None,
def future_unary_unary(self,
group,
method,
request,
timeout,
metadata=None,
protocol_options=None):
request_serializer = self._request_serializers.get((group, method,))
response_deserializer = self._response_deserializers.get((group, method,))
return _future_unary_unary(
self._channel, group, method, timeout, protocol_options, metadata,
self._metadata_transformer, request, request_serializer,
response_deserializer)
request_serializer = self._request_serializers.get((
group,
method,))
response_deserializer = self._response_deserializers.get((
group,
method,))
return _future_unary_unary(self._channel, group, method, timeout,
protocol_options, metadata,
self._metadata_transformer, request,
request_serializer, response_deserializer)
def inline_unary_stream(
self, group, method, request, timeout, metadata=None,
def inline_unary_stream(self,
group,
method,
request,
timeout,
metadata=None,
protocol_options=None):
request_serializer = self._request_serializers.get((group, method,))
response_deserializer = self._response_deserializers.get((group, method,))
return _unary_stream(
self._channel, group, method, timeout, protocol_options, metadata,
self._metadata_transformer, request, request_serializer,
response_deserializer)
request_serializer = self._request_serializers.get((
group,
method,))
response_deserializer = self._response_deserializers.get((
group,
method,))
return _unary_stream(self._channel, group, method, timeout,
protocol_options, metadata,
self._metadata_transformer, request,
request_serializer, response_deserializer)
def blocking_stream_unary(
self, group, method, request_iterator, timeout, metadata=None,
with_call=None, protocol_options=None):
request_serializer = self._request_serializers.get((group, method,))
response_deserializer = self._response_deserializers.get((group, method,))
def blocking_stream_unary(self,
group,
method,
request_iterator,
timeout,
metadata=None,
with_call=None,
protocol_options=None):
request_serializer = self._request_serializers.get((
group,
method,))
response_deserializer = self._response_deserializers.get((
group,
method,))
return _blocking_stream_unary(
self._channel, group, method, timeout, with_call, protocol_options,
metadata, self._metadata_transformer, request_iterator,
request_serializer, response_deserializer)
def future_stream_unary(
self, group, method, request_iterator, timeout, metadata=None,
def future_stream_unary(self,
group,
method,
request_iterator,
timeout,
metadata=None,
protocol_options=None):
request_serializer = self._request_serializers.get((group, method,))
response_deserializer = self._response_deserializers.get((group, method,))
request_serializer = self._request_serializers.get((
group,
method,))
response_deserializer = self._response_deserializers.get((
group,
method,))
return _future_stream_unary(
self._channel, group, method, timeout, protocol_options, metadata,
self._metadata_transformer, request_iterator, request_serializer,
response_deserializer)
def inline_stream_stream(
self, group, method, request_iterator, timeout, metadata=None,
def inline_stream_stream(self,
group,
method,
request_iterator,
timeout,
metadata=None,
protocol_options=None):
request_serializer = self._request_serializers.get((group, method,))
response_deserializer = self._response_deserializers.get((group, method,))
return _stream_stream(
self._channel, group, method, timeout, protocol_options, metadata,
self._metadata_transformer, request_iterator, request_serializer,
response_deserializer)
request_serializer = self._request_serializers.get((
group,
method,))
response_deserializer = self._response_deserializers.get((
group,
method,))
return _stream_stream(self._channel, group, method, timeout,
protocol_options, metadata,
self._metadata_transformer, request_iterator,
request_serializer, response_deserializer)
def event_unary_unary(
self, group, method, request, receiver, abortion_callback, timeout,
metadata=None, protocol_options=None):
def event_unary_unary(self,
group,
method,
request,
receiver,
abortion_callback,
timeout,
metadata=None,
protocol_options=None):
raise NotImplementedError()
def event_unary_stream(
self, group, method, request, receiver, abortion_callback, timeout,
metadata=None, protocol_options=None):
def event_unary_stream(self,
group,
method,
request,
receiver,
abortion_callback,
timeout,
metadata=None,
protocol_options=None):
raise NotImplementedError()
def event_stream_unary(
self, group, method, receiver, abortion_callback, timeout,
metadata=None, protocol_options=None):
def event_stream_unary(self,
group,
method,
receiver,
abortion_callback,
timeout,
metadata=None,
protocol_options=None):
raise NotImplementedError()
def event_stream_stream(
self, group, method, receiver, abortion_callback, timeout,
metadata=None, protocol_options=None):
def event_stream_stream(self,
group,
method,
receiver,
abortion_callback,
timeout,
metadata=None,
protocol_options=None):
raise NotImplementedError()
def unary_unary(self, group, method):
request_serializer = self._request_serializers.get((group, method,))
response_deserializer = self._response_deserializers.get((group, method,))
request_serializer = self._request_serializers.get((
group,
method,))
response_deserializer = self._response_deserializers.get((
group,
method,))
return _UnaryUnaryMultiCallable(
self._channel, group, method, self._metadata_transformer,
request_serializer, response_deserializer)
def unary_stream(self, group, method):
request_serializer = self._request_serializers.get((group, method,))
response_deserializer = self._response_deserializers.get((group, method,))
request_serializer = self._request_serializers.get((
group,
method,))
response_deserializer = self._response_deserializers.get((
group,
method,))
return _UnaryStreamMultiCallable(
self._channel, group, method, self._metadata_transformer,
request_serializer, response_deserializer)
def stream_unary(self, group, method):
request_serializer = self._request_serializers.get((group, method,))
response_deserializer = self._response_deserializers.get((group, method,))
request_serializer = self._request_serializers.get((
group,
method,))
response_deserializer = self._response_deserializers.get((
group,
method,))
return _StreamUnaryMultiCallable(
self._channel, group, method, self._metadata_transformer,
request_serializer, response_deserializer)
def stream_stream(self, group, method):
request_serializer = self._request_serializers.get((group, method,))
response_deserializer = self._response_deserializers.get((group, method,))
request_serializer = self._request_serializers.get((
group,
method,))
response_deserializer = self._response_deserializers.get((
group,
method,))
return _StreamStreamMultiCallable(
self._channel, group, method, self._metadata_transformer,
request_serializer, response_deserializer)
@ -541,7 +667,8 @@ class _DynamicStub(face.DynamicStub):
elif method_cardinality is cardinality.Cardinality.STREAM_STREAM:
return self._generic_stub.stream_stream(self._group, attr)
else:
raise AttributeError('_DynamicStub object has no attribute "%s"!' % attr)
raise AttributeError('_DynamicStub object has no attribute "%s"!' %
attr)
def __enter__(self):
return self
@ -550,19 +677,14 @@ class _DynamicStub(face.DynamicStub):
return False
def generic_stub(
channel, host, metadata_transformer, request_serializers,
def generic_stub(channel, host, metadata_transformer, request_serializers,
response_deserializers):
return _GenericStub(
channel, metadata_transformer, request_serializers,
return _GenericStub(channel, metadata_transformer, request_serializers,
response_deserializers)
def dynamic_stub(
channel, service, cardinalities, host, metadata_transformer,
def dynamic_stub(channel, service, cardinalities, host, metadata_transformer,
request_serializers, response_deserializers):
return _DynamicStub(
_GenericStub(
channel, metadata_transformer, request_serializers,
response_deserializers),
service, cardinalities)
_GenericStub(channel, metadata_transformer, request_serializers,
response_deserializers), service, cardinalities)

@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Affords a connectivity-state-listenable channel."""
import threading
@ -41,8 +40,9 @@ _CHANNEL_SUBSCRIPTION_CALLBACK_ERROR_LOG_MESSAGE = (
'Exception calling channel subscription callback!')
_LOW_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY = {
state: connectivity for state, connectivity in zip(
_types.ConnectivityState, interfaces.ChannelConnectivity)
state: connectivity
for state, connectivity in zip(_types.ConnectivityState,
interfaces.ChannelConnectivity)
}
@ -85,7 +85,9 @@ class ConnectivityChannel(object):
def _spawn_delivery(self, connectivity, callbacks):
delivering_thread = threading.Thread(
target=self._deliver, args=(connectivity, callbacks,))
target=self._deliver, args=(
connectivity,
callbacks,))
delivering_thread.start()
self._delivering = True
@ -97,16 +99,18 @@ class ConnectivityChannel(object):
self._connectivity = _LOW_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY[
low_connectivity]
callbacks = tuple(
callback for callback, unused_but_known_to_be_none_connectivity
in self._callbacks_and_connectivities)
callback
for callback, unused_but_known_to_be_none_connectivity in
self._callbacks_and_connectivities)
for callback_and_connectivity in self._callbacks_and_connectivities:
callback_and_connectivity[1] = self._connectivity
if callbacks:
self._spawn_delivery(self._connectivity, callbacks)
completion_queue = _low.CompletionQueue()
while True:
low_channel.watch_connectivity_state(
low_connectivity, time.time() + 0.2, completion_queue, None)
low_channel.watch_connectivity_state(low_connectivity,
time.time() + 0.2,
completion_queue, None)
event = completion_queue.next()
with self._lock:
if not self._callbacks_and_connectivities and not self._try_to_connect:
@ -117,7 +121,8 @@ class ConnectivityChannel(object):
try_to_connect = self._try_to_connect
self._try_to_connect = False
if event.success or try_to_connect:
low_connectivity = low_channel.check_connectivity_state(try_to_connect)
low_connectivity = low_channel.check_connectivity_state(
try_to_connect)
with self._lock:
self._connectivity = _LOW_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY[
low_connectivity]
@ -146,8 +151,8 @@ class ConnectivityChannel(object):
def unsubscribe(self, callback):
with self._lock:
for index, (subscribed_callback, unused_connectivity) in enumerate(
self._callbacks_and_connectivities):
for index, (subscribed_callback, unused_connectivity
) in enumerate(self._callbacks_and_connectivities):
if callback == subscribed_callback:
self._callbacks_and_connectivities.pop(index)
break

@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Translates gRPC's server-side API into gRPC's server-side Beta API."""
import collections
@ -96,15 +95,20 @@ class _FaceServicerContext(face.ServicerContext):
def _adapt_unary_request_inline(unary_request_inline):
def adaptation(request, servicer_context):
return unary_request_inline(request, _FaceServicerContext(servicer_context))
return unary_request_inline(request,
_FaceServicerContext(servicer_context))
return adaptation
def _adapt_stream_request_inline(stream_request_inline):
def adaptation(request_iterator, servicer_context):
return stream_request_inline(
request_iterator, _FaceServicerContext(servicer_context))
return stream_request_inline(request_iterator,
_FaceServicerContext(servicer_context))
return adaptation
@ -165,6 +169,7 @@ class _Callback(stream.Consumer):
def _run_request_pipe_thread(request_iterator, request_consumer,
servicer_context):
thread_joined = threading.Event()
def pipe_requests():
for request in request_iterator:
if not servicer_context.is_active() or thread_joined.is_set():
@ -183,116 +188,132 @@ def _run_request_pipe_thread(request_iterator, request_consumer,
def _adapt_unary_unary_event(unary_unary_event):
def adaptation(request, servicer_context):
callback = _Callback()
if not servicer_context.add_callback(callback.cancel):
raise abandonment.Abandoned()
unary_unary_event(
request, callback.consume_and_terminate,
unary_unary_event(request, callback.consume_and_terminate,
_FaceServicerContext(servicer_context))
return callback.draw_all_values()[0]
return adaptation
def _adapt_unary_stream_event(unary_stream_event):
def adaptation(request, servicer_context):
callback = _Callback()
if not servicer_context.add_callback(callback.cancel):
raise abandonment.Abandoned()
unary_stream_event(
request, callback, _FaceServicerContext(servicer_context))
unary_stream_event(request, callback,
_FaceServicerContext(servicer_context))
while True:
response = callback.draw_one_value()
if response is None:
return
else:
yield response
return adaptation
def _adapt_stream_unary_event(stream_unary_event):
def adaptation(request_iterator, servicer_context):
callback = _Callback()
if not servicer_context.add_callback(callback.cancel):
raise abandonment.Abandoned()
request_consumer = stream_unary_event(
callback.consume_and_terminate, _FaceServicerContext(servicer_context))
_run_request_pipe_thread(
request_iterator, request_consumer, servicer_context)
callback.consume_and_terminate,
_FaceServicerContext(servicer_context))
_run_request_pipe_thread(request_iterator, request_consumer,
servicer_context)
return callback.draw_all_values()[0]
return adaptation
def _adapt_stream_stream_event(stream_stream_event):
def adaptation(request_iterator, servicer_context):
callback = _Callback()
if not servicer_context.add_callback(callback.cancel):
raise abandonment.Abandoned()
request_consumer = stream_stream_event(
callback, _FaceServicerContext(servicer_context))
_run_request_pipe_thread(
request_iterator, request_consumer, servicer_context)
_run_request_pipe_thread(request_iterator, request_consumer,
servicer_context)
while True:
response = callback.draw_one_value()
if response is None:
return
else:
yield response
return adaptation
class _SimpleMethodHandler(
collections.namedtuple(
'_MethodHandler',
('request_streaming', 'response_streaming', 'request_deserializer',
'response_serializer', 'unary_unary', 'unary_stream', 'stream_unary',
'stream_stream',)),
grpc.RpcMethodHandler):
collections.namedtuple('_MethodHandler', (
'request_streaming',
'response_streaming',
'request_deserializer',
'response_serializer',
'unary_unary',
'unary_stream',
'stream_unary',
'stream_stream',)), grpc.RpcMethodHandler):
pass
def _simple_method_handler(
implementation, request_deserializer, response_serializer):
def _simple_method_handler(implementation, request_deserializer,
response_serializer):
if implementation.style is style.Service.INLINE:
if implementation.cardinality is cardinality.Cardinality.UNARY_UNARY:
return _SimpleMethodHandler(
False, False, request_deserializer, response_serializer,
_adapt_unary_request_inline(implementation.unary_unary_inline), None,
None, None)
_adapt_unary_request_inline(implementation.unary_unary_inline),
None, None, None)
elif implementation.cardinality is cardinality.Cardinality.UNARY_STREAM:
return _SimpleMethodHandler(
False, True, request_deserializer, response_serializer, None,
_adapt_unary_request_inline(implementation.unary_stream_inline), None,
None)
_adapt_unary_request_inline(implementation.unary_stream_inline),
None, None)
elif implementation.cardinality is cardinality.Cardinality.STREAM_UNARY:
return _SimpleMethodHandler(
True, False, request_deserializer, response_serializer, None, None,
_adapt_stream_request_inline(implementation.stream_unary_inline),
return _SimpleMethodHandler(True, False, request_deserializer,
response_serializer, None, None,
_adapt_stream_request_inline(
implementation.stream_unary_inline),
None)
elif implementation.cardinality is cardinality.Cardinality.STREAM_STREAM:
return _SimpleMethodHandler(
True, True, request_deserializer, response_serializer, None, None,
None,
_adapt_stream_request_inline(implementation.stream_stream_inline))
True, True, request_deserializer, response_serializer, None,
None, None,
_adapt_stream_request_inline(
implementation.stream_stream_inline))
elif implementation.style is style.Service.EVENT:
if implementation.cardinality is cardinality.Cardinality.UNARY_UNARY:
return _SimpleMethodHandler(
False, False, request_deserializer, response_serializer,
_adapt_unary_unary_event(implementation.unary_unary_event), None,
None, None)
_adapt_unary_unary_event(implementation.unary_unary_event),
None, None, None)
elif implementation.cardinality is cardinality.Cardinality.UNARY_STREAM:
return _SimpleMethodHandler(
False, True, request_deserializer, response_serializer, None,
_adapt_unary_stream_event(implementation.unary_stream_event), None,
None)
_adapt_unary_stream_event(implementation.unary_stream_event),
None, None)
elif implementation.cardinality is cardinality.Cardinality.STREAM_UNARY:
return _SimpleMethodHandler(
True, False, request_deserializer, response_serializer, None, None,
_adapt_stream_unary_event(implementation.stream_unary_event), None)
True, False, request_deserializer, response_serializer, None,
None,
_adapt_stream_unary_event(implementation.stream_unary_event),
None)
elif implementation.cardinality is cardinality.Cardinality.STREAM_STREAM:
return _SimpleMethodHandler(
True, True, request_deserializer, response_serializer, None, None,
None, _adapt_stream_stream_event(implementation.stream_stream_event))
True, True, request_deserializer, response_serializer, None,
None, None,
_adapt_stream_stream_event(implementation.stream_stream_event))
def _flatten_method_pair_map(method_pair_map):
@ -306,8 +327,7 @@ def _flatten_method_pair_map(method_pair_map):
class _GenericRpcHandler(grpc.GenericRpcHandler):
def __init__(
self, method_implementations, multi_method_implementation,
def __init__(self, method_implementations, multi_method_implementation,
request_deserializers, response_serializers):
self._method_implementations = _flatten_method_pair_map(
method_implementations)
@ -360,16 +380,18 @@ class _Server(interfaces.Server):
return False
def server(
service_implementations, multi_method_implementation, request_deserializers,
response_serializers, thread_pool, thread_pool_size):
def server(service_implementations, multi_method_implementation,
request_deserializers, response_serializers, thread_pool,
thread_pool_size):
generic_rpc_handler = _GenericRpcHandler(
service_implementations, multi_method_implementation,
request_deserializers, response_serializers)
if thread_pool is None:
effective_thread_pool = logging_pool.pool(
_DEFAULT_POOL_SIZE if thread_pool_size is None else thread_pool_size)
effective_thread_pool = logging_pool.pool(_DEFAULT_POOL_SIZE
if thread_pool_size is None
else thread_pool_size)
else:
effective_thread_pool = thread_pool
return _Server(
grpc.server(effective_thread_pool, handlers=(generic_rpc_handler,)))
grpc.server(
effective_thread_pool, handlers=(generic_rpc_handler,)))

@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Entry points into the Beta API of gRPC Python."""
# threading is referenced from specification in this module.
@ -43,7 +42,6 @@ from grpc.beta import interfaces
from grpc.framework.common import cardinality # pylint: disable=unused-import
from grpc.framework.interfaces.face import face # pylint: disable=unused-import
ChannelCredentials = grpc.ChannelCredentials
ssl_channel_credentials = grpc.ssl_channel_credentials
CallCredentials = grpc.CallCredentials
@ -61,6 +59,7 @@ def google_call_credentials(credentials):
"""
return metadata_call_credentials(_auth.GoogleCallCredentials(credentials))
access_token_call_credentials = grpc.access_token_call_credentials
composite_call_credentials = grpc.composite_call_credentials
composite_channel_credentials = grpc.composite_channel_credentials
@ -113,8 +112,8 @@ def insecure_channel(host, port):
Returns:
A Channel to the remote host through which RPCs may be conducted.
"""
channel = grpc.insecure_channel(
host if port is None else '%s:%d' % (host, port))
channel = grpc.insecure_channel(host
if port is None else '%s:%d' % (host, port))
return Channel(channel)
@ -130,8 +129,8 @@ def secure_channel(host, port, channel_credentials):
Returns:
A secure Channel to the remote host through which RPCs may be conducted.
"""
channel = grpc.secure_channel(
host if port is None else '%s:%d' % (host, port), channel_credentials)
channel = grpc.secure_channel(host if port is None else
'%s:%d' % (host, port), channel_credentials)
return Channel(channel)
@ -143,8 +142,7 @@ class StubOptions(object):
functions.
"""
def __init__(
self, host, request_serializers, response_deserializers,
def __init__(self, host, request_serializers, response_deserializers,
metadata_transformer, thread_pool, thread_pool_size):
self.host = host
self.request_serializers = request_serializers
@ -153,13 +151,16 @@ class StubOptions(object):
self.thread_pool = thread_pool
self.thread_pool_size = thread_pool_size
_EMPTY_STUB_OPTIONS = StubOptions(
None, None, None, None, None, None)
_EMPTY_STUB_OPTIONS = StubOptions(None, None, None, None, None, None)
def stub_options(
host=None, request_serializers=None, response_deserializers=None,
metadata_transformer=None, thread_pool=None, thread_pool_size=None):
def stub_options(host=None,
request_serializers=None,
response_deserializers=None,
metadata_transformer=None,
thread_pool=None,
thread_pool_size=None):
"""Creates a StubOptions value to be passed at stub creation.
All parameters are optional and should always be passed by keyword.
@ -180,8 +181,7 @@ def stub_options(
Returns:
A StubOptions value created from the passed parameters.
"""
return StubOptions(
host, request_serializers, response_deserializers,
return StubOptions(host, request_serializers, response_deserializers,
metadata_transformer, thread_pool, thread_pool_size)
@ -198,7 +198,8 @@ def generic_stub(channel, options=None):
effective_options = _EMPTY_STUB_OPTIONS if options is None else options
return _client_adaptations.generic_stub(
channel._channel, # pylint: disable=protected-access
effective_options.host, effective_options.metadata_transformer,
effective_options.host,
effective_options.metadata_transformer,
effective_options.request_serializers,
effective_options.response_deserializers)
@ -220,7 +221,9 @@ def dynamic_stub(channel, service, cardinalities, options=None):
effective_options = StubOptions() if options is None else options
return _client_adaptations.dynamic_stub(
channel._channel, # pylint: disable=protected-access
service, cardinalities, effective_options.host,
service,
cardinalities,
effective_options.host,
effective_options.metadata_transformer,
effective_options.request_serializers,
effective_options.response_deserializers)
@ -238,10 +241,9 @@ class ServerOptions(object):
functions.
"""
def __init__(
self, multi_method_implementation, request_deserializers,
response_serializers, thread_pool, thread_pool_size, default_timeout,
maximum_timeout):
def __init__(self, multi_method_implementation, request_deserializers,
response_serializers, thread_pool, thread_pool_size,
default_timeout, maximum_timeout):
self.multi_method_implementation = multi_method_implementation
self.request_deserializers = request_deserializers
self.response_serializers = response_serializers
@ -250,14 +252,17 @@ class ServerOptions(object):
self.default_timeout = default_timeout
self.maximum_timeout = maximum_timeout
_EMPTY_SERVER_OPTIONS = ServerOptions(
None, None, None, None, None, None, None)
_EMPTY_SERVER_OPTIONS = ServerOptions(None, None, None, None, None, None, None)
def server_options(
multi_method_implementation=None, request_deserializers=None,
response_serializers=None, thread_pool=None, thread_pool_size=None,
default_timeout=None, maximum_timeout=None):
def server_options(multi_method_implementation=None,
request_deserializers=None,
response_serializers=None,
thread_pool=None,
thread_pool_size=None,
default_timeout=None,
maximum_timeout=None):
"""Creates a ServerOptions value to be passed at server creation.
All parameters are optional and should always be passed by keyword.
@ -282,9 +287,9 @@ def server_options(
Returns:
A StubOptions value created from the passed parameters.
"""
return ServerOptions(
multi_method_implementation, request_deserializers, response_serializers,
thread_pool, thread_pool_size, default_timeout, maximum_timeout)
return ServerOptions(multi_method_implementation, request_deserializers,
response_serializers, thread_pool, thread_pool_size,
default_timeout, maximum_timeout)
def server(service_implementations, options=None):

@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Constants and interfaces of the Beta API of gRPC Python."""
import abc
@ -69,6 +68,7 @@ def grpc_call_options(disable_compression=False, credentials=None):
"""
return GRPCCallOptions(disable_compression, None, credentials)
GRPCAuthMetadataContext = grpc.AuthMetadataContext
GRPCAuthMetadataPluginCallback = grpc.AuthMetadataPluginCallback
GRPCAuthMetadataPlugin = grpc.AuthMetadataPlugin

@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Utilities for the gRPC Python Beta API."""
import threading
@ -161,4 +160,3 @@ def channel_ready_future(channel):
ready_future = _ChannelReadyFuture(channel)
ready_future.start()
return ready_future

@ -26,5 +26,3 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

@ -26,5 +26,3 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Defines an enum for classifying RPC methods by streaming semantics."""
import enum

@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Defines an enum for classifying RPC methods by control flow semantics."""
import enum

@ -26,5 +26,3 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Utilities for indicating abandonment of computation."""

@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Utilities for working with callables."""
import abc
@ -59,15 +58,15 @@ class Outcome(six.with_metaclass(abc.ABCMeta)):
class _EasyOutcome(
collections.namedtuple(
'_EasyOutcome', ['kind', 'return_value', 'exception']),
Outcome):
collections.namedtuple('_EasyOutcome',
['kind', 'return_value', 'exception']), Outcome):
"""A trivial implementation of Outcome."""
def _call_logging_exceptions(behavior, message, *args, **kwargs):
try:
return _EasyOutcome(Outcome.Kind.RETURNED, behavior(*args, **kwargs), None)
return _EasyOutcome(Outcome.Kind.RETURNED,
behavior(*args, **kwargs), None)
except Exception as e: # pylint: disable=broad-except
logging.exception(message)
return _EasyOutcome(Outcome.Kind.RAISED, None, e)
@ -86,9 +85,11 @@ def with_exceptions_logged(behavior, message):
future.Outcome describing whether the given behavior returned a value or
raised an exception.
"""
@functools.wraps(behavior)
def wrapped_behavior(*args, **kwargs):
return _call_logging_exceptions(behavior, message, *args, **kwargs)
return wrapped_behavior

@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""A Future interface.
Python doesn't have a Future interface in its standard library. In the absence

@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""A thread pool that logs exceptions raised by tasks executed within it."""
import logging
@ -36,13 +35,16 @@ from concurrent import futures
def _wrap(behavior):
"""Wraps an arbitrary callable behavior in exception-logging."""
def _wrapping(*args, **kwargs):
try:
return behavior(*args, **kwargs)
except Exception as e:
logging.exception(
'Unexpected exception from %s executed in logging pool!', behavior)
'Unexpected exception from %s executed in logging pool!',
behavior)
raise
return _wrapping
@ -62,8 +64,9 @@ class _LoggingPool(object):
return self._backing_pool.submit(_wrap(fn), *args, **kwargs)
def map(self, func, *iterables, **kwargs):
return self._backing_pool.map(
_wrap(func), *iterables, timeout=kwargs.get('timeout', None))
return self._backing_pool.map(_wrap(func),
*iterables,
timeout=kwargs.get('timeout', None))
def shutdown(self, wait=True):
self._backing_pool.shutdown(wait=wait)

@ -26,13 +26,13 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Interfaces related to streams of values or objects."""
import abc
import six
class Consumer(six.with_metaclass(abc.ABCMeta)):
"""Interface for consumers of finite streams of values or objects."""

@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Helpful utilities related to the stream module."""
import logging

@ -26,5 +26,3 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

@ -26,5 +26,3 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""The base interface of RPC Framework.
Implementations of this interface support the conduct of "operations":
@ -166,8 +165,10 @@ class Operator(six.with_metaclass(abc.ABCMeta)):
"""An interface through which to participate in an operation."""
@abc.abstractmethod
def advance(
self, initial_metadata=None, payload=None, completion=None,
def advance(self,
initial_metadata=None,
payload=None,
completion=None,
allowance=None):
"""Progresses the operation.
@ -183,6 +184,7 @@ class Operator(six.with_metaclass(abc.ABCMeta)):
"""
raise NotImplementedError()
class ProtocolReceiver(six.with_metaclass(abc.ABCMeta)):
"""A means of receiving protocol values during an operation."""
@ -284,9 +286,15 @@ class End(six.with_metaclass(abc.ABCMeta)):
raise NotImplementedError()
@abc.abstractmethod
def operate(
self, group, method, subscription, timeout, initial_metadata=None,
payload=None, completion=None, protocol_options=None):
def operate(self,
group,
method,
subscription,
timeout,
initial_metadata=None,
payload=None,
completion=None,
protocol_options=None):
"""Commences an operation.
Args:

@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Utilities for use with the base interface of RPC Framework."""
import collections
@ -34,23 +33,26 @@ import collections
from grpc.framework.interfaces.base import base
class _Completion(
base.Completion,
collections.namedtuple(
'_Completion', ('terminal_metadata', 'code', 'message',))):
class _Completion(base.Completion,
collections.namedtuple('_Completion', (
'terminal_metadata',
'code',
'message',))):
"""A trivial implementation of base.Completion."""
class _Subscription(
base.Subscription,
collections.namedtuple(
'_Subscription',
('kind', 'termination_callback', 'allowance', 'operator',
class _Subscription(base.Subscription,
collections.namedtuple('_Subscription', (
'kind',
'termination_callback',
'allowance',
'operator',
'protocol_receiver',))):
"""A trivial implementation of base.Subscription."""
_NONE_SUBSCRIPTION = _Subscription(
base.Subscription.Kind.NONE, None, None, None, None)
_NONE_SUBSCRIPTION = _Subscription(base.Subscription.Kind.NONE, None, None,
None, None)
def completion(terminal_metadata, code, message):
@ -78,5 +80,5 @@ def full_subscription(operator, protocol_receiver):
A base.Subscription of kind base.Subscription.Kind.FULL wrapping the given
base.Operator and base.ProtocolReceiver.
"""
return _Subscription(
base.Subscription.Kind.FULL, None, None, operator, protocol_receiver)
return _Subscription(base.Subscription.Kind.FULL, None, None, operator,
protocol_receiver)

@ -26,5 +26,3 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Interfaces defining the Face layer of RPC Framework."""
import abc
@ -64,13 +63,18 @@ class NoSuchMethodError(Exception):
self.method = method
def __repr__(self):
return 'face.NoSuchMethodError(%s, %s)' % (self.group, self.method,)
return 'face.NoSuchMethodError(%s, %s)' % (
self.group,
self.method,)
class Abortion(
collections.namedtuple(
'Abortion',
('kind', 'initial_metadata', 'terminal_metadata', 'code', 'details',))):
collections.namedtuple('Abortion', (
'kind',
'initial_metadata',
'terminal_metadata',
'code',
'details',))):
"""A value describing RPC abortion.
Attributes:
@ -119,8 +123,8 @@ class AbortionError(six.with_metaclass(abc.ABCMeta, Exception)):
self.details = details
def __str__(self):
return '%s(code=%s, details="%s")' % (
self.__class__.__name__, self.code, self.details)
return '%s(code=%s, details="%s")' % (self.__class__.__name__,
self.code, self.details)
class CancellationError(AbortionError):
@ -363,8 +367,11 @@ class UnaryUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)):
"""Affords invoking a unary-unary RPC in any call style."""
@abc.abstractmethod
def __call__(
self, request, timeout, metadata=None, with_call=False,
def __call__(self,
request,
timeout,
metadata=None,
with_call=False,
protocol_options=None):
"""Synchronously invokes the underlying RPC.
@ -408,9 +415,13 @@ class UnaryUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)):
raise NotImplementedError()
@abc.abstractmethod
def event(
self, request, receiver, abortion_callback, timeout,
metadata=None, protocol_options=None):
def event(self,
request,
receiver,
abortion_callback,
timeout,
metadata=None,
protocol_options=None):
"""Asynchronously invokes the underlying RPC.
Args:
@ -453,9 +464,13 @@ class UnaryStreamMultiCallable(six.with_metaclass(abc.ABCMeta)):
raise NotImplementedError()
@abc.abstractmethod
def event(
self, request, receiver, abortion_callback, timeout,
metadata=None, protocol_options=None):
def event(self,
request,
receiver,
abortion_callback,
timeout,
metadata=None,
protocol_options=None):
"""Asynchronously invokes the underlying RPC.
Args:
@ -479,9 +494,12 @@ class StreamUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)):
"""Affords invoking a stream-unary RPC in any call style."""
@abc.abstractmethod
def __call__(
self, request_iterator, timeout, metadata=None,
with_call=False, protocol_options=None):
def __call__(self,
request_iterator,
timeout,
metadata=None,
with_call=False,
protocol_options=None):
"""Synchronously invokes the underlying RPC.
Args:
@ -504,8 +522,11 @@ class StreamUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)):
raise NotImplementedError()
@abc.abstractmethod
def future(
self, request_iterator, timeout, metadata=None, protocol_options=None):
def future(self,
request_iterator,
timeout,
metadata=None,
protocol_options=None):
"""Asynchronously invokes the underlying RPC.
Args:
@ -525,8 +546,11 @@ class StreamUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)):
raise NotImplementedError()
@abc.abstractmethod
def event(
self, receiver, abortion_callback, timeout, metadata=None,
def event(self,
receiver,
abortion_callback,
timeout,
metadata=None,
protocol_options=None):
"""Asynchronously invokes the underlying RPC.
@ -551,8 +575,11 @@ class StreamStreamMultiCallable(six.with_metaclass(abc.ABCMeta)):
"""Affords invoking a stream-stream RPC in any call style."""
@abc.abstractmethod
def __call__(
self, request_iterator, timeout, metadata=None, protocol_options=None):
def __call__(self,
request_iterator,
timeout,
metadata=None,
protocol_options=None):
"""Invokes the underlying RPC.
Args:
@ -571,8 +598,11 @@ class StreamStreamMultiCallable(six.with_metaclass(abc.ABCMeta)):
raise NotImplementedError()
@abc.abstractmethod
def event(
self, receiver, abortion_callback, timeout, metadata=None,
def event(self,
receiver,
abortion_callback,
timeout,
metadata=None,
protocol_options=None):
"""Asynchronously invokes the underlying RPC.
@ -673,9 +703,14 @@ class GenericStub(six.with_metaclass(abc.ABCMeta)):
"""Affords RPC invocation via generic methods."""
@abc.abstractmethod
def blocking_unary_unary(
self, group, method, request, timeout, metadata=None,
with_call=False, protocol_options=None):
def blocking_unary_unary(self,
group,
method,
request,
timeout,
metadata=None,
with_call=False,
protocol_options=None):
"""Invokes a unary-request-unary-response method.
This method blocks until either returning the response value of the RPC
@ -703,8 +738,12 @@ class GenericStub(six.with_metaclass(abc.ABCMeta)):
raise NotImplementedError()
@abc.abstractmethod
def future_unary_unary(
self, group, method, request, timeout, metadata=None,
def future_unary_unary(self,
group,
method,
request,
timeout,
metadata=None,
protocol_options=None):
"""Invokes a unary-request-unary-response method.
@ -726,8 +765,12 @@ class GenericStub(six.with_metaclass(abc.ABCMeta)):
raise NotImplementedError()
@abc.abstractmethod
def inline_unary_stream(
self, group, method, request, timeout, metadata=None,
def inline_unary_stream(self,
group,
method,
request,
timeout,
metadata=None,
protocol_options=None):
"""Invokes a unary-request-stream-response method.
@ -748,9 +791,14 @@ class GenericStub(six.with_metaclass(abc.ABCMeta)):
raise NotImplementedError()
@abc.abstractmethod
def blocking_stream_unary(
self, group, method, request_iterator, timeout, metadata=None,
with_call=False, protocol_options=None):
def blocking_stream_unary(self,
group,
method,
request_iterator,
timeout,
metadata=None,
with_call=False,
protocol_options=None):
"""Invokes a stream-request-unary-response method.
This method blocks until either returning the response value of the RPC
@ -778,8 +826,12 @@ class GenericStub(six.with_metaclass(abc.ABCMeta)):
raise NotImplementedError()
@abc.abstractmethod
def future_stream_unary(
self, group, method, request_iterator, timeout, metadata=None,
def future_stream_unary(self,
group,
method,
request_iterator,
timeout,
metadata=None,
protocol_options=None):
"""Invokes a stream-request-unary-response method.
@ -801,8 +853,12 @@ class GenericStub(six.with_metaclass(abc.ABCMeta)):
raise NotImplementedError()
@abc.abstractmethod
def inline_stream_stream(
self, group, method, request_iterator, timeout, metadata=None,
def inline_stream_stream(self,
group,
method,
request_iterator,
timeout,
metadata=None,
protocol_options=None):
"""Invokes a stream-request-stream-response method.
@ -823,9 +879,15 @@ class GenericStub(six.with_metaclass(abc.ABCMeta)):
raise NotImplementedError()
@abc.abstractmethod
def event_unary_unary(
self, group, method, request, receiver, abortion_callback, timeout,
metadata=None, protocol_options=None):
def event_unary_unary(self,
group,
method,
request,
receiver,
abortion_callback,
timeout,
metadata=None,
protocol_options=None):
"""Event-driven invocation of a unary-request-unary-response method.
Args:
@ -846,9 +908,15 @@ class GenericStub(six.with_metaclass(abc.ABCMeta)):
raise NotImplementedError()
@abc.abstractmethod
def event_unary_stream(
self, group, method, request, receiver, abortion_callback, timeout,
metadata=None, protocol_options=None):
def event_unary_stream(self,
group,
method,
request,
receiver,
abortion_callback,
timeout,
metadata=None,
protocol_options=None):
"""Event-driven invocation of a unary-request-stream-response method.
Args:
@ -869,9 +937,14 @@ class GenericStub(six.with_metaclass(abc.ABCMeta)):
raise NotImplementedError()
@abc.abstractmethod
def event_stream_unary(
self, group, method, receiver, abortion_callback, timeout,
metadata=None, protocol_options=None):
def event_stream_unary(self,
group,
method,
receiver,
abortion_callback,
timeout,
metadata=None,
protocol_options=None):
"""Event-driven invocation of a unary-request-unary-response method.
Args:
@ -892,9 +965,14 @@ class GenericStub(six.with_metaclass(abc.ABCMeta)):
raise NotImplementedError()
@abc.abstractmethod
def event_stream_stream(
self, group, method, receiver, abortion_callback, timeout,
metadata=None, protocol_options=None):
def event_stream_stream(self,
group,
method,
receiver,
abortion_callback,
timeout,
metadata=None,
protocol_options=None):
"""Event-driven invocation of a unary-request-stream-response method.
Args:

@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Utilities for RPC Framework's Face interface."""
import collections
@ -38,13 +37,19 @@ from grpc.framework.foundation import stream # pylint: disable=unused-import
from grpc.framework.interfaces.face import face
class _MethodImplementation(
face.MethodImplementation,
collections.namedtuple(
'_MethodImplementation',
['cardinality', 'style', 'unary_unary_inline', 'unary_stream_inline',
'stream_unary_inline', 'stream_stream_inline', 'unary_unary_event',
'unary_stream_event', 'stream_unary_event', 'stream_stream_event',])):
class _MethodImplementation(face.MethodImplementation,
collections.namedtuple('_MethodImplementation', [
'cardinality',
'style',
'unary_unary_inline',
'unary_stream_inline',
'stream_unary_inline',
'stream_stream_inline',
'unary_unary_event',
'unary_stream_event',
'stream_unary_event',
'stream_stream_event',
])):
pass
@ -59,9 +64,9 @@ def unary_unary_inline(behavior):
Returns:
An face.MethodImplementation derived from the given behavior.
"""
return _MethodImplementation(
cardinality.Cardinality.UNARY_UNARY, style.Service.INLINE, behavior,
None, None, None, None, None, None, None)
return _MethodImplementation(cardinality.Cardinality.UNARY_UNARY,
style.Service.INLINE, behavior, None, None,
None, None, None, None, None)
def unary_stream_inline(behavior):
@ -75,9 +80,9 @@ def unary_stream_inline(behavior):
Returns:
An face.MethodImplementation derived from the given behavior.
"""
return _MethodImplementation(
cardinality.Cardinality.UNARY_STREAM, style.Service.INLINE, None,
behavior, None, None, None, None, None, None)
return _MethodImplementation(cardinality.Cardinality.UNARY_STREAM,
style.Service.INLINE, None, behavior, None,
None, None, None, None, None)
def stream_unary_inline(behavior):
@ -91,9 +96,9 @@ def stream_unary_inline(behavior):
Returns:
An face.MethodImplementation derived from the given behavior.
"""
return _MethodImplementation(
cardinality.Cardinality.STREAM_UNARY, style.Service.INLINE, None, None,
behavior, None, None, None, None, None)
return _MethodImplementation(cardinality.Cardinality.STREAM_UNARY,
style.Service.INLINE, None, None, behavior,
None, None, None, None, None)
def stream_stream_inline(behavior):
@ -107,9 +112,9 @@ def stream_stream_inline(behavior):
Returns:
An face.MethodImplementation derived from the given behavior.
"""
return _MethodImplementation(
cardinality.Cardinality.STREAM_STREAM, style.Service.INLINE, None, None,
None, behavior, None, None, None, None)
return _MethodImplementation(cardinality.Cardinality.STREAM_STREAM,
style.Service.INLINE, None, None, None,
behavior, None, None, None, None)
def unary_unary_event(behavior):
@ -123,9 +128,9 @@ def unary_unary_event(behavior):
Returns:
An face.MethodImplementation derived from the given behavior.
"""
return _MethodImplementation(
cardinality.Cardinality.UNARY_UNARY, style.Service.EVENT, None, None,
None, None, behavior, None, None, None)
return _MethodImplementation(cardinality.Cardinality.UNARY_UNARY,
style.Service.EVENT, None, None, None, None,
behavior, None, None, None)
def unary_stream_event(behavior):
@ -139,9 +144,9 @@ def unary_stream_event(behavior):
Returns:
An face.MethodImplementation derived from the given behavior.
"""
return _MethodImplementation(
cardinality.Cardinality.UNARY_STREAM, style.Service.EVENT, None, None,
None, None, None, behavior, None, None)
return _MethodImplementation(cardinality.Cardinality.UNARY_STREAM,
style.Service.EVENT, None, None, None, None,
None, behavior, None, None)
def stream_unary_event(behavior):
@ -156,9 +161,9 @@ def stream_unary_event(behavior):
Returns:
An face.MethodImplementation derived from the given behavior.
"""
return _MethodImplementation(
cardinality.Cardinality.STREAM_UNARY, style.Service.EVENT, None, None,
None, None, None, None, behavior, None)
return _MethodImplementation(cardinality.Cardinality.STREAM_UNARY,
style.Service.EVENT, None, None, None, None,
None, None, behavior, None)
def stream_stream_event(behavior):
@ -173,6 +178,6 @@ def stream_stream_event(behavior):
Returns:
An face.MethodImplementation derived from the given behavior.
"""
return _MethodImplementation(
cardinality.Cardinality.STREAM_STREAM, style.Service.EVENT, None, None,
None, None, None, None, None, behavior)
return _MethodImplementation(cardinality.Cardinality.STREAM_STREAM,
style.Service.EVENT, None, None, None, None,
None, None, None, behavior)

@ -27,7 +27,6 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import os
import os.path
import shutil
@ -38,7 +37,6 @@ from distutils import errors
import commands
C_PYTHON_DEV = """
#include <Python.h>
int main(int argc, char **argv) { return 0; }
@ -55,9 +53,8 @@ Could not find <Python.h>. This could mean the following:
(check your environment variables or try re-installing?)
"""
C_CHECKS = {
C_PYTHON_DEV: C_PYTHON_DEV_ERROR_MESSAGE,
}
C_CHECKS = {C_PYTHON_DEV: C_PYTHON_DEV_ERROR_MESSAGE,}
def _compile(compiler, source_string):
tempdir = tempfile.mkdtemp()
@ -71,6 +68,7 @@ def _compile(compiler, source_string):
finally:
shutil.rmtree(tempdir)
def _expect_compile(compiler, source_string, error_message):
if _compile(compiler, source_string) is not None:
sys.stderr.write(error_message)
@ -78,6 +76,7 @@ def _expect_compile(compiler, source_string, error_message):
"Diagnostics found a compilation environment issue:\n{}"
.format(error_message))
def diagnose_compile_error(build_ext, error):
"""Attempt to diagnose an error during compilation."""
for c_check, message in C_CHECKS.items():
@ -88,17 +87,16 @@ def diagnose_compile_error(build_ext, error):
]
for source in python_sources:
if not os.path.isfile(source):
raise commands.CommandError(
("Diagnostics found a missing Python extension source file:\n{}\n\n"
raise commands.CommandError((
"Diagnostics found a missing Python extension source file:\n{}\n\n"
"This is usually because the Cython sources haven't been transpiled "
"into C yet and you're building from source.\n"
"Try setting the environment variable "
"`GRPC_PYTHON_BUILD_WITH_CYTHON=1` when invoking `setup.py` or "
"when using `pip`, e.g.:\n\n"
"pip install -rrequirements.txt\n"
"GRPC_PYTHON_BUILD_WITH_CYTHON=1 pip install .")
.format(source)
)
"GRPC_PYTHON_BUILD_WITH_CYTHON=1 pip install .").format(source))
def diagnose_attribute_error(build_ext, error):
if any('_needs_stub' in arg for arg in error.args):
@ -106,11 +104,13 @@ def diagnose_attribute_error(build_ext, error):
"We expect a missing `_needs_stub` attribute from older versions of "
"setuptools. Consider upgrading setuptools.")
_ERROR_DIAGNOSES = {
errors.CompileError: diagnose_compile_error,
AttributeError: diagnose_attribute_error
}
def diagnose_build_ext_error(build_ext, error, formatted):
diagnostic = _ERROR_DIAGNOSES.get(type(error))
if diagnostic is None:
@ -120,4 +120,3 @@ def diagnose_build_ext_error(build_ext, error, formatted):
"\n\n{}".format(formatted))
else:
diagnostic(build_ext, error)

@ -26,5 +26,3 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

@ -26,5 +26,3 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Reference implementation for health checking in gRPC Python."""
import threading

@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Provides distutils command classes for the GRPC Python setup process."""
import os

@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Setup module for the GRPC Python package's optional health checking."""
import os
@ -41,18 +40,14 @@ os.chdir(os.path.dirname(os.path.abspath(__file__)))
import health_commands
import grpc_version
PACKAGE_DIRECTORIES = {
'': '.',
}
PACKAGE_DIRECTORIES = {'': '.',}
SETUP_REQUIRES = (
'grpcio-tools>={version}'.format(version=grpc_version.VERSION),
)
'grpcio-tools>={version}'.format(version=grpc_version.VERSION),)
INSTALL_REQUIRES = (
'protobuf>=3.0.0',
'grpcio>={version}'.format(version=grpc_version.VERSION),
)
'grpcio>={version}'.format(version=grpc_version.VERSION),)
COMMAND_CLASS = {
# Run preprocess from the repository *before* doing any packaging!
@ -68,5 +63,4 @@ setuptools.setup(
packages=setuptools.find_packages('.'),
install_requires=INSTALL_REQUIRES,
setup_requires=SETUP_REQUIRES,
cmdclass=COMMAND_CLASS
)
cmdclass=COMMAND_CLASS)

@ -26,4 +26,3 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

@ -26,4 +26,3 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Reference implementation for reflection in gRPC Python."""
import threading
@ -39,13 +38,13 @@ from grpc_reflection.v1alpha import reflection_pb2
_POOL = descriptor_pool.Default()
def _not_found_error():
return reflection_pb2.ServerReflectionResponse(
error_response=reflection_pb2.ErrorResponse(
error_code=grpc.StatusCode.NOT_FOUND.value[0],
error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(),
)
)
error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(),))
def _file_descriptor_response(descriptor):
proto = descriptor_pb2.FileDescriptorProto()
@ -53,9 +52,7 @@ def _file_descriptor_response(descriptor):
serialized_proto = proto.SerializeToString()
return reflection_pb2.ServerReflectionResponse(
file_descriptor_response=reflection_pb2.FileDescriptorResponse(
file_descriptor_proto=(serialized_proto,)
),
)
file_descriptor_proto=(serialized_proto,)),)
class ReflectionServicer(reflection_pb2.ServerReflectionServicer):
@ -80,7 +77,8 @@ class ReflectionServicer(reflection_pb2.ServerReflectionServicer):
def _file_containing_symbol(self, fully_qualified_name):
try:
descriptor = self._pool.FindFileContainingSymbol(fully_qualified_name)
descriptor = self._pool.FindFileContainingSymbol(
fully_qualified_name)
except KeyError:
return _not_found_error()
else:
@ -92,9 +90,7 @@ class ReflectionServicer(reflection_pb2.ServerReflectionServicer):
return reflection_pb2.ServerReflectionResponse(
error_response=reflection_pb2.ErrorResponse(
error_code=grpc.StatusCode.UNIMPLEMENTED.value[0],
error_message=grpc.StatusCode.UNIMPLMENTED.value[1].encode(),
)
)
error_message=grpc.StatusCode.UNIMPLMENTED.value[1].encode(),))
def _extension_numbers_of_type(fully_qualified_name):
# TODO(atash) We're allowed to leave this unsupported according to the
@ -104,26 +100,22 @@ class ReflectionServicer(reflection_pb2.ServerReflectionServicer):
return reflection_pb2.ServerReflectionResponse(
error_response=reflection_pb2.ErrorResponse(
error_code=grpc.StatusCode.UNIMPLEMENTED.value[0],
error_message=grpc.StatusCode.UNIMPLMENTED.value[1].encode(),
)
)
error_message=grpc.StatusCode.UNIMPLMENTED.value[1].encode(),))
def _list_services(self):
return reflection_pb2.ServerReflectionResponse(
list_services_response=reflection_pb2.ListServiceResponse(
service=[
list_services_response=reflection_pb2.ListServiceResponse(service=[
reflection_pb2.ServiceResponse(name=service_name)
for service_name in self._service_names
]
)
)
]))
def ServerReflectionInfo(self, request_iterator, context):
for request in request_iterator:
if request.HasField('file_by_filename'):
yield self._file_by_filename(request.file_by_filename)
elif request.HasField('file_containing_symbol'):
yield self._file_containing_symbol(request.file_containing_symbol)
yield self._file_containing_symbol(
request.file_containing_symbol)
elif request.HasField('file_containing_extension'):
yield self._file_containing_extension(
request.file_containing_extension.containing_type,
@ -137,7 +129,5 @@ class ReflectionServicer(reflection_pb2.ServerReflectionServicer):
yield reflection_pb2.ServerReflectionResponse(
error_response=reflection_pb2.ErrorResponse(
error_code=grpc.StatusCode.INVALID_ARGUMENT.value[0],
error_message=grpc.StatusCode.INVALID_ARGUMENT.value[1].encode(),
)
)
error_message=grpc.StatusCode.INVALID_ARGUMENT.value[1]
.encode(),))

@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Provides distutils command classes for the GRPC Python setup process."""
import os
@ -35,7 +34,8 @@ import shutil
import setuptools
ROOT_DIR = os.path.abspath(os.path.dirname(os.path.abspath(__file__)))
HEALTH_PROTO = os.path.join(ROOT_DIR, '../../proto/grpc/reflection/v1alpha/reflection.proto')
HEALTH_PROTO = os.path.join(
ROOT_DIR, '../../proto/grpc/reflection/v1alpha/reflection.proto')
class CopyProtoModules(setuptools.Command):
@ -54,7 +54,8 @@ class CopyProtoModules(setuptools.Command):
if os.path.isfile(HEALTH_PROTO):
shutil.copyfile(
HEALTH_PROTO,
os.path.join(ROOT_DIR, 'grpc_reflection/v1alpha/reflection.proto'))
os.path.join(ROOT_DIR,
'grpc_reflection/v1alpha/reflection.proto'))
class BuildPackageProtos(setuptools.Command):

@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Setup module for the GRPC Python package's optional reflection."""
import os
@ -41,18 +40,14 @@ os.chdir(os.path.dirname(os.path.abspath(__file__)))
import reflection_commands
import grpc_version
PACKAGE_DIRECTORIES = {
'': '.',
}
PACKAGE_DIRECTORIES = {'': '.',}
SETUP_REQUIRES = (
'grpcio-tools>={version}'.format(version=grpc_version.VERSION),
)
'grpcio-tools>={version}'.format(version=grpc_version.VERSION),)
INSTALL_REQUIRES = (
'protobuf>=3.0.0',
'grpcio>={version}'.format(version=grpc_version.VERSION),
)
'grpcio>={version}'.format(version=grpc_version.VERSION),)
COMMAND_CLASS = {
# Run preprocess from the repository *before* doing any packaging!
@ -68,5 +63,4 @@ setuptools.setup(
packages=setuptools.find_packages('.'),
install_requires=INSTALL_REQUIRES,
setup_requires=SETUP_REQUIRES,
cmdclass=COMMAND_CLASS
)
cmdclass=COMMAND_CLASS)

@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Provides distutils command classes for the gRPC Python setup process."""
import distutils
@ -122,8 +121,7 @@ class BuildProtoModules(setuptools.Command):
'--grpc_python_out={}'.format(PROTO_STEM),
] + [path]
if protoc.main(command) != 0:
sys.stderr.write(
'warning: Command:\n{}\nFailed'.format(
sys.stderr.write('warning: Command:\n{}\nFailed'.format(
command))
# Generated proto directories dont include __init__.py, but
@ -177,11 +175,9 @@ class TestLite(setuptools.Command):
class RunInterop(test.test):
description = 'run interop test client/server'
user_options = [
('args=', 'a', 'pass-thru arguments for the client/server'),
user_options = [('args=', 'a', 'pass-thru arguments for the client/server'),
('client', 'c', 'flag indicating to run the client'),
('server', 's', 'flag indicating to run the server')
]
('server', 's', 'flag indicating to run the server')]
def initialize_options(self):
self.args = ''
@ -190,11 +186,13 @@ class RunInterop(test.test):
def finalize_options(self):
if self.client and self.server:
raise DistutilsOptionError('you may only specify one of client or server')
raise DistutilsOptionError(
'you may only specify one of client or server')
def run(self):
if self.distribution.install_requires:
self.distribution.fetch_build_eggs(self.distribution.install_requires)
self.distribution.fetch_build_eggs(
self.distribution.install_requires)
if self.distribution.tests_require:
self.distribution.fetch_build_eggs(self.distribution.tests_require)
if self.client:

@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""A setup module for the gRPC Python package."""
import os
@ -48,9 +47,7 @@ import grpc_version
LICENSE = '3-clause BSD'
PACKAGE_DIRECTORIES = {
'': '.',
}
PACKAGE_DIRECTORIES = {'': '.',}
INSTALL_REQUIRES = (
'coverage>=4.0',
@ -61,13 +58,11 @@ INSTALL_REQUIRES = (
'grpcio-health-checking>={version}'.format(version=grpc_version.VERSION),
'oauth2client>=1.4.7',
'protobuf>=3.0.0',
'six>=1.10',
)
'six>=1.10',)
COMMAND_CLASS = {
# Run `preprocess` *before* doing any packaging!
'preprocess': commands.GatherProto,
'build_package_protos': grpc_tools.command.BuildPackageProtos,
'build_py': commands.BuildPy,
'run_interop': commands.RunInterop,
@ -80,9 +75,7 @@ PACKAGE_DATA = {
'credentials/server1.key',
'credentials/server1.pem',
],
'tests.protoc_plugin.protos.invocation_testing': [
'same.proto',
],
'tests.protoc_plugin.protos.invocation_testing': ['same.proto',],
'tests.protoc_plugin.protos.invocation_testing.split_messages': [
'messages.proto',
],
@ -94,9 +87,7 @@ PACKAGE_DATA = {
'credentials/server1.key',
'credentials/server1.pem',
],
'tests': [
'tests.json'
],
'tests': ['tests.json'],
}
TEST_SUITE = 'tests'
@ -118,5 +109,4 @@ setuptools.setup(
tests_require=TESTS_REQUIRE,
test_suite=TEST_SUITE,
test_loader=TEST_LOADER,
test_runner=TEST_RUNNER,
)
test_runner=TEST_RUNNER,)

@ -116,4 +116,5 @@ def iterate_suite_cases(suite):
elif isinstance(item, unittest.TestCase):
yield item
else:
raise ValueError('unexpected suite item of type {}'.format(type(item)))
raise ValueError('unexpected suite item of type {}'.format(
type(item)))

@ -41,8 +41,10 @@ from six import moves
from tests import _loader
class CaseResult(collections.namedtuple('CaseResult', [
'id', 'name', 'kind', 'stdout', 'stderr', 'skip_reason', 'traceback'])):
class CaseResult(
collections.namedtuple('CaseResult', [
'id', 'name', 'kind', 'stdout', 'stderr', 'skip_reason', 'traceback'
])):
"""A serializable result of a single test case.
Attributes:
@ -69,8 +71,14 @@ class CaseResult(collections.namedtuple('CaseResult', [
EXPECTED_FAILURE = 'expected failure'
UNEXPECTED_SUCCESS = 'unexpected success'
def __new__(cls, id=None, name=None, kind=None, stdout=None, stderr=None,
skip_reason=None, traceback=None):
def __new__(cls,
id=None,
name=None,
kind=None,
stdout=None,
stderr=None,
skip_reason=None,
traceback=None):
"""Helper keyword constructor for the namedtuple.
See this class' attributes for information on the arguments."""
@ -94,11 +102,16 @@ class CaseResult(collections.namedtuple('CaseResult', [
pass
else:
assert False
return super(cls, CaseResult).__new__(
cls, id, name, kind, stdout, stderr, skip_reason, traceback)
def updated(self, name=None, kind=None, stdout=None, stderr=None,
skip_reason=None, traceback=None):
return super(cls, CaseResult).__new__(cls, id, name, kind, stdout,
stderr, skip_reason, traceback)
def updated(self,
name=None,
kind=None,
stdout=None,
stderr=None,
skip_reason=None,
traceback=None):
"""Get a new validated CaseResult with the fields updated.
See this class' attributes for information on the arguments."""
@ -108,8 +121,13 @@ class CaseResult(collections.namedtuple('CaseResult', [
stderr = self.stderr if stderr is None else stderr
skip_reason = self.skip_reason if skip_reason is None else skip_reason
traceback = self.traceback if traceback is None else traceback
return CaseResult(id=self.id, name=name, kind=kind, stdout=stdout,
stderr=stderr, skip_reason=skip_reason,
return CaseResult(
id=self.id,
name=name,
kind=kind,
stdout=stdout,
stderr=stderr,
skip_reason=skip_reason,
traceback=traceback)
@ -282,9 +300,7 @@ class TerminalResult(CoverageResult):
def startTestRun(self):
"""See unittest.TestResult.startTestRun."""
super(TerminalResult, self).startTestRun()
self.out.write(
_Colors.HEADER +
'Testing gRPC Python...\n' +
self.out.write(_Colors.HEADER + 'Testing gRPC Python...\n' +
_Colors.END)
def stopTestRun(self):
@ -296,57 +312,46 @@ class TerminalResult(CoverageResult):
def addError(self, test, error):
"""See unittest.TestResult.addError."""
super(TerminalResult, self).addError(test, error)
self.out.write(
_Colors.FAIL +
'ERROR {}\n'.format(test.id()) +
self.out.write(_Colors.FAIL + 'ERROR {}\n'.format(test.id()) +
_Colors.END)
self.out.flush()
def addFailure(self, test, error):
"""See unittest.TestResult.addFailure."""
super(TerminalResult, self).addFailure(test, error)
self.out.write(
_Colors.FAIL +
'FAILURE {}\n'.format(test.id()) +
self.out.write(_Colors.FAIL + 'FAILURE {}\n'.format(test.id()) +
_Colors.END)
self.out.flush()
def addSuccess(self, test):
"""See unittest.TestResult.addSuccess."""
super(TerminalResult, self).addSuccess(test)
self.out.write(
_Colors.OK +
'SUCCESS {}\n'.format(test.id()) +
self.out.write(_Colors.OK + 'SUCCESS {}\n'.format(test.id()) +
_Colors.END)
self.out.flush()
def addSkip(self, test, reason):
"""See unittest.TestResult.addSkip."""
super(TerminalResult, self).addSkip(test, reason)
self.out.write(
_Colors.INFO +
'SKIP {}\n'.format(test.id()) +
self.out.write(_Colors.INFO + 'SKIP {}\n'.format(test.id()) +
_Colors.END)
self.out.flush()
def addExpectedFailure(self, test, error):
"""See unittest.TestResult.addExpectedFailure."""
super(TerminalResult, self).addExpectedFailure(test, error)
self.out.write(
_Colors.INFO +
'FAILURE_OK {}\n'.format(test.id()) +
self.out.write(_Colors.INFO + 'FAILURE_OK {}\n'.format(test.id()) +
_Colors.END)
self.out.flush()
def addUnexpectedSuccess(self, test):
"""See unittest.TestResult.addUnexpectedSuccess."""
super(TerminalResult, self).addUnexpectedSuccess(test)
self.out.write(
_Colors.INFO +
'UNEXPECTED_OK {}\n'.format(test.id()) +
self.out.write(_Colors.INFO + 'UNEXPECTED_OK {}\n'.format(test.id()) +
_Colors.END)
self.out.flush()
def _traceback_string(type, value, trace):
"""Generate a descriptive string of a Python exception traceback.
@ -362,6 +367,7 @@ def _traceback_string(type, value, trace):
traceback.print_exception(type, value, trace, file=buffer)
return buffer.getvalue()
def summary(result):
"""A summary string of a result object.
@ -372,56 +378,62 @@ def summary(result):
str: The summary string.
"""
assert isinstance(result, AugmentedResult)
untested = list(result.augmented_results(
untested = list(
result.augmented_results(
lambda case_result: case_result.kind is CaseResult.Kind.UNTESTED))
running = list(result.augmented_results(
running = list(
result.augmented_results(
lambda case_result: case_result.kind is CaseResult.Kind.RUNNING))
failures = list(result.augmented_results(
failures = list(
result.augmented_results(
lambda case_result: case_result.kind is CaseResult.Kind.FAILURE))
errors = list(result.augmented_results(
errors = list(
result.augmented_results(
lambda case_result: case_result.kind is CaseResult.Kind.ERROR))
successes = list(result.augmented_results(
successes = list(
result.augmented_results(
lambda case_result: case_result.kind is CaseResult.Kind.SUCCESS))
skips = list(result.augmented_results(
skips = list(
result.augmented_results(
lambda case_result: case_result.kind is CaseResult.Kind.SKIP))
expected_failures = list(result.augmented_results(
lambda case_result: case_result.kind is CaseResult.Kind.EXPECTED_FAILURE))
unexpected_successes = list(result.augmented_results(
lambda case_result: case_result.kind is CaseResult.Kind.UNEXPECTED_SUCCESS))
expected_failures = list(
result.augmented_results(
lambda case_result: case_result.kind is CaseResult.Kind.EXPECTED_FAILURE
))
unexpected_successes = list(
result.augmented_results(
lambda case_result: case_result.kind is CaseResult.Kind.UNEXPECTED_SUCCESS
))
running_names = [case.name for case in running]
finished_count = (len(failures) + len(errors) + len(successes) +
len(expected_failures) + len(unexpected_successes))
statistics = (
'{finished} tests finished:\n'
statistics = ('{finished} tests finished:\n'
'\t{successful} successful\n'
'\t{unsuccessful} unsuccessful\n'
'\t{skipped} skipped\n'
'\t{expected_fail} expected failures\n'
'\t{unexpected_successful} unexpected successes\n'
'Interrupted Tests:\n'
'\t{interrupted}\n'
.format(finished=finished_count,
'\t{interrupted}\n'.format(
finished=finished_count,
successful=len(successes),
unsuccessful=(len(failures) + len(errors)),
skipped=len(skips),
expected_fail=len(expected_failures),
unexpected_successful=len(unexpected_successes),
interrupted=str(running_names)))
tracebacks = '\n\n'.join([
(_Colors.FAIL + '{test_name}' + _Colors.END + '\n' +
_Colors.BOLD + 'traceback:' + _Colors.END + '\n' +
'{traceback}\n' +
_Colors.BOLD + 'stdout:' + _Colors.END + '\n' +
'{stdout}\n' +
_Colors.BOLD + 'stderr:' + _Colors.END + '\n' +
'{stderr}\n').format(
tracebacks = '\n\n'.join(
[(_Colors.FAIL + '{test_name}' + _Colors.END + '\n' + _Colors.BOLD +
'traceback:' + _Colors.END + '\n' + '{traceback}\n' + _Colors.BOLD +
'stdout:' + _Colors.END + '\n' + '{stdout}\n' + _Colors.BOLD +
'stderr:' + _Colors.END + '\n' + '{stderr}\n').format(
test_name=result.name,
traceback=_traceback_string(*result.traceback),
stdout=result.stdout, stderr=result.stderr)
for result in itertools.chain(failures, errors)
])
notes = 'Unexpected successes: {}\n'.format([
result.name for result in unexpected_successes])
stdout=result.stdout,
stderr=result.stderr)
for result in itertools.chain(failures, errors)])
notes = 'Unexpected successes: {}\n'.format(
[result.name for result in unexpected_successes])
return statistics + '\nErrors/Failures: \n' + tracebacks + '\n' + notes
@ -441,9 +453,7 @@ def jenkins_junit_xml(result):
})
for case in result.cases.values():
if case.kind is CaseResult.Kind.SUCCESS:
ElementTree.SubElement(suite, 'testcase', {
'name': case.name,
})
ElementTree.SubElement(suite, 'testcase', {'name': case.name,})
elif case.kind in (CaseResult.Kind.ERROR, CaseResult.Kind.FAILURE):
case_xml = ElementTree.SubElement(suite, 'testcase', {
'name': case.name,

@ -114,8 +114,7 @@ class CaptureFile(object):
os.close(self._saved_fd)
class AugmentedCase(collections.namedtuple('AugmentedCase', [
'case', 'id'])):
class AugmentedCase(collections.namedtuple('AugmentedCase', ['case', 'id'])):
"""A test case with a guaranteed unique externally specified identifier.
Attributes:
@ -144,8 +143,9 @@ class Runner(object):
# Ensure that every test case has no collision with any other test case in
# the augmented results.
augmented_cases = [AugmentedCase(case, uuid.uuid4())
for case in filtered_cases]
augmented_cases = [
AugmentedCase(case, uuid.uuid4()) for case in filtered_cases
]
case_id_by_case = dict((augmented_case.case, augmented_case.id)
for augmented_case in augmented_cases)
result_out = moves.cStringIO()
@ -162,9 +162,8 @@ class Runner(object):
def fault_handler(signal_number, frame):
stdout_pipe.write_bypass(
'Received fault signal {}\nstdout:\n{}\n\nstderr:{}\n'
.format(signal_number, stdout_pipe.output(),
stderr_pipe.output()))
'Received fault signal {}\nstdout:\n{}\n\nstderr:{}\n'.format(
signal_number, stdout_pipe.output(), stderr_pipe.output()))
os._exit(1)
def check_kill_self():
@ -172,16 +171,18 @@ class Runner(object):
stdout_pipe.write_bypass('Stopping tests short...')
result.stopTestRun()
stdout_pipe.write_bypass(result_out.getvalue())
stdout_pipe.write_bypass(
'\ninterrupted stdout:\n{}\n'.format(stdout_pipe.output().decode()))
stderr_pipe.write_bypass(
'\ninterrupted stderr:\n{}\n'.format(stderr_pipe.output().decode()))
stdout_pipe.write_bypass('\ninterrupted stdout:\n{}\n'.format(
stdout_pipe.output().decode()))
stderr_pipe.write_bypass('\ninterrupted stderr:\n{}\n'.format(
stderr_pipe.output().decode()))
os._exit(1)
def try_set_handler(name, handler):
try:
signal.signal(getattr(signal, name), handler)
except AttributeError:
pass
try_set_handler('SIGINT', sigint_handler)
try_set_handler('SIGSEGV', fault_handler)
try_set_handler('SIGBUS', fault_handler)
@ -195,7 +196,8 @@ class Runner(object):
# Run the tests
result.startTestRun()
for augmented_case in augmented_cases:
sys.stdout.write('Running {}\n'.format(augmented_case.case.id()))
sys.stdout.write('Running {}\n'.format(augmented_case.case.id(
)))
sys.stdout.flush()
case_thread = threading.Thread(
target=augmented_case.case.run, args=(result,))
@ -209,8 +211,8 @@ class Runner(object):
except:
# re-raise the exception after forcing the with-block to end
raise
result.set_output(
augmented_case.case, stdout_pipe.output(), stderr_pipe.output())
result.set_output(augmented_case.case,
stdout_pipe.output(), stderr_pipe.output())
sys.stdout.write(result_out.getvalue())
sys.stdout.flush()
result_out.truncate(0)
@ -226,4 +228,3 @@ class Runner(object):
with open('report.xml', 'wb') as report_xml_file:
_result.jenkins_junit_xml(result).write(report_xml_file)
return result

@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Tests of grpc_health.v1.health."""
import unittest
@ -80,11 +79,11 @@ class HealthServicerTest(unittest.TestCase):
request = health_pb2.HealthCheckRequest(
service='grpc.test.TestServiceNotServing')
resp = self._stub.Check(request)
self.assertEqual(health_pb2.HealthCheckResponse.NOT_SERVING, resp.status)
self.assertEqual(health_pb2.HealthCheckResponse.NOT_SERVING,
resp.status)
def test_not_found_service(self):
request = health_pb2.HealthCheckRequest(
service='not-found')
request = health_pb2.HealthCheckRequest(service='not-found')
with self.assertRaises(grpc.RpcError) as context:
resp = self._stub.Check(request)

@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""The Python client used to test negative http2 conditions."""
import argparse
@ -35,30 +34,33 @@ import grpc
from src.proto.grpc.testing import test_pb2
from src.proto.grpc.testing import messages_pb2
def _validate_payload_type_and_length(response, expected_type, expected_length):
if response.payload.type is not expected_type:
raise ValueError(
'expected payload type %s, got %s' %
raise ValueError('expected payload type %s, got %s' %
(expected_type, type(response.payload.type)))
elif len(response.payload.body) != expected_length:
raise ValueError(
'expected payload body size %d, got %d' %
raise ValueError('expected payload body size %d, got %d' %
(expected_length, len(response.payload.body)))
def _expect_status_code(call, expected_code):
if call.code() != expected_code:
raise ValueError(
'expected code %s, got %s' % (expected_code, call.code()))
raise ValueError('expected code %s, got %s' %
(expected_code, call.code()))
def _expect_status_details(call, expected_details):
if call.details() != expected_details:
raise ValueError(
'expected message %s, got %s' % (expected_details, call.details()))
raise ValueError('expected message %s, got %s' %
(expected_details, call.details()))
def _validate_status_code_and_details(call, expected_code, expected_details):
_expect_status_code(call, expected_code)
_expect_status_details(call, expected_details)
# common requests
_REQUEST_SIZE = 314159
_RESPONSE_SIZE = 271828
@ -68,46 +70,54 @@ _SIMPLE_REQUEST = messages_pb2.SimpleRequest(
response_size=_RESPONSE_SIZE,
payload=messages_pb2.Payload(body=b'\x00' * _REQUEST_SIZE))
def _goaway(stub):
first_response = stub.UnaryCall(_SIMPLE_REQUEST)
_validate_payload_type_and_length(first_response,
messages_pb2.COMPRESSABLE, _RESPONSE_SIZE)
_validate_payload_type_and_length(first_response, messages_pb2.COMPRESSABLE,
_RESPONSE_SIZE)
second_response = stub.UnaryCall(_SIMPLE_REQUEST)
_validate_payload_type_and_length(second_response,
messages_pb2.COMPRESSABLE, _RESPONSE_SIZE)
def _rst_after_header(stub):
resp_future = stub.UnaryCall.future(_SIMPLE_REQUEST)
_validate_status_code_and_details(resp_future, grpc.StatusCode.UNAVAILABLE, "")
_validate_status_code_and_details(resp_future, grpc.StatusCode.UNAVAILABLE,
"")
def _rst_during_data(stub):
resp_future = stub.UnaryCall.future(_SIMPLE_REQUEST)
_validate_status_code_and_details(resp_future, grpc.StatusCode.UNKNOWN, "")
def _rst_after_data(stub):
resp_future = stub.UnaryCall.future(_SIMPLE_REQUEST)
_validate_payload_type_and_length(next(resp_future),
messages_pb2.COMPRESSABLE, _RESPONSE_SIZE)
_validate_payload_type_and_length(
next(resp_future), messages_pb2.COMPRESSABLE, _RESPONSE_SIZE)
_validate_status_code_and_details(resp_future, grpc.StatusCode.UNKNOWN, "")
def _ping(stub):
response = stub.UnaryCall(_SIMPLE_REQUEST)
_validate_payload_type_and_length(response,
messages_pb2.COMPRESSABLE, _RESPONSE_SIZE)
_validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE,
_RESPONSE_SIZE)
def _max_streams(stub):
# send one req to ensure server sets MAX_STREAMS
response = stub.UnaryCall(_SIMPLE_REQUEST)
_validate_payload_type_and_length(response,
messages_pb2.COMPRESSABLE, _RESPONSE_SIZE)
_validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE,
_RESPONSE_SIZE)
# give the streams a workout
futures = []
for _ in range(15):
futures.append(stub.UnaryCall.future(_SIMPLE_REQUEST))
for future in futures:
_validate_payload_type_and_length(future.result(),
messages_pb2.COMPRESSABLE, _RESPONSE_SIZE)
_validate_payload_type_and_length(
future.result(), messages_pb2.COMPRESSABLE, _RESPONSE_SIZE)
def _run_test_case(test_case, stub):
if test_case == 'goaway':
@ -125,24 +135,33 @@ def _run_test_case(test_case, stub):
else:
raise ValueError("Invalid test case: %s" % test_case)
def _args():
parser = argparse.ArgumentParser()
parser.add_argument(
'--server_host', help='the host to which to connect', type=str,
'--server_host',
help='the host to which to connect',
type=str,
default="127.0.0.1")
parser.add_argument(
'--server_port', help='the port to which to connect', type=int,
'--server_port',
help='the port to which to connect',
type=int,
default="8080")
parser.add_argument(
'--test_case', help='the test case to execute', type=str,
'--test_case',
help='the test case to execute',
type=str,
default="goaway")
return parser.parse_args()
def _stub(server_host, server_port):
target = '{}:{}'.format(server_host, server_port)
channel = grpc.insecure_channel(target)
return test_pb2.TestServiceStub(channel)
def main():
args = _args()
stub = _stub(args.server_host, args.server_port)

@ -26,5 +26,3 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Insecure client-server interoperability as a unit test."""
from concurrent import futures
@ -40,14 +39,13 @@ from tests.interop import methods
from tests.interop import server
class InsecureIntraopTest(
_intraop_test_case.IntraopTestCase,
class InsecureIntraopTest(_intraop_test_case.IntraopTestCase,
unittest.TestCase):
def setUp(self):
self.server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
test_pb2.add_TestServiceServicer_to_server(
methods.TestService(), self.server)
test_pb2.add_TestServiceServicer_to_server(methods.TestService(),
self.server)
port = self.server.add_insecure_port('[::]:0')
self.server.start()
self.stub = test_pb2.TestServiceStub(

@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Common code for unit tests of the interoperability test code."""
from tests.interop import methods
@ -55,10 +54,13 @@ class IntraopTestCase(object):
methods.TestCase.PING_PONG.test_interoperability(self.stub, None)
def testCancelAfterBegin(self):
methods.TestCase.CANCEL_AFTER_BEGIN.test_interoperability(self.stub, None)
methods.TestCase.CANCEL_AFTER_BEGIN.test_interoperability(self.stub,
None)
def testCancelAfterFirstResponse(self):
methods.TestCase.CANCEL_AFTER_FIRST_RESPONSE.test_interoperability(self.stub, None)
methods.TestCase.CANCEL_AFTER_FIRST_RESPONSE.test_interoperability(
self.stub, None)
def testTimeoutOnSleepingServer(self):
methods.TestCase.TIMEOUT_ON_SLEEPING_SERVER.test_interoperability(self.stub, None)
methods.TestCase.TIMEOUT_ON_SLEEPING_SERVER.test_interoperability(
self.stub, None)

@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Secure client-server interoperability as a unit test."""
from concurrent import futures
@ -42,23 +41,23 @@ from tests.interop import resources
_SERVER_HOST_OVERRIDE = 'foo.test.google.fr'
class SecureIntraopTest(
_intraop_test_case.IntraopTestCase,
unittest.TestCase):
class SecureIntraopTest(_intraop_test_case.IntraopTestCase, unittest.TestCase):
def setUp(self):
self.server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
test_pb2.add_TestServiceServicer_to_server(
methods.TestService(), self.server)
test_pb2.add_TestServiceServicer_to_server(methods.TestService(),
self.server)
port = self.server.add_secure_port(
'[::]:0', grpc.ssl_server_credentials(
'[::]:0',
grpc.ssl_server_credentials(
[(resources.private_key(), resources.certificate_chain())]))
self.server.start()
self.stub = test_pb2.TestServiceStub(
grpc.secure_channel(
'localhost:{}'.format(port),
grpc.ssl_channel_credentials(resources.test_root_certificates()),
(('grpc.ssl_target_name_override', _SERVER_HOST_OVERRIDE,),)))
grpc.secure_channel('localhost:{}'.format(port),
grpc.ssl_channel_credentials(
resources.test_root_certificates()), ((
'grpc.ssl_target_name_override',
_SERVER_HOST_OVERRIDE,),)))
if __name__ == '__main__':

@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""The Python implementation of the GRPC interoperability test client."""
import argparse
@ -43,26 +42,38 @@ from tests.interop import resources
def _args():
parser = argparse.ArgumentParser()
parser.add_argument(
'--server_host', help='the host to which to connect', type=str,
'--server_host',
help='the host to which to connect',
type=str,
default="127.0.0.1")
parser.add_argument(
'--server_port', help='the port to which to connect', type=int)
parser.add_argument(
'--test_case', help='the test case to execute', type=str,
'--test_case',
help='the test case to execute',
type=str,
default="large_unary")
parser.add_argument(
'--use_tls', help='require a secure connection', default=False,
'--use_tls',
help='require a secure connection',
default=False,
type=resources.parse_bool)
parser.add_argument(
'--use_test_ca',
help='replace platform root CAs with ca.pem',
default=False,
type=resources.parse_bool)
parser.add_argument(
'--use_test_ca', help='replace platform root CAs with ca.pem',
default=False, type=resources.parse_bool)
'--server_host_override',
default="foo.test.google.fr",
help='the server host to which to claim to connect',
type=str)
parser.add_argument(
'--server_host_override', default="foo.test.google.fr",
help='the server host to which to claim to connect', type=str)
parser.add_argument('--oauth_scope', help='scope for OAuth tokens', type=str)
'--oauth_scope', help='scope for OAuth tokens', type=str)
parser.add_argument(
'--default_service_account',
help='email address of the default service account', type=str)
help='email address of the default service account',
type=str)
return parser.parse_args()
@ -74,12 +85,14 @@ def _stub(args):
target = '{}:{}'.format(args.server_host, args.server_port)
if args.test_case == 'oauth2_auth_token':
google_credentials = _application_default_credentials()
scoped_credentials = google_credentials.create_scoped([args.oauth_scope])
scoped_credentials = google_credentials.create_scoped(
[args.oauth_scope])
access_token = scoped_credentials.get_access_token().access_token
call_credentials = grpc.access_token_call_credentials(access_token)
elif args.test_case == 'compute_engine_creds':
google_credentials = _application_default_credentials()
scoped_credentials = google_credentials.create_scoped([args.oauth_scope])
scoped_credentials = google_credentials.create_scoped(
[args.oauth_scope])
# TODO(https://github.com/grpc/grpc/issues/6799): Eliminate this last
# remaining use of the Beta API.
call_credentials = implementations.google_call_credentials(
@ -103,9 +116,9 @@ def _stub(args):
channel_credentials = grpc.composite_channel_credentials(
channel_credentials, call_credentials)
channel = grpc.secure_channel(
target, channel_credentials,
(('grpc.ssl_target_name_override', args.server_host_override,),))
channel = grpc.secure_channel(target, channel_credentials, ((
'grpc.ssl_target_name_override',
args.server_host_override,),))
else:
channel = grpc.insecure_channel(target)
if args.test_case == "unimplemented_service":

@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Implementations of interoperability test methods."""
import enum
@ -46,24 +45,27 @@ from src.proto.grpc.testing import test_pb2
_INITIAL_METADATA_KEY = "x-grpc-test-echo-initial"
_TRAILING_METADATA_KEY = "x-grpc-test-echo-trailing-bin"
def _maybe_echo_metadata(servicer_context):
"""Copies metadata from request to response if it is present."""
invocation_metadata = dict(servicer_context.invocation_metadata())
if _INITIAL_METADATA_KEY in invocation_metadata:
initial_metadatum = (
_INITIAL_METADATA_KEY, invocation_metadata[_INITIAL_METADATA_KEY])
initial_metadatum = (_INITIAL_METADATA_KEY,
invocation_metadata[_INITIAL_METADATA_KEY])
servicer_context.send_initial_metadata((initial_metadatum,))
if _TRAILING_METADATA_KEY in invocation_metadata:
trailing_metadatum = (
_TRAILING_METADATA_KEY, invocation_metadata[_TRAILING_METADATA_KEY])
trailing_metadatum = (_TRAILING_METADATA_KEY,
invocation_metadata[_TRAILING_METADATA_KEY])
servicer_context.set_trailing_metadata((trailing_metadatum,))
def _maybe_echo_status_and_message(request, servicer_context):
"""Sets the response context code and details if the request asks for them"""
if request.HasField('response_status'):
servicer_context.set_code(request.response_status.code)
servicer_context.set_details(request.response_status.message)
class TestService(test_pb2.TestServiceServicer):
def EmptyCall(self, request, context):
@ -73,8 +75,7 @@ class TestService(test_pb2.TestServiceServicer):
def UnaryCall(self, request, context):
_maybe_echo_metadata(context)
_maybe_echo_status_and_message(request, context)
return messages_pb2.SimpleResponse(
payload=messages_pb2.Payload(
return messages_pb2.SimpleResponse(payload=messages_pb2.Payload(
type=messages_pb2.COMPRESSABLE,
body=b'\x00' * request.response_size))
@ -112,14 +113,14 @@ class TestService(test_pb2.TestServiceServicer):
def _expect_status_code(call, expected_code):
if call.code() != expected_code:
raise ValueError(
'expected code %s, got %s' % (expected_code, call.code()))
raise ValueError('expected code %s, got %s' %
(expected_code, call.code()))
def _expect_status_details(call, expected_details):
if call.details() != expected_details:
raise ValueError(
'expected message %s, got %s' % (expected_details, call.details()))
raise ValueError('expected message %s, got %s' %
(expected_details, call.details()))
def _validate_status_code_and_details(call, expected_code, expected_details):
@ -129,22 +130,22 @@ def _validate_status_code_and_details(call, expected_code, expected_details):
def _validate_payload_type_and_length(response, expected_type, expected_length):
if response.payload.type is not expected_type:
raise ValueError(
'expected payload type %s, got %s' %
raise ValueError('expected payload type %s, got %s' %
(expected_type, type(response.payload.type)))
elif len(response.payload.body) != expected_length:
raise ValueError(
'expected payload body size %d, got %d' %
raise ValueError('expected payload body size %d, got %d' %
(expected_length, len(response.payload.body)))
def _large_unary_common_behavior(
stub, fill_username, fill_oauth_scope, call_credentials):
def _large_unary_common_behavior(stub, fill_username, fill_oauth_scope,
call_credentials):
size = 314159
request = messages_pb2.SimpleRequest(
response_type=messages_pb2.COMPRESSABLE, response_size=size,
response_type=messages_pb2.COMPRESSABLE,
response_size=size,
payload=messages_pb2.Payload(body=b'\x00' * 271828),
fill_username=fill_username, fill_oauth_scope=fill_oauth_scope)
fill_username=fill_username,
fill_oauth_scope=fill_oauth_scope)
response_future = stub.UnaryCall.future(
request, credentials=call_credentials)
response = response_future.result()
@ -155,8 +156,8 @@ def _large_unary_common_behavior(
def _empty_unary(stub):
response = stub.EmptyCall(empty_pb2.Empty())
if not isinstance(response, empty_pb2.Empty):
raise TypeError(
'response is of type "%s", not empty_pb2.Empty!', type(response))
raise TypeError('response is of type "%s", not empty_pb2.Empty!',
type(response))
def _large_unary(stub):
@ -164,21 +165,27 @@ def _large_unary(stub):
def _client_streaming(stub):
payload_body_sizes = (27182, 8, 1828, 45904,)
payloads = (
messages_pb2.Payload(body=b'\x00' * size)
payload_body_sizes = (
27182,
8,
1828,
45904,)
payloads = (messages_pb2.Payload(body=b'\x00' * size)
for size in payload_body_sizes)
requests = (
messages_pb2.StreamingInputCallRequest(payload=payload)
requests = (messages_pb2.StreamingInputCallRequest(payload=payload)
for payload in payloads)
response = stub.StreamingInputCall(requests)
if response.aggregated_payload_size != 74922:
raise ValueError(
'incorrect size %d!' % response.aggregated_payload_size)
raise ValueError('incorrect size %d!' %
response.aggregated_payload_size)
def _server_streaming(stub):
sizes = (31415, 9, 2653, 58979,)
sizes = (
31415,
9,
2653,
58979,)
request = messages_pb2.StreamingOutputCallRequest(
response_type=messages_pb2.COMPRESSABLE,
@ -186,14 +193,11 @@ def _server_streaming(stub):
messages_pb2.ResponseParameters(size=sizes[0]),
messages_pb2.ResponseParameters(size=sizes[1]),
messages_pb2.ResponseParameters(size=sizes[2]),
messages_pb2.ResponseParameters(size=sizes[3]),
)
)
messages_pb2.ResponseParameters(size=sizes[3]),))
response_iterator = stub.StreamingOutputCall(request)
for index, response in enumerate(response_iterator):
_validate_payload_type_and_length(
response, messages_pb2.COMPRESSABLE, sizes[index])
_validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE,
sizes[index])
class _Pipe(object):
@ -236,13 +240,21 @@ class _Pipe(object):
def _ping_pong(stub):
request_response_sizes = (31415, 9, 2653, 58979,)
request_payload_sizes = (27182, 8, 1828, 45904,)
request_response_sizes = (
31415,
9,
2653,
58979,)
request_payload_sizes = (
27182,
8,
1828,
45904,)
with _Pipe() as pipe:
response_iterator = stub.FullDuplexCall(pipe)
for response_size, payload_size in zip(
request_response_sizes, request_payload_sizes):
for response_size, payload_size in zip(request_response_sizes,
request_payload_sizes):
request = messages_pb2.StreamingOutputCallRequest(
response_type=messages_pb2.COMPRESSABLE,
response_parameters=(
@ -265,8 +277,16 @@ def _cancel_after_begin(stub):
def _cancel_after_first_response(stub):
request_response_sizes = (31415, 9, 2653, 58979,)
request_payload_sizes = (27182, 8, 1828, 45904,)
request_response_sizes = (
31415,
9,
2653,
58979,)
request_payload_sizes = (
27182,
8,
1828,
45904,)
with _Pipe() as pipe:
response_iterator = stub.FullDuplexCall(pipe)
@ -331,8 +351,8 @@ def _status_code_and_message(stub):
response_type=messages_pb2.COMPRESSABLE,
response_size=1,
payload=messages_pb2.Payload(body=b'\x00'),
response_status=messages_pb2.EchoStatus(code=code, message=details)
)
response_status=messages_pb2.EchoStatus(
code=code, message=details))
response_future = stub.UnaryCall.future(request)
_validate_status_code_and_details(response_future, status, details)
@ -341,10 +361,10 @@ def _status_code_and_message(stub):
response_iterator = stub.FullDuplexCall(pipe)
request = messages_pb2.StreamingOutputCallRequest(
response_type=messages_pb2.COMPRESSABLE,
response_parameters=(
messages_pb2.ResponseParameters(size=1),),
response_parameters=(messages_pb2.ResponseParameters(size=1),),
payload=messages_pb2.Payload(body=b'\x00'),
response_status=messages_pb2.EchoStatus(code=code, message=details))
response_status=messages_pb2.EchoStatus(
code=code, message=details))
pipe.add(request) # sends the initial request.
# Dropping out of with block closes the pipe
_validate_status_code_and_details(response_iterator, status, details)
@ -365,21 +385,20 @@ def _unimplemented_service(unimplemented_service_stub):
def _custom_metadata(stub):
initial_metadata_value = "test_initial_metadata_value"
trailing_metadata_value = "\x0a\x0b\x0a\x0b\x0a\x0b"
metadata = (
(_INITIAL_METADATA_KEY, initial_metadata_value),
metadata = ((_INITIAL_METADATA_KEY, initial_metadata_value),
(_TRAILING_METADATA_KEY, trailing_metadata_value))
def _validate_metadata(response):
initial_metadata = dict(response.initial_metadata())
if initial_metadata[_INITIAL_METADATA_KEY] != initial_metadata_value:
raise ValueError(
'expected initial metadata %s, got %s' % (
initial_metadata_value, initial_metadata[_INITIAL_METADATA_KEY]))
raise ValueError('expected initial metadata %s, got %s' %
(initial_metadata_value,
initial_metadata[_INITIAL_METADATA_KEY]))
trailing_metadata = dict(response.trailing_metadata())
if trailing_metadata[_TRAILING_METADATA_KEY] != trailing_metadata_value:
raise ValueError(
'expected trailing metadata %s, got %s' % (
trailing_metadata_value, initial_metadata[_TRAILING_METADATA_KEY]))
raise ValueError('expected trailing metadata %s, got %s' %
(trailing_metadata_value,
initial_metadata[_TRAILING_METADATA_KEY]))
# Testing with UnaryCall
request = messages_pb2.SimpleRequest(
@ -394,19 +413,18 @@ def _custom_metadata(stub):
response_iterator = stub.FullDuplexCall(pipe, metadata=metadata)
request = messages_pb2.StreamingOutputCallRequest(
response_type=messages_pb2.COMPRESSABLE,
response_parameters=(
messages_pb2.ResponseParameters(size=1),))
response_parameters=(messages_pb2.ResponseParameters(size=1),))
pipe.add(request) # Sends the request
next(response_iterator) # Causes server to send trailing metadata
# Dropping out of the with block closes the pipe
_validate_metadata(response_iterator)
def _compute_engine_creds(stub, args):
response = _large_unary_common_behavior(stub, True, True, None)
if args.default_service_account != response.username:
raise ValueError(
'expected username %s, got %s' % (
args.default_service_account, response.username))
raise ValueError('expected username %s, got %s' %
(args.default_service_account, response.username))
def _oauth2_auth_token(stub, args):
@ -415,12 +433,11 @@ def _oauth2_auth_token(stub, args):
wanted_email = json.load(open(json_key_filename, 'rb'))['client_email']
response = _large_unary_common_behavior(stub, True, True, None)
if wanted_email != response.username:
raise ValueError(
'expected username %s, got %s' % (wanted_email, response.username))
raise ValueError('expected username %s, got %s' %
(wanted_email, response.username))
if args.oauth_scope.find(response.oauth_scope) == -1:
raise ValueError(
'expected to find oauth scope "{}" in received "{}"'.format(
response.oauth_scope, args.oauth_scope))
raise ValueError('expected to find oauth scope "{}" in received "{}"'.
format(response.oauth_scope, args.oauth_scope))
def _jwt_token_creds(stub, args):
@ -429,15 +446,16 @@ def _jwt_token_creds(stub, args):
wanted_email = json.load(open(json_key_filename, 'rb'))['client_email']
response = _large_unary_common_behavior(stub, True, False, None)
if wanted_email != response.username:
raise ValueError(
'expected username %s, got %s' % (wanted_email, response.username))
raise ValueError('expected username %s, got %s' %
(wanted_email, response.username))
def _per_rpc_creds(stub, args):
json_key_filename = os.environ[
oauth2client_client.GOOGLE_APPLICATION_CREDENTIALS]
wanted_email = json.load(open(json_key_filename, 'rb'))['client_email']
credentials = oauth2client_client.GoogleCredentials.get_application_default()
credentials = oauth2client_client.GoogleCredentials.get_application_default(
)
scoped_credentials = credentials.create_scoped([args.oauth_scope])
# TODO(https://github.com/grpc/grpc/issues/6799): Eliminate this last
# remaining use of the Beta API.
@ -445,8 +463,8 @@ def _per_rpc_creds(stub, args):
scoped_credentials)
response = _large_unary_common_behavior(stub, True, False, call_credentials)
if wanted_email != response.username:
raise ValueError(
'expected username %s, got %s' % (wanted_email, response.username))
raise ValueError('expected username %s, got %s' %
(wanted_email, response.username))
@enum.unique
@ -505,4 +523,5 @@ class TestCase(enum.Enum):
elif self is TestCase.PER_RPC_CREDS:
_per_rpc_creds(stub, args)
else:
raise NotImplementedError('Test case "%s" not implemented!' % self.name)
raise NotImplementedError('Test case "%s" not implemented!' %
self.name)

@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Constants and functions for data used in interoperability testing."""
import argparse
@ -40,8 +39,8 @@ _CERTIFICATE_CHAIN_RESOURCE_PATH = 'credentials/server1.pem'
def test_root_certificates():
return pkg_resources.resource_string(
__name__, _ROOT_CERTIFICATES_RESOURCE_PATH)
return pkg_resources.resource_string(__name__,
_ROOT_CERTIFICATES_RESOURCE_PATH)
def private_key():
@ -49,8 +48,8 @@ def private_key():
def certificate_chain():
return pkg_resources.resource_string(
__name__, _CERTIFICATE_CHAIN_RESOURCE_PATH)
return pkg_resources.resource_string(__name__,
_CERTIFICATE_CHAIN_RESOURCE_PATH)
def parse_bool(value):

@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""The Python implementation of the GRPC interoperability test server."""
import argparse
@ -45,11 +44,12 @@ _ONE_DAY_IN_SECONDS = 60 * 60 * 24
def serve():
parser = argparse.ArgumentParser()
parser.add_argument('--port', help='the port on which to serve', type=int)
parser.add_argument(
'--port', help='the port on which to serve', type=int)
parser.add_argument(
'--use_tls', help='require a secure connection',
default=False, type=resources.parse_bool)
'--use_tls',
help='require a secure connection',
default=False,
type=resources.parse_bool)
args = parser.parse_args()
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
@ -57,8 +57,8 @@ def serve():
if args.use_tls:
private_key = resources.private_key()
certificate_chain = resources.certificate_chain()
credentials = grpc.ssl_server_credentials(
((private_key, certificate_chain),))
credentials = grpc.ssl_server_credentials((
(private_key, certificate_chain),))
server.add_secure_port('[::]:{}'.format(args.port), credentials)
else:
server.add_insecure_port('[::]:{}'.format(args.port))
@ -73,5 +73,6 @@ def serve():
server.stop(None)
logging.info('Server stopped; exiting.')
if __name__ == '__main__':
serve()

@ -26,5 +26,3 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

@ -134,8 +134,10 @@ class _ServicerMethods(object):
class _Service(
collections.namedtuple(
'_Service', ('servicer_methods', 'server', 'stub',))):
collections.namedtuple('_Service', (
'servicer_methods',
'server',
'stub',))):
"""A live and running service.
Attributes:
@ -238,10 +240,8 @@ class PythonPluginTest(unittest.TestCase):
def testImportAttributes(self):
# check that we can access the generated module and its members.
self.assertIsNotNone(
getattr(service_pb2, STUB_IDENTIFIER, None))
self.assertIsNotNone(
getattr(service_pb2, SERVICER_IDENTIFIER, None))
self.assertIsNotNone(getattr(service_pb2, STUB_IDENTIFIER, None))
self.assertIsNotNone(getattr(service_pb2, SERVICER_IDENTIFIER, None))
self.assertIsNotNone(
getattr(service_pb2, ADD_SERVICER_TO_SERVER_IDENTIFIER, None))
@ -256,8 +256,8 @@ class PythonPluginTest(unittest.TestCase):
request = request_pb2.SimpleRequest(response_size=13)
with self.assertRaises(grpc.RpcError) as exception_context:
service.stub.UnaryCall(request)
self.assertIs(
exception_context.exception.code(), grpc.StatusCode.UNIMPLEMENTED)
self.assertIs(exception_context.exception.code(),
grpc.StatusCode.UNIMPLEMENTED)
def testUnaryCall(self):
service = _CreateService()
@ -286,8 +286,8 @@ class PythonPluginTest(unittest.TestCase):
request, timeout=test_constants.SHORT_TIMEOUT)
with self.assertRaises(grpc.RpcError) as exception_context:
response_future.result()
self.assertIs(
exception_context.exception.code(), grpc.StatusCode.DEADLINE_EXCEEDED)
self.assertIs(exception_context.exception.code(),
grpc.StatusCode.DEADLINE_EXCEEDED)
self.assertIs(response_future.code(), grpc.StatusCode.DEADLINE_EXCEEDED)
def testUnaryCallFutureCancelled(self):
@ -313,8 +313,8 @@ class PythonPluginTest(unittest.TestCase):
responses = service.stub.StreamingOutputCall(request)
expected_responses = service.servicer_methods.StreamingOutputCall(
request, 'not a real RpcContext!')
for expected_response, response in moves.zip_longest(
expected_responses, responses):
for expected_response, response in moves.zip_longest(expected_responses,
responses):
self.assertEqual(expected_response, response)
def testStreamingOutputCallExpired(self):
@ -325,8 +325,8 @@ class PythonPluginTest(unittest.TestCase):
request, timeout=test_constants.SHORT_TIMEOUT)
with self.assertRaises(grpc.RpcError) as exception_context:
list(responses)
self.assertIs(
exception_context.exception.code(), grpc.StatusCode.DEADLINE_EXCEEDED)
self.assertIs(exception_context.exception.code(),
grpc.StatusCode.DEADLINE_EXCEEDED)
def testStreamingOutputCallCancelled(self):
service = _CreateService()
@ -346,15 +346,15 @@ class PythonPluginTest(unittest.TestCase):
self.assertIsNotNone(responses)
with self.assertRaises(grpc.RpcError) as exception_context:
next(responses)
self.assertIs(exception_context.exception.code(), grpc.StatusCode.UNKNOWN)
self.assertIs(exception_context.exception.code(),
grpc.StatusCode.UNKNOWN)
def testStreamingInputCall(self):
service = _CreateService()
response = service.stub.StreamingInputCall(
_streaming_input_request_iterator())
expected_response = service.servicer_methods.StreamingInputCall(
_streaming_input_request_iterator(),
'not a real RpcContext!')
_streaming_input_request_iterator(), 'not a real RpcContext!')
self.assertEqual(expected_response, response)
def testStreamingInputCallFuture(self):
@ -364,8 +364,7 @@ class PythonPluginTest(unittest.TestCase):
_streaming_input_request_iterator())
response = response_future.result()
expected_response = service.servicer_methods.StreamingInputCall(
_streaming_input_request_iterator(),
'not a real RpcContext!')
_streaming_input_request_iterator(), 'not a real RpcContext!')
self.assertEqual(expected_response, response)
def testStreamingInputCallFutureExpired(self):
@ -377,10 +376,10 @@ class PythonPluginTest(unittest.TestCase):
with self.assertRaises(grpc.RpcError) as exception_context:
response_future.result()
self.assertIsInstance(response_future.exception(), grpc.RpcError)
self.assertIs(
response_future.exception().code(), grpc.StatusCode.DEADLINE_EXCEEDED)
self.assertIs(
exception_context.exception.code(), grpc.StatusCode.DEADLINE_EXCEEDED)
self.assertIs(response_future.exception().code(),
grpc.StatusCode.DEADLINE_EXCEEDED)
self.assertIs(exception_context.exception.code(),
grpc.StatusCode.DEADLINE_EXCEEDED)
def testStreamingInputCallFutureCancelled(self):
service = _CreateService()
@ -402,13 +401,11 @@ class PythonPluginTest(unittest.TestCase):
def testFullDuplexCall(self):
service = _CreateService()
responses = service.stub.FullDuplexCall(
_full_duplex_request_iterator())
responses = service.stub.FullDuplexCall(_full_duplex_request_iterator())
expected_responses = service.servicer_methods.FullDuplexCall(
_full_duplex_request_iterator(),
'not a real RpcContext!')
for expected_response, response in moves.zip_longest(
expected_responses, responses):
_full_duplex_request_iterator(), 'not a real RpcContext!')
for expected_response, response in moves.zip_longest(expected_responses,
responses):
self.assertEqual(expected_response, response)
def testFullDuplexCallExpired(self):
@ -419,8 +416,8 @@ class PythonPluginTest(unittest.TestCase):
request_iterator, timeout=test_constants.SHORT_TIMEOUT)
with self.assertRaises(grpc.RpcError) as exception_context:
list(responses)
self.assertIs(
exception_context.exception.code(), grpc.StatusCode.DEADLINE_EXCEEDED)
self.assertIs(exception_context.exception.code(),
grpc.StatusCode.DEADLINE_EXCEEDED)
def testFullDuplexCallCancelled(self):
service = _CreateService()
@ -430,8 +427,8 @@ class PythonPluginTest(unittest.TestCase):
responses.cancel()
with self.assertRaises(grpc.RpcError) as exception_context:
next(responses)
self.assertIs(
exception_context.exception.code(), grpc.StatusCode.CANCELLED)
self.assertIs(exception_context.exception.code(),
grpc.StatusCode.CANCELLED)
def testFullDuplexCallFailed(self):
request_iterator = _full_duplex_request_iterator()
@ -440,10 +437,12 @@ class PythonPluginTest(unittest.TestCase):
responses = service.stub.FullDuplexCall(request_iterator)
with self.assertRaises(grpc.RpcError) as exception_context:
next(responses)
self.assertIs(exception_context.exception.code(), grpc.StatusCode.UNKNOWN)
self.assertIs(exception_context.exception.code(),
grpc.StatusCode.UNKNOWN)
def testHalfDuplexCall(self):
service = _CreateService()
def half_duplex_request_iterator():
request = request_pb2.StreamingOutputCallRequest()
request.response_parameters.add(size=1, interval_us=0)
@ -452,16 +451,18 @@ class PythonPluginTest(unittest.TestCase):
request.response_parameters.add(size=2, interval_us=0)
request.response_parameters.add(size=3, interval_us=0)
yield request
responses = service.stub.HalfDuplexCall(half_duplex_request_iterator())
expected_responses = service.servicer_methods.HalfDuplexCall(
half_duplex_request_iterator(), 'not a real RpcContext!')
for expected_response, response in moves.zip_longest(
expected_responses, responses):
for expected_response, response in moves.zip_longest(expected_responses,
responses):
self.assertEqual(expected_response, response)
def testHalfDuplexCallWedged(self):
condition = threading.Condition()
wait_cell = [False]
@contextlib.contextmanager
def wait(): # pylint: disable=invalid-name
# Where's Python 3's 'nonlocal' statement when you need it?
@ -471,6 +472,7 @@ class PythonPluginTest(unittest.TestCase):
with condition:
wait_cell[0] = False
condition.notify_all()
def half_duplex_request_iterator():
request = request_pb2.StreamingOutputCallRequest()
request.response_parameters.add(size=1, interval_us=0)
@ -478,15 +480,17 @@ class PythonPluginTest(unittest.TestCase):
with condition:
while wait_cell[0]:
condition.wait()
service = _CreateService()
with wait():
responses = service.stub.HalfDuplexCall(
half_duplex_request_iterator(), timeout=test_constants.SHORT_TIMEOUT)
half_duplex_request_iterator(),
timeout=test_constants.SHORT_TIMEOUT)
# half-duplex waits for the client to send all info
with self.assertRaises(grpc.RpcError) as exception_context:
next(responses)
self.assertIs(
exception_context.exception.code(), grpc.StatusCode.DEADLINE_EXCEEDED)
self.assertIs(exception_context.exception.code(),
grpc.StatusCode.DEADLINE_EXCEEDED)
if __name__ == '__main__':

@ -49,6 +49,7 @@ from tests.unit.framework.common import test_constants
_MESSAGES_IMPORT = b'import "messages.proto";'
@contextlib.contextmanager
def _system_path(path):
old_system_path = sys.path[:]
@ -96,8 +97,7 @@ class SeparateTestMixin(object):
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=test_constants.POOL_SIZE))
pb2_grpc.add_TestServiceServicer_to_server(
DummySplitServicer(
pb2.Request, pb2.Response), server)
DummySplitServicer(pb2.Request, pb2.Response), server)
port = server.add_insecure_port('[::]:0')
server.start()
channel = grpc.insecure_channel('localhost:{}'.format(port))
@ -137,8 +137,7 @@ class CommonTestMixin(object):
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=test_constants.POOL_SIZE))
pb2_grpc.add_TestServiceServicer_to_server(
DummySplitServicer(
pb2.Request, pb2.Response), server)
DummySplitServicer(pb2.Request, pb2.Response), server)
port = server.add_insecure_port('[::]:0')
server.start()
channel = grpc.insecure_channel('localhost:{}'.format(port))
@ -157,23 +156,28 @@ class SameSeparateTest(unittest.TestCase, SeparateTestMixin):
self.directory = tempfile.mkdtemp(suffix='same_separate', dir='.')
self.proto_directory = os.path.join(self.directory, 'proto_path')
self.python_out_directory = os.path.join(self.directory, 'python_out')
self.grpc_python_out_directory = os.path.join(self.directory, 'grpc_python_out')
self.grpc_python_out_directory = os.path.join(self.directory,
'grpc_python_out')
os.makedirs(self.proto_directory)
os.makedirs(self.python_out_directory)
os.makedirs(self.grpc_python_out_directory)
same_proto_file = os.path.join(self.proto_directory, 'same_separate.proto')
same_proto_file = os.path.join(self.proto_directory,
'same_separate.proto')
open(same_proto_file, 'wb').write(same_proto_contents)
protoc_result = protoc.main([
'',
'--proto_path={}'.format(self.proto_directory),
'--python_out={}'.format(self.python_out_directory),
'--grpc_python_out=grpc_2_0:{}'.format(self.grpc_python_out_directory),
'--grpc_python_out=grpc_2_0:{}'.format(
self.grpc_python_out_directory),
same_proto_file,
])
if protoc_result != 0:
raise Exception("unexpected protoc error")
open(os.path.join(self.grpc_python_out_directory, '__init__.py'), 'w').write('')
open(os.path.join(self.python_out_directory, '__init__.py'), 'w').write('')
open(os.path.join(self.grpc_python_out_directory, '__init__.py'),
'w').write('')
open(os.path.join(self.python_out_directory, '__init__.py'),
'w').write('')
self.pb2_import = 'same_separate_pb2'
self.pb2_grpc_import = 'same_separate_pb2_grpc'
self.should_find_services_in_pb2 = False
@ -193,7 +197,8 @@ class SameCommonTest(unittest.TestCase, CommonTestMixin):
self.grpc_python_out_directory = self.python_out_directory
os.makedirs(self.proto_directory)
os.makedirs(self.python_out_directory)
same_proto_file = os.path.join(self.proto_directory, 'same_common.proto')
same_proto_file = os.path.join(self.proto_directory,
'same_common.proto')
open(same_proto_file, 'wb').write(same_proto_contents)
protoc_result = protoc.main([
'',
@ -204,7 +209,8 @@ class SameCommonTest(unittest.TestCase, CommonTestMixin):
])
if protoc_result != 0:
raise Exception("unexpected protoc error")
open(os.path.join(self.python_out_directory, '__init__.py'), 'w').write('')
open(os.path.join(self.python_out_directory, '__init__.py'),
'w').write('')
self.pb2_import = 'same_common_pb2'
self.pb2_grpc_import = 'same_common_pb2_grpc'
self.should_find_services_in_pb2 = True
@ -232,10 +238,9 @@ class SplitCommonTest(unittest.TestCase, CommonTestMixin):
'split_common_services.proto')
messages_proto_file = os.path.join(self.proto_directory,
'split_common_messages.proto')
open(services_proto_file, 'wb').write(services_proto_contents.replace(
_MESSAGES_IMPORT,
b'import "split_common_messages.proto";'
))
open(services_proto_file, 'wb').write(
services_proto_contents.replace(
_MESSAGES_IMPORT, b'import "split_common_messages.proto";'))
open(messages_proto_file, 'wb').write(messages_proto_contents)
protoc_result = protoc.main([
'',
@ -247,7 +252,8 @@ class SplitCommonTest(unittest.TestCase, CommonTestMixin):
])
if protoc_result != 0:
raise Exception("unexpected protoc error")
open(os.path.join(self.python_out_directory, '__init__.py'), 'w').write('')
open(os.path.join(self.python_out_directory, '__init__.py'),
'w').write('')
self.pb2_import = 'split_common_messages_pb2'
self.pb2_grpc_import = 'split_common_services_pb2_grpc'
self.should_find_services_in_pb2 = False
@ -268,7 +274,8 @@ class SplitSeparateTest(unittest.TestCase, SeparateTestMixin):
self.directory = tempfile.mkdtemp(suffix='split_separate', dir='.')
self.proto_directory = os.path.join(self.directory, 'proto_path')
self.python_out_directory = os.path.join(self.directory, 'python_out')
self.grpc_python_out_directory = os.path.join(self.directory, 'grpc_python_out')
self.grpc_python_out_directory = os.path.join(self.directory,
'grpc_python_out')
os.makedirs(self.proto_directory)
os.makedirs(self.python_out_directory)
os.makedirs(self.grpc_python_out_directory)
@ -276,22 +283,23 @@ class SplitSeparateTest(unittest.TestCase, SeparateTestMixin):
'split_separate_services.proto')
messages_proto_file = os.path.join(self.proto_directory,
'split_separate_messages.proto')
open(services_proto_file, 'wb').write(services_proto_contents.replace(
_MESSAGES_IMPORT,
b'import "split_separate_messages.proto";'
))
open(services_proto_file, 'wb').write(
services_proto_contents.replace(
_MESSAGES_IMPORT, b'import "split_separate_messages.proto";'))
open(messages_proto_file, 'wb').write(messages_proto_contents)
protoc_result = protoc.main([
'',
'--proto_path={}'.format(self.proto_directory),
'--python_out={}'.format(self.python_out_directory),
'--grpc_python_out=grpc_2_0:{}'.format(self.grpc_python_out_directory),
'--grpc_python_out=grpc_2_0:{}'.format(
self.grpc_python_out_directory),
services_proto_file,
messages_proto_file,
])
if protoc_result != 0:
raise Exception("unexpected protoc error")
open(os.path.join(self.python_out_directory, '__init__.py'), 'w').write('')
open(os.path.join(self.python_out_directory, '__init__.py'),
'w').write('')
self.pb2_import = 'split_separate_messages_pb2'
self.pb2_grpc_import = 'split_separate_services_pb2_grpc'
self.should_find_services_in_pb2 = False

@ -244,10 +244,8 @@ class PythonPluginTest(unittest.TestCase):
def testImportAttributes(self):
# check that we can access the generated module and its members.
self.assertIsNotNone(
getattr(service_pb2, SERVICER_IDENTIFIER, None))
self.assertIsNotNone(
getattr(service_pb2, STUB_IDENTIFIER, None))
self.assertIsNotNone(getattr(service_pb2, SERVICER_IDENTIFIER, None))
self.assertIsNotNone(getattr(service_pb2, STUB_IDENTIFIER, None))
self.assertIsNotNone(
getattr(service_pb2, SERVER_FACTORY_IDENTIFIER, None))
self.assertIsNotNone(
@ -263,7 +261,8 @@ class PythonPluginTest(unittest.TestCase):
try:
stub.UnaryCall(request, test_constants.LONG_TIMEOUT)
except face.AbortionError as error:
self.assertEqual(interfaces.StatusCode.UNIMPLEMENTED, error.code)
self.assertEqual(interfaces.StatusCode.UNIMPLEMENTED,
error.code)
def testUnaryCall(self):
with _CreateService() as (methods, stub):
@ -311,8 +310,8 @@ class PythonPluginTest(unittest.TestCase):
def testStreamingOutputCall(self):
with _CreateService() as (methods, stub):
request = _streaming_output_request()
responses = stub.StreamingOutputCall(
request, test_constants.LONG_TIMEOUT)
responses = stub.StreamingOutputCall(request,
test_constants.LONG_TIMEOUT)
expected_responses = methods.StreamingOutputCall(
request, 'not a real RpcContext!')
for expected_response, response in moves.zip_longest(
@ -331,8 +330,8 @@ class PythonPluginTest(unittest.TestCase):
def testStreamingOutputCallCancelled(self):
with _CreateService() as (methods, stub):
request = _streaming_output_request()
responses = stub.StreamingOutputCall(
request, test_constants.LONG_TIMEOUT)
responses = stub.StreamingOutputCall(request,
test_constants.LONG_TIMEOUT)
next(responses)
responses.cancel()
with self.assertRaises(face.CancellationError):
@ -353,8 +352,7 @@ class PythonPluginTest(unittest.TestCase):
_streaming_input_request_iterator(),
test_constants.LONG_TIMEOUT)
expected_response = methods.StreamingInputCall(
_streaming_input_request_iterator(),
'not a real RpcContext!')
_streaming_input_request_iterator(), 'not a real RpcContext!')
self.assertEqual(expected_response, response)
def testStreamingInputCallFuture(self):
@ -365,8 +363,7 @@ class PythonPluginTest(unittest.TestCase):
test_constants.LONG_TIMEOUT)
response = response_future.result()
expected_response = methods.StreamingInputCall(
_streaming_input_request_iterator(),
'not a real RpcContext!')
_streaming_input_request_iterator(), 'not a real RpcContext!')
self.assertEqual(expected_response, response)
def testStreamingInputCallFutureExpired(self):
@ -377,8 +374,8 @@ class PythonPluginTest(unittest.TestCase):
test_constants.SHORT_TIMEOUT)
with self.assertRaises(face.ExpirationError):
response_future.result()
self.assertIsInstance(
response_future.exception(), face.ExpirationError)
self.assertIsInstance(response_future.exception(),
face.ExpirationError)
def testStreamingInputCallFutureCancelled(self):
with _CreateService() as (methods, stub):
@ -401,12 +398,10 @@ class PythonPluginTest(unittest.TestCase):
def testFullDuplexCall(self):
with _CreateService() as (methods, stub):
responses = stub.FullDuplexCall(
_full_duplex_request_iterator(),
responses = stub.FullDuplexCall(_full_duplex_request_iterator(),
test_constants.LONG_TIMEOUT)
expected_responses = methods.FullDuplexCall(
_full_duplex_request_iterator(),
'not a real RpcContext!')
_full_duplex_request_iterator(), 'not a real RpcContext!')
for expected_response, response in moves.zip_longest(
expected_responses, responses):
self.assertEqual(expected_response, response)
@ -415,16 +410,16 @@ class PythonPluginTest(unittest.TestCase):
request_iterator = _full_duplex_request_iterator()
with _CreateService() as (methods, stub):
with methods.pause():
responses = stub.FullDuplexCall(
request_iterator, test_constants.SHORT_TIMEOUT)
responses = stub.FullDuplexCall(request_iterator,
test_constants.SHORT_TIMEOUT)
with self.assertRaises(face.ExpirationError):
list(responses)
def testFullDuplexCallCancelled(self):
with _CreateService() as (methods, stub):
request_iterator = _full_duplex_request_iterator()
responses = stub.FullDuplexCall(
request_iterator, test_constants.LONG_TIMEOUT)
responses = stub.FullDuplexCall(request_iterator,
test_constants.LONG_TIMEOUT)
next(responses)
responses.cancel()
with self.assertRaises(face.CancellationError):
@ -434,14 +429,15 @@ class PythonPluginTest(unittest.TestCase):
request_iterator = _full_duplex_request_iterator()
with _CreateService() as (methods, stub):
with methods.fail():
responses = stub.FullDuplexCall(
request_iterator, test_constants.LONG_TIMEOUT)
responses = stub.FullDuplexCall(request_iterator,
test_constants.LONG_TIMEOUT)
self.assertIsNotNone(responses)
with self.assertRaises(face.RemoteError):
next(responses)
def testHalfDuplexCall(self):
with _CreateService() as (methods, stub):
def half_duplex_request_iterator():
request = request_pb2.StreamingOutputCallRequest()
request.response_parameters.add(size=1, interval_us=0)
@ -450,8 +446,9 @@ class PythonPluginTest(unittest.TestCase):
request.response_parameters.add(size=2, interval_us=0)
request.response_parameters.add(size=3, interval_us=0)
yield request
responses = stub.HalfDuplexCall(
half_duplex_request_iterator(), test_constants.LONG_TIMEOUT)
responses = stub.HalfDuplexCall(half_duplex_request_iterator(),
test_constants.LONG_TIMEOUT)
expected_responses = methods.HalfDuplexCall(
half_duplex_request_iterator(), 'not a real RpcContext!')
for check in moves.zip_longest(expected_responses, responses):
@ -461,6 +458,7 @@ class PythonPluginTest(unittest.TestCase):
def testHalfDuplexCallWedged(self):
condition = threading.Condition()
wait_cell = [False]
@contextlib.contextmanager
def wait(): # pylint: disable=invalid-name
# Where's Python 3's 'nonlocal' statement when you need it?
@ -470,6 +468,7 @@ class PythonPluginTest(unittest.TestCase):
with condition:
wait_cell[0] = False
condition.notify_all()
def half_duplex_request_iterator():
request = request_pb2.StreamingOutputCallRequest()
request.response_parameters.add(size=1, interval_us=0)
@ -477,10 +476,11 @@ class PythonPluginTest(unittest.TestCase):
with condition:
while wait_cell[0]:
condition.wait()
with _CreateService() as (methods, stub):
with wait():
responses = stub.HalfDuplexCall(
half_duplex_request_iterator(), test_constants.SHORT_TIMEOUT)
responses = stub.HalfDuplexCall(half_duplex_request_iterator(),
test_constants.SHORT_TIMEOUT)
# half-duplex waits for the client to send all info
with self.assertRaises(face.ExpirationError):
next(responses)

@ -26,5 +26,3 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

@ -26,5 +26,3 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

@ -26,5 +26,3 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

@ -26,5 +26,3 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

@ -26,5 +26,3 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

@ -26,5 +26,3 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

@ -26,5 +26,3 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

@ -26,5 +26,3 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

@ -26,5 +26,3 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Defines test client behaviors (UNARY/STREAMING) (SYNC/ASYNC)."""
import abc
@ -62,7 +61,8 @@ class BenchmarkClient:
def __init__(self, server, config, hist):
# Create the stub
if config.HasField('security_params'):
creds = grpc.ssl_channel_credentials(resources.test_root_certificates())
creds = grpc.ssl_channel_credentials(
resources.test_root_certificates())
channel = test_common.test_secure_channel(
server, creds, config.security_params.server_host_override)
else:
@ -166,8 +166,8 @@ class _SyncStream(object):
def start(self):
self._is_streaming = True
response_stream = self._stub.StreamingCall(
self._request_generator(), _TIMEOUT)
response_stream = self._stub.StreamingCall(self._request_generator(),
_TIMEOUT)
for _ in response_stream:
self._handle_response(
self, time.time() - self._send_time_queue.get_nowait())
@ -190,9 +190,11 @@ class StreamingSyncBenchmarkClient(BenchmarkClient):
super(StreamingSyncBenchmarkClient, self).__init__(server, config, hist)
self._pool = futures.ThreadPoolExecutor(
max_workers=config.outstanding_rpcs_per_channel)
self._streams = [_SyncStream(self._stub, self._generic,
self._request, self._handle_response)
for _ in xrange(config.outstanding_rpcs_per_channel)]
self._streams = [
_SyncStream(self._stub, self._generic, self._request,
self._handle_response)
for _ in xrange(config.outstanding_rpcs_per_channel)
]
self._curr_stream = 0
def send_request(self):

@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Defines behavior for WHEN clients send requests.
Each client exposes a non-blocking send_request() method that the

@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""The entry point for the qps worker."""
import argparse
@ -52,7 +51,8 @@ def run_worker_server(port):
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='gRPC Python performance testing worker')
parser.add_argument('--driver_port',
parser.add_argument(
'--driver_port',
type=int,
dest='port',
help='The port the worker should listen on')

@ -70,7 +70,8 @@ class WorkerServer(services_pb2.WorkerServiceServicer):
def _get_server_status(self, start_time, end_time, port, cores):
end_time = time.time()
elapsed_time = end_time - start_time
stats = stats_pb2.ServerStats(time_elapsed=elapsed_time,
stats = stats_pb2.ServerStats(
time_elapsed=elapsed_time,
time_user=elapsed_time,
time_system=elapsed_time)
return control_pb2.ServerStatus(stats=stats, port=port, cores=cores)
@ -82,11 +83,12 @@ class WorkerServer(services_pb2.WorkerServiceServicer):
server_threads = multiprocessing.cpu_count() * 5
else:
server_threads = config.async_server_threads
server = grpc.server(futures.ThreadPoolExecutor(
max_workers=server_threads))
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=server_threads))
if config.server_type == control_pb2.ASYNC_SERVER:
servicer = benchmark_server.BenchmarkServer()
services_pb2.add_BenchmarkServiceServicer_to_server(servicer, server)
services_pb2.add_BenchmarkServiceServicer_to_server(servicer,
server)
elif config.server_type == control_pb2.ASYNC_GENERIC_SERVER:
resp_size = config.payload_config.bytebuf_params.resp_size
servicer = benchmark_server.GenericBenchmarkServer(resp_size)
@ -100,12 +102,14 @@ class WorkerServer(services_pb2.WorkerServiceServicer):
'grpc.testing.BenchmarkService', method_implementations)
server.add_generic_rpc_handlers((handler,))
else:
raise Exception('Unsupported server type {}'.format(config.server_type))
raise Exception('Unsupported server type {}'.format(
config.server_type))
if config.HasField('security_params'): # Use SSL
server_creds = grpc.ssl_server_credentials(
((resources.private_key(), resources.certificate_chain()),))
port = server.add_secure_port('[::]:{}'.format(config.port), server_creds)
server_creds = grpc.ssl_server_credentials((
(resources.private_key(), resources.certificate_chain()),))
port = server.add_secure_port('[::]:{}'.format(config.port),
server_creds)
else:
port = server.add_insecure_port('[::]:{}'.format(config.port))
@ -145,7 +149,8 @@ class WorkerServer(services_pb2.WorkerServiceServicer):
latencies = qps_data.get_data()
end_time = time.time()
elapsed_time = end_time - start_time
stats = stats_pb2.ClientStats(latencies=latencies,
stats = stats_pb2.ClientStats(
latencies=latencies,
time_elapsed=elapsed_time,
time_user=elapsed_time,
time_system=elapsed_time)
@ -166,7 +171,8 @@ class WorkerServer(services_pb2.WorkerServiceServicer):
else:
raise Exception('Async streaming client not supported')
else:
raise Exception('Unsupported client type {}'.format(config.client_type))
raise Exception('Unsupported client type {}'.format(
config.client_type))
# In multi-channel tests, we split the load across all channels
load_factor = float(config.client_channels)
@ -175,6 +181,7 @@ class WorkerServer(services_pb2.WorkerServiceServicer):
client, config.outstanding_rpcs_per_channel)
else: # Open loop Poisson
alpha = config.load_params.poisson.offered_load / load_factor
def poisson():
while True:
yield random.expovariate(alpha)

@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Tests of grpc_reflection.v1alpha.reflection."""
import unittest
@ -45,14 +44,16 @@ from tests.unit.framework.common import test_constants
_EMPTY_PROTO_FILE_NAME = 'src/proto/grpc/testing/empty.proto'
_EMPTY_PROTO_SYMBOL_NAME = 'grpc.testing.Empty'
_SERVICE_NAMES = (
'Angstrom', 'Bohr', 'Curie', 'Dyson', 'Einstein', 'Feynman', 'Galilei')
_SERVICE_NAMES = ('Angstrom', 'Bohr', 'Curie', 'Dyson', 'Einstein', 'Feynman',
'Galilei')
def _file_descriptor_to_proto(descriptor):
proto = descriptor_pb2.FileDescriptorProto()
descriptor.CopyToProto(proto)
return proto.SerializeToString()
class ReflectionServicerTest(unittest.TestCase):
def setUp(self):
@ -60,7 +61,8 @@ class ReflectionServicerTest(unittest.TestCase):
server_pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY)
self._server = grpc.server(server_pool)
port = self._server.add_insecure_port('[::]:0')
reflection_pb2.add_ServerReflectionServicer_to_server(servicer, self._server)
reflection_pb2.add_ServerReflectionServicer_to_server(servicer,
self._server)
self._server.start()
channel = grpc.insecure_channel('localhost:%d' % port)
@ -69,117 +71,85 @@ class ReflectionServicerTest(unittest.TestCase):
def testFileByName(self):
requests = (
reflection_pb2.ServerReflectionRequest(
file_by_filename=_EMPTY_PROTO_FILE_NAME
),
file_by_filename=_EMPTY_PROTO_FILE_NAME),
reflection_pb2.ServerReflectionRequest(
file_by_filename='i-donut-exist'
),
)
file_by_filename='i-donut-exist'),)
responses = tuple(self._stub.ServerReflectionInfo(iter(requests)))
expected_responses = (
reflection_pb2.ServerReflectionResponse(
valid_host='',
file_descriptor_response=reflection_pb2.FileDescriptorResponse(
file_descriptor_proto=(
_file_descriptor_to_proto(empty_pb2.DESCRIPTOR),
)
)
),
_file_descriptor_to_proto(empty_pb2.DESCRIPTOR),))),
reflection_pb2.ServerReflectionResponse(
valid_host='',
error_response=reflection_pb2.ErrorResponse(
error_code=grpc.StatusCode.NOT_FOUND.value[0],
error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(),
)
),
)
)),)
self.assertSequenceEqual(expected_responses, responses)
def testFileBySymbol(self):
requests = (
reflection_pb2.ServerReflectionRequest(
file_containing_symbol=_EMPTY_PROTO_SYMBOL_NAME
),
file_containing_symbol=_EMPTY_PROTO_SYMBOL_NAME),
reflection_pb2.ServerReflectionRequest(
file_containing_symbol='i.donut.exist.co.uk.org.net.me.name.foo'
),
)
),)
responses = tuple(self._stub.ServerReflectionInfo(iter(requests)))
expected_responses = (
reflection_pb2.ServerReflectionResponse(
valid_host='',
file_descriptor_response=reflection_pb2.FileDescriptorResponse(
file_descriptor_proto=(
_file_descriptor_to_proto(empty_pb2.DESCRIPTOR),
)
)
),
_file_descriptor_to_proto(empty_pb2.DESCRIPTOR),))),
reflection_pb2.ServerReflectionResponse(
valid_host='',
error_response=reflection_pb2.ErrorResponse(
error_code=grpc.StatusCode.NOT_FOUND.value[0],
error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(),
)
),
)
)),)
self.assertSequenceEqual(expected_responses, responses)
@unittest.skip('TODO(atash): implement file-containing-extension reflection '
@unittest.skip(
'TODO(atash): implement file-containing-extension reflection '
'(see https://github.com/google/protobuf/issues/2248)')
def testFileContainingExtension(self):
requests = (
reflection_pb2.ServerReflectionRequest(
file_containing_extension=reflection_pb2.ExtensionRequest(
containing_type='grpc.testing.proto2.Empty',
extension_number=125,
),
),
extension_number=125,),),
reflection_pb2.ServerReflectionRequest(
file_containing_extension=reflection_pb2.ExtensionRequest(
containing_type='i.donut.exist.co.uk.org.net.me.name.foo',
extension_number=55,
),
),
)
extension_number=55,),),)
responses = tuple(self._stub.ServerReflectionInfo(iter(requests)))
expected_responses = (
reflection_pb2.ServerReflectionResponse(
valid_host='',
file_descriptor_response=reflection_pb2.FileDescriptorResponse(
file_descriptor_proto=(
_file_descriptor_to_proto(empty_extensions_pb2.DESCRIPTOR),
)
)
),
file_descriptor_proto=(_file_descriptor_to_proto(
empty_extensions_pb2.DESCRIPTOR),))),
reflection_pb2.ServerReflectionResponse(
valid_host='',
error_response=reflection_pb2.ErrorResponse(
error_code=grpc.StatusCode.NOT_FOUND.value[0],
error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(),
)
),
)
)),)
self.assertSequenceEqual(expected_responses, responses)
def testListServices(self):
requests = (
reflection_pb2.ServerReflectionRequest(
list_services='',
),
)
requests = (reflection_pb2.ServerReflectionRequest(list_services='',),)
responses = tuple(self._stub.ServerReflectionInfo(iter(requests)))
expected_responses = (
reflection_pb2.ServerReflectionResponse(
expected_responses = (reflection_pb2.ServerReflectionResponse(
valid_host='',
list_services_response=reflection_pb2.ListServiceResponse(
service=tuple(
reflection_pb2.ServiceResponse(name=name)
for name in _SERVICE_NAMES
)
)
),
)
for name in _SERVICE_NAMES))),)
self.assertSequenceEqual(expected_responses, responses)
if __name__ == '__main__':
unittest.main(verbosity=2)

@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Entry point for running stress tests."""
import argparse
@ -46,11 +45,13 @@ from tests.stress import test_runner
def _args():
parser = argparse.ArgumentParser(description='gRPC Python stress test client')
parser = argparse.ArgumentParser(
description='gRPC Python stress test client')
parser.add_argument(
'--server_addresses',
help='comma seperated list of hostname:port to run servers on',
default='localhost:8080', type=str)
default='localhost:8080',
type=str)
parser.add_argument(
'--test_cases',
help='comma seperated list of testcase:weighting of tests to run',
@ -59,29 +60,35 @@ def _args():
parser.add_argument(
'--test_duration_secs',
help='number of seconds to run the stress test',
default=-1, type=int)
default=-1,
type=int)
parser.add_argument(
'--num_channels_per_server',
help='number of channels per server',
default=1, type=int)
default=1,
type=int)
parser.add_argument(
'--num_stubs_per_channel',
help='number of stubs to create per channel',
default=1, type=int)
default=1,
type=int)
parser.add_argument(
'--metrics_port',
help='the port to listen for metrics requests on',
default=8081, type=int)
default=8081,
type=int)
parser.add_argument(
'--use_test_ca',
help='Whether to use our fake CA. Requires --use_tls=true',
default=False, type=bool)
default=False,
type=bool)
parser.add_argument(
'--use_tls',
help='Whether to use TLS', default=False, type=bool)
'--use_tls', help='Whether to use TLS', default=False, type=bool)
parser.add_argument(
'--server_host_override', default="foo.test.google.fr",
help='the server host to which to claim to connect', type=str)
'--server_host_override',
default="foo.test.google.fr",
help='the server host to which to claim to connect',
type=str)
return parser.parse_args()
@ -101,6 +108,7 @@ def _parse_weighted_test_cases(test_case_args):
weighted_test_cases[test_case] = int(weight)
return weighted_test_cases
def _get_channel(target, args):
if args.use_tls:
if args.use_test_ca:
@ -109,8 +117,11 @@ def _get_channel(target, args):
root_certificates = None # will load default roots.
channel_credentials = grpc.ssl_channel_credentials(
root_certificates=root_certificates)
options = (('grpc.ssl_target_name_override', args.server_host_override,),)
channel = grpc.secure_channel(target, channel_credentials, options=options)
options = ((
'grpc.ssl_target_name_override',
args.server_host_override,),)
channel = grpc.secure_channel(
target, channel_credentials, options=options)
else:
channel = grpc.insecure_channel(target)
@ -118,6 +129,7 @@ def _get_channel(target, args):
grpc.channel_ready_future(channel).result()
return channel
def run_test(args):
test_cases = _parse_weighted_test_cases(args.test_cases)
test_server_targets = args.server_addresses.split(',')
@ -159,5 +171,6 @@ def run_test(args):
runner = None
server.stop(None)
if __name__ == '__main__':
run_test(_args())

@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""MetricsService for publishing stress test qps data."""
import time

@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Thread that sends random weighted requests on a TestService stub."""
import random

@ -26,5 +26,3 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Test of gRPC Python's application-layer API."""
import unittest
@ -81,19 +80,17 @@ class AllTest(unittest.TestCase):
'channel_ready_future',
'insecure_channel',
'secure_channel',
'server',
)
'server',)
six.assertCountEqual(
self, expected_grpc_code_elements,
six.assertCountEqual(self, expected_grpc_code_elements,
_from_grpc_import_star.GRPC_ELEMENTS)
class ChannelConnectivityTest(unittest.TestCase):
def testChannelConnectivity(self):
self.assertSequenceEqual(
(grpc.ChannelConnectivity.IDLE,
self.assertSequenceEqual((
grpc.ChannelConnectivity.IDLE,
grpc.ChannelConnectivity.CONNECTING,
grpc.ChannelConnectivity.READY,
grpc.ChannelConnectivity.TRANSIENT_FAILURE,

@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Tests of standard AuthMetadataPlugins."""
import collections

@ -26,13 +26,13 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Tests of Channel Args on client/server side."""
import unittest
import grpc
class TestPointerWrapper(object):
def __int__(self):
@ -44,8 +44,7 @@ TEST_CHANNEL_ARGS = (
('arg2', 'str_val'),
('arg3', 1),
(b'arg4', 'str_val'),
('arg6', TestPointerWrapper()),
)
('arg6', TestPointerWrapper()),)
class ChannelArgsTest(unittest.TestCase):
@ -56,5 +55,6 @@ class ChannelArgsTest(unittest.TestCase):
def test_server(self):
grpc.server(None, options=TEST_CHANNEL_ARGS)
if __name__ == '__main__':
unittest.main(verbosity=2)

@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Tests of grpc._channel.Channel connectivity."""
import threading
@ -90,16 +89,12 @@ class ChannelConnectivityTest(unittest.TestCase):
channel.unsubscribe(callback.update)
fifth_connectivities = callback.connectivities()
self.assertSequenceEqual(
(grpc.ChannelConnectivity.IDLE,), first_connectivities)
self.assertNotIn(
grpc.ChannelConnectivity.READY, second_connectivities)
self.assertNotIn(
grpc.ChannelConnectivity.READY, third_connectivities)
self.assertNotIn(
grpc.ChannelConnectivity.READY, fourth_connectivities)
self.assertNotIn(
grpc.ChannelConnectivity.READY, fifth_connectivities)
self.assertSequenceEqual((grpc.ChannelConnectivity.IDLE,),
first_connectivities)
self.assertNotIn(grpc.ChannelConnectivity.READY, second_connectivities)
self.assertNotIn(grpc.ChannelConnectivity.READY, third_connectivities)
self.assertNotIn(grpc.ChannelConnectivity.READY, fourth_connectivities)
self.assertNotIn(grpc.ChannelConnectivity.READY, fifth_connectivities)
def test_immediately_connectable_channel_connectivity(self):
thread_pool = _thread_pool.RecordingThreadPool(max_workers=None)
@ -123,23 +118,24 @@ class ChannelConnectivityTest(unittest.TestCase):
fourth_connectivities = second_callback.block_until_connectivities_satisfy(
bool)
# Wait for a connection that will happen (or may already have happened).
first_callback.block_until_connectivities_satisfy(_ready_in_connectivities)
second_callback.block_until_connectivities_satisfy(_ready_in_connectivities)
first_callback.block_until_connectivities_satisfy(
_ready_in_connectivities)
second_callback.block_until_connectivities_satisfy(
_ready_in_connectivities)
del channel
self.assertSequenceEqual(
(grpc.ChannelConnectivity.IDLE,), first_connectivities)
self.assertSequenceEqual(
(grpc.ChannelConnectivity.IDLE,), second_connectivities)
self.assertNotIn(
grpc.ChannelConnectivity.TRANSIENT_FAILURE, third_connectivities)
self.assertNotIn(
grpc.ChannelConnectivity.SHUTDOWN, third_connectivities)
self.assertNotIn(
grpc.ChannelConnectivity.TRANSIENT_FAILURE,
self.assertSequenceEqual((grpc.ChannelConnectivity.IDLE,),
first_connectivities)
self.assertSequenceEqual((grpc.ChannelConnectivity.IDLE,),
second_connectivities)
self.assertNotIn(grpc.ChannelConnectivity.TRANSIENT_FAILURE,
third_connectivities)
self.assertNotIn(grpc.ChannelConnectivity.SHUTDOWN,
third_connectivities)
self.assertNotIn(grpc.ChannelConnectivity.TRANSIENT_FAILURE,
fourth_connectivities)
self.assertNotIn(grpc.ChannelConnectivity.SHUTDOWN,
fourth_connectivities)
self.assertNotIn(
grpc.ChannelConnectivity.SHUTDOWN, fourth_connectivities)
self.assertFalse(thread_pool.was_used())
def test_reachable_then_unreachable_channel_connectivity(self):
@ -154,7 +150,8 @@ class ChannelConnectivityTest(unittest.TestCase):
callback.block_until_connectivities_satisfy(_ready_in_connectivities)
# Now take down the server and confirm that channel readiness is repudiated.
server.stop(None)
callback.block_until_connectivities_satisfy(_last_connectivity_is_not_ready)
callback.block_until_connectivities_satisfy(
_last_connectivity_is_not_ready)
channel.unsubscribe(callback.update)
self.assertFalse(thread_pool.was_used())

@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Tests of grpc.channel_ready_future."""
import threading
@ -85,7 +84,8 @@ class ChannelReadyFutureTest(unittest.TestCase):
ready_future = grpc.channel_ready_future(channel)
ready_future.add_done_callback(callback.accept_value)
self.assertIsNone(ready_future.result(timeout=test_constants.LONG_TIMEOUT))
self.assertIsNone(
ready_future.result(timeout=test_constants.LONG_TIMEOUT))
value_passed_to_callback = callback.block_until_called()
self.assertIs(ready_future, value_passed_to_callback)
self.assertFalse(ready_future.cancelled())

@ -42,16 +42,16 @@ _STREAM_STREAM = '/test/StreamStream'
def handle_unary(request, servicer_context):
servicer_context.send_initial_metadata([
('grpc-internal-encoding-request', 'gzip')])
servicer_context.send_initial_metadata(
[('grpc-internal-encoding-request', 'gzip')])
return request
def handle_stream(request_iterator, servicer_context):
# TODO(issue:#6891) We should be able to remove this loop,
# and replace with return; yield
servicer_context.send_initial_metadata([
('grpc-internal-encoding-request', 'gzip')])
servicer_context.send_initial_metadata(
[('grpc-internal-encoding-request', 'gzip')])
for request in request_iterator:
yield request
@ -100,7 +100,8 @@ class CompressionTest(unittest.TestCase):
# settings. Server -> client compressed via server-side metadata setting.
# TODO(https://github.com/grpc/grpc/issues/4078): replace the "1" integer
# literal with proper use of the public API.
compressed_channel = grpc.insecure_channel('localhost:%d' % self._port,
compressed_channel = grpc.insecure_channel(
'localhost:%d' % self._port,
options=[('grpc.default_compression_algorithm', 1)])
multi_callable = compressed_channel.unary_unary(_UNARY_UNARY)
response = multi_callable(request)
@ -110,11 +111,12 @@ class CompressionTest(unittest.TestCase):
# client compressed via server-side metadata setting.
# TODO(https://github.com/grpc/grpc/issues/4078): replace the "0" integer
# literal with proper use of the public API.
uncompressed_channel = grpc.insecure_channel('localhost:%d' % self._port,
uncompressed_channel = grpc.insecure_channel(
'localhost:%d' % self._port,
options=[('grpc.default_compression_algorithm', 0)])
multi_callable = compressed_channel.unary_unary(_UNARY_UNARY)
response = multi_callable(request, metadata=[
('grpc-internal-encoding-request', 'gzip')])
response = multi_callable(
request, metadata=[('grpc-internal-encoding-request', 'gzip')])
self.assertEqual(request, response)
def testStreaming(self):
@ -122,7 +124,8 @@ class CompressionTest(unittest.TestCase):
# TODO(https://github.com/grpc/grpc/issues/4078): replace the "1" integer
# literal with proper use of the public API.
compressed_channel = grpc.insecure_channel('localhost:%d' % self._port,
compressed_channel = grpc.insecure_channel(
'localhost:%d' % self._port,
options=[('grpc.default_compression_algorithm', 1)])
multi_callable = compressed_channel.stream_stream(_STREAM_STREAM)
call = multi_callable(iter([request] * test_constants.STREAM_LENGTH))

@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Tests of credentials."""
import unittest
@ -42,8 +41,8 @@ class CredentialsTest(unittest.TestCase):
third = grpc.access_token_call_credentials('ghi')
first_and_second = grpc.composite_call_credentials(first, second)
first_second_and_third = grpc.composite_call_credentials(
first, second, third)
first_second_and_third = grpc.composite_call_credentials(first, second,
third)
self.assertIsInstance(first_and_second, grpc.CallCredentials)
self.assertIsInstance(first_second_and_third, grpc.CallCredentials)
@ -57,15 +56,16 @@ class CredentialsTest(unittest.TestCase):
channel_and_first = grpc.composite_channel_credentials(
channel_credentials, first_call_credentials)
channel_first_and_second = grpc.composite_channel_credentials(
channel_credentials, first_call_credentials, second_call_credentials)
channel_credentials, first_call_credentials,
second_call_credentials)
channel_first_second_and_third = grpc.composite_channel_credentials(
channel_credentials, first_call_credentials, second_call_credentials,
third_call_credentials)
channel_credentials, first_call_credentials,
second_call_credentials, third_call_credentials)
self.assertIsInstance(channel_and_first, grpc.ChannelCredentials)
self.assertIsInstance(channel_first_and_second, grpc.ChannelCredentials)
self.assertIsInstance(
channel_first_second_and_third, grpc.ChannelCredentials)
self.assertIsInstance(channel_first_second_and_third,
grpc.ChannelCredentials)
if __name__ == '__main__':

@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Test making many calls and immediately cancelling most of them."""
import threading
@ -59,8 +58,7 @@ class _State(object):
def _is_cancellation_event(event):
return (
event.tag is _RECEIVE_CLOSE_ON_SERVER_TAG and
return (event.tag is _RECEIVE_CLOSE_ON_SERVER_TAG and
event.batch_operations[0].received_cancelled)
@ -86,7 +84,8 @@ class _Handler(object):
(cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS),)),
_RECEIVE_CLOSE_ON_SERVER_TAG)
self._call.start_server_batch(
cygrpc.Operations((cygrpc.operation_receive_message(_EMPTY_FLAGS),)),
cygrpc.Operations(
(cygrpc.operation_receive_message(_EMPTY_FLAGS),)),
_RECEIVE_MESSAGE_TAG)
first_event = self._completion_queue.poll()
if _is_cancellation_event(first_event):
@ -94,13 +93,12 @@ class _Handler(object):
else:
with self._lock:
operations = (
cygrpc.operation_send_initial_metadata(
_EMPTY_METADATA, _EMPTY_FLAGS),
cygrpc.operation_send_initial_metadata(_EMPTY_METADATA,
_EMPTY_FLAGS),
cygrpc.operation_send_message(b'\x79\x57', _EMPTY_FLAGS),
cygrpc.operation_send_status_from_server(
_EMPTY_METADATA, cygrpc.StatusCode.ok, b'test details!',
_EMPTY_FLAGS),
)
_EMPTY_FLAGS),)
self._call.start_server_batch(
cygrpc.Operations(operations), _SERVER_COMPLETE_CALL_TAG)
self._completion_queue.poll()
@ -110,8 +108,8 @@ class _Handler(object):
def _serve(state, server, server_completion_queue, thread_pool):
for _ in range(test_constants.RPC_CONCURRENCY):
call_completion_queue = cygrpc.CompletionQueue()
server.request_call(
call_completion_queue, server_completion_queue, _REQUEST_CALL_TAG)
server.request_call(call_completion_queue, server_completion_queue,
_REQUEST_CALL_TAG)
rpc_event = server_completion_queue.poll()
thread_pool.submit(_Handler(state, call_completion_queue, rpc_event))
with state.condition:
@ -131,6 +129,7 @@ class _QueueDriver(object):
self._returned = False
def start(self):
def in_thread():
while True:
event = self._completion_queue.poll()
@ -141,6 +140,7 @@ class _QueueDriver(object):
if not self._due:
self._returned = True
return
thread = threading.Thread(target=in_thread)
thread.start()
@ -154,7 +154,8 @@ class _QueueDriver(object):
class CancelManyCallsTest(unittest.TestCase):
def testCancelManyCalls(self):
server_thread_pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY)
server_thread_pool = logging_pool.pool(
test_constants.THREAD_CONCURRENCY)
server_completion_queue = cygrpc.CompletionQueue()
server = cygrpc.Server(cygrpc.ChannelArgs([]))
@ -167,34 +168,37 @@ class CancelManyCallsTest(unittest.TestCase):
state = _State()
server_thread_args = (
state, server, server_completion_queue, server_thread_pool,)
state,
server,
server_completion_queue,
server_thread_pool,)
server_thread = threading.Thread(target=_serve, args=server_thread_args)
server_thread.start()
client_condition = threading.Condition()
client_due = set()
client_completion_queue = cygrpc.CompletionQueue()
client_driver = _QueueDriver(
client_condition, client_completion_queue, client_due)
client_driver = _QueueDriver(client_condition, client_completion_queue,
client_due)
client_driver.start()
with client_condition:
client_calls = []
for index in range(test_constants.RPC_CONCURRENCY):
client_call = channel.create_call(
None, _EMPTY_FLAGS, client_completion_queue, b'/twinkies', None,
_INFINITE_FUTURE)
None, _EMPTY_FLAGS, client_completion_queue, b'/twinkies',
None, _INFINITE_FUTURE)
operations = (
cygrpc.operation_send_initial_metadata(
_EMPTY_METADATA, _EMPTY_FLAGS),
cygrpc.operation_send_initial_metadata(_EMPTY_METADATA,
_EMPTY_FLAGS),
cygrpc.operation_send_message(b'\x45\x56', _EMPTY_FLAGS),
cygrpc.operation_send_close_from_client(_EMPTY_FLAGS),
cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),
cygrpc.operation_receive_message(_EMPTY_FLAGS),
cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),
)
cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),)
tag = 'client_complete_call_{0:04d}_tag'.format(index)
client_call.start_client_batch(cygrpc.Operations(operations), tag)
client_call.start_client_batch(
cygrpc.Operations(operations), tag)
client_due.add(tag)
client_calls.append(client_call)
@ -209,8 +213,8 @@ class CancelManyCallsTest(unittest.TestCase):
state.condition.notify_all()
break
client_driver.events(
test_constants.RPC_CONCURRENCY * _SUCCESS_CALL_FRACTION)
client_driver.events(test_constants.RPC_CONCURRENCY *
_SUCCESS_CALL_FRACTION)
with client_condition:
for client_call in client_calls:
client_call.cancel()

@ -45,9 +45,9 @@ def _channel_and_completion_queue():
def _connectivity_loop(channel, completion_queue):
for _ in range(100):
connectivity = channel.check_connectivity_state(True)
channel.watch_connectivity_state(
connectivity, cygrpc.Timespec(time.time() + 0.2), completion_queue,
None)
channel.watch_connectivity_state(connectivity,
cygrpc.Timespec(time.time() + 0.2),
completion_queue, None)
completion_queue.poll(deadline=cygrpc.Timespec(float('+inf')))
@ -59,7 +59,8 @@ def _create_loop_destroy():
def _in_parallel(behavior, arguments):
threads = tuple(
threading.Thread(target=behavior, args=arguments)
threading.Thread(
target=behavior, args=arguments)
for _ in range(test_constants.THREAD_CONCURRENCY))
for thread in threads:
thread.start()
@ -71,7 +72,9 @@ class ChannelTest(unittest.TestCase):
def test_single_channel_lonely_connectivity(self):
channel, completion_queue = _channel_and_completion_queue()
_in_parallel(_connectivity_loop, (channel, completion_queue,))
_in_parallel(_connectivity_loop, (
channel,
completion_queue,))
completion_queue.shutdown()
def test_multiple_channels_lonely_connectivity(self):

@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Test a corner-case at the level of the Cython API."""
import threading
@ -49,6 +48,7 @@ class _ServerDriver(object):
self._saw_shutdown_tag = False
def start(self):
def in_thread():
while True:
event = self._completion_queue.poll()
@ -58,6 +58,7 @@ class _ServerDriver(object):
if event.tag is self._shutdown_tag:
self._saw_shutdown_tag = True
break
thread = threading.Thread(target=in_thread)
thread.start()
@ -88,6 +89,7 @@ class _QueueDriver(object):
self._returned = False
def start(self):
def in_thread():
while True:
event = self._completion_queue.poll()
@ -98,6 +100,7 @@ class _QueueDriver(object):
if not self._due:
self._returned = True
return
thread = threading.Thread(target=in_thread)
thread.start()
@ -132,14 +135,15 @@ class ReadSomeButNotAllResponsesTest(unittest.TestCase):
cygrpc.ChannelArgs([]))
server_shutdown_tag = 'server_shutdown_tag'
server_driver = _ServerDriver(server_completion_queue, server_shutdown_tag)
server_driver = _ServerDriver(server_completion_queue,
server_shutdown_tag)
server_driver.start()
client_condition = threading.Condition()
client_due = set()
client_completion_queue = cygrpc.CompletionQueue()
client_driver = _QueueDriver(
client_condition, client_completion_queue, client_due)
client_driver = _QueueDriver(client_condition, client_completion_queue,
client_due)
client_driver.start()
server_call_condition = threading.Condition()
@ -151,32 +155,35 @@ class ReadSomeButNotAllResponsesTest(unittest.TestCase):
server_send_initial_metadata_tag,
server_send_first_message_tag,
server_send_second_message_tag,
server_complete_rpc_tag,
))
server_complete_rpc_tag,))
server_call_completion_queue = cygrpc.CompletionQueue()
server_call_driver = _QueueDriver(
server_call_condition, server_call_completion_queue, server_call_due)
server_call_driver = _QueueDriver(server_call_condition,
server_call_completion_queue,
server_call_due)
server_call_driver.start()
server_rpc_tag = 'server_rpc_tag'
request_call_result = server.request_call(
server_call_completion_queue, server_completion_queue, server_rpc_tag)
request_call_result = server.request_call(server_call_completion_queue,
server_completion_queue,
server_rpc_tag)
client_call = channel.create_call(
None, _EMPTY_FLAGS, client_completion_queue, b'/twinkies', None,
_INFINITE_FUTURE)
client_call = channel.create_call(None, _EMPTY_FLAGS,
client_completion_queue, b'/twinkies',
None, _INFINITE_FUTURE)
client_receive_initial_metadata_tag = 'client_receive_initial_metadata_tag'
client_complete_rpc_tag = 'client_complete_rpc_tag'
with client_condition:
client_receive_initial_metadata_start_batch_result = (
client_call.start_client_batch(cygrpc.Operations([
client_call.start_client_batch(
cygrpc.Operations([
cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),
]), client_receive_initial_metadata_tag))
client_due.add(client_receive_initial_metadata_tag)
client_complete_rpc_start_batch_result = (
client_call.start_client_batch(cygrpc.Operations([
cygrpc.operation_send_initial_metadata(
_EMPTY_METADATA, _EMPTY_FLAGS),
client_call.start_client_batch(
cygrpc.Operations([
cygrpc.operation_send_initial_metadata(_EMPTY_METADATA,
_EMPTY_FLAGS),
cygrpc.operation_send_close_from_client(_EMPTY_FLAGS),
cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),
]), client_complete_rpc_tag))
@ -187,8 +194,8 @@ class ReadSomeButNotAllResponsesTest(unittest.TestCase):
with server_call_condition:
server_send_initial_metadata_start_batch_result = (
server_rpc_event.operation_call.start_server_batch([
cygrpc.operation_send_initial_metadata(
_EMPTY_METADATA, _EMPTY_FLAGS),
cygrpc.operation_send_initial_metadata(_EMPTY_METADATA,
_EMPTY_FLAGS),
], server_send_initial_metadata_tag))
server_send_first_message_start_batch_result = (
server_rpc_event.operation_call.start_server_batch([
@ -207,8 +214,8 @@ class ReadSomeButNotAllResponsesTest(unittest.TestCase):
server_rpc_event.operation_call.start_server_batch([
cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS),
cygrpc.operation_send_status_from_server(
cygrpc.Metadata(()), cygrpc.StatusCode.ok, b'test details',
_EMPTY_FLAGS),
cygrpc.Metadata(()), cygrpc.StatusCode.ok,
b'test details', _EMPTY_FLAGS),
], server_complete_rpc_tag))
server_send_second_message_event = server_call_driver.event_with_tag(
server_send_second_message_tag)
@ -219,7 +226,8 @@ class ReadSomeButNotAllResponsesTest(unittest.TestCase):
with client_condition:
client_receive_first_message_tag = 'client_receive_first_message_tag'
client_receive_first_message_start_batch_result = (
client_call.start_client_batch(cygrpc.Operations([
client_call.start_client_batch(
cygrpc.Operations([
cygrpc.operation_receive_message(_EMPTY_FLAGS),
]), client_receive_first_message_tag))
client_due.add(client_receive_first_message_tag)
@ -234,16 +242,16 @@ class ReadSomeButNotAllResponsesTest(unittest.TestCase):
server_driver.events()
self.assertEqual(cygrpc.CallError.ok, request_call_result)
self.assertEqual(
cygrpc.CallError.ok, server_send_initial_metadata_start_batch_result)
self.assertEqual(
cygrpc.CallError.ok, client_receive_initial_metadata_start_batch_result)
self.assertEqual(
cygrpc.CallError.ok, client_complete_rpc_start_batch_result)
self.assertEqual(cygrpc.CallError.ok,
server_send_initial_metadata_start_batch_result)
self.assertEqual(cygrpc.CallError.ok,
client_receive_initial_metadata_start_batch_result)
self.assertEqual(cygrpc.CallError.ok,
client_complete_rpc_start_batch_result)
self.assertEqual(cygrpc.CallError.ok, client_call_cancel_result)
self.assertIs(server_rpc_tag, server_rpc_event.tag)
self.assertEqual(
cygrpc.CompletionType.operation_complete, server_rpc_event.type)
self.assertEqual(cygrpc.CompletionType.operation_complete,
server_rpc_event.type)
self.assertIsInstance(server_rpc_event.operation_call, cygrpc.Call)
self.assertEqual(0, len(server_rpc_event.batch_operations))

@ -37,17 +37,18 @@ from tests.unit._cython import test_utilities
from tests.unit import test_common
from tests.unit import resources
_SSL_HOST_OVERRIDE = b'foo.test.google.fr'
_CALL_CREDENTIALS_METADATA_KEY = 'call-creds-key'
_CALL_CREDENTIALS_METADATA_VALUE = 'call-creds-value'
_EMPTY_FLAGS = 0
def _metadata_plugin_callback(context, callback):
callback(cygrpc.Metadata(
[cygrpc.Metadatum(_CALL_CREDENTIALS_METADATA_KEY,
_CALL_CREDENTIALS_METADATA_VALUE)]),
cygrpc.StatusCode.ok, b'')
callback(
cygrpc.Metadata([
cygrpc.Metadatum(_CALL_CREDENTIALS_METADATA_KEY,
_CALL_CREDENTIALS_METADATA_VALUE)
]), cygrpc.StatusCode.ok, b'')
class TypeSmokeTest(unittest.TestCase):
@ -62,8 +63,8 @@ class TypeSmokeTest(unittest.TestCase):
self.assertEqual(metadatum.key, metadata[0].key)
def testMetadataIteration(self):
metadata = cygrpc.Metadata([
cygrpc.Metadatum(b'a', b'b'), cygrpc.Metadatum(b'c', b'd')])
metadata = cygrpc.Metadata(
[cygrpc.Metadatum(b'a', b'b'), cygrpc.Metadatum(b'c', b'd')])
iterator = iter(metadata)
metadatum = next(iterator)
self.assertIsInstance(metadatum, cygrpc.Metadatum)
@ -77,8 +78,8 @@ class TypeSmokeTest(unittest.TestCase):
next(iterator)
def testOperationsIteration(self):
operations = cygrpc.Operations([
cygrpc.operation_send_message(b'asdf', _EMPTY_FLAGS)])
operations = cygrpc.Operations(
[cygrpc.operation_send_message(b'asdf', _EMPTY_FLAGS)])
iterator = iter(operations)
operation = next(iterator)
self.assertIsInstance(operation, cygrpc.Operation)
@ -115,7 +116,8 @@ class TypeSmokeTest(unittest.TestCase):
del plugin
def testCallCredentialsFromPluginUpDown(self):
plugin = cygrpc.CredentialsMetadataPlugin(_metadata_plugin_callback, b'')
plugin = cygrpc.CredentialsMetadataPlugin(_metadata_plugin_callback,
b'')
call_credentials = cygrpc.call_credentials_metadata_plugin(plugin)
del plugin
del call_credentials
@ -151,7 +153,8 @@ class ServerClientMixin(object):
self.server = cygrpc.Server(cygrpc.ChannelArgs([]))
self.server.register_completion_queue(self.server_completion_queue)
if server_credentials:
self.port = self.server.add_http2_port(b'[::]:0', server_credentials)
self.port = self.server.add_http2_port(b'[::]:0',
server_credentials)
else:
self.port = self.server.add_http2_port(b'[::]:0')
self.server.start()
@ -159,13 +162,15 @@ class ServerClientMixin(object):
if client_credentials:
client_channel_arguments = cygrpc.ChannelArgs([
cygrpc.ChannelArg(cygrpc.ChannelArgKey.ssl_target_name_override,
host_override)])
host_override)
])
self.client_channel = cygrpc.Channel(
'localhost:{}'.format(self.port).encode(), client_channel_arguments,
client_credentials)
'localhost:{}'.format(self.port).encode(),
client_channel_arguments, client_credentials)
else:
self.client_channel = cygrpc.Channel(
'localhost:{}'.format(self.port).encode(), cygrpc.ChannelArgs([]))
'localhost:{}'.format(self.port).encode(),
cygrpc.ChannelArgs([]))
if host_override:
self.host_argument = None # default host
self.expected_host = host_override
@ -179,12 +184,14 @@ class ServerClientMixin(object):
del self.client_completion_queue
del self.server_completion_queue
def _perform_operations(self, operations, call, queue, deadline, description):
def _perform_operations(self, operations, call, queue, deadline,
description):
"""Perform the list of operations with given call, queue, and deadline.
Invocation errors are reported with as an exception with `description` in
the message. Performs the operations asynchronously, returning a future.
"""
def performer():
tag = object()
try:
@ -192,12 +199,15 @@ class ServerClientMixin(object):
cygrpc.Operations(operations), tag)
self.assertEqual(cygrpc.CallError.ok, call_result)
event = queue.poll(deadline)
self.assertEqual(cygrpc.CompletionType.operation_complete, event.type)
self.assertEqual(cygrpc.CompletionType.operation_complete,
event.type)
self.assertTrue(event.success)
self.assertIs(tag, event.tag)
except Exception as error:
raise Exception("Error in '{}': {}".format(description, error.message))
raise Exception("Error in '{}': {}".format(description,
error.message))
return event
return test_utilities.SimpleFuture(performer)
def testEcho(self):
@ -233,7 +243,8 @@ class ServerClientMixin(object):
client_initial_metadata = cygrpc.Metadata([
cygrpc.Metadatum(CLIENT_METADATA_ASCII_KEY,
CLIENT_METADATA_ASCII_VALUE),
cygrpc.Metadatum(CLIENT_METADATA_BIN_KEY, CLIENT_METADATA_BIN_VALUE)])
cygrpc.Metadatum(CLIENT_METADATA_BIN_KEY, CLIENT_METADATA_BIN_VALUE)
])
client_start_batch_result = client_call.start_client_batch([
cygrpc.operation_send_initial_metadata(client_initial_metadata,
_EMPTY_FLAGS),
@ -267,14 +278,16 @@ class ServerClientMixin(object):
server_call = request_event.operation_call
server_initial_metadata = cygrpc.Metadata([
cygrpc.Metadatum(SERVER_INITIAL_METADATA_KEY,
SERVER_INITIAL_METADATA_VALUE)])
SERVER_INITIAL_METADATA_VALUE)
])
server_trailing_metadata = cygrpc.Metadata([
cygrpc.Metadatum(SERVER_TRAILING_METADATA_KEY,
SERVER_TRAILING_METADATA_VALUE)])
SERVER_TRAILING_METADATA_VALUE)
])
server_start_batch_result = server_call.start_server_batch([
cygrpc.operation_send_initial_metadata(server_initial_metadata,
_EMPTY_FLAGS),
cygrpc.operation_receive_message(_EMPTY_FLAGS),
cygrpc.operation_send_initial_metadata(
server_initial_metadata,
_EMPTY_FLAGS), cygrpc.operation_receive_message(_EMPTY_FLAGS),
cygrpc.operation_send_message(RESPONSE, _EMPTY_FLAGS),
cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS),
cygrpc.operation_send_status_from_server(
@ -294,18 +307,23 @@ class ServerClientMixin(object):
found_client_op_types.add(client_result.type)
if client_result.type == cygrpc.OperationType.receive_initial_metadata:
self.assertTrue(
test_common.metadata_transmitted(server_initial_metadata,
test_common.metadata_transmitted(
server_initial_metadata,
client_result.received_metadata))
elif client_result.type == cygrpc.OperationType.receive_message:
self.assertEqual(RESPONSE, client_result.received_message.bytes())
self.assertEqual(RESPONSE,
client_result.received_message.bytes())
elif client_result.type == cygrpc.OperationType.receive_status_on_client:
self.assertTrue(
test_common.metadata_transmitted(server_trailing_metadata,
test_common.metadata_transmitted(
server_trailing_metadata,
client_result.received_metadata))
self.assertEqual(SERVER_STATUS_DETAILS,
client_result.received_status_details)
self.assertEqual(SERVER_STATUS_CODE, client_result.received_status_code)
self.assertEqual(set([
self.assertEqual(SERVER_STATUS_CODE,
client_result.received_status_code)
self.assertEqual(
set([
cygrpc.OperationType.send_initial_metadata,
cygrpc.OperationType.send_message,
cygrpc.OperationType.send_close_from_client,
@ -320,10 +338,12 @@ class ServerClientMixin(object):
self.assertNotIn(client_result.type, found_server_op_types)
found_server_op_types.add(server_result.type)
if server_result.type == cygrpc.OperationType.receive_message:
self.assertEqual(REQUEST, server_result.received_message.bytes())
self.assertEqual(REQUEST,
server_result.received_message.bytes())
elif server_result.type == cygrpc.OperationType.receive_close_on_server:
self.assertFalse(server_result.received_cancelled)
self.assertEqual(set([
self.assertEqual(
set([
cygrpc.OperationType.send_initial_metadata,
cygrpc.OperationType.receive_message,
cygrpc.OperationType.send_message,
@ -343,8 +363,8 @@ class ServerClientMixin(object):
empty_metadata = cygrpc.Metadata([])
server_request_tag = object()
self.server.request_call(
self.server_completion_queue, self.server_completion_queue,
self.server.request_call(self.server_completion_queue,
self.server_completion_queue,
server_request_tag)
client_call = self.client_channel.create_call(
None, 0, self.client_completion_queue, METHOD, self.host_argument,
@ -352,9 +372,9 @@ class ServerClientMixin(object):
# Prologue
def perform_client_operations(operations, description):
return self._perform_operations(
operations, client_call,
self.client_completion_queue, cygrpc_deadline, description)
return self._perform_operations(operations, client_call,
self.client_completion_queue,
cygrpc_deadline, description)
client_event_future = perform_client_operations([
cygrpc.operation_send_initial_metadata(empty_metadata,
@ -366,9 +386,9 @@ class ServerClientMixin(object):
server_call = request_event.operation_call
def perform_server_operations(operations, description):
return self._perform_operations(
operations, server_call,
self.server_completion_queue, cygrpc_deadline, description)
return self._perform_operations(operations, server_call,
self.server_completion_queue,
cygrpc_deadline, description)
server_event_future = perform_server_operations([
cygrpc.operation_send_initial_metadata(empty_metadata,
@ -420,12 +440,14 @@ class InsecureServerInsecureClient(unittest.TestCase, ServerClientMixin):
class SecureServerSecureClient(unittest.TestCase, ServerClientMixin):
def setUp(self):
server_credentials = cygrpc.server_credentials_ssl(
None, [cygrpc.SslPemKeyCertPair(resources.private_key(),
resources.certificate_chain())], False)
server_credentials = cygrpc.server_credentials_ssl(None, [
cygrpc.SslPemKeyCertPair(resources.private_key(),
resources.certificate_chain())
], False)
client_credentials = cygrpc.channel_credentials_ssl(
resources.test_root_certificates(), None)
self.setUpMixin(server_credentials, client_credentials, _SSL_HOST_OVERRIDE)
self.setUpMixin(server_credentials, client_credentials,
_SSL_HOST_OVERRIDE)
def tearDown(self):
self.tearDownMixin()

@ -36,11 +36,13 @@ class SimpleFuture(object):
"""A simple future mechanism."""
def __init__(self, function, *args, **kwargs):
def wrapped_function():
try:
self._result = function(*args, **kwargs)
except Exception as error:
self._error = error
self._result = None
self._error = None
self._thread = threading.Thread(target=wrapped_function)
@ -61,6 +63,5 @@ class SimpleFuture(object):
class CompletionQueuePollFuture(SimpleFuture):
def __init__(self, completion_queue, deadline):
super(CompletionQueuePollFuture, self).__init__(
lambda: completion_queue.poll(deadline))
super(CompletionQueuePollFuture,
self).__init__(lambda: completion_queue.poll(deadline))

@ -118,21 +118,20 @@ class EmptyMessageTest(unittest.TestCase):
def testUnaryStream(self):
response_iterator = self._channel.unary_stream(_UNARY_STREAM)(_REQUEST)
self.assertSequenceEqual(
[_RESPONSE] * test_constants.STREAM_LENGTH, list(response_iterator))
self.assertSequenceEqual([_RESPONSE] * test_constants.STREAM_LENGTH,
list(response_iterator))
def testStreamUnary(self):
response = self._channel.stream_unary(_STREAM_UNARY)(
iter([_REQUEST] * test_constants.STREAM_LENGTH))
response = self._channel.stream_unary(_STREAM_UNARY)(iter(
[_REQUEST] * test_constants.STREAM_LENGTH))
self.assertEqual(_RESPONSE, response)
def testStreamStream(self):
response_iterator = self._channel.stream_stream(_STREAM_STREAM)(
iter([_REQUEST] * test_constants.STREAM_LENGTH))
self.assertSequenceEqual(
[_RESPONSE] * test_constants.STREAM_LENGTH, list(response_iterator))
response_iterator = self._channel.stream_stream(_STREAM_STREAM)(iter(
[_REQUEST] * test_constants.STREAM_LENGTH))
self.assertSequenceEqual([_RESPONSE] * test_constants.STREAM_LENGTH,
list(response_iterator))
if __name__ == '__main__':
unittest.main(verbosity=2)

@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Defines a number of module-scope gRPC scenarios to test clean exit."""
import argparse

@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Tests clean exit of server/client on Python Interpreter exit/sigint.
The tests in this module spawn a subprocess for each test case, the
@ -45,7 +44,8 @@ import unittest
from tests.unit import _exit_scenarios
SCENARIO_FILE = os.path.abspath(os.path.join(
SCENARIO_FILE = os.path.abspath(
os.path.join(
os.path.dirname(os.path.realpath(__file__)), '_exit_scenarios.py'))
INTERPRETER = sys.executable
BASE_COMMAND = [INTERPRETER, SCENARIO_FILE]
@ -53,7 +53,6 @@ BASE_SIGTERM_COMMAND = BASE_COMMAND + ['--wait_for_interrupt']
INIT_TIME = 1.0
processes = []
process_lock = threading.Lock()
@ -67,6 +66,8 @@ def cleanup_processes():
process.kill()
except Exception:
pass
atexit.register(cleanup_processes)
@ -90,7 +91,8 @@ class ExitTest(unittest.TestCase):
def test_unstarted_server(self):
process = subprocess.Popen(
BASE_COMMAND + [_exit_scenarios.UNSTARTED_SERVER],
stdout=sys.stdout, stderr=sys.stderr)
stdout=sys.stdout,
stderr=sys.stderr)
wait(process)
def test_unstarted_server_terminate(self):
@ -102,83 +104,100 @@ class ExitTest(unittest.TestCase):
def test_running_server(self):
process = subprocess.Popen(
BASE_COMMAND + [_exit_scenarios.RUNNING_SERVER],
stdout=sys.stdout, stderr=sys.stderr)
stdout=sys.stdout,
stderr=sys.stderr)
wait(process)
def test_running_server_terminate(self):
process = subprocess.Popen(
BASE_SIGTERM_COMMAND + [_exit_scenarios.RUNNING_SERVER],
stdout=sys.stdout, stderr=sys.stderr)
stdout=sys.stdout,
stderr=sys.stderr)
interrupt_and_wait(process)
def test_poll_connectivity_no_server(self):
process = subprocess.Popen(
BASE_COMMAND + [_exit_scenarios.POLL_CONNECTIVITY_NO_SERVER],
stdout=sys.stdout, stderr=sys.stderr)
stdout=sys.stdout,
stderr=sys.stderr)
wait(process)
def test_poll_connectivity_no_server_terminate(self):
process = subprocess.Popen(
BASE_SIGTERM_COMMAND + [_exit_scenarios.POLL_CONNECTIVITY_NO_SERVER],
stdout=sys.stdout, stderr=sys.stderr)
BASE_SIGTERM_COMMAND +
[_exit_scenarios.POLL_CONNECTIVITY_NO_SERVER],
stdout=sys.stdout,
stderr=sys.stderr)
interrupt_and_wait(process)
def test_poll_connectivity(self):
process = subprocess.Popen(
BASE_COMMAND + [_exit_scenarios.POLL_CONNECTIVITY],
stdout=sys.stdout, stderr=sys.stderr)
stdout=sys.stdout,
stderr=sys.stderr)
wait(process)
def test_poll_connectivity_terminate(self):
process = subprocess.Popen(
BASE_SIGTERM_COMMAND + [_exit_scenarios.POLL_CONNECTIVITY],
stdout=sys.stdout, stderr=sys.stderr)
stdout=sys.stdout,
stderr=sys.stderr)
interrupt_and_wait(process)
def test_in_flight_unary_unary_call(self):
process = subprocess.Popen(
BASE_COMMAND + [_exit_scenarios.IN_FLIGHT_UNARY_UNARY_CALL],
stdout=sys.stdout, stderr=sys.stderr)
stdout=sys.stdout,
stderr=sys.stderr)
interrupt_and_wait(process)
@unittest.skipIf(six.PY2, 'https://github.com/grpc/grpc/issues/6999')
def test_in_flight_unary_stream_call(self):
process = subprocess.Popen(
BASE_COMMAND + [_exit_scenarios.IN_FLIGHT_UNARY_STREAM_CALL],
stdout=sys.stdout, stderr=sys.stderr)
stdout=sys.stdout,
stderr=sys.stderr)
interrupt_and_wait(process)
def test_in_flight_stream_unary_call(self):
process = subprocess.Popen(
BASE_COMMAND + [_exit_scenarios.IN_FLIGHT_STREAM_UNARY_CALL],
stdout=sys.stdout, stderr=sys.stderr)
stdout=sys.stdout,
stderr=sys.stderr)
interrupt_and_wait(process)
@unittest.skipIf(six.PY2, 'https://github.com/grpc/grpc/issues/6999')
def test_in_flight_stream_stream_call(self):
process = subprocess.Popen(
BASE_COMMAND + [_exit_scenarios.IN_FLIGHT_STREAM_STREAM_CALL],
stdout=sys.stdout, stderr=sys.stderr)
stdout=sys.stdout,
stderr=sys.stderr)
interrupt_and_wait(process)
@unittest.skipIf(six.PY2, 'https://github.com/grpc/grpc/issues/6999')
def test_in_flight_partial_unary_stream_call(self):
process = subprocess.Popen(
BASE_COMMAND + [_exit_scenarios.IN_FLIGHT_PARTIAL_UNARY_STREAM_CALL],
stdout=sys.stdout, stderr=sys.stderr)
BASE_COMMAND +
[_exit_scenarios.IN_FLIGHT_PARTIAL_UNARY_STREAM_CALL],
stdout=sys.stdout,
stderr=sys.stderr)
interrupt_and_wait(process)
def test_in_flight_partial_stream_unary_call(self):
process = subprocess.Popen(
BASE_COMMAND + [_exit_scenarios.IN_FLIGHT_PARTIAL_STREAM_UNARY_CALL],
stdout=sys.stdout, stderr=sys.stderr)
BASE_COMMAND +
[_exit_scenarios.IN_FLIGHT_PARTIAL_STREAM_UNARY_CALL],
stdout=sys.stdout,
stderr=sys.stderr)
interrupt_and_wait(process)
@unittest.skipIf(six.PY2, 'https://github.com/grpc/grpc/issues/6999')
def test_in_flight_partial_stream_stream_call(self):
process = subprocess.Popen(
BASE_COMMAND + [_exit_scenarios.IN_FLIGHT_PARTIAL_STREAM_STREAM_CALL],
stdout=sys.stdout, stderr=sys.stderr)
BASE_COMMAND +
[_exit_scenarios.IN_FLIGHT_PARTIAL_STREAM_STREAM_CALL],
stdout=sys.stdout,
stderr=sys.stderr)
interrupt_and_wait(process)

@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Test of RPCs made against gRPC Python's application-layer API."""
import unittest
@ -100,10 +99,10 @@ class InvalidMetadataTest(unittest.TestCase):
response_future = self._unary_unary.future(request, metadata=metadata)
with self.assertRaises(grpc.RpcError) as exception_context:
response_future.result()
self.assertEqual(
exception_context.exception.details(), expected_error_details)
self.assertEqual(
exception_context.exception.code(), grpc.StatusCode.INTERNAL)
self.assertEqual(exception_context.exception.details(),
expected_error_details)
self.assertEqual(exception_context.exception.code(),
grpc.StatusCode.INTERNAL)
self.assertEqual(response_future.details(), expected_error_details)
self.assertEqual(response_future.code(), grpc.StatusCode.INTERNAL)
@ -114,15 +113,16 @@ class InvalidMetadataTest(unittest.TestCase):
response_iterator = self._unary_stream(request, metadata=metadata)
with self.assertRaises(grpc.RpcError) as exception_context:
next(response_iterator)
self.assertEqual(
exception_context.exception.details(), expected_error_details)
self.assertEqual(
exception_context.exception.code(), grpc.StatusCode.INTERNAL)
self.assertEqual(exception_context.exception.details(),
expected_error_details)
self.assertEqual(exception_context.exception.code(),
grpc.StatusCode.INTERNAL)
self.assertEqual(response_iterator.details(), expected_error_details)
self.assertEqual(response_iterator.code(), grpc.StatusCode.INTERNAL)
def testStreamRequestBlockingUnaryResponse(self):
request_iterator = (b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH))
request_iterator = (b'\x07\x08'
for _ in range(test_constants.STREAM_LENGTH))
metadata = (('InVaLiD', 'StreamRequestBlockingUnaryResponse'),)
expected_error_details = "metadata was invalid: %s" % metadata
with self.assertRaises(ValueError) as exception_context:
@ -130,8 +130,8 @@ class InvalidMetadataTest(unittest.TestCase):
self.assertIn(expected_error_details, str(exception_context.exception))
def testStreamRequestBlockingUnaryResponseWithCall(self):
request_iterator = (
b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH))
request_iterator = (b'\x07\x08'
for _ in range(test_constants.STREAM_LENGTH))
metadata = (('InVaLiD', 'StreamRequestBlockingUnaryResponseWithCall'),)
expected_error_details = "metadata was invalid: %s" % metadata
multi_callable = _stream_unary_multi_callable(self._channel)
@ -140,33 +140,34 @@ class InvalidMetadataTest(unittest.TestCase):
self.assertIn(expected_error_details, str(exception_context.exception))
def testStreamRequestFutureUnaryResponse(self):
request_iterator = (
b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH))
request_iterator = (b'\x07\x08'
for _ in range(test_constants.STREAM_LENGTH))
metadata = (('InVaLiD', 'StreamRequestFutureUnaryResponse'),)
expected_error_details = "metadata was invalid: %s" % metadata
response_future = self._stream_unary.future(
request_iterator, metadata=metadata)
with self.assertRaises(grpc.RpcError) as exception_context:
response_future.result()
self.assertEqual(
exception_context.exception.details(), expected_error_details)
self.assertEqual(
exception_context.exception.code(), grpc.StatusCode.INTERNAL)
self.assertEqual(exception_context.exception.details(),
expected_error_details)
self.assertEqual(exception_context.exception.code(),
grpc.StatusCode.INTERNAL)
self.assertEqual(response_future.details(), expected_error_details)
self.assertEqual(response_future.code(), grpc.StatusCode.INTERNAL)
def testStreamRequestStreamResponse(self):
request_iterator = (
b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH))
request_iterator = (b'\x07\x08'
for _ in range(test_constants.STREAM_LENGTH))
metadata = (('InVaLiD', 'StreamRequestStreamResponse'),)
expected_error_details = "metadata was invalid: %s" % metadata
response_iterator = self._stream_stream(request_iterator, metadata=metadata)
response_iterator = self._stream_stream(
request_iterator, metadata=metadata)
with self.assertRaises(grpc.RpcError) as exception_context:
next(response_iterator)
self.assertEqual(
exception_context.exception.details(), expected_error_details)
self.assertEqual(
exception_context.exception.code(), grpc.StatusCode.INTERNAL)
self.assertEqual(exception_context.exception.details(),
expected_error_details)
self.assertEqual(exception_context.exception.code(),
grpc.StatusCode.INTERNAL)
self.assertEqual(response_iterator.details(), expected_error_details)
self.assertEqual(response_iterator.code(), grpc.StatusCode.INTERNAL)

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save