mirror of https://github.com/grpc/grpc.git
[call-v3] Interception chain (#36414)
Introduce the interception chain type.
Also introduces the real call-v3 call spine based atop CallFilters.
Closes #36414
COPYBARA_INTEGRATE_REVIEW=https://github.com/grpc/grpc/pull/36414 from ctiller:interception-chain 90c8e96973
PiperOrigin-RevId: 627784183
pull/36402/head
parent
459abbec5a
commit
d52779da52
43 changed files with 2538 additions and 478 deletions
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,156 @@ |
||||
// Copyright 2024 gRPC authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/core/lib/transport/interception_chain.h" |
||||
|
||||
#include <cstddef> |
||||
|
||||
#include <grpc/support/port_platform.h> |
||||
|
||||
#include "src/core/lib/gprpp/match.h" |
||||
#include "src/core/lib/transport/call_destination.h" |
||||
#include "src/core/lib/transport/call_filters.h" |
||||
#include "src/core/lib/transport/call_spine.h" |
||||
#include "src/core/lib/transport/metadata.h" |
||||
|
||||
namespace grpc_core { |
||||
|
||||
std::atomic<size_t> InterceptionChainBuilder::next_filter_id_{0}; |
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// HijackedCall
|
||||
|
||||
CallInitiator HijackedCall::MakeCall() { |
||||
auto metadata = Arena::MakePooled<ClientMetadata>(); |
||||
*metadata = metadata_->Copy(); |
||||
return MakeCallWithMetadata(std::move(metadata)); |
||||
} |
||||
|
||||
CallInitiator HijackedCall::MakeCallWithMetadata( |
||||
ClientMetadataHandle metadata) { |
||||
auto call = MakeCallPair(std::move(metadata), call_handler_.event_engine(), |
||||
call_handler_.arena(), nullptr, |
||||
call_handler_.legacy_context()); |
||||
destination_->StartCall(std::move(call.handler)); |
||||
return std::move(call.initiator); |
||||
} |
||||
|
||||
namespace { |
||||
class CallStarter final : public UnstartedCallDestination { |
||||
public: |
||||
CallStarter(RefCountedPtr<CallFilters::Stack> stack, |
||||
RefCountedPtr<CallDestination> destination) |
||||
: stack_(std::move(stack)), destination_(std::move(destination)) {} |
||||
|
||||
void Orphaned() override { |
||||
stack_.reset(); |
||||
destination_.reset(); |
||||
} |
||||
|
||||
void StartCall(UnstartedCallHandler unstarted_call_handler) override { |
||||
destination_->HandleCall(unstarted_call_handler.StartCall(stack_)); |
||||
} |
||||
|
||||
private: |
||||
RefCountedPtr<CallFilters::Stack> stack_; |
||||
RefCountedPtr<CallDestination> destination_; |
||||
}; |
||||
|
||||
class TerminalInterceptor final : public UnstartedCallDestination { |
||||
public: |
||||
explicit TerminalInterceptor( |
||||
RefCountedPtr<CallFilters::Stack> stack, |
||||
RefCountedPtr<UnstartedCallDestination> destination) |
||||
: stack_(std::move(stack)), destination_(std::move(destination)) {} |
||||
|
||||
void Orphaned() override { |
||||
stack_.reset(); |
||||
destination_.reset(); |
||||
} |
||||
|
||||
void StartCall(UnstartedCallHandler unstarted_call_handler) override { |
||||
unstarted_call_handler.SpawnGuarded( |
||||
"start_call", |
||||
Map(interception_chain_detail::HijackCall(unstarted_call_handler, |
||||
destination_, stack_), |
||||
[](ValueOrFailure<HijackedCall> hijacked_call) -> StatusFlag { |
||||
if (!hijacked_call.ok()) return Failure{}; |
||||
ForwardCall(hijacked_call.value().original_call_handler(), |
||||
hijacked_call.value().MakeLastCall()); |
||||
return Success{}; |
||||
})); |
||||
} |
||||
|
||||
private: |
||||
RefCountedPtr<CallFilters::Stack> stack_; |
||||
RefCountedPtr<UnstartedCallDestination> destination_; |
||||
}; |
||||
} // namespace
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// InterceptionChain::Builder
|
||||
|
||||
void InterceptionChainBuilder::AddInterceptor( |
||||
absl::StatusOr<RefCountedPtr<Interceptor>> interceptor) { |
||||
if (!status_.ok()) return; |
||||
if (!interceptor.ok()) { |
||||
status_ = interceptor.status(); |
||||
return; |
||||
} |
||||
(*interceptor)->filter_stack_ = MakeFilterStack(); |
||||
if (top_interceptor_ == nullptr) { |
||||
top_interceptor_ = std::move(*interceptor); |
||||
} else { |
||||
Interceptor* previous = top_interceptor_.get(); |
||||
while (previous->wrapped_destination_ != nullptr) { |
||||
previous = DownCast<Interceptor*>(previous->wrapped_destination_.get()); |
||||
} |
||||
previous->wrapped_destination_ = std::move(*interceptor); |
||||
} |
||||
} |
||||
|
||||
absl::StatusOr<RefCountedPtr<UnstartedCallDestination>> |
||||
InterceptionChainBuilder::Build(FinalDestination final_destination) { |
||||
if (!status_.ok()) return status_; |
||||
// Build the final UnstartedCallDestination in the chain - what we do here
|
||||
// depends on both the type of the final destination and the filters we have
|
||||
// that haven't been captured into an Interceptor yet.
|
||||
RefCountedPtr<UnstartedCallDestination> terminator = Match( |
||||
final_destination, |
||||
[this](RefCountedPtr<UnstartedCallDestination> final_destination) |
||||
-> RefCountedPtr<UnstartedCallDestination> { |
||||
if (stack_builder_.has_value()) { |
||||
return MakeRefCounted<TerminalInterceptor>(MakeFilterStack(), |
||||
final_destination); |
||||
} |
||||
return final_destination; |
||||
}, |
||||
[this](RefCountedPtr<CallDestination> final_destination) |
||||
-> RefCountedPtr<UnstartedCallDestination> { |
||||
return MakeRefCounted<CallStarter>(MakeFilterStack(), |
||||
std::move(final_destination)); |
||||
}); |
||||
// Now append the terminator to the interceptor chain.
|
||||
if (top_interceptor_ == nullptr) { |
||||
return std::move(terminator); |
||||
} |
||||
Interceptor* previous = top_interceptor_.get(); |
||||
while (previous->wrapped_destination_ != nullptr) { |
||||
previous = DownCast<Interceptor*>(previous->wrapped_destination_.get()); |
||||
} |
||||
previous->wrapped_destination_ = std::move(terminator); |
||||
return std::move(top_interceptor_); |
||||
} |
||||
|
||||
} // namespace grpc_core
|
@ -0,0 +1,225 @@ |
||||
// Copyright 2024 gRPC authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef GRPC_SRC_CORE_LIB_TRANSPORT_INTERCEPTION_CHAIN_H |
||||
#define GRPC_SRC_CORE_LIB_TRANSPORT_INTERCEPTION_CHAIN_H |
||||
|
||||
#include <memory> |
||||
#include <vector> |
||||
|
||||
#include <grpc/support/port_platform.h> |
||||
|
||||
#include "src/core/lib/gprpp/ref_counted.h" |
||||
#include "src/core/lib/transport/call_destination.h" |
||||
#include "src/core/lib/transport/call_filters.h" |
||||
#include "src/core/lib/transport/call_spine.h" |
||||
#include "src/core/lib/transport/metadata.h" |
||||
|
||||
namespace grpc_core { |
||||
|
||||
class InterceptionChainBuilder; |
||||
|
||||
// One hijacked call. Using this we can get access to the CallHandler for the
|
||||
// call object above us, the processed metadata from any filters/interceptors
|
||||
// above us, and also create new CallInterceptor objects that will be handled
|
||||
// below.
|
||||
class HijackedCall final { |
||||
public: |
||||
HijackedCall(ClientMetadataHandle metadata, |
||||
RefCountedPtr<UnstartedCallDestination> destination, |
||||
CallHandler call_handler) |
||||
: metadata_(std::move(metadata)), |
||||
destination_(std::move(destination)), |
||||
call_handler_(std::move(call_handler)) {} |
||||
|
||||
// Create a new call and pass it down the stack.
|
||||
// This can be called as many times as needed.
|
||||
CallInitiator MakeCall(); |
||||
// Per MakeCall(), but precludes creating further calls.
|
||||
// Allows us to optimize by not copying initial metadata.
|
||||
CallInitiator MakeLastCall() { |
||||
return MakeCallWithMetadata(std::move(metadata_)); |
||||
} |
||||
|
||||
CallHandler& original_call_handler() { return call_handler_; } |
||||
|
||||
ClientMetadata& client_metadata() { return *metadata_; } |
||||
|
||||
private: |
||||
CallInitiator MakeCallWithMetadata(ClientMetadataHandle metadata); |
||||
|
||||
ClientMetadataHandle metadata_; |
||||
RefCountedPtr<UnstartedCallDestination> destination_; |
||||
CallHandler call_handler_; |
||||
}; |
||||
|
||||
namespace interception_chain_detail { |
||||
|
||||
inline auto HijackCall(UnstartedCallHandler unstarted_call_handler, |
||||
RefCountedPtr<UnstartedCallDestination> destination, |
||||
RefCountedPtr<CallFilters::Stack> stack) { |
||||
auto call_handler = unstarted_call_handler.StartCall(stack); |
||||
return Map( |
||||
call_handler.PullClientInitialMetadata(), |
||||
[call_handler, |
||||
destination](ValueOrFailure<ClientMetadataHandle> metadata) mutable |
||||
-> ValueOrFailure<HijackedCall> { |
||||
if (!metadata.ok()) return Failure{}; |
||||
return HijackedCall(std::move(metadata.value()), std::move(destination), |
||||
std::move(call_handler)); |
||||
}); |
||||
} |
||||
|
||||
} // namespace interception_chain_detail
|
||||
|
||||
// A delegating UnstartedCallDestination for use as a hijacking filter.
|
||||
// Implementations may look at the unprocessed initial metadata
|
||||
// and decide to do one of two things:
|
||||
//
|
||||
// 1. It can hijack the call. Returns a HijackedCall object that can
|
||||
// be used to start new calls with the same metadata.
|
||||
//
|
||||
// 2. It can consume the call by calling `Consume`.
|
||||
//
|
||||
// Upon the StartCall call the UnstartedCallHandler will be from the last
|
||||
// *Interceptor* in the call chain (without having been processed by any
|
||||
// intervening filters) -- note that this is commonly not useful (not enough
|
||||
// guarantees), and so it's usually better to Hijack and examine the metadata.
|
||||
class Interceptor : public UnstartedCallDestination { |
||||
protected: |
||||
// Returns a promise that resolves to a HijackedCall instance.
|
||||
// Hijacking is the process of taking over a call and starting one or more new
|
||||
// ones.
|
||||
auto Hijack(UnstartedCallHandler unstarted_call_handler) { |
||||
return interception_chain_detail::HijackCall( |
||||
std::move(unstarted_call_handler), wrapped_destination_, filter_stack_); |
||||
} |
||||
|
||||
// Consume this call - it will not be passed on to any further filters.
|
||||
CallHandler Consume(UnstartedCallHandler unstarted_call_handler) { |
||||
return unstarted_call_handler.StartCall(filter_stack_); |
||||
} |
||||
|
||||
// TODO(ctiller): Consider a Passthrough() method that allows the call to be
|
||||
// passed on to the next filter in the chain without any interception by the
|
||||
// current filter.
|
||||
|
||||
private: |
||||
friend class InterceptionChainBuilder; |
||||
|
||||
RefCountedPtr<UnstartedCallDestination> wrapped_destination_; |
||||
RefCountedPtr<CallFilters::Stack> filter_stack_; |
||||
}; |
||||
|
||||
class InterceptionChainBuilder final { |
||||
public: |
||||
// The kind of destination that the chain will eventually call.
|
||||
// We can bottom out in various types depending on where we're intercepting:
|
||||
// - The top half of the client channel wants to terminate on a
|
||||
// UnstartedCallDestination (specifically the LB call destination).
|
||||
// - The bottom half of the client channel and the server code wants to
|
||||
// terminate on a ClientTransport - which unlike a
|
||||
// UnstartedCallDestination demands a started CallHandler.
|
||||
// There's some adaption code that's needed to start filters just prior
|
||||
// to the bottoming out, and some design considerations to make with that.
|
||||
// One way (that's not chosen here) would be to have the caller of the
|
||||
// Builder provide something that can build an adaptor
|
||||
// UnstartedCallDestination with parameters supplied by this builder - that
|
||||
// disperses the responsibility of building the adaptor to the caller, which
|
||||
// is not ideal - we might want to adjust the way this construct is built in
|
||||
// the future, and building is a builder responsibility.
|
||||
// Instead, we declare a relatively closed set of destinations here, and
|
||||
// hide the adaptors inside the builder at build time.
|
||||
using FinalDestination = |
||||
absl::variant<RefCountedPtr<UnstartedCallDestination>, |
||||
RefCountedPtr<CallDestination>>; |
||||
|
||||
explicit InterceptionChainBuilder(ChannelArgs args) |
||||
: args_(std::move(args)) {} |
||||
|
||||
// Add a filter with a `Call` class as an inner member.
|
||||
// Call class must be one compatible with the filters described in
|
||||
// call_filters.h.
|
||||
template <typename T> |
||||
absl::enable_if_t<sizeof(typename T::Call) != 0, InterceptionChainBuilder&> |
||||
Add() { |
||||
if (!status_.ok()) return *this; |
||||
auto filter = T::Create(args_, {FilterInstanceId(FilterTypeId<T>())}); |
||||
if (!filter.ok()) { |
||||
status_ = filter.status(); |
||||
return *this; |
||||
} |
||||
auto& sb = stack_builder(); |
||||
sb.Add(filter.value().get()); |
||||
sb.AddOwnedObject(std::move(filter.value())); |
||||
return *this; |
||||
}; |
||||
|
||||
// Add a filter that is an interceptor - one that can hijack calls.
|
||||
template <typename T> |
||||
absl::enable_if_t<std::is_base_of<Interceptor, T>::value, |
||||
InterceptionChainBuilder&> |
||||
Add() { |
||||
AddInterceptor(T::Create(args_, {FilterInstanceId(FilterTypeId<T>())})); |
||||
return *this; |
||||
}; |
||||
|
||||
// Add a filter that just mutates server trailing metadata.
|
||||
template <typename F> |
||||
void AddOnServerTrailingMetadata(F f) { |
||||
stack_builder().AddOnServerTrailingMetadata(std::move(f)); |
||||
} |
||||
|
||||
// Build this stack
|
||||
absl::StatusOr<RefCountedPtr<UnstartedCallDestination>> Build( |
||||
FinalDestination final_destination); |
||||
|
||||
const ChannelArgs& channel_args() const { return args_; } |
||||
|
||||
private: |
||||
CallFilters::StackBuilder& stack_builder() { |
||||
if (!stack_builder_.has_value()) stack_builder_.emplace(); |
||||
return *stack_builder_; |
||||
} |
||||
|
||||
RefCountedPtr<CallFilters::Stack> MakeFilterStack() { |
||||
auto stack = stack_builder().Build(); |
||||
stack_builder_.reset(); |
||||
return stack; |
||||
} |
||||
|
||||
template <typename T> |
||||
static size_t FilterTypeId() { |
||||
static const size_t id = |
||||
next_filter_id_.fetch_add(1, std::memory_order_relaxed); |
||||
return id; |
||||
} |
||||
|
||||
size_t FilterInstanceId(size_t filter_type) { |
||||
return filter_type_counts_[filter_type]++; |
||||
} |
||||
|
||||
void AddInterceptor(absl::StatusOr<RefCountedPtr<Interceptor>> interceptor); |
||||
|
||||
ChannelArgs args_; |
||||
absl::optional<CallFilters::StackBuilder> stack_builder_; |
||||
RefCountedPtr<Interceptor> top_interceptor_; |
||||
absl::Status status_; |
||||
std::map<size_t, size_t> filter_type_counts_; |
||||
static std::atomic<size_t> next_filter_id_; |
||||
}; |
||||
|
||||
} // namespace grpc_core
|
||||
|
||||
#endif // GRPC_SRC_CORE_LIB_TRANSPORT_INTERCEPTION_CHAIN_H
|
@ -0,0 +1,60 @@ |
||||
// Copyright 2024 gRPC authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef GRPC_TEST_CORE_PROMISE_POLL_MATCHER_H |
||||
#define GRPC_TEST_CORE_PROMISE_POLL_MATCHER_H |
||||
|
||||
#include "gmock/gmock.h" |
||||
|
||||
// Various gmock matchers for Poll
|
||||
|
||||
namespace grpc_core { |
||||
|
||||
// Expect that a promise is still pending:
|
||||
// EXPECT_THAT(some_promise(), IsPending());
|
||||
MATCHER(IsPending, "") { |
||||
if (arg.ready()) { |
||||
*result_listener << "is ready"; |
||||
return false; |
||||
} |
||||
return true; |
||||
} |
||||
|
||||
// Expect that a promise is ready:
|
||||
// EXPECT_THAT(some_promise(), IsReady());
|
||||
MATCHER(IsReady, "") { |
||||
if (arg.pending()) { |
||||
*result_listener << "is pending"; |
||||
return false; |
||||
} |
||||
return true; |
||||
} |
||||
|
||||
// Expect that a promise is ready with a specific value:
|
||||
// EXPECT_THAT(some_promise(), IsReady(value));
|
||||
MATCHER_P(IsReady, value, "") { |
||||
if (arg.pending()) { |
||||
*result_listener << "is pending"; |
||||
return false; |
||||
} |
||||
if (arg.value() != value) { |
||||
*result_listener << "is " << ::testing::PrintToString(arg.value()); |
||||
return false; |
||||
} |
||||
return true; |
||||
} |
||||
|
||||
} // namespace grpc_core
|
||||
|
||||
#endif // GRPC_TEST_CORE_PROMISE_POLL_MATCHER_H
|
@ -0,0 +1,406 @@ |
||||
// Copyright 2024 gRPC authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/core/lib/transport/interception_chain.h" |
||||
|
||||
#include <memory> |
||||
|
||||
#include "gmock/gmock.h" |
||||
#include "gtest/gtest.h" |
||||
|
||||
#include <grpc/support/log.h> |
||||
|
||||
#include "src/core/lib/channel/promise_based_filter.h" |
||||
#include "src/core/lib/resource_quota/resource_quota.h" |
||||
#include "test/core/promise/poll_matcher.h" |
||||
|
||||
namespace grpc_core { |
||||
namespace { |
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Mutate metadata by annotating that it passed through a filter "x"
|
||||
|
||||
void AnnotatePassedThrough(ClientMetadata& md, int x) { |
||||
md.Append(absl::StrCat("passed-through-", x), Slice::FromCopiedString("true"), |
||||
[](absl::string_view, const Slice&) { Crash("unreachable"); }); |
||||
} |
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// CreationLog helps us reason about filter creation order by logging a small
|
||||
// record of each filter's creation.
|
||||
|
||||
struct CreationLogEntry { |
||||
size_t filter_instance_id; |
||||
size_t type_tag; |
||||
|
||||
bool operator==(const CreationLogEntry& other) const { |
||||
return filter_instance_id == other.filter_instance_id && |
||||
type_tag == other.type_tag; |
||||
} |
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, |
||||
const CreationLogEntry& entry) { |
||||
return os << "{filter_instance_id=" << entry.filter_instance_id |
||||
<< ", type_tag=" << entry.type_tag << "}"; |
||||
} |
||||
}; |
||||
|
||||
struct CreationLog { |
||||
struct RawPointerChannelArgTag {}; |
||||
static absl::string_view ChannelArgName() { return "creation_log"; } |
||||
std::vector<CreationLogEntry> entries; |
||||
}; |
||||
|
||||
void MaybeLogCreation(const ChannelArgs& channel_args, |
||||
ChannelFilter::Args filter_args, size_t type_tag) { |
||||
auto* log = channel_args.GetPointer<CreationLog>("creation_log"); |
||||
if (log == nullptr) return; |
||||
log->entries.push_back(CreationLogEntry{filter_args.instance_id(), type_tag}); |
||||
} |
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Test call filter
|
||||
|
||||
template <int I> |
||||
class TestFilter { |
||||
public: |
||||
class Call { |
||||
public: |
||||
void OnClientInitialMetadata(ClientMetadata& md) { |
||||
AnnotatePassedThrough(md, I); |
||||
} |
||||
static const NoInterceptor OnServerInitialMetadata; |
||||
static const NoInterceptor OnClientToServerMessage; |
||||
static const NoInterceptor OnServerToClientMessage; |
||||
static const NoInterceptor OnServerTrailingMetadata; |
||||
static const NoInterceptor OnFinalize; |
||||
}; |
||||
|
||||
static absl::StatusOr<std::unique_ptr<TestFilter<I>>> Create( |
||||
const ChannelArgs& channel_args, ChannelFilter::Args filter_args) { |
||||
MaybeLogCreation(channel_args, filter_args, I); |
||||
return std::make_unique<TestFilter<I>>(); |
||||
} |
||||
|
||||
private: |
||||
std::unique_ptr<int> i_ = std::make_unique<int>(I); |
||||
}; |
||||
|
||||
template <int I> |
||||
const NoInterceptor TestFilter<I>::Call::OnServerInitialMetadata; |
||||
template <int I> |
||||
const NoInterceptor TestFilter<I>::Call::OnClientToServerMessage; |
||||
template <int I> |
||||
const NoInterceptor TestFilter<I>::Call::OnServerToClientMessage; |
||||
template <int I> |
||||
const NoInterceptor TestFilter<I>::Call::OnServerTrailingMetadata; |
||||
template <int I> |
||||
const NoInterceptor TestFilter<I>::Call::OnFinalize; |
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Test call filter that fails to instantiate
|
||||
|
||||
template <int I> |
||||
class FailsToInstantiateFilter { |
||||
public: |
||||
class Call { |
||||
public: |
||||
static const NoInterceptor OnClientInitialMetadata; |
||||
static const NoInterceptor OnServerInitialMetadata; |
||||
static const NoInterceptor OnClientToServerMessage; |
||||
static const NoInterceptor OnServerToClientMessage; |
||||
static const NoInterceptor OnServerTrailingMetadata; |
||||
static const NoInterceptor OnFinalize; |
||||
}; |
||||
|
||||
static absl::StatusOr<std::unique_ptr<FailsToInstantiateFilter<I>>> Create( |
||||
const ChannelArgs& channel_args, ChannelFilter::Args filter_args) { |
||||
MaybeLogCreation(channel_args, filter_args, I); |
||||
return absl::InternalError(absl::StrCat("👊 failed to instantiate ", I)); |
||||
} |
||||
}; |
||||
|
||||
template <int I> |
||||
const NoInterceptor FailsToInstantiateFilter<I>::Call::OnClientInitialMetadata; |
||||
template <int I> |
||||
const NoInterceptor FailsToInstantiateFilter<I>::Call::OnServerInitialMetadata; |
||||
template <int I> |
||||
const NoInterceptor FailsToInstantiateFilter<I>::Call::OnClientToServerMessage; |
||||
template <int I> |
||||
const NoInterceptor FailsToInstantiateFilter<I>::Call::OnServerToClientMessage; |
||||
template <int I> |
||||
const NoInterceptor FailsToInstantiateFilter<I>::Call::OnServerTrailingMetadata; |
||||
template <int I> |
||||
const NoInterceptor FailsToInstantiateFilter<I>::Call::OnFinalize; |
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Test call interceptor - consumes calls
|
||||
|
||||
template <int I> |
||||
class TestConsumingInterceptor final : public Interceptor { |
||||
public: |
||||
void StartCall(UnstartedCallHandler unstarted_call_handler) override { |
||||
Consume(std::move(unstarted_call_handler)) |
||||
.PushServerTrailingMetadata( |
||||
ServerMetadataFromStatus(absl::InternalError("👊 consumed"))); |
||||
} |
||||
void Orphaned() override {} |
||||
static absl::StatusOr<RefCountedPtr<TestConsumingInterceptor<I>>> Create( |
||||
const ChannelArgs& channel_args, ChannelFilter::Args filter_args) { |
||||
MaybeLogCreation(channel_args, filter_args, I); |
||||
return MakeRefCounted<TestConsumingInterceptor<I>>(); |
||||
} |
||||
}; |
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Test call interceptor - fails to instantiate
|
||||
|
||||
template <int I> |
||||
class TestFailingInterceptor final : public Interceptor { |
||||
public: |
||||
void StartCall(UnstartedCallHandler unstarted_call_handler) override { |
||||
Crash("unreachable"); |
||||
} |
||||
void Orphaned() override {} |
||||
static absl::StatusOr<RefCountedPtr<TestFailingInterceptor<I>>> Create( |
||||
const ChannelArgs& channel_args, ChannelFilter::Args filter_args) { |
||||
MaybeLogCreation(channel_args, filter_args, I); |
||||
return absl::InternalError(absl::StrCat("👊 failed to instantiate ", I)); |
||||
} |
||||
}; |
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Test call interceptor - hijacks calls
|
||||
|
||||
template <int I> |
||||
class TestHijackingInterceptor final : public Interceptor { |
||||
public: |
||||
void StartCall(UnstartedCallHandler unstarted_call_handler) override { |
||||
unstarted_call_handler.SpawnInfallible( |
||||
"hijack", [this, unstarted_call_handler]() mutable { |
||||
return Map(Hijack(std::move(unstarted_call_handler)), |
||||
[](ValueOrFailure<HijackedCall> hijacked_call) { |
||||
ForwardCall( |
||||
hijacked_call.value().original_call_handler(), |
||||
hijacked_call.value().MakeCall()); |
||||
return Empty{}; |
||||
}); |
||||
}); |
||||
} |
||||
void Orphaned() override {} |
||||
static absl::StatusOr<RefCountedPtr<TestHijackingInterceptor<I>>> Create( |
||||
const ChannelArgs& channel_args, ChannelFilter::Args filter_args) { |
||||
MaybeLogCreation(channel_args, filter_args, I); |
||||
return MakeRefCounted<TestHijackingInterceptor<I>>(); |
||||
} |
||||
}; |
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Test fixture
|
||||
|
||||
class InterceptionChainTest : public ::testing::Test { |
||||
protected: |
||||
InterceptionChainTest() {} |
||||
~InterceptionChainTest() override {} |
||||
|
||||
RefCountedPtr<UnstartedCallDestination> destination() { return destination_; } |
||||
|
||||
struct FinishedCall { |
||||
CallInitiator call; |
||||
ClientMetadataHandle client_metadata; |
||||
ServerMetadataHandle server_metadata; |
||||
}; |
||||
|
||||
// Run a call through a UnstartedCallDestination until it's complete.
|
||||
FinishedCall RunCall(UnstartedCallDestination* destination) { |
||||
auto* arena = call_arena_allocator_->MakeArena(); |
||||
auto call = MakeCallPair(Arena::MakePooled<ClientMetadata>(), nullptr, |
||||
arena, call_arena_allocator_, nullptr); |
||||
Poll<ServerMetadataHandle> trailing_md; |
||||
call.initiator.SpawnInfallible( |
||||
"run_call", [destination, &call, &trailing_md]() mutable { |
||||
gpr_log(GPR_INFO, "👊 start call"); |
||||
destination->StartCall(std::move(call.handler)); |
||||
return Map(call.initiator.PullServerTrailingMetadata(), |
||||
[&trailing_md](ServerMetadataHandle md) { |
||||
trailing_md = std::move(md); |
||||
return Empty{}; |
||||
}); |
||||
}); |
||||
EXPECT_THAT(trailing_md, IsReady()); |
||||
return FinishedCall{std::move(call.initiator), destination_->TakeMetadata(), |
||||
std::move(trailing_md.value())}; |
||||
} |
||||
|
||||
private: |
||||
class Destination final : public UnstartedCallDestination { |
||||
public: |
||||
void StartCall(UnstartedCallHandler unstarted_call_handler) override { |
||||
gpr_log(GPR_INFO, "👊 started call: metadata=%s", |
||||
unstarted_call_handler.UnprocessedClientInitialMetadata() |
||||
.DebugString() |
||||
.c_str()); |
||||
EXPECT_EQ(metadata_.get(), nullptr); |
||||
metadata_ = Arena::MakePooled<ClientMetadata>(); |
||||
*metadata_ = |
||||
unstarted_call_handler.UnprocessedClientInitialMetadata().Copy(); |
||||
unstarted_call_handler.PushServerTrailingMetadata( |
||||
ServerMetadataFromStatus(absl::InternalError("👊 cancelled"))); |
||||
} |
||||
|
||||
void Orphaned() override {} |
||||
|
||||
ClientMetadataHandle TakeMetadata() { return std::move(metadata_); } |
||||
|
||||
private: |
||||
ClientMetadataHandle metadata_; |
||||
}; |
||||
RefCountedPtr<Destination> destination_ = MakeRefCounted<Destination>(); |
||||
RefCountedPtr<CallArenaAllocator> call_arena_allocator_ = |
||||
MakeRefCounted<CallArenaAllocator>( |
||||
ResourceQuota::Default()->memory_quota()->CreateMemoryAllocator( |
||||
"test"), |
||||
1024); |
||||
}; |
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Tests begin
|
||||
|
||||
TEST_F(InterceptionChainTest, Empty) { |
||||
auto r = InterceptionChainBuilder(ChannelArgs()).Build(destination()); |
||||
ASSERT_TRUE(r.ok()) << r.status(); |
||||
auto finished_call = RunCall(r.value().get()); |
||||
EXPECT_EQ(finished_call.server_metadata->get(GrpcStatusMetadata()), |
||||
GRPC_STATUS_INTERNAL); |
||||
EXPECT_EQ(finished_call.server_metadata->get_pointer(GrpcMessageMetadata()) |
||||
->as_string_view(), |
||||
"👊 cancelled"); |
||||
EXPECT_NE(finished_call.client_metadata, nullptr); |
||||
} |
||||
|
||||
TEST_F(InterceptionChainTest, Consumed) { |
||||
auto r = InterceptionChainBuilder(ChannelArgs()) |
||||
.Add<TestConsumingInterceptor<1>>() |
||||
.Build(destination()); |
||||
ASSERT_TRUE(r.ok()) << r.status(); |
||||
auto finished_call = RunCall(r.value().get()); |
||||
EXPECT_EQ(finished_call.server_metadata->get(GrpcStatusMetadata()), |
||||
GRPC_STATUS_INTERNAL); |
||||
EXPECT_EQ(finished_call.server_metadata->get_pointer(GrpcMessageMetadata()) |
||||
->as_string_view(), |
||||
"👊 consumed"); |
||||
EXPECT_EQ(finished_call.client_metadata, nullptr); |
||||
} |
||||
|
||||
TEST_F(InterceptionChainTest, Hijacked) { |
||||
auto r = InterceptionChainBuilder(ChannelArgs()) |
||||
.Add<TestHijackingInterceptor<1>>() |
||||
.Build(destination()); |
||||
ASSERT_TRUE(r.ok()) << r.status(); |
||||
auto finished_call = RunCall(r.value().get()); |
||||
EXPECT_EQ(finished_call.server_metadata->get(GrpcStatusMetadata()), |
||||
GRPC_STATUS_INTERNAL); |
||||
EXPECT_EQ(finished_call.server_metadata->get_pointer(GrpcMessageMetadata()) |
||||
->as_string_view(), |
||||
"👊 cancelled"); |
||||
EXPECT_NE(finished_call.client_metadata, nullptr); |
||||
} |
||||
|
||||
TEST_F(InterceptionChainTest, FiltersThenHijacked) { |
||||
auto r = InterceptionChainBuilder(ChannelArgs()) |
||||
.Add<TestFilter<1>>() |
||||
.Add<TestHijackingInterceptor<2>>() |
||||
.Build(destination()); |
||||
ASSERT_TRUE(r.ok()) << r.status(); |
||||
auto finished_call = RunCall(r.value().get()); |
||||
EXPECT_EQ(finished_call.server_metadata->get(GrpcStatusMetadata()), |
||||
GRPC_STATUS_INTERNAL); |
||||
EXPECT_EQ(finished_call.server_metadata->get_pointer(GrpcMessageMetadata()) |
||||
->as_string_view(), |
||||
"👊 cancelled"); |
||||
EXPECT_NE(finished_call.client_metadata, nullptr); |
||||
std::string backing; |
||||
EXPECT_EQ(finished_call.client_metadata->GetStringValue("passed-through-1", |
||||
&backing), |
||||
"true"); |
||||
} |
||||
|
||||
TEST_F(InterceptionChainTest, FailsToInstantiateInterceptor) { |
||||
auto r = InterceptionChainBuilder(ChannelArgs()) |
||||
.Add<TestFailingInterceptor<1>>() |
||||
.Build(destination()); |
||||
EXPECT_FALSE(r.ok()); |
||||
EXPECT_EQ(r.status().code(), absl::StatusCode::kInternal); |
||||
EXPECT_EQ(r.status().message(), "👊 failed to instantiate 1"); |
||||
} |
||||
|
||||
TEST_F(InterceptionChainTest, FailsToInstantiateInterceptor2) { |
||||
auto r = InterceptionChainBuilder(ChannelArgs()) |
||||
.Add<TestFilter<1>>() |
||||
.Add<TestFailingInterceptor<2>>() |
||||
.Build(destination()); |
||||
EXPECT_FALSE(r.ok()); |
||||
EXPECT_EQ(r.status().code(), absl::StatusCode::kInternal); |
||||
EXPECT_EQ(r.status().message(), "👊 failed to instantiate 2"); |
||||
} |
||||
|
||||
TEST_F(InterceptionChainTest, FailsToInstantiateFilter) { |
||||
auto r = InterceptionChainBuilder(ChannelArgs()) |
||||
.Add<FailsToInstantiateFilter<1>>() |
||||
.Build(destination()); |
||||
EXPECT_FALSE(r.ok()); |
||||
EXPECT_EQ(r.status().code(), absl::StatusCode::kInternal); |
||||
EXPECT_EQ(r.status().message(), "👊 failed to instantiate 1"); |
||||
} |
||||
|
||||
TEST_F(InterceptionChainTest, FailsToInstantiateFilter2) { |
||||
auto r = InterceptionChainBuilder(ChannelArgs()) |
||||
.Add<TestFilter<1>>() |
||||
.Add<FailsToInstantiateFilter<2>>() |
||||
.Build(destination()); |
||||
EXPECT_FALSE(r.ok()); |
||||
EXPECT_EQ(r.status().code(), absl::StatusCode::kInternal); |
||||
EXPECT_EQ(r.status().message(), "👊 failed to instantiate 2"); |
||||
} |
||||
|
||||
TEST_F(InterceptionChainTest, CreationOrderCorrect) { |
||||
CreationLog log; |
||||
auto r = InterceptionChainBuilder(ChannelArgs().SetObject(&log)) |
||||
.Add<TestFilter<1>>() |
||||
.Add<TestFilter<2>>() |
||||
.Add<TestFilter<3>>() |
||||
.Add<TestConsumingInterceptor<4>>() |
||||
.Add<TestFilter<1>>() |
||||
.Add<TestFilter<2>>() |
||||
.Add<TestFilter<3>>() |
||||
.Add<TestConsumingInterceptor<4>>() |
||||
.Add<TestFilter<1>>() |
||||
.Build(destination()); |
||||
EXPECT_THAT(log.entries, ::testing::ElementsAre( |
||||
CreationLogEntry{0, 1}, CreationLogEntry{0, 2}, |
||||
CreationLogEntry{0, 3}, CreationLogEntry{0, 4}, |
||||
CreationLogEntry{1, 1}, CreationLogEntry{1, 2}, |
||||
CreationLogEntry{1, 3}, CreationLogEntry{1, 4}, |
||||
CreationLogEntry{2, 1})); |
||||
} |
||||
|
||||
} // namespace
|
||||
} // namespace grpc_core
|
||||
|
||||
int main(int argc, char** argv) { |
||||
::testing::InitGoogleTest(&argc, argv); |
||||
grpc_tracer_init(); |
||||
gpr_log_verbosity_init(); |
||||
return RUN_ALL_TESTS(); |
||||
} |
Loading…
Reference in new issue