| // Copyright 2009 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package websocket |
| |
| import ( |
| "bufio" |
| "fmt" |
| "io" |
| "net/http" |
| ) |
| |
| func newServerConn(rwc io.ReadWriteCloser, buf *bufio.ReadWriter, req *http.Request, config *Config, handshake func(*Config, *http.Request) error) (conn *Conn, err error) { |
| var hs serverHandshaker = &hybiServerHandshaker{Config: config} |
| code, err := hs.ReadHandshake(buf.Reader, req) |
| if err == ErrBadWebSocketVersion { |
| fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code)) |
| fmt.Fprintf(buf, "Sec-WebSocket-Version: %s\r\n", SupportedProtocolVersion) |
| buf.WriteString("\r\n") |
| buf.WriteString(err.Error()) |
| buf.Flush() |
| return |
| } |
| if err != nil { |
| fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code)) |
| buf.WriteString("\r\n") |
| buf.WriteString(err.Error()) |
| buf.Flush() |
| return |
| } |
| if handshake != nil { |
| err = handshake(config, req) |
| if err != nil { |
| code = http.StatusForbidden |
| fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code)) |
| buf.WriteString("\r\n") |
| buf.Flush() |
| return |
| } |
| } |
| err = hs.AcceptHandshake(buf.Writer) |
| if err != nil { |
| code = http.StatusBadRequest |
| fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code)) |
| buf.WriteString("\r\n") |
| buf.Flush() |
| return |
| } |
| conn = hs.NewServerConn(buf, rwc, req) |
| return |
| } |
| |
| // Server represents a server of a WebSocket. |
| type Server struct { |
| // Config is a WebSocket configuration for new WebSocket connection. |
| Config |
| |
| // Handshake is an optional function in WebSocket handshake. |
| // For example, you can check, or don't check Origin header. |
| // Another example, you can select config.Protocol. |
| Handshake func(*Config, *http.Request) error |
| |
| // Handler handles a WebSocket connection. |
| Handler |
| } |
| |
| // ServeHTTP implements the http.Handler interface for a WebSocket |
| func (s Server) ServeHTTP(w http.ResponseWriter, req *http.Request) { |
| s.serveWebSocket(w, req) |
| } |
| |
| func (s Server) serveWebSocket(w http.ResponseWriter, req *http.Request) { |
| rwc, buf, err := w.(http.Hijacker).Hijack() |
| if err != nil { |
| panic("Hijack failed: " + err.Error()) |
| } |
| // The server should abort the WebSocket connection if it finds |
| // the client did not send a handshake that matches with protocol |
| // specification. |
| defer rwc.Close() |
| conn, err := newServerConn(rwc, buf, req, &s.Config, s.Handshake) |
| if err != nil { |
| return |
| } |
| if conn == nil { |
| panic("unexpected nil conn") |
| } |
| s.Handler(conn) |
| } |
| |
| // Handler is a simple interface to a WebSocket browser client. |
| // It checks if Origin header is valid URL by default. |
| // You might want to verify websocket.Conn.Config().Origin in the func. |
| // If you use Server instead of Handler, you could call websocket.Origin and |
| // check the origin in your Handshake func. So, if you want to accept |
| // non-browser clients, which do not send an Origin header, set a |
| // Server.Handshake that does not check the origin. |
| type Handler func(*Conn) |
| |
| func checkOrigin(config *Config, req *http.Request) (err error) { |
| config.Origin, err = Origin(config, req) |
| if err == nil && config.Origin == nil { |
| return fmt.Errorf("null origin") |
| } |
| return err |
| } |
| |
| // ServeHTTP implements the http.Handler interface for a WebSocket |
| func (h Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) { |
| s := Server{Handler: h, Handshake: checkOrigin} |
| s.serveWebSocket(w, req) |
| } |