Merge pull request #16554 from jtattermusch/csharp_dont_leak_when_call_init_fails

C#: avoid leaking resources when starting a call fails
reviewable/pr16055/r17^2
Jan Tattermusch 6 years ago committed by GitHub
commit d90d082ca2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 75
      src/csharp/Grpc.Core.Tests/Internal/AsyncCallTest.cs
  2. 23
      src/csharp/Grpc.Core.Tests/Internal/FakeNativeCall.cs
  3. 6
      src/csharp/Grpc.Core/Channel.cs
  4. 93
      src/csharp/Grpc.Core/Internal/AsyncCall.cs
  5. 2
      src/csharp/Grpc.Core/Internal/AsyncCallBase.cs

@ -106,6 +106,42 @@ namespace Grpc.Core.Internal.Tests
AssertUnaryResponseError(asyncCall, fakeCall, resultTask, StatusCode.Internal); AssertUnaryResponseError(asyncCall, fakeCall, resultTask, StatusCode.Internal);
} }
[Test]
public void AsyncUnary_RequestSerializationExceptionDoesntLeakResources()
{
string nullRequest = null; // will throw when serializing
Assert.Throws(typeof(ArgumentNullException), () => asyncCall.UnaryCallAsync(nullRequest));
Assert.AreEqual(0, channel.GetCallReferenceCount());
Assert.IsTrue(fakeCall.IsDisposed);
}
[Test]
public void AsyncUnary_StartCallFailureDoesntLeakResources()
{
fakeCall.MakeStartCallFail();
Assert.Throws(typeof(InvalidOperationException), () => asyncCall.UnaryCallAsync("request1"));
Assert.AreEqual(0, channel.GetCallReferenceCount());
Assert.IsTrue(fakeCall.IsDisposed);
}
[Test]
public void SyncUnary_RequestSerializationExceptionDoesntLeakResources()
{
string nullRequest = null; // will throw when serializing
Assert.Throws(typeof(ArgumentNullException), () => asyncCall.UnaryCall(nullRequest));
Assert.AreEqual(0, channel.GetCallReferenceCount());
Assert.IsTrue(fakeCall.IsDisposed);
}
[Test]
public void SyncUnary_StartCallFailureDoesntLeakResources()
{
fakeCall.MakeStartCallFail();
Assert.Throws(typeof(InvalidOperationException), () => asyncCall.UnaryCall("request1"));
Assert.AreEqual(0, channel.GetCallReferenceCount());
Assert.IsTrue(fakeCall.IsDisposed);
}
[Test] [Test]
public void ClientStreaming_StreamingReadNotAllowed() public void ClientStreaming_StreamingReadNotAllowed()
{ {
@ -327,6 +363,15 @@ namespace Grpc.Core.Internal.Tests
AssertUnaryResponseError(asyncCall, fakeCall, resultTask, StatusCode.Cancelled); AssertUnaryResponseError(asyncCall, fakeCall, resultTask, StatusCode.Cancelled);
} }
[Test]
public void ClientStreaming_StartCallFailureDoesntLeakResources()
{
fakeCall.MakeStartCallFail();
Assert.Throws(typeof(InvalidOperationException), () => asyncCall.ClientStreamingCallAsync());
Assert.AreEqual(0, channel.GetCallReferenceCount());
Assert.IsTrue(fakeCall.IsDisposed);
}
[Test] [Test]
public void ServerStreaming_StreamingSendNotAllowed() public void ServerStreaming_StreamingSendNotAllowed()
{ {
@ -401,6 +446,27 @@ namespace Grpc.Core.Internal.Tests
AssertStreamingResponseSuccess(asyncCall, fakeCall, readTask3); AssertStreamingResponseSuccess(asyncCall, fakeCall, readTask3);
} }
[Test]
public void ServerStreaming_RequestSerializationExceptionDoesntLeakResources()
{
string nullRequest = null; // will throw when serializing
Assert.Throws(typeof(ArgumentNullException), () => asyncCall.StartServerStreamingCall(nullRequest));
Assert.AreEqual(0, channel.GetCallReferenceCount());
Assert.IsTrue(fakeCall.IsDisposed);
var responseStream = new ClientResponseStream<string, string>(asyncCall);
var readTask = responseStream.MoveNext();
}
[Test]
public void ServerStreaming_StartCallFailureDoesntLeakResources()
{
fakeCall.MakeStartCallFail();
Assert.Throws(typeof(InvalidOperationException), () => asyncCall.StartServerStreamingCall("request1"));
Assert.AreEqual(0, channel.GetCallReferenceCount());
Assert.IsTrue(fakeCall.IsDisposed);
}
[Test] [Test]
public void DuplexStreaming_NoRequestNoResponse_Success() public void DuplexStreaming_NoRequestNoResponse_Success()
{ {
@ -558,6 +624,15 @@ namespace Grpc.Core.Internal.Tests
AssertStreamingResponseError(asyncCall, fakeCall, readTask2, StatusCode.Cancelled); AssertStreamingResponseError(asyncCall, fakeCall, readTask2, StatusCode.Cancelled);
} }
[Test]
public void DuplexStreaming_StartCallFailureDoesntLeakResources()
{
fakeCall.MakeStartCallFail();
Assert.Throws(typeof(InvalidOperationException), () => asyncCall.StartDuplexStreamingCall());
Assert.AreEqual(0, channel.GetCallReferenceCount());
Assert.IsTrue(fakeCall.IsDisposed);
}
ClientSideStatus CreateClientSideStatus(StatusCode statusCode) ClientSideStatus CreateClientSideStatus(StatusCode statusCode)
{ {
return new ClientSideStatus(new Status(statusCode, ""), new Metadata()); return new ClientSideStatus(new Status(statusCode, ""), new Metadata());

@ -31,6 +31,7 @@ namespace Grpc.Core.Internal.Tests
/// </summary> /// </summary>
internal class FakeNativeCall : INativeCall internal class FakeNativeCall : INativeCall
{ {
private bool shouldStartCallFail;
public IUnaryResponseClientCallback UnaryResponseClientCallback public IUnaryResponseClientCallback UnaryResponseClientCallback
{ {
get; get;
@ -102,26 +103,31 @@ namespace Grpc.Core.Internal.Tests
public void StartUnary(IUnaryResponseClientCallback callback, byte[] payload, WriteFlags writeFlags, MetadataArraySafeHandle metadataArray, CallFlags callFlags) public void StartUnary(IUnaryResponseClientCallback callback, byte[] payload, WriteFlags writeFlags, MetadataArraySafeHandle metadataArray, CallFlags callFlags)
{ {
StartCallMaybeFail();
UnaryResponseClientCallback = callback; UnaryResponseClientCallback = callback;
} }
public void StartUnary(BatchContextSafeHandle ctx, byte[] payload, WriteFlags writeFlags, MetadataArraySafeHandle metadataArray, CallFlags callFlags) public void StartUnary(BatchContextSafeHandle ctx, byte[] payload, WriteFlags writeFlags, MetadataArraySafeHandle metadataArray, CallFlags callFlags)
{ {
StartCallMaybeFail();
throw new NotImplementedException(); throw new NotImplementedException();
} }
public void StartClientStreaming(IUnaryResponseClientCallback callback, MetadataArraySafeHandle metadataArray, CallFlags callFlags) public void StartClientStreaming(IUnaryResponseClientCallback callback, MetadataArraySafeHandle metadataArray, CallFlags callFlags)
{ {
StartCallMaybeFail();
UnaryResponseClientCallback = callback; UnaryResponseClientCallback = callback;
} }
public void StartServerStreaming(IReceivedStatusOnClientCallback callback, byte[] payload, WriteFlags writeFlags, MetadataArraySafeHandle metadataArray, CallFlags callFlags) public void StartServerStreaming(IReceivedStatusOnClientCallback callback, byte[] payload, WriteFlags writeFlags, MetadataArraySafeHandle metadataArray, CallFlags callFlags)
{ {
StartCallMaybeFail();
ReceivedStatusOnClientCallback = callback; ReceivedStatusOnClientCallback = callback;
} }
public void StartDuplexStreaming(IReceivedStatusOnClientCallback callback, MetadataArraySafeHandle metadataArray, CallFlags callFlags) public void StartDuplexStreaming(IReceivedStatusOnClientCallback callback, MetadataArraySafeHandle metadataArray, CallFlags callFlags)
{ {
StartCallMaybeFail();
ReceivedStatusOnClientCallback = callback; ReceivedStatusOnClientCallback = callback;
} }
@ -165,5 +171,22 @@ namespace Grpc.Core.Internal.Tests
{ {
IsDisposed = true; IsDisposed = true;
} }
/// <summary>
/// Emulate CallSafeHandle.CheckOk() failure for all future attempts
/// to start a call.
/// </summary>
public void MakeStartCallFail()
{
shouldStartCallFail = true;
}
private void StartCallMaybeFail()
{
if (shouldStartCallFail)
{
throw new InvalidOperationException("Start call has failed.");
}
}
} }
} }

@ -297,6 +297,12 @@ namespace Grpc.Core
activeCallCounter.Decrement(); activeCallCounter.Decrement();
} }
// for testing only
internal long GetCallReferenceCount()
{
return activeCallCounter.Count;
}
private ChannelState GetConnectivityState(bool tryToConnect) private ChannelState GetConnectivityState(bool tryToConnect)
{ {
try try

@ -17,6 +17,7 @@
#endregion #endregion
using System; using System;
using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using Grpc.Core.Logging; using Grpc.Core.Logging;
using Grpc.Core.Profiling; using Grpc.Core.Profiling;
@ -34,6 +35,8 @@ namespace Grpc.Core.Internal
readonly CallInvocationDetails<TRequest, TResponse> details; readonly CallInvocationDetails<TRequest, TResponse> details;
readonly INativeCall injectedNativeCall; // for testing readonly INativeCall injectedNativeCall; // for testing
bool registeredWithChannel;
// Dispose of to de-register cancellation token registration // Dispose of to de-register cancellation token registration
IDisposable cancellationTokenRegistration; IDisposable cancellationTokenRegistration;
@ -77,8 +80,9 @@ namespace Grpc.Core.Internal
using (profiler.NewScope("AsyncCall.UnaryCall")) using (profiler.NewScope("AsyncCall.UnaryCall"))
using (CompletionQueueSafeHandle cq = CompletionQueueSafeHandle.CreateSync()) using (CompletionQueueSafeHandle cq = CompletionQueueSafeHandle.CreateSync())
{ {
byte[] payload = UnsafeSerialize(msg); bool callStartedOk = false;
try
{
unaryResponseTcs = new TaskCompletionSource<TResponse>(); unaryResponseTcs = new TaskCompletionSource<TResponse>();
lock (myLock) lock (myLock)
@ -91,12 +95,16 @@ namespace Grpc.Core.Internal
readingDone = true; readingDone = true;
} }
byte[] payload = UnsafeSerialize(msg);
using (var metadataArray = MetadataArraySafeHandle.Create(details.Options.Headers)) using (var metadataArray = MetadataArraySafeHandle.Create(details.Options.Headers))
{ {
var ctx = details.Channel.Environment.BatchContextPool.Lease(); var ctx = details.Channel.Environment.BatchContextPool.Lease();
try try
{ {
call.StartUnary(ctx, payload, GetWriteFlagsForCall(), metadataArray, details.Options.Flags); call.StartUnary(ctx, payload, GetWriteFlagsForCall(), metadataArray, details.Options.Flags);
callStartedOk = true;
var ev = cq.Pluck(ctx.Handle); var ev = cq.Pluck(ctx.Handle);
bool success = (ev.success != 0); bool success = (ev.success != 0);
try try
@ -116,6 +124,17 @@ namespace Grpc.Core.Internal
ctx.Recycle(); ctx.Recycle();
} }
} }
}
finally
{
if (!callStartedOk)
{
lock (myLock)
{
OnFailedToStartCallLocked();
}
}
}
// Once the blocking call returns, the result should be available synchronously. // Once the blocking call returns, the result should be available synchronously.
// Note that GetAwaiter().GetResult() doesn't wrap exceptions in AggregateException. // Note that GetAwaiter().GetResult() doesn't wrap exceptions in AggregateException.
@ -129,6 +148,9 @@ namespace Grpc.Core.Internal
public Task<TResponse> UnaryCallAsync(TRequest msg) public Task<TResponse> UnaryCallAsync(TRequest msg)
{ {
lock (myLock) lock (myLock)
{
bool callStartedOk = false;
try
{ {
GrpcPreconditions.CheckState(!started); GrpcPreconditions.CheckState(!started);
started = true; started = true;
@ -144,9 +166,19 @@ namespace Grpc.Core.Internal
using (var metadataArray = MetadataArraySafeHandle.Create(details.Options.Headers)) using (var metadataArray = MetadataArraySafeHandle.Create(details.Options.Headers))
{ {
call.StartUnary(UnaryResponseClientCallback, payload, GetWriteFlagsForCall(), metadataArray, details.Options.Flags); call.StartUnary(UnaryResponseClientCallback, payload, GetWriteFlagsForCall(), metadataArray, details.Options.Flags);
callStartedOk = true;
} }
return unaryResponseTcs.Task; return unaryResponseTcs.Task;
} }
finally
{
if (!callStartedOk)
{
OnFailedToStartCallLocked();
}
}
}
} }
/// <summary> /// <summary>
@ -156,6 +188,9 @@ namespace Grpc.Core.Internal
public Task<TResponse> ClientStreamingCallAsync() public Task<TResponse> ClientStreamingCallAsync()
{ {
lock (myLock) lock (myLock)
{
bool callStartedOk = false;
try
{ {
GrpcPreconditions.CheckState(!started); GrpcPreconditions.CheckState(!started);
started = true; started = true;
@ -168,10 +203,19 @@ namespace Grpc.Core.Internal
using (var metadataArray = MetadataArraySafeHandle.Create(details.Options.Headers)) using (var metadataArray = MetadataArraySafeHandle.Create(details.Options.Headers))
{ {
call.StartClientStreaming(UnaryResponseClientCallback, metadataArray, details.Options.Flags); call.StartClientStreaming(UnaryResponseClientCallback, metadataArray, details.Options.Flags);
callStartedOk = true;
} }
return unaryResponseTcs.Task; return unaryResponseTcs.Task;
} }
finally
{
if (!callStartedOk)
{
OnFailedToStartCallLocked();
}
}
}
} }
/// <summary> /// <summary>
@ -180,6 +224,9 @@ namespace Grpc.Core.Internal
public void StartServerStreamingCall(TRequest msg) public void StartServerStreamingCall(TRequest msg)
{ {
lock (myLock) lock (myLock)
{
bool callStartedOk = false;
try
{ {
GrpcPreconditions.CheckState(!started); GrpcPreconditions.CheckState(!started);
started = true; started = true;
@ -194,9 +241,18 @@ namespace Grpc.Core.Internal
using (var metadataArray = MetadataArraySafeHandle.Create(details.Options.Headers)) using (var metadataArray = MetadataArraySafeHandle.Create(details.Options.Headers))
{ {
call.StartServerStreaming(ReceivedStatusOnClientCallback, payload, GetWriteFlagsForCall(), metadataArray, details.Options.Flags); call.StartServerStreaming(ReceivedStatusOnClientCallback, payload, GetWriteFlagsForCall(), metadataArray, details.Options.Flags);
callStartedOk = true;
} }
call.StartReceiveInitialMetadata(ReceivedResponseHeadersCallback); call.StartReceiveInitialMetadata(ReceivedResponseHeadersCallback);
} }
finally
{
if (!callStartedOk)
{
OnFailedToStartCallLocked();
}
}
}
} }
/// <summary> /// <summary>
@ -206,6 +262,9 @@ namespace Grpc.Core.Internal
public void StartDuplexStreamingCall() public void StartDuplexStreamingCall()
{ {
lock (myLock) lock (myLock)
{
bool callStartedOk = false;
try
{ {
GrpcPreconditions.CheckState(!started); GrpcPreconditions.CheckState(!started);
started = true; started = true;
@ -216,9 +275,18 @@ namespace Grpc.Core.Internal
using (var metadataArray = MetadataArraySafeHandle.Create(details.Options.Headers)) using (var metadataArray = MetadataArraySafeHandle.Create(details.Options.Headers))
{ {
call.StartDuplexStreaming(ReceivedStatusOnClientCallback, metadataArray, details.Options.Flags); call.StartDuplexStreaming(ReceivedStatusOnClientCallback, metadataArray, details.Options.Flags);
callStartedOk = true;
} }
call.StartReceiveInitialMetadata(ReceivedResponseHeadersCallback); call.StartReceiveInitialMetadata(ReceivedResponseHeadersCallback);
} }
finally
{
if (!callStartedOk)
{
OnFailedToStartCallLocked();
}
}
}
} }
/// <summary> /// <summary>
@ -326,8 +394,12 @@ namespace Grpc.Core.Internal
} }
protected override void OnAfterReleaseResourcesLocked() protected override void OnAfterReleaseResourcesLocked()
{
if (registeredWithChannel)
{ {
details.Channel.RemoveCallReference(this); details.Channel.RemoveCallReference(this);
registeredWithChannel = false;
}
} }
protected override void OnAfterReleaseResourcesUnlocked() protected override void OnAfterReleaseResourcesUnlocked()
@ -394,10 +466,27 @@ namespace Grpc.Core.Internal
var call = CreateNativeCall(cq); var call = CreateNativeCall(cq);
details.Channel.AddCallReference(this); details.Channel.AddCallReference(this);
registeredWithChannel = true;
InitializeInternal(call); InitializeInternal(call);
RegisterCancellationCallback(); RegisterCancellationCallback();
} }
private void OnFailedToStartCallLocked()
{
ReleaseResources();
// We need to execute the hook that disposes the cancellation token
// registration, but it cannot be done from under a lock.
// To make things simple, we just schedule the unregistering
// on a threadpool.
// - Once the native call is disposed, the Cancel() calls are ignored anyway
// - We don't care about the overhead as OnFailedToStartCallLocked() only happens
// when something goes very bad when initializing a call and that should
// never happen when gRPC is used correctly.
ThreadPool.QueueUserWorkItem((state) => OnAfterReleaseResourcesUnlocked());
}
private INativeCall CreateNativeCall(CompletionQueueSafeHandle cq) private INativeCall CreateNativeCall(CompletionQueueSafeHandle cq)
{ {
if (injectedNativeCall != null) if (injectedNativeCall != null)

@ -189,7 +189,7 @@ namespace Grpc.Core.Internal
/// </summary> /// </summary>
protected abstract Exception GetRpcExceptionClientOnly(); protected abstract Exception GetRpcExceptionClientOnly();
private void ReleaseResources() protected void ReleaseResources()
{ {
if (call != null) if (call != null)
{ {

Loading…
Cancel
Save