Streaming API for callback servers

pull/17144/head
Vijay Pai 6 years ago
parent cdea58eb9a
commit 2a0c0d7ad6
  1. 8
      include/grpcpp/impl/codegen/byte_buffer.h
  2. 19
      include/grpcpp/impl/codegen/callback_common.h
  3. 774
      include/grpcpp/impl/codegen/server_callback.h
  4. 21
      include/grpcpp/impl/codegen/server_context.h
  5. 69
      src/compiler/cpp_generator.cc
  6. 7
      src/cpp/server/server_cc.cc
  7. 47
      src/cpp/server/server_context.cc
  8. 56
      test/cpp/codegen/compiler_test_golden
  9. 90
      test/cpp/end2end/end2end_test.cc
  10. 314
      test/cpp/end2end/test_service_impl.cc
  11. 21
      test/cpp/end2end/test_service_impl.h

@ -45,8 +45,10 @@ template <class ServiceType, class RequestType, class ResponseType>
class RpcMethodHandler;
template <class ServiceType, class RequestType, class ResponseType>
class ServerStreamingHandler;
template <class ServiceType, class RequestType, class ResponseType>
template <class RequestType, class ResponseType>
class CallbackUnaryHandler;
template <class RequestType, class ResponseType>
class CallbackServerStreamingHandler;
template <StatusCode code>
class ErrorMethodHandler;
template <class R>
@ -156,8 +158,10 @@ class ByteBuffer final {
friend class internal::RpcMethodHandler;
template <class ServiceType, class RequestType, class ResponseType>
friend class internal::ServerStreamingHandler;
template <class ServiceType, class RequestType, class ResponseType>
template <class RequestType, class ResponseType>
friend class internal::CallbackUnaryHandler;
template <class RequestType, class ResponseType>
friend class ::grpc::internal::CallbackServerStreamingHandler;
template <StatusCode code>
friend class internal::ErrorMethodHandler;
template <class R>

@ -32,6 +32,8 @@ namespace grpc {
namespace internal {
/// An exception-safe way of invoking a user-specified callback function
// TODO(vjpai): decide whether it is better for this to take a const lvalue
// parameter or an rvalue parameter, or if it even matters
template <class Func, class... Args>
void CatchingCallback(Func&& func, Args&&... args) {
#if GRPC_ALLOW_EXCEPTIONS
@ -45,6 +47,20 @@ void CatchingCallback(Func&& func, Args&&... args) {
#endif // GRPC_ALLOW_EXCEPTIONS
}
template <class ReturnType, class Func, class... Args>
ReturnType* CatchingReactorCreator(Func&& func, Args&&... args) {
#if GRPC_ALLOW_EXCEPTIONS
try {
return func(std::forward<Args>(args)...);
} catch (...) {
// fail the RPC, don't crash the library
return nullptr;
}
#else // GRPC_ALLOW_EXCEPTIONS
return func(std::forward<Args>(args)...);
#endif // GRPC_ALLOW_EXCEPTIONS
}
// The contract on these tags is that they are single-shot. They must be
// constructed and then fired at exactly one point. There is no expectation
// that they can be reused without reconstruction.
@ -185,8 +201,9 @@ class CallbackWithSuccessTag
void* ignored = ops_;
// Allow a "false" return value from FinalizeResult to silence the
// callback, just as it silences a CQ tag in the async cases
auto* ops = ops_;
bool do_callback = ops_->FinalizeResult(&ignored, &ok);
GPR_CODEGEN_ASSERT(ignored == ops_);
GPR_CODEGEN_ASSERT(ignored == ops);
if (do_callback) {
CatchingCallback(func_, ok);

@ -19,7 +19,9 @@
#ifndef GRPCPP_IMPL_CODEGEN_SERVER_CALLBACK_H
#define GRPCPP_IMPL_CODEGEN_SERVER_CALLBACK_H
#include <atomic>
#include <functional>
#include <type_traits>
#include <grpcpp/impl/codegen/call.h>
#include <grpcpp/impl/codegen/call_op_set.h>
@ -32,19 +34,33 @@
namespace grpc {
// forward declarations
// Declare base class of all reactors as internal
namespace internal {
template <class ServiceType, class RequestType, class ResponseType>
class CallbackUnaryHandler;
class ServerReactor {
public:
virtual ~ServerReactor() = default;
virtual void OnDone() {}
virtual void OnCancel() {}
};
} // namespace internal
namespace experimental {
// Forward declarations
template <class Request, class Response>
class ServerReadReactor;
template <class Request, class Response>
class ServerWriteReactor;
template <class Request, class Response>
class ServerBidiReactor;
// For unary RPCs, the exposed controller class is only an interface
// and the actual implementation is an internal class.
class ServerCallbackRpcController {
public:
virtual ~ServerCallbackRpcController() {}
virtual ~ServerCallbackRpcController() = default;
// The method handler must call this function when it is done so that
// the library knows to free its resources
@ -55,18 +71,193 @@ class ServerCallbackRpcController {
virtual void SendInitialMetadata(std::function<void(bool)>) = 0;
};
// NOTE: The actual streaming object classes are provided
// as API only to support mocking. There are no implementations of
// these class interfaces in the API.
template <class Request>
class ServerCallbackReader {
public:
virtual ~ServerCallbackReader() {}
virtual void Finish(Status s) = 0;
virtual void SendInitialMetadata() = 0;
virtual void Read(Request* msg) = 0;
protected:
template <class Response>
void BindReactor(ServerReadReactor<Request, Response>* reactor) {
reactor->BindReader(this);
}
};
template <class Response>
class ServerCallbackWriter {
public:
virtual ~ServerCallbackWriter() {}
virtual void Finish(Status s) = 0;
virtual void SendInitialMetadata() = 0;
virtual void Write(const Response* msg, WriteOptions options) = 0;
virtual void WriteAndFinish(const Response* msg, WriteOptions options,
Status s) {
// Default implementation that can/should be overridden
Write(msg, std::move(options));
Finish(std::move(s));
};
protected:
template <class Request>
void BindReactor(ServerWriteReactor<Request, Response>* reactor) {
reactor->BindWriter(this);
}
};
template <class Request, class Response>
class ServerCallbackReaderWriter {
public:
virtual ~ServerCallbackReaderWriter() {}
virtual void Finish(Status s) = 0;
virtual void SendInitialMetadata() = 0;
virtual void Read(Request* msg) = 0;
virtual void Write(const Response* msg, WriteOptions options) = 0;
virtual void WriteAndFinish(const Response* msg, WriteOptions options,
Status s) {
// Default implementation that can/should be overridden
Write(msg, std::move(options));
Finish(std::move(s));
};
protected:
void BindReactor(ServerBidiReactor<Request, Response>* reactor) {
reactor->BindStream(this);
}
};
// The following classes are reactors that are to be implemented
// by the user, returned as the result of the method handler for
// a callback method, and activated by the call to OnStarted
template <class Request, class Response>
class ServerBidiReactor : public internal::ServerReactor {
public:
~ServerBidiReactor() = default;
virtual void OnStarted(ServerContext*) {}
virtual void OnSendInitialMetadataDone(bool ok) {}
virtual void OnReadDone(bool ok) {}
virtual void OnWriteDone(bool ok) {}
void StartSendInitialMetadata() { stream_->SendInitialMetadata(); }
void StartRead(Request* msg) { stream_->Read(msg); }
void StartWrite(const Response* msg) { StartWrite(msg, WriteOptions()); }
void StartWrite(const Response* msg, WriteOptions options) {
stream_->Write(msg, std::move(options));
}
void StartWriteAndFinish(const Response* msg, WriteOptions options,
Status s) {
stream_->WriteAndFinish(msg, std::move(options), std::move(s));
}
void StartWriteLast(const Response* msg, WriteOptions options) {
StartWrite(msg, std::move(options.set_last_message()));
}
void Finish(Status s) { stream_->Finish(std::move(s)); }
private:
friend class ServerCallbackReaderWriter<Request, Response>;
void BindStream(ServerCallbackReaderWriter<Request, Response>* stream) {
stream_ = stream;
}
ServerCallbackReaderWriter<Request, Response>* stream_;
};
template <class Request, class Response>
class ServerReadReactor : public internal::ServerReactor {
public:
~ServerReadReactor() = default;
virtual void OnStarted(ServerContext*, Response* resp) {}
virtual void OnSendInitialMetadataDone(bool ok) {}
virtual void OnReadDone(bool ok) {}
void StartSendInitialMetadata() { reader_->SendInitialMetadata(); }
void StartRead(Request* msg) { reader_->Read(msg); }
void Finish(Status s) { reader_->Finish(std::move(s)); }
private:
friend class ServerCallbackReader<Request>;
void BindReader(ServerCallbackReader<Request>* reader) { reader_ = reader; }
ServerCallbackReader<Request>* reader_;
};
template <class Request, class Response>
class ServerWriteReactor : public internal::ServerReactor {
public:
~ServerWriteReactor() = default;
virtual void OnStarted(ServerContext*, const Request* req) {}
virtual void OnSendInitialMetadataDone(bool ok) {}
virtual void OnWriteDone(bool ok) {}
void StartSendInitialMetadata() { writer_->SendInitialMetadata(); }
void StartWrite(const Response* msg) { StartWrite(msg, WriteOptions()); }
void StartWrite(const Response* msg, WriteOptions options) {
writer_->Write(msg, std::move(options));
}
void StartWriteAndFinish(const Response* msg, WriteOptions options,
Status s) {
writer_->WriteAndFinish(msg, std::move(options), std::move(s));
}
void StartWriteLast(const Response* msg, WriteOptions options) {
StartWrite(msg, std::move(options.set_last_message()));
}
void Finish(Status s) { writer_->Finish(std::move(s)); }
private:
friend class ServerCallbackWriter<Response>;
void BindWriter(ServerCallbackWriter<Response>* writer) { writer_ = writer; }
ServerCallbackWriter<Response>* writer_;
};
} // namespace experimental
namespace internal {
template <class ServiceType, class RequestType, class ResponseType>
template <class Request, class Response>
class UnimplementedReadReactor
: public experimental::ServerReadReactor<Request, Response> {
public:
void OnDone() override { delete this; }
void OnStarted(ServerContext*, Response*) override {
this->Finish(Status(StatusCode::UNIMPLEMENTED, ""));
}
};
template <class Request, class Response>
class UnimplementedWriteReactor
: public experimental::ServerWriteReactor<Request, Response> {
public:
void OnDone() override { delete this; }
void OnStarted(ServerContext*, const Request*) override {
this->Finish(Status(StatusCode::UNIMPLEMENTED, ""));
}
};
template <class Request, class Response>
class UnimplementedBidiReactor
: public experimental::ServerBidiReactor<Request, Response> {
public:
void OnDone() override { delete this; }
void OnStarted(ServerContext*) override {
this->Finish(Status(StatusCode::UNIMPLEMENTED, ""));
}
};
template <class RequestType, class ResponseType>
class CallbackUnaryHandler : public MethodHandler {
public:
CallbackUnaryHandler(
std::function<void(ServerContext*, const RequestType*, ResponseType*,
experimental::ServerCallbackRpcController*)>
func,
ServiceType* service)
func)
: func_(func) {}
void RunHandler(const HandlerParameter& param) final {
// Arena allocate a controller structure (that includes request/response)
@ -81,9 +272,8 @@ class CallbackUnaryHandler : public MethodHandler {
if (status.ok()) {
// Call the actual function handler and expect the user to call finish
CatchingCallback(std::move(func_), param.server_context,
controller->request(), controller->response(),
controller);
CatchingCallback(func_, param.server_context, controller->request(),
controller->response(), controller);
} else {
// if deserialization failed, we need to fail the call
controller->Finish(status);
@ -117,79 +307,579 @@ class CallbackUnaryHandler : public MethodHandler {
: public experimental::ServerCallbackRpcController {
public:
void Finish(Status s) override {
finish_tag_.Set(
call_.call(),
[this](bool) {
grpc_call* call = call_.call();
auto call_requester = std::move(call_requester_);
this->~ServerCallbackRpcControllerImpl(); // explicitly call
// destructor
g_core_codegen_interface->grpc_call_unref(call);
call_requester();
},
&finish_buf_);
finish_tag_.Set(call_.call(), [this](bool) { MaybeDone(); },
&finish_ops_);
if (!ctx_->sent_initial_metadata_) {
finish_buf_.SendInitialMetadata(&ctx_->initial_metadata_,
finish_ops_.SendInitialMetadata(&ctx_->initial_metadata_,
ctx_->initial_metadata_flags());
if (ctx_->compression_level_set()) {
finish_buf_.set_compression_level(ctx_->compression_level());
finish_ops_.set_compression_level(ctx_->compression_level());
}
ctx_->sent_initial_metadata_ = true;
}
// The response is dropped if the status is not OK.
if (s.ok()) {
finish_buf_.ServerSendStatus(&ctx_->trailing_metadata_,
finish_buf_.SendMessage(resp_));
finish_ops_.ServerSendStatus(&ctx_->trailing_metadata_,
finish_ops_.SendMessage(resp_));
} else {
finish_buf_.ServerSendStatus(&ctx_->trailing_metadata_, s);
finish_ops_.ServerSendStatus(&ctx_->trailing_metadata_, s);
}
finish_buf_.set_core_cq_tag(&finish_tag_);
call_.PerformOps(&finish_buf_);
finish_ops_.set_core_cq_tag(&finish_tag_);
call_.PerformOps(&finish_ops_);
}
void SendInitialMetadata(std::function<void(bool)> f) override {
GPR_CODEGEN_ASSERT(!ctx_->sent_initial_metadata_);
meta_tag_.Set(call_.call(), std::move(f), &meta_buf_);
meta_buf_.SendInitialMetadata(&ctx_->initial_metadata_,
callbacks_outstanding_++;
// TODO(vjpai): Consider taking f as a move-capture if we adopt C++14
// and if performance of this operation matters
meta_tag_.Set(call_.call(),
[this, f](bool ok) {
f(ok);
MaybeDone();
},
&meta_ops_);
meta_ops_.SendInitialMetadata(&ctx_->initial_metadata_,
ctx_->initial_metadata_flags());
if (ctx_->compression_level_set()) {
meta_buf_.set_compression_level(ctx_->compression_level());
meta_ops_.set_compression_level(ctx_->compression_level());
}
ctx_->sent_initial_metadata_ = true;
meta_buf_.set_core_cq_tag(&meta_tag_);
call_.PerformOps(&meta_buf_);
meta_ops_.set_core_cq_tag(&meta_tag_);
call_.PerformOps(&meta_ops_);
}
private:
template <class SrvType, class ReqType, class RespType>
friend class CallbackUnaryHandler;
friend class CallbackUnaryHandler<RequestType, ResponseType>;
ServerCallbackRpcControllerImpl(ServerContext* ctx, Call* call,
RequestType* req,
const RequestType* req,
std::function<void()> call_requester)
: ctx_(ctx),
call_(*call),
req_(req),
call_requester_(std::move(call_requester)) {}
call_requester_(std::move(call_requester)) {
ctx_->BeginCompletionOp(call, [this](bool) { MaybeDone(); }, nullptr);
}
~ServerCallbackRpcControllerImpl() { req_->~RequestType(); }
RequestType* request() { return req_; }
const RequestType* request() { return req_; }
ResponseType* response() { return &resp_; }
CallOpSet<CallOpSendInitialMetadata> meta_buf_;
void MaybeDone() {
if (--callbacks_outstanding_ == 0) {
grpc_call* call = call_.call();
auto call_requester = std::move(call_requester_);
this->~ServerCallbackRpcControllerImpl(); // explicitly call destructor
g_core_codegen_interface->grpc_call_unref(call);
call_requester();
}
}
CallOpSet<CallOpSendInitialMetadata> meta_ops_;
CallbackWithSuccessTag meta_tag_;
CallOpSet<CallOpSendInitialMetadata, CallOpSendMessage,
CallOpServerSendStatus>
finish_buf_;
finish_ops_;
CallbackWithSuccessTag finish_tag_;
ServerContext* ctx_;
Call call_;
RequestType* req_;
const RequestType* req_;
ResponseType resp_;
std::function<void()> call_requester_;
std::atomic_int callbacks_outstanding_{
2}; // reserve for Finish and CompletionOp
};
};
template <class RequestType, class ResponseType>
class CallbackClientStreamingHandler : public MethodHandler {
public:
CallbackClientStreamingHandler(
std::function<
experimental::ServerReadReactor<RequestType, ResponseType>*()>
func)
: func_(std::move(func)) {}
void RunHandler(const HandlerParameter& param) final {
// Arena allocate a reader structure (that includes response)
g_core_codegen_interface->grpc_call_ref(param.call->call());
experimental::ServerReadReactor<RequestType, ResponseType>* reactor =
param.status.ok()
? CatchingReactorCreator<
experimental::ServerReadReactor<RequestType, ResponseType>>(
func_)
: nullptr;
if (reactor == nullptr) {
// if deserialization or reactor creator failed, we need to fail the call
reactor = new UnimplementedReadReactor<RequestType, ResponseType>;
}
auto* reader = new (g_core_codegen_interface->grpc_call_arena_alloc(
param.call->call(), sizeof(ServerCallbackReaderImpl)))
ServerCallbackReaderImpl(param.server_context, param.call,
std::move(param.call_requester), reactor);
reader->BindReactor(reactor);
reactor->OnStarted(param.server_context, reader->response());
reader->MaybeDone();
}
private:
std::function<experimental::ServerReadReactor<RequestType, ResponseType>*()>
func_;
class ServerCallbackReaderImpl
: public experimental::ServerCallbackReader<RequestType> {
public:
void Finish(Status s) override {
finish_tag_.Set(call_.call(), [this](bool) { MaybeDone(); },
&finish_ops_);
if (!ctx_->sent_initial_metadata_) {
finish_ops_.SendInitialMetadata(&ctx_->initial_metadata_,
ctx_->initial_metadata_flags());
if (ctx_->compression_level_set()) {
finish_ops_.set_compression_level(ctx_->compression_level());
}
ctx_->sent_initial_metadata_ = true;
}
// The response is dropped if the status is not OK.
if (s.ok()) {
finish_ops_.ServerSendStatus(&ctx_->trailing_metadata_,
finish_ops_.SendMessage(resp_));
} else {
finish_ops_.ServerSendStatus(&ctx_->trailing_metadata_, s);
}
finish_ops_.set_core_cq_tag(&finish_tag_);
call_.PerformOps(&finish_ops_);
}
void SendInitialMetadata() override {
GPR_CODEGEN_ASSERT(!ctx_->sent_initial_metadata_);
callbacks_outstanding_++;
meta_tag_.Set(call_.call(),
[this](bool ok) {
reactor_->OnSendInitialMetadataDone(ok);
MaybeDone();
},
&meta_ops_);
meta_ops_.SendInitialMetadata(&ctx_->initial_metadata_,
ctx_->initial_metadata_flags());
if (ctx_->compression_level_set()) {
meta_ops_.set_compression_level(ctx_->compression_level());
}
ctx_->sent_initial_metadata_ = true;
meta_ops_.set_core_cq_tag(&meta_tag_);
call_.PerformOps(&meta_ops_);
}
void Read(RequestType* req) override {
callbacks_outstanding_++;
read_ops_.RecvMessage(req);
call_.PerformOps(&read_ops_);
}
private:
friend class CallbackClientStreamingHandler<RequestType, ResponseType>;
ServerCallbackReaderImpl(
ServerContext* ctx, Call* call, std::function<void()> call_requester,
experimental::ServerReadReactor<RequestType, ResponseType>* reactor)
: ctx_(ctx),
call_(*call),
call_requester_(std::move(call_requester)),
reactor_(reactor) {
ctx_->BeginCompletionOp(call, [this](bool) { MaybeDone(); }, reactor);
read_tag_.Set(call_.call(),
[this](bool ok) {
reactor_->OnReadDone(ok);
MaybeDone();
},
&read_ops_);
read_ops_.set_core_cq_tag(&read_tag_);
}
~ServerCallbackReaderImpl() {}
ResponseType* response() { return &resp_; }
void MaybeDone() {
if (--callbacks_outstanding_ == 0) {
reactor_->OnDone();
grpc_call* call = call_.call();
auto call_requester = std::move(call_requester_);
this->~ServerCallbackReaderImpl(); // explicitly call destructor
g_core_codegen_interface->grpc_call_unref(call);
call_requester();
}
}
CallOpSet<CallOpSendInitialMetadata> meta_ops_;
CallbackWithSuccessTag meta_tag_;
CallOpSet<CallOpSendInitialMetadata, CallOpSendMessage,
CallOpServerSendStatus>
finish_ops_;
CallbackWithSuccessTag finish_tag_;
CallOpSet<CallOpRecvMessage<RequestType>> read_ops_;
CallbackWithSuccessTag read_tag_;
ServerContext* ctx_;
Call call_;
ResponseType resp_;
std::function<void()> call_requester_;
experimental::ServerReadReactor<RequestType, ResponseType>* reactor_;
std::atomic_int callbacks_outstanding_{
3}; // reserve for OnStarted, Finish, and CompletionOp
};
};
template <class RequestType, class ResponseType>
class CallbackServerStreamingHandler : public MethodHandler {
public:
CallbackServerStreamingHandler(
std::function<
experimental::ServerWriteReactor<RequestType, ResponseType>*()>
func)
: func_(std::move(func)) {}
void RunHandler(const HandlerParameter& param) final {
// Arena allocate a writer structure
g_core_codegen_interface->grpc_call_ref(param.call->call());
experimental::ServerWriteReactor<RequestType, ResponseType>* reactor =
param.status.ok()
? CatchingReactorCreator<
experimental::ServerWriteReactor<RequestType, ResponseType>>(
func_)
: nullptr;
if (reactor == nullptr) {
// if deserialization or reactor creator failed, we need to fail the call
reactor = new UnimplementedWriteReactor<RequestType, ResponseType>;
}
auto* writer = new (g_core_codegen_interface->grpc_call_arena_alloc(
param.call->call(), sizeof(ServerCallbackWriterImpl)))
ServerCallbackWriterImpl(param.server_context, param.call,
static_cast<RequestType*>(param.request),
std::move(param.call_requester), reactor);
writer->BindReactor(reactor);
reactor->OnStarted(param.server_context, writer->request());
writer->MaybeDone();
}
void* Deserialize(grpc_call* call, grpc_byte_buffer* req,
Status* status) final {
ByteBuffer buf;
buf.set_buffer(req);
auto* request = new (g_core_codegen_interface->grpc_call_arena_alloc(
call, sizeof(RequestType))) RequestType();
*status = SerializationTraits<RequestType>::Deserialize(&buf, request);
buf.Release();
if (status->ok()) {
return request;
}
request->~RequestType();
return nullptr;
}
private:
std::function<experimental::ServerWriteReactor<RequestType, ResponseType>*()>
func_;
class ServerCallbackWriterImpl
: public experimental::ServerCallbackWriter<ResponseType> {
public:
void Finish(Status s) override {
finish_tag_.Set(call_.call(), [this](bool) { MaybeDone(); },
&finish_ops_);
finish_ops_.set_core_cq_tag(&finish_tag_);
if (!ctx_->sent_initial_metadata_) {
finish_ops_.SendInitialMetadata(&ctx_->initial_metadata_,
ctx_->initial_metadata_flags());
if (ctx_->compression_level_set()) {
finish_ops_.set_compression_level(ctx_->compression_level());
}
ctx_->sent_initial_metadata_ = true;
}
finish_ops_.ServerSendStatus(&ctx_->trailing_metadata_, s);
call_.PerformOps(&finish_ops_);
}
void SendInitialMetadata() override {
GPR_CODEGEN_ASSERT(!ctx_->sent_initial_metadata_);
callbacks_outstanding_++;
meta_tag_.Set(call_.call(),
[this](bool ok) {
reactor_->OnSendInitialMetadataDone(ok);
MaybeDone();
},
&meta_ops_);
meta_ops_.SendInitialMetadata(&ctx_->initial_metadata_,
ctx_->initial_metadata_flags());
if (ctx_->compression_level_set()) {
meta_ops_.set_compression_level(ctx_->compression_level());
}
ctx_->sent_initial_metadata_ = true;
meta_ops_.set_core_cq_tag(&meta_tag_);
call_.PerformOps(&meta_ops_);
}
void Write(const ResponseType* resp, WriteOptions options) override {
callbacks_outstanding_++;
if (options.is_last_message()) {
options.set_buffer_hint();
}
if (!ctx_->sent_initial_metadata_) {
write_ops_.SendInitialMetadata(&ctx_->initial_metadata_,
ctx_->initial_metadata_flags());
if (ctx_->compression_level_set()) {
write_ops_.set_compression_level(ctx_->compression_level());
}
ctx_->sent_initial_metadata_ = true;
}
// TODO(vjpai): don't assert
GPR_CODEGEN_ASSERT(write_ops_.SendMessage(*resp, options).ok());
call_.PerformOps(&write_ops_);
}
void WriteAndFinish(const ResponseType* resp, WriteOptions options,
Status s) override {
// This combines the write into the finish callback
// Don't send any message if the status is bad
if (s.ok()) {
// TODO(vjpai): don't assert
GPR_CODEGEN_ASSERT(finish_ops_.SendMessage(*resp, options).ok());
}
Finish(std::move(s));
}
private:
friend class CallbackServerStreamingHandler<RequestType, ResponseType>;
ServerCallbackWriterImpl(
ServerContext* ctx, Call* call, const RequestType* req,
std::function<void()> call_requester,
experimental::ServerWriteReactor<RequestType, ResponseType>* reactor)
: ctx_(ctx),
call_(*call),
req_(req),
call_requester_(std::move(call_requester)),
reactor_(reactor) {
ctx_->BeginCompletionOp(call, [this](bool) { MaybeDone(); }, reactor);
write_tag_.Set(call_.call(),
[this](bool ok) {
reactor_->OnWriteDone(ok);
MaybeDone();
},
&write_ops_);
write_ops_.set_core_cq_tag(&write_tag_);
}
~ServerCallbackWriterImpl() { req_->~RequestType(); }
const RequestType* request() { return req_; }
void MaybeDone() {
if (--callbacks_outstanding_ == 0) {
reactor_->OnDone();
grpc_call* call = call_.call();
auto call_requester = std::move(call_requester_);
this->~ServerCallbackWriterImpl(); // explicitly call destructor
g_core_codegen_interface->grpc_call_unref(call);
call_requester();
}
}
CallOpSet<CallOpSendInitialMetadata> meta_ops_;
CallbackWithSuccessTag meta_tag_;
CallOpSet<CallOpSendInitialMetadata, CallOpSendMessage,
CallOpServerSendStatus>
finish_ops_;
CallbackWithSuccessTag finish_tag_;
CallOpSet<CallOpSendInitialMetadata, CallOpSendMessage> write_ops_;
CallbackWithSuccessTag write_tag_;
ServerContext* ctx_;
Call call_;
const RequestType* req_;
std::function<void()> call_requester_;
experimental::ServerWriteReactor<RequestType, ResponseType>* reactor_;
std::atomic_int callbacks_outstanding_{
3}; // reserve for OnStarted, Finish, and CompletionOp
};
};
template <class RequestType, class ResponseType>
class CallbackBidiHandler : public MethodHandler {
public:
CallbackBidiHandler(
std::function<
experimental::ServerBidiReactor<RequestType, ResponseType>*()>
func)
: func_(std::move(func)) {}
void RunHandler(const HandlerParameter& param) final {
g_core_codegen_interface->grpc_call_ref(param.call->call());
experimental::ServerBidiReactor<RequestType, ResponseType>* reactor =
param.status.ok()
? CatchingReactorCreator<
experimental::ServerBidiReactor<RequestType, ResponseType>>(
func_)
: nullptr;
if (reactor == nullptr) {
// if deserialization or reactor creator failed, we need to fail the call
reactor = new UnimplementedBidiReactor<RequestType, ResponseType>;
}
auto* stream = new (g_core_codegen_interface->grpc_call_arena_alloc(
param.call->call(), sizeof(ServerCallbackReaderWriterImpl)))
ServerCallbackReaderWriterImpl(param.server_context, param.call,
std::move(param.call_requester),
reactor);
stream->BindReactor(reactor);
reactor->OnStarted(param.server_context);
stream->MaybeDone();
}
private:
std::function<experimental::ServerBidiReactor<RequestType, ResponseType>*()>
func_;
class ServerCallbackReaderWriterImpl
: public experimental::ServerCallbackReaderWriter<RequestType,
ResponseType> {
public:
void Finish(Status s) override {
finish_tag_.Set(call_.call(), [this](bool) { MaybeDone(); },
&finish_ops_);
finish_ops_.set_core_cq_tag(&finish_tag_);
if (!ctx_->sent_initial_metadata_) {
finish_ops_.SendInitialMetadata(&ctx_->initial_metadata_,
ctx_->initial_metadata_flags());
if (ctx_->compression_level_set()) {
finish_ops_.set_compression_level(ctx_->compression_level());
}
ctx_->sent_initial_metadata_ = true;
}
finish_ops_.ServerSendStatus(&ctx_->trailing_metadata_, s);
call_.PerformOps(&finish_ops_);
}
void SendInitialMetadata() override {
GPR_CODEGEN_ASSERT(!ctx_->sent_initial_metadata_);
callbacks_outstanding_++;
meta_tag_.Set(call_.call(),
[this](bool ok) {
reactor_->OnSendInitialMetadataDone(ok);
MaybeDone();
},
&meta_ops_);
meta_ops_.SendInitialMetadata(&ctx_->initial_metadata_,
ctx_->initial_metadata_flags());
if (ctx_->compression_level_set()) {
meta_ops_.set_compression_level(ctx_->compression_level());
}
ctx_->sent_initial_metadata_ = true;
meta_ops_.set_core_cq_tag(&meta_tag_);
call_.PerformOps(&meta_ops_);
}
void Write(const ResponseType* resp, WriteOptions options) override {
callbacks_outstanding_++;
if (options.is_last_message()) {
options.set_buffer_hint();
}
if (!ctx_->sent_initial_metadata_) {
write_ops_.SendInitialMetadata(&ctx_->initial_metadata_,
ctx_->initial_metadata_flags());
if (ctx_->compression_level_set()) {
write_ops_.set_compression_level(ctx_->compression_level());
}
ctx_->sent_initial_metadata_ = true;
}
// TODO(vjpai): don't assert
GPR_CODEGEN_ASSERT(write_ops_.SendMessage(*resp, options).ok());
call_.PerformOps(&write_ops_);
}
void WriteAndFinish(const ResponseType* resp, WriteOptions options,
Status s) override {
// Don't send any message if the status is bad
if (s.ok()) {
// TODO(vjpai): don't assert
GPR_CODEGEN_ASSERT(finish_ops_.SendMessage(*resp, options).ok());
}
Finish(std::move(s));
}
void Read(RequestType* req) override {
callbacks_outstanding_++;
read_ops_.RecvMessage(req);
call_.PerformOps(&read_ops_);
}
private:
friend class CallbackBidiHandler<RequestType, ResponseType>;
ServerCallbackReaderWriterImpl(
ServerContext* ctx, Call* call, std::function<void()> call_requester,
experimental::ServerBidiReactor<RequestType, ResponseType>* reactor)
: ctx_(ctx),
call_(*call),
call_requester_(std::move(call_requester)),
reactor_(reactor) {
ctx_->BeginCompletionOp(call, [this](bool) { MaybeDone(); }, reactor);
write_tag_.Set(call_.call(),
[this](bool ok) {
reactor_->OnWriteDone(ok);
MaybeDone();
},
&write_ops_);
write_ops_.set_core_cq_tag(&write_tag_);
read_tag_.Set(call_.call(),
[this](bool ok) {
reactor_->OnReadDone(ok);
MaybeDone();
},
&read_ops_);
read_ops_.set_core_cq_tag(&read_tag_);
}
~ServerCallbackReaderWriterImpl() {}
void MaybeDone() {
if (--callbacks_outstanding_ == 0) {
reactor_->OnDone();
grpc_call* call = call_.call();
auto call_requester = std::move(call_requester_);
this->~ServerCallbackReaderWriterImpl(); // explicitly call destructor
g_core_codegen_interface->grpc_call_unref(call);
call_requester();
}
}
CallOpSet<CallOpSendInitialMetadata> meta_ops_;
CallbackWithSuccessTag meta_tag_;
CallOpSet<CallOpSendInitialMetadata, CallOpSendMessage,
CallOpServerSendStatus>
finish_ops_;
CallbackWithSuccessTag finish_tag_;
CallOpSet<CallOpSendInitialMetadata, CallOpSendMessage> write_ops_;
CallbackWithSuccessTag write_tag_;
CallOpSet<CallOpRecvMessage<RequestType>> read_ops_;
CallbackWithSuccessTag read_tag_;
ServerContext* ctx_;
Call call_;
std::function<void()> call_requester_;
experimental::ServerBidiReactor<RequestType, ResponseType>* reactor_;
std::atomic_int callbacks_outstanding_{
3}; // reserve for OnStarted, Finish, and CompletionOp
};
};

@ -66,13 +66,20 @@ template <class ServiceType, class RequestType, class ResponseType>
class ServerStreamingHandler;
template <class ServiceType, class RequestType, class ResponseType>
class BidiStreamingHandler;
template <class ServiceType, class RequestType, class ResponseType>
template <class RequestType, class ResponseType>
class CallbackUnaryHandler;
template <class RequestType, class ResponseType>
class CallbackClientStreamingHandler;
template <class RequestType, class ResponseType>
class CallbackServerStreamingHandler;
template <class RequestType, class ResponseType>
class CallbackBidiHandler;
template <class Streamer, bool WriteNeeded>
class TemplatedBidiStreamingHandler;
template <StatusCode code>
class ErrorMethodHandler;
class Call;
class ServerReactor;
} // namespace internal
class CompletionQueue;
@ -270,8 +277,14 @@ class ServerContext {
friend class ::grpc::internal::ServerStreamingHandler;
template <class Streamer, bool WriteNeeded>
friend class ::grpc::internal::TemplatedBidiStreamingHandler;
template <class ServiceType, class RequestType, class ResponseType>
template <class RequestType, class ResponseType>
friend class ::grpc::internal::CallbackUnaryHandler;
template <class RequestType, class ResponseType>
friend class ::grpc::internal::CallbackClientStreamingHandler;
template <class RequestType, class ResponseType>
friend class ::grpc::internal::CallbackServerStreamingHandler;
template <class RequestType, class ResponseType>
friend class ::grpc::internal::CallbackBidiHandler;
template <StatusCode code>
friend class internal::ErrorMethodHandler;
friend class ::grpc::ClientContext;
@ -282,7 +295,9 @@ class ServerContext {
class CompletionOp;
void BeginCompletionOp(internal::Call* call, bool callback);
void BeginCompletionOp(internal::Call* call,
std::function<void(bool)> callback,
internal::ServerReactor* reactor);
/// Return the tag queued by BeginCompletionOp()
internal::CompletionQueueTag* GetCompletionOpTag();

@ -889,6 +889,11 @@ void PrintHeaderServerCallbackMethodsHelper(
" abort();\n"
" return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, \"\");\n"
"}\n");
printer->Print(*vars,
"virtual ::grpc::experimental::ServerReadReactor< "
"$RealRequest$, $RealResponse$>* $Method$() {\n"
" return new ::grpc::internal::UnimplementedReadReactor<\n"
" $RealRequest$, $RealResponse$>;}\n");
} else if (ServerOnlyStreaming(method)) {
printer->Print(
*vars,
@ -900,6 +905,11 @@ void PrintHeaderServerCallbackMethodsHelper(
" abort();\n"
" return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, \"\");\n"
"}\n");
printer->Print(*vars,
"virtual ::grpc::experimental::ServerWriteReactor< "
"$RealRequest$, $RealResponse$>* $Method$() {\n"
" return new ::grpc::internal::UnimplementedWriteReactor<\n"
" $RealRequest$, $RealResponse$>;}\n");
} else if (method->BidiStreaming()) {
printer->Print(
*vars,
@ -911,6 +921,11 @@ void PrintHeaderServerCallbackMethodsHelper(
" abort();\n"
" return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, \"\");\n"
"}\n");
printer->Print(*vars,
"virtual ::grpc::experimental::ServerBidiReactor< "
"$RealRequest$, $RealResponse$>* $Method$() {\n"
" return new ::grpc::internal::UnimplementedBidiReactor<\n"
" $RealRequest$, $RealResponse$>;}\n");
}
}
@ -939,22 +954,36 @@ void PrintHeaderServerMethodCallback(
*vars,
" ::grpc::Service::experimental().MarkMethodCallback($Idx$,\n"
" new ::grpc::internal::CallbackUnaryHandler< "
"ExperimentalWithCallbackMethod_$Method$<BaseClass>, $RealRequest$, "
"$RealResponse$>(\n"
"$RealRequest$, $RealResponse$>(\n"
" [this](::grpc::ServerContext* context,\n"
" const $RealRequest$* request,\n"
" $RealResponse$* response,\n"
" ::grpc::experimental::ServerCallbackRpcController* "
"controller) {\n"
" this->$"
" return this->$"
"Method$(context, request, response, controller);\n"
" }, this));\n");
" }));\n");
} else if (ClientOnlyStreaming(method)) {
// TODO(vjpai): Add in code generation for all streaming methods
printer->Print(
*vars,
" ::grpc::Service::experimental().MarkMethodCallback($Idx$,\n"
" new ::grpc::internal::CallbackClientStreamingHandler< "
"$RealRequest$, $RealResponse$>(\n"
" [this] { return this->$Method$(); }));\n");
} else if (ServerOnlyStreaming(method)) {
// TODO(vjpai): Add in code generation for all streaming methods
printer->Print(
*vars,
" ::grpc::Service::experimental().MarkMethodCallback($Idx$,\n"
" new ::grpc::internal::CallbackServerStreamingHandler< "
"$RealRequest$, $RealResponse$>(\n"
" [this] { return this->$Method$(); }));\n");
} else if (method->BidiStreaming()) {
// TODO(vjpai): Add in code generation for all streaming methods
printer->Print(
*vars,
" ::grpc::Service::experimental().MarkMethodCallback($Idx$,\n"
" new ::grpc::internal::CallbackBidiHandler< "
"$RealRequest$, $RealResponse$>(\n"
" [this] { return this->$Method$(); }));\n");
}
printer->Print(*vars, "}\n");
printer->Print(*vars,
@ -991,8 +1020,7 @@ void PrintHeaderServerMethodRawCallback(
*vars,
" ::grpc::Service::experimental().MarkMethodRawCallback($Idx$,\n"
" new ::grpc::internal::CallbackUnaryHandler< "
"ExperimentalWithRawCallbackMethod_$Method$<BaseClass>, $RealRequest$, "
"$RealResponse$>(\n"
"$RealRequest$, $RealResponse$>(\n"
" [this](::grpc::ServerContext* context,\n"
" const $RealRequest$* request,\n"
" $RealResponse$* response,\n"
@ -1000,13 +1028,28 @@ void PrintHeaderServerMethodRawCallback(
"controller) {\n"
" this->$"
"Method$(context, request, response, controller);\n"
" }, this));\n");
" }));\n");
} else if (ClientOnlyStreaming(method)) {
// TODO(vjpai): Add in code generation for all streaming methods
printer->Print(
*vars,
" ::grpc::Service::experimental().MarkMethodRawCallback($Idx$,\n"
" new ::grpc::internal::CallbackClientStreamingHandler< "
"$RealRequest$, $RealResponse$>(\n"
" [this] { return this->$Method$(); }));\n");
} else if (ServerOnlyStreaming(method)) {
// TODO(vjpai): Add in code generation for all streaming methods
printer->Print(
*vars,
" ::grpc::Service::experimental().MarkMethodRawCallback($Idx$,\n"
" new ::grpc::internal::CallbackServerStreamingHandler< "
"$RealRequest$, $RealResponse$>(\n"
" [this] { return this->$Method$(); }));\n");
} else if (method->BidiStreaming()) {
// TODO(vjpai): Add in code generation for all streaming methods
printer->Print(
*vars,
" ::grpc::Service::experimental().MarkMethodRawCallback($Idx$,\n"
" new ::grpc::internal::CallbackBidiHandler< "
"$RealRequest$, $RealResponse$>(\n"
" [this] { return this->$Method$(); }));\n");
}
printer->Print(*vars, "}\n");
printer->Print(*vars,

@ -291,7 +291,7 @@ class Server::SyncRequest final : public internal::CompletionQueueTag {
void ContinueRunAfterInterception() {
{
ctx_.BeginCompletionOp(&call_, false);
ctx_.BeginCompletionOp(&call_, nullptr, nullptr);
global_callbacks_->PreSynchronousRequest(&ctx_);
auto* handler = resources_ ? method_->handler()
: server_->resource_exhausted_handler_.get();
@ -456,7 +456,6 @@ class Server::CallbackRequest final : public internal::CompletionQueueTag {
}
}
void ContinueRunAfterInterception() {
req_->ctx_.BeginCompletionOp(call_, true);
req_->method_->handler()->RunHandler(
internal::MethodHandler::HandlerParameter(
call_, &req_->ctx_, req_->request_, req_->request_status_,
@ -1018,7 +1017,7 @@ bool ServerInterface::BaseAsyncRequest::FinalizeResult(void** tag,
}
}
if (*status && call_) {
context_->BeginCompletionOp(&call_wrapper_, false);
context_->BeginCompletionOp(&call_wrapper_, nullptr, nullptr);
}
*tag = tag_;
if (delete_on_finalize_) {
@ -1029,7 +1028,7 @@ bool ServerInterface::BaseAsyncRequest::FinalizeResult(void** tag,
void ServerInterface::BaseAsyncRequest::
ContinueFinalizeResultAfterInterception() {
context_->BeginCompletionOp(&call_wrapper_, false);
context_->BeginCompletionOp(&call_wrapper_, nullptr, nullptr);
// Queue a tag which will be returned immediately
grpc_core::ExecCtx exec_ctx;
grpc_cq_begin_op(notification_cq_->cq(), this);

@ -17,6 +17,7 @@
*/
#include <grpcpp/server_context.h>
#include <grpcpp/support/server_callback.h>
#include <algorithm>
#include <mutex>
@ -41,8 +42,9 @@ class ServerContext::CompletionOp final : public internal::CallOpSetInterface {
public:
// initial refs: one in the server context, one in the cq
// must ref the call before calling constructor and after deleting this
CompletionOp(internal::Call* call)
CompletionOp(internal::Call* call, internal::ServerReactor* reactor)
: call_(*call),
reactor_(reactor),
has_tag_(false),
tag_(nullptr),
core_cq_tag_(this),
@ -124,9 +126,9 @@ class ServerContext::CompletionOp final : public internal::CallOpSetInterface {
return;
}
/* Start a dummy op so that we can return the tag */
GPR_CODEGEN_ASSERT(GRPC_CALL_OK ==
g_core_codegen_interface->grpc_call_start_batch(
call_.call(), nullptr, 0, this, nullptr));
GPR_CODEGEN_ASSERT(
GRPC_CALL_OK ==
grpc_call_start_batch(call_.call(), nullptr, 0, core_cq_tag_, nullptr));
}
private:
@ -136,13 +138,14 @@ class ServerContext::CompletionOp final : public internal::CallOpSetInterface {
}
internal::Call call_;
internal::ServerReactor* reactor_;
bool has_tag_;
void* tag_;
void* core_cq_tag_;
std::mutex mu_;
int refs_;
bool finalized_;
int cancelled_;
int cancelled_; // This is an int (not bool) because it is passed to core
bool done_intercepting_;
internal::InterceptorBatchMethodsImpl interceptor_methods_;
};
@ -190,7 +193,16 @@ bool ServerContext::CompletionOp::FinalizeResult(void** tag, bool* status) {
}
finalized_ = true;
if (!*status) cancelled_ = 1;
// If for some reason the incoming status is false, mark that as a
// cancellation.
// TODO(vjpai): does this ever happen?
if (!*status) {
cancelled_ = 1;
}
if (cancelled_ && (reactor_ != nullptr)) {
reactor_->OnCancel();
}
/* Release the lock since we are going to be running through interceptors now
*/
lock.unlock();
@ -251,21 +263,25 @@ void ServerContext::Clear() {
initial_metadata_.clear();
trailing_metadata_.clear();
client_metadata_.Reset();
if (call_) {
grpc_call_unref(call_);
}
if (completion_op_) {
completion_op_->Unref();
completion_op_ = nullptr;
completion_tag_.Clear();
}
if (rpc_info_) {
rpc_info_->Unref();
rpc_info_ = nullptr;
}
if (call_) {
auto* call = call_;
call_ = nullptr;
grpc_call_unref(call);
}
// Don't need to clear out call_, completion_op_, or rpc_info_ because this is
// either called from destructor or just before Setup
}
void ServerContext::BeginCompletionOp(internal::Call* call, bool callback) {
void ServerContext::BeginCompletionOp(internal::Call* call,
std::function<void(bool)> callback,
internal::ServerReactor* reactor) {
GPR_ASSERT(!completion_op_);
if (rpc_info_) {
rpc_info_->Ref();
@ -273,10 +289,11 @@ void ServerContext::BeginCompletionOp(internal::Call* call, bool callback) {
grpc_call_ref(call->call());
completion_op_ =
new (grpc_call_arena_alloc(call->call(), sizeof(CompletionOp)))
CompletionOp(call);
if (callback) {
completion_tag_.Set(call->call(), nullptr, completion_op_);
CompletionOp(call, reactor);
if (callback != nullptr) {
completion_tag_.Set(call->call(), std::move(callback), completion_op_);
completion_op_->set_core_cq_tag(&completion_tag_);
completion_op_->set_tag(completion_op_);
} else if (has_notify_when_done_tag_) {
completion_op_->set_tag(async_notify_when_done_tag_);
}

@ -322,13 +322,13 @@ class ServiceA final {
public:
ExperimentalWithCallbackMethod_MethodA1() {
::grpc::Service::experimental().MarkMethodCallback(0,
new ::grpc::internal::CallbackUnaryHandler< ExperimentalWithCallbackMethod_MethodA1<BaseClass>, ::grpc::testing::Request, ::grpc::testing::Response>(
new ::grpc::internal::CallbackUnaryHandler< ::grpc::testing::Request, ::grpc::testing::Response>(
[this](::grpc::ServerContext* context,
const ::grpc::testing::Request* request,
::grpc::testing::Response* response,
::grpc::experimental::ServerCallbackRpcController* controller) {
this->MethodA1(context, request, response, controller);
}, this));
return this->MethodA1(context, request, response, controller);
}));
}
~ExperimentalWithCallbackMethod_MethodA1() override {
BaseClassMustBeDerivedFromService(this);
@ -346,6 +346,9 @@ class ServiceA final {
void BaseClassMustBeDerivedFromService(const Service *service) {}
public:
ExperimentalWithCallbackMethod_MethodA2() {
::grpc::Service::experimental().MarkMethodCallback(1,
new ::grpc::internal::CallbackClientStreamingHandler< ::grpc::testing::Request, ::grpc::testing::Response>(
[this] { return this->MethodA2(); }));
}
~ExperimentalWithCallbackMethod_MethodA2() override {
BaseClassMustBeDerivedFromService(this);
@ -355,6 +358,9 @@ class ServiceA final {
abort();
return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, "");
}
virtual ::grpc::experimental::ServerReadReactor< ::grpc::testing::Request, ::grpc::testing::Response>* MethodA2() {
return new ::grpc::internal::UnimplementedReadReactor<
::grpc::testing::Request, ::grpc::testing::Response>;}
};
template <class BaseClass>
class ExperimentalWithCallbackMethod_MethodA3 : public BaseClass {
@ -362,6 +368,9 @@ class ServiceA final {
void BaseClassMustBeDerivedFromService(const Service *service) {}
public:
ExperimentalWithCallbackMethod_MethodA3() {
::grpc::Service::experimental().MarkMethodCallback(2,
new ::grpc::internal::CallbackServerStreamingHandler< ::grpc::testing::Request, ::grpc::testing::Response>(
[this] { return this->MethodA3(); }));
}
~ExperimentalWithCallbackMethod_MethodA3() override {
BaseClassMustBeDerivedFromService(this);
@ -371,6 +380,9 @@ class ServiceA final {
abort();
return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, "");
}
virtual ::grpc::experimental::ServerWriteReactor< ::grpc::testing::Request, ::grpc::testing::Response>* MethodA3() {
return new ::grpc::internal::UnimplementedWriteReactor<
::grpc::testing::Request, ::grpc::testing::Response>;}
};
template <class BaseClass>
class ExperimentalWithCallbackMethod_MethodA4 : public BaseClass {
@ -378,6 +390,9 @@ class ServiceA final {
void BaseClassMustBeDerivedFromService(const Service *service) {}
public:
ExperimentalWithCallbackMethod_MethodA4() {
::grpc::Service::experimental().MarkMethodCallback(3,
new ::grpc::internal::CallbackBidiHandler< ::grpc::testing::Request, ::grpc::testing::Response>(
[this] { return this->MethodA4(); }));
}
~ExperimentalWithCallbackMethod_MethodA4() override {
BaseClassMustBeDerivedFromService(this);
@ -387,6 +402,9 @@ class ServiceA final {
abort();
return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, "");
}
virtual ::grpc::experimental::ServerBidiReactor< ::grpc::testing::Request, ::grpc::testing::Response>* MethodA4() {
return new ::grpc::internal::UnimplementedBidiReactor<
::grpc::testing::Request, ::grpc::testing::Response>;}
};
typedef ExperimentalWithCallbackMethod_MethodA1<ExperimentalWithCallbackMethod_MethodA2<ExperimentalWithCallbackMethod_MethodA3<ExperimentalWithCallbackMethod_MethodA4<Service > > > > ExperimentalCallbackService;
template <class BaseClass>
@ -544,13 +562,13 @@ class ServiceA final {
public:
ExperimentalWithRawCallbackMethod_MethodA1() {
::grpc::Service::experimental().MarkMethodRawCallback(0,
new ::grpc::internal::CallbackUnaryHandler< ExperimentalWithRawCallbackMethod_MethodA1<BaseClass>, ::grpc::ByteBuffer, ::grpc::ByteBuffer>(
new ::grpc::internal::CallbackUnaryHandler< ::grpc::ByteBuffer, ::grpc::ByteBuffer>(
[this](::grpc::ServerContext* context,
const ::grpc::ByteBuffer* request,
::grpc::ByteBuffer* response,
::grpc::experimental::ServerCallbackRpcController* controller) {
this->MethodA1(context, request, response, controller);
}, this));
}));
}
~ExperimentalWithRawCallbackMethod_MethodA1() override {
BaseClassMustBeDerivedFromService(this);
@ -568,6 +586,9 @@ class ServiceA final {
void BaseClassMustBeDerivedFromService(const Service *service) {}
public:
ExperimentalWithRawCallbackMethod_MethodA2() {
::grpc::Service::experimental().MarkMethodRawCallback(1,
new ::grpc::internal::CallbackClientStreamingHandler< ::grpc::ByteBuffer, ::grpc::ByteBuffer>(
[this] { return this->MethodA2(); }));
}
~ExperimentalWithRawCallbackMethod_MethodA2() override {
BaseClassMustBeDerivedFromService(this);
@ -577,6 +598,9 @@ class ServiceA final {
abort();
return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, "");
}
virtual ::grpc::experimental::ServerReadReactor< ::grpc::ByteBuffer, ::grpc::ByteBuffer>* MethodA2() {
return new ::grpc::internal::UnimplementedReadReactor<
::grpc::ByteBuffer, ::grpc::ByteBuffer>;}
};
template <class BaseClass>
class ExperimentalWithRawCallbackMethod_MethodA3 : public BaseClass {
@ -584,6 +608,9 @@ class ServiceA final {
void BaseClassMustBeDerivedFromService(const Service *service) {}
public:
ExperimentalWithRawCallbackMethod_MethodA3() {
::grpc::Service::experimental().MarkMethodRawCallback(2,
new ::grpc::internal::CallbackServerStreamingHandler< ::grpc::ByteBuffer, ::grpc::ByteBuffer>(
[this] { return this->MethodA3(); }));
}
~ExperimentalWithRawCallbackMethod_MethodA3() override {
BaseClassMustBeDerivedFromService(this);
@ -593,6 +620,9 @@ class ServiceA final {
abort();
return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, "");
}
virtual ::grpc::experimental::ServerWriteReactor< ::grpc::ByteBuffer, ::grpc::ByteBuffer>* MethodA3() {
return new ::grpc::internal::UnimplementedWriteReactor<
::grpc::ByteBuffer, ::grpc::ByteBuffer>;}
};
template <class BaseClass>
class ExperimentalWithRawCallbackMethod_MethodA4 : public BaseClass {
@ -600,6 +630,9 @@ class ServiceA final {
void BaseClassMustBeDerivedFromService(const Service *service) {}
public:
ExperimentalWithRawCallbackMethod_MethodA4() {
::grpc::Service::experimental().MarkMethodRawCallback(3,
new ::grpc::internal::CallbackBidiHandler< ::grpc::ByteBuffer, ::grpc::ByteBuffer>(
[this] { return this->MethodA4(); }));
}
~ExperimentalWithRawCallbackMethod_MethodA4() override {
BaseClassMustBeDerivedFromService(this);
@ -609,6 +642,9 @@ class ServiceA final {
abort();
return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, "");
}
virtual ::grpc::experimental::ServerBidiReactor< ::grpc::ByteBuffer, ::grpc::ByteBuffer>* MethodA4() {
return new ::grpc::internal::UnimplementedBidiReactor<
::grpc::ByteBuffer, ::grpc::ByteBuffer>;}
};
template <class BaseClass>
class WithStreamedUnaryMethod_MethodA1 : public BaseClass {
@ -752,13 +788,13 @@ class ServiceB final {
public:
ExperimentalWithCallbackMethod_MethodB1() {
::grpc::Service::experimental().MarkMethodCallback(0,
new ::grpc::internal::CallbackUnaryHandler< ExperimentalWithCallbackMethod_MethodB1<BaseClass>, ::grpc::testing::Request, ::grpc::testing::Response>(
new ::grpc::internal::CallbackUnaryHandler< ::grpc::testing::Request, ::grpc::testing::Response>(
[this](::grpc::ServerContext* context,
const ::grpc::testing::Request* request,
::grpc::testing::Response* response,
::grpc::experimental::ServerCallbackRpcController* controller) {
this->MethodB1(context, request, response, controller);
}, this));
return this->MethodB1(context, request, response, controller);
}));
}
~ExperimentalWithCallbackMethod_MethodB1() override {
BaseClassMustBeDerivedFromService(this);
@ -815,13 +851,13 @@ class ServiceB final {
public:
ExperimentalWithRawCallbackMethod_MethodB1() {
::grpc::Service::experimental().MarkMethodRawCallback(0,
new ::grpc::internal::CallbackUnaryHandler< ExperimentalWithRawCallbackMethod_MethodB1<BaseClass>, ::grpc::ByteBuffer, ::grpc::ByteBuffer>(
new ::grpc::internal::CallbackUnaryHandler< ::grpc::ByteBuffer, ::grpc::ByteBuffer>(
[this](::grpc::ServerContext* context,
const ::grpc::ByteBuffer* request,
::grpc::ByteBuffer* response,
::grpc::experimental::ServerCallbackRpcController* controller) {
this->MethodB1(context, request, response, controller);
}, this));
}));
}
~ExperimentalWithRawCallbackMethod_MethodB1() override {
BaseClassMustBeDerivedFromService(this);

@ -196,16 +196,18 @@ class TestServiceImplDupPkg
class TestScenario {
public:
TestScenario(bool interceptors, bool proxy, bool inproc_stub,
const grpc::string& creds_type)
const grpc::string& creds_type, bool use_callback_server)
: use_interceptors(interceptors),
use_proxy(proxy),
inproc(inproc_stub),
credentials_type(creds_type) {}
credentials_type(creds_type),
callback_server(use_callback_server) {}
void Log() const;
bool use_interceptors;
bool use_proxy;
bool inproc;
const grpc::string credentials_type;
bool callback_server;
};
static std::ostream& operator<<(std::ostream& out,
@ -214,6 +216,8 @@ static std::ostream& operator<<(std::ostream& out,
<< (scenario.use_interceptors ? "true" : "false")
<< ", use_proxy=" << (scenario.use_proxy ? "true" : "false")
<< ", inproc=" << (scenario.inproc ? "true" : "false")
<< ", server_type="
<< (scenario.callback_server ? "callback" : "sync")
<< ", credentials='" << scenario.credentials_type << "'}";
}
@ -280,7 +284,11 @@ class End2endTest : public ::testing::TestWithParam<TestScenario> {
builder.experimental().SetInterceptorCreators(std::move(creators));
}
builder.AddListeningPort(server_address_.str(), server_creds);
builder.RegisterService(&service_);
if (!GetParam().callback_server) {
builder.RegisterService(&service_);
} else {
builder.RegisterService(&callback_service_);
}
builder.RegisterService("foo.test.youtube.com", &special_service_);
builder.RegisterService(&dup_pkg_service_);
@ -362,6 +370,7 @@ class End2endTest : public ::testing::TestWithParam<TestScenario> {
std::ostringstream server_address_;
const int kMaxMessageSize_;
TestServiceImpl service_;
CallbackTestServiceImpl callback_service_;
TestServiceImpl special_service_;
TestServiceImplDupPkg dup_pkg_service_;
grpc::string user_agent_prefix_;
@ -1016,7 +1025,8 @@ TEST_P(End2endTest, DiffPackageServices) {
EXPECT_TRUE(s.ok());
}
void CancelRpc(ClientContext* context, int delay_us, TestServiceImpl* service) {
template <class ServiceType>
void CancelRpc(ClientContext* context, int delay_us, ServiceType* service) {
gpr_sleep_until(gpr_time_add(gpr_now(GPR_CLOCK_REALTIME),
gpr_time_from_micros(delay_us, GPR_TIMESPAN)));
while (!service->signal_client()) {
@ -1446,7 +1456,24 @@ TEST_P(ProxyEnd2endTest, ClientCancelsRpc) {
request.mutable_param()->set_client_cancel_after_us(kCancelDelayUs);
ClientContext context;
std::thread cancel_thread(CancelRpc, &context, kCancelDelayUs, &service_);
std::thread cancel_thread;
if (!GetParam().callback_server) {
cancel_thread = std::thread(
[&context, this](int delay) { CancelRpc(&context, delay, &service_); },
kCancelDelayUs);
// Note: the unusual pattern above (and below) is caused by a conflict
// between two sets of compiler expectations. clang allows const to be
// captured without mention, so there is no need to capture kCancelDelayUs
// (and indeed clang-tidy complains if you do so). OTOH, a Windows compiler
// in our tests requires an explicit capture even for const. We square this
// circle by passing the const value in as an argument to the lambda.
} else {
cancel_thread = std::thread(
[&context, this](int delay) {
CancelRpc(&context, delay, &callback_service_);
},
kCancelDelayUs);
}
Status s = stub_->Echo(&context, request, &response);
cancel_thread.join();
EXPECT_EQ(StatusCode::CANCELLED, s.error_code());
@ -1838,10 +1865,12 @@ TEST_P(ResourceQuotaEnd2endTest, SimpleRequest) {
EXPECT_TRUE(s.ok());
}
// TODO(vjpai): refactor arguments into a struct if it makes sense
std::vector<TestScenario> CreateTestScenarios(bool use_proxy,
bool test_insecure,
bool test_secure,
bool test_inproc) {
bool test_inproc,
bool test_callback_server) {
std::vector<TestScenario> scenarios;
std::vector<grpc::string> credentials_types;
if (test_secure) {
@ -1857,41 +1886,48 @@ std::vector<TestScenario> CreateTestScenarios(bool use_proxy,
if (test_insecure && insec_ok()) {
credentials_types.push_back(kInsecureCredentialsType);
}
// For now test callback server only with inproc
GPR_ASSERT(!credentials_types.empty());
for (const auto& cred : credentials_types) {
scenarios.emplace_back(false, false, false, cred);
scenarios.emplace_back(true, false, false, cred);
scenarios.emplace_back(false, false, false, cred, false);
scenarios.emplace_back(true, false, false, cred, false);
if (use_proxy) {
scenarios.emplace_back(false, true, false, cred);
scenarios.emplace_back(true, true, false, cred);
scenarios.emplace_back(false, true, false, cred, false);
scenarios.emplace_back(true, true, false, cred, false);
}
}
if (test_inproc && insec_ok()) {
scenarios.emplace_back(false, false, true, kInsecureCredentialsType);
scenarios.emplace_back(true, false, true, kInsecureCredentialsType);
scenarios.emplace_back(false, false, true, kInsecureCredentialsType, false);
scenarios.emplace_back(true, false, true, kInsecureCredentialsType, false);
if (test_callback_server) {
scenarios.emplace_back(false, false, true, kInsecureCredentialsType,
true);
scenarios.emplace_back(true, false, true, kInsecureCredentialsType, true);
}
}
return scenarios;
}
INSTANTIATE_TEST_CASE_P(End2end, End2endTest,
::testing::ValuesIn(CreateTestScenarios(false, true,
true, true)));
INSTANTIATE_TEST_CASE_P(
End2end, End2endTest,
::testing::ValuesIn(CreateTestScenarios(false, true, true, true, true)));
INSTANTIATE_TEST_CASE_P(End2endServerTryCancel, End2endServerTryCancelTest,
::testing::ValuesIn(CreateTestScenarios(false, true,
true, true)));
INSTANTIATE_TEST_CASE_P(
End2endServerTryCancel, End2endServerTryCancelTest,
::testing::ValuesIn(CreateTestScenarios(false, true, true, true, true)));
INSTANTIATE_TEST_CASE_P(ProxyEnd2end, ProxyEnd2endTest,
::testing::ValuesIn(CreateTestScenarios(true, true,
true, true)));
INSTANTIATE_TEST_CASE_P(
ProxyEnd2end, ProxyEnd2endTest,
::testing::ValuesIn(CreateTestScenarios(true, true, true, true, false)));
INSTANTIATE_TEST_CASE_P(SecureEnd2end, SecureEnd2endTest,
::testing::ValuesIn(CreateTestScenarios(false, false,
true, false)));
INSTANTIATE_TEST_CASE_P(
SecureEnd2end, SecureEnd2endTest,
::testing::ValuesIn(CreateTestScenarios(false, false, true, false, true)));
INSTANTIATE_TEST_CASE_P(ResourceQuotaEnd2end, ResourceQuotaEnd2endTest,
::testing::ValuesIn(CreateTestScenarios(false, true,
true, true)));
INSTANTIATE_TEST_CASE_P(
ResourceQuotaEnd2end, ResourceQuotaEnd2endTest,
::testing::ValuesIn(CreateTestScenarios(false, true, true, true, false)));
} // namespace
} // namespace testing

@ -71,6 +71,46 @@ void CheckServerAuthContext(
}
} // namespace
namespace {
int GetIntValueFromMetadataHelper(
const char* key,
const std::multimap<grpc::string_ref, grpc::string_ref>& metadata,
int default_value) {
if (metadata.find(key) != metadata.end()) {
std::istringstream iss(ToString(metadata.find(key)->second));
iss >> default_value;
gpr_log(GPR_INFO, "%s : %d", key, default_value);
}
return default_value;
}
int GetIntValueFromMetadata(
const char* key,
const std::multimap<grpc::string_ref, grpc::string_ref>& metadata,
int default_value) {
return GetIntValueFromMetadataHelper(key, metadata, default_value);
}
void ServerTryCancel(ServerContext* context) {
EXPECT_FALSE(context->IsCancelled());
context->TryCancel();
gpr_log(GPR_INFO, "Server called TryCancel() to cancel the request");
// Now wait until it's really canceled
while (!context->IsCancelled()) {
gpr_sleep_until(gpr_time_add(gpr_now(GPR_CLOCK_REALTIME),
gpr_time_from_micros(1000, GPR_TIMESPAN)));
}
}
void ServerTryCancelNonblocking(ServerContext* context) {
EXPECT_FALSE(context->IsCancelled());
context->TryCancel();
gpr_log(GPR_INFO, "Server called TryCancel() to cancel the request");
}
} // namespace
Status TestServiceImpl::Echo(ServerContext* context, const EchoRequest* request,
EchoResponse* response) {
// A bit of sleep to make sure that short deadline tests fail
@ -195,6 +235,7 @@ void CallbackTestServiceImpl::EchoNonDelayed(
controller->Finish(Status(static_cast<StatusCode>(error.code()),
error.error_message(),
error.binary_error_details()));
return;
}
int server_try_cancel = GetIntValueFromMetadata(
kServerTryCancelRequest, context->client_metadata(), DO_NOT_CANCEL);
@ -254,7 +295,7 @@ void CallbackTestServiceImpl::EchoNonDelayed(
alarm_.experimental().Set(
gpr_time_add(
gpr_now(GPR_CLOCK_REALTIME),
gpr_time_from_micros(request->param().client_cancel_after_us(),
gpr_time_from_micros(request->param().server_cancel_after_us(),
GPR_TIMESPAN)),
[controller](bool) { controller->Finish(Status::CANCELLED); });
return;
@ -279,6 +320,7 @@ void CallbackTestServiceImpl::EchoNonDelayed(
request->param().debug_info().SerializeAsString();
context->AddTrailingMetadata(kDebugInfoTrailerKey, serialized_debug_info);
controller->Finish(Status::CANCELLED);
return;
}
}
if (request->has_param() &&
@ -325,7 +367,7 @@ Status TestServiceImpl::RequestStream(ServerContext* context,
std::thread* server_try_cancel_thd = nullptr;
if (server_try_cancel == CANCEL_DURING_PROCESSING) {
server_try_cancel_thd =
new std::thread(&TestServiceImpl::ServerTryCancel, this, context);
new std::thread([context] { ServerTryCancel(context); });
}
int num_msgs_read = 0;
@ -380,7 +422,7 @@ Status TestServiceImpl::ResponseStream(ServerContext* context,
std::thread* server_try_cancel_thd = nullptr;
if (server_try_cancel == CANCEL_DURING_PROCESSING) {
server_try_cancel_thd =
new std::thread(&TestServiceImpl::ServerTryCancel, this, context);
new std::thread([context] { ServerTryCancel(context); });
}
for (int i = 0; i < server_responses_to_send; i++) {
@ -431,7 +473,7 @@ Status TestServiceImpl::BidiStream(
std::thread* server_try_cancel_thd = nullptr;
if (server_try_cancel == CANCEL_DURING_PROCESSING) {
server_try_cancel_thd =
new std::thread(&TestServiceImpl::ServerTryCancel, this, context);
new std::thread([context] { ServerTryCancel(context); });
}
// kServerFinishAfterNReads suggests after how many reads, the server should
@ -465,44 +507,244 @@ Status TestServiceImpl::BidiStream(
return Status::OK;
}
namespace {
int GetIntValueFromMetadataHelper(
const char* key,
const std::multimap<grpc::string_ref, grpc::string_ref>& metadata,
int default_value) {
if (metadata.find(key) != metadata.end()) {
std::istringstream iss(ToString(metadata.find(key)->second));
iss >> default_value;
gpr_log(GPR_INFO, "%s : %d", key, default_value);
}
experimental::ServerReadReactor<EchoRequest, EchoResponse>*
CallbackTestServiceImpl::RequestStream() {
class Reactor : public ::grpc::experimental::ServerReadReactor<EchoRequest,
EchoResponse> {
public:
Reactor() {}
void OnStarted(ServerContext* context, EchoResponse* response) override {
ctx_ = context;
response_ = response;
// If 'server_try_cancel' is set in the metadata, the RPC is cancelled by
// the server by calling ServerContext::TryCancel() depending on the
// value:
// CANCEL_BEFORE_PROCESSING: The RPC is cancelled before the server
// reads any message from the client CANCEL_DURING_PROCESSING: The RPC
// is cancelled while the server is reading messages from the client
// CANCEL_AFTER_PROCESSING: The RPC is cancelled after the server reads
// all the messages from the client
server_try_cancel_ = GetIntValueFromMetadata(
kServerTryCancelRequest, context->client_metadata(), DO_NOT_CANCEL);
response_->set_message("");
if (server_try_cancel_ == CANCEL_BEFORE_PROCESSING) {
ServerTryCancelNonblocking(ctx_);
return;
}
return default_value;
}
}; // namespace
if (server_try_cancel_ == CANCEL_DURING_PROCESSING) {
ctx_->TryCancel();
// Don't wait for it here
}
int TestServiceImpl::GetIntValueFromMetadata(
const char* key,
const std::multimap<grpc::string_ref, grpc::string_ref>& metadata,
int default_value) {
return GetIntValueFromMetadataHelper(key, metadata, default_value);
StartRead(&request_);
}
void OnDone() override { delete this; }
void OnCancel() override { FinishOnce(Status::CANCELLED); }
void OnReadDone(bool ok) override {
if (ok) {
response_->mutable_message()->append(request_.message());
num_msgs_read_++;
StartRead(&request_);
} else {
gpr_log(GPR_INFO, "Read: %d messages", num_msgs_read_);
if (server_try_cancel_ == CANCEL_DURING_PROCESSING) {
// Let OnCancel recover this
return;
}
if (server_try_cancel_ == CANCEL_AFTER_PROCESSING) {
ServerTryCancelNonblocking(ctx_);
return;
}
FinishOnce(Status::OK);
}
}
private:
void FinishOnce(const Status& s) {
std::lock_guard<std::mutex> l(finish_mu_);
if (!finished_) {
Finish(s);
finished_ = true;
}
}
ServerContext* ctx_;
EchoResponse* response_;
EchoRequest request_;
int num_msgs_read_{0};
int server_try_cancel_;
std::mutex finish_mu_;
bool finished_{false};
};
return new Reactor;
}
int CallbackTestServiceImpl::GetIntValueFromMetadata(
const char* key,
const std::multimap<grpc::string_ref, grpc::string_ref>& metadata,
int default_value) {
return GetIntValueFromMetadataHelper(key, metadata, default_value);
// Return 'kNumResponseStreamMsgs' messages.
// TODO(yangg) make it generic by adding a parameter into EchoRequest
experimental::ServerWriteReactor<EchoRequest, EchoResponse>*
CallbackTestServiceImpl::ResponseStream() {
class Reactor
: public ::grpc::experimental::ServerWriteReactor<EchoRequest,
EchoResponse> {
public:
Reactor() {}
void OnStarted(ServerContext* context,
const EchoRequest* request) override {
ctx_ = context;
request_ = request;
// If 'server_try_cancel' is set in the metadata, the RPC is cancelled by
// the server by calling ServerContext::TryCancel() depending on the
// value:
// CANCEL_BEFORE_PROCESSING: The RPC is cancelled before the server
// reads any message from the client CANCEL_DURING_PROCESSING: The RPC
// is cancelled while the server is reading messages from the client
// CANCEL_AFTER_PROCESSING: The RPC is cancelled after the server reads
// all the messages from the client
server_try_cancel_ = GetIntValueFromMetadata(
kServerTryCancelRequest, context->client_metadata(), DO_NOT_CANCEL);
server_coalescing_api_ = GetIntValueFromMetadata(
kServerUseCoalescingApi, context->client_metadata(), 0);
server_responses_to_send_ = GetIntValueFromMetadata(
kServerResponseStreamsToSend, context->client_metadata(),
kServerDefaultResponseStreamsToSend);
if (server_try_cancel_ == CANCEL_BEFORE_PROCESSING) {
ServerTryCancelNonblocking(ctx_);
return;
}
if (server_try_cancel_ == CANCEL_DURING_PROCESSING) {
ctx_->TryCancel();
}
if (num_msgs_sent_ < server_responses_to_send_) {
NextWrite();
}
}
void OnDone() override { delete this; }
void OnCancel() override { FinishOnce(Status::CANCELLED); }
void OnWriteDone(bool ok) override {
if (num_msgs_sent_ < server_responses_to_send_) {
NextWrite();
} else if (server_try_cancel_ == CANCEL_DURING_PROCESSING) {
// Let OnCancel recover this
} else if (server_try_cancel_ == CANCEL_AFTER_PROCESSING) {
ServerTryCancelNonblocking(ctx_);
} else {
FinishOnce(Status::OK);
}
}
private:
void FinishOnce(const Status& s) {
std::lock_guard<std::mutex> l(finish_mu_);
if (!finished_) {
Finish(s);
finished_ = true;
}
}
void NextWrite() {
response_.set_message(request_->message() +
grpc::to_string(num_msgs_sent_));
if (num_msgs_sent_ == server_responses_to_send_ - 1 &&
server_coalescing_api_ != 0) {
num_msgs_sent_++;
StartWriteLast(&response_, WriteOptions());
} else {
num_msgs_sent_++;
StartWrite(&response_);
}
}
ServerContext* ctx_;
const EchoRequest* request_;
EchoResponse response_;
int num_msgs_sent_{0};
int server_try_cancel_;
int server_coalescing_api_;
int server_responses_to_send_;
std::mutex finish_mu_;
bool finished_{false};
};
return new Reactor;
}
void TestServiceImpl::ServerTryCancel(ServerContext* context) {
EXPECT_FALSE(context->IsCancelled());
context->TryCancel();
gpr_log(GPR_INFO, "Server called TryCancel() to cancel the request");
// Now wait until it's really canceled
while (!context->IsCancelled()) {
gpr_sleep_until(gpr_time_add(gpr_now(GPR_CLOCK_REALTIME),
gpr_time_from_micros(1000, GPR_TIMESPAN)));
}
experimental::ServerBidiReactor<EchoRequest, EchoResponse>*
CallbackTestServiceImpl::BidiStream() {
class Reactor : public ::grpc::experimental::ServerBidiReactor<EchoRequest,
EchoResponse> {
public:
Reactor() {}
void OnStarted(ServerContext* context) override {
ctx_ = context;
// If 'server_try_cancel' is set in the metadata, the RPC is cancelled by
// the server by calling ServerContext::TryCancel() depending on the
// value:
// CANCEL_BEFORE_PROCESSING: The RPC is cancelled before the server
// reads any message from the client CANCEL_DURING_PROCESSING: The RPC
// is cancelled while the server is reading messages from the client
// CANCEL_AFTER_PROCESSING: The RPC is cancelled after the server reads
// all the messages from the client
server_try_cancel_ = GetIntValueFromMetadata(
kServerTryCancelRequest, context->client_metadata(), DO_NOT_CANCEL);
server_write_last_ = GetIntValueFromMetadata(
kServerFinishAfterNReads, context->client_metadata(), 0);
if (server_try_cancel_ == CANCEL_BEFORE_PROCESSING) {
ServerTryCancelNonblocking(ctx_);
return;
}
if (server_try_cancel_ == CANCEL_DURING_PROCESSING) {
ctx_->TryCancel();
}
StartRead(&request_);
}
void OnDone() override { delete this; }
void OnCancel() override { FinishOnce(Status::CANCELLED); }
void OnReadDone(bool ok) override {
if (ok) {
num_msgs_read_++;
gpr_log(GPR_INFO, "recv msg %s", request_.message().c_str());
response_.set_message(request_.message());
if (num_msgs_read_ == server_write_last_) {
StartWriteLast(&response_, WriteOptions());
} else {
StartWrite(&response_);
}
} else if (server_try_cancel_ == CANCEL_DURING_PROCESSING) {
// Let OnCancel handle this
} else if (server_try_cancel_ == CANCEL_AFTER_PROCESSING) {
ServerTryCancelNonblocking(ctx_);
} else {
FinishOnce(Status::OK);
}
}
void OnWriteDone(bool ok) override { StartRead(&request_); }
private:
void FinishOnce(const Status& s) {
std::lock_guard<std::mutex> l(finish_mu_);
if (!finished_) {
Finish(s);
finished_ = true;
}
}
ServerContext* ctx_;
EchoRequest request_;
EchoResponse response_;
int num_msgs_read_{0};
int server_try_cancel_;
int server_write_last_;
std::mutex finish_mu_;
bool finished_{false};
};
return new Reactor;
}
} // namespace testing

@ -72,13 +72,6 @@ class TestServiceImpl : public ::grpc::testing::EchoTestService::Service {
}
private:
int GetIntValueFromMetadata(
const char* key,
const std::multimap<grpc::string_ref, grpc::string_ref>& metadata,
int default_value);
void ServerTryCancel(ServerContext* context);
bool signal_client_;
std::mutex mu_;
std::unique_ptr<grpc::string> host_;
@ -95,6 +88,15 @@ class CallbackTestServiceImpl
EchoResponse* response,
experimental::ServerCallbackRpcController* controller) override;
experimental::ServerReadReactor<EchoRequest, EchoResponse>* RequestStream()
override;
experimental::ServerWriteReactor<EchoRequest, EchoResponse>* ResponseStream()
override;
experimental::ServerBidiReactor<EchoRequest, EchoResponse>* BidiStream()
override;
// Unimplemented is left unimplemented to test the returned error.
bool signal_client() {
std::unique_lock<std::mutex> lock(mu_);
@ -106,11 +108,6 @@ class CallbackTestServiceImpl
EchoResponse* response,
experimental::ServerCallbackRpcController* controller);
int GetIntValueFromMetadata(
const char* key,
const std::multimap<grpc::string_ref, grpc::string_ref>& metadata,
int default_value);
Alarm alarm_;
bool signal_client_;
std::mutex mu_;

Loading…
Cancel
Save