Merge pull request #16554 from jtattermusch/csharp_dont_leak_when_call_init_fails

C#: avoid leaking resources when starting a call fails
reviewable/pr16055/r17^2
Jan Tattermusch 6 years ago committed by GitHub
commit d90d082ca2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 75
      src/csharp/Grpc.Core.Tests/Internal/AsyncCallTest.cs
  2. 23
      src/csharp/Grpc.Core.Tests/Internal/FakeNativeCall.cs
  3. 6
      src/csharp/Grpc.Core/Channel.cs
  4. 207
      src/csharp/Grpc.Core/Internal/AsyncCall.cs
  5. 2
      src/csharp/Grpc.Core/Internal/AsyncCallBase.cs

@ -106,6 +106,42 @@ namespace Grpc.Core.Internal.Tests
AssertUnaryResponseError(asyncCall, fakeCall, resultTask, StatusCode.Internal);
}
[Test]
public void AsyncUnary_RequestSerializationExceptionDoesntLeakResources()
{
string nullRequest = null; // will throw when serializing
Assert.Throws(typeof(ArgumentNullException), () => asyncCall.UnaryCallAsync(nullRequest));
Assert.AreEqual(0, channel.GetCallReferenceCount());
Assert.IsTrue(fakeCall.IsDisposed);
}
[Test]
public void AsyncUnary_StartCallFailureDoesntLeakResources()
{
fakeCall.MakeStartCallFail();
Assert.Throws(typeof(InvalidOperationException), () => asyncCall.UnaryCallAsync("request1"));
Assert.AreEqual(0, channel.GetCallReferenceCount());
Assert.IsTrue(fakeCall.IsDisposed);
}
[Test]
public void SyncUnary_RequestSerializationExceptionDoesntLeakResources()
{
string nullRequest = null; // will throw when serializing
Assert.Throws(typeof(ArgumentNullException), () => asyncCall.UnaryCall(nullRequest));
Assert.AreEqual(0, channel.GetCallReferenceCount());
Assert.IsTrue(fakeCall.IsDisposed);
}
[Test]
public void SyncUnary_StartCallFailureDoesntLeakResources()
{
fakeCall.MakeStartCallFail();
Assert.Throws(typeof(InvalidOperationException), () => asyncCall.UnaryCall("request1"));
Assert.AreEqual(0, channel.GetCallReferenceCount());
Assert.IsTrue(fakeCall.IsDisposed);
}
[Test]
public void ClientStreaming_StreamingReadNotAllowed()
{
@ -327,6 +363,15 @@ namespace Grpc.Core.Internal.Tests
AssertUnaryResponseError(asyncCall, fakeCall, resultTask, StatusCode.Cancelled);
}
[Test]
public void ClientStreaming_StartCallFailureDoesntLeakResources()
{
fakeCall.MakeStartCallFail();
Assert.Throws(typeof(InvalidOperationException), () => asyncCall.ClientStreamingCallAsync());
Assert.AreEqual(0, channel.GetCallReferenceCount());
Assert.IsTrue(fakeCall.IsDisposed);
}
[Test]
public void ServerStreaming_StreamingSendNotAllowed()
{
@ -401,6 +446,27 @@ namespace Grpc.Core.Internal.Tests
AssertStreamingResponseSuccess(asyncCall, fakeCall, readTask3);
}
[Test]
public void ServerStreaming_RequestSerializationExceptionDoesntLeakResources()
{
string nullRequest = null; // will throw when serializing
Assert.Throws(typeof(ArgumentNullException), () => asyncCall.StartServerStreamingCall(nullRequest));
Assert.AreEqual(0, channel.GetCallReferenceCount());
Assert.IsTrue(fakeCall.IsDisposed);
var responseStream = new ClientResponseStream<string, string>(asyncCall);
var readTask = responseStream.MoveNext();
}
[Test]
public void ServerStreaming_StartCallFailureDoesntLeakResources()
{
fakeCall.MakeStartCallFail();
Assert.Throws(typeof(InvalidOperationException), () => asyncCall.StartServerStreamingCall("request1"));
Assert.AreEqual(0, channel.GetCallReferenceCount());
Assert.IsTrue(fakeCall.IsDisposed);
}
[Test]
public void DuplexStreaming_NoRequestNoResponse_Success()
{
@ -558,6 +624,15 @@ namespace Grpc.Core.Internal.Tests
AssertStreamingResponseError(asyncCall, fakeCall, readTask2, StatusCode.Cancelled);
}
[Test]
public void DuplexStreaming_StartCallFailureDoesntLeakResources()
{
fakeCall.MakeStartCallFail();
Assert.Throws(typeof(InvalidOperationException), () => asyncCall.StartDuplexStreamingCall());
Assert.AreEqual(0, channel.GetCallReferenceCount());
Assert.IsTrue(fakeCall.IsDisposed);
}
ClientSideStatus CreateClientSideStatus(StatusCode statusCode)
{
return new ClientSideStatus(new Status(statusCode, ""), new Metadata());

@ -31,6 +31,7 @@ namespace Grpc.Core.Internal.Tests
/// </summary>
internal class FakeNativeCall : INativeCall
{
private bool shouldStartCallFail;
public IUnaryResponseClientCallback UnaryResponseClientCallback
{
get;
@ -102,26 +103,31 @@ namespace Grpc.Core.Internal.Tests
public void StartUnary(IUnaryResponseClientCallback callback, byte[] payload, WriteFlags writeFlags, MetadataArraySafeHandle metadataArray, CallFlags callFlags)
{
StartCallMaybeFail();
UnaryResponseClientCallback = callback;
}
public void StartUnary(BatchContextSafeHandle ctx, byte[] payload, WriteFlags writeFlags, MetadataArraySafeHandle metadataArray, CallFlags callFlags)
{
StartCallMaybeFail();
throw new NotImplementedException();
}
public void StartClientStreaming(IUnaryResponseClientCallback callback, MetadataArraySafeHandle metadataArray, CallFlags callFlags)
{
StartCallMaybeFail();
UnaryResponseClientCallback = callback;
}
public void StartServerStreaming(IReceivedStatusOnClientCallback callback, byte[] payload, WriteFlags writeFlags, MetadataArraySafeHandle metadataArray, CallFlags callFlags)
{
StartCallMaybeFail();
ReceivedStatusOnClientCallback = callback;
}
public void StartDuplexStreaming(IReceivedStatusOnClientCallback callback, MetadataArraySafeHandle metadataArray, CallFlags callFlags)
{
StartCallMaybeFail();
ReceivedStatusOnClientCallback = callback;
}
@ -165,5 +171,22 @@ namespace Grpc.Core.Internal.Tests
{
IsDisposed = true;
}
/// <summary>
/// Emulate CallSafeHandle.CheckOk() failure for all future attempts
/// to start a call.
/// </summary>
public void MakeStartCallFail()
{
shouldStartCallFail = true;
}
private void StartCallMaybeFail()
{
if (shouldStartCallFail)
{
throw new InvalidOperationException("Start call has failed.");
}
}
}
}

@ -297,6 +297,12 @@ namespace Grpc.Core
activeCallCounter.Decrement();
}
// for testing only
internal long GetCallReferenceCount()
{
return activeCallCounter.Count;
}
private ChannelState GetConnectivityState(bool tryToConnect)
{
try

@ -17,6 +17,7 @@
#endregion
using System;
using System.Threading;
using System.Threading.Tasks;
using Grpc.Core.Logging;
using Grpc.Core.Profiling;
@ -34,6 +35,8 @@ namespace Grpc.Core.Internal
readonly CallInvocationDetails<TRequest, TResponse> details;
readonly INativeCall injectedNativeCall; // for testing
bool registeredWithChannel;
// Dispose of to de-register cancellation token registration
IDisposable cancellationTokenRegistration;
@ -77,43 +80,59 @@ namespace Grpc.Core.Internal
using (profiler.NewScope("AsyncCall.UnaryCall"))
using (CompletionQueueSafeHandle cq = CompletionQueueSafeHandle.CreateSync())
{
byte[] payload = UnsafeSerialize(msg);
bool callStartedOk = false;
try
{
unaryResponseTcs = new TaskCompletionSource<TResponse>();
unaryResponseTcs = new TaskCompletionSource<TResponse>();
lock (myLock)
{
GrpcPreconditions.CheckState(!started);
started = true;
Initialize(cq);
lock (myLock)
{
GrpcPreconditions.CheckState(!started);
started = true;
Initialize(cq);
halfcloseRequested = true;
readingDone = true;
}
halfcloseRequested = true;
readingDone = true;
}
byte[] payload = UnsafeSerialize(msg);
using (var metadataArray = MetadataArraySafeHandle.Create(details.Options.Headers))
{
var ctx = details.Channel.Environment.BatchContextPool.Lease();
try
using (var metadataArray = MetadataArraySafeHandle.Create(details.Options.Headers))
{
call.StartUnary(ctx, payload, GetWriteFlagsForCall(), metadataArray, details.Options.Flags);
var ev = cq.Pluck(ctx.Handle);
bool success = (ev.success != 0);
var ctx = details.Channel.Environment.BatchContextPool.Lease();
try
{
using (profiler.NewScope("AsyncCall.UnaryCall.HandleBatch"))
call.StartUnary(ctx, payload, GetWriteFlagsForCall(), metadataArray, details.Options.Flags);
callStartedOk = true;
var ev = cq.Pluck(ctx.Handle);
bool success = (ev.success != 0);
try
{
using (profiler.NewScope("AsyncCall.UnaryCall.HandleBatch"))
{
HandleUnaryResponse(success, ctx.GetReceivedStatusOnClient(), ctx.GetReceivedMessage(), ctx.GetReceivedInitialMetadata());
}
}
catch (Exception e)
{
HandleUnaryResponse(success, ctx.GetReceivedStatusOnClient(), ctx.GetReceivedMessage(), ctx.GetReceivedInitialMetadata());
Logger.Error(e, "Exception occurred while invoking completion delegate.");
}
}
catch (Exception e)
finally
{
Logger.Error(e, "Exception occurred while invoking completion delegate.");
ctx.Recycle();
}
}
finally
}
finally
{
if (!callStartedOk)
{
ctx.Recycle();
lock (myLock)
{
OnFailedToStartCallLocked();
}
}
}
@ -130,22 +149,35 @@ namespace Grpc.Core.Internal
{
lock (myLock)
{
GrpcPreconditions.CheckState(!started);
started = true;
bool callStartedOk = false;
try
{
GrpcPreconditions.CheckState(!started);
started = true;
Initialize(details.Channel.CompletionQueue);
Initialize(details.Channel.CompletionQueue);
halfcloseRequested = true;
readingDone = true;
halfcloseRequested = true;
readingDone = true;
byte[] payload = UnsafeSerialize(msg);
byte[] payload = UnsafeSerialize(msg);
unaryResponseTcs = new TaskCompletionSource<TResponse>();
using (var metadataArray = MetadataArraySafeHandle.Create(details.Options.Headers))
{
call.StartUnary(UnaryResponseClientCallback, payload, GetWriteFlagsForCall(), metadataArray, details.Options.Flags);
callStartedOk = true;
}
unaryResponseTcs = new TaskCompletionSource<TResponse>();
using (var metadataArray = MetadataArraySafeHandle.Create(details.Options.Headers))
return unaryResponseTcs.Task;
}
finally
{
call.StartUnary(UnaryResponseClientCallback, payload, GetWriteFlagsForCall(), metadataArray, details.Options.Flags);
if (!callStartedOk)
{
OnFailedToStartCallLocked();
}
}
return unaryResponseTcs.Task;
}
}
@ -157,20 +189,32 @@ namespace Grpc.Core.Internal
{
lock (myLock)
{
GrpcPreconditions.CheckState(!started);
started = true;
bool callStartedOk = false;
try
{
GrpcPreconditions.CheckState(!started);
started = true;
Initialize(details.Channel.CompletionQueue);
Initialize(details.Channel.CompletionQueue);
readingDone = true;
readingDone = true;
unaryResponseTcs = new TaskCompletionSource<TResponse>();
using (var metadataArray = MetadataArraySafeHandle.Create(details.Options.Headers))
{
call.StartClientStreaming(UnaryResponseClientCallback, metadataArray, details.Options.Flags);
callStartedOk = true;
}
unaryResponseTcs = new TaskCompletionSource<TResponse>();
using (var metadataArray = MetadataArraySafeHandle.Create(details.Options.Headers))
return unaryResponseTcs.Task;
}
finally
{
call.StartClientStreaming(UnaryResponseClientCallback, metadataArray, details.Options.Flags);
if (!callStartedOk)
{
OnFailedToStartCallLocked();
}
}
return unaryResponseTcs.Task;
}
}
@ -181,21 +225,33 @@ namespace Grpc.Core.Internal
{
lock (myLock)
{
GrpcPreconditions.CheckState(!started);
started = true;
bool callStartedOk = false;
try
{
GrpcPreconditions.CheckState(!started);
started = true;
Initialize(details.Channel.CompletionQueue);
Initialize(details.Channel.CompletionQueue);
halfcloseRequested = true;
halfcloseRequested = true;
byte[] payload = UnsafeSerialize(msg);
byte[] payload = UnsafeSerialize(msg);
streamingResponseCallFinishedTcs = new TaskCompletionSource<object>();
using (var metadataArray = MetadataArraySafeHandle.Create(details.Options.Headers))
streamingResponseCallFinishedTcs = new TaskCompletionSource<object>();
using (var metadataArray = MetadataArraySafeHandle.Create(details.Options.Headers))
{
call.StartServerStreaming(ReceivedStatusOnClientCallback, payload, GetWriteFlagsForCall(), metadataArray, details.Options.Flags);
callStartedOk = true;
}
call.StartReceiveInitialMetadata(ReceivedResponseHeadersCallback);
}
finally
{
call.StartServerStreaming(ReceivedStatusOnClientCallback, payload, GetWriteFlagsForCall(), metadataArray, details.Options.Flags);
if (!callStartedOk)
{
OnFailedToStartCallLocked();
}
}
call.StartReceiveInitialMetadata(ReceivedResponseHeadersCallback);
}
}
@ -207,17 +263,29 @@ namespace Grpc.Core.Internal
{
lock (myLock)
{
GrpcPreconditions.CheckState(!started);
started = true;
bool callStartedOk = false;
try
{
GrpcPreconditions.CheckState(!started);
started = true;
Initialize(details.Channel.CompletionQueue);
Initialize(details.Channel.CompletionQueue);
streamingResponseCallFinishedTcs = new TaskCompletionSource<object>();
using (var metadataArray = MetadataArraySafeHandle.Create(details.Options.Headers))
streamingResponseCallFinishedTcs = new TaskCompletionSource<object>();
using (var metadataArray = MetadataArraySafeHandle.Create(details.Options.Headers))
{
call.StartDuplexStreaming(ReceivedStatusOnClientCallback, metadataArray, details.Options.Flags);
callStartedOk = true;
}
call.StartReceiveInitialMetadata(ReceivedResponseHeadersCallback);
}
finally
{
call.StartDuplexStreaming(ReceivedStatusOnClientCallback, metadataArray, details.Options.Flags);
if (!callStartedOk)
{
OnFailedToStartCallLocked();
}
}
call.StartReceiveInitialMetadata(ReceivedResponseHeadersCallback);
}
}
@ -327,7 +395,11 @@ namespace Grpc.Core.Internal
protected override void OnAfterReleaseResourcesLocked()
{
details.Channel.RemoveCallReference(this);
if (registeredWithChannel)
{
details.Channel.RemoveCallReference(this);
registeredWithChannel = false;
}
}
protected override void OnAfterReleaseResourcesUnlocked()
@ -394,10 +466,27 @@ namespace Grpc.Core.Internal
var call = CreateNativeCall(cq);
details.Channel.AddCallReference(this);
registeredWithChannel = true;
InitializeInternal(call);
RegisterCancellationCallback();
}
private void OnFailedToStartCallLocked()
{
ReleaseResources();
// We need to execute the hook that disposes the cancellation token
// registration, but it cannot be done from under a lock.
// To make things simple, we just schedule the unregistering
// on a threadpool.
// - Once the native call is disposed, the Cancel() calls are ignored anyway
// - We don't care about the overhead as OnFailedToStartCallLocked() only happens
// when something goes very bad when initializing a call and that should
// never happen when gRPC is used correctly.
ThreadPool.QueueUserWorkItem((state) => OnAfterReleaseResourcesUnlocked());
}
private INativeCall CreateNativeCall(CompletionQueueSafeHandle cq)
{
if (injectedNativeCall != null)

@ -189,7 +189,7 @@ namespace Grpc.Core.Internal
/// </summary>
protected abstract Exception GetRpcExceptionClientOnly();
private void ReleaseResources()
protected void ReleaseResources()
{
if (call != null)
{

Loading…
Cancel
Save