Modify verifyPeerCallback logic to use NativeCallbackDispatcher

pull/18591/head
Jan Tattermusch 6 years ago
parent f1199912e8
commit c835fedd4d
  1. 37
      src/csharp/Grpc.Core/ChannelCredentials.cs
  2. 12
      src/csharp/Grpc.Core/Internal/ChannelCredentialsSafeHandle.cs
  3. 6
      src/csharp/Grpc.Core/Internal/NativeMethods.Generated.cs
  4. 92
      src/csharp/Grpc.IntegrationTesting/SslCredentialsTest.cs
  5. 53
      src/csharp/ext/grpc_csharp_ext.c
  6. 2
      templates/src/csharp/Grpc.Core/Internal/native_methods.include

@ -127,8 +127,6 @@ namespace Grpc.Core
readonly string rootCertificates; readonly string rootCertificates;
readonly KeyCertificatePair keyCertificatePair; readonly KeyCertificatePair keyCertificatePair;
readonly VerifyPeerCallback verifyPeerCallback; readonly VerifyPeerCallback verifyPeerCallback;
readonly VerifyPeerCallbackInternal verifyPeerCallbackInternal;
readonly GCHandle gcHandle;
/// <summary> /// <summary>
/// Creates client-side SSL credentials loaded from /// Creates client-side SSL credentials loaded from
@ -168,12 +166,7 @@ namespace Grpc.Core
{ {
this.rootCertificates = rootCertificates; this.rootCertificates = rootCertificates;
this.keyCertificatePair = keyCertificatePair; this.keyCertificatePair = keyCertificatePair;
if (verifyPeerCallback != null)
{
this.verifyPeerCallback = verifyPeerCallback; this.verifyPeerCallback = verifyPeerCallback;
this.verifyPeerCallbackInternal = this.VerifyPeerCallbackHandler;
gcHandle = GCHandle.Alloc(verifyPeerCallbackInternal);
}
} }
/// <summary> /// <summary>
@ -207,14 +200,37 @@ namespace Grpc.Core
internal override ChannelCredentialsSafeHandle CreateNativeCredentials() internal override ChannelCredentialsSafeHandle CreateNativeCredentials()
{ {
return ChannelCredentialsSafeHandle.CreateSslCredentials(rootCertificates, keyCertificatePair, this.verifyPeerCallbackInternal); IntPtr verifyPeerCallbackTag = IntPtr.Zero;
if (verifyPeerCallback != null)
{
verifyPeerCallbackTag = new VerifyPeerCallbackRegistration(verifyPeerCallback).CallbackRegistration.Tag;
}
return ChannelCredentialsSafeHandle.CreateSslCredentials(rootCertificates, keyCertificatePair, verifyPeerCallbackTag);
}
private class VerifyPeerCallbackRegistration
{
readonly VerifyPeerCallback verifyPeerCallback;
readonly NativeCallbackRegistration callbackRegistration;
public VerifyPeerCallbackRegistration(VerifyPeerCallback verifyPeerCallback)
{
this.verifyPeerCallback = verifyPeerCallback;
this.callbackRegistration = NativeCallbackDispatcher.RegisterCallback(HandleUniversalCallback);
} }
private int VerifyPeerCallbackHandler(IntPtr host, IntPtr pem, IntPtr userData, bool isDestroy) public NativeCallbackRegistration CallbackRegistration => callbackRegistration;
private int HandleUniversalCallback(IntPtr arg0, IntPtr arg1, IntPtr arg2, IntPtr arg3, IntPtr arg4, IntPtr arg5)
{
return VerifyPeerCallbackHandler(arg0, arg1, arg2 != IntPtr.Zero);
}
private int VerifyPeerCallbackHandler(IntPtr host, IntPtr pem, bool isDestroy)
{ {
if (isDestroy) if (isDestroy)
{ {
this.gcHandle.Free(); this.callbackRegistration.Dispose();
return 0; return 0;
} }
@ -233,6 +249,7 @@ namespace Grpc.Core
} }
} }
} }
}
/// <summary> /// <summary>
/// Credentials that allow composing one <see cref="ChannelCredentials"/> object and /// Credentials that allow composing one <see cref="ChannelCredentials"/> object and

@ -20,12 +20,6 @@ using System.Threading.Tasks;
namespace Grpc.Core.Internal namespace Grpc.Core.Internal
{ {
internal delegate int VerifyPeerCallbackInternal(
IntPtr targetHost,
IntPtr targetPem,
IntPtr userData,
bool isDestroy);
/// <summary> /// <summary>
/// grpc_channel_credentials from <c>grpc/grpc_security.h</c> /// grpc_channel_credentials from <c>grpc/grpc_security.h</c>
/// </summary> /// </summary>
@ -44,15 +38,15 @@ namespace Grpc.Core.Internal
return creds; return creds;
} }
public static ChannelCredentialsSafeHandle CreateSslCredentials(string pemRootCerts, KeyCertificatePair keyCertPair, VerifyPeerCallbackInternal verifyPeerCallback) public static ChannelCredentialsSafeHandle CreateSslCredentials(string pemRootCerts, KeyCertificatePair keyCertPair, IntPtr verifyPeerCallbackTag)
{ {
if (keyCertPair != null) if (keyCertPair != null)
{ {
return Native.grpcsharp_ssl_credentials_create(pemRootCerts, keyCertPair.CertificateChain, keyCertPair.PrivateKey, verifyPeerCallback); return Native.grpcsharp_ssl_credentials_create(pemRootCerts, keyCertPair.CertificateChain, keyCertPair.PrivateKey, verifyPeerCallbackTag);
} }
else else
{ {
return Native.grpcsharp_ssl_credentials_create(pemRootCerts, null, null, verifyPeerCallback); return Native.grpcsharp_ssl_credentials_create(pemRootCerts, null, null, verifyPeerCallbackTag);
} }
} }

@ -482,7 +482,7 @@ namespace Grpc.Core.Internal
public delegate void grpcsharp_channel_args_set_integer_delegate(ChannelArgsSafeHandle args, UIntPtr index, string key, int value); public delegate void grpcsharp_channel_args_set_integer_delegate(ChannelArgsSafeHandle args, UIntPtr index, string key, int value);
public delegate void grpcsharp_channel_args_destroy_delegate(IntPtr args); public delegate void grpcsharp_channel_args_destroy_delegate(IntPtr args);
public delegate void grpcsharp_override_default_ssl_roots_delegate(string pemRootCerts); public delegate void grpcsharp_override_default_ssl_roots_delegate(string pemRootCerts);
public delegate ChannelCredentialsSafeHandle grpcsharp_ssl_credentials_create_delegate(string pemRootCerts, string keyCertPairCertChain, string keyCertPairPrivateKey, VerifyPeerCallbackInternal verifyPeerCallback); public delegate ChannelCredentialsSafeHandle grpcsharp_ssl_credentials_create_delegate(string pemRootCerts, string keyCertPairCertChain, string keyCertPairPrivateKey, IntPtr verifyPeerCallbackTag);
public delegate ChannelCredentialsSafeHandle grpcsharp_composite_channel_credentials_create_delegate(ChannelCredentialsSafeHandle channelCreds, CallCredentialsSafeHandle callCreds); public delegate ChannelCredentialsSafeHandle grpcsharp_composite_channel_credentials_create_delegate(ChannelCredentialsSafeHandle channelCreds, CallCredentialsSafeHandle callCreds);
public delegate void grpcsharp_channel_credentials_release_delegate(IntPtr credentials); public delegate void grpcsharp_channel_credentials_release_delegate(IntPtr credentials);
public delegate ChannelSafeHandle grpcsharp_insecure_channel_create_delegate(string target, ChannelArgsSafeHandle channelArgs); public delegate ChannelSafeHandle grpcsharp_insecure_channel_create_delegate(string target, ChannelArgsSafeHandle channelArgs);
@ -676,7 +676,7 @@ namespace Grpc.Core.Internal
public static extern void grpcsharp_override_default_ssl_roots(string pemRootCerts); public static extern void grpcsharp_override_default_ssl_roots(string pemRootCerts);
[DllImport(ImportName)] [DllImport(ImportName)]
public static extern ChannelCredentialsSafeHandle grpcsharp_ssl_credentials_create(string pemRootCerts, string keyCertPairCertChain, string keyCertPairPrivateKey, VerifyPeerCallbackInternal verifyPeerCallback); public static extern ChannelCredentialsSafeHandle grpcsharp_ssl_credentials_create(string pemRootCerts, string keyCertPairCertChain, string keyCertPairPrivateKey, IntPtr verifyPeerCallbackTag);
[DllImport(ImportName)] [DllImport(ImportName)]
public static extern ChannelCredentialsSafeHandle grpcsharp_composite_channel_credentials_create(ChannelCredentialsSafeHandle channelCreds, CallCredentialsSafeHandle callCreds); public static extern ChannelCredentialsSafeHandle grpcsharp_composite_channel_credentials_create(ChannelCredentialsSafeHandle channelCreds, CallCredentialsSafeHandle callCreds);
@ -972,7 +972,7 @@ namespace Grpc.Core.Internal
public static extern void grpcsharp_override_default_ssl_roots(string pemRootCerts); public static extern void grpcsharp_override_default_ssl_roots(string pemRootCerts);
[DllImport(ImportName)] [DllImport(ImportName)]
public static extern ChannelCredentialsSafeHandle grpcsharp_ssl_credentials_create(string pemRootCerts, string keyCertPairCertChain, string keyCertPairPrivateKey, VerifyPeerCallbackInternal verifyPeerCallback); public static extern ChannelCredentialsSafeHandle grpcsharp_ssl_credentials_create(string pemRootCerts, string keyCertPairCertChain, string keyCertPairPrivateKey, IntPtr verifyPeerCallbackTag);
[DllImport(ImportName)] [DllImport(ImportName)]
public static extern ChannelCredentialsSafeHandle grpcsharp_composite_channel_credentials_create(ChannelCredentialsSafeHandle channelCreds, CallCredentialsSafeHandle callCreds); public static extern ChannelCredentialsSafeHandle grpcsharp_composite_channel_credentials_create(ChannelCredentialsSafeHandle channelCreds, CallCredentialsSafeHandle callCreds);

@ -44,23 +44,18 @@ namespace Grpc.IntegrationTesting
string rootCert; string rootCert;
KeyCertificatePair keyCertPair; KeyCertificatePair keyCertPair;
string certChain;
List<ChannelOption> options;
bool isHostEqual;
bool isPemEqual;
public void InitClientAndServer(bool clientAddKeyCertPair, public void InitClientAndServer(bool clientAddKeyCertPair,
SslClientCertificateRequestType clientCertRequestType) SslClientCertificateRequestType clientCertRequestType,
VerifyPeerCallback verifyPeerCallback = null)
{ {
rootCert = File.ReadAllText(TestCredentials.ClientCertAuthorityPath); rootCert = File.ReadAllText(TestCredentials.ClientCertAuthorityPath);
certChain = File.ReadAllText(TestCredentials.ServerCertChainPath);
certChain = certChain.Replace("\r", string.Empty);
keyCertPair = new KeyCertificatePair( keyCertPair = new KeyCertificatePair(
certChain, File.ReadAllText(TestCredentials.ServerCertChainPath),
File.ReadAllText(TestCredentials.ServerPrivateKeyPath)); File.ReadAllText(TestCredentials.ServerPrivateKeyPath));
var serverCredentials = new SslServerCredentials(new[] { keyCertPair }, rootCert, clientCertRequestType); var serverCredentials = new SslServerCredentials(new[] { keyCertPair }, rootCert, clientCertRequestType);
var clientCredentials = clientAddKeyCertPair ? new SslCredentials(rootCert, keyCertPair, context => this.VerifyPeerCallback(context, true)) : new SslCredentials(rootCert); var clientCredentials = new SslCredentials(rootCert, clientAddKeyCertPair ? keyCertPair : null, verifyPeerCallback);
// Disable SO_REUSEPORT to prevent https://github.com/grpc/grpc/issues/10755 // Disable SO_REUSEPORT to prevent https://github.com/grpc/grpc/issues/10755
server = new Server(new[] { new ChannelOption(ChannelOptions.SoReuseport, 0) }) server = new Server(new[] { new ChannelOption(ChannelOptions.SoReuseport, 0) })
@ -70,7 +65,7 @@ namespace Grpc.IntegrationTesting
}; };
server.Start(); server.Start();
options = new List<ChannelOption> var options = new List<ChannelOption>
{ {
new ChannelOption(ChannelOptions.SslTargetNameOverride, TestCredentials.DefaultHostOverride) new ChannelOption(ChannelOptions.SslTargetNameOverride, TestCredentials.DefaultHostOverride)
}; };
@ -194,6 +189,52 @@ namespace Grpc.IntegrationTesting
Assert.Throws(typeof(ArgumentNullException), () => new SslServerCredentials(keyCertPairs, null, SslClientCertificateRequestType.RequestAndRequireAndVerify)); Assert.Throws(typeof(ArgumentNullException), () => new SslServerCredentials(keyCertPairs, null, SslClientCertificateRequestType.RequestAndRequireAndVerify));
} }
[Test]
public async Task VerifyPeerCallback_Accepted()
{
string targetNameFromCallback = null;
string peerPemFromCallback = null;
InitClientAndServer(
clientAddKeyCertPair: false,
clientCertRequestType: SslClientCertificateRequestType.DontRequest,
verifyPeerCallback: (ctx) =>
{
targetNameFromCallback = ctx.TargetName;
peerPemFromCallback = ctx.PeerPem;
return true;
});
await CheckAccepted(expectPeerAuthenticated: false);
Assert.AreEqual(TestCredentials.DefaultHostOverride, targetNameFromCallback);
var expectedServerPem = File.ReadAllText(TestCredentials.ServerCertChainPath).Replace("\r", "");
Assert.AreEqual(expectedServerPem, peerPemFromCallback);
}
[Test]
public void VerifyPeerCallback_CallbackThrows_Rejected()
{
InitClientAndServer(
clientAddKeyCertPair: false,
clientCertRequestType: SslClientCertificateRequestType.DontRequest,
verifyPeerCallback: (ctx) =>
{
throw new Exception("VerifyPeerCallback has thrown on purpose.");
});
CheckRejected();
}
[Test]
public void VerifyPeerCallback_Rejected()
{
InitClientAndServer(
clientAddKeyCertPair: false,
clientCertRequestType: SslClientCertificateRequestType.DontRequest,
verifyPeerCallback: (ctx) =>
{
return false;
});
CheckRejected();
}
private async Task CheckAccepted(bool expectPeerAuthenticated) private async Task CheckAccepted(bool expectPeerAuthenticated)
{ {
var call = client.UnaryCallAsync(new SimpleRequest { ResponseSize = 10 }); var call = client.UnaryCallAsync(new SimpleRequest { ResponseSize = 10 });
@ -216,37 +257,6 @@ namespace Grpc.IntegrationTesting
Assert.AreEqual(12345, response.AggregatedPayloadSize); Assert.AreEqual(12345, response.AggregatedPayloadSize);
} }
[Test]
public void VerifyPeerCallbackTest()
{
InitClientAndServer(true, SslClientCertificateRequestType.RequestAndRequireAndVerify);
// Force GC collection to verify that the VerifyPeerCallback is not collected. If
// it gets collected, this test will hang.
GC.Collect();
client.UnaryCall(new SimpleRequest { ResponseSize = 10 });
Assert.IsTrue(isHostEqual);
Assert.IsTrue(isPemEqual);
}
[Test]
public void VerifyPeerCallbackFailTest()
{
InitClientAndServer(true, SslClientCertificateRequestType.RequestAndRequireAndVerify);
var clientCredentials = new SslCredentials(rootCert, keyCertPair, context => this.VerifyPeerCallback(context, false));
var failingChannel = new Channel(Host, server.Ports.Single().BoundPort, clientCredentials, options);
var failingClient = new TestService.TestServiceClient(failingChannel);
Assert.Throws<RpcException>(() => failingClient.UnaryCall(new SimpleRequest { ResponseSize = 10 }));
}
private bool VerifyPeerCallback(VerifyPeerContext context, bool returnValue)
{
isHostEqual = TestCredentials.DefaultHostOverride == context.TargetHost;
isPemEqual = certChain == context.TargetPem;
return returnValue;
}
private class SslCredentialsTestServiceImpl : TestService.TestServiceBase private class SslCredentialsTestServiceImpl : TestService.TestServiceBase
{ {
public override Task<SimpleResponse> UnaryCall(SimpleRequest request, ServerCallContext context) public override Task<SimpleResponse> UnaryCall(SimpleRequest request, ServerCallContext context)

@ -901,6 +901,21 @@ grpcsharp_server_request_call(grpc_server* server, grpc_completion_queue* cq,
&(ctx->request_metadata), cq, cq, ctx); &(ctx->request_metadata), cq, cq, ctx);
} }
/* Native callback dispatcher */
typedef int(GPR_CALLTYPE* grpcsharp_native_callback_dispatcher_func)(
void* tag, void* arg0, void* arg1, void* arg2, void* arg3, void* arg4,
void* arg5);
static grpcsharp_native_callback_dispatcher_func native_callback_dispatcher =
NULL;
GPR_EXPORT void GPR_CALLTYPE grpcsharp_native_callback_dispatcher_init(
grpcsharp_native_callback_dispatcher_func func) {
GPR_ASSERT(func);
native_callback_dispatcher = func;
}
/* Security */ /* Security */
static char* default_pem_root_certs = NULL; static char* default_pem_root_certs = NULL;
@ -927,23 +942,18 @@ grpcsharp_override_default_ssl_roots(const char* pem_root_certs) {
grpc_set_ssl_roots_override_callback(override_ssl_roots_handler); grpc_set_ssl_roots_override_callback(override_ssl_roots_handler);
} }
typedef int(GPR_CALLTYPE* grpcsharp_verify_peer_func)(const char* target_host,
const char* target_pem,
void* userdata,
int32_t isDestroy);
static void grpcsharp_verify_peer_destroy_handler(void* userdata) { static void grpcsharp_verify_peer_destroy_handler(void* userdata) {
grpcsharp_verify_peer_func callback = native_callback_dispatcher(userdata, NULL,
(grpcsharp_verify_peer_func)(intptr_t)userdata; NULL, (void*)1, NULL,
callback(NULL, NULL, NULL, 1); NULL, NULL);
} }
static int grpcsharp_verify_peer_handler(const char* target_host, static int grpcsharp_verify_peer_handler(const char* target_host,
const char* target_pem, const char* target_pem,
void* userdata) { void* userdata) {
grpcsharp_verify_peer_func callback = return native_callback_dispatcher(userdata, (void*)target_host,
(grpcsharp_verify_peer_func)(intptr_t)userdata; (void*)target_pem, (void*)0, NULL,
return callback(target_host, target_pem, NULL, 0); NULL, NULL);
} }
@ -951,13 +961,13 @@ GPR_EXPORT grpc_channel_credentials* GPR_CALLTYPE
grpcsharp_ssl_credentials_create(const char* pem_root_certs, grpcsharp_ssl_credentials_create(const char* pem_root_certs,
const char* key_cert_pair_cert_chain, const char* key_cert_pair_cert_chain,
const char* key_cert_pair_private_key, const char* key_cert_pair_private_key,
grpcsharp_verify_peer_func verify_peer_func) { void* verify_peer_callback_tag) {
grpc_ssl_pem_key_cert_pair key_cert_pair; grpc_ssl_pem_key_cert_pair key_cert_pair;
verify_peer_options verify_options; verify_peer_options verify_options;
verify_peer_options* p_verify_options = NULL; verify_peer_options* p_verify_options = NULL;
if (verify_peer_func != NULL) { if (verify_peer_callback_tag != NULL) {
verify_options.verify_peer_callback_userdata = verify_options.verify_peer_callback_userdata =
(void*)(intptr_t)verify_peer_func; verify_peer_callback_tag;
verify_options.verify_peer_destruct = verify_options.verify_peer_destruct =
grpcsharp_verify_peer_destroy_handler; grpcsharp_verify_peer_destroy_handler;
verify_options.verify_peer_callback = grpcsharp_verify_peer_handler; verify_options.verify_peer_callback = grpcsharp_verify_peer_handler;
@ -1043,21 +1053,6 @@ grpcsharp_composite_call_credentials_create(grpc_call_credentials* creds1,
return grpc_composite_call_credentials_create(creds1, creds2, NULL); return grpc_composite_call_credentials_create(creds1, creds2, NULL);
} }
/* Native callback dispatcher */
typedef int(GPR_CALLTYPE* grpcsharp_native_callback_dispatcher_func)(
void* tag, void* arg0, void* arg1, void* arg2, void* arg3, void* arg4,
void* arg5);
static grpcsharp_native_callback_dispatcher_func native_callback_dispatcher =
NULL;
GPR_EXPORT void GPR_CALLTYPE grpcsharp_native_callback_dispatcher_init(
grpcsharp_native_callback_dispatcher_func func) {
GPR_ASSERT(func);
native_callback_dispatcher = func;
}
/* Metadata credentials plugin */ /* Metadata credentials plugin */
GPR_EXPORT void GPR_CALLTYPE grpcsharp_metadata_credentials_notify_from_plugin( GPR_EXPORT void GPR_CALLTYPE grpcsharp_metadata_credentials_notify_from_plugin(

@ -44,7 +44,7 @@ native_method_signatures = [
'void grpcsharp_channel_args_set_integer(ChannelArgsSafeHandle args, UIntPtr index, string key, int value)', 'void grpcsharp_channel_args_set_integer(ChannelArgsSafeHandle args, UIntPtr index, string key, int value)',
'void grpcsharp_channel_args_destroy(IntPtr args)', 'void grpcsharp_channel_args_destroy(IntPtr args)',
'void grpcsharp_override_default_ssl_roots(string pemRootCerts)', 'void grpcsharp_override_default_ssl_roots(string pemRootCerts)',
'ChannelCredentialsSafeHandle grpcsharp_ssl_credentials_create(string pemRootCerts, string keyCertPairCertChain, string keyCertPairPrivateKey, VerifyPeerCallbackInternal verifyPeerCallback)', 'ChannelCredentialsSafeHandle grpcsharp_ssl_credentials_create(string pemRootCerts, string keyCertPairCertChain, string keyCertPairPrivateKey, IntPtr verifyPeerCallbackTag)',
'ChannelCredentialsSafeHandle grpcsharp_composite_channel_credentials_create(ChannelCredentialsSafeHandle channelCreds, CallCredentialsSafeHandle callCreds)', 'ChannelCredentialsSafeHandle grpcsharp_composite_channel_credentials_create(ChannelCredentialsSafeHandle channelCreds, CallCredentialsSafeHandle callCreds)',
'void grpcsharp_channel_credentials_release(IntPtr credentials)', 'void grpcsharp_channel_credentials_release(IntPtr credentials)',
'ChannelSafeHandle grpcsharp_insecure_channel_create(string target, ChannelArgsSafeHandle channelArgs)', 'ChannelSafeHandle grpcsharp_insecure_channel_create(string target, ChannelArgsSafeHandle channelArgs)',

Loading…
Cancel
Save