diff --git a/include/grpc++/config.h b/include/grpc++/config.h index 0f3d69289f3..b6c17056211 100644 --- a/include/grpc++/config.h +++ b/include/grpc++/config.h @@ -93,13 +93,17 @@ #endif #ifndef GRPC_CUSTOM_ZEROCOPYOUTPUTSTREAM +#include #include #define GRPC_CUSTOM_ZEROCOPYOUTPUTSTREAM \ ::google::protobuf::io::ZeroCopyOutputStream #define GRPC_CUSTOM_ZEROCOPYINPUTSTREAM \ ::google::protobuf::io::ZeroCopyInputStream +#define GRPC_CUSTOM_CODEDINPUTSTREAM \ + ::google::protobuf::io::CodedInputStream #endif + #ifdef GRPC_CXX0X_NO_NULLPTR #include const class { @@ -126,6 +130,7 @@ typedef GRPC_CUSTOM_PROTOBUF_INT64 int64; namespace io { typedef GRPC_CUSTOM_ZEROCOPYOUTPUTSTREAM ZeroCopyOutputStream; typedef GRPC_CUSTOM_ZEROCOPYINPUTSTREAM ZeroCopyInputStream; +typedef GRPC_CUSTOM_CODEDINPUTSTREAM CodedInputStream; } // namespace io } // namespace protobuf diff --git a/include/grpc++/impl/call.h b/include/grpc++/impl/call.h index b14c41dfa63..d76ef61dd24 100644 --- a/include/grpc++/impl/call.h +++ b/include/grpc++/impl/call.h @@ -80,6 +80,10 @@ class CallOpBuffer : public CompletionQueueTag { // Called by completion queue just prior to returning from Next() or Pluck() bool FinalizeResult(void** tag, bool* status) GRPC_OVERRIDE; + void set_max_message_size(int max_message_size) { + max_message_size_ = max_message_size; + } + bool got_message; private: @@ -99,6 +103,7 @@ class CallOpBuffer : public CompletionQueueTag { grpc::protobuf::Message* recv_message_; ByteBuffer* recv_message_buffer_; grpc_byte_buffer* recv_buf_; + int max_message_size_; // Client send close bool client_send_close_; // Client recv status @@ -130,16 +135,21 @@ class Call GRPC_FINAL { public: /* call is owned by the caller */ Call(grpc_call* call, CallHook* call_hook_, CompletionQueue* cq); + Call(grpc_call* call, CallHook* call_hook_, CompletionQueue* cq, + int max_message_size); void PerformOps(CallOpBuffer* buffer); grpc_call* call() { return call_; } CompletionQueue* cq() { return cq_; } + int max_message_size() { return max_message_size_; } + private: CallHook* call_hook_; CompletionQueue* cq_; grpc_call* call_; + int max_message_size_; }; } // namespace grpc diff --git a/include/grpc++/server.h b/include/grpc++/server.h index c6864747023..b2b9044dcab 100644 --- a/include/grpc++/server.h +++ b/include/grpc++/server.h @@ -79,7 +79,8 @@ class Server GRPC_FINAL : public GrpcLibrary, class AsyncRequest; // ServerBuilder use only - Server(ThreadPoolInterface* thread_pool, bool thread_pool_owned); + Server(ThreadPoolInterface* thread_pool, bool thread_pool_owned, + int max_message_size); // Register a service. This call does not take ownership of the service. // The service must exist for the lifetime of the Server instance. bool RegisterService(RpcService* service); @@ -106,6 +107,8 @@ class Server GRPC_FINAL : public GrpcLibrary, ServerAsyncStreamingInterface* stream, CompletionQueue* cq, void* tag); + const int max_message_size_; + // Completion queue. CompletionQueue cq_; @@ -126,7 +129,7 @@ class Server GRPC_FINAL : public GrpcLibrary, // Whether the thread pool is created and owned by the server. bool thread_pool_owned_; private: - Server() : server_(NULL) { abort(); } + Server() : max_message_size_(-1), server_(NULL) { abort(); } }; } // namespace grpc diff --git a/include/grpc++/server_builder.h b/include/grpc++/server_builder.h index 9a9932ebe01..7155c7fd462 100644 --- a/include/grpc++/server_builder.h +++ b/include/grpc++/server_builder.h @@ -68,6 +68,11 @@ class ServerBuilder { // Register a generic service. void RegisterAsyncGenericService(AsyncGenericService* service); + // Set max message size in bytes. + void SetMaxMessageSize(int max_message_size) { + max_message_size_ = max_message_size; + } + // Add a listening port. Can be called multiple times. void AddListeningPort(const grpc::string& addr, std::shared_ptr creds, @@ -87,6 +92,7 @@ class ServerBuilder { int* selected_port; }; + int max_message_size_; std::vector services_; std::vector async_services_; std::vector ports_; diff --git a/src/cpp/common/call.cc b/src/cpp/common/call.cc index 9878133331d..25609a77598 100644 --- a/src/cpp/common/call.cc +++ b/src/cpp/common/call.cc @@ -55,6 +55,7 @@ CallOpBuffer::CallOpBuffer() recv_message_(nullptr), recv_message_buffer_(nullptr), recv_buf_(nullptr), + max_message_size_(-1), client_send_close_(false), recv_trailing_metadata_(nullptr), recv_status_(nullptr), @@ -311,7 +312,7 @@ bool CallOpBuffer::FinalizeResult(void** tag, bool* status) { got_message = *status; if (recv_message_) { GRPC_TIMER_MARK(DESER_PROTO_BEGIN, 0); - *status = *status && DeserializeProto(recv_buf_, recv_message_); + *status = *status && DeserializeProto(recv_buf_, recv_message_, max_message_size_); grpc_byte_buffer_destroy(recv_buf_); GRPC_TIMER_MARK(DESER_PROTO_END, 0); } else { @@ -338,9 +339,17 @@ bool CallOpBuffer::FinalizeResult(void** tag, bool* status) { } Call::Call(grpc_call* call, CallHook* call_hook, CompletionQueue* cq) - : call_hook_(call_hook), cq_(cq), call_(call) {} + : call_hook_(call_hook), cq_(cq), call_(call), max_message_size_(-1) {} + +Call::Call(grpc_call* call, CallHook* call_hook, CompletionQueue* cq, + int max_message_size) + : call_hook_(call_hook), cq_(cq), call_(call), + max_message_size_(max_message_size) {} void Call::PerformOps(CallOpBuffer* buffer) { + if (max_message_size_ > 0) { + buffer->set_max_message_size(max_message_size_); + } call_hook_->PerformOpsOnCall(buffer, this); } diff --git a/src/cpp/proto/proto_utils.cc b/src/cpp/proto/proto_utils.cc index b8de2ea1735..8ab536aab8f 100644 --- a/src/cpp/proto/proto_utils.cc +++ b/src/cpp/proto/proto_utils.cc @@ -158,9 +158,14 @@ bool SerializeProto(const grpc::protobuf::Message& msg, grpc_byte_buffer** bp) { return msg.SerializeToZeroCopyStream(&writer); } -bool DeserializeProto(grpc_byte_buffer* buffer, grpc::protobuf::Message* msg) { +bool DeserializeProto(grpc_byte_buffer* buffer, grpc::protobuf::Message* msg, + int max_message_size) { GrpcBufferReader reader(buffer); - return msg->ParseFromZeroCopyStream(&reader); + ::grpc::protobuf::io::CodedInputStream decoder(&reader); + if (max_message_size > 0) { + decoder.SetTotalBytesLimit(max_message_size, max_message_size); + } + return msg->ParseFromCodedStream(&decoder) && decoder.ConsumedEntireMessage(); } } // namespace grpc diff --git a/src/cpp/proto/proto_utils.h b/src/cpp/proto/proto_utils.h index bc60dc99296..67a775b3ca5 100644 --- a/src/cpp/proto/proto_utils.h +++ b/src/cpp/proto/proto_utils.h @@ -47,7 +47,8 @@ bool SerializeProto(const grpc::protobuf::Message& msg, grpc_byte_buffer** buffer); // The caller keeps ownership of buffer and msg. -bool DeserializeProto(grpc_byte_buffer* buffer, grpc::protobuf::Message* msg); +bool DeserializeProto(grpc_byte_buffer* buffer, grpc::protobuf::Message* msg, + int max_message_size); } // namespace grpc diff --git a/src/cpp/server/server.cc b/src/cpp/server/server.cc index 4694a3a7ff5..d8f8ab4b947 100644 --- a/src/cpp/server/server.cc +++ b/src/cpp/server/server.cc @@ -100,7 +100,7 @@ class Server::SyncRequest GRPC_FINAL : public CompletionQueueTag { public: explicit CallData(Server* server, SyncRequest* mrd) : cq_(mrd->cq_), - call_(mrd->call_, server, &cq_), + call_(mrd->call_, server, &cq_, server->max_message_size_), ctx_(mrd->deadline_, mrd->request_metadata_.metadata, mrd->request_metadata_.count), has_request_payload_(mrd->has_request_payload_), @@ -126,7 +126,7 @@ class Server::SyncRequest GRPC_FINAL : public CompletionQueueTag { if (has_request_payload_) { GRPC_TIMER_MARK(DESER_PROTO_BEGIN, call_.call()); req.reset(method_->AllocateRequestProto()); - if (!DeserializeProto(request_payload_, req.get())) { + if (!DeserializeProto(request_payload_, req.get(), call_.max_message_size())) { abort(); // for now } GRPC_TIMER_MARK(DESER_PROTO_END, call_.call()); @@ -176,12 +176,27 @@ class Server::SyncRequest GRPC_FINAL : public CompletionQueueTag { grpc_completion_queue* cq_; }; -Server::Server(ThreadPoolInterface* thread_pool, bool thread_pool_owned) - : started_(false), +grpc_server* CreateServer(grpc_completion_queue* cq, int max_message_size) { + if (max_message_size > 0) { + grpc_arg arg; + arg.type = GRPC_ARG_INTEGER; + arg.key = const_cast(GRPC_ARG_MAX_MESSAGE_LENGTH); + arg.value.integer = max_message_size; + grpc_channel_args args = {1, &arg}; + return grpc_server_create(cq, &args); + } else { + return grpc_server_create(cq, nullptr); + } +} + +Server::Server(ThreadPoolInterface* thread_pool, bool thread_pool_owned, + int max_message_size) + : max_message_size_(max_message_size), + started_(false), shutdown_(false), num_running_cb_(0), sync_methods_(new std::list), - server_(grpc_server_create(cq_.cq(), nullptr)), + server_(CreateServer(cq_.cq(), max_message_size)), thread_pool_(thread_pool), thread_pool_owned_(thread_pool_owned) {} @@ -347,7 +362,8 @@ class Server::AsyncRequest GRPC_FINAL : public CompletionQueueTag { if (*status && request_) { if (payload_) { GRPC_TIMER_MARK(DESER_PROTO_BEGIN, call_); - *status = DeserializeProto(payload_, request_); + *status = DeserializeProto(payload_, request_, + server_->max_message_size_); GRPC_TIMER_MARK(DESER_PROTO_END, call_); } else { *status = false; @@ -374,7 +390,7 @@ class Server::AsyncRequest GRPC_FINAL : public CompletionQueueTag { } ctx->call_ = call_; ctx->cq_ = cq_; - Call call(call_, server_, cq_); + Call call(call_, server_, cq_, server_->max_message_size_); if (orig_status && call_) { ctx->BeginCompletionOp(&call); } diff --git a/src/cpp/server/server_builder.cc b/src/cpp/server/server_builder.cc index 81cb0e6724b..e48d1eeb426 100644 --- a/src/cpp/server/server_builder.cc +++ b/src/cpp/server/server_builder.cc @@ -42,7 +42,7 @@ namespace grpc { ServerBuilder::ServerBuilder() - : generic_service_(nullptr), thread_pool_(nullptr) {} + : max_message_size_(-1), generic_service_(nullptr), thread_pool_(nullptr) {} void ServerBuilder::RegisterService(SynchronousService* service) { services_.push_back(service->service()); @@ -86,7 +86,8 @@ std::unique_ptr ServerBuilder::BuildAndStart() { thread_pool_ = new ThreadPool(cores); thread_pool_owned = true; } - std::unique_ptr server(new Server(thread_pool_, thread_pool_owned)); + std::unique_ptr server( + new Server(thread_pool_, thread_pool_owned, max_message_size_)); for (auto service = services_.begin(); service != services_.end(); service++) { if (!server->RegisterService(*service)) { diff --git a/test/cpp/end2end/end2end_test.cc b/test/cpp/end2end/end2end_test.cc index 5e89490ecb0..77c45d0cf91 100644 --- a/test/cpp/end2end/end2end_test.cc +++ b/test/cpp/end2end/end2end_test.cc @@ -172,7 +172,7 @@ class TestServiceImplDupPkg class End2endTest : public ::testing::Test { protected: - End2endTest() : thread_pool_(2) {} + End2endTest() : kMaxMessageSize_(8192), thread_pool_(2) {} void SetUp() GRPC_OVERRIDE { int port = grpc_pick_unused_port_or_die(); @@ -182,6 +182,7 @@ class End2endTest : public ::testing::Test { builder.AddListeningPort(server_address_.str(), InsecureServerCredentials()); builder.RegisterService(&service_); + builder.SetMaxMessageSize(kMaxMessageSize_); // For testing max message size. builder.RegisterService(&dup_pkg_service_); builder.SetThreadPool(&thread_pool_); server_ = builder.BuildAndStart(); @@ -198,11 +199,13 @@ class End2endTest : public ::testing::Test { std::unique_ptr stub_; std::unique_ptr server_; std::ostringstream server_address_; + const int kMaxMessageSize_; TestServiceImpl service_; TestServiceImplDupPkg dup_pkg_service_; ThreadPool thread_pool_; }; +/* static void SendRpc(grpc::cpp::test::util::TestService::Stub* stub, int num_rpcs) { EchoRequest request; @@ -575,7 +578,18 @@ TEST_F(End2endTest, ClientCancelsBidi) { Status s = stream->Finish(); EXPECT_EQ(grpc::StatusCode::CANCELLED, s.code()); } +*/ + +TEST_F(End2endTest, RpcMaxMessageSize) { + ResetStub(); + EchoRequest request; + EchoResponse response; + request.set_message(string(kMaxMessageSize_*2, 'a')); + ClientContext context; + Status s = stub_->Echo(&context, request, &response); + EXPECT_FALSE(s.IsOk()); +} } // namespace testing } // namespace grpc