From fbf10f0d968beb56622eb4927bace53a0e931189 Mon Sep 17 00:00:00 2001 From: David Benjamin Date: Mon, 12 Feb 2024 16:57:18 -0500 Subject: [PATCH] Make an internal RefCounted base class for libssl This is still a bit more tedious than I'd like, but we've got three of these and I'm about to add a fourth. Add something like Chromium's base class. But where Chromium integrates the base class directly with scoped_refptr (giving a place for a static_assert that you did the subclassing right), we don't quite have that since we need to integrate with the external C API. Instead, use the "passkey" pattern and have RefCounted's protected constructor take a struct that only T can construct. The passkey ensures that only T can construct RefCounted, and the protectedness ensures that T subclassed RefCounted. (I think the latter already comes from the static_cast in DecRef, but may as well.) Change-Id: Icf4cbc7d4168010ee46dfa3a7b0a2e7c20aaf383 Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/66369 Reviewed-by: Bob Beck Commit-Queue: David Benjamin --- ssl/encrypted_client_hello.cc | 12 ++----- ssl/internal.h | 63 +++++++++++++++++++++++++++-------- ssl/ssl_lib.cc | 13 +++----- ssl/ssl_session.cc | 13 +++----- 4 files changed, 62 insertions(+), 39 deletions(-) diff --git a/ssl/encrypted_client_hello.cc b/ssl/encrypted_client_hello.cc index a5492e9a0..8c4a42ce8 100644 --- a/ssl/encrypted_client_hello.cc +++ b/ssl/encrypted_client_hello.cc @@ -1012,18 +1012,12 @@ int SSL_marshal_ech_config(uint8_t **out, size_t *out_len, uint8_t config_id, SSL_ECH_KEYS *SSL_ECH_KEYS_new() { return New(); } -void SSL_ECH_KEYS_up_ref(SSL_ECH_KEYS *keys) { - CRYPTO_refcount_inc(&keys->references); -} +void SSL_ECH_KEYS_up_ref(SSL_ECH_KEYS *keys) { keys->UpRefInternal(); } void SSL_ECH_KEYS_free(SSL_ECH_KEYS *keys) { - if (keys == nullptr || - !CRYPTO_refcount_dec_and_test_zero(&keys->references)) { - return; + if (keys != nullptr) { + keys->DecRefInternal(); } - - keys->~ssl_ech_keys_st(); - OPENSSL_free(keys); } int SSL_ECH_KEYS_add(SSL_ECH_KEYS *configs, int is_retry_config, diff --git a/ssl/internal.h b/ssl/internal.h index f1d02a0fb..dcc546bd3 100644 --- a/ssl/internal.h +++ b/ssl/internal.h @@ -460,6 +460,48 @@ inline size_t GetAllNames(const char **out, size_t max_out, return fixed_names.size() + objects.size(); } +// RefCounted is a common base for ref-counted types. This is an instance of the +// C++ curiously-recurring template pattern, so a type Foo must subclass +// RefCounted. It additionally must friend RefCounted to allow calling +// the destructor. +template +class RefCounted { + public: + RefCounted(const RefCounted &) = delete; + RefCounted &operator=(const RefCounted &) = delete; + + // These methods are intentionally named differently from `bssl::UpRef` to + // avoid a collision. Only the implementations of `FOO_up_ref` and `FOO_free` + // should call these. + void UpRefInternal() { CRYPTO_refcount_inc(&references_); } + void DecRefInternal() { + if (CRYPTO_refcount_dec_and_test_zero(&references_)) { + Derived *d = static_cast(this); + d->~Derived(); + OPENSSL_free(d); + } + } + + protected: + // Ensure that only `Derived`, which must inherit from `RefCounted`, + // can call the constructor. This catches bugs where someone inherited from + // the wrong base. + class CheckSubClass { + private: + friend Derived; + CheckSubClass() = default; + }; + RefCounted(CheckSubClass) { + static_assert(std::is_base_of::value, + "Derived must subclass RefCounted"); + } + + ~RefCounted() = default; + + private: + CRYPTO_refcount_t references_ = 1; +}; + // Protocol versions. // @@ -3446,7 +3488,7 @@ struct ssl_method_st { const bssl::SSL_X509_METHOD *x509_method; }; -struct ssl_ctx_st { +struct ssl_ctx_st : public bssl::RefCounted { explicit ssl_ctx_st(const SSL_METHOD *ssl_method); ssl_ctx_st(const ssl_ctx_st &) = delete; ssl_ctx_st &operator=(const ssl_ctx_st &) = delete; @@ -3516,8 +3558,6 @@ struct ssl_ctx_st { SSL_SESSION *(*get_session_cb)(SSL *ssl, const uint8_t *data, int len, int *copy) = nullptr; - CRYPTO_refcount_t references = 1; - // if defined, these override the X509_verify_cert() calls int (*app_verify_callback)(X509_STORE_CTX *store_ctx, void *arg) = nullptr; void *app_verify_arg = nullptr; @@ -3754,8 +3794,8 @@ struct ssl_ctx_st { bool aes_hw_override_value : 1; private: + friend RefCounted; ~ssl_ctx_st(); - friend OPENSSL_EXPORT void SSL_CTX_free(SSL_CTX *); }; struct ssl_st { @@ -3847,13 +3887,11 @@ struct ssl_st { bool enable_early_data : 1; }; -struct ssl_session_st { +struct ssl_session_st : public bssl::RefCounted { explicit ssl_session_st(const bssl::SSL_X509_METHOD *method); ssl_session_st(const ssl_session_st &) = delete; ssl_session_st &operator=(const ssl_session_st &) = delete; - CRYPTO_refcount_t references = 1; - // ssl_version is the (D)TLS version that established the session. uint16_t ssl_version = 0; @@ -3996,21 +4034,18 @@ struct ssl_session_st { bssl::Array quic_early_data_context; private: + friend RefCounted; ~ssl_session_st(); - friend OPENSSL_EXPORT void SSL_SESSION_free(SSL_SESSION *); }; -struct ssl_ech_keys_st { - ssl_ech_keys_st() = default; - ssl_ech_keys_st(const ssl_ech_keys_st &) = delete; - ssl_ech_keys_st &operator=(const ssl_ech_keys_st &) = delete; +struct ssl_ech_keys_st : public bssl::RefCounted { + ssl_ech_keys_st() : RefCounted(CheckSubClass()) {} bssl::GrowableArray> configs; - CRYPTO_refcount_t references = 1; private: + friend RefCounted; ~ssl_ech_keys_st() = default; - friend OPENSSL_EXPORT void SSL_ECH_KEYS_free(SSL_ECH_KEYS *); }; #endif // OPENSSL_HEADER_SSL_INTERNAL_H diff --git a/ssl/ssl_lib.cc b/ssl/ssl_lib.cc index 58b68e675..91741fdf5 100644 --- a/ssl/ssl_lib.cc +++ b/ssl/ssl_lib.cc @@ -523,7 +523,8 @@ static int ssl_session_cmp(const SSL_SESSION *a, const SSL_SESSION *b) { } ssl_ctx_st::ssl_ctx_st(const SSL_METHOD *ssl_method) - : method(ssl_method->method), + : RefCounted(CheckSubClass()), + method(ssl_method->method), x509_method(ssl_method->x509_method), retain_only_sha256_of_client_certs(false), quiet_shutdown(false), @@ -589,18 +590,14 @@ SSL_CTX *SSL_CTX_new(const SSL_METHOD *method) { } int SSL_CTX_up_ref(SSL_CTX *ctx) { - CRYPTO_refcount_inc(&ctx->references); + ctx->UpRefInternal(); return 1; } void SSL_CTX_free(SSL_CTX *ctx) { - if (ctx == NULL || - !CRYPTO_refcount_dec_and_test_zero(&ctx->references)) { - return; + if (ctx != nullptr) { + ctx->DecRefInternal(); } - - ctx->~ssl_ctx_st(); - OPENSSL_free(ctx); } ssl_st::ssl_st(SSL_CTX *ctx_arg) diff --git a/ssl/ssl_session.cc b/ssl/ssl_session.cc index 979ac5972..5275b69c4 100644 --- a/ssl/ssl_session.cc +++ b/ssl/ssl_session.cc @@ -935,7 +935,8 @@ BSSL_NAMESPACE_END using namespace bssl; ssl_session_st::ssl_session_st(const SSL_X509_METHOD *method) - : x509_method(method), + : RefCounted(CheckSubClass()), + x509_method(method), extended_master_secret(false), peer_sha256_valid(false), not_resumable(false), @@ -957,18 +958,14 @@ SSL_SESSION *SSL_SESSION_new(const SSL_CTX *ctx) { } int SSL_SESSION_up_ref(SSL_SESSION *session) { - CRYPTO_refcount_inc(&session->references); + session->UpRefInternal(); return 1; } void SSL_SESSION_free(SSL_SESSION *session) { - if (session == NULL || - !CRYPTO_refcount_dec_and_test_zero(&session->references)) { - return; + if (session != nullptr) { + session->DecRefInternal(); } - - session->~ssl_session_st(); - OPENSSL_free(session); } const uint8_t *SSL_SESSION_get_id(const SSL_SESSION *session,