From 86522af60db2678ed522c95b7dbb9e86c7a5481c Mon Sep 17 00:00:00 2001 From: Craig Tiller Date: Mon, 8 Jan 2024 09:43:16 -0800 Subject: [PATCH] [chaotic-good] Client & server transport (#35400) Adapts work from https://github.com/grpc/grpc/pull/34728 and previous changes from @nanahpang, implements new v3 filter/transport interface, and brings up the core of the chaotic good transport. The next change will bring a more complete test suite (for this and inproc). Closes #35400 COPYBARA_INTEGRATE_REVIEW=https://github.com/grpc/grpc/pull/35400 from ctiller:v3-server 737ca5431a44b3ac8d65eb738fd11179ecca197b PiperOrigin-RevId: 596621152 --- CMakeLists.txt | 52 ++ build_autogenerated.yaml | 44 +- .../grpc/event_engine/internal/slice_cast.h | 12 + include/grpc/event_engine/slice.h | 5 + src/core/BUILD | 85 +++ .../chaotic_good/chaotic_good_transport.cc | 19 + .../chaotic_good/chaotic_good_transport.h | 111 ++++ .../chaotic_good/client_transport.cc | 325 +++++---- .../transport/chaotic_good/client_transport.h | 202 ++---- src/core/ext/transport/chaotic_good/frame.cc | 165 +++-- src/core/ext/transport/chaotic_good/frame.h | 83 ++- .../transport/chaotic_good/frame_header.cc | 11 +- .../ext/transport/chaotic_good/frame_header.h | 2 +- .../chaotic_good/server_transport.cc | 332 ++++++++++ .../transport/chaotic_good/server_transport.h | 145 ++++ .../ext/transport/inproc/inproc_transport.cc | 8 +- src/core/lib/gprpp/debug_location.h | 9 + src/core/lib/promise/detail/status.h | 45 +- .../promise/event_engine_wakeup_scheduler.h | 4 +- src/core/lib/promise/if.h | 4 + src/core/lib/promise/inter_activity_pipe.h | 10 +- src/core/lib/promise/mpsc.h | 22 +- src/core/lib/promise/status_flag.h | 24 + src/core/lib/promise/try_join.h | 6 +- src/core/lib/promise/try_seq.h | 29 +- src/core/lib/resource_quota/arena.h | 4 +- src/core/lib/slice/slice_buffer.h | 3 + src/core/lib/surface/call.cc | 13 +- src/core/lib/surface/call.h | 3 +- src/core/lib/surface/server.cc | 23 +- src/core/lib/surface/server.h | 6 +- src/core/lib/transport/promise_endpoint.h | 36 +- src/core/lib/transport/transport.cc | 14 +- src/core/lib/transport/transport.h | 151 ++++- test/core/promise/mpsc_test.cc | 2 + test/core/transport/chaotic_good/BUILD | 82 ++- .../client_transport_error_test.cc | 549 ++++++++-------- .../chaotic_good/client_transport_test.cc | 618 ++++++------------ .../transport/chaotic_good/frame_fuzzer.cc | 63 +- .../transport/chaotic_good/frame_fuzzer.proto | 23 + .../frame_fuzzer_corpus/5072496117219328 | Bin 26 -> 0 bytes .../frame_fuzzer_corpus/5691448031772672 | Bin 51 -> 0 bytes ...h-05c704327d21af2cc914de40e9d90d06f16ca0eb | Bin 74 -> 0 bytes ...h-5015de8c7cafb0b0ebbbfd28c29aedd5dbfdc03a | Bin 66 -> 0 bytes ...h-5a34978de8de6889ce913947a77f43f7cdea854c | Bin 180 -> 0 bytes ...h-608f798a51077a8cdc45b11f335c079a81339fbe | Bin 166 -> 0 bytes ...h-6a002cb46eac21af4ab6fd74b61ff3ce26d96dff | Bin 161 -> 0 bytes ...h-7732ddd35a4deb8b7c9e462aaf8680986755e540 | Bin 79 -> 0 bytes ...h-c171e98ebfe8b6485f9a4bea0b9cdfe683776675 | Bin 70 -> 0 bytes .../chaotic_good/frame_fuzzer_corpus/empty | 1 + .../chaotic_good/frame_header_test.cc | 12 +- .../core/transport/chaotic_good/frame_test.cc | 31 +- .../chaotic_good/mock_promise_endpoint.cc | 89 +++ .../chaotic_good/mock_promise_endpoint.h | 77 +++ .../chaotic_good/server_transport_test.cc | 198 ++++++ .../transport/chaotic_good/transport_test.cc | 60 ++ .../transport/chaotic_good/transport_test.h | 67 ++ test/core/transport/promise_endpoint_test.cc | 40 +- tools/run_tests/generated/tests.json | 24 + 59 files changed, 2708 insertions(+), 1235 deletions(-) create mode 100644 src/core/ext/transport/chaotic_good/chaotic_good_transport.cc create mode 100644 src/core/ext/transport/chaotic_good/chaotic_good_transport.h create mode 100644 src/core/ext/transport/chaotic_good/server_transport.cc create mode 100644 src/core/ext/transport/chaotic_good/server_transport.h create mode 100644 test/core/transport/chaotic_good/frame_fuzzer.proto delete mode 100644 test/core/transport/chaotic_good/frame_fuzzer_corpus/5072496117219328 delete mode 100644 test/core/transport/chaotic_good/frame_fuzzer_corpus/5691448031772672 delete mode 100644 test/core/transport/chaotic_good/frame_fuzzer_corpus/crash-05c704327d21af2cc914de40e9d90d06f16ca0eb delete mode 100644 test/core/transport/chaotic_good/frame_fuzzer_corpus/crash-5015de8c7cafb0b0ebbbfd28c29aedd5dbfdc03a delete mode 100644 test/core/transport/chaotic_good/frame_fuzzer_corpus/crash-5a34978de8de6889ce913947a77f43f7cdea854c delete mode 100644 test/core/transport/chaotic_good/frame_fuzzer_corpus/crash-608f798a51077a8cdc45b11f335c079a81339fbe delete mode 100644 test/core/transport/chaotic_good/frame_fuzzer_corpus/crash-6a002cb46eac21af4ab6fd74b61ff3ce26d96dff delete mode 100644 test/core/transport/chaotic_good/frame_fuzzer_corpus/crash-7732ddd35a4deb8b7c9e462aaf8680986755e540 delete mode 100644 test/core/transport/chaotic_good/frame_fuzzer_corpus/crash-c171e98ebfe8b6485f9a4bea0b9cdfe683776675 create mode 100644 test/core/transport/chaotic_good/frame_fuzzer_corpus/empty create mode 100644 test/core/transport/chaotic_good/mock_promise_endpoint.cc create mode 100644 test/core/transport/chaotic_good/mock_promise_endpoint.h create mode 100644 test/core/transport/chaotic_good/server_transport_test.cc create mode 100644 test/core/transport/chaotic_good/transport_test.cc create mode 100644 test/core/transport/chaotic_good/transport_test.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 17d1a7410a8..5d178f90137 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1326,6 +1326,7 @@ if(gRPC_BUILD_TESTS) endif() add_dependencies(buildtests_cxx server_streaming_test) add_dependencies(buildtests_cxx server_test) + add_dependencies(buildtests_cxx server_transport_test) add_dependencies(buildtests_cxx service_config_end2end_test) add_dependencies(buildtests_cxx service_config_test) add_dependencies(buildtests_cxx settings_timeout_test) @@ -9533,6 +9534,7 @@ add_executable(client_transport_error_test ${_gRPC_PROTO_GENS_DIR}/test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.grpc.pb.cc ${_gRPC_PROTO_GENS_DIR}/test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.pb.h ${_gRPC_PROTO_GENS_DIR}/test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.grpc.pb.h + src/core/ext/transport/chaotic_good/chaotic_good_transport.cc src/core/ext/transport/chaotic_good/client_transport.cc src/core/ext/transport/chaotic_good/frame.cc src/core/ext/transport/chaotic_good/frame_header.cc @@ -9563,6 +9565,7 @@ target_include_directories(client_transport_error_test target_link_libraries(client_transport_error_test ${_gRPC_ALLTARGETS_LIBRARIES} gtest + grpc_unsecure ${_gRPC_PROTOBUF_LIBRARIES} grpc_test_util ) @@ -9576,12 +9579,15 @@ add_executable(client_transport_test ${_gRPC_PROTO_GENS_DIR}/test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.grpc.pb.cc ${_gRPC_PROTO_GENS_DIR}/test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.pb.h ${_gRPC_PROTO_GENS_DIR}/test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.grpc.pb.h + src/core/ext/transport/chaotic_good/chaotic_good_transport.cc src/core/ext/transport/chaotic_good/client_transport.cc src/core/ext/transport/chaotic_good/frame.cc src/core/ext/transport/chaotic_good/frame_header.cc src/core/lib/transport/promise_endpoint.cc test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.cc test/core/transport/chaotic_good/client_transport_test.cc + test/core/transport/chaotic_good/mock_promise_endpoint.cc + test/core/transport/chaotic_good/transport_test.cc ) target_compile_features(client_transport_test PUBLIC cxx_std_14) target_include_directories(client_transport_test @@ -22380,6 +22386,52 @@ target_link_libraries(server_test ) +endif() +if(gRPC_BUILD_TESTS) + +add_executable(server_transport_test + ${_gRPC_PROTO_GENS_DIR}/test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.pb.cc + ${_gRPC_PROTO_GENS_DIR}/test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.grpc.pb.cc + ${_gRPC_PROTO_GENS_DIR}/test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.pb.h + ${_gRPC_PROTO_GENS_DIR}/test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.grpc.pb.h + src/core/ext/transport/chaotic_good/chaotic_good_transport.cc + src/core/ext/transport/chaotic_good/frame.cc + src/core/ext/transport/chaotic_good/frame_header.cc + src/core/ext/transport/chaotic_good/server_transport.cc + src/core/lib/transport/promise_endpoint.cc + test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.cc + test/core/transport/chaotic_good/mock_promise_endpoint.cc + test/core/transport/chaotic_good/server_transport_test.cc + test/core/transport/chaotic_good/transport_test.cc +) +target_compile_features(server_transport_test PUBLIC cxx_std_14) +target_include_directories(server_transport_test + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR} + ${CMAKE_CURRENT_SOURCE_DIR}/include + ${_gRPC_ADDRESS_SORTING_INCLUDE_DIR} + ${_gRPC_RE2_INCLUDE_DIR} + ${_gRPC_SSL_INCLUDE_DIR} + ${_gRPC_UPB_GENERATED_DIR} + ${_gRPC_UPB_GRPC_GENERATED_DIR} + ${_gRPC_UPB_INCLUDE_DIR} + ${_gRPC_XXHASH_INCLUDE_DIR} + ${_gRPC_ZLIB_INCLUDE_DIR} + third_party/googletest/googletest/include + third_party/googletest/googletest + third_party/googletest/googlemock/include + third_party/googletest/googlemock + ${_gRPC_PROTO_GENS_DIR} +) + +target_link_libraries(server_transport_test + ${_gRPC_ALLTARGETS_LIBRARIES} + gtest + ${_gRPC_PROTOBUF_LIBRARIES} + grpc_test_util +) + + endif() if(gRPC_BUILD_TESTS) diff --git a/build_autogenerated.yaml b/build_autogenerated.yaml index b1492ba7068..46cd87ce1d7 100644 --- a/build_autogenerated.yaml +++ b/build_autogenerated.yaml @@ -7646,6 +7646,7 @@ targets: build: test language: c++ headers: + - src/core/ext/transport/chaotic_good/chaotic_good_transport.h - src/core/ext/transport/chaotic_good/client_transport.h - src/core/ext/transport/chaotic_good/frame.h - src/core/ext/transport/chaotic_good/frame_header.h @@ -7658,6 +7659,7 @@ targets: - test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.h src: - test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.proto + - src/core/ext/transport/chaotic_good/chaotic_good_transport.cc - src/core/ext/transport/chaotic_good/client_transport.cc - src/core/ext/transport/chaotic_good/frame.cc - src/core/ext/transport/chaotic_good/frame_header.cc @@ -7666,6 +7668,7 @@ targets: - test/core/transport/chaotic_good/client_transport_error_test.cc deps: - gtest + - grpc_unsecure - protobuf - grpc_test_util uses_polling: false @@ -7674,24 +7677,29 @@ targets: build: test language: c++ headers: + - src/core/ext/transport/chaotic_good/chaotic_good_transport.h - src/core/ext/transport/chaotic_good/client_transport.h - src/core/ext/transport/chaotic_good/frame.h - src/core/ext/transport/chaotic_good/frame_header.h - src/core/lib/promise/event_engine_wakeup_scheduler.h - src/core/lib/promise/inter_activity_pipe.h - - src/core/lib/promise/join.h - src/core/lib/promise/mpsc.h - src/core/lib/promise/wait_set.h - src/core/lib/transport/promise_endpoint.h - test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.h + - test/core/transport/chaotic_good/mock_promise_endpoint.h + - test/core/transport/chaotic_good/transport_test.h src: - test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.proto + - src/core/ext/transport/chaotic_good/chaotic_good_transport.cc - src/core/ext/transport/chaotic_good/client_transport.cc - src/core/ext/transport/chaotic_good/frame.cc - src/core/ext/transport/chaotic_good/frame_header.cc - src/core/lib/transport/promise_endpoint.cc - test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.cc - test/core/transport/chaotic_good/client_transport_test.cc + - test/core/transport/chaotic_good/mock_promise_endpoint.cc + - test/core/transport/chaotic_good/transport_test.cc deps: - gtest - protobuf @@ -15568,6 +15576,40 @@ targets: deps: - gtest - grpc_test_util +- name: server_transport_test + gtest: true + build: test + language: c++ + headers: + - src/core/ext/transport/chaotic_good/chaotic_good_transport.h + - src/core/ext/transport/chaotic_good/frame.h + - src/core/ext/transport/chaotic_good/frame_header.h + - src/core/ext/transport/chaotic_good/server_transport.h + - src/core/lib/promise/event_engine_wakeup_scheduler.h + - src/core/lib/promise/inter_activity_pipe.h + - src/core/lib/promise/mpsc.h + - src/core/lib/promise/switch.h + - src/core/lib/promise/wait_set.h + - src/core/lib/transport/promise_endpoint.h + - test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.h + - test/core/transport/chaotic_good/mock_promise_endpoint.h + - test/core/transport/chaotic_good/transport_test.h + src: + - test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.proto + - src/core/ext/transport/chaotic_good/chaotic_good_transport.cc + - src/core/ext/transport/chaotic_good/frame.cc + - src/core/ext/transport/chaotic_good/frame_header.cc + - src/core/ext/transport/chaotic_good/server_transport.cc + - src/core/lib/transport/promise_endpoint.cc + - test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.cc + - test/core/transport/chaotic_good/mock_promise_endpoint.cc + - test/core/transport/chaotic_good/server_transport_test.cc + - test/core/transport/chaotic_good/transport_test.cc + deps: + - gtest + - protobuf + - grpc_test_util + uses_polling: false - name: service_config_end2end_test gtest: true build: test diff --git a/include/grpc/event_engine/internal/slice_cast.h b/include/grpc/event_engine/internal/slice_cast.h index 8bcca60a24a..3f9593464cf 100644 --- a/include/grpc/event_engine/internal/slice_cast.h +++ b/include/grpc/event_engine/internal/slice_cast.h @@ -60,6 +60,18 @@ Result& SliceCast(T& value, SliceCastable = {}) { return reinterpret_cast(value); } +// Cast to `Result&&` from `T&&` without any runtime checks. +// This is only valid if `sizeof(Result) == sizeof(T)`, and if `Result`, `T` are +// opted in as compatible via `SliceCastable`. +template +Result&& SliceCast(T&& value, SliceCastable = {}) { + // Insist upon sizes being equal to catch mismatches. + // We assume if sizes are opted in and sizes are equal then yes, these two + // types are expected to be layout compatible and actually appear to be. + static_assert(sizeof(Result) == sizeof(T), "size mismatch"); + return reinterpret_cast(value); +} + } // namespace internal } // namespace experimental } // namespace grpc_event_engine diff --git a/include/grpc/event_engine/slice.h b/include/grpc/event_engine/slice.h index 8d49f391600..ce7693f6489 100644 --- a/include/grpc/event_engine/slice.h +++ b/include/grpc/event_engine/slice.h @@ -169,6 +169,11 @@ struct CopyConstructors { return Out(grpc_slice_from_copied_buffer(p, len)); } + static Out FromCopiedBuffer(const uint8_t* p, size_t len) { + return Out( + grpc_slice_from_copied_buffer(reinterpret_cast(p), len)); + } + template static Out FromCopiedBuffer(const Buffer& buffer) { return FromCopiedBuffer(reinterpret_cast(buffer.data()), diff --git a/src/core/BUILD b/src/core/BUILD index fcb11452d18..8ee2f308b86 100644 --- a/src/core/BUILD +++ b/src/core/BUILD @@ -6213,6 +6213,7 @@ grpc_cc_library( "bitset", "chaotic_good_frame_header", "context", + "match", "no_destruct", "slice", "slice_buffer", @@ -6368,6 +6369,29 @@ grpc_cc_library( ], ) +grpc_cc_library( + name = "chaotic_good_transport", + srcs = [ + "ext/transport/chaotic_good/chaotic_good_transport.cc", + ], + hdrs = [ + "ext/transport/chaotic_good/chaotic_good_transport.h", + ], + external_deps = ["absl/random"], + language = "c++", + deps = [ + "chaotic_good_frame", + "chaotic_good_frame_header", + "grpc_promise_endpoint", + "if", + "try_join", + "try_seq", + "//:gpr_platform", + "//:hpack_encoder", + "//:promise", + ], +) + grpc_cc_library( name = "chaotic_good_client_transport", srcs = [ @@ -6378,6 +6402,7 @@ grpc_cc_library( ], external_deps = [ "absl/base:core_headers", + "absl/container:flat_hash_map", "absl/random", "absl/random:bit_gen_ref", "absl/status", @@ -6388,9 +6413,11 @@ grpc_cc_library( language = "c++", deps = [ "activity", + "all_ok", "arena", "chaotic_good_frame", "chaotic_good_frame_header", + "chaotic_good_transport", "context", "event_engine_wakeup_scheduler", "for_each", @@ -6398,6 +6425,7 @@ grpc_cc_library( "if", "inter_activity_pipe", "loop", + "map", "match", "memory_quota", "mpsc", @@ -6414,6 +6442,63 @@ grpc_cc_library( "//:grpc_base", "//:hpack_encoder", "//:hpack_parser", + "//:promise", + "//:ref_counted_ptr", + ], +) + +grpc_cc_library( + name = "chaotic_good_server_transport", + srcs = [ + "ext/transport/chaotic_good/server_transport.cc", + ], + hdrs = [ + "ext/transport/chaotic_good/server_transport.h", + ], + external_deps = [ + "absl/base:core_headers", + "absl/container:flat_hash_map", + "absl/functional:any_invocable", + "absl/random", + "absl/random:bit_gen_ref", + "absl/status", + "absl/status:statusor", + "absl/types:optional", + "absl/types:variant", + ], + language = "c++", + deps = [ + "1999", + "activity", + "arena", + "chaotic_good_frame", + "chaotic_good_frame_header", + "chaotic_good_transport", + "context", + "default_event_engine", + "event_engine_wakeup_scheduler", + "for_each", + "grpc_promise_endpoint", + "if", + "inter_activity_pipe", + "loop", + "memory_quota", + "mpsc", + "pipe", + "poll", + "resource_quota", + "seq", + "slice", + "slice_buffer", + "switch", + "try_join", + "try_seq", + "//:exec_ctx", + "//:gpr", + "//:gpr_platform", + "//:grpc_base", + "//:hpack_encoder", + "//:hpack_parser", "//:ref_counted_ptr", ], ) diff --git a/src/core/ext/transport/chaotic_good/chaotic_good_transport.cc b/src/core/ext/transport/chaotic_good/chaotic_good_transport.cc new file mode 100644 index 00000000000..163f994d35f --- /dev/null +++ b/src/core/ext/transport/chaotic_good/chaotic_good_transport.cc @@ -0,0 +1,19 @@ +// Copyright 2023 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "src/core/ext/transport/chaotic_good/chaotic_good_transport.h" + +namespace grpc_core {} // namespace grpc_core diff --git a/src/core/ext/transport/chaotic_good/chaotic_good_transport.h b/src/core/ext/transport/chaotic_good/chaotic_good_transport.h new file mode 100644 index 00000000000..1096486bbea --- /dev/null +++ b/src/core/ext/transport/chaotic_good/chaotic_good_transport.h @@ -0,0 +1,111 @@ +// Copyright 2023 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GRPC_SRC_CORE_EXT_TRANSPORT_CHAOTIC_GOOD_CHAOTIC_GOOD_TRANSPORT_H +#define GRPC_SRC_CORE_EXT_TRANSPORT_CHAOTIC_GOOD_CHAOTIC_GOOD_TRANSPORT_H + +#include + +#include "absl/random/random.h" + +#include "src/core/ext/transport/chaotic_good/frame.h" +#include "src/core/ext/transport/chaotic_good/frame_header.h" +#include "src/core/ext/transport/chttp2/transport/hpack_encoder.h" +#include "src/core/lib/promise/if.h" +#include "src/core/lib/promise/promise.h" +#include "src/core/lib/promise/try_join.h" +#include "src/core/lib/promise/try_seq.h" +#include "src/core/lib/transport/promise_endpoint.h" + +namespace grpc_core { +namespace chaotic_good { + +class ChaoticGoodTransport { + public: + ChaoticGoodTransport(std::unique_ptr control_endpoint, + std::unique_ptr data_endpoint) + : control_endpoint_(std::move(control_endpoint)), + data_endpoint_(std::move(data_endpoint)) {} + + auto WriteFrame(const FrameInterface& frame) { + auto buffers = frame.Serialize(&encoder_); + return TryJoin( + control_endpoint_->Write(std::move(buffers.control)), + data_endpoint_->Write(std::move(buffers.data))); + } + + // Read frame header and payloads for control and data portions of one frame. + // Resolves to StatusOr>. + auto ReadFrameBytes() { + return TrySeq( + control_endpoint_->ReadSlice(FrameHeader::frame_header_size_), + [this](Slice read_buffer) { + auto frame_header = + FrameHeader::Parse(reinterpret_cast( + GRPC_SLICE_START_PTR(read_buffer.c_slice()))); + // Read header and trailers from control endpoint. + // Read message padding and message from data endpoint. + return If( + frame_header.ok(), + [this, &frame_header] { + const uint32_t message_padding = std::exchange( + last_message_padding_, frame_header->message_padding); + const uint32_t message_length = frame_header->message_length; + return Map( + TryJoin( + control_endpoint_->Read(frame_header->GetFrameLength()), + TrySeq(data_endpoint_->Read(message_padding), + [this, message_length]() { + return data_endpoint_->Read(message_length); + })), + [frame_header = *frame_header]( + absl::StatusOr> + buffers) + -> absl::StatusOr> { + if (!buffers.ok()) return buffers.status(); + return std::tuple( + frame_header, + BufferPair{std::move(std::get<0>(*buffers)), + std::move(std::get<1>(*buffers))}); + }); + }, + [&frame_header]() + -> absl::StatusOr> { + return frame_header.status(); + }); + }); + } + + absl::Status DeserializeFrame(FrameHeader header, BufferPair buffers, + Arena* arena, FrameInterface& frame) { + return frame.Deserialize(&parser_, header, bitgen_, arena, + std::move(buffers)); + } + + // Skip a frame, but correctly handle any hpack state updates. + void SkipFrame(FrameHeader, BufferPair) { Crash("not implemented"); } + + private: + const std::unique_ptr control_endpoint_; + const std::unique_ptr data_endpoint_; + uint32_t last_message_padding_ = 0; + HPackCompressor encoder_; + HPackParser parser_; + absl::BitGen bitgen_; +}; + +} // namespace chaotic_good +} // namespace grpc_core + +#endif // GRPC_SRC_CORE_EXT_TRANSPORT_CHAOTIC_GOOD_CHAOTIC_GOOD_TRANSPORT_H diff --git a/src/core/ext/transport/chaotic_good/client_transport.cc b/src/core/ext/transport/chaotic_good/client_transport.cc index 24f30687182..d4075bcad1b 100644 --- a/src/core/ext/transport/chaotic_good/client_transport.cc +++ b/src/core/ext/transport/chaotic_good/client_transport.cc @@ -17,9 +17,11 @@ #include "src/core/ext/transport/chaotic_good/client_transport.h" #include +#include #include #include #include +#include #include "absl/random/bit_gen_ref.h" #include "absl/random/random.h" @@ -36,9 +38,13 @@ #include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/iomgr/exec_ctx.h" #include "src/core/lib/promise/activity.h" +#include "src/core/lib/promise/all_ok.h" #include "src/core/lib/promise/event_engine_wakeup_scheduler.h" #include "src/core/lib/promise/loop.h" +#include "src/core/lib/promise/map.h" +#include "src/core/lib/promise/promise.h" #include "src/core/lib/promise/try_join.h" +#include "src/core/lib/promise/try_seq.h" #include "src/core/lib/resource_quota/arena.h" #include "src/core/lib/resource_quota/resource_quota.h" #include "src/core/lib/slice/slice.h" @@ -49,59 +55,15 @@ namespace grpc_core { namespace chaotic_good { -ClientTransport::ClientTransport( - std::unique_ptr control_endpoint, - std::unique_ptr data_endpoint, - std::shared_ptr event_engine) - : outgoing_frames_(MpscReceiver(4)), - control_endpoint_(std::move(control_endpoint)), - data_endpoint_(std::move(data_endpoint)), - control_endpoint_write_buffer_(SliceBuffer()), - data_endpoint_write_buffer_(SliceBuffer()), - hpack_compressor_(std::make_unique()), - hpack_parser_(std::make_unique()), - memory_allocator_( - ResourceQuota::Default()->memory_quota()->CreateMemoryAllocator( - "client_transport")), - arena_(MakeScopedArena(1024, &memory_allocator_)), - context_(arena_.get()), - event_engine_(event_engine) { - auto write_loop = Loop([this] { +auto ChaoticGoodClientTransport::TransportWriteLoop() { + return Loop([this] { return TrySeq( // Get next outgoing frame. - this->outgoing_frames_.Next(), - // Construct data buffers that will be sent to the endpoints. + outgoing_frames_.Next(), + // Serialize and write it out. [this](ClientFrame client_frame) { - MatchMutable( - &client_frame, - [this](ClientFragmentFrame* frame) mutable { - control_endpoint_write_buffer_.Append( - frame->Serialize(hpack_compressor_.get())); - if (frame->message != nullptr) { - std::string message_padding(frame->message_padding, '0'); - Slice slice(grpc_slice_from_cpp_string(message_padding)); - // Append message padding to data_endpoint_buffer. - data_endpoint_write_buffer_.Append(std::move(slice)); - // Append message payload to data_endpoint_buffer. - frame->message->payload()->MoveFirstNBytesIntoSliceBuffer( - frame->message->payload()->Length(), - data_endpoint_write_buffer_); - } - }, - [this](CancelFrame* frame) mutable { - control_endpoint_write_buffer_.Append( - frame->Serialize(hpack_compressor_.get())); - }); - return absl::OkStatus(); + return transport_.WriteFrame(GetFrameInterface(client_frame)); }, - // Write buffers to corresponding endpoints concurrently. - [this]() { - return TryJoin( - control_endpoint_->Write( - std::move(control_endpoint_write_buffer_)), - data_endpoint_->Write(std::move(data_endpoint_write_buffer_))); - }, - // Finish writes to difference endpoints and continue the loop. []() -> LoopCtl { // The write failures will be caught in TrySeq and exit loop. // Therefore, only need to return Continue() in the last lambda @@ -109,78 +71,215 @@ ClientTransport::ClientTransport( return Continue(); }); }); - writer_ = MakeActivity( - // Continuously write next outgoing frames to promise endpoints. - std::move(write_loop), EventEngineWakeupScheduler(event_engine_), - [this](absl::Status status) { - if (!(status.ok() || status.code() == absl::StatusCode::kCancelled)) { - this->AbortWithError(); - } +} + +absl::optional ChaoticGoodClientTransport::LookupStream( + uint32_t stream_id) { + MutexLock lock(&mu_); + auto it = stream_map_.find(stream_id); + if (it == stream_map_.end()) { + return absl::nullopt; + } + return it->second; +} + +auto ChaoticGoodClientTransport::PushFrameIntoCall(ServerFragmentFrame frame, + CallHandler call_handler) { + auto& headers = frame.headers; + return TrySeq( + If( + headers != nullptr, + [call_handler, &headers]() mutable { + return call_handler.PushServerInitialMetadata(std::move(headers)); + }, + []() -> StatusFlag { return Success{}; }), + [call_handler, message = std::move(frame.message)]() mutable { + return If( + message.has_value(), + [&call_handler, &message]() mutable { + return call_handler.PushMessage(std::move(message->message)); + }, + []() -> StatusFlag { return Success{}; }); }, - // Hold Arena in activity for GetContext usage. - arena_.get()); - auto read_loop = Loop([this] { + [call_handler, trailers = std::move(frame.trailers)]() mutable { + return If( + trailers != nullptr, + [&call_handler, &trailers]() mutable { + return call_handler.PushServerTrailingMetadata( + std::move(trailers)); + }, + []() -> StatusFlag { return Success{}; }); + }); +} + +auto ChaoticGoodClientTransport::TransportReadLoop() { + return Loop([this] { return TrySeq( - // Read frame header from control endpoint. - // TODO(ladynana): remove memcpy in ReadSlice. - this->control_endpoint_->ReadSlice(FrameHeader::frame_header_size_), - // Read different parts of the server frame from control/data endpoints - // based on frame header. - [this](Slice read_buffer) mutable { - frame_header_ = std::make_shared( - FrameHeader::Parse( - reinterpret_cast( - GRPC_SLICE_START_PTR(read_buffer.c_slice()))) - .value()); - // Read header and trailers from control endpoint. - // Read message padding and message from data endpoint. - return TryJoin( - control_endpoint_->Read(frame_header_->GetFrameLength()), - data_endpoint_->Read(frame_header_->message_padding + - frame_header_->message_length)); + transport_.ReadFrameBytes(), + [](std::tuple frame_bytes) + -> absl::StatusOr> { + const auto& frame_header = std::get<0>(frame_bytes); + if (frame_header.type != FrameType::kFragment) { + return absl::InternalError( + absl::StrCat("Expected fragment frame, got ", + static_cast(frame_header.type))); + } + return frame_bytes; }, - // Construct and send the server frame to corresponding stream. - [this](std::tuple ret) mutable { - control_endpoint_read_buffer_ = std::move(std::get<0>(ret)); - // Discard message padding and only keep message in data read buffer. - std::get<1>(ret).MoveLastNBytesIntoSliceBuffer( - frame_header_->message_length, data_endpoint_read_buffer_); + [this](std::tuple frame_bytes) { + const auto& frame_header = std::get<0>(frame_bytes); + auto& buffers = std::get<1>(frame_bytes); + absl::optional call_handler = + LookupStream(frame_header.stream_id); ServerFragmentFrame frame; - // Initialized to get this_cpu() info in global_stat(). - ExecCtx exec_ctx; - // Deserialize frame from read buffer. - absl::BitGen bitgen; - auto status = frame.Deserialize(hpack_parser_.get(), *frame_header_, - absl::BitGenRef(bitgen), - control_endpoint_read_buffer_); - GPR_ASSERT(status.ok()); - // Move message into frame. - frame.message = arena_->MakePooled( - std::move(data_endpoint_read_buffer_), 0); - MutexLock lock(&mu_); - const uint32_t stream_id = frame_header_->stream_id; - return stream_map_[stream_id]->Push(ServerFrame(std::move(frame))); - }, - // Check if send frame to corresponding stream successfully. - [](bool ret) -> LoopCtl { - if (ret) { - // Send incoming frames successfully. - return Continue(); + absl::Status deserialize_status; + if (call_handler.has_value()) { + deserialize_status = transport_.DeserializeFrame( + frame_header, std::move(buffers), call_handler->arena(), frame); } else { - return absl::InternalError("Send incoming frames failed."); + // Stream not found, skip the frame. + transport_.SkipFrame(frame_header, std::move(buffers)); + deserialize_status = absl::OkStatus(); } - }); + return If( + deserialize_status.ok() && call_handler.has_value(), + [this, &frame, &call_handler]() { + return call_handler->SpawnWaitable( + "push-frame", [this, call_handler = *call_handler, + frame = std::move(frame)]() mutable { + return Map(call_handler.CancelIfFails(PushFrameIntoCall( + std::move(frame), call_handler)), + [](StatusFlag f) { + return StatusCast(f); + }); + }); + }, + [&deserialize_status]() -> absl::Status { + // Stream not found, nothing to do. + return std::move(deserialize_status); + }); + }, + []() -> LoopCtl { return Continue{}; }); }); - reader_ = MakeActivity( - // Continuously read next incoming frames from promise endpoints. - std::move(read_loop), EventEngineWakeupScheduler(event_engine_), - [this](absl::Status status) { - if (!(status.ok() || status.code() == absl::StatusCode::kCancelled)) { - this->AbortWithError(); - } +} + +auto ChaoticGoodClientTransport::OnTransportActivityDone() { + return [this](absl::Status status) { + if (!(status.ok() || status.code() == absl::StatusCode::kCancelled)) { + this->AbortWithError(); + } + }; +} + +ChaoticGoodClientTransport::ChaoticGoodClientTransport( + std::unique_ptr control_endpoint, + std::unique_ptr data_endpoint, + std::shared_ptr event_engine) + : outgoing_frames_(4), + transport_(std::move(control_endpoint), std::move(data_endpoint)), + writer_{ + MakeActivity( + // Continuously write next outgoing frames to promise endpoints. + TransportWriteLoop(), EventEngineWakeupScheduler(event_engine), + OnTransportActivityDone()), + }, + reader_{MakeActivity( + // Continuously read next incoming frames from promise endpoints. + TransportReadLoop(), EventEngineWakeupScheduler(event_engine), + OnTransportActivityDone())} {} + +ChaoticGoodClientTransport::~ChaoticGoodClientTransport() { + if (writer_ != nullptr) { + writer_.reset(); + } + if (reader_ != nullptr) { + reader_.reset(); + } +} + +void ChaoticGoodClientTransport::AbortWithError() { + // Mark transport as unavailable when the endpoint write/read failed. + // Close all the available pipes. + outgoing_frames_.MarkClosed(); + ReleasableMutexLock lock(&mu_); + StreamMap stream_map = std::move(stream_map_); + stream_map_.clear(); + lock.Release(); + for (const auto& pair : stream_map) { + auto call_handler = pair.second; + call_handler.SpawnInfallible("cancel", [call_handler]() mutable { + call_handler.Cancel(ServerMetadataFromStatus( + absl::UnavailableError("Transport closed."))); + return Empty{}; + }); + } +} + +uint32_t ChaoticGoodClientTransport::MakeStream(CallHandler call_handler) { + ReleasableMutexLock lock(&mu_); + const uint32_t stream_id = next_stream_id_++; + stream_map_.emplace(stream_id, call_handler); + lock.Release(); + call_handler.OnDone([this, stream_id]() { + MutexLock lock(&mu_); + stream_map_.erase(stream_id); + }); + return stream_id; +} + +auto ChaoticGoodClientTransport::CallOutboundLoop(uint32_t stream_id, + CallHandler call_handler) { + auto send_fragment = [stream_id, + outgoing_frames = outgoing_frames_.MakeSender()]( + ClientFragmentFrame frame) mutable { + frame.stream_id = stream_id; + return Map(outgoing_frames.Send(std::move(frame)), + [](bool success) -> absl::Status { + if (!success) { + // Failed to send outgoing frame. + return absl::UnavailableError("Transport closed."); + } + return absl::OkStatus(); + }); + }; + return TrySeq( + // Wait for initial metadata then send it out. + call_handler.PullClientInitialMetadata(), + [send_fragment](ClientMetadataHandle md) mutable { + ClientFragmentFrame frame; + frame.headers = std::move(md); + return send_fragment(std::move(frame)); }, - // Hold Arena in activity for GetContext usage. - arena_.get()); + // Continuously send client frame with client to server messages. + ForEach(OutgoingMessages(call_handler), + [send_fragment, + aligned_bytes = aligned_bytes_](MessageHandle message) mutable { + ClientFragmentFrame frame; + // Construct frame header (flags, header_length and + // trailer_length will be added in serialization). + const uint32_t message_length = message->payload()->Length(); + const uint32_t padding = + message_length % aligned_bytes == 0 + ? 0 + : aligned_bytes - message_length % aligned_bytes; + GPR_ASSERT((message_length + padding) % aligned_bytes == 0); + frame.message = FragmentMessage(std::move(message), padding, + message_length); + return send_fragment(std::move(frame)); + }), + [send_fragment]() mutable { + ClientFragmentFrame frame; + frame.end_of_stream = true; + return send_fragment(std::move(frame)); + }); +} + +void ChaoticGoodClientTransport::StartCall(CallHandler call_handler) { + // At this point, the connection is set up. + // Start sending data frames. + call_handler.SpawnGuarded("outbound_loop", [this, call_handler]() mutable { + return CallOutboundLoop(MakeStream(call_handler), call_handler); + }); } } // namespace chaotic_good diff --git a/src/core/ext/transport/chaotic_good/client_transport.h b/src/core/ext/transport/chaotic_good/client_transport.h index 23ecdbfe84a..b8d515f9896 100644 --- a/src/core/ext/transport/chaotic_good/client_transport.h +++ b/src/core/ext/transport/chaotic_good/client_transport.h @@ -20,6 +20,7 @@ #include #include +#include #include // IWYU pragma: keep #include #include @@ -28,6 +29,8 @@ #include #include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/random/random.h" #include "absl/status/status.h" #include "absl/types/optional.h" #include "absl/types/variant.h" @@ -35,6 +38,7 @@ #include #include +#include "src/core/ext/transport/chaotic_good/chaotic_good_transport.h" #include "src/core/ext/transport/chaotic_good/frame.h" #include "src/core/ext/transport/chaotic_good/frame_header.h" #include "src/core/ext/transport/chttp2/transport/hpack_encoder.h" @@ -61,178 +65,56 @@ namespace grpc_core { namespace chaotic_good { -class ClientTransport { +class ChaoticGoodClientTransport final : public Transport, + public ClientTransport { public: - ClientTransport(std::unique_ptr control_endpoint, - std::unique_ptr data_endpoint, - std::shared_ptr - event_engine); - ~ClientTransport() { - if (writer_ != nullptr) { - writer_.reset(); - } - if (reader_ != nullptr) { - reader_.reset(); - } - } - void AbortWithError() { - // Mark transport as unavailable when the endpoint write/read failed. - // Close all the available pipes. - if (!outgoing_frames_.IsClosed()) { - outgoing_frames_.MarkClosed(); - } - MutexLock lock(&mu_); - for (const auto& pair : stream_map_) { - if (!pair.second->IsClose()) { - pair.second->MarkClose(); - } - } - } - auto AddStream(CallArgs call_args) { - // At this point, the connection is set up. - // Start sending data frames. - uint32_t stream_id; - InterActivityPipe pipe_server_frames; - { - MutexLock lock(&mu_); - stream_id = next_stream_id_++; - stream_map_.insert( - std::pair::Sender>>( - stream_id, std::make_shared::Sender>( - std::move(pipe_server_frames.sender)))); - } - return TrySeq( - TryJoin( - // Continuously send client frame with client to server messages. - ForEach(std::move(*call_args.client_to_server_messages), - [stream_id, initial_frame = true, - client_initial_metadata = - std::move(call_args.client_initial_metadata), - outgoing_frames = outgoing_frames_.MakeSender(), - this](MessageHandle result) mutable { - ClientFragmentFrame frame; - // Construct frame header (flags, header_length and - // trailer_length will be added in serialization). - uint32_t message_length = result->payload()->Length(); - frame.stream_id = stream_id; - frame.message_padding = message_length % aligned_bytes; - frame.message = std::move(result); - if (initial_frame) { - // Send initial frame with client intial metadata. - frame.headers = std::move(client_initial_metadata); - initial_frame = false; - } - return TrySeq( - outgoing_frames.Send(ClientFrame(std::move(frame))), - [](bool success) -> absl::Status { - if (!success) { - // TODO(ladynana): propagate the actual error - // message from EventEngine. - return absl::UnavailableError( - "Transport closed due to endpoint write/read " - "failed."); - } - return absl::OkStatus(); - }); - }), - // Continuously receive server frames from endpoints and save - // results to call_args. - Loop([server_initial_metadata = call_args.server_initial_metadata, - server_to_client_messages = - call_args.server_to_client_messages, - receiver = std::move(pipe_server_frames.receiver)]() mutable { - return TrySeq( - // Receive incoming server frame. - receiver.Next(), - // Save incomming frame results to call_args. - [server_initial_metadata, server_to_client_messages]( - absl::optional server_frame) mutable { - bool transport_closed = false; - ServerFragmentFrame frame; - if (!server_frame.has_value()) { - // Incoming server frame pipe is closed, which only - // happens when transport is aborted. - transport_closed = true; - } else { - frame = std::move( - absl::get(*server_frame)); - }; - bool has_headers = (frame.headers != nullptr); - bool has_message = (frame.message != nullptr); - bool has_trailers = (frame.trailers != nullptr); - return TrySeq( - If((!transport_closed) && has_headers, - [server_initial_metadata, - headers = std::move(frame.headers)]() mutable { - return server_initial_metadata->Push( - std::move(headers)); - }, - [] { return false; }), - If((!transport_closed) && has_message, - [server_to_client_messages, - message = std::move(frame.message)]() mutable { - return server_to_client_messages->Push( - std::move(message)); - }, - [] { return false; }), - If((!transport_closed) && has_trailers, - [trailers = std::move(frame.trailers)]() mutable - -> LoopCtl { - return std::move(trailers); - }, - [transport_closed]() - -> LoopCtl { - if (transport_closed) { - // TODO(ladynana): propagate the actual error - // message from EventEngine. - return ServerMetadataFromStatus( - absl::UnavailableError( - "Transport closed due to endpoint " - "write/read failed.")); - } - return Continue(); - })); - }); - })), - [](std::tuple ret) { - return std::move(std::get<1>(ret)); - }); - } + ChaoticGoodClientTransport( + std::unique_ptr control_endpoint, + std::unique_ptr data_endpoint, + std::shared_ptr + event_engine); + ~ChaoticGoodClientTransport() override; + + FilterStackTransport* filter_stack_transport() override { return nullptr; } + ClientTransport* client_transport() override { return this; } + ServerTransport* server_transport() override { return nullptr; } + absl::string_view GetTransportName() const override { return "chaotic_good"; } + void SetPollset(grpc_stream*, grpc_pollset*) override {} + void SetPollsetSet(grpc_stream*, grpc_pollset_set*) override {} + void PerformOp(grpc_transport_op*) override { Crash("unimplemented"); } + grpc_endpoint* GetEndpoint() override { return nullptr; } + void Orphan() override { delete this; } + + void StartCall(CallHandler call_handler) override; + void AbortWithError(); private: + // Queue size of each stream pipe is set to 2, so that for each stream read it + // will queue at most 2 frames. + static const size_t kServerFrameQueueSize = 2; + using StreamMap = absl::flat_hash_map; + + uint32_t MakeStream(CallHandler call_handler); + absl::optional LookupStream(uint32_t stream_id); + auto CallOutboundLoop(uint32_t stream_id, CallHandler call_handler); + auto OnTransportActivityDone(); + auto TransportWriteLoop(); + auto TransportReadLoop(); + // Push one frame into a call + auto PushFrameIntoCall(ServerFragmentFrame frame, CallHandler call_handler); + // Max buffer is set to 4, so that for stream writes each time it will queue // at most 2 frames. MpscReceiver outgoing_frames_; - // Queue size of each stream pipe is set to 2, so that for each stream read it - // will queue at most 2 frames. - static const size_t server_frame_queue_size_ = 2; + ChaoticGoodTransport transport_; // Assigned aligned bytes from setting frame. - size_t aligned_bytes = 64; + size_t aligned_bytes_ = 64; Mutex mu_; uint32_t next_stream_id_ ABSL_GUARDED_BY(mu_) = 1; // Map of stream incoming server frames, key is stream_id. - std::map::Sender>> - stream_map_ ABSL_GUARDED_BY(mu_); + StreamMap stream_map_ ABSL_GUARDED_BY(mu_); ActivityPtr writer_; ActivityPtr reader_; - std::unique_ptr control_endpoint_; - std::unique_ptr data_endpoint_; - SliceBuffer control_endpoint_write_buffer_; - SliceBuffer data_endpoint_write_buffer_; - SliceBuffer control_endpoint_read_buffer_; - SliceBuffer data_endpoint_read_buffer_; - std::unique_ptr hpack_compressor_; - std::unique_ptr hpack_parser_; - std::shared_ptr frame_header_; - MemoryAllocator memory_allocator_; - ScopedArenaPtr arena_; - promise_detail::Context context_; - // Use to synchronize writer_ and reader_ activity with outside activities; - std::shared_ptr event_engine_; }; } // namespace chaotic_good diff --git a/src/core/ext/transport/chaotic_good/frame.cc b/src/core/ext/transport/chaotic_good/frame.cc index f49fa4c4f3b..d2c0b3a4fc0 100644 --- a/src/core/ext/transport/chaotic_good/frame.cc +++ b/src/core/ext/transport/chaotic_good/frame.cc @@ -40,6 +40,10 @@ namespace grpc_core { namespace chaotic_good { +namespace { +const uint8_t kZeros[64] = {}; +} + namespace { const NoDestruct kZeroSlice{[] { // Frame header size is fixed to 24 bytes. @@ -50,53 +54,65 @@ const NoDestruct kZeroSlice{[] { class FrameSerializer { public: - explicit FrameSerializer(FrameType frame_type, uint32_t stream_id, - uint32_t message_padding) { - output_.AppendIndexed(kZeroSlice->Copy()); + explicit FrameSerializer(FrameType frame_type, uint32_t stream_id) { + output_.control.AppendIndexed(kZeroSlice->Copy()); header_.type = frame_type; header_.stream_id = stream_id; - header_.message_padding = message_padding; header_.flags.SetAll(false); } + // If called, must be called before AddTrailers, Finish. SliceBuffer& AddHeaders() { header_.flags.set(0); - return output_; + return output_.control; + } + + void AddMessage(const FragmentMessage& msg) { + header_.flags.set(1); + header_.message_length = msg.length; + header_.message_padding = msg.padding; + output_.data = msg.message->payload()->Copy(); + if (msg.padding != 0) { + output_.data.Append(Slice::FromStaticBuffer(kZeros, msg.padding)); + } } + // If called, must be called before Finish. SliceBuffer& AddTrailers() { - header_.flags.set(1); - header_.header_length = output_.Length() - FrameHeader::frame_header_size_; - return output_; + header_.flags.set(2); + header_.header_length = + output_.control.Length() - FrameHeader::frame_header_size_; + return output_.control; } - SliceBuffer Finish() { + BufferPair Finish() { // Calculate frame header_length or trailer_length if available. - if (header_.flags.is_set(1)) { + if (header_.flags.is_set(2)) { // Header length is already known in AddTrailers(). - header_.trailer_length = output_.Length() - header_.header_length - + header_.trailer_length = output_.control.Length() - + header_.header_length - FrameHeader::frame_header_size_; } else { if (header_.flags.is_set(0)) { // Calculate frame header length in Finish() since AddTrailers() isn't // called. header_.header_length = - output_.Length() - FrameHeader::frame_header_size_; + output_.control.Length() - FrameHeader::frame_header_size_; } } header_.Serialize( - GRPC_SLICE_START_PTR(output_.c_slice_buffer()->slices[0])); + GRPC_SLICE_START_PTR(output_.control.c_slice_buffer()->slices[0])); return std::move(output_); } private: FrameHeader header_; - SliceBuffer output_; + BufferPair output_; }; class FrameDeserializer { public: - FrameDeserializer(const FrameHeader& header, SliceBuffer& input) + FrameDeserializer(const FrameHeader& header, BufferPair& input) : header_(header), input_(input) {} const FrameHeader& header() const { return header_; } // If called, must be called before ReceiveTrailers, Finish. @@ -118,28 +134,27 @@ class FrameDeserializer { private: absl::StatusOr Take(uint32_t length) { if (length == 0) return SliceBuffer{}; - if (input_.Length() < length) { + if (input_.control.Length() < length) { return absl::InvalidArgumentError( "Frame too short (insufficient payload)"); } SliceBuffer out; - input_.MoveFirstNBytesIntoSliceBuffer(length, out); + input_.control.MoveFirstNBytesIntoSliceBuffer(length, out); return std::move(out); } FrameHeader header_; - SliceBuffer& input_; + BufferPair& input_; }; template absl::StatusOr> ReadMetadata( HPackParser* parser, absl::StatusOr maybe_slices, - uint32_t stream_id, bool is_header, bool is_client, - absl::BitGenRef bitsrc) { + uint32_t stream_id, bool is_header, bool is_client, absl::BitGenRef bitsrc, + Arena* arena) { if (!maybe_slices.ok()) return maybe_slices.status(); auto& slices = *maybe_slices; - auto arena = GetContext(); GPR_ASSERT(arena != nullptr); - Arena::PoolPtr metadata = arena->MakePooled(arena); + Arena::PoolPtr metadata = Arena::MakePooled(arena); parser->BeginFrame( metadata.get(), std::numeric_limits::max(), std::numeric_limits::max(), @@ -161,20 +176,23 @@ absl::StatusOr> ReadMetadata( } // namespace absl::Status SettingsFrame::Deserialize(HPackParser*, const FrameHeader& header, - absl::BitGenRef, - SliceBuffer& slice_buffer) { + absl::BitGenRef, Arena*, + BufferPair buffers) { if (header.type != FrameType::kSettings) { return absl::InvalidArgumentError("Expected settings frame"); } if (header.flags.any()) { return absl::InvalidArgumentError("Unexpected flags"); } - FrameDeserializer deserializer(header, slice_buffer); + if (buffers.data.Length() != 0) { + return absl::InvalidArgumentError("Unexpected data"); + } + FrameDeserializer deserializer(header, buffers); return deserializer.Finish(); } -SliceBuffer SettingsFrame::Serialize(HPackCompressor*) const { - FrameSerializer serializer(FrameType::kSettings, 0, 0); +BufferPair SettingsFrame::Serialize(HPackCompressor*) const { + FrameSerializer serializer(FrameType::kSettings, 0); return serializer.Finish(); } @@ -183,19 +201,20 @@ std::string SettingsFrame::ToString() const { return "SettingsFrame{}"; } absl::Status ClientFragmentFrame::Deserialize(HPackParser* parser, const FrameHeader& header, absl::BitGenRef bitsrc, - SliceBuffer& slice_buffer) { + Arena* arena, + BufferPair buffers) { if (header.stream_id == 0) { return absl::InvalidArgumentError("Expected non-zero stream id"); } stream_id = header.stream_id; - message_padding = header.message_padding; if (header.type != FrameType::kFragment) { return absl::InvalidArgumentError("Expected fragment frame"); } - FrameDeserializer deserializer(header, slice_buffer); + FrameDeserializer deserializer(header, buffers); if (header.flags.is_set(0)) { auto r = ReadMetadata(parser, deserializer.ReceiveHeaders(), - header.stream_id, true, true, bitsrc); + header.stream_id, true, true, bitsrc, + arena); if (!r.ok()) return r.status(); if (r.value() != nullptr) { headers = std::move(r.value()); @@ -205,8 +224,17 @@ absl::Status ClientFragmentFrame::Deserialize(HPackParser* parser, "Unexpected non-zero header length", header.header_length)); } if (header.flags.is_set(1)) { + message = + FragmentMessage{Arena::MakePooled(std::move(buffers.data), 0), + header.message_padding, header.message_length}; + } else if (buffers.data.Length() != 0) { + return absl::InvalidArgumentError(absl::StrCat( + "Unexpected non-zero message length ", buffers.data.Length())); + } + if (header.flags.is_set(2)) { if (header.trailer_length != 0) { - return absl::InvalidArgumentError("Unexpected trailer length"); + return absl::InvalidArgumentError( + absl::StrCat("Unexpected trailer length ", header.trailer_length)); } end_of_stream = true; } else { @@ -215,42 +243,53 @@ absl::Status ClientFragmentFrame::Deserialize(HPackParser* parser, return deserializer.Finish(); } -SliceBuffer ClientFragmentFrame::Serialize(HPackCompressor* encoder) const { +BufferPair ClientFragmentFrame::Serialize(HPackCompressor* encoder) const { GPR_ASSERT(stream_id != 0); - FrameSerializer serializer(FrameType::kFragment, stream_id, message_padding); + FrameSerializer serializer(FrameType::kFragment, stream_id); if (headers.get() != nullptr) { encoder->EncodeRawHeaders(*headers.get(), serializer.AddHeaders()); } + if (message.has_value()) { + serializer.AddMessage(message.value()); + } if (end_of_stream) { serializer.AddTrailers(); } return serializer.Finish(); } +std::string FragmentMessage::ToString() const { + std::string out = + absl::StrCat("FragmentMessage{length=", length, ", padding=", padding); + if (message.get() != nullptr) { + absl::StrAppend(&out, ", message=", message->DebugString().c_str()); + } + absl::StrAppend(&out, "}"); + return out; +} + std::string ClientFragmentFrame::ToString() const { return absl::StrCat( "ClientFragmentFrame{stream_id=", stream_id, ", headers=", headers.get() != nullptr ? headers->DebugString().c_str() : "nullptr", - ", message=", - message.get() != nullptr ? message->DebugString().c_str() : "nullptr", - ", message_padding=", message_padding, ", end_of_stream=", end_of_stream, - "}"); + ", message=", message.has_value() ? message->ToString().c_str() : "none", + ", end_of_stream=", end_of_stream, "}"); } absl::Status ServerFragmentFrame::Deserialize(HPackParser* parser, const FrameHeader& header, absl::BitGenRef bitsrc, - SliceBuffer& slice_buffer) { + Arena* arena, + BufferPair buffers) { if (header.stream_id == 0) { return absl::InvalidArgumentError("Expected non-zero stream id"); } stream_id = header.stream_id; - message_padding = header.message_padding; - FrameDeserializer deserializer(header, slice_buffer); + FrameDeserializer deserializer(header, buffers); if (header.flags.is_set(0)) { - auto r = - ReadMetadata(parser, deserializer.ReceiveHeaders(), - header.stream_id, true, false, bitsrc); + auto r = ReadMetadata(parser, deserializer.ReceiveHeaders(), + header.stream_id, true, false, bitsrc, + arena); if (!r.ok()) return r.status(); if (r.value() != nullptr) { headers = std::move(r.value()); @@ -260,9 +299,16 @@ absl::Status ServerFragmentFrame::Deserialize(HPackParser* parser, "Unexpected non-zero header length", header.header_length)); } if (header.flags.is_set(1)) { - auto r = - ReadMetadata(parser, deserializer.ReceiveTrailers(), - header.stream_id, false, false, bitsrc); + message.emplace(Arena::MakePooled(std::move(buffers.data), 0), + header.message_padding, header.message_length); + } else if (buffers.data.Length() != 0) { + return absl::InvalidArgumentError(absl::StrCat( + "Unexpected non-zero message length", buffers.data.Length())); + } + if (header.flags.is_set(2)) { + auto r = ReadMetadata( + parser, deserializer.ReceiveTrailers(), header.stream_id, false, false, + bitsrc, arena); if (!r.ok()) return r.status(); if (r.value() != nullptr) { trailers = std::move(r.value()); @@ -274,12 +320,15 @@ absl::Status ServerFragmentFrame::Deserialize(HPackParser* parser, return deserializer.Finish(); } -SliceBuffer ServerFragmentFrame::Serialize(HPackCompressor* encoder) const { +BufferPair ServerFragmentFrame::Serialize(HPackCompressor* encoder) const { GPR_ASSERT(stream_id != 0); - FrameSerializer serializer(FrameType::kFragment, stream_id, message_padding); + FrameSerializer serializer(FrameType::kFragment, stream_id); if (headers.get() != nullptr) { encoder->EncodeRawHeaders(*headers.get(), serializer.AddHeaders()); } + if (message.has_value()) { + serializer.AddMessage(message.value()); + } if (trailers.get() != nullptr) { encoder->EncodeRawHeaders(*trailers.get(), serializer.AddTrailers()); } @@ -290,16 +339,15 @@ std::string ServerFragmentFrame::ToString() const { return absl::StrCat( "ServerFragmentFrame{stream_id=", stream_id, ", headers=", headers.get() != nullptr ? headers->DebugString().c_str() : "nullptr", - ", message=", - message.get() != nullptr ? message->DebugString().c_str() : "nullptr", - ", message_padding=", message_padding, ", trailers=", + ", message=", message.has_value() ? message->ToString().c_str() : "none", + ", trailers=", trailers.get() != nullptr ? trailers->DebugString().c_str() : "nullptr", "}"); } absl::Status CancelFrame::Deserialize(HPackParser*, const FrameHeader& header, - absl::BitGenRef, - SliceBuffer& slice_buffer) { + absl::BitGenRef, Arena*, + BufferPair buffers) { if (header.type != FrameType::kCancel) { return absl::InvalidArgumentError("Expected cancel frame"); } @@ -309,14 +357,17 @@ absl::Status CancelFrame::Deserialize(HPackParser*, const FrameHeader& header, if (header.stream_id == 0) { return absl::InvalidArgumentError("Expected non-zero stream id"); } - FrameDeserializer deserializer(header, slice_buffer); + if (buffers.data.Length() != 0) { + return absl::InvalidArgumentError("Unexpected data"); + } + FrameDeserializer deserializer(header, buffers); stream_id = header.stream_id; return deserializer.Finish(); } -SliceBuffer CancelFrame::Serialize(HPackCompressor*) const { +BufferPair CancelFrame::Serialize(HPackCompressor*) const { GPR_ASSERT(stream_id != 0); - FrameSerializer serializer(FrameType::kCancel, stream_id, 0); + FrameSerializer serializer(FrameType::kCancel, stream_id); return serializer.Finish(); } diff --git a/src/core/ext/transport/chaotic_good/frame.h b/src/core/ext/transport/chaotic_good/frame.h index 529c89570c7..e7ccd6ee222 100644 --- a/src/core/ext/transport/chaotic_good/frame.h +++ b/src/core/ext/transport/chaotic_good/frame.h @@ -28,6 +28,7 @@ #include "src/core/ext/transport/chaotic_good/frame_header.h" #include "src/core/ext/transport/chttp2/transport/hpack_encoder.h" #include "src/core/ext/transport/chttp2/transport/hpack_parser.h" +#include "src/core/lib/gprpp/match.h" #include "src/core/lib/resource_quota/arena.h" #include "src/core/lib/slice/slice_buffer.h" #include "src/core/lib/transport/metadata_batch.h" @@ -36,20 +37,21 @@ namespace grpc_core { namespace chaotic_good { +struct BufferPair { + SliceBuffer control; + SliceBuffer data; +}; + class FrameInterface { public: virtual absl::Status Deserialize(HPackParser* parser, const FrameHeader& header, - absl::BitGenRef bitsrc, - SliceBuffer& slice_buffer) = 0; - virtual SliceBuffer Serialize(HPackCompressor* encoder) const = 0; + absl::BitGenRef bitsrc, Arena* arena, + BufferPair buffers) = 0; + virtual BufferPair Serialize(HPackCompressor* encoder) const = 0; virtual std::string ToString() const = 0; protected: - static bool EqVal(const Message& a, const Message& b) { - return a.payload()->JoinIntoString() == b.payload()->JoinIntoString() && - a.flags() == b.flags(); - } static bool EqVal(const grpc_metadata_batch& a, const grpc_metadata_batch& b) { return a.DebugString() == b.DebugString(); @@ -65,57 +67,75 @@ class FrameInterface { struct SettingsFrame final : public FrameInterface { absl::Status Deserialize(HPackParser* parser, const FrameHeader& header, - absl::BitGenRef bitsrc, - SliceBuffer& slice_buffer) override; - SliceBuffer Serialize(HPackCompressor* encoder) const override; + absl::BitGenRef bitsrc, Arena* arena, + BufferPair buffers) override; + BufferPair Serialize(HPackCompressor* encoder) const override; std::string ToString() const override; bool operator==(const SettingsFrame&) const { return true; } }; +struct FragmentMessage { + FragmentMessage(MessageHandle message, uint32_t padding, uint32_t length) + : message(std::move(message)), padding(padding), length(length) {} + + MessageHandle message; + uint32_t padding; + uint32_t length; + + std::string ToString() const; + + static bool EqVal(const Message& a, const Message& b) { + return a.payload()->JoinIntoString() == b.payload()->JoinIntoString() && + a.flags() == b.flags(); + } + + bool operator==(const FragmentMessage& other) const { + return EqVal(*message, *other.message) && length == other.length; + } +}; + struct ClientFragmentFrame final : public FrameInterface { absl::Status Deserialize(HPackParser* parser, const FrameHeader& header, - absl::BitGenRef bitsrc, - SliceBuffer& slice_buffer) override; - SliceBuffer Serialize(HPackCompressor* encoder) const override; + absl::BitGenRef bitsrc, Arena* arena, + BufferPair buffers) override; + BufferPair Serialize(HPackCompressor* encoder) const override; std::string ToString() const override; uint32_t stream_id; ClientMetadataHandle headers; - MessageHandle message; - uint32_t message_padding; + absl::optional message; bool end_of_stream = false; bool operator==(const ClientFragmentFrame& other) const { return stream_id == other.stream_id && EqHdl(headers, other.headers) && - end_of_stream == other.end_of_stream; + message == other.message && end_of_stream == other.end_of_stream; } }; struct ServerFragmentFrame final : public FrameInterface { absl::Status Deserialize(HPackParser* parser, const FrameHeader& header, - absl::BitGenRef bitsrc, - SliceBuffer& slice_buffer) override; - SliceBuffer Serialize(HPackCompressor* encoder) const override; + absl::BitGenRef bitsrc, Arena* arena, + BufferPair buffers) override; + BufferPair Serialize(HPackCompressor* encoder) const override; std::string ToString() const override; uint32_t stream_id; ServerMetadataHandle headers; - MessageHandle message; - uint32_t message_padding; + absl::optional message; ServerMetadataHandle trailers; bool operator==(const ServerFragmentFrame& other) const { return stream_id == other.stream_id && EqHdl(headers, other.headers) && - EqHdl(trailers, other.trailers); + message == other.message && EqHdl(trailers, other.trailers); } }; struct CancelFrame final : public FrameInterface { absl::Status Deserialize(HPackParser* parser, const FrameHeader& header, - absl::BitGenRef bitsrc, - SliceBuffer& slice_buffer) override; - SliceBuffer Serialize(HPackCompressor* encoder) const override; + absl::BitGenRef bitsrc, Arena* arena, + BufferPair buffers) override; + BufferPair Serialize(HPackCompressor* encoder) const override; std::string ToString() const override; uint32_t stream_id; @@ -128,6 +148,19 @@ struct CancelFrame final : public FrameInterface { using ClientFrame = absl::variant; using ServerFrame = absl::variant; +inline FrameInterface& GetFrameInterface(ClientFrame& frame) { + return MatchMutable( + &frame, + [](ClientFragmentFrame* frame) -> FrameInterface& { return *frame; }, + [](CancelFrame* frame) -> FrameInterface& { return *frame; }); +} + +inline FrameInterface& GetFrameInterface(ServerFrame& frame) { + return MatchMutable( + &frame, + [](ServerFragmentFrame* frame) -> FrameInterface& { return *frame; }); +} + } // namespace chaotic_good } // namespace grpc_core diff --git a/src/core/ext/transport/chaotic_good/frame_header.cc b/src/core/ext/transport/chaotic_good/frame_header.cc index e39d6a34b58..06f9d146e66 100644 --- a/src/core/ext/transport/chaotic_good/frame_header.cc +++ b/src/core/ext/transport/chaotic_good/frame_header.cc @@ -46,7 +46,6 @@ void FrameHeader::Serialize(uint8_t* data) const { WriteLittleEndianUint32( static_cast(type) | (flags.ToInt() << 8), data); if (flags.is_set(0)) GPR_ASSERT(header_length > 0); - if (flags.is_set(1)) GPR_ASSERT(trailer_length > 0); WriteLittleEndianUint32(stream_id, data + 4); WriteLittleEndianUint32(header_length, data + 8); WriteLittleEndianUint32(message_length, data + 12); @@ -60,8 +59,8 @@ absl::StatusOr FrameHeader::Parse(const uint8_t* data) { const uint32_t type_and_flags = ReadLittleEndianUint32(data); header.type = static_cast(type_and_flags & 0xff); const uint32_t flags = type_and_flags >> 8; - if (flags > 3) return absl::InvalidArgumentError("Invalid flags"); - header.flags = BitSet<2>::FromInt(flags); + if (flags > 7) return absl::InvalidArgumentError("Invalid flags"); + header.flags = BitSet<3>::FromInt(flags); header.stream_id = ReadLittleEndianUint32(data + 4); header.header_length = ReadLittleEndianUint32(data + 8); if (header.flags.is_set(0) && header.header_length <= 0) { @@ -70,11 +69,11 @@ absl::StatusOr FrameHeader::Parse(const uint8_t* data) { } header.message_length = ReadLittleEndianUint32(data + 12); header.message_padding = ReadLittleEndianUint32(data + 16); - header.trailer_length = ReadLittleEndianUint32(data + 20); - if (header.flags.is_set(1) && header.trailer_length <= 0) { + if (header.flags.is_set(1) && header.message_length <= 0) { return absl::InvalidArgumentError( - absl::StrCat("Invalid trailer length", header.trailer_length)); + absl::StrCat("Invalid message length: ", header.message_length)); } + header.trailer_length = ReadLittleEndianUint32(data + 20); return header; } diff --git a/src/core/ext/transport/chaotic_good/frame_header.h b/src/core/ext/transport/chaotic_good/frame_header.h index fa236ed3342..773b44f26e3 100644 --- a/src/core/ext/transport/chaotic_good/frame_header.h +++ b/src/core/ext/transport/chaotic_good/frame_header.h @@ -36,7 +36,7 @@ enum class FrameType : uint8_t { struct FrameHeader { FrameType type = FrameType::kCancel; - BitSet<2> flags; + BitSet<3> flags; uint32_t stream_id = 0; uint32_t header_length = 0; uint32_t message_length = 0; diff --git a/src/core/ext/transport/chaotic_good/server_transport.cc b/src/core/ext/transport/chaotic_good/server_transport.cc new file mode 100644 index 00000000000..3d4387ac949 --- /dev/null +++ b/src/core/ext/transport/chaotic_good/server_transport.cc @@ -0,0 +1,332 @@ +// Copyright 2022 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "src/core/ext/transport/chaotic_good/server_transport.h" + +#include +#include +#include + +#include "absl/random/bit_gen_ref.h" +#include "absl/random/random.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" + +#include +#include +#include + +#include "src/core/ext/transport/chaotic_good/frame.h" +#include "src/core/ext/transport/chaotic_good/frame_header.h" +#include "src/core/ext/transport/chttp2/transport/hpack_encoder.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" +#include "src/core/lib/iomgr/exec_ctx.h" +#include "src/core/lib/promise/activity.h" +#include "src/core/lib/promise/event_engine_wakeup_scheduler.h" +#include "src/core/lib/promise/for_each.h" +#include "src/core/lib/promise/loop.h" +#include "src/core/lib/promise/switch.h" +#include "src/core/lib/promise/try_seq.h" +#include "src/core/lib/resource_quota/arena.h" +#include "src/core/lib/resource_quota/resource_quota.h" +#include "src/core/lib/slice/slice.h" +#include "src/core/lib/slice/slice_buffer.h" +#include "src/core/lib/transport/promise_endpoint.h" + +namespace grpc_core { +namespace chaotic_good { + +auto ChaoticGoodServerTransport::TransportWriteLoop() { + return Loop([this] { + return TrySeq( + // Get next outgoing frame. + outgoing_frames_.Next(), + // Serialize and write it out. + [this](ServerFrame client_frame) { + return transport_.WriteFrame(GetFrameInterface(client_frame)); + }, + []() -> LoopCtl { + // The write failures will be caught in TrySeq and exit loop. + // Therefore, only need to return Continue() in the last lambda + // function. + return Continue(); + }); + }); +} + +auto ChaoticGoodServerTransport::PushFragmentIntoCall( + CallInitiator call_initiator, ClientFragmentFrame frame) { + auto& headers = frame.headers; + return TrySeq( + If( + headers != nullptr, + [call_initiator, &headers]() mutable { + return call_initiator.PushClientInitialMetadata(std::move(headers)); + }, + []() -> StatusFlag { return Success{}; }), + [call_initiator, message = std::move(frame.message)]() mutable { + return If( + message.has_value(), + [&call_initiator, &message]() mutable { + return call_initiator.PushMessage(std::move(message->message)); + }, + []() -> StatusFlag { return Success{}; }); + }, + [call_initiator, + end_of_stream = frame.end_of_stream]() mutable -> StatusFlag { + if (end_of_stream) call_initiator.FinishSends(); + return Success{}; + }); +} + +auto ChaoticGoodServerTransport::MaybePushFragmentIntoCall( + absl::optional call_initiator, absl::Status error, + ClientFragmentFrame frame) { + return If( + call_initiator.has_value() && error.ok(), + [this, &call_initiator, &frame]() { + return Map( + call_initiator->SpawnWaitable( + "push-fragment", + [call_initiator, frame = std::move(frame), this]() mutable { + return call_initiator->CancelIfFails( + PushFragmentIntoCall(*call_initiator, std::move(frame))); + }), + [](StatusFlag status) { return StatusCast(status); }); + }, + [error = std::move(error)]() { return error; }); +} + +auto ChaoticGoodServerTransport::CallOutboundLoop( + uint32_t stream_id, CallInitiator call_initiator) { + auto send_fragment = [stream_id, + outgoing_frames = outgoing_frames_.MakeSender()]( + ServerFragmentFrame frame) mutable { + frame.stream_id = stream_id; + return Map(outgoing_frames.Send(std::move(frame)), + [](bool success) -> absl::Status { + if (!success) { + // Failed to send outgoing frame. + return absl::UnavailableError("Transport closed."); + } + return absl::OkStatus(); + }); + }; + return Seq( + TrySeq( + // Wait for initial metadata then send it out. + call_initiator.PullServerInitialMetadata(), + [send_fragment](ServerMetadataHandle md) mutable { + ServerFragmentFrame frame; + frame.headers = std::move(md); + return send_fragment(std::move(frame)); + }, + // Continuously send client frame with client to server messages. + ForEach(OutgoingMessages(call_initiator), + [send_fragment, aligned_bytes = aligned_bytes_]( + MessageHandle message) mutable { + ServerFragmentFrame frame; + // Construct frame header (flags, header_length and + // trailer_length will be added in serialization). + const uint32_t message_length = + message->payload()->Length(); + const uint32_t padding = + message_length % aligned_bytes == 0 + ? 0 + : aligned_bytes - message_length % aligned_bytes; + GPR_ASSERT((message_length + padding) % aligned_bytes == 0); + frame.message = FragmentMessage(std::move(message), padding, + message_length); + return send_fragment(std::move(frame)); + })), + call_initiator.PullServerTrailingMetadata(), + [send_fragment](ServerMetadataHandle md) mutable { + ServerFragmentFrame frame; + frame.trailers = std::move(md); + return send_fragment(std::move(frame)); + }); +} + +auto ChaoticGoodServerTransport::DeserializeAndPushFragmentToNewCall( + FrameHeader frame_header, BufferPair buffers) { + ClientFragmentFrame fragment_frame; + ScopedArenaPtr arena(acceptor_->CreateArena()); + absl::Status status = transport_.DeserializeFrame( + frame_header, std::move(buffers), arena.get(), fragment_frame); + absl::optional call_initiator; + if (status.ok()) { + auto create_call_result = + acceptor_->CreateCall(*fragment_frame.headers, arena.release()); + if (create_call_result.ok()) { + call_initiator.emplace(std::move(*create_call_result)); + call_initiator->SpawnGuarded( + "server-write", [this, stream_id = frame_header.stream_id, + call_initiator = *call_initiator]() { + return CallOutboundLoop(stream_id, call_initiator); + }); + } else { + status = create_call_result.status(); + } + } + return MaybePushFragmentIntoCall(std::move(call_initiator), std::move(status), + std::move(fragment_frame)); +} + +auto ChaoticGoodServerTransport::DeserializeAndPushFragmentToExistingCall( + FrameHeader frame_header, BufferPair buffers) { + absl::optional call_initiator = + LookupStream(frame_header.stream_id); + Arena* arena = nullptr; + if (call_initiator.has_value()) arena = call_initiator->arena(); + ClientFragmentFrame fragment_frame; + absl::Status status = transport_.DeserializeFrame( + frame_header, std::move(buffers), arena, fragment_frame); + return MaybePushFragmentIntoCall(std::move(call_initiator), std::move(status), + std::move(fragment_frame)); +} + +auto ChaoticGoodServerTransport::TransportReadLoop() { + return Loop([this] { + return TrySeq( + transport_.ReadFrameBytes(), + [this](std::tuple frame_bytes) { + const auto& frame_header = std::get<0>(frame_bytes); + auto& buffers = std::get<1>(frame_bytes); + return Switch( + frame_header.type, + Case(FrameType::kSettings, + []() -> absl::Status { + return absl::InternalError("Unexpected settings frame"); + }), + Case(FrameType::kFragment, + [this, &frame_header, &buffers]() { + return If( + frame_header.flags.is_set(0), + [this, &frame_header, &buffers]() { + return DeserializeAndPushFragmentToNewCall( + frame_header, std::move(buffers)); + }, + [this, &frame_header, &buffers]() { + return DeserializeAndPushFragmentToExistingCall( + frame_header, std::move(buffers)); + }); + }), + Case(FrameType::kCancel, + [this, &frame_header]() { + absl::optional call_initiator = + ExtractStream(frame_header.stream_id); + return If( + call_initiator.has_value(), + [&call_initiator]() { + auto c = std::move(*call_initiator); + return c.SpawnWaitable("cancel", [c]() mutable { + c.Cancel(); + return absl::OkStatus(); + }); + }, + []() -> absl::Status { + return absl::InternalError( + "Unexpected cancel frame"); + }); + }), + Default([frame_header]() { + return absl::InternalError( + absl::StrCat("Unexpected frame type: ", + static_cast(frame_header.type))); + })); + }, + []() -> LoopCtl { return Continue{}; }); + }); +} + +auto ChaoticGoodServerTransport::OnTransportActivityDone() { + return [this](absl::Status status) { + if (!(status.ok() || status.code() == absl::StatusCode::kCancelled)) { + this->AbortWithError(); + } + }; +} + +ChaoticGoodServerTransport::ChaoticGoodServerTransport( + const ChannelArgs& args, std::unique_ptr control_endpoint, + std::unique_ptr data_endpoint, + std::shared_ptr event_engine) + : outgoing_frames_(4), + transport_(std::move(control_endpoint), std::move(data_endpoint)), + allocator_(args.GetObject() + ->memory_quota() + ->CreateMemoryAllocator("chaotic-good")), + event_engine_(event_engine), + writer_{MakeActivity(TransportWriteLoop(), + EventEngineWakeupScheduler(event_engine), + OnTransportActivityDone())}, + reader_{nullptr} {} + +void ChaoticGoodServerTransport::SetAcceptor(Acceptor* acceptor) { + GPR_ASSERT(acceptor_ == nullptr); + GPR_ASSERT(acceptor != nullptr); + acceptor_ = acceptor; + reader_ = MakeActivity(TransportReadLoop(), + EventEngineWakeupScheduler(event_engine_), + OnTransportActivityDone()); +} + +ChaoticGoodServerTransport::~ChaoticGoodServerTransport() { + if (writer_ != nullptr) { + writer_.reset(); + } + if (reader_ != nullptr) { + reader_.reset(); + } +} + +void ChaoticGoodServerTransport::AbortWithError() { + // Mark transport as unavailable when the endpoint write/read failed. + // Close all the available pipes. + outgoing_frames_.MarkClosed(); + ReleasableMutexLock lock(&mu_); + StreamMap stream_map = std::move(stream_map_); + stream_map_.clear(); + lock.Release(); + for (const auto& pair : stream_map) { + auto call_initiator = pair.second; + call_initiator.SpawnInfallible("cancel", [call_initiator]() mutable { + call_initiator.Cancel(); + return Empty{}; + }); + } +} + +absl::optional ChaoticGoodServerTransport::LookupStream( + uint32_t stream_id) { + MutexLock lock(&mu_); + auto it = stream_map_.find(stream_id); + if (it == stream_map_.end()) return absl::nullopt; + return it->second; +} + +absl::optional ChaoticGoodServerTransport::ExtractStream( + uint32_t stream_id) { + MutexLock lock(&mu_); + auto it = stream_map_.find(stream_id); + if (it == stream_map_.end()) return absl::nullopt; + auto r = std::move(it->second); + stream_map_.erase(it); + return std::move(r); +} + +} // namespace chaotic_good +} // namespace grpc_core diff --git a/src/core/ext/transport/chaotic_good/server_transport.h b/src/core/ext/transport/chaotic_good/server_transport.h new file mode 100644 index 00000000000..9ce92928385 --- /dev/null +++ b/src/core/ext/transport/chaotic_good/server_transport.h @@ -0,0 +1,145 @@ +// Copyright 2022 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GRPC_SRC_CORE_EXT_TRANSPORT_CHAOTIC_GOOD_SERVER_TRANSPORT_H +#define GRPC_SRC_CORE_EXT_TRANSPORT_CHAOTIC_GOOD_SERVER_TRANSPORT_H + +#include + +#include +#include + +#include +#include // IWYU pragma: keep +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/functional/any_invocable.h" +#include "absl/random/random.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/types/optional.h" +#include "absl/types/variant.h" + +#include +#include +#include +#include + +#include "src/core/ext/transport/chaotic_good/chaotic_good_transport.h" +#include "src/core/ext/transport/chaotic_good/frame.h" +#include "src/core/ext/transport/chaotic_good/frame_header.h" +#include "src/core/ext/transport/chttp2/transport/hpack_encoder.h" +#include "src/core/ext/transport/chttp2/transport/hpack_parser.h" +#include "src/core/lib/event_engine/default_event_engine.h" // IWYU pragma: keep +#include "src/core/lib/gprpp/ref_counted_ptr.h" +#include "src/core/lib/gprpp/sync.h" +#include "src/core/lib/promise/activity.h" +#include "src/core/lib/promise/context.h" +#include "src/core/lib/promise/if.h" +#include "src/core/lib/promise/inter_activity_pipe.h" +#include "src/core/lib/promise/loop.h" +#include "src/core/lib/promise/mpsc.h" +#include "src/core/lib/promise/party.h" +#include "src/core/lib/promise/pipe.h" +#include "src/core/lib/promise/poll.h" +#include "src/core/lib/promise/seq.h" +#include "src/core/lib/promise/try_join.h" +#include "src/core/lib/promise/try_seq.h" +#include "src/core/lib/resource_quota/arena.h" +#include "src/core/lib/resource_quota/memory_quota.h" +#include "src/core/lib/slice/slice.h" +#include "src/core/lib/slice/slice_buffer.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/transport/metadata_batch.h" +#include "src/core/lib/transport/promise_endpoint.h" +#include "src/core/lib/transport/transport.h" + +namespace grpc_core { +namespace chaotic_good { + +class ChaoticGoodServerTransport final : public Transport, + public ServerTransport { + public: + ChaoticGoodServerTransport( + const ChannelArgs& args, + std::unique_ptr control_endpoint, + std::unique_ptr data_endpoint, + std::shared_ptr + event_engine); + ~ChaoticGoodServerTransport() override; + + FilterStackTransport* filter_stack_transport() override { return nullptr; } + ClientTransport* client_transport() override { return nullptr; } + ServerTransport* server_transport() override { return this; } + absl::string_view GetTransportName() const override { return "chaotic_good"; } + void SetPollset(grpc_stream*, grpc_pollset*) override {} + void SetPollsetSet(grpc_stream*, grpc_pollset_set*) override {} + void PerformOp(grpc_transport_op*) override { Crash("unimplemented"); } + grpc_endpoint* GetEndpoint() override { return nullptr; } + void Orphan() override { delete this; } + + void SetAcceptor(Acceptor* acceptor) override; + void AbortWithError(); + + private: + using StreamMap = absl::flat_hash_map; + + absl::Status NewStream(uint32_t stream_id, CallInitiator call_initiator); + absl::optional LookupStream(uint32_t stream_id); + absl::optional ExtractStream(uint32_t stream_id); + auto CallOutboundLoop(uint32_t stream_id, CallInitiator call_initiator); + auto OnTransportActivityDone(); + auto TransportReadLoop(); + auto TransportWriteLoop(); + // Read different parts of the server frame from control/data endpoints + // based on frame header. + // Resolves to a StatusOr> + auto ReadFrameBody(Slice read_buffer); + void SendCancel(uint32_t stream_id, absl::Status why); + auto DeserializeAndPushFragmentToNewCall(FrameHeader frame_header, + BufferPair buffers); + auto DeserializeAndPushFragmentToExistingCall(FrameHeader frame_header, + BufferPair buffers); + auto MaybePushFragmentIntoCall(absl::optional call_initiator, + absl::Status error, ClientFragmentFrame frame); + auto PushFragmentIntoCall(CallInitiator call_initiator, + ClientFragmentFrame frame); + + Acceptor* acceptor_ = nullptr; + MpscReceiver outgoing_frames_; + ChaoticGoodTransport transport_; + // Assigned aligned bytes from setting frame. + size_t aligned_bytes_ = 64; + Mutex mu_; + // Map of stream incoming server frames, key is stream_id. + StreamMap stream_map_ ABSL_GUARDED_BY(mu_); + grpc_event_engine::experimental::MemoryAllocator allocator_; + std::shared_ptr event_engine_; + ActivityPtr writer_; + ActivityPtr reader_; +}; + +} // namespace chaotic_good +} // namespace grpc_core + +#endif // GRPC_SRC_CORE_EXT_TRANSPORT_CHAOTIC_GOOD_SERVER_TRANSPORT_H \ No newline at end of file diff --git a/src/core/ext/transport/inproc/inproc_transport.cc b/src/core/ext/transport/inproc/inproc_transport.cc index cc932d0f216..24fdd41d387 100644 --- a/src/core/ext/transport/inproc/inproc_transport.cc +++ b/src/core/ext/transport/inproc/inproc_transport.cc @@ -36,8 +36,8 @@ class InprocServerTransport final : public RefCounted, public Transport, public ServerTransport { public: - void SetAcceptFunction(AcceptFunction accept_function) override { - accept_ = std::move(accept_function); + void SetAcceptor(Acceptor* acceptor) override { + acceptor_ = acceptor; ConnectionState expect = ConnectionState::kInitial; state_.compare_exchange_strong(expect, ConnectionState::kReady, std::memory_order_acq_rel, @@ -92,7 +92,7 @@ class InprocServerTransport final : public RefCounted, case ConnectionState::kReady: break; } - return accept_(md); + return acceptor_->CreateCall(md, acceptor_->CreateArena()); } private: @@ -100,7 +100,7 @@ class InprocServerTransport final : public RefCounted, std::atomic state_{ConnectionState::kInitial}; std::atomic disconnecting_{false}; - AcceptFunction accept_; + Acceptor* acceptor_; absl::Status disconnect_error_; Mutex state_tracker_mu_; ConnectivityStateTracker state_tracker_ ABSL_GUARDED_BY(state_tracker_mu_){ diff --git a/src/core/lib/gprpp/debug_location.h b/src/core/lib/gprpp/debug_location.h index 7e021fd9f78..c6c9b682869 100644 --- a/src/core/lib/gprpp/debug_location.h +++ b/src/core/lib/gprpp/debug_location.h @@ -81,6 +81,15 @@ class DebugLocation { }; #endif +template +struct ValueWithDebugLocation { + // NOLINTNEXTLINE + ValueWithDebugLocation(T&& value, DebugLocation debug_location = {}) + : value(std::forward(value)), debug_location(debug_location) {} + T value; + GPR_NO_UNIQUE_ADDRESS DebugLocation debug_location; +}; + #define DEBUG_LOCATION ::grpc_core::DebugLocation(__FILE__, __LINE__) } // namespace grpc_core diff --git a/src/core/lib/promise/detail/status.h b/src/core/lib/promise/detail/status.h index 1063f329193..bfc649e6c48 100644 --- a/src/core/lib/promise/detail/status.h +++ b/src/core/lib/promise/detail/status.h @@ -45,6 +45,11 @@ inline absl::Status IntoStatus(absl::Status* status) { // can participate in TrySeq as result types that affect control flow. inline bool IsStatusOk(const absl::Status& status) { return status.ok(); } +template +inline bool IsStatusOk(const absl::StatusOr& status) { + return status.ok(); +} + template struct StatusCastImpl; @@ -59,20 +64,52 @@ struct StatusCastImpl { }; template -struct StatusCastImpl, absl::Status> { - static absl::StatusOr Cast(absl::Status&& t) { return std::move(t); } +struct StatusCastImpl> { + static absl::Status Cast(absl::StatusOr&& t) { + return std::move(t.status()); + } }; template -struct StatusCastImpl, const absl::Status&> { - static absl::StatusOr Cast(const absl::Status& t) { return t; } +struct StatusCastImpl&> { + static absl::Status Cast(const absl::StatusOr& t) { return t.status(); } }; +template +struct StatusCastImpl&> { + static absl::Status Cast(const absl::StatusOr& t) { return t.status(); } +}; + +// StatusCast<> allows casting from one status-bearing type to another, +// regardless of whether the status indicates success or failure. +// This means that we can go from StatusOr to Status safely, but not in the +// opposite direction. +// For cases where the status is guaranteed to be a failure (and hence not +// needing to preserve values) see FailureStatusCast<> below. template To StatusCast(From&& from) { return StatusCastImpl::Cast(std::forward(from)); } +template +struct FailureStatusCastImpl : public StatusCastImpl {}; + +template +struct FailureStatusCastImpl, absl::Status> { + static absl::StatusOr Cast(absl::Status&& t) { return std::move(t); } +}; + +template +struct FailureStatusCastImpl, const absl::Status&> { + static absl::StatusOr Cast(const absl::Status& t) { return t; } +}; + +template +To FailureStatusCast(From&& from) { + GPR_DEBUG_ASSERT(!IsStatusOk(from)); + return FailureStatusCastImpl::Cast(std::forward(from)); +} + } // namespace grpc_core #endif // GRPC_SRC_CORE_LIB_PROMISE_DETAIL_STATUS_H diff --git a/src/core/lib/promise/event_engine_wakeup_scheduler.h b/src/core/lib/promise/event_engine_wakeup_scheduler.h index 792ee9d4439..3e489c87fc6 100644 --- a/src/core/lib/promise/event_engine_wakeup_scheduler.h +++ b/src/core/lib/promise/event_engine_wakeup_scheduler.h @@ -33,7 +33,9 @@ class EventEngineWakeupScheduler { explicit EventEngineWakeupScheduler( std::shared_ptr event_engine) - : event_engine_(std::move(event_engine)) {} + : event_engine_(std::move(event_engine)) { + GPR_ASSERT(event_engine_ != nullptr); + } template class BoundScheduler diff --git a/src/core/lib/promise/if.h b/src/core/lib/promise/if.h index e659ad30ed8..2c81fd7e38a 100644 --- a/src/core/lib/promise/if.h +++ b/src/core/lib/promise/if.h @@ -192,6 +192,10 @@ class If { // If it returns failure, returns failure for the entire combinator. // If it returns true, evaluates the second promise. // If it returns false, evaluates the third promise. +// If C is a constant, it's guaranteed that one of the promise factories +// if_true or if_false will be evaluated before returning from this function. +// This makes it safe to capture lambda arguments in the promise factory by +// reference. template promise_detail::If If(C condition, T if_true, F if_false) { return promise_detail::If(std::move(condition), std::move(if_true), diff --git a/src/core/lib/promise/inter_activity_pipe.h b/src/core/lib/promise/inter_activity_pipe.h index a7594fb26a2..4578cbb3c62 100644 --- a/src/core/lib/promise/inter_activity_pipe.h +++ b/src/core/lib/promise/inter_activity_pipe.h @@ -113,9 +113,9 @@ class InterActivityPipe { if (center_ != nullptr) center_->MarkClosed(); } - bool IsClose() { return center_->IsClosed(); } + bool IsClosed() { return center_->IsClosed(); } - void MarkClose() { + void MarkClosed() { if (center_ != nullptr) center_->MarkClosed(); } @@ -146,6 +146,12 @@ class InterActivityPipe { return [center = center_]() { return center->Next(); }; } + bool IsClose() { return center_->IsClosed(); } + + void MarkClose() { + if (center_ != nullptr) center_->MarkClosed(); + } + private: RefCountedPtr
center_; }; diff --git a/src/core/lib/promise/mpsc.h b/src/core/lib/promise/mpsc.h index c12544282f6..8bbbfc4c8ec 100644 --- a/src/core/lib/promise/mpsc.h +++ b/src/core/lib/promise/mpsc.h @@ -103,14 +103,12 @@ class Center : public RefCounted> { // Mark that the receiver is closed. void ReceiverClosed() { - MutexLock lock(&mu_); + ReleasableMutexLock lock(&mu_); + if (receiver_closed_) return; receiver_closed_ = true; - } - - // Return whether the receiver is closed. - bool IsClosed() { - MutexLock lock(&mu_); - return receiver_closed_; + auto wakeups = send_wakers_.TakeWakeupSet(); + lock.Release(); + wakeups.Wakeup(); } private: @@ -131,8 +129,8 @@ class MpscReceiver; template class MpscSender { public: - MpscSender(const MpscSender&) = delete; - MpscSender& operator=(const MpscSender&) = delete; + MpscSender(const MpscSender&) = default; + MpscSender& operator=(const MpscSender&) = default; MpscSender(MpscSender&&) noexcept = default; MpscSender& operator=(MpscSender&&) noexcept = default; @@ -140,7 +138,10 @@ class MpscSender { // Resolves to true if sent, false if the receiver was closed (and the value // will never be successfully sent). auto Send(T t) { - return [this, t = std::move(t)]() mutable { return center_->PollSend(t); }; + return [center = center_, t = std::move(t)]() mutable -> Poll { + if (center == nullptr) return false; + return center->PollSend(t); + }; } bool UnbufferedImmediateSend(T t) { @@ -170,7 +171,6 @@ class MpscReceiver { ~MpscReceiver() { if (center_ != nullptr) center_->ReceiverClosed(); } - bool IsClosed() { return center_->IsClosed(); } void MarkClosed() { if (center_ != nullptr) center_->ReceiverClosed(); } diff --git a/src/core/lib/promise/status_flag.h b/src/core/lib/promise/status_flag.h index d9067509c32..c8c9c0ba41e 100644 --- a/src/core/lib/promise/status_flag.h +++ b/src/core/lib/promise/status_flag.h @@ -95,6 +95,30 @@ struct StatusCastImpl { } }; +template +struct FailureStatusCastImpl, StatusFlag> { + static absl::StatusOr Cast(StatusFlag flag) { + GPR_DEBUG_ASSERT(!flag.ok()); + return absl::CancelledError(); + } +}; + +template +struct FailureStatusCastImpl, StatusFlag&> { + static absl::StatusOr Cast(StatusFlag flag) { + GPR_DEBUG_ASSERT(!flag.ok()); + return absl::CancelledError(); + } +}; + +template +struct FailureStatusCastImpl, const StatusFlag&> { + static absl::StatusOr Cast(StatusFlag flag) { + GPR_DEBUG_ASSERT(!flag.ok()); + return absl::CancelledError(); + } +}; + // A value if an operation was successful, or a failure flag if not. template class ValueOrFailure { diff --git a/src/core/lib/promise/try_join.h b/src/core/lib/promise/try_join.h index be3354dc9b7..29cd06d5be4 100644 --- a/src/core/lib/promise/try_join.h +++ b/src/core/lib/promise/try_join.h @@ -75,16 +75,16 @@ struct TryJoinTraits { } template static R EarlyReturn(absl::Status x) { - return StatusCast(std::move(x)); + return FailureStatusCast(std::move(x)); } template static R EarlyReturn(StatusFlag x) { - return StatusCast(x); + return FailureStatusCast(x); } template static R EarlyReturn(const ValueOrFailure& x) { GPR_ASSERT(!x.ok()); - return StatusCast(Failure{}); + return FailureStatusCast(Failure{}); } template static auto FinalReturn(A&&... a) { diff --git a/src/core/lib/promise/try_seq.h b/src/core/lib/promise/try_seq.h index ca04904ab1d..ab9777d1afb 100644 --- a/src/core/lib/promise/try_seq.h +++ b/src/core/lib/promise/try_seq.h @@ -76,7 +76,7 @@ struct TrySeqTraitsWithSfinae> { } template static R ReturnValue(absl::StatusOr&& status) { - return StatusCast(status.status()); + return FailureStatusCast(status.status()); } template static auto CallSeqFactory(F& f, Elem&& elem, absl::StatusOr value) @@ -86,11 +86,26 @@ struct TrySeqTraitsWithSfinae> { template static Poll CheckResultAndRunNext(absl::StatusOr prior, RunNext run_next) { - if (!prior.ok()) return StatusCast(prior.status()); + if (!prior.ok()) return FailureStatusCast(prior.status()); return run_next(std::move(prior)); } }; +template +struct AllowGenericTrySeqTraits { + static constexpr bool value = true; +}; + +template <> +struct AllowGenericTrySeqTraits { + static constexpr bool value = false; +}; + +template +struct AllowGenericTrySeqTraits> { + static constexpr bool value = false; +}; + template struct TakeValueExists { static constexpr bool value = false; @@ -107,7 +122,7 @@ template struct TrySeqTraitsWithSfinae< T, absl::enable_if_t< std::is_same())), bool>::value && - !TakeValueExists::value, + !TakeValueExists::value && AllowGenericTrySeqTraits::value, void>> { using UnwrappedType = void; using WrappedType = T; @@ -121,7 +136,7 @@ struct TrySeqTraitsWithSfinae< } template static R ReturnValue(T&& status) { - return StatusCast(std::move(status)); + return FailureStatusCast(std::move(status)); } template static Poll CheckResultAndRunNext(T prior, RunNext run_next) { @@ -133,7 +148,7 @@ template struct TrySeqTraitsWithSfinae< T, absl::enable_if_t< std::is_same())), bool>::value && - TakeValueExists::value, + TakeValueExists::value && AllowGenericTrySeqTraits::value, void>> { using UnwrappedType = decltype(TakeValue(std::declval())); using WrappedType = T; @@ -148,7 +163,7 @@ struct TrySeqTraitsWithSfinae< template static R ReturnValue(T&& status) { GPR_DEBUG_ASSERT(!IsStatusOk(status)); - return StatusCast(status.status()); + return FailureStatusCast(status.status()); } template static Poll CheckResultAndRunNext(T prior, RunNext run_next) { @@ -170,7 +185,7 @@ struct TrySeqTraitsWithSfinae { } template static R ReturnValue(absl::Status&& status) { - return StatusCast(std::move(status)); + return FailureStatusCast(std::move(status)); } template static Poll CheckResultAndRunNext(absl::Status prior, diff --git a/src/core/lib/resource_quota/arena.h b/src/core/lib/resource_quota/arena.h index 9c0c812d4c3..edcab2caf87 100644 --- a/src/core/lib/resource_quota/arena.h +++ b/src/core/lib/resource_quota/arena.h @@ -180,7 +180,7 @@ class Arena { template T* New(Args&&... args) { T* t = static_cast(Alloc(sizeof(T))); - Construct(t, std::forward(args)...); + new (t) T(std::forward(args)...); return t; } @@ -333,7 +333,7 @@ class Arena { // value in Arena::PoolSizes, and so this may pessimize total // arena size. template - PoolPtr MakePooled(Args&&... args) { + static PoolPtr MakePooled(Args&&... args) { return PoolPtr(new T(std::forward(args)...), PooledDeleter()); } diff --git a/src/core/lib/slice/slice_buffer.h b/src/core/lib/slice/slice_buffer.h index 2626bd4a0e9..0c1bfe9d901 100644 --- a/src/core/lib/slice/slice_buffer.h +++ b/src/core/lib/slice/slice_buffer.h @@ -50,6 +50,9 @@ namespace grpc_core { class SliceBuffer { public: explicit SliceBuffer() { grpc_slice_buffer_init(&slice_buffer_); } + explicit SliceBuffer(Slice slice) : SliceBuffer() { + Append(std::move(slice)); + } SliceBuffer(const SliceBuffer& other) = delete; SliceBuffer(SliceBuffer&& other) noexcept { grpc_slice_buffer_init(&slice_buffer_); diff --git a/src/core/lib/surface/call.cc b/src/core/lib/surface/call.cc index d5dbc54be39..2702bedc0e7 100644 --- a/src/core/lib/surface/call.cc +++ b/src/core/lib/surface/call.cc @@ -4063,16 +4063,13 @@ void ServerCallSpine::CommitBatch(const grpc_op* ops, size_t nops, } RefCountedPtr MakeServerCall(Server* server, - Channel* channel) { - const auto initial_size = channel->CallSizeEstimate(); - global_stats().IncrementCallInitialSize(initial_size); - auto alloc = Arena::CreateWithAlloc(initial_size, sizeof(ServerCallSpine), - channel->allocator()); - auto* call = new (alloc.second) ServerCallSpine(server, channel, alloc.first); - return RefCountedPtr(call); + Channel* channel, + Arena* arena) { + return RefCountedPtr( + arena->New(server, channel, arena)); } #else -RefCountedPtr MakeServerCall(Server*, Channel*) { +RefCountedPtr MakeServerCall(Server*, Channel*, Arena*) { Crash("not implemented"); } #endif diff --git a/src/core/lib/surface/call.h b/src/core/lib/surface/call.h index 6653bb6a0dd..520cf13505c 100644 --- a/src/core/lib/surface/call.h +++ b/src/core/lib/surface/call.h @@ -160,7 +160,8 @@ template <> struct ContextType {}; RefCountedPtr MakeServerCall(Server* server, - Channel* channel); + Channel* channel, + Arena* arena); } // namespace grpc_core diff --git a/src/core/lib/surface/server.cc b/src/core/lib/surface/server.cc index 44b541a8593..e53a609cfb1 100644 --- a/src/core/lib/surface/server.cc +++ b/src/core/lib/surface/server.cc @@ -51,6 +51,7 @@ #include "src/core/lib/channel/channel_trace.h" #include "src/core/lib/channel/channelz.h" #include "src/core/lib/config/core_configuration.h" +#include "src/core/lib/debug/stats.h" #include "src/core/lib/experiments/experiments.h" #include "src/core/lib/gpr/useful.h" #include "src/core/lib/gprpp/crash.h" @@ -1297,6 +1298,20 @@ Server::ChannelData::~ChannelData() { } } +Arena* Server::ChannelData::CreateArena() { + const auto initial_size = channel_->CallSizeEstimate(); + global_stats().IncrementCallInitialSize(initial_size); + return Arena::Create(initial_size, channel_->allocator()); +} + +absl::StatusOr Server::ChannelData::CreateCall( + ClientMetadata& client_initial_metadata, Arena* arena) { + SetRegisteredMethodOnMetadata(client_initial_metadata); + auto call = MakeServerCall(server_.get(), channel_.get(), arena); + InitCall(call); + return CallInitiator(std::move(call)); +} + void Server::ChannelData::InitTransport(RefCountedPtr server, RefCountedPtr channel, size_t cq_idx, Transport* transport, @@ -1329,13 +1344,7 @@ void Server::ChannelData::InitTransport(RefCountedPtr server, } if (transport->server_transport() != nullptr) { ++accept_stream_types; - transport->server_transport()->SetAcceptFunction( - [this](ClientMetadata& metadata) { - SetRegisteredMethodOnMetadata(metadata); - auto call = MakeServerCall(server_.get(), channel_.get()); - InitCall(call); - return CallInitiator(std::move(call)); - }); + transport->server_transport()->SetAcceptor(this); } GPR_ASSERT(accept_stream_types == 1); op->start_connectivity_watch = MakeOrphanable(this); diff --git a/src/core/lib/surface/server.h b/src/core/lib/surface/server.h index 11ec7c68a45..4bb6fce3fae 100644 --- a/src/core/lib/surface/server.h +++ b/src/core/lib/surface/server.h @@ -218,7 +218,7 @@ class Server : public InternallyRefCounted, class AllocatingRequestMatcherBatch; class AllocatingRequestMatcherRegistered; - class ChannelData { + class ChannelData final : public ServerTransport::Acceptor { public: ChannelData() = default; ~ChannelData(); @@ -241,6 +241,10 @@ class Server : public InternallyRefCounted, grpc_channel_element* elem, CallArgs call_args, NextPromiseFactory); void InitCall(RefCountedPtr call); + Arena* CreateArena() override; + absl::StatusOr CreateCall( + ClientMetadata& client_initial_metadata, Arena* arena) override; + private: class ConnectivityWatcher; diff --git a/src/core/lib/transport/promise_endpoint.h b/src/core/lib/transport/promise_endpoint.h index 2c2b3a2d37c..fbdc467cbb8 100644 --- a/src/core/lib/transport/promise_endpoint.h +++ b/src/core/lib/transport/promise_endpoint.h @@ -69,24 +69,26 @@ class PromiseEndpoint { auto Write(SliceBuffer data) { // Assert previous write finishes. GPR_ASSERT(!write_state_->complete.load(std::memory_order_relaxed)); - // TODO(ladynana): Replace this with `SliceBufferCast<>` when it is - // available. - grpc_slice_buffer_swap(write_state_->buffer.c_slice_buffer(), - data.c_slice_buffer()); - // If `Write()` returns true immediately, the callback will not be called. - // We still need to call our callback to pick up the result. - write_state_->waker = Activity::current()->MakeNonOwningWaker(); - const bool completed = endpoint_->Write( - [write_state = write_state_](absl::Status status) { - write_state->Complete(std::move(status)); - }, - &write_state_->buffer, nullptr /* uses default arguments */); + bool completed; + if (data.Length() == 0) { + completed = true; + } else { + // TODO(ladynana): Replace this with `SliceBufferCast<>` when it is + // available. + grpc_slice_buffer_swap(write_state_->buffer.c_slice_buffer(), + data.c_slice_buffer()); + // If `Write()` returns true immediately, the callback will not be called. + // We still need to call our callback to pick up the result. + write_state_->waker = Activity::current()->MakeNonOwningWaker(); + completed = endpoint_->Write( + [write_state = write_state_](absl::Status status) { + write_state->Complete(std::move(status)); + }, + &write_state_->buffer, nullptr /* uses default arguments */); + if (completed) write_state_->waker = Waker(); + } return If( - completed, - [this]() { - write_state_->waker = Waker(); - return []() { return absl::OkStatus(); }; - }, + completed, []() { return []() { return absl::OkStatus(); }; }, [this]() { return [write_state = write_state_]() -> Poll { // If current write isn't finished return `Pending()`, else return diff --git a/src/core/lib/transport/transport.cc b/src/core/lib/transport/transport.cc index ab405804065..bad5c4b8590 100644 --- a/src/core/lib/transport/transport.cc +++ b/src/core/lib/transport/transport.cc @@ -291,9 +291,8 @@ void ForwardCall(CallHandler call_handler, CallInitiator call_initiator, return call_initiator.SpawnWaitable( "send_message", [msg = std::move(msg), call_initiator]() mutable { - return call_initiator.CancelIfFails(Map( - call_initiator.PushMessage(std::move(msg)), - [](bool r) { return StatusFlag(r); })); + return call_initiator.CancelIfFails( + call_initiator.PushMessage(std::move(msg))); }); }); }); @@ -317,8 +316,7 @@ void ForwardCall(CallHandler call_handler, CallInitiator call_initiator, "recv_message", [msg = std::move(msg), call_handler]() mutable { return call_handler.CancelIfFails( - Map(call_handler.PushMessage(std::move(msg)), - [](bool r) { return StatusFlag(r); })); + call_handler.PushMessage(std::move(msg))); }); }), ImmediateOkStatus())), @@ -334,4 +332,10 @@ void ForwardCall(CallHandler call_handler, CallInitiator call_initiator, }); } +CallInitiatorAndHandler MakeCall( + grpc_event_engine::experimental::EventEngine* event_engine, Arena* arena) { + auto spine = CallSpine::Create(event_engine, arena); + return {CallInitiator(spine), CallHandler(spine)}; +} + } // namespace grpc_core diff --git a/src/core/lib/transport/transport.h b/src/core/lib/transport/transport.h index c9f138c8f09..5a7e09bdb43 100644 --- a/src/core/lib/transport/transport.h +++ b/src/core/lib/transport/transport.h @@ -258,6 +258,20 @@ class CallSpineInterface { virtual Pipe& server_to_client_messages() = 0; virtual Pipe& server_trailing_metadata() = 0; virtual Latch& cancel_latch() = 0; + // Add a callback to be called when server trailing metadata is received. + void OnDone(absl::AnyInvocable fn) { + if (on_done_ == nullptr) { + on_done_ = std::move(fn); + return; + } + on_done_ = [first = std::move(fn), next = std::move(on_done_)]() mutable { + first(); + next(); + }; + } + void CallOnDone() { + if (on_done_ != nullptr) std::exchange(on_done_, nullptr)(); + } virtual Party& party() = 0; virtual void IncrementRefCount() = 0; virtual void Unref() = 0; @@ -276,6 +290,11 @@ class CallSpineInterface { auto& c = cancel_latch(); if (c.is_set()) return absl::nullopt; c.Set(std::move(metadata)); + CallOnDone(); + client_initial_metadata().sender.CloseWithError(); + server_initial_metadata().sender.CloseWithError(); + client_to_server_messages().sender.CloseWithError(); + server_to_client_messages().sender.CloseWithError(); return absl::nullopt; } @@ -325,11 +344,18 @@ class CallSpineInterface { } }); } + + private: + absl::AnyInvocable on_done_{nullptr}; }; -class CallSpine final : public CallSpineInterface { +class CallSpine final : public CallSpineInterface, public Party { public: - CallSpine() { Crash("unimplemented"); } + static RefCountedPtr Create( + grpc_event_engine::experimental::EventEngine* event_engine, + Arena* arena) { + return RefCountedPtr(arena->New(event_engine, arena)); + } Pipe& client_initial_metadata() override { return client_initial_metadata_; @@ -347,23 +373,57 @@ class CallSpine final : public CallSpineInterface { return server_trailing_metadata_; } Latch& cancel_latch() override { return cancel_latch_; } - Party& party() override { Crash("unimplemented"); } - void IncrementRefCount() override { Crash("unimplemented"); } - void Unref() override { Crash("unimplemented"); } + Party& party() override { return *this; } + void IncrementRefCount() override { Party::IncrementRefCount(); } + void Unref() override { Party::Unref(); } private: + friend class Arena; + CallSpine(grpc_event_engine::experimental::EventEngine* event_engine, + Arena* arena) + : Party(arena, 1), event_engine_(event_engine) {} + + class ScopedContext : public ScopedActivity, + public promise_detail::Context { + public: + explicit ScopedContext(CallSpine* spine) + : ScopedActivity(&spine->party()), Context(spine->arena()) {} + }; + + bool RunParty() override { + ScopedContext context(this); + return Party::RunParty(); + } + + void PartyOver() override { + Arena* a = arena(); + { + ScopedContext context(this); + CancelRemainingParticipants(); + a->DestroyManagedNewObjects(); + } + this->~CallSpine(); + a->Destroy(); + } + + grpc_event_engine::experimental::EventEngine* event_engine() const override { + return event_engine_; + } + // Initial metadata from client to server - Pipe client_initial_metadata_; + Pipe client_initial_metadata_{arena()}; // Initial metadata from server to client - Pipe server_initial_metadata_; + Pipe server_initial_metadata_{arena()}; // Messages travelling from the application to the transport. - Pipe client_to_server_messages_; + Pipe client_to_server_messages_{arena()}; // Messages travelling from the transport to the application. - Pipe server_to_client_messages_; + Pipe server_to_client_messages_{arena()}; // Trailing metadata from server to client - Pipe server_trailing_metadata_; + Pipe server_trailing_metadata_{arena()}; // Latch that can be set to terminate the call Latch cancel_latch_; + // Event engine associated with this call + grpc_event_engine::experimental::EventEngine* const event_engine_; }; class CallInitiator { @@ -405,7 +465,14 @@ class CallInitiator { auto PushMessage(MessageHandle message) { GPR_DEBUG_ASSERT(Activity::current() == &spine_->party()); - return spine_->client_to_server_messages().sender.Push(std::move(message)); + return Map( + spine_->client_to_server_messages().sender.Push(std::move(message)), + [](bool r) { return StatusFlag(r); }); + } + + void FinishSends() { + GPR_DEBUG_ASSERT(Activity::current() == &spine_->party()); + spine_->client_to_server_messages().sender.Close(); } template @@ -413,6 +480,12 @@ class CallInitiator { return spine_->CancelIfFails(std::move(promise)); } + void Cancel() { + GPR_DEBUG_ASSERT(Activity::current() == &spine_->party()); + std::ignore = + spine_->Cancel(ServerMetadataFromStatus(absl::CancelledError())); + } + template void SpawnGuarded(absl::string_view name, PromiseFactory promise_factory) { spine_->SpawnGuarded(name, std::move(promise_factory)); @@ -428,8 +501,10 @@ class CallInitiator { return spine_->party().SpawnWaitable(name, std::move(promise_factory)); } + Arena* arena() { return spine_->party().arena(); } + private: - const RefCountedPtr spine_; + RefCountedPtr spine_; }; class CallHandler { @@ -447,14 +522,16 @@ class CallHandler { }); } - auto PushServerInitialMetadata(ClientMetadataHandle md) { + auto PushServerInitialMetadata(ServerMetadataHandle md) { GPR_DEBUG_ASSERT(Activity::current() == &spine_->party()); return Map(spine_->server_initial_metadata().sender.Push(std::move(md)), [](bool ok) { return StatusFlag(ok); }); } - auto PushServerTrailingMetadata(ClientMetadataHandle md) { + auto PushServerTrailingMetadata(ServerMetadataHandle md) { GPR_DEBUG_ASSERT(Activity::current() == &spine_->party()); + spine_->server_to_client_messages().sender.Close(); + spine_->CallOnDone(); return Map(spine_->server_trailing_metadata().sender.Push(std::move(md)), [](bool ok) { return StatusFlag(ok); }); } @@ -466,9 +543,18 @@ class CallHandler { auto PushMessage(MessageHandle message) { GPR_DEBUG_ASSERT(Activity::current() == &spine_->party()); - return spine_->server_to_client_messages().sender.Push(std::move(message)); + return Map( + spine_->server_to_client_messages().sender.Push(std::move(message)), + [](bool ok) { return StatusFlag(ok); }); } + void Cancel(ServerMetadataHandle status) { + GPR_DEBUG_ASSERT(Activity::current() == &spine_->party()); + std::ignore = spine_->Cancel(std::move(status)); + } + + void OnDone(absl::AnyInvocable fn) { spine_->OnDone(std::move(fn)); } + template auto CancelIfFails(Promise promise) { return spine_->CancelIfFails(std::move(promise)); @@ -489,8 +575,10 @@ class CallHandler { return spine_->party().SpawnWaitable(name, std::move(promise_factory)); } + Arena* arena() { return spine_->party().arena(); } + private: - const RefCountedPtr spine_; + RefCountedPtr spine_; }; struct CallInitiatorAndHandler { @@ -498,13 +586,16 @@ struct CallInitiatorAndHandler { CallHandler handler; }; +CallInitiatorAndHandler MakeCall( + grpc_event_engine::experimental::EventEngine* event_engine, Arena* arena); + template -auto OutgoingMessages(CallHalf& h) { +auto OutgoingMessages(CallHalf h) { struct Wrapper { - CallHalf& h; + CallHalf h; auto Next() { return h.PullMessage(); } }; - return Wrapper{h}; + return Wrapper{std::move(h)}; } // Forward a call from `call_handler` to `call_initiator` (with initial metadata @@ -925,14 +1016,24 @@ class ClientTransport { class ServerTransport { public: - // AcceptFunction takes initial metadata for a new call and returns a - // CallInitiator object for it, for the transport to use to communicate with - // the CallHandler object passed to the application. - using AcceptFunction = - absl::AnyInvocable(ClientMetadata&) const>; + // Acceptor helps transports create calls. + class Acceptor { + public: + // Returns an arena that can be used to allocate memory for initial metadata + // parsing, and later passed to CreateCall() as the underlying arena for + // that call. + virtual Arena* CreateArena() = 0; + // Create a call at the server (or fail) + // arena must have been previously allocated by CreateArena() + virtual absl::StatusOr CreateCall( + ClientMetadata& client_initial_metadata, Arena* arena) = 0; + + protected: + ~Acceptor() = default; + }; // Called once slightly after transport setup to register the accept function. - virtual void SetAcceptFunction(AcceptFunction accept_function) = 0; + virtual void SetAcceptor(Acceptor* acceptor) = 0; protected: ~ServerTransport() = default; diff --git a/test/core/promise/mpsc_test.cc b/test/core/promise/mpsc_test.cc index 38baa68cb60..3d6e669a173 100644 --- a/test/core/promise/mpsc_test.cc +++ b/test/core/promise/mpsc_test.cc @@ -95,6 +95,8 @@ TEST(MpscTest, SendingLotsOfThingsGivesPushback) { EXPECT_EQ(NowOrNever(sender.Send(MakePayload(1))), true); EXPECT_EQ(NowOrNever(sender.Send(MakePayload(2))), absl::nullopt); activity1.Deactivate(); + + EXPECT_CALL(activity1, WakeupRequested()); } TEST(MpscTest, ReceivingAfterBlockageWakesUp) { diff --git a/test/core/transport/chaotic_good/BUILD b/test/core/transport/chaotic_good/BUILD index fc698aeec44..11daa792100 100644 --- a/test/core/transport/chaotic_good/BUILD +++ b/test/core/transport/chaotic_good/BUILD @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("//bazel:grpc_build_system.bzl", "grpc_cc_test", "grpc_package") -load("//test/core/util:grpc_fuzzer.bzl", "grpc_fuzzer") +load("//bazel:grpc_build_system.bzl", "grpc_cc_library", "grpc_cc_test", "grpc_package") +load("//test/core/util:grpc_fuzzer.bzl", "grpc_fuzzer", "grpc_proto_fuzzer") licenses(["notice"]) @@ -22,6 +22,34 @@ grpc_package( visibility = "tests", ) +grpc_cc_library( + name = "mock_promise_endpoint", + testonly = 1, + srcs = ["mock_promise_endpoint.cc"], + hdrs = ["mock_promise_endpoint.h"], + external_deps = ["gtest"], + deps = [ + "//:grpc", + "//src/core:grpc_promise_endpoint", + ], +) + +grpc_cc_library( + name = "transport_test", + testonly = 1, + srcs = ["transport_test.cc"], + hdrs = ["transport_test.h"], + external_deps = ["gtest"], + deps = [ + "//:iomgr_timer", + "//src/core:chaotic_good_frame", + "//src/core:memory_quota", + "//src/core:resource_quota", + "//test/core/event_engine/fuzzing_event_engine", + "//test/core/event_engine/fuzzing_event_engine:fuzzing_event_engine_proto", + ], +) + grpc_cc_test( name = "frame_header_test", srcs = ["frame_header_test.cc"], @@ -54,7 +82,7 @@ grpc_cc_test( deps = ["//src/core:chaotic_good_frame"], ) -grpc_fuzzer( +grpc_proto_fuzzer( name = "frame_fuzzer", srcs = ["frame_fuzzer.cc"], corpus = "frame_fuzzer_corpus", @@ -63,7 +91,10 @@ grpc_fuzzer( "absl/status:statusor", ], language = "C++", + proto = "frame_fuzzer.proto", tags = ["no_windows"], + uses_event_engine = False, + uses_polling = False, deps = [ "//:exec_ctx", "//:gpr", @@ -96,18 +127,46 @@ grpc_cc_test( uses_event_engine = False, uses_polling = False, deps = [ + "mock_promise_endpoint", + "transport_test", "//:grpc", "//:grpc_public_hdrs", + "//src/core:arena", + "//src/core:chaotic_good_client_transport", + "//src/core:if", + "//src/core:loop", + "//src/core:seq", + "//src/core:slice_buffer", + ], +) + +grpc_cc_test( + name = "client_transport_error_test", + srcs = ["client_transport_error_test.cc"], + external_deps = [ + "absl/functional:any_invocable", + "absl/status", + "absl/status:statusor", + "absl/strings:str_format", + "absl/types:optional", + "gtest", + ], + language = "C++", + uses_event_engine = False, + uses_polling = False, + deps = [ + "//:grpc_public_hdrs", + "//:grpc_unsecure", "//:iomgr_timer", "//:ref_counted_ptr", "//src/core:activity", "//src/core:arena", "//src/core:chaotic_good_client_transport", "//src/core:event_engine_wakeup_scheduler", + "//src/core:grpc_promise_endpoint", "//src/core:if", "//src/core:join", "//src/core:loop", - "//src/core:map", "//src/core:memory_quota", "//src/core:pipe", "//src/core:resource_quota", @@ -120,8 +179,8 @@ grpc_cc_test( ) grpc_cc_test( - name = "client_transport_error_test", - srcs = ["client_transport_error_test.cc"], + name = "server_transport_test", + srcs = ["server_transport_test.cc"], external_deps = [ "absl/functional:any_invocable", "absl/status", @@ -134,20 +193,15 @@ grpc_cc_test( uses_event_engine = False, uses_polling = False, deps = [ + "mock_promise_endpoint", + "transport_test", "//:grpc", "//:grpc_public_hdrs", "//:iomgr_timer", "//:ref_counted_ptr", - "//src/core:activity", "//src/core:arena", - "//src/core:chaotic_good_client_transport", - "//src/core:event_engine_wakeup_scheduler", - "//src/core:grpc_promise_endpoint", - "//src/core:if", - "//src/core:join", - "//src/core:loop", + "//src/core:chaotic_good_server_transport", "//src/core:memory_quota", - "//src/core:pipe", "//src/core:resource_quota", "//src/core:seq", "//src/core:slice", diff --git a/test/core/transport/chaotic_good/client_transport_error_test.cc b/test/core/transport/chaotic_good/client_transport_error_test.cc index 3b30c4ca330..295e060b809 100644 --- a/test/core/transport/chaotic_good/client_transport_error_test.cc +++ b/test/core/transport/chaotic_good/client_transport_error_test.cc @@ -12,37 +12,31 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "absl/status/status.h" - -#include "src/core/ext/transport/chaotic_good/client_transport.h" -#include "src/core/lib/transport/promise_endpoint.h" -#include "src/core/lib/transport/transport.h" - -// IWYU pragma: no_include - #include -#include // IWYU pragma: keep +#include #include -#include // IWYU pragma: keep +#include #include #include -#include // IWYU pragma: keep +#include #include "absl/functional/any_invocable.h" -#include "absl/status/statusor.h" // IWYU pragma: keep -#include "absl/strings/str_format.h" // IWYU pragma: keep -#include "absl/types/optional.h" // IWYU pragma: keep +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/types/optional.h" #include "gmock/gmock.h" #include "gtest/gtest.h" #include #include -#include // IWYU pragma: keep +#include #include #include -#include // IWYU pragma: keep +#include +#include "src/core/ext/transport/chaotic_good/client_transport.h" #include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/iomgr/timer_manager.h" #include "src/core/lib/promise/activity.h" @@ -56,14 +50,16 @@ #include "src/core/lib/resource_quota/memory_quota.h" #include "src/core/lib/resource_quota/resource_quota.h" #include "src/core/lib/slice/slice_buffer.h" -#include "src/core/lib/slice/slice_internal.h" // IWYU pragma: keep -#include "src/core/lib/transport/metadata_batch.h" // IWYU pragma: keep +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/transport/metadata_batch.h" +#include "src/core/lib/transport/promise_endpoint.h" +#include "src/core/lib/transport/transport.h" #include "test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.h" #include "test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.pb.h" +using testing::AtMost; using testing::MockFunction; using testing::Return; -using testing::Sequence; using testing::StrictMock; using testing::WithArgs; @@ -98,333 +94,308 @@ class MockEndpoint GetLocalAddress, (), (const, override)); }; +struct MockPromiseEndpoint { + StrictMock* endpoint = new StrictMock(); + std::unique_ptr promise_endpoint = + std::make_unique( + std::unique_ptr>(endpoint), SliceBuffer()); +}; + +// Send messages from client to server. +auto SendClientToServerMessages(CallInitiator initiator, int num_messages) { + return Loop([initiator, num_messages]() mutable { + bool has_message = (num_messages > 0); + return If( + has_message, + Seq(initiator.PushMessage(GetContext()->MakePooled()), + [&num_messages]() -> LoopCtl { + --num_messages; + return Continue(); + }), + [initiator]() mutable -> LoopCtl { + initiator.FinishSends(); + return absl::OkStatus(); + }); + }); +} + +ClientMetadataHandle TestInitialMetadata() { + auto md = + GetContext()->MakePooled(GetContext()); + md->Set(HttpPathMetadata(), Slice::FromStaticString("/test")); + return md; +} + class ClientTransportTest : public ::testing::Test { - public: - ClientTransportTest() - : control_endpoint_ptr_(new StrictMock()), - data_endpoint_ptr_(new StrictMock()), - memory_allocator_( - ResourceQuota::Default()->memory_quota()->CreateMemoryAllocator( - "test")), - control_endpoint_(*control_endpoint_ptr_), - data_endpoint_(*data_endpoint_ptr_), - event_engine_(std::make_shared< - grpc_event_engine::experimental::FuzzingEventEngine>( - []() { - grpc_timer_manager_set_threading(false); - grpc_event_engine::experimental::FuzzingEventEngine::Options - options; - return options; - }(), - fuzzing_event_engine::Actions())), - arena_(MakeScopedArena(initial_arena_size, &memory_allocator_)), - pipe_client_to_server_messages_(arena_.get()), - pipe_server_to_client_messages_(arena_.get()), - pipe_server_intial_metadata_(arena_.get()), - pipe_client_to_server_messages_second_(arena_.get()), - pipe_server_to_client_messages_second_(arena_.get()), - pipe_server_intial_metadata_second_(arena_.get()) {} - // Initial ClientTransport with read expecations - void InitialClientTransport() { - client_transport_ = std::make_unique( - std::make_unique( - std::unique_ptr(control_endpoint_ptr_), - SliceBuffer()), - std::make_unique( - std::unique_ptr(data_endpoint_ptr_), SliceBuffer()), - event_engine_); - } - // Send messages from client to server. - auto SendClientToServerMessages( - Pipe& pipe_client_to_server_messages, - int num_of_messages) { - return Loop([&pipe_client_to_server_messages, num_of_messages, - this]() mutable { - bool has_message = (num_of_messages > 0); - return If( - has_message, - Seq(pipe_client_to_server_messages.sender.Push( - arena_->MakePooled()), - [&num_of_messages]() -> LoopCtl { - num_of_messages--; - return Continue(); - }), - [&pipe_client_to_server_messages]() mutable -> LoopCtl { - pipe_client_to_server_messages.sender.Close(); - return absl::OkStatus(); - }); - }); - } - // Add stream into client transport, and expect return trailers of - // "grpc-status:code". - auto AddStream(CallArgs args) { - return client_transport_->AddStream(std::move(args)); + protected: + const std::shared_ptr& + event_engine() { + return event_engine_; } + MemoryAllocator* memory_allocator() { return &allocator_; } private: - MockEndpoint* control_endpoint_ptr_; - MockEndpoint* data_endpoint_ptr_; - size_t initial_arena_size = 1024; - MemoryAllocator memory_allocator_; - - protected: - MockEndpoint& control_endpoint_; - MockEndpoint& data_endpoint_; std::shared_ptr - event_engine_; - std::unique_ptr client_transport_; - ScopedArenaPtr arena_; - Pipe pipe_client_to_server_messages_; - Pipe pipe_server_to_client_messages_; - Pipe pipe_server_intial_metadata_; - // Added for mutliple streams tests. - Pipe pipe_client_to_server_messages_second_; - Pipe pipe_server_to_client_messages_second_; - Pipe pipe_server_intial_metadata_second_; - absl::AnyInvocable read_callback_; - Sequence control_endpoint_sequence_; - Sequence data_endpoint_sequence_; - // Added to verify received message payload. - const std::string message_ = {0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08}; + event_engine_{ + std::make_shared( + []() { + grpc_timer_manager_set_threading(false); + grpc_event_engine::experimental::FuzzingEventEngine::Options + options; + return options; + }(), + fuzzing_event_engine::Actions())}; + MemoryAllocator allocator_ = MakeResourceQuota("test-quota") + ->memory_quota() + ->CreateMemoryAllocator("test-allocator"); }; TEST_F(ClientTransportTest, AddOneStreamWithWriteFailed) { + MockPromiseEndpoint control_endpoint; + MockPromiseEndpoint data_endpoint; // Mock write failed and read is pending. - EXPECT_CALL(control_endpoint_, Write) + EXPECT_CALL(*control_endpoint.endpoint, Write) + .Times(AtMost(1)) .WillOnce( WithArgs<0>([](absl::AnyInvocable on_write) { on_write(absl::InternalError("control endpoint write failed.")); return false; })); - EXPECT_CALL(data_endpoint_, Write) + EXPECT_CALL(*data_endpoint.endpoint, Write) + .Times(AtMost(1)) .WillOnce( WithArgs<0>([](absl::AnyInvocable on_write) { on_write(absl::InternalError("data endpoint write failed.")); return false; })); - EXPECT_CALL(control_endpoint_, Read) - .InSequence(control_endpoint_sequence_) - .WillOnce(Return(false)); - InitialClientTransport(); - ClientMetadataHandle md; - auto args = CallArgs{std::move(md), - ClientInitialMetadataOutstandingToken::Empty(), - nullptr, - &pipe_server_intial_metadata_.sender, - &pipe_client_to_server_messages_.receiver, - &pipe_server_to_client_messages_.sender}; - StrictMock> on_done; - EXPECT_CALL(on_done, Call(absl::OkStatus())); - auto activity = MakeActivity( - Seq( - // Concurrently: write and read messages in client transport. - Join( - // Add first stream with call_args into client transport. - // Expect return trailers "grpc-status:unavailable". - AddStream(std::move(args)), - // Send messages to call_args.client_to_server_messages pipe, - // which will be eventually sent to control/data endpoints. - SendClientToServerMessages(pipe_client_to_server_messages_, 1)), - // Once complete, verify successful sending and the received value. - [](const std::tuple& ret) { - EXPECT_EQ(std::get<0>(ret)->get(GrpcStatusMetadata()).value(), - GRPC_STATUS_UNAVAILABLE); - EXPECT_TRUE(std::get<1>(ret).ok()); - return absl::OkStatus(); - }), - EventEngineWakeupScheduler(event_engine_), - [&on_done](absl::Status status) { on_done.Call(std::move(status)); }); + EXPECT_CALL(*control_endpoint.endpoint, Read).WillOnce(Return(false)); + auto transport = MakeOrphanable( + std::move(control_endpoint.promise_endpoint), + std::move(data_endpoint.promise_endpoint), event_engine()); + auto call = + MakeCall(event_engine().get(), Arena::Create(8192, memory_allocator())); + transport->StartCall(std::move(call.handler)); + call.initiator.SpawnGuarded("test-send", [initiator = + call.initiator]() mutable { + return TrySeq(initiator.PushClientInitialMetadata(TestInitialMetadata()), + SendClientToServerMessages(initiator, 1)); + }); + StrictMock> on_done; + EXPECT_CALL(on_done, Call()); + call.initiator.SpawnInfallible( + "test-read", [&on_done, initiator = call.initiator]() mutable { + return Seq( + initiator.PullServerInitialMetadata(), + [](ValueOrFailure md) { + EXPECT_FALSE(md.ok()); + return Empty{}; + }, + initiator.PullServerTrailingMetadata(), + [&on_done](ServerMetadataHandle md) { + EXPECT_EQ(md->get(GrpcStatusMetadata()).value(), + GRPC_STATUS_UNAVAILABLE); + on_done.Call(); + return Empty{}; + }); + }); // Wait until ClientTransport's internal activities to finish. - event_engine_->TickUntilIdle(); - event_engine_->UnsetGlobalHooks(); + event_engine()->TickUntilIdle(); + event_engine()->UnsetGlobalHooks(); } TEST_F(ClientTransportTest, AddOneStreamWithReadFailed) { + MockPromiseEndpoint control_endpoint; + MockPromiseEndpoint data_endpoint; // Mock read failed. - EXPECT_CALL(control_endpoint_, Read) - .InSequence(control_endpoint_sequence_) + EXPECT_CALL(*control_endpoint.endpoint, Read) .WillOnce(WithArgs<0>( [](absl::AnyInvocable on_read) mutable { on_read(absl::InternalError("control endpoint read failed.")); // Return false to mock EventEngine read not finish. return false; })); - InitialClientTransport(); - ClientMetadataHandle md; - auto args = CallArgs{std::move(md), - ClientInitialMetadataOutstandingToken::Empty(), - nullptr, - &pipe_server_intial_metadata_.sender, - &pipe_client_to_server_messages_.receiver, - &pipe_server_to_client_messages_.sender}; - StrictMock> on_done; - EXPECT_CALL(on_done, Call(absl::OkStatus())); - auto activity = MakeActivity( - Seq( - // Concurrently: write and read messages in client transport. - Join( - // Add first stream with call_args into client transport. - // Expect return trailers "grpc-status:unavailable". - AddStream(std::move(args)), - // Send messages to call_args.client_to_server_messages pipe. - SendClientToServerMessages(pipe_client_to_server_messages_, 1)), - // Once complete, verify successful sending and the received value. - [](const std::tuple& ret) { - EXPECT_EQ(std::get<0>(ret)->get(GrpcStatusMetadata()).value(), - GRPC_STATUS_UNAVAILABLE); - EXPECT_TRUE(std::get<1>(ret).ok()); - return absl::OkStatus(); - }), - EventEngineWakeupScheduler(event_engine_), - [&on_done](absl::Status status) { on_done.Call(std::move(status)); }); + auto transport = MakeOrphanable( + std::move(control_endpoint.promise_endpoint), + std::move(data_endpoint.promise_endpoint), event_engine()); + auto call = + MakeCall(event_engine().get(), Arena::Create(8192, memory_allocator())); + transport->StartCall(std::move(call.handler)); + call.initiator.SpawnGuarded("test-send", [initiator = + call.initiator]() mutable { + return TrySeq(initiator.PushClientInitialMetadata(TestInitialMetadata()), + SendClientToServerMessages(initiator, 1)); + }); + StrictMock> on_done; + EXPECT_CALL(on_done, Call()); + call.initiator.SpawnInfallible( + "test-read", [&on_done, initiator = call.initiator]() mutable { + return Seq( + initiator.PullServerInitialMetadata(), + [](ValueOrFailure md) { + EXPECT_FALSE(md.ok()); + return Empty{}; + }, + initiator.PullServerTrailingMetadata(), + [&on_done](ServerMetadataHandle md) { + EXPECT_EQ(md->get(GrpcStatusMetadata()).value(), + GRPC_STATUS_UNAVAILABLE); + on_done.Call(); + return Empty{}; + }); + }); // Wait until ClientTransport's internal activities to finish. - event_engine_->TickUntilIdle(); - event_engine_->UnsetGlobalHooks(); + event_engine()->TickUntilIdle(); + event_engine()->UnsetGlobalHooks(); } TEST_F(ClientTransportTest, AddMultipleStreamWithWriteFailed) { // Mock write failed at first stream and second stream's write will fail too. - EXPECT_CALL(control_endpoint_, Write) - .Times(1) + MockPromiseEndpoint control_endpoint; + MockPromiseEndpoint data_endpoint; + EXPECT_CALL(*control_endpoint.endpoint, Write) + .Times(AtMost(1)) .WillRepeatedly( WithArgs<0>([](absl::AnyInvocable on_write) { on_write(absl::InternalError("control endpoint write failed.")); return false; })); - EXPECT_CALL(data_endpoint_, Write) - .Times(1) + EXPECT_CALL(*data_endpoint.endpoint, Write) + .Times(AtMost(1)) .WillRepeatedly( WithArgs<0>([](absl::AnyInvocable on_write) { on_write(absl::InternalError("data endpoint write failed.")); return false; })); - EXPECT_CALL(control_endpoint_, Read) - .InSequence(control_endpoint_sequence_) - .WillOnce(Return(false)); - InitialClientTransport(); - ClientMetadataHandle first_stream_md; - ClientMetadataHandle second_stream_md; - auto first_stream_args = - CallArgs{std::move(first_stream_md), - ClientInitialMetadataOutstandingToken::Empty(), - nullptr, - &pipe_server_intial_metadata_.sender, - &pipe_client_to_server_messages_.receiver, - &pipe_server_to_client_messages_.sender}; - auto second_stream_args = - CallArgs{std::move(second_stream_md), - ClientInitialMetadataOutstandingToken::Empty(), - nullptr, - &pipe_server_intial_metadata_second_.sender, - &pipe_client_to_server_messages_second_.receiver, - &pipe_server_to_client_messages_second_.sender}; - StrictMock> on_done; - EXPECT_CALL(on_done, Call(absl::OkStatus())); - auto activity = MakeActivity( - Seq( - // Concurrently: write and read messages from client transport. - Join( - // Add first stream with call_args into client transport. - // Expect return trailers "grpc-status:unavailable". - AddStream(std::move(first_stream_args)), - // Send messages to first stream's - // call_args.client_to_server_messages pipe. - SendClientToServerMessages(pipe_client_to_server_messages_, 1)), - // Once complete, verify successful sending and the received value. - [](const std::tuple& ret) { - EXPECT_EQ(std::get<0>(ret)->get(GrpcStatusMetadata()).value(), - GRPC_STATUS_UNAVAILABLE); - EXPECT_TRUE(std::get<1>(ret).ok()); - return absl::OkStatus(); - }, - Join( - // Add second stream with call_args into client transport. - // Expect return trailers "grpc-status:unavailable". - AddStream(std::move(second_stream_args)), - // Send messages to second stream's - // call_args.client_to_server_messages pipe. - SendClientToServerMessages(pipe_client_to_server_messages_second_, - 1)), - // Once complete, verify successful sending and the received value. - [](const std::tuple& ret) { - EXPECT_EQ(std::get<0>(ret)->get(GrpcStatusMetadata()).value(), - GRPC_STATUS_UNAVAILABLE); - EXPECT_TRUE(std::get<1>(ret).ok()); - return absl::OkStatus(); - }), - EventEngineWakeupScheduler(event_engine_), - [&on_done](absl::Status status) { on_done.Call(std::move(status)); }); + EXPECT_CALL(*control_endpoint.endpoint, Read).WillOnce(Return(false)); + auto transport = MakeOrphanable( + std::move(control_endpoint.promise_endpoint), + std::move(data_endpoint.promise_endpoint), event_engine()); + auto call1 = + MakeCall(event_engine().get(), Arena::Create(8192, memory_allocator())); + transport->StartCall(std::move(call1.handler)); + auto call2 = + MakeCall(event_engine().get(), Arena::Create(8192, memory_allocator())); + transport->StartCall(std::move(call2.handler)); + call1.initiator.SpawnGuarded("test-send-1", [initiator = + call1.initiator]() mutable { + return TrySeq(initiator.PushClientInitialMetadata(TestInitialMetadata()), + SendClientToServerMessages(initiator, 1)); + }); + call2.initiator.SpawnGuarded("test-send-2", [initiator = + call2.initiator]() mutable { + return TrySeq(initiator.PushClientInitialMetadata(TestInitialMetadata()), + SendClientToServerMessages(initiator, 1)); + }); + StrictMock> on_done1; + EXPECT_CALL(on_done1, Call()); + StrictMock> on_done2; + EXPECT_CALL(on_done2, Call()); + call1.initiator.SpawnInfallible( + "test-read-1", [&on_done1, initiator = call1.initiator]() mutable { + return Seq( + initiator.PullServerInitialMetadata(), + [](ValueOrFailure md) { + EXPECT_FALSE(md.ok()); + return Empty{}; + }, + initiator.PullServerTrailingMetadata(), + [&on_done1](ServerMetadataHandle md) { + EXPECT_EQ(md->get(GrpcStatusMetadata()).value(), + GRPC_STATUS_UNAVAILABLE); + on_done1.Call(); + return Empty{}; + }); + }); + call2.initiator.SpawnInfallible( + "test-read-2", [&on_done2, initiator = call2.initiator]() mutable { + return Seq( + initiator.PullServerInitialMetadata(), + [](ValueOrFailure md) { + EXPECT_FALSE(md.ok()); + return Empty{}; + }, + initiator.PullServerTrailingMetadata(), + [&on_done2](ServerMetadataHandle md) { + EXPECT_EQ(md->get(GrpcStatusMetadata()).value(), + GRPC_STATUS_UNAVAILABLE); + on_done2.Call(); + return Empty{}; + }); + }); // Wait until ClientTransport's internal activities to finish. - event_engine_->TickUntilIdle(); - event_engine_->UnsetGlobalHooks(); + event_engine()->TickUntilIdle(); + event_engine()->UnsetGlobalHooks(); } TEST_F(ClientTransportTest, AddMultipleStreamWithReadFailed) { + MockPromiseEndpoint control_endpoint; + MockPromiseEndpoint data_endpoint; // Mock read failed at first stream, and second stream's write will fail too. - EXPECT_CALL(control_endpoint_, Read) - .InSequence(control_endpoint_sequence_) + EXPECT_CALL(*control_endpoint.endpoint, Read) .WillOnce(WithArgs<0>( [](absl::AnyInvocable on_read) mutable { on_read(absl::InternalError("control endpoint read failed.")); // Return false to mock EventEngine read not finish. return false; })); - InitialClientTransport(); - ClientMetadataHandle first_stream_md; - ClientMetadataHandle second_stream_md; - auto first_stream_args = - CallArgs{std::move(first_stream_md), - ClientInitialMetadataOutstandingToken::Empty(), - nullptr, - &pipe_server_intial_metadata_.sender, - &pipe_client_to_server_messages_.receiver, - &pipe_server_to_client_messages_.sender}; - auto second_stream_args = - CallArgs{std::move(second_stream_md), - ClientInitialMetadataOutstandingToken::Empty(), - nullptr, - &pipe_server_intial_metadata_second_.sender, - &pipe_client_to_server_messages_second_.receiver, - &pipe_server_to_client_messages_second_.sender}; - StrictMock> on_done; - EXPECT_CALL(on_done, Call(absl::OkStatus())); - auto activity = MakeActivity( - Seq( - // Concurrently: write and read messages from client transport. - Join( - // Add first stream with call_args into client transport. - AddStream(std::move(first_stream_args)), - // Send messages to first stream's - // call_args.client_to_server_messages pipe, which will be - // eventually sent to control/data endpoints. - SendClientToServerMessages(pipe_client_to_server_messages_, 1)), - // Once complete, verify successful sending and the received value. - [](const std::tuple& ret) { - EXPECT_EQ(std::get<0>(ret)->get(GrpcStatusMetadata()).value(), - GRPC_STATUS_UNAVAILABLE); - EXPECT_TRUE(std::get<1>(ret).ok()); - return absl::OkStatus(); - }, - Join( - // Add second stream with call_args into client transport. - AddStream(std::move(second_stream_args)), - // Send messages to second stream's - // call_args.client_to_server_messages pipe, which will be - // eventually sent to control/data endpoints. - SendClientToServerMessages(pipe_client_to_server_messages_second_, - 1)), - // Once complete, verify successful sending and the received value. - [](const std::tuple& ret) { - EXPECT_EQ(std::get<0>(ret)->get(GrpcStatusMetadata()).value(), - GRPC_STATUS_UNAVAILABLE); - EXPECT_TRUE(std::get<1>(ret).ok()); - return absl::OkStatus(); - }), - EventEngineWakeupScheduler(event_engine_), - [&on_done](absl::Status status) { on_done.Call(std::move(status)); }); + auto transport = MakeOrphanable( + std::move(control_endpoint.promise_endpoint), + std::move(data_endpoint.promise_endpoint), event_engine()); + auto call1 = + MakeCall(event_engine().get(), Arena::Create(8192, memory_allocator())); + transport->StartCall(std::move(call1.handler)); + auto call2 = + MakeCall(event_engine().get(), Arena::Create(8192, memory_allocator())); + transport->StartCall(std::move(call2.handler)); + call1.initiator.SpawnGuarded("test-send", [initiator = + call1.initiator]() mutable { + return TrySeq(initiator.PushClientInitialMetadata(TestInitialMetadata()), + SendClientToServerMessages(initiator, 1)); + }); + call2.initiator.SpawnGuarded("test-send", [initiator = + call2.initiator]() mutable { + return TrySeq(initiator.PushClientInitialMetadata(TestInitialMetadata()), + SendClientToServerMessages(initiator, 1)); + }); + StrictMock> on_done1; + EXPECT_CALL(on_done1, Call()); + StrictMock> on_done2; + EXPECT_CALL(on_done2, Call()); + call1.initiator.SpawnInfallible( + "test-read", [&on_done1, initiator = call1.initiator]() mutable { + return Seq( + initiator.PullServerInitialMetadata(), + [](ValueOrFailure md) { + EXPECT_FALSE(md.ok()); + return Empty{}; + }, + initiator.PullServerTrailingMetadata(), + [&on_done1](ServerMetadataHandle md) { + EXPECT_EQ(md->get(GrpcStatusMetadata()).value(), + GRPC_STATUS_UNAVAILABLE); + on_done1.Call(); + return Empty{}; + }); + }); + call2.initiator.SpawnInfallible( + "test-read", [&on_done2, initiator = call2.initiator]() mutable { + return Seq( + initiator.PullServerInitialMetadata(), + [](ValueOrFailure md) { + EXPECT_FALSE(md.ok()); + return Empty{}; + }, + initiator.PullServerTrailingMetadata(), + [&on_done2](ServerMetadataHandle md) { + EXPECT_EQ(md->get(GrpcStatusMetadata()).value(), + GRPC_STATUS_UNAVAILABLE); + on_done2.Call(); + return Empty{}; + }); + }); // Wait until ClientTransport's internal activities to finish. - event_engine_->TickUntilIdle(); - event_engine_->UnsetGlobalHooks(); + event_engine()->TickUntilIdle(); + event_engine()->UnsetGlobalHooks(); } } // namespace testing diff --git a/test/core/transport/chaotic_good/client_transport_test.cc b/test/core/transport/chaotic_good/client_transport_test.cc index 86bbf578c29..551c6fd5762 100644 --- a/test/core/transport/chaotic_good/client_transport_test.cc +++ b/test/core/transport/chaotic_good/client_transport_test.cc @@ -14,461 +14,245 @@ #include "src/core/ext/transport/chaotic_good/client_transport.h" -// IWYU pragma: no_include - -#include // IWYU pragma: keep +#include +#include +#include #include -#include // IWYU pragma: keep +#include #include -#include // IWYU pragma: keep +#include #include "absl/functional/any_invocable.h" -#include "absl/status/statusor.h" // IWYU pragma: keep -#include "absl/strings/str_format.h" // IWYU pragma: keep -#include "absl/types/optional.h" // IWYU pragma: keep +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/types/optional.h" #include "gmock/gmock.h" #include "gtest/gtest.h" #include #include -#include // IWYU pragma: keep +#include #include #include -#include // IWYU pragma: keep +#include -#include "src/core/lib/gprpp/ref_counted_ptr.h" -#include "src/core/lib/iomgr/timer_manager.h" -#include "src/core/lib/promise/activity.h" -#include "src/core/lib/promise/event_engine_wakeup_scheduler.h" #include "src/core/lib/promise/if.h" -#include "src/core/lib/promise/join.h" #include "src/core/lib/promise/loop.h" -#include "src/core/lib/promise/map.h" -#include "src/core/lib/promise/pipe.h" #include "src/core/lib/promise/seq.h" #include "src/core/lib/resource_quota/arena.h" -#include "src/core/lib/resource_quota/memory_quota.h" -#include "src/core/lib/resource_quota/resource_quota.h" #include "src/core/lib/slice/slice_buffer.h" -#include "src/core/lib/slice/slice_internal.h" // IWYU pragma: keep -#include "src/core/lib/transport/metadata_batch.h" // IWYU pragma: keep -#include "test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.h" -#include "test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.pb.h" +#include "src/core/lib/transport/metadata_batch.h" +#include "test/core/transport/chaotic_good/mock_promise_endpoint.h" +#include "test/core/transport/chaotic_good/transport_test.h" using testing::MockFunction; using testing::Return; -using testing::Sequence; using testing::StrictMock; -using testing::WithArgs; + +using EventEngineSlice = grpc_event_engine::experimental::Slice; namespace grpc_core { namespace chaotic_good { namespace testing { -class MockEndpoint - : public grpc_event_engine::experimental::EventEngine::Endpoint { - public: - MOCK_METHOD( - bool, Read, - (absl::AnyInvocable on_read, - grpc_event_engine::experimental::SliceBuffer* buffer, - const grpc_event_engine::experimental::EventEngine::Endpoint::ReadArgs* - args), - (override)); +// Encoded string of header ":path: /demo.Service/Step". +const uint8_t kPathDemoServiceStep[] = { + 0x40, 0x05, 0x3a, 0x70, 0x61, 0x74, 0x68, 0x12, 0x2f, + 0x64, 0x65, 0x6d, 0x6f, 0x2e, 0x53, 0x65, 0x72, 0x76, + 0x69, 0x63, 0x65, 0x2f, 0x53, 0x74, 0x65, 0x70}; - MOCK_METHOD( - bool, Write, - (absl::AnyInvocable on_writable, - grpc_event_engine::experimental::SliceBuffer* data, - const grpc_event_engine::experimental::EventEngine::Endpoint::WriteArgs* - args), - (override)); +// Encoded string of trailer "grpc-status: 0". +const uint8_t kGrpcStatus0[] = {0x10, 0x0b, 0x67, 0x72, 0x70, 0x63, 0x2d, 0x73, + 0x74, 0x61, 0x74, 0x75, 0x73, 0x01, 0x30}; - MOCK_METHOD( - const grpc_event_engine::experimental::EventEngine::ResolvedAddress&, - GetPeerAddress, (), (const, override)); - MOCK_METHOD( - const grpc_event_engine::experimental::EventEngine::ResolvedAddress&, - GetLocalAddress, (), (const, override)); -}; +ClientMetadataHandle TestInitialMetadata() { + auto md = + GetContext()->MakePooled(GetContext()); + md->Set(HttpPathMetadata(), Slice::FromStaticString("/demo.Service/Step")); + return md; +} -class ClientTransportTest : public ::testing::Test { - public: - ClientTransportTest() - : control_endpoint_ptr_(new StrictMock()), - data_endpoint_ptr_(new StrictMock()), - memory_allocator_( - ResourceQuota::Default()->memory_quota()->CreateMemoryAllocator( - "test")), - control_endpoint_(*control_endpoint_ptr_), - data_endpoint_(*data_endpoint_ptr_), - event_engine_(std::make_shared< - grpc_event_engine::experimental::FuzzingEventEngine>( - []() { - grpc_timer_manager_set_threading(false); - grpc_event_engine::experimental::FuzzingEventEngine::Options - options; - return options; - }(), - fuzzing_event_engine::Actions())), - arena_(MakeScopedArena(initial_arena_size, &memory_allocator_)), - pipe_client_to_server_messages_(arena_.get()), - pipe_server_to_client_messages_(arena_.get()), - pipe_server_intial_metadata_(arena_.get()), - pipe_client_to_server_messages_second_(arena_.get()), - pipe_server_to_client_messages_second_(arena_.get()), - pipe_server_intial_metadata_second_(arena_.get()) {} - // Expect how client transport will read from control/data endpoints with a - // test frame. - void AddReadExpectations(int num_of_streams) { - for (int i = 0; i < num_of_streams; i++) { - EXPECT_CALL(control_endpoint_, Read) - .InSequence(control_endpoint_sequence) - .WillOnce(WithArgs<0, 1>( - [this, i](absl::AnyInvocable on_read, - grpc_event_engine::experimental::SliceBuffer* - buffer) mutable { - // Construct test frame for EventEngine read: headers (15 - // bytes), message(16 bytes), message padding (48 byte), - // trailers (15 bytes). - const std::string frame_header = { - static_cast(0x80), // frame type = fragment - 0x03, // flag = has header + has trailer - 0x00, - 0x00, - static_cast(i + 1), // stream id = 1 - 0x00, - 0x00, - 0x00, - 0x1a, // header length = 26 - 0x00, - 0x00, - 0x00, - 0x08, // message length = 8 - 0x00, - 0x00, - 0x00, - 0x38, // message padding =56 - 0x00, - 0x00, - 0x00, - 0x0f, // trailer length = 15 - 0x00, - 0x00, - 0x00}; - // Schedule mock_endpoint to read buffer. - grpc_event_engine::experimental::Slice slice( - grpc_slice_from_cpp_string(frame_header)); - buffer->Append(std::move(slice)); - // Execute read callback later to control when read starts. - if (i == 0) { - read_callback_ = std::move(on_read); - // Return false to mock EventEngine read not finish. - return false; - } else { - return true; - } - })); - EXPECT_CALL(control_endpoint_, Read) - .InSequence(control_endpoint_sequence) - .WillOnce(WithArgs<1>( - [](grpc_event_engine::experimental::SliceBuffer* buffer) { - // Encoded string of header ":path: /demo.Service/Step". - const std::string header = { - 0x10, 0x05, 0x3a, 0x70, 0x61, 0x74, 0x68, 0x12, 0x2f, - 0x64, 0x65, 0x6d, 0x6f, 0x2e, 0x53, 0x65, 0x72, 0x76, - 0x69, 0x63, 0x65, 0x2f, 0x53, 0x74, 0x65, 0x70}; - // Encoded string of trailer "grpc-status: 0". - const std::string trailers = {0x10, 0x0b, 0x67, 0x72, 0x70, - 0x63, 0x2d, 0x73, 0x74, 0x61, - 0x74, 0x75, 0x73, 0x01, 0x30}; - // Schedule mock_endpoint to read buffer. - grpc_event_engine::experimental::Slice slice( - grpc_slice_from_cpp_string(header + trailers)); - buffer->Append(std::move(slice)); - return true; - })); - } - EXPECT_CALL(control_endpoint_, Read) - .InSequence(control_endpoint_sequence) - .WillOnce(Return(false)); - for (int i = 0; i < num_of_streams; i++) { - EXPECT_CALL(data_endpoint_, Read) - .InSequence(data_endpoint_sequence) - .WillOnce(WithArgs<1>( - [this](grpc_event_engine::experimental::SliceBuffer* buffer) { - const std::string message_padding = { - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}; - grpc_event_engine::experimental::Slice slice( - grpc_slice_from_cpp_string(message_padding + message_)); - buffer->Append(std::move(slice)); - return true; - })); - } - } - // Initial ClientTransport with read expecations - void InitialClientTransport(int num_of_streams) { - // Read expectaions need to be added before transport initialization since - // reader_ activity loop is started in ClientTransport initialization, - AddReadExpectations(num_of_streams); - client_transport_ = std::make_unique( - std::make_unique( - std::unique_ptr(control_endpoint_ptr_), - SliceBuffer()), - std::make_unique( - std::unique_ptr(data_endpoint_ptr_), SliceBuffer()), - event_engine_); - } - // Send messages from client to server. - auto SendClientToServerMessages( - Pipe& pipe_client_to_server_messages, - int num_of_messages) { - return Loop([&pipe_client_to_server_messages, num_of_messages, - this]() mutable { - bool has_message = (num_of_messages > 0); - return If( - has_message, - Seq(pipe_client_to_server_messages.sender.Push( - arena_->MakePooled()), - [&num_of_messages]() -> LoopCtl { - num_of_messages--; - return Continue(); - }), - [&pipe_client_to_server_messages]() mutable -> LoopCtl { - pipe_client_to_server_messages.sender.Close(); - return absl::OkStatus(); - }); - }); - } - // Add stream into client transport, and expect return trailers of - // "grpc-status:code". - auto AddStream(CallArgs args, const grpc_status_code trailers) { - return Seq(client_transport_->AddStream(std::move(args)), - [trailers](ServerMetadataHandle ret) { - // AddStream will finish with server trailers: - // "grpc-status:code". - EXPECT_EQ(ret->get(GrpcStatusMetadata()).value(), trailers); - return trailers; - }); - } - // Start read from control endpoints. - auto StartRead(const absl::Status& read_status) { - return [read_status, this] { - read_callback_(read_status); - return read_status; - }; - } - // Receive messages from server to client. - auto ReceiveServerToClientMessages( - Pipe& pipe_server_intial_metadata, - Pipe& pipe_server_to_client_messages) { - return Seq( - // Receive server initial metadata. - Map(pipe_server_intial_metadata.receiver.Next(), - [](NextResult r) { - // Expect value: ":path: /demo.Service/Step" - EXPECT_TRUE(r.has_value()); - EXPECT_EQ( - r.value()->get_pointer(HttpPathMetadata())->as_string_view(), - "/demo.Service/Step"); - return absl::OkStatus(); - }), - // Receive server to client messages. - Map(pipe_server_to_client_messages.receiver.Next(), - [this](NextResult r) { - EXPECT_TRUE(r.has_value()); - EXPECT_EQ(r.value()->payload()->JoinIntoString(), message_); - return absl::OkStatus(); +// Send messages from client to server. +auto SendClientToServerMessages(CallInitiator initiator, int num_messages) { + return Loop([initiator, num_messages, i = 0]() mutable { + bool has_message = (i < num_messages); + return If( + has_message, + Seq(initiator.PushMessage(GetContext()->MakePooled( + SliceBuffer(Slice::FromCopiedString(std::to_string(i))), 0)), + [&i]() -> LoopCtl { + ++i; + return Continue(); }), - [&pipe_server_intial_metadata, - &pipe_server_to_client_messages]() mutable { - // Close pipes after receive message. - pipe_server_to_client_messages.sender.Close(); - pipe_server_intial_metadata.sender.Close(); + [initiator]() mutable -> LoopCtl { + initiator.FinishSends(); return absl::OkStatus(); }); - } - - private: - MockEndpoint* control_endpoint_ptr_; - MockEndpoint* data_endpoint_ptr_; - size_t initial_arena_size = 1024; - MemoryAllocator memory_allocator_; - Sequence control_endpoint_sequence; - Sequence data_endpoint_sequence; - - protected: - MockEndpoint& control_endpoint_; - MockEndpoint& data_endpoint_; - std::shared_ptr - event_engine_; - std::unique_ptr client_transport_; - ScopedArenaPtr arena_; - Pipe pipe_client_to_server_messages_; - Pipe pipe_server_to_client_messages_; - Pipe pipe_server_intial_metadata_; - // Added for mutliple streams tests. - Pipe pipe_client_to_server_messages_second_; - Pipe pipe_server_to_client_messages_second_; - Pipe pipe_server_intial_metadata_second_; - absl::AnyInvocable read_callback_; - // Added to verify received message payload. - const std::string message_ = {0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08}; -}; - -TEST_F(ClientTransportTest, AddOneStream) { - InitialClientTransport(1); - ClientMetadataHandle md; - auto args = CallArgs{std::move(md), - ClientInitialMetadataOutstandingToken::Empty(), - nullptr, - &pipe_server_intial_metadata_.sender, - &pipe_client_to_server_messages_.receiver, - &pipe_server_to_client_messages_.sender}; - StrictMock> on_done; - EXPECT_CALL(on_done, Call(absl::OkStatus())); - EXPECT_CALL(control_endpoint_, Write).WillOnce(Return(true)); - EXPECT_CALL(data_endpoint_, Write).WillOnce(Return(true)); - auto activity = MakeActivity( - Seq( - // Concurrently: write and read messages in client transport. - Join( - // Add first stream with call_args into client transport. - AddStream(std::move(args), GRPC_STATUS_OK), - // Start read from control endpoints. - StartRead(absl::OkStatus()), - // Send messages to call_args.client_to_server_messages pipe, - // which will be eventually sent to control/data endpoints. - SendClientToServerMessages(pipe_client_to_server_messages_, 1), - // Receive messages from control/data endpoints. - ReceiveServerToClientMessages(pipe_server_intial_metadata_, - pipe_server_to_client_messages_)), - // Once complete, verify successful sending and the received value. - [](const std::tuple& ret) { - EXPECT_EQ(std::get<0>(ret), GRPC_STATUS_OK); - EXPECT_TRUE(std::get<1>(ret).ok()); - EXPECT_TRUE(std::get<2>(ret).ok()); - EXPECT_TRUE(std::get<3>(ret).ok()); - return absl::OkStatus(); - }), - EventEngineWakeupScheduler(event_engine_), - [&on_done](absl::Status status) { on_done.Call(std::move(status)); }); - // Wait until ClientTransport's internal activities to finish. - event_engine_->TickUntilIdle(); - event_engine_->UnsetGlobalHooks(); + }); } -TEST_F(ClientTransportTest, AddOneStreamMultipleMessages) { - InitialClientTransport(1); - ClientMetadataHandle md; - auto args = CallArgs{std::move(md), - ClientInitialMetadataOutstandingToken::Empty(), - nullptr, - &pipe_server_intial_metadata_.sender, - &pipe_client_to_server_messages_.receiver, - &pipe_server_to_client_messages_.sender}; - StrictMock> on_done; - EXPECT_CALL(on_done, Call(absl::OkStatus())); - EXPECT_CALL(control_endpoint_, Write).Times(3).WillRepeatedly(Return(true)); - EXPECT_CALL(data_endpoint_, Write).Times(3).WillRepeatedly(Return(true)); - auto activity = MakeActivity( - Seq( - // Concurrently: write and read messages in client transport. - Join( - // Add first stream with call_args into client transport. - AddStream(std::move(args), GRPC_STATUS_OK), - // Start read from control endpoints. - StartRead(absl::OkStatus()), - // Send messages to call_args.client_to_server_messages pipe, - // which will be eventually sent to control/data endpoints. - SendClientToServerMessages(pipe_client_to_server_messages_, 3), - // Receive messages from control/data endpoints. - ReceiveServerToClientMessages(pipe_server_intial_metadata_, - pipe_server_to_client_messages_)), - // Once complete, verify successful sending and the received value. - [](const std::tuple& ret) { - EXPECT_EQ(std::get<0>(ret), GRPC_STATUS_OK); - EXPECT_TRUE(std::get<1>(ret).ok()); - EXPECT_TRUE(std::get<2>(ret).ok()); - EXPECT_TRUE(std::get<3>(ret).ok()); - return absl::OkStatus(); - }), - EventEngineWakeupScheduler(event_engine_), - [&on_done](absl::Status status) { on_done.Call(std::move(status)); }); +TEST_F(TransportTest, AddOneStream) { + MockPromiseEndpoint control_endpoint; + MockPromiseEndpoint data_endpoint; + control_endpoint.ExpectRead( + {SerializedFrameHeader(FrameType::kFragment, 7, 1, 26, 8, 56, 15), + EventEngineSlice::FromCopiedBuffer(kPathDemoServiceStep, + sizeof(kPathDemoServiceStep)), + EventEngineSlice::FromCopiedBuffer(kGrpcStatus0, sizeof(kGrpcStatus0))}, + event_engine().get()); + data_endpoint.ExpectRead( + {EventEngineSlice::FromCopiedString("12345678"), Zeros(56)}, nullptr); + EXPECT_CALL(*control_endpoint.endpoint, Read) + .InSequence(control_endpoint.read_sequence) + .WillOnce(Return(false)); + auto transport = MakeOrphanable( + std::move(control_endpoint.promise_endpoint), + std::move(data_endpoint.promise_endpoint), event_engine()); + auto call = + MakeCall(event_engine().get(), Arena::Create(1024, memory_allocator())); + transport->StartCall(std::move(call.handler)); + StrictMock> on_done; + EXPECT_CALL(on_done, Call()); + control_endpoint.ExpectWrite( + {SerializedFrameHeader(FrameType::kFragment, 1, 1, + sizeof(kPathDemoServiceStep), 0, 0, 0), + EventEngineSlice::FromCopiedBuffer(kPathDemoServiceStep, + sizeof(kPathDemoServiceStep))}, + nullptr); + control_endpoint.ExpectWrite( + {SerializedFrameHeader(FrameType::kFragment, 2, 1, 0, 1, 63, 0)}, + nullptr); + data_endpoint.ExpectWrite( + {EventEngineSlice::FromCopiedString("0"), Zeros(63)}, nullptr); + control_endpoint.ExpectWrite( + {SerializedFrameHeader(FrameType::kFragment, 4, 1, 0, 0, 0, 0)}, nullptr); + call.initiator.SpawnGuarded("test-send", [initiator = + call.initiator]() mutable { + return TrySeq(initiator.PushClientInitialMetadata(TestInitialMetadata()), + SendClientToServerMessages(initiator, 1)); + }); + call.initiator.SpawnInfallible( + "test-read", [&on_done, initiator = call.initiator]() mutable { + return Seq( + initiator.PullServerInitialMetadata(), + [](ValueOrFailure md) { + EXPECT_TRUE(md.ok()); + EXPECT_EQ( + md.value()->get_pointer(HttpPathMetadata())->as_string_view(), + "/demo.Service/Step"); + return Empty{}; + }, + initiator.PullMessage(), + [](NextResult msg) { + EXPECT_TRUE(msg.has_value()); + EXPECT_EQ(msg.value()->payload()->JoinIntoString(), "12345678"); + return Empty{}; + }, + initiator.PullMessage(), + [](NextResult msg) { + EXPECT_FALSE(msg.has_value()); + return Empty{}; + }, + initiator.PullServerTrailingMetadata(), + [&on_done](ServerMetadataHandle md) { + EXPECT_EQ(md->get(GrpcStatusMetadata()).value(), GRPC_STATUS_OK); + on_done.Call(); + return Empty{}; + }); + }); // Wait until ClientTransport's internal activities to finish. - event_engine_->TickUntilIdle(); - event_engine_->UnsetGlobalHooks(); + event_engine()->TickUntilIdle(); + event_engine()->UnsetGlobalHooks(); } -TEST_F(ClientTransportTest, AddMultipleStreamsMultipleMessages) { - InitialClientTransport(2); - ClientMetadataHandle first_stream_md; - ClientMetadataHandle second_stream_md; - auto first_stream_args = - CallArgs{std::move(first_stream_md), - ClientInitialMetadataOutstandingToken::Empty(), - nullptr, - &pipe_server_intial_metadata_.sender, - &pipe_client_to_server_messages_.receiver, - &pipe_server_to_client_messages_.sender}; - auto second_stream_args = - CallArgs{std::move(second_stream_md), - ClientInitialMetadataOutstandingToken::Empty(), - nullptr, - &pipe_server_intial_metadata_second_.sender, - &pipe_client_to_server_messages_second_.receiver, - &pipe_server_to_client_messages_second_.sender}; - StrictMock> on_done; - EXPECT_CALL(on_done, Call(absl::OkStatus())); - EXPECT_CALL(control_endpoint_, Write).Times(6).WillRepeatedly(Return(true)); - EXPECT_CALL(data_endpoint_, Write).Times(6).WillRepeatedly(Return(true)); - auto activity = MakeActivity( - Seq( - // Concurrently: write and read messages from client transport. - Join( - // Add first stream with call_args into client transport. - AddStream(std::move(first_stream_args), GRPC_STATUS_OK), - // Start read from control endpoints. - StartRead(absl::OkStatus()), - // Send messages to first stream's - // call_args.client_to_server_messages pipe, which will be - // eventually sent to control/data endpoints. - SendClientToServerMessages(pipe_client_to_server_messages_, 3), - // Receive first stream's messages from control/data endpoints. - ReceiveServerToClientMessages(pipe_server_intial_metadata_, - pipe_server_to_client_messages_)), - Join( - // Add second stream with call_args into client transport. - AddStream(std::move(second_stream_args), GRPC_STATUS_OK), - // Send messages to second stream's - // call_args.client_to_server_messages pipe, which will be - // eventually sent to control/data endpoints. - SendClientToServerMessages(pipe_client_to_server_messages_second_, - 3), - // Receive second stream's messages from control/data endpoints. - ReceiveServerToClientMessages( - pipe_server_intial_metadata_second_, - pipe_server_to_client_messages_second_)), - // Once complete, verify successful sending and the received value. - [](const std::tuple& - ret) { - EXPECT_EQ(std::get<0>(ret), GRPC_STATUS_OK); - EXPECT_TRUE(std::get<1>(ret).ok()); - EXPECT_TRUE(std::get<2>(ret).ok()); - return absl::OkStatus(); - }), - EventEngineWakeupScheduler(event_engine_), - [&on_done](absl::Status status) { on_done.Call(std::move(status)); }); +TEST_F(TransportTest, AddOneStreamMultipleMessages) { + MockPromiseEndpoint control_endpoint; + MockPromiseEndpoint data_endpoint; + control_endpoint.ExpectRead( + {SerializedFrameHeader(FrameType::kFragment, 3, 1, 26, 8, 56, 0), + EventEngineSlice::FromCopiedBuffer(kPathDemoServiceStep, + sizeof(kPathDemoServiceStep))}, + event_engine().get()); + control_endpoint.ExpectRead( + {SerializedFrameHeader(FrameType::kFragment, 6, 1, 0, 8, 56, 15), + EventEngineSlice::FromCopiedBuffer(kGrpcStatus0, sizeof(kGrpcStatus0))}, + event_engine().get()); + data_endpoint.ExpectRead( + {EventEngineSlice::FromCopiedString("12345678"), Zeros(56)}, nullptr); + data_endpoint.ExpectRead( + {EventEngineSlice::FromCopiedString("87654321"), Zeros(56)}, nullptr); + EXPECT_CALL(*control_endpoint.endpoint, Read) + .InSequence(control_endpoint.read_sequence) + .WillOnce(Return(false)); + auto transport = MakeOrphanable( + std::move(control_endpoint.promise_endpoint), + std::move(data_endpoint.promise_endpoint), event_engine()); + auto call = + MakeCall(event_engine().get(), Arena::Create(8192, memory_allocator())); + transport->StartCall(std::move(call.handler)); + StrictMock> on_done; + EXPECT_CALL(on_done, Call()); + control_endpoint.ExpectWrite( + {SerializedFrameHeader(FrameType::kFragment, 1, 1, + sizeof(kPathDemoServiceStep), 0, 0, 0), + EventEngineSlice::FromCopiedBuffer(kPathDemoServiceStep, + sizeof(kPathDemoServiceStep))}, + nullptr); + control_endpoint.ExpectWrite( + {SerializedFrameHeader(FrameType::kFragment, 2, 1, 0, 1, 63, 0)}, + nullptr); + data_endpoint.ExpectWrite( + {EventEngineSlice::FromCopiedString("0"), Zeros(63)}, nullptr); + control_endpoint.ExpectWrite( + {SerializedFrameHeader(FrameType::kFragment, 2, 1, 0, 1, 63, 0)}, + nullptr); + data_endpoint.ExpectWrite( + {EventEngineSlice::FromCopiedString("1"), Zeros(63)}, nullptr); + control_endpoint.ExpectWrite( + {SerializedFrameHeader(FrameType::kFragment, 4, 1, 0, 0, 0, 0)}, nullptr); + call.initiator.SpawnGuarded("test-send", [initiator = + call.initiator]() mutable { + return TrySeq(initiator.PushClientInitialMetadata(TestInitialMetadata()), + SendClientToServerMessages(initiator, 2)); + }); + call.initiator.SpawnInfallible( + "test-read", [&on_done, initiator = call.initiator]() mutable { + return Seq( + initiator.PullServerInitialMetadata(), + [](ValueOrFailure md) { + EXPECT_TRUE(md.ok()); + EXPECT_EQ( + md.value()->get_pointer(HttpPathMetadata())->as_string_view(), + "/demo.Service/Step"); + return Empty{}; + }, + initiator.PullMessage(), + [](NextResult msg) { + EXPECT_TRUE(msg.has_value()); + EXPECT_EQ(msg.value()->payload()->JoinIntoString(), "12345678"); + return Empty{}; + }, + initiator.PullMessage(), + [](NextResult msg) { + EXPECT_TRUE(msg.has_value()); + EXPECT_EQ(msg.value()->payload()->JoinIntoString(), "87654321"); + return Empty{}; + }, + initiator.PullMessage(), + [](NextResult msg) { + EXPECT_FALSE(msg.has_value()); + return Empty{}; + }, + initiator.PullServerTrailingMetadata(), + [&on_done](ServerMetadataHandle md) { + EXPECT_EQ(md->get(GrpcStatusMetadata()).value(), GRPC_STATUS_OK); + on_done.Call(); + return Empty{}; + }); + }); // Wait until ClientTransport's internal activities to finish. - event_engine_->TickUntilIdle(); - event_engine_->UnsetGlobalHooks(); + event_engine()->TickUntilIdle(); + event_engine()->UnsetGlobalHooks(); } } // namespace testing diff --git a/test/core/transport/chaotic_good/frame_fuzzer.cc b/test/core/transport/chaotic_good/frame_fuzzer.cc index 57180ce1c20..03481560771 100644 --- a/test/core/transport/chaotic_good/frame_fuzzer.cc +++ b/test/core/transport/chaotic_good/frame_fuzzer.cc @@ -35,7 +35,9 @@ #include "src/core/lib/resource_quota/resource_quota.h" #include "src/core/lib/slice/slice.h" #include "src/core/lib/slice/slice_buffer.h" +#include "src/libfuzzer/libfuzzer_macro.h" #include "test/core/promise/test_context.h" +#include "test/core/transport/chaotic_good/frame_fuzzer.pb.h" bool squelch = false; @@ -51,10 +53,10 @@ template void AssertRoundTrips(const T& input, FrameType expected_frame_type) { HPackCompressor hpack_compressor; auto serialized = input.Serialize(&hpack_compressor); - GPR_ASSERT(serialized.Length() >= + GPR_ASSERT(serialized.control.Length() >= 24); // Initial output buffer size is 64 byte. uint8_t header_bytes[24]; - serialized.MoveFirstNBytesIntoBuffer(24, header_bytes); + serialized.control.MoveFirstNBytesIntoBuffer(24, header_bytes); auto header = FrameHeader::Parse(header_bytes); if (!header.ok()) { if (!squelch) { @@ -67,66 +69,69 @@ void AssertRoundTrips(const T& input, FrameType expected_frame_type) { T output; HPackParser hpack_parser; DeterministicBitGen bitgen; - auto deser = output.Deserialize(&hpack_parser, header.value(), - absl::BitGenRef(bitgen), serialized); + auto deser = + output.Deserialize(&hpack_parser, header.value(), absl::BitGenRef(bitgen), + GetContext(), std::move(serialized)); GPR_ASSERT(deser.ok()); GPR_ASSERT(output == input); } template -void FinishParseAndChecks(const FrameHeader& header, const uint8_t* data, - size_t size) { +void FinishParseAndChecks(const FrameHeader& header, BufferPair buffers) { T parsed; ExecCtx exec_ctx; // Initialized to get this_cpu() info in global_stat(). HPackParser hpack_parser; - SliceBuffer serialized; - serialized.Append(Slice::FromCopiedBuffer(data, size)); DeterministicBitGen bitgen; - auto deser = parsed.Deserialize(&hpack_parser, header, - absl::BitGenRef(bitgen), serialized); + auto deser = + parsed.Deserialize(&hpack_parser, header, absl::BitGenRef(bitgen), + GetContext(), std::move(buffers)); if (!deser.ok()) return; gpr_log(GPR_INFO, "Read frame: %s", parsed.ToString().c_str()); AssertRoundTrips(parsed, header.type); } -int Run(const uint8_t* data, size_t size) { - if (size < 1) return 0; - const bool is_server = (data[0] & 1) != 0; - size--; - data++; - if (size < 24) return 0; - auto r = FrameHeader::Parse(data); - if (!r.ok()) return 0; +void Run(const frame_fuzzer::Test& test) { + const uint8_t* control_data = + reinterpret_cast(test.control().data()); + size_t control_size = test.control().size(); + if (test.control().size() < 24) return; + auto r = FrameHeader::Parse(control_data); + if (!r.ok()) return; + if (test.data().size() != r->message_length) return; gpr_log(GPR_INFO, "Read frame header: %s", r->ToString().c_str()); - size -= 24; - data += 24; + control_data += 24; + control_size -= 24; MemoryAllocator memory_allocator = MemoryAllocator( ResourceQuota::Default()->memory_quota()->CreateMemoryAllocator("test")); auto arena = MakeScopedArena(1024, &memory_allocator); TestContext ctx(arena.get()); + BufferPair buffers{ + SliceBuffer(Slice::FromCopiedBuffer(control_data, control_size)), + SliceBuffer( + Slice::FromCopiedBuffer(test.data().data(), test.data().size())), + }; switch (r->type) { default: - return 0; // We don't know how to parse this frame type. + return; // We don't know how to parse this frame type. case FrameType::kSettings: - FinishParseAndChecks(*r, data, size); + FinishParseAndChecks(*r, std::move(buffers)); break; case FrameType::kFragment: - if (is_server) { - FinishParseAndChecks(*r, data, size); + if (test.is_server()) { + FinishParseAndChecks(*r, std::move(buffers)); } else { - FinishParseAndChecks(*r, data, size); + FinishParseAndChecks(*r, std::move(buffers)); } break; case FrameType::kCancel: - FinishParseAndChecks(*r, data, size); + FinishParseAndChecks(*r, std::move(buffers)); break; } - return 0; } } // namespace chaotic_good } // namespace grpc_core -extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { - return grpc_core::chaotic_good::Run(data, size); +DEFINE_PROTO_FUZZER(const frame_fuzzer::Test& test) { + grpc_core::chaotic_good::Run(test); } diff --git a/test/core/transport/chaotic_good/frame_fuzzer.proto b/test/core/transport/chaotic_good/frame_fuzzer.proto new file mode 100644 index 00000000000..4ae8657e588 --- /dev/null +++ b/test/core/transport/chaotic_good/frame_fuzzer.proto @@ -0,0 +1,23 @@ +// Copyright 2021 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package frame_fuzzer; + +message Test { + bool is_server = 1; + bytes control = 2; + bytes data = 3; +} diff --git a/test/core/transport/chaotic_good/frame_fuzzer_corpus/5072496117219328 b/test/core/transport/chaotic_good/frame_fuzzer_corpus/5072496117219328 deleted file mode 100644 index 16d6e2f4fde2813246ca23877e80d539e834b3eb..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 26 WcmcC+U}Ru0)K%dGQeXh&Lqq^0Hv&xn diff --git a/test/core/transport/chaotic_good/frame_fuzzer_corpus/5691448031772672 b/test/core/transport/chaotic_good/frame_fuzzer_corpus/5691448031772672 deleted file mode 100644 index 98e8a28385d868b52dc0209da655bed0b3deb36d..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 51 zcmY#rU}Rv(Q+)JKiUAB(8UiU*2B4sDdQm}gPEvewPG)LeNqlihVo5PjyqEz1DVhwJ diff --git a/test/core/transport/chaotic_good/frame_fuzzer_corpus/crash-05c704327d21af2cc914de40e9d90d06f16ca0eb b/test/core/transport/chaotic_good/frame_fuzzer_corpus/crash-05c704327d21af2cc914de40e9d90d06f16ca0eb deleted file mode 100644 index 0340f7dee017d1c99db70bcda61ea432dab08e71..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 74 ccmZQ*U}a!nbYQp!q{svR|J&OB|F82O0Ex&4LjV8( diff --git a/test/core/transport/chaotic_good/frame_fuzzer_corpus/crash-5015de8c7cafb0b0ebbbfd28c29aedd5dbfdc03a b/test/core/transport/chaotic_good/frame_fuzzer_corpus/crash-5015de8c7cafb0b0ebbbfd28c29aedd5dbfdc03a deleted file mode 100644 index 366c53025968158085863d58678bcd130f39ab81..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 66 ScmZQ*U}IolZ~$U5!G8cUK>@}9 diff --git a/test/core/transport/chaotic_good/frame_fuzzer_corpus/crash-5a34978de8de6889ce913947a77f43f7cdea854c b/test/core/transport/chaotic_good/frame_fuzzer_corpus/crash-5a34978de8de6889ce913947a77f43f7cdea854c deleted file mode 100644 index 74cb18c8e938ead8bc408336ba7cf8f0700045e8..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 180 kcmZQ*U}9ikaA3H`0i<+*_$&j10+1R^K#f7?KazX20O4l@H2?qr diff --git a/test/core/transport/chaotic_good/frame_fuzzer_corpus/crash-608f798a51077a8cdc45b11f335c079a81339fbe b/test/core/transport/chaotic_good/frame_fuzzer_corpus/crash-608f798a51077a8cdc45b11f335c079a81339fbe deleted file mode 100644 index 2190f6bc1859843a84a8a46d0f0d1392e9995dbd..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 166 mcmZQ*U}9ikbYQq9!2kgYP?~h=8v_Fi>ALXgc>EnA`yT*|4h5wE diff --git a/test/core/transport/chaotic_good/frame_fuzzer_corpus/crash-6a002cb46eac21af4ab6fd74b61ff3ce26d96dff b/test/core/transport/chaotic_good/frame_fuzzer_corpus/crash-6a002cb46eac21af4ab6fd74b61ff3ce26d96dff deleted file mode 100644 index 90739c7258095c3bd8026055b6e80dd9912dac00..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 161 scmZQ*U}9ikaA3H`0i<+*_$&j10+1pBI6&0?iJnY?CZd$8G3fjU09rQ&H2?qr diff --git a/test/core/transport/chaotic_good/frame_fuzzer_corpus/crash-7732ddd35a4deb8b7c9e462aaf8680986755e540 b/test/core/transport/chaotic_good/frame_fuzzer_corpus/crash-7732ddd35a4deb8b7c9e462aaf8680986755e540 deleted file mode 100644 index 4d14b159ba6b64ea200a8dd65c0f9bfabb1a1972..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 79 ccmZQ*U}a!nbYQr~$Uq7B|KHa3|9_qT0Fc@TKL7v# diff --git a/test/core/transport/chaotic_good/frame_fuzzer_corpus/crash-c171e98ebfe8b6485f9a4bea0b9cdfe683776675 b/test/core/transport/chaotic_good/frame_fuzzer_corpus/crash-c171e98ebfe8b6485f9a4bea0b9cdfe683776675 deleted file mode 100644 index 8fa12bd3aeaf54fed89e4f4e2f08bf1bc370ae21..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 70 VcmZQ Deserialize(std::vector data) { } TEST(FrameHeaderTest, SimpleSerialize) { - EXPECT_EQ(Serialize(FrameHeader{FrameType::kCancel, BitSet<2>::FromInt(0), + EXPECT_EQ(Serialize(FrameHeader{FrameType::kCancel, BitSet<3>::FromInt(0), 0x01020304, 0x05060708, 0x090a0b0c, 0x00000034, 0x0d0e0f10}), std::vector({ @@ -59,7 +59,7 @@ TEST(FrameHeaderTest, SimpleDeserialize) { 0x10, 0x0f, 0x0e, 0x0d // trailer_length })), absl::StatusOr(FrameHeader{ - FrameType::kCancel, BitSet<2>::FromInt(0), 0x01020304, + FrameType::kCancel, BitSet<3>::FromInt(0), 0x01020304, 0x05060708, 0x090a0b0c, 0x00000034, 0x0d0e0f10})); EXPECT_EQ(Deserialize(std::vector({ 0x81, 88, 88, 88, // type, flags @@ -75,19 +75,19 @@ TEST(FrameHeaderTest, SimpleDeserialize) { TEST(FrameHeaderTest, GetFrameLength) { EXPECT_EQ( - (FrameHeader{FrameType::kFragment, BitSet<2>::FromInt(3), 1, 0, 0, 0, 0}) + (FrameHeader{FrameType::kFragment, BitSet<3>::FromInt(5), 1, 0, 0, 0, 0}) .GetFrameLength(), 0); EXPECT_EQ( - (FrameHeader{FrameType::kFragment, BitSet<2>::FromInt(3), 1, 14, 0, 0, 0}) + (FrameHeader{FrameType::kFragment, BitSet<3>::FromInt(5), 1, 14, 0, 0, 0}) .GetFrameLength(), 14); - EXPECT_EQ((FrameHeader{FrameType::kFragment, BitSet<2>::FromInt(3), 1, 0, 14, + EXPECT_EQ((FrameHeader{FrameType::kFragment, BitSet<3>::FromInt(5), 1, 0, 14, 50, 0}) .GetFrameLength(), 0); EXPECT_EQ( - (FrameHeader{FrameType::kFragment, BitSet<2>::FromInt(3), 1, 0, 0, 0, 14}) + (FrameHeader{FrameType::kFragment, BitSet<3>::FromInt(5), 1, 0, 0, 0, 14}) .GetFrameLength(), 14); } diff --git a/test/core/transport/chaotic_good/frame_test.cc b/test/core/transport/chaotic_good/frame_test.cc index 00908f75a6e..15153a09b8d 100644 --- a/test/core/transport/chaotic_good/frame_test.cc +++ b/test/core/transport/chaotic_good/frame_test.cc @@ -21,27 +21,38 @@ #include "absl/status/statusor.h" #include "gtest/gtest.h" +#include "src/core/lib/resource_quota/memory_quota.h" +#include "src/core/lib/resource_quota/resource_quota.h" + namespace grpc_core { namespace chaotic_good { namespace { template -void AssertRoundTrips(const T input, FrameType expected_frame_type) { +void AssertRoundTrips(const T& input, FrameType expected_frame_type) { HPackCompressor hpack_compressor; - absl::BitGen bitgen; auto serialized = input.Serialize(&hpack_compressor); - EXPECT_GE(serialized.Length(), 24); + GPR_ASSERT(serialized.control.Length() >= + 24); // Initial output buffer size is 64 byte. uint8_t header_bytes[24]; - serialized.MoveFirstNBytesIntoBuffer(24, header_bytes); + serialized.control.MoveFirstNBytesIntoBuffer(24, header_bytes); auto header = FrameHeader::Parse(header_bytes); - EXPECT_TRUE(header.ok()) << header.status(); - EXPECT_EQ(header->type, expected_frame_type); + if (!header.ok()) { + Crash("Failed to parse header"); + } + GPR_ASSERT(header->type == expected_frame_type); T output; HPackParser hpack_parser; - auto deser = output.Deserialize(&hpack_parser, header.value(), - absl::BitGenRef(bitgen), serialized); - EXPECT_TRUE(deser.ok()) << deser; - EXPECT_EQ(output, input); + absl::BitGen bitgen; + MemoryAllocator allocator = MakeResourceQuota("test-quota") + ->memory_quota() + ->CreateMemoryAllocator("test-allocator"); + ScopedArenaPtr arena = MakeScopedArena(1024, &allocator); + auto deser = + output.Deserialize(&hpack_parser, header.value(), absl::BitGenRef(bitgen), + arena.get(), std::move(serialized)); + GPR_ASSERT(deser.ok()); + GPR_ASSERT(output == input); } TEST(FrameTest, SettingsFrameRoundTrips) { diff --git a/test/core/transport/chaotic_good/mock_promise_endpoint.cc b/test/core/transport/chaotic_good/mock_promise_endpoint.cc new file mode 100644 index 00000000000..9ba96e75804 --- /dev/null +++ b/test/core/transport/chaotic_good/mock_promise_endpoint.cc @@ -0,0 +1,89 @@ +// Copyright 2023 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/core/transport/chaotic_good/mock_promise_endpoint.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +#include + +using EventEngineSlice = grpc_event_engine::experimental::Slice; +using grpc_event_engine::experimental::EventEngine; + +using testing::WithArgs; + +namespace grpc_core { +namespace chaotic_good { +namespace testing { + +void MockPromiseEndpoint::ExpectRead( + std::initializer_list slices_init, + EventEngine* schedule_on_event_engine) { + std::vector slices; + for (auto&& slice : slices_init) slices.emplace_back(slice.Copy()); + EXPECT_CALL(*endpoint, Read) + .InSequence(read_sequence) + .WillOnce(WithArgs<0, 1>( + [slices = std::move(slices), schedule_on_event_engine]( + absl::AnyInvocable on_read, + grpc_event_engine::experimental::SliceBuffer* buffer) mutable { + for (auto& slice : slices) { + buffer->Append(std::move(slice)); + } + if (schedule_on_event_engine != nullptr) { + schedule_on_event_engine->Run( + [on_read = std::move(on_read)]() mutable { + on_read(absl::OkStatus()); + }); + return false; + } else { + return true; + } + })); +} + +void MockPromiseEndpoint::ExpectWrite( + std::initializer_list slices, + EventEngine* schedule_on_event_engine) { + SliceBuffer expect; + for (auto&& slice : slices) { + expect.Append(grpc_event_engine::experimental::internal::SliceCast( + slice.Copy())); + } + EXPECT_CALL(*endpoint, Write) + .InSequence(write_sequence) + .WillOnce(WithArgs<0, 1>( + [expect = expect.JoinIntoString(), schedule_on_event_engine]( + absl::AnyInvocable on_writable, + grpc_event_engine::experimental::SliceBuffer* buffer) mutable { + SliceBuffer tmp; + grpc_slice_buffer_swap(buffer->c_slice_buffer(), + tmp.c_slice_buffer()); + EXPECT_EQ(tmp.JoinIntoString(), expect); + if (schedule_on_event_engine != nullptr) { + schedule_on_event_engine->Run( + [on_writable = std::move(on_writable)]() mutable { + on_writable(absl::OkStatus()); + }); + return false; + } else { + return true; + } + })); +} + +} // namespace testing +} // namespace chaotic_good +} // namespace grpc_core diff --git a/test/core/transport/chaotic_good/mock_promise_endpoint.h b/test/core/transport/chaotic_good/mock_promise_endpoint.h new file mode 100644 index 00000000000..c1534efb605 --- /dev/null +++ b/test/core/transport/chaotic_good/mock_promise_endpoint.h @@ -0,0 +1,77 @@ +// Copyright 2023 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GRPC_TEST_CORE_TRANSPORT_CHAOTIC_GOOD_MOCK_PROMISE_ENDPOINT_H +#define GRPC_TEST_CORE_TRANSPORT_CHAOTIC_GOOD_MOCK_PROMISE_ENDPOINT_H + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +#include + +#include "src/core/lib/transport/promise_endpoint.h" + +namespace grpc_core { +namespace chaotic_good { +namespace testing { + +class MockEndpoint + : public grpc_event_engine::experimental::EventEngine::Endpoint { + public: + MOCK_METHOD( + bool, Read, + (absl::AnyInvocable on_read, + grpc_event_engine::experimental::SliceBuffer* buffer, + const grpc_event_engine::experimental::EventEngine::Endpoint::ReadArgs* + args), + (override)); + + MOCK_METHOD( + bool, Write, + (absl::AnyInvocable on_writable, + grpc_event_engine::experimental::SliceBuffer* data, + const grpc_event_engine::experimental::EventEngine::Endpoint::WriteArgs* + args), + (override)); + + MOCK_METHOD( + const grpc_event_engine::experimental::EventEngine::ResolvedAddress&, + GetPeerAddress, (), (const, override)); + MOCK_METHOD( + const grpc_event_engine::experimental::EventEngine::ResolvedAddress&, + GetLocalAddress, (), (const, override)); +}; + +struct MockPromiseEndpoint { + ::testing::StrictMock* endpoint = + new ::testing::StrictMock(); + std::unique_ptr promise_endpoint = + std::make_unique( + std::unique_ptr<::testing::StrictMock>(endpoint), + SliceBuffer()); + ::testing::Sequence read_sequence; + ::testing::Sequence write_sequence; + void ExpectRead( + std::initializer_list slices_init, + grpc_event_engine::experimental::EventEngine* schedule_on_event_engine); + void ExpectWrite( + std::initializer_list slices, + grpc_event_engine::experimental::EventEngine* schedule_on_event_engine); +}; + +} // namespace testing +} // namespace chaotic_good +} // namespace grpc_core + +#endif // GRPC_TEST_CORE_TRANSPORT_CHAOTIC_GOOD_MOCK_PROMISE_ENDPOINT_H diff --git a/test/core/transport/chaotic_good/server_transport_test.cc b/test/core/transport/chaotic_good/server_transport_test.cc new file mode 100644 index 00000000000..05a32a21817 --- /dev/null +++ b/test/core/transport/chaotic_good/server_transport_test.cc @@ -0,0 +1,198 @@ +// Copyright 2023 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/core/ext/transport/chaotic_good/server_transport.h" + +#include +#include +#include +#include + +#include "absl/functional/any_invocable.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/types/optional.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +#include +#include +#include +#include +#include +#include + +#include "src/core/lib/gprpp/ref_counted_ptr.h" +#include "src/core/lib/iomgr/timer_manager.h" +#include "src/core/lib/promise/seq.h" +#include "src/core/lib/resource_quota/arena.h" +#include "src/core/lib/resource_quota/memory_quota.h" +#include "src/core/lib/resource_quota/resource_quota.h" +#include "src/core/lib/slice/slice_buffer.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/transport/metadata_batch.h" +#include "test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.h" +#include "test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.pb.h" +#include "test/core/transport/chaotic_good/mock_promise_endpoint.h" +#include "test/core/transport/chaotic_good/transport_test.h" + +using testing::_; +using testing::MockFunction; +using testing::Return; +using testing::StrictMock; +using testing::WithArgs; + +using EventEngineSlice = grpc_event_engine::experimental::Slice; + +namespace grpc_core { +namespace chaotic_good { +namespace testing { + +// Encoded string of header ":path: /demo.Service/Step". +const uint8_t kPathDemoServiceStep[] = { + 0x40, 0x05, 0x3a, 0x70, 0x61, 0x74, 0x68, 0x12, 0x2f, + 0x64, 0x65, 0x6d, 0x6f, 0x2e, 0x53, 0x65, 0x72, 0x76, + 0x69, 0x63, 0x65, 0x2f, 0x53, 0x74, 0x65, 0x70}; + +// Encoded string of trailer "grpc-status: 0". +const uint8_t kGrpcStatus0[] = {0x40, 0x0b, 0x67, 0x72, 0x70, 0x63, 0x2d, 0x73, + 0x74, 0x61, 0x74, 0x75, 0x73, 0x01, 0x30}; + +ServerMetadataHandle TestInitialMetadata() { + auto md = + GetContext()->MakePooled(GetContext()); + md->Set(HttpPathMetadata(), Slice::FromStaticString("/demo.Service/Step")); + return md; +} + +ServerMetadataHandle TestTrailingMetadata() { + auto md = + GetContext()->MakePooled(GetContext()); + md->Set(GrpcStatusMetadata(), GRPC_STATUS_OK); + return md; +} + +class MockAcceptor : public ServerTransport::Acceptor { + public: + virtual ~MockAcceptor() = default; + MOCK_METHOD(Arena*, CreateArena, (), (override)); + MOCK_METHOD(absl::StatusOr, CreateCall, + (ClientMetadata & client_initial_metadata, Arena* arena), + (override)); +}; + +TEST_F(TransportTest, ReadAndWriteOneMessage) { + MockPromiseEndpoint control_endpoint; + MockPromiseEndpoint data_endpoint; + StrictMock acceptor; + auto transport = MakeOrphanable( + CoreConfiguration::Get() + .channel_args_preconditioning() + .PreconditionChannelArgs(nullptr), + std::move(control_endpoint.promise_endpoint), + std::move(data_endpoint.promise_endpoint), event_engine()); + // Once we set the acceptor, expect to read some frames. + // We'll return a new request with a payload of "12345678". + control_endpoint.ExpectRead( + {SerializedFrameHeader(FrameType::kFragment, 7, 1, 26, 8, 56, 0), + EventEngineSlice::FromCopiedBuffer(kPathDemoServiceStep, + sizeof(kPathDemoServiceStep))}, + event_engine().get()); + data_endpoint.ExpectRead( + {EventEngineSlice::FromCopiedString("12345678"), Zeros(56)}, nullptr); + // Once that's read we'll create a new call + auto* call_arena = Arena::Create(1024, memory_allocator()); + CallInitiatorAndHandler call = MakeCall(event_engine().get(), call_arena); + EXPECT_CALL(acceptor, CreateArena).WillOnce(Return(call_arena)); + EXPECT_CALL(acceptor, CreateCall(_, call_arena)) + .WillOnce(WithArgs<0>([call_initiator = std::move(call.initiator)]( + ClientMetadata& client_initial_metadata) { + EXPECT_EQ(client_initial_metadata.get_pointer(HttpPathMetadata()) + ->as_string_view(), + "/demo.Service/Step"); + return call_initiator; + })); + transport->SetAcceptor(&acceptor); + StrictMock> on_done; + EXPECT_CALL(on_done, Call()); + EXPECT_CALL(*control_endpoint.endpoint, Read) + .InSequence(control_endpoint.read_sequence) + .WillOnce(Return(false)); + control_endpoint.ExpectWrite( + {SerializedFrameHeader(FrameType::kFragment, 1, 1, + sizeof(kPathDemoServiceStep), 0, 0, 0), + EventEngineSlice::FromCopiedBuffer(kPathDemoServiceStep, + sizeof(kPathDemoServiceStep))}, + nullptr); + control_endpoint.ExpectWrite( + {SerializedFrameHeader(FrameType::kFragment, 2, 1, 0, 8, 56, 0)}, + nullptr); + data_endpoint.ExpectWrite( + {EventEngineSlice::FromCopiedString("87654321"), Zeros(56)}, nullptr); + control_endpoint.ExpectWrite( + {SerializedFrameHeader(FrameType::kFragment, 4, 1, 0, 0, 0, + sizeof(kGrpcStatus0)), + EventEngineSlice::FromCopiedBuffer(kGrpcStatus0, sizeof(kGrpcStatus0))}, + nullptr); + call.handler.SpawnInfallible( + "test-io", [&on_done, handler = call.handler]() mutable { + return Seq( + handler.PullClientInitialMetadata(), + [](ValueOrFailure md) { + EXPECT_TRUE(md.ok()); + EXPECT_EQ( + md.value()->get_pointer(HttpPathMetadata())->as_string_view(), + "/demo.Service/Step"); + return Empty{}; + }, + handler.PullMessage(), + [](NextResult msg) { + EXPECT_TRUE(msg.has_value()); + EXPECT_EQ(msg.value()->payload()->JoinIntoString(), "12345678"); + return Empty{}; + }, + handler.PullMessage(), + [](NextResult msg) { + EXPECT_FALSE(msg.has_value()); + return Empty{}; + }, + handler.PushServerInitialMetadata(TestInitialMetadata()), + handler.PushMessage(Arena::MakePooled( + SliceBuffer(Slice::FromCopiedString("87654321")), 0)), + [handler]() mutable { + return handler.PushServerTrailingMetadata(TestTrailingMetadata()); + }, + [&on_done]() mutable { + on_done.Call(); + return Empty{}; + }); + }); + // Wait until ClientTransport's internal activities to finish. + event_engine()->TickUntilIdle(); + event_engine()->UnsetGlobalHooks(); +} + +} // namespace testing +} // namespace chaotic_good +} // namespace grpc_core + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + // Must call to create default EventEngine. + grpc_init(); + int ret = RUN_ALL_TESTS(); + grpc_shutdown(); + return ret; +} diff --git a/test/core/transport/chaotic_good/transport_test.cc b/test/core/transport/chaotic_good/transport_test.cc new file mode 100644 index 00000000000..b43098fa7c7 --- /dev/null +++ b/test/core/transport/chaotic_good/transport_test.cc @@ -0,0 +1,60 @@ +// Copyright 2023 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/core/transport/chaotic_good/transport_test.h" + +namespace grpc_core { +namespace chaotic_good { +namespace testing { + +grpc_event_engine::experimental::Slice SerializedFrameHeader( + FrameType type, uint8_t flags, uint32_t stream_id, uint32_t header_length, + uint32_t message_length, uint32_t message_padding, + uint32_t trailer_length) { + uint8_t buffer[24] = {static_cast(type), + flags, + 0, + 0, + static_cast(stream_id), + static_cast(stream_id >> 8), + static_cast(stream_id >> 16), + static_cast(stream_id >> 24), + static_cast(header_length), + static_cast(header_length >> 8), + static_cast(header_length >> 16), + static_cast(header_length >> 24), + static_cast(message_length), + static_cast(message_length >> 8), + static_cast(message_length >> 16), + static_cast(message_length >> 24), + static_cast(message_padding), + static_cast(message_padding >> 8), + static_cast(message_padding >> 16), + static_cast(message_padding >> 24), + static_cast(trailer_length), + static_cast(trailer_length >> 8), + static_cast(trailer_length >> 16), + static_cast(trailer_length >> 24)}; + return grpc_event_engine::experimental::Slice::FromCopiedBuffer(buffer, 24); +} + +grpc_event_engine::experimental::Slice Zeros(uint32_t length) { + std::string zeros(length, 0); + return grpc_event_engine::experimental::Slice::FromCopiedBuffer(zeros.data(), + length); +} + +} // namespace testing +} // namespace chaotic_good +} // namespace grpc_core diff --git a/test/core/transport/chaotic_good/transport_test.h b/test/core/transport/chaotic_good/transport_test.h new file mode 100644 index 00000000000..e70158bb8cf --- /dev/null +++ b/test/core/transport/chaotic_good/transport_test.h @@ -0,0 +1,67 @@ +// Copyright 2023 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GRPC_TEST_CORE_TRANSPORT_CHAOTIC_GOOD_TRANSPORT_TEST_H +#define GRPC_TEST_CORE_TRANSPORT_CHAOTIC_GOOD_TRANSPORT_TEST_H + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +#include "src/core/ext/transport/chaotic_good/frame.h" +#include "src/core/lib/iomgr/timer_manager.h" +#include "src/core/lib/resource_quota/memory_quota.h" +#include "src/core/lib/resource_quota/resource_quota.h" +#include "test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.h" +#include "test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.pb.h" + +namespace grpc_core { +namespace chaotic_good { +namespace testing { + +class TransportTest : public ::testing::Test { + protected: + const std::shared_ptr& + event_engine() { + return event_engine_; + } + + MemoryAllocator* memory_allocator() { return &allocator_; } + + private: + std::shared_ptr + event_engine_{ + std::make_shared( + []() { + grpc_timer_manager_set_threading(false); + grpc_event_engine::experimental::FuzzingEventEngine::Options + options; + return options; + }(), + fuzzing_event_engine::Actions())}; + MemoryAllocator allocator_ = MakeResourceQuota("test-quota") + ->memory_quota() + ->CreateMemoryAllocator("test-allocator"); +}; + +grpc_event_engine::experimental::Slice SerializedFrameHeader( + FrameType type, uint8_t flags, uint32_t stream_id, uint32_t header_length, + uint32_t message_length, uint32_t message_padding, uint32_t trailer_length); + +grpc_event_engine::experimental::Slice Zeros(uint32_t length); + +} // namespace testing +} // namespace chaotic_good +} // namespace grpc_core + +#endif // GRPC_TEST_CORE_TRANSPORT_CHAOTIC_GOOD_TRANSPORT_TEST_H diff --git a/test/core/transport/promise_endpoint_test.cc b/test/core/transport/promise_endpoint_test.cc index 760a2add6c7..e6ad3dd2713 100644 --- a/test/core/transport/promise_endpoint_test.cc +++ b/test/core/transport/promise_endpoint_test.cc @@ -524,6 +524,19 @@ TEST_F(PromiseEndpointTest, OneWriteSuccessful) { activity.Activate(); EXPECT_CALL(activity, WakeupRequested).Times(0); EXPECT_CALL(mock_endpoint_, Write).WillOnce(Return(true)); + auto promise = promise_endpoint_->Write( + SliceBuffer(Slice::FromCopiedString("hello world"))); + auto poll = promise(); + ASSERT_TRUE(poll.ready()); + EXPECT_EQ(absl::OkStatus(), poll.value()); + activity.Deactivate(); +} + +TEST_F(PromiseEndpointTest, EmptyWriteIsNoOp) { + MockActivity activity; + activity.Activate(); + EXPECT_CALL(activity, WakeupRequested).Times(0); + EXPECT_CALL(mock_endpoint_, Write).Times(0); auto promise = promise_endpoint_->Write(SliceBuffer()); auto poll = promise(); ASSERT_TRUE(poll.ready()); @@ -541,7 +554,8 @@ TEST_F(PromiseEndpointTest, OneWriteFailed) { on_write(this->kDummyErrorStatus); return false; })); - auto promise = promise_endpoint_->Write(SliceBuffer()); + auto promise = promise_endpoint_->Write( + SliceBuffer(Slice::FromCopiedString("hello world"))); auto poll = promise(); ASSERT_TRUE(poll.ready()); EXPECT_EQ(kDummyErrorStatus, poll.value()); @@ -564,7 +578,8 @@ TEST_F(PromiseEndpointTest, OnePendingWriteSuccessful) { // Return false to mock EventEngine write pending.. return false; })); - auto promise = promise_endpoint_->Write(SliceBuffer()); + auto promise = promise_endpoint_->Write( + SliceBuffer(Slice::FromCopiedString("hello world"))); EXPECT_TRUE(promise().pending()); // Mock EventEngine write succeeds, and promise resolves. write_callback(absl::OkStatus()); @@ -586,7 +601,8 @@ TEST_F(PromiseEndpointTest, OnePendingWriteFailed) { // Return false to mock EventEngine write pending.. return false; })); - auto promise = promise_endpoint_->Write(SliceBuffer()); + auto promise = promise_endpoint_->Write( + SliceBuffer(Slice::FromCopiedString("hello world"))); EXPECT_TRUE(promise().pending()); write_callback(kDummyErrorStatus); auto poll = promise(); @@ -807,8 +823,10 @@ TEST_F(MultiplePromiseEndpointTest, JoinWritesSuccessful) { EXPECT_CALL(on_done, Call(absl::OkStatus())); auto activity = MakeActivity( [this] { - return Seq(Join(this->first_promise_endpoint_.Write(SliceBuffer()), - this->second_promise_endpoint_.Write(SliceBuffer())), + return Seq(Join(this->first_promise_endpoint_.Write(SliceBuffer( + Slice::FromCopiedString("hello world"))), + this->second_promise_endpoint_.Write(SliceBuffer( + Slice::FromCopiedString("hello world")))), [](std::tuple ret) { // Both writes finish with `absl::OkStatus`. EXPECT_TRUE(std::get<0>(ret).ok()); @@ -832,8 +850,10 @@ TEST_F(MultiplePromiseEndpointTest, JoinOneWriteSuccessfulOneWriteFailed) { EXPECT_CALL(on_done, Call(kDummyErrorStatus)); auto activity = MakeActivity( [this] { - return Seq(Join(this->first_promise_endpoint_.Write(SliceBuffer()), - this->second_promise_endpoint_.Write(SliceBuffer())), + return Seq(Join(this->first_promise_endpoint_.Write(SliceBuffer( + Slice::FromCopiedString("hello world"))), + this->second_promise_endpoint_.Write(SliceBuffer( + Slice::FromCopiedString("hello world")))), [this](std::tuple ret) { // One write finish with `absl::OkStatus` and the other // write fails. @@ -864,8 +884,10 @@ TEST_F(MultiplePromiseEndpointTest, JoinWritesFailed) { EXPECT_CALL(on_done, Call(kDummyErrorStatus)); auto activity = MakeActivity( [this] { - return Seq(Join(this->first_promise_endpoint_.Write(SliceBuffer()), - this->second_promise_endpoint_.Write(SliceBuffer())), + return Seq(Join(this->first_promise_endpoint_.Write(SliceBuffer( + Slice::FromCopiedString("hello world"))), + this->second_promise_endpoint_.Write(SliceBuffer( + Slice::FromCopiedString("hello world")))), [this](std::tuple ret) { // Both writes fail with errors. EXPECT_FALSE(std::get<0>(ret).ok()); diff --git a/tools/run_tests/generated/tests.json b/tools/run_tests/generated/tests.json index 7981a98a90b..ae3fad4bc33 100644 --- a/tools/run_tests/generated/tests.json +++ b/tools/run_tests/generated/tests.json @@ -9071,6 +9071,30 @@ ], "uses_polling": true }, + { + "args": [], + "benchmark": false, + "ci_platforms": [ + "linux", + "mac", + "posix", + "windows" + ], + "cpu_cost": 1.0, + "exclude_configs": [], + "exclude_iomgrs": [], + "flaky": false, + "gtest": true, + "language": "c++", + "name": "server_transport_test", + "platforms": [ + "linux", + "mac", + "posix", + "windows" + ], + "uses_polling": false + }, { "args": [], "benchmark": false,