Expose max message size at the server side

pull/1464/head
Yang Gao 10 years ago
parent 6d42a73bb9
commit 3921c56bee
  1. 5
      include/grpc++/config.h
  2. 10
      include/grpc++/impl/call.h
  3. 7
      include/grpc++/server.h
  4. 6
      include/grpc++/server_builder.h
  5. 13
      src/cpp/common/call.cc
  6. 9
      src/cpp/proto/proto_utils.cc
  7. 3
      src/cpp/proto/proto_utils.h
  8. 30
      src/cpp/server/server.cc
  9. 5
      src/cpp/server/server_builder.cc
  10. 16
      test/cpp/end2end/end2end_test.cc

@ -93,13 +93,17 @@
#endif
#ifndef GRPC_CUSTOM_ZEROCOPYOUTPUTSTREAM
#include <google/protobuf/io/coded_stream.h>
#include <google/protobuf/io/zero_copy_stream.h>
#define GRPC_CUSTOM_ZEROCOPYOUTPUTSTREAM \
::google::protobuf::io::ZeroCopyOutputStream
#define GRPC_CUSTOM_ZEROCOPYINPUTSTREAM \
::google::protobuf::io::ZeroCopyInputStream
#define GRPC_CUSTOM_CODEDINPUTSTREAM \
::google::protobuf::io::CodedInputStream
#endif
#ifdef GRPC_CXX0X_NO_NULLPTR
#include <memory>
const class {
@ -126,6 +130,7 @@ typedef GRPC_CUSTOM_PROTOBUF_INT64 int64;
namespace io {
typedef GRPC_CUSTOM_ZEROCOPYOUTPUTSTREAM ZeroCopyOutputStream;
typedef GRPC_CUSTOM_ZEROCOPYINPUTSTREAM ZeroCopyInputStream;
typedef GRPC_CUSTOM_CODEDINPUTSTREAM CodedInputStream;
} // namespace io
} // namespace protobuf

@ -80,6 +80,10 @@ class CallOpBuffer : public CompletionQueueTag {
// Called by completion queue just prior to returning from Next() or Pluck()
bool FinalizeResult(void** tag, bool* status) GRPC_OVERRIDE;
void set_max_message_size(int max_message_size) {
max_message_size_ = max_message_size;
}
bool got_message;
private:
@ -99,6 +103,7 @@ class CallOpBuffer : public CompletionQueueTag {
grpc::protobuf::Message* recv_message_;
ByteBuffer* recv_message_buffer_;
grpc_byte_buffer* recv_buf_;
int max_message_size_;
// Client send close
bool client_send_close_;
// Client recv status
@ -130,16 +135,21 @@ class Call GRPC_FINAL {
public:
/* call is owned by the caller */
Call(grpc_call* call, CallHook* call_hook_, CompletionQueue* cq);
Call(grpc_call* call, CallHook* call_hook_, CompletionQueue* cq,
int max_message_size);
void PerformOps(CallOpBuffer* buffer);
grpc_call* call() { return call_; }
CompletionQueue* cq() { return cq_; }
int max_message_size() { return max_message_size_; }
private:
CallHook* call_hook_;
CompletionQueue* cq_;
grpc_call* call_;
int max_message_size_;
};
} // namespace grpc

@ -79,7 +79,8 @@ class Server GRPC_FINAL : public GrpcLibrary,
class AsyncRequest;
// ServerBuilder use only
Server(ThreadPoolInterface* thread_pool, bool thread_pool_owned);
Server(ThreadPoolInterface* thread_pool, bool thread_pool_owned,
int max_message_size);
// Register a service. This call does not take ownership of the service.
// The service must exist for the lifetime of the Server instance.
bool RegisterService(RpcService* service);
@ -106,6 +107,8 @@ class Server GRPC_FINAL : public GrpcLibrary,
ServerAsyncStreamingInterface* stream,
CompletionQueue* cq, void* tag);
const int max_message_size_;
// Completion queue.
CompletionQueue cq_;
@ -126,7 +129,7 @@ class Server GRPC_FINAL : public GrpcLibrary,
// Whether the thread pool is created and owned by the server.
bool thread_pool_owned_;
private:
Server() : server_(NULL) { abort(); }
Server() : max_message_size_(-1), server_(NULL) { abort(); }
};
} // namespace grpc

@ -68,6 +68,11 @@ class ServerBuilder {
// Register a generic service.
void RegisterAsyncGenericService(AsyncGenericService* service);
// Set max message size in bytes.
void SetMaxMessageSize(int max_message_size) {
max_message_size_ = max_message_size;
}
// Add a listening port. Can be called multiple times.
void AddListeningPort(const grpc::string& addr,
std::shared_ptr<ServerCredentials> creds,
@ -87,6 +92,7 @@ class ServerBuilder {
int* selected_port;
};
int max_message_size_;
std::vector<RpcService*> services_;
std::vector<AsynchronousService*> async_services_;
std::vector<Port> ports_;

@ -55,6 +55,7 @@ CallOpBuffer::CallOpBuffer()
recv_message_(nullptr),
recv_message_buffer_(nullptr),
recv_buf_(nullptr),
max_message_size_(-1),
client_send_close_(false),
recv_trailing_metadata_(nullptr),
recv_status_(nullptr),
@ -311,7 +312,7 @@ bool CallOpBuffer::FinalizeResult(void** tag, bool* status) {
got_message = *status;
if (recv_message_) {
GRPC_TIMER_MARK(DESER_PROTO_BEGIN, 0);
*status = *status && DeserializeProto(recv_buf_, recv_message_);
*status = *status && DeserializeProto(recv_buf_, recv_message_, max_message_size_);
grpc_byte_buffer_destroy(recv_buf_);
GRPC_TIMER_MARK(DESER_PROTO_END, 0);
} else {
@ -338,9 +339,17 @@ bool CallOpBuffer::FinalizeResult(void** tag, bool* status) {
}
Call::Call(grpc_call* call, CallHook* call_hook, CompletionQueue* cq)
: call_hook_(call_hook), cq_(cq), call_(call) {}
: call_hook_(call_hook), cq_(cq), call_(call), max_message_size_(-1) {}
Call::Call(grpc_call* call, CallHook* call_hook, CompletionQueue* cq,
int max_message_size)
: call_hook_(call_hook), cq_(cq), call_(call),
max_message_size_(max_message_size) {}
void Call::PerformOps(CallOpBuffer* buffer) {
if (max_message_size_ > 0) {
buffer->set_max_message_size(max_message_size_);
}
call_hook_->PerformOpsOnCall(buffer, this);
}

@ -158,9 +158,14 @@ bool SerializeProto(const grpc::protobuf::Message& msg, grpc_byte_buffer** bp) {
return msg.SerializeToZeroCopyStream(&writer);
}
bool DeserializeProto(grpc_byte_buffer* buffer, grpc::protobuf::Message* msg) {
bool DeserializeProto(grpc_byte_buffer* buffer, grpc::protobuf::Message* msg,
int max_message_size) {
GrpcBufferReader reader(buffer);
return msg->ParseFromZeroCopyStream(&reader);
::grpc::protobuf::io::CodedInputStream decoder(&reader);
if (max_message_size > 0) {
decoder.SetTotalBytesLimit(max_message_size, max_message_size);
}
return msg->ParseFromCodedStream(&decoder) && decoder.ConsumedEntireMessage();
}
} // namespace grpc

@ -47,7 +47,8 @@ bool SerializeProto(const grpc::protobuf::Message& msg,
grpc_byte_buffer** buffer);
// The caller keeps ownership of buffer and msg.
bool DeserializeProto(grpc_byte_buffer* buffer, grpc::protobuf::Message* msg);
bool DeserializeProto(grpc_byte_buffer* buffer, grpc::protobuf::Message* msg,
int max_message_size);
} // namespace grpc

@ -100,7 +100,7 @@ class Server::SyncRequest GRPC_FINAL : public CompletionQueueTag {
public:
explicit CallData(Server* server, SyncRequest* mrd)
: cq_(mrd->cq_),
call_(mrd->call_, server, &cq_),
call_(mrd->call_, server, &cq_, server->max_message_size_),
ctx_(mrd->deadline_, mrd->request_metadata_.metadata,
mrd->request_metadata_.count),
has_request_payload_(mrd->has_request_payload_),
@ -126,7 +126,7 @@ class Server::SyncRequest GRPC_FINAL : public CompletionQueueTag {
if (has_request_payload_) {
GRPC_TIMER_MARK(DESER_PROTO_BEGIN, call_.call());
req.reset(method_->AllocateRequestProto());
if (!DeserializeProto(request_payload_, req.get())) {
if (!DeserializeProto(request_payload_, req.get(), call_.max_message_size())) {
abort(); // for now
}
GRPC_TIMER_MARK(DESER_PROTO_END, call_.call());
@ -176,12 +176,27 @@ class Server::SyncRequest GRPC_FINAL : public CompletionQueueTag {
grpc_completion_queue* cq_;
};
Server::Server(ThreadPoolInterface* thread_pool, bool thread_pool_owned)
: started_(false),
grpc_server* CreateServer(grpc_completion_queue* cq, int max_message_size) {
if (max_message_size > 0) {
grpc_arg arg;
arg.type = GRPC_ARG_INTEGER;
arg.key = const_cast<char*>(GRPC_ARG_MAX_MESSAGE_LENGTH);
arg.value.integer = max_message_size;
grpc_channel_args args = {1, &arg};
return grpc_server_create(cq, &args);
} else {
return grpc_server_create(cq, nullptr);
}
}
Server::Server(ThreadPoolInterface* thread_pool, bool thread_pool_owned,
int max_message_size)
: max_message_size_(max_message_size),
started_(false),
shutdown_(false),
num_running_cb_(0),
sync_methods_(new std::list<SyncRequest>),
server_(grpc_server_create(cq_.cq(), nullptr)),
server_(CreateServer(cq_.cq(), max_message_size)),
thread_pool_(thread_pool),
thread_pool_owned_(thread_pool_owned) {}
@ -347,7 +362,8 @@ class Server::AsyncRequest GRPC_FINAL : public CompletionQueueTag {
if (*status && request_) {
if (payload_) {
GRPC_TIMER_MARK(DESER_PROTO_BEGIN, call_);
*status = DeserializeProto(payload_, request_);
*status = DeserializeProto(payload_, request_,
server_->max_message_size_);
GRPC_TIMER_MARK(DESER_PROTO_END, call_);
} else {
*status = false;
@ -374,7 +390,7 @@ class Server::AsyncRequest GRPC_FINAL : public CompletionQueueTag {
}
ctx->call_ = call_;
ctx->cq_ = cq_;
Call call(call_, server_, cq_);
Call call(call_, server_, cq_, server_->max_message_size_);
if (orig_status && call_) {
ctx->BeginCompletionOp(&call);
}

@ -42,7 +42,7 @@
namespace grpc {
ServerBuilder::ServerBuilder()
: generic_service_(nullptr), thread_pool_(nullptr) {}
: max_message_size_(-1), generic_service_(nullptr), thread_pool_(nullptr) {}
void ServerBuilder::RegisterService(SynchronousService* service) {
services_.push_back(service->service());
@ -86,7 +86,8 @@ std::unique_ptr<Server> ServerBuilder::BuildAndStart() {
thread_pool_ = new ThreadPool(cores);
thread_pool_owned = true;
}
std::unique_ptr<Server> server(new Server(thread_pool_, thread_pool_owned));
std::unique_ptr<Server> server(
new Server(thread_pool_, thread_pool_owned, max_message_size_));
for (auto service = services_.begin(); service != services_.end();
service++) {
if (!server->RegisterService(*service)) {

@ -172,7 +172,7 @@ class TestServiceImplDupPkg
class End2endTest : public ::testing::Test {
protected:
End2endTest() : thread_pool_(2) {}
End2endTest() : kMaxMessageSize_(8192), thread_pool_(2) {}
void SetUp() GRPC_OVERRIDE {
int port = grpc_pick_unused_port_or_die();
@ -182,6 +182,7 @@ class End2endTest : public ::testing::Test {
builder.AddListeningPort(server_address_.str(),
InsecureServerCredentials());
builder.RegisterService(&service_);
builder.SetMaxMessageSize(kMaxMessageSize_); // For testing max message size.
builder.RegisterService(&dup_pkg_service_);
builder.SetThreadPool(&thread_pool_);
server_ = builder.BuildAndStart();
@ -198,11 +199,13 @@ class End2endTest : public ::testing::Test {
std::unique_ptr<grpc::cpp::test::util::TestService::Stub> stub_;
std::unique_ptr<Server> server_;
std::ostringstream server_address_;
const int kMaxMessageSize_;
TestServiceImpl service_;
TestServiceImplDupPkg dup_pkg_service_;
ThreadPool thread_pool_;
};
/*
static void SendRpc(grpc::cpp::test::util::TestService::Stub* stub,
int num_rpcs) {
EchoRequest request;
@ -575,7 +578,18 @@ TEST_F(End2endTest, ClientCancelsBidi) {
Status s = stream->Finish();
EXPECT_EQ(grpc::StatusCode::CANCELLED, s.code());
}
*/
TEST_F(End2endTest, RpcMaxMessageSize) {
ResetStub();
EchoRequest request;
EchoResponse response;
request.set_message(string(kMaxMessageSize_*2, 'a'));
ClientContext context;
Status s = stub_->Echo(&context, request, &response);
EXPECT_FALSE(s.IsOk());
}
} // namespace testing
} // namespace grpc

Loading…
Cancel
Save