Make QUIC tests work with early data.

This changes the format of the mock QUIC transport to include an
explicit encryption level, matching real QUIC a bit better. In
particular, we need that extra data to properly skip rejected early data
on the shim side. (On the runner, we manage it by synchronizing with the
TLS stack. Still, the levels make it a bit more accurate.)

Testing sending and receiving of actual early data is not very relevant
in QUIC since application I/O is external, but this allows us to more
easily run the same tests in TLS and QUIC.

Along the way, improve error-reporting in mock_quick_transport.cc so
it's easier to diagnose record-level mismatches.

Change-Id: I96175a4023134b03d61dac089f8e7ff4eb627933
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/44988
Commit-Queue: David Benjamin <davidben@google.com>
Reviewed-by: Adam Langley <agl@google.com>
chromium-5359
David Benjamin 4 years ago committed by CQ bot account: commit-bot@chromium.org
parent 7a55c80271
commit 47d1274fd2
  1. 138
      ssl/test/mock_quic_transport.cc
  2. 10
      ssl/test/mock_quic_transport.h
  3. 12
      ssl/test/runner/conn.go
  4. 12
      ssl/test/runner/handshake_client.go
  5. 10
      ssl/test/runner/handshake_server.go
  6. 53
      ssl/test/runner/mock_quic_transport.go
  7. 5
      ssl/test/runner/runner.go
  8. 9
      ssl/test/test_config.cc
  9. 1
      ssl/test/test_config.h

@ -73,47 +73,97 @@ bool ReadAll(BIO *bio, bssl::Span<uint8_t> out) {
return true;
}
const char *LevelToString(ssl_encryption_level_t level) {
switch (level) {
case ssl_encryption_initial:
return "initial";
case ssl_encryption_early_data:
return "early_data";
case ssl_encryption_handshake:
return "handshake";
case ssl_encryption_application:
return "application";
}
return "";
}
} // namespace
bool MockQuicTransport::ReadHeader(uint8_t *out_tag, size_t *out_len) {
uint8_t header[7];
if (!ReadAll(bio_.get(), header)) {
return false;
}
*out_tag = header[0];
uint16_t cipher_suite = header[1] << 8 | header[2];
size_t remaining_bytes =
header[3] << 24 | header[4] << 16 | header[5] << 8 | header[6];
enum ssl_encryption_level_t level = SSL_quic_read_level(ssl_);
if (*out_tag == kTagApplication) {
if (SSL_in_early_data(ssl_)) {
level = ssl_encryption_early_data;
} else {
level = ssl_encryption_application;
bool MockQuicTransport::ReadHeader(uint8_t *out_tag,
enum ssl_encryption_level_t *out_level,
size_t *out_len) {
for (;;) {
uint8_t header[8];
if (!ReadAll(bio_.get(), header)) {
// TODO(davidben): Distinguish between errors and EOF. See
// ReadApplicationData.
return false;
}
CBS cbs;
uint8_t level_id;
uint16_t cipher_suite;
uint32_t remaining_bytes;
CBS_init(&cbs, header, sizeof(header));
if (!CBS_get_u8(&cbs, out_tag) ||
!CBS_get_u8(&cbs, &level_id) ||
!CBS_get_u16(&cbs, &cipher_suite) ||
!CBS_get_u32(&cbs, &remaining_bytes) ||
level_id >= read_levels_.size()) {
fprintf(stderr, "Error parsing record header.\n");
return false;
}
auto level = static_cast<ssl_encryption_level_t>(level_id);
// Non-initial levels must be configured before use.
uint16_t expect_cipher = read_levels_[level].cipher;
if (expect_cipher == 0 && level != ssl_encryption_initial) {
if (level == ssl_encryption_early_data) {
// If we receive early data records without any early data keys, skip
// the record. This means early data was rejected.
std::vector<uint8_t> discard(remaining_bytes);
if (!ReadAll(bio_.get(), bssl::MakeSpan(discard))) {
return false;
}
continue;
}
fprintf(stderr,
"Got record at level %s, but keys were not configured.\n",
LevelToString(level));
return false;
}
if (cipher_suite != expect_cipher) {
fprintf(stderr, "Got cipher suite 0x%04x at level %s, wanted 0x%04x.\n",
cipher_suite, LevelToString(level), expect_cipher);
return false;
}
const std::vector<uint8_t> &secret = read_levels_[level].secret;
std::vector<uint8_t> read_secret(secret.size());
if (remaining_bytes < secret.size()) {
fprintf(stderr, "Record at level %s too small.\n", LevelToString(level));
return false;
}
remaining_bytes -= secret.size();
if (!ReadAll(bio_.get(), bssl::MakeSpan(read_secret))) {
fprintf(stderr, "Error reading record secret.\n");
return false;
}
if (read_secret != secret) {
fprintf(stderr, "Encryption secret at level %s did not match.\n",
LevelToString(level));
return false;
}
*out_level = level;
*out_len = remaining_bytes;
return true;
}
if (cipher_suite != read_levels_[level].cipher) {
return false;
}
const std::vector<uint8_t> &secret = read_levels_[level].secret;
std::vector<uint8_t> read_secret(secret.size());
if (remaining_bytes < secret.size()) {
return false;
}
remaining_bytes -= secret.size();
if (!ReadAll(bio_.get(), bssl::MakeSpan(read_secret)) ||
read_secret != secret) {
return false;
}
*out_len = remaining_bytes;
return true;
}
bool MockQuicTransport::ReadHandshake() {
uint8_t tag;
ssl_encryption_level_t level;
size_t len;
if (!ReadHeader(&tag, &len)) {
if (!ReadHeader(&tag, &level, &len)) {
return false;
}
if (tag != kTagHandshake) {
@ -124,8 +174,7 @@ bool MockQuicTransport::ReadHandshake() {
if (!ReadAll(bio_.get(), bssl::MakeSpan(buf))) {
return false;
}
return SSL_provide_quic_data(ssl_, SSL_quic_read_level(ssl_), buf.data(),
buf.size());
return SSL_provide_quic_data(ssl_, level, buf.data(), buf.size());
}
int MockQuicTransport::ReadApplicationData(uint8_t *out, size_t max_out) {
@ -144,9 +193,10 @@ int MockQuicTransport::ReadApplicationData(uint8_t *out, size_t max_out) {
}
uint8_t tag = 0;
ssl_encryption_level_t level;
size_t len;
while (true) {
if (!ReadHeader(&tag, &len)) {
if (!ReadHeader(&tag, &level, &len)) {
// Assume that a failure to read the header means there's no more to read,
// not an error reading.
return 0;
@ -162,8 +212,7 @@ int MockQuicTransport::ReadApplicationData(uint8_t *out, size_t max_out) {
if (!ReadAll(bio_.get(), bssl::MakeSpan(buf))) {
return -1;
}
if (SSL_provide_quic_data(ssl_, SSL_quic_read_level(ssl_), buf.data(),
buf.size()) != 1) {
if (SSL_provide_quic_data(ssl_, level, buf.data(), buf.size()) != 1) {
return -1;
}
if (SSL_in_init(ssl_)) {
@ -203,14 +252,15 @@ bool MockQuicTransport::WriteRecord(enum ssl_encryption_level_t level,
uint16_t cipher_suite = write_levels_[level].cipher;
const std::vector<uint8_t> &secret = write_levels_[level].secret;
size_t tlv_len = secret.size() + len;
uint8_t header[7];
uint8_t header[8];
header[0] = tag;
header[1] = (cipher_suite >> 8) & 0xff;
header[2] = cipher_suite & 0xff;
header[3] = (tlv_len >> 24) & 0xff;
header[4] = (tlv_len >> 16) & 0xff;
header[5] = (tlv_len >> 8) & 0xff;
header[6] = tlv_len & 0xff;
header[1] = level;
header[2] = (cipher_suite >> 8) & 0xff;
header[3] = cipher_suite & 0xff;
header[4] = (tlv_len >> 24) & 0xff;
header[5] = (tlv_len >> 16) & 0xff;
header[6] = (tlv_len >> 8) & 0xff;
header[7] = tlv_len & 0xff;
return BIO_write_all(bio_.get(), header, sizeof(header)) &&
BIO_write_all(bio_.get(), secret.data(), secret.size()) &&
BIO_write_all(bio_.get(), data, len);

@ -45,10 +45,12 @@ class MockQuicTransport {
// Reads a record header from |bio_| and returns whether the record was read
// successfully. As part of reading the header, this function checks that the
// cipher suite and secret in the header are correct. On success, the tag
// indicating the TLS record type is put in |*out_tag|, the length of the TLS
// record is put in |*out_len|, and the next thing to be read from |bio_| is
// |*out_len| bytes of the TLS record.
bool ReadHeader(uint8_t *out_tag, size_t *out_len);
// indicating the TLS record type is put in |*out_tag|, the encryption level
// is put in |*out_level|, the length of the TLS record is put in |*out_len|,
// and the next thing to be read from |bio_| is |*out_len| bytes of the TLS
// record.
bool ReadHeader(uint8_t *out_tag, enum ssl_encryption_level_t *out_level,
size_t *out_len);
// Writes a MockQuicTransport record to |bio_| at encryption level |level|
// with record type |tag| and a TLS record payload of length |len| from

@ -754,7 +754,7 @@ func (hc *halfConn) splitBlock(b *block, n int) (*block, *block) {
return b, bb
}
func (c *Conn) useInTrafficSecret(version uint16, suite *cipherSuite, secret []byte) error {
func (c *Conn) useInTrafficSecret(level encryptionLevel, version uint16, suite *cipherSuite, secret []byte) error {
if c.hand.Len() != 0 {
return c.in.setErrorLocked(errors.New("tls: buffered handshake messages on cipher change"))
}
@ -763,6 +763,7 @@ func (c *Conn) useInTrafficSecret(version uint16, suite *cipherSuite, secret []b
side = clientWrite
}
if c.config.Bugs.MockQUICTransport != nil {
c.config.Bugs.MockQUICTransport.readLevel = level
c.config.Bugs.MockQUICTransport.readSecret = secret
c.config.Bugs.MockQUICTransport.readCipherSuite = suite.id
}
@ -771,12 +772,13 @@ func (c *Conn) useInTrafficSecret(version uint16, suite *cipherSuite, secret []b
return nil
}
func (c *Conn) useOutTrafficSecret(version uint16, suite *cipherSuite, secret []byte) {
func (c *Conn) useOutTrafficSecret(level encryptionLevel, version uint16, suite *cipherSuite, secret []byte) {
side := serverWrite
if c.isClient {
side = clientWrite
}
if c.config.Bugs.MockQUICTransport != nil {
c.config.Bugs.MockQUICTransport.writeLevel = level
c.config.Bugs.MockQUICTransport.writeSecret = secret
c.config.Bugs.MockQUICTransport.writeCipherSuite = suite.id
}
@ -1677,7 +1679,7 @@ func (c *Conn) handlePostHandshakeMessage() error {
if c.config.Bugs.RejectUnsolicitedKeyUpdate {
return errors.New("tls: unexpected KeyUpdate message")
}
if err := c.useInTrafficSecret(c.in.wireVersion, c.cipherSuite, updateTrafficSecret(c.cipherSuite.hash(), c.wireVersion, c.in.trafficSecret)); err != nil {
if err := c.useInTrafficSecret(encryptionApplication, c.in.wireVersion, c.cipherSuite, updateTrafficSecret(c.cipherSuite.hash(), c.wireVersion, c.in.trafficSecret)); err != nil {
return err
}
if keyUpdate.keyUpdateRequest == keyUpdateRequested {
@ -1711,7 +1713,7 @@ func (c *Conn) ReadKeyUpdateACK() error {
return errors.New("tls: received invalid KeyUpdate message")
}
return c.useInTrafficSecret(c.in.wireVersion, c.cipherSuite, updateTrafficSecret(c.cipherSuite.hash(), c.wireVersion, c.in.trafficSecret))
return c.useInTrafficSecret(encryptionApplication, c.in.wireVersion, c.cipherSuite, updateTrafficSecret(c.cipherSuite.hash(), c.wireVersion, c.in.trafficSecret))
}
func (c *Conn) Renegotiate() error {
@ -2065,7 +2067,7 @@ func (c *Conn) sendKeyUpdateLocked(keyUpdateRequest byte) error {
if err := c.flushHandshake(); err != nil {
return err
}
c.useOutTrafficSecret(c.out.wireVersion, c.cipherSuite, updateTrafficSecret(c.cipherSuite.hash(), c.wireVersion, c.out.trafficSecret))
c.useOutTrafficSecret(encryptionApplication, c.out.wireVersion, c.cipherSuite, updateTrafficSecret(c.cipherSuite.hash(), c.wireVersion, c.out.trafficSecret))
return nil
}

@ -518,7 +518,7 @@ NextCipherSuite:
earlyTrafficSecret := finishedHash.deriveSecret(earlyTrafficLabel)
c.earlyExporterSecret = finishedHash.deriveSecret(earlyExporterLabel)
c.useOutTrafficSecret(session.wireVersion, pskCipherSuite, earlyTrafficSecret)
c.useOutTrafficSecret(encryptionEarlyData, session.wireVersion, pskCipherSuite, earlyTrafficSecret)
for _, earlyData := range c.config.Bugs.SendEarlyData {
if _, err := c.writeRecord(recordTypeApplicationData, earlyData); err != nil {
return err
@ -923,7 +923,7 @@ func (hs *clientHandshakeState) doTLS13Handshake() error {
// traffic key.
clientHandshakeTrafficSecret := hs.finishedHash.deriveSecret(clientHandshakeTrafficLabel)
serverHandshakeTrafficSecret := hs.finishedHash.deriveSecret(serverHandshakeTrafficLabel)
if err := c.useInTrafficSecret(c.wireVersion, hs.suite, serverHandshakeTrafficSecret); err != nil {
if err := c.useInTrafficSecret(encryptionHandshake, c.wireVersion, hs.suite, serverHandshakeTrafficSecret); err != nil {
return err
}
@ -1098,7 +1098,7 @@ func (hs *clientHandshakeState) doTLS13Handshake() error {
// Switch to application data keys on read. In particular, any alerts
// from the client certificate are read over these keys.
if err := c.useInTrafficSecret(c.wireVersion, hs.suite, serverTrafficSecret); err != nil {
if err := c.useInTrafficSecret(encryptionApplication, c.wireVersion, hs.suite, serverTrafficSecret); err != nil {
return err
}
@ -1133,7 +1133,7 @@ func (hs *clientHandshakeState) doTLS13Handshake() error {
// Send EndOfEarlyData and then switch write key to handshake
// traffic key.
if encryptedExtensions.extensions.hasEarlyData && c.out.cipher != nil && !c.config.Bugs.SkipEndOfEarlyData {
if encryptedExtensions.extensions.hasEarlyData && !c.config.Bugs.SkipEndOfEarlyData && c.config.Bugs.MockQUICTransport == nil {
if c.config.Bugs.SendStrayEarlyHandshake {
helloRequest := new(helloRequestMsg)
c.writeRecord(recordTypeHandshake, helloRequest.marshal())
@ -1157,7 +1157,7 @@ func (hs *clientHandshakeState) doTLS13Handshake() error {
c.writeRecord(recordTypeChangeCipherSpec, []byte{1})
}
c.useOutTrafficSecret(c.wireVersion, hs.suite, clientHandshakeTrafficSecret)
c.useOutTrafficSecret(encryptionHandshake, c.wireVersion, hs.suite, clientHandshakeTrafficSecret)
// The client EncryptedExtensions message is sent if some extension uses it.
// (Currently only ALPS does.)
@ -1263,7 +1263,7 @@ func (hs *clientHandshakeState) doTLS13Handshake() error {
c.flushHandshake()
// Switch to application data keys.
c.useOutTrafficSecret(c.wireVersion, hs.suite, clientTrafficSecret)
c.useOutTrafficSecret(encryptionApplication, c.wireVersion, hs.suite, clientTrafficSecret)
c.resumptionSecret = hs.finishedHash.deriveSecret(resumptionLabel)
for _, ticket := range deferredTickets {
if err := c.processTLS13NewSessionTicket(ticket, hs.suite); err != nil {

@ -747,7 +747,7 @@ ResendHelloRetryRequest:
}
sessionCipher := cipherSuiteFromID(hs.sessionState.cipherSuite)
if err := c.useInTrafficSecret(c.wireVersion, sessionCipher, earlyTrafficSecret); err != nil {
if err := c.useInTrafficSecret(encryptionEarlyData, c.wireVersion, sessionCipher, earlyTrafficSecret); err != nil {
return err
}
@ -854,7 +854,7 @@ ResendHelloRetryRequest:
// Switch to handshake traffic keys.
serverHandshakeTrafficSecret := hs.finishedHash.deriveSecret(serverHandshakeTrafficLabel)
c.useOutTrafficSecret(c.wireVersion, hs.suite, serverHandshakeTrafficSecret)
c.useOutTrafficSecret(encryptionHandshake, c.wireVersion, hs.suite, serverHandshakeTrafficSecret)
// Derive handshake traffic read key, but don't switch yet.
clientHandshakeTrafficSecret := hs.finishedHash.deriveSecret(clientHandshakeTrafficLabel)
@ -1038,7 +1038,7 @@ ResendHelloRetryRequest:
// Switch to application data keys on write. In particular, any alerts
// from the client certificate are sent over these keys.
c.useOutTrafficSecret(c.wireVersion, hs.suite, serverTrafficSecret)
c.useOutTrafficSecret(encryptionApplication, c.wireVersion, hs.suite, serverTrafficSecret)
// Send 0.5-RTT messages.
for _, halfRTTMsg := range config.Bugs.SendHalfRTTData {
@ -1063,7 +1063,7 @@ ResendHelloRetryRequest:
}
// Switch input stream to handshake traffic keys.
if err := c.useInTrafficSecret(c.wireVersion, hs.suite, clientHandshakeTrafficSecret); err != nil {
if err := c.useInTrafficSecret(encryptionHandshake, c.wireVersion, hs.suite, clientHandshakeTrafficSecret); err != nil {
return err
}
@ -1192,7 +1192,7 @@ ResendHelloRetryRequest:
hs.writeClientHash(clientFinished.marshal())
// Switch to application data keys on read.
if err := c.useInTrafficSecret(c.wireVersion, hs.suite, clientTrafficSecret); err != nil {
if err := c.useInTrafficSecret(encryptionApplication, c.wireVersion, hs.suite, clientTrafficSecret); err != nil {
return err
}

@ -26,6 +26,15 @@ const tagHandshake = byte('H')
const tagApplication = byte('A')
const tagAlert = byte('L')
type encryptionLevel byte
const (
encryptionInitial encryptionLevel = 0
encryptionEarlyData encryptionLevel = 1
encryptionHandshake encryptionLevel = 2
encryptionApplication encryptionLevel = 3
)
// mockQUICTransport provides a record layer for sending/receiving messages
// when testing TLS over QUIC. It is only intended for testing, as it runs over
// an in-order reliable transport, looks nothing like the QUIC wire image, and
@ -43,6 +52,7 @@ const tagAlert = byte('L')
// cipher suite ID or tag.
type mockQUICTransport struct {
net.Conn
readLevel, writeLevel encryptionLevel
readSecret, writeSecret []byte
readCipherSuite, writeCipherSuite uint16
skipEarlyData bool
@ -54,37 +64,39 @@ func newMockQUICTransport(conn net.Conn) *mockQUICTransport {
func (m *mockQUICTransport) read() (byte, []byte, error) {
for {
header := make([]byte, 7)
header := make([]byte, 8)
if _, err := io.ReadFull(m.Conn, header); err != nil {
return 0, nil, err
}
cipherSuite := binary.BigEndian.Uint16(header[1:3])
length := binary.BigEndian.Uint32(header[3:])
tag := header[0]
level := header[1]
cipherSuite := binary.BigEndian.Uint16(header[2:4])
length := binary.BigEndian.Uint32(header[4:])
value := make([]byte, length)
if _, err := io.ReadFull(m.Conn, value); err != nil {
return 0, nil, fmt.Errorf("Error reading record")
return 0, nil, fmt.Errorf("error reading record")
}
if cipherSuite != m.readCipherSuite {
if m.skipEarlyData {
if level != byte(m.readLevel) {
if m.skipEarlyData && level == byte(encryptionEarlyData) {
continue
}
return 0, nil, fmt.Errorf("Received cipher suite %d does not match expected %d", cipherSuite, m.readCipherSuite)
return 0, nil, fmt.Errorf("received level %d does not match expected %d", level, m.readLevel)
}
if cipherSuite != m.readCipherSuite {
return 0, nil, fmt.Errorf("received cipher suite %d does not match expected %d", cipherSuite, m.readCipherSuite)
}
if len(m.readSecret) > len(value) {
return 0, nil, fmt.Errorf("Input length too short")
return 0, nil, fmt.Errorf("input length too short")
}
secret := value[:len(m.readSecret)]
out := value[len(m.readSecret):]
if !bytes.Equal(secret, m.readSecret) {
if m.skipEarlyData {
continue
}
return 0, nil, fmt.Errorf("secrets don't match: got %x but expected %x", secret, m.readSecret)
}
if m.skipEarlyData && header[0] == tagHandshake {
m.skipEarlyData = false
}
return header[0], out, nil
// Although not true for QUIC in general, our transport is ordered, so
// we expect to stop skipping early data after a valid record.
m.skipEarlyData = false
return tag, out, nil
}
}
@ -114,12 +126,13 @@ func (m *mockQUICTransport) writeRecord(typ recordType, data []byte) (int, error
return 0, fmt.Errorf("unsupported record type %d\n", typ)
}
length := len(m.writeSecret) + len(data)
payload := make([]byte, 1+2+4+length)
payload := make([]byte, 1+1+2+4+length)
payload[0] = tag
binary.BigEndian.PutUint16(payload[1:3], m.writeCipherSuite)
binary.BigEndian.PutUint32(payload[3:7], uint32(length))
copy(payload[7:], m.writeSecret)
copy(payload[7+len(m.writeSecret):], data)
payload[1] = byte(m.writeLevel)
binary.BigEndian.PutUint16(payload[2:4], m.writeCipherSuite)
binary.BigEndian.PutUint32(payload[4:8], uint32(length))
copy(payload[8:], m.writeSecret)
copy(payload[8+len(m.writeSecret):], data)
if _, err := m.Conn.Write(payload); err != nil {
return 0, err
}

@ -1350,6 +1350,11 @@ func runTest(statusChan chan statusMsg, test *testCase, shimPath string, mallocN
flags = append(flags, "-on-resume-expect-accept-early-data")
}
if test.protocol == quic {
// QUIC requires an early data context string.
flags = append(flags, "-quic-early-data-context", "context")
}
flags = append(flags, "-enable-early-data")
if test.testType == clientTest {
// Configure the runner with default maximum early data.

@ -183,6 +183,7 @@ const Flag<std::string> kStringFlags[] = {
{"-handshaker-path", &TestConfig::handshaker_path},
{"-delegated-credential", &TestConfig::delegated_credential},
{"-expect-early-data-reason", &TestConfig::expect_early_data_reason},
{"-quic-early-data-context", &TestConfig::quic_early_data_context},
};
// TODO(davidben): When we can depend on C++17 or Abseil, switch this to
@ -1797,5 +1798,13 @@ bssl::UniquePtr<SSL> TestConfig::NewSSL(
}
}
if (!quic_early_data_context.empty() &&
!SSL_set_quic_early_data_context(
ssl.get(),
reinterpret_cast<const uint8_t *>(quic_early_data_context.data()),
quic_early_data_context.size())) {
return nullptr;
}
return ssl;
}

@ -182,6 +182,7 @@ struct TestConfig {
bool expect_hrr = false;
bool expect_no_hrr = false;
bool wait_for_debugger = false;
std::string quic_early_data_context;
int argc;
char **argv;

Loading…
Cancel
Save