Make QUIC tests work with early data.

This changes the format of the mock QUIC transport to include an
explicit encryption level, matching real QUIC a bit better. In
particular, we need that extra data to properly skip rejected early data
on the shim side. (On the runner, we manage it by synchronizing with the
TLS stack. Still, the levels make it a bit more accurate.)

Testing sending and receiving of actual early data is not very relevant
in QUIC since application I/O is external, but this allows us to more
easily run the same tests in TLS and QUIC.

Along the way, improve error-reporting in mock_quick_transport.cc so
it's easier to diagnose record-level mismatches.

Change-Id: I96175a4023134b03d61dac089f8e7ff4eb627933
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/44988
Commit-Queue: David Benjamin <davidben@google.com>
Reviewed-by: Adam Langley <agl@google.com>
chromium-5359
David Benjamin 4 years ago committed by CQ bot account: commit-bot@chromium.org
parent 7a55c80271
commit 47d1274fd2
  1. 106
      ssl/test/mock_quic_transport.cc
  2. 10
      ssl/test/mock_quic_transport.h
  3. 12
      ssl/test/runner/conn.go
  4. 12
      ssl/test/runner/handshake_client.go
  5. 10
      ssl/test/runner/handshake_server.go
  6. 51
      ssl/test/runner/mock_quic_transport.go
  7. 5
      ssl/test/runner/runner.go
  8. 9
      ssl/test/test_config.cc
  9. 1
      ssl/test/test_config.h

@ -73,47 +73,97 @@ bool ReadAll(BIO *bio, bssl::Span<uint8_t> out) {
return true; return true;
} }
const char *LevelToString(ssl_encryption_level_t level) {
switch (level) {
case ssl_encryption_initial:
return "initial";
case ssl_encryption_early_data:
return "early_data";
case ssl_encryption_handshake:
return "handshake";
case ssl_encryption_application:
return "application";
}
return "";
}
} // namespace } // namespace
bool MockQuicTransport::ReadHeader(uint8_t *out_tag, size_t *out_len) { bool MockQuicTransport::ReadHeader(uint8_t *out_tag,
uint8_t header[7]; enum ssl_encryption_level_t *out_level,
size_t *out_len) {
for (;;) {
uint8_t header[8];
if (!ReadAll(bio_.get(), header)) { if (!ReadAll(bio_.get(), header)) {
// TODO(davidben): Distinguish between errors and EOF. See
// ReadApplicationData.
return false; return false;
} }
*out_tag = header[0];
uint16_t cipher_suite = header[1] << 8 | header[2];
size_t remaining_bytes =
header[3] << 24 | header[4] << 16 | header[5] << 8 | header[6];
enum ssl_encryption_level_t level = SSL_quic_read_level(ssl_); CBS cbs;
if (*out_tag == kTagApplication) { uint8_t level_id;
if (SSL_in_early_data(ssl_)) { uint16_t cipher_suite;
level = ssl_encryption_early_data; uint32_t remaining_bytes;
} else { CBS_init(&cbs, header, sizeof(header));
level = ssl_encryption_application; if (!CBS_get_u8(&cbs, out_tag) ||
!CBS_get_u8(&cbs, &level_id) ||
!CBS_get_u16(&cbs, &cipher_suite) ||
!CBS_get_u32(&cbs, &remaining_bytes) ||
level_id >= read_levels_.size()) {
fprintf(stderr, "Error parsing record header.\n");
return false;
}
auto level = static_cast<ssl_encryption_level_t>(level_id);
// Non-initial levels must be configured before use.
uint16_t expect_cipher = read_levels_[level].cipher;
if (expect_cipher == 0 && level != ssl_encryption_initial) {
if (level == ssl_encryption_early_data) {
// If we receive early data records without any early data keys, skip
// the record. This means early data was rejected.
std::vector<uint8_t> discard(remaining_bytes);
if (!ReadAll(bio_.get(), bssl::MakeSpan(discard))) {
return false;
}
continue;
} }
fprintf(stderr,
"Got record at level %s, but keys were not configured.\n",
LevelToString(level));
return false;
} }
if (cipher_suite != read_levels_[level].cipher) { if (cipher_suite != expect_cipher) {
fprintf(stderr, "Got cipher suite 0x%04x at level %s, wanted 0x%04x.\n",
cipher_suite, LevelToString(level), expect_cipher);
return false; return false;
} }
const std::vector<uint8_t> &secret = read_levels_[level].secret; const std::vector<uint8_t> &secret = read_levels_[level].secret;
std::vector<uint8_t> read_secret(secret.size()); std::vector<uint8_t> read_secret(secret.size());
if (remaining_bytes < secret.size()) { if (remaining_bytes < secret.size()) {
fprintf(stderr, "Record at level %s too small.\n", LevelToString(level));
return false; return false;
} }
remaining_bytes -= secret.size(); remaining_bytes -= secret.size();
if (!ReadAll(bio_.get(), bssl::MakeSpan(read_secret)) || if (!ReadAll(bio_.get(), bssl::MakeSpan(read_secret))) {
read_secret != secret) { fprintf(stderr, "Error reading record secret.\n");
return false; return false;
} }
if (read_secret != secret) {
fprintf(stderr, "Encryption secret at level %s did not match.\n",
LevelToString(level));
return false;
}
*out_level = level;
*out_len = remaining_bytes; *out_len = remaining_bytes;
return true; return true;
}
} }
bool MockQuicTransport::ReadHandshake() { bool MockQuicTransport::ReadHandshake() {
uint8_t tag; uint8_t tag;
ssl_encryption_level_t level;
size_t len; size_t len;
if (!ReadHeader(&tag, &len)) { if (!ReadHeader(&tag, &level, &len)) {
return false; return false;
} }
if (tag != kTagHandshake) { if (tag != kTagHandshake) {
@ -124,8 +174,7 @@ bool MockQuicTransport::ReadHandshake() {
if (!ReadAll(bio_.get(), bssl::MakeSpan(buf))) { if (!ReadAll(bio_.get(), bssl::MakeSpan(buf))) {
return false; return false;
} }
return SSL_provide_quic_data(ssl_, SSL_quic_read_level(ssl_), buf.data(), return SSL_provide_quic_data(ssl_, level, buf.data(), buf.size());
buf.size());
} }
int MockQuicTransport::ReadApplicationData(uint8_t *out, size_t max_out) { int MockQuicTransport::ReadApplicationData(uint8_t *out, size_t max_out) {
@ -144,9 +193,10 @@ int MockQuicTransport::ReadApplicationData(uint8_t *out, size_t max_out) {
} }
uint8_t tag = 0; uint8_t tag = 0;
ssl_encryption_level_t level;
size_t len; size_t len;
while (true) { while (true) {
if (!ReadHeader(&tag, &len)) { if (!ReadHeader(&tag, &level, &len)) {
// Assume that a failure to read the header means there's no more to read, // Assume that a failure to read the header means there's no more to read,
// not an error reading. // not an error reading.
return 0; return 0;
@ -162,8 +212,7 @@ int MockQuicTransport::ReadApplicationData(uint8_t *out, size_t max_out) {
if (!ReadAll(bio_.get(), bssl::MakeSpan(buf))) { if (!ReadAll(bio_.get(), bssl::MakeSpan(buf))) {
return -1; return -1;
} }
if (SSL_provide_quic_data(ssl_, SSL_quic_read_level(ssl_), buf.data(), if (SSL_provide_quic_data(ssl_, level, buf.data(), buf.size()) != 1) {
buf.size()) != 1) {
return -1; return -1;
} }
if (SSL_in_init(ssl_)) { if (SSL_in_init(ssl_)) {
@ -203,14 +252,15 @@ bool MockQuicTransport::WriteRecord(enum ssl_encryption_level_t level,
uint16_t cipher_suite = write_levels_[level].cipher; uint16_t cipher_suite = write_levels_[level].cipher;
const std::vector<uint8_t> &secret = write_levels_[level].secret; const std::vector<uint8_t> &secret = write_levels_[level].secret;
size_t tlv_len = secret.size() + len; size_t tlv_len = secret.size() + len;
uint8_t header[7]; uint8_t header[8];
header[0] = tag; header[0] = tag;
header[1] = (cipher_suite >> 8) & 0xff; header[1] = level;
header[2] = cipher_suite & 0xff; header[2] = (cipher_suite >> 8) & 0xff;
header[3] = (tlv_len >> 24) & 0xff; header[3] = cipher_suite & 0xff;
header[4] = (tlv_len >> 16) & 0xff; header[4] = (tlv_len >> 24) & 0xff;
header[5] = (tlv_len >> 8) & 0xff; header[5] = (tlv_len >> 16) & 0xff;
header[6] = tlv_len & 0xff; header[6] = (tlv_len >> 8) & 0xff;
header[7] = tlv_len & 0xff;
return BIO_write_all(bio_.get(), header, sizeof(header)) && return BIO_write_all(bio_.get(), header, sizeof(header)) &&
BIO_write_all(bio_.get(), secret.data(), secret.size()) && BIO_write_all(bio_.get(), secret.data(), secret.size()) &&
BIO_write_all(bio_.get(), data, len); BIO_write_all(bio_.get(), data, len);

@ -45,10 +45,12 @@ class MockQuicTransport {
// Reads a record header from |bio_| and returns whether the record was read // Reads a record header from |bio_| and returns whether the record was read
// successfully. As part of reading the header, this function checks that the // successfully. As part of reading the header, this function checks that the
// cipher suite and secret in the header are correct. On success, the tag // cipher suite and secret in the header are correct. On success, the tag
// indicating the TLS record type is put in |*out_tag|, the length of the TLS // indicating the TLS record type is put in |*out_tag|, the encryption level
// record is put in |*out_len|, and the next thing to be read from |bio_| is // is put in |*out_level|, the length of the TLS record is put in |*out_len|,
// |*out_len| bytes of the TLS record. // and the next thing to be read from |bio_| is |*out_len| bytes of the TLS
bool ReadHeader(uint8_t *out_tag, size_t *out_len); // record.
bool ReadHeader(uint8_t *out_tag, enum ssl_encryption_level_t *out_level,
size_t *out_len);
// Writes a MockQuicTransport record to |bio_| at encryption level |level| // Writes a MockQuicTransport record to |bio_| at encryption level |level|
// with record type |tag| and a TLS record payload of length |len| from // with record type |tag| and a TLS record payload of length |len| from

@ -754,7 +754,7 @@ func (hc *halfConn) splitBlock(b *block, n int) (*block, *block) {
return b, bb return b, bb
} }
func (c *Conn) useInTrafficSecret(version uint16, suite *cipherSuite, secret []byte) error { func (c *Conn) useInTrafficSecret(level encryptionLevel, version uint16, suite *cipherSuite, secret []byte) error {
if c.hand.Len() != 0 { if c.hand.Len() != 0 {
return c.in.setErrorLocked(errors.New("tls: buffered handshake messages on cipher change")) return c.in.setErrorLocked(errors.New("tls: buffered handshake messages on cipher change"))
} }
@ -763,6 +763,7 @@ func (c *Conn) useInTrafficSecret(version uint16, suite *cipherSuite, secret []b
side = clientWrite side = clientWrite
} }
if c.config.Bugs.MockQUICTransport != nil { if c.config.Bugs.MockQUICTransport != nil {
c.config.Bugs.MockQUICTransport.readLevel = level
c.config.Bugs.MockQUICTransport.readSecret = secret c.config.Bugs.MockQUICTransport.readSecret = secret
c.config.Bugs.MockQUICTransport.readCipherSuite = suite.id c.config.Bugs.MockQUICTransport.readCipherSuite = suite.id
} }
@ -771,12 +772,13 @@ func (c *Conn) useInTrafficSecret(version uint16, suite *cipherSuite, secret []b
return nil return nil
} }
func (c *Conn) useOutTrafficSecret(version uint16, suite *cipherSuite, secret []byte) { func (c *Conn) useOutTrafficSecret(level encryptionLevel, version uint16, suite *cipherSuite, secret []byte) {
side := serverWrite side := serverWrite
if c.isClient { if c.isClient {
side = clientWrite side = clientWrite
} }
if c.config.Bugs.MockQUICTransport != nil { if c.config.Bugs.MockQUICTransport != nil {
c.config.Bugs.MockQUICTransport.writeLevel = level
c.config.Bugs.MockQUICTransport.writeSecret = secret c.config.Bugs.MockQUICTransport.writeSecret = secret
c.config.Bugs.MockQUICTransport.writeCipherSuite = suite.id c.config.Bugs.MockQUICTransport.writeCipherSuite = suite.id
} }
@ -1677,7 +1679,7 @@ func (c *Conn) handlePostHandshakeMessage() error {
if c.config.Bugs.RejectUnsolicitedKeyUpdate { if c.config.Bugs.RejectUnsolicitedKeyUpdate {
return errors.New("tls: unexpected KeyUpdate message") return errors.New("tls: unexpected KeyUpdate message")
} }
if err := c.useInTrafficSecret(c.in.wireVersion, c.cipherSuite, updateTrafficSecret(c.cipherSuite.hash(), c.wireVersion, c.in.trafficSecret)); err != nil { if err := c.useInTrafficSecret(encryptionApplication, c.in.wireVersion, c.cipherSuite, updateTrafficSecret(c.cipherSuite.hash(), c.wireVersion, c.in.trafficSecret)); err != nil {
return err return err
} }
if keyUpdate.keyUpdateRequest == keyUpdateRequested { if keyUpdate.keyUpdateRequest == keyUpdateRequested {
@ -1711,7 +1713,7 @@ func (c *Conn) ReadKeyUpdateACK() error {
return errors.New("tls: received invalid KeyUpdate message") return errors.New("tls: received invalid KeyUpdate message")
} }
return c.useInTrafficSecret(c.in.wireVersion, c.cipherSuite, updateTrafficSecret(c.cipherSuite.hash(), c.wireVersion, c.in.trafficSecret)) return c.useInTrafficSecret(encryptionApplication, c.in.wireVersion, c.cipherSuite, updateTrafficSecret(c.cipherSuite.hash(), c.wireVersion, c.in.trafficSecret))
} }
func (c *Conn) Renegotiate() error { func (c *Conn) Renegotiate() error {
@ -2065,7 +2067,7 @@ func (c *Conn) sendKeyUpdateLocked(keyUpdateRequest byte) error {
if err := c.flushHandshake(); err != nil { if err := c.flushHandshake(); err != nil {
return err return err
} }
c.useOutTrafficSecret(c.out.wireVersion, c.cipherSuite, updateTrafficSecret(c.cipherSuite.hash(), c.wireVersion, c.out.trafficSecret)) c.useOutTrafficSecret(encryptionApplication, c.out.wireVersion, c.cipherSuite, updateTrafficSecret(c.cipherSuite.hash(), c.wireVersion, c.out.trafficSecret))
return nil return nil
} }

@ -518,7 +518,7 @@ NextCipherSuite:
earlyTrafficSecret := finishedHash.deriveSecret(earlyTrafficLabel) earlyTrafficSecret := finishedHash.deriveSecret(earlyTrafficLabel)
c.earlyExporterSecret = finishedHash.deriveSecret(earlyExporterLabel) c.earlyExporterSecret = finishedHash.deriveSecret(earlyExporterLabel)
c.useOutTrafficSecret(session.wireVersion, pskCipherSuite, earlyTrafficSecret) c.useOutTrafficSecret(encryptionEarlyData, session.wireVersion, pskCipherSuite, earlyTrafficSecret)
for _, earlyData := range c.config.Bugs.SendEarlyData { for _, earlyData := range c.config.Bugs.SendEarlyData {
if _, err := c.writeRecord(recordTypeApplicationData, earlyData); err != nil { if _, err := c.writeRecord(recordTypeApplicationData, earlyData); err != nil {
return err return err
@ -923,7 +923,7 @@ func (hs *clientHandshakeState) doTLS13Handshake() error {
// traffic key. // traffic key.
clientHandshakeTrafficSecret := hs.finishedHash.deriveSecret(clientHandshakeTrafficLabel) clientHandshakeTrafficSecret := hs.finishedHash.deriveSecret(clientHandshakeTrafficLabel)
serverHandshakeTrafficSecret := hs.finishedHash.deriveSecret(serverHandshakeTrafficLabel) serverHandshakeTrafficSecret := hs.finishedHash.deriveSecret(serverHandshakeTrafficLabel)
if err := c.useInTrafficSecret(c.wireVersion, hs.suite, serverHandshakeTrafficSecret); err != nil { if err := c.useInTrafficSecret(encryptionHandshake, c.wireVersion, hs.suite, serverHandshakeTrafficSecret); err != nil {
return err return err
} }
@ -1098,7 +1098,7 @@ func (hs *clientHandshakeState) doTLS13Handshake() error {
// Switch to application data keys on read. In particular, any alerts // Switch to application data keys on read. In particular, any alerts
// from the client certificate are read over these keys. // from the client certificate are read over these keys.
if err := c.useInTrafficSecret(c.wireVersion, hs.suite, serverTrafficSecret); err != nil { if err := c.useInTrafficSecret(encryptionApplication, c.wireVersion, hs.suite, serverTrafficSecret); err != nil {
return err return err
} }
@ -1133,7 +1133,7 @@ func (hs *clientHandshakeState) doTLS13Handshake() error {
// Send EndOfEarlyData and then switch write key to handshake // Send EndOfEarlyData and then switch write key to handshake
// traffic key. // traffic key.
if encryptedExtensions.extensions.hasEarlyData && c.out.cipher != nil && !c.config.Bugs.SkipEndOfEarlyData { if encryptedExtensions.extensions.hasEarlyData && !c.config.Bugs.SkipEndOfEarlyData && c.config.Bugs.MockQUICTransport == nil {
if c.config.Bugs.SendStrayEarlyHandshake { if c.config.Bugs.SendStrayEarlyHandshake {
helloRequest := new(helloRequestMsg) helloRequest := new(helloRequestMsg)
c.writeRecord(recordTypeHandshake, helloRequest.marshal()) c.writeRecord(recordTypeHandshake, helloRequest.marshal())
@ -1157,7 +1157,7 @@ func (hs *clientHandshakeState) doTLS13Handshake() error {
c.writeRecord(recordTypeChangeCipherSpec, []byte{1}) c.writeRecord(recordTypeChangeCipherSpec, []byte{1})
} }
c.useOutTrafficSecret(c.wireVersion, hs.suite, clientHandshakeTrafficSecret) c.useOutTrafficSecret(encryptionHandshake, c.wireVersion, hs.suite, clientHandshakeTrafficSecret)
// The client EncryptedExtensions message is sent if some extension uses it. // The client EncryptedExtensions message is sent if some extension uses it.
// (Currently only ALPS does.) // (Currently only ALPS does.)
@ -1263,7 +1263,7 @@ func (hs *clientHandshakeState) doTLS13Handshake() error {
c.flushHandshake() c.flushHandshake()
// Switch to application data keys. // Switch to application data keys.
c.useOutTrafficSecret(c.wireVersion, hs.suite, clientTrafficSecret) c.useOutTrafficSecret(encryptionApplication, c.wireVersion, hs.suite, clientTrafficSecret)
c.resumptionSecret = hs.finishedHash.deriveSecret(resumptionLabel) c.resumptionSecret = hs.finishedHash.deriveSecret(resumptionLabel)
for _, ticket := range deferredTickets { for _, ticket := range deferredTickets {
if err := c.processTLS13NewSessionTicket(ticket, hs.suite); err != nil { if err := c.processTLS13NewSessionTicket(ticket, hs.suite); err != nil {

@ -747,7 +747,7 @@ ResendHelloRetryRequest:
} }
sessionCipher := cipherSuiteFromID(hs.sessionState.cipherSuite) sessionCipher := cipherSuiteFromID(hs.sessionState.cipherSuite)
if err := c.useInTrafficSecret(c.wireVersion, sessionCipher, earlyTrafficSecret); err != nil { if err := c.useInTrafficSecret(encryptionEarlyData, c.wireVersion, sessionCipher, earlyTrafficSecret); err != nil {
return err return err
} }
@ -854,7 +854,7 @@ ResendHelloRetryRequest:
// Switch to handshake traffic keys. // Switch to handshake traffic keys.
serverHandshakeTrafficSecret := hs.finishedHash.deriveSecret(serverHandshakeTrafficLabel) serverHandshakeTrafficSecret := hs.finishedHash.deriveSecret(serverHandshakeTrafficLabel)
c.useOutTrafficSecret(c.wireVersion, hs.suite, serverHandshakeTrafficSecret) c.useOutTrafficSecret(encryptionHandshake, c.wireVersion, hs.suite, serverHandshakeTrafficSecret)
// Derive handshake traffic read key, but don't switch yet. // Derive handshake traffic read key, but don't switch yet.
clientHandshakeTrafficSecret := hs.finishedHash.deriveSecret(clientHandshakeTrafficLabel) clientHandshakeTrafficSecret := hs.finishedHash.deriveSecret(clientHandshakeTrafficLabel)
@ -1038,7 +1038,7 @@ ResendHelloRetryRequest:
// Switch to application data keys on write. In particular, any alerts // Switch to application data keys on write. In particular, any alerts
// from the client certificate are sent over these keys. // from the client certificate are sent over these keys.
c.useOutTrafficSecret(c.wireVersion, hs.suite, serverTrafficSecret) c.useOutTrafficSecret(encryptionApplication, c.wireVersion, hs.suite, serverTrafficSecret)
// Send 0.5-RTT messages. // Send 0.5-RTT messages.
for _, halfRTTMsg := range config.Bugs.SendHalfRTTData { for _, halfRTTMsg := range config.Bugs.SendHalfRTTData {
@ -1063,7 +1063,7 @@ ResendHelloRetryRequest:
} }
// Switch input stream to handshake traffic keys. // Switch input stream to handshake traffic keys.
if err := c.useInTrafficSecret(c.wireVersion, hs.suite, clientHandshakeTrafficSecret); err != nil { if err := c.useInTrafficSecret(encryptionHandshake, c.wireVersion, hs.suite, clientHandshakeTrafficSecret); err != nil {
return err return err
} }
@ -1192,7 +1192,7 @@ ResendHelloRetryRequest:
hs.writeClientHash(clientFinished.marshal()) hs.writeClientHash(clientFinished.marshal())
// Switch to application data keys on read. // Switch to application data keys on read.
if err := c.useInTrafficSecret(c.wireVersion, hs.suite, clientTrafficSecret); err != nil { if err := c.useInTrafficSecret(encryptionApplication, c.wireVersion, hs.suite, clientTrafficSecret); err != nil {
return err return err
} }

@ -26,6 +26,15 @@ const tagHandshake = byte('H')
const tagApplication = byte('A') const tagApplication = byte('A')
const tagAlert = byte('L') const tagAlert = byte('L')
type encryptionLevel byte
const (
encryptionInitial encryptionLevel = 0
encryptionEarlyData encryptionLevel = 1
encryptionHandshake encryptionLevel = 2
encryptionApplication encryptionLevel = 3
)
// mockQUICTransport provides a record layer for sending/receiving messages // mockQUICTransport provides a record layer for sending/receiving messages
// when testing TLS over QUIC. It is only intended for testing, as it runs over // when testing TLS over QUIC. It is only intended for testing, as it runs over
// an in-order reliable transport, looks nothing like the QUIC wire image, and // an in-order reliable transport, looks nothing like the QUIC wire image, and
@ -43,6 +52,7 @@ const tagAlert = byte('L')
// cipher suite ID or tag. // cipher suite ID or tag.
type mockQUICTransport struct { type mockQUICTransport struct {
net.Conn net.Conn
readLevel, writeLevel encryptionLevel
readSecret, writeSecret []byte readSecret, writeSecret []byte
readCipherSuite, writeCipherSuite uint16 readCipherSuite, writeCipherSuite uint16
skipEarlyData bool skipEarlyData bool
@ -54,37 +64,39 @@ func newMockQUICTransport(conn net.Conn) *mockQUICTransport {
func (m *mockQUICTransport) read() (byte, []byte, error) { func (m *mockQUICTransport) read() (byte, []byte, error) {
for { for {
header := make([]byte, 7) header := make([]byte, 8)
if _, err := io.ReadFull(m.Conn, header); err != nil { if _, err := io.ReadFull(m.Conn, header); err != nil {
return 0, nil, err return 0, nil, err
} }
cipherSuite := binary.BigEndian.Uint16(header[1:3]) tag := header[0]
length := binary.BigEndian.Uint32(header[3:]) level := header[1]
cipherSuite := binary.BigEndian.Uint16(header[2:4])
length := binary.BigEndian.Uint32(header[4:])
value := make([]byte, length) value := make([]byte, length)
if _, err := io.ReadFull(m.Conn, value); err != nil { if _, err := io.ReadFull(m.Conn, value); err != nil {
return 0, nil, fmt.Errorf("Error reading record") return 0, nil, fmt.Errorf("error reading record")
} }
if cipherSuite != m.readCipherSuite { if level != byte(m.readLevel) {
if m.skipEarlyData { if m.skipEarlyData && level == byte(encryptionEarlyData) {
continue continue
} }
return 0, nil, fmt.Errorf("Received cipher suite %d does not match expected %d", cipherSuite, m.readCipherSuite) return 0, nil, fmt.Errorf("received level %d does not match expected %d", level, m.readLevel)
}
if cipherSuite != m.readCipherSuite {
return 0, nil, fmt.Errorf("received cipher suite %d does not match expected %d", cipherSuite, m.readCipherSuite)
} }
if len(m.readSecret) > len(value) { if len(m.readSecret) > len(value) {
return 0, nil, fmt.Errorf("Input length too short") return 0, nil, fmt.Errorf("input length too short")
} }
secret := value[:len(m.readSecret)] secret := value[:len(m.readSecret)]
out := value[len(m.readSecret):] out := value[len(m.readSecret):]
if !bytes.Equal(secret, m.readSecret) { if !bytes.Equal(secret, m.readSecret) {
if m.skipEarlyData {
continue
}
return 0, nil, fmt.Errorf("secrets don't match: got %x but expected %x", secret, m.readSecret) return 0, nil, fmt.Errorf("secrets don't match: got %x but expected %x", secret, m.readSecret)
} }
if m.skipEarlyData && header[0] == tagHandshake { // Although not true for QUIC in general, our transport is ordered, so
// we expect to stop skipping early data after a valid record.
m.skipEarlyData = false m.skipEarlyData = false
} return tag, out, nil
return header[0], out, nil
} }
} }
@ -114,12 +126,13 @@ func (m *mockQUICTransport) writeRecord(typ recordType, data []byte) (int, error
return 0, fmt.Errorf("unsupported record type %d\n", typ) return 0, fmt.Errorf("unsupported record type %d\n", typ)
} }
length := len(m.writeSecret) + len(data) length := len(m.writeSecret) + len(data)
payload := make([]byte, 1+2+4+length) payload := make([]byte, 1+1+2+4+length)
payload[0] = tag payload[0] = tag
binary.BigEndian.PutUint16(payload[1:3], m.writeCipherSuite) payload[1] = byte(m.writeLevel)
binary.BigEndian.PutUint32(payload[3:7], uint32(length)) binary.BigEndian.PutUint16(payload[2:4], m.writeCipherSuite)
copy(payload[7:], m.writeSecret) binary.BigEndian.PutUint32(payload[4:8], uint32(length))
copy(payload[7+len(m.writeSecret):], data) copy(payload[8:], m.writeSecret)
copy(payload[8+len(m.writeSecret):], data)
if _, err := m.Conn.Write(payload); err != nil { if _, err := m.Conn.Write(payload); err != nil {
return 0, err return 0, err
} }

@ -1350,6 +1350,11 @@ func runTest(statusChan chan statusMsg, test *testCase, shimPath string, mallocN
flags = append(flags, "-on-resume-expect-accept-early-data") flags = append(flags, "-on-resume-expect-accept-early-data")
} }
if test.protocol == quic {
// QUIC requires an early data context string.
flags = append(flags, "-quic-early-data-context", "context")
}
flags = append(flags, "-enable-early-data") flags = append(flags, "-enable-early-data")
if test.testType == clientTest { if test.testType == clientTest {
// Configure the runner with default maximum early data. // Configure the runner with default maximum early data.

@ -183,6 +183,7 @@ const Flag<std::string> kStringFlags[] = {
{"-handshaker-path", &TestConfig::handshaker_path}, {"-handshaker-path", &TestConfig::handshaker_path},
{"-delegated-credential", &TestConfig::delegated_credential}, {"-delegated-credential", &TestConfig::delegated_credential},
{"-expect-early-data-reason", &TestConfig::expect_early_data_reason}, {"-expect-early-data-reason", &TestConfig::expect_early_data_reason},
{"-quic-early-data-context", &TestConfig::quic_early_data_context},
}; };
// TODO(davidben): When we can depend on C++17 or Abseil, switch this to // TODO(davidben): When we can depend on C++17 or Abseil, switch this to
@ -1797,5 +1798,13 @@ bssl::UniquePtr<SSL> TestConfig::NewSSL(
} }
} }
if (!quic_early_data_context.empty() &&
!SSL_set_quic_early_data_context(
ssl.get(),
reinterpret_cast<const uint8_t *>(quic_early_data_context.data()),
quic_early_data_context.size())) {
return nullptr;
}
return ssl; return ssl;
} }

@ -182,6 +182,7 @@ struct TestConfig {
bool expect_hrr = false; bool expect_hrr = false;
bool expect_no_hrr = false; bool expect_no_hrr = false;
bool wait_for_debugger = false; bool wait_for_debugger = false;
std::string quic_early_data_context;
int argc; int argc;
char **argv; char **argv;

Loading…
Cancel
Save