diff --git a/src/csharp/Grpc.Core.Tests/Interceptors/ClientInterceptorTest.cs b/src/csharp/Grpc.Core.Tests/Interceptors/ClientInterceptorTest.cs index ca7677c41f0..0ec2d848f00 100644 --- a/src/csharp/Grpc.Core.Tests/Interceptors/ClientInterceptorTest.cs +++ b/src/csharp/Grpc.Core.Tests/Interceptors/ClientInterceptorTest.cs @@ -71,17 +71,14 @@ namespace Grpc.Core.Interceptors.Tests var stringBuilder = new StringBuilder(); var callInvoker = helper.GetChannel().Intercept(metadata => { - metadata = metadata ?? new Metadata(); stringBuilder.Append("interceptor1"); return metadata; }).Intercept(metadata => { - metadata = metadata ?? new Metadata(); stringBuilder.Append("interceptor2"); return metadata; }).Intercept(metadata => { - metadata = metadata ?? new Metadata(); stringBuilder.Append("interceptor3"); return metadata; }); @@ -91,14 +88,14 @@ namespace Grpc.Core.Interceptors.Tests private class CountingInterceptor : GenericInterceptor { - protected override ClientCallArbitrator InterceptCall(ClientInterceptorContext context, bool clientStreaming, bool serverStreaming, TRequest request) + protected override ClientCallHooks InterceptCall(ClientInterceptorContext context, bool clientStreaming, bool serverStreaming, TRequest request) { if (!clientStreaming) { return null; } int counter = 0; - return new ClientCallArbitrator + return new ClientCallHooks { OnRequestMessage = m => { counter++; return m; }, OnUnaryResponse = x => (TResponse)(object)counter.ToString() // Cast to object first is needed to satisfy the type-checker @@ -112,14 +109,14 @@ namespace Grpc.Core.Interceptors.Tests var helper = new MockServiceHelper(Host); helper.ClientStreamingHandler = new ClientStreamingServerMethod(async (requestStream, context) => { - string result = ""; - await requestStream.ForEachAsync((request) => + var stringBuilder = new StringBuilder(); + await requestStream.ForEachAsync(request => { - result += request; + stringBuilder.Append(request); return TaskUtils.CompletedTask; }); await Task.Delay(100); - return result; + return stringBuilder.ToString(); }); var callInvoker = helper.GetChannel().Intercept(new CountingInterceptor()); diff --git a/src/csharp/Grpc.Core.Tests/Interceptors/ServerInterceptorTest.cs b/src/csharp/Grpc.Core.Tests/Interceptors/ServerInterceptorTest.cs index fbace51db50..57dd68b1ebf 100644 --- a/src/csharp/Grpc.Core.Tests/Interceptors/ServerInterceptorTest.cs +++ b/src/csharp/Grpc.Core.Tests/Interceptors/ServerInterceptorTest.cs @@ -44,10 +44,10 @@ namespace Grpc.Core.Interceptors.Tests this.header = new Metadata.Entry(key, value); } - protected override Task> InterceptHandler(ServerCallContext context, bool clientStreaming, bool serverStreaming, TRequest request) + protected override Task> InterceptHandler(ServerCallContext context, bool clientStreaming, bool serverStreaming, TRequest request) { context.RequestHeaders.Add(header); - return Task.FromResult>(null); + return Task.FromResult>(null); } public Metadata.Entry Header @@ -87,10 +87,10 @@ namespace Grpc.Core.Interceptors.Tests this.action = action; } - protected override Task> InterceptHandler(ServerCallContext context, bool clientStreaming, bool serverStreaming, TRequest request) + protected override Task> InterceptHandler(ServerCallContext context, bool clientStreaming, bool serverStreaming, TRequest request) { action(); - return Task.FromResult>(null); + return Task.FromResult>(null); } } diff --git a/src/csharp/Grpc.Core/ClientBase.cs b/src/csharp/Grpc.Core/ClientBase.cs index f978d084d98..4bb06ed87fb 100644 --- a/src/csharp/Grpc.Core/ClientBase.cs +++ b/src/csharp/Grpc.Core/ClientBase.cs @@ -161,12 +161,12 @@ namespace Grpc.Core this.interceptor = GrpcPreconditions.CheckNotNull(interceptor, "interceptor"); } - protected override ClientCallArbitrator InterceptCall(ClientInterceptorContext context, bool clientStreaming, bool serverStreaming, TRequest request) + protected override ClientCallHooks InterceptCall(ClientInterceptorContext context, bool clientStreaming, bool serverStreaming, TRequest request) { var newHostAndCallOptions = interceptor(context.Method, context.Host, context.Options); - return new ClientCallArbitrator + return new ClientCallHooks { - Context = new ClientInterceptorContext(context.Method, newHostAndCallOptions.Item1, newHostAndCallOptions.Item2) + ContextOverride = new ClientInterceptorContext(context.Method, newHostAndCallOptions.Item1, newHostAndCallOptions.Item2) }; } } diff --git a/src/csharp/Grpc.Core/Interceptors/CallInvokerExtensions.cs b/src/csharp/Grpc.Core/Interceptors/CallInvokerExtensions.cs index f1835f6bd86..1c0831a242a 100644 --- a/src/csharp/Grpc.Core/Interceptors/CallInvokerExtensions.cs +++ b/src/csharp/Grpc.Core/Interceptors/CallInvokerExtensions.cs @@ -113,11 +113,12 @@ namespace Grpc.Core.Interceptors this.interceptor = GrpcPreconditions.CheckNotNull(interceptor, "interceptor"); } - protected override ClientCallArbitrator InterceptCall(ClientInterceptorContext context, bool clientStreaming, bool serverStreaming, TRequest request) + protected override ClientCallHooks InterceptCall(ClientInterceptorContext context, bool clientStreaming, bool serverStreaming, TRequest request) { - return new ClientCallArbitrator + var metadata = context.Options.Headers ?? new Metadata(); + return new ClientCallHooks { - Context = new ClientInterceptorContext(context.Method, context.Host, context.Options.WithHeaders(interceptor(context.Options.Headers))) + ContextOverride = new ClientInterceptorContext(context.Method, context.Host, context.Options.WithHeaders(interceptor(metadata))), }; } } diff --git a/src/csharp/Grpc.Core/Interceptors/GenericInterceptor.cs b/src/csharp/Grpc.Core/Interceptors/GenericInterceptor.cs index ed90ded8897..7ee649e9b53 100644 --- a/src/csharp/Grpc.Core/Interceptors/GenericInterceptor.cs +++ b/src/csharp/Grpc.Core/Interceptors/GenericInterceptor.cs @@ -27,27 +27,27 @@ namespace Grpc.Core.Interceptors /// Provides a base class for generic interceptor implementations that raises /// events and hooks to control the RPC lifecycle. /// - public abstract class GenericInterceptor : Interceptor + internal abstract class GenericInterceptor : Interceptor { /// /// Provides hooks through which an invocation should be intercepted. /// - public sealed class ClientCallArbitrator + public sealed class ClientCallHooks where TRequest : class where TResponse : class { - internal ClientCallArbitrator Freeze() + internal ClientCallHooks Freeze() { - return (ClientCallArbitrator)MemberwiseClone(); + return (ClientCallHooks)MemberwiseClone(); } /// /// Override the context for the outgoing invocation. /// - public ClientInterceptorContext Context { get; set; } + public ClientInterceptorContext? ContextOverride { get; set; } /// /// Override the request for the outgoing invocation for non-client-streaming invocations. /// - public TRequest UnaryRequest { get; set; } + public TRequest UnaryRequestOverride { get; set; } /// /// Delegate that intercepts a response from a non-server-streaming invocation and optionally overrides it. /// @@ -73,7 +73,7 @@ namespace Grpc.Core.Interceptors /// /// Intercepts an outgoing call from the client side. /// Derived classes that intend to intercept outgoing invocations from the client side should - /// override this and return the appropriate hooks in the form of a ClientCallArbitrator instance. + /// override this and return the appropriate hooks in the form of a ClientCallHooks instance. /// /// The context of the outgoing invocation. /// True if the invocation is client-streaming. @@ -82,10 +82,10 @@ namespace Grpc.Core.Interceptors /// Request message type for the current invocation. /// Response message type for the current invocation. /// - /// The derived class should return an instance of ClientCallArbitrator to control the trajectory + /// The derived class should return an instance of ClientCallHooks to control the trajectory /// as they see fit, or null if it does not intend to pursue the invocation any further. /// - protected virtual ClientCallArbitrator InterceptCall(ClientInterceptorContext context, bool clientStreaming, bool serverStreaming, TRequest request) + protected virtual ClientCallHooks InterceptCall(ClientInterceptorContext context, bool clientStreaming, bool serverStreaming, TRequest request) where TRequest : class where TResponse : class { @@ -95,18 +95,18 @@ namespace Grpc.Core.Interceptors /// /// Provides hooks through which a server-side handler should be intercepted. /// - public sealed class ServerCallArbitrator + public sealed class ServerCallHooks where TRequest : class where TResponse : class { - internal ServerCallArbitrator Freeze() + internal ServerCallHooks Freeze() { - return (ServerCallArbitrator)MemberwiseClone(); + return (ServerCallHooks)MemberwiseClone(); } /// /// Override the request for the outgoing invocation for non-client-streaming invocations. /// - public TRequest UnaryRequest { get; set; } + public TRequest UnaryRequestOverride { get; set; } /// /// Delegate that intercepts a response from a non-server-streaming invocation and optionally overrides it. /// @@ -132,7 +132,7 @@ namespace Grpc.Core.Interceptors /// /// Intercepts an incoming service handler invocation on the server side. /// Derived classes that intend to intercept incoming handlers on the server side should - /// override this and return the appropriate hooks in the form of a ServerCallArbitrator instance. + /// override this and return the appropriate hooks in the form of a ServerCallHooks instance. /// /// The context of the incoming invocation. /// True if the invocation is client-streaming. @@ -141,14 +141,14 @@ namespace Grpc.Core.Interceptors /// Request message type for the current invocation. /// Response message type for the current invocation. /// - /// The derived class should return an instance of ServerCallArbitrator to control the trajectory + /// The derived class should return an instance of ServerCallHooks to control the trajectory /// as they see fit, or null if it does not intend to pursue the invocation any further. /// - protected virtual Task> InterceptHandler(ServerCallContext context, bool clientStreaming, bool serverStreaming, TRequest request) + protected virtual Task> InterceptHandler(ServerCallContext context, bool clientStreaming, bool serverStreaming, TRequest request) where TRequest : class where TResponse : class { - return Task.FromResult>(null); + return Task.FromResult>(null); } /// @@ -156,13 +156,13 @@ namespace Grpc.Core.Interceptors /// public override TResponse BlockingUnaryCall(TRequest request, ClientInterceptorContext context, BlockingUnaryCallContinuation continuation) { - var arbitrator = InterceptCall(context, false, false, request)?.Freeze(); - context = arbitrator?.Context ?? context; - request = arbitrator?.UnaryRequest ?? request; + var hooks = InterceptCall(context, false, false, request)?.Freeze(); + context = hooks?.ContextOverride ?? context; + request = hooks?.UnaryRequestOverride ?? request; var response = continuation(request, context); - if (arbitrator?.OnUnaryResponse != null) + if (hooks?.OnUnaryResponse != null) { - response = arbitrator.OnUnaryResponse(response); + response = hooks.OnUnaryResponse(response); } return response; } @@ -172,13 +172,13 @@ namespace Grpc.Core.Interceptors /// public override AsyncUnaryCall AsyncUnaryCall(TRequest request, ClientInterceptorContext context, AsyncUnaryCallContinuation continuation) { - var arbitrator = InterceptCall(context, false, false, request)?.Freeze(); - context = arbitrator?.Context ?? context; - request = arbitrator?.UnaryRequest ?? request; + var hooks = InterceptCall(context, false, false, request)?.Freeze(); + context = hooks?.ContextOverride ?? context; + request = hooks?.UnaryRequestOverride ?? request; var response = continuation(request, context); - if (arbitrator?.OnUnaryResponse != null) + if (hooks?.OnUnaryResponse != null) { - response = new AsyncUnaryCall(response.ResponseAsync.ContinueWith(unaryResponse => arbitrator.OnUnaryResponse(unaryResponse.Result)), + response = new AsyncUnaryCall(response.ResponseAsync.ContinueWith(unaryResponse => hooks.OnUnaryResponse(unaryResponse.Result)), response.ResponseHeadersAsync, response.GetStatus, response.GetTrailers, response.Dispose); } return response; @@ -189,14 +189,14 @@ namespace Grpc.Core.Interceptors /// public override AsyncServerStreamingCall AsyncServerStreamingCall(TRequest request, ClientInterceptorContext context, AsyncServerStreamingCallContinuation continuation) { - var arbitrator = InterceptCall(context, false, true, request)?.Freeze(); - context = arbitrator?.Context ?? context; - request = arbitrator?.UnaryRequest ?? request; + var hooks = InterceptCall(context, false, true, request)?.Freeze(); + context = hooks?.ContextOverride ?? context; + request = hooks?.UnaryRequestOverride ?? request; var response = continuation(request, context); - if (arbitrator?.OnResponseMessage != null || arbitrator?.OnResponseStreamEnd != null) + if (hooks?.OnResponseMessage != null || hooks?.OnResponseStreamEnd != null) { response = new AsyncServerStreamingCall( - new WrappedAsyncStreamReader(response.ResponseStream, arbitrator.OnResponseMessage, arbitrator.OnResponseStreamEnd), + new WrappedAsyncStreamReader(response.ResponseStream, hooks.OnResponseMessage, hooks.OnResponseStreamEnd), response.ResponseHeadersAsync, response.GetStatus, response.GetTrailers, response.Dispose); } return response; @@ -207,20 +207,20 @@ namespace Grpc.Core.Interceptors /// public override AsyncClientStreamingCall AsyncClientStreamingCall(ClientInterceptorContext context, AsyncClientStreamingCallContinuation continuation) { - var arbitrator = InterceptCall(context, true, false, null)?.Freeze(); - context = arbitrator?.Context ?? context; + var hooks = InterceptCall(context, true, false, null)?.Freeze(); + context = hooks?.ContextOverride ?? context; var response = continuation(context); - if (arbitrator?.OnRequestMessage != null || arbitrator?.OnResponseStreamEnd != null || arbitrator?.OnUnaryResponse != null) + if (hooks?.OnRequestMessage != null || hooks?.OnResponseStreamEnd != null || hooks?.OnUnaryResponse != null) { var requestStream = response.RequestStream; - if (arbitrator?.OnRequestMessage != null || arbitrator?.OnRequestStreamEnd != null) + if (hooks?.OnRequestMessage != null || hooks?.OnRequestStreamEnd != null) { - requestStream = new WrappedClientStreamWriter(response.RequestStream, arbitrator.OnRequestMessage, arbitrator.OnRequestStreamEnd); + requestStream = new WrappedClientStreamWriter(response.RequestStream, hooks.OnRequestMessage, hooks.OnRequestStreamEnd); } var responseAsync = response.ResponseAsync; - if (arbitrator?.OnUnaryResponse != null) + if (hooks?.OnUnaryResponse != null) { - responseAsync = response.ResponseAsync.ContinueWith(unaryResponse => arbitrator.OnUnaryResponse(unaryResponse.Result)); + responseAsync = response.ResponseAsync.ContinueWith(unaryResponse => hooks.OnUnaryResponse(unaryResponse.Result)); } response = new AsyncClientStreamingCall(requestStream, responseAsync, response.ResponseHeadersAsync, response.GetStatus, response.GetTrailers, response.Dispose); } @@ -232,20 +232,20 @@ namespace Grpc.Core.Interceptors /// public override AsyncDuplexStreamingCall AsyncDuplexStreamingCall(ClientInterceptorContext context, AsyncDuplexStreamingCallContinuation continuation) { - var arbitrator = InterceptCall(context, true, true, null)?.Freeze(); - context = arbitrator?.Context ?? context; + var hooks = InterceptCall(context, true, true, null)?.Freeze(); + context = hooks?.ContextOverride ?? context; var response = continuation(context); - if (arbitrator?.OnRequestMessage != null || arbitrator?.OnRequestStreamEnd != null || arbitrator?.OnResponseMessage != null || arbitrator?.OnResponseStreamEnd != null) + if (hooks?.OnRequestMessage != null || hooks?.OnRequestStreamEnd != null || hooks?.OnResponseMessage != null || hooks?.OnResponseStreamEnd != null) { var requestStream = response.RequestStream; - if (arbitrator?.OnRequestMessage != null || arbitrator?.OnRequestStreamEnd != null) + if (hooks?.OnRequestMessage != null || hooks?.OnRequestStreamEnd != null) { - requestStream = new WrappedClientStreamWriter(response.RequestStream, arbitrator.OnRequestMessage, arbitrator.OnRequestStreamEnd); + requestStream = new WrappedClientStreamWriter(response.RequestStream, hooks.OnRequestMessage, hooks.OnRequestStreamEnd); } var responseStream = response.ResponseStream; - if (arbitrator?.OnResponseMessage != null || arbitrator?.OnResponseStreamEnd != null) + if (hooks?.OnResponseMessage != null || hooks?.OnResponseStreamEnd != null) { - responseStream = new WrappedAsyncStreamReader(response.ResponseStream, arbitrator.OnResponseMessage, arbitrator.OnResponseStreamEnd); + responseStream = new WrappedAsyncStreamReader(response.ResponseStream, hooks.OnResponseMessage, hooks.OnResponseStreamEnd); } response = new AsyncDuplexStreamingCall(requestStream, responseStream, response.ResponseHeadersAsync, response.GetStatus, response.GetTrailers, response.Dispose); } @@ -259,14 +259,14 @@ namespace Grpc.Core.Interceptors /// Response message type for this method. public override async Task UnaryServerHandler(TRequest request, ServerCallContext context, UnaryServerMethod continuation) { - var arbitrator = (await InterceptHandler(context, false, false, request))?.Freeze(); - request = arbitrator?.UnaryRequest ?? request; + var hooks = (await InterceptHandler(context, false, false, request))?.Freeze(); + request = hooks?.UnaryRequestOverride ?? request; var response = await continuation(request, context); - if (arbitrator?.OnUnaryResponse != null) + if (hooks?.OnUnaryResponse != null) { - response = arbitrator.OnUnaryResponse(response); + response = hooks.OnUnaryResponse(response); } - arbitrator?.OnHandlerEnd(); + hooks?.OnHandlerEnd(); return response; } @@ -277,17 +277,17 @@ namespace Grpc.Core.Interceptors /// Response message type for this method. public override async Task ClientStreamingServerHandler(IAsyncStreamReader requestStream, ServerCallContext context, ClientStreamingServerMethod continuation) { - var arbitrator = (await InterceptHandler(context, true, false, null))?.Freeze(); - if (arbitrator?.OnRequestMessage != null || arbitrator?.OnRequestStreamEnd != null) + var hooks = (await InterceptHandler(context, true, false, null))?.Freeze(); + if (hooks?.OnRequestMessage != null || hooks?.OnRequestStreamEnd != null) { - requestStream = new WrappedAsyncStreamReader(requestStream, arbitrator.OnRequestMessage, arbitrator.OnRequestStreamEnd); + requestStream = new WrappedAsyncStreamReader(requestStream, hooks.OnRequestMessage, hooks.OnRequestStreamEnd); } var response = await continuation(requestStream, context); - if (arbitrator?.OnUnaryResponse != null) + if (hooks?.OnUnaryResponse != null) { - response = arbitrator.OnUnaryResponse(response); + response = hooks.OnUnaryResponse(response); } - arbitrator?.OnHandlerEnd(); + hooks?.OnHandlerEnd(); return response; } @@ -298,14 +298,14 @@ namespace Grpc.Core.Interceptors /// Response message type for this method. public override async Task ServerStreamingServerHandler(TRequest request, IServerStreamWriter responseStream, ServerCallContext context, ServerStreamingServerMethod continuation) { - var arbitrator = (await InterceptHandler(context, false, true, request))?.Freeze(); - request = arbitrator?.UnaryRequest ?? request; - if (arbitrator?.OnResponseMessage != null) + var hooks = (await InterceptHandler(context, false, true, request))?.Freeze(); + request = hooks?.UnaryRequestOverride ?? request; + if (hooks?.OnResponseMessage != null) { - responseStream = new WrappedAsyncStreamWriter(responseStream, arbitrator.OnResponseMessage); + responseStream = new WrappedAsyncStreamWriter(responseStream, hooks.OnResponseMessage); } await continuation(request, responseStream, context); - arbitrator?.OnHandlerEnd(); + hooks?.OnHandlerEnd(); } /// @@ -315,17 +315,17 @@ namespace Grpc.Core.Interceptors /// Response message type for this method. public override async Task DuplexStreamingServerHandler(IAsyncStreamReader requestStream, IServerStreamWriter responseStream, ServerCallContext context, DuplexStreamingServerMethod continuation) { - var arbitrator = (await InterceptHandler(context, true, true, null))?.Freeze(); - if (arbitrator?.OnRequestMessage != null || arbitrator?.OnRequestStreamEnd != null) + var hooks = (await InterceptHandler(context, true, true, null))?.Freeze(); + if (hooks?.OnRequestMessage != null || hooks?.OnRequestStreamEnd != null) { - requestStream = new WrappedAsyncStreamReader(requestStream, arbitrator.OnRequestMessage, arbitrator.OnRequestStreamEnd); + requestStream = new WrappedAsyncStreamReader(requestStream, hooks.OnRequestMessage, hooks.OnRequestStreamEnd); } - if (arbitrator?.OnResponseMessage != null) + if (hooks?.OnResponseMessage != null) { - responseStream = new WrappedAsyncStreamWriter(responseStream, arbitrator.OnResponseMessage); + responseStream = new WrappedAsyncStreamWriter(responseStream, hooks.OnResponseMessage); } await continuation(requestStream, responseStream, context); - arbitrator?.OnHandlerEnd(); + hooks?.OnHandlerEnd(); } private class WrappedAsyncStreamReader : IAsyncStreamReader