diff --git a/test/cpp/interop/interop_client.cc b/test/cpp/interop/interop_client.cc index fd01ad06667..7af39422304 100644 --- a/test/cpp/interop/interop_client.cc +++ b/test/cpp/interop/interop_client.cc @@ -82,8 +82,46 @@ CompressionType GetInteropCompressionTypeFromCompressionAlgorithm( } } // namespace +InteropClient::ServiceStub::ServiceStub(std::shared_ptr channel, + bool new_stub_every_call) + : channel_(channel), new_stub_every_call_(new_stub_every_call) { + // If new_stub_every_call is false, then this is our chance to initialize + // stub_. (see Get()) + if (!new_stub_every_call) { + stub_ = TestService::NewStub(channel); + } +} + +TestService::Stub* InteropClient::ServiceStub::Get() { + if (new_stub_every_call_) { + stub_ = TestService::NewStub(channel_); + } + + return stub_.get(); +} + +void InteropClient::ServiceStub::Reset(std::shared_ptr channel) { + channel_ = channel; + + // Update stub_ as well. Note: If new_stub_every_call_ is true, we can set + // stub_ to nullptr since the next call to Get() will create a new stub + if (new_stub_every_call_) { + stub_.reset(nullptr); + } else { + stub_ = TestService::NewStub(channel); + } +} + +void InteropClient::Reset(std::shared_ptr channel) { + serviceStub_.Reset(channel); +} + InteropClient::InteropClient(std::shared_ptr channel) - : channel_(channel), stub_(TestService::NewStub(channel)) {} + : serviceStub_(channel, true) {} + +InteropClient::InteropClient(std::shared_ptr channel, + bool new_stub_every_test_case) + : serviceStub_(channel, new_stub_every_test_case) {} void InteropClient::AssertOkOrPrintErrorStatus(const Status& s) { if (s.ok()) { @@ -101,7 +139,7 @@ void InteropClient::DoEmpty() { Empty response = Empty::default_instance(); ClientContext context; - Status s = stub_->EmptyCall(&context, request, &response); + Status s = serviceStub_.Get()->EmptyCall(&context, request, &response); AssertOkOrPrintErrorStatus(s); gpr_log(GPR_INFO, "Empty rpc done."); @@ -110,7 +148,6 @@ void InteropClient::DoEmpty() { // Shared code to set large payload, make rpc and check response payload. void InteropClient::PerformLargeUnary(SimpleRequest* request, SimpleResponse* response) { - ClientContext context; InteropClientContextInspector inspector(context); // If the request doesn't already specify the response type, default to @@ -119,7 +156,7 @@ void InteropClient::PerformLargeUnary(SimpleRequest* request, grpc::string payload(kLargeRequestSize, '\0'); request->mutable_payload()->set_body(payload.c_str(), kLargeRequestSize); - Status s = stub_->UnaryCall(&context, *request, response); + Status s = serviceStub_.Get()->UnaryCall(&context, *request, response); // Compression related checks. GPR_ASSERT(request->response_compression() == @@ -188,7 +225,7 @@ void InteropClient::DoOauth2AuthToken(const grpc::string& username, ClientContext context; - Status s = stub_->UnaryCall(&context, request, &response); + Status s = serviceStub_.Get()->UnaryCall(&context, request, &response); AssertOkOrPrintErrorStatus(s); GPR_ASSERT(!response.username().empty()); @@ -212,7 +249,7 @@ void InteropClient::DoPerRpcCreds(const grpc::string& json_key) { context.set_credentials(creds); - Status s = stub_->UnaryCall(&context, request, &response); + Status s = serviceStub_.Get()->UnaryCall(&context, request, &response); AssertOkOrPrintErrorStatus(s); GPR_ASSERT(!response.username().empty()); @@ -271,7 +308,7 @@ void InteropClient::DoRequestStreaming() { StreamingInputCallResponse response; std::unique_ptr> stream( - stub_->StreamingInputCall(&context, &response)); + serviceStub_.Get()->StreamingInputCall(&context, &response)); int aggregated_payload_size = 0; for (unsigned int i = 0; i < request_stream_sizes.size(); ++i) { @@ -299,7 +336,7 @@ void InteropClient::DoResponseStreaming() { } StreamingOutputCallResponse response; std::unique_ptr> stream( - stub_->StreamingOutputCall(&context, request)); + serviceStub_.Get()->StreamingOutputCall(&context, request)); unsigned int i = 0; while (stream->Read(&response)) { @@ -314,7 +351,6 @@ void InteropClient::DoResponseStreaming() { } void InteropClient::DoResponseCompressedStreaming() { - const CompressionType compression_types[] = {NONE, GZIP, DEFLATE}; const PayloadType payload_types[] = {COMPRESSABLE, UNCOMPRESSABLE, RANDOM}; for (size_t i = 0; i < GPR_ARRAY_SIZE(payload_types); i++) { @@ -341,7 +377,7 @@ void InteropClient::DoResponseCompressedStreaming() { StreamingOutputCallResponse response; std::unique_ptr> stream( - stub_->StreamingOutputCall(&context, request)); + serviceStub_.Get()->StreamingOutputCall(&context, request)); size_t k = 0; while (stream->Read(&response)) { @@ -404,7 +440,7 @@ void InteropClient::DoResponseStreamingWithSlowConsumer() { } StreamingOutputCallResponse response; std::unique_ptr> stream( - stub_->StreamingOutputCall(&context, request)); + serviceStub_.Get()->StreamingOutputCall(&context, request)); int i = 0; while (stream->Read(&response)) { @@ -427,7 +463,7 @@ void InteropClient::DoHalfDuplex() { ClientContext context; std::unique_ptr> - stream(stub_->HalfDuplexCall(&context)); + stream(serviceStub_.Get()->HalfDuplexCall(&context)); StreamingOutputCallRequest request; ResponseParameters* response_parameter = request.add_response_parameters(); @@ -456,7 +492,7 @@ void InteropClient::DoPingPong() { ClientContext context; std::unique_ptr> - stream(stub_->FullDuplexCall(&context)); + stream(serviceStub_.Get()->FullDuplexCall(&context)); StreamingOutputCallRequest request; request.set_response_type(PayloadType::COMPRESSABLE); @@ -487,7 +523,7 @@ void InteropClient::DoCancelAfterBegin() { StreamingInputCallResponse response; std::unique_ptr> stream( - stub_->StreamingInputCall(&context, &response)); + serviceStub_.Get()->StreamingInputCall(&context, &response)); gpr_log(GPR_INFO, "Trying to cancel..."); context.TryCancel(); @@ -502,7 +538,7 @@ void InteropClient::DoCancelAfterFirstResponse() { ClientContext context; std::unique_ptr> - stream(stub_->FullDuplexCall(&context)); + stream(serviceStub_.Get()->FullDuplexCall(&context)); StreamingOutputCallRequest request; request.set_response_type(PayloadType::COMPRESSABLE); @@ -529,7 +565,7 @@ void InteropClient::DoTimeoutOnSleepingServer() { context.set_deadline(deadline); std::unique_ptr> - stream(stub_->FullDuplexCall(&context)); + stream(serviceStub_.Get()->FullDuplexCall(&context)); StreamingOutputCallRequest request; request.mutable_payload()->set_body(grpc::string(27182, '\0')); @@ -546,7 +582,7 @@ void InteropClient::DoEmptyStream() { ClientContext context; std::unique_ptr> - stream(stub_->FullDuplexCall(&context)); + stream(serviceStub_.Get()->FullDuplexCall(&context)); stream->WritesDone(); StreamingOutputCallResponse response; GPR_ASSERT(stream->Read(&response) == false); @@ -566,7 +602,7 @@ void InteropClient::DoStatusWithMessage() { grpc::string test_msg = "This is a test message"; requested_status->set_message(test_msg); - Status s = stub_->UnaryCall(&context, request, &response); + Status s = serviceStub_.Get()->UnaryCall(&context, request, &response); GPR_ASSERT(s.error_code() == grpc::StatusCode::UNKNOWN); GPR_ASSERT(s.error_message() == test_msg); diff --git a/test/cpp/interop/interop_client.h b/test/cpp/interop/interop_client.h index 1f13d5b971a..1bfb49d514a 100644 --- a/test/cpp/interop/interop_client.h +++ b/test/cpp/interop/interop_client.h @@ -47,9 +47,14 @@ namespace testing { class InteropClient { public: explicit InteropClient(std::shared_ptr channel); + explicit InteropClient( + std::shared_ptr channel, + bool new_stub_every_test_case); // If new_stub_every_test_case is true, + // a new TestService::Stub object is + // created for every test case below ~InteropClient() {} - void Reset(std::shared_ptr channel) { channel_ = channel; } + void Reset(std::shared_ptr channel); void DoEmpty(); void DoLargeUnary(); @@ -77,11 +82,26 @@ class InteropClient { void DoPerRpcCreds(const grpc::string& json_key); private: + class ServiceStub { + public: + // If new_stub_every_call = true, pointer to a new instance of + // TestServce::Stub is returned by Get() everytime it is called + ServiceStub(std::shared_ptr channel, bool new_stub_every_call); + + TestService::Stub* Get(); + + void Reset(std::shared_ptr channel); + + private: + std::unique_ptr stub_; + std::shared_ptr channel_; + bool new_stub_every_call_; // If true, a new stub is returned by every + // Get() call + }; + void PerformLargeUnary(SimpleRequest* request, SimpleResponse* response); void AssertOkOrPrintErrorStatus(const Status& s); - - std::shared_ptr channel_; - std::unique_ptr stub_; + ServiceStub serviceStub_; }; } // namespace testing diff --git a/test/cpp/interop/stress_interop_client.cc b/test/cpp/interop/stress_interop_client.cc index d1d0b0d4f31..5ade60057b4 100644 --- a/test/cpp/interop/stress_interop_client.cc +++ b/test/cpp/interop/stress_interop_client.cc @@ -92,7 +92,7 @@ StressTestInteropClient::StressTestInteropClient( // that won't work with InsecureCredentials() std::shared_ptr channel( CreateChannel(server_address, InsecureCredentials())); - interop_client_.reset(new InteropClient(channel)); + interop_client_.reset(new InteropClient(channel, false)); } void StressTestInteropClient::MainLoop() {