From e5c446004f2c1d3e2f2e5f0963f99332c6c94bc4 Mon Sep 17 00:00:00 2001 From: Jan Tattermusch Date: Fri, 1 May 2015 11:12:34 -0700 Subject: [PATCH] Added basic support for cancellation --- .../Grpc.Core.Tests/ClientServerTest.cs | 219 +++++++++++++----- src/csharp/Grpc.Core/Calls.cs | 17 +- .../Grpc.Core/Internal/AsyncCallServer.cs | 6 + .../Grpc.Core/Internal/ServerCallHandler.cs | 83 +++++-- .../Grpc.IntegrationTesting/InteropClient.cs | 65 ++++++ .../InteropClientServerTest.cs | 12 +- 6 files changed, 322 insertions(+), 80 deletions(-) diff --git a/src/csharp/Grpc.Core.Tests/ClientServerTest.cs b/src/csharp/Grpc.Core.Tests/ClientServerTest.cs index 9e510e82a61..c91efe53b4c 100644 --- a/src/csharp/Grpc.Core.Tests/ClientServerTest.cs +++ b/src/csharp/Grpc.Core.Tests/ClientServerTest.cs @@ -47,23 +47,59 @@ namespace Grpc.Core.Tests const string Host = "localhost"; const string ServiceName = "/tests.Test"; - static readonly Method UnaryEchoStringMethod = new Method( + static readonly Method EchoMethod = new Method( MethodType.Unary, - "/tests.Test/UnaryEchoString", + "/tests.Test/Echo", + Marshallers.StringMarshaller, + Marshallers.StringMarshaller); + + static readonly Method ConcatAndEchoMethod = new Method( + MethodType.ClientStreaming, + "/tests.Test/ConcatAndEcho", + Marshallers.StringMarshaller, + Marshallers.StringMarshaller); + + static readonly Method NonexistentMethod = new Method( + MethodType.Unary, + "/tests.Test/NonexistentMethod", Marshallers.StringMarshaller, Marshallers.StringMarshaller); static readonly ServerServiceDefinition ServiceDefinition = ServerServiceDefinition.CreateBuilder(ServiceName) - .AddMethod(UnaryEchoStringMethod, HandleUnaryEchoString).Build(); + .AddMethod(EchoMethod, EchoHandler) + .AddMethod(ConcatAndEchoMethod, ConcatAndEchoHandler) + .Build(); + + Server server; + Channel channel; [TestFixtureSetUp] + public void InitClass() + { + GrpcEnvironment.Initialize(); + } + + [SetUp] public void Init() { GrpcEnvironment.Initialize(); + + server = new Server(); + server.AddServiceDefinition(ServiceDefinition); + int port = server.AddListeningPort(Host + ":0"); + server.Start(); + channel = new Channel(Host + ":" + port); } - [TestFixtureTearDown] + [TearDown] public void Cleanup() + { + channel.Dispose(); + server.ShutdownAsync().Wait(); + } + + [TestFixtureTearDown] + public void CleanupClass() { GrpcEnvironment.Shutdown(); } @@ -71,93 +107,160 @@ namespace Grpc.Core.Tests [Test] public void UnaryCall() { - var server = new Server(); - server.AddServiceDefinition(ServiceDefinition); - int port = server.AddListeningPort(Host + ":0"); - server.Start(); - - using (Channel channel = new Channel(Host + ":" + port)) - { - var call = new Call(ServiceName, UnaryEchoStringMethod, channel, Metadata.Empty); - Assert.AreEqual("ABC", Calls.BlockingUnaryCall(call, "ABC", default(CancellationToken))); - } - - server.ShutdownAsync().Wait(); + var call = new Call(ServiceName, EchoMethod, channel, Metadata.Empty); + Assert.AreEqual("ABC", Calls.BlockingUnaryCall(call, "ABC", CancellationToken.None)); } [Test] - public void CallOnDisposedChannel() + public void UnaryCall_ServerHandlerThrows() { - var server = new Server(); - server.AddServiceDefinition(ServiceDefinition); - int port = server.AddListeningPort(Host + ":0"); - server.Start(); - - Channel channel = new Channel(Host + ":" + port); - channel.Dispose(); - - var call = new Call(ServiceName, UnaryEchoStringMethod, channel, Metadata.Empty); + var call = new Call(ServiceName, EchoMethod, channel, Metadata.Empty); try { - Calls.BlockingUnaryCall(call, "ABC", default(CancellationToken)); - Assert.Fail(); + Calls.BlockingUnaryCall(call, "THROW", CancellationToken.None); + Assert.Fail(); } - catch (ObjectDisposedException e) + catch (RpcException e) { + Assert.AreEqual(StatusCode.Unknown, e.Status.StatusCode); } + } - server.ShutdownAsync().Wait(); + [Test] + public void AsyncUnaryCall() + { + var call = new Call(ServiceName, EchoMethod, channel, Metadata.Empty); + var result = Calls.AsyncUnaryCall(call, "ABC", CancellationToken.None).Result; + Assert.AreEqual("ABC", result); } [Test] - public void UnaryCallPerformance() + public void AsyncUnaryCall_ServerHandlerThrows() { - var server = new Server(); - server.AddServiceDefinition(ServiceDefinition); - int port = server.AddListeningPort(Host + ":0"); - server.Start(); + Task.Run(async () => + { + var call = new Call(ServiceName, EchoMethod, channel, Metadata.Empty); + try + { + await Calls.AsyncUnaryCall(call, "THROW", CancellationToken.None); + Assert.Fail(); + } + catch (RpcException e) + { + Assert.AreEqual(StatusCode.Unknown, e.Status.StatusCode); + } + }).Wait(); + } - using (Channel channel = new Channel(Host + ":" + port)) + [Test] + public void ClientStreamingCall() + { + Task.Run(async () => { - var call = new Call(ServiceName, UnaryEchoStringMethod, channel, Metadata.Empty); - BenchmarkUtil.RunBenchmark(100, 100, - () => { Calls.BlockingUnaryCall(call, "ABC", default(CancellationToken)); }); - } + var call = new Call(ServiceName, ConcatAndEchoMethod, channel, Metadata.Empty); + var callResult = Calls.AsyncClientStreamingCall(call, CancellationToken.None); - server.ShutdownAsync().Wait(); + await callResult.RequestStream.WriteAll(new string[] { "A", "B", "C" }); + Assert.AreEqual("ABC", await callResult.Result); + }).Wait(); } [Test] - public void UnknownMethodHandler() + public void ClientStreamingCall_ServerHandlerThrows() { - var server = new Server(); - server.AddServiceDefinition(ServerServiceDefinition.CreateBuilder(ServiceName).Build()); - int port = server.AddListeningPort(Host + ":0"); - server.Start(); + Task.Run(async () => + { + var call = new Call(ServiceName, ConcatAndEchoMethod, channel, Metadata.Empty); + var callResult = Calls.AsyncClientStreamingCall(call, CancellationToken.None); + // TODO(jtattermusch): if we send "A", "THROW", "C", server hangs. + await callResult.RequestStream.WriteAll(new string[] { "A", "B", "THROW" }); + + try + { + await callResult.Result; + } + catch(RpcException e) + { + Assert.AreEqual(StatusCode.Unknown, e.Status.StatusCode); + } + }).Wait(); + } - using (Channel channel = new Channel(Host + ":" + port)) + [Test] + public void ClientStreamingCall_CancelAfterBegin() + { + Task.Run(async () => { - var call = new Call(ServiceName, UnaryEchoStringMethod, channel, Metadata.Empty); + var call = new Call(ServiceName, ConcatAndEchoMethod, channel, Metadata.Empty); + + var cts = new CancellationTokenSource(); + var callResult = Calls.AsyncClientStreamingCall(call, cts.Token); + cts.Cancel(); + try { - Calls.BlockingUnaryCall(call, "ABC", default(CancellationToken)); - Assert.Fail(); + await callResult.Result; } - catch (RpcException e) + catch(RpcException e) { - Assert.AreEqual(StatusCode.Unimplemented, e.Status.StatusCode); + Assert.AreEqual(StatusCode.Cancelled, e.Status.StatusCode); } + }).Wait(); + } + + [Test] + public void UnaryCall_DisposedChannel() + { + channel.Dispose(); + + var call = new Call(ServiceName, EchoMethod, channel, Metadata.Empty); + Assert.Throws(typeof(ObjectDisposedException), () => Calls.BlockingUnaryCall(call, "ABC", CancellationToken.None)); + } + + [Test] + public void UnaryCallPerformance() + { + var call = new Call(ServiceName, EchoMethod, channel, Metadata.Empty); + BenchmarkUtil.RunBenchmark(100, 100, + () => { Calls.BlockingUnaryCall(call, "ABC", default(CancellationToken)); }); + } + + [Test] + public void UnknownMethodHandler() + { + var call = new Call(ServiceName, NonexistentMethod, channel, Metadata.Empty); + try + { + Calls.BlockingUnaryCall(call, "ABC", default(CancellationToken)); + Assert.Fail(); + } + catch (RpcException e) + { + Assert.AreEqual(StatusCode.Unimplemented, e.Status.StatusCode); } + } - server.ShutdownAsync().Wait(); + private static async Task EchoHandler(string request) + { + if (request == "THROW") + { + throw new Exception("This was thrown on purpose by a test"); + } + return request; } - /// - /// Handler for unaryEchoString method. - /// - private static Task HandleUnaryEchoString(string request) + private static async Task ConcatAndEchoHandler(IAsyncStreamReader requestStream) { - return Task.FromResult(request); + string result = ""; + await requestStream.ForEach(async (request) => + { + if (request == "THROW") + { + throw new Exception("This was thrown on purpose by a test"); + } + result += request; + }); + return result; } } } diff --git a/src/csharp/Grpc.Core/Calls.cs b/src/csharp/Grpc.Core/Calls.cs index 9365ccd9fb6..c2397290fdb 100644 --- a/src/csharp/Grpc.Core/Calls.cs +++ b/src/csharp/Grpc.Core/Calls.cs @@ -46,6 +46,8 @@ namespace Grpc.Core public static TResponse BlockingUnaryCall(Call call, TRequest req, CancellationToken token) { var asyncCall = new AsyncCall(call.RequestMarshaller.Serializer, call.ResponseMarshaller.Deserializer); + // TODO(jtattermusch): this gives a race that cancellation can be requested before the call even starts. + RegisterCancellationCallback(asyncCall, token); return asyncCall.UnaryCall(call.Channel, call.Name, req, call.Headers); } @@ -53,7 +55,9 @@ namespace Grpc.Core { var asyncCall = new AsyncCall(call.RequestMarshaller.Serializer, call.ResponseMarshaller.Deserializer); asyncCall.Initialize(call.Channel, GetCompletionQueue(), call.Name); - return await asyncCall.UnaryCallAsync(req, call.Headers); + var asyncResult = asyncCall.UnaryCallAsync(req, call.Headers); + RegisterCancellationCallback(asyncCall, token); + return await asyncResult; } public static AsyncServerStreamingCall AsyncServerStreamingCall(Call call, TRequest req, CancellationToken token) @@ -61,6 +65,7 @@ namespace Grpc.Core var asyncCall = new AsyncCall(call.RequestMarshaller.Serializer, call.ResponseMarshaller.Deserializer); asyncCall.Initialize(call.Channel, GetCompletionQueue(), call.Name); asyncCall.StartServerStreamingCall(req, call.Headers); + RegisterCancellationCallback(asyncCall, token); var responseStream = new ClientResponseStream(asyncCall); return new AsyncServerStreamingCall(responseStream); } @@ -70,6 +75,7 @@ namespace Grpc.Core var asyncCall = new AsyncCall(call.RequestMarshaller.Serializer, call.ResponseMarshaller.Deserializer); asyncCall.Initialize(call.Channel, GetCompletionQueue(), call.Name); var resultTask = asyncCall.ClientStreamingCallAsync(call.Headers); + RegisterCancellationCallback(asyncCall, token); var requestStream = new ClientRequestStream(asyncCall); return new AsyncClientStreamingCall(requestStream, resultTask); } @@ -79,11 +85,20 @@ namespace Grpc.Core var asyncCall = new AsyncCall(call.RequestMarshaller.Serializer, call.ResponseMarshaller.Deserializer); asyncCall.Initialize(call.Channel, GetCompletionQueue(), call.Name); asyncCall.StartDuplexStreamingCall(call.Headers); + RegisterCancellationCallback(asyncCall, token); var requestStream = new ClientRequestStream(asyncCall); var responseStream = new ClientResponseStream(asyncCall); return new AsyncDuplexStreamingCall(requestStream, responseStream); } + private static void RegisterCancellationCallback(AsyncCall asyncCall, CancellationToken token) + { + if (token.CanBeCanceled) + { + token.Register( () => asyncCall.Cancel() ); + } + } + /// /// Gets shared completion queue used for async calls. /// diff --git a/src/csharp/Grpc.Core/Internal/AsyncCallServer.cs b/src/csharp/Grpc.Core/Internal/AsyncCallServer.cs index 25dc15bbc0b..449009336f9 100644 --- a/src/csharp/Grpc.Core/Internal/AsyncCallServer.cs +++ b/src/csharp/Grpc.Core/Internal/AsyncCallServer.cs @@ -121,6 +121,12 @@ namespace Grpc.Core.Internal { finished = true; + if (readCompletionDelegate == null) + { + // allow disposal of native call + readingDone = true; + } + ReleaseResourcesIfPossible(); } // TODO: handle error ... diff --git a/src/csharp/Grpc.Core/Internal/ServerCallHandler.cs b/src/csharp/Grpc.Core/Internal/ServerCallHandler.cs index 2ae924b4a88..0416eada348 100644 --- a/src/csharp/Grpc.Core/Internal/ServerCallHandler.cs +++ b/src/csharp/Grpc.Core/Internal/ServerCallHandler.cs @@ -66,13 +66,21 @@ namespace Grpc.Core.Internal var requestStream = new ServerRequestStream(asyncCall); var responseStream = new ServerResponseStream(asyncCall); - var request = await requestStream.ReadNext(); - // TODO(jtattermusch): we need to read the full stream so that native callhandle gets deallocated. - Preconditions.CheckArgument(await requestStream.ReadNext() == null); - - var result = await handler(request); - await responseStream.Write(result); - await responseStream.WriteStatus(Status.DefaultSuccess); + Status status = Status.DefaultSuccess; + try + { + var request = await requestStream.ReadNext(); + // TODO(jtattermusch): we need to read the full stream so that native callhandle gets deallocated. + Preconditions.CheckArgument(await requestStream.ReadNext() == null); + var result = await handler(request); + await responseStream.Write(result); + } + catch (Exception e) + { + Console.WriteLine("Exception occured in handler: " + e); + status = HandlerUtils.StatusFromException(e); + } + await responseStream.WriteStatus(status); await finishedTask; } } @@ -99,12 +107,21 @@ namespace Grpc.Core.Internal var requestStream = new ServerRequestStream(asyncCall); var responseStream = new ServerResponseStream(asyncCall); - var request = await requestStream.ReadNext(); - // TODO(jtattermusch): we need to read the full stream so that native callhandle gets deallocated. - Preconditions.CheckArgument(await requestStream.ReadNext() == null); - - await handler(request, responseStream); - await responseStream.WriteStatus(Status.DefaultSuccess); + Status status = Status.DefaultSuccess; + try + { + var request = await requestStream.ReadNext(); + // TODO(jtattermusch): we need to read the full stream so that native callhandle gets deallocated. + Preconditions.CheckArgument(await requestStream.ReadNext() == null); + + await handler(request, responseStream); + } + catch (Exception e) + { + Console.WriteLine("Exception occured in handler: " + e); + status = HandlerUtils.StatusFromException(e); + } + await responseStream.WriteStatus(status); await finishedTask; } } @@ -131,9 +148,18 @@ namespace Grpc.Core.Internal var requestStream = new ServerRequestStream(asyncCall); var responseStream = new ServerResponseStream(asyncCall); - var result = await handler(requestStream); - await responseStream.Write(result); - await responseStream.WriteStatus(Status.DefaultSuccess); + Status status = Status.DefaultSuccess; + try + { + var result = await handler(requestStream); + await responseStream.Write(result); + } + catch (Exception e) + { + Console.WriteLine("Exception occured in handler: " + e); + status = HandlerUtils.StatusFromException(e); + } + await responseStream.WriteStatus(status); await finishedTask; } } @@ -160,8 +186,17 @@ namespace Grpc.Core.Internal var requestStream = new ServerRequestStream(asyncCall); var responseStream = new ServerResponseStream(asyncCall); - await handler(requestStream, responseStream); - await responseStream.WriteStatus(Status.DefaultSuccess); + Status status = Status.DefaultSuccess; + try + { + await handler(requestStream, responseStream); + } + catch (Exception e) + { + Console.WriteLine("Exception occured in handler: " + e); + status = HandlerUtils.StatusFromException(e); + } + await responseStream.WriteStatus(status); await finishedTask; } } @@ -173,15 +208,25 @@ namespace Grpc.Core.Internal // We don't care about the payload type here. var asyncCall = new AsyncCallServer( (payload) => payload, (payload) => payload); - + asyncCall.Initialize(call); var finishedTask = asyncCall.ServerSideCallAsync(); var requestStream = new ServerRequestStream(asyncCall); var responseStream = new ServerResponseStream(asyncCall); + await responseStream.WriteStatus(new Status(StatusCode.Unimplemented, "No such method.")); // TODO(jtattermusch): if we don't read what client has sent, the server call never gets disposed. await requestStream.ToList(); await finishedTask; } } + + internal static class HandlerUtils + { + public static Status StatusFromException(Exception e) + { + // TODO(jtattermusch): what is the right status code here? + return new Status(StatusCode.Unknown, "Exception was thrown by handler."); + } + } } diff --git a/src/csharp/Grpc.IntegrationTesting/InteropClient.cs b/src/csharp/Grpc.IntegrationTesting/InteropClient.cs index 573ab304522..440702d06f4 100644 --- a/src/csharp/Grpc.IntegrationTesting/InteropClient.cs +++ b/src/csharp/Grpc.IntegrationTesting/InteropClient.cs @@ -34,6 +34,7 @@ using System; using System.Collections.Generic; using System.Text.RegularExpressions; +using System.Threading; using System.Threading.Tasks; using Google.ProtocolBuffers; @@ -166,6 +167,12 @@ namespace Grpc.IntegrationTesting case "compute_engine_creds": RunComputeEngineCreds(client); break; + case "cancel_after_begin": + RunCancelAfterBegin(client); + break; + case "cancel_after_first_response": + RunCancelAfterFirstResponse(client); + break; case "benchmark_empty_unary": RunBenchmarkEmptyUnary(client); break; @@ -351,6 +358,64 @@ namespace Grpc.IntegrationTesting Console.WriteLine("Passed!"); } + public static void RunCancelAfterBegin(TestServiceGrpc.ITestServiceClient client) + { + Task.Run(async () => + { + Console.WriteLine("running cancel_after_begin"); + + var cts = new CancellationTokenSource(); + var call = client.StreamingInputCall(cts.Token); + cts.Cancel(); + + try + { + var response = await call.Result; + Assert.Fail(); + } + catch (RpcException e) + { + Assert.AreEqual(StatusCode.Cancelled, e.Status.StatusCode); + } + Console.WriteLine("Passed!"); + }).Wait(); + } + + public static void RunCancelAfterFirstResponse(TestServiceGrpc.ITestServiceClient client) + { + Task.Run(async () => + { + Console.WriteLine("running cancel_after_first_response"); + + var cts = new CancellationTokenSource(); + var call = client.FullDuplexCall(cts.Token); + + StreamingOutputCallResponse response; + + await call.RequestStream.Write(StreamingOutputCallRequest.CreateBuilder() + .SetResponseType(PayloadType.COMPRESSABLE) + .AddResponseParameters(ResponseParameters.CreateBuilder().SetSize(31415)) + .SetPayload(CreateZerosPayload(27182)).Build()); + + response = await call.ResponseStream.ReadNext(); + Assert.AreEqual(PayloadType.COMPRESSABLE, response.Payload.Type); + Assert.AreEqual(31415, response.Payload.Body.Length); + + cts.Cancel(); + + try + { + response = await call.ResponseStream.ReadNext(); + Assert.Fail(); + } + catch (RpcException e) + { + Assert.AreEqual(StatusCode.Cancelled, e.Status.StatusCode); + } + Console.WriteLine("Passed!"); + }).Wait(); + } + // This is not an official interop test, but it's useful. public static void RunBenchmarkEmptyUnary(TestServiceGrpc.ITestServiceClient client) { diff --git a/src/csharp/Grpc.IntegrationTesting/InteropClientServerTest.cs b/src/csharp/Grpc.IntegrationTesting/InteropClientServerTest.cs index e929b76b5ed..45380227c25 100644 --- a/src/csharp/Grpc.IntegrationTesting/InteropClientServerTest.cs +++ b/src/csharp/Grpc.IntegrationTesting/InteropClientServerTest.cs @@ -114,8 +114,16 @@ namespace Grpc.IntegrationTesting InteropClient.RunEmptyStream(client); } - // TODO: add cancel_after_begin + [Test] + public void CancelAfterBegin() + { + InteropClient.RunCancelAfterBegin(client); + } - // TODO: add cancel_after_first_response + [Test] + public void CancelAfterFirstResponse() + { + InteropClient.RunCancelAfterFirstResponse(client); + } } }