diff --git a/go.mod b/go.mod index 23038f02c..a6a6724f6 100644 --- a/go.mod +++ b/go.mod @@ -3,11 +3,11 @@ module boringssl.googlesource.com/boringssl go 1.19 require ( - golang.org/x/crypto v0.6.0 - golang.org/x/net v0.7.0 + golang.org/x/crypto v0.10.0 + golang.org/x/net v0.11.0 ) require ( - golang.org/x/sys v0.5.0 // indirect - golang.org/x/term v0.5.0 // indirect + golang.org/x/sys v0.9.0 // indirect + golang.org/x/term v0.9.0 // indirect ) diff --git a/go.sum b/go.sum index a97a96075..05c20c32f 100644 --- a/go.sum +++ b/go.sum @@ -1,8 +1,8 @@ -golang.org/x/crypto v0.6.0 h1:qfktjS5LUO+fFKeJXZ+ikTRijMmljikvG68fpMMruSc= -golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58= -golang.org/x/net v0.7.0 h1:rJrUqqhjsgNp7KqAIc25s9pZnjU7TUcSY7HcVZjdn1g= -golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= -golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU= -golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/term v0.5.0 h1:n2a8QNdAb0sZNpU9R1ALUXBbY+w51fCQDN+7EdxNBsY= -golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= +golang.org/x/crypto v0.10.0 h1:LKqV2xt9+kDzSTfOhx4FrkEBcMrAgHSYgzywV9zcGmM= +golang.org/x/crypto v0.10.0/go.mod h1:o4eNf7Ede1fv+hwOwZsTHl9EsPFO6q6ZvYR8vYfY45I= +golang.org/x/net v0.11.0 h1:Gi2tvZIJyBtO9SDr1q9h5hEQCp/4L2RQ+ar0qjx2oNU= +golang.org/x/net v0.11.0/go.mod h1:2L/ixqYpgIVXmeoSA/4Lu7BzTG4KIyPIryS4IsOd1oQ= +golang.org/x/sys v0.9.0 h1:KS/R3tvhPqvJvwcKfnBHJwwthS11LRhmM5D59eEXa0s= +golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.9.0 h1:GRRCnKYhdQrD8kfRAdQ6Zcw1P0OcELxGLKJvtjVMZ28= +golang.org/x/term v0.9.0/go.mod h1:M6DEAAIenWoTxdKrOltXcmDY3rSplQUkrvaDU5FcQyo= diff --git a/ssl/test/runner/handshake_client.go b/ssl/test/runner/handshake_client.go index a51ed558e..0ed0094e9 100644 --- a/ssl/test/runner/handshake_client.go +++ b/ssl/test/runner/handshake_client.go @@ -21,6 +21,7 @@ import ( "time" "boringssl.googlesource.com/boringssl/ssl/test/runner/hpke" + "golang.org/x/crypto/cryptobyte" ) const echBadPayloadByte = 0xff @@ -71,9 +72,12 @@ func replaceClientHello(hello *clientHelloMsg, in []byte) (*clientHelloMsg, erro // Replace |newHellos|'s key shares with those of |hello|. For simplicity, // we require their lengths match, which is satisfied by matching the // DefaultCurves setting to the selection in the replacement ClientHello. - bb := newByteBuilder() + bb := cryptobyte.NewBuilder(nil) hello.marshalKeyShares(bb) - keyShares := bb.finish() + keyShares, err := bb.Bytes() + if err != nil { + return nil, err + } if len(keyShares) != len(newHello.keySharesRaw) { return nil, errors.New("tls: ClientHello key share length is inconsistent with DefaultCurves setting") } diff --git a/ssl/test/runner/handshake_messages.go b/ssl/test/runner/handshake_messages.go index b253b0c84..6ea7faaa8 100644 --- a/ssl/test/runner/handshake_messages.go +++ b/ssl/test/runner/handshake_messages.go @@ -5,239 +5,49 @@ package runner import ( - "encoding/binary" "errors" "fmt" -) - -func writeLen(buf []byte, v, size int) { - for i := 0; i < size; i++ { - buf[size-i-1] = byte(v) - v >>= 8 - } - if v != 0 { - panic("length is too long") - } -} - -type byteBuilder struct { - buf *[]byte - start int - prefixLen int - child *byteBuilder -} - -func newByteBuilder() *byteBuilder { - buf := make([]byte, 0, 32) - return &byteBuilder{buf: &buf} -} - -func (bb *byteBuilder) len() int { - return len(*bb.buf) - bb.start - bb.prefixLen -} - -func (bb *byteBuilder) data() []byte { - bb.flush() - return (*bb.buf)[bb.start+bb.prefixLen:] -} - -func (bb *byteBuilder) flush() { - if bb.child == nil { - return - } - bb.child.flush() - writeLen((*bb.buf)[bb.child.start:], bb.child.len(), bb.child.prefixLen) - bb.child = nil - return -} - -func (bb *byteBuilder) finish() []byte { - bb.flush() - return *bb.buf -} - -func (bb *byteBuilder) addU8(u uint8) { - bb.flush() - *bb.buf = append(*bb.buf, u) -} - -func (bb *byteBuilder) addU16(u uint16) { - bb.flush() - *bb.buf = append(*bb.buf, byte(u>>8), byte(u)) -} - -func (bb *byteBuilder) addU24(u int) { - bb.flush() - *bb.buf = append(*bb.buf, byte(u>>16), byte(u>>8), byte(u)) -} - -func (bb *byteBuilder) addU32(u uint32) { - bb.flush() - *bb.buf = append(*bb.buf, byte(u>>24), byte(u>>16), byte(u>>8), byte(u)) -} - -func (bb *byteBuilder) addU64(u uint64) { - bb.flush() - var b [8]byte - binary.BigEndian.PutUint64(b[:], u) - *bb.buf = append(*bb.buf, b[:]...) -} - -func (bb *byteBuilder) addU8LengthPrefixed() *byteBuilder { - return bb.createChild(1) -} - -func (bb *byteBuilder) addU16LengthPrefixed() *byteBuilder { - return bb.createChild(2) -} - -func (bb *byteBuilder) addU24LengthPrefixed() *byteBuilder { - return bb.createChild(3) -} - -func (bb *byteBuilder) addU32LengthPrefixed() *byteBuilder { - return bb.createChild(4) -} - -func (bb *byteBuilder) addBytes(b []byte) { - bb.flush() - *bb.buf = append(*bb.buf, b...) -} - -func (bb *byteBuilder) createChild(lengthPrefixSize int) *byteBuilder { - bb.flush() - bb.child = &byteBuilder{ - buf: bb.buf, - start: len(*bb.buf), - prefixLen: lengthPrefixSize, - } - for i := 0; i < lengthPrefixSize; i++ { - *bb.buf = append(*bb.buf, 0) - } - return bb.child -} -func (bb *byteBuilder) discardChild() { - if bb.child == nil { - return - } - *bb.buf = (*bb.buf)[:bb.child.start] - bb.child = nil -} - -type byteReader []byte - -func (br *byteReader) readInternal(out *byteReader, n int) bool { - if len(*br) < n { - return false - } - *out = (*br)[:n] - *br = (*br)[n:] - return true -} - -func (br *byteReader) readBytes(out *[]byte, n int) bool { - var child byteReader - if !br.readInternal(&child, n) { - return false - } - *out = []byte(child) - return true -} - -func (br *byteReader) readUint(out *uint64, n int) bool { - var b []byte - if !br.readBytes(&b, n) { - return false - } - *out = 0 - for _, v := range b { - *out <<= 8 - *out |= uint64(v) - } - return true -} - -func (br *byteReader) readU8(out *uint8) bool { - var b []byte - if !br.readBytes(&b, 1) { - return false - } - *out = b[0] - return true -} + "golang.org/x/crypto/cryptobyte" +) -func (br *byteReader) readU16(out *uint16) bool { - var v uint64 - if !br.readUint(&v, 2) { +func readUint8LengthPrefixedBytes(s *cryptobyte.String, out *[]byte) bool { + var child cryptobyte.String + if !s.ReadUint8LengthPrefixed(&child) { return false } - *out = uint16(v) + *out = child return true } -func (br *byteReader) readU24(out *uint32) bool { - var v uint64 - if !br.readUint(&v, 3) { +func readUint16LengthPrefixedBytes(s *cryptobyte.String, out *[]byte) bool { + var child cryptobyte.String + if !s.ReadUint16LengthPrefixed(&child) { return false } - *out = uint32(v) + *out = child return true } -func (br *byteReader) readU32(out *uint32) bool { - var v uint64 - if !br.readUint(&v, 4) { +func readUint24LengthPrefixedBytes(s *cryptobyte.String, out *[]byte) bool { + var child cryptobyte.String + if !s.ReadUint24LengthPrefixed(&child) { return false } - *out = uint32(v) + *out = child return true } -func (br *byteReader) readU64(out *uint64) bool { - return br.readUint(out, 8) +func addUint8LengthPrefixedBytes(b *cryptobyte.Builder, v []byte) { + b.AddUint8LengthPrefixed(func(child *cryptobyte.Builder) { child.AddBytes(v) }) } -func (br *byteReader) readLengthPrefixed(out *byteReader, n int) bool { - var length uint64 - return br.readUint(&length, n) && - uint64(len(*br)) >= length && - br.readInternal(out, int(length)) +func addUint16LengthPrefixedBytes(b *cryptobyte.Builder, v []byte) { + b.AddUint16LengthPrefixed(func(child *cryptobyte.Builder) { child.AddBytes(v) }) } -func (br *byteReader) readLengthPrefixedBytes(out *[]byte, n int) bool { - var length uint64 - return br.readUint(&length, n) && - uint64(len(*br)) >= length && - br.readBytes(out, int(length)) -} - -func (br *byteReader) readU8LengthPrefixed(out *byteReader) bool { - return br.readLengthPrefixed(out, 1) -} -func (br *byteReader) readU8LengthPrefixedBytes(out *[]byte) bool { - return br.readLengthPrefixedBytes(out, 1) -} - -func (br *byteReader) readU16LengthPrefixed(out *byteReader) bool { - return br.readLengthPrefixed(out, 2) -} -func (br *byteReader) readU16LengthPrefixedBytes(out *[]byte) bool { - return br.readLengthPrefixedBytes(out, 2) -} - -func (br *byteReader) readU24LengthPrefixed(out *byteReader) bool { - return br.readLengthPrefixed(out, 3) -} -func (br *byteReader) readU24LengthPrefixedBytes(out *[]byte) bool { - return br.readLengthPrefixedBytes(out, 3) -} - -func (br *byteReader) readU32LengthPrefixed(out *byteReader) bool { - return br.readLengthPrefixed(out, 4) -} -func (br *byteReader) readU32LengthPrefixedBytes(out *[]byte) bool { - return br.readLengthPrefixedBytes(out, 4) +func addUint24LengthPrefixedBytes(b *cryptobyte.Builder, v []byte) { + b.AddUint24LengthPrefixed(func(child *cryptobyte.Builder) { child.AddBytes(v) }) } type keyShareEntry struct { @@ -269,48 +79,52 @@ type ECHConfig struct { } func CreateECHConfig(template *ECHConfig) *ECHConfig { - bb := newByteBuilder() + bb := cryptobyte.NewBuilder(nil) // ECHConfig reuses the encrypted_client_hello extension codepoint as a // version identifier. - bb.addU16(extensionEncryptedClientHello) - contents := bb.addU16LengthPrefixed() - contents.addU8(template.ConfigID) - contents.addU16(template.KEM) - contents.addU16LengthPrefixed().addBytes(template.PublicKey) - cipherSuites := contents.addU16LengthPrefixed() - for _, suite := range template.CipherSuites { - cipherSuites.addU16(suite.KDF) - cipherSuites.addU16(suite.AEAD) - } - contents.addU8(template.MaxNameLen) - contents.addU8LengthPrefixed().addBytes([]byte(template.PublicName)) - extensions := contents.addU16LengthPrefixed() - // Mandatory extensions have the high bit set. - if template.UnsupportedExtension { - extensions.addU16(0x1111) - extensions.addU16LengthPrefixed().addBytes([]byte("test")) - } - if template.UnsupportedMandatoryExtension { - extensions.addU16(0xaaaa) - extensions.addU16LengthPrefixed().addBytes([]byte("test")) - } - - // This ought to be a call to a function like ParseECHConfig(bb.finish()), + bb.AddUint16(extensionEncryptedClientHello) + bb.AddUint16LengthPrefixed(func(contents *cryptobyte.Builder) { + contents.AddUint8(template.ConfigID) + contents.AddUint16(template.KEM) + addUint16LengthPrefixedBytes(contents, template.PublicKey) + contents.AddUint16LengthPrefixed(func(cipherSuites *cryptobyte.Builder) { + for _, suite := range template.CipherSuites { + cipherSuites.AddUint16(suite.KDF) + cipherSuites.AddUint16(suite.AEAD) + } + }) + contents.AddUint8(template.MaxNameLen) + addUint8LengthPrefixedBytes(contents, []byte(template.PublicName)) + contents.AddUint16LengthPrefixed(func(extensions *cryptobyte.Builder) { + // Mandatory extensions have the high bit set. + if template.UnsupportedExtension { + extensions.AddUint16(0x1111) + addUint16LengthPrefixedBytes(extensions, []byte("test")) + } + if template.UnsupportedMandatoryExtension { + extensions.AddUint16(0xaaaa) + addUint16LengthPrefixedBytes(extensions, []byte("test")) + } + }) + }) + + // This ought to be a call to a function like ParseECHConfig(bb.BytesOrPanic()), // but this constrains us to constructing ECHConfigs we are willing to // support. We need to test the client's behavior in response to unparsable // or unsupported ECHConfigs, so populate fields from the template directly. ret := *template - ret.Raw = bb.finish() + ret.Raw = bb.BytesOrPanic() return &ret } func CreateECHConfigList(configs ...[]byte) []byte { - bb := newByteBuilder() - list := bb.addU16LengthPrefixed() - for _, config := range configs { - list.addBytes(config) - } - return bb.finish() + bb := cryptobyte.NewBuilder(nil) + bb.AddUint16LengthPrefixed(func(list *cryptobyte.Builder) { + for _, config := range configs { + list.AddBytes(config) + } + }) + return bb.BytesOrPanic() } type ServerECHConfig struct { @@ -392,16 +206,16 @@ type clientHelloMsg struct { rawExtensions []byte } -func (m *clientHelloMsg) marshalKeyShares(bb *byteBuilder) { - keyShares := bb.addU16LengthPrefixed() - for _, keyShare := range m.keyShares { - keyShares.addU16(uint16(keyShare.group)) - keyExchange := keyShares.addU16LengthPrefixed() - keyExchange.addBytes(keyShare.keyExchange) - } - if m.trailingKeyShareData { - keyShares.addU8(0) - } +func (m *clientHelloMsg) marshalKeyShares(bb *cryptobyte.Builder) { + bb.AddUint16LengthPrefixed(func(keyShares *cryptobyte.Builder) { + for _, keyShare := range m.keyShares { + keyShares.AddUint16(uint16(keyShare.group)) + addUint16LengthPrefixedBytes(keyShares, keyShare.keyExchange) + } + if m.trailingKeyShareData { + keyShares.AddUint8(0) + } + }) } type clientHelloType int @@ -411,23 +225,27 @@ const ( clientHelloEncodedInner ) -func (m *clientHelloMsg) marshalBody(hello *byteBuilder, typ clientHelloType) { - hello.addU16(m.vers) - hello.addBytes(m.random) - sessionID := hello.addU8LengthPrefixed() - if typ != clientHelloEncodedInner { - sessionID.addBytes(m.sessionID) - } +func (m *clientHelloMsg) marshalBody(hello *cryptobyte.Builder, typ clientHelloType) { + hello.AddUint16(m.vers) + hello.AddBytes(m.random) + hello.AddUint8LengthPrefixed(func(sessionID *cryptobyte.Builder) { + if typ != clientHelloEncodedInner { + sessionID.AddBytes(m.sessionID) + } + }) if m.isDTLS { - cookie := hello.addU8LengthPrefixed() - cookie.addBytes(m.cookie) - } - cipherSuites := hello.addU16LengthPrefixed() - for _, suite := range m.cipherSuites { - cipherSuites.addU16(suite) + hello.AddUint8LengthPrefixed(func(cookie *cryptobyte.Builder) { + cookie.AddBytes(m.cookie) + }) } - compressionMethods := hello.addU8LengthPrefixed() - compressionMethods.addBytes(m.compressionMethods) + hello.AddUint16LengthPrefixed(func(cipherSuites *cryptobyte.Builder) { + for _, suite := range m.cipherSuites { + cipherSuites.AddUint16(suite) + } + }) + hello.AddUint8LengthPrefixed(func(compressionMethods *cryptobyte.Builder) { + compressionMethods.AddBytes(m.compressionMethods) + }) type extension struct { id uint16 @@ -462,99 +280,99 @@ func (m *clientHelloMsg) marshalBody(hello *byteBuilder, typ clientHelloType) { // ServerName server_name_list<1..2^16-1> // } ServerNameList; - serverNameList := newByteBuilder() - serverName := serverNameList.addU16LengthPrefixed() - serverName.addU8(0) // NameType host_name(0) - hostName := serverName.addU16LengthPrefixed() - hostName.addBytes([]byte(m.serverName)) + serverNameList := cryptobyte.NewBuilder(nil) + serverNameList.AddUint16LengthPrefixed(func(serverName *cryptobyte.Builder) { + serverName.AddUint8(0) // NameType host_name(0) + addUint16LengthPrefixedBytes(serverName, []byte(m.serverName)) + }) extensions = append(extensions, extension{ id: extensionServerName, - body: serverNameList.finish(), + body: serverNameList.BytesOrPanic(), }) } if m.echOuter != nil { - body := newByteBuilder() - body.addU8(echClientTypeOuter) - body.addU16(m.echOuter.kdfID) - body.addU16(m.echOuter.aeadID) - body.addU8(m.echOuter.configID) - body.addU16LengthPrefixed().addBytes(m.echOuter.enc) - body.addU16LengthPrefixed().addBytes(m.echOuter.payload) + body := cryptobyte.NewBuilder(nil) + body.AddUint8(echClientTypeOuter) + body.AddUint16(m.echOuter.kdfID) + body.AddUint16(m.echOuter.aeadID) + body.AddUint8(m.echOuter.configID) + addUint16LengthPrefixedBytes(body, m.echOuter.enc) + addUint16LengthPrefixedBytes(body, m.echOuter.payload) extensions = append(extensions, extension{ id: extensionEncryptedClientHello, - body: body.finish(), + body: body.BytesOrPanic(), }) } if m.echInner { - body := newByteBuilder() - body.addU8(echClientTypeInner) + body := cryptobyte.NewBuilder(nil) + body.AddUint8(echClientTypeInner) // If unset, invalidECHInner is empty, which is the correct serialization. - body.addBytes(m.invalidECHInner) + body.AddBytes(m.invalidECHInner) extensions = append(extensions, extension{ id: extensionEncryptedClientHello, - body: body.finish(), + body: body.BytesOrPanic(), }) } if m.ocspStapling { - certificateStatusRequest := newByteBuilder() + certificateStatusRequest := cryptobyte.NewBuilder(nil) // RFC 4366, section 3.6 - certificateStatusRequest.addU8(1) // OCSP type + certificateStatusRequest.AddUint8(1) // OCSP type // Two zero valued uint16s for the two lengths. - certificateStatusRequest.addU16(0) // ResponderID length - certificateStatusRequest.addU16(0) // Extensions length + certificateStatusRequest.AddUint16(0) // ResponderID length + certificateStatusRequest.AddUint16(0) // Extensions length extensions = append(extensions, extension{ id: extensionStatusRequest, - body: certificateStatusRequest.finish(), + body: certificateStatusRequest.BytesOrPanic(), }) } if len(m.supportedCurves) > 0 { // http://tools.ietf.org/html/rfc4492#section-5.1.1 - supportedCurvesList := newByteBuilder() - supportedCurves := supportedCurvesList.addU16LengthPrefixed() - for _, curve := range m.supportedCurves { - supportedCurves.addU16(uint16(curve)) - } + supportedCurvesList := cryptobyte.NewBuilder(nil) + supportedCurvesList.AddUint16LengthPrefixed(func(supportedCurves *cryptobyte.Builder) { + for _, curve := range m.supportedCurves { + supportedCurves.AddUint16(uint16(curve)) + } + }) extensions = append(extensions, extension{ id: extensionSupportedCurves, - body: supportedCurvesList.finish(), + body: supportedCurvesList.BytesOrPanic(), }) } if len(m.supportedPoints) > 0 { // http://tools.ietf.org/html/rfc4492#section-5.1.2 - supportedPointsList := newByteBuilder() - supportedPoints := supportedPointsList.addU8LengthPrefixed() - supportedPoints.addBytes(m.supportedPoints) + supportedPointsList := cryptobyte.NewBuilder(nil) + addUint8LengthPrefixedBytes(supportedPointsList, m.supportedPoints) extensions = append(extensions, extension{ id: extensionSupportedPoints, - body: supportedPointsList.finish(), + body: supportedPointsList.BytesOrPanic(), }) } if m.hasKeyShares { - keyShareList := newByteBuilder() + keyShareList := cryptobyte.NewBuilder(nil) m.marshalKeyShares(keyShareList) extensions = append(extensions, extension{ id: extensionKeyShare, - body: keyShareList.finish(), + body: keyShareList.BytesOrPanic(), }) } if len(m.pskKEModes) > 0 { - pskModesExtension := newByteBuilder() - pskModesExtension.addU8LengthPrefixed().addBytes(m.pskKEModes) + pskModesExtension := cryptobyte.NewBuilder(nil) + addUint8LengthPrefixedBytes(pskModesExtension, m.pskKEModes) extensions = append(extensions, extension{ id: extensionPSKKeyExchangeModes, - body: pskModesExtension.finish(), + body: pskModesExtension.BytesOrPanic(), }) } if m.hasEarlyData { extensions = append(extensions, extension{id: extensionEarlyData}) } if len(m.tls13Cookie) > 0 { - body := newByteBuilder() - body.addU16LengthPrefixed().addBytes(m.tls13Cookie) + body := cryptobyte.NewBuilder(nil) + addUint16LengthPrefixedBytes(body, m.tls13Cookie) extensions = append(extensions, extension{ id: extensionCookie, - body: body.finish(), + body: body.BytesOrPanic(), }) } if m.ticketSupported { @@ -566,57 +384,60 @@ func (m *clientHelloMsg) marshalBody(hello *byteBuilder, typ clientHelloType) { } if len(m.signatureAlgorithms) > 0 { // https://tools.ietf.org/html/rfc5246#section-7.4.1.4.1 - signatureAlgorithmsExtension := newByteBuilder() - signatureAlgorithms := signatureAlgorithmsExtension.addU16LengthPrefixed() - for _, sigAlg := range m.signatureAlgorithms { - signatureAlgorithms.addU16(uint16(sigAlg)) - } + signatureAlgorithmsExtension := cryptobyte.NewBuilder(nil) + signatureAlgorithmsExtension.AddUint16LengthPrefixed(func(signatureAlgorithms *cryptobyte.Builder) { + for _, sigAlg := range m.signatureAlgorithms { + signatureAlgorithms.AddUint16(uint16(sigAlg)) + } + }) extensions = append(extensions, extension{ id: extensionSignatureAlgorithms, - body: signatureAlgorithmsExtension.finish(), + body: signatureAlgorithmsExtension.BytesOrPanic(), }) } if len(m.signatureAlgorithmsCert) > 0 { - signatureAlgorithmsCertExtension := newByteBuilder() - signatureAlgorithmsCert := signatureAlgorithmsCertExtension.addU16LengthPrefixed() - for _, sigAlg := range m.signatureAlgorithmsCert { - signatureAlgorithmsCert.addU16(uint16(sigAlg)) - } + signatureAlgorithmsCertExtension := cryptobyte.NewBuilder(nil) + signatureAlgorithmsCertExtension.AddUint16LengthPrefixed(func(signatureAlgorithmsCert *cryptobyte.Builder) { + for _, sigAlg := range m.signatureAlgorithmsCert { + signatureAlgorithmsCert.AddUint16(uint16(sigAlg)) + } + }) extensions = append(extensions, extension{ id: extensionSignatureAlgorithmsCert, - body: signatureAlgorithmsCertExtension.finish(), + body: signatureAlgorithmsCertExtension.BytesOrPanic(), }) } if len(m.supportedVersions) > 0 { - supportedVersionsExtension := newByteBuilder() - supportedVersions := supportedVersionsExtension.addU8LengthPrefixed() - for _, version := range m.supportedVersions { - supportedVersions.addU16(uint16(version)) - } + supportedVersionsExtension := cryptobyte.NewBuilder(nil) + supportedVersionsExtension.AddUint8LengthPrefixed(func(supportedVersions *cryptobyte.Builder) { + for _, version := range m.supportedVersions { + supportedVersions.AddUint16(uint16(version)) + } + }) extensions = append(extensions, extension{ id: extensionSupportedVersions, - body: supportedVersionsExtension.finish(), + body: supportedVersionsExtension.BytesOrPanic(), }) } if m.secureRenegotiation != nil { - secureRenegoExt := newByteBuilder() - secureRenegoExt.addU8LengthPrefixed().addBytes(m.secureRenegotiation) + secureRenegoExt := cryptobyte.NewBuilder(nil) + addUint8LengthPrefixedBytes(secureRenegoExt, m.secureRenegotiation) extensions = append(extensions, extension{ id: extensionRenegotiationInfo, - body: secureRenegoExt.finish(), + body: secureRenegoExt.BytesOrPanic(), }) } if len(m.alpnProtocols) > 0 { // https://tools.ietf.org/html/rfc7301#section-3.1 - alpnExtension := newByteBuilder() - protocolNameList := alpnExtension.addU16LengthPrefixed() - for _, s := range m.alpnProtocols { - protocolName := protocolNameList.addU8LengthPrefixed() - protocolName.addBytes([]byte(s)) - } + alpnExtension := cryptobyte.NewBuilder(nil) + alpnExtension.AddUint16LengthPrefixed(func(protocolNameList *cryptobyte.Builder) { + for _, s := range m.alpnProtocols { + addUint8LengthPrefixedBytes(protocolNameList, []byte(s)) + } + }) extensions = append(extensions, extension{ id: extensionALPN, - body: alpnExtension.finish(), + body: alpnExtension.BytesOrPanic(), }) } if len(m.quicTransportParams) > 0 { @@ -644,18 +465,18 @@ func (m *clientHelloMsg) marshalBody(hello *byteBuilder, typ clientHelloType) { } if len(m.srtpProtectionProfiles) > 0 { // https://tools.ietf.org/html/rfc5764#section-4.1.1 - useSrtpExt := newByteBuilder() + useSrtpExt := cryptobyte.NewBuilder(nil) - srtpProtectionProfiles := useSrtpExt.addU16LengthPrefixed() - for _, p := range m.srtpProtectionProfiles { - srtpProtectionProfiles.addU16(p) - } - srtpMki := useSrtpExt.addU8LengthPrefixed() - srtpMki.addBytes([]byte(m.srtpMasterKeyIdentifier)) + useSrtpExt.AddUint16LengthPrefixed(func(srtpProtectionProfiles *cryptobyte.Builder) { + for _, p := range m.srtpProtectionProfiles { + srtpProtectionProfiles.AddUint16(p) + } + }) + addUint8LengthPrefixedBytes(useSrtpExt, []byte(m.srtpMasterKeyIdentifier)) extensions = append(extensions, extension{ id: extensionUseSRTP, - body: useSrtpExt.finish(), + body: useSrtpExt.BytesOrPanic(), }) } if m.sctListSupported { @@ -668,130 +489,138 @@ func (m *clientHelloMsg) marshalBody(hello *byteBuilder, typ clientHelloType) { }) } if len(m.compressedCertAlgs) > 0 { - body := newByteBuilder() - algIDs := body.addU8LengthPrefixed() - for _, v := range m.compressedCertAlgs { - algIDs.addU16(v) - } + body := cryptobyte.NewBuilder(nil) + body.AddUint8LengthPrefixed(func(algIDs *cryptobyte.Builder) { + for _, v := range m.compressedCertAlgs { + algIDs.AddUint16(v) + } + }) extensions = append(extensions, extension{ id: extensionCompressedCertAlgs, - body: body.finish(), + body: body.BytesOrPanic(), }) } if m.delegatedCredentials { - body := newByteBuilder() - signatureSchemeList := body.addU16LengthPrefixed() - for _, sigAlg := range m.signatureAlgorithms { - signatureSchemeList.addU16(uint16(sigAlg)) - } + body := cryptobyte.NewBuilder(nil) + body.AddUint16LengthPrefixed(func(signatureSchemeList *cryptobyte.Builder) { + for _, sigAlg := range m.signatureAlgorithms { + signatureSchemeList.AddUint16(uint16(sigAlg)) + } + }) extensions = append(extensions, extension{ id: extensionDelegatedCredentials, - body: body.finish(), + body: body.BytesOrPanic(), }) } if len(m.alpsProtocols) > 0 { - body := newByteBuilder() - protocolNameList := body.addU16LengthPrefixed() - for _, s := range m.alpsProtocols { - protocolNameList.addU8LengthPrefixed().addBytes([]byte(s)) - } + body := cryptobyte.NewBuilder(nil) + body.AddUint16LengthPrefixed(func(protocolNameList *cryptobyte.Builder) { + for _, s := range m.alpsProtocols { + addUint8LengthPrefixedBytes(protocolNameList, []byte(s)) + } + }) extensions = append(extensions, extension{ id: extensionApplicationSettings, - body: body.finish(), + body: body.BytesOrPanic(), }) } // The PSK extension must be last. See https://tools.ietf.org/html/rfc8446#section-4.2.11 if len(m.pskIdentities) > 0 { - pskExtension := newByteBuilder() - pskIdentities := pskExtension.addU16LengthPrefixed() - for _, psk := range m.pskIdentities { - pskIdentities.addU16LengthPrefixed().addBytes(psk.ticket) - pskIdentities.addU32(psk.obfuscatedTicketAge) - } - pskBinders := pskExtension.addU16LengthPrefixed() - for _, binder := range m.pskBinders { - pskBinders.addU8LengthPrefixed().addBytes(binder) - } + pskExtension := cryptobyte.NewBuilder(nil) + pskExtension.AddUint16LengthPrefixed(func(pskIdentities *cryptobyte.Builder) { + for _, psk := range m.pskIdentities { + addUint16LengthPrefixedBytes(pskIdentities, psk.ticket) + pskIdentities.AddUint32(psk.obfuscatedTicketAge) + } + }) + pskExtension.AddUint16LengthPrefixed(func(pskBinders *cryptobyte.Builder) { + for _, binder := range m.pskBinders { + addUint8LengthPrefixedBytes(pskBinders, binder) + } + }) extensions = append(extensions, extension{ id: extensionPreSharedKey, - body: pskExtension.finish(), + body: pskExtension.BytesOrPanic(), }) } - extensionsBB := hello.addU16LengthPrefixed() - extMap := make(map[uint16][]byte) - extsWritten := make(map[uint16]struct{}) - for _, ext := range extensions { - extMap[ext.id] = ext.body - } - // Write each of the prefix extensions, if we have it. - for _, extID := range m.prefixExtensions { - if body, ok := extMap[extID]; ok { - extensionsBB.addU16(extID) - extensionsBB.addU16LengthPrefixed().addBytes(body) - extsWritten[extID] = struct{}{} - } + if m.omitExtensions { + return } - // Write outer extensions, possibly in compressed form. - if m.outerExtensions != nil { - if typ == clientHelloEncodedInner && !m.reorderOuterExtensionsWithoutCompressing { - extensionsBB.addU16(extensionECHOuterExtensions) - list := extensionsBB.addU16LengthPrefixed().addU8LengthPrefixed() - for _, extID := range m.outerExtensions { - list.addU16(extID) + hello.AddUint16LengthPrefixed(func(extensionsBB *cryptobyte.Builder) { + if m.emptyExtensions { + return + } + extMap := make(map[uint16][]byte) + extsWritten := make(map[uint16]struct{}) + for _, ext := range extensions { + extMap[ext.id] = ext.body + } + // Write each of the prefix extensions, if we have it. + for _, extID := range m.prefixExtensions { + if body, ok := extMap[extID]; ok { + extensionsBB.AddUint16(extID) + addUint16LengthPrefixedBytes(extensionsBB, body) extsWritten[extID] = struct{}{} } - } else { - for _, extID := range m.outerExtensions { - // m.outerExtensions may intentionally contain duplicates to test the - // server's reaction. If m.reorderOuterExtensionsWithoutCompressing - // is set, we are targetting the second ClientHello and wish to send a - // valid first ClientHello. In that case, deduplicate so the error - // only appears later. - if _, written := extsWritten[extID]; m.reorderOuterExtensionsWithoutCompressing && written { - continue - } - if body, ok := extMap[extID]; ok { - extensionsBB.addU16(extID) - extensionsBB.addU16LengthPrefixed().addBytes(body) - extsWritten[extID] = struct{}{} + } + // Write outer extensions, possibly in compressed form. + if m.outerExtensions != nil { + if typ == clientHelloEncodedInner && !m.reorderOuterExtensionsWithoutCompressing { + extensionsBB.AddUint16(extensionECHOuterExtensions) + extensionsBB.AddUint16LengthPrefixed(func(child *cryptobyte.Builder) { + child.AddUint8LengthPrefixed(func(list *cryptobyte.Builder) { + for _, extID := range m.outerExtensions { + list.AddUint16(extID) + extsWritten[extID] = struct{}{} + } + }) + }) + } else { + for _, extID := range m.outerExtensions { + // m.outerExtensions may intentionally contain duplicates to test the + // server's reaction. If m.reorderOuterExtensionsWithoutCompressing + // is set, we are targetting the second ClientHello and wish to send a + // valid first ClientHello. In that case, deduplicate so the error + // only appears later. + if _, written := extsWritten[extID]; m.reorderOuterExtensionsWithoutCompressing && written { + continue + } + if body, ok := extMap[extID]; ok { + extensionsBB.AddUint16(extID) + addUint16LengthPrefixedBytes(extensionsBB, body) + extsWritten[extID] = struct{}{} + } } } } - } - - // Write each of the remaining extensions in their original order. - for _, ext := range extensions { - if _, written := extsWritten[ext.id]; !written { - extensionsBB.addU16(ext.id) - extensionsBB.addU16LengthPrefixed().addBytes(ext.body) - } - } - if m.pad != 0 && hello.len()%m.pad != 0 { - extensionsBB.addU16(extensionPadding) - padding := extensionsBB.addU16LengthPrefixed() - // Note hello.len() has changed at this point from the length - // prefix. - if l := hello.len() % m.pad; l != 0 { - padding.addBytes(make([]byte, m.pad-l)) + // Write each of the remaining extensions in their original order. + for _, ext := range extensions { + if _, written := extsWritten[ext.id]; !written { + extensionsBB.AddUint16(ext.id) + addUint16LengthPrefixedBytes(extensionsBB, ext.body) + } } - } - if m.omitExtensions || m.emptyExtensions { - // Silently erase any extensions which were sent. - hello.discardChild() - if m.emptyExtensions { - hello.addU16(0) + if m.pad != 0 && len(hello.BytesOrPanic())%m.pad != 0 { + extensionsBB.AddUint16(extensionPadding) + extensionsBB.AddUint16LengthPrefixed(func(padding *cryptobyte.Builder) { + // Note hello.len() has changed at this point from the length + // prefix. + if l := len(hello.BytesOrPanic()) % m.pad; l != 0 { + padding.AddBytes(make([]byte, m.pad-l)) + } + }) } - } + }) } func (m *clientHelloMsg) marshalForEncodedInner() []byte { - hello := newByteBuilder() + hello := cryptobyte.NewBuilder(nil) m.marshalBody(hello, clientHelloEncodedInner) - return hello.finish() + return hello.BytesOrPanic() } func (m *clientHelloMsg) marshal() []byte { @@ -800,26 +629,27 @@ func (m *clientHelloMsg) marshal() []byte { } if m.isV2ClientHello { - v2Msg := newByteBuilder() - v2Msg.addU8(1) - v2Msg.addU16(m.vers) - v2Msg.addU16(uint16(len(m.cipherSuites) * 3)) - v2Msg.addU16(uint16(len(m.sessionID))) - v2Msg.addU16(uint16(len(m.v2Challenge))) + v2Msg := cryptobyte.NewBuilder(nil) + v2Msg.AddUint8(1) + v2Msg.AddUint16(m.vers) + v2Msg.AddUint16(uint16(len(m.cipherSuites) * 3)) + v2Msg.AddUint16(uint16(len(m.sessionID))) + v2Msg.AddUint16(uint16(len(m.v2Challenge))) for _, spec := range m.cipherSuites { - v2Msg.addU24(int(spec)) + v2Msg.AddUint24(uint32(spec)) } - v2Msg.addBytes(m.sessionID) - v2Msg.addBytes(m.v2Challenge) - m.raw = v2Msg.finish() + v2Msg.AddBytes(m.sessionID) + v2Msg.AddBytes(m.v2Challenge) + m.raw = v2Msg.BytesOrPanic() return m.raw } - handshakeMsg := newByteBuilder() - handshakeMsg.addU8(typeClientHello) - hello := handshakeMsg.addU24LengthPrefixed() - m.marshalBody(hello, clientHelloNormal) - m.raw = handshakeMsg.finish() + handshakeMsg := cryptobyte.NewBuilder(nil) + handshakeMsg.AddUint8(typeClientHello) + handshakeMsg.AddUint24LengthPrefixed(func(hello *cryptobyte.Builder) { + m.marshalBody(hello, clientHelloNormal) + }) + m.raw = handshakeMsg.BytesOrPanic() // Sanity-check padding. if m.pad != 0 && (len(m.raw)-4)%m.pad != 0 { panic(fmt.Sprintf("%d is not a multiple of %d", len(m.raw)-4, m.pad)) @@ -827,9 +657,9 @@ func (m *clientHelloMsg) marshal() []byte { return m.raw } -func parseSignatureAlgorithms(reader *byteReader, out *[]signatureAlgorithm, allowEmpty bool) bool { - var sigAlgs byteReader - if !reader.readU16LengthPrefixed(&sigAlgs) { +func parseSignatureAlgorithms(reader *cryptobyte.String, out *[]signatureAlgorithm, allowEmpty bool) bool { + var sigAlgs cryptobyte.String + if !reader.ReadUint16LengthPrefixed(&sigAlgs) { return false } if !allowEmpty && len(sigAlgs) == 0 { @@ -838,7 +668,7 @@ func parseSignatureAlgorithms(reader *byteReader, out *[]signatureAlgorithm, all *out = make([]signatureAlgorithm, 0, len(sigAlgs)/2) for len(sigAlgs) > 0 { var v uint16 - if !sigAlgs.readU16(&v) { + if !sigAlgs.ReadUint16(&v) { return false } if signatureAlgorithm(v) == signatureRSAPKCS1WithMD5AndSHA1 { @@ -852,13 +682,13 @@ func parseSignatureAlgorithms(reader *byteReader, out *[]signatureAlgorithm, all return true } -func checkDuplicateExtensions(extensions byteReader) bool { +func checkDuplicateExtensions(extensions cryptobyte.String) bool { seen := make(map[uint16]struct{}) for len(extensions) > 0 { var extension uint16 - var body byteReader - if !extensions.readU16(&extension) || - !extensions.readU16LengthPrefixed(&body) { + var body cryptobyte.String + if !extensions.ReadUint16(&extension) || + !extensions.ReadUint16LengthPrefixed(&body) { return false } if _, ok := seen[extension]; ok { @@ -871,26 +701,26 @@ func checkDuplicateExtensions(extensions byteReader) bool { func (m *clientHelloMsg) unmarshal(data []byte) bool { m.raw = data - reader := byteReader(data[4:]) - if !reader.readU16(&m.vers) || - !reader.readBytes(&m.random, 32) || - !reader.readU8LengthPrefixedBytes(&m.sessionID) || + reader := cryptobyte.String(data[4:]) + if !reader.ReadUint16(&m.vers) || + !reader.ReadBytes(&m.random, 32) || + !readUint8LengthPrefixedBytes(&reader, &m.sessionID) || len(m.sessionID) > 32 { return false } - if m.isDTLS && !reader.readU8LengthPrefixedBytes(&m.cookie) { + if m.isDTLS && !readUint8LengthPrefixedBytes(&reader, &m.cookie) { return false } - var cipherSuites byteReader - if !reader.readU16LengthPrefixed(&cipherSuites) || - !reader.readU8LengthPrefixedBytes(&m.compressionMethods) { + var cipherSuites cryptobyte.String + if !reader.ReadUint16LengthPrefixed(&cipherSuites) || + !readUint8LengthPrefixedBytes(&reader, &m.compressionMethods) { return false } m.cipherSuites = make([]uint16, 0, len(cipherSuites)/2) for len(cipherSuites) > 0 { var v uint16 - if !cipherSuites.readU16(&v) { + if !cipherSuites.ReadUint16(&v) { return false } m.cipherSuites = append(m.cipherSuites, v) @@ -921,29 +751,29 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { return true } - var extensions byteReader - if !reader.readU16LengthPrefixed(&extensions) || len(reader) != 0 || !checkDuplicateExtensions(extensions) { + var extensions cryptobyte.String + if !reader.ReadUint16LengthPrefixed(&extensions) || len(reader) != 0 || !checkDuplicateExtensions(extensions) { return false } m.rawExtensions = extensions for len(extensions) > 0 { var extension uint16 - var body byteReader - if !extensions.readU16(&extension) || - !extensions.readU16LengthPrefixed(&body) { + var body cryptobyte.String + if !extensions.ReadUint16(&extension) || + !extensions.ReadUint16LengthPrefixed(&body) { return false } switch extension { case extensionServerName: - var names byteReader - if !body.readU16LengthPrefixed(&names) || len(body) != 0 { + var names cryptobyte.String + if !body.ReadUint16LengthPrefixed(&names) || len(body) != 0 { return false } for len(names) > 0 { var nameType byte var name []byte - if !names.readU8(&nameType) || - !names.readU16LengthPrefixedBytes(&name) { + if !names.ReadUint8(&nameType) || + !readUint16LengthPrefixedBytes(&names, &name) { return false } if nameType == 0 { @@ -952,17 +782,17 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { } case extensionEncryptedClientHello: var typ byte - if !body.readU8(&typ) { + if !body.ReadUint8(&typ) { return false } switch typ { case echClientTypeOuter: var echOuter echClientOuter - if !body.readU16(&echOuter.kdfID) || - !body.readU16(&echOuter.aeadID) || - !body.readU8(&echOuter.configID) || - !body.readU16LengthPrefixedBytes(&echOuter.enc) || - !body.readU16LengthPrefixedBytes(&echOuter.payload) || + if !body.ReadUint16(&echOuter.kdfID) || + !body.ReadUint16(&echOuter.aeadID) || + !body.ReadUint8(&echOuter.configID) || + !readUint16LengthPrefixedBytes(&body, &echOuter.enc) || + !readUint16LengthPrefixedBytes(&body, &echOuter.payload) || len(echOuter.payload) == 0 || len(body) > 0 { return false @@ -989,11 +819,11 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { // extensibility, but we expect our client to only send empty // requests of type OCSP. var statusType uint8 - var responderIDList, innerExtensions byteReader - if !body.readU8(&statusType) || + var responderIDList, innerExtensions cryptobyte.String + if !body.ReadUint8(&statusType) || statusType != statusTypeOCSP || - !body.readU16LengthPrefixed(&responderIDList) || - !body.readU16LengthPrefixed(&innerExtensions) || + !body.ReadUint16LengthPrefixed(&responderIDList) || + !body.ReadUint16LengthPrefixed(&innerExtensions) || len(responderIDList) != 0 || len(innerExtensions) != 0 || len(body) != 0 { @@ -1002,21 +832,21 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { m.ocspStapling = true case extensionSupportedCurves: // http://tools.ietf.org/html/rfc4492#section-5.5.1 - var curves byteReader - if !body.readU16LengthPrefixed(&curves) || len(body) != 0 { + var curves cryptobyte.String + if !body.ReadUint16LengthPrefixed(&curves) || len(body) != 0 { return false } m.supportedCurves = make([]CurveID, 0, len(curves)/2) for len(curves) > 0 { var v uint16 - if !curves.readU16(&v) { + if !curves.ReadUint16(&v) { return false } m.supportedCurves = append(m.supportedCurves, CurveID(v)) } case extensionSupportedPoints: // http://tools.ietf.org/html/rfc4492#section-5.1.2 - if !body.readU8LengthPrefixedBytes(&m.supportedPoints) || len(m.supportedPoints) == 0 || len(body) != 0 { + if !readUint8LengthPrefixedBytes(&body, &m.supportedPoints) || len(m.supportedPoints) == 0 || len(body) != 0 { return false } case extensionSessionTicket: @@ -1027,15 +857,15 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { // https://tools.ietf.org/html/rfc8446#section-4.2.8 m.hasKeyShares = true m.keySharesRaw = body - var keyShares byteReader - if !body.readU16LengthPrefixed(&keyShares) || len(body) != 0 { + var keyShares cryptobyte.String + if !body.ReadUint16LengthPrefixed(&keyShares) || len(body) != 0 { return false } for len(keyShares) > 0 { var entry keyShareEntry var group uint16 - if !keyShares.readU16(&group) || - !keyShares.readU16LengthPrefixedBytes(&entry.keyExchange) { + if !keyShares.ReadUint16(&group) || + !readUint16LengthPrefixedBytes(&keyShares, &entry.keyExchange) { return false } entry.group = CurveID(group) @@ -1043,23 +873,23 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { } case extensionPreSharedKey: // https://tools.ietf.org/html/rfc8446#section-4.2.11 - var psks, binders byteReader - if !body.readU16LengthPrefixed(&psks) || - !body.readU16LengthPrefixed(&binders) || + var psks, binders cryptobyte.String + if !body.ReadUint16LengthPrefixed(&psks) || + !body.ReadUint16LengthPrefixed(&binders) || len(body) != 0 { return false } for len(psks) > 0 { var psk pskIdentity - if !psks.readU16LengthPrefixedBytes(&psk.ticket) || - !psks.readU32(&psk.obfuscatedTicketAge) { + if !readUint16LengthPrefixedBytes(&psks, &psk.ticket) || + !psks.ReadUint32(&psk.obfuscatedTicketAge) { return false } m.pskIdentities = append(m.pskIdentities, psk) } for len(binders) > 0 { var binder []byte - if !binders.readU8LengthPrefixedBytes(&binder) { + if !readUint8LengthPrefixedBytes(&binders, &binder) { return false } m.pskBinders = append(m.pskBinders, binder) @@ -1071,7 +901,7 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { } case extensionPSKKeyExchangeModes: // https://tools.ietf.org/html/rfc8446#section-4.2.9 - if !body.readU8LengthPrefixedBytes(&m.pskKEModes) || len(body) != 0 { + if !readUint8LengthPrefixedBytes(&body, &m.pskKEModes) || len(body) != 0 { return false } case extensionEarlyData: @@ -1081,7 +911,7 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { } m.hasEarlyData = true case extensionCookie: - if !body.readU16LengthPrefixedBytes(&m.tls13Cookie) || len(body) != 0 { + if !readUint16LengthPrefixedBytes(&body, &m.tls13Cookie) || len(body) != 0 { return false } case extensionSignatureAlgorithms: @@ -1094,30 +924,30 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { return false } case extensionSupportedVersions: - var versions byteReader - if !body.readU8LengthPrefixed(&versions) || len(body) != 0 { + var versions cryptobyte.String + if !body.ReadUint8LengthPrefixed(&versions) || len(body) != 0 { return false } m.supportedVersions = make([]uint16, 0, len(versions)/2) for len(versions) > 0 { var v uint16 - if !versions.readU16(&v) { + if !versions.ReadUint16(&v) { return false } m.supportedVersions = append(m.supportedVersions, v) } case extensionRenegotiationInfo: - if !body.readU8LengthPrefixedBytes(&m.secureRenegotiation) || len(body) != 0 { + if !readUint8LengthPrefixedBytes(&body, &m.secureRenegotiation) || len(body) != 0 { return false } case extensionALPN: - var protocols byteReader - if !body.readU16LengthPrefixed(&protocols) || len(body) != 0 { + var protocols cryptobyte.String + if !body.ReadUint16LengthPrefixed(&protocols) || len(body) != 0 { return false } for len(protocols) > 0 { var protocol []byte - if !protocols.readU8LengthPrefixedBytes(&protocol) || len(protocol) == 0 { + if !readUint8LengthPrefixedBytes(&protocols, &protocol) || len(protocol) == 0 { return false } m.alpnProtocols = append(m.alpnProtocols, string(protocol)) @@ -1137,17 +967,17 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { } m.extendedMasterSecret = true case extensionUseSRTP: - var profiles byteReader + var profiles cryptobyte.String var mki []byte - if !body.readU16LengthPrefixed(&profiles) || - !body.readU8LengthPrefixedBytes(&mki) || + if !body.ReadUint16LengthPrefixed(&profiles) || + !readUint8LengthPrefixedBytes(&body, &mki) || len(body) != 0 { return false } m.srtpProtectionProfiles = make([]uint16, 0, len(profiles)/2) for len(profiles) > 0 { var v uint16 - if !profiles.readU16(&v) { + if !profiles.ReadUint16(&v) { return false } m.srtpProtectionProfiles = append(m.srtpProtectionProfiles, v) @@ -1161,15 +991,15 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { case extensionCustom: m.customExtension = string(body) case extensionCompressedCertAlgs: - var algIDs byteReader - if !body.readU8LengthPrefixed(&algIDs) { + var algIDs cryptobyte.String + if !body.ReadUint8LengthPrefixed(&algIDs) { return false } seen := make(map[uint16]struct{}) for len(algIDs) > 0 { var algID uint16 - if !algIDs.readU16(&algID) { + if !algIDs.ReadUint16(&algID) { return false } if _, ok := seen[algID]; ok { @@ -1191,13 +1021,13 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { } m.delegatedCredentials = true case extensionApplicationSettings: - var protocols byteReader - if !body.readU16LengthPrefixed(&protocols) || len(body) != 0 { + var protocols cryptobyte.String + if !body.ReadUint16LengthPrefixed(&protocols) || len(body) != 0 { return false } for len(protocols) > 0 { var protocol []byte - if !protocols.readU8LengthPrefixedBytes(&protocol) || len(protocol) == 0 { + if !readUint8LengthPrefixedBytes(&protocols, &protocol) || len(protocol) == 0 { return false } m.alpsProtocols = append(m.alpsProtocols, string(protocol)) @@ -1213,15 +1043,15 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { } func decodeClientHelloInner(config *Config, encoded []byte, helloOuter *clientHelloMsg) (*clientHelloMsg, error) { - reader := byteReader(encoded) + reader := cryptobyte.String(encoded) var versAndRandom, sessionID, cipherSuites, compressionMethods []byte - var extensions byteReader - if !reader.readBytes(&versAndRandom, 2+32) || - !reader.readU8LengthPrefixedBytes(&sessionID) || + var extensions cryptobyte.String + if !reader.ReadBytes(&versAndRandom, 2+32) || + !readUint8LengthPrefixedBytes(&reader, &sessionID) || len(sessionID) != 0 || // Copied from |helloOuter| - !reader.readU16LengthPrefixedBytes(&cipherSuites) || - !reader.readU8LengthPrefixedBytes(&compressionMethods) || - !reader.readU16LengthPrefixed(&extensions) { + !readUint16LengthPrefixedBytes(&reader, &cipherSuites) || + !readUint8LengthPrefixedBytes(&reader, &compressionMethods) || + !reader.ReadUint16LengthPrefixed(&extensions) { return nil, errors.New("tls: error parsing EncodedClientHelloInner") } @@ -1232,64 +1062,77 @@ func decodeClientHelloInner(config *Config, encoded []byte, helloOuter *clientHe } } - builder := newByteBuilder() - builder.addU8(typeClientHello) - body := builder.addU24LengthPrefixed() - body.addBytes(versAndRandom) - body.addU8LengthPrefixed().addBytes(helloOuter.sessionID) - body.addU16LengthPrefixed().addBytes(cipherSuites) - body.addU8LengthPrefixed().addBytes(compressionMethods) - newExtensions := body.addU16LengthPrefixed() - - var seenOuterExtensions bool - outerExtensions := byteReader(helloOuter.rawExtensions) copied := make(map[uint16]struct{}) - for len(extensions) > 0 { - var extType uint16 - var extBody byteReader - if !extensions.readU16(&extType) || - !extensions.readU16LengthPrefixed(&extBody) { - return nil, errors.New("tls: error parsing EncodedClientHelloInner") - } - if extType != extensionECHOuterExtensions { - newExtensions.addU16(extType) - newExtensions.addU16LengthPrefixed().addBytes(extBody) - continue - } - if seenOuterExtensions { - return nil, errors.New("tls: duplicate ech_outer_extensions extension") - } - seenOuterExtensions = true - var extList byteReader - if !extBody.readU8LengthPrefixed(&extList) || len(extList) == 0 || len(extBody) != 0 { - return nil, errors.New("tls: error parsing ech_outer_extensions") - } - for len(extList) != 0 { - var newExtType uint16 - if !extList.readU16(&newExtType) { - return nil, errors.New("tls: error parsing ech_outer_extensions") - } - if newExtType == extensionEncryptedClientHello { - return nil, errors.New("tls: error parsing ech_outer_extensions") - } - for { - if len(outerExtensions) == 0 { - return nil, fmt.Errorf("tls: extension %d not found in ClientHelloOuter", newExtType) + builder := cryptobyte.NewBuilder(nil) + builder.AddUint8(typeClientHello) + builder.AddUint24LengthPrefixed(func(body *cryptobyte.Builder) { + body.AddBytes(versAndRandom) + addUint8LengthPrefixedBytes(body, helloOuter.sessionID) + addUint16LengthPrefixedBytes(body, cipherSuites) + addUint8LengthPrefixedBytes(body, compressionMethods) + body.AddUint16LengthPrefixed(func(newExtensions *cryptobyte.Builder) { + var seenOuterExtensions bool + outerExtensions := cryptobyte.String(helloOuter.rawExtensions) + for len(extensions) > 0 { + var extType uint16 + var extBody cryptobyte.String + if !extensions.ReadUint16(&extType) || + !extensions.ReadUint16LengthPrefixed(&extBody) { + newExtensions.SetError(errors.New("tls: error parsing EncodedClientHelloInner")) + return + } + if extType != extensionECHOuterExtensions { + newExtensions.AddUint16(extType) + addUint16LengthPrefixedBytes(newExtensions, extBody) + continue } - var foundExt uint16 - var newExtBody []byte - if !outerExtensions.readU16(&foundExt) || - !outerExtensions.readU16LengthPrefixedBytes(&newExtBody) { - return nil, errors.New("tls: error parsing ClientHelloOuter") + if seenOuterExtensions { + newExtensions.SetError(errors.New("tls: duplicate ech_outer_extensions extension")) + return } - if foundExt == newExtType { - newExtensions.addU16(newExtType) - newExtensions.addU16LengthPrefixed().addBytes(newExtBody) - copied[newExtType] = struct{}{} - break + seenOuterExtensions = true + var extList cryptobyte.String + if !extBody.ReadUint8LengthPrefixed(&extList) || len(extList) == 0 || len(extBody) != 0 { + newExtensions.SetError(errors.New("tls: error parsing ech_outer_extensions")) + return + } + for len(extList) != 0 { + var newExtType uint16 + if !extList.ReadUint16(&newExtType) { + newExtensions.SetError(errors.New("tls: error parsing ech_outer_extensions")) + return + } + if newExtType == extensionEncryptedClientHello { + newExtensions.SetError(errors.New("tls: error parsing ech_outer_extensions")) + return + } + for { + if len(outerExtensions) == 0 { + newExtensions.SetError(fmt.Errorf("tls: extension %d not found in ClientHelloOuter", newExtType)) + return + } + var foundExt uint16 + var newExtBody []byte + if !outerExtensions.ReadUint16(&foundExt) || + !readUint16LengthPrefixedBytes(&outerExtensions, &newExtBody) { + newExtensions.SetError(errors.New("tls: error parsing ClientHelloOuter")) + return + } + if foundExt == newExtType { + newExtensions.AddUint16(newExtType) + addUint16LengthPrefixedBytes(newExtensions, newExtBody) + copied[newExtType] = struct{}{} + break + } + } } } - } + }) + }) + + bytes, err := builder.Bytes() + if err != nil { + return nil, err } for _, expected := range config.Bugs.ExpectECHOuterExtensions { @@ -1304,9 +1147,10 @@ func decodeClientHelloInner(config *Config, encoded []byte, helloOuter *clientHe } ret := new(clientHelloMsg) - if !ret.unmarshal(builder.finish()) { + if !ret.unmarshal(bytes) { return nil, errors.New("tls: error parsing reconstructed ClientHello") } + return ret, nil } @@ -1337,102 +1181,100 @@ func (m *serverHelloMsg) marshal() []byte { return m.raw } - handshakeMsg := newByteBuilder() - handshakeMsg.addU8(typeServerHello) - hello := handshakeMsg.addU24LengthPrefixed() - - // m.vers is used both to determine the format of the rest of the - // ServerHello and to override the value, so include a second version - // field. - vers, ok := wireToVersion(m.vers, m.isDTLS) - if !ok { - panic("unknown version") - } - if m.versOverride != 0 { - hello.addU16(m.versOverride) - } else if vers >= VersionTLS13 { - hello.addU16(VersionTLS12) - } else { - hello.addU16(m.vers) - } - - hello.addBytes(m.random) - sessionID := hello.addU8LengthPrefixed() - sessionID.addBytes(m.sessionID) - hello.addU16(m.cipherSuite) - hello.addU8(m.compressionMethod) - - extensions := hello.addU16LengthPrefixed() - - if vers >= VersionTLS13 { - if m.hasKeyShare { - extensions.addU16(extensionKeyShare) - keyShare := extensions.addU16LengthPrefixed() - keyShare.addU16(uint16(m.keyShare.group)) - keyExchange := keyShare.addU16LengthPrefixed() - keyExchange.addBytes(m.keyShare.keyExchange) - } - if m.hasPSKIdentity { - extensions.addU16(extensionPreSharedKey) - extensions.addU16(2) // Length - extensions.addU16(m.pskIdentity) - } - if !m.omitSupportedVers { - extensions.addU16(extensionSupportedVersions) - extensions.addU16(2) // Length - if m.supportedVersOverride != 0 { - extensions.addU16(m.supportedVersOverride) + handshakeMsg := cryptobyte.NewBuilder(nil) + handshakeMsg.AddUint8(typeServerHello) + handshakeMsg.AddUint24LengthPrefixed(func(hello *cryptobyte.Builder) { + // m.vers is used both to determine the format of the rest of the + // ServerHello and to override the value, so include a second version + // field. + vers, ok := wireToVersion(m.vers, m.isDTLS) + if !ok { + panic("unknown version") + } + if m.versOverride != 0 { + hello.AddUint16(m.versOverride) + } else if vers >= VersionTLS13 { + hello.AddUint16(VersionTLS12) + } else { + hello.AddUint16(m.vers) + } + + hello.AddBytes(m.random) + addUint8LengthPrefixedBytes(hello, m.sessionID) + hello.AddUint16(m.cipherSuite) + hello.AddUint8(m.compressionMethod) + + hello.AddUint16LengthPrefixed(func(extensions *cryptobyte.Builder) { + if vers >= VersionTLS13 { + if m.hasKeyShare { + extensions.AddUint16(extensionKeyShare) + extensions.AddUint16LengthPrefixed(func(keyShare *cryptobyte.Builder) { + keyShare.AddUint16(uint16(m.keyShare.group)) + addUint16LengthPrefixedBytes(keyShare, m.keyShare.keyExchange) + }) + } + if m.hasPSKIdentity { + extensions.AddUint16(extensionPreSharedKey) + extensions.AddUint16(2) // Length + extensions.AddUint16(m.pskIdentity) + } + if !m.omitSupportedVers { + extensions.AddUint16(extensionSupportedVersions) + extensions.AddUint16(2) // Length + if m.supportedVersOverride != 0 { + extensions.AddUint16(m.supportedVersOverride) + } else { + extensions.AddUint16(m.vers) + } + } + if len(m.customExtension) > 0 { + extensions.AddUint16(extensionCustom) + addUint16LengthPrefixedBytes(extensions, []byte(m.customExtension)) + } + if len(m.unencryptedALPN) > 0 { + extensions.AddUint16(extensionALPN) + extensions.AddUint16LengthPrefixed(func(extension *cryptobyte.Builder) { + extension.AddUint16LengthPrefixed(func(protocolNameList *cryptobyte.Builder) { + addUint8LengthPrefixedBytes(protocolNameList, []byte(m.unencryptedALPN)) + }) + }) + } } else { - extensions.addU16(m.vers) - } - } - if len(m.customExtension) > 0 { - extensions.addU16(extensionCustom) - customExt := extensions.addU16LengthPrefixed() - customExt.addBytes([]byte(m.customExtension)) - } - if len(m.unencryptedALPN) > 0 { - extensions.addU16(extensionALPN) - extension := extensions.addU16LengthPrefixed() - - protocolNameList := extension.addU16LengthPrefixed() - protocolName := protocolNameList.addU8LengthPrefixed() - protocolName.addBytes([]byte(m.unencryptedALPN)) - } - } else { - m.extensions.marshal(extensions) - if m.omitExtensions || m.emptyExtensions { - // Silently erasing server extensions will break the handshake. Instead, - // assert that tests which use this field also disable all features which - // would write an extension. - if extensions.len() != 0 { - panic(fmt.Sprintf("ServerHello unexpectedly contained extensions: %x, %+v", extensions.data(), m)) - } - hello.discardChild() - if m.emptyExtensions { - hello.addU16(0) + m.extensions.marshal(extensions) + } + if m.omitExtensions || m.emptyExtensions { + // Silently erasing server extensions will break the handshake. Instead, + // assert that tests which use this field also disable all features which + // would write an extension. Note the length includes the length prefix. + if b := extensions.BytesOrPanic(); len(b) != 2 { + panic(fmt.Sprintf("ServerHello unexpectedly contained extensions: %x, %+v", b, m)) + } } + }) + // Remove the length prefix. + if m.omitExtensions { + hello.Unwrite(2) } - } + }) - m.raw = handshakeMsg.finish() + m.raw = handshakeMsg.BytesOrPanic() return m.raw } func (m *serverHelloMsg) unmarshal(data []byte) bool { m.raw = data - reader := byteReader(data[4:]) - if !reader.readU16(&m.vers) || - !reader.readBytes(&m.random, 32) { + reader := cryptobyte.String(data[4:]) + if !reader.ReadUint16(&m.vers) || + !reader.ReadBytes(&m.random, 32) { return false } vers, ok := wireToVersion(m.vers, m.isDTLS) if !ok { return false } - if !reader.readU8LengthPrefixedBytes(&m.sessionID) || - !reader.readU16(&m.cipherSuite) || - !reader.readU8(&m.compressionMethod) { + if !readUint8LengthPrefixedBytes(&reader, &m.sessionID) || + !reader.ReadUint16(&m.cipherSuite) || + !reader.ReadUint8(&m.compressionMethod) { return false } @@ -1443,8 +1285,8 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool { return true } - var extensions byteReader - if !reader.readU16LengthPrefixed(&extensions) || len(reader) != 0 || !checkDuplicateExtensions(extensions) { + var extensions cryptobyte.String + if !reader.ReadUint16LengthPrefixed(&extensions) || len(reader) != 0 || !checkDuplicateExtensions(extensions) { return false } @@ -1453,13 +1295,13 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool { extensionsCopy := extensions for len(extensionsCopy) > 0 { var extension uint16 - var body byteReader - if !extensionsCopy.readU16(&extension) || - !extensionsCopy.readU16LengthPrefixed(&body) { + var body cryptobyte.String + if !extensionsCopy.ReadUint16(&extension) || + !extensionsCopy.ReadUint16LengthPrefixed(&body) { return false } if extension == extensionSupportedVersions { - if !body.readU16(&m.vers) || len(body) != 0 { + if !body.ReadUint16(&m.vers) || len(body) != 0 { return false } vers, ok = wireToVersion(m.vers, m.isDTLS) @@ -1473,23 +1315,23 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool { if vers >= VersionTLS13 { for len(extensions) > 0 { var extension uint16 - var body byteReader - if !extensions.readU16(&extension) || - !extensions.readU16LengthPrefixed(&body) { + var body cryptobyte.String + if !extensions.ReadUint16(&extension) || + !extensions.ReadUint16LengthPrefixed(&body) { return false } switch extension { case extensionKeyShare: m.hasKeyShare = true var group uint16 - if !body.readU16(&group) || - !body.readU16LengthPrefixedBytes(&m.keyShare.keyExchange) || + if !body.ReadUint16(&group) || + !readUint16LengthPrefixedBytes(&body, &m.keyShare.keyExchange) || len(body) != 0 { return false } m.keyShare.group = CurveID(group) case extensionPreSharedKey: - if !body.readU16(&m.pskIdentity) || len(body) != 0 { + if !body.ReadUint16(&m.pskIdentity) || len(body) != 0 { return false } m.hasPSKIdentity = true @@ -1519,23 +1361,25 @@ func (m *encryptedExtensionsMsg) marshal() []byte { return m.raw } - encryptedExtensionsMsg := newByteBuilder() - encryptedExtensionsMsg.addU8(typeEncryptedExtensions) - encryptedExtensions := encryptedExtensionsMsg.addU24LengthPrefixed() - if !m.empty { - extensions := encryptedExtensions.addU16LengthPrefixed() - m.extensions.marshal(extensions) - } + encryptedExtensionsMsg := cryptobyte.NewBuilder(nil) + encryptedExtensionsMsg.AddUint8(typeEncryptedExtensions) + encryptedExtensionsMsg.AddUint24LengthPrefixed(func(encryptedExtensions *cryptobyte.Builder) { + if !m.empty { + encryptedExtensions.AddUint16LengthPrefixed(func(extensions *cryptobyte.Builder) { + m.extensions.marshal(extensions) + }) + } + }) - m.raw = encryptedExtensionsMsg.finish() + m.raw = encryptedExtensionsMsg.BytesOrPanic() return m.raw } func (m *encryptedExtensionsMsg) unmarshal(data []byte) bool { m.raw = data - reader := byteReader(data[4:]) - var extensions byteReader - if !reader.readU16LengthPrefixed(&extensions) || len(reader) != 0 { + reader := cryptobyte.String(data[4:]) + var extensions cryptobyte.String + if !reader.ReadUint16LengthPrefixed(&extensions) || len(reader) != 0 { return false } return m.extensions.unmarshal(extensions, VersionTLS13) @@ -1571,147 +1415,137 @@ type serverExtensions struct { echRetryConfigs []byte } -func (m *serverExtensions) marshal(extensions *byteBuilder) { +func (m *serverExtensions) marshal(extensions *cryptobyte.Builder) { if m.duplicateExtension { // Add a duplicate bogus extension at the beginning and end. - extensions.addU16(extensionDuplicate) - extensions.addU16(0) // length = 0 for empty extension + extensions.AddUint16(extensionDuplicate) + extensions.AddUint16(0) // length = 0 for empty extension } if m.nextProtoNeg && !m.npnAfterAlpn { - extensions.addU16(extensionNextProtoNeg) - extension := extensions.addU16LengthPrefixed() - - for _, v := range m.nextProtos { - if len(v) > 255 { - v = v[:255] + extensions.AddUint16(extensionNextProtoNeg) + extensions.AddUint16LengthPrefixed(func(extension *cryptobyte.Builder) { + for _, v := range m.nextProtos { + addUint8LengthPrefixedBytes(extension, []byte(v)) } - npn := extension.addU8LengthPrefixed() - npn.addBytes([]byte(v)) - } + }) } if m.ocspStapling { - extensions.addU16(extensionStatusRequest) - extensions.addU16(0) + extensions.AddUint16(extensionStatusRequest) + extensions.AddUint16(0) } if m.ticketSupported { - extensions.addU16(extensionSessionTicket) - extensions.addU16(0) + extensions.AddUint16(extensionSessionTicket) + extensions.AddUint16(0) } if m.secureRenegotiation != nil { - extensions.addU16(extensionRenegotiationInfo) - extension := extensions.addU16LengthPrefixed() - secureRenego := extension.addU8LengthPrefixed() - secureRenego.addBytes(m.secureRenegotiation) + extensions.AddUint16(extensionRenegotiationInfo) + extensions.AddUint16LengthPrefixed(func(extension *cryptobyte.Builder) { + addUint8LengthPrefixedBytes(extension, m.secureRenegotiation) + }) } if len(m.alpnProtocol) > 0 || m.alpnProtocolEmpty { - extensions.addU16(extensionALPN) - extension := extensions.addU16LengthPrefixed() - - protocolNameList := extension.addU16LengthPrefixed() - protocolName := protocolNameList.addU8LengthPrefixed() - protocolName.addBytes([]byte(m.alpnProtocol)) + extensions.AddUint16(extensionALPN) + extensions.AddUint16LengthPrefixed(func(extension *cryptobyte.Builder) { + extension.AddUint16LengthPrefixed(func(protocolNameList *cryptobyte.Builder) { + addUint8LengthPrefixedBytes(protocolNameList, []byte(m.alpnProtocol)) + }) + }) } if m.channelIDRequested { - extensions.addU16(extensionChannelID) - extensions.addU16(0) + extensions.AddUint16(extensionChannelID) + extensions.AddUint16(0) } if m.duplicateExtension { // Add a duplicate bogus extension at the beginning and end. - extensions.addU16(extensionDuplicate) - extensions.addU16(0) + extensions.AddUint16(extensionDuplicate) + extensions.AddUint16(0) } if m.extendedMasterSecret { - extensions.addU16(extensionExtendedMasterSecret) - extensions.addU16(0) + extensions.AddUint16(extensionExtendedMasterSecret) + extensions.AddUint16(0) } if m.srtpProtectionProfile != 0 { - extensions.addU16(extensionUseSRTP) - extension := extensions.addU16LengthPrefixed() - - srtpProtectionProfiles := extension.addU16LengthPrefixed() - srtpProtectionProfiles.addU16(m.srtpProtectionProfile) - srtpMki := extension.addU8LengthPrefixed() - srtpMki.addBytes([]byte(m.srtpMasterKeyIdentifier)) + extensions.AddUint16(extensionUseSRTP) + extensions.AddUint16LengthPrefixed(func(extension *cryptobyte.Builder) { + extension.AddUint16LengthPrefixed(func(srtpProtectionProfiles *cryptobyte.Builder) { + srtpProtectionProfiles.AddUint16(m.srtpProtectionProfile) + }) + addUint8LengthPrefixedBytes(extension, []byte(m.srtpMasterKeyIdentifier)) + }) } if m.sctList != nil { - extensions.addU16(extensionSignedCertificateTimestamp) - extension := extensions.addU16LengthPrefixed() - extension.addBytes(m.sctList) + extensions.AddUint16(extensionSignedCertificateTimestamp) + addUint16LengthPrefixedBytes(extensions, m.sctList) } if l := len(m.customExtension); l > 0 { - extensions.addU16(extensionCustom) - customExt := extensions.addU16LengthPrefixed() - customExt.addBytes([]byte(m.customExtension)) + extensions.AddUint16(extensionCustom) + addUint16LengthPrefixedBytes(extensions, []byte(m.customExtension)) } if m.nextProtoNeg && m.npnAfterAlpn { - extensions.addU16(extensionNextProtoNeg) - extension := extensions.addU16LengthPrefixed() - - for _, v := range m.nextProtos { - if len(v) > 255 { - v = v[0:255] + extensions.AddUint16(extensionNextProtoNeg) + extensions.AddUint16LengthPrefixed(func(extension *cryptobyte.Builder) { + for _, v := range m.nextProtos { + addUint8LengthPrefixedBytes(extension, []byte(v)) } - npn := extension.addU8LengthPrefixed() - npn.addBytes([]byte(v)) - } + }) } if m.hasKeyShare { - extensions.addU16(extensionKeyShare) - keyShare := extensions.addU16LengthPrefixed() - keyShare.addU16(uint16(m.keyShare.group)) - keyExchange := keyShare.addU16LengthPrefixed() - keyExchange.addBytes(m.keyShare.keyExchange) + extensions.AddUint16(extensionKeyShare) + extensions.AddUint16LengthPrefixed(func(keyShare *cryptobyte.Builder) { + keyShare.AddUint16(uint16(m.keyShare.group)) + addUint16LengthPrefixedBytes(keyShare, m.keyShare.keyExchange) + }) } if m.supportedVersion != 0 { - extensions.addU16(extensionSupportedVersions) - extensions.addU16(2) // Length - extensions.addU16(m.supportedVersion) + extensions.AddUint16(extensionSupportedVersions) + extensions.AddUint16(2) // Length + extensions.AddUint16(m.supportedVersion) } if len(m.supportedPoints) > 0 { // http://tools.ietf.org/html/rfc4492#section-5.1.2 - extensions.addU16(extensionSupportedPoints) - supportedPointsList := extensions.addU16LengthPrefixed() - supportedPoints := supportedPointsList.addU8LengthPrefixed() - supportedPoints.addBytes(m.supportedPoints) + extensions.AddUint16(extensionSupportedPoints) + extensions.AddUint16LengthPrefixed(func(supportedPointsList *cryptobyte.Builder) { + addUint8LengthPrefixedBytes(supportedPointsList, m.supportedPoints) + }) } if len(m.supportedCurves) > 0 { // https://tools.ietf.org/html/rfc8446#section-4.2.7 - extensions.addU16(extensionSupportedCurves) - supportedCurvesList := extensions.addU16LengthPrefixed() - supportedCurves := supportedCurvesList.addU16LengthPrefixed() - for _, curve := range m.supportedCurves { - supportedCurves.addU16(uint16(curve)) - } + extensions.AddUint16(extensionSupportedCurves) + extensions.AddUint16LengthPrefixed(func(supportedCurvesList *cryptobyte.Builder) { + supportedCurvesList.AddUint16LengthPrefixed(func(supportedCurves *cryptobyte.Builder) { + for _, curve := range m.supportedCurves { + supportedCurves.AddUint16(uint16(curve)) + } + }) + }) } if len(m.quicTransportParams) > 0 { - extensions.addU16(extensionQUICTransportParams) - params := extensions.addU16LengthPrefixed() - params.addBytes(m.quicTransportParams) + extensions.AddUint16(extensionQUICTransportParams) + addUint16LengthPrefixedBytes(extensions, m.quicTransportParams) } if len(m.quicTransportParamsLegacy) > 0 { - extensions.addU16(extensionQUICTransportParamsLegacy) - params := extensions.addU16LengthPrefixed() - params.addBytes(m.quicTransportParamsLegacy) + extensions.AddUint16(extensionQUICTransportParamsLegacy) + addUint16LengthPrefixedBytes(extensions, m.quicTransportParamsLegacy) } if m.hasEarlyData { - extensions.addU16(extensionEarlyData) - extensions.addBytes([]byte{0, 0}) + extensions.AddUint16(extensionEarlyData) + extensions.AddBytes([]byte{0, 0}) } if m.serverNameAck { - extensions.addU16(extensionServerName) - extensions.addU16(0) // zero length + extensions.AddUint16(extensionServerName) + extensions.AddUint16(0) // zero length } if m.hasApplicationSettings { - extensions.addU16(extensionApplicationSettings) - extensions.addU16LengthPrefixed().addBytes(m.applicationSettings) + extensions.AddUint16(extensionApplicationSettings) + addUint16LengthPrefixedBytes(extensions, m.applicationSettings) } if len(m.echRetryConfigs) > 0 { - extensions.addU16(extensionEncryptedClientHello) - extensions.addU16LengthPrefixed().addBytes(m.echRetryConfigs) + extensions.AddUint16(extensionEncryptedClientHello) + addUint16LengthPrefixedBytes(extensions, m.echRetryConfigs) } } -func (m *serverExtensions) unmarshal(data byteReader, version uint16) bool { +func (m *serverExtensions) unmarshal(data cryptobyte.String, version uint16) bool { // Reset all fields. *m = serverExtensions{} @@ -1721,9 +1555,9 @@ func (m *serverExtensions) unmarshal(data byteReader, version uint16) bool { for len(data) > 0 { var extension uint16 - var body byteReader - if !data.readU16(&extension) || - !data.readU16LengthPrefixed(&body) { + var body cryptobyte.String + if !data.ReadUint16(&extension) || + !data.ReadUint16LengthPrefixed(&body) { return false } switch extension { @@ -1731,7 +1565,7 @@ func (m *serverExtensions) unmarshal(data byteReader, version uint16) bool { m.nextProtoNeg = true for len(body) > 0 { var protocol []byte - if !body.readU8LengthPrefixedBytes(&protocol) { + if !readUint8LengthPrefixedBytes(&body, &protocol) { return false } m.nextProtos = append(m.nextProtos, string(protocol)) @@ -1747,14 +1581,14 @@ func (m *serverExtensions) unmarshal(data byteReader, version uint16) bool { } m.ticketSupported = true case extensionRenegotiationInfo: - if !body.readU8LengthPrefixedBytes(&m.secureRenegotiation) || len(body) != 0 { + if !readUint8LengthPrefixedBytes(&body, &m.secureRenegotiation) || len(body) != 0 { return false } case extensionALPN: - var protocols, protocol byteReader - if !body.readU16LengthPrefixed(&protocols) || + var protocols, protocol cryptobyte.String + if !body.ReadUint16LengthPrefixed(&protocols) || len(body) != 0 || - !protocols.readU8LengthPrefixed(&protocol) || + !protocols.ReadUint8LengthPrefixed(&protocol) || len(protocols) != 0 { return false } @@ -1771,11 +1605,11 @@ func (m *serverExtensions) unmarshal(data byteReader, version uint16) bool { } m.extendedMasterSecret = true case extensionUseSRTP: - var profiles, mki byteReader - if !body.readU16LengthPrefixed(&profiles) || - !profiles.readU16(&m.srtpProtectionProfile) || + var profiles, mki cryptobyte.String + if !body.ReadUint16LengthPrefixed(&profiles) || + !profiles.ReadUint16(&m.srtpProtectionProfile) || len(profiles) != 0 || - !body.readU8LengthPrefixed(&mki) || + !body.ReadUint8LengthPrefixed(&mki) || len(body) != 0 { return false } @@ -1795,7 +1629,7 @@ func (m *serverExtensions) unmarshal(data byteReader, version uint16) bool { return false } // http://tools.ietf.org/html/rfc4492#section-5.5.2 - if !body.readU8LengthPrefixedBytes(&m.supportedPoints) || len(body) != 0 { + if !readUint8LengthPrefixedBytes(&body, &m.supportedPoints) || len(body) != 0 { return false } case extensionSupportedCurves: @@ -1822,15 +1656,15 @@ func (m *serverExtensions) unmarshal(data byteReader, version uint16) bool { m.echRetryConfigs = body // Validate the ECHConfig with a top-level parse. - var echConfigs byteReader - if !body.readU16LengthPrefixed(&echConfigs) { + var echConfigs cryptobyte.String + if !body.ReadUint16LengthPrefixed(&echConfigs) { return false } for len(echConfigs) > 0 { var version uint16 - var contents byteReader - if !echConfigs.readU16(&version) || - !echConfigs.readU16LengthPrefixed(&contents) { + var contents cryptobyte.String + if !echConfigs.ReadUint16(&version) || + !echConfigs.ReadUint16LengthPrefixed(&contents) { return false } } @@ -1858,29 +1692,31 @@ func (m *clientEncryptedExtensionsMsg) marshal() (x []byte) { return m.raw } - builder := newByteBuilder() - builder.addU8(typeEncryptedExtensions) - body := builder.addU24LengthPrefixed() - extensions := body.addU16LengthPrefixed() - if m.hasApplicationSettings { - extensions.addU16(extensionApplicationSettings) - extensions.addU16LengthPrefixed().addBytes(m.applicationSettings) - } - if len(m.customExtension) > 0 { - extensions.addU16(extensionCustom) - extensions.addU16LengthPrefixed().addBytes(m.customExtension) - } + builder := cryptobyte.NewBuilder(nil) + builder.AddUint8(typeEncryptedExtensions) + builder.AddUint24LengthPrefixed(func(body *cryptobyte.Builder) { + body.AddUint16LengthPrefixed(func(extensions *cryptobyte.Builder) { + if m.hasApplicationSettings { + extensions.AddUint16(extensionApplicationSettings) + addUint16LengthPrefixedBytes(extensions, m.applicationSettings) + } + if len(m.customExtension) > 0 { + extensions.AddUint16(extensionCustom) + addUint16LengthPrefixedBytes(extensions, m.customExtension) + } + }) + }) - m.raw = builder.finish() + m.raw = builder.BytesOrPanic() return m.raw } func (m *clientEncryptedExtensionsMsg) unmarshal(data []byte) bool { m.raw = data - reader := byteReader(data[4:]) + reader := cryptobyte.String(data[4:]) - var extensions byteReader - if !reader.readU16LengthPrefixed(&extensions) || + var extensions cryptobyte.String + if !reader.ReadUint16LengthPrefixed(&extensions) || len(reader) != 0 { return false } @@ -1891,9 +1727,9 @@ func (m *clientEncryptedExtensionsMsg) unmarshal(data []byte) bool { for len(extensions) > 0 { var extension uint16 - var body byteReader - if !extensions.readU16(&extension) || - !extensions.readU16LengthPrefixed(&body) { + var body cryptobyte.String + if !extensions.ReadUint16(&extension) || + !extensions.ReadUint16LengthPrefixed(&body) { return false } switch extension { @@ -1928,92 +1764,93 @@ func (m *helloRetryRequestMsg) marshal() []byte { return m.raw } - retryRequestMsg := newByteBuilder() - retryRequestMsg.addU8(typeServerHello) - retryRequest := retryRequestMsg.addU24LengthPrefixed() - retryRequest.addU16(VersionTLS12) - retryRequest.addBytes(tls13HelloRetryRequest) - sessionID := retryRequest.addU8LengthPrefixed() - sessionID.addBytes(m.sessionID) - retryRequest.addU16(m.cipherSuite) - retryRequest.addU8(m.compressionMethod) - - extensions := retryRequest.addU16LengthPrefixed() + retryRequestMsg := cryptobyte.NewBuilder(nil) + retryRequestMsg.AddUint8(typeServerHello) + retryRequestMsg.AddUint24LengthPrefixed(func(retryRequest *cryptobyte.Builder) { + retryRequest.AddUint16(VersionTLS12) + retryRequest.AddBytes(tls13HelloRetryRequest) + addUint8LengthPrefixedBytes(retryRequest, m.sessionID) + retryRequest.AddUint16(m.cipherSuite) + retryRequest.AddUint8(m.compressionMethod) - count := 1 - if m.duplicateExtensions { - count = 2 - } + retryRequest.AddUint16LengthPrefixed(func(extensions *cryptobyte.Builder) { + count := 1 + if m.duplicateExtensions { + count = 2 + } - for i := 0; i < count; i++ { - extensions.addU16(extensionSupportedVersions) - extensions.addU16(2) // Length - extensions.addU16(m.vers) - if m.hasSelectedGroup { - extensions.addU16(extensionKeyShare) - extensions.addU16(2) // length - extensions.addU16(uint16(m.selectedGroup)) - } - // m.cookie may be a non-nil empty slice for empty cookie tests. - if m.cookie != nil { - extensions.addU16(extensionCookie) - body := extensions.addU16LengthPrefixed() - body.addU16LengthPrefixed().addBytes(m.cookie) - } - if len(m.customExtension) > 0 { - extensions.addU16(extensionCustom) - extensions.addU16LengthPrefixed().addBytes([]byte(m.customExtension)) - } - if len(m.echConfirmation) > 0 { - extensions.addU16(extensionEncryptedClientHello) - extensions.addU16LengthPrefixed().addBytes(m.echConfirmation) - } - } + for i := 0; i < count; i++ { + extensions.AddUint16(extensionSupportedVersions) + extensions.AddUint16(2) // Length + extensions.AddUint16(m.vers) + if m.hasSelectedGroup { + extensions.AddUint16(extensionKeyShare) + extensions.AddUint16(2) // length + extensions.AddUint16(uint16(m.selectedGroup)) + } + // m.cookie may be a non-nil empty slice for empty cookie tests. + if m.cookie != nil { + extensions.AddUint16(extensionCookie) + extensions.AddUint16LengthPrefixed(func(body *cryptobyte.Builder) { + addUint16LengthPrefixedBytes(body, m.cookie) + }) + } + if len(m.customExtension) > 0 { + extensions.AddUint16(extensionCustom) + addUint16LengthPrefixedBytes(extensions, []byte(m.customExtension)) + } + if len(m.echConfirmation) > 0 { + extensions.AddUint16(extensionEncryptedClientHello) + addUint16LengthPrefixedBytes(extensions, m.echConfirmation) + } + } + }) + }) - m.raw = retryRequestMsg.finish() + m.raw = retryRequestMsg.BytesOrPanic() return m.raw } func (m *helloRetryRequestMsg) unmarshal(data []byte) bool { m.raw = data - reader := byteReader(data[4:]) + reader := cryptobyte.String(data[4:]) var legacyVers uint16 var random []byte var compressionMethod byte - var extensions byteReader - if !reader.readU16(&legacyVers) || + var extensions cryptobyte.String + if !reader.ReadUint16(&legacyVers) || legacyVers != VersionTLS12 || - !reader.readBytes(&random, 32) || - !reader.readU8LengthPrefixedBytes(&m.sessionID) || - !reader.readU16(&m.cipherSuite) || - !reader.readU8(&compressionMethod) || + !reader.ReadBytes(&random, 32) || + !readUint8LengthPrefixedBytes(&reader, &m.sessionID) || + !reader.ReadUint16(&m.cipherSuite) || + !reader.ReadUint8(&compressionMethod) || compressionMethod != 0 || - !reader.readU16LengthPrefixed(&extensions) || + !reader.ReadUint16LengthPrefixed(&extensions) || len(reader) != 0 { return false } for len(extensions) > 0 { var extension uint16 - var body byteReader - if !extensions.readU16(&extension) || - !extensions.readU16LengthPrefixed(&body) { + var body cryptobyte.String + if !extensions.ReadUint16(&extension) || + !extensions.ReadUint16LengthPrefixed(&body) { return false } switch extension { case extensionSupportedVersions: - if !body.readU16(&m.vers) || + if !body.ReadUint16(&m.vers) || len(body) != 0 { return false } case extensionKeyShare: var v uint16 - if !body.readU16(&v) || len(body) != 0 { + if !body.ReadUint16(&v) || len(body) != 0 { return false } m.hasSelectedGroup = true m.selectedGroup = CurveID(v) case extensionCookie: - if !body.readU16LengthPrefixedBytes(&m.cookie) || + if !readUint16LengthPrefixedBytes(&body, &m.cookie) || len(m.cookie) == 0 || len(body) != 0 { return false @@ -2063,85 +1900,86 @@ func (m *certificateMsg) marshal() (x []byte) { return m.raw } - certMsg := newByteBuilder() - certMsg.addU8(typeCertificate) - certificate := certMsg.addU24LengthPrefixed() - if m.hasRequestContext { - context := certificate.addU8LengthPrefixed() - context.addBytes(m.requestContext) - } - certificateList := certificate.addU24LengthPrefixed() - for _, cert := range m.certificates { - certEntry := certificateList.addU24LengthPrefixed() - certEntry.addBytes(cert.data) + certMsg := cryptobyte.NewBuilder(nil) + certMsg.AddUint8(typeCertificate) + certMsg.AddUint24LengthPrefixed(func(certificate *cryptobyte.Builder) { if m.hasRequestContext { - extensions := certificateList.addU16LengthPrefixed() - count := 1 - if cert.duplicateExtensions { - count = 2 - } - - for i := 0; i < count; i++ { - if cert.ocspResponse != nil { - extensions.addU16(extensionStatusRequest) - body := extensions.addU16LengthPrefixed() - body.addU8(statusTypeOCSP) - response := body.addU24LengthPrefixed() - response.addBytes(cert.ocspResponse) - } - - if cert.sctList != nil { - extensions.addU16(extensionSignedCertificateTimestamp) - extension := extensions.addU16LengthPrefixed() - extension.addBytes(cert.sctList) + addUint8LengthPrefixedBytes(certificate, m.requestContext) + } + certificate.AddUint24LengthPrefixed(func(certificateList *cryptobyte.Builder) { + for _, cert := range m.certificates { + addUint24LengthPrefixedBytes(certificateList, cert.data) + if m.hasRequestContext { + certificateList.AddUint16LengthPrefixed(func(extensions *cryptobyte.Builder) { + count := 1 + if cert.duplicateExtensions { + count = 2 + } + + for i := 0; i < count; i++ { + if cert.ocspResponse != nil { + extensions.AddUint16(extensionStatusRequest) + extensions.AddUint16LengthPrefixed(func(body *cryptobyte.Builder) { + body.AddUint8(statusTypeOCSP) + addUint24LengthPrefixedBytes(body, cert.ocspResponse) + }) + } + + if cert.sctList != nil { + extensions.AddUint16(extensionSignedCertificateTimestamp) + addUint16LengthPrefixedBytes(extensions, cert.sctList) + } + } + if cert.extraExtension != nil { + extensions.AddBytes(cert.extraExtension) + } + }) } } - if cert.extraExtension != nil { - extensions.addBytes(cert.extraExtension) - } - } - } + }) - m.raw = certMsg.finish() + }) + + m.raw = certMsg.BytesOrPanic() return m.raw } func (m *certificateMsg) unmarshal(data []byte) bool { m.raw = data - reader := byteReader(data[4:]) + reader := cryptobyte.String(data[4:]) - if m.hasRequestContext && !reader.readU8LengthPrefixedBytes(&m.requestContext) { + if m.hasRequestContext && !readUint8LengthPrefixedBytes(&reader, &m.requestContext) { return false } - var certs byteReader - if !reader.readU24LengthPrefixed(&certs) || len(reader) != 0 { + var certs cryptobyte.String + if !reader.ReadUint24LengthPrefixed(&certs) || len(reader) != 0 { return false } m.certificates = nil for len(certs) > 0 { var cert certificateEntry - if !certs.readU24LengthPrefixedBytes(&cert.data) { + if !readUint24LengthPrefixedBytes(&certs, &cert.data) { return false } if m.hasRequestContext { - var extensions byteReader - if !certs.readU16LengthPrefixed(&extensions) || !checkDuplicateExtensions(extensions) { + var extensions cryptobyte.String + if !certs.ReadUint16LengthPrefixed(&extensions) || !checkDuplicateExtensions(extensions) { return false } for len(extensions) > 0 { var extension uint16 - var body byteReader - if !extensions.readU16(&extension) || - !extensions.readU16LengthPrefixed(&body) { + var body cryptobyte.String + if !extensions.ReadUint16(&extension) || + !extensions.ReadUint16LengthPrefixed(&body) { return false } switch extension { case extensionStatusRequest: var statusType byte - if !body.readU8(&statusType) || + if !body.ReadUint8(&statusType) || statusType != statusTypeOCSP || - !body.readU24LengthPrefixedBytes(&cert.ocspResponse) || + !readUint24LengthPrefixedBytes(&body, &cert.ocspResponse) || len(body) != 0 { return false } @@ -2157,11 +1995,11 @@ func (m *certificateMsg) unmarshal(data []byte) bool { origBody := body var expectedCertVerifyAlgo, algorithm uint16 - if !body.readU32(&dc.lifetimeSecs) || - !body.readU16(&expectedCertVerifyAlgo) || - !body.readU24LengthPrefixedBytes(&dc.pkixPublicKey) || - !body.readU16(&algorithm) || - !body.readU16LengthPrefixedBytes(&dc.signature) || + if !body.ReadUint32(&dc.lifetimeSecs) || + !body.ReadUint16(&expectedCertVerifyAlgo) || + !readUint24LengthPrefixedBytes(&body, &dc.pkixPublicKey) || + !body.ReadUint16(&algorithm) || + !readUint16LengthPrefixedBytes(&body, &dc.signature) || len(body) != 0 { return false } @@ -2193,25 +2031,25 @@ func (m *compressedCertificateMsg) marshal() (x []byte) { return m.raw } - certMsg := newByteBuilder() - certMsg.addU8(typeCompressedCertificate) - certificate := certMsg.addU24LengthPrefixed() - certificate.addU16(m.algID) - certificate.addU24(int(m.uncompressedLength)) - compressed := certificate.addU24LengthPrefixed() - compressed.addBytes(m.compressed) + certMsg := cryptobyte.NewBuilder(nil) + certMsg.AddUint8(typeCompressedCertificate) + certMsg.AddUint24LengthPrefixed(func(certificate *cryptobyte.Builder) { + certificate.AddUint16(m.algID) + certificate.AddUint24(m.uncompressedLength) + addUint24LengthPrefixedBytes(certificate, m.compressed) + }) - m.raw = certMsg.finish() + m.raw = certMsg.BytesOrPanic() return m.raw } func (m *compressedCertificateMsg) unmarshal(data []byte) bool { m.raw = data - reader := byteReader(data[4:]) + reader := cryptobyte.String(data[4:]) - if !reader.readU16(&m.algID) || - !reader.readU24(&m.uncompressedLength) || - !reader.readU24LengthPrefixedBytes(&m.compressed) || + if !reader.ReadUint16(&m.algID) || + !reader.ReadUint24(&m.uncompressedLength) || + !readUint24LengthPrefixedBytes(&reader, &m.compressed) || len(reader) != 0 { return false } @@ -2232,10 +2070,10 @@ func (m *serverKeyExchangeMsg) marshal() []byte { if m.raw != nil { return m.raw } - msg := newByteBuilder() - msg.addU8(typeServerKeyExchange) - msg.addU24LengthPrefixed().addBytes(m.key) - m.raw = msg.finish() + msg := cryptobyte.NewBuilder(nil) + msg.AddUint8(typeServerKeyExchange) + addUint24LengthPrefixedBytes(msg, m.key) + m.raw = msg.BytesOrPanic() return m.raw } @@ -2261,12 +2099,13 @@ func (m *certificateStatusMsg) marshal() []byte { var x []byte if m.statusType == statusTypeOCSP { - msg := newByteBuilder() - msg.addU8(typeCertificateStatus) - body := msg.addU24LengthPrefixed() - body.addU8(statusTypeOCSP) - body.addU24LengthPrefixed().addBytes(m.response) - x = msg.finish() + msg := cryptobyte.NewBuilder(nil) + msg.AddUint8(typeCertificateStatus) + msg.AddUint24LengthPrefixed(func(body *cryptobyte.Builder) { + body.AddUint8(statusTypeOCSP) + addUint24LengthPrefixedBytes(body, m.response) + }) + x = msg.BytesOrPanic() } else { x = []byte{typeCertificateStatus, 0, 0, 1, m.statusType} } @@ -2277,10 +2116,10 @@ func (m *certificateStatusMsg) marshal() []byte { func (m *certificateStatusMsg) unmarshal(data []byte) bool { m.raw = data - reader := byteReader(data[4:]) - if !reader.readU8(&m.statusType) || + reader := cryptobyte.String(data[4:]) + if !reader.ReadUint8(&m.statusType) || m.statusType != statusTypeOCSP || - !reader.readU24LengthPrefixedBytes(&m.response) || + !readUint24LengthPrefixedBytes(&reader, &m.response) || len(reader) != 0 { return false } @@ -2308,10 +2147,10 @@ func (m *clientKeyExchangeMsg) marshal() []byte { if m.raw != nil { return m.raw } - msg := newByteBuilder() - msg.addU8(typeClientKeyExchange) - msg.addU24LengthPrefixed().addBytes(m.ciphertext) - m.raw = msg.finish() + msg := cryptobyte.NewBuilder(nil) + msg.AddUint8(typeClientKeyExchange) + addUint24LengthPrefixedBytes(msg, m.ciphertext) + m.raw = msg.BytesOrPanic() return m.raw } @@ -2338,10 +2177,10 @@ func (m *finishedMsg) marshal() []byte { return m.raw } - msg := newByteBuilder() - msg.addU8(typeFinished) - msg.addU24LengthPrefixed().addBytes(m.verifyData) - m.raw = msg.finish() + msg := cryptobyte.NewBuilder(nil) + msg.AddUint8(typeFinished) + addUint24LengthPrefixedBytes(msg, m.verifyData) + m.raw = msg.BytesOrPanic() return m.raw } @@ -2366,21 +2205,22 @@ func (m *nextProtoMsg) marshal() []byte { padding := 32 - (len(m.proto)+2)%32 - msg := newByteBuilder() - msg.addU8(typeNextProtocol) - body := msg.addU24LengthPrefixed() - body.addU8LengthPrefixed().addBytes([]byte(m.proto)) - body.addU8LengthPrefixed().addBytes(make([]byte, padding)) - m.raw = msg.finish() + msg := cryptobyte.NewBuilder(nil) + msg.AddUint8(typeNextProtocol) + msg.AddUint24LengthPrefixed(func(body *cryptobyte.Builder) { + addUint8LengthPrefixedBytes(body, []byte(m.proto)) + addUint8LengthPrefixedBytes(body, make([]byte, padding)) + }) + m.raw = msg.BytesOrPanic() return m.raw } func (m *nextProtoMsg) unmarshal(data []byte) bool { m.raw = data - reader := byteReader(data[4:]) + reader := cryptobyte.String(data[4:]) var proto, padding []byte - if !reader.readU8LengthPrefixedBytes(&proto) || - !reader.readU8LengthPrefixedBytes(&padding) || + if !readUint8LengthPrefixedBytes(&reader, &proto) || + !readUint8LengthPrefixedBytes(&reader, &padding) || len(reader) != 0 { return false } @@ -2427,72 +2267,79 @@ func (m *certificateRequestMsg) marshal() []byte { } // See http://tools.ietf.org/html/rfc4346#section-7.4.4 - builder := newByteBuilder() - builder.addU8(typeCertificateRequest) - body := builder.addU24LengthPrefixed() - - if m.hasRequestContext { - requestContext := body.addU8LengthPrefixed() - requestContext.addBytes(m.requestContext) - extensions := newByteBuilder() - extensions = body.addU16LengthPrefixed() - if m.hasSignatureAlgorithm { - extensions.addU16(extensionSignatureAlgorithms) - signatureAlgorithms := extensions.addU16LengthPrefixed().addU16LengthPrefixed() - for _, sigAlg := range m.signatureAlgorithms { - signatureAlgorithms.addU16(uint16(sigAlg)) - } - } - if len(m.signatureAlgorithmsCert) > 0 { - extensions.addU16(extensionSignatureAlgorithmsCert) - signatureAlgorithmsCert := extensions.addU16LengthPrefixed().addU16LengthPrefixed() - for _, sigAlg := range m.signatureAlgorithmsCert { - signatureAlgorithmsCert.addU16(uint16(sigAlg)) - } - } - if len(m.certificateAuthorities) > 0 { - extensions.addU16(extensionCertificateAuthorities) - certificateAuthorities := extensions.addU16LengthPrefixed().addU16LengthPrefixed() - for _, ca := range m.certificateAuthorities { - caEntry := certificateAuthorities.addU16LengthPrefixed() - caEntry.addBytes(ca) - } - } + builder := cryptobyte.NewBuilder(nil) + builder.AddUint8(typeCertificateRequest) + builder.AddUint24LengthPrefixed(func(body *cryptobyte.Builder) { + if m.hasRequestContext { + addUint8LengthPrefixedBytes(body, m.requestContext) + body.AddUint16LengthPrefixed(func(extensions *cryptobyte.Builder) { + if m.hasSignatureAlgorithm { + extensions.AddUint16(extensionSignatureAlgorithms) + extensions.AddUint16LengthPrefixed(func(extension *cryptobyte.Builder) { + extension.AddUint16LengthPrefixed(func(signatureAlgorithms *cryptobyte.Builder) { + for _, sigAlg := range m.signatureAlgorithms { + signatureAlgorithms.AddUint16(uint16(sigAlg)) + } + }) + }) + } + if len(m.signatureAlgorithmsCert) > 0 { + extensions.AddUint16(extensionSignatureAlgorithmsCert) + extensions.AddUint16LengthPrefixed(func(extension *cryptobyte.Builder) { + extension.AddUint16LengthPrefixed(func(signatureAlgorithmsCert *cryptobyte.Builder) { + for _, sigAlg := range m.signatureAlgorithmsCert { + signatureAlgorithmsCert.AddUint16(uint16(sigAlg)) + } + }) + }) + } + if len(m.certificateAuthorities) > 0 { + extensions.AddUint16(extensionCertificateAuthorities) + extensions.AddUint16LengthPrefixed(func(extension *cryptobyte.Builder) { + extension.AddUint16LengthPrefixed(func(certificateAuthorities *cryptobyte.Builder) { + for _, ca := range m.certificateAuthorities { + addUint16LengthPrefixedBytes(certificateAuthorities, ca) + } + }) + }) + } - if m.customExtension > 0 { - extensions.addU16(m.customExtension) - extensions.addU16LengthPrefixed() - } - } else { - certificateTypes := body.addU8LengthPrefixed() - certificateTypes.addBytes(m.certificateTypes) + if m.customExtension > 0 { + extensions.AddUint16(m.customExtension) + extensions.AddUint16(0) // Empty extension + } + }) + } else { + addUint8LengthPrefixedBytes(body, m.certificateTypes) - if m.hasSignatureAlgorithm { - signatureAlgorithms := body.addU16LengthPrefixed() - for _, sigAlg := range m.signatureAlgorithms { - signatureAlgorithms.addU16(uint16(sigAlg)) + if m.hasSignatureAlgorithm { + body.AddUint16LengthPrefixed(func(signatureAlgorithms *cryptobyte.Builder) { + for _, sigAlg := range m.signatureAlgorithms { + signatureAlgorithms.AddUint16(uint16(sigAlg)) + } + }) } - } - certificateAuthorities := body.addU16LengthPrefixed() - for _, ca := range m.certificateAuthorities { - caEntry := certificateAuthorities.addU16LengthPrefixed() - caEntry.addBytes(ca) + body.AddUint16LengthPrefixed(func(certificateAuthorities *cryptobyte.Builder) { + for _, ca := range m.certificateAuthorities { + addUint16LengthPrefixedBytes(certificateAuthorities, ca) + } + }) } - } + }) - m.raw = builder.finish() + m.raw = builder.BytesOrPanic() return m.raw } -func parseCAs(reader *byteReader, out *[][]byte) bool { - var cas byteReader - if !reader.readU16LengthPrefixed(&cas) { +func parseCAs(reader *cryptobyte.String, out *[][]byte) bool { + var cas cryptobyte.String + if !reader.ReadUint16LengthPrefixed(&cas) { return false } for len(cas) > 0 { var ca []byte - if !cas.readU16LengthPrefixedBytes(&ca) { + if !readUint16LengthPrefixedBytes(&cas, &ca) { return false } *out = append(*out, ca) @@ -2502,21 +2349,21 @@ func parseCAs(reader *byteReader, out *[][]byte) bool { func (m *certificateRequestMsg) unmarshal(data []byte) bool { m.raw = data - reader := byteReader(data[4:]) + reader := cryptobyte.String(data[4:]) if m.hasRequestContext { - var extensions byteReader - if !reader.readU8LengthPrefixedBytes(&m.requestContext) || - !reader.readU16LengthPrefixed(&extensions) || + var extensions cryptobyte.String + if !readUint8LengthPrefixedBytes(&reader, &m.requestContext) || + !reader.ReadUint16LengthPrefixed(&extensions) || len(reader) != 0 || !checkDuplicateExtensions(extensions) { return false } for len(extensions) > 0 { var extension uint16 - var body byteReader - if !extensions.readU16(&extension) || - !extensions.readU16LengthPrefixed(&body) { + var body cryptobyte.String + if !extensions.ReadUint16(&extension) || + !extensions.ReadUint16LengthPrefixed(&body) { return false } switch extension { @@ -2536,7 +2383,7 @@ func (m *certificateRequestMsg) unmarshal(data []byte) bool { } } } else { - if !reader.readU8LengthPrefixedBytes(&m.certificateTypes) { + if !readUint8LengthPrefixedBytes(&reader, &m.certificateTypes) { return false } // In TLS 1.2, the supported_signature_algorithms field in @@ -2648,35 +2495,40 @@ func (m *newSessionTicketMsg) marshal() []byte { } // See http://tools.ietf.org/html/rfc5077#section-3.3 - ticketMsg := newByteBuilder() - ticketMsg.addU8(typeNewSessionTicket) - body := ticketMsg.addU24LengthPrefixed() - body.addU32(m.ticketLifetime) - if version >= VersionTLS13 { - body.addU32(m.ticketAgeAdd) - body.addU8LengthPrefixed().addBytes(m.ticketNonce) - } - - ticket := body.addU16LengthPrefixed() - ticket.addBytes(m.ticket) - - if version >= VersionTLS13 { - extensions := body.addU16LengthPrefixed() - if m.maxEarlyDataSize > 0 { - extensions.addU16(extensionEarlyData) - extensions.addU16LengthPrefixed().addU32(m.maxEarlyDataSize) - if m.duplicateEarlyDataExtension { - extensions.addU16(extensionEarlyData) - extensions.addU16LengthPrefixed().addU32(m.maxEarlyDataSize) - } - } - if len(m.customExtension) > 0 { - extensions.addU16(extensionCustom) - extensions.addU16LengthPrefixed().addBytes([]byte(m.customExtension)) + ticketMsg := cryptobyte.NewBuilder(nil) + ticketMsg.AddUint8(typeNewSessionTicket) + ticketMsg.AddUint24LengthPrefixed(func(body *cryptobyte.Builder) { + body.AddUint32(m.ticketLifetime) + if version >= VersionTLS13 { + body.AddUint32(m.ticketAgeAdd) + addUint8LengthPrefixedBytes(body, m.ticketNonce) + } + + addUint16LengthPrefixedBytes(body, m.ticket) + + if version >= VersionTLS13 { + body.AddUint16LengthPrefixed(func(extensions *cryptobyte.Builder) { + if m.maxEarlyDataSize > 0 { + extensions.AddUint16(extensionEarlyData) + extensions.AddUint16LengthPrefixed(func(child *cryptobyte.Builder) { + child.AddUint32(m.maxEarlyDataSize) + }) + if m.duplicateEarlyDataExtension { + extensions.AddUint16(extensionEarlyData) + extensions.AddUint16LengthPrefixed(func(child *cryptobyte.Builder) { + child.AddUint32(m.maxEarlyDataSize) + }) + } + } + if len(m.customExtension) > 0 { + extensions.AddUint16(extensionCustom) + addUint16LengthPrefixedBytes(extensions, []byte(m.customExtension)) + } + }) } - } + }) - m.raw = ticketMsg.finish() + m.raw = ticketMsg.BytesOrPanic() return m.raw } diff --git a/ssl/test/runner/handshake_server.go b/ssl/test/runner/handshake_server.go index 4130d9bc9..5c49afbc2 100644 --- a/ssl/test/runner/handshake_server.go +++ b/ssl/test/runner/handshake_server.go @@ -20,6 +20,7 @@ import ( "time" "boringssl.googlesource.com/boringssl/ssl/test/runner/hpke" + "golang.org/x/crypto/cryptobyte" ) // serverHandshakeState contains details of a server handshake in progress. @@ -2443,18 +2444,18 @@ func checkClientHellosEqual(a, b []byte, isDTLS bool, ignoreExtensions []uint16) } // Skip the handshake message header. - aReader := byteReader(a[4:]) - bReader := byteReader(b[4:]) + aReader := cryptobyte.String(a[4:]) + bReader := cryptobyte.String(b[4:]) var aVers, bVers uint16 var aRandom, bRandom []byte var aSessionID, bSessionID []byte - if !aReader.readU16(&aVers) || - !bReader.readU16(&bVers) || - !aReader.readBytes(&aRandom, 32) || - !bReader.readBytes(&bRandom, 32) || - !aReader.readU8LengthPrefixedBytes(&aSessionID) || - !bReader.readU8LengthPrefixedBytes(&bSessionID) { + if !aReader.ReadUint16(&aVers) || + !bReader.ReadUint16(&bVers) || + !aReader.ReadBytes(&aRandom, 32) || + !bReader.ReadBytes(&bRandom, 32) || + !readUint8LengthPrefixedBytes(&aReader, &aSessionID) || + !readUint8LengthPrefixedBytes(&bReader, &bSessionID) { return errors.New("tls: could not parse ClientHello") } @@ -2474,17 +2475,17 @@ func checkClientHellosEqual(a, b []byte, isDTLS bool, ignoreExtensions []uint16) // cookie altogether. If we implement DTLS 1.3, we'll need to ensure // that parsing logic above this function rejects this cookie. var aCookie, bCookie []byte - if !aReader.readU8LengthPrefixedBytes(&aCookie) || - !bReader.readU8LengthPrefixedBytes(&bCookie) { + if !readUint8LengthPrefixedBytes(&aReader, &aCookie) || + !readUint8LengthPrefixedBytes(&bReader, &bCookie) { return errors.New("tls: could not parse ClientHello") } } var aCipherSuites, bCipherSuites, aCompressionMethods, bCompressionMethods []byte - if !aReader.readU16LengthPrefixedBytes(&aCipherSuites) || - !bReader.readU16LengthPrefixedBytes(&bCipherSuites) || - !aReader.readU8LengthPrefixedBytes(&aCompressionMethods) || - !bReader.readU8LengthPrefixedBytes(&bCompressionMethods) { + if !readUint16LengthPrefixedBytes(&aReader, &aCipherSuites) || + !readUint16LengthPrefixedBytes(&bReader, &bCipherSuites) || + !readUint8LengthPrefixedBytes(&aReader, &aCompressionMethods) || + !readUint8LengthPrefixedBytes(&bReader, &bCompressionMethods) { return errors.New("tls: could not parse ClientHello") } if !bytes.Equal(aCipherSuites, bCipherSuites) { @@ -2499,9 +2500,9 @@ func checkClientHellosEqual(a, b []byte, isDTLS bool, ignoreExtensions []uint16) return nil } - var aExtensions, bExtensions byteReader - if !aReader.readU16LengthPrefixed(&aExtensions) || - !bReader.readU16LengthPrefixed(&bExtensions) || + var aExtensions, bExtensions cryptobyte.String + if !aReader.ReadUint16LengthPrefixed(&aExtensions) || + !bReader.ReadUint16LengthPrefixed(&bExtensions) || len(aReader) != 0 || len(bReader) != 0 { return errors.New("tls: could not parse ClientHello") @@ -2510,8 +2511,8 @@ func checkClientHellosEqual(a, b []byte, isDTLS bool, ignoreExtensions []uint16) for len(aExtensions) != 0 { var aID uint16 var aBody []byte - if !aExtensions.readU16(&aID) || - !aExtensions.readU16LengthPrefixedBytes(&aBody) { + if !aExtensions.ReadUint16(&aID) || + !readUint16LengthPrefixedBytes(&aExtensions, &aBody) { return errors.New("tls: could not parse ClientHello") } if _, ok := ignoreExtensionsSet[aID]; ok { @@ -2524,8 +2525,8 @@ func checkClientHellosEqual(a, b []byte, isDTLS bool, ignoreExtensions []uint16) } var bID uint16 var bBody []byte - if !bExtensions.readU16(&bID) || - !bExtensions.readU16LengthPrefixedBytes(&bBody) { + if !bExtensions.ReadUint16(&bID) || + !readUint16LengthPrefixedBytes(&bExtensions, &bBody) { return errors.New("tls: could not parse ClientHello") } if _, ok := ignoreExtensionsSet[bID]; ok { @@ -2546,8 +2547,8 @@ func checkClientHellosEqual(a, b []byte, isDTLS bool, ignoreExtensions []uint16) for len(bExtensions) != 0 { var id uint16 var body []byte - if !bExtensions.readU16(&id) || - !bExtensions.readU16LengthPrefixedBytes(&body) { + if !bExtensions.ReadUint16(&id) || + !readUint16LengthPrefixedBytes(&bExtensions, &body) { return errors.New("tls: could not parse ClientHello") } if _, ok := ignoreExtensionsSet[id]; !ok { diff --git a/ssl/test/runner/prf.go b/ssl/test/runner/prf.go index fc67d7503..4cdc7c8c0 100644 --- a/ssl/test/runner/prf.go +++ b/ssl/test/runner/prf.go @@ -13,6 +13,7 @@ import ( "encoding" "hash" + "golang.org/x/crypto/cryptobyte" "golang.org/x/crypto/hkdf" ) @@ -228,15 +229,15 @@ type finishedHash struct { } func (h *finishedHash) UpdateForHelloRetryRequest() { - data := newByteBuilder() - data.addU8(typeMessageHash) - data.addU24(h.hash.Size()) - data.addBytes(h.Sum()) + data := cryptobyte.NewBuilder(nil) + data.AddUint8(typeMessageHash) + data.AddUint24(uint32(h.hash.Size())) + data.AddBytes(h.Sum()) h.hash = h.suite.hash().New() if h.buffer != nil { h.buffer = []byte{} } - h.Write(data.finish()) + h.Write(data.BytesOrPanic()) } func (h *finishedHash) Write(msg []byte) (n int, err error) { diff --git a/ssl/test/runner/runner.go b/ssl/test/runner/runner.go index e01e1d541..fcdd11a3d 100644 --- a/ssl/test/runner/runner.go +++ b/ssl/test/runner/runner.go @@ -47,6 +47,7 @@ import ( "boringssl.googlesource.com/boringssl/ssl/test/runner/hpke" "boringssl.googlesource.com/boringssl/util/testresult" + "golang.org/x/crypto/cryptobyte" ) var ( @@ -819,10 +820,10 @@ func doExchange(test *testCase, config *Config, conn net.Conn, isResume bool, tr if err := os.MkdirAll(dir, 0755); err != nil { return err } - bb := newByteBuilder() - bb.addU24LengthPrefixed().addBytes(encodedInner) - bb.addBytes(outer) - return os.WriteFile(filepath.Join(dir, name), bb.finish(), 0644) + bb := cryptobyte.NewBuilder(nil) + addUint24LengthPrefixedBytes(bb, encodedInner) + bb.AddBytes(outer) + return os.WriteFile(filepath.Join(dir, name), bb.BytesOrPanic(), 0644) } } diff --git a/ssl/test/runner/ticket.go b/ssl/test/runner/ticket.go index 46a6b3579..f0a8bf18a 100644 --- a/ssl/test/runner/ticket.go +++ b/ssl/test/runner/ticket.go @@ -13,6 +13,8 @@ import ( "errors" "io" "time" + + "golang.org/x/crypto/cryptobyte" ) // sessionState contains the information that is serialized into a session @@ -35,49 +37,45 @@ type sessionState struct { } func (s *sessionState) marshal() []byte { - msg := newByteBuilder() - msg.addU16(s.vers) - msg.addU16(s.cipherSuite) - secret := msg.addU16LengthPrefixed() - secret.addBytes(s.secret) - handshakeHash := msg.addU16LengthPrefixed() - handshakeHash.addBytes(s.handshakeHash) - msg.addU16(uint16(len(s.certificates))) + msg := cryptobyte.NewBuilder(nil) + msg.AddUint16(s.vers) + msg.AddUint16(s.cipherSuite) + addUint16LengthPrefixedBytes(msg, s.secret) + addUint16LengthPrefixedBytes(msg, s.handshakeHash) + msg.AddUint16(uint16(len(s.certificates))) for _, cert := range s.certificates { - certMsg := msg.addU32LengthPrefixed() - certMsg.addBytes(cert) + addUint24LengthPrefixedBytes(msg, cert) } if s.extendedMasterSecret { - msg.addU8(1) + msg.AddUint8(1) } else { - msg.addU8(0) + msg.AddUint8(0) } if s.vers >= VersionTLS13 { - msg.addU64(uint64(s.ticketCreationTime.UnixNano())) - msg.addU64(uint64(s.ticketExpiration.UnixNano())) - msg.addU32(s.ticketFlags) - msg.addU32(s.ticketAgeAdd) + msg.AddUint64(uint64(s.ticketCreationTime.UnixNano())) + msg.AddUint64(uint64(s.ticketExpiration.UnixNano())) + msg.AddUint32(s.ticketFlags) + msg.AddUint32(s.ticketAgeAdd) } - earlyALPN := msg.addU16LengthPrefixed() - earlyALPN.addBytes(s.earlyALPN) + addUint16LengthPrefixedBytes(msg, s.earlyALPN) if s.hasApplicationSettings { - msg.addU8(1) - msg.addU16LengthPrefixed().addBytes(s.localApplicationSettings) - msg.addU16LengthPrefixed().addBytes(s.peerApplicationSettings) + msg.AddUint8(1) + addUint16LengthPrefixedBytes(msg, s.localApplicationSettings) + addUint16LengthPrefixedBytes(msg, s.peerApplicationSettings) } else { - msg.addU8(0) + msg.AddUint8(0) } - return msg.finish() + return msg.BytesOrPanic() } -func readBool(reader *byteReader, out *bool) bool { +func readBool(reader *cryptobyte.String, out *bool) bool { var value uint8 - if !reader.readU8(&value) { + if !reader.ReadUint8(&value) { return false } if value == 0 { @@ -92,19 +90,19 @@ func readBool(reader *byteReader, out *bool) bool { } func (s *sessionState) unmarshal(data []byte) bool { - reader := byteReader(data) + reader := cryptobyte.String(data) var numCerts uint16 - if !reader.readU16(&s.vers) || - !reader.readU16(&s.cipherSuite) || - !reader.readU16LengthPrefixedBytes(&s.secret) || - !reader.readU16LengthPrefixedBytes(&s.handshakeHash) || - !reader.readU16(&numCerts) { + if !reader.ReadUint16(&s.vers) || + !reader.ReadUint16(&s.cipherSuite) || + !readUint16LengthPrefixedBytes(&reader, &s.secret) || + !readUint16LengthPrefixedBytes(&reader, &s.handshakeHash) || + !reader.ReadUint16(&numCerts) { return false } s.certificates = make([][]byte, int(numCerts)) for i := range s.certificates { - if !reader.readU32LengthPrefixedBytes(&s.certificates[i]) { + if !readUint24LengthPrefixedBytes(&reader, &s.certificates[i]) { return false } } @@ -115,24 +113,24 @@ func (s *sessionState) unmarshal(data []byte) bool { if s.vers >= VersionTLS13 { var ticketCreationTime, ticketExpiration uint64 - if !reader.readU64(&ticketCreationTime) || - !reader.readU64(&ticketExpiration) || - !reader.readU32(&s.ticketFlags) || - !reader.readU32(&s.ticketAgeAdd) { + if !reader.ReadUint64(&ticketCreationTime) || + !reader.ReadUint64(&ticketExpiration) || + !reader.ReadUint32(&s.ticketFlags) || + !reader.ReadUint32(&s.ticketAgeAdd) { return false } s.ticketCreationTime = time.Unix(0, int64(ticketCreationTime)) s.ticketExpiration = time.Unix(0, int64(ticketExpiration)) } - if !reader.readU16LengthPrefixedBytes(&s.earlyALPN) || + if !readUint16LengthPrefixedBytes(&reader, &s.earlyALPN) || !readBool(&reader, &s.hasApplicationSettings) { return false } if s.hasApplicationSettings { - if !reader.readU16LengthPrefixedBytes(&s.localApplicationSettings) || - !reader.readU16LengthPrefixedBytes(&s.peerApplicationSettings) { + if !readUint16LengthPrefixedBytes(&reader, &s.localApplicationSettings) || + !readUint16LengthPrefixedBytes(&reader, &s.peerApplicationSettings) { return false } }