| // Copyright 2015 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| // +build go1.6 |
| |
| package http2 |
| |
| import ( |
| "crypto/tls" |
| "fmt" |
| "net/http" |
| ) |
| |
| func configureTransport(t1 *http.Transport) (*Transport, error) { |
| connPool := new(clientConnPool) |
| t2 := &Transport{ |
| ConnPool: noDialClientConnPool{connPool}, |
| t1: t1, |
| } |
| connPool.t = t2 |
| if err := registerHTTPSProtocol(t1, noDialH2RoundTripper{t2}); err != nil { |
| return nil, err |
| } |
| if t1.TLSClientConfig == nil { |
| t1.TLSClientConfig = new(tls.Config) |
| } |
| if !strSliceContains(t1.TLSClientConfig.NextProtos, "h2") { |
| t1.TLSClientConfig.NextProtos = append([]string{"h2"}, t1.TLSClientConfig.NextProtos...) |
| } |
| if !strSliceContains(t1.TLSClientConfig.NextProtos, "http/1.1") { |
| t1.TLSClientConfig.NextProtos = append(t1.TLSClientConfig.NextProtos, "http/1.1") |
| } |
| upgradeFn := func(authority string, c *tls.Conn) http.RoundTripper { |
| addr := authorityAddr("https", authority) |
| if used, err := connPool.addConnIfNeeded(addr, t2, c); err != nil { |
| go c.Close() |
| return erringRoundTripper{err} |
| } else if !used { |
| // Turns out we don't need this c. |
| // For example, two goroutines made requests to the same host |
| // at the same time, both kicking off TCP dials. (since protocol |
| // was unknown) |
| go c.Close() |
| } |
| return t2 |
| } |
| if m := t1.TLSNextProto; len(m) == 0 { |
| t1.TLSNextProto = map[string]func(string, *tls.Conn) http.RoundTripper{ |
| "h2": upgradeFn, |
| } |
| } else { |
| m["h2"] = upgradeFn |
| } |
| return t2, nil |
| } |
| |
| // registerHTTPSProtocol calls Transport.RegisterProtocol but |
| // converting panics into errors. |
| func registerHTTPSProtocol(t *http.Transport, rt http.RoundTripper) (err error) { |
| defer func() { |
| if e := recover(); e != nil { |
| err = fmt.Errorf("%v", e) |
| } |
| }() |
| t.RegisterProtocol("https", rt) |
| return nil |
| } |
| |
| // noDialH2RoundTripper is a RoundTripper which only tries to complete the request |
| // if there's already has a cached connection to the host. |
| type noDialH2RoundTripper struct{ t *Transport } |
| |
| func (rt noDialH2RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { |
| res, err := rt.t.RoundTrip(req) |
| if err == ErrNoCachedConn { |
| return nil, http.ErrSkipAltProtocol |
| } |
| return res, err |
| } |