@ -126,10 +126,11 @@
BSSL_NAMESPACE_BEGIN
static int do_tls_write ( SSL * ssl , int type , const uint8_t * in , unsigned len ) ;
static int do_tls_write ( SSL * ssl , size_t * out_bytes_written , uint8_t type ,
Span < const uint8_t > in ) ;
int tls_write_app_data ( SSL * ssl , bool * out_needs_handshake , const uint8_t * in ,
int le n) {
int tls_write_app_data ( SSL * ssl , bool * out_needs_handshake ,
size_t * out_bytes_written , Span < const uint8_t > i n) {
assert ( ssl_can_write ( ssl ) ) ;
assert ( ! ssl - > s3 - > aead_write_ctx - > is_null_cipher ( ) ) ;
@ -140,32 +141,28 @@ int tls_write_app_data(SSL *ssl, bool *out_needs_handshake, const uint8_t *in,
return - 1 ;
}
// TODO(davidben): Switch this logic to |size_t| and |bssl::Span|.
assert ( ssl - > s3 - > wnum < = INT_MAX ) ;
unsigned tot = ssl - > s3 - > wnum ;
// Ensure that if we end up with a smaller value of data to write out than
// the the original len from a write which didn't complete for non-blocking
// I/O and also somehow ended up avoiding the check for this in
// do_tls_write/SSL_R_BAD_WRITE_RETRY as it must never be possible to end up
// with (len-tot) as a large number that will then promptly send beyond the
// end of the users buffer ... so we trap and report the error in a way the
// user will notice.
if ( len < 0 | | ( size_t ) len < tot ) {
size_t total_bytes_written = ssl - > s3 - > unreported_bytes_written ;
if ( in . size ( ) < total_bytes_written ) {
// This can happen if the caller disables |SSL_MODE_ENABLE_PARTIAL_WRITE|,
// asks us to write some input of length N, we successfully encrypt M bytes
// and write it, but fail to write the rest. We will report
// |SSL_ERROR_WANT_WRITE|. If the caller then retries with fewer than M
// bytes, we cannot satisfy that request. The caller is required to always
// retry with at least as many bytes as the previous attempt.
OPENSSL_PUT_ERROR ( SSL , SSL_R_BAD_LENGTH ) ;
return - 1 ;
}
const int is_early_data_write =
! ssl - > server & & SSL_in_early_data ( ssl ) & & ssl - > s3 - > hs - > can_early_write ;
in = in . subspan ( total_bytes_written ) ;
unsigned n = len - tot ;
const bool is_early_data_write =
! ssl - > server & & SSL_in_early_data ( ssl ) & & ssl - > s3 - > hs - > can_early_write ;
for ( ; ; ) {
size_t max_send_fragment = ssl - > max_send_fragment ;
if ( is_early_data_write ) {
SSL_HANDSHAKE * hs = ssl - > s3 - > hs . get ( ) ;
if ( hs - > early_data_written > = hs - > early_session - > ticket_max_early_data ) {
ssl - > s3 - > wnum = tot ;
ssl - > s3 - > unreported_bytes_ writte n = total_bytes_written ;
hs - > can_early_write = false ;
* out_needs_handshake = true ;
return - 1 ;
@ -175,35 +172,43 @@ int tls_write_app_data(SSL *ssl, bool *out_needs_handshake, const uint8_t *in,
hs - > early_data_written } ) ;
}
const size_t nw = std : : min ( max_send_fragment , size_t { n } ) ;
int ret = do_tls_write ( ssl , SSL3_RT_APPLICATION_DATA , & in [ tot ] , nw ) ;
const size_t to_write = std : : min ( max_send_fragment , in . size ( ) ) ;
size_t bytes_written ;
int ret = do_tls_write ( ssl , & bytes_written , SSL3_RT_APPLICATION_DATA ,
in . subspan ( 0 , to_write ) ) ;
if ( ret < = 0 ) {
ssl - > s3 - > wnum = tot ;
ssl - > s3 - > unreported_bytes_ writte n = total_bytes_written ;
return ret ;
}
// Note |bytes_written| may be less than |to_write| if there was a pending
// record from a smaller write attempt.
assert ( bytes_written < = to_write ) ;
total_bytes_written + = bytes_written ;
in = in . subspan ( bytes_written ) ;
if ( is_early_data_write ) {
ssl - > s3 - > hs - > early_data_written + = ret ;
ssl - > s3 - > hs - > early_data_written + = bytes_written ;
}
if ( ret = = ( int ) n | | ( ssl - > mode & SSL_MODE_ENABLE_PARTIAL_WRITE ) ) {
ssl - > s3 - > wnum = 0 ;
return tot + ret ;
if ( in . empty ( ) | | ( ssl - > mode & SSL_MODE_ENABLE_PARTIAL_WRITE ) ) {
ssl - > s3 - > unreported_bytes_written = 0 ;
* out_bytes_written = total_bytes_written ;
return 1 ;
}
n - = ret ;
tot + = ret ;
}
}
// do_tls_write writes an SSL record of the given type.
static int do_tls_write ( SSL * ssl , int type , const uint8_t * in , unsigned len ) {
// do_tls_write writes an SSL record of the given type. On success, it sets
// |*out_bytes_written| to number of bytes successfully written and returns one.
// On error, it returns a value <= 0 from the underlying |BIO|.
static int do_tls_write ( SSL * ssl , size_t * out_bytes_written , uint8_t type ,
Span < const uint8_t > in ) {
// If there is a pending write, the retry must be consistent.
if ( ssl - > s3 - > wpend_tot > 0 & &
( ssl - > s3 - > wpend_tot > ( int ) len | |
if ( ! ssl - > s3 - > pending_write . empty ( ) & &
( ssl - > s3 - > pending_write . size ( ) > in . size ( ) | |
( ! ( ssl - > mode & SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER ) & &
ssl - > s3 - > wpend_buf ! = in ) | |
ssl - > s3 - > w pend_type ! = type ) ) {
ssl - > s3 - > pending_write . data ( ) ! = in . data ( ) ) | |
ssl - > s3 - > pending_write _type ! = type ) ) {
OPENSSL_PUT_ERROR ( SSL , SSL_R_BAD_WRITE_RETRY ) ;
return - 1 ;
}
@ -216,15 +221,14 @@ static int do_tls_write(SSL *ssl, int type, const uint8_t *in, unsigned len) {
}
// If there is a pending write, we just completed it. Report it to the caller.
if ( ssl - > s3 - > wpend_tot > 0 ) {
ret = ssl - > s3 - > wpend_tot ;
ssl - > s3 - > wpend_buf = nullptr ;
ssl - > s3 - > wpend_tot = 0 ;
return ret ;
if ( ! ssl - > s3 - > pending_write . empty ( ) ) {
* out_bytes_written = ssl - > s3 - > pending_write . size ( ) ;
ssl - > s3 - > pending_write = { } ;
return 1 ;
}
SSLBuffer * buf = & ssl - > s3 - > write_buffer ;
if ( l en > SSL3_RT_MAX_PLAIN_LENGTH | | buf - > size ( ) > 0 ) {
if ( in . siz e( ) > SSL3_RT_MAX_PLAIN_LENGTH | | buf - > size ( ) > 0 ) {
OPENSSL_PUT_ERROR ( SSL , ERR_R_INTERNAL_ERROR ) ;
return - 1 ;
}
@ -233,16 +237,22 @@ static int do_tls_write(SSL *ssl, int type, const uint8_t *in, unsigned len) {
return - 1 ;
}
size_t flight_len = 0 ;
// We may have unflushed handshake data that must be written before |in|. This
// may be a KeyUpdate acknowledgment, 0-RTT key change messages, or a
// NewSessionTicket.
Span < const uint8_t > pending_flight ;
if ( ssl - > s3 - > pending_flight ! = nullptr ) {
flight_len =
ssl - > s3 - > pending_flight - > length - ssl - > s3 - > pending_flight_offset ;
pending_flight = MakeConstSpan (
reinterpret_cast < const uint8_t * > ( ssl - > s3 - > pending_flight - > data ) ,
ssl - > s3 - > pending_flight - > length ) ;
pending_flight = pending_flight . subspan ( ssl - > s3 - > pending_flight_offset ) ;
}
size_t max_out = flight_len ;
if ( len > 0 ) {
const size_t max_ciphertext_len = len + SSL_max_seal_overhead ( ssl ) ;
if ( max_ciphertext_len < len | | max_out + max_ciphertext_len < max_out ) {
size_t max_out = pending_flight . size ( ) ;
if ( ! in . empty ( ) ) {
const size_t max_ciphertext_len = in . size ( ) + SSL_max_seal_overhead ( ssl ) ;
if ( max_ciphertext_len < in . size ( ) | |
max_out + max_ciphertext_len < max_out ) {
OPENSSL_PUT_ERROR ( SSL , ERR_R_OVERFLOW ) ;
return - 1 ;
}
@ -250,31 +260,29 @@ static int do_tls_write(SSL *ssl, int type, const uint8_t *in, unsigned len) {
}
if ( max_out = = 0 ) {
return 0 ;
// Nothing to write.
* out_bytes_written = 0 ;
return 1 ;
}
if ( ! buf - > EnsureCap ( flight_len + ssl_seal_align_prefix_len ( ssl ) , max_out ) ) {
if ( ! buf - > EnsureCap ( pending_flight . size ( ) + ssl_seal_align_prefix_len ( ssl ) ,
max_out ) ) {
return - 1 ;
}
// Add any unflushed handshake data as a prefix. This may be a KeyUpdate
// acknowledgment or 0-RTT key change messages. |pending_flight| must be clear
// when data is added to |write_buffer| or it will be written in the wrong
// order.
if ( ssl - > s3 - > pending_flight ! = nullptr ) {
OPENSSL_memcpy (
buf - > remaining ( ) . data ( ) ,
ssl - > s3 - > pending_flight - > data + ssl - > s3 - > pending_flight_offset ,
flight_len ) ;
// Copy |pending_flight| to the output.
if ( ! pending_flight . empty ( ) ) {
OPENSSL_memcpy ( buf - > remaining ( ) . data ( ) , pending_flight . data ( ) ,
pending_flight . size ( ) ) ;
ssl - > s3 - > pending_flight . reset ( ) ;
ssl - > s3 - > pending_flight_offset = 0 ;
buf - > DidWrite ( flight_len ) ;
buf - > DidWrite ( pending_flight . size ( ) ) ;
}
if ( len > 0 ) {
if ( ! in . empty ( ) ) {
size_t ciphertext_len ;
if ( ! tls_seal_record ( ssl , buf - > remaining ( ) . data ( ) , & ciphertext_len ,
buf - > remaining ( ) . size ( ) , type , in , l en ) ) {
buf - > remaining ( ) . size ( ) , type , in . data ( ) , in . siz e( ) ) ) {
return - 1 ;
}
buf - > DidWrite ( ciphertext_len ) ;
@ -288,15 +296,15 @@ static int do_tls_write(SSL *ssl, int type, const uint8_t *in, unsigned len) {
ret = ssl_write_buffer_flush ( ssl ) ;
if ( ret < = 0 ) {
// Track the unfinished write.
if ( len > 0 ) {
ssl - > s3 - > wpend_tot = len ;
ssl - > s3 - > wpend_buf = in ;
ssl - > s3 - > wpend_type = type ;
if ( ! in . empty ( ) ) {
ssl - > s3 - > pending_write = in ;
ssl - > s3 - > pending_write_type = type ;
}
return ret ;
}
return len ;
* out_bytes_written = in . size ( ) ;
return 1 ;
}
ssl_open_record_t tls_open_app_data ( SSL * ssl , Span < uint8_t > * out ,
@ -434,10 +442,13 @@ int tls_dispatch_alert(SSL *ssl) {
return 0 ;
}
} else {
int ret = do_tls_write ( ssl , SSL3_RT_ALERT , & ssl - > s3 - > send_alert [ 0 ] , 2 ) ;
size_t bytes_written ;
int ret =
do_tls_write ( ssl , & bytes_written , SSL3_RT_ALERT , ssl - > s3 - > send_alert ) ;
if ( ret < = 0 ) {
return ret ;
}
assert ( bytes_written = = 2 ) ;
}
ssl - > s3 - > alert_dispatch = false ;