commit
6e71b9cbf2
67 changed files with 2714 additions and 600 deletions
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,156 @@ |
|||||||
|
// Copyright 2024 gRPC authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "src/core/lib/transport/interception_chain.h" |
||||||
|
|
||||||
|
#include <cstddef> |
||||||
|
|
||||||
|
#include <grpc/support/port_platform.h> |
||||||
|
|
||||||
|
#include "src/core/lib/gprpp/match.h" |
||||||
|
#include "src/core/lib/transport/call_destination.h" |
||||||
|
#include "src/core/lib/transport/call_filters.h" |
||||||
|
#include "src/core/lib/transport/call_spine.h" |
||||||
|
#include "src/core/lib/transport/metadata.h" |
||||||
|
|
||||||
|
namespace grpc_core { |
||||||
|
|
||||||
|
std::atomic<size_t> InterceptionChainBuilder::next_filter_id_{0}; |
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// HijackedCall
|
||||||
|
|
||||||
|
CallInitiator HijackedCall::MakeCall() { |
||||||
|
auto metadata = Arena::MakePooled<ClientMetadata>(); |
||||||
|
*metadata = metadata_->Copy(); |
||||||
|
return MakeCallWithMetadata(std::move(metadata)); |
||||||
|
} |
||||||
|
|
||||||
|
CallInitiator HijackedCall::MakeCallWithMetadata( |
||||||
|
ClientMetadataHandle metadata) { |
||||||
|
auto call = MakeCallPair(std::move(metadata), call_handler_.event_engine(), |
||||||
|
call_handler_.arena(), nullptr, |
||||||
|
call_handler_.legacy_context()); |
||||||
|
destination_->StartCall(std::move(call.handler)); |
||||||
|
return std::move(call.initiator); |
||||||
|
} |
||||||
|
|
||||||
|
namespace { |
||||||
|
class CallStarter final : public UnstartedCallDestination { |
||||||
|
public: |
||||||
|
CallStarter(RefCountedPtr<CallFilters::Stack> stack, |
||||||
|
RefCountedPtr<CallDestination> destination) |
||||||
|
: stack_(std::move(stack)), destination_(std::move(destination)) {} |
||||||
|
|
||||||
|
void Orphaned() override { |
||||||
|
stack_.reset(); |
||||||
|
destination_.reset(); |
||||||
|
} |
||||||
|
|
||||||
|
void StartCall(UnstartedCallHandler unstarted_call_handler) override { |
||||||
|
destination_->HandleCall(unstarted_call_handler.StartCall(stack_)); |
||||||
|
} |
||||||
|
|
||||||
|
private: |
||||||
|
RefCountedPtr<CallFilters::Stack> stack_; |
||||||
|
RefCountedPtr<CallDestination> destination_; |
||||||
|
}; |
||||||
|
|
||||||
|
class TerminalInterceptor final : public UnstartedCallDestination { |
||||||
|
public: |
||||||
|
explicit TerminalInterceptor( |
||||||
|
RefCountedPtr<CallFilters::Stack> stack, |
||||||
|
RefCountedPtr<UnstartedCallDestination> destination) |
||||||
|
: stack_(std::move(stack)), destination_(std::move(destination)) {} |
||||||
|
|
||||||
|
void Orphaned() override { |
||||||
|
stack_.reset(); |
||||||
|
destination_.reset(); |
||||||
|
} |
||||||
|
|
||||||
|
void StartCall(UnstartedCallHandler unstarted_call_handler) override { |
||||||
|
unstarted_call_handler.SpawnGuarded( |
||||||
|
"start_call", |
||||||
|
Map(interception_chain_detail::HijackCall(unstarted_call_handler, |
||||||
|
destination_, stack_), |
||||||
|
[](ValueOrFailure<HijackedCall> hijacked_call) -> StatusFlag { |
||||||
|
if (!hijacked_call.ok()) return Failure{}; |
||||||
|
ForwardCall(hijacked_call.value().original_call_handler(), |
||||||
|
hijacked_call.value().MakeLastCall()); |
||||||
|
return Success{}; |
||||||
|
})); |
||||||
|
} |
||||||
|
|
||||||
|
private: |
||||||
|
RefCountedPtr<CallFilters::Stack> stack_; |
||||||
|
RefCountedPtr<UnstartedCallDestination> destination_; |
||||||
|
}; |
||||||
|
} // namespace
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// InterceptionChain::Builder
|
||||||
|
|
||||||
|
void InterceptionChainBuilder::AddInterceptor( |
||||||
|
absl::StatusOr<RefCountedPtr<Interceptor>> interceptor) { |
||||||
|
if (!status_.ok()) return; |
||||||
|
if (!interceptor.ok()) { |
||||||
|
status_ = interceptor.status(); |
||||||
|
return; |
||||||
|
} |
||||||
|
(*interceptor)->filter_stack_ = MakeFilterStack(); |
||||||
|
if (top_interceptor_ == nullptr) { |
||||||
|
top_interceptor_ = std::move(*interceptor); |
||||||
|
} else { |
||||||
|
Interceptor* previous = top_interceptor_.get(); |
||||||
|
while (previous->wrapped_destination_ != nullptr) { |
||||||
|
previous = DownCast<Interceptor*>(previous->wrapped_destination_.get()); |
||||||
|
} |
||||||
|
previous->wrapped_destination_ = std::move(*interceptor); |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
absl::StatusOr<RefCountedPtr<UnstartedCallDestination>> |
||||||
|
InterceptionChainBuilder::Build(FinalDestination final_destination) { |
||||||
|
if (!status_.ok()) return status_; |
||||||
|
// Build the final UnstartedCallDestination in the chain - what we do here
|
||||||
|
// depends on both the type of the final destination and the filters we have
|
||||||
|
// that haven't been captured into an Interceptor yet.
|
||||||
|
RefCountedPtr<UnstartedCallDestination> terminator = Match( |
||||||
|
final_destination, |
||||||
|
[this](RefCountedPtr<UnstartedCallDestination> final_destination) |
||||||
|
-> RefCountedPtr<UnstartedCallDestination> { |
||||||
|
if (stack_builder_.has_value()) { |
||||||
|
return MakeRefCounted<TerminalInterceptor>(MakeFilterStack(), |
||||||
|
final_destination); |
||||||
|
} |
||||||
|
return final_destination; |
||||||
|
}, |
||||||
|
[this](RefCountedPtr<CallDestination> final_destination) |
||||||
|
-> RefCountedPtr<UnstartedCallDestination> { |
||||||
|
return MakeRefCounted<CallStarter>(MakeFilterStack(), |
||||||
|
std::move(final_destination)); |
||||||
|
}); |
||||||
|
// Now append the terminator to the interceptor chain.
|
||||||
|
if (top_interceptor_ == nullptr) { |
||||||
|
return std::move(terminator); |
||||||
|
} |
||||||
|
Interceptor* previous = top_interceptor_.get(); |
||||||
|
while (previous->wrapped_destination_ != nullptr) { |
||||||
|
previous = DownCast<Interceptor*>(previous->wrapped_destination_.get()); |
||||||
|
} |
||||||
|
previous->wrapped_destination_ = std::move(terminator); |
||||||
|
return std::move(top_interceptor_); |
||||||
|
} |
||||||
|
|
||||||
|
} // namespace grpc_core
|
@ -0,0 +1,225 @@ |
|||||||
|
// Copyright 2024 gRPC authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#ifndef GRPC_SRC_CORE_LIB_TRANSPORT_INTERCEPTION_CHAIN_H |
||||||
|
#define GRPC_SRC_CORE_LIB_TRANSPORT_INTERCEPTION_CHAIN_H |
||||||
|
|
||||||
|
#include <memory> |
||||||
|
#include <vector> |
||||||
|
|
||||||
|
#include <grpc/support/port_platform.h> |
||||||
|
|
||||||
|
#include "src/core/lib/gprpp/ref_counted.h" |
||||||
|
#include "src/core/lib/transport/call_destination.h" |
||||||
|
#include "src/core/lib/transport/call_filters.h" |
||||||
|
#include "src/core/lib/transport/call_spine.h" |
||||||
|
#include "src/core/lib/transport/metadata.h" |
||||||
|
|
||||||
|
namespace grpc_core { |
||||||
|
|
||||||
|
class InterceptionChainBuilder; |
||||||
|
|
||||||
|
// One hijacked call. Using this we can get access to the CallHandler for the
|
||||||
|
// call object above us, the processed metadata from any filters/interceptors
|
||||||
|
// above us, and also create new CallInterceptor objects that will be handled
|
||||||
|
// below.
|
||||||
|
class HijackedCall final { |
||||||
|
public: |
||||||
|
HijackedCall(ClientMetadataHandle metadata, |
||||||
|
RefCountedPtr<UnstartedCallDestination> destination, |
||||||
|
CallHandler call_handler) |
||||||
|
: metadata_(std::move(metadata)), |
||||||
|
destination_(std::move(destination)), |
||||||
|
call_handler_(std::move(call_handler)) {} |
||||||
|
|
||||||
|
// Create a new call and pass it down the stack.
|
||||||
|
// This can be called as many times as needed.
|
||||||
|
CallInitiator MakeCall(); |
||||||
|
// Per MakeCall(), but precludes creating further calls.
|
||||||
|
// Allows us to optimize by not copying initial metadata.
|
||||||
|
CallInitiator MakeLastCall() { |
||||||
|
return MakeCallWithMetadata(std::move(metadata_)); |
||||||
|
} |
||||||
|
|
||||||
|
CallHandler& original_call_handler() { return call_handler_; } |
||||||
|
|
||||||
|
ClientMetadata& client_metadata() { return *metadata_; } |
||||||
|
|
||||||
|
private: |
||||||
|
CallInitiator MakeCallWithMetadata(ClientMetadataHandle metadata); |
||||||
|
|
||||||
|
ClientMetadataHandle metadata_; |
||||||
|
RefCountedPtr<UnstartedCallDestination> destination_; |
||||||
|
CallHandler call_handler_; |
||||||
|
}; |
||||||
|
|
||||||
|
namespace interception_chain_detail { |
||||||
|
|
||||||
|
inline auto HijackCall(UnstartedCallHandler unstarted_call_handler, |
||||||
|
RefCountedPtr<UnstartedCallDestination> destination, |
||||||
|
RefCountedPtr<CallFilters::Stack> stack) { |
||||||
|
auto call_handler = unstarted_call_handler.StartCall(stack); |
||||||
|
return Map( |
||||||
|
call_handler.PullClientInitialMetadata(), |
||||||
|
[call_handler, |
||||||
|
destination](ValueOrFailure<ClientMetadataHandle> metadata) mutable |
||||||
|
-> ValueOrFailure<HijackedCall> { |
||||||
|
if (!metadata.ok()) return Failure{}; |
||||||
|
return HijackedCall(std::move(metadata.value()), std::move(destination), |
||||||
|
std::move(call_handler)); |
||||||
|
}); |
||||||
|
} |
||||||
|
|
||||||
|
} // namespace interception_chain_detail
|
||||||
|
|
||||||
|
// A delegating UnstartedCallDestination for use as a hijacking filter.
|
||||||
|
// Implementations may look at the unprocessed initial metadata
|
||||||
|
// and decide to do one of two things:
|
||||||
|
//
|
||||||
|
// 1. It can hijack the call. Returns a HijackedCall object that can
|
||||||
|
// be used to start new calls with the same metadata.
|
||||||
|
//
|
||||||
|
// 2. It can consume the call by calling `Consume`.
|
||||||
|
//
|
||||||
|
// Upon the StartCall call the UnstartedCallHandler will be from the last
|
||||||
|
// *Interceptor* in the call chain (without having been processed by any
|
||||||
|
// intervening filters) -- note that this is commonly not useful (not enough
|
||||||
|
// guarantees), and so it's usually better to Hijack and examine the metadata.
|
||||||
|
class Interceptor : public UnstartedCallDestination { |
||||||
|
protected: |
||||||
|
// Returns a promise that resolves to a HijackedCall instance.
|
||||||
|
// Hijacking is the process of taking over a call and starting one or more new
|
||||||
|
// ones.
|
||||||
|
auto Hijack(UnstartedCallHandler unstarted_call_handler) { |
||||||
|
return interception_chain_detail::HijackCall( |
||||||
|
std::move(unstarted_call_handler), wrapped_destination_, filter_stack_); |
||||||
|
} |
||||||
|
|
||||||
|
// Consume this call - it will not be passed on to any further filters.
|
||||||
|
CallHandler Consume(UnstartedCallHandler unstarted_call_handler) { |
||||||
|
return unstarted_call_handler.StartCall(filter_stack_); |
||||||
|
} |
||||||
|
|
||||||
|
// TODO(ctiller): Consider a Passthrough() method that allows the call to be
|
||||||
|
// passed on to the next filter in the chain without any interception by the
|
||||||
|
// current filter.
|
||||||
|
|
||||||
|
private: |
||||||
|
friend class InterceptionChainBuilder; |
||||||
|
|
||||||
|
RefCountedPtr<UnstartedCallDestination> wrapped_destination_; |
||||||
|
RefCountedPtr<CallFilters::Stack> filter_stack_; |
||||||
|
}; |
||||||
|
|
||||||
|
class InterceptionChainBuilder final { |
||||||
|
public: |
||||||
|
// The kind of destination that the chain will eventually call.
|
||||||
|
// We can bottom out in various types depending on where we're intercepting:
|
||||||
|
// - The top half of the client channel wants to terminate on a
|
||||||
|
// UnstartedCallDestination (specifically the LB call destination).
|
||||||
|
// - The bottom half of the client channel and the server code wants to
|
||||||
|
// terminate on a ClientTransport - which unlike a
|
||||||
|
// UnstartedCallDestination demands a started CallHandler.
|
||||||
|
// There's some adaption code that's needed to start filters just prior
|
||||||
|
// to the bottoming out, and some design considerations to make with that.
|
||||||
|
// One way (that's not chosen here) would be to have the caller of the
|
||||||
|
// Builder provide something that can build an adaptor
|
||||||
|
// UnstartedCallDestination with parameters supplied by this builder - that
|
||||||
|
// disperses the responsibility of building the adaptor to the caller, which
|
||||||
|
// is not ideal - we might want to adjust the way this construct is built in
|
||||||
|
// the future, and building is a builder responsibility.
|
||||||
|
// Instead, we declare a relatively closed set of destinations here, and
|
||||||
|
// hide the adaptors inside the builder at build time.
|
||||||
|
using FinalDestination = |
||||||
|
absl::variant<RefCountedPtr<UnstartedCallDestination>, |
||||||
|
RefCountedPtr<CallDestination>>; |
||||||
|
|
||||||
|
explicit InterceptionChainBuilder(ChannelArgs args) |
||||||
|
: args_(std::move(args)) {} |
||||||
|
|
||||||
|
// Add a filter with a `Call` class as an inner member.
|
||||||
|
// Call class must be one compatible with the filters described in
|
||||||
|
// call_filters.h.
|
||||||
|
template <typename T> |
||||||
|
absl::enable_if_t<sizeof(typename T::Call) != 0, InterceptionChainBuilder&> |
||||||
|
Add() { |
||||||
|
if (!status_.ok()) return *this; |
||||||
|
auto filter = T::Create(args_, {FilterInstanceId(FilterTypeId<T>())}); |
||||||
|
if (!filter.ok()) { |
||||||
|
status_ = filter.status(); |
||||||
|
return *this; |
||||||
|
} |
||||||
|
auto& sb = stack_builder(); |
||||||
|
sb.Add(filter.value().get()); |
||||||
|
sb.AddOwnedObject(std::move(filter.value())); |
||||||
|
return *this; |
||||||
|
}; |
||||||
|
|
||||||
|
// Add a filter that is an interceptor - one that can hijack calls.
|
||||||
|
template <typename T> |
||||||
|
absl::enable_if_t<std::is_base_of<Interceptor, T>::value, |
||||||
|
InterceptionChainBuilder&> |
||||||
|
Add() { |
||||||
|
AddInterceptor(T::Create(args_, {FilterInstanceId(FilterTypeId<T>())})); |
||||||
|
return *this; |
||||||
|
}; |
||||||
|
|
||||||
|
// Add a filter that just mutates server trailing metadata.
|
||||||
|
template <typename F> |
||||||
|
void AddOnServerTrailingMetadata(F f) { |
||||||
|
stack_builder().AddOnServerTrailingMetadata(std::move(f)); |
||||||
|
} |
||||||
|
|
||||||
|
// Build this stack
|
||||||
|
absl::StatusOr<RefCountedPtr<UnstartedCallDestination>> Build( |
||||||
|
FinalDestination final_destination); |
||||||
|
|
||||||
|
const ChannelArgs& channel_args() const { return args_; } |
||||||
|
|
||||||
|
private: |
||||||
|
CallFilters::StackBuilder& stack_builder() { |
||||||
|
if (!stack_builder_.has_value()) stack_builder_.emplace(); |
||||||
|
return *stack_builder_; |
||||||
|
} |
||||||
|
|
||||||
|
RefCountedPtr<CallFilters::Stack> MakeFilterStack() { |
||||||
|
auto stack = stack_builder().Build(); |
||||||
|
stack_builder_.reset(); |
||||||
|
return stack; |
||||||
|
} |
||||||
|
|
||||||
|
template <typename T> |
||||||
|
static size_t FilterTypeId() { |
||||||
|
static const size_t id = |
||||||
|
next_filter_id_.fetch_add(1, std::memory_order_relaxed); |
||||||
|
return id; |
||||||
|
} |
||||||
|
|
||||||
|
size_t FilterInstanceId(size_t filter_type) { |
||||||
|
return filter_type_counts_[filter_type]++; |
||||||
|
} |
||||||
|
|
||||||
|
void AddInterceptor(absl::StatusOr<RefCountedPtr<Interceptor>> interceptor); |
||||||
|
|
||||||
|
ChannelArgs args_; |
||||||
|
absl::optional<CallFilters::StackBuilder> stack_builder_; |
||||||
|
RefCountedPtr<Interceptor> top_interceptor_; |
||||||
|
absl::Status status_; |
||||||
|
std::map<size_t, size_t> filter_type_counts_; |
||||||
|
static std::atomic<size_t> next_filter_id_; |
||||||
|
}; |
||||||
|
|
||||||
|
} // namespace grpc_core
|
||||||
|
|
||||||
|
#endif // GRPC_SRC_CORE_LIB_TRANSPORT_INTERCEPTION_CHAIN_H
|
@ -0,0 +1,60 @@ |
|||||||
|
// Copyright 2024 gRPC authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#ifndef GRPC_TEST_CORE_PROMISE_POLL_MATCHER_H |
||||||
|
#define GRPC_TEST_CORE_PROMISE_POLL_MATCHER_H |
||||||
|
|
||||||
|
#include "gmock/gmock.h" |
||||||
|
|
||||||
|
// Various gmock matchers for Poll
|
||||||
|
|
||||||
|
namespace grpc_core { |
||||||
|
|
||||||
|
// Expect that a promise is still pending:
|
||||||
|
// EXPECT_THAT(some_promise(), IsPending());
|
||||||
|
MATCHER(IsPending, "") { |
||||||
|
if (arg.ready()) { |
||||||
|
*result_listener << "is ready"; |
||||||
|
return false; |
||||||
|
} |
||||||
|
return true; |
||||||
|
} |
||||||
|
|
||||||
|
// Expect that a promise is ready:
|
||||||
|
// EXPECT_THAT(some_promise(), IsReady());
|
||||||
|
MATCHER(IsReady, "") { |
||||||
|
if (arg.pending()) { |
||||||
|
*result_listener << "is pending"; |
||||||
|
return false; |
||||||
|
} |
||||||
|
return true; |
||||||
|
} |
||||||
|
|
||||||
|
// Expect that a promise is ready with a specific value:
|
||||||
|
// EXPECT_THAT(some_promise(), IsReady(value));
|
||||||
|
MATCHER_P(IsReady, value, "") { |
||||||
|
if (arg.pending()) { |
||||||
|
*result_listener << "is pending"; |
||||||
|
return false; |
||||||
|
} |
||||||
|
if (arg.value() != value) { |
||||||
|
*result_listener << "is " << ::testing::PrintToString(arg.value()); |
||||||
|
return false; |
||||||
|
} |
||||||
|
return true; |
||||||
|
} |
||||||
|
|
||||||
|
} // namespace grpc_core
|
||||||
|
|
||||||
|
#endif // GRPC_TEST_CORE_PROMISE_POLL_MATCHER_H
|
@ -0,0 +1,406 @@ |
|||||||
|
// Copyright 2024 gRPC authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "src/core/lib/transport/interception_chain.h" |
||||||
|
|
||||||
|
#include <memory> |
||||||
|
|
||||||
|
#include "gmock/gmock.h" |
||||||
|
#include "gtest/gtest.h" |
||||||
|
|
||||||
|
#include <grpc/support/log.h> |
||||||
|
|
||||||
|
#include "src/core/lib/channel/promise_based_filter.h" |
||||||
|
#include "src/core/lib/resource_quota/resource_quota.h" |
||||||
|
#include "test/core/promise/poll_matcher.h" |
||||||
|
|
||||||
|
namespace grpc_core { |
||||||
|
namespace { |
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Mutate metadata by annotating that it passed through a filter "x"
|
||||||
|
|
||||||
|
void AnnotatePassedThrough(ClientMetadata& md, int x) { |
||||||
|
md.Append(absl::StrCat("passed-through-", x), Slice::FromCopiedString("true"), |
||||||
|
[](absl::string_view, const Slice&) { Crash("unreachable"); }); |
||||||
|
} |
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// CreationLog helps us reason about filter creation order by logging a small
|
||||||
|
// record of each filter's creation.
|
||||||
|
|
||||||
|
struct CreationLogEntry { |
||||||
|
size_t filter_instance_id; |
||||||
|
size_t type_tag; |
||||||
|
|
||||||
|
bool operator==(const CreationLogEntry& other) const { |
||||||
|
return filter_instance_id == other.filter_instance_id && |
||||||
|
type_tag == other.type_tag; |
||||||
|
} |
||||||
|
|
||||||
|
friend std::ostream& operator<<(std::ostream& os, |
||||||
|
const CreationLogEntry& entry) { |
||||||
|
return os << "{filter_instance_id=" << entry.filter_instance_id |
||||||
|
<< ", type_tag=" << entry.type_tag << "}"; |
||||||
|
} |
||||||
|
}; |
||||||
|
|
||||||
|
struct CreationLog { |
||||||
|
struct RawPointerChannelArgTag {}; |
||||||
|
static absl::string_view ChannelArgName() { return "creation_log"; } |
||||||
|
std::vector<CreationLogEntry> entries; |
||||||
|
}; |
||||||
|
|
||||||
|
void MaybeLogCreation(const ChannelArgs& channel_args, |
||||||
|
ChannelFilter::Args filter_args, size_t type_tag) { |
||||||
|
auto* log = channel_args.GetPointer<CreationLog>("creation_log"); |
||||||
|
if (log == nullptr) return; |
||||||
|
log->entries.push_back(CreationLogEntry{filter_args.instance_id(), type_tag}); |
||||||
|
} |
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Test call filter
|
||||||
|
|
||||||
|
template <int I> |
||||||
|
class TestFilter { |
||||||
|
public: |
||||||
|
class Call { |
||||||
|
public: |
||||||
|
void OnClientInitialMetadata(ClientMetadata& md) { |
||||||
|
AnnotatePassedThrough(md, I); |
||||||
|
} |
||||||
|
static const NoInterceptor OnServerInitialMetadata; |
||||||
|
static const NoInterceptor OnClientToServerMessage; |
||||||
|
static const NoInterceptor OnServerToClientMessage; |
||||||
|
static const NoInterceptor OnServerTrailingMetadata; |
||||||
|
static const NoInterceptor OnFinalize; |
||||||
|
}; |
||||||
|
|
||||||
|
static absl::StatusOr<std::unique_ptr<TestFilter<I>>> Create( |
||||||
|
const ChannelArgs& channel_args, ChannelFilter::Args filter_args) { |
||||||
|
MaybeLogCreation(channel_args, filter_args, I); |
||||||
|
return std::make_unique<TestFilter<I>>(); |
||||||
|
} |
||||||
|
|
||||||
|
private: |
||||||
|
std::unique_ptr<int> i_ = std::make_unique<int>(I); |
||||||
|
}; |
||||||
|
|
||||||
|
template <int I> |
||||||
|
const NoInterceptor TestFilter<I>::Call::OnServerInitialMetadata; |
||||||
|
template <int I> |
||||||
|
const NoInterceptor TestFilter<I>::Call::OnClientToServerMessage; |
||||||
|
template <int I> |
||||||
|
const NoInterceptor TestFilter<I>::Call::OnServerToClientMessage; |
||||||
|
template <int I> |
||||||
|
const NoInterceptor TestFilter<I>::Call::OnServerTrailingMetadata; |
||||||
|
template <int I> |
||||||
|
const NoInterceptor TestFilter<I>::Call::OnFinalize; |
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Test call filter that fails to instantiate
|
||||||
|
|
||||||
|
template <int I> |
||||||
|
class FailsToInstantiateFilter { |
||||||
|
public: |
||||||
|
class Call { |
||||||
|
public: |
||||||
|
static const NoInterceptor OnClientInitialMetadata; |
||||||
|
static const NoInterceptor OnServerInitialMetadata; |
||||||
|
static const NoInterceptor OnClientToServerMessage; |
||||||
|
static const NoInterceptor OnServerToClientMessage; |
||||||
|
static const NoInterceptor OnServerTrailingMetadata; |
||||||
|
static const NoInterceptor OnFinalize; |
||||||
|
}; |
||||||
|
|
||||||
|
static absl::StatusOr<std::unique_ptr<FailsToInstantiateFilter<I>>> Create( |
||||||
|
const ChannelArgs& channel_args, ChannelFilter::Args filter_args) { |
||||||
|
MaybeLogCreation(channel_args, filter_args, I); |
||||||
|
return absl::InternalError(absl::StrCat("👊 failed to instantiate ", I)); |
||||||
|
} |
||||||
|
}; |
||||||
|
|
||||||
|
template <int I> |
||||||
|
const NoInterceptor FailsToInstantiateFilter<I>::Call::OnClientInitialMetadata; |
||||||
|
template <int I> |
||||||
|
const NoInterceptor FailsToInstantiateFilter<I>::Call::OnServerInitialMetadata; |
||||||
|
template <int I> |
||||||
|
const NoInterceptor FailsToInstantiateFilter<I>::Call::OnClientToServerMessage; |
||||||
|
template <int I> |
||||||
|
const NoInterceptor FailsToInstantiateFilter<I>::Call::OnServerToClientMessage; |
||||||
|
template <int I> |
||||||
|
const NoInterceptor FailsToInstantiateFilter<I>::Call::OnServerTrailingMetadata; |
||||||
|
template <int I> |
||||||
|
const NoInterceptor FailsToInstantiateFilter<I>::Call::OnFinalize; |
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Test call interceptor - consumes calls
|
||||||
|
|
||||||
|
template <int I> |
||||||
|
class TestConsumingInterceptor final : public Interceptor { |
||||||
|
public: |
||||||
|
void StartCall(UnstartedCallHandler unstarted_call_handler) override { |
||||||
|
Consume(std::move(unstarted_call_handler)) |
||||||
|
.PushServerTrailingMetadata( |
||||||
|
ServerMetadataFromStatus(absl::InternalError("👊 consumed"))); |
||||||
|
} |
||||||
|
void Orphaned() override {} |
||||||
|
static absl::StatusOr<RefCountedPtr<TestConsumingInterceptor<I>>> Create( |
||||||
|
const ChannelArgs& channel_args, ChannelFilter::Args filter_args) { |
||||||
|
MaybeLogCreation(channel_args, filter_args, I); |
||||||
|
return MakeRefCounted<TestConsumingInterceptor<I>>(); |
||||||
|
} |
||||||
|
}; |
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Test call interceptor - fails to instantiate
|
||||||
|
|
||||||
|
template <int I> |
||||||
|
class TestFailingInterceptor final : public Interceptor { |
||||||
|
public: |
||||||
|
void StartCall(UnstartedCallHandler unstarted_call_handler) override { |
||||||
|
Crash("unreachable"); |
||||||
|
} |
||||||
|
void Orphaned() override {} |
||||||
|
static absl::StatusOr<RefCountedPtr<TestFailingInterceptor<I>>> Create( |
||||||
|
const ChannelArgs& channel_args, ChannelFilter::Args filter_args) { |
||||||
|
MaybeLogCreation(channel_args, filter_args, I); |
||||||
|
return absl::InternalError(absl::StrCat("👊 failed to instantiate ", I)); |
||||||
|
} |
||||||
|
}; |
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Test call interceptor - hijacks calls
|
||||||
|
|
||||||
|
template <int I> |
||||||
|
class TestHijackingInterceptor final : public Interceptor { |
||||||
|
public: |
||||||
|
void StartCall(UnstartedCallHandler unstarted_call_handler) override { |
||||||
|
unstarted_call_handler.SpawnInfallible( |
||||||
|
"hijack", [this, unstarted_call_handler]() mutable { |
||||||
|
return Map(Hijack(std::move(unstarted_call_handler)), |
||||||
|
[](ValueOrFailure<HijackedCall> hijacked_call) { |
||||||
|
ForwardCall( |
||||||
|
hijacked_call.value().original_call_handler(), |
||||||
|
hijacked_call.value().MakeCall()); |
||||||
|
return Empty{}; |
||||||
|
}); |
||||||
|
}); |
||||||
|
} |
||||||
|
void Orphaned() override {} |
||||||
|
static absl::StatusOr<RefCountedPtr<TestHijackingInterceptor<I>>> Create( |
||||||
|
const ChannelArgs& channel_args, ChannelFilter::Args filter_args) { |
||||||
|
MaybeLogCreation(channel_args, filter_args, I); |
||||||
|
return MakeRefCounted<TestHijackingInterceptor<I>>(); |
||||||
|
} |
||||||
|
}; |
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Test fixture
|
||||||
|
|
||||||
|
class InterceptionChainTest : public ::testing::Test { |
||||||
|
protected: |
||||||
|
InterceptionChainTest() {} |
||||||
|
~InterceptionChainTest() override {} |
||||||
|
|
||||||
|
RefCountedPtr<UnstartedCallDestination> destination() { return destination_; } |
||||||
|
|
||||||
|
struct FinishedCall { |
||||||
|
CallInitiator call; |
||||||
|
ClientMetadataHandle client_metadata; |
||||||
|
ServerMetadataHandle server_metadata; |
||||||
|
}; |
||||||
|
|
||||||
|
// Run a call through a UnstartedCallDestination until it's complete.
|
||||||
|
FinishedCall RunCall(UnstartedCallDestination* destination) { |
||||||
|
auto* arena = call_arena_allocator_->MakeArena(); |
||||||
|
auto call = MakeCallPair(Arena::MakePooled<ClientMetadata>(), nullptr, |
||||||
|
arena, call_arena_allocator_, nullptr); |
||||||
|
Poll<ServerMetadataHandle> trailing_md; |
||||||
|
call.initiator.SpawnInfallible( |
||||||
|
"run_call", [destination, &call, &trailing_md]() mutable { |
||||||
|
gpr_log(GPR_INFO, "👊 start call"); |
||||||
|
destination->StartCall(std::move(call.handler)); |
||||||
|
return Map(call.initiator.PullServerTrailingMetadata(), |
||||||
|
[&trailing_md](ServerMetadataHandle md) { |
||||||
|
trailing_md = std::move(md); |
||||||
|
return Empty{}; |
||||||
|
}); |
||||||
|
}); |
||||||
|
EXPECT_THAT(trailing_md, IsReady()); |
||||||
|
return FinishedCall{std::move(call.initiator), destination_->TakeMetadata(), |
||||||
|
std::move(trailing_md.value())}; |
||||||
|
} |
||||||
|
|
||||||
|
private: |
||||||
|
class Destination final : public UnstartedCallDestination { |
||||||
|
public: |
||||||
|
void StartCall(UnstartedCallHandler unstarted_call_handler) override { |
||||||
|
gpr_log(GPR_INFO, "👊 started call: metadata=%s", |
||||||
|
unstarted_call_handler.UnprocessedClientInitialMetadata() |
||||||
|
.DebugString() |
||||||
|
.c_str()); |
||||||
|
EXPECT_EQ(metadata_.get(), nullptr); |
||||||
|
metadata_ = Arena::MakePooled<ClientMetadata>(); |
||||||
|
*metadata_ = |
||||||
|
unstarted_call_handler.UnprocessedClientInitialMetadata().Copy(); |
||||||
|
unstarted_call_handler.PushServerTrailingMetadata( |
||||||
|
ServerMetadataFromStatus(absl::InternalError("👊 cancelled"))); |
||||||
|
} |
||||||
|
|
||||||
|
void Orphaned() override {} |
||||||
|
|
||||||
|
ClientMetadataHandle TakeMetadata() { return std::move(metadata_); } |
||||||
|
|
||||||
|
private: |
||||||
|
ClientMetadataHandle metadata_; |
||||||
|
}; |
||||||
|
RefCountedPtr<Destination> destination_ = MakeRefCounted<Destination>(); |
||||||
|
RefCountedPtr<CallArenaAllocator> call_arena_allocator_ = |
||||||
|
MakeRefCounted<CallArenaAllocator>( |
||||||
|
ResourceQuota::Default()->memory_quota()->CreateMemoryAllocator( |
||||||
|
"test"), |
||||||
|
1024); |
||||||
|
}; |
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Tests begin
|
||||||
|
|
||||||
|
TEST_F(InterceptionChainTest, Empty) { |
||||||
|
auto r = InterceptionChainBuilder(ChannelArgs()).Build(destination()); |
||||||
|
ASSERT_TRUE(r.ok()) << r.status(); |
||||||
|
auto finished_call = RunCall(r.value().get()); |
||||||
|
EXPECT_EQ(finished_call.server_metadata->get(GrpcStatusMetadata()), |
||||||
|
GRPC_STATUS_INTERNAL); |
||||||
|
EXPECT_EQ(finished_call.server_metadata->get_pointer(GrpcMessageMetadata()) |
||||||
|
->as_string_view(), |
||||||
|
"👊 cancelled"); |
||||||
|
EXPECT_NE(finished_call.client_metadata, nullptr); |
||||||
|
} |
||||||
|
|
||||||
|
TEST_F(InterceptionChainTest, Consumed) { |
||||||
|
auto r = InterceptionChainBuilder(ChannelArgs()) |
||||||
|
.Add<TestConsumingInterceptor<1>>() |
||||||
|
.Build(destination()); |
||||||
|
ASSERT_TRUE(r.ok()) << r.status(); |
||||||
|
auto finished_call = RunCall(r.value().get()); |
||||||
|
EXPECT_EQ(finished_call.server_metadata->get(GrpcStatusMetadata()), |
||||||
|
GRPC_STATUS_INTERNAL); |
||||||
|
EXPECT_EQ(finished_call.server_metadata->get_pointer(GrpcMessageMetadata()) |
||||||
|
->as_string_view(), |
||||||
|
"👊 consumed"); |
||||||
|
EXPECT_EQ(finished_call.client_metadata, nullptr); |
||||||
|
} |
||||||
|
|
||||||
|
TEST_F(InterceptionChainTest, Hijacked) { |
||||||
|
auto r = InterceptionChainBuilder(ChannelArgs()) |
||||||
|
.Add<TestHijackingInterceptor<1>>() |
||||||
|
.Build(destination()); |
||||||
|
ASSERT_TRUE(r.ok()) << r.status(); |
||||||
|
auto finished_call = RunCall(r.value().get()); |
||||||
|
EXPECT_EQ(finished_call.server_metadata->get(GrpcStatusMetadata()), |
||||||
|
GRPC_STATUS_INTERNAL); |
||||||
|
EXPECT_EQ(finished_call.server_metadata->get_pointer(GrpcMessageMetadata()) |
||||||
|
->as_string_view(), |
||||||
|
"👊 cancelled"); |
||||||
|
EXPECT_NE(finished_call.client_metadata, nullptr); |
||||||
|
} |
||||||
|
|
||||||
|
TEST_F(InterceptionChainTest, FiltersThenHijacked) { |
||||||
|
auto r = InterceptionChainBuilder(ChannelArgs()) |
||||||
|
.Add<TestFilter<1>>() |
||||||
|
.Add<TestHijackingInterceptor<2>>() |
||||||
|
.Build(destination()); |
||||||
|
ASSERT_TRUE(r.ok()) << r.status(); |
||||||
|
auto finished_call = RunCall(r.value().get()); |
||||||
|
EXPECT_EQ(finished_call.server_metadata->get(GrpcStatusMetadata()), |
||||||
|
GRPC_STATUS_INTERNAL); |
||||||
|
EXPECT_EQ(finished_call.server_metadata->get_pointer(GrpcMessageMetadata()) |
||||||
|
->as_string_view(), |
||||||
|
"👊 cancelled"); |
||||||
|
EXPECT_NE(finished_call.client_metadata, nullptr); |
||||||
|
std::string backing; |
||||||
|
EXPECT_EQ(finished_call.client_metadata->GetStringValue("passed-through-1", |
||||||
|
&backing), |
||||||
|
"true"); |
||||||
|
} |
||||||
|
|
||||||
|
TEST_F(InterceptionChainTest, FailsToInstantiateInterceptor) { |
||||||
|
auto r = InterceptionChainBuilder(ChannelArgs()) |
||||||
|
.Add<TestFailingInterceptor<1>>() |
||||||
|
.Build(destination()); |
||||||
|
EXPECT_FALSE(r.ok()); |
||||||
|
EXPECT_EQ(r.status().code(), absl::StatusCode::kInternal); |
||||||
|
EXPECT_EQ(r.status().message(), "👊 failed to instantiate 1"); |
||||||
|
} |
||||||
|
|
||||||
|
TEST_F(InterceptionChainTest, FailsToInstantiateInterceptor2) { |
||||||
|
auto r = InterceptionChainBuilder(ChannelArgs()) |
||||||
|
.Add<TestFilter<1>>() |
||||||
|
.Add<TestFailingInterceptor<2>>() |
||||||
|
.Build(destination()); |
||||||
|
EXPECT_FALSE(r.ok()); |
||||||
|
EXPECT_EQ(r.status().code(), absl::StatusCode::kInternal); |
||||||
|
EXPECT_EQ(r.status().message(), "👊 failed to instantiate 2"); |
||||||
|
} |
||||||
|
|
||||||
|
TEST_F(InterceptionChainTest, FailsToInstantiateFilter) { |
||||||
|
auto r = InterceptionChainBuilder(ChannelArgs()) |
||||||
|
.Add<FailsToInstantiateFilter<1>>() |
||||||
|
.Build(destination()); |
||||||
|
EXPECT_FALSE(r.ok()); |
||||||
|
EXPECT_EQ(r.status().code(), absl::StatusCode::kInternal); |
||||||
|
EXPECT_EQ(r.status().message(), "👊 failed to instantiate 1"); |
||||||
|
} |
||||||
|
|
||||||
|
TEST_F(InterceptionChainTest, FailsToInstantiateFilter2) { |
||||||
|
auto r = InterceptionChainBuilder(ChannelArgs()) |
||||||
|
.Add<TestFilter<1>>() |
||||||
|
.Add<FailsToInstantiateFilter<2>>() |
||||||
|
.Build(destination()); |
||||||
|
EXPECT_FALSE(r.ok()); |
||||||
|
EXPECT_EQ(r.status().code(), absl::StatusCode::kInternal); |
||||||
|
EXPECT_EQ(r.status().message(), "👊 failed to instantiate 2"); |
||||||
|
} |
||||||
|
|
||||||
|
TEST_F(InterceptionChainTest, CreationOrderCorrect) { |
||||||
|
CreationLog log; |
||||||
|
auto r = InterceptionChainBuilder(ChannelArgs().SetObject(&log)) |
||||||
|
.Add<TestFilter<1>>() |
||||||
|
.Add<TestFilter<2>>() |
||||||
|
.Add<TestFilter<3>>() |
||||||
|
.Add<TestConsumingInterceptor<4>>() |
||||||
|
.Add<TestFilter<1>>() |
||||||
|
.Add<TestFilter<2>>() |
||||||
|
.Add<TestFilter<3>>() |
||||||
|
.Add<TestConsumingInterceptor<4>>() |
||||||
|
.Add<TestFilter<1>>() |
||||||
|
.Build(destination()); |
||||||
|
EXPECT_THAT(log.entries, ::testing::ElementsAre( |
||||||
|
CreationLogEntry{0, 1}, CreationLogEntry{0, 2}, |
||||||
|
CreationLogEntry{0, 3}, CreationLogEntry{0, 4}, |
||||||
|
CreationLogEntry{1, 1}, CreationLogEntry{1, 2}, |
||||||
|
CreationLogEntry{1, 3}, CreationLogEntry{1, 4}, |
||||||
|
CreationLogEntry{2, 1})); |
||||||
|
} |
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace grpc_core
|
||||||
|
|
||||||
|
int main(int argc, char** argv) { |
||||||
|
::testing::InitGoogleTest(&argc, argv); |
||||||
|
grpc_tracer_init(); |
||||||
|
gpr_log_verbosity_init(); |
||||||
|
return RUN_ALL_TESTS(); |
||||||
|
} |
Loading…
Reference in new issue