From 4bd32a848308a6a1acd02f62db1710e5ad31e1c8 Mon Sep 17 00:00:00 2001 From: David Benjamin Date: Fri, 14 May 2021 17:06:29 -0400 Subject: [PATCH] Convert more of the SSL write path to size_t and Spans. We still have our <= 0 return values because anything with BIOs tries to preserve BIO_write's error returns. (Maybe we can stop doing this? BIO_read's error return is a little subtle with EOF vs error, but BIO_write's is uninteresting.) But the rest of the logic is size_t-clean and hopefully a little clearer. We still have to support SSL_write's rather goofy calling convention, however. I haven't pushed Spans down into the low-level record construction logic yet. We should probably do that, but there are enough offsets tossed around there that they warrant their own CL. Bug: 507 Change-Id: Ia0c702d1a2d3713e71b0bbfa8d65649d3b20da9b Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/47544 Commit-Queue: Bob Beck Reviewed-by: Bob Beck --- ssl/d1_pkt.cc | 35 ++++++------ ssl/internal.h | 39 +++++++++---- ssl/s3_pkt.cc | 147 ++++++++++++++++++++++++++----------------------- ssl/ssl_lib.cc | 16 ++++-- 4 files changed, 135 insertions(+), 102 deletions(-) diff --git a/ssl/d1_pkt.cc b/ssl/d1_pkt.cc index b9b0ef931..b86615626 100644 --- a/ssl/d1_pkt.cc +++ b/ssl/d1_pkt.cc @@ -186,8 +186,8 @@ ssl_open_record_t dtls1_open_app_data(SSL *ssl, Span *out, return ssl_open_record_success; } -int dtls1_write_app_data(SSL *ssl, bool *out_needs_handshake, const uint8_t *in, - int len) { +int dtls1_write_app_data(SSL *ssl, bool *out_needs_handshake, + size_t *out_bytes_written, Span in) { assert(!SSL_in_init(ssl)); *out_needs_handshake = false; @@ -196,47 +196,46 @@ int dtls1_write_app_data(SSL *ssl, bool *out_needs_handshake, const uint8_t *in, return -1; } - if (len > SSL3_RT_MAX_PLAIN_LENGTH) { + // DTLS does not split the input across records. + if (in.size() > SSL3_RT_MAX_PLAIN_LENGTH) { OPENSSL_PUT_ERROR(SSL, SSL_R_DTLS_MESSAGE_TOO_BIG); return -1; } - if (len < 0) { - OPENSSL_PUT_ERROR(SSL, SSL_R_BAD_LENGTH); - return -1; - } - - if (len == 0) { - return 0; + if (in.empty()) { + *out_bytes_written = 0; + return 1; } - int ret = dtls1_write_record(ssl, SSL3_RT_APPLICATION_DATA, in, (size_t)len, + int ret = dtls1_write_record(ssl, SSL3_RT_APPLICATION_DATA, in, dtls1_use_current_epoch); if (ret <= 0) { return ret; } - return len; + *out_bytes_written = in.size(); + return 1; } -int dtls1_write_record(SSL *ssl, int type, const uint8_t *in, size_t len, +int dtls1_write_record(SSL *ssl, int type, Span in, enum dtls1_use_epoch_t use_epoch) { SSLBuffer *buf = &ssl->s3->write_buffer; - assert(len <= SSL3_RT_MAX_PLAIN_LENGTH); + assert(in.size() <= SSL3_RT_MAX_PLAIN_LENGTH); // There should never be a pending write buffer in DTLS. One can't write half // a datagram, so the write buffer is always dropped in // |ssl_write_buffer_flush|. assert(buf->empty()); - if (len > SSL3_RT_MAX_PLAIN_LENGTH) { + if (in.size() > SSL3_RT_MAX_PLAIN_LENGTH) { OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR); return -1; } size_t ciphertext_len; if (!buf->EnsureCap(ssl_seal_align_prefix_len(ssl), - len + SSL_max_seal_overhead(ssl)) || + in.size() + SSL_max_seal_overhead(ssl)) || !dtls_seal_record(ssl, buf->remaining().data(), &ciphertext_len, - buf->remaining().size(), type, in, len, use_epoch)) { + buf->remaining().size(), type, in.data(), in.size(), + use_epoch)) { buf->Clear(); return -1; } @@ -250,7 +249,7 @@ int dtls1_write_record(SSL *ssl, int type, const uint8_t *in, size_t len, } int dtls1_dispatch_alert(SSL *ssl) { - int ret = dtls1_write_record(ssl, SSL3_RT_ALERT, &ssl->s3->send_alert[0], 2, + int ret = dtls1_write_record(ssl, SSL3_RT_ALERT, ssl->s3->send_alert, dtls1_use_current_epoch); if (ret <= 0) { return ret; diff --git a/ssl/internal.h b/ssl/internal.h index 1a78f6335..0a15ace65 100644 --- a/ssl/internal.h +++ b/ssl/internal.h @@ -2456,8 +2456,13 @@ struct SSL_PROTOCOL_METHOD { ssl_open_record_t (*open_app_data)(SSL *ssl, Span *out, size_t *out_consumed, uint8_t *out_alert, Span in); - int (*write_app_data)(SSL *ssl, bool *out_needs_handshake, const uint8_t *buf, - int len); + // write_app_data encrypts and writes |in| as application data. On success, it + // returns one and sets |*out_bytes_written| to the number of bytes of |in| + // written. Otherwise, it returns <= 0 and sets |*out_needs_handshake| to + // whether the operation failed because the caller needs to drive the + // handshake. + int (*write_app_data)(SSL *ssl, bool *out_needs_handshake, + size_t *out_bytes_written, Span in); int (*dispatch_alert)(SSL *ssl); // init_message begins a new handshake message of type |type|. |cbb| is the // root CBB to be passed into |finish_message|. |*body| is set to a child CBB @@ -2646,11 +2651,23 @@ struct SSL3_STATE { // |read_buffer|. Span pending_app_data; - // partial write - check the numbers match - unsigned int wnum = 0; // number of bytes sent so far - int wpend_tot = 0; // number bytes written - int wpend_type = 0; - const uint8_t *wpend_buf = nullptr; + // unreported_bytes_written is the number of bytes successfully written to the + // transport, but not yet reported to the caller. The next |SSL_write| will + // skip this many bytes from the input. This is used if + // |SSL_MODE_ENABLE_PARTIAL_WRITE| is disabled, in which case |SSL_write| only + // reports bytes written when the full caller input is written. + size_t unreported_bytes_written = 0; + + // pending_write, if |has_pending_write| is true, is the caller-supplied data + // corresponding to the current pending write. This is used to check the + // caller retried with a compatible buffer. + Span pending_write; + + // pending_write_type, if |has_pending_write| is true, is the record type + // for the current pending write. + // + // TODO(davidben): Remove this when alerts are moved out of this write path. + uint8_t pending_write_type = 0; // read_shutdown is the shutdown state for the read half of the connection. enum ssl_shutdown_t read_shutdown = ssl_shutdown_none; @@ -3214,8 +3231,8 @@ ssl_open_record_t tls_open_app_data(SSL *ssl, Span *out, ssl_open_record_t tls_open_change_cipher_spec(SSL *ssl, size_t *out_consumed, uint8_t *out_alert, Span in); -int tls_write_app_data(SSL *ssl, bool *out_needs_handshake, const uint8_t *buf, - int len); +int tls_write_app_data(SSL *ssl, bool *out_needs_handshake, + size_t *out_bytes_written, Span in); bool tls_new(SSL *ssl); void tls_free(SSL *ssl); @@ -3248,11 +3265,11 @@ ssl_open_record_t dtls1_open_change_cipher_spec(SSL *ssl, size_t *out_consumed, Span in); int dtls1_write_app_data(SSL *ssl, bool *out_needs_handshake, - const uint8_t *buf, int len); + size_t *out_bytes_written, Span in); // dtls1_write_record sends a record. It returns one on success and <= 0 on // error. -int dtls1_write_record(SSL *ssl, int type, const uint8_t *buf, size_t len, +int dtls1_write_record(SSL *ssl, int type, Span in, enum dtls1_use_epoch_t use_epoch); int dtls1_retransmit_outgoing_messages(SSL *ssl); diff --git a/ssl/s3_pkt.cc b/ssl/s3_pkt.cc index efe5905e8..bc0d13dd5 100644 --- a/ssl/s3_pkt.cc +++ b/ssl/s3_pkt.cc @@ -126,10 +126,11 @@ BSSL_NAMESPACE_BEGIN -static int do_tls_write(SSL *ssl, int type, const uint8_t *in, unsigned len); +static int do_tls_write(SSL *ssl, size_t *out_bytes_written, uint8_t type, + Span in); -int tls_write_app_data(SSL *ssl, bool *out_needs_handshake, const uint8_t *in, - int len) { +int tls_write_app_data(SSL *ssl, bool *out_needs_handshake, + size_t *out_bytes_written, Span in) { assert(ssl_can_write(ssl)); assert(!ssl->s3->aead_write_ctx->is_null_cipher()); @@ -140,32 +141,28 @@ int tls_write_app_data(SSL *ssl, bool *out_needs_handshake, const uint8_t *in, return -1; } - // TODO(davidben): Switch this logic to |size_t| and |bssl::Span|. - assert(ssl->s3->wnum <= INT_MAX); - unsigned tot = ssl->s3->wnum; - - // Ensure that if we end up with a smaller value of data to write out than - // the the original len from a write which didn't complete for non-blocking - // I/O and also somehow ended up avoiding the check for this in - // do_tls_write/SSL_R_BAD_WRITE_RETRY as it must never be possible to end up - // with (len-tot) as a large number that will then promptly send beyond the - // end of the users buffer ... so we trap and report the error in a way the - // user will notice. - if (len < 0 || (size_t)len < tot) { + size_t total_bytes_written = ssl->s3->unreported_bytes_written; + if (in.size() < total_bytes_written) { + // This can happen if the caller disables |SSL_MODE_ENABLE_PARTIAL_WRITE|, + // asks us to write some input of length N, we successfully encrypt M bytes + // and write it, but fail to write the rest. We will report + // |SSL_ERROR_WANT_WRITE|. If the caller then retries with fewer than M + // bytes, we cannot satisfy that request. The caller is required to always + // retry with at least as many bytes as the previous attempt. OPENSSL_PUT_ERROR(SSL, SSL_R_BAD_LENGTH); return -1; } - const int is_early_data_write = - !ssl->server && SSL_in_early_data(ssl) && ssl->s3->hs->can_early_write; + in = in.subspan(total_bytes_written); - unsigned n = len - tot; + const bool is_early_data_write = + !ssl->server && SSL_in_early_data(ssl) && ssl->s3->hs->can_early_write; for (;;) { size_t max_send_fragment = ssl->max_send_fragment; if (is_early_data_write) { SSL_HANDSHAKE *hs = ssl->s3->hs.get(); if (hs->early_data_written >= hs->early_session->ticket_max_early_data) { - ssl->s3->wnum = tot; + ssl->s3->unreported_bytes_written = total_bytes_written; hs->can_early_write = false; *out_needs_handshake = true; return -1; @@ -175,35 +172,43 @@ int tls_write_app_data(SSL *ssl, bool *out_needs_handshake, const uint8_t *in, hs->early_data_written}); } - const size_t nw = std::min(max_send_fragment, size_t{n}); - int ret = do_tls_write(ssl, SSL3_RT_APPLICATION_DATA, &in[tot], nw); + const size_t to_write = std::min(max_send_fragment, in.size()); + size_t bytes_written; + int ret = do_tls_write(ssl, &bytes_written, SSL3_RT_APPLICATION_DATA, + in.subspan(0, to_write)); if (ret <= 0) { - ssl->s3->wnum = tot; + ssl->s3->unreported_bytes_written = total_bytes_written; return ret; } + // Note |bytes_written| may be less than |to_write| if there was a pending + // record from a smaller write attempt. + assert(bytes_written <= to_write); + total_bytes_written += bytes_written; + in = in.subspan(bytes_written); if (is_early_data_write) { - ssl->s3->hs->early_data_written += ret; + ssl->s3->hs->early_data_written += bytes_written; } - if (ret == (int)n || (ssl->mode & SSL_MODE_ENABLE_PARTIAL_WRITE)) { - ssl->s3->wnum = 0; - return tot + ret; + if (in.empty() || (ssl->mode & SSL_MODE_ENABLE_PARTIAL_WRITE)) { + ssl->s3->unreported_bytes_written = 0; + *out_bytes_written = total_bytes_written; + return 1; } - - n -= ret; - tot += ret; } } -// do_tls_write writes an SSL record of the given type. -static int do_tls_write(SSL *ssl, int type, const uint8_t *in, unsigned len) { +// do_tls_write writes an SSL record of the given type. On success, it sets +// |*out_bytes_written| to number of bytes successfully written and returns one. +// On error, it returns a value <= 0 from the underlying |BIO|. +static int do_tls_write(SSL *ssl, size_t *out_bytes_written, uint8_t type, + Span in) { // If there is a pending write, the retry must be consistent. - if (ssl->s3->wpend_tot > 0 && - (ssl->s3->wpend_tot > (int)len || + if (!ssl->s3->pending_write.empty() && + (ssl->s3->pending_write.size() > in.size() || (!(ssl->mode & SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER) && - ssl->s3->wpend_buf != in) || - ssl->s3->wpend_type != type)) { + ssl->s3->pending_write.data() != in.data()) || + ssl->s3->pending_write_type != type)) { OPENSSL_PUT_ERROR(SSL, SSL_R_BAD_WRITE_RETRY); return -1; } @@ -216,15 +221,14 @@ static int do_tls_write(SSL *ssl, int type, const uint8_t *in, unsigned len) { } // If there is a pending write, we just completed it. Report it to the caller. - if (ssl->s3->wpend_tot > 0) { - ret = ssl->s3->wpend_tot; - ssl->s3->wpend_buf = nullptr; - ssl->s3->wpend_tot = 0; - return ret; + if (!ssl->s3->pending_write.empty()) { + *out_bytes_written = ssl->s3->pending_write.size(); + ssl->s3->pending_write = {}; + return 1; } SSLBuffer *buf = &ssl->s3->write_buffer; - if (len > SSL3_RT_MAX_PLAIN_LENGTH || buf->size() > 0) { + if (in.size() > SSL3_RT_MAX_PLAIN_LENGTH || buf->size() > 0) { OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR); return -1; } @@ -233,16 +237,22 @@ static int do_tls_write(SSL *ssl, int type, const uint8_t *in, unsigned len) { return -1; } - size_t flight_len = 0; + // We may have unflushed handshake data that must be written before |in|. This + // may be a KeyUpdate acknowledgment, 0-RTT key change messages, or a + // NewSessionTicket. + Span pending_flight; if (ssl->s3->pending_flight != nullptr) { - flight_len = - ssl->s3->pending_flight->length - ssl->s3->pending_flight_offset; + pending_flight = MakeConstSpan( + reinterpret_cast(ssl->s3->pending_flight->data), + ssl->s3->pending_flight->length); + pending_flight = pending_flight.subspan(ssl->s3->pending_flight_offset); } - size_t max_out = flight_len; - if (len > 0) { - const size_t max_ciphertext_len = len + SSL_max_seal_overhead(ssl); - if (max_ciphertext_len < len || max_out + max_ciphertext_len < max_out) { + size_t max_out = pending_flight.size(); + if (!in.empty()) { + const size_t max_ciphertext_len = in.size() + SSL_max_seal_overhead(ssl); + if (max_ciphertext_len < in.size() || + max_out + max_ciphertext_len < max_out) { OPENSSL_PUT_ERROR(SSL, ERR_R_OVERFLOW); return -1; } @@ -250,31 +260,29 @@ static int do_tls_write(SSL *ssl, int type, const uint8_t *in, unsigned len) { } if (max_out == 0) { - return 0; + // Nothing to write. + *out_bytes_written = 0; + return 1; } - if (!buf->EnsureCap(flight_len + ssl_seal_align_prefix_len(ssl), max_out)) { + if (!buf->EnsureCap(pending_flight.size() + ssl_seal_align_prefix_len(ssl), + max_out)) { return -1; } - // Add any unflushed handshake data as a prefix. This may be a KeyUpdate - // acknowledgment or 0-RTT key change messages. |pending_flight| must be clear - // when data is added to |write_buffer| or it will be written in the wrong - // order. - if (ssl->s3->pending_flight != nullptr) { - OPENSSL_memcpy( - buf->remaining().data(), - ssl->s3->pending_flight->data + ssl->s3->pending_flight_offset, - flight_len); + // Copy |pending_flight| to the output. + if (!pending_flight.empty()) { + OPENSSL_memcpy(buf->remaining().data(), pending_flight.data(), + pending_flight.size()); ssl->s3->pending_flight.reset(); ssl->s3->pending_flight_offset = 0; - buf->DidWrite(flight_len); + buf->DidWrite(pending_flight.size()); } - if (len > 0) { + if (!in.empty()) { size_t ciphertext_len; if (!tls_seal_record(ssl, buf->remaining().data(), &ciphertext_len, - buf->remaining().size(), type, in, len)) { + buf->remaining().size(), type, in.data(), in.size())) { return -1; } buf->DidWrite(ciphertext_len); @@ -288,15 +296,15 @@ static int do_tls_write(SSL *ssl, int type, const uint8_t *in, unsigned len) { ret = ssl_write_buffer_flush(ssl); if (ret <= 0) { // Track the unfinished write. - if (len > 0) { - ssl->s3->wpend_tot = len; - ssl->s3->wpend_buf = in; - ssl->s3->wpend_type = type; + if (!in.empty()) { + ssl->s3->pending_write = in; + ssl->s3->pending_write_type = type; } return ret; } - return len; + *out_bytes_written = in.size(); + return 1; } ssl_open_record_t tls_open_app_data(SSL *ssl, Span *out, @@ -434,10 +442,13 @@ int tls_dispatch_alert(SSL *ssl) { return 0; } } else { - int ret = do_tls_write(ssl, SSL3_RT_ALERT, &ssl->s3->send_alert[0], 2); + size_t bytes_written; + int ret = + do_tls_write(ssl, &bytes_written, SSL3_RT_ALERT, ssl->s3->send_alert); if (ret <= 0) { return ret; } + assert(bytes_written == 2); } ssl->s3->alert_dispatch = false; diff --git a/ssl/ssl_lib.cc b/ssl/ssl_lib.cc index 703574847..4d3ad446e 100644 --- a/ssl/ssl_lib.cc +++ b/ssl/ssl_lib.cc @@ -1058,6 +1058,7 @@ int SSL_write(SSL *ssl, const void *buf, int num) { } int ret = 0; + size_t bytes_written = 0; bool needs_handshake = false; do { // If necessary, complete the handshake implicitly. @@ -1072,10 +1073,16 @@ int SSL_write(SSL *ssl, const void *buf, int num) { } } - ret = ssl->method->write_app_data(ssl, &needs_handshake, - (const uint8_t *)buf, num); + if (num < 0) { + OPENSSL_PUT_ERROR(SSL, SSL_R_BAD_LENGTH); + return -1; + } + ret = ssl->method->write_app_data( + ssl, &needs_handshake, &bytes_written, + MakeConstSpan(static_cast(buf), + static_cast(num))); } while (needs_handshake); - return ret; + return ret <= 0 ? ret : static_cast(bytes_written); } int SSL_key_update(SSL *ssl, int request_type) { @@ -1239,8 +1246,7 @@ void SSL_reset_early_data_reject(SSL *ssl) { // Discard any unfinished writes from the perspective of |SSL_write|'s // retry. The handshake will transparently flush out the pending record // (discarded by the server) to keep the framing correct. - ssl->s3->wpend_buf = nullptr; - ssl->s3->wpend_tot = 0; + ssl->s3->pending_write = {}; } enum ssl_early_data_reason_t SSL_get_early_data_reason(const SSL *ssl) {