http: fix chunking bug during content sniffing
R=golang-dev, bradfitz, gri
CC=golang-dev
https://golang.org/cl/4807044
diff --git a/src/pkg/http/httptest/server.go b/src/pkg/http/httptest/server.go
index 879f04f..2ec36d0 100644
--- a/src/pkg/http/httptest/server.go
+++ b/src/pkg/http/httptest/server.go
@@ -9,6 +9,7 @@
import (
"crypto/rand"
"crypto/tls"
+ "flag"
"fmt"
"http"
"net"
@@ -49,15 +50,34 @@
return l
}
+// When debugging a particular http server-based test,
+// this flag lets you run
+// gotest -run=BrokenTest -httptest.serve=127.0.0.1:8000
+// to start the broken server so you can interact with it manually.
+var serve = flag.String("httptest.serve", "", "if non-empty, httptest.NewServer serves on this address and blocks")
+
// NewServer starts and returns a new Server.
// The caller should call Close when finished, to shut it down.
func NewServer(handler http.Handler) *Server {
ts := new(Server)
- l := newLocalListener()
+ var l net.Listener
+ if *serve != "" {
+ var err os.Error
+ l, err = net.Listen("tcp", *serve)
+ if err != nil {
+ panic(fmt.Sprintf("httptest: failed to listen on %v: %v", *serve, err))
+ }
+ } else {
+ l = newLocalListener()
+ }
ts.Listener = &historyListener{l, make([]net.Conn, 0)}
ts.URL = "http://" + l.Addr().String()
server := &http.Server{Handler: handler}
go server.Serve(ts.Listener)
+ if *serve != "" {
+ fmt.Println(os.Stderr, "httptest: serving on", ts.URL)
+ select {}
+ }
return ts
}
diff --git a/src/pkg/http/server.go b/src/pkg/http/server.go
index b3fb8e1..f14ef8c 100644
--- a/src/pkg/http/server.go
+++ b/src/pkg/http/server.go
@@ -255,9 +255,7 @@
} else {
// If no content type, apply sniffing algorithm to body.
if w.header.Get("Content-Type") == "" {
- // NOTE(dsymonds): the sniffing mechanism in this file is currently broken.
- //w.needSniff = true
- w.header.Set("Content-Type", "text/html; charset=utf-8")
+ w.needSniff = true
}
}
@@ -364,10 +362,16 @@
fmt.Fprintf(w.conn.buf, "Content-Type: %s\r\n", DetectContentType(data))
io.WriteString(w.conn.buf, "\r\n")
- if w.chunking && len(data) > 0 {
+ if len(data) == 0 {
+ return
+ }
+ if w.chunking {
fmt.Fprintf(w.conn.buf, "%x\r\n", len(data))
}
- w.conn.buf.Write(data)
+ _, err := w.conn.buf.Write(data)
+ if w.chunking && err == nil {
+ io.WriteString(w.conn.buf, "\r\n")
+ }
}
// bodyAllowed returns true if a Write is allowed for this response type.
@@ -401,12 +405,23 @@
var m int
if w.needSniff {
+ // We need to sniff the beginning of the output to
+ // determine the content type. Accumulate the
+ // initial writes in w.conn.body.
body := w.conn.body
- m = copy(body[len(body):], data)
+ m = copy(body[len(body):cap(body)], data)
w.conn.body = body[:len(body)+m]
if m == len(data) {
+ // Copied everything into the buffer.
+ // Wait for next write.
return m, nil
}
+
+ // Filled the buffer; more data remains.
+ // Sniff the content (flushes the buffer)
+ // and then proceed with the remainder
+ // of the data as a normal Write.
+ // Calling sniff clears needSniff.
w.sniff()
data = data[m:]
}
diff --git a/src/pkg/http/sniff_test.go b/src/pkg/http/sniff_test.go
index 770496f..2d01807 100644
--- a/src/pkg/http/sniff_test.go
+++ b/src/pkg/http/sniff_test.go
@@ -2,16 +2,22 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-package http
+package http_test
import (
+ "bytes"
+ . "http"
+ "http/httptest"
+ "io/ioutil"
+ "log"
+ "strconv"
"testing"
)
var sniffTests = []struct {
- desc string
- data []byte
- exp string
+ desc string
+ data []byte
+ contentType string
}{
// Some nonsense.
{"Empty", []byte{}, "text/plain; charset=utf-8"},
@@ -30,11 +36,41 @@
{"GIF 89a", []byte(`GIF89a...`), "image/gif"},
}
-func TestSniffing(t *testing.T) {
- for _, st := range sniffTests {
- got := DetectContentType(st.data)
- if got != st.exp {
- t.Errorf("%v: sniffed as %v, want %v", st.desc, got, st.exp)
+func TestDetectContentType(t *testing.T) {
+ for _, tt := range sniffTests {
+ ct := DetectContentType(tt.data)
+ if ct != tt.contentType {
+ t.Errorf("%v: DetectContentType = %q, want %q", tt.desc, ct, tt.contentType)
}
}
}
+
+func TestServerContentType(t *testing.T) {
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ i, _ := strconv.Atoi(r.FormValue("i"))
+ tt := sniffTests[i]
+ n, err := w.Write(tt.data)
+ if n != len(tt.data) || err != nil {
+ log.Fatalf("%v: Write(%q) = %v, %v want %d, nil", tt.desc, tt.data, n, err, len(tt.data))
+ }
+ }))
+ defer ts.Close()
+
+ for i, tt := range sniffTests {
+ resp, err := Get(ts.URL + "/?i=" + strconv.Itoa(i))
+ if err != nil {
+ t.Errorf("%v: %v", tt.desc, err)
+ continue
+ }
+ if ct := resp.Header.Get("Content-Type"); ct != tt.contentType {
+ t.Errorf("%v: Content-Type = %q, want %q", tt.desc, ct, tt.contentType)
+ }
+ data, err := ioutil.ReadAll(resp.Body)
+ if err != nil {
+ t.Errorf("%v: reading body: %v", tt.desc, err)
+ } else if !bytes.Equal(data, tt.data) {
+ t.Errorf("%v: data is %q, want %q", tt.desc, data, tt.data)
+ }
+ resp.Body.Close()
+ }
+}