[RefCounted and friends] Fix type safety of ref-counted types.

Previously, `RefCountedPtr<>` and `WeakRefCountedPtr<>` incorrectly allowed
implicit casting of any type to any other type.  This hadn't caused a
problem until recently, but now that it has, we need to fix it.  I have
fixed this by changing these smart pointer types to allow type
conversions only when the type used is convertible to the type of the
smart pointer.  This means that if `Subclass` inherits from `Base`, then
we can set a `RefCountedPtr<BaseClass>` to a value of type
`RefCountedPtr<Subclass>`, but we cannot do the reverse.

We had been (ab)using this bug to make it more convenient to deal with
down-casting in subclasses of ref-counted types.  For example, because
`Resolver` inherits from `InternallyRefCounted<Resolver>`, calling
`Ref()` on a subclass of `Resolver` will return `RefCountedPtr<Resolver>`
rather than returning the subclass's type.  The ability to implicitly
convert to the subclass type made this a bit easier to deal with.  Now
that that ability is gone, we need a different way of dealing with that
problem.

I considered several ways of dealing with this, but none of them are
quite as ergonomic as I would ideally like.  For now, I've settled on
requiring callers to explicitly down-cast as needed, although I have
provided some utility functions to make this slightly easier:

- `RefCounted<>`, `InternallyRefCounted<>`, and `DualRefCounted<>` all
  provide a templated `RefAsSubclass<>()` method that will return a new
  ref as a subclass.  The type used with `RefAsSubclass()` must be a
  subclass of the type passed to `RefCounted<>`, `InternallyRefCounted<>`,
  or `DualRefCounted<>`.
- In addition, `DualRefCounted<>` provides a templated `WeakRefAsSubclass<T>()`
  method.  This is the same as `RefAsSubclass()`, except that it returns
  a weak ref instead of a strong ref.
- In `RefCountedPtr<>`, I have added a new `Ref()` method that takes
  debug tracing parameters.  This can be used instead of calling `Ref()`
  on the underlying object in cases where the caller already has a
  `RefCountedPtr<>` and is calling `Ref()` only to specify the debug
  tracing parameters.  Using this method on `RefCountedPtr<>` is more
  ergonomic, because the smart pointer is already using the right
  subclass, so no down-casting is needed.
- In `WeakRefCountedPtr<>`, I have added a new `WeakRef()` method that
  takes debug tracing parameters.  This is the same as the new `Ref()`
  method on `RefCountedPtr<>`.
- In both `RefCountedPtr<>` and `WeakRefCountedPtr<>`, I have added a
  templated `TakeAsSubclass<>()` method that takes the ref out of the
  smart pointer and returns a new smart pointer of the down-casted type.
  Just as with the `RefAsSubclass()` method above, the type used with
  `TakeAsSubclass()` must be a subclass of the type passed to
  `RefCountedPtr<>` or `WeakRefCountedPtr<>`.

Note that I have *not* provided an `AsSubclass<>()` variant of the
`RefIfNonZero()` methods.  Those methods are used relatively rarely, so
it's not as important for them to be quite so ergonomic.  Callers of
these methods that need to down-cast can use
`RefIfNonZero().TakeAsSubclass<>()`.

PiperOrigin-RevId: 592327447
pull/35257/head
Mark D. Roth 1 year ago committed by Copybara-Service
parent 85cede5967
commit 3e785d395d
  1. 13
      src/core/ext/filters/client_channel/client_channel.cc
  2. 2
      src/core/ext/filters/client_channel/global_subchannel_pool.cc
  3. 3
      src/core/ext/filters/client_channel/lb_policy/child_policy_handler.cc
  4. 31
      src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb.cc
  5. 11
      src/core/ext/filters/client_channel/lb_policy/health_check_client.cc
  6. 11
      src/core/ext/filters/client_channel/lb_policy/oob_backend_metric.cc
  7. 16
      src/core/ext/filters/client_channel/lb_policy/outlier_detection/outlier_detection.cc
  8. 4
      src/core/ext/filters/client_channel/lb_policy/pick_first/pick_first.cc
  9. 5
      src/core/ext/filters/client_channel/lb_policy/priority/priority.cc
  10. 9
      src/core/ext/filters/client_channel/lb_policy/ring_hash/ring_hash.cc
  11. 25
      src/core/ext/filters/client_channel/lb_policy/rls/rls.cc
  12. 10
      src/core/ext/filters/client_channel/lb_policy/round_robin/round_robin.cc
  13. 30
      src/core/ext/filters/client_channel/lb_policy/weighted_round_robin/weighted_round_robin.cc
  14. 5
      src/core/ext/filters/client_channel/lb_policy/weighted_target/weighted_target.cc
  15. 24
      src/core/ext/filters/client_channel/lb_policy/xds/cds.cc
  16. 6
      src/core/ext/filters/client_channel/lb_policy/xds/xds_cluster_impl.cc
  17. 7
      src/core/ext/filters/client_channel/lb_policy/xds/xds_cluster_manager.cc
  18. 31
      src/core/ext/filters/client_channel/lb_policy/xds/xds_cluster_resolver.cc
  19. 33
      src/core/ext/filters/client_channel/lb_policy/xds/xds_override_host.cc
  20. 6
      src/core/ext/filters/client_channel/lb_policy/xds/xds_wrr_locality.cc
  21. 3
      src/core/ext/filters/client_channel/resolver/dns/c_ares/dns_resolver_ares.cc
  22. 4
      src/core/ext/filters/client_channel/resolver/dns/event_engine/event_engine_client_channel_resolver.cc
  23. 4
      src/core/ext/filters/client_channel/resolver/fake/fake_resolver.cc
  24. 4
      src/core/ext/filters/client_channel/resolver/google_c2p/google_c2p_resolver.cc
  25. 14
      src/core/ext/filters/client_channel/resolver/polling_resolver.cc
  26. 47
      src/core/ext/filters/client_channel/resolver/xds/xds_resolver.cc
  27. 4
      src/core/ext/transport/chttp2/client/chttp2_connector.cc
  28. 3
      src/core/ext/xds/certificate_provider_store.cc
  29. 4
      src/core/ext/xds/xds_client_grpc.cc
  30. 22
      src/core/ext/xds/xds_server_config_fetcher.cc
  31. 5
      src/core/ext/xds/xds_transport_grpc.cc
  32. 4
      src/core/lib/channel/channel_args.h
  33. 43
      src/core/lib/gprpp/dual_ref_counted.h
  34. 32
      src/core/lib/gprpp/orphanable.h
  35. 46
      src/core/lib/gprpp/ref_counted.h
  36. 97
      src/core/lib/gprpp/ref_counted_ptr.h
  37. 3
      src/core/lib/security/credentials/plugin/plugin_credentials.cc
  38. 7
      src/core/lib/security/security_connector/tls/tls_security_connector.cc
  39. 6
      src/core/lib/security/transport/client_auth_filter.cc
  40. 3
      src/core/lib/surface/server.cc
  41. 5
      src/core/lib/transport/connectivity_state.cc
  42. 9
      test/core/client_channel/lb_policy/xds_override_host_lb_config_parser_test.cc
  43. 18
      test/core/gprpp/dual_ref_counted_test.cc
  44. 14
      test/core/gprpp/orphanable_test.cc
  45. 47
      test/core/gprpp/ref_counted_ptr_test.cc
  46. 12
      test/core/gprpp/ref_counted_test.cc
  47. 2
      test/core/server_config_selector/server_config_selector_test.cc
  48. 4
      test/core/util/test_lb_policies.cc
  49. 2
      test/core/xds/file_watcher_certificate_provider_factory_test.cc
  50. 14
      test/core/xds/xds_bootstrap_test.cc
  51. 3
      test/core/xds/xds_client_test.cc
  52. 13
      test/core/xds/xds_transport_fake.cc
  53. 2
      test/cpp/interop/rpc_behavior_lb_policy.cc

@ -714,8 +714,9 @@ class ClientChannel::SubchannelWrapper : public SubchannelInterface {
ABSL_EXCLUSIVE_LOCKS_REQUIRED(*chand_->work_serializer_) {
auto& watcher_wrapper = watcher_map_[watcher.get()];
GPR_ASSERT(watcher_wrapper == nullptr);
watcher_wrapper = new WatcherWrapper(std::move(watcher),
Ref(DEBUG_LOCATION, "WatcherWrapper"));
watcher_wrapper = new WatcherWrapper(
std::move(watcher),
RefAsSubclass<SubchannelWrapper>(DEBUG_LOCATION, "WatcherWrapper"));
subchannel_->WatchConnectivityState(
RefCountedPtr<Subchannel::ConnectivityStateWatcherInterface>(
watcher_wrapper));
@ -919,7 +920,8 @@ ClientChannel::ExternalConnectivityWatcher::ExternalConnectivityWatcher(
GPR_ASSERT(chand->external_watchers_[on_complete] == nullptr);
// Store a ref to the watcher in the external_watchers_ map.
chand->external_watchers_[on_complete] =
Ref(DEBUG_LOCATION, "AddWatcherToExternalWatchersMapLocked");
RefAsSubclass<ExternalConnectivityWatcher>(
DEBUG_LOCATION, "AddWatcherToExternalWatchersMapLocked");
}
// Pass the ref from creating the object to Start().
chand_->work_serializer_->Run(
@ -3421,7 +3423,8 @@ void ClientChannel::FilterBasedLoadBalancedCall::TryPick(bool was_queued) {
void ClientChannel::FilterBasedLoadBalancedCall::OnAddToQueueLocked() {
// Register call combiner cancellation callback.
lb_call_canceller_ = new LbQueuedCallCanceller(Ref());
lb_call_canceller_ =
new LbQueuedCallCanceller(RefAsSubclass<FilterBasedLoadBalancedCall>());
}
void ClientChannel::FilterBasedLoadBalancedCall::RetryPickLocked() {
@ -3510,7 +3513,7 @@ ClientChannel::PromiseBasedLoadBalancedCall::MakeCallPromise(
}
// Extract peer name from server initial metadata.
call_args.server_initial_metadata->InterceptAndMap(
[self = RefCountedPtr<PromiseBasedLoadBalancedCall>(lb_call->Ref())](
[self = lb_call->RefAsSubclass<PromiseBasedLoadBalancedCall>()](
ServerMetadataHandle metadata) {
if (self->call_attempt_tracer() != nullptr) {
self->call_attempt_tracer()->RecordReceivedInitialMetadata(

@ -28,7 +28,7 @@ namespace grpc_core {
RefCountedPtr<GlobalSubchannelPool> GlobalSubchannelPool::instance() {
static GlobalSubchannelPool* p = new GlobalSubchannelPool();
return p->Ref();
return p->RefAsSubclass<GlobalSubchannelPool>();
}
RefCountedPtr<Subchannel> GlobalSubchannelPool::RegisterSubchannel(

@ -272,7 +272,8 @@ void ChildPolicyHandler::ResetBackoffLocked() {
OrphanablePtr<LoadBalancingPolicy> ChildPolicyHandler::CreateChildPolicy(
absl::string_view child_policy_name, const ChannelArgs& args) {
Helper* helper = new Helper(Ref(DEBUG_LOCATION, "Helper"));
Helper* helper =
new Helper(RefAsSubclass<ChildPolicyHandler>(DEBUG_LOCATION, "Helper"));
LoadBalancingPolicy::Args lb_policy_args;
lb_policy_args.work_serializer = work_serializer();
lb_policy_args.channel_control_helper =

@ -324,9 +324,8 @@ class GrpcLb : public LoadBalancingPolicy {
}
return;
}
WeakRefCountedPtr<SubchannelWrapper> self = WeakRef();
lb_policy_->work_serializer()->Run(
[self = std::move(self)]() {
[self = WeakRefAsSubclass<SubchannelWrapper>()]() {
if (!self->lb_policy_->shutting_down_) {
self->lb_policy_->CacheDeletedSubchannelLocked(
self->wrapped_subchannel());
@ -819,8 +818,8 @@ RefCountedPtr<SubchannelInterface> GrpcLb::Helper::CreateSubchannel(
return MakeRefCounted<SubchannelWrapper>(
parent()->channel_control_helper()->CreateSubchannel(
address, per_address_args, args),
parent()->Ref(DEBUG_LOCATION, "SubchannelWrapper"), std::move(lb_token),
std::move(client_stats));
parent()->RefAsSubclass<GrpcLb>(DEBUG_LOCATION, "SubchannelWrapper"),
std::move(lb_token), std::move(client_stats));
}
void GrpcLb::Helper::UpdateState(grpc_connectivity_state state,
@ -1558,7 +1557,7 @@ absl::Status GrpcLb::UpdateLocked(UpdateArgs args) {
gpr_log(GPR_INFO, "[grpclb %p] received update", this);
}
const bool is_initial_update = lb_channel_ == nullptr;
config_ = args.config;
config_ = args.config.TakeAsSubclass<GrpcLbConfig>();
GPR_ASSERT(config_ != nullptr);
args_ = std::move(args.args);
// Update fallback address list.
@ -1581,8 +1580,8 @@ absl::Status GrpcLb::UpdateLocked(UpdateArgs args) {
lb_fallback_timer_handle_ =
channel_control_helper()->GetEventEngine()->RunAfter(
fallback_at_startup_timeout_,
[self = static_cast<RefCountedPtr<GrpcLb>>(
Ref(DEBUG_LOCATION, "on_fallback_timer"))]() mutable {
[self = RefAsSubclass<GrpcLb>(DEBUG_LOCATION,
"on_fallback_timer")]() mutable {
ApplicationCallbackExecCtx callback_exec_ctx;
ExecCtx exec_ctx;
auto self_ptr = self.get();
@ -1597,7 +1596,8 @@ absl::Status GrpcLb::UpdateLocked(UpdateArgs args) {
ClientChannel::GetFromChannel(Channel::FromC(lb_channel_));
GPR_ASSERT(client_channel != nullptr);
// Ref held by callback.
watcher_ = new StateWatcher(Ref(DEBUG_LOCATION, "StateWatcher"));
watcher_ =
new StateWatcher(RefAsSubclass<GrpcLb>(DEBUG_LOCATION, "StateWatcher"));
client_channel->AddConnectivityWatcher(
GRPC_CHANNEL_IDLE,
OrphanablePtr<AsyncConnectivityStateWatcherInterface>(watcher_));
@ -1640,11 +1640,10 @@ absl::Status GrpcLb::UpdateBalancerChannelLocked() {
// Set up channelz linkage.
channelz::ChannelNode* child_channelz_node =
grpc_channel_get_channelz_node(lb_channel_);
channelz::ChannelNode* parent_channelz_node =
args_.GetObject<channelz::ChannelNode>();
auto parent_channelz_node = args_.GetObjectRef<channelz::ChannelNode>();
if (child_channelz_node != nullptr && parent_channelz_node != nullptr) {
parent_channelz_node->AddChildChannel(child_channelz_node->uuid());
parent_channelz_node_ = parent_channelz_node->Ref();
parent_channelz_node_ = std::move(parent_channelz_node);
}
}
// Propagate updates to the LB channel (pick_first) through the fake
@ -1699,8 +1698,8 @@ void GrpcLb::StartBalancerCallRetryTimerLocked() {
lb_call_retry_timer_handle_ =
channel_control_helper()->GetEventEngine()->RunAfter(
timeout,
[self = static_cast<RefCountedPtr<GrpcLb>>(
Ref(DEBUG_LOCATION, "on_balancer_call_retry_timer"))]() mutable {
[self = RefAsSubclass<GrpcLb>(
DEBUG_LOCATION, "on_balancer_call_retry_timer")]() mutable {
ApplicationCallbackExecCtx callback_exec_ctx;
ExecCtx exec_ctx;
auto self_ptr = self.get();
@ -1782,7 +1781,7 @@ OrphanablePtr<LoadBalancingPolicy> GrpcLb::CreateChildPolicyLocked(
lb_policy_args.work_serializer = work_serializer();
lb_policy_args.args = args;
lb_policy_args.channel_control_helper =
std::make_unique<Helper>(Ref(DEBUG_LOCATION, "Helper"));
std::make_unique<Helper>(RefAsSubclass<GrpcLb>(DEBUG_LOCATION, "Helper"));
OrphanablePtr<LoadBalancingPolicy> lb_policy =
MakeOrphanable<ChildPolicyHandler>(std::move(lb_policy_args),
&grpc_lb_glb_trace);
@ -1867,8 +1866,8 @@ void GrpcLb::StartSubchannelCacheTimerLocked() {
subchannel_cache_timer_handle_ =
channel_control_helper()->GetEventEngine()->RunAfter(
cached_subchannels_.begin()->first - Timestamp::Now(),
[self = static_cast<RefCountedPtr<GrpcLb>>(
Ref(DEBUG_LOCATION, "OnSubchannelCacheTimer"))]() mutable {
[self = RefAsSubclass<GrpcLb>(DEBUG_LOCATION,
"OnSubchannelCacheTimer")]() mutable {
ApplicationCallbackExecCtx callback_exec_ctx;
ExecCtx exec_ctx;
auto* self_ptr = self.get();

@ -356,7 +356,8 @@ void HealthProducer::Start(RefCountedPtr<Subchannel> subchannel) {
MutexLock lock(&mu_);
connected_subchannel_ = subchannel_->connected_subchannel();
}
auto connectivity_watcher = MakeRefCounted<ConnectivityWatcher>(WeakRef());
auto connectivity_watcher =
MakeRefCounted<ConnectivityWatcher>(WeakRefAsSubclass<HealthProducer>());
connectivity_watcher_ = connectivity_watcher.get();
subchannel_->WatchConnectivityState(std::move(connectivity_watcher));
}
@ -387,7 +388,8 @@ void HealthProducer::AddWatcher(
health_checkers_.emplace(*health_check_service_name, nullptr).first;
auto& health_checker = it->second;
if (health_checker == nullptr) {
health_checker = MakeOrphanable<HealthChecker>(WeakRef(), it->first);
health_checker = MakeOrphanable<HealthChecker>(
WeakRefAsSubclass<HealthProducer>(), it->first);
}
health_checker->AddWatcherLocked(watcher);
}
@ -456,7 +458,10 @@ void HealthWatcher::SetSubchannel(Subchannel* subchannel) {
subchannel->GetOrAddDataProducer(
HealthProducer::Type(),
[&](Subchannel::DataProducerInterface** producer) {
if (*producer != nullptr) producer_ = (*producer)->RefIfNonZero();
if (*producer != nullptr) {
producer_ =
(*producer)->RefIfNonZero().TakeAsSubclass<HealthProducer>();
}
if (producer_ == nullptr) {
producer_ = MakeRefCounted<HealthProducer>();
*producer = producer_.get();

@ -215,7 +215,8 @@ class OrcaProducer::OrcaStreamEventHandler
void OrcaProducer::Start(RefCountedPtr<Subchannel> subchannel) {
subchannel_ = std::move(subchannel);
connected_subchannel_ = subchannel_->connected_subchannel();
auto connectivity_watcher = MakeRefCounted<ConnectivityWatcher>(WeakRef());
auto connectivity_watcher =
MakeRefCounted<ConnectivityWatcher>(WeakRefAsSubclass<OrcaProducer>());
connectivity_watcher_ = connectivity_watcher.get();
subchannel_->WatchConnectivityState(std::move(connectivity_watcher));
}
@ -269,7 +270,8 @@ void OrcaProducer::MaybeStartStreamLocked() {
if (connected_subchannel_ == nullptr) return;
stream_client_ = MakeOrphanable<SubchannelStreamClient>(
connected_subchannel_, subchannel_->pollset_set(),
std::make_unique<OrcaStreamEventHandler>(WeakRef(), report_interval_),
std::make_unique<OrcaStreamEventHandler>(
WeakRefAsSubclass<OrcaProducer>(), report_interval_),
GRPC_TRACE_FLAG_ENABLED(grpc_orca_client_trace) ? "OrcaClient" : nullptr);
}
@ -310,7 +312,10 @@ void OrcaWatcher::SetSubchannel(Subchannel* subchannel) {
// If not, create a new one.
subchannel->GetOrAddDataProducer(
OrcaProducer::Type(), [&](Subchannel::DataProducerInterface** producer) {
if (*producer != nullptr) producer_ = (*producer)->RefIfNonZero();
if (*producer != nullptr) {
producer_ =
(*producer)->RefIfNonZero().TakeAsSubclass<OrcaProducer>();
}
if (producer_ == nullptr) {
producer_ = MakeRefCounted<OrcaProducer>();
*producer = producer_.get();

@ -151,9 +151,8 @@ class OutlierDetectionLb : public LoadBalancingPolicy {
}
return;
}
WeakRefCountedPtr<SubchannelWrapper> self = WeakRef();
work_serializer_->Run(
[self = std::move(self)]() {
[self = WeakRefAsSubclass<SubchannelWrapper>()]() {
if (self->subchannel_state_ != nullptr) {
self->subchannel_state_->RemoveSubchannel(self.get());
}
@ -624,7 +623,7 @@ absl::Status OutlierDetectionLb::UpdateLocked(UpdateArgs args) {
}
auto old_config = std::move(config_);
// Update config.
config_ = std::move(args.config);
config_ = args.config.TakeAsSubclass<OutlierDetectionLbConfig>();
// Update outlier detection timer.
if (!config_->CountingEnabled()) {
// No need for timer. Cancel the current timer, if any.
@ -639,7 +638,8 @@ absl::Status OutlierDetectionLb::UpdateLocked(UpdateArgs args) {
if (GRPC_TRACE_FLAG_ENABLED(grpc_outlier_detection_lb_trace)) {
gpr_log(GPR_INFO, "[outlier_detection_lb %p] starting timer", this);
}
ejection_timer_ = MakeOrphanable<EjectionTimer>(Ref(), Timestamp::Now());
ejection_timer_ = MakeOrphanable<EjectionTimer>(
RefAsSubclass<OutlierDetectionLb>(), Timestamp::Now());
for (const auto& p : endpoint_state_map_) {
p.second->RotateBucket(); // Reset call counters.
}
@ -654,8 +654,8 @@ absl::Status OutlierDetectionLb::UpdateLocked(UpdateArgs args) {
"[outlier_detection_lb %p] interval changed, replacing timer",
this);
}
ejection_timer_ =
MakeOrphanable<EjectionTimer>(Ref(), ejection_timer_->StartTime());
ejection_timer_ = MakeOrphanable<EjectionTimer>(
RefAsSubclass<OutlierDetectionLb>(), ejection_timer_->StartTime());
}
// Update subchannel and endpoint maps.
if (args.addresses.ok()) {
@ -783,8 +783,8 @@ OrphanablePtr<LoadBalancingPolicy> OutlierDetectionLb::CreateChildPolicyLocked(
LoadBalancingPolicy::Args lb_policy_args;
lb_policy_args.work_serializer = work_serializer();
lb_policy_args.args = args;
lb_policy_args.channel_control_helper =
std::make_unique<Helper>(Ref(DEBUG_LOCATION, "Helper"));
lb_policy_args.channel_control_helper = std::make_unique<Helper>(
RefAsSubclass<OutlierDetectionLb>(DEBUG_LOCATION, "Helper"));
OrphanablePtr<LoadBalancingPolicy> lb_policy =
MakeOrphanable<ChildPolicyHandler>(std::move(lb_policy_args),
&grpc_outlier_detection_lb_trace);

@ -418,7 +418,7 @@ void PickFirst::AttemptToConnectUsingLatestUpdateArgsLocked() {
latest_pending_subchannel_list_.get());
}
latest_pending_subchannel_list_ = MakeOrphanable<SubchannelList>(
Ref(), addresses, latest_update_args_.args);
RefAsSubclass<PickFirst>(), addresses, latest_update_args_.args);
// Empty update or no valid subchannels. Put the channel in
// TRANSIENT_FAILURE and request re-resolution.
if (latest_pending_subchannel_list_->size() == 0) {
@ -1030,7 +1030,7 @@ void PickFirst::SubchannelList::SubchannelData::ProcessUnselectedReadyLocked() {
gpr_log(GPR_INFO, "[PF %p] starting health watch", p);
}
auto watcher = std::make_unique<HealthWatcher>(
p->Ref(DEBUG_LOCATION, "HealthWatcher"));
p->RefAsSubclass<PickFirst>(DEBUG_LOCATION, "HealthWatcher"));
p->health_watcher_ = watcher.get();
auto health_data_watcher = MakeHealthCheckWatcher(
p->work_serializer(), subchannel_list_->args_, std::move(watcher));

@ -335,7 +335,7 @@ absl::Status PriorityLb::UpdateLocked(UpdateArgs args) {
gpr_log(GPR_INFO, "[priority_lb %p] received update", this);
}
// Update config.
config_ = std::move(args.config);
config_ = args.config.TakeAsSubclass<PriorityLbConfig>();
// Update args.
args_ = std::move(args.args);
// Update addresses.
@ -411,7 +411,8 @@ void PriorityLb::ChoosePriorityLocked() {
// Create child if needed.
if (child == nullptr) {
child = MakeOrphanable<ChildPriority>(
Ref(DEBUG_LOCATION, "ChildPriority"), child_name);
RefAsSubclass<PriorityLb>(DEBUG_LOCATION, "ChildPriority"),
child_name);
auto child_config = config_->children().find(child_name);
GPR_DEBUG_ASSERT(child_config != config_->children().end());
// TODO(roth): If the child reports a non-OK status with the

@ -346,7 +346,7 @@ RingHash::PickResult RingHash::Picker::Pick(PickArgs args) {
return endpoint_info.picker->Pick(args);
case GRPC_CHANNEL_IDLE:
new EndpointConnectionAttempter(
ring_hash_->Ref(DEBUG_LOCATION, "EndpointConnectionAttempter"),
ring_hash_.Ref(DEBUG_LOCATION, "EndpointConnectionAttempter"),
endpoint_info.endpoint);
ABSL_FALLTHROUGH_INTENDED;
case GRPC_CHANNEL_CONNECTING:
@ -677,8 +677,8 @@ absl::Status RingHash::UpdateLocked(UpdateArgs args) {
it->second->UpdateLocked(i);
endpoint_map.emplace(address_set, std::move(it->second));
} else {
endpoint_map.emplace(address_set,
MakeOrphanable<RingHashEndpoint>(Ref(), i));
endpoint_map.emplace(address_set, MakeOrphanable<RingHashEndpoint>(
RefAsSubclass<RingHash>(), i));
}
}
endpoint_map_ = std::move(endpoint_map);
@ -779,7 +779,8 @@ void RingHash::UpdateAggregatedConnectivityStateLocked(
// Note that we use our own picker regardless of connectivity state.
channel_control_helper()->UpdateState(
state, status,
MakeRefCounted<Picker>(Ref(DEBUG_LOCATION, "RingHashPicker")));
MakeRefCounted<Picker>(
RefAsSubclass<RingHash>(DEBUG_LOCATION, "RingHashPicker")));
// While the ring_hash policy is reporting TRANSIENT_FAILURE, it will
// not be getting any pick requests from the priority policy.
// However, because the ring_hash policy does not attempt to

@ -1282,7 +1282,7 @@ RlsLb::Cache::Entry::OnRlsResponseLocked(
auto it = lb_policy_->child_policy_map_.find(target);
if (it == lb_policy_->child_policy_map_.end()) {
auto new_child = MakeRefCounted<ChildPolicyWrapper>(
lb_policy_->Ref(DEBUG_LOCATION, "ChildPolicyWrapper"), target);
lb_policy_.Ref(DEBUG_LOCATION, "ChildPolicyWrapper"), target);
new_child->StartUpdate();
child_policies_to_finish_update.push_back(new_child.get());
new_child_policy_wrappers.emplace_back(std::move(new_child));
@ -1326,8 +1326,8 @@ RlsLb::Cache::Entry* RlsLb::Cache::FindOrInsert(const RequestKey& key) {
if (it == map_.end()) {
size_t entry_size = EntrySizeForKey(key);
MaybeShrinkSize(size_limit_ - std::min(size_limit_, entry_size));
Entry* entry =
new Entry(lb_policy_->Ref(DEBUG_LOCATION, "CacheEntry"), key);
Entry* entry = new Entry(
lb_policy_->RefAsSubclass<RlsLb>(DEBUG_LOCATION, "CacheEntry"), key);
map_.emplace(key, OrphanablePtr<Entry>(entry));
size_ += entry_size;
if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_rls_trace)) {
@ -1550,11 +1550,11 @@ RlsLb::RlsChannel::RlsChannel(RefCountedPtr<RlsLb> lb_policy)
// Set up channelz linkage.
channelz::ChannelNode* child_channelz_node =
grpc_channel_get_channelz_node(channel_);
channelz::ChannelNode* parent_channelz_node =
lb_policy_->channel_args_.GetObject<channelz::ChannelNode>();
auto parent_channelz_node =
lb_policy_->channel_args_.GetObjectRef<channelz::ChannelNode>();
if (child_channelz_node != nullptr && parent_channelz_node != nullptr) {
parent_channelz_node->AddChildChannel(child_channelz_node->uuid());
parent_channelz_node_ = parent_channelz_node->Ref();
parent_channelz_node_ = std::move(parent_channelz_node);
}
// Start connectivity watch.
ClientChannel* client_channel =
@ -1607,7 +1607,7 @@ void RlsLb::RlsChannel::StartRlsCall(const RequestKey& key,
}
lb_policy_->request_map_.emplace(
key, MakeOrphanable<RlsRequest>(
lb_policy_->Ref(DEBUG_LOCATION, "RlsRequest"), key,
lb_policy_.Ref(DEBUG_LOCATION, "RlsRequest"), key,
lb_policy_->rls_channel_->Ref(DEBUG_LOCATION, "RlsRequest"),
std::move(backoff_state), reason, std::move(stale_header_data)));
}
@ -1886,7 +1886,7 @@ absl::Status RlsLb::UpdateLocked(UpdateArgs args) {
update_in_progress_ = true;
// Swap out config.
RefCountedPtr<RlsLbConfig> old_config = std::move(config_);
config_ = std::move(args.config);
config_ = args.config.TakeAsSubclass<RlsLbConfig>();
if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_rls_trace) &&
(old_config == nullptr ||
old_config->child_policy_config() != config_->child_policy_config())) {
@ -1926,7 +1926,7 @@ absl::Status RlsLb::UpdateLocked(UpdateArgs args) {
gpr_log(GPR_INFO, "[rlslb %p] creating new default target", this);
}
default_child_policy_ = MakeRefCounted<ChildPolicyWrapper>(
Ref(DEBUG_LOCATION, "ChildPolicyWrapper"),
RefAsSubclass<RlsLb>(DEBUG_LOCATION, "ChildPolicyWrapper"),
config_->default_target());
created_default_child = true;
} else {
@ -1945,8 +1945,8 @@ absl::Status RlsLb::UpdateLocked(UpdateArgs args) {
// Swap out RLS channel if needed.
if (old_config == nullptr ||
config_->lookup_service() != old_config->lookup_service()) {
rls_channel_ =
MakeOrphanable<RlsChannel>(Ref(DEBUG_LOCATION, "RlsChannel"));
rls_channel_ = MakeOrphanable<RlsChannel>(
RefAsSubclass<RlsLb>(DEBUG_LOCATION, "RlsChannel"));
}
// Resize cache if needed.
if (old_config == nullptr ||
@ -2116,7 +2116,8 @@ void RlsLb::UpdatePickerLocked() {
status = absl::UnavailableError("no children available");
}
channel_control_helper()->UpdateState(
state, status, MakeRefCounted<Picker>(Ref(DEBUG_LOCATION, "Picker")));
state, status,
MakeRefCounted<Picker>(RefAsSubclass<RlsLb>(DEBUG_LOCATION, "Picker")));
}
//

@ -404,7 +404,8 @@ void OldRoundRobin::RoundRobinSubchannelList::
}
p->channel_control_helper()->UpdateState(
GRPC_CHANNEL_CONNECTING, absl::Status(),
MakeRefCounted<QueuePicker>(p->Ref(DEBUG_LOCATION, "QueuePicker")));
MakeRefCounted<QueuePicker>(
p->RefAsSubclass<OldRoundRobin>(DEBUG_LOCATION, "QueuePicker")));
} else if (num_transient_failure_ == num_subchannels()) {
if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_round_robin_trace)) {
gpr_log(GPR_INFO,
@ -530,7 +531,7 @@ class RoundRobin : public LoadBalancingPolicy {
? "RoundRobinEndpointList"
: nullptr) {
Init(endpoints, args,
[&](RefCountedPtr<RoundRobinEndpointList> endpoint_list,
[&](RefCountedPtr<EndpointList> endpoint_list,
const EndpointAddresses& addresses, const ChannelArgs& args) {
return MakeOrphanable<RoundRobinEndpoint>(
std::move(endpoint_list), addresses, args,
@ -541,7 +542,7 @@ class RoundRobin : public LoadBalancingPolicy {
private:
class RoundRobinEndpoint : public Endpoint {
public:
RoundRobinEndpoint(RefCountedPtr<RoundRobinEndpointList> endpoint_list,
RoundRobinEndpoint(RefCountedPtr<EndpointList> endpoint_list,
const EndpointAddresses& addresses,
const ChannelArgs& args,
std::shared_ptr<WorkSerializer> work_serializer)
@ -708,7 +709,8 @@ absl::Status RoundRobin::UpdateLocked(UpdateArgs args) {
latest_pending_endpoint_list_.get());
}
latest_pending_endpoint_list_ = MakeOrphanable<RoundRobinEndpointList>(
Ref(DEBUG_LOCATION, "RoundRobinEndpointList"), addresses, args.args);
RefAsSubclass<RoundRobin>(DEBUG_LOCATION, "RoundRobinEndpointList"),
addresses, args.args);
// If the new list is empty, immediately promote it to
// endpoint_list_ and report TRANSIENT_FAILURE.
if (latest_pending_endpoint_list_->size() == 0) {

@ -610,10 +610,9 @@ void OldWeightedRoundRobin::Picker::BuildSchedulerAndStartTimerLocked() {
scheduler_ = std::move(scheduler);
}
// Start timer.
WeakRefCountedPtr<Picker> self = WeakRef();
timer_handle_ = wrr_->channel_control_helper()->GetEventEngine()->RunAfter(
config_->weight_update_period(),
[self = std::move(self),
[self = WeakRefAsSubclass<Picker>(),
work_serializer = wrr_->work_serializer()]() mutable {
ApplicationCallbackExecCtx callback_exec_ctx;
ExecCtx exec_ctx;
@ -673,7 +672,7 @@ void OldWeightedRoundRobin::ResetBackoffLocked() {
absl::Status OldWeightedRoundRobin::UpdateLocked(UpdateArgs args) {
global_stats().IncrementWrrUpdates();
config_ = std::move(args.config);
config_ = args.config.TakeAsSubclass<WeightedRoundRobinConfig>();
std::shared_ptr<EndpointAddressesIterator> addresses;
if (args.addresses.ok()) {
if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_wrr_trace)) {
@ -757,8 +756,9 @@ OldWeightedRoundRobin::GetOrCreateWeight(const grpc_resolved_address& address) {
auto weight = it->second->RefIfNonZero();
if (weight != nullptr) return weight;
}
auto weight =
MakeRefCounted<AddressWeight>(Ref(DEBUG_LOCATION, "AddressWeight"), *key);
auto weight = MakeRefCounted<AddressWeight>(
RefAsSubclass<OldWeightedRoundRobin>(DEBUG_LOCATION, "AddressWeight"),
*key);
address_weight_map_.emplace(*key, weight.get());
return weight;
}
@ -833,7 +833,8 @@ void OldWeightedRoundRobin::WeightedRoundRobinSubchannelList::
}
p->channel_control_helper()->UpdateState(
GRPC_CHANNEL_READY, absl::Status(),
MakeRefCounted<Picker>(p->Ref(), this));
MakeRefCounted<Picker>(p->RefAsSubclass<OldWeightedRoundRobin>(),
this));
} else if (num_connecting_ > 0) {
if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_wrr_trace)) {
gpr_log(GPR_INFO, "[WRR %p] reporting CONNECTING with subchannel list %p",
@ -1038,7 +1039,7 @@ class WeightedRoundRobin : public LoadBalancingPolicy {
public:
class WrrEndpoint : public Endpoint {
public:
WrrEndpoint(RefCountedPtr<WrrEndpointList> endpoint_list,
WrrEndpoint(RefCountedPtr<EndpointList> endpoint_list,
const EndpointAddresses& addresses, const ChannelArgs& args,
std::shared_ptr<WorkSerializer> work_serializer)
: Endpoint(std::move(endpoint_list)),
@ -1086,7 +1087,7 @@ class WeightedRoundRobin : public LoadBalancingPolicy {
? "WrrEndpointList"
: nullptr) {
Init(endpoints, args,
[&](RefCountedPtr<WrrEndpointList> endpoint_list,
[&](RefCountedPtr<EndpointList> endpoint_list,
const EndpointAddresses& addresses, const ChannelArgs& args) {
return MakeOrphanable<WrrEndpoint>(
std::move(endpoint_list), addresses, args,
@ -1452,10 +1453,9 @@ void WeightedRoundRobin::Picker::BuildSchedulerAndStartTimerLocked() {
gpr_log(GPR_INFO, "[WRR %p picker %p] scheduling timer for %s", wrr_.get(),
this, config_->weight_update_period().ToString().c_str());
}
WeakRefCountedPtr<Picker> self = WeakRef();
timer_handle_ = wrr_->channel_control_helper()->GetEventEngine()->RunAfter(
config_->weight_update_period(),
[self = std::move(self),
[self = WeakRefAsSubclass<Picker>(),
work_serializer = wrr_->work_serializer()]() mutable {
ApplicationCallbackExecCtx callback_exec_ctx;
ExecCtx exec_ctx;
@ -1515,7 +1515,7 @@ void WeightedRoundRobin::ResetBackoffLocked() {
absl::Status WeightedRoundRobin::UpdateLocked(UpdateArgs args) {
global_stats().IncrementWrrUpdates();
config_ = std::move(args.config);
config_ = args.config.TakeAsSubclass<WeightedRoundRobinConfig>();
std::shared_ptr<EndpointAddressesIterator> addresses;
if (args.addresses.ok()) {
if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_wrr_trace)) {
@ -1560,8 +1560,8 @@ absl::Status WeightedRoundRobin::UpdateLocked(UpdateArgs args) {
gpr_log(GPR_INFO, "[WRR %p] replacing previous pending endpoint list %p",
this, latest_pending_endpoint_list_.get());
}
latest_pending_endpoint_list_ =
MakeOrphanable<WrrEndpointList>(Ref(), addresses.get(), args.args);
latest_pending_endpoint_list_ = MakeOrphanable<WrrEndpointList>(
RefAsSubclass<WeightedRoundRobin>(), addresses.get(), args.args);
// If the new list is empty, immediately promote it to
// endpoint_list_ and report TRANSIENT_FAILURE.
if (latest_pending_endpoint_list_->size() == 0) {
@ -1599,7 +1599,7 @@ WeightedRoundRobin::GetOrCreateWeight(
if (weight != nullptr) return weight;
}
auto weight = MakeRefCounted<EndpointWeight>(
Ref(DEBUG_LOCATION, "EndpointWeight"), key);
RefAsSubclass<WeightedRoundRobin>(DEBUG_LOCATION, "EndpointWeight"), key);
endpoint_weight_map_.emplace(key, weight.get());
return weight;
}
@ -1759,7 +1759,7 @@ void WeightedRoundRobin::WrrEndpointList::
}
wrr->channel_control_helper()->UpdateState(
GRPC_CHANNEL_READY, absl::Status(),
MakeRefCounted<Picker>(wrr->Ref(), this));
MakeRefCounted<Picker>(wrr->RefAsSubclass<WeightedRoundRobin>(), this));
} else if (num_connecting_ > 0) {
if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_wrr_trace)) {
gpr_log(GPR_INFO, "[WRR %p] reporting CONNECTING with endpoint list %p",

@ -316,7 +316,7 @@ absl::Status WeightedTargetLb::UpdateLocked(UpdateArgs args) {
}
update_in_progress_ = true;
// Update config.
config_ = std::move(args.config);
config_ = args.config.TakeAsSubclass<WeightedTargetLbConfig>();
// Deactivate the targets not in the new config.
for (const auto& p : targets_) {
const std::string& name = p.first;
@ -336,7 +336,8 @@ absl::Status WeightedTargetLb::UpdateLocked(UpdateArgs args) {
// Create child if it does not already exist.
if (target == nullptr) {
target = MakeOrphanable<WeightedChild>(
Ref(DEBUG_LOCATION, "WeightedChild"), name);
RefAsSubclass<WeightedTargetLb>(DEBUG_LOCATION, "WeightedChild"),
name);
}
absl::StatusOr<std::shared_ptr<EndpointAddressesIterator>> addresses;
if (address_map.ok()) {

@ -124,9 +124,8 @@ class CdsLb : public LoadBalancingPolicy {
void OnResourceChanged(
std::shared_ptr<const XdsClusterResource> cluster_data) override {
RefCountedPtr<ClusterWatcher> self = Ref();
parent_->work_serializer()->Run(
[self = std::move(self),
[self = RefAsSubclass<ClusterWatcher>(),
cluster_data = std::move(cluster_data)]() mutable {
self->parent_->OnClusterChanged(self->name_,
std::move(cluster_data));
@ -134,17 +133,16 @@ class CdsLb : public LoadBalancingPolicy {
DEBUG_LOCATION);
}
void OnError(absl::Status status) override {
RefCountedPtr<ClusterWatcher> self = Ref();
parent_->work_serializer()->Run(
[self = std::move(self), status = std::move(status)]() mutable {
[self = RefAsSubclass<ClusterWatcher>(),
status = std::move(status)]() mutable {
self->parent_->OnError(self->name_, std::move(status));
},
DEBUG_LOCATION);
}
void OnResourceDoesNotExist() override {
RefCountedPtr<ClusterWatcher> self = Ref();
parent_->work_serializer()->Run(
[self = std::move(self)]() {
[self = RefAsSubclass<ClusterWatcher>()]() {
self->parent_->OnResourceDoesNotExist(self->name_);
},
DEBUG_LOCATION);
@ -281,7 +279,7 @@ void CdsLb::ExitIdleLocked() {
absl::Status CdsLb::UpdateLocked(UpdateArgs args) {
// Update config.
auto old_config = std::move(config_);
config_ = std::move(args.config);
config_ = args.config.TakeAsSubclass<CdsLbConfig>();
if (GRPC_TRACE_FLAG_ENABLED(grpc_cds_lb_trace)) {
gpr_log(GPR_INFO, "[cdslb %p] received update: cluster=%s", this,
config_->cluster().c_str());
@ -301,7 +299,8 @@ absl::Status CdsLb::UpdateLocked(UpdateArgs args) {
}
watchers_.clear();
}
auto watcher = MakeRefCounted<ClusterWatcher>(Ref(), config_->cluster());
auto watcher = MakeRefCounted<ClusterWatcher>(RefAsSubclass<CdsLb>(),
config_->cluster());
watchers_[config_->cluster()].watcher = watcher.get();
XdsClusterResourceType::StartWatch(xds_client_.get(), config_->cluster(),
std::move(watcher));
@ -330,7 +329,7 @@ absl::StatusOr<bool> CdsLb::GenerateDiscoveryMechanismForCluster(
auto& state = watchers_[name];
// Create a new watcher if needed.
if (state.watcher == nullptr) {
auto watcher = MakeRefCounted<ClusterWatcher>(Ref(), name);
auto watcher = MakeRefCounted<ClusterWatcher>(RefAsSubclass<CdsLb>(), name);
if (GRPC_TRACE_FLAG_ENABLED(grpc_cds_lb_trace)) {
gpr_log(GPR_INFO, "[cdslb %p] starting watch for cluster %s", this,
name.c_str());
@ -505,7 +504,8 @@ void CdsLb::OnClusterChanged(
LoadBalancingPolicy::Args args;
args.work_serializer = work_serializer();
args.args = args_;
args.channel_control_helper = std::make_unique<Helper>(Ref());
args.channel_control_helper =
std::make_unique<Helper>(RefAsSubclass<CdsLb>());
child_policy_ =
CoreConfiguration::Get()
.lb_policy_registry()
@ -596,7 +596,7 @@ absl::Status CdsLb::UpdateXdsCertificateProvider(
absl::string_view root_provider_cert_name =
cluster_data.common_tls_context.certificate_validation_context
.ca_certificate_provider_instance.certificate_name;
RefCountedPtr<XdsCertificateProvider> new_root_provider;
RefCountedPtr<grpc_tls_certificate_provider> new_root_provider;
if (!root_provider_instance_name.empty()) {
new_root_provider =
xds_client_->certificate_provider_store()
@ -620,7 +620,7 @@ absl::Status CdsLb::UpdateXdsCertificateProvider(
absl::string_view identity_provider_cert_name =
cluster_data.common_tls_context.tls_certificate_provider_instance
.certificate_name;
RefCountedPtr<XdsCertificateProvider> new_identity_provider;
RefCountedPtr<grpc_tls_certificate_provider> new_identity_provider;
if (!identity_provider_instance_name.empty()) {
new_identity_provider =
xds_client_->certificate_provider_store()

@ -476,7 +476,7 @@ absl::Status XdsClusterImplLb::UpdateLocked(UpdateArgs args) {
// Update config.
const bool is_initial_update = config_ == nullptr;
auto old_config = std::move(config_);
config_ = std::move(args.config);
config_ = args.config.TakeAsSubclass<XdsClusterImplLbConfig>();
// On initial update, create drop stats.
if (is_initial_update) {
if (config_->lrs_load_reporting_server().has_value()) {
@ -550,8 +550,8 @@ OrphanablePtr<LoadBalancingPolicy> XdsClusterImplLb::CreateChildPolicyLocked(
LoadBalancingPolicy::Args lb_policy_args;
lb_policy_args.work_serializer = work_serializer();
lb_policy_args.args = args;
lb_policy_args.channel_control_helper =
std::make_unique<Helper>(Ref(DEBUG_LOCATION, "Helper"));
lb_policy_args.channel_control_helper = std::make_unique<Helper>(
RefAsSubclass<XdsClusterImplLb>(DEBUG_LOCATION, "Helper"));
OrphanablePtr<LoadBalancingPolicy> lb_policy =
MakeOrphanable<ChildPolicyHandler>(std::move(lb_policy_args),
&grpc_xds_cluster_impl_lb_trace);

@ -283,7 +283,7 @@ absl::Status XdsClusterManagerLb::UpdateLocked(UpdateArgs args) {
}
update_in_progress_ = true;
// Update config.
config_ = std::move(args.config);
config_ = args.config.TakeAsSubclass<XdsClusterManagerLbConfig>();
// Deactivate the children not in the new config.
for (const auto& p : children_) {
const std::string& name = p.first;
@ -299,8 +299,9 @@ absl::Status XdsClusterManagerLb::UpdateLocked(UpdateArgs args) {
const RefCountedPtr<LoadBalancingPolicy::Config>& config = p.second.config;
auto& child = children_[name];
if (child == nullptr) {
child = MakeOrphanable<ClusterChild>(Ref(DEBUG_LOCATION, "ClusterChild"),
name);
child = MakeOrphanable<ClusterChild>(
RefAsSubclass<XdsClusterManagerLb>(DEBUG_LOCATION, "ClusterChild"),
name);
}
absl::Status status =
child->UpdateLocked(config, args.addresses, args.args);

@ -217,25 +217,24 @@ class XdsClusterResolverLb : public LoadBalancingPolicy {
}
void OnResourceChanged(
std::shared_ptr<const XdsEndpointResource> update) override {
RefCountedPtr<EndpointWatcher> self = Ref();
discovery_mechanism_->parent()->work_serializer()->Run(
[self = std::move(self), update = std::move(update)]() mutable {
[self = RefAsSubclass<EndpointWatcher>(),
update = std::move(update)]() mutable {
self->OnResourceChangedHelper(std::move(update));
},
DEBUG_LOCATION);
}
void OnError(absl::Status status) override {
RefCountedPtr<EndpointWatcher> self = Ref();
discovery_mechanism_->parent()->work_serializer()->Run(
[self = std::move(self), status = std::move(status)]() mutable {
[self = RefAsSubclass<EndpointWatcher>(),
status = std::move(status)]() mutable {
self->OnErrorHelper(std::move(status));
},
DEBUG_LOCATION);
}
void OnResourceDoesNotExist() override {
RefCountedPtr<EndpointWatcher> self = Ref();
discovery_mechanism_->parent()->work_serializer()->Run(
[self = std::move(self)]() {
[self = RefAsSubclass<EndpointWatcher>()]() {
self->OnResourceDoesNotExistHelper();
},
DEBUG_LOCATION);
@ -424,8 +423,9 @@ void XdsClusterResolverLb::EdsDiscoveryMechanism::Start() {
":%p starting xds watch for %s",
parent(), index(), this, std::string(GetEdsResourceName()).c_str());
}
auto watcher = MakeRefCounted<EndpointWatcher>(
Ref(DEBUG_LOCATION, "EdsDiscoveryMechanism"));
auto watcher =
MakeRefCounted<EndpointWatcher>(RefAsSubclass<EdsDiscoveryMechanism>(
DEBUG_LOCATION, "EdsDiscoveryMechanism"));
watcher_ = watcher.get();
XdsEndpointResourceType::StartWatch(parent()->xds_client_.get(),
GetEdsResourceName(), std::move(watcher));
@ -463,7 +463,8 @@ void XdsClusterResolverLb::LogicalDNSDiscoveryMechanism::Start() {
target.c_str(), args, parent()->interested_parties(),
parent()->work_serializer(),
std::make_unique<ResolverResultHandler>(
Ref(DEBUG_LOCATION, "LogicalDNSDiscoveryMechanism")));
RefAsSubclass<LogicalDNSDiscoveryMechanism>(
DEBUG_LOCATION, "LogicalDNSDiscoveryMechanism")));
if (resolver_ == nullptr) {
parent()->OnResourceDoesNotExist(
index(),
@ -591,7 +592,7 @@ absl::Status XdsClusterResolverLb::UpdateLocked(UpdateArgs args) {
const bool is_initial_update = args_ == ChannelArgs();
// Update config.
auto old_config = std::move(config_);
config_ = std::move(args.config);
config_ = args.config.TakeAsSubclass<XdsClusterResolverLbConfig>();
// Update args.
args_ = std::move(args.args);
// Update child policy if needed.
@ -604,13 +605,15 @@ absl::Status XdsClusterResolverLb::UpdateLocked(UpdateArgs args) {
if (config.type == XdsClusterResolverLbConfig::DiscoveryMechanism::
DiscoveryMechanismType::EDS) {
entry.discovery_mechanism = MakeOrphanable<EdsDiscoveryMechanism>(
Ref(DEBUG_LOCATION, "EdsDiscoveryMechanism"),
RefAsSubclass<XdsClusterResolverLb>(DEBUG_LOCATION,
"EdsDiscoveryMechanism"),
discovery_mechanisms_.size());
} else if (config.type == XdsClusterResolverLbConfig::DiscoveryMechanism::
DiscoveryMechanismType::LOGICAL_DNS) {
entry.discovery_mechanism =
MakeOrphanable<LogicalDNSDiscoveryMechanism>(
Ref(DEBUG_LOCATION, "LogicalDNSDiscoveryMechanism"),
RefAsSubclass<XdsClusterResolverLb>(
DEBUG_LOCATION, "LogicalDNSDiscoveryMechanism"),
discovery_mechanisms_.size());
} else {
GPR_ASSERT(0);
@ -1010,8 +1013,8 @@ XdsClusterResolverLb::CreateChildPolicyLocked(const ChannelArgs& args) {
LoadBalancingPolicy::Args lb_policy_args;
lb_policy_args.work_serializer = work_serializer();
lb_policy_args.args = args;
lb_policy_args.channel_control_helper =
std::make_unique<Helper>(Ref(DEBUG_LOCATION, "Helper"));
lb_policy_args.channel_control_helper = std::make_unique<Helper>(
RefAsSubclass<XdsClusterResolverLb>(DEBUG_LOCATION, "Helper"));
OrphanablePtr<LoadBalancingPolicy> lb_policy =
CoreConfiguration::Get().lb_policy_registry().CreateLoadBalancingPolicy(
"priority_experimental", std::move(lb_policy_args));

@ -185,9 +185,9 @@ class XdsOverrideHostLb : public LoadBalancingPolicy {
void SetSubchannel(SubchannelWrapper* subchannel) {
if (eds_health_status_.status() == XdsHealthStatus::kDraining) {
subchannel_ = subchannel->Ref();
subchannel_ = subchannel->RefAsSubclass<SubchannelWrapper>();
} else {
subchannel_ = subchannel->WeakRef();
subchannel_ = subchannel->WeakRefAsSubclass<SubchannelWrapper>();
}
}
@ -210,9 +210,9 @@ class XdsOverrideHostLb : public LoadBalancingPolicy {
auto subchannel = GetSubchannel();
if (subchannel == nullptr) return;
if (eds_health_status_.status() == XdsHealthStatus::kDraining) {
subchannel_ = subchannel->Ref();
subchannel_ = subchannel->RefAsSubclass<SubchannelWrapper>();
} else {
subchannel_ = subchannel->WeakRef();
subchannel_ = subchannel->WeakRefAsSubclass<SubchannelWrapper>();
}
}
@ -362,7 +362,9 @@ XdsOverrideHostLb::Picker::PickOverridenHost(
RefCountedPtr<SubchannelWrapper> subchannel;
auto it = policy_->subchannel_map_.find(address);
if (it != policy_->subchannel_map_.end()) {
subchannel = it->second.GetSubchannel()->RefIfNonZero();
subchannel = it->second.GetSubchannel()
->RefIfNonZero()
.TakeAsSubclass<SubchannelWrapper>();
}
if (subchannel == nullptr) {
if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_xds_override_host_trace)) {
@ -540,7 +542,7 @@ absl::Status XdsOverrideHostLb::UpdateLocked(UpdateArgs args) {
}
auto old_config = std::move(config_);
// Update config.
config_ = std::move(args.config);
config_ = args.config.TakeAsSubclass<XdsOverrideHostLbConfig>();
if (config_ == nullptr) {
return absl::InvalidArgumentError("Missing policy config");
}
@ -575,8 +577,9 @@ absl::Status XdsOverrideHostLb::UpdateLocked(UpdateArgs args) {
void XdsOverrideHostLb::MaybeUpdatePickerLocked() {
if (picker_ != nullptr) {
auto xds_override_host_picker = MakeRefCounted<Picker>(
Ref(), picker_, config_->override_host_status_set());
auto xds_override_host_picker =
MakeRefCounted<Picker>(RefAsSubclass<XdsOverrideHostLb>(), picker_,
config_->override_host_status_set());
if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_xds_override_host_trace)) {
gpr_log(GPR_INFO,
"[xds_override_host_lb %p] updating connectivity: state=%s "
@ -594,8 +597,8 @@ OrphanablePtr<LoadBalancingPolicy> XdsOverrideHostLb::CreateChildPolicyLocked(
LoadBalancingPolicy::Args lb_policy_args;
lb_policy_args.work_serializer = work_serializer();
lb_policy_args.args = args;
lb_policy_args.channel_control_helper =
std::make_unique<Helper>(Ref(DEBUG_LOCATION, "Helper"));
lb_policy_args.channel_control_helper = std::make_unique<Helper>(
RefAsSubclass<XdsOverrideHostLb>(DEBUG_LOCATION, "Helper"));
OrphanablePtr<LoadBalancingPolicy> lb_policy =
MakeOrphanable<ChildPolicyHandler>(std::move(lb_policy_args),
&grpc_lb_xds_override_host_trace);
@ -713,8 +716,8 @@ XdsOverrideHostLb::AdoptSubchannel(
const grpc_resolved_address& address,
RefCountedPtr<SubchannelInterface> subchannel) {
auto key = grpc_sockaddr_to_string(&address, /*normalize=*/false);
auto wrapper =
MakeRefCounted<SubchannelWrapper>(std::move(subchannel), Ref());
auto wrapper = MakeRefCounted<SubchannelWrapper>(
std::move(subchannel), RefAsSubclass<XdsOverrideHostLb>());
if (key.ok()) {
MutexLock lock(&subchannel_map_mu_);
auto it = subchannel_map_.find(*key);
@ -780,7 +783,8 @@ XdsOverrideHostLb::SubchannelWrapper::SubchannelWrapper(
RefCountedPtr<SubchannelInterface> subchannel,
RefCountedPtr<XdsOverrideHostLb> policy)
: DelegatingSubchannel(std::move(subchannel)), policy_(std::move(policy)) {
auto watcher = std::make_unique<ConnectivityStateWatcher>(WeakRef());
auto watcher = std::make_unique<ConnectivityStateWatcher>(
WeakRefAsSubclass<SubchannelWrapper>());
watcher_ = watcher.get();
wrapped_subchannel()->WatchConnectivityState(std::move(watcher));
}
@ -831,9 +835,8 @@ void XdsOverrideHostLb::SubchannelWrapper::Orphan() {
wrapped_subchannel()->CancelConnectivityStateWatch(watcher_);
return;
}
WeakRefCountedPtr<SubchannelWrapper> self = WeakRef();
policy_->work_serializer()->Run(
[self = std::move(self)]() {
[self = WeakRefAsSubclass<SubchannelWrapper>()]() {
self->key_.reset();
self->wrapped_subchannel()->CancelConnectivityStateWatch(
self->watcher_);

@ -165,7 +165,7 @@ absl::Status XdsWrrLocalityLb::UpdateLocked(UpdateArgs args) {
if (GRPC_TRACE_FLAG_ENABLED(grpc_xds_wrr_locality_lb_trace)) {
gpr_log(GPR_INFO, "[xds_wrr_locality_lb %p] Received update", this);
}
RefCountedPtr<XdsWrrLocalityLbConfig> config = std::move(args.config);
auto config = args.config.TakeAsSubclass<XdsWrrLocalityLbConfig>();
// Scan the addresses to find the weight for each locality.
std::map<std::string, uint32_t> locality_weights;
if (args.addresses.ok()) {
@ -252,8 +252,8 @@ OrphanablePtr<LoadBalancingPolicy> XdsWrrLocalityLb::CreateChildPolicyLocked(
LoadBalancingPolicy::Args lb_policy_args;
lb_policy_args.work_serializer = work_serializer();
lb_policy_args.args = args;
lb_policy_args.channel_control_helper =
std::make_unique<Helper>(Ref(DEBUG_LOCATION, "Helper"));
lb_policy_args.channel_control_helper = std::make_unique<Helper>(
RefAsSubclass<XdsWrrLocalityLb>(DEBUG_LOCATION, "Helper"));
auto lb_policy =
CoreConfiguration::Get().lb_policy_registry().CreateLoadBalancingPolicy(
"weighted_target_experimental", std::move(lb_policy_args));

@ -224,7 +224,8 @@ AresClientChannelDNSResolver::~AresClientChannelDNSResolver() {
OrphanablePtr<Orphanable> AresClientChannelDNSResolver::StartRequest() {
return MakeOrphanable<AresRequestWrapper>(
Ref(DEBUG_LOCATION, "dns-resolving"));
RefAsSubclass<AresClientChannelDNSResolver>(DEBUG_LOCATION,
"dns-resolving"));
}
void AresClientChannelDNSResolver::AresRequestWrapper::OnHostnameResolved(

@ -210,7 +210,9 @@ OrphanablePtr<Orphanable> EventEngineClientChannelDNSResolver::StartRequest() {
return nullptr;
}
return MakeOrphanable<EventEngineDNSRequestWrapper>(
Ref(DEBUG_LOCATION, "dns-resolving"), std::move(*dns_resolver));
RefAsSubclass<EventEngineClientChannelDNSResolver>(DEBUG_LOCATION,
"dns-resolving"),
std::move(*dns_resolver));
}
// ----------------------------------------------------------------------------

@ -85,7 +85,7 @@ FakeResolver::FakeResolver(ResolverArgs args)
response_generator_(
args.args.GetObjectRef<FakeResolverResponseGenerator>()) {
if (response_generator_ != nullptr) {
response_generator_->SetFakeResolver(Ref());
response_generator_->SetFakeResolver(RefAsSubclass<FakeResolver>());
}
}
@ -137,7 +137,7 @@ void FakeResolverResponseGenerator::SetResponseAndNotify(
if (notify_when_set != nullptr) notify_when_set->Notify();
return;
}
resolver = resolver_->Ref();
resolver = resolver_;
}
SendResultToResolver(std::move(resolver), std::move(result), notify_when_set);
}

@ -154,7 +154,7 @@ void GoogleCloud2ProdResolver::StartLocked() {
zone_query_ = MakeOrphanable<MetadataQuery>(
metadata_server_name_, std::string(MetadataQuery::kZoneAttribute),
&pollent_,
[resolver = static_cast<RefCountedPtr<GoogleCloud2ProdResolver>>(Ref())](
[resolver = RefAsSubclass<GoogleCloud2ProdResolver>()](
std::string /* attribute */,
absl::StatusOr<std::string> result) mutable {
resolver->work_serializer_->Run(
@ -168,7 +168,7 @@ void GoogleCloud2ProdResolver::StartLocked() {
ipv6_query_ = MakeOrphanable<MetadataQuery>(
metadata_server_name_, std::string(MetadataQuery::kIPv6Attribute),
&pollent_,
[resolver = static_cast<RefCountedPtr<GoogleCloud2ProdResolver>>(Ref())](
[resolver = RefAsSubclass<GoogleCloud2ProdResolver>()](
std::string /* attribute */,
absl::StatusOr<std::string> result) mutable {
resolver->work_serializer_->Run(

@ -105,10 +105,9 @@ void PollingResolver::ShutdownLocked() {
}
void PollingResolver::ScheduleNextResolutionTimer(const Duration& timeout) {
RefCountedPtr<PollingResolver> self = Ref();
next_resolution_timer_handle_ =
channel_args_.GetObject<EventEngine>()->RunAfter(
timeout, [self = std::move(self)]() mutable {
timeout, [self = RefAsSubclass<PollingResolver>()]() mutable {
ApplicationCallbackExecCtx callback_exec_ctx;
ExecCtx exec_ctx;
auto* self_ptr = self.get();
@ -174,12 +173,11 @@ void PollingResolver::OnRequestCompleteLocked(Result result) {
result.resolution_note.c_str());
}
GPR_ASSERT(result.result_health_callback == nullptr);
RefCountedPtr<PollingResolver> self =
Ref(DEBUG_LOCATION, "result_health_callback");
result.result_health_callback = [self =
std::move(self)](absl::Status status) {
self->GetResultStatus(std::move(status));
};
result.result_health_callback =
[self = RefAsSubclass<PollingResolver>(
DEBUG_LOCATION, "result_health_callback")](absl::Status status) {
self->GetResultStatus(std::move(status));
};
result_status_state_ = ResultStatusState::kResultHealthCallbackPending;
result_handler_->ReportResult(std::move(result));
}

@ -142,26 +142,25 @@ class XdsResolver : public Resolver {
: resolver_(std::move(resolver)) {}
void OnResourceChanged(
std::shared_ptr<const XdsListenerResource> listener) override {
RefCountedPtr<ListenerWatcher> self = Ref();
resolver_->work_serializer_->Run(
[self = std::move(self), listener = std::move(listener)]() mutable {
[self = RefAsSubclass<ListenerWatcher>(),
listener = std::move(listener)]() mutable {
self->resolver_->OnListenerUpdate(std::move(listener));
},
DEBUG_LOCATION);
}
void OnError(absl::Status status) override {
RefCountedPtr<ListenerWatcher> self = Ref();
resolver_->work_serializer_->Run(
[self = std::move(self), status = std::move(status)]() mutable {
[self = RefAsSubclass<ListenerWatcher>(),
status = std::move(status)]() mutable {
self->resolver_->OnError(self->resolver_->lds_resource_name_,
std::move(status));
},
DEBUG_LOCATION);
}
void OnResourceDoesNotExist() override {
RefCountedPtr<ListenerWatcher> self = Ref();
resolver_->work_serializer_->Run(
[self = std::move(self)]() {
[self = RefAsSubclass<ListenerWatcher>()]() {
self->resolver_->OnResourceDoesNotExist(
absl::StrCat(self->resolver_->lds_resource_name_,
": xDS listener resource does not exist"));
@ -180,9 +179,8 @@ class XdsResolver : public Resolver {
: resolver_(std::move(resolver)) {}
void OnResourceChanged(
std::shared_ptr<const XdsRouteConfigResource> route_config) override {
RefCountedPtr<RouteConfigWatcher> self = Ref();
resolver_->work_serializer_->Run(
[self = std::move(self),
[self = RefAsSubclass<RouteConfigWatcher>(),
route_config = std::move(route_config)]() mutable {
if (self != self->resolver_->route_config_watcher_) return;
self->resolver_->OnRouteConfigUpdate(std::move(route_config));
@ -190,9 +188,9 @@ class XdsResolver : public Resolver {
DEBUG_LOCATION);
}
void OnError(absl::Status status) override {
RefCountedPtr<RouteConfigWatcher> self = Ref();
resolver_->work_serializer_->Run(
[self = std::move(self), status = std::move(status)]() mutable {
[self = RefAsSubclass<RouteConfigWatcher>(),
status = std::move(status)]() mutable {
if (self != self->resolver_->route_config_watcher_) return;
self->resolver_->OnError(self->resolver_->route_config_name_,
std::move(status));
@ -200,9 +198,8 @@ class XdsResolver : public Resolver {
DEBUG_LOCATION);
}
void OnResourceDoesNotExist() override {
RefCountedPtr<RouteConfigWatcher> self = Ref();
resolver_->work_serializer_->Run(
[self = std::move(self)]() {
[self = RefAsSubclass<RouteConfigWatcher>()]() {
if (self != self->resolver_->route_config_watcher_) return;
self->resolver_->OnResourceDoesNotExist(absl::StrCat(
self->resolver_->route_config_name_,
@ -389,7 +386,8 @@ class XdsResolver : public Resolver {
absl::string_view cluster_name) {
auto it = cluster_ref_map_.find(cluster_name);
if (it == cluster_ref_map_.end()) {
auto cluster = MakeRefCounted<ClusterRef>(Ref(), cluster_name);
auto cluster = MakeRefCounted<ClusterRef>(RefAsSubclass<XdsResolver>(),
cluster_name);
cluster_ref_map_.emplace(cluster->cluster_name(), cluster->WeakRef());
return cluster;
}
@ -980,7 +978,7 @@ void XdsResolver::StartLocked() {
grpc_pollset_set_add_pollset_set(
static_cast<GrpcXdsClient*>(xds_client_.get())->interested_parties(),
interested_parties_);
auto watcher = MakeRefCounted<ListenerWatcher>(Ref());
auto watcher = MakeRefCounted<ListenerWatcher>(RefAsSubclass<XdsResolver>());
listener_watcher_ = watcher.get();
XdsListenerResourceType::StartWatch(xds_client_.get(), lds_resource_name_,
std::move(watcher));
@ -1042,7 +1040,8 @@ void XdsResolver::OnListenerUpdate(
}
// Start watch for the new RDS resource name.
route_config_name_ = rds_name;
auto watcher = MakeRefCounted<RouteConfigWatcher>(Ref());
auto watcher =
MakeRefCounted<RouteConfigWatcher>(RefAsSubclass<XdsResolver>());
route_config_watcher_ = watcher.get();
XdsRouteConfigResourceType::StartWatch(
xds_client_.get(), route_config_name_, std::move(watcher));
@ -1118,11 +1117,8 @@ void XdsResolver::OnError(absl::string_view context, absl::Status status) {
Result result;
result.addresses = status;
result.service_config = std::move(status);
// Need to explicitly convert to the right RefCountedPtr<> type for
// use with ChannelArgs::SetObject().
RefCountedPtr<GrpcXdsClient> xds_client =
xds_client_->Ref(DEBUG_LOCATION, "xds resolver result");
result.args = args_.SetObject(std::move(xds_client));
result.args =
args_.SetObject(xds_client_.Ref(DEBUG_LOCATION, "xds resolver result"));
result_handler_->ReportResult(std::move(result));
}
@ -1197,8 +1193,8 @@ void XdsResolver::GenerateResult() {
absl::UnavailableError(route_config_data.status().message()));
return;
}
auto config_selector =
MakeRefCounted<XdsConfigSelector>(Ref(), std::move(*route_config_data));
auto config_selector = MakeRefCounted<XdsConfigSelector>(
RefAsSubclass<XdsResolver>(), std::move(*route_config_data));
Result result;
result.addresses.emplace();
result.service_config = CreateServiceConfig();
@ -1208,12 +1204,9 @@ void XdsResolver::GenerateResult() {
? std::string((*result.service_config)->json_string()).c_str()
: result.service_config.status().ToString().c_str());
}
// Need to explicitly convert to the right RefCountedPtr<> type for
// use with ChannelArgs::SetObject().
RefCountedPtr<GrpcXdsClient> xds_client =
xds_client_->Ref(DEBUG_LOCATION, "xds resolver result");
result.args =
args_.SetObject(std::move(xds_client)).SetObject(config_selector);
args_.SetObject(xds_client_.Ref(DEBUG_LOCATION, "xds resolver result"))
.SetObject(config_selector);
result_handler_->ReportResult(std::move(result));
}

@ -177,9 +177,9 @@ void Chttp2Connector::OnHandshakeDone(void* arg, grpc_error_handle error) {
grpc_chttp2_transport_start_reading(self->result_->transport,
args->read_buffer,
&self->on_receive_settings_, nullptr);
RefCountedPtr<Chttp2Connector> cc = self->Ref();
self->timer_handle_ = self->event_engine_->RunAfter(
self->args_.deadline - Timestamp::Now(), [self = std::move(cc)] {
self->args_.deadline - Timestamp::Now(),
[self = self->RefAsSubclass<Chttp2Connector>()] {
ApplicationCallbackExecCtx callback_exec_ctx;
ExecCtx exec_ctx;
self->OnTimeout();

@ -106,7 +106,8 @@ CertificateProviderStore::CreateOrGetCertificateProvider(
certificate_providers_map_.insert({result->key(), result.get()});
}
} else {
result = it->second->RefIfNonZero();
result =
it->second->RefIfNonZero().TakeAsSubclass<CertificateProviderWrapper>();
if (result == nullptr) {
result = CreateCertificateProviderLocked(key);
it->second = result.get();

@ -155,7 +155,9 @@ absl::StatusOr<RefCountedPtr<GrpcXdsClient>> GrpcXdsClient::GetOrCreate(
MutexLock lock(g_mu);
if (g_xds_client != nullptr) {
auto xds_client = g_xds_client->RefIfNonZero(DEBUG_LOCATION, reason);
if (xds_client != nullptr) return xds_client;
if (xds_client != nullptr) {
return xds_client.TakeAsSubclass<GrpcXdsClient>();
}
}
// Find bootstrap contents.
auto bootstrap_contents = GetBootstrapContents(g_fallback_bootstrap_config);

@ -527,7 +527,7 @@ void XdsServerConfigFetcher::StartWatch(
std::unique_ptr<grpc_server_config_fetcher::WatcherInterface> watcher) {
grpc_server_config_fetcher::WatcherInterface* watcher_ptr = watcher.get();
auto listener_watcher = MakeRefCounted<ListenerWatcher>(
xds_client_->Ref(DEBUG_LOCATION, "ListenerWatcher"), std::move(watcher),
xds_client_.Ref(DEBUG_LOCATION, "ListenerWatcher"), std::move(watcher),
serving_status_notifier_, listening_address);
auto* listener_watcher_ptr = listener_watcher.get();
XdsListenerResourceType::StartWatch(
@ -595,7 +595,7 @@ void XdsServerConfigFetcher::ListenerWatcher::OnResourceChanged(
return;
}
auto new_filter_chain_match_manager = MakeRefCounted<FilterChainMatchManager>(
xds_client_->Ref(DEBUG_LOCATION, "FilterChainMatchManager"),
xds_client_.Ref(DEBUG_LOCATION, "FilterChainMatchManager"),
tcp_listener->filter_chain_map, tcp_listener->default_filter_chain);
MutexLock lock(&mu_);
if (filter_chain_match_manager_ == nullptr ||
@ -605,7 +605,8 @@ void XdsServerConfigFetcher::ListenerWatcher::OnResourceChanged(
filter_chain_match_manager_->default_filter_chain())) {
pending_filter_chain_match_manager_ =
std::move(new_filter_chain_match_manager);
pending_filter_chain_match_manager_->StartRdsWatch(Ref());
pending_filter_chain_match_manager_->StartRdsWatch(
RefAsSubclass<ListenerWatcher>());
}
}
@ -743,8 +744,8 @@ void XdsServerConfigFetcher::ListenerWatcher::FilterChainMatchManager::
MutexLock lock(&mu_);
for (const auto& resource_name : resource_names) {
++rds_resources_yet_to_fetch_;
auto route_config_watcher =
MakeRefCounted<RouteConfigWatcher>(resource_name, WeakRef());
auto route_config_watcher = MakeRefCounted<RouteConfigWatcher>(
resource_name, WeakRefAsSubclass<FilterChainMatchManager>());
rds_map_.emplace(resource_name, RdsUpdateState{route_config_watcher.get(),
absl::nullopt});
watchers_to_start.push_back(
@ -1122,8 +1123,8 @@ absl::StatusOr<ChannelArgs> XdsServerConfigFetcher::ListenerWatcher::
}
server_config_selector_provider =
MakeRefCounted<DynamicXdsServerConfigSelectorProvider>(
xds_client_->Ref(DEBUG_LOCATION,
"DynamicXdsServerConfigSelectorProvider"),
xds_client_.Ref(DEBUG_LOCATION,
"DynamicXdsServerConfigSelectorProvider"),
rds_name, std::move(initial_resource),
filter_chain->http_connection_manager.http_filters);
},
@ -1131,8 +1132,8 @@ absl::StatusOr<ChannelArgs> XdsServerConfigFetcher::ListenerWatcher::
[&](const std::shared_ptr<const XdsRouteConfigResource>& route_config) {
server_config_selector_provider =
MakeRefCounted<StaticXdsServerConfigSelectorProvider>(
xds_client_->Ref(DEBUG_LOCATION,
"StaticXdsServerConfigSelectorProvider"),
xds_client_.Ref(DEBUG_LOCATION,
"StaticXdsServerConfigSelectorProvider"),
route_config,
filter_chain->http_connection_manager.http_filters);
});
@ -1270,7 +1271,8 @@ XdsServerConfigFetcher::ListenerWatcher::FilterChainMatchManager::
// RouteConfigWatcher is being created here instead of in Watch() to avoid
// deadlocks from invoking XdsRouteConfigResourceType::StartWatch whilst in a
// critical region.
auto route_config_watcher = MakeRefCounted<RouteConfigWatcher>(WeakRef());
auto route_config_watcher = MakeRefCounted<RouteConfigWatcher>(
WeakRefAsSubclass<DynamicXdsServerConfigSelectorProvider>());
route_config_watcher_ = route_config_watcher.get();
XdsRouteConfigResourceType::StartWatch(xds_client_.get(), resource_name_,
std::move(route_config_watcher));

@ -321,8 +321,9 @@ GrpcXdsTransportFactory::GrpcXdsTransport::CreateStreamingCall(
const char* method,
std::unique_ptr<StreamingCall::EventHandler> event_handler) {
return MakeOrphanable<GrpcStreamingCall>(
factory_->Ref(DEBUG_LOCATION, "StreamingCall"), channel_, method,
std::move(event_handler));
factory_->RefAsSubclass<GrpcXdsTransportFactory>(DEBUG_LOCATION,
"StreamingCall"),
channel_, method, std::move(event_handler));
}
void GrpcXdsTransportFactory::GrpcXdsTransport::ResetBackoff() {

@ -233,12 +233,12 @@ struct GetObjectImpl<
static Result Get(StoredType p) { return p; };
static ReffedResult GetReffed(StoredType p) {
if (p == nullptr) return nullptr;
return p->Ref();
return p->template RefAsSubclass<T>();
};
static ReffedResult GetReffed(StoredType p, const DebugLocation& location,
const char* reason) {
if (p == nullptr) return nullptr;
return p->Ref(location, reason);
return p->template RefAsSubclass<T>(location, reason);
};
};

@ -47,19 +47,38 @@ namespace grpc_core {
template <typename Child>
class DualRefCounted : public Orphanable {
public:
// Not copyable nor movable.
DualRefCounted(const DualRefCounted&) = delete;
DualRefCounted& operator=(const DualRefCounted&) = delete;
~DualRefCounted() override = default;
GRPC_MUST_USE_RESULT RefCountedPtr<Child> Ref() {
IncrementRefCount();
return RefCountedPtr<Child>(static_cast<Child*>(this));
}
GRPC_MUST_USE_RESULT RefCountedPtr<Child> Ref(const DebugLocation& location,
const char* reason) {
IncrementRefCount(location, reason);
return RefCountedPtr<Child>(static_cast<Child*>(this));
}
template <
typename Subclass,
std::enable_if_t<std::is_base_of<Child, Subclass>::value, bool> = true>
RefCountedPtr<Subclass> RefAsSubclass() {
IncrementRefCount();
return RefCountedPtr<Subclass>(static_cast<Subclass*>(this));
}
template <
typename Subclass,
std::enable_if_t<std::is_base_of<Child, Subclass>::value, bool> = true>
RefCountedPtr<Subclass> RefAsSubclass(const DebugLocation& location,
const char* reason) {
IncrementRefCount(location, reason);
return RefCountedPtr<Subclass>(static_cast<Subclass*>(this));
}
void Unref() {
// Convert strong ref to weak ref.
const uint64_t prev_ref_pair =
@ -120,7 +139,6 @@ class DualRefCounted : public Orphanable {
std::memory_order_acq_rel, std::memory_order_acquire));
return RefCountedPtr<Child>(static_cast<Child*>(this));
}
GRPC_MUST_USE_RESULT RefCountedPtr<Child> RefIfNonZero(
const DebugLocation& location, const char* reason) {
uint64_t prev_ref_pair = refs_.load(std::memory_order_acquire);
@ -150,13 +168,28 @@ class DualRefCounted : public Orphanable {
IncrementWeakRefCount();
return WeakRefCountedPtr<Child>(static_cast<Child*>(this));
}
GRPC_MUST_USE_RESULT WeakRefCountedPtr<Child> WeakRef(
const DebugLocation& location, const char* reason) {
IncrementWeakRefCount(location, reason);
return WeakRefCountedPtr<Child>(static_cast<Child*>(this));
}
template <
typename Subclass,
std::enable_if_t<std::is_base_of<Child, Subclass>::value, bool> = true>
WeakRefCountedPtr<Subclass> WeakRefAsSubclass() {
IncrementWeakRefCount();
return WeakRefCountedPtr<Subclass>(static_cast<Subclass*>(this));
}
template <
typename Subclass,
std::enable_if_t<std::is_base_of<Child, Subclass>::value, bool> = true>
WeakRefCountedPtr<Subclass> WeakRefAsSubclass(const DebugLocation& location,
const char* reason) {
IncrementWeakRefCount(location, reason);
return WeakRefCountedPtr<Subclass>(static_cast<Subclass*>(this));
}
void WeakUnref() {
#ifndef NDEBUG
// Grab a copy of the trace flag before the atomic change, since we
@ -207,10 +240,6 @@ class DualRefCounted : public Orphanable {
}
}
// Not copyable nor movable.
DualRefCounted(const DualRefCounted&) = delete;
DualRefCounted& operator=(const DualRefCounted&) = delete;
protected:
// Note: Tracing is a no-op in non-debug builds.
explicit DualRefCounted(

@ -97,15 +97,20 @@ class InternallyRefCounted : public Orphanable {
return RefCountedPtr<Child>(static_cast<Child*>(this));
}
void Unref() {
if (GPR_UNLIKELY(refs_.Unref())) {
unref_behavior_(static_cast<Child*>(this));
}
template <
typename Subclass,
std::enable_if_t<std::is_base_of<Child, Subclass>::value, bool> = true>
RefCountedPtr<Subclass> RefAsSubclass() {
IncrementRefCount();
return RefCountedPtr<Subclass>(static_cast<Subclass*>(this));
}
void Unref(const DebugLocation& location, const char* reason) {
if (GPR_UNLIKELY(refs_.Unref(location, reason))) {
unref_behavior_(static_cast<Child*>(this));
}
template <
typename Subclass,
std::enable_if_t<std::is_base_of<Child, Subclass>::value, bool> = true>
RefCountedPtr<Subclass> RefAsSubclass(const DebugLocation& location,
const char* reason) {
IncrementRefCount(location, reason);
return RefCountedPtr<Subclass>(static_cast<Subclass*>(this));
}
GRPC_MUST_USE_RESULT RefCountedPtr<Child> RefIfNonZero() {
@ -119,6 +124,17 @@ class InternallyRefCounted : public Orphanable {
: nullptr);
}
void Unref() {
if (GPR_UNLIKELY(refs_.Unref())) {
unref_behavior_(static_cast<Child*>(this));
}
}
void Unref(const DebugLocation& location, const char* reason) {
if (GPR_UNLIKELY(refs_.Unref(location, reason))) {
unref_behavior_(static_cast<Child*>(this));
}
}
private:
void IncrementRefCount() { refs_.Ref(); }
void IncrementRefCount(const DebugLocation& location, const char* reason) {

@ -276,6 +276,10 @@ class RefCounted : public Impl {
public:
using RefCountedChildType = Child;
// Not copyable nor movable.
RefCounted(const RefCounted&) = delete;
RefCounted& operator=(const RefCounted&) = delete;
// Note: Depending on the Impl used, this dtor can be implicitly virtual.
~RefCounted() = default;
@ -301,19 +305,20 @@ class RefCounted : public Impl {
return RefCountedPtr<const Child>(static_cast<const Child*>(this));
}
// TODO(roth): Once all of our code is converted to C++ and can use
// RefCountedPtr<> instead of manual ref-counting, make this method
// private, since it will only be used by RefCountedPtr<>, which is a
// friend of this class.
void Unref() const {
if (GPR_UNLIKELY(refs_.Unref())) {
unref_behavior_(static_cast<const Child*>(this));
}
template <
typename Subclass,
std::enable_if_t<std::is_base_of<Child, Subclass>::value, bool> = true>
RefCountedPtr<Subclass> RefAsSubclass() {
IncrementRefCount();
return RefCountedPtr<Subclass>(static_cast<Subclass*>(this));
}
void Unref(const DebugLocation& location, const char* reason) const {
if (GPR_UNLIKELY(refs_.Unref(location, reason))) {
unref_behavior_(static_cast<const Child*>(this));
}
template <
typename Subclass,
std::enable_if_t<std::is_base_of<Child, Subclass>::value, bool> = true>
RefCountedPtr<Subclass> RefAsSubclass(const DebugLocation& location,
const char* reason) {
IncrementRefCount(location, reason);
return RefCountedPtr<Subclass>(static_cast<Subclass*>(this));
}
// RefIfNonZero() for mutable types.
@ -340,9 +345,20 @@ class RefCounted : public Impl {
: nullptr);
}
// Not copyable nor movable.
RefCounted(const RefCounted&) = delete;
RefCounted& operator=(const RefCounted&) = delete;
// TODO(roth): Once all of our code is converted to C++ and can use
// RefCountedPtr<> instead of manual ref-counting, make this method
// private, since it will only be used by RefCountedPtr<>, which is a
// friend of this class.
void Unref() const {
if (GPR_UNLIKELY(refs_.Unref())) {
unref_behavior_(static_cast<const Child*>(this));
}
}
void Unref(const DebugLocation& location, const char* reason) const {
if (GPR_UNLIKELY(refs_.Unref(location, reason))) {
unref_behavior_(static_cast<const Child*>(this));
}
}
protected:
// Note: Tracing is a no-op on non-debug builds.

@ -43,7 +43,8 @@ class RefCountedPtr {
RefCountedPtr(std::nullptr_t) {}
// If value is non-null, we take ownership of a ref to it.
template <typename Y>
template <typename Y,
std::enable_if_t<std::is_convertible<Y*, T*>::value, bool> = true>
explicit RefCountedPtr(Y* value) : value_(value) {}
// Move ctors.
@ -51,7 +52,8 @@ class RefCountedPtr {
value_ = other.value_;
other.value_ = nullptr;
}
template <typename Y>
template <typename Y,
std::enable_if_t<std::is_convertible<Y*, T*>::value, bool> = true>
// NOLINTNEXTLINE(google-explicit-constructor)
RefCountedPtr(RefCountedPtr<Y>&& other) noexcept {
value_ = static_cast<T*>(other.value_);
@ -63,7 +65,8 @@ class RefCountedPtr {
reset(std::exchange(other.value_, nullptr));
return *this;
}
template <typename Y>
template <typename Y,
std::enable_if_t<std::is_convertible<Y*, T*>::value, bool> = true>
RefCountedPtr& operator=(RefCountedPtr<Y>&& other) noexcept {
reset(std::exchange(other.value_, nullptr));
return *this;
@ -74,7 +77,8 @@ class RefCountedPtr {
if (other.value_ != nullptr) other.value_->IncrementRefCount();
value_ = other.value_;
}
template <typename Y>
template <typename Y,
std::enable_if_t<std::is_convertible<Y*, T*>::value, bool> = true>
// NOLINTNEXTLINE(google-explicit-constructor)
RefCountedPtr(const RefCountedPtr<Y>& other) {
static_assert(std::has_virtual_destructor<T>::value,
@ -92,7 +96,8 @@ class RefCountedPtr {
reset(other.value_);
return *this;
}
template <typename Y>
template <typename Y,
std::enable_if_t<std::is_convertible<Y*, T*>::value, bool> = true>
RefCountedPtr& operator=(const RefCountedPtr<Y>& other) {
static_assert(std::has_virtual_destructor<T>::value,
"T does not have a virtual dtor");
@ -107,6 +112,12 @@ class RefCountedPtr {
if (value_ != nullptr) value_->Unref();
}
// An explicit copy method that supports ref-count tracing.
RefCountedPtr<T> Ref(const DebugLocation& location, const char* reason) {
if (value_ != nullptr) value_->IncrementRefCount(location, reason);
return RefCountedPtr<T>(value_);
}
void swap(RefCountedPtr& other) { std::swap(value_, other.value_); }
// If value is non-null, we take ownership of a ref to it.
@ -119,13 +130,15 @@ class RefCountedPtr {
T* old_value = std::exchange(value_, value);
if (old_value != nullptr) old_value->Unref(location, reason);
}
template <typename Y>
template <typename Y,
std::enable_if_t<std::is_convertible<Y*, T*>::value, bool> = true>
void reset(Y* value = nullptr) {
static_assert(std::has_virtual_destructor<T>::value,
"T does not have a virtual dtor");
reset(static_cast<T*>(value));
}
template <typename Y>
template <typename Y,
std::enable_if_t<std::is_convertible<Y*, T*>::value, bool> = true>
void reset(const DebugLocation& location, const char* reason,
Y* value = nullptr) {
static_assert(std::has_virtual_destructor<T>::value,
@ -143,24 +156,34 @@ class RefCountedPtr {
T& operator*() const { return *value_; }
T* operator->() const { return value_; }
template <typename Y>
template <typename Y,
std::enable_if_t<std::is_base_of<T, Y>::value, bool> = true>
RefCountedPtr<Y> TakeAsSubclass() {
return RefCountedPtr<Y>(static_cast<Y*>(release()));
}
template <typename Y,
std::enable_if_t<std::is_convertible<Y*, T*>::value, bool> = true>
bool operator==(const RefCountedPtr<Y>& other) const {
return value_ == other.value_;
}
template <typename Y>
template <typename Y,
std::enable_if_t<std::is_convertible<Y*, T*>::value, bool> = true>
bool operator==(const Y* other) const {
return value_ == other;
}
bool operator==(std::nullptr_t) const { return value_ == nullptr; }
template <typename Y>
template <typename Y,
std::enable_if_t<std::is_convertible<Y*, T*>::value, bool> = true>
bool operator!=(const RefCountedPtr<Y>& other) const {
return value_ != other.value_;
}
template <typename Y>
template <typename Y,
std::enable_if_t<std::is_convertible<Y*, T*>::value, bool> = true>
bool operator!=(const Y* other) const {
return value_ != other;
}
@ -184,7 +207,8 @@ class WeakRefCountedPtr {
WeakRefCountedPtr(std::nullptr_t) {}
// If value is non-null, we take ownership of a ref to it.
template <typename Y>
template <typename Y,
std::enable_if_t<std::is_convertible<Y*, T*>::value, bool> = true>
explicit WeakRefCountedPtr(Y* value) {
value_ = value;
}
@ -194,7 +218,8 @@ class WeakRefCountedPtr {
value_ = other.value_;
other.value_ = nullptr;
}
template <typename Y>
template <typename Y,
std::enable_if_t<std::is_convertible<Y*, T*>::value, bool> = true>
// NOLINTNEXTLINE(google-explicit-constructor)
WeakRefCountedPtr(WeakRefCountedPtr<Y>&& other) noexcept {
value_ = static_cast<T*>(other.value_);
@ -206,7 +231,8 @@ class WeakRefCountedPtr {
reset(std::exchange(other.value_, nullptr));
return *this;
}
template <typename Y>
template <typename Y,
std::enable_if_t<std::is_convertible<Y*, T*>::value, bool> = true>
WeakRefCountedPtr& operator=(WeakRefCountedPtr<Y>&& other) noexcept {
reset(std::exchange(other.value_, nullptr));
return *this;
@ -217,7 +243,8 @@ class WeakRefCountedPtr {
if (other.value_ != nullptr) other.value_->IncrementWeakRefCount();
value_ = other.value_;
}
template <typename Y>
template <typename Y,
std::enable_if_t<std::is_convertible<Y*, T*>::value, bool> = true>
// NOLINTNEXTLINE(google-explicit-constructor)
WeakRefCountedPtr(const WeakRefCountedPtr<Y>& other) {
static_assert(std::has_virtual_destructor<T>::value,
@ -235,7 +262,8 @@ class WeakRefCountedPtr {
reset(other.value_);
return *this;
}
template <typename Y>
template <typename Y,
std::enable_if_t<std::is_convertible<Y*, T*>::value, bool> = true>
WeakRefCountedPtr& operator=(const WeakRefCountedPtr<Y>& other) {
static_assert(std::has_virtual_destructor<T>::value,
"T does not have a virtual dtor");
@ -250,6 +278,13 @@ class WeakRefCountedPtr {
if (value_ != nullptr) value_->WeakUnref();
}
// An explicit copy method that supports ref-count tracing.
WeakRefCountedPtr<T> WeakRef(const DebugLocation& location,
const char* reason) {
if (value_ != nullptr) value_->IncrementWeakRefCount(location, reason);
return WeakRefCountedPtr<T>(value_);
}
void swap(WeakRefCountedPtr& other) { std::swap(value_, other.value_); }
// If value is non-null, we take ownership of a ref to it.
@ -262,13 +297,15 @@ class WeakRefCountedPtr {
T* old_value = std::exchange(value_, value);
if (old_value != nullptr) old_value->WeakUnref(location, reason);
}
template <typename Y>
template <typename Y,
std::enable_if_t<std::is_convertible<Y*, T*>::value, bool> = true>
void reset(Y* value = nullptr) {
static_assert(std::has_virtual_destructor<T>::value,
"T does not have a virtual dtor");
reset(static_cast<T*>(value));
}
template <typename Y>
template <typename Y,
std::enable_if_t<std::is_convertible<Y*, T*>::value, bool> = true>
void reset(const DebugLocation& location, const char* reason,
Y* value = nullptr) {
static_assert(std::has_virtual_destructor<T>::value,
@ -280,35 +317,41 @@ class WeakRefCountedPtr {
// us to pass a ref to idiomatic C code that does not use WeakRefCountedPtr<>.
// Once all of our code has been converted to idiomatic C++, this
// method should go away.
T* release() {
T* value = value_;
value_ = nullptr;
return value;
}
T* release() { return std::exchange(value_, nullptr); }
T* get() const { return value_; }
T& operator*() const { return *value_; }
T* operator->() const { return value_; }
template <typename Y>
template <typename Y,
std::enable_if_t<std::is_base_of<T, Y>::value, bool> = true>
WeakRefCountedPtr<Y> TakeAsSubclass() {
return WeakRefCountedPtr<Y>(static_cast<Y*>(release()));
}
template <typename Y,
std::enable_if_t<std::is_convertible<Y*, T*>::value, bool> = true>
bool operator==(const WeakRefCountedPtr<Y>& other) const {
return value_ == other.value_;
}
template <typename Y>
template <typename Y,
std::enable_if_t<std::is_convertible<Y*, T*>::value, bool> = true>
bool operator==(const Y* other) const {
return value_ == other;
}
bool operator==(std::nullptr_t) const { return value_ == nullptr; }
template <typename Y>
template <typename Y,
std::enable_if_t<std::is_convertible<Y*, T*>::value, bool> = true>
bool operator!=(const WeakRefCountedPtr<Y>& other) const {
return value_ != other.value_;
}
template <typename Y>
template <typename Y,
std::enable_if_t<std::is_convertible<Y*, T*>::value, bool> = true>
bool operator!=(const Y* other) const {
return value_ != other;
}

@ -152,7 +152,8 @@ grpc_plugin_credentials::GetRequestMetadata(
// Create pending_request object.
auto request = grpc_core::MakeRefCounted<PendingRequest>(
Ref(), std::move(initial_metadata), args);
RefAsSubclass<grpc_plugin_credentials>(), std::move(initial_metadata),
args);
// Invoke the plugin. The callback holds a ref to us.
if (GRPC_TRACE_FLAG_ENABLED(grpc_plugin_credentials_trace)) {
gpr_log(GPR_INFO, "plugin_credentials[%p]: request %p: invoking plugin",

@ -379,7 +379,8 @@ void TlsChannelSecurityConnector::check_peer(
grpc_ssl_peer_to_auth_context(&peer, GRPC_TLS_TRANSPORT_SECURITY_TYPE);
GPR_ASSERT(options_->certificate_verifier() != nullptr);
auto* pending_request = new ChannelPendingVerifierRequest(
Ref(), on_peer_checked, peer, target_name);
RefAsSubclass<TlsChannelSecurityConnector>(), on_peer_checked, peer,
target_name);
{
MutexLock lock(&verifier_request_map_mu_);
pending_verifier_requests_.emplace(on_peer_checked, pending_request);
@ -653,8 +654,8 @@ void TlsServerSecurityConnector::check_peer(
*auth_context =
grpc_ssl_peer_to_auth_context(&peer, GRPC_TLS_TRANSPORT_SECURITY_TYPE);
if (options_->certificate_verifier() != nullptr) {
auto* pending_request =
new ServerPendingVerifierRequest(Ref(), on_peer_checked, peer);
auto* pending_request = new ServerPendingVerifierRequest(
RefAsSubclass<TlsServerSecurityConnector>(), on_peer_checked, peer);
{
MutexLock lock(&verifier_request_map_mu_);
pending_verifier_requests_.emplace(on_peer_checked, pending_request);

@ -216,10 +216,8 @@ absl::StatusOr<ClientAuthFilter> ClientAuthFilter::Create(
return absl::InvalidArgumentError(
"Auth context missing from client auth filter args");
}
return ClientAuthFilter(
static_cast<grpc_channel_security_connector*>(sc)->Ref(),
auth_context->Ref());
return ClientAuthFilter(sc->RefAsSubclass<grpc_channel_security_connector>(),
auth_context->Ref());
}
const grpc_channel_filter ClientAuthFilter::kFilter =

@ -815,7 +815,8 @@ void Server::AddListener(OrphanablePtr<ListenerInterface> listener) {
channelz::ListenSocketNode* listen_socket_node =
listener->channelz_listen_socket_node();
if (listen_socket_node != nullptr && channelz_node_ != nullptr) {
channelz_node_->AddChildListenSocket(listen_socket_node->Ref());
channelz_node_->AddChildListenSocket(
listen_socket_node->RefAsSubclass<channelz::ListenSocketNode>());
}
listeners_.emplace_back(std::move(listener));
}

@ -91,8 +91,9 @@ class AsyncConnectivityStateWatcherInterface::Notifier {
void AsyncConnectivityStateWatcherInterface::Notify(
grpc_connectivity_state state, const absl::Status& status) {
new Notifier(Ref(), state, status,
work_serializer_); // Deletes itself when done.
// Deletes itself when done.
new Notifier(RefAsSubclass<AsyncConnectivityStateWatcherInterface>(), state,
status, work_serializer_);
}
//

@ -65,7 +65,7 @@ TEST(XdsOverrideHostConfigParsingTest, ValidConfig) {
ASSERT_NE(lb_config, nullptr);
ASSERT_EQ(lb_config->name(), XdsOverrideHostLbConfig::Name());
auto override_host_lb_config =
static_cast<RefCountedPtr<XdsOverrideHostLbConfig>>(lb_config);
lb_config.TakeAsSubclass<XdsOverrideHostLbConfig>();
EXPECT_EQ(override_host_lb_config->override_host_status_set(),
XdsHealthStatusSet({
XdsHealthStatus(XdsHealthStatus::HealthStatus::kDraining),
@ -100,7 +100,7 @@ TEST(XdsOverrideHostConfigParsingTest, ValidConfigWithRR) {
ASSERT_NE(lb_config, nullptr);
ASSERT_EQ(lb_config->name(), XdsOverrideHostLbConfig::Name());
auto override_host_lb_config =
static_cast<RefCountedPtr<XdsOverrideHostLbConfig>>(lb_config);
lb_config.TakeAsSubclass<XdsOverrideHostLbConfig>();
ASSERT_NE(override_host_lb_config->child_config(), nullptr);
ASSERT_EQ(override_host_lb_config->child_config()->name(), "round_robin");
}
@ -132,7 +132,7 @@ TEST(XdsOverrideHostConfigParsingTest, ValidConfigNoDraining) {
ASSERT_NE(lb_config, nullptr);
ASSERT_EQ(lb_config->name(), XdsOverrideHostLbConfig::Name());
auto override_host_lb_config =
static_cast<RefCountedPtr<XdsOverrideHostLbConfig>>(lb_config);
lb_config.TakeAsSubclass<XdsOverrideHostLbConfig>();
EXPECT_EQ(override_host_lb_config->override_host_status_set(),
XdsHealthStatusSet(
{XdsHealthStatus(XdsHealthStatus::HealthStatus::kHealthy),
@ -161,8 +161,9 @@ TEST(XdsOverrideHostConfigParsingTest, ValidConfigNoOverrideHostStatuses) {
ASSERT_NE(global_config, nullptr);
auto lb_config = global_config->parsed_lb_config();
ASSERT_NE(lb_config, nullptr);
ASSERT_EQ(lb_config->name(), XdsOverrideHostLbConfig::Name());
auto override_host_lb_config =
static_cast<RefCountedPtr<XdsOverrideHostLbConfig>>(lb_config);
lb_config.TakeAsSubclass<XdsOverrideHostLbConfig>();
EXPECT_EQ(override_host_lb_config->override_host_status_set(),
XdsHealthStatusSet(
{XdsHealthStatus(XdsHealthStatus::HealthStatus::kHealthy),

@ -71,6 +71,24 @@ TEST(DualRefCounted, RefIfNonZero) {
foo->WeakUnref();
}
TEST(DualRefCounted, RefAndWeakRefAsSubclass) {
class Bar : public Foo {};
Foo* foo = new Bar();
RefCountedPtr<Bar> barp = foo->RefAsSubclass<Bar>();
barp.release();
barp = foo->RefAsSubclass<Bar>(DEBUG_LOCATION, "test");
barp.release();
WeakRefCountedPtr<Bar> weak_barp = foo->WeakRefAsSubclass<Bar>();
weak_barp.release();
weak_barp = foo->WeakRefAsSubclass<Bar>(DEBUG_LOCATION, "test");
weak_barp.release();
foo->WeakUnref();
foo->WeakUnref();
foo->Unref();
foo->Unref();
foo->Unref();
}
class FooWithTracing : public DualRefCounted<FooWithTracing> {
public:
FooWithTracing() : DualRefCounted("FooWithTracing") {}

@ -78,6 +78,20 @@ TEST(OrphanablePtr, InternallyRefCounted) {
bar->FinishWork();
}
TEST(OrphanablePtr, InternallyRefCountedRefAsSubclass) {
class Subclass : public Bar {
public:
void StartWork() { self_ref_ = RefAsSubclass<Subclass>(); }
void FinishWork() { self_ref_.reset(); }
private:
RefCountedPtr<Subclass> self_ref_;
};
auto bar = MakeOrphanable<Subclass>();
bar->StartWork();
bar->FinishWork();
}
class Baz : public InternallyRefCounted<Baz> {
public:
Baz() : Baz(0) {}

@ -194,6 +194,9 @@ TEST(RefCountedPtr, RefCountedWithTracing) {
RefCountedPtr<FooWithTracing> foo(new FooWithTracing());
RefCountedPtr<FooWithTracing> foo2 = foo->Ref(DEBUG_LOCATION, "foo");
foo2.release();
RefCountedPtr<FooWithTracing> foo3 = foo.Ref(DEBUG_LOCATION, "foo");
foo3.release();
foo->Unref(DEBUG_LOCATION, "foo");
foo->Unref(DEBUG_LOCATION, "foo");
}
@ -240,24 +243,27 @@ TEST(RefCountedPtr, EqualityWithSubclass) {
EXPECT_EQ(b, s);
}
void FunctionTakingBaseClass(RefCountedPtr<BaseClass> p) {
p.reset(); // To appease clang-tidy.
}
void FunctionTakingBaseClass(RefCountedPtr<BaseClass>) {}
TEST(RefCountedPtr, CanPassSubclassToFunctionExpectingBaseClass) {
RefCountedPtr<Subclass> p = MakeRefCounted<Subclass>();
FunctionTakingBaseClass(p);
}
void FunctionTakingSubclass(RefCountedPtr<Subclass> p) {
p.reset(); // To appease clang-tidy.
}
void FunctionTakingSubclass(RefCountedPtr<Subclass>) {}
TEST(RefCountedPtr, CanPassSubclassToFunctionExpectingSubclass) {
RefCountedPtr<Subclass> p = MakeRefCounted<Subclass>();
FunctionTakingSubclass(p);
}
TEST(RefCountedPtr, TakeAsSubclass) {
RefCountedPtr<BaseClass> p = MakeRefCounted<Subclass>();
auto s = p.TakeAsSubclass<Subclass>();
EXPECT_EQ(p.get(), nullptr);
EXPECT_NE(s.get(), nullptr);
}
//
// WeakRefCountedPtr<> tests
//
@ -437,6 +443,9 @@ TEST(WeakRefCountedPtr, RefCountedWithTracing) {
WeakRefCountedPtr<BarWithTracing> bar = bar_strong->WeakRef();
WeakRefCountedPtr<BarWithTracing> bar2 = bar->WeakRef(DEBUG_LOCATION, "bar");
bar2.release();
WeakRefCountedPtr<BarWithTracing> bar3 = bar.WeakRef(DEBUG_LOCATION, "bar");
bar3.release();
bar->WeakUnref(DEBUG_LOCATION, "bar");
bar->WeakUnref(DEBUG_LOCATION, "bar");
}
@ -466,7 +475,7 @@ TEST(WeakRefCountedPtr, CopyAssignFromWeakSubclass) {
RefCountedPtr<WeakSubclass> strong(new WeakSubclass());
WeakRefCountedPtr<WeakBaseClass> b;
EXPECT_EQ(nullptr, b.get());
WeakRefCountedPtr<WeakSubclass> s = strong->WeakRef();
WeakRefCountedPtr<WeakSubclass> s = strong->WeakRefAsSubclass<WeakSubclass>();
b = s;
EXPECT_NE(nullptr, b.get());
}
@ -475,7 +484,7 @@ TEST(WeakRefCountedPtr, MoveAssignFromWeakSubclass) {
RefCountedPtr<WeakSubclass> strong(new WeakSubclass());
WeakRefCountedPtr<WeakBaseClass> b;
EXPECT_EQ(nullptr, b.get());
WeakRefCountedPtr<WeakSubclass> s = strong->WeakRef();
WeakRefCountedPtr<WeakSubclass> s = strong->WeakRefAsSubclass<WeakSubclass>();
b = std::move(s);
EXPECT_NE(nullptr, b.get());
}
@ -484,7 +493,7 @@ TEST(WeakRefCountedPtr, ResetFromWeakSubclass) {
RefCountedPtr<WeakSubclass> strong(new WeakSubclass());
WeakRefCountedPtr<WeakBaseClass> b;
EXPECT_EQ(nullptr, b.get());
b.reset(strong->WeakRef().release());
b.reset(strong->WeakRefAsSubclass<WeakSubclass>().release());
EXPECT_NE(nullptr, b.get());
}
@ -494,26 +503,30 @@ TEST(WeakRefCountedPtr, EqualityWithWeakSubclass) {
EXPECT_EQ(b, strong.get());
}
void FunctionTakingWeakBaseClass(WeakRefCountedPtr<WeakBaseClass> p) {
p.reset(); // To appease clang-tidy.
}
void FunctionTakingWeakBaseClass(WeakRefCountedPtr<WeakBaseClass>) {}
TEST(WeakRefCountedPtr, CanPassWeakSubclassToFunctionExpectingWeakBaseClass) {
RefCountedPtr<WeakSubclass> strong(new WeakSubclass());
WeakRefCountedPtr<WeakSubclass> p = strong->WeakRef();
WeakRefCountedPtr<WeakSubclass> p = strong->WeakRefAsSubclass<WeakSubclass>();
FunctionTakingWeakBaseClass(p);
}
void FunctionTakingWeakSubclass(WeakRefCountedPtr<WeakSubclass> p) {
p.reset(); // To appease clang-tidy.
}
void FunctionTakingWeakSubclass(WeakRefCountedPtr<WeakSubclass>) {}
TEST(WeakRefCountedPtr, CanPassWeakSubclassToFunctionExpectingWeakSubclass) {
RefCountedPtr<WeakSubclass> strong(new WeakSubclass());
WeakRefCountedPtr<WeakSubclass> p = strong->WeakRef();
WeakRefCountedPtr<WeakSubclass> p = strong->WeakRefAsSubclass<WeakSubclass>();
FunctionTakingWeakSubclass(p);
}
TEST(WeakRefCountedPtr, TakeAsSubclass) {
RefCountedPtr<WeakBaseClass> strong = MakeRefCounted<WeakSubclass>();
WeakRefCountedPtr<WeakBaseClass> p = strong->WeakRef();
WeakRefCountedPtr<WeakSubclass> s = p.TakeAsSubclass<WeakSubclass>();
EXPECT_EQ(p.get(), nullptr);
EXPECT_NE(s.get(), nullptr);
}
//
// tests for absl hash integration
//

@ -64,6 +64,18 @@ TEST(RefCounted, Const) {
foo->Unref();
}
TEST(RefCounted, SubclassOfRefCountedType) {
class Bar : public Foo {};
Bar* bar = new Bar();
RefCountedPtr<Bar> barp = bar->RefAsSubclass<Bar>();
barp.release();
barp = bar->RefAsSubclass<Bar>(DEBUG_LOCATION, "whee");
barp.release();
bar->Unref();
bar->Unref();
bar->Unref();
}
class Value : public RefCounted<Value, PolymorphicRefCount, UnrefNoDelete> {
public:
Value(int value, std::set<std::unique_ptr<Value>>* registry) : value_(value) {

@ -44,7 +44,7 @@ class TestServerConfigSelectorProvider : public ServerConfigSelectorProvider {
// Test that ServerConfigSelectorProvider can be safely copied to channel args
// and destroyed
TEST(ServerConfigSelectorProviderTest, CopyChannelArgs) {
auto server_config_selector_provider =
RefCountedPtr<ServerConfigSelectorProvider> server_config_selector_provider =
MakeRefCounted<TestServerConfigSelectorProvider>();
auto args = ChannelArgs().SetObject(server_config_selector_provider);
EXPECT_EQ(server_config_selector_provider,

@ -527,7 +527,9 @@ class OobBackendMetricTestLoadBalancingPolicy
subchannel->AddDataWatcher(MakeOobBackendMetricWatcher(
Duration::Seconds(1),
std::make_unique<BackendMetricWatcher>(
EndpointAddresses(address, per_address_args), parent()->Ref())));
EndpointAddresses(address, per_address_args),
parent()
->RefAsSubclass<OobBackendMetricTestLoadBalancingPolicy>())));
return subchannel;
}
};

@ -51,7 +51,7 @@ ParseConfig(absl::string_view json_string) {
return errors.status(absl::StatusCode::kInvalidArgument,
"validation errors");
}
return std::move(config);
return config.TakeAsSubclass<FileWatcherCertificateProviderFactory::Config>();
}
TEST(FileWatcherConfigTest, Basic) {

@ -664,10 +664,9 @@ TEST(XdsBootstrapTest, CertificateProvidersFakePluginParsingSuccess) {
bootstrap->certificate_providers().at("fake_plugin");
ASSERT_EQ(fake_plugin.plugin_name, "fake");
ASSERT_EQ(fake_plugin.config->name(), "fake");
ASSERT_EQ(static_cast<RefCountedPtr<FakeCertificateProviderFactory::Config>>(
fake_plugin.config)
->value(),
10);
auto* config = static_cast<FakeCertificateProviderFactory::Config*>(
fake_plugin.config.get());
ASSERT_EQ(config->value(), 10);
}
TEST(XdsBootstrapTest, CertificateProvidersFakePluginEmptyConfig) {
@ -692,10 +691,9 @@ TEST(XdsBootstrapTest, CertificateProvidersFakePluginEmptyConfig) {
bootstrap->certificate_providers().at("fake_plugin");
ASSERT_EQ(fake_plugin.plugin_name, "fake");
ASSERT_EQ(fake_plugin.config->name(), "fake");
ASSERT_EQ(static_cast<RefCountedPtr<FakeCertificateProviderFactory::Config>>(
fake_plugin.config)
->value(),
0);
auto* config = static_cast<FakeCertificateProviderFactory::Config*>(
fake_plugin.config.get());
ASSERT_EQ(config->value(), 0);
}
TEST(XdsBootstrapTest, XdsServerToJsonAndParse) {

@ -578,7 +578,8 @@ class XdsClientTest : public ::testing::Test {
FakeXdsBootstrap::Builder bootstrap_builder = FakeXdsBootstrap::Builder(),
Duration resource_request_timeout = Duration::Seconds(15)) {
auto transport_factory = MakeOrphanable<FakeXdsTransportFactory>();
transport_factory_ = transport_factory->Ref();
transport_factory_ =
transport_factory->Ref().TakeAsSubclass<FakeXdsTransportFactory>();
xds_client_ = MakeRefCounted<XdsClient>(
bootstrap_builder.Build(), std::move(transport_factory),
grpc_event_engine::experimental::GetDefaultEventEngine(), "foo agent",

@ -216,10 +216,10 @@ OrphanablePtr<XdsTransportFactory::XdsTransport::StreamingCall>
FakeXdsTransportFactory::FakeXdsTransport::CreateStreamingCall(
const char* method,
std::unique_ptr<StreamingCall::EventHandler> event_handler) {
auto call = MakeOrphanable<FakeStreamingCall>(Ref(), method,
std::move(event_handler));
auto call = MakeOrphanable<FakeStreamingCall>(
RefAsSubclass<FakeXdsTransport>(), method, std::move(event_handler));
MutexLock lock(&mu_);
active_calls_[method] = call->Ref();
active_calls_[method] = call->Ref().TakeAsSubclass<FakeStreamingCall>();
cv_.Signal();
return call;
}
@ -240,9 +240,10 @@ FakeXdsTransportFactory::Create(
auto& entry = transport_map_[&server];
GPR_ASSERT(entry == nullptr);
auto transport = MakeOrphanable<FakeXdsTransport>(
Ref(), server, std::move(on_connectivity_failure),
auto_complete_messages_from_client_, abort_on_undrained_messages_);
entry = transport->Ref();
RefAsSubclass<FakeXdsTransportFactory>(), server,
std::move(on_connectivity_failure), auto_complete_messages_from_client_,
abort_on_undrained_messages_);
entry = transport->Ref().TakeAsSubclass<FakeXdsTransport>();
return transport;
}

@ -84,7 +84,7 @@ class RpcBehaviorLbPolicy : public LoadBalancingPolicy {
absl::string_view name() const override { return kRpcBehaviorLbPolicyName; }
absl::Status UpdateLocked(UpdateArgs args) override {
RefCountedPtr<RpcBehaviorLbPolicyConfig> config = std::move(args.config);
auto config = args.config.TakeAsSubclass<RpcBehaviorLbPolicyConfig>();
rpc_behavior_ = std::string(config->rpc_behavior());
// Use correct config for the delegate load balancing policy
auto delegate_config =

Loading…
Cancel
Save