Eliminate GenericInterceptor to simplify this PR

pull/12613/head
Mehrdad Afshari 7 years ago
parent 074b802c9f
commit a7c1b6251c
  1. 130
      src/csharp/Grpc.Core.Tests/Interceptors/ClientInterceptorTest.cs
  2. 93
      src/csharp/Grpc.Core.Tests/Interceptors/ServerInterceptorTest.cs
  3. 46
      src/csharp/Grpc.Core/ClientBase.cs
  4. 38
      src/csharp/Grpc.Core/Interceptors/CallInvokerExtensions.cs
  5. 449
      src/csharp/Grpc.Core/Interceptors/GenericInterceptor.cs

@ -58,22 +58,6 @@ namespace Grpc.Core.Interceptors.Tests
Assert.AreEqual("PASS", callInvoker.BlockingUnaryCall(new Method<string, string>(MethodType.Unary, MockServiceHelper.ServiceName, "Unary", Marshallers.StringMarshaller, Marshallers.StringMarshaller), Host, new CallOptions(), ""));
}
private class CallbackInterceptor : GenericInterceptor
{
readonly Action callback;
public CallbackInterceptor(Action callback)
{
this.callback = callback;
}
protected override ClientCallHooks<TRequest, TResponse> InterceptCall<TRequest, TResponse>(ClientInterceptorContext<TRequest, TResponse> context, bool clientStreaming, bool serverStreaming, TRequest request)
{
callback();
return null;
}
}
[Test]
public void CheckInterceptorOrderInClientInterceptors()
{
@ -118,23 +102,6 @@ namespace Grpc.Core.Interceptors.Tests
Assert.Throws<ArgumentNullException>(() => helper.GetChannel().Intercept(default(Interceptor[])));
}
private class CountingInterceptor : GenericInterceptor
{
protected override ClientCallHooks<TRequest, TResponse> InterceptCall<TRequest, TResponse>(ClientInterceptorContext<TRequest, TResponse> context, bool clientStreaming, bool serverStreaming, TRequest request)
{
if (!clientStreaming)
{
return null;
}
int counter = 0;
return new ClientCallHooks<TRequest, TResponse>
{
OnRequestMessage = m => { counter++; return m; },
OnUnaryResponse = x => (TResponse)(object)counter.ToString() // Cast to object first is needed to satisfy the type-checker
};
}
}
[Test]
public async Task CountNumberOfRequestsInClientInterceptors()
{
@ -151,7 +118,7 @@ namespace Grpc.Core.Interceptors.Tests
return stringBuilder.ToString();
});
var callInvoker = helper.GetChannel().Intercept(new CountingInterceptor());
var callInvoker = helper.GetChannel().Intercept(new ClientStreamingCountingInterceptor());
var server = helper.GetServer();
server.Start();
@ -162,5 +129,100 @@ namespace Grpc.Core.Interceptors.Tests
Assert.AreEqual(StatusCode.OK, call.GetStatus().StatusCode);
Assert.IsNotNull(call.GetTrailers());
}
private class CallbackInterceptor : Interceptor
{
readonly Action callback;
public CallbackInterceptor(Action callback)
{
this.callback = GrpcPreconditions.CheckNotNull(callback, nameof(callback));
}
public override TResponse BlockingUnaryCall<TRequest, TResponse>(TRequest request, ClientInterceptorContext<TRequest, TResponse> context, BlockingUnaryCallContinuation<TRequest, TResponse> continuation)
{
callback();
return continuation(request, context);
}
public override AsyncUnaryCall<TResponse> AsyncUnaryCall<TRequest, TResponse>(TRequest request, ClientInterceptorContext<TRequest, TResponse> context, AsyncUnaryCallContinuation<TRequest, TResponse> continuation)
{
callback();
return continuation(request, context);
}
public override AsyncServerStreamingCall<TResponse> AsyncServerStreamingCall<TRequest, TResponse>(TRequest request, ClientInterceptorContext<TRequest, TResponse> context, AsyncServerStreamingCallContinuation<TRequest, TResponse> continuation)
{
callback();
return continuation(request, context);
}
public override AsyncClientStreamingCall<TRequest, TResponse> AsyncClientStreamingCall<TRequest, TResponse>(ClientInterceptorContext<TRequest, TResponse> context, AsyncClientStreamingCallContinuation<TRequest, TResponse> continuation)
{
callback();
return continuation(context);
}
public override AsyncDuplexStreamingCall<TRequest, TResponse> AsyncDuplexStreamingCall<TRequest, TResponse>(ClientInterceptorContext<TRequest, TResponse> context, AsyncDuplexStreamingCallContinuation<TRequest, TResponse> continuation)
{
callback();
return continuation(context);
}
}
private class ClientStreamingCountingInterceptor : Interceptor
{
public override AsyncClientStreamingCall<TRequest, TResponse> AsyncClientStreamingCall<TRequest, TResponse>(ClientInterceptorContext<TRequest, TResponse> context, AsyncClientStreamingCallContinuation<TRequest, TResponse> continuation)
{
var response = continuation(context);
int counter = 0;
var requestStream = new WrappedClientStreamWriter<TRequest>(response.RequestStream,
message => { counter++; return message; }, null);
var responseAsync = response.ResponseAsync.ContinueWith(
unaryResponse => (TResponse)(object)counter.ToString() // Cast to object first is needed to satisfy the type-checker
);
return new AsyncClientStreamingCall<TRequest, TResponse>(requestStream, responseAsync, response.ResponseHeadersAsync, response.GetStatus, response.GetTrailers, response.Dispose);
}
}
private class WrappedClientStreamWriter<T> : IClientStreamWriter<T>
{
readonly IClientStreamWriter<T> writer;
readonly Func<T, T> onMessage;
readonly Action onResponseStreamEnd;
public WrappedClientStreamWriter(IClientStreamWriter<T> writer, Func<T, T> onMessage, Action onResponseStreamEnd)
{
this.writer = writer;
this.onMessage = onMessage;
this.onResponseStreamEnd = onResponseStreamEnd;
}
public Task CompleteAsync()
{
if (onResponseStreamEnd != null)
{
return writer.CompleteAsync().ContinueWith(x => onResponseStreamEnd());
}
return writer.CompleteAsync();
}
public Task WriteAsync(T message)
{
if (onMessage != null)
{
message = onMessage(message);
}
return writer.WriteAsync(message);
}
public WriteOptions WriteOptions
{
get
{
return writer.WriteOptions;
}
set
{
writer.WriteOptions = value;
}
}
}
}
}

@ -35,33 +35,17 @@ namespace Grpc.Core.Interceptors.Tests
{
const string Host = "127.0.0.1";
private class AddRequestHeaderServerInterceptor : GenericInterceptor
{
readonly Metadata.Entry header;
public AddRequestHeaderServerInterceptor(string key, string value)
{
this.header = new Metadata.Entry(key, value);
}
protected override Task<ServerCallHooks<TRequest, TResponse>> InterceptHandler<TRequest, TResponse>(ServerCallContext context, bool clientStreaming, bool serverStreaming, TRequest request)
{
context.RequestHeaders.Add(header);
return Task.FromResult<ServerCallHooks<TRequest, TResponse>>(null);
}
public Metadata.Entry Header => header;
}
[Test]
public void AddRequestHeaderInServerInterceptor()
{
var helper = new MockServiceHelper(Host);
var interceptor = new AddRequestHeaderServerInterceptor("x-interceptor", "hello world");
const string MetadataKey = "x-interceptor";
const string MetadataValue = "hello world";
var interceptor = new ServerCallContextInterceptor(ctx => ctx.RequestHeaders.Add(new Metadata.Entry(MetadataKey, MetadataValue)));
helper.UnaryHandler = new UnaryServerMethod<string, string>((request, context) =>
{
var interceptorHeader = context.RequestHeaders.Last(m => (m.Key == interceptor.Header.Key)).Value;
Assert.AreEqual(interceptorHeader, interceptor.Header.Value);
var interceptorHeader = context.RequestHeaders.Last(m => (m.Key == MetadataKey)).Value;
Assert.AreEqual(interceptorHeader, MetadataValue);
return Task.FromResult("PASS");
});
helper.ServiceDefinition = helper.ServiceDefinition.Intercept(interceptor);
@ -71,22 +55,6 @@ namespace Grpc.Core.Interceptors.Tests
Assert.AreEqual("PASS", Calls.BlockingUnaryCall(helper.CreateUnaryCall(), ""));
}
private class ArbitraryActionInterceptor : GenericInterceptor
{
readonly Action action;
public ArbitraryActionInterceptor(Action action)
{
this.action = action;
}
protected override Task<ServerCallHooks<TRequest, TResponse>> InterceptHandler<TRequest, TResponse>(ServerCallContext context, bool clientStreaming, bool serverStreaming, TRequest request)
{
action();
return Task.FromResult<ServerCallHooks<TRequest, TResponse>>(null);
}
}
[Test]
public void VerifyInterceptorOrdering()
{
@ -97,11 +65,11 @@ namespace Grpc.Core.Interceptors.Tests
});
var stringBuilder = new StringBuilder();
helper.ServiceDefinition = helper.ServiceDefinition
.Intercept(new ArbitraryActionInterceptor(() => stringBuilder.Append("A")))
.Intercept(new ArbitraryActionInterceptor(() => stringBuilder.Append("B1")),
new ArbitraryActionInterceptor(() => stringBuilder.Append("B2")),
new ArbitraryActionInterceptor(() => stringBuilder.Append("B3")))
.Intercept(new ArbitraryActionInterceptor(() => stringBuilder.Append("C")));
.Intercept(new ServerCallContextInterceptor(ctx => stringBuilder.Append("A")))
.Intercept(new ServerCallContextInterceptor(ctx => stringBuilder.Append("B1")),
new ServerCallContextInterceptor(ctx => stringBuilder.Append("B2")),
new ServerCallContextInterceptor(ctx => stringBuilder.Append("B3")))
.Intercept(new ServerCallContextInterceptor(ctx => stringBuilder.Append("C")));
var server = helper.GetServer();
server.Start();
var channel = helper.GetChannel();
@ -113,15 +81,46 @@ namespace Grpc.Core.Interceptors.Tests
public void CheckNullInterceptorRegistrationFails()
{
var helper = new MockServiceHelper(Host);
helper.UnaryHandler = new UnaryServerMethod<string, string>((request, context) =>
{
return Task.FromResult("PASS");
});
var sd = helper.ServiceDefinition;
Assert.Throws<ArgumentNullException>(() => sd.Intercept(default(Interceptor)));
Assert.Throws<ArgumentNullException>(() => sd.Intercept(new[]{default(Interceptor)}));
Assert.Throws<ArgumentNullException>(() => sd.Intercept(new[]{new ArbitraryActionInterceptor(()=>{}), null}));
Assert.Throws<ArgumentNullException>(() => sd.Intercept(new[]{new ServerCallContextInterceptor(ctx=>{}), null}));
Assert.Throws<ArgumentNullException>(() => sd.Intercept(default(Interceptor[])));
}
private class ServerCallContextInterceptor : Interceptor
{
readonly Action<ServerCallContext> interceptor;
public ServerCallContextInterceptor(Action<ServerCallContext> interceptor)
{
GrpcPreconditions.CheckNotNull(interceptor, nameof(interceptor));
this.interceptor = interceptor;
}
public override Task<TResponse> UnaryServerHandler<TRequest, TResponse>(TRequest request, ServerCallContext context, UnaryServerMethod<TRequest, TResponse> continuation)
{
interceptor(context);
return continuation(request, context);
}
public override Task<TResponse> ClientStreamingServerHandler<TRequest, TResponse>(IAsyncStreamReader<TRequest> requestStream, ServerCallContext context, ClientStreamingServerMethod<TRequest, TResponse> continuation)
{
interceptor(context);
return continuation(requestStream, context);
}
public override Task ServerStreamingServerHandler<TRequest, TResponse>(TRequest request, IServerStreamWriter<TResponse> responseStream, ServerCallContext context, ServerStreamingServerMethod<TRequest, TResponse> continuation)
{
interceptor(context);
return continuation(request, responseStream, context);
}
public override Task DuplexStreamingServerHandler<TRequest, TResponse>(IAsyncStreamReader<TRequest> requestStream, IServerStreamWriter<TResponse> responseStream, ServerCallContext context, DuplexStreamingServerMethod<TRequest, TResponse> continuation)
{
interceptor(context);
return continuation(requestStream, responseStream, context);
}
}
}
}

@ -149,25 +149,49 @@ namespace Grpc.Core
/// </summary>
protected internal class ClientBaseConfiguration
{
private class ClientHeaderInterceptor : GenericInterceptor
private class ClientBaseConfigurationInterceptor : Interceptor
{
readonly Func<IMethod, string, CallOptions, Tuple<string, CallOptions>> interceptor;
/// <summary>
/// Creates a new instance of ClientHeaderInterceptor given the specified header interceptor function.
/// Creates a new instance of ClientBaseConfigurationInterceptor given the specified header and host interceptor function.
/// </summary>
public ClientHeaderInterceptor(Func<IMethod, string, CallOptions, Tuple<string, CallOptions>> interceptor)
public ClientBaseConfigurationInterceptor(Func<IMethod, string, CallOptions, Tuple<string, CallOptions>> interceptor)
{
this.interceptor = GrpcPreconditions.CheckNotNull(interceptor, "interceptor");
this.interceptor = GrpcPreconditions.CheckNotNull(interceptor, nameof(interceptor));
}
protected override ClientCallHooks<TRequest, TResponse> InterceptCall<TRequest, TResponse>(ClientInterceptorContext<TRequest, TResponse> context, bool clientStreaming, bool serverStreaming, TRequest request)
private ClientInterceptorContext<TRequest, TResponse> GetNewContext<TRequest, TResponse>(ClientInterceptorContext<TRequest, TResponse> context)
where TRequest : class
where TResponse : class
{
var newHostAndCallOptions = interceptor(context.Method, context.Host, context.Options);
return new ClientCallHooks<TRequest, TResponse>
{
ContextOverride = new ClientInterceptorContext<TRequest, TResponse>(context.Method, newHostAndCallOptions.Item1, newHostAndCallOptions.Item2)
};
return new ClientInterceptorContext<TRequest, TResponse>(context.Method, newHostAndCallOptions.Item1, newHostAndCallOptions.Item2);
}
public override TResponse BlockingUnaryCall<TRequest, TResponse>(TRequest request, ClientInterceptorContext<TRequest, TResponse> context, BlockingUnaryCallContinuation<TRequest, TResponse> continuation)
{
return continuation(request, GetNewContext(context));
}
public override AsyncUnaryCall<TResponse> AsyncUnaryCall<TRequest, TResponse>(TRequest request, ClientInterceptorContext<TRequest, TResponse> context, AsyncUnaryCallContinuation<TRequest, TResponse> continuation)
{
return continuation(request, GetNewContext(context));
}
public override AsyncServerStreamingCall<TResponse> AsyncServerStreamingCall<TRequest, TResponse>(TRequest request, ClientInterceptorContext<TRequest, TResponse> context, AsyncServerStreamingCallContinuation<TRequest, TResponse> continuation)
{
return continuation(request, GetNewContext(context));
}
public override AsyncClientStreamingCall<TRequest, TResponse> AsyncClientStreamingCall<TRequest, TResponse>(ClientInterceptorContext<TRequest, TResponse> context, AsyncClientStreamingCallContinuation<TRequest, TResponse> continuation)
{
return continuation(GetNewContext(context));
}
public override AsyncDuplexStreamingCall<TRequest, TResponse> AsyncDuplexStreamingCall<TRequest, TResponse>(ClientInterceptorContext<TRequest, TResponse> context, AsyncDuplexStreamingCallContinuation<TRequest, TResponse> continuation)
{
return continuation(GetNewContext(context));
}
}
@ -182,12 +206,12 @@ namespace Grpc.Core
internal CallInvoker CreateDecoratedCallInvoker()
{
return undecoratedCallInvoker.Intercept(new ClientHeaderInterceptor((method, host, options) => Tuple.Create(this.host, options)));
return undecoratedCallInvoker.Intercept(new ClientBaseConfigurationInterceptor((method, host, options) => Tuple.Create(this.host, options)));
}
internal ClientBaseConfiguration WithHost(string host)
{
GrpcPreconditions.CheckNotNull(host, "host");
GrpcPreconditions.CheckNotNull(host, nameof(host));
return new ClientBaseConfiguration(this.undecoratedCallInvoker, host);
}
}

@ -64,7 +64,7 @@ namespace Grpc.Core.Interceptors
/// </remarks>
public static CallInvoker Intercept(this CallInvoker invoker, params Interceptor[] interceptors)
{
GrpcPreconditions.CheckNotNull(invoker, nameof(invoker);
GrpcPreconditions.CheckNotNull(invoker, nameof(invoker));
GrpcPreconditions.CheckNotNull(interceptors, nameof(interceptors));
foreach (var interceptor in interceptors.Reverse())
@ -95,7 +95,7 @@ namespace Grpc.Core.Interceptors
return new InterceptingCallInvoker(invoker, new MetadataInterceptor(interceptor));
}
private class MetadataInterceptor : GenericInterceptor
private class MetadataInterceptor : Interceptor
{
readonly Func<Metadata, Metadata> interceptor;
@ -107,13 +107,37 @@ namespace Grpc.Core.Interceptors
this.interceptor = GrpcPreconditions.CheckNotNull(interceptor, nameof(interceptor));
}
protected override ClientCallHooks<TRequest, TResponse> InterceptCall<TRequest, TResponse>(ClientInterceptorContext<TRequest, TResponse> context, bool clientStreaming, bool serverStreaming, TRequest request)
private ClientInterceptorContext<TRequest, TResponse> GetNewContext<TRequest, TResponse>(ClientInterceptorContext<TRequest, TResponse> context)
where TRequest : class
where TResponse : class
{
var metadata = context.Options.Headers ?? new Metadata();
return new ClientCallHooks<TRequest, TResponse>
{
ContextOverride = new ClientInterceptorContext<TRequest, TResponse>(context.Method, context.Host, context.Options.WithHeaders(interceptor(metadata))),
};
return new ClientInterceptorContext<TRequest, TResponse>(context.Method, context.Host, context.Options.WithHeaders(interceptor(metadata)));
}
public override TResponse BlockingUnaryCall<TRequest, TResponse>(TRequest request, ClientInterceptorContext<TRequest, TResponse> context, BlockingUnaryCallContinuation<TRequest, TResponse> continuation)
{
return continuation(request, GetNewContext(context));
}
public override AsyncUnaryCall<TResponse> AsyncUnaryCall<TRequest, TResponse>(TRequest request, ClientInterceptorContext<TRequest, TResponse> context, AsyncUnaryCallContinuation<TRequest, TResponse> continuation)
{
return continuation(request, GetNewContext(context));
}
public override AsyncServerStreamingCall<TResponse> AsyncServerStreamingCall<TRequest, TResponse>(TRequest request, ClientInterceptorContext<TRequest, TResponse> context, AsyncServerStreamingCallContinuation<TRequest, TResponse> continuation)
{
return continuation(request, GetNewContext(context));
}
public override AsyncClientStreamingCall<TRequest, TResponse> AsyncClientStreamingCall<TRequest, TResponse>(ClientInterceptorContext<TRequest, TResponse> context, AsyncClientStreamingCallContinuation<TRequest, TResponse> continuation)
{
return continuation(GetNewContext(context));
}
public override AsyncDuplexStreamingCall<TRequest, TResponse> AsyncDuplexStreamingCall<TRequest, TResponse>(ClientInterceptorContext<TRequest, TResponse> context, AsyncDuplexStreamingCallContinuation<TRequest, TResponse> continuation)
{
return continuation(GetNewContext(context));
}
}
}

@ -1,449 +0,0 @@
#region Copyright notice and license
// Copyright 2018 gRPC authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#endregion
using System;
using System.Threading;
using System.Threading.Tasks;
using Grpc.Core.Internal;
namespace Grpc.Core.Interceptors
{
/// <summary>
/// Provides a base class for generic interceptor implementations that raises
/// events and hooks to control the RPC lifecycle.
/// </summary>
internal abstract class GenericInterceptor : Interceptor
{
/// <summary>
/// Provides hooks through which an invocation should be intercepted.
/// </summary>
public sealed class ClientCallHooks<TRequest, TResponse>
where TRequest : class
where TResponse : class
{
internal ClientCallHooks<TRequest, TResponse> Freeze()
{
return (ClientCallHooks<TRequest, TResponse>)MemberwiseClone();
}
/// <summary>
/// Override the context for the outgoing invocation.
/// </summary>
public ClientInterceptorContext<TRequest, TResponse>? ContextOverride { get; set; }
/// <summary>
/// Override the request for the outgoing invocation for non-client-streaming invocations.
/// </summary>
public TRequest UnaryRequestOverride { get; set; }
/// <summary>
/// Delegate that intercepts a response from a non-server-streaming invocation and optionally overrides it.
/// </summary>
public Func<TResponse, TResponse> OnUnaryResponse { get; set; }
/// <summary>
/// Delegate that intercepts each request message for a client-streaming invocation and optionally overrides each message.
/// </summary>
public Func<TRequest, TRequest> OnRequestMessage { get; set; }
/// <summary>
/// Delegate that intercepts each response message for a server-streaming invocation and optionally overrides each message.
/// </summary>
public Func<TResponse, TResponse> OnResponseMessage { get; set; }
/// <summary>
/// Callback that gets invoked when response stream is finished.
/// </summary>
public Action OnResponseStreamEnd { get; set; }
/// <summary>
/// Callback that gets invoked when request stream is finished.
/// </summary>
public Action OnRequestStreamEnd { get; set; }
}
/// <summary>
/// Intercepts an outgoing call from the client side.
/// Derived classes that intend to intercept outgoing invocations from the client side should
/// override this and return the appropriate hooks in the form of a ClientCallHooks instance.
/// </summary>
/// <param name="context">The context of the outgoing invocation.</param>
/// <param name="clientStreaming">True if the invocation is client-streaming.</param>
/// <param name="serverStreaming">True if the invocation is server-streaming.</param>
/// <param name="request">The request message for client-unary invocations, null otherwise.</param>
/// <typeparam name="TRequest">Request message type for the current invocation.</typeparam>
/// <typeparam name="TResponse">Response message type for the current invocation.</typeparam>
/// <returns>
/// The derived class should return an instance of ClientCallHooks to control the trajectory
/// as they see fit, or null if it does not intend to pursue the invocation any further.
/// </returns>
protected virtual ClientCallHooks<TRequest, TResponse> InterceptCall<TRequest, TResponse>(ClientInterceptorContext<TRequest, TResponse> context, bool clientStreaming, bool serverStreaming, TRequest request)
where TRequest : class
where TResponse : class
{
return null;
}
/// <summary>
/// Provides hooks through which a server-side handler should be intercepted.
/// </summary>
public sealed class ServerCallHooks<TRequest, TResponse>
where TRequest : class
where TResponse : class
{
internal ServerCallHooks<TRequest, TResponse> Freeze()
{
return (ServerCallHooks<TRequest, TResponse>)MemberwiseClone();
}
/// <summary>
/// Override the request for the outgoing invocation for non-client-streaming invocations.
/// </summary>
public TRequest UnaryRequestOverride { get; set; }
/// <summary>
/// Delegate that intercepts a response from a non-server-streaming invocation and optionally overrides it.
/// </summary>
public Func<TResponse, TResponse> OnUnaryResponse { get; set; }
/// <summary>
/// Delegate that intercepts each request message for a client-streaming invocation and optionally overrides each message.
/// </summary>
public Func<TRequest, TRequest> OnRequestMessage { get; set; }
/// <summary>
/// Delegate that intercepts each response message for a server-streaming invocation and optionally overrides each message.
/// </summary>
public Func<TResponse, TResponse> OnResponseMessage { get; set; }
/// <summary>
/// Callback that gets invoked when handler is finished executing.
/// </summary>
public Action OnHandlerEnd { get; set; }
/// <summary>
/// Callback that gets invoked when request stream is finished.
/// </summary>
public Action OnRequestStreamEnd { get; set; }
}
/// <summary>
/// Intercepts an incoming service handler invocation on the server side.
/// Derived classes that intend to intercept incoming handlers on the server side should
/// override this and return the appropriate hooks in the form of a ServerCallHooks instance.
/// </summary>
/// <param name="context">The context of the incoming invocation.</param>
/// <param name="clientStreaming">True if the invocation is client-streaming.</param>
/// <param name="serverStreaming">True if the invocation is server-streaming.</param>
/// <param name="request">The request message for client-unary invocations, null otherwise.</param>
/// <typeparam name="TRequest">Request message type for the current invocation.</typeparam>
/// <typeparam name="TResponse">Response message type for the current invocation.</typeparam>
/// <returns>
/// The derived class should return an instance of ServerCallHooks to control the trajectory
/// as they see fit, or null if it does not intend to pursue the invocation any further.
/// </returns>
protected virtual Task<ServerCallHooks<TRequest, TResponse>> InterceptHandler<TRequest, TResponse>(ServerCallContext context, bool clientStreaming, bool serverStreaming, TRequest request)
where TRequest : class
where TResponse : class
{
return Task.FromResult<ServerCallHooks<TRequest, TResponse>>(null);
}
/// <summary>
/// Intercepts a blocking invocation of a simple remote call and dispatches the events accordingly.
/// </summary>
public override TResponse BlockingUnaryCall<TRequest, TResponse>(TRequest request, ClientInterceptorContext<TRequest, TResponse> context, BlockingUnaryCallContinuation<TRequest, TResponse> continuation)
{
var hooks = InterceptCall(context, false, false, request)?.Freeze();
context = hooks?.ContextOverride ?? context;
request = hooks?.UnaryRequestOverride ?? request;
var response = continuation(request, context);
if (hooks?.OnUnaryResponse != null)
{
response = hooks.OnUnaryResponse(response);
}
return response;
}
/// <summary>
/// Intercepts an asynchronous invocation of a simple remote call and dispatches the events accordingly.
/// </summary>
public override AsyncUnaryCall<TResponse> AsyncUnaryCall<TRequest, TResponse>(TRequest request, ClientInterceptorContext<TRequest, TResponse> context, AsyncUnaryCallContinuation<TRequest, TResponse> continuation)
{
var hooks = InterceptCall(context, false, false, request)?.Freeze();
context = hooks?.ContextOverride ?? context;
request = hooks?.UnaryRequestOverride ?? request;
var response = continuation(request, context);
if (hooks?.OnUnaryResponse != null)
{
response = new AsyncUnaryCall<TResponse>(response.ResponseAsync.ContinueWith(unaryResponse => hooks.OnUnaryResponse(unaryResponse.Result)),
response.ResponseHeadersAsync, response.GetStatus, response.GetTrailers, response.Dispose);
}
return response;
}
/// <summary>
/// Intercepts an asynchronous invocation of a streaming remote call and dispatches the events accordingly.
/// </summary>
public override AsyncServerStreamingCall<TResponse> AsyncServerStreamingCall<TRequest, TResponse>(TRequest request, ClientInterceptorContext<TRequest, TResponse> context, AsyncServerStreamingCallContinuation<TRequest, TResponse> continuation)
{
var hooks = InterceptCall(context, false, true, request)?.Freeze();
context = hooks?.ContextOverride ?? context;
request = hooks?.UnaryRequestOverride ?? request;
var response = continuation(request, context);
if (hooks?.OnResponseMessage != null || hooks?.OnResponseStreamEnd != null)
{
response = new AsyncServerStreamingCall<TResponse>(
new WrappedAsyncStreamReader<TResponse>(response.ResponseStream, hooks.OnResponseMessage, hooks.OnResponseStreamEnd),
response.ResponseHeadersAsync, response.GetStatus, response.GetTrailers, response.Dispose);
}
return response;
}
/// <summary>
/// Intercepts an asynchronous invocation of a client streaming call and dispatches the events accordingly.
/// </summary>
public override AsyncClientStreamingCall<TRequest, TResponse> AsyncClientStreamingCall<TRequest, TResponse>(ClientInterceptorContext<TRequest, TResponse> context, AsyncClientStreamingCallContinuation<TRequest, TResponse> continuation)
{
var hooks = InterceptCall(context, true, false, null)?.Freeze();
context = hooks?.ContextOverride ?? context;
var response = continuation(context);
if (hooks?.OnRequestMessage != null || hooks?.OnResponseStreamEnd != null || hooks?.OnUnaryResponse != null)
{
var requestStream = response.RequestStream;
if (hooks?.OnRequestMessage != null || hooks?.OnRequestStreamEnd != null)
{
requestStream = new WrappedClientStreamWriter<TRequest>(response.RequestStream, hooks.OnRequestMessage, hooks.OnRequestStreamEnd);
}
var responseAsync = response.ResponseAsync;
if (hooks?.OnUnaryResponse != null)
{
responseAsync = response.ResponseAsync.ContinueWith(unaryResponse => hooks.OnUnaryResponse(unaryResponse.Result));
}
response = new AsyncClientStreamingCall<TRequest, TResponse>(requestStream, responseAsync, response.ResponseHeadersAsync, response.GetStatus, response.GetTrailers, response.Dispose);
}
return response;
}
/// <summary>
/// Intercepts an asynchronous invocation of a duplex streaming call and dispatches the events accordingly.
/// </summary>
public override AsyncDuplexStreamingCall<TRequest, TResponse> AsyncDuplexStreamingCall<TRequest, TResponse>(ClientInterceptorContext<TRequest, TResponse> context, AsyncDuplexStreamingCallContinuation<TRequest, TResponse> continuation)
{
var hooks = InterceptCall(context, true, true, null)?.Freeze();
context = hooks?.ContextOverride ?? context;
var response = continuation(context);
if (hooks?.OnRequestMessage != null || hooks?.OnRequestStreamEnd != null || hooks?.OnResponseMessage != null || hooks?.OnResponseStreamEnd != null)
{
var requestStream = response.RequestStream;
if (hooks?.OnRequestMessage != null || hooks?.OnRequestStreamEnd != null)
{
requestStream = new WrappedClientStreamWriter<TRequest>(response.RequestStream, hooks.OnRequestMessage, hooks.OnRequestStreamEnd);
}
var responseStream = response.ResponseStream;
if (hooks?.OnResponseMessage != null || hooks?.OnResponseStreamEnd != null)
{
responseStream = new WrappedAsyncStreamReader<TResponse>(response.ResponseStream, hooks.OnResponseMessage, hooks.OnResponseStreamEnd);
}
response = new AsyncDuplexStreamingCall<TRequest, TResponse>(requestStream, responseStream, response.ResponseHeadersAsync, response.GetStatus, response.GetTrailers, response.Dispose);
}
return response;
}
/// <summary>
/// Server-side handler for intercepting unary calls.
/// </summary>
/// <typeparam name="TRequest">Request message type for this method.</typeparam>
/// <typeparam name="TResponse">Response message type for this method.</typeparam>
public override async Task<TResponse> UnaryServerHandler<TRequest, TResponse>(TRequest request, ServerCallContext context, UnaryServerMethod<TRequest, TResponse> continuation)
{
var hooks = (await InterceptHandler<TRequest, TResponse>(context, false, false, request))?.Freeze();
request = hooks?.UnaryRequestOverride ?? request;
var response = await continuation(request, context);
if (hooks?.OnUnaryResponse != null)
{
response = hooks.OnUnaryResponse(response);
}
hooks?.OnHandlerEnd();
return response;
}
/// <summary>
/// Server-side handler for intercepting client streaming call.
/// </summary>
/// <typeparam name="TRequest">Request message type for this method.</typeparam>
/// <typeparam name="TResponse">Response message type for this method.</typeparam>
public override async Task<TResponse> ClientStreamingServerHandler<TRequest, TResponse>(IAsyncStreamReader<TRequest> requestStream, ServerCallContext context, ClientStreamingServerMethod<TRequest, TResponse> continuation)
{
var hooks = (await InterceptHandler<TRequest, TResponse>(context, true, false, null))?.Freeze();
if (hooks?.OnRequestMessage != null || hooks?.OnRequestStreamEnd != null)
{
requestStream = new WrappedAsyncStreamReader<TRequest>(requestStream, hooks.OnRequestMessage, hooks.OnRequestStreamEnd);
}
var response = await continuation(requestStream, context);
if (hooks?.OnUnaryResponse != null)
{
response = hooks.OnUnaryResponse(response);
}
hooks?.OnHandlerEnd();
return response;
}
/// <summary>
/// Server-side handler for intercepting server streaming calls.
/// </summary>
/// <typeparam name="TRequest">Request message type for this method.</typeparam>
/// <typeparam name="TResponse">Response message type for this method.</typeparam>
public override async Task ServerStreamingServerHandler<TRequest, TResponse>(TRequest request, IServerStreamWriter<TResponse> responseStream, ServerCallContext context, ServerStreamingServerMethod<TRequest, TResponse> continuation)
{
var hooks = (await InterceptHandler<TRequest, TResponse>(context, false, true, request))?.Freeze();
request = hooks?.UnaryRequestOverride ?? request;
if (hooks?.OnResponseMessage != null)
{
responseStream = new WrappedAsyncStreamWriter<TResponse>(responseStream, hooks.OnResponseMessage);
}
await continuation(request, responseStream, context);
hooks?.OnHandlerEnd();
}
/// <summary>
/// Server-side handler for intercepting bidi streaming calls.
/// </summary>
/// <typeparam name="TRequest">Request message type for this method.</typeparam>
/// <typeparam name="TResponse">Response message type for this method.</typeparam>
public override async Task DuplexStreamingServerHandler<TRequest, TResponse>(IAsyncStreamReader<TRequest> requestStream, IServerStreamWriter<TResponse> responseStream, ServerCallContext context, DuplexStreamingServerMethod<TRequest, TResponse> continuation)
{
var hooks = (await InterceptHandler<TRequest, TResponse>(context, true, true, null))?.Freeze();
if (hooks?.OnRequestMessage != null || hooks?.OnRequestStreamEnd != null)
{
requestStream = new WrappedAsyncStreamReader<TRequest>(requestStream, hooks.OnRequestMessage, hooks.OnRequestStreamEnd);
}
if (hooks?.OnResponseMessage != null)
{
responseStream = new WrappedAsyncStreamWriter<TResponse>(responseStream, hooks.OnResponseMessage);
}
await continuation(requestStream, responseStream, context);
hooks?.OnHandlerEnd();
}
private class WrappedAsyncStreamReader<T> : IAsyncStreamReader<T>
{
readonly IAsyncStreamReader<T> reader;
readonly Func<T, T> onMessage;
readonly Action onStreamEnd;
public WrappedAsyncStreamReader(IAsyncStreamReader<T> reader, Func<T, T> onMessage, Action onStreamEnd)
{
this.reader = reader;
this.onMessage = onMessage;
this.onStreamEnd = onStreamEnd;
}
public void Dispose() => ((IDisposable)reader).Dispose();
private T current;
public T Current
{
get
{
if (current == null)
{
throw new InvalidOperationException("No current element is available.");
}
return current;
}
}
public async Task<bool> MoveNext(CancellationToken token)
{
if (await reader.MoveNext(token))
{
var current = reader.Current;
if (onMessage != null)
{
var mappedValue = onMessage(current);
if (mappedValue != null)
{
current = mappedValue;
}
}
this.current = current;
return true;
}
onStreamEnd?.Invoke();
return false;
}
}
private class WrappedClientStreamWriter<T> : IClientStreamWriter<T>
{
readonly IClientStreamWriter<T> writer;
readonly Func<T, T> onMessage;
readonly Action onResponseStreamEnd;
public WrappedClientStreamWriter(IClientStreamWriter<T> writer, Func<T, T> onMessage, Action onResponseStreamEnd)
{
this.writer = writer;
this.onMessage = onMessage;
this.onResponseStreamEnd = onResponseStreamEnd;
}
public Task CompleteAsync()
{
if (onResponseStreamEnd != null)
{
return writer.CompleteAsync().ContinueWith(x => onResponseStreamEnd());
}
return writer.CompleteAsync();
}
public Task WriteAsync(T message)
{
if (onMessage != null)
{
message = onMessage(message);
}
return writer.WriteAsync(message);
}
public WriteOptions WriteOptions
{
get
{
return writer.WriteOptions;
}
set
{
writer.WriteOptions = value;
}
}
}
private class WrappedAsyncStreamWriter<T> : IServerStreamWriter<T>
{
readonly IAsyncStreamWriter<T> writer;
readonly Func<T, T> onMessage;
public WrappedAsyncStreamWriter(IAsyncStreamWriter<T> writer, Func<T, T> onMessage)
{
this.writer = writer;
this.onMessage = onMessage;
}
public Task WriteAsync(T message)
{
if (onMessage != null)
{
message = onMessage(message);
}
return writer.WriteAsync(message);
}
public WriteOptions WriteOptions
{
get
{
return writer.WriteOptions;
}
set
{
writer.WriteOptions = value;
}
}
}
}
}
Loading…
Cancel
Save