Revert "Rework version selection following David's comment."

This reverts commit 70764178bf.
pull/24955/head
Matthew Stevenson 4 years ago
parent 70764178bf
commit d660e2a47a
  1. 421
      src/core/tsi/ssl_transport_security.cc

@ -915,53 +915,39 @@ static tsi_result tsi_set_min_and_max_tls_versions(
// |SSL_CTX_set_min_proto_version| and |SSL_CTX_set_max_proto_version| APIs // |SSL_CTX_set_min_proto_version| and |SSL_CTX_set_max_proto_version| APIs
// only exist in this version range. // only exist in this version range.
switch (min_tls_version) { switch (min_tls_version) {
case tsi_tls_version::TSI_TLS1_2:
SSL_CTX_set_min_proto_version(ssl_context, TLS1_2_VERSION);
break;
// If the library does not support TLS 1.3, and the caller requested a minimum
// of TLS 1.3, return an error. The caller's request cannot be satisfied.
#if defined(TLS1_3_VERSION) #if defined(TLS1_3_VERSION)
case tsi_tls_version::TSI_TLS1_3: case tsi_tls_version::TSI_TLS1_3:
SSL_CTX_set_min_proto_version(ssl_context, TLS1_3_VERSION); SSL_CTX_set_min_proto_version(ssl_context, TLS1_3_VERSION);
break; break;
#endif #endif
default: default:
gpr_log(GPR_INFO, "TLS version is not supported."); SSL_CTX_set_min_proto_version(ssl_context, TLS1_2_VERSION);
return TSI_FAILED_PRECONDITION; break;
witch(min_tls_version) {} }
// Set the max TLS version of the SSL context. // Set the max TLS version of the SSL context.
switch (max_tls_version) { switch (max_tls_version) {
case tsi_tls_version::TSI_TLS1_2:
SSL_CTX_set_max_proto_version(ssl_context, TLS1_2_VERSION);
break;
case tsi_tls_version::TSI_TLS1_3:
#if defined(TLS1_3_VERSION) #if defined(TLS1_3_VERSION)
case tsi_tls_version::TSI_TLS1_3:
SSL_CTX_set_max_proto_version(ssl_context, TLS1_3_VERSION); SSL_CTX_set_max_proto_version(ssl_context, TLS1_3_VERSION);
#else
// The library doesn't support TLS 1.3, so set a maximum of
// TLS 1.2 instead.
SSL_CTX_set_max_proto_version(ssl_context, TLS1_2_VERSION);
#endif
break; break;
#endif
default: default:
gpr_log(GPR_INFO, "TLS version is not supported."); SSL_CTX_set_max_proto_version(ssl_context, TLS1_2_VERSION);
return TSI_FAILED_PRECONDITION; break;
} }
#endif #endif
return TSI_OK; return TSI_OK;
} }
/* --- tsi_ssl_root_certs_store methods implementation. ---*/ /* --- tsi_ssl_root_certs_store methods implementation. ---*/
tsi_ssl_root_certs_store* tsi_ssl_root_certs_store_create( tsi_ssl_root_certs_store* tsi_ssl_root_certs_store_create(
const char* pem_roots) { const char* pem_roots) {
if (pem_roots == nullptr) { if (pem_roots == nullptr) {
gpr_log(GPR_ERROR, "The root certificates are empty."); gpr_log(GPR_ERROR, "The root certificates are empty.");
return nullptr; return nullptr;
} }
tsi_ssl_root_certs_store* root_store = tsi_ssl_root_certs_store* root_store = static_cast<tsi_ssl_root_certs_store*>(
static_cast<tsi_ssl_root_certs_store*>(
gpr_zalloc(sizeof(tsi_ssl_root_certs_store))); gpr_zalloc(sizeof(tsi_ssl_root_certs_store)));
if (root_store == nullptr) { if (root_store == nullptr) {
gpr_log(GPR_ERROR, "Could not allocate buffer for ssl_root_certs_store."); gpr_log(GPR_ERROR, "Could not allocate buffer for ssl_root_certs_store.");
@ -982,36 +968,37 @@ static tsi_result tsi_set_min_and_max_tls_versions(
return nullptr; return nullptr;
} }
return root_store; return root_store;
} }
void tsi_ssl_root_certs_store_destroy(tsi_ssl_root_certs_store * self) { void tsi_ssl_root_certs_store_destroy(tsi_ssl_root_certs_store* self) {
if (self == nullptr) return; if (self == nullptr) return;
X509_STORE_free(self->store); X509_STORE_free(self->store);
gpr_free(self); gpr_free(self);
} }
/* --- tsi_ssl_session_cache methods implementation. ---*/ /* --- tsi_ssl_session_cache methods implementation. ---*/
tsi_ssl_session_cache* tsi_ssl_session_cache_create_lru(size_t capacity) { tsi_ssl_session_cache* tsi_ssl_session_cache_create_lru(size_t capacity) {
/* Pointer will be dereferenced by unref call. */ /* Pointer will be dereferenced by unref call. */
return reinterpret_cast<tsi_ssl_session_cache*>( return reinterpret_cast<tsi_ssl_session_cache*>(
tsi::SslSessionLRUCache::Create(capacity).release()); tsi::SslSessionLRUCache::Create(capacity).release());
} }
void tsi_ssl_session_cache_ref(tsi_ssl_session_cache * cache) { void tsi_ssl_session_cache_ref(tsi_ssl_session_cache* cache) {
/* Pointer will be dereferenced by unref call. */ /* Pointer will be dereferenced by unref call. */
reinterpret_cast<tsi::SslSessionLRUCache*>(cache)->Ref().release(); reinterpret_cast<tsi::SslSessionLRUCache*>(cache)->Ref().release();
} }
void tsi_ssl_session_cache_unref(tsi_ssl_session_cache * cache) { void tsi_ssl_session_cache_unref(tsi_ssl_session_cache* cache) {
reinterpret_cast<tsi::SslSessionLRUCache*>(cache)->Unref(); reinterpret_cast<tsi::SslSessionLRUCache*>(cache)->Unref();
} }
/* --- tsi_frame_protector methods implementation. ---*/ /* --- tsi_frame_protector methods implementation. ---*/
static tsi_result ssl_protector_protect( static tsi_result ssl_protector_protect(tsi_frame_protector* self,
tsi_frame_protector * self, const unsigned char* unprotected_bytes, const unsigned char* unprotected_bytes,
size_t* unprotected_bytes_size, unsigned char* protected_output_frames, size_t* unprotected_bytes_size,
unsigned char* protected_output_frames,
size_t* protected_output_frames_size) { size_t* protected_output_frames_size) {
tsi_ssl_frame_protector* impl = tsi_ssl_frame_protector* impl =
reinterpret_cast<tsi_ssl_frame_protector*>(self); reinterpret_cast<tsi_ssl_frame_protector*>(self);
@ -1062,10 +1049,10 @@ static tsi_result tsi_set_min_and_max_tls_versions(
*unprotected_bytes_size = available; *unprotected_bytes_size = available;
impl->buffer_offset = 0; impl->buffer_offset = 0;
return TSI_OK; return TSI_OK;
} }
static tsi_result ssl_protector_protect_flush( static tsi_result ssl_protector_protect_flush(
tsi_frame_protector * self, unsigned char* protected_output_frames, tsi_frame_protector* self, unsigned char* protected_output_frames,
size_t* protected_output_frames_size, size_t* still_pending_size) { size_t* protected_output_frames_size, size_t* still_pending_size) {
tsi_result result = TSI_OK; tsi_result result = TSI_OK;
tsi_ssl_frame_protector* impl = tsi_ssl_frame_protector* impl =
@ -1096,10 +1083,10 @@ static tsi_result tsi_set_min_and_max_tls_versions(
GPR_ASSERT(pending >= 0); GPR_ASSERT(pending >= 0);
*still_pending_size = static_cast<size_t>(pending); *still_pending_size = static_cast<size_t>(pending);
return TSI_OK; return TSI_OK;
} }
static tsi_result ssl_protector_unprotect( static tsi_result ssl_protector_unprotect(
tsi_frame_protector * self, const unsigned char* protected_frames_bytes, tsi_frame_protector* self, const unsigned char* protected_frames_bytes,
size_t* protected_frames_bytes_size, unsigned char* unprotected_bytes, size_t* protected_frames_bytes_size, unsigned char* unprotected_bytes,
size_t* unprotected_bytes_size) { size_t* unprotected_bytes_size) {
tsi_result result = TSI_OK; tsi_result result = TSI_OK;
@ -1123,8 +1110,7 @@ static tsi_result tsi_set_min_and_max_tls_versions(
/* Then, try to write some data to ssl. */ /* Then, try to write some data to ssl. */
GPR_ASSERT(*protected_frames_bytes_size <= INT_MAX); GPR_ASSERT(*protected_frames_bytes_size <= INT_MAX);
written_into_ssl = written_into_ssl = BIO_write(impl->network_io, protected_frames_bytes,
BIO_write(impl->network_io, protected_frames_bytes,
static_cast<int>(*protected_frames_bytes_size)); static_cast<int>(*protected_frames_bytes_size));
if (written_into_ssl < 0) { if (written_into_ssl < 0) {
gpr_log(GPR_ERROR, "Sending protected frame to ssl failed with %d", gpr_log(GPR_ERROR, "Sending protected frame to ssl failed with %d",
@ -1140,28 +1126,28 @@ static tsi_result tsi_set_min_and_max_tls_versions(
*unprotected_bytes_size += output_bytes_offset; *unprotected_bytes_size += output_bytes_offset;
} }
return result; return result;
} }
static void ssl_protector_destroy(tsi_frame_protector * self) { static void ssl_protector_destroy(tsi_frame_protector* self) {
tsi_ssl_frame_protector* impl = tsi_ssl_frame_protector* impl =
reinterpret_cast<tsi_ssl_frame_protector*>(self); reinterpret_cast<tsi_ssl_frame_protector*>(self);
if (impl->buffer != nullptr) gpr_free(impl->buffer); if (impl->buffer != nullptr) gpr_free(impl->buffer);
if (impl->ssl != nullptr) SSL_free(impl->ssl); if (impl->ssl != nullptr) SSL_free(impl->ssl);
if (impl->network_io != nullptr) BIO_free(impl->network_io); if (impl->network_io != nullptr) BIO_free(impl->network_io);
gpr_free(self); gpr_free(self);
} }
static const tsi_frame_protector_vtable frame_protector_vtable = { static const tsi_frame_protector_vtable frame_protector_vtable = {
ssl_protector_protect, ssl_protector_protect,
ssl_protector_protect_flush, ssl_protector_protect_flush,
ssl_protector_unprotect, ssl_protector_unprotect,
ssl_protector_destroy, ssl_protector_destroy,
}; };
/* --- tsi_server_handshaker_factory methods implementation. --- */ /* --- tsi_server_handshaker_factory methods implementation. --- */
static void tsi_ssl_handshaker_factory_destroy(tsi_ssl_handshaker_factory * static void tsi_ssl_handshaker_factory_destroy(
factory) { tsi_ssl_handshaker_factory* factory) {
if (factory == nullptr) return; if (factory == nullptr) return;
if (factory->vtable != nullptr && factory->vtable->destroy != nullptr) { if (factory->vtable != nullptr && factory->vtable->destroy != nullptr) {
@ -1170,40 +1156,39 @@ static tsi_result tsi_set_min_and_max_tls_versions(
/* Note, we don't free(self) here because this object is always directly /* Note, we don't free(self) here because this object is always directly
* embedded in another object. If tsi_ssl_handshaker_factory_init allocates * embedded in another object. If tsi_ssl_handshaker_factory_init allocates
* any memory, it should be free'd here. */ * any memory, it should be free'd here. */
} }
static tsi_ssl_handshaker_factory* tsi_ssl_handshaker_factory_ref( static tsi_ssl_handshaker_factory* tsi_ssl_handshaker_factory_ref(
tsi_ssl_handshaker_factory * factory) { tsi_ssl_handshaker_factory* factory) {
if (factory == nullptr) return nullptr; if (factory == nullptr) return nullptr;
gpr_refn(&factory->refcount, 1); gpr_refn(&factory->refcount, 1);
return factory; return factory;
} }
static void tsi_ssl_handshaker_factory_unref(tsi_ssl_handshaker_factory * static void tsi_ssl_handshaker_factory_unref(
factory) { tsi_ssl_handshaker_factory* factory) {
if (factory == nullptr) return; if (factory == nullptr) return;
if (gpr_unref(&factory->refcount)) { if (gpr_unref(&factory->refcount)) {
tsi_ssl_handshaker_factory_destroy(factory); tsi_ssl_handshaker_factory_destroy(factory);
} }
} }
static tsi_ssl_handshaker_factory_vtable handshaker_factory_vtable = { static tsi_ssl_handshaker_factory_vtable handshaker_factory_vtable = {nullptr};
nullptr};
/* Initializes a tsi_ssl_handshaker_factory object. Caller is responsible for /* Initializes a tsi_ssl_handshaker_factory object. Caller is responsible for
* allocating memory for the factory. */ * allocating memory for the factory. */
static void tsi_ssl_handshaker_factory_init(tsi_ssl_handshaker_factory * static void tsi_ssl_handshaker_factory_init(
factory) { tsi_ssl_handshaker_factory* factory) {
GPR_ASSERT(factory != nullptr); GPR_ASSERT(factory != nullptr);
factory->vtable = &handshaker_factory_vtable; factory->vtable = &handshaker_factory_vtable;
gpr_ref_init(&factory->refcount, 1); gpr_ref_init(&factory->refcount, 1);
} }
/* Gets the X509 cert chain in PEM format as a tsi_peer_property. */ /* Gets the X509 cert chain in PEM format as a tsi_peer_property. */
tsi_result tsi_ssl_get_cert_chain_contents(STACK_OF(X509) * peer_chain, tsi_result tsi_ssl_get_cert_chain_contents(STACK_OF(X509) * peer_chain,
tsi_peer_property * property) { tsi_peer_property* property) {
BIO* bio = BIO_new(BIO_s_mem()); BIO* bio = BIO_new(BIO_s_mem());
const auto peer_chain_len = sk_X509_num(peer_chain); const auto peer_chain_len = sk_X509_num(peer_chain);
for (auto i = decltype(peer_chain_len){0}; i < peer_chain_len; i++) { for (auto i = decltype(peer_chain_len){0}; i < peer_chain_len; i++) {
@ -1223,10 +1208,10 @@ static tsi_result tsi_set_min_and_max_tls_versions(
property); property);
BIO_free(bio); BIO_free(bio);
return result; return result;
} }
/* --- tsi_handshaker_result methods implementation. ---*/ /* --- tsi_handshaker_result methods implementation. ---*/
static tsi_result ssl_handshaker_result_extract_peer( static tsi_result ssl_handshaker_result_extract_peer(
const tsi_handshaker_result* self, tsi_peer* peer) { const tsi_handshaker_result* self, tsi_peer* peer) {
tsi_result result = TSI_OK; tsi_result result = TSI_OK;
const unsigned char* alpn_selected = nullptr; const unsigned char* alpn_selected = nullptr;
@ -1284,19 +1269,17 @@ static tsi_result tsi_set_min_and_max_tls_versions(
if (result != TSI_OK) return result; if (result != TSI_OK) return result;
peer->property_count++; peer->property_count++;
const char* session_reused = const char* session_reused = SSL_session_reused(impl->ssl) ? "true" : "false";
SSL_session_reused(impl->ssl) ? "true" : "false";
result = tsi_construct_string_peer_property_from_cstring( result = tsi_construct_string_peer_property_from_cstring(
TSI_SSL_SESSION_REUSED_PEER_PROPERTY, session_reused, TSI_SSL_SESSION_REUSED_PEER_PROPERTY, session_reused,
&peer->properties[peer->property_count]); &peer->properties[peer->property_count]);
if (result != TSI_OK) return result; if (result != TSI_OK) return result;
peer->property_count++; peer->property_count++;
return result; return result;
} }
static tsi_result ssl_handshaker_result_create_frame_protector( static tsi_result ssl_handshaker_result_create_frame_protector(
const tsi_handshaker_result* self, const tsi_handshaker_result* self, size_t* max_output_protected_frame_size,
size_t* max_output_protected_frame_size,
tsi_frame_protector** protector) { tsi_frame_protector** protector) {
size_t actual_max_output_protected_frame_size = size_t actual_max_output_protected_frame_size =
TSI_SSL_MAX_PROTECTED_FRAME_SIZE_UPPER_BOUND; TSI_SSL_MAX_PROTECTED_FRAME_SIZE_UPPER_BOUND;
@ -1319,8 +1302,8 @@ static tsi_result tsi_set_min_and_max_tls_versions(
} }
actual_max_output_protected_frame_size = *max_output_protected_frame_size; actual_max_output_protected_frame_size = *max_output_protected_frame_size;
} }
protector_impl->buffer_size = actual_max_output_protected_frame_size - protector_impl->buffer_size =
TSI_SSL_MAX_PROTECTION_OVERHEAD; actual_max_output_protected_frame_size - TSI_SSL_MAX_PROTECTION_OVERHEAD;
protector_impl->buffer = protector_impl->buffer =
static_cast<unsigned char*>(gpr_malloc(protector_impl->buffer_size)); static_cast<unsigned char*>(gpr_malloc(protector_impl->buffer_size));
if (protector_impl->buffer == nullptr) { if (protector_impl->buffer == nullptr) {
@ -1338,9 +1321,9 @@ static tsi_result tsi_set_min_and_max_tls_versions(
protector_impl->base.vtable = &frame_protector_vtable; protector_impl->base.vtable = &frame_protector_vtable;
*protector = &protector_impl->base; *protector = &protector_impl->base;
return TSI_OK; return TSI_OK;
} }
static tsi_result ssl_handshaker_result_get_unused_bytes( static tsi_result ssl_handshaker_result_get_unused_bytes(
const tsi_handshaker_result* self, const unsigned char** bytes, const tsi_handshaker_result* self, const unsigned char** bytes,
size_t* bytes_size) { size_t* bytes_size) {
const tsi_ssl_handshaker_result* impl = const tsi_ssl_handshaker_result* impl =
@ -1348,27 +1331,27 @@ static tsi_result tsi_set_min_and_max_tls_versions(
*bytes_size = impl->unused_bytes_size; *bytes_size = impl->unused_bytes_size;
*bytes = impl->unused_bytes; *bytes = impl->unused_bytes;
return TSI_OK; return TSI_OK;
} }
static void ssl_handshaker_result_destroy(tsi_handshaker_result * self) { static void ssl_handshaker_result_destroy(tsi_handshaker_result* self) {
tsi_ssl_handshaker_result* impl = tsi_ssl_handshaker_result* impl =
reinterpret_cast<tsi_ssl_handshaker_result*>(self); reinterpret_cast<tsi_ssl_handshaker_result*>(self);
SSL_free(impl->ssl); SSL_free(impl->ssl);
BIO_free(impl->network_io); BIO_free(impl->network_io);
gpr_free(impl->unused_bytes); gpr_free(impl->unused_bytes);
gpr_free(impl); gpr_free(impl);
} }
static const tsi_handshaker_result_vtable handshaker_result_vtable = { static const tsi_handshaker_result_vtable handshaker_result_vtable = {
ssl_handshaker_result_extract_peer, ssl_handshaker_result_extract_peer,
nullptr, /* create_zero_copy_grpc_protector */ nullptr, /* create_zero_copy_grpc_protector */
ssl_handshaker_result_create_frame_protector, ssl_handshaker_result_create_frame_protector,
ssl_handshaker_result_get_unused_bytes, ssl_handshaker_result_get_unused_bytes,
ssl_handshaker_result_destroy, ssl_handshaker_result_destroy,
}; };
static tsi_result ssl_handshaker_result_create( static tsi_result ssl_handshaker_result_create(
tsi_ssl_handshaker * handshaker, unsigned char* unused_bytes, tsi_ssl_handshaker* handshaker, unsigned char* unused_bytes,
size_t unused_bytes_size, tsi_handshaker_result** handshaker_result) { size_t unused_bytes_size, tsi_handshaker_result** handshaker_result) {
if (handshaker == nullptr || handshaker_result == nullptr || if (handshaker == nullptr || handshaker_result == nullptr ||
(unused_bytes_size > 0 && unused_bytes == nullptr)) { (unused_bytes_size > 0 && unused_bytes == nullptr)) {
@ -1387,12 +1370,12 @@ static tsi_result tsi_set_min_and_max_tls_versions(
result->unused_bytes_size = unused_bytes_size; result->unused_bytes_size = unused_bytes_size;
*handshaker_result = &result->base; *handshaker_result = &result->base;
return TSI_OK; return TSI_OK;
} }
/* --- tsi_handshaker methods implementation. ---*/ /* --- tsi_handshaker methods implementation. ---*/
static tsi_result ssl_handshaker_get_bytes_to_send_to_peer( static tsi_result ssl_handshaker_get_bytes_to_send_to_peer(
tsi_ssl_handshaker * impl, unsigned char* bytes, size_t* bytes_size) { tsi_ssl_handshaker* impl, unsigned char* bytes, size_t* bytes_size) {
int bytes_read_from_ssl = 0; int bytes_read_from_ssl = 0;
if (bytes == nullptr || bytes_size == nullptr || *bytes_size == 0 || if (bytes == nullptr || bytes_size == nullptr || *bytes_size == 0 ||
*bytes_size > INT_MAX) { *bytes_size > INT_MAX) {
@ -1412,19 +1395,18 @@ static tsi_result tsi_set_min_and_max_tls_versions(
} }
*bytes_size = static_cast<size_t>(bytes_read_from_ssl); *bytes_size = static_cast<size_t>(bytes_read_from_ssl);
return BIO_pending(impl->network_io) == 0 ? TSI_OK : TSI_INCOMPLETE_DATA; return BIO_pending(impl->network_io) == 0 ? TSI_OK : TSI_INCOMPLETE_DATA;
} }
static tsi_result ssl_handshaker_get_result(tsi_ssl_handshaker * impl) { static tsi_result ssl_handshaker_get_result(tsi_ssl_handshaker* impl) {
if ((impl->result == TSI_HANDSHAKE_IN_PROGRESS) && if ((impl->result == TSI_HANDSHAKE_IN_PROGRESS) &&
SSL_is_init_finished(impl->ssl)) { SSL_is_init_finished(impl->ssl)) {
impl->result = TSI_OK; impl->result = TSI_OK;
} }
return impl->result; return impl->result;
} }
static tsi_result ssl_handshaker_process_bytes_from_peer( static tsi_result ssl_handshaker_process_bytes_from_peer(
tsi_ssl_handshaker * impl, const unsigned char* bytes, tsi_ssl_handshaker* impl, const unsigned char* bytes, size_t* bytes_size) {
size_t* bytes_size) {
int bytes_written_into_ssl_size = 0; int bytes_written_into_ssl_size = 0;
if (bytes == nullptr || bytes_size == nullptr || *bytes_size > INT_MAX) { if (bytes == nullptr || bytes_size == nullptr || *bytes_size > INT_MAX) {
return TSI_INVALID_ARGUMENT; return TSI_INVALID_ARGUMENT;
@ -1466,20 +1448,20 @@ static tsi_result tsi_set_min_and_max_tls_versions(
} }
} }
} }
} }
static void ssl_handshaker_destroy(tsi_handshaker * self) { static void ssl_handshaker_destroy(tsi_handshaker* self) {
tsi_ssl_handshaker* impl = reinterpret_cast<tsi_ssl_handshaker*>(self); tsi_ssl_handshaker* impl = reinterpret_cast<tsi_ssl_handshaker*>(self);
SSL_free(impl->ssl); SSL_free(impl->ssl);
BIO_free(impl->network_io); BIO_free(impl->network_io);
gpr_free(impl->outgoing_bytes_buffer); gpr_free(impl->outgoing_bytes_buffer);
tsi_ssl_handshaker_factory_unref(impl->factory_ref); tsi_ssl_handshaker_factory_unref(impl->factory_ref);
gpr_free(impl); gpr_free(impl);
} }
// Removes the bytes remaining in |impl->SSL|'s read BIO and writes them to // Removes the bytes remaining in |impl->SSL|'s read BIO and writes them to
// |bytes_remaining|. // |bytes_remaining|.
static tsi_result ssl_bytes_remaining(tsi_ssl_handshaker * impl, static tsi_result ssl_bytes_remaining(tsi_ssl_handshaker* impl,
unsigned char** bytes_remaining, unsigned char** bytes_remaining,
size_t* bytes_remaining_size) { size_t* bytes_remaining_size) {
if (impl == nullptr || bytes_remaining == nullptr || if (impl == nullptr || bytes_remaining == nullptr ||
@ -1494,8 +1476,8 @@ static tsi_result tsi_set_min_and_max_tls_versions(
*bytes_remaining = static_cast<uint8_t*>(gpr_malloc(bytes_in_ssl)); *bytes_remaining = static_cast<uint8_t*>(gpr_malloc(bytes_in_ssl));
int bytes_read = BIO_read(SSL_get_rbio(impl->ssl), *bytes_remaining, int bytes_read = BIO_read(SSL_get_rbio(impl->ssl), *bytes_remaining,
static_cast<int>(bytes_in_ssl)); static_cast<int>(bytes_in_ssl));
// If an unexpected number of bytes were read, return an error status and // If an unexpected number of bytes were read, return an error status and free
// free all of the bytes that were read. // all of the bytes that were read.
if (bytes_read < 0 || static_cast<size_t>(bytes_read) != bytes_in_ssl) { if (bytes_read < 0 || static_cast<size_t>(bytes_read) != bytes_in_ssl) {
gpr_log(GPR_ERROR, gpr_log(GPR_ERROR,
"Failed to read the expected number of bytes from SSL object."); "Failed to read the expected number of bytes from SSL object.");
@ -1505,10 +1487,10 @@ static tsi_result tsi_set_min_and_max_tls_versions(
} }
*bytes_remaining_size = static_cast<size_t>(bytes_read); *bytes_remaining_size = static_cast<size_t>(bytes_read);
return TSI_OK; return TSI_OK;
} }
static tsi_result ssl_handshaker_next( static tsi_result ssl_handshaker_next(
tsi_handshaker * self, const unsigned char* received_bytes, tsi_handshaker* self, const unsigned char* received_bytes,
size_t received_bytes_size, const unsigned char** bytes_to_send, size_t received_bytes_size, const unsigned char** bytes_to_send,
size_t* bytes_to_send_size, tsi_handshaker_result** handshaker_result, size_t* bytes_to_send_size, tsi_handshaker_result** handshaker_result,
tsi_handshaker_on_next_done_cb /*cb*/, void* /*user_data*/) { tsi_handshaker_on_next_done_cb /*cb*/, void* /*user_data*/) {
@ -1548,9 +1530,9 @@ static tsi_result tsi_set_min_and_max_tls_versions(
*handshaker_result = nullptr; *handshaker_result = nullptr;
} else { } else {
// Any bytes that remain in |impl->ssl|'s read BIO after the handshake is // Any bytes that remain in |impl->ssl|'s read BIO after the handshake is
// complete must be extracted and set to the unused bytes of the // complete must be extracted and set to the unused bytes of the handshaker
// handshaker result. This indicates to the gRPC stack that there are // result. This indicates to the gRPC stack that there are bytes from the
// bytes from the peer that must be processed. // peer that must be processed.
unsigned char* unused_bytes = nullptr; unsigned char* unused_bytes = nullptr;
size_t unused_bytes_size = 0; size_t unused_bytes_size = 0;
status = ssl_bytes_remaining(impl, &unused_bytes, &unused_bytes_size); status = ssl_bytes_remaining(impl, &unused_bytes, &unused_bytes_size);
@ -1560,18 +1542,18 @@ static tsi_result tsi_set_min_and_max_tls_versions(
gpr_free(unused_bytes); gpr_free(unused_bytes);
return TSI_INTERNAL_ERROR; return TSI_INTERNAL_ERROR;
} }
status = ssl_handshaker_result_create( status = ssl_handshaker_result_create(impl, unused_bytes, unused_bytes_size,
impl, unused_bytes, unused_bytes_size, handshaker_result); handshaker_result);
if (status == TSI_OK) { if (status == TSI_OK) {
/* Indicates that the handshake has completed and that a /* Indicates that the handshake has completed and that a handshaker_result
* handshaker_result has been created. */ * has been created. */
self->handshaker_result_created = true; self->handshaker_result_created = true;
} }
} }
return status; return status;
} }
static const tsi_handshaker_vtable handshaker_vtable = { static const tsi_handshaker_vtable handshaker_vtable = {
nullptr, /* get_bytes_to_send_to_peer -- deprecated */ nullptr, /* get_bytes_to_send_to_peer -- deprecated */
nullptr, /* process_bytes_from_peer -- deprecated */ nullptr, /* process_bytes_from_peer -- deprecated */
nullptr, /* get_result -- deprecated */ nullptr, /* get_result -- deprecated */
@ -1580,14 +1562,13 @@ static tsi_result tsi_set_min_and_max_tls_versions(
ssl_handshaker_destroy, ssl_handshaker_destroy,
ssl_handshaker_next, ssl_handshaker_next,
nullptr, /* shutdown */ nullptr, /* shutdown */
}; };
/* --- tsi_ssl_handshaker_factory common methods. --- */ /* --- tsi_ssl_handshaker_factory common methods. --- */
static void tsi_ssl_handshaker_resume_session( static void tsi_ssl_handshaker_resume_session(
SSL * ssl, tsi::SslSessionLRUCache * session_cache) { SSL* ssl, tsi::SslSessionLRUCache* session_cache) {
const char* server_name = const char* server_name = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
if (server_name == nullptr) { if (server_name == nullptr) {
return; return;
} }
@ -1596,11 +1577,12 @@ static tsi_result tsi_set_min_and_max_tls_versions(
// SSL_set_session internally increments reference counter. // SSL_set_session internally increments reference counter.
SSL_set_session(ssl, session.get()); SSL_set_session(ssl, session.get());
} }
} }
static tsi_result create_tsi_ssl_handshaker( static tsi_result create_tsi_ssl_handshaker(SSL_CTX* ctx, int is_client,
SSL_CTX * ctx, int is_client, const char* server_name_indication, const char* server_name_indication,
tsi_ssl_handshaker_factory* factory, tsi_handshaker** handshaker) { tsi_ssl_handshaker_factory* factory,
tsi_handshaker** handshaker) {
SSL* ssl = SSL_new(ctx); SSL* ssl = SSL_new(ctx);
BIO* network_io = nullptr; BIO* network_io = nullptr;
BIO* ssl_io = nullptr; BIO* ssl_io = nullptr;
@ -1642,8 +1624,7 @@ static tsi_result tsi_set_min_and_max_tls_versions(
ssl_result = SSL_do_handshake(ssl); ssl_result = SSL_do_handshake(ssl);
ssl_result = SSL_get_error(ssl, ssl_result); ssl_result = SSL_get_error(ssl, ssl_result);
if (ssl_result != SSL_ERROR_WANT_READ) { if (ssl_result != SSL_ERROR_WANT_READ) {
gpr_log( gpr_log(GPR_ERROR,
GPR_ERROR,
"Unexpected error received from first SSL_do_handshake call: %s", "Unexpected error received from first SSL_do_handshake call: %s",
ssl_error_string(ssl_result)); ssl_error_string(ssl_result));
SSL_free(ssl); SSL_free(ssl);
@ -1660,18 +1641,20 @@ static tsi_result tsi_set_min_and_max_tls_versions(
impl->result = TSI_HANDSHAKE_IN_PROGRESS; impl->result = TSI_HANDSHAKE_IN_PROGRESS;
impl->outgoing_bytes_buffer_size = impl->outgoing_bytes_buffer_size =
TSI_SSL_HANDSHAKER_OUTGOING_BUFFER_INITIAL_SIZE; TSI_SSL_HANDSHAKER_OUTGOING_BUFFER_INITIAL_SIZE;
impl->outgoing_bytes_buffer = static_cast<unsigned char*>( impl->outgoing_bytes_buffer =
gpr_zalloc(impl->outgoing_bytes_buffer_size)); static_cast<unsigned char*>(gpr_zalloc(impl->outgoing_bytes_buffer_size));
impl->base.vtable = &handshaker_vtable; impl->base.vtable = &handshaker_vtable;
impl->factory_ref = tsi_ssl_handshaker_factory_ref(factory); impl->factory_ref = tsi_ssl_handshaker_factory_ref(factory);
*handshaker = &impl->base; *handshaker = &impl->base;
return TSI_OK; return TSI_OK;
} }
static int select_protocol_list( static int select_protocol_list(const unsigned char** out,
const unsigned char** out, unsigned char* outlen, unsigned char* outlen,
const unsigned char* client_list, size_t client_list_len, const unsigned char* client_list,
const unsigned char* server_list, size_t server_list_len) { size_t client_list_len,
const unsigned char* server_list,
size_t server_list_len) {
const unsigned char* client_current = client_list; const unsigned char* client_current = client_list;
while (static_cast<unsigned int>(client_current - client_list) < while (static_cast<unsigned int>(client_current - client_list) <
client_list_len) { client_list_len) {
@ -1692,26 +1675,26 @@ static tsi_result tsi_set_min_and_max_tls_versions(
client_current += client_current_len; client_current += client_current_len;
} }
return SSL_TLSEXT_ERR_NOACK; return SSL_TLSEXT_ERR_NOACK;
} }
/* --- tsi_ssl_client_handshaker_factory methods implementation. --- */ /* --- tsi_ssl_client_handshaker_factory methods implementation. --- */
tsi_result tsi_ssl_client_handshaker_factory_create_handshaker( tsi_result tsi_ssl_client_handshaker_factory_create_handshaker(
tsi_ssl_client_handshaker_factory * factory, tsi_ssl_client_handshaker_factory* factory,
const char* server_name_indication, tsi_handshaker** handshaker) { const char* server_name_indication, tsi_handshaker** handshaker) {
return create_tsi_ssl_handshaker(factory->ssl_context, 1, return create_tsi_ssl_handshaker(factory->ssl_context, 1,
server_name_indication, &factory->base, server_name_indication, &factory->base,
handshaker); handshaker);
} }
void tsi_ssl_client_handshaker_factory_unref( void tsi_ssl_client_handshaker_factory_unref(
tsi_ssl_client_handshaker_factory * factory) { tsi_ssl_client_handshaker_factory* factory) {
if (factory == nullptr) return; if (factory == nullptr) return;
tsi_ssl_handshaker_factory_unref(&factory->base); tsi_ssl_handshaker_factory_unref(&factory->base);
} }
static void tsi_ssl_client_handshaker_factory_destroy( static void tsi_ssl_client_handshaker_factory_destroy(
tsi_ssl_handshaker_factory * factory) { tsi_ssl_handshaker_factory* factory) {
if (factory == nullptr) return; if (factory == nullptr) return;
tsi_ssl_client_handshaker_factory* self = tsi_ssl_client_handshaker_factory* self =
reinterpret_cast<tsi_ssl_client_handshaker_factory*>(factory); reinterpret_cast<tsi_ssl_client_handshaker_factory*>(factory);
@ -1719,9 +1702,9 @@ static tsi_result tsi_set_min_and_max_tls_versions(
if (self->alpn_protocol_list != nullptr) gpr_free(self->alpn_protocol_list); if (self->alpn_protocol_list != nullptr) gpr_free(self->alpn_protocol_list);
self->session_cache.reset(); self->session_cache.reset();
gpr_free(self); gpr_free(self);
} }
static int client_handshaker_factory_npn_callback( static int client_handshaker_factory_npn_callback(
SSL* /*ssl*/, unsigned char** out, unsigned char* outlen, SSL* /*ssl*/, unsigned char** out, unsigned char* outlen,
const unsigned char* in, unsigned int inlen, void* arg) { const unsigned char* in, unsigned int inlen, void* arg) {
tsi_ssl_client_handshaker_factory* factory = tsi_ssl_client_handshaker_factory* factory =
@ -1729,28 +1712,27 @@ static tsi_result tsi_set_min_and_max_tls_versions(
return select_protocol_list(const_cast<const unsigned char**>(out), outlen, return select_protocol_list(const_cast<const unsigned char**>(out), outlen,
factory->alpn_protocol_list, factory->alpn_protocol_list,
factory->alpn_protocol_list_length, in, inlen); factory->alpn_protocol_list_length, in, inlen);
} }
/* --- tsi_ssl_server_handshaker_factory methods implementation. --- */ /* --- tsi_ssl_server_handshaker_factory methods implementation. --- */
tsi_result tsi_ssl_server_handshaker_factory_create_handshaker( tsi_result tsi_ssl_server_handshaker_factory_create_handshaker(
tsi_ssl_server_handshaker_factory * factory, tsi_ssl_server_handshaker_factory* factory, tsi_handshaker** handshaker) {
tsi_handshaker * *handshaker) {
if (factory->ssl_context_count == 0) return TSI_INVALID_ARGUMENT; if (factory->ssl_context_count == 0) return TSI_INVALID_ARGUMENT;
/* Create the handshaker with the first context. We will switch if needed /* Create the handshaker with the first context. We will switch if needed
because of SNI in ssl_server_handshaker_factory_servername_callback. */ because of SNI in ssl_server_handshaker_factory_servername_callback. */
return create_tsi_ssl_handshaker(factory->ssl_contexts[0], 0, nullptr, return create_tsi_ssl_handshaker(factory->ssl_contexts[0], 0, nullptr,
&factory->base, handshaker); &factory->base, handshaker);
} }
void tsi_ssl_server_handshaker_factory_unref( void tsi_ssl_server_handshaker_factory_unref(
tsi_ssl_server_handshaker_factory * factory) { tsi_ssl_server_handshaker_factory* factory) {
if (factory == nullptr) return; if (factory == nullptr) return;
tsi_ssl_handshaker_factory_unref(&factory->base); tsi_ssl_handshaker_factory_unref(&factory->base);
} }
static void tsi_ssl_server_handshaker_factory_destroy( static void tsi_ssl_server_handshaker_factory_destroy(
tsi_ssl_handshaker_factory * factory) { tsi_ssl_handshaker_factory* factory) {
if (factory == nullptr) return; if (factory == nullptr) return;
tsi_ssl_server_handshaker_factory* self = tsi_ssl_server_handshaker_factory* self =
reinterpret_cast<tsi_ssl_server_handshaker_factory*>(factory); reinterpret_cast<tsi_ssl_server_handshaker_factory*>(factory);
@ -1767,9 +1749,9 @@ static tsi_result tsi_set_min_and_max_tls_versions(
} }
if (self->alpn_protocol_list != nullptr) gpr_free(self->alpn_protocol_list); if (self->alpn_protocol_list != nullptr) gpr_free(self->alpn_protocol_list);
gpr_free(self); gpr_free(self);
} }
static int does_entry_match_name(absl::string_view entry, static int does_entry_match_name(absl::string_view entry,
absl::string_view name) { absl::string_view name) {
if (entry.empty()) return 0; if (entry.empty()) return 0;
@ -1808,10 +1790,11 @@ static tsi_result tsi_set_min_and_max_tls_versions(
name_subdomain.remove_suffix(1); name_subdomain.remove_suffix(1);
} }
return !entry.empty() && absl::EqualsIgnoreCase(name_subdomain, entry); return !entry.empty() && absl::EqualsIgnoreCase(name_subdomain, entry);
} }
static int ssl_server_handshaker_factory_servername_callback( static int ssl_server_handshaker_factory_servername_callback(SSL* ssl,
SSL * ssl, int* /*ap*/, void* arg) { int* /*ap*/,
void* arg) {
tsi_ssl_server_handshaker_factory* impl = tsi_ssl_server_handshaker_factory* impl =
static_cast<tsi_ssl_server_handshaker_factory*>(arg); static_cast<tsi_ssl_server_handshaker_factory*>(arg);
size_t i = 0; size_t i = 0;
@ -1829,10 +1812,10 @@ static tsi_result tsi_set_min_and_max_tls_versions(
} }
gpr_log(GPR_ERROR, "No match found for server name: %s.", servername); gpr_log(GPR_ERROR, "No match found for server name: %s.", servername);
return SSL_TLSEXT_ERR_NOACK; return SSL_TLSEXT_ERR_NOACK;
} }
#if TSI_OPENSSL_ALPN_SUPPORT #if TSI_OPENSSL_ALPN_SUPPORT
static int server_handshaker_factory_alpn_callback( static int server_handshaker_factory_alpn_callback(
SSL* /*ssl*/, const unsigned char** out, unsigned char* outlen, SSL* /*ssl*/, const unsigned char** out, unsigned char* outlen,
const unsigned char* in, unsigned int inlen, void* arg) { const unsigned char* in, unsigned int inlen, void* arg) {
tsi_ssl_server_handshaker_factory* factory = tsi_ssl_server_handshaker_factory* factory =
@ -1840,28 +1823,27 @@ static tsi_result tsi_set_min_and_max_tls_versions(
return select_protocol_list(out, outlen, in, inlen, return select_protocol_list(out, outlen, in, inlen,
factory->alpn_protocol_list, factory->alpn_protocol_list,
factory->alpn_protocol_list_length); factory->alpn_protocol_list_length);
} }
#endif /* TSI_OPENSSL_ALPN_SUPPORT */ #endif /* TSI_OPENSSL_ALPN_SUPPORT */
static int server_handshaker_factory_npn_advertised_callback( static int server_handshaker_factory_npn_advertised_callback(
SSL* /*ssl*/, const unsigned char** out, unsigned int* outlen, SSL* /*ssl*/, const unsigned char** out, unsigned int* outlen, void* arg) {
void* arg) {
tsi_ssl_server_handshaker_factory* factory = tsi_ssl_server_handshaker_factory* factory =
static_cast<tsi_ssl_server_handshaker_factory*>(arg); static_cast<tsi_ssl_server_handshaker_factory*>(arg);
*out = factory->alpn_protocol_list; *out = factory->alpn_protocol_list;
GPR_ASSERT(factory->alpn_protocol_list_length <= UINT_MAX); GPR_ASSERT(factory->alpn_protocol_list_length <= UINT_MAX);
*outlen = static_cast<unsigned int>(factory->alpn_protocol_list_length); *outlen = static_cast<unsigned int>(factory->alpn_protocol_list_length);
return SSL_TLSEXT_ERR_OK; return SSL_TLSEXT_ERR_OK;
} }
/// This callback is called when new \a session is established and ready to /// This callback is called when new \a session is established and ready to
/// be cached. This session can be reused for new connections to similar /// be cached. This session can be reused for new connections to similar
/// servers at later point of time. /// servers at later point of time.
/// It's intended to be used with SSL_CTX_sess_set_new_cb function. /// It's intended to be used with SSL_CTX_sess_set_new_cb function.
/// ///
/// It returns 1 if callback takes ownership over \a session and 0 otherwise. /// It returns 1 if callback takes ownership over \a session and 0 otherwise.
static int server_handshaker_factory_new_session_callback( static int server_handshaker_factory_new_session_callback(
SSL * ssl, SSL_SESSION * session) { SSL* ssl, SSL_SESSION* session) {
SSL_CTX* ssl_context = SSL_get_SSL_CTX(ssl); SSL_CTX* ssl_context = SSL_get_SSL_CTX(ssl);
if (ssl_context == nullptr) { if (ssl_context == nullptr) {
return 0; return 0;
@ -1869,22 +1851,21 @@ static tsi_result tsi_set_min_and_max_tls_versions(
void* arg = SSL_CTX_get_ex_data(ssl_context, g_ssl_ctx_ex_factory_index); void* arg = SSL_CTX_get_ex_data(ssl_context, g_ssl_ctx_ex_factory_index);
tsi_ssl_client_handshaker_factory* factory = tsi_ssl_client_handshaker_factory* factory =
static_cast<tsi_ssl_client_handshaker_factory*>(arg); static_cast<tsi_ssl_client_handshaker_factory*>(arg);
const char* server_name = const char* server_name = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
if (server_name == nullptr) { if (server_name == nullptr) {
return 0; return 0;
} }
factory->session_cache->Put(server_name, tsi::SslSessionPtr(session)); factory->session_cache->Put(server_name, tsi::SslSessionPtr(session));
// Return 1 to indicate transferred ownership over the given session. // Return 1 to indicate transferred ownership over the given session.
return 1; return 1;
} }
/* --- tsi_ssl_handshaker_factory constructors. --- */ /* --- tsi_ssl_handshaker_factory constructors. --- */
static tsi_ssl_handshaker_factory_vtable client_handshaker_factory_vtable = { static tsi_ssl_handshaker_factory_vtable client_handshaker_factory_vtable = {
tsi_ssl_client_handshaker_factory_destroy}; tsi_ssl_client_handshaker_factory_destroy};
tsi_result tsi_create_ssl_client_handshaker_factory( tsi_result tsi_create_ssl_client_handshaker_factory(
const tsi_ssl_pem_key_cert_pair* pem_key_cert_pair, const tsi_ssl_pem_key_cert_pair* pem_key_cert_pair,
const char* pem_root_certs, const char* cipher_suites, const char* pem_root_certs, const char* cipher_suites,
const char** alpn_protocols, uint16_t num_alpn_protocols, const char** alpn_protocols, uint16_t num_alpn_protocols,
@ -1897,9 +1878,9 @@ static tsi_result tsi_set_min_and_max_tls_versions(
options.num_alpn_protocols = num_alpn_protocols; options.num_alpn_protocols = num_alpn_protocols;
return tsi_create_ssl_client_handshaker_factory_with_options(&options, return tsi_create_ssl_client_handshaker_factory_with_options(&options,
factory); factory);
} }
tsi_result tsi_create_ssl_client_handshaker_factory_with_options( tsi_result tsi_create_ssl_client_handshaker_factory_with_options(
const tsi_ssl_client_handshaker_options* options, const tsi_ssl_client_handshaker_options* options,
tsi_ssl_client_handshaker_factory** factory) { tsi_ssl_client_handshaker_factory** factory) {
SSL_CTX* ssl_context = nullptr; SSL_CTX* ssl_context = nullptr;
@ -1955,11 +1936,10 @@ static tsi_result tsi_set_min_and_max_tls_versions(
SSL_CTX_set_cert_store(ssl_context, options->root_store->store); SSL_CTX_set_cert_store(ssl_context, options->root_store->store);
} }
#endif #endif
if (OPENSSL_VERSION_NUMBER < 0x10100000 || if (OPENSSL_VERSION_NUMBER < 0x10100000 || options->root_store == nullptr) {
options->root_store == nullptr) {
result = ssl_ctx_load_verification_certs( result = ssl_ctx_load_verification_certs(
ssl_context, options->pem_root_certs, ssl_context, options->pem_root_certs, strlen(options->pem_root_certs),
strlen(options->pem_root_certs), nullptr); nullptr);
if (result != TSI_OK) { if (result != TSI_OK) {
gpr_log(GPR_ERROR, "Cannot load server root certificates."); gpr_log(GPR_ERROR, "Cannot load server root certificates.");
break; break;
@ -2002,12 +1982,12 @@ static tsi_result tsi_set_min_and_max_tls_versions(
*factory = impl; *factory = impl;
return TSI_OK; return TSI_OK;
} }
static tsi_ssl_handshaker_factory_vtable server_handshaker_factory_vtable = { static tsi_ssl_handshaker_factory_vtable server_handshaker_factory_vtable = {
tsi_ssl_server_handshaker_factory_destroy}; tsi_ssl_server_handshaker_factory_destroy};
tsi_result tsi_create_ssl_server_handshaker_factory( tsi_result tsi_create_ssl_server_handshaker_factory(
const tsi_ssl_pem_key_cert_pair* pem_key_cert_pairs, const tsi_ssl_pem_key_cert_pair* pem_key_cert_pairs,
size_t num_key_cert_pairs, const char* pem_client_root_certs, size_t num_key_cert_pairs, const char* pem_client_root_certs,
int force_client_auth, const char* cipher_suites, int force_client_auth, const char* cipher_suites,
@ -2015,19 +1995,17 @@ static tsi_result tsi_set_min_and_max_tls_versions(
tsi_ssl_server_handshaker_factory** factory) { tsi_ssl_server_handshaker_factory** factory) {
return tsi_create_ssl_server_handshaker_factory_ex( return tsi_create_ssl_server_handshaker_factory_ex(
pem_key_cert_pairs, num_key_cert_pairs, pem_client_root_certs, pem_key_cert_pairs, num_key_cert_pairs, pem_client_root_certs,
force_client_auth force_client_auth ? TSI_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY
? TSI_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY
: TSI_DONT_REQUEST_CLIENT_CERTIFICATE, : TSI_DONT_REQUEST_CLIENT_CERTIFICATE,
cipher_suites, alpn_protocols, num_alpn_protocols, factory); cipher_suites, alpn_protocols, num_alpn_protocols, factory);
} }
tsi_result tsi_create_ssl_server_handshaker_factory_ex( tsi_result tsi_create_ssl_server_handshaker_factory_ex(
const tsi_ssl_pem_key_cert_pair* pem_key_cert_pairs, const tsi_ssl_pem_key_cert_pair* pem_key_cert_pairs,
size_t num_key_cert_pairs, const char* pem_client_root_certs, size_t num_key_cert_pairs, const char* pem_client_root_certs,
tsi_client_certificate_request_type client_certificate_request, tsi_client_certificate_request_type client_certificate_request,
const char* cipher_suites, const char** alpn_protocols, const char* cipher_suites, const char** alpn_protocols,
uint16_t num_alpn_protocols, uint16_t num_alpn_protocols, tsi_ssl_server_handshaker_factory** factory) {
tsi_ssl_server_handshaker_factory** factory) {
tsi_ssl_server_handshaker_options options; tsi_ssl_server_handshaker_options options;
options.pem_key_cert_pairs = pem_key_cert_pairs; options.pem_key_cert_pairs = pem_key_cert_pairs;
options.num_key_cert_pairs = num_key_cert_pairs; options.num_key_cert_pairs = num_key_cert_pairs;
@ -2038,9 +2016,9 @@ static tsi_result tsi_set_min_and_max_tls_versions(
options.num_alpn_protocols = num_alpn_protocols; options.num_alpn_protocols = num_alpn_protocols;
return tsi_create_ssl_server_handshaker_factory_with_options(&options, return tsi_create_ssl_server_handshaker_factory_with_options(&options,
factory); factory);
} }
tsi_result tsi_create_ssl_server_handshaker_factory_with_options( tsi_result tsi_create_ssl_server_handshaker_factory_with_options(
const tsi_ssl_server_handshaker_options* options, const tsi_ssl_server_handshaker_options* options,
tsi_ssl_server_handshaker_factory** factory) { tsi_ssl_server_handshaker_factory** factory) {
tsi_ssl_server_handshaker_factory* impl = nullptr; tsi_ssl_server_handshaker_factory* impl = nullptr;
@ -2149,15 +2127,14 @@ static tsi_result tsi_set_min_and_max_tls_versions(
SSL_CTX_set_verify(impl->ssl_contexts[i], SSL_VERIFY_PEER, nullptr); SSL_CTX_set_verify(impl->ssl_contexts[i], SSL_VERIFY_PEER, nullptr);
break; break;
case TSI_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_BUT_DONT_VERIFY: case TSI_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_BUT_DONT_VERIFY:
SSL_CTX_set_verify( SSL_CTX_set_verify(impl->ssl_contexts[i],
impl->ssl_contexts[i],
SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT,
NullVerifyCallback); NullVerifyCallback);
break; break;
case TSI_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY: case TSI_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY:
SSL_CTX_set_verify( SSL_CTX_set_verify(impl->ssl_contexts[i],
impl->ssl_contexts[i], SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT,
SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, nullptr); nullptr);
break; break;
} }
/* TODO(jboeuf): Add revocation verification. */ /* TODO(jboeuf): Add revocation verification. */
@ -2173,8 +2150,7 @@ static tsi_result tsi_set_min_and_max_tls_versions(
SSL_CTX_set_tlsext_servername_arg(impl->ssl_contexts[i], impl); SSL_CTX_set_tlsext_servername_arg(impl->ssl_contexts[i], impl);
#if TSI_OPENSSL_ALPN_SUPPORT #if TSI_OPENSSL_ALPN_SUPPORT
SSL_CTX_set_alpn_select_cb(impl->ssl_contexts[i], SSL_CTX_set_alpn_select_cb(impl->ssl_contexts[i],
server_handshaker_factory_alpn_callback, server_handshaker_factory_alpn_callback, impl);
impl);
#endif /* TSI_OPENSSL_ALPN_SUPPORT */ #endif /* TSI_OPENSSL_ALPN_SUPPORT */
SSL_CTX_set_next_protos_advertised_cb( SSL_CTX_set_next_protos_advertised_cb(
impl->ssl_contexts[i], impl->ssl_contexts[i],
@ -2189,11 +2165,11 @@ static tsi_result tsi_set_min_and_max_tls_versions(
*factory = impl; *factory = impl;
return TSI_OK; return TSI_OK;
} }
/* --- tsi_ssl utils. --- */ /* --- tsi_ssl utils. --- */
int tsi_ssl_peer_matches_name(const tsi_peer* peer, absl::string_view name) { int tsi_ssl_peer_matches_name(const tsi_peer* peer, absl::string_view name) {
size_t i = 0; size_t i = 0;
size_t san_count = 0; size_t san_count = 0;
const tsi_peer_property* cn_property = nullptr; const tsi_peer_property* cn_property = nullptr;
@ -2230,17 +2206,16 @@ static tsi_result tsi_set_min_and_max_tls_versions(
} }
return 0; /* Not found. */ return 0; /* Not found. */
} }
/* --- Testing support. --- */ /* --- Testing support. --- */
const tsi_ssl_handshaker_factory_vtable* const tsi_ssl_handshaker_factory_vtable* tsi_ssl_handshaker_factory_swap_vtable(
tsi_ssl_handshaker_factory_swap_vtable( tsi_ssl_handshaker_factory* factory,
tsi_ssl_handshaker_factory * factory, tsi_ssl_handshaker_factory_vtable* new_vtable) {
tsi_ssl_handshaker_factory_vtable * new_vtable) {
GPR_ASSERT(factory != nullptr); GPR_ASSERT(factory != nullptr);
GPR_ASSERT(factory->vtable != nullptr); GPR_ASSERT(factory->vtable != nullptr);
const tsi_ssl_handshaker_factory_vtable* orig_vtable = factory->vtable; const tsi_ssl_handshaker_factory_vtable* orig_vtable = factory->vtable;
factory->vtable = new_vtable; factory->vtable = new_vtable;
return orig_vtable; return orig_vtable;
} }

Loading…
Cancel
Save