diff --git a/CMakeLists.txt b/CMakeLists.txt index 17d1a7410a8..5d178f90137 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1326,6 +1326,7 @@ if(gRPC_BUILD_TESTS) endif() add_dependencies(buildtests_cxx server_streaming_test) add_dependencies(buildtests_cxx server_test) + add_dependencies(buildtests_cxx server_transport_test) add_dependencies(buildtests_cxx service_config_end2end_test) add_dependencies(buildtests_cxx service_config_test) add_dependencies(buildtests_cxx settings_timeout_test) @@ -9533,6 +9534,7 @@ add_executable(client_transport_error_test ${_gRPC_PROTO_GENS_DIR}/test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.grpc.pb.cc ${_gRPC_PROTO_GENS_DIR}/test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.pb.h ${_gRPC_PROTO_GENS_DIR}/test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.grpc.pb.h + src/core/ext/transport/chaotic_good/chaotic_good_transport.cc src/core/ext/transport/chaotic_good/client_transport.cc src/core/ext/transport/chaotic_good/frame.cc src/core/ext/transport/chaotic_good/frame_header.cc @@ -9563,6 +9565,7 @@ target_include_directories(client_transport_error_test target_link_libraries(client_transport_error_test ${_gRPC_ALLTARGETS_LIBRARIES} gtest + grpc_unsecure ${_gRPC_PROTOBUF_LIBRARIES} grpc_test_util ) @@ -9576,12 +9579,15 @@ add_executable(client_transport_test ${_gRPC_PROTO_GENS_DIR}/test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.grpc.pb.cc ${_gRPC_PROTO_GENS_DIR}/test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.pb.h ${_gRPC_PROTO_GENS_DIR}/test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.grpc.pb.h + src/core/ext/transport/chaotic_good/chaotic_good_transport.cc src/core/ext/transport/chaotic_good/client_transport.cc src/core/ext/transport/chaotic_good/frame.cc src/core/ext/transport/chaotic_good/frame_header.cc src/core/lib/transport/promise_endpoint.cc test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.cc test/core/transport/chaotic_good/client_transport_test.cc + test/core/transport/chaotic_good/mock_promise_endpoint.cc + test/core/transport/chaotic_good/transport_test.cc ) target_compile_features(client_transport_test PUBLIC cxx_std_14) target_include_directories(client_transport_test @@ -22380,6 +22386,52 @@ target_link_libraries(server_test ) +endif() +if(gRPC_BUILD_TESTS) + +add_executable(server_transport_test + ${_gRPC_PROTO_GENS_DIR}/test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.pb.cc + ${_gRPC_PROTO_GENS_DIR}/test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.grpc.pb.cc + ${_gRPC_PROTO_GENS_DIR}/test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.pb.h + ${_gRPC_PROTO_GENS_DIR}/test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.grpc.pb.h + src/core/ext/transport/chaotic_good/chaotic_good_transport.cc + src/core/ext/transport/chaotic_good/frame.cc + src/core/ext/transport/chaotic_good/frame_header.cc + src/core/ext/transport/chaotic_good/server_transport.cc + src/core/lib/transport/promise_endpoint.cc + test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.cc + test/core/transport/chaotic_good/mock_promise_endpoint.cc + test/core/transport/chaotic_good/server_transport_test.cc + test/core/transport/chaotic_good/transport_test.cc +) +target_compile_features(server_transport_test PUBLIC cxx_std_14) +target_include_directories(server_transport_test + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR} + ${CMAKE_CURRENT_SOURCE_DIR}/include + ${_gRPC_ADDRESS_SORTING_INCLUDE_DIR} + ${_gRPC_RE2_INCLUDE_DIR} + ${_gRPC_SSL_INCLUDE_DIR} + ${_gRPC_UPB_GENERATED_DIR} + ${_gRPC_UPB_GRPC_GENERATED_DIR} + ${_gRPC_UPB_INCLUDE_DIR} + ${_gRPC_XXHASH_INCLUDE_DIR} + ${_gRPC_ZLIB_INCLUDE_DIR} + third_party/googletest/googletest/include + third_party/googletest/googletest + third_party/googletest/googlemock/include + third_party/googletest/googlemock + ${_gRPC_PROTO_GENS_DIR} +) + +target_link_libraries(server_transport_test + ${_gRPC_ALLTARGETS_LIBRARIES} + gtest + ${_gRPC_PROTOBUF_LIBRARIES} + grpc_test_util +) + + endif() if(gRPC_BUILD_TESTS) diff --git a/build_autogenerated.yaml b/build_autogenerated.yaml index b1492ba7068..46cd87ce1d7 100644 --- a/build_autogenerated.yaml +++ b/build_autogenerated.yaml @@ -7646,6 +7646,7 @@ targets: build: test language: c++ headers: + - src/core/ext/transport/chaotic_good/chaotic_good_transport.h - src/core/ext/transport/chaotic_good/client_transport.h - src/core/ext/transport/chaotic_good/frame.h - src/core/ext/transport/chaotic_good/frame_header.h @@ -7658,6 +7659,7 @@ targets: - test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.h src: - test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.proto + - src/core/ext/transport/chaotic_good/chaotic_good_transport.cc - src/core/ext/transport/chaotic_good/client_transport.cc - src/core/ext/transport/chaotic_good/frame.cc - src/core/ext/transport/chaotic_good/frame_header.cc @@ -7666,6 +7668,7 @@ targets: - test/core/transport/chaotic_good/client_transport_error_test.cc deps: - gtest + - grpc_unsecure - protobuf - grpc_test_util uses_polling: false @@ -7674,24 +7677,29 @@ targets: build: test language: c++ headers: + - src/core/ext/transport/chaotic_good/chaotic_good_transport.h - src/core/ext/transport/chaotic_good/client_transport.h - src/core/ext/transport/chaotic_good/frame.h - src/core/ext/transport/chaotic_good/frame_header.h - src/core/lib/promise/event_engine_wakeup_scheduler.h - src/core/lib/promise/inter_activity_pipe.h - - src/core/lib/promise/join.h - src/core/lib/promise/mpsc.h - src/core/lib/promise/wait_set.h - src/core/lib/transport/promise_endpoint.h - test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.h + - test/core/transport/chaotic_good/mock_promise_endpoint.h + - test/core/transport/chaotic_good/transport_test.h src: - test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.proto + - src/core/ext/transport/chaotic_good/chaotic_good_transport.cc - src/core/ext/transport/chaotic_good/client_transport.cc - src/core/ext/transport/chaotic_good/frame.cc - src/core/ext/transport/chaotic_good/frame_header.cc - src/core/lib/transport/promise_endpoint.cc - test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.cc - test/core/transport/chaotic_good/client_transport_test.cc + - test/core/transport/chaotic_good/mock_promise_endpoint.cc + - test/core/transport/chaotic_good/transport_test.cc deps: - gtest - protobuf @@ -15568,6 +15576,40 @@ targets: deps: - gtest - grpc_test_util +- name: server_transport_test + gtest: true + build: test + language: c++ + headers: + - src/core/ext/transport/chaotic_good/chaotic_good_transport.h + - src/core/ext/transport/chaotic_good/frame.h + - src/core/ext/transport/chaotic_good/frame_header.h + - src/core/ext/transport/chaotic_good/server_transport.h + - src/core/lib/promise/event_engine_wakeup_scheduler.h + - src/core/lib/promise/inter_activity_pipe.h + - src/core/lib/promise/mpsc.h + - src/core/lib/promise/switch.h + - src/core/lib/promise/wait_set.h + - src/core/lib/transport/promise_endpoint.h + - test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.h + - test/core/transport/chaotic_good/mock_promise_endpoint.h + - test/core/transport/chaotic_good/transport_test.h + src: + - test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.proto + - src/core/ext/transport/chaotic_good/chaotic_good_transport.cc + - src/core/ext/transport/chaotic_good/frame.cc + - src/core/ext/transport/chaotic_good/frame_header.cc + - src/core/ext/transport/chaotic_good/server_transport.cc + - src/core/lib/transport/promise_endpoint.cc + - test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.cc + - test/core/transport/chaotic_good/mock_promise_endpoint.cc + - test/core/transport/chaotic_good/server_transport_test.cc + - test/core/transport/chaotic_good/transport_test.cc + deps: + - gtest + - protobuf + - grpc_test_util + uses_polling: false - name: service_config_end2end_test gtest: true build: test diff --git a/examples/python/helloworld/helloworld_pb2.py b/examples/python/helloworld/helloworld_pb2.py index f5b4f2d27dc..6ee31ad94a2 100644 --- a/examples/python/helloworld/helloworld_pb2.py +++ b/examples/python/helloworld/helloworld_pb2.py @@ -1,11 +1,12 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: helloworld.proto +# Protobuf Python Version: 4.25.0 """Generated protocol buffer code.""" -from google.protobuf.internal import builder as _builder from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() @@ -13,18 +14,18 @@ _sym_db = _symbol_database.Default() -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10helloworld.proto\x12\nhelloworld\"\x1c\n\x0cHelloRequest\x12\x0c\n\x04name\x18\x01 \x01(\t\"\x1d\n\nHelloReply\x12\x0f\n\x07message\x18\x01 \x01(\t2I\n\x07Greeter\x12>\n\x08SayHello\x12\x18.helloworld.HelloRequest\x1a\x16.helloworld.HelloReply\"\x00\x42\x36\n\x1bio.grpc.examples.helloworldB\x0fHelloWorldProtoP\x01\xa2\x02\x03HLWb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10helloworld.proto\x12\nhelloworld\"\x1c\n\x0cHelloRequest\x12\x0c\n\x04name\x18\x01 \x01(\t\"\x1d\n\nHelloReply\x12\x0f\n\x07message\x18\x01 \x01(\t2\xe4\x01\n\x07Greeter\x12>\n\x08SayHello\x12\x18.helloworld.HelloRequest\x1a\x16.helloworld.HelloReply\"\x00\x12K\n\x13SayHelloStreamReply\x12\x18.helloworld.HelloRequest\x1a\x16.helloworld.HelloReply\"\x00\x30\x01\x12L\n\x12SayHelloBidiStream\x12\x18.helloworld.HelloRequest\x1a\x16.helloworld.HelloReply\"\x00(\x01\x30\x01\x42\x36\n\x1bio.grpc.examples.helloworldB\x0fHelloWorldProtoP\x01\xa2\x02\x03HLWb\x06proto3') -_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'helloworld_pb2', globals()) +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'helloworld_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: - - DESCRIPTOR._options = None - DESCRIPTOR._serialized_options = b'\n\033io.grpc.examples.helloworldB\017HelloWorldProtoP\001\242\002\003HLW' - _HELLOREQUEST._serialized_start=32 - _HELLOREQUEST._serialized_end=60 - _HELLOREPLY._serialized_start=62 - _HELLOREPLY._serialized_end=91 - _GREETER._serialized_start=93 - _GREETER._serialized_end=166 + _globals['DESCRIPTOR']._options = None + _globals['DESCRIPTOR']._serialized_options = b'\n\033io.grpc.examples.helloworldB\017HelloWorldProtoP\001\242\002\003HLW' + _globals['_HELLOREQUEST']._serialized_start=32 + _globals['_HELLOREQUEST']._serialized_end=60 + _globals['_HELLOREPLY']._serialized_start=62 + _globals['_HELLOREPLY']._serialized_end=91 + _globals['_GREETER']._serialized_start=94 + _globals['_GREETER']._serialized_end=322 # @@protoc_insertion_point(module_scope) diff --git a/examples/python/helloworld/helloworld_pb2.pyi b/examples/python/helloworld/helloworld_pb2.pyi index 8c4b5b22805..bf0bd395ad5 100644 --- a/examples/python/helloworld/helloworld_pb2.pyi +++ b/examples/python/helloworld/helloworld_pb2.pyi @@ -4,14 +4,14 @@ from typing import ClassVar as _ClassVar, Optional as _Optional DESCRIPTOR: _descriptor.FileDescriptor -class HelloReply(_message.Message): - __slots__ = ["message"] - MESSAGE_FIELD_NUMBER: _ClassVar[int] - message: str - def __init__(self, message: _Optional[str] = ...) -> None: ... - class HelloRequest(_message.Message): - __slots__ = ["name"] + __slots__ = ("name",) NAME_FIELD_NUMBER: _ClassVar[int] name: str def __init__(self, name: _Optional[str] = ...) -> None: ... + +class HelloReply(_message.Message): + __slots__ = ("message",) + MESSAGE_FIELD_NUMBER: _ClassVar[int] + message: str + def __init__(self, message: _Optional[str] = ...) -> None: ... diff --git a/examples/python/helloworld/helloworld_pb2_grpc.py b/examples/python/helloworld/helloworld_pb2_grpc.py index 47c186976e1..68bcfef1756 100644 --- a/examples/python/helloworld/helloworld_pb2_grpc.py +++ b/examples/python/helloworld/helloworld_pb2_grpc.py @@ -19,7 +19,17 @@ class GreeterStub(object): '/helloworld.Greeter/SayHello', request_serializer=helloworld__pb2.HelloRequest.SerializeToString, response_deserializer=helloworld__pb2.HelloReply.FromString, - ) + _registered_method=True) + self.SayHelloStreamReply = channel.unary_stream( + '/helloworld.Greeter/SayHelloStreamReply', + request_serializer=helloworld__pb2.HelloRequest.SerializeToString, + response_deserializer=helloworld__pb2.HelloReply.FromString, + _registered_method=True) + self.SayHelloBidiStream = channel.stream_stream( + '/helloworld.Greeter/SayHelloBidiStream', + request_serializer=helloworld__pb2.HelloRequest.SerializeToString, + response_deserializer=helloworld__pb2.HelloReply.FromString, + _registered_method=True) class GreeterServicer(object): @@ -33,6 +43,18 @@ class GreeterServicer(object): context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') + def SayHelloStreamReply(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def SayHelloBidiStream(self, request_iterator, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + def add_GreeterServicer_to_server(servicer, server): rpc_method_handlers = { @@ -41,6 +63,16 @@ def add_GreeterServicer_to_server(servicer, server): request_deserializer=helloworld__pb2.HelloRequest.FromString, response_serializer=helloworld__pb2.HelloReply.SerializeToString, ), + 'SayHelloStreamReply': grpc.unary_stream_rpc_method_handler( + servicer.SayHelloStreamReply, + request_deserializer=helloworld__pb2.HelloRequest.FromString, + response_serializer=helloworld__pb2.HelloReply.SerializeToString, + ), + 'SayHelloBidiStream': grpc.stream_stream_rpc_method_handler( + servicer.SayHelloBidiStream, + request_deserializer=helloworld__pb2.HelloRequest.FromString, + response_serializer=helloworld__pb2.HelloReply.SerializeToString, + ), } generic_handler = grpc.method_handlers_generic_handler( 'helloworld.Greeter', rpc_method_handlers) @@ -63,8 +95,72 @@ class Greeter(object): wait_for_ready=None, timeout=None, metadata=None): - return grpc.experimental.unary_unary(request, target, '/helloworld.Greeter/SayHello', + return grpc.experimental.unary_unary( + request, + target, + '/helloworld.Greeter/SayHello', + helloworld__pb2.HelloRequest.SerializeToString, + helloworld__pb2.HelloReply.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def SayHelloStreamReply(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_stream( + request, + target, + '/helloworld.Greeter/SayHelloStreamReply', + helloworld__pb2.HelloRequest.SerializeToString, + helloworld__pb2.HelloReply.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def SayHelloBidiStream(request_iterator, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.stream_stream( + request_iterator, + target, + '/helloworld.Greeter/SayHelloBidiStream', helloworld__pb2.HelloRequest.SerializeToString, helloworld__pb2.HelloReply.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) diff --git a/include/grpc/event_engine/internal/slice_cast.h b/include/grpc/event_engine/internal/slice_cast.h index 8bcca60a24a..3f9593464cf 100644 --- a/include/grpc/event_engine/internal/slice_cast.h +++ b/include/grpc/event_engine/internal/slice_cast.h @@ -60,6 +60,18 @@ Result& SliceCast(T& value, SliceCastable = {}) { return reinterpret_cast(value); } +// Cast to `Result&&` from `T&&` without any runtime checks. +// This is only valid if `sizeof(Result) == sizeof(T)`, and if `Result`, `T` are +// opted in as compatible via `SliceCastable`. +template +Result&& SliceCast(T&& value, SliceCastable = {}) { + // Insist upon sizes being equal to catch mismatches. + // We assume if sizes are opted in and sizes are equal then yes, these two + // types are expected to be layout compatible and actually appear to be. + static_assert(sizeof(Result) == sizeof(T), "size mismatch"); + return reinterpret_cast(value); +} + } // namespace internal } // namespace experimental } // namespace grpc_event_engine diff --git a/include/grpc/event_engine/slice.h b/include/grpc/event_engine/slice.h index 8d49f391600..ce7693f6489 100644 --- a/include/grpc/event_engine/slice.h +++ b/include/grpc/event_engine/slice.h @@ -169,6 +169,11 @@ struct CopyConstructors { return Out(grpc_slice_from_copied_buffer(p, len)); } + static Out FromCopiedBuffer(const uint8_t* p, size_t len) { + return Out( + grpc_slice_from_copied_buffer(reinterpret_cast(p), len)); + } + template static Out FromCopiedBuffer(const Buffer& buffer) { return FromCopiedBuffer(reinterpret_cast(buffer.data()), diff --git a/src/compiler/python_generator.cc b/src/compiler/python_generator.cc index 753fe1c8887..b7a3115bce0 100644 --- a/src/compiler/python_generator.cc +++ b/src/compiler/python_generator.cc @@ -467,7 +467,7 @@ bool PrivateGenerator::PrintStub( out->Print( method_dict, "response_deserializer=$ResponseModuleAndClass$.FromString,\n"); - out->Print(")\n"); + out->Print("_registered_method=True)\n"); } } } @@ -642,22 +642,27 @@ bool PrivateGenerator::PrintServiceClass( args_dict["ArityMethodName"] = arity_method_name; args_dict["PackageQualifiedService"] = package_qualified_service_name; args_dict["Method"] = method->name(); - out->Print(args_dict, - "return " - "grpc.experimental.$ArityMethodName$($RequestParameter$, " - "target, '/$PackageQualifiedService$/$Method$',\n"); + out->Print(args_dict, "return grpc.experimental.$ArityMethodName$(\n"); { IndentScope continuation_indent(out); StringMap serializer_dict; + out->Print(args_dict, "$RequestParameter$,\n"); + out->Print("target,\n"); + out->Print(args_dict, "'/$PackageQualifiedService$/$Method$',\n"); serializer_dict["RequestModuleAndClass"] = request_module_and_class; serializer_dict["ResponseModuleAndClass"] = response_module_and_class; out->Print(serializer_dict, "$RequestModuleAndClass$.SerializeToString,\n"); out->Print(serializer_dict, "$ResponseModuleAndClass$.FromString,\n"); - out->Print("options, channel_credentials,\n"); - out->Print( - "insecure, call_credentials, compression, wait_for_ready, " - "timeout, metadata)\n"); + out->Print("options,\n"); + out->Print("channel_credentials,\n"); + out->Print("insecure,\n"); + out->Print("call_credentials,\n"); + out->Print("compression,\n"); + out->Print("wait_for_ready,\n"); + out->Print("timeout,\n"); + out->Print("metadata,\n"); + out->Print("_registered_method=True)\n"); } } } diff --git a/src/core/BUILD b/src/core/BUILD index fcb11452d18..8ee2f308b86 100644 --- a/src/core/BUILD +++ b/src/core/BUILD @@ -6213,6 +6213,7 @@ grpc_cc_library( "bitset", "chaotic_good_frame_header", "context", + "match", "no_destruct", "slice", "slice_buffer", @@ -6368,6 +6369,29 @@ grpc_cc_library( ], ) +grpc_cc_library( + name = "chaotic_good_transport", + srcs = [ + "ext/transport/chaotic_good/chaotic_good_transport.cc", + ], + hdrs = [ + "ext/transport/chaotic_good/chaotic_good_transport.h", + ], + external_deps = ["absl/random"], + language = "c++", + deps = [ + "chaotic_good_frame", + "chaotic_good_frame_header", + "grpc_promise_endpoint", + "if", + "try_join", + "try_seq", + "//:gpr_platform", + "//:hpack_encoder", + "//:promise", + ], +) + grpc_cc_library( name = "chaotic_good_client_transport", srcs = [ @@ -6378,6 +6402,7 @@ grpc_cc_library( ], external_deps = [ "absl/base:core_headers", + "absl/container:flat_hash_map", "absl/random", "absl/random:bit_gen_ref", "absl/status", @@ -6388,9 +6413,11 @@ grpc_cc_library( language = "c++", deps = [ "activity", + "all_ok", "arena", "chaotic_good_frame", "chaotic_good_frame_header", + "chaotic_good_transport", "context", "event_engine_wakeup_scheduler", "for_each", @@ -6398,6 +6425,7 @@ grpc_cc_library( "if", "inter_activity_pipe", "loop", + "map", "match", "memory_quota", "mpsc", @@ -6414,6 +6442,63 @@ grpc_cc_library( "//:grpc_base", "//:hpack_encoder", "//:hpack_parser", + "//:promise", + "//:ref_counted_ptr", + ], +) + +grpc_cc_library( + name = "chaotic_good_server_transport", + srcs = [ + "ext/transport/chaotic_good/server_transport.cc", + ], + hdrs = [ + "ext/transport/chaotic_good/server_transport.h", + ], + external_deps = [ + "absl/base:core_headers", + "absl/container:flat_hash_map", + "absl/functional:any_invocable", + "absl/random", + "absl/random:bit_gen_ref", + "absl/status", + "absl/status:statusor", + "absl/types:optional", + "absl/types:variant", + ], + language = "c++", + deps = [ + "1999", + "activity", + "arena", + "chaotic_good_frame", + "chaotic_good_frame_header", + "chaotic_good_transport", + "context", + "default_event_engine", + "event_engine_wakeup_scheduler", + "for_each", + "grpc_promise_endpoint", + "if", + "inter_activity_pipe", + "loop", + "memory_quota", + "mpsc", + "pipe", + "poll", + "resource_quota", + "seq", + "slice", + "slice_buffer", + "switch", + "try_join", + "try_seq", + "//:exec_ctx", + "//:gpr", + "//:gpr_platform", + "//:grpc_base", + "//:hpack_encoder", + "//:hpack_parser", "//:ref_counted_ptr", ], ) diff --git a/src/core/ext/transport/chaotic_good/chaotic_good_transport.cc b/src/core/ext/transport/chaotic_good/chaotic_good_transport.cc new file mode 100644 index 00000000000..163f994d35f --- /dev/null +++ b/src/core/ext/transport/chaotic_good/chaotic_good_transport.cc @@ -0,0 +1,19 @@ +// Copyright 2023 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "src/core/ext/transport/chaotic_good/chaotic_good_transport.h" + +namespace grpc_core {} // namespace grpc_core diff --git a/src/core/ext/transport/chaotic_good/chaotic_good_transport.h b/src/core/ext/transport/chaotic_good/chaotic_good_transport.h new file mode 100644 index 00000000000..1096486bbea --- /dev/null +++ b/src/core/ext/transport/chaotic_good/chaotic_good_transport.h @@ -0,0 +1,111 @@ +// Copyright 2023 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GRPC_SRC_CORE_EXT_TRANSPORT_CHAOTIC_GOOD_CHAOTIC_GOOD_TRANSPORT_H +#define GRPC_SRC_CORE_EXT_TRANSPORT_CHAOTIC_GOOD_CHAOTIC_GOOD_TRANSPORT_H + +#include + +#include "absl/random/random.h" + +#include "src/core/ext/transport/chaotic_good/frame.h" +#include "src/core/ext/transport/chaotic_good/frame_header.h" +#include "src/core/ext/transport/chttp2/transport/hpack_encoder.h" +#include "src/core/lib/promise/if.h" +#include "src/core/lib/promise/promise.h" +#include "src/core/lib/promise/try_join.h" +#include "src/core/lib/promise/try_seq.h" +#include "src/core/lib/transport/promise_endpoint.h" + +namespace grpc_core { +namespace chaotic_good { + +class ChaoticGoodTransport { + public: + ChaoticGoodTransport(std::unique_ptr control_endpoint, + std::unique_ptr data_endpoint) + : control_endpoint_(std::move(control_endpoint)), + data_endpoint_(std::move(data_endpoint)) {} + + auto WriteFrame(const FrameInterface& frame) { + auto buffers = frame.Serialize(&encoder_); + return TryJoin( + control_endpoint_->Write(std::move(buffers.control)), + data_endpoint_->Write(std::move(buffers.data))); + } + + // Read frame header and payloads for control and data portions of one frame. + // Resolves to StatusOr>. + auto ReadFrameBytes() { + return TrySeq( + control_endpoint_->ReadSlice(FrameHeader::frame_header_size_), + [this](Slice read_buffer) { + auto frame_header = + FrameHeader::Parse(reinterpret_cast( + GRPC_SLICE_START_PTR(read_buffer.c_slice()))); + // Read header and trailers from control endpoint. + // Read message padding and message from data endpoint. + return If( + frame_header.ok(), + [this, &frame_header] { + const uint32_t message_padding = std::exchange( + last_message_padding_, frame_header->message_padding); + const uint32_t message_length = frame_header->message_length; + return Map( + TryJoin( + control_endpoint_->Read(frame_header->GetFrameLength()), + TrySeq(data_endpoint_->Read(message_padding), + [this, message_length]() { + return data_endpoint_->Read(message_length); + })), + [frame_header = *frame_header]( + absl::StatusOr> + buffers) + -> absl::StatusOr> { + if (!buffers.ok()) return buffers.status(); + return std::tuple( + frame_header, + BufferPair{std::move(std::get<0>(*buffers)), + std::move(std::get<1>(*buffers))}); + }); + }, + [&frame_header]() + -> absl::StatusOr> { + return frame_header.status(); + }); + }); + } + + absl::Status DeserializeFrame(FrameHeader header, BufferPair buffers, + Arena* arena, FrameInterface& frame) { + return frame.Deserialize(&parser_, header, bitgen_, arena, + std::move(buffers)); + } + + // Skip a frame, but correctly handle any hpack state updates. + void SkipFrame(FrameHeader, BufferPair) { Crash("not implemented"); } + + private: + const std::unique_ptr control_endpoint_; + const std::unique_ptr data_endpoint_; + uint32_t last_message_padding_ = 0; + HPackCompressor encoder_; + HPackParser parser_; + absl::BitGen bitgen_; +}; + +} // namespace chaotic_good +} // namespace grpc_core + +#endif // GRPC_SRC_CORE_EXT_TRANSPORT_CHAOTIC_GOOD_CHAOTIC_GOOD_TRANSPORT_H diff --git a/src/core/ext/transport/chaotic_good/client_transport.cc b/src/core/ext/transport/chaotic_good/client_transport.cc index 24f30687182..d4075bcad1b 100644 --- a/src/core/ext/transport/chaotic_good/client_transport.cc +++ b/src/core/ext/transport/chaotic_good/client_transport.cc @@ -17,9 +17,11 @@ #include "src/core/ext/transport/chaotic_good/client_transport.h" #include +#include #include #include #include +#include #include "absl/random/bit_gen_ref.h" #include "absl/random/random.h" @@ -36,9 +38,13 @@ #include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/iomgr/exec_ctx.h" #include "src/core/lib/promise/activity.h" +#include "src/core/lib/promise/all_ok.h" #include "src/core/lib/promise/event_engine_wakeup_scheduler.h" #include "src/core/lib/promise/loop.h" +#include "src/core/lib/promise/map.h" +#include "src/core/lib/promise/promise.h" #include "src/core/lib/promise/try_join.h" +#include "src/core/lib/promise/try_seq.h" #include "src/core/lib/resource_quota/arena.h" #include "src/core/lib/resource_quota/resource_quota.h" #include "src/core/lib/slice/slice.h" @@ -49,59 +55,15 @@ namespace grpc_core { namespace chaotic_good { -ClientTransport::ClientTransport( - std::unique_ptr control_endpoint, - std::unique_ptr data_endpoint, - std::shared_ptr event_engine) - : outgoing_frames_(MpscReceiver(4)), - control_endpoint_(std::move(control_endpoint)), - data_endpoint_(std::move(data_endpoint)), - control_endpoint_write_buffer_(SliceBuffer()), - data_endpoint_write_buffer_(SliceBuffer()), - hpack_compressor_(std::make_unique()), - hpack_parser_(std::make_unique()), - memory_allocator_( - ResourceQuota::Default()->memory_quota()->CreateMemoryAllocator( - "client_transport")), - arena_(MakeScopedArena(1024, &memory_allocator_)), - context_(arena_.get()), - event_engine_(event_engine) { - auto write_loop = Loop([this] { +auto ChaoticGoodClientTransport::TransportWriteLoop() { + return Loop([this] { return TrySeq( // Get next outgoing frame. - this->outgoing_frames_.Next(), - // Construct data buffers that will be sent to the endpoints. + outgoing_frames_.Next(), + // Serialize and write it out. [this](ClientFrame client_frame) { - MatchMutable( - &client_frame, - [this](ClientFragmentFrame* frame) mutable { - control_endpoint_write_buffer_.Append( - frame->Serialize(hpack_compressor_.get())); - if (frame->message != nullptr) { - std::string message_padding(frame->message_padding, '0'); - Slice slice(grpc_slice_from_cpp_string(message_padding)); - // Append message padding to data_endpoint_buffer. - data_endpoint_write_buffer_.Append(std::move(slice)); - // Append message payload to data_endpoint_buffer. - frame->message->payload()->MoveFirstNBytesIntoSliceBuffer( - frame->message->payload()->Length(), - data_endpoint_write_buffer_); - } - }, - [this](CancelFrame* frame) mutable { - control_endpoint_write_buffer_.Append( - frame->Serialize(hpack_compressor_.get())); - }); - return absl::OkStatus(); + return transport_.WriteFrame(GetFrameInterface(client_frame)); }, - // Write buffers to corresponding endpoints concurrently. - [this]() { - return TryJoin( - control_endpoint_->Write( - std::move(control_endpoint_write_buffer_)), - data_endpoint_->Write(std::move(data_endpoint_write_buffer_))); - }, - // Finish writes to difference endpoints and continue the loop. []() -> LoopCtl { // The write failures will be caught in TrySeq and exit loop. // Therefore, only need to return Continue() in the last lambda @@ -109,78 +71,215 @@ ClientTransport::ClientTransport( return Continue(); }); }); - writer_ = MakeActivity( - // Continuously write next outgoing frames to promise endpoints. - std::move(write_loop), EventEngineWakeupScheduler(event_engine_), - [this](absl::Status status) { - if (!(status.ok() || status.code() == absl::StatusCode::kCancelled)) { - this->AbortWithError(); - } +} + +absl::optional ChaoticGoodClientTransport::LookupStream( + uint32_t stream_id) { + MutexLock lock(&mu_); + auto it = stream_map_.find(stream_id); + if (it == stream_map_.end()) { + return absl::nullopt; + } + return it->second; +} + +auto ChaoticGoodClientTransport::PushFrameIntoCall(ServerFragmentFrame frame, + CallHandler call_handler) { + auto& headers = frame.headers; + return TrySeq( + If( + headers != nullptr, + [call_handler, &headers]() mutable { + return call_handler.PushServerInitialMetadata(std::move(headers)); + }, + []() -> StatusFlag { return Success{}; }), + [call_handler, message = std::move(frame.message)]() mutable { + return If( + message.has_value(), + [&call_handler, &message]() mutable { + return call_handler.PushMessage(std::move(message->message)); + }, + []() -> StatusFlag { return Success{}; }); }, - // Hold Arena in activity for GetContext usage. - arena_.get()); - auto read_loop = Loop([this] { + [call_handler, trailers = std::move(frame.trailers)]() mutable { + return If( + trailers != nullptr, + [&call_handler, &trailers]() mutable { + return call_handler.PushServerTrailingMetadata( + std::move(trailers)); + }, + []() -> StatusFlag { return Success{}; }); + }); +} + +auto ChaoticGoodClientTransport::TransportReadLoop() { + return Loop([this] { return TrySeq( - // Read frame header from control endpoint. - // TODO(ladynana): remove memcpy in ReadSlice. - this->control_endpoint_->ReadSlice(FrameHeader::frame_header_size_), - // Read different parts of the server frame from control/data endpoints - // based on frame header. - [this](Slice read_buffer) mutable { - frame_header_ = std::make_shared( - FrameHeader::Parse( - reinterpret_cast( - GRPC_SLICE_START_PTR(read_buffer.c_slice()))) - .value()); - // Read header and trailers from control endpoint. - // Read message padding and message from data endpoint. - return TryJoin( - control_endpoint_->Read(frame_header_->GetFrameLength()), - data_endpoint_->Read(frame_header_->message_padding + - frame_header_->message_length)); + transport_.ReadFrameBytes(), + [](std::tuple frame_bytes) + -> absl::StatusOr> { + const auto& frame_header = std::get<0>(frame_bytes); + if (frame_header.type != FrameType::kFragment) { + return absl::InternalError( + absl::StrCat("Expected fragment frame, got ", + static_cast(frame_header.type))); + } + return frame_bytes; }, - // Construct and send the server frame to corresponding stream. - [this](std::tuple ret) mutable { - control_endpoint_read_buffer_ = std::move(std::get<0>(ret)); - // Discard message padding and only keep message in data read buffer. - std::get<1>(ret).MoveLastNBytesIntoSliceBuffer( - frame_header_->message_length, data_endpoint_read_buffer_); + [this](std::tuple frame_bytes) { + const auto& frame_header = std::get<0>(frame_bytes); + auto& buffers = std::get<1>(frame_bytes); + absl::optional call_handler = + LookupStream(frame_header.stream_id); ServerFragmentFrame frame; - // Initialized to get this_cpu() info in global_stat(). - ExecCtx exec_ctx; - // Deserialize frame from read buffer. - absl::BitGen bitgen; - auto status = frame.Deserialize(hpack_parser_.get(), *frame_header_, - absl::BitGenRef(bitgen), - control_endpoint_read_buffer_); - GPR_ASSERT(status.ok()); - // Move message into frame. - frame.message = arena_->MakePooled( - std::move(data_endpoint_read_buffer_), 0); - MutexLock lock(&mu_); - const uint32_t stream_id = frame_header_->stream_id; - return stream_map_[stream_id]->Push(ServerFrame(std::move(frame))); - }, - // Check if send frame to corresponding stream successfully. - [](bool ret) -> LoopCtl { - if (ret) { - // Send incoming frames successfully. - return Continue(); + absl::Status deserialize_status; + if (call_handler.has_value()) { + deserialize_status = transport_.DeserializeFrame( + frame_header, std::move(buffers), call_handler->arena(), frame); } else { - return absl::InternalError("Send incoming frames failed."); + // Stream not found, skip the frame. + transport_.SkipFrame(frame_header, std::move(buffers)); + deserialize_status = absl::OkStatus(); } - }); + return If( + deserialize_status.ok() && call_handler.has_value(), + [this, &frame, &call_handler]() { + return call_handler->SpawnWaitable( + "push-frame", [this, call_handler = *call_handler, + frame = std::move(frame)]() mutable { + return Map(call_handler.CancelIfFails(PushFrameIntoCall( + std::move(frame), call_handler)), + [](StatusFlag f) { + return StatusCast(f); + }); + }); + }, + [&deserialize_status]() -> absl::Status { + // Stream not found, nothing to do. + return std::move(deserialize_status); + }); + }, + []() -> LoopCtl { return Continue{}; }); }); - reader_ = MakeActivity( - // Continuously read next incoming frames from promise endpoints. - std::move(read_loop), EventEngineWakeupScheduler(event_engine_), - [this](absl::Status status) { - if (!(status.ok() || status.code() == absl::StatusCode::kCancelled)) { - this->AbortWithError(); - } +} + +auto ChaoticGoodClientTransport::OnTransportActivityDone() { + return [this](absl::Status status) { + if (!(status.ok() || status.code() == absl::StatusCode::kCancelled)) { + this->AbortWithError(); + } + }; +} + +ChaoticGoodClientTransport::ChaoticGoodClientTransport( + std::unique_ptr control_endpoint, + std::unique_ptr data_endpoint, + std::shared_ptr event_engine) + : outgoing_frames_(4), + transport_(std::move(control_endpoint), std::move(data_endpoint)), + writer_{ + MakeActivity( + // Continuously write next outgoing frames to promise endpoints. + TransportWriteLoop(), EventEngineWakeupScheduler(event_engine), + OnTransportActivityDone()), + }, + reader_{MakeActivity( + // Continuously read next incoming frames from promise endpoints. + TransportReadLoop(), EventEngineWakeupScheduler(event_engine), + OnTransportActivityDone())} {} + +ChaoticGoodClientTransport::~ChaoticGoodClientTransport() { + if (writer_ != nullptr) { + writer_.reset(); + } + if (reader_ != nullptr) { + reader_.reset(); + } +} + +void ChaoticGoodClientTransport::AbortWithError() { + // Mark transport as unavailable when the endpoint write/read failed. + // Close all the available pipes. + outgoing_frames_.MarkClosed(); + ReleasableMutexLock lock(&mu_); + StreamMap stream_map = std::move(stream_map_); + stream_map_.clear(); + lock.Release(); + for (const auto& pair : stream_map) { + auto call_handler = pair.second; + call_handler.SpawnInfallible("cancel", [call_handler]() mutable { + call_handler.Cancel(ServerMetadataFromStatus( + absl::UnavailableError("Transport closed."))); + return Empty{}; + }); + } +} + +uint32_t ChaoticGoodClientTransport::MakeStream(CallHandler call_handler) { + ReleasableMutexLock lock(&mu_); + const uint32_t stream_id = next_stream_id_++; + stream_map_.emplace(stream_id, call_handler); + lock.Release(); + call_handler.OnDone([this, stream_id]() { + MutexLock lock(&mu_); + stream_map_.erase(stream_id); + }); + return stream_id; +} + +auto ChaoticGoodClientTransport::CallOutboundLoop(uint32_t stream_id, + CallHandler call_handler) { + auto send_fragment = [stream_id, + outgoing_frames = outgoing_frames_.MakeSender()]( + ClientFragmentFrame frame) mutable { + frame.stream_id = stream_id; + return Map(outgoing_frames.Send(std::move(frame)), + [](bool success) -> absl::Status { + if (!success) { + // Failed to send outgoing frame. + return absl::UnavailableError("Transport closed."); + } + return absl::OkStatus(); + }); + }; + return TrySeq( + // Wait for initial metadata then send it out. + call_handler.PullClientInitialMetadata(), + [send_fragment](ClientMetadataHandle md) mutable { + ClientFragmentFrame frame; + frame.headers = std::move(md); + return send_fragment(std::move(frame)); }, - // Hold Arena in activity for GetContext usage. - arena_.get()); + // Continuously send client frame with client to server messages. + ForEach(OutgoingMessages(call_handler), + [send_fragment, + aligned_bytes = aligned_bytes_](MessageHandle message) mutable { + ClientFragmentFrame frame; + // Construct frame header (flags, header_length and + // trailer_length will be added in serialization). + const uint32_t message_length = message->payload()->Length(); + const uint32_t padding = + message_length % aligned_bytes == 0 + ? 0 + : aligned_bytes - message_length % aligned_bytes; + GPR_ASSERT((message_length + padding) % aligned_bytes == 0); + frame.message = FragmentMessage(std::move(message), padding, + message_length); + return send_fragment(std::move(frame)); + }), + [send_fragment]() mutable { + ClientFragmentFrame frame; + frame.end_of_stream = true; + return send_fragment(std::move(frame)); + }); +} + +void ChaoticGoodClientTransport::StartCall(CallHandler call_handler) { + // At this point, the connection is set up. + // Start sending data frames. + call_handler.SpawnGuarded("outbound_loop", [this, call_handler]() mutable { + return CallOutboundLoop(MakeStream(call_handler), call_handler); + }); } } // namespace chaotic_good diff --git a/src/core/ext/transport/chaotic_good/client_transport.h b/src/core/ext/transport/chaotic_good/client_transport.h index 23ecdbfe84a..b8d515f9896 100644 --- a/src/core/ext/transport/chaotic_good/client_transport.h +++ b/src/core/ext/transport/chaotic_good/client_transport.h @@ -20,6 +20,7 @@ #include #include +#include #include // IWYU pragma: keep #include #include @@ -28,6 +29,8 @@ #include #include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/random/random.h" #include "absl/status/status.h" #include "absl/types/optional.h" #include "absl/types/variant.h" @@ -35,6 +38,7 @@ #include #include +#include "src/core/ext/transport/chaotic_good/chaotic_good_transport.h" #include "src/core/ext/transport/chaotic_good/frame.h" #include "src/core/ext/transport/chaotic_good/frame_header.h" #include "src/core/ext/transport/chttp2/transport/hpack_encoder.h" @@ -61,178 +65,56 @@ namespace grpc_core { namespace chaotic_good { -class ClientTransport { +class ChaoticGoodClientTransport final : public Transport, + public ClientTransport { public: - ClientTransport(std::unique_ptr control_endpoint, - std::unique_ptr data_endpoint, - std::shared_ptr - event_engine); - ~ClientTransport() { - if (writer_ != nullptr) { - writer_.reset(); - } - if (reader_ != nullptr) { - reader_.reset(); - } - } - void AbortWithError() { - // Mark transport as unavailable when the endpoint write/read failed. - // Close all the available pipes. - if (!outgoing_frames_.IsClosed()) { - outgoing_frames_.MarkClosed(); - } - MutexLock lock(&mu_); - for (const auto& pair : stream_map_) { - if (!pair.second->IsClose()) { - pair.second->MarkClose(); - } - } - } - auto AddStream(CallArgs call_args) { - // At this point, the connection is set up. - // Start sending data frames. - uint32_t stream_id; - InterActivityPipe pipe_server_frames; - { - MutexLock lock(&mu_); - stream_id = next_stream_id_++; - stream_map_.insert( - std::pair::Sender>>( - stream_id, std::make_shared::Sender>( - std::move(pipe_server_frames.sender)))); - } - return TrySeq( - TryJoin( - // Continuously send client frame with client to server messages. - ForEach(std::move(*call_args.client_to_server_messages), - [stream_id, initial_frame = true, - client_initial_metadata = - std::move(call_args.client_initial_metadata), - outgoing_frames = outgoing_frames_.MakeSender(), - this](MessageHandle result) mutable { - ClientFragmentFrame frame; - // Construct frame header (flags, header_length and - // trailer_length will be added in serialization). - uint32_t message_length = result->payload()->Length(); - frame.stream_id = stream_id; - frame.message_padding = message_length % aligned_bytes; - frame.message = std::move(result); - if (initial_frame) { - // Send initial frame with client intial metadata. - frame.headers = std::move(client_initial_metadata); - initial_frame = false; - } - return TrySeq( - outgoing_frames.Send(ClientFrame(std::move(frame))), - [](bool success) -> absl::Status { - if (!success) { - // TODO(ladynana): propagate the actual error - // message from EventEngine. - return absl::UnavailableError( - "Transport closed due to endpoint write/read " - "failed."); - } - return absl::OkStatus(); - }); - }), - // Continuously receive server frames from endpoints and save - // results to call_args. - Loop([server_initial_metadata = call_args.server_initial_metadata, - server_to_client_messages = - call_args.server_to_client_messages, - receiver = std::move(pipe_server_frames.receiver)]() mutable { - return TrySeq( - // Receive incoming server frame. - receiver.Next(), - // Save incomming frame results to call_args. - [server_initial_metadata, server_to_client_messages]( - absl::optional server_frame) mutable { - bool transport_closed = false; - ServerFragmentFrame frame; - if (!server_frame.has_value()) { - // Incoming server frame pipe is closed, which only - // happens when transport is aborted. - transport_closed = true; - } else { - frame = std::move( - absl::get(*server_frame)); - }; - bool has_headers = (frame.headers != nullptr); - bool has_message = (frame.message != nullptr); - bool has_trailers = (frame.trailers != nullptr); - return TrySeq( - If((!transport_closed) && has_headers, - [server_initial_metadata, - headers = std::move(frame.headers)]() mutable { - return server_initial_metadata->Push( - std::move(headers)); - }, - [] { return false; }), - If((!transport_closed) && has_message, - [server_to_client_messages, - message = std::move(frame.message)]() mutable { - return server_to_client_messages->Push( - std::move(message)); - }, - [] { return false; }), - If((!transport_closed) && has_trailers, - [trailers = std::move(frame.trailers)]() mutable - -> LoopCtl { - return std::move(trailers); - }, - [transport_closed]() - -> LoopCtl { - if (transport_closed) { - // TODO(ladynana): propagate the actual error - // message from EventEngine. - return ServerMetadataFromStatus( - absl::UnavailableError( - "Transport closed due to endpoint " - "write/read failed.")); - } - return Continue(); - })); - }); - })), - [](std::tuple ret) { - return std::move(std::get<1>(ret)); - }); - } + ChaoticGoodClientTransport( + std::unique_ptr control_endpoint, + std::unique_ptr data_endpoint, + std::shared_ptr + event_engine); + ~ChaoticGoodClientTransport() override; + + FilterStackTransport* filter_stack_transport() override { return nullptr; } + ClientTransport* client_transport() override { return this; } + ServerTransport* server_transport() override { return nullptr; } + absl::string_view GetTransportName() const override { return "chaotic_good"; } + void SetPollset(grpc_stream*, grpc_pollset*) override {} + void SetPollsetSet(grpc_stream*, grpc_pollset_set*) override {} + void PerformOp(grpc_transport_op*) override { Crash("unimplemented"); } + grpc_endpoint* GetEndpoint() override { return nullptr; } + void Orphan() override { delete this; } + + void StartCall(CallHandler call_handler) override; + void AbortWithError(); private: + // Queue size of each stream pipe is set to 2, so that for each stream read it + // will queue at most 2 frames. + static const size_t kServerFrameQueueSize = 2; + using StreamMap = absl::flat_hash_map; + + uint32_t MakeStream(CallHandler call_handler); + absl::optional LookupStream(uint32_t stream_id); + auto CallOutboundLoop(uint32_t stream_id, CallHandler call_handler); + auto OnTransportActivityDone(); + auto TransportWriteLoop(); + auto TransportReadLoop(); + // Push one frame into a call + auto PushFrameIntoCall(ServerFragmentFrame frame, CallHandler call_handler); + // Max buffer is set to 4, so that for stream writes each time it will queue // at most 2 frames. MpscReceiver outgoing_frames_; - // Queue size of each stream pipe is set to 2, so that for each stream read it - // will queue at most 2 frames. - static const size_t server_frame_queue_size_ = 2; + ChaoticGoodTransport transport_; // Assigned aligned bytes from setting frame. - size_t aligned_bytes = 64; + size_t aligned_bytes_ = 64; Mutex mu_; uint32_t next_stream_id_ ABSL_GUARDED_BY(mu_) = 1; // Map of stream incoming server frames, key is stream_id. - std::map::Sender>> - stream_map_ ABSL_GUARDED_BY(mu_); + StreamMap stream_map_ ABSL_GUARDED_BY(mu_); ActivityPtr writer_; ActivityPtr reader_; - std::unique_ptr control_endpoint_; - std::unique_ptr data_endpoint_; - SliceBuffer control_endpoint_write_buffer_; - SliceBuffer data_endpoint_write_buffer_; - SliceBuffer control_endpoint_read_buffer_; - SliceBuffer data_endpoint_read_buffer_; - std::unique_ptr hpack_compressor_; - std::unique_ptr hpack_parser_; - std::shared_ptr frame_header_; - MemoryAllocator memory_allocator_; - ScopedArenaPtr arena_; - promise_detail::Context context_; - // Use to synchronize writer_ and reader_ activity with outside activities; - std::shared_ptr event_engine_; }; } // namespace chaotic_good diff --git a/src/core/ext/transport/chaotic_good/frame.cc b/src/core/ext/transport/chaotic_good/frame.cc index f49fa4c4f3b..d2c0b3a4fc0 100644 --- a/src/core/ext/transport/chaotic_good/frame.cc +++ b/src/core/ext/transport/chaotic_good/frame.cc @@ -40,6 +40,10 @@ namespace grpc_core { namespace chaotic_good { +namespace { +const uint8_t kZeros[64] = {}; +} + namespace { const NoDestruct kZeroSlice{[] { // Frame header size is fixed to 24 bytes. @@ -50,53 +54,65 @@ const NoDestruct kZeroSlice{[] { class FrameSerializer { public: - explicit FrameSerializer(FrameType frame_type, uint32_t stream_id, - uint32_t message_padding) { - output_.AppendIndexed(kZeroSlice->Copy()); + explicit FrameSerializer(FrameType frame_type, uint32_t stream_id) { + output_.control.AppendIndexed(kZeroSlice->Copy()); header_.type = frame_type; header_.stream_id = stream_id; - header_.message_padding = message_padding; header_.flags.SetAll(false); } + // If called, must be called before AddTrailers, Finish. SliceBuffer& AddHeaders() { header_.flags.set(0); - return output_; + return output_.control; + } + + void AddMessage(const FragmentMessage& msg) { + header_.flags.set(1); + header_.message_length = msg.length; + header_.message_padding = msg.padding; + output_.data = msg.message->payload()->Copy(); + if (msg.padding != 0) { + output_.data.Append(Slice::FromStaticBuffer(kZeros, msg.padding)); + } } + // If called, must be called before Finish. SliceBuffer& AddTrailers() { - header_.flags.set(1); - header_.header_length = output_.Length() - FrameHeader::frame_header_size_; - return output_; + header_.flags.set(2); + header_.header_length = + output_.control.Length() - FrameHeader::frame_header_size_; + return output_.control; } - SliceBuffer Finish() { + BufferPair Finish() { // Calculate frame header_length or trailer_length if available. - if (header_.flags.is_set(1)) { + if (header_.flags.is_set(2)) { // Header length is already known in AddTrailers(). - header_.trailer_length = output_.Length() - header_.header_length - + header_.trailer_length = output_.control.Length() - + header_.header_length - FrameHeader::frame_header_size_; } else { if (header_.flags.is_set(0)) { // Calculate frame header length in Finish() since AddTrailers() isn't // called. header_.header_length = - output_.Length() - FrameHeader::frame_header_size_; + output_.control.Length() - FrameHeader::frame_header_size_; } } header_.Serialize( - GRPC_SLICE_START_PTR(output_.c_slice_buffer()->slices[0])); + GRPC_SLICE_START_PTR(output_.control.c_slice_buffer()->slices[0])); return std::move(output_); } private: FrameHeader header_; - SliceBuffer output_; + BufferPair output_; }; class FrameDeserializer { public: - FrameDeserializer(const FrameHeader& header, SliceBuffer& input) + FrameDeserializer(const FrameHeader& header, BufferPair& input) : header_(header), input_(input) {} const FrameHeader& header() const { return header_; } // If called, must be called before ReceiveTrailers, Finish. @@ -118,28 +134,27 @@ class FrameDeserializer { private: absl::StatusOr Take(uint32_t length) { if (length == 0) return SliceBuffer{}; - if (input_.Length() < length) { + if (input_.control.Length() < length) { return absl::InvalidArgumentError( "Frame too short (insufficient payload)"); } SliceBuffer out; - input_.MoveFirstNBytesIntoSliceBuffer(length, out); + input_.control.MoveFirstNBytesIntoSliceBuffer(length, out); return std::move(out); } FrameHeader header_; - SliceBuffer& input_; + BufferPair& input_; }; template absl::StatusOr> ReadMetadata( HPackParser* parser, absl::StatusOr maybe_slices, - uint32_t stream_id, bool is_header, bool is_client, - absl::BitGenRef bitsrc) { + uint32_t stream_id, bool is_header, bool is_client, absl::BitGenRef bitsrc, + Arena* arena) { if (!maybe_slices.ok()) return maybe_slices.status(); auto& slices = *maybe_slices; - auto arena = GetContext(); GPR_ASSERT(arena != nullptr); - Arena::PoolPtr metadata = arena->MakePooled(arena); + Arena::PoolPtr metadata = Arena::MakePooled(arena); parser->BeginFrame( metadata.get(), std::numeric_limits::max(), std::numeric_limits::max(), @@ -161,20 +176,23 @@ absl::StatusOr> ReadMetadata( } // namespace absl::Status SettingsFrame::Deserialize(HPackParser*, const FrameHeader& header, - absl::BitGenRef, - SliceBuffer& slice_buffer) { + absl::BitGenRef, Arena*, + BufferPair buffers) { if (header.type != FrameType::kSettings) { return absl::InvalidArgumentError("Expected settings frame"); } if (header.flags.any()) { return absl::InvalidArgumentError("Unexpected flags"); } - FrameDeserializer deserializer(header, slice_buffer); + if (buffers.data.Length() != 0) { + return absl::InvalidArgumentError("Unexpected data"); + } + FrameDeserializer deserializer(header, buffers); return deserializer.Finish(); } -SliceBuffer SettingsFrame::Serialize(HPackCompressor*) const { - FrameSerializer serializer(FrameType::kSettings, 0, 0); +BufferPair SettingsFrame::Serialize(HPackCompressor*) const { + FrameSerializer serializer(FrameType::kSettings, 0); return serializer.Finish(); } @@ -183,19 +201,20 @@ std::string SettingsFrame::ToString() const { return "SettingsFrame{}"; } absl::Status ClientFragmentFrame::Deserialize(HPackParser* parser, const FrameHeader& header, absl::BitGenRef bitsrc, - SliceBuffer& slice_buffer) { + Arena* arena, + BufferPair buffers) { if (header.stream_id == 0) { return absl::InvalidArgumentError("Expected non-zero stream id"); } stream_id = header.stream_id; - message_padding = header.message_padding; if (header.type != FrameType::kFragment) { return absl::InvalidArgumentError("Expected fragment frame"); } - FrameDeserializer deserializer(header, slice_buffer); + FrameDeserializer deserializer(header, buffers); if (header.flags.is_set(0)) { auto r = ReadMetadata(parser, deserializer.ReceiveHeaders(), - header.stream_id, true, true, bitsrc); + header.stream_id, true, true, bitsrc, + arena); if (!r.ok()) return r.status(); if (r.value() != nullptr) { headers = std::move(r.value()); @@ -205,8 +224,17 @@ absl::Status ClientFragmentFrame::Deserialize(HPackParser* parser, "Unexpected non-zero header length", header.header_length)); } if (header.flags.is_set(1)) { + message = + FragmentMessage{Arena::MakePooled(std::move(buffers.data), 0), + header.message_padding, header.message_length}; + } else if (buffers.data.Length() != 0) { + return absl::InvalidArgumentError(absl::StrCat( + "Unexpected non-zero message length ", buffers.data.Length())); + } + if (header.flags.is_set(2)) { if (header.trailer_length != 0) { - return absl::InvalidArgumentError("Unexpected trailer length"); + return absl::InvalidArgumentError( + absl::StrCat("Unexpected trailer length ", header.trailer_length)); } end_of_stream = true; } else { @@ -215,42 +243,53 @@ absl::Status ClientFragmentFrame::Deserialize(HPackParser* parser, return deserializer.Finish(); } -SliceBuffer ClientFragmentFrame::Serialize(HPackCompressor* encoder) const { +BufferPair ClientFragmentFrame::Serialize(HPackCompressor* encoder) const { GPR_ASSERT(stream_id != 0); - FrameSerializer serializer(FrameType::kFragment, stream_id, message_padding); + FrameSerializer serializer(FrameType::kFragment, stream_id); if (headers.get() != nullptr) { encoder->EncodeRawHeaders(*headers.get(), serializer.AddHeaders()); } + if (message.has_value()) { + serializer.AddMessage(message.value()); + } if (end_of_stream) { serializer.AddTrailers(); } return serializer.Finish(); } +std::string FragmentMessage::ToString() const { + std::string out = + absl::StrCat("FragmentMessage{length=", length, ", padding=", padding); + if (message.get() != nullptr) { + absl::StrAppend(&out, ", message=", message->DebugString().c_str()); + } + absl::StrAppend(&out, "}"); + return out; +} + std::string ClientFragmentFrame::ToString() const { return absl::StrCat( "ClientFragmentFrame{stream_id=", stream_id, ", headers=", headers.get() != nullptr ? headers->DebugString().c_str() : "nullptr", - ", message=", - message.get() != nullptr ? message->DebugString().c_str() : "nullptr", - ", message_padding=", message_padding, ", end_of_stream=", end_of_stream, - "}"); + ", message=", message.has_value() ? message->ToString().c_str() : "none", + ", end_of_stream=", end_of_stream, "}"); } absl::Status ServerFragmentFrame::Deserialize(HPackParser* parser, const FrameHeader& header, absl::BitGenRef bitsrc, - SliceBuffer& slice_buffer) { + Arena* arena, + BufferPair buffers) { if (header.stream_id == 0) { return absl::InvalidArgumentError("Expected non-zero stream id"); } stream_id = header.stream_id; - message_padding = header.message_padding; - FrameDeserializer deserializer(header, slice_buffer); + FrameDeserializer deserializer(header, buffers); if (header.flags.is_set(0)) { - auto r = - ReadMetadata(parser, deserializer.ReceiveHeaders(), - header.stream_id, true, false, bitsrc); + auto r = ReadMetadata(parser, deserializer.ReceiveHeaders(), + header.stream_id, true, false, bitsrc, + arena); if (!r.ok()) return r.status(); if (r.value() != nullptr) { headers = std::move(r.value()); @@ -260,9 +299,16 @@ absl::Status ServerFragmentFrame::Deserialize(HPackParser* parser, "Unexpected non-zero header length", header.header_length)); } if (header.flags.is_set(1)) { - auto r = - ReadMetadata(parser, deserializer.ReceiveTrailers(), - header.stream_id, false, false, bitsrc); + message.emplace(Arena::MakePooled(std::move(buffers.data), 0), + header.message_padding, header.message_length); + } else if (buffers.data.Length() != 0) { + return absl::InvalidArgumentError(absl::StrCat( + "Unexpected non-zero message length", buffers.data.Length())); + } + if (header.flags.is_set(2)) { + auto r = ReadMetadata( + parser, deserializer.ReceiveTrailers(), header.stream_id, false, false, + bitsrc, arena); if (!r.ok()) return r.status(); if (r.value() != nullptr) { trailers = std::move(r.value()); @@ -274,12 +320,15 @@ absl::Status ServerFragmentFrame::Deserialize(HPackParser* parser, return deserializer.Finish(); } -SliceBuffer ServerFragmentFrame::Serialize(HPackCompressor* encoder) const { +BufferPair ServerFragmentFrame::Serialize(HPackCompressor* encoder) const { GPR_ASSERT(stream_id != 0); - FrameSerializer serializer(FrameType::kFragment, stream_id, message_padding); + FrameSerializer serializer(FrameType::kFragment, stream_id); if (headers.get() != nullptr) { encoder->EncodeRawHeaders(*headers.get(), serializer.AddHeaders()); } + if (message.has_value()) { + serializer.AddMessage(message.value()); + } if (trailers.get() != nullptr) { encoder->EncodeRawHeaders(*trailers.get(), serializer.AddTrailers()); } @@ -290,16 +339,15 @@ std::string ServerFragmentFrame::ToString() const { return absl::StrCat( "ServerFragmentFrame{stream_id=", stream_id, ", headers=", headers.get() != nullptr ? headers->DebugString().c_str() : "nullptr", - ", message=", - message.get() != nullptr ? message->DebugString().c_str() : "nullptr", - ", message_padding=", message_padding, ", trailers=", + ", message=", message.has_value() ? message->ToString().c_str() : "none", + ", trailers=", trailers.get() != nullptr ? trailers->DebugString().c_str() : "nullptr", "}"); } absl::Status CancelFrame::Deserialize(HPackParser*, const FrameHeader& header, - absl::BitGenRef, - SliceBuffer& slice_buffer) { + absl::BitGenRef, Arena*, + BufferPair buffers) { if (header.type != FrameType::kCancel) { return absl::InvalidArgumentError("Expected cancel frame"); } @@ -309,14 +357,17 @@ absl::Status CancelFrame::Deserialize(HPackParser*, const FrameHeader& header, if (header.stream_id == 0) { return absl::InvalidArgumentError("Expected non-zero stream id"); } - FrameDeserializer deserializer(header, slice_buffer); + if (buffers.data.Length() != 0) { + return absl::InvalidArgumentError("Unexpected data"); + } + FrameDeserializer deserializer(header, buffers); stream_id = header.stream_id; return deserializer.Finish(); } -SliceBuffer CancelFrame::Serialize(HPackCompressor*) const { +BufferPair CancelFrame::Serialize(HPackCompressor*) const { GPR_ASSERT(stream_id != 0); - FrameSerializer serializer(FrameType::kCancel, stream_id, 0); + FrameSerializer serializer(FrameType::kCancel, stream_id); return serializer.Finish(); } diff --git a/src/core/ext/transport/chaotic_good/frame.h b/src/core/ext/transport/chaotic_good/frame.h index 529c89570c7..e7ccd6ee222 100644 --- a/src/core/ext/transport/chaotic_good/frame.h +++ b/src/core/ext/transport/chaotic_good/frame.h @@ -28,6 +28,7 @@ #include "src/core/ext/transport/chaotic_good/frame_header.h" #include "src/core/ext/transport/chttp2/transport/hpack_encoder.h" #include "src/core/ext/transport/chttp2/transport/hpack_parser.h" +#include "src/core/lib/gprpp/match.h" #include "src/core/lib/resource_quota/arena.h" #include "src/core/lib/slice/slice_buffer.h" #include "src/core/lib/transport/metadata_batch.h" @@ -36,20 +37,21 @@ namespace grpc_core { namespace chaotic_good { +struct BufferPair { + SliceBuffer control; + SliceBuffer data; +}; + class FrameInterface { public: virtual absl::Status Deserialize(HPackParser* parser, const FrameHeader& header, - absl::BitGenRef bitsrc, - SliceBuffer& slice_buffer) = 0; - virtual SliceBuffer Serialize(HPackCompressor* encoder) const = 0; + absl::BitGenRef bitsrc, Arena* arena, + BufferPair buffers) = 0; + virtual BufferPair Serialize(HPackCompressor* encoder) const = 0; virtual std::string ToString() const = 0; protected: - static bool EqVal(const Message& a, const Message& b) { - return a.payload()->JoinIntoString() == b.payload()->JoinIntoString() && - a.flags() == b.flags(); - } static bool EqVal(const grpc_metadata_batch& a, const grpc_metadata_batch& b) { return a.DebugString() == b.DebugString(); @@ -65,57 +67,75 @@ class FrameInterface { struct SettingsFrame final : public FrameInterface { absl::Status Deserialize(HPackParser* parser, const FrameHeader& header, - absl::BitGenRef bitsrc, - SliceBuffer& slice_buffer) override; - SliceBuffer Serialize(HPackCompressor* encoder) const override; + absl::BitGenRef bitsrc, Arena* arena, + BufferPair buffers) override; + BufferPair Serialize(HPackCompressor* encoder) const override; std::string ToString() const override; bool operator==(const SettingsFrame&) const { return true; } }; +struct FragmentMessage { + FragmentMessage(MessageHandle message, uint32_t padding, uint32_t length) + : message(std::move(message)), padding(padding), length(length) {} + + MessageHandle message; + uint32_t padding; + uint32_t length; + + std::string ToString() const; + + static bool EqVal(const Message& a, const Message& b) { + return a.payload()->JoinIntoString() == b.payload()->JoinIntoString() && + a.flags() == b.flags(); + } + + bool operator==(const FragmentMessage& other) const { + return EqVal(*message, *other.message) && length == other.length; + } +}; + struct ClientFragmentFrame final : public FrameInterface { absl::Status Deserialize(HPackParser* parser, const FrameHeader& header, - absl::BitGenRef bitsrc, - SliceBuffer& slice_buffer) override; - SliceBuffer Serialize(HPackCompressor* encoder) const override; + absl::BitGenRef bitsrc, Arena* arena, + BufferPair buffers) override; + BufferPair Serialize(HPackCompressor* encoder) const override; std::string ToString() const override; uint32_t stream_id; ClientMetadataHandle headers; - MessageHandle message; - uint32_t message_padding; + absl::optional message; bool end_of_stream = false; bool operator==(const ClientFragmentFrame& other) const { return stream_id == other.stream_id && EqHdl(headers, other.headers) && - end_of_stream == other.end_of_stream; + message == other.message && end_of_stream == other.end_of_stream; } }; struct ServerFragmentFrame final : public FrameInterface { absl::Status Deserialize(HPackParser* parser, const FrameHeader& header, - absl::BitGenRef bitsrc, - SliceBuffer& slice_buffer) override; - SliceBuffer Serialize(HPackCompressor* encoder) const override; + absl::BitGenRef bitsrc, Arena* arena, + BufferPair buffers) override; + BufferPair Serialize(HPackCompressor* encoder) const override; std::string ToString() const override; uint32_t stream_id; ServerMetadataHandle headers; - MessageHandle message; - uint32_t message_padding; + absl::optional message; ServerMetadataHandle trailers; bool operator==(const ServerFragmentFrame& other) const { return stream_id == other.stream_id && EqHdl(headers, other.headers) && - EqHdl(trailers, other.trailers); + message == other.message && EqHdl(trailers, other.trailers); } }; struct CancelFrame final : public FrameInterface { absl::Status Deserialize(HPackParser* parser, const FrameHeader& header, - absl::BitGenRef bitsrc, - SliceBuffer& slice_buffer) override; - SliceBuffer Serialize(HPackCompressor* encoder) const override; + absl::BitGenRef bitsrc, Arena* arena, + BufferPair buffers) override; + BufferPair Serialize(HPackCompressor* encoder) const override; std::string ToString() const override; uint32_t stream_id; @@ -128,6 +148,19 @@ struct CancelFrame final : public FrameInterface { using ClientFrame = absl::variant; using ServerFrame = absl::variant; +inline FrameInterface& GetFrameInterface(ClientFrame& frame) { + return MatchMutable( + &frame, + [](ClientFragmentFrame* frame) -> FrameInterface& { return *frame; }, + [](CancelFrame* frame) -> FrameInterface& { return *frame; }); +} + +inline FrameInterface& GetFrameInterface(ServerFrame& frame) { + return MatchMutable( + &frame, + [](ServerFragmentFrame* frame) -> FrameInterface& { return *frame; }); +} + } // namespace chaotic_good } // namespace grpc_core diff --git a/src/core/ext/transport/chaotic_good/frame_header.cc b/src/core/ext/transport/chaotic_good/frame_header.cc index e39d6a34b58..06f9d146e66 100644 --- a/src/core/ext/transport/chaotic_good/frame_header.cc +++ b/src/core/ext/transport/chaotic_good/frame_header.cc @@ -46,7 +46,6 @@ void FrameHeader::Serialize(uint8_t* data) const { WriteLittleEndianUint32( static_cast(type) | (flags.ToInt() << 8), data); if (flags.is_set(0)) GPR_ASSERT(header_length > 0); - if (flags.is_set(1)) GPR_ASSERT(trailer_length > 0); WriteLittleEndianUint32(stream_id, data + 4); WriteLittleEndianUint32(header_length, data + 8); WriteLittleEndianUint32(message_length, data + 12); @@ -60,8 +59,8 @@ absl::StatusOr FrameHeader::Parse(const uint8_t* data) { const uint32_t type_and_flags = ReadLittleEndianUint32(data); header.type = static_cast(type_and_flags & 0xff); const uint32_t flags = type_and_flags >> 8; - if (flags > 3) return absl::InvalidArgumentError("Invalid flags"); - header.flags = BitSet<2>::FromInt(flags); + if (flags > 7) return absl::InvalidArgumentError("Invalid flags"); + header.flags = BitSet<3>::FromInt(flags); header.stream_id = ReadLittleEndianUint32(data + 4); header.header_length = ReadLittleEndianUint32(data + 8); if (header.flags.is_set(0) && header.header_length <= 0) { @@ -70,11 +69,11 @@ absl::StatusOr FrameHeader::Parse(const uint8_t* data) { } header.message_length = ReadLittleEndianUint32(data + 12); header.message_padding = ReadLittleEndianUint32(data + 16); - header.trailer_length = ReadLittleEndianUint32(data + 20); - if (header.flags.is_set(1) && header.trailer_length <= 0) { + if (header.flags.is_set(1) && header.message_length <= 0) { return absl::InvalidArgumentError( - absl::StrCat("Invalid trailer length", header.trailer_length)); + absl::StrCat("Invalid message length: ", header.message_length)); } + header.trailer_length = ReadLittleEndianUint32(data + 20); return header; } diff --git a/src/core/ext/transport/chaotic_good/frame_header.h b/src/core/ext/transport/chaotic_good/frame_header.h index fa236ed3342..773b44f26e3 100644 --- a/src/core/ext/transport/chaotic_good/frame_header.h +++ b/src/core/ext/transport/chaotic_good/frame_header.h @@ -36,7 +36,7 @@ enum class FrameType : uint8_t { struct FrameHeader { FrameType type = FrameType::kCancel; - BitSet<2> flags; + BitSet<3> flags; uint32_t stream_id = 0; uint32_t header_length = 0; uint32_t message_length = 0; diff --git a/src/core/ext/transport/chaotic_good/server_transport.cc b/src/core/ext/transport/chaotic_good/server_transport.cc new file mode 100644 index 00000000000..3d4387ac949 --- /dev/null +++ b/src/core/ext/transport/chaotic_good/server_transport.cc @@ -0,0 +1,332 @@ +// Copyright 2022 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "src/core/ext/transport/chaotic_good/server_transport.h" + +#include +#include +#include + +#include "absl/random/bit_gen_ref.h" +#include "absl/random/random.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" + +#include +#include +#include + +#include "src/core/ext/transport/chaotic_good/frame.h" +#include "src/core/ext/transport/chaotic_good/frame_header.h" +#include "src/core/ext/transport/chttp2/transport/hpack_encoder.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/promise/activity.h" +#include "src/core/lib/promise/event_engine_wakeup_scheduler.h" +#include "src/core/lib/promise/for_each.h" +#include "src/core/lib/promise/loop.h" +#include "src/core/lib/promise/switch.h" +#include "src/core/lib/promise/try_seq.h" +#include "src/core/lib/resource_quota/arena.h" +#include "src/core/lib/resource_quota/resource_quota.h" +#include "src/core/lib/slice/slice.h" +#include "src/core/lib/slice/slice_buffer.h" +#include "src/core/lib/transport/promise_endpoint.h" + +namespace grpc_core { +namespace chaotic_good { + +auto ChaoticGoodServerTransport::TransportWriteLoop() { + return Loop([this] { + return TrySeq( + // Get next outgoing frame. + outgoing_frames_.Next(), + // Serialize and write it out. + [this](ServerFrame client_frame) { + return transport_.WriteFrame(GetFrameInterface(client_frame)); + }, + []() -> LoopCtl { + // The write failures will be caught in TrySeq and exit loop. + // Therefore, only need to return Continue() in the last lambda + // function. + return Continue(); + }); + }); +} + +auto ChaoticGoodServerTransport::PushFragmentIntoCall( + CallInitiator call_initiator, ClientFragmentFrame frame) { + auto& headers = frame.headers; + return TrySeq( + If( + headers != nullptr, + [call_initiator, &headers]() mutable { + return call_initiator.PushClientInitialMetadata(std::move(headers)); + }, + []() -> StatusFlag { return Success{}; }), + [call_initiator, message = std::move(frame.message)]() mutable { + return If( + message.has_value(), + [&call_initiator, &message]() mutable { + return call_initiator.PushMessage(std::move(message->message)); + }, + []() -> StatusFlag { return Success{}; }); + }, + [call_initiator, + end_of_stream = frame.end_of_stream]() mutable -> StatusFlag { + if (end_of_stream) call_initiator.FinishSends(); + return Success{}; + }); +} + +auto ChaoticGoodServerTransport::MaybePushFragmentIntoCall( + absl::optional call_initiator, absl::Status error, + ClientFragmentFrame frame) { + return If( + call_initiator.has_value() && error.ok(), + [this, &call_initiator, &frame]() { + return Map( + call_initiator->SpawnWaitable( + "push-fragment", + [call_initiator, frame = std::move(frame), this]() mutable { + return call_initiator->CancelIfFails( + PushFragmentIntoCall(*call_initiator, std::move(frame))); + }), + [](StatusFlag status) { return StatusCast(status); }); + }, + [error = std::move(error)]() { return error; }); +} + +auto ChaoticGoodServerTransport::CallOutboundLoop( + uint32_t stream_id, CallInitiator call_initiator) { + auto send_fragment = [stream_id, + outgoing_frames = outgoing_frames_.MakeSender()]( + ServerFragmentFrame frame) mutable { + frame.stream_id = stream_id; + return Map(outgoing_frames.Send(std::move(frame)), + [](bool success) -> absl::Status { + if (!success) { + // Failed to send outgoing frame. + return absl::UnavailableError("Transport closed."); + } + return absl::OkStatus(); + }); + }; + return Seq( + TrySeq( + // Wait for initial metadata then send it out. + call_initiator.PullServerInitialMetadata(), + [send_fragment](ServerMetadataHandle md) mutable { + ServerFragmentFrame frame; + frame.headers = std::move(md); + return send_fragment(std::move(frame)); + }, + // Continuously send client frame with client to server messages. + ForEach(OutgoingMessages(call_initiator), + [send_fragment, aligned_bytes = aligned_bytes_]( + MessageHandle message) mutable { + ServerFragmentFrame frame; + // Construct frame header (flags, header_length and + // trailer_length will be added in serialization). + const uint32_t message_length = + message->payload()->Length(); + const uint32_t padding = + message_length % aligned_bytes == 0 + ? 0 + : aligned_bytes - message_length % aligned_bytes; + GPR_ASSERT((message_length + padding) % aligned_bytes == 0); + frame.message = FragmentMessage(std::move(message), padding, + message_length); + return send_fragment(std::move(frame)); + })), + call_initiator.PullServerTrailingMetadata(), + [send_fragment](ServerMetadataHandle md) mutable { + ServerFragmentFrame frame; + frame.trailers = std::move(md); + return send_fragment(std::move(frame)); + }); +} + +auto ChaoticGoodServerTransport::DeserializeAndPushFragmentToNewCall( + FrameHeader frame_header, BufferPair buffers) { + ClientFragmentFrame fragment_frame; + ScopedArenaPtr arena(acceptor_->CreateArena()); + absl::Status status = transport_.DeserializeFrame( + frame_header, std::move(buffers), arena.get(), fragment_frame); + absl::optional call_initiator; + if (status.ok()) { + auto create_call_result = + acceptor_->CreateCall(*fragment_frame.headers, arena.release()); + if (create_call_result.ok()) { + call_initiator.emplace(std::move(*create_call_result)); + call_initiator->SpawnGuarded( + "server-write", [this, stream_id = frame_header.stream_id, + call_initiator = *call_initiator]() { + return CallOutboundLoop(stream_id, call_initiator); + }); + } else { + status = create_call_result.status(); + } + } + return MaybePushFragmentIntoCall(std::move(call_initiator), std::move(status), + std::move(fragment_frame)); +} + +auto ChaoticGoodServerTransport::DeserializeAndPushFragmentToExistingCall( + FrameHeader frame_header, BufferPair buffers) { + absl::optional call_initiator = + LookupStream(frame_header.stream_id); + Arena* arena = nullptr; + if (call_initiator.has_value()) arena = call_initiator->arena(); + ClientFragmentFrame fragment_frame; + absl::Status status = transport_.DeserializeFrame( + frame_header, std::move(buffers), arena, fragment_frame); + return MaybePushFragmentIntoCall(std::move(call_initiator), std::move(status), + std::move(fragment_frame)); +} + +auto ChaoticGoodServerTransport::TransportReadLoop() { + return Loop([this] { + return TrySeq( + transport_.ReadFrameBytes(), + [this](std::tuple frame_bytes) { + const auto& frame_header = std::get<0>(frame_bytes); + auto& buffers = std::get<1>(frame_bytes); + return Switch( + frame_header.type, + Case(FrameType::kSettings, + []() -> absl::Status { + return absl::InternalError("Unexpected settings frame"); + }), + Case(FrameType::kFragment, + [this, &frame_header, &buffers]() { + return If( + frame_header.flags.is_set(0), + [this, &frame_header, &buffers]() { + return DeserializeAndPushFragmentToNewCall( + frame_header, std::move(buffers)); + }, + [this, &frame_header, &buffers]() { + return DeserializeAndPushFragmentToExistingCall( + frame_header, std::move(buffers)); + }); + }), + Case(FrameType::kCancel, + [this, &frame_header]() { + absl::optional call_initiator = + ExtractStream(frame_header.stream_id); + return If( + call_initiator.has_value(), + [&call_initiator]() { + auto c = std::move(*call_initiator); + return c.SpawnWaitable("cancel", [c]() mutable { + c.Cancel(); + return absl::OkStatus(); + }); + }, + []() -> absl::Status { + return absl::InternalError( + "Unexpected cancel frame"); + }); + }), + Default([frame_header]() { + return absl::InternalError( + absl::StrCat("Unexpected frame type: ", + static_cast(frame_header.type))); + })); + }, + []() -> LoopCtl { return Continue{}; }); + }); +} + +auto ChaoticGoodServerTransport::OnTransportActivityDone() { + return [this](absl::Status status) { + if (!(status.ok() || status.code() == absl::StatusCode::kCancelled)) { + this->AbortWithError(); + } + }; +} + +ChaoticGoodServerTransport::ChaoticGoodServerTransport( + const ChannelArgs& args, std::unique_ptr control_endpoint, + std::unique_ptr data_endpoint, + std::shared_ptr event_engine) + : outgoing_frames_(4), + transport_(std::move(control_endpoint), std::move(data_endpoint)), + allocator_(args.GetObject() + ->memory_quota() + ->CreateMemoryAllocator("chaotic-good")), + event_engine_(event_engine), + writer_{MakeActivity(TransportWriteLoop(), + EventEngineWakeupScheduler(event_engine), + OnTransportActivityDone())}, + reader_{nullptr} {} + +void ChaoticGoodServerTransport::SetAcceptor(Acceptor* acceptor) { + GPR_ASSERT(acceptor_ == nullptr); + GPR_ASSERT(acceptor != nullptr); + acceptor_ = acceptor; + reader_ = MakeActivity(TransportReadLoop(), + EventEngineWakeupScheduler(event_engine_), + OnTransportActivityDone()); +} + +ChaoticGoodServerTransport::~ChaoticGoodServerTransport() { + if (writer_ != nullptr) { + writer_.reset(); + } + if (reader_ != nullptr) { + reader_.reset(); + } +} + +void ChaoticGoodServerTransport::AbortWithError() { + // Mark transport as unavailable when the endpoint write/read failed. + // Close all the available pipes. + outgoing_frames_.MarkClosed(); + ReleasableMutexLock lock(&mu_); + StreamMap stream_map = std::move(stream_map_); + stream_map_.clear(); + lock.Release(); + for (const auto& pair : stream_map) { + auto call_initiator = pair.second; + call_initiator.SpawnInfallible("cancel", [call_initiator]() mutable { + call_initiator.Cancel(); + return Empty{}; + }); + } +} + +absl::optional ChaoticGoodServerTransport::LookupStream( + uint32_t stream_id) { + MutexLock lock(&mu_); + auto it = stream_map_.find(stream_id); + if (it == stream_map_.end()) return absl::nullopt; + return it->second; +} + +absl::optional ChaoticGoodServerTransport::ExtractStream( + uint32_t stream_id) { + MutexLock lock(&mu_); + auto it = stream_map_.find(stream_id); + if (it == stream_map_.end()) return absl::nullopt; + auto r = std::move(it->second); + stream_map_.erase(it); + return std::move(r); +} + +} // namespace chaotic_good +} // namespace grpc_core diff --git a/src/core/ext/transport/chaotic_good/server_transport.h b/src/core/ext/transport/chaotic_good/server_transport.h new file mode 100644 index 00000000000..9ce92928385 --- /dev/null +++ b/src/core/ext/transport/chaotic_good/server_transport.h @@ -0,0 +1,145 @@ +// Copyright 2022 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GRPC_SRC_CORE_EXT_TRANSPORT_CHAOTIC_GOOD_SERVER_TRANSPORT_H +#define GRPC_SRC_CORE_EXT_TRANSPORT_CHAOTIC_GOOD_SERVER_TRANSPORT_H + +#include + +#include +#include + +#include +#include // IWYU pragma: keep +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/functional/any_invocable.h" +#include "absl/random/random.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/types/optional.h" +#include "absl/types/variant.h" + +#include +#include +#include +#include + +#include "src/core/ext/transport/chaotic_good/chaotic_good_transport.h" +#include "src/core/ext/transport/chaotic_good/frame.h" +#include "src/core/ext/transport/chaotic_good/frame_header.h" +#include "src/core/ext/transport/chttp2/transport/hpack_encoder.h" +#include "src/core/ext/transport/chttp2/transport/hpack_parser.h" +#include "src/core/lib/event_engine/default_event_engine.h" // IWYU pragma: keep +#include "src/core/lib/gprpp/ref_counted_ptr.h" +#include "src/core/lib/gprpp/sync.h" +#include "src/core/lib/promise/activity.h" +#include "src/core/lib/promise/context.h" +#include "src/core/lib/promise/if.h" +#include "src/core/lib/promise/inter_activity_pipe.h" +#include "src/core/lib/promise/loop.h" +#include "src/core/lib/promise/mpsc.h" +#include "src/core/lib/promise/party.h" +#include "src/core/lib/promise/pipe.h" +#include "src/core/lib/promise/poll.h" +#include "src/core/lib/promise/seq.h" +#include "src/core/lib/promise/try_join.h" +#include "src/core/lib/promise/try_seq.h" +#include "src/core/lib/resource_quota/arena.h" +#include "src/core/lib/resource_quota/memory_quota.h" +#include "src/core/lib/slice/slice.h" +#include "src/core/lib/slice/slice_buffer.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/transport/metadata_batch.h" +#include "src/core/lib/transport/promise_endpoint.h" +#include "src/core/lib/transport/transport.h" + +namespace grpc_core { +namespace chaotic_good { + +class ChaoticGoodServerTransport final : public Transport, + public ServerTransport { + public: + ChaoticGoodServerTransport( + const ChannelArgs& args, + std::unique_ptr control_endpoint, + std::unique_ptr data_endpoint, + std::shared_ptr + event_engine); + ~ChaoticGoodServerTransport() override; + + FilterStackTransport* filter_stack_transport() override { return nullptr; } + ClientTransport* client_transport() override { return nullptr; } + ServerTransport* server_transport() override { return this; } + absl::string_view GetTransportName() const override { return "chaotic_good"; } + void SetPollset(grpc_stream*, grpc_pollset*) override {} + void SetPollsetSet(grpc_stream*, grpc_pollset_set*) override {} + void PerformOp(grpc_transport_op*) override { Crash("unimplemented"); } + grpc_endpoint* GetEndpoint() override { return nullptr; } + void Orphan() override { delete this; } + + void SetAcceptor(Acceptor* acceptor) override; + void AbortWithError(); + + private: + using StreamMap = absl::flat_hash_map; + + absl::Status NewStream(uint32_t stream_id, CallInitiator call_initiator); + absl::optional LookupStream(uint32_t stream_id); + absl::optional ExtractStream(uint32_t stream_id); + auto CallOutboundLoop(uint32_t stream_id, CallInitiator call_initiator); + auto OnTransportActivityDone(); + auto TransportReadLoop(); + auto TransportWriteLoop(); + // Read different parts of the server frame from control/data endpoints + // based on frame header. + // Resolves to a StatusOr> + auto ReadFrameBody(Slice read_buffer); + void SendCancel(uint32_t stream_id, absl::Status why); + auto DeserializeAndPushFragmentToNewCall(FrameHeader frame_header, + BufferPair buffers); + auto DeserializeAndPushFragmentToExistingCall(FrameHeader frame_header, + BufferPair buffers); + auto MaybePushFragmentIntoCall(absl::optional call_initiator, + absl::Status error, ClientFragmentFrame frame); + auto PushFragmentIntoCall(CallInitiator call_initiator, + ClientFragmentFrame frame); + + Acceptor* acceptor_ = nullptr; + MpscReceiver outgoing_frames_; + ChaoticGoodTransport transport_; + // Assigned aligned bytes from setting frame. + size_t aligned_bytes_ = 64; + Mutex mu_; + // Map of stream incoming server frames, key is stream_id. + StreamMap stream_map_ ABSL_GUARDED_BY(mu_); + grpc_event_engine::experimental::MemoryAllocator allocator_; + std::shared_ptr event_engine_; + ActivityPtr writer_; + ActivityPtr reader_; +}; + +} // namespace chaotic_good +} // namespace grpc_core + +#endif // GRPC_SRC_CORE_EXT_TRANSPORT_CHAOTIC_GOOD_SERVER_TRANSPORT_H \ No newline at end of file diff --git a/src/core/ext/transport/inproc/inproc_transport.cc b/src/core/ext/transport/inproc/inproc_transport.cc index cc932d0f216..24fdd41d387 100644 --- a/src/core/ext/transport/inproc/inproc_transport.cc +++ b/src/core/ext/transport/inproc/inproc_transport.cc @@ -36,8 +36,8 @@ class InprocServerTransport final : public RefCounted, public Transport, public ServerTransport { public: - void SetAcceptFunction(AcceptFunction accept_function) override { - accept_ = std::move(accept_function); + void SetAcceptor(Acceptor* acceptor) override { + acceptor_ = acceptor; ConnectionState expect = ConnectionState::kInitial; state_.compare_exchange_strong(expect, ConnectionState::kReady, std::memory_order_acq_rel, @@ -92,7 +92,7 @@ class InprocServerTransport final : public RefCounted, case ConnectionState::kReady: break; } - return accept_(md); + return acceptor_->CreateCall(md, acceptor_->CreateArena()); } private: @@ -100,7 +100,7 @@ class InprocServerTransport final : public RefCounted, std::atomic state_{ConnectionState::kInitial}; std::atomic disconnecting_{false}; - AcceptFunction accept_; + Acceptor* acceptor_; absl::Status disconnect_error_; Mutex state_tracker_mu_; ConnectivityStateTracker state_tracker_ ABSL_GUARDED_BY(state_tracker_mu_){ diff --git a/src/core/lib/gprpp/debug_location.h b/src/core/lib/gprpp/debug_location.h index 7e021fd9f78..c6c9b682869 100644 --- a/src/core/lib/gprpp/debug_location.h +++ b/src/core/lib/gprpp/debug_location.h @@ -81,6 +81,15 @@ class DebugLocation { }; #endif +template +struct ValueWithDebugLocation { + // NOLINTNEXTLINE + ValueWithDebugLocation(T&& value, DebugLocation debug_location = {}) + : value(std::forward(value)), debug_location(debug_location) {} + T value; + GPR_NO_UNIQUE_ADDRESS DebugLocation debug_location; +}; + #define DEBUG_LOCATION ::grpc_core::DebugLocation(__FILE__, __LINE__) } // namespace grpc_core diff --git a/src/core/lib/promise/detail/status.h b/src/core/lib/promise/detail/status.h index 1063f329193..bfc649e6c48 100644 --- a/src/core/lib/promise/detail/status.h +++ b/src/core/lib/promise/detail/status.h @@ -45,6 +45,11 @@ inline absl::Status IntoStatus(absl::Status* status) { // can participate in TrySeq as result types that affect control flow. inline bool IsStatusOk(const absl::Status& status) { return status.ok(); } +template +inline bool IsStatusOk(const absl::StatusOr& status) { + return status.ok(); +} + template struct StatusCastImpl; @@ -59,20 +64,52 @@ struct StatusCastImpl { }; template -struct StatusCastImpl, absl::Status> { - static absl::StatusOr Cast(absl::Status&& t) { return std::move(t); } +struct StatusCastImpl> { + static absl::Status Cast(absl::StatusOr&& t) { + return std::move(t.status()); + } }; template -struct StatusCastImpl, const absl::Status&> { - static absl::StatusOr Cast(const absl::Status& t) { return t; } +struct StatusCastImpl&> { + static absl::Status Cast(const absl::StatusOr& t) { return t.status(); } }; +template +struct StatusCastImpl&> { + static absl::Status Cast(const absl::StatusOr& t) { return t.status(); } +}; + +// StatusCast<> allows casting from one status-bearing type to another, +// regardless of whether the status indicates success or failure. +// This means that we can go from StatusOr to Status safely, but not in the +// opposite direction. +// For cases where the status is guaranteed to be a failure (and hence not +// needing to preserve values) see FailureStatusCast<> below. template To StatusCast(From&& from) { return StatusCastImpl::Cast(std::forward(from)); } +template +struct FailureStatusCastImpl : public StatusCastImpl {}; + +template +struct FailureStatusCastImpl, absl::Status> { + static absl::StatusOr Cast(absl::Status&& t) { return std::move(t); } +}; + +template +struct FailureStatusCastImpl, const absl::Status&> { + static absl::StatusOr Cast(const absl::Status& t) { return t; } +}; + +template +To FailureStatusCast(From&& from) { + GPR_DEBUG_ASSERT(!IsStatusOk(from)); + return FailureStatusCastImpl::Cast(std::forward(from)); +} + } // namespace grpc_core #endif // GRPC_SRC_CORE_LIB_PROMISE_DETAIL_STATUS_H diff --git a/src/core/lib/promise/event_engine_wakeup_scheduler.h b/src/core/lib/promise/event_engine_wakeup_scheduler.h index 792ee9d4439..3e489c87fc6 100644 --- a/src/core/lib/promise/event_engine_wakeup_scheduler.h +++ b/src/core/lib/promise/event_engine_wakeup_scheduler.h @@ -33,7 +33,9 @@ class EventEngineWakeupScheduler { explicit EventEngineWakeupScheduler( std::shared_ptr event_engine) - : event_engine_(std::move(event_engine)) {} + : event_engine_(std::move(event_engine)) { + GPR_ASSERT(event_engine_ != nullptr); + } template class BoundScheduler diff --git a/src/core/lib/promise/if.h b/src/core/lib/promise/if.h index e659ad30ed8..2c81fd7e38a 100644 --- a/src/core/lib/promise/if.h +++ b/src/core/lib/promise/if.h @@ -192,6 +192,10 @@ class If { // If it returns failure, returns failure for the entire combinator. // If it returns true, evaluates the second promise. // If it returns false, evaluates the third promise. +// If C is a constant, it's guaranteed that one of the promise factories +// if_true or if_false will be evaluated before returning from this function. +// This makes it safe to capture lambda arguments in the promise factory by +// reference. template promise_detail::If If(C condition, T if_true, F if_false) { return promise_detail::If(std::move(condition), std::move(if_true), diff --git a/src/core/lib/promise/inter_activity_pipe.h b/src/core/lib/promise/inter_activity_pipe.h index a7594fb26a2..4578cbb3c62 100644 --- a/src/core/lib/promise/inter_activity_pipe.h +++ b/src/core/lib/promise/inter_activity_pipe.h @@ -113,9 +113,9 @@ class InterActivityPipe { if (center_ != nullptr) center_->MarkClosed(); } - bool IsClose() { return center_->IsClosed(); } + bool IsClosed() { return center_->IsClosed(); } - void MarkClose() { + void MarkClosed() { if (center_ != nullptr) center_->MarkClosed(); } @@ -146,6 +146,12 @@ class InterActivityPipe { return [center = center_]() { return center->Next(); }; } + bool IsClose() { return center_->IsClosed(); } + + void MarkClose() { + if (center_ != nullptr) center_->MarkClosed(); + } + private: RefCountedPtr
center_; }; diff --git a/src/core/lib/promise/mpsc.h b/src/core/lib/promise/mpsc.h index c12544282f6..8bbbfc4c8ec 100644 --- a/src/core/lib/promise/mpsc.h +++ b/src/core/lib/promise/mpsc.h @@ -103,14 +103,12 @@ class Center : public RefCounted> { // Mark that the receiver is closed. void ReceiverClosed() { - MutexLock lock(&mu_); + ReleasableMutexLock lock(&mu_); + if (receiver_closed_) return; receiver_closed_ = true; - } - - // Return whether the receiver is closed. - bool IsClosed() { - MutexLock lock(&mu_); - return receiver_closed_; + auto wakeups = send_wakers_.TakeWakeupSet(); + lock.Release(); + wakeups.Wakeup(); } private: @@ -131,8 +129,8 @@ class MpscReceiver; template class MpscSender { public: - MpscSender(const MpscSender&) = delete; - MpscSender& operator=(const MpscSender&) = delete; + MpscSender(const MpscSender&) = default; + MpscSender& operator=(const MpscSender&) = default; MpscSender(MpscSender&&) noexcept = default; MpscSender& operator=(MpscSender&&) noexcept = default; @@ -140,7 +138,10 @@ class MpscSender { // Resolves to true if sent, false if the receiver was closed (and the value // will never be successfully sent). auto Send(T t) { - return [this, t = std::move(t)]() mutable { return center_->PollSend(t); }; + return [center = center_, t = std::move(t)]() mutable -> Poll { + if (center == nullptr) return false; + return center->PollSend(t); + }; } bool UnbufferedImmediateSend(T t) { @@ -170,7 +171,6 @@ class MpscReceiver { ~MpscReceiver() { if (center_ != nullptr) center_->ReceiverClosed(); } - bool IsClosed() { return center_->IsClosed(); } void MarkClosed() { if (center_ != nullptr) center_->ReceiverClosed(); } diff --git a/src/core/lib/promise/status_flag.h b/src/core/lib/promise/status_flag.h index d9067509c32..c8c9c0ba41e 100644 --- a/src/core/lib/promise/status_flag.h +++ b/src/core/lib/promise/status_flag.h @@ -95,6 +95,30 @@ struct StatusCastImpl { } }; +template +struct FailureStatusCastImpl, StatusFlag> { + static absl::StatusOr Cast(StatusFlag flag) { + GPR_DEBUG_ASSERT(!flag.ok()); + return absl::CancelledError(); + } +}; + +template +struct FailureStatusCastImpl, StatusFlag&> { + static absl::StatusOr Cast(StatusFlag flag) { + GPR_DEBUG_ASSERT(!flag.ok()); + return absl::CancelledError(); + } +}; + +template +struct FailureStatusCastImpl, const StatusFlag&> { + static absl::StatusOr Cast(StatusFlag flag) { + GPR_DEBUG_ASSERT(!flag.ok()); + return absl::CancelledError(); + } +}; + // A value if an operation was successful, or a failure flag if not. template class ValueOrFailure { diff --git a/src/core/lib/promise/try_join.h b/src/core/lib/promise/try_join.h index be3354dc9b7..29cd06d5be4 100644 --- a/src/core/lib/promise/try_join.h +++ b/src/core/lib/promise/try_join.h @@ -75,16 +75,16 @@ struct TryJoinTraits { } template static R EarlyReturn(absl::Status x) { - return StatusCast(std::move(x)); + return FailureStatusCast(std::move(x)); } template static R EarlyReturn(StatusFlag x) { - return StatusCast(x); + return FailureStatusCast(x); } template static R EarlyReturn(const ValueOrFailure& x) { GPR_ASSERT(!x.ok()); - return StatusCast(Failure{}); + return FailureStatusCast(Failure{}); } template static auto FinalReturn(A&&... a) { diff --git a/src/core/lib/promise/try_seq.h b/src/core/lib/promise/try_seq.h index ca04904ab1d..ab9777d1afb 100644 --- a/src/core/lib/promise/try_seq.h +++ b/src/core/lib/promise/try_seq.h @@ -76,7 +76,7 @@ struct TrySeqTraitsWithSfinae> { } template static R ReturnValue(absl::StatusOr&& status) { - return StatusCast(status.status()); + return FailureStatusCast(status.status()); } template static auto CallSeqFactory(F& f, Elem&& elem, absl::StatusOr value) @@ -86,11 +86,26 @@ struct TrySeqTraitsWithSfinae> { template static Poll CheckResultAndRunNext(absl::StatusOr prior, RunNext run_next) { - if (!prior.ok()) return StatusCast(prior.status()); + if (!prior.ok()) return FailureStatusCast(prior.status()); return run_next(std::move(prior)); } }; +template +struct AllowGenericTrySeqTraits { + static constexpr bool value = true; +}; + +template <> +struct AllowGenericTrySeqTraits { + static constexpr bool value = false; +}; + +template +struct AllowGenericTrySeqTraits> { + static constexpr bool value = false; +}; + template struct TakeValueExists { static constexpr bool value = false; @@ -107,7 +122,7 @@ template struct TrySeqTraitsWithSfinae< T, absl::enable_if_t< std::is_same())), bool>::value && - !TakeValueExists::value, + !TakeValueExists::value && AllowGenericTrySeqTraits::value, void>> { using UnwrappedType = void; using WrappedType = T; @@ -121,7 +136,7 @@ struct TrySeqTraitsWithSfinae< } template static R ReturnValue(T&& status) { - return StatusCast(std::move(status)); + return FailureStatusCast(std::move(status)); } template static Poll CheckResultAndRunNext(T prior, RunNext run_next) { @@ -133,7 +148,7 @@ template struct TrySeqTraitsWithSfinae< T, absl::enable_if_t< std::is_same())), bool>::value && - TakeValueExists::value, + TakeValueExists::value && AllowGenericTrySeqTraits::value, void>> { using UnwrappedType = decltype(TakeValue(std::declval())); using WrappedType = T; @@ -148,7 +163,7 @@ struct TrySeqTraitsWithSfinae< template static R ReturnValue(T&& status) { GPR_DEBUG_ASSERT(!IsStatusOk(status)); - return StatusCast(status.status()); + return FailureStatusCast(status.status()); } template static Poll CheckResultAndRunNext(T prior, RunNext run_next) { @@ -170,7 +185,7 @@ struct TrySeqTraitsWithSfinae { } template static R ReturnValue(absl::Status&& status) { - return StatusCast(std::move(status)); + return FailureStatusCast(std::move(status)); } template static Poll CheckResultAndRunNext(absl::Status prior, diff --git a/src/core/lib/resource_quota/arena.h b/src/core/lib/resource_quota/arena.h index 9c0c812d4c3..edcab2caf87 100644 --- a/src/core/lib/resource_quota/arena.h +++ b/src/core/lib/resource_quota/arena.h @@ -180,7 +180,7 @@ class Arena { template T* New(Args&&... args) { T* t = static_cast(Alloc(sizeof(T))); - Construct(t, std::forward(args)...); + new (t) T(std::forward(args)...); return t; } @@ -333,7 +333,7 @@ class Arena { // value in Arena::PoolSizes, and so this may pessimize total // arena size. template - PoolPtr MakePooled(Args&&... args) { + static PoolPtr MakePooled(Args&&... args) { return PoolPtr(new T(std::forward(args)...), PooledDeleter()); } diff --git a/src/core/lib/slice/slice_buffer.h b/src/core/lib/slice/slice_buffer.h index 2626bd4a0e9..0c1bfe9d901 100644 --- a/src/core/lib/slice/slice_buffer.h +++ b/src/core/lib/slice/slice_buffer.h @@ -50,6 +50,9 @@ namespace grpc_core { class SliceBuffer { public: explicit SliceBuffer() { grpc_slice_buffer_init(&slice_buffer_); } + explicit SliceBuffer(Slice slice) : SliceBuffer() { + Append(std::move(slice)); + } SliceBuffer(const SliceBuffer& other) = delete; SliceBuffer(SliceBuffer&& other) noexcept { grpc_slice_buffer_init(&slice_buffer_); diff --git a/src/core/lib/surface/call.cc b/src/core/lib/surface/call.cc index d5dbc54be39..2702bedc0e7 100644 --- a/src/core/lib/surface/call.cc +++ b/src/core/lib/surface/call.cc @@ -4063,16 +4063,13 @@ void ServerCallSpine::CommitBatch(const grpc_op* ops, size_t nops, } RefCountedPtr MakeServerCall(Server* server, - Channel* channel) { - const auto initial_size = channel->CallSizeEstimate(); - global_stats().IncrementCallInitialSize(initial_size); - auto alloc = Arena::CreateWithAlloc(initial_size, sizeof(ServerCallSpine), - channel->allocator()); - auto* call = new (alloc.second) ServerCallSpine(server, channel, alloc.first); - return RefCountedPtr(call); + Channel* channel, + Arena* arena) { + return RefCountedPtr( + arena->New(server, channel, arena)); } #else -RefCountedPtr MakeServerCall(Server*, Channel*) { +RefCountedPtr MakeServerCall(Server*, Channel*, Arena*) { Crash("not implemented"); } #endif diff --git a/src/core/lib/surface/call.h b/src/core/lib/surface/call.h index 6653bb6a0dd..520cf13505c 100644 --- a/src/core/lib/surface/call.h +++ b/src/core/lib/surface/call.h @@ -160,7 +160,8 @@ template <> struct ContextType {}; RefCountedPtr MakeServerCall(Server* server, - Channel* channel); + Channel* channel, + Arena* arena); } // namespace grpc_core diff --git a/src/core/lib/surface/server.cc b/src/core/lib/surface/server.cc index 44b541a8593..e53a609cfb1 100644 --- a/src/core/lib/surface/server.cc +++ b/src/core/lib/surface/server.cc @@ -51,6 +51,7 @@ #include "src/core/lib/channel/channel_trace.h" #include "src/core/lib/channel/channelz.h" #include "src/core/lib/config/core_configuration.h" +#include "src/core/lib/debug/stats.h" #include "src/core/lib/experiments/experiments.h" #include "src/core/lib/gpr/useful.h" #include "src/core/lib/gprpp/crash.h" @@ -1297,6 +1298,20 @@ Server::ChannelData::~ChannelData() { } } +Arena* Server::ChannelData::CreateArena() { + const auto initial_size = channel_->CallSizeEstimate(); + global_stats().IncrementCallInitialSize(initial_size); + return Arena::Create(initial_size, channel_->allocator()); +} + +absl::StatusOr Server::ChannelData::CreateCall( + ClientMetadata& client_initial_metadata, Arena* arena) { + SetRegisteredMethodOnMetadata(client_initial_metadata); + auto call = MakeServerCall(server_.get(), channel_.get(), arena); + InitCall(call); + return CallInitiator(std::move(call)); +} + void Server::ChannelData::InitTransport(RefCountedPtr server, RefCountedPtr channel, size_t cq_idx, Transport* transport, @@ -1329,13 +1344,7 @@ void Server::ChannelData::InitTransport(RefCountedPtr server, } if (transport->server_transport() != nullptr) { ++accept_stream_types; - transport->server_transport()->SetAcceptFunction( - [this](ClientMetadata& metadata) { - SetRegisteredMethodOnMetadata(metadata); - auto call = MakeServerCall(server_.get(), channel_.get()); - InitCall(call); - return CallInitiator(std::move(call)); - }); + transport->server_transport()->SetAcceptor(this); } GPR_ASSERT(accept_stream_types == 1); op->start_connectivity_watch = MakeOrphanable(this); diff --git a/src/core/lib/surface/server.h b/src/core/lib/surface/server.h index 11ec7c68a45..4bb6fce3fae 100644 --- a/src/core/lib/surface/server.h +++ b/src/core/lib/surface/server.h @@ -218,7 +218,7 @@ class Server : public InternallyRefCounted, class AllocatingRequestMatcherBatch; class AllocatingRequestMatcherRegistered; - class ChannelData { + class ChannelData final : public ServerTransport::Acceptor { public: ChannelData() = default; ~ChannelData(); @@ -241,6 +241,10 @@ class Server : public InternallyRefCounted, grpc_channel_element* elem, CallArgs call_args, NextPromiseFactory); void InitCall(RefCountedPtr call); + Arena* CreateArena() override; + absl::StatusOr CreateCall( + ClientMetadata& client_initial_metadata, Arena* arena) override; + private: class ConnectivityWatcher; diff --git a/src/core/lib/transport/promise_endpoint.h b/src/core/lib/transport/promise_endpoint.h index 2c2b3a2d37c..fbdc467cbb8 100644 --- a/src/core/lib/transport/promise_endpoint.h +++ b/src/core/lib/transport/promise_endpoint.h @@ -69,24 +69,26 @@ class PromiseEndpoint { auto Write(SliceBuffer data) { // Assert previous write finishes. GPR_ASSERT(!write_state_->complete.load(std::memory_order_relaxed)); - // TODO(ladynana): Replace this with `SliceBufferCast<>` when it is - // available. - grpc_slice_buffer_swap(write_state_->buffer.c_slice_buffer(), - data.c_slice_buffer()); - // If `Write()` returns true immediately, the callback will not be called. - // We still need to call our callback to pick up the result. - write_state_->waker = Activity::current()->MakeNonOwningWaker(); - const bool completed = endpoint_->Write( - [write_state = write_state_](absl::Status status) { - write_state->Complete(std::move(status)); - }, - &write_state_->buffer, nullptr /* uses default arguments */); + bool completed; + if (data.Length() == 0) { + completed = true; + } else { + // TODO(ladynana): Replace this with `SliceBufferCast<>` when it is + // available. + grpc_slice_buffer_swap(write_state_->buffer.c_slice_buffer(), + data.c_slice_buffer()); + // If `Write()` returns true immediately, the callback will not be called. + // We still need to call our callback to pick up the result. + write_state_->waker = Activity::current()->MakeNonOwningWaker(); + completed = endpoint_->Write( + [write_state = write_state_](absl::Status status) { + write_state->Complete(std::move(status)); + }, + &write_state_->buffer, nullptr /* uses default arguments */); + if (completed) write_state_->waker = Waker(); + } return If( - completed, - [this]() { - write_state_->waker = Waker(); - return []() { return absl::OkStatus(); }; - }, + completed, []() { return []() { return absl::OkStatus(); }; }, [this]() { return [write_state = write_state_]() -> Poll { // If current write isn't finished return `Pending()`, else return diff --git a/src/core/lib/transport/transport.cc b/src/core/lib/transport/transport.cc index ab405804065..bad5c4b8590 100644 --- a/src/core/lib/transport/transport.cc +++ b/src/core/lib/transport/transport.cc @@ -291,9 +291,8 @@ void ForwardCall(CallHandler call_handler, CallInitiator call_initiator, return call_initiator.SpawnWaitable( "send_message", [msg = std::move(msg), call_initiator]() mutable { - return call_initiator.CancelIfFails(Map( - call_initiator.PushMessage(std::move(msg)), - [](bool r) { return StatusFlag(r); })); + return call_initiator.CancelIfFails( + call_initiator.PushMessage(std::move(msg))); }); }); }); @@ -317,8 +316,7 @@ void ForwardCall(CallHandler call_handler, CallInitiator call_initiator, "recv_message", [msg = std::move(msg), call_handler]() mutable { return call_handler.CancelIfFails( - Map(call_handler.PushMessage(std::move(msg)), - [](bool r) { return StatusFlag(r); })); + call_handler.PushMessage(std::move(msg))); }); }), ImmediateOkStatus())), @@ -334,4 +332,10 @@ void ForwardCall(CallHandler call_handler, CallInitiator call_initiator, }); } +CallInitiatorAndHandler MakeCall( + grpc_event_engine::experimental::EventEngine* event_engine, Arena* arena) { + auto spine = CallSpine::Create(event_engine, arena); + return {CallInitiator(spine), CallHandler(spine)}; +} + } // namespace grpc_core diff --git a/src/core/lib/transport/transport.h b/src/core/lib/transport/transport.h index c9f138c8f09..5a7e09bdb43 100644 --- a/src/core/lib/transport/transport.h +++ b/src/core/lib/transport/transport.h @@ -258,6 +258,20 @@ class CallSpineInterface { virtual Pipe& server_to_client_messages() = 0; virtual Pipe& server_trailing_metadata() = 0; virtual Latch& cancel_latch() = 0; + // Add a callback to be called when server trailing metadata is received. + void OnDone(absl::AnyInvocable fn) { + if (on_done_ == nullptr) { + on_done_ = std::move(fn); + return; + } + on_done_ = [first = std::move(fn), next = std::move(on_done_)]() mutable { + first(); + next(); + }; + } + void CallOnDone() { + if (on_done_ != nullptr) std::exchange(on_done_, nullptr)(); + } virtual Party& party() = 0; virtual void IncrementRefCount() = 0; virtual void Unref() = 0; @@ -276,6 +290,11 @@ class CallSpineInterface { auto& c = cancel_latch(); if (c.is_set()) return absl::nullopt; c.Set(std::move(metadata)); + CallOnDone(); + client_initial_metadata().sender.CloseWithError(); + server_initial_metadata().sender.CloseWithError(); + client_to_server_messages().sender.CloseWithError(); + server_to_client_messages().sender.CloseWithError(); return absl::nullopt; } @@ -325,11 +344,18 @@ class CallSpineInterface { } }); } + + private: + absl::AnyInvocable on_done_{nullptr}; }; -class CallSpine final : public CallSpineInterface { +class CallSpine final : public CallSpineInterface, public Party { public: - CallSpine() { Crash("unimplemented"); } + static RefCountedPtr Create( + grpc_event_engine::experimental::EventEngine* event_engine, + Arena* arena) { + return RefCountedPtr(arena->New(event_engine, arena)); + } Pipe& client_initial_metadata() override { return client_initial_metadata_; @@ -347,23 +373,57 @@ class CallSpine final : public CallSpineInterface { return server_trailing_metadata_; } Latch& cancel_latch() override { return cancel_latch_; } - Party& party() override { Crash("unimplemented"); } - void IncrementRefCount() override { Crash("unimplemented"); } - void Unref() override { Crash("unimplemented"); } + Party& party() override { return *this; } + void IncrementRefCount() override { Party::IncrementRefCount(); } + void Unref() override { Party::Unref(); } private: + friend class Arena; + CallSpine(grpc_event_engine::experimental::EventEngine* event_engine, + Arena* arena) + : Party(arena, 1), event_engine_(event_engine) {} + + class ScopedContext : public ScopedActivity, + public promise_detail::Context { + public: + explicit ScopedContext(CallSpine* spine) + : ScopedActivity(&spine->party()), Context(spine->arena()) {} + }; + + bool RunParty() override { + ScopedContext context(this); + return Party::RunParty(); + } + + void PartyOver() override { + Arena* a = arena(); + { + ScopedContext context(this); + CancelRemainingParticipants(); + a->DestroyManagedNewObjects(); + } + this->~CallSpine(); + a->Destroy(); + } + + grpc_event_engine::experimental::EventEngine* event_engine() const override { + return event_engine_; + } + // Initial metadata from client to server - Pipe client_initial_metadata_; + Pipe client_initial_metadata_{arena()}; // Initial metadata from server to client - Pipe server_initial_metadata_; + Pipe server_initial_metadata_{arena()}; // Messages travelling from the application to the transport. - Pipe client_to_server_messages_; + Pipe client_to_server_messages_{arena()}; // Messages travelling from the transport to the application. - Pipe server_to_client_messages_; + Pipe server_to_client_messages_{arena()}; // Trailing metadata from server to client - Pipe server_trailing_metadata_; + Pipe server_trailing_metadata_{arena()}; // Latch that can be set to terminate the call Latch cancel_latch_; + // Event engine associated with this call + grpc_event_engine::experimental::EventEngine* const event_engine_; }; class CallInitiator { @@ -405,7 +465,14 @@ class CallInitiator { auto PushMessage(MessageHandle message) { GPR_DEBUG_ASSERT(Activity::current() == &spine_->party()); - return spine_->client_to_server_messages().sender.Push(std::move(message)); + return Map( + spine_->client_to_server_messages().sender.Push(std::move(message)), + [](bool r) { return StatusFlag(r); }); + } + + void FinishSends() { + GPR_DEBUG_ASSERT(Activity::current() == &spine_->party()); + spine_->client_to_server_messages().sender.Close(); } template @@ -413,6 +480,12 @@ class CallInitiator { return spine_->CancelIfFails(std::move(promise)); } + void Cancel() { + GPR_DEBUG_ASSERT(Activity::current() == &spine_->party()); + std::ignore = + spine_->Cancel(ServerMetadataFromStatus(absl::CancelledError())); + } + template void SpawnGuarded(absl::string_view name, PromiseFactory promise_factory) { spine_->SpawnGuarded(name, std::move(promise_factory)); @@ -428,8 +501,10 @@ class CallInitiator { return spine_->party().SpawnWaitable(name, std::move(promise_factory)); } + Arena* arena() { return spine_->party().arena(); } + private: - const RefCountedPtr spine_; + RefCountedPtr spine_; }; class CallHandler { @@ -447,14 +522,16 @@ class CallHandler { }); } - auto PushServerInitialMetadata(ClientMetadataHandle md) { + auto PushServerInitialMetadata(ServerMetadataHandle md) { GPR_DEBUG_ASSERT(Activity::current() == &spine_->party()); return Map(spine_->server_initial_metadata().sender.Push(std::move(md)), [](bool ok) { return StatusFlag(ok); }); } - auto PushServerTrailingMetadata(ClientMetadataHandle md) { + auto PushServerTrailingMetadata(ServerMetadataHandle md) { GPR_DEBUG_ASSERT(Activity::current() == &spine_->party()); + spine_->server_to_client_messages().sender.Close(); + spine_->CallOnDone(); return Map(spine_->server_trailing_metadata().sender.Push(std::move(md)), [](bool ok) { return StatusFlag(ok); }); } @@ -466,9 +543,18 @@ class CallHandler { auto PushMessage(MessageHandle message) { GPR_DEBUG_ASSERT(Activity::current() == &spine_->party()); - return spine_->server_to_client_messages().sender.Push(std::move(message)); + return Map( + spine_->server_to_client_messages().sender.Push(std::move(message)), + [](bool ok) { return StatusFlag(ok); }); } + void Cancel(ServerMetadataHandle status) { + GPR_DEBUG_ASSERT(Activity::current() == &spine_->party()); + std::ignore = spine_->Cancel(std::move(status)); + } + + void OnDone(absl::AnyInvocable fn) { spine_->OnDone(std::move(fn)); } + template auto CancelIfFails(Promise promise) { return spine_->CancelIfFails(std::move(promise)); @@ -489,8 +575,10 @@ class CallHandler { return spine_->party().SpawnWaitable(name, std::move(promise_factory)); } + Arena* arena() { return spine_->party().arena(); } + private: - const RefCountedPtr spine_; + RefCountedPtr spine_; }; struct CallInitiatorAndHandler { @@ -498,13 +586,16 @@ struct CallInitiatorAndHandler { CallHandler handler; }; +CallInitiatorAndHandler MakeCall( + grpc_event_engine::experimental::EventEngine* event_engine, Arena* arena); + template -auto OutgoingMessages(CallHalf& h) { +auto OutgoingMessages(CallHalf h) { struct Wrapper { - CallHalf& h; + CallHalf h; auto Next() { return h.PullMessage(); } }; - return Wrapper{h}; + return Wrapper{std::move(h)}; } // Forward a call from `call_handler` to `call_initiator` (with initial metadata @@ -925,14 +1016,24 @@ class ClientTransport { class ServerTransport { public: - // AcceptFunction takes initial metadata for a new call and returns a - // CallInitiator object for it, for the transport to use to communicate with - // the CallHandler object passed to the application. - using AcceptFunction = - absl::AnyInvocable(ClientMetadata&) const>; + // Acceptor helps transports create calls. + class Acceptor { + public: + // Returns an arena that can be used to allocate memory for initial metadata + // parsing, and later passed to CreateCall() as the underlying arena for + // that call. + virtual Arena* CreateArena() = 0; + // Create a call at the server (or fail) + // arena must have been previously allocated by CreateArena() + virtual absl::StatusOr CreateCall( + ClientMetadata& client_initial_metadata, Arena* arena) = 0; + + protected: + ~Acceptor() = default; + }; // Called once slightly after transport setup to register the accept function. - virtual void SetAcceptFunction(AcceptFunction accept_function) = 0; + virtual void SetAcceptor(Acceptor* acceptor) = 0; protected: ~ServerTransport() = default; diff --git a/src/python/.gitignore b/src/python/.gitignore index 095ab8bbae1..61363e8cb8d 100644 --- a/src/python/.gitignore +++ b/src/python/.gitignore @@ -1,4 +1,6 @@ -gens/ +build/ +grpc_root/ +third_party/ *_pb2.py *_pb2.pyi *_pb2_grpc.py diff --git a/src/python/grpcio/grpc/_channel.py b/src/python/grpcio/grpc/_channel.py index 79c85a1b94c..bf29982ca65 100644 --- a/src/python/grpcio/grpc/_channel.py +++ b/src/python/grpcio/grpc/_channel.py @@ -24,6 +24,7 @@ import types from typing import ( Any, Callable, + Dict, Iterator, List, Optional, @@ -1054,6 +1055,7 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): _request_serializer: Optional[SerializingFunction] _response_deserializer: Optional[DeserializingFunction] _context: Any + _registered_call_handle: Optional[int] __slots__ = [ "_channel", @@ -1074,6 +1076,7 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): target: bytes, request_serializer: Optional[SerializingFunction], response_deserializer: Optional[DeserializingFunction], + _registered_call_handle: Optional[int], ): self._channel = channel self._managed_call = managed_call @@ -1082,6 +1085,7 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): self._request_serializer = request_serializer self._response_deserializer = response_deserializer self._context = cygrpc.build_census_context() + self._registered_call_handle = _registered_call_handle def _prepare( self, @@ -1153,6 +1157,7 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): ), ), self._context, + self._registered_call_handle, ) event = call.next_event() _handle_event(event, state, self._response_deserializer) @@ -1221,6 +1226,7 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): (operations,), event_handler, self._context, + self._registered_call_handle, ) return _MultiThreadedRendezvous( state, call, self._response_deserializer, deadline @@ -1234,6 +1240,7 @@ class _SingleThreadedUnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable): _request_serializer: Optional[SerializingFunction] _response_deserializer: Optional[DeserializingFunction] _context: Any + _registered_call_handle: Optional[int] __slots__ = [ "_channel", @@ -1252,6 +1259,7 @@ class _SingleThreadedUnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable): target: bytes, request_serializer: SerializingFunction, response_deserializer: DeserializingFunction, + _registered_call_handle: Optional[int], ): self._channel = channel self._method = method @@ -1259,6 +1267,7 @@ class _SingleThreadedUnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable): self._request_serializer = request_serializer self._response_deserializer = response_deserializer self._context = cygrpc.build_census_context() + self._registered_call_handle = _registered_call_handle def __call__( # pylint: disable=too-many-locals self, @@ -1317,6 +1326,7 @@ class _SingleThreadedUnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable): call_credentials, operations_and_tags, self._context, + self._registered_call_handle, ) return _SingleThreadedRendezvous( state, call, self._response_deserializer, deadline @@ -1331,6 +1341,7 @@ class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable): _request_serializer: Optional[SerializingFunction] _response_deserializer: Optional[DeserializingFunction] _context: Any + _registered_call_handle: Optional[int] __slots__ = [ "_channel", @@ -1351,6 +1362,7 @@ class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable): target: bytes, request_serializer: SerializingFunction, response_deserializer: DeserializingFunction, + _registered_call_handle: Optional[int], ): self._channel = channel self._managed_call = managed_call @@ -1359,6 +1371,7 @@ class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable): self._request_serializer = request_serializer self._response_deserializer = response_deserializer self._context = cygrpc.build_census_context() + self._registered_call_handle = _registered_call_handle def __call__( # pylint: disable=too-many-locals self, @@ -1408,6 +1421,7 @@ class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable): operations, _event_handler(state, self._response_deserializer), self._context, + self._registered_call_handle, ) return _MultiThreadedRendezvous( state, call, self._response_deserializer, deadline @@ -1422,6 +1436,7 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): _request_serializer: Optional[SerializingFunction] _response_deserializer: Optional[DeserializingFunction] _context: Any + _registered_call_handle: Optional[int] __slots__ = [ "_channel", @@ -1442,6 +1457,7 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): target: bytes, request_serializer: Optional[SerializingFunction], response_deserializer: Optional[DeserializingFunction], + _registered_call_handle: Optional[int], ): self._channel = channel self._managed_call = managed_call @@ -1450,6 +1466,7 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): self._request_serializer = request_serializer self._response_deserializer = response_deserializer self._context = cygrpc.build_census_context() + self._registered_call_handle = _registered_call_handle def _blocking( self, @@ -1482,6 +1499,7 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): augmented_metadata, initial_metadata_flags ), self._context, + self._registered_call_handle, ) _consume_request_iterator( request_iterator, state, call, self._request_serializer, None @@ -1572,6 +1590,7 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): ), event_handler, self._context, + self._registered_call_handle, ) _consume_request_iterator( request_iterator, @@ -1593,6 +1612,7 @@ class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable): _request_serializer: Optional[SerializingFunction] _response_deserializer: Optional[DeserializingFunction] _context: Any + _registered_call_handle: Optional[int] __slots__ = [ "_channel", @@ -1611,8 +1631,9 @@ class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable): managed_call: IntegratedCallFactory, method: bytes, target: bytes, - request_serializer: Optional[SerializingFunction] = None, - response_deserializer: Optional[DeserializingFunction] = None, + request_serializer: Optional[SerializingFunction], + response_deserializer: Optional[DeserializingFunction], + _registered_call_handle: Optional[int], ): self._channel = channel self._managed_call = managed_call @@ -1621,6 +1642,7 @@ class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable): self._request_serializer = request_serializer self._response_deserializer = response_deserializer self._context = cygrpc.build_census_context() + self._registered_call_handle = _registered_call_handle def __call__( self, @@ -1662,6 +1684,7 @@ class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable): operations, event_handler, self._context, + self._registered_call_handle, ) _consume_request_iterator( request_iterator, @@ -1751,7 +1774,8 @@ def _channel_managed_call_management(state: _ChannelCallState): credentials: Optional[cygrpc.CallCredentials], operations: Sequence[Sequence[cygrpc.Operation]], event_handler: UserTag, - context, + context: Any, + _registered_call_handle: Optional[int], ) -> cygrpc.IntegratedCall: """Creates a cygrpc.IntegratedCall. @@ -1768,6 +1792,8 @@ def _channel_managed_call_management(state: _ChannelCallState): event_handler: A behavior to call to handle the events resultant from the operations on the call. context: Context object for distributed tracing. + _registered_call_handle: An int representing the call handle of the + method, or None if the method is not registered. Returns: A cygrpc.IntegratedCall with which to conduct an RPC. """ @@ -1788,6 +1814,7 @@ def _channel_managed_call_management(state: _ChannelCallState): credentials, operations_and_tags, context, + _registered_call_handle, ) if state.managed_calls == 0: state.managed_calls = 1 @@ -2021,6 +2048,7 @@ class Channel(grpc.Channel): _call_state: _ChannelCallState _connectivity_state: _ChannelConnectivityState _target: str + _registered_call_handles: Dict[str, int] def __init__( self, @@ -2055,6 +2083,22 @@ class Channel(grpc.Channel): if cygrpc.g_gevent_activated: cygrpc.gevent_increment_channel_count() + def _get_registered_call_handle(self, method: str) -> int: + """ + Get the registered call handle for a method. + + This is a semi-private method. It is intended for use only by gRPC generated code. + + This method is not thread-safe. + + Args: + method: Required, the method name for the RPC. + + Returns: + The registered call handle pointer in the form of a Python Long. + """ + return self._channel.get_registered_call_handle(_common.encode(method)) + def _process_python_options( self, python_options: Sequence[ChannelArgumentType] ) -> None: @@ -2078,12 +2122,17 @@ class Channel(grpc.Channel): ) -> None: _unsubscribe(self._connectivity_state, callback) + # pylint: disable=arguments-differ def unary_unary( self, method: str, request_serializer: Optional[SerializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None, + _registered_method: Optional[bool] = False, ) -> grpc.UnaryUnaryMultiCallable: + _registered_call_handle = None + if _registered_method: + _registered_call_handle = self._get_registered_call_handle(method) return _UnaryUnaryMultiCallable( self._channel, _channel_managed_call_management(self._call_state), @@ -2091,14 +2140,20 @@ class Channel(grpc.Channel): _common.encode(self._target), request_serializer, response_deserializer, + _registered_call_handle, ) + # pylint: disable=arguments-differ def unary_stream( self, method: str, request_serializer: Optional[SerializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None, + _registered_method: Optional[bool] = False, ) -> grpc.UnaryStreamMultiCallable: + _registered_call_handle = None + if _registered_method: + _registered_call_handle = self._get_registered_call_handle(method) # NOTE(rbellevi): Benchmarks have shown that running a unary-stream RPC # on a single Python thread results in an appreciable speed-up. However, # due to slight differences in capability, the multi-threaded variant @@ -2110,6 +2165,7 @@ class Channel(grpc.Channel): _common.encode(self._target), request_serializer, response_deserializer, + _registered_call_handle, ) else: return _UnaryStreamMultiCallable( @@ -2119,14 +2175,20 @@ class Channel(grpc.Channel): _common.encode(self._target), request_serializer, response_deserializer, + _registered_call_handle, ) + # pylint: disable=arguments-differ def stream_unary( self, method: str, request_serializer: Optional[SerializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None, + _registered_method: Optional[bool] = False, ) -> grpc.StreamUnaryMultiCallable: + _registered_call_handle = None + if _registered_method: + _registered_call_handle = self._get_registered_call_handle(method) return _StreamUnaryMultiCallable( self._channel, _channel_managed_call_management(self._call_state), @@ -2134,14 +2196,20 @@ class Channel(grpc.Channel): _common.encode(self._target), request_serializer, response_deserializer, + _registered_call_handle, ) + # pylint: disable=arguments-differ def stream_stream( self, method: str, request_serializer: Optional[SerializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None, + _registered_method: Optional[bool] = False, ) -> grpc.StreamStreamMultiCallable: + _registered_call_handle = None + if _registered_method: + _registered_call_handle = self._get_registered_call_handle(method) return _StreamStreamMultiCallable( self._channel, _channel_managed_call_management(self._call_state), @@ -2149,6 +2217,7 @@ class Channel(grpc.Channel): _common.encode(self._target), request_serializer, response_deserializer, + _registered_call_handle, ) def _unsubscribe_all(self) -> None: diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/channel.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/channel.pxd.pxi index 6e5416a9e31..96d03e181b9 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/channel.pxd.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/channel.pxd.pxi @@ -74,6 +74,13 @@ cdef class SegregatedCall: cdef class Channel: cdef _ChannelState _state + cdef dict _registered_call_handles # TODO(https://github.com/grpc/grpc/issues/15662): Eliminate this. cdef tuple _arguments + + +cdef class CallHandle: + + cdef void *c_call_handle + cdef object method diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi index f6db36ebde1..dde3b166789 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi @@ -101,6 +101,25 @@ cdef class _ChannelState: self.connectivity_due = set() self.closed_reason = None +cdef class CallHandle: + + def __cinit__(self, _ChannelState channel_state, object method): + self.method = method + cpython.Py_INCREF(method) + # Note that since we always pass None for host, we set the + # second-to-last parameter of grpc_channel_register_call to a fixed + # NULL value. + self.c_call_handle = grpc_channel_register_call( + channel_state.c_channel, method, NULL, NULL) + + def __dealloc__(self): + cpython.Py_DECREF(self.method) + + @property + def call_handle(self): + return cpython.PyLong_FromVoidPtr(self.c_call_handle) + + cdef tuple _operate(grpc_call *c_call, object operations, object user_tag): cdef grpc_call_error c_call_error @@ -199,7 +218,7 @@ cdef void _call( grpc_completion_queue *c_completion_queue, on_success, int flags, method, host, object deadline, CallCredentials credentials, object operationses_and_user_tags, object metadata, - object context) except *: + object context, object registered_call_handle) except *: """Invokes an RPC. Args: @@ -226,6 +245,8 @@ cdef void _call( must be present in the first element of this value. metadata: The metadata for this call. context: Context object for distributed tracing. + registered_call_handle: An int representing the call handle of the method, or + None if the method is not registered. """ cdef grpc_slice method_slice cdef grpc_slice host_slice @@ -242,10 +263,16 @@ cdef void _call( else: host_slice = _slice_from_bytes(host) host_slice_ptr = &host_slice - call_state.c_call = grpc_channel_create_call( - channel_state.c_channel, NULL, flags, - c_completion_queue, method_slice, host_slice_ptr, - _timespec_from_time(deadline), NULL) + if registered_call_handle: + call_state.c_call = grpc_channel_create_registered_call( + channel_state.c_channel, NULL, flags, + c_completion_queue, cpython.PyLong_AsVoidPtr(registered_call_handle), + _timespec_from_time(deadline), NULL) + else: + call_state.c_call = grpc_channel_create_call( + channel_state.c_channel, NULL, flags, + c_completion_queue, method_slice, host_slice_ptr, + _timespec_from_time(deadline), NULL) grpc_slice_unref(method_slice) if host_slice_ptr: grpc_slice_unref(host_slice) @@ -309,7 +336,7 @@ cdef class IntegratedCall: cdef IntegratedCall _integrated_call( _ChannelState state, int flags, method, host, object deadline, object metadata, CallCredentials credentials, operationses_and_user_tags, - object context): + object context, object registered_call_handle): call_state = _CallState() def on_success(started_tags): @@ -318,7 +345,8 @@ cdef IntegratedCall _integrated_call( _call( state, call_state, state.c_call_completion_queue, on_success, flags, - method, host, deadline, credentials, operationses_and_user_tags, metadata, context) + method, host, deadline, credentials, operationses_and_user_tags, + metadata, context, registered_call_handle) return IntegratedCall(state, call_state) @@ -371,7 +399,7 @@ cdef class SegregatedCall: cdef SegregatedCall _segregated_call( _ChannelState state, int flags, method, host, object deadline, object metadata, CallCredentials credentials, operationses_and_user_tags, - object context): + object context, object registered_call_handle): cdef _CallState call_state = _CallState() cdef SegregatedCall segregated_call cdef grpc_completion_queue *c_completion_queue @@ -389,7 +417,7 @@ cdef SegregatedCall _segregated_call( _call( state, call_state, c_completion_queue, on_success, flags, method, host, deadline, credentials, operationses_and_user_tags, metadata, - context) + context, registered_call_handle) except: _destroy_c_completion_queue(c_completion_queue) raise @@ -486,6 +514,7 @@ cdef class Channel: else grpc_insecure_credentials_create()) self._state.c_channel = grpc_channel_create( target, c_channel_credentials, channel_args.c_args()) + self._registered_call_handles = {} grpc_channel_credentials_release(c_channel_credentials) def target(self): @@ -499,10 +528,10 @@ cdef class Channel: def integrated_call( self, int flags, method, host, object deadline, object metadata, CallCredentials credentials, operationses_and_tags, - object context = None): + object context = None, object registered_call_handle = None): return _integrated_call( self._state, flags, method, host, deadline, metadata, credentials, - operationses_and_tags, context) + operationses_and_tags, context, registered_call_handle) def next_call_event(self): def on_success(tag): @@ -521,10 +550,10 @@ cdef class Channel: def segregated_call( self, int flags, method, host, object deadline, object metadata, CallCredentials credentials, operationses_and_tags, - object context = None): + object context = None, object registered_call_handle = None): return _segregated_call( self._state, flags, method, host, deadline, metadata, credentials, - operationses_and_tags, context) + operationses_and_tags, context, registered_call_handle) def check_connectivity_state(self, bint try_to_connect): with self._state.condition: @@ -543,3 +572,19 @@ cdef class Channel: def close_on_fork(self, code, details): _close(self, code, details, True) + + def get_registered_call_handle(self, method): + """ + Get or registers a call handler for a method. + + This method is not thread-safe. + + Args: + method: Required, the method name for the RPC. + + Returns: + The registered call handle pointer in the form of a Python Long. + """ + if method not in self._registered_call_handles.keys(): + self._registered_call_handles[method] = CallHandle(self._state, method) + return self._registered_call_handles[method].call_handle diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxi index e1bc87d4abd..29149e9893a 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxi @@ -433,6 +433,12 @@ cdef extern from "grpc/grpc.h": grpc_channel *channel, grpc_call *parent_call, uint32_t propagation_mask, grpc_completion_queue *completion_queue, grpc_slice method, const grpc_slice *host, gpr_timespec deadline, void *reserved) nogil + void *grpc_channel_register_call( + grpc_channel *channel, const char *method, const char *host, void *reserved) nogil + grpc_call *grpc_channel_create_registered_call( + grpc_channel *channel, grpc_call *parent_call, uint32_t propagation_mask, + grpc_completion_queue *completion_queue, void* registered_call_handle, + gpr_timespec deadline, void *reserved) nogil grpc_connectivity_state grpc_channel_check_connectivity_state( grpc_channel *channel, int try_to_connect) nogil void grpc_channel_watch_connectivity_state( diff --git a/src/python/grpcio/grpc/_interceptor.py b/src/python/grpcio/grpc/_interceptor.py index 36bce4e3ba5..94abafebaa6 100644 --- a/src/python/grpcio/grpc/_interceptor.py +++ b/src/python/grpcio/grpc/_interceptor.py @@ -684,57 +684,85 @@ class _Channel(grpc.Channel): def unsubscribe(self, callback: Callable): self._channel.unsubscribe(callback) + # pylint: disable=arguments-differ def unary_unary( self, method: str, request_serializer: Optional[SerializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None, + _registered_method: Optional[bool] = False, ) -> grpc.UnaryUnaryMultiCallable: + # pytype: disable=wrong-arg-count thunk = lambda m: self._channel.unary_unary( - m, request_serializer, response_deserializer + m, + request_serializer, + response_deserializer, + _registered_method, ) + # pytype: enable=wrong-arg-count if isinstance(self._interceptor, grpc.UnaryUnaryClientInterceptor): return _UnaryUnaryMultiCallable(thunk, method, self._interceptor) else: return thunk(method) + # pylint: disable=arguments-differ def unary_stream( self, method: str, request_serializer: Optional[SerializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None, + _registered_method: Optional[bool] = False, ) -> grpc.UnaryStreamMultiCallable: + # pytype: disable=wrong-arg-count thunk = lambda m: self._channel.unary_stream( - m, request_serializer, response_deserializer + m, + request_serializer, + response_deserializer, + _registered_method, ) + # pytype: enable=wrong-arg-count if isinstance(self._interceptor, grpc.UnaryStreamClientInterceptor): return _UnaryStreamMultiCallable(thunk, method, self._interceptor) else: return thunk(method) + # pylint: disable=arguments-differ def stream_unary( self, method: str, request_serializer: Optional[SerializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None, + _registered_method: Optional[bool] = False, ) -> grpc.StreamUnaryMultiCallable: + # pytype: disable=wrong-arg-count thunk = lambda m: self._channel.stream_unary( - m, request_serializer, response_deserializer + m, + request_serializer, + response_deserializer, + _registered_method, ) + # pytype: enable=wrong-arg-count if isinstance(self._interceptor, grpc.StreamUnaryClientInterceptor): return _StreamUnaryMultiCallable(thunk, method, self._interceptor) else: return thunk(method) + # pylint: disable=arguments-differ def stream_stream( self, method: str, request_serializer: Optional[SerializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None, + _registered_method: Optional[bool] = False, ) -> grpc.StreamStreamMultiCallable: + # pytype: disable=wrong-arg-count thunk = lambda m: self._channel.stream_stream( - m, request_serializer, response_deserializer + m, + request_serializer, + response_deserializer, + _registered_method, ) + # pytype: enable=wrong-arg-count if isinstance(self._interceptor, grpc.StreamStreamClientInterceptor): return _StreamStreamMultiCallable(thunk, method, self._interceptor) else: diff --git a/src/python/grpcio/grpc/_simple_stubs.py b/src/python/grpcio/grpc/_simple_stubs.py index 7772860957b..3e88670aa08 100644 --- a/src/python/grpcio/grpc/_simple_stubs.py +++ b/src/python/grpcio/grpc/_simple_stubs.py @@ -159,7 +159,19 @@ class ChannelCache: channel_credentials: Optional[grpc.ChannelCredentials], insecure: bool, compression: Optional[grpc.Compression], - ) -> grpc.Channel: + method: str, + _registered_method: bool, + ) -> Tuple[grpc.Channel, Optional[int]]: + """Get a channel from cache or creates a new channel. + + This method also takes care of register method for channel, + which means we'll register a new call handle if we're calling a + non-registered method for an existing channel. + + Returns: + A tuple with two items. The first item is the channel, second item is + the call handle if the method is registered, None if it's not registered. + """ if insecure and channel_credentials: raise ValueError( "The insecure option is mutually exclusive with " @@ -176,18 +188,25 @@ class ChannelCache: key = (target, options, channel_credentials, compression) with self._lock: channel_data = self._mapping.get(key, None) + call_handle = None if channel_data is not None: channel = channel_data[0] + # Register a new call handle if we're calling a registered method for an + # existing channel and this method is not registered. + if _registered_method: + call_handle = channel._get_registered_call_handle(method) self._mapping.pop(key) self._mapping[key] = ( channel, datetime.datetime.now() + _EVICTION_PERIOD, ) - return channel + return channel, call_handle else: channel = _create_channel( target, options, channel_credentials, compression ) + if _registered_method: + call_handle = channel._get_registered_call_handle(method) self._mapping[key] = ( channel, datetime.datetime.now() + _EVICTION_PERIOD, @@ -197,7 +216,7 @@ class ChannelCache: or len(self._mapping) >= _MAXIMUM_CHANNELS ): self._condition.notify() - return channel + return channel, call_handle def _test_only_channel_count(self) -> int: with self._lock: @@ -205,6 +224,7 @@ class ChannelCache: @experimental_api +# pylint: disable=too-many-locals def unary_unary( request: RequestType, target: str, @@ -219,6 +239,7 @@ def unary_unary( wait_for_ready: Optional[bool] = None, timeout: Optional[float] = _DEFAULT_TIMEOUT, metadata: Optional[Sequence[Tuple[str, Union[str, bytes]]]] = None, + _registered_method: Optional[bool] = False, ) -> ResponseType: """Invokes a unary-unary RPC without an explicitly specified channel. @@ -272,11 +293,17 @@ def unary_unary( Returns: The response to the RPC. """ - channel = ChannelCache.get().get_channel( - target, options, channel_credentials, insecure, compression + channel, method_handle = ChannelCache.get().get_channel( + target, + options, + channel_credentials, + insecure, + compression, + method, + _registered_method, ) multicallable = channel.unary_unary( - method, request_serializer, response_deserializer + method, request_serializer, response_deserializer, method_handle ) wait_for_ready = wait_for_ready if wait_for_ready is not None else True return multicallable( @@ -289,6 +316,7 @@ def unary_unary( @experimental_api +# pylint: disable=too-many-locals def unary_stream( request: RequestType, target: str, @@ -303,6 +331,7 @@ def unary_stream( wait_for_ready: Optional[bool] = None, timeout: Optional[float] = _DEFAULT_TIMEOUT, metadata: Optional[Sequence[Tuple[str, Union[str, bytes]]]] = None, + _registered_method: Optional[bool] = False, ) -> Iterator[ResponseType]: """Invokes a unary-stream RPC without an explicitly specified channel. @@ -355,11 +384,17 @@ def unary_stream( Returns: An iterator of responses. """ - channel = ChannelCache.get().get_channel( - target, options, channel_credentials, insecure, compression + channel, method_handle = ChannelCache.get().get_channel( + target, + options, + channel_credentials, + insecure, + compression, + method, + _registered_method, ) multicallable = channel.unary_stream( - method, request_serializer, response_deserializer + method, request_serializer, response_deserializer, method_handle ) wait_for_ready = wait_for_ready if wait_for_ready is not None else True return multicallable( @@ -372,6 +407,7 @@ def unary_stream( @experimental_api +# pylint: disable=too-many-locals def stream_unary( request_iterator: Iterator[RequestType], target: str, @@ -386,6 +422,7 @@ def stream_unary( wait_for_ready: Optional[bool] = None, timeout: Optional[float] = _DEFAULT_TIMEOUT, metadata: Optional[Sequence[Tuple[str, Union[str, bytes]]]] = None, + _registered_method: Optional[bool] = False, ) -> ResponseType: """Invokes a stream-unary RPC without an explicitly specified channel. @@ -438,11 +475,17 @@ def stream_unary( Returns: The response to the RPC. """ - channel = ChannelCache.get().get_channel( - target, options, channel_credentials, insecure, compression + channel, method_handle = ChannelCache.get().get_channel( + target, + options, + channel_credentials, + insecure, + compression, + method, + _registered_method, ) multicallable = channel.stream_unary( - method, request_serializer, response_deserializer + method, request_serializer, response_deserializer, method_handle ) wait_for_ready = wait_for_ready if wait_for_ready is not None else True return multicallable( @@ -455,6 +498,7 @@ def stream_unary( @experimental_api +# pylint: disable=too-many-locals def stream_stream( request_iterator: Iterator[RequestType], target: str, @@ -469,6 +513,7 @@ def stream_stream( wait_for_ready: Optional[bool] = None, timeout: Optional[float] = _DEFAULT_TIMEOUT, metadata: Optional[Sequence[Tuple[str, Union[str, bytes]]]] = None, + _registered_method: Optional[bool] = False, ) -> Iterator[ResponseType]: """Invokes a stream-stream RPC without an explicitly specified channel. @@ -521,11 +566,17 @@ def stream_stream( Returns: An iterator of responses. """ - channel = ChannelCache.get().get_channel( - target, options, channel_credentials, insecure, compression + channel, method_handle = ChannelCache.get().get_channel( + target, + options, + channel_credentials, + insecure, + compression, + method, + _registered_method, ) multicallable = channel.stream_stream( - method, request_serializer, response_deserializer + method, request_serializer, response_deserializer, method_handle ) wait_for_ready = wait_for_ready if wait_for_ready is not None else True return multicallable( diff --git a/src/python/grpcio/grpc/aio/_channel.py b/src/python/grpcio/grpc/aio/_channel.py index bea64c27fa3..ea4de20965a 100644 --- a/src/python/grpcio/grpc/aio/_channel.py +++ b/src/python/grpcio/grpc/aio/_channel.py @@ -478,11 +478,20 @@ class Channel(_base_channel.Channel): await self.wait_for_state_change(state) state = self.get_state(try_to_connect=True) + # TODO(xuanwn): Implement this method after we have + # observability for Asyncio. + def _get_registered_call_handle(self, method: str) -> int: + pass + + # TODO(xuanwn): Implement _registered_method after we have + # observability for Asyncio. + # pylint: disable=arguments-differ,unused-argument def unary_unary( self, method: str, request_serializer: Optional[SerializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None, + _registered_method: Optional[bool] = False, ) -> UnaryUnaryMultiCallable: return UnaryUnaryMultiCallable( self._channel, @@ -494,11 +503,15 @@ class Channel(_base_channel.Channel): self._loop, ) + # TODO(xuanwn): Implement _registered_method after we have + # observability for Asyncio. + # pylint: disable=arguments-differ,unused-argument def unary_stream( self, method: str, request_serializer: Optional[SerializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None, + _registered_method: Optional[bool] = False, ) -> UnaryStreamMultiCallable: return UnaryStreamMultiCallable( self._channel, @@ -510,11 +523,15 @@ class Channel(_base_channel.Channel): self._loop, ) + # TODO(xuanwn): Implement _registered_method after we have + # observability for Asyncio. + # pylint: disable=arguments-differ,unused-argument def stream_unary( self, method: str, request_serializer: Optional[SerializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None, + _registered_method: Optional[bool] = False, ) -> StreamUnaryMultiCallable: return StreamUnaryMultiCallable( self._channel, @@ -526,11 +543,15 @@ class Channel(_base_channel.Channel): self._loop, ) + # TODO(xuanwn): Implement _registered_method after we have + # observability for Asyncio. + # pylint: disable=arguments-differ,unused-argument def stream_stream( self, method: str, request_serializer: Optional[SerializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None, + _registered_method: Optional[bool] = False, ) -> StreamStreamMultiCallable: return StreamStreamMultiCallable( self._channel, diff --git a/src/python/grpcio_testing/grpc_testing/_channel/_channel.py b/src/python/grpcio_testing/grpc_testing/_channel/_channel.py index 170533f63ea..3f12e1f4df8 100644 --- a/src/python/grpcio_testing/grpc_testing/_channel/_channel.py +++ b/src/python/grpcio_testing/grpc_testing/_channel/_channel.py @@ -31,23 +31,42 @@ class TestingChannel(grpc_testing.Channel): def unsubscribe(self, callback): raise NotImplementedError() + def _get_registered_call_handle(self, method: str) -> int: + pass + def unary_unary( - self, method, request_serializer=None, response_deserializer=None + self, + method, + request_serializer=None, + response_deserializer=None, + _registered_method=False, ): return _multi_callable.UnaryUnary(method, self._state) def unary_stream( - self, method, request_serializer=None, response_deserializer=None + self, + method, + request_serializer=None, + response_deserializer=None, + _registered_method=False, ): return _multi_callable.UnaryStream(method, self._state) def stream_unary( - self, method, request_serializer=None, response_deserializer=None + self, + method, + request_serializer=None, + response_deserializer=None, + _registered_method=False, ): return _multi_callable.StreamUnary(method, self._state) def stream_stream( - self, method, request_serializer=None, response_deserializer=None + self, + method, + request_serializer=None, + response_deserializer=None, + _registered_method=False, ): return _multi_callable.StreamStream(method, self._state) diff --git a/src/python/grpcio_tests/tests/channelz/_channelz_servicer_test.py b/src/python/grpcio_tests/tests/channelz/_channelz_servicer_test.py index 78333fc62c7..2379ab59a3c 100644 --- a/src/python/grpcio_tests/tests/channelz/_channelz_servicer_test.py +++ b/src/python/grpcio_tests/tests/channelz/_channelz_servicer_test.py @@ -99,16 +99,20 @@ class ChannelzServicerTest(unittest.TestCase): def _send_successful_unary_unary(self, idx): _, r = ( self._pairs[idx] - .channel.unary_unary(_SUCCESSFUL_UNARY_UNARY) + .channel.unary_unary( + _SUCCESSFUL_UNARY_UNARY, + _registered_method=True, + ) .with_call(_REQUEST) ) self.assertEqual(r.code(), grpc.StatusCode.OK) def _send_failed_unary_unary(self, idx): try: - self._pairs[idx].channel.unary_unary(_FAILED_UNARY_UNARY).with_call( - _REQUEST - ) + self._pairs[idx].channel.unary_unary( + _FAILED_UNARY_UNARY, + _registered_method=True, + ).with_call(_REQUEST) except grpc.RpcError: return else: @@ -117,7 +121,10 @@ class ChannelzServicerTest(unittest.TestCase): def _send_successful_stream_stream(self, idx): response_iterator = ( self._pairs[idx] - .channel.stream_stream(_SUCCESSFUL_STREAM_STREAM) + .channel.stream_stream( + _SUCCESSFUL_STREAM_STREAM, + _registered_method=True, + ) .__call__(iter([_REQUEST] * test_constants.STREAM_LENGTH)) ) cnt = 0 diff --git a/src/python/grpcio_tests/tests/csds/csds_test.py b/src/python/grpcio_tests/tests/csds/csds_test.py index c58cb5943d2..15bf4e8a49b 100644 --- a/src/python/grpcio_tests/tests/csds/csds_test.py +++ b/src/python/grpcio_tests/tests/csds/csds_test.py @@ -92,7 +92,10 @@ class TestCsds(unittest.TestCase): # Force the XdsClient to initialize and request a resource with self.assertRaises(grpc.RpcError) as rpc_error: - dummy_channel.unary_unary("")(b"", wait_for_ready=False, timeout=1) + dummy_channel.unary_unary( + "", + _registered_method=True, + )(b"", wait_for_ready=False, timeout=1) self.assertEqual( grpc.StatusCode.DEADLINE_EXCEEDED, rpc_error.exception.code() ) diff --git a/src/python/grpcio_tests/tests/protoc_plugin/_python_plugin_test.py b/src/python/grpcio_tests/tests/protoc_plugin/_python_plugin_test.py index 9a6538cca35..c6e41b3ae8c 100644 --- a/src/python/grpcio_tests/tests/protoc_plugin/_python_plugin_test.py +++ b/src/python/grpcio_tests/tests/protoc_plugin/_python_plugin_test.py @@ -543,6 +543,17 @@ class PythonPluginTest(unittest.TestCase): ) service.server.stop(None) + def testRegisteredMethod(self): + """Tests that we're setting _registered_call_handle when create call using generated stub.""" + service = _CreateService() + self.assertTrue(service.stub.UnaryCall._registered_call_handle) + self.assertTrue( + service.stub.StreamingOutputCall._registered_call_handle + ) + self.assertTrue(service.stub.StreamingInputCall._registered_call_handle) + self.assertTrue(service.stub.FullDuplexCall._registered_call_handle) + service.server.stop(None) + @unittest.skipIf( sys.version_info[0] < 3 or sys.version_info[1] < 6, diff --git a/src/python/grpcio_tests/tests/qps/benchmark_client.py b/src/python/grpcio_tests/tests/qps/benchmark_client.py index e5aafc4142f..9310f8a8c61 100644 --- a/src/python/grpcio_tests/tests/qps/benchmark_client.py +++ b/src/python/grpcio_tests/tests/qps/benchmark_client.py @@ -32,13 +32,16 @@ _TIMEOUT = 60 * 60 * 24 class GenericStub(object): def __init__(self, channel): self.UnaryCall = channel.unary_unary( - "/grpc.testing.BenchmarkService/UnaryCall" + "/grpc.testing.BenchmarkService/UnaryCall", + _registered_method=True, ) self.StreamingFromServer = channel.unary_stream( - "/grpc.testing.BenchmarkService/StreamingFromServer" + "/grpc.testing.BenchmarkService/StreamingFromServer", + _registered_method=True, ) self.StreamingCall = channel.stream_stream( - "/grpc.testing.BenchmarkService/StreamingCall" + "/grpc.testing.BenchmarkService/StreamingCall", + _registered_method=True, ) diff --git a/src/python/grpcio_tests/tests/status/_grpc_status_test.py b/src/python/grpcio_tests/tests/status/_grpc_status_test.py index 2573e961f18..031bdbe4d53 100644 --- a/src/python/grpcio_tests/tests/status/_grpc_status_test.py +++ b/src/python/grpcio_tests/tests/status/_grpc_status_test.py @@ -138,7 +138,10 @@ class StatusTest(unittest.TestCase): self._channel.close() def test_status_ok(self): - _, call = self._channel.unary_unary(_STATUS_OK).with_call(_REQUEST) + _, call = self._channel.unary_unary( + _STATUS_OK, + _registered_method=True, + ).with_call(_REQUEST) # Succeed RPC doesn't have status status = rpc_status.from_call(call) @@ -146,7 +149,10 @@ class StatusTest(unittest.TestCase): def test_status_not_ok(self): with self.assertRaises(grpc.RpcError) as exception_context: - self._channel.unary_unary(_STATUS_NOT_OK).with_call(_REQUEST) + self._channel.unary_unary( + _STATUS_NOT_OK, + _registered_method=True, + ).with_call(_REQUEST) rpc_error = exception_context.exception self.assertEqual(rpc_error.code(), grpc.StatusCode.INTERNAL) @@ -156,7 +162,10 @@ class StatusTest(unittest.TestCase): def test_error_details(self): with self.assertRaises(grpc.RpcError) as exception_context: - self._channel.unary_unary(_ERROR_DETAILS).with_call(_REQUEST) + self._channel.unary_unary( + _ERROR_DETAILS, + _registered_method=True, + ).with_call(_REQUEST) rpc_error = exception_context.exception status = rpc_status.from_call(rpc_error) @@ -173,7 +182,10 @@ class StatusTest(unittest.TestCase): def test_code_message_validation(self): with self.assertRaises(grpc.RpcError) as exception_context: - self._channel.unary_unary(_INCONSISTENT).with_call(_REQUEST) + self._channel.unary_unary( + _INCONSISTENT, + _registered_method=True, + ).with_call(_REQUEST) rpc_error = exception_context.exception self.assertEqual(rpc_error.code(), grpc.StatusCode.NOT_FOUND) @@ -182,7 +194,10 @@ class StatusTest(unittest.TestCase): def test_invalid_code(self): with self.assertRaises(grpc.RpcError) as exception_context: - self._channel.unary_unary(_INVALID_CODE).with_call(_REQUEST) + self._channel.unary_unary( + _INVALID_CODE, + _registered_method=True, + ).with_call(_REQUEST) rpc_error = exception_context.exception self.assertEqual(rpc_error.code(), grpc.StatusCode.UNKNOWN) # Invalid status code exception raised during coversion diff --git a/src/python/grpcio_tests/tests/unit/_abort_test.py b/src/python/grpcio_tests/tests/unit/_abort_test.py index 731bb741bec..46f48bd1cae 100644 --- a/src/python/grpcio_tests/tests/unit/_abort_test.py +++ b/src/python/grpcio_tests/tests/unit/_abort_test.py @@ -107,7 +107,10 @@ class AbortTest(unittest.TestCase): def test_abort(self): with self.assertRaises(grpc.RpcError) as exception_context: - self._channel.unary_unary(_ABORT)(_REQUEST) + self._channel.unary_unary( + _ABORT, + _registered_method=True, + )(_REQUEST) rpc_error = exception_context.exception self.assertEqual(rpc_error.code(), grpc.StatusCode.INTERNAL) @@ -124,7 +127,10 @@ class AbortTest(unittest.TestCase): # Servicer will abort() after creating a local ref to do_not_leak_me. with self.assertRaises(grpc.RpcError): - self._channel.unary_unary(_ABORT)(_REQUEST) + self._channel.unary_unary( + _ABORT, + _registered_method=True, + )(_REQUEST) # Server may still have a stack frame reference to the exception even # after client sees error, so ensure server has shutdown. @@ -134,7 +140,10 @@ class AbortTest(unittest.TestCase): def test_abort_with_status(self): with self.assertRaises(grpc.RpcError) as exception_context: - self._channel.unary_unary(_ABORT_WITH_STATUS)(_REQUEST) + self._channel.unary_unary( + _ABORT_WITH_STATUS, + _registered_method=True, + )(_REQUEST) rpc_error = exception_context.exception self.assertEqual(rpc_error.code(), grpc.StatusCode.INTERNAL) @@ -143,7 +152,10 @@ class AbortTest(unittest.TestCase): def test_invalid_code(self): with self.assertRaises(grpc.RpcError) as exception_context: - self._channel.unary_unary(_INVALID_CODE)(_REQUEST) + self._channel.unary_unary( + _INVALID_CODE, + _registered_method=True, + )(_REQUEST) rpc_error = exception_context.exception self.assertEqual(rpc_error.code(), grpc.StatusCode.UNKNOWN) diff --git a/src/python/grpcio_tests/tests/unit/_auth_context_test.py b/src/python/grpcio_tests/tests/unit/_auth_context_test.py index 039c908c3e5..0e5e2017fba 100644 --- a/src/python/grpcio_tests/tests/unit/_auth_context_test.py +++ b/src/python/grpcio_tests/tests/unit/_auth_context_test.py @@ -78,7 +78,10 @@ class AuthContextTest(unittest.TestCase): server.start() with grpc.insecure_channel("localhost:%d" % port) as channel: - response = channel.unary_unary(_UNARY_UNARY)(_REQUEST) + response = channel.unary_unary( + _UNARY_UNARY, + _registered_method=True, + )(_REQUEST) server.stop(None) auth_data = pickle.loads(response) @@ -115,7 +118,10 @@ class AuthContextTest(unittest.TestCase): channel_creds, options=_PROPERTY_OPTIONS, ) - response = channel.unary_unary(_UNARY_UNARY)(_REQUEST) + response = channel.unary_unary( + _UNARY_UNARY, + _registered_method=True, + )(_REQUEST) channel.close() server.stop(None) @@ -161,7 +167,10 @@ class AuthContextTest(unittest.TestCase): options=_PROPERTY_OPTIONS, ) - response = channel.unary_unary(_UNARY_UNARY)(_REQUEST) + response = channel.unary_unary( + _UNARY_UNARY, + _registered_method=True, + )(_REQUEST) channel.close() server.stop(None) @@ -180,7 +189,10 @@ class AuthContextTest(unittest.TestCase): channel = grpc.secure_channel( "localhost:{}".format(port), channel_creds, options=channel_options ) - response = channel.unary_unary(_UNARY_UNARY)(_REQUEST) + response = channel.unary_unary( + _UNARY_UNARY, + _registered_method=True, + )(_REQUEST) auth_data = pickle.loads(response) self.assertEqual( expect_ssl_session_reused, diff --git a/src/python/grpcio_tests/tests/unit/_channel_close_test.py b/src/python/grpcio_tests/tests/unit/_channel_close_test.py index 4e5f215af89..b2ba4e7c887 100644 --- a/src/python/grpcio_tests/tests/unit/_channel_close_test.py +++ b/src/python/grpcio_tests/tests/unit/_channel_close_test.py @@ -123,7 +123,10 @@ class ChannelCloseTest(unittest.TestCase): def test_close_immediately_after_call_invocation(self): channel = grpc.insecure_channel("localhost:{}".format(self._port)) - multi_callable = channel.stream_stream(_STREAM_URI) + multi_callable = channel.stream_stream( + _STREAM_URI, + _registered_method=True, + ) request_iterator = _Pipe(()) response_iterator = multi_callable(request_iterator) channel.close() @@ -133,7 +136,10 @@ class ChannelCloseTest(unittest.TestCase): def test_close_while_call_active(self): channel = grpc.insecure_channel("localhost:{}".format(self._port)) - multi_callable = channel.stream_stream(_STREAM_URI) + multi_callable = channel.stream_stream( + _STREAM_URI, + _registered_method=True, + ) request_iterator = _Pipe((b"abc",)) response_iterator = multi_callable(request_iterator) next(response_iterator) @@ -146,7 +152,10 @@ class ChannelCloseTest(unittest.TestCase): with grpc.insecure_channel( "localhost:{}".format(self._port) ) as channel: # pylint: disable=bad-continuation - multi_callable = channel.stream_stream(_STREAM_URI) + multi_callable = channel.stream_stream( + _STREAM_URI, + _registered_method=True, + ) request_iterator = _Pipe((b"abc",)) response_iterator = multi_callable(request_iterator) next(response_iterator) @@ -158,7 +167,10 @@ class ChannelCloseTest(unittest.TestCase): with grpc.insecure_channel( "localhost:{}".format(self._port) ) as channel: # pylint: disable=bad-continuation - multi_callable = channel.stream_stream(_STREAM_URI) + multi_callable = channel.stream_stream( + _STREAM_URI, + _registered_method=True, + ) request_iterators = tuple( _Pipe((b"abc",)) for _ in range(test_constants.THREAD_CONCURRENCY) @@ -176,7 +188,10 @@ class ChannelCloseTest(unittest.TestCase): def test_many_concurrent_closes(self): channel = grpc.insecure_channel("localhost:{}".format(self._port)) - multi_callable = channel.stream_stream(_STREAM_URI) + multi_callable = channel.stream_stream( + _STREAM_URI, + _registered_method=True, + ) request_iterator = _Pipe((b"abc",)) response_iterator = multi_callable(request_iterator) next(response_iterator) @@ -203,10 +218,16 @@ class ChannelCloseTest(unittest.TestCase): with grpc.insecure_channel( "localhost:{}".format(self._port) ) as channel: - stream_multi_callable = channel.stream_stream(_STREAM_URI) + stream_multi_callable = channel.stream_stream( + _STREAM_URI, + _registered_method=True, + ) endless_iterator = itertools.repeat(b"abc") stream_response_iterator = stream_multi_callable(endless_iterator) - future = channel.unary_unary(_UNARY_URI).future(b"abc") + future = channel.unary_unary( + _UNARY_URI, + _registered_method=True, + ).future(b"abc") def on_done_callback(future): raise Exception("This should not cause a deadlock.") diff --git a/src/python/grpcio_tests/tests/unit/_compression_test.py b/src/python/grpcio_tests/tests/unit/_compression_test.py index be2a528ea9b..9fdeca030d8 100644 --- a/src/python/grpcio_tests/tests/unit/_compression_test.py +++ b/src/python/grpcio_tests/tests/unit/_compression_test.py @@ -237,7 +237,10 @@ def _get_compression_ratios( def _unary_unary_client(channel, multicallable_kwargs, message): - multi_callable = channel.unary_unary(_UNARY_UNARY) + multi_callable = channel.unary_unary( + _UNARY_UNARY, + _registered_method=True, + ) response = multi_callable(message, **multicallable_kwargs) if response != message: raise RuntimeError( @@ -246,7 +249,10 @@ def _unary_unary_client(channel, multicallable_kwargs, message): def _unary_stream_client(channel, multicallable_kwargs, message): - multi_callable = channel.unary_stream(_UNARY_STREAM) + multi_callable = channel.unary_stream( + _UNARY_STREAM, + _registered_method=True, + ) response_iterator = multi_callable(message, **multicallable_kwargs) for response in response_iterator: if response != message: @@ -256,7 +262,10 @@ def _unary_stream_client(channel, multicallable_kwargs, message): def _stream_unary_client(channel, multicallable_kwargs, message): - multi_callable = channel.stream_unary(_STREAM_UNARY) + multi_callable = channel.stream_unary( + _STREAM_UNARY, + _registered_method=True, + ) requests = (_REQUEST for _ in range(_STREAM_LENGTH)) response = multi_callable(requests, **multicallable_kwargs) if response != message: @@ -266,7 +275,10 @@ def _stream_unary_client(channel, multicallable_kwargs, message): def _stream_stream_client(channel, multicallable_kwargs, message): - multi_callable = channel.stream_stream(_STREAM_STREAM) + multi_callable = channel.stream_stream( + _STREAM_STREAM, + _registered_method=True, + ) request_prefix = str(0).encode("ascii") * 100 requests = ( request_prefix + str(i).encode("ascii") for i in range(_STREAM_LENGTH) diff --git a/src/python/grpcio_tests/tests/unit/_contextvars_propagation_test.py b/src/python/grpcio_tests/tests/unit/_contextvars_propagation_test.py index 6f3b601ceb2..5a23d5dc69e 100644 --- a/src/python/grpcio_tests/tests/unit/_contextvars_propagation_test.py +++ b/src/python/grpcio_tests/tests/unit/_contextvars_propagation_test.py @@ -116,7 +116,10 @@ class ContextVarsPropagationTest(unittest.TestCase): local_credentials, call_credentials ) with grpc.secure_channel(target, composite_credentials) as channel: - stub = channel.unary_unary(_UNARY_UNARY) + stub = channel.unary_unary( + _UNARY_UNARY, + _registered_method=True, + ) response = stub(_REQUEST, wait_for_ready=True) self.assertEqual(_REQUEST, response) @@ -142,7 +145,10 @@ class ContextVarsPropagationTest(unittest.TestCase): with grpc.secure_channel( target, composite_credentials ) as channel: - stub = channel.unary_unary(_UNARY_UNARY) + stub = channel.unary_unary( + _UNARY_UNARY, + _registered_method=True, + ) wait_group.done() wait_group.wait() for i in range(_RPC_COUNT): diff --git a/src/python/grpcio_tests/tests/unit/_dns_resolver_test.py b/src/python/grpcio_tests/tests/unit/_dns_resolver_test.py index 62a95a02135..bcd7e6da849 100644 --- a/src/python/grpcio_tests/tests/unit/_dns_resolver_test.py +++ b/src/python/grpcio_tests/tests/unit/_dns_resolver_test.py @@ -55,7 +55,10 @@ class DNSResolverTest(unittest.TestCase): "loopback46.unittest.grpc.io:%d" % self._port ) as channel: self.assertEqual( - channel.unary_unary(_METHOD)( + channel.unary_unary( + _METHOD, + _registered_method=True, + )( _REQUEST, timeout=10, ), diff --git a/src/python/grpcio_tests/tests/unit/_empty_message_test.py b/src/python/grpcio_tests/tests/unit/_empty_message_test.py index e2dc1594202..a303aa8b3e9 100644 --- a/src/python/grpcio_tests/tests/unit/_empty_message_test.py +++ b/src/python/grpcio_tests/tests/unit/_empty_message_test.py @@ -96,25 +96,33 @@ class EmptyMessageTest(unittest.TestCase): self._channel.close() def testUnaryUnary(self): - response = self._channel.unary_unary(_UNARY_UNARY)(_REQUEST) + response = self._channel.unary_unary( + _UNARY_UNARY, + _registered_method=True, + )(_REQUEST) self.assertEqual(_RESPONSE, response) def testUnaryStream(self): - response_iterator = self._channel.unary_stream(_UNARY_STREAM)(_REQUEST) + response_iterator = self._channel.unary_stream( + _UNARY_STREAM, + _registered_method=True, + )(_REQUEST) self.assertSequenceEqual( [_RESPONSE] * test_constants.STREAM_LENGTH, list(response_iterator) ) def testStreamUnary(self): - response = self._channel.stream_unary(_STREAM_UNARY)( - iter([_REQUEST] * test_constants.STREAM_LENGTH) - ) + response = self._channel.stream_unary( + _STREAM_UNARY, + _registered_method=True, + )(iter([_REQUEST] * test_constants.STREAM_LENGTH)) self.assertEqual(_RESPONSE, response) def testStreamStream(self): - response_iterator = self._channel.stream_stream(_STREAM_STREAM)( - iter([_REQUEST] * test_constants.STREAM_LENGTH) - ) + response_iterator = self._channel.stream_stream( + _STREAM_STREAM, + _registered_method=True, + )(iter([_REQUEST] * test_constants.STREAM_LENGTH)) self.assertSequenceEqual( [_RESPONSE] * test_constants.STREAM_LENGTH, list(response_iterator) ) diff --git a/src/python/grpcio_tests/tests/unit/_error_message_encoding_test.py b/src/python/grpcio_tests/tests/unit/_error_message_encoding_test.py index 4f07477fac9..334b7ed5a3a 100644 --- a/src/python/grpcio_tests/tests/unit/_error_message_encoding_test.py +++ b/src/python/grpcio_tests/tests/unit/_error_message_encoding_test.py @@ -73,7 +73,10 @@ class ErrorMessageEncodingTest(unittest.TestCase): def testMessageEncoding(self): for message in _UNICODE_ERROR_MESSAGES: - multi_callable = self._channel.unary_unary(_UNARY_UNARY) + multi_callable = self._channel.unary_unary( + _UNARY_UNARY, + _registered_method=True, + ) with self.assertRaises(grpc.RpcError) as cm: multi_callable(message.encode("utf-8")) diff --git a/src/python/grpcio_tests/tests/unit/_exit_scenarios.py b/src/python/grpcio_tests/tests/unit/_exit_scenarios.py index c1f9816df08..1b7e2e5ac84 100644 --- a/src/python/grpcio_tests/tests/unit/_exit_scenarios.py +++ b/src/python/grpcio_tests/tests/unit/_exit_scenarios.py @@ -210,14 +210,20 @@ if __name__ == "__main__": method = TEST_TO_METHOD[args.scenario] if args.scenario == IN_FLIGHT_UNARY_UNARY_CALL: - multi_callable = channel.unary_unary(method) + multi_callable = channel.unary_unary( + method, + _registered_method=True, + ) future = multi_callable.future(REQUEST) result, call = multi_callable.with_call(REQUEST) elif ( args.scenario == IN_FLIGHT_UNARY_STREAM_CALL or args.scenario == IN_FLIGHT_PARTIAL_UNARY_STREAM_CALL ): - multi_callable = channel.unary_stream(method) + multi_callable = channel.unary_stream( + method, + _registered_method=True, + ) response_iterator = multi_callable(REQUEST) for response in response_iterator: pass @@ -225,7 +231,10 @@ if __name__ == "__main__": args.scenario == IN_FLIGHT_STREAM_UNARY_CALL or args.scenario == IN_FLIGHT_PARTIAL_STREAM_UNARY_CALL ): - multi_callable = channel.stream_unary(method) + multi_callable = channel.stream_unary( + method, + _registered_method=True, + ) future = multi_callable.future(infinite_request_iterator()) result, call = multi_callable.with_call( iter([REQUEST] * test_constants.STREAM_LENGTH) @@ -234,7 +243,10 @@ if __name__ == "__main__": args.scenario == IN_FLIGHT_STREAM_STREAM_CALL or args.scenario == IN_FLIGHT_PARTIAL_STREAM_STREAM_CALL ): - multi_callable = channel.stream_stream(method) + multi_callable = channel.stream_stream( + method, + _registered_method=True, + ) response_iterator = multi_callable(infinite_request_iterator()) for response in response_iterator: pass diff --git a/src/python/grpcio_tests/tests/unit/_interceptor_test.py b/src/python/grpcio_tests/tests/unit/_interceptor_test.py index 9bbff1f6bee..72e299b5887 100644 --- a/src/python/grpcio_tests/tests/unit/_interceptor_test.py +++ b/src/python/grpcio_tests/tests/unit/_interceptor_test.py @@ -231,7 +231,7 @@ class _GenericHandler(grpc.GenericRpcHandler): def _unary_unary_multi_callable(channel): - return channel.unary_unary(_UNARY_UNARY) + return channel.unary_unary(_UNARY_UNARY, _registered_method=True) def _unary_stream_multi_callable(channel): @@ -239,6 +239,7 @@ def _unary_stream_multi_callable(channel): _UNARY_STREAM, request_serializer=_SERIALIZE_REQUEST, response_deserializer=_DESERIALIZE_RESPONSE, + _registered_method=True, ) @@ -247,11 +248,12 @@ def _stream_unary_multi_callable(channel): _STREAM_UNARY, request_serializer=_SERIALIZE_REQUEST, response_deserializer=_DESERIALIZE_RESPONSE, + _registered_method=True, ) def _stream_stream_multi_callable(channel): - return channel.stream_stream(_STREAM_STREAM) + return channel.stream_stream(_STREAM_STREAM, _registered_method=True) class _ClientCallDetails( @@ -562,7 +564,7 @@ class InterceptorTest(unittest.TestCase): self._record[:] = [] multi_callable = _unary_unary_multi_callable(channel) - multi_callable.with_call( + response, call = multi_callable.with_call( request, metadata=( ( diff --git a/src/python/grpcio_tests/tests/unit/_invalid_metadata_test.py b/src/python/grpcio_tests/tests/unit/_invalid_metadata_test.py index a19966131c5..58d1e589fde 100644 --- a/src/python/grpcio_tests/tests/unit/_invalid_metadata_test.py +++ b/src/python/grpcio_tests/tests/unit/_invalid_metadata_test.py @@ -32,7 +32,7 @@ _STREAM_STREAM = "/test/StreamStream" def _unary_unary_multi_callable(channel): - return channel.unary_unary(_UNARY_UNARY) + return channel.unary_unary(_UNARY_UNARY, _registered_method=True) def _unary_stream_multi_callable(channel): @@ -40,6 +40,7 @@ def _unary_stream_multi_callable(channel): _UNARY_STREAM, request_serializer=_SERIALIZE_REQUEST, response_deserializer=_DESERIALIZE_RESPONSE, + _registered_method=True, ) @@ -48,11 +49,15 @@ def _stream_unary_multi_callable(channel): _STREAM_UNARY, request_serializer=_SERIALIZE_REQUEST, response_deserializer=_DESERIALIZE_RESPONSE, + _registered_method=True, ) def _stream_stream_multi_callable(channel): - return channel.stream_stream(_STREAM_STREAM) + return channel.stream_stream( + _STREAM_STREAM, + _registered_method=True, + ) class InvalidMetadataTest(unittest.TestCase): diff --git a/src/python/grpcio_tests/tests/unit/_invocation_defects_test.py b/src/python/grpcio_tests/tests/unit/_invocation_defects_test.py index b22ab016593..cb903e7b812 100644 --- a/src/python/grpcio_tests/tests/unit/_invocation_defects_test.py +++ b/src/python/grpcio_tests/tests/unit/_invocation_defects_test.py @@ -219,7 +219,10 @@ class FailAfterFewIterationsCounter(object): def _unary_unary_multi_callable(channel): - return channel.unary_unary(_UNARY_UNARY) + return channel.unary_unary( + _UNARY_UNARY, + _registered_method=True, + ) def _unary_stream_multi_callable(channel): @@ -227,6 +230,7 @@ def _unary_stream_multi_callable(channel): _UNARY_STREAM, request_serializer=_SERIALIZE_REQUEST, response_deserializer=_DESERIALIZE_RESPONSE, + _registered_method=True, ) @@ -235,19 +239,29 @@ def _stream_unary_multi_callable(channel): _STREAM_UNARY, request_serializer=_SERIALIZE_REQUEST, response_deserializer=_DESERIALIZE_RESPONSE, + _registered_method=True, ) def _stream_stream_multi_callable(channel): - return channel.stream_stream(_STREAM_STREAM) + return channel.stream_stream( + _STREAM_STREAM, + _registered_method=True, + ) def _defective_handler_multi_callable(channel): - return channel.unary_unary(_DEFECTIVE_GENERIC_RPC_HANDLER) + return channel.unary_unary( + _DEFECTIVE_GENERIC_RPC_HANDLER, + _registered_method=True, + ) def _defective_nested_exception_handler_multi_callable(channel): - return channel.unary_unary(_UNARY_UNARY_NESTED_EXCEPTION) + return channel.unary_unary( + _UNARY_UNARY_NESTED_EXCEPTION, + _registered_method=True, + ) class InvocationDefectsTest(unittest.TestCase): diff --git a/src/python/grpcio_tests/tests/unit/_local_credentials_test.py b/src/python/grpcio_tests/tests/unit/_local_credentials_test.py index 165f6ca16eb..9c5b425eaf7 100644 --- a/src/python/grpcio_tests/tests/unit/_local_credentials_test.py +++ b/src/python/grpcio_tests/tests/unit/_local_credentials_test.py @@ -53,9 +53,10 @@ class LocalCredentialsTest(unittest.TestCase): ) as channel: self.assertEqual( b"abc", - channel.unary_unary("/test/method")( - b"abc", wait_for_ready=True - ), + channel.unary_unary( + "/test/method", + _registered_method=True, + )(b"abc", wait_for_ready=True), ) server.stop(None) @@ -77,9 +78,10 @@ class LocalCredentialsTest(unittest.TestCase): with grpc.secure_channel(server_addr, channel_creds) as channel: self.assertEqual( b"abc", - channel.unary_unary("/test/method")( - b"abc", wait_for_ready=True - ), + channel.unary_unary( + "/test/method", + _registered_method=True, + )(b"abc", wait_for_ready=True), ) server.stop(None) diff --git a/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py b/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py index 3c530058dc3..320deb7e5f6 100644 --- a/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py +++ b/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py @@ -207,45 +207,53 @@ class MetadataCodeDetailsTest(unittest.TestCase): self._server.start() self._channel = grpc.insecure_channel("localhost:{}".format(port)) + unary_unary_method_name = "/".join( + ( + "", + _SERVICE, + _UNARY_UNARY, + ) + ) self._unary_unary = self._channel.unary_unary( - "/".join( - ( - "", - _SERVICE, - _UNARY_UNARY, - ) - ), + unary_unary_method_name, request_serializer=_REQUEST_SERIALIZER, response_deserializer=_RESPONSE_DESERIALIZER, + _registered_method=True, + ) + unary_stream_method_name = "/".join( + ( + "", + _SERVICE, + _UNARY_STREAM, + ) ) self._unary_stream = self._channel.unary_stream( - "/".join( - ( - "", - _SERVICE, - _UNARY_STREAM, - ) - ), + unary_stream_method_name, + _registered_method=True, + ) + stream_unary_method_name = "/".join( + ( + "", + _SERVICE, + _STREAM_UNARY, + ) ) self._stream_unary = self._channel.stream_unary( - "/".join( - ( - "", - _SERVICE, - _STREAM_UNARY, - ) - ), + stream_unary_method_name, + _registered_method=True, + ) + stream_stream_method_name = "/".join( + ( + "", + _SERVICE, + _STREAM_STREAM, + ) ) self._stream_stream = self._channel.stream_stream( - "/".join( - ( - "", - _SERVICE, - _STREAM_STREAM, - ) - ), + stream_stream_method_name, request_serializer=_REQUEST_SERIALIZER, response_deserializer=_RESPONSE_DESERIALIZER, + _registered_method=True, ) def tearDown(self): @@ -828,16 +836,18 @@ class InspectContextTest(unittest.TestCase): self._server.start() self._channel = grpc.insecure_channel("localhost:{}".format(port)) + unary_unary_method_name = "/".join( + ( + "", + _SERVICE, + _UNARY_UNARY, + ) + ) self._unary_unary = self._channel.unary_unary( - "/".join( - ( - "", - _SERVICE, - _UNARY_UNARY, - ) - ), + unary_unary_method_name, request_serializer=_REQUEST_SERIALIZER, response_deserializer=_RESPONSE_DESERIALIZER, + _registered_method=True, ) def tearDown(self): diff --git a/src/python/grpcio_tests/tests/unit/_metadata_flags_test.py b/src/python/grpcio_tests/tests/unit/_metadata_flags_test.py index a67a496860f..2cd9ad9bd89 100644 --- a/src/python/grpcio_tests/tests/unit/_metadata_flags_test.py +++ b/src/python/grpcio_tests/tests/unit/_metadata_flags_test.py @@ -110,7 +110,10 @@ def create_phony_channel(): def perform_unary_unary_call(channel, wait_for_ready=None): - channel.unary_unary(_UNARY_UNARY).__call__( + channel.unary_unary( + _UNARY_UNARY, + _registered_method=True, + ).__call__( _REQUEST, timeout=test_constants.LONG_TIMEOUT, wait_for_ready=wait_for_ready, @@ -118,7 +121,10 @@ def perform_unary_unary_call(channel, wait_for_ready=None): def perform_unary_unary_with_call(channel, wait_for_ready=None): - channel.unary_unary(_UNARY_UNARY).with_call( + channel.unary_unary( + _UNARY_UNARY, + _registered_method=True, + ).with_call( _REQUEST, timeout=test_constants.LONG_TIMEOUT, wait_for_ready=wait_for_ready, @@ -126,7 +132,10 @@ def perform_unary_unary_with_call(channel, wait_for_ready=None): def perform_unary_unary_future(channel, wait_for_ready=None): - channel.unary_unary(_UNARY_UNARY).future( + channel.unary_unary( + _UNARY_UNARY, + _registered_method=True, + ).future( _REQUEST, timeout=test_constants.LONG_TIMEOUT, wait_for_ready=wait_for_ready, @@ -134,7 +143,10 @@ def perform_unary_unary_future(channel, wait_for_ready=None): def perform_unary_stream_call(channel, wait_for_ready=None): - response_iterator = channel.unary_stream(_UNARY_STREAM).__call__( + response_iterator = channel.unary_stream( + _UNARY_STREAM, + _registered_method=True, + ).__call__( _REQUEST, timeout=test_constants.LONG_TIMEOUT, wait_for_ready=wait_for_ready, @@ -144,7 +156,10 @@ def perform_unary_stream_call(channel, wait_for_ready=None): def perform_stream_unary_call(channel, wait_for_ready=None): - channel.stream_unary(_STREAM_UNARY).__call__( + channel.stream_unary( + _STREAM_UNARY, + _registered_method=True, + ).__call__( iter([_REQUEST] * test_constants.STREAM_LENGTH), timeout=test_constants.LONG_TIMEOUT, wait_for_ready=wait_for_ready, @@ -152,7 +167,10 @@ def perform_stream_unary_call(channel, wait_for_ready=None): def perform_stream_unary_with_call(channel, wait_for_ready=None): - channel.stream_unary(_STREAM_UNARY).with_call( + channel.stream_unary( + _STREAM_UNARY, + _registered_method=True, + ).with_call( iter([_REQUEST] * test_constants.STREAM_LENGTH), timeout=test_constants.LONG_TIMEOUT, wait_for_ready=wait_for_ready, @@ -160,7 +178,10 @@ def perform_stream_unary_with_call(channel, wait_for_ready=None): def perform_stream_unary_future(channel, wait_for_ready=None): - channel.stream_unary(_STREAM_UNARY).future( + channel.stream_unary( + _STREAM_UNARY, + _registered_method=True, + ).future( iter([_REQUEST] * test_constants.STREAM_LENGTH), timeout=test_constants.LONG_TIMEOUT, wait_for_ready=wait_for_ready, @@ -168,7 +189,9 @@ def perform_stream_unary_future(channel, wait_for_ready=None): def perform_stream_stream_call(channel, wait_for_ready=None): - response_iterator = channel.stream_stream(_STREAM_STREAM).__call__( + response_iterator = channel.stream_stream( + _STREAM_STREAM, _registered_method=True + ).__call__( iter([_REQUEST] * test_constants.STREAM_LENGTH), timeout=test_constants.LONG_TIMEOUT, wait_for_ready=wait_for_ready, diff --git a/src/python/grpcio_tests/tests/unit/_metadata_test.py b/src/python/grpcio_tests/tests/unit/_metadata_test.py index 7110177fa18..b9b7502972c 100644 --- a/src/python/grpcio_tests/tests/unit/_metadata_test.py +++ b/src/python/grpcio_tests/tests/unit/_metadata_test.py @@ -195,7 +195,9 @@ class MetadataTest(unittest.TestCase): self._channel.close() def testUnaryUnary(self): - multi_callable = self._channel.unary_unary(_UNARY_UNARY) + multi_callable = self._channel.unary_unary( + _UNARY_UNARY, _registered_method=True + ) unused_response, call = multi_callable.with_call( _REQUEST, metadata=_INVOCATION_METADATA ) @@ -211,7 +213,9 @@ class MetadataTest(unittest.TestCase): ) def testUnaryStream(self): - multi_callable = self._channel.unary_stream(_UNARY_STREAM) + multi_callable = self._channel.unary_stream( + _UNARY_STREAM, _registered_method=True + ) call = multi_callable(_REQUEST, metadata=_INVOCATION_METADATA) self.assertTrue( test_common.metadata_transmitted( @@ -227,7 +231,9 @@ class MetadataTest(unittest.TestCase): ) def testStreamUnary(self): - multi_callable = self._channel.stream_unary(_STREAM_UNARY) + multi_callable = self._channel.stream_unary( + _STREAM_UNARY, _registered_method=True + ) unused_response, call = multi_callable.with_call( iter([_REQUEST] * test_constants.STREAM_LENGTH), metadata=_INVOCATION_METADATA, @@ -244,7 +250,9 @@ class MetadataTest(unittest.TestCase): ) def testStreamStream(self): - multi_callable = self._channel.stream_stream(_STREAM_STREAM) + multi_callable = self._channel.stream_stream( + _STREAM_STREAM, _registered_method=True + ) call = multi_callable( iter([_REQUEST] * test_constants.STREAM_LENGTH), metadata=_INVOCATION_METADATA, diff --git a/src/python/grpcio_tests/tests/unit/_reconnect_test.py b/src/python/grpcio_tests/tests/unit/_reconnect_test.py index d412533251c..8d4fadaa3da 100644 --- a/src/python/grpcio_tests/tests/unit/_reconnect_test.py +++ b/src/python/grpcio_tests/tests/unit/_reconnect_test.py @@ -52,7 +52,10 @@ class ReconnectTest(unittest.TestCase): server.add_insecure_port(addr) server.start() channel = grpc.insecure_channel(addr) - multi_callable = channel.unary_unary(_UNARY_UNARY) + multi_callable = channel.unary_unary( + _UNARY_UNARY, + _registered_method=True, + ) self.assertEqual(_RESPONSE, multi_callable(_REQUEST)) server.stop(None) # By default, the channel connectivity is checked every 5s diff --git a/src/python/grpcio_tests/tests/unit/_resource_exhausted_test.py b/src/python/grpcio_tests/tests/unit/_resource_exhausted_test.py index 3fc04f06a18..93e72c84f07 100644 --- a/src/python/grpcio_tests/tests/unit/_resource_exhausted_test.py +++ b/src/python/grpcio_tests/tests/unit/_resource_exhausted_test.py @@ -149,7 +149,10 @@ class ResourceExhaustedTest(unittest.TestCase): self._channel.close() def testUnaryUnary(self): - multi_callable = self._channel.unary_unary(_UNARY_UNARY) + multi_callable = self._channel.unary_unary( + _UNARY_UNARY, + _registered_method=True, + ) futures = [] for _ in range(test_constants.THREAD_CONCURRENCY): futures.append(multi_callable.future(_REQUEST)) @@ -178,7 +181,10 @@ class ResourceExhaustedTest(unittest.TestCase): self.assertEqual(_RESPONSE, multi_callable(_REQUEST)) def testUnaryStream(self): - multi_callable = self._channel.unary_stream(_UNARY_STREAM) + multi_callable = self._channel.unary_stream( + _UNARY_STREAM, + _registered_method=True, + ) calls = [] for _ in range(test_constants.THREAD_CONCURRENCY): calls.append(multi_callable(_REQUEST)) @@ -205,7 +211,10 @@ class ResourceExhaustedTest(unittest.TestCase): self.assertEqual(_RESPONSE, response) def testStreamUnary(self): - multi_callable = self._channel.stream_unary(_STREAM_UNARY) + multi_callable = self._channel.stream_unary( + _STREAM_UNARY, + _registered_method=True, + ) futures = [] request = iter([_REQUEST] * test_constants.STREAM_LENGTH) for _ in range(test_constants.THREAD_CONCURRENCY): @@ -236,7 +245,10 @@ class ResourceExhaustedTest(unittest.TestCase): self.assertEqual(_RESPONSE, multi_callable(request)) def testStreamStream(self): - multi_callable = self._channel.stream_stream(_STREAM_STREAM) + multi_callable = self._channel.stream_stream( + _STREAM_STREAM, + _registered_method=True, + ) calls = [] request = iter([_REQUEST] * test_constants.STREAM_LENGTH) for _ in range(test_constants.THREAD_CONCURRENCY): diff --git a/src/python/grpcio_tests/tests/unit/_rpc_test_helpers.py b/src/python/grpcio_tests/tests/unit/_rpc_test_helpers.py index 1027be1c677..d85eb30c8e7 100644 --- a/src/python/grpcio_tests/tests/unit/_rpc_test_helpers.py +++ b/src/python/grpcio_tests/tests/unit/_rpc_test_helpers.py @@ -277,7 +277,10 @@ class _GenericHandler(grpc.GenericRpcHandler): def unary_unary_multi_callable(channel): - return channel.unary_unary(_UNARY_UNARY) + return channel.unary_unary( + _UNARY_UNARY, + _registered_method=True, + ) def unary_stream_multi_callable(channel): @@ -285,6 +288,7 @@ def unary_stream_multi_callable(channel): _UNARY_STREAM, request_serializer=_SERIALIZE_REQUEST, response_deserializer=_DESERIALIZE_RESPONSE, + _registered_method=True, ) @@ -293,6 +297,7 @@ def unary_stream_non_blocking_multi_callable(channel): _UNARY_STREAM_NON_BLOCKING, request_serializer=_SERIALIZE_REQUEST, response_deserializer=_DESERIALIZE_RESPONSE, + _registered_method=True, ) @@ -301,15 +306,22 @@ def stream_unary_multi_callable(channel): _STREAM_UNARY, request_serializer=_SERIALIZE_REQUEST, response_deserializer=_DESERIALIZE_RESPONSE, + _registered_method=True, ) def stream_stream_multi_callable(channel): - return channel.stream_stream(_STREAM_STREAM) + return channel.stream_stream( + _STREAM_STREAM, + _registered_method=True, + ) def stream_stream_non_blocking_multi_callable(channel): - return channel.stream_stream(_STREAM_STREAM_NON_BLOCKING) + return channel.stream_stream( + _STREAM_STREAM_NON_BLOCKING, + _registered_method=True, + ) class BaseRPCTest(object): diff --git a/src/python/grpcio_tests/tests/unit/_server_shutdown_scenarios.py b/src/python/grpcio_tests/tests/unit/_server_shutdown_scenarios.py index 9190f108f7b..34d51fd72ec 100644 --- a/src/python/grpcio_tests/tests/unit/_server_shutdown_scenarios.py +++ b/src/python/grpcio_tests/tests/unit/_server_shutdown_scenarios.py @@ -81,7 +81,10 @@ def run_test(args): thread.start() port = port_queue.get() channel = grpc.insecure_channel("localhost:%d" % port) - multi_callable = channel.unary_unary(FORK_EXIT) + multi_callable = channel.unary_unary( + FORK_EXIT, + _registered_method=True, + ) result, call = multi_callable.with_call(REQUEST, wait_for_ready=True) os.wait() else: diff --git a/src/python/grpcio_tests/tests/unit/_session_cache_test.py b/src/python/grpcio_tests/tests/unit/_session_cache_test.py index acf671d1ca9..e4bc64f7d11 100644 --- a/src/python/grpcio_tests/tests/unit/_session_cache_test.py +++ b/src/python/grpcio_tests/tests/unit/_session_cache_test.py @@ -77,7 +77,10 @@ class SSLSessionCacheTest(unittest.TestCase): channel = grpc.secure_channel( "localhost:{}".format(port), channel_creds, options=channel_options ) - response = channel.unary_unary(_UNARY_UNARY)(_REQUEST) + response = channel.unary_unary( + _UNARY_UNARY, + _registered_method=True, + )(_REQUEST) auth_data = pickle.loads(response) self.assertEqual( expect_ssl_session_reused, diff --git a/src/python/grpcio_tests/tests/unit/_signal_client.py b/src/python/grpcio_tests/tests/unit/_signal_client.py index 56563c20075..34c3da0c933 100644 --- a/src/python/grpcio_tests/tests/unit/_signal_client.py +++ b/src/python/grpcio_tests/tests/unit/_signal_client.py @@ -53,7 +53,10 @@ def main_unary(server_target): """Initiate a unary RPC to be interrupted by a SIGINT.""" global per_process_rpc_future # pylint: disable=global-statement with grpc.insecure_channel(server_target) as channel: - multicallable = channel.unary_unary(UNARY_UNARY) + multicallable = channel.unary_unary( + UNARY_UNARY, + _registered_method=True, + ) signal.signal(signal.SIGINT, handle_sigint) per_process_rpc_future = multicallable.future( _MESSAGE, wait_for_ready=True @@ -67,9 +70,10 @@ def main_streaming(server_target): global per_process_rpc_future # pylint: disable=global-statement with grpc.insecure_channel(server_target) as channel: signal.signal(signal.SIGINT, handle_sigint) - per_process_rpc_future = channel.unary_stream(UNARY_STREAM)( - _MESSAGE, wait_for_ready=True - ) + per_process_rpc_future = channel.unary_stream( + UNARY_STREAM, + _registered_method=True, + )(_MESSAGE, wait_for_ready=True) for result in per_process_rpc_future: pass assert False, _ASSERTION_MESSAGE @@ -79,7 +83,10 @@ def main_unary_with_exception(server_target): """Initiate a unary RPC with a signal handler that will raise.""" channel = grpc.insecure_channel(server_target) try: - channel.unary_unary(UNARY_UNARY)(_MESSAGE, wait_for_ready=True) + channel.unary_unary( + UNARY_UNARY, + _registered_method=True, + )(_MESSAGE, wait_for_ready=True) except KeyboardInterrupt: sys.stderr.write("Running signal handler.\n") sys.stderr.flush() @@ -92,9 +99,10 @@ def main_streaming_with_exception(server_target): """Initiate a streaming RPC with a signal handler that will raise.""" channel = grpc.insecure_channel(server_target) try: - for _ in channel.unary_stream(UNARY_STREAM)( - _MESSAGE, wait_for_ready=True - ): + for _ in channel.unary_stream( + UNARY_STREAM, + _registered_method=True, + )(_MESSAGE, wait_for_ready=True): pass except KeyboardInterrupt: sys.stderr.write("Running signal handler.\n") diff --git a/src/python/grpcio_tests/tests/unit/_xds_credentials_test.py b/src/python/grpcio_tests/tests/unit/_xds_credentials_test.py index 977d564888d..6d8b2b6a041 100644 --- a/src/python/grpcio_tests/tests/unit/_xds_credentials_test.py +++ b/src/python/grpcio_tests/tests/unit/_xds_credentials_test.py @@ -71,9 +71,10 @@ class XdsCredentialsTest(unittest.TestCase): server_address, channel_creds, options=override_options ) as channel: request = b"abc" - response = channel.unary_unary("/test/method")( - request, wait_for_ready=True - ) + response = channel.unary_unary( + "/test/method", + _registered_method=True, + )(request, wait_for_ready=True) self.assertEqual(response, request) def test_xds_creds_fallback_insecure(self): @@ -89,9 +90,10 @@ class XdsCredentialsTest(unittest.TestCase): channel_creds = grpc.xds_channel_credentials(channel_fallback_creds) with grpc.secure_channel(server_address, channel_creds) as channel: request = b"abc" - response = channel.unary_unary("/test/method")( - request, wait_for_ready=True - ) + response = channel.unary_unary( + "/test/method", + _registered_method=True, + )(request, wait_for_ready=True) self.assertEqual(response, request) def test_start_xds_server(self): diff --git a/src/python/grpcio_tests/tests_gevent/unit/close_channel_test.py b/src/python/grpcio_tests/tests_gevent/unit/close_channel_test.py index 47fdb2c22e7..f09c47734e0 100644 --- a/src/python/grpcio_tests/tests_gevent/unit/close_channel_test.py +++ b/src/python/grpcio_tests/tests_gevent/unit/close_channel_test.py @@ -65,6 +65,7 @@ class CloseChannelTest(unittest.TestCase): _UNARY_CALL_METHOD_WITH_SLEEP, request_serializer=messages_pb2.SimpleRequest.SerializeToString, response_deserializer=messages_pb2.SimpleResponse.FromString, + _registered_method=True, ) greenlet = group.spawn(self._run_client, UnaryCallWithSleep) # release loop so that greenlet can take control @@ -78,6 +79,7 @@ class CloseChannelTest(unittest.TestCase): _UNARY_CALL_METHOD_WITH_SLEEP, request_serializer=messages_pb2.SimpleRequest.SerializeToString, response_deserializer=messages_pb2.SimpleResponse.FromString, + _registered_method=True, ) greenlet = group.spawn(self._run_client, UnaryCallWithSleep) # release loop so that greenlet can take control diff --git a/src/python/grpcio_tests/tests_py3_only/unit/_leak_test.py b/src/python/grpcio_tests/tests_py3_only/unit/_leak_test.py index c917bd10521..06fa6466014 100644 --- a/src/python/grpcio_tests/tests_py3_only/unit/_leak_test.py +++ b/src/python/grpcio_tests/tests_py3_only/unit/_leak_test.py @@ -68,7 +68,10 @@ def _start_a_test_server(): def _perform_an_rpc(address): channel = grpc.insecure_channel(address) - multicallable = channel.unary_unary(_TEST_METHOD) + multicallable = channel.unary_unary( + _TEST_METHOD, + _registered_method=True, + ) response = multicallable(_REQUEST) assert _REQUEST == response diff --git a/src/python/grpcio_tests/tests_py3_only/unit/_simple_stubs_test.py b/src/python/grpcio_tests/tests_py3_only/unit/_simple_stubs_test.py index 771097936f6..adcc1299e9c 100644 --- a/src/python/grpcio_tests/tests_py3_only/unit/_simple_stubs_test.py +++ b/src/python/grpcio_tests/tests_py3_only/unit/_simple_stubs_test.py @@ -193,6 +193,7 @@ class SimpleStubsTest(unittest.TestCase): _UNARY_UNARY, channel_credentials=grpc.experimental.insecure_channel_credentials(), timeout=None, + _registered_method=0, ) self.assertEqual(_REQUEST, response) @@ -205,6 +206,7 @@ class SimpleStubsTest(unittest.TestCase): _UNARY_UNARY, channel_credentials=grpc.local_channel_credentials(), timeout=None, + _registered_method=0, ) self.assertEqual(_REQUEST, response) @@ -213,7 +215,10 @@ class SimpleStubsTest(unittest.TestCase): target = f"localhost:{port}" test_name = inspect.stack()[0][3] args = (_REQUEST, target, _UNARY_UNARY) - kwargs = {"channel_credentials": grpc.local_channel_credentials()} + kwargs = { + "channel_credentials": grpc.local_channel_credentials(), + "_registered_method": True, + } def _invoke(seed: str): run_kwargs = dict(kwargs) @@ -230,6 +235,7 @@ class SimpleStubsTest(unittest.TestCase): target, _UNARY_UNARY, channel_credentials=grpc.local_channel_credentials(), + _registered_method=0, ) self.assert_eventually( lambda: grpc._simple_stubs.ChannelCache.get()._test_only_channel_count() @@ -250,6 +256,7 @@ class SimpleStubsTest(unittest.TestCase): _UNARY_UNARY, options=options, channel_credentials=grpc.local_channel_credentials(), + _registered_method=0, ) self.assert_eventually( lambda: grpc._simple_stubs.ChannelCache.get()._test_only_channel_count() @@ -265,6 +272,7 @@ class SimpleStubsTest(unittest.TestCase): target, _UNARY_STREAM, channel_credentials=grpc.local_channel_credentials(), + _registered_method=0, ): self.assertEqual(_REQUEST, response) @@ -280,6 +288,7 @@ class SimpleStubsTest(unittest.TestCase): target, _STREAM_UNARY, channel_credentials=grpc.local_channel_credentials(), + _registered_method=0, ) self.assertEqual(_REQUEST, response) @@ -295,6 +304,7 @@ class SimpleStubsTest(unittest.TestCase): target, _STREAM_STREAM, channel_credentials=grpc.local_channel_credentials(), + _registered_method=0, ): self.assertEqual(_REQUEST, response) @@ -319,14 +329,22 @@ class SimpleStubsTest(unittest.TestCase): with _server(server_creds) as port: target = f"localhost:{port}" response = grpc.experimental.unary_unary( - _REQUEST, target, _UNARY_UNARY, options=_property_options + _REQUEST, + target, + _UNARY_UNARY, + options=_property_options, + _registered_method=0, ) def test_insecure_sugar(self): with _server(None) as port: target = f"localhost:{port}" response = grpc.experimental.unary_unary( - _REQUEST, target, _UNARY_UNARY, insecure=True + _REQUEST, + target, + _UNARY_UNARY, + insecure=True, + _registered_method=0, ) self.assertEqual(_REQUEST, response) @@ -340,14 +358,24 @@ class SimpleStubsTest(unittest.TestCase): _UNARY_UNARY, insecure=True, channel_credentials=grpc.local_channel_credentials(), + _registered_method=0, ) def test_default_wait_for_ready(self): addr, port, sock = get_socket() sock.close() target = f"{addr}:{port}" - channel = grpc._simple_stubs.ChannelCache.get().get_channel( - target, (), None, True, None + ( + channel, + unused_method_handle, + ) = grpc._simple_stubs.ChannelCache.get().get_channel( + target=target, + options=(), + channel_credentials=None, + insecure=True, + compression=None, + method=_UNARY_UNARY, + _registered_method=True, ) rpc_finished_event = threading.Event() rpc_failed_event = threading.Event() @@ -376,7 +404,12 @@ class SimpleStubsTest(unittest.TestCase): def _send_rpc(): try: response = grpc.experimental.unary_unary( - _REQUEST, target, _UNARY_UNARY, timeout=None, insecure=True + _REQUEST, + target, + _UNARY_UNARY, + timeout=None, + insecure=True, + _registered_method=0, ) rpc_finished_event.set() except Exception as e: @@ -399,6 +432,7 @@ class SimpleStubsTest(unittest.TestCase): target, _BLACK_HOLE, insecure=True, + _registered_method=0, **invocation_args, ) self.assertEqual( diff --git a/test/core/end2end/fuzzers/api_fuzzer.cc b/test/core/end2end/fuzzers/api_fuzzer.cc index d58e143b996..3cec7990f42 100644 --- a/test/core/end2end/fuzzers/api_fuzzer.cc +++ b/test/core/end2end/fuzzers/api_fuzzer.cc @@ -29,6 +29,8 @@ #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" @@ -407,6 +409,21 @@ void ApiFuzzer::Tick() { } } +namespace { + +// If there are more than 1K comma-delimited strings in target, remove +// the extra ones. +std::string SanitizeTargetUri(absl::string_view target) { + constexpr size_t kMaxCommaDelimitedStrings = 1000; + std::vector parts = absl::StrSplit(target, ','); + if (parts.size() > kMaxCommaDelimitedStrings) { + parts.resize(kMaxCommaDelimitedStrings); + } + return absl::StrJoin(parts, ","); +} + +} // namespace + ApiFuzzer::Result ApiFuzzer::CreateChannel( const api_fuzzer::CreateChannel& create_channel) { if (channel_ != nullptr) return Result::kComplete; @@ -423,8 +440,9 @@ ApiFuzzer::Result ApiFuzzer::CreateChannel( create_channel.has_channel_creds() ? ReadChannelCreds(create_channel.channel_creds()) : grpc_insecure_credentials_create(); - channel_ = grpc_channel_create(create_channel.target().c_str(), creds, - args.ToC().get()); + channel_ = + grpc_channel_create(SanitizeTargetUri(create_channel.target()).c_str(), + creds, args.ToC().get()); grpc_channel_credentials_release(creds); } GPR_ASSERT(channel_ != nullptr); diff --git a/test/core/promise/mpsc_test.cc b/test/core/promise/mpsc_test.cc index 38baa68cb60..3d6e669a173 100644 --- a/test/core/promise/mpsc_test.cc +++ b/test/core/promise/mpsc_test.cc @@ -95,6 +95,8 @@ TEST(MpscTest, SendingLotsOfThingsGivesPushback) { EXPECT_EQ(NowOrNever(sender.Send(MakePayload(1))), true); EXPECT_EQ(NowOrNever(sender.Send(MakePayload(2))), absl::nullopt); activity1.Deactivate(); + + EXPECT_CALL(activity1, WakeupRequested()); } TEST(MpscTest, ReceivingAfterBlockageWakesUp) { diff --git a/test/core/transport/chaotic_good/BUILD b/test/core/transport/chaotic_good/BUILD index fc698aeec44..11daa792100 100644 --- a/test/core/transport/chaotic_good/BUILD +++ b/test/core/transport/chaotic_good/BUILD @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("//bazel:grpc_build_system.bzl", "grpc_cc_test", "grpc_package") -load("//test/core/util:grpc_fuzzer.bzl", "grpc_fuzzer") +load("//bazel:grpc_build_system.bzl", "grpc_cc_library", "grpc_cc_test", "grpc_package") +load("//test/core/util:grpc_fuzzer.bzl", "grpc_fuzzer", "grpc_proto_fuzzer") licenses(["notice"]) @@ -22,6 +22,34 @@ grpc_package( visibility = "tests", ) +grpc_cc_library( + name = "mock_promise_endpoint", + testonly = 1, + srcs = ["mock_promise_endpoint.cc"], + hdrs = ["mock_promise_endpoint.h"], + external_deps = ["gtest"], + deps = [ + "//:grpc", + "//src/core:grpc_promise_endpoint", + ], +) + +grpc_cc_library( + name = "transport_test", + testonly = 1, + srcs = ["transport_test.cc"], + hdrs = ["transport_test.h"], + external_deps = ["gtest"], + deps = [ + "//:iomgr_timer", + "//src/core:chaotic_good_frame", + "//src/core:memory_quota", + "//src/core:resource_quota", + "//test/core/event_engine/fuzzing_event_engine", + "//test/core/event_engine/fuzzing_event_engine:fuzzing_event_engine_proto", + ], +) + grpc_cc_test( name = "frame_header_test", srcs = ["frame_header_test.cc"], @@ -54,7 +82,7 @@ grpc_cc_test( deps = ["//src/core:chaotic_good_frame"], ) -grpc_fuzzer( +grpc_proto_fuzzer( name = "frame_fuzzer", srcs = ["frame_fuzzer.cc"], corpus = "frame_fuzzer_corpus", @@ -63,7 +91,10 @@ grpc_fuzzer( "absl/status:statusor", ], language = "C++", + proto = "frame_fuzzer.proto", tags = ["no_windows"], + uses_event_engine = False, + uses_polling = False, deps = [ "//:exec_ctx", "//:gpr", @@ -96,18 +127,46 @@ grpc_cc_test( uses_event_engine = False, uses_polling = False, deps = [ + "mock_promise_endpoint", + "transport_test", "//:grpc", "//:grpc_public_hdrs", + "//src/core:arena", + "//src/core:chaotic_good_client_transport", + "//src/core:if", + "//src/core:loop", + "//src/core:seq", + "//src/core:slice_buffer", + ], +) + +grpc_cc_test( + name = "client_transport_error_test", + srcs = ["client_transport_error_test.cc"], + external_deps = [ + "absl/functional:any_invocable", + "absl/status", + "absl/status:statusor", + "absl/strings:str_format", + "absl/types:optional", + "gtest", + ], + language = "C++", + uses_event_engine = False, + uses_polling = False, + deps = [ + "//:grpc_public_hdrs", + "//:grpc_unsecure", "//:iomgr_timer", "//:ref_counted_ptr", "//src/core:activity", "//src/core:arena", "//src/core:chaotic_good_client_transport", "//src/core:event_engine_wakeup_scheduler", + "//src/core:grpc_promise_endpoint", "//src/core:if", "//src/core:join", "//src/core:loop", - "//src/core:map", "//src/core:memory_quota", "//src/core:pipe", "//src/core:resource_quota", @@ -120,8 +179,8 @@ grpc_cc_test( ) grpc_cc_test( - name = "client_transport_error_test", - srcs = ["client_transport_error_test.cc"], + name = "server_transport_test", + srcs = ["server_transport_test.cc"], external_deps = [ "absl/functional:any_invocable", "absl/status", @@ -134,20 +193,15 @@ grpc_cc_test( uses_event_engine = False, uses_polling = False, deps = [ + "mock_promise_endpoint", + "transport_test", "//:grpc", "//:grpc_public_hdrs", "//:iomgr_timer", "//:ref_counted_ptr", - "//src/core:activity", "//src/core:arena", - "//src/core:chaotic_good_client_transport", - "//src/core:event_engine_wakeup_scheduler", - "//src/core:grpc_promise_endpoint", - "//src/core:if", - "//src/core:join", - "//src/core:loop", + "//src/core:chaotic_good_server_transport", "//src/core:memory_quota", - "//src/core:pipe", "//src/core:resource_quota", "//src/core:seq", "//src/core:slice", diff --git a/test/core/transport/chaotic_good/client_transport_error_test.cc b/test/core/transport/chaotic_good/client_transport_error_test.cc index 3b30c4ca330..295e060b809 100644 --- a/test/core/transport/chaotic_good/client_transport_error_test.cc +++ b/test/core/transport/chaotic_good/client_transport_error_test.cc @@ -12,37 +12,31 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "absl/status/status.h" - -#include "src/core/ext/transport/chaotic_good/client_transport.h" -#include "src/core/lib/transport/promise_endpoint.h" -#include "src/core/lib/transport/transport.h" - -// IWYU pragma: no_include - #include -#include // IWYU pragma: keep +#include #include -#include // IWYU pragma: keep +#include #include #include -#include // IWYU pragma: keep +#include #include "absl/functional/any_invocable.h" -#include "absl/status/statusor.h" // IWYU pragma: keep -#include "absl/strings/str_format.h" // IWYU pragma: keep -#include "absl/types/optional.h" // IWYU pragma: keep +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/types/optional.h" #include "gmock/gmock.h" #include "gtest/gtest.h" #include #include -#include // IWYU pragma: keep +#include #include #include -#include // IWYU pragma: keep +#include +#include "src/core/ext/transport/chaotic_good/client_transport.h" #include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/iomgr/timer_manager.h" #include "src/core/lib/promise/activity.h" @@ -56,14 +50,16 @@ #include "src/core/lib/resource_quota/memory_quota.h" #include "src/core/lib/resource_quota/resource_quota.h" #include "src/core/lib/slice/slice_buffer.h" -#include "src/core/lib/slice/slice_internal.h" // IWYU pragma: keep -#include "src/core/lib/transport/metadata_batch.h" // IWYU pragma: keep +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/transport/metadata_batch.h" +#include "src/core/lib/transport/promise_endpoint.h" +#include "src/core/lib/transport/transport.h" #include "test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.h" #include "test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.pb.h" +using testing::AtMost; using testing::MockFunction; using testing::Return; -using testing::Sequence; using testing::StrictMock; using testing::WithArgs; @@ -98,333 +94,308 @@ class MockEndpoint GetLocalAddress, (), (const, override)); }; +struct MockPromiseEndpoint { + StrictMock* endpoint = new StrictMock(); + std::unique_ptr promise_endpoint = + std::make_unique( + std::unique_ptr>(endpoint), SliceBuffer()); +}; + +// Send messages from client to server. +auto SendClientToServerMessages(CallInitiator initiator, int num_messages) { + return Loop([initiator, num_messages]() mutable { + bool has_message = (num_messages > 0); + return If( + has_message, + Seq(initiator.PushMessage(GetContext()->MakePooled()), + [&num_messages]() -> LoopCtl { + --num_messages; + return Continue(); + }), + [initiator]() mutable -> LoopCtl { + initiator.FinishSends(); + return absl::OkStatus(); + }); + }); +} + +ClientMetadataHandle TestInitialMetadata() { + auto md = + GetContext()->MakePooled(GetContext()); + md->Set(HttpPathMetadata(), Slice::FromStaticString("/test")); + return md; +} + class ClientTransportTest : public ::testing::Test { - public: - ClientTransportTest() - : control_endpoint_ptr_(new StrictMock()), - data_endpoint_ptr_(new StrictMock()), - memory_allocator_( - ResourceQuota::Default()->memory_quota()->CreateMemoryAllocator( - "test")), - control_endpoint_(*control_endpoint_ptr_), - data_endpoint_(*data_endpoint_ptr_), - event_engine_(std::make_shared< - grpc_event_engine::experimental::FuzzingEventEngine>( - []() { - grpc_timer_manager_set_threading(false); - grpc_event_engine::experimental::FuzzingEventEngine::Options - options; - return options; - }(), - fuzzing_event_engine::Actions())), - arena_(MakeScopedArena(initial_arena_size, &memory_allocator_)), - pipe_client_to_server_messages_(arena_.get()), - pipe_server_to_client_messages_(arena_.get()), - pipe_server_intial_metadata_(arena_.get()), - pipe_client_to_server_messages_second_(arena_.get()), - pipe_server_to_client_messages_second_(arena_.get()), - pipe_server_intial_metadata_second_(arena_.get()) {} - // Initial ClientTransport with read expecations - void InitialClientTransport() { - client_transport_ = std::make_unique( - std::make_unique( - std::unique_ptr(control_endpoint_ptr_), - SliceBuffer()), - std::make_unique( - std::unique_ptr(data_endpoint_ptr_), SliceBuffer()), - event_engine_); - } - // Send messages from client to server. - auto SendClientToServerMessages( - Pipe& pipe_client_to_server_messages, - int num_of_messages) { - return Loop([&pipe_client_to_server_messages, num_of_messages, - this]() mutable { - bool has_message = (num_of_messages > 0); - return If( - has_message, - Seq(pipe_client_to_server_messages.sender.Push( - arena_->MakePooled()), - [&num_of_messages]() -> LoopCtl { - num_of_messages--; - return Continue(); - }), - [&pipe_client_to_server_messages]() mutable -> LoopCtl { - pipe_client_to_server_messages.sender.Close(); - return absl::OkStatus(); - }); - }); - } - // Add stream into client transport, and expect return trailers of - // "grpc-status:code". - auto AddStream(CallArgs args) { - return client_transport_->AddStream(std::move(args)); + protected: + const std::shared_ptr& + event_engine() { + return event_engine_; } + MemoryAllocator* memory_allocator() { return &allocator_; } private: - MockEndpoint* control_endpoint_ptr_; - MockEndpoint* data_endpoint_ptr_; - size_t initial_arena_size = 1024; - MemoryAllocator memory_allocator_; - - protected: - MockEndpoint& control_endpoint_; - MockEndpoint& data_endpoint_; std::shared_ptr - event_engine_; - std::unique_ptr client_transport_; - ScopedArenaPtr arena_; - Pipe pipe_client_to_server_messages_; - Pipe pipe_server_to_client_messages_; - Pipe pipe_server_intial_metadata_; - // Added for mutliple streams tests. - Pipe pipe_client_to_server_messages_second_; - Pipe pipe_server_to_client_messages_second_; - Pipe pipe_server_intial_metadata_second_; - absl::AnyInvocable read_callback_; - Sequence control_endpoint_sequence_; - Sequence data_endpoint_sequence_; - // Added to verify received message payload. - const std::string message_ = {0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08}; + event_engine_{ + std::make_shared( + []() { + grpc_timer_manager_set_threading(false); + grpc_event_engine::experimental::FuzzingEventEngine::Options + options; + return options; + }(), + fuzzing_event_engine::Actions())}; + MemoryAllocator allocator_ = MakeResourceQuota("test-quota") + ->memory_quota() + ->CreateMemoryAllocator("test-allocator"); }; TEST_F(ClientTransportTest, AddOneStreamWithWriteFailed) { + MockPromiseEndpoint control_endpoint; + MockPromiseEndpoint data_endpoint; // Mock write failed and read is pending. - EXPECT_CALL(control_endpoint_, Write) + EXPECT_CALL(*control_endpoint.endpoint, Write) + .Times(AtMost(1)) .WillOnce( WithArgs<0>([](absl::AnyInvocable on_write) { on_write(absl::InternalError("control endpoint write failed.")); return false; })); - EXPECT_CALL(data_endpoint_, Write) + EXPECT_CALL(*data_endpoint.endpoint, Write) + .Times(AtMost(1)) .WillOnce( WithArgs<0>([](absl::AnyInvocable on_write) { on_write(absl::InternalError("data endpoint write failed.")); return false; })); - EXPECT_CALL(control_endpoint_, Read) - .InSequence(control_endpoint_sequence_) - .WillOnce(Return(false)); - InitialClientTransport(); - ClientMetadataHandle md; - auto args = CallArgs{std::move(md), - ClientInitialMetadataOutstandingToken::Empty(), - nullptr, - &pipe_server_intial_metadata_.sender, - &pipe_client_to_server_messages_.receiver, - &pipe_server_to_client_messages_.sender}; - StrictMock> on_done; - EXPECT_CALL(on_done, Call(absl::OkStatus())); - auto activity = MakeActivity( - Seq( - // Concurrently: write and read messages in client transport. - Join( - // Add first stream with call_args into client transport. - // Expect return trailers "grpc-status:unavailable". - AddStream(std::move(args)), - // Send messages to call_args.client_to_server_messages pipe, - // which will be eventually sent to control/data endpoints. - SendClientToServerMessages(pipe_client_to_server_messages_, 1)), - // Once complete, verify successful sending and the received value. - [](const std::tuple& ret) { - EXPECT_EQ(std::get<0>(ret)->get(GrpcStatusMetadata()).value(), - GRPC_STATUS_UNAVAILABLE); - EXPECT_TRUE(std::get<1>(ret).ok()); - return absl::OkStatus(); - }), - EventEngineWakeupScheduler(event_engine_), - [&on_done](absl::Status status) { on_done.Call(std::move(status)); }); + EXPECT_CALL(*control_endpoint.endpoint, Read).WillOnce(Return(false)); + auto transport = MakeOrphanable( + std::move(control_endpoint.promise_endpoint), + std::move(data_endpoint.promise_endpoint), event_engine()); + auto call = + MakeCall(event_engine().get(), Arena::Create(8192, memory_allocator())); + transport->StartCall(std::move(call.handler)); + call.initiator.SpawnGuarded("test-send", [initiator = + call.initiator]() mutable { + return TrySeq(initiator.PushClientInitialMetadata(TestInitialMetadata()), + SendClientToServerMessages(initiator, 1)); + }); + StrictMock> on_done; + EXPECT_CALL(on_done, Call()); + call.initiator.SpawnInfallible( + "test-read", [&on_done, initiator = call.initiator]() mutable { + return Seq( + initiator.PullServerInitialMetadata(), + [](ValueOrFailure md) { + EXPECT_FALSE(md.ok()); + return Empty{}; + }, + initiator.PullServerTrailingMetadata(), + [&on_done](ServerMetadataHandle md) { + EXPECT_EQ(md->get(GrpcStatusMetadata()).value(), + GRPC_STATUS_UNAVAILABLE); + on_done.Call(); + return Empty{}; + }); + }); // Wait until ClientTransport's internal activities to finish. - event_engine_->TickUntilIdle(); - event_engine_->UnsetGlobalHooks(); + event_engine()->TickUntilIdle(); + event_engine()->UnsetGlobalHooks(); } TEST_F(ClientTransportTest, AddOneStreamWithReadFailed) { + MockPromiseEndpoint control_endpoint; + MockPromiseEndpoint data_endpoint; // Mock read failed. - EXPECT_CALL(control_endpoint_, Read) - .InSequence(control_endpoint_sequence_) + EXPECT_CALL(*control_endpoint.endpoint, Read) .WillOnce(WithArgs<0>( [](absl::AnyInvocable on_read) mutable { on_read(absl::InternalError("control endpoint read failed.")); // Return false to mock EventEngine read not finish. return false; })); - InitialClientTransport(); - ClientMetadataHandle md; - auto args = CallArgs{std::move(md), - ClientInitialMetadataOutstandingToken::Empty(), - nullptr, - &pipe_server_intial_metadata_.sender, - &pipe_client_to_server_messages_.receiver, - &pipe_server_to_client_messages_.sender}; - StrictMock> on_done; - EXPECT_CALL(on_done, Call(absl::OkStatus())); - auto activity = MakeActivity( - Seq( - // Concurrently: write and read messages in client transport. - Join( - // Add first stream with call_args into client transport. - // Expect return trailers "grpc-status:unavailable". - AddStream(std::move(args)), - // Send messages to call_args.client_to_server_messages pipe. - SendClientToServerMessages(pipe_client_to_server_messages_, 1)), - // Once complete, verify successful sending and the received value. - [](const std::tuple& ret) { - EXPECT_EQ(std::get<0>(ret)->get(GrpcStatusMetadata()).value(), - GRPC_STATUS_UNAVAILABLE); - EXPECT_TRUE(std::get<1>(ret).ok()); - return absl::OkStatus(); - }), - EventEngineWakeupScheduler(event_engine_), - [&on_done](absl::Status status) { on_done.Call(std::move(status)); }); + auto transport = MakeOrphanable( + std::move(control_endpoint.promise_endpoint), + std::move(data_endpoint.promise_endpoint), event_engine()); + auto call = + MakeCall(event_engine().get(), Arena::Create(8192, memory_allocator())); + transport->StartCall(std::move(call.handler)); + call.initiator.SpawnGuarded("test-send", [initiator = + call.initiator]() mutable { + return TrySeq(initiator.PushClientInitialMetadata(TestInitialMetadata()), + SendClientToServerMessages(initiator, 1)); + }); + StrictMock> on_done; + EXPECT_CALL(on_done, Call()); + call.initiator.SpawnInfallible( + "test-read", [&on_done, initiator = call.initiator]() mutable { + return Seq( + initiator.PullServerInitialMetadata(), + [](ValueOrFailure md) { + EXPECT_FALSE(md.ok()); + return Empty{}; + }, + initiator.PullServerTrailingMetadata(), + [&on_done](ServerMetadataHandle md) { + EXPECT_EQ(md->get(GrpcStatusMetadata()).value(), + GRPC_STATUS_UNAVAILABLE); + on_done.Call(); + return Empty{}; + }); + }); // Wait until ClientTransport's internal activities to finish. - event_engine_->TickUntilIdle(); - event_engine_->UnsetGlobalHooks(); + event_engine()->TickUntilIdle(); + event_engine()->UnsetGlobalHooks(); } TEST_F(ClientTransportTest, AddMultipleStreamWithWriteFailed) { // Mock write failed at first stream and second stream's write will fail too. - EXPECT_CALL(control_endpoint_, Write) - .Times(1) + MockPromiseEndpoint control_endpoint; + MockPromiseEndpoint data_endpoint; + EXPECT_CALL(*control_endpoint.endpoint, Write) + .Times(AtMost(1)) .WillRepeatedly( WithArgs<0>([](absl::AnyInvocable on_write) { on_write(absl::InternalError("control endpoint write failed.")); return false; })); - EXPECT_CALL(data_endpoint_, Write) - .Times(1) + EXPECT_CALL(*data_endpoint.endpoint, Write) + .Times(AtMost(1)) .WillRepeatedly( WithArgs<0>([](absl::AnyInvocable on_write) { on_write(absl::InternalError("data endpoint write failed.")); return false; })); - EXPECT_CALL(control_endpoint_, Read) - .InSequence(control_endpoint_sequence_) - .WillOnce(Return(false)); - InitialClientTransport(); - ClientMetadataHandle first_stream_md; - ClientMetadataHandle second_stream_md; - auto first_stream_args = - CallArgs{std::move(first_stream_md), - ClientInitialMetadataOutstandingToken::Empty(), - nullptr, - &pipe_server_intial_metadata_.sender, - &pipe_client_to_server_messages_.receiver, - &pipe_server_to_client_messages_.sender}; - auto second_stream_args = - CallArgs{std::move(second_stream_md), - ClientInitialMetadataOutstandingToken::Empty(), - nullptr, - &pipe_server_intial_metadata_second_.sender, - &pipe_client_to_server_messages_second_.receiver, - &pipe_server_to_client_messages_second_.sender}; - StrictMock> on_done; - EXPECT_CALL(on_done, Call(absl::OkStatus())); - auto activity = MakeActivity( - Seq( - // Concurrently: write and read messages from client transport. - Join( - // Add first stream with call_args into client transport. - // Expect return trailers "grpc-status:unavailable". - AddStream(std::move(first_stream_args)), - // Send messages to first stream's - // call_args.client_to_server_messages pipe. - SendClientToServerMessages(pipe_client_to_server_messages_, 1)), - // Once complete, verify successful sending and the received value. - [](const std::tuple& ret) { - EXPECT_EQ(std::get<0>(ret)->get(GrpcStatusMetadata()).value(), - GRPC_STATUS_UNAVAILABLE); - EXPECT_TRUE(std::get<1>(ret).ok()); - return absl::OkStatus(); - }, - Join( - // Add second stream with call_args into client transport. - // Expect return trailers "grpc-status:unavailable". - AddStream(std::move(second_stream_args)), - // Send messages to second stream's - // call_args.client_to_server_messages pipe. - SendClientToServerMessages(pipe_client_to_server_messages_second_, - 1)), - // Once complete, verify successful sending and the received value. - [](const std::tuple& ret) { - EXPECT_EQ(std::get<0>(ret)->get(GrpcStatusMetadata()).value(), - GRPC_STATUS_UNAVAILABLE); - EXPECT_TRUE(std::get<1>(ret).ok()); - return absl::OkStatus(); - }), - EventEngineWakeupScheduler(event_engine_), - [&on_done](absl::Status status) { on_done.Call(std::move(status)); }); + EXPECT_CALL(*control_endpoint.endpoint, Read).WillOnce(Return(false)); + auto transport = MakeOrphanable( + std::move(control_endpoint.promise_endpoint), + std::move(data_endpoint.promise_endpoint), event_engine()); + auto call1 = + MakeCall(event_engine().get(), Arena::Create(8192, memory_allocator())); + transport->StartCall(std::move(call1.handler)); + auto call2 = + MakeCall(event_engine().get(), Arena::Create(8192, memory_allocator())); + transport->StartCall(std::move(call2.handler)); + call1.initiator.SpawnGuarded("test-send-1", [initiator = + call1.initiator]() mutable { + return TrySeq(initiator.PushClientInitialMetadata(TestInitialMetadata()), + SendClientToServerMessages(initiator, 1)); + }); + call2.initiator.SpawnGuarded("test-send-2", [initiator = + call2.initiator]() mutable { + return TrySeq(initiator.PushClientInitialMetadata(TestInitialMetadata()), + SendClientToServerMessages(initiator, 1)); + }); + StrictMock> on_done1; + EXPECT_CALL(on_done1, Call()); + StrictMock> on_done2; + EXPECT_CALL(on_done2, Call()); + call1.initiator.SpawnInfallible( + "test-read-1", [&on_done1, initiator = call1.initiator]() mutable { + return Seq( + initiator.PullServerInitialMetadata(), + [](ValueOrFailure md) { + EXPECT_FALSE(md.ok()); + return Empty{}; + }, + initiator.PullServerTrailingMetadata(), + [&on_done1](ServerMetadataHandle md) { + EXPECT_EQ(md->get(GrpcStatusMetadata()).value(), + GRPC_STATUS_UNAVAILABLE); + on_done1.Call(); + return Empty{}; + }); + }); + call2.initiator.SpawnInfallible( + "test-read-2", [&on_done2, initiator = call2.initiator]() mutable { + return Seq( + initiator.PullServerInitialMetadata(), + [](ValueOrFailure md) { + EXPECT_FALSE(md.ok()); + return Empty{}; + }, + initiator.PullServerTrailingMetadata(), + [&on_done2](ServerMetadataHandle md) { + EXPECT_EQ(md->get(GrpcStatusMetadata()).value(), + GRPC_STATUS_UNAVAILABLE); + on_done2.Call(); + return Empty{}; + }); + }); // Wait until ClientTransport's internal activities to finish. - event_engine_->TickUntilIdle(); - event_engine_->UnsetGlobalHooks(); + event_engine()->TickUntilIdle(); + event_engine()->UnsetGlobalHooks(); } TEST_F(ClientTransportTest, AddMultipleStreamWithReadFailed) { + MockPromiseEndpoint control_endpoint; + MockPromiseEndpoint data_endpoint; // Mock read failed at first stream, and second stream's write will fail too. - EXPECT_CALL(control_endpoint_, Read) - .InSequence(control_endpoint_sequence_) + EXPECT_CALL(*control_endpoint.endpoint, Read) .WillOnce(WithArgs<0>( [](absl::AnyInvocable on_read) mutable { on_read(absl::InternalError("control endpoint read failed.")); // Return false to mock EventEngine read not finish. return false; })); - InitialClientTransport(); - ClientMetadataHandle first_stream_md; - ClientMetadataHandle second_stream_md; - auto first_stream_args = - CallArgs{std::move(first_stream_md), - ClientInitialMetadataOutstandingToken::Empty(), - nullptr, - &pipe_server_intial_metadata_.sender, - &pipe_client_to_server_messages_.receiver, - &pipe_server_to_client_messages_.sender}; - auto second_stream_args = - CallArgs{std::move(second_stream_md), - ClientInitialMetadataOutstandingToken::Empty(), - nullptr, - &pipe_server_intial_metadata_second_.sender, - &pipe_client_to_server_messages_second_.receiver, - &pipe_server_to_client_messages_second_.sender}; - StrictMock> on_done; - EXPECT_CALL(on_done, Call(absl::OkStatus())); - auto activity = MakeActivity( - Seq( - // Concurrently: write and read messages from client transport. - Join( - // Add first stream with call_args into client transport. - AddStream(std::move(first_stream_args)), - // Send messages to first stream's - // call_args.client_to_server_messages pipe, which will be - // eventually sent to control/data endpoints. - SendClientToServerMessages(pipe_client_to_server_messages_, 1)), - // Once complete, verify successful sending and the received value. - [](const std::tuple& ret) { - EXPECT_EQ(std::get<0>(ret)->get(GrpcStatusMetadata()).value(), - GRPC_STATUS_UNAVAILABLE); - EXPECT_TRUE(std::get<1>(ret).ok()); - return absl::OkStatus(); - }, - Join( - // Add second stream with call_args into client transport. - AddStream(std::move(second_stream_args)), - // Send messages to second stream's - // call_args.client_to_server_messages pipe, which will be - // eventually sent to control/data endpoints. - SendClientToServerMessages(pipe_client_to_server_messages_second_, - 1)), - // Once complete, verify successful sending and the received value. - [](const std::tuple& ret) { - EXPECT_EQ(std::get<0>(ret)->get(GrpcStatusMetadata()).value(), - GRPC_STATUS_UNAVAILABLE); - EXPECT_TRUE(std::get<1>(ret).ok()); - return absl::OkStatus(); - }), - EventEngineWakeupScheduler(event_engine_), - [&on_done](absl::Status status) { on_done.Call(std::move(status)); }); + auto transport = MakeOrphanable( + std::move(control_endpoint.promise_endpoint), + std::move(data_endpoint.promise_endpoint), event_engine()); + auto call1 = + MakeCall(event_engine().get(), Arena::Create(8192, memory_allocator())); + transport->StartCall(std::move(call1.handler)); + auto call2 = + MakeCall(event_engine().get(), Arena::Create(8192, memory_allocator())); + transport->StartCall(std::move(call2.handler)); + call1.initiator.SpawnGuarded("test-send", [initiator = + call1.initiator]() mutable { + return TrySeq(initiator.PushClientInitialMetadata(TestInitialMetadata()), + SendClientToServerMessages(initiator, 1)); + }); + call2.initiator.SpawnGuarded("test-send", [initiator = + call2.initiator]() mutable { + return TrySeq(initiator.PushClientInitialMetadata(TestInitialMetadata()), + SendClientToServerMessages(initiator, 1)); + }); + StrictMock> on_done1; + EXPECT_CALL(on_done1, Call()); + StrictMock> on_done2; + EXPECT_CALL(on_done2, Call()); + call1.initiator.SpawnInfallible( + "test-read", [&on_done1, initiator = call1.initiator]() mutable { + return Seq( + initiator.PullServerInitialMetadata(), + [](ValueOrFailure md) { + EXPECT_FALSE(md.ok()); + return Empty{}; + }, + initiator.PullServerTrailingMetadata(), + [&on_done1](ServerMetadataHandle md) { + EXPECT_EQ(md->get(GrpcStatusMetadata()).value(), + GRPC_STATUS_UNAVAILABLE); + on_done1.Call(); + return Empty{}; + }); + }); + call2.initiator.SpawnInfallible( + "test-read", [&on_done2, initiator = call2.initiator]() mutable { + return Seq( + initiator.PullServerInitialMetadata(), + [](ValueOrFailure md) { + EXPECT_FALSE(md.ok()); + return Empty{}; + }, + initiator.PullServerTrailingMetadata(), + [&on_done2](ServerMetadataHandle md) { + EXPECT_EQ(md->get(GrpcStatusMetadata()).value(), + GRPC_STATUS_UNAVAILABLE); + on_done2.Call(); + return Empty{}; + }); + }); // Wait until ClientTransport's internal activities to finish. - event_engine_->TickUntilIdle(); - event_engine_->UnsetGlobalHooks(); + event_engine()->TickUntilIdle(); + event_engine()->UnsetGlobalHooks(); } } // namespace testing diff --git a/test/core/transport/chaotic_good/client_transport_test.cc b/test/core/transport/chaotic_good/client_transport_test.cc index 86bbf578c29..551c6fd5762 100644 --- a/test/core/transport/chaotic_good/client_transport_test.cc +++ b/test/core/transport/chaotic_good/client_transport_test.cc @@ -14,461 +14,245 @@ #include "src/core/ext/transport/chaotic_good/client_transport.h" -// IWYU pragma: no_include - -#include // IWYU pragma: keep +#include +#include +#include #include -#include // IWYU pragma: keep +#include #include -#include // IWYU pragma: keep +#include #include "absl/functional/any_invocable.h" -#include "absl/status/statusor.h" // IWYU pragma: keep -#include "absl/strings/str_format.h" // IWYU pragma: keep -#include "absl/types/optional.h" // IWYU pragma: keep +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/types/optional.h" #include "gmock/gmock.h" #include "gtest/gtest.h" #include #include -#include // IWYU pragma: keep +#include #include #include -#include // IWYU pragma: keep +#include -#include "src/core/lib/gprpp/ref_counted_ptr.h" -#include "src/core/lib/iomgr/timer_manager.h" -#include "src/core/lib/promise/activity.h" -#include "src/core/lib/promise/event_engine_wakeup_scheduler.h" #include "src/core/lib/promise/if.h" -#include "src/core/lib/promise/join.h" #include "src/core/lib/promise/loop.h" -#include "src/core/lib/promise/map.h" -#include "src/core/lib/promise/pipe.h" #include "src/core/lib/promise/seq.h" #include "src/core/lib/resource_quota/arena.h" -#include "src/core/lib/resource_quota/memory_quota.h" -#include "src/core/lib/resource_quota/resource_quota.h" #include "src/core/lib/slice/slice_buffer.h" -#include "src/core/lib/slice/slice_internal.h" // IWYU pragma: keep -#include "src/core/lib/transport/metadata_batch.h" // IWYU pragma: keep -#include "test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.h" -#include "test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.pb.h" +#include "src/core/lib/transport/metadata_batch.h" +#include "test/core/transport/chaotic_good/mock_promise_endpoint.h" +#include "test/core/transport/chaotic_good/transport_test.h" using testing::MockFunction; using testing::Return; -using testing::Sequence; using testing::StrictMock; -using testing::WithArgs; + +using EventEngineSlice = grpc_event_engine::experimental::Slice; namespace grpc_core { namespace chaotic_good { namespace testing { -class MockEndpoint - : public grpc_event_engine::experimental::EventEngine::Endpoint { - public: - MOCK_METHOD( - bool, Read, - (absl::AnyInvocable on_read, - grpc_event_engine::experimental::SliceBuffer* buffer, - const grpc_event_engine::experimental::EventEngine::Endpoint::ReadArgs* - args), - (override)); +// Encoded string of header ":path: /demo.Service/Step". +const uint8_t kPathDemoServiceStep[] = { + 0x40, 0x05, 0x3a, 0x70, 0x61, 0x74, 0x68, 0x12, 0x2f, + 0x64, 0x65, 0x6d, 0x6f, 0x2e, 0x53, 0x65, 0x72, 0x76, + 0x69, 0x63, 0x65, 0x2f, 0x53, 0x74, 0x65, 0x70}; - MOCK_METHOD( - bool, Write, - (absl::AnyInvocable on_writable, - grpc_event_engine::experimental::SliceBuffer* data, - const grpc_event_engine::experimental::EventEngine::Endpoint::WriteArgs* - args), - (override)); +// Encoded string of trailer "grpc-status: 0". +const uint8_t kGrpcStatus0[] = {0x10, 0x0b, 0x67, 0x72, 0x70, 0x63, 0x2d, 0x73, + 0x74, 0x61, 0x74, 0x75, 0x73, 0x01, 0x30}; - MOCK_METHOD( - const grpc_event_engine::experimental::EventEngine::ResolvedAddress&, - GetPeerAddress, (), (const, override)); - MOCK_METHOD( - const grpc_event_engine::experimental::EventEngine::ResolvedAddress&, - GetLocalAddress, (), (const, override)); -}; +ClientMetadataHandle TestInitialMetadata() { + auto md = + GetContext()->MakePooled(GetContext()); + md->Set(HttpPathMetadata(), Slice::FromStaticString("/demo.Service/Step")); + return md; +} -class ClientTransportTest : public ::testing::Test { - public: - ClientTransportTest() - : control_endpoint_ptr_(new StrictMock()), - data_endpoint_ptr_(new StrictMock()), - memory_allocator_( - ResourceQuota::Default()->memory_quota()->CreateMemoryAllocator( - "test")), - control_endpoint_(*control_endpoint_ptr_), - data_endpoint_(*data_endpoint_ptr_), - event_engine_(std::make_shared< - grpc_event_engine::experimental::FuzzingEventEngine>( - []() { - grpc_timer_manager_set_threading(false); - grpc_event_engine::experimental::FuzzingEventEngine::Options - options; - return options; - }(), - fuzzing_event_engine::Actions())), - arena_(MakeScopedArena(initial_arena_size, &memory_allocator_)), - pipe_client_to_server_messages_(arena_.get()), - pipe_server_to_client_messages_(arena_.get()), - pipe_server_intial_metadata_(arena_.get()), - pipe_client_to_server_messages_second_(arena_.get()), - pipe_server_to_client_messages_second_(arena_.get()), - pipe_server_intial_metadata_second_(arena_.get()) {} - // Expect how client transport will read from control/data endpoints with a - // test frame. - void AddReadExpectations(int num_of_streams) { - for (int i = 0; i < num_of_streams; i++) { - EXPECT_CALL(control_endpoint_, Read) - .InSequence(control_endpoint_sequence) - .WillOnce(WithArgs<0, 1>( - [this, i](absl::AnyInvocable on_read, - grpc_event_engine::experimental::SliceBuffer* - buffer) mutable { - // Construct test frame for EventEngine read: headers (15 - // bytes), message(16 bytes), message padding (48 byte), - // trailers (15 bytes). - const std::string frame_header = { - static_cast(0x80), // frame type = fragment - 0x03, // flag = has header + has trailer - 0x00, - 0x00, - static_cast(i + 1), // stream id = 1 - 0x00, - 0x00, - 0x00, - 0x1a, // header length = 26 - 0x00, - 0x00, - 0x00, - 0x08, // message length = 8 - 0x00, - 0x00, - 0x00, - 0x38, // message padding =56 - 0x00, - 0x00, - 0x00, - 0x0f, // trailer length = 15 - 0x00, - 0x00, - 0x00}; - // Schedule mock_endpoint to read buffer. - grpc_event_engine::experimental::Slice slice( - grpc_slice_from_cpp_string(frame_header)); - buffer->Append(std::move(slice)); - // Execute read callback later to control when read starts. - if (i == 0) { - read_callback_ = std::move(on_read); - // Return false to mock EventEngine read not finish. - return false; - } else { - return true; - } - })); - EXPECT_CALL(control_endpoint_, Read) - .InSequence(control_endpoint_sequence) - .WillOnce(WithArgs<1>( - [](grpc_event_engine::experimental::SliceBuffer* buffer) { - // Encoded string of header ":path: /demo.Service/Step". - const std::string header = { - 0x10, 0x05, 0x3a, 0x70, 0x61, 0x74, 0x68, 0x12, 0x2f, - 0x64, 0x65, 0x6d, 0x6f, 0x2e, 0x53, 0x65, 0x72, 0x76, - 0x69, 0x63, 0x65, 0x2f, 0x53, 0x74, 0x65, 0x70}; - // Encoded string of trailer "grpc-status: 0". - const std::string trailers = {0x10, 0x0b, 0x67, 0x72, 0x70, - 0x63, 0x2d, 0x73, 0x74, 0x61, - 0x74, 0x75, 0x73, 0x01, 0x30}; - // Schedule mock_endpoint to read buffer. - grpc_event_engine::experimental::Slice slice( - grpc_slice_from_cpp_string(header + trailers)); - buffer->Append(std::move(slice)); - return true; - })); - } - EXPECT_CALL(control_endpoint_, Read) - .InSequence(control_endpoint_sequence) - .WillOnce(Return(false)); - for (int i = 0; i < num_of_streams; i++) { - EXPECT_CALL(data_endpoint_, Read) - .InSequence(data_endpoint_sequence) - .WillOnce(WithArgs<1>( - [this](grpc_event_engine::experimental::SliceBuffer* buffer) { - const std::string message_padding = { - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}; - grpc_event_engine::experimental::Slice slice( - grpc_slice_from_cpp_string(message_padding + message_)); - buffer->Append(std::move(slice)); - return true; - })); - } - } - // Initial ClientTransport with read expecations - void InitialClientTransport(int num_of_streams) { - // Read expectaions need to be added before transport initialization since - // reader_ activity loop is started in ClientTransport initialization, - AddReadExpectations(num_of_streams); - client_transport_ = std::make_unique( - std::make_unique( - std::unique_ptr(control_endpoint_ptr_), - SliceBuffer()), - std::make_unique( - std::unique_ptr(data_endpoint_ptr_), SliceBuffer()), - event_engine_); - } - // Send messages from client to server. - auto SendClientToServerMessages( - Pipe& pipe_client_to_server_messages, - int num_of_messages) { - return Loop([&pipe_client_to_server_messages, num_of_messages, - this]() mutable { - bool has_message = (num_of_messages > 0); - return If( - has_message, - Seq(pipe_client_to_server_messages.sender.Push( - arena_->MakePooled()), - [&num_of_messages]() -> LoopCtl { - num_of_messages--; - return Continue(); - }), - [&pipe_client_to_server_messages]() mutable -> LoopCtl { - pipe_client_to_server_messages.sender.Close(); - return absl::OkStatus(); - }); - }); - } - // Add stream into client transport, and expect return trailers of - // "grpc-status:code". - auto AddStream(CallArgs args, const grpc_status_code trailers) { - return Seq(client_transport_->AddStream(std::move(args)), - [trailers](ServerMetadataHandle ret) { - // AddStream will finish with server trailers: - // "grpc-status:code". - EXPECT_EQ(ret->get(GrpcStatusMetadata()).value(), trailers); - return trailers; - }); - } - // Start read from control endpoints. - auto StartRead(const absl::Status& read_status) { - return [read_status, this] { - read_callback_(read_status); - return read_status; - }; - } - // Receive messages from server to client. - auto ReceiveServerToClientMessages( - Pipe& pipe_server_intial_metadata, - Pipe& pipe_server_to_client_messages) { - return Seq( - // Receive server initial metadata. - Map(pipe_server_intial_metadata.receiver.Next(), - [](NextResult r) { - // Expect value: ":path: /demo.Service/Step" - EXPECT_TRUE(r.has_value()); - EXPECT_EQ( - r.value()->get_pointer(HttpPathMetadata())->as_string_view(), - "/demo.Service/Step"); - return absl::OkStatus(); - }), - // Receive server to client messages. - Map(pipe_server_to_client_messages.receiver.Next(), - [this](NextResult r) { - EXPECT_TRUE(r.has_value()); - EXPECT_EQ(r.value()->payload()->JoinIntoString(), message_); - return absl::OkStatus(); +// Send messages from client to server. +auto SendClientToServerMessages(CallInitiator initiator, int num_messages) { + return Loop([initiator, num_messages, i = 0]() mutable { + bool has_message = (i < num_messages); + return If( + has_message, + Seq(initiator.PushMessage(GetContext()->MakePooled( + SliceBuffer(Slice::FromCopiedString(std::to_string(i))), 0)), + [&i]() -> LoopCtl { + ++i; + return Continue(); }), - [&pipe_server_intial_metadata, - &pipe_server_to_client_messages]() mutable { - // Close pipes after receive message. - pipe_server_to_client_messages.sender.Close(); - pipe_server_intial_metadata.sender.Close(); + [initiator]() mutable -> LoopCtl { + initiator.FinishSends(); return absl::OkStatus(); }); - } - - private: - MockEndpoint* control_endpoint_ptr_; - MockEndpoint* data_endpoint_ptr_; - size_t initial_arena_size = 1024; - MemoryAllocator memory_allocator_; - Sequence control_endpoint_sequence; - Sequence data_endpoint_sequence; - - protected: - MockEndpoint& control_endpoint_; - MockEndpoint& data_endpoint_; - std::shared_ptr - event_engine_; - std::unique_ptr client_transport_; - ScopedArenaPtr arena_; - Pipe pipe_client_to_server_messages_; - Pipe pipe_server_to_client_messages_; - Pipe pipe_server_intial_metadata_; - // Added for mutliple streams tests. - Pipe pipe_client_to_server_messages_second_; - Pipe pipe_server_to_client_messages_second_; - Pipe pipe_server_intial_metadata_second_; - absl::AnyInvocable read_callback_; - // Added to verify received message payload. - const std::string message_ = {0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08}; -}; - -TEST_F(ClientTransportTest, AddOneStream) { - InitialClientTransport(1); - ClientMetadataHandle md; - auto args = CallArgs{std::move(md), - ClientInitialMetadataOutstandingToken::Empty(), - nullptr, - &pipe_server_intial_metadata_.sender, - &pipe_client_to_server_messages_.receiver, - &pipe_server_to_client_messages_.sender}; - StrictMock> on_done; - EXPECT_CALL(on_done, Call(absl::OkStatus())); - EXPECT_CALL(control_endpoint_, Write).WillOnce(Return(true)); - EXPECT_CALL(data_endpoint_, Write).WillOnce(Return(true)); - auto activity = MakeActivity( - Seq( - // Concurrently: write and read messages in client transport. - Join( - // Add first stream with call_args into client transport. - AddStream(std::move(args), GRPC_STATUS_OK), - // Start read from control endpoints. - StartRead(absl::OkStatus()), - // Send messages to call_args.client_to_server_messages pipe, - // which will be eventually sent to control/data endpoints. - SendClientToServerMessages(pipe_client_to_server_messages_, 1), - // Receive messages from control/data endpoints. - ReceiveServerToClientMessages(pipe_server_intial_metadata_, - pipe_server_to_client_messages_)), - // Once complete, verify successful sending and the received value. - [](const std::tuple& ret) { - EXPECT_EQ(std::get<0>(ret), GRPC_STATUS_OK); - EXPECT_TRUE(std::get<1>(ret).ok()); - EXPECT_TRUE(std::get<2>(ret).ok()); - EXPECT_TRUE(std::get<3>(ret).ok()); - return absl::OkStatus(); - }), - EventEngineWakeupScheduler(event_engine_), - [&on_done](absl::Status status) { on_done.Call(std::move(status)); }); - // Wait until ClientTransport's internal activities to finish. - event_engine_->TickUntilIdle(); - event_engine_->UnsetGlobalHooks(); + }); } -TEST_F(ClientTransportTest, AddOneStreamMultipleMessages) { - InitialClientTransport(1); - ClientMetadataHandle md; - auto args = CallArgs{std::move(md), - ClientInitialMetadataOutstandingToken::Empty(), - nullptr, - &pipe_server_intial_metadata_.sender, - &pipe_client_to_server_messages_.receiver, - &pipe_server_to_client_messages_.sender}; - StrictMock> on_done; - EXPECT_CALL(on_done, Call(absl::OkStatus())); - EXPECT_CALL(control_endpoint_, Write).Times(3).WillRepeatedly(Return(true)); - EXPECT_CALL(data_endpoint_, Write).Times(3).WillRepeatedly(Return(true)); - auto activity = MakeActivity( - Seq( - // Concurrently: write and read messages in client transport. - Join( - // Add first stream with call_args into client transport. - AddStream(std::move(args), GRPC_STATUS_OK), - // Start read from control endpoints. - StartRead(absl::OkStatus()), - // Send messages to call_args.client_to_server_messages pipe, - // which will be eventually sent to control/data endpoints. - SendClientToServerMessages(pipe_client_to_server_messages_, 3), - // Receive messages from control/data endpoints. - ReceiveServerToClientMessages(pipe_server_intial_metadata_, - pipe_server_to_client_messages_)), - // Once complete, verify successful sending and the received value. - [](const std::tuple& ret) { - EXPECT_EQ(std::get<0>(ret), GRPC_STATUS_OK); - EXPECT_TRUE(std::get<1>(ret).ok()); - EXPECT_TRUE(std::get<2>(ret).ok()); - EXPECT_TRUE(std::get<3>(ret).ok()); - return absl::OkStatus(); - }), - EventEngineWakeupScheduler(event_engine_), - [&on_done](absl::Status status) { on_done.Call(std::move(status)); }); +TEST_F(TransportTest, AddOneStream) { + MockPromiseEndpoint control_endpoint; + MockPromiseEndpoint data_endpoint; + control_endpoint.ExpectRead( + {SerializedFrameHeader(FrameType::kFragment, 7, 1, 26, 8, 56, 15), + EventEngineSlice::FromCopiedBuffer(kPathDemoServiceStep, + sizeof(kPathDemoServiceStep)), + EventEngineSlice::FromCopiedBuffer(kGrpcStatus0, sizeof(kGrpcStatus0))}, + event_engine().get()); + data_endpoint.ExpectRead( + {EventEngineSlice::FromCopiedString("12345678"), Zeros(56)}, nullptr); + EXPECT_CALL(*control_endpoint.endpoint, Read) + .InSequence(control_endpoint.read_sequence) + .WillOnce(Return(false)); + auto transport = MakeOrphanable( + std::move(control_endpoint.promise_endpoint), + std::move(data_endpoint.promise_endpoint), event_engine()); + auto call = + MakeCall(event_engine().get(), Arena::Create(1024, memory_allocator())); + transport->StartCall(std::move(call.handler)); + StrictMock> on_done; + EXPECT_CALL(on_done, Call()); + control_endpoint.ExpectWrite( + {SerializedFrameHeader(FrameType::kFragment, 1, 1, + sizeof(kPathDemoServiceStep), 0, 0, 0), + EventEngineSlice::FromCopiedBuffer(kPathDemoServiceStep, + sizeof(kPathDemoServiceStep))}, + nullptr); + control_endpoint.ExpectWrite( + {SerializedFrameHeader(FrameType::kFragment, 2, 1, 0, 1, 63, 0)}, + nullptr); + data_endpoint.ExpectWrite( + {EventEngineSlice::FromCopiedString("0"), Zeros(63)}, nullptr); + control_endpoint.ExpectWrite( + {SerializedFrameHeader(FrameType::kFragment, 4, 1, 0, 0, 0, 0)}, nullptr); + call.initiator.SpawnGuarded("test-send", [initiator = + call.initiator]() mutable { + return TrySeq(initiator.PushClientInitialMetadata(TestInitialMetadata()), + SendClientToServerMessages(initiator, 1)); + }); + call.initiator.SpawnInfallible( + "test-read", [&on_done, initiator = call.initiator]() mutable { + return Seq( + initiator.PullServerInitialMetadata(), + [](ValueOrFailure md) { + EXPECT_TRUE(md.ok()); + EXPECT_EQ( + md.value()->get_pointer(HttpPathMetadata())->as_string_view(), + "/demo.Service/Step"); + return Empty{}; + }, + initiator.PullMessage(), + [](NextResult msg) { + EXPECT_TRUE(msg.has_value()); + EXPECT_EQ(msg.value()->payload()->JoinIntoString(), "12345678"); + return Empty{}; + }, + initiator.PullMessage(), + [](NextResult msg) { + EXPECT_FALSE(msg.has_value()); + return Empty{}; + }, + initiator.PullServerTrailingMetadata(), + [&on_done](ServerMetadataHandle md) { + EXPECT_EQ(md->get(GrpcStatusMetadata()).value(), GRPC_STATUS_OK); + on_done.Call(); + return Empty{}; + }); + }); // Wait until ClientTransport's internal activities to finish. - event_engine_->TickUntilIdle(); - event_engine_->UnsetGlobalHooks(); + event_engine()->TickUntilIdle(); + event_engine()->UnsetGlobalHooks(); } -TEST_F(ClientTransportTest, AddMultipleStreamsMultipleMessages) { - InitialClientTransport(2); - ClientMetadataHandle first_stream_md; - ClientMetadataHandle second_stream_md; - auto first_stream_args = - CallArgs{std::move(first_stream_md), - ClientInitialMetadataOutstandingToken::Empty(), - nullptr, - &pipe_server_intial_metadata_.sender, - &pipe_client_to_server_messages_.receiver, - &pipe_server_to_client_messages_.sender}; - auto second_stream_args = - CallArgs{std::move(second_stream_md), - ClientInitialMetadataOutstandingToken::Empty(), - nullptr, - &pipe_server_intial_metadata_second_.sender, - &pipe_client_to_server_messages_second_.receiver, - &pipe_server_to_client_messages_second_.sender}; - StrictMock> on_done; - EXPECT_CALL(on_done, Call(absl::OkStatus())); - EXPECT_CALL(control_endpoint_, Write).Times(6).WillRepeatedly(Return(true)); - EXPECT_CALL(data_endpoint_, Write).Times(6).WillRepeatedly(Return(true)); - auto activity = MakeActivity( - Seq( - // Concurrently: write and read messages from client transport. - Join( - // Add first stream with call_args into client transport. - AddStream(std::move(first_stream_args), GRPC_STATUS_OK), - // Start read from control endpoints. - StartRead(absl::OkStatus()), - // Send messages to first stream's - // call_args.client_to_server_messages pipe, which will be - // eventually sent to control/data endpoints. - SendClientToServerMessages(pipe_client_to_server_messages_, 3), - // Receive first stream's messages from control/data endpoints. - ReceiveServerToClientMessages(pipe_server_intial_metadata_, - pipe_server_to_client_messages_)), - Join( - // Add second stream with call_args into client transport. - AddStream(std::move(second_stream_args), GRPC_STATUS_OK), - // Send messages to second stream's - // call_args.client_to_server_messages pipe, which will be - // eventually sent to control/data endpoints. - SendClientToServerMessages(pipe_client_to_server_messages_second_, - 3), - // Receive second stream's messages from control/data endpoints. - ReceiveServerToClientMessages( - pipe_server_intial_metadata_second_, - pipe_server_to_client_messages_second_)), - // Once complete, verify successful sending and the received value. - [](const std::tuple& - ret) { - EXPECT_EQ(std::get<0>(ret), GRPC_STATUS_OK); - EXPECT_TRUE(std::get<1>(ret).ok()); - EXPECT_TRUE(std::get<2>(ret).ok()); - return absl::OkStatus(); - }), - EventEngineWakeupScheduler(event_engine_), - [&on_done](absl::Status status) { on_done.Call(std::move(status)); }); +TEST_F(TransportTest, AddOneStreamMultipleMessages) { + MockPromiseEndpoint control_endpoint; + MockPromiseEndpoint data_endpoint; + control_endpoint.ExpectRead( + {SerializedFrameHeader(FrameType::kFragment, 3, 1, 26, 8, 56, 0), + EventEngineSlice::FromCopiedBuffer(kPathDemoServiceStep, + sizeof(kPathDemoServiceStep))}, + event_engine().get()); + control_endpoint.ExpectRead( + {SerializedFrameHeader(FrameType::kFragment, 6, 1, 0, 8, 56, 15), + EventEngineSlice::FromCopiedBuffer(kGrpcStatus0, sizeof(kGrpcStatus0))}, + event_engine().get()); + data_endpoint.ExpectRead( + {EventEngineSlice::FromCopiedString("12345678"), Zeros(56)}, nullptr); + data_endpoint.ExpectRead( + {EventEngineSlice::FromCopiedString("87654321"), Zeros(56)}, nullptr); + EXPECT_CALL(*control_endpoint.endpoint, Read) + .InSequence(control_endpoint.read_sequence) + .WillOnce(Return(false)); + auto transport = MakeOrphanable( + std::move(control_endpoint.promise_endpoint), + std::move(data_endpoint.promise_endpoint), event_engine()); + auto call = + MakeCall(event_engine().get(), Arena::Create(8192, memory_allocator())); + transport->StartCall(std::move(call.handler)); + StrictMock> on_done; + EXPECT_CALL(on_done, Call()); + control_endpoint.ExpectWrite( + {SerializedFrameHeader(FrameType::kFragment, 1, 1, + sizeof(kPathDemoServiceStep), 0, 0, 0), + EventEngineSlice::FromCopiedBuffer(kPathDemoServiceStep, + sizeof(kPathDemoServiceStep))}, + nullptr); + control_endpoint.ExpectWrite( + {SerializedFrameHeader(FrameType::kFragment, 2, 1, 0, 1, 63, 0)}, + nullptr); + data_endpoint.ExpectWrite( + {EventEngineSlice::FromCopiedString("0"), Zeros(63)}, nullptr); + control_endpoint.ExpectWrite( + {SerializedFrameHeader(FrameType::kFragment, 2, 1, 0, 1, 63, 0)}, + nullptr); + data_endpoint.ExpectWrite( + {EventEngineSlice::FromCopiedString("1"), Zeros(63)}, nullptr); + control_endpoint.ExpectWrite( + {SerializedFrameHeader(FrameType::kFragment, 4, 1, 0, 0, 0, 0)}, nullptr); + call.initiator.SpawnGuarded("test-send", [initiator = + call.initiator]() mutable { + return TrySeq(initiator.PushClientInitialMetadata(TestInitialMetadata()), + SendClientToServerMessages(initiator, 2)); + }); + call.initiator.SpawnInfallible( + "test-read", [&on_done, initiator = call.initiator]() mutable { + return Seq( + initiator.PullServerInitialMetadata(), + [](ValueOrFailure md) { + EXPECT_TRUE(md.ok()); + EXPECT_EQ( + md.value()->get_pointer(HttpPathMetadata())->as_string_view(), + "/demo.Service/Step"); + return Empty{}; + }, + initiator.PullMessage(), + [](NextResult msg) { + EXPECT_TRUE(msg.has_value()); + EXPECT_EQ(msg.value()->payload()->JoinIntoString(), "12345678"); + return Empty{}; + }, + initiator.PullMessage(), + [](NextResult msg) { + EXPECT_TRUE(msg.has_value()); + EXPECT_EQ(msg.value()->payload()->JoinIntoString(), "87654321"); + return Empty{}; + }, + initiator.PullMessage(), + [](NextResult msg) { + EXPECT_FALSE(msg.has_value()); + return Empty{}; + }, + initiator.PullServerTrailingMetadata(), + [&on_done](ServerMetadataHandle md) { + EXPECT_EQ(md->get(GrpcStatusMetadata()).value(), GRPC_STATUS_OK); + on_done.Call(); + return Empty{}; + }); + }); // Wait until ClientTransport's internal activities to finish. - event_engine_->TickUntilIdle(); - event_engine_->UnsetGlobalHooks(); + event_engine()->TickUntilIdle(); + event_engine()->UnsetGlobalHooks(); } } // namespace testing diff --git a/test/core/transport/chaotic_good/frame_fuzzer.cc b/test/core/transport/chaotic_good/frame_fuzzer.cc index 57180ce1c20..03481560771 100644 --- a/test/core/transport/chaotic_good/frame_fuzzer.cc +++ b/test/core/transport/chaotic_good/frame_fuzzer.cc @@ -35,7 +35,9 @@ #include "src/core/lib/resource_quota/resource_quota.h" #include "src/core/lib/slice/slice.h" #include "src/core/lib/slice/slice_buffer.h" +#include "src/libfuzzer/libfuzzer_macro.h" #include "test/core/promise/test_context.h" +#include "test/core/transport/chaotic_good/frame_fuzzer.pb.h" bool squelch = false; @@ -51,10 +53,10 @@ template void AssertRoundTrips(const T& input, FrameType expected_frame_type) { HPackCompressor hpack_compressor; auto serialized = input.Serialize(&hpack_compressor); - GPR_ASSERT(serialized.Length() >= + GPR_ASSERT(serialized.control.Length() >= 24); // Initial output buffer size is 64 byte. uint8_t header_bytes[24]; - serialized.MoveFirstNBytesIntoBuffer(24, header_bytes); + serialized.control.MoveFirstNBytesIntoBuffer(24, header_bytes); auto header = FrameHeader::Parse(header_bytes); if (!header.ok()) { if (!squelch) { @@ -67,66 +69,69 @@ void AssertRoundTrips(const T& input, FrameType expected_frame_type) { T output; HPackParser hpack_parser; DeterministicBitGen bitgen; - auto deser = output.Deserialize(&hpack_parser, header.value(), - absl::BitGenRef(bitgen), serialized); + auto deser = + output.Deserialize(&hpack_parser, header.value(), absl::BitGenRef(bitgen), + GetContext(), std::move(serialized)); GPR_ASSERT(deser.ok()); GPR_ASSERT(output == input); } template -void FinishParseAndChecks(const FrameHeader& header, const uint8_t* data, - size_t size) { +void FinishParseAndChecks(const FrameHeader& header, BufferPair buffers) { T parsed; ExecCtx exec_ctx; // Initialized to get this_cpu() info in global_stat(). HPackParser hpack_parser; - SliceBuffer serialized; - serialized.Append(Slice::FromCopiedBuffer(data, size)); DeterministicBitGen bitgen; - auto deser = parsed.Deserialize(&hpack_parser, header, - absl::BitGenRef(bitgen), serialized); + auto deser = + parsed.Deserialize(&hpack_parser, header, absl::BitGenRef(bitgen), + GetContext(), std::move(buffers)); if (!deser.ok()) return; gpr_log(GPR_INFO, "Read frame: %s", parsed.ToString().c_str()); AssertRoundTrips(parsed, header.type); } -int Run(const uint8_t* data, size_t size) { - if (size < 1) return 0; - const bool is_server = (data[0] & 1) != 0; - size--; - data++; - if (size < 24) return 0; - auto r = FrameHeader::Parse(data); - if (!r.ok()) return 0; +void Run(const frame_fuzzer::Test& test) { + const uint8_t* control_data = + reinterpret_cast(test.control().data()); + size_t control_size = test.control().size(); + if (test.control().size() < 24) return; + auto r = FrameHeader::Parse(control_data); + if (!r.ok()) return; + if (test.data().size() != r->message_length) return; gpr_log(GPR_INFO, "Read frame header: %s", r->ToString().c_str()); - size -= 24; - data += 24; + control_data += 24; + control_size -= 24; MemoryAllocator memory_allocator = MemoryAllocator( ResourceQuota::Default()->memory_quota()->CreateMemoryAllocator("test")); auto arena = MakeScopedArena(1024, &memory_allocator); TestContext ctx(arena.get()); + BufferPair buffers{ + SliceBuffer(Slice::FromCopiedBuffer(control_data, control_size)), + SliceBuffer( + Slice::FromCopiedBuffer(test.data().data(), test.data().size())), + }; switch (r->type) { default: - return 0; // We don't know how to parse this frame type. + return; // We don't know how to parse this frame type. case FrameType::kSettings: - FinishParseAndChecks(*r, data, size); + FinishParseAndChecks(*r, std::move(buffers)); break; case FrameType::kFragment: - if (is_server) { - FinishParseAndChecks(*r, data, size); + if (test.is_server()) { + FinishParseAndChecks(*r, std::move(buffers)); } else { - FinishParseAndChecks(*r, data, size); + FinishParseAndChecks(*r, std::move(buffers)); } break; case FrameType::kCancel: - FinishParseAndChecks(*r, data, size); + FinishParseAndChecks(*r, std::move(buffers)); break; } - return 0; } } // namespace chaotic_good } // namespace grpc_core -extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { - return grpc_core::chaotic_good::Run(data, size); +DEFINE_PROTO_FUZZER(const frame_fuzzer::Test& test) { + grpc_core::chaotic_good::Run(test); } diff --git a/test/core/transport/chaotic_good/frame_fuzzer.proto b/test/core/transport/chaotic_good/frame_fuzzer.proto new file mode 100644 index 00000000000..4ae8657e588 --- /dev/null +++ b/test/core/transport/chaotic_good/frame_fuzzer.proto @@ -0,0 +1,23 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package frame_fuzzer; + +message Test { + bool is_server = 1; + bytes control = 2; + bytes data = 3; +} diff --git a/test/core/transport/chaotic_good/frame_fuzzer_corpus/5072496117219328 b/test/core/transport/chaotic_good/frame_fuzzer_corpus/5072496117219328 deleted file mode 100644 index 16d6e2f4fde..00000000000 Binary files a/test/core/transport/chaotic_good/frame_fuzzer_corpus/5072496117219328 and /dev/null differ diff --git a/test/core/transport/chaotic_good/frame_fuzzer_corpus/5691448031772672 b/test/core/transport/chaotic_good/frame_fuzzer_corpus/5691448031772672 deleted file mode 100644 index 98e8a28385d..00000000000 Binary files a/test/core/transport/chaotic_good/frame_fuzzer_corpus/5691448031772672 and /dev/null differ diff --git a/test/core/transport/chaotic_good/frame_fuzzer_corpus/crash-05c704327d21af2cc914de40e9d90d06f16ca0eb b/test/core/transport/chaotic_good/frame_fuzzer_corpus/crash-05c704327d21af2cc914de40e9d90d06f16ca0eb deleted file mode 100644 index 0340f7dee01..00000000000 Binary files a/test/core/transport/chaotic_good/frame_fuzzer_corpus/crash-05c704327d21af2cc914de40e9d90d06f16ca0eb and /dev/null differ diff --git a/test/core/transport/chaotic_good/frame_fuzzer_corpus/crash-5015de8c7cafb0b0ebbbfd28c29aedd5dbfdc03a b/test/core/transport/chaotic_good/frame_fuzzer_corpus/crash-5015de8c7cafb0b0ebbbfd28c29aedd5dbfdc03a deleted file mode 100644 index 366c5302596..00000000000 Binary files a/test/core/transport/chaotic_good/frame_fuzzer_corpus/crash-5015de8c7cafb0b0ebbbfd28c29aedd5dbfdc03a and /dev/null differ diff --git a/test/core/transport/chaotic_good/frame_fuzzer_corpus/crash-5a34978de8de6889ce913947a77f43f7cdea854c b/test/core/transport/chaotic_good/frame_fuzzer_corpus/crash-5a34978de8de6889ce913947a77f43f7cdea854c deleted file mode 100644 index 74cb18c8e93..00000000000 Binary files a/test/core/transport/chaotic_good/frame_fuzzer_corpus/crash-5a34978de8de6889ce913947a77f43f7cdea854c and /dev/null differ diff --git a/test/core/transport/chaotic_good/frame_fuzzer_corpus/crash-608f798a51077a8cdc45b11f335c079a81339fbe b/test/core/transport/chaotic_good/frame_fuzzer_corpus/crash-608f798a51077a8cdc45b11f335c079a81339fbe deleted file mode 100644 index 2190f6bc185..00000000000 Binary files a/test/core/transport/chaotic_good/frame_fuzzer_corpus/crash-608f798a51077a8cdc45b11f335c079a81339fbe and /dev/null differ diff --git a/test/core/transport/chaotic_good/frame_fuzzer_corpus/crash-6a002cb46eac21af4ab6fd74b61ff3ce26d96dff b/test/core/transport/chaotic_good/frame_fuzzer_corpus/crash-6a002cb46eac21af4ab6fd74b61ff3ce26d96dff deleted file mode 100644 index 90739c72580..00000000000 Binary files a/test/core/transport/chaotic_good/frame_fuzzer_corpus/crash-6a002cb46eac21af4ab6fd74b61ff3ce26d96dff and /dev/null differ diff --git a/test/core/transport/chaotic_good/frame_fuzzer_corpus/crash-7732ddd35a4deb8b7c9e462aaf8680986755e540 b/test/core/transport/chaotic_good/frame_fuzzer_corpus/crash-7732ddd35a4deb8b7c9e462aaf8680986755e540 deleted file mode 100644 index 4d14b159ba6..00000000000 Binary files a/test/core/transport/chaotic_good/frame_fuzzer_corpus/crash-7732ddd35a4deb8b7c9e462aaf8680986755e540 and /dev/null differ diff --git a/test/core/transport/chaotic_good/frame_fuzzer_corpus/crash-c171e98ebfe8b6485f9a4bea0b9cdfe683776675 b/test/core/transport/chaotic_good/frame_fuzzer_corpus/crash-c171e98ebfe8b6485f9a4bea0b9cdfe683776675 deleted file mode 100644 index 8fa12bd3aea..00000000000 Binary files a/test/core/transport/chaotic_good/frame_fuzzer_corpus/crash-c171e98ebfe8b6485f9a4bea0b9cdfe683776675 and /dev/null differ diff --git a/test/core/transport/chaotic_good/frame_fuzzer_corpus/empty b/test/core/transport/chaotic_good/frame_fuzzer_corpus/empty new file mode 100644 index 00000000000..8b137891791 --- /dev/null +++ b/test/core/transport/chaotic_good/frame_fuzzer_corpus/empty @@ -0,0 +1 @@ + diff --git a/test/core/transport/chaotic_good/frame_header_test.cc b/test/core/transport/chaotic_good/frame_header_test.cc index a382315fa14..929a74b1df2 100644 --- a/test/core/transport/chaotic_good/frame_header_test.cc +++ b/test/core/transport/chaotic_good/frame_header_test.cc @@ -36,7 +36,7 @@ absl::StatusOr Deserialize(std::vector data) { } TEST(FrameHeaderTest, SimpleSerialize) { - EXPECT_EQ(Serialize(FrameHeader{FrameType::kCancel, BitSet<2>::FromInt(0), + EXPECT_EQ(Serialize(FrameHeader{FrameType::kCancel, BitSet<3>::FromInt(0), 0x01020304, 0x05060708, 0x090a0b0c, 0x00000034, 0x0d0e0f10}), std::vector({ @@ -59,7 +59,7 @@ TEST(FrameHeaderTest, SimpleDeserialize) { 0x10, 0x0f, 0x0e, 0x0d // trailer_length })), absl::StatusOr(FrameHeader{ - FrameType::kCancel, BitSet<2>::FromInt(0), 0x01020304, + FrameType::kCancel, BitSet<3>::FromInt(0), 0x01020304, 0x05060708, 0x090a0b0c, 0x00000034, 0x0d0e0f10})); EXPECT_EQ(Deserialize(std::vector({ 0x81, 88, 88, 88, // type, flags @@ -75,19 +75,19 @@ TEST(FrameHeaderTest, SimpleDeserialize) { TEST(FrameHeaderTest, GetFrameLength) { EXPECT_EQ( - (FrameHeader{FrameType::kFragment, BitSet<2>::FromInt(3), 1, 0, 0, 0, 0}) + (FrameHeader{FrameType::kFragment, BitSet<3>::FromInt(5), 1, 0, 0, 0, 0}) .GetFrameLength(), 0); EXPECT_EQ( - (FrameHeader{FrameType::kFragment, BitSet<2>::FromInt(3), 1, 14, 0, 0, 0}) + (FrameHeader{FrameType::kFragment, BitSet<3>::FromInt(5), 1, 14, 0, 0, 0}) .GetFrameLength(), 14); - EXPECT_EQ((FrameHeader{FrameType::kFragment, BitSet<2>::FromInt(3), 1, 0, 14, + EXPECT_EQ((FrameHeader{FrameType::kFragment, BitSet<3>::FromInt(5), 1, 0, 14, 50, 0}) .GetFrameLength(), 0); EXPECT_EQ( - (FrameHeader{FrameType::kFragment, BitSet<2>::FromInt(3), 1, 0, 0, 0, 14}) + (FrameHeader{FrameType::kFragment, BitSet<3>::FromInt(5), 1, 0, 0, 0, 14}) .GetFrameLength(), 14); } diff --git a/test/core/transport/chaotic_good/frame_test.cc b/test/core/transport/chaotic_good/frame_test.cc index 00908f75a6e..15153a09b8d 100644 --- a/test/core/transport/chaotic_good/frame_test.cc +++ b/test/core/transport/chaotic_good/frame_test.cc @@ -21,27 +21,38 @@ #include "absl/status/statusor.h" #include "gtest/gtest.h" +#include "src/core/lib/resource_quota/memory_quota.h" +#include "src/core/lib/resource_quota/resource_quota.h" + namespace grpc_core { namespace chaotic_good { namespace { template -void AssertRoundTrips(const T input, FrameType expected_frame_type) { +void AssertRoundTrips(const T& input, FrameType expected_frame_type) { HPackCompressor hpack_compressor; - absl::BitGen bitgen; auto serialized = input.Serialize(&hpack_compressor); - EXPECT_GE(serialized.Length(), 24); + GPR_ASSERT(serialized.control.Length() >= + 24); // Initial output buffer size is 64 byte. uint8_t header_bytes[24]; - serialized.MoveFirstNBytesIntoBuffer(24, header_bytes); + serialized.control.MoveFirstNBytesIntoBuffer(24, header_bytes); auto header = FrameHeader::Parse(header_bytes); - EXPECT_TRUE(header.ok()) << header.status(); - EXPECT_EQ(header->type, expected_frame_type); + if (!header.ok()) { + Crash("Failed to parse header"); + } + GPR_ASSERT(header->type == expected_frame_type); T output; HPackParser hpack_parser; - auto deser = output.Deserialize(&hpack_parser, header.value(), - absl::BitGenRef(bitgen), serialized); - EXPECT_TRUE(deser.ok()) << deser; - EXPECT_EQ(output, input); + absl::BitGen bitgen; + MemoryAllocator allocator = MakeResourceQuota("test-quota") + ->memory_quota() + ->CreateMemoryAllocator("test-allocator"); + ScopedArenaPtr arena = MakeScopedArena(1024, &allocator); + auto deser = + output.Deserialize(&hpack_parser, header.value(), absl::BitGenRef(bitgen), + arena.get(), std::move(serialized)); + GPR_ASSERT(deser.ok()); + GPR_ASSERT(output == input); } TEST(FrameTest, SettingsFrameRoundTrips) { diff --git a/test/core/transport/chaotic_good/mock_promise_endpoint.cc b/test/core/transport/chaotic_good/mock_promise_endpoint.cc new file mode 100644 index 00000000000..9ba96e75804 --- /dev/null +++ b/test/core/transport/chaotic_good/mock_promise_endpoint.cc @@ -0,0 +1,89 @@ +// Copyright 2023 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/core/transport/chaotic_good/mock_promise_endpoint.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +#include + +using EventEngineSlice = grpc_event_engine::experimental::Slice; +using grpc_event_engine::experimental::EventEngine; + +using testing::WithArgs; + +namespace grpc_core { +namespace chaotic_good { +namespace testing { + +void MockPromiseEndpoint::ExpectRead( + std::initializer_list slices_init, + EventEngine* schedule_on_event_engine) { + std::vector slices; + for (auto&& slice : slices_init) slices.emplace_back(slice.Copy()); + EXPECT_CALL(*endpoint, Read) + .InSequence(read_sequence) + .WillOnce(WithArgs<0, 1>( + [slices = std::move(slices), schedule_on_event_engine]( + absl::AnyInvocable on_read, + grpc_event_engine::experimental::SliceBuffer* buffer) mutable { + for (auto& slice : slices) { + buffer->Append(std::move(slice)); + } + if (schedule_on_event_engine != nullptr) { + schedule_on_event_engine->Run( + [on_read = std::move(on_read)]() mutable { + on_read(absl::OkStatus()); + }); + return false; + } else { + return true; + } + })); +} + +void MockPromiseEndpoint::ExpectWrite( + std::initializer_list slices, + EventEngine* schedule_on_event_engine) { + SliceBuffer expect; + for (auto&& slice : slices) { + expect.Append(grpc_event_engine::experimental::internal::SliceCast( + slice.Copy())); + } + EXPECT_CALL(*endpoint, Write) + .InSequence(write_sequence) + .WillOnce(WithArgs<0, 1>( + [expect = expect.JoinIntoString(), schedule_on_event_engine]( + absl::AnyInvocable on_writable, + grpc_event_engine::experimental::SliceBuffer* buffer) mutable { + SliceBuffer tmp; + grpc_slice_buffer_swap(buffer->c_slice_buffer(), + tmp.c_slice_buffer()); + EXPECT_EQ(tmp.JoinIntoString(), expect); + if (schedule_on_event_engine != nullptr) { + schedule_on_event_engine->Run( + [on_writable = std::move(on_writable)]() mutable { + on_writable(absl::OkStatus()); + }); + return false; + } else { + return true; + } + })); +} + +} // namespace testing +} // namespace chaotic_good +} // namespace grpc_core diff --git a/test/core/transport/chaotic_good/mock_promise_endpoint.h b/test/core/transport/chaotic_good/mock_promise_endpoint.h new file mode 100644 index 00000000000..c1534efb605 --- /dev/null +++ b/test/core/transport/chaotic_good/mock_promise_endpoint.h @@ -0,0 +1,77 @@ +// Copyright 2023 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GRPC_TEST_CORE_TRANSPORT_CHAOTIC_GOOD_MOCK_PROMISE_ENDPOINT_H +#define GRPC_TEST_CORE_TRANSPORT_CHAOTIC_GOOD_MOCK_PROMISE_ENDPOINT_H + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +#include + +#include "src/core/lib/transport/promise_endpoint.h" + +namespace grpc_core { +namespace chaotic_good { +namespace testing { + +class MockEndpoint + : public grpc_event_engine::experimental::EventEngine::Endpoint { + public: + MOCK_METHOD( + bool, Read, + (absl::AnyInvocable on_read, + grpc_event_engine::experimental::SliceBuffer* buffer, + const grpc_event_engine::experimental::EventEngine::Endpoint::ReadArgs* + args), + (override)); + + MOCK_METHOD( + bool, Write, + (absl::AnyInvocable on_writable, + grpc_event_engine::experimental::SliceBuffer* data, + const grpc_event_engine::experimental::EventEngine::Endpoint::WriteArgs* + args), + (override)); + + MOCK_METHOD( + const grpc_event_engine::experimental::EventEngine::ResolvedAddress&, + GetPeerAddress, (), (const, override)); + MOCK_METHOD( + const grpc_event_engine::experimental::EventEngine::ResolvedAddress&, + GetLocalAddress, (), (const, override)); +}; + +struct MockPromiseEndpoint { + ::testing::StrictMock* endpoint = + new ::testing::StrictMock(); + std::unique_ptr promise_endpoint = + std::make_unique( + std::unique_ptr<::testing::StrictMock>(endpoint), + SliceBuffer()); + ::testing::Sequence read_sequence; + ::testing::Sequence write_sequence; + void ExpectRead( + std::initializer_list slices_init, + grpc_event_engine::experimental::EventEngine* schedule_on_event_engine); + void ExpectWrite( + std::initializer_list slices, + grpc_event_engine::experimental::EventEngine* schedule_on_event_engine); +}; + +} // namespace testing +} // namespace chaotic_good +} // namespace grpc_core + +#endif // GRPC_TEST_CORE_TRANSPORT_CHAOTIC_GOOD_MOCK_PROMISE_ENDPOINT_H diff --git a/test/core/transport/chaotic_good/server_transport_test.cc b/test/core/transport/chaotic_good/server_transport_test.cc new file mode 100644 index 00000000000..05a32a21817 --- /dev/null +++ b/test/core/transport/chaotic_good/server_transport_test.cc @@ -0,0 +1,198 @@ +// Copyright 2023 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/core/ext/transport/chaotic_good/server_transport.h" + +#include +#include +#include +#include + +#include "absl/functional/any_invocable.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/types/optional.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/gprpp/ref_counted_ptr.h" +#include "src/core/lib/iomgr/timer_manager.h" +#include "src/core/lib/promise/seq.h" +#include "src/core/lib/resource_quota/arena.h" +#include "src/core/lib/resource_quota/memory_quota.h" +#include "src/core/lib/resource_quota/resource_quota.h" +#include "src/core/lib/slice/slice_buffer.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/transport/metadata_batch.h" +#include "test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.h" +#include "test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.pb.h" +#include "test/core/transport/chaotic_good/mock_promise_endpoint.h" +#include "test/core/transport/chaotic_good/transport_test.h" + +using testing::_; +using testing::MockFunction; +using testing::Return; +using testing::StrictMock; +using testing::WithArgs; + +using EventEngineSlice = grpc_event_engine::experimental::Slice; + +namespace grpc_core { +namespace chaotic_good { +namespace testing { + +// Encoded string of header ":path: /demo.Service/Step". +const uint8_t kPathDemoServiceStep[] = { + 0x40, 0x05, 0x3a, 0x70, 0x61, 0x74, 0x68, 0x12, 0x2f, + 0x64, 0x65, 0x6d, 0x6f, 0x2e, 0x53, 0x65, 0x72, 0x76, + 0x69, 0x63, 0x65, 0x2f, 0x53, 0x74, 0x65, 0x70}; + +// Encoded string of trailer "grpc-status: 0". +const uint8_t kGrpcStatus0[] = {0x40, 0x0b, 0x67, 0x72, 0x70, 0x63, 0x2d, 0x73, + 0x74, 0x61, 0x74, 0x75, 0x73, 0x01, 0x30}; + +ServerMetadataHandle TestInitialMetadata() { + auto md = + GetContext()->MakePooled(GetContext()); + md->Set(HttpPathMetadata(), Slice::FromStaticString("/demo.Service/Step")); + return md; +} + +ServerMetadataHandle TestTrailingMetadata() { + auto md = + GetContext()->MakePooled(GetContext()); + md->Set(GrpcStatusMetadata(), GRPC_STATUS_OK); + return md; +} + +class MockAcceptor : public ServerTransport::Acceptor { + public: + virtual ~MockAcceptor() = default; + MOCK_METHOD(Arena*, CreateArena, (), (override)); + MOCK_METHOD(absl::StatusOr, CreateCall, + (ClientMetadata & client_initial_metadata, Arena* arena), + (override)); +}; + +TEST_F(TransportTest, ReadAndWriteOneMessage) { + MockPromiseEndpoint control_endpoint; + MockPromiseEndpoint data_endpoint; + StrictMock acceptor; + auto transport = MakeOrphanable( + CoreConfiguration::Get() + .channel_args_preconditioning() + .PreconditionChannelArgs(nullptr), + std::move(control_endpoint.promise_endpoint), + std::move(data_endpoint.promise_endpoint), event_engine()); + // Once we set the acceptor, expect to read some frames. + // We'll return a new request with a payload of "12345678". + control_endpoint.ExpectRead( + {SerializedFrameHeader(FrameType::kFragment, 7, 1, 26, 8, 56, 0), + EventEngineSlice::FromCopiedBuffer(kPathDemoServiceStep, + sizeof(kPathDemoServiceStep))}, + event_engine().get()); + data_endpoint.ExpectRead( + {EventEngineSlice::FromCopiedString("12345678"), Zeros(56)}, nullptr); + // Once that's read we'll create a new call + auto* call_arena = Arena::Create(1024, memory_allocator()); + CallInitiatorAndHandler call = MakeCall(event_engine().get(), call_arena); + EXPECT_CALL(acceptor, CreateArena).WillOnce(Return(call_arena)); + EXPECT_CALL(acceptor, CreateCall(_, call_arena)) + .WillOnce(WithArgs<0>([call_initiator = std::move(call.initiator)]( + ClientMetadata& client_initial_metadata) { + EXPECT_EQ(client_initial_metadata.get_pointer(HttpPathMetadata()) + ->as_string_view(), + "/demo.Service/Step"); + return call_initiator; + })); + transport->SetAcceptor(&acceptor); + StrictMock> on_done; + EXPECT_CALL(on_done, Call()); + EXPECT_CALL(*control_endpoint.endpoint, Read) + .InSequence(control_endpoint.read_sequence) + .WillOnce(Return(false)); + control_endpoint.ExpectWrite( + {SerializedFrameHeader(FrameType::kFragment, 1, 1, + sizeof(kPathDemoServiceStep), 0, 0, 0), + EventEngineSlice::FromCopiedBuffer(kPathDemoServiceStep, + sizeof(kPathDemoServiceStep))}, + nullptr); + control_endpoint.ExpectWrite( + {SerializedFrameHeader(FrameType::kFragment, 2, 1, 0, 8, 56, 0)}, + nullptr); + data_endpoint.ExpectWrite( + {EventEngineSlice::FromCopiedString("87654321"), Zeros(56)}, nullptr); + control_endpoint.ExpectWrite( + {SerializedFrameHeader(FrameType::kFragment, 4, 1, 0, 0, 0, + sizeof(kGrpcStatus0)), + EventEngineSlice::FromCopiedBuffer(kGrpcStatus0, sizeof(kGrpcStatus0))}, + nullptr); + call.handler.SpawnInfallible( + "test-io", [&on_done, handler = call.handler]() mutable { + return Seq( + handler.PullClientInitialMetadata(), + [](ValueOrFailure md) { + EXPECT_TRUE(md.ok()); + EXPECT_EQ( + md.value()->get_pointer(HttpPathMetadata())->as_string_view(), + "/demo.Service/Step"); + return Empty{}; + }, + handler.PullMessage(), + [](NextResult msg) { + EXPECT_TRUE(msg.has_value()); + EXPECT_EQ(msg.value()->payload()->JoinIntoString(), "12345678"); + return Empty{}; + }, + handler.PullMessage(), + [](NextResult msg) { + EXPECT_FALSE(msg.has_value()); + return Empty{}; + }, + handler.PushServerInitialMetadata(TestInitialMetadata()), + handler.PushMessage(Arena::MakePooled( + SliceBuffer(Slice::FromCopiedString("87654321")), 0)), + [handler]() mutable { + return handler.PushServerTrailingMetadata(TestTrailingMetadata()); + }, + [&on_done]() mutable { + on_done.Call(); + return Empty{}; + }); + }); + // Wait until ClientTransport's internal activities to finish. + event_engine()->TickUntilIdle(); + event_engine()->UnsetGlobalHooks(); +} + +} // namespace testing +} // namespace chaotic_good +} // namespace grpc_core + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + // Must call to create default EventEngine. + grpc_init(); + int ret = RUN_ALL_TESTS(); + grpc_shutdown(); + return ret; +} diff --git a/test/core/transport/chaotic_good/transport_test.cc b/test/core/transport/chaotic_good/transport_test.cc new file mode 100644 index 00000000000..b43098fa7c7 --- /dev/null +++ b/test/core/transport/chaotic_good/transport_test.cc @@ -0,0 +1,60 @@ +// Copyright 2023 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/core/transport/chaotic_good/transport_test.h" + +namespace grpc_core { +namespace chaotic_good { +namespace testing { + +grpc_event_engine::experimental::Slice SerializedFrameHeader( + FrameType type, uint8_t flags, uint32_t stream_id, uint32_t header_length, + uint32_t message_length, uint32_t message_padding, + uint32_t trailer_length) { + uint8_t buffer[24] = {static_cast(type), + flags, + 0, + 0, + static_cast(stream_id), + static_cast(stream_id >> 8), + static_cast(stream_id >> 16), + static_cast(stream_id >> 24), + static_cast(header_length), + static_cast(header_length >> 8), + static_cast(header_length >> 16), + static_cast(header_length >> 24), + static_cast(message_length), + static_cast(message_length >> 8), + static_cast(message_length >> 16), + static_cast(message_length >> 24), + static_cast(message_padding), + static_cast(message_padding >> 8), + static_cast(message_padding >> 16), + static_cast(message_padding >> 24), + static_cast(trailer_length), + static_cast(trailer_length >> 8), + static_cast(trailer_length >> 16), + static_cast(trailer_length >> 24)}; + return grpc_event_engine::experimental::Slice::FromCopiedBuffer(buffer, 24); +} + +grpc_event_engine::experimental::Slice Zeros(uint32_t length) { + std::string zeros(length, 0); + return grpc_event_engine::experimental::Slice::FromCopiedBuffer(zeros.data(), + length); +} + +} // namespace testing +} // namespace chaotic_good +} // namespace grpc_core diff --git a/test/core/transport/chaotic_good/transport_test.h b/test/core/transport/chaotic_good/transport_test.h new file mode 100644 index 00000000000..e70158bb8cf --- /dev/null +++ b/test/core/transport/chaotic_good/transport_test.h @@ -0,0 +1,67 @@ +// Copyright 2023 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GRPC_TEST_CORE_TRANSPORT_CHAOTIC_GOOD_TRANSPORT_TEST_H +#define GRPC_TEST_CORE_TRANSPORT_CHAOTIC_GOOD_TRANSPORT_TEST_H + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +#include "src/core/ext/transport/chaotic_good/frame.h" +#include "src/core/lib/iomgr/timer_manager.h" +#include "src/core/lib/resource_quota/memory_quota.h" +#include "src/core/lib/resource_quota/resource_quota.h" +#include "test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.h" +#include "test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.pb.h" + +namespace grpc_core { +namespace chaotic_good { +namespace testing { + +class TransportTest : public ::testing::Test { + protected: + const std::shared_ptr& + event_engine() { + return event_engine_; + } + + MemoryAllocator* memory_allocator() { return &allocator_; } + + private: + std::shared_ptr + event_engine_{ + std::make_shared( + []() { + grpc_timer_manager_set_threading(false); + grpc_event_engine::experimental::FuzzingEventEngine::Options + options; + return options; + }(), + fuzzing_event_engine::Actions())}; + MemoryAllocator allocator_ = MakeResourceQuota("test-quota") + ->memory_quota() + ->CreateMemoryAllocator("test-allocator"); +}; + +grpc_event_engine::experimental::Slice SerializedFrameHeader( + FrameType type, uint8_t flags, uint32_t stream_id, uint32_t header_length, + uint32_t message_length, uint32_t message_padding, uint32_t trailer_length); + +grpc_event_engine::experimental::Slice Zeros(uint32_t length); + +} // namespace testing +} // namespace chaotic_good +} // namespace grpc_core + +#endif // GRPC_TEST_CORE_TRANSPORT_CHAOTIC_GOOD_TRANSPORT_TEST_H diff --git a/test/core/transport/promise_endpoint_test.cc b/test/core/transport/promise_endpoint_test.cc index 760a2add6c7..e6ad3dd2713 100644 --- a/test/core/transport/promise_endpoint_test.cc +++ b/test/core/transport/promise_endpoint_test.cc @@ -524,6 +524,19 @@ TEST_F(PromiseEndpointTest, OneWriteSuccessful) { activity.Activate(); EXPECT_CALL(activity, WakeupRequested).Times(0); EXPECT_CALL(mock_endpoint_, Write).WillOnce(Return(true)); + auto promise = promise_endpoint_->Write( + SliceBuffer(Slice::FromCopiedString("hello world"))); + auto poll = promise(); + ASSERT_TRUE(poll.ready()); + EXPECT_EQ(absl::OkStatus(), poll.value()); + activity.Deactivate(); +} + +TEST_F(PromiseEndpointTest, EmptyWriteIsNoOp) { + MockActivity activity; + activity.Activate(); + EXPECT_CALL(activity, WakeupRequested).Times(0); + EXPECT_CALL(mock_endpoint_, Write).Times(0); auto promise = promise_endpoint_->Write(SliceBuffer()); auto poll = promise(); ASSERT_TRUE(poll.ready()); @@ -541,7 +554,8 @@ TEST_F(PromiseEndpointTest, OneWriteFailed) { on_write(this->kDummyErrorStatus); return false; })); - auto promise = promise_endpoint_->Write(SliceBuffer()); + auto promise = promise_endpoint_->Write( + SliceBuffer(Slice::FromCopiedString("hello world"))); auto poll = promise(); ASSERT_TRUE(poll.ready()); EXPECT_EQ(kDummyErrorStatus, poll.value()); @@ -564,7 +578,8 @@ TEST_F(PromiseEndpointTest, OnePendingWriteSuccessful) { // Return false to mock EventEngine write pending.. return false; })); - auto promise = promise_endpoint_->Write(SliceBuffer()); + auto promise = promise_endpoint_->Write( + SliceBuffer(Slice::FromCopiedString("hello world"))); EXPECT_TRUE(promise().pending()); // Mock EventEngine write succeeds, and promise resolves. write_callback(absl::OkStatus()); @@ -586,7 +601,8 @@ TEST_F(PromiseEndpointTest, OnePendingWriteFailed) { // Return false to mock EventEngine write pending.. return false; })); - auto promise = promise_endpoint_->Write(SliceBuffer()); + auto promise = promise_endpoint_->Write( + SliceBuffer(Slice::FromCopiedString("hello world"))); EXPECT_TRUE(promise().pending()); write_callback(kDummyErrorStatus); auto poll = promise(); @@ -807,8 +823,10 @@ TEST_F(MultiplePromiseEndpointTest, JoinWritesSuccessful) { EXPECT_CALL(on_done, Call(absl::OkStatus())); auto activity = MakeActivity( [this] { - return Seq(Join(this->first_promise_endpoint_.Write(SliceBuffer()), - this->second_promise_endpoint_.Write(SliceBuffer())), + return Seq(Join(this->first_promise_endpoint_.Write(SliceBuffer( + Slice::FromCopiedString("hello world"))), + this->second_promise_endpoint_.Write(SliceBuffer( + Slice::FromCopiedString("hello world")))), [](std::tuple ret) { // Both writes finish with `absl::OkStatus`. EXPECT_TRUE(std::get<0>(ret).ok()); @@ -832,8 +850,10 @@ TEST_F(MultiplePromiseEndpointTest, JoinOneWriteSuccessfulOneWriteFailed) { EXPECT_CALL(on_done, Call(kDummyErrorStatus)); auto activity = MakeActivity( [this] { - return Seq(Join(this->first_promise_endpoint_.Write(SliceBuffer()), - this->second_promise_endpoint_.Write(SliceBuffer())), + return Seq(Join(this->first_promise_endpoint_.Write(SliceBuffer( + Slice::FromCopiedString("hello world"))), + this->second_promise_endpoint_.Write(SliceBuffer( + Slice::FromCopiedString("hello world")))), [this](std::tuple ret) { // One write finish with `absl::OkStatus` and the other // write fails. @@ -864,8 +884,10 @@ TEST_F(MultiplePromiseEndpointTest, JoinWritesFailed) { EXPECT_CALL(on_done, Call(kDummyErrorStatus)); auto activity = MakeActivity( [this] { - return Seq(Join(this->first_promise_endpoint_.Write(SliceBuffer()), - this->second_promise_endpoint_.Write(SliceBuffer())), + return Seq(Join(this->first_promise_endpoint_.Write(SliceBuffer( + Slice::FromCopiedString("hello world"))), + this->second_promise_endpoint_.Write(SliceBuffer( + Slice::FromCopiedString("hello world")))), [this](std::tuple ret) { // Both writes fail with errors. EXPECT_FALSE(std::get<0>(ret).ok()); diff --git a/tools/run_tests/generated/tests.json b/tools/run_tests/generated/tests.json index 7981a98a90b..ae3fad4bc33 100644 --- a/tools/run_tests/generated/tests.json +++ b/tools/run_tests/generated/tests.json @@ -9071,6 +9071,30 @@ ], "uses_polling": true }, + { + "args": [], + "benchmark": false, + "ci_platforms": [ + "linux", + "mac", + "posix", + "windows" + ], + "cpu_cost": 1.0, + "exclude_configs": [], + "exclude_iomgrs": [], + "flaky": false, + "gtest": true, + "language": "c++", + "name": "server_transport_test", + "platforms": [ + "linux", + "mac", + "posix", + "windows" + ], + "uses_polling": false + }, { "args": [], "benchmark": false,