| // Copyright 2019 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| // Package revdial implements a Dialer and Listener which work together |
| // to turn an accepted connection (for instance, a Hijacked HTTP request) into |
| // a Dialer which can then create net.Conns connecting back to the original |
| // dialer, which then gets a net.Listener accepting those conns. |
| // |
| // This is basically a very minimal SOCKS5 client & server. |
| // |
| // The motivation is that sometimes you want to run a server on a |
| // machine deep inside a NAT. Rather than connecting to the machine |
| // directly (which you can't, because of the NAT), you have the |
| // sequestered machine connect out to a public machine. Both sides |
| // then use revdial and the public machine can become a client for the |
| // NATed machine. |
| package revdial |
| |
| import ( |
| "bufio" |
| "context" |
| "crypto/rand" |
| "encoding/json" |
| "errors" |
| "fmt" |
| "io" |
| "log" |
| "net" |
| "net/http" |
| "net/url" |
| "strings" |
| "sync" |
| "time" |
| ) |
| |
| // dialerUniqParam is the parameter name of the GET URL form value |
| // containing the Dialer's random unique ID. |
| const dialerUniqParam = "revdial.dialer" |
| |
| // The Dialer can create new connections. |
| type Dialer struct { |
| conn net.Conn // hijacked client conn |
| path string // e.g. "/revdial" |
| uniqID string |
| pickupPath string // path + uniqID: "/revdial?revdial.dialer="+uniqID |
| |
| incomingConn chan net.Conn |
| pickupFailed chan error |
| connReady chan bool |
| donec chan struct{} |
| closeOnce sync.Once |
| } |
| |
| var ( |
| dmapMu sync.Mutex |
| dialers = map[string]*Dialer{} |
| ) |
| |
| // NewDialer returns the side of the connection which will initiate |
| // new connections. This will typically be the side which did the HTTP |
| // Hijack. The connection is (typically) the hijacked HTTP client |
| // connection. The connPath is the HTTP path and optional query (but |
| // without scheme or host) on the dialer where the ConnHandler is |
| // mounted. |
| func NewDialer(c net.Conn, connPath string) *Dialer { |
| d := &Dialer{ |
| path: connPath, |
| uniqID: newUniqID(), |
| conn: c, |
| donec: make(chan struct{}), |
| connReady: make(chan bool), |
| incomingConn: make(chan net.Conn), |
| pickupFailed: make(chan error), |
| } |
| |
| join := "?" |
| if strings.Contains(connPath, "?") { |
| join = "&" |
| } |
| d.pickupPath = connPath + join + dialerUniqParam + "=" + d.uniqID |
| d.register() |
| go d.serve() |
| return d |
| } |
| |
| func newUniqID() string { |
| buf := make([]byte, 16) |
| rand.Read(buf) |
| return fmt.Sprintf("%x", buf) |
| } |
| |
| func (d *Dialer) register() { |
| dmapMu.Lock() |
| defer dmapMu.Unlock() |
| dialers[d.uniqID] = d |
| } |
| |
| func (d *Dialer) unregister() { |
| dmapMu.Lock() |
| defer dmapMu.Unlock() |
| delete(dialers, d.uniqID) |
| } |
| |
| // Done returns a channel which is closed when d is closed (either by |
| // this process on purpose, by a local error, or close or error from |
| // the peer). |
| func (d *Dialer) Done() <-chan struct{} { return d.donec } |
| |
| // Close closes the Dialer. |
| func (d *Dialer) Close() error { |
| d.closeOnce.Do(d.close) |
| return nil |
| } |
| |
| func (d *Dialer) close() { |
| d.unregister() |
| d.conn.Close() |
| close(d.donec) |
| } |
| |
| // Dial creates a new connection back to the Listener. |
| func (d *Dialer) Dial(ctx context.Context) (net.Conn, error) { |
| // First, tell serve that we want a connection: |
| select { |
| case d.connReady <- true: |
| case <-d.donec: |
| return nil, errors.New("revdial.Dialer closed") |
| case <-ctx.Done(): |
| return nil, ctx.Err() |
| } |
| |
| // Then pick it up: |
| select { |
| case c := <-d.incomingConn: |
| return c, nil |
| case err := <-d.pickupFailed: |
| return nil, err |
| case <-d.donec: |
| return nil, errors.New("revdial.Dialer closed") |
| case <-ctx.Done(): |
| return nil, ctx.Err() |
| } |
| } |
| |
| func (d *Dialer) matchConn(c net.Conn) { |
| select { |
| case d.incomingConn <- c: |
| case <-d.donec: |
| } |
| } |
| |
| // serve blocks and runs the control message loop, keeping the peer |
| // alive and notifying the peer when new connections are available. |
| func (d *Dialer) serve() error { |
| defer d.Close() |
| go func() { |
| defer d.Close() |
| br := bufio.NewReader(d.conn) |
| for { |
| line, err := br.ReadSlice('\n') |
| if err != nil { |
| return |
| } |
| var msg controlMsg |
| if err := json.Unmarshal(line, &msg); err != nil { |
| log.Printf("revdial.Dialer read invalid JSON: %q: %v", line, err) |
| return |
| } |
| switch msg.Command { |
| case "pickup-failed": |
| err := fmt.Errorf("revdial listener failed to pick up connection: %v", msg.Err) |
| select { |
| case d.pickupFailed <- err: |
| case <-d.donec: |
| return |
| } |
| } |
| } |
| }() |
| for { |
| if err := d.sendMessage(controlMsg{Command: "keep-alive"}); err != nil { |
| return err |
| } |
| |
| t := time.NewTimer(30 * time.Second) |
| select { |
| case <-t.C: |
| continue |
| case <-d.connReady: |
| t.Stop() |
| if err := d.sendMessage(controlMsg{ |
| Command: "conn-ready", |
| ConnPath: d.pickupPath, |
| }); err != nil { |
| return err |
| } |
| case <-d.donec: |
| t.Stop() |
| return errors.New("revdial.Dialer closed") |
| } |
| } |
| } |
| |
| func (d *Dialer) sendMessage(m controlMsg) error { |
| j, _ := json.Marshal(m) |
| d.conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) |
| j = append(j, '\n') |
| _, err := d.conn.Write(j) |
| d.conn.SetWriteDeadline(time.Time{}) |
| return err |
| } |
| |
| // NewListener returns a new Listener, accepting connections which |
| // arrive from the provided server connection, which should be after |
| // any necessary authentication (usually after an HTTP exchange). |
| // |
| // The provided dialServer func is responsible for connecting back to |
| // the server and doing TLS setup. |
| func NewListener(serverConn net.Conn, dialServer func(context.Context) (net.Conn, error)) *Listener { |
| ln := &Listener{ |
| sc: serverConn, |
| dial: dialServer, |
| connc: make(chan net.Conn, 8), // arbitrary |
| donec: make(chan struct{}), |
| } |
| go ln.run() |
| return ln |
| } |
| |
| var _ net.Listener = (*Listener)(nil) |
| |
| // Listener is a net.Listener, returning new connections which arrive |
| // from a corresponding Dialer. |
| type Listener struct { |
| sc net.Conn |
| connc chan net.Conn |
| donec chan struct{} |
| dial func(context.Context) (net.Conn, error) |
| writec chan<- []byte |
| |
| mu sync.Mutex // guards below, closing connc, and writing to rw |
| readErr error |
| closed bool |
| } |
| |
| type controlMsg struct { |
| Command string `json:"command,omitempty"` // "keep-alive", "conn-ready", "pickup-failed" |
| ConnPath string `json:"connPath,omitempty"` // conn pick-up URL path for "conn-url", "pickup-failed" |
| Err string `json:"err,omitempty"` |
| } |
| |
| // run reads control messages from the public server forever until the connection dies, which |
| // then closes the listener. |
| func (ln *Listener) run() { |
| defer ln.Close() |
| |
| // Write loop |
| writec := make(chan []byte, 8) |
| ln.writec = writec |
| go func() { |
| for { |
| select { |
| case <-ln.donec: |
| return |
| case msg := <-writec: |
| if _, err := ln.sc.Write(msg); err != nil { |
| log.Printf("revdial.Listener: error writing message to server: %v", err) |
| ln.Close() |
| return |
| } |
| } |
| } |
| }() |
| |
| // Read loop |
| br := bufio.NewReader(ln.sc) |
| for { |
| line, err := br.ReadSlice('\n') |
| if err != nil { |
| return |
| } |
| var msg controlMsg |
| if err := json.Unmarshal(line, &msg); err != nil { |
| log.Printf("revdial.Listener read invalid JSON: %q: %v", line, err) |
| return |
| } |
| switch msg.Command { |
| case "keep-alive": |
| // Occasional no-op message from server to keep |
| // us alive through NAT timeouts. |
| case "conn-ready": |
| go ln.grabConn(msg.ConnPath) |
| default: |
| // Ignore unknown messages |
| } |
| } |
| } |
| |
| func (ln *Listener) sendMessage(m controlMsg) { |
| j, _ := json.Marshal(m) |
| j = append(j, '\n') |
| ln.writec <- j |
| } |
| |
| func (ln *Listener) grabConn(path string) { |
| ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) |
| defer cancel() |
| c, err := ln.dial(ctx) |
| if err != nil { |
| ln.sendMessage(controlMsg{Command: "pickup-failed", ConnPath: path, Err: err.Error()}) |
| return |
| } |
| failPickup := func(err error) { |
| c.Close() |
| log.Printf("revdial.Listener: failed to pick up connection to %s: %v", path, err) |
| ln.sendMessage(controlMsg{Command: "pickup-failed", ConnPath: path, Err: err.Error()}) |
| } |
| bufr := bufio.NewReader(c) |
| |
| success := false |
| const maxRedirects = 2 |
| for i := 0; i < maxRedirects; i++ { |
| req, _ := http.NewRequest("GET", path, nil) |
| if err := req.Write(c); err != nil { |
| failPickup(err) |
| return |
| } |
| path, err = ReadProtoSwitchOrRedirect(bufr, req) |
| if err != nil { |
| failPickup(fmt.Errorf("switch failed: %v", err)) |
| return |
| } |
| if path == "" { |
| success = true |
| break |
| } |
| } |
| if !success { |
| failPickup(errors.New("too many redirects")) |
| return |
| } |
| |
| select { |
| case ln.connc <- c: |
| case <-ln.donec: |
| } |
| } |
| |
| // Closed reports whether the listener has been closed. |
| func (ln *Listener) Closed() bool { |
| ln.mu.Lock() |
| defer ln.mu.Unlock() |
| return ln.closed |
| } |
| |
| // Accept blocks and returns a new connection, or an error. |
| func (ln *Listener) Accept() (net.Conn, error) { |
| c, ok := <-ln.connc |
| if !ok { |
| ln.mu.Lock() |
| err, closed := ln.readErr, ln.closed |
| ln.mu.Unlock() |
| if err != nil && !closed { |
| return nil, fmt.Errorf("revdial: Listener closed; %v", err) |
| } |
| return nil, ErrListenerClosed |
| } |
| return c, nil |
| } |
| |
| // ErrListenerClosed is returned by Accept after Close has been called. |
| var ErrListenerClosed = errors.New("revdial: Listener closed") |
| |
| // Close closes the Listener, making future Accept calls return an |
| // error. |
| func (ln *Listener) Close() error { |
| ln.mu.Lock() |
| defer ln.mu.Unlock() |
| if ln.closed { |
| return nil |
| } |
| go ln.sc.Close() |
| ln.closed = true |
| close(ln.connc) |
| close(ln.donec) |
| return nil |
| } |
| |
| // Addr returns a dummy address. This exists only to conform to the |
| // net.Listener interface. |
| func (ln *Listener) Addr() net.Addr { return fakeAddr{} } |
| |
| type fakeAddr struct{} |
| |
| func (fakeAddr) Network() string { return "revdial" } |
| func (fakeAddr) String() string { return "revdialconn" } |
| |
| // ConnHandler returns the HTTP handler that needs to be mounted somewhere |
| // that the Listeners can dial out and get to. A dialer to connect to it |
| // is given to NewListener and the path to reach it is given to NewDialer |
| // to use in messages to the listener. |
| func ConnHandler() http.Handler { |
| return http.HandlerFunc(connHandler) |
| } |
| |
| func connHandler(w http.ResponseWriter, r *http.Request) { |
| if r.TLS == nil { |
| http.Error(w, "handler requires TLS", http.StatusInternalServerError) |
| return |
| } |
| if r.Method != "GET" { |
| w.Header().Set("Allow", "GET") |
| http.Error(w, "expected GET request to revdial conn handler", http.StatusMethodNotAllowed) |
| return |
| } |
| dialerUniq := r.FormValue(dialerUniqParam) |
| |
| dmapMu.Lock() |
| d, ok := dialers[dialerUniq] |
| dmapMu.Unlock() |
| if !ok { |
| http.Error(w, "unknown dialer", http.StatusBadRequest) |
| return |
| } |
| |
| conn, _, err := w.(http.Hijacker).Hijack() |
| if err != nil { |
| http.Error(w, err.Error(), http.StatusInternalServerError) |
| return |
| } |
| (&http.Response{StatusCode: http.StatusSwitchingProtocols, Proto: "HTTP/1.1"}).Write(conn) |
| d.matchConn(conn) |
| } |
| |
| // checkRelativeURL verifies that URL s does not change scheme or host. |
| func checkRelativeURL(s string) error { |
| u, err := url.Parse(s) |
| if err != nil { |
| return err |
| } |
| |
| // A relative URL should have no schema or host. |
| if u.Scheme != "" { |
| return fmt.Errorf("URL %q is not relative: contains scheme", s) |
| } |
| if u.Host != "" { |
| return fmt.Errorf("URL %q is not relative: contains host", s) |
| } |
| return nil |
| } |
| |
| // ReadProtoSwitchOrRedirect is a helper for completing revdial protocol switch |
| // requests. If the response indicates successful switch, nothing is returned. |
| // If the response indicates a redirect, the new location is returned. |
| func ReadProtoSwitchOrRedirect(r *bufio.Reader, req *http.Request) (location string, err error) { |
| resp, err := http.ReadResponse(r, req) |
| if err != nil { |
| return "", fmt.Errorf("error reading response: %v", err) |
| } |
| switch resp.StatusCode { |
| case http.StatusSwitchingProtocols: |
| // Success! Don't read body, as caller may want it. |
| return "", nil |
| case http.StatusTemporaryRedirect: |
| // Redirect. Discard body. |
| msg, _ := io.ReadAll(resp.Body) |
| location := resp.Header.Get("Location") |
| if location == "" { |
| return "", fmt.Errorf("redirect missing Location header; got %+v:\n\t%s", resp, msg) |
| } |
| if err := checkRelativeURL(location); err != nil { |
| return "", fmt.Errorf("redirect Location must be relative: %w", err) |
| } |
| // Retry at new location. |
| return location, nil |
| default: |
| msg, _ := io.ReadAll(resp.Body) |
| return "", fmt.Errorf("want HTTP status 101 or 307; got %v:\n\t%s", resp.Status, msg) |
| } |
| } |