go.net/websocket: allow server configurable

Add websocket.Server to configure WebSocket server handler.

- Config.Header is additional headers to send, so you can use it
  to send cookies or so.
  To read cookies, you can use Conn.Request().Header.
- factor out Handshake.
  You can set func to check origin, subprotocol etc.
  Handler checks origin by default.

Fixes golang/go#4198.
Fixes golang/go#5178.

R=golang-dev, mikioh.mikioh, crobin
CC=golang-dev
https://golang.org/cl/8731044
diff --git a/websocket/client.go b/websocket/client.go
index e59da0b..df54a68 100644
--- a/websocket/client.go
+++ b/websocket/client.go
@@ -9,6 +9,7 @@
 	"crypto/tls"
 	"io"
 	"net"
+	"net/http"
 	"net/url"
 )
 
@@ -34,6 +35,7 @@
 	if err != nil {
 		return
 	}
+	config.Header = http.Header(make(map[string][]string))
 	return
 }
 
diff --git a/websocket/hybi.go b/websocket/hybi.go
index 0023d1c..c6ba6cf 100644
--- a/websocket/hybi.go
+++ b/websocket/hybi.go
@@ -46,6 +46,17 @@
 	ErrBadClosingStatus      = &ProtocolError{"bad closing status"}
 	ErrUnsupportedExtensions = &ProtocolError{"unsupported extensions"}
 	ErrNotImplemented        = &ProtocolError{"not implemented"}
+
+	handshakeHeader = map[string]bool{
+		"Host":                   true,
+		"Upgrade":                true,
+		"Connection":             true,
+		"Sec-Websocket-Key":      true,
+		"Sec-Websocket-Origin":   true,
+		"Sec-Websocket-Version":  true,
+		"Sec-Websocket-Protocol": true,
+		"Sec-Websocket-Accept":   true,
+	}
 )
 
 // A hybiFrameHeader is a frame header as defined in hybi draft.
@@ -408,8 +419,11 @@
 	if len(config.Protocol) > 0 {
 		bw.WriteString("Sec-WebSocket-Protocol: " + strings.Join(config.Protocol, ", ") + "\r\n")
 	}
-	// TODO(ukai): send extensions.
-	// TODO(ukai): send cookie if any.
+	// TODO(ukai): send Sec-WebSocket-Extensions.
+	err = config.Header.WriteSubset(bw, handshakeHeader)
+	if err != nil {
+		return err
+	}
 
 	bw.WriteString("\r\n")
 	if err = bw.Flush(); err != nil {
@@ -483,21 +497,14 @@
 		return http.StatusBadRequest, ErrChallengeResponse
 	}
 	version := req.Header.Get("Sec-Websocket-Version")
-	var origin string
 	switch version {
 	case "13":
 		c.Version = ProtocolVersionHybi13
-		origin = req.Header.Get("Origin")
 	case "8":
 		c.Version = ProtocolVersionHybi08
-		origin = req.Header.Get("Sec-Websocket-Origin")
 	default:
 		return http.StatusBadRequest, ErrBadWebSocketVersion
 	}
-	c.Origin, err = url.ParseRequestURI(origin)
-	if err != nil {
-		return http.StatusForbidden, err
-	}
 	var scheme string
 	if req.TLS != nil {
 		scheme = "wss"
@@ -520,6 +527,22 @@
 	return http.StatusSwitchingProtocols, nil
 }
 
+// Origin parses Origin header in "req".
+// If origin is "null", returns (nil, nil).
+func Origin(config *Config, req *http.Request) (*url.URL, error) {
+	var origin string
+	switch config.Version {
+	case ProtocolVersionHybi13:
+		origin = req.Header.Get("Origin")
+	case ProtocolVersionHybi08:
+		origin = req.Header.Get("Sec-Websocket-Origin")
+	}
+	if origin == "null" {
+		return nil, nil
+	}
+	return url.ParseRequestURI(origin)
+}
+
 func (c *hybiServerHandshaker) AcceptHandshake(buf *bufio.Writer) (err error) {
 	if len(c.Protocol) > 0 {
 		if len(c.Protocol) != 1 {
@@ -533,7 +556,13 @@
 	if len(c.Protocol) > 0 {
 		buf.WriteString("Sec-WebSocket-Protocol: " + c.Protocol[0] + "\r\n")
 	}
-	// TODO(ukai): support extensions
+	// TODO(ukai): send Sec-WebSocket-Extensions.
+	if c.Header != nil {
+		err := c.Header.WriteSubset(buf, handshakeHeader)
+		if err != nil {
+			return err
+		}
+	}
 	buf.WriteString("\r\n")
 	return buf.Flush()
 }
diff --git a/websocket/hybi_test.go b/websocket/hybi_test.go
index b527e0b..01ed9e9 100644
--- a/websocket/hybi_test.go
+++ b/websocket/hybi_test.go
@@ -92,6 +92,71 @@
 	}
 }
 
+func TestHybiClientHandshakeWithHeader(t *testing.T) {
+	b := bytes.NewBuffer([]byte{})
+	bw := bufio.NewWriter(b)
+	br := bufio.NewReader(strings.NewReader(`HTTP/1.1 101 Switching Protocols
+Upgrade: websocket
+Connection: Upgrade
+Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=
+Sec-WebSocket-Protocol: chat
+
+`))
+	var err error
+	config := new(Config)
+	config.Location, err = url.ParseRequestURI("ws://server.example.com/chat")
+	if err != nil {
+		t.Fatal("location url", err)
+	}
+	config.Origin, err = url.ParseRequestURI("http://example.com")
+	if err != nil {
+		t.Fatal("origin url", err)
+	}
+	config.Protocol = append(config.Protocol, "chat")
+	config.Protocol = append(config.Protocol, "superchat")
+	config.Version = ProtocolVersionHybi13
+	config.Header = http.Header(make(map[string][]string))
+	config.Header.Add("User-Agent", "test")
+
+	config.handshakeData = map[string]string{
+		"key": "dGhlIHNhbXBsZSBub25jZQ==",
+	}
+	err = hybiClientHandshake(config, br, bw)
+	if err != nil {
+		t.Errorf("handshake failed: %v", err)
+	}
+	req, err := http.ReadRequest(bufio.NewReader(b))
+	if err != nil {
+		t.Fatalf("read request: %v", err)
+	}
+	if req.Method != "GET" {
+		t.Errorf("request method expected GET, but got %q", req.Method)
+	}
+	if req.URL.Path != "/chat" {
+		t.Errorf("request path expected /chat, but got %q", req.URL.Path)
+	}
+	if req.Proto != "HTTP/1.1" {
+		t.Errorf("request proto expected HTTP/1.1, but got %q", req.Proto)
+	}
+	if req.Host != "server.example.com" {
+		t.Errorf("request Host expected server.example.com, but got %v", req.Host)
+	}
+	var expectedHeader = map[string]string{
+		"Connection":             "Upgrade",
+		"Upgrade":                "websocket",
+		"Sec-Websocket-Key":      config.handshakeData["key"],
+		"Origin":                 config.Origin.String(),
+		"Sec-Websocket-Protocol": "chat, superchat",
+		"Sec-Websocket-Version":  fmt.Sprintf("%d", ProtocolVersionHybi13),
+		"User-Agent":             "test",
+	}
+	for k, v := range expectedHeader {
+		if req.Header.Get(k) != v {
+			t.Errorf(fmt.Sprintf("%s expected %q but got %q", k, v, req.Header.Get(k)))
+		}
+	}
+}
+
 func TestHybiClientHandshakeHybi08(t *testing.T) {
 	b := bytes.NewBuffer([]byte{})
 	bw := bufio.NewWriter(b)
diff --git a/websocket/server.go b/websocket/server.go
index 428bfb4..54e05b4 100644
--- a/websocket/server.go
+++ b/websocket/server.go
@@ -11,8 +11,7 @@
 	"net/http"
 )
 
-func newServerConn(rwc io.ReadWriteCloser, buf *bufio.ReadWriter, req *http.Request) (conn *Conn, err error) {
-	config := new(Config)
+func newServerConn(rwc io.ReadWriteCloser, buf *bufio.ReadWriter, req *http.Request, config *Config, handshake func(*Config, *http.Request) error) (conn *Conn, err error) {
 	var hs serverHandshaker = &hybiServerHandshaker{Config: config}
 	code, err := hs.ReadHandshake(buf.Reader, req)
 	if err == ErrBadWebSocketVersion {
@@ -38,8 +37,16 @@
 		buf.Flush()
 		return
 	}
-	config.Protocol = nil
-
+	if handshake != nil {
+		err = handshake(config, req)
+		if err != nil {
+			code = http.StatusForbidden
+			fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code))
+			buf.WriteString("\r\n")
+			buf.Flush()
+			return
+		}
+	}
 	err = hs.AcceptHandshake(buf.Writer)
 	if err != nil {
 		code = http.StatusBadRequest
@@ -52,11 +59,26 @@
 	return
 }
 
-// Handler is an interface to a WebSocket.
-type Handler func(*Conn)
+// Server represents a server of a WebSocket.
+type Server struct {
+	// Config is a WebSocket configuration for new WebSocket connection.
+	Config
 
-// ServeHTTP implements the http.Handler interface for a Web Socket
-func (h Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
+	// Handshake is an optional function in WebSocket handshake.
+	// For example, you can check, or don't check Origin header.
+	// Another example, you can select config.Protocol.
+	Handshake func(*Config, *http.Request) error
+
+	// Handler handles a WebSocket connection.
+	Handler
+}
+
+// ServeHTTP implements the http.Handler interface for a WebSocket
+func (s Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
+	s.serveWebSocket(w, req)
+}
+
+func (s Server) serveWebSocket(w http.ResponseWriter, req *http.Request) {
 	rwc, buf, err := w.(http.Hijacker).Hijack()
 	if err != nil {
 		panic("Hijack failed: " + err.Error())
@@ -66,12 +88,35 @@
 	// the client did not send a handshake that matches with protocol
 	// specification.
 	defer rwc.Close()
-	conn, err := newServerConn(rwc, buf, req)
+	conn, err := newServerConn(rwc, buf, req, &s.Config, s.Handshake)
 	if err != nil {
 		return
 	}
 	if conn == nil {
 		panic("unexpected nil conn")
 	}
-	h(conn)
+	s.Handler(conn)
+}
+
+// Handler is a simple interface to a WebSocket browser client.
+// It checks if Origin header is valid URL by default.
+// You might want to verify websocket.Conn.Config().Origin in the func.
+// If you use Server instead of Handler, you could call websocket.Origin and
+// check the origin in your Handshake func. So, if you want to accept
+// non-browser client, which doesn't send Origin header, you could use Server
+//. that doesn't check origin in its Handshake.
+type Handler func(*Conn)
+
+func checkOrigin(config *Config, req *http.Request) (err error) {
+	config.Origin, err = Origin(config, req)
+	if err == nil && config.Origin == nil {
+		return fmt.Errorf("null origin")
+	}
+	return err
+}
+
+// ServeHTTP implements the http.Handler interface for a WebSocket
+func (h Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
+	s := Server{Handler: h, Handshake: checkOrigin}
+	s.serveWebSocket(w, req)
 }
diff --git a/websocket/websocket.go b/websocket/websocket.go
index 793e510..861b3c6 100644
--- a/websocket/websocket.go
+++ b/websocket/websocket.go
@@ -87,6 +87,9 @@
 	// TLS config for secure WebSocket (wss).
 	TlsConfig *tls.Config
 
+	// Additional header fields to be sent in WebSocket opening handshake.
+	Header http.Header
+
 	handshakeData map[string]string
 }
 
diff --git a/websocket/websocket_test.go b/websocket/websocket_test.go
index 40c147f..53e445b 100644
--- a/websocket/websocket_test.go
+++ b/websocket/websocket_test.go
@@ -44,9 +44,30 @@
 	}
 }
 
+func subProtocolHandshake(config *Config, req *http.Request) error {
+	for _, proto := range config.Protocol {
+		if proto == "chat" {
+			config.Protocol = []string{proto}
+			return nil
+		}
+	}
+	return ErrBadWebSocketProtocol
+}
+
+func subProtoServer(ws *Conn) {
+	for _, proto := range ws.Config().Protocol {
+		io.WriteString(ws, proto)
+	}
+}
+
 func startServer() {
 	http.Handle("/echo", Handler(echoServer))
 	http.Handle("/count", Handler(countServer))
+	subproto := Server{
+		Handshake: subProtocolHandshake,
+		Handler:   Handler(subProtoServer),
+	}
+	http.Handle("/subproto", subproto)
 	server := httptest.NewServer(nil)
 	serverAddr = server.Listener.Addr().String()
 	log.Print("Test WebSocket server listening on ", serverAddr)
@@ -177,7 +198,7 @@
 	ws.Close()
 }
 
-func TestWithProtocol(t *testing.T) {
+func testWithProtocol(t *testing.T, subproto []string) (string, error) {
 	once.Do(startServer)
 
 	client, err := net.Dial("tcp", serverAddr)
@@ -185,15 +206,47 @@
 		t.Fatal("dialing", err)
 	}
 
-	config := newConfig(t, "/echo")
-	config.Protocol = append(config.Protocol, "test")
+	config := newConfig(t, "/subproto")
+	config.Protocol = subproto
 
 	ws, err := NewClient(config, client)
 	if err != nil {
-		t.Errorf("WebSocket handshake: %v", err)
-		return
+		return "", err
+	}
+	msg := make([]byte, 16)
+	n, err := ws.Read(msg)
+	if err != nil {
+		return "", err
 	}
 	ws.Close()
+	return string(msg[:n]), nil
+}
+
+func TestWithProtocol(t *testing.T) {
+	proto, err := testWithProtocol(t, []string{"chat"})
+	if err != nil {
+		t.Errorf("SubProto: unexpected error: %v", err)
+	}
+	if proto != "chat" {
+		t.Errorf("SubProto: expected %q, got %q", "chat", proto)
+	}
+}
+
+func TestWithTwoProtocol(t *testing.T) {
+	proto, err := testWithProtocol(t, []string{"test", "chat"})
+	if err != nil {
+		t.Errorf("SubProto: unexpected error: %v", err)
+	}
+	if proto != "chat" {
+		t.Errorf("SubProto: expected %q, got %q", "chat", proto)
+	}
+}
+
+func TestWithBadProtocol(t *testing.T) {
+	_, err := testWithProtocol(t, []string{"test"})
+	if err != ErrBadStatus {
+		t.Errorf("SubProto: expected %q, got %q", ErrBadStatus)
+	}
 }
 
 func TestHTTP(t *testing.T) {