diff --git a/ssl/s3_pkt.cc b/ssl/s3_pkt.cc index 457696d81..450f7dc0b 100644 --- a/ssl/s3_pkt.cc +++ b/ssl/s3_pkt.cc @@ -112,6 +112,8 @@ #include #include +#include + #include #include #include @@ -138,10 +140,9 @@ int tls_write_app_data(SSL *ssl, bool *out_needs_handshake, const uint8_t *in, return -1; } - unsigned tot, n, nw; - + // TODO(davidben): Switch this logic to |size_t| and |bssl::Span|. assert(ssl->s3->wnum <= INT_MAX); - tot = ssl->s3->wnum; + unsigned tot = ssl->s3->wnum; ssl->s3->wnum = 0; // Ensure that if we end up with a smaller value of data to write out than @@ -159,29 +160,23 @@ int tls_write_app_data(SSL *ssl, bool *out_needs_handshake, const uint8_t *in, const int is_early_data_write = !ssl->server && SSL_in_early_data(ssl) && ssl->s3->hs->can_early_write; - n = len - tot; + unsigned n = len - tot; for (;;) { - // max contains the maximum number of bytes that we can put into a record. - unsigned max = ssl->max_send_fragment; - if (is_early_data_write && - max > ssl->session->ticket_max_early_data - - ssl->s3->hs->early_data_written) { - max = - ssl->session->ticket_max_early_data - ssl->s3->hs->early_data_written; - if (max == 0) { + size_t max_send_fragment = ssl->max_send_fragment; + if (is_early_data_write) { + SSL_HANDSHAKE *hs = ssl->s3->hs.get(); + if (hs->early_data_written >= hs->early_session->ticket_max_early_data) { ssl->s3->wnum = tot; - ssl->s3->hs->can_early_write = false; + hs->can_early_write = false; *out_needs_handshake = true; return -1; } + max_send_fragment = std::min( + max_send_fragment, size_t{hs->early_session->ticket_max_early_data - + hs->early_data_written}); } - if (n > max) { - nw = max; - } else { - nw = n; - } - + const size_t nw = std::min(max_send_fragment, size_t{n}); int ret = do_tls_write(ssl, SSL3_RT_APPLICATION_DATA, &in[tot], nw); if (ret <= 0) { ssl->s3->wnum = tot; diff --git a/ssl/test/bssl_shim.cc b/ssl/test/bssl_shim.cc index 89313493c..9438c1f11 100644 --- a/ssl/test/bssl_shim.cc +++ b/ssl/test/bssl_shim.cc @@ -851,6 +851,7 @@ static bool DoExchange(bssl::UniquePtr *out_session, int ret; SSL *ssl = ssl_uniqueptr->get(); SSL_CTX *session_ctx = SSL_get_SSL_CTX(ssl); + TestState *test_state = GetTestState(ssl); if (!config->implicit_handshake) { if (config->handoff) { @@ -859,6 +860,7 @@ static bool DoExchange(bssl::UniquePtr *out_session, return false; } ssl = ssl_uniqueptr->get(); + test_state = GetTestState(ssl); #else fprintf(stderr, "The external handshaker can only be used on Linux\n"); return false; @@ -903,9 +905,44 @@ static bool DoExchange(bssl::UniquePtr *out_session, return false; } + if (config->early_write_after_message != 0) { + if (!SSL_in_early_data(ssl) || config->is_server) { + fprintf(stderr, + "-early-write-after-message only works for 0-RTT connections " + "on servers.\n"); + return false; + } + if (!config->shim_writes_first || !config->async) { + fprintf(stderr, + "-early-write-after-message requires -shim-writes-first and " + "-async.\n"); + return false; + } + // Run the handshake until the specified message. Note that, if a + // handshake record contains multiple messages, |SSL_do_handshake| usually + // processes both atomically. The test must ensure there is a record + // boundary after the desired message. Checking |last_message_received| + // confirms this. + do { + ret = SSL_do_handshake(ssl); + } while (test_state->last_message_received != + config->early_write_after_message && + RetryAsync(ssl, ret)); + if (ret == 1) { + fprintf(stderr, "Handshake unexpectedly succeeded.\n"); + return false; + } + if (test_state->last_message_received != + config->early_write_after_message) { + // The handshake failed before we saw the target message. The generic + // error-handling logic in the caller will print the error. + return false; + } + } + // Reset the state to assert later that the callback isn't called in // renegotations. - GetTestState(ssl)->got_new_session = false; + test_state->got_new_session = false; } if (config->export_keying_material > 0) { @@ -1005,7 +1042,7 @@ static bool DoExchange(bssl::UniquePtr *out_session, } // Let only one byte of the record through. - AsyncBioAllowWrite(GetTestState(ssl)->async_bio, 1); + AsyncBioAllowWrite(test_state->async_bio, 1); int write_ret = SSL_write(ssl, kInitialWrite, strlen(kInitialWrite)); if (SSL_get_error(ssl, write_ret) != SSL_ERROR_WANT_WRITE) { @@ -1060,7 +1097,7 @@ static bool DoExchange(bssl::UniquePtr *out_session, // After a successful read, with or without False Start, the handshake // must be complete unless we are doing early data. - if (!GetTestState(ssl)->handshake_done && + if (!test_state->handshake_done && !SSL_early_data_accepted(ssl)) { fprintf(stderr, "handshake was not completed after SSL_read\n"); return false; @@ -1094,7 +1131,7 @@ static bool DoExchange(bssl::UniquePtr *out_session, !config->implicit_handshake && // Session tickets are sent post-handshake in TLS 1.3. GetProtocolVersion(ssl) < TLS1_3_VERSION && - GetTestState(ssl)->got_new_session) { + test_state->got_new_session) { fprintf(stderr, "new session was established after the handshake\n"); return false; } @@ -1102,16 +1139,16 @@ static bool DoExchange(bssl::UniquePtr *out_session, if (GetProtocolVersion(ssl) >= TLS1_3_VERSION && !config->is_server) { bool expect_new_session = !config->expect_no_session && !config->shim_shuts_down; - if (expect_new_session != GetTestState(ssl)->got_new_session) { + if (expect_new_session != test_state->got_new_session) { fprintf(stderr, "new session was%s cached, but we expected the opposite\n", - GetTestState(ssl)->got_new_session ? "" : " not"); + test_state->got_new_session ? "" : " not"); return false; } if (expect_new_session) { bool got_early_data = - GetTestState(ssl)->new_session->ticket_max_early_data != 0; + test_state->new_session->ticket_max_early_data != 0; if (config->expect_ticket_supports_early_data != got_early_data) { fprintf(stderr, "new session did%s support early data, but we expected the " @@ -1123,7 +1160,7 @@ static bool DoExchange(bssl::UniquePtr *out_session, } if (out_session) { - *out_session = std::move(GetTestState(ssl)->new_session); + *out_session = std::move(test_state->new_session); } ret = DoShutdown(ssl); @@ -1172,10 +1209,10 @@ static bool DoExchange(bssl::UniquePtr *out_session, if (config->renegotiate_explicit && SSL_total_renegotiations(ssl) != - GetTestState(ssl)->explicit_renegotiates) { + test_state->explicit_renegotiates) { fprintf(stderr, "Performed %d renegotiations, but triggered %d of them\n", SSL_total_renegotiations(ssl), - GetTestState(ssl)->explicit_renegotiates); + test_state->explicit_renegotiates); return false; } diff --git a/ssl/test/runner/runner.go b/ssl/test/runner/runner.go index c169d2d05..f802585f1 100644 --- a/ssl/test/runner/runner.go +++ b/ssl/test/runner/runner.go @@ -1060,7 +1060,7 @@ func doExchange(test *testCase, config *Config, conn net.Conn, isResume bool, tr shimPrefix = test.resumeShimPrefix } if test.shimWritesFirst || test.readWithUnfinishedWrite { - shimPrefix = "hello" + shimPrefix = shimInitialWrite } if test.renegotiate > 0 { // If readWithUnfinishedWrite is set, the shim prefix will be @@ -1294,6 +1294,10 @@ func translateExpectedError(errorStr string) string { return errorStr } +// shimInitialWrite is the data we expect from the shim when the +// -shim-writes-first flag is used. +const shimInitialWrite = "hello" + func runTest(statusChan chan statusMsg, test *testCase, shimPath string, mallocNumToFail int64) error { // Help debugging panics on the Go side. defer func() { @@ -1433,7 +1437,7 @@ func runTest(statusChan chan statusMsg, test *testCase, shimPath string, mallocN // Configure the shim to send some data in early data. flags = append(flags, "-on-resume-shim-writes-first") if resumeConfig.Bugs.ExpectEarlyData == nil { - resumeConfig.Bugs.ExpectEarlyData = [][]byte{[]byte("hello")} + resumeConfig.Bugs.ExpectEarlyData = [][]byte{[]byte(shimInitialWrite)} } } else { // By default, send some early data and expect half-RTT data response. @@ -4842,10 +4846,10 @@ func addStateMachineCoverageTests(config stateMachineTestConfig) { MinVersion: VersionTLS13, MaxEarlyDataSize: 2, Bugs: ProtocolBugs{ - ExpectEarlyData: [][]byte{{'h', 'e'}}, + ExpectEarlyData: [][]byte{[]byte(shimInitialWrite[:2])}, }, }, - resumeShimPrefix: "llo", + resumeShimPrefix: shimInitialWrite[2:], resumeSession: true, earlyData: true, }) @@ -4865,8 +4869,9 @@ func addStateMachineCoverageTests(config stateMachineTestConfig) { MaxVersion: VersionTLS13, MinVersion: VersionTLS13, Bugs: ProtocolBugs{ + // Write the server response before expecting early data. ExpectEarlyData: [][]byte{}, - ExpectLateEarlyData: [][]byte{{'h', 'e', 'l', 'l', 'o'}}, + ExpectLateEarlyData: [][]byte{[]byte(shimInitialWrite)}, }, }, resumeSession: true, @@ -15147,6 +15152,51 @@ func addTLS13HandshakeTests() { expectedError: ":CIPHER_MISMATCH_ON_EARLY_DATA:", expectedLocalError: "remote error: illegal parameter", }) + + // Test that the client can write early data when it has received a partial + // ServerHello..Finished flight. See https://crbug.com/1208784. Note the + // EncryptedExtensions test assumes EncryptedExtensions and Finished are in + // separate records, i.e. that PackHandshakeFlight is disabled. + testCases = append(testCases, testCase{ + testType: clientTest, + name: "EarlyData-WriteAfterServerHello", + config: Config{ + MinVersion: VersionTLS13, + MaxVersion: VersionTLS13, + Bugs: ProtocolBugs{ + // Write the server response before expecting early data. + ExpectEarlyData: [][]byte{}, + ExpectLateEarlyData: [][]byte{[]byte(shimInitialWrite)}, + }, + }, + resumeSession: true, + earlyData: true, + flags: []string{ + "-async", + "-on-resume-early-write-after-message", + strconv.Itoa(int(typeServerHello)), + }, + }) + testCases = append(testCases, testCase{ + testType: clientTest, + name: "EarlyData-WriteAfterEncryptedExtensions", + config: Config{ + MinVersion: VersionTLS13, + MaxVersion: VersionTLS13, + Bugs: ProtocolBugs{ + // Write the server response before expecting early data. + ExpectEarlyData: [][]byte{}, + ExpectLateEarlyData: [][]byte{[]byte(shimInitialWrite)}, + }, + }, + resumeSession: true, + earlyData: true, + flags: []string{ + "-async", + "-on-resume-early-write-after-message", + strconv.Itoa(int(typeEncryptedExtensions)), + }, + }) } func addTLS13CipherPreferenceTests() { diff --git a/ssl/test/test_config.cc b/ssl/test/test_config.cc index fff536f3e..e933f0f99 100644 --- a/ssl/test/test_config.cc +++ b/ssl/test/test_config.cc @@ -235,6 +235,7 @@ const Flag kIntFlags[] = { {"-read-size", &TestConfig::read_size}, {"-expect-ticket-age-skew", &TestConfig::expect_ticket_age_skew}, {"-quic-use-legacy-codepoint", &TestConfig::quic_use_legacy_codepoint}, + {"-early-write-after-message", &TestConfig::early_write_after_message}, }; const Flag> kIntVectorFlags[] = { @@ -599,6 +600,9 @@ static void MessageCallback(int is_write, int version, int content_type, char text[16]; snprintf(text, sizeof(text), "hs %d\n", type); state->msg_callback_text += text; + if (!is_write) { + state->last_message_received = type; + } return; } diff --git a/ssl/test/test_config.h b/ssl/test/test_config.h index f4e3f61a5..f4ddba236 100644 --- a/ssl/test/test_config.h +++ b/ssl/test/test_config.h @@ -190,6 +190,7 @@ struct TestConfig { bool expect_no_hrr = false; bool wait_for_debugger = false; std::string quic_early_data_context; + int early_write_after_message = 0; int argc; char **argv; diff --git a/ssl/test/test_state.h b/ssl/test/test_state.h index 2c558a425..d9fe945ac 100644 --- a/ssl/test/test_state.h +++ b/ssl/test/test_state.h @@ -68,6 +68,7 @@ struct TestState { bool cert_verified = false; int explicit_renegotiates = 0; std::function get_handshake_hints_cb; + int last_message_received = -1; }; bool SetTestState(SSL *ssl, std::unique_ptr state);