| // Package httpcache provides a http.RoundTripper implementation that works as a |
| // mostly RFC-compliant cache for http responses. |
| // |
| // It is only suitable for use as a 'private' cache (i.e. for a web-browser or an API-client |
| // and not for a shared proxy). |
| // |
| package httpcache |
| |
| import ( |
| "bufio" |
| "bytes" |
| "errors" |
| "fmt" |
| "io" |
| "io/ioutil" |
| "net/http" |
| "net/http/httputil" |
| "strings" |
| "sync" |
| "time" |
| ) |
| |
| const ( |
| stale = iota |
| fresh |
| transparent |
| // XFromCache is the header added to responses that are returned from the cache |
| XFromCache = "X-From-Cache" |
| ) |
| |
| // A Cache interface is used by the Transport to store and retrieve responses. |
| type Cache interface { |
| // Get returns the []byte representation of a cached response and a bool |
| // set to true if the value isn't empty |
| Get(key string) (responseBytes []byte, ok bool) |
| // Set stores the []byte representation of a response against a key |
| Set(key string, responseBytes []byte) |
| // Delete removes the value associated with the key |
| Delete(key string) |
| } |
| |
| // cacheKey returns the cache key for req. |
| func cacheKey(req *http.Request) string { |
| if req.Method == http.MethodGet { |
| return req.URL.String() |
| } else { |
| return req.Method + " " + req.URL.String() |
| } |
| } |
| |
| // CachedResponse returns the cached http.Response for req if present, and nil |
| // otherwise. |
| func CachedResponse(c Cache, req *http.Request) (resp *http.Response, err error) { |
| cachedVal, ok := c.Get(cacheKey(req)) |
| if !ok { |
| return |
| } |
| |
| b := bytes.NewBuffer(cachedVal) |
| return http.ReadResponse(bufio.NewReader(b), req) |
| } |
| |
| // MemoryCache is an implemtation of Cache that stores responses in an in-memory map. |
| type MemoryCache struct { |
| mu sync.RWMutex |
| items map[string][]byte |
| } |
| |
| // Get returns the []byte representation of the response and true if present, false if not |
| func (c *MemoryCache) Get(key string) (resp []byte, ok bool) { |
| c.mu.RLock() |
| resp, ok = c.items[key] |
| c.mu.RUnlock() |
| return resp, ok |
| } |
| |
| // Set saves response resp to the cache with key |
| func (c *MemoryCache) Set(key string, resp []byte) { |
| c.mu.Lock() |
| c.items[key] = resp |
| c.mu.Unlock() |
| } |
| |
| // Delete removes key from the cache |
| func (c *MemoryCache) Delete(key string) { |
| c.mu.Lock() |
| delete(c.items, key) |
| c.mu.Unlock() |
| } |
| |
| // NewMemoryCache returns a new Cache that will store items in an in-memory map |
| func NewMemoryCache() *MemoryCache { |
| c := &MemoryCache{items: map[string][]byte{}} |
| return c |
| } |
| |
| // Transport is an implementation of http.RoundTripper that will return values from a cache |
| // where possible (avoiding a network request) and will additionally add validators (etag/if-modified-since) |
| // to repeated requests allowing servers to return 304 / Not Modified |
| type Transport struct { |
| // The RoundTripper interface actually used to make requests |
| // If nil, http.DefaultTransport is used |
| Transport http.RoundTripper |
| Cache Cache |
| // If true, responses returned from the cache will be given an extra header, X-From-Cache |
| MarkCachedResponses bool |
| } |
| |
| // NewTransport returns a new Transport with the |
| // provided Cache implementation and MarkCachedResponses set to true |
| func NewTransport(c Cache) *Transport { |
| return &Transport{Cache: c, MarkCachedResponses: true} |
| } |
| |
| // Client returns an *http.Client that caches responses. |
| func (t *Transport) Client() *http.Client { |
| return &http.Client{Transport: t} |
| } |
| |
| // varyMatches will return false unless all of the cached values for the headers listed in Vary |
| // match the new request |
| func varyMatches(cachedResp *http.Response, req *http.Request) bool { |
| for _, header := range headerAllCommaSepValues(cachedResp.Header, "vary") { |
| header = http.CanonicalHeaderKey(header) |
| if header != "" && req.Header.Get(header) != cachedResp.Header.Get("X-Varied-"+header) { |
| return false |
| } |
| } |
| return true |
| } |
| |
| // RoundTrip takes a Request and returns a Response |
| // |
| // If there is a fresh Response already in cache, then it will be returned without connecting to |
| // the server. |
| // |
| // If there is a stale Response, then any validators it contains will be set on the new request |
| // to give the server a chance to respond with NotModified. If this happens, then the cached Response |
| // will be returned. |
| func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error) { |
| cacheKey := cacheKey(req) |
| cacheable := (req.Method == "GET" || req.Method == "HEAD") && req.Header.Get("range") == "" |
| var cachedResp *http.Response |
| if cacheable { |
| cachedResp, err = CachedResponse(t.Cache, req) |
| } else { |
| // Need to invalidate an existing value |
| t.Cache.Delete(cacheKey) |
| } |
| |
| transport := t.Transport |
| if transport == nil { |
| transport = http.DefaultTransport |
| } |
| |
| if cacheable && cachedResp != nil && err == nil { |
| if t.MarkCachedResponses { |
| cachedResp.Header.Set(XFromCache, "1") |
| } |
| |
| if varyMatches(cachedResp, req) { |
| // Can only use cached value if the new request doesn't Vary significantly |
| freshness := getFreshness(cachedResp.Header, req.Header) |
| if freshness == fresh { |
| return cachedResp, nil |
| } |
| |
| if freshness == stale { |
| var req2 *http.Request |
| // Add validators if caller hasn't already done so |
| etag := cachedResp.Header.Get("etag") |
| if etag != "" && req.Header.Get("etag") == "" { |
| req2 = cloneRequest(req) |
| req2.Header.Set("if-none-match", etag) |
| } |
| lastModified := cachedResp.Header.Get("last-modified") |
| if lastModified != "" && req.Header.Get("last-modified") == "" { |
| if req2 == nil { |
| req2 = cloneRequest(req) |
| } |
| req2.Header.Set("if-modified-since", lastModified) |
| } |
| if req2 != nil { |
| req = req2 |
| } |
| } |
| } |
| |
| resp, err = transport.RoundTrip(req) |
| if err == nil && req.Method == "GET" && resp.StatusCode == http.StatusNotModified { |
| // Replace the 304 response with the one from cache, but update with some new headers |
| endToEndHeaders := getEndToEndHeaders(resp.Header) |
| for _, header := range endToEndHeaders { |
| cachedResp.Header[header] = resp.Header[header] |
| } |
| cachedResp.Status = fmt.Sprintf("%d %s", http.StatusOK, http.StatusText(http.StatusOK)) |
| cachedResp.StatusCode = http.StatusOK |
| |
| resp = cachedResp |
| } else if (err != nil || (cachedResp != nil && resp.StatusCode >= 500)) && |
| req.Method == "GET" && canStaleOnError(cachedResp.Header, req.Header) { |
| // In case of transport failure and stale-if-error activated, returns cached content |
| // when available |
| cachedResp.Status = fmt.Sprintf("%d %s", http.StatusOK, http.StatusText(http.StatusOK)) |
| cachedResp.StatusCode = http.StatusOK |
| return cachedResp, nil |
| } else { |
| if err != nil || resp.StatusCode != http.StatusOK { |
| t.Cache.Delete(cacheKey) |
| } |
| if err != nil { |
| return nil, err |
| } |
| } |
| } else { |
| reqCacheControl := parseCacheControl(req.Header) |
| if _, ok := reqCacheControl["only-if-cached"]; ok { |
| resp = newGatewayTimeoutResponse(req) |
| } else { |
| resp, err = transport.RoundTrip(req) |
| if err != nil { |
| return nil, err |
| } |
| } |
| } |
| |
| if cacheable && canStore(parseCacheControl(req.Header), parseCacheControl(resp.Header)) { |
| for _, varyKey := range headerAllCommaSepValues(resp.Header, "vary") { |
| varyKey = http.CanonicalHeaderKey(varyKey) |
| fakeHeader := "X-Varied-" + varyKey |
| reqValue := req.Header.Get(varyKey) |
| if reqValue != "" { |
| resp.Header.Set(fakeHeader, reqValue) |
| } |
| } |
| switch req.Method { |
| case "GET": |
| // Delay caching until EOF is reached. |
| resp.Body = &cachingReadCloser{ |
| R: resp.Body, |
| OnEOF: func(r io.Reader) { |
| resp := *resp |
| resp.Body = ioutil.NopCloser(r) |
| respBytes, err := httputil.DumpResponse(&resp, true) |
| if err == nil { |
| t.Cache.Set(cacheKey, respBytes) |
| } |
| }, |
| } |
| default: |
| respBytes, err := httputil.DumpResponse(resp, true) |
| if err == nil { |
| t.Cache.Set(cacheKey, respBytes) |
| } |
| } |
| } else { |
| t.Cache.Delete(cacheKey) |
| } |
| return resp, nil |
| } |
| |
| // ErrNoDateHeader indicates that the HTTP headers contained no Date header. |
| var ErrNoDateHeader = errors.New("no Date header") |
| |
| // Date parses and returns the value of the Date header. |
| func Date(respHeaders http.Header) (date time.Time, err error) { |
| dateHeader := respHeaders.Get("date") |
| if dateHeader == "" { |
| err = ErrNoDateHeader |
| return |
| } |
| |
| return time.Parse(time.RFC1123, dateHeader) |
| } |
| |
| type realClock struct{} |
| |
| func (c *realClock) since(d time.Time) time.Duration { |
| return time.Since(d) |
| } |
| |
| type timer interface { |
| since(d time.Time) time.Duration |
| } |
| |
| var clock timer = &realClock{} |
| |
| // getFreshness will return one of fresh/stale/transparent based on the cache-control |
| // values of the request and the response |
| // |
| // fresh indicates the response can be returned |
| // stale indicates that the response needs validating before it is returned |
| // transparent indicates the response should not be used to fulfil the request |
| // |
| // Because this is only a private cache, 'public' and 'private' in cache-control aren't |
| // signficant. Similarly, smax-age isn't used. |
| func getFreshness(respHeaders, reqHeaders http.Header) (freshness int) { |
| respCacheControl := parseCacheControl(respHeaders) |
| reqCacheControl := parseCacheControl(reqHeaders) |
| if _, ok := reqCacheControl["no-cache"]; ok { |
| return transparent |
| } |
| if _, ok := respCacheControl["no-cache"]; ok { |
| return stale |
| } |
| if _, ok := reqCacheControl["only-if-cached"]; ok { |
| return fresh |
| } |
| |
| date, err := Date(respHeaders) |
| if err != nil { |
| return stale |
| } |
| currentAge := clock.since(date) |
| |
| var lifetime time.Duration |
| var zeroDuration time.Duration |
| |
| // If a response includes both an Expires header and a max-age directive, |
| // the max-age directive overrides the Expires header, even if the Expires header is more restrictive. |
| if maxAge, ok := respCacheControl["max-age"]; ok { |
| lifetime, err = time.ParseDuration(maxAge + "s") |
| if err != nil { |
| lifetime = zeroDuration |
| } |
| } else { |
| expiresHeader := respHeaders.Get("Expires") |
| if expiresHeader != "" { |
| expires, err := time.Parse(time.RFC1123, expiresHeader) |
| if err != nil { |
| lifetime = zeroDuration |
| } else { |
| lifetime = expires.Sub(date) |
| } |
| } |
| } |
| |
| if maxAge, ok := reqCacheControl["max-age"]; ok { |
| // the client is willing to accept a response whose age is no greater than the specified time in seconds |
| lifetime, err = time.ParseDuration(maxAge + "s") |
| if err != nil { |
| lifetime = zeroDuration |
| } |
| } |
| if minfresh, ok := reqCacheControl["min-fresh"]; ok { |
| // the client wants a response that will still be fresh for at least the specified number of seconds. |
| minfreshDuration, err := time.ParseDuration(minfresh + "s") |
| if err == nil { |
| currentAge = time.Duration(currentAge + minfreshDuration) |
| } |
| } |
| |
| if maxstale, ok := reqCacheControl["max-stale"]; ok { |
| // Indicates that the client is willing to accept a response that has exceeded its expiration time. |
| // If max-stale is assigned a value, then the client is willing to accept a response that has exceeded |
| // its expiration time by no more than the specified number of seconds. |
| // If no value is assigned to max-stale, then the client is willing to accept a stale response of any age. |
| // |
| // Responses served only because of a max-stale value are supposed to have a Warning header added to them, |
| // but that seems like a hassle, and is it actually useful? If so, then there needs to be a different |
| // return-value available here. |
| if maxstale == "" { |
| return fresh |
| } |
| maxstaleDuration, err := time.ParseDuration(maxstale + "s") |
| if err == nil { |
| currentAge = time.Duration(currentAge - maxstaleDuration) |
| } |
| } |
| |
| if lifetime > currentAge { |
| return fresh |
| } |
| |
| return stale |
| } |
| |
| // Returns true if either the request or the response includes the stale-if-error |
| // cache control extension: https://tools.ietf.org/html/rfc5861 |
| func canStaleOnError(respHeaders, reqHeaders http.Header) bool { |
| respCacheControl := parseCacheControl(respHeaders) |
| reqCacheControl := parseCacheControl(reqHeaders) |
| |
| var err error |
| lifetime := time.Duration(-1) |
| |
| if staleMaxAge, ok := respCacheControl["stale-if-error"]; ok { |
| if staleMaxAge != "" { |
| lifetime, err = time.ParseDuration(staleMaxAge + "s") |
| if err != nil { |
| return false |
| } |
| } else { |
| return true |
| } |
| } |
| if staleMaxAge, ok := reqCacheControl["stale-if-error"]; ok { |
| if staleMaxAge != "" { |
| lifetime, err = time.ParseDuration(staleMaxAge + "s") |
| if err != nil { |
| return false |
| } |
| } else { |
| return true |
| } |
| } |
| |
| if lifetime >= 0 { |
| date, err := Date(respHeaders) |
| if err != nil { |
| return false |
| } |
| currentAge := clock.since(date) |
| if lifetime > currentAge { |
| return true |
| } |
| } |
| |
| return false |
| } |
| |
| func getEndToEndHeaders(respHeaders http.Header) []string { |
| // These headers are always hop-by-hop |
| hopByHopHeaders := map[string]struct{}{ |
| "Connection": struct{}{}, |
| "Keep-Alive": struct{}{}, |
| "Proxy-Authenticate": struct{}{}, |
| "Proxy-Authorization": struct{}{}, |
| "Te": struct{}{}, |
| "Trailers": struct{}{}, |
| "Transfer-Encoding": struct{}{}, |
| "Upgrade": struct{}{}, |
| } |
| |
| for _, extra := range strings.Split(respHeaders.Get("connection"), ",") { |
| // any header listed in connection, if present, is also considered hop-by-hop |
| if strings.Trim(extra, " ") != "" { |
| hopByHopHeaders[http.CanonicalHeaderKey(extra)] = struct{}{} |
| } |
| } |
| endToEndHeaders := []string{} |
| for respHeader, _ := range respHeaders { |
| if _, ok := hopByHopHeaders[respHeader]; !ok { |
| endToEndHeaders = append(endToEndHeaders, respHeader) |
| } |
| } |
| return endToEndHeaders |
| } |
| |
| func canStore(reqCacheControl, respCacheControl cacheControl) (canStore bool) { |
| if _, ok := respCacheControl["no-store"]; ok { |
| return false |
| } |
| if _, ok := reqCacheControl["no-store"]; ok { |
| return false |
| } |
| return true |
| } |
| |
| func newGatewayTimeoutResponse(req *http.Request) *http.Response { |
| var braw bytes.Buffer |
| braw.WriteString("HTTP/1.1 504 Gateway Timeout\r\n\r\n") |
| resp, err := http.ReadResponse(bufio.NewReader(&braw), req) |
| if err != nil { |
| panic(err) |
| } |
| return resp |
| } |
| |
| // cloneRequest returns a clone of the provided *http.Request. |
| // The clone is a shallow copy of the struct and its Header map. |
| // (This function copyright goauth2 authors: https://code.google.com/p/goauth2) |
| func cloneRequest(r *http.Request) *http.Request { |
| // shallow copy of the struct |
| r2 := new(http.Request) |
| *r2 = *r |
| // deep copy of the Header |
| r2.Header = make(http.Header) |
| for k, s := range r.Header { |
| r2.Header[k] = s |
| } |
| return r2 |
| } |
| |
| type cacheControl map[string]string |
| |
| func parseCacheControl(headers http.Header) cacheControl { |
| cc := cacheControl{} |
| ccHeader := headers.Get("Cache-Control") |
| for _, part := range strings.Split(ccHeader, ",") { |
| part = strings.Trim(part, " ") |
| if part == "" { |
| continue |
| } |
| if strings.ContainsRune(part, '=') { |
| keyval := strings.Split(part, "=") |
| cc[strings.Trim(keyval[0], " ")] = strings.Trim(keyval[1], ",") |
| } else { |
| cc[part] = "" |
| } |
| } |
| return cc |
| } |
| |
| // headerAllCommaSepValues returns all comma-separated values (each |
| // with whitespace trimmed) for header name in headers. According to |
| // Section 4.2 of the HTTP/1.1 spec |
| // (http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2), |
| // values from multiple occurrences of a header should be concatenated, if |
| // the header's value is a comma-separated list. |
| func headerAllCommaSepValues(headers http.Header, name string) []string { |
| var vals []string |
| for _, val := range headers[http.CanonicalHeaderKey(name)] { |
| fields := strings.Split(val, ",") |
| for i, f := range fields { |
| fields[i] = strings.TrimSpace(f) |
| } |
| vals = append(vals, fields...) |
| } |
| return vals |
| } |
| |
| // cachingReadCloser is a wrapper around ReadCloser R that calls OnEOF |
| // handler with a full copy of the content read from R when EOF is |
| // reached. |
| type cachingReadCloser struct { |
| // Underlying ReadCloser. |
| R io.ReadCloser |
| // OnEOF is called with a copy of the content of R when EOF is reached. |
| OnEOF func(io.Reader) |
| |
| buf bytes.Buffer // buf stores a copy of the content of R. |
| } |
| |
| // Read reads the next len(p) bytes from R or until R is drained. The |
| // return value n is the number of bytes read. If R has no data to |
| // return, err is io.EOF and OnEOF is called with a full copy of what |
| // has been read so far. |
| func (r *cachingReadCloser) Read(p []byte) (n int, err error) { |
| n, err = r.R.Read(p) |
| r.buf.Write(p[:n]) |
| if err == io.EOF { |
| r.OnEOF(bytes.NewReader(r.buf.Bytes())) |
| } |
| return n, err |
| } |
| |
| func (r *cachingReadCloser) Close() error { |
| return r.R.Close() |
| } |
| |
| // NewMemoryCacheTransport returns a new Transport using the in-memory cache implementation |
| func NewMemoryCacheTransport() *Transport { |
| c := NewMemoryCache() |
| t := NewTransport(c) |
| return t |
| } |