Pass channel args to ChannelData ctor and ChannelData to CallData ctor.

pull/6915/head
Mark D. Roth 9 years ago
parent 8a1a5976c7
commit c008b33c18
  1. 79
      include/grpc++/channel_filter.h
  2. 2
      src/cpp/common/channel_filter.cc
  3. 15
      test/cpp/end2end/filter_end2end_test.cc

@ -53,6 +53,20 @@
namespace grpc { namespace grpc {
// Represents channel data.
// Note: Must be copyable.
class ChannelData {
public:
virtual ~ChannelData() {}
virtual void StartTransportOp(
grpc_exec_ctx *exec_ctx, grpc_channel_element *elem,
grpc_transport_op *op);
protected:
explicit ChannelData(const grpc_channel_args&) {}
};
// Represents call data. // Represents call data.
// Note: Must be copyable. // Note: Must be copyable.
class CallData { class CallData {
@ -70,21 +84,7 @@ class CallData {
virtual char* GetPeer(grpc_exec_ctx *exec_ctx, grpc_call_element *elem); virtual char* GetPeer(grpc_exec_ctx *exec_ctx, grpc_call_element *elem);
protected: protected:
CallData() {} explicit CallData(const ChannelData&) {}
};
// Represents channel data.
// Note: Must be copyable.
class ChannelData {
public:
virtual ~ChannelData() {}
virtual void StartTransportOp(
grpc_exec_ctx *exec_ctx, grpc_channel_element *elem,
grpc_transport_op *op);
protected:
ChannelData() {}
}; };
namespace internal { namespace internal {
@ -93,13 +93,35 @@ namespace internal {
template<typename ChannelDataType, typename CallDataType> template<typename ChannelDataType, typename CallDataType>
class ChannelFilter { class ChannelFilter {
public: public:
static const size_t channel_data_size = sizeof(ChannelDataType);
static void InitChannelElement(
grpc_exec_ctx *exec_ctx, grpc_channel_element *elem,
grpc_channel_element_args *args) {
// Construct the object in the already-allocated memory.
new (elem->channel_data) ChannelDataType(*args->channel_args);
}
static void DestroyChannelElement(
grpc_exec_ctx *exec_ctx, grpc_channel_element *elem) {
reinterpret_cast<ChannelDataType*>(elem->channel_data)->~ChannelDataType();
}
static void StartTransportOp(
grpc_exec_ctx *exec_ctx, grpc_channel_element *elem,
grpc_transport_op *op) {
ChannelDataType* channel_data = (ChannelDataType*)elem->channel_data;
channel_data->StartTransportOp(exec_ctx, elem, op);
}
static const size_t call_data_size = sizeof(CallDataType); static const size_t call_data_size = sizeof(CallDataType);
static void InitCallElement( static void InitCallElement(
grpc_exec_ctx *exec_ctx, grpc_call_element *elem, grpc_exec_ctx *exec_ctx, grpc_call_element *elem,
grpc_call_element_args *args) { grpc_call_element_args *args) {
const ChannelDataType& channel_data = *(ChannelDataType*)elem->channel_data;
// Construct the object in the already-allocated memory. // Construct the object in the already-allocated memory.
new (elem->call_data) CallDataType(); new (elem->call_data) CallDataType(channel_data);
} }
static void DestroyCallElement( static void DestroyCallElement(
@ -127,33 +149,12 @@ class ChannelFilter {
CallDataType* call_data = (CallDataType*)elem->call_data; CallDataType* call_data = (CallDataType*)elem->call_data;
return call_data->GetPeer(exec_ctx, elem); return call_data->GetPeer(exec_ctx, elem);
} }
static const size_t channel_data_size = sizeof(ChannelDataType);
static void InitChannelElement(
grpc_exec_ctx *exec_ctx, grpc_channel_element *elem,
grpc_channel_element_args *args) {
// Construct the object in the already-allocated memory.
new (elem->channel_data) ChannelDataType();
}
static void DestroyChannelElement(
grpc_exec_ctx *exec_ctx, grpc_channel_element *elem) {
reinterpret_cast<ChannelDataType*>(elem->channel_data)->~ChannelDataType();
}
static void StartTransportOp(
grpc_exec_ctx *exec_ctx, grpc_channel_element *elem,
grpc_transport_op *op) {
ChannelDataType* channel_data = (ChannelDataType*)elem->channel_data;
channel_data->StartTransportOp(exec_ctx, elem, op);
}
}; };
struct FilterRecord { struct FilterRecord {
grpc_channel_stack_type stack_type; grpc_channel_stack_type stack_type;
int priority; int priority;
std::function<bool(const grpc_channel_args*)> include_filter; std::function<bool(const grpc_channel_args&)> include_filter;
grpc_channel_filter filter; grpc_channel_filter filter;
}; };
extern std::vector<FilterRecord>* channel_filters; extern std::vector<FilterRecord>* channel_filters;
@ -171,7 +172,7 @@ void ChannelFilterPluginShutdown();
template<typename ChannelDataType, typename CallDataType> template<typename ChannelDataType, typename CallDataType>
void RegisterChannelFilter( void RegisterChannelFilter(
const char* name, grpc_channel_stack_type stack_type, int priority, const char* name, grpc_channel_stack_type stack_type, int priority,
std::function<bool(const grpc_channel_args*)> include_filter) { std::function<bool(const grpc_channel_args&)> include_filter) {
// If we haven't been called before, initialize channel_filters and // If we haven't been called before, initialize channel_filters and
// call grpc_register_plugin(). // call grpc_register_plugin().
if (internal::channel_filters == nullptr) { if (internal::channel_filters == nullptr) {

@ -83,7 +83,7 @@ bool MaybeAddFilter(grpc_channel_stack_builder* builder, void* arg) {
if (filter.include_filter != nullptr) { if (filter.include_filter != nullptr) {
const grpc_channel_args *args = const grpc_channel_args *args =
grpc_channel_stack_builder_get_channel_arguments(builder); grpc_channel_stack_builder_get_channel_arguments(builder);
if (!filter.include_filter(args)) if (!filter.include_filter(*args))
return true; return true;
} }
return grpc_channel_stack_builder_prepend_filter( return grpc_channel_stack_builder_prepend_filter(

@ -95,9 +95,16 @@ int GetCounterValue() {
} // namespace } // namespace
class ChannelDataImpl : public ChannelData {
public:
explicit ChannelDataImpl(const grpc_channel_args& args) : ChannelData(args) {}
virtual ~ChannelDataImpl() {}
};
class CallDataImpl : public CallData { class CallDataImpl : public CallData {
public: public:
CallDataImpl() {} explicit CallDataImpl(const ChannelDataImpl& channel_data)
: CallData(channel_data) {}
virtual ~CallDataImpl() {} virtual ~CallDataImpl() {}
void StartTransportStreamOp( void StartTransportStreamOp(
@ -109,12 +116,6 @@ class CallDataImpl : public CallData {
} }
}; };
class ChannelDataImpl : public ChannelData {
public:
ChannelDataImpl() {}
virtual ~ChannelDataImpl() {}
};
class FilterEnd2endTest : public ::testing::Test { class FilterEnd2endTest : public ::testing::Test {
protected: protected:
FilterEnd2endTest() : server_host_("localhost") {} FilterEnd2endTest() : server_host_("localhost") {}

Loading…
Cancel
Save