diff --git a/test/core/filters/filter_test.cc b/test/core/filters/filter_test.cc index f6715839f59..8b18254961d 100644 --- a/test/core/filters/filter_test.cc +++ b/test/core/filters/filter_test.cc @@ -24,6 +24,7 @@ #include "absl/types/optional.h" #include "gtest/gtest.h" +#include "src/core/lib/channel/call_finalization.h" #include "src/core/lib/channel/context.h" #include "src/core/lib/gprpp/crash.h" #include "src/core/lib/iomgr/timer_manager.h" @@ -52,6 +53,7 @@ class FilterTestBase::Call::Impl Arena* arena() { return arena_.get(); } grpc_call_context_element* legacy_context() { return legacy_context_; } const std::shared_ptr& channel() const { return channel_; } + CallFinalization* call_finalization() { return &call_finalization_; } void Start(ClientMetadataHandle md); void ForwardServerInitialMetadata(ServerMetadataHandle md); @@ -76,6 +78,8 @@ class FilterTestBase::Call::Impl std::shared_ptr const channel_; ScopedArenaPtr arena_{MakeScopedArena(channel_->initial_arena_size, &channel_->memory_allocator)}; + bool run_call_finalization_ = false; + CallFinalization call_finalization_; absl::optional> promise_; Poll poll_next_filter_result_; Pipe pipe_server_initial_metadata_{arena_.get()}; @@ -104,6 +108,9 @@ class FilterTestBase::Call::Impl }; FilterTestBase::Call::Impl::~Impl() { + if (!run_call_finalization_) { + call_finalization_.Run(nullptr); + } for (size_t i = 0; i < GRPC_CONTEXT_COUNT; ++i) { if (legacy_context_[i].destroy != nullptr) { legacy_context_[i].destroy(legacy_context_[i].value); @@ -264,7 +271,8 @@ bool FilterTestBase::Call::Impl::StepOnce() { class FilterTestBase::Call::ScopedContext final : public Activity, public promise_detail::Context, - public promise_detail::Context { + public promise_detail::Context, + public promise_detail::Context { private: class TestWakeable final : public Wakeable { public: @@ -293,6 +301,7 @@ class FilterTestBase::Call::ScopedContext final : promise_detail::Context(impl->arena()), promise_detail::Context( impl->legacy_context()), + promise_detail::Context(impl->call_finalization()), impl_(std::move(impl)) {} void Orphan() override { Crash("Orphan called on Call::ScopedContext"); } @@ -333,6 +342,8 @@ FilterTestBase::Call::Call(const Channel& channel) FilterTestBase::Call::~Call() { ScopedContext x(std::move(impl_)); } +Arena* FilterTestBase::Call::arena() { return impl_->arena(); } + ClientMetadataHandle FilterTestBase::Call::NewClientMetadata( std::initializer_list> init) { diff --git a/test/core/filters/filter_test.h b/test/core/filters/filter_test.h index af73a51fc6a..74acd71b5fa 100644 --- a/test/core/filters/filter_test.h +++ b/test/core/filters/filter_test.h @@ -39,6 +39,7 @@ #include "src/core/lib/channel/channel_args.h" #include "src/core/lib/channel/promise_based_filter.h" #include "src/core/lib/gprpp/ref_counted_ptr.h" +#include "src/core/lib/resource_quota/arena.h" #include "src/core/lib/resource_quota/memory_quota.h" #include "src/core/lib/resource_quota/resource_quota.h" #include "src/core/lib/slice/slice_buffer.h" @@ -54,6 +55,12 @@ MATCHER_P2(HasMetadataKeyValue, key, value, "") { return r == value; } +// gmock matcher to ensure that metadata does not include a key/value pair. +MATCHER_P(LacksMetadataKey, key, "") { + std::string temp; + return !arg.GetStringValue(key, &temp).has_value(); +} + // gmock matcher to ensure that a message has a given set of flags. MATCHER_P(HasMessageFlags, value, "") { return arg.flags() == value; } @@ -109,14 +116,17 @@ class FilterTestBase : public ::testing::Test { Call MakeCall(); - private: - friend class FilterTestBase; - friend class Call; - + protected: explicit Channel(std::unique_ptr filter, FilterTestBase* test) : impl_(std::make_shared(std::move(filter), test)) {} + ChannelFilter* filter_ptr() { return impl_->filter.get(); } + + private: + friend class FilterTestBase; + friend class Call; + std::shared_ptr impl_; }; @@ -160,6 +170,8 @@ class FilterTestBase : public ::testing::Test { // metadata. void FinishNextFilter(ServerMetadataHandle md); + Arena* arena(); + private: friend class Channel; class ScopedContext; @@ -192,9 +204,6 @@ class FilterTestBase : public ::testing::Test { protected: FilterTestBase(); ~FilterTestBase() override; - absl::StatusOr MakeChannel(std::unique_ptr filter) { - return Channel(std::move(filter), this); - } grpc_event_engine::experimental::EventEngine* event_engine() { return &event_engine_; @@ -209,11 +218,19 @@ class FilterTestBase : public ::testing::Test { template class FilterTest : public FilterTestBase { public: + class Channel : public FilterTestBase::Channel { + public: + Filter* filter() { return static_cast(filter_ptr()); } + + private: + friend class FilterTest; + using FilterTestBase::Channel::Channel; + }; + absl::StatusOr MakeChannel(const ChannelArgs& args) { auto filter = Filter::Create(args, ChannelFilter::Args()); if (!filter.ok()) return filter.status(); - return FilterTestBase::MakeChannel( - std::make_unique(std::move(*filter))); + return Channel(std::make_unique(std::move(*filter)), this); } };