From beac88ca56f4710e86668f2cbbd80e02e0607f9c Mon Sep 17 00:00:00 2001 From: David Garcia Quintas Date: Mon, 10 Aug 2015 13:39:52 -0700 Subject: [PATCH] Server: added the ability to disable compression algorithm --- include/grpc++/server.h | 3 +- include/grpc++/server_builder.h | 23 +++++++------ include/grpc/compression.h | 21 ++++++++++++ src/core/channel/channel_args.c | 27 +++++++-------- src/core/channel/channel_args.h | 15 +++++---- src/core/channel/compress_filter.c | 28 +++++++++++++++- src/core/compression/algorithm.c | 23 +++++++++++++ src/cpp/server/server.cc | 25 +++++++++----- src/cpp/server/server_builder.cc | 53 ++++++++++++++++++++++-------- 9 files changed, 164 insertions(+), 54 deletions(-) diff --git a/include/grpc++/server.h b/include/grpc++/server.h index 94ee0b6a4ac..07dbd7fd202 100644 --- a/include/grpc++/server.h +++ b/include/grpc++/server.h @@ -43,6 +43,7 @@ #include #include #include +#include struct grpc_server; @@ -81,7 +82,7 @@ class Server GRPC_FINAL : public GrpcLibrary, private CallHook { // ServerBuilder use only Server(ThreadPoolInterface* thread_pool, bool thread_pool_owned, - int max_message_size); + int max_message_size, grpc_compression_options compression_options); // Register a service. This call does not take ownership of the service. // The service must exist for the lifetime of the Server instance. bool RegisterService(const grpc::string *host, RpcService* service); diff --git a/include/grpc++/server_builder.h b/include/grpc++/server_builder.h index 44ee00eec9c..47efbb78346 100644 --- a/include/grpc++/server_builder.h +++ b/include/grpc++/server_builder.h @@ -37,6 +37,7 @@ #include #include +#include #include namespace grpc { @@ -59,24 +60,24 @@ class ServerBuilder { // The service must exist for the lifetime of the Server instance returned by // BuildAndStart(). // Matches requests with any :authority - void RegisterService(SynchronousService* service); + ServerBuilder& RegisterService(SynchronousService* service); // Register an asynchronous service. // This call does not take ownership of the service or completion queue. // The service and completion queuemust exist for the lifetime of the Server // instance returned by BuildAndStart(). // Matches requests with any :authority - void RegisterAsyncService(AsynchronousService* service); + ServerBuilder& RegisterAsyncService(AsynchronousService* service); // Register a generic service. // Matches requests with any :authority - void RegisterAsyncGenericService(AsyncGenericService* service); + ServerBuilder& RegisterAsyncGenericService(AsyncGenericService* service); // Register a service. This call does not take ownership of the service. // The service must exist for the lifetime of the Server instance returned by // BuildAndStart(). // Only matches requests with :authority \a host - void RegisterService(const grpc::string& host, + ServerBuilder& RegisterService(const grpc::string& host, SynchronousService* service); // Register an asynchronous service. @@ -84,22 +85,23 @@ class ServerBuilder { // The service and completion queuemust exist for the lifetime of the Server // instance returned by BuildAndStart(). // Only matches requests with :authority \a host - void RegisterAsyncService(const grpc::string& host, + ServerBuilder& RegisterAsyncService(const grpc::string& host, AsynchronousService* service); // Set max message size in bytes. - void SetMaxMessageSize(int max_message_size) { - max_message_size_ = max_message_size; - } + ServerBuilder& SetMaxMessageSize(int max_message_size); // Add a listening port. Can be called multiple times. - void AddListeningPort(const grpc::string& addr, + ServerBuilder& AddListeningPort(const grpc::string& addr, std::shared_ptr creds, int* selected_port = nullptr); // Set the thread pool used for running appliation rpc handlers. // Does not take ownership. - void SetThreadPool(ThreadPoolInterface* thread_pool); + ServerBuilder& SetThreadPool(ThreadPoolInterface* thread_pool); + + // Set the compression options to be used by the server. + ServerBuilder& SetCompressionOptions(const grpc_compression_options& options); // Add a completion queue for handling asynchronous services // Caller is required to keep this completion queue live until calling @@ -126,6 +128,7 @@ class ServerBuilder { }; int max_message_size_; + grpc_compression_options compression_options_; std::vector>> services_; std::vector>> async_services_; std::vector ports_; diff --git a/include/grpc/compression.h b/include/grpc/compression.h index a1a3236d3bf..82e326fe0ec 100644 --- a/include/grpc/compression.h +++ b/include/grpc/compression.h @@ -36,6 +36,8 @@ #include +#include + #ifdef __cplusplus extern "C" { #endif @@ -61,6 +63,11 @@ typedef enum { GRPC_COMPRESS_LEVEL_COUNT } grpc_compression_level; +typedef struct grpc_compression_options { + gpr_uint32 enabled_algorithms_bitset; /**< All algs are enabled by default */ + grpc_compression_algorithm default_compression_algorithm; /**< for channel */ +} grpc_compression_options; + /** Parses the first \a name_length bytes of \a name as a * grpc_compression_algorithm instance, updating \a algorithm. Returns 1 upon * success, 0 otherwise. */ @@ -84,6 +91,20 @@ grpc_compression_level grpc_compression_level_for_algorithm( grpc_compression_algorithm grpc_compression_algorithm_for_level( grpc_compression_level level); +void grpc_compression_options_init(grpc_compression_options *opts); + +/** Mark \a algorithm as enabled in \a opts. */ +void grpc_compression_options_enable_algorithm( + grpc_compression_options *opts, grpc_compression_algorithm algorithm); + +/** Mark \a algorithm as disabled in \a opts. */ +void grpc_compression_options_disable_algorithm( + grpc_compression_options *opts, grpc_compression_algorithm algorithm); + +/** Returns true if \a algorithm is marked as enabled in \a opts. */ +int grpc_compression_options_is_algorithm_enabled( + const grpc_compression_options *opts, grpc_compression_algorithm algorithm); + #ifdef __cplusplus } #endif diff --git a/src/core/channel/channel_args.c b/src/core/channel/channel_args.c index 10199f7719c..7d97b795531 100644 --- a/src/core/channel/channel_args.c +++ b/src/core/channel/channel_args.c @@ -148,16 +148,19 @@ grpc_channel_args *grpc_channel_args_set_compression_algorithm( return grpc_channel_args_copy_and_add(a, &tmp, 1); } +/** Returns the compression algorithm's enabled states bitset from \a a. If not + * found, return a biset will all algorithms enabled */ static gpr_uint32 find_compression_algorithm_states_bitset( const grpc_channel_args *a) { - size_t i; - gpr_uint32 states_bitset = 0; - if (a == NULL) return 0; - for (i = 0; i < a->num_args; ++i) { - if (a->args[i].type == GRPC_ARG_INTEGER && - !strcmp(GRPC_COMPRESSION_ALGORITHM_STATE_ARG, a->args[i].key)) { - states_bitset = a->args[i].value.integer; - break; + gpr_uint32 states_bitset = (1u << GRPC_COMPRESS_ALGORITHMS_COUNT) - 1; + if (a != NULL) { + size_t i; + for (i = 0; i < a->num_args; ++i) { + if (a->args[i].type == GRPC_ARG_INTEGER && + !strcmp(GRPC_COMPRESSION_ALGORITHM_STATE_ARG, a->args[i].key)) { + states_bitset = a->args[i].value.integer; + break; + } } } return states_bitset; @@ -182,9 +185,7 @@ grpc_channel_args *grpc_channel_args_compression_algorithm_set_state( return grpc_channel_args_copy_and_add(a, &tmp, 1); } -int grpc_channel_args_compression_algorithm_get_state( - grpc_channel_args *a, - grpc_compression_algorithm algorithm) { - const gpr_uint32 states_bitset = find_compression_algorithm_states_bitset(a); - return GPR_BITGET(states_bitset, algorithm); +int grpc_channel_args_compression_algorithm_get_states( + const grpc_channel_args *a) { + return find_compression_algorithm_states_bitset(a); } diff --git a/src/core/channel/channel_args.h b/src/core/channel/channel_args.h index f1a75117af4..e557f9a9d92 100644 --- a/src/core/channel/channel_args.h +++ b/src/core/channel/channel_args.h @@ -68,17 +68,20 @@ grpc_channel_args *grpc_channel_args_set_compression_algorithm( grpc_channel_args *a, grpc_compression_algorithm algorithm); /** Sets the support for the given compression algorithm. By default, all - * compression algorithms are enabled. Disabling an algorithm set by - * grpc_channel_args_set_compression_algorithm disables compression altogether + * compression algorithms are enabled. It's an error to disable an algorithm set + * by grpc_channel_args_set_compression_algorithm. * */ grpc_channel_args *grpc_channel_args_compression_algorithm_set_state( grpc_channel_args *a, grpc_compression_algorithm algorithm, int enabled); -/** Returns the state (true for enabled, false for disabled) for \a algorithm */ -int grpc_channel_args_compression_algorithm_get_state( - grpc_channel_args *a, - grpc_compression_algorithm algorithm); +/** Returns the bitset representing the support state (true for enabled, false + * for disabled) for compression algorithms. + * + * The i-th bit of the returned bitset corresponds to the i-th entry in the + * grpc_compression_algorithm enum. */ +int grpc_channel_args_compression_algorithm_get_states( + const grpc_channel_args *a); #endif /* GRPC_INTERNAL_CORE_CHANNEL_CHANNEL_ARGS_H */ diff --git a/src/core/channel/compress_filter.c b/src/core/channel/compress_filter.c index 2fd4c8cae6c..065fe258dc8 100644 --- a/src/core/channel/compress_filter.c +++ b/src/core/channel/compress_filter.c @@ -70,6 +70,8 @@ typedef struct channel_data { grpc_mdelem *mdelem_accept_encoding; /** The default, channel-level, compression algorithm */ grpc_compression_algorithm default_compression_algorithm; + /** Compression options for the channel */ + grpc_compression_options compression_options; } channel_data; /** Compress \a slices in place using \a algorithm. Returns 1 if compression did @@ -102,7 +104,17 @@ static grpc_mdelem* compression_md_filter(void *user_data, grpc_mdelem *md) { const char *md_c_str = grpc_mdstr_as_c_string(md->value); if (!grpc_compression_algorithm_parse(md_c_str, strlen(md_c_str), &calld->compression_algorithm)) { - gpr_log(GPR_ERROR, "Invalid compression algorithm: '%s'. Ignoring.", + gpr_log(GPR_ERROR, + "Invalid compression algorithm: '%s' (unknown). Ignoring.", + md_c_str); + calld->compression_algorithm = GRPC_COMPRESS_NONE; + } + if (grpc_compression_options_is_algorithm_enabled( + &channeld->compression_options, calld->compression_algorithm) == 0) + { + gpr_log(GPR_ERROR, + "Invalid compression algorithm: '%s' (previously disabled). " + "Ignoring.", md_c_str); calld->compression_algorithm = GRPC_COMPRESS_NONE; } @@ -297,8 +309,17 @@ static void init_channel_elem(grpc_channel_element *elem, grpc_channel *master, char *accept_encoding_str; size_t accept_encoding_str_len; + grpc_compression_options_init(&channeld->compression_options); + channeld->compression_options.enabled_algorithms_bitset = + grpc_channel_args_compression_algorithm_get_states(args); + channeld->default_compression_algorithm = grpc_channel_args_get_compression_algorithm(args); + /* Make sure the default isn't disabled. */ + GPR_ASSERT(grpc_compression_options_is_algorithm_enabled( + &channeld->compression_options, channeld->default_compression_algorithm)); + channeld->compression_options.default_compression_algorithm = + channeld->default_compression_algorithm; channeld->mdstr_request_compression_algorithm_key = grpc_mdstr_from_string(mdctx, GRPC_COMPRESS_REQUEST_ALGORITHM_KEY, 0); @@ -311,6 +332,11 @@ static void init_channel_elem(grpc_channel_element *elem, grpc_channel *master, for (algo_idx = 0; algo_idx < GRPC_COMPRESS_ALGORITHMS_COUNT; ++algo_idx) { char *algorithm_name; + /* skip disabled algorithms */ + if (grpc_compression_options_is_algorithm_enabled( + &channeld->compression_options, algo_idx) == 0) { + continue; + } GPR_ASSERT(grpc_compression_algorithm_name(algo_idx, &algorithm_name) != 0); channeld->mdelem_compression_algorithms[algo_idx] = grpc_mdelem_from_metadata_strings( diff --git a/src/core/compression/algorithm.c b/src/core/compression/algorithm.c index dbf4721d13e..6514fcd26f6 100644 --- a/src/core/compression/algorithm.c +++ b/src/core/compression/algorithm.c @@ -33,7 +33,9 @@ #include #include + #include +#include int grpc_compression_algorithm_parse(const char* name, size_t name_length, grpc_compression_algorithm *algorithm) { @@ -102,3 +104,24 @@ grpc_compression_level grpc_compression_level_for_algorithm( } abort(); } + +void grpc_compression_options_init(grpc_compression_options *opts) { + opts->enabled_algorithms_bitset = (1u << GRPC_COMPRESS_ALGORITHMS_COUNT)-1; + opts->default_compression_algorithm = GRPC_COMPRESS_NONE; +} + +void grpc_compression_options_enable_algorithm( + grpc_compression_options *opts, grpc_compression_algorithm algorithm) { + GPR_BITSET(&opts->enabled_algorithms_bitset, algorithm); +} + +void grpc_compression_options_disable_algorithm( + grpc_compression_options *opts, grpc_compression_algorithm algorithm) { + GPR_BITCLEAR(&opts->enabled_algorithms_bitset, algorithm); +} + +int grpc_compression_options_is_algorithm_enabled( + const grpc_compression_options *opts, + grpc_compression_algorithm algorithm) { + return GPR_BITGET(opts->enabled_algorithms_bitset, algorithm); +} diff --git a/src/cpp/server/server.cc b/src/cpp/server/server.cc index ab87b22f5fb..6e576ab8b39 100644 --- a/src/cpp/server/server.cc +++ b/src/cpp/server/server.cc @@ -163,27 +163,34 @@ class Server::SyncRequest GRPC_FINAL : public CompletionQueueTag { grpc_completion_queue* cq_; }; -static grpc_server* CreateServer(int max_message_size) { +static grpc_server* CreateServer( + int max_message_size, const grpc_compression_options& compression_options) { if (max_message_size > 0) { - grpc_arg arg; - arg.type = GRPC_ARG_INTEGER; - arg.key = const_cast(GRPC_ARG_MAX_MESSAGE_LENGTH); - arg.value.integer = max_message_size; - grpc_channel_args args = {1, &arg}; - return grpc_server_create(&args); + grpc_arg args[2]; + args[0].type = GRPC_ARG_INTEGER; + args[0].key = const_cast(GRPC_ARG_MAX_MESSAGE_LENGTH); + args[0].value.integer = max_message_size; + + args[1].type = GRPC_ARG_INTEGER; + args[1].key = const_cast(GRPC_COMPRESSION_ALGORITHM_STATE_ARG); + args[1].value.integer = compression_options.enabled_algorithms_bitset; + + grpc_channel_args channel_args = {2, args}; + return grpc_server_create(&channel_args); } else { return grpc_server_create(nullptr); } } Server::Server(ThreadPoolInterface* thread_pool, bool thread_pool_owned, - int max_message_size) + int max_message_size, + grpc_compression_options compression_options) : max_message_size_(max_message_size), started_(false), shutdown_(false), num_running_cb_(0), sync_methods_(new std::list), - server_(CreateServer(max_message_size)), + server_(CreateServer(max_message_size, compression_options)), thread_pool_(thread_pool), thread_pool_owned_(thread_pool_owned) { grpc_server_register_completion_queue(server_, cq_.cq()); diff --git a/src/cpp/server/server_builder.cc b/src/cpp/server/server_builder.cc index f723d4611ae..425b0521280 100644 --- a/src/cpp/server/server_builder.cc +++ b/src/cpp/server/server_builder.cc @@ -42,7 +42,9 @@ namespace grpc { ServerBuilder::ServerBuilder() - : max_message_size_(-1), generic_service_(nullptr), thread_pool_(nullptr) {} + : max_message_size_(-1), generic_service_(nullptr), thread_pool_(nullptr) { + grpc_compression_options_init(&compression_options_); +} std::unique_ptr ServerBuilder::AddCompletionQueue() { ServerCompletionQueue* cq = new ServerCompletionQueue(); @@ -50,44 +52,65 @@ std::unique_ptr ServerBuilder::AddCompletionQueue() { return std::unique_ptr(cq); } -void ServerBuilder::RegisterService(SynchronousService* service) { +ServerBuilder& ServerBuilder::RegisterService(SynchronousService* service) { services_.emplace_back(new NamedService(service->service())); + return *this; } -void ServerBuilder::RegisterAsyncService(AsynchronousService* service) { +ServerBuilder& ServerBuilder::RegisterAsyncService( + AsynchronousService* service) { async_services_.emplace_back(new NamedService(service)); + return *this; } -void ServerBuilder::RegisterService( +ServerBuilder& ServerBuilder::RegisterService( const grpc::string& addr, SynchronousService* service) { services_.emplace_back(new NamedService(addr, service->service())); + return *this; } -void ServerBuilder::RegisterAsyncService( +ServerBuilder& ServerBuilder::RegisterAsyncService( const grpc::string& addr, AsynchronousService* service) { - async_services_.emplace_back(new NamedService(addr, service)); + async_services_.emplace_back( + new NamedService(addr, service)); + return *this; } -void ServerBuilder::RegisterAsyncGenericService(AsyncGenericService* service) { +ServerBuilder& ServerBuilder::RegisterAsyncGenericService( + AsyncGenericService* service) { if (generic_service_) { gpr_log(GPR_ERROR, "Adding multiple AsyncGenericService is unsupported for now. " "Dropping the service %p", service); - return; + } else { + generic_service_ = service; } - generic_service_ = service; + return *this; +} + +ServerBuilder& ServerBuilder::SetMaxMessageSize(int max_message_size) { + max_message_size_ = max_message_size; + return *this; } -void ServerBuilder::AddListeningPort(const grpc::string& addr, +ServerBuilder& ServerBuilder::AddListeningPort(const grpc::string& addr, std::shared_ptr creds, int* selected_port) { Port port = {addr, creds, selected_port}; ports_.push_back(port); + return *this; } -void ServerBuilder::SetThreadPool(ThreadPoolInterface* thread_pool) { +ServerBuilder& ServerBuilder::SetThreadPool(ThreadPoolInterface* thread_pool) { thread_pool_ = thread_pool; + return *this; +} + +ServerBuilder& ServerBuilder::SetCompressionOptions( + const grpc_compression_options& options) { + compression_options_ = options; + return *this; } std::unique_ptr ServerBuilder::BuildAndStart() { @@ -100,8 +123,9 @@ std::unique_ptr ServerBuilder::BuildAndStart() { thread_pool_ = CreateDefaultThreadPool(); thread_pool_owned = true; } - std::unique_ptr server( - new Server(thread_pool_, thread_pool_owned, max_message_size_)); + std::unique_ptr server(new Server(thread_pool_, thread_pool_owned, + max_message_size_, + compression_options_)); for (auto cq = cqs_.begin(); cq != cqs_.end(); ++cq) { grpc_server_register_completion_queue(server->server_, (*cq)->cq()); } @@ -113,7 +137,8 @@ std::unique_ptr ServerBuilder::BuildAndStart() { } for (auto service = async_services_.begin(); service != async_services_.end(); service++) { - if (!server->RegisterAsyncService((*service)->host.get(), (*service)->service)) { + if (!server->RegisterAsyncService((*service)->host.get(), + (*service)->service)) { return nullptr; } }