diff --git a/src/csharp/Grpc.Core.Tests/Interceptors/ClientInterceptorTest.cs b/src/csharp/Grpc.Core.Tests/Interceptors/ClientInterceptorTest.cs index 0ec2d848f00..d7c01d08ac4 100644 --- a/src/csharp/Grpc.Core.Tests/Interceptors/ClientInterceptorTest.cs +++ b/src/csharp/Grpc.Core.Tests/Interceptors/ClientInterceptorTest.cs @@ -58,6 +58,22 @@ namespace Grpc.Core.Interceptors.Tests Assert.AreEqual("PASS", callInvoker.BlockingUnaryCall(new Method(MethodType.Unary, MockServiceHelper.ServiceName, "Unary", Marshallers.StringMarshaller, Marshallers.StringMarshaller), Host, new CallOptions(), "")); } + private class CallbackInterceptor : GenericInterceptor + { + readonly Action callback; + + public CallbackInterceptor(Action callback) + { + this.callback = callback; + } + + protected override ClientCallHooks InterceptCall(ClientInterceptorContext context, bool clientStreaming, bool serverStreaming, TRequest request) + { + callback(); + return null; + } + } + [Test] public void CheckInterceptorOrderInClientInterceptors() { @@ -69,11 +85,13 @@ namespace Grpc.Core.Interceptors.Tests var server = helper.GetServer(); server.Start(); var stringBuilder = new StringBuilder(); - var callInvoker = helper.GetChannel().Intercept(metadata => - { + var callInvoker = helper.GetChannel().Intercept(metadata => { stringBuilder.Append("interceptor1"); return metadata; - }).Intercept(metadata => + }).Intercept(new CallbackInterceptor(() => stringBuilder.Append("array1")), + new CallbackInterceptor(() => stringBuilder.Append("array2")), + new CallbackInterceptor(() => stringBuilder.Append("array3"))) + .Intercept(metadata => { stringBuilder.Append("interceptor2"); return metadata; @@ -83,7 +101,21 @@ namespace Grpc.Core.Interceptors.Tests return metadata; }); Assert.AreEqual("PASS", callInvoker.BlockingUnaryCall(new Method(MethodType.Unary, MockServiceHelper.ServiceName, "Unary", Marshallers.StringMarshaller, Marshallers.StringMarshaller), Host, new CallOptions(), "")); - Assert.AreEqual("interceptor3interceptor2interceptor1", stringBuilder.ToString()); + Assert.AreEqual("interceptor3interceptor2array1array2array3interceptor1", stringBuilder.ToString()); + } + + [Test] + public void CheckNullInterceptorRegistrationFails() + { + var helper = new MockServiceHelper(Host); + helper.UnaryHandler = new UnaryServerMethod((request, context) => + { + return Task.FromResult("PASS"); + }); + Assert.Throws(() => helper.GetChannel().Intercept(default(Interceptor))); + Assert.Throws(() => helper.GetChannel().Intercept(new[]{default(Interceptor)})); + Assert.Throws(() => helper.GetChannel().Intercept(new[]{new CallbackInterceptor(()=>{}), null})); + Assert.Throws(() => helper.GetChannel().Intercept(default(Interceptor[]))); } private class CountingInterceptor : GenericInterceptor diff --git a/src/csharp/Grpc.Core.Tests/Interceptors/ServerInterceptorTest.cs b/src/csharp/Grpc.Core.Tests/Interceptors/ServerInterceptorTest.cs index 57dd68b1ebf..c0957a2b422 100644 --- a/src/csharp/Grpc.Core.Tests/Interceptors/ServerInterceptorTest.cs +++ b/src/csharp/Grpc.Core.Tests/Interceptors/ServerInterceptorTest.cs @@ -50,13 +50,7 @@ namespace Grpc.Core.Interceptors.Tests return Task.FromResult>(null); } - public Metadata.Entry Header - { - get - { - return header; - } - } + public Metadata.Entry Header => header; } [Test] @@ -81,7 +75,6 @@ namespace Grpc.Core.Interceptors.Tests { readonly Action action; - public ArbitraryActionInterceptor(Action action) { this.action = action; @@ -105,13 +98,30 @@ namespace Grpc.Core.Interceptors.Tests var stringBuilder = new StringBuilder(); helper.ServiceDefinition = helper.ServiceDefinition .Intercept(new ArbitraryActionInterceptor(() => stringBuilder.Append("A"))) - .Intercept(new ArbitraryActionInterceptor(() => stringBuilder.Append("B"))) + .Intercept(new ArbitraryActionInterceptor(() => stringBuilder.Append("B1")), + new ArbitraryActionInterceptor(() => stringBuilder.Append("B2")), + new ArbitraryActionInterceptor(() => stringBuilder.Append("B3"))) .Intercept(new ArbitraryActionInterceptor(() => stringBuilder.Append("C"))); var server = helper.GetServer(); server.Start(); var channel = helper.GetChannel(); Assert.AreEqual("PASS", Calls.BlockingUnaryCall(helper.CreateUnaryCall(), "")); - Assert.AreEqual("CBA", stringBuilder.ToString()); + Assert.AreEqual("CB1B2B3A", stringBuilder.ToString()); + } + + [Test] + public void CheckNullInterceptorRegistrationFails() + { + var helper = new MockServiceHelper(Host); + helper.UnaryHandler = new UnaryServerMethod((request, context) => + { + return Task.FromResult("PASS"); + }); + var sd = helper.ServiceDefinition; + Assert.Throws(() => sd.Intercept(default(Interceptor))); + Assert.Throws(() => sd.Intercept(new[]{default(Interceptor)})); + Assert.Throws(() => sd.Intercept(new[]{new ArbitraryActionInterceptor(()=>{}), null})); + Assert.Throws(() => sd.Intercept(default(Interceptor[]))); } } }