diff --git a/ssl/test/runner/runner.go b/ssl/test/runner/runner.go index b52a43b76..f9678506c 100644 --- a/ssl/test/runner/runner.go +++ b/ssl/test/runner/runner.go @@ -1243,9 +1243,57 @@ var ( errUnimplemented = errors.New("child process does not implement needed flags") ) -// accept accepts a connection from listener, unless waitChan signals a process -// exit first. -func acceptOrWait(listener *net.TCPListener, waitChan chan error) (net.Conn, error) { +type shimProcess struct { + cmd *exec.Cmd + waitChan chan error + listener *net.TCPListener + stdout, stderr bytes.Buffer +} + +// newShimProcess starts a new shim with the specified executable, flags, and +// environment. It internally creates a TCP listener and adds the the -port +// flag. +func newShimProcess(shimPath string, flags []string, env []string) (*shimProcess, error) { + shim := new(shimProcess) + var err error + shim.listener, err = net.ListenTCP("tcp", &net.TCPAddr{IP: net.IPv6loopback}) + if err != nil { + shim.listener, err = net.ListenTCP("tcp4", &net.TCPAddr{IP: net.IP{127, 0, 0, 1}}) + } + if err != nil { + return nil, err + } + + flags = append([]string{"-port", strconv.Itoa(shim.listener.Addr().(*net.TCPAddr).Port)}, flags...) + if *useValgrind { + shim.cmd = valgrindOf(false, shimPath, flags...) + } else if *useGDB { + shim.cmd = gdbOf(shimPath, flags...) + } else if *useLLDB { + shim.cmd = lldbOf(shimPath, flags...) + } else if *useRR { + shim.cmd = rrOf(shimPath, flags...) + } else { + shim.cmd = exec.Command(shimPath, flags...) + } + shim.cmd.Stdin = os.Stdin + shim.cmd.Stdout = &shim.stdout + shim.cmd.Stderr = &shim.stderr + shim.cmd.Env = env + + if err := shim.cmd.Start(); err != nil { + shim.listener.Close() + return nil, err + } + + shim.waitChan = make(chan error, 1) + go func() { shim.waitChan <- shim.cmd.Wait() }() + return shim, nil +} + +// accept returns a new TCP connection with the shim process, or returns an +// error on timeout or shim exit. +func (s *shimProcess) accept() (net.Conn, error) { type connOrError struct { conn net.Conn err error @@ -1253,21 +1301,93 @@ func acceptOrWait(listener *net.TCPListener, waitChan chan error) (net.Conn, err connChan := make(chan connOrError, 1) go func() { if !useDebugger() { - listener.SetDeadline(time.Now().Add(*idleTimeout)) + s.listener.SetDeadline(time.Now().Add(*idleTimeout)) } - conn, err := listener.Accept() + conn, err := s.listener.Accept() connChan <- connOrError{conn, err} close(connChan) }() select { case result := <-connChan: return result.conn, result.err - case childErr := <-waitChan: - waitChan <- childErr + case childErr := <-s.waitChan: + s.waitChan <- childErr + if childErr == nil { + return nil, fmt.Errorf("child exited early with no error") + } return nil, fmt.Errorf("child exited early: %s", childErr) } } +// wait finishes the test and waits for the shim process to exit. +func (s *shimProcess) wait() error { + // Close the listener now. This is to avoid hangs if the shim tries to open + // more connections than expected. + s.listener.Close() + + if !useDebugger() { + waitTimeout := time.AfterFunc(*idleTimeout, func() { + s.cmd.Process.Kill() + }) + defer waitTimeout.Stop() + } + + err := <-s.waitChan + s.waitChan <- err + return err +} + +// close releases resources associated with the shimProcess. This is safe to +// call before or after |wait|. +func (s *shimProcess) close() { + s.listener.Close() + s.cmd.Process.Kill() +} + +func doExchanges(test *testCase, shim *shimProcess, resumeCount int, transcripts *[][]byte) error { + config := test.config + if *deterministic { + config.Rand = &deterministicRand{} + } + + conn, err := shim.accept() + if err != nil { + return err + } + err = doExchange(test, &config, conn, false /* not a resumption */, transcripts, 0) + conn.Close() + if err != nil { + return err + } + + for i := 0; i < resumeCount; i++ { + var resumeConfig Config + if test.resumeConfig != nil { + resumeConfig = *test.resumeConfig + if !test.newSessionsOnResume { + resumeConfig.SessionTicketKey = config.SessionTicketKey + resumeConfig.ClientSessionCache = config.ClientSessionCache + resumeConfig.ServerSessionCache = config.ServerSessionCache + } + resumeConfig.Rand = config.Rand + } else { + resumeConfig = config + } + var connResume net.Conn + connResume, err = shim.accept() + if err != nil { + return err + } + err = doExchange(test, &resumeConfig, connResume, true /* resumption */, transcripts, i+1) + connResume.Close() + if err != nil { + return err + } + } + + return nil +} + func translateExpectedError(errorStr string) string { if translated, ok := shimConfig.ErrorMap[errorStr]; ok { return translated @@ -1293,20 +1413,7 @@ func runTest(statusChan chan statusMsg, test *testCase, shimPath string, mallocN } }() - listener, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.IPv6loopback}) - if err != nil { - listener, err = net.ListenTCP("tcp4", &net.TCPAddr{IP: net.IP{127, 0, 0, 1}}) - } - if err != nil { - panic(err) - } - defer func() { - if listener != nil { - listener.Close() - } - }() - - flags := []string{"-port", strconv.Itoa(listener.Addr().(*net.TCPAddr).Port)} + var flags []string if test.testType == serverTest { flags = append(flags, "-server") @@ -1495,88 +1602,26 @@ func runTest(statusChan chan statusMsg, test *testCase, shimPath string, mallocN flags = append(flags, test.flags...) - var shim *exec.Cmd - if *useValgrind { - shim = valgrindOf(false, shimPath, flags...) - } else if *useGDB { - shim = gdbOf(shimPath, flags...) - } else if *useLLDB { - shim = lldbOf(shimPath, flags...) - } else if *useRR { - shim = rrOf(shimPath, flags...) - } else { - shim = exec.Command(shimPath, flags...) - } - shim.Stdin = os.Stdin - var stdoutBuf, stderrBuf bytes.Buffer - shim.Stdout = &stdoutBuf - shim.Stderr = &stderrBuf + var env []string if mallocNumToFail >= 0 { - shim.Env = os.Environ() - shim.Env = append(shim.Env, "MALLOC_NUMBER_TO_FAIL="+strconv.FormatInt(mallocNumToFail, 10)) + env = os.Environ() + env = append(env, "MALLOC_NUMBER_TO_FAIL="+strconv.FormatInt(mallocNumToFail, 10)) if *mallocTestDebug { - shim.Env = append(shim.Env, "MALLOC_BREAK_ON_FAIL=1") + env = append(env, "MALLOC_BREAK_ON_FAIL=1") } - shim.Env = append(shim.Env, "_MALLOC_CHECK=1") - } - - if err := shim.Start(); err != nil { - panic(err) - } - statusChan <- statusMsg{test: test, statusType: statusShimStarted, pid: shim.Process.Pid} - waitChan := make(chan error, 1) - go func() { waitChan <- shim.Wait() }() - - config := test.config - - if *deterministic { - config.Rand = &deterministicRand{} + env = append(env, "_MALLOC_CHECK=1") } - conn, err := acceptOrWait(listener, waitChan) - if err == nil { - err = doExchange(test, &config, conn, false /* not a resumption */, &transcripts, 0) - conn.Close() - } - - for i := 0; err == nil && i < resumeCount; i++ { - var resumeConfig Config - if test.resumeConfig != nil { - resumeConfig = *test.resumeConfig - if !test.newSessionsOnResume { - resumeConfig.SessionTicketKey = config.SessionTicketKey - resumeConfig.ClientSessionCache = config.ClientSessionCache - resumeConfig.ServerSessionCache = config.ServerSessionCache - } - resumeConfig.Rand = config.Rand - } else { - resumeConfig = config - } - var connResume net.Conn - connResume, err = acceptOrWait(listener, waitChan) - if err == nil { - err = doExchange(test, &resumeConfig, connResume, true /* resumption */, &transcripts, i+1) - connResume.Close() - } + shim, err := newShimProcess(shimPath, flags, env) + if err != nil { + return err } + defer shim.close() - // Close the listener now. This is to avoid hangs should the shim try to - // open more connections than expected. - listener.Close() - listener = nil + localErr := doExchanges(test, shim, resumeCount, &transcripts) + childErr := shim.wait() - var childErr error - if useDebugger() { - childErr = <-waitChan - } else { - waitTimeout := time.AfterFunc(*idleTimeout, func() { - shim.Process.Kill() - }) - childErr = <-waitChan - waitTimeout.Stop() - } - - // Now that the shim has exitted, all the settings files have been + // Now that the shim has exited, all the settings files have been // written. Append the saved transcripts. for i, transcript := range transcripts { if err := appendTranscript(transcriptPrefix+strconv.Itoa(i), transcript); err != nil { @@ -1599,8 +1644,8 @@ func runTest(statusChan chan statusMsg, test *testCase, shimPath string, mallocN } // Account for Windows line endings. - stdout := strings.Replace(string(stdoutBuf.Bytes()), "\r\n", "\n", -1) - stderr := strings.Replace(string(stderrBuf.Bytes()), "\r\n", "\n", -1) + stdout := strings.Replace(shim.stdout.String(), "\r\n", "\n", -1) + stderr := strings.Replace(shim.stderr.String(), "\r\n", "\n", -1) // Work around an NDK / Android bug. The NDK r16 sometimes generates // binaries with the DF_1_PIE, which the runtime linker on Android N @@ -1622,22 +1667,22 @@ func runTest(statusChan chan statusMsg, test *testCase, shimPath string, mallocN extraStderr = stderrParts[1] } - failed := err != nil || childErr != nil + failed := localErr != nil || childErr != nil expectedError := translateExpectedError(test.expectedError) correctFailure := len(expectedError) == 0 || strings.Contains(stderr, expectedError) - localError := "none" - if err != nil { - localError = err.Error() + localErrString := "none" + if localErr != nil { + localErrString = localErr.Error() } if len(test.expectedLocalError) != 0 { - correctFailure = correctFailure && strings.Contains(localError, test.expectedLocalError) + correctFailure = correctFailure && strings.Contains(localErrString, test.expectedLocalError) } if failed != test.shouldFail || failed && !correctFailure || mustFail { - childError := "none" + childErrString := "none" if childErr != nil { - childError = childErr.Error() + childErrString = childErr.Error() } var msg string @@ -1654,7 +1699,7 @@ func runTest(statusChan chan statusMsg, test *testCase, shimPath string, mallocN panic("internal error") } - return fmt.Errorf("%s: local error '%s', child error '%s', stdout:\n%s\nstderr:\n%s\n%s", msg, localError, childError, stdout, stderr, extraStderr) + return fmt.Errorf("%s: local error '%s', child error '%s', stdout:\n%s\nstderr:\n%s\n%s", msg, localErrString, childErrString, stdout, stderr, extraStderr) } if len(extraStderr) > 0 || (!failed && len(stderr) > 0) {