Move security credentials, connectors, and auth context to C++

This is to use `grpc_core::RefCount` to improve performnace.
This commit also replaces explicit C vtables, with C++ vtable
with its own compile time assertions and performance benefits.
It also makes use of `RefCountedPtr` wherever possible.
pull/17291/head
Soheil Hassas Yeganeh 6 years ago
parent 9e9cae7839
commit 9decf48632
  1. 10
      src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb_channel_secure.cc
  2. 10
      src/core/ext/filters/client_channel/lb_policy/xds/xds_channel_secure.cc
  3. 17
      src/core/ext/transport/chttp2/client/secure/secure_channel_create.cc
  4. 19
      src/core/ext/transport/chttp2/server/secure/server_secure_chttp2.cc
  5. 8
      src/core/lib/gprpp/ref_counted_ptr.h
  6. 195
      src/core/lib/http/httpcli_security_connector.cc
  7. 10
      src/core/lib/http/parser.h
  8. 183
      src/core/lib/security/context/security_context.cc
  9. 94
      src/core/lib/security/context/security_context.h
  10. 84
      src/core/lib/security/credentials/alts/alts_credentials.cc
  11. 47
      src/core/lib/security/credentials/alts/alts_credentials.h
  12. 297
      src/core/lib/security/credentials/composite/composite_credentials.cc
  13. 111
      src/core/lib/security/credentials/composite/composite_credentials.h
  14. 160
      src/core/lib/security/credentials/credentials.cc
  15. 214
      src/core/lib/security/credentials/credentials.h
  16. 117
      src/core/lib/security/credentials/fake/fake_credentials.cc
  17. 28
      src/core/lib/security/credentials/fake/fake_credentials.h
  18. 83
      src/core/lib/security/credentials/google_default/google_default_credentials.cc
  19. 33
      src/core/lib/security/credentials/google_default/google_default_credentials.h
  20. 62
      src/core/lib/security/credentials/iam/iam_credentials.cc
  21. 22
      src/core/lib/security/credentials/iam/iam_credentials.h
  22. 129
      src/core/lib/security/credentials/jwt/jwt_credentials.cc
  23. 39
      src/core/lib/security/credentials/jwt/jwt_credentials.h
  24. 51
      src/core/lib/security/credentials/local/local_credentials.cc
  25. 43
      src/core/lib/security/credentials/local/local_credentials.h
  26. 279
      src/core/lib/security/credentials/oauth2/oauth2_credentials.cc
  27. 103
      src/core/lib/security/credentials/oauth2/oauth2_credentials.h
  28. 136
      src/core/lib/security/credentials/plugin/plugin_credentials.cc
  29. 57
      src/core/lib/security/credentials/plugin/plugin_credentials.h
  30. 149
      src/core/lib/security/credentials/ssl/ssl_credentials.cc
  31. 73
      src/core/lib/security/credentials/ssl/ssl_credentials.h
  32. 329
      src/core/lib/security/security_connector/alts/alts_security_connector.cc
  33. 22
      src/core/lib/security/security_connector/alts/alts_security_connector.h
  34. 424
      src/core/lib/security/security_connector/fake/fake_security_connector.cc
  35. 15
      src/core/lib/security/security_connector/fake/fake_security_connector.h
  36. 278
      src/core/lib/security/security_connector/local/local_security_connector.cc
  37. 19
      src/core/lib/security/security_connector/local/local_security_connector.h
  38. 165
      src/core/lib/security/security_connector/security_connector.cc
  39. 206
      src/core/lib/security/security_connector/security_connector.h
  40. 718
      src/core/lib/security/security_connector/ssl/ssl_security_connector.cc
  41. 26
      src/core/lib/security/security_connector/ssl/ssl_security_connector.h
  42. 22
      src/core/lib/security/security_connector/ssl_utils.cc
  43. 4
      src/core/lib/security/security_connector/ssl_utils.h
  44. 100
      src/core/lib/security/transport/client_auth_filter.cc
  45. 147
      src/core/lib/security/transport/security_handshaker.cc
  46. 28
      src/core/lib/security/transport/server_auth_filter.cc
  47. 6
      src/cpp/client/secure_credentials.cc
  48. 9
      src/cpp/client/secure_credentials.h
  49. 38
      src/cpp/common/secure_auth_context.cc
  50. 11
      src/cpp/common/secure_auth_context.h
  51. 5
      src/cpp/common/secure_create_auth_context.cc
  52. 2
      src/cpp/server/secure_server_credentials.cc
  53. 41
      test/core/security/alts_security_connector_test.cc
  54. 116
      test/core/security/auth_context_test.cc
  55. 232
      test/core/security/credentials_test.cc
  56. 5
      test/core/security/oauth2_utils.cc
  57. 9
      test/core/security/print_google_default_creds_token.cc
  58. 95
      test/core/security/security_connector_test.cc
  59. 11
      test/core/security/ssl_server_fuzzer.cc
  60. 2
      test/core/surface/secure_channel_create_test.cc
  61. 17
      test/cpp/common/auth_property_iterator_test.cc
  62. 14
      test/cpp/common/secure_auth_context_test.cc
  63. 4
      test/cpp/end2end/grpclb_end2end_test.cc

@ -88,22 +88,18 @@ grpc_channel_args* grpc_lb_policy_grpclb_modify_lb_channel_args(
// bearer token credentials.
grpc_channel_credentials* channel_credentials =
grpc_channel_credentials_find_in_args(args);
grpc_channel_credentials* creds_sans_call_creds = nullptr;
grpc_core::RefCountedPtr<grpc_channel_credentials> creds_sans_call_creds;
if (channel_credentials != nullptr) {
creds_sans_call_creds =
grpc_channel_credentials_duplicate_without_call_credentials(
channel_credentials);
channel_credentials->duplicate_without_call_credentials();
GPR_ASSERT(creds_sans_call_creds != nullptr);
args_to_remove[num_args_to_remove++] = GRPC_ARG_CHANNEL_CREDENTIALS;
args_to_add[num_args_to_add++] =
grpc_channel_credentials_to_arg(creds_sans_call_creds);
grpc_channel_credentials_to_arg(creds_sans_call_creds.get());
}
grpc_channel_args* result = grpc_channel_args_copy_and_add_and_remove(
args, args_to_remove, num_args_to_remove, args_to_add, num_args_to_add);
// Clean up.
grpc_channel_args_destroy(args);
if (creds_sans_call_creds != nullptr) {
grpc_channel_credentials_unref(creds_sans_call_creds);
}
return result;
}

@ -87,22 +87,18 @@ grpc_channel_args* grpc_lb_policy_xds_modify_lb_channel_args(
// bearer token credentials.
grpc_channel_credentials* channel_credentials =
grpc_channel_credentials_find_in_args(args);
grpc_channel_credentials* creds_sans_call_creds = nullptr;
grpc_core::RefCountedPtr<grpc_channel_credentials> creds_sans_call_creds;
if (channel_credentials != nullptr) {
creds_sans_call_creds =
grpc_channel_credentials_duplicate_without_call_credentials(
channel_credentials);
channel_credentials->duplicate_without_call_credentials();
GPR_ASSERT(creds_sans_call_creds != nullptr);
args_to_remove[num_args_to_remove++] = GRPC_ARG_CHANNEL_CREDENTIALS;
args_to_add[num_args_to_add++] =
grpc_channel_credentials_to_arg(creds_sans_call_creds);
grpc_channel_credentials_to_arg(creds_sans_call_creds.get());
}
grpc_channel_args* result = grpc_channel_args_copy_and_add_and_remove(
args, args_to_remove, num_args_to_remove, args_to_add, num_args_to_add);
// Clean up.
grpc_channel_args_destroy(args);
if (creds_sans_call_creds != nullptr) {
grpc_channel_credentials_unref(creds_sans_call_creds);
}
return result;
}

@ -110,14 +110,14 @@ static grpc_subchannel_args* get_secure_naming_subchannel_args(
grpc_channel_args* args_with_authority =
grpc_channel_args_copy_and_add(args->args, args_to_add, num_args_to_add);
grpc_uri_destroy(server_uri);
grpc_channel_security_connector* subchannel_security_connector = nullptr;
// Create the security connector using the credentials and target name.
grpc_channel_args* new_args_from_connector = nullptr;
const grpc_security_status security_status =
grpc_channel_credentials_create_security_connector(
channel_credentials, authority.get(), args_with_authority,
&subchannel_security_connector, &new_args_from_connector);
if (security_status != GRPC_SECURITY_OK) {
grpc_core::RefCountedPtr<grpc_channel_security_connector>
subchannel_security_connector =
channel_credentials->create_security_connector(
/*call_creds=*/nullptr, authority.get(), args_with_authority,
&new_args_from_connector);
if (subchannel_security_connector == nullptr) {
gpr_log(GPR_ERROR,
"Failed to create secure subchannel for secure name '%s'",
authority.get());
@ -125,15 +125,14 @@ static grpc_subchannel_args* get_secure_naming_subchannel_args(
return nullptr;
}
grpc_arg new_security_connector_arg =
grpc_security_connector_to_arg(&subchannel_security_connector->base);
grpc_security_connector_to_arg(subchannel_security_connector.get());
grpc_channel_args* new_args = grpc_channel_args_copy_and_add(
new_args_from_connector != nullptr ? new_args_from_connector
: args_with_authority,
&new_security_connector_arg, 1);
GRPC_SECURITY_CONNECTOR_UNREF(&subchannel_security_connector->base,
"lb_channel_create");
subchannel_security_connector.reset(DEBUG_LOCATION, "lb_channel_create");
if (new_args_from_connector != nullptr) {
grpc_channel_args_destroy(new_args_from_connector);
}

@ -31,6 +31,7 @@
#include "src/core/ext/transport/chttp2/transport/chttp2_transport.h"
#include "src/core/lib/channel/channel_args.h"
#include "src/core/lib/channel/handshaker.h"
#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/lib/security/context/security_context.h"
#include "src/core/lib/security/credentials/credentials.h"
#include "src/core/lib/surface/api_trace.h"
@ -40,9 +41,8 @@ int grpc_server_add_secure_http2_port(grpc_server* server, const char* addr,
grpc_server_credentials* creds) {
grpc_core::ExecCtx exec_ctx;
grpc_error* err = GRPC_ERROR_NONE;
grpc_server_security_connector* sc = nullptr;
grpc_core::RefCountedPtr<grpc_server_security_connector> sc;
int port_num = 0;
grpc_security_status status;
grpc_channel_args* args = nullptr;
GRPC_API_TRACE(
"grpc_server_add_secure_http2_port("
@ -54,30 +54,27 @@ int grpc_server_add_secure_http2_port(grpc_server* server, const char* addr,
"No credentials specified for secure server port (creds==NULL)");
goto done;
}
status = grpc_server_credentials_create_security_connector(creds, &sc);
if (status != GRPC_SECURITY_OK) {
sc = creds->create_security_connector();
if (sc == nullptr) {
char* msg;
gpr_asprintf(&msg,
"Unable to create secure server with credentials of type %s.",
creds->type);
err = grpc_error_set_int(GRPC_ERROR_CREATE_FROM_COPIED_STRING(msg),
GRPC_ERROR_INT_SECURITY_STATUS, status);
creds->type());
err = GRPC_ERROR_CREATE_FROM_COPIED_STRING(msg);
gpr_free(msg);
goto done;
}
// Create channel args.
grpc_arg args_to_add[2];
args_to_add[0] = grpc_server_credentials_to_arg(creds);
args_to_add[1] = grpc_security_connector_to_arg(&sc->base);
args_to_add[1] = grpc_security_connector_to_arg(sc.get());
args =
grpc_channel_args_copy_and_add(grpc_server_get_channel_args(server),
args_to_add, GPR_ARRAY_SIZE(args_to_add));
// Add server port.
err = grpc_chttp2_server_add_port(server, addr, args, &port_num);
done:
if (sc != nullptr) {
GRPC_SECURITY_CONNECTOR_UNREF(&sc->base, "server");
}
sc.reset(DEBUG_LOCATION, "server");
if (err != GRPC_ERROR_NONE) {
const char* msg = grpc_error_string(err);

@ -50,7 +50,7 @@ class RefCountedPtr {
}
template <typename Y>
RefCountedPtr(RefCountedPtr<Y>&& other) {
value_ = other.value_;
value_ = static_cast<T*>(other.value_);
other.value_ = nullptr;
}
@ -77,7 +77,7 @@ class RefCountedPtr {
static_assert(std::has_virtual_destructor<T>::value,
"T does not have a virtual dtor");
if (other.value_ != nullptr) other.value_->IncrementRefCount();
value_ = other.value_;
value_ = static_cast<T*>(other.value_);
}
// Copy assignment.
@ -118,7 +118,7 @@ class RefCountedPtr {
static_assert(std::has_virtual_destructor<T>::value,
"T does not have a virtual dtor");
if (value_ != nullptr) value_->Unref();
value_ = value;
value_ = static_cast<T*>(value);
}
template <typename Y>
void reset(const DebugLocation& location, const char* reason,
@ -126,7 +126,7 @@ class RefCountedPtr {
static_assert(std::has_virtual_destructor<T>::value,
"T does not have a virtual dtor");
if (value_ != nullptr) value_->Unref(location, reason);
value_ = value;
value_ = static_cast<T*>(value);
}
// TODO(roth): This method exists solely as a transition mechanism to allow

@ -29,119 +29,125 @@
#include "src/core/lib/channel/channel_args.h"
#include "src/core/lib/channel/handshaker_registry.h"
#include "src/core/lib/gpr/string.h"
#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/lib/iomgr/pollset.h"
#include "src/core/lib/security/credentials/credentials.h"
#include "src/core/lib/security/security_connector/ssl_utils.h"
#include "src/core/lib/security/transport/security_handshaker.h"
#include "src/core/lib/slice/slice_internal.h"
#include "src/core/tsi/ssl_transport_security.h"
typedef struct {
grpc_channel_security_connector base;
tsi_ssl_client_handshaker_factory* handshaker_factory;
char* secure_peer_name;
} grpc_httpcli_ssl_channel_security_connector;
static void httpcli_ssl_destroy(grpc_security_connector* sc) {
grpc_httpcli_ssl_channel_security_connector* c =
reinterpret_cast<grpc_httpcli_ssl_channel_security_connector*>(sc);
if (c->handshaker_factory != nullptr) {
tsi_ssl_client_handshaker_factory_unref(c->handshaker_factory);
c->handshaker_factory = nullptr;
class grpc_httpcli_ssl_channel_security_connector final
: public grpc_channel_security_connector {
public:
explicit grpc_httpcli_ssl_channel_security_connector(char* secure_peer_name)
: grpc_channel_security_connector(
/*url_scheme=*/nullptr,
/*channel_creds=*/nullptr,
/*request_metadata_creds=*/nullptr),
secure_peer_name_(secure_peer_name) {}
~grpc_httpcli_ssl_channel_security_connector() override {
if (handshaker_factory_ != nullptr) {
tsi_ssl_client_handshaker_factory_unref(handshaker_factory_);
}
if (secure_peer_name_ != nullptr) {
gpr_free(secure_peer_name_);
}
}
tsi_result InitHandshakerFactory(const char* pem_root_certs,
const tsi_ssl_root_certs_store* root_store) {
tsi_ssl_client_handshaker_options options;
memset(&options, 0, sizeof(options));
options.pem_root_certs = pem_root_certs;
options.root_store = root_store;
return tsi_create_ssl_client_handshaker_factory_with_options(
&options, &handshaker_factory_);
}
if (c->secure_peer_name != nullptr) gpr_free(c->secure_peer_name);
gpr_free(sc);
}
static void httpcli_ssl_add_handshakers(grpc_channel_security_connector* sc,
grpc_pollset_set* interested_parties,
grpc_handshake_manager* handshake_mgr) {
grpc_httpcli_ssl_channel_security_connector* c =
reinterpret_cast<grpc_httpcli_ssl_channel_security_connector*>(sc);
tsi_handshaker* handshaker = nullptr;
if (c->handshaker_factory != nullptr) {
tsi_result result = tsi_ssl_client_handshaker_factory_create_handshaker(
c->handshaker_factory, c->secure_peer_name, &handshaker);
if (result != TSI_OK) {
gpr_log(GPR_ERROR, "Handshaker creation failed with error %s.",
tsi_result_to_string(result));
void add_handshakers(grpc_pollset_set* interested_parties,
grpc_handshake_manager* handshake_mgr) override {
tsi_handshaker* handshaker = nullptr;
if (handshaker_factory_ != nullptr) {
tsi_result result = tsi_ssl_client_handshaker_factory_create_handshaker(
handshaker_factory_, secure_peer_name_, &handshaker);
if (result != TSI_OK) {
gpr_log(GPR_ERROR, "Handshaker creation failed with error %s.",
tsi_result_to_string(result));
}
}
grpc_handshake_manager_add(
handshake_mgr, grpc_security_handshaker_create(handshaker, this));
}
grpc_handshake_manager_add(
handshake_mgr, grpc_security_handshaker_create(handshaker, &sc->base));
}
static void httpcli_ssl_check_peer(grpc_security_connector* sc, tsi_peer peer,
grpc_auth_context** auth_context,
grpc_closure* on_peer_checked) {
grpc_httpcli_ssl_channel_security_connector* c =
reinterpret_cast<grpc_httpcli_ssl_channel_security_connector*>(sc);
grpc_error* error = GRPC_ERROR_NONE;
/* Check the peer name. */
if (c->secure_peer_name != nullptr &&
!tsi_ssl_peer_matches_name(&peer, c->secure_peer_name)) {
char* msg;
gpr_asprintf(&msg, "Peer name %s is not in peer certificate",
c->secure_peer_name);
error = GRPC_ERROR_CREATE_FROM_COPIED_STRING(msg);
gpr_free(msg);
tsi_ssl_client_handshaker_factory* handshaker_factory() const {
return handshaker_factory_;
}
GRPC_CLOSURE_SCHED(on_peer_checked, error);
tsi_peer_destruct(&peer);
}
static int httpcli_ssl_cmp(grpc_security_connector* sc1,
grpc_security_connector* sc2) {
grpc_httpcli_ssl_channel_security_connector* c1 =
reinterpret_cast<grpc_httpcli_ssl_channel_security_connector*>(sc1);
grpc_httpcli_ssl_channel_security_connector* c2 =
reinterpret_cast<grpc_httpcli_ssl_channel_security_connector*>(sc2);
return strcmp(c1->secure_peer_name, c2->secure_peer_name);
}
void check_peer(tsi_peer peer,
grpc_core::RefCountedPtr<grpc_auth_context>* /*auth_context*/,
grpc_closure* on_peer_checked) override {
grpc_error* error = GRPC_ERROR_NONE;
/* Check the peer name. */
if (secure_peer_name_ != nullptr &&
!tsi_ssl_peer_matches_name(&peer, secure_peer_name_)) {
char* msg;
gpr_asprintf(&msg, "Peer name %s is not in peer certificate",
secure_peer_name_);
error = GRPC_ERROR_CREATE_FROM_COPIED_STRING(msg);
gpr_free(msg);
}
GRPC_CLOSURE_SCHED(on_peer_checked, error);
tsi_peer_destruct(&peer);
}
static grpc_security_connector_vtable httpcli_ssl_vtable = {
httpcli_ssl_destroy, httpcli_ssl_check_peer, httpcli_ssl_cmp};
int cmp(const grpc_security_connector* other_sc) const override {
auto* other =
reinterpret_cast<const grpc_httpcli_ssl_channel_security_connector*>(
other_sc);
return strcmp(secure_peer_name_, other->secure_peer_name_);
}
static grpc_security_status httpcli_ssl_channel_security_connector_create(
const char* pem_root_certs, const tsi_ssl_root_certs_store* root_store,
const char* secure_peer_name, grpc_channel_security_connector** sc) {
tsi_result result = TSI_OK;
grpc_httpcli_ssl_channel_security_connector* c;
bool check_call_host(const char* host, grpc_auth_context* auth_context,
grpc_closure* on_call_host_checked,
grpc_error** error) override {
*error = GRPC_ERROR_NONE;
return true;
}
if (secure_peer_name != nullptr && pem_root_certs == nullptr) {
gpr_log(GPR_ERROR,
"Cannot assert a secure peer name without a trust root.");
return GRPC_SECURITY_ERROR;
void cancel_check_call_host(grpc_closure* on_call_host_checked,
grpc_error* error) override {
GRPC_ERROR_UNREF(error);
}
c = static_cast<grpc_httpcli_ssl_channel_security_connector*>(
gpr_zalloc(sizeof(grpc_httpcli_ssl_channel_security_connector)));
const char* secure_peer_name() const { return secure_peer_name_; }
gpr_ref_init(&c->base.base.refcount, 1);
c->base.base.vtable = &httpcli_ssl_vtable;
if (secure_peer_name != nullptr) {
c->secure_peer_name = gpr_strdup(secure_peer_name);
private:
tsi_ssl_client_handshaker_factory* handshaker_factory_ = nullptr;
char* secure_peer_name_;
};
static grpc_core::RefCountedPtr<grpc_channel_security_connector>
httpcli_ssl_channel_security_connector_create(
const char* pem_root_certs, const tsi_ssl_root_certs_store* root_store,
const char* secure_peer_name) {
if (secure_peer_name != nullptr && pem_root_certs == nullptr) {
gpr_log(GPR_ERROR,
"Cannot assert a secure peer name without a trust root.");
return nullptr;
}
tsi_ssl_client_handshaker_options options;
memset(&options, 0, sizeof(options));
options.pem_root_certs = pem_root_certs;
options.root_store = root_store;
result = tsi_create_ssl_client_handshaker_factory_with_options(
&options, &c->handshaker_factory);
grpc_core::RefCountedPtr<grpc_httpcli_ssl_channel_security_connector> c =
grpc_core::MakeRefCounted<grpc_httpcli_ssl_channel_security_connector>(
secure_peer_name == nullptr ? nullptr : gpr_strdup(secure_peer_name));
tsi_result result = c->InitHandshakerFactory(pem_root_certs, root_store);
if (result != TSI_OK) {
gpr_log(GPR_ERROR, "Handshaker factory creation failed with %s.",
tsi_result_to_string(result));
httpcli_ssl_destroy(&c->base.base);
*sc = nullptr;
return GRPC_SECURITY_ERROR;
return nullptr;
}
// We don't actually need a channel credentials object in this case,
// but we set it to a non-nullptr address so that we don't trigger
// assertions in grpc_channel_security_connector_cmp().
c->base.channel_creds = (grpc_channel_credentials*)1;
c->base.add_handshakers = httpcli_ssl_add_handshakers;
*sc = &c->base;
return GRPC_SECURITY_OK;
return c;
}
/* handshaker */
@ -186,10 +192,11 @@ static void ssl_handshake(void* arg, grpc_endpoint* tcp, const char* host,
}
c->func = on_done;
c->arg = arg;
grpc_channel_security_connector* sc = nullptr;
GPR_ASSERT(httpcli_ssl_channel_security_connector_create(
pem_root_certs, root_store, host, &sc) == GRPC_SECURITY_OK);
grpc_arg channel_arg = grpc_security_connector_to_arg(&sc->base);
grpc_core::RefCountedPtr<grpc_channel_security_connector> sc =
httpcli_ssl_channel_security_connector_create(pem_root_certs, root_store,
host);
GPR_ASSERT(sc != nullptr);
grpc_arg channel_arg = grpc_security_connector_to_arg(sc.get());
grpc_channel_args args = {1, &channel_arg};
c->handshake_mgr = grpc_handshake_manager_create();
grpc_handshakers_add(HANDSHAKER_CLIENT, &args,
@ -197,7 +204,7 @@ static void ssl_handshake(void* arg, grpc_endpoint* tcp, const char* host,
grpc_handshake_manager_do_handshake(
c->handshake_mgr, tcp, nullptr /* channel_args */, deadline,
nullptr /* acceptor */, on_handshake_done, c /* user_data */);
GRPC_SECURITY_CONNECTOR_UNREF(&sc->base, "httpcli");
sc.reset(DEBUG_LOCATION, "httpcli");
}
const grpc_httpcli_handshaker grpc_httpcli_ssl = {"https", ssl_handshake};

@ -70,13 +70,13 @@ typedef struct grpc_http_request {
/* A response */
typedef struct grpc_http_response {
/* HTTP status code */
int status;
int status = 0;
/* Headers: count and key/values */
size_t hdr_count;
grpc_http_header* hdrs;
size_t hdr_count = 0;
grpc_http_header* hdrs = nullptr;
/* Body: length and contents; contents are NOT null-terminated */
size_t body_length;
char* body;
size_t body_length = 0;
char* body = nullptr;
} grpc_http_response;
typedef struct {

@ -23,6 +23,8 @@
#include "src/core/lib/channel/channel_args.h"
#include "src/core/lib/gpr/arena.h"
#include "src/core/lib/gpr/string.h"
#include "src/core/lib/gprpp/ref_counted.h"
#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/lib/security/context/security_context.h"
#include "src/core/lib/surface/api_trace.h"
#include "src/core/lib/surface/call.h"
@ -50,13 +52,11 @@ grpc_call_error grpc_call_set_credentials(grpc_call* call,
ctx = static_cast<grpc_client_security_context*>(
grpc_call_context_get(call, GRPC_CONTEXT_SECURITY));
if (ctx == nullptr) {
ctx = grpc_client_security_context_create(grpc_call_get_arena(call));
ctx->creds = grpc_call_credentials_ref(creds);
ctx = grpc_client_security_context_create(grpc_call_get_arena(call), creds);
grpc_call_context_set(call, GRPC_CONTEXT_SECURITY, ctx,
grpc_client_security_context_destroy);
} else {
grpc_call_credentials_unref(ctx->creds);
ctx->creds = grpc_call_credentials_ref(creds);
ctx->creds = creds != nullptr ? creds->Ref() : nullptr;
}
return GRPC_CALL_OK;
@ -66,33 +66,45 @@ grpc_auth_context* grpc_call_auth_context(grpc_call* call) {
void* sec_ctx = grpc_call_context_get(call, GRPC_CONTEXT_SECURITY);
GRPC_API_TRACE("grpc_call_auth_context(call=%p)", 1, (call));
if (sec_ctx == nullptr) return nullptr;
return grpc_call_is_client(call)
? GRPC_AUTH_CONTEXT_REF(
((grpc_client_security_context*)sec_ctx)->auth_context,
"grpc_call_auth_context client")
: GRPC_AUTH_CONTEXT_REF(
((grpc_server_security_context*)sec_ctx)->auth_context,
"grpc_call_auth_context server");
if (grpc_call_is_client(call)) {
auto* sc = static_cast<grpc_client_security_context*>(sec_ctx);
if (sc->auth_context == nullptr) {
return nullptr;
} else {
return sc->auth_context
->Ref(DEBUG_LOCATION, "grpc_call_auth_context client")
.release();
}
} else {
auto* sc = static_cast<grpc_server_security_context*>(sec_ctx);
if (sc->auth_context == nullptr) {
return nullptr;
} else {
return sc->auth_context
->Ref(DEBUG_LOCATION, "grpc_call_auth_context server")
.release();
}
}
}
void grpc_auth_context_release(grpc_auth_context* context) {
GRPC_API_TRACE("grpc_auth_context_release(context=%p)", 1, (context));
GRPC_AUTH_CONTEXT_UNREF(context, "grpc_auth_context_unref");
if (context == nullptr) return;
context->Unref(DEBUG_LOCATION, "grpc_auth_context_unref");
}
/* --- grpc_client_security_context --- */
grpc_client_security_context::~grpc_client_security_context() {
grpc_call_credentials_unref(creds);
GRPC_AUTH_CONTEXT_UNREF(auth_context, "client_security_context");
auth_context.reset(DEBUG_LOCATION, "client_security_context");
if (extension.instance != nullptr && extension.destroy != nullptr) {
extension.destroy(extension.instance);
}
}
grpc_client_security_context* grpc_client_security_context_create(
gpr_arena* arena) {
gpr_arena* arena, grpc_call_credentials* creds) {
return new (gpr_arena_alloc(arena, sizeof(grpc_client_security_context)))
grpc_client_security_context();
grpc_client_security_context(creds != nullptr ? creds->Ref() : nullptr);
}
void grpc_client_security_context_destroy(void* ctx) {
@ -104,7 +116,7 @@ void grpc_client_security_context_destroy(void* ctx) {
/* --- grpc_server_security_context --- */
grpc_server_security_context::~grpc_server_security_context() {
GRPC_AUTH_CONTEXT_UNREF(auth_context, "server_security_context");
auth_context.reset(DEBUG_LOCATION, "server_security_context");
if (extension.instance != nullptr && extension.destroy != nullptr) {
extension.destroy(extension.instance);
}
@ -126,69 +138,11 @@ void grpc_server_security_context_destroy(void* ctx) {
static grpc_auth_property_iterator empty_iterator = {nullptr, 0, nullptr};
grpc_auth_context* grpc_auth_context_create(grpc_auth_context* chained) {
grpc_auth_context* ctx =
static_cast<grpc_auth_context*>(gpr_zalloc(sizeof(grpc_auth_context)));
gpr_ref_init(&ctx->refcount, 1);
if (chained != nullptr) {
ctx->chained = GRPC_AUTH_CONTEXT_REF(chained, "chained");
ctx->peer_identity_property_name =
ctx->chained->peer_identity_property_name;
}
return ctx;
}
#ifndef NDEBUG
grpc_auth_context* grpc_auth_context_ref(grpc_auth_context* ctx,
const char* file, int line,
const char* reason) {
if (ctx == nullptr) return nullptr;
if (grpc_trace_auth_context_refcount.enabled()) {
gpr_atm val = gpr_atm_no_barrier_load(&ctx->refcount.count);
gpr_log(file, line, GPR_LOG_SEVERITY_DEBUG,
"AUTH_CONTEXT:%p ref %" PRIdPTR " -> %" PRIdPTR " %s", ctx, val,
val + 1, reason);
}
#else
grpc_auth_context* grpc_auth_context_ref(grpc_auth_context* ctx) {
if (ctx == nullptr) return nullptr;
#endif
gpr_ref(&ctx->refcount);
return ctx;
}
#ifndef NDEBUG
void grpc_auth_context_unref(grpc_auth_context* ctx, const char* file, int line,
const char* reason) {
if (ctx == nullptr) return;
if (grpc_trace_auth_context_refcount.enabled()) {
gpr_atm val = gpr_atm_no_barrier_load(&ctx->refcount.count);
gpr_log(file, line, GPR_LOG_SEVERITY_DEBUG,
"AUTH_CONTEXT:%p unref %" PRIdPTR " -> %" PRIdPTR " %s", ctx, val,
val - 1, reason);
}
#else
void grpc_auth_context_unref(grpc_auth_context* ctx) {
if (ctx == nullptr) return;
#endif
if (gpr_unref(&ctx->refcount)) {
size_t i;
GRPC_AUTH_CONTEXT_UNREF(ctx->chained, "chained");
if (ctx->properties.array != nullptr) {
for (i = 0; i < ctx->properties.count; i++) {
grpc_auth_property_reset(&ctx->properties.array[i]);
}
gpr_free(ctx->properties.array);
}
gpr_free(ctx);
}
}
const char* grpc_auth_context_peer_identity_property_name(
const grpc_auth_context* ctx) {
GRPC_API_TRACE("grpc_auth_context_peer_identity_property_name(ctx=%p)", 1,
(ctx));
return ctx->peer_identity_property_name;
return ctx->peer_identity_property_name();
}
int grpc_auth_context_set_peer_identity_property_name(grpc_auth_context* ctx,
@ -204,13 +158,13 @@ int grpc_auth_context_set_peer_identity_property_name(grpc_auth_context* ctx,
name != nullptr ? name : "NULL");
return 0;
}
ctx->peer_identity_property_name = prop->name;
ctx->set_peer_identity_property_name(prop->name);
return 1;
}
int grpc_auth_context_peer_is_authenticated(const grpc_auth_context* ctx) {
GRPC_API_TRACE("grpc_auth_context_peer_is_authenticated(ctx=%p)", 1, (ctx));
return ctx->peer_identity_property_name == nullptr ? 0 : 1;
return ctx->is_authenticated();
}
grpc_auth_property_iterator grpc_auth_context_property_iterator(
@ -226,16 +180,17 @@ const grpc_auth_property* grpc_auth_property_iterator_next(
grpc_auth_property_iterator* it) {
GRPC_API_TRACE("grpc_auth_property_iterator_next(it=%p)", 1, (it));
if (it == nullptr || it->ctx == nullptr) return nullptr;
while (it->index == it->ctx->properties.count) {
if (it->ctx->chained == nullptr) return nullptr;
it->ctx = it->ctx->chained;
while (it->index == it->ctx->properties().count) {
if (it->ctx->chained() == nullptr) return nullptr;
it->ctx = it->ctx->chained();
it->index = 0;
}
if (it->name == nullptr) {
return &it->ctx->properties.array[it->index++];
return &it->ctx->properties().array[it->index++];
} else {
while (it->index < it->ctx->properties.count) {
const grpc_auth_property* prop = &it->ctx->properties.array[it->index++];
while (it->index < it->ctx->properties().count) {
const grpc_auth_property* prop =
&it->ctx->properties().array[it->index++];
GPR_ASSERT(prop->name != nullptr);
if (strcmp(it->name, prop->name) == 0) {
return prop;
@ -262,49 +217,56 @@ grpc_auth_property_iterator grpc_auth_context_peer_identity(
GRPC_API_TRACE("grpc_auth_context_peer_identity(ctx=%p)", 1, (ctx));
if (ctx == nullptr) return empty_iterator;
return grpc_auth_context_find_properties_by_name(
ctx, ctx->peer_identity_property_name);
ctx, ctx->peer_identity_property_name());
}
static void ensure_auth_context_capacity(grpc_auth_context* ctx) {
if (ctx->properties.count == ctx->properties.capacity) {
ctx->properties.capacity =
GPR_MAX(ctx->properties.capacity + 8, ctx->properties.capacity * 2);
ctx->properties.array = static_cast<grpc_auth_property*>(
gpr_realloc(ctx->properties.array,
ctx->properties.capacity * sizeof(grpc_auth_property)));
void grpc_auth_context::ensure_capacity() {
if (properties_.count == properties_.capacity) {
properties_.capacity =
GPR_MAX(properties_.capacity + 8, properties_.capacity * 2);
properties_.array = static_cast<grpc_auth_property*>(gpr_realloc(
properties_.array, properties_.capacity * sizeof(grpc_auth_property)));
}
}
void grpc_auth_context::add_property(const char* name, const char* value,
size_t value_length) {
ensure_capacity();
grpc_auth_property* prop = &properties_.array[properties_.count++];
prop->name = gpr_strdup(name);
prop->value = static_cast<char*>(gpr_malloc(value_length + 1));
memcpy(prop->value, value, value_length);
prop->value[value_length] = '\0';
prop->value_length = value_length;
}
void grpc_auth_context_add_property(grpc_auth_context* ctx, const char* name,
const char* value, size_t value_length) {
grpc_auth_property* prop;
GRPC_API_TRACE(
"grpc_auth_context_add_property(ctx=%p, name=%s, value=%*.*s, "
"value_length=%lu)",
6,
(ctx, name, (int)value_length, (int)value_length, value,
(unsigned long)value_length));
ensure_auth_context_capacity(ctx);
prop = &ctx->properties.array[ctx->properties.count++];
ctx->add_property(name, value, value_length);
}
void grpc_auth_context::add_cstring_property(const char* name,
const char* value) {
ensure_capacity();
grpc_auth_property* prop = &properties_.array[properties_.count++];
prop->name = gpr_strdup(name);
prop->value = static_cast<char*>(gpr_malloc(value_length + 1));
memcpy(prop->value, value, value_length);
prop->value[value_length] = '\0';
prop->value_length = value_length;
prop->value = gpr_strdup(value);
prop->value_length = strlen(value);
}
void grpc_auth_context_add_cstring_property(grpc_auth_context* ctx,
const char* name,
const char* value) {
grpc_auth_property* prop;
GRPC_API_TRACE(
"grpc_auth_context_add_cstring_property(ctx=%p, name=%s, value=%s)", 3,
(ctx, name, value));
ensure_auth_context_capacity(ctx);
prop = &ctx->properties.array[ctx->properties.count++];
prop->name = gpr_strdup(name);
prop->value = gpr_strdup(value);
prop->value_length = strlen(value);
ctx->add_cstring_property(name, value);
}
void grpc_auth_property_reset(grpc_auth_property* property) {
@ -314,12 +276,17 @@ void grpc_auth_property_reset(grpc_auth_property* property) {
}
static void auth_context_pointer_arg_destroy(void* p) {
GRPC_AUTH_CONTEXT_UNREF((grpc_auth_context*)p, "auth_context_pointer_arg");
if (p != nullptr) {
static_cast<grpc_auth_context*>(p)->Unref(DEBUG_LOCATION,
"auth_context_pointer_arg");
}
}
static void* auth_context_pointer_arg_copy(void* p) {
return GRPC_AUTH_CONTEXT_REF((grpc_auth_context*)p,
"auth_context_pointer_arg");
auto* ctx = static_cast<grpc_auth_context*>(p);
return ctx == nullptr
? nullptr
: ctx->Ref(DEBUG_LOCATION, "auth_context_pointer_arg").release();
}
static int auth_context_pointer_cmp(void* a, void* b) { return GPR_ICMP(a, b); }

@ -21,6 +21,8 @@
#include <grpc/support/port_platform.h>
#include "src/core/lib/gprpp/ref_counted.h"
#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/lib/iomgr/pollset.h"
#include "src/core/lib/security/credentials/credentials.h"
@ -40,39 +42,59 @@ struct grpc_auth_property_array {
size_t capacity = 0;
};
struct grpc_auth_context {
grpc_auth_context() { gpr_ref_init(&refcount, 0); }
void grpc_auth_property_reset(grpc_auth_property* property);
struct grpc_auth_context* chained = nullptr;
grpc_auth_property_array properties;
gpr_refcount refcount;
const char* peer_identity_property_name = nullptr;
grpc_pollset* pollset = nullptr;
// This type is forward declared as a C struct and we cannot define it as a
// class. Otherwise, compiler will complain about type mismatch due to
// -Wmismatched-tags.
struct grpc_auth_context
: public grpc_core::RefCounted<grpc_auth_context,
grpc_core::NonPolymorphicRefCount> {
public:
explicit grpc_auth_context(
grpc_core::RefCountedPtr<grpc_auth_context> chained)
: grpc_core::RefCounted<grpc_auth_context,
grpc_core::NonPolymorphicRefCount>(
&grpc_trace_auth_context_refcount),
chained_(std::move(chained)) {
if (chained_ != nullptr) {
peer_identity_property_name_ = chained_->peer_identity_property_name_;
}
}
~grpc_auth_context() {
chained_.reset(DEBUG_LOCATION, "chained");
if (properties_.array != nullptr) {
for (size_t i = 0; i < properties_.count; i++) {
grpc_auth_property_reset(&properties_.array[i]);
}
gpr_free(properties_.array);
}
}
const grpc_auth_context* chained() const { return chained_.get(); }
const grpc_auth_property_array& properties() const { return properties_; }
bool is_authenticated() const {
return peer_identity_property_name_ != nullptr;
}
const char* peer_identity_property_name() const {
return peer_identity_property_name_;
}
void set_peer_identity_property_name(const char* name) {
peer_identity_property_name_ = name;
}
void ensure_capacity();
void add_property(const char* name, const char* value, size_t value_length);
void add_cstring_property(const char* name, const char* value);
private:
grpc_core::RefCountedPtr<grpc_auth_context> chained_;
grpc_auth_property_array properties_;
const char* peer_identity_property_name_ = nullptr;
};
/* Creation. */
grpc_auth_context* grpc_auth_context_create(grpc_auth_context* chained);
/* Refcounting. */
#ifndef NDEBUG
#define GRPC_AUTH_CONTEXT_REF(p, r) \
grpc_auth_context_ref((p), __FILE__, __LINE__, (r))
#define GRPC_AUTH_CONTEXT_UNREF(p, r) \
grpc_auth_context_unref((p), __FILE__, __LINE__, (r))
grpc_auth_context* grpc_auth_context_ref(grpc_auth_context* policy,
const char* file, int line,
const char* reason);
void grpc_auth_context_unref(grpc_auth_context* policy, const char* file,
int line, const char* reason);
#else
#define GRPC_AUTH_CONTEXT_REF(p, r) grpc_auth_context_ref((p))
#define GRPC_AUTH_CONTEXT_UNREF(p, r) grpc_auth_context_unref((p))
grpc_auth_context* grpc_auth_context_ref(grpc_auth_context* policy);
void grpc_auth_context_unref(grpc_auth_context* policy);
#endif
void grpc_auth_property_reset(grpc_auth_property* property);
/* --- grpc_security_context_extension ---
Extension to the security context that may be set in a filter and accessed
@ -88,16 +110,18 @@ struct grpc_security_context_extension {
Internal client-side security context. */
struct grpc_client_security_context {
grpc_client_security_context() = default;
explicit grpc_client_security_context(
grpc_core::RefCountedPtr<grpc_call_credentials> creds)
: creds(std::move(creds)) {}
~grpc_client_security_context();
grpc_call_credentials* creds = nullptr;
grpc_auth_context* auth_context = nullptr;
grpc_core::RefCountedPtr<grpc_call_credentials> creds;
grpc_core::RefCountedPtr<grpc_auth_context> auth_context;
grpc_security_context_extension extension;
};
grpc_client_security_context* grpc_client_security_context_create(
gpr_arena* arena);
gpr_arena* arena, grpc_call_credentials* creds);
void grpc_client_security_context_destroy(void* ctx);
/* --- grpc_server_security_context ---
@ -108,7 +132,7 @@ struct grpc_server_security_context {
grpc_server_security_context() = default;
~grpc_server_security_context();
grpc_auth_context* auth_context = nullptr;
grpc_core::RefCountedPtr<grpc_auth_context> auth_context;
grpc_security_context_extension extension;
};

@ -33,40 +33,47 @@
#define GRPC_CREDENTIALS_TYPE_ALTS "Alts"
#define GRPC_ALTS_HANDSHAKER_SERVICE_URL "metadata.google.internal:8080"
static void alts_credentials_destruct(grpc_channel_credentials* creds) {
grpc_alts_credentials* alts_creds =
reinterpret_cast<grpc_alts_credentials*>(creds);
grpc_alts_credentials_options_destroy(alts_creds->options);
gpr_free(alts_creds->handshaker_service_url);
}
static void alts_server_credentials_destruct(grpc_server_credentials* creds) {
grpc_alts_server_credentials* alts_creds =
reinterpret_cast<grpc_alts_server_credentials*>(creds);
grpc_alts_credentials_options_destroy(alts_creds->options);
gpr_free(alts_creds->handshaker_service_url);
grpc_alts_credentials::grpc_alts_credentials(
const grpc_alts_credentials_options* options,
const char* handshaker_service_url)
: grpc_channel_credentials(GRPC_CREDENTIALS_TYPE_ALTS),
options_(grpc_alts_credentials_options_copy(options)),
handshaker_service_url_(handshaker_service_url == nullptr
? gpr_strdup(GRPC_ALTS_HANDSHAKER_SERVICE_URL)
: gpr_strdup(handshaker_service_url)) {}
grpc_alts_credentials::~grpc_alts_credentials() {
grpc_alts_credentials_options_destroy(options_);
gpr_free(handshaker_service_url_);
}
static grpc_security_status alts_create_security_connector(
grpc_channel_credentials* creds,
grpc_call_credentials* request_metadata_creds, const char* target_name,
const grpc_channel_args* args, grpc_channel_security_connector** sc,
grpc_core::RefCountedPtr<grpc_channel_security_connector>
grpc_alts_credentials::create_security_connector(
grpc_core::RefCountedPtr<grpc_call_credentials> call_creds,
const char* target_name, const grpc_channel_args* args,
grpc_channel_args** new_args) {
return grpc_alts_channel_security_connector_create(
creds, request_metadata_creds, target_name, sc);
this->Ref(), std::move(call_creds), target_name);
}
static grpc_security_status alts_server_create_security_connector(
grpc_server_credentials* creds, grpc_server_security_connector** sc) {
return grpc_alts_server_security_connector_create(creds, sc);
grpc_alts_server_credentials::grpc_alts_server_credentials(
const grpc_alts_credentials_options* options,
const char* handshaker_service_url)
: grpc_server_credentials(GRPC_CREDENTIALS_TYPE_ALTS),
options_(grpc_alts_credentials_options_copy(options)),
handshaker_service_url_(handshaker_service_url == nullptr
? gpr_strdup(GRPC_ALTS_HANDSHAKER_SERVICE_URL)
: gpr_strdup(handshaker_service_url)) {}
grpc_core::RefCountedPtr<grpc_server_security_connector>
grpc_alts_server_credentials::create_security_connector() {
return grpc_alts_server_security_connector_create(this->Ref());
}
static const grpc_channel_credentials_vtable alts_credentials_vtable = {
alts_credentials_destruct, alts_create_security_connector,
/*duplicate_without_call_credentials=*/nullptr};
static const grpc_server_credentials_vtable alts_server_credentials_vtable = {
alts_server_credentials_destruct, alts_server_create_security_connector};
grpc_alts_server_credentials::~grpc_alts_server_credentials() {
grpc_alts_credentials_options_destroy(options_);
gpr_free(handshaker_service_url_);
}
grpc_channel_credentials* grpc_alts_credentials_create_customized(
const grpc_alts_credentials_options* options,
@ -74,17 +81,7 @@ grpc_channel_credentials* grpc_alts_credentials_create_customized(
if (!enable_untrusted_alts && !grpc_alts_is_running_on_gcp()) {
return nullptr;
}
auto creds = static_cast<grpc_alts_credentials*>(
gpr_zalloc(sizeof(grpc_alts_credentials)));
creds->options = grpc_alts_credentials_options_copy(options);
creds->handshaker_service_url =
handshaker_service_url == nullptr
? gpr_strdup(GRPC_ALTS_HANDSHAKER_SERVICE_URL)
: gpr_strdup(handshaker_service_url);
creds->base.type = GRPC_CREDENTIALS_TYPE_ALTS;
creds->base.vtable = &alts_credentials_vtable;
gpr_ref_init(&creds->base.refcount, 1);
return &creds->base;
return grpc_core::New<grpc_alts_credentials>(options, handshaker_service_url);
}
grpc_server_credentials* grpc_alts_server_credentials_create_customized(
@ -93,17 +90,8 @@ grpc_server_credentials* grpc_alts_server_credentials_create_customized(
if (!enable_untrusted_alts && !grpc_alts_is_running_on_gcp()) {
return nullptr;
}
auto creds = static_cast<grpc_alts_server_credentials*>(
gpr_zalloc(sizeof(grpc_alts_server_credentials)));
creds->options = grpc_alts_credentials_options_copy(options);
creds->handshaker_service_url =
handshaker_service_url == nullptr
? gpr_strdup(GRPC_ALTS_HANDSHAKER_SERVICE_URL)
: gpr_strdup(handshaker_service_url);
creds->base.type = GRPC_CREDENTIALS_TYPE_ALTS;
creds->base.vtable = &alts_server_credentials_vtable;
gpr_ref_init(&creds->base.refcount, 1);
return &creds->base;
return grpc_core::New<grpc_alts_server_credentials>(options,
handshaker_service_url);
}
grpc_channel_credentials* grpc_alts_credentials_create(

@ -27,18 +27,45 @@
#include "src/core/lib/security/credentials/credentials.h"
/* Main struct for grpc ALTS channel credential. */
typedef struct grpc_alts_credentials {
grpc_channel_credentials base;
grpc_alts_credentials_options* options;
char* handshaker_service_url;
} grpc_alts_credentials;
class grpc_alts_credentials final : public grpc_channel_credentials {
public:
grpc_alts_credentials(const grpc_alts_credentials_options* options,
const char* handshaker_service_url);
~grpc_alts_credentials() override;
grpc_core::RefCountedPtr<grpc_channel_security_connector>
create_security_connector(
grpc_core::RefCountedPtr<grpc_call_credentials> call_creds,
const char* target_name, const grpc_channel_args* args,
grpc_channel_args** new_args) override;
const grpc_alts_credentials_options* options() const { return options_; }
grpc_alts_credentials_options* mutable_options() { return options_; }
const char* handshaker_service_url() const { return handshaker_service_url_; }
private:
grpc_alts_credentials_options* options_;
char* handshaker_service_url_;
};
/* Main struct for grpc ALTS server credential. */
typedef struct grpc_alts_server_credentials {
grpc_server_credentials base;
grpc_alts_credentials_options* options;
char* handshaker_service_url;
} grpc_alts_server_credentials;
class grpc_alts_server_credentials final : public grpc_server_credentials {
public:
grpc_alts_server_credentials(const grpc_alts_credentials_options* options,
const char* handshaker_service_url);
~grpc_alts_server_credentials() override;
grpc_core::RefCountedPtr<grpc_server_security_connector>
create_security_connector() override;
const grpc_alts_credentials_options* options() const { return options_; }
grpc_alts_credentials_options* mutable_options() { return options_; }
const char* handshaker_service_url() const { return handshaker_service_url_; }
private:
grpc_alts_credentials_options* options_;
char* handshaker_service_url_;
};
/**
* This method creates an ALTS channel credential object with customized

@ -20,8 +20,10 @@
#include "src/core/lib/security/credentials/composite/composite_credentials.h"
#include <string.h>
#include <cstring>
#include <new>
#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/lib/iomgr/polling_entity.h"
#include "src/core/lib/surface/api_trace.h"
@ -31,36 +33,83 @@
/* -- Composite call credentials. -- */
typedef struct {
static void composite_call_metadata_cb(void* arg, grpc_error* error);
grpc_call_credentials_array::~grpc_call_credentials_array() {
for (size_t i = 0; i < num_creds_; ++i) {
creds_array_[i].~RefCountedPtr<grpc_call_credentials>();
}
if (creds_array_ != nullptr) {
gpr_free(creds_array_);
}
}
grpc_call_credentials_array::grpc_call_credentials_array(
const grpc_call_credentials_array& that)
: num_creds_(that.num_creds_) {
reserve(that.capacity_);
for (size_t i = 0; i < num_creds_; ++i) {
new (&creds_array_[i])
grpc_core::RefCountedPtr<grpc_call_credentials>(that.creds_array_[i]);
}
}
void grpc_call_credentials_array::reserve(size_t capacity) {
if (capacity_ >= capacity) {
return;
}
grpc_core::RefCountedPtr<grpc_call_credentials>* new_arr =
static_cast<grpc_core::RefCountedPtr<grpc_call_credentials>*>(gpr_malloc(
sizeof(grpc_core::RefCountedPtr<grpc_call_credentials>) * capacity));
if (creds_array_ != nullptr) {
for (size_t i = 0; i < num_creds_; ++i) {
new (&new_arr[i]) grpc_core::RefCountedPtr<grpc_call_credentials>(
std::move(creds_array_[i]));
creds_array_[i].~RefCountedPtr<grpc_call_credentials>();
}
gpr_free(creds_array_);
}
creds_array_ = new_arr;
capacity_ = capacity;
}
namespace {
struct grpc_composite_call_credentials_metadata_context {
grpc_composite_call_credentials_metadata_context(
grpc_composite_call_credentials* composite_creds,
grpc_polling_entity* pollent, grpc_auth_metadata_context auth_md_context,
grpc_credentials_mdelem_array* md_array,
grpc_closure* on_request_metadata)
: composite_creds(composite_creds),
pollent(pollent),
auth_md_context(auth_md_context),
md_array(md_array),
on_request_metadata(on_request_metadata) {
GRPC_CLOSURE_INIT(&internal_on_request_metadata, composite_call_metadata_cb,
this, grpc_schedule_on_exec_ctx);
}
grpc_composite_call_credentials* composite_creds;
size_t creds_index;
size_t creds_index = 0;
grpc_polling_entity* pollent;
grpc_auth_metadata_context auth_md_context;
grpc_credentials_mdelem_array* md_array;
grpc_closure* on_request_metadata;
grpc_closure internal_on_request_metadata;
} grpc_composite_call_credentials_metadata_context;
static void composite_call_destruct(grpc_call_credentials* creds) {
grpc_composite_call_credentials* c =
reinterpret_cast<grpc_composite_call_credentials*>(creds);
for (size_t i = 0; i < c->inner.num_creds; i++) {
grpc_call_credentials_unref(c->inner.creds_array[i]);
}
gpr_free(c->inner.creds_array);
}
};
} // namespace
static void composite_call_metadata_cb(void* arg, grpc_error* error) {
grpc_composite_call_credentials_metadata_context* ctx =
static_cast<grpc_composite_call_credentials_metadata_context*>(arg);
if (error == GRPC_ERROR_NONE) {
const grpc_call_credentials_array& inner = ctx->composite_creds->inner();
/* See if we need to get some more metadata. */
if (ctx->creds_index < ctx->composite_creds->inner.num_creds) {
grpc_call_credentials* inner_creds =
ctx->composite_creds->inner.creds_array[ctx->creds_index++];
if (grpc_call_credentials_get_request_metadata(
inner_creds, ctx->pollent, ctx->auth_md_context, ctx->md_array,
&ctx->internal_on_request_metadata, &error)) {
if (ctx->creds_index < inner.size()) {
if (inner.get(ctx->creds_index++)
->get_request_metadata(
ctx->pollent, ctx->auth_md_context, ctx->md_array,
&ctx->internal_on_request_metadata, &error)) {
// Synchronous response, so call ourselves recursively.
composite_call_metadata_cb(arg, error);
GRPC_ERROR_UNREF(error);
@ -73,76 +122,86 @@ static void composite_call_metadata_cb(void* arg, grpc_error* error) {
gpr_free(ctx);
}
static bool composite_call_get_request_metadata(
grpc_call_credentials* creds, grpc_polling_entity* pollent,
grpc_auth_metadata_context auth_md_context,
bool grpc_composite_call_credentials::get_request_metadata(
grpc_polling_entity* pollent, grpc_auth_metadata_context auth_md_context,
grpc_credentials_mdelem_array* md_array, grpc_closure* on_request_metadata,
grpc_error** error) {
grpc_composite_call_credentials* c =
reinterpret_cast<grpc_composite_call_credentials*>(creds);
grpc_composite_call_credentials_metadata_context* ctx;
ctx = static_cast<grpc_composite_call_credentials_metadata_context*>(
gpr_zalloc(sizeof(grpc_composite_call_credentials_metadata_context)));
ctx->composite_creds = c;
ctx->pollent = pollent;
ctx->auth_md_context = auth_md_context;
ctx->md_array = md_array;
ctx->on_request_metadata = on_request_metadata;
GRPC_CLOSURE_INIT(&ctx->internal_on_request_metadata,
composite_call_metadata_cb, ctx, grpc_schedule_on_exec_ctx);
ctx = grpc_core::New<grpc_composite_call_credentials_metadata_context>(
this, pollent, auth_md_context, md_array, on_request_metadata);
bool synchronous = true;
while (ctx->creds_index < ctx->composite_creds->inner.num_creds) {
grpc_call_credentials* inner_creds =
ctx->composite_creds->inner.creds_array[ctx->creds_index++];
if (grpc_call_credentials_get_request_metadata(
inner_creds, ctx->pollent, ctx->auth_md_context, ctx->md_array,
&ctx->internal_on_request_metadata, error)) {
const grpc_call_credentials_array& inner = ctx->composite_creds->inner();
while (ctx->creds_index < inner.size()) {
if (inner.get(ctx->creds_index++)
->get_request_metadata(ctx->pollent, ctx->auth_md_context,
ctx->md_array,
&ctx->internal_on_request_metadata, error)) {
if (*error != GRPC_ERROR_NONE) break;
} else {
synchronous = false; // Async return.
break;
}
}
if (synchronous) gpr_free(ctx);
if (synchronous) grpc_core::Delete(ctx);
return synchronous;
}
static void composite_call_cancel_get_request_metadata(
grpc_call_credentials* creds, grpc_credentials_mdelem_array* md_array,
grpc_error* error) {
grpc_composite_call_credentials* c =
reinterpret_cast<grpc_composite_call_credentials*>(creds);
for (size_t i = 0; i < c->inner.num_creds; ++i) {
grpc_call_credentials_cancel_get_request_metadata(
c->inner.creds_array[i], md_array, GRPC_ERROR_REF(error));
void grpc_composite_call_credentials::cancel_get_request_metadata(
grpc_credentials_mdelem_array* md_array, grpc_error* error) {
for (size_t i = 0; i < inner_.size(); ++i) {
inner_.get(i)->cancel_get_request_metadata(md_array, GRPC_ERROR_REF(error));
}
GRPC_ERROR_UNREF(error);
}
static grpc_call_credentials_vtable composite_call_credentials_vtable = {
composite_call_destruct, composite_call_get_request_metadata,
composite_call_cancel_get_request_metadata};
static size_t get_creds_array_size(const grpc_call_credentials* creds,
bool is_composite) {
return is_composite
? static_cast<const grpc_composite_call_credentials*>(creds)
->inner()
.size()
: 1;
}
static grpc_call_credentials_array get_creds_array(
grpc_call_credentials** creds_addr) {
grpc_call_credentials_array result;
grpc_call_credentials* creds = *creds_addr;
result.creds_array = creds_addr;
result.num_creds = 1;
if (strcmp(creds->type, GRPC_CALL_CREDENTIALS_TYPE_COMPOSITE) == 0) {
result = *grpc_composite_call_credentials_get_credentials(creds);
void grpc_composite_call_credentials::push_to_inner(
grpc_core::RefCountedPtr<grpc_call_credentials> creds, bool is_composite) {
if (!is_composite) {
inner_.push_back(std::move(creds));
return;
}
return result;
auto composite_creds =
static_cast<grpc_composite_call_credentials*>(creds.get());
for (size_t i = 0; i < composite_creds->inner().size(); ++i) {
inner_.push_back(std::move(composite_creds->inner_.get_mutable(i)));
}
}
grpc_composite_call_credentials::grpc_composite_call_credentials(
grpc_core::RefCountedPtr<grpc_call_credentials> creds1,
grpc_core::RefCountedPtr<grpc_call_credentials> creds2)
: grpc_call_credentials(GRPC_CALL_CREDENTIALS_TYPE_COMPOSITE) {
const bool creds1_is_composite =
strcmp(creds1->type(), GRPC_CALL_CREDENTIALS_TYPE_COMPOSITE) == 0;
const bool creds2_is_composite =
strcmp(creds2->type(), GRPC_CALL_CREDENTIALS_TYPE_COMPOSITE) == 0;
const size_t size = get_creds_array_size(creds1.get(), creds1_is_composite) +
get_creds_array_size(creds2.get(), creds2_is_composite);
inner_.reserve(size);
push_to_inner(std::move(creds1), creds1_is_composite);
push_to_inner(std::move(creds2), creds2_is_composite);
}
static grpc_core::RefCountedPtr<grpc_call_credentials>
composite_call_credentials_create(
grpc_core::RefCountedPtr<grpc_call_credentials> creds1,
grpc_core::RefCountedPtr<grpc_call_credentials> creds2) {
return grpc_core::MakeRefCounted<grpc_composite_call_credentials>(
std::move(creds1), std::move(creds2));
}
grpc_call_credentials* grpc_composite_call_credentials_create(
grpc_call_credentials* creds1, grpc_call_credentials* creds2,
void* reserved) {
size_t i;
size_t creds_array_byte_size;
grpc_call_credentials_array creds1_array;
grpc_call_credentials_array creds2_array;
grpc_composite_call_credentials* c;
GRPC_API_TRACE(
"grpc_composite_call_credentials_create(creds1=%p, creds2=%p, "
"reserved=%p)",
@ -150,120 +209,40 @@ grpc_call_credentials* grpc_composite_call_credentials_create(
GPR_ASSERT(reserved == nullptr);
GPR_ASSERT(creds1 != nullptr);
GPR_ASSERT(creds2 != nullptr);
c = static_cast<grpc_composite_call_credentials*>(
gpr_zalloc(sizeof(grpc_composite_call_credentials)));
c->base.type = GRPC_CALL_CREDENTIALS_TYPE_COMPOSITE;
c->base.vtable = &composite_call_credentials_vtable;
gpr_ref_init(&c->base.refcount, 1);
creds1_array = get_creds_array(&creds1);
creds2_array = get_creds_array(&creds2);
c->inner.num_creds = creds1_array.num_creds + creds2_array.num_creds;
creds_array_byte_size = c->inner.num_creds * sizeof(grpc_call_credentials*);
c->inner.creds_array =
static_cast<grpc_call_credentials**>(gpr_zalloc(creds_array_byte_size));
for (i = 0; i < creds1_array.num_creds; i++) {
grpc_call_credentials* cur_creds = creds1_array.creds_array[i];
c->inner.creds_array[i] = grpc_call_credentials_ref(cur_creds);
}
for (i = 0; i < creds2_array.num_creds; i++) {
grpc_call_credentials* cur_creds = creds2_array.creds_array[i];
c->inner.creds_array[i + creds1_array.num_creds] =
grpc_call_credentials_ref(cur_creds);
}
return &c->base;
}
const grpc_call_credentials_array*
grpc_composite_call_credentials_get_credentials(grpc_call_credentials* creds) {
const grpc_composite_call_credentials* c =
reinterpret_cast<const grpc_composite_call_credentials*>(creds);
GPR_ASSERT(strcmp(creds->type, GRPC_CALL_CREDENTIALS_TYPE_COMPOSITE) == 0);
return &c->inner;
}
grpc_call_credentials* grpc_credentials_contains_type(
grpc_call_credentials* creds, const char* type,
grpc_call_credentials** composite_creds) {
size_t i;
if (strcmp(creds->type, type) == 0) {
if (composite_creds != nullptr) *composite_creds = nullptr;
return creds;
} else if (strcmp(creds->type, GRPC_CALL_CREDENTIALS_TYPE_COMPOSITE) == 0) {
const grpc_call_credentials_array* inner_creds_array =
grpc_composite_call_credentials_get_credentials(creds);
for (i = 0; i < inner_creds_array->num_creds; i++) {
if (strcmp(type, inner_creds_array->creds_array[i]->type) == 0) {
if (composite_creds != nullptr) *composite_creds = creds;
return inner_creds_array->creds_array[i];
}
}
}
return nullptr;
return composite_call_credentials_create(creds1->Ref(), creds2->Ref())
.release();
}
/* -- Composite channel credentials. -- */
static void composite_channel_destruct(grpc_channel_credentials* creds) {
grpc_composite_channel_credentials* c =
reinterpret_cast<grpc_composite_channel_credentials*>(creds);
grpc_channel_credentials_unref(c->inner_creds);
grpc_call_credentials_unref(c->call_creds);
}
static grpc_security_status composite_channel_create_security_connector(
grpc_channel_credentials* creds, grpc_call_credentials* call_creds,
grpc_core::RefCountedPtr<grpc_channel_security_connector>
grpc_composite_channel_credentials::create_security_connector(
grpc_core::RefCountedPtr<grpc_call_credentials> call_creds,
const char* target, const grpc_channel_args* args,
grpc_channel_security_connector** sc, grpc_channel_args** new_args) {
grpc_composite_channel_credentials* c =
reinterpret_cast<grpc_composite_channel_credentials*>(creds);
grpc_security_status status = GRPC_SECURITY_ERROR;
GPR_ASSERT(c->inner_creds != nullptr && c->call_creds != nullptr &&
c->inner_creds->vtable != nullptr &&
c->inner_creds->vtable->create_security_connector != nullptr);
grpc_channel_args** new_args) {
GPR_ASSERT(inner_creds_ != nullptr && call_creds_ != nullptr);
/* If we are passed a call_creds, create a call composite to pass it
downstream. */
if (call_creds != nullptr) {
grpc_call_credentials* composite_call_creds =
grpc_composite_call_credentials_create(c->call_creds, call_creds,
nullptr);
status = c->inner_creds->vtable->create_security_connector(
c->inner_creds, composite_call_creds, target, args, sc, new_args);
grpc_call_credentials_unref(composite_call_creds);
return inner_creds_->create_security_connector(
composite_call_credentials_create(call_creds_, std::move(call_creds)),
target, args, new_args);
} else {
status = c->inner_creds->vtable->create_security_connector(
c->inner_creds, c->call_creds, target, args, sc, new_args);
return inner_creds_->create_security_connector(call_creds_, target, args,
new_args);
}
return status;
}
static grpc_channel_credentials*
composite_channel_duplicate_without_call_credentials(
grpc_channel_credentials* creds) {
grpc_composite_channel_credentials* c =
reinterpret_cast<grpc_composite_channel_credentials*>(creds);
return grpc_channel_credentials_ref(c->inner_creds);
}
static grpc_channel_credentials_vtable composite_channel_credentials_vtable = {
composite_channel_destruct, composite_channel_create_security_connector,
composite_channel_duplicate_without_call_credentials};
grpc_channel_credentials* grpc_composite_channel_credentials_create(
grpc_channel_credentials* channel_creds, grpc_call_credentials* call_creds,
void* reserved) {
grpc_composite_channel_credentials* c =
static_cast<grpc_composite_channel_credentials*>(gpr_zalloc(sizeof(*c)));
GPR_ASSERT(channel_creds != nullptr && call_creds != nullptr &&
reserved == nullptr);
GRPC_API_TRACE(
"grpc_composite_channel_credentials_create(channel_creds=%p, "
"call_creds=%p, reserved=%p)",
3, (channel_creds, call_creds, reserved));
c->base.type = channel_creds->type;
c->base.vtable = &composite_channel_credentials_vtable;
gpr_ref_init(&c->base.refcount, 1);
c->inner_creds = grpc_channel_credentials_ref(channel_creds);
c->call_creds = grpc_call_credentials_ref(call_creds);
return &c->base;
return grpc_core::New<grpc_composite_channel_credentials>(
channel_creds->Ref(), call_creds->Ref());
}

@ -21,39 +21,104 @@
#include <grpc/support/port_platform.h>
#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/lib/security/credentials/credentials.h"
typedef struct {
grpc_call_credentials** creds_array;
size_t num_creds;
} grpc_call_credentials_array;
// TODO(soheil): Replace this with InlinedVector once #16032 is resolved.
class grpc_call_credentials_array {
public:
grpc_call_credentials_array() = default;
grpc_call_credentials_array(const grpc_call_credentials_array& that);
const grpc_call_credentials_array*
grpc_composite_call_credentials_get_credentials(
grpc_call_credentials* composite_creds);
~grpc_call_credentials_array();
/* Returns creds if creds is of the specified type or the inner creds of the
specified type (if found), if the creds is of type COMPOSITE.
If composite_creds is not NULL, *composite_creds will point to creds if of
type COMPOSITE in case of success. */
grpc_call_credentials* grpc_credentials_contains_type(
grpc_call_credentials* creds, const char* type,
grpc_call_credentials** composite_creds);
void reserve(size_t capacity);
// Must reserve before pushing any data.
void push_back(grpc_core::RefCountedPtr<grpc_call_credentials> cred) {
GPR_DEBUG_ASSERT(capacity_ > num_creds_);
new (&creds_array_[num_creds_++])
grpc_core::RefCountedPtr<grpc_call_credentials>(std::move(cred));
}
const grpc_core::RefCountedPtr<grpc_call_credentials>& get(size_t i) const {
GPR_DEBUG_ASSERT(i < num_creds_);
return creds_array_[i];
}
grpc_core::RefCountedPtr<grpc_call_credentials>& get_mutable(size_t i) {
GPR_DEBUG_ASSERT(i < num_creds_);
return creds_array_[i];
}
size_t size() const { return num_creds_; }
private:
grpc_core::RefCountedPtr<grpc_call_credentials>* creds_array_ = nullptr;
size_t num_creds_ = 0;
size_t capacity_ = 0;
};
/* -- Composite channel credentials. -- */
typedef struct {
grpc_channel_credentials base;
grpc_channel_credentials* inner_creds;
grpc_call_credentials* call_creds;
} grpc_composite_channel_credentials;
class grpc_composite_channel_credentials : public grpc_channel_credentials {
public:
grpc_composite_channel_credentials(
grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds,
grpc_core::RefCountedPtr<grpc_call_credentials> call_creds)
: grpc_channel_credentials(channel_creds->type()),
inner_creds_(std::move(channel_creds)),
call_creds_(std::move(call_creds)) {}
~grpc_composite_channel_credentials() override = default;
grpc_core::RefCountedPtr<grpc_channel_credentials>
duplicate_without_call_credentials() override {
return inner_creds_;
}
grpc_core::RefCountedPtr<grpc_channel_security_connector>
create_security_connector(
grpc_core::RefCountedPtr<grpc_call_credentials> call_creds,
const char* target, const grpc_channel_args* args,
grpc_channel_args** new_args) override;
const grpc_channel_credentials* inner_creds() const {
return inner_creds_.get();
}
const grpc_call_credentials* call_creds() const { return call_creds_.get(); }
grpc_call_credentials* mutable_call_creds() { return call_creds_.get(); }
private:
grpc_core::RefCountedPtr<grpc_channel_credentials> inner_creds_;
grpc_core::RefCountedPtr<grpc_call_credentials> call_creds_;
};
/* -- Composite call credentials. -- */
typedef struct {
grpc_call_credentials base;
grpc_call_credentials_array inner;
} grpc_composite_call_credentials;
class grpc_composite_call_credentials : public grpc_call_credentials {
public:
grpc_composite_call_credentials(
grpc_core::RefCountedPtr<grpc_call_credentials> creds1,
grpc_core::RefCountedPtr<grpc_call_credentials> creds2);
~grpc_composite_call_credentials() override = default;
bool get_request_metadata(grpc_polling_entity* pollent,
grpc_auth_metadata_context context,
grpc_credentials_mdelem_array* md_array,
grpc_closure* on_request_metadata,
grpc_error** error) override;
void cancel_get_request_metadata(grpc_credentials_mdelem_array* md_array,
grpc_error* error) override;
const grpc_call_credentials_array& inner() const { return inner_; }
private:
void push_to_inner(grpc_core::RefCountedPtr<grpc_call_credentials> creds,
bool is_composite);
grpc_call_credentials_array inner_;
};
#endif /* GRPC_CORE_LIB_SECURITY_CREDENTIALS_COMPOSITE_COMPOSITE_CREDENTIALS_H \
*/

@ -39,120 +39,24 @@
/* -- Common. -- */
grpc_credentials_metadata_request* grpc_credentials_metadata_request_create(
grpc_call_credentials* creds) {
grpc_credentials_metadata_request* r =
static_cast<grpc_credentials_metadata_request*>(
gpr_zalloc(sizeof(grpc_credentials_metadata_request)));
r->creds = grpc_call_credentials_ref(creds);
return r;
}
void grpc_credentials_metadata_request_destroy(
grpc_credentials_metadata_request* r) {
grpc_call_credentials_unref(r->creds);
grpc_http_response_destroy(&r->response);
gpr_free(r);
}
grpc_channel_credentials* grpc_channel_credentials_ref(
grpc_channel_credentials* creds) {
if (creds == nullptr) return nullptr;
gpr_ref(&creds->refcount);
return creds;
}
void grpc_channel_credentials_unref(grpc_channel_credentials* creds) {
if (creds == nullptr) return;
if (gpr_unref(&creds->refcount)) {
if (creds->vtable->destruct != nullptr) {
creds->vtable->destruct(creds);
}
gpr_free(creds);
}
}
void grpc_channel_credentials_release(grpc_channel_credentials* creds) {
GRPC_API_TRACE("grpc_channel_credentials_release(creds=%p)", 1, (creds));
grpc_core::ExecCtx exec_ctx;
grpc_channel_credentials_unref(creds);
}
grpc_call_credentials* grpc_call_credentials_ref(grpc_call_credentials* creds) {
if (creds == nullptr) return nullptr;
gpr_ref(&creds->refcount);
return creds;
}
void grpc_call_credentials_unref(grpc_call_credentials* creds) {
if (creds == nullptr) return;
if (gpr_unref(&creds->refcount)) {
if (creds->vtable->destruct != nullptr) {
creds->vtable->destruct(creds);
}
gpr_free(creds);
}
if (creds) creds->Unref();
}
void grpc_call_credentials_release(grpc_call_credentials* creds) {
GRPC_API_TRACE("grpc_call_credentials_release(creds=%p)", 1, (creds));
grpc_core::ExecCtx exec_ctx;
grpc_call_credentials_unref(creds);
}
bool grpc_call_credentials_get_request_metadata(
grpc_call_credentials* creds, grpc_polling_entity* pollent,
grpc_auth_metadata_context context, grpc_credentials_mdelem_array* md_array,
grpc_closure* on_request_metadata, grpc_error** error) {
if (creds == nullptr || creds->vtable->get_request_metadata == nullptr) {
return true;
}
return creds->vtable->get_request_metadata(creds, pollent, context, md_array,
on_request_metadata, error);
}
void grpc_call_credentials_cancel_get_request_metadata(
grpc_call_credentials* creds, grpc_credentials_mdelem_array* md_array,
grpc_error* error) {
if (creds == nullptr ||
creds->vtable->cancel_get_request_metadata == nullptr) {
return;
}
creds->vtable->cancel_get_request_metadata(creds, md_array, error);
}
grpc_security_status grpc_channel_credentials_create_security_connector(
grpc_channel_credentials* channel_creds, const char* target,
const grpc_channel_args* args, grpc_channel_security_connector** sc,
grpc_channel_args** new_args) {
*new_args = nullptr;
if (channel_creds == nullptr) {
return GRPC_SECURITY_ERROR;
}
GPR_ASSERT(channel_creds->vtable->create_security_connector != nullptr);
return channel_creds->vtable->create_security_connector(
channel_creds, nullptr, target, args, sc, new_args);
}
grpc_channel_credentials*
grpc_channel_credentials_duplicate_without_call_credentials(
grpc_channel_credentials* channel_creds) {
if (channel_creds != nullptr && channel_creds->vtable != nullptr &&
channel_creds->vtable->duplicate_without_call_credentials != nullptr) {
return channel_creds->vtable->duplicate_without_call_credentials(
channel_creds);
} else {
return grpc_channel_credentials_ref(channel_creds);
}
if (creds) creds->Unref();
}
static void credentials_pointer_arg_destroy(void* p) {
grpc_channel_credentials_unref(static_cast<grpc_channel_credentials*>(p));
static_cast<grpc_channel_credentials*>(p)->Unref();
}
static void* credentials_pointer_arg_copy(void* p) {
return grpc_channel_credentials_ref(
static_cast<grpc_channel_credentials*>(p));
return static_cast<grpc_channel_credentials*>(p)->Ref().release();
}
static int credentials_pointer_cmp(void* a, void* b) { return GPR_ICMP(a, b); }
@ -191,63 +95,35 @@ grpc_channel_credentials* grpc_channel_credentials_find_in_args(
return nullptr;
}
grpc_server_credentials* grpc_server_credentials_ref(
grpc_server_credentials* creds) {
if (creds == nullptr) return nullptr;
gpr_ref(&creds->refcount);
return creds;
}
void grpc_server_credentials_unref(grpc_server_credentials* creds) {
if (creds == nullptr) return;
if (gpr_unref(&creds->refcount)) {
if (creds->vtable->destruct != nullptr) {
creds->vtable->destruct(creds);
}
if (creds->processor.destroy != nullptr &&
creds->processor.state != nullptr) {
creds->processor.destroy(creds->processor.state);
}
gpr_free(creds);
}
}
void grpc_server_credentials_release(grpc_server_credentials* creds) {
GRPC_API_TRACE("grpc_server_credentials_release(creds=%p)", 1, (creds));
grpc_core::ExecCtx exec_ctx;
grpc_server_credentials_unref(creds);
if (creds) creds->Unref();
}
grpc_security_status grpc_server_credentials_create_security_connector(
grpc_server_credentials* creds, grpc_server_security_connector** sc) {
if (creds == nullptr || creds->vtable->create_security_connector == nullptr) {
gpr_log(GPR_ERROR, "Server credentials cannot create security context.");
return GRPC_SECURITY_ERROR;
}
return creds->vtable->create_security_connector(creds, sc);
}
void grpc_server_credentials_set_auth_metadata_processor(
grpc_server_credentials* creds, grpc_auth_metadata_processor processor) {
void grpc_server_credentials::set_auth_metadata_processor(
const grpc_auth_metadata_processor& processor) {
GRPC_API_TRACE(
"grpc_server_credentials_set_auth_metadata_processor("
"creds=%p, "
"processor=grpc_auth_metadata_processor { process: %p, state: %p })",
3, (creds, (void*)(intptr_t)processor.process, processor.state));
if (creds == nullptr) return;
if (creds->processor.destroy != nullptr &&
creds->processor.state != nullptr) {
creds->processor.destroy(creds->processor.state);
}
creds->processor = processor;
3, (this, (void*)(intptr_t)processor.process, processor.state));
DestroyProcessor();
processor_ = processor;
}
void grpc_server_credentials_set_auth_metadata_processor(
grpc_server_credentials* creds, grpc_auth_metadata_processor processor) {
GPR_DEBUG_ASSERT(creds != nullptr);
creds->set_auth_metadata_processor(processor);
}
static void server_credentials_pointer_arg_destroy(void* p) {
grpc_server_credentials_unref(static_cast<grpc_server_credentials*>(p));
static_cast<grpc_server_credentials*>(p)->Unref();
}
static void* server_credentials_pointer_arg_copy(void* p) {
return grpc_server_credentials_ref(static_cast<grpc_server_credentials*>(p));
return static_cast<grpc_server_credentials*>(p)->Ref().release();
}
static int server_credentials_pointer_cmp(void* a, void* b) {

@ -26,6 +26,7 @@
#include <grpc/support/sync.h>
#include "src/core/lib/transport/metadata_batch.h"
#include "src/core/lib/gprpp/ref_counted.h"
#include "src/core/lib/http/httpcli.h"
#include "src/core/lib/http/parser.h"
#include "src/core/lib/iomgr/polling_entity.h"
@ -90,44 +91,46 @@ void grpc_override_well_known_credentials_path_getter(
#define GRPC_ARG_CHANNEL_CREDENTIALS "grpc.channel_credentials"
typedef struct {
void (*destruct)(grpc_channel_credentials* c);
grpc_security_status (*create_security_connector)(
grpc_channel_credentials* c, grpc_call_credentials* call_creds,
// This type is forward declared as a C struct and we cannot define it as a
// class. Otherwise, compiler will complain about type mismatch due to
// -Wmismatched-tags.
struct grpc_channel_credentials
: grpc_core::RefCounted<grpc_channel_credentials> {
public:
explicit grpc_channel_credentials(const char* type) : type_(type) {}
virtual ~grpc_channel_credentials() = default;
// Creates a security connector for the channel. May also create new channel
// args for the channel to be used in place of the passed in const args if
// returned non NULL. In that case the caller is responsible for destroying
// new_args after channel creation.
virtual grpc_core::RefCountedPtr<grpc_channel_security_connector>
create_security_connector(
grpc_core::RefCountedPtr<grpc_call_credentials> call_creds,
const char* target, const grpc_channel_args* args,
grpc_channel_security_connector** sc, grpc_channel_args** new_args);
grpc_channel_credentials* (*duplicate_without_call_credentials)(
grpc_channel_credentials* c);
} grpc_channel_credentials_vtable;
struct grpc_channel_credentials {
const grpc_channel_credentials_vtable* vtable;
const char* type;
gpr_refcount refcount;
grpc_channel_args** new_args) {
// Tell clang-tidy that call_creds cannot be passed as const-ref.
call_creds.reset();
GRPC_ABSTRACT;
}
// Creates a version of the channel credentials without any attached call
// credentials. This can be used in order to open a channel to a non-trusted
// gRPC load balancer.
virtual grpc_core::RefCountedPtr<grpc_channel_credentials>
duplicate_without_call_credentials() {
// By default we just increment the refcount.
return Ref();
}
const char* type() const { return type_; }
GRPC_ABSTRACT_BASE_CLASS
private:
const char* type_;
};
grpc_channel_credentials* grpc_channel_credentials_ref(
grpc_channel_credentials* creds);
void grpc_channel_credentials_unref(grpc_channel_credentials* creds);
/* Creates a security connector for the channel. May also create new channel
args for the channel to be used in place of the passed in const args if
returned non NULL. In that case the caller is responsible for destroying
new_args after channel creation. */
grpc_security_status grpc_channel_credentials_create_security_connector(
grpc_channel_credentials* creds, const char* target,
const grpc_channel_args* args, grpc_channel_security_connector** sc,
grpc_channel_args** new_args);
/* Creates a version of the channel credentials without any attached call
credentials. This can be used in order to open a channel to a non-trusted
gRPC load balancer. */
grpc_channel_credentials*
grpc_channel_credentials_duplicate_without_call_credentials(
grpc_channel_credentials* creds);
/* Util to encapsulate the channel credentials in a channel arg. */
grpc_arg grpc_channel_credentials_to_arg(grpc_channel_credentials* credentials);
@ -158,44 +161,39 @@ void grpc_credentials_mdelem_array_destroy(grpc_credentials_mdelem_array* list);
/* --- grpc_call_credentials. --- */
typedef struct {
void (*destruct)(grpc_call_credentials* c);
bool (*get_request_metadata)(grpc_call_credentials* c,
grpc_polling_entity* pollent,
grpc_auth_metadata_context context,
grpc_credentials_mdelem_array* md_array,
grpc_closure* on_request_metadata,
grpc_error** error);
void (*cancel_get_request_metadata)(grpc_call_credentials* c,
grpc_credentials_mdelem_array* md_array,
grpc_error* error);
} grpc_call_credentials_vtable;
struct grpc_call_credentials {
const grpc_call_credentials_vtable* vtable;
const char* type;
gpr_refcount refcount;
// This type is forward declared as a C struct and we cannot define it as a
// class. Otherwise, compiler will complain about type mismatch due to
// -Wmismatched-tags.
struct grpc_call_credentials
: public grpc_core::RefCounted<grpc_call_credentials> {
public:
explicit grpc_call_credentials(const char* type) : type_(type) {}
virtual ~grpc_call_credentials() = default;
// Returns true if completed synchronously, in which case \a error will
// be set to indicate the result. Otherwise, \a on_request_metadata will
// be invoked asynchronously when complete. \a md_array will be populated
// with the resulting metadata once complete.
virtual bool get_request_metadata(grpc_polling_entity* pollent,
grpc_auth_metadata_context context,
grpc_credentials_mdelem_array* md_array,
grpc_closure* on_request_metadata,
grpc_error** error) GRPC_ABSTRACT;
// Cancels a pending asynchronous operation started by
// grpc_call_credentials_get_request_metadata() with the corresponding
// value of \a md_array.
virtual void cancel_get_request_metadata(
grpc_credentials_mdelem_array* md_array, grpc_error* error) GRPC_ABSTRACT;
const char* type() const { return type_; }
GRPC_ABSTRACT_BASE_CLASS
private:
const char* type_;
};
grpc_call_credentials* grpc_call_credentials_ref(grpc_call_credentials* creds);
void grpc_call_credentials_unref(grpc_call_credentials* creds);
/// Returns true if completed synchronously, in which case \a error will
/// be set to indicate the result. Otherwise, \a on_request_metadata will
/// be invoked asynchronously when complete. \a md_array will be populated
/// with the resulting metadata once complete.
bool grpc_call_credentials_get_request_metadata(
grpc_call_credentials* creds, grpc_polling_entity* pollent,
grpc_auth_metadata_context context, grpc_credentials_mdelem_array* md_array,
grpc_closure* on_request_metadata, grpc_error** error);
/// Cancels a pending asynchronous operation started by
/// grpc_call_credentials_get_request_metadata() with the corresponding
/// value of \a md_array.
void grpc_call_credentials_cancel_get_request_metadata(
grpc_call_credentials* c, grpc_credentials_mdelem_array* md_array,
grpc_error* error);
/* Metadata-only credentials with the specified key and value where
asynchronicity can be simulated for testing. */
grpc_call_credentials* grpc_md_only_test_credentials_create(
@ -203,26 +201,40 @@ grpc_call_credentials* grpc_md_only_test_credentials_create(
/* --- grpc_server_credentials. --- */
typedef struct {
void (*destruct)(grpc_server_credentials* c);
grpc_security_status (*create_security_connector)(
grpc_server_credentials* c, grpc_server_security_connector** sc);
} grpc_server_credentials_vtable;
struct grpc_server_credentials {
const grpc_server_credentials_vtable* vtable;
const char* type;
gpr_refcount refcount;
grpc_auth_metadata_processor processor;
};
// This type is forward declared as a C struct and we cannot define it as a
// class. Otherwise, compiler will complain about type mismatch due to
// -Wmismatched-tags.
struct grpc_server_credentials
: public grpc_core::RefCounted<grpc_server_credentials> {
public:
explicit grpc_server_credentials(const char* type) : type_(type) {}
grpc_security_status grpc_server_credentials_create_security_connector(
grpc_server_credentials* creds, grpc_server_security_connector** sc);
virtual ~grpc_server_credentials() { DestroyProcessor(); }
grpc_server_credentials* grpc_server_credentials_ref(
grpc_server_credentials* creds);
virtual grpc_core::RefCountedPtr<grpc_server_security_connector>
create_security_connector() GRPC_ABSTRACT;
void grpc_server_credentials_unref(grpc_server_credentials* creds);
const char* type() const { return type_; }
const grpc_auth_metadata_processor& auth_metadata_processor() const {
return processor_;
}
void set_auth_metadata_processor(
const grpc_auth_metadata_processor& processor);
GRPC_ABSTRACT_BASE_CLASS
private:
void DestroyProcessor() {
if (processor_.destroy != nullptr && processor_.state != nullptr) {
processor_.destroy(processor_.state);
}
}
const char* type_;
grpc_auth_metadata_processor processor_ =
grpc_auth_metadata_processor(); // Zero-initialize the C struct.
};
#define GRPC_SERVER_CREDENTIALS_ARG "grpc.server_credentials"
@ -233,15 +245,27 @@ grpc_server_credentials* grpc_find_server_credentials_in_args(
/* -- Credentials Metadata Request. -- */
typedef struct {
grpc_call_credentials* creds;
struct grpc_credentials_metadata_request {
explicit grpc_credentials_metadata_request(
grpc_core::RefCountedPtr<grpc_call_credentials> creds)
: creds(std::move(creds)) {}
~grpc_credentials_metadata_request() {
grpc_http_response_destroy(&response);
}
grpc_core::RefCountedPtr<grpc_call_credentials> creds;
grpc_http_response response;
} grpc_credentials_metadata_request;
};
grpc_credentials_metadata_request* grpc_credentials_metadata_request_create(
grpc_call_credentials* creds);
inline grpc_credentials_metadata_request*
grpc_credentials_metadata_request_create(
grpc_core::RefCountedPtr<grpc_call_credentials> creds) {
return grpc_core::New<grpc_credentials_metadata_request>(std::move(creds));
}
void grpc_credentials_metadata_request_destroy(
grpc_credentials_metadata_request* r);
inline void grpc_credentials_metadata_request_destroy(
grpc_credentials_metadata_request* r) {
grpc_core::Delete(r);
}
#endif /* GRPC_CORE_LIB_SECURITY_CREDENTIALS_CREDENTIALS_H */

@ -33,49 +33,45 @@
/* -- Fake transport security credentials. -- */
static grpc_security_status fake_transport_security_create_security_connector(
grpc_channel_credentials* c, grpc_call_credentials* call_creds,
const char* target, const grpc_channel_args* args,
grpc_channel_security_connector** sc, grpc_channel_args** new_args) {
*sc =
grpc_fake_channel_security_connector_create(c, call_creds, target, args);
return GRPC_SECURITY_OK;
}
static grpc_security_status
fake_transport_security_server_create_security_connector(
grpc_server_credentials* c, grpc_server_security_connector** sc) {
*sc = grpc_fake_server_security_connector_create(c);
return GRPC_SECURITY_OK;
}
namespace {
class grpc_fake_channel_credentials final : public grpc_channel_credentials {
public:
grpc_fake_channel_credentials()
: grpc_channel_credentials(
GRPC_CHANNEL_CREDENTIALS_TYPE_FAKE_TRANSPORT_SECURITY) {}
~grpc_fake_channel_credentials() override = default;
grpc_core::RefCountedPtr<grpc_channel_security_connector>
create_security_connector(
grpc_core::RefCountedPtr<grpc_call_credentials> call_creds,
const char* target, const grpc_channel_args* args,
grpc_channel_args** new_args) override {
return grpc_fake_channel_security_connector_create(
this->Ref(), std::move(call_creds), target, args);
}
};
class grpc_fake_server_credentials final : public grpc_server_credentials {
public:
grpc_fake_server_credentials()
: grpc_server_credentials(
GRPC_CHANNEL_CREDENTIALS_TYPE_FAKE_TRANSPORT_SECURITY) {}
~grpc_fake_server_credentials() override = default;
grpc_core::RefCountedPtr<grpc_server_security_connector>
create_security_connector() override {
return grpc_fake_server_security_connector_create(this->Ref());
}
};
} // namespace
static grpc_channel_credentials_vtable
fake_transport_security_credentials_vtable = {
nullptr, fake_transport_security_create_security_connector, nullptr};
static grpc_server_credentials_vtable
fake_transport_security_server_credentials_vtable = {
nullptr, fake_transport_security_server_create_security_connector};
grpc_channel_credentials* grpc_fake_transport_security_credentials_create(
void) {
grpc_channel_credentials* c = static_cast<grpc_channel_credentials*>(
gpr_zalloc(sizeof(grpc_channel_credentials)));
c->type = GRPC_CHANNEL_CREDENTIALS_TYPE_FAKE_TRANSPORT_SECURITY;
c->vtable = &fake_transport_security_credentials_vtable;
gpr_ref_init(&c->refcount, 1);
return c;
grpc_channel_credentials* grpc_fake_transport_security_credentials_create() {
return grpc_core::New<grpc_fake_channel_credentials>();
}
grpc_server_credentials* grpc_fake_transport_security_server_credentials_create(
void) {
grpc_server_credentials* c = static_cast<grpc_server_credentials*>(
gpr_malloc(sizeof(grpc_server_credentials)));
memset(c, 0, sizeof(grpc_server_credentials));
c->type = GRPC_CHANNEL_CREDENTIALS_TYPE_FAKE_TRANSPORT_SECURITY;
gpr_ref_init(&c->refcount, 1);
c->vtable = &fake_transport_security_server_credentials_vtable;
return c;
grpc_server_credentials*
grpc_fake_transport_security_server_credentials_create() {
return grpc_core::New<grpc_fake_server_credentials>();
}
grpc_arg grpc_fake_transport_expected_targets_arg(char* expected_targets) {
@ -92,46 +88,25 @@ const char* grpc_fake_transport_get_expected_targets(
/* -- Metadata-only test credentials. -- */
static void md_only_test_destruct(grpc_call_credentials* creds) {
grpc_md_only_test_credentials* c =
reinterpret_cast<grpc_md_only_test_credentials*>(creds);
GRPC_MDELEM_UNREF(c->md);
}
static bool md_only_test_get_request_metadata(
grpc_call_credentials* creds, grpc_polling_entity* pollent,
grpc_auth_metadata_context context, grpc_credentials_mdelem_array* md_array,
grpc_closure* on_request_metadata, grpc_error** error) {
grpc_md_only_test_credentials* c =
reinterpret_cast<grpc_md_only_test_credentials*>(creds);
grpc_credentials_mdelem_array_add(md_array, c->md);
if (c->is_async) {
bool grpc_md_only_test_credentials::get_request_metadata(
grpc_polling_entity* pollent, grpc_auth_metadata_context context,
grpc_credentials_mdelem_array* md_array, grpc_closure* on_request_metadata,
grpc_error** error) {
grpc_credentials_mdelem_array_add(md_array, md_);
if (is_async_) {
GRPC_CLOSURE_SCHED(on_request_metadata, GRPC_ERROR_NONE);
return false;
}
return true;
}
static void md_only_test_cancel_get_request_metadata(
grpc_call_credentials* c, grpc_credentials_mdelem_array* md_array,
grpc_error* error) {
void grpc_md_only_test_credentials::cancel_get_request_metadata(
grpc_credentials_mdelem_array* md_array, grpc_error* error) {
GRPC_ERROR_UNREF(error);
}
static grpc_call_credentials_vtable md_only_test_vtable = {
md_only_test_destruct, md_only_test_get_request_metadata,
md_only_test_cancel_get_request_metadata};
grpc_call_credentials* grpc_md_only_test_credentials_create(
const char* md_key, const char* md_value, bool is_async) {
grpc_md_only_test_credentials* c =
static_cast<grpc_md_only_test_credentials*>(
gpr_zalloc(sizeof(grpc_md_only_test_credentials)));
c->base.type = GRPC_CALL_CREDENTIALS_TYPE_OAUTH2;
c->base.vtable = &md_only_test_vtable;
gpr_ref_init(&c->base.refcount, 1);
c->md = grpc_mdelem_from_slices(grpc_slice_from_copied_string(md_key),
grpc_slice_from_copied_string(md_value));
c->is_async = is_async;
return &c->base;
return grpc_core::New<grpc_md_only_test_credentials>(md_key, md_value,
is_async);
}

@ -55,10 +55,28 @@ const char* grpc_fake_transport_get_expected_targets(
/* -- Metadata-only Test credentials. -- */
typedef struct {
grpc_call_credentials base;
grpc_mdelem md;
bool is_async;
} grpc_md_only_test_credentials;
class grpc_md_only_test_credentials : public grpc_call_credentials {
public:
grpc_md_only_test_credentials(const char* md_key, const char* md_value,
bool is_async)
: grpc_call_credentials(GRPC_CALL_CREDENTIALS_TYPE_OAUTH2),
md_(grpc_mdelem_from_slices(grpc_slice_from_copied_string(md_key),
grpc_slice_from_copied_string(md_value))),
is_async_(is_async) {}
~grpc_md_only_test_credentials() override { GRPC_MDELEM_UNREF(md_); }
bool get_request_metadata(grpc_polling_entity* pollent,
grpc_auth_metadata_context context,
grpc_credentials_mdelem_array* md_array,
grpc_closure* on_request_metadata,
grpc_error** error) override;
void cancel_get_request_metadata(grpc_credentials_mdelem_array* md_array,
grpc_error* error) override;
private:
grpc_mdelem md_;
bool is_async_;
};
#endif /* GRPC_CORE_LIB_SECURITY_CREDENTIALS_FAKE_FAKE_CREDENTIALS_H */

@ -30,6 +30,7 @@
#include "src/core/lib/channel/channel_args.h"
#include "src/core/lib/gpr/env.h"
#include "src/core/lib/gpr/string.h"
#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/lib/http/httpcli.h"
#include "src/core/lib/http/parser.h"
#include "src/core/lib/iomgr/load_file.h"
@ -72,20 +73,11 @@ typedef struct {
grpc_http_response response;
} metadata_server_detector;
static void google_default_credentials_destruct(
grpc_channel_credentials* creds) {
grpc_google_default_channel_credentials* c =
reinterpret_cast<grpc_google_default_channel_credentials*>(creds);
grpc_channel_credentials_unref(c->alts_creds);
grpc_channel_credentials_unref(c->ssl_creds);
}
static grpc_security_status google_default_create_security_connector(
grpc_channel_credentials* creds, grpc_call_credentials* call_creds,
grpc_core::RefCountedPtr<grpc_channel_security_connector>
grpc_google_default_channel_credentials::create_security_connector(
grpc_core::RefCountedPtr<grpc_call_credentials> call_creds,
const char* target, const grpc_channel_args* args,
grpc_channel_security_connector** sc, grpc_channel_args** new_args) {
grpc_google_default_channel_credentials* c =
reinterpret_cast<grpc_google_default_channel_credentials*>(creds);
grpc_channel_args** new_args) {
bool is_grpclb_load_balancer = grpc_channel_arg_get_bool(
grpc_channel_args_find(args, GRPC_ARG_ADDRESS_IS_GRPCLB_LOAD_BALANCER),
false);
@ -95,22 +87,22 @@ static grpc_security_status google_default_create_security_connector(
false);
bool use_alts =
is_grpclb_load_balancer || is_backend_from_grpclb_load_balancer;
grpc_security_status status = GRPC_SECURITY_ERROR;
/* Return failure if ALTS is selected but not running on GCE. */
if (use_alts && !g_is_on_gce) {
gpr_log(GPR_ERROR, "ALTS is selected, but not running on GCE.");
goto end;
return nullptr;
}
status = use_alts ? c->alts_creds->vtable->create_security_connector(
c->alts_creds, call_creds, target, args, sc, new_args)
: c->ssl_creds->vtable->create_security_connector(
c->ssl_creds, call_creds, target, args, sc, new_args);
/* grpclb-specific channel args are removed from the channel args set
* to ensure backends and fallback adresses will have the same set of channel
* args. By doing that, it guarantees the connections to backends will not be
* torn down and re-connected when switching in and out of fallback mode.
*/
end:
grpc_core::RefCountedPtr<grpc_channel_security_connector> sc =
use_alts ? alts_creds_->create_security_connector(call_creds, target,
args, new_args)
: ssl_creds_->create_security_connector(call_creds, target, args,
new_args);
/* grpclb-specific channel args are removed from the channel args set
* to ensure backends and fallback adresses will have the same set of channel
* args. By doing that, it guarantees the connections to backends will not be
* torn down and re-connected when switching in and out of fallback mode.
*/
if (use_alts) {
static const char* args_to_remove[] = {
GRPC_ARG_ADDRESS_IS_GRPCLB_LOAD_BALANCER,
@ -119,13 +111,9 @@ end:
*new_args = grpc_channel_args_copy_and_add_and_remove(
args, args_to_remove, GPR_ARRAY_SIZE(args_to_remove), nullptr, 0);
}
return status;
return sc;
}
static grpc_channel_credentials_vtable google_default_credentials_vtable = {
google_default_credentials_destruct,
google_default_create_security_connector, nullptr};
static void on_metadata_server_detection_http_response(void* user_data,
grpc_error* error) {
metadata_server_detector* detector =
@ -215,11 +203,11 @@ static int is_metadata_server_reachable() {
/* Takes ownership of creds_path if not NULL. */
static grpc_error* create_default_creds_from_path(
char* creds_path, grpc_call_credentials** creds) {
char* creds_path, grpc_core::RefCountedPtr<grpc_call_credentials>* creds) {
grpc_json* json = nullptr;
grpc_auth_json_key key;
grpc_auth_refresh_token token;
grpc_call_credentials* result = nullptr;
grpc_core::RefCountedPtr<grpc_call_credentials> result;
grpc_slice creds_data = grpc_empty_slice();
grpc_error* error = GRPC_ERROR_NONE;
if (creds_path == nullptr) {
@ -276,9 +264,9 @@ end:
return error;
}
grpc_channel_credentials* grpc_google_default_credentials_create(void) {
grpc_channel_credentials* grpc_google_default_credentials_create() {
grpc_channel_credentials* result = nullptr;
grpc_call_credentials* call_creds = nullptr;
grpc_core::RefCountedPtr<grpc_call_credentials> call_creds;
grpc_error* error = GRPC_ERROR_CREATE_FROM_STATIC_STRING(
"Failed to create Google credentials");
grpc_error* err;
@ -316,7 +304,8 @@ grpc_channel_credentials* grpc_google_default_credentials_create(void) {
gpr_mu_unlock(&g_state_mu);
if (g_metadata_server_available) {
call_creds = grpc_google_compute_engine_credentials_create(nullptr);
call_creds = grpc_core::RefCountedPtr<grpc_call_credentials>(
grpc_google_compute_engine_credentials_create(nullptr));
if (call_creds == nullptr) {
error = grpc_error_add_child(
error, GRPC_ERROR_CREATE_FROM_STATIC_STRING(
@ -327,23 +316,23 @@ grpc_channel_credentials* grpc_google_default_credentials_create(void) {
end:
if (call_creds != nullptr) {
/* Create google default credentials. */
auto creds = static_cast<grpc_google_default_channel_credentials*>(
gpr_zalloc(sizeof(grpc_google_default_channel_credentials)));
creds->base.vtable = &google_default_credentials_vtable;
creds->base.type = GRPC_CHANNEL_CREDENTIALS_TYPE_GOOGLE_DEFAULT;
gpr_ref_init(&creds->base.refcount, 1);
creds->ssl_creds =
grpc_channel_credentials* ssl_creds =
grpc_ssl_credentials_create(nullptr, nullptr, nullptr, nullptr);
GPR_ASSERT(creds->ssl_creds != nullptr);
GPR_ASSERT(ssl_creds != nullptr);
grpc_alts_credentials_options* options =
grpc_alts_credentials_client_options_create();
creds->alts_creds = grpc_alts_credentials_create(options);
grpc_channel_credentials* alts_creds =
grpc_alts_credentials_create(options);
grpc_alts_credentials_options_destroy(options);
result = grpc_composite_channel_credentials_create(&creds->base, call_creds,
nullptr);
auto creds =
grpc_core::MakeRefCounted<grpc_google_default_channel_credentials>(
alts_creds != nullptr ? alts_creds->Ref() : nullptr,
ssl_creds != nullptr ? ssl_creds->Ref() : nullptr);
if (ssl_creds) ssl_creds->Unref();
if (alts_creds) alts_creds->Unref();
result = grpc_composite_channel_credentials_create(
creds.get(), call_creds.get(), nullptr);
GPR_ASSERT(result != nullptr);
grpc_channel_credentials_unref(&creds->base);
grpc_call_credentials_unref(call_creds);
} else {
gpr_log(GPR_ERROR, "Could not create google default credentials: %s",
grpc_error_string(error));

@ -21,6 +21,7 @@
#include <grpc/support/port_platform.h>
#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/lib/security/credentials/credentials.h"
#define GRPC_GOOGLE_CLOUD_SDK_CONFIG_DIRECTORY "gcloud"
@ -39,11 +40,33 @@
"/" GRPC_GOOGLE_WELL_KNOWN_CREDENTIALS_FILE
#endif
typedef struct {
grpc_channel_credentials base;
grpc_channel_credentials* alts_creds;
grpc_channel_credentials* ssl_creds;
} grpc_google_default_channel_credentials;
class grpc_google_default_channel_credentials
: public grpc_channel_credentials {
public:
grpc_google_default_channel_credentials(
grpc_core::RefCountedPtr<grpc_channel_credentials> alts_creds,
grpc_core::RefCountedPtr<grpc_channel_credentials> ssl_creds)
: grpc_channel_credentials(GRPC_CHANNEL_CREDENTIALS_TYPE_GOOGLE_DEFAULT),
alts_creds_(std::move(alts_creds)),
ssl_creds_(std::move(ssl_creds)) {}
~grpc_google_default_channel_credentials() override = default;
grpc_core::RefCountedPtr<grpc_channel_security_connector>
create_security_connector(
grpc_core::RefCountedPtr<grpc_call_credentials> call_creds,
const char* target, const grpc_channel_args* args,
grpc_channel_args** new_args) override;
const grpc_channel_credentials* alts_creds() const {
return alts_creds_.get();
}
const grpc_channel_credentials* ssl_creds() const { return ssl_creds_.get(); }
private:
grpc_core::RefCountedPtr<grpc_channel_credentials> alts_creds_;
grpc_core::RefCountedPtr<grpc_channel_credentials> ssl_creds_;
};
namespace grpc_core {
namespace internal {

@ -22,6 +22,7 @@
#include <string.h>
#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/lib/surface/api_trace.h"
#include <grpc/support/alloc.h>
@ -29,32 +30,37 @@
#include <grpc/support/string_util.h>
#include <grpc/support/sync.h>
static void iam_destruct(grpc_call_credentials* creds) {
grpc_google_iam_credentials* c =
reinterpret_cast<grpc_google_iam_credentials*>(creds);
grpc_credentials_mdelem_array_destroy(&c->md_array);
grpc_google_iam_credentials::~grpc_google_iam_credentials() {
grpc_credentials_mdelem_array_destroy(&md_array_);
}
static bool iam_get_request_metadata(grpc_call_credentials* creds,
grpc_polling_entity* pollent,
grpc_auth_metadata_context context,
grpc_credentials_mdelem_array* md_array,
grpc_closure* on_request_metadata,
grpc_error** error) {
grpc_google_iam_credentials* c =
reinterpret_cast<grpc_google_iam_credentials*>(creds);
grpc_credentials_mdelem_array_append(md_array, &c->md_array);
bool grpc_google_iam_credentials::get_request_metadata(
grpc_polling_entity* pollent, grpc_auth_metadata_context context,
grpc_credentials_mdelem_array* md_array, grpc_closure* on_request_metadata,
grpc_error** error) {
grpc_credentials_mdelem_array_append(md_array, &md_array_);
return true;
}
static void iam_cancel_get_request_metadata(
grpc_call_credentials* c, grpc_credentials_mdelem_array* md_array,
grpc_error* error) {
void grpc_google_iam_credentials::cancel_get_request_metadata(
grpc_credentials_mdelem_array* md_array, grpc_error* error) {
GRPC_ERROR_UNREF(error);
}
static grpc_call_credentials_vtable iam_vtable = {
iam_destruct, iam_get_request_metadata, iam_cancel_get_request_metadata};
grpc_google_iam_credentials::grpc_google_iam_credentials(
const char* token, const char* authority_selector)
: grpc_call_credentials(GRPC_CALL_CREDENTIALS_TYPE_IAM) {
grpc_mdelem md = grpc_mdelem_from_slices(
grpc_slice_from_static_string(GRPC_IAM_AUTHORIZATION_TOKEN_METADATA_KEY),
grpc_slice_from_copied_string(token));
grpc_credentials_mdelem_array_add(&md_array_, md);
GRPC_MDELEM_UNREF(md);
md = grpc_mdelem_from_slices(
grpc_slice_from_static_string(GRPC_IAM_AUTHORITY_SELECTOR_METADATA_KEY),
grpc_slice_from_copied_string(authority_selector));
grpc_credentials_mdelem_array_add(&md_array_, md);
GRPC_MDELEM_UNREF(md);
}
grpc_call_credentials* grpc_google_iam_credentials_create(
const char* token, const char* authority_selector, void* reserved) {
@ -66,21 +72,7 @@ grpc_call_credentials* grpc_google_iam_credentials_create(
GPR_ASSERT(reserved == nullptr);
GPR_ASSERT(token != nullptr);
GPR_ASSERT(authority_selector != nullptr);
grpc_google_iam_credentials* c =
static_cast<grpc_google_iam_credentials*>(gpr_zalloc(sizeof(*c)));
c->base.type = GRPC_CALL_CREDENTIALS_TYPE_IAM;
c->base.vtable = &iam_vtable;
gpr_ref_init(&c->base.refcount, 1);
grpc_mdelem md = grpc_mdelem_from_slices(
grpc_slice_from_static_string(GRPC_IAM_AUTHORIZATION_TOKEN_METADATA_KEY),
grpc_slice_from_copied_string(token));
grpc_credentials_mdelem_array_add(&c->md_array, md);
GRPC_MDELEM_UNREF(md);
md = grpc_mdelem_from_slices(
grpc_slice_from_static_string(GRPC_IAM_AUTHORITY_SELECTOR_METADATA_KEY),
grpc_slice_from_copied_string(authority_selector));
grpc_credentials_mdelem_array_add(&c->md_array, md);
GRPC_MDELEM_UNREF(md);
return &c->base;
return grpc_core::MakeRefCounted<grpc_google_iam_credentials>(
token, authority_selector)
.release();
}

@ -23,9 +23,23 @@
#include "src/core/lib/security/credentials/credentials.h"
typedef struct {
grpc_call_credentials base;
grpc_credentials_mdelem_array md_array;
} grpc_google_iam_credentials;
class grpc_google_iam_credentials : public grpc_call_credentials {
public:
grpc_google_iam_credentials(const char* token,
const char* authority_selector);
~grpc_google_iam_credentials() override;
bool get_request_metadata(grpc_polling_entity* pollent,
grpc_auth_metadata_context context,
grpc_credentials_mdelem_array* md_array,
grpc_closure* on_request_metadata,
grpc_error** error) override;
void cancel_get_request_metadata(grpc_credentials_mdelem_array* md_array,
grpc_error* error) override;
private:
grpc_credentials_mdelem_array md_array_;
};
#endif /* GRPC_CORE_LIB_SECURITY_CREDENTIALS_IAM_IAM_CREDENTIALS_H */

@ -23,6 +23,8 @@
#include <inttypes.h>
#include <string.h>
#include "src/core/lib/gprpp/ref_counted.h"
#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/lib/surface/api_trace.h"
#include <grpc/support/alloc.h>
@ -30,71 +32,66 @@
#include <grpc/support/string_util.h>
#include <grpc/support/sync.h>
static void jwt_reset_cache(grpc_service_account_jwt_access_credentials* c) {
GRPC_MDELEM_UNREF(c->cached.jwt_md);
c->cached.jwt_md = GRPC_MDNULL;
if (c->cached.service_url != nullptr) {
gpr_free(c->cached.service_url);
c->cached.service_url = nullptr;
void grpc_service_account_jwt_access_credentials::reset_cache() {
GRPC_MDELEM_UNREF(cached_.jwt_md);
cached_.jwt_md = GRPC_MDNULL;
if (cached_.service_url != nullptr) {
gpr_free(cached_.service_url);
cached_.service_url = nullptr;
}
c->cached.jwt_expiration = gpr_inf_past(GPR_CLOCK_REALTIME);
cached_.jwt_expiration = gpr_inf_past(GPR_CLOCK_REALTIME);
}
static void jwt_destruct(grpc_call_credentials* creds) {
grpc_service_account_jwt_access_credentials* c =
reinterpret_cast<grpc_service_account_jwt_access_credentials*>(creds);
grpc_auth_json_key_destruct(&c->key);
jwt_reset_cache(c);
gpr_mu_destroy(&c->cache_mu);
grpc_service_account_jwt_access_credentials::
~grpc_service_account_jwt_access_credentials() {
grpc_auth_json_key_destruct(&key_);
reset_cache();
gpr_mu_destroy(&cache_mu_);
}
static bool jwt_get_request_metadata(grpc_call_credentials* creds,
grpc_polling_entity* pollent,
grpc_auth_metadata_context context,
grpc_credentials_mdelem_array* md_array,
grpc_closure* on_request_metadata,
grpc_error** error) {
grpc_service_account_jwt_access_credentials* c =
reinterpret_cast<grpc_service_account_jwt_access_credentials*>(creds);
bool grpc_service_account_jwt_access_credentials::get_request_metadata(
grpc_polling_entity* pollent, grpc_auth_metadata_context context,
grpc_credentials_mdelem_array* md_array, grpc_closure* on_request_metadata,
grpc_error** error) {
gpr_timespec refresh_threshold = gpr_time_from_seconds(
GRPC_SECURE_TOKEN_REFRESH_THRESHOLD_SECS, GPR_TIMESPAN);
/* See if we can return a cached jwt. */
grpc_mdelem jwt_md = GRPC_MDNULL;
{
gpr_mu_lock(&c->cache_mu);
if (c->cached.service_url != nullptr &&
strcmp(c->cached.service_url, context.service_url) == 0 &&
!GRPC_MDISNULL(c->cached.jwt_md) &&
(gpr_time_cmp(gpr_time_sub(c->cached.jwt_expiration,
gpr_now(GPR_CLOCK_REALTIME)),
refresh_threshold) > 0)) {
jwt_md = GRPC_MDELEM_REF(c->cached.jwt_md);
gpr_mu_lock(&cache_mu_);
if (cached_.service_url != nullptr &&
strcmp(cached_.service_url, context.service_url) == 0 &&
!GRPC_MDISNULL(cached_.jwt_md) &&
(gpr_time_cmp(
gpr_time_sub(cached_.jwt_expiration, gpr_now(GPR_CLOCK_REALTIME)),
refresh_threshold) > 0)) {
jwt_md = GRPC_MDELEM_REF(cached_.jwt_md);
}
gpr_mu_unlock(&c->cache_mu);
gpr_mu_unlock(&cache_mu_);
}
if (GRPC_MDISNULL(jwt_md)) {
char* jwt = nullptr;
/* Generate a new jwt. */
gpr_mu_lock(&c->cache_mu);
jwt_reset_cache(c);
jwt = grpc_jwt_encode_and_sign(&c->key, context.service_url,
c->jwt_lifetime, nullptr);
gpr_mu_lock(&cache_mu_);
reset_cache();
jwt = grpc_jwt_encode_and_sign(&key_, context.service_url, jwt_lifetime_,
nullptr);
if (jwt != nullptr) {
char* md_value;
gpr_asprintf(&md_value, "Bearer %s", jwt);
gpr_free(jwt);
c->cached.jwt_expiration =
gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), c->jwt_lifetime);
c->cached.service_url = gpr_strdup(context.service_url);
c->cached.jwt_md = grpc_mdelem_from_slices(
cached_.jwt_expiration =
gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), jwt_lifetime_);
cached_.service_url = gpr_strdup(context.service_url);
cached_.jwt_md = grpc_mdelem_from_slices(
grpc_slice_from_static_string(GRPC_AUTHORIZATION_METADATA_KEY),
grpc_slice_from_copied_string(md_value));
gpr_free(md_value);
jwt_md = GRPC_MDELEM_REF(c->cached.jwt_md);
jwt_md = GRPC_MDELEM_REF(cached_.jwt_md);
}
gpr_mu_unlock(&c->cache_mu);
gpr_mu_unlock(&cache_mu_);
}
if (!GRPC_MDISNULL(jwt_md)) {
@ -106,29 +103,15 @@ static bool jwt_get_request_metadata(grpc_call_credentials* creds,
return true;
}
static void jwt_cancel_get_request_metadata(
grpc_call_credentials* c, grpc_credentials_mdelem_array* md_array,
grpc_error* error) {
void grpc_service_account_jwt_access_credentials::cancel_get_request_metadata(
grpc_credentials_mdelem_array* md_array, grpc_error* error) {
GRPC_ERROR_UNREF(error);
}
static grpc_call_credentials_vtable jwt_vtable = {
jwt_destruct, jwt_get_request_metadata, jwt_cancel_get_request_metadata};
grpc_call_credentials*
grpc_service_account_jwt_access_credentials_create_from_auth_json_key(
grpc_auth_json_key key, gpr_timespec token_lifetime) {
grpc_service_account_jwt_access_credentials* c;
if (!grpc_auth_json_key_is_valid(&key)) {
gpr_log(GPR_ERROR, "Invalid input for jwt credentials creation");
return nullptr;
}
c = static_cast<grpc_service_account_jwt_access_credentials*>(
gpr_zalloc(sizeof(grpc_service_account_jwt_access_credentials)));
c->base.type = GRPC_CALL_CREDENTIALS_TYPE_JWT;
gpr_ref_init(&c->base.refcount, 1);
c->base.vtable = &jwt_vtable;
c->key = key;
grpc_service_account_jwt_access_credentials::
grpc_service_account_jwt_access_credentials(grpc_auth_json_key key,
gpr_timespec token_lifetime)
: grpc_call_credentials(GRPC_CALL_CREDENTIALS_TYPE_JWT), key_(key) {
gpr_timespec max_token_lifetime = grpc_max_auth_token_lifetime();
if (gpr_time_cmp(token_lifetime, max_token_lifetime) > 0) {
gpr_log(GPR_INFO,
@ -136,10 +119,20 @@ grpc_service_account_jwt_access_credentials_create_from_auth_json_key(
static_cast<int>(max_token_lifetime.tv_sec));
token_lifetime = grpc_max_auth_token_lifetime();
}
c->jwt_lifetime = token_lifetime;
gpr_mu_init(&c->cache_mu);
jwt_reset_cache(c);
return &c->base;
jwt_lifetime_ = token_lifetime;
gpr_mu_init(&cache_mu_);
reset_cache();
}
grpc_core::RefCountedPtr<grpc_call_credentials>
grpc_service_account_jwt_access_credentials_create_from_auth_json_key(
grpc_auth_json_key key, gpr_timespec token_lifetime) {
if (!grpc_auth_json_key_is_valid(&key)) {
gpr_log(GPR_ERROR, "Invalid input for jwt credentials creation");
return nullptr;
}
return grpc_core::MakeRefCounted<grpc_service_account_jwt_access_credentials>(
key, token_lifetime);
}
static char* redact_private_key(const char* json_key) {
@ -182,9 +175,7 @@ grpc_call_credentials* grpc_service_account_jwt_access_credentials_create(
}
GPR_ASSERT(reserved == nullptr);
grpc_core::ExecCtx exec_ctx;
grpc_call_credentials* creds =
grpc_service_account_jwt_access_credentials_create_from_auth_json_key(
grpc_auth_json_key_create_from_string(json_key), token_lifetime);
return creds;
return grpc_service_account_jwt_access_credentials_create_from_auth_json_key(
grpc_auth_json_key_create_from_string(json_key), token_lifetime)
.release();
}

@ -24,25 +24,44 @@
#include "src/core/lib/security/credentials/credentials.h"
#include "src/core/lib/security/credentials/jwt/json_token.h"
typedef struct {
grpc_call_credentials base;
class grpc_service_account_jwt_access_credentials
: public grpc_call_credentials {
public:
grpc_service_account_jwt_access_credentials(grpc_auth_json_key key,
gpr_timespec token_lifetime);
~grpc_service_account_jwt_access_credentials() override;
bool get_request_metadata(grpc_polling_entity* pollent,
grpc_auth_metadata_context context,
grpc_credentials_mdelem_array* md_array,
grpc_closure* on_request_metadata,
grpc_error** error) override;
void cancel_get_request_metadata(grpc_credentials_mdelem_array* md_array,
grpc_error* error) override;
const gpr_timespec& jwt_lifetime() const { return jwt_lifetime_; }
const grpc_auth_json_key& key() const { return key_; }
private:
void reset_cache();
// Have a simple cache for now with just 1 entry. We could have a map based on
// the service_url for a more sophisticated one.
gpr_mu cache_mu;
gpr_mu cache_mu_;
struct {
grpc_mdelem jwt_md;
char* service_url;
grpc_mdelem jwt_md = GRPC_MDNULL;
char* service_url = nullptr;
gpr_timespec jwt_expiration;
} cached;
} cached_;
grpc_auth_json_key key;
gpr_timespec jwt_lifetime;
} grpc_service_account_jwt_access_credentials;
grpc_auth_json_key key_;
gpr_timespec jwt_lifetime_;
};
// Private constructor for jwt credentials from an already parsed json key.
// Takes ownership of the key.
grpc_call_credentials*
grpc_core::RefCountedPtr<grpc_call_credentials>
grpc_service_account_jwt_access_credentials_create_from_auth_json_key(
grpc_auth_json_key key, gpr_timespec token_lifetime);

@ -29,49 +29,36 @@
#define GRPC_CREDENTIALS_TYPE_LOCAL "Local"
static void local_credentials_destruct(grpc_channel_credentials* creds) {}
static void local_server_credentials_destruct(grpc_server_credentials* creds) {}
static grpc_security_status local_create_security_connector(
grpc_channel_credentials* creds,
grpc_call_credentials* request_metadata_creds, const char* target_name,
const grpc_channel_args* args, grpc_channel_security_connector** sc,
grpc_core::RefCountedPtr<grpc_channel_security_connector>
grpc_local_credentials::create_security_connector(
grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds,
const char* target_name, const grpc_channel_args* args,
grpc_channel_args** new_args) {
return grpc_local_channel_security_connector_create(
creds, request_metadata_creds, args, target_name, sc);
this->Ref(), std::move(request_metadata_creds), args, target_name);
}
static grpc_security_status local_server_create_security_connector(
grpc_server_credentials* creds, grpc_server_security_connector** sc) {
return grpc_local_server_security_connector_create(creds, sc);
grpc_core::RefCountedPtr<grpc_server_security_connector>
grpc_local_server_credentials::create_security_connector() {
return grpc_local_server_security_connector_create(this->Ref());
}
static const grpc_channel_credentials_vtable local_credentials_vtable = {
local_credentials_destruct, local_create_security_connector,
/*duplicate_without_call_credentials=*/nullptr};
static const grpc_server_credentials_vtable local_server_credentials_vtable = {
local_server_credentials_destruct, local_server_create_security_connector};
grpc_local_credentials::grpc_local_credentials(
grpc_local_connect_type connect_type)
: grpc_channel_credentials(GRPC_CREDENTIALS_TYPE_LOCAL),
connect_type_(connect_type) {}
grpc_channel_credentials* grpc_local_credentials_create(
grpc_local_connect_type connect_type) {
auto creds = static_cast<grpc_local_credentials*>(
gpr_zalloc(sizeof(grpc_local_credentials)));
creds->connect_type = connect_type;
creds->base.type = GRPC_CREDENTIALS_TYPE_LOCAL;
creds->base.vtable = &local_credentials_vtable;
gpr_ref_init(&creds->base.refcount, 1);
return &creds->base;
return grpc_core::New<grpc_local_credentials>(connect_type);
}
grpc_local_server_credentials::grpc_local_server_credentials(
grpc_local_connect_type connect_type)
: grpc_server_credentials(GRPC_CREDENTIALS_TYPE_LOCAL),
connect_type_(connect_type) {}
grpc_server_credentials* grpc_local_server_credentials_create(
grpc_local_connect_type connect_type) {
auto creds = static_cast<grpc_local_server_credentials*>(
gpr_zalloc(sizeof(grpc_local_server_credentials)));
creds->connect_type = connect_type;
creds->base.type = GRPC_CREDENTIALS_TYPE_LOCAL;
creds->base.vtable = &local_server_credentials_vtable;
gpr_ref_init(&creds->base.refcount, 1);
return &creds->base;
return grpc_core::New<grpc_local_server_credentials>(connect_type);
}

@ -25,16 +25,37 @@
#include "src/core/lib/security/credentials/credentials.h"
/* Main struct for grpc local channel credential. */
typedef struct grpc_local_credentials {
grpc_channel_credentials base;
grpc_local_connect_type connect_type;
} grpc_local_credentials;
/* Main struct for grpc local server credential. */
typedef struct grpc_local_server_credentials {
grpc_server_credentials base;
grpc_local_connect_type connect_type;
} grpc_local_server_credentials;
/* Main class for grpc local channel credential. */
class grpc_local_credentials final : public grpc_channel_credentials {
public:
explicit grpc_local_credentials(grpc_local_connect_type connect_type);
~grpc_local_credentials() override = default;
grpc_core::RefCountedPtr<grpc_channel_security_connector>
create_security_connector(
grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds,
const char* target_name, const grpc_channel_args* args,
grpc_channel_args** new_args) override;
grpc_local_connect_type connect_type() const { return connect_type_; }
private:
grpc_local_connect_type connect_type_;
};
/* Main class for grpc local server credential. */
class grpc_local_server_credentials final : public grpc_server_credentials {
public:
explicit grpc_local_server_credentials(grpc_local_connect_type connect_type);
~grpc_local_server_credentials() override = default;
grpc_core::RefCountedPtr<grpc_server_security_connector>
create_security_connector() override;
grpc_local_connect_type connect_type() const { return connect_type_; }
private:
grpc_local_connect_type connect_type_;
};
#endif /* GRPC_CORE_LIB_SECURITY_CREDENTIALS_LOCAL_LOCAL_CREDENTIALS_H */

@ -22,6 +22,7 @@
#include <string.h>
#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/lib/security/util/json_util.h"
#include "src/core/lib/surface/api_trace.h"
@ -105,13 +106,12 @@ void grpc_auth_refresh_token_destruct(grpc_auth_refresh_token* refresh_token) {
// Oauth2 Token Fetcher credentials.
//
static void oauth2_token_fetcher_destruct(grpc_call_credentials* creds) {
grpc_oauth2_token_fetcher_credentials* c =
reinterpret_cast<grpc_oauth2_token_fetcher_credentials*>(creds);
GRPC_MDELEM_UNREF(c->access_token_md);
gpr_mu_destroy(&c->mu);
grpc_pollset_set_destroy(grpc_polling_entity_pollset_set(&c->pollent));
grpc_httpcli_context_destroy(&c->httpcli_context);
grpc_oauth2_token_fetcher_credentials::
~grpc_oauth2_token_fetcher_credentials() {
GRPC_MDELEM_UNREF(access_token_md_);
gpr_mu_destroy(&mu_);
grpc_pollset_set_destroy(grpc_polling_entity_pollset_set(&pollent_));
grpc_httpcli_context_destroy(&httpcli_context_);
}
grpc_credentials_status
@ -209,25 +209,29 @@ static void on_oauth2_token_fetcher_http_response(void* user_data,
grpc_credentials_metadata_request* r =
static_cast<grpc_credentials_metadata_request*>(user_data);
grpc_oauth2_token_fetcher_credentials* c =
reinterpret_cast<grpc_oauth2_token_fetcher_credentials*>(r->creds);
reinterpret_cast<grpc_oauth2_token_fetcher_credentials*>(r->creds.get());
c->on_http_response(r, error);
}
void grpc_oauth2_token_fetcher_credentials::on_http_response(
grpc_credentials_metadata_request* r, grpc_error* error) {
grpc_mdelem access_token_md = GRPC_MDNULL;
grpc_millis token_lifetime;
grpc_credentials_status status =
grpc_oauth2_token_fetcher_credentials_parse_server_response(
&r->response, &access_token_md, &token_lifetime);
// Update cache and grab list of pending requests.
gpr_mu_lock(&c->mu);
c->token_fetch_pending = false;
c->access_token_md = GRPC_MDELEM_REF(access_token_md);
c->token_expiration =
gpr_mu_lock(&mu_);
token_fetch_pending_ = false;
access_token_md_ = GRPC_MDELEM_REF(access_token_md);
token_expiration_ =
status == GRPC_CREDENTIALS_OK
? gpr_time_add(gpr_now(GPR_CLOCK_MONOTONIC),
gpr_time_from_millis(token_lifetime, GPR_TIMESPAN))
: gpr_inf_past(GPR_CLOCK_MONOTONIC);
grpc_oauth2_pending_get_request_metadata* pending_request =
c->pending_requests;
c->pending_requests = nullptr;
gpr_mu_unlock(&c->mu);
grpc_oauth2_pending_get_request_metadata* pending_request = pending_requests_;
pending_requests_ = nullptr;
gpr_mu_unlock(&mu_);
// Invoke callbacks for all pending requests.
while (pending_request != nullptr) {
if (status == GRPC_CREDENTIALS_OK) {
@ -239,42 +243,40 @@ static void on_oauth2_token_fetcher_http_response(void* user_data,
}
GRPC_CLOSURE_SCHED(pending_request->on_request_metadata, error);
grpc_polling_entity_del_from_pollset_set(
pending_request->pollent, grpc_polling_entity_pollset_set(&c->pollent));
pending_request->pollent, grpc_polling_entity_pollset_set(&pollent_));
grpc_oauth2_pending_get_request_metadata* prev = pending_request;
pending_request = pending_request->next;
gpr_free(prev);
}
GRPC_MDELEM_UNREF(access_token_md);
grpc_call_credentials_unref(r->creds);
Unref();
grpc_credentials_metadata_request_destroy(r);
}
static bool oauth2_token_fetcher_get_request_metadata(
grpc_call_credentials* creds, grpc_polling_entity* pollent,
grpc_auth_metadata_context context, grpc_credentials_mdelem_array* md_array,
grpc_closure* on_request_metadata, grpc_error** error) {
grpc_oauth2_token_fetcher_credentials* c =
reinterpret_cast<grpc_oauth2_token_fetcher_credentials*>(creds);
bool grpc_oauth2_token_fetcher_credentials::get_request_metadata(
grpc_polling_entity* pollent, grpc_auth_metadata_context context,
grpc_credentials_mdelem_array* md_array, grpc_closure* on_request_metadata,
grpc_error** error) {
// Check if we can use the cached token.
grpc_millis refresh_threshold =
GRPC_SECURE_TOKEN_REFRESH_THRESHOLD_SECS * GPR_MS_PER_SEC;
grpc_mdelem cached_access_token_md = GRPC_MDNULL;
gpr_mu_lock(&c->mu);
if (!GRPC_MDISNULL(c->access_token_md) &&
gpr_mu_lock(&mu_);
if (!GRPC_MDISNULL(access_token_md_) &&
gpr_time_cmp(
gpr_time_sub(c->token_expiration, gpr_now(GPR_CLOCK_MONOTONIC)),
gpr_time_sub(token_expiration_, gpr_now(GPR_CLOCK_MONOTONIC)),
gpr_time_from_seconds(GRPC_SECURE_TOKEN_REFRESH_THRESHOLD_SECS,
GPR_TIMESPAN)) > 0) {
cached_access_token_md = GRPC_MDELEM_REF(c->access_token_md);
cached_access_token_md = GRPC_MDELEM_REF(access_token_md_);
}
if (!GRPC_MDISNULL(cached_access_token_md)) {
gpr_mu_unlock(&c->mu);
gpr_mu_unlock(&mu_);
grpc_credentials_mdelem_array_add(md_array, cached_access_token_md);
GRPC_MDELEM_UNREF(cached_access_token_md);
return true;
}
// Couldn't get the token from the cache.
// Add request to c->pending_requests and start a new fetch if needed.
// Add request to pending_requests_ and start a new fetch if needed.
grpc_oauth2_pending_get_request_metadata* pending_request =
static_cast<grpc_oauth2_pending_get_request_metadata*>(
gpr_malloc(sizeof(*pending_request)));
@ -282,41 +284,37 @@ static bool oauth2_token_fetcher_get_request_metadata(
pending_request->on_request_metadata = on_request_metadata;
pending_request->pollent = pollent;
grpc_polling_entity_add_to_pollset_set(
pollent, grpc_polling_entity_pollset_set(&c->pollent));
pending_request->next = c->pending_requests;
c->pending_requests = pending_request;
pollent, grpc_polling_entity_pollset_set(&pollent_));
pending_request->next = pending_requests_;
pending_requests_ = pending_request;
bool start_fetch = false;
if (!c->token_fetch_pending) {
c->token_fetch_pending = true;
if (!token_fetch_pending_) {
token_fetch_pending_ = true;
start_fetch = true;
}
gpr_mu_unlock(&c->mu);
gpr_mu_unlock(&mu_);
if (start_fetch) {
grpc_call_credentials_ref(creds);
c->fetch_func(grpc_credentials_metadata_request_create(creds),
&c->httpcli_context, &c->pollent,
on_oauth2_token_fetcher_http_response,
grpc_core::ExecCtx::Get()->Now() + refresh_threshold);
Ref().release();
fetch_oauth2(grpc_credentials_metadata_request_create(this->Ref()),
&httpcli_context_, &pollent_,
on_oauth2_token_fetcher_http_response,
grpc_core::ExecCtx::Get()->Now() + refresh_threshold);
}
return false;
}
static void oauth2_token_fetcher_cancel_get_request_metadata(
grpc_call_credentials* creds, grpc_credentials_mdelem_array* md_array,
grpc_error* error) {
grpc_oauth2_token_fetcher_credentials* c =
reinterpret_cast<grpc_oauth2_token_fetcher_credentials*>(creds);
gpr_mu_lock(&c->mu);
void grpc_oauth2_token_fetcher_credentials::cancel_get_request_metadata(
grpc_credentials_mdelem_array* md_array, grpc_error* error) {
gpr_mu_lock(&mu_);
grpc_oauth2_pending_get_request_metadata* prev = nullptr;
grpc_oauth2_pending_get_request_metadata* pending_request =
c->pending_requests;
grpc_oauth2_pending_get_request_metadata* pending_request = pending_requests_;
while (pending_request != nullptr) {
if (pending_request->md_array == md_array) {
// Remove matching pending request from the list.
if (prev != nullptr) {
prev->next = pending_request->next;
} else {
c->pending_requests = pending_request->next;
pending_requests_ = pending_request->next;
}
// Invoke the callback immediately with an error.
GRPC_CLOSURE_SCHED(pending_request->on_request_metadata,
@ -327,96 +325,89 @@ static void oauth2_token_fetcher_cancel_get_request_metadata(
prev = pending_request;
pending_request = pending_request->next;
}
gpr_mu_unlock(&c->mu);
gpr_mu_unlock(&mu_);
GRPC_ERROR_UNREF(error);
}
static void init_oauth2_token_fetcher(grpc_oauth2_token_fetcher_credentials* c,
grpc_fetch_oauth2_func fetch_func) {
memset(c, 0, sizeof(grpc_oauth2_token_fetcher_credentials));
c->base.type = GRPC_CALL_CREDENTIALS_TYPE_OAUTH2;
gpr_ref_init(&c->base.refcount, 1);
gpr_mu_init(&c->mu);
c->token_expiration = gpr_inf_past(GPR_CLOCK_MONOTONIC);
c->fetch_func = fetch_func;
c->pollent =
grpc_polling_entity_create_from_pollset_set(grpc_pollset_set_create());
grpc_httpcli_context_init(&c->httpcli_context);
grpc_oauth2_token_fetcher_credentials::grpc_oauth2_token_fetcher_credentials()
: grpc_call_credentials(GRPC_CALL_CREDENTIALS_TYPE_OAUTH2),
token_expiration_(gpr_inf_past(GPR_CLOCK_MONOTONIC)),
pollent_(grpc_polling_entity_create_from_pollset_set(
grpc_pollset_set_create())) {
gpr_mu_init(&mu_);
grpc_httpcli_context_init(&httpcli_context_);
}
//
// Google Compute Engine credentials.
//
static grpc_call_credentials_vtable compute_engine_vtable = {
oauth2_token_fetcher_destruct, oauth2_token_fetcher_get_request_metadata,
oauth2_token_fetcher_cancel_get_request_metadata};
namespace {
class grpc_compute_engine_token_fetcher_credentials
: public grpc_oauth2_token_fetcher_credentials {
public:
grpc_compute_engine_token_fetcher_credentials() = default;
~grpc_compute_engine_token_fetcher_credentials() override = default;
protected:
void fetch_oauth2(grpc_credentials_metadata_request* metadata_req,
grpc_httpcli_context* http_context,
grpc_polling_entity* pollent,
grpc_iomgr_cb_func response_cb,
grpc_millis deadline) override {
grpc_http_header header = {(char*)"Metadata-Flavor", (char*)"Google"};
grpc_httpcli_request request;
memset(&request, 0, sizeof(grpc_httpcli_request));
request.host = (char*)GRPC_COMPUTE_ENGINE_METADATA_HOST;
request.http.path = (char*)GRPC_COMPUTE_ENGINE_METADATA_TOKEN_PATH;
request.http.hdr_count = 1;
request.http.hdrs = &header;
/* TODO(ctiller): Carry the resource_quota in ctx and share it with the host
channel. This would allow us to cancel an authentication query when under
extreme memory pressure. */
grpc_resource_quota* resource_quota =
grpc_resource_quota_create("oauth2_credentials");
grpc_httpcli_get(http_context, pollent, resource_quota, &request, deadline,
GRPC_CLOSURE_CREATE(response_cb, metadata_req,
grpc_schedule_on_exec_ctx),
&metadata_req->response);
grpc_resource_quota_unref_internal(resource_quota);
}
};
static void compute_engine_fetch_oauth2(
grpc_credentials_metadata_request* metadata_req,
grpc_httpcli_context* httpcli_context, grpc_polling_entity* pollent,
grpc_iomgr_cb_func response_cb, grpc_millis deadline) {
grpc_http_header header = {(char*)"Metadata-Flavor", (char*)"Google"};
grpc_httpcli_request request;
memset(&request, 0, sizeof(grpc_httpcli_request));
request.host = (char*)GRPC_COMPUTE_ENGINE_METADATA_HOST;
request.http.path = (char*)GRPC_COMPUTE_ENGINE_METADATA_TOKEN_PATH;
request.http.hdr_count = 1;
request.http.hdrs = &header;
/* TODO(ctiller): Carry the resource_quota in ctx and share it with the host
channel. This would allow us to cancel an authentication query when under
extreme memory pressure. */
grpc_resource_quota* resource_quota =
grpc_resource_quota_create("oauth2_credentials");
grpc_httpcli_get(
httpcli_context, pollent, resource_quota, &request, deadline,
GRPC_CLOSURE_CREATE(response_cb, metadata_req, grpc_schedule_on_exec_ctx),
&metadata_req->response);
grpc_resource_quota_unref_internal(resource_quota);
}
} // namespace
grpc_call_credentials* grpc_google_compute_engine_credentials_create(
void* reserved) {
grpc_oauth2_token_fetcher_credentials* c =
static_cast<grpc_oauth2_token_fetcher_credentials*>(
gpr_malloc(sizeof(grpc_oauth2_token_fetcher_credentials)));
GRPC_API_TRACE("grpc_compute_engine_credentials_create(reserved=%p)", 1,
(reserved));
GPR_ASSERT(reserved == nullptr);
init_oauth2_token_fetcher(c, compute_engine_fetch_oauth2);
c->base.vtable = &compute_engine_vtable;
return &c->base;
return grpc_core::MakeRefCounted<
grpc_compute_engine_token_fetcher_credentials>()
.release();
}
//
// Google Refresh Token credentials.
//
static void refresh_token_destruct(grpc_call_credentials* creds) {
grpc_google_refresh_token_credentials* c =
reinterpret_cast<grpc_google_refresh_token_credentials*>(creds);
grpc_auth_refresh_token_destruct(&c->refresh_token);
oauth2_token_fetcher_destruct(&c->base.base);
grpc_google_refresh_token_credentials::
~grpc_google_refresh_token_credentials() {
grpc_auth_refresh_token_destruct(&refresh_token_);
}
static grpc_call_credentials_vtable refresh_token_vtable = {
refresh_token_destruct, oauth2_token_fetcher_get_request_metadata,
oauth2_token_fetcher_cancel_get_request_metadata};
static void refresh_token_fetch_oauth2(
void grpc_google_refresh_token_credentials::fetch_oauth2(
grpc_credentials_metadata_request* metadata_req,
grpc_httpcli_context* httpcli_context, grpc_polling_entity* pollent,
grpc_iomgr_cb_func response_cb, grpc_millis deadline) {
grpc_google_refresh_token_credentials* c =
reinterpret_cast<grpc_google_refresh_token_credentials*>(
metadata_req->creds);
grpc_http_header header = {(char*)"Content-Type",
(char*)"application/x-www-form-urlencoded"};
grpc_httpcli_request request;
char* body = nullptr;
gpr_asprintf(&body, GRPC_REFRESH_TOKEN_POST_BODY_FORMAT_STRING,
c->refresh_token.client_id, c->refresh_token.client_secret,
c->refresh_token.refresh_token);
refresh_token_.client_id, refresh_token_.client_secret,
refresh_token_.refresh_token);
memset(&request, 0, sizeof(grpc_httpcli_request));
request.host = (char*)GRPC_GOOGLE_OAUTH2_SERVICE_HOST;
request.http.path = (char*)GRPC_GOOGLE_OAUTH2_SERVICE_TOKEN_PATH;
@ -437,20 +428,19 @@ static void refresh_token_fetch_oauth2(
gpr_free(body);
}
grpc_call_credentials*
grpc_google_refresh_token_credentials::grpc_google_refresh_token_credentials(
grpc_auth_refresh_token refresh_token)
: refresh_token_(refresh_token) {}
grpc_core::RefCountedPtr<grpc_call_credentials>
grpc_refresh_token_credentials_create_from_auth_refresh_token(
grpc_auth_refresh_token refresh_token) {
grpc_google_refresh_token_credentials* c;
if (!grpc_auth_refresh_token_is_valid(&refresh_token)) {
gpr_log(GPR_ERROR, "Invalid input for refresh token credentials creation");
return nullptr;
}
c = static_cast<grpc_google_refresh_token_credentials*>(
gpr_zalloc(sizeof(grpc_google_refresh_token_credentials)));
init_oauth2_token_fetcher(&c->base, refresh_token_fetch_oauth2);
c->base.base.vtable = &refresh_token_vtable;
c->refresh_token = refresh_token;
return &c->base.base;
return grpc_core::MakeRefCounted<grpc_google_refresh_token_credentials>(
refresh_token);
}
static char* create_loggable_refresh_token(grpc_auth_refresh_token* token) {
@ -478,59 +468,50 @@ grpc_call_credentials* grpc_google_refresh_token_credentials_create(
gpr_free(loggable_token);
}
GPR_ASSERT(reserved == nullptr);
return grpc_refresh_token_credentials_create_from_auth_refresh_token(token);
return grpc_refresh_token_credentials_create_from_auth_refresh_token(token)
.release();
}
//
// Oauth2 Access Token credentials.
//
static void access_token_destruct(grpc_call_credentials* creds) {
grpc_access_token_credentials* c =
reinterpret_cast<grpc_access_token_credentials*>(creds);
GRPC_MDELEM_UNREF(c->access_token_md);
grpc_access_token_credentials::~grpc_access_token_credentials() {
GRPC_MDELEM_UNREF(access_token_md_);
}
static bool access_token_get_request_metadata(
grpc_call_credentials* creds, grpc_polling_entity* pollent,
grpc_auth_metadata_context context, grpc_credentials_mdelem_array* md_array,
grpc_closure* on_request_metadata, grpc_error** error) {
grpc_access_token_credentials* c =
reinterpret_cast<grpc_access_token_credentials*>(creds);
grpc_credentials_mdelem_array_add(md_array, c->access_token_md);
bool grpc_access_token_credentials::get_request_metadata(
grpc_polling_entity* pollent, grpc_auth_metadata_context context,
grpc_credentials_mdelem_array* md_array, grpc_closure* on_request_metadata,
grpc_error** error) {
grpc_credentials_mdelem_array_add(md_array, access_token_md_);
return true;
}
static void access_token_cancel_get_request_metadata(
grpc_call_credentials* c, grpc_credentials_mdelem_array* md_array,
grpc_error* error) {
void grpc_access_token_credentials::cancel_get_request_metadata(
grpc_credentials_mdelem_array* md_array, grpc_error* error) {
GRPC_ERROR_UNREF(error);
}
static grpc_call_credentials_vtable access_token_vtable = {
access_token_destruct, access_token_get_request_metadata,
access_token_cancel_get_request_metadata};
grpc_access_token_credentials::grpc_access_token_credentials(
const char* access_token)
: grpc_call_credentials(GRPC_CALL_CREDENTIALS_TYPE_OAUTH2) {
char* token_md_value;
gpr_asprintf(&token_md_value, "Bearer %s", access_token);
grpc_core::ExecCtx exec_ctx;
access_token_md_ = grpc_mdelem_from_slices(
grpc_slice_from_static_string(GRPC_AUTHORIZATION_METADATA_KEY),
grpc_slice_from_copied_string(token_md_value));
gpr_free(token_md_value);
}
grpc_call_credentials* grpc_access_token_credentials_create(
const char* access_token, void* reserved) {
grpc_access_token_credentials* c =
static_cast<grpc_access_token_credentials*>(
gpr_zalloc(sizeof(grpc_access_token_credentials)));
GRPC_API_TRACE(
"grpc_access_token_credentials_create(access_token=<redacted>, "
"reserved=%p)",
1, (reserved));
GPR_ASSERT(reserved == nullptr);
c->base.type = GRPC_CALL_CREDENTIALS_TYPE_OAUTH2;
c->base.vtable = &access_token_vtable;
gpr_ref_init(&c->base.refcount, 1);
char* token_md_value;
gpr_asprintf(&token_md_value, "Bearer %s", access_token);
grpc_core::ExecCtx exec_ctx;
c->access_token_md = grpc_mdelem_from_slices(
grpc_slice_from_static_string(GRPC_AUTHORIZATION_METADATA_KEY),
grpc_slice_from_copied_string(token_md_value));
gpr_free(token_md_value);
return &c->base;
return grpc_core::MakeRefCounted<grpc_access_token_credentials>(access_token)
.release();
}

@ -54,46 +54,91 @@ void grpc_auth_refresh_token_destruct(grpc_auth_refresh_token* refresh_token);
// This object is a base for credentials that need to acquire an oauth2 token
// from an http service.
typedef void (*grpc_fetch_oauth2_func)(grpc_credentials_metadata_request* req,
grpc_httpcli_context* http_context,
grpc_polling_entity* pollent,
grpc_iomgr_cb_func cb,
grpc_millis deadline);
typedef struct grpc_oauth2_pending_get_request_metadata {
struct grpc_oauth2_pending_get_request_metadata {
grpc_credentials_mdelem_array* md_array;
grpc_closure* on_request_metadata;
grpc_polling_entity* pollent;
struct grpc_oauth2_pending_get_request_metadata* next;
} grpc_oauth2_pending_get_request_metadata;
typedef struct {
grpc_call_credentials base;
gpr_mu mu;
grpc_mdelem access_token_md;
gpr_timespec token_expiration;
bool token_fetch_pending;
grpc_oauth2_pending_get_request_metadata* pending_requests;
grpc_httpcli_context httpcli_context;
grpc_fetch_oauth2_func fetch_func;
grpc_polling_entity pollent;
} grpc_oauth2_token_fetcher_credentials;
};
class grpc_oauth2_token_fetcher_credentials : public grpc_call_credentials {
public:
grpc_oauth2_token_fetcher_credentials();
~grpc_oauth2_token_fetcher_credentials() override;
bool get_request_metadata(grpc_polling_entity* pollent,
grpc_auth_metadata_context context,
grpc_credentials_mdelem_array* md_array,
grpc_closure* on_request_metadata,
grpc_error** error) override;
void cancel_get_request_metadata(grpc_credentials_mdelem_array* md_array,
grpc_error* error) override;
void on_http_response(grpc_credentials_metadata_request* r,
grpc_error* error);
GRPC_ABSTRACT_BASE_CLASS
protected:
virtual void fetch_oauth2(grpc_credentials_metadata_request* req,
grpc_httpcli_context* httpcli_context,
grpc_polling_entity* pollent, grpc_iomgr_cb_func cb,
grpc_millis deadline) GRPC_ABSTRACT;
private:
gpr_mu mu_;
grpc_mdelem access_token_md_ = GRPC_MDNULL;
gpr_timespec token_expiration_;
bool token_fetch_pending_ = false;
grpc_oauth2_pending_get_request_metadata* pending_requests_ = nullptr;
grpc_httpcli_context httpcli_context_;
grpc_polling_entity pollent_;
};
// Google refresh token credentials.
typedef struct {
grpc_oauth2_token_fetcher_credentials base;
grpc_auth_refresh_token refresh_token;
} grpc_google_refresh_token_credentials;
class grpc_google_refresh_token_credentials final
: public grpc_oauth2_token_fetcher_credentials {
public:
grpc_google_refresh_token_credentials(grpc_auth_refresh_token refresh_token);
~grpc_google_refresh_token_credentials() override;
const grpc_auth_refresh_token& refresh_token() const {
return refresh_token_;
}
protected:
void fetch_oauth2(grpc_credentials_metadata_request* req,
grpc_httpcli_context* httpcli_context,
grpc_polling_entity* pollent, grpc_iomgr_cb_func cb,
grpc_millis deadline) override;
private:
grpc_auth_refresh_token refresh_token_;
};
// Access token credentials.
typedef struct {
grpc_call_credentials base;
grpc_mdelem access_token_md;
} grpc_access_token_credentials;
class grpc_access_token_credentials final : public grpc_call_credentials {
public:
grpc_access_token_credentials(const char* access_token);
~grpc_access_token_credentials() override;
bool get_request_metadata(grpc_polling_entity* pollent,
grpc_auth_metadata_context context,
grpc_credentials_mdelem_array* md_array,
grpc_closure* on_request_metadata,
grpc_error** error) override;
void cancel_get_request_metadata(grpc_credentials_mdelem_array* md_array,
grpc_error* error) override;
private:
grpc_mdelem access_token_md_;
};
// Private constructor for refresh token credentials from an already parsed
// refresh token. Takes ownership of the refresh token.
grpc_call_credentials*
grpc_core::RefCountedPtr<grpc_call_credentials>
grpc_refresh_token_credentials_create_from_auth_refresh_token(
grpc_auth_refresh_token token);

@ -35,20 +35,17 @@
grpc_core::TraceFlag grpc_plugin_credentials_trace(false, "plugin_credentials");
static void plugin_destruct(grpc_call_credentials* creds) {
grpc_plugin_credentials* c =
reinterpret_cast<grpc_plugin_credentials*>(creds);
gpr_mu_destroy(&c->mu);
if (c->plugin.state != nullptr && c->plugin.destroy != nullptr) {
c->plugin.destroy(c->plugin.state);
grpc_plugin_credentials::~grpc_plugin_credentials() {
gpr_mu_destroy(&mu_);
if (plugin_.state != nullptr && plugin_.destroy != nullptr) {
plugin_.destroy(plugin_.state);
}
}
static void pending_request_remove_locked(
grpc_plugin_credentials* c,
grpc_plugin_credentials_pending_request* pending_request) {
void grpc_plugin_credentials::pending_request_remove_locked(
pending_request* pending_request) {
if (pending_request->prev == nullptr) {
c->pending_requests = pending_request->next;
pending_requests_ = pending_request->next;
} else {
pending_request->prev->next = pending_request->next;
}
@ -62,17 +59,17 @@ static void pending_request_remove_locked(
// cancelled out from under us.
// When this returns, r->cancelled indicates whether the request was
// cancelled before completion.
static void pending_request_complete(
grpc_plugin_credentials_pending_request* r) {
gpr_mu_lock(&r->creds->mu);
if (!r->cancelled) pending_request_remove_locked(r->creds, r);
gpr_mu_unlock(&r->creds->mu);
void grpc_plugin_credentials::pending_request_complete(pending_request* r) {
GPR_DEBUG_ASSERT(r->creds == this);
gpr_mu_lock(&mu_);
if (!r->cancelled) pending_request_remove_locked(r);
gpr_mu_unlock(&mu_);
// Ref to credentials not needed anymore.
grpc_call_credentials_unref(&r->creds->base);
Unref();
}
static grpc_error* process_plugin_result(
grpc_plugin_credentials_pending_request* r, const grpc_metadata* md,
grpc_plugin_credentials::pending_request* r, const grpc_metadata* md,
size_t num_md, grpc_status_code status, const char* error_details) {
grpc_error* error = GRPC_ERROR_NONE;
if (status != GRPC_STATUS_OK) {
@ -119,8 +116,8 @@ static void plugin_md_request_metadata_ready(void* request,
/* called from application code */
grpc_core::ExecCtx exec_ctx(GRPC_EXEC_CTX_FLAG_IS_FINISHED |
GRPC_EXEC_CTX_FLAG_THREAD_RESOURCE_LOOP);
grpc_plugin_credentials_pending_request* r =
static_cast<grpc_plugin_credentials_pending_request*>(request);
grpc_plugin_credentials::pending_request* r =
static_cast<grpc_plugin_credentials::pending_request*>(request);
if (grpc_plugin_credentials_trace.enabled()) {
gpr_log(GPR_INFO,
"plugin_credentials[%p]: request %p: plugin returned "
@ -128,7 +125,7 @@ static void plugin_md_request_metadata_ready(void* request,
r->creds, r);
}
// Remove request from pending list if not previously cancelled.
pending_request_complete(r);
r->creds->pending_request_complete(r);
// If it has not been cancelled, process it.
if (!r->cancelled) {
grpc_error* error =
@ -143,65 +140,59 @@ static void plugin_md_request_metadata_ready(void* request,
gpr_free(r);
}
static bool plugin_get_request_metadata(grpc_call_credentials* creds,
grpc_polling_entity* pollent,
grpc_auth_metadata_context context,
grpc_credentials_mdelem_array* md_array,
grpc_closure* on_request_metadata,
grpc_error** error) {
grpc_plugin_credentials* c =
reinterpret_cast<grpc_plugin_credentials*>(creds);
bool grpc_plugin_credentials::get_request_metadata(
grpc_polling_entity* pollent, grpc_auth_metadata_context context,
grpc_credentials_mdelem_array* md_array, grpc_closure* on_request_metadata,
grpc_error** error) {
bool retval = true; // Synchronous return.
if (c->plugin.get_metadata != nullptr) {
if (plugin_.get_metadata != nullptr) {
// Create pending_request object.
grpc_plugin_credentials_pending_request* pending_request =
static_cast<grpc_plugin_credentials_pending_request*>(
gpr_zalloc(sizeof(*pending_request)));
pending_request->creds = c;
pending_request->md_array = md_array;
pending_request->on_request_metadata = on_request_metadata;
pending_request* request =
static_cast<pending_request*>(gpr_zalloc(sizeof(*request)));
request->creds = this;
request->md_array = md_array;
request->on_request_metadata = on_request_metadata;
// Add it to the pending list.
gpr_mu_lock(&c->mu);
if (c->pending_requests != nullptr) {
c->pending_requests->prev = pending_request;
gpr_mu_lock(&mu_);
if (pending_requests_ != nullptr) {
pending_requests_->prev = request;
}
pending_request->next = c->pending_requests;
c->pending_requests = pending_request;
gpr_mu_unlock(&c->mu);
request->next = pending_requests_;
pending_requests_ = request;
gpr_mu_unlock(&mu_);
// Invoke the plugin. The callback holds a ref to us.
if (grpc_plugin_credentials_trace.enabled()) {
gpr_log(GPR_INFO, "plugin_credentials[%p]: request %p: invoking plugin",
c, pending_request);
this, request);
}
grpc_call_credentials_ref(creds);
Ref().release();
grpc_metadata creds_md[GRPC_METADATA_CREDENTIALS_PLUGIN_SYNC_MAX];
size_t num_creds_md = 0;
grpc_status_code status = GRPC_STATUS_OK;
const char* error_details = nullptr;
if (!c->plugin.get_metadata(c->plugin.state, context,
plugin_md_request_metadata_ready,
pending_request, creds_md, &num_creds_md,
&status, &error_details)) {
if (!plugin_.get_metadata(
plugin_.state, context, plugin_md_request_metadata_ready, request,
creds_md, &num_creds_md, &status, &error_details)) {
if (grpc_plugin_credentials_trace.enabled()) {
gpr_log(GPR_INFO,
"plugin_credentials[%p]: request %p: plugin will return "
"asynchronously",
c, pending_request);
this, request);
}
return false; // Asynchronous return.
}
// Returned synchronously.
// Remove request from pending list if not previously cancelled.
pending_request_complete(pending_request);
request->creds->pending_request_complete(request);
// If the request was cancelled, the error will have been returned
// asynchronously by plugin_cancel_get_request_metadata(), so return
// false. Otherwise, process the result.
if (pending_request->cancelled) {
if (request->cancelled) {
if (grpc_plugin_credentials_trace.enabled()) {
gpr_log(GPR_INFO,
"plugin_credentials[%p]: request %p was cancelled, error "
"will be returned asynchronously",
c, pending_request);
this, request);
}
retval = false;
} else {
@ -209,10 +200,10 @@ static bool plugin_get_request_metadata(grpc_call_credentials* creds,
gpr_log(GPR_INFO,
"plugin_credentials[%p]: request %p: plugin returned "
"synchronously",
c, pending_request);
this, request);
}
*error = process_plugin_result(pending_request, creds_md, num_creds_md,
status, error_details);
*error = process_plugin_result(request, creds_md, num_creds_md, status,
error_details);
}
// Clean up.
for (size_t i = 0; i < num_creds_md; ++i) {
@ -220,51 +211,42 @@ static bool plugin_get_request_metadata(grpc_call_credentials* creds,
grpc_slice_unref_internal(creds_md[i].value);
}
gpr_free((void*)error_details);
gpr_free(pending_request);
gpr_free(request);
}
return retval;
}
static void plugin_cancel_get_request_metadata(
grpc_call_credentials* creds, grpc_credentials_mdelem_array* md_array,
grpc_error* error) {
grpc_plugin_credentials* c =
reinterpret_cast<grpc_plugin_credentials*>(creds);
gpr_mu_lock(&c->mu);
for (grpc_plugin_credentials_pending_request* pending_request =
c->pending_requests;
void grpc_plugin_credentials::cancel_get_request_metadata(
grpc_credentials_mdelem_array* md_array, grpc_error* error) {
gpr_mu_lock(&mu_);
for (pending_request* pending_request = pending_requests_;
pending_request != nullptr; pending_request = pending_request->next) {
if (pending_request->md_array == md_array) {
if (grpc_plugin_credentials_trace.enabled()) {
gpr_log(GPR_INFO, "plugin_credentials[%p]: cancelling request %p", c,
gpr_log(GPR_INFO, "plugin_credentials[%p]: cancelling request %p", this,
pending_request);
}
pending_request->cancelled = true;
GRPC_CLOSURE_SCHED(pending_request->on_request_metadata,
GRPC_ERROR_REF(error));
pending_request_remove_locked(c, pending_request);
pending_request_remove_locked(pending_request);
break;
}
}
gpr_mu_unlock(&c->mu);
gpr_mu_unlock(&mu_);
GRPC_ERROR_UNREF(error);
}
static grpc_call_credentials_vtable plugin_vtable = {
plugin_destruct, plugin_get_request_metadata,
plugin_cancel_get_request_metadata};
grpc_plugin_credentials::grpc_plugin_credentials(
grpc_metadata_credentials_plugin plugin)
: grpc_call_credentials(plugin.type), plugin_(plugin) {
gpr_mu_init(&mu_);
}
grpc_call_credentials* grpc_metadata_credentials_create_from_plugin(
grpc_metadata_credentials_plugin plugin, void* reserved) {
grpc_plugin_credentials* c =
static_cast<grpc_plugin_credentials*>(gpr_zalloc(sizeof(*c)));
GRPC_API_TRACE("grpc_metadata_credentials_create_from_plugin(reserved=%p)", 1,
(reserved));
GPR_ASSERT(reserved == nullptr);
c->base.type = plugin.type;
c->base.vtable = &plugin_vtable;
gpr_ref_init(&c->base.refcount, 1);
c->plugin = plugin;
gpr_mu_init(&c->mu);
return &c->base;
return grpc_core::New<grpc_plugin_credentials>(plugin);
}

@ -25,22 +25,45 @@
extern grpc_core::TraceFlag grpc_plugin_credentials_trace;
struct grpc_plugin_credentials;
typedef struct grpc_plugin_credentials_pending_request {
bool cancelled;
struct grpc_plugin_credentials* creds;
grpc_credentials_mdelem_array* md_array;
grpc_closure* on_request_metadata;
struct grpc_plugin_credentials_pending_request* prev;
struct grpc_plugin_credentials_pending_request* next;
} grpc_plugin_credentials_pending_request;
typedef struct grpc_plugin_credentials {
grpc_call_credentials base;
grpc_metadata_credentials_plugin plugin;
gpr_mu mu;
grpc_plugin_credentials_pending_request* pending_requests;
} grpc_plugin_credentials;
// This type is forward declared as a C struct and we cannot define it as a
// class. Otherwise, compiler will complain about type mismatch due to
// -Wmismatched-tags.
struct grpc_plugin_credentials final : public grpc_call_credentials {
public:
struct pending_request {
bool cancelled;
struct grpc_plugin_credentials* creds;
grpc_credentials_mdelem_array* md_array;
grpc_closure* on_request_metadata;
struct pending_request* prev;
struct pending_request* next;
};
explicit grpc_plugin_credentials(grpc_metadata_credentials_plugin plugin);
~grpc_plugin_credentials() override;
bool get_request_metadata(grpc_polling_entity* pollent,
grpc_auth_metadata_context context,
grpc_credentials_mdelem_array* md_array,
grpc_closure* on_request_metadata,
grpc_error** error) override;
void cancel_get_request_metadata(grpc_credentials_mdelem_array* md_array,
grpc_error* error) override;
// Checks if the request has been cancelled.
// If not, removes it from the pending list, so that it cannot be
// cancelled out from under us.
// When this returns, r->cancelled indicates whether the request was
// cancelled before completion.
void pending_request_complete(pending_request* r);
private:
void pending_request_remove_locked(pending_request* pending_request);
grpc_metadata_credentials_plugin plugin_;
gpr_mu mu_;
pending_request* pending_requests_ = nullptr;
};
#endif /* GRPC_CORE_LIB_SECURITY_CREDENTIALS_PLUGIN_PLUGIN_CREDENTIALS_H */

@ -44,22 +44,27 @@ void grpc_tsi_ssl_pem_key_cert_pairs_destroy(tsi_ssl_pem_key_cert_pair* kp,
gpr_free(kp);
}
static void ssl_destruct(grpc_channel_credentials* creds) {
grpc_ssl_credentials* c = reinterpret_cast<grpc_ssl_credentials*>(creds);
gpr_free(c->config.pem_root_certs);
grpc_tsi_ssl_pem_key_cert_pairs_destroy(c->config.pem_key_cert_pair, 1);
if (c->config.verify_options.verify_peer_destruct != nullptr) {
c->config.verify_options.verify_peer_destruct(
c->config.verify_options.verify_peer_callback_userdata);
grpc_ssl_credentials::grpc_ssl_credentials(
const char* pem_root_certs, grpc_ssl_pem_key_cert_pair* pem_key_cert_pair,
const verify_peer_options* verify_options)
: grpc_channel_credentials(GRPC_CHANNEL_CREDENTIALS_TYPE_SSL) {
build_config(pem_root_certs, pem_key_cert_pair, verify_options);
}
grpc_ssl_credentials::~grpc_ssl_credentials() {
gpr_free(config_.pem_root_certs);
grpc_tsi_ssl_pem_key_cert_pairs_destroy(config_.pem_key_cert_pair, 1);
if (config_.verify_options.verify_peer_destruct != nullptr) {
config_.verify_options.verify_peer_destruct(
config_.verify_options.verify_peer_callback_userdata);
}
}
static grpc_security_status ssl_create_security_connector(
grpc_channel_credentials* creds, grpc_call_credentials* call_creds,
grpc_core::RefCountedPtr<grpc_channel_security_connector>
grpc_ssl_credentials::create_security_connector(
grpc_core::RefCountedPtr<grpc_call_credentials> call_creds,
const char* target, const grpc_channel_args* args,
grpc_channel_security_connector** sc, grpc_channel_args** new_args) {
grpc_ssl_credentials* c = reinterpret_cast<grpc_ssl_credentials*>(creds);
grpc_security_status status = GRPC_SECURITY_OK;
grpc_channel_args** new_args) {
const char* overridden_target_name = nullptr;
tsi_ssl_session_cache* ssl_session_cache = nullptr;
for (size_t i = 0; args && i < args->num_args; i++) {
@ -74,52 +79,47 @@ static grpc_security_status ssl_create_security_connector(
static_cast<tsi_ssl_session_cache*>(arg->value.pointer.p);
}
}
status = grpc_ssl_channel_security_connector_create(
creds, call_creds, &c->config, target, overridden_target_name,
ssl_session_cache, sc);
if (status != GRPC_SECURITY_OK) {
return status;
grpc_core::RefCountedPtr<grpc_channel_security_connector> sc =
grpc_ssl_channel_security_connector_create(
this->Ref(), std::move(call_creds), &config_, target,
overridden_target_name, ssl_session_cache);
if (sc == nullptr) {
return sc;
}
grpc_arg new_arg = grpc_channel_arg_string_create(
(char*)GRPC_ARG_HTTP2_SCHEME, (char*)"https");
*new_args = grpc_channel_args_copy_and_add(args, &new_arg, 1);
return status;
return sc;
}
static grpc_channel_credentials_vtable ssl_vtable = {
ssl_destruct, ssl_create_security_connector, nullptr};
static void ssl_build_config(const char* pem_root_certs,
grpc_ssl_pem_key_cert_pair* pem_key_cert_pair,
const verify_peer_options* verify_options,
grpc_ssl_config* config) {
if (pem_root_certs != nullptr) {
config->pem_root_certs = gpr_strdup(pem_root_certs);
}
void grpc_ssl_credentials::build_config(
const char* pem_root_certs, grpc_ssl_pem_key_cert_pair* pem_key_cert_pair,
const verify_peer_options* verify_options) {
config_.pem_root_certs = gpr_strdup(pem_root_certs);
if (pem_key_cert_pair != nullptr) {
GPR_ASSERT(pem_key_cert_pair->private_key != nullptr);
GPR_ASSERT(pem_key_cert_pair->cert_chain != nullptr);
config->pem_key_cert_pair = static_cast<tsi_ssl_pem_key_cert_pair*>(
config_.pem_key_cert_pair = static_cast<tsi_ssl_pem_key_cert_pair*>(
gpr_zalloc(sizeof(tsi_ssl_pem_key_cert_pair)));
config->pem_key_cert_pair->cert_chain =
config_.pem_key_cert_pair->cert_chain =
gpr_strdup(pem_key_cert_pair->cert_chain);
config->pem_key_cert_pair->private_key =
config_.pem_key_cert_pair->private_key =
gpr_strdup(pem_key_cert_pair->private_key);
} else {
config_.pem_key_cert_pair = nullptr;
}
if (verify_options != nullptr) {
memcpy(&config->verify_options, verify_options,
memcpy(&config_.verify_options, verify_options,
sizeof(verify_peer_options));
} else {
// Otherwise set all options to default values
memset(&config->verify_options, 0, sizeof(verify_peer_options));
memset(&config_.verify_options, 0, sizeof(verify_peer_options));
}
}
grpc_channel_credentials* grpc_ssl_credentials_create(
const char* pem_root_certs, grpc_ssl_pem_key_cert_pair* pem_key_cert_pair,
const verify_peer_options* verify_options, void* reserved) {
grpc_ssl_credentials* c = static_cast<grpc_ssl_credentials*>(
gpr_zalloc(sizeof(grpc_ssl_credentials)));
GRPC_API_TRACE(
"grpc_ssl_credentials_create(pem_root_certs=%s, "
"pem_key_cert_pair=%p, "
@ -127,12 +127,9 @@ grpc_channel_credentials* grpc_ssl_credentials_create(
"reserved=%p)",
4, (pem_root_certs, pem_key_cert_pair, verify_options, reserved));
GPR_ASSERT(reserved == nullptr);
c->base.type = GRPC_CHANNEL_CREDENTIALS_TYPE_SSL;
c->base.vtable = &ssl_vtable;
gpr_ref_init(&c->base.refcount, 1);
ssl_build_config(pem_root_certs, pem_key_cert_pair, verify_options,
&c->config);
return &c->base;
return grpc_core::New<grpc_ssl_credentials>(pem_root_certs, pem_key_cert_pair,
verify_options);
}
//
@ -145,21 +142,29 @@ struct grpc_ssl_server_credentials_options {
grpc_ssl_server_certificate_config_fetcher* certificate_config_fetcher;
};
static void ssl_server_destruct(grpc_server_credentials* creds) {
grpc_ssl_server_credentials* c =
reinterpret_cast<grpc_ssl_server_credentials*>(creds);
grpc_tsi_ssl_pem_key_cert_pairs_destroy(c->config.pem_key_cert_pairs,
c->config.num_key_cert_pairs);
gpr_free(c->config.pem_root_certs);
grpc_ssl_server_credentials::grpc_ssl_server_credentials(
const grpc_ssl_server_credentials_options& options)
: grpc_server_credentials(GRPC_CHANNEL_CREDENTIALS_TYPE_SSL) {
if (options.certificate_config_fetcher != nullptr) {
config_.client_certificate_request = options.client_certificate_request;
certificate_config_fetcher_ = *options.certificate_config_fetcher;
} else {
build_config(options.certificate_config->pem_root_certs,
options.certificate_config->pem_key_cert_pairs,
options.certificate_config->num_key_cert_pairs,
options.client_certificate_request);
}
}
static grpc_security_status ssl_server_create_security_connector(
grpc_server_credentials* creds, grpc_server_security_connector** sc) {
return grpc_ssl_server_security_connector_create(creds, sc);
grpc_ssl_server_credentials::~grpc_ssl_server_credentials() {
grpc_tsi_ssl_pem_key_cert_pairs_destroy(config_.pem_key_cert_pairs,
config_.num_key_cert_pairs);
gpr_free(config_.pem_root_certs);
}
grpc_core::RefCountedPtr<grpc_server_security_connector>
grpc_ssl_server_credentials::create_security_connector() {
return grpc_ssl_server_security_connector_create(this->Ref());
}
static grpc_server_credentials_vtable ssl_server_vtable = {
ssl_server_destruct, ssl_server_create_security_connector};
tsi_ssl_pem_key_cert_pair* grpc_convert_grpc_to_tsi_cert_pairs(
const grpc_ssl_pem_key_cert_pair* pem_key_cert_pairs,
@ -179,18 +184,15 @@ tsi_ssl_pem_key_cert_pair* grpc_convert_grpc_to_tsi_cert_pairs(
return tsi_pairs;
}
static void ssl_build_server_config(
void grpc_ssl_server_credentials::build_config(
const char* pem_root_certs, grpc_ssl_pem_key_cert_pair* pem_key_cert_pairs,
size_t num_key_cert_pairs,
grpc_ssl_client_certificate_request_type client_certificate_request,
grpc_ssl_server_config* config) {
config->client_certificate_request = client_certificate_request;
if (pem_root_certs != nullptr) {
config->pem_root_certs = gpr_strdup(pem_root_certs);
}
config->pem_key_cert_pairs = grpc_convert_grpc_to_tsi_cert_pairs(
grpc_ssl_client_certificate_request_type client_certificate_request) {
config_.client_certificate_request = client_certificate_request;
config_.pem_root_certs = gpr_strdup(pem_root_certs);
config_.pem_key_cert_pairs = grpc_convert_grpc_to_tsi_cert_pairs(
pem_key_cert_pairs, num_key_cert_pairs);
config->num_key_cert_pairs = num_key_cert_pairs;
config_.num_key_cert_pairs = num_key_cert_pairs;
}
grpc_ssl_server_certificate_config* grpc_ssl_server_certificate_config_create(
@ -200,9 +202,7 @@ grpc_ssl_server_certificate_config* grpc_ssl_server_certificate_config_create(
grpc_ssl_server_certificate_config* config =
static_cast<grpc_ssl_server_certificate_config*>(
gpr_zalloc(sizeof(grpc_ssl_server_certificate_config)));
if (pem_root_certs != nullptr) {
config->pem_root_certs = gpr_strdup(pem_root_certs);
}
config->pem_root_certs = gpr_strdup(pem_root_certs);
if (num_key_cert_pairs > 0) {
GPR_ASSERT(pem_key_cert_pairs != nullptr);
config->pem_key_cert_pairs = static_cast<grpc_ssl_pem_key_cert_pair*>(
@ -311,7 +311,6 @@ grpc_server_credentials* grpc_ssl_server_credentials_create_ex(
grpc_server_credentials* grpc_ssl_server_credentials_create_with_options(
grpc_ssl_server_credentials_options* options) {
grpc_server_credentials* retval = nullptr;
grpc_ssl_server_credentials* c = nullptr;
if (options == nullptr) {
gpr_log(GPR_ERROR,
@ -331,23 +330,7 @@ grpc_server_credentials* grpc_ssl_server_credentials_create_with_options(
goto done;
}
c = static_cast<grpc_ssl_server_credentials*>(
gpr_zalloc(sizeof(grpc_ssl_server_credentials)));
c->base.type = GRPC_CHANNEL_CREDENTIALS_TYPE_SSL;
gpr_ref_init(&c->base.refcount, 1);
c->base.vtable = &ssl_server_vtable;
if (options->certificate_config_fetcher != nullptr) {
c->config.client_certificate_request = options->client_certificate_request;
c->certificate_config_fetcher = *options->certificate_config_fetcher;
} else {
ssl_build_server_config(options->certificate_config->pem_root_certs,
options->certificate_config->pem_key_cert_pairs,
options->certificate_config->num_key_cert_pairs,
options->client_certificate_request, &c->config);
}
retval = &c->base;
retval = grpc_core::New<grpc_ssl_server_credentials>(*options);
done:
grpc_ssl_server_credentials_options_destroy(options);

@ -24,27 +24,70 @@
#include "src/core/lib/security/security_connector/ssl/ssl_security_connector.h"
typedef struct {
grpc_channel_credentials base;
grpc_ssl_config config;
} grpc_ssl_credentials;
class grpc_ssl_credentials : public grpc_channel_credentials {
public:
grpc_ssl_credentials(const char* pem_root_certs,
grpc_ssl_pem_key_cert_pair* pem_key_cert_pair,
const verify_peer_options* verify_options);
~grpc_ssl_credentials() override;
grpc_core::RefCountedPtr<grpc_channel_security_connector>
create_security_connector(
grpc_core::RefCountedPtr<grpc_call_credentials> call_creds,
const char* target, const grpc_channel_args* args,
grpc_channel_args** new_args) override;
private:
void build_config(const char* pem_root_certs,
grpc_ssl_pem_key_cert_pair* pem_key_cert_pair,
const verify_peer_options* verify_options);
grpc_ssl_config config_;
};
struct grpc_ssl_server_certificate_config {
grpc_ssl_pem_key_cert_pair* pem_key_cert_pairs;
size_t num_key_cert_pairs;
char* pem_root_certs;
grpc_ssl_pem_key_cert_pair* pem_key_cert_pairs = nullptr;
size_t num_key_cert_pairs = 0;
char* pem_root_certs = nullptr;
};
typedef struct {
grpc_ssl_server_certificate_config_callback cb;
struct grpc_ssl_server_certificate_config_fetcher {
grpc_ssl_server_certificate_config_callback cb = nullptr;
void* user_data;
} grpc_ssl_server_certificate_config_fetcher;
};
class grpc_ssl_server_credentials final : public grpc_server_credentials {
public:
grpc_ssl_server_credentials(
const grpc_ssl_server_credentials_options& options);
~grpc_ssl_server_credentials() override;
typedef struct {
grpc_server_credentials base;
grpc_ssl_server_config config;
grpc_ssl_server_certificate_config_fetcher certificate_config_fetcher;
} grpc_ssl_server_credentials;
grpc_core::RefCountedPtr<grpc_server_security_connector>
create_security_connector() override;
bool has_cert_config_fetcher() const {
return certificate_config_fetcher_.cb != nullptr;
}
grpc_ssl_certificate_config_reload_status FetchCertConfig(
grpc_ssl_server_certificate_config** config) {
GPR_DEBUG_ASSERT(has_cert_config_fetcher());
return certificate_config_fetcher_.cb(certificate_config_fetcher_.user_data,
config);
}
const grpc_ssl_server_config& config() const { return config_; }
private:
void build_config(
const char* pem_root_certs,
grpc_ssl_pem_key_cert_pair* pem_key_cert_pairs, size_t num_key_cert_pairs,
grpc_ssl_client_certificate_request_type client_certificate_request);
grpc_ssl_server_config config_;
grpc_ssl_server_certificate_config_fetcher certificate_config_fetcher_;
};
tsi_ssl_pem_key_cert_pair* grpc_convert_grpc_to_tsi_cert_pairs(
const grpc_ssl_pem_key_cert_pair* pem_key_cert_pairs,

@ -28,6 +28,7 @@
#include <grpc/support/log.h>
#include <grpc/support/string_util.h>
#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/lib/security/credentials/alts/alts_credentials.h"
#include "src/core/lib/security/transport/security_handshaker.h"
#include "src/core/lib/slice/slice_internal.h"
@ -35,64 +36,9 @@
#include "src/core/tsi/alts/handshaker/alts_tsi_handshaker.h"
#include "src/core/tsi/transport_security.h"
typedef struct {
grpc_channel_security_connector base;
char* target_name;
} grpc_alts_channel_security_connector;
namespace {
typedef struct {
grpc_server_security_connector base;
} grpc_alts_server_security_connector;
static void alts_channel_destroy(grpc_security_connector* sc) {
if (sc == nullptr) {
return;
}
auto c = reinterpret_cast<grpc_alts_channel_security_connector*>(sc);
grpc_call_credentials_unref(c->base.request_metadata_creds);
grpc_channel_credentials_unref(c->base.channel_creds);
gpr_free(c->target_name);
gpr_free(sc);
}
static void alts_server_destroy(grpc_security_connector* sc) {
if (sc == nullptr) {
return;
}
auto c = reinterpret_cast<grpc_alts_server_security_connector*>(sc);
grpc_server_credentials_unref(c->base.server_creds);
gpr_free(sc);
}
static void alts_channel_add_handshakers(
grpc_channel_security_connector* sc, grpc_pollset_set* interested_parties,
grpc_handshake_manager* handshake_manager) {
tsi_handshaker* handshaker = nullptr;
auto c = reinterpret_cast<grpc_alts_channel_security_connector*>(sc);
grpc_alts_credentials* creds =
reinterpret_cast<grpc_alts_credentials*>(c->base.channel_creds);
GPR_ASSERT(alts_tsi_handshaker_create(
creds->options, c->target_name, creds->handshaker_service_url,
true, interested_parties, &handshaker) == TSI_OK);
grpc_handshake_manager_add(handshake_manager, grpc_security_handshaker_create(
handshaker, &sc->base));
}
static void alts_server_add_handshakers(
grpc_server_security_connector* sc, grpc_pollset_set* interested_parties,
grpc_handshake_manager* handshake_manager) {
tsi_handshaker* handshaker = nullptr;
auto c = reinterpret_cast<grpc_alts_server_security_connector*>(sc);
grpc_alts_server_credentials* creds =
reinterpret_cast<grpc_alts_server_credentials*>(c->base.server_creds);
GPR_ASSERT(alts_tsi_handshaker_create(
creds->options, nullptr, creds->handshaker_service_url, false,
interested_parties, &handshaker) == TSI_OK);
grpc_handshake_manager_add(handshake_manager, grpc_security_handshaker_create(
handshaker, &sc->base));
}
static void alts_set_rpc_protocol_versions(
void alts_set_rpc_protocol_versions(
grpc_gcp_rpc_protocol_versions* rpc_versions) {
grpc_gcp_rpc_protocol_versions_set_max(rpc_versions,
GRPC_PROTOCOL_VERSION_MAX_MAJOR,
@ -102,17 +48,131 @@ static void alts_set_rpc_protocol_versions(
GRPC_PROTOCOL_VERSION_MIN_MINOR);
}
void atls_check_peer(tsi_peer peer,
grpc_core::RefCountedPtr<grpc_auth_context>* auth_context,
grpc_closure* on_peer_checked) {
*auth_context =
grpc_core::internal::grpc_alts_auth_context_from_tsi_peer(&peer);
tsi_peer_destruct(&peer);
grpc_error* error =
*auth_context != nullptr
? GRPC_ERROR_NONE
: GRPC_ERROR_CREATE_FROM_STATIC_STRING(
"Could not get ALTS auth context from TSI peer");
GRPC_CLOSURE_SCHED(on_peer_checked, error);
}
class grpc_alts_channel_security_connector final
: public grpc_channel_security_connector {
public:
grpc_alts_channel_security_connector(
grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds,
grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds,
const char* target_name)
: grpc_channel_security_connector(/*url_scheme=*/nullptr,
std::move(channel_creds),
std::move(request_metadata_creds)),
target_name_(gpr_strdup(target_name)) {
grpc_alts_credentials* creds =
static_cast<grpc_alts_credentials*>(mutable_channel_creds());
alts_set_rpc_protocol_versions(&creds->mutable_options()->rpc_versions);
}
~grpc_alts_channel_security_connector() override { gpr_free(target_name_); }
void add_handshakers(grpc_pollset_set* interested_parties,
grpc_handshake_manager* handshake_manager) override {
tsi_handshaker* handshaker = nullptr;
const grpc_alts_credentials* creds =
static_cast<const grpc_alts_credentials*>(channel_creds());
GPR_ASSERT(alts_tsi_handshaker_create(creds->options(), target_name_,
creds->handshaker_service_url(), true,
interested_parties,
&handshaker) == TSI_OK);
grpc_handshake_manager_add(
handshake_manager, grpc_security_handshaker_create(handshaker, this));
}
void check_peer(tsi_peer peer,
grpc_core::RefCountedPtr<grpc_auth_context>* auth_context,
grpc_closure* on_peer_checked) override {
atls_check_peer(peer, auth_context, on_peer_checked);
}
int cmp(const grpc_security_connector* other_sc) const override {
auto* other =
reinterpret_cast<const grpc_alts_channel_security_connector*>(other_sc);
int c = channel_security_connector_cmp(other);
if (c != 0) return c;
return strcmp(target_name_, other->target_name_);
}
bool check_call_host(const char* host, grpc_auth_context* auth_context,
grpc_closure* on_call_host_checked,
grpc_error** error) override {
if (host == nullptr || strcmp(host, target_name_) != 0) {
*error = GRPC_ERROR_CREATE_FROM_STATIC_STRING(
"ALTS call host does not match target name");
}
return true;
}
void cancel_check_call_host(grpc_closure* on_call_host_checked,
grpc_error* error) override {
GRPC_ERROR_UNREF(error);
}
private:
char* target_name_;
};
class grpc_alts_server_security_connector final
: public grpc_server_security_connector {
public:
grpc_alts_server_security_connector(
grpc_core::RefCountedPtr<grpc_server_credentials> server_creds)
: grpc_server_security_connector(/*url_scheme=*/nullptr,
std::move(server_creds)) {
grpc_alts_server_credentials* creds =
reinterpret_cast<grpc_alts_server_credentials*>(mutable_server_creds());
alts_set_rpc_protocol_versions(&creds->mutable_options()->rpc_versions);
}
~grpc_alts_server_security_connector() override = default;
void add_handshakers(grpc_pollset_set* interested_parties,
grpc_handshake_manager* handshake_manager) override {
tsi_handshaker* handshaker = nullptr;
const grpc_alts_server_credentials* creds =
static_cast<const grpc_alts_server_credentials*>(server_creds());
GPR_ASSERT(alts_tsi_handshaker_create(
creds->options(), nullptr, creds->handshaker_service_url(),
false, interested_parties, &handshaker) == TSI_OK);
grpc_handshake_manager_add(
handshake_manager, grpc_security_handshaker_create(handshaker, this));
}
void check_peer(tsi_peer peer,
grpc_core::RefCountedPtr<grpc_auth_context>* auth_context,
grpc_closure* on_peer_checked) override {
atls_check_peer(peer, auth_context, on_peer_checked);
}
int cmp(const grpc_security_connector* other) const override {
return server_security_connector_cmp(
static_cast<const grpc_server_security_connector*>(other));
}
};
} // namespace
namespace grpc_core {
namespace internal {
grpc_security_status grpc_alts_auth_context_from_tsi_peer(
const tsi_peer* peer, grpc_auth_context** ctx) {
if (peer == nullptr || ctx == nullptr) {
grpc_core::RefCountedPtr<grpc_auth_context>
grpc_alts_auth_context_from_tsi_peer(const tsi_peer* peer) {
if (peer == nullptr) {
gpr_log(GPR_ERROR,
"Invalid arguments to grpc_alts_auth_context_from_tsi_peer()");
return GRPC_SECURITY_ERROR;
return nullptr;
}
*ctx = nullptr;
/* Validate certificate type. */
const tsi_peer_property* cert_type_prop =
tsi_peer_get_property_by_name(peer, TSI_CERTIFICATE_TYPE_PEER_PROPERTY);
@ -120,14 +180,14 @@ grpc_security_status grpc_alts_auth_context_from_tsi_peer(
strncmp(cert_type_prop->value.data, TSI_ALTS_CERTIFICATE_TYPE,
cert_type_prop->value.length) != 0) {
gpr_log(GPR_ERROR, "Invalid or missing certificate type property.");
return GRPC_SECURITY_ERROR;
return nullptr;
}
/* Validate RPC protocol versions. */
const tsi_peer_property* rpc_versions_prop =
tsi_peer_get_property_by_name(peer, TSI_ALTS_RPC_VERSIONS);
if (rpc_versions_prop == nullptr) {
gpr_log(GPR_ERROR, "Missing rpc protocol versions property.");
return GRPC_SECURITY_ERROR;
return nullptr;
}
grpc_gcp_rpc_protocol_versions local_versions, peer_versions;
alts_set_rpc_protocol_versions(&local_versions);
@ -138,19 +198,19 @@ grpc_security_status grpc_alts_auth_context_from_tsi_peer(
grpc_slice_unref_internal(slice);
if (!decode_result) {
gpr_log(GPR_ERROR, "Invalid peer rpc protocol versions.");
return GRPC_SECURITY_ERROR;
return nullptr;
}
/* TODO: Pass highest common rpc protocol version to grpc caller. */
bool check_result = grpc_gcp_rpc_protocol_versions_check(
&local_versions, &peer_versions, nullptr);
if (!check_result) {
gpr_log(GPR_ERROR, "Mismatch of local and peer rpc protocol versions.");
return GRPC_SECURITY_ERROR;
return nullptr;
}
/* Create auth context. */
*ctx = grpc_auth_context_create(nullptr);
auto ctx = grpc_core::MakeRefCounted<grpc_auth_context>(nullptr);
grpc_auth_context_add_cstring_property(
*ctx, GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME,
ctx.get(), GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME,
GRPC_ALTS_TRANSPORT_SECURITY_TYPE);
size_t i = 0;
for (i = 0; i < peer->property_count; i++) {
@ -158,132 +218,47 @@ grpc_security_status grpc_alts_auth_context_from_tsi_peer(
/* Add service account to auth context. */
if (strcmp(tsi_prop->name, TSI_ALTS_SERVICE_ACCOUNT_PEER_PROPERTY) == 0) {
grpc_auth_context_add_property(
*ctx, TSI_ALTS_SERVICE_ACCOUNT_PEER_PROPERTY, tsi_prop->value.data,
tsi_prop->value.length);
ctx.get(), TSI_ALTS_SERVICE_ACCOUNT_PEER_PROPERTY,
tsi_prop->value.data, tsi_prop->value.length);
GPR_ASSERT(grpc_auth_context_set_peer_identity_property_name(
*ctx, TSI_ALTS_SERVICE_ACCOUNT_PEER_PROPERTY) == 1);
ctx.get(), TSI_ALTS_SERVICE_ACCOUNT_PEER_PROPERTY) == 1);
}
}
if (!grpc_auth_context_peer_is_authenticated(*ctx)) {
if (!grpc_auth_context_peer_is_authenticated(ctx.get())) {
gpr_log(GPR_ERROR, "Invalid unauthenticated peer.");
GRPC_AUTH_CONTEXT_UNREF(*ctx, "test");
*ctx = nullptr;
return GRPC_SECURITY_ERROR;
ctx.reset(DEBUG_LOCATION, "test");
return nullptr;
}
return GRPC_SECURITY_OK;
return ctx;
}
} // namespace internal
} // namespace grpc_core
static void alts_check_peer(grpc_security_connector* sc, tsi_peer peer,
grpc_auth_context** auth_context,
grpc_closure* on_peer_checked) {
grpc_security_status status;
status = grpc_core::internal::grpc_alts_auth_context_from_tsi_peer(
&peer, auth_context);
tsi_peer_destruct(&peer);
grpc_error* error =
status == GRPC_SECURITY_OK
? GRPC_ERROR_NONE
: GRPC_ERROR_CREATE_FROM_STATIC_STRING(
"Could not get ALTS auth context from TSI peer");
GRPC_CLOSURE_SCHED(on_peer_checked, error);
}
static int alts_channel_cmp(grpc_security_connector* sc1,
grpc_security_connector* sc2) {
grpc_alts_channel_security_connector* c1 =
reinterpret_cast<grpc_alts_channel_security_connector*>(sc1);
grpc_alts_channel_security_connector* c2 =
reinterpret_cast<grpc_alts_channel_security_connector*>(sc2);
int c = grpc_channel_security_connector_cmp(&c1->base, &c2->base);
if (c != 0) return c;
return strcmp(c1->target_name, c2->target_name);
}
static int alts_server_cmp(grpc_security_connector* sc1,
grpc_security_connector* sc2) {
grpc_alts_server_security_connector* c1 =
reinterpret_cast<grpc_alts_server_security_connector*>(sc1);
grpc_alts_server_security_connector* c2 =
reinterpret_cast<grpc_alts_server_security_connector*>(sc2);
return grpc_server_security_connector_cmp(&c1->base, &c2->base);
}
static grpc_security_connector_vtable alts_channel_vtable = {
alts_channel_destroy, alts_check_peer, alts_channel_cmp};
static grpc_security_connector_vtable alts_server_vtable = {
alts_server_destroy, alts_check_peer, alts_server_cmp};
static bool alts_check_call_host(grpc_channel_security_connector* sc,
const char* host,
grpc_auth_context* auth_context,
grpc_closure* on_call_host_checked,
grpc_error** error) {
grpc_alts_channel_security_connector* alts_sc =
reinterpret_cast<grpc_alts_channel_security_connector*>(sc);
if (host == nullptr || alts_sc == nullptr ||
strcmp(host, alts_sc->target_name) != 0) {
*error = GRPC_ERROR_CREATE_FROM_STATIC_STRING(
"ALTS call host does not match target name");
}
return true;
}
static void alts_cancel_check_call_host(grpc_channel_security_connector* sc,
grpc_closure* on_call_host_checked,
grpc_error* error) {
GRPC_ERROR_UNREF(error);
}
grpc_security_status grpc_alts_channel_security_connector_create(
grpc_channel_credentials* channel_creds,
grpc_call_credentials* request_metadata_creds, const char* target_name,
grpc_channel_security_connector** sc) {
if (channel_creds == nullptr || sc == nullptr || target_name == nullptr) {
grpc_core::RefCountedPtr<grpc_channel_security_connector>
grpc_alts_channel_security_connector_create(
grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds,
grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds,
const char* target_name) {
if (channel_creds == nullptr || target_name == nullptr) {
gpr_log(
GPR_ERROR,
"Invalid arguments to grpc_alts_channel_security_connector_create()");
return GRPC_SECURITY_ERROR;
return nullptr;
}
auto c = static_cast<grpc_alts_channel_security_connector*>(
gpr_zalloc(sizeof(grpc_alts_channel_security_connector)));
gpr_ref_init(&c->base.base.refcount, 1);
c->base.base.vtable = &alts_channel_vtable;
c->base.add_handshakers = alts_channel_add_handshakers;
c->base.channel_creds = grpc_channel_credentials_ref(channel_creds);
c->base.request_metadata_creds =
grpc_call_credentials_ref(request_metadata_creds);
c->base.check_call_host = alts_check_call_host;
c->base.cancel_check_call_host = alts_cancel_check_call_host;
grpc_alts_credentials* creds =
reinterpret_cast<grpc_alts_credentials*>(c->base.channel_creds);
alts_set_rpc_protocol_versions(&creds->options->rpc_versions);
c->target_name = gpr_strdup(target_name);
*sc = &c->base;
return GRPC_SECURITY_OK;
return grpc_core::MakeRefCounted<grpc_alts_channel_security_connector>(
std::move(channel_creds), std::move(request_metadata_creds), target_name);
}
grpc_security_status grpc_alts_server_security_connector_create(
grpc_server_credentials* server_creds,
grpc_server_security_connector** sc) {
if (server_creds == nullptr || sc == nullptr) {
grpc_core::RefCountedPtr<grpc_server_security_connector>
grpc_alts_server_security_connector_create(
grpc_core::RefCountedPtr<grpc_server_credentials> server_creds) {
if (server_creds == nullptr) {
gpr_log(
GPR_ERROR,
"Invalid arguments to grpc_alts_server_security_connector_create()");
return GRPC_SECURITY_ERROR;
return nullptr;
}
auto c = static_cast<grpc_alts_server_security_connector*>(
gpr_zalloc(sizeof(grpc_alts_server_security_connector)));
gpr_ref_init(&c->base.base.refcount, 1);
c->base.base.vtable = &alts_server_vtable;
c->base.server_creds = grpc_server_credentials_ref(server_creds);
c->base.add_handshakers = alts_server_add_handshakers;
grpc_alts_server_credentials* creds =
reinterpret_cast<grpc_alts_server_credentials*>(c->base.server_creds);
alts_set_rpc_protocol_versions(&creds->options->rpc_versions);
*sc = &c->base;
return GRPC_SECURITY_OK;
return grpc_core::MakeRefCounted<grpc_alts_server_security_connector>(
std::move(server_creds));
}

@ -36,12 +36,13 @@
* - sc: address of ALTS channel security connector instance to be returned from
* the method.
*
* It returns GRPC_SECURITY_OK on success, and an error stauts code on failure.
* It returns nullptr on failure.
*/
grpc_security_status grpc_alts_channel_security_connector_create(
grpc_channel_credentials* channel_creds,
grpc_call_credentials* request_metadata_creds, const char* target_name,
grpc_channel_security_connector** sc);
grpc_core::RefCountedPtr<grpc_channel_security_connector>
grpc_alts_channel_security_connector_create(
grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds,
grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds,
const char* target_name);
/**
* This method creates an ALTS server security connector.
@ -50,17 +51,18 @@ grpc_security_status grpc_alts_channel_security_connector_create(
* - sc: address of ALTS server security connector instance to be returned from
* the method.
*
* It returns GRPC_SECURITY_OK on success, and an error status code on failure.
* It returns nullptr on failure.
*/
grpc_security_status grpc_alts_server_security_connector_create(
grpc_server_credentials* server_creds, grpc_server_security_connector** sc);
grpc_core::RefCountedPtr<grpc_server_security_connector>
grpc_alts_server_security_connector_create(
grpc_core::RefCountedPtr<grpc_server_credentials> server_creds);
namespace grpc_core {
namespace internal {
/* Exposed only for testing. */
grpc_security_status grpc_alts_auth_context_from_tsi_peer(
const tsi_peer* peer, grpc_auth_context** ctx);
grpc_core::RefCountedPtr<grpc_auth_context>
grpc_alts_auth_context_from_tsi_peer(const tsi_peer* peer);
} // namespace internal
} // namespace grpc_core

@ -31,6 +31,7 @@
#include "src/core/lib/channel/handshaker.h"
#include "src/core/lib/gpr/host_port.h"
#include "src/core/lib/gpr/string.h"
#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/lib/security/context/security_context.h"
#include "src/core/lib/security/credentials/credentials.h"
#include "src/core/lib/security/credentials/fake/fake_credentials.h"
@ -38,91 +39,183 @@
#include "src/core/lib/security/transport/target_authority_table.h"
#include "src/core/tsi/fake_transport_security.h"
typedef struct {
grpc_channel_security_connector base;
char* target;
char* expected_targets;
bool is_lb_channel;
char* target_name_override;
} grpc_fake_channel_security_connector;
namespace {
class grpc_fake_channel_security_connector final
: public grpc_channel_security_connector {
public:
grpc_fake_channel_security_connector(
grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds,
grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds,
const char* target, const grpc_channel_args* args)
: grpc_channel_security_connector(GRPC_FAKE_SECURITY_URL_SCHEME,
std::move(channel_creds),
std::move(request_metadata_creds)),
target_(gpr_strdup(target)),
expected_targets_(
gpr_strdup(grpc_fake_transport_get_expected_targets(args))),
is_lb_channel_(grpc_core::FindTargetAuthorityTableInArgs(args) !=
nullptr) {
const grpc_arg* target_name_override_arg =
grpc_channel_args_find(args, GRPC_SSL_TARGET_NAME_OVERRIDE_ARG);
if (target_name_override_arg != nullptr) {
target_name_override_ =
gpr_strdup(grpc_channel_arg_get_string(target_name_override_arg));
} else {
target_name_override_ = nullptr;
}
}
static void fake_channel_destroy(grpc_security_connector* sc) {
grpc_fake_channel_security_connector* c =
reinterpret_cast<grpc_fake_channel_security_connector*>(sc);
grpc_call_credentials_unref(c->base.request_metadata_creds);
gpr_free(c->target);
gpr_free(c->expected_targets);
gpr_free(c->target_name_override);
gpr_free(c);
}
~grpc_fake_channel_security_connector() override {
gpr_free(target_);
gpr_free(expected_targets_);
if (target_name_override_ != nullptr) gpr_free(target_name_override_);
}
static void fake_server_destroy(grpc_security_connector* sc) { gpr_free(sc); }
void check_peer(tsi_peer peer,
grpc_core::RefCountedPtr<grpc_auth_context>* auth_context,
grpc_closure* on_peer_checked) override;
static bool fake_check_target(const char* target_type, const char* target,
const char* set_str) {
GPR_ASSERT(target_type != nullptr);
GPR_ASSERT(target != nullptr);
char** set = nullptr;
size_t set_size = 0;
gpr_string_split(set_str, ",", &set, &set_size);
bool found = false;
for (size_t i = 0; i < set_size; ++i) {
if (set[i] != nullptr && strcmp(target, set[i]) == 0) found = true;
int cmp(const grpc_security_connector* other_sc) const override {
auto* other =
reinterpret_cast<const grpc_fake_channel_security_connector*>(other_sc);
int c = channel_security_connector_cmp(other);
if (c != 0) return c;
c = strcmp(target_, other->target_);
if (c != 0) return c;
if (expected_targets_ == nullptr || other->expected_targets_ == nullptr) {
c = GPR_ICMP(expected_targets_, other->expected_targets_);
} else {
c = strcmp(expected_targets_, other->expected_targets_);
}
if (c != 0) return c;
return GPR_ICMP(is_lb_channel_, other->is_lb_channel_);
}
for (size_t i = 0; i < set_size; ++i) {
gpr_free(set[i]);
void add_handshakers(grpc_pollset_set* interested_parties,
grpc_handshake_manager* handshake_mgr) override {
grpc_handshake_manager_add(
handshake_mgr,
grpc_security_handshaker_create(
tsi_create_fake_handshaker(/*is_client=*/true), this));
}
gpr_free(set);
return found;
}
static void fake_secure_name_check(const char* target,
const char* expected_targets,
bool is_lb_channel) {
if (expected_targets == nullptr) return;
char** lbs_and_backends = nullptr;
size_t lbs_and_backends_size = 0;
bool success = false;
gpr_string_split(expected_targets, ";", &lbs_and_backends,
&lbs_and_backends_size);
if (lbs_and_backends_size > 2 || lbs_and_backends_size == 0) {
gpr_log(GPR_ERROR, "Invalid expected targets arg value: '%s'",
expected_targets);
goto done;
bool check_call_host(const char* host, grpc_auth_context* auth_context,
grpc_closure* on_call_host_checked,
grpc_error** error) override {
char* authority_hostname = nullptr;
char* authority_ignored_port = nullptr;
char* target_hostname = nullptr;
char* target_ignored_port = nullptr;
gpr_split_host_port(host, &authority_hostname, &authority_ignored_port);
gpr_split_host_port(target_, &target_hostname, &target_ignored_port);
if (target_name_override_ != nullptr) {
char* fake_security_target_name_override_hostname = nullptr;
char* fake_security_target_name_override_ignored_port = nullptr;
gpr_split_host_port(target_name_override_,
&fake_security_target_name_override_hostname,
&fake_security_target_name_override_ignored_port);
if (strcmp(authority_hostname,
fake_security_target_name_override_hostname) != 0) {
gpr_log(GPR_ERROR,
"Authority (host) '%s' != Fake Security Target override '%s'",
host, fake_security_target_name_override_hostname);
abort();
}
gpr_free(fake_security_target_name_override_hostname);
gpr_free(fake_security_target_name_override_ignored_port);
} else if (strcmp(authority_hostname, target_hostname) != 0) {
gpr_log(GPR_ERROR, "Authority (host) '%s' != Target '%s'",
authority_hostname, target_hostname);
abort();
}
gpr_free(authority_hostname);
gpr_free(authority_ignored_port);
gpr_free(target_hostname);
gpr_free(target_ignored_port);
return true;
}
if (is_lb_channel) {
if (lbs_and_backends_size != 2) {
gpr_log(GPR_ERROR,
"Invalid expected targets arg value: '%s'. Expectations for LB "
"channels must be of the form 'be1,be2,be3,...;lb1,lb2,...",
expected_targets);
goto done;
void cancel_check_call_host(grpc_closure* on_call_host_checked,
grpc_error* error) override {
GRPC_ERROR_UNREF(error);
}
char* target() const { return target_; }
char* expected_targets() const { return expected_targets_; }
bool is_lb_channel() const { return is_lb_channel_; }
char* target_name_override() const { return target_name_override_; }
private:
bool fake_check_target(const char* target_type, const char* target,
const char* set_str) const {
GPR_ASSERT(target_type != nullptr);
GPR_ASSERT(target != nullptr);
char** set = nullptr;
size_t set_size = 0;
gpr_string_split(set_str, ",", &set, &set_size);
bool found = false;
for (size_t i = 0; i < set_size; ++i) {
if (set[i] != nullptr && strcmp(target, set[i]) == 0) found = true;
}
if (!fake_check_target("LB", target, lbs_and_backends[1])) {
gpr_log(GPR_ERROR, "LB target '%s' not found in expected set '%s'",
target, lbs_and_backends[1]);
goto done;
for (size_t i = 0; i < set_size; ++i) {
gpr_free(set[i]);
}
success = true;
} else {
if (!fake_check_target("Backend", target, lbs_and_backends[0])) {
gpr_log(GPR_ERROR, "Backend target '%s' not found in expected set '%s'",
target, lbs_and_backends[0]);
gpr_free(set);
return found;
}
void fake_secure_name_check() const {
if (expected_targets_ == nullptr) return;
char** lbs_and_backends = nullptr;
size_t lbs_and_backends_size = 0;
bool success = false;
gpr_string_split(expected_targets_, ";", &lbs_and_backends,
&lbs_and_backends_size);
if (lbs_and_backends_size > 2 || lbs_and_backends_size == 0) {
gpr_log(GPR_ERROR, "Invalid expected targets arg value: '%s'",
expected_targets_);
goto done;
}
success = true;
}
done:
for (size_t i = 0; i < lbs_and_backends_size; ++i) {
gpr_free(lbs_and_backends[i]);
if (is_lb_channel_) {
if (lbs_and_backends_size != 2) {
gpr_log(GPR_ERROR,
"Invalid expected targets arg value: '%s'. Expectations for LB "
"channels must be of the form 'be1,be2,be3,...;lb1,lb2,...",
expected_targets_);
goto done;
}
if (!fake_check_target("LB", target_, lbs_and_backends[1])) {
gpr_log(GPR_ERROR, "LB target '%s' not found in expected set '%s'",
target_, lbs_and_backends[1]);
goto done;
}
success = true;
} else {
if (!fake_check_target("Backend", target_, lbs_and_backends[0])) {
gpr_log(GPR_ERROR, "Backend target '%s' not found in expected set '%s'",
target_, lbs_and_backends[0]);
goto done;
}
success = true;
}
done:
for (size_t i = 0; i < lbs_and_backends_size; ++i) {
gpr_free(lbs_and_backends[i]);
}
gpr_free(lbs_and_backends);
if (!success) abort();
}
gpr_free(lbs_and_backends);
if (!success) abort();
}
static void fake_check_peer(grpc_security_connector* sc, tsi_peer peer,
grpc_auth_context** auth_context,
grpc_closure* on_peer_checked) {
char* target_;
char* expected_targets_;
bool is_lb_channel_;
char* target_name_override_;
};
static void fake_check_peer(
grpc_security_connector* sc, tsi_peer peer,
grpc_core::RefCountedPtr<grpc_auth_context>* auth_context,
grpc_closure* on_peer_checked) {
const char* prop_name;
grpc_error* error = GRPC_ERROR_NONE;
*auth_context = nullptr;
@ -147,164 +240,65 @@ static void fake_check_peer(grpc_security_connector* sc, tsi_peer peer,
"Invalid value for cert type property.");
goto end;
}
*auth_context = grpc_auth_context_create(nullptr);
*auth_context = grpc_core::MakeRefCounted<grpc_auth_context>(nullptr);
grpc_auth_context_add_cstring_property(
*auth_context, GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME,
auth_context->get(), GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME,
GRPC_FAKE_TRANSPORT_SECURITY_TYPE);
end:
GRPC_CLOSURE_SCHED(on_peer_checked, error);
tsi_peer_destruct(&peer);
}
static void fake_channel_check_peer(grpc_security_connector* sc, tsi_peer peer,
grpc_auth_context** auth_context,
grpc_closure* on_peer_checked) {
fake_check_peer(sc, peer, auth_context, on_peer_checked);
grpc_fake_channel_security_connector* c =
reinterpret_cast<grpc_fake_channel_security_connector*>(sc);
fake_secure_name_check(c->target, c->expected_targets, c->is_lb_channel);
void grpc_fake_channel_security_connector::check_peer(
tsi_peer peer, grpc_core::RefCountedPtr<grpc_auth_context>* auth_context,
grpc_closure* on_peer_checked) {
fake_check_peer(this, peer, auth_context, on_peer_checked);
fake_secure_name_check();
}
static void fake_server_check_peer(grpc_security_connector* sc, tsi_peer peer,
grpc_auth_context** auth_context,
grpc_closure* on_peer_checked) {
fake_check_peer(sc, peer, auth_context, on_peer_checked);
}
class grpc_fake_server_security_connector
: public grpc_server_security_connector {
public:
grpc_fake_server_security_connector(
grpc_core::RefCountedPtr<grpc_server_credentials> server_creds)
: grpc_server_security_connector(GRPC_FAKE_SECURITY_URL_SCHEME,
std::move(server_creds)) {}
~grpc_fake_server_security_connector() override = default;
static int fake_channel_cmp(grpc_security_connector* sc1,
grpc_security_connector* sc2) {
grpc_fake_channel_security_connector* c1 =
reinterpret_cast<grpc_fake_channel_security_connector*>(sc1);
grpc_fake_channel_security_connector* c2 =
reinterpret_cast<grpc_fake_channel_security_connector*>(sc2);
int c = grpc_channel_security_connector_cmp(&c1->base, &c2->base);
if (c != 0) return c;
c = strcmp(c1->target, c2->target);
if (c != 0) return c;
if (c1->expected_targets == nullptr || c2->expected_targets == nullptr) {
c = GPR_ICMP(c1->expected_targets, c2->expected_targets);
} else {
c = strcmp(c1->expected_targets, c2->expected_targets);
void check_peer(tsi_peer peer,
grpc_core::RefCountedPtr<grpc_auth_context>* auth_context,
grpc_closure* on_peer_checked) override {
fake_check_peer(this, peer, auth_context, on_peer_checked);
}
if (c != 0) return c;
return GPR_ICMP(c1->is_lb_channel, c2->is_lb_channel);
}
static int fake_server_cmp(grpc_security_connector* sc1,
grpc_security_connector* sc2) {
return grpc_server_security_connector_cmp(
reinterpret_cast<grpc_server_security_connector*>(sc1),
reinterpret_cast<grpc_server_security_connector*>(sc2));
}
static bool fake_channel_check_call_host(grpc_channel_security_connector* sc,
const char* host,
grpc_auth_context* auth_context,
grpc_closure* on_call_host_checked,
grpc_error** error) {
grpc_fake_channel_security_connector* c =
reinterpret_cast<grpc_fake_channel_security_connector*>(sc);
char* authority_hostname = nullptr;
char* authority_ignored_port = nullptr;
char* target_hostname = nullptr;
char* target_ignored_port = nullptr;
gpr_split_host_port(host, &authority_hostname, &authority_ignored_port);
gpr_split_host_port(c->target, &target_hostname, &target_ignored_port);
if (c->target_name_override != nullptr) {
char* fake_security_target_name_override_hostname = nullptr;
char* fake_security_target_name_override_ignored_port = nullptr;
gpr_split_host_port(c->target_name_override,
&fake_security_target_name_override_hostname,
&fake_security_target_name_override_ignored_port);
if (strcmp(authority_hostname,
fake_security_target_name_override_hostname) != 0) {
gpr_log(GPR_ERROR,
"Authority (host) '%s' != Fake Security Target override '%s'",
host, fake_security_target_name_override_hostname);
abort();
}
gpr_free(fake_security_target_name_override_hostname);
gpr_free(fake_security_target_name_override_ignored_port);
} else if (strcmp(authority_hostname, target_hostname) != 0) {
gpr_log(GPR_ERROR, "Authority (host) '%s' != Target '%s'",
authority_hostname, target_hostname);
abort();
void add_handshakers(grpc_pollset_set* interested_parties,
grpc_handshake_manager* handshake_mgr) override {
grpc_handshake_manager_add(
handshake_mgr,
grpc_security_handshaker_create(
tsi_create_fake_handshaker(/*=is_client*/ false), this));
}
gpr_free(authority_hostname);
gpr_free(authority_ignored_port);
gpr_free(target_hostname);
gpr_free(target_ignored_port);
return true;
}
static void fake_channel_cancel_check_call_host(
grpc_channel_security_connector* sc, grpc_closure* on_call_host_checked,
grpc_error* error) {
GRPC_ERROR_UNREF(error);
}
static void fake_channel_add_handshakers(
grpc_channel_security_connector* sc, grpc_pollset_set* interested_parties,
grpc_handshake_manager* handshake_mgr) {
grpc_handshake_manager_add(
handshake_mgr,
grpc_security_handshaker_create(
tsi_create_fake_handshaker(true /* is_client */), &sc->base));
}
static void fake_server_add_handshakers(grpc_server_security_connector* sc,
grpc_pollset_set* interested_parties,
grpc_handshake_manager* handshake_mgr) {
grpc_handshake_manager_add(
handshake_mgr,
grpc_security_handshaker_create(
tsi_create_fake_handshaker(false /* is_client */), &sc->base));
}
static grpc_security_connector_vtable fake_channel_vtable = {
fake_channel_destroy, fake_channel_check_peer, fake_channel_cmp};
static grpc_security_connector_vtable fake_server_vtable = {
fake_server_destroy, fake_server_check_peer, fake_server_cmp};
grpc_channel_security_connector* grpc_fake_channel_security_connector_create(
grpc_channel_credentials* channel_creds,
grpc_call_credentials* request_metadata_creds, const char* target,
const grpc_channel_args* args) {
grpc_fake_channel_security_connector* c =
static_cast<grpc_fake_channel_security_connector*>(
gpr_zalloc(sizeof(*c)));
gpr_ref_init(&c->base.base.refcount, 1);
c->base.base.url_scheme = GRPC_FAKE_SECURITY_URL_SCHEME;
c->base.base.vtable = &fake_channel_vtable;
c->base.channel_creds = channel_creds;
c->base.request_metadata_creds =
grpc_call_credentials_ref(request_metadata_creds);
c->base.check_call_host = fake_channel_check_call_host;
c->base.cancel_check_call_host = fake_channel_cancel_check_call_host;
c->base.add_handshakers = fake_channel_add_handshakers;
c->target = gpr_strdup(target);
const char* expected_targets = grpc_fake_transport_get_expected_targets(args);
c->expected_targets = gpr_strdup(expected_targets);
c->is_lb_channel = grpc_core::FindTargetAuthorityTableInArgs(args) != nullptr;
const grpc_arg* target_name_override_arg =
grpc_channel_args_find(args, GRPC_SSL_TARGET_NAME_OVERRIDE_ARG);
if (target_name_override_arg != nullptr) {
c->target_name_override =
gpr_strdup(grpc_channel_arg_get_string(target_name_override_arg));
int cmp(const grpc_security_connector* other) const override {
return server_security_connector_cmp(
static_cast<const grpc_server_security_connector*>(other));
}
return &c->base;
};
} // namespace
grpc_core::RefCountedPtr<grpc_channel_security_connector>
grpc_fake_channel_security_connector_create(
grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds,
grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds,
const char* target, const grpc_channel_args* args) {
return grpc_core::MakeRefCounted<grpc_fake_channel_security_connector>(
std::move(channel_creds), std::move(request_metadata_creds), target,
args);
}
grpc_server_security_connector* grpc_fake_server_security_connector_create(
grpc_server_credentials* server_creds) {
grpc_server_security_connector* c =
static_cast<grpc_server_security_connector*>(
gpr_zalloc(sizeof(grpc_server_security_connector)));
gpr_ref_init(&c->base.refcount, 1);
c->base.vtable = &fake_server_vtable;
c->base.url_scheme = GRPC_FAKE_SECURITY_URL_SCHEME;
c->server_creds = server_creds;
c->add_handshakers = fake_server_add_handshakers;
return c;
grpc_core::RefCountedPtr<grpc_server_security_connector>
grpc_fake_server_security_connector_create(
grpc_core::RefCountedPtr<grpc_server_credentials> server_creds) {
return grpc_core::MakeRefCounted<grpc_fake_server_security_connector>(
std::move(server_creds));
}

@ -24,19 +24,22 @@
#include <grpc/grpc_security.h>
#include "src/core/lib/channel/handshaker.h"
#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/lib/security/security_connector/security_connector.h"
#define GRPC_FAKE_SECURITY_URL_SCHEME "http+fake_security"
/* Creates a fake connector that emulates real channel security. */
grpc_channel_security_connector* grpc_fake_channel_security_connector_create(
grpc_channel_credentials* channel_creds,
grpc_call_credentials* request_metadata_creds, const char* target,
const grpc_channel_args* args);
grpc_core::RefCountedPtr<grpc_channel_security_connector>
grpc_fake_channel_security_connector_create(
grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds,
grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds,
const char* target, const grpc_channel_args* args);
/* Creates a fake connector that emulates real server security. */
grpc_server_security_connector* grpc_fake_server_security_connector_create(
grpc_server_credentials* server_creds);
grpc_core::RefCountedPtr<grpc_server_security_connector>
grpc_fake_server_security_connector_create(
grpc_core::RefCountedPtr<grpc_server_credentials> server_creds);
#endif /* GRPC_CORE_LIB_SECURITY_SECURITY_CONNECTOR_FAKE_FAKE_SECURITY_CONNECTOR_H \
*/

@ -30,6 +30,7 @@
#include "src/core/ext/filters/client_channel/client_channel.h"
#include "src/core/lib/channel/channel_args.h"
#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/lib/iomgr/pollset.h"
#include "src/core/lib/security/credentials/local/local_credentials.h"
#include "src/core/lib/security/transport/security_handshaker.h"
@ -39,153 +40,145 @@
#define GRPC_UDS_URL_SCHEME "unix"
#define GRPC_LOCAL_TRANSPORT_SECURITY_TYPE "local"
typedef struct {
grpc_channel_security_connector base;
char* target_name;
} grpc_local_channel_security_connector;
namespace {
typedef struct {
grpc_server_security_connector base;
} grpc_local_server_security_connector;
static void local_channel_destroy(grpc_security_connector* sc) {
if (sc == nullptr) {
return;
}
auto c = reinterpret_cast<grpc_local_channel_security_connector*>(sc);
grpc_call_credentials_unref(c->base.request_metadata_creds);
grpc_channel_credentials_unref(c->base.channel_creds);
gpr_free(c->target_name);
gpr_free(sc);
}
static void local_server_destroy(grpc_security_connector* sc) {
if (sc == nullptr) {
return;
}
auto c = reinterpret_cast<grpc_local_server_security_connector*>(sc);
grpc_server_credentials_unref(c->base.server_creds);
gpr_free(sc);
}
static void local_channel_add_handshakers(
grpc_channel_security_connector* sc, grpc_pollset_set* interested_parties,
grpc_handshake_manager* handshake_manager) {
tsi_handshaker* handshaker = nullptr;
GPR_ASSERT(local_tsi_handshaker_create(true /* is_client */, &handshaker) ==
TSI_OK);
grpc_handshake_manager_add(handshake_manager, grpc_security_handshaker_create(
handshaker, &sc->base));
}
static void local_server_add_handshakers(
grpc_server_security_connector* sc, grpc_pollset_set* interested_parties,
grpc_handshake_manager* handshake_manager) {
tsi_handshaker* handshaker = nullptr;
GPR_ASSERT(local_tsi_handshaker_create(false /* is_client */, &handshaker) ==
TSI_OK);
grpc_handshake_manager_add(handshake_manager, grpc_security_handshaker_create(
handshaker, &sc->base));
}
static int local_channel_cmp(grpc_security_connector* sc1,
grpc_security_connector* sc2) {
grpc_local_channel_security_connector* c1 =
reinterpret_cast<grpc_local_channel_security_connector*>(sc1);
grpc_local_channel_security_connector* c2 =
reinterpret_cast<grpc_local_channel_security_connector*>(sc2);
int c = grpc_channel_security_connector_cmp(&c1->base, &c2->base);
if (c != 0) return c;
return strcmp(c1->target_name, c2->target_name);
}
static int local_server_cmp(grpc_security_connector* sc1,
grpc_security_connector* sc2) {
grpc_local_server_security_connector* c1 =
reinterpret_cast<grpc_local_server_security_connector*>(sc1);
grpc_local_server_security_connector* c2 =
reinterpret_cast<grpc_local_server_security_connector*>(sc2);
return grpc_server_security_connector_cmp(&c1->base, &c2->base);
}
static grpc_security_status local_auth_context_create(grpc_auth_context** ctx) {
if (ctx == nullptr) {
gpr_log(GPR_ERROR, "Invalid arguments to local_auth_context_create()");
return GRPC_SECURITY_ERROR;
}
grpc_core::RefCountedPtr<grpc_auth_context> local_auth_context_create() {
/* Create auth context. */
*ctx = grpc_auth_context_create(nullptr);
grpc_core::RefCountedPtr<grpc_auth_context> ctx =
grpc_core::MakeRefCounted<grpc_auth_context>(nullptr);
grpc_auth_context_add_cstring_property(
*ctx, GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME,
ctx.get(), GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME,
GRPC_LOCAL_TRANSPORT_SECURITY_TYPE);
GPR_ASSERT(grpc_auth_context_set_peer_identity_property_name(
*ctx, GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME) == 1);
return GRPC_SECURITY_OK;
ctx.get(), GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME) == 1);
return ctx;
}
static void local_check_peer(grpc_security_connector* sc, tsi_peer peer,
grpc_auth_context** auth_context,
grpc_closure* on_peer_checked) {
grpc_security_status status;
void local_check_peer(grpc_security_connector* sc, tsi_peer peer,
grpc_core::RefCountedPtr<grpc_auth_context>* auth_context,
grpc_closure* on_peer_checked) {
/* Create an auth context which is necessary to pass the santiy check in
* {client, server}_auth_filter that verifies if the peer's auth context is
* {client, server}_auth_filter that verifies if the pepp's auth context is
* obtained during handshakes. The auth context is only checked for its
* existence and not actually used.
*/
status = local_auth_context_create(auth_context);
grpc_error* error = status == GRPC_SECURITY_OK
*auth_context = local_auth_context_create();
grpc_error* error = *auth_context != nullptr
? GRPC_ERROR_NONE
: GRPC_ERROR_CREATE_FROM_STATIC_STRING(
"Could not create local auth context");
GRPC_CLOSURE_SCHED(on_peer_checked, error);
}
static grpc_security_connector_vtable local_channel_vtable = {
local_channel_destroy, local_check_peer, local_channel_cmp};
static grpc_security_connector_vtable local_server_vtable = {
local_server_destroy, local_check_peer, local_server_cmp};
static bool local_check_call_host(grpc_channel_security_connector* sc,
const char* host,
grpc_auth_context* auth_context,
grpc_closure* on_call_host_checked,
grpc_error** error) {
grpc_local_channel_security_connector* local_sc =
reinterpret_cast<grpc_local_channel_security_connector*>(sc);
if (host == nullptr || local_sc == nullptr ||
strcmp(host, local_sc->target_name) != 0) {
*error = GRPC_ERROR_CREATE_FROM_STATIC_STRING(
"local call host does not match target name");
class grpc_local_channel_security_connector final
: public grpc_channel_security_connector {
public:
grpc_local_channel_security_connector(
grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds,
grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds,
const char* target_name)
: grpc_channel_security_connector(GRPC_UDS_URL_SCHEME,
std::move(channel_creds),
std::move(request_metadata_creds)),
target_name_(gpr_strdup(target_name)) {}
~grpc_local_channel_security_connector() override { gpr_free(target_name_); }
void add_handshakers(grpc_pollset_set* interested_parties,
grpc_handshake_manager* handshake_manager) override {
tsi_handshaker* handshaker = nullptr;
GPR_ASSERT(local_tsi_handshaker_create(true /* is_client */, &handshaker) ==
TSI_OK);
grpc_handshake_manager_add(
handshake_manager, grpc_security_handshaker_create(handshaker, this));
}
return true;
}
static void local_cancel_check_call_host(grpc_channel_security_connector* sc,
grpc_closure* on_call_host_checked,
grpc_error* error) {
GRPC_ERROR_UNREF(error);
}
int cmp(const grpc_security_connector* other_sc) const override {
auto* other =
reinterpret_cast<const grpc_local_channel_security_connector*>(
other_sc);
int c = channel_security_connector_cmp(other);
if (c != 0) return c;
return strcmp(target_name_, other->target_name_);
}
void check_peer(tsi_peer peer,
grpc_core::RefCountedPtr<grpc_auth_context>* auth_context,
grpc_closure* on_peer_checked) override {
local_check_peer(this, peer, auth_context, on_peer_checked);
}
bool check_call_host(const char* host, grpc_auth_context* auth_context,
grpc_closure* on_call_host_checked,
grpc_error** error) override {
if (host == nullptr || strcmp(host, target_name_) != 0) {
*error = GRPC_ERROR_CREATE_FROM_STATIC_STRING(
"local call host does not match target name");
}
return true;
}
grpc_security_status grpc_local_channel_security_connector_create(
grpc_channel_credentials* channel_creds,
grpc_call_credentials* request_metadata_creds,
const grpc_channel_args* args, const char* target_name,
grpc_channel_security_connector** sc) {
if (channel_creds == nullptr || sc == nullptr || target_name == nullptr) {
void cancel_check_call_host(grpc_closure* on_call_host_checked,
grpc_error* error) override {
GRPC_ERROR_UNREF(error);
}
const char* target_name() const { return target_name_; }
private:
char* target_name_;
};
class grpc_local_server_security_connector final
: public grpc_server_security_connector {
public:
grpc_local_server_security_connector(
grpc_core::RefCountedPtr<grpc_server_credentials> server_creds)
: grpc_server_security_connector(GRPC_UDS_URL_SCHEME,
std::move(server_creds)) {}
~grpc_local_server_security_connector() override = default;
void add_handshakers(grpc_pollset_set* interested_parties,
grpc_handshake_manager* handshake_manager) override {
tsi_handshaker* handshaker = nullptr;
GPR_ASSERT(local_tsi_handshaker_create(false /* is_client */,
&handshaker) == TSI_OK);
grpc_handshake_manager_add(
handshake_manager, grpc_security_handshaker_create(handshaker, this));
}
void check_peer(tsi_peer peer,
grpc_core::RefCountedPtr<grpc_auth_context>* auth_context,
grpc_closure* on_peer_checked) override {
local_check_peer(this, peer, auth_context, on_peer_checked);
}
int cmp(const grpc_security_connector* other) const override {
return server_security_connector_cmp(
static_cast<const grpc_server_security_connector*>(other));
}
};
} // namespace
grpc_core::RefCountedPtr<grpc_channel_security_connector>
grpc_local_channel_security_connector_create(
grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds,
grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds,
const grpc_channel_args* args, const char* target_name) {
if (channel_creds == nullptr || target_name == nullptr) {
gpr_log(
GPR_ERROR,
"Invalid arguments to grpc_local_channel_security_connector_create()");
return GRPC_SECURITY_ERROR;
return nullptr;
}
// Check if local_connect_type is UDS. Only UDS is supported for now.
grpc_local_credentials* creds =
reinterpret_cast<grpc_local_credentials*>(channel_creds);
if (creds->connect_type != UDS) {
static_cast<grpc_local_credentials*>(channel_creds.get());
if (creds->connect_type() != UDS) {
gpr_log(GPR_ERROR,
"Invalid local channel type to "
"grpc_local_channel_security_connector_create()");
return GRPC_SECURITY_ERROR;
return nullptr;
}
// Check if target_name is a valid UDS address.
const grpc_arg* server_uri_arg =
@ -196,51 +189,30 @@ grpc_security_status grpc_local_channel_security_connector_create(
gpr_log(GPR_ERROR,
"Invalid target_name to "
"grpc_local_channel_security_connector_create()");
return GRPC_SECURITY_ERROR;
return nullptr;
}
auto c = static_cast<grpc_local_channel_security_connector*>(
gpr_zalloc(sizeof(grpc_local_channel_security_connector)));
gpr_ref_init(&c->base.base.refcount, 1);
c->base.base.vtable = &local_channel_vtable;
c->base.add_handshakers = local_channel_add_handshakers;
c->base.channel_creds = grpc_channel_credentials_ref(channel_creds);
c->base.request_metadata_creds =
grpc_call_credentials_ref(request_metadata_creds);
c->base.check_call_host = local_check_call_host;
c->base.cancel_check_call_host = local_cancel_check_call_host;
c->base.base.url_scheme =
creds->connect_type == UDS ? GRPC_UDS_URL_SCHEME : nullptr;
c->target_name = gpr_strdup(target_name);
*sc = &c->base;
return GRPC_SECURITY_OK;
return grpc_core::MakeRefCounted<grpc_local_channel_security_connector>(
channel_creds, request_metadata_creds, target_name);
}
grpc_security_status grpc_local_server_security_connector_create(
grpc_server_credentials* server_creds,
grpc_server_security_connector** sc) {
if (server_creds == nullptr || sc == nullptr) {
grpc_core::RefCountedPtr<grpc_server_security_connector>
grpc_local_server_security_connector_create(
grpc_core::RefCountedPtr<grpc_server_credentials> server_creds) {
if (server_creds == nullptr) {
gpr_log(
GPR_ERROR,
"Invalid arguments to grpc_local_server_security_connector_create()");
return GRPC_SECURITY_ERROR;
return nullptr;
}
// Check if local_connect_type is UDS. Only UDS is supported for now.
grpc_local_server_credentials* creds =
reinterpret_cast<grpc_local_server_credentials*>(server_creds);
if (creds->connect_type != UDS) {
const grpc_local_server_credentials* creds =
static_cast<const grpc_local_server_credentials*>(server_creds.get());
if (creds->connect_type() != UDS) {
gpr_log(GPR_ERROR,
"Invalid local server type to "
"grpc_local_server_security_connector_create()");
return GRPC_SECURITY_ERROR;
return nullptr;
}
auto c = static_cast<grpc_local_server_security_connector*>(
gpr_zalloc(sizeof(grpc_local_server_security_connector)));
gpr_ref_init(&c->base.base.refcount, 1);
c->base.base.vtable = &local_server_vtable;
c->base.server_creds = grpc_server_credentials_ref(server_creds);
c->base.base.url_scheme =
creds->connect_type == UDS ? GRPC_UDS_URL_SCHEME : nullptr;
c->base.add_handshakers = local_server_add_handshakers;
*sc = &c->base;
return GRPC_SECURITY_OK;
return grpc_core::MakeRefCounted<grpc_local_server_security_connector>(
std::move(server_creds));
}

@ -34,13 +34,13 @@
* - sc: address of local channel security connector instance to be returned
* from the method.
*
* It returns GRPC_SECURITY_OK on success, and an error stauts code on failure.
* It returns nullptr on failure.
*/
grpc_security_status grpc_local_channel_security_connector_create(
grpc_channel_credentials* channel_creds,
grpc_call_credentials* request_metadata_creds,
const grpc_channel_args* args, const char* target_name,
grpc_channel_security_connector** sc);
grpc_core::RefCountedPtr<grpc_channel_security_connector>
grpc_local_channel_security_connector_create(
grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds,
grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds,
const grpc_channel_args* args, const char* target_name);
/**
* This method creates a local server security connector.
@ -49,10 +49,11 @@ grpc_security_status grpc_local_channel_security_connector_create(
* - sc: address of local server security connector instance to be returned from
* the method.
*
* It returns GRPC_SECURITY_OK on success, and an error status code on failure.
* It returns nullptr on failure.
*/
grpc_security_status grpc_local_server_security_connector_create(
grpc_server_credentials* server_creds, grpc_server_security_connector** sc);
grpc_core::RefCountedPtr<grpc_server_security_connector>
grpc_local_server_security_connector_create(
grpc_core::RefCountedPtr<grpc_server_credentials> server_creds);
#endif /* GRPC_CORE_LIB_SECURITY_SECURITY_CONNECTOR_LOCAL_LOCAL_SECURITY_CONNECTOR_H \
*/

@ -35,150 +35,67 @@
#include "src/core/lib/security/context/security_context.h"
#include "src/core/lib/security/credentials/credentials.h"
#include "src/core/lib/security/security_connector/load_system_roots.h"
#include "src/core/lib/security/security_connector/security_connector.h"
#include "src/core/lib/security/transport/security_handshaker.h"
grpc_core::DebugOnlyTraceFlag grpc_trace_security_connector_refcount(
false, "security_connector_refcount");
void grpc_channel_security_connector_add_handshakers(
grpc_channel_security_connector* connector,
grpc_pollset_set* interested_parties,
grpc_handshake_manager* handshake_mgr) {
if (connector != nullptr) {
connector->add_handshakers(connector, interested_parties, handshake_mgr);
}
}
void grpc_server_security_connector_add_handshakers(
grpc_server_security_connector* connector,
grpc_pollset_set* interested_parties,
grpc_handshake_manager* handshake_mgr) {
if (connector != nullptr) {
connector->add_handshakers(connector, interested_parties, handshake_mgr);
}
}
void grpc_security_connector_check_peer(grpc_security_connector* sc,
tsi_peer peer,
grpc_auth_context** auth_context,
grpc_closure* on_peer_checked) {
if (sc == nullptr) {
GRPC_CLOSURE_SCHED(on_peer_checked,
GRPC_ERROR_CREATE_FROM_STATIC_STRING(
"cannot check peer -- no security connector"));
tsi_peer_destruct(&peer);
} else {
sc->vtable->check_peer(sc, peer, auth_context, on_peer_checked);
}
}
int grpc_security_connector_cmp(grpc_security_connector* sc,
grpc_security_connector* other) {
grpc_server_security_connector::grpc_server_security_connector(
const char* url_scheme,
grpc_core::RefCountedPtr<grpc_server_credentials> server_creds)
: grpc_security_connector(url_scheme),
server_creds_(std::move(server_creds)) {}
grpc_channel_security_connector::grpc_channel_security_connector(
const char* url_scheme,
grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds,
grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds)
: grpc_security_connector(url_scheme),
channel_creds_(std::move(channel_creds)),
request_metadata_creds_(std::move(request_metadata_creds)) {}
grpc_channel_security_connector::~grpc_channel_security_connector() {}
int grpc_security_connector_cmp(const grpc_security_connector* sc,
const grpc_security_connector* other) {
if (sc == nullptr || other == nullptr) return GPR_ICMP(sc, other);
int c = GPR_ICMP(sc->vtable, other->vtable);
if (c != 0) return c;
return sc->vtable->cmp(sc, other);
return sc->cmp(other);
}
int grpc_channel_security_connector_cmp(grpc_channel_security_connector* sc1,
grpc_channel_security_connector* sc2) {
GPR_ASSERT(sc1->channel_creds != nullptr);
GPR_ASSERT(sc2->channel_creds != nullptr);
int c = GPR_ICMP(sc1->channel_creds, sc2->channel_creds);
if (c != 0) return c;
c = GPR_ICMP(sc1->request_metadata_creds, sc2->request_metadata_creds);
if (c != 0) return c;
c = GPR_ICMP((void*)sc1->check_call_host, (void*)sc2->check_call_host);
if (c != 0) return c;
c = GPR_ICMP((void*)sc1->cancel_check_call_host,
(void*)sc2->cancel_check_call_host);
int grpc_channel_security_connector::channel_security_connector_cmp(
const grpc_channel_security_connector* other) const {
const grpc_channel_security_connector* other_sc =
static_cast<const grpc_channel_security_connector*>(other);
GPR_ASSERT(channel_creds() != nullptr);
GPR_ASSERT(other_sc->channel_creds() != nullptr);
int c = GPR_ICMP(channel_creds(), other_sc->channel_creds());
if (c != 0) return c;
return GPR_ICMP((void*)sc1->add_handshakers, (void*)sc2->add_handshakers);
return GPR_ICMP(request_metadata_creds(), other_sc->request_metadata_creds());
}
int grpc_server_security_connector_cmp(grpc_server_security_connector* sc1,
grpc_server_security_connector* sc2) {
GPR_ASSERT(sc1->server_creds != nullptr);
GPR_ASSERT(sc2->server_creds != nullptr);
int c = GPR_ICMP(sc1->server_creds, sc2->server_creds);
if (c != 0) return c;
return GPR_ICMP((void*)sc1->add_handshakers, (void*)sc2->add_handshakers);
}
bool grpc_channel_security_connector_check_call_host(
grpc_channel_security_connector* sc, const char* host,
grpc_auth_context* auth_context, grpc_closure* on_call_host_checked,
grpc_error** error) {
if (sc == nullptr || sc->check_call_host == nullptr) {
*error = GRPC_ERROR_CREATE_FROM_STATIC_STRING(
"cannot check call host -- no security connector");
return true;
}
return sc->check_call_host(sc, host, auth_context, on_call_host_checked,
error);
}
void grpc_channel_security_connector_cancel_check_call_host(
grpc_channel_security_connector* sc, grpc_closure* on_call_host_checked,
grpc_error* error) {
if (sc == nullptr || sc->cancel_check_call_host == nullptr) {
GRPC_ERROR_UNREF(error);
return;
}
sc->cancel_check_call_host(sc, on_call_host_checked, error);
}
#ifndef NDEBUG
grpc_security_connector* grpc_security_connector_ref(
grpc_security_connector* sc, const char* file, int line,
const char* reason) {
if (sc == nullptr) return nullptr;
if (grpc_trace_security_connector_refcount.enabled()) {
gpr_atm val = gpr_atm_no_barrier_load(&sc->refcount.count);
gpr_log(file, line, GPR_LOG_SEVERITY_DEBUG,
"SECURITY_CONNECTOR:%p ref %" PRIdPTR " -> %" PRIdPTR " %s", sc,
val, val + 1, reason);
}
#else
grpc_security_connector* grpc_security_connector_ref(
grpc_security_connector* sc) {
if (sc == nullptr) return nullptr;
#endif
gpr_ref(&sc->refcount);
return sc;
}
#ifndef NDEBUG
void grpc_security_connector_unref(grpc_security_connector* sc,
const char* file, int line,
const char* reason) {
if (sc == nullptr) return;
if (grpc_trace_security_connector_refcount.enabled()) {
gpr_atm val = gpr_atm_no_barrier_load(&sc->refcount.count);
gpr_log(file, line, GPR_LOG_SEVERITY_DEBUG,
"SECURITY_CONNECTOR:%p unref %" PRIdPTR " -> %" PRIdPTR " %s", sc,
val, val - 1, reason);
}
#else
void grpc_security_connector_unref(grpc_security_connector* sc) {
if (sc == nullptr) return;
#endif
if (gpr_unref(&sc->refcount)) sc->vtable->destroy(sc);
int grpc_server_security_connector::server_security_connector_cmp(
const grpc_server_security_connector* other) const {
const grpc_server_security_connector* other_sc =
static_cast<const grpc_server_security_connector*>(other);
GPR_ASSERT(server_creds() != nullptr);
GPR_ASSERT(other_sc->server_creds() != nullptr);
return GPR_ICMP(server_creds(), other_sc->server_creds());
}
static void connector_arg_destroy(void* p) {
GRPC_SECURITY_CONNECTOR_UNREF((grpc_security_connector*)p,
"connector_arg_destroy");
static_cast<grpc_security_connector*>(p)->Unref(DEBUG_LOCATION,
"connector_arg_destroy");
}
static void* connector_arg_copy(void* p) {
return GRPC_SECURITY_CONNECTOR_REF((grpc_security_connector*)p,
"connector_arg_copy");
return static_cast<grpc_security_connector*>(p)
->Ref(DEBUG_LOCATION, "connector_arg_copy")
.release();
}
static int connector_cmp(void* a, void* b) {
return grpc_security_connector_cmp(static_cast<grpc_security_connector*>(a),
static_cast<grpc_security_connector*>(b));
return static_cast<grpc_security_connector*>(a)->cmp(
static_cast<grpc_security_connector*>(b));
}
static const grpc_arg_pointer_vtable connector_arg_vtable = {

@ -26,6 +26,7 @@
#include <grpc/grpc_security.h>
#include "src/core/lib/channel/handshaker.h"
#include "src/core/lib/gprpp/ref_counted.h"
#include "src/core/lib/iomgr/endpoint.h"
#include "src/core/lib/iomgr/pollset.h"
#include "src/core/lib/iomgr/tcp_server.h"
@ -34,8 +35,6 @@
extern grpc_core::DebugOnlyTraceFlag grpc_trace_security_connector_refcount;
/* --- status enum. --- */
typedef enum { GRPC_SECURITY_OK = 0, GRPC_SECURITY_ERROR } grpc_security_status;
/* --- security_connector object. ---
@ -43,54 +42,33 @@ typedef enum { GRPC_SECURITY_OK = 0, GRPC_SECURITY_ERROR } grpc_security_status;
A security connector object represents away to configure the underlying
transport security mechanism and check the resulting trusted peer. */
typedef struct grpc_security_connector grpc_security_connector;
#define GRPC_ARG_SECURITY_CONNECTOR "grpc.security_connector"
typedef struct {
void (*destroy)(grpc_security_connector* sc);
void (*check_peer)(grpc_security_connector* sc, tsi_peer peer,
grpc_auth_context** auth_context,
grpc_closure* on_peer_checked);
int (*cmp)(grpc_security_connector* sc, grpc_security_connector* other);
} grpc_security_connector_vtable;
struct grpc_security_connector {
const grpc_security_connector_vtable* vtable;
gpr_refcount refcount;
const char* url_scheme;
};
class grpc_security_connector
: public grpc_core::RefCounted<grpc_security_connector> {
public:
explicit grpc_security_connector(const char* url_scheme)
: grpc_core::RefCounted<grpc_security_connector>(
&grpc_trace_security_connector_refcount),
url_scheme_(url_scheme) {}
virtual ~grpc_security_connector() = default;
/* Check the peer. Callee takes ownership of the peer object.
When done, sets *auth_context and invokes on_peer_checked. */
virtual void check_peer(
tsi_peer peer, grpc_core::RefCountedPtr<grpc_auth_context>* auth_context,
grpc_closure* on_peer_checked) GRPC_ABSTRACT;
/* Compares two security connectors. */
virtual int cmp(const grpc_security_connector* other) const GRPC_ABSTRACT;
const char* url_scheme() const { return url_scheme_; }
/* Refcounting. */
#ifndef NDEBUG
#define GRPC_SECURITY_CONNECTOR_REF(p, r) \
grpc_security_connector_ref((p), __FILE__, __LINE__, (r))
#define GRPC_SECURITY_CONNECTOR_UNREF(p, r) \
grpc_security_connector_unref((p), __FILE__, __LINE__, (r))
grpc_security_connector* grpc_security_connector_ref(
grpc_security_connector* policy, const char* file, int line,
const char* reason);
void grpc_security_connector_unref(grpc_security_connector* policy,
const char* file, int line,
const char* reason);
#else
#define GRPC_SECURITY_CONNECTOR_REF(p, r) grpc_security_connector_ref((p))
#define GRPC_SECURITY_CONNECTOR_UNREF(p, r) grpc_security_connector_unref((p))
grpc_security_connector* grpc_security_connector_ref(
grpc_security_connector* policy);
void grpc_security_connector_unref(grpc_security_connector* policy);
#endif
/* Check the peer. Callee takes ownership of the peer object.
When done, sets *auth_context and invokes on_peer_checked. */
void grpc_security_connector_check_peer(grpc_security_connector* sc,
tsi_peer peer,
grpc_auth_context** auth_context,
grpc_closure* on_peer_checked);
/* Compares two security connectors. */
int grpc_security_connector_cmp(grpc_security_connector* sc,
grpc_security_connector* other);
GRPC_ABSTRACT_BASE_CLASS
private:
const char* url_scheme_;
};
/* Util to encapsulate the connector in a channel arg. */
grpc_arg grpc_security_connector_to_arg(grpc_security_connector* sc);
@ -107,71 +85,89 @@ grpc_security_connector* grpc_security_connector_find_in_args(
A channel security connector object represents a way to configure the
underlying transport security mechanism on the client side. */
typedef struct grpc_channel_security_connector grpc_channel_security_connector;
struct grpc_channel_security_connector {
grpc_security_connector base;
grpc_channel_credentials* channel_creds;
grpc_call_credentials* request_metadata_creds;
bool (*check_call_host)(grpc_channel_security_connector* sc, const char* host,
grpc_auth_context* auth_context,
grpc_closure* on_call_host_checked,
grpc_error** error);
void (*cancel_check_call_host)(grpc_channel_security_connector* sc,
grpc_closure* on_call_host_checked,
grpc_error* error);
void (*add_handshakers)(grpc_channel_security_connector* sc,
grpc_pollset_set* interested_parties,
grpc_handshake_manager* handshake_mgr);
class grpc_channel_security_connector : public grpc_security_connector {
public:
grpc_channel_security_connector(
const char* url_scheme,
grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds,
grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds);
~grpc_channel_security_connector() override;
/// Checks that the host that will be set for a call is acceptable.
/// Returns true if completed synchronously, in which case \a error will
/// be set to indicate the result. Otherwise, \a on_call_host_checked
/// will be invoked when complete.
virtual bool check_call_host(const char* host,
grpc_auth_context* auth_context,
grpc_closure* on_call_host_checked,
grpc_error** error) GRPC_ABSTRACT;
/// Cancels a pending asychronous call to
/// grpc_channel_security_connector_check_call_host() with
/// \a on_call_host_checked as its callback.
virtual void cancel_check_call_host(grpc_closure* on_call_host_checked,
grpc_error* error) GRPC_ABSTRACT;
/// Registers handshakers with \a handshake_mgr.
virtual void add_handshakers(grpc_pollset_set* interested_parties,
grpc_handshake_manager* handshake_mgr)
GRPC_ABSTRACT;
const grpc_channel_credentials* channel_creds() const {
return channel_creds_.get();
}
grpc_channel_credentials* mutable_channel_creds() {
return channel_creds_.get();
}
const grpc_call_credentials* request_metadata_creds() const {
return request_metadata_creds_.get();
}
grpc_call_credentials* mutable_request_metadata_creds() {
return request_metadata_creds_.get();
}
GRPC_ABSTRACT_BASE_CLASS
protected:
// Helper methods to be used in subclasses.
int channel_security_connector_cmp(
const grpc_channel_security_connector* other) const;
private:
grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds_;
grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds_;
};
/// A helper function for use in grpc_security_connector_cmp() implementations.
int grpc_channel_security_connector_cmp(grpc_channel_security_connector* sc1,
grpc_channel_security_connector* sc2);
/// Checks that the host that will be set for a call is acceptable.
/// Returns true if completed synchronously, in which case \a error will
/// be set to indicate the result. Otherwise, \a on_call_host_checked
/// will be invoked when complete.
bool grpc_channel_security_connector_check_call_host(
grpc_channel_security_connector* sc, const char* host,
grpc_auth_context* auth_context, grpc_closure* on_call_host_checked,
grpc_error** error);
/// Cancels a pending asychronous call to
/// grpc_channel_security_connector_check_call_host() with
/// \a on_call_host_checked as its callback.
void grpc_channel_security_connector_cancel_check_call_host(
grpc_channel_security_connector* sc, grpc_closure* on_call_host_checked,
grpc_error* error);
/* Registers handshakers with \a handshake_mgr. */
void grpc_channel_security_connector_add_handshakers(
grpc_channel_security_connector* connector,
grpc_pollset_set* interested_parties,
grpc_handshake_manager* handshake_mgr);
/* --- server_security_connector object. ---
A server security connector object represents a way to configure the
underlying transport security mechanism on the server side. */
typedef struct grpc_server_security_connector grpc_server_security_connector;
struct grpc_server_security_connector {
grpc_security_connector base;
grpc_server_credentials* server_creds;
void (*add_handshakers)(grpc_server_security_connector* sc,
grpc_pollset_set* interested_parties,
grpc_handshake_manager* handshake_mgr);
class grpc_server_security_connector : public grpc_security_connector {
public:
grpc_server_security_connector(
const char* url_scheme,
grpc_core::RefCountedPtr<grpc_server_credentials> server_creds);
~grpc_server_security_connector() override = default;
virtual void add_handshakers(grpc_pollset_set* interested_parties,
grpc_handshake_manager* handshake_mgr)
GRPC_ABSTRACT;
const grpc_server_credentials* server_creds() const {
return server_creds_.get();
}
grpc_server_credentials* mutable_server_creds() {
return server_creds_.get();
}
GRPC_ABSTRACT_BASE_CLASS
protected:
// Helper methods to be used in subclasses.
int server_security_connector_cmp(
const grpc_server_security_connector* other) const;
private:
grpc_core::RefCountedPtr<grpc_server_credentials> server_creds_;
};
/// A helper function for use in grpc_security_connector_cmp() implementations.
int grpc_server_security_connector_cmp(grpc_server_security_connector* sc1,
grpc_server_security_connector* sc2);
void grpc_server_security_connector_add_handshakers(
grpc_server_security_connector* sc, grpc_pollset_set* interested_parties,
grpc_handshake_manager* handshake_mgr);
#endif /* GRPC_CORE_LIB_SECURITY_SECURITY_CONNECTOR_SECURITY_CONNECTOR_H */

@ -30,6 +30,7 @@
#include "src/core/lib/channel/handshaker.h"
#include "src/core/lib/gpr/host_port.h"
#include "src/core/lib/gpr/string.h"
#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/lib/security/context/security_context.h"
#include "src/core/lib/security/credentials/credentials.h"
#include "src/core/lib/security/credentials/ssl/ssl_credentials.h"
@ -39,172 +40,10 @@
#include "src/core/tsi/ssl_transport_security.h"
#include "src/core/tsi/transport_security.h"
typedef struct {
grpc_channel_security_connector base;
tsi_ssl_client_handshaker_factory* client_handshaker_factory;
char* target_name;
char* overridden_target_name;
const verify_peer_options* verify_options;
} grpc_ssl_channel_security_connector;
typedef struct {
grpc_server_security_connector base;
tsi_ssl_server_handshaker_factory* server_handshaker_factory;
} grpc_ssl_server_security_connector;
static bool server_connector_has_cert_config_fetcher(
grpc_ssl_server_security_connector* c) {
GPR_ASSERT(c != nullptr);
grpc_ssl_server_credentials* server_creds =
reinterpret_cast<grpc_ssl_server_credentials*>(c->base.server_creds);
GPR_ASSERT(server_creds != nullptr);
return server_creds->certificate_config_fetcher.cb != nullptr;
}
static void ssl_channel_destroy(grpc_security_connector* sc) {
grpc_ssl_channel_security_connector* c =
reinterpret_cast<grpc_ssl_channel_security_connector*>(sc);
grpc_channel_credentials_unref(c->base.channel_creds);
grpc_call_credentials_unref(c->base.request_metadata_creds);
tsi_ssl_client_handshaker_factory_unref(c->client_handshaker_factory);
c->client_handshaker_factory = nullptr;
if (c->target_name != nullptr) gpr_free(c->target_name);
if (c->overridden_target_name != nullptr) gpr_free(c->overridden_target_name);
gpr_free(sc);
}
static void ssl_server_destroy(grpc_security_connector* sc) {
grpc_ssl_server_security_connector* c =
reinterpret_cast<grpc_ssl_server_security_connector*>(sc);
grpc_server_credentials_unref(c->base.server_creds);
tsi_ssl_server_handshaker_factory_unref(c->server_handshaker_factory);
c->server_handshaker_factory = nullptr;
gpr_free(sc);
}
static void ssl_channel_add_handshakers(grpc_channel_security_connector* sc,
grpc_pollset_set* interested_parties,
grpc_handshake_manager* handshake_mgr) {
grpc_ssl_channel_security_connector* c =
reinterpret_cast<grpc_ssl_channel_security_connector*>(sc);
// Instantiate TSI handshaker.
tsi_handshaker* tsi_hs = nullptr;
tsi_result result = tsi_ssl_client_handshaker_factory_create_handshaker(
c->client_handshaker_factory,
c->overridden_target_name != nullptr ? c->overridden_target_name
: c->target_name,
&tsi_hs);
if (result != TSI_OK) {
gpr_log(GPR_ERROR, "Handshaker creation failed with error %s.",
tsi_result_to_string(result));
return;
}
// Create handshakers.
grpc_handshake_manager_add(
handshake_mgr, grpc_security_handshaker_create(tsi_hs, &sc->base));
}
/* Attempts to replace the server_handshaker_factory with a new factory using
* the provided grpc_ssl_server_certificate_config. Should new factory creation
* fail, the existing factory will not be replaced. Returns true on success (new
* factory created). */
static bool try_replace_server_handshaker_factory(
grpc_ssl_server_security_connector* sc,
const grpc_ssl_server_certificate_config* config) {
if (config == nullptr) {
gpr_log(GPR_ERROR,
"Server certificate config callback returned invalid (NULL) "
"config.");
return false;
}
gpr_log(GPR_DEBUG, "Using new server certificate config (%p).", config);
size_t num_alpn_protocols = 0;
const char** alpn_protocol_strings =
grpc_fill_alpn_protocol_strings(&num_alpn_protocols);
tsi_ssl_pem_key_cert_pair* cert_pairs = grpc_convert_grpc_to_tsi_cert_pairs(
config->pem_key_cert_pairs, config->num_key_cert_pairs);
tsi_ssl_server_handshaker_factory* new_handshaker_factory = nullptr;
grpc_ssl_server_credentials* server_creds =
reinterpret_cast<grpc_ssl_server_credentials*>(sc->base.server_creds);
tsi_result result = tsi_create_ssl_server_handshaker_factory_ex(
cert_pairs, config->num_key_cert_pairs, config->pem_root_certs,
grpc_get_tsi_client_certificate_request_type(
server_creds->config.client_certificate_request),
grpc_get_ssl_cipher_suites(), alpn_protocol_strings,
static_cast<uint16_t>(num_alpn_protocols), &new_handshaker_factory);
gpr_free(cert_pairs);
gpr_free((void*)alpn_protocol_strings);
if (result != TSI_OK) {
gpr_log(GPR_ERROR, "Handshaker factory creation failed with %s.",
tsi_result_to_string(result));
return false;
}
tsi_ssl_server_handshaker_factory_unref(sc->server_handshaker_factory);
sc->server_handshaker_factory = new_handshaker_factory;
return true;
}
/* Attempts to fetch the server certificate config if a callback is available.
* Current certificate config will continue to be used if the callback returns
* an error. Returns true if new credentials were sucessfully loaded. */
static bool try_fetch_ssl_server_credentials(
grpc_ssl_server_security_connector* sc) {
grpc_ssl_server_certificate_config* certificate_config = nullptr;
bool status;
GPR_ASSERT(sc != nullptr);
if (!server_connector_has_cert_config_fetcher(sc)) return false;
grpc_ssl_server_credentials* server_creds =
reinterpret_cast<grpc_ssl_server_credentials*>(sc->base.server_creds);
grpc_ssl_certificate_config_reload_status cb_result =
server_creds->certificate_config_fetcher.cb(
server_creds->certificate_config_fetcher.user_data,
&certificate_config);
if (cb_result == GRPC_SSL_CERTIFICATE_CONFIG_RELOAD_UNCHANGED) {
gpr_log(GPR_DEBUG, "No change in SSL server credentials.");
status = false;
} else if (cb_result == GRPC_SSL_CERTIFICATE_CONFIG_RELOAD_NEW) {
status = try_replace_server_handshaker_factory(sc, certificate_config);
} else {
// Log error, continue using previously-loaded credentials.
gpr_log(GPR_ERROR,
"Failed fetching new server credentials, continuing to "
"use previously-loaded credentials.");
status = false;
}
if (certificate_config != nullptr) {
grpc_ssl_server_certificate_config_destroy(certificate_config);
}
return status;
}
static void ssl_server_add_handshakers(grpc_server_security_connector* sc,
grpc_pollset_set* interested_parties,
grpc_handshake_manager* handshake_mgr) {
grpc_ssl_server_security_connector* c =
reinterpret_cast<grpc_ssl_server_security_connector*>(sc);
// Instantiate TSI handshaker.
try_fetch_ssl_server_credentials(c);
tsi_handshaker* tsi_hs = nullptr;
tsi_result result = tsi_ssl_server_handshaker_factory_create_handshaker(
c->server_handshaker_factory, &tsi_hs);
if (result != TSI_OK) {
gpr_log(GPR_ERROR, "Handshaker creation failed with error %s.",
tsi_result_to_string(result));
return;
}
// Create handshakers.
grpc_handshake_manager_add(
handshake_mgr, grpc_security_handshaker_create(tsi_hs, &sc->base));
}
static grpc_error* ssl_check_peer(grpc_security_connector* sc,
const char* peer_name, const tsi_peer* peer,
grpc_auth_context** auth_context) {
namespace {
grpc_error* ssl_check_peer(
const char* peer_name, const tsi_peer* peer,
grpc_core::RefCountedPtr<grpc_auth_context>* auth_context) {
#if TSI_OPENSSL_ALPN_SUPPORT
/* Check the ALPN if ALPN is supported. */
const tsi_peer_property* p =
@ -230,245 +69,384 @@ static grpc_error* ssl_check_peer(grpc_security_connector* sc,
return GRPC_ERROR_NONE;
}
static void ssl_channel_check_peer(grpc_security_connector* sc, tsi_peer peer,
grpc_auth_context** auth_context,
grpc_closure* on_peer_checked) {
grpc_ssl_channel_security_connector* c =
reinterpret_cast<grpc_ssl_channel_security_connector*>(sc);
const char* target_name = c->overridden_target_name != nullptr
? c->overridden_target_name
: c->target_name;
grpc_error* error = ssl_check_peer(sc, target_name, &peer, auth_context);
if (error == GRPC_ERROR_NONE &&
c->verify_options->verify_peer_callback != nullptr) {
const tsi_peer_property* p =
tsi_peer_get_property_by_name(&peer, TSI_X509_PEM_CERT_PROPERTY);
if (p == nullptr) {
error = GRPC_ERROR_CREATE_FROM_STATIC_STRING(
"Cannot check peer: missing pem cert property.");
} else {
char* peer_pem = static_cast<char*>(gpr_malloc(p->value.length + 1));
memcpy(peer_pem, p->value.data, p->value.length);
peer_pem[p->value.length] = '\0';
int callback_status = c->verify_options->verify_peer_callback(
target_name, peer_pem,
c->verify_options->verify_peer_callback_userdata);
gpr_free(peer_pem);
if (callback_status) {
char* msg;
gpr_asprintf(&msg, "Verify peer callback returned a failure (%d)",
callback_status);
error = GRPC_ERROR_CREATE_FROM_COPIED_STRING(msg);
gpr_free(msg);
}
}
class grpc_ssl_channel_security_connector final
: public grpc_channel_security_connector {
public:
grpc_ssl_channel_security_connector(
grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds,
grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds,
const grpc_ssl_config* config, const char* target_name,
const char* overridden_target_name)
: grpc_channel_security_connector(GRPC_SSL_URL_SCHEME,
std::move(channel_creds),
std::move(request_metadata_creds)),
overridden_target_name_(overridden_target_name == nullptr
? nullptr
: gpr_strdup(overridden_target_name)),
verify_options_(&config->verify_options) {
char* port;
gpr_split_host_port(target_name, &target_name_, &port);
gpr_free(port);
}
GRPC_CLOSURE_SCHED(on_peer_checked, error);
tsi_peer_destruct(&peer);
}
static void ssl_server_check_peer(grpc_security_connector* sc, tsi_peer peer,
grpc_auth_context** auth_context,
grpc_closure* on_peer_checked) {
grpc_error* error = ssl_check_peer(sc, nullptr, &peer, auth_context);
tsi_peer_destruct(&peer);
GRPC_CLOSURE_SCHED(on_peer_checked, error);
}
~grpc_ssl_channel_security_connector() override {
tsi_ssl_client_handshaker_factory_unref(client_handshaker_factory_);
if (target_name_ != nullptr) gpr_free(target_name_);
if (overridden_target_name_ != nullptr) gpr_free(overridden_target_name_);
}
static int ssl_channel_cmp(grpc_security_connector* sc1,
grpc_security_connector* sc2) {
grpc_ssl_channel_security_connector* c1 =
reinterpret_cast<grpc_ssl_channel_security_connector*>(sc1);
grpc_ssl_channel_security_connector* c2 =
reinterpret_cast<grpc_ssl_channel_security_connector*>(sc2);
int c = grpc_channel_security_connector_cmp(&c1->base, &c2->base);
if (c != 0) return c;
c = strcmp(c1->target_name, c2->target_name);
if (c != 0) return c;
return (c1->overridden_target_name == nullptr ||
c2->overridden_target_name == nullptr)
? GPR_ICMP(c1->overridden_target_name, c2->overridden_target_name)
: strcmp(c1->overridden_target_name, c2->overridden_target_name);
}
grpc_security_status InitializeHandshakerFactory(
const grpc_ssl_config* config, const char* pem_root_certs,
const tsi_ssl_root_certs_store* root_store,
tsi_ssl_session_cache* ssl_session_cache) {
bool has_key_cert_pair =
config->pem_key_cert_pair != nullptr &&
config->pem_key_cert_pair->private_key != nullptr &&
config->pem_key_cert_pair->cert_chain != nullptr;
tsi_ssl_client_handshaker_options options;
memset(&options, 0, sizeof(options));
GPR_DEBUG_ASSERT(pem_root_certs != nullptr);
options.pem_root_certs = pem_root_certs;
options.root_store = root_store;
options.alpn_protocols =
grpc_fill_alpn_protocol_strings(&options.num_alpn_protocols);
if (has_key_cert_pair) {
options.pem_key_cert_pair = config->pem_key_cert_pair;
}
options.cipher_suites = grpc_get_ssl_cipher_suites();
options.session_cache = ssl_session_cache;
const tsi_result result =
tsi_create_ssl_client_handshaker_factory_with_options(
&options, &client_handshaker_factory_);
gpr_free((void*)options.alpn_protocols);
if (result != TSI_OK) {
gpr_log(GPR_ERROR, "Handshaker factory creation failed with %s.",
tsi_result_to_string(result));
return GRPC_SECURITY_ERROR;
}
return GRPC_SECURITY_OK;
}
static int ssl_server_cmp(grpc_security_connector* sc1,
grpc_security_connector* sc2) {
return grpc_server_security_connector_cmp(
reinterpret_cast<grpc_server_security_connector*>(sc1),
reinterpret_cast<grpc_server_security_connector*>(sc2));
}
void add_handshakers(grpc_pollset_set* interested_parties,
grpc_handshake_manager* handshake_mgr) override {
// Instantiate TSI handshaker.
tsi_handshaker* tsi_hs = nullptr;
tsi_result result = tsi_ssl_client_handshaker_factory_create_handshaker(
client_handshaker_factory_,
overridden_target_name_ != nullptr ? overridden_target_name_
: target_name_,
&tsi_hs);
if (result != TSI_OK) {
gpr_log(GPR_ERROR, "Handshaker creation failed with error %s.",
tsi_result_to_string(result));
return;
}
// Create handshakers.
grpc_handshake_manager_add(handshake_mgr,
grpc_security_handshaker_create(tsi_hs, this));
}
static bool ssl_channel_check_call_host(grpc_channel_security_connector* sc,
const char* host,
grpc_auth_context* auth_context,
grpc_closure* on_call_host_checked,
grpc_error** error) {
grpc_ssl_channel_security_connector* c =
reinterpret_cast<grpc_ssl_channel_security_connector*>(sc);
grpc_security_status status = GRPC_SECURITY_ERROR;
tsi_peer peer = grpc_shallow_peer_from_ssl_auth_context(auth_context);
if (grpc_ssl_host_matches_name(&peer, host)) status = GRPC_SECURITY_OK;
/* If the target name was overridden, then the original target_name was
'checked' transitively during the previous peer check at the end of the
handshake. */
if (c->overridden_target_name != nullptr &&
strcmp(host, c->target_name) == 0) {
status = GRPC_SECURITY_OK;
void check_peer(tsi_peer peer,
grpc_core::RefCountedPtr<grpc_auth_context>* auth_context,
grpc_closure* on_peer_checked) override {
const char* target_name = overridden_target_name_ != nullptr
? overridden_target_name_
: target_name_;
grpc_error* error = ssl_check_peer(target_name, &peer, auth_context);
if (error == GRPC_ERROR_NONE &&
verify_options_->verify_peer_callback != nullptr) {
const tsi_peer_property* p =
tsi_peer_get_property_by_name(&peer, TSI_X509_PEM_CERT_PROPERTY);
if (p == nullptr) {
error = GRPC_ERROR_CREATE_FROM_STATIC_STRING(
"Cannot check peer: missing pem cert property.");
} else {
char* peer_pem = static_cast<char*>(gpr_malloc(p->value.length + 1));
memcpy(peer_pem, p->value.data, p->value.length);
peer_pem[p->value.length] = '\0';
int callback_status = verify_options_->verify_peer_callback(
target_name, peer_pem,
verify_options_->verify_peer_callback_userdata);
gpr_free(peer_pem);
if (callback_status) {
char* msg;
gpr_asprintf(&msg, "Verify peer callback returned a failure (%d)",
callback_status);
error = GRPC_ERROR_CREATE_FROM_COPIED_STRING(msg);
gpr_free(msg);
}
}
}
GRPC_CLOSURE_SCHED(on_peer_checked, error);
tsi_peer_destruct(&peer);
}
if (status != GRPC_SECURITY_OK) {
*error = GRPC_ERROR_CREATE_FROM_STATIC_STRING(
"call host does not match SSL server name");
int cmp(const grpc_security_connector* other_sc) const override {
auto* other =
reinterpret_cast<const grpc_ssl_channel_security_connector*>(other_sc);
int c = channel_security_connector_cmp(other);
if (c != 0) return c;
c = strcmp(target_name_, other->target_name_);
if (c != 0) return c;
return (overridden_target_name_ == nullptr ||
other->overridden_target_name_ == nullptr)
? GPR_ICMP(overridden_target_name_,
other->overridden_target_name_)
: strcmp(overridden_target_name_,
other->overridden_target_name_);
}
grpc_shallow_peer_destruct(&peer);
return true;
}
static void ssl_channel_cancel_check_call_host(
grpc_channel_security_connector* sc, grpc_closure* on_call_host_checked,
grpc_error* error) {
GRPC_ERROR_UNREF(error);
}
bool check_call_host(const char* host, grpc_auth_context* auth_context,
grpc_closure* on_call_host_checked,
grpc_error** error) override {
grpc_security_status status = GRPC_SECURITY_ERROR;
tsi_peer peer = grpc_shallow_peer_from_ssl_auth_context(auth_context);
if (grpc_ssl_host_matches_name(&peer, host)) status = GRPC_SECURITY_OK;
/* If the target name was overridden, then the original target_name was
'checked' transitively during the previous peer check at the end of the
handshake. */
if (overridden_target_name_ != nullptr && strcmp(host, target_name_) == 0) {
status = GRPC_SECURITY_OK;
}
if (status != GRPC_SECURITY_OK) {
*error = GRPC_ERROR_CREATE_FROM_STATIC_STRING(
"call host does not match SSL server name");
}
grpc_shallow_peer_destruct(&peer);
return true;
}
static grpc_security_connector_vtable ssl_channel_vtable = {
ssl_channel_destroy, ssl_channel_check_peer, ssl_channel_cmp};
void cancel_check_call_host(grpc_closure* on_call_host_checked,
grpc_error* error) override {
GRPC_ERROR_UNREF(error);
}
static grpc_security_connector_vtable ssl_server_vtable = {
ssl_server_destroy, ssl_server_check_peer, ssl_server_cmp};
private:
tsi_ssl_client_handshaker_factory* client_handshaker_factory_;
char* target_name_;
char* overridden_target_name_;
const verify_peer_options* verify_options_;
};
class grpc_ssl_server_security_connector
: public grpc_server_security_connector {
public:
grpc_ssl_server_security_connector(
grpc_core::RefCountedPtr<grpc_server_credentials> server_creds)
: grpc_server_security_connector(GRPC_SSL_URL_SCHEME,
std::move(server_creds)) {}
~grpc_ssl_server_security_connector() override {
tsi_ssl_server_handshaker_factory_unref(server_handshaker_factory_);
}
grpc_security_status grpc_ssl_channel_security_connector_create(
grpc_channel_credentials* channel_creds,
grpc_call_credentials* request_metadata_creds,
const grpc_ssl_config* config, const char* target_name,
const char* overridden_target_name,
tsi_ssl_session_cache* ssl_session_cache,
grpc_channel_security_connector** sc) {
tsi_result result = TSI_OK;
grpc_ssl_channel_security_connector* c;
char* port;
bool has_key_cert_pair;
tsi_ssl_client_handshaker_options options;
memset(&options, 0, sizeof(options));
options.alpn_protocols =
grpc_fill_alpn_protocol_strings(&options.num_alpn_protocols);
bool has_cert_config_fetcher() const {
return static_cast<const grpc_ssl_server_credentials*>(server_creds())
->has_cert_config_fetcher();
}
if (config == nullptr || target_name == nullptr) {
gpr_log(GPR_ERROR, "An ssl channel needs a config and a target name.");
goto error;
const tsi_ssl_server_handshaker_factory* server_handshaker_factory() const {
return server_handshaker_factory_;
}
if (config->pem_root_certs == nullptr) {
// Use default root certificates.
options.pem_root_certs = grpc_core::DefaultSslRootStore::GetPemRootCerts();
options.root_store = grpc_core::DefaultSslRootStore::GetRootStore();
if (options.pem_root_certs == nullptr) {
gpr_log(GPR_ERROR, "Could not get default pem root certs.");
goto error;
grpc_security_status InitializeHandshakerFactory() {
if (has_cert_config_fetcher()) {
// Load initial credentials from certificate_config_fetcher:
if (!try_fetch_ssl_server_credentials()) {
gpr_log(GPR_ERROR,
"Failed loading SSL server credentials from fetcher.");
return GRPC_SECURITY_ERROR;
}
} else {
auto* server_credentials =
static_cast<const grpc_ssl_server_credentials*>(server_creds());
size_t num_alpn_protocols = 0;
const char** alpn_protocol_strings =
grpc_fill_alpn_protocol_strings(&num_alpn_protocols);
const tsi_result result = tsi_create_ssl_server_handshaker_factory_ex(
server_credentials->config().pem_key_cert_pairs,
server_credentials->config().num_key_cert_pairs,
server_credentials->config().pem_root_certs,
grpc_get_tsi_client_certificate_request_type(
server_credentials->config().client_certificate_request),
grpc_get_ssl_cipher_suites(), alpn_protocol_strings,
static_cast<uint16_t>(num_alpn_protocols),
&server_handshaker_factory_);
gpr_free((void*)alpn_protocol_strings);
if (result != TSI_OK) {
gpr_log(GPR_ERROR, "Handshaker factory creation failed with %s.",
tsi_result_to_string(result));
return GRPC_SECURITY_ERROR;
}
}
} else {
options.pem_root_certs = config->pem_root_certs;
}
c = static_cast<grpc_ssl_channel_security_connector*>(
gpr_zalloc(sizeof(grpc_ssl_channel_security_connector)));
gpr_ref_init(&c->base.base.refcount, 1);
c->base.base.vtable = &ssl_channel_vtable;
c->base.base.url_scheme = GRPC_SSL_URL_SCHEME;
c->base.channel_creds = grpc_channel_credentials_ref(channel_creds);
c->base.request_metadata_creds =
grpc_call_credentials_ref(request_metadata_creds);
c->base.check_call_host = ssl_channel_check_call_host;
c->base.cancel_check_call_host = ssl_channel_cancel_check_call_host;
c->base.add_handshakers = ssl_channel_add_handshakers;
gpr_split_host_port(target_name, &c->target_name, &port);
gpr_free(port);
if (overridden_target_name != nullptr) {
c->overridden_target_name = gpr_strdup(overridden_target_name);
return GRPC_SECURITY_OK;
}
c->verify_options = &config->verify_options;
has_key_cert_pair = config->pem_key_cert_pair != nullptr &&
config->pem_key_cert_pair->private_key != nullptr &&
config->pem_key_cert_pair->cert_chain != nullptr;
if (has_key_cert_pair) {
options.pem_key_cert_pair = config->pem_key_cert_pair;
void add_handshakers(grpc_pollset_set* interested_parties,
grpc_handshake_manager* handshake_mgr) override {
// Instantiate TSI handshaker.
try_fetch_ssl_server_credentials();
tsi_handshaker* tsi_hs = nullptr;
tsi_result result = tsi_ssl_server_handshaker_factory_create_handshaker(
server_handshaker_factory_, &tsi_hs);
if (result != TSI_OK) {
gpr_log(GPR_ERROR, "Handshaker creation failed with error %s.",
tsi_result_to_string(result));
return;
}
// Create handshakers.
grpc_handshake_manager_add(handshake_mgr,
grpc_security_handshaker_create(tsi_hs, this));
}
options.cipher_suites = grpc_get_ssl_cipher_suites();
options.session_cache = ssl_session_cache;
result = tsi_create_ssl_client_handshaker_factory_with_options(
&options, &c->client_handshaker_factory);
if (result != TSI_OK) {
gpr_log(GPR_ERROR, "Handshaker factory creation failed with %s.",
tsi_result_to_string(result));
ssl_channel_destroy(&c->base.base);
*sc = nullptr;
goto error;
void check_peer(tsi_peer peer,
grpc_core::RefCountedPtr<grpc_auth_context>* auth_context,
grpc_closure* on_peer_checked) override {
grpc_error* error = ssl_check_peer(nullptr, &peer, auth_context);
tsi_peer_destruct(&peer);
GRPC_CLOSURE_SCHED(on_peer_checked, error);
}
*sc = &c->base;
gpr_free((void*)options.alpn_protocols);
return GRPC_SECURITY_OK;
error:
gpr_free((void*)options.alpn_protocols);
return GRPC_SECURITY_ERROR;
}
int cmp(const grpc_security_connector* other) const override {
return server_security_connector_cmp(
static_cast<const grpc_server_security_connector*>(other));
}
static grpc_ssl_server_security_connector*
grpc_ssl_server_security_connector_initialize(
grpc_server_credentials* server_creds) {
grpc_ssl_server_security_connector* c =
static_cast<grpc_ssl_server_security_connector*>(
gpr_zalloc(sizeof(grpc_ssl_server_security_connector)));
gpr_ref_init(&c->base.base.refcount, 1);
c->base.base.url_scheme = GRPC_SSL_URL_SCHEME;
c->base.base.vtable = &ssl_server_vtable;
c->base.add_handshakers = ssl_server_add_handshakers;
c->base.server_creds = grpc_server_credentials_ref(server_creds);
return c;
}
private:
/* Attempts to fetch the server certificate config if a callback is available.
* Current certificate config will continue to be used if the callback returns
* an error. Returns true if new credentials were sucessfully loaded. */
bool try_fetch_ssl_server_credentials() {
grpc_ssl_server_certificate_config* certificate_config = nullptr;
bool status;
if (!has_cert_config_fetcher()) return false;
grpc_ssl_server_credentials* server_creds =
static_cast<grpc_ssl_server_credentials*>(this->mutable_server_creds());
grpc_ssl_certificate_config_reload_status cb_result =
server_creds->FetchCertConfig(&certificate_config);
if (cb_result == GRPC_SSL_CERTIFICATE_CONFIG_RELOAD_UNCHANGED) {
gpr_log(GPR_DEBUG, "No change in SSL server credentials.");
status = false;
} else if (cb_result == GRPC_SSL_CERTIFICATE_CONFIG_RELOAD_NEW) {
status = try_replace_server_handshaker_factory(certificate_config);
} else {
// Log error, continue using previously-loaded credentials.
gpr_log(GPR_ERROR,
"Failed fetching new server credentials, continuing to "
"use previously-loaded credentials.");
status = false;
}
grpc_security_status grpc_ssl_server_security_connector_create(
grpc_server_credentials* gsc, grpc_server_security_connector** sc) {
tsi_result result = TSI_OK;
grpc_ssl_server_credentials* server_credentials =
reinterpret_cast<grpc_ssl_server_credentials*>(gsc);
grpc_security_status retval = GRPC_SECURITY_OK;
if (certificate_config != nullptr) {
grpc_ssl_server_certificate_config_destroy(certificate_config);
}
return status;
}
GPR_ASSERT(server_credentials != nullptr);
GPR_ASSERT(sc != nullptr);
grpc_ssl_server_security_connector* c =
grpc_ssl_server_security_connector_initialize(gsc);
if (server_connector_has_cert_config_fetcher(c)) {
// Load initial credentials from certificate_config_fetcher:
if (!try_fetch_ssl_server_credentials(c)) {
gpr_log(GPR_ERROR, "Failed loading SSL server credentials from fetcher.");
retval = GRPC_SECURITY_ERROR;
/* Attempts to replace the server_handshaker_factory with a new factory using
* the provided grpc_ssl_server_certificate_config. Should new factory
* creation fail, the existing factory will not be replaced. Returns true on
* success (new factory created). */
bool try_replace_server_handshaker_factory(
const grpc_ssl_server_certificate_config* config) {
if (config == nullptr) {
gpr_log(GPR_ERROR,
"Server certificate config callback returned invalid (NULL) "
"config.");
return false;
}
} else {
gpr_log(GPR_DEBUG, "Using new server certificate config (%p).", config);
size_t num_alpn_protocols = 0;
const char** alpn_protocol_strings =
grpc_fill_alpn_protocol_strings(&num_alpn_protocols);
result = tsi_create_ssl_server_handshaker_factory_ex(
server_credentials->config.pem_key_cert_pairs,
server_credentials->config.num_key_cert_pairs,
server_credentials->config.pem_root_certs,
tsi_ssl_pem_key_cert_pair* cert_pairs = grpc_convert_grpc_to_tsi_cert_pairs(
config->pem_key_cert_pairs, config->num_key_cert_pairs);
tsi_ssl_server_handshaker_factory* new_handshaker_factory = nullptr;
const grpc_ssl_server_credentials* server_creds =
static_cast<const grpc_ssl_server_credentials*>(this->server_creds());
GPR_DEBUG_ASSERT(config->pem_root_certs != nullptr);
tsi_result result = tsi_create_ssl_server_handshaker_factory_ex(
cert_pairs, config->num_key_cert_pairs, config->pem_root_certs,
grpc_get_tsi_client_certificate_request_type(
server_credentials->config.client_certificate_request),
server_creds->config().client_certificate_request),
grpc_get_ssl_cipher_suites(), alpn_protocol_strings,
static_cast<uint16_t>(num_alpn_protocols),
&c->server_handshaker_factory);
static_cast<uint16_t>(num_alpn_protocols), &new_handshaker_factory);
gpr_free(cert_pairs);
gpr_free((void*)alpn_protocol_strings);
if (result != TSI_OK) {
gpr_log(GPR_ERROR, "Handshaker factory creation failed with %s.",
tsi_result_to_string(result));
retval = GRPC_SECURITY_ERROR;
return false;
}
set_server_handshaker_factory(new_handshaker_factory);
return true;
}
void set_server_handshaker_factory(
tsi_ssl_server_handshaker_factory* new_factory) {
if (server_handshaker_factory_) {
tsi_ssl_server_handshaker_factory_unref(server_handshaker_factory_);
}
server_handshaker_factory_ = new_factory;
}
tsi_ssl_server_handshaker_factory* server_handshaker_factory_ = nullptr;
};
} // namespace
grpc_core::RefCountedPtr<grpc_channel_security_connector>
grpc_ssl_channel_security_connector_create(
grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds,
grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds,
const grpc_ssl_config* config, const char* target_name,
const char* overridden_target_name,
tsi_ssl_session_cache* ssl_session_cache) {
if (config == nullptr || target_name == nullptr) {
gpr_log(GPR_ERROR, "An ssl channel needs a config and a target name.");
return nullptr;
}
if (retval == GRPC_SECURITY_OK) {
*sc = &c->base;
const char* pem_root_certs;
const tsi_ssl_root_certs_store* root_store;
if (config->pem_root_certs == nullptr) {
// Use default root certificates.
pem_root_certs = grpc_core::DefaultSslRootStore::GetPemRootCerts();
if (pem_root_certs == nullptr) {
gpr_log(GPR_ERROR, "Could not get default pem root certs.");
return nullptr;
}
root_store = grpc_core::DefaultSslRootStore::GetRootStore();
} else {
if (c != nullptr) ssl_server_destroy(&c->base.base);
if (sc != nullptr) *sc = nullptr;
pem_root_certs = config->pem_root_certs;
root_store = nullptr;
}
grpc_core::RefCountedPtr<grpc_ssl_channel_security_connector> c =
grpc_core::MakeRefCounted<grpc_ssl_channel_security_connector>(
std::move(channel_creds), std::move(request_metadata_creds), config,
target_name, overridden_target_name);
const grpc_security_status result = c->InitializeHandshakerFactory(
config, pem_root_certs, root_store, ssl_session_cache);
if (result != GRPC_SECURITY_OK) {
return nullptr;
}
return retval;
return c;
}
grpc_core::RefCountedPtr<grpc_server_security_connector>
grpc_ssl_server_security_connector_create(
grpc_core::RefCountedPtr<grpc_server_credentials> server_credentials) {
GPR_ASSERT(server_credentials != nullptr);
grpc_core::RefCountedPtr<grpc_ssl_server_security_connector> c =
grpc_core::MakeRefCounted<grpc_ssl_server_security_connector>(
std::move(server_credentials));
const grpc_security_status retval = c->InitializeHandshakerFactory();
if (retval != GRPC_SECURITY_OK) {
return nullptr;
}
return c;
}

@ -25,6 +25,7 @@
#include "src/core/lib/security/security_connector/security_connector.h"
#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/tsi/ssl_transport_security.h"
#include "src/core/tsi/transport_security_interface.h"
@ -47,20 +48,21 @@ typedef struct {
This function returns GRPC_SECURITY_OK in case of success or a
specific error code otherwise.
*/
grpc_security_status grpc_ssl_channel_security_connector_create(
grpc_channel_credentials* channel_creds,
grpc_call_credentials* request_metadata_creds,
grpc_core::RefCountedPtr<grpc_channel_security_connector>
grpc_ssl_channel_security_connector_create(
grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds,
grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds,
const grpc_ssl_config* config, const char* target_name,
const char* overridden_target_name,
tsi_ssl_session_cache* ssl_session_cache,
grpc_channel_security_connector** sc);
tsi_ssl_session_cache* ssl_session_cache);
/* Config for ssl servers. */
typedef struct {
tsi_ssl_pem_key_cert_pair* pem_key_cert_pairs;
size_t num_key_cert_pairs;
char* pem_root_certs;
grpc_ssl_client_certificate_request_type client_certificate_request;
tsi_ssl_pem_key_cert_pair* pem_key_cert_pairs = nullptr;
size_t num_key_cert_pairs = 0;
char* pem_root_certs = nullptr;
grpc_ssl_client_certificate_request_type client_certificate_request =
GRPC_SSL_DONT_REQUEST_CLIENT_CERTIFICATE;
} grpc_ssl_server_config;
/* Creates an SSL server_security_connector.
@ -69,9 +71,9 @@ typedef struct {
This function returns GRPC_SECURITY_OK in case of success or a
specific error code otherwise.
*/
grpc_security_status grpc_ssl_server_security_connector_create(
grpc_server_credentials* server_credentials,
grpc_server_security_connector** sc);
grpc_core::RefCountedPtr<grpc_server_security_connector>
grpc_ssl_server_security_connector_create(
grpc_core::RefCountedPtr<grpc_server_credentials> server_credentials);
#endif /* GRPC_CORE_LIB_SECURITY_SECURITY_CONNECTOR_SSL_SSL_SECURITY_CONNECTOR_H \
*/

@ -30,6 +30,7 @@
#include "src/core/lib/gpr/env.h"
#include "src/core/lib/gpr/host_port.h"
#include "src/core/lib/gpr/string.h"
#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/lib/iomgr/load_file.h"
#include "src/core/lib/security/context/security_context.h"
#include "src/core/lib/security/security_connector/load_system_roots.h"
@ -141,16 +142,17 @@ int grpc_ssl_host_matches_name(const tsi_peer* peer, const char* peer_name) {
return r;
}
grpc_auth_context* grpc_ssl_peer_to_auth_context(const tsi_peer* peer) {
grpc_core::RefCountedPtr<grpc_auth_context> grpc_ssl_peer_to_auth_context(
const tsi_peer* peer) {
size_t i;
grpc_auth_context* ctx = nullptr;
const char* peer_identity_property_name = nullptr;
/* The caller has checked the certificate type property. */
GPR_ASSERT(peer->property_count >= 1);
ctx = grpc_auth_context_create(nullptr);
grpc_core::RefCountedPtr<grpc_auth_context> ctx =
grpc_core::MakeRefCounted<grpc_auth_context>(nullptr);
grpc_auth_context_add_cstring_property(
ctx, GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME,
ctx.get(), GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME,
GRPC_SSL_TRANSPORT_SECURITY_TYPE);
for (i = 0; i < peer->property_count; i++) {
const tsi_peer_property* prop = &peer->properties[i];
@ -160,24 +162,26 @@ grpc_auth_context* grpc_ssl_peer_to_auth_context(const tsi_peer* peer) {
if (peer_identity_property_name == nullptr) {
peer_identity_property_name = GRPC_X509_CN_PROPERTY_NAME;
}
grpc_auth_context_add_property(ctx, GRPC_X509_CN_PROPERTY_NAME,
grpc_auth_context_add_property(ctx.get(), GRPC_X509_CN_PROPERTY_NAME,
prop->value.data, prop->value.length);
} else if (strcmp(prop->name,
TSI_X509_SUBJECT_ALTERNATIVE_NAME_PEER_PROPERTY) == 0) {
peer_identity_property_name = GRPC_X509_SAN_PROPERTY_NAME;
grpc_auth_context_add_property(ctx, GRPC_X509_SAN_PROPERTY_NAME,
grpc_auth_context_add_property(ctx.get(), GRPC_X509_SAN_PROPERTY_NAME,
prop->value.data, prop->value.length);
} else if (strcmp(prop->name, TSI_X509_PEM_CERT_PROPERTY) == 0) {
grpc_auth_context_add_property(ctx, GRPC_X509_PEM_CERT_PROPERTY_NAME,
grpc_auth_context_add_property(ctx.get(),
GRPC_X509_PEM_CERT_PROPERTY_NAME,
prop->value.data, prop->value.length);
} else if (strcmp(prop->name, TSI_SSL_SESSION_REUSED_PEER_PROPERTY) == 0) {
grpc_auth_context_add_property(ctx, GRPC_SSL_SESSION_REUSED_PROPERTY,
grpc_auth_context_add_property(ctx.get(),
GRPC_SSL_SESSION_REUSED_PROPERTY,
prop->value.data, prop->value.length);
}
}
if (peer_identity_property_name != nullptr) {
GPR_ASSERT(grpc_auth_context_set_peer_identity_property_name(
ctx, peer_identity_property_name) == 1);
ctx.get(), peer_identity_property_name) == 1);
}
return ctx;
}

@ -26,6 +26,7 @@
#include <grpc/grpc_security.h>
#include <grpc/slice_buffer.h>
#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/tsi/ssl_transport_security.h"
#include "src/core/tsi/transport_security_interface.h"
@ -47,7 +48,8 @@ grpc_get_tsi_client_certificate_request_type(
const char** grpc_fill_alpn_protocol_strings(size_t* num_alpn_protocols);
/* Exposed for testing only. */
grpc_auth_context* grpc_ssl_peer_to_auth_context(const tsi_peer* peer);
grpc_core::RefCountedPtr<grpc_auth_context> grpc_ssl_peer_to_auth_context(
const tsi_peer* peer);
tsi_peer grpc_shallow_peer_from_ssl_auth_context(
const grpc_auth_context* auth_context);
void grpc_shallow_peer_destruct(tsi_peer* peer);

@ -55,7 +55,7 @@ struct call_data {
// that the memory is not initialized.
void destroy() {
grpc_credentials_mdelem_array_destroy(&md_array);
grpc_call_credentials_unref(creds);
creds.reset();
grpc_slice_unref_internal(host);
grpc_slice_unref_internal(method);
grpc_auth_metadata_context_reset(&auth_md_context);
@ -64,7 +64,7 @@ struct call_data {
gpr_arena* arena;
grpc_call_stack* owning_call;
grpc_call_combiner* call_combiner;
grpc_call_credentials* creds = nullptr;
grpc_core::RefCountedPtr<grpc_call_credentials> creds;
grpc_slice host = grpc_empty_slice();
grpc_slice method = grpc_empty_slice();
/* pollset{_set} bound to this call; if we need to make external
@ -83,8 +83,18 @@ struct call_data {
/* We can have a per-channel credentials. */
struct channel_data {
grpc_channel_security_connector* security_connector;
grpc_auth_context* auth_context;
channel_data(grpc_channel_security_connector* security_connector,
grpc_auth_context* auth_context)
: security_connector(
security_connector->Ref(DEBUG_LOCATION, "client_auth_filter")),
auth_context(auth_context->Ref(DEBUG_LOCATION, "client_auth_filter")) {}
~channel_data() {
security_connector.reset(DEBUG_LOCATION, "client_auth_filter");
auth_context.reset(DEBUG_LOCATION, "client_auth_filter");
}
grpc_core::RefCountedPtr<grpc_channel_security_connector> security_connector;
grpc_core::RefCountedPtr<grpc_auth_context> auth_context;
};
} // namespace
@ -98,10 +108,11 @@ void grpc_auth_metadata_context_reset(
gpr_free(const_cast<char*>(auth_md_context->method_name));
auth_md_context->method_name = nullptr;
}
GRPC_AUTH_CONTEXT_UNREF(
(grpc_auth_context*)auth_md_context->channel_auth_context,
"grpc_auth_metadata_context");
auth_md_context->channel_auth_context = nullptr;
if (auth_md_context->channel_auth_context != nullptr) {
const_cast<grpc_auth_context*>(auth_md_context->channel_auth_context)
->Unref(DEBUG_LOCATION, "grpc_auth_metadata_context");
auth_md_context->channel_auth_context = nullptr;
}
}
static void add_error(grpc_error** combined, grpc_error* error) {
@ -175,7 +186,10 @@ void grpc_auth_metadata_context_build(
auth_md_context->service_url = service_url;
auth_md_context->method_name = method_name;
auth_md_context->channel_auth_context =
GRPC_AUTH_CONTEXT_REF(auth_context, "grpc_auth_metadata_context");
auth_context == nullptr
? nullptr
: auth_context->Ref(DEBUG_LOCATION, "grpc_auth_metadata_context")
.release();
gpr_free(service);
gpr_free(host_and_port);
}
@ -184,8 +198,8 @@ static void cancel_get_request_metadata(void* arg, grpc_error* error) {
grpc_call_element* elem = static_cast<grpc_call_element*>(arg);
call_data* calld = static_cast<call_data*>(elem->call_data);
if (error != GRPC_ERROR_NONE) {
grpc_call_credentials_cancel_get_request_metadata(
calld->creds, &calld->md_array, GRPC_ERROR_REF(error));
calld->creds->cancel_get_request_metadata(&calld->md_array,
GRPC_ERROR_REF(error));
}
}
@ -197,7 +211,7 @@ static void send_security_metadata(grpc_call_element* elem,
static_cast<grpc_client_security_context*>(
batch->payload->context[GRPC_CONTEXT_SECURITY].value);
grpc_call_credentials* channel_call_creds =
chand->security_connector->request_metadata_creds;
chand->security_connector->mutable_request_metadata_creds();
int call_creds_has_md = (ctx != nullptr) && (ctx->creds != nullptr);
if (channel_call_creds == nullptr && !call_creds_has_md) {
@ -207,8 +221,9 @@ static void send_security_metadata(grpc_call_element* elem,
}
if (channel_call_creds != nullptr && call_creds_has_md) {
calld->creds = grpc_composite_call_credentials_create(channel_call_creds,
ctx->creds, nullptr);
calld->creds = grpc_core::RefCountedPtr<grpc_call_credentials>(
grpc_composite_call_credentials_create(channel_call_creds,
ctx->creds.get(), nullptr));
if (calld->creds == nullptr) {
grpc_transport_stream_op_batch_finish_with_failure(
batch,
@ -220,22 +235,22 @@ static void send_security_metadata(grpc_call_element* elem,
return;
}
} else {
calld->creds = grpc_call_credentials_ref(
call_creds_has_md ? ctx->creds : channel_call_creds);
calld->creds =
call_creds_has_md ? ctx->creds->Ref() : channel_call_creds->Ref();
}
grpc_auth_metadata_context_build(
chand->security_connector->base.url_scheme, calld->host, calld->method,
chand->auth_context, &calld->auth_md_context);
chand->security_connector->url_scheme(), calld->host, calld->method,
chand->auth_context.get(), &calld->auth_md_context);
GPR_ASSERT(calld->pollent != nullptr);
GRPC_CALL_STACK_REF(calld->owning_call, "get_request_metadata");
GRPC_CLOSURE_INIT(&calld->async_result_closure, on_credentials_metadata,
batch, grpc_schedule_on_exec_ctx);
grpc_error* error = GRPC_ERROR_NONE;
if (grpc_call_credentials_get_request_metadata(
calld->creds, calld->pollent, calld->auth_md_context,
&calld->md_array, &calld->async_result_closure, &error)) {
if (calld->creds->get_request_metadata(
calld->pollent, calld->auth_md_context, &calld->md_array,
&calld->async_result_closure, &error)) {
// Synchronous return; invoke on_credentials_metadata() directly.
on_credentials_metadata(batch, error);
GRPC_ERROR_UNREF(error);
@ -279,9 +294,8 @@ static void cancel_check_call_host(void* arg, grpc_error* error) {
call_data* calld = static_cast<call_data*>(elem->call_data);
channel_data* chand = static_cast<channel_data*>(elem->channel_data);
if (error != GRPC_ERROR_NONE) {
grpc_channel_security_connector_cancel_check_call_host(
chand->security_connector, &calld->async_result_closure,
GRPC_ERROR_REF(error));
chand->security_connector->cancel_check_call_host(
&calld->async_result_closure, GRPC_ERROR_REF(error));
}
}
@ -299,16 +313,16 @@ static void auth_start_transport_stream_op_batch(
GPR_ASSERT(batch->payload->context != nullptr);
if (batch->payload->context[GRPC_CONTEXT_SECURITY].value == nullptr) {
batch->payload->context[GRPC_CONTEXT_SECURITY].value =
grpc_client_security_context_create(calld->arena);
grpc_client_security_context_create(calld->arena, /*creds=*/nullptr);
batch->payload->context[GRPC_CONTEXT_SECURITY].destroy =
grpc_client_security_context_destroy;
}
grpc_client_security_context* sec_ctx =
static_cast<grpc_client_security_context*>(
batch->payload->context[GRPC_CONTEXT_SECURITY].value);
GRPC_AUTH_CONTEXT_UNREF(sec_ctx->auth_context, "client auth filter");
sec_ctx->auth_context.reset(DEBUG_LOCATION, "client_auth_filter");
sec_ctx->auth_context =
GRPC_AUTH_CONTEXT_REF(chand->auth_context, "client_auth_filter");
chand->auth_context->Ref(DEBUG_LOCATION, "client_auth_filter");
}
if (batch->send_initial_metadata) {
@ -327,8 +341,8 @@ static void auth_start_transport_stream_op_batch(
grpc_schedule_on_exec_ctx);
char* call_host = grpc_slice_to_c_string(calld->host);
grpc_error* error = GRPC_ERROR_NONE;
if (grpc_channel_security_connector_check_call_host(
chand->security_connector, call_host, chand->auth_context,
if (chand->security_connector->check_call_host(
call_host, chand->auth_context.get(),
&calld->async_result_closure, &error)) {
// Synchronous return; invoke on_host_checked() directly.
on_host_checked(batch, error);
@ -374,6 +388,10 @@ static void destroy_call_elem(grpc_call_element* elem,
/* Constructor for channel_data */
static grpc_error* init_channel_elem(grpc_channel_element* elem,
grpc_channel_element_args* args) {
/* The first and the last filters tend to be implemented differently to
handle the case that there's no 'next' filter to call on the up or down
path */
GPR_ASSERT(!args->is_last);
grpc_security_connector* sc =
grpc_security_connector_find_in_args(args->channel_args);
if (sc == nullptr) {
@ -386,33 +404,15 @@ static grpc_error* init_channel_elem(grpc_channel_element* elem,
return GRPC_ERROR_CREATE_FROM_STATIC_STRING(
"Auth context missing from client auth filter args");
}
/* grab pointers to our data from the channel element */
channel_data* chand = static_cast<channel_data*>(elem->channel_data);
/* The first and the last filters tend to be implemented differently to
handle the case that there's no 'next' filter to call on the up or down
path */
GPR_ASSERT(!args->is_last);
/* initialize members */
chand->security_connector =
reinterpret_cast<grpc_channel_security_connector*>(
GRPC_SECURITY_CONNECTOR_REF(sc, "client_auth_filter"));
chand->auth_context =
GRPC_AUTH_CONTEXT_REF(auth_context, "client_auth_filter");
new (elem->channel_data) channel_data(
static_cast<grpc_channel_security_connector*>(sc), auth_context);
return GRPC_ERROR_NONE;
}
/* Destructor for channel data */
static void destroy_channel_elem(grpc_channel_element* elem) {
/* grab pointers to our data from the channel element */
channel_data* chand = static_cast<channel_data*>(elem->channel_data);
grpc_channel_security_connector* sc = chand->security_connector;
if (sc != nullptr) {
GRPC_SECURITY_CONNECTOR_UNREF(&sc->base, "client_auth_filter");
}
GRPC_AUTH_CONTEXT_UNREF(chand->auth_context, "client_auth_filter");
chand->~channel_data();
}
const grpc_channel_filter grpc_client_auth_filter = {

@ -30,6 +30,7 @@
#include "src/core/lib/channel/channel_args.h"
#include "src/core/lib/channel/handshaker.h"
#include "src/core/lib/channel/handshaker_registry.h"
#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/lib/security/context/security_context.h"
#include "src/core/lib/security/transport/secure_endpoint.h"
#include "src/core/lib/security/transport/tsi_error.h"
@ -38,34 +39,62 @@
#define GRPC_INITIAL_HANDSHAKE_BUFFER_SIZE 256
typedef struct {
namespace {
struct security_handshaker {
security_handshaker(tsi_handshaker* handshaker,
grpc_security_connector* connector);
~security_handshaker() {
gpr_mu_destroy(&mu);
tsi_handshaker_destroy(handshaker);
tsi_handshaker_result_destroy(handshaker_result);
if (endpoint_to_destroy != nullptr) {
grpc_endpoint_destroy(endpoint_to_destroy);
}
if (read_buffer_to_destroy != nullptr) {
grpc_slice_buffer_destroy_internal(read_buffer_to_destroy);
gpr_free(read_buffer_to_destroy);
}
gpr_free(handshake_buffer);
grpc_slice_buffer_destroy_internal(&outgoing);
auth_context.reset(DEBUG_LOCATION, "handshake");
connector.reset(DEBUG_LOCATION, "handshake");
}
void Ref() { refs.Ref(); }
void Unref() {
if (refs.Unref()) {
grpc_core::Delete(this);
}
}
grpc_handshaker base;
// State set at creation time.
tsi_handshaker* handshaker;
grpc_security_connector* connector;
grpc_core::RefCountedPtr<grpc_security_connector> connector;
gpr_mu mu;
gpr_refcount refs;
grpc_core::RefCount refs;
bool shutdown;
bool shutdown = false;
// Endpoint and read buffer to destroy after a shutdown.
grpc_endpoint* endpoint_to_destroy;
grpc_slice_buffer* read_buffer_to_destroy;
grpc_endpoint* endpoint_to_destroy = nullptr;
grpc_slice_buffer* read_buffer_to_destroy = nullptr;
// State saved while performing the handshake.
grpc_handshaker_args* args;
grpc_closure* on_handshake_done;
grpc_handshaker_args* args = nullptr;
grpc_closure* on_handshake_done = nullptr;
unsigned char* handshake_buffer;
size_t handshake_buffer_size;
unsigned char* handshake_buffer;
grpc_slice_buffer outgoing;
grpc_closure on_handshake_data_sent_to_peer;
grpc_closure on_handshake_data_received_from_peer;
grpc_closure on_peer_checked;
grpc_auth_context* auth_context;
tsi_handshaker_result* handshaker_result;
} security_handshaker;
grpc_core::RefCountedPtr<grpc_auth_context> auth_context;
tsi_handshaker_result* handshaker_result = nullptr;
};
} // namespace
static size_t move_read_buffer_into_handshake_buffer(security_handshaker* h) {
size_t bytes_in_read_buffer = h->args->read_buffer->length;
@ -85,26 +114,6 @@ static size_t move_read_buffer_into_handshake_buffer(security_handshaker* h) {
return bytes_in_read_buffer;
}
static void security_handshaker_unref(security_handshaker* h) {
if (gpr_unref(&h->refs)) {
gpr_mu_destroy(&h->mu);
tsi_handshaker_destroy(h->handshaker);
tsi_handshaker_result_destroy(h->handshaker_result);
if (h->endpoint_to_destroy != nullptr) {
grpc_endpoint_destroy(h->endpoint_to_destroy);
}
if (h->read_buffer_to_destroy != nullptr) {
grpc_slice_buffer_destroy_internal(h->read_buffer_to_destroy);
gpr_free(h->read_buffer_to_destroy);
}
gpr_free(h->handshake_buffer);
grpc_slice_buffer_destroy_internal(&h->outgoing);
GRPC_AUTH_CONTEXT_UNREF(h->auth_context, "handshake");
GRPC_SECURITY_CONNECTOR_UNREF(h->connector, "handshake");
gpr_free(h);
}
}
// Set args fields to NULL, saving the endpoint and read buffer for
// later destruction.
static void cleanup_args_for_failure_locked(security_handshaker* h) {
@ -194,7 +203,7 @@ static void on_peer_checked_inner(security_handshaker* h, grpc_error* error) {
tsi_handshaker_result_destroy(h->handshaker_result);
h->handshaker_result = nullptr;
// Add auth context to channel args.
grpc_arg auth_context_arg = grpc_auth_context_to_arg(h->auth_context);
grpc_arg auth_context_arg = grpc_auth_context_to_arg(h->auth_context.get());
grpc_channel_args* tmp_args = h->args->args;
h->args->args =
grpc_channel_args_copy_and_add(tmp_args, &auth_context_arg, 1);
@ -211,7 +220,7 @@ static void on_peer_checked(void* arg, grpc_error* error) {
gpr_mu_lock(&h->mu);
on_peer_checked_inner(h, error);
gpr_mu_unlock(&h->mu);
security_handshaker_unref(h);
h->Unref();
}
static grpc_error* check_peer_locked(security_handshaker* h) {
@ -222,8 +231,7 @@ static grpc_error* check_peer_locked(security_handshaker* h) {
return grpc_set_tsi_error_result(
GRPC_ERROR_CREATE_FROM_STATIC_STRING("Peer extraction failed"), result);
}
grpc_security_connector_check_peer(h->connector, peer, &h->auth_context,
&h->on_peer_checked);
h->connector->check_peer(peer, &h->auth_context, &h->on_peer_checked);
return GRPC_ERROR_NONE;
}
@ -281,7 +289,7 @@ static void on_handshake_next_done_grpc_wrapper(
if (error != GRPC_ERROR_NONE) {
security_handshake_failed_locked(h, error);
gpr_mu_unlock(&h->mu);
security_handshaker_unref(h);
h->Unref();
} else {
gpr_mu_unlock(&h->mu);
}
@ -317,7 +325,7 @@ static void on_handshake_data_received_from_peer(void* arg, grpc_error* error) {
h, GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING(
"Handshake read failed", &error, 1));
gpr_mu_unlock(&h->mu);
security_handshaker_unref(h);
h->Unref();
return;
}
// Copy all slices received.
@ -329,7 +337,7 @@ static void on_handshake_data_received_from_peer(void* arg, grpc_error* error) {
if (error != GRPC_ERROR_NONE) {
security_handshake_failed_locked(h, error);
gpr_mu_unlock(&h->mu);
security_handshaker_unref(h);
h->Unref();
} else {
gpr_mu_unlock(&h->mu);
}
@ -343,7 +351,7 @@ static void on_handshake_data_sent_to_peer(void* arg, grpc_error* error) {
h, GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING(
"Handshake write failed", &error, 1));
gpr_mu_unlock(&h->mu);
security_handshaker_unref(h);
h->Unref();
return;
}
// We may be done.
@ -355,7 +363,7 @@ static void on_handshake_data_sent_to_peer(void* arg, grpc_error* error) {
if (error != GRPC_ERROR_NONE) {
security_handshake_failed_locked(h, error);
gpr_mu_unlock(&h->mu);
security_handshaker_unref(h);
h->Unref();
return;
}
}
@ -368,7 +376,7 @@ static void on_handshake_data_sent_to_peer(void* arg, grpc_error* error) {
static void security_handshaker_destroy(grpc_handshaker* handshaker) {
security_handshaker* h = reinterpret_cast<security_handshaker*>(handshaker);
security_handshaker_unref(h);
h->Unref();
}
static void security_handshaker_shutdown(grpc_handshaker* handshaker,
@ -393,14 +401,14 @@ static void security_handshaker_do_handshake(grpc_handshaker* handshaker,
gpr_mu_lock(&h->mu);
h->args = args;
h->on_handshake_done = on_handshake_done;
gpr_ref(&h->refs);
h->Ref();
size_t bytes_received_size = move_read_buffer_into_handshake_buffer(h);
grpc_error* error =
do_handshaker_next_locked(h, h->handshake_buffer, bytes_received_size);
if (error != GRPC_ERROR_NONE) {
security_handshake_failed_locked(h, error);
gpr_mu_unlock(&h->mu);
security_handshaker_unref(h);
h->Unref();
return;
}
gpr_mu_unlock(&h->mu);
@ -410,27 +418,32 @@ static const grpc_handshaker_vtable security_handshaker_vtable = {
security_handshaker_destroy, security_handshaker_shutdown,
security_handshaker_do_handshake, "security"};
static grpc_handshaker* security_handshaker_create(
tsi_handshaker* handshaker, grpc_security_connector* connector) {
security_handshaker* h = static_cast<security_handshaker*>(
gpr_zalloc(sizeof(security_handshaker)));
grpc_handshaker_init(&security_handshaker_vtable, &h->base);
h->handshaker = handshaker;
h->connector = GRPC_SECURITY_CONNECTOR_REF(connector, "handshake");
gpr_mu_init(&h->mu);
gpr_ref_init(&h->refs, 1);
h->handshake_buffer_size = GRPC_INITIAL_HANDSHAKE_BUFFER_SIZE;
h->handshake_buffer =
static_cast<uint8_t*>(gpr_malloc(h->handshake_buffer_size));
GRPC_CLOSURE_INIT(&h->on_handshake_data_sent_to_peer,
on_handshake_data_sent_to_peer, h,
namespace {
security_handshaker::security_handshaker(tsi_handshaker* handshaker,
grpc_security_connector* connector)
: handshaker(handshaker),
connector(connector->Ref(DEBUG_LOCATION, "handshake")),
handshake_buffer_size(GRPC_INITIAL_HANDSHAKE_BUFFER_SIZE),
handshake_buffer(
static_cast<uint8_t*>(gpr_malloc(handshake_buffer_size))) {
grpc_handshaker_init(&security_handshaker_vtable, &base);
gpr_mu_init(&mu);
grpc_slice_buffer_init(&outgoing);
GRPC_CLOSURE_INIT(&on_handshake_data_sent_to_peer,
::on_handshake_data_sent_to_peer, this,
grpc_schedule_on_exec_ctx);
GRPC_CLOSURE_INIT(&h->on_handshake_data_received_from_peer,
on_handshake_data_received_from_peer, h,
GRPC_CLOSURE_INIT(&on_handshake_data_received_from_peer,
::on_handshake_data_received_from_peer, this,
grpc_schedule_on_exec_ctx);
GRPC_CLOSURE_INIT(&h->on_peer_checked, on_peer_checked, h,
GRPC_CLOSURE_INIT(&on_peer_checked, ::on_peer_checked, this,
grpc_schedule_on_exec_ctx);
grpc_slice_buffer_init(&h->outgoing);
}
} // namespace
static grpc_handshaker* security_handshaker_create(
tsi_handshaker* handshaker, grpc_security_connector* connector) {
security_handshaker* h =
grpc_core::New<security_handshaker>(handshaker, connector);
return &h->base;
}
@ -477,8 +490,9 @@ static void client_handshaker_factory_add_handshakers(
grpc_channel_security_connector* security_connector =
reinterpret_cast<grpc_channel_security_connector*>(
grpc_security_connector_find_in_args(args));
grpc_channel_security_connector_add_handshakers(
security_connector, interested_parties, handshake_mgr);
if (security_connector) {
security_connector->add_handshakers(interested_parties, handshake_mgr);
}
}
static void server_handshaker_factory_add_handshakers(
@ -488,8 +502,9 @@ static void server_handshaker_factory_add_handshakers(
grpc_server_security_connector* security_connector =
reinterpret_cast<grpc_server_security_connector*>(
grpc_security_connector_find_in_args(args));
grpc_server_security_connector_add_handshakers(
security_connector, interested_parties, handshake_mgr);
if (security_connector) {
security_connector->add_handshakers(interested_parties, handshake_mgr);
}
}
static void handshaker_factory_destroy(

@ -39,8 +39,12 @@ enum async_state {
};
struct channel_data {
grpc_auth_context* auth_context;
grpc_server_credentials* creds;
channel_data(grpc_auth_context* auth_context, grpc_server_credentials* creds)
: auth_context(auth_context->Ref()), creds(creds->Ref()) {}
~channel_data() { auth_context.reset(DEBUG_LOCATION, "server_auth_filter"); }
grpc_core::RefCountedPtr<grpc_auth_context> auth_context;
grpc_core::RefCountedPtr<grpc_server_credentials> creds;
};
struct call_data {
@ -58,7 +62,7 @@ struct call_data {
grpc_server_security_context_create(args.arena);
channel_data* chand = static_cast<channel_data*>(elem->channel_data);
server_ctx->auth_context =
GRPC_AUTH_CONTEXT_REF(chand->auth_context, "server_auth_filter");
chand->auth_context->Ref(DEBUG_LOCATION, "server_auth_filter");
if (args.context[GRPC_CONTEXT_SECURITY].value != nullptr) {
args.context[GRPC_CONTEXT_SECURITY].destroy(
args.context[GRPC_CONTEXT_SECURITY].value);
@ -208,7 +212,8 @@ static void recv_initial_metadata_ready(void* arg, grpc_error* error) {
call_data* calld = static_cast<call_data*>(elem->call_data);
grpc_transport_stream_op_batch* batch = calld->recv_initial_metadata_batch;
if (error == GRPC_ERROR_NONE) {
if (chand->creds != nullptr && chand->creds->processor.process != nullptr) {
if (chand->creds != nullptr &&
chand->creds->auth_metadata_processor().process != nullptr) {
// We're calling out to the application, so we need to make sure
// to drop the call combiner early if we get cancelled.
GRPC_CLOSURE_INIT(&calld->cancel_closure, cancel_call, elem,
@ -218,9 +223,10 @@ static void recv_initial_metadata_ready(void* arg, grpc_error* error) {
GRPC_CALL_STACK_REF(calld->owning_call, "server_auth_metadata");
calld->md = metadata_batch_to_md_array(
batch->payload->recv_initial_metadata.recv_initial_metadata);
chand->creds->processor.process(
chand->creds->processor.state, chand->auth_context,
calld->md.metadata, calld->md.count, on_md_processing_done, elem);
chand->creds->auth_metadata_processor().process(
chand->creds->auth_metadata_processor().state,
chand->auth_context.get(), calld->md.metadata, calld->md.count,
on_md_processing_done, elem);
return;
}
}
@ -290,23 +296,19 @@ static void destroy_call_elem(grpc_call_element* elem,
static grpc_error* init_channel_elem(grpc_channel_element* elem,
grpc_channel_element_args* args) {
GPR_ASSERT(!args->is_last);
channel_data* chand = static_cast<channel_data*>(elem->channel_data);
grpc_auth_context* auth_context =
grpc_find_auth_context_in_args(args->channel_args);
GPR_ASSERT(auth_context != nullptr);
chand->auth_context =
GRPC_AUTH_CONTEXT_REF(auth_context, "server_auth_filter");
grpc_server_credentials* creds =
grpc_find_server_credentials_in_args(args->channel_args);
chand->creds = grpc_server_credentials_ref(creds);
new (elem->channel_data) channel_data(auth_context, creds);
return GRPC_ERROR_NONE;
}
/* Destructor for channel data */
static void destroy_channel_elem(grpc_channel_element* elem) {
channel_data* chand = static_cast<channel_data*>(elem->channel_data);
GRPC_AUTH_CONTEXT_UNREF(chand->auth_context, "server_auth_filter");
grpc_server_credentials_unref(chand->creds);
chand->~channel_data();
}
const grpc_channel_filter grpc_server_auth_filter = {

@ -261,10 +261,10 @@ void MetadataCredentialsPluginWrapper::InvokePlugin(
grpc_status_code* status_code, const char** error_details) {
std::multimap<grpc::string, grpc::string> metadata;
// const_cast is safe since the SecureAuthContext does not take owndership and
// the object is passed as a const ref to plugin_->GetMetadata.
// const_cast is safe since the SecureAuthContext only inc/dec the refcount
// and the object is passed as a const ref to plugin_->GetMetadata.
SecureAuthContext cpp_channel_auth_context(
const_cast<grpc_auth_context*>(context.channel_auth_context), false);
const_cast<grpc_auth_context*>(context.channel_auth_context));
Status status = plugin_->GetMetadata(context.service_url, context.method_name,
cpp_channel_auth_context, &metadata);

@ -24,6 +24,7 @@
#include <grpcpp/security/credentials.h>
#include <grpcpp/support/config.h>
#include "src/core/lib/security/credentials/credentials.h"
#include "src/cpp/server/thread_pool_interface.h"
namespace grpc {
@ -31,7 +32,9 @@ namespace grpc {
class SecureChannelCredentials final : public ChannelCredentials {
public:
explicit SecureChannelCredentials(grpc_channel_credentials* c_creds);
~SecureChannelCredentials() { grpc_channel_credentials_release(c_creds_); }
~SecureChannelCredentials() {
if (c_creds_ != nullptr) c_creds_->Unref();
}
grpc_channel_credentials* GetRawCreds() { return c_creds_; }
std::shared_ptr<grpc::Channel> CreateChannel(
@ -51,7 +54,9 @@ class SecureChannelCredentials final : public ChannelCredentials {
class SecureCallCredentials final : public CallCredentials {
public:
explicit SecureCallCredentials(grpc_call_credentials* c_creds);
~SecureCallCredentials() { grpc_call_credentials_release(c_creds_); }
~SecureCallCredentials() {
if (c_creds_ != nullptr) c_creds_->Unref();
}
grpc_call_credentials* GetRawCreds() { return c_creds_; }
bool ApplyToCall(grpc_call* call) override;

@ -22,19 +22,12 @@
namespace grpc {
SecureAuthContext::SecureAuthContext(grpc_auth_context* ctx,
bool take_ownership)
: ctx_(ctx), take_ownership_(take_ownership) {}
SecureAuthContext::~SecureAuthContext() {
if (take_ownership_) grpc_auth_context_release(ctx_);
}
std::vector<grpc::string_ref> SecureAuthContext::GetPeerIdentity() const {
if (!ctx_) {
if (ctx_ == nullptr) {
return std::vector<grpc::string_ref>();
}
grpc_auth_property_iterator iter = grpc_auth_context_peer_identity(ctx_);
grpc_auth_property_iterator iter =
grpc_auth_context_peer_identity(ctx_.get());
std::vector<grpc::string_ref> identity;
const grpc_auth_property* property = nullptr;
while ((property = grpc_auth_property_iterator_next(&iter))) {
@ -45,20 +38,20 @@ std::vector<grpc::string_ref> SecureAuthContext::GetPeerIdentity() const {
}
grpc::string SecureAuthContext::GetPeerIdentityPropertyName() const {
if (!ctx_) {
if (ctx_ == nullptr) {
return "";
}
const char* name = grpc_auth_context_peer_identity_property_name(ctx_);
const char* name = grpc_auth_context_peer_identity_property_name(ctx_.get());
return name == nullptr ? "" : name;
}
std::vector<grpc::string_ref> SecureAuthContext::FindPropertyValues(
const grpc::string& name) const {
if (!ctx_) {
if (ctx_ == nullptr) {
return std::vector<grpc::string_ref>();
}
grpc_auth_property_iterator iter =
grpc_auth_context_find_properties_by_name(ctx_, name.c_str());
grpc_auth_context_find_properties_by_name(ctx_.get(), name.c_str());
const grpc_auth_property* property = nullptr;
std::vector<grpc::string_ref> values;
while ((property = grpc_auth_property_iterator_next(&iter))) {
@ -68,9 +61,9 @@ std::vector<grpc::string_ref> SecureAuthContext::FindPropertyValues(
}
AuthPropertyIterator SecureAuthContext::begin() const {
if (ctx_) {
if (ctx_ != nullptr) {
grpc_auth_property_iterator iter =
grpc_auth_context_property_iterator(ctx_);
grpc_auth_context_property_iterator(ctx_.get());
const grpc_auth_property* property =
grpc_auth_property_iterator_next(&iter);
return AuthPropertyIterator(property, &iter);
@ -85,19 +78,20 @@ AuthPropertyIterator SecureAuthContext::end() const {
void SecureAuthContext::AddProperty(const grpc::string& key,
const grpc::string_ref& value) {
if (!ctx_) return;
grpc_auth_context_add_property(ctx_, key.c_str(), value.data(), value.size());
if (ctx_ == nullptr) return;
grpc_auth_context_add_property(ctx_.get(), key.c_str(), value.data(),
value.size());
}
bool SecureAuthContext::SetPeerIdentityPropertyName(const grpc::string& name) {
if (!ctx_) return false;
return grpc_auth_context_set_peer_identity_property_name(ctx_,
if (ctx_ == nullptr) return false;
return grpc_auth_context_set_peer_identity_property_name(ctx_.get(),
name.c_str()) != 0;
}
bool SecureAuthContext::IsPeerAuthenticated() const {
if (!ctx_) return false;
return grpc_auth_context_peer_is_authenticated(ctx_) != 0;
if (ctx_ == nullptr) return false;
return grpc_auth_context_peer_is_authenticated(ctx_.get()) != 0;
}
} // namespace grpc

@ -21,15 +21,17 @@
#include <grpcpp/security/auth_context.h>
struct grpc_auth_context;
#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/lib/security/context/security_context.h"
namespace grpc {
class SecureAuthContext final : public AuthContext {
public:
SecureAuthContext(grpc_auth_context* ctx, bool take_ownership);
explicit SecureAuthContext(grpc_auth_context* ctx)
: ctx_(ctx != nullptr ? ctx->Ref() : nullptr) {}
~SecureAuthContext() override;
~SecureAuthContext() override = default;
bool IsPeerAuthenticated() const override;
@ -50,8 +52,7 @@ class SecureAuthContext final : public AuthContext {
virtual bool SetPeerIdentityPropertyName(const grpc::string& name) override;
private:
grpc_auth_context* ctx_;
bool take_ownership_;
grpc_core::RefCountedPtr<grpc_auth_context> ctx_;
};
} // namespace grpc

@ -20,6 +20,7 @@
#include <grpc/grpc.h>
#include <grpc/grpc_security.h>
#include <grpcpp/security/auth_context.h>
#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/cpp/common/secure_auth_context.h"
namespace grpc {
@ -28,8 +29,8 @@ std::shared_ptr<const AuthContext> CreateAuthContext(grpc_call* call) {
if (call == nullptr) {
return std::shared_ptr<const AuthContext>();
}
return std::shared_ptr<const AuthContext>(
new SecureAuthContext(grpc_call_auth_context(call), true));
grpc_core::RefCountedPtr<grpc_auth_context> ctx(grpc_call_auth_context(call));
return std::make_shared<SecureAuthContext>(ctx.get());
}
} // namespace grpc

@ -61,7 +61,7 @@ void AuthMetadataProcessorAyncWrapper::InvokeProcessor(
metadata.insert(std::make_pair(StringRefFromSlice(&md[i].key),
StringRefFromSlice(&md[i].value)));
}
SecureAuthContext context(ctx, false);
SecureAuthContext context(ctx);
AuthMetadataProcessor::OutputMetadata consumed_metadata;
AuthMetadataProcessor::OutputMetadata response_metadata;

@ -33,40 +33,34 @@ using grpc_core::internal::grpc_alts_auth_context_from_tsi_peer;
/* This file contains unit tests of grpc_alts_auth_context_from_tsi_peer(). */
static void test_invalid_input_failure() {
tsi_peer peer;
grpc_auth_context* ctx;
GPR_ASSERT(grpc_alts_auth_context_from_tsi_peer(nullptr, &ctx) ==
GRPC_SECURITY_ERROR);
GPR_ASSERT(grpc_alts_auth_context_from_tsi_peer(&peer, nullptr) ==
GRPC_SECURITY_ERROR);
grpc_core::RefCountedPtr<grpc_auth_context> ctx =
grpc_alts_auth_context_from_tsi_peer(nullptr);
GPR_ASSERT(ctx == nullptr);
}
static void test_empty_certificate_type_failure() {
tsi_peer peer;
grpc_auth_context* ctx = nullptr;
GPR_ASSERT(tsi_construct_peer(0, &peer) == TSI_OK);
GPR_ASSERT(grpc_alts_auth_context_from_tsi_peer(&peer, &ctx) ==
GRPC_SECURITY_ERROR);
grpc_core::RefCountedPtr<grpc_auth_context> ctx =
grpc_alts_auth_context_from_tsi_peer(&peer);
GPR_ASSERT(ctx == nullptr);
tsi_peer_destruct(&peer);
}
static void test_empty_peer_property_failure() {
tsi_peer peer;
grpc_auth_context* ctx;
GPR_ASSERT(tsi_construct_peer(1, &peer) == TSI_OK);
GPR_ASSERT(tsi_construct_string_peer_property_from_cstring(
TSI_CERTIFICATE_TYPE_PEER_PROPERTY, TSI_ALTS_CERTIFICATE_TYPE,
&peer.properties[0]) == TSI_OK);
GPR_ASSERT(grpc_alts_auth_context_from_tsi_peer(&peer, &ctx) ==
GRPC_SECURITY_ERROR);
grpc_core::RefCountedPtr<grpc_auth_context> ctx =
grpc_alts_auth_context_from_tsi_peer(&peer);
GPR_ASSERT(ctx == nullptr);
tsi_peer_destruct(&peer);
}
static void test_missing_rpc_protocol_versions_property_failure() {
tsi_peer peer;
grpc_auth_context* ctx;
GPR_ASSERT(tsi_construct_peer(kTsiAltsNumOfPeerProperties, &peer) == TSI_OK);
GPR_ASSERT(tsi_construct_string_peer_property_from_cstring(
TSI_CERTIFICATE_TYPE_PEER_PROPERTY, TSI_ALTS_CERTIFICATE_TYPE,
@ -74,23 +68,22 @@ static void test_missing_rpc_protocol_versions_property_failure() {
GPR_ASSERT(tsi_construct_string_peer_property_from_cstring(
TSI_ALTS_SERVICE_ACCOUNT_PEER_PROPERTY, "alice",
&peer.properties[1]) == TSI_OK);
GPR_ASSERT(grpc_alts_auth_context_from_tsi_peer(&peer, &ctx) ==
GRPC_SECURITY_ERROR);
grpc_core::RefCountedPtr<grpc_auth_context> ctx =
grpc_alts_auth_context_from_tsi_peer(&peer);
GPR_ASSERT(ctx == nullptr);
tsi_peer_destruct(&peer);
}
static void test_unknown_peer_property_failure() {
tsi_peer peer;
grpc_auth_context* ctx;
GPR_ASSERT(tsi_construct_peer(kTsiAltsNumOfPeerProperties, &peer) == TSI_OK);
GPR_ASSERT(tsi_construct_string_peer_property_from_cstring(
TSI_CERTIFICATE_TYPE_PEER_PROPERTY, TSI_ALTS_CERTIFICATE_TYPE,
&peer.properties[0]) == TSI_OK);
GPR_ASSERT(tsi_construct_string_peer_property_from_cstring(
"unknown", "alice", &peer.properties[1]) == TSI_OK);
GPR_ASSERT(grpc_alts_auth_context_from_tsi_peer(&peer, &ctx) ==
GRPC_SECURITY_ERROR);
grpc_core::RefCountedPtr<grpc_auth_context> ctx =
grpc_alts_auth_context_from_tsi_peer(&peer);
GPR_ASSERT(ctx == nullptr);
tsi_peer_destruct(&peer);
}
@ -119,7 +112,6 @@ static bool test_identity(const grpc_auth_context* ctx,
static void test_alts_peer_to_auth_context_success() {
tsi_peer peer;
grpc_auth_context* ctx;
GPR_ASSERT(tsi_construct_peer(kTsiAltsNumOfPeerProperties, &peer) == TSI_OK);
GPR_ASSERT(tsi_construct_string_peer_property_from_cstring(
TSI_CERTIFICATE_TYPE_PEER_PROPERTY, TSI_ALTS_CERTIFICATE_TYPE,
@ -144,11 +136,12 @@ static void test_alts_peer_to_auth_context_success() {
GRPC_SLICE_START_PTR(serialized_peer_versions)),
GRPC_SLICE_LENGTH(serialized_peer_versions),
&peer.properties[2]) == TSI_OK);
GPR_ASSERT(grpc_alts_auth_context_from_tsi_peer(&peer, &ctx) ==
GRPC_SECURITY_OK);
GPR_ASSERT(
test_identity(ctx, TSI_ALTS_SERVICE_ACCOUNT_PEER_PROPERTY, "alice"));
GRPC_AUTH_CONTEXT_UNREF(ctx, "test");
grpc_core::RefCountedPtr<grpc_auth_context> ctx =
grpc_alts_auth_context_from_tsi_peer(&peer);
GPR_ASSERT(ctx != nullptr);
GPR_ASSERT(test_identity(ctx.get(), TSI_ALTS_SERVICE_ACCOUNT_PEER_PROPERTY,
"alice"));
ctx.reset(DEBUG_LOCATION, "test");
grpc_slice_unref(serialized_peer_versions);
tsi_peer_destruct(&peer);
}

@ -19,114 +19,122 @@
#include <string.h>
#include "src/core/lib/gpr/string.h"
#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/lib/security/context/security_context.h"
#include "test/core/util/test_config.h"
#include <grpc/support/log.h>
static void test_empty_context(void) {
grpc_auth_context* ctx = grpc_auth_context_create(nullptr);
grpc_core::RefCountedPtr<grpc_auth_context> ctx =
grpc_core::MakeRefCounted<grpc_auth_context>(nullptr);
grpc_auth_property_iterator it;
gpr_log(GPR_INFO, "test_empty_context");
GPR_ASSERT(ctx != nullptr);
GPR_ASSERT(grpc_auth_context_peer_identity_property_name(ctx) == nullptr);
it = grpc_auth_context_peer_identity(ctx);
GPR_ASSERT(grpc_auth_context_peer_identity_property_name(ctx.get()) ==
nullptr);
it = grpc_auth_context_peer_identity(ctx.get());
GPR_ASSERT(grpc_auth_property_iterator_next(&it) == nullptr);
it = grpc_auth_context_property_iterator(ctx);
it = grpc_auth_context_property_iterator(ctx.get());
GPR_ASSERT(grpc_auth_property_iterator_next(&it) == nullptr);
it = grpc_auth_context_find_properties_by_name(ctx, "foo");
it = grpc_auth_context_find_properties_by_name(ctx.get(), "foo");
GPR_ASSERT(grpc_auth_property_iterator_next(&it) == nullptr);
GPR_ASSERT(grpc_auth_context_set_peer_identity_property_name(ctx, "bar") ==
0);
GPR_ASSERT(grpc_auth_context_peer_identity_property_name(ctx) == nullptr);
GRPC_AUTH_CONTEXT_UNREF(ctx, "test");
GPR_ASSERT(
grpc_auth_context_set_peer_identity_property_name(ctx.get(), "bar") == 0);
GPR_ASSERT(grpc_auth_context_peer_identity_property_name(ctx.get()) ==
nullptr);
ctx.reset(DEBUG_LOCATION, "test");
}
static void test_simple_context(void) {
grpc_auth_context* ctx = grpc_auth_context_create(nullptr);
grpc_core::RefCountedPtr<grpc_auth_context> ctx =
grpc_core::MakeRefCounted<grpc_auth_context>(nullptr);
grpc_auth_property_iterator it;
size_t i;
gpr_log(GPR_INFO, "test_simple_context");
GPR_ASSERT(ctx != nullptr);
grpc_auth_context_add_cstring_property(ctx, "name", "chapi");
grpc_auth_context_add_cstring_property(ctx, "name", "chapo");
grpc_auth_context_add_cstring_property(ctx, "foo", "bar");
GPR_ASSERT(ctx->properties.count == 3);
GPR_ASSERT(grpc_auth_context_set_peer_identity_property_name(ctx, "name") ==
1);
GPR_ASSERT(
strcmp(grpc_auth_context_peer_identity_property_name(ctx), "name") == 0);
it = grpc_auth_context_property_iterator(ctx);
for (i = 0; i < ctx->properties.count; i++) {
grpc_auth_context_add_cstring_property(ctx.get(), "name", "chapi");
grpc_auth_context_add_cstring_property(ctx.get(), "name", "chapo");
grpc_auth_context_add_cstring_property(ctx.get(), "foo", "bar");
GPR_ASSERT(ctx->properties().count == 3);
GPR_ASSERT(grpc_auth_context_set_peer_identity_property_name(ctx.get(),
"name") == 1);
GPR_ASSERT(strcmp(grpc_auth_context_peer_identity_property_name(ctx.get()),
"name") == 0);
it = grpc_auth_context_property_iterator(ctx.get());
for (i = 0; i < ctx->properties().count; i++) {
const grpc_auth_property* p = grpc_auth_property_iterator_next(&it);
GPR_ASSERT(p == &ctx->properties.array[i]);
GPR_ASSERT(p == &ctx->properties().array[i]);
}
GPR_ASSERT(grpc_auth_property_iterator_next(&it) == nullptr);
it = grpc_auth_context_find_properties_by_name(ctx, "foo");
it = grpc_auth_context_find_properties_by_name(ctx.get(), "foo");
GPR_ASSERT(grpc_auth_property_iterator_next(&it) ==
&ctx->properties.array[2]);
&ctx->properties().array[2]);
GPR_ASSERT(grpc_auth_property_iterator_next(&it) == nullptr);
it = grpc_auth_context_peer_identity(ctx);
it = grpc_auth_context_peer_identity(ctx.get());
GPR_ASSERT(grpc_auth_property_iterator_next(&it) ==
&ctx->properties.array[0]);
&ctx->properties().array[0]);
GPR_ASSERT(grpc_auth_property_iterator_next(&it) ==
&ctx->properties.array[1]);
&ctx->properties().array[1]);
GPR_ASSERT(grpc_auth_property_iterator_next(&it) == nullptr);
GRPC_AUTH_CONTEXT_UNREF(ctx, "test");
ctx.reset(DEBUG_LOCATION, "test");
}
static void test_chained_context(void) {
grpc_auth_context* chained = grpc_auth_context_create(nullptr);
grpc_auth_context* ctx = grpc_auth_context_create(chained);
grpc_core::RefCountedPtr<grpc_auth_context> chained =
grpc_core::MakeRefCounted<grpc_auth_context>(nullptr);
grpc_auth_context* chained_ptr = chained.get();
grpc_core::RefCountedPtr<grpc_auth_context> ctx =
grpc_core::MakeRefCounted<grpc_auth_context>(std::move(chained));
grpc_auth_property_iterator it;
size_t i;
gpr_log(GPR_INFO, "test_chained_context");
GRPC_AUTH_CONTEXT_UNREF(chained, "chained");
grpc_auth_context_add_cstring_property(chained, "name", "padapo");
grpc_auth_context_add_cstring_property(chained, "foo", "baz");
grpc_auth_context_add_cstring_property(ctx, "name", "chapi");
grpc_auth_context_add_cstring_property(ctx, "name", "chap0");
grpc_auth_context_add_cstring_property(ctx, "foo", "bar");
GPR_ASSERT(grpc_auth_context_set_peer_identity_property_name(ctx, "name") ==
1);
GPR_ASSERT(
strcmp(grpc_auth_context_peer_identity_property_name(ctx), "name") == 0);
it = grpc_auth_context_property_iterator(ctx);
for (i = 0; i < ctx->properties.count; i++) {
grpc_auth_context_add_cstring_property(chained_ptr, "name", "padapo");
grpc_auth_context_add_cstring_property(chained_ptr, "foo", "baz");
grpc_auth_context_add_cstring_property(ctx.get(), "name", "chapi");
grpc_auth_context_add_cstring_property(ctx.get(), "name", "chap0");
grpc_auth_context_add_cstring_property(ctx.get(), "foo", "bar");
GPR_ASSERT(grpc_auth_context_set_peer_identity_property_name(ctx.get(),
"name") == 1);
GPR_ASSERT(strcmp(grpc_auth_context_peer_identity_property_name(ctx.get()),
"name") == 0);
it = grpc_auth_context_property_iterator(ctx.get());
for (i = 0; i < ctx->properties().count; i++) {
const grpc_auth_property* p = grpc_auth_property_iterator_next(&it);
GPR_ASSERT(p == &ctx->properties.array[i]);
GPR_ASSERT(p == &ctx->properties().array[i]);
}
for (i = 0; i < chained->properties.count; i++) {
for (i = 0; i < chained_ptr->properties().count; i++) {
const grpc_auth_property* p = grpc_auth_property_iterator_next(&it);
GPR_ASSERT(p == &chained->properties.array[i]);
GPR_ASSERT(p == &chained_ptr->properties().array[i]);
}
GPR_ASSERT(grpc_auth_property_iterator_next(&it) == nullptr);
it = grpc_auth_context_find_properties_by_name(ctx, "foo");
it = grpc_auth_context_find_properties_by_name(ctx.get(), "foo");
GPR_ASSERT(grpc_auth_property_iterator_next(&it) ==
&ctx->properties.array[2]);
&ctx->properties().array[2]);
GPR_ASSERT(grpc_auth_property_iterator_next(&it) ==
&chained->properties.array[1]);
&chained_ptr->properties().array[1]);
GPR_ASSERT(grpc_auth_property_iterator_next(&it) == nullptr);
it = grpc_auth_context_peer_identity(ctx);
it = grpc_auth_context_peer_identity(ctx.get());
GPR_ASSERT(grpc_auth_property_iterator_next(&it) ==
&ctx->properties.array[0]);
&ctx->properties().array[0]);
GPR_ASSERT(grpc_auth_property_iterator_next(&it) ==
&ctx->properties.array[1]);
&ctx->properties().array[1]);
GPR_ASSERT(grpc_auth_property_iterator_next(&it) ==
&chained->properties.array[0]);
&chained_ptr->properties().array[0]);
GPR_ASSERT(grpc_auth_property_iterator_next(&it) == nullptr);
GRPC_AUTH_CONTEXT_UNREF(ctx, "test");
ctx.reset(DEBUG_LOCATION, "test");
}
int main(int argc, char** argv) {

@ -46,19 +46,6 @@
using grpc_core::internal::grpc_flush_cached_google_default_credentials;
using grpc_core::internal::set_gce_tenancy_checker_for_testing;
/* -- Mock channel credentials. -- */
static grpc_channel_credentials* grpc_mock_channel_credentials_create(
const grpc_channel_credentials_vtable* vtable) {
grpc_channel_credentials* c =
static_cast<grpc_channel_credentials*>(gpr_malloc(sizeof(*c)));
memset(c, 0, sizeof(*c));
c->type = "mock";
c->vtable = vtable;
gpr_ref_init(&c->refcount, 1);
return c;
}
/* -- Constants. -- */
static const char test_google_iam_authorization_token[] = "blahblahblhahb";
@ -377,9 +364,9 @@ static void run_request_metadata_test(grpc_call_credentials* creds,
grpc_auth_metadata_context auth_md_ctx,
request_metadata_state* state) {
grpc_error* error = GRPC_ERROR_NONE;
if (grpc_call_credentials_get_request_metadata(
creds, &state->pollent, auth_md_ctx, &state->md_array,
&state->on_request_metadata, &error)) {
if (creds->get_request_metadata(&state->pollent, auth_md_ctx,
&state->md_array, &state->on_request_metadata,
&error)) {
// Synchronous result. Invoke the callback directly.
check_request_metadata(state, error);
GRPC_ERROR_UNREF(error);
@ -400,7 +387,7 @@ static void test_google_iam_creds(void) {
grpc_auth_metadata_context auth_md_ctx = {test_service_url, test_method,
nullptr, nullptr};
run_request_metadata_test(creds, auth_md_ctx, state);
grpc_call_credentials_unref(creds);
creds->Unref();
}
static void test_access_token_creds(void) {
@ -412,28 +399,36 @@ static void test_access_token_creds(void) {
grpc_access_token_credentials_create("blah", nullptr);
grpc_auth_metadata_context auth_md_ctx = {test_service_url, test_method,
nullptr, nullptr};
GPR_ASSERT(strcmp(creds->type, GRPC_CALL_CREDENTIALS_TYPE_OAUTH2) == 0);
GPR_ASSERT(strcmp(creds->type(), GRPC_CALL_CREDENTIALS_TYPE_OAUTH2) == 0);
run_request_metadata_test(creds, auth_md_ctx, state);
grpc_call_credentials_unref(creds);
creds->Unref();
}
static grpc_security_status check_channel_oauth2_create_security_connector(
grpc_channel_credentials* c, grpc_call_credentials* call_creds,
const char* target, const grpc_channel_args* args,
grpc_channel_security_connector** sc, grpc_channel_args** new_args) {
GPR_ASSERT(strcmp(c->type, "mock") == 0);
GPR_ASSERT(call_creds != nullptr);
GPR_ASSERT(strcmp(call_creds->type, GRPC_CALL_CREDENTIALS_TYPE_OAUTH2) == 0);
return GRPC_SECURITY_OK;
}
namespace {
class check_channel_oauth2 final : public grpc_channel_credentials {
public:
check_channel_oauth2() : grpc_channel_credentials("mock") {}
~check_channel_oauth2() override = default;
grpc_core::RefCountedPtr<grpc_channel_security_connector>
create_security_connector(
grpc_core::RefCountedPtr<grpc_call_credentials> call_creds,
const char* target, const grpc_channel_args* args,
grpc_channel_args** new_args) override {
GPR_ASSERT(strcmp(type(), "mock") == 0);
GPR_ASSERT(call_creds != nullptr);
GPR_ASSERT(strcmp(call_creds->type(), GRPC_CALL_CREDENTIALS_TYPE_OAUTH2) ==
0);
return nullptr;
}
};
} // namespace
static void test_channel_oauth2_composite_creds(void) {
grpc_core::ExecCtx exec_ctx;
grpc_channel_args* new_args;
grpc_channel_credentials_vtable vtable = {
nullptr, check_channel_oauth2_create_security_connector, nullptr};
grpc_channel_credentials* channel_creds =
grpc_mock_channel_credentials_create(&vtable);
grpc_core::New<check_channel_oauth2>();
grpc_call_credentials* oauth2_creds =
grpc_access_token_credentials_create("blah", nullptr);
grpc_channel_credentials* channel_oauth2_creds =
@ -441,9 +436,8 @@ static void test_channel_oauth2_composite_creds(void) {
nullptr);
grpc_channel_credentials_release(channel_creds);
grpc_call_credentials_release(oauth2_creds);
GPR_ASSERT(grpc_channel_credentials_create_security_connector(
channel_oauth2_creds, nullptr, nullptr, nullptr, &new_args) ==
GRPC_SECURITY_OK);
channel_oauth2_creds->create_security_connector(nullptr, nullptr, nullptr,
&new_args);
grpc_channel_credentials_release(channel_oauth2_creds);
}
@ -467,47 +461,54 @@ static void test_oauth2_google_iam_composite_creds(void) {
grpc_call_credentials* composite_creds =
grpc_composite_call_credentials_create(oauth2_creds, google_iam_creds,
nullptr);
grpc_call_credentials_unref(oauth2_creds);
grpc_call_credentials_unref(google_iam_creds);
GPR_ASSERT(
strcmp(composite_creds->type, GRPC_CALL_CREDENTIALS_TYPE_COMPOSITE) == 0);
const grpc_call_credentials_array* creds_array =
grpc_composite_call_credentials_get_credentials(composite_creds);
GPR_ASSERT(creds_array->num_creds == 2);
GPR_ASSERT(strcmp(creds_array->creds_array[0]->type,
oauth2_creds->Unref();
google_iam_creds->Unref();
GPR_ASSERT(strcmp(composite_creds->type(),
GRPC_CALL_CREDENTIALS_TYPE_COMPOSITE) == 0);
const grpc_call_credentials_array& creds_array =
static_cast<const grpc_composite_call_credentials*>(composite_creds)
->inner();
GPR_ASSERT(creds_array.size() == 2);
GPR_ASSERT(strcmp(creds_array.get(0)->type(),
GRPC_CALL_CREDENTIALS_TYPE_OAUTH2) == 0);
GPR_ASSERT(strcmp(creds_array->creds_array[1]->type,
GRPC_CALL_CREDENTIALS_TYPE_IAM) == 0);
GPR_ASSERT(
strcmp(creds_array.get(1)->type(), GRPC_CALL_CREDENTIALS_TYPE_IAM) == 0);
run_request_metadata_test(composite_creds, auth_md_ctx, state);
grpc_call_credentials_unref(composite_creds);
composite_creds->Unref();
}
static grpc_security_status
check_channel_oauth2_google_iam_create_security_connector(
grpc_channel_credentials* c, grpc_call_credentials* call_creds,
const char* target, const grpc_channel_args* args,
grpc_channel_security_connector** sc, grpc_channel_args** new_args) {
const grpc_call_credentials_array* creds_array;
GPR_ASSERT(strcmp(c->type, "mock") == 0);
GPR_ASSERT(call_creds != nullptr);
GPR_ASSERT(strcmp(call_creds->type, GRPC_CALL_CREDENTIALS_TYPE_COMPOSITE) ==
0);
creds_array = grpc_composite_call_credentials_get_credentials(call_creds);
GPR_ASSERT(strcmp(creds_array->creds_array[0]->type,
GRPC_CALL_CREDENTIALS_TYPE_OAUTH2) == 0);
GPR_ASSERT(strcmp(creds_array->creds_array[1]->type,
GRPC_CALL_CREDENTIALS_TYPE_IAM) == 0);
return GRPC_SECURITY_OK;
}
namespace {
class check_channel_oauth2_google_iam final : public grpc_channel_credentials {
public:
check_channel_oauth2_google_iam() : grpc_channel_credentials("mock") {}
~check_channel_oauth2_google_iam() override = default;
grpc_core::RefCountedPtr<grpc_channel_security_connector>
create_security_connector(
grpc_core::RefCountedPtr<grpc_call_credentials> call_creds,
const char* target, const grpc_channel_args* args,
grpc_channel_args** new_args) override {
GPR_ASSERT(strcmp(type(), "mock") == 0);
GPR_ASSERT(call_creds != nullptr);
GPR_ASSERT(
strcmp(call_creds->type(), GRPC_CALL_CREDENTIALS_TYPE_COMPOSITE) == 0);
const grpc_call_credentials_array& creds_array =
static_cast<const grpc_composite_call_credentials*>(call_creds.get())
->inner();
GPR_ASSERT(strcmp(creds_array.get(0)->type(),
GRPC_CALL_CREDENTIALS_TYPE_OAUTH2) == 0);
GPR_ASSERT(strcmp(creds_array.get(1)->type(),
GRPC_CALL_CREDENTIALS_TYPE_IAM) == 0);
return nullptr;
}
};
} // namespace
static void test_channel_oauth2_google_iam_composite_creds(void) {
grpc_core::ExecCtx exec_ctx;
grpc_channel_args* new_args;
grpc_channel_credentials_vtable vtable = {
nullptr, check_channel_oauth2_google_iam_create_security_connector,
nullptr};
grpc_channel_credentials* channel_creds =
grpc_mock_channel_credentials_create(&vtable);
grpc_core::New<check_channel_oauth2_google_iam>();
grpc_call_credentials* oauth2_creds =
grpc_access_token_credentials_create("blah", nullptr);
grpc_channel_credentials* channel_oauth2_creds =
@ -524,9 +525,8 @@ static void test_channel_oauth2_google_iam_composite_creds(void) {
grpc_channel_credentials_release(channel_oauth2_creds);
grpc_call_credentials_release(google_iam_creds);
GPR_ASSERT(grpc_channel_credentials_create_security_connector(
channel_oauth2_iam_creds, nullptr, nullptr, nullptr,
&new_args) == GRPC_SECURITY_OK);
channel_oauth2_iam_creds->create_security_connector(nullptr, nullptr, nullptr,
&new_args);
grpc_channel_credentials_release(channel_oauth2_iam_creds);
}
@ -578,7 +578,7 @@ static int httpcli_get_should_not_be_called(const grpc_httpcli_request* request,
return 1;
}
static void test_compute_engine_creds_success(void) {
static void test_compute_engine_creds_success() {
grpc_core::ExecCtx exec_ctx;
expected_md emd[] = {
{"authorization", "Bearer ya29.AHES6ZRN3-HlhAPya30GnW_bHSb_"}};
@ -603,7 +603,7 @@ static void test_compute_engine_creds_success(void) {
run_request_metadata_test(creds, auth_md_ctx, state);
grpc_core::ExecCtx::Get()->Flush();
grpc_call_credentials_unref(creds);
creds->Unref();
grpc_httpcli_set_override(nullptr, nullptr);
}
@ -620,7 +620,7 @@ static void test_compute_engine_creds_failure(void) {
grpc_httpcli_set_override(compute_engine_httpcli_get_failure_override,
httpcli_post_should_not_be_called);
run_request_metadata_test(creds, auth_md_ctx, state);
grpc_call_credentials_unref(creds);
creds->Unref();
grpc_httpcli_set_override(nullptr, nullptr);
}
@ -692,7 +692,7 @@ static void test_refresh_token_creds_success(void) {
run_request_metadata_test(creds, auth_md_ctx, state);
grpc_core::ExecCtx::Get()->Flush();
grpc_call_credentials_unref(creds);
creds->Unref();
grpc_httpcli_set_override(nullptr, nullptr);
}
@ -709,7 +709,7 @@ static void test_refresh_token_creds_failure(void) {
grpc_httpcli_set_override(httpcli_get_should_not_be_called,
refresh_token_httpcli_post_failure);
run_request_metadata_test(creds, auth_md_ctx, state);
grpc_call_credentials_unref(creds);
creds->Unref();
grpc_httpcli_set_override(nullptr, nullptr);
}
@ -762,7 +762,7 @@ static char* encode_and_sign_jwt_should_not_be_called(
static grpc_service_account_jwt_access_credentials* creds_as_jwt(
grpc_call_credentials* creds) {
GPR_ASSERT(creds != nullptr);
GPR_ASSERT(strcmp(creds->type, GRPC_CALL_CREDENTIALS_TYPE_JWT) == 0);
GPR_ASSERT(strcmp(creds->type(), GRPC_CALL_CREDENTIALS_TYPE_JWT) == 0);
return reinterpret_cast<grpc_service_account_jwt_access_credentials*>(creds);
}
@ -773,7 +773,7 @@ static void test_jwt_creds_lifetime(void) {
grpc_call_credentials* jwt_creds =
grpc_service_account_jwt_access_credentials_create(
json_key_string, grpc_max_auth_token_lifetime(), nullptr);
GPR_ASSERT(gpr_time_cmp(creds_as_jwt(jwt_creds)->jwt_lifetime,
GPR_ASSERT(gpr_time_cmp(creds_as_jwt(jwt_creds)->jwt_lifetime(),
grpc_max_auth_token_lifetime()) == 0);
grpc_call_credentials_release(jwt_creds);
@ -782,8 +782,8 @@ static void test_jwt_creds_lifetime(void) {
GPR_ASSERT(gpr_time_cmp(grpc_max_auth_token_lifetime(), token_lifetime) > 0);
jwt_creds = grpc_service_account_jwt_access_credentials_create(
json_key_string, token_lifetime, nullptr);
GPR_ASSERT(
gpr_time_cmp(creds_as_jwt(jwt_creds)->jwt_lifetime, token_lifetime) == 0);
GPR_ASSERT(gpr_time_cmp(creds_as_jwt(jwt_creds)->jwt_lifetime(),
token_lifetime) == 0);
grpc_call_credentials_release(jwt_creds);
// Cropped lifetime.
@ -791,7 +791,7 @@ static void test_jwt_creds_lifetime(void) {
token_lifetime = gpr_time_add(grpc_max_auth_token_lifetime(), add_to_max);
jwt_creds = grpc_service_account_jwt_access_credentials_create(
json_key_string, token_lifetime, nullptr);
GPR_ASSERT(gpr_time_cmp(creds_as_jwt(jwt_creds)->jwt_lifetime,
GPR_ASSERT(gpr_time_cmp(creds_as_jwt(jwt_creds)->jwt_lifetime(),
grpc_max_auth_token_lifetime()) == 0);
grpc_call_credentials_release(jwt_creds);
@ -834,7 +834,7 @@ static void test_jwt_creds_success(void) {
run_request_metadata_test(creds, auth_md_ctx, state);
grpc_core::ExecCtx::Get()->Flush();
grpc_call_credentials_unref(creds);
creds->Unref();
gpr_free(json_key_string);
gpr_free(expected_md_value);
grpc_jwt_encode_and_sign_set_override(nullptr);
@ -856,7 +856,7 @@ static void test_jwt_creds_signing_failure(void) {
run_request_metadata_test(creds, auth_md_ctx, state);
gpr_free(json_key_string);
grpc_call_credentials_unref(creds);
creds->Unref();
grpc_jwt_encode_and_sign_set_override(nullptr);
}
@ -875,8 +875,6 @@ static void set_google_default_creds_env_var_with_file_contents(
static void test_google_default_creds_auth_key(void) {
grpc_core::ExecCtx exec_ctx;
grpc_service_account_jwt_access_credentials* jwt;
grpc_google_default_channel_credentials* default_creds;
grpc_composite_channel_credentials* creds;
char* json_key = test_json_key_str();
grpc_flush_cached_google_default_credentials();
@ -885,37 +883,39 @@ static void test_google_default_creds_auth_key(void) {
gpr_free(json_key);
creds = reinterpret_cast<grpc_composite_channel_credentials*>(
grpc_google_default_credentials_create());
default_creds = reinterpret_cast<grpc_google_default_channel_credentials*>(
creds->inner_creds);
GPR_ASSERT(default_creds->ssl_creds != nullptr);
jwt = reinterpret_cast<grpc_service_account_jwt_access_credentials*>(
creds->call_creds);
auto* default_creds =
reinterpret_cast<const grpc_google_default_channel_credentials*>(
creds->inner_creds());
GPR_ASSERT(default_creds->ssl_creds() != nullptr);
auto* jwt =
reinterpret_cast<const grpc_service_account_jwt_access_credentials*>(
creds->call_creds());
GPR_ASSERT(
strcmp(jwt->key.client_id,
strcmp(jwt->key().client_id,
"777-abaslkan11hlb6nmim3bpspl31ud.apps.googleusercontent.com") ==
0);
grpc_channel_credentials_unref(&creds->base);
creds->Unref();
gpr_setenv(GRPC_GOOGLE_CREDENTIALS_ENV_VAR, ""); /* Reset. */
}
static void test_google_default_creds_refresh_token(void) {
grpc_core::ExecCtx exec_ctx;
grpc_google_refresh_token_credentials* refresh;
grpc_google_default_channel_credentials* default_creds;
grpc_composite_channel_credentials* creds;
grpc_flush_cached_google_default_credentials();
set_google_default_creds_env_var_with_file_contents(
"refresh_token_google_default_creds", test_refresh_token_str);
creds = reinterpret_cast<grpc_composite_channel_credentials*>(
grpc_google_default_credentials_create());
default_creds = reinterpret_cast<grpc_google_default_channel_credentials*>(
creds->inner_creds);
GPR_ASSERT(default_creds->ssl_creds != nullptr);
refresh = reinterpret_cast<grpc_google_refresh_token_credentials*>(
creds->call_creds);
GPR_ASSERT(strcmp(refresh->refresh_token.client_id,
auto* default_creds =
reinterpret_cast<const grpc_google_default_channel_credentials*>(
creds->inner_creds());
GPR_ASSERT(default_creds->ssl_creds() != nullptr);
auto* refresh =
reinterpret_cast<const grpc_google_refresh_token_credentials*>(
creds->call_creds());
GPR_ASSERT(strcmp(refresh->refresh_token().client_id,
"32555999999.apps.googleusercontent.com") == 0);
grpc_channel_credentials_unref(&creds->base);
creds->Unref();
gpr_setenv(GRPC_GOOGLE_CREDENTIALS_ENV_VAR, ""); /* Reset. */
}
@ -965,16 +965,16 @@ static void test_google_default_creds_gce(void) {
/* Verify that the default creds actually embeds a GCE creds. */
GPR_ASSERT(creds != nullptr);
GPR_ASSERT(creds->call_creds != nullptr);
GPR_ASSERT(creds->call_creds() != nullptr);
grpc_httpcli_set_override(compute_engine_httpcli_get_success_override,
httpcli_post_should_not_be_called);
run_request_metadata_test(creds->call_creds, auth_md_ctx, state);
run_request_metadata_test(creds->mutable_call_creds(), auth_md_ctx, state);
grpc_core::ExecCtx::Get()->Flush();
GPR_ASSERT(g_test_gce_tenancy_checker_called == true);
/* Cleanup. */
grpc_channel_credentials_unref(&creds->base);
creds->Unref();
grpc_httpcli_set_override(nullptr, nullptr);
grpc_override_well_known_credentials_path_getter(nullptr);
}
@ -1003,14 +1003,14 @@ static void test_google_default_creds_non_gce(void) {
grpc_google_default_credentials_create());
/* Verify that the default creds actually embeds a GCE creds. */
GPR_ASSERT(creds != nullptr);
GPR_ASSERT(creds->call_creds != nullptr);
GPR_ASSERT(creds->call_creds() != nullptr);
grpc_httpcli_set_override(compute_engine_httpcli_get_success_override,
httpcli_post_should_not_be_called);
run_request_metadata_test(creds->call_creds, auth_md_ctx, state);
run_request_metadata_test(creds->mutable_call_creds(), auth_md_ctx, state);
grpc_core::ExecCtx::Get()->Flush();
GPR_ASSERT(g_test_gce_tenancy_checker_called == true);
/* Cleanup. */
grpc_channel_credentials_unref(&creds->base);
creds->Unref();
grpc_httpcli_set_override(nullptr, nullptr);
grpc_override_well_known_credentials_path_getter(nullptr);
}
@ -1121,7 +1121,7 @@ static void test_metadata_plugin_success(void) {
GPR_ASSERT(state == PLUGIN_INITIAL_STATE);
run_request_metadata_test(creds, auth_md_ctx, md_state);
GPR_ASSERT(state == PLUGIN_GET_METADATA_CALLED_STATE);
grpc_call_credentials_unref(creds);
creds->Unref();
GPR_ASSERT(state == PLUGIN_DESTROY_CALLED_STATE);
}
@ -1149,7 +1149,7 @@ static void test_metadata_plugin_failure(void) {
GPR_ASSERT(state == PLUGIN_INITIAL_STATE);
run_request_metadata_test(creds, auth_md_ctx, md_state);
GPR_ASSERT(state == PLUGIN_GET_METADATA_CALLED_STATE);
grpc_call_credentials_unref(creds);
creds->Unref();
GPR_ASSERT(state == PLUGIN_DESTROY_CALLED_STATE);
}
@ -1176,25 +1176,23 @@ static void test_channel_creds_duplicate_without_call_creds(void) {
grpc_channel_credentials* channel_creds =
grpc_fake_transport_security_credentials_create();
grpc_channel_credentials* dup =
grpc_channel_credentials_duplicate_without_call_credentials(
channel_creds);
grpc_core::RefCountedPtr<grpc_channel_credentials> dup =
channel_creds->duplicate_without_call_credentials();
GPR_ASSERT(dup == channel_creds);
grpc_channel_credentials_unref(dup);
dup.reset();
grpc_call_credentials* call_creds =
grpc_access_token_credentials_create("blah", nullptr);
grpc_channel_credentials* composite_creds =
grpc_composite_channel_credentials_create(channel_creds, call_creds,
nullptr);
grpc_call_credentials_unref(call_creds);
dup = grpc_channel_credentials_duplicate_without_call_credentials(
composite_creds);
call_creds->Unref();
dup = composite_creds->duplicate_without_call_credentials();
GPR_ASSERT(dup == channel_creds);
grpc_channel_credentials_unref(dup);
dup.reset();
grpc_channel_credentials_unref(channel_creds);
grpc_channel_credentials_unref(composite_creds);
channel_creds->Unref();
composite_creds->Unref();
}
typedef struct {

@ -86,9 +86,8 @@ char* grpc_test_fetch_oauth2_token_with_credentials(
grpc_schedule_on_exec_ctx);
grpc_error* error = GRPC_ERROR_NONE;
if (grpc_call_credentials_get_request_metadata(creds, &request.pops, null_ctx,
&request.md_array,
&request.closure, &error)) {
if (creds->get_request_metadata(&request.pops, null_ctx, &request.md_array,
&request.closure, &error)) {
// Synchronous result; invoke callback directly.
on_oauth2_response(&request, error);
GRPC_ERROR_UNREF(error);

@ -96,11 +96,10 @@ int main(int argc, char** argv) {
grpc_schedule_on_exec_ctx);
error = GRPC_ERROR_NONE;
if (grpc_call_credentials_get_request_metadata(
(reinterpret_cast<grpc_composite_channel_credentials*>(creds))
->call_creds,
&sync.pops, context, &sync.md_array, &sync.on_request_metadata,
&error)) {
if (reinterpret_cast<grpc_composite_channel_credentials*>(creds)
->mutable_call_creds()
->get_request_metadata(&sync.pops, context, &sync.md_array,
&sync.on_request_metadata, &error)) {
// Synchronous response. Invoke callback directly.
on_metadata_response(&sync, error);
GRPC_ERROR_UNREF(error);

@ -27,6 +27,7 @@
#include "src/core/lib/gpr/env.h"
#include "src/core/lib/gpr/string.h"
#include "src/core/lib/gpr/tmpfile.h"
#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/lib/security/context/security_context.h"
#include "src/core/lib/security/security_connector/security_connector.h"
#include "src/core/lib/security/security_connector/ssl_utils.h"
@ -83,22 +84,22 @@ static int check_ssl_peer_equivalence(const tsi_peer* original,
static void test_unauthenticated_ssl_peer(void) {
tsi_peer peer;
tsi_peer rpeer;
grpc_auth_context* ctx;
GPR_ASSERT(tsi_construct_peer(1, &peer) == TSI_OK);
GPR_ASSERT(tsi_construct_string_peer_property_from_cstring(
TSI_CERTIFICATE_TYPE_PEER_PROPERTY, TSI_X509_CERTIFICATE_TYPE,
&peer.properties[0]) == TSI_OK);
ctx = grpc_ssl_peer_to_auth_context(&peer);
grpc_core::RefCountedPtr<grpc_auth_context> ctx =
grpc_ssl_peer_to_auth_context(&peer);
GPR_ASSERT(ctx != nullptr);
GPR_ASSERT(!grpc_auth_context_peer_is_authenticated(ctx));
GPR_ASSERT(check_transport_security_type(ctx));
GPR_ASSERT(!grpc_auth_context_peer_is_authenticated(ctx.get()));
GPR_ASSERT(check_transport_security_type(ctx.get()));
rpeer = grpc_shallow_peer_from_ssl_auth_context(ctx);
rpeer = grpc_shallow_peer_from_ssl_auth_context(ctx.get());
GPR_ASSERT(check_ssl_peer_equivalence(&peer, &rpeer));
grpc_shallow_peer_destruct(&rpeer);
tsi_peer_destruct(&peer);
GRPC_AUTH_CONTEXT_UNREF(ctx, "test");
ctx.reset(DEBUG_LOCATION, "test");
}
static int check_identity(const grpc_auth_context* ctx,
@ -175,7 +176,6 @@ static int check_x509_pem_cert(const grpc_auth_context* ctx,
static void test_cn_only_ssl_peer_to_auth_context(void) {
tsi_peer peer;
tsi_peer rpeer;
grpc_auth_context* ctx;
const char* expected_cn = "cn1";
const char* expected_pem_cert = "pem_cert1";
GPR_ASSERT(tsi_construct_peer(3, &peer) == TSI_OK);
@ -188,26 +188,27 @@ static void test_cn_only_ssl_peer_to_auth_context(void) {
GPR_ASSERT(tsi_construct_string_peer_property_from_cstring(
TSI_X509_PEM_CERT_PROPERTY, expected_pem_cert,
&peer.properties[2]) == TSI_OK);
ctx = grpc_ssl_peer_to_auth_context(&peer);
grpc_core::RefCountedPtr<grpc_auth_context> ctx =
grpc_ssl_peer_to_auth_context(&peer);
GPR_ASSERT(ctx != nullptr);
GPR_ASSERT(grpc_auth_context_peer_is_authenticated(ctx));
GPR_ASSERT(check_identity(ctx, GRPC_X509_CN_PROPERTY_NAME, &expected_cn, 1));
GPR_ASSERT(check_transport_security_type(ctx));
GPR_ASSERT(check_x509_cn(ctx, expected_cn));
GPR_ASSERT(check_x509_pem_cert(ctx, expected_pem_cert));
GPR_ASSERT(grpc_auth_context_peer_is_authenticated(ctx.get()));
GPR_ASSERT(
check_identity(ctx.get(), GRPC_X509_CN_PROPERTY_NAME, &expected_cn, 1));
GPR_ASSERT(check_transport_security_type(ctx.get()));
GPR_ASSERT(check_x509_cn(ctx.get(), expected_cn));
GPR_ASSERT(check_x509_pem_cert(ctx.get(), expected_pem_cert));
rpeer = grpc_shallow_peer_from_ssl_auth_context(ctx);
rpeer = grpc_shallow_peer_from_ssl_auth_context(ctx.get());
GPR_ASSERT(check_ssl_peer_equivalence(&peer, &rpeer));
grpc_shallow_peer_destruct(&rpeer);
tsi_peer_destruct(&peer);
GRPC_AUTH_CONTEXT_UNREF(ctx, "test");
ctx.reset(DEBUG_LOCATION, "test");
}
static void test_cn_and_one_san_ssl_peer_to_auth_context(void) {
tsi_peer peer;
tsi_peer rpeer;
grpc_auth_context* ctx;
const char* expected_cn = "cn1";
const char* expected_san = "san1";
const char* expected_pem_cert = "pem_cert1";
@ -224,27 +225,28 @@ static void test_cn_and_one_san_ssl_peer_to_auth_context(void) {
GPR_ASSERT(tsi_construct_string_peer_property_from_cstring(
TSI_X509_PEM_CERT_PROPERTY, expected_pem_cert,
&peer.properties[3]) == TSI_OK);
ctx = grpc_ssl_peer_to_auth_context(&peer);
grpc_core::RefCountedPtr<grpc_auth_context> ctx =
grpc_ssl_peer_to_auth_context(&peer);
GPR_ASSERT(ctx != nullptr);
GPR_ASSERT(grpc_auth_context_peer_is_authenticated(ctx));
GPR_ASSERT(grpc_auth_context_peer_is_authenticated(ctx.get()));
GPR_ASSERT(
check_identity(ctx, GRPC_X509_SAN_PROPERTY_NAME, &expected_san, 1));
GPR_ASSERT(check_transport_security_type(ctx));
GPR_ASSERT(check_x509_cn(ctx, expected_cn));
GPR_ASSERT(check_x509_pem_cert(ctx, expected_pem_cert));
check_identity(ctx.get(), GRPC_X509_SAN_PROPERTY_NAME, &expected_san, 1));
GPR_ASSERT(check_transport_security_type(ctx.get()));
GPR_ASSERT(check_x509_cn(ctx.get(), expected_cn));
GPR_ASSERT(check_x509_pem_cert(ctx.get(), expected_pem_cert));
rpeer = grpc_shallow_peer_from_ssl_auth_context(ctx);
rpeer = grpc_shallow_peer_from_ssl_auth_context(ctx.get());
GPR_ASSERT(check_ssl_peer_equivalence(&peer, &rpeer));
grpc_shallow_peer_destruct(&rpeer);
tsi_peer_destruct(&peer);
GRPC_AUTH_CONTEXT_UNREF(ctx, "test");
ctx.reset(DEBUG_LOCATION, "test");
}
static void test_cn_and_multiple_sans_ssl_peer_to_auth_context(void) {
tsi_peer peer;
tsi_peer rpeer;
grpc_auth_context* ctx;
const char* expected_cn = "cn1";
const char* expected_sans[] = {"san1", "san2", "san3"};
const char* expected_pem_cert = "pem_cert1";
@ -265,28 +267,28 @@ static void test_cn_and_multiple_sans_ssl_peer_to_auth_context(void) {
TSI_X509_SUBJECT_ALTERNATIVE_NAME_PEER_PROPERTY,
expected_sans[i], &peer.properties[3 + i]) == TSI_OK);
}
ctx = grpc_ssl_peer_to_auth_context(&peer);
grpc_core::RefCountedPtr<grpc_auth_context> ctx =
grpc_ssl_peer_to_auth_context(&peer);
GPR_ASSERT(ctx != nullptr);
GPR_ASSERT(grpc_auth_context_peer_is_authenticated(ctx));
GPR_ASSERT(check_identity(ctx, GRPC_X509_SAN_PROPERTY_NAME, expected_sans,
GPR_ARRAY_SIZE(expected_sans)));
GPR_ASSERT(check_transport_security_type(ctx));
GPR_ASSERT(check_x509_cn(ctx, expected_cn));
GPR_ASSERT(check_x509_pem_cert(ctx, expected_pem_cert));
rpeer = grpc_shallow_peer_from_ssl_auth_context(ctx);
GPR_ASSERT(grpc_auth_context_peer_is_authenticated(ctx.get()));
GPR_ASSERT(check_identity(ctx.get(), GRPC_X509_SAN_PROPERTY_NAME,
expected_sans, GPR_ARRAY_SIZE(expected_sans)));
GPR_ASSERT(check_transport_security_type(ctx.get()));
GPR_ASSERT(check_x509_cn(ctx.get(), expected_cn));
GPR_ASSERT(check_x509_pem_cert(ctx.get(), expected_pem_cert));
rpeer = grpc_shallow_peer_from_ssl_auth_context(ctx.get());
GPR_ASSERT(check_ssl_peer_equivalence(&peer, &rpeer));
grpc_shallow_peer_destruct(&rpeer);
tsi_peer_destruct(&peer);
GRPC_AUTH_CONTEXT_UNREF(ctx, "test");
ctx.reset(DEBUG_LOCATION, "test");
}
static void test_cn_and_multiple_sans_and_others_ssl_peer_to_auth_context(
void) {
tsi_peer peer;
tsi_peer rpeer;
grpc_auth_context* ctx;
const char* expected_cn = "cn1";
const char* expected_pem_cert = "pem_cert1";
const char* expected_sans[] = {"san1", "san2", "san3"};
@ -311,21 +313,22 @@ static void test_cn_and_multiple_sans_and_others_ssl_peer_to_auth_context(
TSI_X509_SUBJECT_ALTERNATIVE_NAME_PEER_PROPERTY,
expected_sans[i], &peer.properties[5 + i]) == TSI_OK);
}
ctx = grpc_ssl_peer_to_auth_context(&peer);
grpc_core::RefCountedPtr<grpc_auth_context> ctx =
grpc_ssl_peer_to_auth_context(&peer);
GPR_ASSERT(ctx != nullptr);
GPR_ASSERT(grpc_auth_context_peer_is_authenticated(ctx));
GPR_ASSERT(check_identity(ctx, GRPC_X509_SAN_PROPERTY_NAME, expected_sans,
GPR_ARRAY_SIZE(expected_sans)));
GPR_ASSERT(check_transport_security_type(ctx));
GPR_ASSERT(check_x509_cn(ctx, expected_cn));
GPR_ASSERT(check_x509_pem_cert(ctx, expected_pem_cert));
rpeer = grpc_shallow_peer_from_ssl_auth_context(ctx);
GPR_ASSERT(grpc_auth_context_peer_is_authenticated(ctx.get()));
GPR_ASSERT(check_identity(ctx.get(), GRPC_X509_SAN_PROPERTY_NAME,
expected_sans, GPR_ARRAY_SIZE(expected_sans)));
GPR_ASSERT(check_transport_security_type(ctx.get()));
GPR_ASSERT(check_x509_cn(ctx.get(), expected_cn));
GPR_ASSERT(check_x509_pem_cert(ctx.get(), expected_pem_cert));
rpeer = grpc_shallow_peer_from_ssl_auth_context(ctx.get());
GPR_ASSERT(check_ssl_peer_equivalence(&peer, &rpeer));
grpc_shallow_peer_destruct(&rpeer);
tsi_peer_destruct(&peer);
GRPC_AUTH_CONTEXT_UNREF(ctx, "test");
ctx.reset(DEBUG_LOCATION, "test");
}
static const char* roots_for_override_api = "roots for override api";

@ -82,16 +82,15 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) {
ca_cert, &pem_key_cert_pair, 1, 0, nullptr);
// Create security connector
grpc_server_security_connector* sc = nullptr;
grpc_security_status status =
grpc_server_credentials_create_security_connector(creds, &sc);
GPR_ASSERT(status == GRPC_SECURITY_OK);
grpc_core::RefCountedPtr<grpc_server_security_connector> sc =
creds->create_security_connector();
GPR_ASSERT(sc != nullptr);
grpc_millis deadline = GPR_MS_PER_SEC + grpc_core::ExecCtx::Get()->Now();
struct handshake_state state;
state.done_callback_called = false;
grpc_handshake_manager* handshake_mgr = grpc_handshake_manager_create();
grpc_server_security_connector_add_handshakers(sc, nullptr, handshake_mgr);
sc->add_handshakers(nullptr, handshake_mgr);
grpc_handshake_manager_do_handshake(
handshake_mgr, mock_endpoint, nullptr /* channel_args */, deadline,
nullptr /* acceptor */, on_handshake_done, &state);
@ -110,7 +109,7 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) {
GPR_ASSERT(state.done_callback_called);
grpc_handshake_manager_destroy(handshake_mgr);
GRPC_SECURITY_CONNECTOR_UNREF(&sc->base, "test");
sc.reset(DEBUG_LOCATION, "test");
grpc_server_credentials_release(creds);
grpc_slice_unref(cert_slice);
grpc_slice_unref(key_slice);

@ -39,7 +39,7 @@ void test_unknown_scheme_target(void) {
GPR_ASSERT(0 == strcmp(elem->filter->name, "lame-client"));
grpc_core::ExecCtx exec_ctx;
GRPC_CHANNEL_INTERNAL_UNREF(chan, "test");
grpc_channel_credentials_unref(creds);
creds->Unref();
}
void test_security_connector_already_in_arg(void) {

@ -40,15 +40,14 @@ class TestAuthPropertyIterator : public AuthPropertyIterator {
class AuthPropertyIteratorTest : public ::testing::Test {
protected:
void SetUp() override {
ctx_ = grpc_auth_context_create(nullptr);
grpc_auth_context_add_cstring_property(ctx_, "name", "chapi");
grpc_auth_context_add_cstring_property(ctx_, "name", "chapo");
grpc_auth_context_add_cstring_property(ctx_, "foo", "bar");
EXPECT_EQ(1,
grpc_auth_context_set_peer_identity_property_name(ctx_, "name"));
ctx_ = grpc_core::MakeRefCounted<grpc_auth_context>(nullptr);
grpc_auth_context_add_cstring_property(ctx_.get(), "name", "chapi");
grpc_auth_context_add_cstring_property(ctx_.get(), "name", "chapo");
grpc_auth_context_add_cstring_property(ctx_.get(), "foo", "bar");
EXPECT_EQ(1, grpc_auth_context_set_peer_identity_property_name(ctx_.get(),
"name"));
}
void TearDown() override { grpc_auth_context_release(ctx_); }
grpc_auth_context* ctx_;
grpc_core::RefCountedPtr<grpc_auth_context> ctx_;
};
TEST_F(AuthPropertyIteratorTest, DefaultCtor) {
@ -59,7 +58,7 @@ TEST_F(AuthPropertyIteratorTest, DefaultCtor) {
TEST_F(AuthPropertyIteratorTest, GeneralTest) {
grpc_auth_property_iterator c_iter =
grpc_auth_context_property_iterator(ctx_);
grpc_auth_context_property_iterator(ctx_.get());
const grpc_auth_property* property =
grpc_auth_property_iterator_next(&c_iter);
TestAuthPropertyIterator iter(property, &c_iter);

@ -33,7 +33,7 @@ class SecureAuthContextTest : public ::testing::Test {};
// Created with nullptr
TEST_F(SecureAuthContextTest, EmptyContext) {
SecureAuthContext context(nullptr, true);
SecureAuthContext context(nullptr);
EXPECT_TRUE(context.GetPeerIdentity().empty());
EXPECT_TRUE(context.GetPeerIdentityPropertyName().empty());
EXPECT_TRUE(context.FindPropertyValues("").empty());
@ -42,8 +42,10 @@ TEST_F(SecureAuthContextTest, EmptyContext) {
}
TEST_F(SecureAuthContextTest, Properties) {
grpc_auth_context* ctx = grpc_auth_context_create(nullptr);
SecureAuthContext context(ctx, true);
grpc_core::RefCountedPtr<grpc_auth_context> ctx =
grpc_core::MakeRefCounted<grpc_auth_context>(nullptr);
SecureAuthContext context(ctx.get());
ctx.reset();
context.AddProperty("name", "chapi");
context.AddProperty("name", "chapo");
context.AddProperty("foo", "bar");
@ -60,8 +62,10 @@ TEST_F(SecureAuthContextTest, Properties) {
}
TEST_F(SecureAuthContextTest, Iterators) {
grpc_auth_context* ctx = grpc_auth_context_create(nullptr);
SecureAuthContext context(ctx, true);
grpc_core::RefCountedPtr<grpc_auth_context> ctx =
grpc_core::MakeRefCounted<grpc_auth_context>(nullptr);
SecureAuthContext context(ctx.get());
ctx.reset();
context.AddProperty("name", "chapi");
context.AddProperty("name", "chapo");
context.AddProperty("foo", "bar");

@ -414,8 +414,8 @@ class GrpclbEnd2endTest : public ::testing::Test {
std::shared_ptr<ChannelCredentials> creds(
new SecureChannelCredentials(grpc_composite_channel_credentials_create(
channel_creds, call_creds, nullptr)));
grpc_call_credentials_unref(call_creds);
grpc_channel_credentials_unref(channel_creds);
call_creds->Unref();
channel_creds->Unref();
channel_ = CreateCustomChannel(uri.str(), creds, args);
stub_ = grpc::testing::EchoTestService::NewStub(channel_);
}

Loading…
Cancel
Save