commit
c160ab231a
75 changed files with 2762 additions and 732 deletions
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,156 @@ |
||||
// Copyright 2024 gRPC authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/core/lib/transport/interception_chain.h" |
||||
|
||||
#include <cstddef> |
||||
|
||||
#include <grpc/support/port_platform.h> |
||||
|
||||
#include "src/core/lib/gprpp/match.h" |
||||
#include "src/core/lib/transport/call_destination.h" |
||||
#include "src/core/lib/transport/call_filters.h" |
||||
#include "src/core/lib/transport/call_spine.h" |
||||
#include "src/core/lib/transport/metadata.h" |
||||
|
||||
namespace grpc_core { |
||||
|
||||
std::atomic<size_t> InterceptionChainBuilder::next_filter_id_{0}; |
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// HijackedCall
|
||||
|
||||
CallInitiator HijackedCall::MakeCall() { |
||||
auto metadata = Arena::MakePooled<ClientMetadata>(); |
||||
*metadata = metadata_->Copy(); |
||||
return MakeCallWithMetadata(std::move(metadata)); |
||||
} |
||||
|
||||
CallInitiator HijackedCall::MakeCallWithMetadata( |
||||
ClientMetadataHandle metadata) { |
||||
auto call = MakeCallPair(std::move(metadata), call_handler_.event_engine(), |
||||
call_handler_.arena(), nullptr, |
||||
call_handler_.legacy_context()); |
||||
destination_->StartCall(std::move(call.handler)); |
||||
return std::move(call.initiator); |
||||
} |
||||
|
||||
namespace { |
||||
class CallStarter final : public UnstartedCallDestination { |
||||
public: |
||||
CallStarter(RefCountedPtr<CallFilters::Stack> stack, |
||||
RefCountedPtr<CallDestination> destination) |
||||
: stack_(std::move(stack)), destination_(std::move(destination)) {} |
||||
|
||||
void Orphaned() override { |
||||
stack_.reset(); |
||||
destination_.reset(); |
||||
} |
||||
|
||||
void StartCall(UnstartedCallHandler unstarted_call_handler) override { |
||||
destination_->HandleCall(unstarted_call_handler.StartCall(stack_)); |
||||
} |
||||
|
||||
private: |
||||
RefCountedPtr<CallFilters::Stack> stack_; |
||||
RefCountedPtr<CallDestination> destination_; |
||||
}; |
||||
|
||||
class TerminalInterceptor final : public UnstartedCallDestination { |
||||
public: |
||||
explicit TerminalInterceptor( |
||||
RefCountedPtr<CallFilters::Stack> stack, |
||||
RefCountedPtr<UnstartedCallDestination> destination) |
||||
: stack_(std::move(stack)), destination_(std::move(destination)) {} |
||||
|
||||
void Orphaned() override { |
||||
stack_.reset(); |
||||
destination_.reset(); |
||||
} |
||||
|
||||
void StartCall(UnstartedCallHandler unstarted_call_handler) override { |
||||
unstarted_call_handler.SpawnGuarded( |
||||
"start_call", |
||||
Map(interception_chain_detail::HijackCall(unstarted_call_handler, |
||||
destination_, stack_), |
||||
[](ValueOrFailure<HijackedCall> hijacked_call) -> StatusFlag { |
||||
if (!hijacked_call.ok()) return Failure{}; |
||||
ForwardCall(hijacked_call.value().original_call_handler(), |
||||
hijacked_call.value().MakeLastCall()); |
||||
return Success{}; |
||||
})); |
||||
} |
||||
|
||||
private: |
||||
RefCountedPtr<CallFilters::Stack> stack_; |
||||
RefCountedPtr<UnstartedCallDestination> destination_; |
||||
}; |
||||
} // namespace
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// InterceptionChain::Builder
|
||||
|
||||
void InterceptionChainBuilder::AddInterceptor( |
||||
absl::StatusOr<RefCountedPtr<Interceptor>> interceptor) { |
||||
if (!status_.ok()) return; |
||||
if (!interceptor.ok()) { |
||||
status_ = interceptor.status(); |
||||
return; |
||||
} |
||||
(*interceptor)->filter_stack_ = MakeFilterStack(); |
||||
if (top_interceptor_ == nullptr) { |
||||
top_interceptor_ = std::move(*interceptor); |
||||
} else { |
||||
Interceptor* previous = top_interceptor_.get(); |
||||
while (previous->wrapped_destination_ != nullptr) { |
||||
previous = DownCast<Interceptor*>(previous->wrapped_destination_.get()); |
||||
} |
||||
previous->wrapped_destination_ = std::move(*interceptor); |
||||
} |
||||
} |
||||
|
||||
absl::StatusOr<RefCountedPtr<UnstartedCallDestination>> |
||||
InterceptionChainBuilder::Build(FinalDestination final_destination) { |
||||
if (!status_.ok()) return status_; |
||||
// Build the final UnstartedCallDestination in the chain - what we do here
|
||||
// depends on both the type of the final destination and the filters we have
|
||||
// that haven't been captured into an Interceptor yet.
|
||||
RefCountedPtr<UnstartedCallDestination> terminator = Match( |
||||
final_destination, |
||||
[this](RefCountedPtr<UnstartedCallDestination> final_destination) |
||||
-> RefCountedPtr<UnstartedCallDestination> { |
||||
if (stack_builder_.has_value()) { |
||||
return MakeRefCounted<TerminalInterceptor>(MakeFilterStack(), |
||||
final_destination); |
||||
} |
||||
return final_destination; |
||||
}, |
||||
[this](RefCountedPtr<CallDestination> final_destination) |
||||
-> RefCountedPtr<UnstartedCallDestination> { |
||||
return MakeRefCounted<CallStarter>(MakeFilterStack(), |
||||
std::move(final_destination)); |
||||
}); |
||||
// Now append the terminator to the interceptor chain.
|
||||
if (top_interceptor_ == nullptr) { |
||||
return std::move(terminator); |
||||
} |
||||
Interceptor* previous = top_interceptor_.get(); |
||||
while (previous->wrapped_destination_ != nullptr) { |
||||
previous = DownCast<Interceptor*>(previous->wrapped_destination_.get()); |
||||
} |
||||
previous->wrapped_destination_ = std::move(terminator); |
||||
return std::move(top_interceptor_); |
||||
} |
||||
|
||||
} // namespace grpc_core
|
@ -0,0 +1,225 @@ |
||||
// Copyright 2024 gRPC authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef GRPC_SRC_CORE_LIB_TRANSPORT_INTERCEPTION_CHAIN_H |
||||
#define GRPC_SRC_CORE_LIB_TRANSPORT_INTERCEPTION_CHAIN_H |
||||
|
||||
#include <memory> |
||||
#include <vector> |
||||
|
||||
#include <grpc/support/port_platform.h> |
||||
|
||||
#include "src/core/lib/gprpp/ref_counted.h" |
||||
#include "src/core/lib/transport/call_destination.h" |
||||
#include "src/core/lib/transport/call_filters.h" |
||||
#include "src/core/lib/transport/call_spine.h" |
||||
#include "src/core/lib/transport/metadata.h" |
||||
|
||||
namespace grpc_core { |
||||
|
||||
class InterceptionChainBuilder; |
||||
|
||||
// One hijacked call. Using this we can get access to the CallHandler for the
|
||||
// call object above us, the processed metadata from any filters/interceptors
|
||||
// above us, and also create new CallInterceptor objects that will be handled
|
||||
// below.
|
||||
class HijackedCall final { |
||||
public: |
||||
HijackedCall(ClientMetadataHandle metadata, |
||||
RefCountedPtr<UnstartedCallDestination> destination, |
||||
CallHandler call_handler) |
||||
: metadata_(std::move(metadata)), |
||||
destination_(std::move(destination)), |
||||
call_handler_(std::move(call_handler)) {} |
||||
|
||||
// Create a new call and pass it down the stack.
|
||||
// This can be called as many times as needed.
|
||||
CallInitiator MakeCall(); |
||||
// Per MakeCall(), but precludes creating further calls.
|
||||
// Allows us to optimize by not copying initial metadata.
|
||||
CallInitiator MakeLastCall() { |
||||
return MakeCallWithMetadata(std::move(metadata_)); |
||||
} |
||||
|
||||
CallHandler& original_call_handler() { return call_handler_; } |
||||
|
||||
ClientMetadata& client_metadata() { return *metadata_; } |
||||
|
||||
private: |
||||
CallInitiator MakeCallWithMetadata(ClientMetadataHandle metadata); |
||||
|
||||
ClientMetadataHandle metadata_; |
||||
RefCountedPtr<UnstartedCallDestination> destination_; |
||||
CallHandler call_handler_; |
||||
}; |
||||
|
||||
namespace interception_chain_detail { |
||||
|
||||
inline auto HijackCall(UnstartedCallHandler unstarted_call_handler, |
||||
RefCountedPtr<UnstartedCallDestination> destination, |
||||
RefCountedPtr<CallFilters::Stack> stack) { |
||||
auto call_handler = unstarted_call_handler.StartCall(stack); |
||||
return Map( |
||||
call_handler.PullClientInitialMetadata(), |
||||
[call_handler, |
||||
destination](ValueOrFailure<ClientMetadataHandle> metadata) mutable |
||||
-> ValueOrFailure<HijackedCall> { |
||||
if (!metadata.ok()) return Failure{}; |
||||
return HijackedCall(std::move(metadata.value()), std::move(destination), |
||||
std::move(call_handler)); |
||||
}); |
||||
} |
||||
|
||||
} // namespace interception_chain_detail
|
||||
|
||||
// A delegating UnstartedCallDestination for use as a hijacking filter.
|
||||
// Implementations may look at the unprocessed initial metadata
|
||||
// and decide to do one of two things:
|
||||
//
|
||||
// 1. It can hijack the call. Returns a HijackedCall object that can
|
||||
// be used to start new calls with the same metadata.
|
||||
//
|
||||
// 2. It can consume the call by calling `Consume`.
|
||||
//
|
||||
// Upon the StartCall call the UnstartedCallHandler will be from the last
|
||||
// *Interceptor* in the call chain (without having been processed by any
|
||||
// intervening filters) -- note that this is commonly not useful (not enough
|
||||
// guarantees), and so it's usually better to Hijack and examine the metadata.
|
||||
class Interceptor : public UnstartedCallDestination { |
||||
protected: |
||||
// Returns a promise that resolves to a HijackedCall instance.
|
||||
// Hijacking is the process of taking over a call and starting one or more new
|
||||
// ones.
|
||||
auto Hijack(UnstartedCallHandler unstarted_call_handler) { |
||||
return interception_chain_detail::HijackCall( |
||||
std::move(unstarted_call_handler), wrapped_destination_, filter_stack_); |
||||
} |
||||
|
||||
// Consume this call - it will not be passed on to any further filters.
|
||||
CallHandler Consume(UnstartedCallHandler unstarted_call_handler) { |
||||
return unstarted_call_handler.StartCall(filter_stack_); |
||||
} |
||||
|
||||
// TODO(ctiller): Consider a Passthrough() method that allows the call to be
|
||||
// passed on to the next filter in the chain without any interception by the
|
||||
// current filter.
|
||||
|
||||
private: |
||||
friend class InterceptionChainBuilder; |
||||
|
||||
RefCountedPtr<UnstartedCallDestination> wrapped_destination_; |
||||
RefCountedPtr<CallFilters::Stack> filter_stack_; |
||||
}; |
||||
|
||||
class InterceptionChainBuilder final { |
||||
public: |
||||
// The kind of destination that the chain will eventually call.
|
||||
// We can bottom out in various types depending on where we're intercepting:
|
||||
// - The top half of the client channel wants to terminate on a
|
||||
// UnstartedCallDestination (specifically the LB call destination).
|
||||
// - The bottom half of the client channel and the server code wants to
|
||||
// terminate on a ClientTransport - which unlike a
|
||||
// UnstartedCallDestination demands a started CallHandler.
|
||||
// There's some adaption code that's needed to start filters just prior
|
||||
// to the bottoming out, and some design considerations to make with that.
|
||||
// One way (that's not chosen here) would be to have the caller of the
|
||||
// Builder provide something that can build an adaptor
|
||||
// UnstartedCallDestination with parameters supplied by this builder - that
|
||||
// disperses the responsibility of building the adaptor to the caller, which
|
||||
// is not ideal - we might want to adjust the way this construct is built in
|
||||
// the future, and building is a builder responsibility.
|
||||
// Instead, we declare a relatively closed set of destinations here, and
|
||||
// hide the adaptors inside the builder at build time.
|
||||
using FinalDestination = |
||||
absl::variant<RefCountedPtr<UnstartedCallDestination>, |
||||
RefCountedPtr<CallDestination>>; |
||||
|
||||
explicit InterceptionChainBuilder(ChannelArgs args) |
||||
: args_(std::move(args)) {} |
||||
|
||||
// Add a filter with a `Call` class as an inner member.
|
||||
// Call class must be one compatible with the filters described in
|
||||
// call_filters.h.
|
||||
template <typename T> |
||||
absl::enable_if_t<sizeof(typename T::Call) != 0, InterceptionChainBuilder&> |
||||
Add() { |
||||
if (!status_.ok()) return *this; |
||||
auto filter = T::Create(args_, {FilterInstanceId(FilterTypeId<T>())}); |
||||
if (!filter.ok()) { |
||||
status_ = filter.status(); |
||||
return *this; |
||||
} |
||||
auto& sb = stack_builder(); |
||||
sb.Add(filter.value().get()); |
||||
sb.AddOwnedObject(std::move(filter.value())); |
||||
return *this; |
||||
}; |
||||
|
||||
// Add a filter that is an interceptor - one that can hijack calls.
|
||||
template <typename T> |
||||
absl::enable_if_t<std::is_base_of<Interceptor, T>::value, |
||||
InterceptionChainBuilder&> |
||||
Add() { |
||||
AddInterceptor(T::Create(args_, {FilterInstanceId(FilterTypeId<T>())})); |
||||
return *this; |
||||
}; |
||||
|
||||
// Add a filter that just mutates server trailing metadata.
|
||||
template <typename F> |
||||
void AddOnServerTrailingMetadata(F f) { |
||||
stack_builder().AddOnServerTrailingMetadata(std::move(f)); |
||||
} |
||||
|
||||
// Build this stack
|
||||
absl::StatusOr<RefCountedPtr<UnstartedCallDestination>> Build( |
||||
FinalDestination final_destination); |
||||
|
||||
const ChannelArgs& channel_args() const { return args_; } |
||||
|
||||
private: |
||||
CallFilters::StackBuilder& stack_builder() { |
||||
if (!stack_builder_.has_value()) stack_builder_.emplace(); |
||||
return *stack_builder_; |
||||
} |
||||
|
||||
RefCountedPtr<CallFilters::Stack> MakeFilterStack() { |
||||
auto stack = stack_builder().Build(); |
||||
stack_builder_.reset(); |
||||
return stack; |
||||
} |
||||
|
||||
template <typename T> |
||||
static size_t FilterTypeId() { |
||||
static const size_t id = |
||||
next_filter_id_.fetch_add(1, std::memory_order_relaxed); |
||||
return id; |
||||
} |
||||
|
||||
size_t FilterInstanceId(size_t filter_type) { |
||||
return filter_type_counts_[filter_type]++; |
||||
} |
||||
|
||||
void AddInterceptor(absl::StatusOr<RefCountedPtr<Interceptor>> interceptor); |
||||
|
||||
ChannelArgs args_; |
||||
absl::optional<CallFilters::StackBuilder> stack_builder_; |
||||
RefCountedPtr<Interceptor> top_interceptor_; |
||||
absl::Status status_; |
||||
std::map<size_t, size_t> filter_type_counts_; |
||||
static std::atomic<size_t> next_filter_id_; |
||||
}; |
||||
|
||||
} // namespace grpc_core
|
||||
|
||||
#endif // GRPC_SRC_CORE_LIB_TRANSPORT_INTERCEPTION_CHAIN_H
|
@ -0,0 +1,60 @@ |
||||
// Copyright 2024 gRPC authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef GRPC_TEST_CORE_PROMISE_POLL_MATCHER_H |
||||
#define GRPC_TEST_CORE_PROMISE_POLL_MATCHER_H |
||||
|
||||
#include "gmock/gmock.h" |
||||
|
||||
// Various gmock matchers for Poll
|
||||
|
||||
namespace grpc_core { |
||||
|
||||
// Expect that a promise is still pending:
|
||||
// EXPECT_THAT(some_promise(), IsPending());
|
||||
MATCHER(IsPending, "") { |
||||
if (arg.ready()) { |
||||
*result_listener << "is ready"; |
||||
return false; |
||||
} |
||||
return true; |
||||
} |
||||
|
||||
// Expect that a promise is ready:
|
||||
// EXPECT_THAT(some_promise(), IsReady());
|
||||
MATCHER(IsReady, "") { |
||||
if (arg.pending()) { |
||||
*result_listener << "is pending"; |
||||
return false; |
||||
} |
||||
return true; |
||||
} |
||||
|
||||
// Expect that a promise is ready with a specific value:
|
||||
// EXPECT_THAT(some_promise(), IsReady(value));
|
||||
MATCHER_P(IsReady, value, "") { |
||||
if (arg.pending()) { |
||||
*result_listener << "is pending"; |
||||
return false; |
||||
} |
||||
if (arg.value() != value) { |
||||
*result_listener << "is " << ::testing::PrintToString(arg.value()); |
||||
return false; |
||||
} |
||||
return true; |
||||
} |
||||
|
||||
} // namespace grpc_core
|
||||
|
||||
#endif // GRPC_TEST_CORE_PROMISE_POLL_MATCHER_H
|
@ -0,0 +1,406 @@ |
||||
// Copyright 2024 gRPC authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/core/lib/transport/interception_chain.h" |
||||
|
||||
#include <memory> |
||||
|
||||
#include "gmock/gmock.h" |
||||
#include "gtest/gtest.h" |
||||
|
||||
#include <grpc/support/log.h> |
||||
|
||||
#include "src/core/lib/channel/promise_based_filter.h" |
||||
#include "src/core/lib/resource_quota/resource_quota.h" |
||||
#include "test/core/promise/poll_matcher.h" |
||||
|
||||
namespace grpc_core { |
||||
namespace { |
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Mutate metadata by annotating that it passed through a filter "x"
|
||||
|
||||
void AnnotatePassedThrough(ClientMetadata& md, int x) { |
||||
md.Append(absl::StrCat("passed-through-", x), Slice::FromCopiedString("true"), |
||||
[](absl::string_view, const Slice&) { Crash("unreachable"); }); |
||||
} |
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// CreationLog helps us reason about filter creation order by logging a small
|
||||
// record of each filter's creation.
|
||||
|
||||
struct CreationLogEntry { |
||||
size_t filter_instance_id; |
||||
size_t type_tag; |
||||
|
||||
bool operator==(const CreationLogEntry& other) const { |
||||
return filter_instance_id == other.filter_instance_id && |
||||
type_tag == other.type_tag; |
||||
} |
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, |
||||
const CreationLogEntry& entry) { |
||||
return os << "{filter_instance_id=" << entry.filter_instance_id |
||||
<< ", type_tag=" << entry.type_tag << "}"; |
||||
} |
||||
}; |
||||
|
||||
struct CreationLog { |
||||
struct RawPointerChannelArgTag {}; |
||||
static absl::string_view ChannelArgName() { return "creation_log"; } |
||||
std::vector<CreationLogEntry> entries; |
||||
}; |
||||
|
||||
void MaybeLogCreation(const ChannelArgs& channel_args, |
||||
ChannelFilter::Args filter_args, size_t type_tag) { |
||||
auto* log = channel_args.GetPointer<CreationLog>("creation_log"); |
||||
if (log == nullptr) return; |
||||
log->entries.push_back(CreationLogEntry{filter_args.instance_id(), type_tag}); |
||||
} |
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Test call filter
|
||||
|
||||
template <int I> |
||||
class TestFilter { |
||||
public: |
||||
class Call { |
||||
public: |
||||
void OnClientInitialMetadata(ClientMetadata& md) { |
||||
AnnotatePassedThrough(md, I); |
||||
} |
||||
static const NoInterceptor OnServerInitialMetadata; |
||||
static const NoInterceptor OnClientToServerMessage; |
||||
static const NoInterceptor OnServerToClientMessage; |
||||
static const NoInterceptor OnServerTrailingMetadata; |
||||
static const NoInterceptor OnFinalize; |
||||
}; |
||||
|
||||
static absl::StatusOr<std::unique_ptr<TestFilter<I>>> Create( |
||||
const ChannelArgs& channel_args, ChannelFilter::Args filter_args) { |
||||
MaybeLogCreation(channel_args, filter_args, I); |
||||
return std::make_unique<TestFilter<I>>(); |
||||
} |
||||
|
||||
private: |
||||
std::unique_ptr<int> i_ = std::make_unique<int>(I); |
||||
}; |
||||
|
||||
template <int I> |
||||
const NoInterceptor TestFilter<I>::Call::OnServerInitialMetadata; |
||||
template <int I> |
||||
const NoInterceptor TestFilter<I>::Call::OnClientToServerMessage; |
||||
template <int I> |
||||
const NoInterceptor TestFilter<I>::Call::OnServerToClientMessage; |
||||
template <int I> |
||||
const NoInterceptor TestFilter<I>::Call::OnServerTrailingMetadata; |
||||
template <int I> |
||||
const NoInterceptor TestFilter<I>::Call::OnFinalize; |
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Test call filter that fails to instantiate
|
||||
|
||||
template <int I> |
||||
class FailsToInstantiateFilter { |
||||
public: |
||||
class Call { |
||||
public: |
||||
static const NoInterceptor OnClientInitialMetadata; |
||||
static const NoInterceptor OnServerInitialMetadata; |
||||
static const NoInterceptor OnClientToServerMessage; |
||||
static const NoInterceptor OnServerToClientMessage; |
||||
static const NoInterceptor OnServerTrailingMetadata; |
||||
static const NoInterceptor OnFinalize; |
||||
}; |
||||
|
||||
static absl::StatusOr<std::unique_ptr<FailsToInstantiateFilter<I>>> Create( |
||||
const ChannelArgs& channel_args, ChannelFilter::Args filter_args) { |
||||
MaybeLogCreation(channel_args, filter_args, I); |
||||
return absl::InternalError(absl::StrCat("👊 failed to instantiate ", I)); |
||||
} |
||||
}; |
||||
|
||||
template <int I> |
||||
const NoInterceptor FailsToInstantiateFilter<I>::Call::OnClientInitialMetadata; |
||||
template <int I> |
||||
const NoInterceptor FailsToInstantiateFilter<I>::Call::OnServerInitialMetadata; |
||||
template <int I> |
||||
const NoInterceptor FailsToInstantiateFilter<I>::Call::OnClientToServerMessage; |
||||
template <int I> |
||||
const NoInterceptor FailsToInstantiateFilter<I>::Call::OnServerToClientMessage; |
||||
template <int I> |
||||
const NoInterceptor FailsToInstantiateFilter<I>::Call::OnServerTrailingMetadata; |
||||
template <int I> |
||||
const NoInterceptor FailsToInstantiateFilter<I>::Call::OnFinalize; |
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Test call interceptor - consumes calls
|
||||
|
||||
template <int I> |
||||
class TestConsumingInterceptor final : public Interceptor { |
||||
public: |
||||
void StartCall(UnstartedCallHandler unstarted_call_handler) override { |
||||
Consume(std::move(unstarted_call_handler)) |
||||
.PushServerTrailingMetadata( |
||||
ServerMetadataFromStatus(absl::InternalError("👊 consumed"))); |
||||
} |
||||
void Orphaned() override {} |
||||
static absl::StatusOr<RefCountedPtr<TestConsumingInterceptor<I>>> Create( |
||||
const ChannelArgs& channel_args, ChannelFilter::Args filter_args) { |
||||
MaybeLogCreation(channel_args, filter_args, I); |
||||
return MakeRefCounted<TestConsumingInterceptor<I>>(); |
||||
} |
||||
}; |
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Test call interceptor - fails to instantiate
|
||||
|
||||
template <int I> |
||||
class TestFailingInterceptor final : public Interceptor { |
||||
public: |
||||
void StartCall(UnstartedCallHandler unstarted_call_handler) override { |
||||
Crash("unreachable"); |
||||
} |
||||
void Orphaned() override {} |
||||
static absl::StatusOr<RefCountedPtr<TestFailingInterceptor<I>>> Create( |
||||
const ChannelArgs& channel_args, ChannelFilter::Args filter_args) { |
||||
MaybeLogCreation(channel_args, filter_args, I); |
||||
return absl::InternalError(absl::StrCat("👊 failed to instantiate ", I)); |
||||
} |
||||
}; |
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Test call interceptor - hijacks calls
|
||||
|
||||
template <int I> |
||||
class TestHijackingInterceptor final : public Interceptor { |
||||
public: |
||||
void StartCall(UnstartedCallHandler unstarted_call_handler) override { |
||||
unstarted_call_handler.SpawnInfallible( |
||||
"hijack", [this, unstarted_call_handler]() mutable { |
||||
return Map(Hijack(std::move(unstarted_call_handler)), |
||||
[](ValueOrFailure<HijackedCall> hijacked_call) { |
||||
ForwardCall( |
||||
hijacked_call.value().original_call_handler(), |
||||
hijacked_call.value().MakeCall()); |
||||
return Empty{}; |
||||
}); |
||||
}); |
||||
} |
||||
void Orphaned() override {} |
||||
static absl::StatusOr<RefCountedPtr<TestHijackingInterceptor<I>>> Create( |
||||
const ChannelArgs& channel_args, ChannelFilter::Args filter_args) { |
||||
MaybeLogCreation(channel_args, filter_args, I); |
||||
return MakeRefCounted<TestHijackingInterceptor<I>>(); |
||||
} |
||||
}; |
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Test fixture
|
||||
|
||||
class InterceptionChainTest : public ::testing::Test { |
||||
protected: |
||||
InterceptionChainTest() {} |
||||
~InterceptionChainTest() override {} |
||||
|
||||
RefCountedPtr<UnstartedCallDestination> destination() { return destination_; } |
||||
|
||||
struct FinishedCall { |
||||
CallInitiator call; |
||||
ClientMetadataHandle client_metadata; |
||||
ServerMetadataHandle server_metadata; |
||||
}; |
||||
|
||||
// Run a call through a UnstartedCallDestination until it's complete.
|
||||
FinishedCall RunCall(UnstartedCallDestination* destination) { |
||||
auto* arena = call_arena_allocator_->MakeArena(); |
||||
auto call = MakeCallPair(Arena::MakePooled<ClientMetadata>(), nullptr, |
||||
arena, call_arena_allocator_, nullptr); |
||||
Poll<ServerMetadataHandle> trailing_md; |
||||
call.initiator.SpawnInfallible( |
||||
"run_call", [destination, &call, &trailing_md]() mutable { |
||||
gpr_log(GPR_INFO, "👊 start call"); |
||||
destination->StartCall(std::move(call.handler)); |
||||
return Map(call.initiator.PullServerTrailingMetadata(), |
||||
[&trailing_md](ServerMetadataHandle md) { |
||||
trailing_md = std::move(md); |
||||
return Empty{}; |
||||
}); |
||||
}); |
||||
EXPECT_THAT(trailing_md, IsReady()); |
||||
return FinishedCall{std::move(call.initiator), destination_->TakeMetadata(), |
||||
std::move(trailing_md.value())}; |
||||
} |
||||
|
||||
private: |
||||
class Destination final : public UnstartedCallDestination { |
||||
public: |
||||
void StartCall(UnstartedCallHandler unstarted_call_handler) override { |
||||
gpr_log(GPR_INFO, "👊 started call: metadata=%s", |
||||
unstarted_call_handler.UnprocessedClientInitialMetadata() |
||||
.DebugString() |
||||
.c_str()); |
||||
EXPECT_EQ(metadata_.get(), nullptr); |
||||
metadata_ = Arena::MakePooled<ClientMetadata>(); |
||||
*metadata_ = |
||||
unstarted_call_handler.UnprocessedClientInitialMetadata().Copy(); |
||||
unstarted_call_handler.PushServerTrailingMetadata( |
||||
ServerMetadataFromStatus(absl::InternalError("👊 cancelled"))); |
||||
} |
||||
|
||||
void Orphaned() override {} |
||||
|
||||
ClientMetadataHandle TakeMetadata() { return std::move(metadata_); } |
||||
|
||||
private: |
||||
ClientMetadataHandle metadata_; |
||||
}; |
||||
RefCountedPtr<Destination> destination_ = MakeRefCounted<Destination>(); |
||||
RefCountedPtr<CallArenaAllocator> call_arena_allocator_ = |
||||
MakeRefCounted<CallArenaAllocator>( |
||||
ResourceQuota::Default()->memory_quota()->CreateMemoryAllocator( |
||||
"test"), |
||||
1024); |
||||
}; |
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Tests begin
|
||||
|
||||
TEST_F(InterceptionChainTest, Empty) { |
||||
auto r = InterceptionChainBuilder(ChannelArgs()).Build(destination()); |
||||
ASSERT_TRUE(r.ok()) << r.status(); |
||||
auto finished_call = RunCall(r.value().get()); |
||||
EXPECT_EQ(finished_call.server_metadata->get(GrpcStatusMetadata()), |
||||
GRPC_STATUS_INTERNAL); |
||||
EXPECT_EQ(finished_call.server_metadata->get_pointer(GrpcMessageMetadata()) |
||||
->as_string_view(), |
||||
"👊 cancelled"); |
||||
EXPECT_NE(finished_call.client_metadata, nullptr); |
||||
} |
||||
|
||||
TEST_F(InterceptionChainTest, Consumed) { |
||||
auto r = InterceptionChainBuilder(ChannelArgs()) |
||||
.Add<TestConsumingInterceptor<1>>() |
||||
.Build(destination()); |
||||
ASSERT_TRUE(r.ok()) << r.status(); |
||||
auto finished_call = RunCall(r.value().get()); |
||||
EXPECT_EQ(finished_call.server_metadata->get(GrpcStatusMetadata()), |
||||
GRPC_STATUS_INTERNAL); |
||||
EXPECT_EQ(finished_call.server_metadata->get_pointer(GrpcMessageMetadata()) |
||||
->as_string_view(), |
||||
"👊 consumed"); |
||||
EXPECT_EQ(finished_call.client_metadata, nullptr); |
||||
} |
||||
|
||||
TEST_F(InterceptionChainTest, Hijacked) { |
||||
auto r = InterceptionChainBuilder(ChannelArgs()) |
||||
.Add<TestHijackingInterceptor<1>>() |
||||
.Build(destination()); |
||||
ASSERT_TRUE(r.ok()) << r.status(); |
||||
auto finished_call = RunCall(r.value().get()); |
||||
EXPECT_EQ(finished_call.server_metadata->get(GrpcStatusMetadata()), |
||||
GRPC_STATUS_INTERNAL); |
||||
EXPECT_EQ(finished_call.server_metadata->get_pointer(GrpcMessageMetadata()) |
||||
->as_string_view(), |
||||
"👊 cancelled"); |
||||
EXPECT_NE(finished_call.client_metadata, nullptr); |
||||
} |
||||
|
||||
TEST_F(InterceptionChainTest, FiltersThenHijacked) { |
||||
auto r = InterceptionChainBuilder(ChannelArgs()) |
||||
.Add<TestFilter<1>>() |
||||
.Add<TestHijackingInterceptor<2>>() |
||||
.Build(destination()); |
||||
ASSERT_TRUE(r.ok()) << r.status(); |
||||
auto finished_call = RunCall(r.value().get()); |
||||
EXPECT_EQ(finished_call.server_metadata->get(GrpcStatusMetadata()), |
||||
GRPC_STATUS_INTERNAL); |
||||
EXPECT_EQ(finished_call.server_metadata->get_pointer(GrpcMessageMetadata()) |
||||
->as_string_view(), |
||||
"👊 cancelled"); |
||||
EXPECT_NE(finished_call.client_metadata, nullptr); |
||||
std::string backing; |
||||
EXPECT_EQ(finished_call.client_metadata->GetStringValue("passed-through-1", |
||||
&backing), |
||||
"true"); |
||||
} |
||||
|
||||
TEST_F(InterceptionChainTest, FailsToInstantiateInterceptor) { |
||||
auto r = InterceptionChainBuilder(ChannelArgs()) |
||||
.Add<TestFailingInterceptor<1>>() |
||||
.Build(destination()); |
||||
EXPECT_FALSE(r.ok()); |
||||
EXPECT_EQ(r.status().code(), absl::StatusCode::kInternal); |
||||
EXPECT_EQ(r.status().message(), "👊 failed to instantiate 1"); |
||||
} |
||||
|
||||
TEST_F(InterceptionChainTest, FailsToInstantiateInterceptor2) { |
||||
auto r = InterceptionChainBuilder(ChannelArgs()) |
||||
.Add<TestFilter<1>>() |
||||
.Add<TestFailingInterceptor<2>>() |
||||
.Build(destination()); |
||||
EXPECT_FALSE(r.ok()); |
||||
EXPECT_EQ(r.status().code(), absl::StatusCode::kInternal); |
||||
EXPECT_EQ(r.status().message(), "👊 failed to instantiate 2"); |
||||
} |
||||
|
||||
TEST_F(InterceptionChainTest, FailsToInstantiateFilter) { |
||||
auto r = InterceptionChainBuilder(ChannelArgs()) |
||||
.Add<FailsToInstantiateFilter<1>>() |
||||
.Build(destination()); |
||||
EXPECT_FALSE(r.ok()); |
||||
EXPECT_EQ(r.status().code(), absl::StatusCode::kInternal); |
||||
EXPECT_EQ(r.status().message(), "👊 failed to instantiate 1"); |
||||
} |
||||
|
||||
TEST_F(InterceptionChainTest, FailsToInstantiateFilter2) { |
||||
auto r = InterceptionChainBuilder(ChannelArgs()) |
||||
.Add<TestFilter<1>>() |
||||
.Add<FailsToInstantiateFilter<2>>() |
||||
.Build(destination()); |
||||
EXPECT_FALSE(r.ok()); |
||||
EXPECT_EQ(r.status().code(), absl::StatusCode::kInternal); |
||||
EXPECT_EQ(r.status().message(), "👊 failed to instantiate 2"); |
||||
} |
||||
|
||||
TEST_F(InterceptionChainTest, CreationOrderCorrect) { |
||||
CreationLog log; |
||||
auto r = InterceptionChainBuilder(ChannelArgs().SetObject(&log)) |
||||
.Add<TestFilter<1>>() |
||||
.Add<TestFilter<2>>() |
||||
.Add<TestFilter<3>>() |
||||
.Add<TestConsumingInterceptor<4>>() |
||||
.Add<TestFilter<1>>() |
||||
.Add<TestFilter<2>>() |
||||
.Add<TestFilter<3>>() |
||||
.Add<TestConsumingInterceptor<4>>() |
||||
.Add<TestFilter<1>>() |
||||
.Build(destination()); |
||||
EXPECT_THAT(log.entries, ::testing::ElementsAre( |
||||
CreationLogEntry{0, 1}, CreationLogEntry{0, 2}, |
||||
CreationLogEntry{0, 3}, CreationLogEntry{0, 4}, |
||||
CreationLogEntry{1, 1}, CreationLogEntry{1, 2}, |
||||
CreationLogEntry{1, 3}, CreationLogEntry{1, 4}, |
||||
CreationLogEntry{2, 1})); |
||||
} |
||||
|
||||
} // namespace
|
||||
} // namespace grpc_core
|
||||
|
||||
int main(int argc, char** argv) { |
||||
::testing::InitGoogleTest(&argc, argv); |
||||
grpc_tracer_init(); |
||||
gpr_log_verbosity_init(); |
||||
return RUN_ALL_TESTS(); |
||||
} |
Loading…
Reference in new issue