| // Copyright 2011 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package httptest |
| |
| import ( |
| "bytes" |
| "fmt" |
| "io" |
| "net/http" |
| "net/textproto" |
| "strconv" |
| "strings" |
| |
| "golang.org/x/net/http/httpguts" |
| ) |
| |
| // ResponseRecorder is an implementation of http.ResponseWriter that |
| // records its mutations for later inspection in tests. |
| type ResponseRecorder struct { |
| // Code is the HTTP response code set by WriteHeader. |
| // |
| // Note that if a Handler never calls WriteHeader or Write, |
| // this might end up being 0, rather than the implicit |
| // http.StatusOK. To get the implicit value, use the Result |
| // method. |
| Code int |
| |
| // HeaderMap contains the headers explicitly set by the Handler. |
| // It is an internal detail. |
| // |
| // Deprecated: HeaderMap exists for historical compatibility |
| // and should not be used. To access the headers returned by a handler, |
| // use the Response.Header map as returned by the Result method. |
| HeaderMap http.Header |
| |
| // Body is the buffer to which the Handler's Write calls are sent. |
| // If nil, the Writes are silently discarded. |
| Body *bytes.Buffer |
| |
| // Flushed is whether the Handler called Flush. |
| Flushed bool |
| |
| result *http.Response // cache of Result's return value |
| snapHeader http.Header // snapshot of HeaderMap at first Write |
| wroteHeader bool |
| } |
| |
| // NewRecorder returns an initialized ResponseRecorder. |
| func NewRecorder() *ResponseRecorder { |
| return &ResponseRecorder{ |
| HeaderMap: make(http.Header), |
| Body: new(bytes.Buffer), |
| Code: 200, |
| } |
| } |
| |
| // DefaultRemoteAddr is the default remote address to return in RemoteAddr if |
| // an explicit DefaultRemoteAddr isn't set on ResponseRecorder. |
| const DefaultRemoteAddr = "1.2.3.4" |
| |
| // Header implements http.ResponseWriter. It returns the response |
| // headers to mutate within a handler. To test the headers that were |
| // written after a handler completes, use the Result method and see |
| // the returned Response value's Header. |
| func (rw *ResponseRecorder) Header() http.Header { |
| m := rw.HeaderMap |
| if m == nil { |
| m = make(http.Header) |
| rw.HeaderMap = m |
| } |
| return m |
| } |
| |
| // writeHeader writes a header if it was not written yet and |
| // detects Content-Type if needed. |
| // |
| // bytes or str are the beginning of the response body. |
| // We pass both to avoid unnecessarily generate garbage |
| // in rw.WriteString which was created for performance reasons. |
| // Non-nil bytes win. |
| func (rw *ResponseRecorder) writeHeader(b []byte, str string) { |
| if rw.wroteHeader { |
| return |
| } |
| if len(str) > 512 { |
| str = str[:512] |
| } |
| |
| m := rw.Header() |
| |
| _, hasType := m["Content-Type"] |
| hasTE := m.Get("Transfer-Encoding") != "" |
| if !hasType && !hasTE { |
| if b == nil { |
| b = []byte(str) |
| } |
| m.Set("Content-Type", http.DetectContentType(b)) |
| } |
| |
| rw.WriteHeader(200) |
| } |
| |
| // Write implements http.ResponseWriter. The data in buf is written to |
| // rw.Body, if not nil. |
| func (rw *ResponseRecorder) Write(buf []byte) (int, error) { |
| rw.writeHeader(buf, "") |
| if rw.Body != nil { |
| rw.Body.Write(buf) |
| } |
| return len(buf), nil |
| } |
| |
| // WriteString implements io.StringWriter. The data in str is written |
| // to rw.Body, if not nil. |
| func (rw *ResponseRecorder) WriteString(str string) (int, error) { |
| rw.writeHeader(nil, str) |
| if rw.Body != nil { |
| rw.Body.WriteString(str) |
| } |
| return len(str), nil |
| } |
| |
| // WriteHeader implements http.ResponseWriter. |
| func (rw *ResponseRecorder) WriteHeader(code int) { |
| if rw.wroteHeader { |
| return |
| } |
| rw.Code = code |
| rw.wroteHeader = true |
| if rw.HeaderMap == nil { |
| rw.HeaderMap = make(http.Header) |
| } |
| rw.snapHeader = rw.HeaderMap.Clone() |
| } |
| |
| // Flush implements http.Flusher. To test whether Flush was |
| // called, see rw.Flushed. |
| func (rw *ResponseRecorder) Flush() { |
| if !rw.wroteHeader { |
| rw.WriteHeader(200) |
| } |
| rw.Flushed = true |
| } |
| |
| // Result returns the response generated by the handler. |
| // |
| // The returned Response will have at least its StatusCode, |
| // Header, Body, and optionally Trailer populated. |
| // More fields may be populated in the future, so callers should |
| // not DeepEqual the result in tests. |
| // |
| // The Response.Header is a snapshot of the headers at the time of the |
| // first write call, or at the time of this call, if the handler never |
| // did a write. |
| // |
| // The Response.Body is guaranteed to be non-nil and Body.Read call is |
| // guaranteed to not return any error other than io.EOF. |
| // |
| // Result must only be called after the handler has finished running. |
| func (rw *ResponseRecorder) Result() *http.Response { |
| if rw.result != nil { |
| return rw.result |
| } |
| if rw.snapHeader == nil { |
| rw.snapHeader = rw.HeaderMap.Clone() |
| } |
| res := &http.Response{ |
| Proto: "HTTP/1.1", |
| ProtoMajor: 1, |
| ProtoMinor: 1, |
| StatusCode: rw.Code, |
| Header: rw.snapHeader, |
| } |
| rw.result = res |
| if res.StatusCode == 0 { |
| res.StatusCode = 200 |
| } |
| res.Status = fmt.Sprintf("%03d %s", res.StatusCode, http.StatusText(res.StatusCode)) |
| if rw.Body != nil { |
| res.Body = io.NopCloser(bytes.NewReader(rw.Body.Bytes())) |
| } else { |
| res.Body = http.NoBody |
| } |
| res.ContentLength = parseContentLength(res.Header.Get("Content-Length")) |
| |
| if trailers, ok := rw.snapHeader["Trailer"]; ok { |
| res.Trailer = make(http.Header, len(trailers)) |
| for _, k := range trailers { |
| k = http.CanonicalHeaderKey(k) |
| if !httpguts.ValidTrailerHeader(k) { |
| // Ignore since forbidden by RFC 7230, section 4.1.2. |
| continue |
| } |
| vv, ok := rw.HeaderMap[k] |
| if !ok { |
| continue |
| } |
| vv2 := make([]string, len(vv)) |
| copy(vv2, vv) |
| res.Trailer[k] = vv2 |
| } |
| } |
| for k, vv := range rw.HeaderMap { |
| if !strings.HasPrefix(k, http.TrailerPrefix) { |
| continue |
| } |
| if res.Trailer == nil { |
| res.Trailer = make(http.Header) |
| } |
| for _, v := range vv { |
| res.Trailer.Add(strings.TrimPrefix(k, http.TrailerPrefix), v) |
| } |
| } |
| return res |
| } |
| |
| // parseContentLength trims whitespace from s and returns -1 if no value |
| // is set, or the value if it's >= 0. |
| // |
| // This a modified version of same function found in net/http/transfer.go. This |
| // one just ignores an invalid header. |
| func parseContentLength(cl string) int64 { |
| cl = textproto.TrimString(cl) |
| if cl == "" { |
| return -1 |
| } |
| n, err := strconv.ParseUint(cl, 10, 63) |
| if err != nil { |
| return -1 |
| } |
| return int64(n) |
| } |