fix use-after-free metadata corruption in C# when receiving response headers for streaming response calls (#27382)

pull/27408/head
Jan Tattermusch 3 years ago committed by GitHub
parent 30b7f09508
commit 2fc133b9be
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 69
      src/csharp/Grpc.Core.Tests/Internal/AsyncCallTest.cs
  2. 15
      src/csharp/Grpc.Core/Internal/AsyncCall.cs
  3. 3
      src/csharp/Grpc.Core/Internal/AsyncCallBase.cs

@ -410,6 +410,22 @@ namespace Grpc.Core.Internal.Tests
// try alternative order of completions
fakeCall.ReceivedStatusOnClientCallback.OnReceivedStatusOnClient(true, new ClientSideStatus(Status.DefaultSuccess, new Metadata()));
fakeCall.ReceivedMessageCallback.OnReceivedMessage(true, CreateNullResponse());
fakeCall.ReceivedResponseHeadersCallback.OnReceivedResponseHeaders(true, new Metadata());
AssertStreamingResponseSuccess(asyncCall, fakeCall, readTask);
}
[Test]
public void ServerStreaming_NoResponse_Success3()
{
asyncCall.StartServerStreamingCall("request1");
var responseStream = new ClientResponseStream<string, string>(asyncCall);
var readTask = responseStream.MoveNext();
// try alternative order of completions
fakeCall.ReceivedStatusOnClientCallback.OnReceivedStatusOnClient(true, new ClientSideStatus(Status.DefaultSuccess, new Metadata()));
fakeCall.ReceivedResponseHeadersCallback.OnReceivedResponseHeaders(true, new Metadata());
fakeCall.ReceivedMessageCallback.OnReceivedMessage(true, CreateNullResponse());
AssertStreamingResponseSuccess(asyncCall, fakeCall, readTask);
}
@ -421,6 +437,9 @@ namespace Grpc.Core.Internal.Tests
var responseStream = new ClientResponseStream<string, string>(asyncCall);
var readTask = responseStream.MoveNext();
fakeCall.ReceivedResponseHeadersCallback.OnReceivedResponseHeaders(true, new Metadata());
Assert.AreEqual(0, asyncCall.ResponseHeadersAsync.Result.Count);
fakeCall.ReceivedMessageCallback.OnReceivedMessage(false, CreateNullResponse()); // after a failed read, we rely on C core to deliver appropriate status code.
fakeCall.ReceivedStatusOnClientCallback.OnReceivedStatusOnClient(true, CreateClientSideStatus(StatusCode.Internal));
@ -433,6 +452,9 @@ namespace Grpc.Core.Internal.Tests
asyncCall.StartServerStreamingCall("request1");
var responseStream = new ClientResponseStream<string, string>(asyncCall);
fakeCall.ReceivedResponseHeadersCallback.OnReceivedResponseHeaders(true, new Metadata());
Assert.AreEqual(0, asyncCall.ResponseHeadersAsync.Result.Count);
var readTask1 = responseStream.MoveNext();
fakeCall.ReceivedMessageCallback.OnReceivedMessage(true, CreateResponsePayload());
Assert.IsTrue(readTask1.Result);
@ -472,12 +494,15 @@ namespace Grpc.Core.Internal.Tests
}
[Test]
public void DuplexStreaming_NoRequestNoResponse_Success()
public void DuplexStreaming_NoRequestNoResponse_Success1()
{
asyncCall.StartDuplexStreamingCall();
var requestStream = new ClientRequestStream<string, string>(asyncCall);
var responseStream = new ClientResponseStream<string, string>(asyncCall);
fakeCall.ReceivedResponseHeadersCallback.OnReceivedResponseHeaders(true, new Metadata());
Assert.AreEqual(0, asyncCall.ResponseHeadersAsync.Result.Count);
var writeTask1 = requestStream.CompleteAsync();
fakeCall.SendCompletionCallback.OnSendCompletion(true);
Assert.DoesNotThrowAsync(async () => await writeTask1);
@ -489,6 +514,27 @@ namespace Grpc.Core.Internal.Tests
AssertStreamingResponseSuccess(asyncCall, fakeCall, readTask);
}
[Test]
public void DuplexStreaming_NoRequestNoResponse_Success2()
{
asyncCall.StartDuplexStreamingCall();
var requestStream = new ClientRequestStream<string, string>(asyncCall);
var responseStream = new ClientResponseStream<string, string>(asyncCall);
var writeTask1 = requestStream.CompleteAsync();
fakeCall.SendCompletionCallback.OnSendCompletion(true);
Assert.DoesNotThrowAsync(async () => await writeTask1);
var readTask = responseStream.MoveNext();
fakeCall.ReceivedMessageCallback.OnReceivedMessage(true, CreateNullResponse());
fakeCall.ReceivedStatusOnClientCallback.OnReceivedStatusOnClient(true, new ClientSideStatus(Status.DefaultSuccess, new Metadata()));
fakeCall.ReceivedResponseHeadersCallback.OnReceivedResponseHeaders(true, new Metadata());
Assert.AreEqual(0, asyncCall.ResponseHeadersAsync.Result.Count);
AssertStreamingResponseSuccess(asyncCall, fakeCall, readTask);
}
[Test]
public void DuplexStreaming_WriteAfterReceivingStatusThrowsRpcException()
{
@ -496,6 +542,9 @@ namespace Grpc.Core.Internal.Tests
var requestStream = new ClientRequestStream<string, string>(asyncCall);
var responseStream = new ClientResponseStream<string, string>(asyncCall);
fakeCall.ReceivedResponseHeadersCallback.OnReceivedResponseHeaders(true, new Metadata());
Assert.AreEqual(0, asyncCall.ResponseHeadersAsync.Result.Count);
var readTask = responseStream.MoveNext();
fakeCall.ReceivedMessageCallback.OnReceivedMessage(true, CreateNullResponse());
fakeCall.ReceivedStatusOnClientCallback.OnReceivedStatusOnClient(true, new ClientSideStatus(Status.DefaultSuccess, new Metadata()));
@ -514,6 +563,9 @@ namespace Grpc.Core.Internal.Tests
var requestStream = new ClientRequestStream<string, string>(asyncCall);
var responseStream = new ClientResponseStream<string, string>(asyncCall);
fakeCall.ReceivedResponseHeadersCallback.OnReceivedResponseHeaders(true, new Metadata());
Assert.AreEqual(0, asyncCall.ResponseHeadersAsync.Result.Count);
var readTask = responseStream.MoveNext();
fakeCall.ReceivedMessageCallback.OnReceivedMessage(true, CreateNullResponse());
fakeCall.ReceivedStatusOnClientCallback.OnReceivedStatusOnClient(true, new ClientSideStatus(Status.DefaultSuccess, new Metadata()));
@ -530,6 +582,9 @@ namespace Grpc.Core.Internal.Tests
var requestStream = new ClientRequestStream<string, string>(asyncCall);
var responseStream = new ClientResponseStream<string, string>(asyncCall);
fakeCall.ReceivedResponseHeadersCallback.OnReceivedResponseHeaders(true, new Metadata());
Assert.AreEqual(0, asyncCall.ResponseHeadersAsync.Result.Count);
var writeTask = requestStream.WriteAsync("request1");
fakeCall.SendCompletionCallback.OnSendCompletion(false);
@ -553,6 +608,9 @@ namespace Grpc.Core.Internal.Tests
var requestStream = new ClientRequestStream<string, string>(asyncCall);
var responseStream = new ClientResponseStream<string, string>(asyncCall);
fakeCall.ReceivedResponseHeadersCallback.OnReceivedResponseHeaders(true, new Metadata());
Assert.AreEqual(0, asyncCall.ResponseHeadersAsync.Result.Count);
var writeTask = requestStream.WriteAsync("request1");
var readTask = responseStream.MoveNext();
@ -573,6 +631,9 @@ namespace Grpc.Core.Internal.Tests
var requestStream = new ClientRequestStream<string, string>(asyncCall);
var responseStream = new ClientResponseStream<string, string>(asyncCall);
fakeCall.ReceivedResponseHeadersCallback.OnReceivedResponseHeaders(true, new Metadata());
Assert.AreEqual(0, asyncCall.ResponseHeadersAsync.Result.Count);
asyncCall.Cancel();
Assert.IsTrue(fakeCall.IsCancelled);
@ -592,6 +653,9 @@ namespace Grpc.Core.Internal.Tests
asyncCall.StartDuplexStreamingCall();
var responseStream = new ClientResponseStream<string, string>(asyncCall);
fakeCall.ReceivedResponseHeadersCallback.OnReceivedResponseHeaders(true, new Metadata());
Assert.AreEqual(0, asyncCall.ResponseHeadersAsync.Result.Count);
asyncCall.Cancel();
Assert.IsTrue(fakeCall.IsCancelled);
@ -613,6 +677,9 @@ namespace Grpc.Core.Internal.Tests
asyncCall.StartDuplexStreamingCall();
var responseStream = new ClientResponseStream<string, string>(asyncCall);
fakeCall.ReceivedResponseHeadersCallback.OnReceivedResponseHeaders(true, new Metadata());
Assert.AreEqual(0, asyncCall.ResponseHeadersAsync.Result.Count);
var readTask1 = responseStream.MoveNext(); // initiate the read before cancel request
asyncCall.Cancel();
Assert.IsTrue(fakeCall.IsCancelled);

@ -236,6 +236,7 @@ namespace Grpc.Core.Internal
Initialize(details.Channel.CompletionQueue);
halfcloseRequested = true;
receiveResponseHeadersPending = true;
using (var serializationScope = DefaultSerializationContext.GetInitializedThreadLocalScope())
{
@ -272,6 +273,7 @@ namespace Grpc.Core.Internal
{
GrpcPreconditions.CheckState(!started);
started = true;
receiveResponseHeadersPending = true;
Initialize(details.Channel.CompletionQueue);
@ -531,6 +533,19 @@ namespace Grpc.Core.Internal
private void HandleReceivedResponseHeaders(bool success, Metadata responseHeaders)
{
// TODO(jtattermusch): handle success==false
bool releasedResources;
lock (myLock)
{
receiveResponseHeadersPending = false;
releasedResources = ReleaseResourcesIfPossible();
}
if (releasedResources)
{
OnAfterReleaseResourcesUnlocked();
}
responseHeadersTcs.SetResult(responseHeaders);
}

@ -62,6 +62,7 @@ namespace Grpc.Core.Internal
protected bool initialMetadataSent;
protected long streamingWritesCounter; // Number of streaming send operations started so far.
protected bool receiveResponseHeadersPending; // True if this is a call with streaming response and the recv_initial_metadata_on_client operation hasn't finished yet.
public AsyncCallBase(Action<TWrite, SerializationContext> serializer, Func<DeserializationContext, TRead> deserializer)
{
@ -171,7 +172,7 @@ namespace Grpc.Core.Internal
if (!disposed && call != null)
{
bool noMoreSendCompletions = streamingWriteTcs == null && (halfcloseRequested || cancelRequested || finished);
if (noMoreSendCompletions && readingDone && finished)
if (noMoreSendCompletions && readingDone && finished && !receiveResponseHeadersPending)
{
ReleaseResources();
return true;

Loading…
Cancel
Save