|
|
|
@ -50,32 +50,32 @@ class InternalInterceptorBatchMethods |
|
|
|
|
|
|
|
|
|
virtual void SetRecvMessage(void* message) = 0; |
|
|
|
|
|
|
|
|
|
virtual void SetRecvInitialMetadata(internal::MetadataMap* map) = 0; |
|
|
|
|
virtual void SetRecvInitialMetadata(MetadataMap* map) = 0; |
|
|
|
|
|
|
|
|
|
virtual void SetRecvStatus(Status* status) = 0; |
|
|
|
|
|
|
|
|
|
virtual void SetRecvTrailingMetadata(internal::MetadataMap* map) = 0; |
|
|
|
|
virtual void SetRecvTrailingMetadata(MetadataMap* map) = 0; |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
class InterceptorBatchMethodsImpl : public InternalInterceptorBatchMethods { |
|
|
|
|
public: |
|
|
|
|
InterceptorBatchMethodsImpl() { |
|
|
|
|
for (auto i = 0; |
|
|
|
|
i < static_cast<int>( |
|
|
|
|
experimental::InterceptionHookPoints::NUM_INTERCEPTION_HOOKS); |
|
|
|
|
i++) { |
|
|
|
|
hooks_[i] = false; |
|
|
|
|
for (auto i = static_cast<experimental::InterceptionHookPoints>(0); |
|
|
|
|
i < experimental::InterceptionHookPoints::NUM_INTERCEPTION_HOOKS; |
|
|
|
|
i = static_cast<experimental::InterceptionHookPoints>( |
|
|
|
|
static_cast<size_t>(i) + 1)) { |
|
|
|
|
hooks_[static_cast<size_t>(i)] = false; |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
virtual ~InterceptorBatchMethodsImpl() {} |
|
|
|
|
~InterceptorBatchMethodsImpl() {} |
|
|
|
|
|
|
|
|
|
virtual bool QueryInterceptionHookPoint( |
|
|
|
|
bool QueryInterceptionHookPoint( |
|
|
|
|
experimental::InterceptionHookPoints type) override { |
|
|
|
|
return hooks_[static_cast<int>(type)]; |
|
|
|
|
return hooks_[static_cast<size_t>(type)]; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
virtual void Proceed() override { /* fill this */ |
|
|
|
|
void Proceed() override { /* fill this */ |
|
|
|
|
if (call_->client_rpc_info() != nullptr) { |
|
|
|
|
return ProceedClient(); |
|
|
|
|
} |
|
|
|
@ -83,7 +83,7 @@ class InterceptorBatchMethodsImpl : public InternalInterceptorBatchMethods { |
|
|
|
|
ProceedServer(); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
virtual void Hijack() override { |
|
|
|
|
void Hijack() override { |
|
|
|
|
// Only the client can hijack when sending down initial metadata
|
|
|
|
|
GPR_CODEGEN_ASSERT(!reverse_ && ops_ != nullptr && |
|
|
|
|
call_->client_rpc_info() != nullptr); |
|
|
|
@ -91,99 +91,94 @@ class InterceptorBatchMethodsImpl : public InternalInterceptorBatchMethods { |
|
|
|
|
GPR_CODEGEN_ASSERT(!ran_hijacking_interceptor_); |
|
|
|
|
auto* rpc_info = call_->client_rpc_info(); |
|
|
|
|
rpc_info->hijacked_ = true; |
|
|
|
|
rpc_info->hijacked_interceptor_ = curr_iteration_; |
|
|
|
|
rpc_info->hijacked_interceptor_ = current_interceptor_index_; |
|
|
|
|
ClearHookPoints(); |
|
|
|
|
ops_->SetHijackingState(); |
|
|
|
|
ran_hijacking_interceptor_ = true; |
|
|
|
|
rpc_info->RunInterceptor(this, curr_iteration_); |
|
|
|
|
rpc_info->RunInterceptor(this, current_interceptor_index_); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
virtual void AddInterceptionHookPoint( |
|
|
|
|
void AddInterceptionHookPoint( |
|
|
|
|
experimental::InterceptionHookPoints type) override { |
|
|
|
|
hooks_[static_cast<int>(type)] = true; |
|
|
|
|
hooks_[static_cast<size_t>(type)] = true; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
virtual ByteBuffer* GetSendMessage() override { return send_message_; } |
|
|
|
|
ByteBuffer* GetSendMessage() override { return send_message_; } |
|
|
|
|
|
|
|
|
|
virtual std::multimap<grpc::string, grpc::string>* GetSendInitialMetadata() |
|
|
|
|
override { |
|
|
|
|
std::multimap<grpc::string, grpc::string>* GetSendInitialMetadata() override { |
|
|
|
|
return send_initial_metadata_; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
virtual Status GetSendStatus() override { |
|
|
|
|
Status GetSendStatus() override { |
|
|
|
|
return Status(static_cast<StatusCode>(*code_), *error_message_, |
|
|
|
|
*error_details_); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
virtual void ModifySendStatus(const Status& status) override { |
|
|
|
|
void ModifySendStatus(const Status& status) override { |
|
|
|
|
*code_ = static_cast<grpc_status_code>(status.error_code()); |
|
|
|
|
*error_details_ = status.error_details(); |
|
|
|
|
*error_message_ = status.error_message(); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
virtual std::multimap<grpc::string, grpc::string>* GetSendTrailingMetadata() |
|
|
|
|
std::multimap<grpc::string, grpc::string>* GetSendTrailingMetadata() |
|
|
|
|
override { |
|
|
|
|
return send_trailing_metadata_; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
virtual void* GetRecvMessage() override { return recv_message_; } |
|
|
|
|
void* GetRecvMessage() override { return recv_message_; } |
|
|
|
|
|
|
|
|
|
virtual std::multimap<grpc::string_ref, grpc::string_ref>* |
|
|
|
|
GetRecvInitialMetadata() override { |
|
|
|
|
std::multimap<grpc::string_ref, grpc::string_ref>* GetRecvInitialMetadata() |
|
|
|
|
override { |
|
|
|
|
return recv_initial_metadata_->map(); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
virtual Status* GetRecvStatus() override { return recv_status_; } |
|
|
|
|
Status* GetRecvStatus() override { return recv_status_; } |
|
|
|
|
|
|
|
|
|
virtual std::multimap<grpc::string_ref, grpc::string_ref>* |
|
|
|
|
GetRecvTrailingMetadata() override { |
|
|
|
|
std::multimap<grpc::string_ref, grpc::string_ref>* GetRecvTrailingMetadata() |
|
|
|
|
override { |
|
|
|
|
return recv_trailing_metadata_->map(); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
virtual void SetSendMessage(ByteBuffer* buf) override { send_message_ = buf; } |
|
|
|
|
void SetSendMessage(ByteBuffer* buf) override { send_message_ = buf; } |
|
|
|
|
|
|
|
|
|
virtual void SetSendInitialMetadata( |
|
|
|
|
void SetSendInitialMetadata( |
|
|
|
|
std::multimap<grpc::string, grpc::string>* metadata) override { |
|
|
|
|
send_initial_metadata_ = metadata; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
virtual void SetSendStatus(grpc_status_code* code, |
|
|
|
|
grpc::string* error_details, |
|
|
|
|
grpc::string* error_message) override { |
|
|
|
|
void SetSendStatus(grpc_status_code* code, grpc::string* error_details, |
|
|
|
|
grpc::string* error_message) override { |
|
|
|
|
code_ = code; |
|
|
|
|
error_details_ = error_details; |
|
|
|
|
error_message_ = error_message; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
virtual void SetSendTrailingMetadata( |
|
|
|
|
void SetSendTrailingMetadata( |
|
|
|
|
std::multimap<grpc::string, grpc::string>* metadata) override { |
|
|
|
|
send_trailing_metadata_ = metadata; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
virtual void SetRecvMessage(void* message) override { |
|
|
|
|
recv_message_ = message; |
|
|
|
|
} |
|
|
|
|
void SetRecvMessage(void* message) override { recv_message_ = message; } |
|
|
|
|
|
|
|
|
|
virtual void SetRecvInitialMetadata(internal::MetadataMap* map) override { |
|
|
|
|
void SetRecvInitialMetadata(MetadataMap* map) override { |
|
|
|
|
recv_initial_metadata_ = map; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
virtual void SetRecvStatus(Status* status) override { recv_status_ = status; } |
|
|
|
|
void SetRecvStatus(Status* status) override { recv_status_ = status; } |
|
|
|
|
|
|
|
|
|
virtual void SetRecvTrailingMetadata(internal::MetadataMap* map) override { |
|
|
|
|
void SetRecvTrailingMetadata(MetadataMap* map) override { |
|
|
|
|
recv_trailing_metadata_ = map; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
virtual std::unique_ptr<ChannelInterface> GetInterceptedChannel() override { |
|
|
|
|
std::unique_ptr<ChannelInterface> GetInterceptedChannel() override { |
|
|
|
|
auto* info = call_->client_rpc_info(); |
|
|
|
|
if (info == nullptr) { |
|
|
|
|
return std::unique_ptr<ChannelInterface>(nullptr); |
|
|
|
|
} |
|
|
|
|
// The intercepted channel starts from the interceptor just after the
|
|
|
|
|
// current interceptor
|
|
|
|
|
return std::unique_ptr<ChannelInterface>(new internal::InterceptedChannel( |
|
|
|
|
reinterpret_cast<grpc::ChannelInterface*>(info->channel()), |
|
|
|
|
curr_iteration_ + 1)); |
|
|
|
|
return std::unique_ptr<ChannelInterface>(new InterceptedChannel( |
|
|
|
|
info->channel(), current_interceptor_index_ + 1)); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// Clears all state
|
|
|
|
@ -256,60 +251,63 @@ class InterceptorBatchMethodsImpl : public InternalInterceptorBatchMethods { |
|
|
|
|
void RunClientInterceptors() { |
|
|
|
|
auto* rpc_info = call_->client_rpc_info(); |
|
|
|
|
if (!reverse_) { |
|
|
|
|
curr_iteration_ = 0; |
|
|
|
|
current_interceptor_index_ = 0; |
|
|
|
|
} else { |
|
|
|
|
if (rpc_info->hijacked_) { |
|
|
|
|
curr_iteration_ = rpc_info->hijacked_interceptor_; |
|
|
|
|
current_interceptor_index_ = rpc_info->hijacked_interceptor_; |
|
|
|
|
} else { |
|
|
|
|
curr_iteration_ = rpc_info->interceptors_.size() - 1; |
|
|
|
|
current_interceptor_index_ = rpc_info->interceptors_.size() - 1; |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
rpc_info->RunInterceptor(this, curr_iteration_); |
|
|
|
|
rpc_info->RunInterceptor(this, current_interceptor_index_); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
void RunServerInterceptors() { |
|
|
|
|
auto* rpc_info = call_->server_rpc_info(); |
|
|
|
|
if (!reverse_) { |
|
|
|
|
curr_iteration_ = 0; |
|
|
|
|
current_interceptor_index_ = 0; |
|
|
|
|
} else { |
|
|
|
|
curr_iteration_ = rpc_info->interceptors_.size() - 1; |
|
|
|
|
current_interceptor_index_ = rpc_info->interceptors_.size() - 1; |
|
|
|
|
} |
|
|
|
|
rpc_info->RunInterceptor(this, curr_iteration_); |
|
|
|
|
rpc_info->RunInterceptor(this, current_interceptor_index_); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
void ProceedClient() { |
|
|
|
|
auto* rpc_info = call_->client_rpc_info(); |
|
|
|
|
if (rpc_info->hijacked_ && !reverse_ && |
|
|
|
|
curr_iteration_ == rpc_info->hijacked_interceptor_ && |
|
|
|
|
static_cast<size_t>(current_interceptor_index_) == |
|
|
|
|
rpc_info->hijacked_interceptor_ && |
|
|
|
|
!ran_hijacking_interceptor_) { |
|
|
|
|
// We now need to provide hijacked recv ops to this interceptor
|
|
|
|
|
ClearHookPoints(); |
|
|
|
|
ops_->SetHijackingState(); |
|
|
|
|
ran_hijacking_interceptor_ = true; |
|
|
|
|
rpc_info->RunInterceptor(this, curr_iteration_); |
|
|
|
|
rpc_info->RunInterceptor(this, current_interceptor_index_); |
|
|
|
|
return; |
|
|
|
|
} |
|
|
|
|
if (!reverse_) { |
|
|
|
|
curr_iteration_++; |
|
|
|
|
current_interceptor_index_++; |
|
|
|
|
// We are going down the stack of interceptors
|
|
|
|
|
if (curr_iteration_ < static_cast<long>(rpc_info->interceptors_.size())) { |
|
|
|
|
if (static_cast<size_t>(current_interceptor_index_) < |
|
|
|
|
rpc_info->interceptors_.size()) { |
|
|
|
|
if (rpc_info->hijacked_ && |
|
|
|
|
curr_iteration_ > rpc_info->hijacked_interceptor_) { |
|
|
|
|
static_cast<size_t>(current_interceptor_index_) > |
|
|
|
|
rpc_info->hijacked_interceptor_) { |
|
|
|
|
// This is a hijacked RPC and we are done with hijacking
|
|
|
|
|
ops_->ContinueFillOpsAfterInterception(); |
|
|
|
|
} else { |
|
|
|
|
rpc_info->RunInterceptor(this, curr_iteration_); |
|
|
|
|
rpc_info->RunInterceptor(this, current_interceptor_index_); |
|
|
|
|
} |
|
|
|
|
} else { |
|
|
|
|
// we are done running all the interceptors without any hijacking
|
|
|
|
|
ops_->ContinueFillOpsAfterInterception(); |
|
|
|
|
} |
|
|
|
|
} else { |
|
|
|
|
curr_iteration_--; |
|
|
|
|
current_interceptor_index_--; |
|
|
|
|
// We are going up the stack of interceptors
|
|
|
|
|
if (curr_iteration_ >= 0) { |
|
|
|
|
if (current_interceptor_index_ >= 0) { |
|
|
|
|
// Continue running interceptors
|
|
|
|
|
rpc_info->RunInterceptor(this, curr_iteration_); |
|
|
|
|
rpc_info->RunInterceptor(this, current_interceptor_index_); |
|
|
|
|
} else { |
|
|
|
|
// we are done running all the interceptors without any hijacking
|
|
|
|
|
ops_->ContinueFinalizeResultAfterInterception(); |
|
|
|
@ -320,18 +318,19 @@ class InterceptorBatchMethodsImpl : public InternalInterceptorBatchMethods { |
|
|
|
|
void ProceedServer() { |
|
|
|
|
auto* rpc_info = call_->server_rpc_info(); |
|
|
|
|
if (!reverse_) { |
|
|
|
|
curr_iteration_++; |
|
|
|
|
if (curr_iteration_ < static_cast<long>(rpc_info->interceptors_.size())) { |
|
|
|
|
return rpc_info->RunInterceptor(this, curr_iteration_); |
|
|
|
|
current_interceptor_index_++; |
|
|
|
|
if (static_cast<size_t>(current_interceptor_index_) < |
|
|
|
|
rpc_info->interceptors_.size()) { |
|
|
|
|
return rpc_info->RunInterceptor(this, current_interceptor_index_); |
|
|
|
|
} else if (ops_) { |
|
|
|
|
return ops_->ContinueFillOpsAfterInterception(); |
|
|
|
|
} |
|
|
|
|
} else { |
|
|
|
|
curr_iteration_--; |
|
|
|
|
current_interceptor_index_--; |
|
|
|
|
// We are going up the stack of interceptors
|
|
|
|
|
if (curr_iteration_ >= 0) { |
|
|
|
|
if (current_interceptor_index_ >= 0) { |
|
|
|
|
// Continue running interceptors
|
|
|
|
|
return rpc_info->RunInterceptor(this, curr_iteration_); |
|
|
|
|
return rpc_info->RunInterceptor(this, current_interceptor_index_); |
|
|
|
|
} else if (ops_) { |
|
|
|
|
return ops_->ContinueFinalizeResultAfterInterception(); |
|
|
|
|
} |
|
|
|
@ -341,20 +340,20 @@ class InterceptorBatchMethodsImpl : public InternalInterceptorBatchMethods { |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
void ClearHookPoints() { |
|
|
|
|
for (auto i = 0; |
|
|
|
|
i < static_cast<int>( |
|
|
|
|
experimental::InterceptionHookPoints::NUM_INTERCEPTION_HOOKS); |
|
|
|
|
i++) { |
|
|
|
|
hooks_[i] = false; |
|
|
|
|
for (auto i = static_cast<experimental::InterceptionHookPoints>(0); |
|
|
|
|
i < experimental::InterceptionHookPoints::NUM_INTERCEPTION_HOOKS; |
|
|
|
|
i = static_cast<experimental::InterceptionHookPoints>( |
|
|
|
|
static_cast<size_t>(i) + 1)) { |
|
|
|
|
hooks_[static_cast<size_t>(i)] = false; |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
std::array<bool, |
|
|
|
|
static_cast<int>( |
|
|
|
|
static_cast<size_t>( |
|
|
|
|
experimental::InterceptionHookPoints::NUM_INTERCEPTION_HOOKS)> |
|
|
|
|
hooks_; |
|
|
|
|
|
|
|
|
|
int curr_iteration_ = 0; // Current iterator
|
|
|
|
|
long current_interceptor_index_ = 0; // Current iterator
|
|
|
|
|
bool reverse_ = false; |
|
|
|
|
bool ran_hijacking_interceptor_ = false; |
|
|
|
|
Call* call_ = |
|
|
|
@ -375,11 +374,11 @@ class InterceptorBatchMethodsImpl : public InternalInterceptorBatchMethods { |
|
|
|
|
|
|
|
|
|
void* recv_message_ = nullptr; |
|
|
|
|
|
|
|
|
|
internal::MetadataMap* recv_initial_metadata_ = nullptr; |
|
|
|
|
MetadataMap* recv_initial_metadata_ = nullptr; |
|
|
|
|
|
|
|
|
|
Status* recv_status_ = nullptr; |
|
|
|
|
|
|
|
|
|
internal::MetadataMap* recv_trailing_metadata_ = nullptr; |
|
|
|
|
MetadataMap* recv_trailing_metadata_ = nullptr; |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
} // namespace internal
|
|
|
|
|