Expose max message size at the server side

pull/1464/head
Yang Gao 10 years ago
parent 6d42a73bb9
commit 3921c56bee
  1. 5
      include/grpc++/config.h
  2. 10
      include/grpc++/impl/call.h
  3. 7
      include/grpc++/server.h
  4. 6
      include/grpc++/server_builder.h
  5. 13
      src/cpp/common/call.cc
  6. 9
      src/cpp/proto/proto_utils.cc
  7. 3
      src/cpp/proto/proto_utils.h
  8. 30
      src/cpp/server/server.cc
  9. 5
      src/cpp/server/server_builder.cc
  10. 16
      test/cpp/end2end/end2end_test.cc

@ -93,13 +93,17 @@
#endif #endif
#ifndef GRPC_CUSTOM_ZEROCOPYOUTPUTSTREAM #ifndef GRPC_CUSTOM_ZEROCOPYOUTPUTSTREAM
#include <google/protobuf/io/coded_stream.h>
#include <google/protobuf/io/zero_copy_stream.h> #include <google/protobuf/io/zero_copy_stream.h>
#define GRPC_CUSTOM_ZEROCOPYOUTPUTSTREAM \ #define GRPC_CUSTOM_ZEROCOPYOUTPUTSTREAM \
::google::protobuf::io::ZeroCopyOutputStream ::google::protobuf::io::ZeroCopyOutputStream
#define GRPC_CUSTOM_ZEROCOPYINPUTSTREAM \ #define GRPC_CUSTOM_ZEROCOPYINPUTSTREAM \
::google::protobuf::io::ZeroCopyInputStream ::google::protobuf::io::ZeroCopyInputStream
#define GRPC_CUSTOM_CODEDINPUTSTREAM \
::google::protobuf::io::CodedInputStream
#endif #endif
#ifdef GRPC_CXX0X_NO_NULLPTR #ifdef GRPC_CXX0X_NO_NULLPTR
#include <memory> #include <memory>
const class { const class {
@ -126,6 +130,7 @@ typedef GRPC_CUSTOM_PROTOBUF_INT64 int64;
namespace io { namespace io {
typedef GRPC_CUSTOM_ZEROCOPYOUTPUTSTREAM ZeroCopyOutputStream; typedef GRPC_CUSTOM_ZEROCOPYOUTPUTSTREAM ZeroCopyOutputStream;
typedef GRPC_CUSTOM_ZEROCOPYINPUTSTREAM ZeroCopyInputStream; typedef GRPC_CUSTOM_ZEROCOPYINPUTSTREAM ZeroCopyInputStream;
typedef GRPC_CUSTOM_CODEDINPUTSTREAM CodedInputStream;
} // namespace io } // namespace io
} // namespace protobuf } // namespace protobuf

@ -80,6 +80,10 @@ class CallOpBuffer : public CompletionQueueTag {
// Called by completion queue just prior to returning from Next() or Pluck() // Called by completion queue just prior to returning from Next() or Pluck()
bool FinalizeResult(void** tag, bool* status) GRPC_OVERRIDE; bool FinalizeResult(void** tag, bool* status) GRPC_OVERRIDE;
void set_max_message_size(int max_message_size) {
max_message_size_ = max_message_size;
}
bool got_message; bool got_message;
private: private:
@ -99,6 +103,7 @@ class CallOpBuffer : public CompletionQueueTag {
grpc::protobuf::Message* recv_message_; grpc::protobuf::Message* recv_message_;
ByteBuffer* recv_message_buffer_; ByteBuffer* recv_message_buffer_;
grpc_byte_buffer* recv_buf_; grpc_byte_buffer* recv_buf_;
int max_message_size_;
// Client send close // Client send close
bool client_send_close_; bool client_send_close_;
// Client recv status // Client recv status
@ -130,16 +135,21 @@ class Call GRPC_FINAL {
public: public:
/* call is owned by the caller */ /* call is owned by the caller */
Call(grpc_call* call, CallHook* call_hook_, CompletionQueue* cq); Call(grpc_call* call, CallHook* call_hook_, CompletionQueue* cq);
Call(grpc_call* call, CallHook* call_hook_, CompletionQueue* cq,
int max_message_size);
void PerformOps(CallOpBuffer* buffer); void PerformOps(CallOpBuffer* buffer);
grpc_call* call() { return call_; } grpc_call* call() { return call_; }
CompletionQueue* cq() { return cq_; } CompletionQueue* cq() { return cq_; }
int max_message_size() { return max_message_size_; }
private: private:
CallHook* call_hook_; CallHook* call_hook_;
CompletionQueue* cq_; CompletionQueue* cq_;
grpc_call* call_; grpc_call* call_;
int max_message_size_;
}; };
} // namespace grpc } // namespace grpc

@ -79,7 +79,8 @@ class Server GRPC_FINAL : public GrpcLibrary,
class AsyncRequest; class AsyncRequest;
// ServerBuilder use only // ServerBuilder use only
Server(ThreadPoolInterface* thread_pool, bool thread_pool_owned); Server(ThreadPoolInterface* thread_pool, bool thread_pool_owned,
int max_message_size);
// Register a service. This call does not take ownership of the service. // Register a service. This call does not take ownership of the service.
// The service must exist for the lifetime of the Server instance. // The service must exist for the lifetime of the Server instance.
bool RegisterService(RpcService* service); bool RegisterService(RpcService* service);
@ -106,6 +107,8 @@ class Server GRPC_FINAL : public GrpcLibrary,
ServerAsyncStreamingInterface* stream, ServerAsyncStreamingInterface* stream,
CompletionQueue* cq, void* tag); CompletionQueue* cq, void* tag);
const int max_message_size_;
// Completion queue. // Completion queue.
CompletionQueue cq_; CompletionQueue cq_;
@ -126,7 +129,7 @@ class Server GRPC_FINAL : public GrpcLibrary,
// Whether the thread pool is created and owned by the server. // Whether the thread pool is created and owned by the server.
bool thread_pool_owned_; bool thread_pool_owned_;
private: private:
Server() : server_(NULL) { abort(); } Server() : max_message_size_(-1), server_(NULL) { abort(); }
}; };
} // namespace grpc } // namespace grpc

@ -68,6 +68,11 @@ class ServerBuilder {
// Register a generic service. // Register a generic service.
void RegisterAsyncGenericService(AsyncGenericService* service); void RegisterAsyncGenericService(AsyncGenericService* service);
// Set max message size in bytes.
void SetMaxMessageSize(int max_message_size) {
max_message_size_ = max_message_size;
}
// Add a listening port. Can be called multiple times. // Add a listening port. Can be called multiple times.
void AddListeningPort(const grpc::string& addr, void AddListeningPort(const grpc::string& addr,
std::shared_ptr<ServerCredentials> creds, std::shared_ptr<ServerCredentials> creds,
@ -87,6 +92,7 @@ class ServerBuilder {
int* selected_port; int* selected_port;
}; };
int max_message_size_;
std::vector<RpcService*> services_; std::vector<RpcService*> services_;
std::vector<AsynchronousService*> async_services_; std::vector<AsynchronousService*> async_services_;
std::vector<Port> ports_; std::vector<Port> ports_;

@ -55,6 +55,7 @@ CallOpBuffer::CallOpBuffer()
recv_message_(nullptr), recv_message_(nullptr),
recv_message_buffer_(nullptr), recv_message_buffer_(nullptr),
recv_buf_(nullptr), recv_buf_(nullptr),
max_message_size_(-1),
client_send_close_(false), client_send_close_(false),
recv_trailing_metadata_(nullptr), recv_trailing_metadata_(nullptr),
recv_status_(nullptr), recv_status_(nullptr),
@ -311,7 +312,7 @@ bool CallOpBuffer::FinalizeResult(void** tag, bool* status) {
got_message = *status; got_message = *status;
if (recv_message_) { if (recv_message_) {
GRPC_TIMER_MARK(DESER_PROTO_BEGIN, 0); GRPC_TIMER_MARK(DESER_PROTO_BEGIN, 0);
*status = *status && DeserializeProto(recv_buf_, recv_message_); *status = *status && DeserializeProto(recv_buf_, recv_message_, max_message_size_);
grpc_byte_buffer_destroy(recv_buf_); grpc_byte_buffer_destroy(recv_buf_);
GRPC_TIMER_MARK(DESER_PROTO_END, 0); GRPC_TIMER_MARK(DESER_PROTO_END, 0);
} else { } else {
@ -338,9 +339,17 @@ bool CallOpBuffer::FinalizeResult(void** tag, bool* status) {
} }
Call::Call(grpc_call* call, CallHook* call_hook, CompletionQueue* cq) Call::Call(grpc_call* call, CallHook* call_hook, CompletionQueue* cq)
: call_hook_(call_hook), cq_(cq), call_(call) {} : call_hook_(call_hook), cq_(cq), call_(call), max_message_size_(-1) {}
Call::Call(grpc_call* call, CallHook* call_hook, CompletionQueue* cq,
int max_message_size)
: call_hook_(call_hook), cq_(cq), call_(call),
max_message_size_(max_message_size) {}
void Call::PerformOps(CallOpBuffer* buffer) { void Call::PerformOps(CallOpBuffer* buffer) {
if (max_message_size_ > 0) {
buffer->set_max_message_size(max_message_size_);
}
call_hook_->PerformOpsOnCall(buffer, this); call_hook_->PerformOpsOnCall(buffer, this);
} }

@ -158,9 +158,14 @@ bool SerializeProto(const grpc::protobuf::Message& msg, grpc_byte_buffer** bp) {
return msg.SerializeToZeroCopyStream(&writer); return msg.SerializeToZeroCopyStream(&writer);
} }
bool DeserializeProto(grpc_byte_buffer* buffer, grpc::protobuf::Message* msg) { bool DeserializeProto(grpc_byte_buffer* buffer, grpc::protobuf::Message* msg,
int max_message_size) {
GrpcBufferReader reader(buffer); GrpcBufferReader reader(buffer);
return msg->ParseFromZeroCopyStream(&reader); ::grpc::protobuf::io::CodedInputStream decoder(&reader);
if (max_message_size > 0) {
decoder.SetTotalBytesLimit(max_message_size, max_message_size);
}
return msg->ParseFromCodedStream(&decoder) && decoder.ConsumedEntireMessage();
} }
} // namespace grpc } // namespace grpc

@ -47,7 +47,8 @@ bool SerializeProto(const grpc::protobuf::Message& msg,
grpc_byte_buffer** buffer); grpc_byte_buffer** buffer);
// The caller keeps ownership of buffer and msg. // The caller keeps ownership of buffer and msg.
bool DeserializeProto(grpc_byte_buffer* buffer, grpc::protobuf::Message* msg); bool DeserializeProto(grpc_byte_buffer* buffer, grpc::protobuf::Message* msg,
int max_message_size);
} // namespace grpc } // namespace grpc

@ -100,7 +100,7 @@ class Server::SyncRequest GRPC_FINAL : public CompletionQueueTag {
public: public:
explicit CallData(Server* server, SyncRequest* mrd) explicit CallData(Server* server, SyncRequest* mrd)
: cq_(mrd->cq_), : cq_(mrd->cq_),
call_(mrd->call_, server, &cq_), call_(mrd->call_, server, &cq_, server->max_message_size_),
ctx_(mrd->deadline_, mrd->request_metadata_.metadata, ctx_(mrd->deadline_, mrd->request_metadata_.metadata,
mrd->request_metadata_.count), mrd->request_metadata_.count),
has_request_payload_(mrd->has_request_payload_), has_request_payload_(mrd->has_request_payload_),
@ -126,7 +126,7 @@ class Server::SyncRequest GRPC_FINAL : public CompletionQueueTag {
if (has_request_payload_) { if (has_request_payload_) {
GRPC_TIMER_MARK(DESER_PROTO_BEGIN, call_.call()); GRPC_TIMER_MARK(DESER_PROTO_BEGIN, call_.call());
req.reset(method_->AllocateRequestProto()); req.reset(method_->AllocateRequestProto());
if (!DeserializeProto(request_payload_, req.get())) { if (!DeserializeProto(request_payload_, req.get(), call_.max_message_size())) {
abort(); // for now abort(); // for now
} }
GRPC_TIMER_MARK(DESER_PROTO_END, call_.call()); GRPC_TIMER_MARK(DESER_PROTO_END, call_.call());
@ -176,12 +176,27 @@ class Server::SyncRequest GRPC_FINAL : public CompletionQueueTag {
grpc_completion_queue* cq_; grpc_completion_queue* cq_;
}; };
Server::Server(ThreadPoolInterface* thread_pool, bool thread_pool_owned) grpc_server* CreateServer(grpc_completion_queue* cq, int max_message_size) {
: started_(false), if (max_message_size > 0) {
grpc_arg arg;
arg.type = GRPC_ARG_INTEGER;
arg.key = const_cast<char*>(GRPC_ARG_MAX_MESSAGE_LENGTH);
arg.value.integer = max_message_size;
grpc_channel_args args = {1, &arg};
return grpc_server_create(cq, &args);
} else {
return grpc_server_create(cq, nullptr);
}
}
Server::Server(ThreadPoolInterface* thread_pool, bool thread_pool_owned,
int max_message_size)
: max_message_size_(max_message_size),
started_(false),
shutdown_(false), shutdown_(false),
num_running_cb_(0), num_running_cb_(0),
sync_methods_(new std::list<SyncRequest>), sync_methods_(new std::list<SyncRequest>),
server_(grpc_server_create(cq_.cq(), nullptr)), server_(CreateServer(cq_.cq(), max_message_size)),
thread_pool_(thread_pool), thread_pool_(thread_pool),
thread_pool_owned_(thread_pool_owned) {} thread_pool_owned_(thread_pool_owned) {}
@ -347,7 +362,8 @@ class Server::AsyncRequest GRPC_FINAL : public CompletionQueueTag {
if (*status && request_) { if (*status && request_) {
if (payload_) { if (payload_) {
GRPC_TIMER_MARK(DESER_PROTO_BEGIN, call_); GRPC_TIMER_MARK(DESER_PROTO_BEGIN, call_);
*status = DeserializeProto(payload_, request_); *status = DeserializeProto(payload_, request_,
server_->max_message_size_);
GRPC_TIMER_MARK(DESER_PROTO_END, call_); GRPC_TIMER_MARK(DESER_PROTO_END, call_);
} else { } else {
*status = false; *status = false;
@ -374,7 +390,7 @@ class Server::AsyncRequest GRPC_FINAL : public CompletionQueueTag {
} }
ctx->call_ = call_; ctx->call_ = call_;
ctx->cq_ = cq_; ctx->cq_ = cq_;
Call call(call_, server_, cq_); Call call(call_, server_, cq_, server_->max_message_size_);
if (orig_status && call_) { if (orig_status && call_) {
ctx->BeginCompletionOp(&call); ctx->BeginCompletionOp(&call);
} }

@ -42,7 +42,7 @@
namespace grpc { namespace grpc {
ServerBuilder::ServerBuilder() ServerBuilder::ServerBuilder()
: generic_service_(nullptr), thread_pool_(nullptr) {} : max_message_size_(-1), generic_service_(nullptr), thread_pool_(nullptr) {}
void ServerBuilder::RegisterService(SynchronousService* service) { void ServerBuilder::RegisterService(SynchronousService* service) {
services_.push_back(service->service()); services_.push_back(service->service());
@ -86,7 +86,8 @@ std::unique_ptr<Server> ServerBuilder::BuildAndStart() {
thread_pool_ = new ThreadPool(cores); thread_pool_ = new ThreadPool(cores);
thread_pool_owned = true; thread_pool_owned = true;
} }
std::unique_ptr<Server> server(new Server(thread_pool_, thread_pool_owned)); std::unique_ptr<Server> server(
new Server(thread_pool_, thread_pool_owned, max_message_size_));
for (auto service = services_.begin(); service != services_.end(); for (auto service = services_.begin(); service != services_.end();
service++) { service++) {
if (!server->RegisterService(*service)) { if (!server->RegisterService(*service)) {

@ -172,7 +172,7 @@ class TestServiceImplDupPkg
class End2endTest : public ::testing::Test { class End2endTest : public ::testing::Test {
protected: protected:
End2endTest() : thread_pool_(2) {} End2endTest() : kMaxMessageSize_(8192), thread_pool_(2) {}
void SetUp() GRPC_OVERRIDE { void SetUp() GRPC_OVERRIDE {
int port = grpc_pick_unused_port_or_die(); int port = grpc_pick_unused_port_or_die();
@ -182,6 +182,7 @@ class End2endTest : public ::testing::Test {
builder.AddListeningPort(server_address_.str(), builder.AddListeningPort(server_address_.str(),
InsecureServerCredentials()); InsecureServerCredentials());
builder.RegisterService(&service_); builder.RegisterService(&service_);
builder.SetMaxMessageSize(kMaxMessageSize_); // For testing max message size.
builder.RegisterService(&dup_pkg_service_); builder.RegisterService(&dup_pkg_service_);
builder.SetThreadPool(&thread_pool_); builder.SetThreadPool(&thread_pool_);
server_ = builder.BuildAndStart(); server_ = builder.BuildAndStart();
@ -198,11 +199,13 @@ class End2endTest : public ::testing::Test {
std::unique_ptr<grpc::cpp::test::util::TestService::Stub> stub_; std::unique_ptr<grpc::cpp::test::util::TestService::Stub> stub_;
std::unique_ptr<Server> server_; std::unique_ptr<Server> server_;
std::ostringstream server_address_; std::ostringstream server_address_;
const int kMaxMessageSize_;
TestServiceImpl service_; TestServiceImpl service_;
TestServiceImplDupPkg dup_pkg_service_; TestServiceImplDupPkg dup_pkg_service_;
ThreadPool thread_pool_; ThreadPool thread_pool_;
}; };
/*
static void SendRpc(grpc::cpp::test::util::TestService::Stub* stub, static void SendRpc(grpc::cpp::test::util::TestService::Stub* stub,
int num_rpcs) { int num_rpcs) {
EchoRequest request; EchoRequest request;
@ -575,7 +578,18 @@ TEST_F(End2endTest, ClientCancelsBidi) {
Status s = stream->Finish(); Status s = stream->Finish();
EXPECT_EQ(grpc::StatusCode::CANCELLED, s.code()); EXPECT_EQ(grpc::StatusCode::CANCELLED, s.code());
} }
*/
TEST_F(End2endTest, RpcMaxMessageSize) {
ResetStub();
EchoRequest request;
EchoResponse response;
request.set_message(string(kMaxMessageSize_*2, 'a'));
ClientContext context;
Status s = stub_->Echo(&context, request, &response);
EXPECT_FALSE(s.IsOk());
}
} // namespace testing } // namespace testing
} // namespace grpc } // namespace grpc

Loading…
Cancel
Save