blob: b98172e06f9eda82db81968a101fcd384209eeef [file] [log] [blame]
// Copyright 2015 The Go Authors.
// See https://go.googlesource.com/go/+/master/CONTRIBUTORS
// Licensed under the same terms as Go itself:
// https://go.googlesource.com/go/+/master/LICENSE
package http2
import (
"bufio"
"bytes"
"crypto/tls"
"errors"
"fmt"
"io"
"log"
"net"
"net/http"
"strings"
"sync"
"github.com/bradfitz/http2/hpack"
)
type Transport struct {
Fallback http.RoundTripper
// TODO: remove this and make more general with a TLS dial hook, like http
InsecureTLSDial bool
}
type clientConn struct {
tconn *tls.Conn
bw *bufio.Writer
br *bufio.Reader
fr *Framer
readerDone chan struct{} // closed on error
readerErr error // set before readerDone is closed
werr error // first write error that has occurred
hbuf bytes.Buffer // HPACK encoder writes into this
henc *hpack.Encoder
hdec *hpack.Decoder
nextRes http.Header
// Settings from peer:
maxFrameSize uint32
mu sync.Mutex
streams map[uint32]*clientStream
nextStreamID uint32
}
type clientStream struct {
ID uint32
resc chan *http.Response
pw *io.PipeWriter
pr *io.PipeReader
}
type stickyErrWriter struct {
w io.Writer
err *error
}
func (sew stickyErrWriter) Write(p []byte) (n int, err error) {
if *sew.err != nil {
return 0, *sew.err
}
n, err = sew.w.Write(p)
*sew.err = err
return
}
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
if req.URL.Scheme != "https" {
if t.Fallback == nil {
return nil, errors.New("http2: unsupported scheme and no Fallback")
}
return t.Fallback.RoundTrip(req)
}
host, port, err := net.SplitHostPort(req.URL.Host)
if err != nil {
host = req.URL.Host
port = "443"
}
cfg := &tls.Config{
ServerName: host,
NextProtos: []string{NextProtoTLS},
InsecureSkipVerify: t.InsecureTLSDial,
}
tconn, err := tls.Dial("tcp", host+":"+port, cfg)
if err != nil {
return nil, err
}
if err := tconn.Handshake(); err != nil {
return nil, err
}
if !t.InsecureTLSDial {
if err := tconn.VerifyHostname(cfg.ServerName); err != nil {
return nil, err
}
}
state := tconn.ConnectionState()
if p := state.NegotiatedProtocol; p != NextProtoTLS {
// TODO(bradfitz): fall back to Fallback
return nil, fmt.Errorf("bad protocol: %v", p)
}
if !state.NegotiatedProtocolIsMutual {
return nil, errors.New("could not negotiate protocol mutually")
}
if _, err := tconn.Write(clientPreface); err != nil {
return nil, err
}
cc := &clientConn{
tconn: tconn,
readerDone: make(chan struct{}),
nextStreamID: 1,
streams: make(map[uint32]*clientStream),
}
cc.bw = bufio.NewWriter(stickyErrWriter{tconn, &cc.werr})
cc.br = bufio.NewReader(tconn)
cc.fr = NewFramer(cc.bw, cc.br)
cc.henc = hpack.NewEncoder(&cc.hbuf)
cc.fr.WriteSettings()
// TODO: re-send more conn-level flow control tokens when server uses all these.
cc.fr.WriteWindowUpdate(0, 1<<30) // um, 0x7fffffff doesn't work to Google? it hangs?
cc.bw.Flush()
if cc.werr != nil {
return nil, cc.werr
}
// Read the obligatory SETTINGS frame
f, err := cc.fr.ReadFrame()
if err != nil {
return nil, err
}
sf, ok := f.(*SettingsFrame)
if !ok {
return nil, fmt.Errorf("expected settings frame, got: %T", f)
}
cc.fr.WriteSettingsAck()
cc.bw.Flush()
sf.ForeachSetting(func(s Setting) error {
switch s.ID {
case SettingMaxFrameSize:
cc.maxFrameSize = s.Val
// TODO(bradfitz): handle the others
default:
log.Printf("Unhandled Setting: %v", s)
}
return nil
})
// TODO: figure out henc size
cc.hdec = hpack.NewDecoder(initialHeaderTableSize, cc.onNewHeaderField)
go cc.readLoop()
cs := cc.newStream()
hasBody := false // TODO
// we send: HEADERS[+CONTINUATION] + (DATA?)
hdrs := cc.encodeHeaders(req)
first := true
for len(hdrs) > 0 {
chunk := hdrs
if len(chunk) > int(cc.maxFrameSize) {
chunk = chunk[:cc.maxFrameSize]
}
hdrs = hdrs[len(chunk):]
endHeaders := len(hdrs) == 0
if first {
cc.fr.WriteHeaders(HeadersFrameParam{
StreamID: cs.ID,
BlockFragment: chunk,
EndStream: !hasBody,
EndHeaders: endHeaders,
})
first = false
} else {
cc.fr.WriteContinuation(cs.ID, endHeaders, chunk)
}
}
cc.bw.Flush()
if cc.werr != nil {
return nil, cc.werr
}
return <-cs.resc, nil
}
func (cc *clientConn) encodeHeaders(req *http.Request) []byte {
cc.hbuf.Reset()
// TODO(bradfitz): figure out :authority-vs-Host stuff between http2 and Go
host := req.Host
if host == "" {
host = req.URL.Host
}
cc.writeHeader(":authority", host) // probably not right for all sites
cc.writeHeader(":method", req.Method)
cc.writeHeader(":path", req.URL.Path)
cc.writeHeader(":scheme", "https")
for k, vv := range req.Header {
lowKey := strings.ToLower(k)
if lowKey == "host" {
continue
}
for _, v := range vv {
cc.writeHeader(lowKey, v)
}
}
return cc.hbuf.Bytes()
}
func (cc *clientConn) writeHeader(name, value string) {
log.Printf("sending %q = %q", name, value)
cc.henc.WriteField(hpack.HeaderField{Name: name, Value: value})
}
func (cc *clientConn) newStream() *clientStream {
cc.mu.Lock()
defer cc.mu.Unlock()
cs := &clientStream{
ID: cc.nextStreamID,
resc: make(chan *http.Response, 1),
}
cc.nextStreamID += 2
cc.streams[cs.ID] = cs
return cs
}
func (cc *clientConn) streamByID(id uint32) *clientStream {
cc.mu.Lock()
defer cc.mu.Unlock()
return cc.streams[id]
}
// runs in its own goroutine.
func (cc *clientConn) readLoop() {
defer close(cc.readerDone)
for {
f, err := cc.fr.ReadFrame()
if err != nil {
cc.readerErr = err
// TODO: don't log it.
log.Printf("ReadFrame: %v", err)
return
}
cs := cc.streamByID(f.Header().StreamID)
log.Printf("Read %v: %#v", f.Header(), f)
headersEnded := false
streamEnded := false
if ff, ok := f.(interface {
StreamEnded() bool
}); ok {
streamEnded = ff.StreamEnded()
}
switch f := f.(type) {
case *HeadersFrame:
cc.nextRes = make(http.Header)
cs.pr, cs.pw = io.Pipe()
cc.hdec.Write(f.HeaderBlockFragment())
headersEnded = f.HeadersEnded()
case *ContinuationFrame:
// TODO: verify stream id is the same
cc.hdec.Write(f.HeaderBlockFragment())
headersEnded = f.HeadersEnded()
case *DataFrame:
log.Printf("DATA: %q", f.Data())
cs.pw.Write(f.Data())
default:
}
if streamEnded {
cs.pw.Close()
}
if headersEnded {
if cs == nil {
panic("couldn't find stream") // TODO be graceful
}
cs.resc <- &http.Response{
Header: cc.nextRes,
Body: cs.pr,
}
}
}
}
func (cc *clientConn) onNewHeaderField(f hpack.HeaderField) {
log.Printf("Header field: %+v", f)
cc.nextRes.Add(http.CanonicalHeaderKey(f.Name), f.Value)
}