diff --git a/src/csharp/Grpc.Core.Tests/Internal/AsyncCallTest.cs b/src/csharp/Grpc.Core.Tests/Internal/AsyncCallTest.cs index 141af7760c8..1fa895ba711 100644 --- a/src/csharp/Grpc.Core.Tests/Internal/AsyncCallTest.cs +++ b/src/csharp/Grpc.Core.Tests/Internal/AsyncCallTest.cs @@ -66,7 +66,7 @@ namespace Grpc.Core.Internal.Tests public void AsyncUnary_CompletionSuccess() { var resultTask = asyncCall.UnaryCallAsync("abc"); - fakeCall.UnaryResponseClientHandler(true, new ClientSideStatus(Status.DefaultSuccess, new Metadata()), new byte[] { 1, 2, 3 }); + fakeCall.UnaryResponseClientHandler(true, new ClientSideStatus(Status.DefaultSuccess, new Metadata()), new byte[] { 1, 2, 3 }, new Metadata()); Assert.IsTrue(resultTask.IsCompleted); Assert.IsTrue(fakeCall.IsDisposed); Assert.AreEqual(Status.DefaultSuccess, asyncCall.GetStatus()); @@ -76,7 +76,7 @@ namespace Grpc.Core.Internal.Tests public void AsyncUnary_CompletionFailure() { var resultTask = asyncCall.UnaryCallAsync("abc"); - fakeCall.UnaryResponseClientHandler(false, new ClientSideStatus(), null); + fakeCall.UnaryResponseClientHandler(false, new ClientSideStatus(new Status(StatusCode.Internal, ""), null), new byte[] { 1, 2, 3 }, new Metadata()); Assert.IsTrue(resultTask.IsCompleted); Assert.IsTrue(fakeCall.IsDisposed); diff --git a/src/csharp/Grpc.Core.Tests/ResponseHeadersTest.cs b/src/csharp/Grpc.Core.Tests/ResponseHeadersTest.cs index 706006702e5..8ad41af1b81 100644 --- a/src/csharp/Grpc.Core.Tests/ResponseHeadersTest.cs +++ b/src/csharp/Grpc.Core.Tests/ResponseHeadersTest.cs @@ -73,6 +73,25 @@ namespace Grpc.Core.Tests server.ShutdownAsync().Wait(); } + [Test] + public async Task ResponseHeadersAsync_UnaryCall() + { + helper.UnaryHandler = new UnaryServerMethod(async (request, context) => + { + await context.WriteResponseHeadersAsync(headers); + return "PASS"; + }); + + var call = Calls.AsyncUnaryCall(helper.CreateUnaryCall(), ""); + var responseHeaders = await call.ResponseHeadersAsync; + + Assert.AreEqual(headers.Count, responseHeaders.Count); + Assert.AreEqual("ascii-header", responseHeaders[0].Key); + Assert.AreEqual("abcdefg", responseHeaders[0].Value); + + Assert.AreEqual("PASS", await call.ResponseAsync); + } + [Test] public void WriteResponseHeaders_NullNotAllowed() { diff --git a/src/csharp/Grpc.Core/AsyncClientStreamingCall.cs b/src/csharp/Grpc.Core/AsyncClientStreamingCall.cs index fb9b562c77b..dbaa3085c54 100644 --- a/src/csharp/Grpc.Core/AsyncClientStreamingCall.cs +++ b/src/csharp/Grpc.Core/AsyncClientStreamingCall.cs @@ -44,14 +44,16 @@ namespace Grpc.Core { readonly IClientStreamWriter requestStream; readonly Task responseAsync; + readonly Task responseHeadersAsync; readonly Func getStatusFunc; readonly Func getTrailersFunc; readonly Action disposeAction; - public AsyncClientStreamingCall(IClientStreamWriter requestStream, Task responseAsync, Func getStatusFunc, Func getTrailersFunc, Action disposeAction) + public AsyncClientStreamingCall(IClientStreamWriter requestStream, Task responseAsync, Task responseHeadersAsync, Func getStatusFunc, Func getTrailersFunc, Action disposeAction) { this.requestStream = requestStream; this.responseAsync = responseAsync; + this.responseHeadersAsync = responseHeadersAsync; this.getStatusFunc = getStatusFunc; this.getTrailersFunc = getTrailersFunc; this.disposeAction = disposeAction; @@ -68,6 +70,17 @@ namespace Grpc.Core } } + /// + /// Asynchronous access to response headers. + /// + public Task ResponseHeadersAsync + { + get + { + return this.responseHeadersAsync; + } + } + /// /// Async stream to send streaming requests. /// diff --git a/src/csharp/Grpc.Core/AsyncUnaryCall.cs b/src/csharp/Grpc.Core/AsyncUnaryCall.cs index 224e3439160..154a17a33ef 100644 --- a/src/csharp/Grpc.Core/AsyncUnaryCall.cs +++ b/src/csharp/Grpc.Core/AsyncUnaryCall.cs @@ -43,13 +43,15 @@ namespace Grpc.Core public sealed class AsyncUnaryCall : IDisposable { readonly Task responseAsync; + readonly Task responseHeadersAsync; readonly Func getStatusFunc; readonly Func getTrailersFunc; readonly Action disposeAction; - public AsyncUnaryCall(Task responseAsync, Func getStatusFunc, Func getTrailersFunc, Action disposeAction) + public AsyncUnaryCall(Task responseAsync, Task responseHeadersAsync, Func getStatusFunc, Func getTrailersFunc, Action disposeAction) { this.responseAsync = responseAsync; + this.responseHeadersAsync = responseHeadersAsync; this.getStatusFunc = getStatusFunc; this.getTrailersFunc = getTrailersFunc; this.disposeAction = disposeAction; @@ -66,6 +68,17 @@ namespace Grpc.Core } } + /// + /// Asynchronous access to response headers. + /// + public Task ResponseHeadersAsync + { + get + { + return this.responseHeadersAsync; + } + } + /// /// Allows awaiting this object directly. /// diff --git a/src/csharp/Grpc.Core/Calls.cs b/src/csharp/Grpc.Core/Calls.cs index 7067456638a..ada3616aa4a 100644 --- a/src/csharp/Grpc.Core/Calls.cs +++ b/src/csharp/Grpc.Core/Calls.cs @@ -74,7 +74,7 @@ namespace Grpc.Core { var asyncCall = new AsyncCall(call); var asyncResult = asyncCall.UnaryCallAsync(req); - return new AsyncUnaryCall(asyncResult, asyncCall.GetStatus, asyncCall.GetTrailers, asyncCall.Cancel); + return new AsyncUnaryCall(asyncResult, asyncCall.ResponseHeadersAsync, asyncCall.GetStatus, asyncCall.GetTrailers, asyncCall.Cancel); } /// @@ -110,7 +110,7 @@ namespace Grpc.Core var asyncCall = new AsyncCall(call); var resultTask = asyncCall.ClientStreamingCallAsync(); var requestStream = new ClientRequestStream(asyncCall); - return new AsyncClientStreamingCall(requestStream, resultTask, asyncCall.GetStatus, asyncCall.GetTrailers, asyncCall.Cancel); + return new AsyncClientStreamingCall(requestStream, resultTask, asyncCall.ResponseHeadersAsync, asyncCall.GetStatus, asyncCall.GetTrailers, asyncCall.Cancel); } /// diff --git a/src/csharp/Grpc.Core/Internal/AsyncCall.cs b/src/csharp/Grpc.Core/Internal/AsyncCall.cs index 30d60077f01..132b4264243 100644 --- a/src/csharp/Grpc.Core/Internal/AsyncCall.cs +++ b/src/csharp/Grpc.Core/Internal/AsyncCall.cs @@ -56,6 +56,9 @@ namespace Grpc.Core.Internal // Completion of a pending unary response if not null. TaskCompletionSource unaryResponseTcs; + // Response headers set here once received. + TaskCompletionSource responseHeadersTcs = new TaskCompletionSource(); + // Set after status is received. Used for both unary and streaming response calls. ClientSideStatus? finishedStatus; @@ -110,7 +113,7 @@ namespace Grpc.Core.Internal bool success = (ev.success != 0); try { - HandleUnaryResponse(success, ctx.GetReceivedStatusOnClient(), ctx.GetReceivedMessage()); + HandleUnaryResponse(success, ctx.GetReceivedStatusOnClient(), ctx.GetReceivedMessage(), ctx.GetReceivedInitialMetadata()); } catch (Exception e) { @@ -257,6 +260,17 @@ namespace Grpc.Core.Internal } } + /// + /// Get the task that completes once response headers are received. + /// + public Task ResponseHeadersAsync + { + get + { + return responseHeadersTcs.Task; + } + } + /// /// Gets the resulting status if the call has already finished. /// Throws InvalidOperationException otherwise. @@ -371,7 +385,7 @@ namespace Grpc.Core.Internal /// /// Handler for unary response completion. /// - private void HandleUnaryResponse(bool success, ClientSideStatus receivedStatus, byte[] receivedMessage) + private void HandleUnaryResponse(bool success, ClientSideStatus receivedStatus, byte[] receivedMessage, Metadata responseHeaders) { lock (myLock) { @@ -383,18 +397,13 @@ namespace Grpc.Core.Internal ReleaseResourcesIfPossible(); } - if (!success) - { - var internalError = new Status(StatusCode.Internal, "Internal error occured."); - finishedStatus = new ClientSideStatus(internalError, null); - unaryResponseTcs.SetException(new RpcException(internalError)); - return; - } + responseHeadersTcs.SetResult(responseHeaders); var status = receivedStatus.Status; - if (status.StatusCode != StatusCode.OK) + if (!success || status.StatusCode != StatusCode.OK) { + unaryResponseTcs.SetException(new RpcException(status)); return; } diff --git a/src/csharp/Grpc.Core/Internal/CallSafeHandle.cs b/src/csharp/Grpc.Core/Internal/CallSafeHandle.cs index e1466da65b3..ed6747ea93a 100644 --- a/src/csharp/Grpc.Core/Internal/CallSafeHandle.cs +++ b/src/csharp/Grpc.Core/Internal/CallSafeHandle.cs @@ -112,7 +112,7 @@ namespace Grpc.Core.Internal public void StartUnary(UnaryResponseClientHandler callback, byte[] payload, MetadataArraySafeHandle metadataArray, WriteFlags writeFlags) { var ctx = BatchContextSafeHandle.Create(); - completionRegistry.RegisterBatchCompletion(ctx, (success, context) => callback(success, context.GetReceivedStatusOnClient(), context.GetReceivedMessage())); + completionRegistry.RegisterBatchCompletion(ctx, (success, context) => callback(success, context.GetReceivedStatusOnClient(), context.GetReceivedMessage(), context.GetReceivedInitialMetadata())); grpcsharp_call_start_unary(this, ctx, payload, new UIntPtr((ulong)payload.Length), metadataArray, writeFlags) .CheckOk(); } @@ -126,7 +126,7 @@ namespace Grpc.Core.Internal public void StartClientStreaming(UnaryResponseClientHandler callback, MetadataArraySafeHandle metadataArray) { var ctx = BatchContextSafeHandle.Create(); - completionRegistry.RegisterBatchCompletion(ctx, (success, context) => callback(success, context.GetReceivedStatusOnClient(), context.GetReceivedMessage())); + completionRegistry.RegisterBatchCompletion(ctx, (success, context) => callback(success, context.GetReceivedStatusOnClient(), context.GetReceivedMessage(), context.GetReceivedInitialMetadata())); grpcsharp_call_start_client_streaming(this, ctx, metadataArray).CheckOk(); } diff --git a/src/csharp/Grpc.Core/Internal/INativeCall.cs b/src/csharp/Grpc.Core/Internal/INativeCall.cs index 42028e458cf..ef2e230ff8d 100644 --- a/src/csharp/Grpc.Core/Internal/INativeCall.cs +++ b/src/csharp/Grpc.Core/Internal/INativeCall.cs @@ -33,8 +33,9 @@ using System; namespace Grpc.Core.Internal { - internal delegate void UnaryResponseClientHandler(bool success, ClientSideStatus receivedStatus, byte[] receivedMessage); + internal delegate void UnaryResponseClientHandler(bool success, ClientSideStatus receivedStatus, byte[] receivedMessage, Metadata responseHeaders); + // Received status for streaming response calls. internal delegate void ReceivedStatusOnClientHandler(bool success, ClientSideStatus receivedStatus); internal delegate void ReceivedMessageHandler(bool success, byte[] receivedMessage);