[promises] Client channel promise conversion (#33210)

<!--

If you know who should review your pull request, please assign it to
that
person, otherwise the pull request would get assigned randomly.

If your pull request is for a specific language, please add the
appropriate
lang label.

-->

---------

Co-authored-by: Mark D. Roth <roth@google.com>
Co-authored-by: markdroth <markdroth@users.noreply.github.com>
Co-authored-by: ctiller <ctiller@users.noreply.github.com>
pull/34211/head
Craig Tiller 1 year ago committed by GitHub
parent a749d07acf
commit 79a983472c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 10
      BUILD
  2. 5
      CMakeLists.txt
  3. 5
      build_autogenerated.yaml
  4. 353
      src/core/ext/filters/client_channel/client_channel.cc
  5. 60
      src/core/ext/filters/client_channel/client_channel.h
  6. 16
      src/core/ext/filters/client_channel/client_channel_plugin.cc
  7. 2
      src/core/ext/filters/client_channel/dynamic_filters.h
  8. 32
      src/core/ext/filters/client_channel/subchannel.cc
  9. 3
      src/core/ext/filters/client_channel/subchannel.h
  10. 63
      src/core/lib/gprpp/ref_counted_ptr.h
  11. 10
      src/core/lib/iomgr/polling_entity.cc
  12. 2
      src/core/lib/iomgr/polling_entity.h
  13. 1
      src/core/lib/promise/latch.h
  14. 14
      src/core/lib/promise/pipe.h
  15. 76
      src/core/lib/surface/call.cc
  16. 3
      src/core/lib/surface/call.h
  17. 10
      test/core/end2end/tests/filter_causes_close.cc
  18. 2
      test/core/end2end/tests/server_streaming.cc
  19. 1
      test/core/gprpp/BUILD
  20. 60
      test/core/gprpp/ref_counted_ptr_test.cc
  21. 2
      test/cpp/microbenchmarks/bm_call_create.cc

10
BUILD

@ -2532,6 +2532,7 @@ grpc_cc_library(
grpc_cc_library(
name = "ref_counted_ptr",
external_deps = ["absl/hash"],
language = "c++",
public_hdrs = ["//src/core:lib/gprpp/ref_counted_ptr.h"],
visibility = ["@grpc:ref_counted_ptr"],
@ -3038,6 +3039,7 @@ grpc_cc_library(
"legacy_context",
"orphanable",
"parse_address",
"promise",
"protobuf_duration_upb",
"ref_counted_ptr",
"server_address",
@ -3047,8 +3049,10 @@ grpc_cc_library(
"work_serializer",
"xds_orca_service_upb",
"xds_orca_upb",
"//src/core:activity",
"//src/core:arena",
"//src/core:arena_promise",
"//src/core:cancel_callback",
"//src/core:channel_args",
"//src/core:channel_fwd",
"//src/core:channel_init",
@ -3070,15 +3074,20 @@ grpc_cc_library(
"//src/core:json_args",
"//src/core:json_channel_args",
"//src/core:json_object_loader",
"//src/core:latch",
"//src/core:lb_policy",
"//src/core:lb_policy_registry",
"//src/core:map",
"//src/core:memory_quota",
"//src/core:pipe",
"//src/core:poll",
"//src/core:pollset_set",
"//src/core:proxy_mapper",
"//src/core:proxy_mapper_registry",
"//src/core:ref_counted",
"//src/core:resolved_address",
"//src/core:resource_quota",
"//src/core:seq",
"//src/core:service_config_parser",
"//src/core:slice",
"//src/core:slice_buffer",
@ -3088,6 +3097,7 @@ grpc_cc_library(
"//src/core:subchannel_interface",
"//src/core:time",
"//src/core:transport_fwd",
"//src/core:try_seq",
"//src/core:unique_type_name",
"//src/core:useful",
"//src/core:validation_errors",

5
CMakeLists.txt generated

@ -6534,6 +6534,7 @@ target_include_directories(avl_test
target_link_libraries(avl_test
${_gRPC_ALLTARGETS_LIBRARIES}
gtest
absl::hash
gpr
)
@ -10629,6 +10630,7 @@ target_include_directories(endpoint_config_test
target_link_libraries(endpoint_config_test
${_gRPC_ALLTARGETS_LIBRARIES}
gtest
absl::hash
absl::type_traits
absl::statusor
gpr
@ -15099,6 +15101,7 @@ target_include_directories(latch_test
target_link_libraries(latch_test
${_gRPC_ALLTARGETS_LIBRARIES}
gtest
absl::hash
absl::type_traits
absl::statusor
gpr
@ -24338,6 +24341,7 @@ target_include_directories(thread_quota_test
target_link_libraries(thread_quota_test
${_gRPC_ALLTARGETS_LIBRARIES}
gtest
absl::hash
gpr
)
@ -25389,6 +25393,7 @@ target_include_directories(wait_for_callback_test
target_link_libraries(wait_for_callback_test
${_gRPC_ALLTARGETS_LIBRARIES}
gtest
absl::hash
absl::type_traits
absl::statusor
gpr

@ -5338,6 +5338,7 @@ targets:
- test/core/avl/avl_test.cc
deps:
- gtest
- absl/hash:hash
- gpr
uses_polling: false
- name: aws_request_signer_test
@ -7774,6 +7775,7 @@ targets:
- test/core/event_engine/endpoint_config_test.cc
deps:
- gtest
- absl/hash:hash
- absl/meta:type_traits
- absl/status:statusor
- gpr
@ -10460,6 +10462,7 @@ targets:
- test/core/promise/latch_test.cc
deps:
- gtest
- absl/hash:hash
- absl/meta:type_traits
- absl/status:statusor
- gpr
@ -16259,6 +16262,7 @@ targets:
- test/core/resource_quota/thread_quota_test.cc
deps:
- gtest
- absl/hash:hash
- gpr
uses_polling: false
- name: thread_stress_test
@ -16798,6 +16802,7 @@ targets:
- test/core/promise/wait_for_callback_test.cc
deps:
- gtest
- absl/hash:hash
- absl/meta:type_traits
- absl/status:statusor
- gpr

@ -82,6 +82,14 @@
#include "src/core/lib/json/json.h"
#include "src/core/lib/load_balancing/lb_policy_registry.h"
#include "src/core/lib/load_balancing/subchannel_interface.h"
#include "src/core/lib/promise/cancel_callback.h"
#include "src/core/lib/promise/context.h"
#include "src/core/lib/promise/latch.h"
#include "src/core/lib/promise/map.h"
#include "src/core/lib/promise/pipe.h"
#include "src/core/lib/promise/poll.h"
#include "src/core/lib/promise/promise.h"
#include "src/core/lib/promise/try_seq.h"
#include "src/core/lib/resolver/resolver_registry.h"
#include "src/core/lib/resolver/server_address.h"
#include "src/core/lib/security/credentials/credentials.h"
@ -89,6 +97,7 @@
#include "src/core/lib/service_config/service_config_impl.h"
#include "src/core/lib/slice/slice.h"
#include "src/core/lib/slice/slice_internal.h"
#include "src/core/lib/surface/call.h"
#include "src/core/lib/surface/channel.h"
#include "src/core/lib/transport/connectivity_state.h"
#include "src/core/lib/transport/error_utils.h"
@ -146,7 +155,7 @@ class ClientChannel::CallData {
// Accessors for data stored in the subclass.
virtual ClientChannel* chand() const = 0;
virtual Arena* arena() const = 0;
virtual grpc_polling_entity* pollent() const = 0;
virtual grpc_polling_entity* pollent() = 0;
virtual grpc_metadata_batch* send_initial_metadata() = 0;
virtual grpc_call_context_element* call_context() const = 0;
@ -205,7 +214,7 @@ class ClientChannel::FilterBasedCallData : public ClientChannel::CallData {
return static_cast<ClientChannel*>(elem()->channel_data);
}
Arena* arena() const override { return deadline_state_.arena; }
grpc_polling_entity* pollent() const override { return pollent_; }
grpc_polling_entity* pollent() override { return pollent_; }
grpc_metadata_batch* send_initial_metadata() override {
return pending_batches_[0]
->payload->send_initial_metadata.send_initial_metadata;
@ -298,11 +307,105 @@ class ClientChannel::FilterBasedCallData : public ClientChannel::CallData {
grpc_error_handle cancel_error_;
};
class ClientChannel::PromiseBasedCallData : public ClientChannel::CallData {
public:
explicit PromiseBasedCallData(ClientChannel* chand) : chand_(chand) {}
ArenaPromise<absl::StatusOr<CallArgs>> MakeNameResolutionPromise(
CallArgs call_args) {
pollent_ = NowOrNever(call_args.polling_entity->WaitAndCopy()).value();
client_initial_metadata_ = std::move(call_args.client_initial_metadata);
// If we're still in IDLE, we need to start resolving.
if (GPR_UNLIKELY(chand_->CheckConnectivityState(false) ==
GRPC_CHANNEL_IDLE)) {
if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_call_trace)) {
gpr_log(GPR_INFO, "chand=%p calld=%p: %striggering exit idle", chand_,
this, Activity::current()->DebugTag().c_str());
}
// Bounce into the control plane work serializer to start resolving.
GRPC_CHANNEL_STACK_REF(chand_->owning_stack_, "ExitIdle");
chand_->work_serializer_->Run(
[chand = chand_]()
ABSL_EXCLUSIVE_LOCKS_REQUIRED(*chand_->work_serializer_) {
chand->CheckConnectivityState(/*try_to_connect=*/true);
GRPC_CHANNEL_STACK_UNREF(chand->owning_stack_, "ExitIdle");
},
DEBUG_LOCATION);
}
return [this, call_args = std::move(
call_args)]() mutable -> Poll<absl::StatusOr<CallArgs>> {
auto result = CheckResolution(was_queued_);
if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_call_trace)) {
gpr_log(GPR_INFO, "chand=%p calld=%p: %sCheckResolution returns %s",
chand_, this, Activity::current()->DebugTag().c_str(),
result.has_value() ? result->ToString().c_str() : "Pending");
}
if (!result.has_value()) {
waker_ = Activity::current()->MakeNonOwningWaker();
was_queued_ = true;
return Pending{};
}
if (!result->ok()) return *result;
call_args.client_initial_metadata = std::move(client_initial_metadata_);
return std::move(call_args);
};
}
private:
ClientChannel* chand() const override { return chand_; }
Arena* arena() const override { return GetContext<Arena>(); }
grpc_polling_entity* pollent() override { return &pollent_; }
grpc_metadata_batch* send_initial_metadata() override {
return client_initial_metadata_.get();
}
grpc_call_context_element* call_context() const override {
return GetContext<grpc_call_context_element>();
}
void RetryCheckResolutionLocked() override {
if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_call_trace)) {
gpr_log(GPR_INFO, "chand=%p calld=%p: RetryCheckResolutionLocked()",
chand_, this);
}
waker_.WakeupAsync();
}
void ResetDeadline(Duration timeout) override {
CallContext* call_context = GetContext<CallContext>();
const Timestamp per_method_deadline =
Timestamp::FromCycleCounterRoundUp(call_context->call_start_time()) +
timeout;
call_context->UpdateDeadline(per_method_deadline);
}
ClientChannel* chand_;
grpc_polling_entity pollent_;
ClientMetadataHandle client_initial_metadata_;
bool was_queued_ = false;
Waker waker_;
};
//
// Filter vtable
//
const grpc_channel_filter ClientChannel::kFilterVtable = {
const grpc_channel_filter ClientChannel::kFilterVtableWithPromises = {
ClientChannel::FilterBasedCallData::StartTransportStreamOpBatch,
ClientChannel::MakeCallPromise,
ClientChannel::StartTransportOp,
sizeof(ClientChannel::FilterBasedCallData),
ClientChannel::FilterBasedCallData::Init,
ClientChannel::FilterBasedCallData::SetPollent,
ClientChannel::FilterBasedCallData::Destroy,
sizeof(ClientChannel),
ClientChannel::Init,
grpc_channel_stack_no_post_init,
ClientChannel::Destroy,
ClientChannel::GetChannelInfo,
"client-channel",
};
const grpc_channel_filter ClientChannel::kFilterVtableWithoutPromises = {
ClientChannel::FilterBasedCallData::StartTransportStreamOpBatch,
nullptr,
ClientChannel::StartTransportOp,
@ -324,6 +427,12 @@ const grpc_channel_filter ClientChannel::kFilterVtable = {
namespace {
ClientChannelServiceConfigCallData* GetServiceConfigCallData(
grpc_call_context_element* context) {
return static_cast<ClientChannelServiceConfigCallData*>(
context[GRPC_CONTEXT_SERVICE_CONFIG_CALL_DATA].value);
}
class DynamicTerminationFilter {
public:
class CallData;
@ -349,6 +458,19 @@ class DynamicTerminationFilter {
static void GetChannelInfo(grpc_channel_element* /*elem*/,
const grpc_channel_info* /*info*/) {}
static ArenaPromise<ServerMetadataHandle> MakeCallPromise(
grpc_channel_element* elem, CallArgs call_args, NextPromiseFactory) {
auto* chand = static_cast<DynamicTerminationFilter*>(elem->channel_data);
return chand->chand_->CreateLoadBalancedCallPromise(
std::move(call_args),
[]() {
auto* service_config_call_data =
GetServiceConfigCallData(GetContext<grpc_call_context_element>());
service_config_call_data->Commit();
},
/*is_transparent_retry=*/false);
}
private:
explicit DynamicTerminationFilter(const ChannelArgs& args)
: chand_(args.GetObject<ClientChannel>()) {}
@ -397,8 +519,7 @@ class DynamicTerminationFilter::CallData {
/*start_time=*/0, calld->deadline_,
calld->arena_, calld->call_combiner_};
auto* service_config_call_data =
static_cast<ClientChannelServiceConfigCallData*>(
calld->call_context_[GRPC_CONTEXT_SERVICE_CONFIG_CALL_DATA].value);
GetServiceConfigCallData(calld->call_context_);
calld->lb_call_ = client_channel->CreateLoadBalancedCall(
args, pollent, nullptr,
[service_config_call_data]() { service_config_call_data->Commit(); },
@ -433,7 +554,7 @@ class DynamicTerminationFilter::CallData {
const grpc_channel_filter DynamicTerminationFilter::kFilterVtable = {
DynamicTerminationFilter::CallData::StartTransportStreamOpBatch,
nullptr,
DynamicTerminationFilter::MakeCallPromise,
DynamicTerminationFilter::StartTransportOp,
sizeof(DynamicTerminationFilter::CallData),
DynamicTerminationFilter::CallData::Init,
@ -1013,14 +1134,18 @@ class ClientChannel::ClientChannelControlHelper
ClientChannel* ClientChannel::GetFromChannel(Channel* channel) {
grpc_channel_element* elem =
grpc_channel_stack_last_element(channel->channel_stack());
if (elem->filter != &kFilterVtable) return nullptr;
if (elem->filter != &kFilterVtableWithPromises &&
elem->filter != &kFilterVtableWithoutPromises) {
return nullptr;
}
return static_cast<ClientChannel*>(elem->channel_data);
}
grpc_error_handle ClientChannel::Init(grpc_channel_element* elem,
grpc_channel_element_args* args) {
GPR_ASSERT(args->is_last);
GPR_ASSERT(elem->filter == &kFilterVtable);
GPR_ASSERT(elem->filter == &kFilterVtableWithPromises ||
elem->filter == &kFilterVtableWithoutPromises);
grpc_error_handle error;
new (elem->channel_data) ClientChannel(args, &error);
return error;
@ -1136,6 +1261,21 @@ ClientChannel::~ClientChannel() {
grpc_pollset_set_destroy(interested_parties_);
}
ArenaPromise<ServerMetadataHandle> ClientChannel::MakeCallPromise(
grpc_channel_element* elem, CallArgs call_args, NextPromiseFactory) {
auto* chand = static_cast<ClientChannel*>(elem->channel_data);
// TODO(roth): Is this the right lifetime story for calld?
auto* calld = GetContext<Arena>()->ManagedNew<PromiseBasedCallData>(chand);
return TrySeq(
// Name resolution.
calld->MakeNameResolutionPromise(std::move(call_args)),
// Dynamic filter stack.
[calld](CallArgs call_args) mutable {
return calld->dynamic_filters()->channel_stack()->MakeClientCallPromise(
std::move(call_args));
});
}
OrphanablePtr<ClientChannel::FilterBasedLoadBalancedCall>
ClientChannel::CreateLoadBalancedCall(
const grpc_call_element_args& args, grpc_polling_entity* pollent,
@ -1147,6 +1287,16 @@ ClientChannel::CreateLoadBalancedCall(
std::move(on_commit), is_transparent_retry));
}
ArenaPromise<ServerMetadataHandle> ClientChannel::CreateLoadBalancedCallPromise(
CallArgs call_args, absl::AnyInvocable<void()> on_commit,
bool is_transparent_retry) {
OrphanablePtr<PromiseBasedLoadBalancedCall> lb_call(
GetContext<Arena>()->New<PromiseBasedLoadBalancedCall>(
this, std::move(on_commit), is_transparent_retry));
auto* call_ptr = lb_call.get();
return call_ptr->MakeCallPromise(std::move(call_args), std::move(lb_call));
}
ChannelArgs ClientChannel::MakeSubchannelArgs(
const ChannelArgs& channel_args, const ChannelArgs& address_args,
const RefCountedPtr<SubchannelPoolInterface>& subchannel_pool,
@ -1610,7 +1760,7 @@ void ClientChannel::UpdateStateAndPickerLocked(
MutexLock lock(&lb_mu_);
picker_.swap(picker);
// Reprocess queued picks.
for (LoadBalancedCall* call : lb_queued_calls_) {
for (auto& call : lb_queued_calls_) {
call->RemoveCallFromLbQueuedCallsLocked();
call->RetryPickLocked();
}
@ -1840,8 +1990,10 @@ void ClientChannel::CallData::RemoveCallFromResolverQueuedCallsLocked() {
void ClientChannel::CallData::AddCallToResolverQueuedCallsLocked() {
if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_call_trace)) {
gpr_log(GPR_INFO, "chand=%p calld=%p: adding to resolver queued picks list",
chand(), this);
gpr_log(
GPR_INFO,
"chand=%p calld=%p: adding to resolver queued picks list; pollent=%s",
chand(), this, grpc_polling_entity_string(pollent()).c_str());
}
// Add call's pollent to channel's interested_parties, so that I/O
// can be done under the call's CQ.
@ -2351,8 +2503,7 @@ void ClientChannel::FilterBasedCallData::
auto* calld = static_cast<FilterBasedCallData*>(arg);
auto* chand = calld->chand();
auto* service_config_call_data =
static_cast<ClientChannelServiceConfigCallData*>(
calld->call_context()[GRPC_CONTEXT_SERVICE_CONFIG_CALL_DATA].value);
GetServiceConfigCallData(calld->call_context());
if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_call_trace)) {
gpr_log(GPR_INFO,
"chand=%p calld=%p: got recv_trailing_metadata_ready: error=%s "
@ -2470,8 +2621,8 @@ class ClientChannel::LoadBalancedCall::Metadata
ServiceConfigCallData::CallAttributeInterface*
ClientChannel::LoadBalancedCall::LbCallState::GetCallAttribute(
UniqueTypeName type) const {
auto* service_config_call_data = static_cast<ServiceConfigCallData*>(
lb_call_->call_context()[GRPC_CONTEXT_SERVICE_CONFIG_CALL_DATA].value);
auto* service_config_call_data =
GetServiceConfigCallData(lb_call_->call_context());
return service_config_call_data->GetCallAttribute(type);
}
@ -2559,16 +2710,6 @@ ClientChannel::LoadBalancedCall::~LoadBalancedCall() {
}
}
void ClientChannel::LoadBalancedCall::Orphan() {
// Compute latency and report it to the tracer.
if (call_attempt_tracer() != nullptr) {
gpr_timespec latency =
gpr_cycle_counter_sub(gpr_get_cycle_counter(), lb_call_start_time_);
call_attempt_tracer()->RecordEnd(latency);
}
Unref();
}
void ClientChannel::LoadBalancedCall::RecordCallCompletion(
absl::Status status, grpc_metadata_batch* recv_trailing_metadata,
grpc_transport_stream_stats* transport_stream_stats,
@ -2590,6 +2731,15 @@ void ClientChannel::LoadBalancedCall::RecordCallCompletion(
}
}
void ClientChannel::LoadBalancedCall::RecordLatency() {
// Compute latency and report it to the tracer.
if (call_attempt_tracer() != nullptr) {
gpr_timespec latency =
gpr_cycle_counter_sub(gpr_get_cycle_counter(), lb_call_start_time_);
call_attempt_tracer()->RecordEnd(latency);
}
}
void ClientChannel::LoadBalancedCall::RemoveCallFromLbQueuedCallsLocked() {
if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_lb_call_trace)) {
gpr_log(GPR_INFO, "chand=%p lb_call=%p: removing from queued picks list",
@ -2614,7 +2764,7 @@ void ClientChannel::LoadBalancedCall::AddCallToLbQueuedCallsLocked() {
grpc_polling_entity_add_to_pollset_set(pollent(),
chand_->interested_parties_);
// Add to queue.
chand_->lb_queued_calls_.insert(this);
chand_->lb_queued_calls_.insert(Ref());
OnAddToQueueLocked();
}
@ -2813,6 +2963,7 @@ void ClientChannel::FilterBasedLoadBalancedCall::Orphan() {
RecordCallCompletion(absl::CancelledError("call cancelled"), nullptr,
nullptr, "");
}
RecordLatency();
// Delegate to parent.
LoadBalancedCall::Orphan();
}
@ -3153,7 +3304,7 @@ class ClientChannel::FilterBasedLoadBalancedCall::LbQueuedCallCanceller {
// Remove pick from list of queued picks.
lb_call->RemoveCallFromLbQueuedCallsLocked();
// Remove from queued picks list.
chand->lb_queued_calls_.erase(lb_call);
chand->lb_queued_calls_.erase(self->lb_call_);
// Fail pending batches on the call.
lb_call->PendingBatchesFail(error,
YieldCallCombinerIfPendingBatchesFound);
@ -3243,4 +3394,152 @@ void ClientChannel::FilterBasedLoadBalancedCall::CreateSubchannelCall() {
}
}
//
// ClientChannel::PromiseBasedLoadBalancedCall
//
ClientChannel::PromiseBasedLoadBalancedCall::PromiseBasedLoadBalancedCall(
ClientChannel* chand, absl::AnyInvocable<void()> on_commit,
bool is_transparent_retry)
: LoadBalancedCall(chand, GetContext<grpc_call_context_element>(),
std::move(on_commit), is_transparent_retry) {}
ArenaPromise<ServerMetadataHandle>
ClientChannel::PromiseBasedLoadBalancedCall::MakeCallPromise(
CallArgs call_args, OrphanablePtr<PromiseBasedLoadBalancedCall> lb_call) {
pollent_ = NowOrNever(call_args.polling_entity->WaitAndCopy()).value();
// Record ops in tracer.
if (call_attempt_tracer() != nullptr) {
call_attempt_tracer()->RecordSendInitialMetadata(
call_args.client_initial_metadata.get());
// TODO(ctiller): Find a way to do this without registering a no-op mapper.
call_args.client_to_server_messages->InterceptAndMapWithHalfClose(
[](MessageHandle message) { return message; }, // No-op.
[this]() {
// TODO(roth): Change CallTracer API to not pass metadata
// batch to this method, since the batch is always empty.
grpc_metadata_batch metadata(GetContext<Arena>());
call_attempt_tracer()->RecordSendTrailingMetadata(&metadata);
});
}
// Extract peer name from server initial metadata.
call_args.server_initial_metadata->InterceptAndMap(
[this](ServerMetadataHandle metadata) {
if (call_attempt_tracer() != nullptr) {
call_attempt_tracer()->RecordReceivedInitialMetadata(metadata.get());
}
Slice* peer_string = metadata->get_pointer(PeerString());
if (peer_string != nullptr) peer_string_ = peer_string->Ref();
return metadata;
});
client_initial_metadata_ = std::move(call_args.client_initial_metadata);
return OnCancel(
Map(TrySeq(
// LB pick.
[this]() -> Poll<absl::Status> {
auto result = PickSubchannel(was_queued_);
if (GRPC_TRACE_FLAG_ENABLED(
grpc_client_channel_lb_call_trace)) {
gpr_log(GPR_INFO,
"chand=%p lb_call=%p: %sPickSubchannel() returns %s",
chand(), this,
Activity::current()->DebugTag().c_str(),
result.has_value() ? result->ToString().c_str()
: "Pending");
}
if (result == absl::nullopt) return Pending{};
return std::move(*result);
},
[this, call_args = std::move(call_args)]() mutable
-> ArenaPromise<ServerMetadataHandle> {
call_args.client_initial_metadata =
std::move(client_initial_metadata_);
return connected_subchannel()->MakeCallPromise(
std::move(call_args));
}),
// Record call completion.
[this](ServerMetadataHandle metadata) {
if (call_attempt_tracer() != nullptr ||
lb_subchannel_call_tracker() != nullptr) {
absl::Status status;
grpc_status_code code = metadata->get(GrpcStatusMetadata())
.value_or(GRPC_STATUS_UNKNOWN);
if (code != GRPC_STATUS_OK) {
absl::string_view message;
if (const auto* grpc_message =
metadata->get_pointer(GrpcMessageMetadata())) {
message = grpc_message->as_string_view();
}
status =
absl::Status(static_cast<absl::StatusCode>(code), message);
}
RecordCallCompletion(status, metadata.get(),
&GetContext<CallContext>()
->call_stats()
->transport_stream_stats,
peer_string_.as_string_view());
}
RecordLatency();
return metadata;
}),
[lb_call = std::move(lb_call)]() {
// If the waker is pending, then we need to remove ourself from
// the list of queued LB calls.
if (!lb_call->waker_.is_unwakeable()) {
MutexLock lock(&lb_call->chand()->lb_mu_);
lb_call->Commit();
// Remove pick from list of queued picks.
lb_call->RemoveCallFromLbQueuedCallsLocked();
// Remove from queued picks list.
lb_call->chand()->lb_queued_calls_.erase(lb_call.get());
}
// TODO(ctiller): We don't have access to the call's actual status
// here, so we just assume CANCELLED. We could change this to use
// CallFinalization instead of OnCancel() so that we can get the
// actual status. But we should also have access to the trailing
// metadata, which we don't have in either case. Ultimately, we
// need a better story for code that needs to run at the end of a
// call in both cancellation and non-cancellation cases that needs
// access to server trailing metadata and the call's real status.
if (lb_call->call_attempt_tracer() != nullptr) {
lb_call->call_attempt_tracer()->RecordCancel(
absl::CancelledError("call cancelled"));
}
if (lb_call->call_attempt_tracer() != nullptr ||
lb_call->lb_subchannel_call_tracker() != nullptr) {
// If we were cancelled without recording call completion, then
// record call completion here, as best we can. We assume status
// CANCELLED in this case.
lb_call->RecordCallCompletion(absl::CancelledError("call cancelled"),
nullptr, nullptr, "");
}
});
}
Arena* ClientChannel::PromiseBasedLoadBalancedCall::arena() const {
return GetContext<Arena>();
}
grpc_call_context_element*
ClientChannel::PromiseBasedLoadBalancedCall::call_context() const {
return GetContext<grpc_call_context_element>();
}
grpc_metadata_batch*
ClientChannel::PromiseBasedLoadBalancedCall::send_initial_metadata() const {
return client_initial_metadata_.get();
}
void ClientChannel::PromiseBasedLoadBalancedCall::OnAddToQueueLocked() {
waker_ = Activity::current()->MakeNonOwningWaker();
was_queued_ = true;
}
void ClientChannel::PromiseBasedLoadBalancedCall::RetryPickLocked() {
if (GRPC_TRACE_FLAG_ENABLED(grpc_client_channel_lb_call_trace)) {
gpr_log(GPR_INFO, "chand=%p lb_call=%p: RetryPickLocked()", chand(), this);
}
waker_.WakeupAsync();
}
} // namespace grpc_core

@ -62,6 +62,8 @@
#include "src/core/lib/iomgr/iomgr_fwd.h"
#include "src/core/lib/iomgr/polling_entity.h"
#include "src/core/lib/load_balancing/lb_policy.h"
#include "src/core/lib/promise/activity.h"
#include "src/core/lib/promise/arena_promise.h"
#include "src/core/lib/resolver/resolver.h"
#include "src/core/lib/resource_quota/arena.h"
#include "src/core/lib/service_config/service_config.h"
@ -102,10 +104,12 @@ namespace grpc_core {
class ClientChannel {
public:
static const grpc_channel_filter kFilterVtable;
static const grpc_channel_filter kFilterVtableWithPromises;
static const grpc_channel_filter kFilterVtableWithoutPromises;
class LoadBalancedCall;
class FilterBasedLoadBalancedCall;
class PromiseBasedLoadBalancedCall;
// Flag that this object gets stored in channel args as a raw pointer.
struct RawPointerChannelArgTag {};
@ -115,6 +119,10 @@ class ClientChannel {
// is not a client channel.
static ClientChannel* GetFromChannel(Channel* channel);
static ArenaPromise<ServerMetadataHandle> MakeCallPromise(
grpc_channel_element* elem, CallArgs call_args,
NextPromiseFactory next_promise_factory);
grpc_connectivity_state CheckConnectivityState(bool try_to_connect);
// Starts a one-time connectivity state watch. When the channel's state
@ -164,6 +172,10 @@ class ClientChannel {
grpc_closure* on_call_destruction_complete,
absl::AnyInvocable<void()> on_commit, bool is_transparent_retry);
ArenaPromise<ServerMetadataHandle> CreateLoadBalancedCallPromise(
CallArgs call_args, absl::AnyInvocable<void()> on_commit,
bool is_transparent_retry);
// Exposed for testing only.
static ChannelArgs MakeSubchannelArgs(
const ChannelArgs& channel_args, const ChannelArgs& address_args,
@ -173,6 +185,7 @@ class ClientChannel {
private:
class CallData;
class FilterBasedCallData;
class PromiseBasedCallData;
class ResolverResultHandler;
class SubchannelWrapper;
class ClientChannelControlHelper;
@ -315,8 +328,10 @@ class ClientChannel {
mutable Mutex lb_mu_;
RefCountedPtr<LoadBalancingPolicy::SubchannelPicker> picker_
ABSL_GUARDED_BY(lb_mu_);
absl::flat_hash_set<LoadBalancedCall*> lb_queued_calls_
ABSL_GUARDED_BY(lb_mu_);
absl::flat_hash_set<RefCountedPtr<LoadBalancedCall>,
RefCountedPtrHash<LoadBalancedCall>,
RefCountedPtrEq<LoadBalancedCall>>
lb_queued_calls_ ABSL_GUARDED_BY(lb_mu_);
//
// Fields used in the control plane. Guarded by work_serializer.
@ -377,7 +392,7 @@ class ClientChannel::LoadBalancedCall
bool is_transparent_retry);
~LoadBalancedCall() override;
void Orphan() override;
void Orphan() override { Unref(); }
// Called by channel when removing a call from the list of queued calls.
void RemoveCallFromLbQueuedCallsLocked()
@ -394,7 +409,6 @@ class ClientChannel::LoadBalancedCall
return static_cast<ClientCallTracer::CallAttemptTracer*>(
call_context()[GRPC_CONTEXT_CALL_TRACER].value);
}
gpr_cycle_counter lb_call_start_time() const { return lb_call_start_time_; }
ConnectedSubchannel* connected_subchannel() const {
return connected_subchannel_.get();
}
@ -425,6 +439,8 @@ class ClientChannel::LoadBalancedCall
grpc_transport_stream_stats* transport_stream_stats,
absl::string_view peer_address);
void RecordLatency();
private:
class LbCallState;
class Metadata;
@ -432,7 +448,7 @@ class ClientChannel::LoadBalancedCall
virtual Arena* arena() const = 0;
virtual grpc_call_context_element* call_context() const = 0;
virtual grpc_polling_entity* pollent() const = 0;
virtual grpc_polling_entity* pollent() = 0;
virtual grpc_metadata_batch* send_initial_metadata() const = 0;
// Helper function for performing an LB pick with a specified picker.
@ -445,7 +461,7 @@ class ClientChannel::LoadBalancedCall
// Called when adding the call to the LB queue.
virtual void OnAddToQueueLocked()
ABSL_EXCLUSIVE_LOCKS_REQUIRED(&ClientChannel::lb_mu_) {}
ABSL_EXCLUSIVE_LOCKS_REQUIRED(&ClientChannel::lb_mu_) = 0;
ClientChannel* chand_;
@ -496,7 +512,7 @@ class ClientChannel::FilterBasedLoadBalancedCall
grpc_call_context_element* call_context() const override {
return call_context_;
}
grpc_polling_entity* pollent() const override { return pollent_; }
grpc_polling_entity* pollent() override { return pollent_; }
grpc_metadata_batch* send_initial_metadata() const override {
return pending_batches_[0]
->payload->send_initial_metadata.send_initial_metadata;
@ -590,6 +606,34 @@ class ClientChannel::FilterBasedLoadBalancedCall
grpc_transport_stream_op_batch* pending_batches_[MAX_PENDING_BATCHES] = {};
};
class ClientChannel::PromiseBasedLoadBalancedCall
: public ClientChannel::LoadBalancedCall {
public:
PromiseBasedLoadBalancedCall(ClientChannel* chand,
absl::AnyInvocable<void()> on_commit,
bool is_transparent_retry);
ArenaPromise<ServerMetadataHandle> MakeCallPromise(
CallArgs call_args, OrphanablePtr<PromiseBasedLoadBalancedCall> lb_call);
private:
Arena* arena() const override;
grpc_call_context_element* call_context() const override;
grpc_polling_entity* pollent() override { return &pollent_; }
grpc_metadata_batch* send_initial_metadata() const override;
void RetryPickLocked() override;
void OnAddToQueueLocked() override
ABSL_EXCLUSIVE_LOCKS_REQUIRED(&ClientChannel::lb_mu_);
grpc_polling_entity pollent_;
ClientMetadataHandle client_initial_metadata_;
Waker waker_;
bool was_queued_ = false;
Slice peer_string_;
};
} // namespace grpc_core
#endif // GRPC_SRC_CORE_EXT_FILTERS_CLIENT_CHANNEL_CLIENT_CHANNEL_H

@ -18,9 +18,14 @@
#include <grpc/support/port_platform.h>
#include "absl/types/optional.h"
#include <grpc/impl/channel_arg_names.h>
#include "src/core/ext/filters/client_channel/client_channel.h"
#include "src/core/ext/filters/client_channel/client_channel_service_config.h"
#include "src/core/ext/filters/client_channel/retry_service_config.h"
#include "src/core/lib/channel/channel_args.h"
#include "src/core/lib/channel/channel_stack_builder.h"
#include "src/core/lib/config/core_configuration.h"
#include "src/core/lib/surface/channel_init.h"
@ -28,13 +33,22 @@
namespace grpc_core {
namespace {
bool IsEverythingBelowClientChannelPromiseSafe(const ChannelArgs& args) {
return !args.GetBool(GRPC_ARG_ENABLE_RETRIES).value_or(true);
}
} // namespace
void BuildClientChannelConfiguration(CoreConfiguration::Builder* builder) {
internal::ClientChannelServiceConfigParser::Register(builder);
internal::RetryServiceConfigParser::Register(builder);
builder->channel_init()->RegisterStage(
GRPC_CLIENT_CHANNEL, GRPC_CHANNEL_INIT_BUILTIN_PRIORITY,
[](ChannelStackBuilder* builder) {
builder->AppendFilter(&ClientChannel::kFilterVtable);
builder->AppendFilter(
IsEverythingBelowClientChannelPromiseSafe(builder->channel_args())
? &ClientChannel::kFilterVtableWithPromises
: &ClientChannel::kFilterVtableWithoutPromises);
return true;
});
}

@ -99,6 +99,8 @@ class DynamicFilters : public RefCounted<DynamicFilters> {
RefCountedPtr<Call> CreateCall(Call::Args args, grpc_error_handle* error);
grpc_channel_stack* channel_stack() const { return channel_stack_.get(); }
private:
RefCountedPtr<grpc_channel_stack> channel_stack_;
};

@ -59,6 +59,8 @@
#include "src/core/lib/handshaker/proxy_mapper_registry.h"
#include "src/core/lib/iomgr/exec_ctx.h"
#include "src/core/lib/iomgr/pollset_set.h"
#include "src/core/lib/promise/cancel_callback.h"
#include "src/core/lib/promise/seq.h"
#include "src/core/lib/slice/slice_internal.h"
#include "src/core/lib/surface/channel_init.h"
#include "src/core/lib/surface/channel_stack_type.h"
@ -133,6 +135,36 @@ size_t ConnectedSubchannel::GetInitialCallSizeEstimate() const {
channel_stack_->call_stack_size;
}
ArenaPromise<ServerMetadataHandle> ConnectedSubchannel::MakeCallPromise(
CallArgs call_args) {
// If not using channelz, we just need to call the channel stack.
if (channelz_subchannel() == nullptr) {
return channel_stack_->MakeClientCallPromise(std::move(call_args));
}
// Otherwise, we need to wrap the channel stack promise with code that
// handles the channelz updates.
return OnCancel(
Seq(channel_stack_->MakeClientCallPromise(std::move(call_args)),
[self = Ref()](ServerMetadataHandle metadata) {
channelz::SubchannelNode* channelz_subchannel =
self->channelz_subchannel();
GPR_ASSERT(channelz_subchannel != nullptr);
if (metadata->get(GrpcStatusMetadata())
.value_or(GRPC_STATUS_UNKNOWN) != GRPC_STATUS_OK) {
channelz_subchannel->RecordCallFailed();
} else {
channelz_subchannel->RecordCallSucceeded();
}
return metadata;
}),
[self = Ref()]() {
channelz::SubchannelNode* channelz_subchannel =
self->channelz_subchannel();
GPR_ASSERT(channelz_subchannel != nullptr);
channelz_subchannel->RecordCallFailed();
});
}
//
// SubchannelCall
//

@ -54,6 +54,7 @@
#include "src/core/lib/iomgr/iomgr_fwd.h"
#include "src/core/lib/iomgr/polling_entity.h"
#include "src/core/lib/iomgr/resolved_address.h"
#include "src/core/lib/promise/arena_promise.h"
#include "src/core/lib/resource_quota/arena.h"
#include "src/core/lib/slice/slice.h"
#include "src/core/lib/transport/connectivity_state.h"
@ -84,6 +85,8 @@ class ConnectedSubchannel : public RefCounted<ConnectedSubchannel> {
size_t GetInitialCallSizeEstimate() const;
ArenaPromise<ServerMetadataHandle> MakeCallPromise(CallArgs call_args);
private:
grpc_channel_stack* channel_stack_;
ChannelArgs args_;

@ -21,10 +21,14 @@
#include <grpc/support/port_platform.h>
#include <stddef.h>
#include <iosfwd>
#include <type_traits>
#include <utility>
#include "absl/hash/hash.h"
#include "src/core/lib/gprpp/debug_location.h"
namespace grpc_core {
@ -333,6 +337,65 @@ bool operator<(const WeakRefCountedPtr<T>& p1, const WeakRefCountedPtr<T>& p2) {
return p1.get() < p2.get();
}
//
// absl::Hash integration
//
template <typename H, typename T>
H AbslHashValue(H h, const RefCountedPtr<T>& p) {
return H::combine(std::move(h), p.get());
}
template <typename H, typename T>
H AbslHashValue(H h, const WeakRefCountedPtr<T>& p) {
return H::combine(std::move(h), p.get());
}
// Heterogenous lookup support.
template <typename T>
struct RefCountedPtrHash {
using is_transparent = void;
size_t operator()(const RefCountedPtr<T>& p) const {
return absl::Hash<RefCountedPtr<T>>{}(p);
}
size_t operator()(const WeakRefCountedPtr<T>& p) const {
return absl::Hash<WeakRefCountedPtr<T>>{}(p);
}
size_t operator()(T* p) const { return absl::Hash<T*>{}(p); }
};
template <typename T>
struct RefCountedPtrEq {
using is_transparent = void;
bool operator()(const RefCountedPtr<T>& p1,
const RefCountedPtr<T>& p2) const {
return p1 == p2;
}
bool operator()(const WeakRefCountedPtr<T>& p1,
const WeakRefCountedPtr<T>& p2) const {
return p1 == p2;
}
bool operator()(const RefCountedPtr<T>& p1,
const WeakRefCountedPtr<T>& p2) const {
return p1 == p2.get();
}
bool operator()(const WeakRefCountedPtr<T>& p1,
const RefCountedPtr<T>& p2) const {
return p1 == p2.get();
}
bool operator()(const RefCountedPtr<T>& p1, const T* p2) const {
return p1 == p2;
}
bool operator()(const WeakRefCountedPtr<T>& p1, const T* p2) const {
return p1 == p2;
}
bool operator()(const T* p1, const RefCountedPtr<T>& p2) const {
return p2 == p1;
}
bool operator()(const T* p1, const WeakRefCountedPtr<T>& p2) const {
return p2 == p1;
}
};
} // namespace grpc_core
#endif // GRPC_SRC_CORE_LIB_GPRPP_REF_COUNTED_PTR_H

@ -98,3 +98,13 @@ void grpc_polling_entity_del_from_pollset_set(grpc_polling_entity* pollent,
absl::StrFormat("Invalid grpc_polling_entity tag '%d'", pollent->tag));
}
}
std::string grpc_polling_entity_string(grpc_polling_entity* pollent) {
if (pollent->tag == GRPC_POLLS_POLLSET) {
return absl::StrFormat("pollset:%p", pollent->pollent.pollset);
} else if (pollent->tag == GRPC_POLLS_POLLSET_SET) {
return absl::StrFormat("pollset_set:%p", pollent->pollent.pollset_set);
} else {
return absl::StrFormat("invalid_tag:%d", pollent->tag);
}
}

@ -66,6 +66,8 @@ void grpc_polling_entity_add_to_pollset_set(grpc_polling_entity* pollent,
void grpc_polling_entity_del_from_pollset_set(grpc_polling_entity* pollent,
grpc_pollset_set* pss_dst);
std::string grpc_polling_entity_string(grpc_polling_entity* pollent);
namespace grpc_core {
template <>
struct ContextType<grpc_polling_entity> {};

@ -44,6 +44,7 @@ class Latch {
public:
Latch() = default;
Latch(const Latch&) = delete;
explicit Latch(T value) : value_(std::move(value)), has_value_(true) {}
Latch& operator=(const Latch&) = delete;
Latch(Latch&& other) noexcept
: value_(std::move(other.value_)), has_value_(other.has_value_) {

@ -541,7 +541,9 @@ class Next {
Next(Next&& other) noexcept = default;
Next& operator=(Next&& other) noexcept = default;
Poll<absl::optional<T>> operator()() { return center_->Next(); }
Poll<absl::optional<T>> operator()() {
return center_ == nullptr ? absl::nullopt : center_->Next();
}
private:
friend class PipeReceiver<T>;
@ -572,17 +574,15 @@ class PipeReceiver {
// Blocks the promise until the receiver is either closed or a message is
// available.
auto Next() {
return Seq(
pipe_detail::Next<T>(center_->Ref()),
[center = center_->Ref()](absl::optional<T> value) {
return Seq(pipe_detail::Next<T>(center_), [center = center_](
absl::optional<T> value) {
bool open = value.has_value();
bool cancelled = center->cancelled();
bool cancelled = center == nullptr ? true : center->cancelled();
return If(
open,
[center = std::move(center), value = std::move(value)]() mutable {
auto run = center->Run(std::move(value));
return Map(std::move(run),
[center = std::move(center)](
return Map(std::move(run), [center = std::move(center)](
absl::optional<T> value) mutable {
if (value.has_value()) {
center->value() = std::move(*value);

@ -1980,10 +1980,21 @@ class PromiseBasedCall : public Call,
void SetCompletionQueue(grpc_completion_queue* cq) override;
bool Completed() final { return finished_.IsSet(); }
virtual void OrphanCall() = 0;
// Implementation of call refcounting: move this to DualRefCounted once we
// don't need to maintain FilterStackCall compatibility
void ExternalRef() final { InternalRef("external"); }
void ExternalUnref() final { InternalUnref("external"); }
void ExternalRef() final {
if (external_refs_.fetch_add(1, std::memory_order_relaxed) == 0) {
InternalRef("external");
}
}
void ExternalUnref() final {
if (external_refs_.fetch_sub(1, std::memory_order_acq_rel) == 1) {
OrphanCall();
InternalUnref("external");
}
}
void InternalRef(const char* reason) final {
if (grpc_call_refcount_trace.enabled()) {
gpr_log(GPR_DEBUG, "INTERNAL_REF:%p:%s", this, reason);
@ -2346,14 +2357,16 @@ class PromiseBasedCall : public Call,
}
CallContext call_context_{this};
// Double refcounted for now: party owns the internal refcount, we track the
// external refcount. Figure out a better scheme post-promise conversion.
std::atomic<size_t> external_refs_;
// Contexts for various subsystems (security, tracing, ...).
grpc_call_context_element context_[GRPC_CONTEXT_COUNT] = {};
grpc_completion_queue* cq_;
CompletionInfo completion_info_[6];
grpc_call_stats final_stats_{};
Slice final_message_;
grpc_status_code final_status_;
grpc_status_code final_status_ = GRPC_STATUS_UNKNOWN;
CallFinalization finalization_;
// Current deadline.
Mutex deadline_mu_;
@ -2391,7 +2404,8 @@ PromiseBasedCall::PromiseBasedCall(Arena* arena, uint32_t initial_external_refs,
const grpc_call_create_args& args)
: Call(arena, args.server_transport_data == nullptr, args.send_deadline,
args.channel->Ref()),
Party(arena, initial_external_refs),
Party(arena, initial_external_refs != 0 ? 1 : 0),
external_refs_(initial_external_refs),
cq_(args.cq) {
if (args.cq != nullptr) {
GRPC_CQ_INTERNAL_REF(args.cq, "bind");
@ -2684,18 +2698,20 @@ void PublishMetadataArray(grpc_metadata_batch* md, grpc_metadata_array* array,
class ClientPromiseBasedCall final : public PromiseBasedCall {
public:
ClientPromiseBasedCall(Arena* arena, grpc_call_create_args* args)
: PromiseBasedCall(arena, 1, *args) {
: PromiseBasedCall(arena, 1, *args),
polling_entity_(
args->cq != nullptr
? grpc_polling_entity_create_from_pollset(
grpc_cq_pollset(args->cq))
: (args->pollset_set_alternative != nullptr
? grpc_polling_entity_create_from_pollset_set(
args->pollset_set_alternative)
: grpc_polling_entity{})) {
global_stats().IncrementClientCallsCreated();
if (args->cq != nullptr) {
GPR_ASSERT(args->pollset_set_alternative == nullptr &&
"Only one of 'cq' and 'pollset_set_alternative' should be "
"non-nullptr.");
polling_entity_.Set(
grpc_polling_entity_create_from_pollset(grpc_cq_pollset(args->cq)));
}
if (args->pollset_set_alternative != nullptr) {
polling_entity_.Set(grpc_polling_entity_create_from_pollset_set(
args->pollset_set_alternative));
}
ScopedContext context(this);
send_initial_metadata_ =
@ -2711,8 +2727,18 @@ class ClientPromiseBasedCall final : public PromiseBasedCall {
if (args->send_deadline != Timestamp::InfFuture()) {
UpdateDeadline(args->send_deadline);
}
Call* parent = Call::FromC(args->parent);
if (parent != nullptr) {
auto parent_status = InitParent(parent, args->propagation_mask);
if (!parent_status.ok()) {
CancelWithError(std::move(parent_status));
}
PublishToParent(parent);
}
}
void OrphanCall() override { MaybeUnpublishFromParent(); }
~ClientPromiseBasedCall() override {
ScopedContext context(this);
send_initial_metadata_.reset();
@ -2740,7 +2766,9 @@ class ClientPromiseBasedCall final : public PromiseBasedCall {
"cancel_with_error",
[error = std::move(error), this]() {
if (!cancel_error_.is_set()) {
cancel_error_.Set(ServerMetadataFromStatus(error));
auto md = ServerMetadataFromStatus(error);
md->Set(GrpcCallWasCancelled(), true);
cancel_error_.Set(std::move(md));
}
return Empty{};
},
@ -2790,7 +2818,7 @@ class ClientPromiseBasedCall final : public PromiseBasedCall {
Latch<grpc_polling_entity> polling_entity_;
Pipe<MessageHandle> client_to_server_messages_{arena()};
Pipe<MessageHandle> server_to_client_messages_{arena()};
bool is_trailers_only_;
bool is_trailers_only_ = false;
// True once the promise for the call is started.
// This corresponds to sending initial metadata, or cancelling before doing
// so.
@ -2868,7 +2896,9 @@ grpc_call_error ClientPromiseBasedCall::ValidateBatch(const grpc_op* ops,
case GRPC_OP_SEND_STATUS_FROM_SERVER:
return GRPC_CALL_ERROR_NOT_ON_CLIENT;
}
if (got_ops.is_set(op.op)) return GRPC_CALL_ERROR_TOO_MANY_OPERATIONS;
if (got_ops.is_set(op.op)) {
return GRPC_CALL_ERROR_TOO_MANY_OPERATIONS;
}
got_ops.set(op.op);
}
return GRPC_CALL_OK;
@ -2972,9 +3002,16 @@ void ClientPromiseBasedCall::StartRecvInitialMetadata(
NextResult<ServerMetadataHandle> next_metadata) mutable {
server_initial_metadata_.sender.Close();
ServerMetadataHandle metadata;
if (grpc_call_trace.enabled()) {
gpr_log(GPR_INFO, "%s[call] RecvTrailingMetadata: %s",
DebugTag().c_str(),
next_metadata.has_value()
? next_metadata.value()->DebugString().c_str()
: "null");
}
if (next_metadata.has_value()) {
is_trailers_only_ = false;
metadata = std::move(next_metadata.value());
is_trailers_only_ = metadata->get(GrpcTrailersOnly()).value_or(false);
} else {
is_trailers_only_ = true;
metadata = arena()->MakePooled<ServerMetadata>(arena());
@ -2993,7 +3030,11 @@ void ClientPromiseBasedCall::Finish(ServerMetadataHandle trailing_metadata) {
}
ResetDeadline();
set_completed();
client_to_server_messages_.sender.Close();
client_to_server_messages_.sender.CloseWithError();
client_to_server_messages_.receiver.CloseWithError();
if (trailing_metadata->get(GrpcCallWasCancelled()).value_or(false)) {
server_to_client_messages_.receiver.CloseWithError();
}
if (auto* channelz_channel = channel()->channelz_node()) {
if (trailing_metadata->get(GrpcStatusMetadata())
.value_or(GRPC_STATUS_UNKNOWN) == GRPC_STATUS_OK) {
@ -3071,6 +3112,7 @@ class ServerPromiseBasedCall final : public PromiseBasedCall {
public:
ServerPromiseBasedCall(Arena* arena, grpc_call_create_args* args);
void OrphanCall() override {}
void CancelWithError(grpc_error_handle) override;
grpc_call_error StartBatch(const grpc_op* ops, size_t nops, void* notify_tag,
bool is_notify_tag_closure) override;

@ -38,6 +38,7 @@
#include "src/core/lib/channel/channel_stack.h"
#include "src/core/lib/channel/context.h"
#include "src/core/lib/debug/trace.h"
#include "src/core/lib/gpr/time_precise.h"
#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/lib/gprpp/time.h"
#include "src/core/lib/iomgr/closure.h"
@ -126,6 +127,7 @@ class CallContext {
grpc_call_stats* call_stats() { return &call_stats_; }
gpr_atm* peer_string_atm_ptr();
gpr_cycle_counter call_start_time() { return start_time_; }
ServerCallContext* server_call_context();
@ -139,6 +141,7 @@ class CallContext {
// TODO(ctiller): remove this once transport APIs are promise based and we
// don't need refcounting here.
PromiseBasedCall* const call_;
gpr_cycle_counter start_time_ = gpr_get_cycle_counter();
// Is this call traced?
bool traced_ = false;
};

@ -18,6 +18,8 @@
#include <stdint.h>
#include <memory>
#include "absl/status/status.h"
#include "gtest/gtest.h"
@ -32,6 +34,8 @@
#include "src/core/lib/gprpp/time.h"
#include "src/core/lib/iomgr/closure.h"
#include "src/core/lib/iomgr/error.h"
#include "src/core/lib/promise/arena_promise.h"
#include "src/core/lib/promise/promise.h"
#include "src/core/lib/surface/channel_stack_type.h"
#include "src/core/lib/transport/transport.h"
#include "test/core/end2end/end2end_tests.h"
@ -92,7 +96,11 @@ void destroy_channel_elem(grpc_channel_element* /*elem*/) {}
const grpc_channel_filter test_filter = {
start_transport_stream_op_batch,
nullptr,
[](grpc_channel_element*, CallArgs,
NextPromiseFactory) -> ArenaPromise<ServerMetadataHandle> {
return Immediate(ServerMetadataFromStatus(
absl::PermissionDeniedError("Failure that's not preventable.")));
},
grpc_channel_next_op,
sizeof(call_data),
init_call_elem,

@ -67,6 +67,8 @@ void ServerStreaming(CoreEnd2endTest& test, int num_messages) {
test.Expect(104, true);
test.Step();
gpr_log(GPR_DEBUG, "SEEN_STATUS:%d", seen_status);
// Client keeps reading messages till it gets the status
int num_messages_received = 0;
while (true) {

@ -200,6 +200,7 @@ grpc_cc_test(
name = "ref_counted_ptr_test",
srcs = ["ref_counted_ptr_test.cc"],
external_deps = [
"absl/container:flat_hash_set",
"gtest",
],
language = "C++",

@ -18,6 +18,7 @@
#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "absl/container/flat_hash_set.h"
#include "gtest/gtest.h"
#include <grpc/support/log.h>
@ -511,6 +512,65 @@ TEST(WeakRefCountedPtr, CanPassWeakSubclassToFunctionExpectingWeakSubclass) {
FunctionTakingWeakSubclass(p);
}
//
// tests for absl hash integration
//
TEST(AbslHashIntegration, RefCountedPtr) {
absl::flat_hash_set<RefCountedPtr<Foo>> set;
auto p = MakeRefCounted<Foo>(5);
set.insert(p);
auto it = set.find(p);
ASSERT_NE(it, set.end());
EXPECT_EQ(*it, p);
}
TEST(AbslHashIntegration, WeakRefCountedPtr) {
absl::flat_hash_set<WeakRefCountedPtr<Bar>> set;
auto p = MakeRefCounted<Bar>(5);
auto q = p->WeakRef();
set.insert(q);
auto it = set.find(q);
ASSERT_NE(it, set.end());
EXPECT_EQ(*it, q);
}
TEST(AbslHashIntegration, RefCountedPtrHeterogenousLookup) {
absl::flat_hash_set<RefCountedPtr<Bar>, RefCountedPtrHash<Bar>,
RefCountedPtrEq<Bar>>
set;
auto p = MakeRefCounted<Bar>(5);
set.insert(p);
auto it = set.find(p);
ASSERT_NE(it, set.end());
EXPECT_EQ(*it, p);
auto q = p->WeakRef();
it = set.find(q);
ASSERT_NE(it, set.end());
EXPECT_EQ(*it, p);
it = set.find(p.get());
ASSERT_NE(it, set.end());
EXPECT_EQ(*it, p);
}
TEST(AbslHashIntegration, WeakRefCountedPtrHeterogenousLookup) {
absl::flat_hash_set<WeakRefCountedPtr<Bar>, RefCountedPtrHash<Bar>,
RefCountedPtrEq<Bar>>
set;
auto p = MakeRefCounted<Bar>(5);
auto q = p->WeakRef();
set.insert(q);
auto it = set.find(q);
ASSERT_NE(it, set.end());
EXPECT_EQ(*it, q);
it = set.find(p);
ASSERT_NE(it, set.end());
EXPECT_EQ(*it, q);
it = set.find(p.get());
ASSERT_NE(it, set.end());
EXPECT_EQ(*it, q);
}
} // namespace
} // namespace testing
} // namespace grpc_core

@ -570,7 +570,7 @@ BENCHMARK_TEMPLATE(BM_IsolatedFilter, NoFilter, NoOp);
typedef Fixture<&phony_filter::phony_filter, 0> PhonyFilter;
BENCHMARK_TEMPLATE(BM_IsolatedFilter, PhonyFilter, NoOp);
BENCHMARK_TEMPLATE(BM_IsolatedFilter, PhonyFilter, SendEmptyMetadata);
typedef Fixture<&grpc_core::ClientChannel::kFilterVtable, 0>
typedef Fixture<&grpc_core::ClientChannel::kFilterVtableWithoutPromises, 0>
ClientChannelFilter;
BENCHMARK_TEMPLATE(BM_IsolatedFilter, ClientChannelFilter, NoOp);
typedef Fixture<&grpc_core::ClientCompressionFilter::kFilter, CHECKS_NOT_LAST>

Loading…
Cancel
Save