fix use-after-free metadata corruption in C# when receiving response headers for streaming response calls (#27382)

pull/27408/head
Jan Tattermusch 3 years ago committed by GitHub
parent 30b7f09508
commit 2fc133b9be
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 69
      src/csharp/Grpc.Core.Tests/Internal/AsyncCallTest.cs
  2. 15
      src/csharp/Grpc.Core/Internal/AsyncCall.cs
  3. 3
      src/csharp/Grpc.Core/Internal/AsyncCallBase.cs

@ -410,6 +410,22 @@ namespace Grpc.Core.Internal.Tests
// try alternative order of completions // try alternative order of completions
fakeCall.ReceivedStatusOnClientCallback.OnReceivedStatusOnClient(true, new ClientSideStatus(Status.DefaultSuccess, new Metadata())); fakeCall.ReceivedStatusOnClientCallback.OnReceivedStatusOnClient(true, new ClientSideStatus(Status.DefaultSuccess, new Metadata()));
fakeCall.ReceivedMessageCallback.OnReceivedMessage(true, CreateNullResponse()); fakeCall.ReceivedMessageCallback.OnReceivedMessage(true, CreateNullResponse());
fakeCall.ReceivedResponseHeadersCallback.OnReceivedResponseHeaders(true, new Metadata());
AssertStreamingResponseSuccess(asyncCall, fakeCall, readTask);
}
[Test]
public void ServerStreaming_NoResponse_Success3()
{
asyncCall.StartServerStreamingCall("request1");
var responseStream = new ClientResponseStream<string, string>(asyncCall);
var readTask = responseStream.MoveNext();
// try alternative order of completions
fakeCall.ReceivedStatusOnClientCallback.OnReceivedStatusOnClient(true, new ClientSideStatus(Status.DefaultSuccess, new Metadata()));
fakeCall.ReceivedResponseHeadersCallback.OnReceivedResponseHeaders(true, new Metadata());
fakeCall.ReceivedMessageCallback.OnReceivedMessage(true, CreateNullResponse());
AssertStreamingResponseSuccess(asyncCall, fakeCall, readTask); AssertStreamingResponseSuccess(asyncCall, fakeCall, readTask);
} }
@ -421,6 +437,9 @@ namespace Grpc.Core.Internal.Tests
var responseStream = new ClientResponseStream<string, string>(asyncCall); var responseStream = new ClientResponseStream<string, string>(asyncCall);
var readTask = responseStream.MoveNext(); var readTask = responseStream.MoveNext();
fakeCall.ReceivedResponseHeadersCallback.OnReceivedResponseHeaders(true, new Metadata());
Assert.AreEqual(0, asyncCall.ResponseHeadersAsync.Result.Count);
fakeCall.ReceivedMessageCallback.OnReceivedMessage(false, CreateNullResponse()); // after a failed read, we rely on C core to deliver appropriate status code. fakeCall.ReceivedMessageCallback.OnReceivedMessage(false, CreateNullResponse()); // after a failed read, we rely on C core to deliver appropriate status code.
fakeCall.ReceivedStatusOnClientCallback.OnReceivedStatusOnClient(true, CreateClientSideStatus(StatusCode.Internal)); fakeCall.ReceivedStatusOnClientCallback.OnReceivedStatusOnClient(true, CreateClientSideStatus(StatusCode.Internal));
@ -433,6 +452,9 @@ namespace Grpc.Core.Internal.Tests
asyncCall.StartServerStreamingCall("request1"); asyncCall.StartServerStreamingCall("request1");
var responseStream = new ClientResponseStream<string, string>(asyncCall); var responseStream = new ClientResponseStream<string, string>(asyncCall);
fakeCall.ReceivedResponseHeadersCallback.OnReceivedResponseHeaders(true, new Metadata());
Assert.AreEqual(0, asyncCall.ResponseHeadersAsync.Result.Count);
var readTask1 = responseStream.MoveNext(); var readTask1 = responseStream.MoveNext();
fakeCall.ReceivedMessageCallback.OnReceivedMessage(true, CreateResponsePayload()); fakeCall.ReceivedMessageCallback.OnReceivedMessage(true, CreateResponsePayload());
Assert.IsTrue(readTask1.Result); Assert.IsTrue(readTask1.Result);
@ -472,12 +494,15 @@ namespace Grpc.Core.Internal.Tests
} }
[Test] [Test]
public void DuplexStreaming_NoRequestNoResponse_Success() public void DuplexStreaming_NoRequestNoResponse_Success1()
{ {
asyncCall.StartDuplexStreamingCall(); asyncCall.StartDuplexStreamingCall();
var requestStream = new ClientRequestStream<string, string>(asyncCall); var requestStream = new ClientRequestStream<string, string>(asyncCall);
var responseStream = new ClientResponseStream<string, string>(asyncCall); var responseStream = new ClientResponseStream<string, string>(asyncCall);
fakeCall.ReceivedResponseHeadersCallback.OnReceivedResponseHeaders(true, new Metadata());
Assert.AreEqual(0, asyncCall.ResponseHeadersAsync.Result.Count);
var writeTask1 = requestStream.CompleteAsync(); var writeTask1 = requestStream.CompleteAsync();
fakeCall.SendCompletionCallback.OnSendCompletion(true); fakeCall.SendCompletionCallback.OnSendCompletion(true);
Assert.DoesNotThrowAsync(async () => await writeTask1); Assert.DoesNotThrowAsync(async () => await writeTask1);
@ -489,6 +514,27 @@ namespace Grpc.Core.Internal.Tests
AssertStreamingResponseSuccess(asyncCall, fakeCall, readTask); AssertStreamingResponseSuccess(asyncCall, fakeCall, readTask);
} }
[Test]
public void DuplexStreaming_NoRequestNoResponse_Success2()
{
asyncCall.StartDuplexStreamingCall();
var requestStream = new ClientRequestStream<string, string>(asyncCall);
var responseStream = new ClientResponseStream<string, string>(asyncCall);
var writeTask1 = requestStream.CompleteAsync();
fakeCall.SendCompletionCallback.OnSendCompletion(true);
Assert.DoesNotThrowAsync(async () => await writeTask1);
var readTask = responseStream.MoveNext();
fakeCall.ReceivedMessageCallback.OnReceivedMessage(true, CreateNullResponse());
fakeCall.ReceivedStatusOnClientCallback.OnReceivedStatusOnClient(true, new ClientSideStatus(Status.DefaultSuccess, new Metadata()));
fakeCall.ReceivedResponseHeadersCallback.OnReceivedResponseHeaders(true, new Metadata());
Assert.AreEqual(0, asyncCall.ResponseHeadersAsync.Result.Count);
AssertStreamingResponseSuccess(asyncCall, fakeCall, readTask);
}
[Test] [Test]
public void DuplexStreaming_WriteAfterReceivingStatusThrowsRpcException() public void DuplexStreaming_WriteAfterReceivingStatusThrowsRpcException()
{ {
@ -496,6 +542,9 @@ namespace Grpc.Core.Internal.Tests
var requestStream = new ClientRequestStream<string, string>(asyncCall); var requestStream = new ClientRequestStream<string, string>(asyncCall);
var responseStream = new ClientResponseStream<string, string>(asyncCall); var responseStream = new ClientResponseStream<string, string>(asyncCall);
fakeCall.ReceivedResponseHeadersCallback.OnReceivedResponseHeaders(true, new Metadata());
Assert.AreEqual(0, asyncCall.ResponseHeadersAsync.Result.Count);
var readTask = responseStream.MoveNext(); var readTask = responseStream.MoveNext();
fakeCall.ReceivedMessageCallback.OnReceivedMessage(true, CreateNullResponse()); fakeCall.ReceivedMessageCallback.OnReceivedMessage(true, CreateNullResponse());
fakeCall.ReceivedStatusOnClientCallback.OnReceivedStatusOnClient(true, new ClientSideStatus(Status.DefaultSuccess, new Metadata())); fakeCall.ReceivedStatusOnClientCallback.OnReceivedStatusOnClient(true, new ClientSideStatus(Status.DefaultSuccess, new Metadata()));
@ -514,6 +563,9 @@ namespace Grpc.Core.Internal.Tests
var requestStream = new ClientRequestStream<string, string>(asyncCall); var requestStream = new ClientRequestStream<string, string>(asyncCall);
var responseStream = new ClientResponseStream<string, string>(asyncCall); var responseStream = new ClientResponseStream<string, string>(asyncCall);
fakeCall.ReceivedResponseHeadersCallback.OnReceivedResponseHeaders(true, new Metadata());
Assert.AreEqual(0, asyncCall.ResponseHeadersAsync.Result.Count);
var readTask = responseStream.MoveNext(); var readTask = responseStream.MoveNext();
fakeCall.ReceivedMessageCallback.OnReceivedMessage(true, CreateNullResponse()); fakeCall.ReceivedMessageCallback.OnReceivedMessage(true, CreateNullResponse());
fakeCall.ReceivedStatusOnClientCallback.OnReceivedStatusOnClient(true, new ClientSideStatus(Status.DefaultSuccess, new Metadata())); fakeCall.ReceivedStatusOnClientCallback.OnReceivedStatusOnClient(true, new ClientSideStatus(Status.DefaultSuccess, new Metadata()));
@ -530,6 +582,9 @@ namespace Grpc.Core.Internal.Tests
var requestStream = new ClientRequestStream<string, string>(asyncCall); var requestStream = new ClientRequestStream<string, string>(asyncCall);
var responseStream = new ClientResponseStream<string, string>(asyncCall); var responseStream = new ClientResponseStream<string, string>(asyncCall);
fakeCall.ReceivedResponseHeadersCallback.OnReceivedResponseHeaders(true, new Metadata());
Assert.AreEqual(0, asyncCall.ResponseHeadersAsync.Result.Count);
var writeTask = requestStream.WriteAsync("request1"); var writeTask = requestStream.WriteAsync("request1");
fakeCall.SendCompletionCallback.OnSendCompletion(false); fakeCall.SendCompletionCallback.OnSendCompletion(false);
@ -553,6 +608,9 @@ namespace Grpc.Core.Internal.Tests
var requestStream = new ClientRequestStream<string, string>(asyncCall); var requestStream = new ClientRequestStream<string, string>(asyncCall);
var responseStream = new ClientResponseStream<string, string>(asyncCall); var responseStream = new ClientResponseStream<string, string>(asyncCall);
fakeCall.ReceivedResponseHeadersCallback.OnReceivedResponseHeaders(true, new Metadata());
Assert.AreEqual(0, asyncCall.ResponseHeadersAsync.Result.Count);
var writeTask = requestStream.WriteAsync("request1"); var writeTask = requestStream.WriteAsync("request1");
var readTask = responseStream.MoveNext(); var readTask = responseStream.MoveNext();
@ -573,6 +631,9 @@ namespace Grpc.Core.Internal.Tests
var requestStream = new ClientRequestStream<string, string>(asyncCall); var requestStream = new ClientRequestStream<string, string>(asyncCall);
var responseStream = new ClientResponseStream<string, string>(asyncCall); var responseStream = new ClientResponseStream<string, string>(asyncCall);
fakeCall.ReceivedResponseHeadersCallback.OnReceivedResponseHeaders(true, new Metadata());
Assert.AreEqual(0, asyncCall.ResponseHeadersAsync.Result.Count);
asyncCall.Cancel(); asyncCall.Cancel();
Assert.IsTrue(fakeCall.IsCancelled); Assert.IsTrue(fakeCall.IsCancelled);
@ -592,6 +653,9 @@ namespace Grpc.Core.Internal.Tests
asyncCall.StartDuplexStreamingCall(); asyncCall.StartDuplexStreamingCall();
var responseStream = new ClientResponseStream<string, string>(asyncCall); var responseStream = new ClientResponseStream<string, string>(asyncCall);
fakeCall.ReceivedResponseHeadersCallback.OnReceivedResponseHeaders(true, new Metadata());
Assert.AreEqual(0, asyncCall.ResponseHeadersAsync.Result.Count);
asyncCall.Cancel(); asyncCall.Cancel();
Assert.IsTrue(fakeCall.IsCancelled); Assert.IsTrue(fakeCall.IsCancelled);
@ -613,6 +677,9 @@ namespace Grpc.Core.Internal.Tests
asyncCall.StartDuplexStreamingCall(); asyncCall.StartDuplexStreamingCall();
var responseStream = new ClientResponseStream<string, string>(asyncCall); var responseStream = new ClientResponseStream<string, string>(asyncCall);
fakeCall.ReceivedResponseHeadersCallback.OnReceivedResponseHeaders(true, new Metadata());
Assert.AreEqual(0, asyncCall.ResponseHeadersAsync.Result.Count);
var readTask1 = responseStream.MoveNext(); // initiate the read before cancel request var readTask1 = responseStream.MoveNext(); // initiate the read before cancel request
asyncCall.Cancel(); asyncCall.Cancel();
Assert.IsTrue(fakeCall.IsCancelled); Assert.IsTrue(fakeCall.IsCancelled);

@ -236,6 +236,7 @@ namespace Grpc.Core.Internal
Initialize(details.Channel.CompletionQueue); Initialize(details.Channel.CompletionQueue);
halfcloseRequested = true; halfcloseRequested = true;
receiveResponseHeadersPending = true;
using (var serializationScope = DefaultSerializationContext.GetInitializedThreadLocalScope()) using (var serializationScope = DefaultSerializationContext.GetInitializedThreadLocalScope())
{ {
@ -272,6 +273,7 @@ namespace Grpc.Core.Internal
{ {
GrpcPreconditions.CheckState(!started); GrpcPreconditions.CheckState(!started);
started = true; started = true;
receiveResponseHeadersPending = true;
Initialize(details.Channel.CompletionQueue); Initialize(details.Channel.CompletionQueue);
@ -531,6 +533,19 @@ namespace Grpc.Core.Internal
private void HandleReceivedResponseHeaders(bool success, Metadata responseHeaders) private void HandleReceivedResponseHeaders(bool success, Metadata responseHeaders)
{ {
// TODO(jtattermusch): handle success==false // TODO(jtattermusch): handle success==false
bool releasedResources;
lock (myLock)
{
receiveResponseHeadersPending = false;
releasedResources = ReleaseResourcesIfPossible();
}
if (releasedResources)
{
OnAfterReleaseResourcesUnlocked();
}
responseHeadersTcs.SetResult(responseHeaders); responseHeadersTcs.SetResult(responseHeaders);
} }

@ -62,6 +62,7 @@ namespace Grpc.Core.Internal
protected bool initialMetadataSent; protected bool initialMetadataSent;
protected long streamingWritesCounter; // Number of streaming send operations started so far. protected long streamingWritesCounter; // Number of streaming send operations started so far.
protected bool receiveResponseHeadersPending; // True if this is a call with streaming response and the recv_initial_metadata_on_client operation hasn't finished yet.
public AsyncCallBase(Action<TWrite, SerializationContext> serializer, Func<DeserializationContext, TRead> deserializer) public AsyncCallBase(Action<TWrite, SerializationContext> serializer, Func<DeserializationContext, TRead> deserializer)
{ {
@ -171,7 +172,7 @@ namespace Grpc.Core.Internal
if (!disposed && call != null) if (!disposed && call != null)
{ {
bool noMoreSendCompletions = streamingWriteTcs == null && (halfcloseRequested || cancelRequested || finished); bool noMoreSendCompletions = streamingWriteTcs == null && (halfcloseRequested || cancelRequested || finished);
if (noMoreSendCompletions && readingDone && finished) if (noMoreSendCompletions && readingDone && finished && !receiveResponseHeadersPending)
{ {
ReleaseResources(); ReleaseResources();
return true; return true;

Loading…
Cancel
Save