diff --git a/include/grpc++/call.h b/include/grpc++/call.h index 5aa96d33b96..94215bfa986 100644 --- a/include/grpc++/call.h +++ b/include/grpc++/call.h @@ -55,6 +55,8 @@ class ChannelInterface; class CallOpBuffer final : public CompletionQueueTag { public: + CallOpBuffer() : return_tag_(this) {} + void AddSendInitialMetadata(std::vector > *metadata); void AddSendMessage(const google::protobuf::Message &message); void AddRecvMessage(google::protobuf::Message *message); @@ -67,7 +69,10 @@ class CallOpBuffer final : public CompletionQueueTag { void FillOps(grpc_op *ops, size_t *nops); // Called by completion queue just prior to returning from Next() or Pluck() - FinalizeResultOutput FinalizeResult(bool status) override; + void FinalizeResult(void *tag, bool *status) override; + + private: + void *return_tag_; }; class CCallDeleter { @@ -80,7 +85,7 @@ class Call final { public: Call(grpc_call *call, ChannelInterface *channel, CompletionQueue *cq); - void PerformOps(CallOpBuffer *buffer, void *tag); + void PerformOps(CallOpBuffer *buffer); grpc_call *call() { return call_.get(); } CompletionQueue *cq() { return cq_; } diff --git a/include/grpc++/channel_interface.h b/include/grpc++/channel_interface.h index 79466c9fda0..c128a08a9f9 100644 --- a/include/grpc++/channel_interface.h +++ b/include/grpc++/channel_interface.h @@ -58,7 +58,7 @@ class ChannelInterface { virtual Call CreateCall(const RpcMethod &method, ClientContext *context, CompletionQueue *cq) = 0; - virtual void PerformOpsOnCall(CallOpBuffer *ops, void *tag, Call *call) = 0; + virtual void PerformOpsOnCall(CallOpBuffer *ops, Call *call) = 0; }; // Wrapper that begins an asynchronous unary call diff --git a/include/grpc++/completion_queue.h b/include/grpc++/completion_queue.h index 8033fd12058..641d599c7ea 100644 --- a/include/grpc++/completion_queue.h +++ b/include/grpc++/completion_queue.h @@ -38,21 +38,27 @@ struct grpc_completion_queue; namespace grpc { +template +class ClientReader; +template +class ClientWriter; +template +class ClientReaderWriter; +template +class ServerReader; +template +class ServerWriter; +template +class ServerReaderWriter; + class CompletionQueue; class CompletionQueueTag { public: - enum FinalizeResultOutput { - SUCCEED, - FAIL, - SWALLOW, - }; - - virtual FinalizeResultOutput FinalizeResult(bool status) = 0; - - private: - friend class CompletionQueue; - void *user_tag_; + // Called prior to returning from Next(), return value + // is the status of the operation (return status is the default thing + // to do) + virtual void FinalizeResult(void *tag, bool *status) = 0; }; // grpc_completion_queue wrapper class @@ -66,22 +72,6 @@ class CompletionQueue { // for destruction. bool Next(void **tag, bool *ok); - bool Pluck(void *tag); - - // Prepare a tag for the C api - // Given a tag we'd like to receive from Next, what tag should we pass - // down to the C api? - // Usage example: - // grpc_call_start_batch(..., cq.PrepareTagForC(tag)); - // Allows attaching some work to be executed before the original tag - // is returned. - // MUST be used for all events that could be surfaced through this - // wrapping API - void *PrepareTagForC(CompletionQueueTag *cq_tag, void *user_tag) { - cq_tag->user_tag_ = user_tag; - return cq_tag; - } - // Shutdown has to be called, and the CompletionQueue can only be // destructed when false is returned from Next(). void Shutdown(); @@ -89,6 +79,15 @@ class CompletionQueue { grpc_completion_queue* cq() { return cq_; } private: + template friend class ::grpc::ClientReader; + template friend class ::grpc::ClientWriter; + template friend class ::grpc::ClientReaderWriter; + template friend class ::grpc::ServerReader; + template friend class ::grpc::ServerWriter; + template friend class ::grpc::ServerReaderWriter; + + bool Pluck(CompletionQueueTag *tag); + grpc_completion_queue* cq_; // owned }; diff --git a/include/grpc++/stream.h b/include/grpc++/stream.h index 22dc44efe40..ca32d60810d 100644 --- a/include/grpc++/stream.h +++ b/include/grpc++/stream.h @@ -123,23 +123,23 @@ class ClientReader final : public ClientStreamingInterface, CallOpBuffer buf; buf.AddSendMessage(request); buf.AddClientSendClose(); - call_.PerformOps(&buf, (void *)1); - cq_.Pluck((void *)1); + call_.PerformOps(&buf); + cq_.Pluck(&buf); } virtual bool Read(R *msg) override { CallOpBuffer buf; buf.AddRecvMessage(msg); - call_.PerformOps(&buf, (void *)2); - return cq_.Pluck((void *)2); + call_.PerformOps(&buf); + return cq_.Pluck(&buf); } virtual Status Finish() override { CallOpBuffer buf; Status status; buf.AddClientRecvStatus(&status); - call_.PerformOps(&buf, (void *)3); - GPR_ASSERT(cq_.Pluck((void *)3)); + call_.PerformOps(&buf); + GPR_ASSERT(cq_.Pluck(&buf)); return status; } @@ -162,15 +162,15 @@ class ClientWriter final : public ClientStreamingInterface, virtual bool Write(const W& msg) override { CallOpBuffer buf; buf.AddSendMessage(msg); - call_.PerformOps(&buf, (void *)2); - return cq_.Pluck((void *)2); + call_.PerformOps(&buf); + return cq_.Pluck(&buf); } virtual bool WritesDone() { CallOpBuffer buf; buf.AddClientSendClose(); - call_.PerformOps(&buf, (void *)3); - return cq_.Pluck((void *)3); + call_.PerformOps(&buf); + return cq_.Pluck(&buf); } // Read the final response and wait for the final status. @@ -179,8 +179,8 @@ class ClientWriter final : public ClientStreamingInterface, Status status; buf.AddRecvMessage(response_); buf.AddClientRecvStatus(&status); - call_.PerformOps(&buf, (void *)4); - GPR_ASSERT(cq_.Pluck((void *)4)); + call_.PerformOps(&buf); + GPR_ASSERT(cq_.Pluck(&buf)); return status; } @@ -204,30 +204,30 @@ class ClientReaderWriter final : public ClientStreamingInterface, virtual bool Read(R *msg) override { CallOpBuffer buf; buf.AddRecvMessage(msg); - call_.PerformOps(&buf, (void *)2); - return cq_.Pluck((void *)2); + call_.PerformOps(&buf); + return cq_.Pluck(&buf); } virtual bool Write(const W& msg) override { CallOpBuffer buf; buf.AddSendMessage(msg); - call_.PerformOps(&buf, (void *)3); - return cq_.Pluck((void *)3); + call_.PerformOps(&buf); + return cq_.Pluck(&buf); } virtual bool WritesDone() { CallOpBuffer buf; buf.AddClientSendClose(); - call_.PerformOps(&buf, (void *)4); - return cq_.Pluck((void *)4); + call_.PerformOps(&buf); + return cq_.Pluck(&buf); } virtual Status Finish() override { CallOpBuffer buf; Status status; buf.AddClientRecvStatus(&status); - call_.PerformOps(&buf, (void *)5); - GPR_ASSERT(cq_.Pluck((void *)5)); + call_.PerformOps(&buf); + GPR_ASSERT(cq_.Pluck(&buf)); return status; } @@ -244,8 +244,8 @@ class ServerReader final : public ReaderInterface { virtual bool Read(R* msg) override { CallOpBuffer buf; buf.AddRecvMessage(msg); - call_->PerformOps(&buf, (void *)2); - return call_->cq()->Pluck((void *)2); + call_->PerformOps(&buf); + return call_->cq()->Pluck(&buf); } private: @@ -260,8 +260,8 @@ class ServerWriter final : public WriterInterface { virtual bool Write(const W& msg) override { CallOpBuffer buf; buf.AddSendMessage(msg); - call_->PerformOps(&buf, (void *)2); - return call_->cq()->Pluck((void *)2); + call_->PerformOps(&buf); + return call_->cq()->Pluck(&buf); } private: @@ -278,15 +278,15 @@ class ServerReaderWriter final : public WriterInterface, virtual bool Read(R* msg) override { CallOpBuffer buf; buf.AddRecvMessage(msg); - call_->PerformOps(&buf, (void *)2); - return call_->cq()->Pluck((void *)2); + call_->PerformOps(&buf); + return call_->cq()->Pluck(&buf); } virtual bool Write(const W& msg) override { CallOpBuffer buf; buf.AddSendMessage(msg); - call_->PerformOps(&buf, (void *)3); - return call_->cq()->Pluck((void *)3); + call_->PerformOps(&buf); + return call_->cq()->Pluck(&buf); } private: @@ -333,19 +333,19 @@ class ClientAsyncReader final : public ClientAsyncStreamingInterface, CallOpBuffer buf; buf.AddSendMessage(request); buf.AddClientSendClose(); - call_.PerformOps(&buf, tag); + call_.PerformOps(&buf); } virtual void Read(R *msg, void* tag) override { CallOpBuffer buf; buf.AddRecvMessage(msg); - call_.PerformOps(&buf, tag); + call_.PerformOps(&buf); } virtual void Finish(Status* status, void* tag) override { CallOpBuffer buf; buf.AddClientRecvStatus(status); - call_.PerformOps(&buf, tag); + call_.PerformOps(&buf); } private: @@ -367,20 +367,20 @@ class ClientAsyncWriter final : public ClientAsyncStreamingInterface, virtual void Write(const W& msg, void* tag) override { CallOpBuffer buf; buf.AddSendMessage(msg); - call_.PerformOps(&buf, tag); + call_.PerformOps(&buf); } virtual void WritesDone(void* tag) { CallOpBuffer buf; buf.AddClientSendClose(); - call_.PerformOps(&buf, tag); + call_.PerformOps(&buf); } virtual void Finish(Status* status, void* tag) override { CallOpBuffer buf; buf.AddRecvMessage(response_); buf.AddClientRecvStatus(status); - call_.PerformOps(&buf, tag); + call_.PerformOps(&buf); } private: @@ -402,25 +402,25 @@ class ClientAsyncReaderWriter final : public ClientAsyncStreamingInterface, virtual void Read(R *msg, void* tag) override { CallOpBuffer buf; buf.AddRecvMessage(msg); - call_.PerformOps(&buf, tag); + call_.PerformOps(&buf); } virtual void Write(const W& msg, void* tag) override { CallOpBuffer buf; buf.AddSendMessage(msg); - call_.PerformOps(&buf, tag); + call_.PerformOps(&buf); } virtual void WritesDone(void* tag) { CallOpBuffer buf; buf.AddClientSendClose(); - call_.PerformOps(&buf, tag); + call_.PerformOps(&buf); } virtual void Finish(Status* status, void* tag) override { CallOpBuffer buf; buf.AddClientRecvStatus(status); - call_.PerformOps(&buf, tag); + call_.PerformOps(&buf); } private: @@ -437,7 +437,7 @@ class ServerAsyncResponseWriter final { virtual void Write(const W& msg, void* tag) override { CallOpBuffer buf; buf.AddSendMessage(msg); - call_->PerformOps(&buf, tag); + call_->PerformOps(&buf); } private: diff --git a/src/cpp/client/channel.cc b/src/cpp/client/channel.cc index 24bd7adaad4..a1539e4711d 100644 --- a/src/cpp/client/channel.cc +++ b/src/cpp/client/channel.cc @@ -87,14 +87,14 @@ Call Channel::CreateCall(const RpcMethod &method, ClientContext *context, return Call(c_call, this, cq); } -void Channel::PerformOpsOnCall(CallOpBuffer *buf, void *tag, Call *call) { +void Channel::PerformOpsOnCall(CallOpBuffer *buf, Call *call) { static const size_t MAX_OPS = 8; size_t nops = MAX_OPS; grpc_op ops[MAX_OPS]; buf->FillOps(ops, &nops); GPR_ASSERT(GRPC_CALL_OK == grpc_call_start_batch(call->call(), ops, nops, - call->cq()->PrepareTagForC(buf, tag))); + buf)); } } // namespace grpc diff --git a/src/cpp/client/channel.h b/src/cpp/client/channel.h index 6cf222883c2..894f87698ce 100644 --- a/src/cpp/client/channel.h +++ b/src/cpp/client/channel.h @@ -59,7 +59,7 @@ class Channel final : public ChannelInterface { virtual Call CreateCall(const RpcMethod &method, ClientContext *context, CompletionQueue *cq) override; - virtual void PerformOpsOnCall(CallOpBuffer *ops, void *tag, + virtual void PerformOpsOnCall(CallOpBuffer *ops, Call *call) override; private: diff --git a/src/cpp/common/call.cc b/src/cpp/common/call.cc index 9f8d9364b18..d3a9de3620e 100644 --- a/src/cpp/common/call.cc +++ b/src/cpp/common/call.cc @@ -36,8 +36,8 @@ namespace grpc { -void Call::PerformOps(CallOpBuffer* buffer, void* tag) { - channel_->PerformOpsOnCall(buffer, tag, this); +void Call::PerformOps(CallOpBuffer* buffer) { + channel_->PerformOpsOnCall(buffer, this); } } // namespace grpc diff --git a/src/cpp/common/completion_queue.cc b/src/cpp/common/completion_queue.cc index d7d616407e0..cbeda81a0bc 100644 --- a/src/cpp/common/completion_queue.cc +++ b/src/cpp/common/completion_queue.cc @@ -56,19 +56,26 @@ class EventDeleter { bool CompletionQueue::Next(void **tag, bool *ok) { std::unique_ptr ev; - while (true) { - ev.reset(grpc_completion_queue_next(cq_, gpr_inf_future)); - if (ev->type == GRPC_QUEUE_SHUTDOWN) { - return false; - } - auto cq_tag = static_cast(ev->tag); - switch (cq_tag->FinalizeResult(ev->data.op_complete == GRPC_OP_OK)) { - case CompletionQueueTag::SUCCEED: *ok = true; break; - case CompletionQueueTag::FAIL: *ok = false; break; - case CompletionQueueTag::SWALLOW: continue; - } - return true; + ev.reset(grpc_completion_queue_next(cq_, gpr_inf_future)); + if (ev->type == GRPC_QUEUE_SHUTDOWN) { + return false; } + auto cq_tag = static_cast(ev->tag); + *ok = ev->data.op_complete == GRPC_OP_OK; + *tag = cq_tag; + cq_tag->FinalizeResult(tag, ok); + return true; +} + +bool CompletionQueue::Pluck(CompletionQueueTag *tag) { + std::unique_ptr ev; + + ev.reset(grpc_completion_queue_pluck(cq_, tag, gpr_inf_future)); + bool ok = ev->data.op_complete == GRPC_OP_OK; + void *ignored = tag; + tag->FinalizeResult(&ignored, &ok); + GPR_ASSERT(ignored == tag); + return ok; } } // namespace grpc