diff --git a/src/csharp/Grpc.Core.Tests/ContextualMarshallerTest.cs b/src/csharp/Grpc.Core.Tests/ContextualMarshallerTest.cs index c3aee726f26..99f92e12628 100644 --- a/src/csharp/Grpc.Core.Tests/ContextualMarshallerTest.cs +++ b/src/csharp/Grpc.Core.Tests/ContextualMarshallerTest.cs @@ -52,6 +52,8 @@ namespace Grpc.Core.Tests } if (str == "SERIALIZE_TO_NULL") { + // TODO: test for not calling complete Complete() (that resulted in null payload before...) + // TODO: test for calling Complete(null byte array) return; } var bytes = System.Text.Encoding.UTF8.GetBytes(str); diff --git a/src/csharp/Grpc.Core/Internal/AsyncCall.cs b/src/csharp/Grpc.Core/Internal/AsyncCall.cs index aefa58a0cee..4111c5347ce 100644 --- a/src/csharp/Grpc.Core/Internal/AsyncCall.cs +++ b/src/csharp/Grpc.Core/Internal/AsyncCall.cs @@ -95,10 +95,10 @@ namespace Grpc.Core.Internal readingDone = true; } - var payload = UnsafeSerialize(msg); - + using (var serializationScope = DefaultSerializationContext.GetInitializedThreadLocalScope()) using (var metadataArray = MetadataArraySafeHandle.Create(details.Options.Headers)) { + var payload = UnsafeSerialize(msg, serializationScope.Context); // do before metadata array? var ctx = details.Channel.Environment.BatchContextPool.Lease(); try { @@ -160,11 +160,14 @@ namespace Grpc.Core.Internal halfcloseRequested = true; readingDone = true; - var payload = UnsafeSerialize(msg); + //var payload = UnsafeSerialize(msg); unaryResponseTcs = new TaskCompletionSource(); + using (var serializationScope = DefaultSerializationContext.GetInitializedThreadLocalScope()) using (var metadataArray = MetadataArraySafeHandle.Create(details.Options.Headers)) { + var payload = UnsafeSerialize(msg, serializationScope.Context); // do before metadata array? + call.StartUnary(UnaryResponseClientCallback, payload, GetWriteFlagsForCall(), metadataArray, details.Options.Flags); callStartedOk = true; } @@ -235,11 +238,15 @@ namespace Grpc.Core.Internal halfcloseRequested = true; - var payload = UnsafeSerialize(msg); + //var payload = UnsafeSerialize(msg); streamingResponseCallFinishedTcs = new TaskCompletionSource(); + + using (var serializationScope = DefaultSerializationContext.GetInitializedThreadLocalScope()) using (var metadataArray = MetadataArraySafeHandle.Create(details.Options.Headers)) { + var payload = UnsafeSerialize(msg, serializationScope.Context); // do before metadata array? + call.StartServerStreaming(ReceivedStatusOnClientCallback, payload, GetWriteFlagsForCall(), metadataArray, details.Options.Flags); callStartedOk = true; } diff --git a/src/csharp/Grpc.Core/Internal/AsyncCallBase.cs b/src/csharp/Grpc.Core/Internal/AsyncCallBase.cs index 07f9ecf23e9..bd4b0d81832 100644 --- a/src/csharp/Grpc.Core/Internal/AsyncCallBase.cs +++ b/src/csharp/Grpc.Core/Internal/AsyncCallBase.cs @@ -115,23 +115,25 @@ namespace Grpc.Core.Internal /// protected Task SendMessageInternalAsync(TWrite msg, WriteFlags writeFlags) { - var payload = UnsafeSerialize(msg); - - lock (myLock) + using (var serializationScope = DefaultSerializationContext.GetInitializedThreadLocalScope()) { - GrpcPreconditions.CheckState(started); - var earlyResult = CheckSendAllowedOrEarlyResult(); - if (earlyResult != null) + var payload = UnsafeSerialize(msg, serializationScope.Context); // do before metadata array? + lock (myLock) { - return earlyResult; - } + GrpcPreconditions.CheckState(started); + var earlyResult = CheckSendAllowedOrEarlyResult(); + if (earlyResult != null) + { + return earlyResult; + } - call.StartSendMessage(SendCompletionCallback, payload, writeFlags, !initialMetadataSent); + call.StartSendMessage(SendCompletionCallback, payload, writeFlags, !initialMetadataSent); - initialMetadataSent = true; - streamingWritesCounter++; - streamingWriteTcs = new TaskCompletionSource(); - return streamingWriteTcs.Task; + initialMetadataSent = true; + streamingWritesCounter++; + streamingWriteTcs = new TaskCompletionSource(); + return streamingWriteTcs.Task; + } } } @@ -213,19 +215,11 @@ namespace Grpc.Core.Internal /// protected abstract Task CheckSendAllowedOrEarlyResult(); - protected SliceBufferSafeHandle UnsafeSerialize(TWrite msg) + // runs the serializer, propagating any exceptions being thrown without modifying them + protected SliceBufferSafeHandle UnsafeSerialize(TWrite msg, DefaultSerializationContext context) { - DefaultSerializationContext context = null; - try - { - context = DefaultSerializationContext.GetInitializedThreadLocal(); - serializer(msg, context); - return context.GetPayload(); - } - finally - { - context?.Reset(); - } + serializer(msg, context); + return context.GetPayload(); } protected Exception TryDeserialize(IBufferReader reader, out TRead msg) diff --git a/src/csharp/Grpc.Core/Internal/AsyncCallServer.cs b/src/csharp/Grpc.Core/Internal/AsyncCallServer.cs index e1c3a215422..e0bb41e15dc 100644 --- a/src/csharp/Grpc.Core/Internal/AsyncCallServer.cs +++ b/src/csharp/Grpc.Core/Internal/AsyncCallServer.cs @@ -129,28 +129,31 @@ namespace Grpc.Core.Internal /// public Task SendStatusFromServerAsync(Status status, Metadata trailers, ResponseWithFlags? optionalWrite) { - var payload = optionalWrite.HasValue ? UnsafeSerialize(optionalWrite.Value.Response) : null; - var writeFlags = optionalWrite.HasValue ? optionalWrite.Value.WriteFlags : default(WriteFlags); - - lock (myLock) + using (var serializationScope = DefaultSerializationContext.GetInitializedThreadLocalScope()) { - GrpcPreconditions.CheckState(started); - GrpcPreconditions.CheckState(!disposed); - GrpcPreconditions.CheckState(!halfcloseRequested, "Can only send status from server once."); + var payload = optionalWrite.HasValue ? UnsafeSerialize(optionalWrite.Value.Response, serializationScope.Context) : null; + var writeFlags = optionalWrite.HasValue ? optionalWrite.Value.WriteFlags : default(WriteFlags); - using (var metadataArray = MetadataArraySafeHandle.Create(trailers)) - { - call.StartSendStatusFromServer(SendStatusFromServerCompletionCallback, status, metadataArray, !initialMetadataSent, - payload, writeFlags); - } - halfcloseRequested = true; - initialMetadataSent = true; - sendStatusFromServerTcs = new TaskCompletionSource(); - if (optionalWrite.HasValue) + lock (myLock) { - streamingWritesCounter++; + GrpcPreconditions.CheckState(started); + GrpcPreconditions.CheckState(!disposed); + GrpcPreconditions.CheckState(!halfcloseRequested, "Can only send status from server once."); + + using (var metadataArray = MetadataArraySafeHandle.Create(trailers)) + { + call.StartSendStatusFromServer(SendStatusFromServerCompletionCallback, status, metadataArray, !initialMetadataSent, + payload, writeFlags); + } + halfcloseRequested = true; + initialMetadataSent = true; + sendStatusFromServerTcs = new TaskCompletionSource(); + if (optionalWrite.HasValue) + { + streamingWritesCounter++; + } + return sendStatusFromServerTcs.Task; } - return sendStatusFromServerTcs.Task; } } diff --git a/src/csharp/Grpc.Core/Internal/DefaultSerializationContext.cs b/src/csharp/Grpc.Core/Internal/DefaultSerializationContext.cs index db5e78a608d..6bf9da56264 100644 --- a/src/csharp/Grpc.Core/Internal/DefaultSerializationContext.cs +++ b/src/csharp/Grpc.Core/Internal/DefaultSerializationContext.cs @@ -29,8 +29,7 @@ namespace Grpc.Core.Internal new ThreadLocal(() => new DefaultSerializationContext(), false); bool isComplete; - //byte[] payload; - SliceBufferSafeHandle sliceBuffer; + SliceBufferSafeHandle sliceBuffer = SliceBufferSafeHandle.Create(); public DefaultSerializationContext() { @@ -42,12 +41,10 @@ namespace Grpc.Core.Internal GrpcPreconditions.CheckState(!isComplete); this.isComplete = true; - GetBufferWriter(); var destSpan = sliceBuffer.GetSpan(payload.Length); payload.AsSpan().CopyTo(destSpan); sliceBuffer.Advance(payload.Length); sliceBuffer.Complete(); - //this.payload = payload; } /// @@ -55,11 +52,6 @@ namespace Grpc.Core.Internal /// public override IBufferWriter GetBufferWriter() { - if (sliceBuffer == null) - { - // TODO: avoid allocation.. - sliceBuffer = SliceBufferSafeHandle.Create(); - } return sliceBuffer; } @@ -81,17 +73,35 @@ namespace Grpc.Core.Internal public void Reset() { this.isComplete = false; - //this.payload = null; - this.sliceBuffer = null; // reset instead... + this.sliceBuffer.Reset(); } - public static DefaultSerializationContext GetInitializedThreadLocal() + // Get a cached thread local instance of deserialization context + // and wrap it in a disposable struct that allows easy resetting + // via "using" statement. + public static UsageScope GetInitializedThreadLocalScope() { var instance = threadLocalInstance.Value; - instance.Reset(); - return instance; + return new UsageScope(instance); } - + public struct UsageScope : IDisposable + { + readonly DefaultSerializationContext context; + + public UsageScope(DefaultSerializationContext context) + { + this.context = context; + } + + public DefaultSerializationContext Context => context; + + // TODO: add Serialize method... + + public void Dispose() + { + context.Reset(); + } + } } }