diff --git a/BUILD b/BUILD index 044350b4d8c..b751f1d5fc5 100644 --- a/BUILD +++ b/BUILD @@ -701,6 +701,7 @@ grpc_cc_library( external_deps = [ "absl/base", "absl/base:core_headers", + "absl/functional:any_invocable", "absl/memory", "absl/random", "absl/status", @@ -1311,6 +1312,7 @@ grpc_cc_library( "//src/core:lib/transport/timeout_encoding.cc", "//src/core:lib/transport/transport.cc", "//src/core:lib/transport/transport_op_string.cc", + "//src/core:lib/transport/batch_builder.cc", ] + # TODO(vigneshbabu): remove these # These headers used to be vended by this target, but they have to be @@ -1402,6 +1404,7 @@ grpc_cc_library( "//src/core:lib/transport/timeout_encoding.h", "//src/core:lib/transport/transport.h", "//src/core:lib/transport/transport_impl.h", + "//src/core:lib/transport/batch_builder.h", ] + # TODO(vigneshbabu): remove these # These headers used to be vended by this target, but they have to be @@ -1458,6 +1461,7 @@ grpc_cc_library( "stats", "uri_parser", "work_serializer", + "//src/core:1999", "//src/core:activity", "//src/core:arena", "//src/core:arena_promise", @@ -1486,15 +1490,19 @@ grpc_cc_library( "//src/core:event_engine_trace", "//src/core:event_log", "//src/core:experiments", + "//src/core:for_each", "//src/core:gpr_atm", "//src/core:gpr_manual_constructor", "//src/core:gpr_spinlock", "//src/core:grpc_sockaddr", "//src/core:http2_errors", + "//src/core:if", "//src/core:init_internally", "//src/core:iomgr_fwd", "//src/core:iomgr_port", "//src/core:json", + "//src/core:latch", + "//src/core:loop", "//src/core:map", "//src/core:match", "//src/core:memory_quota", @@ -1506,10 +1514,12 @@ grpc_cc_library( "//src/core:pollset_set", "//src/core:posix_event_engine_base_hdrs", "//src/core:promise_status", + "//src/core:race", "//src/core:ref_counted", "//src/core:resolved_address", "//src/core:resource_quota", "//src/core:resource_quota_trace", + "//src/core:seq", "//src/core:slice", "//src/core:slice_buffer", "//src/core:slice_cast", @@ -2345,6 +2355,7 @@ grpc_cc_library( grpc_cc_library( name = "promise", external_deps = [ + "absl/functional:any_invocable", "absl/status", "absl/types:optional", ], diff --git a/CMakeLists.txt b/CMakeLists.txt index 371d2e2c149..44c28879312 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1070,7 +1070,6 @@ if(gRPC_BUILD_TESTS) add_dependencies(buildtests_cxx nonblocking_test) add_dependencies(buildtests_cxx notification_test) add_dependencies(buildtests_cxx num_external_connectivity_watchers_test) - add_dependencies(buildtests_cxx observable_test) if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) add_dependencies(buildtests_cxx oracle_event_engine_posix_test) endif() @@ -1644,6 +1643,7 @@ target_link_libraries(gpr ${_gRPC_ALLTARGETS_LIBRARIES} absl::base absl::core_headers + absl::any_invocable absl::memory absl::random_random absl::status @@ -2317,6 +2317,7 @@ add_library(grpc src/core/lib/load_balancing/lb_policy_registry.cc src/core/lib/matchers/matchers.cc src/core/lib/promise/activity.cc + src/core/lib/promise/party.cc src/core/lib/promise/sleep.cc src/core/lib/promise/trace.cc src/core/lib/resolver/resolver.cc @@ -2420,6 +2421,7 @@ add_library(grpc src/core/lib/surface/server.cc src/core/lib/surface/validate_metadata.cc src/core/lib/surface/version.cc + src/core/lib/transport/batch_builder.cc src/core/lib/transport/bdp_estimator.cc src/core/lib/transport/connectivity_state.cc src/core/lib/transport/error_utils.cc @@ -2511,7 +2513,6 @@ target_link_libraries(grpc absl::flat_hash_map absl::flat_hash_set absl::inlined_vector - absl::any_invocable absl::bind_front absl::function_ref absl::hash @@ -3004,6 +3005,7 @@ add_library(grpc_unsecure src/core/lib/load_balancing/lb_policy.cc src/core/lib/load_balancing/lb_policy_registry.cc src/core/lib/promise/activity.cc + src/core/lib/promise/party.cc src/core/lib/promise/sleep.cc src/core/lib/promise/trace.cc src/core/lib/resolver/resolver.cc @@ -3076,6 +3078,7 @@ add_library(grpc_unsecure src/core/lib/surface/server.cc src/core/lib/surface/validate_metadata.cc src/core/lib/surface/version.cc + src/core/lib/transport/batch_builder.cc src/core/lib/transport/bdp_estimator.cc src/core/lib/transport/connectivity_state.cc src/core/lib/transport/error_utils.cc @@ -3143,7 +3146,6 @@ target_link_libraries(grpc_unsecure absl::flat_hash_map absl::flat_hash_set absl::inlined_vector - absl::any_invocable absl::bind_front absl::function_ref absl::hash @@ -4523,6 +4525,7 @@ add_library(grpc_authorization_provider src/core/lib/load_balancing/lb_policy_registry.cc src/core/lib/matchers/matchers.cc src/core/lib/promise/activity.cc + src/core/lib/promise/party.cc src/core/lib/promise/trace.cc src/core/lib/resolver/resolver.cc src/core/lib/resolver/resolver_registry.cc @@ -4593,6 +4596,7 @@ add_library(grpc_authorization_provider src/core/lib/surface/server.cc src/core/lib/surface/validate_metadata.cc src/core/lib/surface/version.cc + src/core/lib/transport/batch_builder.cc src/core/lib/transport/connectivity_state.cc src/core/lib/transport/error_utils.cc src/core/lib/transport/handshaker.cc @@ -4651,7 +4655,6 @@ target_link_libraries(grpc_authorization_provider absl::flat_hash_map absl::flat_hash_set absl::inlined_vector - absl::any_invocable absl::function_ref absl::hash absl::type_traits @@ -5403,7 +5406,6 @@ if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_POSIX) ${_gRPC_ZLIB_LIBRARIES} ${_gRPC_ALLTARGETS_LIBRARIES} absl::algorithm_container - absl::any_invocable absl::span ${_gRPC_BENCHMARK_LIBRARIES} gpr @@ -8350,7 +8352,6 @@ target_link_libraries(chunked_vector_test ${_gRPC_ZLIB_LIBRARIES} ${_gRPC_ALLTARGETS_LIBRARIES} absl::flat_hash_set - absl::any_invocable absl::function_ref absl::hash absl::type_traits @@ -9099,7 +9100,6 @@ target_link_libraries(common_closures_test ${_gRPC_PROTOBUF_LIBRARIES} ${_gRPC_ZLIB_LIBRARIES} ${_gRPC_ALLTARGETS_LIBRARIES} - absl::any_invocable absl::statusor gpr ) @@ -10128,7 +10128,6 @@ target_link_libraries(endpoint_config_test ${_gRPC_PROTOBUF_LIBRARIES} ${_gRPC_ZLIB_LIBRARIES} ${_gRPC_ALLTARGETS_LIBRARIES} - absl::any_invocable absl::type_traits absl::statusor gpr @@ -10638,7 +10637,6 @@ target_link_libraries(exec_ctx_wakeup_scheduler_test ${_gRPC_PROTOBUF_LIBRARIES} ${_gRPC_ZLIB_LIBRARIES} ${_gRPC_ALLTARGETS_LIBRARIES} - absl::any_invocable absl::hash absl::type_traits absl::statusor @@ -11121,7 +11119,6 @@ target_link_libraries(flow_control_test ${_gRPC_ZLIB_LIBRARIES} ${_gRPC_ALLTARGETS_LIBRARIES} absl::flat_hash_set - absl::any_invocable absl::function_ref absl::hash absl::type_traits @@ -11192,7 +11189,6 @@ target_link_libraries(for_each_test ${_gRPC_ZLIB_LIBRARIES} ${_gRPC_ALLTARGETS_LIBRARIES} absl::flat_hash_set - absl::any_invocable absl::function_ref absl::hash absl::type_traits @@ -11277,7 +11273,6 @@ target_link_libraries(forkable_test ${_gRPC_ZLIB_LIBRARIES} ${_gRPC_ALLTARGETS_LIBRARIES} absl::flat_hash_set - absl::any_invocable absl::statusor gpr ) @@ -11587,6 +11582,7 @@ add_executable(frame_test src/core/lib/load_balancing/lb_policy.cc src/core/lib/load_balancing/lb_policy_registry.cc src/core/lib/promise/activity.cc + src/core/lib/promise/party.cc src/core/lib/promise/trace.cc src/core/lib/resolver/resolver.cc src/core/lib/resolver/resolver_registry.cc @@ -11634,6 +11630,7 @@ add_executable(frame_test src/core/lib/surface/server.cc src/core/lib/surface/validate_metadata.cc src/core/lib/surface/version.cc + src/core/lib/transport/batch_builder.cc src/core/lib/transport/connectivity_state.cc src/core/lib/transport/error_utils.cc src/core/lib/transport/handshaker_registry.cc @@ -11678,7 +11675,6 @@ target_link_libraries(frame_test absl::flat_hash_map absl::flat_hash_set absl::inlined_vector - absl::any_invocable absl::function_ref absl::hash absl::type_traits @@ -14127,7 +14123,6 @@ target_link_libraries(interceptor_list_test ${_gRPC_ZLIB_LIBRARIES} ${_gRPC_ALLTARGETS_LIBRARIES} absl::flat_hash_set - absl::any_invocable absl::function_ref absl::hash absl::type_traits @@ -14917,7 +14912,6 @@ target_link_libraries(map_pipe_test ${_gRPC_ZLIB_LIBRARIES} ${_gRPC_ALLTARGETS_LIBRARIES} absl::flat_hash_set - absl::any_invocable absl::function_ref absl::hash absl::type_traits @@ -15708,49 +15702,6 @@ target_link_libraries(num_external_connectivity_watchers_test ) -endif() -if(gRPC_BUILD_TESTS) - -add_executable(observable_test - src/core/lib/promise/activity.cc - test/core/promise/observable_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) -target_compile_features(observable_test PUBLIC cxx_std_14) -target_include_directories(observable_test - PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR} - ${CMAKE_CURRENT_SOURCE_DIR}/include - ${_gRPC_ADDRESS_SORTING_INCLUDE_DIR} - ${_gRPC_RE2_INCLUDE_DIR} - ${_gRPC_SSL_INCLUDE_DIR} - ${_gRPC_UPB_GENERATED_DIR} - ${_gRPC_UPB_GRPC_GENERATED_DIR} - ${_gRPC_UPB_INCLUDE_DIR} - ${_gRPC_XXHASH_INCLUDE_DIR} - ${_gRPC_ZLIB_INCLUDE_DIR} - third_party/googletest/googletest/include - third_party/googletest/googletest - third_party/googletest/googlemock/include - third_party/googletest/googlemock - ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(observable_test - ${_gRPC_BASELIB_LIBRARIES} - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ZLIB_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - absl::flat_hash_set - absl::hash - absl::type_traits - absl::statusor - absl::utility - gpr -) - - endif() if(gRPC_BUILD_TESTS) if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) @@ -16198,7 +16149,6 @@ endif() if(gRPC_BUILD_TESTS) add_executable(party_test - src/core/lib/promise/party.cc test/core/promise/party_test.cc third_party/googletest/googletest/src/gtest-all.cc third_party/googletest/googlemock/src/gmock-all.cc @@ -16318,7 +16268,6 @@ target_link_libraries(periodic_update_test ${_gRPC_PROTOBUF_LIBRARIES} ${_gRPC_ZLIB_LIBRARIES} ${_gRPC_ALLTARGETS_LIBRARIES} - absl::any_invocable absl::function_ref absl::hash absl::statusor @@ -19102,7 +19051,6 @@ target_link_libraries(slice_string_helpers_test ${_gRPC_PROTOBUF_LIBRARIES} ${_gRPC_ZLIB_LIBRARIES} ${_gRPC_ALLTARGETS_LIBRARIES} - absl::any_invocable absl::hash absl::statusor gpr @@ -19582,7 +19530,6 @@ target_link_libraries(static_stride_scheduler_test ${_gRPC_PROTOBUF_LIBRARIES} ${_gRPC_ZLIB_LIBRARIES} ${_gRPC_ALLTARGETS_LIBRARIES} - absl::any_invocable absl::span gpr ) @@ -20412,7 +20359,6 @@ target_link_libraries(test_core_event_engine_posix_timer_heap_test ${_gRPC_PROTOBUF_LIBRARIES} ${_gRPC_ZLIB_LIBRARIES} ${_gRPC_ALLTARGETS_LIBRARIES} - absl::any_invocable absl::statusor gpr ) @@ -20455,7 +20401,6 @@ target_link_libraries(test_core_event_engine_posix_timer_list_test ${_gRPC_PROTOBUF_LIBRARIES} ${_gRPC_ZLIB_LIBRARIES} ${_gRPC_ALLTARGETS_LIBRARIES} - absl::any_invocable absl::statusor gpr ) @@ -20504,7 +20449,6 @@ target_link_libraries(test_core_event_engine_slice_buffer_test ${_gRPC_ZLIB_LIBRARIES} ${_gRPC_ALLTARGETS_LIBRARIES} absl::flat_hash_set - absl::any_invocable absl::hash absl::statusor absl::utility @@ -20620,7 +20564,6 @@ target_link_libraries(test_core_gprpp_time_test ${_gRPC_PROTOBUF_LIBRARIES} ${_gRPC_ZLIB_LIBRARIES} ${_gRPC_ALLTARGETS_LIBRARIES} - absl::any_invocable absl::statusor gpr ) @@ -21112,7 +21055,6 @@ target_link_libraries(thread_pool_test ${_gRPC_ZLIB_LIBRARIES} ${_gRPC_ALLTARGETS_LIBRARIES} absl::flat_hash_set - absl::any_invocable absl::statusor gpr ) @@ -26153,7 +26095,7 @@ generate_pkgconfig( "gpr" "gRPC platform support library" "${gRPC_CORE_VERSION}" - "absl_base absl_cord absl_core_headers absl_memory absl_optional absl_random_random absl_status absl_str_format absl_strings absl_synchronization absl_time absl_variant" + "absl_any_invocable absl_base absl_cord absl_core_headers absl_memory absl_optional absl_random_random absl_status absl_str_format absl_strings absl_synchronization absl_time absl_variant" "" "-lgpr" "" diff --git a/Makefile b/Makefile index f5710676a33..a3560aa12a7 100644 --- a/Makefile +++ b/Makefile @@ -1561,6 +1561,7 @@ LIBGRPC_SRC = \ src/core/lib/load_balancing/lb_policy_registry.cc \ src/core/lib/matchers/matchers.cc \ src/core/lib/promise/activity.cc \ + src/core/lib/promise/party.cc \ src/core/lib/promise/sleep.cc \ src/core/lib/promise/trace.cc \ src/core/lib/resolver/resolver.cc \ @@ -1664,6 +1665,7 @@ LIBGRPC_SRC = \ src/core/lib/surface/server.cc \ src/core/lib/surface/validate_metadata.cc \ src/core/lib/surface/version.cc \ + src/core/lib/transport/batch_builder.cc \ src/core/lib/transport/bdp_estimator.cc \ src/core/lib/transport/connectivity_state.cc \ src/core/lib/transport/error_utils.cc \ @@ -2101,6 +2103,7 @@ LIBGRPC_UNSECURE_SRC = \ src/core/lib/load_balancing/lb_policy.cc \ src/core/lib/load_balancing/lb_policy_registry.cc \ src/core/lib/promise/activity.cc \ + src/core/lib/promise/party.cc \ src/core/lib/promise/sleep.cc \ src/core/lib/promise/trace.cc \ src/core/lib/resolver/resolver.cc \ @@ -2173,6 +2176,7 @@ LIBGRPC_UNSECURE_SRC = \ src/core/lib/surface/server.cc \ src/core/lib/surface/validate_metadata.cc \ src/core/lib/surface/version.cc \ + src/core/lib/transport/batch_builder.cc \ src/core/lib/transport/bdp_estimator.cc \ src/core/lib/transport/connectivity_state.cc \ src/core/lib/transport/error_utils.cc \ diff --git a/build_autogenerated.yaml b/build_autogenerated.yaml index 8533b98e880..8d556f6c62b 100644 --- a/build_autogenerated.yaml +++ b/build_autogenerated.yaml @@ -251,6 +251,7 @@ libs: deps: - absl/base:base - absl/base:core_headers + - absl/functional:any_invocable - absl/memory:memory - absl/random:random - absl/status:status @@ -951,12 +952,13 @@ libs: - src/core/lib/promise/detail/status.h - src/core/lib/promise/detail/switch.h - src/core/lib/promise/exec_ctx_wakeup_scheduler.h + - src/core/lib/promise/for_each.h - src/core/lib/promise/if.h - src/core/lib/promise/interceptor_list.h - - src/core/lib/promise/intra_activity_waiter.h - src/core/lib/promise/latch.h - src/core/lib/promise/loop.h - src/core/lib/promise/map.h + - src/core/lib/promise/party.h - src/core/lib/promise/pipe.h - src/core/lib/promise/poll.h - src/core/lib/promise/promise.h @@ -1060,6 +1062,7 @@ libs: - src/core/lib/surface/lame_client.h - src/core/lib/surface/server.h - src/core/lib/surface/validate_metadata.h + - src/core/lib/transport/batch_builder.h - src/core/lib/transport/bdp_estimator.h - src/core/lib/transport/connectivity_state.h - src/core/lib/transport/error_utils.h @@ -1711,6 +1714,7 @@ libs: - src/core/lib/load_balancing/lb_policy_registry.cc - src/core/lib/matchers/matchers.cc - src/core/lib/promise/activity.cc + - src/core/lib/promise/party.cc - src/core/lib/promise/sleep.cc - src/core/lib/promise/trace.cc - src/core/lib/resolver/resolver.cc @@ -1814,6 +1818,7 @@ libs: - src/core/lib/surface/server.cc - src/core/lib/surface/validate_metadata.cc - src/core/lib/surface/version.cc + - src/core/lib/transport/batch_builder.cc - src/core/lib/transport/bdp_estimator.cc - src/core/lib/transport/connectivity_state.cc - src/core/lib/transport/error_utils.cc @@ -1865,7 +1870,6 @@ libs: - absl/container:flat_hash_map - absl/container:flat_hash_set - absl/container:inlined_vector - - absl/functional:any_invocable - absl/functional:bind_front - absl/functional:function_ref - absl/hash:hash @@ -2292,12 +2296,13 @@ libs: - src/core/lib/promise/detail/status.h - src/core/lib/promise/detail/switch.h - src/core/lib/promise/exec_ctx_wakeup_scheduler.h + - src/core/lib/promise/for_each.h - src/core/lib/promise/if.h - src/core/lib/promise/interceptor_list.h - - src/core/lib/promise/intra_activity_waiter.h - src/core/lib/promise/latch.h - src/core/lib/promise/loop.h - src/core/lib/promise/map.h + - src/core/lib/promise/party.h - src/core/lib/promise/pipe.h - src/core/lib/promise/poll.h - src/core/lib/promise/promise.h @@ -2372,6 +2377,7 @@ libs: - src/core/lib/surface/lame_client.h - src/core/lib/surface/server.h - src/core/lib/surface/validate_metadata.h + - src/core/lib/transport/batch_builder.h - src/core/lib/transport/bdp_estimator.h - src/core/lib/transport/connectivity_state.h - src/core/lib/transport/error_utils.h @@ -2665,6 +2671,7 @@ libs: - src/core/lib/load_balancing/lb_policy.cc - src/core/lib/load_balancing/lb_policy_registry.cc - src/core/lib/promise/activity.cc + - src/core/lib/promise/party.cc - src/core/lib/promise/sleep.cc - src/core/lib/promise/trace.cc - src/core/lib/resolver/resolver.cc @@ -2737,6 +2744,7 @@ libs: - src/core/lib/surface/server.cc - src/core/lib/surface/validate_metadata.cc - src/core/lib/surface/version.cc + - src/core/lib/transport/batch_builder.cc - src/core/lib/transport/bdp_estimator.cc - src/core/lib/transport/connectivity_state.cc - src/core/lib/transport/error_utils.cc @@ -2764,7 +2772,6 @@ libs: - absl/container:flat_hash_map - absl/container:flat_hash_set - absl/container:inlined_vector - - absl/functional:any_invocable - absl/functional:bind_front - absl/functional:function_ref - absl/hash:hash @@ -3755,11 +3762,13 @@ libs: - src/core/lib/promise/detail/status.h - src/core/lib/promise/detail/switch.h - src/core/lib/promise/exec_ctx_wakeup_scheduler.h + - src/core/lib/promise/for_each.h - src/core/lib/promise/if.h - src/core/lib/promise/interceptor_list.h - - src/core/lib/promise/intra_activity_waiter.h + - src/core/lib/promise/latch.h - src/core/lib/promise/loop.h - src/core/lib/promise/map.h + - src/core/lib/promise/party.h - src/core/lib/promise/pipe.h - src/core/lib/promise/poll.h - src/core/lib/promise/promise.h @@ -3833,6 +3842,7 @@ libs: - src/core/lib/surface/lame_client.h - src/core/lib/surface/server.h - src/core/lib/surface/validate_metadata.h + - src/core/lib/transport/batch_builder.h - src/core/lib/transport/connectivity_state.h - src/core/lib/transport/error_utils.h - src/core/lib/transport/handshaker.h @@ -4010,6 +4020,7 @@ libs: - src/core/lib/load_balancing/lb_policy_registry.cc - src/core/lib/matchers/matchers.cc - src/core/lib/promise/activity.cc + - src/core/lib/promise/party.cc - src/core/lib/promise/trace.cc - src/core/lib/resolver/resolver.cc - src/core/lib/resolver/resolver_registry.cc @@ -4080,6 +4091,7 @@ libs: - src/core/lib/surface/server.cc - src/core/lib/surface/validate_metadata.cc - src/core/lib/surface/version.cc + - src/core/lib/transport/batch_builder.cc - src/core/lib/transport/connectivity_state.cc - src/core/lib/transport/error_utils.cc - src/core/lib/transport/handshaker.cc @@ -4099,7 +4111,6 @@ libs: - absl/container:flat_hash_map - absl/container:flat_hash_set - absl/container:inlined_vector - - absl/functional:any_invocable - absl/functional:function_ref - absl/hash:hash - absl/meta:type_traits @@ -4381,7 +4392,6 @@ targets: - test/core/client_channel/lb_policy/static_stride_scheduler_benchmark.cc deps: - absl/algorithm:container - - absl/functional:any_invocable - absl/types:span - benchmark - gpr @@ -5866,7 +5876,6 @@ targets: - test/core/gprpp/chunked_vector_test.cc deps: - absl/container:flat_hash_set - - absl/functional:any_invocable - absl/functional:function_ref - absl/hash:hash - absl/meta:type_traits @@ -6154,7 +6163,6 @@ targets: src: - test/core/event_engine/common_closures_test.cc deps: - - absl/functional:any_invocable - absl/status:statusor - gpr - name: completion_queue_threading_test @@ -6585,7 +6593,6 @@ targets: - src/core/lib/surface/channel_stack_type.cc - test/core/event_engine/endpoint_config_test.cc deps: - - absl/functional:any_invocable - absl/meta:type_traits - absl/status:statusor - gpr @@ -6861,7 +6868,6 @@ targets: - src/core/lib/slice/slice_string_helpers.cc - test/core/promise/exec_ctx_wakeup_scheduler_test.cc deps: - - absl/functional:any_invocable - absl/hash:hash - absl/meta:type_traits - absl/status:statusor @@ -7168,7 +7174,6 @@ targets: - test/core/transport/chttp2/flow_control_test.cc deps: - absl/container:flat_hash_set - - absl/functional:any_invocable - absl/functional:function_ref - absl/hash:hash - absl/meta:type_traits @@ -7215,7 +7220,6 @@ targets: - src/core/lib/promise/for_each.h - src/core/lib/promise/if.h - src/core/lib/promise/interceptor_list.h - - src/core/lib/promise/intra_activity_waiter.h - src/core/lib/promise/join.h - src/core/lib/promise/loop.h - src/core/lib/promise/map.h @@ -7267,7 +7271,6 @@ targets: - test/core/promise/for_each_test.cc deps: - absl/container:flat_hash_set - - absl/functional:any_invocable - absl/functional:function_ref - absl/hash:hash - absl/meta:type_traits @@ -7301,7 +7304,6 @@ targets: - test/core/event_engine/forkable_test.cc deps: - absl/container:flat_hash_set - - absl/functional:any_invocable - absl/status:statusor - gpr - name: format_request_test @@ -7571,11 +7573,13 @@ targets: - src/core/lib/promise/detail/status.h - src/core/lib/promise/detail/switch.h - src/core/lib/promise/exec_ctx_wakeup_scheduler.h + - src/core/lib/promise/for_each.h - src/core/lib/promise/if.h - src/core/lib/promise/interceptor_list.h - - src/core/lib/promise/intra_activity_waiter.h + - src/core/lib/promise/latch.h - src/core/lib/promise/loop.h - src/core/lib/promise/map.h + - src/core/lib/promise/party.h - src/core/lib/promise/pipe.h - src/core/lib/promise/poll.h - src/core/lib/promise/promise.h @@ -7626,6 +7630,7 @@ targets: - src/core/lib/surface/lame_client.h - src/core/lib/surface/server.h - src/core/lib/surface/validate_metadata.h + - src/core/lib/transport/batch_builder.h - src/core/lib/transport/connectivity_state.h - src/core/lib/transport/error_utils.h - src/core/lib/transport/handshaker_factory.h @@ -7809,6 +7814,7 @@ targets: - src/core/lib/load_balancing/lb_policy.cc - src/core/lib/load_balancing/lb_policy_registry.cc - src/core/lib/promise/activity.cc + - src/core/lib/promise/party.cc - src/core/lib/promise/trace.cc - src/core/lib/resolver/resolver.cc - src/core/lib/resolver/resolver_registry.cc @@ -7856,6 +7862,7 @@ targets: - src/core/lib/surface/server.cc - src/core/lib/surface/validate_metadata.cc - src/core/lib/surface/version.cc + - src/core/lib/transport/batch_builder.cc - src/core/lib/transport/connectivity_state.cc - src/core/lib/transport/error_utils.cc - src/core/lib/transport/handshaker_registry.cc @@ -7873,7 +7880,6 @@ targets: - absl/container:flat_hash_map - absl/container:flat_hash_set - absl/container:inlined_vector - - absl/functional:any_invocable - absl/functional:function_ref - absl/hash:hash - absl/meta:type_traits @@ -8973,7 +8979,6 @@ targets: - test/core/promise/interceptor_list_test.cc deps: - absl/container:flat_hash_set - - absl/functional:any_invocable - absl/functional:function_ref - absl/hash:hash - absl/meta:type_traits @@ -9187,7 +9192,6 @@ targets: - src/core/lib/promise/detail/promise_like.h - src/core/lib/promise/detail/status.h - src/core/lib/promise/detail/switch.h - - src/core/lib/promise/intra_activity_waiter.h - src/core/lib/promise/join.h - src/core/lib/promise/latch.h - src/core/lib/promise/poll.h @@ -9330,7 +9334,6 @@ targets: - src/core/lib/promise/for_each.h - src/core/lib/promise/if.h - src/core/lib/promise/interceptor_list.h - - src/core/lib/promise/intra_activity_waiter.h - src/core/lib/promise/join.h - src/core/lib/promise/loop.h - src/core/lib/promise/map.h @@ -9383,7 +9386,6 @@ targets: - test/core/promise/map_pipe_test.cc deps: - absl/container:flat_hash_set - - absl/functional:any_invocable - absl/functional:function_ref - absl/hash:hash - absl/meta:type_traits @@ -9689,39 +9691,6 @@ targets: - test/core/surface/num_external_connectivity_watchers_test.cc deps: - grpc_test_util -- name: observable_test - gtest: true - build: test - language: c++ - headers: - - src/core/lib/gprpp/atomic_utils.h - - src/core/lib/gprpp/orphanable.h - - src/core/lib/gprpp/ref_counted.h - - src/core/lib/gprpp/ref_counted_ptr.h - - src/core/lib/promise/activity.h - - src/core/lib/promise/context.h - - src/core/lib/promise/detail/basic_seq.h - - src/core/lib/promise/detail/promise_factory.h - - src/core/lib/promise/detail/promise_like.h - - src/core/lib/promise/detail/status.h - - src/core/lib/promise/detail/switch.h - - src/core/lib/promise/observable.h - - src/core/lib/promise/poll.h - - src/core/lib/promise/promise.h - - src/core/lib/promise/seq.h - - src/core/lib/promise/wait_set.h - - test/core/promise/test_wakeup_schedulers.h - src: - - src/core/lib/promise/activity.cc - - test/core/promise/observable_test.cc - deps: - - absl/container:flat_hash_set - - absl/hash:hash - - absl/meta:type_traits - - absl/status:statusor - - absl/utility:utility - - gpr - uses_polling: false - name: oracle_event_engine_posix_test gtest: true build: test @@ -9888,10 +9857,8 @@ targets: gtest: true build: test language: c++ - headers: - - src/core/lib/promise/party.h + headers: [] src: - - src/core/lib/promise/party.cc - test/core/promise/party_test.cc deps: - grpc_unsecure @@ -9951,7 +9918,6 @@ targets: - src/core/lib/slice/slice_string_helpers.cc - test/core/resource_quota/periodic_update_test.cc deps: - - absl/functional:any_invocable - absl/functional:function_ref - absl/hash:hash - absl/status:statusor @@ -10130,7 +10096,6 @@ targets: - src/core/lib/promise/detail/promise_factory.h - src/core/lib/promise/detail/promise_like.h - src/core/lib/promise/poll.h - - src/core/lib/promise/promise.h src: - test/core/promise/promise_factory_test.cc deps: @@ -11113,7 +11078,6 @@ targets: - src/core/lib/slice/slice_string_helpers.cc - test/core/slice/slice_string_helpers_test.cc deps: - - absl/functional:any_invocable - absl/hash:hash - absl/status:statusor - gpr @@ -11299,7 +11263,6 @@ targets: - src/core/ext/filters/client_channel/lb_policy/weighted_round_robin/static_stride_scheduler.cc - test/core/client_channel/lb_policy/static_stride_scheduler_test.cc deps: - - absl/functional:any_invocable - absl/types:span - gpr uses_polling: false @@ -11704,7 +11667,6 @@ targets: - src/core/lib/gprpp/time_averaged_stats.cc - test/core/event_engine/posix/timer_heap_test.cc deps: - - absl/functional:any_invocable - absl/status:statusor - gpr uses_polling: false @@ -11724,7 +11686,6 @@ targets: - src/core/lib/gprpp/time_averaged_stats.cc - test/core/event_engine/posix/timer_list_test.cc deps: - - absl/functional:any_invocable - absl/status:statusor - gpr uses_polling: false @@ -11756,7 +11717,6 @@ targets: - test/core/event_engine/slice_buffer_test.cc deps: - absl/container:flat_hash_set - - absl/functional:any_invocable - absl/hash:hash - absl/status:statusor - absl/utility:utility @@ -11791,7 +11751,6 @@ targets: - src/core/lib/gprpp/time.cc - test/core/gprpp/time_test.cc deps: - - absl/functional:any_invocable - absl/status:statusor - gpr uses_polling: false @@ -12002,7 +11961,6 @@ targets: - test/core/event_engine/thread_pool_test.cc deps: - absl/container:flat_hash_set - - absl/functional:any_invocable - absl/status:statusor - gpr - name: thread_quota_test diff --git a/config.m4 b/config.m4 index 92612e073b6..410626e3323 100644 --- a/config.m4 +++ b/config.m4 @@ -686,6 +686,7 @@ if test "$PHP_GRPC" != "no"; then src/core/lib/load_balancing/lb_policy_registry.cc \ src/core/lib/matchers/matchers.cc \ src/core/lib/promise/activity.cc \ + src/core/lib/promise/party.cc \ src/core/lib/promise/sleep.cc \ src/core/lib/promise/trace.cc \ src/core/lib/resolver/resolver.cc \ @@ -789,6 +790,7 @@ if test "$PHP_GRPC" != "no"; then src/core/lib/surface/server.cc \ src/core/lib/surface/validate_metadata.cc \ src/core/lib/surface/version.cc \ + src/core/lib/transport/batch_builder.cc \ src/core/lib/transport/bdp_estimator.cc \ src/core/lib/transport/connectivity_state.cc \ src/core/lib/transport/error_utils.cc \ diff --git a/config.w32 b/config.w32 index 6da67e3c94c..d20340afe8a 100644 --- a/config.w32 +++ b/config.w32 @@ -652,6 +652,7 @@ if (PHP_GRPC != "no") { "src\\core\\lib\\load_balancing\\lb_policy_registry.cc " + "src\\core\\lib\\matchers\\matchers.cc " + "src\\core\\lib\\promise\\activity.cc " + + "src\\core\\lib\\promise\\party.cc " + "src\\core\\lib\\promise\\sleep.cc " + "src\\core\\lib\\promise\\trace.cc " + "src\\core\\lib\\resolver\\resolver.cc " + @@ -755,6 +756,7 @@ if (PHP_GRPC != "no") { "src\\core\\lib\\surface\\server.cc " + "src\\core\\lib\\surface\\validate_metadata.cc " + "src\\core\\lib\\surface\\version.cc " + + "src\\core\\lib\\transport\\batch_builder.cc " + "src\\core\\lib\\transport\\bdp_estimator.cc " + "src\\core\\lib\\transport\\connectivity_state.cc " + "src\\core\\lib\\transport\\error_utils.cc " + diff --git a/gRPC-C++.podspec b/gRPC-C++.podspec index d03714be789..baee59d5aaa 100644 --- a/gRPC-C++.podspec +++ b/gRPC-C++.podspec @@ -921,12 +921,13 @@ Pod::Spec.new do |s| 'src/core/lib/promise/detail/status.h', 'src/core/lib/promise/detail/switch.h', 'src/core/lib/promise/exec_ctx_wakeup_scheduler.h', + 'src/core/lib/promise/for_each.h', 'src/core/lib/promise/if.h', 'src/core/lib/promise/interceptor_list.h', - 'src/core/lib/promise/intra_activity_waiter.h', 'src/core/lib/promise/latch.h', 'src/core/lib/promise/loop.h', 'src/core/lib/promise/map.h', + 'src/core/lib/promise/party.h', 'src/core/lib/promise/pipe.h', 'src/core/lib/promise/poll.h', 'src/core/lib/promise/promise.h', @@ -1030,6 +1031,7 @@ Pod::Spec.new do |s| 'src/core/lib/surface/lame_client.h', 'src/core/lib/surface/server.h', 'src/core/lib/surface/validate_metadata.h', + 'src/core/lib/transport/batch_builder.h', 'src/core/lib/transport/bdp_estimator.h', 'src/core/lib/transport/connectivity_state.h', 'src/core/lib/transport/error_utils.h', @@ -1859,12 +1861,13 @@ Pod::Spec.new do |s| 'src/core/lib/promise/detail/status.h', 'src/core/lib/promise/detail/switch.h', 'src/core/lib/promise/exec_ctx_wakeup_scheduler.h', + 'src/core/lib/promise/for_each.h', 'src/core/lib/promise/if.h', 'src/core/lib/promise/interceptor_list.h', - 'src/core/lib/promise/intra_activity_waiter.h', 'src/core/lib/promise/latch.h', 'src/core/lib/promise/loop.h', 'src/core/lib/promise/map.h', + 'src/core/lib/promise/party.h', 'src/core/lib/promise/pipe.h', 'src/core/lib/promise/poll.h', 'src/core/lib/promise/promise.h', @@ -1968,6 +1971,7 @@ Pod::Spec.new do |s| 'src/core/lib/surface/lame_client.h', 'src/core/lib/surface/server.h', 'src/core/lib/surface/validate_metadata.h', + 'src/core/lib/transport/batch_builder.h', 'src/core/lib/transport/bdp_estimator.h', 'src/core/lib/transport/connectivity_state.h', 'src/core/lib/transport/error_utils.h', diff --git a/gRPC-Core.podspec b/gRPC-Core.podspec index 4ec4870a7e0..ef365e45e78 100644 --- a/gRPC-Core.podspec +++ b/gRPC-Core.podspec @@ -1496,12 +1496,14 @@ Pod::Spec.new do |s| 'src/core/lib/promise/detail/status.h', 'src/core/lib/promise/detail/switch.h', 'src/core/lib/promise/exec_ctx_wakeup_scheduler.h', + 'src/core/lib/promise/for_each.h', 'src/core/lib/promise/if.h', 'src/core/lib/promise/interceptor_list.h', - 'src/core/lib/promise/intra_activity_waiter.h', 'src/core/lib/promise/latch.h', 'src/core/lib/promise/loop.h', 'src/core/lib/promise/map.h', + 'src/core/lib/promise/party.cc', + 'src/core/lib/promise/party.h', 'src/core/lib/promise/pipe.h', 'src/core/lib/promise/poll.h', 'src/core/lib/promise/promise.h', @@ -1708,6 +1710,8 @@ Pod::Spec.new do |s| 'src/core/lib/surface/validate_metadata.cc', 'src/core/lib/surface/validate_metadata.h', 'src/core/lib/surface/version.cc', + 'src/core/lib/transport/batch_builder.cc', + 'src/core/lib/transport/batch_builder.h', 'src/core/lib/transport/bdp_estimator.cc', 'src/core/lib/transport/bdp_estimator.h', 'src/core/lib/transport/connectivity_state.cc', @@ -2546,12 +2550,13 @@ Pod::Spec.new do |s| 'src/core/lib/promise/detail/status.h', 'src/core/lib/promise/detail/switch.h', 'src/core/lib/promise/exec_ctx_wakeup_scheduler.h', + 'src/core/lib/promise/for_each.h', 'src/core/lib/promise/if.h', 'src/core/lib/promise/interceptor_list.h', - 'src/core/lib/promise/intra_activity_waiter.h', 'src/core/lib/promise/latch.h', 'src/core/lib/promise/loop.h', 'src/core/lib/promise/map.h', + 'src/core/lib/promise/party.h', 'src/core/lib/promise/pipe.h', 'src/core/lib/promise/poll.h', 'src/core/lib/promise/promise.h', @@ -2655,6 +2660,7 @@ Pod::Spec.new do |s| 'src/core/lib/surface/lame_client.h', 'src/core/lib/surface/server.h', 'src/core/lib/surface/validate_metadata.h', + 'src/core/lib/transport/batch_builder.h', 'src/core/lib/transport/bdp_estimator.h', 'src/core/lib/transport/connectivity_state.h', 'src/core/lib/transport/error_utils.h', diff --git a/grpc.gemspec b/grpc.gemspec index 9426def7f7c..6facb0a1066 100644 --- a/grpc.gemspec +++ b/grpc.gemspec @@ -1405,12 +1405,14 @@ Gem::Specification.new do |s| s.files += %w( src/core/lib/promise/detail/status.h ) s.files += %w( src/core/lib/promise/detail/switch.h ) s.files += %w( src/core/lib/promise/exec_ctx_wakeup_scheduler.h ) + s.files += %w( src/core/lib/promise/for_each.h ) s.files += %w( src/core/lib/promise/if.h ) s.files += %w( src/core/lib/promise/interceptor_list.h ) - s.files += %w( src/core/lib/promise/intra_activity_waiter.h ) s.files += %w( src/core/lib/promise/latch.h ) s.files += %w( src/core/lib/promise/loop.h ) s.files += %w( src/core/lib/promise/map.h ) + s.files += %w( src/core/lib/promise/party.cc ) + s.files += %w( src/core/lib/promise/party.h ) s.files += %w( src/core/lib/promise/pipe.h ) s.files += %w( src/core/lib/promise/poll.h ) s.files += %w( src/core/lib/promise/promise.h ) @@ -1617,6 +1619,8 @@ Gem::Specification.new do |s| s.files += %w( src/core/lib/surface/validate_metadata.cc ) s.files += %w( src/core/lib/surface/validate_metadata.h ) s.files += %w( src/core/lib/surface/version.cc ) + s.files += %w( src/core/lib/transport/batch_builder.cc ) + s.files += %w( src/core/lib/transport/batch_builder.h ) s.files += %w( src/core/lib/transport/bdp_estimator.cc ) s.files += %w( src/core/lib/transport/bdp_estimator.h ) s.files += %w( src/core/lib/transport/connectivity_state.cc ) diff --git a/grpc.gyp b/grpc.gyp index 9acf4bee5b8..c2d2ab0fcd0 100644 --- a/grpc.gyp +++ b/grpc.gyp @@ -293,6 +293,7 @@ 'dependencies': [ 'absl/base:base', 'absl/base:core_headers', + 'absl/functional:any_invocable', 'absl/memory:memory', 'absl/random:random', 'absl/status:status', @@ -359,7 +360,6 @@ 'absl/container:flat_hash_map', 'absl/container:flat_hash_set', 'absl/container:inlined_vector', - 'absl/functional:any_invocable', 'absl/functional:bind_front', 'absl/functional:function_ref', 'absl/hash:hash', @@ -974,6 +974,7 @@ 'src/core/lib/load_balancing/lb_policy_registry.cc', 'src/core/lib/matchers/matchers.cc', 'src/core/lib/promise/activity.cc', + 'src/core/lib/promise/party.cc', 'src/core/lib/promise/sleep.cc', 'src/core/lib/promise/trace.cc', 'src/core/lib/resolver/resolver.cc', @@ -1077,6 +1078,7 @@ 'src/core/lib/surface/server.cc', 'src/core/lib/surface/validate_metadata.cc', 'src/core/lib/surface/version.cc', + 'src/core/lib/transport/batch_builder.cc', 'src/core/lib/transport/bdp_estimator.cc', 'src/core/lib/transport/connectivity_state.cc', 'src/core/lib/transport/error_utils.cc', @@ -1176,7 +1178,6 @@ 'absl/container:flat_hash_map', 'absl/container:flat_hash_set', 'absl/container:inlined_vector', - 'absl/functional:any_invocable', 'absl/functional:bind_front', 'absl/functional:function_ref', 'absl/hash:hash', @@ -1456,6 +1457,7 @@ 'src/core/lib/load_balancing/lb_policy.cc', 'src/core/lib/load_balancing/lb_policy_registry.cc', 'src/core/lib/promise/activity.cc', + 'src/core/lib/promise/party.cc', 'src/core/lib/promise/sleep.cc', 'src/core/lib/promise/trace.cc', 'src/core/lib/resolver/resolver.cc', @@ -1528,6 +1530,7 @@ 'src/core/lib/surface/server.cc', 'src/core/lib/surface/validate_metadata.cc', 'src/core/lib/surface/version.cc', + 'src/core/lib/transport/batch_builder.cc', 'src/core/lib/transport/bdp_estimator.cc', 'src/core/lib/transport/connectivity_state.cc', 'src/core/lib/transport/error_utils.cc', @@ -1795,7 +1798,6 @@ 'absl/container:flat_hash_map', 'absl/container:flat_hash_set', 'absl/container:inlined_vector', - 'absl/functional:any_invocable', 'absl/functional:function_ref', 'absl/hash:hash', 'absl/meta:type_traits', @@ -1964,6 +1966,7 @@ 'src/core/lib/load_balancing/lb_policy_registry.cc', 'src/core/lib/matchers/matchers.cc', 'src/core/lib/promise/activity.cc', + 'src/core/lib/promise/party.cc', 'src/core/lib/promise/trace.cc', 'src/core/lib/resolver/resolver.cc', 'src/core/lib/resolver/resolver_registry.cc', @@ -2034,6 +2037,7 @@ 'src/core/lib/surface/server.cc', 'src/core/lib/surface/validate_metadata.cc', 'src/core/lib/surface/version.cc', + 'src/core/lib/transport/batch_builder.cc', 'src/core/lib/transport/connectivity_state.cc', 'src/core/lib/transport/error_utils.cc', 'src/core/lib/transport/handshaker.cc', diff --git a/package.xml b/package.xml index abbbad4d4d1..3846ae4eb4b 100644 --- a/package.xml +++ b/package.xml @@ -1387,12 +1387,14 @@ + - + + @@ -1599,6 +1601,8 @@ + + diff --git a/src/core/BUILD b/src/core/BUILD index b67320be39f..9adb0182026 100644 --- a/src/core/BUILD +++ b/src/core/BUILD @@ -413,7 +413,6 @@ grpc_cc_library( ], external_deps = [ "absl/base:core_headers", - "absl/container:inlined_vector", "absl/strings", "absl/strings:str_format", ], @@ -421,9 +420,15 @@ grpc_cc_library( deps = [ "activity", "arena", + "construct_destruct", + "context", + "promise_factory", "promise_trace", + "ref_counted", + "//:exec_ctx", "//:gpr", "//:grpc_trace", + "//:ref_counted_ptr", ], ) @@ -571,6 +576,7 @@ grpc_cc_library( "lib/promise/loop.h", ], deps = [ + "construct_destruct", "poll", "promise_factory", "//:gpr_platform", @@ -696,6 +702,7 @@ grpc_cc_library( external_deps = [ "absl/base:core_headers", "absl/status", + "absl/strings", "absl/strings:str_format", "absl/types:optional", ], @@ -708,6 +715,7 @@ grpc_cc_library( "construct_destruct", "context", "no_destruct", + "poll", "promise_factory", "promise_status", "//:gpr", @@ -761,19 +769,6 @@ grpc_cc_library( ], ) -grpc_cc_library( - name = "intra_activity_waiter", - language = "c++", - public_hdrs = [ - "lib/promise/intra_activity_waiter.h", - ], - deps = [ - "activity", - "poll", - "//:gpr_platform", - ], -) - grpc_cc_library( name = "latch", external_deps = ["absl/strings"], @@ -783,7 +778,6 @@ grpc_cc_library( ], deps = [ "activity", - "intra_activity_waiter", "poll", "promise_trace", "//:gpr", @@ -791,25 +785,6 @@ grpc_cc_library( ], ) -grpc_cc_library( - name = "observable", - external_deps = [ - "absl/base:core_headers", - "absl/types:optional", - ], - language = "c++", - public_hdrs = [ - "lib/promise/observable.h", - ], - deps = [ - "activity", - "poll", - "promise_like", - "wait_set", - "//:gpr", - ], -) - grpc_cc_library( name = "interceptor_list", hdrs = [ @@ -839,7 +814,6 @@ grpc_cc_library( "lib/promise/pipe.h", ], external_deps = [ - "absl/base:core_headers", "absl/strings", "absl/types:optional", "absl/types:variant", @@ -851,7 +825,6 @@ grpc_cc_library( "context", "if", "interceptor_list", - "intra_activity_waiter", "map", "poll", "promise_trace", @@ -3502,33 +3475,38 @@ grpc_cc_library( "ext/filters/message_size/message_size_filter.h", ], external_deps = [ - "absl/status", + "absl/status:statusor", "absl/strings", "absl/strings:str_format", "absl/types:optional", ], language = "c++", deps = [ + "activity", + "arena", + "arena_promise", "channel_args", "channel_fwd", "channel_init", "channel_stack_type", - "closure", - "error", + "context", "grpc_service_config", "json", "json_args", "json_object_loader", + "latch", + "poll", + "race", "service_config_parser", + "slice", "slice_buffer", - "status_helper", "validation_errors", "//:channel_stack_builder", "//:config", - "//:debug_location", "//:gpr", "//:grpc_base", "//:grpc_public_hdrs", + "//:grpc_trace", ], ) diff --git a/src/core/ext/filters/client_channel/client_channel.h b/src/core/ext/filters/client_channel/client_channel.h index bfc35c27b4d..957a4748c3b 100644 --- a/src/core/ext/filters/client_channel/client_channel.h +++ b/src/core/ext/filters/client_channel/client_channel.h @@ -363,7 +363,7 @@ class ClientChannel { // TODO(roth): As part of simplifying cancellation in the filter stack, // this should no longer need to be ref-counted. class ClientChannel::LoadBalancedCall - : public InternallyRefCounted { + : public InternallyRefCounted { public: LoadBalancedCall( ClientChannel* chand, grpc_call_context_element* call_context, diff --git a/src/core/ext/filters/client_channel/retry_filter.cc b/src/core/ext/filters/client_channel/retry_filter.cc index c52b6300208..a92a1837bbf 100644 --- a/src/core/ext/filters/client_channel/retry_filter.cc +++ b/src/core/ext/filters/client_channel/retry_filter.cc @@ -272,7 +272,7 @@ class RetryFilter::CallData { // We allocate one struct on the arena for each attempt at starting a // batch on a given LB call. class BatchData - : public RefCounted { + : public RefCounted { public: BatchData(RefCountedPtr call_attempt, int refcount, bool set_on_complete); @@ -649,7 +649,7 @@ class RetryFilter::CallData { // on_call_stack_destruction closure from the surface. class RetryFilter::CallData::CallStackDestructionBarrier : public RefCounted { + UnrefCallDtor> { public: CallStackDestructionBarrier() {} diff --git a/src/core/ext/filters/http/client/http_client_filter.cc b/src/core/ext/filters/http/client/http_client_filter.cc index 58f9a709024..89ff0356cce 100644 --- a/src/core/ext/filters/http/client/http_client_filter.cc +++ b/src/core/ext/filters/http/client/http_client_filter.cc @@ -133,13 +133,13 @@ ArenaPromise HttpClientFilter::MakeCallPromise( return std::move(md); }); - return Race(Map(next_promise_factory(std::move(call_args)), + return Race(initial_metadata_err->Wait(), + Map(next_promise_factory(std::move(call_args)), [](ServerMetadataHandle md) -> ServerMetadataHandle { auto r = CheckServerMetadata(md.get()); if (!r.ok()) return ServerMetadataFromStatus(r); return md; - }), - initial_metadata_err->Wait()); + })); } HttpClientFilter::HttpClientFilter(HttpSchemeMetadata::ValueType scheme, diff --git a/src/core/ext/filters/http/message_compress/compression_filter.cc b/src/core/ext/filters/http/message_compress/compression_filter.cc index a719f5b9132..aea371dbe6d 100644 --- a/src/core/ext/filters/http/message_compress/compression_filter.cc +++ b/src/core/ext/filters/http/message_compress/compression_filter.cc @@ -252,7 +252,7 @@ ArenaPromise ClientCompressionFilter::MakeCallPromise( return CompressMessage(std::move(message), compression_algorithm); }); auto* decompress_args = GetContext()->New( - DecompressArgs{GRPC_COMPRESS_NONE, absl::nullopt}); + DecompressArgs{GRPC_COMPRESS_ALGORITHMS_COUNT, absl::nullopt}); auto* decompress_err = GetContext()->New>(); call_args.server_initial_metadata->InterceptAndMap( @@ -273,8 +273,8 @@ ArenaPromise ClientCompressionFilter::MakeCallPromise( return std::move(*r); }); // Run the next filter, and race it with getting an error from decompression. - return Race(next_promise_factory(std::move(call_args)), - decompress_err->Wait()); + return Race(decompress_err->Wait(), + next_promise_factory(std::move(call_args))); } ArenaPromise ServerCompressionFilter::MakeCallPromise( @@ -288,7 +288,8 @@ ArenaPromise ServerCompressionFilter::MakeCallPromise( this](MessageHandle message) -> absl::optional { auto r = DecompressMessage(std::move(message), decompress_args); if (grpc_call_trace.enabled()) { - gpr_log(GPR_DEBUG, "DecompressMessage returned %s", + gpr_log(GPR_DEBUG, "%s[compression] DecompressMessage returned %s", + Activity::current()->DebugTag().c_str(), r.status().ToString().c_str()); } if (!r.ok()) { @@ -314,13 +315,9 @@ ArenaPromise ServerCompressionFilter::MakeCallPromise( this](MessageHandle message) -> absl::optional { return CompressMessage(std::move(message), *compression_algorithm); }); - // Concurrently: - // - call the next filter - // - decompress incoming messages - // - wait for initial metadata to be sent, and then commence compression of - // outgoing messages - return Race(next_promise_factory(std::move(call_args)), - decompress_err->Wait()); + // Run the next filter, and race it with getting an error from decompression. + return Race(decompress_err->Wait(), + next_promise_factory(std::move(call_args))); } } // namespace grpc_core diff --git a/src/core/ext/filters/message_size/message_size_filter.cc b/src/core/ext/filters/message_size/message_size_filter.cc index 33ff178e5a9..6143239c0c0 100644 --- a/src/core/ext/filters/message_size/message_size_filter.cc +++ b/src/core/ext/filters/message_size/message_size_filter.cc @@ -18,10 +18,13 @@ #include "src/core/ext/filters/message_size/message_size_filter.h" +#include + +#include #include -#include +#include +#include -#include "absl/status/status.h" #include "absl/strings/str_format.h" #include @@ -32,21 +35,22 @@ #include "src/core/lib/channel/channel_stack.h" #include "src/core/lib/channel/channel_stack_builder.h" #include "src/core/lib/config/core_configuration.h" -#include "src/core/lib/gprpp/debug_location.h" -#include "src/core/lib/gprpp/status_helper.h" -#include "src/core/lib/iomgr/call_combiner.h" -#include "src/core/lib/iomgr/closure.h" -#include "src/core/lib/iomgr/error.h" +#include "src/core/lib/debug/trace.h" +#include "src/core/lib/promise/activity.h" +#include "src/core/lib/promise/context.h" +#include "src/core/lib/promise/latch.h" +#include "src/core/lib/promise/poll.h" +#include "src/core/lib/promise/race.h" +#include "src/core/lib/resource_quota/arena.h" #include "src/core/lib/service_config/service_config_call_data.h" +#include "src/core/lib/slice/slice.h" #include "src/core/lib/slice/slice_buffer.h" +#include "src/core/lib/surface/call_trace.h" #include "src/core/lib/surface/channel_init.h" #include "src/core/lib/surface/channel_stack_type.h" +#include "src/core/lib/transport/metadata_batch.h" #include "src/core/lib/transport/transport.h" -static void recv_message_ready(void* user_data, grpc_error_handle error); -static void recv_trailing_metadata_ready(void* user_data, - grpc_error_handle error); - namespace grpc_core { // @@ -124,251 +128,164 @@ size_t MessageSizeParser::ParserIndex() { parser_name()); } -} // namespace grpc_core - -namespace { -struct channel_data { - grpc_core::MessageSizeParsedConfig limits; - const size_t service_config_parser_index{ - grpc_core::MessageSizeParser::ParserIndex()}; -}; +// +// MessageSizeFilter +// -struct call_data { - call_data(grpc_call_element* elem, const channel_data& chand, - const grpc_call_element_args& args) - : call_combiner(args.call_combiner), limits(chand.limits) { - GRPC_CLOSURE_INIT(&recv_message_ready, ::recv_message_ready, elem, - grpc_schedule_on_exec_ctx); - GRPC_CLOSURE_INIT(&recv_trailing_metadata_ready, - ::recv_trailing_metadata_ready, elem, - grpc_schedule_on_exec_ctx); - // Get max sizes from channel data, then merge in per-method config values. - // Note: Per-method config is only available on the client, so we - // apply the max request size to the send limit and the max response - // size to the receive limit. - const grpc_core::MessageSizeParsedConfig* config_from_call_context = - grpc_core::MessageSizeParsedConfig::GetFromCallContext( - args.context, chand.service_config_parser_index); - if (config_from_call_context != nullptr) { - absl::optional max_send_size = limits.max_send_size(); - absl::optional max_recv_size = limits.max_recv_size(); - if (config_from_call_context->max_send_size().has_value() && - (!max_send_size.has_value() || - *config_from_call_context->max_send_size() < *max_send_size)) { - max_send_size = *config_from_call_context->max_send_size(); +const grpc_channel_filter ClientMessageSizeFilter::kFilter = + MakePromiseBasedFilter("message_size"); +const grpc_channel_filter ServerMessageSizeFilter::kFilter = + MakePromiseBasedFilter("message_size"); + +class MessageSizeFilter::CallBuilder { + private: + auto Interceptor(uint32_t max_length, bool is_send) { + return [max_length, is_send, + err = err_](MessageHandle msg) -> absl::optional { + if (grpc_call_trace.enabled()) { + gpr_log(GPR_INFO, "%s[message_size] %s len:%" PRIdPTR " max:%d", + Activity::current()->DebugTag().c_str(), + is_send ? "send" : "recv", msg->payload()->Length(), + max_length); } - if (config_from_call_context->max_recv_size().has_value() && - (!max_recv_size.has_value() || - *config_from_call_context->max_recv_size() < *max_recv_size)) { - max_recv_size = *config_from_call_context->max_recv_size(); + if (msg->payload()->Length() > max_length) { + if (err->is_set()) return std::move(msg); + auto r = GetContext()->MakePooled( + GetContext()); + r->Set(GrpcStatusMetadata(), GRPC_STATUS_RESOURCE_EXHAUSTED); + r->Set(GrpcMessageMetadata(), + Slice::FromCopiedString( + absl::StrFormat("%s message larger than max (%u vs. %d)", + is_send ? "Sent" : "Received", + msg->payload()->Length(), max_length))); + err->Set(std::move(r)); + return absl::nullopt; } - limits = grpc_core::MessageSizeParsedConfig(max_send_size, max_recv_size); - } + return std::move(msg); + }; } - ~call_data() {} - - grpc_core::CallCombiner* call_combiner; - grpc_core::MessageSizeParsedConfig limits; - // Receive closures are chained: we inject this closure as the - // recv_message_ready up-call on transport_stream_op, and remember to - // call our next_recv_message_ready member after handling it. - grpc_closure recv_message_ready; - grpc_closure recv_trailing_metadata_ready; - // The error caused by a message that is too large, or absl::OkStatus() - grpc_error_handle error; - // Used by recv_message_ready. - absl::optional* recv_message = nullptr; - // Original recv_message_ready callback, invoked after our own. - grpc_closure* next_recv_message_ready = nullptr; - // Original recv_trailing_metadata callback, invoked after our own. - grpc_closure* original_recv_trailing_metadata_ready; - bool seen_recv_trailing_metadata = false; - grpc_error_handle recv_trailing_metadata_error; -}; - -} // namespace + public: + explicit CallBuilder(const MessageSizeParsedConfig& limits) + : limits_(limits) {} -// Callback invoked when we receive a message. Here we check the max -// receive message size. -static void recv_message_ready(void* user_data, grpc_error_handle error) { - grpc_call_element* elem = static_cast(user_data); - call_data* calld = static_cast(elem->call_data); - if (calld->recv_message->has_value() && - calld->limits.max_recv_size().has_value() && - (*calld->recv_message)->Length() > - static_cast(*calld->limits.max_recv_size())) { - grpc_error_handle new_error = grpc_error_set_int( - GRPC_ERROR_CREATE(absl::StrFormat( - "Received message larger than max (%u vs. %d)", - (*calld->recv_message)->Length(), *calld->limits.max_recv_size())), - grpc_core::StatusIntProperty::kRpcStatus, - GRPC_STATUS_RESOURCE_EXHAUSTED); - error = grpc_error_add_child(error, new_error); - calld->error = error; + template + void AddSend(T* pipe_end) { + if (!limits_.max_send_size().has_value()) return; + pipe_end->InterceptAndMap(Interceptor(*limits_.max_send_size(), true)); } - // Invoke the next callback. - grpc_closure* closure = calld->next_recv_message_ready; - calld->next_recv_message_ready = nullptr; - if (calld->seen_recv_trailing_metadata) { - // We might potentially see another RECV_MESSAGE op. In that case, we do not - // want to run the recv_trailing_metadata_ready closure again. The newer - // RECV_MESSAGE op cannot cause any errors since the transport has already - // invoked the recv_trailing_metadata_ready closure and all further - // RECV_MESSAGE ops will get null payloads. - calld->seen_recv_trailing_metadata = false; - GRPC_CALL_COMBINER_START(calld->call_combiner, - &calld->recv_trailing_metadata_ready, - calld->recv_trailing_metadata_error, - "continue recv_trailing_metadata_ready"); + template + void AddRecv(T* pipe_end) { + if (!limits_.max_recv_size().has_value()) return; + pipe_end->InterceptAndMap(Interceptor(*limits_.max_recv_size(), false)); } - grpc_core::Closure::Run(DEBUG_LOCATION, closure, error); -} -// Callback invoked on completion of recv_trailing_metadata -// Notifies the recv_trailing_metadata batch of any message size failures -static void recv_trailing_metadata_ready(void* user_data, - grpc_error_handle error) { - grpc_call_element* elem = static_cast(user_data); - call_data* calld = static_cast(elem->call_data); - if (calld->next_recv_message_ready != nullptr) { - calld->seen_recv_trailing_metadata = true; - calld->recv_trailing_metadata_error = error; - GRPC_CALL_COMBINER_STOP(calld->call_combiner, - "deferring recv_trailing_metadata_ready until " - "after recv_message_ready"); - return; + ArenaPromise Run( + CallArgs call_args, NextPromiseFactory next_promise_factory) { + return Race(err_->Wait(), next_promise_factory(std::move(call_args))); } - error = grpc_error_add_child(error, calld->error); - // Invoke the next callback. - grpc_core::Closure::Run(DEBUG_LOCATION, - calld->original_recv_trailing_metadata_ready, error); -} -// Start transport stream op. -static void message_size_start_transport_stream_op_batch( - grpc_call_element* elem, grpc_transport_stream_op_batch* op) { - call_data* calld = static_cast(elem->call_data); - // Check max send message size. - if (op->send_message && calld->limits.max_send_size().has_value() && - op->payload->send_message.send_message->Length() > - static_cast(*calld->limits.max_send_size())) { - grpc_transport_stream_op_batch_finish_with_failure( - op, - grpc_error_set_int(GRPC_ERROR_CREATE(absl::StrFormat( - "Sent message larger than max (%u vs. %d)", - op->payload->send_message.send_message->Length(), - *calld->limits.max_send_size())), - grpc_core::StatusIntProperty::kRpcStatus, - GRPC_STATUS_RESOURCE_EXHAUSTED), - calld->call_combiner); - return; - } - // Inject callback for receiving a message. - if (op->recv_message) { - calld->next_recv_message_ready = - op->payload->recv_message.recv_message_ready; - calld->recv_message = op->payload->recv_message.recv_message; - op->payload->recv_message.recv_message_ready = &calld->recv_message_ready; - } - // Inject callback for receiving trailing metadata. - if (op->recv_trailing_metadata) { - calld->original_recv_trailing_metadata_ready = - op->payload->recv_trailing_metadata.recv_trailing_metadata_ready; - op->payload->recv_trailing_metadata.recv_trailing_metadata_ready = - &calld->recv_trailing_metadata_ready; - } - // Chain to the next filter. - grpc_call_next_op(elem, op); -} + private: + Latch* const err_ = + GetContext()->ManagedNew>(); + MessageSizeParsedConfig limits_; +}; -// Constructor for call_data. -static grpc_error_handle message_size_init_call_elem( - grpc_call_element* elem, const grpc_call_element_args* args) { - channel_data* chand = static_cast(elem->channel_data); - new (elem->call_data) call_data(elem, *chand, *args); - return absl::OkStatus(); +absl::StatusOr ClientMessageSizeFilter::Create( + const ChannelArgs& args, ChannelFilter::Args) { + return ClientMessageSizeFilter(args); } -// Destructor for call_data. -static void message_size_destroy_call_elem( - grpc_call_element* elem, const grpc_call_final_info* /*final_info*/, - grpc_closure* /*ignored*/) { - call_data* calld = static_cast(elem->call_data); - calld->~call_data(); +absl::StatusOr ServerMessageSizeFilter::Create( + const ChannelArgs& args, ChannelFilter::Args) { + return ServerMessageSizeFilter(args); } -// Constructor for channel_data. -static grpc_error_handle message_size_init_channel_elem( - grpc_channel_element* elem, grpc_channel_element_args* args) { - GPR_ASSERT(!args->is_last); - channel_data* chand = static_cast(elem->channel_data); - new (chand) channel_data(); - chand->limits = grpc_core::MessageSizeParsedConfig::GetFromChannelArgs( - args->channel_args); - return absl::OkStatus(); -} +ArenaPromise ClientMessageSizeFilter::MakeCallPromise( + CallArgs call_args, NextPromiseFactory next_promise_factory) { + // Get max sizes from channel data, then merge in per-method config values. + // Note: Per-method config is only available on the client, so we + // apply the max request size to the send limit and the max response + // size to the receive limit. + MessageSizeParsedConfig limits = this->limits(); + const MessageSizeParsedConfig* config_from_call_context = + MessageSizeParsedConfig::GetFromCallContext( + GetContext(), + service_config_parser_index_); + if (config_from_call_context != nullptr) { + absl::optional max_send_size = limits.max_send_size(); + absl::optional max_recv_size = limits.max_recv_size(); + if (config_from_call_context->max_send_size().has_value() && + (!max_send_size.has_value() || + *config_from_call_context->max_send_size() < *max_send_size)) { + max_send_size = *config_from_call_context->max_send_size(); + } + if (config_from_call_context->max_recv_size().has_value() && + (!max_recv_size.has_value() || + *config_from_call_context->max_recv_size() < *max_recv_size)) { + max_recv_size = *config_from_call_context->max_recv_size(); + } + limits = MessageSizeParsedConfig(max_send_size, max_recv_size); + } -// Destructor for channel_data. -static void message_size_destroy_channel_elem(grpc_channel_element* elem) { - channel_data* chand = static_cast(elem->channel_data); - chand->~channel_data(); + CallBuilder b(limits); + b.AddSend(call_args.client_to_server_messages); + b.AddRecv(call_args.server_to_client_messages); + return b.Run(std::move(call_args), std::move(next_promise_factory)); } -const grpc_channel_filter grpc_message_size_filter = { - message_size_start_transport_stream_op_batch, - nullptr, - grpc_channel_next_op, - sizeof(call_data), - message_size_init_call_elem, - grpc_call_stack_ignore_set_pollset_or_pollset_set, - message_size_destroy_call_elem, - sizeof(channel_data), - message_size_init_channel_elem, - grpc_channel_stack_no_post_init, - message_size_destroy_channel_elem, - grpc_channel_next_get_info, - "message_size"}; +ArenaPromise ServerMessageSizeFilter::MakeCallPromise( + CallArgs call_args, NextPromiseFactory next_promise_factory) { + CallBuilder b(limits()); + b.AddSend(call_args.server_to_client_messages); + b.AddRecv(call_args.client_to_server_messages); + return b.Run(std::move(call_args), std::move(next_promise_factory)); +} +namespace { // Used for GRPC_CLIENT_SUBCHANNEL -static bool maybe_add_message_size_filter_subchannel( - grpc_core::ChannelStackBuilder* builder) { +bool MaybeAddMessageSizeFilterToSubchannel(ChannelStackBuilder* builder) { if (builder->channel_args().WantMinimalStack()) { return true; } - builder->PrependFilter(&grpc_message_size_filter); + builder->PrependFilter(&ClientMessageSizeFilter::kFilter); return true; } -// Used for GRPC_CLIENT_DIRECT_CHANNEL and GRPC_SERVER_CHANNEL. Adds the filter -// only if message size limits or service config is specified. -static bool maybe_add_message_size_filter( - grpc_core::ChannelStackBuilder* builder) { - auto channel_args = builder->channel_args(); - if (channel_args.WantMinimalStack()) { +// Used for GRPC_CLIENT_DIRECT_CHANNEL and GRPC_SERVER_CHANNEL. Adds the +// filter only if message size limits or service config is specified. +auto MaybeAddMessageSizeFilter(const grpc_channel_filter* filter) { + return [filter](ChannelStackBuilder* builder) { + auto channel_args = builder->channel_args(); + if (channel_args.WantMinimalStack()) { + return true; + } + MessageSizeParsedConfig limits = + MessageSizeParsedConfig::GetFromChannelArgs(channel_args); + const bool enable = + limits.max_send_size().has_value() || + limits.max_recv_size().has_value() || + channel_args.GetString(GRPC_ARG_SERVICE_CONFIG).has_value(); + if (enable) builder->PrependFilter(filter); return true; - } - grpc_core::MessageSizeParsedConfig limits = - grpc_core::MessageSizeParsedConfig::GetFromChannelArgs(channel_args); - const bool enable = - limits.max_send_size().has_value() || - limits.max_recv_size().has_value() || - channel_args.GetString(GRPC_ARG_SERVICE_CONFIG).has_value(); - if (enable) builder->PrependFilter(&grpc_message_size_filter); - return true; + }; } -namespace grpc_core { +} // namespace void RegisterMessageSizeFilter(CoreConfiguration::Builder* builder) { MessageSizeParser::Register(builder); - builder->channel_init()->RegisterStage( - GRPC_CLIENT_SUBCHANNEL, GRPC_CHANNEL_INIT_BUILTIN_PRIORITY, - maybe_add_message_size_filter_subchannel); - builder->channel_init()->RegisterStage(GRPC_CLIENT_DIRECT_CHANNEL, - GRPC_CHANNEL_INIT_BUILTIN_PRIORITY, - maybe_add_message_size_filter); - builder->channel_init()->RegisterStage(GRPC_SERVER_CHANNEL, + builder->channel_init()->RegisterStage(GRPC_CLIENT_SUBCHANNEL, GRPC_CHANNEL_INIT_BUILTIN_PRIORITY, - maybe_add_message_size_filter); + MaybeAddMessageSizeFilterToSubchannel); + builder->channel_init()->RegisterStage( + GRPC_CLIENT_DIRECT_CHANNEL, GRPC_CHANNEL_INIT_BUILTIN_PRIORITY, + MaybeAddMessageSizeFilter(&ClientMessageSizeFilter::kFilter)); + builder->channel_init()->RegisterStage( + GRPC_SERVER_CHANNEL, GRPC_CHANNEL_INIT_BUILTIN_PRIORITY, + MaybeAddMessageSizeFilter(&ServerMessageSizeFilter::kFilter)); } } // namespace grpc_core diff --git a/src/core/ext/filters/message_size/message_size_filter.h b/src/core/ext/filters/message_size/message_size_filter.h index e47485a8950..75135a1b75e 100644 --- a/src/core/ext/filters/message_size/message_size_filter.h +++ b/src/core/ext/filters/message_size/message_size_filter.h @@ -24,21 +24,22 @@ #include +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "src/core/lib/channel/channel_args.h" #include "src/core/lib/channel/channel_fwd.h" -#include "src/core/lib/channel/channel_stack.h" #include "src/core/lib/channel/context.h" +#include "src/core/lib/channel/promise_based_filter.h" #include "src/core/lib/config/core_configuration.h" #include "src/core/lib/gprpp/validation_errors.h" #include "src/core/lib/json/json.h" #include "src/core/lib/json/json_args.h" #include "src/core/lib/json/json_object_loader.h" +#include "src/core/lib/promise/arena_promise.h" #include "src/core/lib/service_config/service_config_parser.h" - -extern const grpc_channel_filter grpc_message_size_filter; +#include "src/core/lib/transport/transport.h" namespace grpc_core { @@ -85,6 +86,50 @@ class MessageSizeParser : public ServiceConfigParser::Parser { absl::optional GetMaxRecvSizeFromChannelArgs(const ChannelArgs& args); absl::optional GetMaxSendSizeFromChannelArgs(const ChannelArgs& args); +class MessageSizeFilter : public ChannelFilter { + protected: + explicit MessageSizeFilter(const ChannelArgs& args) + : limits_(MessageSizeParsedConfig::GetFromChannelArgs(args)) {} + + class CallBuilder; + + const MessageSizeParsedConfig& limits() const { return limits_; } + + private: + MessageSizeParsedConfig limits_; +}; + +class ServerMessageSizeFilter final : public MessageSizeFilter { + public: + static const grpc_channel_filter kFilter; + + static absl::StatusOr Create( + const ChannelArgs& args, ChannelFilter::Args filter_args); + + // Construct a promise for one call. + ArenaPromise MakeCallPromise( + CallArgs call_args, NextPromiseFactory next_promise_factory) override; + + private: + using MessageSizeFilter::MessageSizeFilter; +}; + +class ClientMessageSizeFilter final : public MessageSizeFilter { + public: + static const grpc_channel_filter kFilter; + + static absl::StatusOr Create( + const ChannelArgs& args, ChannelFilter::Args filter_args); + + // Construct a promise for one call. + ArenaPromise MakeCallPromise( + CallArgs call_args, NextPromiseFactory next_promise_factory) override; + + private: + const size_t service_config_parser_index_{MessageSizeParser::ParserIndex()}; + using MessageSizeFilter::MessageSizeFilter; +}; + } // namespace grpc_core #endif // GRPC_SRC_CORE_EXT_FILTERS_MESSAGE_SIZE_MESSAGE_SIZE_FILTER_H diff --git a/src/core/ext/transport/binder/transport/binder_transport.cc b/src/core/ext/transport/binder/transport/binder_transport.cc index 38ccbc6ef8f..0420e96b184 100644 --- a/src/core/ext/transport/binder/transport/binder_transport.cc +++ b/src/core/ext/transport/binder/transport/binder_transport.cc @@ -694,6 +694,7 @@ static grpc_endpoint* get_endpoint(grpc_transport*) { // See grpc_transport_vtable declaration for meaning of each field static const grpc_transport_vtable vtable = {sizeof(grpc_binder_stream), + false, "binder", init_stream, nullptr, diff --git a/src/core/ext/transport/chttp2/transport/chttp2_transport.cc b/src/core/ext/transport/chttp2/transport/chttp2_transport.cc index df7adaf667f..0b55f73f67c 100644 --- a/src/core/ext/transport/chttp2/transport/chttp2_transport.cc +++ b/src/core/ext/transport/chttp2/transport/chttp2_transport.cc @@ -1204,7 +1204,8 @@ void grpc_chttp2_complete_closure_step(grpc_chttp2_transport* t, grpc_chttp2_stream* s, grpc_closure** pclosure, grpc_error_handle error, - const char* desc) { + const char* desc, + grpc_core::DebugLocation whence) { grpc_closure* closure = *pclosure; *pclosure = nullptr; if (closure == nullptr) { @@ -1215,14 +1216,14 @@ void grpc_chttp2_complete_closure_step(grpc_chttp2_transport* t, gpr_log( GPR_INFO, "complete_closure_step: t=%p %p refs=%d flags=0x%04x desc=%s err=%s " - "write_state=%s", + "write_state=%s whence=%s:%d", t, closure, static_cast(closure->next_data.scratch / CLOSURE_BARRIER_FIRST_REF_BIT), static_cast(closure->next_data.scratch % CLOSURE_BARRIER_FIRST_REF_BIT), desc, grpc_core::StatusToString(error).c_str(), - write_state_name(t->write_state)); + write_state_name(t->write_state), whence.file(), whence.line()); } auto* tracer = CallTracerIfEnabled(s); @@ -3073,6 +3074,7 @@ static grpc_endpoint* chttp2_get_endpoint(grpc_transport* t) { } static const grpc_transport_vtable vtable = {sizeof(grpc_chttp2_stream), + false, "chttp2", init_stream, nullptr, diff --git a/src/core/ext/transport/chttp2/transport/internal.h b/src/core/ext/transport/chttp2/transport/internal.h index 8ad1a17a101..a89b003fcdb 100644 --- a/src/core/ext/transport/chttp2/transport/internal.h +++ b/src/core/ext/transport/chttp2/transport/internal.h @@ -709,7 +709,8 @@ void grpc_chttp2_complete_closure_step(grpc_chttp2_transport* t, grpc_chttp2_stream* s, grpc_closure** pclosure, grpc_error_handle error, - const char* desc); + const char* desc, + grpc_core::DebugLocation whence = {}); #define GRPC_HEADER_SIZE_IN_BYTES 5 #define MAX_SIZE_T (~(size_t)0) diff --git a/src/core/ext/transport/cronet/transport/cronet_transport.cc b/src/core/ext/transport/cronet/transport/cronet_transport.cc index 77cf88b17de..e1dfbb33b71 100644 --- a/src/core/ext/transport/cronet/transport/cronet_transport.cc +++ b/src/core/ext/transport/cronet/transport/cronet_transport.cc @@ -1462,6 +1462,7 @@ static void perform_op(grpc_transport* /*gt*/, grpc_transport_op* /*op*/) {} static const grpc_transport_vtable grpc_cronet_vtable = { sizeof(stream_obj), + false, "cronet_http", init_stream, nullptr, diff --git a/src/core/ext/transport/inproc/inproc_transport.cc b/src/core/ext/transport/inproc/inproc_transport.cc index dc6b4804f0a..b4185e454c0 100644 --- a/src/core/ext/transport/inproc/inproc_transport.cc +++ b/src/core/ext/transport/inproc/inproc_transport.cc @@ -408,7 +408,7 @@ void complete_if_batch_end_locked(inproc_stream* s, grpc_error_handle error, int is_rtm = static_cast(op == s->recv_trailing_md_op); if ((is_sm + is_stm + is_rim + is_rm + is_rtm) == 1) { - INPROC_LOG(GPR_INFO, "%s %p %p %s", msg, s, op, + INPROC_LOG(GPR_INFO, "%s %p %p %p %s", msg, s, op, op->on_complete, grpc_core::StatusToString(error).c_str()); grpc_core::ExecCtx::Run(DEBUG_LOCATION, op->on_complete, error); } @@ -697,8 +697,9 @@ void op_state_machine_locked(inproc_stream* s, grpc_error_handle error) { s->to_read_initial_md_filled = false; grpc_core::ExecCtx::Run( DEBUG_LOCATION, - s->recv_initial_md_op->payload->recv_initial_metadata - .recv_initial_metadata_ready, + std::exchange(s->recv_initial_md_op->payload->recv_initial_metadata + .recv_initial_metadata_ready, + nullptr), absl::OkStatus()); complete_if_batch_end_locked( s, absl::OkStatus(), s->recv_initial_md_op, @@ -766,6 +767,8 @@ void op_state_machine_locked(inproc_stream* s, grpc_error_handle error) { nullptr); s->to_read_trailing_md.Clear(); s->to_read_trailing_md_filled = false; + s->recv_trailing_md_op->payload->recv_trailing_metadata + .recv_trailing_metadata->Set(grpc_core::GrpcStatusFromWire(), true); // We should schedule the recv_trailing_md_op completion if // 1. this stream is the client-side @@ -906,8 +909,6 @@ bool cancel_stream_locked(inproc_stream* s, grpc_error_handle error) { return ret; } -void do_nothing(void* /*arg*/, grpc_error_handle /*error*/) {} - void perform_stream_op(grpc_transport* gt, grpc_stream* gs, grpc_transport_stream_op_batch* op) { INPROC_LOG(GPR_INFO, "perform_stream_op %p %p %p", gt, gs, op); @@ -933,8 +934,8 @@ void perform_stream_op(grpc_transport* gt, grpc_stream* gs, // completed). This can go away once we move to a new C++ closure API // that provides the ability to create a barrier closure. if (on_complete == nullptr) { - on_complete = GRPC_CLOSURE_INIT(&op->handler_private.closure, do_nothing, - nullptr, grpc_schedule_on_exec_ctx); + on_complete = op->on_complete = + grpc_core::NewClosure([](grpc_error_handle) {}); } if (op->cancel_stream) { @@ -1177,13 +1178,18 @@ void set_pollset_set(grpc_transport* /*gt*/, grpc_stream* /*gs*/, grpc_endpoint* get_endpoint(grpc_transport* /*t*/) { return nullptr; } -const grpc_transport_vtable inproc_vtable = { - sizeof(inproc_stream), "inproc", - init_stream, nullptr, - set_pollset, set_pollset_set, - perform_stream_op, perform_transport_op, - destroy_stream, destroy_transport, - get_endpoint}; +const grpc_transport_vtable inproc_vtable = {sizeof(inproc_stream), + true, + "inproc", + init_stream, + nullptr, + set_pollset, + set_pollset_set, + perform_stream_op, + perform_transport_op, + destroy_stream, + destroy_transport, + get_endpoint}; //****************************************************************************** // Main inproc transport functions diff --git a/src/core/lib/channel/connected_channel.cc b/src/core/lib/channel/connected_channel.cc index 40582f99e44..fa66027557f 100644 --- a/src/core/lib/channel/connected_channel.cc +++ b/src/core/lib/channel/connected_channel.cc @@ -21,21 +21,16 @@ #include "src/core/lib/channel/connected_channel.h" #include -#include -#include #include #include #include #include +#include #include -#include -#include "absl/base/thread_annotations.h" -#include "absl/container/inlined_vector.h" #include "absl/status/status.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_join.h" +#include "absl/status/statusor.h" #include "absl/types/optional.h" #include "absl/types/variant.h" @@ -47,39 +42,48 @@ #include "src/core/lib/channel/channel_args.h" #include "src/core/lib/channel/channel_fwd.h" #include "src/core/lib/channel/channel_stack.h" -#include "src/core/lib/channel/context.h" #include "src/core/lib/debug/trace.h" #include "src/core/lib/experiments/experiments.h" #include "src/core/lib/gpr/alloc.h" #include "src/core/lib/gprpp/debug_location.h" -#include "src/core/lib/gprpp/match.h" #include "src/core/lib/gprpp/orphanable.h" -#include "src/core/lib/gprpp/status_helper.h" -#include "src/core/lib/gprpp/sync.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" +#include "src/core/lib/gprpp/time.h" #include "src/core/lib/iomgr/call_combiner.h" #include "src/core/lib/iomgr/closure.h" #include "src/core/lib/iomgr/error.h" -#include "src/core/lib/iomgr/exec_ctx.h" #include "src/core/lib/iomgr/polling_entity.h" #include "src/core/lib/promise/activity.h" #include "src/core/lib/promise/arena_promise.h" #include "src/core/lib/promise/context.h" +#include "src/core/lib/promise/detail/basic_join.h" #include "src/core/lib/promise/detail/basic_seq.h" +#include "src/core/lib/promise/for_each.h" +#include "src/core/lib/promise/if.h" +#include "src/core/lib/promise/latch.h" +#include "src/core/lib/promise/loop.h" +#include "src/core/lib/promise/map.h" +#include "src/core/lib/promise/party.h" #include "src/core/lib/promise/pipe.h" #include "src/core/lib/promise/poll.h" +#include "src/core/lib/promise/promise.h" +#include "src/core/lib/promise/race.h" +#include "src/core/lib/promise/seq.h" +#include "src/core/lib/promise/try_join.h" +#include "src/core/lib/promise/try_seq.h" #include "src/core/lib/resource_quota/arena.h" #include "src/core/lib/slice/slice.h" #include "src/core/lib/slice/slice_buffer.h" #include "src/core/lib/surface/call.h" #include "src/core/lib/surface/call_trace.h" #include "src/core/lib/surface/channel_stack_type.h" +#include "src/core/lib/transport/batch_builder.h" +#include "src/core/lib/transport/error_utils.h" #include "src/core/lib/transport/metadata_batch.h" #include "src/core/lib/transport/transport.h" #include "src/core/lib/transport/transport_fwd.h" #include "src/core/lib/transport/transport_impl.h" -#define MAX_BUFFER_LENGTH 8192 - typedef struct connected_channel_channel_data { grpc_transport* transport; } channel_data; @@ -252,10 +256,24 @@ namespace { defined(GRPC_EXPERIMENT_IS_INCLUDED_PROMISE_BASED_SERVER_CALL) class ConnectedChannelStream : public Orphanable { public: + explicit ConnectedChannelStream(grpc_transport* transport) + : transport_(transport), stream_(nullptr, StreamDeleter(this)) { + GRPC_STREAM_REF_INIT( + &stream_refcount_, 1, + [](void* p, grpc_error_handle) { + static_cast(p)->BeginDestroy(); + }, + this, "ConnectedChannelStream"); + } + grpc_transport* transport() { return transport_; } grpc_closure* stream_destroyed_closure() { return &stream_destroyed_; } - void IncrementRefCount(const char* reason) { + BatchBuilder::Target batch_target() { + return BatchBuilder::Target{transport_, stream_.get(), &stream_refcount_}; + } + + void IncrementRefCount(const char* reason = "smartptr") { #ifndef NDEBUG grpc_stream_ref(&stream_refcount_, reason); #else @@ -264,7 +282,7 @@ class ConnectedChannelStream : public Orphanable { #endif } - void Unref(const char* reason) { + void Unref(const char* reason = "smartptr") { #ifndef NDEBUG grpc_stream_unref(&stream_refcount_, reason); #else @@ -273,235 +291,48 @@ class ConnectedChannelStream : public Orphanable { #endif } + RefCountedPtr InternalRef() { + IncrementRefCount("smartptr"); + return RefCountedPtr(this); + } + void Orphan() final { - bool finished; - { - MutexLock lock(mu()); - if (grpc_call_trace.enabled()) { - gpr_log(GPR_INFO, "%s[connected] DropStream: %s finished=%s", - Activity::current()->DebugTag().c_str(), - ActiveOpsString().c_str(), finished_ ? "true" : "false"); - } - finished = finished_; + bool finished = finished_.IsSet(); + if (grpc_call_trace.enabled()) { + gpr_log(GPR_DEBUG, "%s[connected] Orphan stream, finished: %d", + party_->DebugTag().c_str(), finished); } // If we hadn't already observed the stream to be finished, we need to // cancel it at the transport. if (!finished) { - IncrementRefCount("shutdown client stream"); - auto* cancel_op = - GetContext()->New(); - cancel_op->cancel_stream = true; - cancel_op->payload = batch_payload(); - auto* s = stream(); - cancel_op->on_complete = NewClosure( - [this](grpc_error_handle) { Unref("shutdown client stream"); }); - batch_payload()->cancel_stream.cancel_error = absl::CancelledError(); - grpc_transport_perform_stream_op(transport(), s, cancel_op); + party_->Spawn( + "finish", + [self = InternalRef()]() { + if (!self->finished_.IsSet()) { + self->finished_.Set(); + } + return Empty{}; + }, + [](Empty) {}); + GetContext()->Cancel(batch_target(), + absl::CancelledError()); } - Unref("orphan client stream"); + Unref("orphan connected stream"); } - protected: - explicit ConnectedChannelStream(grpc_transport* transport) - : transport_(transport), stream_(nullptr, StreamDeleter(this)) { - call_context_->IncrementRefCount("connected_channel_stream"); - GRPC_STREAM_REF_INIT( - &stream_refcount_, 1, - [](void* p, grpc_error_handle) { - static_cast(p)->BeginDestroy(); - }, - this, "client_stream"); - } + // Returns a promise that implements the receive message loop. + auto RecvMessages(PipeSender* incoming_messages); + // Returns a promise that implements the send message loop. + auto SendMessages(PipeReceiver* outgoing_messages); - grpc_stream* stream() { return stream_.get(); } void SetStream(grpc_stream* stream) { stream_.reset(stream); } + grpc_stream* stream() { return stream_.get(); } grpc_stream_refcount* stream_refcount() { return &stream_refcount_; } - Mutex* mu() const ABSL_LOCK_RETURNED(mu_) { return &mu_; } - grpc_transport_stream_op_batch_payload* batch_payload() { - return &batch_payload_; - } - bool finished() const ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { return finished_; } - void set_finished() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { finished_ = true; } - virtual std::string ActiveOpsString() const - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) = 0; - - void SchedulePush(grpc_transport_stream_op_batch* batch) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { - batch->is_traced = GetContext()->traced(); - if (grpc_call_trace.enabled()) { - gpr_log(GPR_DEBUG, "%s[connected] Push batch to transport: %s", - Activity::current()->DebugTag().c_str(), - grpc_transport_stream_op_batch_string(batch, false).c_str()); - } - if (push_batches_.empty()) { - IncrementRefCount("push"); - ExecCtx::Run(DEBUG_LOCATION, &push_, absl::OkStatus()); - } - push_batches_.push_back(batch); - } - void PollSendMessage(PipeReceiver* outgoing_messages, - ClientMetadataHandle* client_trailing_metadata) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { - if (absl::holds_alternative(send_message_state_)) { - message_to_send_.reset(); - } - if (absl::holds_alternative(send_message_state_)) { - message_to_send_.reset(); - send_message_state_.emplace>( - outgoing_messages->Next()); - } - if (auto* next = absl::get_if>( - &send_message_state_)) { - auto r = (*next)(); - if (auto* p = r.value_if_ready()) { - memset(&send_message_, 0, sizeof(send_message_)); - send_message_.payload = batch_payload(); - send_message_.on_complete = &send_message_batch_done_; - // No value => half close from above. - if (p->has_value()) { - message_to_send_ = std::move(*p); - send_message_state_ = SendMessageToTransport{}; - send_message_.send_message = true; - batch_payload()->send_message.send_message = - (*message_to_send_)->payload(); - batch_payload()->send_message.flags = (*message_to_send_)->flags(); - } else { - if (grpc_call_trace.enabled()) { - gpr_log(GPR_INFO, "%s[connected] PollConnectedChannel: half close", - Activity::current()->DebugTag().c_str()); - } - GPR_ASSERT(!absl::holds_alternative(send_message_state_)); - send_message_state_ = Closed{}; - send_message_.send_trailing_metadata = true; - if (client_trailing_metadata != nullptr) { - *client_trailing_metadata = - GetContext()->MakePooled( - GetContext()); - batch_payload()->send_trailing_metadata.send_trailing_metadata = - client_trailing_metadata->get(); - batch_payload()->send_trailing_metadata.sent = nullptr; - } else { - return; // Skip rest of function for server - } - } - IncrementRefCount("send_message"); - send_message_waker_ = Activity::current()->MakeOwningWaker(); - SchedulePush(&send_message_); - } - } - } - - void PollRecvMessage(PipeSender*& incoming_messages) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { - if (auto* pending = - absl::get_if(&recv_message_state_)) { - if (pending->received) { - if (pending->payload.has_value()) { - if (grpc_call_trace.enabled()) { - gpr_log(GPR_INFO, - "%s[connected] PollRecvMessage: received payload of " - "%" PRIdPTR " bytes", - recv_message_waker_.ActivityDebugTag().c_str(), - pending->payload->Length()); - } - recv_message_state_ = - incoming_messages->Push(GetContext()->MakePooled( - std::move(*pending->payload), pending->flags)); - } else { - if (grpc_call_trace.enabled()) { - gpr_log(GPR_INFO, - "%s[connected] PollRecvMessage: received no payload", - recv_message_waker_.ActivityDebugTag().c_str()); - } - recv_message_state_ = Closed{}; - std::exchange(incoming_messages, nullptr)->Close(); - } - } - } - if (absl::holds_alternative(recv_message_state_)) { - if (grpc_call_trace.enabled()) { - gpr_log(GPR_INFO, "%s[connected] PollRecvMessage: requesting message", - Activity::current()->DebugTag().c_str()); - } - PushRecvMessage(); - } - if (auto* push = absl::get_if::PushType>( - &recv_message_state_)) { - auto r = (*push)(); - if (bool* result = r.value_if_ready()) { - if (*result) { - if (!finished_) { - if (grpc_call_trace.enabled()) { - gpr_log(GPR_INFO, - "%s[connected] PollRecvMessage: pushed message; " - "requesting next", - Activity::current()->DebugTag().c_str()); - } - PushRecvMessage(); - } else { - if (grpc_call_trace.enabled()) { - gpr_log(GPR_INFO, - "%s[connected] PollRecvMessage: pushed message " - "and finished; " - "marking closed", - Activity::current()->DebugTag().c_str()); - } - recv_message_state_ = Closed{}; - std::exchange(incoming_messages, nullptr)->Close(); - } - } else { - if (grpc_call_trace.enabled()) { - gpr_log(GPR_INFO, - "%s[connected] PollRecvMessage: failed to push " - "message; marking " - "closed", - Activity::current()->DebugTag().c_str()); - } - recv_message_state_ = Closed{}; - std::exchange(incoming_messages, nullptr)->Close(); - } - } - } - } - - std::string SendMessageString() const ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu()) { - return Match( - send_message_state_, [](Idle) -> std::string { return "IDLE"; }, - [](Closed) -> std::string { return "CLOSED"; }, - [](const PipeReceiverNextType&) -> std::string { - return "WAITING"; - }, - [](SendMessageToTransport) -> std::string { return "SENDING"; }); - } - - std::string RecvMessageString() const ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu()) { - return Match( - recv_message_state_, [](Idle) -> std::string { return "IDLE"; }, - [](Closed) -> std::string { return "CLOSED"; }, - [](const PendingReceiveMessage&) -> std::string { return "WAITING"; }, - [](const absl::optional& message) -> std::string { - return absl::StrCat( - "READY:", message.has_value() - ? absl::StrCat((*message)->payload()->Length(), "b") - : "EOS"); - }, - [](const PipeSender::PushType&) -> std::string { - return "PUSHING"; - }); - } - - bool IsPromiseReceiving() const ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu()) { - return absl::holds_alternative::PushType>( - recv_message_state_) || - absl::holds_alternative(recv_message_state_); - } + void set_finished() { finished_.Set(); } + auto WaitFinished() { return finished_.Wait(); } private: - struct SendMessageToTransport {}; - struct Idle {}; - struct Closed {}; - class StreamDeleter { public: explicit StreamDeleter(ConnectedChannelStream* impl) : impl_(impl) {} @@ -517,11 +348,7 @@ class ConnectedChannelStream : public Orphanable { using StreamPtr = std::unique_ptr; void StreamDestroyed() { - call_context_->RunInContext([this] { - auto* cc = call_context_; - this->~ConnectedChannelStream(); - cc->Unref("child_stream"); - }); + call_context_->RunInContext([this] { this->~ConnectedChannelStream(); }); } void BeginDestroy() { @@ -532,824 +359,434 @@ class ConnectedChannelStream : public Orphanable { } } - // Called from outside the activity to push work down to the transport. - void Push() { - PushBatches push_batches; - { - MutexLock lock(&mu_); - push_batches.swap(push_batches_); - } - for (auto* batch : push_batches) { - if (stream() != nullptr) { - grpc_transport_perform_stream_op(transport(), stream(), batch); - } else { - grpc_transport_stream_op_batch_finish_with_failure_from_transport( - batch, absl::CancelledError()); - } - } - Unref("push"); - } - - void SendMessageBatchDone(grpc_error_handle error) { - { - MutexLock lock(&mu_); - if (error != absl::OkStatus()) { - // Note that we're in error here, the call will be closed by the - // transport in a moment, and we'll return from the promise with an - // error - so we don't need to do any extra work to close out pipes or - // the like. - send_message_state_ = Closed{}; - } - if (!absl::holds_alternative(send_message_state_)) { - send_message_state_ = Idle{}; - } - send_message_waker_.Wakeup(); - } - Unref("send_message"); - } - - void RecvMessageBatchDone(grpc_error_handle error) { - { - MutexLock lock(mu()); - if (error != absl::OkStatus()) { - if (grpc_call_trace.enabled()) { - gpr_log(GPR_INFO, "%s[connected] RecvMessageBatchDone: error=%s", - recv_message_waker_.ActivityDebugTag().c_str(), - StatusToString(error).c_str()); - } - } else if (absl::holds_alternative(recv_message_state_)) { - if (grpc_call_trace.enabled()) { - gpr_log(GPR_INFO, - "%s[connected] RecvMessageBatchDone: already closed, " - "ignoring", - recv_message_waker_.ActivityDebugTag().c_str()); - } - } else { - if (grpc_call_trace.enabled()) { - gpr_log(GPR_INFO, - "%s[connected] RecvMessageBatchDone: received message", - recv_message_waker_.ActivityDebugTag().c_str()); - } - auto pending = - absl::get_if(&recv_message_state_); - GPR_ASSERT(pending != nullptr); - GPR_ASSERT(pending->received == false); - pending->received = true; - } - recv_message_waker_.Wakeup(); - } - Unref("recv_message"); - } - - void PushRecvMessage() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { - recv_message_state_ = PendingReceiveMessage{}; - auto& pending_recv_message = - absl::get(recv_message_state_); - memset(&recv_message_, 0, sizeof(recv_message_)); - recv_message_.payload = batch_payload(); - recv_message_.on_complete = nullptr; - recv_message_.recv_message = true; - batch_payload()->recv_message.recv_message = &pending_recv_message.payload; - batch_payload()->recv_message.flags = &pending_recv_message.flags; - batch_payload()->recv_message.call_failed_before_recv_message = nullptr; - batch_payload()->recv_message.recv_message_ready = - &recv_message_batch_done_; - IncrementRefCount("recv_message"); - recv_message_waker_ = Activity::current()->MakeOwningWaker(); - SchedulePush(&recv_message_); - } - - mutable Mutex mu_; grpc_transport* const transport_; - CallContext* const call_context_{GetContext()}; + RefCountedPtr const call_context_{ + GetContext()->Ref()}; grpc_closure stream_destroyed_ = MakeMemberClosure( this, DEBUG_LOCATION); grpc_stream_refcount stream_refcount_; StreamPtr stream_; - using PushBatches = absl::InlinedVector; - PushBatches push_batches_ ABSL_GUARDED_BY(mu_); - grpc_closure push_ = - MakeMemberClosure( - this, DEBUG_LOCATION); - - NextResult message_to_send_ ABSL_GUARDED_BY(mu_); - absl::variant, - SendMessageToTransport> - send_message_state_ ABSL_GUARDED_BY(mu_); - grpc_transport_stream_op_batch send_message_; - grpc_closure send_message_batch_done_ = - MakeMemberClosure( - this, DEBUG_LOCATION); - - struct PendingReceiveMessage { - absl::optional payload; - uint32_t flags; - bool received = false; - }; - absl::variant::PushType> - recv_message_state_ ABSL_GUARDED_BY(mu_); - grpc_closure recv_message_batch_done_ = - MakeMemberClosure( - this, DEBUG_LOCATION); - grpc_transport_stream_op_batch recv_message_; - - Waker send_message_waker_ ABSL_GUARDED_BY(mu_); - Waker recv_message_waker_ ABSL_GUARDED_BY(mu_); - bool finished_ ABSL_GUARDED_BY(mu_) = false; - - grpc_transport_stream_op_batch_payload batch_payload_{ - GetContext()}; + Arena* arena_ = GetContext(); + Party* const party_ = static_cast(Activity::current()); + ExternallyObservableLatch finished_; }; -#endif - -#ifdef GRPC_EXPERIMENT_IS_INCLUDED_PROMISE_BASED_CLIENT_CALL -class ClientStream : public ConnectedChannelStream { - public: - ClientStream(grpc_transport* transport, CallArgs call_args) - : ConnectedChannelStream(transport), - server_initial_metadata_pipe_(call_args.server_initial_metadata), - client_to_server_messages_(call_args.client_to_server_messages), - server_to_client_messages_(call_args.server_to_client_messages), - client_initial_metadata_(std::move(call_args.client_initial_metadata)) { - if (grpc_call_trace.enabled()) { - gpr_log(GPR_INFO, "%s[connected] InitImpl: intitial_metadata=%s", - Activity::current()->DebugTag().c_str(), - client_initial_metadata_->DebugString().c_str()); - } - } - - Poll PollOnce() { - MutexLock lock(mu()); - GPR_ASSERT(!finished()); - - if (grpc_call_trace.enabled()) { - gpr_log(GPR_INFO, "%s[connected] PollConnectedChannel: %s", - Activity::current()->DebugTag().c_str(), - ActiveOpsString().c_str()); - } - - if (!std::exchange(requested_metadata_, true)) { - if (grpc_call_trace.enabled()) { - gpr_log(GPR_INFO, - "%s[connected] PollConnectedChannel: requesting metadata", - Activity::current()->DebugTag().c_str()); - } - SetStream(static_cast( - GetContext()->Alloc(transport()->vtable->sizeof_stream))); - grpc_transport_init_stream(transport(), stream(), stream_refcount(), - nullptr, GetContext()); - grpc_transport_set_pops(transport(), stream(), - GetContext()->polling_entity()); - memset(&metadata_, 0, sizeof(metadata_)); - metadata_.send_initial_metadata = true; - metadata_.recv_initial_metadata = true; - metadata_.recv_trailing_metadata = true; - metadata_.payload = batch_payload(); - metadata_.on_complete = &metadata_batch_done_; - batch_payload()->send_initial_metadata.send_initial_metadata = - client_initial_metadata_.get(); - server_initial_metadata_ = - GetContext()->MakePooled(GetContext()); - batch_payload()->recv_initial_metadata.recv_initial_metadata = - server_initial_metadata_.get(); - batch_payload()->recv_initial_metadata.recv_initial_metadata_ready = - &recv_initial_metadata_ready_; - batch_payload()->recv_initial_metadata.trailing_metadata_available = - nullptr; - server_trailing_metadata_ = - GetContext()->MakePooled(GetContext()); - batch_payload()->recv_trailing_metadata.recv_trailing_metadata = - server_trailing_metadata_.get(); - batch_payload()->recv_trailing_metadata.collect_stats = - &GetContext()->call_stats()->transport_stream_stats; - batch_payload()->recv_trailing_metadata.recv_trailing_metadata_ready = - &recv_trailing_metadata_ready_; - IncrementRefCount("metadata_batch_done"); - IncrementRefCount("initial_metadata_ready"); - IncrementRefCount("trailing_metadata_ready"); - initial_metadata_waker_ = Activity::current()->MakeOwningWaker(); - trailing_metadata_waker_ = Activity::current()->MakeOwningWaker(); - SchedulePush(&metadata_); - } - if (server_initial_metadata_state_ == - ServerInitialMetadataState::kReceivedButNotPushed) { - server_initial_metadata_state_ = ServerInitialMetadataState::kPushing; - server_initial_metadata_push_promise_ = - server_initial_metadata_pipe_->Push( - std::move(server_initial_metadata_)); - } - if (server_initial_metadata_state_ == - ServerInitialMetadataState::kPushing) { - auto r = (*server_initial_metadata_push_promise_)(); - if (r.ready()) { - server_initial_metadata_state_ = ServerInitialMetadataState::kPushed; - server_initial_metadata_push_promise_.reset(); - } - } - PollSendMessage(client_to_server_messages_, &client_trailing_metadata_); - PollRecvMessage(server_to_client_messages_); - if (server_initial_metadata_state_ == ServerInitialMetadataState::kPushed && - !IsPromiseReceiving() && - std::exchange(queued_trailing_metadata_, false)) { - if (grpc_call_trace.enabled()) { - gpr_log(GPR_INFO, - "%s[connected] PollConnectedChannel: finished request, " - "returning: {%s}; " - "active_ops: %s", - Activity::current()->DebugTag().c_str(), - server_trailing_metadata_->DebugString().c_str(), - ActiveOpsString().c_str()); - } - set_finished(); - return ServerMetadataHandle(std::move(server_trailing_metadata_)); - } - return Pending{}; - } - - void RecvInitialMetadataReady(grpc_error_handle error) { - GPR_ASSERT(error == absl::OkStatus()); - { - MutexLock lock(mu()); - server_initial_metadata_state_ = - ServerInitialMetadataState::kReceivedButNotPushed; - initial_metadata_waker_.Wakeup(); - } - Unref("initial_metadata_ready"); - } - - void RecvTrailingMetadataReady(grpc_error_handle error) { - GPR_ASSERT(error == absl::OkStatus()); - { - MutexLock lock(mu()); - queued_trailing_metadata_ = true; - if (grpc_call_trace.enabled()) { - gpr_log(GPR_DEBUG, - "%s[connected] RecvTrailingMetadataReady: " - "queued_trailing_metadata_ " - "set to true; active_ops: %s", - trailing_metadata_waker_.ActivityDebugTag().c_str(), - ActiveOpsString().c_str()); - } - trailing_metadata_waker_.Wakeup(); - } - Unref("trailing_metadata_ready"); - } - - void MetadataBatchDone(grpc_error_handle error) { - GPR_ASSERT(error == absl::OkStatus()); - Unref("metadata_batch_done"); - } - private: - enum class ServerInitialMetadataState : uint8_t { - // Initial metadata has not been received from the server. - kNotReceived, - // Initial metadata has been received from the server via the transport, but - // has not yet been pushed onto the pipe to publish it up the call stack. - kReceivedButNotPushed, - // Initial metadata has been received from the server via the transport and - // has been pushed on the pipe to publish it up the call stack. - // It's still in the pipe and has not been removed by the call at the top - // yet. - kPushing, - // Initial metadata has been received from the server via the transport and - // has been pushed on the pipe to publish it up the call stack AND removed - // by the call at the top. - kPushed, - }; - - std::string ActiveOpsString() const override - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu()) { - std::vector ops; - if (finished()) ops.push_back("FINISHED"); - // Outstanding Operations on Transport - std::vector waiting; - if (initial_metadata_waker_ != Waker()) { - waiting.push_back("initial_metadata"); - } - if (trailing_metadata_waker_ != Waker()) { - waiting.push_back("trailing_metadata"); - } - if (!waiting.empty()) { - ops.push_back(absl::StrCat("waiting:", absl::StrJoin(waiting, ","))); - } - // Results from transport - std::vector queued; - if (server_initial_metadata_state_ == - ServerInitialMetadataState::kReceivedButNotPushed) { - queued.push_back("initial_metadata"); - } - if (queued_trailing_metadata_) queued.push_back("trailing_metadata"); - if (!queued.empty()) { - ops.push_back(absl::StrCat("queued:", absl::StrJoin(queued, ","))); - } - // Send message - std::string send_message_state = SendMessageString(); - if (send_message_state != "WAITING") { - ops.push_back(absl::StrCat("send_message:", send_message_state)); - } - // Receive message - std::string recv_message_state = RecvMessageString(); - if (recv_message_state != "IDLE") { - ops.push_back(absl::StrCat("recv_message:", recv_message_state)); - } - return absl::StrJoin(ops, " "); - } - - bool requested_metadata_ = false; - ServerInitialMetadataState server_initial_metadata_state_ - ABSL_GUARDED_BY(mu()) = ServerInitialMetadataState::kNotReceived; - bool queued_trailing_metadata_ ABSL_GUARDED_BY(mu()) = false; - Waker initial_metadata_waker_ ABSL_GUARDED_BY(mu()); - Waker trailing_metadata_waker_ ABSL_GUARDED_BY(mu()); - PipeSender* server_initial_metadata_pipe_; - PipeReceiver* client_to_server_messages_; - PipeSender* server_to_client_messages_; - grpc_closure recv_initial_metadata_ready_ = - MakeMemberClosure( - this, DEBUG_LOCATION); - grpc_closure recv_trailing_metadata_ready_ = - MakeMemberClosure( - this, DEBUG_LOCATION); - ClientMetadataHandle client_initial_metadata_; - ClientMetadataHandle client_trailing_metadata_; - ServerMetadataHandle server_initial_metadata_; - ServerMetadataHandle server_trailing_metadata_; - absl::optional::PushType> - server_initial_metadata_push_promise_; - grpc_transport_stream_op_batch metadata_; - grpc_closure metadata_batch_done_ = - MakeMemberClosure( - this, DEBUG_LOCATION); -}; - -class ClientConnectedCallPromise { - public: - ClientConnectedCallPromise(grpc_transport* transport, CallArgs call_args) - : impl_(GetContext()->New(transport, - std::move(call_args))) {} - - ClientConnectedCallPromise(const ClientConnectedCallPromise&) = delete; - ClientConnectedCallPromise& operator=(const ClientConnectedCallPromise&) = - delete; - ClientConnectedCallPromise(ClientConnectedCallPromise&& other) noexcept - : impl_(std::exchange(other.impl_, nullptr)) {} - ClientConnectedCallPromise& operator=( - ClientConnectedCallPromise&& other) noexcept { - impl_ = std::move(other.impl_); - return *this; - } - - static ArenaPromise Make(grpc_transport* transport, - CallArgs call_args, - NextPromiseFactory) { - return ClientConnectedCallPromise(transport, std::move(call_args)); - } +auto ConnectedChannelStream::RecvMessages( + PipeSender* incoming_messages) { + return Loop([self = InternalRef(), + incoming_messages = std::move(*incoming_messages)]() mutable { + return Seq( + GetContext()->ReceiveMessage(self->batch_target()), + [&incoming_messages]( + absl::StatusOr> status) mutable { + bool has_message = status.ok() && status->has_value(); + auto publish_message = [&incoming_messages, &status]() { + auto pending_message = std::move(**status); + if (grpc_call_trace.enabled()) { + gpr_log(GPR_INFO, + "%s[connected] RecvMessage: received payload of %" PRIdPTR + " bytes", + Activity::current()->DebugTag().c_str(), + pending_message->payload()->Length()); + } + return Map(incoming_messages.Push(std::move(pending_message)), + [](bool ok) -> LoopCtl { + if (!ok) { + if (grpc_call_trace.enabled()) { + gpr_log(GPR_INFO, + "%s[connected] RecvMessage: failed to " + "push message towards the application", + Activity::current()->DebugTag().c_str()); + } + return absl::OkStatus(); + } + return Continue{}; + }); + }; + auto publish_close = [&status]() mutable { + if (grpc_call_trace.enabled()) { + gpr_log(GPR_INFO, + "%s[connected] RecvMessage: reached end of stream with " + "status:%s", + Activity::current()->DebugTag().c_str(), + status.status().ToString().c_str()); + } + return Immediate(LoopCtl(status.status())); + }; + return If(has_message, std::move(publish_message), + std::move(publish_close)); + }); + }); +} - Poll operator()() { return impl_->PollOnce(); } +auto ConnectedChannelStream::SendMessages( + PipeReceiver* outgoing_messages) { + return ForEach(std::move(*outgoing_messages), + [self = InternalRef()](MessageHandle message) { + return GetContext()->SendMessage( + self->batch_target(), std::move(message)); + }); +} +#endif // defined(GRPC_EXPERIMENT_IS_INCLUDED_PROMISE_BASED_CLIENT_CALL) || + // defined(GRPC_EXPERIMENT_IS_INCLUDED_PROMISE_BASED_SERVER_CALL) - private: - OrphanablePtr impl_; -}; +#ifdef GRPC_EXPERIMENT_IS_INCLUDED_PROMISE_BASED_CLIENT_CALL +ArenaPromise MakeClientCallPromise( + grpc_transport* transport, CallArgs call_args, NextPromiseFactory) { + OrphanablePtr stream( + GetContext()->New(transport)); + stream->SetStream(static_cast( + GetContext()->Alloc(transport->vtable->sizeof_stream))); + grpc_transport_init_stream(transport, stream->stream(), + stream->stream_refcount(), nullptr, + GetContext()); + grpc_transport_set_pops(transport, stream->stream(), + GetContext()->polling_entity()); + auto* party = static_cast(Activity::current()); + // Start a loop to send messages from client_to_server_messages to the + // transport. When the pipe closes and the loop completes, send a trailing + // metadata batch to close the stream. + party->Spawn( + "send_messages", + TrySeq(stream->SendMessages(call_args.client_to_server_messages), + [stream = stream->InternalRef()]() { + return GetContext()->SendClientTrailingMetadata( + stream->batch_target()); + }), + [](absl::Status) {}); + // Start a promise to receive server initial metadata and then forward it up + // through the receiving pipe. + auto server_initial_metadata = + GetContext()->MakePooled(GetContext()); + party->Spawn( + "recv_initial_metadata", + TrySeq(GetContext()->ReceiveServerInitialMetadata( + stream->batch_target()), + [pipe = call_args.server_initial_metadata]( + ServerMetadataHandle server_initial_metadata) { + if (grpc_call_trace.enabled()) { + gpr_log(GPR_DEBUG, + "%s[connected] Publish client initial metadata: %s", + Activity::current()->DebugTag().c_str(), + server_initial_metadata->DebugString().c_str()); + } + return Map(pipe->Push(std::move(server_initial_metadata)), + [](bool r) { + if (r) return absl::OkStatus(); + return absl::CancelledError(); + }); + }), + [](absl::Status) {}); + + // Build up the rest of the main call promise: + + // Create a promise that will send initial metadata and then signal completion + // of that via the token. + auto send_initial_metadata = Seq( + GetContext()->SendClientInitialMetadata( + stream->batch_target(), std::move(call_args.client_initial_metadata)), + [sent_initial_metadata_token = + std::move(call_args.client_initial_metadata_outstanding)]( + absl::Status status) mutable { + sent_initial_metadata_token.Complete(status.ok()); + return status; + }); + // Create a promise that will receive server trailing metadata. + // If this fails, we massage the error into metadata that we can report + // upwards. + auto server_trailing_metadata = + GetContext()->MakePooled(GetContext()); + auto recv_trailing_metadata = + Map(GetContext()->ReceiveServerTrailingMetadata( + stream->batch_target()), + [](absl::StatusOr status) mutable { + if (!status.ok()) { + auto server_trailing_metadata = + GetContext()->MakePooled( + GetContext()); + grpc_status_code status_code = GRPC_STATUS_UNKNOWN; + std::string message; + grpc_error_get_status(status.status(), Timestamp::InfFuture(), + &status_code, &message, nullptr, nullptr); + server_trailing_metadata->Set(GrpcStatusMetadata(), status_code); + server_trailing_metadata->Set(GrpcMessageMetadata(), + Slice::FromCopiedString(message)); + return server_trailing_metadata; + } else { + return std::move(*status); + } + }); + // Finally the main call promise. + // Concurrently: send initial metadata and receive messages, until BOTH + // complete (or one fails). + // Next: receive trailing metadata, and return that up the stack. + auto recv_messages = + stream->RecvMessages(call_args.server_to_client_messages); + return Map(TrySeq(TryJoin(std::move(send_initial_metadata), + std::move(recv_messages)), + std::move(recv_trailing_metadata)), + [stream = std::move(stream)](ServerMetadataHandle result) { + stream->set_finished(); + return result; + }); +} #endif #ifdef GRPC_EXPERIMENT_IS_INCLUDED_PROMISE_BASED_SERVER_CALL -class ServerStream final : public ConnectedChannelStream { - public: - ServerStream(grpc_transport* transport, - NextPromiseFactory next_promise_factory) - : ConnectedChannelStream(transport) { - SetStream(static_cast( - GetContext()->Alloc(transport->vtable->sizeof_stream))); - grpc_transport_init_stream( - transport, stream(), stream_refcount(), - GetContext()->server_call_context()->server_stream_data(), - GetContext()); - grpc_transport_set_pops(transport, stream(), - GetContext()->polling_entity()); - - // Fetch initial metadata - auto& gim = call_state_.emplace(this); - gim.recv_initial_metadata_ready_waker = - Activity::current()->MakeOwningWaker(); - memset(&gim.recv_initial_metadata, 0, sizeof(gim.recv_initial_metadata)); - gim.recv_initial_metadata.payload = batch_payload(); - gim.recv_initial_metadata.on_complete = nullptr; - gim.recv_initial_metadata.recv_initial_metadata = true; - gim.next_promise_factory = std::move(next_promise_factory); - batch_payload()->recv_initial_metadata.recv_initial_metadata = - gim.client_initial_metadata.get(); - batch_payload()->recv_initial_metadata.recv_initial_metadata_ready = - &gim.recv_initial_metadata_ready; - SchedulePush(&gim.recv_initial_metadata); - - // Fetch trailing metadata (to catch cancellations) - auto& gtm = - client_trailing_metadata_state_.emplace(); - gtm.recv_trailing_metadata_ready = - MakeMemberClosure(this); - memset(>m.recv_trailing_metadata, 0, sizeof(gtm.recv_trailing_metadata)); - gtm.recv_trailing_metadata.payload = batch_payload(); - gtm.recv_trailing_metadata.recv_trailing_metadata = true; - batch_payload()->recv_trailing_metadata.recv_trailing_metadata = - gtm.result.get(); - batch_payload()->recv_trailing_metadata.collect_stats = - &GetContext()->call_stats()->transport_stream_stats; - batch_payload()->recv_trailing_metadata.recv_trailing_metadata_ready = - >m.recv_trailing_metadata_ready; - SchedulePush(>m.recv_trailing_metadata); - gtm.waker = Activity::current()->MakeOwningWaker(); - } - - Poll PollOnce() { - MutexLock lock(mu()); - - auto poll_send_initial_metadata = [this]() ABSL_EXCLUSIVE_LOCKS_REQUIRED( - mu()) { - if (auto* promise = - absl::get_if>( - &server_initial_metadata_)) { - auto r = (*promise)(); - if (auto* md = r.value_if_ready()) { - if (grpc_call_trace.enabled()) { - gpr_log( - GPR_INFO, "%s[connected] got initial metadata %s", - Activity::current()->DebugTag().c_str(), - (md->has_value() ? (**md)->DebugString() : "") - .c_str()); - } - memset(&send_initial_metadata_, 0, sizeof(send_initial_metadata_)); - send_initial_metadata_.send_initial_metadata = true; - send_initial_metadata_.payload = batch_payload(); - send_initial_metadata_.on_complete = &send_initial_metadata_done_; - batch_payload()->send_initial_metadata.send_initial_metadata = - server_initial_metadata_ - .emplace(std::move(**md)) - .get(); - SchedulePush(&send_initial_metadata_); - return true; - } else { - return false; - } - } else { - return true; - } - }; - - if (grpc_call_trace.enabled()) { - gpr_log(GPR_INFO, "%s[connected] PollConnectedChannel: %s", - Activity::current()->DebugTag().c_str(), - ActiveOpsString().c_str()); - } - - poll_send_initial_metadata(); - - if (auto* p = absl::get_if( - &client_trailing_metadata_state_)) { - pipes_.client_to_server.sender.Close(); - if (!p->result.ok()) { - // client cancelled, we should cancel too - if (absl::holds_alternative(call_state_) || - absl::holds_alternative(call_state_) || - absl::holds_alternative(call_state_)) { - if (!absl::holds_alternative( - server_initial_metadata_)) { - // pretend we've sent initial metadata to stop that op from - // progressing if it's stuck somewhere above us in the stack - server_initial_metadata_.emplace(); - } - // cancel the call - this status will be returned to the server bottom - // promise - call_state_.emplace( - Complete{ServerMetadataFromStatus(p->result)}); - } - } - } - - if (auto* p = absl::get_if(&call_state_)) { - incoming_messages_ = &pipes_.client_to_server.sender; - auto promise = p->next_promise_factory(CallArgs{ - std::move(p->client_initial_metadata), - &pipes_.server_initial_metadata.sender, - &pipes_.client_to_server.receiver, &pipes_.server_to_client.sender}); - call_state_.emplace( - MessageLoop{&pipes_.server_to_client.receiver, std::move(promise)}); - server_initial_metadata_ - .emplace>( - pipes_.server_initial_metadata.receiver.Next()); - } - if (incoming_messages_ != nullptr) { - PollRecvMessage(incoming_messages_); - } - if (auto* p = absl::get_if(&call_state_)) { - if (absl::holds_alternative( - server_initial_metadata_)) { - PollSendMessage(p->outgoing_messages, nullptr); - } - auto poll = p->promise(); - if (auto* r = poll.value_if_ready()) { - if (grpc_call_trace.enabled()) { - gpr_log(GPR_INFO, "%s[connected] got trailing metadata %s; %s", - Activity::current()->DebugTag().c_str(), - (*r)->DebugString().c_str(), ActiveOpsString().c_str()); - } - auto& completing = call_state_.emplace(); - completing.server_trailing_metadata = std::move(*r); - completing.on_complete = - MakeMemberClosure(this); - completing.waker = Activity::current()->MakeOwningWaker(); - auto& op = completing.send_trailing_metadata; - memset(&op, 0, sizeof(op)); - op.payload = batch_payload(); - op.on_complete = &completing.on_complete; - // If we've gotten initial server metadata, we can send trailing - // metadata. - // Otherwise we need to cancel the call. - // There could be an unlucky ordering, so we poll here to make sure. - if (poll_send_initial_metadata()) { - op.send_trailing_metadata = true; - batch_payload()->send_trailing_metadata.send_trailing_metadata = - completing.server_trailing_metadata.get(); - batch_payload()->send_trailing_metadata.sent = &completing.sent; - } else { - op.cancel_stream = true; - const auto status_code = - completing.server_trailing_metadata->get(GrpcStatusMetadata()) - .value_or(GRPC_STATUS_UNKNOWN); - batch_payload()->cancel_stream.cancel_error = grpc_error_set_int( - absl::Status(static_cast(status_code), - completing.server_trailing_metadata - ->GetOrCreatePointer(GrpcMessageMetadata()) - ->as_string_view()), - StatusIntProperty::kRpcStatus, status_code); - } - SchedulePush(&op); - } - } - if (auto* p = absl::get_if(&call_state_)) { - set_finished(); - return std::move(p->result); - } - return Pending{}; - } - - private: - // Call state: we've asked the transport for initial metadata and are - // waiting for it before proceeding. - struct GettingInitialMetadata { - explicit GettingInitialMetadata(ServerStream* stream) - : recv_initial_metadata_ready( - MakeMemberClosure( - stream)) {} - // The batch we're using to get initial metadata. - grpc_transport_stream_op_batch recv_initial_metadata; - // Waker to re-enter the activity once the transport returns. - Waker recv_initial_metadata_ready_waker; - // Initial metadata storage for the transport. - ClientMetadataHandle client_initial_metadata = - GetContext()->MakePooled(GetContext()); - // Closure for the transport to call when it's ready. - grpc_closure recv_initial_metadata_ready; - // Next promise factory to use once we have initial metadata. - NextPromiseFactory next_promise_factory; - }; - - // Call state: transport has returned initial metadata, we're waiting to - // re-enter the activity to process it. - struct GotInitialMetadata { - ClientMetadataHandle client_initial_metadata; - NextPromiseFactory next_promise_factory; - }; - - // Call state: we're sending/receiving messages and processing the filter - // stack. - struct MessageLoop { - PipeReceiver* outgoing_messages; - ArenaPromise promise; - }; - - // Call state: promise stack has returned trailing metadata, we're sending it - // to the transport to communicate. - struct Completing { - ServerMetadataHandle server_trailing_metadata; - grpc_transport_stream_op_batch send_trailing_metadata; - grpc_closure on_complete; - bool sent = false; - Waker waker; - }; - - // Call state: server metadata has been communicated to the transport and sent - // to the client. - // The metadata will be returned down to the server call to tick the - // cancellation bit or not on the originating batch. - struct Complete { - ServerMetadataHandle result; - }; - - // Trailing metadata state: we've asked the transport for trailing metadata - // and are waiting for it before proceeding. - struct WaitingForTrailingMetadata { - ClientMetadataHandle result = - GetContext()->MakePooled(GetContext()); - grpc_transport_stream_op_batch recv_trailing_metadata; - grpc_closure recv_trailing_metadata_ready; - Waker waker; - }; - - // We've received trailing metadata from the transport - which indicates reads - // are closed. - // We convert to an absl::Status here and use that to drive a decision to - // cancel the call (on error) or not. - struct GotClientHalfClose { - absl::Status result; - }; - - void RecvInitialMetadataReady(absl::Status status) { - MutexLock lock(mu()); - auto& getting = absl::get(call_state_); - auto waker = std::move(getting.recv_initial_metadata_ready_waker); - if (grpc_call_trace.enabled()) { - gpr_log(GPR_DEBUG, "%sGOT INITIAL METADATA: err=%s %s", - waker.ActivityDebugTag().c_str(), status.ToString().c_str(), - getting.client_initial_metadata->DebugString().c_str()); - } - GotInitialMetadata got{std::move(getting.client_initial_metadata), - std::move(getting.next_promise_factory)}; - call_state_.emplace(std::move(got)); - waker.Wakeup(); - } - - void SendTrailingMetadataDone(absl::Status result) { - MutexLock lock(mu()); - auto& completing = absl::get(call_state_); - auto md = std::move(completing.server_trailing_metadata); - auto waker = std::move(completing.waker); - if (grpc_call_trace.enabled()) { - gpr_log(GPR_DEBUG, "%sSEND TRAILING METADATA DONE: err=%s sent=%s %s", - waker.ActivityDebugTag().c_str(), result.ToString().c_str(), - completing.sent ? "true" : "false", md->DebugString().c_str()); - } - md->Set(GrpcStatusFromWire(), completing.sent); - if (!result.ok()) { - md->Clear(); - md->Set(GrpcStatusMetadata(), - static_cast(result.code())); - md->Set(GrpcMessageMetadata(), Slice::FromCopiedString(result.message())); - md->Set(GrpcStatusFromWire(), false); - } - call_state_.emplace(Complete{std::move(md)}); - waker.Wakeup(); - } - - std::string ActiveOpsString() const override - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu()) { - std::vector ops; - ops.push_back(absl::StrCat( - "call_state:", - Match( - call_state_, - [](const absl::monostate&) { return "absl::monostate"; }, - [](const GettingInitialMetadata&) { return "GETTING"; }, - [](const GotInitialMetadata&) { return "GOT"; }, - [](const MessageLoop&) { return "RUNNING"; }, - [](const Completing&) { return "COMPLETING"; }, - [](const Complete&) { return "COMPLETE"; }))); - ops.push_back( - absl::StrCat("client_trailing_metadata_state:", - Match( - client_trailing_metadata_state_, - [](const absl::monostate&) -> std::string { - return "absl::monostate"; - }, - [](const WaitingForTrailingMetadata&) -> std::string { - return "WAITING"; - }, - [](const GotClientHalfClose& got) -> std::string { - return absl::StrCat("GOT:", got.result.ToString()); - }))); - // Send initial metadata - ops.push_back(absl::StrCat( - "server_initial_metadata_state:", - Match( - server_initial_metadata_, - [](const absl::monostate&) { return "absl::monostate"; }, - [](const PipeReceiverNextType&) { - return "WAITING"; - }, - [](const ServerMetadataHandle&) { return "GOT"; }))); - // Send message - std::string send_message_state = SendMessageString(); - if (send_message_state != "WAITING") { - ops.push_back(absl::StrCat("send_message:", send_message_state)); - } - // Receive message - std::string recv_message_state = RecvMessageString(); - if (recv_message_state != "IDLE") { - ops.push_back(absl::StrCat("recv_message:", recv_message_state)); - } - return absl::StrJoin(ops, " "); - } - - void SendInitialMetadataDone() {} - - void RecvTrailingMetadataReady(absl::Status error) { - MutexLock lock(mu()); - auto& state = - absl::get(client_trailing_metadata_state_); - if (grpc_call_trace.enabled()) { - gpr_log(GPR_INFO, - "%sRecvTrailingMetadataReady: error:%s metadata:%s state:%s", - state.waker.ActivityDebugTag().c_str(), error.ToString().c_str(), - state.result->DebugString().c_str(), ActiveOpsString().c_str()); - } - auto waker = std::move(state.waker); - ServerMetadataHandle result = std::move(state.result); - if (error.ok()) { - auto* message = result->get_pointer(GrpcMessageMetadata()); - error = absl::Status( - static_cast( - result->get(GrpcStatusMetadata()).value_or(GRPC_STATUS_UNKNOWN)), - message == nullptr ? "" : message->as_string_view()); - } - client_trailing_metadata_state_.emplace( - GotClientHalfClose{error}); - waker.Wakeup(); - } - - struct Pipes { +ArenaPromise MakeServerCallPromise( + grpc_transport* transport, CallArgs, + NextPromiseFactory next_promise_factory) { + OrphanablePtr stream( + GetContext()->New(transport)); + + stream->SetStream(static_cast( + GetContext()->Alloc(transport->vtable->sizeof_stream))); + grpc_transport_init_stream( + transport, stream->stream(), stream->stream_refcount(), + GetContext()->server_call_context()->server_stream_data(), + GetContext()); + grpc_transport_set_pops(transport, stream->stream(), + GetContext()->polling_entity()); + + auto* party = static_cast(Activity::current()); + + // Arifacts we need for the lifetime of the call. + struct CallData { Pipe server_to_client; Pipe client_to_server; Pipe server_initial_metadata; + Latch failure_latch; + bool sent_initial_metadata = false; + bool sent_trailing_metadata = false; }; - - using CallState = - absl::variant; - CallState call_state_ ABSL_GUARDED_BY(mu()) = absl::monostate{}; - using ClientTrailingMetadataState = - absl::variant; - ClientTrailingMetadataState client_trailing_metadata_state_ - ABSL_GUARDED_BY(mu()) = absl::monostate{}; - absl::variant, - ServerMetadataHandle> - ABSL_GUARDED_BY(mu()) server_initial_metadata_ = absl::monostate{}; - PipeSender* incoming_messages_ = nullptr; - grpc_transport_stream_op_batch send_initial_metadata_; - grpc_closure send_initial_metadata_done_ = - MakeMemberClosure( - this); - Pipes pipes_ ABSL_GUARDED_BY(mu()); -}; - -class ServerConnectedCallPromise { - public: - ServerConnectedCallPromise(grpc_transport* transport, - NextPromiseFactory next_promise_factory) - : impl_(GetContext()->New( - transport, std::move(next_promise_factory))) {} - - ServerConnectedCallPromise(const ServerConnectedCallPromise&) = delete; - ServerConnectedCallPromise& operator=(const ServerConnectedCallPromise&) = - delete; - ServerConnectedCallPromise(ServerConnectedCallPromise&& other) noexcept - : impl_(std::exchange(other.impl_, nullptr)) {} - ServerConnectedCallPromise& operator=( - ServerConnectedCallPromise&& other) noexcept { - impl_ = std::move(other.impl_); - return *this; - } - - static ArenaPromise Make(grpc_transport* transport, - CallArgs, - NextPromiseFactory next) { - return ServerConnectedCallPromise(transport, std::move(next)); - } - - Poll operator()() { return impl_->PollOnce(); } - - private: - OrphanablePtr impl_; -}; + auto* call_data = GetContext()->ManagedNew(); + + auto server_to_client_empty = + call_data->server_to_client.receiver.AwaitEmpty(); + + // Create a promise that will receive client initial metadata, and then run + // the main stem of the call (calling next_promise_factory up through the + // filters). + // Race the main call with failure_latch, allowing us to forcefully complete + // the call in the case of a failure. + auto recv_initial_metadata_then_run_promise = + TrySeq(GetContext()->ReceiveClientInitialMetadata( + stream->batch_target()), + [next_promise_factory = std::move(next_promise_factory), + server_to_client_empty = std::move(server_to_client_empty), + call_data](ClientMetadataHandle client_initial_metadata) { + auto call_promise = next_promise_factory(CallArgs{ + std::move(client_initial_metadata), + ClientInitialMetadataOutstandingToken::Empty(), + &call_data->server_initial_metadata.sender, + &call_data->client_to_server.receiver, + &call_data->server_to_client.sender, + }); + return Race(call_data->failure_latch.Wait(), + [call_promise = std::move(call_promise), + server_to_client_empty = + std::move(server_to_client_empty)]() mutable + -> Poll { + // TODO(ctiller): this is deeply weird and we need + // to clean this up. + // + // The following few lines check to ensure that + // there's no message currently pending in the + // outgoing message queue, and if (and only if) + // that's true decides to poll the main promise to + // see if there's a result. + // + // This essentially introduces a polling priority + // scheme that makes the current promise structure + // work out the way we want when talking to + // transports. + // + // The problem is that transports are going to need + // to replicate this structure when they convert to + // promises, and that becomes troubling as we'll be + // replicating weird throughout the stack. + // + // Instead we likely need to change the way we're + // composing promises through the stack. + // + // Proposed is to change filters from a promise + // that takes ClientInitialMetadata and returns + // ServerTrailingMetadata with three pipes for + // ServerInitialMetadata and + // ClientToServerMessages, ServerToClientMessages. + // Instead we'll have five pipes, moving + // ClientInitialMetadata and ServerTrailingMetadata + // to pipes that can be intercepted. + // + // The effect of this change will be to cripple the + // things that can be done in a filter (but cripple + // in line with what most filters actually do). + // We'll likely need to add a `CallContext::Cancel` + // to allow filters to cancel a request, but this + // would also have the advantage of centralizing + // our cancellation machinery which seems like an + // additional win - with the net effect that the + // shape of the call gets made explicit at the top + // & bottom of the stack. + // + // There's a small set of filters (retry, this one, + // lame client, clinet channel) that terminate + // stacks and need a richer set of semantics, but + // that ends up being fine because we can spawn + // tasks in parties to handle those edge cases, and + // keep the majority of filters simple: they just + // call InterceptAndMap on a handful of filters at + // call initialization time and then proceed to + // actually filter. + // + // So that's the plan, why isn't it enacted here? + // + // Well, the plan ends up being easy to implement + // in the promise based world (I did a prototype on + // a branch in an afternoon). It's heinous to + // implement in promise_based_filter, and that code + // is load bearing for us at the time of writing. + // It's not worth delaying promises for a further N + // months (N ~ 6) to make that change. + // + // Instead, we'll move forward with this, get + // promise_based_filter out of the picture, and + // then during the mop-up phase for promises tweak + // the compute structure to move to the magical + // five pipes (I'm reminded of an old Onion + // article), and end up in a good happy place. + if (server_to_client_empty().pending()) { + return Pending{}; + } + return call_promise(); + }); + }); + + // Promise factory that accepts a ServerMetadataHandle, and sends it as the + // trailing metadata for this call. + auto send_trailing_metadata = + [call_data, stream = stream->InternalRef()]( + ServerMetadataHandle server_trailing_metadata) { + return GetContext()->SendServerTrailingMetadata( + stream->batch_target(), std::move(server_trailing_metadata), + !std::exchange(call_data->sent_initial_metadata, true)); + }; + + // Runs the receive message loop, either until all the messages + // are received or the server call is complete. + party->Spawn( + "recv_messages", + Race( + Map(stream->WaitFinished(), [](Empty) { return absl::OkStatus(); }), + Map(stream->RecvMessages(&call_data->client_to_server.sender), + [failure_latch = &call_data->failure_latch](absl::Status status) { + if (!status.ok() && !failure_latch->is_set()) { + failure_latch->Set(ServerMetadataFromStatus(status)); + } + return status; + })), + [](absl::Status) {}); + + // Run a promise that will send initial metadata (if that pipe sends some). + // And then run the send message loop until that completes. + + auto send_initial_metadata = Seq( + Race(Map(stream->WaitFinished(), + [](Empty) { return NextResult(true); }), + call_data->server_initial_metadata.receiver.Next()), + [call_data, stream = stream->InternalRef()]( + NextResult next_result) mutable { + auto md = !call_data->sent_initial_metadata && next_result.has_value() + ? std::move(next_result.value()) + : nullptr; + if (md != nullptr) { + call_data->sent_initial_metadata = true; + auto* party = static_cast(Activity::current()); + party->Spawn("connected/send_initial_metadata", + GetContext()->SendServerInitialMetadata( + stream->batch_target(), std::move(md)), + [](absl::Status) {}); + return Immediate(absl::OkStatus()); + } + return Immediate(absl::CancelledError()); + }); + party->Spawn( + "send_initial_metadata_then_messages", + Race(Map(stream->WaitFinished(), [](Empty) { return absl::OkStatus(); }), + TrySeq(std::move(send_initial_metadata), + stream->SendMessages(&call_data->server_to_client.receiver))), + [](absl::Status) {}); + + // Spawn a job to fetch the "client trailing metadata" - if this is OK then + // it's client done, otherwise it's a signal of cancellation from the client + // which we'll use failure_latch to signal. + + party->Spawn( + "recv_trailing_metadata", + Seq(GetContext()->ReceiveClientTrailingMetadata( + stream->batch_target()), + [failure_latch = &call_data->failure_latch]( + absl::StatusOr status) mutable { + if (grpc_call_trace.enabled()) { + gpr_log( + GPR_DEBUG, + "%s[connected] Got trailing metadata; status=%s metadata=%s", + Activity::current()->DebugTag().c_str(), + status.status().ToString().c_str(), + status.ok() ? (*status)->DebugString().c_str() : ""); + } + ClientMetadataHandle trailing_metadata; + if (status.ok()) { + trailing_metadata = std::move(*status); + } else { + trailing_metadata = + GetContext()->MakePooled( + GetContext()); + grpc_status_code status_code = GRPC_STATUS_UNKNOWN; + std::string message; + grpc_error_get_status(status.status(), Timestamp::InfFuture(), + &status_code, &message, nullptr, nullptr); + trailing_metadata->Set(GrpcStatusMetadata(), status_code); + trailing_metadata->Set(GrpcMessageMetadata(), + Slice::FromCopiedString(message)); + } + if (trailing_metadata->get(GrpcStatusMetadata()) + .value_or(GRPC_STATUS_UNKNOWN) != GRPC_STATUS_OK) { + if (!failure_latch->is_set()) { + failure_latch->Set(std::move(trailing_metadata)); + } + } + return Empty{}; + }), + [](Empty) {}); + + // Finally assemble the main call promise: + // Receive initial metadata from the client and start the promise up the + // filter stack. + // Upon completion, send trailing metadata to the client and then return it + // (allowing the call code to decide on what signalling to give the + // application). + + return Map(Seq(std::move(recv_initial_metadata_then_run_promise), + std::move(send_trailing_metadata)), + [stream = std::move(stream)](ServerMetadataHandle md) { + stream->set_finished(); + return md; + }); +} #endif template (*make_call_promise)( grpc_transport*, CallArgs, NextPromiseFactory)> grpc_channel_filter MakeConnectedFilter() { // Create a vtable that contains both the legacy call methods (for filter - // stack based calls) and the new promise based method for creating promise - // based calls (the latter iff make_call_promise != nullptr). - // In this way the filter can be inserted into either kind of channel stack, - // and only if all the filters in the stack are promise based will the call - // be promise based. + // stack based calls) and the new promise based method for creating + // promise based calls (the latter iff make_call_promise != nullptr). In + // this way the filter can be inserted into either kind of channel stack, + // and only if all the filters in the stack are promise based will the + // call be promise based. auto make_call_wrapper = +[](grpc_channel_element* elem, CallArgs call_args, NextPromiseFactory next) { grpc_transport* transport = @@ -1367,12 +804,11 @@ grpc_channel_filter MakeConnectedFilter() { sizeof(channel_data), connected_channel_init_channel_elem, +[](grpc_channel_stack* channel_stack, grpc_channel_element* elem) { - // HACK(ctiller): increase call stack size for the channel to make space - // for channel data. We need a cleaner (but performant) way to do this, - // and I'm not sure what that is yet. - // This is only "safe" because call stacks place no additional data - // after the last call element, and the last call element MUST be the - // connected channel. + // HACK(ctiller): increase call stack size for the channel to make + // space for channel data. We need a cleaner (but performant) way to + // do this, and I'm not sure what that is yet. This is only "safe" + // because call stacks place no additional data after the last call + // element, and the last call element MUST be the connected channel. channel_stack->call_stack_size += grpc_transport_stream_size( static_cast(elem->channel_data)->transport); }, @@ -1392,7 +828,7 @@ const grpc_channel_filter kPromiseBasedTransportFilter = #ifdef GRPC_EXPERIMENT_IS_INCLUDED_PROMISE_BASED_CLIENT_CALL const grpc_channel_filter kClientEmulatedFilter = - MakeConnectedFilter(); + MakeConnectedFilter(); #else const grpc_channel_filter kClientEmulatedFilter = MakeConnectedFilter(); @@ -1400,7 +836,7 @@ const grpc_channel_filter kClientEmulatedFilter = #ifdef GRPC_EXPERIMENT_IS_INCLUDED_PROMISE_BASED_SERVER_CALL const grpc_channel_filter kServerEmulatedFilter = - MakeConnectedFilter(); + MakeConnectedFilter(); #else const grpc_channel_filter kServerEmulatedFilter = MakeConnectedFilter(); @@ -1416,20 +852,20 @@ bool grpc_add_connected_filter(grpc_core::ChannelStackBuilder* builder) { // We can't know promise based call or not here (that decision needs the // collaboration of all of the filters on the channel, and we don't want // ordering constraints on when we add filters). - // We can know if this results in a promise based call how we'll create our - // promise (if indeed we can), and so that is the choice made here. + // We can know if this results in a promise based call how we'll create + // our promise (if indeed we can), and so that is the choice made here. if (t->vtable->make_call_promise != nullptr) { - // Option 1, and our ideal: the transport supports promise based calls, and - // so we simply use the transport directly. + // Option 1, and our ideal: the transport supports promise based calls, + // and so we simply use the transport directly. builder->AppendFilter(&grpc_core::kPromiseBasedTransportFilter); } else if (grpc_channel_stack_type_is_client(builder->channel_stack_type())) { - // Option 2: the transport does not support promise based calls, but we're - // on the client and so we have an implementation that we can use to convert - // to batches. + // Option 2: the transport does not support promise based calls, but + // we're on the client and so we have an implementation that we can use + // to convert to batches. builder->AppendFilter(&grpc_core::kClientEmulatedFilter); } else { - // Option 3: the transport does not support promise based calls, and we're - // on the server so we use the server filter. + // Option 3: the transport does not support promise based calls, and + // we're on the server so we use the server filter. builder->AppendFilter(&grpc_core::kServerEmulatedFilter); } return true; diff --git a/src/core/lib/channel/promise_based_filter.cc b/src/core/lib/channel/promise_based_filter.cc index d44a9920216..9082022770d 100644 --- a/src/core/lib/channel/promise_based_filter.cc +++ b/src/core/lib/channel/promise_based_filter.cc @@ -16,6 +16,8 @@ #include "src/core/lib/channel/promise_based_filter.h" +#include + #include #include #include @@ -52,7 +54,7 @@ class FakeActivity final : public Activity { explicit FakeActivity(Activity* wake_activity) : wake_activity_(wake_activity) {} void Orphan() override {} - void ForceImmediateRepoll() override {} + void ForceImmediateRepoll(WakeupMask) override {} Waker MakeOwningWaker() override { return wake_activity_->MakeOwningWaker(); } Waker MakeNonOwningWaker() override { return wake_activity_->MakeNonOwningWaker(); @@ -136,20 +138,22 @@ Waker BaseCallData::MakeNonOwningWaker() { return MakeOwningWaker(); } Waker BaseCallData::MakeOwningWaker() { GRPC_CALL_STACK_REF(call_stack_, "waker"); - return Waker(this, nullptr); + return Waker(this, 0); } -void BaseCallData::Wakeup(void*) { +void BaseCallData::Wakeup(WakeupMask) { auto wakeup = [](void* p, grpc_error_handle) { auto* self = static_cast(p); self->OnWakeup(); - self->Drop(nullptr); + self->Drop(0); }; auto* closure = GRPC_CLOSURE_CREATE(wakeup, this, nullptr); GRPC_CALL_COMBINER_START(call_combiner_, closure, absl::OkStatus(), "wakeup"); } -void BaseCallData::Drop(void*) { GRPC_CALL_STACK_UNREF(call_stack_, "waker"); } +void BaseCallData::Drop(WakeupMask) { + GRPC_CALL_STACK_UNREF(call_stack_, "waker"); +} std::string BaseCallData::LogTag() const { return absl::StrCat( @@ -217,7 +221,7 @@ void BaseCallData::CapturedBatch::ResumeWith(Flusher* releaser) { // refcnt==0 ==> cancelled if (grpc_trace_channel.enabled()) { gpr_log(GPR_INFO, "%sRESUME BATCH REQUEST CANCELLED", - Activity::current()->DebugTag().c_str()); + releaser->call()->DebugTag().c_str()); } return; } @@ -241,6 +245,10 @@ void BaseCallData::CapturedBatch::CancelWith(grpc_error_handle error, auto* batch = std::exchange(batch_, nullptr); GPR_ASSERT(batch != nullptr); uintptr_t& refcnt = *RefCountField(batch); + gpr_log(GPR_DEBUG, "%sCancelWith: %p refs=%" PRIdPTR " err=%s [%s]", + releaser->call()->DebugTag().c_str(), batch, refcnt, + error.ToString().c_str(), + grpc_transport_stream_op_batch_string(batch, false).c_str()); if (refcnt == 0) { // refcnt==0 ==> cancelled if (grpc_trace_channel.enabled()) { @@ -331,6 +339,8 @@ const char* BaseCallData::SendMessage::StateString(State state) { return "CANCELLED"; case State::kCancelledButNotYetPolled: return "CANCELLED_BUT_NOT_YET_POLLED"; + case State::kCancelledButNoStatus: + return "CANCELLED_BUT_NO_STATUS"; } return "UNKNOWN"; } @@ -355,6 +365,7 @@ void BaseCallData::SendMessage::StartOp(CapturedBatch batch) { Crash(absl::StrFormat("ILLEGAL STATE: %s", StateString(state_))); case State::kCancelled: case State::kCancelledButNotYetPolled: + case State::kCancelledButNoStatus: return; } batch_ = batch; @@ -382,6 +393,7 @@ void BaseCallData::SendMessage::GotPipe(T* pipe_end) { case State::kForwardedBatch: case State::kBatchCompleted: case State::kPushedToPipe: + case State::kCancelledButNoStatus: Crash(absl::StrFormat("ILLEGAL STATE: %s", StateString(state_))); case State::kCancelled: case State::kCancelledButNotYetPolled: @@ -397,6 +409,7 @@ bool BaseCallData::SendMessage::IsIdle() const { case State::kForwardedBatch: case State::kCancelled: case State::kCancelledButNotYetPolled: + case State::kCancelledButNoStatus: return true; case State::kGotBatchNoPipe: case State::kGotBatch: @@ -425,6 +438,7 @@ void BaseCallData::SendMessage::OnComplete(absl::Status status) { break; case State::kCancelled: case State::kCancelledButNotYetPolled: + case State::kCancelledButNoStatus: flusher.AddClosure(intercepted_on_complete_, status, "forward after cancel"); break; @@ -449,10 +463,14 @@ void BaseCallData::SendMessage::Done(const ServerMetadata& metadata, case State::kCancelledButNotYetPolled: break; case State::kInitial: + state_ = State::kCancelled; + break; case State::kIdle: case State::kForwardedBatch: state_ = State::kCancelledButNotYetPolled; + if (base_->is_current()) base_->ForceImmediateRepoll(); break; + case State::kCancelledButNoStatus: case State::kGotBatchNoPipe: case State::kGotBatch: { std::string temp; @@ -471,6 +489,7 @@ void BaseCallData::SendMessage::Done(const ServerMetadata& metadata, push_.reset(); next_.reset(); state_ = State::kCancelledButNotYetPolled; + if (base_->is_current()) base_->ForceImmediateRepoll(); break; } } @@ -489,6 +508,7 @@ void BaseCallData::SendMessage::WakeInsideCombiner(Flusher* flusher, case State::kIdle: case State::kGotBatchNoPipe: case State::kCancelled: + case State::kCancelledButNoStatus: break; case State::kCancelledButNotYetPolled: interceptor()->Push()->Close(); @@ -530,13 +550,18 @@ void BaseCallData::SendMessage::WakeInsideCombiner(Flusher* flusher, "result.has_value=%s", base_->LogTag().c_str(), p->has_value() ? "true" : "false"); } - GPR_ASSERT(p->has_value()); - batch_->payload->send_message.send_message->Swap((**p)->payload()); - batch_->payload->send_message.flags = (**p)->flags(); - state_ = State::kForwardedBatch; - batch_.ResumeWith(flusher); - next_.reset(); - if ((*push_)().ready()) push_.reset(); + if (p->has_value()) { + batch_->payload->send_message.send_message->Swap((**p)->payload()); + batch_->payload->send_message.flags = (**p)->flags(); + state_ = State::kForwardedBatch; + batch_.ResumeWith(flusher); + next_.reset(); + if ((*push_)().ready()) push_.reset(); + } else { + state_ = State::kCancelledButNoStatus; + next_.reset(); + push_.reset(); + } } } break; case State::kForwardedBatch: @@ -1094,11 +1119,14 @@ class ClientCallData::PollContext { // Poll the promise once since we're waiting for it. Poll poll = self_->promise_(); if (grpc_trace_channel.enabled()) { - gpr_log(GPR_INFO, "%s ClientCallData.PollContext.Run: poll=%s", + gpr_log(GPR_INFO, "%s ClientCallData.PollContext.Run: poll=%s; %s", self_->LogTag().c_str(), - PollToString(poll, [](const ServerMetadataHandle& h) { - return h->DebugString(); - }).c_str()); + PollToString(poll, + [](const ServerMetadataHandle& h) { + return h->DebugString(); + }) + .c_str(), + self_->DebugString().c_str()); } if (auto* r = poll.value_if_ready()) { auto md = std::move(*r); @@ -1278,7 +1306,11 @@ ClientCallData::ClientCallData(grpc_call_element* elem, [args]() { return args->arena->New(args->arena); }, - [args]() { return args->arena->New(args->arena); }) { + [args]() { return args->arena->New(args->arena); }), + initial_metadata_outstanding_token_( + (flags & kFilterIsLast) != 0 + ? ClientInitialMetadataOutstandingToken::New(arena()) + : ClientInitialMetadataOutstandingToken::Empty()) { GRPC_CLOSURE_INIT(&recv_trailing_metadata_ready_, RecvTrailingMetadataReadyCallback, this, grpc_schedule_on_exec_ctx); @@ -1294,8 +1326,12 @@ ClientCallData::~ClientCallData() { } } +std::string ClientCallData::DebugTag() const { + return absl::StrFormat("PBF_CLIENT[%p]: [%s] ", this, elem()->filter->name); +} + // Activity implementation. -void ClientCallData::ForceImmediateRepoll() { +void ClientCallData::ForceImmediateRepoll(WakeupMask) { GPR_ASSERT(poll_ctx_ != nullptr); poll_ctx_->Repoll(); } @@ -1547,6 +1583,7 @@ void ClientCallData::StartPromise(Flusher* flusher) { promise_ = filter->MakeCallPromise( CallArgs{WrapMetadata(send_initial_metadata_batch_->payload ->send_initial_metadata.send_initial_metadata), + std::move(initial_metadata_outstanding_token_), server_initial_metadata_pipe() == nullptr ? nullptr : &server_initial_metadata_pipe()->sender, @@ -1654,8 +1691,7 @@ ArenaPromise ClientCallData::MakeNextPromise( GPR_ASSERT(poll_ctx_ != nullptr); GPR_ASSERT(send_initial_state_ == SendInitialState::kQueued); send_initial_metadata_batch_->payload->send_initial_metadata - .send_initial_metadata = - UnwrapMetadata(std::move(call_args.client_initial_metadata)); + .send_initial_metadata = call_args.client_initial_metadata.get(); if (recv_initial_metadata_ != nullptr) { // Call args should contain a latch for receiving initial metadata. // It might be the one we passed in - in which case we know this filter @@ -1867,8 +1903,15 @@ struct ServerCallData::SendInitialMetadata { class ServerCallData::PollContext { public: - explicit PollContext(ServerCallData* self, Flusher* flusher) - : self_(self), flusher_(flusher) { + explicit PollContext(ServerCallData* self, Flusher* flusher, + DebugLocation created = DebugLocation()) + : self_(self), flusher_(flusher), created_(created) { + if (self_->poll_ctx_ != nullptr) { + Crash(absl::StrCat( + "PollContext: disallowed recursion. New: ", created_.file(), ":", + created_.line(), "; Old: ", self_->poll_ctx_->created_.file(), ":", + self_->poll_ctx_->created_.line())); + } GPR_ASSERT(self_->poll_ctx_ == nullptr); self_->poll_ctx_ = this; scoped_activity_.Init(self_); @@ -1914,6 +1957,7 @@ class ServerCallData::PollContext { Flusher* const flusher_; bool repoll_ = false; bool have_scoped_activity_; + GPR_NO_UNIQUE_ADDRESS DebugLocation created_; }; const char* ServerCallData::StateString(RecvInitialState state) { @@ -1973,11 +2017,18 @@ ServerCallData::~ServerCallData() { gpr_log(GPR_INFO, "%s ~ServerCallData %s", LogTag().c_str(), DebugString().c_str()); } + if (send_initial_metadata_ != nullptr) { + send_initial_metadata_->~SendInitialMetadata(); + } GPR_ASSERT(poll_ctx_ == nullptr); } +std::string ServerCallData::DebugTag() const { + return absl::StrFormat("PBF_SERVER[%p]: [%s] ", this, elem()->filter->name); +} + // Activity implementation. -void ServerCallData::ForceImmediateRepoll() { +void ServerCallData::ForceImmediateRepoll(WakeupMask) { GPR_ASSERT(poll_ctx_ != nullptr); poll_ctx_->Repoll(); } @@ -2083,7 +2134,10 @@ void ServerCallData::StartBatch(grpc_transport_stream_op_batch* b) { switch (send_trailing_state_) { case SendTrailingState::kInitial: send_trailing_metadata_batch_ = batch; - if (receive_message() != nullptr) { + if (receive_message() != nullptr && + batch->payload->send_trailing_metadata.send_trailing_metadata + ->get(GrpcStatusMetadata()) + .value_or(GRPC_STATUS_UNKNOWN) != GRPC_STATUS_OK) { receive_message()->Done( *batch->payload->send_trailing_metadata.send_trailing_metadata, &flusher); @@ -2140,9 +2194,12 @@ void ServerCallData::Completed(grpc_error_handle error, Flusher* flusher) { case SendTrailingState::kForwarded: send_trailing_state_ = SendTrailingState::kCancelled; if (!error.ok()) { + call_stack()->IncrementRefCount(); auto* batch = grpc_make_transport_stream_op( - NewClosure([call_combiner = call_combiner()](absl::Status) { + NewClosure([call_combiner = call_combiner(), + call_stack = call_stack()](absl::Status) { GRPC_CALL_COMBINER_STOP(call_combiner, "done-cancel"); + call_stack->Unref(); })); batch->cancel_stream = true; batch->payload->cancel_stream.cancel_error = error; @@ -2194,7 +2251,7 @@ void ServerCallData::Completed(grpc_error_handle error, Flusher* flusher) { ArenaPromise ServerCallData::MakeNextPromise( CallArgs call_args) { GPR_ASSERT(recv_initial_state_ == RecvInitialState::kComplete); - GPR_ASSERT(UnwrapMetadata(std::move(call_args.client_initial_metadata)) == + GPR_ASSERT(std::move(call_args.client_initial_metadata).get() == recv_initial_metadata_); forward_recv_initial_metadata_callback_ = true; if (send_initial_metadata_ != nullptr) { @@ -2316,6 +2373,7 @@ void ServerCallData::RecvInitialMetadataReady(grpc_error_handle error) { FakeActivity(this).Run([this, filter] { promise_ = filter->MakeCallPromise( CallArgs{WrapMetadata(recv_initial_metadata_), + ClientInitialMetadataOutstandingToken::Empty(), server_initial_metadata_pipe() == nullptr ? nullptr : &server_initial_metadata_pipe()->sender, @@ -2416,9 +2474,14 @@ void ServerCallData::WakeInsideCombiner(Flusher* flusher) { (send_trailing_metadata_batch_->send_message && send_message()->IsForwarded()))) { send_trailing_state_ = SendTrailingState::kQueued; - send_message()->Done(*send_trailing_metadata_batch_->payload - ->send_trailing_metadata.send_trailing_metadata, - flusher); + if (send_trailing_metadata_batch_->payload->send_trailing_metadata + .send_trailing_metadata->get(GrpcStatusMetadata()) + .value_or(GRPC_STATUS_UNKNOWN) != GRPC_STATUS_OK) { + send_message()->Done( + *send_trailing_metadata_batch_->payload->send_trailing_metadata + .send_trailing_metadata, + flusher); + } } } if (receive_message() != nullptr) { @@ -2469,8 +2532,7 @@ void ServerCallData::WakeInsideCombiner(Flusher* flusher) { } if (auto* r = poll.value_if_ready()) { promise_ = ArenaPromise(); - auto* md = UnwrapMetadata(std::move(*r)); - bool destroy_md = true; + auto md = std::move(*r); if (send_message() != nullptr) { send_message()->Done(*md, flusher); } @@ -2482,11 +2544,9 @@ void ServerCallData::WakeInsideCombiner(Flusher* flusher) { case SendTrailingState::kQueuedButHaventClosedSends: case SendTrailingState::kQueued: { if (send_trailing_metadata_batch_->payload->send_trailing_metadata - .send_trailing_metadata != md) { + .send_trailing_metadata != md.get()) { *send_trailing_metadata_batch_->payload->send_trailing_metadata .send_trailing_metadata = std::move(*md); - } else { - destroy_md = false; } send_trailing_metadata_batch_.ResumeWith(flusher); send_trailing_state_ = SendTrailingState::kForwarded; @@ -2504,9 +2564,6 @@ void ServerCallData::WakeInsideCombiner(Flusher* flusher) { // Nothing to do. break; } - if (destroy_md) { - md->~grpc_metadata_batch(); - } } } if (std::exchange(forward_recv_initial_metadata_callback_, false)) { diff --git a/src/core/lib/channel/promise_based_filter.h b/src/core/lib/channel/promise_based_filter.h index e5ec0b2aed8..78d413d15e1 100644 --- a/src/core/lib/channel/promise_based_filter.h +++ b/src/core/lib/channel/promise_based_filter.h @@ -184,7 +184,7 @@ class BaseCallData : public Activity, private Wakeable { Waker MakeNonOwningWaker() final; Waker MakeOwningWaker() final; - std::string ActivityDebugTag(void*) const override { return DebugTag(); } + std::string ActivityDebugTag(WakeupMask) const override { return DebugTag(); } void Finalize(const grpc_call_final_info* final_info) { finalization_.Run(final_info); @@ -222,7 +222,11 @@ class BaseCallData : public Activity, private Wakeable { void Resume(grpc_transport_stream_op_batch* batch) { GPR_ASSERT(!call_->is_last()); - release_.push_back(batch); + if (batch->HasOp()) { + release_.push_back(batch); + } else if (batch->on_complete != nullptr) { + Complete(batch); + } } void Cancel(grpc_transport_stream_op_batch* batch, @@ -241,6 +245,8 @@ class BaseCallData : public Activity, private Wakeable { call_closures_.Add(closure, error, reason); } + BaseCallData* call() const { return call_; } + private: absl::InlinedVector release_; CallCombinerClosureList call_closures_; @@ -284,11 +290,6 @@ class BaseCallData : public Activity, private Wakeable { Arena::PooledDeleter(nullptr)); } - static grpc_metadata_batch* UnwrapMetadata( - Arena::PoolPtr p) { - return p.release(); - } - class ReceiveInterceptor final : public Interceptor { public: explicit ReceiveInterceptor(Arena* arena) : pipe_{arena} {} @@ -402,6 +403,8 @@ class BaseCallData : public Activity, private Wakeable { kCancelledButNotYetPolled, // We're done. kCancelled, + // We're done, but we haven't gotten a status yet + kCancelledButNoStatus, }; static const char* StateString(State); @@ -542,8 +545,8 @@ class BaseCallData : public Activity, private Wakeable { private: // Wakeable implementation. - void Wakeup(void*) final; - void Drop(void*) final; + void Wakeup(WakeupMask) final; + void Drop(WakeupMask) final; virtual void OnWakeup() = 0; @@ -569,10 +572,12 @@ class ClientCallData : public BaseCallData { ~ClientCallData() override; // Activity implementation. - void ForceImmediateRepoll() final; + void ForceImmediateRepoll(WakeupMask) final; // Handle one grpc_transport_stream_op_batch void StartBatch(grpc_transport_stream_op_batch* batch) override; + std::string DebugTag() const override; + private: // At what stage is our handling of send initial metadata? enum class SendInitialState { @@ -669,6 +674,8 @@ class ClientCallData : public BaseCallData { RecvTrailingState recv_trailing_state_ = RecvTrailingState::kInitial; // Polling related data. Non-null if we're actively polling PollContext* poll_ctx_ = nullptr; + // Initial metadata outstanding token + ClientInitialMetadataOutstandingToken initial_metadata_outstanding_token_; }; class ServerCallData : public BaseCallData { @@ -678,10 +685,12 @@ class ServerCallData : public BaseCallData { ~ServerCallData() override; // Activity implementation. - void ForceImmediateRepoll() final; + void ForceImmediateRepoll(WakeupMask) final; // Handle one grpc_transport_stream_op_batch void StartBatch(grpc_transport_stream_op_batch* batch) override; + std::string DebugTag() const override; + protected: absl::string_view ClientOrServerString() const override { return "SVR"; } diff --git a/src/core/lib/gprpp/orphanable.h b/src/core/lib/gprpp/orphanable.h index b9b291317d1..a2f24cb54c2 100644 --- a/src/core/lib/gprpp/orphanable.h +++ b/src/core/lib/gprpp/orphanable.h @@ -69,7 +69,7 @@ inline OrphanablePtr MakeOrphanable(Args&&... args) { } // A type of Orphanable with internal ref-counting. -template +template class InternallyRefCounted : public Orphanable { public: // Not copyable nor movable. @@ -99,12 +99,12 @@ class InternallyRefCounted : public Orphanable { void Unref() { if (GPR_UNLIKELY(refs_.Unref())) { - internal::Delete(static_cast(this)); + unref_behavior_(static_cast(this)); } } void Unref(const DebugLocation& location, const char* reason) { if (GPR_UNLIKELY(refs_.Unref(location, reason))) { - internal::Delete(static_cast(this)); + unref_behavior_(static_cast(this)); } } @@ -115,6 +115,7 @@ class InternallyRefCounted : public Orphanable { } RefCount refs_; + GPR_NO_UNIQUE_ADDRESS UnrefBehavior unref_behavior_; }; } // namespace grpc_core diff --git a/src/core/lib/gprpp/ref_counted.h b/src/core/lib/gprpp/ref_counted.h index 066791929b0..96fe288ff6f 100644 --- a/src/core/lib/gprpp/ref_counted.h +++ b/src/core/lib/gprpp/ref_counted.h @@ -213,41 +213,34 @@ class NonPolymorphicRefCount { }; // Behavior of RefCounted<> upon ref count reaching 0. -enum UnrefBehavior { - // Default behavior: Delete the object. - kUnrefDelete, - // Do not delete the object upon unref. This is useful in cases where all - // existing objects must be tracked in a registry but the object's entry in - // the registry cannot be removed from the object's dtor due to - // synchronization issues. In this case, the registry can be cleaned up - // later by identifying entries for which RefIfNonZero() returns null. - kUnrefNoDelete, - // Call the object's dtor but do not delete it. This is useful for cases - // where the object is stored in memory allocated elsewhere (e.g., the call - // arena). - kUnrefCallDtor, -}; - -namespace internal { -template -class Delete; -template -class Delete { - public: - explicit Delete(T* t) { delete t; } +// Default behavior: Delete the object. +struct UnrefDelete { + template + void operator()(T* p) { + delete p; + } }; -template -class Delete { - public: - explicit Delete(T* /*t*/) {} + +// Do not delete the object upon unref. This is useful in cases where all +// existing objects must be tracked in a registry but the object's entry in +// the registry cannot be removed from the object's dtor due to +// synchronization issues. In this case, the registry can be cleaned up +// later by identifying entries for which RefIfNonZero() returns null. +struct UnrefNoDelete { + template + void operator()(T* /*p*/) {} }; -template -class Delete { - public: - explicit Delete(T* t) { t->~T(); } + +// Call the object's dtor but do not delete it. This is useful for cases +// where the object is stored in memory allocated elsewhere (e.g., the call +// arena). +struct UnrefCallDtor { + template + void operator()(T* p) { + p->~T(); + } }; -} // namespace internal // A base class for reference-counted objects. // New objects should be created via new and start with a refcount of 1. @@ -276,7 +269,7 @@ class Delete { // ch->Unref(); // template + typename UnrefBehavior = UnrefDelete> class RefCounted : public Impl { public: using RefCountedChildType = Child; @@ -301,12 +294,12 @@ class RefCounted : public Impl { // friend of this class. void Unref() { if (GPR_UNLIKELY(refs_.Unref())) { - internal::Delete(static_cast(this)); + unref_behavior_(static_cast(this)); } } void Unref(const DebugLocation& location, const char* reason) { if (GPR_UNLIKELY(refs_.Unref(location, reason))) { - internal::Delete(static_cast(this)); + unref_behavior_(static_cast(this)); } } @@ -331,6 +324,11 @@ class RefCounted : public Impl { intptr_t initial_refcount = 1) : refs_(initial_refcount, trace) {} + // Note: Tracing is a no-op on non-debug builds. + explicit RefCounted(UnrefBehavior b, const char* trace = nullptr, + intptr_t initial_refcount = 1) + : refs_(initial_refcount, trace), unref_behavior_(b) {} + private: // Allow RefCountedPtr<> to access IncrementRefCount(). template @@ -342,6 +340,7 @@ class RefCounted : public Impl { } RefCount refs_; + GPR_NO_UNIQUE_ADDRESS UnrefBehavior unref_behavior_; }; } // namespace grpc_core diff --git a/src/core/lib/gprpp/thd.h b/src/core/lib/gprpp/thd.h index 16a9188d793..a2d9101bce3 100644 --- a/src/core/lib/gprpp/thd.h +++ b/src/core/lib/gprpp/thd.h @@ -25,6 +25,11 @@ #include +#include +#include + +#include "absl/functional/any_invocable.h" + #include namespace grpc_core { @@ -86,6 +91,17 @@ class Thread { Thread(const char* thd_name, void (*thd_body)(void* arg), void* arg, bool* success = nullptr, const Options& options = Options()); + Thread(const char* thd_name, absl::AnyInvocable fn, + bool* success = nullptr, const Options& options = Options()) + : Thread( + thd_name, + [](void* p) { + std::unique_ptr> fn_from_p( + static_cast*>(p)); + (*fn_from_p)(); + }, + new absl::AnyInvocable(std::move(fn)), success, options) {} + /// Move constructor for thread. After this is called, the other thread /// no longer represents a living thread object Thread(Thread&& other) noexcept diff --git a/src/core/lib/iomgr/call_combiner.h b/src/core/lib/iomgr/call_combiner.h index 50aeb63c780..e314479413b 100644 --- a/src/core/lib/iomgr/call_combiner.h +++ b/src/core/lib/iomgr/call_combiner.h @@ -171,8 +171,8 @@ class CallCombinerClosureList { if (GRPC_TRACE_FLAG_ENABLED(grpc_call_combiner_trace)) { gpr_log(GPR_INFO, "CallCombinerClosureList executing closure while already " - "holding call_combiner %p: closure=%p error=%s reason=%s", - call_combiner, closures_[0].closure, + "holding call_combiner %p: closure=%s error=%s reason=%s", + call_combiner, closures_[0].closure->DebugString().c_str(), StatusToString(closures_[0].error).c_str(), closures_[0].reason); } // This will release the call combiner. diff --git a/src/core/lib/promise/activity.cc b/src/core/lib/promise/activity.cc index da009f940c0..b982b8400c9 100644 --- a/src/core/lib/promise/activity.cc +++ b/src/core/lib/promise/activity.cc @@ -19,8 +19,11 @@ #include #include +#include +#include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" #include "src/core/lib/gprpp/atomic_utils.h" @@ -36,7 +39,9 @@ namespace promise_detail { /////////////////////////////////////////////////////////////////////////////// // HELPER TYPES -std::string Unwakeable::ActivityDebugTag(void*) const { return ""; } +std::string Unwakeable::ActivityDebugTag(WakeupMask) const { + return ""; +} // Weak handle to an Activity. // Handle can persist while Activity goes away. @@ -58,7 +63,7 @@ class FreestandingActivity::Handle final : public Wakeable { // Activity needs to wake up (if it still exists!) - wake it up, and drop the // ref that was kept for this handle. - void Wakeup(void*) override ABSL_LOCKS_EXCLUDED(mu_) { + void Wakeup(WakeupMask) override ABSL_LOCKS_EXCLUDED(mu_) { mu_.Lock(); // Note that activity refcount can drop to zero, but we could win the lock // against DropActivity, so we need to only increase activities refcount if @@ -68,7 +73,7 @@ class FreestandingActivity::Handle final : public Wakeable { mu_.Unlock(); // Activity still exists and we have a reference: wake it up, which will // drop the ref. - activity->Wakeup(nullptr); + activity->Wakeup(0); } else { // Could not get the activity - it's either gone or going. No need to wake // it up! @@ -78,9 +83,9 @@ class FreestandingActivity::Handle final : public Wakeable { Unref(); } - void Drop(void*) override { Unref(); } + void Drop(WakeupMask) override { Unref(); } - std::string ActivityDebugTag(void*) const override { + std::string ActivityDebugTag(WakeupMask) const override { MutexLock lock(&mu_); return activity_ == nullptr ? "" : activity_->DebugTag(); } @@ -124,7 +129,7 @@ void FreestandingActivity::DropHandle() { Waker FreestandingActivity::MakeNonOwningWaker() { mu_.AssertHeld(); - return Waker(RefHandle(), nullptr); + return Waker(RefHandle(), 0); } } // namespace promise_detail @@ -133,4 +138,15 @@ std::string Activity::DebugTag() const { return absl::StrFormat("ACTIVITY[%p]", this); } +/////////////////////////////////////////////////////////////////////////////// +// INTRA ACTIVITY WAKER IMPLEMENTATION + +std::string IntraActivityWaiter::DebugString() const { + std::vector bits; + for (size_t i = 0; i < 8 * sizeof(WakeupMask); i++) { + if (wakeups_ & (1 << i)) bits.push_back(i); + } + return absl::StrCat("{", absl::StrJoin(bits, ","), "}"); +} + } // namespace grpc_core diff --git a/src/core/lib/promise/activity.h b/src/core/lib/promise/activity.h index 67933de829a..8e198f0c118 100644 --- a/src/core/lib/promise/activity.h +++ b/src/core/lib/promise/activity.h @@ -38,24 +38,29 @@ #include "src/core/lib/promise/context.h" #include "src/core/lib/promise/detail/promise_factory.h" #include "src/core/lib/promise/detail/status.h" +#include "src/core/lib/promise/poll.h" namespace grpc_core { class Activity; +// WakeupMask is a bitfield representing which parts of an activity should be +// woken up. +using WakeupMask = uint16_t; + // A Wakeable object is used by queues to wake activities. class Wakeable { public: // Wake up the underlying activity. // After calling, this Wakeable cannot be used again. - // arg comes from the Waker object and allows one Wakeable instance to be used - // for multiple disjoint subparts of an Activity. - virtual void Wakeup(void* arg) = 0; + // WakeupMask comes from the activity that created this Wakeable and specifies + // the set of promises that should be awoken. + virtual void Wakeup(WakeupMask wakeup_mask) = 0; // Drop this wakeable without waking up the underlying activity. - virtual void Drop(void* arg) = 0; + virtual void Drop(WakeupMask wakeup_mask) = 0; // Return the underlying activity debug tag, or "" if not available. - virtual std::string ActivityDebugTag(void* arg) const = 0; + virtual std::string ActivityDebugTag(WakeupMask wakeup_mask) const = 0; protected: inline ~Wakeable() {} @@ -63,9 +68,9 @@ class Wakeable { namespace promise_detail { struct Unwakeable final : public Wakeable { - void Wakeup(void*) override {} - void Drop(void*) override {} - std::string ActivityDebugTag(void*) const override; + void Wakeup(WakeupMask) override {} + void Drop(WakeupMask) override {} + std::string ActivityDebugTag(WakeupMask) const override; }; static Unwakeable* unwakeable() { return NoDestructSingleton::Get(); @@ -76,8 +81,9 @@ static Unwakeable* unwakeable() { // This type is non-copyable but movable. class Waker { public: - Waker(Wakeable* wakeable, void* arg) : wakeable_and_arg_{wakeable, arg} {} - Waker() : Waker(promise_detail::unwakeable(), nullptr) {} + Waker(Wakeable* wakeable, WakeupMask wakeup_mask) + : wakeable_and_arg_{wakeable, wakeup_mask} {} + Waker() : Waker(promise_detail::unwakeable(), 0) {} ~Waker() { wakeable_and_arg_.Drop(); } Waker(const Waker&) = delete; Waker& operator=(const Waker&) = delete; @@ -93,7 +99,7 @@ class Waker { template friend H AbslHashValue(H h, const Waker& w) { return H::combine(H::combine(std::move(h), w.wakeable_and_arg_.wakeable), - w.wakeable_and_arg_.arg); + w.wakeable_and_arg_.wakeup_mask); } bool operator==(const Waker& other) const noexcept { @@ -116,27 +122,42 @@ class Waker { private: struct WakeableAndArg { Wakeable* wakeable; - void* arg; + WakeupMask wakeup_mask; - void Wakeup() { wakeable->Wakeup(arg); } - void Drop() { wakeable->Drop(arg); } + void Wakeup() { wakeable->Wakeup(wakeup_mask); } + void Drop() { wakeable->Drop(wakeup_mask); } std::string ActivityDebugTag() const { return wakeable == nullptr ? "" - : wakeable->ActivityDebugTag(arg); + : wakeable->ActivityDebugTag(wakeup_mask); } bool operator==(const WakeableAndArg& other) const noexcept { - return wakeable == other.wakeable && arg == other.arg; + return wakeable == other.wakeable && wakeup_mask == other.wakeup_mask; } }; WakeableAndArg Take() { - return std::exchange(wakeable_and_arg_, - {promise_detail::unwakeable(), nullptr}); + return std::exchange(wakeable_and_arg_, {promise_detail::unwakeable(), 0}); } WakeableAndArg wakeable_and_arg_; }; +// Helper type to track wakeups between objects in the same activity. +// Can be fairly fast as no ref counting or locking needs to occur. +class IntraActivityWaiter { + public: + // Register for wakeup, return Pending(). If state is not ready to proceed, + // Promises should bottom out here. + Pending pending(); + // Wake the activity + void Wake(); + + std::string DebugString() const; + + private: + WakeupMask wakeups_ = 0; +}; + // An Activity tracks execution of a single promise. // It executes the promise under a mutex. // When the promise stalls, it registers the containing activity to be woken up @@ -156,7 +177,13 @@ class Activity : public Orphanable { void ForceWakeup() { MakeOwningWaker().Wakeup(); } // Force the current activity to immediately repoll if it doesn't complete. - virtual void ForceImmediateRepoll() = 0; + virtual void ForceImmediateRepoll(WakeupMask mask) = 0; + // Legacy version of ForceImmediateRepoll() that uses the current participant. + // Will go away once Party gets merged with Activity. New usage is banned. + void ForceImmediateRepoll() { ForceImmediateRepoll(CurrentParticipant()); } + + // Return the current part of the activity as a bitmask + virtual WakeupMask CurrentParticipant() const { return 1; } // Return the current activity. // Additionally: @@ -284,7 +311,7 @@ class FreestandingActivity : public Activity, private Wakeable { public: Waker MakeOwningWaker() final { Ref(); - return Waker(this, nullptr); + return Waker(this, 0); } Waker MakeNonOwningWaker() final; @@ -293,7 +320,7 @@ class FreestandingActivity : public Activity, private Wakeable { Unref(); } - void ForceImmediateRepoll() final { + void ForceImmediateRepoll(WakeupMask) final { mu_.AssertHeld(); SetActionDuringRun(ActionDuringRun::kWakeup); } @@ -333,7 +360,7 @@ class FreestandingActivity : public Activity, private Wakeable { Mutex* mu() ABSL_LOCK_RETURNED(mu_) { return &mu_; } - std::string ActivityDebugTag(void*) const override { return DebugTag(); } + std::string ActivityDebugTag(WakeupMask) const override { return DebugTag(); } private: class Handle; @@ -467,7 +494,7 @@ class PromiseActivity final // the activity to an external threadpool to run. If the activity is already // running on this thread, a note is taken of such and the activity is // repolled if it doesn't complete. - void Wakeup(void*) final { + void Wakeup(WakeupMask) final { // If there is an active activity, but hey it's us, flag that and we'll loop // in RunLoop (that's calling from above here!). if (Activity::is_current()) { @@ -486,7 +513,7 @@ class PromiseActivity final } // Drop a wakeup - void Drop(void*) final { this->WakeupComplete(); } + void Drop(WakeupMask) final { this->WakeupComplete(); } // Notification that we're no longer executing - it's ok to destruct the // promise. @@ -593,6 +620,16 @@ ActivityPtr MakeActivity(Factory promise_factory, std::move(on_done), std::forward(contexts)...)); } +inline Pending IntraActivityWaiter::pending() { + wakeups_ |= Activity::current()->CurrentParticipant(); + return Pending(); +} + +inline void IntraActivityWaiter::Wake() { + if (wakeups_ == 0) return; + Activity::current()->ForceImmediateRepoll(std::exchange(wakeups_, 0)); +} + } // namespace grpc_core #endif // GRPC_SRC_CORE_LIB_PROMISE_ACTIVITY_H diff --git a/src/core/lib/promise/detail/promise_factory.h b/src/core/lib/promise/detail/promise_factory.h index 12b291a545b..adca5af7f08 100644 --- a/src/core/lib/promise/detail/promise_factory.h +++ b/src/core/lib/promise/detail/promise_factory.h @@ -17,6 +17,7 @@ #include +#include #include #include @@ -106,6 +107,9 @@ class Curried { private: GPR_NO_UNIQUE_ADDRESS F f_; GPR_NO_UNIQUE_ADDRESS Arg arg_; +#ifndef NDEBUG + std::unique_ptr asan_canary_ = std::make_unique(0); +#endif }; // Promote a callable(A) -> T | Poll to a PromiseFactory(A) -> Promise by diff --git a/src/core/lib/promise/if.h b/src/core/lib/promise/if.h index 74139715a8c..22956a2f72d 100644 --- a/src/core/lib/promise/if.h +++ b/src/core/lib/promise/if.h @@ -17,7 +17,9 @@ #include +#include #include +#include #include "absl/status/statusor.h" #include "absl/types/variant.h" @@ -162,6 +164,9 @@ class If { } Poll operator()() { +#ifndef NDEBUG + asan_canary_ = std::make_unique(1 + *asan_canary_); +#endif if (condition_) { return if_true_(); } else { @@ -175,6 +180,10 @@ class If { TruePromise if_true_; FalsePromise if_false_; }; + // Make failure to destruct show up in ASAN builds. +#ifndef NDEBUG + std::unique_ptr asan_canary_ = std::make_unique(0); +#endif }; } // namespace promise_detail diff --git a/src/core/lib/promise/interceptor_list.h b/src/core/lib/promise/interceptor_list.h index 546b46d074e..1e460b9c145 100644 --- a/src/core/lib/promise/interceptor_list.h +++ b/src/core/lib/promise/interceptor_list.h @@ -89,6 +89,10 @@ class InterceptorList { public: RunPromise(size_t memory_required, Map* factory, absl::optional value) { if (!value.has_value() || factory == nullptr) { + if (grpc_trace_promise_primitives.enabled()) { + gpr_log(GPR_DEBUG, + "InterceptorList::RunPromise[%p]: create immediate", this); + } is_immediately_resolved_ = true; Construct(&result_, std::move(value)); } else { @@ -96,10 +100,18 @@ class InterceptorList { Construct(&async_resolution_, memory_required); factory->MakePromise(std::move(*value), async_resolution_.space.get()); async_resolution_.current_factory = factory; + if (grpc_trace_promise_primitives.enabled()) { + gpr_log(GPR_DEBUG, + "InterceptorList::RunPromise[%p]: create async; mem=%p", this, + async_resolution_.space.get()); + } } } ~RunPromise() { + if (grpc_trace_promise_primitives.enabled()) { + gpr_log(GPR_DEBUG, "InterceptorList::RunPromise[%p]: destroy", this); + } if (is_immediately_resolved_) { Destruct(&result_); } else { @@ -116,6 +128,10 @@ class InterceptorList { RunPromise(RunPromise&& other) noexcept : is_immediately_resolved_(other.is_immediately_resolved_) { + if (grpc_trace_promise_primitives.enabled()) { + gpr_log(GPR_DEBUG, "InterceptorList::RunPromise[%p]: move from %p", + this, &other); + } if (is_immediately_resolved_) { Construct(&result_, std::move(other.result_)); } else { @@ -127,7 +143,7 @@ class InterceptorList { Poll> operator()() { if (grpc_trace_promise_primitives.enabled()) { - gpr_log(GPR_DEBUG, "InterceptorList::RunPromise: %s", + gpr_log(GPR_DEBUG, "InterceptorList::RunPromise[%p]: %s", this, DebugString().c_str()); } if (is_immediately_resolved_) return std::move(result_); @@ -139,7 +155,12 @@ class InterceptorList { async_resolution_.space.get()); async_resolution_.current_factory = async_resolution_.current_factory->next(); - if (async_resolution_.current_factory == nullptr || !p->has_value()) { + if (!p->has_value()) async_resolution_.current_factory = nullptr; + if (grpc_trace_promise_primitives.enabled()) { + gpr_log(GPR_DEBUG, "InterceptorList::RunPromise[%p]: %s", this, + DebugString().c_str()); + } + if (async_resolution_.current_factory == nullptr) { return std::move(*p); } async_resolution_.current_factory->MakePromise( diff --git a/src/core/lib/promise/intra_activity_waiter.h b/src/core/lib/promise/intra_activity_waiter.h deleted file mode 100644 index 736ec04ae7d..00000000000 --- a/src/core/lib/promise/intra_activity_waiter.h +++ /dev/null @@ -1,55 +0,0 @@ -// Copyright 2021 gRPC authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef GRPC_SRC_CORE_LIB_PROMISE_INTRA_ACTIVITY_WAITER_H -#define GRPC_SRC_CORE_LIB_PROMISE_INTRA_ACTIVITY_WAITER_H - -#include - -#include - -#include "src/core/lib/promise/activity.h" -#include "src/core/lib/promise/poll.h" - -namespace grpc_core { - -// Helper type to track wakeups between objects in the same activity. -// Can be fairly fast as no ref counting or locking needs to occur. -class IntraActivityWaiter { - public: - // Register for wakeup, return Pending(). If state is not ready to proceed, - // Promises should bottom out here. - Pending pending() { - waiting_ = true; - return Pending(); - } - // Wake the activity - void Wake() { - if (waiting_) { - waiting_ = false; - Activity::current()->ForceImmediateRepoll(); - } - } - - std::string DebugString() const { - return waiting_ ? "WAITING" : "NOT_WAITING"; - } - - private: - bool waiting_ = false; -}; - -} // namespace grpc_core - -#endif // GRPC_SRC_CORE_LIB_PROMISE_INTRA_ACTIVITY_WAITER_H diff --git a/src/core/lib/promise/latch.h b/src/core/lib/promise/latch.h index 305cf53ab6e..9d33fe7d280 100644 --- a/src/core/lib/promise/latch.h +++ b/src/core/lib/promise/latch.h @@ -19,6 +19,7 @@ #include +#include #include #include #include @@ -29,7 +30,6 @@ #include "src/core/lib/debug/trace.h" #include "src/core/lib/promise/activity.h" -#include "src/core/lib/promise/intra_activity_waiter.h" #include "src/core/lib/promise/poll.h" #include "src/core/lib/promise/trace.h" @@ -61,13 +61,14 @@ class Latch { } // Produce a promise to wait for a value from this latch. + // Moves the result out of the latch. auto Wait() { #ifndef NDEBUG has_had_waiters_ = true; #endif return [this]() -> Poll { if (grpc_trace_promise_primitives.enabled()) { - gpr_log(GPR_INFO, "%sPollWait %s", DebugTag().c_str(), + gpr_log(GPR_INFO, "%sWait %s", DebugTag().c_str(), StateString().c_str()); } if (has_value_) { @@ -78,6 +79,25 @@ class Latch { }; } + // Produce a promise to wait for a value from this latch. + // Copies the result out of the latch. + auto WaitAndCopy() { +#ifndef NDEBUG + has_had_waiters_ = true; +#endif + return [this]() -> Poll { + if (grpc_trace_promise_primitives.enabled()) { + gpr_log(GPR_INFO, "%sWaitAndCopy %s", DebugTag().c_str(), + StateString().c_str()); + } + if (has_value_) { + return value_; + } else { + return waiter_.pending(); + } + }; + } + // Set the value of the latch. Can only be called once. void Set(T value) { if (grpc_trace_promise_primitives.enabled()) { @@ -89,6 +109,8 @@ class Latch { waiter_.Wake(); } + bool is_set() const { return has_value_; } + private: std::string DebugTag() { return absl::StrCat(Activity::current()->DebugTag(), " LATCH[0x", @@ -165,7 +187,7 @@ class Latch { private: std::string DebugTag() { - return absl::StrCat(Activity::current()->DebugTag(), " LATCH[0x", + return absl::StrCat(Activity::current()->DebugTag(), " LATCH(void)[0x", reinterpret_cast(this), "]: "); } @@ -183,6 +205,70 @@ class Latch { IntraActivityWaiter waiter_; }; +// A Latch that can have its value observed by outside threads, but only waited +// upon from inside a single activity. +template +class ExternallyObservableLatch; + +template <> +class ExternallyObservableLatch { + public: + ExternallyObservableLatch() = default; + ExternallyObservableLatch(const ExternallyObservableLatch&) = delete; + ExternallyObservableLatch& operator=(const ExternallyObservableLatch&) = + delete; + + // Produce a promise to wait for this latch. + auto Wait() { + return [this]() -> Poll { + if (grpc_trace_promise_primitives.enabled()) { + gpr_log(GPR_INFO, "%sPollWait %s", DebugTag().c_str(), + StateString().c_str()); + } + if (IsSet()) { + return Empty{}; + } else { + return waiter_.pending(); + } + }; + } + + // Set the latch. + void Set() { + if (grpc_trace_promise_primitives.enabled()) { + gpr_log(GPR_INFO, "%sSet %s", DebugTag().c_str(), StateString().c_str()); + } + is_set_.store(true, std::memory_order_relaxed); + waiter_.Wake(); + } + + bool IsSet() const { return is_set_.load(std::memory_order_relaxed); } + + void Reset() { + if (grpc_trace_promise_primitives.enabled()) { + gpr_log(GPR_INFO, "%sReset %s", DebugTag().c_str(), + StateString().c_str()); + } + is_set_.store(false, std::memory_order_relaxed); + } + + private: + std::string DebugTag() { + return absl::StrCat(Activity::current()->DebugTag(), " LATCH(void)[0x", + reinterpret_cast(this), "]: "); + } + + std::string StateString() { + return absl::StrCat( + "is_set:", is_set_.load(std::memory_order_relaxed) ? "true" : "false", + " waiter:", waiter_.DebugString()); + } + + // True if we have a value set, false otherwise. + std::atomic is_set_{false}; + IntraActivityWaiter waiter_; +}; + template using LatchWaitPromise = decltype(std::declval>().Wait()); diff --git a/src/core/lib/promise/loop.h b/src/core/lib/promise/loop.h index f0b3f713a9b..0833865ac51 100644 --- a/src/core/lib/promise/loop.h +++ b/src/core/lib/promise/loop.h @@ -17,14 +17,13 @@ #include -#include #include -#include #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/variant.h" +#include "src/core/lib/gprpp/construct_destruct.h" #include "src/core/lib/promise/detail/promise_factory.h" #include "src/core/lib/promise/poll.h" @@ -83,17 +82,21 @@ class Loop { public: using Result = typename LoopTraits::Result; - explicit Loop(F f) : factory_(std::move(f)), promise_(factory_.Make()) {} - ~Loop() { promise_.~PromiseType(); } + explicit Loop(F f) : factory_(std::move(f)) {} + ~Loop() { + if (started_) Destruct(&promise_); + } - Loop(Loop&& loop) noexcept - : factory_(std::move(loop.factory_)), - promise_(std::move(loop.promise_)) {} + Loop(Loop&& loop) noexcept : factory_(std::move(loop.factory_)) {} Loop(const Loop& loop) = delete; Loop& operator=(const Loop& loop) = delete; Poll operator()() { + if (!started_) { + started_ = true; + Construct(&promise_, factory_.Make()); + } while (true) { // Poll the inner promise. auto promise_result = promise_(); @@ -103,8 +106,8 @@ class Loop { // from our factory. auto lc = LoopTraits::ToLoopCtl(*p); if (absl::holds_alternative(lc)) { - promise_.~PromiseType(); - new (&promise_) PromiseType(factory_.Make()); + Destruct(&promise_); + Construct(&promise_, factory_.Make()); continue; } // - otherwise there's our result... return it out. @@ -121,6 +124,7 @@ class Loop { GPR_NO_UNIQUE_ADDRESS union { GPR_NO_UNIQUE_ADDRESS PromiseType promise_; }; + bool started_ = false; }; } // namespace promise_detail diff --git a/src/core/lib/promise/map.h b/src/core/lib/promise/map.h index a3088d9f6f0..44e19bb96ac 100644 --- a/src/core/lib/promise/map.h +++ b/src/core/lib/promise/map.h @@ -39,6 +39,13 @@ class Map { Map(Promise promise, Fn fn) : promise_(std::move(promise)), fn_(std::move(fn)) {} + Map(const Map&) = delete; + Map& operator=(const Map&) = delete; + // NOLINTNEXTLINE(performance-noexcept-move-constructor): clang6 bug + Map(Map&& other) = default; + // NOLINTNEXTLINE(performance-noexcept-move-constructor): clang6 bug + Map& operator=(Map&& other) = default; + using PromiseResult = typename PromiseLike::Result; using Result = RemoveCVRef()(std::declval()))>; diff --git a/src/core/lib/promise/observable.h b/src/core/lib/promise/observable.h deleted file mode 100644 index 3138d90dfa2..00000000000 --- a/src/core/lib/promise/observable.h +++ /dev/null @@ -1,295 +0,0 @@ -// Copyright 2021 gRPC authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef GRPC_SRC_CORE_LIB_PROMISE_OBSERVABLE_H -#define GRPC_SRC_CORE_LIB_PROMISE_OBSERVABLE_H - -#include - -#include - -#include -#include -#include -#include - -#include "absl/base/thread_annotations.h" -#include "absl/types/optional.h" - -#include "src/core/lib/gprpp/sync.h" -#include "src/core/lib/promise/activity.h" -#include "src/core/lib/promise/detail/promise_like.h" -#include "src/core/lib/promise/poll.h" -#include "src/core/lib/promise/wait_set.h" - -namespace grpc_core { - -namespace promise_detail { - -using ObservableVersion = uint64_t; -static constexpr ObservableVersion kTombstoneVersion = - std::numeric_limits::max(); - -} // namespace promise_detail - -class WatchCommitter { - public: - void Commit() { version_seen_ = promise_detail::kTombstoneVersion; } - - protected: - promise_detail::ObservableVersion version_seen_ = 0; -}; - -namespace promise_detail { - -// Shared state between Observable and Observer. -template -class ObservableState { - public: - explicit ObservableState(absl::optional value) - : value_(std::move(value)) {} - - // Publish that we're closed. - void Close() { - mu_.Lock(); - version_ = kTombstoneVersion; - value_.reset(); - auto wakeup = waiters_.TakeWakeupSet(); - mu_.Unlock(); - wakeup.Wakeup(); - } - - // Synchronously publish a new value, and wake any waiters. - void Push(T value) { - mu_.Lock(); - version_++; - value_ = std::move(value); - auto wakeup = waiters_.TakeWakeupSet(); - mu_.Unlock(); - wakeup.Wakeup(); - } - - Poll> PollGet(ObservableVersion* version_seen) { - MutexLock lock(&mu_); - if (!Started()) return Pending(); - *version_seen = version_; - return value_; - } - - Poll> PollNext(ObservableVersion* version_seen) { - MutexLock lock(&mu_); - if (!NextValueReady(version_seen)) return Pending(); - return value_; - } - - Poll> PollWatch(ObservableVersion* version_seen) { - if (*version_seen == kTombstoneVersion) return Pending(); - - MutexLock lock(&mu_); - if (!NextValueReady(version_seen)) return Pending(); - // Watch needs to be woken up if the value changes even if it's ready now. - waiters_.AddPending(Activity::current()->MakeNonOwningWaker()); - return value_; - } - - private: - // Returns true if an initial value is set. - // If one is not set, add ourselves as pending to waiters_, and return false. - bool Started() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { - if (!value_.has_value()) { - if (version_ != kTombstoneVersion) { - // We allow initial no-value, which does not indicate closure. - waiters_.AddPending(Activity::current()->MakeNonOwningWaker()); - return false; - } - } - return true; - } - - // If no value is ready, add ourselves as pending to waiters_ and return - // false. - // If the next value is ready, update the last version seen and return true. - bool NextValueReady(ObservableVersion* version_seen) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { - if (!Started()) return false; - if (version_ == *version_seen) { - waiters_.AddPending(Activity::current()->MakeNonOwningWaker()); - return false; - } - *version_seen = version_; - return true; - } - - Mutex mu_; - WaitSet waiters_ ABSL_GUARDED_BY(mu_); - ObservableVersion version_ ABSL_GUARDED_BY(mu_) = 1; - absl::optional value_ ABSL_GUARDED_BY(mu_); -}; - -// Promise implementation for Observer::Get. -template -class ObservableGet { - public: - ObservableGet(ObservableVersion* version_seen, ObservableState* state) - : version_seen_(version_seen), state_(state) {} - - Poll> operator()() { - return state_->PollGet(version_seen_); - } - - private: - ObservableVersion* version_seen_; - ObservableState* state_; -}; - -// Promise implementation for Observer::Next. -template -class ObservableNext { - public: - ObservableNext(ObservableVersion* version_seen, ObservableState* state) - : version_seen_(version_seen), state_(state) {} - - Poll> operator()() { - return state_->PollNext(version_seen_); - } - - private: - ObservableVersion* version_seen_; - ObservableState* state_; -}; - -template -class ObservableWatch final : private WatchCommitter { - private: - using Promise = PromiseLike()( - std::declval(), std::declval()))>; - using Result = typename Promise::Result; - - public: - explicit ObservableWatch(F factory, std::shared_ptr> state) - : state_(std::move(state)), factory_(std::move(factory)) {} - ObservableWatch(const ObservableWatch&) = delete; - ObservableWatch& operator=(const ObservableWatch&) = delete; - ObservableWatch(ObservableWatch&& other) noexcept - : state_(std::move(other.state_)), - promise_(std::move(other.promise_)), - factory_(std::move(other.factory_)) {} - ObservableWatch& operator=(ObservableWatch&&) noexcept = default; - - Poll operator()() { - auto r = state_->PollWatch(&version_seen_); - if (auto* p = r.value_if_ready()) { - if (p->has_value()) { - promise_ = Promise(factory_(std::move(**p), this)); - } else { - promise_ = {}; - } - } - if (promise_.has_value()) { - return (*promise_)(); - } else { - return Pending(); - } - } - - private: - std::shared_ptr> state_; - absl::optional promise_; - F factory_; -}; - -} // namespace promise_detail - -template -class Observable; - -// Observer watches an Observable for updates. -// It can see either the latest value or wait for a new value, but is not -// guaranteed to see every value pushed to the Observable. -template -class Observer { - public: - Observer(const Observer&) = delete; - Observer& operator=(const Observer&) = delete; - Observer(Observer&& other) noexcept - : version_seen_(other.version_seen_), state_(std::move(other.state_)) {} - Observer& operator=(Observer&& other) noexcept { - version_seen_ = other.version_seen_; - state_ = std::move(other.state_); - return *this; - } - - // Return a promise that will produce an optional. - // If the Observable is still present, this will be a value T, but if the - // Observable has been closed, this will be nullopt. Borrows data from the - // Observer, so this value must stay valid until the promise is resolved. Only - // one Next, Get call is allowed to be outstanding at a time. - promise_detail::ObservableGet Get() { - return promise_detail::ObservableGet{&version_seen_, &*state_}; - } - - // Return a promise that will produce the next unseen value as an optional. - // If the Observable is still present, this will be a value T, but if the - // Observable has been closed, this will be nullopt. Borrows data from the - // Observer, so this value must stay valid until the promise is resolved. Only - // one Next, Get call is allowed to be outstanding at a time. - promise_detail::ObservableNext Next() { - return promise_detail::ObservableNext{&version_seen_, &*state_}; - } - - private: - using State = promise_detail::ObservableState; - friend class Observable; - explicit Observer(std::shared_ptr state) : state_(state) {} - promise_detail::ObservableVersion version_seen_ = 0; - std::shared_ptr state_; -}; - -// Observable models a single writer multiple reader broadcast channel. -// Readers can observe the latest value, or await a new latest value, but they -// are not guaranteed to observe every value. -template -class Observable { - public: - Observable() : state_(std::make_shared(absl::nullopt)) {} - explicit Observable(T value) - : state_(std::make_shared(std::move(value))) {} - ~Observable() { state_->Close(); } - Observable(const Observable&) = delete; - Observable& operator=(const Observable&) = delete; - - // Push a new value into the observable. - void Push(T value) { state_->Push(std::move(value)); } - - // Create a new Observer - which can pull the current state from this - // Observable. - Observer MakeObserver() { return Observer(state_); } - - // Create a new Watch - a promise that pushes state into the passed in promise - // factory. The promise factory takes two parameters - the current value and a - // commit token. If the commit token is used (the Commit function on it is - // called), then no further Watch updates are provided. - template - promise_detail::ObservableWatch Watch(F f) { - return promise_detail::ObservableWatch(std::move(f), state_); - } - - private: - using State = promise_detail::ObservableState; - std::shared_ptr state_; -}; - -} // namespace grpc_core - -#endif // GRPC_SRC_CORE_LIB_PROMISE_OBSERVABLE_H diff --git a/src/core/lib/promise/party.cc b/src/core/lib/promise/party.cc index 98c9ea23b0c..6c7e38011de 100644 --- a/src/core/lib/promise/party.cc +++ b/src/core/lib/promise/party.cc @@ -16,19 +16,11 @@ #include "src/core/lib/promise/party.h" -#include - -#include #include #include -#include -#include -#include #include "absl/base/thread_annotations.h" -#include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" -#include "absl/strings/str_join.h" #include @@ -37,8 +29,59 @@ #include "src/core/lib/promise/activity.h" #include "src/core/lib/promise/trace.h" +// #define GRPC_PARTY_MAXIMIZE_THREADS + +#ifdef GRPC_PARTY_MAXIMIZE_THREADS +#include "src/core/lib/gprpp/thd.h" // IWYU pragma: keep +#include "src/core/lib/iomgr/exec_ctx.h" // IWYU pragma: keep +#endif + namespace grpc_core { +/////////////////////////////////////////////////////////////////////////////// +// PartySyncUsingAtomics + +GRPC_MUST_USE_RESULT bool PartySyncUsingAtomics::RefIfNonZero() { + auto count = state_.load(std::memory_order_relaxed); + do { + // If zero, we are done (without an increment). If not, we must do a CAS + // to maintain the contract: do not increment the counter if it is already + // zero + if (count == 0) { + return false; + } + } while (!state_.compare_exchange_weak(count, count + kOneRef, + std::memory_order_acq_rel, + std::memory_order_relaxed)); + return true; +} + +bool PartySyncUsingAtomics::UnreffedLast() { + uint64_t prev_state = + state_.fetch_or(kDestroying | kLocked, std::memory_order_acq_rel); + return (prev_state & kLocked) == 0; +} + +bool PartySyncUsingAtomics::ScheduleWakeup(WakeupMask mask) { + // Or in the wakeup bit for the participant, AND the locked bit. + uint64_t prev_state = state_.fetch_or((mask & kWakeupMask) | kLocked, + std::memory_order_acq_rel); + // If the lock was not held now we hold it, so we need to run. + return ((prev_state & kLocked) == 0); +} + +/////////////////////////////////////////////////////////////////////////////// +// PartySyncUsingMutex + +bool PartySyncUsingMutex::ScheduleWakeup(WakeupMask mask) { + MutexLock lock(&mu_); + wakeups_ |= mask; + return !std::exchange(locked_, true); +} + +/////////////////////////////////////////////////////////////////////////////// +// Party::Handle + // Weak handle to a Party. // Handle can persist while Party goes away. class Party::Handle final : public Wakeable { @@ -59,7 +102,7 @@ class Party::Handle final : public Wakeable { // Activity needs to wake up (if it still exists!) - wake it up, and drop the // ref that was kept for this handle. - void Wakeup(void* arg) override ABSL_LOCKS_EXCLUDED(mu_) { + void Wakeup(WakeupMask wakeup_mask) override ABSL_LOCKS_EXCLUDED(mu_) { mu_.Lock(); // Note that activity refcount can drop to zero, but we could win the lock // against DropActivity, so we need to only increase activities refcount if @@ -69,7 +112,7 @@ class Party::Handle final : public Wakeable { mu_.Unlock(); // Activity still exists and we have a reference: wake it up, which will // drop the ref. - party->Wakeup(reinterpret_cast(arg)); + party->Wakeup(wakeup_mask); } else { // Could not get the activity - it's either gone or going. No need to wake // it up! @@ -79,9 +122,9 @@ class Party::Handle final : public Wakeable { Unref(); } - void Drop(void*) override { Unref(); } + void Drop(WakeupMask) override { Unref(); } - std::string ActivityDebugTag(void*) const override { + std::string ActivityDebugTag(WakeupMask) const override { MutexLock lock(&mu_); return party_ == nullptr ? "" : party_->DebugTag(); } @@ -116,206 +159,128 @@ Party::Participant::~Participant() { } } -Party::~Party() { - participants_.clear(); - arena_->Destroy(); -} - -void Party::Orphan() { Unref(); } - -void Party::Ref() { state_.fetch_add(kOneRef, std::memory_order_relaxed); } +Party::~Party() {} -bool Party::RefIfNonZero() { - auto count = state_.load(std::memory_order_relaxed); - do { - // If zero, we are done (without an increment). If not, we must do a CAS - // to maintain the contract: do not increment the counter if it is already - // zero - if (count == 0) { - return false; +void Party::CancelRemainingParticipants() { + ScopedActivity activity(this); + promise_detail::Context arena_ctx(arena_); + for (size_t i = 0; i < party_detail::kMaxParticipants; i++) { + if (auto* p = + participants_[i].exchange(nullptr, std::memory_order_acquire)) { + p->Destroy(); } - } while (!state_.compare_exchange_weak(count, count + kOneRef, - std::memory_order_acq_rel, - std::memory_order_relaxed)); - return true; -} - -void Party::Unref() { - auto prev = state_.fetch_sub(kOneRef, std::memory_order_acq_rel); - if (prev == kOneRef) { - delete this; } - GPR_DEBUG_ASSERT((prev & kRefMask) != 0); } -std::string Party::ActivityDebugTag(void* arg) const { - return absl::StrFormat("%s/%p", DebugTag(), arg); +std::string Party::ActivityDebugTag(WakeupMask wakeup_mask) const { + return absl::StrFormat("%s [parts:%x]", DebugTag(), wakeup_mask); } Waker Party::MakeOwningWaker() { GPR_DEBUG_ASSERT(currently_polling_ != kNotPolling); - Ref(); - return Waker(this, reinterpret_cast(currently_polling_)); + IncrementRefCount(); + return Waker(this, 1u << currently_polling_); } Waker Party::MakeNonOwningWaker() { GPR_DEBUG_ASSERT(currently_polling_ != kNotPolling); - return Waker(participants_[currently_polling_]->MakeNonOwningWakeable(this), - reinterpret_cast(currently_polling_)); + return Waker(participants_[currently_polling_] + .load(std::memory_order_relaxed) + ->MakeNonOwningWakeable(this), + 1u << currently_polling_); } -void Party::ForceImmediateRepoll() { - GPR_DEBUG_ASSERT(currently_polling_ != kNotPolling); - // Or in the bit for the currently polling participant. - // Will be grabbed next round to force a repoll of this promise. - state_.fetch_or(1 << currently_polling_, std::memory_order_relaxed); +void Party::ForceImmediateRepoll(WakeupMask mask) { + GPR_DEBUG_ASSERT(is_current()); + sync_.ForceImmediateRepoll(mask); } -void Party::Run() { +void Party::RunLocked() { + auto body = [this]() { + if (RunParty()) { + ScopedActivity activity(this); + PartyOver(); + } + }; +#ifdef GRPC_PARTY_MAXIMIZE_THREADS + Thread thd( + "RunParty", + [body]() { + ApplicationCallbackExecCtx app_exec_ctx; + ExecCtx exec_ctx; + body(); + }, + nullptr, Thread::Options().set_joinable(false)); + thd.Start(); +#else + body(); +#endif +} + +bool Party::RunParty() { ScopedActivity activity(this); - uint64_t prev_state; - do { - // Grab the current state, and clear the wakeup bits & add flag. - prev_state = - state_.fetch_and(kRefMask | kLocked, std::memory_order_acquire); + promise_detail::Context arena_ctx(arena_); + return sync_.RunParty([this](int i) { + // If the participant is null, skip. + // This allows participants to complete whilst wakers still exist + // somewhere. + auto* participant = participants_[i].load(std::memory_order_acquire); + if (participant == nullptr) { + if (grpc_trace_promise_primitives.enabled()) { + gpr_log(GPR_DEBUG, "%s[party] wakeup %d already complete", + DebugTag().c_str(), i); + } + return false; + } + absl::string_view name; if (grpc_trace_promise_primitives.enabled()) { - gpr_log(GPR_DEBUG, "Party::Run(): prev_state=%s", - StateToString(prev_state).c_str()); + name = participant->name(); + gpr_log(GPR_DEBUG, "%s[%s] begin job %d", DebugTag().c_str(), + std::string(name).c_str(), i); } - // From the previous state, extract which participants we're to wakeup. - uint64_t wakeups = prev_state & kWakeupMask; - // If there were adds pending, drain them. - // We pass in wakeups here so that the new participants are polled - // immediately (draining will situate them). - if (prev_state & kAddsPending) DrainAdds(wakeups); - // Now update prev_state to be what we want the CAS to see below. - prev_state &= kRefMask | kLocked; - // For each wakeup bit... - for (size_t i = 0; wakeups != 0; i++, wakeups >>= 1) { - // If the bit is not set, skip. - if ((wakeups & 1) == 0) continue; - // If the participant is null, skip. - // This allows participants to complete whilst wakers still exist - // somewhere. - if (participants_[i] == nullptr) continue; - // Poll the participant. - currently_polling_ = i; - if (participants_[i]->Poll()) participants_[i].reset(); - currently_polling_ = kNotPolling; + // Poll the participant. + currently_polling_ = i; + bool done = participant->Poll(); + currently_polling_ = kNotPolling; + if (done) { + if (!name.empty()) { + gpr_log(GPR_DEBUG, "%s[%s] end poll and finish job %d", + DebugTag().c_str(), std::string(name).c_str(), i); + } + participants_[i].store(nullptr, std::memory_order_relaxed); + } else if (!name.empty()) { + gpr_log(GPR_DEBUG, "%s[%s] end poll", DebugTag().c_str(), + std::string(name).c_str()); } - // Try to CAS the state we expected to have (with no wakeups or adds) - // back to unlocked (by masking in only the ref mask - sans locked bit). - // If this succeeds then no wakeups were added, no adds were added, and we - // have successfully unlocked. - // Otherwise, we need to loop again. - // Note that if an owning waker is created or the weak cas spuriously fails - // we will also loop again, but in that case see no wakeups or adds and so - // will get back here fairly quickly. - // TODO(ctiller): consider mitigations for the accidental wakeup on owning - // waker creation case -- I currently expect this will be more expensive - // than this quick loop. - } while (!state_.compare_exchange_weak(prev_state, (prev_state & kRefMask), - std::memory_order_acq_rel, - std::memory_order_acquire)); -} - -void Party::DrainAdds(uint64_t& wakeups) { - // Grab the list of adds. - AddingParticipant* adding = - adding_.exchange(nullptr, std::memory_order_acquire); - // For each add, situate it and add it to the wakeup mask. - while (adding != nullptr) { - wakeups |= 1 << SituateNewParticipant(std::move(adding->participant)); - // Don't leak the add request. - delete std::exchange(adding, adding->next); - } + return done; + }); } -void Party::AddParticipant(Arena::PoolPtr participant) { - // Lock - auto prev_state = state_.fetch_or(kLocked, std::memory_order_acquire); - if (grpc_trace_promise_primitives.enabled()) { - gpr_log(GPR_DEBUG, "Party::AddParticipant(): prev_state=%s", - StateToString(prev_state).c_str()); - } - if ((prev_state & kLocked) == 0) { - // Lock acquired - state_.fetch_or(1 << SituateNewParticipant(std::move(participant)), - std::memory_order_relaxed); - Run(); - return; - } - // Already locked: add to the list of things to add - auto* add = new AddingParticipant{std::move(participant), nullptr}; - while (!adding_.compare_exchange_weak( - add->next, add, std::memory_order_acq_rel, std::memory_order_acquire)) { - } - // And signal that there are adds waiting. - // This needs to happen after the add above: Run() will examine this bit - // first, and then decide to drain the queue - so if the ordering was reversed - // it might examine the adds pending bit, and then observe no add to drain. - prev_state = - state_.fetch_or(kLocked | kAddsPending, std::memory_order_release); - if (grpc_trace_promise_primitives.enabled()) { - gpr_log(GPR_DEBUG, "Party::AddParticipant(): prev_state=%s", - StateToString(prev_state).c_str()); - } - if ((prev_state & kLocked) == 0) { - // We queued the add but the lock was released before we signalled that. - // We acquired the lock though, so now we can run. - Run(); - } -} - -size_t Party::SituateNewParticipant(Arena::PoolPtr participant) { - // First search for a free index in the participants array. - // If we find one, use it. - for (size_t i = 0; i < participants_.size(); i++) { - if (participants_[i] != nullptr) continue; - participants_[i] = std::move(participant); - return i; - } - - // Otherwise, add it to the end. - GPR_ASSERT(participants_.size() < kMaxParticipants); - participants_.emplace_back(std::move(participant)); - return participants_.size() - 1; +void Party::AddParticipants(Participant** participants, size_t count) { + bool run_party = sync_.AddParticipantsAndRef(count, [this, participants, + count](size_t* slots) { + for (size_t i = 0; i < count; i++) { + participants_[slots[i]].store(participants[i], std::memory_order_release); + } + }); + if (run_party) RunLocked(); + Unref(); } -void Party::ScheduleWakeup(uint64_t participant_index) { - // Or in the wakeup bit for the participant, AND the locked bit. - uint64_t prev_state = state_.fetch_or((1 << participant_index) | kLocked, - std::memory_order_acquire); - if (grpc_trace_promise_primitives.enabled()) { - gpr_log(GPR_DEBUG, "Party::ScheduleWakeup(%" PRIu64 "): prev_state=%s", - participant_index, StateToString(prev_state).c_str()); - } - // If the lock was not held now we hold it, so we need to run. - if ((prev_state & kLocked) == 0) Run(); +void Party::ScheduleWakeup(WakeupMask mask) { + if (sync_.ScheduleWakeup(mask)) RunLocked(); } -void Party::Wakeup(void* arg) { - ScheduleWakeup(reinterpret_cast(arg)); +void Party::Wakeup(WakeupMask wakeup_mask) { + ScheduleWakeup(wakeup_mask); Unref(); } -void Party::Drop(void*) { Unref(); } - -std::string Party::StateToString(uint64_t state) { - std::vector parts; - if (state & kLocked) parts.push_back("locked"); - if (state & kAddsPending) parts.push_back("adds_pending"); - parts.push_back( - absl::StrFormat("refs=%" PRIuPTR, (state & kRefMask) >> kRefShift)); - std::vector participants; - for (size_t i = 0; i < kMaxParticipants; i++) { - if ((state & (1 << i)) != 0) participants.push_back(i); - } - if (!participants.empty()) { - parts.push_back( - absl::StrFormat("wakeup=%s", absl::StrJoin(participants, ","))); - } - return absl::StrCat("{", absl::StrJoin(parts, " "), "}"); +void Party::Drop(WakeupMask) { Unref(); } + +void Party::PartyIsOver() { + ScopedActivity activity(this); + PartyOver(); } } // namespace grpc_core diff --git a/src/core/lib/promise/party.h b/src/core/lib/promise/party.h index 3032d05cd45..90e4f4f7481 100644 --- a/src/core/lib/promise/party.h +++ b/src/core/lib/promise/party.h @@ -24,39 +24,353 @@ #include #include -#include "absl/container/inlined_vector.h" +#include "absl/base/thread_annotations.h" +#include "absl/strings/string_view.h" +#include + +#include "src/core/lib/gprpp/construct_destruct.h" +#include "src/core/lib/gprpp/crash.h" +#include "src/core/lib/gprpp/ref_counted.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" +#include "src/core/lib/gprpp/sync.h" #include "src/core/lib/promise/activity.h" +#include "src/core/lib/promise/context.h" +#include "src/core/lib/promise/detail/promise_factory.h" #include "src/core/lib/resource_quota/arena.h" +// Two implementations of party synchronization are provided: one using a single +// atomic, the other using a mutex and a set of state variables. +// Originally the atomic implementation was implemented, but we found some race +// conditions on Arm that were not reported by our default TSAN implementation. +// The mutex implementation was added to see if it would fix the problem, and +// it did. Later we found the race condition, so there's no known reason to use +// the mutex version - however we keep it around as a just in case measure. +// There's a thought of fuzzing the two implementations against each other as +// a correctness check of both, but that's not implemented yet. + +#define GRPC_PARTY_SYNC_USING_ATOMICS +// #define GRPC_PARTY_SYNC_USING_MUTEX + +#if defined(GRPC_PARTY_SYNC_USING_ATOMICS) + \ + defined(GRPC_PARTY_SYNC_USING_MUTEX) != \ + 1 +#error Must define a party sync mechanism +#endif + namespace grpc_core { +namespace party_detail { + +// Number of bits reserved for wakeups gives us the maximum number of +// participants. +static constexpr size_t kMaxParticipants = 16; + +} // namespace party_detail + +class PartySyncUsingAtomics { + public: + explicit PartySyncUsingAtomics(size_t initial_refs) + : state_(kOneRef * initial_refs) {} + + void IncrementRefCount() { + state_.fetch_add(kOneRef, std::memory_order_relaxed); + } + GRPC_MUST_USE_RESULT bool RefIfNonZero(); + // Returns true if the ref count is now zero and the caller should call + // PartyOver + GRPC_MUST_USE_RESULT bool Unref() { + uint64_t prev_state = state_.fetch_sub(kOneRef, std::memory_order_acq_rel); + if ((prev_state & kRefMask) == kOneRef) { + return UnreffedLast(); + } + return false; + } + void ForceImmediateRepoll(WakeupMask mask) { + // Or in the bit for the currently polling participant. + // Will be grabbed next round to force a repoll of this promise. + state_.fetch_or(mask, std::memory_order_relaxed); + } + + // Run the update loop: poll_one_participant is called with an integral index + // for the participant that should be polled. It should return true if the + // participant completed and should be removed from the allocated set. + template + GRPC_MUST_USE_RESULT bool RunParty(F poll_one_participant) { + uint64_t prev_state; + do { + // Grab the current state, and clear the wakeup bits & add flag. + prev_state = state_.fetch_and(kRefMask | kLocked | kAllocatedMask, + std::memory_order_acquire); + GPR_ASSERT(prev_state & kLocked); + if (prev_state & kDestroying) return true; + // From the previous state, extract which participants we're to wakeup. + uint64_t wakeups = prev_state & kWakeupMask; + // Now update prev_state to be what we want the CAS to see below. + prev_state &= kRefMask | kLocked | kAllocatedMask; + // For each wakeup bit... + for (size_t i = 0; wakeups != 0; i++, wakeups >>= 1) { + // If the bit is not set, skip. + if ((wakeups & 1) == 0) continue; + if (poll_one_participant(i)) { + const uint64_t allocated_bit = (1u << i << kAllocatedShift); + prev_state &= ~allocated_bit; + state_.fetch_and(~allocated_bit, std::memory_order_release); + } + } + // Try to CAS the state we expected to have (with no wakeups or adds) + // back to unlocked (by masking in only the ref mask - sans locked bit). + // If this succeeds then no wakeups were added, no adds were added, and we + // have successfully unlocked. + // Otherwise, we need to loop again. + // Note that if an owning waker is created or the weak cas spuriously + // fails we will also loop again, but in that case see no wakeups or adds + // and so will get back here fairly quickly. + // TODO(ctiller): consider mitigations for the accidental wakeup on owning + // waker creation case -- I currently expect this will be more expensive + // than this quick loop. + } while (!state_.compare_exchange_weak( + prev_state, (prev_state & (kRefMask | kAllocatedMask)), + std::memory_order_acq_rel, std::memory_order_acquire)); + return false; + } + + // Add new participants to the party. Returns true if the caller should run + // the party. store is called with an array of indices of the new + // participants. Adds a ref that should be dropped by the caller after + // RunParty has been called (if that was required). + template + GRPC_MUST_USE_RESULT bool AddParticipantsAndRef(size_t count, F store) { + uint64_t state = state_.load(std::memory_order_acquire); + uint64_t allocated; + + size_t slots[party_detail::kMaxParticipants]; + + // Find slots for each new participant, ordering them from lowest available + // slot upwards to ensure the same poll ordering as presentation ordering to + // this function. + WakeupMask wakeup_mask; + do { + wakeup_mask = 0; + allocated = (state & kAllocatedMask) >> kAllocatedShift; + size_t n = 0; + for (size_t bit = 0; n < count && bit < party_detail::kMaxParticipants; + bit++) { + if (allocated & (1 << bit)) continue; + wakeup_mask |= (1 << bit); + slots[n++] = bit; + allocated |= 1 << bit; + } + GPR_ASSERT(n == count); + // Try to allocate this slot and take a ref (atomically). + // Ref needs to be taken because once we store the participant it could be + // spuriously woken up and unref the party. + } while (!state_.compare_exchange_weak( + state, (state | (allocated << kAllocatedShift)) + kOneRef, + std::memory_order_acq_rel, std::memory_order_acquire)); + + store(slots); + + // Now we need to wake up the party. + state = state_.fetch_or(wakeup_mask | kLocked, std::memory_order_release); + + // If the party was already locked, we're done. + return ((state & kLocked) == 0); + } + + // Schedule a wakeup for the given participant. + // Returns true if the caller should run the party. + GRPC_MUST_USE_RESULT bool ScheduleWakeup(WakeupMask mask); + + private: + bool UnreffedLast(); + + // State bits: + // The atomic state_ field is composed of the following: + // - 24 bits for ref counts + // 1 is owned by the party prior to Orphan() + // All others are owned by owning wakers + // - 1 bit to indicate whether the party is locked + // The first thread to set this owns the party until it is unlocked + // That thread will run the main loop until no further work needs to + // be done. + // - 1 bit to indicate whether there are participants waiting to be + // added + // - 16 bits, one per participant, indicating which participants have + // been + // woken up and should be polled next time the main loop runs. + + // clang-format off + // Bits used to store 16 bits of wakeups + static constexpr uint64_t kWakeupMask = 0x0000'0000'0000'ffff; + // Bits used to store 16 bits of allocated participant slots. + static constexpr uint64_t kAllocatedMask = 0x0000'0000'ffff'0000; + // Bit indicating destruction has begun (refs went to zero) + static constexpr uint64_t kDestroying = 0x0000'0001'0000'0000; + // Bit indicating locked or not + static constexpr uint64_t kLocked = 0x0000'0008'0000'0000; + // Bits used to store 24 bits of ref counts + static constexpr uint64_t kRefMask = 0xffff'ff00'0000'0000; + // clang-format on + + // Shift to get from a participant mask to an allocated mask. + static constexpr size_t kAllocatedShift = 16; + // How far to shift to get the refcount + static constexpr size_t kRefShift = 40; + // One ref count + static constexpr uint64_t kOneRef = 1ull << kRefShift; + + std::atomic state_; +}; + +class PartySyncUsingMutex { + public: + explicit PartySyncUsingMutex(size_t initial_refs) : refs_(initial_refs) {} + + void IncrementRefCount() { refs_.Ref(); } + GRPC_MUST_USE_RESULT bool RefIfNonZero() { return refs_.RefIfNonZero(); } + GRPC_MUST_USE_RESULT bool Unref() { return refs_.Unref(); } + void ForceImmediateRepoll(WakeupMask mask) { + MutexLock lock(&mu_); + wakeups_ |= mask; + } + template + GRPC_MUST_USE_RESULT bool RunParty(F poll_one_participant) { + WakeupMask freed = 0; + while (true) { + ReleasableMutexLock lock(&mu_); + GPR_ASSERT(locked_); + allocated_ &= ~std::exchange(freed, 0); + auto wakeup = std::exchange(wakeups_, 0); + if (wakeup == 0) { + locked_ = false; + return false; + } + lock.Release(); + for (size_t i = 0; wakeup != 0; i++, wakeup >>= 1) { + if ((wakeup & 1) == 0) continue; + if (poll_one_participant(i)) freed |= 1 << i; + } + } + } + + template + GRPC_MUST_USE_RESULT bool AddParticipantsAndRef(size_t count, F store) { + IncrementRefCount(); + MutexLock lock(&mu_); + size_t slots[party_detail::kMaxParticipants]; + WakeupMask wakeup_mask = 0; + size_t n = 0; + for (size_t bit = 0; n < count && bit < party_detail::kMaxParticipants; + bit++) { + if (allocated_ & (1 << bit)) continue; + slots[n++] = bit; + wakeup_mask |= 1 << bit; + allocated_ |= 1 << bit; + } + GPR_ASSERT(n == count); + store(slots); + wakeups_ |= wakeup_mask; + return !std::exchange(locked_, true); + } + + GRPC_MUST_USE_RESULT bool ScheduleWakeup(WakeupMask mask); + + private: + RefCount refs_; + Mutex mu_; + WakeupMask allocated_ ABSL_GUARDED_BY(mu_) = 0; + WakeupMask wakeups_ ABSL_GUARDED_BY(mu_) = 0; + bool locked_ ABSL_GUARDED_BY(mu_) = false; +}; + // A Party is an Activity with multiple participant promises. class Party : public Activity, private Wakeable { - public: - explicit Party(Arena* arena) : arena_(arena) {} + private: + // Non-owning wakeup handle. + class Handle; + + // One participant in the party. + class Participant { + public: + explicit Participant(absl::string_view name) : name_(name) {} + // Poll the participant. Return true if complete. + // Participant should take care of its own deallocation in this case. + virtual bool Poll() = 0; + + // Destroy the participant before finishing. + virtual void Destroy() = 0; + + // Return a Handle instance for this participant. + Wakeable* MakeNonOwningWakeable(Party* party); + + absl::string_view name() const { return name_; } + + protected: + ~Participant(); + private: + Handle* handle_ = nullptr; + absl::string_view name_; + }; + + public: Party(const Party&) = delete; Party& operator=(const Party&) = delete; - // Spawn one promise onto the arena. + // Spawn one promise into the party. // The promise will be polled until it is resolved, or until the party is shut // down. // The on_complete callback will be called with the result of the promise if // it completes. // A maximum of sixteen promises can be spawned onto a party. - template - void Spawn(Promise promise, OnComplete on_complete); + template + void Spawn(absl::string_view name, Factory promise_factory, + OnComplete on_complete); - void Orphan() final; + void Orphan() final { Crash("unused"); } // Activity implementation: not allowed to be overridden by derived types. - void ForceImmediateRepoll() final; + void ForceImmediateRepoll(WakeupMask mask) final; + WakeupMask CurrentParticipant() const final { + GPR_DEBUG_ASSERT(currently_polling_ != kNotPolling); + return 1u << currently_polling_; + } Waker MakeOwningWaker() final; Waker MakeNonOwningWaker() final; - std::string ActivityDebugTag(void* arg) const final; + std::string ActivityDebugTag(WakeupMask wakeup_mask) const final; + + void IncrementRefCount() { sync_.IncrementRefCount(); } + void Unref() { + if (sync_.Unref()) PartyIsOver(); + } + RefCountedPtr Ref() { + IncrementRefCount(); + return RefCountedPtr(this); + } + + Arena* arena() const { return arena_; } + + class BulkSpawner { + public: + explicit BulkSpawner(Party* party) : party_(party) {} + ~BulkSpawner() { + party_->AddParticipants(participants_, num_participants_); + } + + template + void Spawn(absl::string_view name, Factory promise_factory, + OnComplete on_complete); + + private: + Party* const party_; + size_t num_participants_ = 0; + Participant* participants_[party_detail::kMaxParticipants]; + }; protected: + explicit Party(Arena* arena, size_t initial_refs) + : sync_(initial_refs), arena_(arena) {} ~Party() override; // Main run loop. Must be locked. @@ -64,128 +378,120 @@ class Party : public Activity, private Wakeable { // be done. // Derived types will likely want to override this to set up their // contexts before polling. - virtual void Run(); - - Arena* arena() const { return arena_; } - - private: - // Non-owning wakeup handle. - class Handle; - - // One participant in the party. - class Participant { - public: - virtual ~Participant(); - // Poll the participant. Return true if complete. - virtual bool Poll() = 0; + // Should not be called by derived types except as a tail call to the base + // class RunParty when overriding this method to add custom context. + // Returns true if the party is over. + virtual bool RunParty() GRPC_MUST_USE_RESULT; - // Return a Handle instance for this participant. - Wakeable* MakeNonOwningWakeable(Party* party); + bool RefIfNonZero() { return sync_.RefIfNonZero(); } - private: - Handle* handle_ = nullptr; - }; + // Destroy any remaining participants. + // Should be called by derived types in response to PartyOver. + // Needs to have normal context setup before calling. + void CancelRemainingParticipants(); + private: // Concrete implementation of a participant for some promise & oncomplete // type. - template + template class ParticipantImpl final : public Participant { + using Factory = promise_detail::OncePromiseFactory; + using Promise = typename Factory::Promise; + public: - ParticipantImpl(Promise promise, OnComplete on_complete) - : promise_(std::move(promise)), on_complete_(std::move(on_complete)) {} + ParticipantImpl(absl::string_view name, SuppliedFactory promise_factory, + OnComplete on_complete) + : Participant(name), on_complete_(std::move(on_complete)) { + Construct(&factory_, std::move(promise_factory)); + } + ~ParticipantImpl() { + if (!started_) { + Destruct(&factory_); + } else { + Destruct(&promise_); + } + } bool Poll() override { + if (!started_) { + auto p = factory_.Make(); + Destruct(&factory_); + Construct(&promise_, std::move(p)); + started_ = true; + } auto p = promise_(); if (auto* r = p.value_if_ready()) { on_complete_(std::move(*r)); + GetContext()->DeletePooled(this); return true; } return false; } + void Destroy() override { GetContext()->DeletePooled(this); } + private: - GPR_NO_UNIQUE_ADDRESS Promise promise_; + union { + GPR_NO_UNIQUE_ADDRESS Factory factory_; + GPR_NO_UNIQUE_ADDRESS Promise promise_; + }; GPR_NO_UNIQUE_ADDRESS OnComplete on_complete_; + bool started_ = false; }; - // One participant that's been spawned, but has not yet made it into - // participants_. - // Since it's impossible to block on locking this type, we form a queue of - // participants waiting and drain that prior to polling. - struct AddingParticipant { - Arena::PoolPtr participant; - AddingParticipant* next; - }; + // Notification that the party has finished and this instance can be deleted. + // Derived types should arrange to call CancelRemainingParticipants during + // this sequence. + virtual void PartyOver() = 0; + + // Run the locked part of the party until it is unlocked. + void RunLocked(); + // Called in response to Unref() hitting zero - ultimately calls PartyOver, + // but needs to set some stuff up. + // Here so it gets compiled out of line. + void PartyIsOver(); // Wakeable implementation - void Wakeup(void* arg) final; - void Drop(void* arg) final; - - // Internal ref counting - void Ref(); - bool RefIfNonZero(); - void Unref(); - - // Organize to wake up one participant. - void ScheduleWakeup(uint64_t participant_index); - // Start adding a participant to the party. - // Backs Spawn() after type erasure. - void AddParticipant(Arena::PoolPtr participant); - // Drain the add queue. - void DrainAdds(uint64_t& wakeups); - // Take a new participant, and add it to the participants_ array. - // Returns the index of the participant in the array. - size_t SituateNewParticipant(Arena::PoolPtr new_participant); - - // Convert a state into a string. - static std::string StateToString(uint64_t state); + void Wakeup(WakeupMask wakeup_mask) final; + void Drop(WakeupMask wakeup_mask) final; + + // Organize to wake up some participants. + void ScheduleWakeup(WakeupMask mask); + // Add a participant (backs Spawn, after type erasure to ParticipantFactory). + void AddParticipants(Participant** participant, size_t count); // Sentinal value for currently_polling_ when no participant is being polled. static constexpr uint8_t kNotPolling = 255; - // State bits: - // The atomic state_ field is composed of the following: - // - 24 bits for ref counts - // 1 is owned by the party prior to Orphan() - // All others are owned by owning wakers - // - 1 bit to indicate whether the party is locked - // The first thread to set this owns the party until it is unlocked - // That thread will run the main loop until no further work needs to be - // done. - // - 1 bit to indicate whether there are participants waiting to be added - // - 16 bits, one per participant, indicating which participants have been - // woken up and should be polled next time the main loop runs. - - // clang-format off - // Bits used to store 16 bits of wakeups - static constexpr uint64_t kWakeupMask = 0x0000'0000'0000'ffff; - // Bit indicating locked or not - static constexpr uint64_t kLocked = 0x0000'0000'0100'0000; - // Bit indicating whether there are adds pending - static constexpr uint64_t kAddsPending = 0x0000'0000'1000'0000; - // Bits used to store 24 bits of ref counts - static constexpr uint64_t kRefMask = 0xffff'ff00'0000'0000; - // clang-format on - - // Number of bits reserved for wakeups gives us the maximum number of - // participants. - static constexpr size_t kMaxParticipants = 16; - // How far to shift to get the refcount - static constexpr size_t kRefShift = 40; - // One ref count - static constexpr uint64_t kOneRef = 1ull << kRefShift; +#ifdef GRPC_PARTY_SYNC_USING_ATOMICS + PartySyncUsingAtomics sync_; +#elif defined(GRPC_PARTY_SYNC_USING_MUTEX) + PartySyncUsingMutex sync_; +#else +#error No synchronization method defined +#endif Arena* const arena_; - absl::InlinedVector, 1> participants_; - std::atomic state_{kOneRef}; - std::atomic adding_{nullptr}; uint8_t currently_polling_ = kNotPolling; + // All current participants, using a tagged format. + // If the lower bit is unset, then this is a Participant*. + // If the lower bit is set, then this is a ParticipantFactory*. + std::atomic participants_[party_detail::kMaxParticipants] = {}; }; -template -void Party::Spawn(Promise promise, OnComplete on_complete) { - AddParticipant(arena_->MakePooled>( - std::move(promise), std::move(on_complete))); +template +void Party::BulkSpawner::Spawn(absl::string_view name, Factory promise_factory, + OnComplete on_complete) { + participants_[num_participants_++] = + party_->arena_->NewPooled>( + name, std::move(promise_factory), std::move(on_complete)); +} + +template +void Party::Spawn(absl::string_view name, Factory promise_factory, + OnComplete on_complete) { + BulkSpawner(this).Spawn(name, std::move(promise_factory), + std::move(on_complete)); } } // namespace grpc_core diff --git a/src/core/lib/promise/pipe.h b/src/core/lib/promise/pipe.h index c3ec5ad447c..e989b5cc9a7 100644 --- a/src/core/lib/promise/pipe.h +++ b/src/core/lib/promise/pipe.h @@ -25,7 +25,6 @@ #include #include -#include "absl/base/attributes.h" #include "absl/strings/str_cat.h" #include "absl/types/optional.h" #include "absl/types/variant.h" @@ -39,7 +38,6 @@ #include "src/core/lib/promise/context.h" #include "src/core/lib/promise/if.h" #include "src/core/lib/promise/interceptor_list.h" -#include "src/core/lib/promise/intra_activity_waiter.h" #include "src/core/lib/promise/map.h" #include "src/core/lib/promise/poll.h" #include "src/core/lib/promise/seq.h" @@ -160,9 +158,11 @@ class Center : public InterceptorList { case ValueState::kClosed: case ValueState::kReadyClosed: case ValueState::kCancelled: + case ValueState::kWaitingForAckAndClosed: return false; case ValueState::kReady: case ValueState::kAcked: + case ValueState::kWaitingForAck: return on_empty_.pending(); case ValueState::kEmpty: value_state_ = ValueState::kReady; @@ -180,11 +180,14 @@ class Center : public InterceptorList { GPR_DEBUG_ASSERT(refs_ != 0); switch (value_state_) { case ValueState::kClosed: - case ValueState::kReadyClosed: + return true; case ValueState::kCancelled: return false; case ValueState::kReady: + case ValueState::kReadyClosed: case ValueState::kEmpty: + case ValueState::kWaitingForAck: + case ValueState::kWaitingForAckAndClosed: return on_empty_.pending(); case ValueState::kAcked: value_state_ = ValueState::kEmpty; @@ -206,12 +209,14 @@ class Center : public InterceptorList { switch (value_state_) { case ValueState::kEmpty: case ValueState::kAcked: + case ValueState::kWaitingForAck: + case ValueState::kWaitingForAckAndClosed: return on_full_.pending(); case ValueState::kReadyClosed: - this->ResetInterceptorList(); - value_state_ = ValueState::kClosed; - ABSL_FALLTHROUGH_INTENDED; + value_state_ = ValueState::kWaitingForAckAndClosed; + return std::move(value_); case ValueState::kReady: + value_state_ = ValueState::kWaitingForAck; return std::move(value_); case ValueState::kClosed: case ValueState::kCancelled: @@ -220,18 +225,89 @@ class Center : public InterceptorList { GPR_UNREACHABLE_CODE(return absl::nullopt); } + // Check if the pipe is closed for sending (if there is a value still queued + // but the pipe is closed, reports closed). + Poll PollClosedForSender() { + if (grpc_trace_promise_primitives.enabled()) { + gpr_log(GPR_INFO, "%s", DebugOpString("PollClosedForSender").c_str()); + } + GPR_DEBUG_ASSERT(refs_ != 0); + switch (value_state_) { + case ValueState::kEmpty: + case ValueState::kAcked: + case ValueState::kReady: + case ValueState::kWaitingForAck: + return on_closed_.pending(); + case ValueState::kWaitingForAckAndClosed: + case ValueState::kReadyClosed: + case ValueState::kClosed: + return false; + case ValueState::kCancelled: + return true; + } + GPR_UNREACHABLE_CODE(return true); + } + + // Check if the pipe is closed for receiving (if there is a value still queued + // but the pipe is closed, reports open). + Poll PollClosedForReceiver() { + if (grpc_trace_promise_primitives.enabled()) { + gpr_log(GPR_INFO, "%s", DebugOpString("PollClosedForReceiver").c_str()); + } + GPR_DEBUG_ASSERT(refs_ != 0); + switch (value_state_) { + case ValueState::kEmpty: + case ValueState::kAcked: + case ValueState::kReady: + case ValueState::kReadyClosed: + case ValueState::kWaitingForAck: + case ValueState::kWaitingForAckAndClosed: + return on_closed_.pending(); + case ValueState::kClosed: + return false; + case ValueState::kCancelled: + return true; + } + GPR_UNREACHABLE_CODE(return true); + } + + Poll PollEmpty() { + if (grpc_trace_promise_primitives.enabled()) { + gpr_log(GPR_INFO, "%s", DebugOpString("PollEmpty").c_str()); + } + GPR_DEBUG_ASSERT(refs_ != 0); + switch (value_state_) { + case ValueState::kReady: + case ValueState::kReadyClosed: + return on_empty_.pending(); + case ValueState::kWaitingForAck: + case ValueState::kWaitingForAckAndClosed: + case ValueState::kAcked: + case ValueState::kEmpty: + case ValueState::kClosed: + case ValueState::kCancelled: + return Empty{}; + } + GPR_UNREACHABLE_CODE(return Empty{}); + } + void AckNext() { if (grpc_trace_promise_primitives.enabled()) { gpr_log(GPR_INFO, "%s", DebugOpString("AckNext").c_str()); } switch (value_state_) { case ValueState::kReady: + case ValueState::kWaitingForAck: value_state_ = ValueState::kAcked; on_empty_.Wake(); break; case ValueState::kReadyClosed: + case ValueState::kWaitingForAckAndClosed: this->ResetInterceptorList(); value_state_ = ValueState::kClosed; + on_closed_.Wake(); + on_empty_.Wake(); + on_full_.Wake(); break; case ValueState::kClosed: case ValueState::kCancelled: @@ -251,14 +327,22 @@ class Center : public InterceptorList { case ValueState::kAcked: this->ResetInterceptorList(); value_state_ = ValueState::kClosed; + on_empty_.Wake(); on_full_.Wake(); + on_closed_.Wake(); break; case ValueState::kReady: value_state_ = ValueState::kReadyClosed; + on_closed_.Wake(); + break; + case ValueState::kWaitingForAck: + value_state_ = ValueState::kWaitingForAckAndClosed; + on_closed_.Wake(); break; case ValueState::kReadyClosed: case ValueState::kClosed: case ValueState::kCancelled: + case ValueState::kWaitingForAckAndClosed: break; } } @@ -272,13 +356,15 @@ class Center : public InterceptorList { case ValueState::kAcked: case ValueState::kReady: case ValueState::kReadyClosed: + case ValueState::kWaitingForAck: + case ValueState::kWaitingForAckAndClosed: this->ResetInterceptorList(); value_state_ = ValueState::kCancelled; + on_empty_.Wake(); on_full_.Wake(); + on_closed_.Wake(); break; case ValueState::kClosed: - value_state_ = ValueState::kCancelled; - break; case ValueState::kCancelled: break; } @@ -305,6 +391,8 @@ class Center : public InterceptorList { kEmpty, // Value has been pushed but not acked, it's possible to receive. kReady, + // Value has been read and not acked, both send/receive blocked until ack. + kWaitingForAck, // Value has been received and acked, we can unblock senders and transition // to empty. kAcked, @@ -313,6 +401,9 @@ class Center : public InterceptorList { // Pipe is closed successfully, no more values can be sent // (but one value is queued and ready to be received) kReadyClosed, + // Pipe is closed successfully, no more values can be sent + // (but one value is queued and waiting to be acked) + kWaitingForAckAndClosed, // Pipe is closed unsuccessfully, no more values can be sent kCancelled, }; @@ -321,7 +412,8 @@ class Center : public InterceptorList { return absl::StrCat(DebugTag(), op, " refs=", refs_, " value_state=", ValueStateName(value_state_), " on_empty=", on_empty_.DebugString().c_str(), - " on_full=", on_full_.DebugString().c_str()); + " on_full=", on_full_.DebugString().c_str(), + " on_closed=", on_closed_.DebugString().c_str()); } static const char* ValueStateName(ValueState state) { @@ -336,6 +428,10 @@ class Center : public InterceptorList { return "Closed"; case ValueState::kReadyClosed: return "ReadyClosed"; + case ValueState::kWaitingForAck: + return "WaitingForAck"; + case ValueState::kWaitingForAckAndClosed: + return "WaitingForAckAndClosed"; case ValueState::kCancelled: return "Cancelled"; } @@ -349,6 +445,7 @@ class Center : public InterceptorList { ValueState value_state_; IntraActivityWaiter on_empty_; IntraActivityWaiter on_full_; + IntraActivityWaiter on_closed_; // Make failure to destruct show up in ASAN builds. #ifndef NDEBUG @@ -388,11 +485,25 @@ class PipeSender { // receiver is either closed or able to receive another message. PushType Push(T value); + // Return a promise that resolves when the receiver is closed. + // The resolved value is a bool - true if the pipe was cancelled, false if it + // was closed successfully. + // Checks closed from the senders perspective: that is, if there is a value in + // the pipe but the pipe is closed, reports closed. + auto AwaitClosed() { + return [center = center_]() { return center->PollClosedForSender(); }; + } + + // Interject PromiseFactory f into the pipeline. + // f will be called with the current value traversing the pipe, and should + // return a value to replace it with. + // Interjects at the Push end of the pipe. template void InterceptAndMap(Fn f, DebugLocation from = {}) { center_->PrependMap(std::move(f), from); } + // Per above, but calls cleanup_fn when the pipe is closed. template void InterceptAndMap(Fn f, OnHalfClose cleanup_fn, DebugLocation from = {}) { center_->PrependMapWithCleanup(std::move(f), std::move(cleanup_fn), from); @@ -409,6 +520,31 @@ class PipeSender { #endif }; +template +class PipeReceiver; + +namespace pipe_detail { + +// Implementation of PipeReceiver::Next promise. +template +class Next { + public: + Next(const Next&) = delete; + Next& operator=(const Next&) = delete; + Next(Next&& other) noexcept = default; + Next& operator=(Next&& other) noexcept = default; + + Poll> operator()() { return center_->Next(); } + + private: + friend class PipeReceiver; + explicit Next(RefCountedPtr> center) : center_(std::move(center)) {} + + RefCountedPtr> center_; +}; + +} // namespace pipe_detail + // Receive end of a Pipe. template class PipeReceiver { @@ -418,7 +554,7 @@ class PipeReceiver { PipeReceiver(PipeReceiver&& other) noexcept = default; PipeReceiver& operator=(PipeReceiver&& other) noexcept = default; ~PipeReceiver() { - if (center_ != nullptr) center_->MarkClosed(); + if (center_ != nullptr) center_->MarkCancelled(); } void Swap(PipeReceiver* other) { std::swap(center_, other->center_); } @@ -428,13 +564,55 @@ class PipeReceiver { // message was received, or no value if the other end of the pipe was closed. // Blocks the promise until the receiver is either closed or a message is // available. - auto Next(); + auto Next() { + return Seq( + pipe_detail::Next(center_->Ref()), + [center = center_->Ref()](absl::optional value) { + bool open = value.has_value(); + bool cancelled = center->cancelled(); + return If( + open, + [center = std::move(center), value = std::move(value)]() mutable { + auto run = center->Run(std::move(value)); + return Map(std::move(run), + [center = std::move(center)]( + absl::optional value) mutable { + if (value.has_value()) { + center->value() = std::move(*value); + return NextResult(std::move(center)); + } else { + center->MarkCancelled(); + return NextResult(true); + } + }); + }, + [cancelled]() { return NextResult(cancelled); }); + }); + } + // Return a promise that resolves when the receiver is closed. + // The resolved value is a bool - true if the pipe was cancelled, false if it + // was closed successfully. + // Checks closed from the receivers perspective: that is, if there is a value + // in the pipe but the pipe is closed, reports open until that value is read. + auto AwaitClosed() { + return [center = center_]() { return center->PollClosedForReceiver(); }; + } + + auto AwaitEmpty() { + return [center = center_]() { return center->PollEmpty(); }; + } + + // Interject PromiseFactory f into the pipeline. + // f will be called with the current value traversing the pipe, and should + // return a value to replace it with. + // Interjects at the Next end of the pipe. template void InterceptAndMap(Fn f, DebugLocation from = {}) { center_->AppendMap(std::move(f), from); } + // Per above, but calls cleanup_fn when the pipe is closed. template void InterceptAndMapWithHalfClose(Fn f, OnHalfClose cleanup_fn, DebugLocation from = {}) { @@ -459,12 +637,19 @@ template class Push { public: Push(const Push&) = delete; + Push& operator=(const Push&) = delete; Push(Push&& other) noexcept = default; Push& operator=(Push&& other) noexcept = default; Poll operator()() { - if (center_ == nullptr) return false; + if (center_ == nullptr) { + if (grpc_trace_promise_primitives.enabled()) { + gpr_log(GPR_DEBUG, "%s Pipe push has a null center", + Activity::current()->DebugTag().c_str()); + } + return false; + } if (auto* p = absl::get_if(&state_)) { auto r = center_->Push(p); if (auto* ok = r.value_if_ready()) { @@ -489,24 +674,6 @@ class Push { absl::variant state_; }; -// Implementation of PipeReceiver::Next promise. -template -class Next { - public: - Next(const Next&) = delete; - Next& operator=(const Next&) = delete; - Next(Next&& other) noexcept = default; - Next& operator=(Next&& other) noexcept = default; - - Poll> operator()() { return center_->Next(); } - - private: - friend class PipeReceiver; - explicit Next(RefCountedPtr> center) : center_(std::move(center)) {} - - RefCountedPtr> center_; -}; - } // namespace pipe_detail template @@ -515,33 +682,6 @@ pipe_detail::Push PipeSender::Push(T value) { std::move(value)); } -template -auto PipeReceiver::Next() { - return Seq( - pipe_detail::Next(center_->Ref()), - [center = center_->Ref()](absl::optional value) { - bool open = value.has_value(); - bool cancelled = center->cancelled(); - return If( - open, - [center = std::move(center), value = std::move(value)]() mutable { - auto run_interceptors = center->Run(std::move(value)); - return Map(std::move(run_interceptors), - [center = std::move(center)]( - absl::optional value) mutable { - if (value.has_value()) { - center->value() = std::move(*value); - return NextResult(std::move(center)); - } else { - center->MarkCancelled(); - return NextResult(true); - } - }); - }, - [cancelled]() { return NextResult(cancelled); }); - }); -} - template using PipeReceiverNextType = decltype(std::declval>().Next()); diff --git a/src/core/lib/promise/promise.h b/src/core/lib/promise/promise.h index d5683bd93b4..5da762f99eb 100644 --- a/src/core/lib/promise/promise.h +++ b/src/core/lib/promise/promise.h @@ -17,10 +17,10 @@ #include -#include #include #include +#include "absl/functional/any_invocable.h" #include "absl/status/status.h" #include "absl/types/optional.h" @@ -33,7 +33,7 @@ namespace grpc_core { // Most of the time we just pass around the functor, but occasionally // it pays to have a type erased variant, which we define here. template -using Promise = std::function()>; +using Promise = absl::AnyInvocable()>; // Helper to execute a promise immediately and return either the result or // nothing. diff --git a/src/core/lib/resource_quota/arena.cc b/src/core/lib/resource_quota/arena.cc index 8811f8848e2..d37ef426490 100644 --- a/src/core/lib/resource_quota/arena.cc +++ b/src/core/lib/resource_quota/arena.cc @@ -54,6 +54,9 @@ Arena::~Arena() { gpr_free_aligned(z); z = prev_z; } +#ifdef GRPC_ARENA_TRACE_POOLED_ALLOCATIONS + gpr_log(GPR_ERROR, "DESTRUCT_ARENA %p", this); +#endif } Arena* Arena::Create(size_t initial_size, MemoryAllocator* memory_allocator) { @@ -71,7 +74,7 @@ std::pair Arena::CreateWithAlloc( return std::make_pair(new_arena, first_alloc); } -void Arena::Destroy() { +void Arena::DestroyManagedNewObjects() { ManagedNewObject* p; // Outer loop: clear the managed new object list. // We do this repeatedly in case a destructor ends up allocating something. @@ -82,6 +85,10 @@ void Arena::Destroy() { Destruct(std::exchange(p, p->next)); } } +} + +void Arena::Destroy() { + DestroyManagedNewObjects(); memory_allocator_->Release(total_allocated_.load(std::memory_order_relaxed)); this->~Arena(); gpr_free_aligned(this); @@ -114,7 +121,8 @@ void Arena::ManagedNewObject::Link(std::atomic* head) { } } -void* Arena::AllocPooled(size_t alloc_size, std::atomic* head) { +void* Arena::AllocPooled(size_t obj_size, size_t alloc_size, + std::atomic* head) { // ABA mitigation: // AllocPooled may be called by multiple threads, and to remove a node from // the free list we need to manipulate the next pointer, which may be done @@ -132,7 +140,11 @@ void* Arena::AllocPooled(size_t alloc_size, std::atomic* head) { FreePoolNode* p = head->exchange(nullptr, std::memory_order_acquire); // If there are no nodes in the free list, then go ahead and allocate from the // arena. - if (p == nullptr) return Alloc(alloc_size); + if (p == nullptr) { + void* r = Alloc(alloc_size); + TracePoolAlloc(obj_size, r); + return r; + } // We had a non-empty free list... but we own the *entire* free list. // We only want one node, so if there are extras we'd better give them back. if (p->next != nullptr) { @@ -151,10 +163,14 @@ void* Arena::AllocPooled(size_t alloc_size, std::atomic* head) { extra = next; } } + TracePoolAlloc(obj_size, p); return p; } void Arena::FreePooled(void* p, std::atomic* head) { + // May spuriously trace a free of an already freed object - see AllocPooled + // ABA mitigation. + TracePoolFree(p); FreePoolNode* node = static_cast(p); node->next = head->load(std::memory_order_acquire); while (!head->compare_exchange_weak( diff --git a/src/core/lib/resource_quota/arena.h b/src/core/lib/resource_quota/arena.h index b47985188b6..1dcb530243e 100644 --- a/src/core/lib/resource_quota/arena.h +++ b/src/core/lib/resource_quota/arena.h @@ -45,6 +45,9 @@ #include "src/core/lib/promise/context.h" #include "src/core/lib/resource_quota/memory_quota.h" +// #define GRPC_ARENA_POOLED_ALLOCATIONS_USE_MALLOC +// #define GRPC_ARENA_TRACE_POOLED_ALLOCATIONS + namespace grpc_core { namespace arena_detail { @@ -114,7 +117,9 @@ PoolAndSize ChoosePoolForAllocationSize( } // namespace arena_detail class Arena { - using PoolSizes = absl::integer_sequence; + // Selected pool sizes. + // How to tune: see tools/codegen/core/optimize_arena_pool_sizes.py + using PoolSizes = absl::integer_sequence; struct FreePoolNode { FreePoolNode* next; }; @@ -130,6 +135,13 @@ class Arena { size_t initial_size, size_t alloc_size, MemoryAllocator* memory_allocator); + // Destroy all `ManagedNew` allocated objects. + // Allows safe destruction of these objects even if they need context held by + // the arena. + // Idempotent. + // TODO(ctiller): eliminate ManagedNew. + void DestroyManagedNewObjects(); + // Destroy an arena. void Destroy(); @@ -170,6 +182,7 @@ class Arena { return &p->t; } +#ifndef GRPC_ARENA_POOLED_ALLOCATIONS_USE_MALLOC class PooledDeleter { public: explicit PooledDeleter(std::atomic* free_list) @@ -209,6 +222,7 @@ class Arena { &pools_[arena_detail::PoolFromObjectSize(PoolSizes())]; return PoolPtr( new (AllocPooled( + sizeof(T), arena_detail::AllocationSizeFromObjectSize(PoolSizes()), free_list)) T(std::forward(args)...), PooledDeleter(free_list)); @@ -229,12 +243,95 @@ class Arena { return PoolPtr(new (Alloc(where.alloc_size)) T[n], PooledDeleter(nullptr)); } else { - return PoolPtr( - new (AllocPooled(where.alloc_size, &pools_[where.pool_index])) T[n], - PooledDeleter(&pools_[where.pool_index])); + return PoolPtr(new (AllocPooled(where.alloc_size, where.alloc_size, + &pools_[where.pool_index])) T[n], + PooledDeleter(&pools_[where.pool_index])); + } + } + + // Like MakePooled, but with manual memory management. + // The caller is responsible for calling DeletePooled() on the returned + // pointer, and expected to call it with the same type T as was passed to this + // function (else the free list returned to the arena will be corrupted). + template + T* NewPooled(Args&&... args) { + auto* free_list = + &pools_[arena_detail::PoolFromObjectSize(PoolSizes())]; + return new (AllocPooled( + sizeof(T), + arena_detail::AllocationSizeFromObjectSize(PoolSizes()), + free_list)) T(std::forward(args)...); + } + + template + void DeletePooled(T* p) { + auto* free_list = + &pools_[arena_detail::PoolFromObjectSize(PoolSizes())]; + p->~T(); + FreePooled(p, free_list); + } +#else + class PooledDeleter { + public: + PooledDeleter() = default; + explicit PooledDeleter(std::nullptr_t) : delete_(false) {} + template + void operator()(T* p) { + // TODO(ctiller): promise based filter hijacks ownership of some pointers + // to make them appear as PoolPtr without really transferring ownership, + // by setting the arena to nullptr. + // This is a transitional hack and should be removed once promise based + // filter is removed. + if (delete_) delete p; } + + bool has_freelist() const { return delete_; } + + private: + bool delete_ = true; + }; + + template + using PoolPtr = std::unique_ptr; + + // Make a unique_ptr to T that is allocated from the arena. + // When the pointer is released, the memory may be reused for other + // MakePooled(.*) calls. + // CAUTION: The amount of memory allocated is rounded up to the nearest + // value in Arena::PoolSizes, and so this may pessimize total + // arena size. + template + PoolPtr MakePooled(Args&&... args) { + return PoolPtr(new T(std::forward(args)...), PooledDeleter()); + } + + // Make a unique_ptr to an array of T that is allocated from the arena. + // When the pointer is released, the memory may be reused for other + // MakePooled(.*) calls. + // One can use MakePooledArray to allocate a buffer of bytes. + // CAUTION: The amount of memory allocated is rounded up to the nearest + // value in Arena::PoolSizes, and so this may pessimize total + // arena size. + template + PoolPtr MakePooledArray(size_t n) { + return PoolPtr(new T[n], PooledDeleter()); } + // Like MakePooled, but with manual memory management. + // The caller is responsible for calling DeletePooled() on the returned + // pointer, and expected to call it with the same type T as was passed to this + // function (else the free list returned to the arena will be corrupted). + template + T* NewPooled(Args&&... args) { + return new T(std::forward(args)...); + } + + template + void DeletePooled(T* p) { + delete p; + } +#endif + private: struct Zone { Zone* prev; @@ -275,9 +372,24 @@ class Arena { void* AllocZone(size_t size); - void* AllocPooled(size_t alloc_size, std::atomic* head); + void* AllocPooled(size_t obj_size, size_t alloc_size, + std::atomic* head); static void FreePooled(void* p, std::atomic* head); + void TracePoolAlloc(size_t size, void* ptr) { + (void)size; + (void)ptr; +#ifdef GRPC_ARENA_TRACE_POOLED_ALLOCATIONS + gpr_log(GPR_ERROR, "ARENA %p ALLOC %" PRIdPTR " @ %p", this, size, ptr); +#endif + } + static void TracePoolFree(void* ptr) { + (void)ptr; +#ifdef GRPC_ARENA_TRACE_POOLED_ALLOCATIONS + gpr_log(GPR_ERROR, "FREE %p", ptr); +#endif + } + // Keep track of the total used size. We use this in our call sizing // hysteresis. std::atomic total_used_{0}; @@ -290,7 +402,9 @@ class Arena { // last zone; the zone list is reverse-walked during arena destruction only. std::atomic last_zone_{nullptr}; std::atomic managed_new_head_{nullptr}; +#ifndef GRPC_ARENA_POOLED_ALLOCATIONS_USE_MALLOC std::atomic pools_[PoolSizes::size()]{}; +#endif // The backing memory quota MemoryAllocator* const memory_allocator_; }; diff --git a/src/core/lib/security/transport/server_auth_filter.cc b/src/core/lib/security/transport/server_auth_filter.cc index f5938e99e6c..c67b97388ba 100644 --- a/src/core/lib/security/transport/server_auth_filter.cc +++ b/src/core/lib/security/transport/server_auth_filter.cc @@ -18,12 +18,12 @@ #include -#include - #include #include +#include #include #include +#include #include #include "absl/status/status.h" @@ -41,6 +41,7 @@ #include "src/core/lib/channel/channel_stack.h" #include "src/core/lib/channel/context.h" #include "src/core/lib/channel/promise_based_filter.h" +#include "src/core/lib/debug/trace.h" #include "src/core/lib/gprpp/debug_location.h" #include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/gprpp/status_helper.h" @@ -57,6 +58,7 @@ #include "src/core/lib/security/transport/auth_filters.h" // IWYU pragma: keep #include "src/core/lib/slice/slice.h" #include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/surface/call_trace.h" #include "src/core/lib/transport/metadata_batch.h" #include "src/core/lib/transport/transport.h" @@ -120,12 +122,28 @@ class ServerAuthFilter::RunApplicationCode { // memory later RunApplicationCode(ServerAuthFilter* filter, CallArgs call_args) : state_(GetContext()->ManagedNew(std::move(call_args))) { + if (grpc_call_trace.enabled()) { + gpr_log(GPR_ERROR, + "%s[server-auth]: Delegate to application: filter=%p this=%p " + "auth_ctx=%p", + Activity::current()->DebugTag().c_str(), filter, this, + filter->auth_context_.get()); + } filter->server_credentials_->auth_metadata_processor().process( filter->server_credentials_->auth_metadata_processor().state, filter->auth_context_.get(), state_->md.metadata, state_->md.count, OnMdProcessingDone, state_); } + RunApplicationCode(const RunApplicationCode&) = delete; + RunApplicationCode& operator=(const RunApplicationCode&) = delete; + RunApplicationCode(RunApplicationCode&& other) noexcept + : state_(std::exchange(other.state_, nullptr)) {} + RunApplicationCode& operator=(RunApplicationCode&& other) noexcept { + state_ = std::exchange(other.state_, nullptr); + return *this; + } + Poll> operator()() { if (state_->done.load(std::memory_order_acquire)) { return Poll>(std::move(state_->call_args)); diff --git a/src/core/lib/slice/slice.cc b/src/core/lib/slice/slice.cc index 51ee3a83644..6180ef10e56 100644 --- a/src/core/lib/slice/slice.cc +++ b/src/core/lib/slice/slice.cc @@ -480,7 +480,7 @@ int grpc_slice_slice(grpc_slice haystack, grpc_slice needle) { } const uint8_t* last = haystack_bytes + haystack_len - needle_len; - for (const uint8_t* cur = haystack_bytes; cur != last; ++cur) { + for (const uint8_t* cur = haystack_bytes; cur <= last; ++cur) { if (0 == memcmp(cur, needle_bytes, needle_len)) { return static_cast(cur - haystack_bytes); } diff --git a/src/core/lib/surface/call.cc b/src/core/lib/surface/call.cc index f2730770915..3832b61c845 100644 --- a/src/core/lib/surface/call.cc +++ b/src/core/lib/surface/call.cc @@ -36,14 +36,12 @@ #include #include "absl/base/thread_annotations.h" -#include "absl/cleanup/cleanup.h" #include "absl/meta/type_traits.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" -#include "absl/types/variant.h" #include #include @@ -88,8 +86,13 @@ #include "src/core/lib/promise/arena_promise.h" #include "src/core/lib/promise/context.h" #include "src/core/lib/promise/detail/basic_seq.h" +#include "src/core/lib/promise/latch.h" +#include "src/core/lib/promise/map.h" +#include "src/core/lib/promise/party.h" #include "src/core/lib/promise/pipe.h" #include "src/core/lib/promise/poll.h" +#include "src/core/lib/promise/race.h" +#include "src/core/lib/promise/seq.h" #include "src/core/lib/resource_quota/arena.h" #include "src/core/lib/slice/slice_buffer.h" #include "src/core/lib/slice/slice_internal.h" @@ -99,6 +102,7 @@ #include "src/core/lib/surface/completion_queue.h" #include "src/core/lib/surface/server.h" #include "src/core/lib/surface/validate_metadata.h" +#include "src/core/lib/transport/batch_builder.h" #include "src/core/lib/transport/error_utils.h" #include "src/core/lib/transport/metadata_batch.h" #include "src/core/lib/transport/transport.h" @@ -137,11 +141,13 @@ class Call : public CppImplOf { virtual void InternalRef(const char* reason) = 0; virtual void InternalUnref(const char* reason) = 0; - virtual grpc_compression_algorithm test_only_compression_algorithm() = 0; - virtual uint32_t test_only_message_flags() = 0; - virtual uint32_t test_only_encodings_accepted_by_peer() = 0; - virtual grpc_compression_algorithm compression_for_level( - grpc_compression_level level) = 0; + grpc_compression_algorithm test_only_compression_algorithm() { + return incoming_compression_algorithm_; + } + uint32_t test_only_message_flags() { return test_only_last_message_flags_; } + CompressionAlgorithmSet encodings_accepted_by_peer() { + return encodings_accepted_by_peer_; + } // This should return nullptr for the promise stack (and alternative means // for that functionality be invented) @@ -216,6 +222,26 @@ class Call : public CppImplOf { void ClearPeerString() { SetPeerString(Slice(grpc_empty_slice())); } + // TODO(ctiller): cancel_func is for cancellation of the call - filter stack + // holds no mutexes here, promise stack does, and so locking is different. + // Remove this and cancel directly once promise conversion is done. + void ProcessIncomingInitialMetadata(grpc_metadata_batch& md); + // Fixup outgoing metadata before sending - adds compression, protects + // internal headers against external modification. + void PrepareOutgoingInitialMetadata(const grpc_op& op, + grpc_metadata_batch& md); + void NoteLastMessageFlags(uint32_t flags) { + test_only_last_message_flags_ = flags; + } + grpc_compression_algorithm incoming_compression_algorithm() const { + return incoming_compression_algorithm_; + } + + void HandleCompressionAlgorithmDisabled( + grpc_compression_algorithm compression_algorithm) GPR_ATTRIBUTE_NOINLINE; + void HandleCompressionAlgorithmNotAccepted( + grpc_compression_algorithm compression_algorithm) GPR_ATTRIBUTE_NOINLINE; + private: RefCountedPtr channel_; Arena* const arena_; @@ -225,11 +251,18 @@ class Call : public CppImplOf { const bool is_client_; // flag indicating that cancellation is inherited bool cancellation_is_inherited_ = false; + // Compression algorithm for *incoming* data + grpc_compression_algorithm incoming_compression_algorithm_ = + GRPC_COMPRESS_NONE; + // Supported encodings (compression algorithms), a bitset. + // Always support no compression. + CompressionAlgorithmSet encodings_accepted_by_peer_{GRPC_COMPRESS_NONE}; + uint32_t test_only_last_message_flags_ = 0; // Peer name is protected by a mutex because it can be accessed by the // application at the same moment as it is being set by the completion // of the recv_initial_metadata op. The mutex should be mostly uncontended. mutable Mutex peer_mu_; - Slice peer_string_ ABSL_GUARDED_BY(&peer_mu_); + Slice peer_string_; }; Call::ParentCall* Call::GetOrCreateParentCall() { @@ -324,9 +357,13 @@ void Call::MaybeUnpublishFromParent() { void Call::CancelWithStatus(grpc_status_code status, const char* description) { // copying 'description' is needed to ensure the grpc_call_cancel_with_status // guarantee that can be short-lived. + // TODO(ctiller): change to + // absl::Status(static_cast(status), description) + // (ie remove the set_int, set_str). CancelWithError(grpc_error_set_int( - grpc_error_set_str(GRPC_ERROR_CREATE(description), - StatusStrProperty::kGrpcMessage, description), + grpc_error_set_str( + absl::Status(static_cast(status), description), + StatusStrProperty::kGrpcMessage, description), StatusIntProperty::kRpcStatus, status)); } @@ -373,6 +410,92 @@ void Call::DeleteThis() { arena->Destroy(); } +void Call::PrepareOutgoingInitialMetadata(const grpc_op& op, + grpc_metadata_batch& md) { + // TODO(juanlishen): If the user has already specified a compression + // algorithm by setting the initial metadata with key of + // GRPC_COMPRESSION_REQUEST_ALGORITHM_MD_KEY, we shouldn't override that + // with the compression algorithm mapped from compression level. + // process compression level + grpc_compression_level effective_compression_level = GRPC_COMPRESS_LEVEL_NONE; + bool level_set = false; + if (op.data.send_initial_metadata.maybe_compression_level.is_set) { + effective_compression_level = + op.data.send_initial_metadata.maybe_compression_level.level; + level_set = true; + } else { + const grpc_compression_options copts = channel()->compression_options(); + if (copts.default_level.is_set) { + level_set = true; + effective_compression_level = copts.default_level.level; + } + } + // Currently, only server side supports compression level setting. + if (level_set && !is_client()) { + const grpc_compression_algorithm calgo = + encodings_accepted_by_peer().CompressionAlgorithmForLevel( + effective_compression_level); + // The following metadata will be checked and removed by the message + // compression filter. It will be used as the call's compression + // algorithm. + md.Set(GrpcInternalEncodingRequest(), calgo); + } + // Ignore any te metadata key value pairs specified. + md.Remove(TeMetadata()); +} + +void Call::ProcessIncomingInitialMetadata(grpc_metadata_batch& md) { + Slice* peer_string = md.get_pointer(PeerString()); + if (peer_string != nullptr) SetPeerString(peer_string->Ref()); + + incoming_compression_algorithm_ = + md.Take(GrpcEncodingMetadata()).value_or(GRPC_COMPRESS_NONE); + encodings_accepted_by_peer_ = + md.Take(GrpcAcceptEncodingMetadata()) + .value_or(CompressionAlgorithmSet{GRPC_COMPRESS_NONE}); + + const grpc_compression_options compression_options = + channel_->compression_options(); + const grpc_compression_algorithm compression_algorithm = + incoming_compression_algorithm_; + if (GPR_UNLIKELY(!CompressionAlgorithmSet::FromUint32( + compression_options.enabled_algorithms_bitset) + .IsSet(compression_algorithm))) { + // check if algorithm is supported by current channel config + HandleCompressionAlgorithmDisabled(compression_algorithm); + } + // GRPC_COMPRESS_NONE is always set. + GPR_DEBUG_ASSERT(encodings_accepted_by_peer_.IsSet(GRPC_COMPRESS_NONE)); + if (GPR_UNLIKELY(!encodings_accepted_by_peer_.IsSet(compression_algorithm))) { + if (GRPC_TRACE_FLAG_ENABLED(grpc_compression_trace)) { + HandleCompressionAlgorithmNotAccepted(compression_algorithm); + } + } +} + +void Call::HandleCompressionAlgorithmNotAccepted( + grpc_compression_algorithm compression_algorithm) { + const char* algo_name = nullptr; + grpc_compression_algorithm_name(compression_algorithm, &algo_name); + gpr_log(GPR_ERROR, + "Compression algorithm ('%s') not present in the " + "accepted encodings (%s)", + algo_name, + std::string(encodings_accepted_by_peer_.ToString()).c_str()); +} + +void Call::HandleCompressionAlgorithmDisabled( + grpc_compression_algorithm compression_algorithm) { + const char* algo_name = nullptr; + grpc_compression_algorithm_name(compression_algorithm, &algo_name); + std::string error_msg = + absl::StrFormat("Compression algorithm '%s' is disabled.", algo_name); + gpr_log(GPR_ERROR, "%s", error_msg.c_str()); + CancelWithError(grpc_error_set_int(absl::UnimplementedError(error_msg), + StatusIntProperty::kRpcStatus, + GRPC_STATUS_UNIMPLEMENTED)); +} + /////////////////////////////////////////////////////////////////////////////// // FilterStackCall // To be removed once promise conversion is complete @@ -431,11 +554,6 @@ class FilterStackCall final : public Call { return context_[elem].value; } - grpc_compression_algorithm compression_for_level( - grpc_compression_level level) override { - return encodings_accepted_by_peer_.CompressionAlgorithmForLevel(level); - } - bool is_trailers_only() const override { bool result = is_trailers_only_; GPR_DEBUG_ASSERT(!result || recv_initial_metadata_.TransportSize() == 0); @@ -453,18 +571,6 @@ class FilterStackCall final : public Call { return authority_metadata->as_string_view(); } - grpc_compression_algorithm test_only_compression_algorithm() override { - return incoming_compression_algorithm_; - } - - uint32_t test_only_message_flags() override { - return test_only_last_message_flags_; - } - - uint32_t test_only_encodings_accepted_by_peer() override { - return encodings_accepted_by_peer_.ToLegacyBitmask(); - } - static size_t InitialSizeEstimate() { return sizeof(FilterStackCall) + sizeof(BatchControl) * kMaxConcurrentBatches; @@ -565,7 +671,6 @@ class FilterStackCall final : public Call { void FinishStep(PendingOp op); void ProcessDataAfterMetadata(); void ReceivingStreamReady(grpc_error_handle error); - void ValidateFilteredMetadata(); void ReceivingInitialMetadataReady(grpc_error_handle error); void ReceivingTrailingMetadataReady(grpc_error_handle error); void FinishBatch(grpc_error_handle error); @@ -590,10 +695,6 @@ class FilterStackCall final : public Call { grpc_closure* start_batch_closure); void SetFinalStatus(grpc_error_handle error); BatchControl* ReuseOrAllocateBatchControl(const grpc_op* ops); - void HandleCompressionAlgorithmDisabled( - grpc_compression_algorithm compression_algorithm) GPR_ATTRIBUTE_NOINLINE; - void HandleCompressionAlgorithmNotAccepted( - grpc_compression_algorithm compression_algorithm) GPR_ATTRIBUTE_NOINLINE; bool PrepareApplicationMetadata(size_t count, grpc_metadata* metadata, bool is_trailing); void PublishAppMetadata(grpc_metadata_batch* b, bool is_trailing); @@ -637,13 +738,6 @@ class FilterStackCall final : public Call { // completed grpc_call_final_info final_info_; - // Compression algorithm for *incoming* data - grpc_compression_algorithm incoming_compression_algorithm_ = - GRPC_COMPRESS_NONE; - // Supported encodings (compression algorithms), a bitset. - // Always support no compression. - CompressionAlgorithmSet encodings_accepted_by_peer_{GRPC_COMPRESS_NONE}; - // Contexts for various subsystems (security, tracing, ...). grpc_call_context_element context_[GRPC_CONTEXT_COUNT] = {}; @@ -657,7 +751,6 @@ class FilterStackCall final : public Call { grpc_closure receiving_stream_ready_; grpc_closure receiving_initial_metadata_ready_; grpc_closure receiving_trailing_metadata_ready_; - uint32_t test_only_last_message_flags_ = 0; // Status about operation of call bool sent_server_trailing_metadata_ = false; gpr_atm cancelled_with_error_ = 0; @@ -1094,11 +1187,7 @@ void FilterStackCall::PublishAppMetadata(grpc_metadata_batch* b, } void FilterStackCall::RecvInitialFilter(grpc_metadata_batch* b) { - incoming_compression_algorithm_ = - b->Take(GrpcEncodingMetadata()).value_or(GRPC_COMPRESS_NONE); - encodings_accepted_by_peer_ = - b->Take(GrpcAcceptEncodingMetadata()) - .value_or(CompressionAlgorithmSet{GRPC_COMPRESS_NONE}); + ProcessIncomingInitialMetadata(*b); PublishAppMetadata(b, false); } @@ -1267,11 +1356,11 @@ void FilterStackCall::BatchControl::ProcessDataAfterMetadata() { call->receiving_message_ = false; FinishStep(PendingOp::kRecvMessage); } else { - call->test_only_last_message_flags_ = call->receiving_stream_flags_; + call->NoteLastMessageFlags(call->receiving_stream_flags_); if ((call->receiving_stream_flags_ & GRPC_WRITE_INTERNAL_COMPRESS) && - (call->incoming_compression_algorithm_ != GRPC_COMPRESS_NONE)) { + (call->incoming_compression_algorithm() != GRPC_COMPRESS_NONE)) { *call->receiving_buffer_ = grpc_raw_compressed_byte_buffer_create( - nullptr, 0, call->incoming_compression_algorithm_); + nullptr, 0, call->incoming_compression_algorithm()); } else { *call->receiving_buffer_ = grpc_raw_byte_buffer_create(nullptr, 0); } @@ -1312,50 +1401,6 @@ void FilterStackCall::BatchControl::ReceivingStreamReady( } } -void FilterStackCall::HandleCompressionAlgorithmDisabled( - grpc_compression_algorithm compression_algorithm) { - const char* algo_name = nullptr; - grpc_compression_algorithm_name(compression_algorithm, &algo_name); - std::string error_msg = - absl::StrFormat("Compression algorithm '%s' is disabled.", algo_name); - gpr_log(GPR_ERROR, "%s", error_msg.c_str()); - CancelWithStatus(GRPC_STATUS_UNIMPLEMENTED, error_msg.c_str()); -} - -void FilterStackCall::HandleCompressionAlgorithmNotAccepted( - grpc_compression_algorithm compression_algorithm) { - const char* algo_name = nullptr; - grpc_compression_algorithm_name(compression_algorithm, &algo_name); - gpr_log(GPR_ERROR, - "Compression algorithm ('%s') not present in the " - "accepted encodings (%s)", - algo_name, - std::string(encodings_accepted_by_peer_.ToString()).c_str()); -} - -void FilterStackCall::BatchControl::ValidateFilteredMetadata() { - FilterStackCall* call = call_; - - const grpc_compression_options compression_options = - call->channel()->compression_options(); - const grpc_compression_algorithm compression_algorithm = - call->incoming_compression_algorithm_; - if (GPR_UNLIKELY(!CompressionAlgorithmSet::FromUint32( - compression_options.enabled_algorithms_bitset) - .IsSet(compression_algorithm))) { - // check if algorithm is supported by current channel config - call->HandleCompressionAlgorithmDisabled(compression_algorithm); - } - // GRPC_COMPRESS_NONE is always set. - GPR_DEBUG_ASSERT(call->encodings_accepted_by_peer_.IsSet(GRPC_COMPRESS_NONE)); - if (GPR_UNLIKELY( - !call->encodings_accepted_by_peer_.IsSet(compression_algorithm))) { - if (GRPC_TRACE_FLAG_ENABLED(grpc_compression_trace)) { - call->HandleCompressionAlgorithmNotAccepted(compression_algorithm); - } - } -} - void FilterStackCall::BatchControl::ReceivingInitialMetadataReady( grpc_error_handle error) { FilterStackCall* call = call_; @@ -1366,12 +1411,6 @@ void FilterStackCall::BatchControl::ReceivingInitialMetadataReady( grpc_metadata_batch* md = &call->recv_initial_metadata_; call->RecvInitialFilter(md); - // TODO(ctiller): this could be moved into recv_initial_filter now - ValidateFilteredMetadata(); - - Slice* peer_string = md->get_pointer(PeerString()); - if (peer_string != nullptr) call->SetPeerString(peer_string->Ref()); - absl::optional deadline = md->get(GrpcTimeoutMetadata()); if (deadline.has_value() && !call->is_client()) { call_->set_send_deadline(*deadline); @@ -1521,36 +1560,6 @@ grpc_call_error FilterStackCall::StartBatch(const grpc_op* ops, size_t nops, error = GRPC_CALL_ERROR_TOO_MANY_OPERATIONS; goto done_with_error; } - // TODO(juanlishen): If the user has already specified a compression - // algorithm by setting the initial metadata with key of - // GRPC_COMPRESSION_REQUEST_ALGORITHM_MD_KEY, we shouldn't override that - // with the compression algorithm mapped from compression level. - // process compression level - grpc_compression_level effective_compression_level = - GRPC_COMPRESS_LEVEL_NONE; - bool level_set = false; - if (op->data.send_initial_metadata.maybe_compression_level.is_set) { - effective_compression_level = - op->data.send_initial_metadata.maybe_compression_level.level; - level_set = true; - } else { - const grpc_compression_options copts = - channel()->compression_options(); - if (copts.default_level.is_set) { - level_set = true; - effective_compression_level = copts.default_level.level; - } - } - // Currently, only server side supports compression level setting. - if (level_set && !is_client()) { - const grpc_compression_algorithm calgo = - encodings_accepted_by_peer_.CompressionAlgorithmForLevel( - effective_compression_level); - // The following metadata will be checked and removed by the message - // compression filter. It will be used as the call's compression - // algorithm. - send_initial_metadata_.Set(GrpcInternalEncodingRequest(), calgo); - } if (op->data.send_initial_metadata.count > INT_MAX) { error = GRPC_CALL_ERROR_INVALID_METADATA; goto done_with_error; @@ -1563,8 +1572,7 @@ grpc_call_error FilterStackCall::StartBatch(const grpc_op* ops, size_t nops, error = GRPC_CALL_ERROR_INVALID_METADATA; goto done_with_error; } - // Ignore any te metadata key value pairs specified. - send_initial_metadata_.Remove(TeMetadata()); + PrepareOutgoingInitialMetadata(*op, send_initial_metadata_); // TODO(ctiller): just make these the same variable? if (is_client() && send_deadline() != Timestamp::InfFuture()) { send_initial_metadata_.Set(GrpcTimeoutMetadata(), send_deadline()); @@ -1941,8 +1949,7 @@ bool ValidateMetadata(size_t count, grpc_metadata* metadata) { // Will be folded into Call once the promise conversion is done class PromiseBasedCall : public Call, - public Activity, - public Wakeable, + public Party, public grpc_event_engine::experimental::EventEngine:: Closure /* for deadlines */ { public: @@ -1953,176 +1960,54 @@ class PromiseBasedCall : public Call, void (*destroy)(void* value)) override; void* ContextGet(grpc_context_index elem) const override; void SetCompletionQueue(grpc_completion_queue* cq) override; - void SetCompletionQueueLocked(grpc_completion_queue* cq) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); - void CancelWithError(absl::Status error) final ABSL_LOCKS_EXCLUDED(mu_) { - MutexLock lock(&mu_); - CancelWithErrorLocked(std::move(error)); - } - virtual void CancelWithErrorLocked(absl::Status error) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) = 0; - bool Completed() final ABSL_LOCKS_EXCLUDED(mu_) { - MutexLock lock(&mu_); - return completed_; - } - - void Orphan() final { - MutexLock lock(&mu_); - if (!completed_) { - CancelWithErrorLocked(absl::CancelledError("Call orphaned")); - } - } + bool Completed() final { return finished_.IsSet(); } // Implementation of call refcounting: move this to DualRefCounted once we // don't need to maintain FilterStackCall compatibility - void ExternalRef() final { - const uint64_t prev_ref_pair = - refs_.fetch_add(MakeRefPair(1, 0), std::memory_order_relaxed); - if (grpc_call_refcount_trace.enabled()) { - gpr_log(GPR_DEBUG, "%s EXTERNAL_REF: %d:%d->%d:%d", DebugTag().c_str(), - GetStrongRefs(prev_ref_pair), GetWeakRefs(prev_ref_pair), - GetStrongRefs(prev_ref_pair) + 1, GetWeakRefs(prev_ref_pair)); - } - } - void ExternalUnref() final { - const uint64_t prev_ref_pair = - refs_.fetch_add(MakeRefPair(-1, 1), std::memory_order_acq_rel); - if (grpc_call_refcount_trace.enabled()) { - gpr_log(GPR_DEBUG, "%s EXTERNAL_UNREF: %d:%d->%d:%d", DebugTag().c_str(), - GetStrongRefs(prev_ref_pair), GetWeakRefs(prev_ref_pair), - GetStrongRefs(prev_ref_pair) - 1, GetWeakRefs(prev_ref_pair) + 1); - } - const uint32_t strong_refs = GetStrongRefs(prev_ref_pair); - if (GPR_UNLIKELY(strong_refs == 1)) { - Orphan(); - } - // Now drop the weak ref. - InternalUnref("external_ref"); - } + void ExternalRef() final { InternalRef("external"); } + void ExternalUnref() final { InternalUnref("external"); } void InternalRef(const char* reason) final { - uint64_t n = refs_.fetch_add(MakeRefPair(0, 1), std::memory_order_relaxed); if (grpc_call_refcount_trace.enabled()) { - gpr_log(GPR_DEBUG, "%s REF: %s %d:%d->%d:%d", DebugTag().c_str(), reason, - GetStrongRefs(n), GetWeakRefs(n), GetStrongRefs(n), - GetWeakRefs(n) + 1); + gpr_log(GPR_DEBUG, "INTERNAL_REF:%p:%s", this, reason); } + Party::IncrementRefCount(); } void InternalUnref(const char* reason) final { - const uint64_t prev_ref_pair = - refs_.fetch_sub(MakeRefPair(0, 1), std::memory_order_acq_rel); if (grpc_call_refcount_trace.enabled()) { - gpr_log(GPR_DEBUG, "%s UNREF: %s %d:%d->%d:%d", DebugTag().c_str(), - reason, GetStrongRefs(prev_ref_pair), GetWeakRefs(prev_ref_pair), - GetStrongRefs(prev_ref_pair), GetWeakRefs(prev_ref_pair) - 1); - } - if (GPR_UNLIKELY(prev_ref_pair == MakeRefPair(0, 1))) { - DeleteThis(); + gpr_log(GPR_DEBUG, "INTERNAL_UNREF:%p:%s", this, reason); } + Party::Unref(); } - // Activity methods - void ForceImmediateRepoll() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) override; - Waker MakeOwningWaker() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) override { - InternalRef("wakeup"); -// If ASAN is defined, we leverage it to detect dropped Waker objects. -// Usually Waker must be destroyed or woken up, but (especially with arenas) -// it's not uncommon to create a Waker and then do neither. In that case it's -// incredibly fraught to diagnose where the dropped reference to this object was -// created. Instead, leverage ASAN and create a new object per expected wakeup. -// Now when we drop such an object ASAN will fail and we'll get a callstack to -// the creation of the waker in question. -#if defined(__has_feature) -#if __has_feature(address_sanitizer) -#define GRPC_CALL_USES_ASAN_WAKER - class AsanWaker final : public Wakeable { - public: - explicit AsanWaker(PromiseBasedCall* call) : call_(call) {} - - void Wakeup(void*) override { - call_->Wakeup(nullptr); - delete this; - } - - void Drop(void*) override { - call_->Drop(nullptr); - delete this; - } - - std::string ActivityDebugTag(void*) const override { - return call_->DebugTag(); - } - - private: - PromiseBasedCall* call_; - }; - return Waker(new AsanWaker(this), nullptr); -#endif -#endif -#ifndef GRPC_CALL_USES_ASAN_WAKER - return Waker(this, nullptr); -#endif - } - Waker MakeNonOwningWaker() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) override; - - // Wakeable methods - void Wakeup(void*) override { - channel()->event_engine()->Run([this] { - ApplicationCallbackExecCtx app_exec_ctx; - ExecCtx exec_ctx; - { - ScopedContext activity_context(this); - MutexLock lock(&mu_); - Update(); - } - InternalUnref("wakeup"); - }); - } - void Drop(void*) override { InternalUnref("wakeup"); } - void RunInContext(absl::AnyInvocable fn) { - if (Activity::current() == this) { - fn(); - } else { - InternalRef("in_context"); - channel()->event_engine()->Run([this, fn = std::move(fn)]() mutable { - ApplicationCallbackExecCtx app_exec_ctx; - ExecCtx exec_ctx; - { - ScopedContext activity_context(this); - MutexLock lock(&mu_); + Spawn( + "run_in_context", + [fn = std::move(fn)]() mutable { fn(); - Update(); - } - InternalUnref("in_context"); - }); - } - } - - grpc_compression_algorithm test_only_compression_algorithm() override { - abort(); - } - uint32_t test_only_message_flags() override { abort(); } - uint32_t test_only_encodings_accepted_by_peer() override { abort(); } - grpc_compression_algorithm compression_for_level( - grpc_compression_level) override { - abort(); + return Empty{}; + }, + [](Empty) {}); } // This should return nullptr for the promise stack (and alternative means // for that functionality be invented) grpc_call_stack* call_stack() override { return nullptr; } - void UpdateDeadline(Timestamp deadline); - void ResetDeadline(); + void UpdateDeadline(Timestamp deadline) ABSL_LOCKS_EXCLUDED(deadline_mu_); + void ResetDeadline() ABSL_LOCKS_EXCLUDED(deadline_mu_); // Implementation of EventEngine::Closure, called when deadline expires void Run() override; virtual ServerCallContext* server_call_context() { return nullptr; } + using Call::arena; + protected: class ScopedContext : public ScopedActivity, + public BatchBuilder, + public promise_detail::Context, public promise_detail::Context, public promise_detail::Context, public promise_detail::Context, @@ -2130,6 +2015,8 @@ class PromiseBasedCall : public Call, public: explicit ScopedContext(PromiseBasedCall* call) : ScopedActivity(call), + BatchBuilder(&call->batch_payload_), + promise_detail::Context(this), promise_detail::Context(call->arena()), promise_detail::Context(call->context_), promise_detail::Context(&call->call_context_), @@ -2163,8 +2050,12 @@ class PromiseBasedCall : public Call, }; ~PromiseBasedCall() override { - if (non_owning_wakeable_) non_owning_wakeable_->DropActivity(); if (cq_) GRPC_CQ_INTERNAL_UNREF(cq_, "bind"); + for (int i = 0; i < GRPC_CONTEXT_COUNT; i++) { + if (context_[i].destroy) { + context_[i].destroy(context_[i].value); + } + } } // Enumerates why a Completion is still pending @@ -2172,6 +2063,7 @@ class PromiseBasedCall : public Call, // We're in the midst of starting a batch of operations kStartingBatch = 0, // The following correspond with the batch operations from above + kSendInitialMetadata, kReceiveInitialMetadata, kReceiveStatusOnClient, kReceiveCloseOnServer = kReceiveStatusOnClient, @@ -2181,10 +2073,17 @@ class PromiseBasedCall : public Call, kSendCloseFromClient = kSendStatusFromServer, }; + bool RunParty() override { + ScopedContext ctx(this); + return Party::RunParty(); + } + const char* PendingOpString(PendingOp reason) const { switch (reason) { case PendingOp::kStartingBatch: return "StartingBatch"; + case PendingOp::kSendInitialMetadata: + return "SendInitialMetadata"; case PendingOp::kReceiveInitialMetadata: return "ReceiveInitialMetadata"; case PendingOp::kReceiveStatusOnClient: @@ -2199,56 +2098,47 @@ class PromiseBasedCall : public Call, return "Unknown"; } - static constexpr uint8_t PendingOpBit(PendingOp reason) { + static constexpr uint32_t PendingOpBit(PendingOp reason) { return 1 << static_cast(reason); } - Mutex* mu() const ABSL_LOCK_RETURNED(mu_) { return &mu_; } // Begin work on a completion, recording the tag/closure to notify. // Use the op selected in \a ops to determine the index to allocate into. // Starts the "StartingBatch" PendingOp immediately. // Assumes at least one operation in \a ops. - Completion StartCompletion(void* tag, bool is_closure, const grpc_op* ops) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + Completion StartCompletion(void* tag, bool is_closure, const grpc_op* ops); // Add one pending op to the completion, and return it. - Completion AddOpToCompletion(const Completion& completion, PendingOp reason) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + Completion AddOpToCompletion(const Completion& completion, PendingOp reason); // Stringify a completion std::string CompletionString(const Completion& completion) const { return completion.has_value() - ? absl::StrFormat( - "%d:tag=%p", static_cast(completion.index()), - completion_info_[completion.index()].pending.tag) + ? completion_info_[completion.index()].pending.ToString(this) : "no-completion"; } // Finish one op on the completion. Must have been previously been added. // The completion as a whole finishes when all pending ops finish. - void FinishOpOnCompletion(Completion* completion, PendingOp reason) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + void FinishOpOnCompletion(Completion* completion, PendingOp reason); // Mark the completion as failed. Does not finish it. void FailCompletion(const Completion& completion, SourceLocation source_location = {}); - // Run the promise polling loop until it stalls. - void Update() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); - // Update the promise state once. - virtual void UpdateOnce() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) = 0; + // Mark the completion as infallible. Overrides FailCompletion to report + // success always. + void ForceCompletionSuccess(const Completion& completion); // Accept the stats from the context (call once we have proof the transport is // done with them). // Right now this means that promise based calls do not record correct stats // with census if they are cancelled. // TODO(ctiller): this should be remedied before promise based calls are // dexperimentalized. - void AcceptTransportStatsFromContext() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + void AcceptTransportStatsFromContext() { final_stats_ = *call_context_.call_stats(); } - grpc_completion_queue* cq() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { return cq_; } + grpc_completion_queue* cq() { return cq_; } void CToMetadata(grpc_metadata* metadata, size_t count, grpc_metadata_batch* batch); - std::string ActivityDebugTag(void*) const override { return DebugTag(); } - // At the end of the call run any finalization actions. void RunFinalization(grpc_status_code status, const char* status_details) { grpc_call_final_info final_info; @@ -2277,158 +2167,163 @@ class PromiseBasedCall : public Call, } } - std::string PollStateDebugString() const ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { - return absl::StrCat(PresentAndCompletionText("outstanding_send", - outstanding_send_.has_value(), - send_message_completion_) - .c_str(), - PresentAndCompletionText("outstanding_recv", - outstanding_recv_.has_value(), - recv_message_completion_) - .c_str()); - } - + // Spawn a job that will first do FirstPromise then receive a message + template void StartRecvMessage(const grpc_op& op, const Completion& completion, - PipeReceiver* receiver) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); - void PollRecvMessage(grpc_compression_algorithm compression_algorithm) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); - void CancelRecvMessage() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + FirstPromise first, + PipeReceiver* receiver, + Party::BulkSpawner& spawner); void StartSendMessage(const grpc_op& op, const Completion& completion, - PipeSender* sender) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); - bool PollSendMessage() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); - void CancelSendMessage() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + PipeSender* sender, + Party::BulkSpawner& spawner); + + void set_completed() { finished_.Set(); } - bool completed() const ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { - return completed_; + // Returns a promise that resolves to Empty whenever the call is completed. + auto finished() { return finished_.Wait(); } + + // Returns a promise that resolves to Empty whenever there is no outstanding + // send operation + auto WaitForSendingStarted() { + return [this]() -> Poll { + int n = sends_queued_.load(std::memory_order_relaxed); + if (grpc_call_trace.enabled()) { + gpr_log(GPR_DEBUG, "%s[call] WaitForSendingStarted n=%d", + DebugTag().c_str(), n); + } + if (n != 0) return waiting_for_queued_sends_.pending(); + return Empty{}; + }; + } + + // Mark that a send has been queued - blocks sending trailing metadata. + void QueueSend() { + if (grpc_call_trace.enabled()) { + gpr_log(GPR_DEBUG, "%s[call] QueueSend", DebugTag().c_str()); + } + sends_queued_.fetch_add(1, std::memory_order_relaxed); } - void set_completed() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { completed_ = true; } - bool is_sending() const ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { - return outstanding_send_.has_value(); + // Mark that a send has been dequeued - allows sending trailing metadata once + // zero sends are queued. + void EnactSend() { + if (grpc_call_trace.enabled()) { + gpr_log(GPR_DEBUG, "%s[call] EnactSend", DebugTag().c_str()); + } + if (1 == sends_queued_.fetch_sub(1, std::memory_order_relaxed)) { + waiting_for_queued_sends_.Wake(); + } } private: union CompletionInfo { + static constexpr uint32_t kOpFailed = 0x8000'0000u; + static constexpr uint32_t kOpForceSuccess = 0x4000'0000u; + CompletionInfo() {} + enum CompletionState { + kPending, + kSuccess, + kFailure, + }; struct Pending { - // Bitmask of PendingOps - uint8_t pending_op_bits; + // Bitmask of PendingOps at the bottom, and kOpFailed, kOpForceSuccess at + // the top. + std::atomic state; bool is_closure; - bool success; + // True if this completion was for a recv_message op. + // In that case if the completion as a whole fails we need to cleanup the + // returned message. + bool is_recv_message; void* tag; - } pending; - grpc_cq_completion completion; - }; - class NonOwningWakable final : public Wakeable { - public: - explicit NonOwningWakable(PromiseBasedCall* call) : call_(call) {} - - // Ref the Handle (not the activity). - void Ref() { refs_.fetch_add(1, std::memory_order_relaxed); } - - // Activity is going away... drop its reference and sever the connection - // back. - void DropActivity() ABSL_LOCKS_EXCLUDED(mu_) { - auto unref = absl::MakeCleanup([this]() { Unref(); }); - MutexLock lock(&mu_); - GPR_ASSERT(call_ != nullptr); - call_ = nullptr; - } + void Start(bool is_closure, void* tag) { + this->is_closure = is_closure; + this->is_recv_message = false; + this->tag = tag; + state.store(PendingOpBit(PendingOp::kStartingBatch), + std::memory_order_release); + } - // Activity needs to wake up (if it still exists!) - wake it up, and drop - // the ref that was kept for this handle. - void Wakeup(void*) override ABSL_LOCKS_EXCLUDED(mu_) { - // Drop the ref to the handle at end of scope (we have one ref = one - // wakeup semantics). - auto unref = absl::MakeCleanup([this]() { Unref(); }); - ReleasableMutexLock lock(&mu_); - // Note that activity refcount can drop to zero, but we could win the lock - // against DropActivity, so we need to only increase activities refcount - // if it is non-zero. - PromiseBasedCall* call = call_; - if (call != nullptr && call->RefIfNonZero()) { - lock.Release(); - // Activity still exists and we have a reference: wake it up, which will - // drop the ref. - call->Wakeup(nullptr); + void AddPendingBit(PendingOp reason) { + if (reason == PendingOp::kReceiveMessage) is_recv_message = true; + auto prev = + state.fetch_or(PendingOpBit(reason), std::memory_order_relaxed); + GPR_ASSERT((prev & PendingOpBit(reason)) == 0); } - } - std::string ActivityDebugTag(void*) const override { - MutexLock lock(&mu_); - return call_ == nullptr ? "" : call_->DebugTag(); - } + CompletionState RemovePendingBit(PendingOp reason) { + const uint32_t mask = ~PendingOpBit(reason); + auto prev = state.fetch_and(mask, std::memory_order_acq_rel); + GPR_ASSERT((prev & PendingOpBit(reason)) != 0); + switch (prev & mask) { + case kOpFailed: + return kFailure; + case kOpFailed | kOpForceSuccess: + case kOpForceSuccess: + case 0: + return kSuccess; + default: + return kPending; + } + } - void Drop(void*) override { Unref(); } + void MarkFailed() { + state.fetch_or(kOpFailed, std::memory_order_relaxed); + } - private: - // Unref the Handle (not the activity). - void Unref() { - if (1 == refs_.fetch_sub(1, std::memory_order_acq_rel)) { - delete this; + void MarkForceSuccess() { + state.fetch_or(kOpForceSuccess, std::memory_order_relaxed); } - } - mutable Mutex mu_; - // We have two initial refs: one for the wakeup that this is created for, - // and will be dropped by Wakeup, and the other for the activity which is - // dropped by DropActivity. - std::atomic refs_{2}; - PromiseBasedCall* call_ ABSL_GUARDED_BY(mu_); + std::string ToString(const PromiseBasedCall* call) const { + auto state = this->state.load(std::memory_order_relaxed); + std::vector pending_ops; + for (size_t i = 0; i < 24; i++) { + if (state & (1u << i)) { + pending_ops.push_back( + call->PendingOpString(static_cast(i))); + } + } + return absl::StrFormat("{%s}%s:tag=%p", absl::StrJoin(pending_ops, ","), + (state & kOpForceSuccess) ? ":force-success" + : (state & kOpFailed) ? ":failed" + : ":success", + tag); + } + } pending; + grpc_cq_completion completion; }; - static void OnDestroy(void* arg, grpc_error_handle) { - auto* call = static_cast(arg); - ScopedContext context(call); - call->DeleteThis(); - } - - // First 32 bits are strong refs, next 32 bits are weak refs. - static uint64_t MakeRefPair(uint32_t strong, uint32_t weak) { - return (static_cast(strong) << 32) + static_cast(weak); - } - static uint32_t GetStrongRefs(uint64_t ref_pair) { - return static_cast(ref_pair >> 32); - } - static uint32_t GetWeakRefs(uint64_t ref_pair) { - return static_cast(ref_pair & 0xffffffffu); - } - - bool RefIfNonZero() { - uint64_t prev_ref_pair = refs_.load(std::memory_order_acquire); - do { - const uint32_t strong_refs = GetStrongRefs(prev_ref_pair); - if (strong_refs == 0) return false; - } while (!refs_.compare_exchange_weak( - prev_ref_pair, prev_ref_pair + MakeRefPair(1, 0), - std::memory_order_acq_rel, std::memory_order_acquire)); - return true; + void PartyOver() override { + { + ScopedContext ctx(this); + CancelRemainingParticipants(); + arena()->DestroyManagedNewObjects(); + } + DeleteThis(); } - mutable Mutex mu_; - std::atomic refs_; CallContext call_context_{this}; - bool keep_polling_ ABSL_GUARDED_BY(mu()) = false; // Contexts for various subsystems (security, tracing, ...). grpc_call_context_element context_[GRPC_CONTEXT_COUNT] = {}; - grpc_completion_queue* cq_ ABSL_GUARDED_BY(mu_); - NonOwningWakable* non_owning_wakeable_ ABSL_GUARDED_BY(mu_) = nullptr; + grpc_completion_queue* cq_; CompletionInfo completion_info_[6]; grpc_call_stats final_stats_{}; CallFinalization finalization_; // Current deadline. - Timestamp deadline_ = Timestamp::InfFuture(); - grpc_event_engine::experimental::EventEngine::TaskHandle deadline_task_; - absl::optional::PushType> outstanding_send_ - ABSL_GUARDED_BY(mu_); - absl::optional> outstanding_recv_ - ABSL_GUARDED_BY(mu_); - grpc_byte_buffer** recv_message_ ABSL_GUARDED_BY(mu_) = nullptr; - Completion send_message_completion_ ABSL_GUARDED_BY(mu_); - Completion recv_message_completion_ ABSL_GUARDED_BY(mu_); - bool completed_ ABSL_GUARDED_BY(mu_) = false; + Mutex deadline_mu_; + Timestamp deadline_ ABSL_GUARDED_BY(deadline_mu_) = Timestamp::InfFuture(); + grpc_event_engine::experimental::EventEngine::TaskHandle ABSL_GUARDED_BY( + deadline_mu_) deadline_task_; + ExternallyObservableLatch finished_; + // Non-zero with an outstanding GRPC_OP_SEND_INITIAL_METADATA or + // GRPC_OP_SEND_MESSAGE (one count each), and 0 once those payloads have been + // pushed onto the outgoing pipe. + std::atomic sends_queued_{0}; + // Waiter for when sends_queued_ becomes 0. + IntraActivityWaiter waiting_for_queued_sends_; + grpc_byte_buffer** recv_message_ = nullptr; + grpc_transport_stream_op_batch_payload batch_payload_{context_}; }; template @@ -2448,7 +2343,7 @@ PromiseBasedCall::PromiseBasedCall(Arena* arena, uint32_t initial_external_refs, const grpc_call_create_args& args) : Call(arena, args.server_transport_data == nullptr, args.send_deadline, args.channel->Ref()), - refs_(MakeRefPair(initial_external_refs, 0)), + Party(arena, initial_external_refs), cq_(args.cq) { if (args.cq != nullptr) { GPR_ASSERT(args.pollset_set_alternative == nullptr && @@ -2464,15 +2359,6 @@ PromiseBasedCall::PromiseBasedCall(Arena* arena, uint32_t initial_external_refs, } } -Waker PromiseBasedCall::MakeNonOwningWaker() { - if (non_owning_wakeable_ == nullptr) { - non_owning_wakeable_ = new NonOwningWakable(this); - } else { - non_owning_wakeable_->Ref(); - } - return Waker(non_owning_wakeable_, nullptr); -} - void PromiseBasedCall::CToMetadata(grpc_metadata* metadata, size_t count, grpc_metadata_batch* b) { for (size_t i = 0; i < count; i++) { @@ -2507,15 +2393,14 @@ void* PromiseBasedCall::ContextGet(grpc_context_index elem) const { PromiseBasedCall::Completion PromiseBasedCall::StartCompletion( void* tag, bool is_closure, const grpc_op* ops) { Completion c(BatchSlotForOp(ops[0].op)); - if (grpc_call_trace.enabled()) { - gpr_log(GPR_INFO, "%s[call] StartCompletion %s tag=%p", DebugTag().c_str(), - CompletionString(c).c_str(), tag); - } if (!is_closure) { grpc_cq_begin_op(cq(), tag); } - completion_info_[c.index()].pending = { - PendingOpBit(PendingOp::kStartingBatch), is_closure, true, tag}; + completion_info_[c.index()].pending.Start(is_closure, tag); + if (grpc_call_trace.enabled()) { + gpr_log(GPR_INFO, "%s[call] StartCompletion %s", DebugTag().c_str(), + CompletionString(c).c_str()); + } return c; } @@ -2526,10 +2411,7 @@ PromiseBasedCall::Completion PromiseBasedCall::AddOpToCompletion( CompletionString(completion).c_str(), PendingOpString(reason)); } GPR_ASSERT(completion.has_value()); - auto& pending_op_bits = - completion_info_[completion.index()].pending.pending_op_bits; - GPR_ASSERT((pending_op_bits & PendingOpBit(reason)) == 0); - pending_op_bits |= PendingOpBit(reason); + completion_info_[completion.index()].pending.AddPendingBit(reason); return Completion(completion.index()); } @@ -2540,64 +2422,50 @@ void PromiseBasedCall::FailCompletion(const Completion& completion, "%s[call] FailCompletion %s", DebugTag().c_str(), CompletionString(completion).c_str()); } - completion_info_[completion.index()].pending.success = false; + completion_info_[completion.index()].pending.MarkFailed(); +} + +void PromiseBasedCall::ForceCompletionSuccess(const Completion& completion) { + completion_info_[completion.index()].pending.MarkForceSuccess(); } void PromiseBasedCall::FinishOpOnCompletion(Completion* completion, PendingOp reason) { if (grpc_call_trace.enabled()) { - auto pending_op_bits = - completion_info_[completion->index()].pending.pending_op_bits; - bool success = completion_info_[completion->index()].pending.success; - std::vector pending; - for (size_t i = 0; i < 8 * sizeof(pending_op_bits); i++) { - if (static_cast(i) == reason) continue; - if (pending_op_bits & (1 << i)) { - pending.push_back(PendingOpString(static_cast(i))); - } - } - gpr_log( - GPR_INFO, "%s[call] FinishOpOnCompletion tag:%p %s %s %s", - DebugTag().c_str(), completion_info_[completion->index()].pending.tag, - CompletionString(*completion).c_str(), PendingOpString(reason), - (pending.empty() - ? (success ? std::string("done") : std::string("failed")) - : absl::StrFormat("pending_ops={%s}", absl::StrJoin(pending, ","))) - .c_str()); + gpr_log(GPR_INFO, "%s[call] FinishOpOnCompletion completion:%s finish:%s", + DebugTag().c_str(), CompletionString(*completion).c_str(), + PendingOpString(reason)); } const uint8_t i = completion->TakeIndex(); GPR_ASSERT(i < GPR_ARRAY_SIZE(completion_info_)); CompletionInfo::Pending& pending = completion_info_[i].pending; - GPR_ASSERT(pending.pending_op_bits & PendingOpBit(reason)); - pending.pending_op_bits &= ~PendingOpBit(reason); - auto error = pending.success ? absl::OkStatus() : absl::CancelledError(); - if (pending.pending_op_bits == 0) { - if (pending.is_closure) { - ExecCtx::Run(DEBUG_LOCATION, static_cast(pending.tag), - error); - } else { - grpc_cq_end_op( - cq(), pending.tag, error, [](void*, grpc_cq_completion*) {}, nullptr, - &completion_info_[i].completion); - } + bool success; + switch (pending.RemovePendingBit(reason)) { + case CompletionInfo::kPending: + return; // Early out + case CompletionInfo::kSuccess: + success = true; + break; + case CompletionInfo::kFailure: + success = false; + break; + } + if (pending.is_recv_message && !success && *recv_message_ != nullptr) { + grpc_byte_buffer_destroy(*recv_message_); + *recv_message_ = nullptr; + } + auto error = success ? absl::OkStatus() : absl::CancelledError(); + if (pending.is_closure) { + ExecCtx::Run(DEBUG_LOCATION, static_cast(pending.tag), + error); + } else { + grpc_cq_end_op( + cq(), pending.tag, error, [](void*, grpc_cq_completion*) {}, nullptr, + &completion_info_[i].completion); } } -void PromiseBasedCall::Update() { - keep_polling_ = false; - do { - UpdateOnce(); - } while (std::exchange(keep_polling_, false)); -} - -void PromiseBasedCall::ForceImmediateRepoll() { keep_polling_ = true; } - void PromiseBasedCall::SetCompletionQueue(grpc_completion_queue* cq) { - MutexLock lock(&mu_); - SetCompletionQueueLocked(cq); -} - -void PromiseBasedCall::SetCompletionQueueLocked(grpc_completion_queue* cq) { cq_ = cq; GRPC_CQ_INTERNAL_REF(cq, "bind"); call_context_.pollent_ = @@ -2605,6 +2473,12 @@ void PromiseBasedCall::SetCompletionQueueLocked(grpc_completion_queue* cq) { } void PromiseBasedCall::UpdateDeadline(Timestamp deadline) { + MutexLock lock(&deadline_mu_); + if (grpc_call_trace.enabled()) { + gpr_log(GPR_DEBUG, "%s[call] UpdateDeadline from=%s to=%s", + DebugTag().c_str(), deadline_.ToString().c_str(), + deadline.ToString().c_str()); + } if (deadline >= deadline_) return; auto* const event_engine = channel()->event_engine(); if (deadline_ != Timestamp::InfFuture()) { @@ -2612,10 +2486,12 @@ void PromiseBasedCall::UpdateDeadline(Timestamp deadline) { } else { InternalRef("deadline"); } - event_engine->RunAfter(deadline - Timestamp::Now(), this); + deadline_ = deadline; + deadline_task_ = event_engine->RunAfter(deadline - Timestamp::Now(), this); } void PromiseBasedCall::ResetDeadline() { + MutexLock lock(&deadline_mu_); if (deadline_ == Timestamp::InfFuture()) return; auto* const event_engine = channel()->event_engine(); if (!event_engine->Cancel(deadline_task_)) return; @@ -2632,117 +2508,88 @@ void PromiseBasedCall::Run() { void PromiseBasedCall::StartSendMessage(const grpc_op& op, const Completion& completion, - PipeSender* sender) { - GPR_ASSERT(!outstanding_send_.has_value()); - if (!completed_) { - send_message_completion_ = - AddOpToCompletion(completion, PendingOp::kSendMessage); - SliceBuffer send; - grpc_slice_buffer_swap( - &op.data.send_message.send_message->data.raw.slice_buffer, - send.c_slice_buffer()); - outstanding_send_.emplace(sender->Push( - GetContext()->MakePooled(std::move(send), op.flags))); - } else { - FailCompletion(completion); - } + PipeSender* sender, + Party::BulkSpawner& spawner) { + QueueSend(); + SliceBuffer send; + grpc_slice_buffer_swap( + &op.data.send_message.send_message->data.raw.slice_buffer, + send.c_slice_buffer()); + auto msg = arena()->MakePooled(std::move(send), op.flags); + spawner.Spawn( + "call_send_message", + [this, sender, msg = std::move(msg)]() mutable { + EnactSend(); + return sender->Push(std::move(msg)); + }, + [this, completion = AddOpToCompletion( + completion, PendingOp::kSendMessage)](bool result) mutable { + if (grpc_call_trace.enabled()) { + gpr_log(GPR_DEBUG, "%sSendMessage completes %s", DebugTag().c_str(), + result ? "successfully" : "with failure"); + } + if (!result) FailCompletion(completion); + FinishOpOnCompletion(&completion, PendingOp::kSendMessage); + }); } -bool PromiseBasedCall::PollSendMessage() { - if (!outstanding_send_.has_value()) return true; - Poll r = (*outstanding_send_)(); - if (const bool* result = r.value_if_ready()) { - if (grpc_call_trace.enabled()) { - gpr_log(GPR_DEBUG, "%sPollSendMessage completes %s", DebugTag().c_str(), - *result ? "successfully" : "with failure"); - } - if (!*result) { - FailCompletion(send_message_completion_); - return false; - } - FinishOpOnCompletion(&send_message_completion_, PendingOp::kSendMessage); - outstanding_send_.reset(); +template +void PromiseBasedCall::StartRecvMessage( + const grpc_op& op, const Completion& completion, + FirstPromiseFactory first_promise_factory, + PipeReceiver* receiver, Party::BulkSpawner& spawner) { + if (grpc_call_trace.enabled()) { + gpr_log(GPR_INFO, "%s[call] Start RecvMessage: %s", DebugTag().c_str(), + CompletionString(completion).c_str()); } - return true; -} - -void PromiseBasedCall::CancelSendMessage() { - if (!outstanding_send_.has_value()) return; - FinishOpOnCompletion(&send_message_completion_, PendingOp::kSendMessage); - outstanding_send_.reset(); -} - -void PromiseBasedCall::StartRecvMessage(const grpc_op& op, - const Completion& completion, - PipeReceiver* receiver) { - GPR_ASSERT(!outstanding_recv_.has_value()); recv_message_ = op.data.recv_message.recv_message; - recv_message_completion_ = - AddOpToCompletion(completion, PendingOp::kReceiveMessage); - outstanding_recv_.emplace(receiver->Next()); -} - -void PromiseBasedCall::PollRecvMessage( - grpc_compression_algorithm incoming_compression_algorithm) { - if (!outstanding_recv_.has_value()) return; - Poll> r = (*outstanding_recv_)(); - if (auto* result = r.value_if_ready()) { - outstanding_recv_.reset(); - if (result->has_value()) { - MessageHandle& message = **result; - if ((message->flags() & GRPC_WRITE_INTERNAL_COMPRESS) && - (incoming_compression_algorithm != GRPC_COMPRESS_NONE)) { - *recv_message_ = grpc_raw_compressed_byte_buffer_create( - nullptr, 0, incoming_compression_algorithm); - } else { - *recv_message_ = grpc_raw_byte_buffer_create(nullptr, 0); - } - grpc_slice_buffer_move_into(message->payload()->c_slice_buffer(), - &(*recv_message_)->data.raw.slice_buffer); - if (grpc_call_trace.enabled()) { - gpr_log(GPR_INFO, - "%s[call] PollRecvMessage: outstanding_recv finishes: received " - "%" PRIdPTR " byte message", - DebugTag().c_str(), - (*recv_message_)->data.raw.slice_buffer.length); - } - } else if (result->cancelled()) { - if (grpc_call_trace.enabled()) { - gpr_log(GPR_INFO, - "%s[call] PollRecvMessage: outstanding_recv finishes: received " - "end-of-stream with error", - DebugTag().c_str()); - } - FailCompletion(recv_message_completion_); - *recv_message_ = nullptr; - } else { - if (grpc_call_trace.enabled()) { - gpr_log(GPR_INFO, - "%s[call] PollRecvMessage: outstanding_recv finishes: received " - "end-of-stream", - DebugTag().c_str()); - } - *recv_message_ = nullptr; - } - FinishOpOnCompletion(&recv_message_completion_, PendingOp::kReceiveMessage); - } else if (completed_) { - if (grpc_call_trace.enabled()) { - gpr_log(GPR_INFO, - "%s[call] UpdateOnce: outstanding_recv finishes: promise has " - "completed without queuing a message, forcing end-of-stream", - DebugTag().c_str()); - } - outstanding_recv_.reset(); - *recv_message_ = nullptr; - FinishOpOnCompletion(&recv_message_completion_, PendingOp::kReceiveMessage); - } -} - -void PromiseBasedCall::CancelRecvMessage() { - if (!outstanding_recv_.has_value()) return; - *recv_message_ = nullptr; - outstanding_recv_.reset(); - FinishOpOnCompletion(&recv_message_completion_, PendingOp::kReceiveMessage); + spawner.Spawn( + "call_recv_message", + [first_promise_factory = std::move(first_promise_factory), receiver]() { + return Seq(first_promise_factory(), receiver->Next()); + }, + [this, + completion = AddOpToCompletion(completion, PendingOp::kReceiveMessage)]( + NextResult result) mutable { + if (result.has_value()) { + MessageHandle& message = *result; + NoteLastMessageFlags(message->flags()); + if ((message->flags() & GRPC_WRITE_INTERNAL_COMPRESS) && + (incoming_compression_algorithm() != GRPC_COMPRESS_NONE)) { + *recv_message_ = grpc_raw_compressed_byte_buffer_create( + nullptr, 0, incoming_compression_algorithm()); + } else { + *recv_message_ = grpc_raw_byte_buffer_create(nullptr, 0); + } + grpc_slice_buffer_move_into(message->payload()->c_slice_buffer(), + &(*recv_message_)->data.raw.slice_buffer); + if (grpc_call_trace.enabled()) { + gpr_log(GPR_INFO, + "%s[call] RecvMessage: outstanding_recv " + "finishes: received %" PRIdPTR " byte message", + DebugTag().c_str(), + (*recv_message_)->data.raw.slice_buffer.length); + } + } else if (result.cancelled()) { + if (grpc_call_trace.enabled()) { + gpr_log(GPR_INFO, + "%s[call] RecvMessage: outstanding_recv " + "finishes: received end-of-stream with error", + DebugTag().c_str()); + } + FailCompletion(completion); + *recv_message_ = nullptr; + } else { + if (grpc_call_trace.enabled()) { + gpr_log(GPR_INFO, + "%s[call] RecvMessage: outstanding_recv " + "finishes: received end-of-stream", + DebugTag().c_str()); + } + *recv_message_ = nullptr; + } + FinishOpOnCompletion(&completion, PendingOp::kReceiveMessage); + }); } /////////////////////////////////////////////////////////////////////////////// @@ -2811,24 +2658,40 @@ class ClientPromiseBasedCall final : public PromiseBasedCall { ~ClientPromiseBasedCall() override { ScopedContext context(this); send_initial_metadata_.reset(); - recv_status_on_client_ = absl::monostate(); - promise_ = ArenaPromise(); - // Need to destroy the pipes under the ScopedContext above, so we move them - // out here and then allow the destructors to run at end of scope, but - // before context. + // Need to destroy the pipes under the ScopedContext above, so we + // move them out here and then allow the destructors to run at + // end of scope, but before context. auto c2s = std::move(client_to_server_messages_); auto s2c = std::move(server_to_client_messages_); auto sim = std::move(server_initial_metadata_); } - absl::string_view GetServerAuthority() const override { abort(); } - void CancelWithErrorLocked(grpc_error_handle error) override - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu()); - bool is_trailers_only() const override { - MutexLock lock(mu()); - return is_trailers_only_; + void CancelWithError(absl::Status error) override { + if (!started_.exchange(true, std::memory_order_relaxed)) { + // Initial metadata not sent yet, so we can just fail the call. + Spawn( + "cancel_before_initial_metadata", + [error = std::move(error), this]() { + server_to_client_messages_.sender.Close(); + Finish(ServerMetadataFromStatus(error)); + return Empty{}; + }, + [](Empty) {}); + } else { + Spawn( + "cancel_with_error", + [error = std::move(error), this]() { + if (!cancel_error_.is_set()) { + cancel_error_.Set(ServerMetadataFromStatus(error)); + } + return Empty{}; + }, + [](Empty) {}); + } } - bool failed_before_recv_message() const override { abort(); } + absl::string_view GetServerAuthority() const override { abort(); } + bool is_trailers_only() const override { return is_trailers_only_; } + bool failed_before_recv_message() const override { return false; } grpc_call_error StartBatch(const grpc_op* ops, size_t nops, void* notify_tag, bool is_notify_tag_closure) override; @@ -2838,65 +2701,79 @@ class ClientPromiseBasedCall final : public PromiseBasedCall { } private: - // Poll the underlying promise (and sundry objects) once. - void UpdateOnce() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu()) override; // Finish the call with the given status/trailing metadata. - void Finish(ServerMetadataHandle trailing_metadata) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu()); + void Finish(ServerMetadataHandle trailing_metadata); // Validate that a set of ops is valid for a client call. - grpc_call_error ValidateBatch(const grpc_op* ops, size_t nops) const - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu()); + grpc_call_error ValidateBatch(const grpc_op* ops, size_t nops) const; // Commit a valid batch of operations to be executed. void CommitBatch(const grpc_op* ops, size_t nops, - const Completion& completion) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu()); + const Completion& completion); // Start the underlying promise. - void StartPromise(ClientMetadataHandle client_initial_metadata) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu()); + void StartPromise(ClientMetadataHandle client_initial_metadata, + const Completion& completion, Party::BulkSpawner& spawner); + // Start receiving initial metadata + void StartRecvInitialMetadata(grpc_metadata_array* array, + const Completion& completion, + Party::BulkSpawner& spawner); + void StartRecvStatusOnClient( + const Completion& completion, + grpc_op::grpc_op_data::grpc_op_recv_status_on_client op_args, + Party::BulkSpawner& spawner); // Publish status out to the application. void PublishStatus( grpc_op::grpc_op_data::grpc_op_recv_status_on_client op_args, - ServerMetadataHandle trailing_metadata) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu()); + ServerMetadataHandle trailing_metadata); // Publish server initial metadata out to the application. - void PublishInitialMetadata(ServerMetadata* metadata) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu()); - - ArenaPromise promise_ ABSL_GUARDED_BY(mu()); - Pipe server_initial_metadata_ ABSL_GUARDED_BY(mu()){ - arena()}; - Pipe client_to_server_messages_ ABSL_GUARDED_BY(mu()){arena()}; - Pipe server_to_client_messages_ ABSL_GUARDED_BY(mu()){arena()}; + void PublishInitialMetadata(ServerMetadata* metadata); ClientMetadataHandle send_initial_metadata_; - grpc_metadata_array* recv_initial_metadata_ ABSL_GUARDED_BY(mu()) = nullptr; - absl::variant - recv_status_on_client_ ABSL_GUARDED_BY(mu()); - absl::optional> - server_initial_metadata_ready_; - absl::optional incoming_compression_algorithm_; - Completion recv_initial_metadata_completion_ ABSL_GUARDED_BY(mu()); - Completion recv_status_on_client_completion_ ABSL_GUARDED_BY(mu()); - Completion close_send_completion_ ABSL_GUARDED_BY(mu()); - bool is_trailers_only_ ABSL_GUARDED_BY(mu()); + Pipe server_initial_metadata_{arena()}; + Latch server_trailing_metadata_; + Latch cancel_error_; + Pipe client_to_server_messages_{arena()}; + Pipe server_to_client_messages_{arena()}; + bool is_trailers_only_; + // True once the promise for the call is started. + // This corresponds to sending initial metadata, or cancelling before doing + // so. + // In the latter case real world code sometimes does not sent the initial + // metadata, and so gating based upon that does not work out. + std::atomic started_{false}; }; void ClientPromiseBasedCall::StartPromise( - ClientMetadataHandle client_initial_metadata) { - GPR_ASSERT(!promise_.has_value()); - promise_ = channel()->channel_stack()->MakeClientCallPromise(CallArgs{ - std::move(client_initial_metadata), - &server_initial_metadata_.sender, - &client_to_server_messages_.receiver, - &server_to_client_messages_.sender, - }); -} - -void ClientPromiseBasedCall::CancelWithErrorLocked(grpc_error_handle error) { - ScopedContext context(this); - Finish(ServerMetadataFromStatus(grpc_error_to_absl_status(error))); + ClientMetadataHandle client_initial_metadata, const Completion& completion, + Party::BulkSpawner& spawner) { + auto token = ClientInitialMetadataOutstandingToken::New(arena()); + spawner.Spawn( + "call_send_initial_metadata", token.Wait(), + [this, + completion = AddOpToCompletion( + completion, PendingOp::kSendInitialMetadata)](bool result) mutable { + if (!result) FailCompletion(completion); + FinishOpOnCompletion(&completion, PendingOp::kSendInitialMetadata); + }); + spawner.Spawn( + "client_promise", + [this, client_initial_metadata = std::move(client_initial_metadata), + token = std::move(token)]() mutable { + return Race( + cancel_error_.Wait(), + Map(channel()->channel_stack()->MakeClientCallPromise( + CallArgs{std::move(client_initial_metadata), + std::move(token), &server_initial_metadata_.sender, + &client_to_server_messages_.receiver, + &server_to_client_messages_.sender}), + [this](ServerMetadataHandle trailing_metadata) { + // If we're cancelled the transport doesn't get to return + // stats. + AcceptTransportStatsFromContext(); + return trailing_metadata; + })); + }, + [this](ServerMetadataHandle trailing_metadata) { + Finish(std::move(trailing_metadata)); + }); } grpc_call_error ClientPromiseBasedCall::ValidateBatch(const grpc_op* ops, @@ -2937,49 +2814,61 @@ grpc_call_error ClientPromiseBasedCall::ValidateBatch(const grpc_op* ops, void ClientPromiseBasedCall::CommitBatch(const grpc_op* ops, size_t nops, const Completion& completion) { + Party::BulkSpawner spawner(this); for (size_t op_idx = 0; op_idx < nops; op_idx++) { const grpc_op& op = ops[op_idx]; switch (op.op) { case GRPC_OP_SEND_INITIAL_METADATA: { - // compression not implemented - GPR_ASSERT( - !op.data.send_initial_metadata.maybe_compression_level.is_set); - if (!completed()) { - CToMetadata(op.data.send_initial_metadata.metadata, - op.data.send_initial_metadata.count, - send_initial_metadata_.get()); - StartPromise(std::move(send_initial_metadata_)); + if (started_.exchange(true, std::memory_order_relaxed)) break; + CToMetadata(op.data.send_initial_metadata.metadata, + op.data.send_initial_metadata.count, + send_initial_metadata_.get()); + PrepareOutgoingInitialMetadata(op, *send_initial_metadata_); + if (send_deadline() != Timestamp::InfFuture()) { + send_initial_metadata_->Set(GrpcTimeoutMetadata(), send_deadline()); } + send_initial_metadata_->Set( + WaitForReady(), + WaitForReady::ValueType{ + (op.flags & GRPC_INITIAL_METADATA_WAIT_FOR_READY) != 0, + (op.flags & + GRPC_INITIAL_METADATA_WAIT_FOR_READY_EXPLICITLY_SET) != 0}); + StartPromise(std::move(send_initial_metadata_), completion, spawner); } break; case GRPC_OP_RECV_INITIAL_METADATA: { - recv_initial_metadata_ = - op.data.recv_initial_metadata.recv_initial_metadata; - server_initial_metadata_ready_.emplace( - server_initial_metadata_.receiver.Next()); - recv_initial_metadata_completion_ = - AddOpToCompletion(completion, PendingOp::kReceiveInitialMetadata); + StartRecvInitialMetadata( + op.data.recv_initial_metadata.recv_initial_metadata, completion, + spawner); } break; case GRPC_OP_RECV_STATUS_ON_CLIENT: { - recv_status_on_client_completion_ = - AddOpToCompletion(completion, PendingOp::kReceiveStatusOnClient); - if (auto* finished_metadata = - absl::get_if(&recv_status_on_client_)) { - PublishStatus(op.data.recv_status_on_client, - std::move(*finished_metadata)); - } else { - recv_status_on_client_ = op.data.recv_status_on_client; - } + StartRecvStatusOnClient(completion, op.data.recv_status_on_client, + spawner); } break; case GRPC_OP_SEND_MESSAGE: - StartSendMessage(op, completion, &client_to_server_messages_.sender); + StartSendMessage(op, completion, &client_to_server_messages_.sender, + spawner); break; case GRPC_OP_RECV_MESSAGE: - StartRecvMessage(op, completion, &server_to_client_messages_.receiver); + StartRecvMessage( + op, completion, + [this]() { + return server_initial_metadata_.receiver.AwaitClosed(); + }, + &server_to_client_messages_.receiver, spawner); break; case GRPC_OP_SEND_CLOSE_FROM_CLIENT: - close_send_completion_ = - AddOpToCompletion(completion, PendingOp::kSendCloseFromClient); - GPR_ASSERT(close_send_completion_.has_value()); + spawner.Spawn( + "send_close_from_client", + [this]() { + client_to_server_messages_.sender.Close(); + return Empty{}; + }, + [this, + completion = AddOpToCompletion( + completion, PendingOp::kSendCloseFromClient)](Empty) mutable { + FinishOpOnCompletion(&completion, + PendingOp::kSendCloseFromClient); + }); break; case GRPC_OP_SEND_STATUS_FROM_SERVER: case GRPC_OP_RECV_CLOSE_ON_SERVER: @@ -2992,8 +2881,6 @@ grpc_call_error ClientPromiseBasedCall::StartBatch(const grpc_op* ops, size_t nops, void* notify_tag, bool is_notify_tag_closure) { - MutexLock lock(mu()); - ScopedContext activity_context(this); if (nops == 0) { EndOpImmediately(cq(), notify_tag, is_notify_tag_closure); return GRPC_CALL_OK; @@ -3005,71 +2892,35 @@ grpc_call_error ClientPromiseBasedCall::StartBatch(const grpc_op* ops, Completion completion = StartCompletion(notify_tag, is_notify_tag_closure, ops); CommitBatch(ops, nops, completion); - Update(); FinishOpOnCompletion(&completion, PendingOp::kStartingBatch); return GRPC_CALL_OK; } -void ClientPromiseBasedCall::PublishInitialMetadata(ServerMetadata* metadata) { - incoming_compression_algorithm_ = - metadata->Take(GrpcEncodingMetadata()).value_or(GRPC_COMPRESS_NONE); - Slice* peer_string = metadata->get_pointer(PeerString()); - if (peer_string != nullptr) SetPeerString(peer_string->Ref()); - server_initial_metadata_ready_.reset(); - GPR_ASSERT(recv_initial_metadata_ != nullptr); - PublishMetadataArray(metadata, - std::exchange(recv_initial_metadata_, nullptr)); - FinishOpOnCompletion(&recv_initial_metadata_completion_, - PendingOp::kReceiveInitialMetadata); -} - -void ClientPromiseBasedCall::UpdateOnce() { - if (grpc_call_trace.enabled()) { - gpr_log(GPR_INFO, "%s[call] UpdateOnce: %s%shas_promise=%s", - DebugTag().c_str(), - PresentAndCompletionText("server_initial_metadata_ready", - server_initial_metadata_ready_.has_value(), - recv_initial_metadata_completion_) - .c_str(), - PollStateDebugString().c_str(), - promise_.has_value() ? "true" : "false"); - } - if (server_initial_metadata_ready_.has_value()) { - Poll> r = - (*server_initial_metadata_ready_)(); - if (auto* server_initial_metadata = r.value_if_ready()) { - PublishInitialMetadata(server_initial_metadata->value().get()); - } else if (completed()) { - ServerMetadata no_metadata{GetContext()}; - PublishInitialMetadata(&no_metadata); - } - } - if (!PollSendMessage()) { - Finish(ServerMetadataFromStatus(absl::Status( - absl::StatusCode::kInternal, "Failed to send message to server"))); - } - if (!is_sending() && close_send_completion_.has_value()) { - client_to_server_messages_.sender.Close(); - FinishOpOnCompletion(&close_send_completion_, - PendingOp::kSendCloseFromClient); - } - if (promise_.has_value()) { - Poll r = promise_(); - if (grpc_call_trace.enabled()) { - gpr_log(GPR_INFO, "%s[call] UpdateOnce: promise returns %s", - DebugTag().c_str(), - PollToString(r, [](const ServerMetadataHandle& h) { - return h->DebugString(); - }).c_str()); - } - if (auto* result = r.value_if_ready()) { - AcceptTransportStatsFromContext(); - Finish(std::move(*result)); - } - } - if (incoming_compression_algorithm_.has_value()) { - PollRecvMessage(*incoming_compression_algorithm_); - } +void ClientPromiseBasedCall::StartRecvInitialMetadata( + grpc_metadata_array* array, const Completion& completion, + Party::BulkSpawner& spawner) { + spawner.Spawn( + "recv_initial_metadata", + Race(server_initial_metadata_.receiver.Next(), + Map(finished(), + [](Empty) { return NextResult(true); })), + [this, array, + completion = + AddOpToCompletion(completion, PendingOp::kReceiveInitialMetadata)]( + NextResult next_metadata) mutable { + server_initial_metadata_.sender.Close(); + ServerMetadataHandle metadata; + if (next_metadata.has_value()) { + is_trailers_only_ = false; + metadata = std::move(next_metadata.value()); + } else { + is_trailers_only_ = true; + metadata = arena()->MakePooled(arena()); + } + ProcessIncomingInitialMetadata(*metadata); + PublishMetadataArray(metadata.get(), array); + FinishOpOnCompletion(&completion, PendingOp::kReceiveInitialMetadata); + }); } void ClientPromiseBasedCall::Finish(ServerMetadataHandle trailing_metadata) { @@ -3077,31 +2928,9 @@ void ClientPromiseBasedCall::Finish(ServerMetadataHandle trailing_metadata) { gpr_log(GPR_INFO, "%s[call] Finish: %s", DebugTag().c_str(), trailing_metadata->DebugString().c_str()); } - promise_ = ArenaPromise(); ResetDeadline(); set_completed(); - if (recv_initial_metadata_ != nullptr) { - ForceImmediateRepoll(); - } - const bool pending_initial_metadata = - server_initial_metadata_ready_.has_value(); - if (!pending_initial_metadata) { - server_initial_metadata_ready_.emplace( - server_initial_metadata_.receiver.Next()); - } - Poll> r = - (*server_initial_metadata_ready_)(); - server_initial_metadata_ready_.reset(); - if (auto* result = r.value_if_ready()) { - if (pending_initial_metadata) PublishInitialMetadata(result->value().get()); - is_trailers_only_ = false; - } else { - if (pending_initial_metadata) { - ServerMetadata no_metadata{GetContext()}; - PublishInitialMetadata(&no_metadata); - } - is_trailers_only_ = true; - } + client_to_server_messages_.sender.Close(); if (auto* channelz_channel = channel()->channelz_node()) { if (trailing_metadata->get(GrpcStatusMetadata()) .value_or(GRPC_STATUS_UNKNOWN) == GRPC_STATUS_OK) { @@ -3110,13 +2939,7 @@ void ClientPromiseBasedCall::Finish(ServerMetadataHandle trailing_metadata) { channelz_channel->RecordCallFailed(); } } - if (auto* status_request = - absl::get_if( - &recv_status_on_client_)) { - PublishStatus(*status_request, std::move(trailing_metadata)); - } else { - recv_status_on_client_ = std::move(trailing_metadata); - } + server_trailing_metadata_.Set(std::move(trailing_metadata)); } namespace { @@ -3142,35 +2965,43 @@ std::string MakeErrorString(const ServerMetadata* trailing_metadata) { } } // namespace -void ClientPromiseBasedCall::PublishStatus( +void ClientPromiseBasedCall::StartRecvStatusOnClient( + const Completion& completion, grpc_op::grpc_op_data::grpc_op_recv_status_on_client op_args, - ServerMetadataHandle trailing_metadata) { - const grpc_status_code status = trailing_metadata->get(GrpcStatusMetadata()) - .value_or(GRPC_STATUS_UNKNOWN); - *op_args.status = status; - absl::string_view message_string; - if (Slice* message = trailing_metadata->get_pointer(GrpcMessageMetadata())) { - message_string = message->as_string_view(); - *op_args.status_details = message->Ref().TakeCSlice(); - } else { - *op_args.status_details = grpc_empty_slice(); - } - if (message_string.empty()) { - RunFinalization(status, nullptr); - } else { - std::string error_string(message_string); - RunFinalization(status, error_string.c_str()); - } - if (op_args.error_string != nullptr && status != GRPC_STATUS_OK) { - *op_args.error_string = - gpr_strdup(MakeErrorString(trailing_metadata.get()).c_str()); - } - PublishMetadataArray(trailing_metadata.get(), op_args.trailing_metadata); - // Clear state saying we have a RECV_STATUS_ON_CLIENT outstanding - // (so we don't call through twice) - recv_status_on_client_ = absl::monostate(); - FinishOpOnCompletion(&recv_status_on_client_completion_, - PendingOp::kReceiveStatusOnClient); + Party::BulkSpawner& spawner) { + ForceCompletionSuccess(completion); + spawner.Spawn( + "recv_status_on_client", server_trailing_metadata_.Wait(), + [this, op_args, + completion = + AddOpToCompletion(completion, PendingOp::kReceiveStatusOnClient)]( + ServerMetadataHandle trailing_metadata) mutable { + const grpc_status_code status = + trailing_metadata->get(GrpcStatusMetadata()) + .value_or(GRPC_STATUS_UNKNOWN); + *op_args.status = status; + absl::string_view message_string; + if (Slice* message = + trailing_metadata->get_pointer(GrpcMessageMetadata())) { + message_string = message->as_string_view(); + *op_args.status_details = message->Ref().TakeCSlice(); + } else { + *op_args.status_details = grpc_empty_slice(); + } + if (message_string.empty()) { + RunFinalization(status, nullptr); + } else { + std::string error_string(message_string); + RunFinalization(status, error_string.c_str()); + } + if (op_args.error_string != nullptr && status != GRPC_STATUS_OK) { + *op_args.error_string = + gpr_strdup(MakeErrorString(trailing_metadata.get()).c_str()); + } + PublishMetadataArray(trailing_metadata.get(), + op_args.trailing_metadata); + FinishOpOnCompletion(&completion, PendingOp::kReceiveStatusOnClient); + }); } #endif @@ -3183,19 +3014,18 @@ class ServerPromiseBasedCall final : public PromiseBasedCall { public: ServerPromiseBasedCall(Arena* arena, grpc_call_create_args* args); - void CancelWithErrorLocked(grpc_error_handle) override - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu()); + void CancelWithError(grpc_error_handle) override; grpc_call_error StartBatch(const grpc_op* ops, size_t nops, void* notify_tag, bool is_notify_tag_closure) override; - bool failed_before_recv_message() const override { abort(); } + bool failed_before_recv_message() const override { return false; } bool is_trailers_only() const override { abort(); } absl::string_view GetServerAuthority() const override { return ""; } // Polling order for the server promise stack: // // │ ┌───────────────────────────────────────┐ - // │ │ ServerPromiseBasedCall::UpdateOnce ├──► Lifetime management, - // │ ├───────────────────────────────────────┤ signal call end to app + // │ │ ServerPromiseBasedCall ├──► Lifetime management + // │ ├───────────────────────────────────────┤ // │ │ ConnectedChannel ├─┐ // │ ├───────────────────────────────────────┤ └► Interactions with the // │ │ ... closest to transport filter │ transport - send/recv msgs @@ -3206,16 +3036,12 @@ class ServerPromiseBasedCall final : public PromiseBasedCall { // │ ├───────────────────────────────────────┤ │ setup, publishing call to // │ │ Server::ChannelData::MakeCallPromise ├─┘ application // │ ├───────────────────────────────────────┤ - // │ │ ServerPromiseBasedCall::PollTopOfCall ├──► Application interactions, - // ▼ └───────────────────────────────────────┘ forwarding messages, - // Polling & sending trailing metadata + // │ │ MakeTopOfServerCallPromise ├──► Send trailing metadata + // ▼ └───────────────────────────────────────┘ + // Polling & // instantiation // order - void UpdateOnce() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu()) override; - Poll PollTopOfCall() - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu()); - std::string DebugTag() const override { return absl::StrFormat("SERVER_CALL[%p]: ", this); } @@ -3225,44 +3051,64 @@ class ServerPromiseBasedCall final : public PromiseBasedCall { private: class RecvCloseOpCancelState { public: - // Request that receiver be filled in per grpc_op_recv_close_on_server. - // Returns true if the request can be fulfilled immediately. - // Returns false if the request will be fulfilled later. + // Request that receiver be filled in per + // grpc_op_recv_close_on_server. Returns true if the request can + // be fulfilled immediately. Returns false if the request will be + // fulfilled later. bool ReceiveCloseOnServerOpStarted(int* receiver) { - switch (state_) { - case kUnset: - state_ = reinterpret_cast(receiver); - return false; - case kFinishedWithFailure: - *receiver = 1; - return true; - case kFinishedWithSuccess: - *receiver = 0; - return true; - default: - abort(); // unreachable - } + uintptr_t state = state_.load(std::memory_order_acquire); + uintptr_t new_state; + do { + switch (state) { + case kUnset: + new_state = reinterpret_cast(receiver); + break; + case kFinishedWithFailure: + *receiver = 1; + return true; + case kFinishedWithSuccess: + *receiver = 0; + return true; + default: + Crash("Two threads offered ReceiveCloseOnServerOpStarted"); + } + } while (!state_.compare_exchange_weak(state, new_state, + std::memory_order_acq_rel, + std::memory_order_acquire)); + return false; } // Mark the call as having completed. - // Returns true if this finishes a previous RequestReceiveCloseOnServer. - bool CompleteCall(bool success) { - switch (state_) { - case kUnset: - state_ = success ? kFinishedWithSuccess : kFinishedWithFailure; - return false; - case kFinishedWithFailure: - case kFinishedWithSuccess: - abort(); // unreachable - default: - *reinterpret_cast(state_) = success ? 0 : 1; - state_ = success ? kFinishedWithSuccess : kFinishedWithFailure; - return true; - } + // Returns true if this finishes a previous + // RequestReceiveCloseOnServer. + bool CompleteCallWithCancelledSetTo(bool cancelled) { + uintptr_t state = state_.load(std::memory_order_acquire); + uintptr_t new_state; + bool r; + do { + switch (state) { + case kUnset: + new_state = cancelled ? kFinishedWithFailure : kFinishedWithSuccess; + r = false; + break; + case kFinishedWithFailure: + return false; + case kFinishedWithSuccess: + Crash("unreachable"); + default: + new_state = cancelled ? kFinishedWithFailure : kFinishedWithSuccess; + r = true; + } + } while (!state_.compare_exchange_weak(state, new_state, + std::memory_order_acq_rel, + std::memory_order_acquire)); + if (r) *reinterpret_cast(state) = cancelled ? 1 : 0; + return r; } std::string ToString() const { - switch (state_) { + auto state = state_.load(std::memory_order_relaxed); + switch (state) { case kUnset: return "Unset"; case kFinishedWithFailure: @@ -3271,7 +3117,7 @@ class ServerPromiseBasedCall final : public PromiseBasedCall { return "FinishedWithSuccess"; default: return absl::StrFormat("WaitingForReceiver(%p)", - reinterpret_cast(state_)); + reinterpret_cast(state)); } } @@ -3279,37 +3125,28 @@ class ServerPromiseBasedCall final : public PromiseBasedCall { static constexpr uintptr_t kUnset = 0; static constexpr uintptr_t kFinishedWithFailure = 1; static constexpr uintptr_t kFinishedWithSuccess = 2; - // Holds one of kUnset, kFinishedWithFailure, or kFinishedWithSuccess - // OR an int* that wants to receive the final status. - uintptr_t state_ = kUnset; + // Holds one of kUnset, kFinishedWithFailure, or + // kFinishedWithSuccess OR an int* that wants to receive the + // final status. + std::atomic state_{kUnset}; }; grpc_call_error ValidateBatch(const grpc_op* ops, size_t nops) const; void CommitBatch(const grpc_op* ops, size_t nops, - const Completion& completion) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu()); + const Completion& completion); + void Finish(ServerMetadataHandle result); friend class ServerCallContext; ServerCallContext call_context_; Server* const server_; - ArenaPromise promise_ ABSL_GUARDED_BY(mu()); - PipeSender* server_to_client_messages_ ABSL_GUARDED_BY(mu()) = - nullptr; - PipeReceiver* client_to_server_messages_ - ABSL_GUARDED_BY(mu()) = nullptr; - using SendInitialMetadataState = - absl::variant*, - typename PipeSender::PushType>; - SendInitialMetadataState send_initial_metadata_state_ ABSL_GUARDED_BY(mu()) = - absl::monostate{}; - ServerMetadataHandle send_trailing_metadata_ ABSL_GUARDED_BY(mu()); - grpc_compression_algorithm incoming_compression_algorithm_ - ABSL_GUARDED_BY(mu()); - RecvCloseOpCancelState recv_close_op_cancel_state_ ABSL_GUARDED_BY(mu()); - Completion recv_close_completion_ ABSL_GUARDED_BY(mu()); - bool cancel_send_and_receive_ ABSL_GUARDED_BY(mu()) = false; - Completion send_status_from_server_completion_ ABSL_GUARDED_BY(mu()); - ClientMetadataHandle client_initial_metadata_ ABSL_GUARDED_BY(mu()); + PipeSender* server_initial_metadata_ = nullptr; + PipeSender* server_to_client_messages_ = nullptr; + PipeReceiver* client_to_server_messages_ = nullptr; + Latch send_trailing_metadata_; + RecvCloseOpCancelState recv_close_op_cancel_state_; + ClientMetadataHandle client_initial_metadata_; + Completion recv_close_completion_; + std::atomic cancelled_{false}; }; ServerPromiseBasedCall::ServerPromiseBasedCall(Arena* arena, @@ -3342,106 +3179,40 @@ ServerPromiseBasedCall::ServerPromiseBasedCall(Arena* arena, ContextSet(GRPC_CONTEXT_CALL_TRACER, server_call_tracer, nullptr); } } - MutexLock lock(mu()); ScopedContext activity_context(this); - promise_ = channel()->channel_stack()->MakeServerCallPromise( - CallArgs{nullptr, nullptr, nullptr, nullptr}); + Spawn("server_promise", + channel()->channel_stack()->MakeServerCallPromise( + CallArgs{nullptr, ClientInitialMetadataOutstandingToken::Empty(), + nullptr, nullptr, nullptr}), + [this](ServerMetadataHandle result) { Finish(std::move(result)); }); } -Poll ServerPromiseBasedCall::PollTopOfCall() { +void ServerPromiseBasedCall::Finish(ServerMetadataHandle result) { if (grpc_call_trace.enabled()) { - gpr_log(GPR_INFO, "%s[call] PollTopOfCall: %s%s%s", DebugTag().c_str(), - cancel_send_and_receive_ ? "force-" : "", - send_trailing_metadata_ != nullptr - ? absl::StrCat("send-metadata:", - send_trailing_metadata_->DebugString(), " ") - .c_str() - : " ", - PollStateDebugString().c_str()); - } - - if (cancel_send_and_receive_) { - CancelSendMessage(); - CancelRecvMessage(); + gpr_log(GPR_INFO, "%s[call] Finish: recv_close_state:%s result:%s", + DebugTag().c_str(), recv_close_op_cancel_state_.ToString().c_str(), + result->DebugString().c_str()); } - - PollSendMessage(); - PollRecvMessage(incoming_compression_algorithm_); - - if (!is_sending() && send_trailing_metadata_ != nullptr) { - server_to_client_messages_->Close(); - return std::move(send_trailing_metadata_); + if (recv_close_op_cancel_state_.CompleteCallWithCancelledSetTo( + result->get(GrpcCallWasCancelled()).value_or(true))) { + FinishOpOnCompletion(&recv_close_completion_, + PendingOp::kReceiveCloseOnServer); } - - return Pending{}; -} - -void ServerPromiseBasedCall::UpdateOnce() { - if (grpc_call_trace.enabled()) { - gpr_log( - GPR_INFO, "%s[call] UpdateOnce: recv_close:%s%s %s%shas_promise=%s", - DebugTag().c_str(), recv_close_op_cancel_state_.ToString().c_str(), - recv_close_completion_.has_value() - ? absl::StrCat(":", CompletionString(recv_close_completion_)) - .c_str() - : "", - send_status_from_server_completion_.has_value() - ? absl::StrCat( - "send_status:", - CompletionString(send_status_from_server_completion_), " ") - .c_str() - : "", - PollStateDebugString().c_str(), - promise_.has_value() ? "true" : "false"); - } - if (auto* p = - absl::get_if::PushType>( - &send_initial_metadata_state_)) { - if ((*p)().ready()) { - send_initial_metadata_state_ = absl::monostate{}; - } + if (server_initial_metadata_ != nullptr) { + server_initial_metadata_->Close(); } - if (promise_.has_value()) { - auto r = promise_(); - if (grpc_call_trace.enabled()) { - gpr_log(GPR_INFO, "%s[call] UpdateOnce: promise returns %s", - DebugTag().c_str(), - PollToString(r, [](const ServerMetadataHandle& h) { - return h->DebugString(); - }).c_str()); - } - if (auto* result = r.value_if_ready()) { - if (grpc_call_trace.enabled()) { - gpr_log(GPR_INFO, "%s[call] UpdateOnce: GotResult %s result:%s", - DebugTag().c_str(), - recv_close_op_cancel_state_.ToString().c_str(), - (*result)->DebugString().c_str()); - } - if (recv_close_op_cancel_state_.CompleteCall( - (*result)->get(GrpcStatusFromWire()).value_or(false))) { - FinishOpOnCompletion(&recv_close_completion_, - PendingOp::kReceiveCloseOnServer); - } - channelz::ServerNode* channelz_node = server_->channelz_node(); - if (channelz_node != nullptr) { - if ((*result) - ->get(GrpcStatusMetadata()) - .value_or(GRPC_STATUS_UNKNOWN) == GRPC_STATUS_OK) { - channelz_node->RecordCallSucceeded(); - } else { - channelz_node->RecordCallFailed(); - } - } - if (send_status_from_server_completion_.has_value()) { - FinishOpOnCompletion(&send_status_from_server_completion_, - PendingOp::kSendStatusFromServer); - } - CancelSendMessage(); - CancelRecvMessage(); - set_completed(); - promise_ = ArenaPromise(); + channelz::ServerNode* channelz_node = server_->channelz_node(); + if (channelz_node != nullptr) { + if (result->get(GrpcStatusMetadata()).value_or(GRPC_STATUS_UNKNOWN) == + GRPC_STATUS_OK) { + channelz_node->RecordCallSucceeded(); + } else { + channelz_node->RecordCallFailed(); } } + set_completed(); + ResetDeadline(); + PropagateCancellationToChildren(); } grpc_call_error ServerPromiseBasedCall::ValidateBatch(const grpc_op* ops, @@ -3482,56 +3253,92 @@ grpc_call_error ServerPromiseBasedCall::ValidateBatch(const grpc_op* ops, void ServerPromiseBasedCall::CommitBatch(const grpc_op* ops, size_t nops, const Completion& completion) { + Party::BulkSpawner spawner(this); for (size_t op_idx = 0; op_idx < nops; op_idx++) { const grpc_op& op = ops[op_idx]; switch (op.op) { case GRPC_OP_SEND_INITIAL_METADATA: { - // compression not implemented - GPR_ASSERT( - !op.data.send_initial_metadata.maybe_compression_level.is_set); - if (!completed()) { - auto metadata = arena()->MakePooled(arena()); - CToMetadata(op.data.send_initial_metadata.metadata, - op.data.send_initial_metadata.count, metadata.get()); - if (grpc_call_trace.enabled()) { - gpr_log(GPR_INFO, "%s[call] Send initial metadata", - DebugTag().c_str()); - } - auto* pipe = absl::get*>( - send_initial_metadata_state_); - send_initial_metadata_state_ = pipe->Push(std::move(metadata)); + auto metadata = arena()->MakePooled(arena()); + PrepareOutgoingInitialMetadata(op, *metadata); + CToMetadata(op.data.send_initial_metadata.metadata, + op.data.send_initial_metadata.count, metadata.get()); + if (grpc_call_trace.enabled()) { + gpr_log(GPR_INFO, "%s[call] Send initial metadata", + DebugTag().c_str()); } + QueueSend(); + spawner.Spawn( + "call_send_initial_metadata", + [this, metadata = std::move(metadata)]() mutable { + EnactSend(); + return server_initial_metadata_->Push(std::move(metadata)); + }, + [this, + completion = AddOpToCompletion( + completion, PendingOp::kSendInitialMetadata)](bool r) mutable { + if (!r) FailCompletion(completion); + FinishOpOnCompletion(&completion, + PendingOp::kSendInitialMetadata); + }); } break; case GRPC_OP_SEND_MESSAGE: - StartSendMessage(op, completion, server_to_client_messages_); + StartSendMessage(op, completion, server_to_client_messages_, spawner); break; case GRPC_OP_RECV_MESSAGE: - StartRecvMessage(op, completion, client_to_server_messages_); + if (cancelled_.load(std::memory_order_relaxed)) { + FailCompletion(completion); + break; + } + StartRecvMessage( + op, completion, []() { return []() { return Empty{}; }; }, + client_to_server_messages_, spawner); break; - case GRPC_OP_SEND_STATUS_FROM_SERVER: - send_trailing_metadata_ = arena()->MakePooled(arena()); + case GRPC_OP_SEND_STATUS_FROM_SERVER: { + auto metadata = arena()->MakePooled(arena()); CToMetadata(op.data.send_status_from_server.trailing_metadata, op.data.send_status_from_server.trailing_metadata_count, - send_trailing_metadata_.get()); - send_trailing_metadata_->Set(GrpcStatusMetadata(), - op.data.send_status_from_server.status); + metadata.get()); + metadata->Set(GrpcStatusMetadata(), + op.data.send_status_from_server.status); if (auto* details = op.data.send_status_from_server.status_details) { - send_trailing_metadata_->Set(GrpcMessageMetadata(), - Slice(CSliceRef(*details))); + metadata->Set(GrpcMessageMetadata(), Slice(CSliceRef(*details))); } - send_status_from_server_completion_ = - AddOpToCompletion(completion, PendingOp::kSendStatusFromServer); - break; + spawner.Spawn( + "call_send_status_from_server", + [this, metadata = std::move(metadata)]() mutable { + bool r = true; + if (send_trailing_metadata_.is_set()) { + r = false; + } else { + send_trailing_metadata_.Set(std::move(metadata)); + } + return Map(WaitForSendingStarted(), [this, r](Empty) { + server_initial_metadata_->Close(); + server_to_client_messages_->Close(); + return r; + }); + }, + [this, completion = AddOpToCompletion( + completion, PendingOp::kSendStatusFromServer)]( + bool ok) mutable { + if (!ok) FailCompletion(completion); + FinishOpOnCompletion(&completion, + PendingOp::kSendStatusFromServer); + }); + } break; case GRPC_OP_RECV_CLOSE_ON_SERVER: if (grpc_call_trace.enabled()) { gpr_log(GPR_INFO, "%s[call] StartBatch: RecvClose %s", DebugTag().c_str(), recv_close_op_cancel_state_.ToString().c_str()); } - if (!recv_close_op_cancel_state_.ReceiveCloseOnServerOpStarted( + ForceCompletionSuccess(completion); + recv_close_completion_ = + AddOpToCompletion(completion, PendingOp::kReceiveCloseOnServer); + if (recv_close_op_cancel_state_.ReceiveCloseOnServerOpStarted( op.data.recv_close_on_server.cancelled)) { - recv_close_completion_ = - AddOpToCompletion(completion, PendingOp::kReceiveCloseOnServer); + FinishOpOnCompletion(&recv_close_completion_, + PendingOp::kReceiveCloseOnServer); } break; case GRPC_OP_RECV_STATUS_ON_CLIENT: @@ -3546,8 +3353,6 @@ grpc_call_error ServerPromiseBasedCall::StartBatch(const grpc_op* ops, size_t nops, void* notify_tag, bool is_notify_tag_closure) { - MutexLock lock(mu()); - ScopedContext activity_context(this); if (nops == 0) { EndOpImmediately(cq(), notify_tag, is_notify_tag_closure); return GRPC_CALL_OK; @@ -3559,18 +3364,30 @@ grpc_call_error ServerPromiseBasedCall::StartBatch(const grpc_op* ops, Completion completion = StartCompletion(notify_tag, is_notify_tag_closure, ops); CommitBatch(ops, nops, completion); - Update(); FinishOpOnCompletion(&completion, PendingOp::kStartingBatch); return GRPC_CALL_OK; } -void ServerPromiseBasedCall::CancelWithErrorLocked(absl::Status error) { - if (!promise_.has_value()) return; - cancel_send_and_receive_ = true; - send_trailing_metadata_ = ServerMetadataFromStatus(error, arena()); - ForceWakeup(); +void ServerPromiseBasedCall::CancelWithError(absl::Status error) { + cancelled_.store(true, std::memory_order_relaxed); + Spawn( + "cancel_with_error", + [this, error = std::move(error)]() { + if (!send_trailing_metadata_.is_set()) { + auto md = ServerMetadataFromStatus(error); + md->Set(GrpcCallWasCancelled(), true); + send_trailing_metadata_.Set(std::move(md)); + } + if (server_to_client_messages_ != nullptr) { + server_to_client_messages_->Close(); + } + if (server_initial_metadata_ != nullptr) { + server_initial_metadata_->Close(); + } + return Empty{}; + }, + [](Empty) {}); } - #endif #ifdef GRPC_EXPERIMENT_IS_INCLUDED_PROMISE_BASED_SERVER_CALL @@ -3579,24 +3396,19 @@ ServerCallContext::MakeTopOfServerCallPromise( CallArgs call_args, grpc_completion_queue* cq, grpc_metadata_array* publish_initial_metadata, absl::FunctionRef publish) { - call_->mu()->AssertHeld(); - call_->SetCompletionQueueLocked(cq); + call_->SetCompletionQueue(cq); call_->server_to_client_messages_ = call_args.server_to_client_messages; call_->client_to_server_messages_ = call_args.client_to_server_messages; - call_->send_initial_metadata_state_ = call_args.server_initial_metadata; - call_->incoming_compression_algorithm_ = - call_args.client_initial_metadata->get(GrpcEncodingMetadata()) - .value_or(GRPC_COMPRESS_NONE); + call_->server_initial_metadata_ = call_args.server_initial_metadata; call_->client_initial_metadata_ = std::move(call_args.client_initial_metadata); + call_->ProcessIncomingInitialMetadata(*call_->client_initial_metadata_); PublishMetadataArray(call_->client_initial_metadata_.get(), publish_initial_metadata); call_->ExternalRef(); publish(call_->c_ptr()); - return [this]() { - call_->mu()->AssertHeld(); - return call_->PollTopOfCall(); - }; + return Seq(call_->server_to_client_messages_->AwaitClosed(), + call_->send_trailing_metadata_.Wait()); } #else ArenaPromise @@ -3699,7 +3511,9 @@ uint32_t grpc_call_test_only_get_message_flags(grpc_call* call) { } uint32_t grpc_call_test_only_get_encodings_accepted_by_peer(grpc_call* call) { - return grpc_core::Call::FromC(call)->test_only_encodings_accepted_by_peer(); + return grpc_core::Call::FromC(call) + ->encodings_accepted_by_peer() + .ToLegacyBitmask(); } grpc_core::Arena* grpc_call_get_arena(grpc_call* call) { @@ -3748,7 +3562,9 @@ uint8_t grpc_call_is_client(grpc_call* call) { grpc_compression_algorithm grpc_call_compression_for_level( grpc_call* call, grpc_compression_level level) { - return grpc_core::Call::FromC(call)->compression_for_level(level); + return grpc_core::Call::FromC(call) + ->encodings_accepted_by_peer() + .CompressionAlgorithmForLevel(level); } bool grpc_call_is_trailers_only(const grpc_call* call) { diff --git a/src/core/lib/surface/call.h b/src/core/lib/surface/call.h index ea4d3937157..61176bbe932 100644 --- a/src/core/lib/surface/call.h +++ b/src/core/lib/surface/call.h @@ -119,6 +119,11 @@ class CallContext { // TODO(ctiller): remove this once transport APIs are promise based void Unref(const char* reason = "call_context"); + RefCountedPtr Ref() { + IncrementRefCount(); + return RefCountedPtr(this); + } + grpc_call_stats* call_stats() { return &call_stats_; } gpr_atm* peer_string_atm_ptr(); grpc_polling_entity* polling_entity() { return &pollent_; } diff --git a/src/core/lib/surface/lame_client.cc b/src/core/lib/surface/lame_client.cc index ecbc9eed098..7fbdf8e64b4 100644 --- a/src/core/lib/surface/lame_client.cc +++ b/src/core/lib/surface/lame_client.cc @@ -79,6 +79,7 @@ ArenaPromise LameClientFilter::MakeCallPromise( if (args.server_to_client_messages != nullptr) { args.server_to_client_messages->Close(); } + args.client_initial_metadata_outstanding.Complete(true); return Immediate(ServerMetadataFromStatus(error_)); } diff --git a/src/core/lib/transport/batch_builder.cc b/src/core/lib/transport/batch_builder.cc new file mode 100644 index 00000000000..06d8c0a72f9 --- /dev/null +++ b/src/core/lib/transport/batch_builder.cc @@ -0,0 +1,179 @@ +// Copyright 2023 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "src/core/lib/transport/batch_builder.h" + +#include + +#include "src/core/lib/promise/poll.h" +#include "src/core/lib/slice/slice.h" +#include "src/core/lib/surface/call_trace.h" +#include "src/core/lib/transport/metadata_batch.h" +#include "src/core/lib/transport/transport.h" +#include "src/core/lib/transport/transport_impl.h" + +namespace grpc_core { + +BatchBuilder::BatchBuilder(grpc_transport_stream_op_batch_payload* payload) + : payload_(payload) {} + +void BatchBuilder::PendingCompletion::CompletionCallback( + void* self, grpc_error_handle error) { + auto* pc = static_cast(self); + auto* party = pc->batch->party.get(); + if (grpc_call_trace.enabled()) { + gpr_log( + GPR_DEBUG, "%s[connected] Finish batch-component %s for %s: status=%s", + party->DebugTag().c_str(), std::string(pc->name()).c_str(), + grpc_transport_stream_op_batch_string(&pc->batch->batch, false).c_str(), + error.ToString().c_str()); + } + party->Spawn( + "batch-completion", + [pc, error = std::move(error)]() mutable { + RefCountedPtr batch = std::exchange(pc->batch, nullptr); + pc->done_latch.Set(std::move(error)); + return Empty{}; + }, + [](Empty) {}); +} + +BatchBuilder::PendingCompletion::PendingCompletion(RefCountedPtr batch) + : batch(std::move(batch)) { + GRPC_CLOSURE_INIT(&on_done_closure, CompletionCallback, this, nullptr); +} + +BatchBuilder::Batch::Batch(grpc_transport_stream_op_batch_payload* payload, + grpc_stream_refcount* stream_refcount) + : party(static_cast(Activity::current())->Ref()), + stream_refcount(stream_refcount) { + batch.payload = payload; + batch.is_traced = GetContext()->traced(); +#ifndef NDEBUG + grpc_stream_ref(stream_refcount, "pending-batch"); +#else + grpc_stream_ref(stream_refcount); +#endif +} + +BatchBuilder::Batch::~Batch() { + auto* arena = party->arena(); + if (pending_receive_message != nullptr) { + arena->DeletePooled(pending_receive_message); + } + if (pending_receive_initial_metadata != nullptr) { + arena->DeletePooled(pending_receive_initial_metadata); + } + if (pending_receive_trailing_metadata != nullptr) { + arena->DeletePooled(pending_receive_trailing_metadata); + } + if (pending_sends != nullptr) { + arena->DeletePooled(pending_sends); + } + if (batch.cancel_stream) { + arena->DeletePooled(batch.payload); + } +#ifndef NDEBUG + grpc_stream_unref(stream_refcount, "pending-batch"); +#else + grpc_stream_unref(stream_refcount); +#endif +} + +BatchBuilder::Batch* BatchBuilder::GetBatch(Target target) { + if (target_.has_value() && + (target_->stream != target.stream || + target.transport->vtable + ->hacky_disable_stream_op_batch_coalescing_in_connected_channel)) { + FlushBatch(); + } + if (!target_.has_value()) { + target_ = target; + batch_ = GetContext()->NewPooled(payload_, + target_->stream_refcount); + } + GPR_ASSERT(batch_ != nullptr); + return batch_; +} + +void BatchBuilder::FlushBatch() { + GPR_ASSERT(batch_ != nullptr); + GPR_ASSERT(target_.has_value()); + if (grpc_call_trace.enabled()) { + gpr_log( + GPR_DEBUG, "%s[connected] Perform transport stream op batch: %p %s", + batch_->party->DebugTag().c_str(), &batch_->batch, + grpc_transport_stream_op_batch_string(&batch_->batch, false).c_str()); + } + std::exchange(batch_, nullptr)->PerformWith(*target_); + target_.reset(); +} + +void BatchBuilder::Batch::PerformWith(Target target) { + grpc_transport_perform_stream_op(target.transport, target.stream, &batch); +} + +ServerMetadataHandle BatchBuilder::CompleteSendServerTrailingMetadata( + ServerMetadataHandle sent_metadata, absl::Status send_result, + bool actually_sent) { + if (!send_result.ok()) { + if (grpc_call_trace.enabled()) { + gpr_log(GPR_DEBUG, + "%s[connected] Send metadata failed with error: %s, " + "fabricating trailing metadata", + Activity::current()->DebugTag().c_str(), + send_result.ToString().c_str()); + } + sent_metadata->Clear(); + sent_metadata->Set(GrpcStatusMetadata(), + static_cast(send_result.code())); + sent_metadata->Set(GrpcMessageMetadata(), + Slice::FromCopiedString(send_result.message())); + sent_metadata->Set(GrpcCallWasCancelled(), true); + } + if (!sent_metadata->get(GrpcCallWasCancelled()).has_value()) { + if (grpc_call_trace.enabled()) { + gpr_log( + GPR_DEBUG, + "%s[connected] Tagging trailing metadata with " + "cancellation status from transport: %s", + Activity::current()->DebugTag().c_str(), + actually_sent ? "sent => not-cancelled" : "not-sent => cancelled"); + } + sent_metadata->Set(GrpcCallWasCancelled(), !actually_sent); + } + return sent_metadata; +} + +BatchBuilder::Batch* BatchBuilder::MakeCancel( + grpc_stream_refcount* stream_refcount, absl::Status status) { + auto* arena = GetContext(); + auto* payload = + arena->NewPooled(nullptr); + auto* batch = arena->NewPooled(payload, stream_refcount); + batch->batch.cancel_stream = true; + payload->cancel_stream.cancel_error = std::move(status); + return batch; +} + +void BatchBuilder::Cancel(Target target, absl::Status status) { + auto* batch = MakeCancel(target.stream_refcount, std::move(status)); + batch->batch.on_complete = NewClosure( + [batch](absl::Status) { batch->party->arena()->DeletePooled(batch); }); + batch->PerformWith(target); +} + +} // namespace grpc_core diff --git a/src/core/lib/transport/batch_builder.h b/src/core/lib/transport/batch_builder.h new file mode 100644 index 00000000000..5b0056aebe7 --- /dev/null +++ b/src/core/lib/transport/batch_builder.h @@ -0,0 +1,468 @@ +// Copyright 2023 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GRPC_SRC_CORE_LIB_TRANSPORT_BATCH_BUILDER_H +#define GRPC_SRC_CORE_LIB_TRANSPORT_BATCH_BUILDER_H + +#include + +#include + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" + +#include +#include + +#include "src/core/lib/channel/channel_stack.h" +#include "src/core/lib/debug/trace.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" +#include "src/core/lib/gprpp/status_helper.h" +#include "src/core/lib/iomgr/closure.h" +#include "src/core/lib/iomgr/error.h" +#include "src/core/lib/promise/activity.h" +#include "src/core/lib/promise/context.h" +#include "src/core/lib/promise/latch.h" +#include "src/core/lib/promise/map.h" +#include "src/core/lib/promise/party.h" +#include "src/core/lib/resource_quota/arena.h" +#include "src/core/lib/slice/slice_buffer.h" +#include "src/core/lib/surface/call.h" +#include "src/core/lib/surface/call_trace.h" +#include "src/core/lib/transport/metadata_batch.h" +#include "src/core/lib/transport/transport.h" +#include "src/core/lib/transport/transport_fwd.h" + +namespace grpc_core { + +// Build up a transport stream op batch for a stream for a promise based +// connected channel. +// Offered as a context from Call, so that it can collect ALL the updates during +// a single party round, and then push them down to the transport as a single +// transaction. +class BatchBuilder { + public: + explicit BatchBuilder(grpc_transport_stream_op_batch_payload* payload); + ~BatchBuilder() { + if (batch_ != nullptr) FlushBatch(); + } + + struct Target { + grpc_transport* transport; + grpc_stream* stream; + grpc_stream_refcount* stream_refcount; + }; + + BatchBuilder(const BatchBuilder&) = delete; + BatchBuilder& operator=(const BatchBuilder&) = delete; + + // Returns a promise that will resolve to a Status when the send is completed. + auto SendMessage(Target target, MessageHandle message); + + // Returns a promise that will resolve to a Status when the send is completed. + auto SendClientInitialMetadata(Target target, ClientMetadataHandle metadata); + + // Returns a promise that will resolve to a Status when the send is completed. + auto SendClientTrailingMetadata(Target target); + + // Returns a promise that will resolve to a Status when the send is completed. + auto SendServerInitialMetadata(Target target, ServerMetadataHandle metadata); + + // Returns a promise that will resolve to a ServerMetadataHandle when the send + // is completed. + // + // If convert_to_cancellation is true, then the status will be converted to a + // cancellation batch instead of a trailing metadata op in a coalesced batch. + // + // This quirk exists as in the filter based stack upon which our transports + // were written if a trailing metadata op were sent it always needed to be + // paired with an initial op batch, and the transports would wait for the + // initial metadata batch to arrive (in case of reordering up the stack). + auto SendServerTrailingMetadata(Target target, ServerMetadataHandle metadata, + bool convert_to_cancellation); + + // Returns a promise that will resolve to a StatusOr> + // when a message is received. + // Error => non-ok status + // End of stream => Ok, nullopt (no message) + // Message => Ok, message + auto ReceiveMessage(Target target); + + // Returns a promise that will resolve to a StatusOr + // when the receive is complete. + auto ReceiveClientInitialMetadata(Target target); + + // Returns a promise that will resolve to a StatusOr + // when the receive is complete. + auto ReceiveClientTrailingMetadata(Target target); + + // Returns a promise that will resolve to a StatusOr + // when the receive is complete. + auto ReceiveServerInitialMetadata(Target target); + + // Returns a promise that will resolve to a StatusOr + // when the receive is complete. + auto ReceiveServerTrailingMetadata(Target target); + + // Send a cancellation: does not occupy the same payload, nor does it + // coalesce with other ops. + void Cancel(Target target, absl::Status status); + + private: + struct Batch; + + // Base pending operation + struct PendingCompletion { + explicit PendingCompletion(RefCountedPtr batch); + virtual absl::string_view name() const = 0; + static void CompletionCallback(void* self, grpc_error_handle error); + grpc_closure on_done_closure; + Latch done_latch; + RefCountedPtr batch; + + protected: + ~PendingCompletion() = default; + }; + + // A pending receive message. + struct PendingReceiveMessage final : public PendingCompletion { + using PendingCompletion::PendingCompletion; + + absl::string_view name() const override { return "receive_message"; } + + MessageHandle IntoMessageHandle() { + return GetContext()->MakePooled(std::move(*payload), + flags); + } + + absl::optional payload; + uint32_t flags; + }; + + // A pending receive metadata. + struct PendingReceiveMetadata : public PendingCompletion { + using PendingCompletion::PendingCompletion; + + Arena::PoolPtr metadata = + GetContext()->MakePooled( + GetContext()); + + protected: + ~PendingReceiveMetadata() = default; + }; + + struct PendingReceiveInitialMetadata final : public PendingReceiveMetadata { + using PendingReceiveMetadata::PendingReceiveMetadata; + absl::string_view name() const override { + return "receive_initial_metadata"; + } + }; + + struct PendingReceiveTrailingMetadata final : public PendingReceiveMetadata { + using PendingReceiveMetadata::PendingReceiveMetadata; + absl::string_view name() const override { + return "receive_trailing_metadata"; + } + }; + + // Pending sends in a batch + struct PendingSends final : public PendingCompletion { + using PendingCompletion::PendingCompletion; + + absl::string_view name() const override { return "sends"; } + + MessageHandle send_message; + Arena::PoolPtr send_initial_metadata; + Arena::PoolPtr send_trailing_metadata; + bool trailing_metadata_sent = false; + }; + + // One outstanding batch. + struct Batch final { + Batch(grpc_transport_stream_op_batch_payload* payload, + grpc_stream_refcount* stream_refcount); + ~Batch(); + Batch(const Batch&) = delete; + Batch& operator=(const Batch&) = delete; + void IncrementRefCount() { ++refs; } + void Unref() { + if (--refs == 0) party->arena()->DeletePooled(this); + } + RefCountedPtr Ref() { + IncrementRefCount(); + return RefCountedPtr(this); + } + // Get an initialized pending completion. + // There are four pending completions potentially contained within a batch. + // They can be rather large so we don't create all of them always. Instead, + // we dynamically create them on the arena as needed. + // This method either returns the existing completion in a batch if that + // completion has already been initialized, or it creates a new completion + // and returns that. + template + T* GetInitializedCompletion(T*(Batch::*field)) { + if (this->*field != nullptr) return this->*field; + this->*field = party->arena()->NewPooled(Ref()); + if (grpc_call_trace.enabled()) { + gpr_log(GPR_DEBUG, "%s[connected] Add batch closure for %s @ %s", + Activity::current()->DebugTag().c_str(), + std::string((this->*field)->name()).c_str(), + (this->*field)->on_done_closure.DebugString().c_str()); + } + return this->*field; + } + // grpc_transport_perform_stream_op on target.stream + void PerformWith(Target target); + // Take a promise, and return a promise that holds a ref on this batch until + // the promise completes or is cancelled. + template + auto RefUntil(P promise) { + return [self = Ref(), promise = std::move(promise)]() mutable { + return promise(); + }; + } + + grpc_transport_stream_op_batch batch; + PendingReceiveMessage* pending_receive_message = nullptr; + PendingReceiveInitialMetadata* pending_receive_initial_metadata = nullptr; + PendingReceiveTrailingMetadata* pending_receive_trailing_metadata = nullptr; + PendingSends* pending_sends = nullptr; + const RefCountedPtr party; + grpc_stream_refcount* const stream_refcount; + uint8_t refs = 0; + }; + + // Get a batch for the given target. + // Currently: if the current batch is for this target, return it - otherwise + // flush the batch and start a new one (and return that). + // This function may change in the future to allow multiple batches to be + // building at once (if that turns out to be useful for hedging). + Batch* GetBatch(Target target); + // Flush the current batch down to the transport. + void FlushBatch(); + // Create a cancel batch with its own payload. + Batch* MakeCancel(grpc_stream_refcount* stream_refcount, absl::Status status); + + // Note: we don't distinguish between client and server metadata here. + // At the time of writing they're both the same thing - and it's unclear + // whether we'll get to separate them prior to batches going away or not. + // So for now we claim YAGNI and just do the simplest possible implementation. + auto SendInitialMetadata(Target target, + Arena::PoolPtr md); + auto ReceiveInitialMetadata(Target target); + auto ReceiveTrailingMetadata(Target target); + + // Combine send status and server metadata into a final status to report back + // to the containing call. + static ServerMetadataHandle CompleteSendServerTrailingMetadata( + ServerMetadataHandle sent_metadata, absl::Status send_result, + bool actually_sent); + + grpc_transport_stream_op_batch_payload* const payload_; + absl::optional target_; + Batch* batch_ = nullptr; +}; + +inline auto BatchBuilder::SendMessage(Target target, MessageHandle message) { + if (grpc_call_trace.enabled()) { + gpr_log(GPR_DEBUG, "%s[connected] Queue send message: %s", + Activity::current()->DebugTag().c_str(), + message->DebugString().c_str()); + } + auto* batch = GetBatch(target); + auto* pc = batch->GetInitializedCompletion(&Batch::pending_sends); + batch->batch.on_complete = &pc->on_done_closure; + batch->batch.send_message = true; + payload_->send_message.send_message = message->payload(); + payload_->send_message.flags = message->flags(); + pc->send_message = std::move(message); + return batch->RefUntil(pc->done_latch.WaitAndCopy()); +} + +inline auto BatchBuilder::SendInitialMetadata( + Target target, Arena::PoolPtr md) { + if (grpc_call_trace.enabled()) { + gpr_log(GPR_DEBUG, "%s[connected] Queue send initial metadata: %s", + Activity::current()->DebugTag().c_str(), md->DebugString().c_str()); + } + auto* batch = GetBatch(target); + auto* pc = batch->GetInitializedCompletion(&Batch::pending_sends); + batch->batch.on_complete = &pc->on_done_closure; + batch->batch.send_initial_metadata = true; + payload_->send_initial_metadata.send_initial_metadata = md.get(); + pc->send_initial_metadata = std::move(md); + return batch->RefUntil(pc->done_latch.WaitAndCopy()); +} + +inline auto BatchBuilder::SendClientInitialMetadata( + Target target, ClientMetadataHandle metadata) { + return SendInitialMetadata(target, std::move(metadata)); +} + +inline auto BatchBuilder::SendClientTrailingMetadata(Target target) { + if (grpc_call_trace.enabled()) { + gpr_log(GPR_DEBUG, "%s[connected] Queue send trailing metadata", + Activity::current()->DebugTag().c_str()); + } + auto* batch = GetBatch(target); + auto* pc = batch->GetInitializedCompletion(&Batch::pending_sends); + batch->batch.on_complete = &pc->on_done_closure; + batch->batch.send_trailing_metadata = true; + auto metadata = + GetContext()->MakePooled(GetContext()); + payload_->send_trailing_metadata.send_trailing_metadata = metadata.get(); + payload_->send_trailing_metadata.sent = nullptr; + pc->send_trailing_metadata = std::move(metadata); + return batch->RefUntil(pc->done_latch.WaitAndCopy()); +} + +inline auto BatchBuilder::SendServerInitialMetadata( + Target target, ServerMetadataHandle metadata) { + return SendInitialMetadata(target, std::move(metadata)); +} + +inline auto BatchBuilder::SendServerTrailingMetadata( + Target target, ServerMetadataHandle metadata, + bool convert_to_cancellation) { + if (grpc_call_trace.enabled()) { + gpr_log(GPR_DEBUG, "%s[connected] %s: %s", + Activity::current()->DebugTag().c_str(), + convert_to_cancellation ? "Send trailing metadata as cancellation" + : "Queue send trailing metadata", + metadata->DebugString().c_str()); + } + Batch* batch; + PendingSends* pc; + if (convert_to_cancellation) { + const auto status_code = + metadata->get(GrpcStatusMetadata()).value_or(GRPC_STATUS_UNKNOWN); + auto status = grpc_error_set_int( + absl::Status(static_cast(status_code), + metadata->GetOrCreatePointer(GrpcMessageMetadata()) + ->as_string_view()), + StatusIntProperty::kRpcStatus, status_code); + batch = MakeCancel(target.stream_refcount, std::move(status)); + pc = batch->GetInitializedCompletion(&Batch::pending_sends); + } else { + batch = GetBatch(target); + pc = batch->GetInitializedCompletion(&Batch::pending_sends); + batch->batch.send_trailing_metadata = true; + payload_->send_trailing_metadata.send_trailing_metadata = metadata.get(); + payload_->send_trailing_metadata.sent = &pc->trailing_metadata_sent; + } + batch->batch.on_complete = &pc->on_done_closure; + pc->send_trailing_metadata = std::move(metadata); + auto promise = batch->RefUntil( + Map(pc->done_latch.WaitAndCopy(), [pc](absl::Status status) { + return CompleteSendServerTrailingMetadata( + std::move(pc->send_trailing_metadata), std::move(status), + pc->trailing_metadata_sent); + })); + if (convert_to_cancellation) { + batch->PerformWith(target); + } + return promise; +} + +inline auto BatchBuilder::ReceiveMessage(Target target) { + if (grpc_call_trace.enabled()) { + gpr_log(GPR_DEBUG, "%s[connected] Queue receive message", + Activity::current()->DebugTag().c_str()); + } + auto* batch = GetBatch(target); + auto* pc = batch->GetInitializedCompletion(&Batch::pending_receive_message); + batch->batch.recv_message = true; + payload_->recv_message.recv_message_ready = &pc->on_done_closure; + payload_->recv_message.recv_message = &pc->payload; + payload_->recv_message.flags = &pc->flags; + return batch->RefUntil( + Map(pc->done_latch.Wait(), + [pc](absl::Status status) + -> absl::StatusOr> { + if (!status.ok()) return status; + if (!pc->payload.has_value()) return absl::nullopt; + return pc->IntoMessageHandle(); + })); +} + +inline auto BatchBuilder::ReceiveInitialMetadata(Target target) { + if (grpc_call_trace.enabled()) { + gpr_log(GPR_DEBUG, "%s[connected] Queue receive initial metadata", + Activity::current()->DebugTag().c_str()); + } + auto* batch = GetBatch(target); + auto* pc = + batch->GetInitializedCompletion(&Batch::pending_receive_initial_metadata); + batch->batch.recv_initial_metadata = true; + payload_->recv_initial_metadata.recv_initial_metadata_ready = + &pc->on_done_closure; + payload_->recv_initial_metadata.recv_initial_metadata = pc->metadata.get(); + return batch->RefUntil( + Map(pc->done_latch.Wait(), + [pc](absl::Status status) -> absl::StatusOr { + if (!status.ok()) return status; + return std::move(pc->metadata); + })); +} + +inline auto BatchBuilder::ReceiveClientInitialMetadata(Target target) { + return ReceiveInitialMetadata(target); +} + +inline auto BatchBuilder::ReceiveServerInitialMetadata(Target target) { + return ReceiveInitialMetadata(target); +} + +inline auto BatchBuilder::ReceiveTrailingMetadata(Target target) { + if (grpc_call_trace.enabled()) { + gpr_log(GPR_DEBUG, "%s[connected] Queue receive trailing metadata", + Activity::current()->DebugTag().c_str()); + } + auto* batch = GetBatch(target); + auto* pc = batch->GetInitializedCompletion( + &Batch::pending_receive_trailing_metadata); + batch->batch.recv_trailing_metadata = true; + payload_->recv_trailing_metadata.recv_trailing_metadata_ready = + &pc->on_done_closure; + payload_->recv_trailing_metadata.recv_trailing_metadata = pc->metadata.get(); + payload_->recv_trailing_metadata.collect_stats = + &GetContext()->call_stats()->transport_stream_stats; + return batch->RefUntil( + Map(pc->done_latch.Wait(), + [pc](absl::Status status) -> absl::StatusOr { + if (!status.ok()) return status; + return std::move(pc->metadata); + })); +} + +inline auto BatchBuilder::ReceiveClientTrailingMetadata(Target target) { + return ReceiveTrailingMetadata(target); +} + +inline auto BatchBuilder::ReceiveServerTrailingMetadata(Target target) { + return ReceiveTrailingMetadata(target); +} + +template <> +struct ContextType {}; + +} // namespace grpc_core + +#endif // GRPC_SRC_CORE_LIB_TRANSPORT_BATCH_BUILDER_H diff --git a/src/core/lib/transport/metadata_batch.h b/src/core/lib/transport/metadata_batch.h index 5b4f6d972f7..493a9f9d94b 100644 --- a/src/core/lib/transport/metadata_batch.h +++ b/src/core/lib/transport/metadata_batch.h @@ -441,6 +441,15 @@ struct GrpcStatusFromWire { static absl::string_view DisplayValue(bool x) { return x ? "true" : "false"; } }; +// Annotation to denote that this call qualifies for cancelled=1 for the +// RECV_CLOSE_ON_SERVER op +struct GrpcCallWasCancelled { + static absl::string_view DebugKey() { return "GrpcCallWasCancelled"; } + static constexpr bool kRepeatable = false; + using ValueType = bool; + static absl::string_view DisplayValue(bool x) { return x ? "true" : "false"; } +}; + // Annotation added by client surface code to denote wait-for-ready state struct WaitForReady { struct ValueType { @@ -1378,7 +1387,8 @@ using grpc_metadata_batch_base = grpc_core::MetadataMap< // Non-encodable things grpc_core::GrpcStreamNetworkState, grpc_core::PeerString, grpc_core::GrpcStatusContext, grpc_core::GrpcStatusFromWire, - grpc_core::WaitForReady, grpc_core::GrpcTrailersOnly>; + grpc_core::GrpcCallWasCancelled, grpc_core::WaitForReady, + grpc_core::GrpcTrailersOnly>; struct grpc_metadata_batch : public grpc_metadata_batch_base { using grpc_metadata_batch_base::grpc_metadata_batch_base; diff --git a/src/core/lib/transport/transport.cc b/src/core/lib/transport/transport.cc index 6e2e48ac24c..055d298780f 100644 --- a/src/core/lib/transport/transport.cc +++ b/src/core/lib/transport/transport.cc @@ -26,13 +26,17 @@ #include #include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include +#include #include "src/core/lib/event_engine/default_event_engine.h" #include "src/core/lib/gpr/alloc.h" +#include "src/core/lib/gprpp/time.h" #include "src/core/lib/iomgr/exec_ctx.h" #include "src/core/lib/slice/slice.h" +#include "src/core/lib/transport/error_utils.h" #include "src/core/lib/transport/transport_impl.h" grpc_core::DebugOnlyTraceFlag grpc_trace_stream_refcount(false, @@ -271,11 +275,35 @@ namespace grpc_core { ServerMetadataHandle ServerMetadataFromStatus(const absl::Status& status, Arena* arena) { auto hdl = arena->MakePooled(arena); - hdl->Set(GrpcStatusMetadata(), static_cast(status.code())); + grpc_status_code code; + std::string message; + grpc_error_get_status(status, Timestamp::InfFuture(), &code, &message, + nullptr, nullptr); + hdl->Set(GrpcStatusMetadata(), code); if (!status.ok()) { - hdl->Set(GrpcMessageMetadata(), Slice::FromCopiedString(status.message())); + hdl->Set(GrpcMessageMetadata(), Slice::FromCopiedString(message)); } return hdl; } +std::string Message::DebugString() const { + std::string out = absl::StrCat(payload_.Length(), "b"); + auto flags = flags_; + auto explain = [&flags, &out](uint32_t flag, absl::string_view name) { + if (flags & flag) { + flags &= ~flag; + absl::StrAppend(&out, ":", name); + } + }; + explain(GRPC_WRITE_BUFFER_HINT, "write_buffer"); + explain(GRPC_WRITE_NO_COMPRESS, "no_compress"); + explain(GRPC_WRITE_THROUGH, "write_through"); + explain(GRPC_WRITE_INTERNAL_COMPRESS, "compress"); + explain(GRPC_WRITE_INTERNAL_TEST_ONLY_WAS_COMPRESSED, "was_compressed"); + if (flags != 0) { + absl::StrAppend(&out, ":huh=0x", absl::Hex(flags)); + } + return out; +} + } // namespace grpc_core diff --git a/src/core/lib/transport/transport.h b/src/core/lib/transport/transport.h index af5611e9a49..d74aa477aca 100644 --- a/src/core/lib/transport/transport.h +++ b/src/core/lib/transport/transport.h @@ -27,6 +27,7 @@ #include #include +#include #include #include "absl/status/status.h" @@ -53,6 +54,7 @@ #include "src/core/lib/promise/arena_promise.h" #include "src/core/lib/promise/context.h" #include "src/core/lib/promise/detail/status.h" +#include "src/core/lib/promise/latch.h" #include "src/core/lib/promise/pipe.h" #include "src/core/lib/resource_quota/arena.h" #include "src/core/lib/slice/slice_buffer.h" @@ -105,6 +107,8 @@ class Message { SliceBuffer* payload() { return &payload_; } const SliceBuffer* payload() const { return &payload_; } + std::string DebugString() const; + private: SliceBuffer payload_; uint32_t flags_ = 0; @@ -143,11 +147,70 @@ struct StatusCastImpl { } }; +// Move only type that tracks call startup. +// Allows observation of when client_initial_metadata has been processed by the +// end of the local call stack. +// Interested observers can call Wait() to obtain a promise that will resolve +// when all local client_initial_metadata processing has completed. +// The result of this token is either true on successful completion, or false +// if the metadata was not sent. +// To set a successful completion, call Complete(true). For failure, call +// Complete(false). +// If Complete is not called, the destructor of a still held token will complete +// with failure. +// Transports should hold this token until client_initial_metadata has passed +// any flow control (eg MAX_CONCURRENT_STREAMS for http2). +class ClientInitialMetadataOutstandingToken { + public: + static ClientInitialMetadataOutstandingToken Empty() { + return ClientInitialMetadataOutstandingToken(); + } + static ClientInitialMetadataOutstandingToken New( + Arena* arena = GetContext()) { + ClientInitialMetadataOutstandingToken token; + token.latch_ = arena->New>(); + return token; + } + + ClientInitialMetadataOutstandingToken( + const ClientInitialMetadataOutstandingToken&) = delete; + ClientInitialMetadataOutstandingToken& operator=( + const ClientInitialMetadataOutstandingToken&) = delete; + ClientInitialMetadataOutstandingToken( + ClientInitialMetadataOutstandingToken&& other) noexcept + : latch_(std::exchange(other.latch_, nullptr)) {} + ClientInitialMetadataOutstandingToken& operator=( + ClientInitialMetadataOutstandingToken&& other) noexcept { + latch_ = std::exchange(other.latch_, nullptr); + return *this; + } + ~ClientInitialMetadataOutstandingToken() { + if (latch_ != nullptr) latch_->Set(false); + } + void Complete(bool success) { std::exchange(latch_, nullptr)->Set(success); } + + // Returns a promise that will resolve when this object (or its moved-from + // ancestor) is dropped. + auto Wait() { return latch_->Wait(); } + + private: + ClientInitialMetadataOutstandingToken() = default; + + Latch* latch_ = nullptr; +}; + +using ClientInitialMetadataOutstandingTokenWaitType = + decltype(std::declval().Wait()); + struct CallArgs { // Initial metadata from the client to the server. // During promise setup this can be manipulated by filters (and then // passed on to the next filter). ClientMetadataHandle client_initial_metadata; + // Token indicating that client_initial_metadata is still being processed. + // This should be moved around and only destroyed when the transport is + // satisfied that the metadata has passed any flow control measures it has. + ClientInitialMetadataOutstandingToken client_initial_metadata_outstanding; // Initial metadata from the server to the client. // Set once when it's available. // During promise setup filters can substitute their own latch for this @@ -330,6 +393,12 @@ struct grpc_transport_stream_op_batch { /// Is this stream traced bool is_traced : 1; + bool HasOp() const { + return send_initial_metadata || send_trailing_metadata || send_message || + recv_initial_metadata || recv_message || recv_trailing_metadata || + cancel_stream; + } + //************************************************************************** // remaining fields are initialized and used at the discretion of the // current handler of the op diff --git a/src/core/lib/transport/transport_impl.h b/src/core/lib/transport/transport_impl.h index d6d912260ab..0f5eecef47e 100644 --- a/src/core/lib/transport/transport_impl.h +++ b/src/core/lib/transport/transport_impl.h @@ -38,6 +38,13 @@ typedef struct grpc_transport_vtable { // layers and initialized by the transport size_t sizeof_stream; // = sizeof(transport stream) + // HACK: inproc does not handle stream op batch callbacks correctly (receive + // ops are required to complete prior to on_complete triggering). + // This flag is used to disable coalescing of batches in connected_channel for + // that specific transport. + // TODO(ctiller): This ought not be necessary once we have promises complete. + bool hacky_disable_stream_op_batch_coalescing_in_connected_channel; + // name of this transport implementation const char* name; diff --git a/src/python/grpcio/grpc_core_dependencies.py b/src/python/grpcio/grpc_core_dependencies.py index 83b622e0b38..ccff0cb8be0 100644 --- a/src/python/grpcio/grpc_core_dependencies.py +++ b/src/python/grpcio/grpc_core_dependencies.py @@ -661,6 +661,7 @@ CORE_SOURCE_FILES = [ 'src/core/lib/load_balancing/lb_policy_registry.cc', 'src/core/lib/matchers/matchers.cc', 'src/core/lib/promise/activity.cc', + 'src/core/lib/promise/party.cc', 'src/core/lib/promise/sleep.cc', 'src/core/lib/promise/trace.cc', 'src/core/lib/resolver/resolver.cc', @@ -764,6 +765,7 @@ CORE_SOURCE_FILES = [ 'src/core/lib/surface/server.cc', 'src/core/lib/surface/validate_metadata.cc', 'src/core/lib/surface/version.cc', + 'src/core/lib/transport/batch_builder.cc', 'src/core/lib/transport/bdp_estimator.cc', 'src/core/lib/transport/connectivity_state.cc', 'src/core/lib/transport/error_utils.cc', diff --git a/test/core/end2end/cq_verifier.cc b/test/core/end2end/cq_verifier.cc index 16b7551d950..0be24f53d57 100644 --- a/test/core/end2end/cq_verifier.cc +++ b/test/core/end2end/cq_verifier.cc @@ -28,13 +28,16 @@ #include #include +#include "absl/strings/escaping.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" #include #include #include +#include #include #include #include @@ -120,6 +123,35 @@ int raw_byte_buffer_eq_slice(grpc_byte_buffer* rbb, grpc_slice b) { ok = GRPC_SLICE_LENGTH(a) == GRPC_SLICE_LENGTH(b) && 0 == memcmp(GRPC_SLICE_START_PTR(a), GRPC_SLICE_START_PTR(b), GRPC_SLICE_LENGTH(a)); + if (!ok) { + gpr_log(GPR_ERROR, + "SLICE MISMATCH: left_length=%" PRIuPTR " right_length=%" PRIuPTR, + GRPC_SLICE_LENGTH(a), GRPC_SLICE_LENGTH(b)); + std::string out; + const char* a_str = reinterpret_cast(GRPC_SLICE_START_PTR(a)); + const char* b_str = reinterpret_cast(GRPC_SLICE_START_PTR(b)); + for (size_t i = 0; i < std::max(GRPC_SLICE_LENGTH(a), GRPC_SLICE_LENGTH(b)); + i++) { + if (i >= GRPC_SLICE_LENGTH(a)) { + absl::StrAppend(&out, "\u001b[36m", // cyan + absl::CEscape(absl::string_view(&b_str[i], 1)), + "\u001b[0m"); + } else if (i >= GRPC_SLICE_LENGTH(b)) { + absl::StrAppend(&out, "\u001b[35m", // magenta + absl::CEscape(absl::string_view(&a_str[i], 1)), + "\u001b[0m"); + } else if (a_str[i] == b_str[i]) { + absl::StrAppend(&out, absl::CEscape(absl::string_view(&a_str[i], 1))); + } else { + absl::StrAppend(&out, "\u001b[31m", // red + absl::CEscape(absl::string_view(&a_str[i], 1)), + "\u001b[33m", // yellow + absl::CEscape(absl::string_view(&b_str[i], 1)), + "\u001b[0m"); + } + gpr_log(GPR_ERROR, "%s", out.c_str()); + } + } grpc_slice_unref(a); grpc_slice_unref(b); return ok; diff --git a/test/core/end2end/fixtures/h2_oauth2_common.h b/test/core/end2end/fixtures/h2_oauth2_common.h index ae9103df4b6..bd54b59095c 100644 --- a/test/core/end2end/fixtures/h2_oauth2_common.h +++ b/test/core/end2end/fixtures/h2_oauth2_common.h @@ -65,17 +65,13 @@ class Oauth2Fixture : public SecureFixture { return nullptr; } - static void process_oauth2_success(void*, grpc_auth_context* ctx, + static void process_oauth2_success(void*, grpc_auth_context*, const grpc_metadata* md, size_t md_count, grpc_process_auth_metadata_done_cb cb, void* user_data) { const grpc_metadata* oauth2 = find_metadata(md, md_count, "authorization", oauth2_md()); GPR_ASSERT(oauth2 != nullptr); - grpc_auth_context_add_cstring_property(ctx, client_identity_property_name(), - client_identity()); - GPR_ASSERT(grpc_auth_context_set_peer_identity_property_name( - ctx, client_identity_property_name()) == 1); cb(user_data, oauth2, 1, nullptr, 0, GRPC_STATUS_OK, nullptr); } diff --git a/test/core/end2end/fixtures/proxy.cc b/test/core/end2end/fixtures/proxy.cc index c1b7da1a446..8e99c5b73a7 100644 --- a/test/core/end2end/fixtures/proxy.cc +++ b/test/core/end2end/fixtures/proxy.cc @@ -210,7 +210,7 @@ static void on_p2s_sent_message(void* arg, int success) { grpc_op op; grpc_call_error err; - grpc_byte_buffer_destroy(pc->c2p_msg); + grpc_byte_buffer_destroy(std::exchange(pc->c2p_msg, nullptr)); if (!pc->proxy->shutdown && success) { op.op = GRPC_OP_RECV_MESSAGE; op.flags = 0; diff --git a/test/core/end2end/tests/filter_init_fails.cc b/test/core/end2end/tests/filter_init_fails.cc index ddae01f9f10..753d7cd7294 100644 --- a/test/core/end2end/tests/filter_init_fails.cc +++ b/test/core/end2end/tests/filter_init_fails.cc @@ -42,7 +42,10 @@ #include "src/core/lib/gprpp/status_helper.h" #include "src/core/lib/iomgr/closure.h" #include "src/core/lib/iomgr/error.h" +#include "src/core/lib/promise/arena_promise.h" +#include "src/core/lib/promise/promise.h" #include "src/core/lib/surface/channel_stack_type.h" +#include "src/core/lib/transport/transport.h" #include "test/core/end2end/cq_verifier.h" #include "test/core/end2end/end2end_tests.h" #include "test/core/util/test_config.h" @@ -397,12 +400,23 @@ static grpc_error_handle init_channel_elem( static void destroy_channel_elem(grpc_channel_element* /*elem*/) {} static const grpc_channel_filter test_filter = { - grpc_call_next_op, nullptr, - grpc_channel_next_op, 0, - init_call_elem, grpc_call_stack_ignore_set_pollset_or_pollset_set, - destroy_call_elem, 0, - init_channel_elem, grpc_channel_stack_no_post_init, - destroy_channel_elem, grpc_channel_next_get_info, + grpc_call_next_op, + [](grpc_channel_element*, grpc_core::CallArgs, + grpc_core::NextPromiseFactory) + -> grpc_core::ArenaPromise { + return grpc_core::Immediate(grpc_core::ServerMetadataFromStatus( + absl::PermissionDeniedError("access denied"))); + }, + grpc_channel_next_op, + 0, + init_call_elem, + grpc_call_stack_ignore_set_pollset_or_pollset_set, + destroy_call_elem, + 0, + init_channel_elem, + grpc_channel_stack_no_post_init, + destroy_channel_elem, + grpc_channel_next_get_info, "filter_init_fails"}; //****************************************************************************** diff --git a/test/core/end2end/tests/max_message_length.cc b/test/core/end2end/tests/max_message_length.cc index 3127db53928..e57cc949cfe 100644 --- a/test/core/end2end/tests/max_message_length.cc +++ b/test/core/end2end/tests/max_message_length.cc @@ -82,6 +82,9 @@ static void test_max_message_length_on_request( grpc_status_code status; grpc_call_error error; grpc_slice details; + grpc_slice expect_in_details = grpc_slice_from_copied_string( + send_limit ? "Sent message larger than max (11 vs. 5)" + : "Received message larger than max (11 vs. 5)"); int was_cancelled = 2; grpc_channel_args* client_args = nullptr; @@ -220,13 +223,10 @@ static void test_max_message_length_on_request( done: GPR_ASSERT(status == GRPC_STATUS_RESOURCE_EXHAUSTED); - GPR_ASSERT( - grpc_slice_str_cmp( - details, send_limit - ? "Sent message larger than max (11 vs. 5)" - : "Received message larger than max (11 vs. 5)") == 0); + GPR_ASSERT(grpc_slice_slice(details, expect_in_details) >= 0); grpc_slice_unref(details); + grpc_slice_unref(expect_in_details); grpc_metadata_array_destroy(&initial_metadata_recv); grpc_metadata_array_destroy(&trailing_metadata_recv); grpc_metadata_array_destroy(&request_metadata_recv); @@ -265,6 +265,9 @@ static void test_max_message_length_on_response( grpc_status_code status; grpc_call_error error; grpc_slice details; + grpc_slice expect_in_details = grpc_slice_from_copied_string( + send_limit ? "Sent message larger than max (11 vs. 5)" + : "Received message larger than max (11 vs. 5)"); int was_cancelled = 2; grpc_channel_args* client_args = nullptr; @@ -404,13 +407,10 @@ static void test_max_message_length_on_response( GPR_ASSERT(0 == grpc_slice_str_cmp(call_details.method, "/service/method")); GPR_ASSERT(status == GRPC_STATUS_RESOURCE_EXHAUSTED); - GPR_ASSERT( - grpc_slice_str_cmp( - details, send_limit - ? "Sent message larger than max (11 vs. 5)" - : "Received message larger than max (11 vs. 5)") == 0); + GPR_ASSERT(grpc_slice_slice(details, expect_in_details) >= 0); grpc_slice_unref(details); + grpc_slice_unref(expect_in_details); grpc_metadata_array_destroy(&initial_metadata_recv); grpc_metadata_array_destroy(&trailing_metadata_recv); grpc_metadata_array_destroy(&request_metadata_recv); diff --git a/test/core/filters/client_auth_filter_test.cc b/test/core/filters/client_auth_filter_test.cc index 12f0e253ac4..e0711c453b6 100644 --- a/test/core/filters/client_auth_filter_test.cc +++ b/test/core/filters/client_auth_filter_test.cc @@ -154,7 +154,8 @@ TEST_F(ClientAuthFilterTest, CallCredsFails) { auto promise = filter->MakeCallPromise( CallArgs{ClientMetadataHandle(&initial_metadata_batch_, Arena::PooledDeleter(nullptr)), - nullptr, nullptr, nullptr}, + ClientInitialMetadataOutstandingToken::Empty(), nullptr, nullptr, + nullptr}, [&](CallArgs /*call_args*/) { return ArenaPromise( [&]() -> Poll { @@ -183,7 +184,8 @@ TEST_F(ClientAuthFilterTest, RewritesInvalidStatusFromCallCreds) { auto promise = filter->MakeCallPromise( CallArgs{ClientMetadataHandle(&initial_metadata_batch_, Arena::PooledDeleter(nullptr)), - nullptr, nullptr, nullptr}, + ClientInitialMetadataOutstandingToken::Empty(), nullptr, nullptr, + nullptr}, [&](CallArgs /*call_args*/) { return ArenaPromise( [&]() -> Poll { diff --git a/test/core/filters/client_authority_filter_test.cc b/test/core/filters/client_authority_filter_test.cc index ae86f71e413..df2656e637f 100644 --- a/test/core/filters/client_authority_filter_test.cc +++ b/test/core/filters/client_authority_filter_test.cc @@ -71,7 +71,8 @@ TEST(ClientAuthorityFilterTest, PromiseCompletesImmediatelyAndSetsAuthority) { auto promise = filter.MakeCallPromise( CallArgs{ClientMetadataHandle(&initial_metadata_batch, Arena::PooledDeleter(nullptr)), - nullptr, nullptr, nullptr}, + ClientInitialMetadataOutstandingToken::Empty(), nullptr, nullptr, + nullptr}, [&](CallArgs call_args) { EXPECT_EQ(call_args.client_initial_metadata ->get_pointer(HttpAuthorityMetadata()) @@ -106,7 +107,8 @@ TEST(ClientAuthorityFilterTest, auto promise = filter.MakeCallPromise( CallArgs{ClientMetadataHandle(&initial_metadata_batch, Arena::PooledDeleter(nullptr)), - nullptr, nullptr, nullptr}, + ClientInitialMetadataOutstandingToken::Empty(), nullptr, nullptr, + nullptr}, [&](CallArgs call_args) { EXPECT_EQ(call_args.client_initial_metadata ->get_pointer(HttpAuthorityMetadata()) diff --git a/test/core/filters/filter_fuzzer.cc b/test/core/filters/filter_fuzzer.cc index fe0ccda0625..ebf985351bc 100644 --- a/test/core/filters/filter_fuzzer.cc +++ b/test/core/filters/filter_fuzzer.cc @@ -110,6 +110,8 @@ namespace { const grpc_transport_vtable kFakeTransportVTable = { // sizeof_stream 0, + // hacky_disable_stream_op_batch_coalescing_in_connected_channel + false, // name "fake_transport", // init_stream @@ -402,16 +404,16 @@ class MainLoop { public: WakeCall(MainLoop* main_loop, uint32_t id) : main_loop_(main_loop), id_(id) {} - void Wakeup(void*) override { + void Wakeup(WakeupMask) override { for (const uint32_t already : main_loop_->wakeups_) { if (already == id_) return; } main_loop_->wakeups_.push_back(id_); delete this; } - void Drop(void*) override { delete this; } + void Drop(WakeupMask) override { delete this; } - std::string ActivityDebugTag(void*) const override { + std::string ActivityDebugTag(WakeupMask) const override { return "WakeCall(" + std::to_string(id_) + ")"; } @@ -476,6 +478,7 @@ class MainLoop { auto* server_initial_metadata = arena_->New>(); CallArgs call_args{std::move(*LoadMetadata(client_initial_metadata, &client_initial_metadata_)), + ClientInitialMetadataOutstandingToken::Empty(), &server_initial_metadata->sender, nullptr, nullptr}; if (is_client) { promise_ = main_loop_->channel_stack_->MakeClientCallPromise( @@ -524,9 +527,9 @@ class MainLoop { } void Orphan() override { abort(); } - void ForceImmediateRepoll() override { context_->set_continue(); } + void ForceImmediateRepoll(WakeupMask) override { context_->set_continue(); } Waker MakeOwningWaker() override { - return Waker(new WakeCall(main_loop_, id_), nullptr); + return Waker(new WakeCall(main_loop_, id_), 0); } Waker MakeNonOwningWaker() override { return MakeOwningWaker(); } diff --git a/test/core/gprpp/ref_counted_test.cc b/test/core/gprpp/ref_counted_test.cc index 990acf243fa..0d58d1338e9 100644 --- a/test/core/gprpp/ref_counted_test.cc +++ b/test/core/gprpp/ref_counted_test.cc @@ -53,7 +53,7 @@ TEST(RefCounted, ExtraRef) { foo->Unref(); } -class Value : public RefCounted { +class Value : public RefCounted { public: Value(int value, std::set>* registry) : value_(value) { registry->emplace(this); @@ -108,7 +108,7 @@ TEST(RefCounted, NoDeleteUponUnref) { class ValueInExternalAllocation : public RefCounted { + UnrefCallDtor> { public: explicit ValueInExternalAllocation(int value) : value_(value) {} diff --git a/test/core/gprpp/thd_test.cc b/test/core/gprpp/thd_test.cc index 7561965fd98..be10ca1e8cd 100644 --- a/test/core/gprpp/thd_test.cc +++ b/test/core/gprpp/thd_test.cc @@ -20,6 +20,8 @@ #include "src/core/lib/gprpp/thd.h" +#include + #include "gtest/gtest.h" #include @@ -49,7 +51,7 @@ static void thd_body1(void* v) { } // Test that we can create a number of threads, wait for them, and join them. -static void test1(void) { +TEST(ThreadTest, CanCreateWaitAndJoin) { grpc_core::Thread thds[NUM_THREADS]; struct test t; gpr_mu_init(&t.mu); @@ -76,7 +78,7 @@ static void test1(void) { static void thd_body2(void* /*v*/) {} // Test that we can create a number of threads and join them. -static void test2(void) { +TEST(ThreadTest, CanCreateSomeAndJoinThem) { grpc_core::Thread thds[NUM_THREADS]; for (auto& th : thds) { bool ok; @@ -89,11 +91,23 @@ static void test2(void) { } } -// ------------------------------------------------- - -TEST(ThdTest, MainTest) { - test1(); - test2(); +// Test that we can create a thread with an AnyInvocable. +TEST(ThreadTest, CanCreateWithAnyInvocable) { + grpc_core::Thread thds[NUM_THREADS]; + std::atomic count_run{0}; + for (auto& th : thds) { + bool ok; + th = grpc_core::Thread( + "grpc_thread_body2_test", + [&count_run]() { count_run.fetch_add(1, std::memory_order_relaxed); }, + &ok); + ASSERT_TRUE(ok); + th.Start(); + } + for (auto& th : thds) { + th.Join(); + } + EXPECT_EQ(count_run.load(std::memory_order_relaxed), NUM_THREADS); } int main(int argc, char** argv) { diff --git a/test/core/promise/BUILD b/test/core/promise/BUILD index 313276e62bf..baaade68659 100644 --- a/test/core/promise/BUILD +++ b/test/core/promise/BUILD @@ -127,7 +127,10 @@ grpc_cc_test( # is. name = "promise_map_test", srcs = ["map_test.cc"], - external_deps = ["gtest"], + external_deps = [ + "absl/functional:any_invocable", + "gtest", + ], language = "c++", tags = ["promise_test"], uses_event_engine = False, @@ -164,7 +167,6 @@ grpc_cc_test( uses_event_engine = False, uses_polling = False, deps = [ - "//:promise", "//src/core:poll", "//src/core:promise_factory", ], @@ -307,25 +309,6 @@ grpc_cc_test( ], ) -grpc_cc_test( - name = "observable_test", - srcs = ["observable_test.cc"], - external_deps = [ - "absl/status", - "gtest", - ], - language = "c++", - tags = ["promise_test"], - uses_event_engine = False, - uses_polling = False, - deps = [ - "test_wakeup_schedulers", - "//:promise", - "//src/core:observable", - "//src/core:seq", - ], -) - grpc_cc_test( name = "for_each_test", srcs = ["for_each_test.cc"], @@ -385,6 +368,7 @@ grpc_cc_test( name = "pipe_test", srcs = ["pipe_test.cc"], external_deps = [ + "absl/functional:function_ref", "absl/status", "gtest", ], @@ -394,6 +378,7 @@ grpc_cc_test( uses_polling = False, deps = [ "test_wakeup_schedulers", + "//:gpr", "//:grpc", "//:ref_counted_ptr", "//src/core:activity", @@ -432,6 +417,7 @@ grpc_proto_fuzzer( srcs = ["promise_fuzzer.cc"], corpus = "promise_fuzzer_corpus", external_deps = [ + "absl/functional:any_invocable", "absl/status", "absl/types:optional", ], @@ -526,7 +512,6 @@ grpc_cc_test( "//:exec_ctx", "//:gpr", "//:grpc_unsecure", - "//:orphanable", "//:ref_counted_ptr", "//src/core:1999", "//src/core:context", diff --git a/test/core/promise/if_test.cc b/test/core/promise/if_test.cc index 0a2ccc5cbe1..9965fa0b893 100644 --- a/test/core/promise/if_test.cc +++ b/test/core/promise/if_test.cc @@ -14,8 +14,6 @@ #include "src/core/lib/promise/if.h" -#include - #include "gtest/gtest.h" namespace grpc_core { diff --git a/test/core/promise/latch_test.cc b/test/core/promise/latch_test.cc index 07ab6ade5f6..3b18efc48c9 100644 --- a/test/core/promise/latch_test.cc +++ b/test/core/promise/latch_test.cc @@ -52,6 +52,33 @@ TEST(LatchTest, Works) { [&on_done](absl::Status status) { on_done.Call(std::move(status)); }); } +TEST(LatchTest, WaitAndCopyWorks) { + Latch latch; + StrictMock> on_done; + EXPECT_CALL(on_done, Call(absl::OkStatus())); + MakeActivity( + [&latch] { + return Seq(Join(latch.WaitAndCopy(), latch.WaitAndCopy(), + [&latch]() { + latch.Set( + "Once a jolly swagman camped by a billabong, " + "under the shade of a coolibah tree."); + return true; + }), + [](std::tuple result) { + EXPECT_EQ(std::get<0>(result), + "Once a jolly swagman camped by a billabong, " + "under the shade of a coolibah tree."); + EXPECT_EQ(std::get<1>(result), + "Once a jolly swagman camped by a billabong, " + "under the shade of a coolibah tree."); + return absl::OkStatus(); + }); + }, + NoWakeupScheduler(), + [&on_done](absl::Status status) { on_done.Call(std::move(status)); }); +} + TEST(LatchTest, Void) { Latch latch; StrictMock> on_done; @@ -69,6 +96,23 @@ TEST(LatchTest, Void) { [&on_done](absl::Status status) { on_done.Call(std::move(status)); }); } +TEST(LatchTest, ExternallyObservableVoid) { + ExternallyObservableLatch latch; + StrictMock> on_done; + EXPECT_CALL(on_done, Call(absl::OkStatus())); + MakeActivity( + [&latch] { + return Seq(Join(latch.Wait(), + [&latch]() { + latch.Set(); + return true; + }), + [](std::tuple) { return absl::OkStatus(); }); + }, + NoWakeupScheduler(), + [&on_done](absl::Status status) { on_done.Call(std::move(status)); }); +} + } // namespace grpc_core int main(int argc, char** argv) { diff --git a/test/core/promise/loop_test.cc b/test/core/promise/loop_test.cc index 36beb0047e5..30fa4d1dd95 100644 --- a/test/core/promise/loop_test.cc +++ b/test/core/promise/loop_test.cc @@ -14,6 +14,8 @@ #include "src/core/lib/promise/loop.h" +#include + #include "gtest/gtest.h" #include "src/core/lib/promise/detail/basic_seq.h" @@ -49,6 +51,20 @@ TEST(LoopTest, LoopOfSeq) { EXPECT_EQ(x, Poll(42)); } +TEST(LoopTest, CanAccessFactoryLambdaVariables) { + int i = 0; + auto x = Loop([p = &i]() { + return [q = &p]() -> Poll> { + ++**q; + return Pending{}; + }; + }); + auto y = std::move(x); + auto z = std::move(y); + z(); + EXPECT_EQ(i, 1); +} + } // namespace grpc_core int main(int argc, char** argv) { diff --git a/test/core/promise/map_test.cc b/test/core/promise/map_test.cc index d266654f89f..eacc66674d3 100644 --- a/test/core/promise/map_test.cc +++ b/test/core/promise/map_test.cc @@ -14,8 +14,7 @@ #include "src/core/lib/promise/map.h" -#include - +#include "absl/functional/any_invocable.h" #include "gtest/gtest.h" #include "src/core/lib/promise/promise.h" diff --git a/test/core/promise/mpsc_test.cc b/test/core/promise/mpsc_test.cc index 8001793cf98..0a724cd9409 100644 --- a/test/core/promise/mpsc_test.cc +++ b/test/core/promise/mpsc_test.cc @@ -36,14 +36,14 @@ class MockActivity : public Activity, public Wakeable { public: MOCK_METHOD(void, WakeupRequested, ()); - void ForceImmediateRepoll() override { WakeupRequested(); } + void ForceImmediateRepoll(WakeupMask) override { WakeupRequested(); } void Orphan() override {} - Waker MakeOwningWaker() override { return Waker(this, nullptr); } - Waker MakeNonOwningWaker() override { return Waker(this, nullptr); } - void Wakeup(void*) override { WakeupRequested(); } - void Drop(void*) override {} + Waker MakeOwningWaker() override { return Waker(this, 0); } + Waker MakeNonOwningWaker() override { return Waker(this, 0); } + void Wakeup(WakeupMask) override { WakeupRequested(); } + void Drop(WakeupMask) override {} std::string DebugTag() const override { return "MockActivity"; } - std::string ActivityDebugTag(void*) const override { return DebugTag(); } + std::string ActivityDebugTag(WakeupMask) const override { return DebugTag(); } void Activate() { if (scoped_activity_ != nullptr) return; diff --git a/test/core/promise/observable_test.cc b/test/core/promise/observable_test.cc deleted file mode 100644 index c4bb92572e5..00000000000 --- a/test/core/promise/observable_test.cc +++ /dev/null @@ -1,134 +0,0 @@ -// Copyright 2021 gRPC authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "src/core/lib/promise/observable.h" - -#include - -#include "absl/status/status.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" - -#include "src/core/lib/promise/promise.h" -#include "src/core/lib/promise/seq.h" -#include "test/core/promise/test_wakeup_schedulers.h" - -using testing::MockFunction; -using testing::StrictMock; - -namespace grpc_core { - -// A simple Barrier type: stalls progress until it is 'cleared'. -class Barrier { - public: - struct Result {}; - - Promise Wait() { - return [this]() -> Poll { - MutexLock lock(&mu_); - if (cleared_) { - return Result{}; - } else { - return wait_set_.AddPending(Activity::current()->MakeOwningWaker()); - } - }; - } - - void Clear() { - mu_.Lock(); - cleared_ = true; - auto wakeup = wait_set_.TakeWakeupSet(); - mu_.Unlock(); - wakeup.Wakeup(); - } - - private: - Mutex mu_; - WaitSet wait_set_ ABSL_GUARDED_BY(mu_); - bool cleared_ ABSL_GUARDED_BY(mu_) = false; -}; - -TEST(ObservableTest, CanPushAndGet) { - StrictMock> on_done; - Observable observable; - auto observer = observable.MakeObserver(); - auto activity = MakeActivity( - [&observer]() { - return Seq(observer.Get(), [](absl::optional i) { - return i == 42 ? absl::OkStatus() : absl::UnknownError("expected 42"); - }); - }, - InlineWakeupScheduler(), - [&on_done](absl::Status status) { on_done.Call(std::move(status)); }); - EXPECT_CALL(on_done, Call(absl::OkStatus())); - observable.Push(42); -} - -TEST(ObservableTest, CanNext) { - StrictMock> on_done; - Observable observable; - auto observer = observable.MakeObserver(); - auto activity = MakeActivity( - [&observer]() { - return Seq( - observer.Get(), - [&observer](absl::optional i) { - EXPECT_EQ(i, 42); - return observer.Next(); - }, - [](absl::optional i) { - return i == 1 ? absl::OkStatus() - : absl::UnknownError("expected 1"); - }); - }, - InlineWakeupScheduler(), - [&on_done](absl::Status status) { on_done.Call(std::move(status)); }); - observable.Push(42); - EXPECT_CALL(on_done, Call(absl::OkStatus())); - observable.Push(1); -} - -TEST(ObservableTest, CanWatch) { - StrictMock> on_done; - Observable observable; - Barrier barrier; - auto activity = MakeActivity( - [&observable, &barrier]() { - return observable.Watch( - [&barrier](int x, - WatchCommitter* committer) -> Promise { - if (x == 3) { - committer->Commit(); - return Seq(barrier.Wait(), Immediate(absl::OkStatus())); - } else { - return Never(); - } - }); - }, - InlineWakeupScheduler(), - [&on_done](absl::Status status) { on_done.Call(std::move(status)); }); - observable.Push(1); - observable.Push(2); - observable.Push(3); - observable.Push(4); - EXPECT_CALL(on_done, Call(absl::OkStatus())); - barrier.Clear(); -} - -} // namespace grpc_core - -int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} diff --git a/test/core/promise/party_test.cc b/test/core/promise/party_test.cc index bd0a6c288bc..ad6339be2eb 100644 --- a/test/core/promise/party_test.cc +++ b/test/core/promise/party_test.cc @@ -14,7 +14,10 @@ #include "src/core/lib/promise/party.h" +#include + #include +#include #include #include #include @@ -28,7 +31,6 @@ #include "src/core/lib/event_engine/default_event_engine.h" #include "src/core/lib/gprpp/notification.h" -#include "src/core/lib/gprpp/orphanable.h" #include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/gprpp/sync.h" #include "src/core/lib/gprpp/time.h" @@ -42,21 +44,216 @@ namespace grpc_core { +/////////////////////////////////////////////////////////////////////////////// +// PartySyncTest + +template +class PartySyncTest : public ::testing::Test {}; + +using PartySyncTypes = + ::testing::Types; +TYPED_TEST_SUITE(PartySyncTest, PartySyncTypes); + +TYPED_TEST(PartySyncTest, NoOp) { TypeParam sync(1); } + +TYPED_TEST(PartySyncTest, RefAndUnref) { + Notification half_way; + TypeParam sync(1); + std::thread thread1([&] { + for (int i = 0; i < 1000000; i++) { + sync.IncrementRefCount(); + } + half_way.Notify(); + for (int i = 0; i < 1000000; i++) { + sync.IncrementRefCount(); + } + for (int i = 0; i < 2000000; i++) { + EXPECT_FALSE(sync.Unref()); + } + }); + half_way.WaitForNotification(); + for (int i = 0; i < 2000000; i++) { + sync.IncrementRefCount(); + } + for (int i = 0; i < 2000000; i++) { + EXPECT_FALSE(sync.Unref()); + } + thread1.join(); + EXPECT_TRUE(sync.Unref()); +} + +TYPED_TEST(PartySyncTest, AddAndRemoveParticipant) { + TypeParam sync(1); + std::vector threads; + std::atomic*> participants[party_detail::kMaxParticipants] = + {}; + threads.reserve(8); + for (int i = 0; i < 8; i++) { + threads.emplace_back([&] { + for (int i = 0; i < 100000; i++) { + auto done = std::make_unique>(false); + int slot = -1; + bool run = sync.AddParticipantsAndRef(1, [&](size_t* idxs) { + slot = idxs[0]; + participants[slot].store(done.get(), std::memory_order_release); + }); + EXPECT_NE(slot, -1); + if (run) { + bool run_any = false; + bool run_me = false; + EXPECT_FALSE(sync.RunParty([&](int slot) { + run_any = true; + std::atomic* participant = + participants[slot].exchange(nullptr, std::memory_order_acquire); + if (participant == done.get()) run_me = true; + if (participant == nullptr) { + gpr_log(GPR_ERROR, + "Participant was null (spurious wakeup observed)"); + return false; + } + participant->store(true, std::memory_order_release); + return true; + })); + EXPECT_TRUE(run_any); + EXPECT_TRUE(run_me); + } + EXPECT_FALSE(sync.Unref()); + while (!done->load(std::memory_order_acquire)) { + } + } + }); + } + for (auto& thread : threads) { + thread.join(); + } + EXPECT_TRUE(sync.Unref()); +} + +TYPED_TEST(PartySyncTest, AddAndRemoveTwoParticipants) { + TypeParam sync(1); + std::vector threads; + std::atomic*> participants[party_detail::kMaxParticipants] = + {}; + threads.reserve(8); + for (int i = 0; i < 4; i++) { + threads.emplace_back([&] { + for (int i = 0; i < 100000; i++) { + auto done = std::make_unique>(2); + int slots[2] = {-1, -1}; + bool run = sync.AddParticipantsAndRef(2, [&](size_t* idxs) { + for (int i = 0; i < 2; i++) { + slots[i] = idxs[i]; + participants[slots[i]].store(done.get(), std::memory_order_release); + } + }); + EXPECT_NE(slots[0], -1); + EXPECT_NE(slots[1], -1); + EXPECT_GT(slots[1], slots[0]); + if (run) { + bool run_any = false; + int run_me = 0; + EXPECT_FALSE(sync.RunParty([&](int slot) { + run_any = true; + std::atomic* participant = + participants[slot].exchange(nullptr, std::memory_order_acquire); + if (participant == done.get()) run_me++; + if (participant == nullptr) { + gpr_log(GPR_ERROR, + "Participant was null (spurious wakeup observed)"); + return false; + } + participant->fetch_sub(1, std::memory_order_release); + return true; + })); + EXPECT_TRUE(run_any); + EXPECT_EQ(run_me, 2); + } + EXPECT_FALSE(sync.Unref()); + while (done->load(std::memory_order_acquire) != 0) { + } + } + }); + } + for (auto& thread : threads) { + thread.join(); + } + EXPECT_TRUE(sync.Unref()); +} + +TYPED_TEST(PartySyncTest, UnrefWhileRunning) { + std::vector trials; + std::atomic delete_paths_taken[3] = {{0}, {0}, {0}}; + trials.reserve(100); + for (int i = 0; i < 100; i++) { + trials.emplace_back([&delete_paths_taken] { + TypeParam sync(1); + int delete_path = -1; + EXPECT_TRUE(sync.AddParticipantsAndRef( + 1, [](size_t* slots) { EXPECT_EQ(slots[0], 0); })); + std::thread run_party([&] { + if (sync.RunParty([&sync, n = 0](int slot) mutable { + EXPECT_EQ(slot, 0); + ++n; + if (n < 10) { + sync.ForceImmediateRepoll(1); + return false; + } + return true; + })) { + delete_path = 0; + } + }); + std::thread unref([&] { + if (sync.Unref()) delete_path = 1; + }); + if (sync.Unref()) delete_path = 2; + run_party.join(); + unref.join(); + EXPECT_GE(delete_path, 0); + delete_paths_taken[delete_path].fetch_add(1, std::memory_order_relaxed); + }); + } + for (auto& trial : trials) { + trial.join(); + } + fprintf(stderr, "DELETE_PATHS: RunParty:%d AsyncUnref:%d SyncUnref:%d\n", + delete_paths_taken[0].load(), delete_paths_taken[1].load(), + delete_paths_taken[2].load()); +} + +/////////////////////////////////////////////////////////////////////////////// +// PartyTest + class AllocatorOwner { protected: + ~AllocatorOwner() { arena_->Destroy(); } MemoryAllocator memory_allocator_ = MemoryAllocator( ResourceQuota::Default()->memory_quota()->CreateMemoryAllocator("test")); + Arena* arena_ = Arena::Create(1024, &memory_allocator_); }; class TestParty final : public AllocatorOwner, public Party { public: - TestParty() : Party(Arena::Create(1024, &memory_allocator_)) {} + TestParty() : Party(AllocatorOwner::arena_, 1) {} + ~TestParty() override {} std::string DebugTag() const override { return "TestParty"; } - void Run() override { + using Party::IncrementRefCount; + using Party::Unref; + + bool RunParty() override { promise_detail::Context ee_ctx(ee_.get()); - Party::Run(); + return Party::RunParty(); + } + + void PartyOver() override { + { + promise_detail::Context + ee_ctx(ee_.get()); + CancelRemainingParticipants(); + } + delete this; } private: @@ -68,34 +265,39 @@ class PartyTest : public ::testing::Test { protected: }; -TEST_F(PartyTest, Noop) { auto party = MakeOrphanable(); } +TEST_F(PartyTest, Noop) { auto party = MakeRefCounted(); } TEST_F(PartyTest, CanSpawnAndRun) { - auto party = MakeOrphanable(); - bool done = false; + auto party = MakeRefCounted(); + Notification n; party->Spawn( + "TestSpawn", [i = 10]() mutable -> Poll { EXPECT_EQ(Activity::current()->DebugTag(), "TestParty"); + gpr_log(GPR_DEBUG, "i=%d", i); + GPR_ASSERT(i > 0); Activity::current()->ForceImmediateRepoll(); --i; if (i == 0) return 42; return Pending{}; }, - [&done](int x) { + [&n](int x) { EXPECT_EQ(x, 42); - done = true; + n.Notify(); }); - EXPECT_TRUE(done); + n.WaitForNotification(); } TEST_F(PartyTest, CanSpawnFromSpawn) { - auto party = MakeOrphanable(); - bool done1 = false; - bool done2 = false; + auto party = MakeRefCounted(); + Notification n1; + Notification n2; party->Spawn( - [party = party.get(), &done2]() -> Poll { + "TestSpawn", + [party, &n2]() -> Poll { EXPECT_EQ(Activity::current()->DebugTag(), "TestParty"); party->Spawn( + "TestSpawnInner", [i = 10]() mutable -> Poll { EXPECT_EQ(Activity::current()->DebugTag(), "TestParty"); Activity::current()->ForceImmediateRepoll(); @@ -103,147 +305,166 @@ TEST_F(PartyTest, CanSpawnFromSpawn) { if (i == 0) return 42; return Pending{}; }, - [&done2](int x) { + [&n2](int x) { EXPECT_EQ(x, 42); - done2 = true; + n2.Notify(); }); return 1234; }, - [&done1](int x) { + [&n1](int x) { EXPECT_EQ(x, 1234); - done1 = true; + n1.Notify(); }); - EXPECT_TRUE(done1); - EXPECT_TRUE(done2); + n1.WaitForNotification(); + n2.WaitForNotification(); } TEST_F(PartyTest, CanWakeupWithOwningWaker) { - auto party = MakeOrphanable(); - bool done = false; + auto party = MakeRefCounted(); + Notification n[10]; + Notification complete; Waker waker; party->Spawn( - [i = 10, &waker]() mutable -> Poll { + "TestSpawn", + [i = 0, &waker, &n]() mutable -> Poll { EXPECT_EQ(Activity::current()->DebugTag(), "TestParty"); waker = Activity::current()->MakeOwningWaker(); - --i; - if (i == 0) return 42; + n[i].Notify(); + i++; + if (i == 10) return 42; return Pending{}; }, - [&done](int x) { + [&complete](int x) { EXPECT_EQ(x, 42); - done = true; + complete.Notify(); }); - for (int i = 0; i < 9; i++) { - EXPECT_FALSE(done) << i; + for (int i = 0; i < 10; i++) { + n[i].WaitForNotification(); waker.Wakeup(); } - EXPECT_TRUE(done); + complete.WaitForNotification(); } TEST_F(PartyTest, CanWakeupWithNonOwningWaker) { - auto party = MakeOrphanable(); - bool done = false; + auto party = MakeRefCounted(); + Notification n[10]; + Notification complete; Waker waker; party->Spawn( - [i = 10, &waker]() mutable -> Poll { + "TestSpawn", + [i = 10, &waker, &n]() mutable -> Poll { EXPECT_EQ(Activity::current()->DebugTag(), "TestParty"); waker = Activity::current()->MakeNonOwningWaker(); --i; + n[9 - i].Notify(); if (i == 0) return 42; return Pending{}; }, - [&done](int x) { + [&complete](int x) { EXPECT_EQ(x, 42); - done = true; + complete.Notify(); }); for (int i = 0; i < 9; i++) { - EXPECT_FALSE(done) << i; + n[i].WaitForNotification(); + EXPECT_FALSE(n[i + 1].HasBeenNotified()); waker.Wakeup(); } - EXPECT_TRUE(done); + complete.WaitForNotification(); } TEST_F(PartyTest, CanWakeupWithNonOwningWakerAfterOrphaning) { - auto party = MakeOrphanable(); - bool done = false; + auto party = MakeRefCounted(); + Notification set_waker; Waker waker; party->Spawn( - [i = 10, &waker]() mutable -> Poll { + "TestSpawn", + [&waker, &set_waker]() mutable -> Poll { + EXPECT_FALSE(set_waker.HasBeenNotified()); EXPECT_EQ(Activity::current()->DebugTag(), "TestParty"); waker = Activity::current()->MakeNonOwningWaker(); - --i; - if (i == 0) return 42; + set_waker.Notify(); return Pending{}; }, - [&done](int x) { - EXPECT_EQ(x, 42); - done = true; - }); + [](int) { Crash("unreachable"); }); + set_waker.WaitForNotification(); party.reset(); - EXPECT_FALSE(done); EXPECT_FALSE(waker.is_unwakeable()); waker.Wakeup(); EXPECT_TRUE(waker.is_unwakeable()); - EXPECT_FALSE(done); } TEST_F(PartyTest, CanDropNonOwningWakeAfterOrphaning) { - auto party = MakeOrphanable(); - bool done = false; + auto party = MakeRefCounted(); + Notification set_waker; std::unique_ptr waker; party->Spawn( - [i = 10, &waker]() mutable -> Poll { + "TestSpawn", + [&waker, &set_waker]() mutable -> Poll { + EXPECT_FALSE(set_waker.HasBeenNotified()); EXPECT_EQ(Activity::current()->DebugTag(), "TestParty"); waker = std::make_unique(Activity::current()->MakeNonOwningWaker()); - --i; - if (i == 0) return 42; + set_waker.Notify(); return Pending{}; }, - [&done](int x) { - EXPECT_EQ(x, 42); - done = true; - }); + [](int) { Crash("unreachable"); }); + set_waker.WaitForNotification(); party.reset(); EXPECT_NE(waker, nullptr); waker.reset(); - EXPECT_FALSE(done); } TEST_F(PartyTest, CanWakeupNonOwningOrphanedWakerWithNoEffect) { - auto party = MakeOrphanable(); - bool done = false; + auto party = MakeRefCounted(); + Notification set_waker; Waker waker; party->Spawn( - [i = 10, &waker]() mutable -> Poll { + "TestSpawn", + [&waker, &set_waker]() mutable -> Poll { + EXPECT_FALSE(set_waker.HasBeenNotified()); EXPECT_EQ(Activity::current()->DebugTag(), "TestParty"); waker = Activity::current()->MakeNonOwningWaker(); - --i; - if (i == 0) return 42; + set_waker.Notify(); return Pending{}; }, - [&done](int x) { - EXPECT_EQ(x, 42); - done = true; - }); - EXPECT_FALSE(done); + [](int) { Crash("unreachable"); }); + set_waker.WaitForNotification(); EXPECT_FALSE(waker.is_unwakeable()); party.reset(); waker.Wakeup(); - EXPECT_FALSE(done); EXPECT_TRUE(waker.is_unwakeable()); } +TEST_F(PartyTest, CanBulkSpawn) { + auto party = MakeRefCounted(); + Notification n1; + Notification n2; + { + Party::BulkSpawner spawner(party.get()); + spawner.Spawn( + "spawn1", []() { return Empty{}; }, [&n1](Empty) { n1.Notify(); }); + spawner.Spawn( + "spawn2", []() { return Empty{}; }, [&n2](Empty) { n2.Notify(); }); + for (int i = 0; i < 5000; i++) { + EXPECT_FALSE(n1.HasBeenNotified()); + EXPECT_FALSE(n2.HasBeenNotified()); + } + } + n1.WaitForNotification(); + n2.WaitForNotification(); +} + TEST_F(PartyTest, ThreadStressTest) { - auto party = MakeOrphanable(); + auto party = MakeRefCounted(); std::vector threads; - threads.reserve(16); - for (int i = 0; i < 16; i++) { - threads.emplace_back([party = party.get()]() { + threads.reserve(8); + for (int i = 0; i < 8; i++) { + threads.emplace_back([party]() { for (int i = 0; i < 100; i++) { ExecCtx ctx; // needed for Sleep Notification promise_complete; - party->Spawn(Seq(Sleep(Timestamp::Now() + Duration::Milliseconds(10)), + party->Spawn("TestSpawn", + Seq(Sleep(Timestamp::Now() + Duration::Milliseconds(10)), []() -> Poll { return 42; }), [&promise_complete](int i) { EXPECT_EQ(i, 42); @@ -298,16 +519,17 @@ class PromiseNotification { }; TEST_F(PartyTest, ThreadStressTestWithOwningWaker) { - auto party = MakeOrphanable(); + auto party = MakeRefCounted(); std::vector threads; - threads.reserve(16); - for (int i = 0; i < 16; i++) { - threads.emplace_back([party = party.get()]() { + threads.reserve(8); + for (int i = 0; i < 8; i++) { + threads.emplace_back([party]() { for (int i = 0; i < 100; i++) { ExecCtx ctx; // needed for Sleep PromiseNotification promise_start(true); Notification promise_complete; - party->Spawn(Seq(promise_start.Wait(), + party->Spawn("TestSpawn", + Seq(promise_start.Wait(), Sleep(Timestamp::Now() + Duration::Milliseconds(10)), []() -> Poll { return 42; }), [&promise_complete](int i) { @@ -325,16 +547,17 @@ TEST_F(PartyTest, ThreadStressTestWithOwningWaker) { } TEST_F(PartyTest, ThreadStressTestWithNonOwningWaker) { - auto party = MakeOrphanable(); + auto party = MakeRefCounted(); std::vector threads; - threads.reserve(16); - for (int i = 0; i < 16; i++) { - threads.emplace_back([party = party.get()]() { + threads.reserve(8); + for (int i = 0; i < 8; i++) { + threads.emplace_back([party]() { for (int i = 0; i < 100; i++) { ExecCtx ctx; // needed for Sleep PromiseNotification promise_start(false); Notification promise_complete; - party->Spawn(Seq(promise_start.Wait(), + party->Spawn("TestSpawn", + Seq(promise_start.Wait(), Sleep(Timestamp::Now() + Duration::Milliseconds(10)), []() -> Poll { return 42; }), [&promise_complete](int i) { @@ -352,15 +575,16 @@ TEST_F(PartyTest, ThreadStressTestWithNonOwningWaker) { } TEST_F(PartyTest, ThreadStressTestWithOwningWakerNoSleep) { - auto party = MakeOrphanable(); + auto party = MakeRefCounted(); std::vector threads; - threads.reserve(16); - for (int i = 0; i < 16; i++) { - threads.emplace_back([party = party.get()]() { + threads.reserve(8); + for (int i = 0; i < 8; i++) { + threads.emplace_back([party]() { for (int i = 0; i < 10000; i++) { PromiseNotification promise_start(true); Notification promise_complete; party->Spawn( + "TestSpawn", Seq(promise_start.Wait(), []() -> Poll { return 42; }), [&promise_complete](int i) { EXPECT_EQ(i, 42); @@ -377,15 +601,16 @@ TEST_F(PartyTest, ThreadStressTestWithOwningWakerNoSleep) { } TEST_F(PartyTest, ThreadStressTestWithNonOwningWakerNoSleep) { - auto party = MakeOrphanable(); + auto party = MakeRefCounted(); std::vector threads; - threads.reserve(16); - for (int i = 0; i < 16; i++) { - threads.emplace_back([party = party.get()]() { + threads.reserve(8); + for (int i = 0; i < 8; i++) { + threads.emplace_back([party]() { for (int i = 0; i < 10000; i++) { PromiseNotification promise_start(false); Notification promise_complete; party->Spawn( + "TestSpawn", Seq(promise_start.Wait(), []() -> Poll { return 42; }), [&promise_complete](int i) { EXPECT_EQ(i, 42); @@ -402,20 +627,22 @@ TEST_F(PartyTest, ThreadStressTestWithNonOwningWakerNoSleep) { } TEST_F(PartyTest, ThreadStressTestWithInnerSpawn) { - auto party = MakeOrphanable(); + auto party = MakeRefCounted(); std::vector threads; threads.reserve(8); for (int i = 0; i < 8; i++) { - threads.emplace_back([party = party.get()]() { + threads.emplace_back([party]() { for (int i = 0; i < 100; i++) { ExecCtx ctx; // needed for Sleep PromiseNotification inner_start(true); PromiseNotification inner_complete(false); Notification promise_complete; party->Spawn( + "TestSpawn", Seq( [party, &inner_start, &inner_complete]() -> Poll { - party->Spawn(Seq(inner_start.Wait(), []() { return 0; }), + party->Spawn("TestSpawnInner", + Seq(inner_start.Wait(), []() { return 0; }), [&inner_complete](int i) { EXPECT_EQ(i, 0); inner_complete.Notify(); diff --git a/test/core/promise/pipe_test.cc b/test/core/promise/pipe_test.cc index e78b4acfe1f..4613c40274d 100644 --- a/test/core/promise/pipe_test.cc +++ b/test/core/promise/pipe_test.cc @@ -19,6 +19,7 @@ #include #include +#include "absl/functional/function_ref.h" #include "absl/status/status.h" #include "gmock/gmock.h" #include "gtest/gtest.h" @@ -26,6 +27,7 @@ #include #include +#include "src/core/lib/gprpp/crash.h" #include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/promise/activity.h" #include "src/core/lib/promise/detail/basic_join.h" @@ -381,6 +383,58 @@ TEST_F(PipeTest, CanFlowControlThroughManyStages) { ASSERT_TRUE(*done); } +TEST_F(PipeTest, AwaitClosedWorks) { + StrictMock> on_done; + EXPECT_CALL(on_done, Call(absl::OkStatus())); + MakeActivity( + [] { + auto* pipe = GetContext()->ManagedNew>(); + pipe->sender.InterceptAndMap([](int value) { return value + 1; }); + return Seq( + // Concurrently: + // - wait for closed on both ends + // - close the sender, which will signal the receiver to return an + // end-of-stream. + Join(pipe->receiver.AwaitClosed(), pipe->sender.AwaitClosed(), + [pipe]() mutable { + pipe->sender.Close(); + return absl::OkStatus(); + }), + // Verify we received end-of-stream and closed the sender. + [](std::tuple result) { + EXPECT_FALSE(std::get<0>(result)); + EXPECT_FALSE(std::get<1>(result)); + EXPECT_EQ(std::get<2>(result), absl::OkStatus()); + return absl::OkStatus(); + }); + }, + NoWakeupScheduler(), + [&on_done](absl::Status status) { on_done.Call(std::move(status)); }, + MakeScopedArena(1024, &memory_allocator_)); +} + +class FakeActivity final : public Activity { + public: + void Orphan() override {} + void ForceImmediateRepoll(WakeupMask) override {} + Waker MakeOwningWaker() override { Crash("Not implemented"); } + Waker MakeNonOwningWaker() override { Crash("Not implemented"); } + void Run(absl::FunctionRef f) { + ScopedActivity activity(this); + f(); + } +}; + +TEST_F(PipeTest, PollAckWaitsForReadyClosed) { + FakeActivity().Run([]() { + pipe_detail::Center c; + int i = 1; + EXPECT_EQ(c.Push(&i), Poll(true)); + c.MarkClosed(); + EXPECT_EQ(c.PollAck(), Poll(Pending{})); + }); +} + } // namespace grpc_core int main(int argc, char** argv) { diff --git a/test/core/promise/promise_factory_test.cc b/test/core/promise/promise_factory_test.cc index d822bf1f004..3690d78e5cc 100644 --- a/test/core/promise/promise_factory_test.cc +++ b/test/core/promise/promise_factory_test.cc @@ -14,13 +14,10 @@ #include "src/core/lib/promise/detail/promise_factory.h" -#include - #include "absl/functional/bind_front.h" #include "gtest/gtest.h" #include "src/core/lib/promise/poll.h" -#include "src/core/lib/promise/promise.h" namespace grpc_core { namespace promise_detail { @@ -43,13 +40,12 @@ TEST(AdaptorTest, FactoryFromPromise) { return Poll(Poll(42)); }).Make()(), Poll(42)); - EXPECT_EQ(MakeOnceFactory(Promise([]() { - return Poll(Poll(42)); - })).Make()(), - Poll(42)); - EXPECT_EQ(MakeRepeatedFactory(Promise([]() { + EXPECT_EQ( + MakeOnceFactory([]() { return Poll(Poll(42)); }).Make()(), + Poll(42)); + EXPECT_EQ(MakeRepeatedFactory([]() { return Poll(Poll(42)); - })).Make()(), + }).Make()(), Poll(42)); } diff --git a/test/core/promise/promise_fuzzer.cc b/test/core/promise/promise_fuzzer.cc index c777e948065..cd60e1a6a9f 100644 --- a/test/core/promise/promise_fuzzer.cc +++ b/test/core/promise/promise_fuzzer.cc @@ -19,6 +19,7 @@ #include #include +#include "absl/functional/any_invocable.h" #include "absl/status/status.h" #include "absl/types/optional.h" diff --git a/test/core/resource_quota/arena_test.cc b/test/core/resource_quota/arena_test.cc index 4667ab578df..665364305f3 100644 --- a/test/core/resource_quota/arena_test.cc +++ b/test/core/resource_quota/arena_test.cc @@ -22,6 +22,7 @@ #include #include +#include #include #include #include diff --git a/test/cpp/microbenchmarks/bm_call_create.cc b/test/cpp/microbenchmarks/bm_call_create.cc index 1f9d142518a..cdb948c24b2 100644 --- a/test/cpp/microbenchmarks/bm_call_create.cc +++ b/test/cpp/microbenchmarks/bm_call_create.cc @@ -418,17 +418,10 @@ void Destroy(grpc_transport* /*self*/) {} // implementation of grpc_transport_get_endpoint grpc_endpoint* GetEndpoint(grpc_transport* /*self*/) { return nullptr; } -static const grpc_transport_vtable phony_transport_vtable = {0, - "phony_http2", - InitStream, - nullptr, - SetPollset, - SetPollsetSet, - PerformStreamOp, - PerformOp, - DestroyStream, - Destroy, - GetEndpoint}; +static const grpc_transport_vtable phony_transport_vtable = { + 0, false, "phony_http2", InitStream, + nullptr, SetPollset, SetPollsetSet, PerformStreamOp, + PerformOp, DestroyStream, Destroy, GetEndpoint}; static grpc_transport phony_transport = {&phony_transport_vtable}; diff --git a/tools/codegen/core/optimize_arena_pool_sizes.py b/tools/codegen/core/optimize_arena_pool_sizes.py new file mode 100755 index 00000000000..bfae1c6217a --- /dev/null +++ b/tools/codegen/core/optimize_arena_pool_sizes.py @@ -0,0 +1,132 @@ +#!/usr/bin/env python3 + +# Copyright 2023 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# USAGE: +# Run some tests with the GRPC_ARENA_TRACE_POOLED_ALLOCATIONS #define turned on. +# Capture the output to a text file. +# Invoke this program with that as an argument, and let it work its magic. + +import collections +import heapq +import random +import re +import sys + +# A single allocation, negative size => free +Allocation = collections.namedtuple('Allocation', 'size ptr') +Active = collections.namedtuple('Active', 'id size') + +# Read through all the captures, and build up scrubbed traces +arenas = [] +building = collections.defaultdict(list) +active = {} +biggest = 0 +smallest = 1024 +sizes = set() +for filename in sys.argv[1:]: + for line in open(filename): + m = re.search(r'ARENA 0x([0-9a-f]+) ALLOC ([0-9]+) @ 0x([0-9a-f]+)', + line) + if m: + size = int(m.group(2)) + if size > biggest: + biggest = size + if size < smallest: + smallest = size + active[m.group(3)] = Active(m.group(1), size) + building[m.group(1)].append(size) + sizes.add(size) + m = re.search(r'FREE 0x([0-9a-f]+)', line) + if m: + # We may have spurious frees, so make sure there's an outstanding allocation + last = active.pop(m.group(1), None) + if last is not None: + building[last.id].append(-last.size) + m = re.search(r'DESTRUCT_ARENA 0x([0-9a-f]+)', line) + if m: + trace = building.pop(m.group(1), None) + if trace: + arenas.append(trace) + + +# Given a list of pool sizes, return which bucket an allocation should go into +def bucket(pool_sizes, size): + for bucket in sorted(pool_sizes): + if abs(size) <= bucket: + return bucket + + +# Given a list of pool sizes, determine the total outstanding bytes in the arena for once trace +def outstanding_bytes(pool_sizes, trace): + free_list = collections.defaultdict(int) + allocated = 0 + for size in trace: + b = bucket(pool_sizes, size) + if size < 0: + free_list[b] += 1 + else: + if free_list[b] > 0: + free_list[b] -= 1 + else: + allocated += b + return allocated + len(pool_sizes) * 8 + + +# Given a list of pool sizes, determine the maximum outstanding bytes for any seen trace +def measure(pool_sizes): + max_outstanding = 0 + for trace in arenas: + max_outstanding = max(max_outstanding, + outstanding_bytes(pool_sizes, trace)) + return max_outstanding + + +ALWAYS_INCLUDE = 1024 +best = [ALWAYS_INCLUDE, biggest] +best_measure = measure(best) + +testq = [] +step = 0 + + +def add(l): + global testq, best_measure, best + m = measure(l) + if m < best_measure: + best_measure = m + best = l + if l[-1] == smallest: + return + heapq.heappush(testq, (m, l)) + + +add(best) + +while testq: + top = heapq.heappop(testq)[1] + m = measure(top) + step += 1 + if step % 1000 == 0: + print("iter %d; pending=%d; top=%r/%d" % + (step, len(testq), top, measure(top))) + for i in sizes: + if i >= top[-1]: + continue + add(top + [i]) + +print("SAW SIZES: %r" % sorted(list(sizes))) +print("BEST: %r" % list(reversed(best))) +print("BEST MEASURE: %d" % best_measure) diff --git a/tools/doxygen/Doxyfile.c++.internal b/tools/doxygen/Doxyfile.c++.internal index e3e1cf38446..baee5f661fe 100644 --- a/tools/doxygen/Doxyfile.c++.internal +++ b/tools/doxygen/Doxyfile.c++.internal @@ -2400,12 +2400,14 @@ src/core/lib/promise/detail/promise_like.h \ src/core/lib/promise/detail/status.h \ src/core/lib/promise/detail/switch.h \ src/core/lib/promise/exec_ctx_wakeup_scheduler.h \ +src/core/lib/promise/for_each.h \ src/core/lib/promise/if.h \ src/core/lib/promise/interceptor_list.h \ -src/core/lib/promise/intra_activity_waiter.h \ src/core/lib/promise/latch.h \ src/core/lib/promise/loop.h \ src/core/lib/promise/map.h \ +src/core/lib/promise/party.cc \ +src/core/lib/promise/party.h \ src/core/lib/promise/pipe.h \ src/core/lib/promise/poll.h \ src/core/lib/promise/promise.h \ @@ -2612,6 +2614,8 @@ src/core/lib/surface/server.h \ src/core/lib/surface/validate_metadata.cc \ src/core/lib/surface/validate_metadata.h \ src/core/lib/surface/version.cc \ +src/core/lib/transport/batch_builder.cc \ +src/core/lib/transport/batch_builder.h \ src/core/lib/transport/bdp_estimator.cc \ src/core/lib/transport/bdp_estimator.h \ src/core/lib/transport/connectivity_state.cc \ diff --git a/tools/doxygen/Doxyfile.core.internal b/tools/doxygen/Doxyfile.core.internal index 66ee1646416..0c024d57711 100644 --- a/tools/doxygen/Doxyfile.core.internal +++ b/tools/doxygen/Doxyfile.core.internal @@ -2181,12 +2181,14 @@ src/core/lib/promise/detail/promise_like.h \ src/core/lib/promise/detail/status.h \ src/core/lib/promise/detail/switch.h \ src/core/lib/promise/exec_ctx_wakeup_scheduler.h \ +src/core/lib/promise/for_each.h \ src/core/lib/promise/if.h \ src/core/lib/promise/interceptor_list.h \ -src/core/lib/promise/intra_activity_waiter.h \ src/core/lib/promise/latch.h \ src/core/lib/promise/loop.h \ src/core/lib/promise/map.h \ +src/core/lib/promise/party.cc \ +src/core/lib/promise/party.h \ src/core/lib/promise/pipe.h \ src/core/lib/promise/poll.h \ src/core/lib/promise/promise.h \ @@ -2395,6 +2397,8 @@ src/core/lib/surface/validate_metadata.cc \ src/core/lib/surface/validate_metadata.h \ src/core/lib/surface/version.cc \ src/core/lib/transport/README.md \ +src/core/lib/transport/batch_builder.cc \ +src/core/lib/transport/batch_builder.h \ src/core/lib/transport/bdp_estimator.cc \ src/core/lib/transport/bdp_estimator.h \ src/core/lib/transport/connectivity_state.cc \ diff --git a/tools/run_tests/generated/tests.json b/tools/run_tests/generated/tests.json index 7ddbd1c1614..dc58db5868f 100644 --- a/tools/run_tests/generated/tests.json +++ b/tools/run_tests/generated/tests.json @@ -5067,30 +5067,6 @@ ], "uses_polling": true }, - { - "args": [], - "benchmark": false, - "ci_platforms": [ - "linux", - "mac", - "posix", - "windows" - ], - "cpu_cost": 1.0, - "exclude_configs": [], - "exclude_iomgrs": [], - "flaky": false, - "gtest": true, - "language": "c++", - "name": "observable_test", - "platforms": [ - "linux", - "mac", - "posix", - "windows" - ], - "uses_polling": false - }, { "args": [], "benchmark": false,