[promises] Filter unit test framework (#32110)

Built atop #31448 

Offers a simple framework for testing filters.

<!--

If you know who should review your pull request, please assign it to
that
person, otherwise the pull request would get assigned randomly.

If your pull request is for a specific language, please add the
appropriate
lang label.

-->

---------

Co-authored-by: ctiller <ctiller@users.noreply.github.com>
pull/32604/head
Craig Tiller 2 years ago committed by GitHub
parent db3daf567b
commit 2cd1501ca5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 61
      CMakeLists.txt
  2. 32
      build_autogenerated.yaml
  3. 2
      src/core/lib/transport/metadata_batch.h
  4. 10
      test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.cc
  5. 2
      test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.h
  6. 120
      test/core/filters/BUILD
  7. 117
      test/core/filters/client_auth_filter_test.cc
  8. 114
      test/core/filters/client_authority_filter_test.cc
  9. 425
      test/core/filters/filter_test.cc
  10. 225
      test/core/filters/filter_test.h
  11. 253
      test/core/filters/filter_test_test.cc
  12. 1
      tools/distrib/fix_build_deps.py
  13. 24
      tools/run_tests/generated/tests.json

61
CMakeLists.txt generated

@ -962,6 +962,7 @@ if(gRPC_BUILD_TESTS)
endif()
add_dependencies(buildtests_cxx file_watcher_certificate_provider_factory_test)
add_dependencies(buildtests_cxx filter_end2end_test)
add_dependencies(buildtests_cxx filter_test_test)
add_dependencies(buildtests_cxx flaky_network_test)
add_dependencies(buildtests_cxx flow_control_test)
add_dependencies(buildtests_cxx for_each_test)
@ -8437,7 +8438,13 @@ endif()
if(gRPC_BUILD_TESTS)
add_executable(client_auth_filter_test
${_gRPC_PROTO_GENS_DIR}/test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.pb.cc
${_gRPC_PROTO_GENS_DIR}/test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.grpc.pb.cc
${_gRPC_PROTO_GENS_DIR}/test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.pb.h
${_gRPC_PROTO_GENS_DIR}/test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.grpc.pb.h
test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.cc
test/core/filters/client_auth_filter_test.cc
test/core/filters/filter_test.cc
third_party/googletest/googletest/src/gtest-all.cc
third_party/googletest/googlemock/src/gmock-all.cc
)
@ -8466,7 +8473,7 @@ target_link_libraries(client_auth_filter_test
${_gRPC_PROTOBUF_LIBRARIES}
${_gRPC_ZLIB_LIBRARIES}
${_gRPC_ALLTARGETS_LIBRARIES}
grpc
grpc_test_util
)
@ -8474,7 +8481,13 @@ endif()
if(gRPC_BUILD_TESTS)
add_executable(client_authority_filter_test
${_gRPC_PROTO_GENS_DIR}/test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.pb.cc
${_gRPC_PROTO_GENS_DIR}/test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.grpc.pb.cc
${_gRPC_PROTO_GENS_DIR}/test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.pb.h
${_gRPC_PROTO_GENS_DIR}/test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.grpc.pb.h
test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.cc
test/core/filters/client_authority_filter_test.cc
test/core/filters/filter_test.cc
third_party/googletest/googletest/src/gtest-all.cc
third_party/googletest/googlemock/src/gmock-all.cc
)
@ -8503,7 +8516,7 @@ target_link_libraries(client_authority_filter_test
${_gRPC_PROTOBUF_LIBRARIES}
${_gRPC_ZLIB_LIBRARIES}
${_gRPC_ALLTARGETS_LIBRARIES}
grpc
grpc_test_util
)
@ -11011,6 +11024,50 @@ target_link_libraries(filter_end2end_test
)
endif()
if(gRPC_BUILD_TESTS)
add_executable(filter_test_test
${_gRPC_PROTO_GENS_DIR}/test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.pb.cc
${_gRPC_PROTO_GENS_DIR}/test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.grpc.pb.cc
${_gRPC_PROTO_GENS_DIR}/test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.pb.h
${_gRPC_PROTO_GENS_DIR}/test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.grpc.pb.h
test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.cc
test/core/filters/filter_test.cc
test/core/filters/filter_test_test.cc
third_party/googletest/googletest/src/gtest-all.cc
third_party/googletest/googlemock/src/gmock-all.cc
)
target_compile_features(filter_test_test PUBLIC cxx_std_14)
target_include_directories(filter_test_test
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/include
${_gRPC_ADDRESS_SORTING_INCLUDE_DIR}
${_gRPC_RE2_INCLUDE_DIR}
${_gRPC_SSL_INCLUDE_DIR}
${_gRPC_UPB_GENERATED_DIR}
${_gRPC_UPB_GRPC_GENERATED_DIR}
${_gRPC_UPB_INCLUDE_DIR}
${_gRPC_XXHASH_INCLUDE_DIR}
${_gRPC_ZLIB_INCLUDE_DIR}
third_party/googletest/googletest/include
third_party/googletest/googletest
third_party/googletest/googlemock/include
third_party/googletest/googlemock
${_gRPC_PROTO_GENS_DIR}
)
target_link_libraries(filter_test_test
${_gRPC_BASELIB_LIBRARIES}
${_gRPC_PROTOBUF_LIBRARIES}
${_gRPC_ZLIB_LIBRARIES}
${_gRPC_ALLTARGETS_LIBRARIES}
grpc_unsecure
grpc_test_util
)
endif()
if(gRPC_BUILD_TESTS)

@ -5918,22 +5918,30 @@ targets:
build: test
language: c++
headers:
- test/core/promise/test_context.h
- test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.h
- test/core/filters/filter_test.h
src:
- test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.proto
- test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.cc
- test/core/filters/client_auth_filter_test.cc
- test/core/filters/filter_test.cc
deps:
- grpc
- grpc_test_util
uses_polling: false
- name: client_authority_filter_test
gtest: true
build: test
language: c++
headers:
- test/core/promise/test_context.h
- test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.h
- test/core/filters/filter_test.h
src:
- test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.proto
- test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.cc
- test/core/filters/client_authority_filter_test.cc
- test/core/filters/filter_test.cc
deps:
- grpc
- grpc_test_util
uses_polling: false
- name: client_callback_end2end_test
gtest: true
@ -7074,6 +7082,22 @@ targets:
- test/cpp/end2end/filter_end2end_test.cc
deps:
- grpc++_test_util
- name: filter_test_test
gtest: true
build: test
language: c++
headers:
- test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.h
- test/core/filters/filter_test.h
src:
- test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.proto
- test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.cc
- test/core/filters/filter_test.cc
- test/core/filters/filter_test_test.cc
deps:
- grpc_unsecure
- grpc_test_util
uses_polling: false
- name: flaky_network_test
gtest: true
build: test

@ -1270,8 +1270,6 @@ class MetadataMap {
// Parse metadata from a key/value pair, and return an object representing
// that result.
// TODO(ctiller): key should probably be an absl::string_view.
// Once we don't care about interning anymore, make that change!
static ParsedMetadata<Derived> Parse(absl::string_view key, Slice value,
uint32_t transport_size,
MetadataParseErrorFn on_error) {

@ -175,6 +175,16 @@ void FuzzingEventEngine::Tick() {
}
}
void FuzzingEventEngine::TickUntilIdle() {
while (true) {
{
grpc_core::MutexLock lock(&*mu_);
if (tasks_by_id_.empty()) return;
}
Tick();
}
}
FuzzingEventEngine::Time FuzzingEventEngine::Now() {
grpc_core::MutexLock lock(&*mu_);
return now_;

@ -65,6 +65,8 @@ class FuzzingEventEngine : public EventEngine {
void FuzzingDone() ABSL_LOCKS_EXCLUDED(mu_);
// Increment time once and perform any scheduled work.
void Tick() ABSL_LOCKS_EXCLUDED(mu_);
// Repeatedly call Tick() until there is no more work to do.
void TickUntilIdle() ABSL_LOCKS_EXCLUDED(mu_);
absl::StatusOr<std::unique_ptr<Listener>> CreateListener(
Listener::AcceptCallback on_accept,

@ -12,39 +12,109 @@
# See the License for the specific language governing permissions and
# limitations under the License.
load("//bazel:grpc_build_system.bzl", "grpc_cc_test", "grpc_package")
load("//bazel:grpc_build_system.bzl", "grpc_cc_library", "grpc_cc_test", "grpc_package")
load("//test/core/util:grpc_fuzzer.bzl", "grpc_proto_fuzzer")
licenses(["notice"])
grpc_package(name = "test/core/filters")
grpc_cc_library(
name = "filter_test",
srcs = ["filter_test.cc"],
hdrs = ["filter_test.h"],
external_deps = [
"absl/memory",
"absl/strings",
"absl/strings:str_format",
"absl/types:optional",
"absl/types:variant",
"gtest",
],
language = "c++",
tags = ["nofixdeps"], # until event engine tests are under fixbuilddeps
deps = [
"//:gpr",
"//:grpc",
"//:ref_counted_ptr",
"//src/core:activity",
"//src/core:arena",
"//src/core:arena_promise",
"//src/core:basic_seq",
"//src/core:context",
"//src/core:memory_quota",
"//src/core:pipe",
"//src/core:poll",
"//src/core:resource_quota",
"//src/core:slice",
"//src/core:slice_buffer",
"//test/core/event_engine/fuzzing_event_engine",
],
)
grpc_cc_test(
name = "filter_test_test",
srcs = ["filter_test_test.cc"],
external_deps = ["gtest"],
uses_event_engine = False,
uses_polling = False,
deps = [
"filter_test",
"//:grpc_unsecure",
"//src/core:activity",
"//src/core:arena_promise",
"//src/core:map",
"//src/core:pipe",
"//src/core:poll",
"//src/core:seq",
"//src/core:slice",
],
)
grpc_cc_test(
name = "client_auth_filter_test",
srcs = ["client_auth_filter_test.cc"],
external_deps = ["gtest"],
external_deps = [
"absl/status",
"absl/status:statusor",
"absl/strings",
"absl/types:optional",
"gtest",
],
language = "c++",
uses_event_engine = False,
uses_polling = False,
deps = [
"filter_test",
"//:grpc",
"//:grpc_public_hdrs",
"//:grpc_security_base",
"//:promise",
"//:ref_counted_ptr",
"//src/core:arena_promise",
"//src/core:channel_args",
"//test/core/promise:test_context",
"//src/core:grpc_fake_credentials",
"//src/core:unique_type_name",
"//src/core:useful",
],
)
grpc_cc_test(
name = "client_authority_filter_test",
srcs = ["client_authority_filter_test.cc"],
external_deps = ["gtest"],
external_deps = [
"absl/status",
"absl/strings",
"absl/types:optional",
"gtest",
],
language = "c++",
uses_event_engine = False,
uses_polling = False,
deps = [
"filter_test",
"//:grpc",
"//src/core:grpc_client_authority_filter",
"//test/core/promise:test_context",
],
)
@ -52,16 +122,52 @@ grpc_proto_fuzzer(
name = "filter_fuzzer",
srcs = ["filter_fuzzer.cc"],
corpus = "filter_fuzzer_corpus",
external_deps = [
"absl/base:core_headers",
"absl/status",
"absl/status:statusor",
"absl/strings",
"absl/types:optional",
],
language = "C++",
proto = "filter_fuzzer.proto",
tags = ["no_windows"],
uses_polling = False,
deps = [
"//:config",
"//:debug_location",
"//:exec_ctx",
"//:gpr",
"//:grpc",
"//:grpc_http_filters",
"//:grpc_public_hdrs",
"//:grpc_security_base",
"//:handshaker",
"//:iomgr_timer",
"//:ref_counted_ptr",
"//:tsi_base",
"//src/core:activity",
"//src/core:arena",
"//src/core:arena_promise",
"//src/core:channel_args",
"//src/core:channel_args_preconditioning",
"//src/core:channel_fwd",
"//src/core:channel_stack_type",
"//src/core:closure",
"//test/core/end2end:ssl_test_data",
"//test/core/util:grpc_test_util",
"//src/core:context",
"//src/core:env",
"//src/core:error",
"//src/core:grpc_authorization_base",
"//src/core:grpc_channel_idle_filter",
"//src/core:grpc_client_authority_filter",
"//src/core:iomgr_fwd",
"//src/core:memory_quota",
"//src/core:pipe",
"//src/core:poll",
"//src/core:resource_quota",
"//src/core:slice",
"//src/core:time",
"//src/core:transport_fwd",
"//src/core:useful",
],
)

@ -12,9 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stddef.h>
#include <memory>
#include <string>
#include <utility>
@ -22,35 +19,28 @@
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include <grpc/event_engine/memory_allocator.h>
#include <grpc/grpc.h>
#include <grpc/grpc_security.h>
#include <grpc/grpc_security_constants.h>
#include <grpc/status.h>
#include "src/core/lib/channel/channel_args.h"
#include "src/core/lib/channel/context.h"
#include "src/core/lib/channel/promise_based_filter.h"
#include "src/core/lib/gpr/useful.h"
#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/lib/gprpp/unique_type_name.h"
#include "src/core/lib/promise/arena_promise.h"
#include "src/core/lib/promise/poll.h"
#include "src/core/lib/promise/promise.h"
#include "src/core/lib/resource_quota/arena.h"
#include "src/core/lib/resource_quota/memory_quota.h"
#include "src/core/lib/resource_quota/resource_quota.h"
#include "src/core/lib/security/context/security_context.h"
#include "src/core/lib/security/credentials/credentials.h"
#include "src/core/lib/security/credentials/fake/fake_credentials.h"
#include "src/core/lib/security/security_connector/security_connector.h"
#include "src/core/lib/security/transport/auth_filters.h"
#include "src/core/lib/slice/slice.h"
#include "src/core/lib/transport/metadata_batch.h"
#include "src/core/lib/transport/transport.h"
#include "test/core/promise/test_context.h"
#include "test/core/filters/filter_test.h"
// TODO(roth): Need to add a lot more tests here. I created this file
// as part of adding a feature, and I added tests only for the feature I
@ -60,7 +50,7 @@
namespace grpc_core {
namespace {
class ClientAuthFilterTest : public ::testing::Test {
class ClientAuthFilterTest : public FilterTest<ClientAuthFilter> {
protected:
class FailCallCreds : public grpc_call_credentials {
public:
@ -90,23 +80,10 @@ class ClientAuthFilterTest : public ::testing::Test {
};
ClientAuthFilterTest()
: memory_allocator_(
ResourceQuota::Default()->memory_quota()->CreateMemoryAllocator(
"test")),
arena_(MakeScopedArena(1024, &memory_allocator_)),
initial_metadata_batch_(arena_.get()),
trailing_metadata_batch_(arena_.get()),
target_(Slice::FromStaticString("localhost:1234")),
channel_creds_(grpc_fake_transport_security_credentials_create()) {
initial_metadata_batch_.Set(HttpAuthorityMetadata(), target_.Ref());
}
: channel_creds_(grpc_fake_transport_security_credentials_create()) {}
~ClientAuthFilterTest() override {
for (size_t i = 0; i < GRPC_CONTEXT_COUNT; ++i) {
if (call_context_[i].destroy != nullptr) {
call_context_[i].destroy(call_context_[i].value);
}
}
Channel MakeChannelWithCallCredsResult(absl::Status status) {
return MakeChannel(MakeChannelArgs(std::move(status))).value();
}
ChannelArgs MakeChannelArgs(absl::Status status_for_call_creds) {
@ -115,7 +92,7 @@ class ClientAuthFilterTest : public ::testing::Test {
status_for_call_creds.ok()
? nullptr
: MakeRefCounted<FailCallCreds>(std::move(status_for_call_creds)),
std::string(target_.as_string_view()).c_str(), &args);
std::string(target()).c_str(), &args);
auto auth_context = MakeRefCounted<grpc_auth_context>(nullptr);
absl::string_view security_level = "TSI_SECURITY_NONE";
auth_context->add_property(GRPC_TRANSPORT_SECURITY_LEVEL_PROPERTY_NAME,
@ -124,13 +101,9 @@ class ClientAuthFilterTest : public ::testing::Test {
.SetObject(std::move(auth_context));
}
MemoryAllocator memory_allocator_;
ScopedArenaPtr arena_;
grpc_metadata_batch initial_metadata_batch_;
grpc_metadata_batch trailing_metadata_batch_;
Slice target_;
absl::string_view target() { return "localhost:1234"; }
RefCountedPtr<grpc_channel_credentials> channel_creds_;
grpc_call_context_element call_context_[GRPC_CONTEXT_COUNT];
};
TEST_F(ClientAuthFilterTest, CreateFailsWithoutRequiredChannelArgs) {
@ -139,72 +112,26 @@ TEST_F(ClientAuthFilterTest, CreateFailsWithoutRequiredChannelArgs) {
}
TEST_F(ClientAuthFilterTest, CreateSucceeds) {
auto filter = ClientAuthFilter::Create(MakeChannelArgs(absl::OkStatus()),
ChannelFilter::Args());
auto filter = MakeChannel(MakeChannelArgs(absl::OkStatus()));
EXPECT_TRUE(filter.ok()) << filter.status();
}
TEST_F(ClientAuthFilterTest, CallCredsFails) {
auto filter = ClientAuthFilter::Create(
MakeChannelArgs(absl::UnauthenticatedError("access denied")),
ChannelFilter::Args());
// TODO(ctiller): use Activity here, once it's ready.
TestContext<Arena> context(arena_.get());
TestContext<grpc_call_context_element> promise_call_context(call_context_);
auto promise = filter->MakeCallPromise(
CallArgs{ClientMetadataHandle(&initial_metadata_batch_,
Arena::PooledDeleter(nullptr)),
ClientInitialMetadataOutstandingToken::Empty(), nullptr, nullptr,
nullptr},
[&](CallArgs /*call_args*/) {
return ArenaPromise<ServerMetadataHandle>(
[&]() -> Poll<ServerMetadataHandle> {
return ServerMetadataHandle(&trailing_metadata_batch_,
Arena::PooledDeleter(nullptr));
});
});
auto result = promise();
ServerMetadataHandle* server_metadata = result.value_if_ready();
ASSERT_TRUE(server_metadata != nullptr);
auto status_md = (*server_metadata)->get(GrpcStatusMetadata());
ASSERT_TRUE(status_md.has_value());
EXPECT_EQ(*status_md, GRPC_STATUS_UNAUTHENTICATED);
const Slice* message_md =
(*server_metadata)->get_pointer(GrpcMessageMetadata());
ASSERT_TRUE(message_md != nullptr);
EXPECT_EQ(message_md->as_string_view(), "access denied");
Call call(MakeChannelWithCallCredsResult(
absl::UnauthenticatedError("access denied")));
call.Start(call.NewClientMetadata({{":authority", target()}}));
EXPECT_EVENT(Finished(
&call, HasMetadataResult(absl::UnauthenticatedError("access denied"))));
Step();
}
TEST_F(ClientAuthFilterTest, RewritesInvalidStatusFromCallCreds) {
auto filter = ClientAuthFilter::Create(
MakeChannelArgs(absl::AbortedError("nope")), ChannelFilter::Args());
// TODO(ctiller): use Activity here, once it's ready.
TestContext<Arena> context(arena_.get());
TestContext<grpc_call_context_element> promise_call_context(call_context_);
auto promise = filter->MakeCallPromise(
CallArgs{ClientMetadataHandle(&initial_metadata_batch_,
Arena::PooledDeleter(nullptr)),
ClientInitialMetadataOutstandingToken::Empty(), nullptr, nullptr,
nullptr},
[&](CallArgs /*call_args*/) {
return ArenaPromise<ServerMetadataHandle>(
[&]() -> Poll<ServerMetadataHandle> {
return ServerMetadataHandle(&trailing_metadata_batch_,
Arena::PooledDeleter(nullptr));
});
});
auto result = promise();
ServerMetadataHandle* server_metadata = result.value_if_ready();
ASSERT_TRUE(server_metadata != nullptr);
auto status_md = (*server_metadata)->get(GrpcStatusMetadata());
ASSERT_TRUE(status_md.has_value());
EXPECT_EQ(*status_md, GRPC_STATUS_INTERNAL);
const Slice* message_md =
(*server_metadata)->get_pointer(GrpcMessageMetadata());
ASSERT_TRUE(message_md != nullptr);
EXPECT_EQ(message_md->as_string_view(),
"Illegal status code from call credentials; original status: "
"ABORTED: nope");
Call call(MakeChannelWithCallCredsResult(absl::AbortedError("nope")));
call.Start(call.NewClientMetadata({{":authority", target()}}));
EXPECT_EVENT(Finished(&call, HasMetadataResult(absl::InternalError(
"Illegal status code from call credentials; "
"original status: ABORTED: nope"))));
Step();
}
} // namespace

@ -14,116 +14,56 @@
#include "src/core/ext/filters/http/client_authority_filter.h"
#include <memory>
#include "absl/status/status.h"
#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include <grpc/event_engine/memory_allocator.h>
#include <grpc/grpc.h>
#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/lib/promise/poll.h"
#include "src/core/lib/resource_quota/arena.h"
#include "src/core/lib/resource_quota/memory_quota.h"
#include "src/core/lib/resource_quota/resource_quota.h"
#include "src/core/lib/transport/metadata_batch.h"
#include "test/core/promise/test_context.h"
#include "test/core/filters/filter_test.h"
using ::testing::StrictMock;
namespace grpc_core {
namespace {
using ClientAuthorityFilterTest = FilterTest<ClientAuthorityFilter>;
ChannelArgs TestChannelArgs(absl::string_view default_authority) {
return ChannelArgs().Set(GRPC_ARG_DEFAULT_AUTHORITY, default_authority);
}
TEST(ClientAuthorityFilterTest, DefaultFails) {
EXPECT_FALSE(
ClientAuthorityFilter::Create(ChannelArgs(), ChannelFilter::Args()).ok());
TEST_F(ClientAuthorityFilterTest, DefaultFails) {
EXPECT_FALSE(MakeChannel(ChannelArgs()).ok());
}
TEST(ClientAuthorityFilterTest, WithArgSucceeds) {
EXPECT_EQ(ClientAuthorityFilter::Create(TestChannelArgs("foo.test.google.au"),
ChannelFilter::Args())
.status(),
TEST_F(ClientAuthorityFilterTest, WithArgSucceeds) {
EXPECT_EQ(MakeChannel(TestChannelArgs("foo.test.google.au")).status(),
absl::OkStatus());
}
TEST(ClientAuthorityFilterTest, NonStringArgFails) {
EXPECT_FALSE(ClientAuthorityFilter::Create(
ChannelArgs().Set(GRPC_ARG_DEFAULT_AUTHORITY, 123),
ChannelFilter::Args())
.ok());
TEST_F(ClientAuthorityFilterTest, NonStringArgFails) {
EXPECT_FALSE(
MakeChannel(ChannelArgs().Set(GRPC_ARG_DEFAULT_AUTHORITY, 123)).ok());
}
TEST(ClientAuthorityFilterTest, PromiseCompletesImmediatelyAndSetsAuthority) {
auto filter = *ClientAuthorityFilter::Create(
TestChannelArgs("foo.test.google.au"), ChannelFilter::Args());
MemoryAllocator memory_allocator = MemoryAllocator(
ResourceQuota::Default()->memory_quota()->CreateMemoryAllocator("test"));
auto arena = MakeScopedArena(1024, &memory_allocator);
grpc_metadata_batch initial_metadata_batch(arena.get());
grpc_metadata_batch trailing_metadata_batch(arena.get());
bool seen = false;
// TODO(ctiller): use Activity here, once it's ready.
TestContext<Arena> context(arena.get());
auto promise = filter.MakeCallPromise(
CallArgs{ClientMetadataHandle(&initial_metadata_batch,
Arena::PooledDeleter(nullptr)),
ClientInitialMetadataOutstandingToken::Empty(), nullptr, nullptr,
nullptr},
[&](CallArgs call_args) {
EXPECT_EQ(call_args.client_initial_metadata
->get_pointer(HttpAuthorityMetadata())
->as_string_view(),
"foo.test.google.au");
seen = true;
return ArenaPromise<ServerMetadataHandle>(
[&]() -> Poll<ServerMetadataHandle> {
return ServerMetadataHandle(&trailing_metadata_batch,
Arena::PooledDeleter(nullptr));
});
});
auto result = promise();
EXPECT_TRUE(result.ready());
EXPECT_TRUE(seen);
TEST_F(ClientAuthorityFilterTest, PromiseCompletesImmediatelyAndSetsAuthority) {
StrictMock<FilterTest::Call> call(
MakeChannel(TestChannelArgs("foo.test.google.au")).value());
EXPECT_EVENT(
Started(&call, HasMetadataKeyValue(":authority", "foo.test.google.au")));
call.Start(call.NewClientMetadata());
}
TEST(ClientAuthorityFilterTest,
PromiseCompletesImmediatelyAndDoesNotClobberAlreadySetsAuthority) {
auto filter = *ClientAuthorityFilter::Create(
TestChannelArgs("foo.test.google.au"), ChannelFilter::Args());
MemoryAllocator memory_allocator = MemoryAllocator(
ResourceQuota::Default()->memory_quota()->CreateMemoryAllocator("test"));
auto arena = MakeScopedArena(1024, &memory_allocator);
grpc_metadata_batch initial_metadata_batch(arena.get());
grpc_metadata_batch trailing_metadata_batch(arena.get());
initial_metadata_batch.Set(HttpAuthorityMetadata(),
Slice::FromStaticString("bar.test.google.au"));
bool seen = false;
// TODO(ctiller): use Activity here, once it's ready.
TestContext<Arena> context(arena.get());
auto promise = filter.MakeCallPromise(
CallArgs{ClientMetadataHandle(&initial_metadata_batch,
Arena::PooledDeleter(nullptr)),
ClientInitialMetadataOutstandingToken::Empty(), nullptr, nullptr,
nullptr},
[&](CallArgs call_args) {
EXPECT_EQ(call_args.client_initial_metadata
->get_pointer(HttpAuthorityMetadata())
->as_string_view(),
"bar.test.google.au");
seen = true;
return ArenaPromise<ServerMetadataHandle>(
[&]() -> Poll<ServerMetadataHandle> {
return ServerMetadataHandle(&trailing_metadata_batch,
Arena::PooledDeleter(nullptr));
});
});
auto result = promise();
EXPECT_TRUE(result.ready());
EXPECT_TRUE(seen);
TEST_F(ClientAuthorityFilterTest,
PromiseCompletesImmediatelyAndDoesNotSetAuthority) {
StrictMock<FilterTest::Call> call(
MakeChannel(TestChannelArgs("foo.test.google.au")).value());
EXPECT_EVENT(
Started(&call, HasMetadataKeyValue(":authority", "bar.test.google.au")));
call.Start(call.NewClientMetadata({{":authority", "bar.test.google.au"}}));
}
} // namespace

@ -0,0 +1,425 @@
// Copyright 2023 gRPC authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "test/core/filters/filter_test.h"
#include <algorithm>
#include <chrono>
#include <memory>
#include <queue>
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/types/optional.h"
#include "gtest/gtest.h"
#include "src/core/lib/channel/context.h"
#include "src/core/lib/gprpp/crash.h"
#include "src/core/lib/iomgr/timer_manager.h"
#include "src/core/lib/promise/activity.h"
#include "src/core/lib/promise/arena_promise.h"
#include "src/core/lib/promise/context.h"
#include "src/core/lib/promise/detail/basic_seq.h"
#include "src/core/lib/promise/pipe.h"
#include "src/core/lib/promise/poll.h"
#include "src/core/lib/resource_quota/arena.h"
#include "src/core/lib/slice/slice.h"
#include "test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.pb.h"
namespace grpc_core {
///////////////////////////////////////////////////////////////////////////////
// FilterTestBase::Call::Impl
class FilterTestBase::Call::Impl
: public std::enable_shared_from_this<FilterTestBase::Call::Impl> {
public:
Impl(Call* call, std::shared_ptr<Channel::Impl> channel)
: call_(call), channel_(std::move(channel)) {}
~Impl();
Arena* arena() { return arena_.get(); }
grpc_call_context_element* legacy_context() { return legacy_context_; }
const std::shared_ptr<Channel::Impl>& channel() const { return channel_; }
void Start(ClientMetadataHandle md);
void ForwardServerInitialMetadata(ServerMetadataHandle md);
void ForwardMessageClientToServer(MessageHandle msg);
void ForwardMessageServerToClient(MessageHandle msg);
void FinishNextFilter(ServerMetadataHandle md);
void StepLoop();
grpc_event_engine::experimental::EventEngine* event_engine() {
return channel_->test->event_engine();
}
Events& events() { return channel_->test->events; }
private:
bool StepOnce();
Poll<ServerMetadataHandle> PollNextFilter();
void ForceWakeup();
Call* const call_;
std::shared_ptr<Channel::Impl> const channel_;
ScopedArenaPtr arena_{MakeScopedArena(channel_->initial_arena_size,
&channel_->memory_allocator)};
absl::optional<ArenaPromise<ServerMetadataHandle>> promise_;
Poll<ServerMetadataHandle> poll_next_filter_result_;
Pipe<ServerMetadataHandle> pipe_server_initial_metadata_{arena_.get()};
Pipe<MessageHandle> pipe_server_to_client_messages_{arena_.get()};
Pipe<MessageHandle> pipe_client_to_server_messages_{arena_.get()};
PipeSender<ServerMetadataHandle>* server_initial_metadata_sender_ = nullptr;
PipeSender<MessageHandle>* server_to_client_messages_sender_ = nullptr;
PipeReceiver<MessageHandle>* client_to_server_messages_receiver_ = nullptr;
absl::optional<PipeSender<ServerMetadataHandle>::PushType>
push_server_initial_metadata_;
absl::optional<PipeReceiverNextType<ServerMetadataHandle>>
next_server_initial_metadata_;
absl::optional<PipeSender<MessageHandle>::PushType>
push_server_to_client_messages_;
absl::optional<PipeReceiverNextType<MessageHandle>>
next_server_to_client_messages_;
absl::optional<PipeSender<MessageHandle>::PushType>
push_client_to_server_messages_;
absl::optional<PipeReceiverNextType<MessageHandle>>
next_client_to_server_messages_;
absl::optional<ServerMetadataHandle> forward_server_initial_metadata_;
std::queue<MessageHandle> forward_client_to_server_messages_;
std::queue<MessageHandle> forward_server_to_client_messages_;
// Contexts for various subsystems (security, tracing, ...).
grpc_call_context_element legacy_context_[GRPC_CONTEXT_COUNT] = {};
};
FilterTestBase::Call::Impl::~Impl() {
for (size_t i = 0; i < GRPC_CONTEXT_COUNT; ++i) {
if (legacy_context_[i].destroy != nullptr) {
legacy_context_[i].destroy(legacy_context_[i].value);
}
}
}
void FilterTestBase::Call::Impl::Start(ClientMetadataHandle md) {
EXPECT_EQ(promise_, absl::nullopt);
promise_ = channel_->filter->MakeCallPromise(
CallArgs{std::move(md), ClientInitialMetadataOutstandingToken::Empty(),
&pipe_server_initial_metadata_.sender,
&pipe_client_to_server_messages_.receiver,
&pipe_server_to_client_messages_.sender},
[this](CallArgs args) -> ArenaPromise<ServerMetadataHandle> {
server_initial_metadata_sender_ = args.server_initial_metadata;
client_to_server_messages_receiver_ = args.client_to_server_messages;
server_to_client_messages_sender_ = args.server_to_client_messages;
next_server_initial_metadata_.emplace(
pipe_server_initial_metadata_.receiver.Next());
events().Started(call_, *args.client_initial_metadata);
return [this]() { return PollNextFilter(); };
});
EXPECT_NE(promise_, absl::nullopt);
ForceWakeup();
}
Poll<ServerMetadataHandle> FilterTestBase::Call::Impl::PollNextFilter() {
return std::exchange(poll_next_filter_result_, Pending());
}
void FilterTestBase::Call::Impl::ForwardServerInitialMetadata(
ServerMetadataHandle md) {
EXPECT_FALSE(forward_server_initial_metadata_.has_value());
forward_server_initial_metadata_ = std::move(md);
ForceWakeup();
}
void FilterTestBase::Call::Impl::ForwardMessageClientToServer(
MessageHandle msg) {
forward_client_to_server_messages_.push(std::move(msg));
ForceWakeup();
}
void FilterTestBase::Call::Impl::ForwardMessageServerToClient(
MessageHandle msg) {
forward_server_to_client_messages_.push(std::move(msg));
ForceWakeup();
}
void FilterTestBase::Call::Impl::FinishNextFilter(ServerMetadataHandle md) {
poll_next_filter_result_ = std::move(md);
ForceWakeup();
}
bool FilterTestBase::Call::Impl::StepOnce() {
if (!promise_.has_value()) return true;
if (forward_server_initial_metadata_.has_value() &&
!push_server_initial_metadata_.has_value()) {
push_server_initial_metadata_.emplace(server_initial_metadata_sender_->Push(
std::move(*forward_server_initial_metadata_)));
forward_server_initial_metadata_.reset();
}
if (push_server_initial_metadata_.has_value()) {
auto r = (*push_server_initial_metadata_)();
if (r.ready()) push_server_initial_metadata_.reset();
}
if (next_server_initial_metadata_.has_value()) {
auto r = (*next_server_initial_metadata_)();
if (auto* p = r.value_if_ready()) {
if (p->has_value()) {
events().ForwardedServerInitialMetadata(call_, *p->value());
}
next_server_initial_metadata_.reset();
}
}
if (server_initial_metadata_sender_ != nullptr &&
!next_server_initial_metadata_.has_value()) {
// We've finished sending server initial metadata, so we can
// process server-to-client messages.
if (!next_server_to_client_messages_.has_value()) {
next_server_to_client_messages_.emplace(
pipe_server_to_client_messages_.receiver.Next());
}
if (push_server_to_client_messages_.has_value()) {
auto r = (*push_server_to_client_messages_)();
if (r.ready()) push_server_to_client_messages_.reset();
}
{
auto r = (*next_server_to_client_messages_)();
if (auto* p = r.value_if_ready()) {
if (p->has_value()) {
events().ForwardedMessageServerToClient(call_, *p->value());
}
next_server_to_client_messages_.reset();
Activity::current()->ForceImmediateRepoll();
}
}
if (!push_server_to_client_messages_.has_value() &&
!forward_server_to_client_messages_.empty()) {
push_server_to_client_messages_.emplace(
server_to_client_messages_sender_->Push(
std::move(forward_server_to_client_messages_.front())));
forward_server_to_client_messages_.pop();
Activity::current()->ForceImmediateRepoll();
}
}
if (client_to_server_messages_receiver_ != nullptr) {
if (!next_client_to_server_messages_.has_value()) {
next_client_to_server_messages_.emplace(
client_to_server_messages_receiver_->Next());
}
if (push_client_to_server_messages_.has_value()) {
auto r = (*push_client_to_server_messages_)();
if (r.ready()) push_client_to_server_messages_.reset();
}
{
auto r = (*next_client_to_server_messages_)();
if (auto* p = r.value_if_ready()) {
if (p->has_value()) {
events().ForwardedMessageClientToServer(call_, *p->value());
}
next_client_to_server_messages_.reset();
Activity::current()->ForceImmediateRepoll();
}
}
if (!push_client_to_server_messages_.has_value() &&
!forward_client_to_server_messages_.empty()) {
push_client_to_server_messages_.emplace(
pipe_client_to_server_messages_.sender.Push(
std::move(forward_client_to_server_messages_.front())));
forward_client_to_server_messages_.pop();
Activity::current()->ForceImmediateRepoll();
}
}
auto r = (*promise_)();
if (r.pending()) return false;
promise_.reset();
events().Finished(call_, *r.value());
return true;
}
///////////////////////////////////////////////////////////////////////////////
// FilterTestBase::Call::ScopedContext
class FilterTestBase::Call::ScopedContext final
: public Activity,
public promise_detail::Context<Arena>,
public promise_detail::Context<grpc_call_context_element> {
private:
class TestWakeable final : public Wakeable {
public:
explicit TestWakeable(ScopedContext* ctx)
: tag_(ctx->DebugTag()), impl_(ctx->impl_) {}
void Wakeup(WakeupMask) override {
std::unique_ptr<TestWakeable> self(this);
auto impl = impl_.lock();
if (impl == nullptr) return;
impl->event_engine()->Run([weak_impl = impl_]() {
auto impl = weak_impl.lock();
if (impl != nullptr) impl->StepLoop();
});
}
void Drop(WakeupMask) override { delete this; }
std::string ActivityDebugTag(WakeupMask) const override { return tag_; }
private:
const std::string tag_;
const std::weak_ptr<Impl> impl_;
};
public:
explicit ScopedContext(std::shared_ptr<Impl> impl)
: promise_detail::Context<Arena>(impl->arena()),
promise_detail::Context<grpc_call_context_element>(
impl->legacy_context()),
impl_(std::move(impl)) {}
void Orphan() override { Crash("Orphan called on Call::ScopedContext"); }
void ForceImmediateRepoll(WakeupMask) override { repoll_ = true; }
Waker MakeOwningWaker() override { return Waker(new TestWakeable(this), 0); }
Waker MakeNonOwningWaker() override {
return Waker(new TestWakeable(this), 0);
}
std::string DebugTag() const override {
return absl::StrFormat("FILTER_TEST_CALL[%p]", impl_.get());
}
bool repoll() const { return repoll_; }
private:
ScopedActivity scoped_activity_{this};
const std::shared_ptr<Impl> impl_;
bool repoll_ = false;
};
void FilterTestBase::Call::Impl::StepLoop() {
for (;;) {
ScopedContext ctx(shared_from_this());
if (!StepOnce() && ctx.repoll()) continue;
return;
}
}
void FilterTestBase::Call::Impl::ForceWakeup() {
ScopedContext(shared_from_this()).MakeOwningWaker().Wakeup();
}
///////////////////////////////////////////////////////////////////////////////
// FilterTestBase::Call
FilterTestBase::Call::Call(const Channel& channel)
: impl_(std::make_unique<Impl>(this, channel.impl_)) {}
FilterTestBase::Call::~Call() { ScopedContext x(std::move(impl_)); }
ClientMetadataHandle FilterTestBase::Call::NewClientMetadata(
std::initializer_list<std::pair<absl::string_view, absl::string_view>>
init) {
auto md = impl_->arena()->MakePooled<ClientMetadata>(impl_->arena());
for (auto& p : init) {
auto parsed = ClientMetadata::Parse(
p.first, Slice::FromCopiedString(p.second),
p.first.length() + p.second.length() + 32,
[p](absl::string_view, const Slice&) {
Crash(absl::StrCat("Illegal metadata value: ", p.first, ": ",
p.second));
});
md->Set(parsed);
}
return md;
}
ServerMetadataHandle FilterTestBase::Call::NewServerMetadata(
std::initializer_list<std::pair<absl::string_view, absl::string_view>>
init) {
auto md = impl_->arena()->MakePooled<ClientMetadata>(impl_->arena());
for (auto& p : init) {
auto parsed = ServerMetadata::Parse(
p.first, Slice::FromCopiedString(p.second),
p.first.length() + p.second.length() + 32,
[p](absl::string_view, const Slice&) {
Crash(absl::StrCat("Illegal metadata value: ", p.first, ": ",
p.second));
});
md->Set(parsed);
}
return md;
}
MessageHandle FilterTestBase::Call::NewMessage(absl::string_view payload,
uint32_t flags) {
SliceBuffer buffer;
if (!payload.empty()) buffer.Append(Slice::FromCopiedString(payload));
return impl_->arena()->MakePooled<Message>(std::move(buffer), flags);
}
void FilterTestBase::Call::Start(ClientMetadataHandle md) {
ScopedContext ctx(impl_);
impl_->Start(std::move(md));
}
void FilterTestBase::Call::Cancel() {
ScopedContext ctx(impl_);
impl_ = absl::make_unique<Impl>(this, impl_->channel());
}
void FilterTestBase::Call::ForwardServerInitialMetadata(
ServerMetadataHandle md) {
impl_->ForwardServerInitialMetadata(std::move(md));
}
void FilterTestBase::Call::ForwardMessageClientToServer(MessageHandle msg) {
impl_->ForwardMessageClientToServer(std::move(msg));
}
void FilterTestBase::Call::ForwardMessageServerToClient(MessageHandle msg) {
impl_->ForwardMessageServerToClient(std::move(msg));
}
void FilterTestBase::Call::FinishNextFilter(ServerMetadataHandle md) {
impl_->FinishNextFilter(std::move(md));
}
///////////////////////////////////////////////////////////////////////////////
// FilterTestBase
FilterTestBase::FilterTestBase()
: event_engine_(
[]() {
grpc_timer_manager_set_threading(false);
grpc_event_engine::experimental::FuzzingEventEngine::Options
options;
options.final_tick_length = std::chrono::milliseconds(1);
return options;
}(),
fuzzing_event_engine::Actions()) {}
FilterTestBase::~FilterTestBase() { event_engine_.UnsetGlobalHooks(); }
void FilterTestBase::Step() {
event_engine_.TickUntilIdle();
::testing::Mock::VerifyAndClearExpectations(&events);
}
} // namespace grpc_core

@ -0,0 +1,225 @@
// Copyright 2023 gRPC authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef GRPC_TEST_CORE_FILTERS_FILTER_TEST_H
#define GRPC_TEST_CORE_FILTERS_FILTER_TEST_H
#include <stddef.h>
#include <stdint.h>
#include <initializer_list>
#include <iosfwd>
#include <memory>
#include <ostream>
#include <string>
#include <utility>
#include <gtest/gtest.h>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/escaping.h"
#include "absl/strings/string_view.h"
#include "gmock/gmock.h"
#include <grpc/event_engine/event_engine.h>
#include <grpc/event_engine/memory_allocator.h>
#include "src/core/lib/channel/channel_args.h"
#include "src/core/lib/channel/promise_based_filter.h"
#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/lib/resource_quota/memory_quota.h"
#include "src/core/lib/resource_quota/resource_quota.h"
#include "src/core/lib/slice/slice_buffer.h"
#include "src/core/lib/transport/metadata_batch.h"
#include "src/core/lib/transport/transport.h"
#include "test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.h"
#include "test/core/filters/filter_test.h"
// gmock matcher to ensure that metadata has a key/value pair.
MATCHER_P2(HasMetadataKeyValue, key, value, "") {
std::string temp;
auto r = arg.GetStringValue(key, &temp);
return r == value;
}
// gmock matcher to ensure that a message has a given set of flags.
MATCHER_P(HasMessageFlags, value, "") { return arg.flags() == value; }
MATCHER_P(HasMetadataResult, absl_status, "") {
auto status = arg.get(grpc_core::GrpcStatusMetadata());
if (!status.has_value()) return false;
if (static_cast<absl::StatusCode>(status.value()) != absl_status.code()) {
return false;
}
auto* message = arg.get_pointer(grpc_core::GrpcMessageMetadata());
if (message == nullptr) return absl_status.message().empty();
return message->as_string_view() == absl_status.message();
}
// gmock matcher to ensure that a message has a given payload.
MATCHER_P(HasMessagePayload, value, "") {
return arg.payload()->JoinIntoString() == value;
}
namespace grpc_core {
inline std::ostream& operator<<(std::ostream& os,
const grpc_metadata_batch& md) {
return os << md.DebugString();
}
inline std::ostream& operator<<(std::ostream& os, const Message& msg) {
return os << "flags:" << msg.flags()
<< " payload:" << absl::CEscape(msg.payload()->JoinIntoString());
}
class FilterTestBase : public ::testing::Test {
public:
class Call;
class Channel {
private:
struct Impl {
Impl(std::unique_ptr<ChannelFilter> filter, FilterTestBase* test)
: filter(std::move(filter)), test(test) {}
size_t initial_arena_size = 1024;
MemoryAllocator memory_allocator =
ResourceQuota::Default()->memory_quota()->CreateMemoryAllocator(
"test");
std::unique_ptr<ChannelFilter> filter;
FilterTestBase* const test;
};
public:
void set_initial_arena_size(size_t size) {
impl_->initial_arena_size = size;
}
Call MakeCall();
private:
friend class FilterTestBase;
friend class Call;
explicit Channel(std::unique_ptr<ChannelFilter> filter,
FilterTestBase* test)
: impl_(std::make_shared<Impl>(std::move(filter), test)) {}
std::shared_ptr<Impl> impl_;
};
// One "call" outstanding against this filter.
// In reality - this filter is the only thing in the call.
// Provides mocks to trap events that happen on the call.
class Call {
public:
explicit Call(const Channel& channel);
Call(const Call&) = delete;
Call& operator=(const Call&) = delete;
~Call();
// Construct client metadata in the arena of this call.
// Optional argument is a list of key/value pairs to add to the metadata.
ClientMetadataHandle NewClientMetadata(
std::initializer_list<std::pair<absl::string_view, absl::string_view>>
init = {});
// Construct server metadata in the arena of this call.
// Optional argument is a list of key/value pairs to add to the metadata.
ServerMetadataHandle NewServerMetadata(
std::initializer_list<std::pair<absl::string_view, absl::string_view>>
init = {});
// Construct a message in the arena of this call.
MessageHandle NewMessage(absl::string_view payload = "",
uint32_t flags = 0);
// Start the call.
void Start(ClientMetadataHandle md);
// Cancel the call.
void Cancel();
// Forward server initial metadata through this filter.
void ForwardServerInitialMetadata(ServerMetadataHandle md);
// Forward a message from client to server through this filter.
void ForwardMessageClientToServer(MessageHandle msg);
// Forward a message from server to client through this filter.
void ForwardMessageServerToClient(MessageHandle msg);
// Have the 'next' filter in the chain finish this call and return trailing
// metadata.
void FinishNextFilter(ServerMetadataHandle md);
private:
friend class Channel;
class ScopedContext;
class Impl;
std::shared_ptr<Impl> impl_;
};
struct Events {
// Mock to trap starting the next filter in the chain.
MOCK_METHOD(void, Started,
(Call * call, const ClientMetadata& client_initial_metadata));
// Mock to trap receiving server initial metadata in the next filter in the
// chain.
MOCK_METHOD(void, ForwardedServerInitialMetadata,
(Call * call, const ServerMetadata& server_initial_metadata));
// Mock to trap seeing a message forward from client to server.
MOCK_METHOD(void, ForwardedMessageClientToServer,
(Call * call, const Message& msg));
// Mock to trap seeing a message forward from server to client.
MOCK_METHOD(void, ForwardedMessageServerToClient,
(Call * call, const Message& msg));
// Mock to trap seeing a call finish in the next filter in the chain.
MOCK_METHOD(void, Finished,
(Call * call, const ServerMetadata& server_trailing_metadata));
};
::testing::StrictMock<Events> events;
protected:
FilterTestBase();
~FilterTestBase() override;
absl::StatusOr<Channel> MakeChannel(std::unique_ptr<ChannelFilter> filter) {
return Channel(std::move(filter), this);
}
grpc_event_engine::experimental::EventEngine* event_engine() {
return &event_engine_;
}
void Step();
private:
grpc_event_engine::experimental::FuzzingEventEngine event_engine_;
};
template <typename Filter>
class FilterTest : public FilterTestBase {
public:
absl::StatusOr<Channel> MakeChannel(const ChannelArgs& args) {
auto filter = Filter::Create(args, ChannelFilter::Args());
if (!filter.ok()) return filter.status();
return FilterTestBase::MakeChannel(
std::make_unique<Filter>(std::move(*filter)));
}
};
} // namespace grpc_core
// Expect one of the events corresponding to the methods in FilterTest::Events.
#define EXPECT_EVENT(event) EXPECT_CALL(events, event)
#endif // GRPC_TEST_CORE_FILTERS_FILTER_TEST_H

@ -0,0 +1,253 @@
// Copyright 2023 gRPC authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "test/core/filters/filter_test.h"
#include <functional>
#include <memory>
#include <type_traits>
#include <utility>
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include <grpc/compression.h>
#include <grpc/grpc.h>
#include "src/core/lib/channel/promise_based_filter.h"
#include "src/core/lib/promise/activity.h"
#include "src/core/lib/promise/arena_promise.h"
#include "src/core/lib/promise/map.h"
#include "src/core/lib/promise/pipe.h"
#include "src/core/lib/promise/poll.h"
#include "src/core/lib/promise/seq.h"
#include "src/core/lib/slice/slice.h"
#include "src/core/lib/transport/metadata_batch.h"
#include "src/core/lib/transport/transport.h"
using ::testing::_;
namespace grpc_core {
namespace {
class NoOpFilter final : public ChannelFilter {
public:
ArenaPromise<ServerMetadataHandle> MakeCallPromise(
CallArgs args, NextPromiseFactory next) override {
return next(std::move(args));
}
static absl::StatusOr<NoOpFilter> Create(const ChannelArgs&,
ChannelFilter::Args) {
return NoOpFilter();
}
};
using NoOpFilterTest = FilterTest<NoOpFilter>;
class DelayStartFilter final : public ChannelFilter {
public:
ArenaPromise<ServerMetadataHandle> MakeCallPromise(
CallArgs args, NextPromiseFactory next) override {
return Seq(
[args = std::move(args), i = 10]() mutable -> Poll<CallArgs> {
--i;
if (i == 0) return std::move(args);
Activity::current()->ForceImmediateRepoll();
return Pending{};
},
next);
}
static absl::StatusOr<DelayStartFilter> Create(const ChannelArgs&,
ChannelFilter::Args) {
return DelayStartFilter();
}
};
using DelayStartFilterTest = FilterTest<DelayStartFilter>;
class AddClientInitialMetadataFilter final : public ChannelFilter {
public:
ArenaPromise<ServerMetadataHandle> MakeCallPromise(
CallArgs args, NextPromiseFactory next) override {
args.client_initial_metadata->Set(HttpPathMetadata(),
Slice::FromCopiedString("foo.bar"));
return next(std::move(args));
}
static absl::StatusOr<AddClientInitialMetadataFilter> Create(
const ChannelArgs&, ChannelFilter::Args) {
return AddClientInitialMetadataFilter();
}
};
using AddClientInitialMetadataFilterTest =
FilterTest<AddClientInitialMetadataFilter>;
class AddServerTrailingMetadataFilter final : public ChannelFilter {
public:
ArenaPromise<ServerMetadataHandle> MakeCallPromise(
CallArgs args, NextPromiseFactory next) override {
return Map(next(std::move(args)), [](ServerMetadataHandle handle) {
handle->Set(HttpStatusMetadata(), 420);
return handle;
});
}
static absl::StatusOr<AddServerTrailingMetadataFilter> Create(
const ChannelArgs&, ChannelFilter::Args) {
return AddServerTrailingMetadataFilter();
}
};
using AddServerTrailingMetadataFilterTest =
FilterTest<AddServerTrailingMetadataFilter>;
class AddServerInitialMetadataFilter final : public ChannelFilter {
public:
ArenaPromise<ServerMetadataHandle> MakeCallPromise(
CallArgs args, NextPromiseFactory next) override {
args.server_initial_metadata->InterceptAndMap([](ServerMetadataHandle md) {
md->Set(GrpcEncodingMetadata(), GRPC_COMPRESS_GZIP);
return md;
});
return next(std::move(args));
}
static absl::StatusOr<AddServerInitialMetadataFilter> Create(
const ChannelArgs&, ChannelFilter::Args) {
return AddServerInitialMetadataFilter();
}
};
using AddServerInitialMetadataFilterTest =
FilterTest<AddServerInitialMetadataFilter>;
TEST_F(NoOpFilterTest, NoOp) {}
TEST_F(NoOpFilterTest, MakeCall) {
Call call(MakeChannel(ChannelArgs()).value());
}
TEST_F(NoOpFilterTest, MakeClientMetadata) {
Call call(MakeChannel(ChannelArgs()).value());
auto md = call.NewClientMetadata({{":path", "foo.bar"}});
EXPECT_EQ(md->get_pointer(HttpPathMetadata())->as_string_view(), "foo.bar");
}
TEST_F(NoOpFilterTest, MakeServerMetadata) {
Call call(MakeChannel(ChannelArgs()).value());
auto md = call.NewServerMetadata({{":status", "200"}});
EXPECT_EQ(md->get(HttpStatusMetadata()), HttpStatusMetadata::ValueType(200));
}
TEST_F(NoOpFilterTest, CanStart) {
Call call(MakeChannel(ChannelArgs()).value());
EXPECT_EVENT(Started(&call, _));
call.Start(call.NewClientMetadata());
Step();
}
TEST_F(DelayStartFilterTest, CanStartWithDelay) {
Call call(MakeChannel(ChannelArgs()).value());
EXPECT_EVENT(Started(&call, _));
call.Start(call.NewClientMetadata());
Step();
}
TEST_F(NoOpFilterTest, CanCancel) {
Call call(MakeChannel(ChannelArgs()).value());
EXPECT_EVENT(Started(&call, _));
call.Start(call.NewClientMetadata());
call.Cancel();
}
TEST_F(DelayStartFilterTest, CanCancelWithDelay) {
Call call(MakeChannel(ChannelArgs()).value());
call.Start(call.NewClientMetadata());
call.Cancel();
}
TEST_F(AddClientInitialMetadataFilterTest, CanSetClientInitialMetadata) {
Call call(MakeChannel(ChannelArgs()).value());
EXPECT_EVENT(Started(&call, HasMetadataKeyValue(":path", "foo.bar")));
call.Start(call.NewClientMetadata());
Step();
}
TEST_F(NoOpFilterTest, CanFinish) {
Call call(MakeChannel(ChannelArgs()).value());
EXPECT_EVENT(Started(&call, _));
call.Start(call.NewClientMetadata());
call.FinishNextFilter(call.NewServerMetadata());
EXPECT_EVENT(Finished(&call, _));
Step();
}
TEST_F(AddServerTrailingMetadataFilterTest, CanSetServerTrailingMetadata) {
Call call(MakeChannel(ChannelArgs()).value());
EXPECT_EVENT(Started(&call, _));
call.Start(call.NewClientMetadata());
call.FinishNextFilter(call.NewServerMetadata());
EXPECT_EVENT(Finished(&call, HasMetadataKeyValue(":status", "420")));
Step();
}
TEST_F(NoOpFilterTest, CanProcessServerInitialMetadata) {
Call call(MakeChannel(ChannelArgs()).value());
EXPECT_EVENT(Started(&call, _));
call.Start(call.NewClientMetadata());
call.ForwardServerInitialMetadata(call.NewServerMetadata());
EXPECT_EVENT(ForwardedServerInitialMetadata(&call, _));
Step();
}
TEST_F(AddServerInitialMetadataFilterTest, CanSetServerInitialMetadata) {
Call call(MakeChannel(ChannelArgs()).value());
EXPECT_EVENT(Started(&call, _));
call.Start(call.NewClientMetadata());
call.ForwardServerInitialMetadata(call.NewServerMetadata());
EXPECT_EVENT(ForwardedServerInitialMetadata(
&call, HasMetadataKeyValue("grpc-encoding", "gzip")));
Step();
}
TEST_F(NoOpFilterTest, CanProcessClientToServerMessage) {
Call call(MakeChannel(ChannelArgs()).value());
EXPECT_EVENT(Started(&call, _));
call.Start(call.NewClientMetadata());
call.ForwardMessageClientToServer(call.NewMessage("abc"));
EXPECT_CALL(events,
ForwardedMessageClientToServer(&call, HasMessagePayload("abc")));
Step();
}
TEST_F(NoOpFilterTest, CanProcessServerToClientMessage) {
Call call(MakeChannel(ChannelArgs()).value());
EXPECT_EVENT(Started(&call, _));
call.Start(call.NewClientMetadata());
call.ForwardServerInitialMetadata(call.NewServerMetadata());
call.ForwardMessageServerToClient(call.NewMessage("abc"));
EXPECT_EVENT(ForwardedServerInitialMetadata(&call, _));
EXPECT_CALL(events,
ForwardedMessageServerToClient(&call, HasMessagePayload("abc")));
Step();
}
} // namespace
} // namespace grpc_core
int main(int argc, char** argv) {
::testing::InitGoogleTest(&argc, argv);
grpc_init();
int r = RUN_ALL_TESTS();
grpc_shutdown();
return r;
}

@ -446,6 +446,7 @@ for dirname in [
"test/core/util",
"test/core/end2end",
"test/core/event_engine",
"test/core/filters",
"test/core/promise",
"test/core/resource_quota",
"test/core/transport/chaotic_good",

@ -3035,6 +3035,30 @@
],
"uses_polling": true
},
{
"args": [],
"benchmark": false,
"ci_platforms": [
"linux",
"mac",
"posix",
"windows"
],
"cpu_cost": 1.0,
"exclude_configs": [],
"exclude_iomgrs": [],
"flaky": false,
"gtest": true,
"language": "c++",
"name": "filter_test_test",
"platforms": [
"linux",
"mac",
"posix",
"windows"
],
"uses_polling": false
},
{
"args": [],
"benchmark": false,

Loading…
Cancel
Save