handle failures in async call initialization without leaks

pull/16554/head
Jan Tattermusch 6 years ago
parent cd74b357e1
commit b155c314f1
  1. 57
      src/csharp/Grpc.Core.Tests/Internal/AsyncCallTest.cs
  2. 22
      src/csharp/Grpc.Core.Tests/Internal/FakeNativeCall.cs
  3. 6
      src/csharp/Grpc.Core/Channel.cs
  4. 210
      src/csharp/Grpc.Core/Internal/AsyncCall.cs
  5. 2
      src/csharp/Grpc.Core/Internal/AsyncCallBase.cs

@ -106,6 +106,24 @@ namespace Grpc.Core.Internal.Tests
AssertUnaryResponseError(asyncCall, fakeCall, resultTask, StatusCode.Internal); AssertUnaryResponseError(asyncCall, fakeCall, resultTask, StatusCode.Internal);
} }
[Test]
public void AsyncUnary_RequestSerializationExceptionDoesntLeakResources()
{
string nullRequest = null; // will throw when serializing
Assert.Throws(typeof(ArgumentNullException), () => asyncCall.UnaryCallAsync(nullRequest));
Assert.AreEqual(0, channel.GetCallReferenceCount());
Assert.IsTrue(fakeCall.IsDisposed);
}
[Test]
public void AsyncUnary_StartCallFailureDoesntLeakResources()
{
fakeCall.MakeStartCallFail();
Assert.Throws(typeof(InvalidOperationException), () => asyncCall.UnaryCallAsync("request1"));
Assert.AreEqual(0, channel.GetCallReferenceCount());
Assert.IsTrue(fakeCall.IsDisposed);
}
[Test] [Test]
public void ClientStreaming_StreamingReadNotAllowed() public void ClientStreaming_StreamingReadNotAllowed()
{ {
@ -327,6 +345,15 @@ namespace Grpc.Core.Internal.Tests
AssertUnaryResponseError(asyncCall, fakeCall, resultTask, StatusCode.Cancelled); AssertUnaryResponseError(asyncCall, fakeCall, resultTask, StatusCode.Cancelled);
} }
[Test]
public void ClientStreaming_StartCallFailureDoesntLeakResources()
{
fakeCall.MakeStartCallFail();
Assert.Throws(typeof(InvalidOperationException), () => asyncCall.ClientStreamingCallAsync());
Assert.AreEqual(0, channel.GetCallReferenceCount());
Assert.IsTrue(fakeCall.IsDisposed);
}
[Test] [Test]
public void ServerStreaming_StreamingSendNotAllowed() public void ServerStreaming_StreamingSendNotAllowed()
{ {
@ -401,6 +428,27 @@ namespace Grpc.Core.Internal.Tests
AssertStreamingResponseSuccess(asyncCall, fakeCall, readTask3); AssertStreamingResponseSuccess(asyncCall, fakeCall, readTask3);
} }
[Test]
public void ServerStreaming_RequestSerializationExceptionDoesntLeakResources()
{
string nullRequest = null; // will throw when serializing
Assert.Throws(typeof(ArgumentNullException), () => asyncCall.StartServerStreamingCall(nullRequest));
Assert.AreEqual(0, channel.GetCallReferenceCount());
Assert.IsTrue(fakeCall.IsDisposed);
var responseStream = new ClientResponseStream<string, string>(asyncCall);
var readTask = responseStream.MoveNext();
}
[Test]
public void ServerStreaming_StartCallFailureDoesntLeakResources()
{
fakeCall.MakeStartCallFail();
Assert.Throws(typeof(InvalidOperationException), () => asyncCall.StartServerStreamingCall("request1"));
Assert.AreEqual(0, channel.GetCallReferenceCount());
Assert.IsTrue(fakeCall.IsDisposed);
}
[Test] [Test]
public void DuplexStreaming_NoRequestNoResponse_Success() public void DuplexStreaming_NoRequestNoResponse_Success()
{ {
@ -558,6 +606,15 @@ namespace Grpc.Core.Internal.Tests
AssertStreamingResponseError(asyncCall, fakeCall, readTask2, StatusCode.Cancelled); AssertStreamingResponseError(asyncCall, fakeCall, readTask2, StatusCode.Cancelled);
} }
[Test]
public void DuplexStreaming_StartCallFailureDoesntLeakResources()
{
fakeCall.MakeStartCallFail();
Assert.Throws(typeof(InvalidOperationException), () => asyncCall.StartDuplexStreamingCall());
Assert.AreEqual(0, channel.GetCallReferenceCount());
Assert.IsTrue(fakeCall.IsDisposed);
}
ClientSideStatus CreateClientSideStatus(StatusCode statusCode) ClientSideStatus CreateClientSideStatus(StatusCode statusCode)
{ {
return new ClientSideStatus(new Status(statusCode, ""), new Metadata()); return new ClientSideStatus(new Status(statusCode, ""), new Metadata());

@ -31,6 +31,7 @@ namespace Grpc.Core.Internal.Tests
/// </summary> /// </summary>
internal class FakeNativeCall : INativeCall internal class FakeNativeCall : INativeCall
{ {
private bool shouldStartCallFail;
public IUnaryResponseClientCallback UnaryResponseClientCallback public IUnaryResponseClientCallback UnaryResponseClientCallback
{ {
get; get;
@ -102,6 +103,7 @@ namespace Grpc.Core.Internal.Tests
public void StartUnary(IUnaryResponseClientCallback callback, byte[] payload, WriteFlags writeFlags, MetadataArraySafeHandle metadataArray, CallFlags callFlags) public void StartUnary(IUnaryResponseClientCallback callback, byte[] payload, WriteFlags writeFlags, MetadataArraySafeHandle metadataArray, CallFlags callFlags)
{ {
StartCallMaybeFail();
UnaryResponseClientCallback = callback; UnaryResponseClientCallback = callback;
} }
@ -112,16 +114,19 @@ namespace Grpc.Core.Internal.Tests
public void StartClientStreaming(IUnaryResponseClientCallback callback, MetadataArraySafeHandle metadataArray, CallFlags callFlags) public void StartClientStreaming(IUnaryResponseClientCallback callback, MetadataArraySafeHandle metadataArray, CallFlags callFlags)
{ {
StartCallMaybeFail();
UnaryResponseClientCallback = callback; UnaryResponseClientCallback = callback;
} }
public void StartServerStreaming(IReceivedStatusOnClientCallback callback, byte[] payload, WriteFlags writeFlags, MetadataArraySafeHandle metadataArray, CallFlags callFlags) public void StartServerStreaming(IReceivedStatusOnClientCallback callback, byte[] payload, WriteFlags writeFlags, MetadataArraySafeHandle metadataArray, CallFlags callFlags)
{ {
StartCallMaybeFail();
ReceivedStatusOnClientCallback = callback; ReceivedStatusOnClientCallback = callback;
} }
public void StartDuplexStreaming(IReceivedStatusOnClientCallback callback, MetadataArraySafeHandle metadataArray, CallFlags callFlags) public void StartDuplexStreaming(IReceivedStatusOnClientCallback callback, MetadataArraySafeHandle metadataArray, CallFlags callFlags)
{ {
StartCallMaybeFail();
ReceivedStatusOnClientCallback = callback; ReceivedStatusOnClientCallback = callback;
} }
@ -165,5 +170,22 @@ namespace Grpc.Core.Internal.Tests
{ {
IsDisposed = true; IsDisposed = true;
} }
/// <summary>
/// Emulate CallSafeHandle.CheckOk() failure for all future attempts
/// to start a call.
/// </summary>
public void MakeStartCallFail()
{
shouldStartCallFail = true;
}
private void StartCallMaybeFail()
{
if (shouldStartCallFail)
{
throw new InvalidOperationException("Start call has failed.");
}
}
} }
} }

@ -297,6 +297,12 @@ namespace Grpc.Core
activeCallCounter.Decrement(); activeCallCounter.Decrement();
} }
// for testing only
internal long GetCallReferenceCount()
{
return activeCallCounter.Count;
}
private ChannelState GetConnectivityState(bool tryToConnect) private ChannelState GetConnectivityState(bool tryToConnect)
{ {
try try

@ -17,6 +17,7 @@
#endregion #endregion
using System; using System;
using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using Grpc.Core.Logging; using Grpc.Core.Logging;
using Grpc.Core.Profiling; using Grpc.Core.Profiling;
@ -34,6 +35,8 @@ namespace Grpc.Core.Internal
readonly CallInvocationDetails<TRequest, TResponse> details; readonly CallInvocationDetails<TRequest, TResponse> details;
readonly INativeCall injectedNativeCall; // for testing readonly INativeCall injectedNativeCall; // for testing
bool registeredWithChannel;
// Dispose of to de-register cancellation token registration // Dispose of to de-register cancellation token registration
IDisposable cancellationTokenRegistration; IDisposable cancellationTokenRegistration;
@ -79,42 +82,59 @@ namespace Grpc.Core.Internal
{ {
byte[] payload = UnsafeSerialize(msg); byte[] payload = UnsafeSerialize(msg);
unaryResponseTcs = new TaskCompletionSource<TResponse>(); bool callStartedOk = false;
try
lock (myLock)
{ {
GrpcPreconditions.CheckState(!started); unaryResponseTcs = new TaskCompletionSource<TResponse>();
started = true;
Initialize(cq);
halfcloseRequested = true; lock (myLock)
readingDone = true; {
} GrpcPreconditions.CheckState(!started);
started = true;
Initialize(cq);
using (var metadataArray = MetadataArraySafeHandle.Create(details.Options.Headers)) halfcloseRequested = true;
{ readingDone = true;
var ctx = details.Channel.Environment.BatchContextPool.Lease(); }
try
using (var metadataArray = MetadataArraySafeHandle.Create(details.Options.Headers))
{ {
call.StartUnary(ctx, payload, GetWriteFlagsForCall(), metadataArray, details.Options.Flags); var ctx = details.Channel.Environment.BatchContextPool.Lease();
var ev = cq.Pluck(ctx.Handle);
bool success = (ev.success != 0);
try try
{ {
using (profiler.NewScope("AsyncCall.UnaryCall.HandleBatch")) call.StartUnary(ctx, payload, GetWriteFlagsForCall(), metadataArray, details.Options.Flags);
callStartedOk = true;
var ev = cq.Pluck(ctx.Handle);
bool success = (ev.success != 0);
try
{
using (profiler.NewScope("AsyncCall.UnaryCall.HandleBatch"))
{
HandleUnaryResponse(success, ctx.GetReceivedStatusOnClient(), ctx.GetReceivedMessage(), ctx.GetReceivedInitialMetadata());
}
}
catch (Exception e)
{ {
HandleUnaryResponse(success, ctx.GetReceivedStatusOnClient(), ctx.GetReceivedMessage(), ctx.GetReceivedInitialMetadata()); Logger.Error(e, "Exception occurred while invoking completion delegate.");
} }
} }
catch (Exception e) finally
{ {
Logger.Error(e, "Exception occurred while invoking completion delegate."); ctx.Recycle();
} }
} }
finally }
catch (Exception)
{
if (!callStartedOk)
{ {
ctx.Recycle(); lock (myLock)
{
OnFailedToStartCallLocked();
}
} }
throw;
} }
// Once the blocking call returns, the result should be available synchronously. // Once the blocking call returns, the result should be available synchronously.
@ -130,22 +150,36 @@ namespace Grpc.Core.Internal
{ {
lock (myLock) lock (myLock)
{ {
GrpcPreconditions.CheckState(!started); bool callStartedOk = false;
started = true; try
{
GrpcPreconditions.CheckState(!started);
started = true;
Initialize(details.Channel.CompletionQueue); Initialize(details.Channel.CompletionQueue);
halfcloseRequested = true; halfcloseRequested = true;
readingDone = true; readingDone = true;
byte[] payload = UnsafeSerialize(msg); byte[] payload = UnsafeSerialize(msg);
unaryResponseTcs = new TaskCompletionSource<TResponse>(); unaryResponseTcs = new TaskCompletionSource<TResponse>();
using (var metadataArray = MetadataArraySafeHandle.Create(details.Options.Headers)) using (var metadataArray = MetadataArraySafeHandle.Create(details.Options.Headers))
{
call.StartUnary(UnaryResponseClientCallback, payload, GetWriteFlagsForCall(), metadataArray, details.Options.Flags);
callStartedOk = true;
}
return unaryResponseTcs.Task;
}
catch (Exception)
{ {
call.StartUnary(UnaryResponseClientCallback, payload, GetWriteFlagsForCall(), metadataArray, details.Options.Flags); if (!callStartedOk)
{
OnFailedToStartCallLocked();
}
throw;
} }
return unaryResponseTcs.Task;
} }
} }
@ -157,20 +191,33 @@ namespace Grpc.Core.Internal
{ {
lock (myLock) lock (myLock)
{ {
GrpcPreconditions.CheckState(!started); bool callStartedOk = false;
started = true; try
{
GrpcPreconditions.CheckState(!started);
started = true;
Initialize(details.Channel.CompletionQueue); Initialize(details.Channel.CompletionQueue);
readingDone = true; readingDone = true;
unaryResponseTcs = new TaskCompletionSource<TResponse>(); unaryResponseTcs = new TaskCompletionSource<TResponse>();
using (var metadataArray = MetadataArraySafeHandle.Create(details.Options.Headers)) using (var metadataArray = MetadataArraySafeHandle.Create(details.Options.Headers))
{
call.StartClientStreaming(UnaryResponseClientCallback, metadataArray, details.Options.Flags);
callStartedOk = true;
}
return unaryResponseTcs.Task;
}
catch (Exception)
{ {
call.StartClientStreaming(UnaryResponseClientCallback, metadataArray, details.Options.Flags); if (!callStartedOk)
{
OnFailedToStartCallLocked();
}
throw;
} }
return unaryResponseTcs.Task;
} }
} }
@ -181,21 +228,34 @@ namespace Grpc.Core.Internal
{ {
lock (myLock) lock (myLock)
{ {
GrpcPreconditions.CheckState(!started); bool callStartedOk = false;
started = true; try
{
GrpcPreconditions.CheckState(!started);
started = true;
Initialize(details.Channel.CompletionQueue); Initialize(details.Channel.CompletionQueue);
halfcloseRequested = true; halfcloseRequested = true;
byte[] payload = UnsafeSerialize(msg); byte[] payload = UnsafeSerialize(msg);
streamingResponseCallFinishedTcs = new TaskCompletionSource<object>(); streamingResponseCallFinishedTcs = new TaskCompletionSource<object>();
using (var metadataArray = MetadataArraySafeHandle.Create(details.Options.Headers)) using (var metadataArray = MetadataArraySafeHandle.Create(details.Options.Headers))
{
call.StartServerStreaming(ReceivedStatusOnClientCallback, payload, GetWriteFlagsForCall(), metadataArray, details.Options.Flags);
callStartedOk = true;
}
call.StartReceiveInitialMetadata(ReceivedResponseHeadersCallback);
}
catch (Exception)
{ {
call.StartServerStreaming(ReceivedStatusOnClientCallback, payload, GetWriteFlagsForCall(), metadataArray, details.Options.Flags); if (!callStartedOk)
{
OnFailedToStartCallLocked();
}
throw;
} }
call.StartReceiveInitialMetadata(ReceivedResponseHeadersCallback);
} }
} }
@ -207,17 +267,30 @@ namespace Grpc.Core.Internal
{ {
lock (myLock) lock (myLock)
{ {
GrpcPreconditions.CheckState(!started); bool callStartedOk = false;
started = true; try
{
GrpcPreconditions.CheckState(!started);
started = true;
Initialize(details.Channel.CompletionQueue); Initialize(details.Channel.CompletionQueue);
streamingResponseCallFinishedTcs = new TaskCompletionSource<object>(); streamingResponseCallFinishedTcs = new TaskCompletionSource<object>();
using (var metadataArray = MetadataArraySafeHandle.Create(details.Options.Headers)) using (var metadataArray = MetadataArraySafeHandle.Create(details.Options.Headers))
{
call.StartDuplexStreaming(ReceivedStatusOnClientCallback, metadataArray, details.Options.Flags);
callStartedOk = true;
}
call.StartReceiveInitialMetadata(ReceivedResponseHeadersCallback);
}
catch (Exception)
{ {
call.StartDuplexStreaming(ReceivedStatusOnClientCallback, metadataArray, details.Options.Flags); if (!callStartedOk)
{
OnFailedToStartCallLocked();
}
throw;
} }
call.StartReceiveInitialMetadata(ReceivedResponseHeadersCallback);
} }
} }
@ -327,7 +400,11 @@ namespace Grpc.Core.Internal
protected override void OnAfterReleaseResourcesLocked() protected override void OnAfterReleaseResourcesLocked()
{ {
details.Channel.RemoveCallReference(this); if (registeredWithChannel)
{
details.Channel.RemoveCallReference(this);
registeredWithChannel = false;
}
} }
protected override void OnAfterReleaseResourcesUnlocked() protected override void OnAfterReleaseResourcesUnlocked()
@ -394,10 +471,27 @@ namespace Grpc.Core.Internal
var call = CreateNativeCall(cq); var call = CreateNativeCall(cq);
details.Channel.AddCallReference(this); details.Channel.AddCallReference(this);
registeredWithChannel = true;
InitializeInternal(call); InitializeInternal(call);
RegisterCancellationCallback(); RegisterCancellationCallback();
} }
private void OnFailedToStartCallLocked()
{
ReleaseResources();
// We need to execute the hook that disposes the cancellation token
// registration, but it cannot be done from under a lock.
// To make things simple, we just schedule the unregistering
// on a threadpool.
// - Once the native call is disposed, the Cancel() calls are ignored anyway
// - We don't care about the overhead as OnFailedToStartCallLocked() only happens
// when something goes very bad when initializing a call and that should
// never happen when gRPC is used correctly.
ThreadPool.QueueUserWorkItem((state) => OnAfterReleaseResourcesUnlocked());
}
private INativeCall CreateNativeCall(CompletionQueueSafeHandle cq) private INativeCall CreateNativeCall(CompletionQueueSafeHandle cq)
{ {
if (injectedNativeCall != null) if (injectedNativeCall != null)

@ -189,7 +189,7 @@ namespace Grpc.Core.Internal
/// </summary> /// </summary>
protected abstract Exception GetRpcExceptionClientOnly(); protected abstract Exception GetRpcExceptionClientOnly();
private void ReleaseResources() protected void ReleaseResources()
{ {
if (call != null) if (call != null)
{ {

Loading…
Cancel
Save