pull/35470/head
Craig Tiller 1 year ago
commit 789b365616
  1. 52
      CMakeLists.txt
  2. 44
      build_autogenerated.yaml
  3. 27
      examples/python/helloworld/helloworld_pb2.py
  4. 14
      examples/python/helloworld/helloworld_pb2.pyi
  5. 104
      examples/python/helloworld/helloworld_pb2_grpc.py
  6. 12
      include/grpc/event_engine/internal/slice_cast.h
  7. 5
      include/grpc/event_engine/slice.h
  8. 23
      src/compiler/python_generator.cc
  9. 85
      src/core/BUILD
  10. 19
      src/core/ext/transport/chaotic_good/chaotic_good_transport.cc
  11. 111
      src/core/ext/transport/chaotic_good/chaotic_good_transport.h
  12. 325
      src/core/ext/transport/chaotic_good/client_transport.cc
  13. 202
      src/core/ext/transport/chaotic_good/client_transport.h
  14. 165
      src/core/ext/transport/chaotic_good/frame.cc
  15. 83
      src/core/ext/transport/chaotic_good/frame.h
  16. 11
      src/core/ext/transport/chaotic_good/frame_header.cc
  17. 2
      src/core/ext/transport/chaotic_good/frame_header.h
  18. 332
      src/core/ext/transport/chaotic_good/server_transport.cc
  19. 145
      src/core/ext/transport/chaotic_good/server_transport.h
  20. 8
      src/core/ext/transport/inproc/inproc_transport.cc
  21. 9
      src/core/lib/gprpp/debug_location.h
  22. 45
      src/core/lib/promise/detail/status.h
  23. 4
      src/core/lib/promise/event_engine_wakeup_scheduler.h
  24. 4
      src/core/lib/promise/if.h
  25. 10
      src/core/lib/promise/inter_activity_pipe.h
  26. 22
      src/core/lib/promise/mpsc.h
  27. 24
      src/core/lib/promise/status_flag.h
  28. 6
      src/core/lib/promise/try_join.h
  29. 29
      src/core/lib/promise/try_seq.h
  30. 4
      src/core/lib/resource_quota/arena.h
  31. 3
      src/core/lib/slice/slice_buffer.h
  32. 13
      src/core/lib/surface/call.cc
  33. 3
      src/core/lib/surface/call.h
  34. 23
      src/core/lib/surface/server.cc
  35. 6
      src/core/lib/surface/server.h
  36. 36
      src/core/lib/transport/promise_endpoint.h
  37. 14
      src/core/lib/transport/transport.cc
  38. 151
      src/core/lib/transport/transport.h
  39. 4
      src/python/.gitignore
  40. 75
      src/python/grpcio/grpc/_channel.py
  41. 7
      src/python/grpcio/grpc/_cython/_cygrpc/channel.pxd.pxi
  42. 71
      src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi
  43. 6
      src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxi
  44. 36
      src/python/grpcio/grpc/_interceptor.py
  45. 81
      src/python/grpcio/grpc/_simple_stubs.py
  46. 21
      src/python/grpcio/grpc/aio/_channel.py
  47. 27
      src/python/grpcio_testing/grpc_testing/_channel/_channel.py
  48. 17
      src/python/grpcio_tests/tests/channelz/_channelz_servicer_test.py
  49. 5
      src/python/grpcio_tests/tests/csds/csds_test.py
  50. 11
      src/python/grpcio_tests/tests/protoc_plugin/_python_plugin_test.py
  51. 9
      src/python/grpcio_tests/tests/qps/benchmark_client.py
  52. 25
      src/python/grpcio_tests/tests/status/_grpc_status_test.py
  53. 20
      src/python/grpcio_tests/tests/unit/_abort_test.py
  54. 20
      src/python/grpcio_tests/tests/unit/_auth_context_test.py
  55. 35
      src/python/grpcio_tests/tests/unit/_channel_close_test.py
  56. 20
      src/python/grpcio_tests/tests/unit/_compression_test.py
  57. 10
      src/python/grpcio_tests/tests/unit/_contextvars_propagation_test.py
  58. 5
      src/python/grpcio_tests/tests/unit/_dns_resolver_test.py
  59. 24
      src/python/grpcio_tests/tests/unit/_empty_message_test.py
  60. 5
      src/python/grpcio_tests/tests/unit/_error_message_encoding_test.py
  61. 20
      src/python/grpcio_tests/tests/unit/_exit_scenarios.py
  62. 8
      src/python/grpcio_tests/tests/unit/_interceptor_test.py
  63. 9
      src/python/grpcio_tests/tests/unit/_invalid_metadata_test.py
  64. 22
      src/python/grpcio_tests/tests/unit/_invocation_defects_test.py
  65. 14
      src/python/grpcio_tests/tests/unit/_local_credentials_test.py
  66. 80
      src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py
  67. 39
      src/python/grpcio_tests/tests/unit/_metadata_flags_test.py
  68. 16
      src/python/grpcio_tests/tests/unit/_metadata_test.py
  69. 5
      src/python/grpcio_tests/tests/unit/_reconnect_test.py
  70. 20
      src/python/grpcio_tests/tests/unit/_resource_exhausted_test.py
  71. 18
      src/python/grpcio_tests/tests/unit/_rpc_test_helpers.py
  72. 5
      src/python/grpcio_tests/tests/unit/_server_shutdown_scenarios.py
  73. 5
      src/python/grpcio_tests/tests/unit/_session_cache_test.py
  74. 24
      src/python/grpcio_tests/tests/unit/_signal_client.py
  75. 14
      src/python/grpcio_tests/tests/unit/_xds_credentials_test.py
  76. 2
      src/python/grpcio_tests/tests_gevent/unit/close_channel_test.py
  77. 5
      src/python/grpcio_tests/tests_py3_only/unit/_leak_test.py
  78. 46
      src/python/grpcio_tests/tests_py3_only/unit/_simple_stubs_test.py
  79. 22
      test/core/end2end/fuzzers/api_fuzzer.cc
  80. 2
      test/core/promise/mpsc_test.cc
  81. 82
      test/core/transport/chaotic_good/BUILD
  82. 549
      test/core/transport/chaotic_good/client_transport_error_test.cc
  83. 618
      test/core/transport/chaotic_good/client_transport_test.cc
  84. 63
      test/core/transport/chaotic_good/frame_fuzzer.cc
  85. 23
      test/core/transport/chaotic_good/frame_fuzzer.proto
  86. BIN
      test/core/transport/chaotic_good/frame_fuzzer_corpus/5072496117219328
  87. BIN
      test/core/transport/chaotic_good/frame_fuzzer_corpus/5691448031772672
  88. BIN
      test/core/transport/chaotic_good/frame_fuzzer_corpus/crash-05c704327d21af2cc914de40e9d90d06f16ca0eb
  89. BIN
      test/core/transport/chaotic_good/frame_fuzzer_corpus/crash-5015de8c7cafb0b0ebbbfd28c29aedd5dbfdc03a
  90. BIN
      test/core/transport/chaotic_good/frame_fuzzer_corpus/crash-5a34978de8de6889ce913947a77f43f7cdea854c
  91. BIN
      test/core/transport/chaotic_good/frame_fuzzer_corpus/crash-608f798a51077a8cdc45b11f335c079a81339fbe
  92. BIN
      test/core/transport/chaotic_good/frame_fuzzer_corpus/crash-6a002cb46eac21af4ab6fd74b61ff3ce26d96dff
  93. BIN
      test/core/transport/chaotic_good/frame_fuzzer_corpus/crash-7732ddd35a4deb8b7c9e462aaf8680986755e540
  94. BIN
      test/core/transport/chaotic_good/frame_fuzzer_corpus/crash-c171e98ebfe8b6485f9a4bea0b9cdfe683776675
  95. 1
      test/core/transport/chaotic_good/frame_fuzzer_corpus/empty
  96. 12
      test/core/transport/chaotic_good/frame_header_test.cc
  97. 31
      test/core/transport/chaotic_good/frame_test.cc
  98. 89
      test/core/transport/chaotic_good/mock_promise_endpoint.cc
  99. 77
      test/core/transport/chaotic_good/mock_promise_endpoint.h
  100. 198
      test/core/transport/chaotic_good/server_transport_test.cc
  101. Some files were not shown because too many files have changed in this diff Show More

52
CMakeLists.txt generated

@ -1326,6 +1326,7 @@ if(gRPC_BUILD_TESTS)
endif()
add_dependencies(buildtests_cxx server_streaming_test)
add_dependencies(buildtests_cxx server_test)
add_dependencies(buildtests_cxx server_transport_test)
add_dependencies(buildtests_cxx service_config_end2end_test)
add_dependencies(buildtests_cxx service_config_test)
add_dependencies(buildtests_cxx settings_timeout_test)
@ -9533,6 +9534,7 @@ add_executable(client_transport_error_test
${_gRPC_PROTO_GENS_DIR}/test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.grpc.pb.cc
${_gRPC_PROTO_GENS_DIR}/test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.pb.h
${_gRPC_PROTO_GENS_DIR}/test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.grpc.pb.h
src/core/ext/transport/chaotic_good/chaotic_good_transport.cc
src/core/ext/transport/chaotic_good/client_transport.cc
src/core/ext/transport/chaotic_good/frame.cc
src/core/ext/transport/chaotic_good/frame_header.cc
@ -9563,6 +9565,7 @@ target_include_directories(client_transport_error_test
target_link_libraries(client_transport_error_test
${_gRPC_ALLTARGETS_LIBRARIES}
gtest
grpc_unsecure
${_gRPC_PROTOBUF_LIBRARIES}
grpc_test_util
)
@ -9576,12 +9579,15 @@ add_executable(client_transport_test
${_gRPC_PROTO_GENS_DIR}/test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.grpc.pb.cc
${_gRPC_PROTO_GENS_DIR}/test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.pb.h
${_gRPC_PROTO_GENS_DIR}/test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.grpc.pb.h
src/core/ext/transport/chaotic_good/chaotic_good_transport.cc
src/core/ext/transport/chaotic_good/client_transport.cc
src/core/ext/transport/chaotic_good/frame.cc
src/core/ext/transport/chaotic_good/frame_header.cc
src/core/lib/transport/promise_endpoint.cc
test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.cc
test/core/transport/chaotic_good/client_transport_test.cc
test/core/transport/chaotic_good/mock_promise_endpoint.cc
test/core/transport/chaotic_good/transport_test.cc
)
target_compile_features(client_transport_test PUBLIC cxx_std_14)
target_include_directories(client_transport_test
@ -22380,6 +22386,52 @@ target_link_libraries(server_test
)
endif()
if(gRPC_BUILD_TESTS)
add_executable(server_transport_test
${_gRPC_PROTO_GENS_DIR}/test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.pb.cc
${_gRPC_PROTO_GENS_DIR}/test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.grpc.pb.cc
${_gRPC_PROTO_GENS_DIR}/test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.pb.h
${_gRPC_PROTO_GENS_DIR}/test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.grpc.pb.h
src/core/ext/transport/chaotic_good/chaotic_good_transport.cc
src/core/ext/transport/chaotic_good/frame.cc
src/core/ext/transport/chaotic_good/frame_header.cc
src/core/ext/transport/chaotic_good/server_transport.cc
src/core/lib/transport/promise_endpoint.cc
test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.cc
test/core/transport/chaotic_good/mock_promise_endpoint.cc
test/core/transport/chaotic_good/server_transport_test.cc
test/core/transport/chaotic_good/transport_test.cc
)
target_compile_features(server_transport_test PUBLIC cxx_std_14)
target_include_directories(server_transport_test
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/include
${_gRPC_ADDRESS_SORTING_INCLUDE_DIR}
${_gRPC_RE2_INCLUDE_DIR}
${_gRPC_SSL_INCLUDE_DIR}
${_gRPC_UPB_GENERATED_DIR}
${_gRPC_UPB_GRPC_GENERATED_DIR}
${_gRPC_UPB_INCLUDE_DIR}
${_gRPC_XXHASH_INCLUDE_DIR}
${_gRPC_ZLIB_INCLUDE_DIR}
third_party/googletest/googletest/include
third_party/googletest/googletest
third_party/googletest/googlemock/include
third_party/googletest/googlemock
${_gRPC_PROTO_GENS_DIR}
)
target_link_libraries(server_transport_test
${_gRPC_ALLTARGETS_LIBRARIES}
gtest
${_gRPC_PROTOBUF_LIBRARIES}
grpc_test_util
)
endif()
if(gRPC_BUILD_TESTS)

@ -7646,6 +7646,7 @@ targets:
build: test
language: c++
headers:
- src/core/ext/transport/chaotic_good/chaotic_good_transport.h
- src/core/ext/transport/chaotic_good/client_transport.h
- src/core/ext/transport/chaotic_good/frame.h
- src/core/ext/transport/chaotic_good/frame_header.h
@ -7658,6 +7659,7 @@ targets:
- test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.h
src:
- test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.proto
- src/core/ext/transport/chaotic_good/chaotic_good_transport.cc
- src/core/ext/transport/chaotic_good/client_transport.cc
- src/core/ext/transport/chaotic_good/frame.cc
- src/core/ext/transport/chaotic_good/frame_header.cc
@ -7666,6 +7668,7 @@ targets:
- test/core/transport/chaotic_good/client_transport_error_test.cc
deps:
- gtest
- grpc_unsecure
- protobuf
- grpc_test_util
uses_polling: false
@ -7674,24 +7677,29 @@ targets:
build: test
language: c++
headers:
- src/core/ext/transport/chaotic_good/chaotic_good_transport.h
- src/core/ext/transport/chaotic_good/client_transport.h
- src/core/ext/transport/chaotic_good/frame.h
- src/core/ext/transport/chaotic_good/frame_header.h
- src/core/lib/promise/event_engine_wakeup_scheduler.h
- src/core/lib/promise/inter_activity_pipe.h
- src/core/lib/promise/join.h
- src/core/lib/promise/mpsc.h
- src/core/lib/promise/wait_set.h
- src/core/lib/transport/promise_endpoint.h
- test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.h
- test/core/transport/chaotic_good/mock_promise_endpoint.h
- test/core/transport/chaotic_good/transport_test.h
src:
- test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.proto
- src/core/ext/transport/chaotic_good/chaotic_good_transport.cc
- src/core/ext/transport/chaotic_good/client_transport.cc
- src/core/ext/transport/chaotic_good/frame.cc
- src/core/ext/transport/chaotic_good/frame_header.cc
- src/core/lib/transport/promise_endpoint.cc
- test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.cc
- test/core/transport/chaotic_good/client_transport_test.cc
- test/core/transport/chaotic_good/mock_promise_endpoint.cc
- test/core/transport/chaotic_good/transport_test.cc
deps:
- gtest
- protobuf
@ -15568,6 +15576,40 @@ targets:
deps:
- gtest
- grpc_test_util
- name: server_transport_test
gtest: true
build: test
language: c++
headers:
- src/core/ext/transport/chaotic_good/chaotic_good_transport.h
- src/core/ext/transport/chaotic_good/frame.h
- src/core/ext/transport/chaotic_good/frame_header.h
- src/core/ext/transport/chaotic_good/server_transport.h
- src/core/lib/promise/event_engine_wakeup_scheduler.h
- src/core/lib/promise/inter_activity_pipe.h
- src/core/lib/promise/mpsc.h
- src/core/lib/promise/switch.h
- src/core/lib/promise/wait_set.h
- src/core/lib/transport/promise_endpoint.h
- test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.h
- test/core/transport/chaotic_good/mock_promise_endpoint.h
- test/core/transport/chaotic_good/transport_test.h
src:
- test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.proto
- src/core/ext/transport/chaotic_good/chaotic_good_transport.cc
- src/core/ext/transport/chaotic_good/frame.cc
- src/core/ext/transport/chaotic_good/frame_header.cc
- src/core/ext/transport/chaotic_good/server_transport.cc
- src/core/lib/transport/promise_endpoint.cc
- test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.cc
- test/core/transport/chaotic_good/mock_promise_endpoint.cc
- test/core/transport/chaotic_good/server_transport_test.cc
- test/core/transport/chaotic_good/transport_test.cc
deps:
- gtest
- protobuf
- grpc_test_util
uses_polling: false
- name: service_config_end2end_test
gtest: true
build: test

@ -1,11 +1,12 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: helloworld.proto
# Protobuf Python Version: 4.25.0
"""Generated protocol buffer code."""
from google.protobuf.internal import builder as _builder
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
@ -13,18 +14,18 @@ _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10helloworld.proto\x12\nhelloworld\"\x1c\n\x0cHelloRequest\x12\x0c\n\x04name\x18\x01 \x01(\t\"\x1d\n\nHelloReply\x12\x0f\n\x07message\x18\x01 \x01(\t2I\n\x07Greeter\x12>\n\x08SayHello\x12\x18.helloworld.HelloRequest\x1a\x16.helloworld.HelloReply\"\x00\x42\x36\n\x1bio.grpc.examples.helloworldB\x0fHelloWorldProtoP\x01\xa2\x02\x03HLWb\x06proto3')
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10helloworld.proto\x12\nhelloworld\"\x1c\n\x0cHelloRequest\x12\x0c\n\x04name\x18\x01 \x01(\t\"\x1d\n\nHelloReply\x12\x0f\n\x07message\x18\x01 \x01(\t2\xe4\x01\n\x07Greeter\x12>\n\x08SayHello\x12\x18.helloworld.HelloRequest\x1a\x16.helloworld.HelloReply\"\x00\x12K\n\x13SayHelloStreamReply\x12\x18.helloworld.HelloRequest\x1a\x16.helloworld.HelloReply\"\x00\x30\x01\x12L\n\x12SayHelloBidiStream\x12\x18.helloworld.HelloRequest\x1a\x16.helloworld.HelloReply\"\x00(\x01\x30\x01\x42\x36\n\x1bio.grpc.examples.helloworldB\x0fHelloWorldProtoP\x01\xa2\x02\x03HLWb\x06proto3')
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'helloworld_pb2', globals())
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'helloworld_pb2', _globals)
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
DESCRIPTOR._serialized_options = b'\n\033io.grpc.examples.helloworldB\017HelloWorldProtoP\001\242\002\003HLW'
_HELLOREQUEST._serialized_start=32
_HELLOREQUEST._serialized_end=60
_HELLOREPLY._serialized_start=62
_HELLOREPLY._serialized_end=91
_GREETER._serialized_start=93
_GREETER._serialized_end=166
_globals['DESCRIPTOR']._options = None
_globals['DESCRIPTOR']._serialized_options = b'\n\033io.grpc.examples.helloworldB\017HelloWorldProtoP\001\242\002\003HLW'
_globals['_HELLOREQUEST']._serialized_start=32
_globals['_HELLOREQUEST']._serialized_end=60
_globals['_HELLOREPLY']._serialized_start=62
_globals['_HELLOREPLY']._serialized_end=91
_globals['_GREETER']._serialized_start=94
_globals['_GREETER']._serialized_end=322
# @@protoc_insertion_point(module_scope)

@ -4,14 +4,14 @@ from typing import ClassVar as _ClassVar, Optional as _Optional
DESCRIPTOR: _descriptor.FileDescriptor
class HelloReply(_message.Message):
__slots__ = ["message"]
MESSAGE_FIELD_NUMBER: _ClassVar[int]
message: str
def __init__(self, message: _Optional[str] = ...) -> None: ...
class HelloRequest(_message.Message):
__slots__ = ["name"]
__slots__ = ("name",)
NAME_FIELD_NUMBER: _ClassVar[int]
name: str
def __init__(self, name: _Optional[str] = ...) -> None: ...
class HelloReply(_message.Message):
__slots__ = ("message",)
MESSAGE_FIELD_NUMBER: _ClassVar[int]
message: str
def __init__(self, message: _Optional[str] = ...) -> None: ...

@ -19,7 +19,17 @@ class GreeterStub(object):
'/helloworld.Greeter/SayHello',
request_serializer=helloworld__pb2.HelloRequest.SerializeToString,
response_deserializer=helloworld__pb2.HelloReply.FromString,
)
_registered_method=True)
self.SayHelloStreamReply = channel.unary_stream(
'/helloworld.Greeter/SayHelloStreamReply',
request_serializer=helloworld__pb2.HelloRequest.SerializeToString,
response_deserializer=helloworld__pb2.HelloReply.FromString,
_registered_method=True)
self.SayHelloBidiStream = channel.stream_stream(
'/helloworld.Greeter/SayHelloBidiStream',
request_serializer=helloworld__pb2.HelloRequest.SerializeToString,
response_deserializer=helloworld__pb2.HelloReply.FromString,
_registered_method=True)
class GreeterServicer(object):
@ -33,6 +43,18 @@ class GreeterServicer(object):
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def SayHelloStreamReply(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def SayHelloBidiStream(self, request_iterator, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_GreeterServicer_to_server(servicer, server):
rpc_method_handlers = {
@ -41,6 +63,16 @@ def add_GreeterServicer_to_server(servicer, server):
request_deserializer=helloworld__pb2.HelloRequest.FromString,
response_serializer=helloworld__pb2.HelloReply.SerializeToString,
),
'SayHelloStreamReply': grpc.unary_stream_rpc_method_handler(
servicer.SayHelloStreamReply,
request_deserializer=helloworld__pb2.HelloRequest.FromString,
response_serializer=helloworld__pb2.HelloReply.SerializeToString,
),
'SayHelloBidiStream': grpc.stream_stream_rpc_method_handler(
servicer.SayHelloBidiStream,
request_deserializer=helloworld__pb2.HelloRequest.FromString,
response_serializer=helloworld__pb2.HelloReply.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'helloworld.Greeter', rpc_method_handlers)
@ -63,8 +95,72 @@ class Greeter(object):
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/helloworld.Greeter/SayHello',
return grpc.experimental.unary_unary(
request,
target,
'/helloworld.Greeter/SayHello',
helloworld__pb2.HelloRequest.SerializeToString,
helloworld__pb2.HelloReply.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def SayHelloStreamReply(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_stream(
request,
target,
'/helloworld.Greeter/SayHelloStreamReply',
helloworld__pb2.HelloRequest.SerializeToString,
helloworld__pb2.HelloReply.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def SayHelloBidiStream(request_iterator,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.stream_stream(
request_iterator,
target,
'/helloworld.Greeter/SayHelloBidiStream',
helloworld__pb2.HelloRequest.SerializeToString,
helloworld__pb2.HelloReply.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)

@ -60,6 +60,18 @@ Result& SliceCast(T& value, SliceCastable<Result, T> = {}) {
return reinterpret_cast<Result&>(value);
}
// Cast to `Result&&` from `T&&` without any runtime checks.
// This is only valid if `sizeof(Result) == sizeof(T)`, and if `Result`, `T` are
// opted in as compatible via `SliceCastable`.
template <typename Result, typename T>
Result&& SliceCast(T&& value, SliceCastable<Result, T> = {}) {
// Insist upon sizes being equal to catch mismatches.
// We assume if sizes are opted in and sizes are equal then yes, these two
// types are expected to be layout compatible and actually appear to be.
static_assert(sizeof(Result) == sizeof(T), "size mismatch");
return reinterpret_cast<Result&&>(value);
}
} // namespace internal
} // namespace experimental
} // namespace grpc_event_engine

@ -169,6 +169,11 @@ struct CopyConstructors {
return Out(grpc_slice_from_copied_buffer(p, len));
}
static Out FromCopiedBuffer(const uint8_t* p, size_t len) {
return Out(
grpc_slice_from_copied_buffer(reinterpret_cast<const char*>(p), len));
}
template <typename Buffer>
static Out FromCopiedBuffer(const Buffer& buffer) {
return FromCopiedBuffer(reinterpret_cast<const char*>(buffer.data()),

@ -467,7 +467,7 @@ bool PrivateGenerator::PrintStub(
out->Print(
method_dict,
"response_deserializer=$ResponseModuleAndClass$.FromString,\n");
out->Print(")\n");
out->Print("_registered_method=True)\n");
}
}
}
@ -642,22 +642,27 @@ bool PrivateGenerator::PrintServiceClass(
args_dict["ArityMethodName"] = arity_method_name;
args_dict["PackageQualifiedService"] = package_qualified_service_name;
args_dict["Method"] = method->name();
out->Print(args_dict,
"return "
"grpc.experimental.$ArityMethodName$($RequestParameter$, "
"target, '/$PackageQualifiedService$/$Method$',\n");
out->Print(args_dict, "return grpc.experimental.$ArityMethodName$(\n");
{
IndentScope continuation_indent(out);
StringMap serializer_dict;
out->Print(args_dict, "$RequestParameter$,\n");
out->Print("target,\n");
out->Print(args_dict, "'/$PackageQualifiedService$/$Method$',\n");
serializer_dict["RequestModuleAndClass"] = request_module_and_class;
serializer_dict["ResponseModuleAndClass"] = response_module_and_class;
out->Print(serializer_dict,
"$RequestModuleAndClass$.SerializeToString,\n");
out->Print(serializer_dict, "$ResponseModuleAndClass$.FromString,\n");
out->Print("options, channel_credentials,\n");
out->Print(
"insecure, call_credentials, compression, wait_for_ready, "
"timeout, metadata)\n");
out->Print("options,\n");
out->Print("channel_credentials,\n");
out->Print("insecure,\n");
out->Print("call_credentials,\n");
out->Print("compression,\n");
out->Print("wait_for_ready,\n");
out->Print("timeout,\n");
out->Print("metadata,\n");
out->Print("_registered_method=True)\n");
}
}
}

@ -6213,6 +6213,7 @@ grpc_cc_library(
"bitset",
"chaotic_good_frame_header",
"context",
"match",
"no_destruct",
"slice",
"slice_buffer",
@ -6368,6 +6369,29 @@ grpc_cc_library(
],
)
grpc_cc_library(
name = "chaotic_good_transport",
srcs = [
"ext/transport/chaotic_good/chaotic_good_transport.cc",
],
hdrs = [
"ext/transport/chaotic_good/chaotic_good_transport.h",
],
external_deps = ["absl/random"],
language = "c++",
deps = [
"chaotic_good_frame",
"chaotic_good_frame_header",
"grpc_promise_endpoint",
"if",
"try_join",
"try_seq",
"//:gpr_platform",
"//:hpack_encoder",
"//:promise",
],
)
grpc_cc_library(
name = "chaotic_good_client_transport",
srcs = [
@ -6378,6 +6402,7 @@ grpc_cc_library(
],
external_deps = [
"absl/base:core_headers",
"absl/container:flat_hash_map",
"absl/random",
"absl/random:bit_gen_ref",
"absl/status",
@ -6388,9 +6413,11 @@ grpc_cc_library(
language = "c++",
deps = [
"activity",
"all_ok",
"arena",
"chaotic_good_frame",
"chaotic_good_frame_header",
"chaotic_good_transport",
"context",
"event_engine_wakeup_scheduler",
"for_each",
@ -6398,6 +6425,7 @@ grpc_cc_library(
"if",
"inter_activity_pipe",
"loop",
"map",
"match",
"memory_quota",
"mpsc",
@ -6414,6 +6442,63 @@ grpc_cc_library(
"//:grpc_base",
"//:hpack_encoder",
"//:hpack_parser",
"//:promise",
"//:ref_counted_ptr",
],
)
grpc_cc_library(
name = "chaotic_good_server_transport",
srcs = [
"ext/transport/chaotic_good/server_transport.cc",
],
hdrs = [
"ext/transport/chaotic_good/server_transport.h",
],
external_deps = [
"absl/base:core_headers",
"absl/container:flat_hash_map",
"absl/functional:any_invocable",
"absl/random",
"absl/random:bit_gen_ref",
"absl/status",
"absl/status:statusor",
"absl/types:optional",
"absl/types:variant",
],
language = "c++",
deps = [
"1999",
"activity",
"arena",
"chaotic_good_frame",
"chaotic_good_frame_header",
"chaotic_good_transport",
"context",
"default_event_engine",
"event_engine_wakeup_scheduler",
"for_each",
"grpc_promise_endpoint",
"if",
"inter_activity_pipe",
"loop",
"memory_quota",
"mpsc",
"pipe",
"poll",
"resource_quota",
"seq",
"slice",
"slice_buffer",
"switch",
"try_join",
"try_seq",
"//:exec_ctx",
"//:gpr",
"//:gpr_platform",
"//:grpc_base",
"//:hpack_encoder",
"//:hpack_parser",
"//:ref_counted_ptr",
],
)

@ -0,0 +1,19 @@
// Copyright 2023 gRPC authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <grpc/support/port_platform.h>
#include "src/core/ext/transport/chaotic_good/chaotic_good_transport.h"
namespace grpc_core {} // namespace grpc_core

@ -0,0 +1,111 @@
// Copyright 2023 gRPC authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef GRPC_SRC_CORE_EXT_TRANSPORT_CHAOTIC_GOOD_CHAOTIC_GOOD_TRANSPORT_H
#define GRPC_SRC_CORE_EXT_TRANSPORT_CHAOTIC_GOOD_CHAOTIC_GOOD_TRANSPORT_H
#include <grpc/support/port_platform.h>
#include "absl/random/random.h"
#include "src/core/ext/transport/chaotic_good/frame.h"
#include "src/core/ext/transport/chaotic_good/frame_header.h"
#include "src/core/ext/transport/chttp2/transport/hpack_encoder.h"
#include "src/core/lib/promise/if.h"
#include "src/core/lib/promise/promise.h"
#include "src/core/lib/promise/try_join.h"
#include "src/core/lib/promise/try_seq.h"
#include "src/core/lib/transport/promise_endpoint.h"
namespace grpc_core {
namespace chaotic_good {
class ChaoticGoodTransport {
public:
ChaoticGoodTransport(std::unique_ptr<PromiseEndpoint> control_endpoint,
std::unique_ptr<PromiseEndpoint> data_endpoint)
: control_endpoint_(std::move(control_endpoint)),
data_endpoint_(std::move(data_endpoint)) {}
auto WriteFrame(const FrameInterface& frame) {
auto buffers = frame.Serialize(&encoder_);
return TryJoin<absl::StatusOr>(
control_endpoint_->Write(std::move(buffers.control)),
data_endpoint_->Write(std::move(buffers.data)));
}
// Read frame header and payloads for control and data portions of one frame.
// Resolves to StatusOr<tuple<FrameHeader, BufferPair>>.
auto ReadFrameBytes() {
return TrySeq(
control_endpoint_->ReadSlice(FrameHeader::frame_header_size_),
[this](Slice read_buffer) {
auto frame_header =
FrameHeader::Parse(reinterpret_cast<const uint8_t*>(
GRPC_SLICE_START_PTR(read_buffer.c_slice())));
// Read header and trailers from control endpoint.
// Read message padding and message from data endpoint.
return If(
frame_header.ok(),
[this, &frame_header] {
const uint32_t message_padding = std::exchange(
last_message_padding_, frame_header->message_padding);
const uint32_t message_length = frame_header->message_length;
return Map(
TryJoin<absl::StatusOr>(
control_endpoint_->Read(frame_header->GetFrameLength()),
TrySeq(data_endpoint_->Read(message_padding),
[this, message_length]() {
return data_endpoint_->Read(message_length);
})),
[frame_header = *frame_header](
absl::StatusOr<std::tuple<SliceBuffer, SliceBuffer>>
buffers)
-> absl::StatusOr<std::tuple<FrameHeader, BufferPair>> {
if (!buffers.ok()) return buffers.status();
return std::tuple<FrameHeader, BufferPair>(
frame_header,
BufferPair{std::move(std::get<0>(*buffers)),
std::move(std::get<1>(*buffers))});
});
},
[&frame_header]()
-> absl::StatusOr<std::tuple<FrameHeader, BufferPair>> {
return frame_header.status();
});
});
}
absl::Status DeserializeFrame(FrameHeader header, BufferPair buffers,
Arena* arena, FrameInterface& frame) {
return frame.Deserialize(&parser_, header, bitgen_, arena,
std::move(buffers));
}
// Skip a frame, but correctly handle any hpack state updates.
void SkipFrame(FrameHeader, BufferPair) { Crash("not implemented"); }
private:
const std::unique_ptr<PromiseEndpoint> control_endpoint_;
const std::unique_ptr<PromiseEndpoint> data_endpoint_;
uint32_t last_message_padding_ = 0;
HPackCompressor encoder_;
HPackParser parser_;
absl::BitGen bitgen_;
};
} // namespace chaotic_good
} // namespace grpc_core
#endif // GRPC_SRC_CORE_EXT_TRANSPORT_CHAOTIC_GOOD_CHAOTIC_GOOD_TRANSPORT_H

@ -17,9 +17,11 @@
#include "src/core/ext/transport/chaotic_good/client_transport.h"
#include <cstdint>
#include <cstdlib>
#include <memory>
#include <string>
#include <tuple>
#include <utility>
#include "absl/random/bit_gen_ref.h"
#include "absl/random/random.h"
@ -36,9 +38,13 @@
#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/lib/iomgr/exec_ctx.h"
#include "src/core/lib/promise/activity.h"
#include "src/core/lib/promise/all_ok.h"
#include "src/core/lib/promise/event_engine_wakeup_scheduler.h"
#include "src/core/lib/promise/loop.h"
#include "src/core/lib/promise/map.h"
#include "src/core/lib/promise/promise.h"
#include "src/core/lib/promise/try_join.h"
#include "src/core/lib/promise/try_seq.h"
#include "src/core/lib/resource_quota/arena.h"
#include "src/core/lib/resource_quota/resource_quota.h"
#include "src/core/lib/slice/slice.h"
@ -49,59 +55,15 @@
namespace grpc_core {
namespace chaotic_good {
ClientTransport::ClientTransport(
std::unique_ptr<PromiseEndpoint> control_endpoint,
std::unique_ptr<PromiseEndpoint> data_endpoint,
std::shared_ptr<grpc_event_engine::experimental::EventEngine> event_engine)
: outgoing_frames_(MpscReceiver<ClientFrame>(4)),
control_endpoint_(std::move(control_endpoint)),
data_endpoint_(std::move(data_endpoint)),
control_endpoint_write_buffer_(SliceBuffer()),
data_endpoint_write_buffer_(SliceBuffer()),
hpack_compressor_(std::make_unique<HPackCompressor>()),
hpack_parser_(std::make_unique<HPackParser>()),
memory_allocator_(
ResourceQuota::Default()->memory_quota()->CreateMemoryAllocator(
"client_transport")),
arena_(MakeScopedArena(1024, &memory_allocator_)),
context_(arena_.get()),
event_engine_(event_engine) {
auto write_loop = Loop([this] {
auto ChaoticGoodClientTransport::TransportWriteLoop() {
return Loop([this] {
return TrySeq(
// Get next outgoing frame.
this->outgoing_frames_.Next(),
// Construct data buffers that will be sent to the endpoints.
outgoing_frames_.Next(),
// Serialize and write it out.
[this](ClientFrame client_frame) {
MatchMutable(
&client_frame,
[this](ClientFragmentFrame* frame) mutable {
control_endpoint_write_buffer_.Append(
frame->Serialize(hpack_compressor_.get()));
if (frame->message != nullptr) {
std::string message_padding(frame->message_padding, '0');
Slice slice(grpc_slice_from_cpp_string(message_padding));
// Append message padding to data_endpoint_buffer.
data_endpoint_write_buffer_.Append(std::move(slice));
// Append message payload to data_endpoint_buffer.
frame->message->payload()->MoveFirstNBytesIntoSliceBuffer(
frame->message->payload()->Length(),
data_endpoint_write_buffer_);
}
},
[this](CancelFrame* frame) mutable {
control_endpoint_write_buffer_.Append(
frame->Serialize(hpack_compressor_.get()));
});
return absl::OkStatus();
return transport_.WriteFrame(GetFrameInterface(client_frame));
},
// Write buffers to corresponding endpoints concurrently.
[this]() {
return TryJoin<absl::StatusOr>(
control_endpoint_->Write(
std::move(control_endpoint_write_buffer_)),
data_endpoint_->Write(std::move(data_endpoint_write_buffer_)));
},
// Finish writes to difference endpoints and continue the loop.
[]() -> LoopCtl<absl::Status> {
// The write failures will be caught in TrySeq and exit loop.
// Therefore, only need to return Continue() in the last lambda
@ -109,78 +71,215 @@ ClientTransport::ClientTransport(
return Continue();
});
});
writer_ = MakeActivity(
// Continuously write next outgoing frames to promise endpoints.
std::move(write_loop), EventEngineWakeupScheduler(event_engine_),
[this](absl::Status status) {
if (!(status.ok() || status.code() == absl::StatusCode::kCancelled)) {
this->AbortWithError();
}
}
absl::optional<CallHandler> ChaoticGoodClientTransport::LookupStream(
uint32_t stream_id) {
MutexLock lock(&mu_);
auto it = stream_map_.find(stream_id);
if (it == stream_map_.end()) {
return absl::nullopt;
}
return it->second;
}
auto ChaoticGoodClientTransport::PushFrameIntoCall(ServerFragmentFrame frame,
CallHandler call_handler) {
auto& headers = frame.headers;
return TrySeq(
If(
headers != nullptr,
[call_handler, &headers]() mutable {
return call_handler.PushServerInitialMetadata(std::move(headers));
},
[]() -> StatusFlag { return Success{}; }),
[call_handler, message = std::move(frame.message)]() mutable {
return If(
message.has_value(),
[&call_handler, &message]() mutable {
return call_handler.PushMessage(std::move(message->message));
},
[]() -> StatusFlag { return Success{}; });
},
// Hold Arena in activity for GetContext<Arena> usage.
arena_.get());
auto read_loop = Loop([this] {
[call_handler, trailers = std::move(frame.trailers)]() mutable {
return If(
trailers != nullptr,
[&call_handler, &trailers]() mutable {
return call_handler.PushServerTrailingMetadata(
std::move(trailers));
},
[]() -> StatusFlag { return Success{}; });
});
}
auto ChaoticGoodClientTransport::TransportReadLoop() {
return Loop([this] {
return TrySeq(
// Read frame header from control endpoint.
// TODO(ladynana): remove memcpy in ReadSlice.
this->control_endpoint_->ReadSlice(FrameHeader::frame_header_size_),
// Read different parts of the server frame from control/data endpoints
// based on frame header.
[this](Slice read_buffer) mutable {
frame_header_ = std::make_shared<FrameHeader>(
FrameHeader::Parse(
reinterpret_cast<const uint8_t*>(
GRPC_SLICE_START_PTR(read_buffer.c_slice())))
.value());
// Read header and trailers from control endpoint.
// Read message padding and message from data endpoint.
return TryJoin<absl::StatusOr>(
control_endpoint_->Read(frame_header_->GetFrameLength()),
data_endpoint_->Read(frame_header_->message_padding +
frame_header_->message_length));
transport_.ReadFrameBytes(),
[](std::tuple<FrameHeader, BufferPair> frame_bytes)
-> absl::StatusOr<std::tuple<FrameHeader, BufferPair>> {
const auto& frame_header = std::get<0>(frame_bytes);
if (frame_header.type != FrameType::kFragment) {
return absl::InternalError(
absl::StrCat("Expected fragment frame, got ",
static_cast<int>(frame_header.type)));
}
return frame_bytes;
},
// Construct and send the server frame to corresponding stream.
[this](std::tuple<SliceBuffer, SliceBuffer> ret) mutable {
control_endpoint_read_buffer_ = std::move(std::get<0>(ret));
// Discard message padding and only keep message in data read buffer.
std::get<1>(ret).MoveLastNBytesIntoSliceBuffer(
frame_header_->message_length, data_endpoint_read_buffer_);
[this](std::tuple<FrameHeader, BufferPair> frame_bytes) {
const auto& frame_header = std::get<0>(frame_bytes);
auto& buffers = std::get<1>(frame_bytes);
absl::optional<CallHandler> call_handler =
LookupStream(frame_header.stream_id);
ServerFragmentFrame frame;
// Initialized to get this_cpu() info in global_stat().
ExecCtx exec_ctx;
// Deserialize frame from read buffer.
absl::BitGen bitgen;
auto status = frame.Deserialize(hpack_parser_.get(), *frame_header_,
absl::BitGenRef(bitgen),
control_endpoint_read_buffer_);
GPR_ASSERT(status.ok());
// Move message into frame.
frame.message = arena_->MakePooled<Message>(
std::move(data_endpoint_read_buffer_), 0);
MutexLock lock(&mu_);
const uint32_t stream_id = frame_header_->stream_id;
return stream_map_[stream_id]->Push(ServerFrame(std::move(frame)));
},
// Check if send frame to corresponding stream successfully.
[](bool ret) -> LoopCtl<absl::Status> {
if (ret) {
// Send incoming frames successfully.
return Continue();
absl::Status deserialize_status;
if (call_handler.has_value()) {
deserialize_status = transport_.DeserializeFrame(
frame_header, std::move(buffers), call_handler->arena(), frame);
} else {
return absl::InternalError("Send incoming frames failed.");
// Stream not found, skip the frame.
transport_.SkipFrame(frame_header, std::move(buffers));
deserialize_status = absl::OkStatus();
}
});
return If(
deserialize_status.ok() && call_handler.has_value(),
[this, &frame, &call_handler]() {
return call_handler->SpawnWaitable(
"push-frame", [this, call_handler = *call_handler,
frame = std::move(frame)]() mutable {
return Map(call_handler.CancelIfFails(PushFrameIntoCall(
std::move(frame), call_handler)),
[](StatusFlag f) {
return StatusCast<absl::Status>(f);
});
});
},
[&deserialize_status]() -> absl::Status {
// Stream not found, nothing to do.
return std::move(deserialize_status);
});
},
[]() -> LoopCtl<absl::Status> { return Continue{}; });
});
reader_ = MakeActivity(
// Continuously read next incoming frames from promise endpoints.
std::move(read_loop), EventEngineWakeupScheduler(event_engine_),
[this](absl::Status status) {
if (!(status.ok() || status.code() == absl::StatusCode::kCancelled)) {
this->AbortWithError();
}
}
auto ChaoticGoodClientTransport::OnTransportActivityDone() {
return [this](absl::Status status) {
if (!(status.ok() || status.code() == absl::StatusCode::kCancelled)) {
this->AbortWithError();
}
};
}
ChaoticGoodClientTransport::ChaoticGoodClientTransport(
std::unique_ptr<PromiseEndpoint> control_endpoint,
std::unique_ptr<PromiseEndpoint> data_endpoint,
std::shared_ptr<grpc_event_engine::experimental::EventEngine> event_engine)
: outgoing_frames_(4),
transport_(std::move(control_endpoint), std::move(data_endpoint)),
writer_{
MakeActivity(
// Continuously write next outgoing frames to promise endpoints.
TransportWriteLoop(), EventEngineWakeupScheduler(event_engine),
OnTransportActivityDone()),
},
reader_{MakeActivity(
// Continuously read next incoming frames from promise endpoints.
TransportReadLoop(), EventEngineWakeupScheduler(event_engine),
OnTransportActivityDone())} {}
ChaoticGoodClientTransport::~ChaoticGoodClientTransport() {
if (writer_ != nullptr) {
writer_.reset();
}
if (reader_ != nullptr) {
reader_.reset();
}
}
void ChaoticGoodClientTransport::AbortWithError() {
// Mark transport as unavailable when the endpoint write/read failed.
// Close all the available pipes.
outgoing_frames_.MarkClosed();
ReleasableMutexLock lock(&mu_);
StreamMap stream_map = std::move(stream_map_);
stream_map_.clear();
lock.Release();
for (const auto& pair : stream_map) {
auto call_handler = pair.second;
call_handler.SpawnInfallible("cancel", [call_handler]() mutable {
call_handler.Cancel(ServerMetadataFromStatus(
absl::UnavailableError("Transport closed.")));
return Empty{};
});
}
}
uint32_t ChaoticGoodClientTransport::MakeStream(CallHandler call_handler) {
ReleasableMutexLock lock(&mu_);
const uint32_t stream_id = next_stream_id_++;
stream_map_.emplace(stream_id, call_handler);
lock.Release();
call_handler.OnDone([this, stream_id]() {
MutexLock lock(&mu_);
stream_map_.erase(stream_id);
});
return stream_id;
}
auto ChaoticGoodClientTransport::CallOutboundLoop(uint32_t stream_id,
CallHandler call_handler) {
auto send_fragment = [stream_id,
outgoing_frames = outgoing_frames_.MakeSender()](
ClientFragmentFrame frame) mutable {
frame.stream_id = stream_id;
return Map(outgoing_frames.Send(std::move(frame)),
[](bool success) -> absl::Status {
if (!success) {
// Failed to send outgoing frame.
return absl::UnavailableError("Transport closed.");
}
return absl::OkStatus();
});
};
return TrySeq(
// Wait for initial metadata then send it out.
call_handler.PullClientInitialMetadata(),
[send_fragment](ClientMetadataHandle md) mutable {
ClientFragmentFrame frame;
frame.headers = std::move(md);
return send_fragment(std::move(frame));
},
// Hold Arena in activity for GetContext<Arena> usage.
arena_.get());
// Continuously send client frame with client to server messages.
ForEach(OutgoingMessages(call_handler),
[send_fragment,
aligned_bytes = aligned_bytes_](MessageHandle message) mutable {
ClientFragmentFrame frame;
// Construct frame header (flags, header_length and
// trailer_length will be added in serialization).
const uint32_t message_length = message->payload()->Length();
const uint32_t padding =
message_length % aligned_bytes == 0
? 0
: aligned_bytes - message_length % aligned_bytes;
GPR_ASSERT((message_length + padding) % aligned_bytes == 0);
frame.message = FragmentMessage(std::move(message), padding,
message_length);
return send_fragment(std::move(frame));
}),
[send_fragment]() mutable {
ClientFragmentFrame frame;
frame.end_of_stream = true;
return send_fragment(std::move(frame));
});
}
void ChaoticGoodClientTransport::StartCall(CallHandler call_handler) {
// At this point, the connection is set up.
// Start sending data frames.
call_handler.SpawnGuarded("outbound_loop", [this, call_handler]() mutable {
return CallOutboundLoop(MakeStream(call_handler), call_handler);
});
}
} // namespace chaotic_good

@ -20,6 +20,7 @@
#include <stdint.h>
#include <stdio.h>
#include <cstdint>
#include <initializer_list> // IWYU pragma: keep
#include <map>
#include <memory>
@ -28,6 +29,8 @@
#include <utility>
#include "absl/base/thread_annotations.h"
#include "absl/container/flat_hash_map.h"
#include "absl/random/random.h"
#include "absl/status/status.h"
#include "absl/types/optional.h"
#include "absl/types/variant.h"
@ -35,6 +38,7 @@
#include <grpc/event_engine/event_engine.h>
#include <grpc/event_engine/memory_allocator.h>
#include "src/core/ext/transport/chaotic_good/chaotic_good_transport.h"
#include "src/core/ext/transport/chaotic_good/frame.h"
#include "src/core/ext/transport/chaotic_good/frame_header.h"
#include "src/core/ext/transport/chttp2/transport/hpack_encoder.h"
@ -61,178 +65,56 @@
namespace grpc_core {
namespace chaotic_good {
class ClientTransport {
class ChaoticGoodClientTransport final : public Transport,
public ClientTransport {
public:
ClientTransport(std::unique_ptr<PromiseEndpoint> control_endpoint,
std::unique_ptr<PromiseEndpoint> data_endpoint,
std::shared_ptr<grpc_event_engine::experimental::EventEngine>
event_engine);
~ClientTransport() {
if (writer_ != nullptr) {
writer_.reset();
}
if (reader_ != nullptr) {
reader_.reset();
}
}
void AbortWithError() {
// Mark transport as unavailable when the endpoint write/read failed.
// Close all the available pipes.
if (!outgoing_frames_.IsClosed()) {
outgoing_frames_.MarkClosed();
}
MutexLock lock(&mu_);
for (const auto& pair : stream_map_) {
if (!pair.second->IsClose()) {
pair.second->MarkClose();
}
}
}
auto AddStream(CallArgs call_args) {
// At this point, the connection is set up.
// Start sending data frames.
uint32_t stream_id;
InterActivityPipe<ServerFrame, server_frame_queue_size_> pipe_server_frames;
{
MutexLock lock(&mu_);
stream_id = next_stream_id_++;
stream_map_.insert(
std::pair<uint32_t,
std::shared_ptr<InterActivityPipe<
ServerFrame, server_frame_queue_size_>::Sender>>(
stream_id, std::make_shared<InterActivityPipe<
ServerFrame, server_frame_queue_size_>::Sender>(
std::move(pipe_server_frames.sender))));
}
return TrySeq(
TryJoin<absl::StatusOr>(
// Continuously send client frame with client to server messages.
ForEach(std::move(*call_args.client_to_server_messages),
[stream_id, initial_frame = true,
client_initial_metadata =
std::move(call_args.client_initial_metadata),
outgoing_frames = outgoing_frames_.MakeSender(),
this](MessageHandle result) mutable {
ClientFragmentFrame frame;
// Construct frame header (flags, header_length and
// trailer_length will be added in serialization).
uint32_t message_length = result->payload()->Length();
frame.stream_id = stream_id;
frame.message_padding = message_length % aligned_bytes;
frame.message = std::move(result);
if (initial_frame) {
// Send initial frame with client intial metadata.
frame.headers = std::move(client_initial_metadata);
initial_frame = false;
}
return TrySeq(
outgoing_frames.Send(ClientFrame(std::move(frame))),
[](bool success) -> absl::Status {
if (!success) {
// TODO(ladynana): propagate the actual error
// message from EventEngine.
return absl::UnavailableError(
"Transport closed due to endpoint write/read "
"failed.");
}
return absl::OkStatus();
});
}),
// Continuously receive server frames from endpoints and save
// results to call_args.
Loop([server_initial_metadata = call_args.server_initial_metadata,
server_to_client_messages =
call_args.server_to_client_messages,
receiver = std::move(pipe_server_frames.receiver)]() mutable {
return TrySeq(
// Receive incoming server frame.
receiver.Next(),
// Save incomming frame results to call_args.
[server_initial_metadata, server_to_client_messages](
absl::optional<ServerFrame> server_frame) mutable {
bool transport_closed = false;
ServerFragmentFrame frame;
if (!server_frame.has_value()) {
// Incoming server frame pipe is closed, which only
// happens when transport is aborted.
transport_closed = true;
} else {
frame = std::move(
absl::get<ServerFragmentFrame>(*server_frame));
};
bool has_headers = (frame.headers != nullptr);
bool has_message = (frame.message != nullptr);
bool has_trailers = (frame.trailers != nullptr);
return TrySeq(
If((!transport_closed) && has_headers,
[server_initial_metadata,
headers = std::move(frame.headers)]() mutable {
return server_initial_metadata->Push(
std::move(headers));
},
[] { return false; }),
If((!transport_closed) && has_message,
[server_to_client_messages,
message = std::move(frame.message)]() mutable {
return server_to_client_messages->Push(
std::move(message));
},
[] { return false; }),
If((!transport_closed) && has_trailers,
[trailers = std::move(frame.trailers)]() mutable
-> LoopCtl<ServerMetadataHandle> {
return std::move(trailers);
},
[transport_closed]()
-> LoopCtl<ServerMetadataHandle> {
if (transport_closed) {
// TODO(ladynana): propagate the actual error
// message from EventEngine.
return ServerMetadataFromStatus(
absl::UnavailableError(
"Transport closed due to endpoint "
"write/read failed."));
}
return Continue();
}));
});
})),
[](std::tuple<Empty, ServerMetadataHandle> ret) {
return std::move(std::get<1>(ret));
});
}
ChaoticGoodClientTransport(
std::unique_ptr<PromiseEndpoint> control_endpoint,
std::unique_ptr<PromiseEndpoint> data_endpoint,
std::shared_ptr<grpc_event_engine::experimental::EventEngine>
event_engine);
~ChaoticGoodClientTransport() override;
FilterStackTransport* filter_stack_transport() override { return nullptr; }
ClientTransport* client_transport() override { return this; }
ServerTransport* server_transport() override { return nullptr; }
absl::string_view GetTransportName() const override { return "chaotic_good"; }
void SetPollset(grpc_stream*, grpc_pollset*) override {}
void SetPollsetSet(grpc_stream*, grpc_pollset_set*) override {}
void PerformOp(grpc_transport_op*) override { Crash("unimplemented"); }
grpc_endpoint* GetEndpoint() override { return nullptr; }
void Orphan() override { delete this; }
void StartCall(CallHandler call_handler) override;
void AbortWithError();
private:
// Queue size of each stream pipe is set to 2, so that for each stream read it
// will queue at most 2 frames.
static const size_t kServerFrameQueueSize = 2;
using StreamMap = absl::flat_hash_map<uint32_t, CallHandler>;
uint32_t MakeStream(CallHandler call_handler);
absl::optional<CallHandler> LookupStream(uint32_t stream_id);
auto CallOutboundLoop(uint32_t stream_id, CallHandler call_handler);
auto OnTransportActivityDone();
auto TransportWriteLoop();
auto TransportReadLoop();
// Push one frame into a call
auto PushFrameIntoCall(ServerFragmentFrame frame, CallHandler call_handler);
// Max buffer is set to 4, so that for stream writes each time it will queue
// at most 2 frames.
MpscReceiver<ClientFrame> outgoing_frames_;
// Queue size of each stream pipe is set to 2, so that for each stream read it
// will queue at most 2 frames.
static const size_t server_frame_queue_size_ = 2;
ChaoticGoodTransport transport_;
// Assigned aligned bytes from setting frame.
size_t aligned_bytes = 64;
size_t aligned_bytes_ = 64;
Mutex mu_;
uint32_t next_stream_id_ ABSL_GUARDED_BY(mu_) = 1;
// Map of stream incoming server frames, key is stream_id.
std::map<uint32_t, std::shared_ptr<InterActivityPipe<
ServerFrame, server_frame_queue_size_>::Sender>>
stream_map_ ABSL_GUARDED_BY(mu_);
StreamMap stream_map_ ABSL_GUARDED_BY(mu_);
ActivityPtr writer_;
ActivityPtr reader_;
std::unique_ptr<PromiseEndpoint> control_endpoint_;
std::unique_ptr<PromiseEndpoint> data_endpoint_;
SliceBuffer control_endpoint_write_buffer_;
SliceBuffer data_endpoint_write_buffer_;
SliceBuffer control_endpoint_read_buffer_;
SliceBuffer data_endpoint_read_buffer_;
std::unique_ptr<HPackCompressor> hpack_compressor_;
std::unique_ptr<HPackParser> hpack_parser_;
std::shared_ptr<FrameHeader> frame_header_;
MemoryAllocator memory_allocator_;
ScopedArenaPtr arena_;
promise_detail::Context<Arena> context_;
// Use to synchronize writer_ and reader_ activity with outside activities;
std::shared_ptr<grpc_event_engine::experimental::EventEngine> event_engine_;
};
} // namespace chaotic_good

@ -40,6 +40,10 @@
namespace grpc_core {
namespace chaotic_good {
namespace {
const uint8_t kZeros[64] = {};
}
namespace {
const NoDestruct<Slice> kZeroSlice{[] {
// Frame header size is fixed to 24 bytes.
@ -50,53 +54,65 @@ const NoDestruct<Slice> kZeroSlice{[] {
class FrameSerializer {
public:
explicit FrameSerializer(FrameType frame_type, uint32_t stream_id,
uint32_t message_padding) {
output_.AppendIndexed(kZeroSlice->Copy());
explicit FrameSerializer(FrameType frame_type, uint32_t stream_id) {
output_.control.AppendIndexed(kZeroSlice->Copy());
header_.type = frame_type;
header_.stream_id = stream_id;
header_.message_padding = message_padding;
header_.flags.SetAll(false);
}
// If called, must be called before AddTrailers, Finish.
SliceBuffer& AddHeaders() {
header_.flags.set(0);
return output_;
return output_.control;
}
void AddMessage(const FragmentMessage& msg) {
header_.flags.set(1);
header_.message_length = msg.length;
header_.message_padding = msg.padding;
output_.data = msg.message->payload()->Copy();
if (msg.padding != 0) {
output_.data.Append(Slice::FromStaticBuffer(kZeros, msg.padding));
}
}
// If called, must be called before Finish.
SliceBuffer& AddTrailers() {
header_.flags.set(1);
header_.header_length = output_.Length() - FrameHeader::frame_header_size_;
return output_;
header_.flags.set(2);
header_.header_length =
output_.control.Length() - FrameHeader::frame_header_size_;
return output_.control;
}
SliceBuffer Finish() {
BufferPair Finish() {
// Calculate frame header_length or trailer_length if available.
if (header_.flags.is_set(1)) {
if (header_.flags.is_set(2)) {
// Header length is already known in AddTrailers().
header_.trailer_length = output_.Length() - header_.header_length -
header_.trailer_length = output_.control.Length() -
header_.header_length -
FrameHeader::frame_header_size_;
} else {
if (header_.flags.is_set(0)) {
// Calculate frame header length in Finish() since AddTrailers() isn't
// called.
header_.header_length =
output_.Length() - FrameHeader::frame_header_size_;
output_.control.Length() - FrameHeader::frame_header_size_;
}
}
header_.Serialize(
GRPC_SLICE_START_PTR(output_.c_slice_buffer()->slices[0]));
GRPC_SLICE_START_PTR(output_.control.c_slice_buffer()->slices[0]));
return std::move(output_);
}
private:
FrameHeader header_;
SliceBuffer output_;
BufferPair output_;
};
class FrameDeserializer {
public:
FrameDeserializer(const FrameHeader& header, SliceBuffer& input)
FrameDeserializer(const FrameHeader& header, BufferPair& input)
: header_(header), input_(input) {}
const FrameHeader& header() const { return header_; }
// If called, must be called before ReceiveTrailers, Finish.
@ -118,28 +134,27 @@ class FrameDeserializer {
private:
absl::StatusOr<SliceBuffer> Take(uint32_t length) {
if (length == 0) return SliceBuffer{};
if (input_.Length() < length) {
if (input_.control.Length() < length) {
return absl::InvalidArgumentError(
"Frame too short (insufficient payload)");
}
SliceBuffer out;
input_.MoveFirstNBytesIntoSliceBuffer(length, out);
input_.control.MoveFirstNBytesIntoSliceBuffer(length, out);
return std::move(out);
}
FrameHeader header_;
SliceBuffer& input_;
BufferPair& input_;
};
template <typename Metadata>
absl::StatusOr<Arena::PoolPtr<Metadata>> ReadMetadata(
HPackParser* parser, absl::StatusOr<SliceBuffer> maybe_slices,
uint32_t stream_id, bool is_header, bool is_client,
absl::BitGenRef bitsrc) {
uint32_t stream_id, bool is_header, bool is_client, absl::BitGenRef bitsrc,
Arena* arena) {
if (!maybe_slices.ok()) return maybe_slices.status();
auto& slices = *maybe_slices;
auto arena = GetContext<Arena>();
GPR_ASSERT(arena != nullptr);
Arena::PoolPtr<Metadata> metadata = arena->MakePooled<Metadata>(arena);
Arena::PoolPtr<Metadata> metadata = Arena::MakePooled<Metadata>(arena);
parser->BeginFrame(
metadata.get(), std::numeric_limits<uint32_t>::max(),
std::numeric_limits<uint32_t>::max(),
@ -161,20 +176,23 @@ absl::StatusOr<Arena::PoolPtr<Metadata>> ReadMetadata(
} // namespace
absl::Status SettingsFrame::Deserialize(HPackParser*, const FrameHeader& header,
absl::BitGenRef,
SliceBuffer& slice_buffer) {
absl::BitGenRef, Arena*,
BufferPair buffers) {
if (header.type != FrameType::kSettings) {
return absl::InvalidArgumentError("Expected settings frame");
}
if (header.flags.any()) {
return absl::InvalidArgumentError("Unexpected flags");
}
FrameDeserializer deserializer(header, slice_buffer);
if (buffers.data.Length() != 0) {
return absl::InvalidArgumentError("Unexpected data");
}
FrameDeserializer deserializer(header, buffers);
return deserializer.Finish();
}
SliceBuffer SettingsFrame::Serialize(HPackCompressor*) const {
FrameSerializer serializer(FrameType::kSettings, 0, 0);
BufferPair SettingsFrame::Serialize(HPackCompressor*) const {
FrameSerializer serializer(FrameType::kSettings, 0);
return serializer.Finish();
}
@ -183,19 +201,20 @@ std::string SettingsFrame::ToString() const { return "SettingsFrame{}"; }
absl::Status ClientFragmentFrame::Deserialize(HPackParser* parser,
const FrameHeader& header,
absl::BitGenRef bitsrc,
SliceBuffer& slice_buffer) {
Arena* arena,
BufferPair buffers) {
if (header.stream_id == 0) {
return absl::InvalidArgumentError("Expected non-zero stream id");
}
stream_id = header.stream_id;
message_padding = header.message_padding;
if (header.type != FrameType::kFragment) {
return absl::InvalidArgumentError("Expected fragment frame");
}
FrameDeserializer deserializer(header, slice_buffer);
FrameDeserializer deserializer(header, buffers);
if (header.flags.is_set(0)) {
auto r = ReadMetadata<ClientMetadata>(parser, deserializer.ReceiveHeaders(),
header.stream_id, true, true, bitsrc);
header.stream_id, true, true, bitsrc,
arena);
if (!r.ok()) return r.status();
if (r.value() != nullptr) {
headers = std::move(r.value());
@ -205,8 +224,17 @@ absl::Status ClientFragmentFrame::Deserialize(HPackParser* parser,
"Unexpected non-zero header length", header.header_length));
}
if (header.flags.is_set(1)) {
message =
FragmentMessage{Arena::MakePooled<Message>(std::move(buffers.data), 0),
header.message_padding, header.message_length};
} else if (buffers.data.Length() != 0) {
return absl::InvalidArgumentError(absl::StrCat(
"Unexpected non-zero message length ", buffers.data.Length()));
}
if (header.flags.is_set(2)) {
if (header.trailer_length != 0) {
return absl::InvalidArgumentError("Unexpected trailer length");
return absl::InvalidArgumentError(
absl::StrCat("Unexpected trailer length ", header.trailer_length));
}
end_of_stream = true;
} else {
@ -215,42 +243,53 @@ absl::Status ClientFragmentFrame::Deserialize(HPackParser* parser,
return deserializer.Finish();
}
SliceBuffer ClientFragmentFrame::Serialize(HPackCompressor* encoder) const {
BufferPair ClientFragmentFrame::Serialize(HPackCompressor* encoder) const {
GPR_ASSERT(stream_id != 0);
FrameSerializer serializer(FrameType::kFragment, stream_id, message_padding);
FrameSerializer serializer(FrameType::kFragment, stream_id);
if (headers.get() != nullptr) {
encoder->EncodeRawHeaders(*headers.get(), serializer.AddHeaders());
}
if (message.has_value()) {
serializer.AddMessage(message.value());
}
if (end_of_stream) {
serializer.AddTrailers();
}
return serializer.Finish();
}
std::string FragmentMessage::ToString() const {
std::string out =
absl::StrCat("FragmentMessage{length=", length, ", padding=", padding);
if (message.get() != nullptr) {
absl::StrAppend(&out, ", message=", message->DebugString().c_str());
}
absl::StrAppend(&out, "}");
return out;
}
std::string ClientFragmentFrame::ToString() const {
return absl::StrCat(
"ClientFragmentFrame{stream_id=", stream_id, ", headers=",
headers.get() != nullptr ? headers->DebugString().c_str() : "nullptr",
", message=",
message.get() != nullptr ? message->DebugString().c_str() : "nullptr",
", message_padding=", message_padding, ", end_of_stream=", end_of_stream,
"}");
", message=", message.has_value() ? message->ToString().c_str() : "none",
", end_of_stream=", end_of_stream, "}");
}
absl::Status ServerFragmentFrame::Deserialize(HPackParser* parser,
const FrameHeader& header,
absl::BitGenRef bitsrc,
SliceBuffer& slice_buffer) {
Arena* arena,
BufferPair buffers) {
if (header.stream_id == 0) {
return absl::InvalidArgumentError("Expected non-zero stream id");
}
stream_id = header.stream_id;
message_padding = header.message_padding;
FrameDeserializer deserializer(header, slice_buffer);
FrameDeserializer deserializer(header, buffers);
if (header.flags.is_set(0)) {
auto r =
ReadMetadata<ServerMetadata>(parser, deserializer.ReceiveHeaders(),
header.stream_id, true, false, bitsrc);
auto r = ReadMetadata<ServerMetadata>(parser, deserializer.ReceiveHeaders(),
header.stream_id, true, false, bitsrc,
arena);
if (!r.ok()) return r.status();
if (r.value() != nullptr) {
headers = std::move(r.value());
@ -260,9 +299,16 @@ absl::Status ServerFragmentFrame::Deserialize(HPackParser* parser,
"Unexpected non-zero header length", header.header_length));
}
if (header.flags.is_set(1)) {
auto r =
ReadMetadata<ServerMetadata>(parser, deserializer.ReceiveTrailers(),
header.stream_id, false, false, bitsrc);
message.emplace(Arena::MakePooled<Message>(std::move(buffers.data), 0),
header.message_padding, header.message_length);
} else if (buffers.data.Length() != 0) {
return absl::InvalidArgumentError(absl::StrCat(
"Unexpected non-zero message length", buffers.data.Length()));
}
if (header.flags.is_set(2)) {
auto r = ReadMetadata<ServerMetadata>(
parser, deserializer.ReceiveTrailers(), header.stream_id, false, false,
bitsrc, arena);
if (!r.ok()) return r.status();
if (r.value() != nullptr) {
trailers = std::move(r.value());
@ -274,12 +320,15 @@ absl::Status ServerFragmentFrame::Deserialize(HPackParser* parser,
return deserializer.Finish();
}
SliceBuffer ServerFragmentFrame::Serialize(HPackCompressor* encoder) const {
BufferPair ServerFragmentFrame::Serialize(HPackCompressor* encoder) const {
GPR_ASSERT(stream_id != 0);
FrameSerializer serializer(FrameType::kFragment, stream_id, message_padding);
FrameSerializer serializer(FrameType::kFragment, stream_id);
if (headers.get() != nullptr) {
encoder->EncodeRawHeaders(*headers.get(), serializer.AddHeaders());
}
if (message.has_value()) {
serializer.AddMessage(message.value());
}
if (trailers.get() != nullptr) {
encoder->EncodeRawHeaders(*trailers.get(), serializer.AddTrailers());
}
@ -290,16 +339,15 @@ std::string ServerFragmentFrame::ToString() const {
return absl::StrCat(
"ServerFragmentFrame{stream_id=", stream_id, ", headers=",
headers.get() != nullptr ? headers->DebugString().c_str() : "nullptr",
", message=",
message.get() != nullptr ? message->DebugString().c_str() : "nullptr",
", message_padding=", message_padding, ", trailers=",
", message=", message.has_value() ? message->ToString().c_str() : "none",
", trailers=",
trailers.get() != nullptr ? trailers->DebugString().c_str() : "nullptr",
"}");
}
absl::Status CancelFrame::Deserialize(HPackParser*, const FrameHeader& header,
absl::BitGenRef,
SliceBuffer& slice_buffer) {
absl::BitGenRef, Arena*,
BufferPair buffers) {
if (header.type != FrameType::kCancel) {
return absl::InvalidArgumentError("Expected cancel frame");
}
@ -309,14 +357,17 @@ absl::Status CancelFrame::Deserialize(HPackParser*, const FrameHeader& header,
if (header.stream_id == 0) {
return absl::InvalidArgumentError("Expected non-zero stream id");
}
FrameDeserializer deserializer(header, slice_buffer);
if (buffers.data.Length() != 0) {
return absl::InvalidArgumentError("Unexpected data");
}
FrameDeserializer deserializer(header, buffers);
stream_id = header.stream_id;
return deserializer.Finish();
}
SliceBuffer CancelFrame::Serialize(HPackCompressor*) const {
BufferPair CancelFrame::Serialize(HPackCompressor*) const {
GPR_ASSERT(stream_id != 0);
FrameSerializer serializer(FrameType::kCancel, stream_id, 0);
FrameSerializer serializer(FrameType::kCancel, stream_id);
return serializer.Finish();
}

@ -28,6 +28,7 @@
#include "src/core/ext/transport/chaotic_good/frame_header.h"
#include "src/core/ext/transport/chttp2/transport/hpack_encoder.h"
#include "src/core/ext/transport/chttp2/transport/hpack_parser.h"
#include "src/core/lib/gprpp/match.h"
#include "src/core/lib/resource_quota/arena.h"
#include "src/core/lib/slice/slice_buffer.h"
#include "src/core/lib/transport/metadata_batch.h"
@ -36,20 +37,21 @@
namespace grpc_core {
namespace chaotic_good {
struct BufferPair {
SliceBuffer control;
SliceBuffer data;
};
class FrameInterface {
public:
virtual absl::Status Deserialize(HPackParser* parser,
const FrameHeader& header,
absl::BitGenRef bitsrc,
SliceBuffer& slice_buffer) = 0;
virtual SliceBuffer Serialize(HPackCompressor* encoder) const = 0;
absl::BitGenRef bitsrc, Arena* arena,
BufferPair buffers) = 0;
virtual BufferPair Serialize(HPackCompressor* encoder) const = 0;
virtual std::string ToString() const = 0;
protected:
static bool EqVal(const Message& a, const Message& b) {
return a.payload()->JoinIntoString() == b.payload()->JoinIntoString() &&
a.flags() == b.flags();
}
static bool EqVal(const grpc_metadata_batch& a,
const grpc_metadata_batch& b) {
return a.DebugString() == b.DebugString();
@ -65,57 +67,75 @@ class FrameInterface {
struct SettingsFrame final : public FrameInterface {
absl::Status Deserialize(HPackParser* parser, const FrameHeader& header,
absl::BitGenRef bitsrc,
SliceBuffer& slice_buffer) override;
SliceBuffer Serialize(HPackCompressor* encoder) const override;
absl::BitGenRef bitsrc, Arena* arena,
BufferPair buffers) override;
BufferPair Serialize(HPackCompressor* encoder) const override;
std::string ToString() const override;
bool operator==(const SettingsFrame&) const { return true; }
};
struct FragmentMessage {
FragmentMessage(MessageHandle message, uint32_t padding, uint32_t length)
: message(std::move(message)), padding(padding), length(length) {}
MessageHandle message;
uint32_t padding;
uint32_t length;
std::string ToString() const;
static bool EqVal(const Message& a, const Message& b) {
return a.payload()->JoinIntoString() == b.payload()->JoinIntoString() &&
a.flags() == b.flags();
}
bool operator==(const FragmentMessage& other) const {
return EqVal(*message, *other.message) && length == other.length;
}
};
struct ClientFragmentFrame final : public FrameInterface {
absl::Status Deserialize(HPackParser* parser, const FrameHeader& header,
absl::BitGenRef bitsrc,
SliceBuffer& slice_buffer) override;
SliceBuffer Serialize(HPackCompressor* encoder) const override;
absl::BitGenRef bitsrc, Arena* arena,
BufferPair buffers) override;
BufferPair Serialize(HPackCompressor* encoder) const override;
std::string ToString() const override;
uint32_t stream_id;
ClientMetadataHandle headers;
MessageHandle message;
uint32_t message_padding;
absl::optional<FragmentMessage> message;
bool end_of_stream = false;
bool operator==(const ClientFragmentFrame& other) const {
return stream_id == other.stream_id && EqHdl(headers, other.headers) &&
end_of_stream == other.end_of_stream;
message == other.message && end_of_stream == other.end_of_stream;
}
};
struct ServerFragmentFrame final : public FrameInterface {
absl::Status Deserialize(HPackParser* parser, const FrameHeader& header,
absl::BitGenRef bitsrc,
SliceBuffer& slice_buffer) override;
SliceBuffer Serialize(HPackCompressor* encoder) const override;
absl::BitGenRef bitsrc, Arena* arena,
BufferPair buffers) override;
BufferPair Serialize(HPackCompressor* encoder) const override;
std::string ToString() const override;
uint32_t stream_id;
ServerMetadataHandle headers;
MessageHandle message;
uint32_t message_padding;
absl::optional<FragmentMessage> message;
ServerMetadataHandle trailers;
bool operator==(const ServerFragmentFrame& other) const {
return stream_id == other.stream_id && EqHdl(headers, other.headers) &&
EqHdl(trailers, other.trailers);
message == other.message && EqHdl(trailers, other.trailers);
}
};
struct CancelFrame final : public FrameInterface {
absl::Status Deserialize(HPackParser* parser, const FrameHeader& header,
absl::BitGenRef bitsrc,
SliceBuffer& slice_buffer) override;
SliceBuffer Serialize(HPackCompressor* encoder) const override;
absl::BitGenRef bitsrc, Arena* arena,
BufferPair buffers) override;
BufferPair Serialize(HPackCompressor* encoder) const override;
std::string ToString() const override;
uint32_t stream_id;
@ -128,6 +148,19 @@ struct CancelFrame final : public FrameInterface {
using ClientFrame = absl::variant<ClientFragmentFrame, CancelFrame>;
using ServerFrame = absl::variant<ServerFragmentFrame>;
inline FrameInterface& GetFrameInterface(ClientFrame& frame) {
return MatchMutable(
&frame,
[](ClientFragmentFrame* frame) -> FrameInterface& { return *frame; },
[](CancelFrame* frame) -> FrameInterface& { return *frame; });
}
inline FrameInterface& GetFrameInterface(ServerFrame& frame) {
return MatchMutable(
&frame,
[](ServerFragmentFrame* frame) -> FrameInterface& { return *frame; });
}
} // namespace chaotic_good
} // namespace grpc_core

@ -46,7 +46,6 @@ void FrameHeader::Serialize(uint8_t* data) const {
WriteLittleEndianUint32(
static_cast<uint32_t>(type) | (flags.ToInt<uint32_t>() << 8), data);
if (flags.is_set(0)) GPR_ASSERT(header_length > 0);
if (flags.is_set(1)) GPR_ASSERT(trailer_length > 0);
WriteLittleEndianUint32(stream_id, data + 4);
WriteLittleEndianUint32(header_length, data + 8);
WriteLittleEndianUint32(message_length, data + 12);
@ -60,8 +59,8 @@ absl::StatusOr<FrameHeader> FrameHeader::Parse(const uint8_t* data) {
const uint32_t type_and_flags = ReadLittleEndianUint32(data);
header.type = static_cast<FrameType>(type_and_flags & 0xff);
const uint32_t flags = type_and_flags >> 8;
if (flags > 3) return absl::InvalidArgumentError("Invalid flags");
header.flags = BitSet<2>::FromInt(flags);
if (flags > 7) return absl::InvalidArgumentError("Invalid flags");
header.flags = BitSet<3>::FromInt(flags);
header.stream_id = ReadLittleEndianUint32(data + 4);
header.header_length = ReadLittleEndianUint32(data + 8);
if (header.flags.is_set(0) && header.header_length <= 0) {
@ -70,11 +69,11 @@ absl::StatusOr<FrameHeader> FrameHeader::Parse(const uint8_t* data) {
}
header.message_length = ReadLittleEndianUint32(data + 12);
header.message_padding = ReadLittleEndianUint32(data + 16);
header.trailer_length = ReadLittleEndianUint32(data + 20);
if (header.flags.is_set(1) && header.trailer_length <= 0) {
if (header.flags.is_set(1) && header.message_length <= 0) {
return absl::InvalidArgumentError(
absl::StrCat("Invalid trailer length", header.trailer_length));
absl::StrCat("Invalid message length: ", header.message_length));
}
header.trailer_length = ReadLittleEndianUint32(data + 20);
return header;
}

@ -36,7 +36,7 @@ enum class FrameType : uint8_t {
struct FrameHeader {
FrameType type = FrameType::kCancel;
BitSet<2> flags;
BitSet<3> flags;
uint32_t stream_id = 0;
uint32_t header_length = 0;
uint32_t message_length = 0;

@ -0,0 +1,332 @@
// Copyright 2022 gRPC authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <grpc/support/port_platform.h>
#include "src/core/ext/transport/chaotic_good/server_transport.h"
#include <memory>
#include <string>
#include <tuple>
#include "absl/random/bit_gen_ref.h"
#include "absl/random/random.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include <grpc/event_engine/event_engine.h>
#include <grpc/slice.h>
#include <grpc/support/log.h>
#include "src/core/ext/transport/chaotic_good/frame.h"
#include "src/core/ext/transport/chaotic_good/frame_header.h"
#include "src/core/ext/transport/chttp2/transport/hpack_encoder.h"
#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/lib/iomgr/exec_ctx.h"
#include "src/core/lib/promise/activity.h"
#include "src/core/lib/promise/event_engine_wakeup_scheduler.h"
#include "src/core/lib/promise/for_each.h"
#include "src/core/lib/promise/loop.h"
#include "src/core/lib/promise/switch.h"
#include "src/core/lib/promise/try_seq.h"
#include "src/core/lib/resource_quota/arena.h"
#include "src/core/lib/resource_quota/resource_quota.h"
#include "src/core/lib/slice/slice.h"
#include "src/core/lib/slice/slice_buffer.h"
#include "src/core/lib/transport/promise_endpoint.h"
namespace grpc_core {
namespace chaotic_good {
auto ChaoticGoodServerTransport::TransportWriteLoop() {
return Loop([this] {
return TrySeq(
// Get next outgoing frame.
outgoing_frames_.Next(),
// Serialize and write it out.
[this](ServerFrame client_frame) {
return transport_.WriteFrame(GetFrameInterface(client_frame));
},
[]() -> LoopCtl<absl::Status> {
// The write failures will be caught in TrySeq and exit loop.
// Therefore, only need to return Continue() in the last lambda
// function.
return Continue();
});
});
}
auto ChaoticGoodServerTransport::PushFragmentIntoCall(
CallInitiator call_initiator, ClientFragmentFrame frame) {
auto& headers = frame.headers;
return TrySeq(
If(
headers != nullptr,
[call_initiator, &headers]() mutable {
return call_initiator.PushClientInitialMetadata(std::move(headers));
},
[]() -> StatusFlag { return Success{}; }),
[call_initiator, message = std::move(frame.message)]() mutable {
return If(
message.has_value(),
[&call_initiator, &message]() mutable {
return call_initiator.PushMessage(std::move(message->message));
},
[]() -> StatusFlag { return Success{}; });
},
[call_initiator,
end_of_stream = frame.end_of_stream]() mutable -> StatusFlag {
if (end_of_stream) call_initiator.FinishSends();
return Success{};
});
}
auto ChaoticGoodServerTransport::MaybePushFragmentIntoCall(
absl::optional<CallInitiator> call_initiator, absl::Status error,
ClientFragmentFrame frame) {
return If(
call_initiator.has_value() && error.ok(),
[this, &call_initiator, &frame]() {
return Map(
call_initiator->SpawnWaitable(
"push-fragment",
[call_initiator, frame = std::move(frame), this]() mutable {
return call_initiator->CancelIfFails(
PushFragmentIntoCall(*call_initiator, std::move(frame)));
}),
[](StatusFlag status) { return StatusCast<absl::Status>(status); });
},
[error = std::move(error)]() { return error; });
}
auto ChaoticGoodServerTransport::CallOutboundLoop(
uint32_t stream_id, CallInitiator call_initiator) {
auto send_fragment = [stream_id,
outgoing_frames = outgoing_frames_.MakeSender()](
ServerFragmentFrame frame) mutable {
frame.stream_id = stream_id;
return Map(outgoing_frames.Send(std::move(frame)),
[](bool success) -> absl::Status {
if (!success) {
// Failed to send outgoing frame.
return absl::UnavailableError("Transport closed.");
}
return absl::OkStatus();
});
};
return Seq(
TrySeq(
// Wait for initial metadata then send it out.
call_initiator.PullServerInitialMetadata(),
[send_fragment](ServerMetadataHandle md) mutable {
ServerFragmentFrame frame;
frame.headers = std::move(md);
return send_fragment(std::move(frame));
},
// Continuously send client frame with client to server messages.
ForEach(OutgoingMessages(call_initiator),
[send_fragment, aligned_bytes = aligned_bytes_](
MessageHandle message) mutable {
ServerFragmentFrame frame;
// Construct frame header (flags, header_length and
// trailer_length will be added in serialization).
const uint32_t message_length =
message->payload()->Length();
const uint32_t padding =
message_length % aligned_bytes == 0
? 0
: aligned_bytes - message_length % aligned_bytes;
GPR_ASSERT((message_length + padding) % aligned_bytes == 0);
frame.message = FragmentMessage(std::move(message), padding,
message_length);
return send_fragment(std::move(frame));
})),
call_initiator.PullServerTrailingMetadata(),
[send_fragment](ServerMetadataHandle md) mutable {
ServerFragmentFrame frame;
frame.trailers = std::move(md);
return send_fragment(std::move(frame));
});
}
auto ChaoticGoodServerTransport::DeserializeAndPushFragmentToNewCall(
FrameHeader frame_header, BufferPair buffers) {
ClientFragmentFrame fragment_frame;
ScopedArenaPtr arena(acceptor_->CreateArena());
absl::Status status = transport_.DeserializeFrame(
frame_header, std::move(buffers), arena.get(), fragment_frame);
absl::optional<CallInitiator> call_initiator;
if (status.ok()) {
auto create_call_result =
acceptor_->CreateCall(*fragment_frame.headers, arena.release());
if (create_call_result.ok()) {
call_initiator.emplace(std::move(*create_call_result));
call_initiator->SpawnGuarded(
"server-write", [this, stream_id = frame_header.stream_id,
call_initiator = *call_initiator]() {
return CallOutboundLoop(stream_id, call_initiator);
});
} else {
status = create_call_result.status();
}
}
return MaybePushFragmentIntoCall(std::move(call_initiator), std::move(status),
std::move(fragment_frame));
}
auto ChaoticGoodServerTransport::DeserializeAndPushFragmentToExistingCall(
FrameHeader frame_header, BufferPair buffers) {
absl::optional<CallInitiator> call_initiator =
LookupStream(frame_header.stream_id);
Arena* arena = nullptr;
if (call_initiator.has_value()) arena = call_initiator->arena();
ClientFragmentFrame fragment_frame;
absl::Status status = transport_.DeserializeFrame(
frame_header, std::move(buffers), arena, fragment_frame);
return MaybePushFragmentIntoCall(std::move(call_initiator), std::move(status),
std::move(fragment_frame));
}
auto ChaoticGoodServerTransport::TransportReadLoop() {
return Loop([this] {
return TrySeq(
transport_.ReadFrameBytes(),
[this](std::tuple<FrameHeader, BufferPair> frame_bytes) {
const auto& frame_header = std::get<0>(frame_bytes);
auto& buffers = std::get<1>(frame_bytes);
return Switch(
frame_header.type,
Case(FrameType::kSettings,
[]() -> absl::Status {
return absl::InternalError("Unexpected settings frame");
}),
Case(FrameType::kFragment,
[this, &frame_header, &buffers]() {
return If(
frame_header.flags.is_set(0),
[this, &frame_header, &buffers]() {
return DeserializeAndPushFragmentToNewCall(
frame_header, std::move(buffers));
},
[this, &frame_header, &buffers]() {
return DeserializeAndPushFragmentToExistingCall(
frame_header, std::move(buffers));
});
}),
Case(FrameType::kCancel,
[this, &frame_header]() {
absl::optional<CallInitiator> call_initiator =
ExtractStream(frame_header.stream_id);
return If(
call_initiator.has_value(),
[&call_initiator]() {
auto c = std::move(*call_initiator);
return c.SpawnWaitable("cancel", [c]() mutable {
c.Cancel();
return absl::OkStatus();
});
},
[]() -> absl::Status {
return absl::InternalError(
"Unexpected cancel frame");
});
}),
Default([frame_header]() {
return absl::InternalError(
absl::StrCat("Unexpected frame type: ",
static_cast<uint8_t>(frame_header.type)));
}));
},
[]() -> LoopCtl<absl::Status> { return Continue{}; });
});
}
auto ChaoticGoodServerTransport::OnTransportActivityDone() {
return [this](absl::Status status) {
if (!(status.ok() || status.code() == absl::StatusCode::kCancelled)) {
this->AbortWithError();
}
};
}
ChaoticGoodServerTransport::ChaoticGoodServerTransport(
const ChannelArgs& args, std::unique_ptr<PromiseEndpoint> control_endpoint,
std::unique_ptr<PromiseEndpoint> data_endpoint,
std::shared_ptr<grpc_event_engine::experimental::EventEngine> event_engine)
: outgoing_frames_(4),
transport_(std::move(control_endpoint), std::move(data_endpoint)),
allocator_(args.GetObject<ResourceQuota>()
->memory_quota()
->CreateMemoryAllocator("chaotic-good")),
event_engine_(event_engine),
writer_{MakeActivity(TransportWriteLoop(),
EventEngineWakeupScheduler(event_engine),
OnTransportActivityDone())},
reader_{nullptr} {}
void ChaoticGoodServerTransport::SetAcceptor(Acceptor* acceptor) {
GPR_ASSERT(acceptor_ == nullptr);
GPR_ASSERT(acceptor != nullptr);
acceptor_ = acceptor;
reader_ = MakeActivity(TransportReadLoop(),
EventEngineWakeupScheduler(event_engine_),
OnTransportActivityDone());
}
ChaoticGoodServerTransport::~ChaoticGoodServerTransport() {
if (writer_ != nullptr) {
writer_.reset();
}
if (reader_ != nullptr) {
reader_.reset();
}
}
void ChaoticGoodServerTransport::AbortWithError() {
// Mark transport as unavailable when the endpoint write/read failed.
// Close all the available pipes.
outgoing_frames_.MarkClosed();
ReleasableMutexLock lock(&mu_);
StreamMap stream_map = std::move(stream_map_);
stream_map_.clear();
lock.Release();
for (const auto& pair : stream_map) {
auto call_initiator = pair.second;
call_initiator.SpawnInfallible("cancel", [call_initiator]() mutable {
call_initiator.Cancel();
return Empty{};
});
}
}
absl::optional<CallInitiator> ChaoticGoodServerTransport::LookupStream(
uint32_t stream_id) {
MutexLock lock(&mu_);
auto it = stream_map_.find(stream_id);
if (it == stream_map_.end()) return absl::nullopt;
return it->second;
}
absl::optional<CallInitiator> ChaoticGoodServerTransport::ExtractStream(
uint32_t stream_id) {
MutexLock lock(&mu_);
auto it = stream_map_.find(stream_id);
if (it == stream_map_.end()) return absl::nullopt;
auto r = std::move(it->second);
stream_map_.erase(it);
return std::move(r);
}
} // namespace chaotic_good
} // namespace grpc_core

@ -0,0 +1,145 @@
// Copyright 2022 gRPC authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef GRPC_SRC_CORE_EXT_TRANSPORT_CHAOTIC_GOOD_SERVER_TRANSPORT_H
#define GRPC_SRC_CORE_EXT_TRANSPORT_CHAOTIC_GOOD_SERVER_TRANSPORT_H
#include <grpc/support/port_platform.h>
#include <stdint.h>
#include <stdio.h>
#include <cstdint>
#include <initializer_list> // IWYU pragma: keep
#include <iostream>
#include <map>
#include <memory>
#include <optional>
#include <string>
#include <tuple>
#include <type_traits>
#include <utility>
#include "absl/base/thread_annotations.h"
#include "absl/container/flat_hash_map.h"
#include "absl/functional/any_invocable.h"
#include "absl/random/random.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/types/optional.h"
#include "absl/types/variant.h"
#include <grpc/event_engine/event_engine.h>
#include <grpc/event_engine/memory_allocator.h>
#include <grpc/slice.h>
#include <grpc/support/log.h>
#include "src/core/ext/transport/chaotic_good/chaotic_good_transport.h"
#include "src/core/ext/transport/chaotic_good/frame.h"
#include "src/core/ext/transport/chaotic_good/frame_header.h"
#include "src/core/ext/transport/chttp2/transport/hpack_encoder.h"
#include "src/core/ext/transport/chttp2/transport/hpack_parser.h"
#include "src/core/lib/event_engine/default_event_engine.h" // IWYU pragma: keep
#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/lib/gprpp/sync.h"
#include "src/core/lib/promise/activity.h"
#include "src/core/lib/promise/context.h"
#include "src/core/lib/promise/if.h"
#include "src/core/lib/promise/inter_activity_pipe.h"
#include "src/core/lib/promise/loop.h"
#include "src/core/lib/promise/mpsc.h"
#include "src/core/lib/promise/party.h"
#include "src/core/lib/promise/pipe.h"
#include "src/core/lib/promise/poll.h"
#include "src/core/lib/promise/seq.h"
#include "src/core/lib/promise/try_join.h"
#include "src/core/lib/promise/try_seq.h"
#include "src/core/lib/resource_quota/arena.h"
#include "src/core/lib/resource_quota/memory_quota.h"
#include "src/core/lib/slice/slice.h"
#include "src/core/lib/slice/slice_buffer.h"
#include "src/core/lib/slice/slice_internal.h"
#include "src/core/lib/transport/metadata_batch.h"
#include "src/core/lib/transport/promise_endpoint.h"
#include "src/core/lib/transport/transport.h"
namespace grpc_core {
namespace chaotic_good {
class ChaoticGoodServerTransport final : public Transport,
public ServerTransport {
public:
ChaoticGoodServerTransport(
const ChannelArgs& args,
std::unique_ptr<PromiseEndpoint> control_endpoint,
std::unique_ptr<PromiseEndpoint> data_endpoint,
std::shared_ptr<grpc_event_engine::experimental::EventEngine>
event_engine);
~ChaoticGoodServerTransport() override;
FilterStackTransport* filter_stack_transport() override { return nullptr; }
ClientTransport* client_transport() override { return nullptr; }
ServerTransport* server_transport() override { return this; }
absl::string_view GetTransportName() const override { return "chaotic_good"; }
void SetPollset(grpc_stream*, grpc_pollset*) override {}
void SetPollsetSet(grpc_stream*, grpc_pollset_set*) override {}
void PerformOp(grpc_transport_op*) override { Crash("unimplemented"); }
grpc_endpoint* GetEndpoint() override { return nullptr; }
void Orphan() override { delete this; }
void SetAcceptor(Acceptor* acceptor) override;
void AbortWithError();
private:
using StreamMap = absl::flat_hash_map<uint32_t, CallInitiator>;
absl::Status NewStream(uint32_t stream_id, CallInitiator call_initiator);
absl::optional<CallInitiator> LookupStream(uint32_t stream_id);
absl::optional<CallInitiator> ExtractStream(uint32_t stream_id);
auto CallOutboundLoop(uint32_t stream_id, CallInitiator call_initiator);
auto OnTransportActivityDone();
auto TransportReadLoop();
auto TransportWriteLoop();
// Read different parts of the server frame from control/data endpoints
// based on frame header.
// Resolves to a StatusOr<tuple<SliceBuffer, SliceBuffer>>
auto ReadFrameBody(Slice read_buffer);
void SendCancel(uint32_t stream_id, absl::Status why);
auto DeserializeAndPushFragmentToNewCall(FrameHeader frame_header,
BufferPair buffers);
auto DeserializeAndPushFragmentToExistingCall(FrameHeader frame_header,
BufferPair buffers);
auto MaybePushFragmentIntoCall(absl::optional<CallInitiator> call_initiator,
absl::Status error, ClientFragmentFrame frame);
auto PushFragmentIntoCall(CallInitiator call_initiator,
ClientFragmentFrame frame);
Acceptor* acceptor_ = nullptr;
MpscReceiver<ServerFrame> outgoing_frames_;
ChaoticGoodTransport transport_;
// Assigned aligned bytes from setting frame.
size_t aligned_bytes_ = 64;
Mutex mu_;
// Map of stream incoming server frames, key is stream_id.
StreamMap stream_map_ ABSL_GUARDED_BY(mu_);
grpc_event_engine::experimental::MemoryAllocator allocator_;
std::shared_ptr<grpc_event_engine::experimental::EventEngine> event_engine_;
ActivityPtr writer_;
ActivityPtr reader_;
};
} // namespace chaotic_good
} // namespace grpc_core
#endif // GRPC_SRC_CORE_EXT_TRANSPORT_CHAOTIC_GOOD_SERVER_TRANSPORT_H

@ -36,8 +36,8 @@ class InprocServerTransport final : public RefCounted<InprocServerTransport>,
public Transport,
public ServerTransport {
public:
void SetAcceptFunction(AcceptFunction accept_function) override {
accept_ = std::move(accept_function);
void SetAcceptor(Acceptor* acceptor) override {
acceptor_ = acceptor;
ConnectionState expect = ConnectionState::kInitial;
state_.compare_exchange_strong(expect, ConnectionState::kReady,
std::memory_order_acq_rel,
@ -92,7 +92,7 @@ class InprocServerTransport final : public RefCounted<InprocServerTransport>,
case ConnectionState::kReady:
break;
}
return accept_(md);
return acceptor_->CreateCall(md, acceptor_->CreateArena());
}
private:
@ -100,7 +100,7 @@ class InprocServerTransport final : public RefCounted<InprocServerTransport>,
std::atomic<ConnectionState> state_{ConnectionState::kInitial};
std::atomic<bool> disconnecting_{false};
AcceptFunction accept_;
Acceptor* acceptor_;
absl::Status disconnect_error_;
Mutex state_tracker_mu_;
ConnectivityStateTracker state_tracker_ ABSL_GUARDED_BY(state_tracker_mu_){

@ -81,6 +81,15 @@ class DebugLocation {
};
#endif
template <typename T>
struct ValueWithDebugLocation {
// NOLINTNEXTLINE
ValueWithDebugLocation(T&& value, DebugLocation debug_location = {})
: value(std::forward<T>(value)), debug_location(debug_location) {}
T value;
GPR_NO_UNIQUE_ADDRESS DebugLocation debug_location;
};
#define DEBUG_LOCATION ::grpc_core::DebugLocation(__FILE__, __LINE__)
} // namespace grpc_core

@ -45,6 +45,11 @@ inline absl::Status IntoStatus(absl::Status* status) {
// can participate in TrySeq as result types that affect control flow.
inline bool IsStatusOk(const absl::Status& status) { return status.ok(); }
template <typename T>
inline bool IsStatusOk(const absl::StatusOr<T>& status) {
return status.ok();
}
template <typename To, typename From, typename SfinaeVoid = void>
struct StatusCastImpl;
@ -59,20 +64,52 @@ struct StatusCastImpl<To, const To&> {
};
template <typename T>
struct StatusCastImpl<absl::StatusOr<T>, absl::Status> {
static absl::StatusOr<T> Cast(absl::Status&& t) { return std::move(t); }
struct StatusCastImpl<absl::Status, absl::StatusOr<T>> {
static absl::Status Cast(absl::StatusOr<T>&& t) {
return std::move(t.status());
}
};
template <typename T>
struct StatusCastImpl<absl::StatusOr<T>, const absl::Status&> {
static absl::StatusOr<T> Cast(const absl::Status& t) { return t; }
struct StatusCastImpl<absl::Status, absl::StatusOr<T>&> {
static absl::Status Cast(const absl::StatusOr<T>& t) { return t.status(); }
};
template <typename T>
struct StatusCastImpl<absl::Status, const absl::StatusOr<T>&> {
static absl::Status Cast(const absl::StatusOr<T>& t) { return t.status(); }
};
// StatusCast<> allows casting from one status-bearing type to another,
// regardless of whether the status indicates success or failure.
// This means that we can go from StatusOr to Status safely, but not in the
// opposite direction.
// For cases where the status is guaranteed to be a failure (and hence not
// needing to preserve values) see FailureStatusCast<> below.
template <typename To, typename From>
To StatusCast(From&& from) {
return StatusCastImpl<To, From>::Cast(std::forward<From>(from));
}
template <typename To, typename From, typename SfinaeVoid = void>
struct FailureStatusCastImpl : public StatusCastImpl<To, From> {};
template <typename T>
struct FailureStatusCastImpl<absl::StatusOr<T>, absl::Status> {
static absl::StatusOr<T> Cast(absl::Status&& t) { return std::move(t); }
};
template <typename T>
struct FailureStatusCastImpl<absl::StatusOr<T>, const absl::Status&> {
static absl::StatusOr<T> Cast(const absl::Status& t) { return t; }
};
template <typename To, typename From>
To FailureStatusCast(From&& from) {
GPR_DEBUG_ASSERT(!IsStatusOk(from));
return FailureStatusCastImpl<To, From>::Cast(std::forward<From>(from));
}
} // namespace grpc_core
#endif // GRPC_SRC_CORE_LIB_PROMISE_DETAIL_STATUS_H

@ -33,7 +33,9 @@ class EventEngineWakeupScheduler {
explicit EventEngineWakeupScheduler(
std::shared_ptr<grpc_event_engine::experimental::EventEngine>
event_engine)
: event_engine_(std::move(event_engine)) {}
: event_engine_(std::move(event_engine)) {
GPR_ASSERT(event_engine_ != nullptr);
}
template <typename ActivityType>
class BoundScheduler

@ -192,6 +192,10 @@ class If<bool, T, F> {
// If it returns failure, returns failure for the entire combinator.
// If it returns true, evaluates the second promise.
// If it returns false, evaluates the third promise.
// If C is a constant, it's guaranteed that one of the promise factories
// if_true or if_false will be evaluated before returning from this function.
// This makes it safe to capture lambda arguments in the promise factory by
// reference.
template <typename C, typename T, typename F>
promise_detail::If<C, T, F> If(C condition, T if_true, F if_false) {
return promise_detail::If<C, T, F>(std::move(condition), std::move(if_true),

@ -113,9 +113,9 @@ class InterActivityPipe {
if (center_ != nullptr) center_->MarkClosed();
}
bool IsClose() { return center_->IsClosed(); }
bool IsClosed() { return center_->IsClosed(); }
void MarkClose() {
void MarkClosed() {
if (center_ != nullptr) center_->MarkClosed();
}
@ -146,6 +146,12 @@ class InterActivityPipe {
return [center = center_]() { return center->Next(); };
}
bool IsClose() { return center_->IsClosed(); }
void MarkClose() {
if (center_ != nullptr) center_->MarkClosed();
}
private:
RefCountedPtr<Center> center_;
};

@ -103,14 +103,12 @@ class Center : public RefCounted<Center<T>> {
// Mark that the receiver is closed.
void ReceiverClosed() {
MutexLock lock(&mu_);
ReleasableMutexLock lock(&mu_);
if (receiver_closed_) return;
receiver_closed_ = true;
}
// Return whether the receiver is closed.
bool IsClosed() {
MutexLock lock(&mu_);
return receiver_closed_;
auto wakeups = send_wakers_.TakeWakeupSet();
lock.Release();
wakeups.Wakeup();
}
private:
@ -131,8 +129,8 @@ class MpscReceiver;
template <typename T>
class MpscSender {
public:
MpscSender(const MpscSender&) = delete;
MpscSender& operator=(const MpscSender&) = delete;
MpscSender(const MpscSender&) = default;
MpscSender& operator=(const MpscSender&) = default;
MpscSender(MpscSender&&) noexcept = default;
MpscSender& operator=(MpscSender&&) noexcept = default;
@ -140,7 +138,10 @@ class MpscSender {
// Resolves to true if sent, false if the receiver was closed (and the value
// will never be successfully sent).
auto Send(T t) {
return [this, t = std::move(t)]() mutable { return center_->PollSend(t); };
return [center = center_, t = std::move(t)]() mutable -> Poll<bool> {
if (center == nullptr) return false;
return center->PollSend(t);
};
}
bool UnbufferedImmediateSend(T t) {
@ -170,7 +171,6 @@ class MpscReceiver {
~MpscReceiver() {
if (center_ != nullptr) center_->ReceiverClosed();
}
bool IsClosed() { return center_->IsClosed(); }
void MarkClosed() {
if (center_ != nullptr) center_->ReceiverClosed();
}

@ -95,6 +95,30 @@ struct StatusCastImpl<absl::Status, const StatusFlag&> {
}
};
template <typename T>
struct FailureStatusCastImpl<absl::StatusOr<T>, StatusFlag> {
static absl::StatusOr<T> Cast(StatusFlag flag) {
GPR_DEBUG_ASSERT(!flag.ok());
return absl::CancelledError();
}
};
template <typename T>
struct FailureStatusCastImpl<absl::StatusOr<T>, StatusFlag&> {
static absl::StatusOr<T> Cast(StatusFlag flag) {
GPR_DEBUG_ASSERT(!flag.ok());
return absl::CancelledError();
}
};
template <typename T>
struct FailureStatusCastImpl<absl::StatusOr<T>, const StatusFlag&> {
static absl::StatusOr<T> Cast(StatusFlag flag) {
GPR_DEBUG_ASSERT(!flag.ok());
return absl::CancelledError();
}
};
// A value if an operation was successful, or a failure flag if not.
template <typename T>
class ValueOrFailure {

@ -75,16 +75,16 @@ struct TryJoinTraits {
}
template <typename R>
static R EarlyReturn(absl::Status x) {
return StatusCast<R>(std::move(x));
return FailureStatusCast<R>(std::move(x));
}
template <typename R>
static R EarlyReturn(StatusFlag x) {
return StatusCast<R>(x);
return FailureStatusCast<R>(x);
}
template <typename R, typename T>
static R EarlyReturn(const ValueOrFailure<T>& x) {
GPR_ASSERT(!x.ok());
return StatusCast<R>(Failure{});
return FailureStatusCast<R>(Failure{});
}
template <typename... A>
static auto FinalReturn(A&&... a) {

@ -76,7 +76,7 @@ struct TrySeqTraitsWithSfinae<absl::StatusOr<T>> {
}
template <typename R>
static R ReturnValue(absl::StatusOr<T>&& status) {
return StatusCast<R>(status.status());
return FailureStatusCast<R>(status.status());
}
template <typename F, typename Elem>
static auto CallSeqFactory(F& f, Elem&& elem, absl::StatusOr<T> value)
@ -86,11 +86,26 @@ struct TrySeqTraitsWithSfinae<absl::StatusOr<T>> {
template <typename Result, typename RunNext>
static Poll<Result> CheckResultAndRunNext(absl::StatusOr<T> prior,
RunNext run_next) {
if (!prior.ok()) return StatusCast<Result>(prior.status());
if (!prior.ok()) return FailureStatusCast<Result>(prior.status());
return run_next(std::move(prior));
}
};
template <typename T>
struct AllowGenericTrySeqTraits {
static constexpr bool value = true;
};
template <>
struct AllowGenericTrySeqTraits<absl::Status> {
static constexpr bool value = false;
};
template <typename T>
struct AllowGenericTrySeqTraits<absl::StatusOr<T>> {
static constexpr bool value = false;
};
template <typename T, typename AnyType = void>
struct TakeValueExists {
static constexpr bool value = false;
@ -107,7 +122,7 @@ template <typename T>
struct TrySeqTraitsWithSfinae<
T, absl::enable_if_t<
std::is_same<decltype(IsStatusOk(std::declval<T>())), bool>::value &&
!TakeValueExists<T>::value,
!TakeValueExists<T>::value && AllowGenericTrySeqTraits<T>::value,
void>> {
using UnwrappedType = void;
using WrappedType = T;
@ -121,7 +136,7 @@ struct TrySeqTraitsWithSfinae<
}
template <typename R>
static R ReturnValue(T&& status) {
return StatusCast<R>(std::move(status));
return FailureStatusCast<R>(std::move(status));
}
template <typename Result, typename RunNext>
static Poll<Result> CheckResultAndRunNext(T prior, RunNext run_next) {
@ -133,7 +148,7 @@ template <typename T>
struct TrySeqTraitsWithSfinae<
T, absl::enable_if_t<
std::is_same<decltype(IsStatusOk(std::declval<T>())), bool>::value &&
TakeValueExists<T>::value,
TakeValueExists<T>::value && AllowGenericTrySeqTraits<T>::value,
void>> {
using UnwrappedType = decltype(TakeValue(std::declval<T>()));
using WrappedType = T;
@ -148,7 +163,7 @@ struct TrySeqTraitsWithSfinae<
template <typename R>
static R ReturnValue(T&& status) {
GPR_DEBUG_ASSERT(!IsStatusOk(status));
return StatusCast<R>(status.status());
return FailureStatusCast<R>(status.status());
}
template <typename Result, typename RunNext>
static Poll<Result> CheckResultAndRunNext(T prior, RunNext run_next) {
@ -170,7 +185,7 @@ struct TrySeqTraitsWithSfinae<absl::Status> {
}
template <typename R>
static R ReturnValue(absl::Status&& status) {
return StatusCast<R>(std::move(status));
return FailureStatusCast<R>(std::move(status));
}
template <typename Result, typename RunNext>
static Poll<Result> CheckResultAndRunNext(absl::Status prior,

@ -180,7 +180,7 @@ class Arena {
template <typename T, typename... Args>
T* New(Args&&... args) {
T* t = static_cast<T*>(Alloc(sizeof(T)));
Construct(t, std::forward<Args>(args)...);
new (t) T(std::forward<Args>(args)...);
return t;
}
@ -333,7 +333,7 @@ class Arena {
// value in Arena::PoolSizes, and so this may pessimize total
// arena size.
template <typename T, typename... Args>
PoolPtr<T> MakePooled(Args&&... args) {
static PoolPtr<T> MakePooled(Args&&... args) {
return PoolPtr<T>(new T(std::forward<Args>(args)...), PooledDeleter());
}

@ -50,6 +50,9 @@ namespace grpc_core {
class SliceBuffer {
public:
explicit SliceBuffer() { grpc_slice_buffer_init(&slice_buffer_); }
explicit SliceBuffer(Slice slice) : SliceBuffer() {
Append(std::move(slice));
}
SliceBuffer(const SliceBuffer& other) = delete;
SliceBuffer(SliceBuffer&& other) noexcept {
grpc_slice_buffer_init(&slice_buffer_);

@ -4063,16 +4063,13 @@ void ServerCallSpine::CommitBatch(const grpc_op* ops, size_t nops,
}
RefCountedPtr<CallSpineInterface> MakeServerCall(Server* server,
Channel* channel) {
const auto initial_size = channel->CallSizeEstimate();
global_stats().IncrementCallInitialSize(initial_size);
auto alloc = Arena::CreateWithAlloc(initial_size, sizeof(ServerCallSpine),
channel->allocator());
auto* call = new (alloc.second) ServerCallSpine(server, channel, alloc.first);
return RefCountedPtr<ServerCallSpine>(call);
Channel* channel,
Arena* arena) {
return RefCountedPtr<ServerCallSpine>(
arena->New<ServerCallSpine>(server, channel, arena));
}
#else
RefCountedPtr<CallSpineInterface> MakeServerCall(Server*, Channel*) {
RefCountedPtr<CallSpineInterface> MakeServerCall(Server*, Channel*, Arena*) {
Crash("not implemented");
}
#endif

@ -160,7 +160,8 @@ template <>
struct ContextType<CallContext> {};
RefCountedPtr<CallSpineInterface> MakeServerCall(Server* server,
Channel* channel);
Channel* channel,
Arena* arena);
} // namespace grpc_core

@ -51,6 +51,7 @@
#include "src/core/lib/channel/channel_trace.h"
#include "src/core/lib/channel/channelz.h"
#include "src/core/lib/config/core_configuration.h"
#include "src/core/lib/debug/stats.h"
#include "src/core/lib/experiments/experiments.h"
#include "src/core/lib/gpr/useful.h"
#include "src/core/lib/gprpp/crash.h"
@ -1297,6 +1298,20 @@ Server::ChannelData::~ChannelData() {
}
}
Arena* Server::ChannelData::CreateArena() {
const auto initial_size = channel_->CallSizeEstimate();
global_stats().IncrementCallInitialSize(initial_size);
return Arena::Create(initial_size, channel_->allocator());
}
absl::StatusOr<CallInitiator> Server::ChannelData::CreateCall(
ClientMetadata& client_initial_metadata, Arena* arena) {
SetRegisteredMethodOnMetadata(client_initial_metadata);
auto call = MakeServerCall(server_.get(), channel_.get(), arena);
InitCall(call);
return CallInitiator(std::move(call));
}
void Server::ChannelData::InitTransport(RefCountedPtr<Server> server,
RefCountedPtr<Channel> channel,
size_t cq_idx, Transport* transport,
@ -1329,13 +1344,7 @@ void Server::ChannelData::InitTransport(RefCountedPtr<Server> server,
}
if (transport->server_transport() != nullptr) {
++accept_stream_types;
transport->server_transport()->SetAcceptFunction(
[this](ClientMetadata& metadata) {
SetRegisteredMethodOnMetadata(metadata);
auto call = MakeServerCall(server_.get(), channel_.get());
InitCall(call);
return CallInitiator(std::move(call));
});
transport->server_transport()->SetAcceptor(this);
}
GPR_ASSERT(accept_stream_types == 1);
op->start_connectivity_watch = MakeOrphanable<ConnectivityWatcher>(this);

@ -218,7 +218,7 @@ class Server : public InternallyRefCounted<Server>,
class AllocatingRequestMatcherBatch;
class AllocatingRequestMatcherRegistered;
class ChannelData {
class ChannelData final : public ServerTransport::Acceptor {
public:
ChannelData() = default;
~ChannelData();
@ -241,6 +241,10 @@ class Server : public InternallyRefCounted<Server>,
grpc_channel_element* elem, CallArgs call_args, NextPromiseFactory);
void InitCall(RefCountedPtr<CallSpineInterface> call);
Arena* CreateArena() override;
absl::StatusOr<CallInitiator> CreateCall(
ClientMetadata& client_initial_metadata, Arena* arena) override;
private:
class ConnectivityWatcher;

@ -69,24 +69,26 @@ class PromiseEndpoint {
auto Write(SliceBuffer data) {
// Assert previous write finishes.
GPR_ASSERT(!write_state_->complete.load(std::memory_order_relaxed));
// TODO(ladynana): Replace this with `SliceBufferCast<>` when it is
// available.
grpc_slice_buffer_swap(write_state_->buffer.c_slice_buffer(),
data.c_slice_buffer());
// If `Write()` returns true immediately, the callback will not be called.
// We still need to call our callback to pick up the result.
write_state_->waker = Activity::current()->MakeNonOwningWaker();
const bool completed = endpoint_->Write(
[write_state = write_state_](absl::Status status) {
write_state->Complete(std::move(status));
},
&write_state_->buffer, nullptr /* uses default arguments */);
bool completed;
if (data.Length() == 0) {
completed = true;
} else {
// TODO(ladynana): Replace this with `SliceBufferCast<>` when it is
// available.
grpc_slice_buffer_swap(write_state_->buffer.c_slice_buffer(),
data.c_slice_buffer());
// If `Write()` returns true immediately, the callback will not be called.
// We still need to call our callback to pick up the result.
write_state_->waker = Activity::current()->MakeNonOwningWaker();
completed = endpoint_->Write(
[write_state = write_state_](absl::Status status) {
write_state->Complete(std::move(status));
},
&write_state_->buffer, nullptr /* uses default arguments */);
if (completed) write_state_->waker = Waker();
}
return If(
completed,
[this]() {
write_state_->waker = Waker();
return []() { return absl::OkStatus(); };
},
completed, []() { return []() { return absl::OkStatus(); }; },
[this]() {
return [write_state = write_state_]() -> Poll<absl::Status> {
// If current write isn't finished return `Pending()`, else return

@ -291,9 +291,8 @@ void ForwardCall(CallHandler call_handler, CallInitiator call_initiator,
return call_initiator.SpawnWaitable(
"send_message",
[msg = std::move(msg), call_initiator]() mutable {
return call_initiator.CancelIfFails(Map(
call_initiator.PushMessage(std::move(msg)),
[](bool r) { return StatusFlag(r); }));
return call_initiator.CancelIfFails(
call_initiator.PushMessage(std::move(msg)));
});
});
});
@ -317,8 +316,7 @@ void ForwardCall(CallHandler call_handler, CallInitiator call_initiator,
"recv_message",
[msg = std::move(msg), call_handler]() mutable {
return call_handler.CancelIfFails(
Map(call_handler.PushMessage(std::move(msg)),
[](bool r) { return StatusFlag(r); }));
call_handler.PushMessage(std::move(msg)));
});
}),
ImmediateOkStatus())),
@ -334,4 +332,10 @@ void ForwardCall(CallHandler call_handler, CallInitiator call_initiator,
});
}
CallInitiatorAndHandler MakeCall(
grpc_event_engine::experimental::EventEngine* event_engine, Arena* arena) {
auto spine = CallSpine::Create(event_engine, arena);
return {CallInitiator(spine), CallHandler(spine)};
}
} // namespace grpc_core

@ -258,6 +258,20 @@ class CallSpineInterface {
virtual Pipe<MessageHandle>& server_to_client_messages() = 0;
virtual Pipe<ServerMetadataHandle>& server_trailing_metadata() = 0;
virtual Latch<ServerMetadataHandle>& cancel_latch() = 0;
// Add a callback to be called when server trailing metadata is received.
void OnDone(absl::AnyInvocable<void()> fn) {
if (on_done_ == nullptr) {
on_done_ = std::move(fn);
return;
}
on_done_ = [first = std::move(fn), next = std::move(on_done_)]() mutable {
first();
next();
};
}
void CallOnDone() {
if (on_done_ != nullptr) std::exchange(on_done_, nullptr)();
}
virtual Party& party() = 0;
virtual void IncrementRefCount() = 0;
virtual void Unref() = 0;
@ -276,6 +290,11 @@ class CallSpineInterface {
auto& c = cancel_latch();
if (c.is_set()) return absl::nullopt;
c.Set(std::move(metadata));
CallOnDone();
client_initial_metadata().sender.CloseWithError();
server_initial_metadata().sender.CloseWithError();
client_to_server_messages().sender.CloseWithError();
server_to_client_messages().sender.CloseWithError();
return absl::nullopt;
}
@ -325,11 +344,18 @@ class CallSpineInterface {
}
});
}
private:
absl::AnyInvocable<void()> on_done_{nullptr};
};
class CallSpine final : public CallSpineInterface {
class CallSpine final : public CallSpineInterface, public Party {
public:
CallSpine() { Crash("unimplemented"); }
static RefCountedPtr<CallSpine> Create(
grpc_event_engine::experimental::EventEngine* event_engine,
Arena* arena) {
return RefCountedPtr<CallSpine>(arena->New<CallSpine>(event_engine, arena));
}
Pipe<ClientMetadataHandle>& client_initial_metadata() override {
return client_initial_metadata_;
@ -347,23 +373,57 @@ class CallSpine final : public CallSpineInterface {
return server_trailing_metadata_;
}
Latch<ServerMetadataHandle>& cancel_latch() override { return cancel_latch_; }
Party& party() override { Crash("unimplemented"); }
void IncrementRefCount() override { Crash("unimplemented"); }
void Unref() override { Crash("unimplemented"); }
Party& party() override { return *this; }
void IncrementRefCount() override { Party::IncrementRefCount(); }
void Unref() override { Party::Unref(); }
private:
friend class Arena;
CallSpine(grpc_event_engine::experimental::EventEngine* event_engine,
Arena* arena)
: Party(arena, 1), event_engine_(event_engine) {}
class ScopedContext : public ScopedActivity,
public promise_detail::Context<Arena> {
public:
explicit ScopedContext(CallSpine* spine)
: ScopedActivity(&spine->party()), Context<Arena>(spine->arena()) {}
};
bool RunParty() override {
ScopedContext context(this);
return Party::RunParty();
}
void PartyOver() override {
Arena* a = arena();
{
ScopedContext context(this);
CancelRemainingParticipants();
a->DestroyManagedNewObjects();
}
this->~CallSpine();
a->Destroy();
}
grpc_event_engine::experimental::EventEngine* event_engine() const override {
return event_engine_;
}
// Initial metadata from client to server
Pipe<ClientMetadataHandle> client_initial_metadata_;
Pipe<ClientMetadataHandle> client_initial_metadata_{arena()};
// Initial metadata from server to client
Pipe<ServerMetadataHandle> server_initial_metadata_;
Pipe<ServerMetadataHandle> server_initial_metadata_{arena()};
// Messages travelling from the application to the transport.
Pipe<MessageHandle> client_to_server_messages_;
Pipe<MessageHandle> client_to_server_messages_{arena()};
// Messages travelling from the transport to the application.
Pipe<MessageHandle> server_to_client_messages_;
Pipe<MessageHandle> server_to_client_messages_{arena()};
// Trailing metadata from server to client
Pipe<ServerMetadataHandle> server_trailing_metadata_;
Pipe<ServerMetadataHandle> server_trailing_metadata_{arena()};
// Latch that can be set to terminate the call
Latch<ServerMetadataHandle> cancel_latch_;
// Event engine associated with this call
grpc_event_engine::experimental::EventEngine* const event_engine_;
};
class CallInitiator {
@ -405,7 +465,14 @@ class CallInitiator {
auto PushMessage(MessageHandle message) {
GPR_DEBUG_ASSERT(Activity::current() == &spine_->party());
return spine_->client_to_server_messages().sender.Push(std::move(message));
return Map(
spine_->client_to_server_messages().sender.Push(std::move(message)),
[](bool r) { return StatusFlag(r); });
}
void FinishSends() {
GPR_DEBUG_ASSERT(Activity::current() == &spine_->party());
spine_->client_to_server_messages().sender.Close();
}
template <typename Promise>
@ -413,6 +480,12 @@ class CallInitiator {
return spine_->CancelIfFails(std::move(promise));
}
void Cancel() {
GPR_DEBUG_ASSERT(Activity::current() == &spine_->party());
std::ignore =
spine_->Cancel(ServerMetadataFromStatus(absl::CancelledError()));
}
template <typename PromiseFactory>
void SpawnGuarded(absl::string_view name, PromiseFactory promise_factory) {
spine_->SpawnGuarded(name, std::move(promise_factory));
@ -428,8 +501,10 @@ class CallInitiator {
return spine_->party().SpawnWaitable(name, std::move(promise_factory));
}
Arena* arena() { return spine_->party().arena(); }
private:
const RefCountedPtr<CallSpineInterface> spine_;
RefCountedPtr<CallSpineInterface> spine_;
};
class CallHandler {
@ -447,14 +522,16 @@ class CallHandler {
});
}
auto PushServerInitialMetadata(ClientMetadataHandle md) {
auto PushServerInitialMetadata(ServerMetadataHandle md) {
GPR_DEBUG_ASSERT(Activity::current() == &spine_->party());
return Map(spine_->server_initial_metadata().sender.Push(std::move(md)),
[](bool ok) { return StatusFlag(ok); });
}
auto PushServerTrailingMetadata(ClientMetadataHandle md) {
auto PushServerTrailingMetadata(ServerMetadataHandle md) {
GPR_DEBUG_ASSERT(Activity::current() == &spine_->party());
spine_->server_to_client_messages().sender.Close();
spine_->CallOnDone();
return Map(spine_->server_trailing_metadata().sender.Push(std::move(md)),
[](bool ok) { return StatusFlag(ok); });
}
@ -466,9 +543,18 @@ class CallHandler {
auto PushMessage(MessageHandle message) {
GPR_DEBUG_ASSERT(Activity::current() == &spine_->party());
return spine_->server_to_client_messages().sender.Push(std::move(message));
return Map(
spine_->server_to_client_messages().sender.Push(std::move(message)),
[](bool ok) { return StatusFlag(ok); });
}
void Cancel(ServerMetadataHandle status) {
GPR_DEBUG_ASSERT(Activity::current() == &spine_->party());
std::ignore = spine_->Cancel(std::move(status));
}
void OnDone(absl::AnyInvocable<void()> fn) { spine_->OnDone(std::move(fn)); }
template <typename Promise>
auto CancelIfFails(Promise promise) {
return spine_->CancelIfFails(std::move(promise));
@ -489,8 +575,10 @@ class CallHandler {
return spine_->party().SpawnWaitable(name, std::move(promise_factory));
}
Arena* arena() { return spine_->party().arena(); }
private:
const RefCountedPtr<CallSpineInterface> spine_;
RefCountedPtr<CallSpineInterface> spine_;
};
struct CallInitiatorAndHandler {
@ -498,13 +586,16 @@ struct CallInitiatorAndHandler {
CallHandler handler;
};
CallInitiatorAndHandler MakeCall(
grpc_event_engine::experimental::EventEngine* event_engine, Arena* arena);
template <typename CallHalf>
auto OutgoingMessages(CallHalf& h) {
auto OutgoingMessages(CallHalf h) {
struct Wrapper {
CallHalf& h;
CallHalf h;
auto Next() { return h.PullMessage(); }
};
return Wrapper{h};
return Wrapper{std::move(h)};
}
// Forward a call from `call_handler` to `call_initiator` (with initial metadata
@ -925,14 +1016,24 @@ class ClientTransport {
class ServerTransport {
public:
// AcceptFunction takes initial metadata for a new call and returns a
// CallInitiator object for it, for the transport to use to communicate with
// the CallHandler object passed to the application.
using AcceptFunction =
absl::AnyInvocable<absl::StatusOr<CallInitiator>(ClientMetadata&) const>;
// Acceptor helps transports create calls.
class Acceptor {
public:
// Returns an arena that can be used to allocate memory for initial metadata
// parsing, and later passed to CreateCall() as the underlying arena for
// that call.
virtual Arena* CreateArena() = 0;
// Create a call at the server (or fail)
// arena must have been previously allocated by CreateArena()
virtual absl::StatusOr<CallInitiator> CreateCall(
ClientMetadata& client_initial_metadata, Arena* arena) = 0;
protected:
~Acceptor() = default;
};
// Called once slightly after transport setup to register the accept function.
virtual void SetAcceptFunction(AcceptFunction accept_function) = 0;
virtual void SetAcceptor(Acceptor* acceptor) = 0;
protected:
~ServerTransport() = default;

@ -1,4 +1,6 @@
gens/
build/
grpc_root/
third_party/
*_pb2.py
*_pb2.pyi
*_pb2_grpc.py

@ -24,6 +24,7 @@ import types
from typing import (
Any,
Callable,
Dict,
Iterator,
List,
Optional,
@ -1054,6 +1055,7 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
_request_serializer: Optional[SerializingFunction]
_response_deserializer: Optional[DeserializingFunction]
_context: Any
_registered_call_handle: Optional[int]
__slots__ = [
"_channel",
@ -1074,6 +1076,7 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
target: bytes,
request_serializer: Optional[SerializingFunction],
response_deserializer: Optional[DeserializingFunction],
_registered_call_handle: Optional[int],
):
self._channel = channel
self._managed_call = managed_call
@ -1082,6 +1085,7 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
self._context = cygrpc.build_census_context()
self._registered_call_handle = _registered_call_handle
def _prepare(
self,
@ -1153,6 +1157,7 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
),
),
self._context,
self._registered_call_handle,
)
event = call.next_event()
_handle_event(event, state, self._response_deserializer)
@ -1221,6 +1226,7 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
(operations,),
event_handler,
self._context,
self._registered_call_handle,
)
return _MultiThreadedRendezvous(
state, call, self._response_deserializer, deadline
@ -1234,6 +1240,7 @@ class _SingleThreadedUnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
_request_serializer: Optional[SerializingFunction]
_response_deserializer: Optional[DeserializingFunction]
_context: Any
_registered_call_handle: Optional[int]
__slots__ = [
"_channel",
@ -1252,6 +1259,7 @@ class _SingleThreadedUnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
target: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction,
_registered_call_handle: Optional[int],
):
self._channel = channel
self._method = method
@ -1259,6 +1267,7 @@ class _SingleThreadedUnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
self._context = cygrpc.build_census_context()
self._registered_call_handle = _registered_call_handle
def __call__( # pylint: disable=too-many-locals
self,
@ -1317,6 +1326,7 @@ class _SingleThreadedUnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
call_credentials,
operations_and_tags,
self._context,
self._registered_call_handle,
)
return _SingleThreadedRendezvous(
state, call, self._response_deserializer, deadline
@ -1331,6 +1341,7 @@ class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
_request_serializer: Optional[SerializingFunction]
_response_deserializer: Optional[DeserializingFunction]
_context: Any
_registered_call_handle: Optional[int]
__slots__ = [
"_channel",
@ -1351,6 +1362,7 @@ class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
target: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction,
_registered_call_handle: Optional[int],
):
self._channel = channel
self._managed_call = managed_call
@ -1359,6 +1371,7 @@ class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
self._context = cygrpc.build_census_context()
self._registered_call_handle = _registered_call_handle
def __call__( # pylint: disable=too-many-locals
self,
@ -1408,6 +1421,7 @@ class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
operations,
_event_handler(state, self._response_deserializer),
self._context,
self._registered_call_handle,
)
return _MultiThreadedRendezvous(
state, call, self._response_deserializer, deadline
@ -1422,6 +1436,7 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
_request_serializer: Optional[SerializingFunction]
_response_deserializer: Optional[DeserializingFunction]
_context: Any
_registered_call_handle: Optional[int]
__slots__ = [
"_channel",
@ -1442,6 +1457,7 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
target: bytes,
request_serializer: Optional[SerializingFunction],
response_deserializer: Optional[DeserializingFunction],
_registered_call_handle: Optional[int],
):
self._channel = channel
self._managed_call = managed_call
@ -1450,6 +1466,7 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
self._context = cygrpc.build_census_context()
self._registered_call_handle = _registered_call_handle
def _blocking(
self,
@ -1482,6 +1499,7 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
augmented_metadata, initial_metadata_flags
),
self._context,
self._registered_call_handle,
)
_consume_request_iterator(
request_iterator, state, call, self._request_serializer, None
@ -1572,6 +1590,7 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
),
event_handler,
self._context,
self._registered_call_handle,
)
_consume_request_iterator(
request_iterator,
@ -1593,6 +1612,7 @@ class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable):
_request_serializer: Optional[SerializingFunction]
_response_deserializer: Optional[DeserializingFunction]
_context: Any
_registered_call_handle: Optional[int]
__slots__ = [
"_channel",
@ -1611,8 +1631,9 @@ class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable):
managed_call: IntegratedCallFactory,
method: bytes,
target: bytes,
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None,
request_serializer: Optional[SerializingFunction],
response_deserializer: Optional[DeserializingFunction],
_registered_call_handle: Optional[int],
):
self._channel = channel
self._managed_call = managed_call
@ -1621,6 +1642,7 @@ class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable):
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
self._context = cygrpc.build_census_context()
self._registered_call_handle = _registered_call_handle
def __call__(
self,
@ -1662,6 +1684,7 @@ class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable):
operations,
event_handler,
self._context,
self._registered_call_handle,
)
_consume_request_iterator(
request_iterator,
@ -1751,7 +1774,8 @@ def _channel_managed_call_management(state: _ChannelCallState):
credentials: Optional[cygrpc.CallCredentials],
operations: Sequence[Sequence[cygrpc.Operation]],
event_handler: UserTag,
context,
context: Any,
_registered_call_handle: Optional[int],
) -> cygrpc.IntegratedCall:
"""Creates a cygrpc.IntegratedCall.
@ -1768,6 +1792,8 @@ def _channel_managed_call_management(state: _ChannelCallState):
event_handler: A behavior to call to handle the events resultant from
the operations on the call.
context: Context object for distributed tracing.
_registered_call_handle: An int representing the call handle of the
method, or None if the method is not registered.
Returns:
A cygrpc.IntegratedCall with which to conduct an RPC.
"""
@ -1788,6 +1814,7 @@ def _channel_managed_call_management(state: _ChannelCallState):
credentials,
operations_and_tags,
context,
_registered_call_handle,
)
if state.managed_calls == 0:
state.managed_calls = 1
@ -2021,6 +2048,7 @@ class Channel(grpc.Channel):
_call_state: _ChannelCallState
_connectivity_state: _ChannelConnectivityState
_target: str
_registered_call_handles: Dict[str, int]
def __init__(
self,
@ -2055,6 +2083,22 @@ class Channel(grpc.Channel):
if cygrpc.g_gevent_activated:
cygrpc.gevent_increment_channel_count()
def _get_registered_call_handle(self, method: str) -> int:
"""
Get the registered call handle for a method.
This is a semi-private method. It is intended for use only by gRPC generated code.
This method is not thread-safe.
Args:
method: Required, the method name for the RPC.
Returns:
The registered call handle pointer in the form of a Python Long.
"""
return self._channel.get_registered_call_handle(_common.encode(method))
def _process_python_options(
self, python_options: Sequence[ChannelArgumentType]
) -> None:
@ -2078,12 +2122,17 @@ class Channel(grpc.Channel):
) -> None:
_unsubscribe(self._connectivity_state, callback)
# pylint: disable=arguments-differ
def unary_unary(
self,
method: str,
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None,
_registered_method: Optional[bool] = False,
) -> grpc.UnaryUnaryMultiCallable:
_registered_call_handle = None
if _registered_method:
_registered_call_handle = self._get_registered_call_handle(method)
return _UnaryUnaryMultiCallable(
self._channel,
_channel_managed_call_management(self._call_state),
@ -2091,14 +2140,20 @@ class Channel(grpc.Channel):
_common.encode(self._target),
request_serializer,
response_deserializer,
_registered_call_handle,
)
# pylint: disable=arguments-differ
def unary_stream(
self,
method: str,
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None,
_registered_method: Optional[bool] = False,
) -> grpc.UnaryStreamMultiCallable:
_registered_call_handle = None
if _registered_method:
_registered_call_handle = self._get_registered_call_handle(method)
# NOTE(rbellevi): Benchmarks have shown that running a unary-stream RPC
# on a single Python thread results in an appreciable speed-up. However,
# due to slight differences in capability, the multi-threaded variant
@ -2110,6 +2165,7 @@ class Channel(grpc.Channel):
_common.encode(self._target),
request_serializer,
response_deserializer,
_registered_call_handle,
)
else:
return _UnaryStreamMultiCallable(
@ -2119,14 +2175,20 @@ class Channel(grpc.Channel):
_common.encode(self._target),
request_serializer,
response_deserializer,
_registered_call_handle,
)
# pylint: disable=arguments-differ
def stream_unary(
self,
method: str,
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None,
_registered_method: Optional[bool] = False,
) -> grpc.StreamUnaryMultiCallable:
_registered_call_handle = None
if _registered_method:
_registered_call_handle = self._get_registered_call_handle(method)
return _StreamUnaryMultiCallable(
self._channel,
_channel_managed_call_management(self._call_state),
@ -2134,14 +2196,20 @@ class Channel(grpc.Channel):
_common.encode(self._target),
request_serializer,
response_deserializer,
_registered_call_handle,
)
# pylint: disable=arguments-differ
def stream_stream(
self,
method: str,
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None,
_registered_method: Optional[bool] = False,
) -> grpc.StreamStreamMultiCallable:
_registered_call_handle = None
if _registered_method:
_registered_call_handle = self._get_registered_call_handle(method)
return _StreamStreamMultiCallable(
self._channel,
_channel_managed_call_management(self._call_state),
@ -2149,6 +2217,7 @@ class Channel(grpc.Channel):
_common.encode(self._target),
request_serializer,
response_deserializer,
_registered_call_handle,
)
def _unsubscribe_all(self) -> None:

@ -74,6 +74,13 @@ cdef class SegregatedCall:
cdef class Channel:
cdef _ChannelState _state
cdef dict _registered_call_handles
# TODO(https://github.com/grpc/grpc/issues/15662): Eliminate this.
cdef tuple _arguments
cdef class CallHandle:
cdef void *c_call_handle
cdef object method

@ -101,6 +101,25 @@ cdef class _ChannelState:
self.connectivity_due = set()
self.closed_reason = None
cdef class CallHandle:
def __cinit__(self, _ChannelState channel_state, object method):
self.method = method
cpython.Py_INCREF(method)
# Note that since we always pass None for host, we set the
# second-to-last parameter of grpc_channel_register_call to a fixed
# NULL value.
self.c_call_handle = grpc_channel_register_call(
channel_state.c_channel, <const char *>method, NULL, NULL)
def __dealloc__(self):
cpython.Py_DECREF(self.method)
@property
def call_handle(self):
return cpython.PyLong_FromVoidPtr(self.c_call_handle)
cdef tuple _operate(grpc_call *c_call, object operations, object user_tag):
cdef grpc_call_error c_call_error
@ -199,7 +218,7 @@ cdef void _call(
grpc_completion_queue *c_completion_queue, on_success, int flags, method,
host, object deadline, CallCredentials credentials,
object operationses_and_user_tags, object metadata,
object context) except *:
object context, object registered_call_handle) except *:
"""Invokes an RPC.
Args:
@ -226,6 +245,8 @@ cdef void _call(
must be present in the first element of this value.
metadata: The metadata for this call.
context: Context object for distributed tracing.
registered_call_handle: An int representing the call handle of the method, or
None if the method is not registered.
"""
cdef grpc_slice method_slice
cdef grpc_slice host_slice
@ -242,10 +263,16 @@ cdef void _call(
else:
host_slice = _slice_from_bytes(host)
host_slice_ptr = &host_slice
call_state.c_call = grpc_channel_create_call(
channel_state.c_channel, NULL, flags,
c_completion_queue, method_slice, host_slice_ptr,
_timespec_from_time(deadline), NULL)
if registered_call_handle:
call_state.c_call = grpc_channel_create_registered_call(
channel_state.c_channel, NULL, flags,
c_completion_queue, cpython.PyLong_AsVoidPtr(registered_call_handle),
_timespec_from_time(deadline), NULL)
else:
call_state.c_call = grpc_channel_create_call(
channel_state.c_channel, NULL, flags,
c_completion_queue, method_slice, host_slice_ptr,
_timespec_from_time(deadline), NULL)
grpc_slice_unref(method_slice)
if host_slice_ptr:
grpc_slice_unref(host_slice)
@ -309,7 +336,7 @@ cdef class IntegratedCall:
cdef IntegratedCall _integrated_call(
_ChannelState state, int flags, method, host, object deadline,
object metadata, CallCredentials credentials, operationses_and_user_tags,
object context):
object context, object registered_call_handle):
call_state = _CallState()
def on_success(started_tags):
@ -318,7 +345,8 @@ cdef IntegratedCall _integrated_call(
_call(
state, call_state, state.c_call_completion_queue, on_success, flags,
method, host, deadline, credentials, operationses_and_user_tags, metadata, context)
method, host, deadline, credentials, operationses_and_user_tags,
metadata, context, registered_call_handle)
return IntegratedCall(state, call_state)
@ -371,7 +399,7 @@ cdef class SegregatedCall:
cdef SegregatedCall _segregated_call(
_ChannelState state, int flags, method, host, object deadline,
object metadata, CallCredentials credentials, operationses_and_user_tags,
object context):
object context, object registered_call_handle):
cdef _CallState call_state = _CallState()
cdef SegregatedCall segregated_call
cdef grpc_completion_queue *c_completion_queue
@ -389,7 +417,7 @@ cdef SegregatedCall _segregated_call(
_call(
state, call_state, c_completion_queue, on_success, flags, method, host,
deadline, credentials, operationses_and_user_tags, metadata,
context)
context, registered_call_handle)
except:
_destroy_c_completion_queue(c_completion_queue)
raise
@ -486,6 +514,7 @@ cdef class Channel:
else grpc_insecure_credentials_create())
self._state.c_channel = grpc_channel_create(
<char *>target, c_channel_credentials, channel_args.c_args())
self._registered_call_handles = {}
grpc_channel_credentials_release(c_channel_credentials)
def target(self):
@ -499,10 +528,10 @@ cdef class Channel:
def integrated_call(
self, int flags, method, host, object deadline, object metadata,
CallCredentials credentials, operationses_and_tags,
object context = None):
object context = None, object registered_call_handle = None):
return _integrated_call(
self._state, flags, method, host, deadline, metadata, credentials,
operationses_and_tags, context)
operationses_and_tags, context, registered_call_handle)
def next_call_event(self):
def on_success(tag):
@ -521,10 +550,10 @@ cdef class Channel:
def segregated_call(
self, int flags, method, host, object deadline, object metadata,
CallCredentials credentials, operationses_and_tags,
object context = None):
object context = None, object registered_call_handle = None):
return _segregated_call(
self._state, flags, method, host, deadline, metadata, credentials,
operationses_and_tags, context)
operationses_and_tags, context, registered_call_handle)
def check_connectivity_state(self, bint try_to_connect):
with self._state.condition:
@ -543,3 +572,19 @@ cdef class Channel:
def close_on_fork(self, code, details):
_close(self, code, details, True)
def get_registered_call_handle(self, method):
"""
Get or registers a call handler for a method.
This method is not thread-safe.
Args:
method: Required, the method name for the RPC.
Returns:
The registered call handle pointer in the form of a Python Long.
"""
if method not in self._registered_call_handles.keys():
self._registered_call_handles[method] = CallHandle(self._state, method)
return self._registered_call_handles[method].call_handle

@ -433,6 +433,12 @@ cdef extern from "grpc/grpc.h":
grpc_channel *channel, grpc_call *parent_call, uint32_t propagation_mask,
grpc_completion_queue *completion_queue, grpc_slice method,
const grpc_slice *host, gpr_timespec deadline, void *reserved) nogil
void *grpc_channel_register_call(
grpc_channel *channel, const char *method, const char *host, void *reserved) nogil
grpc_call *grpc_channel_create_registered_call(
grpc_channel *channel, grpc_call *parent_call, uint32_t propagation_mask,
grpc_completion_queue *completion_queue, void* registered_call_handle,
gpr_timespec deadline, void *reserved) nogil
grpc_connectivity_state grpc_channel_check_connectivity_state(
grpc_channel *channel, int try_to_connect) nogil
void grpc_channel_watch_connectivity_state(

@ -684,57 +684,85 @@ class _Channel(grpc.Channel):
def unsubscribe(self, callback: Callable):
self._channel.unsubscribe(callback)
# pylint: disable=arguments-differ
def unary_unary(
self,
method: str,
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None,
_registered_method: Optional[bool] = False,
) -> grpc.UnaryUnaryMultiCallable:
# pytype: disable=wrong-arg-count
thunk = lambda m: self._channel.unary_unary(
m, request_serializer, response_deserializer
m,
request_serializer,
response_deserializer,
_registered_method,
)
# pytype: enable=wrong-arg-count
if isinstance(self._interceptor, grpc.UnaryUnaryClientInterceptor):
return _UnaryUnaryMultiCallable(thunk, method, self._interceptor)
else:
return thunk(method)
# pylint: disable=arguments-differ
def unary_stream(
self,
method: str,
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None,
_registered_method: Optional[bool] = False,
) -> grpc.UnaryStreamMultiCallable:
# pytype: disable=wrong-arg-count
thunk = lambda m: self._channel.unary_stream(
m, request_serializer, response_deserializer
m,
request_serializer,
response_deserializer,
_registered_method,
)
# pytype: enable=wrong-arg-count
if isinstance(self._interceptor, grpc.UnaryStreamClientInterceptor):
return _UnaryStreamMultiCallable(thunk, method, self._interceptor)
else:
return thunk(method)
# pylint: disable=arguments-differ
def stream_unary(
self,
method: str,
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None,
_registered_method: Optional[bool] = False,
) -> grpc.StreamUnaryMultiCallable:
# pytype: disable=wrong-arg-count
thunk = lambda m: self._channel.stream_unary(
m, request_serializer, response_deserializer
m,
request_serializer,
response_deserializer,
_registered_method,
)
# pytype: enable=wrong-arg-count
if isinstance(self._interceptor, grpc.StreamUnaryClientInterceptor):
return _StreamUnaryMultiCallable(thunk, method, self._interceptor)
else:
return thunk(method)
# pylint: disable=arguments-differ
def stream_stream(
self,
method: str,
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None,
_registered_method: Optional[bool] = False,
) -> grpc.StreamStreamMultiCallable:
# pytype: disable=wrong-arg-count
thunk = lambda m: self._channel.stream_stream(
m, request_serializer, response_deserializer
m,
request_serializer,
response_deserializer,
_registered_method,
)
# pytype: enable=wrong-arg-count
if isinstance(self._interceptor, grpc.StreamStreamClientInterceptor):
return _StreamStreamMultiCallable(thunk, method, self._interceptor)
else:

@ -159,7 +159,19 @@ class ChannelCache:
channel_credentials: Optional[grpc.ChannelCredentials],
insecure: bool,
compression: Optional[grpc.Compression],
) -> grpc.Channel:
method: str,
_registered_method: bool,
) -> Tuple[grpc.Channel, Optional[int]]:
"""Get a channel from cache or creates a new channel.
This method also takes care of register method for channel,
which means we'll register a new call handle if we're calling a
non-registered method for an existing channel.
Returns:
A tuple with two items. The first item is the channel, second item is
the call handle if the method is registered, None if it's not registered.
"""
if insecure and channel_credentials:
raise ValueError(
"The insecure option is mutually exclusive with "
@ -176,18 +188,25 @@ class ChannelCache:
key = (target, options, channel_credentials, compression)
with self._lock:
channel_data = self._mapping.get(key, None)
call_handle = None
if channel_data is not None:
channel = channel_data[0]
# Register a new call handle if we're calling a registered method for an
# existing channel and this method is not registered.
if _registered_method:
call_handle = channel._get_registered_call_handle(method)
self._mapping.pop(key)
self._mapping[key] = (
channel,
datetime.datetime.now() + _EVICTION_PERIOD,
)
return channel
return channel, call_handle
else:
channel = _create_channel(
target, options, channel_credentials, compression
)
if _registered_method:
call_handle = channel._get_registered_call_handle(method)
self._mapping[key] = (
channel,
datetime.datetime.now() + _EVICTION_PERIOD,
@ -197,7 +216,7 @@ class ChannelCache:
or len(self._mapping) >= _MAXIMUM_CHANNELS
):
self._condition.notify()
return channel
return channel, call_handle
def _test_only_channel_count(self) -> int:
with self._lock:
@ -205,6 +224,7 @@ class ChannelCache:
@experimental_api
# pylint: disable=too-many-locals
def unary_unary(
request: RequestType,
target: str,
@ -219,6 +239,7 @@ def unary_unary(
wait_for_ready: Optional[bool] = None,
timeout: Optional[float] = _DEFAULT_TIMEOUT,
metadata: Optional[Sequence[Tuple[str, Union[str, bytes]]]] = None,
_registered_method: Optional[bool] = False,
) -> ResponseType:
"""Invokes a unary-unary RPC without an explicitly specified channel.
@ -272,11 +293,17 @@ def unary_unary(
Returns:
The response to the RPC.
"""
channel = ChannelCache.get().get_channel(
target, options, channel_credentials, insecure, compression
channel, method_handle = ChannelCache.get().get_channel(
target,
options,
channel_credentials,
insecure,
compression,
method,
_registered_method,
)
multicallable = channel.unary_unary(
method, request_serializer, response_deserializer
method, request_serializer, response_deserializer, method_handle
)
wait_for_ready = wait_for_ready if wait_for_ready is not None else True
return multicallable(
@ -289,6 +316,7 @@ def unary_unary(
@experimental_api
# pylint: disable=too-many-locals
def unary_stream(
request: RequestType,
target: str,
@ -303,6 +331,7 @@ def unary_stream(
wait_for_ready: Optional[bool] = None,
timeout: Optional[float] = _DEFAULT_TIMEOUT,
metadata: Optional[Sequence[Tuple[str, Union[str, bytes]]]] = None,
_registered_method: Optional[bool] = False,
) -> Iterator[ResponseType]:
"""Invokes a unary-stream RPC without an explicitly specified channel.
@ -355,11 +384,17 @@ def unary_stream(
Returns:
An iterator of responses.
"""
channel = ChannelCache.get().get_channel(
target, options, channel_credentials, insecure, compression
channel, method_handle = ChannelCache.get().get_channel(
target,
options,
channel_credentials,
insecure,
compression,
method,
_registered_method,
)
multicallable = channel.unary_stream(
method, request_serializer, response_deserializer
method, request_serializer, response_deserializer, method_handle
)
wait_for_ready = wait_for_ready if wait_for_ready is not None else True
return multicallable(
@ -372,6 +407,7 @@ def unary_stream(
@experimental_api
# pylint: disable=too-many-locals
def stream_unary(
request_iterator: Iterator[RequestType],
target: str,
@ -386,6 +422,7 @@ def stream_unary(
wait_for_ready: Optional[bool] = None,
timeout: Optional[float] = _DEFAULT_TIMEOUT,
metadata: Optional[Sequence[Tuple[str, Union[str, bytes]]]] = None,
_registered_method: Optional[bool] = False,
) -> ResponseType:
"""Invokes a stream-unary RPC without an explicitly specified channel.
@ -438,11 +475,17 @@ def stream_unary(
Returns:
The response to the RPC.
"""
channel = ChannelCache.get().get_channel(
target, options, channel_credentials, insecure, compression
channel, method_handle = ChannelCache.get().get_channel(
target,
options,
channel_credentials,
insecure,
compression,
method,
_registered_method,
)
multicallable = channel.stream_unary(
method, request_serializer, response_deserializer
method, request_serializer, response_deserializer, method_handle
)
wait_for_ready = wait_for_ready if wait_for_ready is not None else True
return multicallable(
@ -455,6 +498,7 @@ def stream_unary(
@experimental_api
# pylint: disable=too-many-locals
def stream_stream(
request_iterator: Iterator[RequestType],
target: str,
@ -469,6 +513,7 @@ def stream_stream(
wait_for_ready: Optional[bool] = None,
timeout: Optional[float] = _DEFAULT_TIMEOUT,
metadata: Optional[Sequence[Tuple[str, Union[str, bytes]]]] = None,
_registered_method: Optional[bool] = False,
) -> Iterator[ResponseType]:
"""Invokes a stream-stream RPC without an explicitly specified channel.
@ -521,11 +566,17 @@ def stream_stream(
Returns:
An iterator of responses.
"""
channel = ChannelCache.get().get_channel(
target, options, channel_credentials, insecure, compression
channel, method_handle = ChannelCache.get().get_channel(
target,
options,
channel_credentials,
insecure,
compression,
method,
_registered_method,
)
multicallable = channel.stream_stream(
method, request_serializer, response_deserializer
method, request_serializer, response_deserializer, method_handle
)
wait_for_ready = wait_for_ready if wait_for_ready is not None else True
return multicallable(

@ -478,11 +478,20 @@ class Channel(_base_channel.Channel):
await self.wait_for_state_change(state)
state = self.get_state(try_to_connect=True)
# TODO(xuanwn): Implement this method after we have
# observability for Asyncio.
def _get_registered_call_handle(self, method: str) -> int:
pass
# TODO(xuanwn): Implement _registered_method after we have
# observability for Asyncio.
# pylint: disable=arguments-differ,unused-argument
def unary_unary(
self,
method: str,
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None,
_registered_method: Optional[bool] = False,
) -> UnaryUnaryMultiCallable:
return UnaryUnaryMultiCallable(
self._channel,
@ -494,11 +503,15 @@ class Channel(_base_channel.Channel):
self._loop,
)
# TODO(xuanwn): Implement _registered_method after we have
# observability for Asyncio.
# pylint: disable=arguments-differ,unused-argument
def unary_stream(
self,
method: str,
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None,
_registered_method: Optional[bool] = False,
) -> UnaryStreamMultiCallable:
return UnaryStreamMultiCallable(
self._channel,
@ -510,11 +523,15 @@ class Channel(_base_channel.Channel):
self._loop,
)
# TODO(xuanwn): Implement _registered_method after we have
# observability for Asyncio.
# pylint: disable=arguments-differ,unused-argument
def stream_unary(
self,
method: str,
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None,
_registered_method: Optional[bool] = False,
) -> StreamUnaryMultiCallable:
return StreamUnaryMultiCallable(
self._channel,
@ -526,11 +543,15 @@ class Channel(_base_channel.Channel):
self._loop,
)
# TODO(xuanwn): Implement _registered_method after we have
# observability for Asyncio.
# pylint: disable=arguments-differ,unused-argument
def stream_stream(
self,
method: str,
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None,
_registered_method: Optional[bool] = False,
) -> StreamStreamMultiCallable:
return StreamStreamMultiCallable(
self._channel,

@ -31,23 +31,42 @@ class TestingChannel(grpc_testing.Channel):
def unsubscribe(self, callback):
raise NotImplementedError()
def _get_registered_call_handle(self, method: str) -> int:
pass
def unary_unary(
self, method, request_serializer=None, response_deserializer=None
self,
method,
request_serializer=None,
response_deserializer=None,
_registered_method=False,
):
return _multi_callable.UnaryUnary(method, self._state)
def unary_stream(
self, method, request_serializer=None, response_deserializer=None
self,
method,
request_serializer=None,
response_deserializer=None,
_registered_method=False,
):
return _multi_callable.UnaryStream(method, self._state)
def stream_unary(
self, method, request_serializer=None, response_deserializer=None
self,
method,
request_serializer=None,
response_deserializer=None,
_registered_method=False,
):
return _multi_callable.StreamUnary(method, self._state)
def stream_stream(
self, method, request_serializer=None, response_deserializer=None
self,
method,
request_serializer=None,
response_deserializer=None,
_registered_method=False,
):
return _multi_callable.StreamStream(method, self._state)

@ -99,16 +99,20 @@ class ChannelzServicerTest(unittest.TestCase):
def _send_successful_unary_unary(self, idx):
_, r = (
self._pairs[idx]
.channel.unary_unary(_SUCCESSFUL_UNARY_UNARY)
.channel.unary_unary(
_SUCCESSFUL_UNARY_UNARY,
_registered_method=True,
)
.with_call(_REQUEST)
)
self.assertEqual(r.code(), grpc.StatusCode.OK)
def _send_failed_unary_unary(self, idx):
try:
self._pairs[idx].channel.unary_unary(_FAILED_UNARY_UNARY).with_call(
_REQUEST
)
self._pairs[idx].channel.unary_unary(
_FAILED_UNARY_UNARY,
_registered_method=True,
).with_call(_REQUEST)
except grpc.RpcError:
return
else:
@ -117,7 +121,10 @@ class ChannelzServicerTest(unittest.TestCase):
def _send_successful_stream_stream(self, idx):
response_iterator = (
self._pairs[idx]
.channel.stream_stream(_SUCCESSFUL_STREAM_STREAM)
.channel.stream_stream(
_SUCCESSFUL_STREAM_STREAM,
_registered_method=True,
)
.__call__(iter([_REQUEST] * test_constants.STREAM_LENGTH))
)
cnt = 0

@ -92,7 +92,10 @@ class TestCsds(unittest.TestCase):
# Force the XdsClient to initialize and request a resource
with self.assertRaises(grpc.RpcError) as rpc_error:
dummy_channel.unary_unary("")(b"", wait_for_ready=False, timeout=1)
dummy_channel.unary_unary(
"",
_registered_method=True,
)(b"", wait_for_ready=False, timeout=1)
self.assertEqual(
grpc.StatusCode.DEADLINE_EXCEEDED, rpc_error.exception.code()
)

@ -543,6 +543,17 @@ class PythonPluginTest(unittest.TestCase):
)
service.server.stop(None)
def testRegisteredMethod(self):
"""Tests that we're setting _registered_call_handle when create call using generated stub."""
service = _CreateService()
self.assertTrue(service.stub.UnaryCall._registered_call_handle)
self.assertTrue(
service.stub.StreamingOutputCall._registered_call_handle
)
self.assertTrue(service.stub.StreamingInputCall._registered_call_handle)
self.assertTrue(service.stub.FullDuplexCall._registered_call_handle)
service.server.stop(None)
@unittest.skipIf(
sys.version_info[0] < 3 or sys.version_info[1] < 6,

@ -32,13 +32,16 @@ _TIMEOUT = 60 * 60 * 24
class GenericStub(object):
def __init__(self, channel):
self.UnaryCall = channel.unary_unary(
"/grpc.testing.BenchmarkService/UnaryCall"
"/grpc.testing.BenchmarkService/UnaryCall",
_registered_method=True,
)
self.StreamingFromServer = channel.unary_stream(
"/grpc.testing.BenchmarkService/StreamingFromServer"
"/grpc.testing.BenchmarkService/StreamingFromServer",
_registered_method=True,
)
self.StreamingCall = channel.stream_stream(
"/grpc.testing.BenchmarkService/StreamingCall"
"/grpc.testing.BenchmarkService/StreamingCall",
_registered_method=True,
)

@ -138,7 +138,10 @@ class StatusTest(unittest.TestCase):
self._channel.close()
def test_status_ok(self):
_, call = self._channel.unary_unary(_STATUS_OK).with_call(_REQUEST)
_, call = self._channel.unary_unary(
_STATUS_OK,
_registered_method=True,
).with_call(_REQUEST)
# Succeed RPC doesn't have status
status = rpc_status.from_call(call)
@ -146,7 +149,10 @@ class StatusTest(unittest.TestCase):
def test_status_not_ok(self):
with self.assertRaises(grpc.RpcError) as exception_context:
self._channel.unary_unary(_STATUS_NOT_OK).with_call(_REQUEST)
self._channel.unary_unary(
_STATUS_NOT_OK,
_registered_method=True,
).with_call(_REQUEST)
rpc_error = exception_context.exception
self.assertEqual(rpc_error.code(), grpc.StatusCode.INTERNAL)
@ -156,7 +162,10 @@ class StatusTest(unittest.TestCase):
def test_error_details(self):
with self.assertRaises(grpc.RpcError) as exception_context:
self._channel.unary_unary(_ERROR_DETAILS).with_call(_REQUEST)
self._channel.unary_unary(
_ERROR_DETAILS,
_registered_method=True,
).with_call(_REQUEST)
rpc_error = exception_context.exception
status = rpc_status.from_call(rpc_error)
@ -173,7 +182,10 @@ class StatusTest(unittest.TestCase):
def test_code_message_validation(self):
with self.assertRaises(grpc.RpcError) as exception_context:
self._channel.unary_unary(_INCONSISTENT).with_call(_REQUEST)
self._channel.unary_unary(
_INCONSISTENT,
_registered_method=True,
).with_call(_REQUEST)
rpc_error = exception_context.exception
self.assertEqual(rpc_error.code(), grpc.StatusCode.NOT_FOUND)
@ -182,7 +194,10 @@ class StatusTest(unittest.TestCase):
def test_invalid_code(self):
with self.assertRaises(grpc.RpcError) as exception_context:
self._channel.unary_unary(_INVALID_CODE).with_call(_REQUEST)
self._channel.unary_unary(
_INVALID_CODE,
_registered_method=True,
).with_call(_REQUEST)
rpc_error = exception_context.exception
self.assertEqual(rpc_error.code(), grpc.StatusCode.UNKNOWN)
# Invalid status code exception raised during coversion

@ -107,7 +107,10 @@ class AbortTest(unittest.TestCase):
def test_abort(self):
with self.assertRaises(grpc.RpcError) as exception_context:
self._channel.unary_unary(_ABORT)(_REQUEST)
self._channel.unary_unary(
_ABORT,
_registered_method=True,
)(_REQUEST)
rpc_error = exception_context.exception
self.assertEqual(rpc_error.code(), grpc.StatusCode.INTERNAL)
@ -124,7 +127,10 @@ class AbortTest(unittest.TestCase):
# Servicer will abort() after creating a local ref to do_not_leak_me.
with self.assertRaises(grpc.RpcError):
self._channel.unary_unary(_ABORT)(_REQUEST)
self._channel.unary_unary(
_ABORT,
_registered_method=True,
)(_REQUEST)
# Server may still have a stack frame reference to the exception even
# after client sees error, so ensure server has shutdown.
@ -134,7 +140,10 @@ class AbortTest(unittest.TestCase):
def test_abort_with_status(self):
with self.assertRaises(grpc.RpcError) as exception_context:
self._channel.unary_unary(_ABORT_WITH_STATUS)(_REQUEST)
self._channel.unary_unary(
_ABORT_WITH_STATUS,
_registered_method=True,
)(_REQUEST)
rpc_error = exception_context.exception
self.assertEqual(rpc_error.code(), grpc.StatusCode.INTERNAL)
@ -143,7 +152,10 @@ class AbortTest(unittest.TestCase):
def test_invalid_code(self):
with self.assertRaises(grpc.RpcError) as exception_context:
self._channel.unary_unary(_INVALID_CODE)(_REQUEST)
self._channel.unary_unary(
_INVALID_CODE,
_registered_method=True,
)(_REQUEST)
rpc_error = exception_context.exception
self.assertEqual(rpc_error.code(), grpc.StatusCode.UNKNOWN)

@ -78,7 +78,10 @@ class AuthContextTest(unittest.TestCase):
server.start()
with grpc.insecure_channel("localhost:%d" % port) as channel:
response = channel.unary_unary(_UNARY_UNARY)(_REQUEST)
response = channel.unary_unary(
_UNARY_UNARY,
_registered_method=True,
)(_REQUEST)
server.stop(None)
auth_data = pickle.loads(response)
@ -115,7 +118,10 @@ class AuthContextTest(unittest.TestCase):
channel_creds,
options=_PROPERTY_OPTIONS,
)
response = channel.unary_unary(_UNARY_UNARY)(_REQUEST)
response = channel.unary_unary(
_UNARY_UNARY,
_registered_method=True,
)(_REQUEST)
channel.close()
server.stop(None)
@ -161,7 +167,10 @@ class AuthContextTest(unittest.TestCase):
options=_PROPERTY_OPTIONS,
)
response = channel.unary_unary(_UNARY_UNARY)(_REQUEST)
response = channel.unary_unary(
_UNARY_UNARY,
_registered_method=True,
)(_REQUEST)
channel.close()
server.stop(None)
@ -180,7 +189,10 @@ class AuthContextTest(unittest.TestCase):
channel = grpc.secure_channel(
"localhost:{}".format(port), channel_creds, options=channel_options
)
response = channel.unary_unary(_UNARY_UNARY)(_REQUEST)
response = channel.unary_unary(
_UNARY_UNARY,
_registered_method=True,
)(_REQUEST)
auth_data = pickle.loads(response)
self.assertEqual(
expect_ssl_session_reused,

@ -123,7 +123,10 @@ class ChannelCloseTest(unittest.TestCase):
def test_close_immediately_after_call_invocation(self):
channel = grpc.insecure_channel("localhost:{}".format(self._port))
multi_callable = channel.stream_stream(_STREAM_URI)
multi_callable = channel.stream_stream(
_STREAM_URI,
_registered_method=True,
)
request_iterator = _Pipe(())
response_iterator = multi_callable(request_iterator)
channel.close()
@ -133,7 +136,10 @@ class ChannelCloseTest(unittest.TestCase):
def test_close_while_call_active(self):
channel = grpc.insecure_channel("localhost:{}".format(self._port))
multi_callable = channel.stream_stream(_STREAM_URI)
multi_callable = channel.stream_stream(
_STREAM_URI,
_registered_method=True,
)
request_iterator = _Pipe((b"abc",))
response_iterator = multi_callable(request_iterator)
next(response_iterator)
@ -146,7 +152,10 @@ class ChannelCloseTest(unittest.TestCase):
with grpc.insecure_channel(
"localhost:{}".format(self._port)
) as channel: # pylint: disable=bad-continuation
multi_callable = channel.stream_stream(_STREAM_URI)
multi_callable = channel.stream_stream(
_STREAM_URI,
_registered_method=True,
)
request_iterator = _Pipe((b"abc",))
response_iterator = multi_callable(request_iterator)
next(response_iterator)
@ -158,7 +167,10 @@ class ChannelCloseTest(unittest.TestCase):
with grpc.insecure_channel(
"localhost:{}".format(self._port)
) as channel: # pylint: disable=bad-continuation
multi_callable = channel.stream_stream(_STREAM_URI)
multi_callable = channel.stream_stream(
_STREAM_URI,
_registered_method=True,
)
request_iterators = tuple(
_Pipe((b"abc",))
for _ in range(test_constants.THREAD_CONCURRENCY)
@ -176,7 +188,10 @@ class ChannelCloseTest(unittest.TestCase):
def test_many_concurrent_closes(self):
channel = grpc.insecure_channel("localhost:{}".format(self._port))
multi_callable = channel.stream_stream(_STREAM_URI)
multi_callable = channel.stream_stream(
_STREAM_URI,
_registered_method=True,
)
request_iterator = _Pipe((b"abc",))
response_iterator = multi_callable(request_iterator)
next(response_iterator)
@ -203,10 +218,16 @@ class ChannelCloseTest(unittest.TestCase):
with grpc.insecure_channel(
"localhost:{}".format(self._port)
) as channel:
stream_multi_callable = channel.stream_stream(_STREAM_URI)
stream_multi_callable = channel.stream_stream(
_STREAM_URI,
_registered_method=True,
)
endless_iterator = itertools.repeat(b"abc")
stream_response_iterator = stream_multi_callable(endless_iterator)
future = channel.unary_unary(_UNARY_URI).future(b"abc")
future = channel.unary_unary(
_UNARY_URI,
_registered_method=True,
).future(b"abc")
def on_done_callback(future):
raise Exception("This should not cause a deadlock.")

@ -237,7 +237,10 @@ def _get_compression_ratios(
def _unary_unary_client(channel, multicallable_kwargs, message):
multi_callable = channel.unary_unary(_UNARY_UNARY)
multi_callable = channel.unary_unary(
_UNARY_UNARY,
_registered_method=True,
)
response = multi_callable(message, **multicallable_kwargs)
if response != message:
raise RuntimeError(
@ -246,7 +249,10 @@ def _unary_unary_client(channel, multicallable_kwargs, message):
def _unary_stream_client(channel, multicallable_kwargs, message):
multi_callable = channel.unary_stream(_UNARY_STREAM)
multi_callable = channel.unary_stream(
_UNARY_STREAM,
_registered_method=True,
)
response_iterator = multi_callable(message, **multicallable_kwargs)
for response in response_iterator:
if response != message:
@ -256,7 +262,10 @@ def _unary_stream_client(channel, multicallable_kwargs, message):
def _stream_unary_client(channel, multicallable_kwargs, message):
multi_callable = channel.stream_unary(_STREAM_UNARY)
multi_callable = channel.stream_unary(
_STREAM_UNARY,
_registered_method=True,
)
requests = (_REQUEST for _ in range(_STREAM_LENGTH))
response = multi_callable(requests, **multicallable_kwargs)
if response != message:
@ -266,7 +275,10 @@ def _stream_unary_client(channel, multicallable_kwargs, message):
def _stream_stream_client(channel, multicallable_kwargs, message):
multi_callable = channel.stream_stream(_STREAM_STREAM)
multi_callable = channel.stream_stream(
_STREAM_STREAM,
_registered_method=True,
)
request_prefix = str(0).encode("ascii") * 100
requests = (
request_prefix + str(i).encode("ascii") for i in range(_STREAM_LENGTH)

@ -116,7 +116,10 @@ class ContextVarsPropagationTest(unittest.TestCase):
local_credentials, call_credentials
)
with grpc.secure_channel(target, composite_credentials) as channel:
stub = channel.unary_unary(_UNARY_UNARY)
stub = channel.unary_unary(
_UNARY_UNARY,
_registered_method=True,
)
response = stub(_REQUEST, wait_for_ready=True)
self.assertEqual(_REQUEST, response)
@ -142,7 +145,10 @@ class ContextVarsPropagationTest(unittest.TestCase):
with grpc.secure_channel(
target, composite_credentials
) as channel:
stub = channel.unary_unary(_UNARY_UNARY)
stub = channel.unary_unary(
_UNARY_UNARY,
_registered_method=True,
)
wait_group.done()
wait_group.wait()
for i in range(_RPC_COUNT):

@ -55,7 +55,10 @@ class DNSResolverTest(unittest.TestCase):
"loopback46.unittest.grpc.io:%d" % self._port
) as channel:
self.assertEqual(
channel.unary_unary(_METHOD)(
channel.unary_unary(
_METHOD,
_registered_method=True,
)(
_REQUEST,
timeout=10,
),

@ -96,25 +96,33 @@ class EmptyMessageTest(unittest.TestCase):
self._channel.close()
def testUnaryUnary(self):
response = self._channel.unary_unary(_UNARY_UNARY)(_REQUEST)
response = self._channel.unary_unary(
_UNARY_UNARY,
_registered_method=True,
)(_REQUEST)
self.assertEqual(_RESPONSE, response)
def testUnaryStream(self):
response_iterator = self._channel.unary_stream(_UNARY_STREAM)(_REQUEST)
response_iterator = self._channel.unary_stream(
_UNARY_STREAM,
_registered_method=True,
)(_REQUEST)
self.assertSequenceEqual(
[_RESPONSE] * test_constants.STREAM_LENGTH, list(response_iterator)
)
def testStreamUnary(self):
response = self._channel.stream_unary(_STREAM_UNARY)(
iter([_REQUEST] * test_constants.STREAM_LENGTH)
)
response = self._channel.stream_unary(
_STREAM_UNARY,
_registered_method=True,
)(iter([_REQUEST] * test_constants.STREAM_LENGTH))
self.assertEqual(_RESPONSE, response)
def testStreamStream(self):
response_iterator = self._channel.stream_stream(_STREAM_STREAM)(
iter([_REQUEST] * test_constants.STREAM_LENGTH)
)
response_iterator = self._channel.stream_stream(
_STREAM_STREAM,
_registered_method=True,
)(iter([_REQUEST] * test_constants.STREAM_LENGTH))
self.assertSequenceEqual(
[_RESPONSE] * test_constants.STREAM_LENGTH, list(response_iterator)
)

@ -73,7 +73,10 @@ class ErrorMessageEncodingTest(unittest.TestCase):
def testMessageEncoding(self):
for message in _UNICODE_ERROR_MESSAGES:
multi_callable = self._channel.unary_unary(_UNARY_UNARY)
multi_callable = self._channel.unary_unary(
_UNARY_UNARY,
_registered_method=True,
)
with self.assertRaises(grpc.RpcError) as cm:
multi_callable(message.encode("utf-8"))

@ -210,14 +210,20 @@ if __name__ == "__main__":
method = TEST_TO_METHOD[args.scenario]
if args.scenario == IN_FLIGHT_UNARY_UNARY_CALL:
multi_callable = channel.unary_unary(method)
multi_callable = channel.unary_unary(
method,
_registered_method=True,
)
future = multi_callable.future(REQUEST)
result, call = multi_callable.with_call(REQUEST)
elif (
args.scenario == IN_FLIGHT_UNARY_STREAM_CALL
or args.scenario == IN_FLIGHT_PARTIAL_UNARY_STREAM_CALL
):
multi_callable = channel.unary_stream(method)
multi_callable = channel.unary_stream(
method,
_registered_method=True,
)
response_iterator = multi_callable(REQUEST)
for response in response_iterator:
pass
@ -225,7 +231,10 @@ if __name__ == "__main__":
args.scenario == IN_FLIGHT_STREAM_UNARY_CALL
or args.scenario == IN_FLIGHT_PARTIAL_STREAM_UNARY_CALL
):
multi_callable = channel.stream_unary(method)
multi_callable = channel.stream_unary(
method,
_registered_method=True,
)
future = multi_callable.future(infinite_request_iterator())
result, call = multi_callable.with_call(
iter([REQUEST] * test_constants.STREAM_LENGTH)
@ -234,7 +243,10 @@ if __name__ == "__main__":
args.scenario == IN_FLIGHT_STREAM_STREAM_CALL
or args.scenario == IN_FLIGHT_PARTIAL_STREAM_STREAM_CALL
):
multi_callable = channel.stream_stream(method)
multi_callable = channel.stream_stream(
method,
_registered_method=True,
)
response_iterator = multi_callable(infinite_request_iterator())
for response in response_iterator:
pass

@ -231,7 +231,7 @@ class _GenericHandler(grpc.GenericRpcHandler):
def _unary_unary_multi_callable(channel):
return channel.unary_unary(_UNARY_UNARY)
return channel.unary_unary(_UNARY_UNARY, _registered_method=True)
def _unary_stream_multi_callable(channel):
@ -239,6 +239,7 @@ def _unary_stream_multi_callable(channel):
_UNARY_STREAM,
request_serializer=_SERIALIZE_REQUEST,
response_deserializer=_DESERIALIZE_RESPONSE,
_registered_method=True,
)
@ -247,11 +248,12 @@ def _stream_unary_multi_callable(channel):
_STREAM_UNARY,
request_serializer=_SERIALIZE_REQUEST,
response_deserializer=_DESERIALIZE_RESPONSE,
_registered_method=True,
)
def _stream_stream_multi_callable(channel):
return channel.stream_stream(_STREAM_STREAM)
return channel.stream_stream(_STREAM_STREAM, _registered_method=True)
class _ClientCallDetails(
@ -562,7 +564,7 @@ class InterceptorTest(unittest.TestCase):
self._record[:] = []
multi_callable = _unary_unary_multi_callable(channel)
multi_callable.with_call(
response, call = multi_callable.with_call(
request,
metadata=(
(

@ -32,7 +32,7 @@ _STREAM_STREAM = "/test/StreamStream"
def _unary_unary_multi_callable(channel):
return channel.unary_unary(_UNARY_UNARY)
return channel.unary_unary(_UNARY_UNARY, _registered_method=True)
def _unary_stream_multi_callable(channel):
@ -40,6 +40,7 @@ def _unary_stream_multi_callable(channel):
_UNARY_STREAM,
request_serializer=_SERIALIZE_REQUEST,
response_deserializer=_DESERIALIZE_RESPONSE,
_registered_method=True,
)
@ -48,11 +49,15 @@ def _stream_unary_multi_callable(channel):
_STREAM_UNARY,
request_serializer=_SERIALIZE_REQUEST,
response_deserializer=_DESERIALIZE_RESPONSE,
_registered_method=True,
)
def _stream_stream_multi_callable(channel):
return channel.stream_stream(_STREAM_STREAM)
return channel.stream_stream(
_STREAM_STREAM,
_registered_method=True,
)
class InvalidMetadataTest(unittest.TestCase):

@ -219,7 +219,10 @@ class FailAfterFewIterationsCounter(object):
def _unary_unary_multi_callable(channel):
return channel.unary_unary(_UNARY_UNARY)
return channel.unary_unary(
_UNARY_UNARY,
_registered_method=True,
)
def _unary_stream_multi_callable(channel):
@ -227,6 +230,7 @@ def _unary_stream_multi_callable(channel):
_UNARY_STREAM,
request_serializer=_SERIALIZE_REQUEST,
response_deserializer=_DESERIALIZE_RESPONSE,
_registered_method=True,
)
@ -235,19 +239,29 @@ def _stream_unary_multi_callable(channel):
_STREAM_UNARY,
request_serializer=_SERIALIZE_REQUEST,
response_deserializer=_DESERIALIZE_RESPONSE,
_registered_method=True,
)
def _stream_stream_multi_callable(channel):
return channel.stream_stream(_STREAM_STREAM)
return channel.stream_stream(
_STREAM_STREAM,
_registered_method=True,
)
def _defective_handler_multi_callable(channel):
return channel.unary_unary(_DEFECTIVE_GENERIC_RPC_HANDLER)
return channel.unary_unary(
_DEFECTIVE_GENERIC_RPC_HANDLER,
_registered_method=True,
)
def _defective_nested_exception_handler_multi_callable(channel):
return channel.unary_unary(_UNARY_UNARY_NESTED_EXCEPTION)
return channel.unary_unary(
_UNARY_UNARY_NESTED_EXCEPTION,
_registered_method=True,
)
class InvocationDefectsTest(unittest.TestCase):

@ -53,9 +53,10 @@ class LocalCredentialsTest(unittest.TestCase):
) as channel:
self.assertEqual(
b"abc",
channel.unary_unary("/test/method")(
b"abc", wait_for_ready=True
),
channel.unary_unary(
"/test/method",
_registered_method=True,
)(b"abc", wait_for_ready=True),
)
server.stop(None)
@ -77,9 +78,10 @@ class LocalCredentialsTest(unittest.TestCase):
with grpc.secure_channel(server_addr, channel_creds) as channel:
self.assertEqual(
b"abc",
channel.unary_unary("/test/method")(
b"abc", wait_for_ready=True
),
channel.unary_unary(
"/test/method",
_registered_method=True,
)(b"abc", wait_for_ready=True),
)
server.stop(None)

@ -207,45 +207,53 @@ class MetadataCodeDetailsTest(unittest.TestCase):
self._server.start()
self._channel = grpc.insecure_channel("localhost:{}".format(port))
unary_unary_method_name = "/".join(
(
"",
_SERVICE,
_UNARY_UNARY,
)
)
self._unary_unary = self._channel.unary_unary(
"/".join(
(
"",
_SERVICE,
_UNARY_UNARY,
)
),
unary_unary_method_name,
request_serializer=_REQUEST_SERIALIZER,
response_deserializer=_RESPONSE_DESERIALIZER,
_registered_method=True,
)
unary_stream_method_name = "/".join(
(
"",
_SERVICE,
_UNARY_STREAM,
)
)
self._unary_stream = self._channel.unary_stream(
"/".join(
(
"",
_SERVICE,
_UNARY_STREAM,
)
),
unary_stream_method_name,
_registered_method=True,
)
stream_unary_method_name = "/".join(
(
"",
_SERVICE,
_STREAM_UNARY,
)
)
self._stream_unary = self._channel.stream_unary(
"/".join(
(
"",
_SERVICE,
_STREAM_UNARY,
)
),
stream_unary_method_name,
_registered_method=True,
)
stream_stream_method_name = "/".join(
(
"",
_SERVICE,
_STREAM_STREAM,
)
)
self._stream_stream = self._channel.stream_stream(
"/".join(
(
"",
_SERVICE,
_STREAM_STREAM,
)
),
stream_stream_method_name,
request_serializer=_REQUEST_SERIALIZER,
response_deserializer=_RESPONSE_DESERIALIZER,
_registered_method=True,
)
def tearDown(self):
@ -828,16 +836,18 @@ class InspectContextTest(unittest.TestCase):
self._server.start()
self._channel = grpc.insecure_channel("localhost:{}".format(port))
unary_unary_method_name = "/".join(
(
"",
_SERVICE,
_UNARY_UNARY,
)
)
self._unary_unary = self._channel.unary_unary(
"/".join(
(
"",
_SERVICE,
_UNARY_UNARY,
)
),
unary_unary_method_name,
request_serializer=_REQUEST_SERIALIZER,
response_deserializer=_RESPONSE_DESERIALIZER,
_registered_method=True,
)
def tearDown(self):

@ -110,7 +110,10 @@ def create_phony_channel():
def perform_unary_unary_call(channel, wait_for_ready=None):
channel.unary_unary(_UNARY_UNARY).__call__(
channel.unary_unary(
_UNARY_UNARY,
_registered_method=True,
).__call__(
_REQUEST,
timeout=test_constants.LONG_TIMEOUT,
wait_for_ready=wait_for_ready,
@ -118,7 +121,10 @@ def perform_unary_unary_call(channel, wait_for_ready=None):
def perform_unary_unary_with_call(channel, wait_for_ready=None):
channel.unary_unary(_UNARY_UNARY).with_call(
channel.unary_unary(
_UNARY_UNARY,
_registered_method=True,
).with_call(
_REQUEST,
timeout=test_constants.LONG_TIMEOUT,
wait_for_ready=wait_for_ready,
@ -126,7 +132,10 @@ def perform_unary_unary_with_call(channel, wait_for_ready=None):
def perform_unary_unary_future(channel, wait_for_ready=None):
channel.unary_unary(_UNARY_UNARY).future(
channel.unary_unary(
_UNARY_UNARY,
_registered_method=True,
).future(
_REQUEST,
timeout=test_constants.LONG_TIMEOUT,
wait_for_ready=wait_for_ready,
@ -134,7 +143,10 @@ def perform_unary_unary_future(channel, wait_for_ready=None):
def perform_unary_stream_call(channel, wait_for_ready=None):
response_iterator = channel.unary_stream(_UNARY_STREAM).__call__(
response_iterator = channel.unary_stream(
_UNARY_STREAM,
_registered_method=True,
).__call__(
_REQUEST,
timeout=test_constants.LONG_TIMEOUT,
wait_for_ready=wait_for_ready,
@ -144,7 +156,10 @@ def perform_unary_stream_call(channel, wait_for_ready=None):
def perform_stream_unary_call(channel, wait_for_ready=None):
channel.stream_unary(_STREAM_UNARY).__call__(
channel.stream_unary(
_STREAM_UNARY,
_registered_method=True,
).__call__(
iter([_REQUEST] * test_constants.STREAM_LENGTH),
timeout=test_constants.LONG_TIMEOUT,
wait_for_ready=wait_for_ready,
@ -152,7 +167,10 @@ def perform_stream_unary_call(channel, wait_for_ready=None):
def perform_stream_unary_with_call(channel, wait_for_ready=None):
channel.stream_unary(_STREAM_UNARY).with_call(
channel.stream_unary(
_STREAM_UNARY,
_registered_method=True,
).with_call(
iter([_REQUEST] * test_constants.STREAM_LENGTH),
timeout=test_constants.LONG_TIMEOUT,
wait_for_ready=wait_for_ready,
@ -160,7 +178,10 @@ def perform_stream_unary_with_call(channel, wait_for_ready=None):
def perform_stream_unary_future(channel, wait_for_ready=None):
channel.stream_unary(_STREAM_UNARY).future(
channel.stream_unary(
_STREAM_UNARY,
_registered_method=True,
).future(
iter([_REQUEST] * test_constants.STREAM_LENGTH),
timeout=test_constants.LONG_TIMEOUT,
wait_for_ready=wait_for_ready,
@ -168,7 +189,9 @@ def perform_stream_unary_future(channel, wait_for_ready=None):
def perform_stream_stream_call(channel, wait_for_ready=None):
response_iterator = channel.stream_stream(_STREAM_STREAM).__call__(
response_iterator = channel.stream_stream(
_STREAM_STREAM, _registered_method=True
).__call__(
iter([_REQUEST] * test_constants.STREAM_LENGTH),
timeout=test_constants.LONG_TIMEOUT,
wait_for_ready=wait_for_ready,

@ -195,7 +195,9 @@ class MetadataTest(unittest.TestCase):
self._channel.close()
def testUnaryUnary(self):
multi_callable = self._channel.unary_unary(_UNARY_UNARY)
multi_callable = self._channel.unary_unary(
_UNARY_UNARY, _registered_method=True
)
unused_response, call = multi_callable.with_call(
_REQUEST, metadata=_INVOCATION_METADATA
)
@ -211,7 +213,9 @@ class MetadataTest(unittest.TestCase):
)
def testUnaryStream(self):
multi_callable = self._channel.unary_stream(_UNARY_STREAM)
multi_callable = self._channel.unary_stream(
_UNARY_STREAM, _registered_method=True
)
call = multi_callable(_REQUEST, metadata=_INVOCATION_METADATA)
self.assertTrue(
test_common.metadata_transmitted(
@ -227,7 +231,9 @@ class MetadataTest(unittest.TestCase):
)
def testStreamUnary(self):
multi_callable = self._channel.stream_unary(_STREAM_UNARY)
multi_callable = self._channel.stream_unary(
_STREAM_UNARY, _registered_method=True
)
unused_response, call = multi_callable.with_call(
iter([_REQUEST] * test_constants.STREAM_LENGTH),
metadata=_INVOCATION_METADATA,
@ -244,7 +250,9 @@ class MetadataTest(unittest.TestCase):
)
def testStreamStream(self):
multi_callable = self._channel.stream_stream(_STREAM_STREAM)
multi_callable = self._channel.stream_stream(
_STREAM_STREAM, _registered_method=True
)
call = multi_callable(
iter([_REQUEST] * test_constants.STREAM_LENGTH),
metadata=_INVOCATION_METADATA,

@ -52,7 +52,10 @@ class ReconnectTest(unittest.TestCase):
server.add_insecure_port(addr)
server.start()
channel = grpc.insecure_channel(addr)
multi_callable = channel.unary_unary(_UNARY_UNARY)
multi_callable = channel.unary_unary(
_UNARY_UNARY,
_registered_method=True,
)
self.assertEqual(_RESPONSE, multi_callable(_REQUEST))
server.stop(None)
# By default, the channel connectivity is checked every 5s

@ -149,7 +149,10 @@ class ResourceExhaustedTest(unittest.TestCase):
self._channel.close()
def testUnaryUnary(self):
multi_callable = self._channel.unary_unary(_UNARY_UNARY)
multi_callable = self._channel.unary_unary(
_UNARY_UNARY,
_registered_method=True,
)
futures = []
for _ in range(test_constants.THREAD_CONCURRENCY):
futures.append(multi_callable.future(_REQUEST))
@ -178,7 +181,10 @@ class ResourceExhaustedTest(unittest.TestCase):
self.assertEqual(_RESPONSE, multi_callable(_REQUEST))
def testUnaryStream(self):
multi_callable = self._channel.unary_stream(_UNARY_STREAM)
multi_callable = self._channel.unary_stream(
_UNARY_STREAM,
_registered_method=True,
)
calls = []
for _ in range(test_constants.THREAD_CONCURRENCY):
calls.append(multi_callable(_REQUEST))
@ -205,7 +211,10 @@ class ResourceExhaustedTest(unittest.TestCase):
self.assertEqual(_RESPONSE, response)
def testStreamUnary(self):
multi_callable = self._channel.stream_unary(_STREAM_UNARY)
multi_callable = self._channel.stream_unary(
_STREAM_UNARY,
_registered_method=True,
)
futures = []
request = iter([_REQUEST] * test_constants.STREAM_LENGTH)
for _ in range(test_constants.THREAD_CONCURRENCY):
@ -236,7 +245,10 @@ class ResourceExhaustedTest(unittest.TestCase):
self.assertEqual(_RESPONSE, multi_callable(request))
def testStreamStream(self):
multi_callable = self._channel.stream_stream(_STREAM_STREAM)
multi_callable = self._channel.stream_stream(
_STREAM_STREAM,
_registered_method=True,
)
calls = []
request = iter([_REQUEST] * test_constants.STREAM_LENGTH)
for _ in range(test_constants.THREAD_CONCURRENCY):

@ -277,7 +277,10 @@ class _GenericHandler(grpc.GenericRpcHandler):
def unary_unary_multi_callable(channel):
return channel.unary_unary(_UNARY_UNARY)
return channel.unary_unary(
_UNARY_UNARY,
_registered_method=True,
)
def unary_stream_multi_callable(channel):
@ -285,6 +288,7 @@ def unary_stream_multi_callable(channel):
_UNARY_STREAM,
request_serializer=_SERIALIZE_REQUEST,
response_deserializer=_DESERIALIZE_RESPONSE,
_registered_method=True,
)
@ -293,6 +297,7 @@ def unary_stream_non_blocking_multi_callable(channel):
_UNARY_STREAM_NON_BLOCKING,
request_serializer=_SERIALIZE_REQUEST,
response_deserializer=_DESERIALIZE_RESPONSE,
_registered_method=True,
)
@ -301,15 +306,22 @@ def stream_unary_multi_callable(channel):
_STREAM_UNARY,
request_serializer=_SERIALIZE_REQUEST,
response_deserializer=_DESERIALIZE_RESPONSE,
_registered_method=True,
)
def stream_stream_multi_callable(channel):
return channel.stream_stream(_STREAM_STREAM)
return channel.stream_stream(
_STREAM_STREAM,
_registered_method=True,
)
def stream_stream_non_blocking_multi_callable(channel):
return channel.stream_stream(_STREAM_STREAM_NON_BLOCKING)
return channel.stream_stream(
_STREAM_STREAM_NON_BLOCKING,
_registered_method=True,
)
class BaseRPCTest(object):

@ -81,7 +81,10 @@ def run_test(args):
thread.start()
port = port_queue.get()
channel = grpc.insecure_channel("localhost:%d" % port)
multi_callable = channel.unary_unary(FORK_EXIT)
multi_callable = channel.unary_unary(
FORK_EXIT,
_registered_method=True,
)
result, call = multi_callable.with_call(REQUEST, wait_for_ready=True)
os.wait()
else:

@ -77,7 +77,10 @@ class SSLSessionCacheTest(unittest.TestCase):
channel = grpc.secure_channel(
"localhost:{}".format(port), channel_creds, options=channel_options
)
response = channel.unary_unary(_UNARY_UNARY)(_REQUEST)
response = channel.unary_unary(
_UNARY_UNARY,
_registered_method=True,
)(_REQUEST)
auth_data = pickle.loads(response)
self.assertEqual(
expect_ssl_session_reused,

@ -53,7 +53,10 @@ def main_unary(server_target):
"""Initiate a unary RPC to be interrupted by a SIGINT."""
global per_process_rpc_future # pylint: disable=global-statement
with grpc.insecure_channel(server_target) as channel:
multicallable = channel.unary_unary(UNARY_UNARY)
multicallable = channel.unary_unary(
UNARY_UNARY,
_registered_method=True,
)
signal.signal(signal.SIGINT, handle_sigint)
per_process_rpc_future = multicallable.future(
_MESSAGE, wait_for_ready=True
@ -67,9 +70,10 @@ def main_streaming(server_target):
global per_process_rpc_future # pylint: disable=global-statement
with grpc.insecure_channel(server_target) as channel:
signal.signal(signal.SIGINT, handle_sigint)
per_process_rpc_future = channel.unary_stream(UNARY_STREAM)(
_MESSAGE, wait_for_ready=True
)
per_process_rpc_future = channel.unary_stream(
UNARY_STREAM,
_registered_method=True,
)(_MESSAGE, wait_for_ready=True)
for result in per_process_rpc_future:
pass
assert False, _ASSERTION_MESSAGE
@ -79,7 +83,10 @@ def main_unary_with_exception(server_target):
"""Initiate a unary RPC with a signal handler that will raise."""
channel = grpc.insecure_channel(server_target)
try:
channel.unary_unary(UNARY_UNARY)(_MESSAGE, wait_for_ready=True)
channel.unary_unary(
UNARY_UNARY,
_registered_method=True,
)(_MESSAGE, wait_for_ready=True)
except KeyboardInterrupt:
sys.stderr.write("Running signal handler.\n")
sys.stderr.flush()
@ -92,9 +99,10 @@ def main_streaming_with_exception(server_target):
"""Initiate a streaming RPC with a signal handler that will raise."""
channel = grpc.insecure_channel(server_target)
try:
for _ in channel.unary_stream(UNARY_STREAM)(
_MESSAGE, wait_for_ready=True
):
for _ in channel.unary_stream(
UNARY_STREAM,
_registered_method=True,
)(_MESSAGE, wait_for_ready=True):
pass
except KeyboardInterrupt:
sys.stderr.write("Running signal handler.\n")

@ -71,9 +71,10 @@ class XdsCredentialsTest(unittest.TestCase):
server_address, channel_creds, options=override_options
) as channel:
request = b"abc"
response = channel.unary_unary("/test/method")(
request, wait_for_ready=True
)
response = channel.unary_unary(
"/test/method",
_registered_method=True,
)(request, wait_for_ready=True)
self.assertEqual(response, request)
def test_xds_creds_fallback_insecure(self):
@ -89,9 +90,10 @@ class XdsCredentialsTest(unittest.TestCase):
channel_creds = grpc.xds_channel_credentials(channel_fallback_creds)
with grpc.secure_channel(server_address, channel_creds) as channel:
request = b"abc"
response = channel.unary_unary("/test/method")(
request, wait_for_ready=True
)
response = channel.unary_unary(
"/test/method",
_registered_method=True,
)(request, wait_for_ready=True)
self.assertEqual(response, request)
def test_start_xds_server(self):

@ -65,6 +65,7 @@ class CloseChannelTest(unittest.TestCase):
_UNARY_CALL_METHOD_WITH_SLEEP,
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString,
_registered_method=True,
)
greenlet = group.spawn(self._run_client, UnaryCallWithSleep)
# release loop so that greenlet can take control
@ -78,6 +79,7 @@ class CloseChannelTest(unittest.TestCase):
_UNARY_CALL_METHOD_WITH_SLEEP,
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString,
_registered_method=True,
)
greenlet = group.spawn(self._run_client, UnaryCallWithSleep)
# release loop so that greenlet can take control

@ -68,7 +68,10 @@ def _start_a_test_server():
def _perform_an_rpc(address):
channel = grpc.insecure_channel(address)
multicallable = channel.unary_unary(_TEST_METHOD)
multicallable = channel.unary_unary(
_TEST_METHOD,
_registered_method=True,
)
response = multicallable(_REQUEST)
assert _REQUEST == response

@ -193,6 +193,7 @@ class SimpleStubsTest(unittest.TestCase):
_UNARY_UNARY,
channel_credentials=grpc.experimental.insecure_channel_credentials(),
timeout=None,
_registered_method=0,
)
self.assertEqual(_REQUEST, response)
@ -205,6 +206,7 @@ class SimpleStubsTest(unittest.TestCase):
_UNARY_UNARY,
channel_credentials=grpc.local_channel_credentials(),
timeout=None,
_registered_method=0,
)
self.assertEqual(_REQUEST, response)
@ -213,7 +215,10 @@ class SimpleStubsTest(unittest.TestCase):
target = f"localhost:{port}"
test_name = inspect.stack()[0][3]
args = (_REQUEST, target, _UNARY_UNARY)
kwargs = {"channel_credentials": grpc.local_channel_credentials()}
kwargs = {
"channel_credentials": grpc.local_channel_credentials(),
"_registered_method": True,
}
def _invoke(seed: str):
run_kwargs = dict(kwargs)
@ -230,6 +235,7 @@ class SimpleStubsTest(unittest.TestCase):
target,
_UNARY_UNARY,
channel_credentials=grpc.local_channel_credentials(),
_registered_method=0,
)
self.assert_eventually(
lambda: grpc._simple_stubs.ChannelCache.get()._test_only_channel_count()
@ -250,6 +256,7 @@ class SimpleStubsTest(unittest.TestCase):
_UNARY_UNARY,
options=options,
channel_credentials=grpc.local_channel_credentials(),
_registered_method=0,
)
self.assert_eventually(
lambda: grpc._simple_stubs.ChannelCache.get()._test_only_channel_count()
@ -265,6 +272,7 @@ class SimpleStubsTest(unittest.TestCase):
target,
_UNARY_STREAM,
channel_credentials=grpc.local_channel_credentials(),
_registered_method=0,
):
self.assertEqual(_REQUEST, response)
@ -280,6 +288,7 @@ class SimpleStubsTest(unittest.TestCase):
target,
_STREAM_UNARY,
channel_credentials=grpc.local_channel_credentials(),
_registered_method=0,
)
self.assertEqual(_REQUEST, response)
@ -295,6 +304,7 @@ class SimpleStubsTest(unittest.TestCase):
target,
_STREAM_STREAM,
channel_credentials=grpc.local_channel_credentials(),
_registered_method=0,
):
self.assertEqual(_REQUEST, response)
@ -319,14 +329,22 @@ class SimpleStubsTest(unittest.TestCase):
with _server(server_creds) as port:
target = f"localhost:{port}"
response = grpc.experimental.unary_unary(
_REQUEST, target, _UNARY_UNARY, options=_property_options
_REQUEST,
target,
_UNARY_UNARY,
options=_property_options,
_registered_method=0,
)
def test_insecure_sugar(self):
with _server(None) as port:
target = f"localhost:{port}"
response = grpc.experimental.unary_unary(
_REQUEST, target, _UNARY_UNARY, insecure=True
_REQUEST,
target,
_UNARY_UNARY,
insecure=True,
_registered_method=0,
)
self.assertEqual(_REQUEST, response)
@ -340,14 +358,24 @@ class SimpleStubsTest(unittest.TestCase):
_UNARY_UNARY,
insecure=True,
channel_credentials=grpc.local_channel_credentials(),
_registered_method=0,
)
def test_default_wait_for_ready(self):
addr, port, sock = get_socket()
sock.close()
target = f"{addr}:{port}"
channel = grpc._simple_stubs.ChannelCache.get().get_channel(
target, (), None, True, None
(
channel,
unused_method_handle,
) = grpc._simple_stubs.ChannelCache.get().get_channel(
target=target,
options=(),
channel_credentials=None,
insecure=True,
compression=None,
method=_UNARY_UNARY,
_registered_method=True,
)
rpc_finished_event = threading.Event()
rpc_failed_event = threading.Event()
@ -376,7 +404,12 @@ class SimpleStubsTest(unittest.TestCase):
def _send_rpc():
try:
response = grpc.experimental.unary_unary(
_REQUEST, target, _UNARY_UNARY, timeout=None, insecure=True
_REQUEST,
target,
_UNARY_UNARY,
timeout=None,
insecure=True,
_registered_method=0,
)
rpc_finished_event.set()
except Exception as e:
@ -399,6 +432,7 @@ class SimpleStubsTest(unittest.TestCase):
target,
_BLACK_HOLE,
insecure=True,
_registered_method=0,
**invocation_args,
)
self.assertEqual(

@ -29,6 +29,8 @@
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_join.h"
#include "absl/strings/str_split.h"
#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
@ -407,6 +409,21 @@ void ApiFuzzer::Tick() {
}
}
namespace {
// If there are more than 1K comma-delimited strings in target, remove
// the extra ones.
std::string SanitizeTargetUri(absl::string_view target) {
constexpr size_t kMaxCommaDelimitedStrings = 1000;
std::vector<absl::string_view> parts = absl::StrSplit(target, ',');
if (parts.size() > kMaxCommaDelimitedStrings) {
parts.resize(kMaxCommaDelimitedStrings);
}
return absl::StrJoin(parts, ",");
}
} // namespace
ApiFuzzer::Result ApiFuzzer::CreateChannel(
const api_fuzzer::CreateChannel& create_channel) {
if (channel_ != nullptr) return Result::kComplete;
@ -423,8 +440,9 @@ ApiFuzzer::Result ApiFuzzer::CreateChannel(
create_channel.has_channel_creds()
? ReadChannelCreds(create_channel.channel_creds())
: grpc_insecure_credentials_create();
channel_ = grpc_channel_create(create_channel.target().c_str(), creds,
args.ToC().get());
channel_ =
grpc_channel_create(SanitizeTargetUri(create_channel.target()).c_str(),
creds, args.ToC().get());
grpc_channel_credentials_release(creds);
}
GPR_ASSERT(channel_ != nullptr);

@ -95,6 +95,8 @@ TEST(MpscTest, SendingLotsOfThingsGivesPushback) {
EXPECT_EQ(NowOrNever(sender.Send(MakePayload(1))), true);
EXPECT_EQ(NowOrNever(sender.Send(MakePayload(2))), absl::nullopt);
activity1.Deactivate();
EXPECT_CALL(activity1, WakeupRequested());
}
TEST(MpscTest, ReceivingAfterBlockageWakesUp) {

@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
load("//bazel:grpc_build_system.bzl", "grpc_cc_test", "grpc_package")
load("//test/core/util:grpc_fuzzer.bzl", "grpc_fuzzer")
load("//bazel:grpc_build_system.bzl", "grpc_cc_library", "grpc_cc_test", "grpc_package")
load("//test/core/util:grpc_fuzzer.bzl", "grpc_fuzzer", "grpc_proto_fuzzer")
licenses(["notice"])
@ -22,6 +22,34 @@ grpc_package(
visibility = "tests",
)
grpc_cc_library(
name = "mock_promise_endpoint",
testonly = 1,
srcs = ["mock_promise_endpoint.cc"],
hdrs = ["mock_promise_endpoint.h"],
external_deps = ["gtest"],
deps = [
"//:grpc",
"//src/core:grpc_promise_endpoint",
],
)
grpc_cc_library(
name = "transport_test",
testonly = 1,
srcs = ["transport_test.cc"],
hdrs = ["transport_test.h"],
external_deps = ["gtest"],
deps = [
"//:iomgr_timer",
"//src/core:chaotic_good_frame",
"//src/core:memory_quota",
"//src/core:resource_quota",
"//test/core/event_engine/fuzzing_event_engine",
"//test/core/event_engine/fuzzing_event_engine:fuzzing_event_engine_proto",
],
)
grpc_cc_test(
name = "frame_header_test",
srcs = ["frame_header_test.cc"],
@ -54,7 +82,7 @@ grpc_cc_test(
deps = ["//src/core:chaotic_good_frame"],
)
grpc_fuzzer(
grpc_proto_fuzzer(
name = "frame_fuzzer",
srcs = ["frame_fuzzer.cc"],
corpus = "frame_fuzzer_corpus",
@ -63,7 +91,10 @@ grpc_fuzzer(
"absl/status:statusor",
],
language = "C++",
proto = "frame_fuzzer.proto",
tags = ["no_windows"],
uses_event_engine = False,
uses_polling = False,
deps = [
"//:exec_ctx",
"//:gpr",
@ -96,18 +127,46 @@ grpc_cc_test(
uses_event_engine = False,
uses_polling = False,
deps = [
"mock_promise_endpoint",
"transport_test",
"//:grpc",
"//:grpc_public_hdrs",
"//src/core:arena",
"//src/core:chaotic_good_client_transport",
"//src/core:if",
"//src/core:loop",
"//src/core:seq",
"//src/core:slice_buffer",
],
)
grpc_cc_test(
name = "client_transport_error_test",
srcs = ["client_transport_error_test.cc"],
external_deps = [
"absl/functional:any_invocable",
"absl/status",
"absl/status:statusor",
"absl/strings:str_format",
"absl/types:optional",
"gtest",
],
language = "C++",
uses_event_engine = False,
uses_polling = False,
deps = [
"//:grpc_public_hdrs",
"//:grpc_unsecure",
"//:iomgr_timer",
"//:ref_counted_ptr",
"//src/core:activity",
"//src/core:arena",
"//src/core:chaotic_good_client_transport",
"//src/core:event_engine_wakeup_scheduler",
"//src/core:grpc_promise_endpoint",
"//src/core:if",
"//src/core:join",
"//src/core:loop",
"//src/core:map",
"//src/core:memory_quota",
"//src/core:pipe",
"//src/core:resource_quota",
@ -120,8 +179,8 @@ grpc_cc_test(
)
grpc_cc_test(
name = "client_transport_error_test",
srcs = ["client_transport_error_test.cc"],
name = "server_transport_test",
srcs = ["server_transport_test.cc"],
external_deps = [
"absl/functional:any_invocable",
"absl/status",
@ -134,20 +193,15 @@ grpc_cc_test(
uses_event_engine = False,
uses_polling = False,
deps = [
"mock_promise_endpoint",
"transport_test",
"//:grpc",
"//:grpc_public_hdrs",
"//:iomgr_timer",
"//:ref_counted_ptr",
"//src/core:activity",
"//src/core:arena",
"//src/core:chaotic_good_client_transport",
"//src/core:event_engine_wakeup_scheduler",
"//src/core:grpc_promise_endpoint",
"//src/core:if",
"//src/core:join",
"//src/core:loop",
"//src/core:chaotic_good_server_transport",
"//src/core:memory_quota",
"//src/core:pipe",
"//src/core:resource_quota",
"//src/core:seq",
"//src/core:slice",

@ -12,37 +12,31 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "absl/status/status.h"
#include "src/core/ext/transport/chaotic_good/client_transport.h"
#include "src/core/lib/transport/promise_endpoint.h"
#include "src/core/lib/transport/transport.h"
// IWYU pragma: no_include <sys/socket.h>
#include <stddef.h>
#include <algorithm> // IWYU pragma: keep
#include <algorithm>
#include <memory>
#include <string> // IWYU pragma: keep
#include <string>
#include <tuple>
#include <utility>
#include <vector> // IWYU pragma: keep
#include <vector>
#include "absl/functional/any_invocable.h"
#include "absl/status/statusor.h" // IWYU pragma: keep
#include "absl/strings/str_format.h" // IWYU pragma: keep
#include "absl/types/optional.h" // IWYU pragma: keep
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "absl/types/optional.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include <grpc/event_engine/event_engine.h>
#include <grpc/event_engine/memory_allocator.h>
#include <grpc/event_engine/slice.h> // IWYU pragma: keep
#include <grpc/event_engine/slice.h>
#include <grpc/event_engine/slice_buffer.h>
#include <grpc/grpc.h>
#include <grpc/status.h> // IWYU pragma: keep
#include <grpc/status.h>
#include "src/core/ext/transport/chaotic_good/client_transport.h"
#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/lib/iomgr/timer_manager.h"
#include "src/core/lib/promise/activity.h"
@ -56,14 +50,16 @@
#include "src/core/lib/resource_quota/memory_quota.h"
#include "src/core/lib/resource_quota/resource_quota.h"
#include "src/core/lib/slice/slice_buffer.h"
#include "src/core/lib/slice/slice_internal.h" // IWYU pragma: keep
#include "src/core/lib/transport/metadata_batch.h" // IWYU pragma: keep
#include "src/core/lib/slice/slice_internal.h"
#include "src/core/lib/transport/metadata_batch.h"
#include "src/core/lib/transport/promise_endpoint.h"
#include "src/core/lib/transport/transport.h"
#include "test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.h"
#include "test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.pb.h"
using testing::AtMost;
using testing::MockFunction;
using testing::Return;
using testing::Sequence;
using testing::StrictMock;
using testing::WithArgs;
@ -98,333 +94,308 @@ class MockEndpoint
GetLocalAddress, (), (const, override));
};
struct MockPromiseEndpoint {
StrictMock<MockEndpoint>* endpoint = new StrictMock<MockEndpoint>();
std::unique_ptr<PromiseEndpoint> promise_endpoint =
std::make_unique<PromiseEndpoint>(
std::unique_ptr<StrictMock<MockEndpoint>>(endpoint), SliceBuffer());
};
// Send messages from client to server.
auto SendClientToServerMessages(CallInitiator initiator, int num_messages) {
return Loop([initiator, num_messages]() mutable {
bool has_message = (num_messages > 0);
return If(
has_message,
Seq(initiator.PushMessage(GetContext<Arena>()->MakePooled<Message>()),
[&num_messages]() -> LoopCtl<absl::Status> {
--num_messages;
return Continue();
}),
[initiator]() mutable -> LoopCtl<absl::Status> {
initiator.FinishSends();
return absl::OkStatus();
});
});
}
ClientMetadataHandle TestInitialMetadata() {
auto md =
GetContext<Arena>()->MakePooled<ClientMetadata>(GetContext<Arena>());
md->Set(HttpPathMetadata(), Slice::FromStaticString("/test"));
return md;
}
class ClientTransportTest : public ::testing::Test {
public:
ClientTransportTest()
: control_endpoint_ptr_(new StrictMock<MockEndpoint>()),
data_endpoint_ptr_(new StrictMock<MockEndpoint>()),
memory_allocator_(
ResourceQuota::Default()->memory_quota()->CreateMemoryAllocator(
"test")),
control_endpoint_(*control_endpoint_ptr_),
data_endpoint_(*data_endpoint_ptr_),
event_engine_(std::make_shared<
grpc_event_engine::experimental::FuzzingEventEngine>(
[]() {
grpc_timer_manager_set_threading(false);
grpc_event_engine::experimental::FuzzingEventEngine::Options
options;
return options;
}(),
fuzzing_event_engine::Actions())),
arena_(MakeScopedArena(initial_arena_size, &memory_allocator_)),
pipe_client_to_server_messages_(arena_.get()),
pipe_server_to_client_messages_(arena_.get()),
pipe_server_intial_metadata_(arena_.get()),
pipe_client_to_server_messages_second_(arena_.get()),
pipe_server_to_client_messages_second_(arena_.get()),
pipe_server_intial_metadata_second_(arena_.get()) {}
// Initial ClientTransport with read expecations
void InitialClientTransport() {
client_transport_ = std::make_unique<ClientTransport>(
std::make_unique<PromiseEndpoint>(
std::unique_ptr<MockEndpoint>(control_endpoint_ptr_),
SliceBuffer()),
std::make_unique<PromiseEndpoint>(
std::unique_ptr<MockEndpoint>(data_endpoint_ptr_), SliceBuffer()),
event_engine_);
}
// Send messages from client to server.
auto SendClientToServerMessages(
Pipe<MessageHandle>& pipe_client_to_server_messages,
int num_of_messages) {
return Loop([&pipe_client_to_server_messages, num_of_messages,
this]() mutable {
bool has_message = (num_of_messages > 0);
return If(
has_message,
Seq(pipe_client_to_server_messages.sender.Push(
arena_->MakePooled<Message>()),
[&num_of_messages]() -> LoopCtl<absl::Status> {
num_of_messages--;
return Continue();
}),
[&pipe_client_to_server_messages]() mutable -> LoopCtl<absl::Status> {
pipe_client_to_server_messages.sender.Close();
return absl::OkStatus();
});
});
}
// Add stream into client transport, and expect return trailers of
// "grpc-status:code".
auto AddStream(CallArgs args) {
return client_transport_->AddStream(std::move(args));
protected:
const std::shared_ptr<grpc_event_engine::experimental::FuzzingEventEngine>&
event_engine() {
return event_engine_;
}
MemoryAllocator* memory_allocator() { return &allocator_; }
private:
MockEndpoint* control_endpoint_ptr_;
MockEndpoint* data_endpoint_ptr_;
size_t initial_arena_size = 1024;
MemoryAllocator memory_allocator_;
protected:
MockEndpoint& control_endpoint_;
MockEndpoint& data_endpoint_;
std::shared_ptr<grpc_event_engine::experimental::FuzzingEventEngine>
event_engine_;
std::unique_ptr<ClientTransport> client_transport_;
ScopedArenaPtr arena_;
Pipe<MessageHandle> pipe_client_to_server_messages_;
Pipe<MessageHandle> pipe_server_to_client_messages_;
Pipe<ServerMetadataHandle> pipe_server_intial_metadata_;
// Added for mutliple streams tests.
Pipe<MessageHandle> pipe_client_to_server_messages_second_;
Pipe<MessageHandle> pipe_server_to_client_messages_second_;
Pipe<ServerMetadataHandle> pipe_server_intial_metadata_second_;
absl::AnyInvocable<void(absl::Status)> read_callback_;
Sequence control_endpoint_sequence_;
Sequence data_endpoint_sequence_;
// Added to verify received message payload.
const std::string message_ = {0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08};
event_engine_{
std::make_shared<grpc_event_engine::experimental::FuzzingEventEngine>(
[]() {
grpc_timer_manager_set_threading(false);
grpc_event_engine::experimental::FuzzingEventEngine::Options
options;
return options;
}(),
fuzzing_event_engine::Actions())};
MemoryAllocator allocator_ = MakeResourceQuota("test-quota")
->memory_quota()
->CreateMemoryAllocator("test-allocator");
};
TEST_F(ClientTransportTest, AddOneStreamWithWriteFailed) {
MockPromiseEndpoint control_endpoint;
MockPromiseEndpoint data_endpoint;
// Mock write failed and read is pending.
EXPECT_CALL(control_endpoint_, Write)
EXPECT_CALL(*control_endpoint.endpoint, Write)
.Times(AtMost(1))
.WillOnce(
WithArgs<0>([](absl::AnyInvocable<void(absl::Status)> on_write) {
on_write(absl::InternalError("control endpoint write failed."));
return false;
}));
EXPECT_CALL(data_endpoint_, Write)
EXPECT_CALL(*data_endpoint.endpoint, Write)
.Times(AtMost(1))
.WillOnce(
WithArgs<0>([](absl::AnyInvocable<void(absl::Status)> on_write) {
on_write(absl::InternalError("data endpoint write failed."));
return false;
}));
EXPECT_CALL(control_endpoint_, Read)
.InSequence(control_endpoint_sequence_)
.WillOnce(Return(false));
InitialClientTransport();
ClientMetadataHandle md;
auto args = CallArgs{std::move(md),
ClientInitialMetadataOutstandingToken::Empty(),
nullptr,
&pipe_server_intial_metadata_.sender,
&pipe_client_to_server_messages_.receiver,
&pipe_server_to_client_messages_.sender};
StrictMock<MockFunction<void(absl::Status)>> on_done;
EXPECT_CALL(on_done, Call(absl::OkStatus()));
auto activity = MakeActivity(
Seq(
// Concurrently: write and read messages in client transport.
Join(
// Add first stream with call_args into client transport.
// Expect return trailers "grpc-status:unavailable".
AddStream(std::move(args)),
// Send messages to call_args.client_to_server_messages pipe,
// which will be eventually sent to control/data endpoints.
SendClientToServerMessages(pipe_client_to_server_messages_, 1)),
// Once complete, verify successful sending and the received value.
[](const std::tuple<ServerMetadataHandle, absl::Status>& ret) {
EXPECT_EQ(std::get<0>(ret)->get(GrpcStatusMetadata()).value(),
GRPC_STATUS_UNAVAILABLE);
EXPECT_TRUE(std::get<1>(ret).ok());
return absl::OkStatus();
}),
EventEngineWakeupScheduler(event_engine_),
[&on_done](absl::Status status) { on_done.Call(std::move(status)); });
EXPECT_CALL(*control_endpoint.endpoint, Read).WillOnce(Return(false));
auto transport = MakeOrphanable<ChaoticGoodClientTransport>(
std::move(control_endpoint.promise_endpoint),
std::move(data_endpoint.promise_endpoint), event_engine());
auto call =
MakeCall(event_engine().get(), Arena::Create(8192, memory_allocator()));
transport->StartCall(std::move(call.handler));
call.initiator.SpawnGuarded("test-send", [initiator =
call.initiator]() mutable {
return TrySeq(initiator.PushClientInitialMetadata(TestInitialMetadata()),
SendClientToServerMessages(initiator, 1));
});
StrictMock<MockFunction<void()>> on_done;
EXPECT_CALL(on_done, Call());
call.initiator.SpawnInfallible(
"test-read", [&on_done, initiator = call.initiator]() mutable {
return Seq(
initiator.PullServerInitialMetadata(),
[](ValueOrFailure<ServerMetadataHandle> md) {
EXPECT_FALSE(md.ok());
return Empty{};
},
initiator.PullServerTrailingMetadata(),
[&on_done](ServerMetadataHandle md) {
EXPECT_EQ(md->get(GrpcStatusMetadata()).value(),
GRPC_STATUS_UNAVAILABLE);
on_done.Call();
return Empty{};
});
});
// Wait until ClientTransport's internal activities to finish.
event_engine_->TickUntilIdle();
event_engine_->UnsetGlobalHooks();
event_engine()->TickUntilIdle();
event_engine()->UnsetGlobalHooks();
}
TEST_F(ClientTransportTest, AddOneStreamWithReadFailed) {
MockPromiseEndpoint control_endpoint;
MockPromiseEndpoint data_endpoint;
// Mock read failed.
EXPECT_CALL(control_endpoint_, Read)
.InSequence(control_endpoint_sequence_)
EXPECT_CALL(*control_endpoint.endpoint, Read)
.WillOnce(WithArgs<0>(
[](absl::AnyInvocable<void(absl::Status)> on_read) mutable {
on_read(absl::InternalError("control endpoint read failed."));
// Return false to mock EventEngine read not finish.
return false;
}));
InitialClientTransport();
ClientMetadataHandle md;
auto args = CallArgs{std::move(md),
ClientInitialMetadataOutstandingToken::Empty(),
nullptr,
&pipe_server_intial_metadata_.sender,
&pipe_client_to_server_messages_.receiver,
&pipe_server_to_client_messages_.sender};
StrictMock<MockFunction<void(absl::Status)>> on_done;
EXPECT_CALL(on_done, Call(absl::OkStatus()));
auto activity = MakeActivity(
Seq(
// Concurrently: write and read messages in client transport.
Join(
// Add first stream with call_args into client transport.
// Expect return trailers "grpc-status:unavailable".
AddStream(std::move(args)),
// Send messages to call_args.client_to_server_messages pipe.
SendClientToServerMessages(pipe_client_to_server_messages_, 1)),
// Once complete, verify successful sending and the received value.
[](const std::tuple<ServerMetadataHandle, absl::Status>& ret) {
EXPECT_EQ(std::get<0>(ret)->get(GrpcStatusMetadata()).value(),
GRPC_STATUS_UNAVAILABLE);
EXPECT_TRUE(std::get<1>(ret).ok());
return absl::OkStatus();
}),
EventEngineWakeupScheduler(event_engine_),
[&on_done](absl::Status status) { on_done.Call(std::move(status)); });
auto transport = MakeOrphanable<ChaoticGoodClientTransport>(
std::move(control_endpoint.promise_endpoint),
std::move(data_endpoint.promise_endpoint), event_engine());
auto call =
MakeCall(event_engine().get(), Arena::Create(8192, memory_allocator()));
transport->StartCall(std::move(call.handler));
call.initiator.SpawnGuarded("test-send", [initiator =
call.initiator]() mutable {
return TrySeq(initiator.PushClientInitialMetadata(TestInitialMetadata()),
SendClientToServerMessages(initiator, 1));
});
StrictMock<MockFunction<void()>> on_done;
EXPECT_CALL(on_done, Call());
call.initiator.SpawnInfallible(
"test-read", [&on_done, initiator = call.initiator]() mutable {
return Seq(
initiator.PullServerInitialMetadata(),
[](ValueOrFailure<ServerMetadataHandle> md) {
EXPECT_FALSE(md.ok());
return Empty{};
},
initiator.PullServerTrailingMetadata(),
[&on_done](ServerMetadataHandle md) {
EXPECT_EQ(md->get(GrpcStatusMetadata()).value(),
GRPC_STATUS_UNAVAILABLE);
on_done.Call();
return Empty{};
});
});
// Wait until ClientTransport's internal activities to finish.
event_engine_->TickUntilIdle();
event_engine_->UnsetGlobalHooks();
event_engine()->TickUntilIdle();
event_engine()->UnsetGlobalHooks();
}
TEST_F(ClientTransportTest, AddMultipleStreamWithWriteFailed) {
// Mock write failed at first stream and second stream's write will fail too.
EXPECT_CALL(control_endpoint_, Write)
.Times(1)
MockPromiseEndpoint control_endpoint;
MockPromiseEndpoint data_endpoint;
EXPECT_CALL(*control_endpoint.endpoint, Write)
.Times(AtMost(1))
.WillRepeatedly(
WithArgs<0>([](absl::AnyInvocable<void(absl::Status)> on_write) {
on_write(absl::InternalError("control endpoint write failed."));
return false;
}));
EXPECT_CALL(data_endpoint_, Write)
.Times(1)
EXPECT_CALL(*data_endpoint.endpoint, Write)
.Times(AtMost(1))
.WillRepeatedly(
WithArgs<0>([](absl::AnyInvocable<void(absl::Status)> on_write) {
on_write(absl::InternalError("data endpoint write failed."));
return false;
}));
EXPECT_CALL(control_endpoint_, Read)
.InSequence(control_endpoint_sequence_)
.WillOnce(Return(false));
InitialClientTransport();
ClientMetadataHandle first_stream_md;
ClientMetadataHandle second_stream_md;
auto first_stream_args =
CallArgs{std::move(first_stream_md),
ClientInitialMetadataOutstandingToken::Empty(),
nullptr,
&pipe_server_intial_metadata_.sender,
&pipe_client_to_server_messages_.receiver,
&pipe_server_to_client_messages_.sender};
auto second_stream_args =
CallArgs{std::move(second_stream_md),
ClientInitialMetadataOutstandingToken::Empty(),
nullptr,
&pipe_server_intial_metadata_second_.sender,
&pipe_client_to_server_messages_second_.receiver,
&pipe_server_to_client_messages_second_.sender};
StrictMock<MockFunction<void(absl::Status)>> on_done;
EXPECT_CALL(on_done, Call(absl::OkStatus()));
auto activity = MakeActivity(
Seq(
// Concurrently: write and read messages from client transport.
Join(
// Add first stream with call_args into client transport.
// Expect return trailers "grpc-status:unavailable".
AddStream(std::move(first_stream_args)),
// Send messages to first stream's
// call_args.client_to_server_messages pipe.
SendClientToServerMessages(pipe_client_to_server_messages_, 1)),
// Once complete, verify successful sending and the received value.
[](const std::tuple<ServerMetadataHandle, absl::Status>& ret) {
EXPECT_EQ(std::get<0>(ret)->get(GrpcStatusMetadata()).value(),
GRPC_STATUS_UNAVAILABLE);
EXPECT_TRUE(std::get<1>(ret).ok());
return absl::OkStatus();
},
Join(
// Add second stream with call_args into client transport.
// Expect return trailers "grpc-status:unavailable".
AddStream(std::move(second_stream_args)),
// Send messages to second stream's
// call_args.client_to_server_messages pipe.
SendClientToServerMessages(pipe_client_to_server_messages_second_,
1)),
// Once complete, verify successful sending and the received value.
[](const std::tuple<ServerMetadataHandle, absl::Status>& ret) {
EXPECT_EQ(std::get<0>(ret)->get(GrpcStatusMetadata()).value(),
GRPC_STATUS_UNAVAILABLE);
EXPECT_TRUE(std::get<1>(ret).ok());
return absl::OkStatus();
}),
EventEngineWakeupScheduler(event_engine_),
[&on_done](absl::Status status) { on_done.Call(std::move(status)); });
EXPECT_CALL(*control_endpoint.endpoint, Read).WillOnce(Return(false));
auto transport = MakeOrphanable<ChaoticGoodClientTransport>(
std::move(control_endpoint.promise_endpoint),
std::move(data_endpoint.promise_endpoint), event_engine());
auto call1 =
MakeCall(event_engine().get(), Arena::Create(8192, memory_allocator()));
transport->StartCall(std::move(call1.handler));
auto call2 =
MakeCall(event_engine().get(), Arena::Create(8192, memory_allocator()));
transport->StartCall(std::move(call2.handler));
call1.initiator.SpawnGuarded("test-send-1", [initiator =
call1.initiator]() mutable {
return TrySeq(initiator.PushClientInitialMetadata(TestInitialMetadata()),
SendClientToServerMessages(initiator, 1));
});
call2.initiator.SpawnGuarded("test-send-2", [initiator =
call2.initiator]() mutable {
return TrySeq(initiator.PushClientInitialMetadata(TestInitialMetadata()),
SendClientToServerMessages(initiator, 1));
});
StrictMock<MockFunction<void()>> on_done1;
EXPECT_CALL(on_done1, Call());
StrictMock<MockFunction<void()>> on_done2;
EXPECT_CALL(on_done2, Call());
call1.initiator.SpawnInfallible(
"test-read-1", [&on_done1, initiator = call1.initiator]() mutable {
return Seq(
initiator.PullServerInitialMetadata(),
[](ValueOrFailure<ServerMetadataHandle> md) {
EXPECT_FALSE(md.ok());
return Empty{};
},
initiator.PullServerTrailingMetadata(),
[&on_done1](ServerMetadataHandle md) {
EXPECT_EQ(md->get(GrpcStatusMetadata()).value(),
GRPC_STATUS_UNAVAILABLE);
on_done1.Call();
return Empty{};
});
});
call2.initiator.SpawnInfallible(
"test-read-2", [&on_done2, initiator = call2.initiator]() mutable {
return Seq(
initiator.PullServerInitialMetadata(),
[](ValueOrFailure<ServerMetadataHandle> md) {
EXPECT_FALSE(md.ok());
return Empty{};
},
initiator.PullServerTrailingMetadata(),
[&on_done2](ServerMetadataHandle md) {
EXPECT_EQ(md->get(GrpcStatusMetadata()).value(),
GRPC_STATUS_UNAVAILABLE);
on_done2.Call();
return Empty{};
});
});
// Wait until ClientTransport's internal activities to finish.
event_engine_->TickUntilIdle();
event_engine_->UnsetGlobalHooks();
event_engine()->TickUntilIdle();
event_engine()->UnsetGlobalHooks();
}
TEST_F(ClientTransportTest, AddMultipleStreamWithReadFailed) {
MockPromiseEndpoint control_endpoint;
MockPromiseEndpoint data_endpoint;
// Mock read failed at first stream, and second stream's write will fail too.
EXPECT_CALL(control_endpoint_, Read)
.InSequence(control_endpoint_sequence_)
EXPECT_CALL(*control_endpoint.endpoint, Read)
.WillOnce(WithArgs<0>(
[](absl::AnyInvocable<void(absl::Status)> on_read) mutable {
on_read(absl::InternalError("control endpoint read failed."));
// Return false to mock EventEngine read not finish.
return false;
}));
InitialClientTransport();
ClientMetadataHandle first_stream_md;
ClientMetadataHandle second_stream_md;
auto first_stream_args =
CallArgs{std::move(first_stream_md),
ClientInitialMetadataOutstandingToken::Empty(),
nullptr,
&pipe_server_intial_metadata_.sender,
&pipe_client_to_server_messages_.receiver,
&pipe_server_to_client_messages_.sender};
auto second_stream_args =
CallArgs{std::move(second_stream_md),
ClientInitialMetadataOutstandingToken::Empty(),
nullptr,
&pipe_server_intial_metadata_second_.sender,
&pipe_client_to_server_messages_second_.receiver,
&pipe_server_to_client_messages_second_.sender};
StrictMock<MockFunction<void(absl::Status)>> on_done;
EXPECT_CALL(on_done, Call(absl::OkStatus()));
auto activity = MakeActivity(
Seq(
// Concurrently: write and read messages from client transport.
Join(
// Add first stream with call_args into client transport.
AddStream(std::move(first_stream_args)),
// Send messages to first stream's
// call_args.client_to_server_messages pipe, which will be
// eventually sent to control/data endpoints.
SendClientToServerMessages(pipe_client_to_server_messages_, 1)),
// Once complete, verify successful sending and the received value.
[](const std::tuple<ServerMetadataHandle, absl::Status>& ret) {
EXPECT_EQ(std::get<0>(ret)->get(GrpcStatusMetadata()).value(),
GRPC_STATUS_UNAVAILABLE);
EXPECT_TRUE(std::get<1>(ret).ok());
return absl::OkStatus();
},
Join(
// Add second stream with call_args into client transport.
AddStream(std::move(second_stream_args)),
// Send messages to second stream's
// call_args.client_to_server_messages pipe, which will be
// eventually sent to control/data endpoints.
SendClientToServerMessages(pipe_client_to_server_messages_second_,
1)),
// Once complete, verify successful sending and the received value.
[](const std::tuple<ServerMetadataHandle, absl::Status>& ret) {
EXPECT_EQ(std::get<0>(ret)->get(GrpcStatusMetadata()).value(),
GRPC_STATUS_UNAVAILABLE);
EXPECT_TRUE(std::get<1>(ret).ok());
return absl::OkStatus();
}),
EventEngineWakeupScheduler(event_engine_),
[&on_done](absl::Status status) { on_done.Call(std::move(status)); });
auto transport = MakeOrphanable<ChaoticGoodClientTransport>(
std::move(control_endpoint.promise_endpoint),
std::move(data_endpoint.promise_endpoint), event_engine());
auto call1 =
MakeCall(event_engine().get(), Arena::Create(8192, memory_allocator()));
transport->StartCall(std::move(call1.handler));
auto call2 =
MakeCall(event_engine().get(), Arena::Create(8192, memory_allocator()));
transport->StartCall(std::move(call2.handler));
call1.initiator.SpawnGuarded("test-send", [initiator =
call1.initiator]() mutable {
return TrySeq(initiator.PushClientInitialMetadata(TestInitialMetadata()),
SendClientToServerMessages(initiator, 1));
});
call2.initiator.SpawnGuarded("test-send", [initiator =
call2.initiator]() mutable {
return TrySeq(initiator.PushClientInitialMetadata(TestInitialMetadata()),
SendClientToServerMessages(initiator, 1));
});
StrictMock<MockFunction<void()>> on_done1;
EXPECT_CALL(on_done1, Call());
StrictMock<MockFunction<void()>> on_done2;
EXPECT_CALL(on_done2, Call());
call1.initiator.SpawnInfallible(
"test-read", [&on_done1, initiator = call1.initiator]() mutable {
return Seq(
initiator.PullServerInitialMetadata(),
[](ValueOrFailure<ServerMetadataHandle> md) {
EXPECT_FALSE(md.ok());
return Empty{};
},
initiator.PullServerTrailingMetadata(),
[&on_done1](ServerMetadataHandle md) {
EXPECT_EQ(md->get(GrpcStatusMetadata()).value(),
GRPC_STATUS_UNAVAILABLE);
on_done1.Call();
return Empty{};
});
});
call2.initiator.SpawnInfallible(
"test-read", [&on_done2, initiator = call2.initiator]() mutable {
return Seq(
initiator.PullServerInitialMetadata(),
[](ValueOrFailure<ServerMetadataHandle> md) {
EXPECT_FALSE(md.ok());
return Empty{};
},
initiator.PullServerTrailingMetadata(),
[&on_done2](ServerMetadataHandle md) {
EXPECT_EQ(md->get(GrpcStatusMetadata()).value(),
GRPC_STATUS_UNAVAILABLE);
on_done2.Call();
return Empty{};
});
});
// Wait until ClientTransport's internal activities to finish.
event_engine_->TickUntilIdle();
event_engine_->UnsetGlobalHooks();
event_engine()->TickUntilIdle();
event_engine()->UnsetGlobalHooks();
}
} // namespace testing

@ -14,461 +14,245 @@
#include "src/core/ext/transport/chaotic_good/client_transport.h"
// IWYU pragma: no_include <sys/socket.h>
#include <algorithm> // IWYU pragma: keep
#include <algorithm>
#include <cstdlib>
#include <initializer_list>
#include <memory>
#include <string> // IWYU pragma: keep
#include <string>
#include <tuple>
#include <vector> // IWYU pragma: keep
#include <vector>
#include "absl/functional/any_invocable.h"
#include "absl/status/statusor.h" // IWYU pragma: keep
#include "absl/strings/str_format.h" // IWYU pragma: keep
#include "absl/types/optional.h" // IWYU pragma: keep
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "absl/types/optional.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include <grpc/event_engine/event_engine.h>
#include <grpc/event_engine/memory_allocator.h>
#include <grpc/event_engine/slice.h> // IWYU pragma: keep
#include <grpc/event_engine/slice.h>
#include <grpc/event_engine/slice_buffer.h>
#include <grpc/grpc.h>
#include <grpc/status.h> // IWYU pragma: keep
#include <grpc/status.h>
#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/lib/iomgr/timer_manager.h"
#include "src/core/lib/promise/activity.h"
#include "src/core/lib/promise/event_engine_wakeup_scheduler.h"
#include "src/core/lib/promise/if.h"
#include "src/core/lib/promise/join.h"
#include "src/core/lib/promise/loop.h"
#include "src/core/lib/promise/map.h"
#include "src/core/lib/promise/pipe.h"
#include "src/core/lib/promise/seq.h"
#include "src/core/lib/resource_quota/arena.h"
#include "src/core/lib/resource_quota/memory_quota.h"
#include "src/core/lib/resource_quota/resource_quota.h"
#include "src/core/lib/slice/slice_buffer.h"
#include "src/core/lib/slice/slice_internal.h" // IWYU pragma: keep
#include "src/core/lib/transport/metadata_batch.h" // IWYU pragma: keep
#include "test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.h"
#include "test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.pb.h"
#include "src/core/lib/transport/metadata_batch.h"
#include "test/core/transport/chaotic_good/mock_promise_endpoint.h"
#include "test/core/transport/chaotic_good/transport_test.h"
using testing::MockFunction;
using testing::Return;
using testing::Sequence;
using testing::StrictMock;
using testing::WithArgs;
using EventEngineSlice = grpc_event_engine::experimental::Slice;
namespace grpc_core {
namespace chaotic_good {
namespace testing {
class MockEndpoint
: public grpc_event_engine::experimental::EventEngine::Endpoint {
public:
MOCK_METHOD(
bool, Read,
(absl::AnyInvocable<void(absl::Status)> on_read,
grpc_event_engine::experimental::SliceBuffer* buffer,
const grpc_event_engine::experimental::EventEngine::Endpoint::ReadArgs*
args),
(override));
// Encoded string of header ":path: /demo.Service/Step".
const uint8_t kPathDemoServiceStep[] = {
0x40, 0x05, 0x3a, 0x70, 0x61, 0x74, 0x68, 0x12, 0x2f,
0x64, 0x65, 0x6d, 0x6f, 0x2e, 0x53, 0x65, 0x72, 0x76,
0x69, 0x63, 0x65, 0x2f, 0x53, 0x74, 0x65, 0x70};
MOCK_METHOD(
bool, Write,
(absl::AnyInvocable<void(absl::Status)> on_writable,
grpc_event_engine::experimental::SliceBuffer* data,
const grpc_event_engine::experimental::EventEngine::Endpoint::WriteArgs*
args),
(override));
// Encoded string of trailer "grpc-status: 0".
const uint8_t kGrpcStatus0[] = {0x10, 0x0b, 0x67, 0x72, 0x70, 0x63, 0x2d, 0x73,
0x74, 0x61, 0x74, 0x75, 0x73, 0x01, 0x30};
MOCK_METHOD(
const grpc_event_engine::experimental::EventEngine::ResolvedAddress&,
GetPeerAddress, (), (const, override));
MOCK_METHOD(
const grpc_event_engine::experimental::EventEngine::ResolvedAddress&,
GetLocalAddress, (), (const, override));
};
ClientMetadataHandle TestInitialMetadata() {
auto md =
GetContext<Arena>()->MakePooled<ClientMetadata>(GetContext<Arena>());
md->Set(HttpPathMetadata(), Slice::FromStaticString("/demo.Service/Step"));
return md;
}
class ClientTransportTest : public ::testing::Test {
public:
ClientTransportTest()
: control_endpoint_ptr_(new StrictMock<MockEndpoint>()),
data_endpoint_ptr_(new StrictMock<MockEndpoint>()),
memory_allocator_(
ResourceQuota::Default()->memory_quota()->CreateMemoryAllocator(
"test")),
control_endpoint_(*control_endpoint_ptr_),
data_endpoint_(*data_endpoint_ptr_),
event_engine_(std::make_shared<
grpc_event_engine::experimental::FuzzingEventEngine>(
[]() {
grpc_timer_manager_set_threading(false);
grpc_event_engine::experimental::FuzzingEventEngine::Options
options;
return options;
}(),
fuzzing_event_engine::Actions())),
arena_(MakeScopedArena(initial_arena_size, &memory_allocator_)),
pipe_client_to_server_messages_(arena_.get()),
pipe_server_to_client_messages_(arena_.get()),
pipe_server_intial_metadata_(arena_.get()),
pipe_client_to_server_messages_second_(arena_.get()),
pipe_server_to_client_messages_second_(arena_.get()),
pipe_server_intial_metadata_second_(arena_.get()) {}
// Expect how client transport will read from control/data endpoints with a
// test frame.
void AddReadExpectations(int num_of_streams) {
for (int i = 0; i < num_of_streams; i++) {
EXPECT_CALL(control_endpoint_, Read)
.InSequence(control_endpoint_sequence)
.WillOnce(WithArgs<0, 1>(
[this, i](absl::AnyInvocable<void(absl::Status)> on_read,
grpc_event_engine::experimental::SliceBuffer*
buffer) mutable {
// Construct test frame for EventEngine read: headers (15
// bytes), message(16 bytes), message padding (48 byte),
// trailers (15 bytes).
const std::string frame_header = {
static_cast<char>(0x80), // frame type = fragment
0x03, // flag = has header + has trailer
0x00,
0x00,
static_cast<char>(i + 1), // stream id = 1
0x00,
0x00,
0x00,
0x1a, // header length = 26
0x00,
0x00,
0x00,
0x08, // message length = 8
0x00,
0x00,
0x00,
0x38, // message padding =56
0x00,
0x00,
0x00,
0x0f, // trailer length = 15
0x00,
0x00,
0x00};
// Schedule mock_endpoint to read buffer.
grpc_event_engine::experimental::Slice slice(
grpc_slice_from_cpp_string(frame_header));
buffer->Append(std::move(slice));
// Execute read callback later to control when read starts.
if (i == 0) {
read_callback_ = std::move(on_read);
// Return false to mock EventEngine read not finish.
return false;
} else {
return true;
}
}));
EXPECT_CALL(control_endpoint_, Read)
.InSequence(control_endpoint_sequence)
.WillOnce(WithArgs<1>(
[](grpc_event_engine::experimental::SliceBuffer* buffer) {
// Encoded string of header ":path: /demo.Service/Step".
const std::string header = {
0x10, 0x05, 0x3a, 0x70, 0x61, 0x74, 0x68, 0x12, 0x2f,
0x64, 0x65, 0x6d, 0x6f, 0x2e, 0x53, 0x65, 0x72, 0x76,
0x69, 0x63, 0x65, 0x2f, 0x53, 0x74, 0x65, 0x70};
// Encoded string of trailer "grpc-status: 0".
const std::string trailers = {0x10, 0x0b, 0x67, 0x72, 0x70,
0x63, 0x2d, 0x73, 0x74, 0x61,
0x74, 0x75, 0x73, 0x01, 0x30};
// Schedule mock_endpoint to read buffer.
grpc_event_engine::experimental::Slice slice(
grpc_slice_from_cpp_string(header + trailers));
buffer->Append(std::move(slice));
return true;
}));
}
EXPECT_CALL(control_endpoint_, Read)
.InSequence(control_endpoint_sequence)
.WillOnce(Return(false));
for (int i = 0; i < num_of_streams; i++) {
EXPECT_CALL(data_endpoint_, Read)
.InSequence(data_endpoint_sequence)
.WillOnce(WithArgs<1>(
[this](grpc_event_engine::experimental::SliceBuffer* buffer) {
const std::string message_padding = {
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00};
grpc_event_engine::experimental::Slice slice(
grpc_slice_from_cpp_string(message_padding + message_));
buffer->Append(std::move(slice));
return true;
}));
}
}
// Initial ClientTransport with read expecations
void InitialClientTransport(int num_of_streams) {
// Read expectaions need to be added before transport initialization since
// reader_ activity loop is started in ClientTransport initialization,
AddReadExpectations(num_of_streams);
client_transport_ = std::make_unique<ClientTransport>(
std::make_unique<PromiseEndpoint>(
std::unique_ptr<MockEndpoint>(control_endpoint_ptr_),
SliceBuffer()),
std::make_unique<PromiseEndpoint>(
std::unique_ptr<MockEndpoint>(data_endpoint_ptr_), SliceBuffer()),
event_engine_);
}
// Send messages from client to server.
auto SendClientToServerMessages(
Pipe<MessageHandle>& pipe_client_to_server_messages,
int num_of_messages) {
return Loop([&pipe_client_to_server_messages, num_of_messages,
this]() mutable {
bool has_message = (num_of_messages > 0);
return If(
has_message,
Seq(pipe_client_to_server_messages.sender.Push(
arena_->MakePooled<Message>()),
[&num_of_messages]() -> LoopCtl<absl::Status> {
num_of_messages--;
return Continue();
}),
[&pipe_client_to_server_messages]() mutable -> LoopCtl<absl::Status> {
pipe_client_to_server_messages.sender.Close();
return absl::OkStatus();
});
});
}
// Add stream into client transport, and expect return trailers of
// "grpc-status:code".
auto AddStream(CallArgs args, const grpc_status_code trailers) {
return Seq(client_transport_->AddStream(std::move(args)),
[trailers](ServerMetadataHandle ret) {
// AddStream will finish with server trailers:
// "grpc-status:code".
EXPECT_EQ(ret->get(GrpcStatusMetadata()).value(), trailers);
return trailers;
});
}
// Start read from control endpoints.
auto StartRead(const absl::Status& read_status) {
return [read_status, this] {
read_callback_(read_status);
return read_status;
};
}
// Receive messages from server to client.
auto ReceiveServerToClientMessages(
Pipe<ServerMetadataHandle>& pipe_server_intial_metadata,
Pipe<MessageHandle>& pipe_server_to_client_messages) {
return Seq(
// Receive server initial metadata.
Map(pipe_server_intial_metadata.receiver.Next(),
[](NextResult<ServerMetadataHandle> r) {
// Expect value: ":path: /demo.Service/Step"
EXPECT_TRUE(r.has_value());
EXPECT_EQ(
r.value()->get_pointer(HttpPathMetadata())->as_string_view(),
"/demo.Service/Step");
return absl::OkStatus();
}),
// Receive server to client messages.
Map(pipe_server_to_client_messages.receiver.Next(),
[this](NextResult<MessageHandle> r) {
EXPECT_TRUE(r.has_value());
EXPECT_EQ(r.value()->payload()->JoinIntoString(), message_);
return absl::OkStatus();
// Send messages from client to server.
auto SendClientToServerMessages(CallInitiator initiator, int num_messages) {
return Loop([initiator, num_messages, i = 0]() mutable {
bool has_message = (i < num_messages);
return If(
has_message,
Seq(initiator.PushMessage(GetContext<Arena>()->MakePooled<Message>(
SliceBuffer(Slice::FromCopiedString(std::to_string(i))), 0)),
[&i]() -> LoopCtl<absl::Status> {
++i;
return Continue();
}),
[&pipe_server_intial_metadata,
&pipe_server_to_client_messages]() mutable {
// Close pipes after receive message.
pipe_server_to_client_messages.sender.Close();
pipe_server_intial_metadata.sender.Close();
[initiator]() mutable -> LoopCtl<absl::Status> {
initiator.FinishSends();
return absl::OkStatus();
});
}
private:
MockEndpoint* control_endpoint_ptr_;
MockEndpoint* data_endpoint_ptr_;
size_t initial_arena_size = 1024;
MemoryAllocator memory_allocator_;
Sequence control_endpoint_sequence;
Sequence data_endpoint_sequence;
protected:
MockEndpoint& control_endpoint_;
MockEndpoint& data_endpoint_;
std::shared_ptr<grpc_event_engine::experimental::FuzzingEventEngine>
event_engine_;
std::unique_ptr<ClientTransport> client_transport_;
ScopedArenaPtr arena_;
Pipe<MessageHandle> pipe_client_to_server_messages_;
Pipe<MessageHandle> pipe_server_to_client_messages_;
Pipe<ServerMetadataHandle> pipe_server_intial_metadata_;
// Added for mutliple streams tests.
Pipe<MessageHandle> pipe_client_to_server_messages_second_;
Pipe<MessageHandle> pipe_server_to_client_messages_second_;
Pipe<ServerMetadataHandle> pipe_server_intial_metadata_second_;
absl::AnyInvocable<void(absl::Status)> read_callback_;
// Added to verify received message payload.
const std::string message_ = {0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08};
};
TEST_F(ClientTransportTest, AddOneStream) {
InitialClientTransport(1);
ClientMetadataHandle md;
auto args = CallArgs{std::move(md),
ClientInitialMetadataOutstandingToken::Empty(),
nullptr,
&pipe_server_intial_metadata_.sender,
&pipe_client_to_server_messages_.receiver,
&pipe_server_to_client_messages_.sender};
StrictMock<MockFunction<void(absl::Status)>> on_done;
EXPECT_CALL(on_done, Call(absl::OkStatus()));
EXPECT_CALL(control_endpoint_, Write).WillOnce(Return(true));
EXPECT_CALL(data_endpoint_, Write).WillOnce(Return(true));
auto activity = MakeActivity(
Seq(
// Concurrently: write and read messages in client transport.
Join(
// Add first stream with call_args into client transport.
AddStream(std::move(args), GRPC_STATUS_OK),
// Start read from control endpoints.
StartRead(absl::OkStatus()),
// Send messages to call_args.client_to_server_messages pipe,
// which will be eventually sent to control/data endpoints.
SendClientToServerMessages(pipe_client_to_server_messages_, 1),
// Receive messages from control/data endpoints.
ReceiveServerToClientMessages(pipe_server_intial_metadata_,
pipe_server_to_client_messages_)),
// Once complete, verify successful sending and the received value.
[](const std::tuple<grpc_status_code, absl::Status, absl::Status,
absl::Status>& ret) {
EXPECT_EQ(std::get<0>(ret), GRPC_STATUS_OK);
EXPECT_TRUE(std::get<1>(ret).ok());
EXPECT_TRUE(std::get<2>(ret).ok());
EXPECT_TRUE(std::get<3>(ret).ok());
return absl::OkStatus();
}),
EventEngineWakeupScheduler(event_engine_),
[&on_done](absl::Status status) { on_done.Call(std::move(status)); });
// Wait until ClientTransport's internal activities to finish.
event_engine_->TickUntilIdle();
event_engine_->UnsetGlobalHooks();
});
}
TEST_F(ClientTransportTest, AddOneStreamMultipleMessages) {
InitialClientTransport(1);
ClientMetadataHandle md;
auto args = CallArgs{std::move(md),
ClientInitialMetadataOutstandingToken::Empty(),
nullptr,
&pipe_server_intial_metadata_.sender,
&pipe_client_to_server_messages_.receiver,
&pipe_server_to_client_messages_.sender};
StrictMock<MockFunction<void(absl::Status)>> on_done;
EXPECT_CALL(on_done, Call(absl::OkStatus()));
EXPECT_CALL(control_endpoint_, Write).Times(3).WillRepeatedly(Return(true));
EXPECT_CALL(data_endpoint_, Write).Times(3).WillRepeatedly(Return(true));
auto activity = MakeActivity(
Seq(
// Concurrently: write and read messages in client transport.
Join(
// Add first stream with call_args into client transport.
AddStream(std::move(args), GRPC_STATUS_OK),
// Start read from control endpoints.
StartRead(absl::OkStatus()),
// Send messages to call_args.client_to_server_messages pipe,
// which will be eventually sent to control/data endpoints.
SendClientToServerMessages(pipe_client_to_server_messages_, 3),
// Receive messages from control/data endpoints.
ReceiveServerToClientMessages(pipe_server_intial_metadata_,
pipe_server_to_client_messages_)),
// Once complete, verify successful sending and the received value.
[](const std::tuple<grpc_status_code, absl::Status, absl::Status,
absl::Status>& ret) {
EXPECT_EQ(std::get<0>(ret), GRPC_STATUS_OK);
EXPECT_TRUE(std::get<1>(ret).ok());
EXPECT_TRUE(std::get<2>(ret).ok());
EXPECT_TRUE(std::get<3>(ret).ok());
return absl::OkStatus();
}),
EventEngineWakeupScheduler(event_engine_),
[&on_done](absl::Status status) { on_done.Call(std::move(status)); });
TEST_F(TransportTest, AddOneStream) {
MockPromiseEndpoint control_endpoint;
MockPromiseEndpoint data_endpoint;
control_endpoint.ExpectRead(
{SerializedFrameHeader(FrameType::kFragment, 7, 1, 26, 8, 56, 15),
EventEngineSlice::FromCopiedBuffer(kPathDemoServiceStep,
sizeof(kPathDemoServiceStep)),
EventEngineSlice::FromCopiedBuffer(kGrpcStatus0, sizeof(kGrpcStatus0))},
event_engine().get());
data_endpoint.ExpectRead(
{EventEngineSlice::FromCopiedString("12345678"), Zeros(56)}, nullptr);
EXPECT_CALL(*control_endpoint.endpoint, Read)
.InSequence(control_endpoint.read_sequence)
.WillOnce(Return(false));
auto transport = MakeOrphanable<ChaoticGoodClientTransport>(
std::move(control_endpoint.promise_endpoint),
std::move(data_endpoint.promise_endpoint), event_engine());
auto call =
MakeCall(event_engine().get(), Arena::Create(1024, memory_allocator()));
transport->StartCall(std::move(call.handler));
StrictMock<MockFunction<void()>> on_done;
EXPECT_CALL(on_done, Call());
control_endpoint.ExpectWrite(
{SerializedFrameHeader(FrameType::kFragment, 1, 1,
sizeof(kPathDemoServiceStep), 0, 0, 0),
EventEngineSlice::FromCopiedBuffer(kPathDemoServiceStep,
sizeof(kPathDemoServiceStep))},
nullptr);
control_endpoint.ExpectWrite(
{SerializedFrameHeader(FrameType::kFragment, 2, 1, 0, 1, 63, 0)},
nullptr);
data_endpoint.ExpectWrite(
{EventEngineSlice::FromCopiedString("0"), Zeros(63)}, nullptr);
control_endpoint.ExpectWrite(
{SerializedFrameHeader(FrameType::kFragment, 4, 1, 0, 0, 0, 0)}, nullptr);
call.initiator.SpawnGuarded("test-send", [initiator =
call.initiator]() mutable {
return TrySeq(initiator.PushClientInitialMetadata(TestInitialMetadata()),
SendClientToServerMessages(initiator, 1));
});
call.initiator.SpawnInfallible(
"test-read", [&on_done, initiator = call.initiator]() mutable {
return Seq(
initiator.PullServerInitialMetadata(),
[](ValueOrFailure<ServerMetadataHandle> md) {
EXPECT_TRUE(md.ok());
EXPECT_EQ(
md.value()->get_pointer(HttpPathMetadata())->as_string_view(),
"/demo.Service/Step");
return Empty{};
},
initiator.PullMessage(),
[](NextResult<MessageHandle> msg) {
EXPECT_TRUE(msg.has_value());
EXPECT_EQ(msg.value()->payload()->JoinIntoString(), "12345678");
return Empty{};
},
initiator.PullMessage(),
[](NextResult<MessageHandle> msg) {
EXPECT_FALSE(msg.has_value());
return Empty{};
},
initiator.PullServerTrailingMetadata(),
[&on_done](ServerMetadataHandle md) {
EXPECT_EQ(md->get(GrpcStatusMetadata()).value(), GRPC_STATUS_OK);
on_done.Call();
return Empty{};
});
});
// Wait until ClientTransport's internal activities to finish.
event_engine_->TickUntilIdle();
event_engine_->UnsetGlobalHooks();
event_engine()->TickUntilIdle();
event_engine()->UnsetGlobalHooks();
}
TEST_F(ClientTransportTest, AddMultipleStreamsMultipleMessages) {
InitialClientTransport(2);
ClientMetadataHandle first_stream_md;
ClientMetadataHandle second_stream_md;
auto first_stream_args =
CallArgs{std::move(first_stream_md),
ClientInitialMetadataOutstandingToken::Empty(),
nullptr,
&pipe_server_intial_metadata_.sender,
&pipe_client_to_server_messages_.receiver,
&pipe_server_to_client_messages_.sender};
auto second_stream_args =
CallArgs{std::move(second_stream_md),
ClientInitialMetadataOutstandingToken::Empty(),
nullptr,
&pipe_server_intial_metadata_second_.sender,
&pipe_client_to_server_messages_second_.receiver,
&pipe_server_to_client_messages_second_.sender};
StrictMock<MockFunction<void(absl::Status)>> on_done;
EXPECT_CALL(on_done, Call(absl::OkStatus()));
EXPECT_CALL(control_endpoint_, Write).Times(6).WillRepeatedly(Return(true));
EXPECT_CALL(data_endpoint_, Write).Times(6).WillRepeatedly(Return(true));
auto activity = MakeActivity(
Seq(
// Concurrently: write and read messages from client transport.
Join(
// Add first stream with call_args into client transport.
AddStream(std::move(first_stream_args), GRPC_STATUS_OK),
// Start read from control endpoints.
StartRead(absl::OkStatus()),
// Send messages to first stream's
// call_args.client_to_server_messages pipe, which will be
// eventually sent to control/data endpoints.
SendClientToServerMessages(pipe_client_to_server_messages_, 3),
// Receive first stream's messages from control/data endpoints.
ReceiveServerToClientMessages(pipe_server_intial_metadata_,
pipe_server_to_client_messages_)),
Join(
// Add second stream with call_args into client transport.
AddStream(std::move(second_stream_args), GRPC_STATUS_OK),
// Send messages to second stream's
// call_args.client_to_server_messages pipe, which will be
// eventually sent to control/data endpoints.
SendClientToServerMessages(pipe_client_to_server_messages_second_,
3),
// Receive second stream's messages from control/data endpoints.
ReceiveServerToClientMessages(
pipe_server_intial_metadata_second_,
pipe_server_to_client_messages_second_)),
// Once complete, verify successful sending and the received value.
[](const std::tuple<grpc_status_code, absl::Status, absl::Status>&
ret) {
EXPECT_EQ(std::get<0>(ret), GRPC_STATUS_OK);
EXPECT_TRUE(std::get<1>(ret).ok());
EXPECT_TRUE(std::get<2>(ret).ok());
return absl::OkStatus();
}),
EventEngineWakeupScheduler(event_engine_),
[&on_done](absl::Status status) { on_done.Call(std::move(status)); });
TEST_F(TransportTest, AddOneStreamMultipleMessages) {
MockPromiseEndpoint control_endpoint;
MockPromiseEndpoint data_endpoint;
control_endpoint.ExpectRead(
{SerializedFrameHeader(FrameType::kFragment, 3, 1, 26, 8, 56, 0),
EventEngineSlice::FromCopiedBuffer(kPathDemoServiceStep,
sizeof(kPathDemoServiceStep))},
event_engine().get());
control_endpoint.ExpectRead(
{SerializedFrameHeader(FrameType::kFragment, 6, 1, 0, 8, 56, 15),
EventEngineSlice::FromCopiedBuffer(kGrpcStatus0, sizeof(kGrpcStatus0))},
event_engine().get());
data_endpoint.ExpectRead(
{EventEngineSlice::FromCopiedString("12345678"), Zeros(56)}, nullptr);
data_endpoint.ExpectRead(
{EventEngineSlice::FromCopiedString("87654321"), Zeros(56)}, nullptr);
EXPECT_CALL(*control_endpoint.endpoint, Read)
.InSequence(control_endpoint.read_sequence)
.WillOnce(Return(false));
auto transport = MakeOrphanable<ChaoticGoodClientTransport>(
std::move(control_endpoint.promise_endpoint),
std::move(data_endpoint.promise_endpoint), event_engine());
auto call =
MakeCall(event_engine().get(), Arena::Create(8192, memory_allocator()));
transport->StartCall(std::move(call.handler));
StrictMock<MockFunction<void()>> on_done;
EXPECT_CALL(on_done, Call());
control_endpoint.ExpectWrite(
{SerializedFrameHeader(FrameType::kFragment, 1, 1,
sizeof(kPathDemoServiceStep), 0, 0, 0),
EventEngineSlice::FromCopiedBuffer(kPathDemoServiceStep,
sizeof(kPathDemoServiceStep))},
nullptr);
control_endpoint.ExpectWrite(
{SerializedFrameHeader(FrameType::kFragment, 2, 1, 0, 1, 63, 0)},
nullptr);
data_endpoint.ExpectWrite(
{EventEngineSlice::FromCopiedString("0"), Zeros(63)}, nullptr);
control_endpoint.ExpectWrite(
{SerializedFrameHeader(FrameType::kFragment, 2, 1, 0, 1, 63, 0)},
nullptr);
data_endpoint.ExpectWrite(
{EventEngineSlice::FromCopiedString("1"), Zeros(63)}, nullptr);
control_endpoint.ExpectWrite(
{SerializedFrameHeader(FrameType::kFragment, 4, 1, 0, 0, 0, 0)}, nullptr);
call.initiator.SpawnGuarded("test-send", [initiator =
call.initiator]() mutable {
return TrySeq(initiator.PushClientInitialMetadata(TestInitialMetadata()),
SendClientToServerMessages(initiator, 2));
});
call.initiator.SpawnInfallible(
"test-read", [&on_done, initiator = call.initiator]() mutable {
return Seq(
initiator.PullServerInitialMetadata(),
[](ValueOrFailure<ServerMetadataHandle> md) {
EXPECT_TRUE(md.ok());
EXPECT_EQ(
md.value()->get_pointer(HttpPathMetadata())->as_string_view(),
"/demo.Service/Step");
return Empty{};
},
initiator.PullMessage(),
[](NextResult<MessageHandle> msg) {
EXPECT_TRUE(msg.has_value());
EXPECT_EQ(msg.value()->payload()->JoinIntoString(), "12345678");
return Empty{};
},
initiator.PullMessage(),
[](NextResult<MessageHandle> msg) {
EXPECT_TRUE(msg.has_value());
EXPECT_EQ(msg.value()->payload()->JoinIntoString(), "87654321");
return Empty{};
},
initiator.PullMessage(),
[](NextResult<MessageHandle> msg) {
EXPECT_FALSE(msg.has_value());
return Empty{};
},
initiator.PullServerTrailingMetadata(),
[&on_done](ServerMetadataHandle md) {
EXPECT_EQ(md->get(GrpcStatusMetadata()).value(), GRPC_STATUS_OK);
on_done.Call();
return Empty{};
});
});
// Wait until ClientTransport's internal activities to finish.
event_engine_->TickUntilIdle();
event_engine_->UnsetGlobalHooks();
event_engine()->TickUntilIdle();
event_engine()->UnsetGlobalHooks();
}
} // namespace testing

@ -35,7 +35,9 @@
#include "src/core/lib/resource_quota/resource_quota.h"
#include "src/core/lib/slice/slice.h"
#include "src/core/lib/slice/slice_buffer.h"
#include "src/libfuzzer/libfuzzer_macro.h"
#include "test/core/promise/test_context.h"
#include "test/core/transport/chaotic_good/frame_fuzzer.pb.h"
bool squelch = false;
@ -51,10 +53,10 @@ template <typename T>
void AssertRoundTrips(const T& input, FrameType expected_frame_type) {
HPackCompressor hpack_compressor;
auto serialized = input.Serialize(&hpack_compressor);
GPR_ASSERT(serialized.Length() >=
GPR_ASSERT(serialized.control.Length() >=
24); // Initial output buffer size is 64 byte.
uint8_t header_bytes[24];
serialized.MoveFirstNBytesIntoBuffer(24, header_bytes);
serialized.control.MoveFirstNBytesIntoBuffer(24, header_bytes);
auto header = FrameHeader::Parse(header_bytes);
if (!header.ok()) {
if (!squelch) {
@ -67,66 +69,69 @@ void AssertRoundTrips(const T& input, FrameType expected_frame_type) {
T output;
HPackParser hpack_parser;
DeterministicBitGen bitgen;
auto deser = output.Deserialize(&hpack_parser, header.value(),
absl::BitGenRef(bitgen), serialized);
auto deser =
output.Deserialize(&hpack_parser, header.value(), absl::BitGenRef(bitgen),
GetContext<Arena>(), std::move(serialized));
GPR_ASSERT(deser.ok());
GPR_ASSERT(output == input);
}
template <typename T>
void FinishParseAndChecks(const FrameHeader& header, const uint8_t* data,
size_t size) {
void FinishParseAndChecks(const FrameHeader& header, BufferPair buffers) {
T parsed;
ExecCtx exec_ctx; // Initialized to get this_cpu() info in global_stat().
HPackParser hpack_parser;
SliceBuffer serialized;
serialized.Append(Slice::FromCopiedBuffer(data, size));
DeterministicBitGen bitgen;
auto deser = parsed.Deserialize(&hpack_parser, header,
absl::BitGenRef(bitgen), serialized);
auto deser =
parsed.Deserialize(&hpack_parser, header, absl::BitGenRef(bitgen),
GetContext<Arena>(), std::move(buffers));
if (!deser.ok()) return;
gpr_log(GPR_INFO, "Read frame: %s", parsed.ToString().c_str());
AssertRoundTrips(parsed, header.type);
}
int Run(const uint8_t* data, size_t size) {
if (size < 1) return 0;
const bool is_server = (data[0] & 1) != 0;
size--;
data++;
if (size < 24) return 0;
auto r = FrameHeader::Parse(data);
if (!r.ok()) return 0;
void Run(const frame_fuzzer::Test& test) {
const uint8_t* control_data =
reinterpret_cast<const uint8_t*>(test.control().data());
size_t control_size = test.control().size();
if (test.control().size() < 24) return;
auto r = FrameHeader::Parse(control_data);
if (!r.ok()) return;
if (test.data().size() != r->message_length) return;
gpr_log(GPR_INFO, "Read frame header: %s", r->ToString().c_str());
size -= 24;
data += 24;
control_data += 24;
control_size -= 24;
MemoryAllocator memory_allocator = MemoryAllocator(
ResourceQuota::Default()->memory_quota()->CreateMemoryAllocator("test"));
auto arena = MakeScopedArena(1024, &memory_allocator);
TestContext<Arena> ctx(arena.get());
BufferPair buffers{
SliceBuffer(Slice::FromCopiedBuffer(control_data, control_size)),
SliceBuffer(
Slice::FromCopiedBuffer(test.data().data(), test.data().size())),
};
switch (r->type) {
default:
return 0; // We don't know how to parse this frame type.
return; // We don't know how to parse this frame type.
case FrameType::kSettings:
FinishParseAndChecks<SettingsFrame>(*r, data, size);
FinishParseAndChecks<SettingsFrame>(*r, std::move(buffers));
break;
case FrameType::kFragment:
if (is_server) {
FinishParseAndChecks<ServerFragmentFrame>(*r, data, size);
if (test.is_server()) {
FinishParseAndChecks<ServerFragmentFrame>(*r, std::move(buffers));
} else {
FinishParseAndChecks<ClientFragmentFrame>(*r, data, size);
FinishParseAndChecks<ClientFragmentFrame>(*r, std::move(buffers));
}
break;
case FrameType::kCancel:
FinishParseAndChecks<CancelFrame>(*r, data, size);
FinishParseAndChecks<CancelFrame>(*r, std::move(buffers));
break;
}
return 0;
}
} // namespace chaotic_good
} // namespace grpc_core
extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) {
return grpc_core::chaotic_good::Run(data, size);
DEFINE_PROTO_FUZZER(const frame_fuzzer::Test& test) {
grpc_core::chaotic_good::Run(test);
}

@ -0,0 +1,23 @@
// Copyright 2021 gRPC authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto3";
package frame_fuzzer;
message Test {
bool is_server = 1;
bytes control = 2;
bytes data = 3;
}

@ -36,7 +36,7 @@ absl::StatusOr<FrameHeader> Deserialize(std::vector<uint8_t> data) {
}
TEST(FrameHeaderTest, SimpleSerialize) {
EXPECT_EQ(Serialize(FrameHeader{FrameType::kCancel, BitSet<2>::FromInt(0),
EXPECT_EQ(Serialize(FrameHeader{FrameType::kCancel, BitSet<3>::FromInt(0),
0x01020304, 0x05060708, 0x090a0b0c,
0x00000034, 0x0d0e0f10}),
std::vector<uint8_t>({
@ -59,7 +59,7 @@ TEST(FrameHeaderTest, SimpleDeserialize) {
0x10, 0x0f, 0x0e, 0x0d // trailer_length
})),
absl::StatusOr<FrameHeader>(FrameHeader{
FrameType::kCancel, BitSet<2>::FromInt(0), 0x01020304,
FrameType::kCancel, BitSet<3>::FromInt(0), 0x01020304,
0x05060708, 0x090a0b0c, 0x00000034, 0x0d0e0f10}));
EXPECT_EQ(Deserialize(std::vector<uint8_t>({
0x81, 88, 88, 88, // type, flags
@ -75,19 +75,19 @@ TEST(FrameHeaderTest, SimpleDeserialize) {
TEST(FrameHeaderTest, GetFrameLength) {
EXPECT_EQ(
(FrameHeader{FrameType::kFragment, BitSet<2>::FromInt(3), 1, 0, 0, 0, 0})
(FrameHeader{FrameType::kFragment, BitSet<3>::FromInt(5), 1, 0, 0, 0, 0})
.GetFrameLength(),
0);
EXPECT_EQ(
(FrameHeader{FrameType::kFragment, BitSet<2>::FromInt(3), 1, 14, 0, 0, 0})
(FrameHeader{FrameType::kFragment, BitSet<3>::FromInt(5), 1, 14, 0, 0, 0})
.GetFrameLength(),
14);
EXPECT_EQ((FrameHeader{FrameType::kFragment, BitSet<2>::FromInt(3), 1, 0, 14,
EXPECT_EQ((FrameHeader{FrameType::kFragment, BitSet<3>::FromInt(5), 1, 0, 14,
50, 0})
.GetFrameLength(),
0);
EXPECT_EQ(
(FrameHeader{FrameType::kFragment, BitSet<2>::FromInt(3), 1, 0, 0, 0, 14})
(FrameHeader{FrameType::kFragment, BitSet<3>::FromInt(5), 1, 0, 0, 0, 14})
.GetFrameLength(),
14);
}

@ -21,27 +21,38 @@
#include "absl/status/statusor.h"
#include "gtest/gtest.h"
#include "src/core/lib/resource_quota/memory_quota.h"
#include "src/core/lib/resource_quota/resource_quota.h"
namespace grpc_core {
namespace chaotic_good {
namespace {
template <typename T>
void AssertRoundTrips(const T input, FrameType expected_frame_type) {
void AssertRoundTrips(const T& input, FrameType expected_frame_type) {
HPackCompressor hpack_compressor;
absl::BitGen bitgen;
auto serialized = input.Serialize(&hpack_compressor);
EXPECT_GE(serialized.Length(), 24);
GPR_ASSERT(serialized.control.Length() >=
24); // Initial output buffer size is 64 byte.
uint8_t header_bytes[24];
serialized.MoveFirstNBytesIntoBuffer(24, header_bytes);
serialized.control.MoveFirstNBytesIntoBuffer(24, header_bytes);
auto header = FrameHeader::Parse(header_bytes);
EXPECT_TRUE(header.ok()) << header.status();
EXPECT_EQ(header->type, expected_frame_type);
if (!header.ok()) {
Crash("Failed to parse header");
}
GPR_ASSERT(header->type == expected_frame_type);
T output;
HPackParser hpack_parser;
auto deser = output.Deserialize(&hpack_parser, header.value(),
absl::BitGenRef(bitgen), serialized);
EXPECT_TRUE(deser.ok()) << deser;
EXPECT_EQ(output, input);
absl::BitGen bitgen;
MemoryAllocator allocator = MakeResourceQuota("test-quota")
->memory_quota()
->CreateMemoryAllocator("test-allocator");
ScopedArenaPtr arena = MakeScopedArena(1024, &allocator);
auto deser =
output.Deserialize(&hpack_parser, header.value(), absl::BitGenRef(bitgen),
arena.get(), std::move(serialized));
GPR_ASSERT(deser.ok());
GPR_ASSERT(output == input);
}
TEST(FrameTest, SettingsFrameRoundTrips) {

@ -0,0 +1,89 @@
// Copyright 2023 gRPC authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "test/core/transport/chaotic_good/mock_promise_endpoint.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include <grpc/event_engine/event_engine.h>
using EventEngineSlice = grpc_event_engine::experimental::Slice;
using grpc_event_engine::experimental::EventEngine;
using testing::WithArgs;
namespace grpc_core {
namespace chaotic_good {
namespace testing {
void MockPromiseEndpoint::ExpectRead(
std::initializer_list<EventEngineSlice> slices_init,
EventEngine* schedule_on_event_engine) {
std::vector<EventEngineSlice> slices;
for (auto&& slice : slices_init) slices.emplace_back(slice.Copy());
EXPECT_CALL(*endpoint, Read)
.InSequence(read_sequence)
.WillOnce(WithArgs<0, 1>(
[slices = std::move(slices), schedule_on_event_engine](
absl::AnyInvocable<void(absl::Status)> on_read,
grpc_event_engine::experimental::SliceBuffer* buffer) mutable {
for (auto& slice : slices) {
buffer->Append(std::move(slice));
}
if (schedule_on_event_engine != nullptr) {
schedule_on_event_engine->Run(
[on_read = std::move(on_read)]() mutable {
on_read(absl::OkStatus());
});
return false;
} else {
return true;
}
}));
}
void MockPromiseEndpoint::ExpectWrite(
std::initializer_list<EventEngineSlice> slices,
EventEngine* schedule_on_event_engine) {
SliceBuffer expect;
for (auto&& slice : slices) {
expect.Append(grpc_event_engine::experimental::internal::SliceCast<Slice>(
slice.Copy()));
}
EXPECT_CALL(*endpoint, Write)
.InSequence(write_sequence)
.WillOnce(WithArgs<0, 1>(
[expect = expect.JoinIntoString(), schedule_on_event_engine](
absl::AnyInvocable<void(absl::Status)> on_writable,
grpc_event_engine::experimental::SliceBuffer* buffer) mutable {
SliceBuffer tmp;
grpc_slice_buffer_swap(buffer->c_slice_buffer(),
tmp.c_slice_buffer());
EXPECT_EQ(tmp.JoinIntoString(), expect);
if (schedule_on_event_engine != nullptr) {
schedule_on_event_engine->Run(
[on_writable = std::move(on_writable)]() mutable {
on_writable(absl::OkStatus());
});
return false;
} else {
return true;
}
}));
}
} // namespace testing
} // namespace chaotic_good
} // namespace grpc_core

@ -0,0 +1,77 @@
// Copyright 2023 gRPC authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef GRPC_TEST_CORE_TRANSPORT_CHAOTIC_GOOD_MOCK_PROMISE_ENDPOINT_H
#define GRPC_TEST_CORE_TRANSPORT_CHAOTIC_GOOD_MOCK_PROMISE_ENDPOINT_H
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include <grpc/event_engine/event_engine.h>
#include "src/core/lib/transport/promise_endpoint.h"
namespace grpc_core {
namespace chaotic_good {
namespace testing {
class MockEndpoint
: public grpc_event_engine::experimental::EventEngine::Endpoint {
public:
MOCK_METHOD(
bool, Read,
(absl::AnyInvocable<void(absl::Status)> on_read,
grpc_event_engine::experimental::SliceBuffer* buffer,
const grpc_event_engine::experimental::EventEngine::Endpoint::ReadArgs*
args),
(override));
MOCK_METHOD(
bool, Write,
(absl::AnyInvocable<void(absl::Status)> on_writable,
grpc_event_engine::experimental::SliceBuffer* data,
const grpc_event_engine::experimental::EventEngine::Endpoint::WriteArgs*
args),
(override));
MOCK_METHOD(
const grpc_event_engine::experimental::EventEngine::ResolvedAddress&,
GetPeerAddress, (), (const, override));
MOCK_METHOD(
const grpc_event_engine::experimental::EventEngine::ResolvedAddress&,
GetLocalAddress, (), (const, override));
};
struct MockPromiseEndpoint {
::testing::StrictMock<MockEndpoint>* endpoint =
new ::testing::StrictMock<MockEndpoint>();
std::unique_ptr<PromiseEndpoint> promise_endpoint =
std::make_unique<PromiseEndpoint>(
std::unique_ptr<::testing::StrictMock<MockEndpoint>>(endpoint),
SliceBuffer());
::testing::Sequence read_sequence;
::testing::Sequence write_sequence;
void ExpectRead(
std::initializer_list<grpc_event_engine::experimental::Slice> slices_init,
grpc_event_engine::experimental::EventEngine* schedule_on_event_engine);
void ExpectWrite(
std::initializer_list<grpc_event_engine::experimental::Slice> slices,
grpc_event_engine::experimental::EventEngine* schedule_on_event_engine);
};
} // namespace testing
} // namespace chaotic_good
} // namespace grpc_core
#endif // GRPC_TEST_CORE_TRANSPORT_CHAOTIC_GOOD_MOCK_PROMISE_ENDPOINT_H

@ -0,0 +1,198 @@
// Copyright 2023 gRPC authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/core/ext/transport/chaotic_good/server_transport.h"
#include <algorithm>
#include <memory>
#include <string>
#include <vector>
#include "absl/functional/any_invocable.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "absl/types/optional.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include <grpc/event_engine/event_engine.h>
#include <grpc/event_engine/memory_allocator.h>
#include <grpc/event_engine/slice.h>
#include <grpc/event_engine/slice_buffer.h>
#include <grpc/grpc.h>
#include <grpc/status.h>
#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/lib/iomgr/timer_manager.h"
#include "src/core/lib/promise/seq.h"
#include "src/core/lib/resource_quota/arena.h"
#include "src/core/lib/resource_quota/memory_quota.h"
#include "src/core/lib/resource_quota/resource_quota.h"
#include "src/core/lib/slice/slice_buffer.h"
#include "src/core/lib/slice/slice_internal.h"
#include "src/core/lib/transport/metadata_batch.h"
#include "test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.h"
#include "test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.pb.h"
#include "test/core/transport/chaotic_good/mock_promise_endpoint.h"
#include "test/core/transport/chaotic_good/transport_test.h"
using testing::_;
using testing::MockFunction;
using testing::Return;
using testing::StrictMock;
using testing::WithArgs;
using EventEngineSlice = grpc_event_engine::experimental::Slice;
namespace grpc_core {
namespace chaotic_good {
namespace testing {
// Encoded string of header ":path: /demo.Service/Step".
const uint8_t kPathDemoServiceStep[] = {
0x40, 0x05, 0x3a, 0x70, 0x61, 0x74, 0x68, 0x12, 0x2f,
0x64, 0x65, 0x6d, 0x6f, 0x2e, 0x53, 0x65, 0x72, 0x76,
0x69, 0x63, 0x65, 0x2f, 0x53, 0x74, 0x65, 0x70};
// Encoded string of trailer "grpc-status: 0".
const uint8_t kGrpcStatus0[] = {0x40, 0x0b, 0x67, 0x72, 0x70, 0x63, 0x2d, 0x73,
0x74, 0x61, 0x74, 0x75, 0x73, 0x01, 0x30};
ServerMetadataHandle TestInitialMetadata() {
auto md =
GetContext<Arena>()->MakePooled<ServerMetadata>(GetContext<Arena>());
md->Set(HttpPathMetadata(), Slice::FromStaticString("/demo.Service/Step"));
return md;
}
ServerMetadataHandle TestTrailingMetadata() {
auto md =
GetContext<Arena>()->MakePooled<ServerMetadata>(GetContext<Arena>());
md->Set(GrpcStatusMetadata(), GRPC_STATUS_OK);
return md;
}
class MockAcceptor : public ServerTransport::Acceptor {
public:
virtual ~MockAcceptor() = default;
MOCK_METHOD(Arena*, CreateArena, (), (override));
MOCK_METHOD(absl::StatusOr<CallInitiator>, CreateCall,
(ClientMetadata & client_initial_metadata, Arena* arena),
(override));
};
TEST_F(TransportTest, ReadAndWriteOneMessage) {
MockPromiseEndpoint control_endpoint;
MockPromiseEndpoint data_endpoint;
StrictMock<MockAcceptor> acceptor;
auto transport = MakeOrphanable<ChaoticGoodServerTransport>(
CoreConfiguration::Get()
.channel_args_preconditioning()
.PreconditionChannelArgs(nullptr),
std::move(control_endpoint.promise_endpoint),
std::move(data_endpoint.promise_endpoint), event_engine());
// Once we set the acceptor, expect to read some frames.
// We'll return a new request with a payload of "12345678".
control_endpoint.ExpectRead(
{SerializedFrameHeader(FrameType::kFragment, 7, 1, 26, 8, 56, 0),
EventEngineSlice::FromCopiedBuffer(kPathDemoServiceStep,
sizeof(kPathDemoServiceStep))},
event_engine().get());
data_endpoint.ExpectRead(
{EventEngineSlice::FromCopiedString("12345678"), Zeros(56)}, nullptr);
// Once that's read we'll create a new call
auto* call_arena = Arena::Create(1024, memory_allocator());
CallInitiatorAndHandler call = MakeCall(event_engine().get(), call_arena);
EXPECT_CALL(acceptor, CreateArena).WillOnce(Return(call_arena));
EXPECT_CALL(acceptor, CreateCall(_, call_arena))
.WillOnce(WithArgs<0>([call_initiator = std::move(call.initiator)](
ClientMetadata& client_initial_metadata) {
EXPECT_EQ(client_initial_metadata.get_pointer(HttpPathMetadata())
->as_string_view(),
"/demo.Service/Step");
return call_initiator;
}));
transport->SetAcceptor(&acceptor);
StrictMock<MockFunction<void()>> on_done;
EXPECT_CALL(on_done, Call());
EXPECT_CALL(*control_endpoint.endpoint, Read)
.InSequence(control_endpoint.read_sequence)
.WillOnce(Return(false));
control_endpoint.ExpectWrite(
{SerializedFrameHeader(FrameType::kFragment, 1, 1,
sizeof(kPathDemoServiceStep), 0, 0, 0),
EventEngineSlice::FromCopiedBuffer(kPathDemoServiceStep,
sizeof(kPathDemoServiceStep))},
nullptr);
control_endpoint.ExpectWrite(
{SerializedFrameHeader(FrameType::kFragment, 2, 1, 0, 8, 56, 0)},
nullptr);
data_endpoint.ExpectWrite(
{EventEngineSlice::FromCopiedString("87654321"), Zeros(56)}, nullptr);
control_endpoint.ExpectWrite(
{SerializedFrameHeader(FrameType::kFragment, 4, 1, 0, 0, 0,
sizeof(kGrpcStatus0)),
EventEngineSlice::FromCopiedBuffer(kGrpcStatus0, sizeof(kGrpcStatus0))},
nullptr);
call.handler.SpawnInfallible(
"test-io", [&on_done, handler = call.handler]() mutable {
return Seq(
handler.PullClientInitialMetadata(),
[](ValueOrFailure<ServerMetadataHandle> md) {
EXPECT_TRUE(md.ok());
EXPECT_EQ(
md.value()->get_pointer(HttpPathMetadata())->as_string_view(),
"/demo.Service/Step");
return Empty{};
},
handler.PullMessage(),
[](NextResult<MessageHandle> msg) {
EXPECT_TRUE(msg.has_value());
EXPECT_EQ(msg.value()->payload()->JoinIntoString(), "12345678");
return Empty{};
},
handler.PullMessage(),
[](NextResult<MessageHandle> msg) {
EXPECT_FALSE(msg.has_value());
return Empty{};
},
handler.PushServerInitialMetadata(TestInitialMetadata()),
handler.PushMessage(Arena::MakePooled<Message>(
SliceBuffer(Slice::FromCopiedString("87654321")), 0)),
[handler]() mutable {
return handler.PushServerTrailingMetadata(TestTrailingMetadata());
},
[&on_done]() mutable {
on_done.Call();
return Empty{};
});
});
// Wait until ClientTransport's internal activities to finish.
event_engine()->TickUntilIdle();
event_engine()->UnsetGlobalHooks();
}
} // namespace testing
} // namespace chaotic_good
} // namespace grpc_core
int main(int argc, char** argv) {
::testing::InitGoogleTest(&argc, argv);
// Must call to create default EventEngine.
grpc_init();
int ret = RUN_ALL_TESTS();
grpc_shutdown();
return ret;
}

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save