Merge pull request #13745 from mehrdada/upmerge180

Upmerge v1.8.x into master
pull/13747/head
Mehrdad Afshari 7 years ago committed by GitHub
commit b88b342dab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 14
      README.md
  2. 68
      examples/python/interceptors/default_value/default_value_client_interceptor.py
  3. 38
      examples/python/interceptors/default_value/greeter_client.py
  4. 134
      examples/python/interceptors/default_value/helloworld_pb2.py
  5. 46
      examples/python/interceptors/default_value/helloworld_pb2_grpc.py
  6. 55
      examples/python/interceptors/headers/generic_client_interceptor.py
  7. 36
      examples/python/interceptors/headers/greeter_client.py
  8. 52
      examples/python/interceptors/headers/greeter_server.py
  9. 42
      examples/python/interceptors/headers/header_manipulator_client_interceptor.py
  10. 134
      examples/python/interceptors/headers/helloworld_pb2.py
  11. 46
      examples/python/interceptors/headers/helloworld_pb2_grpc.py
  12. 39
      examples/python/interceptors/headers/request_header_validator_interceptor.py
  13. 335
      src/python/grpcio/grpc/__init__.py
  14. 68
      src/python/grpcio/grpc/_channel.py
  15. 17
      src/python/grpcio/grpc/_common.py
  16. 12
      src/python/grpcio/grpc/_cython/_cygrpc/call.pyx.pxi
  17. 2
      src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi
  18. 9
      src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pyx.pxi
  19. 8
      src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi
  20. 24
      src/python/grpcio/grpc/_cython/_cygrpc/grpc_string.pyx.pxi
  21. 26
      src/python/grpcio/grpc/_cython/_cygrpc/metadata.pxd.pxi
  22. 62
      src/python/grpcio/grpc/_cython/_cygrpc/metadata.pyx.pxi
  23. 41
      src/python/grpcio/grpc/_cython/_cygrpc/records.pxd.pxi
  24. 245
      src/python/grpcio/grpc/_cython/_cygrpc/records.pyx.pxi
  25. 22
      src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi
  26. 1
      src/python/grpcio/grpc/_cython/cygrpc.pxd
  27. 1
      src/python/grpcio/grpc/_cython/cygrpc.pyx
  28. 318
      src/python/grpcio/grpc/_interceptor.py
  29. 4
      src/python/grpcio/grpc/_plugin_wrapping.py
  30. 123
      src/python/grpcio/grpc/_server.py
  31. 21
      src/python/grpcio/grpc/beta/_client_adaptations.py
  32. 49
      src/python/grpcio/grpc/beta/_metadata.py
  33. 10
      src/python/grpcio/grpc/beta/_server_adaptations.py
  34. 14
      src/python/grpcio/grpc/beta/implementations.py
  35. 3
      src/python/grpcio_testing/grpc_testing/_server/_servicer_context.py
  36. 1
      src/python/grpcio_tests/tests/tests.json
  37. 13
      src/python/grpcio_tests/tests/unit/_api_test.py
  38. 15
      src/python/grpcio_tests/tests/unit/_cython/_cancel_many_calls_test.py
  39. 15
      src/python/grpcio_tests/tests/unit/_cython/_common.py
  40. 13
      src/python/grpcio_tests/tests/unit/_cython/_no_messages_server_completion_queue_per_call_test.py
  41. 13
      src/python/grpcio_tests/tests/unit/_cython/_no_messages_single_server_completion_queue_test.py
  42. 33
      src/python/grpcio_tests/tests/unit/_cython/_read_some_but_not_all_responses_test.py
  43. 81
      src/python/grpcio_tests/tests/unit/_cython/cygrpc_test.py
  44. 571
      src/python/grpcio_tests/tests/unit/_interceptor_test.py
  45. 237
      src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py
  46. 2
      src/ruby/ext/grpc/extconf.rb

@ -27,13 +27,13 @@ Libraries in different languages may be in different states of development. We a
| Language | Source | Status |
|-------------------------|-------------------------------------|---------|
| Shared C [core library] | [src/core](src/core) | 1.6 |
| C++ | [src/cpp](src/cpp) | 1.6 |
| Ruby | [src/ruby](src/ruby) | 1.6 |
| Python | [src/python](src/python) | 1.6 |
| PHP | [src/php](src/php) | 1.6 |
| C# | [src/csharp](src/csharp) | 1.6 |
| Objective-C | [src/objective-c](src/objective-c) | 1.6 |
| Shared C [core library] | [src/core](src/core) | 1.8 |
| C++ | [src/cpp](src/cpp) | 1.8 |
| Ruby | [src/ruby](src/ruby) | 1.8 |
| Python | [src/python](src/python) | 1.8 |
| PHP | [src/php](src/php) | 1.8 |
| C# | [src/csharp](src/csharp) | 1.8 |
| Objective-C | [src/objective-c](src/objective-c) | 1.8 |
Java source code is in the [grpc-java](http://github.com/grpc/grpc-java)
repository. Go source code is in the

@ -0,0 +1,68 @@
# Copyright 2017 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Interceptor that adds headers to outgoing requests."""
import collections
import grpc
class _ConcreteValue(grpc.Future):
def __init__(self, result):
self._result = result
def cancel(self):
return False
def cancelled(self):
return False
def running(self):
return False
def done(self):
return True
def result(self, timeout=None):
return self._result
def exception(self, timeout=None):
return None
def traceback(self, timeout=None):
return None
def add_done_callback(self, fn):
fn(self._result)
class DefaultValueClientInterceptor(grpc.UnaryUnaryClientInterceptor,
grpc.StreamUnaryClientInterceptor):
def __init__(self, value):
self._default = _ConcreteValue(value)
def _intercept_call(self, continuation, client_call_details,
request_or_iterator):
response = continuation(client_call_details, request_or_iterator)
return self._default if response.exception() else response
def intercept_unary_unary(self, continuation, client_call_details, request):
return self._intercept_call(continuation, client_call_details, request)
def intercept_stream_unary(self, continuation, client_call_details,
request_iterator):
return self._intercept_call(continuation, client_call_details,
request_iterator)

@ -0,0 +1,38 @@
# Copyright 2017 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The Python implementation of the gRPC helloworld.Greeter client."""
from __future__ import print_function
import grpc
import helloworld_pb2
import helloworld_pb2_grpc
import default_value_client_interceptor
def run():
default_value = helloworld_pb2.HelloReply(
message='Hello from your local interceptor!')
default_value_interceptor = default_value_client_interceptor.DefaultValueClientInterceptor(
default_value)
channel = grpc.insecure_channel('localhost:50051')
channel = grpc.intercept_channel(channel, default_value_interceptor)
stub = helloworld_pb2_grpc.GreeterStub(channel)
response = stub.SayHello(helloworld_pb2.HelloRequest(name='you'))
print("Greeter client received: " + response.message)
if __name__ == '__main__':
run()

@ -0,0 +1,134 @@
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: helloworld.proto
import sys
_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
from google.protobuf import descriptor_pb2
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor.FileDescriptor(
name='helloworld.proto',
package='helloworld',
syntax='proto3',
serialized_pb=_b('\n\x10helloworld.proto\x12\nhelloworld\"\x1c\n\x0cHelloRequest\x12\x0c\n\x04name\x18\x01 \x01(\t\"\x1d\n\nHelloReply\x12\x0f\n\x07message\x18\x01 \x01(\t2I\n\x07Greeter\x12>\n\x08SayHello\x12\x18.helloworld.HelloRequest\x1a\x16.helloworld.HelloReply\"\x00\x42\x36\n\x1bio.grpc.examples.helloworldB\x0fHelloWorldProtoP\x01\xa2\x02\x03HLWb\x06proto3')
)
_HELLOREQUEST = _descriptor.Descriptor(
name='HelloRequest',
full_name='helloworld.HelloRequest',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='name', full_name='helloworld.HelloRequest.name', index=0,
number=1, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=32,
serialized_end=60,
)
_HELLOREPLY = _descriptor.Descriptor(
name='HelloReply',
full_name='helloworld.HelloReply',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='message', full_name='helloworld.HelloReply.message', index=0,
number=1, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=62,
serialized_end=91,
)
DESCRIPTOR.message_types_by_name['HelloRequest'] = _HELLOREQUEST
DESCRIPTOR.message_types_by_name['HelloReply'] = _HELLOREPLY
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
HelloRequest = _reflection.GeneratedProtocolMessageType('HelloRequest', (_message.Message,), dict(
DESCRIPTOR = _HELLOREQUEST,
__module__ = 'helloworld_pb2'
# @@protoc_insertion_point(class_scope:helloworld.HelloRequest)
))
_sym_db.RegisterMessage(HelloRequest)
HelloReply = _reflection.GeneratedProtocolMessageType('HelloReply', (_message.Message,), dict(
DESCRIPTOR = _HELLOREPLY,
__module__ = 'helloworld_pb2'
# @@protoc_insertion_point(class_scope:helloworld.HelloReply)
))
_sym_db.RegisterMessage(HelloReply)
DESCRIPTOR.has_options = True
DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\n\033io.grpc.examples.helloworldB\017HelloWorldProtoP\001\242\002\003HLW'))
_GREETER = _descriptor.ServiceDescriptor(
name='Greeter',
full_name='helloworld.Greeter',
file=DESCRIPTOR,
index=0,
options=None,
serialized_start=93,
serialized_end=166,
methods=[
_descriptor.MethodDescriptor(
name='SayHello',
full_name='helloworld.Greeter.SayHello',
index=0,
containing_service=None,
input_type=_HELLOREQUEST,
output_type=_HELLOREPLY,
options=None,
),
])
_sym_db.RegisterServiceDescriptor(_GREETER)
DESCRIPTOR.services_by_name['Greeter'] = _GREETER
# @@protoc_insertion_point(module_scope)

@ -0,0 +1,46 @@
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
import grpc
import helloworld_pb2 as helloworld__pb2
class GreeterStub(object):
"""The greeting service definition.
"""
def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.SayHello = channel.unary_unary(
'/helloworld.Greeter/SayHello',
request_serializer=helloworld__pb2.HelloRequest.SerializeToString,
response_deserializer=helloworld__pb2.HelloReply.FromString,
)
class GreeterServicer(object):
"""The greeting service definition.
"""
def SayHello(self, request, context):
"""Sends a greeting
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_GreeterServicer_to_server(servicer, server):
rpc_method_handlers = {
'SayHello': grpc.unary_unary_rpc_method_handler(
servicer.SayHello,
request_deserializer=helloworld__pb2.HelloRequest.FromString,
response_serializer=helloworld__pb2.HelloReply.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'helloworld.Greeter', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))

@ -0,0 +1,55 @@
# Copyright 2017 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base class for interceptors that operate on all RPC types."""
import grpc
class _GenericClientInterceptor(
grpc.UnaryUnaryClientInterceptor, grpc.UnaryStreamClientInterceptor,
grpc.StreamUnaryClientInterceptor, grpc.StreamStreamClientInterceptor):
def __init__(self, interceptor_function):
self._fn = interceptor_function
def intercept_unary_unary(self, continuation, client_call_details, request):
new_details, new_request_iterator, postprocess = self._fn(
client_call_details, iter((request,)), False, False)
response = continuation(new_details, next(new_request_iterator))
return postprocess(response) if postprocess else response
def intercept_unary_stream(self, continuation, client_call_details,
request):
new_details, new_request_iterator, postprocess = self._fn(
client_call_details, iter((request,)), False, True)
response_it = continuation(new_details, new_request_iterator)
return postprocess(response_it) if postprocess else response_it
def intercept_stream_unary(self, continuation, client_call_details,
request_iterator):
new_details, new_request_iterator, postprocess = self._fn(
client_call_details, request_iterator, True, False)
response = continuation(new_details, next(new_request_iterator))
return postprocess(response) if postprocess else response
def intercept_stream_stream(self, continuation, client_call_details,
request_iterator):
new_details, new_request_iterator, postprocess = self._fn(
client_call_details, request_iterator, True, True)
response_it = continuation(new_details, new_request_iterator)
return postprocess(response_it) if postprocess else response_it
def create(intercept_call):
return _GenericClientInterceptor(intercept_call)

@ -0,0 +1,36 @@
# Copyright 2017 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The Python implementation of the GRPC helloworld.Greeter client."""
from __future__ import print_function
import grpc
import helloworld_pb2
import helloworld_pb2_grpc
import header_manipulator_client_interceptor
def run():
header_adder_interceptor = header_manipulator_client_interceptor.header_adder_interceptor(
'one-time-password', '42')
channel = grpc.insecure_channel('localhost:50051')
channel = grpc.intercept_channel(channel, header_adder_interceptor)
stub = helloworld_pb2_grpc.GreeterStub(channel)
response = stub.SayHello(helloworld_pb2.HelloRequest(name='you'))
print("Greeter client received: " + response.message)
if __name__ == '__main__':
run()

@ -0,0 +1,52 @@
# Copyright 2017 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The Python implementation of the GRPC helloworld.Greeter server."""
from concurrent import futures
import time
import grpc
import helloworld_pb2
import helloworld_pb2_grpc
from request_header_validator_interceptor import RequestHeaderValidatorInterceptor
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
class Greeter(helloworld_pb2_grpc.GreeterServicer):
def SayHello(self, request, context):
return helloworld_pb2.HelloReply(message='Hello, %s!' % request.name)
def serve():
header_validator = RequestHeaderValidatorInterceptor(
'one-time-password', '42', grpc.StatusCode.UNAUTHENTICATED,
'Access denied!')
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=10),
interceptors=(header_validator,))
helloworld_pb2_grpc.add_GreeterServicer_to_server(Greeter(), server)
server.add_insecure_port('[::]:50051')
server.start()
try:
while True:
time.sleep(_ONE_DAY_IN_SECONDS)
except KeyboardInterrupt:
server.stop(0)
if __name__ == '__main__':
serve()

@ -0,0 +1,42 @@
# Copyright 2017 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Interceptor that adds headers to outgoing requests."""
import collections
import grpc
import generic_client_interceptor
class _ClientCallDetails(
collections.namedtuple('_ClientCallDetails',
('method', 'timeout', 'metadata',
'credentials')), grpc.ClientCallDetails):
pass
def header_adder_interceptor(header, value):
def intercept_call(client_call_details, request_iterator, request_streaming,
response_streaming):
metadata = []
if client_call_details.metadata is not None:
metadata = list(client_call_details.metadata)
metadata.append((header, value,))
client_call_details = _ClientCallDetails(
client_call_details.method, client_call_details.timeout, metadata,
client_call_details.credentials)
return client_call_details, request_iterator, None
return generic_client_interceptor.create(intercept_call)

@ -0,0 +1,134 @@
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: helloworld.proto
import sys
_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
from google.protobuf import descriptor_pb2
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor.FileDescriptor(
name='helloworld.proto',
package='helloworld',
syntax='proto3',
serialized_pb=_b('\n\x10helloworld.proto\x12\nhelloworld\"\x1c\n\x0cHelloRequest\x12\x0c\n\x04name\x18\x01 \x01(\t\"\x1d\n\nHelloReply\x12\x0f\n\x07message\x18\x01 \x01(\t2I\n\x07Greeter\x12>\n\x08SayHello\x12\x18.helloworld.HelloRequest\x1a\x16.helloworld.HelloReply\"\x00\x42\x36\n\x1bio.grpc.examples.helloworldB\x0fHelloWorldProtoP\x01\xa2\x02\x03HLWb\x06proto3')
)
_HELLOREQUEST = _descriptor.Descriptor(
name='HelloRequest',
full_name='helloworld.HelloRequest',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='name', full_name='helloworld.HelloRequest.name', index=0,
number=1, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=32,
serialized_end=60,
)
_HELLOREPLY = _descriptor.Descriptor(
name='HelloReply',
full_name='helloworld.HelloReply',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='message', full_name='helloworld.HelloReply.message', index=0,
number=1, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=62,
serialized_end=91,
)
DESCRIPTOR.message_types_by_name['HelloRequest'] = _HELLOREQUEST
DESCRIPTOR.message_types_by_name['HelloReply'] = _HELLOREPLY
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
HelloRequest = _reflection.GeneratedProtocolMessageType('HelloRequest', (_message.Message,), dict(
DESCRIPTOR = _HELLOREQUEST,
__module__ = 'helloworld_pb2'
# @@protoc_insertion_point(class_scope:helloworld.HelloRequest)
))
_sym_db.RegisterMessage(HelloRequest)
HelloReply = _reflection.GeneratedProtocolMessageType('HelloReply', (_message.Message,), dict(
DESCRIPTOR = _HELLOREPLY,
__module__ = 'helloworld_pb2'
# @@protoc_insertion_point(class_scope:helloworld.HelloReply)
))
_sym_db.RegisterMessage(HelloReply)
DESCRIPTOR.has_options = True
DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\n\033io.grpc.examples.helloworldB\017HelloWorldProtoP\001\242\002\003HLW'))
_GREETER = _descriptor.ServiceDescriptor(
name='Greeter',
full_name='helloworld.Greeter',
file=DESCRIPTOR,
index=0,
options=None,
serialized_start=93,
serialized_end=166,
methods=[
_descriptor.MethodDescriptor(
name='SayHello',
full_name='helloworld.Greeter.SayHello',
index=0,
containing_service=None,
input_type=_HELLOREQUEST,
output_type=_HELLOREPLY,
options=None,
),
])
_sym_db.RegisterServiceDescriptor(_GREETER)
DESCRIPTOR.services_by_name['Greeter'] = _GREETER
# @@protoc_insertion_point(module_scope)

@ -0,0 +1,46 @@
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
import grpc
import helloworld_pb2 as helloworld__pb2
class GreeterStub(object):
"""The greeting service definition.
"""
def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.SayHello = channel.unary_unary(
'/helloworld.Greeter/SayHello',
request_serializer=helloworld__pb2.HelloRequest.SerializeToString,
response_deserializer=helloworld__pb2.HelloReply.FromString,
)
class GreeterServicer(object):
"""The greeting service definition.
"""
def SayHello(self, request, context):
"""Sends a greeting
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_GreeterServicer_to_server(servicer, server):
rpc_method_handlers = {
'SayHello': grpc.unary_unary_rpc_method_handler(
servicer.SayHello,
request_deserializer=helloworld__pb2.HelloRequest.FromString,
response_serializer=helloworld__pb2.HelloReply.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'helloworld.Greeter', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))

@ -0,0 +1,39 @@
# Copyright 2017 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Interceptor that ensures a specific header is present."""
import grpc
def _unary_unary_rpc_terminator(code, details):
def terminate(ignored_request, context):
context.abort(code, details)
return grpc.unary_unary_rpc_method_handler(terminate)
class RequestHeaderValidatorInterceptor(grpc.ServerInterceptor):
def __init__(self, header, value, code, details):
self._header = header
self._value = value
self._terminator = _unary_unary_rpc_terminator(code, details)
def intercept_service(self, continuation, handler_call_details):
if (self._header,
self._value) in handler_call_details.invocation_metadata:
return continuation(handler_call_details)
else:
return self._terminator

@ -342,6 +342,170 @@ class Call(six.with_metaclass(abc.ABCMeta, RpcContext)):
raise NotImplementedError()
############## Invocation-Side Interceptor Interfaces & Classes ##############
class ClientCallDetails(six.with_metaclass(abc.ABCMeta)):
"""Describes an RPC to be invoked.
This is an EXPERIMENTAL API.
Attributes:
method: The method name of the RPC.
timeout: An optional duration of time in seconds to allow for the RPC.
metadata: Optional :term:`metadata` to be transmitted to
the service-side of the RPC.
credentials: An optional CallCredentials for the RPC.
"""
class UnaryUnaryClientInterceptor(six.with_metaclass(abc.ABCMeta)):
"""Affords intercepting unary-unary invocations.
This is an EXPERIMENTAL API.
"""
@abc.abstractmethod
def intercept_unary_unary(self, continuation, client_call_details, request):
"""Intercepts a unary-unary invocation asynchronously.
Args:
continuation: A function that proceeds with the invocation by
executing the next interceptor in chain or invoking the
actual RPC on the underlying Channel. It is the interceptor's
responsibility to call it if it decides to move the RPC forward.
The interceptor can use
`response_future = continuation(client_call_details, request)`
to continue with the RPC. `continuation` returns an object that is
both a Call for the RPC and a Future. In the event of RPC
completion, the return Call-Future's result value will be
the response message of the RPC. Should the event terminate
with non-OK status, the returned Call-Future's exception value
will be an RpcError.
client_call_details: A ClientCallDetails object describing the
outgoing RPC.
request: The request value for the RPC.
Returns:
An object that is both a Call for the RPC and a Future.
In the event of RPC completion, the return Call-Future's
result value will be the response message of the RPC.
Should the event terminate with non-OK status, the returned
Call-Future's exception value will be an RpcError.
"""
raise NotImplementedError()
class UnaryStreamClientInterceptor(six.with_metaclass(abc.ABCMeta)):
"""Affords intercepting unary-stream invocations.
This is an EXPERIMENTAL API.
"""
@abc.abstractmethod
def intercept_unary_stream(self, continuation, client_call_details,
request):
"""Intercepts a unary-stream invocation.
Args:
continuation: A function that proceeds with the invocation by
executing the next interceptor in chain or invoking the
actual RPC on the underlying Channel. It is the interceptor's
responsibility to call it if it decides to move the RPC forward.
The interceptor can use
`response_iterator = continuation(client_call_details, request)`
to continue with the RPC. `continuation` returns an object that is
both a Call for the RPC and an iterator for response values.
Drawing response values from the returned Call-iterator may
raise RpcError indicating termination of the RPC with non-OK
status.
client_call_details: A ClientCallDetails object describing the
outgoing RPC.
request: The request value for the RPC.
Returns:
An object that is both a Call for the RPC and an iterator of
response values. Drawing response values from the returned
Call-iterator may raise RpcError indicating termination of
the RPC with non-OK status.
"""
raise NotImplementedError()
class StreamUnaryClientInterceptor(six.with_metaclass(abc.ABCMeta)):
"""Affords intercepting stream-unary invocations.
This is an EXPERIMENTAL API.
"""
@abc.abstractmethod
def intercept_stream_unary(self, continuation, client_call_details,
request_iterator):
"""Intercepts a stream-unary invocation asynchronously.
Args:
continuation: A function that proceeds with the invocation by
executing the next interceptor in chain or invoking the
actual RPC on the underlying Channel. It is the interceptor's
responsibility to call it if it decides to move the RPC forward.
The interceptor can use
`response_future = continuation(client_call_details,
request_iterator)`
to continue with the RPC. `continuation` returns an object that is
both a Call for the RPC and a Future. In the event of RPC completion,
the return Call-Future's result value will be the response message
of the RPC. Should the event terminate with non-OK status, the
returned Call-Future's exception value will be an RpcError.
client_call_details: A ClientCallDetails object describing the
outgoing RPC.
request_iterator: An iterator that yields request values for the RPC.
Returns:
An object that is both a Call for the RPC and a Future.
In the event of RPC completion, the return Call-Future's
result value will be the response message of the RPC.
Should the event terminate with non-OK status, the returned
Call-Future's exception value will be an RpcError.
"""
raise NotImplementedError()
class StreamStreamClientInterceptor(six.with_metaclass(abc.ABCMeta)):
"""Affords intercepting stream-stream invocations.
This is an EXPERIMENTAL API.
"""
@abc.abstractmethod
def intercept_stream_stream(self, continuation, client_call_details,
request_iterator):
"""Intercepts a stream-stream invocation.
continuation: A function that proceeds with the invocation by
executing the next interceptor in chain or invoking the
actual RPC on the underlying Channel. It is the interceptor's
responsibility to call it if it decides to move the RPC forward.
The interceptor can use
`response_iterator = continuation(client_call_details,
request_iterator)`
to continue with the RPC. `continuation` returns an object that is
both a Call for the RPC and an iterator for response values.
Drawing response values from the returned Call-iterator may
raise RpcError indicating termination of the RPC with non-OK
status.
client_call_details: A ClientCallDetails object describing the
outgoing RPC.
request_iterator: An iterator that yields request values for the RPC.
Returns:
An object that is both a Call for the RPC and an iterator of
response values. Drawing response values from the returned
Call-iterator may raise RpcError indicating termination of
the RPC with non-OK status.
"""
raise NotImplementedError()
############ Authentication & Authorization Interfaces & Classes #############
@ -834,28 +998,48 @@ class ServicerContext(six.with_metaclass(abc.ABCMeta, RpcContext)):
"""
raise NotImplementedError()
@abc.abstractmethod
def abort(self, code, details):
"""Raises an exception to terminate the RPC with a non-OK status.
The code and details passed as arguments will supercede any existing
ones.
Args:
code: A StatusCode object to be sent to the client.
It must not be StatusCode.OK.
details: An ASCII-encodable string to be sent to the client upon
termination of the RPC.
Raises:
Exception: An exception is always raised to signal the abortion the
RPC to the gRPC runtime.
"""
raise NotImplementedError()
@abc.abstractmethod
def set_code(self, code):
"""Sets the value to be used as status code upon RPC completion.
This method need not be called by method implementations if they wish the
gRPC runtime to determine the status code of the RPC.
This method need not be called by method implementations if they wish
the gRPC runtime to determine the status code of the RPC.
Args:
code: A StatusCode object to be sent to the client.
"""
Args:
code: A StatusCode object to be sent to the client.
"""
raise NotImplementedError()
@abc.abstractmethod
def set_details(self, details):
"""Sets the value to be used as detail string upon RPC completion.
This method need not be called by method implementations if they have no
details to transmit.
This method need not be called by method implementations if they have
no details to transmit.
Args:
details: An arbitrary string to be sent to the client upon completion.
"""
Args:
details: An ASCII-encodable string to be sent to the client upon
termination of the RPC.
"""
raise NotImplementedError()
@ -942,6 +1126,34 @@ class ServiceRpcHandler(six.with_metaclass(abc.ABCMeta, GenericRpcHandler)):
raise NotImplementedError()
#################### Service-Side Interceptor Interfaces #####################
class ServerInterceptor(six.with_metaclass(abc.ABCMeta)):
"""Affords intercepting incoming RPCs on the service-side.
This is an EXPERIMENTAL API.
"""
@abc.abstractmethod
def intercept_service(self, continuation, handler_call_details):
"""Intercepts incoming RPCs before handing them over to a handler.
Args:
continuation: A function that takes a HandlerCallDetails and
proceeds to invoke the next interceptor in the chain, if any,
or the RPC handler lookup logic, with the call details passed
as an argument, and returns an RpcMethodHandler instance if
the RPC is considered serviced, or None otherwise.
handler_call_details: A HandlerCallDetails describing the RPC.
Returns:
An RpcMethodHandler with which the RPC may be serviced if the
interceptor chooses to service this RPC, or None otherwise.
"""
raise NotImplementedError()
############################# Server Interface ###############################
@ -1356,53 +1568,88 @@ def secure_channel(target, credentials, options=None):
credentials._credentials)
def intercept_channel(channel, *interceptors):
"""Intercepts a channel through a set of interceptors.
This is an EXPERIMENTAL API.
Args:
channel: A Channel.
interceptors: Zero or more objects of type
UnaryUnaryClientInterceptor,
UnaryStreamClientInterceptor,
StreamUnaryClientInterceptor, or
StreamStreamClientInterceptor.
Interceptors are given control in the order they are listed.
Returns:
A Channel that intercepts each invocation via the provided interceptors.
Raises:
TypeError: If interceptor does not derive from any of
UnaryUnaryClientInterceptor,
UnaryStreamClientInterceptor,
StreamUnaryClientInterceptor, or
StreamStreamClientInterceptor.
"""
from grpc import _interceptor # pylint: disable=cyclic-import
return _interceptor.intercept_channel(channel, *interceptors)
def server(thread_pool,
handlers=None,
interceptors=None,
options=None,
maximum_concurrent_rpcs=None):
"""Creates a Server with which RPCs can be serviced.
Args:
thread_pool: A futures.ThreadPoolExecutor to be used by the Server
to execute RPC handlers.
handlers: An optional list of GenericRpcHandlers used for executing RPCs.
More handlers may be added by calling add_generic_rpc_handlers any time
before the server is started.
options: An optional list of key-value pairs (channel args in gRPC runtime)
to configure the channel.
maximum_concurrent_rpcs: The maximum number of concurrent RPCs this server
will service before returning RESOURCE_EXHAUSTED status, or None to
indicate no limit.
Args:
thread_pool: A futures.ThreadPoolExecutor to be used by the Server
to execute RPC handlers.
handlers: An optional list of GenericRpcHandlers used for executing RPCs.
More handlers may be added by calling add_generic_rpc_handlers any time
before the server is started.
interceptors: An optional list of ServerInterceptor objects that observe
and optionally manipulate the incoming RPCs before handing them over to
handlers. The interceptors are given control in the order they are
specified. This is an EXPERIMENTAL API.
options: An optional list of key-value pairs (channel args in gRPC runtime)
to configure the channel.
maximum_concurrent_rpcs: The maximum number of concurrent RPCs this server
will service before returning RESOURCE_EXHAUSTED status, or None to
indicate no limit.
Returns:
A Server object.
"""
Returns:
A Server object.
"""
from grpc import _server # pylint: disable=cyclic-import
return _server.Server(thread_pool, () if handlers is None else handlers, ()
if options is None else options,
maximum_concurrent_rpcs)
if interceptors is None else interceptors, () if
options is None else options, maximum_concurrent_rpcs)
################################### __all__ #################################
__all__ = ('FutureTimeoutError', 'FutureCancelledError', 'Future',
'ChannelConnectivity', 'StatusCode', 'RpcError', 'RpcContext',
'Call', 'ChannelCredentials', 'CallCredentials',
'AuthMetadataContext', 'AuthMetadataPluginCallback',
'AuthMetadataPlugin', 'ServerCertificateConfiguration',
'ServerCredentials', 'UnaryUnaryMultiCallable',
'UnaryStreamMultiCallable', 'StreamUnaryMultiCallable',
'StreamStreamMultiCallable', 'Channel', 'ServicerContext',
'RpcMethodHandler', 'HandlerCallDetails', 'GenericRpcHandler',
'ServiceRpcHandler', 'Server', 'unary_unary_rpc_method_handler',
'unary_stream_rpc_method_handler', 'stream_unary_rpc_method_handler',
'stream_stream_rpc_method_handler',
'method_handlers_generic_handler', 'ssl_channel_credentials',
'metadata_call_credentials', 'access_token_call_credentials',
'composite_call_credentials', 'composite_channel_credentials',
'ssl_server_credentials', 'ssl_server_certificate_configuration',
'dynamic_ssl_server_credentials', 'channel_ready_future',
'insecure_channel', 'secure_channel', 'server',)
__all__ = (
'FutureTimeoutError', 'FutureCancelledError', 'Future',
'ChannelConnectivity', 'StatusCode', 'RpcError', 'RpcContext', 'Call',
'ChannelCredentials', 'CallCredentials', 'AuthMetadataContext',
'AuthMetadataPluginCallback', 'AuthMetadataPlugin', 'ClientCallDetails',
'ServerCertificateConfiguration', 'ServerCredentials',
'UnaryUnaryMultiCallable', 'UnaryStreamMultiCallable',
'StreamUnaryMultiCallable', 'StreamStreamMultiCallable',
'UnaryUnaryClientInterceptor', 'UnaryStreamClientInterceptor',
'StreamUnaryClientInterceptor', 'StreamStreamClientInterceptor', 'Channel',
'ServicerContext', 'RpcMethodHandler', 'HandlerCallDetails',
'GenericRpcHandler', 'ServiceRpcHandler', 'Server', 'ServerInterceptor',
'unary_unary_rpc_method_handler', 'unary_stream_rpc_method_handler',
'stream_unary_rpc_method_handler', 'stream_stream_rpc_method_handler',
'method_handlers_generic_handler', 'ssl_channel_credentials',
'metadata_call_credentials', 'access_token_call_credentials',
'composite_call_credentials', 'composite_channel_credentials',
'ssl_server_credentials', 'ssl_server_certificate_configuration',
'dynamic_ssl_server_credentials', 'channel_ready_future',
'insecure_channel', 'secure_channel', 'intercept_channel', 'server',)
############################### Extension Shims ################################

@ -122,8 +122,8 @@ def _abort(state, code, details):
state.code = code
state.details = details
if state.initial_metadata is None:
state.initial_metadata = _common.EMPTY_METADATA
state.trailing_metadata = _common.EMPTY_METADATA
state.initial_metadata = ()
state.trailing_metadata = ()
def _handle_event(event, state, response_deserializer):
@ -202,8 +202,7 @@ def _consume_request_iterator(request_iterator, state, call,
else:
operations = (cygrpc.operation_send_message(
serialized_request, _EMPTY_FLAGS),)
call.start_client_batch(
cygrpc.Operations(operations), event_handler)
call.start_client_batch(operations, event_handler)
state.due.add(cygrpc.OperationType.send_message)
while True:
state.condition.wait()
@ -218,8 +217,7 @@ def _consume_request_iterator(request_iterator, state, call,
if state.code is None:
operations = (
cygrpc.operation_send_close_from_client(_EMPTY_FLAGS),)
call.start_client_batch(
cygrpc.Operations(operations), event_handler)
call.start_client_batch(operations, event_handler)
state.due.add(cygrpc.OperationType.send_close_from_client)
def stop_consumption_thread(timeout): # pylint: disable=unused-argument
@ -321,8 +319,7 @@ class _Rendezvous(grpc.RpcError, grpc.Future, grpc.Call):
event_handler = _event_handler(self._state, self._call,
self._response_deserializer)
self._call.start_client_batch(
cygrpc.Operations(
(cygrpc.operation_receive_message(_EMPTY_FLAGS),)),
(cygrpc.operation_receive_message(_EMPTY_FLAGS),),
event_handler)
self._state.due.add(cygrpc.OperationType.receive_message)
elif self._state.code is grpc.StatusCode.OK:
@ -372,14 +369,13 @@ class _Rendezvous(grpc.RpcError, grpc.Future, grpc.Call):
with self._state.condition:
while self._state.initial_metadata is None:
self._state.condition.wait()
return _common.to_application_metadata(self._state.initial_metadata)
return self._state.initial_metadata
def trailing_metadata(self):
with self._state.condition:
while self._state.trailing_metadata is None:
self._state.condition.wait()
return _common.to_application_metadata(
self._state.trailing_metadata)
return self._state.trailing_metadata
def code(self):
with self._state.condition:
@ -420,8 +416,7 @@ def _start_unary_request(request, timeout, request_serializer):
deadline, deadline_timespec = _deadline(timeout)
serialized_request = _common.serialize(request, request_serializer)
if serialized_request is None:
state = _RPCState((), _common.EMPTY_METADATA, _common.EMPTY_METADATA,
grpc.StatusCode.INTERNAL,
state = _RPCState((), (), (), grpc.StatusCode.INTERNAL,
'Exception serializing request!')
rendezvous = _Rendezvous(state, None, None, deadline)
return deadline, deadline_timespec, None, rendezvous
@ -458,8 +453,7 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
else:
state = _RPCState(_UNARY_UNARY_INITIAL_DUE, None, None, None, None)
operations = (
cygrpc.operation_send_initial_metadata(
_common.to_cygrpc_metadata(metadata), _EMPTY_FLAGS),
cygrpc.operation_send_initial_metadata(metadata, _EMPTY_FLAGS),
cygrpc.operation_send_message(serialized_request, _EMPTY_FLAGS),
cygrpc.operation_send_close_from_client(_EMPTY_FLAGS),
cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),
@ -479,8 +473,7 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
deadline_timespec)
if credentials is not None:
call.set_credentials(credentials._credentials)
call_error = call.start_client_batch(
cygrpc.Operations(operations), None)
call_error = call.start_client_batch(operations, None)
_check_call_error(call_error, metadata)
_handle_event(completion_queue.poll(), state,
self._response_deserializer)
@ -509,8 +502,7 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
event_handler = _event_handler(state, call,
self._response_deserializer)
with state.condition:
call_error = call.start_client_batch(
cygrpc.Operations(operations), event_handler)
call_error = call.start_client_batch(operations, event_handler)
if call_error != cygrpc.CallError.ok:
_call_error_set_RPCstate(state, call_error, metadata)
return _Rendezvous(state, None, None, deadline)
@ -544,18 +536,15 @@ class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
self._response_deserializer)
with state.condition:
call.start_client_batch(
cygrpc.Operations((
cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),
)), event_handler)
(cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),),
event_handler)
operations = (
cygrpc.operation_send_initial_metadata(
_common.to_cygrpc_metadata(metadata),
_EMPTY_FLAGS), cygrpc.operation_send_message(
metadata, _EMPTY_FLAGS), cygrpc.operation_send_message(
serialized_request, _EMPTY_FLAGS),
cygrpc.operation_send_close_from_client(_EMPTY_FLAGS),
cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),)
call_error = call.start_client_batch(
cygrpc.Operations(operations), event_handler)
call_error = call.start_client_batch(operations, event_handler)
if call_error != cygrpc.CallError.ok:
_call_error_set_RPCstate(state, call_error, metadata)
return _Rendezvous(state, None, None, deadline)
@ -584,16 +573,13 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
call.set_credentials(credentials._credentials)
with state.condition:
call.start_client_batch(
cygrpc.Operations(
(cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),)),
(cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),),
None)
operations = (
cygrpc.operation_send_initial_metadata(
_common.to_cygrpc_metadata(metadata), _EMPTY_FLAGS),
cygrpc.operation_send_initial_metadata(metadata, _EMPTY_FLAGS),
cygrpc.operation_receive_message(_EMPTY_FLAGS),
cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),)
call_error = call.start_client_batch(
cygrpc.Operations(operations), None)
call_error = call.start_client_batch(operations, None)
_check_call_error(call_error, metadata)
_consume_request_iterator(request_iterator, state, call,
self._request_serializer)
@ -638,16 +624,13 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
event_handler = _event_handler(state, call, self._response_deserializer)
with state.condition:
call.start_client_batch(
cygrpc.Operations(
(cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),)),
(cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),),
event_handler)
operations = (
cygrpc.operation_send_initial_metadata(
_common.to_cygrpc_metadata(metadata), _EMPTY_FLAGS),
cygrpc.operation_send_initial_metadata(metadata, _EMPTY_FLAGS),
cygrpc.operation_receive_message(_EMPTY_FLAGS),
cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),)
call_error = call.start_client_batch(
cygrpc.Operations(operations), event_handler)
call_error = call.start_client_batch(operations, event_handler)
if call_error != cygrpc.CallError.ok:
_call_error_set_RPCstate(state, call_error, metadata)
return _Rendezvous(state, None, None, deadline)
@ -681,15 +664,12 @@ class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable):
event_handler = _event_handler(state, call, self._response_deserializer)
with state.condition:
call.start_client_batch(
cygrpc.Operations(
(cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),)),
(cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),),
event_handler)
operations = (
cygrpc.operation_send_initial_metadata(
_common.to_cygrpc_metadata(metadata), _EMPTY_FLAGS),
cygrpc.operation_send_initial_metadata(metadata, _EMPTY_FLAGS),
cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),)
call_error = call.start_client_batch(
cygrpc.Operations(operations), event_handler)
call_error = call.start_client_batch(operations, event_handler)
if call_error != cygrpc.CallError.ok:
_call_error_set_RPCstate(state, call_error, metadata)
return _Rendezvous(state, None, None, deadline)

@ -22,8 +22,6 @@ import six
import grpc
from grpc._cython import cygrpc
EMPTY_METADATA = cygrpc.Metadata(())
CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY = {
cygrpc.ConnectivityState.idle:
grpc.ChannelConnectivity.IDLE,
@ -91,21 +89,6 @@ def channel_args(options):
return cygrpc.ChannelArgs(cygrpc_args)
def to_cygrpc_metadata(application_metadata):
return EMPTY_METADATA if application_metadata is None else cygrpc.Metadata(
cygrpc.Metadatum(encode(key), encode(value))
for key, value in application_metadata)
def to_application_metadata(cygrpc_metadata):
if cygrpc_metadata is None:
return ()
else:
return tuple((decode(key), value
if key[-4:] == b'-bin' else decode(value))
for key, value in cygrpc_metadata)
def _transform(message, transformer, exception_message):
if transformer is None:
return message

@ -26,20 +26,16 @@ cdef class Call:
def _start_batch(self, operations, tag, retain_self):
if not self.is_valid:
raise ValueError("invalid call object cannot be used from Python")
cdef grpc_call_error result
cdef Operations cy_operations = Operations(operations)
cdef OperationTag operation_tag = OperationTag(tag)
cdef OperationTag operation_tag = OperationTag(tag, operations)
if retain_self:
operation_tag.operation_call = self
else:
operation_tag.operation_call = None
operation_tag.batch_operations = cy_operations
operation_tag.store_ops()
cpython.Py_INCREF(operation_tag)
with nogil:
result = grpc_call_start_batch(
self.c_call, cy_operations.c_ops, cy_operations.c_nops,
return grpc_call_start_batch(
self.c_call, operation_tag.c_ops, operation_tag.c_nops,
<cpython.PyObject *>operation_tag, NULL)
return result
def start_client_batch(self, operations, tag):
# We don't reference this call in the operations tag because

@ -76,7 +76,7 @@ cdef class Channel:
def watch_connectivity_state(
self, grpc_connectivity_state last_observed_state,
Timespec deadline not None, CompletionQueue queue not None, tag):
cdef OperationTag operation_tag = OperationTag(tag)
cdef OperationTag operation_tag = OperationTag(tag, None)
cpython.Py_INCREF(operation_tag)
with nogil:
grpc_channel_watch_connectivity_state(

@ -42,7 +42,7 @@ cdef class CompletionQueue:
cdef Call operation_call = None
cdef CallDetails request_call_details = None
cdef object request_metadata = None
cdef Operations batch_operations = None
cdef object batch_operations = None
if event.type == GRPC_QUEUE_TIMEOUT:
return Event(
event.type, False, None, None, None, None, False, None)
@ -61,9 +61,10 @@ cdef class CompletionQueue:
user_tag = tag.user_tag
operation_call = tag.operation_call
request_call_details = tag.request_call_details
if tag.request_metadata is not None:
request_metadata = tuple(tag.request_metadata)
batch_operations = tag.batch_operations
if tag.is_new_request:
request_metadata = _metadata(&tag._c_request_metadata)
grpc_metadata_array_destroy(&tag._c_request_metadata)
batch_operations = tag.release_ops()
if tag.is_new_request:
# Stuff in the tag not explicitly handled by us needs to live through
# the life of the call

@ -30,9 +30,13 @@ cdef int _get_metadata(
grpc_metadata creds_md[GRPC_METADATA_CREDENTIALS_PLUGIN_SYNC_MAX],
size_t *num_creds_md, grpc_status_code *status,
const char **error_details) with gil:
def callback(Metadata metadata, grpc_status_code status, bytes error_details):
cdef size_t metadata_count
cdef grpc_metadata *c_metadata
def callback(metadata, grpc_status_code status, bytes error_details):
if status is StatusCode.ok:
cb(user_data, metadata.c_metadata, metadata.c_count, status, NULL)
_store_c_metadata(metadata, &c_metadata, &metadata_count)
cb(user_data, c_metadata, metadata_count, status, NULL)
_release_c_metadata(c_metadata, metadata_count)
else:
cb(user_data, NULL, 0, status, error_details)
args = context.service_url, context.method_name, callback,

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
# This function will ascii encode unicode string inputs if neccesary.
# In Python3, unicode strings are the default str type.
@ -22,3 +24,25 @@ cdef bytes str_to_bytes(object s):
return s.encode('ascii')
else:
raise TypeError('Expected bytes, str, or unicode, not {}'.format(type(s)))
cdef bytes _encode(str native_string_or_none):
if native_string_or_none is None:
return b''
elif isinstance(native_string_or_none, (bytes,)):
return <bytes>native_string_or_none
elif isinstance(native_string_or_none, (unicode,)):
return native_string_or_none.encode('ascii')
else:
raise TypeError('Expected str, not {}'.format(type(native_string_or_none)))
cdef str _decode(bytes bytestring):
if isinstance(bytestring, (str,)):
return <str>bytestring
else:
try:
return bytestring.decode('utf8')
except UnicodeDecodeError:
logging.exception('Invalid encoding on %s', bytestring)
return bytestring.decode('latin1')

@ -0,0 +1,26 @@
# Copyright 2017 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
cdef void _store_c_metadata(
metadata, grpc_metadata **c_metadata, size_t *c_count)
cdef void _release_c_metadata(grpc_metadata *c_metadata, int count)
cdef tuple _metadatum(grpc_slice key_slice, grpc_slice value_slice)
cdef tuple _metadata(grpc_metadata_array *c_metadata_array)

@ -0,0 +1,62 @@
# Copyright 2017 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
_Metadatum = collections.namedtuple('_Metadatum', ('key', 'value',))
cdef void _store_c_metadata(
metadata, grpc_metadata **c_metadata, size_t *c_count):
if metadata is None:
c_count[0] = 0
c_metadata[0] = NULL
else:
metadatum_count = len(metadata)
if metadatum_count == 0:
c_count[0] = 0
c_metadata[0] = NULL
else:
c_count[0] = metadatum_count
c_metadata[0] = <grpc_metadata *>gpr_malloc(
metadatum_count * sizeof(grpc_metadata))
for index, (key, value) in enumerate(metadata):
encoded_key = _encode(key)
encoded_value = value if encoded_key[-4:] == b'-bin' else _encode(value)
c_metadata[0][index].key = _slice_from_bytes(encoded_key)
c_metadata[0][index].value = _slice_from_bytes(encoded_value)
cdef void _release_c_metadata(grpc_metadata *c_metadata, int count):
if 0 < count:
for index in range(count):
grpc_slice_unref(c_metadata[index].key)
grpc_slice_unref(c_metadata[index].value)
gpr_free(c_metadata)
cdef tuple _metadatum(grpc_slice key_slice, grpc_slice value_slice):
cdef bytes key = _slice_bytes(key_slice)
cdef bytes value = _slice_bytes(value_slice)
return <tuple>_Metadatum(
_decode(key), value if key[-4:] == b'-bin' else _decode(value))
cdef tuple _metadata(grpc_metadata_array *c_metadata_array):
return tuple(
_metadatum(
c_metadata_array.metadata[index].key,
c_metadata_array.metadata[index].value)
for index in range(c_metadata_array.count))

@ -37,10 +37,15 @@ cdef class OperationTag:
cdef Server shutting_down_server
cdef Call operation_call
cdef CallDetails request_call_details
cdef MetadataArray request_metadata
cdef Operations batch_operations
cdef grpc_metadata_array _c_request_metadata
cdef grpc_op *c_ops
cdef size_t c_nops
cdef readonly object _operations
cdef bint is_new_request
cdef void store_ops(self)
cdef object release_ops(self)
cdef class Event:
@ -57,7 +62,7 @@ cdef class Event:
cdef readonly Call operation_call
# For Call.start_batch
cdef readonly Operations batch_operations
cdef readonly object batch_operations
cdef class ByteBuffer:
@ -84,28 +89,15 @@ cdef class ChannelArgs:
cdef list args
cdef class Metadatum:
cdef grpc_metadata c_metadata
cdef void _copy_metadatum(self, grpc_metadata *destination) nogil
cdef class Metadata:
cdef grpc_metadata *c_metadata
cdef readonly size_t c_count
cdef class MetadataArray:
cdef grpc_metadata_array c_metadata_array
cdef class Operation:
cdef grpc_op c_op
cdef bint _c_metadata_needs_release
cdef size_t _c_metadata_count
cdef grpc_metadata *_c_metadata
cdef ByteBuffer _received_message
cdef MetadataArray _received_metadata
cdef bint _c_metadata_array_needs_destruction
cdef grpc_metadata_array _c_metadata_array
cdef grpc_status_code _received_status_code
cdef grpc_slice _status_details
cdef int _received_cancelled
@ -113,13 +105,6 @@ cdef class Operation:
cdef object references
cdef class Operations:
cdef grpc_op *c_ops
cdef size_t c_nops
cdef list operations
cdef class CompressionOptions:
cdef grpc_compression_options c_options

@ -220,9 +220,26 @@ cdef class CallDetails:
cdef class OperationTag:
def __cinit__(self, user_tag):
def __cinit__(self, user_tag, operations):
self.user_tag = user_tag
self.references = []
self._operations = operations
cdef void store_ops(self):
self.c_nops = 0 if self._operations is None else len(self._operations)
if 0 < self.c_nops:
self.c_ops = <grpc_op *>gpr_malloc(sizeof(grpc_op) * self.c_nops)
for index in range(self.c_nops):
self.c_ops[index] = (<Operation>(self._operations[index])).c_op
cdef object release_ops(self):
if 0 < self.c_nops:
for index, operation in enumerate(self._operations):
(<Operation>operation).c_op = self.c_ops[index]
gpr_free(self.c_ops)
return self._operations
else:
return ()
cdef class Event:
@ -232,7 +249,7 @@ cdef class Event:
CallDetails request_call_details,
object request_metadata,
bint is_new_request,
Operations batch_operations):
object batch_operations):
self.type = type
self.success = success
self.tag = tag
@ -390,140 +407,13 @@ cdef class ChannelArgs:
return self.args[i]
cdef class Metadatum:
def __cinit__(self, bytes key, bytes value):
self.c_metadata.key = _slice_from_bytes(key)
self.c_metadata.value = _slice_from_bytes(value)
cdef void _copy_metadatum(self, grpc_metadata *destination) nogil:
destination[0].key = _copy_slice(self.c_metadata.key)
destination[0].value = _copy_slice(self.c_metadata.value)
@property
def key(self):
return _slice_bytes(self.c_metadata.key)
@property
def value(self):
return _slice_bytes(self.c_metadata.value)
def __len__(self):
return 2
def __getitem__(self, size_t i):
if i == 0:
return self.key
elif i == 1:
return self.value
else:
raise IndexError("index must be 0 (key) or 1 (value)")
def __iter__(self):
return iter((self.key, self.value))
def __dealloc__(self):
grpc_slice_unref(self.c_metadata.key)
grpc_slice_unref(self.c_metadata.value)
cdef class _MetadataIterator:
cdef size_t i
cdef size_t _length
cdef object _metadatum_indexable
def __cinit__(self, length, metadatum_indexable):
self._length = length
self._metadatum_indexable = metadatum_indexable
self.i = 0
def __iter__(self):
return self
def __next__(self):
if self.i < self._length:
result = self._metadatum_indexable[self.i]
self.i = self.i + 1
return result
else:
raise StopIteration()
# TODO(https://github.com/grpc/grpc/issues/7950): Eliminate this; just use an
# ordinary sequence of pairs of bytestrings all the way down to the
# grpc_call_start_batch call.
cdef class Metadata:
"""Metadata being passed from application to core."""
def __cinit__(self, metadata_iterable):
metadata_sequence = tuple(metadata_iterable)
cdef size_t count = len(metadata_sequence)
with nogil:
grpc_init()
self.c_metadata = <grpc_metadata *>gpr_malloc(
count * sizeof(grpc_metadata))
self.c_count = count
for index, metadatum in enumerate(metadata_sequence):
self.c_metadata[index].key = grpc_slice_copy(
(<Metadatum>metadatum).c_metadata.key)
self.c_metadata[index].value = grpc_slice_copy(
(<Metadatum>metadatum).c_metadata.value)
def __dealloc__(self):
with nogil:
for index in range(self.c_count):
grpc_slice_unref(self.c_metadata[index].key)
grpc_slice_unref(self.c_metadata[index].value)
gpr_free(self.c_metadata)
grpc_shutdown()
def __len__(self):
return self.c_count
def __getitem__(self, size_t index):
if index < self.c_count:
key = _slice_bytes(self.c_metadata[index].key)
value = _slice_bytes(self.c_metadata[index].value)
return Metadatum(key, value)
else:
raise IndexError()
def __iter__(self):
return _MetadataIterator(self.c_count, self)
cdef class MetadataArray:
"""Metadata being passed from core to application."""
def __cinit__(self):
with nogil:
grpc_init()
grpc_metadata_array_init(&self.c_metadata_array)
def __dealloc__(self):
with nogil:
grpc_metadata_array_destroy(&self.c_metadata_array)
grpc_shutdown()
def __len__(self):
return self.c_metadata_array.count
def __getitem__(self, size_t i):
if i >= self.c_metadata_array.count:
raise IndexError()
key = _slice_bytes(self.c_metadata_array.metadata[i].key)
value = _slice_bytes(self.c_metadata_array.metadata[i].value)
return Metadatum(key=key, value=value)
def __iter__(self):
return _MetadataIterator(self.c_metadata_array.count, self)
cdef class Operation:
def __cinit__(self):
grpc_init()
self.references = []
self._c_metadata_needs_release = False
self._c_metadata_array_needs_destruction = False
self._status_details = grpc_empty_slice()
self.is_valid = False
@ -556,13 +446,7 @@ cdef class Operation:
if (self.c_op.type != GRPC_OP_RECV_INITIAL_METADATA and
self.c_op.type != GRPC_OP_RECV_STATUS_ON_CLIENT):
raise TypeError("self must be an operation receiving metadata")
# TODO(https://github.com/grpc/grpc/issues/7950): Drop the "all Cython
# objects must be legitimate for use from Python at any time" policy in
# place today, shift the policy toward "Operation objects are only usable
# while their calls are active", and move this making-a-copy-because-this-
# data-needs-to-live-much-longer-than-the-call-from-which-it-arose to the
# lowest Python layer.
return tuple(self._received_metadata)
return _metadata(&self._c_metadata_array)
@property
def received_status_code(self):
@ -602,16 +486,21 @@ cdef class Operation:
return False if self._received_cancelled == 0 else True
def __dealloc__(self):
if self._c_metadata_needs_release:
_release_c_metadata(self._c_metadata, self._c_metadata_count)
if self._c_metadata_array_needs_destruction:
grpc_metadata_array_destroy(&self._c_metadata_array)
grpc_slice_unref(self._status_details)
grpc_shutdown()
def operation_send_initial_metadata(Metadata metadata, int flags):
def operation_send_initial_metadata(metadata, int flags):
cdef Operation op = Operation()
op.c_op.type = GRPC_OP_SEND_INITIAL_METADATA
op.c_op.flags = flags
op.c_op.data.send_initial_metadata.count = metadata.c_count
op.c_op.data.send_initial_metadata.metadata = metadata.c_metadata
op.references.append(metadata)
_store_c_metadata(metadata, &op._c_metadata, &op._c_metadata_count)
op._c_metadata_needs_release = True
op.c_op.data.send_initial_metadata.count = op._c_metadata_count
op.c_op.data.send_initial_metadata.metadata = op._c_metadata
op.is_valid = True
return op
@ -633,18 +522,19 @@ def operation_send_close_from_client(int flags):
return op
def operation_send_status_from_server(
Metadata metadata, grpc_status_code code, bytes details, int flags):
metadata, grpc_status_code code, bytes details, int flags):
cdef Operation op = Operation()
op.c_op.type = GRPC_OP_SEND_STATUS_FROM_SERVER
op.c_op.flags = flags
_store_c_metadata(metadata, &op._c_metadata, &op._c_metadata_count)
op._c_metadata_needs_release = True
op.c_op.data.send_status_from_server.trailing_metadata_count = (
metadata.c_count)
op.c_op.data.send_status_from_server.trailing_metadata = metadata.c_metadata
op._c_metadata_count)
op.c_op.data.send_status_from_server.trailing_metadata = op._c_metadata
op.c_op.data.send_status_from_server.status = code
grpc_slice_unref(op._status_details)
op._status_details = _slice_from_bytes(details)
op.c_op.data.send_status_from_server.status_details = &op._status_details
op.references.append(metadata)
op.is_valid = True
return op
@ -652,9 +542,10 @@ def operation_receive_initial_metadata(int flags):
cdef Operation op = Operation()
op.c_op.type = GRPC_OP_RECV_INITIAL_METADATA
op.c_op.flags = flags
op._received_metadata = MetadataArray()
grpc_metadata_array_init(&op._c_metadata_array)
op.c_op.data.receive_initial_metadata.receive_initial_metadata = (
&op._received_metadata.c_metadata_array)
&op._c_metadata_array)
op._c_metadata_array_needs_destruction = True
op.is_valid = True
return op
@ -675,9 +566,10 @@ def operation_receive_status_on_client(int flags):
cdef Operation op = Operation()
op.c_op.type = GRPC_OP_RECV_STATUS_ON_CLIENT
op.c_op.flags = flags
op._received_metadata = MetadataArray()
grpc_metadata_array_init(&op._c_metadata_array)
op.c_op.data.receive_status_on_client.trailing_metadata = (
&op._received_metadata.c_metadata_array)
&op._c_metadata_array)
op._c_metadata_array_needs_destruction = True
op.c_op.data.receive_status_on_client.status = (
&op._received_status_code)
op.c_op.data.receive_status_on_client.status_details = (
@ -694,59 +586,6 @@ def operation_receive_close_on_server(int flags):
return op
cdef class _OperationsIterator:
cdef size_t i
cdef Operations operations
def __cinit__(self, Operations operations not None):
self.i = 0
self.operations = operations
def __iter__(self):
return self
def __next__(self):
if self.i < len(self.operations):
result = self.operations[self.i]
self.i = self.i + 1
return result
else:
raise StopIteration()
cdef class Operations:
def __cinit__(self, operations):
grpc_init()
self.operations = list(operations) # normalize iterable
self.c_ops = NULL
self.c_nops = 0
for operation in self.operations:
if not isinstance(operation, Operation):
raise TypeError("expected operations to be iterable of Operation")
self.c_nops = len(self.operations)
with nogil:
self.c_ops = <grpc_op *>gpr_malloc(sizeof(grpc_op)*self.c_nops)
for i in range(self.c_nops):
self.c_ops[i] = (<Operation>(self.operations[i])).c_op
def __len__(self):
return self.c_nops
def __getitem__(self, size_t i):
# self.operations is never stale; it's only updated from this file
return self.operations[i]
def __dealloc__(self):
with nogil:
gpr_free(self.c_ops)
grpc_shutdown()
def __iter__(self):
return _OperationsIterator(self)
cdef class CompressionOptions:
def __cinit__(self):

@ -78,23 +78,19 @@ cdef class Server:
raise ValueError("server must be started and not shutting down")
if server_queue not in self.registered_completion_queues:
raise ValueError("server_queue must be a registered completion queue")
cdef grpc_call_error result
cdef OperationTag operation_tag = OperationTag(tag)
cdef OperationTag operation_tag = OperationTag(tag, None)
operation_tag.operation_call = Call()
operation_tag.request_call_details = CallDetails()
operation_tag.request_metadata = MetadataArray()
grpc_metadata_array_init(&operation_tag._c_request_metadata)
operation_tag.references.extend([self, call_queue, server_queue])
operation_tag.is_new_request = True
operation_tag.batch_operations = Operations([])
cpython.Py_INCREF(operation_tag)
with nogil:
result = grpc_server_request_call(
self.c_server, &operation_tag.operation_call.c_call,
&operation_tag.request_call_details.c_details,
&operation_tag.request_metadata.c_metadata_array,
call_queue.c_completion_queue, server_queue.c_completion_queue,
<cpython.PyObject *>operation_tag)
return result
return grpc_server_request_call(
self.c_server, &operation_tag.operation_call.c_call,
&operation_tag.request_call_details.c_details,
&operation_tag._c_request_metadata,
call_queue.c_completion_queue, server_queue.c_completion_queue,
<cpython.PyObject *>operation_tag)
def register_completion_queue(
self, CompletionQueue queue not None):
@ -135,7 +131,7 @@ cdef class Server:
cdef _c_shutdown(self, CompletionQueue queue, tag):
self.is_shutting_down = True
operation_tag = OperationTag(tag)
operation_tag = OperationTag(tag, None)
operation_tag.shutting_down_server = self
cpython.Py_INCREF(operation_tag)
with nogil:

@ -18,6 +18,7 @@ include "_cygrpc/call.pxd.pxi"
include "_cygrpc/channel.pxd.pxi"
include "_cygrpc/credentials.pxd.pxi"
include "_cygrpc/completion_queue.pxd.pxi"
include "_cygrpc/metadata.pxd.pxi"
include "_cygrpc/records.pxd.pxi"
include "_cygrpc/security.pxd.pxi"
include "_cygrpc/server.pxd.pxi"

@ -25,6 +25,7 @@ include "_cygrpc/call.pyx.pxi"
include "_cygrpc/channel.pyx.pxi"
include "_cygrpc/credentials.pyx.pxi"
include "_cygrpc/completion_queue.pyx.pxi"
include "_cygrpc/metadata.pyx.pxi"
include "_cygrpc/records.pyx.pxi"
include "_cygrpc/security.pyx.pxi"
include "_cygrpc/server.pyx.pxi"

@ -0,0 +1,318 @@
# Copyright 2017 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implementation of gRPC Python interceptors."""
import collections
import sys
import grpc
class _ServicePipeline(object):
def __init__(self, interceptors):
self.interceptors = tuple(interceptors)
def _continuation(self, thunk, index):
return lambda context: self._intercept_at(thunk, index, context)
def _intercept_at(self, thunk, index, context):
if index < len(self.interceptors):
interceptor = self.interceptors[index]
thunk = self._continuation(thunk, index + 1)
return interceptor.intercept_service(thunk, context)
else:
return thunk(context)
def execute(self, thunk, context):
return self._intercept_at(thunk, 0, context)
def service_pipeline(interceptors):
return _ServicePipeline(interceptors) if interceptors else None
class _ClientCallDetails(
collections.namedtuple('_ClientCallDetails',
('method', 'timeout', 'metadata',
'credentials')), grpc.ClientCallDetails):
pass
class _LocalFailure(grpc.RpcError, grpc.Future, grpc.Call):
def __init__(self, exception, traceback):
super(_LocalFailure, self).__init__()
self._exception = exception
self._traceback = traceback
def initial_metadata(self):
return None
def trailing_metadata(self):
return None
def code(self):
return grpc.StatusCode.INTERNAL
def details(self):
return 'Exception raised while intercepting the RPC'
def cancel(self):
return False
def cancelled(self):
return False
def running(self):
return False
def done(self):
return True
def result(self, ignored_timeout=None):
raise self._exception
def exception(self, ignored_timeout=None):
return self._exception
def traceback(self, ignored_timeout=None):
return self._traceback
def add_done_callback(self, fn):
fn(self)
def __iter__(self):
return self
def next(self):
raise self._exception
class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
def __init__(self, thunk, method, interceptor):
self._thunk = thunk
self._method = method
self._interceptor = interceptor
def __call__(self, request, timeout=None, metadata=None, credentials=None):
call_future = self.future(
request,
timeout=timeout,
metadata=metadata,
credentials=credentials)
return call_future.result()
def with_call(self, request, timeout=None, metadata=None, credentials=None):
call_future = self.future(
request,
timeout=timeout,
metadata=metadata,
credentials=credentials)
return call_future.result(), call_future
def future(self, request, timeout=None, metadata=None, credentials=None):
def continuation(client_call_details, request):
return self._thunk(client_call_details.method).future(
request,
timeout=client_call_details.timeout,
metadata=client_call_details.metadata,
credentials=client_call_details.credentials)
client_call_details = _ClientCallDetails(self._method, timeout,
metadata, credentials)
try:
return self._interceptor.intercept_unary_unary(
continuation, client_call_details, request)
except Exception as exception: # pylint:disable=broad-except
return _LocalFailure(exception, sys.exc_info()[2])
class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
def __init__(self, thunk, method, interceptor):
self._thunk = thunk
self._method = method
self._interceptor = interceptor
def __call__(self, request, timeout=None, metadata=None, credentials=None):
def continuation(client_call_details, request):
return self._thunk(client_call_details.method)(
request,
timeout=client_call_details.timeout,
metadata=client_call_details.metadata,
credentials=client_call_details.credentials)
client_call_details = _ClientCallDetails(self._method, timeout,
metadata, credentials)
try:
return self._interceptor.intercept_unary_stream(
continuation, client_call_details, request)
except Exception as exception: # pylint:disable=broad-except
return _LocalFailure(exception, sys.exc_info()[2])
class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
def __init__(self, thunk, method, interceptor):
self._thunk = thunk
self._method = method
self._interceptor = interceptor
def __call__(self,
request_iterator,
timeout=None,
metadata=None,
credentials=None):
call_future = self.future(
request_iterator,
timeout=timeout,
metadata=metadata,
credentials=credentials)
return call_future.result()
def with_call(self,
request_iterator,
timeout=None,
metadata=None,
credentials=None):
call_future = self.future(
request_iterator,
timeout=timeout,
metadata=metadata,
credentials=credentials)
return call_future.result(), call_future
def future(self,
request_iterator,
timeout=None,
metadata=None,
credentials=None):
def continuation(client_call_details, request_iterator):
return self._thunk(client_call_details.method).future(
request_iterator,
timeout=client_call_details.timeout,
metadata=client_call_details.metadata,
credentials=client_call_details.credentials)
client_call_details = _ClientCallDetails(self._method, timeout,
metadata, credentials)
try:
return self._interceptor.intercept_stream_unary(
continuation, client_call_details, request_iterator)
except Exception as exception: # pylint:disable=broad-except
return _LocalFailure(exception, sys.exc_info()[2])
class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable):
def __init__(self, thunk, method, interceptor):
self._thunk = thunk
self._method = method
self._interceptor = interceptor
def __call__(self,
request_iterator,
timeout=None,
metadata=None,
credentials=None):
def continuation(client_call_details, request_iterator):
return self._thunk(client_call_details.method)(
request_iterator,
timeout=client_call_details.timeout,
metadata=client_call_details.metadata,
credentials=client_call_details.credentials)
client_call_details = _ClientCallDetails(self._method, timeout,
metadata, credentials)
try:
return self._interceptor.intercept_stream_stream(
continuation, client_call_details, request_iterator)
except Exception as exception: # pylint:disable=broad-except
return _LocalFailure(exception, sys.exc_info()[2])
class _Channel(grpc.Channel):
def __init__(self, channel, interceptor):
self._channel = channel
self._interceptor = interceptor
def subscribe(self, *args, **kwargs):
self._channel.subscribe(*args, **kwargs)
def unsubscribe(self, *args, **kwargs):
self._channel.unsubscribe(*args, **kwargs)
def unary_unary(self,
method,
request_serializer=None,
response_deserializer=None):
thunk = lambda m: self._channel.unary_unary(m, request_serializer, response_deserializer)
if isinstance(self._interceptor, grpc.UnaryUnaryClientInterceptor):
return _UnaryUnaryMultiCallable(thunk, method, self._interceptor)
else:
return thunk(method)
def unary_stream(self,
method,
request_serializer=None,
response_deserializer=None):
thunk = lambda m: self._channel.unary_stream(m, request_serializer, response_deserializer)
if isinstance(self._interceptor, grpc.UnaryStreamClientInterceptor):
return _UnaryStreamMultiCallable(thunk, method, self._interceptor)
else:
return thunk(method)
def stream_unary(self,
method,
request_serializer=None,
response_deserializer=None):
thunk = lambda m: self._channel.stream_unary(m, request_serializer, response_deserializer)
if isinstance(self._interceptor, grpc.StreamUnaryClientInterceptor):
return _StreamUnaryMultiCallable(thunk, method, self._interceptor)
else:
return thunk(method)
def stream_stream(self,
method,
request_serializer=None,
response_deserializer=None):
thunk = lambda m: self._channel.stream_stream(m, request_serializer, response_deserializer)
if isinstance(self._interceptor, grpc.StreamStreamClientInterceptor):
return _StreamStreamMultiCallable(thunk, method, self._interceptor)
else:
return thunk(method)
def intercept_channel(channel, *interceptors):
for interceptor in reversed(list(interceptors)):
if not isinstance(interceptor, grpc.UnaryUnaryClientInterceptor) and \
not isinstance(interceptor, grpc.UnaryStreamClientInterceptor) and \
not isinstance(interceptor, grpc.StreamUnaryClientInterceptor) and \
not isinstance(interceptor, grpc.StreamStreamClientInterceptor):
raise TypeError('interceptor must be '
'grpc.UnaryUnaryClientInterceptor or '
'grpc.UnaryStreamClientInterceptor or '
'grpc.StreamUnaryClientInterceptor or '
'grpc.StreamStreamClientInterceptor or ')
channel = _Channel(channel, interceptor)
return channel

@ -54,9 +54,7 @@ class _AuthMetadataPluginCallback(grpc.AuthMetadataPluginCallback):
'AuthMetadataPluginCallback raised exception "{}"!'.format(
self._state.exception))
if error is None:
self._callback(
_common.to_cygrpc_metadata(metadata), cygrpc.StatusCode.ok,
None)
self._callback(metadata, cygrpc.StatusCode.ok, None)
else:
self._callback(None, cygrpc.StatusCode.internal,
_common.encode(str(error)))

@ -23,6 +23,7 @@ import six
import grpc
from grpc import _common
from grpc import _interceptor
from grpc._cython import cygrpc
from grpc.framework.foundation import callable_util
@ -96,6 +97,7 @@ class _RPCState(object):
self.statused = False
self.rpc_errors = []
self.callbacks = []
self.abortion = None
def _raise_rpc_error(state):
@ -129,19 +131,17 @@ def _abort(state, call, code, details):
effective_details = details if state.details is None else state.details
if state.initial_metadata_allowed:
operations = (cygrpc.operation_send_initial_metadata(
_common.EMPTY_METADATA,
_EMPTY_FLAGS), cygrpc.operation_send_status_from_server(
_common.to_cygrpc_metadata(state.trailing_metadata),
effective_code, effective_details, _EMPTY_FLAGS),)
(), _EMPTY_FLAGS), cygrpc.operation_send_status_from_server(
state.trailing_metadata, effective_code, effective_details,
_EMPTY_FLAGS),)
token = _SEND_INITIAL_METADATA_AND_SEND_STATUS_FROM_SERVER_TOKEN
else:
operations = (cygrpc.operation_send_status_from_server(
_common.to_cygrpc_metadata(state.trailing_metadata),
effective_code, effective_details, _EMPTY_FLAGS),)
state.trailing_metadata, effective_code, effective_details,
_EMPTY_FLAGS),)
token = _SEND_STATUS_FROM_SERVER_TOKEN
call.start_server_batch(
cygrpc.Operations(operations),
_send_status_from_server(state, token))
call.start_server_batch(operations,
_send_status_from_server(state, token))
state.statused = True
state.due.add(token)
@ -237,7 +237,7 @@ class _Context(grpc.ServicerContext):
self._state.disable_next_compression = True
def invocation_metadata(self):
return _common.to_application_metadata(self._rpc_event.request_metadata)
return self._rpc_event.request_metadata
def peer(self):
return _common.decode(self._rpc_event.operation_call.peer())
@ -263,11 +263,9 @@ class _Context(grpc.ServicerContext):
else:
if self._state.initial_metadata_allowed:
operation = cygrpc.operation_send_initial_metadata(
_common.to_cygrpc_metadata(initial_metadata),
_EMPTY_FLAGS)
initial_metadata, _EMPTY_FLAGS)
self._rpc_event.operation_call.start_server_batch(
cygrpc.Operations((operation,)),
_send_initial_metadata(self._state))
(operation,), _send_initial_metadata(self._state))
self._state.initial_metadata_allowed = False
self._state.due.add(_SEND_INITIAL_METADATA_TOKEN)
else:
@ -275,8 +273,14 @@ class _Context(grpc.ServicerContext):
def set_trailing_metadata(self, trailing_metadata):
with self._state.condition:
self._state.trailing_metadata = _common.to_cygrpc_metadata(
trailing_metadata)
self._state.trailing_metadata = trailing_metadata
def abort(self, code, details):
with self._state.condition:
self._state.code = code
self._state.details = _common.encode(details)
self._state.abortion = Exception()
raise self._state.abortion
def set_code(self, code):
with self._state.condition:
@ -301,8 +305,7 @@ class _RequestIterator(object):
raise StopIteration()
else:
self._call.start_server_batch(
cygrpc.Operations(
(cygrpc.operation_receive_message(_EMPTY_FLAGS),)),
(cygrpc.operation_receive_message(_EMPTY_FLAGS),),
_receive_message(self._state, self._call,
self._request_deserializer))
self._state.due.add(_RECEIVE_MESSAGE_TOKEN)
@ -345,8 +348,7 @@ def _unary_request(rpc_event, state, request_deserializer):
return None
else:
rpc_event.operation_call.start_server_batch(
cygrpc.Operations(
(cygrpc.operation_receive_message(_EMPTY_FLAGS),)),
(cygrpc.operation_receive_message(_EMPTY_FLAGS),),
_receive_message(state, rpc_event.operation_call,
request_deserializer))
state.due.add(_RECEIVE_MESSAGE_TOKEN)
@ -376,7 +378,10 @@ def _call_behavior(rpc_event, state, behavior, argument, request_deserializer):
return behavior(argument, context), True
except Exception as exception: # pylint: disable=broad-except
with state.condition:
if exception not in state.rpc_errors:
if exception is state.abortion:
_abort(state, rpc_event.operation_call,
cygrpc.StatusCode.unknown, b'RPC Aborted')
elif exception not in state.rpc_errors:
details = 'Exception calling application: {}'.format(exception)
logging.exception(details)
_abort(state, rpc_event.operation_call,
@ -391,7 +396,10 @@ def _take_response_from_response_iterator(rpc_event, state, response_iterator):
return None, True
except Exception as exception: # pylint: disable=broad-except
with state.condition:
if exception not in state.rpc_errors:
if exception is state.abortion:
_abort(state, rpc_event.operation_call,
cygrpc.StatusCode.unknown, b'RPC Aborted')
elif exception not in state.rpc_errors:
details = 'Exception iterating responses: {}'.format(exception)
logging.exception(details)
_abort(state, rpc_event.operation_call,
@ -417,9 +425,8 @@ def _send_response(rpc_event, state, serialized_response):
else:
if state.initial_metadata_allowed:
operations = (cygrpc.operation_send_initial_metadata(
_common.EMPTY_METADATA, _EMPTY_FLAGS),
cygrpc.operation_send_message(serialized_response,
_EMPTY_FLAGS),)
(), _EMPTY_FLAGS), cygrpc.operation_send_message(
serialized_response, _EMPTY_FLAGS),)
state.initial_metadata_allowed = False
token = _SEND_INITIAL_METADATA_AND_SEND_MESSAGE_TOKEN
else:
@ -427,7 +434,7 @@ def _send_response(rpc_event, state, serialized_response):
_EMPTY_FLAGS),)
token = _SEND_MESSAGE_TOKEN
rpc_event.operation_call.start_server_batch(
cygrpc.Operations(operations), _send_message(state, token))
operations, _send_message(state, token))
state.due.add(token)
while True:
state.condition.wait()
@ -438,24 +445,21 @@ def _send_response(rpc_event, state, serialized_response):
def _status(rpc_event, state, serialized_response):
with state.condition:
if state.client is not _CANCELLED:
trailing_metadata = _common.to_cygrpc_metadata(
state.trailing_metadata)
code = _completion_code(state)
details = _details(state)
operations = [
cygrpc.operation_send_status_from_server(
trailing_metadata, code, details, _EMPTY_FLAGS),
state.trailing_metadata, code, details, _EMPTY_FLAGS),
]
if state.initial_metadata_allowed:
operations.append(
cygrpc.operation_send_initial_metadata(
_common.EMPTY_METADATA, _EMPTY_FLAGS))
cygrpc.operation_send_initial_metadata((), _EMPTY_FLAGS))
if serialized_response is not None:
operations.append(
cygrpc.operation_send_message(serialized_response,
_EMPTY_FLAGS))
rpc_event.operation_call.start_server_batch(
cygrpc.Operations(operations),
operations,
_send_status_from_server(state, _SEND_STATUS_FROM_SERVER_TOKEN))
state.statused = True
state.due.add(_SEND_STATUS_FROM_SERVER_TOKEN)
@ -538,24 +542,31 @@ def _handle_stream_stream(rpc_event, state, method_handler, thread_pool):
method_handler.request_deserializer, method_handler.response_serializer)
def _find_method_handler(rpc_event, generic_handlers):
for generic_handler in generic_handlers:
method_handler = generic_handler.service(
_HandlerCallDetails(
_common.decode(rpc_event.request_call_details.method),
rpc_event.request_metadata))
if method_handler is not None:
return method_handler
else:
def _find_method_handler(rpc_event, generic_handlers, interceptor_pipeline):
def query_handlers(handler_call_details):
for generic_handler in generic_handlers:
method_handler = generic_handler.service(handler_call_details)
if method_handler is not None:
return method_handler
return None
handler_call_details = _HandlerCallDetails(
_common.decode(rpc_event.request_call_details.method),
rpc_event.request_metadata)
if interceptor_pipeline is not None:
return interceptor_pipeline.execute(query_handlers,
handler_call_details)
else:
return query_handlers(handler_call_details)
def _reject_rpc(rpc_event, status, details):
operations = (cygrpc.operation_send_initial_metadata(_common.EMPTY_METADATA,
_EMPTY_FLAGS),
operations = (cygrpc.operation_send_initial_metadata((), _EMPTY_FLAGS),
cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS),
cygrpc.operation_send_status_from_server(
_common.EMPTY_METADATA, status, details, _EMPTY_FLAGS),)
cygrpc.operation_send_status_from_server((), status, details,
_EMPTY_FLAGS),)
rpc_state = _RPCState()
rpc_event.operation_call.start_server_batch(
operations, lambda ignored_event: (rpc_state, (),))
@ -566,8 +577,7 @@ def _handle_with_method_handler(rpc_event, method_handler, thread_pool):
state = _RPCState()
with state.condition:
rpc_event.operation_call.start_server_batch(
cygrpc.Operations(
(cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS),)),
(cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS),),
_receive_close_on_server(state))
state.due.add(_RECEIVE_CLOSE_ON_SERVER_TOKEN)
if method_handler.request_streaming:
@ -586,13 +596,14 @@ def _handle_with_method_handler(rpc_event, method_handler, thread_pool):
method_handler, thread_pool)
def _handle_call(rpc_event, generic_handlers, thread_pool,
def _handle_call(rpc_event, generic_handlers, interceptor_pipeline, thread_pool,
concurrency_exceeded):
if not rpc_event.success:
return None, None
if rpc_event.request_call_details.method is not None:
try:
method_handler = _find_method_handler(rpc_event, generic_handlers)
method_handler = _find_method_handler(rpc_event, generic_handlers,
interceptor_pipeline)
except Exception as exception: # pylint: disable=broad-except
details = 'Exception servicing handler: {}'.format(exception)
logging.exception(details)
@ -620,12 +631,14 @@ class _ServerStage(enum.Enum):
class _ServerState(object):
def __init__(self, completion_queue, server, generic_handlers, thread_pool,
maximum_concurrent_rpcs):
# pylint: disable=too-many-arguments
def __init__(self, completion_queue, server, generic_handlers,
interceptor_pipeline, thread_pool, maximum_concurrent_rpcs):
self.lock = threading.Lock()
self.completion_queue = completion_queue
self.server = server
self.generic_handlers = list(generic_handlers)
self.interceptor_pipeline = interceptor_pipeline
self.thread_pool = thread_pool
self.stage = _ServerStage.STOPPED
self.shutdown_events = None
@ -690,8 +703,8 @@ def _serve(state):
state.maximum_concurrent_rpcs is not None and
state.active_rpc_count >= state.maximum_concurrent_rpcs)
rpc_state, rpc_future = _handle_call(
event, state.generic_handlers, state.thread_pool,
concurrency_exceeded)
event, state.generic_handlers, state.interceptor_pipeline,
state.thread_pool, concurrency_exceeded)
if rpc_state is not None:
state.rpc_states.add(rpc_state)
if rpc_future is not None:
@ -779,12 +792,14 @@ def _start(state):
class Server(grpc.Server):
def __init__(self, thread_pool, generic_handlers, options,
# pylint: disable=too-many-arguments
def __init__(self, thread_pool, generic_handlers, interceptors, options,
maximum_concurrent_rpcs):
completion_queue = cygrpc.CompletionQueue()
server = cygrpc.Server(_common.channel_args(options))
server.register_completion_queue(completion_queue)
self._state = _ServerState(completion_queue, server, generic_handlers,
_interceptor.service_pipeline(interceptors),
thread_pool, maximum_concurrent_rpcs)
def add_generic_rpc_handlers(self, generic_rpc_handlers):

@ -15,6 +15,7 @@
import grpc
from grpc import _common
from grpc.beta import _metadata
from grpc.beta import interfaces
from grpc.framework.common import cardinality
from grpc.framework.foundation import future
@ -157,10 +158,10 @@ class _Rendezvous(future.Future, face.Call):
return _InvocationProtocolContext()
def initial_metadata(self):
return self._call.initial_metadata()
return _metadata.beta(self._call.initial_metadata())
def terminal_metadata(self):
return self._call.terminal_metadata()
return _metadata.beta(self._call.terminal_metadata())
def code(self):
return self._call.code()
@ -182,14 +183,14 @@ def _blocking_unary_unary(channel, group, method, timeout, with_call,
response, call = multi_callable.with_call(
request,
timeout=timeout,
metadata=effective_metadata,
metadata=_metadata.unbeta(effective_metadata),
credentials=_credentials(protocol_options))
return response, _Rendezvous(None, None, call)
else:
return multi_callable(
request,
timeout=timeout,
metadata=effective_metadata,
metadata=_metadata.unbeta(effective_metadata),
credentials=_credentials(protocol_options))
except grpc.RpcError as rpc_error_call:
raise _abortion_error(rpc_error_call)
@ -206,7 +207,7 @@ def _future_unary_unary(channel, group, method, timeout, protocol_options,
response_future = multi_callable.future(
request,
timeout=timeout,
metadata=effective_metadata,
metadata=_metadata.unbeta(effective_metadata),
credentials=_credentials(protocol_options))
return _Rendezvous(response_future, None, response_future)
@ -222,7 +223,7 @@ def _unary_stream(channel, group, method, timeout, protocol_options, metadata,
response_iterator = multi_callable(
request,
timeout=timeout,
metadata=effective_metadata,
metadata=_metadata.unbeta(effective_metadata),
credentials=_credentials(protocol_options))
return _Rendezvous(None, response_iterator, response_iterator)
@ -241,14 +242,14 @@ def _blocking_stream_unary(channel, group, method, timeout, with_call,
response, call = multi_callable.with_call(
request_iterator,
timeout=timeout,
metadata=effective_metadata,
metadata=_metadata.unbeta(effective_metadata),
credentials=_credentials(protocol_options))
return response, _Rendezvous(None, None, call)
else:
return multi_callable(
request_iterator,
timeout=timeout,
metadata=effective_metadata,
metadata=_metadata.unbeta(effective_metadata),
credentials=_credentials(protocol_options))
except grpc.RpcError as rpc_error_call:
raise _abortion_error(rpc_error_call)
@ -265,7 +266,7 @@ def _future_stream_unary(channel, group, method, timeout, protocol_options,
response_future = multi_callable.future(
request_iterator,
timeout=timeout,
metadata=effective_metadata,
metadata=_metadata.unbeta(effective_metadata),
credentials=_credentials(protocol_options))
return _Rendezvous(response_future, None, response_future)
@ -281,7 +282,7 @@ def _stream_stream(channel, group, method, timeout, protocol_options, metadata,
response_iterator = multi_callable(
request_iterator,
timeout=timeout,
metadata=effective_metadata,
metadata=_metadata.unbeta(effective_metadata),
credentials=_credentials(protocol_options))
return _Rendezvous(None, response_iterator, response_iterator)

@ -0,0 +1,49 @@
# Copyright 2017 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""API metadata conversion utilities."""
import collections
_Metadatum = collections.namedtuple('_Metadatum', ('key', 'value',))
def _beta_metadatum(key, value):
beta_key = key if isinstance(key, (bytes,)) else key.encode('ascii')
beta_value = value if isinstance(value, (bytes,)) else value.encode('ascii')
return _Metadatum(beta_key, beta_value)
def _metadatum(beta_key, beta_value):
key = beta_key if isinstance(beta_key, (str,)) else beta_key.decode('utf8')
if isinstance(beta_value, (str,)) or key[-4:] == '-bin':
value = beta_value
else:
value = beta_value.decode('utf8')
return _Metadatum(key, value)
def beta(metadata):
if metadata is None:
return ()
else:
return tuple(_beta_metadatum(key, value) for key, value in metadata)
def unbeta(beta_metadata):
if beta_metadata is None:
return ()
else:
return tuple(
_metadatum(beta_key, beta_value)
for beta_key, beta_value in beta_metadata)

@ -18,6 +18,7 @@ import threading
import grpc
from grpc import _common
from grpc.beta import _metadata
from grpc.beta import interfaces
from grpc.framework.common import cardinality
from grpc.framework.common import style
@ -65,14 +66,15 @@ class _FaceServicerContext(face.ServicerContext):
return _ServerProtocolContext(self._servicer_context)
def invocation_metadata(self):
return _common.to_cygrpc_metadata(
self._servicer_context.invocation_metadata())
return _metadata.beta(self._servicer_context.invocation_metadata())
def initial_metadata(self, initial_metadata):
self._servicer_context.send_initial_metadata(initial_metadata)
self._servicer_context.send_initial_metadata(
_metadata.unbeta(initial_metadata))
def terminal_metadata(self, terminal_metadata):
self._servicer_context.set_terminal_metadata(terminal_metadata)
self._servicer_context.set_terminal_metadata(
_metadata.unbeta(terminal_metadata))
def code(self, code):
self._servicer_context.set_code(code)

@ -21,6 +21,7 @@ import threading # pylint: disable=unused-import
import grpc
from grpc import _auth
from grpc.beta import _client_adaptations
from grpc.beta import _metadata
from grpc.beta import _server_adaptations
from grpc.beta import interfaces # pylint: disable=unused-import
from grpc.framework.common import cardinality # pylint: disable=unused-import
@ -31,7 +32,18 @@ from grpc.framework.interfaces.face import face # pylint: disable=unused-import
ChannelCredentials = grpc.ChannelCredentials
ssl_channel_credentials = grpc.ssl_channel_credentials
CallCredentials = grpc.CallCredentials
metadata_call_credentials = grpc.metadata_call_credentials
def metadata_call_credentials(metadata_plugin, name=None):
def plugin(context, callback):
def wrapped_callback(beta_metadata, error):
callback(_metadata.unbeta(beta_metadata), error)
metadata_plugin(context, wrapped_callback)
return grpc.metadata_call_credentials(plugin, name=name)
def google_call_credentials(credentials):

@ -67,6 +67,9 @@ class ServicerContext(grpc.ServicerContext):
self._rpc.set_trailing_metadata(
_common.fuss_with_metadata(trailing_metadata))
def abort(self, code, details):
raise NotImplementedError()
def set_code(self, code):
self._rpc.set_code(code)

@ -39,6 +39,7 @@
"unit._cython.cygrpc_test.TypeSmokeTest",
"unit._empty_message_test.EmptyMessageTest",
"unit._exit_test.ExitTest",
"unit._interceptor_test.InterceptorTest",
"unit._invalid_metadata_test.InvalidMetadataTest",
"unit._invocation_defects_test.InvocationDefectsTest",
"unit._metadata_code_details_test.MetadataCodeDetailsTest",

@ -33,18 +33,21 @@ class AllTest(unittest.TestCase):
'AuthMetadataPlugin', 'ServerCertificateConfiguration',
'ServerCredentials', 'UnaryUnaryMultiCallable',
'UnaryStreamMultiCallable', 'StreamUnaryMultiCallable',
'StreamStreamMultiCallable', 'Channel', 'ServicerContext',
'StreamStreamMultiCallable', 'UnaryUnaryClientInterceptor',
'UnaryStreamClientInterceptor', 'StreamUnaryClientInterceptor',
'StreamStreamClientInterceptor', 'Channel', 'ServicerContext',
'RpcMethodHandler', 'HandlerCallDetails', 'GenericRpcHandler',
'ServiceRpcHandler', 'Server', 'unary_unary_rpc_method_handler',
'unary_stream_rpc_method_handler',
'stream_unary_rpc_method_handler',
'ServiceRpcHandler', 'Server', 'ServerInterceptor',
'unary_unary_rpc_method_handler', 'unary_stream_rpc_method_handler',
'stream_unary_rpc_method_handler', 'ClientCallDetails',
'stream_stream_rpc_method_handler',
'method_handlers_generic_handler', 'ssl_channel_credentials',
'metadata_call_credentials', 'access_token_call_credentials',
'composite_call_credentials', 'composite_channel_credentials',
'ssl_server_credentials', 'ssl_server_certificate_configuration',
'dynamic_ssl_server_credentials', 'channel_ready_future',
'insecure_channel', 'secure_channel', 'server',)
'insecure_channel', 'secure_channel', 'intercept_channel',
'server',)
six.assertCountEqual(self, expected_grpc_code_elements,
_from_grpc_import_star.GRPC_ELEMENTS)

@ -22,7 +22,7 @@ from tests.unit.framework.common import test_constants
_INFINITE_FUTURE = cygrpc.Timespec(float('+inf'))
_EMPTY_FLAGS = 0
_EMPTY_METADATA = cygrpc.Metadata(())
_EMPTY_METADATA = ()
_SERVER_SHUTDOWN_TAG = 'server_shutdown'
_REQUEST_CALL_TAG = 'request_call'
@ -65,12 +65,10 @@ class _Handler(object):
with self._lock:
self._call.start_server_batch(
cygrpc.Operations(
(cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS),)),
(cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS),),
_RECEIVE_CLOSE_ON_SERVER_TAG)
self._call.start_server_batch(
cygrpc.Operations(
(cygrpc.operation_receive_message(_EMPTY_FLAGS),)),
(cygrpc.operation_receive_message(_EMPTY_FLAGS),),
_RECEIVE_MESSAGE_TAG)
first_event = self._completion_queue.poll()
if _is_cancellation_event(first_event):
@ -84,8 +82,8 @@ class _Handler(object):
cygrpc.operation_send_status_from_server(
_EMPTY_METADATA, cygrpc.StatusCode.ok, b'test details!',
_EMPTY_FLAGS),)
self._call.start_server_batch(
cygrpc.Operations(operations), _SERVER_COMPLETE_CALL_TAG)
self._call.start_server_batch(operations,
_SERVER_COMPLETE_CALL_TAG)
self._completion_queue.poll()
self._completion_queue.poll()
@ -179,8 +177,7 @@ class CancelManyCallsTest(unittest.TestCase):
cygrpc.operation_receive_message(_EMPTY_FLAGS),
cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),)
tag = 'client_complete_call_{0:04d}_tag'.format(index)
client_call.start_client_batch(
cygrpc.Operations(operations), tag)
client_call.start_client_batch(operations, tag)
client_due.add(tag)
client_calls.append(client_call)

@ -23,17 +23,14 @@ RPC_COUNT = 4000
INFINITE_FUTURE = cygrpc.Timespec(float('+inf'))
EMPTY_FLAGS = 0
INVOCATION_METADATA = cygrpc.Metadata(
(cygrpc.Metadatum(b'client-md-key', b'client-md-key'),
cygrpc.Metadatum(b'client-md-key-bin', b'\x00\x01' * 3000),))
INVOCATION_METADATA = (('client-md-key', 'client-md-key'),
('client-md-key-bin', b'\x00\x01' * 3000),)
INITIAL_METADATA = cygrpc.Metadata(
(cygrpc.Metadatum(b'server-initial-md-key', b'server-initial-md-value'),
cygrpc.Metadatum(b'server-initial-md-key-bin', b'\x00\x02' * 3000),))
INITIAL_METADATA = (('server-initial-md-key', 'server-initial-md-value'),
('server-initial-md-key-bin', b'\x00\x02' * 3000),)
TRAILING_METADATA = cygrpc.Metadata(
(cygrpc.Metadatum(b'server-trailing-md-key', b'server-trailing-md-value'),
cygrpc.Metadatum(b'server-trailing-md-key-bin', b'\x00\x03' * 3000),))
TRAILING_METADATA = (('server-trailing-md-key', 'server-trailing-md-value'),
('server-trailing-md-key-bin', b'\x00\x03' * 3000),)
class QueueDriver(object):

@ -48,20 +48,19 @@ class Test(_common.RpcTest, unittest.TestCase):
client_complete_rpc_tag = 'client_complete_rpc_tag'
with self.client_condition:
client_receive_initial_metadata_start_batch_result = (
client_call.start_client_batch(
cygrpc.Operations([
cygrpc.operation_receive_initial_metadata(
_common.EMPTY_FLAGS),
]), client_receive_initial_metadata_tag))
client_call.start_client_batch([
cygrpc.operation_receive_initial_metadata(
_common.EMPTY_FLAGS),
], client_receive_initial_metadata_tag))
client_complete_rpc_start_batch_result = client_call.start_client_batch(
cygrpc.Operations([
[
cygrpc.operation_send_initial_metadata(
_common.INVOCATION_METADATA, _common.EMPTY_FLAGS),
cygrpc.operation_send_close_from_client(
_common.EMPTY_FLAGS),
cygrpc.operation_receive_status_on_client(
_common.EMPTY_FLAGS),
]), client_complete_rpc_tag)
], client_complete_rpc_tag)
self.client_driver.add_due({
client_receive_initial_metadata_tag,
client_complete_rpc_tag,

@ -43,20 +43,19 @@ class Test(_common.RpcTest, unittest.TestCase):
client_complete_rpc_tag = 'client_complete_rpc_tag'
with self.client_condition:
client_receive_initial_metadata_start_batch_result = (
client_call.start_client_batch(
cygrpc.Operations([
cygrpc.operation_receive_initial_metadata(
_common.EMPTY_FLAGS),
]), client_receive_initial_metadata_tag))
client_call.start_client_batch([
cygrpc.operation_receive_initial_metadata(
_common.EMPTY_FLAGS),
], client_receive_initial_metadata_tag))
client_complete_rpc_start_batch_result = client_call.start_client_batch(
cygrpc.Operations([
[
cygrpc.operation_send_initial_metadata(
_common.INVOCATION_METADATA, _common.EMPTY_FLAGS),
cygrpc.operation_send_close_from_client(
_common.EMPTY_FLAGS),
cygrpc.operation_receive_status_on_client(
_common.EMPTY_FLAGS),
]), client_complete_rpc_tag)
], client_complete_rpc_tag)
self.client_driver.add_due({
client_receive_initial_metadata_tag,
client_complete_rpc_tag,

@ -20,7 +20,7 @@ from grpc._cython import cygrpc
_INFINITE_FUTURE = cygrpc.Timespec(float('+inf'))
_EMPTY_FLAGS = 0
_EMPTY_METADATA = cygrpc.Metadata(())
_EMPTY_METADATA = ()
class _ServerDriver(object):
@ -157,19 +157,17 @@ class ReadSomeButNotAllResponsesTest(unittest.TestCase):
client_complete_rpc_tag = 'client_complete_rpc_tag'
with client_condition:
client_receive_initial_metadata_start_batch_result = (
client_call.start_client_batch(
cygrpc.Operations([
cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),
]), client_receive_initial_metadata_tag))
client_call.start_client_batch([
cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),
], client_receive_initial_metadata_tag))
client_due.add(client_receive_initial_metadata_tag)
client_complete_rpc_start_batch_result = (
client_call.start_client_batch(
cygrpc.Operations([
cygrpc.operation_send_initial_metadata(_EMPTY_METADATA,
_EMPTY_FLAGS),
cygrpc.operation_send_close_from_client(_EMPTY_FLAGS),
cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),
]), client_complete_rpc_tag))
client_call.start_client_batch([
cygrpc.operation_send_initial_metadata(_EMPTY_METADATA,
_EMPTY_FLAGS),
cygrpc.operation_send_close_from_client(_EMPTY_FLAGS),
cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),
], client_complete_rpc_tag))
client_due.add(client_complete_rpc_tag)
server_rpc_event = server_driver.first_event()
@ -197,8 +195,8 @@ class ReadSomeButNotAllResponsesTest(unittest.TestCase):
server_rpc_event.operation_call.start_server_batch([
cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS),
cygrpc.operation_send_status_from_server(
cygrpc.Metadata(()), cygrpc.StatusCode.ok,
b'test details', _EMPTY_FLAGS),
(), cygrpc.StatusCode.ok, b'test details',
_EMPTY_FLAGS),
], server_complete_rpc_tag))
server_send_second_message_event = server_call_driver.event_with_tag(
server_send_second_message_tag)
@ -209,10 +207,9 @@ class ReadSomeButNotAllResponsesTest(unittest.TestCase):
with client_condition:
client_receive_first_message_tag = 'client_receive_first_message_tag'
client_receive_first_message_start_batch_result = (
client_call.start_client_batch(
cygrpc.Operations([
cygrpc.operation_receive_message(_EMPTY_FLAGS),
]), client_receive_first_message_tag))
client_call.start_client_batch([
cygrpc.operation_receive_message(_EMPTY_FLAGS),
], client_receive_first_message_tag))
client_due.add(client_receive_first_message_tag)
client_receive_first_message_event = client_driver.event_with_tag(
client_receive_first_message_tag)

@ -29,50 +29,12 @@ _EMPTY_FLAGS = 0
def _metadata_plugin(context, callback):
callback(
cygrpc.Metadata([
cygrpc.Metadatum(_CALL_CREDENTIALS_METADATA_KEY,
_CALL_CREDENTIALS_METADATA_VALUE)
]), cygrpc.StatusCode.ok, b'')
callback(((_CALL_CREDENTIALS_METADATA_KEY,
_CALL_CREDENTIALS_METADATA_VALUE,),), cygrpc.StatusCode.ok, b'')
class TypeSmokeTest(unittest.TestCase):
def testStringsInUtilitiesUpDown(self):
self.assertEqual(0, cygrpc.StatusCode.ok)
metadatum = cygrpc.Metadatum(b'a', b'b')
self.assertEqual(b'a', metadatum.key)
self.assertEqual(b'b', metadatum.value)
metadata = cygrpc.Metadata([metadatum])
self.assertEqual(1, len(metadata))
self.assertEqual(metadatum.key, metadata[0].key)
def testMetadataIteration(self):
metadata = cygrpc.Metadata(
[cygrpc.Metadatum(b'a', b'b'), cygrpc.Metadatum(b'c', b'd')])
iterator = iter(metadata)
metadatum = next(iterator)
self.assertIsInstance(metadatum, cygrpc.Metadatum)
self.assertEqual(metadatum.key, b'a')
self.assertEqual(metadatum.value, b'b')
metadatum = next(iterator)
self.assertIsInstance(metadatum, cygrpc.Metadatum)
self.assertEqual(metadatum.key, b'c')
self.assertEqual(metadatum.value, b'd')
with self.assertRaises(StopIteration):
next(iterator)
def testOperationsIteration(self):
operations = cygrpc.Operations(
[cygrpc.operation_send_message(b'asdf', _EMPTY_FLAGS)])
iterator = iter(operations)
operation = next(iterator)
self.assertIsInstance(operation, cygrpc.Operation)
# `Operation`s are write-only structures; can't directly debug anything out
# of them. Just check that we stop iterating.
with self.assertRaises(StopIteration):
next(iterator)
def testOperationFlags(self):
operation = cygrpc.operation_send_message(b'asdf',
cygrpc.WriteFlag.no_compress)
@ -182,8 +144,7 @@ class ServerClientMixin(object):
def performer():
tag = object()
try:
call_result = call.start_client_batch(
cygrpc.Operations(operations), tag)
call_result = call.start_client_batch(operations, tag)
self.assertEqual(cygrpc.CallError.ok, call_result)
event = queue.poll(deadline)
self.assertEqual(cygrpc.CompletionType.operation_complete,
@ -200,14 +161,14 @@ class ServerClientMixin(object):
def test_echo(self):
DEADLINE = time.time() + 5
DEADLINE_TOLERANCE = 0.25
CLIENT_METADATA_ASCII_KEY = b'key'
CLIENT_METADATA_ASCII_VALUE = b'val'
CLIENT_METADATA_BIN_KEY = b'key-bin'
CLIENT_METADATA_ASCII_KEY = 'key'
CLIENT_METADATA_ASCII_VALUE = 'val'
CLIENT_METADATA_BIN_KEY = 'key-bin'
CLIENT_METADATA_BIN_VALUE = b'\0' * 1000
SERVER_INITIAL_METADATA_KEY = b'init_me_me_me'
SERVER_INITIAL_METADATA_VALUE = b'whodawha?'
SERVER_TRAILING_METADATA_KEY = b'california_is_in_a_drought'
SERVER_TRAILING_METADATA_VALUE = b'zomg it is'
SERVER_INITIAL_METADATA_KEY = 'init_me_me_me'
SERVER_INITIAL_METADATA_VALUE = 'whodawha?'
SERVER_TRAILING_METADATA_KEY = 'california_is_in_a_drought'
SERVER_TRAILING_METADATA_VALUE = 'zomg it is'
SERVER_STATUS_CODE = cygrpc.StatusCode.ok
SERVER_STATUS_DETAILS = b'our work is never over'
REQUEST = b'in death a member of project mayhem has a name'
@ -227,11 +188,9 @@ class ServerClientMixin(object):
client_call = self.client_channel.create_call(
None, 0, self.client_completion_queue, METHOD, self.host_argument,
cygrpc_deadline)
client_initial_metadata = cygrpc.Metadata([
cygrpc.Metadatum(CLIENT_METADATA_ASCII_KEY,
CLIENT_METADATA_ASCII_VALUE),
cygrpc.Metadatum(CLIENT_METADATA_BIN_KEY, CLIENT_METADATA_BIN_VALUE)
])
client_initial_metadata = (
(CLIENT_METADATA_ASCII_KEY, CLIENT_METADATA_ASCII_VALUE,),
(CLIENT_METADATA_BIN_KEY, CLIENT_METADATA_BIN_VALUE,),)
client_start_batch_result = client_call.start_client_batch([
cygrpc.operation_send_initial_metadata(client_initial_metadata,
_EMPTY_FLAGS),
@ -263,14 +222,10 @@ class ServerClientMixin(object):
server_call_tag = object()
server_call = request_event.operation_call
server_initial_metadata = cygrpc.Metadata([
cygrpc.Metadatum(SERVER_INITIAL_METADATA_KEY,
SERVER_INITIAL_METADATA_VALUE)
])
server_trailing_metadata = cygrpc.Metadata([
cygrpc.Metadatum(SERVER_TRAILING_METADATA_KEY,
SERVER_TRAILING_METADATA_VALUE)
])
server_initial_metadata = (
(SERVER_INITIAL_METADATA_KEY, SERVER_INITIAL_METADATA_VALUE,),)
server_trailing_metadata = (
(SERVER_TRAILING_METADATA_KEY, SERVER_TRAILING_METADATA_VALUE,),)
server_start_batch_result = server_call.start_server_batch([
cygrpc.operation_send_initial_metadata(
server_initial_metadata,
@ -347,7 +302,7 @@ class ServerClientMixin(object):
METHOD = b'twinkies'
cygrpc_deadline = cygrpc.Timespec(DEADLINE)
empty_metadata = cygrpc.Metadata([])
empty_metadata = ()
server_request_tag = object()
self.server.request_call(self.server_completion_queue,

@ -0,0 +1,571 @@
# Copyright 2017 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test of gRPC Python interceptors."""
import collections
import itertools
import threading
import unittest
from concurrent import futures
import grpc
from grpc.framework.foundation import logging_pool
from tests.unit.framework.common import test_constants
from tests.unit.framework.common import test_control
_SERIALIZE_REQUEST = lambda bytestring: bytestring * 2
_DESERIALIZE_REQUEST = lambda bytestring: bytestring[len(bytestring) // 2:]
_SERIALIZE_RESPONSE = lambda bytestring: bytestring * 3
_DESERIALIZE_RESPONSE = lambda bytestring: bytestring[:len(bytestring) // 3]
_UNARY_UNARY = '/test/UnaryUnary'
_UNARY_STREAM = '/test/UnaryStream'
_STREAM_UNARY = '/test/StreamUnary'
_STREAM_STREAM = '/test/StreamStream'
class _Callback(object):
def __init__(self):
self._condition = threading.Condition()
self._value = None
self._called = False
def __call__(self, value):
with self._condition:
self._value = value
self._called = True
self._condition.notify_all()
def value(self):
with self._condition:
while not self._called:
self._condition.wait()
return self._value
class _Handler(object):
def __init__(self, control):
self._control = control
def handle_unary_unary(self, request, servicer_context):
self._control.control()
if servicer_context is not None:
servicer_context.set_trailing_metadata((('testkey', 'testvalue',),))
return request
def handle_unary_stream(self, request, servicer_context):
for _ in range(test_constants.STREAM_LENGTH):
self._control.control()
yield request
self._control.control()
if servicer_context is not None:
servicer_context.set_trailing_metadata((('testkey', 'testvalue',),))
def handle_stream_unary(self, request_iterator, servicer_context):
if servicer_context is not None:
servicer_context.invocation_metadata()
self._control.control()
response_elements = []
for request in request_iterator:
self._control.control()
response_elements.append(request)
self._control.control()
if servicer_context is not None:
servicer_context.set_trailing_metadata((('testkey', 'testvalue',),))
return b''.join(response_elements)
def handle_stream_stream(self, request_iterator, servicer_context):
self._control.control()
if servicer_context is not None:
servicer_context.set_trailing_metadata((('testkey', 'testvalue',),))
for request in request_iterator:
self._control.control()
yield request
self._control.control()
class _MethodHandler(grpc.RpcMethodHandler):
def __init__(self, request_streaming, response_streaming,
request_deserializer, response_serializer, unary_unary,
unary_stream, stream_unary, stream_stream):
self.request_streaming = request_streaming
self.response_streaming = response_streaming
self.request_deserializer = request_deserializer
self.response_serializer = response_serializer
self.unary_unary = unary_unary
self.unary_stream = unary_stream
self.stream_unary = stream_unary
self.stream_stream = stream_stream
class _GenericHandler(grpc.GenericRpcHandler):
def __init__(self, handler):
self._handler = handler
def service(self, handler_call_details):
if handler_call_details.method == _UNARY_UNARY:
return _MethodHandler(False, False, None, None,
self._handler.handle_unary_unary, None, None,
None)
elif handler_call_details.method == _UNARY_STREAM:
return _MethodHandler(False, True, _DESERIALIZE_REQUEST,
_SERIALIZE_RESPONSE, None,
self._handler.handle_unary_stream, None, None)
elif handler_call_details.method == _STREAM_UNARY:
return _MethodHandler(True, False, _DESERIALIZE_REQUEST,
_SERIALIZE_RESPONSE, None, None,
self._handler.handle_stream_unary, None)
elif handler_call_details.method == _STREAM_STREAM:
return _MethodHandler(True, True, None, None, None, None, None,
self._handler.handle_stream_stream)
else:
return None
def _unary_unary_multi_callable(channel):
return channel.unary_unary(_UNARY_UNARY)
def _unary_stream_multi_callable(channel):
return channel.unary_stream(
_UNARY_STREAM,
request_serializer=_SERIALIZE_REQUEST,
response_deserializer=_DESERIALIZE_RESPONSE)
def _stream_unary_multi_callable(channel):
return channel.stream_unary(
_STREAM_UNARY,
request_serializer=_SERIALIZE_REQUEST,
response_deserializer=_DESERIALIZE_RESPONSE)
def _stream_stream_multi_callable(channel):
return channel.stream_stream(_STREAM_STREAM)
class _ClientCallDetails(
collections.namedtuple('_ClientCallDetails',
('method', 'timeout', 'metadata',
'credentials')), grpc.ClientCallDetails):
pass
class _GenericClientInterceptor(
grpc.UnaryUnaryClientInterceptor, grpc.UnaryStreamClientInterceptor,
grpc.StreamUnaryClientInterceptor, grpc.StreamStreamClientInterceptor):
def __init__(self, interceptor_function):
self._fn = interceptor_function
def intercept_unary_unary(self, continuation, client_call_details, request):
new_details, new_request_iterator, postprocess = self._fn(
client_call_details, iter((request,)), False, False)
response = continuation(new_details, next(new_request_iterator))
return postprocess(response) if postprocess else response
def intercept_unary_stream(self, continuation, client_call_details,
request):
new_details, new_request_iterator, postprocess = self._fn(
client_call_details, iter((request,)), False, True)
response_it = continuation(new_details, new_request_iterator)
return postprocess(response_it) if postprocess else response_it
def intercept_stream_unary(self, continuation, client_call_details,
request_iterator):
new_details, new_request_iterator, postprocess = self._fn(
client_call_details, request_iterator, True, False)
response = continuation(new_details, next(new_request_iterator))
return postprocess(response) if postprocess else response
def intercept_stream_stream(self, continuation, client_call_details,
request_iterator):
new_details, new_request_iterator, postprocess = self._fn(
client_call_details, request_iterator, True, True)
response_it = continuation(new_details, new_request_iterator)
return postprocess(response_it) if postprocess else response_it
class _LoggingInterceptor(
grpc.ServerInterceptor, grpc.UnaryUnaryClientInterceptor,
grpc.UnaryStreamClientInterceptor, grpc.StreamUnaryClientInterceptor,
grpc.StreamStreamClientInterceptor):
def __init__(self, tag, record):
self.tag = tag
self.record = record
def intercept_service(self, continuation, handler_call_details):
self.record.append(self.tag + ':intercept_service')
return continuation(handler_call_details)
def intercept_unary_unary(self, continuation, client_call_details, request):
self.record.append(self.tag + ':intercept_unary_unary')
return continuation(client_call_details, request)
def intercept_unary_stream(self, continuation, client_call_details,
request):
self.record.append(self.tag + ':intercept_unary_stream')
return continuation(client_call_details, request)
def intercept_stream_unary(self, continuation, client_call_details,
request_iterator):
self.record.append(self.tag + ':intercept_stream_unary')
return continuation(client_call_details, request_iterator)
def intercept_stream_stream(self, continuation, client_call_details,
request_iterator):
self.record.append(self.tag + ':intercept_stream_stream')
return continuation(client_call_details, request_iterator)
class _DefectiveClientInterceptor(grpc.UnaryUnaryClientInterceptor):
def intercept_unary_unary(self, ignored_continuation,
ignored_client_call_details, ignored_request):
raise test_control.Defect()
def _wrap_request_iterator_stream_interceptor(wrapper):
def intercept_call(client_call_details, request_iterator, request_streaming,
ignored_response_streaming):
if request_streaming:
return client_call_details, wrapper(request_iterator), None
else:
return client_call_details, request_iterator, None
return _GenericClientInterceptor(intercept_call)
def _append_request_header_interceptor(header, value):
def intercept_call(client_call_details, request_iterator,
ignored_request_streaming, ignored_response_streaming):
metadata = []
if client_call_details.metadata:
metadata = list(client_call_details.metadata)
metadata.append((header, value,))
client_call_details = _ClientCallDetails(
client_call_details.method, client_call_details.timeout, metadata,
client_call_details.credentials)
return client_call_details, request_iterator, None
return _GenericClientInterceptor(intercept_call)
class _GenericServerInterceptor(grpc.ServerInterceptor):
def __init__(self, fn):
self._fn = fn
def intercept_service(self, continuation, handler_call_details):
return self._fn(continuation, handler_call_details)
def _filter_server_interceptor(condition, interceptor):
def intercept_service(continuation, handler_call_details):
if condition(handler_call_details):
return interceptor.intercept_service(continuation,
handler_call_details)
return continuation(handler_call_details)
return _GenericServerInterceptor(intercept_service)
class InterceptorTest(unittest.TestCase):
def setUp(self):
self._control = test_control.PauseFailControl()
self._handler = _Handler(self._control)
self._server_pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY)
self._record = []
conditional_interceptor = _filter_server_interceptor(
lambda x: ('secret', '42') in x.invocation_metadata,
_LoggingInterceptor('s3', self._record))
self._server = grpc.server(
self._server_pool,
interceptors=(_LoggingInterceptor('s1', self._record),
conditional_interceptor,
_LoggingInterceptor('s2', self._record),))
port = self._server.add_insecure_port('[::]:0')
self._server.add_generic_rpc_handlers((_GenericHandler(self._handler),))
self._server.start()
self._channel = grpc.insecure_channel('localhost:%d' % port)
def tearDown(self):
self._server.stop(None)
self._server_pool.shutdown(wait=True)
def testTripleRequestMessagesClientInterceptor(self):
def triple(request_iterator):
while True:
try:
item = next(request_iterator)
yield item
yield item
yield item
except StopIteration:
break
interceptor = _wrap_request_iterator_stream_interceptor(triple)
channel = grpc.intercept_channel(self._channel, interceptor)
requests = tuple(b'\x07\x08'
for _ in range(test_constants.STREAM_LENGTH))
multi_callable = _stream_stream_multi_callable(channel)
response_iterator = multi_callable(
iter(requests),
metadata=(
('test',
'InterceptedStreamRequestBlockingUnaryResponseWithCall'),))
responses = tuple(response_iterator)
self.assertEqual(len(responses), 3 * test_constants.STREAM_LENGTH)
multi_callable = _stream_stream_multi_callable(self._channel)
response_iterator = multi_callable(
iter(requests),
metadata=(
('test',
'InterceptedStreamRequestBlockingUnaryResponseWithCall'),))
responses = tuple(response_iterator)
self.assertEqual(len(responses), test_constants.STREAM_LENGTH)
def testDefectiveClientInterceptor(self):
interceptor = _DefectiveClientInterceptor()
defective_channel = grpc.intercept_channel(self._channel, interceptor)
request = b'\x07\x08'
multi_callable = _unary_unary_multi_callable(defective_channel)
call_future = multi_callable.future(
request,
metadata=(
('test', 'InterceptedUnaryRequestBlockingUnaryResponse'),))
self.assertIsNotNone(call_future.exception())
self.assertEqual(call_future.code(), grpc.StatusCode.INTERNAL)
def testInterceptedHeaderManipulationWithServerSideVerification(self):
request = b'\x07\x08'
channel = grpc.intercept_channel(
self._channel, _append_request_header_interceptor('secret', '42'))
channel = grpc.intercept_channel(
channel,
_LoggingInterceptor('c1', self._record),
_LoggingInterceptor('c2', self._record))
self._record[:] = []
multi_callable = _unary_unary_multi_callable(channel)
multi_callable.with_call(
request,
metadata=(
('test',
'InterceptedUnaryRequestBlockingUnaryResponseWithCall'),))
self.assertSequenceEqual(self._record, [
'c1:intercept_unary_unary', 'c2:intercept_unary_unary',
's1:intercept_service', 's3:intercept_service',
's2:intercept_service'
])
def testInterceptedUnaryRequestBlockingUnaryResponse(self):
request = b'\x07\x08'
self._record[:] = []
channel = grpc.intercept_channel(
self._channel,
_LoggingInterceptor('c1', self._record),
_LoggingInterceptor('c2', self._record))
multi_callable = _unary_unary_multi_callable(channel)
multi_callable(
request,
metadata=(
('test', 'InterceptedUnaryRequestBlockingUnaryResponse'),))
self.assertSequenceEqual(self._record, [
'c1:intercept_unary_unary', 'c2:intercept_unary_unary',
's1:intercept_service', 's2:intercept_service'
])
def testInterceptedUnaryRequestBlockingUnaryResponseWithCall(self):
request = b'\x07\x08'
channel = grpc.intercept_channel(
self._channel,
_LoggingInterceptor('c1', self._record),
_LoggingInterceptor('c2', self._record))
self._record[:] = []
multi_callable = _unary_unary_multi_callable(channel)
multi_callable.with_call(
request,
metadata=(
('test',
'InterceptedUnaryRequestBlockingUnaryResponseWithCall'),))
self.assertSequenceEqual(self._record, [
'c1:intercept_unary_unary', 'c2:intercept_unary_unary',
's1:intercept_service', 's2:intercept_service'
])
def testInterceptedUnaryRequestFutureUnaryResponse(self):
request = b'\x07\x08'
self._record[:] = []
channel = grpc.intercept_channel(
self._channel,
_LoggingInterceptor('c1', self._record),
_LoggingInterceptor('c2', self._record))
multi_callable = _unary_unary_multi_callable(channel)
response_future = multi_callable.future(
request,
metadata=(('test', 'InterceptedUnaryRequestFutureUnaryResponse'),))
response_future.result()
self.assertSequenceEqual(self._record, [
'c1:intercept_unary_unary', 'c2:intercept_unary_unary',
's1:intercept_service', 's2:intercept_service'
])
def testInterceptedUnaryRequestStreamResponse(self):
request = b'\x37\x58'
self._record[:] = []
channel = grpc.intercept_channel(
self._channel,
_LoggingInterceptor('c1', self._record),
_LoggingInterceptor('c2', self._record))
multi_callable = _unary_stream_multi_callable(channel)
response_iterator = multi_callable(
request,
metadata=(('test', 'InterceptedUnaryRequestStreamResponse'),))
tuple(response_iterator)
self.assertSequenceEqual(self._record, [
'c1:intercept_unary_stream', 'c2:intercept_unary_stream',
's1:intercept_service', 's2:intercept_service'
])
def testInterceptedStreamRequestBlockingUnaryResponse(self):
requests = tuple(b'\x07\x08'
for _ in range(test_constants.STREAM_LENGTH))
request_iterator = iter(requests)
self._record[:] = []
channel = grpc.intercept_channel(
self._channel,
_LoggingInterceptor('c1', self._record),
_LoggingInterceptor('c2', self._record))
multi_callable = _stream_unary_multi_callable(channel)
multi_callable(
request_iterator,
metadata=(
('test', 'InterceptedStreamRequestBlockingUnaryResponse'),))
self.assertSequenceEqual(self._record, [
'c1:intercept_stream_unary', 'c2:intercept_stream_unary',
's1:intercept_service', 's2:intercept_service'
])
def testInterceptedStreamRequestBlockingUnaryResponseWithCall(self):
requests = tuple(b'\x07\x08'
for _ in range(test_constants.STREAM_LENGTH))
request_iterator = iter(requests)
self._record[:] = []
channel = grpc.intercept_channel(
self._channel,
_LoggingInterceptor('c1', self._record),
_LoggingInterceptor('c2', self._record))
multi_callable = _stream_unary_multi_callable(channel)
multi_callable.with_call(
request_iterator,
metadata=(
('test',
'InterceptedStreamRequestBlockingUnaryResponseWithCall'),))
self.assertSequenceEqual(self._record, [
'c1:intercept_stream_unary', 'c2:intercept_stream_unary',
's1:intercept_service', 's2:intercept_service'
])
def testInterceptedStreamRequestFutureUnaryResponse(self):
requests = tuple(b'\x07\x08'
for _ in range(test_constants.STREAM_LENGTH))
request_iterator = iter(requests)
self._record[:] = []
channel = grpc.intercept_channel(
self._channel,
_LoggingInterceptor('c1', self._record),
_LoggingInterceptor('c2', self._record))
multi_callable = _stream_unary_multi_callable(channel)
response_future = multi_callable.future(
request_iterator,
metadata=(('test', 'InterceptedStreamRequestFutureUnaryResponse'),))
response_future.result()
self.assertSequenceEqual(self._record, [
'c1:intercept_stream_unary', 'c2:intercept_stream_unary',
's1:intercept_service', 's2:intercept_service'
])
def testInterceptedStreamRequestStreamResponse(self):
requests = tuple(b'\x77\x58'
for _ in range(test_constants.STREAM_LENGTH))
request_iterator = iter(requests)
self._record[:] = []
channel = grpc.intercept_channel(
self._channel,
_LoggingInterceptor('c1', self._record),
_LoggingInterceptor('c2', self._record))
multi_callable = _stream_stream_multi_callable(channel)
response_iterator = multi_callable(
request_iterator,
metadata=(('test', 'InterceptedStreamRequestStreamResponse'),))
tuple(response_iterator)
self.assertSequenceEqual(self._record, [
'c1:intercept_stream_stream', 'c2:intercept_stream_stream',
's1:intercept_service', 's2:intercept_service'
])
if __name__ == '__main__':
unittest.main(verbosity=2)

@ -56,6 +56,7 @@ class _Servicer(object):
def __init__(self):
self._lock = threading.Lock()
self._abort_call = False
self._code = None
self._details = None
self._exception = False
@ -67,10 +68,13 @@ class _Servicer(object):
self._received_client_metadata = context.invocation_metadata()
context.send_initial_metadata(_SERVER_INITIAL_METADATA)
context.set_trailing_metadata(_SERVER_TRAILING_METADATA)
if self._code is not None:
context.set_code(self._code)
if self._details is not None:
context.set_details(self._details)
if self._abort_call:
context.abort(self._code, self._details)
else:
if self._code is not None:
context.set_code(self._code)
if self._details is not None:
context.set_details(self._details)
if self._exception:
raise test_control.Defect()
else:
@ -81,10 +85,13 @@ class _Servicer(object):
self._received_client_metadata = context.invocation_metadata()
context.send_initial_metadata(_SERVER_INITIAL_METADATA)
context.set_trailing_metadata(_SERVER_TRAILING_METADATA)
if self._code is not None:
context.set_code(self._code)
if self._details is not None:
context.set_details(self._details)
if self._abort_call:
context.abort(self._code, self._details)
else:
if self._code is not None:
context.set_code(self._code)
if self._details is not None:
context.set_details(self._details)
for _ in range(test_constants.STREAM_LENGTH // 2):
yield _SERIALIZED_RESPONSE
if self._exception:
@ -95,14 +102,16 @@ class _Servicer(object):
self._received_client_metadata = context.invocation_metadata()
context.send_initial_metadata(_SERVER_INITIAL_METADATA)
context.set_trailing_metadata(_SERVER_TRAILING_METADATA)
if self._code is not None:
context.set_code(self._code)
if self._details is not None:
context.set_details(self._details)
# TODO(https://github.com/grpc/grpc/issues/6891): just ignore the
# request iterator.
for ignored_request in request_iterator:
pass
list(request_iterator)
if self._abort_call:
context.abort(self._code, self._details)
else:
if self._code is not None:
context.set_code(self._code)
if self._details is not None:
context.set_details(self._details)
if self._exception:
raise test_control.Defect()
else:
@ -113,19 +122,25 @@ class _Servicer(object):
self._received_client_metadata = context.invocation_metadata()
context.send_initial_metadata(_SERVER_INITIAL_METADATA)
context.set_trailing_metadata(_SERVER_TRAILING_METADATA)
if self._code is not None:
context.set_code(self._code)
if self._details is not None:
context.set_details(self._details)
# TODO(https://github.com/grpc/grpc/issues/6891): just ignore the
# request iterator.
for ignored_request in request_iterator:
pass
list(request_iterator)
if self._abort_call:
context.abort(self._code, self._details)
else:
if self._code is not None:
context.set_code(self._code)
if self._details is not None:
context.set_details(self._details)
for _ in range(test_constants.STREAM_LENGTH // 3):
yield object()
if self._exception:
raise test_control.Defect()
def set_abort_call(self):
with self._lock:
self._abort_call = True
def set_code(self, code):
with self._lock:
self._code = code
@ -212,11 +227,10 @@ class MetadataCodeDetailsTest(unittest.TestCase):
def testSuccessfulUnaryStream(self):
self._servicer.set_details(_DETAILS)
call = self._unary_stream(
response_iterator_call = self._unary_stream(
_SERIALIZED_REQUEST, metadata=_CLIENT_METADATA)
received_initial_metadata = call.initial_metadata()
for _ in call:
pass
received_initial_metadata = response_iterator_call.initial_metadata()
list(response_iterator_call)
self.assertTrue(
test_common.metadata_transmitted(
@ -225,10 +239,11 @@ class MetadataCodeDetailsTest(unittest.TestCase):
test_common.metadata_transmitted(_SERVER_INITIAL_METADATA,
received_initial_metadata))
self.assertTrue(
test_common.metadata_transmitted(_SERVER_TRAILING_METADATA,
call.trailing_metadata()))
self.assertIs(grpc.StatusCode.OK, call.code())
self.assertEqual(_DETAILS, call.details())
test_common.metadata_transmitted(
_SERVER_TRAILING_METADATA,
response_iterator_call.trailing_metadata()))
self.assertIs(grpc.StatusCode.OK, response_iterator_call.code())
self.assertEqual(_DETAILS, response_iterator_call.details())
def testSuccessfulStreamUnary(self):
self._servicer.set_details(_DETAILS)
@ -252,12 +267,11 @@ class MetadataCodeDetailsTest(unittest.TestCase):
def testSuccessfulStreamStream(self):
self._servicer.set_details(_DETAILS)
call = self._stream_stream(
response_iterator_call = self._stream_stream(
iter([object()] * test_constants.STREAM_LENGTH),
metadata=_CLIENT_METADATA)
received_initial_metadata = call.initial_metadata()
for _ in call:
pass
received_initial_metadata = response_iterator_call.initial_metadata()
list(response_iterator_call)
self.assertTrue(
test_common.metadata_transmitted(
@ -266,10 +280,106 @@ class MetadataCodeDetailsTest(unittest.TestCase):
test_common.metadata_transmitted(_SERVER_INITIAL_METADATA,
received_initial_metadata))
self.assertTrue(
test_common.metadata_transmitted(_SERVER_TRAILING_METADATA,
call.trailing_metadata()))
self.assertIs(grpc.StatusCode.OK, call.code())
self.assertEqual(_DETAILS, call.details())
test_common.metadata_transmitted(
_SERVER_TRAILING_METADATA,
response_iterator_call.trailing_metadata()))
self.assertIs(grpc.StatusCode.OK, response_iterator_call.code())
self.assertEqual(_DETAILS, response_iterator_call.details())
def testAbortedUnaryUnary(self):
self._servicer.set_code(_NON_OK_CODE)
self._servicer.set_details(_DETAILS)
self._servicer.set_abort_call()
with self.assertRaises(grpc.RpcError) as exception_context:
self._unary_unary.with_call(object(), metadata=_CLIENT_METADATA)
self.assertTrue(
test_common.metadata_transmitted(
_CLIENT_METADATA, self._servicer.received_client_metadata()))
self.assertTrue(
test_common.metadata_transmitted(
_SERVER_INITIAL_METADATA,
exception_context.exception.initial_metadata()))
self.assertTrue(
test_common.metadata_transmitted(
_SERVER_TRAILING_METADATA,
exception_context.exception.trailing_metadata()))
self.assertIs(_NON_OK_CODE, exception_context.exception.code())
self.assertEqual(_DETAILS, exception_context.exception.details())
def testAbortedUnaryStream(self):
self._servicer.set_code(_NON_OK_CODE)
self._servicer.set_details(_DETAILS)
self._servicer.set_abort_call()
response_iterator_call = self._unary_stream(
_SERIALIZED_REQUEST, metadata=_CLIENT_METADATA)
received_initial_metadata = response_iterator_call.initial_metadata()
with self.assertRaises(grpc.RpcError):
self.assertEqual(len(list(response_iterator_call)), 0)
self.assertTrue(
test_common.metadata_transmitted(
_CLIENT_METADATA, self._servicer.received_client_metadata()))
self.assertTrue(
test_common.metadata_transmitted(_SERVER_INITIAL_METADATA,
received_initial_metadata))
self.assertTrue(
test_common.metadata_transmitted(
_SERVER_TRAILING_METADATA,
response_iterator_call.trailing_metadata()))
self.assertIs(_NON_OK_CODE, response_iterator_call.code())
self.assertEqual(_DETAILS, response_iterator_call.details())
def testAbortedStreamUnary(self):
self._servicer.set_code(_NON_OK_CODE)
self._servicer.set_details(_DETAILS)
self._servicer.set_abort_call()
with self.assertRaises(grpc.RpcError) as exception_context:
self._stream_unary.with_call(
iter([_SERIALIZED_REQUEST] * test_constants.STREAM_LENGTH),
metadata=_CLIENT_METADATA)
self.assertTrue(
test_common.metadata_transmitted(
_CLIENT_METADATA, self._servicer.received_client_metadata()))
self.assertTrue(
test_common.metadata_transmitted(
_SERVER_INITIAL_METADATA,
exception_context.exception.initial_metadata()))
self.assertTrue(
test_common.metadata_transmitted(
_SERVER_TRAILING_METADATA,
exception_context.exception.trailing_metadata()))
self.assertIs(_NON_OK_CODE, exception_context.exception.code())
self.assertEqual(_DETAILS, exception_context.exception.details())
def testAbortedStreamStream(self):
self._servicer.set_code(_NON_OK_CODE)
self._servicer.set_details(_DETAILS)
self._servicer.set_abort_call()
response_iterator_call = self._stream_stream(
iter([object()] * test_constants.STREAM_LENGTH),
metadata=_CLIENT_METADATA)
received_initial_metadata = response_iterator_call.initial_metadata()
with self.assertRaises(grpc.RpcError):
self.assertEqual(len(list(response_iterator_call)), 0)
self.assertTrue(
test_common.metadata_transmitted(
_CLIENT_METADATA, self._servicer.received_client_metadata()))
self.assertTrue(
test_common.metadata_transmitted(_SERVER_INITIAL_METADATA,
received_initial_metadata))
self.assertTrue(
test_common.metadata_transmitted(
_SERVER_TRAILING_METADATA,
response_iterator_call.trailing_metadata()))
self.assertIs(_NON_OK_CODE, response_iterator_call.code())
self.assertEqual(_DETAILS, response_iterator_call.details())
def testCustomCodeUnaryUnary(self):
self._servicer.set_code(_NON_OK_CODE)
@ -296,12 +406,11 @@ class MetadataCodeDetailsTest(unittest.TestCase):
self._servicer.set_code(_NON_OK_CODE)
self._servicer.set_details(_DETAILS)
call = self._unary_stream(
response_iterator_call = self._unary_stream(
_SERIALIZED_REQUEST, metadata=_CLIENT_METADATA)
received_initial_metadata = call.initial_metadata()
received_initial_metadata = response_iterator_call.initial_metadata()
with self.assertRaises(grpc.RpcError):
for _ in call:
pass
list(response_iterator_call)
self.assertTrue(
test_common.metadata_transmitted(
@ -310,10 +419,11 @@ class MetadataCodeDetailsTest(unittest.TestCase):
test_common.metadata_transmitted(_SERVER_INITIAL_METADATA,
received_initial_metadata))
self.assertTrue(
test_common.metadata_transmitted(_SERVER_TRAILING_METADATA,
call.trailing_metadata()))
self.assertIs(_NON_OK_CODE, call.code())
self.assertEqual(_DETAILS, call.details())
test_common.metadata_transmitted(
_SERVER_TRAILING_METADATA,
response_iterator_call.trailing_metadata()))
self.assertIs(_NON_OK_CODE, response_iterator_call.code())
self.assertEqual(_DETAILS, response_iterator_call.details())
def testCustomCodeStreamUnary(self):
self._servicer.set_code(_NON_OK_CODE)
@ -342,13 +452,12 @@ class MetadataCodeDetailsTest(unittest.TestCase):
self._servicer.set_code(_NON_OK_CODE)
self._servicer.set_details(_DETAILS)
call = self._stream_stream(
response_iterator_call = self._stream_stream(
iter([object()] * test_constants.STREAM_LENGTH),
metadata=_CLIENT_METADATA)
received_initial_metadata = call.initial_metadata()
received_initial_metadata = response_iterator_call.initial_metadata()
with self.assertRaises(grpc.RpcError) as exception_context:
for _ in call:
pass
list(response_iterator_call)
self.assertTrue(
test_common.metadata_transmitted(
@ -390,12 +499,11 @@ class MetadataCodeDetailsTest(unittest.TestCase):
self._servicer.set_details(_DETAILS)
self._servicer.set_exception()
call = self._unary_stream(
response_iterator_call = self._unary_stream(
_SERIALIZED_REQUEST, metadata=_CLIENT_METADATA)
received_initial_metadata = call.initial_metadata()
received_initial_metadata = response_iterator_call.initial_metadata()
with self.assertRaises(grpc.RpcError):
for _ in call:
pass
list(response_iterator_call)
self.assertTrue(
test_common.metadata_transmitted(
@ -404,10 +512,11 @@ class MetadataCodeDetailsTest(unittest.TestCase):
test_common.metadata_transmitted(_SERVER_INITIAL_METADATA,
received_initial_metadata))
self.assertTrue(
test_common.metadata_transmitted(_SERVER_TRAILING_METADATA,
call.trailing_metadata()))
self.assertIs(_NON_OK_CODE, call.code())
self.assertEqual(_DETAILS, call.details())
test_common.metadata_transmitted(
_SERVER_TRAILING_METADATA,
response_iterator_call.trailing_metadata()))
self.assertIs(_NON_OK_CODE, response_iterator_call.code())
self.assertEqual(_DETAILS, response_iterator_call.details())
def testCustomCodeExceptionStreamUnary(self):
self._servicer.set_code(_NON_OK_CODE)
@ -438,13 +547,12 @@ class MetadataCodeDetailsTest(unittest.TestCase):
self._servicer.set_details(_DETAILS)
self._servicer.set_exception()
call = self._stream_stream(
response_iterator_call = self._stream_stream(
iter([object()] * test_constants.STREAM_LENGTH),
metadata=_CLIENT_METADATA)
received_initial_metadata = call.initial_metadata()
received_initial_metadata = response_iterator_call.initial_metadata()
with self.assertRaises(grpc.RpcError):
for _ in call:
pass
list(response_iterator_call)
self.assertTrue(
test_common.metadata_transmitted(
@ -453,10 +561,11 @@ class MetadataCodeDetailsTest(unittest.TestCase):
test_common.metadata_transmitted(_SERVER_INITIAL_METADATA,
received_initial_metadata))
self.assertTrue(
test_common.metadata_transmitted(_SERVER_TRAILING_METADATA,
call.trailing_metadata()))
self.assertIs(_NON_OK_CODE, call.code())
self.assertEqual(_DETAILS, call.details())
test_common.metadata_transmitted(
_SERVER_TRAILING_METADATA,
response_iterator_call.trailing_metadata()))
self.assertIs(_NON_OK_CODE, response_iterator_call.code())
self.assertEqual(_DETAILS, response_iterator_call.details())
def testCustomCodeReturnNoneUnaryUnary(self):
self._servicer.set_code(_NON_OK_CODE)

@ -61,7 +61,7 @@ ENV['EMBED_ZLIB'] = 'true'
ENV['EMBED_CARES'] = 'true'
ENV['ARCH_FLAGS'] = RbConfig::CONFIG['ARCH_FLAG']
ENV['ARCH_FLAGS'] = '-arch i386 -arch x86_64' if RUBY_PLATFORM =~ /darwin/
ENV['CFLAGS'] = '-DGPR_BACKWARDS_COMPATIBILITY_MODE'
ENV['CPPFLAGS'] = '-DGPR_BACKWARDS_COMPATIBILITY_MODE'
output_dir = File.expand_path(RbConfig::CONFIG['topdir'])
grpc_lib_dir = File.join(output_dir, 'libs', grpc_config)

Loading…
Cancel
Save