From 8a1cdec2ae0324b5d1bdf92197ca7ae3105163fc Mon Sep 17 00:00:00 2001 From: Carl Mastrangelo Date: Wed, 18 Nov 2015 16:48:57 -0800 Subject: [PATCH] Add cipher suite test to http2 interop tests, and honor test_case flag --- tools/http2_interop/goaway.go | 72 ++++++++++++++++++++++++ tools/http2_interop/http2interop.go | 52 +++++++++++++++++ tools/http2_interop/http2interop_test.go | 46 ++++++++++----- 3 files changed, 155 insertions(+), 15 deletions(-) create mode 100644 tools/http2_interop/goaway.go diff --git a/tools/http2_interop/goaway.go b/tools/http2_interop/goaway.go new file mode 100644 index 00000000000..289442d615b --- /dev/null +++ b/tools/http2_interop/goaway.go @@ -0,0 +1,72 @@ +package http2interop + +import ( + "encoding/binary" + "fmt" + "io" +) + +type GoAwayFrame struct { + Header FrameHeader + Reserved + StreamID + // TODO(carl-mastrangelo): make an enum out of this. + Code uint32 + Data []byte +} + +func (f *GoAwayFrame) GetHeader() *FrameHeader { + return &f.Header +} + +func (f *GoAwayFrame) ParsePayload(r io.Reader) error { + raw := make([]byte, f.Header.Length) + if _, err := io.ReadFull(r, raw); err != nil { + return err + } + return f.UnmarshalPayload(raw) +} + +func (f *GoAwayFrame) UnmarshalPayload(raw []byte) error { + if f.Header.Length != len(raw) { + return fmt.Errorf("Invalid Payload length %d != %d", f.Header.Length, len(raw)) + } + if f.Header.Length < 8 { + return fmt.Errorf("Invalid Payload length %d", f.Header.Length) + } + *f = GoAwayFrame{ + Reserved: Reserved(raw[0]>>7 == 1), + StreamID: StreamID(binary.BigEndian.Uint32(raw[0:4]) & 0x7fffffff), + Code: binary.BigEndian.Uint32(raw[4:8]), + Data: []byte(string(raw[8:])), + } + + return nil +} + +func (f *GoAwayFrame) MarshalPayload() ([]byte, error) { + raw := make([]byte, 8, 8+len(f.Data)) + binary.BigEndian.PutUint32(raw[:4], uint32(f.StreamID)) + binary.BigEndian.PutUint32(raw[4:8], f.Code) + raw = append(raw, f.Data...) + + return raw, nil +} + +func (f *GoAwayFrame) MarshalBinary() ([]byte, error) { + payload, err := f.MarshalPayload() + if err != nil { + return nil, err + } + + f.Header.Length = len(payload) + f.Header.Type = GoAwayFrameType + header, err := f.Header.MarshalBinary() + if err != nil { + return nil, err + } + + header = append(header, payload...) + + return header, nil +} diff --git a/tools/http2_interop/http2interop.go b/tools/http2_interop/http2interop.go index 8585a044e53..bef8b0b656e 100644 --- a/tools/http2_interop/http2interop.go +++ b/tools/http2_interop/http2interop.go @@ -252,6 +252,58 @@ func testTLSApplicationProtocol(ctx *HTTP2InteropCtx) error { return nil } +func testTLSBadCipherSuites(ctx *HTTP2InteropCtx) error { + config := buildTlsConfig(ctx) + // These are the suites that Go supports, but are forbidden by http2. + config.CipherSuites = []uint16{ + tls.TLS_RSA_WITH_RC4_128_SHA, + tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA, + tls.TLS_RSA_WITH_AES_128_CBC_SHA, + tls.TLS_RSA_WITH_AES_256_CBC_SHA, + tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, + tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, + tls.TLS_ECDHE_RSA_WITH_RC4_128_SHA, + tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, + tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, + tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, + } + conn, err := connectWithTls(ctx, config) + if err != nil { + return err + } + defer conn.Close() + conn.SetDeadline(time.Now().Add(defaultTimeout)) + + if err := http2Connect(conn, nil); err != nil { + return err + } + + for { + f, err := parseFrame(conn) + if err != nil { + return err + } + if gf, ok := f.(*GoAwayFrame); ok { + return fmt.Errorf("Got goaway frame %d", gf.Code) + } + } + return nil +} + +func http2Connect(c net.Conn, sf *SettingsFrame) error { + if _, err := c.Write([]byte(Preface)); err != nil { + return err + } + if sf == nil { + sf = &SettingsFrame{} + } + if err := streamFrame(c, sf); err != nil { + return err + } + return nil +} + func connect(ctx *HTTP2InteropCtx) (net.Conn, error) { var conn net.Conn var err error diff --git a/tools/http2_interop/http2interop_test.go b/tools/http2_interop/http2interop_test.go index dc2960048f1..8fd838422b8 100644 --- a/tools/http2_interop/http2interop_test.go +++ b/tools/http2_interop/http2interop_test.go @@ -3,13 +3,13 @@ package http2interop import ( "crypto/tls" "crypto/x509" - "strings" "flag" "fmt" "io" "io/ioutil" "os" "strconv" + "strings" "testing" ) @@ -17,8 +17,7 @@ var ( serverHost = flag.String("server_host", "", "The host to test") serverPort = flag.Int("server_port", 443, "The port to test") useTls = flag.Bool("use_tls", true, "Should TLS tests be run") - // TODO: implement - testCase = flag.String("test_case", "", "What test cases to run") + testCase = flag.String("test_case", "", "What test cases to run") // The rest of these are unused, but present to fulfill the client interface serverHostOverride = flag.String("server_host_override", "", "Unused") @@ -86,33 +85,50 @@ func TestUnknownFrameType(t *testing.T) { } func TestTLSApplicationProtocol(t *testing.T) { + if *testCase != "tls" { + return + } ctx := InteropCtx(t) - err := testTLSApplicationProtocol(ctx); + err := testTLSApplicationProtocol(ctx) matchError(t, err, "EOF") } func TestTLSMaxVersion(t *testing.T) { + if *testCase != "tls" { + return + } ctx := InteropCtx(t) - err := testTLSMaxVersion(ctx, tls.VersionTLS11); + err := testTLSMaxVersion(ctx, tls.VersionTLS11) + // TODO(carl-mastrangelo): maybe this should be some other error. If the server picks + // the wrong protocol version, thats bad too. matchError(t, err, "EOF", "server selected unsupported protocol") } +func TestTLSBadCipherSuites(t *testing.T) { + if *testCase != "tls" { + return + } + ctx := InteropCtx(t) + err := testTLSBadCipherSuites(ctx) + matchError(t, err, "EOF", "Got goaway frame") +} + func TestClientPrefaceWithStreamId(t *testing.T) { ctx := InteropCtx(t) err := testClientPrefaceWithStreamId(ctx) matchError(t, err, "EOF") } -func matchError(t *testing.T, err error, matches ... string) { - if err == nil { - t.Fatal("Expected an error") - } - for _, s := range matches { - if strings.Contains(err.Error(), s) { - return - } - } - t.Fatalf("Error %v not in %+v", err, matches) +func matchError(t *testing.T, err error, matches ...string) { + if err == nil { + t.Fatal("Expected an error") + } + for _, s := range matches { + if strings.Contains(err.Error(), s) { + return + } + } + t.Fatalf("Error %v not in %+v", err, matches) } func TestMain(m *testing.M) {