diff --git a/ssl/test/mock_quic_transport.cc b/ssl/test/mock_quic_transport.cc index 6a3f0e8ac..45d664a65 100644 --- a/ssl/test/mock_quic_transport.cc +++ b/ssl/test/mock_quic_transport.cc @@ -73,47 +73,97 @@ bool ReadAll(BIO *bio, bssl::Span out) { return true; } +const char *LevelToString(ssl_encryption_level_t level) { + switch (level) { + case ssl_encryption_initial: + return "initial"; + case ssl_encryption_early_data: + return "early_data"; + case ssl_encryption_handshake: + return "handshake"; + case ssl_encryption_application: + return "application"; + } + return ""; +} + } // namespace -bool MockQuicTransport::ReadHeader(uint8_t *out_tag, size_t *out_len) { - uint8_t header[7]; - if (!ReadAll(bio_.get(), header)) { - return false; - } - *out_tag = header[0]; - uint16_t cipher_suite = header[1] << 8 | header[2]; - size_t remaining_bytes = - header[3] << 24 | header[4] << 16 | header[5] << 8 | header[6]; - - enum ssl_encryption_level_t level = SSL_quic_read_level(ssl_); - if (*out_tag == kTagApplication) { - if (SSL_in_early_data(ssl_)) { - level = ssl_encryption_early_data; - } else { - level = ssl_encryption_application; +bool MockQuicTransport::ReadHeader(uint8_t *out_tag, + enum ssl_encryption_level_t *out_level, + size_t *out_len) { + for (;;) { + uint8_t header[8]; + if (!ReadAll(bio_.get(), header)) { + // TODO(davidben): Distinguish between errors and EOF. See + // ReadApplicationData. + return false; } + + CBS cbs; + uint8_t level_id; + uint16_t cipher_suite; + uint32_t remaining_bytes; + CBS_init(&cbs, header, sizeof(header)); + if (!CBS_get_u8(&cbs, out_tag) || + !CBS_get_u8(&cbs, &level_id) || + !CBS_get_u16(&cbs, &cipher_suite) || + !CBS_get_u32(&cbs, &remaining_bytes) || + level_id >= read_levels_.size()) { + fprintf(stderr, "Error parsing record header.\n"); + return false; + } + + auto level = static_cast(level_id); + // Non-initial levels must be configured before use. + uint16_t expect_cipher = read_levels_[level].cipher; + if (expect_cipher == 0 && level != ssl_encryption_initial) { + if (level == ssl_encryption_early_data) { + // If we receive early data records without any early data keys, skip + // the record. This means early data was rejected. + std::vector discard(remaining_bytes); + if (!ReadAll(bio_.get(), bssl::MakeSpan(discard))) { + return false; + } + continue; + } + fprintf(stderr, + "Got record at level %s, but keys were not configured.\n", + LevelToString(level)); + return false; + } + if (cipher_suite != expect_cipher) { + fprintf(stderr, "Got cipher suite 0x%04x at level %s, wanted 0x%04x.\n", + cipher_suite, LevelToString(level), expect_cipher); + return false; + } + const std::vector &secret = read_levels_[level].secret; + std::vector read_secret(secret.size()); + if (remaining_bytes < secret.size()) { + fprintf(stderr, "Record at level %s too small.\n", LevelToString(level)); + return false; + } + remaining_bytes -= secret.size(); + if (!ReadAll(bio_.get(), bssl::MakeSpan(read_secret))) { + fprintf(stderr, "Error reading record secret.\n"); + return false; + } + if (read_secret != secret) { + fprintf(stderr, "Encryption secret at level %s did not match.\n", + LevelToString(level)); + return false; + } + *out_level = level; + *out_len = remaining_bytes; + return true; } - if (cipher_suite != read_levels_[level].cipher) { - return false; - } - const std::vector &secret = read_levels_[level].secret; - std::vector read_secret(secret.size()); - if (remaining_bytes < secret.size()) { - return false; - } - remaining_bytes -= secret.size(); - if (!ReadAll(bio_.get(), bssl::MakeSpan(read_secret)) || - read_secret != secret) { - return false; - } - *out_len = remaining_bytes; - return true; } bool MockQuicTransport::ReadHandshake() { uint8_t tag; + ssl_encryption_level_t level; size_t len; - if (!ReadHeader(&tag, &len)) { + if (!ReadHeader(&tag, &level, &len)) { return false; } if (tag != kTagHandshake) { @@ -124,8 +174,7 @@ bool MockQuicTransport::ReadHandshake() { if (!ReadAll(bio_.get(), bssl::MakeSpan(buf))) { return false; } - return SSL_provide_quic_data(ssl_, SSL_quic_read_level(ssl_), buf.data(), - buf.size()); + return SSL_provide_quic_data(ssl_, level, buf.data(), buf.size()); } int MockQuicTransport::ReadApplicationData(uint8_t *out, size_t max_out) { @@ -144,9 +193,10 @@ int MockQuicTransport::ReadApplicationData(uint8_t *out, size_t max_out) { } uint8_t tag = 0; + ssl_encryption_level_t level; size_t len; while (true) { - if (!ReadHeader(&tag, &len)) { + if (!ReadHeader(&tag, &level, &len)) { // Assume that a failure to read the header means there's no more to read, // not an error reading. return 0; @@ -162,8 +212,7 @@ int MockQuicTransport::ReadApplicationData(uint8_t *out, size_t max_out) { if (!ReadAll(bio_.get(), bssl::MakeSpan(buf))) { return -1; } - if (SSL_provide_quic_data(ssl_, SSL_quic_read_level(ssl_), buf.data(), - buf.size()) != 1) { + if (SSL_provide_quic_data(ssl_, level, buf.data(), buf.size()) != 1) { return -1; } if (SSL_in_init(ssl_)) { @@ -203,14 +252,15 @@ bool MockQuicTransport::WriteRecord(enum ssl_encryption_level_t level, uint16_t cipher_suite = write_levels_[level].cipher; const std::vector &secret = write_levels_[level].secret; size_t tlv_len = secret.size() + len; - uint8_t header[7]; + uint8_t header[8]; header[0] = tag; - header[1] = (cipher_suite >> 8) & 0xff; - header[2] = cipher_suite & 0xff; - header[3] = (tlv_len >> 24) & 0xff; - header[4] = (tlv_len >> 16) & 0xff; - header[5] = (tlv_len >> 8) & 0xff; - header[6] = tlv_len & 0xff; + header[1] = level; + header[2] = (cipher_suite >> 8) & 0xff; + header[3] = cipher_suite & 0xff; + header[4] = (tlv_len >> 24) & 0xff; + header[5] = (tlv_len >> 16) & 0xff; + header[6] = (tlv_len >> 8) & 0xff; + header[7] = tlv_len & 0xff; return BIO_write_all(bio_.get(), header, sizeof(header)) && BIO_write_all(bio_.get(), secret.data(), secret.size()) && BIO_write_all(bio_.get(), data, len); diff --git a/ssl/test/mock_quic_transport.h b/ssl/test/mock_quic_transport.h index a56652dbf..114f059f4 100644 --- a/ssl/test/mock_quic_transport.h +++ b/ssl/test/mock_quic_transport.h @@ -45,10 +45,12 @@ class MockQuicTransport { // Reads a record header from |bio_| and returns whether the record was read // successfully. As part of reading the header, this function checks that the // cipher suite and secret in the header are correct. On success, the tag - // indicating the TLS record type is put in |*out_tag|, the length of the TLS - // record is put in |*out_len|, and the next thing to be read from |bio_| is - // |*out_len| bytes of the TLS record. - bool ReadHeader(uint8_t *out_tag, size_t *out_len); + // indicating the TLS record type is put in |*out_tag|, the encryption level + // is put in |*out_level|, the length of the TLS record is put in |*out_len|, + // and the next thing to be read from |bio_| is |*out_len| bytes of the TLS + // record. + bool ReadHeader(uint8_t *out_tag, enum ssl_encryption_level_t *out_level, + size_t *out_len); // Writes a MockQuicTransport record to |bio_| at encryption level |level| // with record type |tag| and a TLS record payload of length |len| from diff --git a/ssl/test/runner/conn.go b/ssl/test/runner/conn.go index c0c91d29b..9fa5c0508 100644 --- a/ssl/test/runner/conn.go +++ b/ssl/test/runner/conn.go @@ -754,7 +754,7 @@ func (hc *halfConn) splitBlock(b *block, n int) (*block, *block) { return b, bb } -func (c *Conn) useInTrafficSecret(version uint16, suite *cipherSuite, secret []byte) error { +func (c *Conn) useInTrafficSecret(level encryptionLevel, version uint16, suite *cipherSuite, secret []byte) error { if c.hand.Len() != 0 { return c.in.setErrorLocked(errors.New("tls: buffered handshake messages on cipher change")) } @@ -763,6 +763,7 @@ func (c *Conn) useInTrafficSecret(version uint16, suite *cipherSuite, secret []b side = clientWrite } if c.config.Bugs.MockQUICTransport != nil { + c.config.Bugs.MockQUICTransport.readLevel = level c.config.Bugs.MockQUICTransport.readSecret = secret c.config.Bugs.MockQUICTransport.readCipherSuite = suite.id } @@ -771,12 +772,13 @@ func (c *Conn) useInTrafficSecret(version uint16, suite *cipherSuite, secret []b return nil } -func (c *Conn) useOutTrafficSecret(version uint16, suite *cipherSuite, secret []byte) { +func (c *Conn) useOutTrafficSecret(level encryptionLevel, version uint16, suite *cipherSuite, secret []byte) { side := serverWrite if c.isClient { side = clientWrite } if c.config.Bugs.MockQUICTransport != nil { + c.config.Bugs.MockQUICTransport.writeLevel = level c.config.Bugs.MockQUICTransport.writeSecret = secret c.config.Bugs.MockQUICTransport.writeCipherSuite = suite.id } @@ -1677,7 +1679,7 @@ func (c *Conn) handlePostHandshakeMessage() error { if c.config.Bugs.RejectUnsolicitedKeyUpdate { return errors.New("tls: unexpected KeyUpdate message") } - if err := c.useInTrafficSecret(c.in.wireVersion, c.cipherSuite, updateTrafficSecret(c.cipherSuite.hash(), c.wireVersion, c.in.trafficSecret)); err != nil { + if err := c.useInTrafficSecret(encryptionApplication, c.in.wireVersion, c.cipherSuite, updateTrafficSecret(c.cipherSuite.hash(), c.wireVersion, c.in.trafficSecret)); err != nil { return err } if keyUpdate.keyUpdateRequest == keyUpdateRequested { @@ -1711,7 +1713,7 @@ func (c *Conn) ReadKeyUpdateACK() error { return errors.New("tls: received invalid KeyUpdate message") } - return c.useInTrafficSecret(c.in.wireVersion, c.cipherSuite, updateTrafficSecret(c.cipherSuite.hash(), c.wireVersion, c.in.trafficSecret)) + return c.useInTrafficSecret(encryptionApplication, c.in.wireVersion, c.cipherSuite, updateTrafficSecret(c.cipherSuite.hash(), c.wireVersion, c.in.trafficSecret)) } func (c *Conn) Renegotiate() error { @@ -2065,7 +2067,7 @@ func (c *Conn) sendKeyUpdateLocked(keyUpdateRequest byte) error { if err := c.flushHandshake(); err != nil { return err } - c.useOutTrafficSecret(c.out.wireVersion, c.cipherSuite, updateTrafficSecret(c.cipherSuite.hash(), c.wireVersion, c.out.trafficSecret)) + c.useOutTrafficSecret(encryptionApplication, c.out.wireVersion, c.cipherSuite, updateTrafficSecret(c.cipherSuite.hash(), c.wireVersion, c.out.trafficSecret)) return nil } diff --git a/ssl/test/runner/handshake_client.go b/ssl/test/runner/handshake_client.go index bf89c0163..ad01f1e22 100644 --- a/ssl/test/runner/handshake_client.go +++ b/ssl/test/runner/handshake_client.go @@ -518,7 +518,7 @@ NextCipherSuite: earlyTrafficSecret := finishedHash.deriveSecret(earlyTrafficLabel) c.earlyExporterSecret = finishedHash.deriveSecret(earlyExporterLabel) - c.useOutTrafficSecret(session.wireVersion, pskCipherSuite, earlyTrafficSecret) + c.useOutTrafficSecret(encryptionEarlyData, session.wireVersion, pskCipherSuite, earlyTrafficSecret) for _, earlyData := range c.config.Bugs.SendEarlyData { if _, err := c.writeRecord(recordTypeApplicationData, earlyData); err != nil { return err @@ -923,7 +923,7 @@ func (hs *clientHandshakeState) doTLS13Handshake() error { // traffic key. clientHandshakeTrafficSecret := hs.finishedHash.deriveSecret(clientHandshakeTrafficLabel) serverHandshakeTrafficSecret := hs.finishedHash.deriveSecret(serverHandshakeTrafficLabel) - if err := c.useInTrafficSecret(c.wireVersion, hs.suite, serverHandshakeTrafficSecret); err != nil { + if err := c.useInTrafficSecret(encryptionHandshake, c.wireVersion, hs.suite, serverHandshakeTrafficSecret); err != nil { return err } @@ -1098,7 +1098,7 @@ func (hs *clientHandshakeState) doTLS13Handshake() error { // Switch to application data keys on read. In particular, any alerts // from the client certificate are read over these keys. - if err := c.useInTrafficSecret(c.wireVersion, hs.suite, serverTrafficSecret); err != nil { + if err := c.useInTrafficSecret(encryptionApplication, c.wireVersion, hs.suite, serverTrafficSecret); err != nil { return err } @@ -1133,7 +1133,7 @@ func (hs *clientHandshakeState) doTLS13Handshake() error { // Send EndOfEarlyData and then switch write key to handshake // traffic key. - if encryptedExtensions.extensions.hasEarlyData && c.out.cipher != nil && !c.config.Bugs.SkipEndOfEarlyData { + if encryptedExtensions.extensions.hasEarlyData && !c.config.Bugs.SkipEndOfEarlyData && c.config.Bugs.MockQUICTransport == nil { if c.config.Bugs.SendStrayEarlyHandshake { helloRequest := new(helloRequestMsg) c.writeRecord(recordTypeHandshake, helloRequest.marshal()) @@ -1157,7 +1157,7 @@ func (hs *clientHandshakeState) doTLS13Handshake() error { c.writeRecord(recordTypeChangeCipherSpec, []byte{1}) } - c.useOutTrafficSecret(c.wireVersion, hs.suite, clientHandshakeTrafficSecret) + c.useOutTrafficSecret(encryptionHandshake, c.wireVersion, hs.suite, clientHandshakeTrafficSecret) // The client EncryptedExtensions message is sent if some extension uses it. // (Currently only ALPS does.) @@ -1263,7 +1263,7 @@ func (hs *clientHandshakeState) doTLS13Handshake() error { c.flushHandshake() // Switch to application data keys. - c.useOutTrafficSecret(c.wireVersion, hs.suite, clientTrafficSecret) + c.useOutTrafficSecret(encryptionApplication, c.wireVersion, hs.suite, clientTrafficSecret) c.resumptionSecret = hs.finishedHash.deriveSecret(resumptionLabel) for _, ticket := range deferredTickets { if err := c.processTLS13NewSessionTicket(ticket, hs.suite); err != nil { diff --git a/ssl/test/runner/handshake_server.go b/ssl/test/runner/handshake_server.go index df74ccd9d..3cdebefd9 100644 --- a/ssl/test/runner/handshake_server.go +++ b/ssl/test/runner/handshake_server.go @@ -747,7 +747,7 @@ ResendHelloRetryRequest: } sessionCipher := cipherSuiteFromID(hs.sessionState.cipherSuite) - if err := c.useInTrafficSecret(c.wireVersion, sessionCipher, earlyTrafficSecret); err != nil { + if err := c.useInTrafficSecret(encryptionEarlyData, c.wireVersion, sessionCipher, earlyTrafficSecret); err != nil { return err } @@ -854,7 +854,7 @@ ResendHelloRetryRequest: // Switch to handshake traffic keys. serverHandshakeTrafficSecret := hs.finishedHash.deriveSecret(serverHandshakeTrafficLabel) - c.useOutTrafficSecret(c.wireVersion, hs.suite, serverHandshakeTrafficSecret) + c.useOutTrafficSecret(encryptionHandshake, c.wireVersion, hs.suite, serverHandshakeTrafficSecret) // Derive handshake traffic read key, but don't switch yet. clientHandshakeTrafficSecret := hs.finishedHash.deriveSecret(clientHandshakeTrafficLabel) @@ -1038,7 +1038,7 @@ ResendHelloRetryRequest: // Switch to application data keys on write. In particular, any alerts // from the client certificate are sent over these keys. - c.useOutTrafficSecret(c.wireVersion, hs.suite, serverTrafficSecret) + c.useOutTrafficSecret(encryptionApplication, c.wireVersion, hs.suite, serverTrafficSecret) // Send 0.5-RTT messages. for _, halfRTTMsg := range config.Bugs.SendHalfRTTData { @@ -1063,7 +1063,7 @@ ResendHelloRetryRequest: } // Switch input stream to handshake traffic keys. - if err := c.useInTrafficSecret(c.wireVersion, hs.suite, clientHandshakeTrafficSecret); err != nil { + if err := c.useInTrafficSecret(encryptionHandshake, c.wireVersion, hs.suite, clientHandshakeTrafficSecret); err != nil { return err } @@ -1192,7 +1192,7 @@ ResendHelloRetryRequest: hs.writeClientHash(clientFinished.marshal()) // Switch to application data keys on read. - if err := c.useInTrafficSecret(c.wireVersion, hs.suite, clientTrafficSecret); err != nil { + if err := c.useInTrafficSecret(encryptionApplication, c.wireVersion, hs.suite, clientTrafficSecret); err != nil { return err } diff --git a/ssl/test/runner/mock_quic_transport.go b/ssl/test/runner/mock_quic_transport.go index 27a8bc326..99ce0f77f 100644 --- a/ssl/test/runner/mock_quic_transport.go +++ b/ssl/test/runner/mock_quic_transport.go @@ -26,6 +26,15 @@ const tagHandshake = byte('H') const tagApplication = byte('A') const tagAlert = byte('L') +type encryptionLevel byte + +const ( + encryptionInitial encryptionLevel = 0 + encryptionEarlyData encryptionLevel = 1 + encryptionHandshake encryptionLevel = 2 + encryptionApplication encryptionLevel = 3 +) + // mockQUICTransport provides a record layer for sending/receiving messages // when testing TLS over QUIC. It is only intended for testing, as it runs over // an in-order reliable transport, looks nothing like the QUIC wire image, and @@ -43,6 +52,7 @@ const tagAlert = byte('L') // cipher suite ID or tag. type mockQUICTransport struct { net.Conn + readLevel, writeLevel encryptionLevel readSecret, writeSecret []byte readCipherSuite, writeCipherSuite uint16 skipEarlyData bool @@ -54,37 +64,39 @@ func newMockQUICTransport(conn net.Conn) *mockQUICTransport { func (m *mockQUICTransport) read() (byte, []byte, error) { for { - header := make([]byte, 7) + header := make([]byte, 8) if _, err := io.ReadFull(m.Conn, header); err != nil { return 0, nil, err } - cipherSuite := binary.BigEndian.Uint16(header[1:3]) - length := binary.BigEndian.Uint32(header[3:]) + tag := header[0] + level := header[1] + cipherSuite := binary.BigEndian.Uint16(header[2:4]) + length := binary.BigEndian.Uint32(header[4:]) value := make([]byte, length) if _, err := io.ReadFull(m.Conn, value); err != nil { - return 0, nil, fmt.Errorf("Error reading record") + return 0, nil, fmt.Errorf("error reading record") } - if cipherSuite != m.readCipherSuite { - if m.skipEarlyData { + if level != byte(m.readLevel) { + if m.skipEarlyData && level == byte(encryptionEarlyData) { continue } - return 0, nil, fmt.Errorf("Received cipher suite %d does not match expected %d", cipherSuite, m.readCipherSuite) + return 0, nil, fmt.Errorf("received level %d does not match expected %d", level, m.readLevel) + } + if cipherSuite != m.readCipherSuite { + return 0, nil, fmt.Errorf("received cipher suite %d does not match expected %d", cipherSuite, m.readCipherSuite) } if len(m.readSecret) > len(value) { - return 0, nil, fmt.Errorf("Input length too short") + return 0, nil, fmt.Errorf("input length too short") } secret := value[:len(m.readSecret)] out := value[len(m.readSecret):] if !bytes.Equal(secret, m.readSecret) { - if m.skipEarlyData { - continue - } return 0, nil, fmt.Errorf("secrets don't match: got %x but expected %x", secret, m.readSecret) } - if m.skipEarlyData && header[0] == tagHandshake { - m.skipEarlyData = false - } - return header[0], out, nil + // Although not true for QUIC in general, our transport is ordered, so + // we expect to stop skipping early data after a valid record. + m.skipEarlyData = false + return tag, out, nil } } @@ -114,12 +126,13 @@ func (m *mockQUICTransport) writeRecord(typ recordType, data []byte) (int, error return 0, fmt.Errorf("unsupported record type %d\n", typ) } length := len(m.writeSecret) + len(data) - payload := make([]byte, 1+2+4+length) + payload := make([]byte, 1+1+2+4+length) payload[0] = tag - binary.BigEndian.PutUint16(payload[1:3], m.writeCipherSuite) - binary.BigEndian.PutUint32(payload[3:7], uint32(length)) - copy(payload[7:], m.writeSecret) - copy(payload[7+len(m.writeSecret):], data) + payload[1] = byte(m.writeLevel) + binary.BigEndian.PutUint16(payload[2:4], m.writeCipherSuite) + binary.BigEndian.PutUint32(payload[4:8], uint32(length)) + copy(payload[8:], m.writeSecret) + copy(payload[8+len(m.writeSecret):], data) if _, err := m.Conn.Write(payload); err != nil { return 0, err } diff --git a/ssl/test/runner/runner.go b/ssl/test/runner/runner.go index f1ec122c3..f3847d6fa 100644 --- a/ssl/test/runner/runner.go +++ b/ssl/test/runner/runner.go @@ -1350,6 +1350,11 @@ func runTest(statusChan chan statusMsg, test *testCase, shimPath string, mallocN flags = append(flags, "-on-resume-expect-accept-early-data") } + if test.protocol == quic { + // QUIC requires an early data context string. + flags = append(flags, "-quic-early-data-context", "context") + } + flags = append(flags, "-enable-early-data") if test.testType == clientTest { // Configure the runner with default maximum early data. diff --git a/ssl/test/test_config.cc b/ssl/test/test_config.cc index eb863ebf3..4f8486746 100644 --- a/ssl/test/test_config.cc +++ b/ssl/test/test_config.cc @@ -183,6 +183,7 @@ const Flag kStringFlags[] = { {"-handshaker-path", &TestConfig::handshaker_path}, {"-delegated-credential", &TestConfig::delegated_credential}, {"-expect-early-data-reason", &TestConfig::expect_early_data_reason}, + {"-quic-early-data-context", &TestConfig::quic_early_data_context}, }; // TODO(davidben): When we can depend on C++17 or Abseil, switch this to @@ -1797,5 +1798,13 @@ bssl::UniquePtr TestConfig::NewSSL( } } + if (!quic_early_data_context.empty() && + !SSL_set_quic_early_data_context( + ssl.get(), + reinterpret_cast(quic_early_data_context.data()), + quic_early_data_context.size())) { + return nullptr; + } + return ssl; } diff --git a/ssl/test/test_config.h b/ssl/test/test_config.h index 67cab951c..9279fca4c 100644 --- a/ssl/test/test_config.h +++ b/ssl/test/test_config.h @@ -182,6 +182,7 @@ struct TestConfig { bool expect_hrr = false; bool expect_no_hrr = false; bool wait_for_debugger = false; + std::string quic_early_data_context; int argc; char **argv;