| // Copyright 2024 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package httprr |
| |
| import ( |
| "bytes" |
| "errors" |
| "io" |
| "net/http" |
| "net/http/httptest" |
| "os" |
| "strings" |
| "testing" |
| "testing/iotest" |
| ) |
| |
| func handler(w http.ResponseWriter, r *http.Request) { |
| if strings.HasSuffix(r.URL.Path, "/redirect") { |
| http.Error(w, "redirect me!", 304) |
| return |
| } |
| if r.Method == "GET" { |
| if r.Header.Get("Secret") != "key" { |
| http.Error(w, "missing secret", 666) |
| return |
| } |
| } |
| if r.Method == "POST" { |
| data, err := io.ReadAll(r.Body) |
| if err != nil { |
| panic(err) |
| } |
| if !strings.Contains(string(data), "my Secret") { |
| http.Error(w, "missing body secret", 667) |
| return |
| } |
| } |
| } |
| |
| func always555(w http.ResponseWriter, r *http.Request) { |
| http.Error(w, "should not be making HTTP requests", 555) |
| } |
| |
| func dropPort(r *http.Request) error { |
| if r.URL.Port() != "" { |
| r.URL.Host = r.URL.Host[:strings.LastIndex(r.URL.Host, ":")] |
| r.Host = r.Host[:strings.LastIndex(r.Host, ":")] |
| } |
| return nil |
| } |
| |
| func dropSecretHeader(r *http.Request) error { |
| r.Header.Del("Secret") |
| return nil |
| } |
| |
| func hideSecretBody(r *http.Request) error { |
| if r.Body != nil { |
| body := r.Body.(*Body) |
| body.Data = []byte("redacted") |
| } |
| return nil |
| } |
| |
| func doNothing(b *bytes.Buffer) error { |
| return nil |
| } |
| |
| func doRefresh(b *bytes.Buffer) error { |
| s := b.String() |
| b.Reset() |
| _, _ = b.WriteString(s) |
| return nil |
| } |
| |
| func TestRecordReplay(t *testing.T) { |
| dir := t.TempDir() |
| file := dir + "/rr" |
| |
| // 4 passes: |
| // 0: create |
| // 1: open |
| // 2: Open with -httprecord="r+" |
| // 3: Open with -httprecord="" |
| for pass := range 4 { |
| start := open |
| h := always555 |
| *record = "" |
| switch pass { |
| case 0: |
| start = create |
| h = handler |
| case 2: |
| start = Open |
| *record = "r+" |
| h = handler |
| case 3: |
| start = Open |
| } |
| rr, err := start(file, http.DefaultTransport) |
| if err != nil { |
| t.Fatal(err) |
| } |
| if rr.Recording() { |
| t.Log("RECORDING") |
| } else { |
| t.Log("REPLAYING") |
| } |
| rr.ScrubReq(dropPort, dropSecretHeader) |
| rr.ScrubReq(hideSecretBody) |
| rr.ScrubResp(doNothing, doRefresh) |
| |
| mustNewRequest := func(method, url string, body io.Reader) *http.Request { |
| req, err := http.NewRequest(method, url, body) |
| if err != nil { |
| t.Helper() |
| t.Fatal(err) |
| } |
| return req |
| } |
| |
| mustDo := func(req *http.Request, status int) { |
| resp, err := rr.Client().Do(req) |
| if err != nil { |
| t.Helper() |
| t.Fatal(err) |
| } |
| body, _ := io.ReadAll(resp.Body) |
| resp.Body.Close() |
| if resp.StatusCode != status { |
| t.Helper() |
| t.Fatalf("%v: %s\n%s", req.URL, resp.Status, body) |
| } |
| } |
| |
| srv := httptest.NewServer(http.HandlerFunc(h)) |
| defer srv.Close() |
| |
| req := mustNewRequest("GET", srv.URL+"/myrequest", nil) |
| req.Header.Set("Secret", "key") |
| mustDo(req, 200) |
| |
| req = mustNewRequest("POST", srv.URL+"/myrequest", strings.NewReader("my Secret")) |
| mustDo(req, 200) |
| |
| req = mustNewRequest("GET", srv.URL+"/redirect", nil) |
| mustDo(req, 304) |
| |
| if !rr.Recording() { |
| req = mustNewRequest("GET", srv.URL+"/uncached", nil) |
| resp, err := rr.Client().Do(req) |
| if err == nil { |
| body, _ := io.ReadAll(resp.Body) |
| t.Fatalf("%v: %s\n%s", req.URL, resp.Status, body) |
| } |
| } |
| |
| if err := rr.Close(); err != nil { |
| t.Fatal(err) |
| } |
| } |
| |
| data, err := os.ReadFile(file) |
| if err != nil { |
| t.Fatal(err) |
| } |
| if strings.Contains(string(data), "Secret") { |
| t.Fatalf("rr file contains Secret:\n%s", data) |
| } |
| } |
| |
| var badResponseTrace = []byte("httprr trace v1\n" + |
| "92 75\n" + |
| "GET http://127.0.0.1/myrequest HTTP/1.1\r\n" + |
| "Host: 127.0.0.1\r\n" + |
| "User-Agent: Go-http-client/1.1\r\n" + |
| "\r\n" + |
| "HZZP/1.1 200 OK\r\n" + |
| "Date: Wed, 12 Jun 2024 13:55:02 GMT\r\n" + |
| "Content-Length: 0\r\n" + |
| "\r\n") |
| |
| func TestErrors(t *testing.T) { |
| dir := t.TempDir() |
| |
| makeTmpFile := func() string { |
| f, err := os.CreateTemp(dir, "TestErrors") |
| if err != nil { |
| t.Fatalf("failed to create tmp file for test: %v", err) |
| } |
| name := f.Name() |
| f.Close() |
| return name |
| } |
| |
| // -httprecord regexp parsing |
| *record = "+" |
| if _, err := Open(makeTmpFile(), nil); err == nil || !strings.Contains(err.Error(), "invalid -httprecord flag") { |
| t.Errorf("did not diagnose bad -httprecord: err = %v", err) |
| } |
| *record = "" |
| |
| // invalid httprr trace |
| if _, err := Open(makeTmpFile(), nil); err == nil || !strings.Contains(err.Error(), "not an httprr trace") { |
| t.Errorf("did not diagnose invalid httprr trace: err = %v", err) |
| } |
| |
| // corrupt httprr trace |
| corruptTraceFile := makeTmpFile() |
| os.WriteFile(corruptTraceFile, []byte("httprr trace v1\ngarbage\n"), 0666) |
| if _, err := Open(corruptTraceFile, nil); err == nil || !strings.Contains(err.Error(), "corrupt httprr trace") { |
| t.Errorf("did not diagnose invalid httprr trace: err = %v", err) |
| } |
| |
| // os.Create error creating trace |
| if _, err := create("invalid\x00file", nil); err == nil { |
| t.Errorf("did not report failure from os.Create: err = %v", err) |
| } |
| |
| // os.ReadAll error reading trace |
| if _, err := open("nonexistent", nil); err == nil { |
| t.Errorf("did not report failure from os.ReadFile: err = %v", err) |
| } |
| |
| // error reading body |
| rr, err := create(makeTmpFile(), nil) |
| if err != nil { |
| t.Fatal(err) |
| } |
| if _, err := rr.Client().Post("http://127.0.0.1/nonexist", "x/error", iotest.ErrReader(errors.New("MY ERROR"))); err == nil || !strings.Contains(err.Error(), "MY ERROR") { |
| t.Errorf("did not report failure from io.ReadAll(body): err = %v", err) |
| } |
| |
| // error during request scrub |
| rr.ScrubReq(func(*http.Request) error { return errors.New("SCRUB ERROR") }) |
| if _, err := rr.Client().Get("http://127.0.0.1/nonexist"); err == nil || !strings.Contains(err.Error(), "SCRUB ERROR") { |
| t.Errorf("did not report failure from scrub: err = %v", err) |
| } |
| rr.Close() |
| |
| // error during response scrub |
| rr.ScrubResp(func(*bytes.Buffer) error { return errors.New("SCRUB ERROR") }) |
| if _, err := rr.Client().Get("http://127.0.0.1/nonexist"); err == nil || !strings.Contains(err.Error(), "SCRUB ERROR") { |
| t.Errorf("did not report failure from scrub: err = %v", err) |
| } |
| rr.Close() |
| |
| // error during rkey.WriteProxy |
| rr, err = create(makeTmpFile(), nil) |
| if err != nil { |
| t.Fatal(err) |
| } |
| rr.ScrubReq(func(req *http.Request) error { |
| req.URL = nil |
| req.Host = "" |
| return nil |
| }) |
| rr.ScrubResp(func(b *bytes.Buffer) error { |
| b.Reset() |
| return nil |
| }) |
| if _, err := rr.Client().Get("http://127.0.0.1/nonexist"); err == nil || !strings.Contains(err.Error(), "no Host or URL set") { |
| t.Errorf("did not report failure from rkey.WriteProxy: err = %v", err) |
| } |
| rr.Close() |
| |
| // error during resp.Write |
| rr, err = create(makeTmpFile(), badRespTransport{}) |
| if err != nil { |
| t.Fatal(err) |
| } |
| if _, err := rr.Client().Get("http://127.0.0.1/nonexist"); err == nil || !strings.Contains(err.Error(), "TRANSPORT ERROR") { |
| t.Errorf("did not report failure from resp.Write: err = %v", err) |
| } |
| rr.Close() |
| |
| // error during Write logging request |
| srv := httptest.NewServer(http.HandlerFunc(always555)) |
| defer srv.Close() |
| rr, err = create(makeTmpFile(), http.DefaultTransport) |
| if err != nil { |
| t.Fatal(err) |
| } |
| rr.ScrubReq(dropPort) |
| rr.record.Close() // cause write error |
| if _, err := rr.Client().Get(srv.URL + "/redirect"); err == nil || !strings.Contains(err.Error(), "file already closed") { |
| t.Errorf("did not report failure from record write: err = %v", err) |
| } |
| rr.writeErr = errors.New("BROKEN ERROR") |
| if _, err := rr.Client().Get(srv.URL + "/redirect"); err == nil || !strings.Contains(err.Error(), "BROKEN ERROR") { |
| t.Errorf("did not report previous write failure: err = %v", err) |
| } |
| if err := rr.Close(); err == nil || !strings.Contains(err.Error(), "BROKEN ERROR") { |
| t.Errorf("did not report write failure during close: err = %v", err) |
| } |
| |
| // error during RoundTrip |
| rr, err = create(makeTmpFile(), errTransport{errors.New("TRANSPORT ERROR")}) |
| if err != nil { |
| t.Fatal(err) |
| } |
| if _, err := rr.Client().Get(srv.URL); err == nil || !strings.Contains(err.Error(), "TRANSPORT ERROR") { |
| t.Errorf("did not report failure from transport: err = %v", err) |
| } |
| rr.Close() |
| |
| // error during http.ReadResponse: trace is structurally okay but has malformed response inside |
| tmpFile := makeTmpFile() |
| if err := os.WriteFile(tmpFile, badResponseTrace, 0666); err != nil { |
| t.Fatal(err) |
| } |
| rr, err = Open(tmpFile, nil) |
| if err != nil { |
| t.Fatal(err) |
| } |
| if _, err := rr.Client().Get("http://127.0.0.1/myrequest"); err == nil || !strings.Contains(err.Error(), "corrupt httprr trace:") { |
| t.Errorf("did not diagnose invalid httprr trace: err = %v", err) |
| } |
| rr.Close() |
| } |
| |
| type errTransport struct{ err error } |
| |
| func (e errTransport) RoundTrip(req *http.Request) (*http.Response, error) { |
| return nil, e.err |
| } |
| |
| type badRespTransport struct{} |
| |
| func (badRespTransport) RoundTrip(req *http.Request) (*http.Response, error) { |
| resp := new(http.Response) |
| resp.Body = io.NopCloser(iotest.ErrReader(errors.New("TRANSPORT ERROR"))) |
| return resp, nil |
| } |