mirror of https://github.com/grpc/grpc.git
[promises] Filter unit test framework (#32110)
Built atop #31448 Offers a simple framework for testing filters. <!-- If you know who should review your pull request, please assign it to that person, otherwise the pull request would get assigned randomly. If your pull request is for a specific language, please add the appropriate lang label. --> --------- Co-authored-by: ctiller <ctiller@users.noreply.github.com>pull/32604/head
parent
db3daf567b
commit
2cd1501ca5
13 changed files with 1189 additions and 197 deletions
@ -0,0 +1,425 @@ |
||||
// Copyright 2023 gRPC authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "test/core/filters/filter_test.h" |
||||
|
||||
#include <algorithm> |
||||
#include <chrono> |
||||
#include <memory> |
||||
#include <queue> |
||||
|
||||
#include "absl/memory/memory.h" |
||||
#include "absl/strings/str_cat.h" |
||||
#include "absl/strings/str_format.h" |
||||
#include "absl/types/optional.h" |
||||
#include "gtest/gtest.h" |
||||
|
||||
#include "src/core/lib/channel/context.h" |
||||
#include "src/core/lib/gprpp/crash.h" |
||||
#include "src/core/lib/iomgr/timer_manager.h" |
||||
#include "src/core/lib/promise/activity.h" |
||||
#include "src/core/lib/promise/arena_promise.h" |
||||
#include "src/core/lib/promise/context.h" |
||||
#include "src/core/lib/promise/detail/basic_seq.h" |
||||
#include "src/core/lib/promise/pipe.h" |
||||
#include "src/core/lib/promise/poll.h" |
||||
#include "src/core/lib/resource_quota/arena.h" |
||||
#include "src/core/lib/slice/slice.h" |
||||
#include "test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.pb.h" |
||||
|
||||
namespace grpc_core { |
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// FilterTestBase::Call::Impl
|
||||
|
||||
class FilterTestBase::Call::Impl |
||||
: public std::enable_shared_from_this<FilterTestBase::Call::Impl> { |
||||
public: |
||||
Impl(Call* call, std::shared_ptr<Channel::Impl> channel) |
||||
: call_(call), channel_(std::move(channel)) {} |
||||
~Impl(); |
||||
|
||||
Arena* arena() { return arena_.get(); } |
||||
grpc_call_context_element* legacy_context() { return legacy_context_; } |
||||
const std::shared_ptr<Channel::Impl>& channel() const { return channel_; } |
||||
|
||||
void Start(ClientMetadataHandle md); |
||||
void ForwardServerInitialMetadata(ServerMetadataHandle md); |
||||
void ForwardMessageClientToServer(MessageHandle msg); |
||||
void ForwardMessageServerToClient(MessageHandle msg); |
||||
void FinishNextFilter(ServerMetadataHandle md); |
||||
|
||||
void StepLoop(); |
||||
|
||||
grpc_event_engine::experimental::EventEngine* event_engine() { |
||||
return channel_->test->event_engine(); |
||||
} |
||||
|
||||
Events& events() { return channel_->test->events; } |
||||
|
||||
private: |
||||
bool StepOnce(); |
||||
Poll<ServerMetadataHandle> PollNextFilter(); |
||||
void ForceWakeup(); |
||||
|
||||
Call* const call_; |
||||
std::shared_ptr<Channel::Impl> const channel_; |
||||
ScopedArenaPtr arena_{MakeScopedArena(channel_->initial_arena_size, |
||||
&channel_->memory_allocator)}; |
||||
absl::optional<ArenaPromise<ServerMetadataHandle>> promise_; |
||||
Poll<ServerMetadataHandle> poll_next_filter_result_; |
||||
Pipe<ServerMetadataHandle> pipe_server_initial_metadata_{arena_.get()}; |
||||
Pipe<MessageHandle> pipe_server_to_client_messages_{arena_.get()}; |
||||
Pipe<MessageHandle> pipe_client_to_server_messages_{arena_.get()}; |
||||
PipeSender<ServerMetadataHandle>* server_initial_metadata_sender_ = nullptr; |
||||
PipeSender<MessageHandle>* server_to_client_messages_sender_ = nullptr; |
||||
PipeReceiver<MessageHandle>* client_to_server_messages_receiver_ = nullptr; |
||||
absl::optional<PipeSender<ServerMetadataHandle>::PushType> |
||||
push_server_initial_metadata_; |
||||
absl::optional<PipeReceiverNextType<ServerMetadataHandle>> |
||||
next_server_initial_metadata_; |
||||
absl::optional<PipeSender<MessageHandle>::PushType> |
||||
push_server_to_client_messages_; |
||||
absl::optional<PipeReceiverNextType<MessageHandle>> |
||||
next_server_to_client_messages_; |
||||
absl::optional<PipeSender<MessageHandle>::PushType> |
||||
push_client_to_server_messages_; |
||||
absl::optional<PipeReceiverNextType<MessageHandle>> |
||||
next_client_to_server_messages_; |
||||
absl::optional<ServerMetadataHandle> forward_server_initial_metadata_; |
||||
std::queue<MessageHandle> forward_client_to_server_messages_; |
||||
std::queue<MessageHandle> forward_server_to_client_messages_; |
||||
// Contexts for various subsystems (security, tracing, ...).
|
||||
grpc_call_context_element legacy_context_[GRPC_CONTEXT_COUNT] = {}; |
||||
}; |
||||
|
||||
FilterTestBase::Call::Impl::~Impl() { |
||||
for (size_t i = 0; i < GRPC_CONTEXT_COUNT; ++i) { |
||||
if (legacy_context_[i].destroy != nullptr) { |
||||
legacy_context_[i].destroy(legacy_context_[i].value); |
||||
} |
||||
} |
||||
} |
||||
|
||||
void FilterTestBase::Call::Impl::Start(ClientMetadataHandle md) { |
||||
EXPECT_EQ(promise_, absl::nullopt); |
||||
promise_ = channel_->filter->MakeCallPromise( |
||||
CallArgs{std::move(md), ClientInitialMetadataOutstandingToken::Empty(), |
||||
&pipe_server_initial_metadata_.sender, |
||||
&pipe_client_to_server_messages_.receiver, |
||||
&pipe_server_to_client_messages_.sender}, |
||||
[this](CallArgs args) -> ArenaPromise<ServerMetadataHandle> { |
||||
server_initial_metadata_sender_ = args.server_initial_metadata; |
||||
client_to_server_messages_receiver_ = args.client_to_server_messages; |
||||
server_to_client_messages_sender_ = args.server_to_client_messages; |
||||
next_server_initial_metadata_.emplace( |
||||
pipe_server_initial_metadata_.receiver.Next()); |
||||
events().Started(call_, *args.client_initial_metadata); |
||||
return [this]() { return PollNextFilter(); }; |
||||
}); |
||||
EXPECT_NE(promise_, absl::nullopt); |
||||
ForceWakeup(); |
||||
} |
||||
|
||||
Poll<ServerMetadataHandle> FilterTestBase::Call::Impl::PollNextFilter() { |
||||
return std::exchange(poll_next_filter_result_, Pending()); |
||||
} |
||||
|
||||
void FilterTestBase::Call::Impl::ForwardServerInitialMetadata( |
||||
ServerMetadataHandle md) { |
||||
EXPECT_FALSE(forward_server_initial_metadata_.has_value()); |
||||
forward_server_initial_metadata_ = std::move(md); |
||||
ForceWakeup(); |
||||
} |
||||
|
||||
void FilterTestBase::Call::Impl::ForwardMessageClientToServer( |
||||
MessageHandle msg) { |
||||
forward_client_to_server_messages_.push(std::move(msg)); |
||||
ForceWakeup(); |
||||
} |
||||
|
||||
void FilterTestBase::Call::Impl::ForwardMessageServerToClient( |
||||
MessageHandle msg) { |
||||
forward_server_to_client_messages_.push(std::move(msg)); |
||||
ForceWakeup(); |
||||
} |
||||
|
||||
void FilterTestBase::Call::Impl::FinishNextFilter(ServerMetadataHandle md) { |
||||
poll_next_filter_result_ = std::move(md); |
||||
ForceWakeup(); |
||||
} |
||||
|
||||
bool FilterTestBase::Call::Impl::StepOnce() { |
||||
if (!promise_.has_value()) return true; |
||||
|
||||
if (forward_server_initial_metadata_.has_value() && |
||||
!push_server_initial_metadata_.has_value()) { |
||||
push_server_initial_metadata_.emplace(server_initial_metadata_sender_->Push( |
||||
std::move(*forward_server_initial_metadata_))); |
||||
forward_server_initial_metadata_.reset(); |
||||
} |
||||
|
||||
if (push_server_initial_metadata_.has_value()) { |
||||
auto r = (*push_server_initial_metadata_)(); |
||||
if (r.ready()) push_server_initial_metadata_.reset(); |
||||
} |
||||
|
||||
if (next_server_initial_metadata_.has_value()) { |
||||
auto r = (*next_server_initial_metadata_)(); |
||||
if (auto* p = r.value_if_ready()) { |
||||
if (p->has_value()) { |
||||
events().ForwardedServerInitialMetadata(call_, *p->value()); |
||||
} |
||||
next_server_initial_metadata_.reset(); |
||||
} |
||||
} |
||||
|
||||
if (server_initial_metadata_sender_ != nullptr && |
||||
!next_server_initial_metadata_.has_value()) { |
||||
// We've finished sending server initial metadata, so we can
|
||||
// process server-to-client messages.
|
||||
if (!next_server_to_client_messages_.has_value()) { |
||||
next_server_to_client_messages_.emplace( |
||||
pipe_server_to_client_messages_.receiver.Next()); |
||||
} |
||||
|
||||
if (push_server_to_client_messages_.has_value()) { |
||||
auto r = (*push_server_to_client_messages_)(); |
||||
if (r.ready()) push_server_to_client_messages_.reset(); |
||||
} |
||||
|
||||
{ |
||||
auto r = (*next_server_to_client_messages_)(); |
||||
if (auto* p = r.value_if_ready()) { |
||||
if (p->has_value()) { |
||||
events().ForwardedMessageServerToClient(call_, *p->value()); |
||||
} |
||||
next_server_to_client_messages_.reset(); |
||||
Activity::current()->ForceImmediateRepoll(); |
||||
} |
||||
} |
||||
|
||||
if (!push_server_to_client_messages_.has_value() && |
||||
!forward_server_to_client_messages_.empty()) { |
||||
push_server_to_client_messages_.emplace( |
||||
server_to_client_messages_sender_->Push( |
||||
std::move(forward_server_to_client_messages_.front()))); |
||||
forward_server_to_client_messages_.pop(); |
||||
Activity::current()->ForceImmediateRepoll(); |
||||
} |
||||
} |
||||
|
||||
if (client_to_server_messages_receiver_ != nullptr) { |
||||
if (!next_client_to_server_messages_.has_value()) { |
||||
next_client_to_server_messages_.emplace( |
||||
client_to_server_messages_receiver_->Next()); |
||||
} |
||||
|
||||
if (push_client_to_server_messages_.has_value()) { |
||||
auto r = (*push_client_to_server_messages_)(); |
||||
if (r.ready()) push_client_to_server_messages_.reset(); |
||||
} |
||||
|
||||
{ |
||||
auto r = (*next_client_to_server_messages_)(); |
||||
if (auto* p = r.value_if_ready()) { |
||||
if (p->has_value()) { |
||||
events().ForwardedMessageClientToServer(call_, *p->value()); |
||||
} |
||||
next_client_to_server_messages_.reset(); |
||||
Activity::current()->ForceImmediateRepoll(); |
||||
} |
||||
} |
||||
|
||||
if (!push_client_to_server_messages_.has_value() && |
||||
!forward_client_to_server_messages_.empty()) { |
||||
push_client_to_server_messages_.emplace( |
||||
pipe_client_to_server_messages_.sender.Push( |
||||
std::move(forward_client_to_server_messages_.front()))); |
||||
forward_client_to_server_messages_.pop(); |
||||
Activity::current()->ForceImmediateRepoll(); |
||||
} |
||||
} |
||||
|
||||
auto r = (*promise_)(); |
||||
if (r.pending()) return false; |
||||
promise_.reset(); |
||||
events().Finished(call_, *r.value()); |
||||
return true; |
||||
} |
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// FilterTestBase::Call::ScopedContext
|
||||
|
||||
class FilterTestBase::Call::ScopedContext final |
||||
: public Activity, |
||||
public promise_detail::Context<Arena>, |
||||
public promise_detail::Context<grpc_call_context_element> { |
||||
private: |
||||
class TestWakeable final : public Wakeable { |
||||
public: |
||||
explicit TestWakeable(ScopedContext* ctx) |
||||
: tag_(ctx->DebugTag()), impl_(ctx->impl_) {} |
||||
void Wakeup(WakeupMask) override { |
||||
std::unique_ptr<TestWakeable> self(this); |
||||
auto impl = impl_.lock(); |
||||
if (impl == nullptr) return; |
||||
impl->event_engine()->Run([weak_impl = impl_]() { |
||||
auto impl = weak_impl.lock(); |
||||
if (impl != nullptr) impl->StepLoop(); |
||||
}); |
||||
} |
||||
void Drop(WakeupMask) override { delete this; } |
||||
std::string ActivityDebugTag(WakeupMask) const override { return tag_; } |
||||
|
||||
private: |
||||
const std::string tag_; |
||||
const std::weak_ptr<Impl> impl_; |
||||
}; |
||||
|
||||
public: |
||||
explicit ScopedContext(std::shared_ptr<Impl> impl) |
||||
: promise_detail::Context<Arena>(impl->arena()), |
||||
promise_detail::Context<grpc_call_context_element>( |
||||
impl->legacy_context()), |
||||
impl_(std::move(impl)) {} |
||||
|
||||
void Orphan() override { Crash("Orphan called on Call::ScopedContext"); } |
||||
void ForceImmediateRepoll(WakeupMask) override { repoll_ = true; } |
||||
Waker MakeOwningWaker() override { return Waker(new TestWakeable(this), 0); } |
||||
Waker MakeNonOwningWaker() override { |
||||
return Waker(new TestWakeable(this), 0); |
||||
} |
||||
std::string DebugTag() const override { |
||||
return absl::StrFormat("FILTER_TEST_CALL[%p]", impl_.get()); |
||||
} |
||||
|
||||
bool repoll() const { return repoll_; } |
||||
|
||||
private: |
||||
ScopedActivity scoped_activity_{this}; |
||||
const std::shared_ptr<Impl> impl_; |
||||
bool repoll_ = false; |
||||
}; |
||||
|
||||
void FilterTestBase::Call::Impl::StepLoop() { |
||||
for (;;) { |
||||
ScopedContext ctx(shared_from_this()); |
||||
if (!StepOnce() && ctx.repoll()) continue; |
||||
return; |
||||
} |
||||
} |
||||
|
||||
void FilterTestBase::Call::Impl::ForceWakeup() { |
||||
ScopedContext(shared_from_this()).MakeOwningWaker().Wakeup(); |
||||
} |
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// FilterTestBase::Call
|
||||
|
||||
FilterTestBase::Call::Call(const Channel& channel) |
||||
: impl_(std::make_unique<Impl>(this, channel.impl_)) {} |
||||
|
||||
FilterTestBase::Call::~Call() { ScopedContext x(std::move(impl_)); } |
||||
|
||||
ClientMetadataHandle FilterTestBase::Call::NewClientMetadata( |
||||
std::initializer_list<std::pair<absl::string_view, absl::string_view>> |
||||
init) { |
||||
auto md = impl_->arena()->MakePooled<ClientMetadata>(impl_->arena()); |
||||
for (auto& p : init) { |
||||
auto parsed = ClientMetadata::Parse( |
||||
p.first, Slice::FromCopiedString(p.second), |
||||
p.first.length() + p.second.length() + 32, |
||||
[p](absl::string_view, const Slice&) { |
||||
Crash(absl::StrCat("Illegal metadata value: ", p.first, ": ", |
||||
p.second)); |
||||
}); |
||||
md->Set(parsed); |
||||
} |
||||
return md; |
||||
} |
||||
|
||||
ServerMetadataHandle FilterTestBase::Call::NewServerMetadata( |
||||
std::initializer_list<std::pair<absl::string_view, absl::string_view>> |
||||
init) { |
||||
auto md = impl_->arena()->MakePooled<ClientMetadata>(impl_->arena()); |
||||
for (auto& p : init) { |
||||
auto parsed = ServerMetadata::Parse( |
||||
p.first, Slice::FromCopiedString(p.second), |
||||
p.first.length() + p.second.length() + 32, |
||||
[p](absl::string_view, const Slice&) { |
||||
Crash(absl::StrCat("Illegal metadata value: ", p.first, ": ", |
||||
p.second)); |
||||
}); |
||||
md->Set(parsed); |
||||
} |
||||
return md; |
||||
} |
||||
|
||||
MessageHandle FilterTestBase::Call::NewMessage(absl::string_view payload, |
||||
uint32_t flags) { |
||||
SliceBuffer buffer; |
||||
if (!payload.empty()) buffer.Append(Slice::FromCopiedString(payload)); |
||||
return impl_->arena()->MakePooled<Message>(std::move(buffer), flags); |
||||
} |
||||
|
||||
void FilterTestBase::Call::Start(ClientMetadataHandle md) { |
||||
ScopedContext ctx(impl_); |
||||
impl_->Start(std::move(md)); |
||||
} |
||||
|
||||
void FilterTestBase::Call::Cancel() { |
||||
ScopedContext ctx(impl_); |
||||
impl_ = absl::make_unique<Impl>(this, impl_->channel()); |
||||
} |
||||
|
||||
void FilterTestBase::Call::ForwardServerInitialMetadata( |
||||
ServerMetadataHandle md) { |
||||
impl_->ForwardServerInitialMetadata(std::move(md)); |
||||
} |
||||
|
||||
void FilterTestBase::Call::ForwardMessageClientToServer(MessageHandle msg) { |
||||
impl_->ForwardMessageClientToServer(std::move(msg)); |
||||
} |
||||
|
||||
void FilterTestBase::Call::ForwardMessageServerToClient(MessageHandle msg) { |
||||
impl_->ForwardMessageServerToClient(std::move(msg)); |
||||
} |
||||
|
||||
void FilterTestBase::Call::FinishNextFilter(ServerMetadataHandle md) { |
||||
impl_->FinishNextFilter(std::move(md)); |
||||
} |
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// FilterTestBase
|
||||
|
||||
FilterTestBase::FilterTestBase() |
||||
: event_engine_( |
||||
[]() { |
||||
grpc_timer_manager_set_threading(false); |
||||
grpc_event_engine::experimental::FuzzingEventEngine::Options |
||||
options; |
||||
options.final_tick_length = std::chrono::milliseconds(1); |
||||
return options; |
||||
}(), |
||||
fuzzing_event_engine::Actions()) {} |
||||
|
||||
FilterTestBase::~FilterTestBase() { event_engine_.UnsetGlobalHooks(); } |
||||
|
||||
void FilterTestBase::Step() { |
||||
event_engine_.TickUntilIdle(); |
||||
::testing::Mock::VerifyAndClearExpectations(&events); |
||||
} |
||||
|
||||
} // namespace grpc_core
|
@ -0,0 +1,225 @@ |
||||
// Copyright 2023 gRPC authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef GRPC_TEST_CORE_FILTERS_FILTER_TEST_H |
||||
#define GRPC_TEST_CORE_FILTERS_FILTER_TEST_H |
||||
|
||||
#include <stddef.h> |
||||
#include <stdint.h> |
||||
|
||||
#include <initializer_list> |
||||
#include <iosfwd> |
||||
#include <memory> |
||||
#include <ostream> |
||||
#include <string> |
||||
#include <utility> |
||||
|
||||
#include <gtest/gtest.h> |
||||
|
||||
#include "absl/status/status.h" |
||||
#include "absl/status/statusor.h" |
||||
#include "absl/strings/escaping.h" |
||||
#include "absl/strings/string_view.h" |
||||
#include "gmock/gmock.h" |
||||
|
||||
#include <grpc/event_engine/event_engine.h> |
||||
#include <grpc/event_engine/memory_allocator.h> |
||||
|
||||
#include "src/core/lib/channel/channel_args.h" |
||||
#include "src/core/lib/channel/promise_based_filter.h" |
||||
#include "src/core/lib/gprpp/ref_counted_ptr.h" |
||||
#include "src/core/lib/resource_quota/memory_quota.h" |
||||
#include "src/core/lib/resource_quota/resource_quota.h" |
||||
#include "src/core/lib/slice/slice_buffer.h" |
||||
#include "src/core/lib/transport/metadata_batch.h" |
||||
#include "src/core/lib/transport/transport.h" |
||||
#include "test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.h" |
||||
#include "test/core/filters/filter_test.h" |
||||
|
||||
// gmock matcher to ensure that metadata has a key/value pair.
|
||||
MATCHER_P2(HasMetadataKeyValue, key, value, "") { |
||||
std::string temp; |
||||
auto r = arg.GetStringValue(key, &temp); |
||||
return r == value; |
||||
} |
||||
|
||||
// gmock matcher to ensure that a message has a given set of flags.
|
||||
MATCHER_P(HasMessageFlags, value, "") { return arg.flags() == value; } |
||||
|
||||
MATCHER_P(HasMetadataResult, absl_status, "") { |
||||
auto status = arg.get(grpc_core::GrpcStatusMetadata()); |
||||
if (!status.has_value()) return false; |
||||
if (static_cast<absl::StatusCode>(status.value()) != absl_status.code()) { |
||||
return false; |
||||
} |
||||
auto* message = arg.get_pointer(grpc_core::GrpcMessageMetadata()); |
||||
if (message == nullptr) return absl_status.message().empty(); |
||||
return message->as_string_view() == absl_status.message(); |
||||
} |
||||
|
||||
// gmock matcher to ensure that a message has a given payload.
|
||||
MATCHER_P(HasMessagePayload, value, "") { |
||||
return arg.payload()->JoinIntoString() == value; |
||||
} |
||||
|
||||
namespace grpc_core { |
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, |
||||
const grpc_metadata_batch& md) { |
||||
return os << md.DebugString(); |
||||
} |
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, const Message& msg) { |
||||
return os << "flags:" << msg.flags() |
||||
<< " payload:" << absl::CEscape(msg.payload()->JoinIntoString()); |
||||
} |
||||
|
||||
class FilterTestBase : public ::testing::Test { |
||||
public: |
||||
class Call; |
||||
|
||||
class Channel { |
||||
private: |
||||
struct Impl { |
||||
Impl(std::unique_ptr<ChannelFilter> filter, FilterTestBase* test) |
||||
: filter(std::move(filter)), test(test) {} |
||||
size_t initial_arena_size = 1024; |
||||
MemoryAllocator memory_allocator = |
||||
ResourceQuota::Default()->memory_quota()->CreateMemoryAllocator( |
||||
"test"); |
||||
std::unique_ptr<ChannelFilter> filter; |
||||
FilterTestBase* const test; |
||||
}; |
||||
|
||||
public: |
||||
void set_initial_arena_size(size_t size) { |
||||
impl_->initial_arena_size = size; |
||||
} |
||||
|
||||
Call MakeCall(); |
||||
|
||||
private: |
||||
friend class FilterTestBase; |
||||
friend class Call; |
||||
|
||||
explicit Channel(std::unique_ptr<ChannelFilter> filter, |
||||
FilterTestBase* test) |
||||
: impl_(std::make_shared<Impl>(std::move(filter), test)) {} |
||||
|
||||
std::shared_ptr<Impl> impl_; |
||||
}; |
||||
|
||||
// One "call" outstanding against this filter.
|
||||
// In reality - this filter is the only thing in the call.
|
||||
// Provides mocks to trap events that happen on the call.
|
||||
class Call { |
||||
public: |
||||
explicit Call(const Channel& channel); |
||||
|
||||
Call(const Call&) = delete; |
||||
Call& operator=(const Call&) = delete; |
||||
|
||||
~Call(); |
||||
|
||||
// Construct client metadata in the arena of this call.
|
||||
// Optional argument is a list of key/value pairs to add to the metadata.
|
||||
ClientMetadataHandle NewClientMetadata( |
||||
std::initializer_list<std::pair<absl::string_view, absl::string_view>> |
||||
init = {}); |
||||
// Construct server metadata in the arena of this call.
|
||||
// Optional argument is a list of key/value pairs to add to the metadata.
|
||||
ServerMetadataHandle NewServerMetadata( |
||||
std::initializer_list<std::pair<absl::string_view, absl::string_view>> |
||||
init = {}); |
||||
// Construct a message in the arena of this call.
|
||||
MessageHandle NewMessage(absl::string_view payload = "", |
||||
uint32_t flags = 0); |
||||
|
||||
// Start the call.
|
||||
void Start(ClientMetadataHandle md); |
||||
// Cancel the call.
|
||||
void Cancel(); |
||||
// Forward server initial metadata through this filter.
|
||||
void ForwardServerInitialMetadata(ServerMetadataHandle md); |
||||
// Forward a message from client to server through this filter.
|
||||
void ForwardMessageClientToServer(MessageHandle msg); |
||||
// Forward a message from server to client through this filter.
|
||||
void ForwardMessageServerToClient(MessageHandle msg); |
||||
// Have the 'next' filter in the chain finish this call and return trailing
|
||||
// metadata.
|
||||
void FinishNextFilter(ServerMetadataHandle md); |
||||
|
||||
private: |
||||
friend class Channel; |
||||
class ScopedContext; |
||||
class Impl; |
||||
|
||||
std::shared_ptr<Impl> impl_; |
||||
}; |
||||
|
||||
struct Events { |
||||
// Mock to trap starting the next filter in the chain.
|
||||
MOCK_METHOD(void, Started, |
||||
(Call * call, const ClientMetadata& client_initial_metadata)); |
||||
// Mock to trap receiving server initial metadata in the next filter in the
|
||||
// chain.
|
||||
MOCK_METHOD(void, ForwardedServerInitialMetadata, |
||||
(Call * call, const ServerMetadata& server_initial_metadata)); |
||||
// Mock to trap seeing a message forward from client to server.
|
||||
MOCK_METHOD(void, ForwardedMessageClientToServer, |
||||
(Call * call, const Message& msg)); |
||||
// Mock to trap seeing a message forward from server to client.
|
||||
MOCK_METHOD(void, ForwardedMessageServerToClient, |
||||
(Call * call, const Message& msg)); |
||||
// Mock to trap seeing a call finish in the next filter in the chain.
|
||||
MOCK_METHOD(void, Finished, |
||||
(Call * call, const ServerMetadata& server_trailing_metadata)); |
||||
}; |
||||
|
||||
::testing::StrictMock<Events> events; |
||||
|
||||
protected: |
||||
FilterTestBase(); |
||||
~FilterTestBase() override; |
||||
absl::StatusOr<Channel> MakeChannel(std::unique_ptr<ChannelFilter> filter) { |
||||
return Channel(std::move(filter), this); |
||||
} |
||||
|
||||
grpc_event_engine::experimental::EventEngine* event_engine() { |
||||
return &event_engine_; |
||||
} |
||||
|
||||
void Step(); |
||||
|
||||
private: |
||||
grpc_event_engine::experimental::FuzzingEventEngine event_engine_; |
||||
}; |
||||
|
||||
template <typename Filter> |
||||
class FilterTest : public FilterTestBase { |
||||
public: |
||||
absl::StatusOr<Channel> MakeChannel(const ChannelArgs& args) { |
||||
auto filter = Filter::Create(args, ChannelFilter::Args()); |
||||
if (!filter.ok()) return filter.status(); |
||||
return FilterTestBase::MakeChannel( |
||||
std::make_unique<Filter>(std::move(*filter))); |
||||
} |
||||
}; |
||||
|
||||
} // namespace grpc_core
|
||||
|
||||
// Expect one of the events corresponding to the methods in FilterTest::Events.
|
||||
#define EXPECT_EVENT(event) EXPECT_CALL(events, event) |
||||
|
||||
#endif // GRPC_TEST_CORE_FILTERS_FILTER_TEST_H
|
@ -0,0 +1,253 @@ |
||||
// Copyright 2023 gRPC authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "test/core/filters/filter_test.h" |
||||
|
||||
#include <functional> |
||||
#include <memory> |
||||
#include <type_traits> |
||||
#include <utility> |
||||
|
||||
#include "gmock/gmock.h" |
||||
#include "gtest/gtest.h" |
||||
|
||||
#include <grpc/compression.h> |
||||
#include <grpc/grpc.h> |
||||
|
||||
#include "src/core/lib/channel/promise_based_filter.h" |
||||
#include "src/core/lib/promise/activity.h" |
||||
#include "src/core/lib/promise/arena_promise.h" |
||||
#include "src/core/lib/promise/map.h" |
||||
#include "src/core/lib/promise/pipe.h" |
||||
#include "src/core/lib/promise/poll.h" |
||||
#include "src/core/lib/promise/seq.h" |
||||
#include "src/core/lib/slice/slice.h" |
||||
#include "src/core/lib/transport/metadata_batch.h" |
||||
#include "src/core/lib/transport/transport.h" |
||||
|
||||
using ::testing::_; |
||||
|
||||
namespace grpc_core { |
||||
namespace { |
||||
|
||||
class NoOpFilter final : public ChannelFilter { |
||||
public: |
||||
ArenaPromise<ServerMetadataHandle> MakeCallPromise( |
||||
CallArgs args, NextPromiseFactory next) override { |
||||
return next(std::move(args)); |
||||
} |
||||
|
||||
static absl::StatusOr<NoOpFilter> Create(const ChannelArgs&, |
||||
ChannelFilter::Args) { |
||||
return NoOpFilter(); |
||||
} |
||||
}; |
||||
using NoOpFilterTest = FilterTest<NoOpFilter>; |
||||
|
||||
class DelayStartFilter final : public ChannelFilter { |
||||
public: |
||||
ArenaPromise<ServerMetadataHandle> MakeCallPromise( |
||||
CallArgs args, NextPromiseFactory next) override { |
||||
return Seq( |
||||
[args = std::move(args), i = 10]() mutable -> Poll<CallArgs> { |
||||
--i; |
||||
if (i == 0) return std::move(args); |
||||
Activity::current()->ForceImmediateRepoll(); |
||||
return Pending{}; |
||||
}, |
||||
next); |
||||
} |
||||
|
||||
static absl::StatusOr<DelayStartFilter> Create(const ChannelArgs&, |
||||
ChannelFilter::Args) { |
||||
return DelayStartFilter(); |
||||
} |
||||
}; |
||||
using DelayStartFilterTest = FilterTest<DelayStartFilter>; |
||||
|
||||
class AddClientInitialMetadataFilter final : public ChannelFilter { |
||||
public: |
||||
ArenaPromise<ServerMetadataHandle> MakeCallPromise( |
||||
CallArgs args, NextPromiseFactory next) override { |
||||
args.client_initial_metadata->Set(HttpPathMetadata(), |
||||
Slice::FromCopiedString("foo.bar")); |
||||
return next(std::move(args)); |
||||
} |
||||
|
||||
static absl::StatusOr<AddClientInitialMetadataFilter> Create( |
||||
const ChannelArgs&, ChannelFilter::Args) { |
||||
return AddClientInitialMetadataFilter(); |
||||
} |
||||
}; |
||||
using AddClientInitialMetadataFilterTest = |
||||
FilterTest<AddClientInitialMetadataFilter>; |
||||
|
||||
class AddServerTrailingMetadataFilter final : public ChannelFilter { |
||||
public: |
||||
ArenaPromise<ServerMetadataHandle> MakeCallPromise( |
||||
CallArgs args, NextPromiseFactory next) override { |
||||
return Map(next(std::move(args)), [](ServerMetadataHandle handle) { |
||||
handle->Set(HttpStatusMetadata(), 420); |
||||
return handle; |
||||
}); |
||||
} |
||||
|
||||
static absl::StatusOr<AddServerTrailingMetadataFilter> Create( |
||||
const ChannelArgs&, ChannelFilter::Args) { |
||||
return AddServerTrailingMetadataFilter(); |
||||
} |
||||
}; |
||||
using AddServerTrailingMetadataFilterTest = |
||||
FilterTest<AddServerTrailingMetadataFilter>; |
||||
|
||||
class AddServerInitialMetadataFilter final : public ChannelFilter { |
||||
public: |
||||
ArenaPromise<ServerMetadataHandle> MakeCallPromise( |
||||
CallArgs args, NextPromiseFactory next) override { |
||||
args.server_initial_metadata->InterceptAndMap([](ServerMetadataHandle md) { |
||||
md->Set(GrpcEncodingMetadata(), GRPC_COMPRESS_GZIP); |
||||
return md; |
||||
}); |
||||
return next(std::move(args)); |
||||
} |
||||
|
||||
static absl::StatusOr<AddServerInitialMetadataFilter> Create( |
||||
const ChannelArgs&, ChannelFilter::Args) { |
||||
return AddServerInitialMetadataFilter(); |
||||
} |
||||
}; |
||||
using AddServerInitialMetadataFilterTest = |
||||
FilterTest<AddServerInitialMetadataFilter>; |
||||
|
||||
TEST_F(NoOpFilterTest, NoOp) {} |
||||
|
||||
TEST_F(NoOpFilterTest, MakeCall) { |
||||
Call call(MakeChannel(ChannelArgs()).value()); |
||||
} |
||||
|
||||
TEST_F(NoOpFilterTest, MakeClientMetadata) { |
||||
Call call(MakeChannel(ChannelArgs()).value()); |
||||
auto md = call.NewClientMetadata({{":path", "foo.bar"}}); |
||||
EXPECT_EQ(md->get_pointer(HttpPathMetadata())->as_string_view(), "foo.bar"); |
||||
} |
||||
|
||||
TEST_F(NoOpFilterTest, MakeServerMetadata) { |
||||
Call call(MakeChannel(ChannelArgs()).value()); |
||||
auto md = call.NewServerMetadata({{":status", "200"}}); |
||||
EXPECT_EQ(md->get(HttpStatusMetadata()), HttpStatusMetadata::ValueType(200)); |
||||
} |
||||
|
||||
TEST_F(NoOpFilterTest, CanStart) { |
||||
Call call(MakeChannel(ChannelArgs()).value()); |
||||
EXPECT_EVENT(Started(&call, _)); |
||||
call.Start(call.NewClientMetadata()); |
||||
Step(); |
||||
} |
||||
|
||||
TEST_F(DelayStartFilterTest, CanStartWithDelay) { |
||||
Call call(MakeChannel(ChannelArgs()).value()); |
||||
EXPECT_EVENT(Started(&call, _)); |
||||
call.Start(call.NewClientMetadata()); |
||||
Step(); |
||||
} |
||||
|
||||
TEST_F(NoOpFilterTest, CanCancel) { |
||||
Call call(MakeChannel(ChannelArgs()).value()); |
||||
EXPECT_EVENT(Started(&call, _)); |
||||
call.Start(call.NewClientMetadata()); |
||||
call.Cancel(); |
||||
} |
||||
|
||||
TEST_F(DelayStartFilterTest, CanCancelWithDelay) { |
||||
Call call(MakeChannel(ChannelArgs()).value()); |
||||
call.Start(call.NewClientMetadata()); |
||||
call.Cancel(); |
||||
} |
||||
|
||||
TEST_F(AddClientInitialMetadataFilterTest, CanSetClientInitialMetadata) { |
||||
Call call(MakeChannel(ChannelArgs()).value()); |
||||
EXPECT_EVENT(Started(&call, HasMetadataKeyValue(":path", "foo.bar"))); |
||||
call.Start(call.NewClientMetadata()); |
||||
Step(); |
||||
} |
||||
|
||||
TEST_F(NoOpFilterTest, CanFinish) { |
||||
Call call(MakeChannel(ChannelArgs()).value()); |
||||
EXPECT_EVENT(Started(&call, _)); |
||||
call.Start(call.NewClientMetadata()); |
||||
call.FinishNextFilter(call.NewServerMetadata()); |
||||
EXPECT_EVENT(Finished(&call, _)); |
||||
Step(); |
||||
} |
||||
|
||||
TEST_F(AddServerTrailingMetadataFilterTest, CanSetServerTrailingMetadata) { |
||||
Call call(MakeChannel(ChannelArgs()).value()); |
||||
EXPECT_EVENT(Started(&call, _)); |
||||
call.Start(call.NewClientMetadata()); |
||||
call.FinishNextFilter(call.NewServerMetadata()); |
||||
EXPECT_EVENT(Finished(&call, HasMetadataKeyValue(":status", "420"))); |
||||
Step(); |
||||
} |
||||
|
||||
TEST_F(NoOpFilterTest, CanProcessServerInitialMetadata) { |
||||
Call call(MakeChannel(ChannelArgs()).value()); |
||||
EXPECT_EVENT(Started(&call, _)); |
||||
call.Start(call.NewClientMetadata()); |
||||
call.ForwardServerInitialMetadata(call.NewServerMetadata()); |
||||
EXPECT_EVENT(ForwardedServerInitialMetadata(&call, _)); |
||||
Step(); |
||||
} |
||||
|
||||
TEST_F(AddServerInitialMetadataFilterTest, CanSetServerInitialMetadata) { |
||||
Call call(MakeChannel(ChannelArgs()).value()); |
||||
EXPECT_EVENT(Started(&call, _)); |
||||
call.Start(call.NewClientMetadata()); |
||||
call.ForwardServerInitialMetadata(call.NewServerMetadata()); |
||||
EXPECT_EVENT(ForwardedServerInitialMetadata( |
||||
&call, HasMetadataKeyValue("grpc-encoding", "gzip"))); |
||||
Step(); |
||||
} |
||||
|
||||
TEST_F(NoOpFilterTest, CanProcessClientToServerMessage) { |
||||
Call call(MakeChannel(ChannelArgs()).value()); |
||||
EXPECT_EVENT(Started(&call, _)); |
||||
call.Start(call.NewClientMetadata()); |
||||
call.ForwardMessageClientToServer(call.NewMessage("abc")); |
||||
EXPECT_CALL(events, |
||||
ForwardedMessageClientToServer(&call, HasMessagePayload("abc"))); |
||||
Step(); |
||||
} |
||||
|
||||
TEST_F(NoOpFilterTest, CanProcessServerToClientMessage) { |
||||
Call call(MakeChannel(ChannelArgs()).value()); |
||||
EXPECT_EVENT(Started(&call, _)); |
||||
call.Start(call.NewClientMetadata()); |
||||
call.ForwardServerInitialMetadata(call.NewServerMetadata()); |
||||
call.ForwardMessageServerToClient(call.NewMessage("abc")); |
||||
EXPECT_EVENT(ForwardedServerInitialMetadata(&call, _)); |
||||
EXPECT_CALL(events, |
||||
ForwardedMessageServerToClient(&call, HasMessagePayload("abc"))); |
||||
Step(); |
||||
} |
||||
|
||||
} // namespace
|
||||
} // namespace grpc_core
|
||||
|
||||
int main(int argc, char** argv) { |
||||
::testing::InitGoogleTest(&argc, argv); |
||||
grpc_init(); |
||||
int r = RUN_ALL_TESTS(); |
||||
grpc_shutdown(); |
||||
return r; |
||||
} |
Loading…
Reference in new issue