blob: 96196dce95e8f0d85f4e49b0e26d5e48c4c0405f [file] [log] [blame]
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package middleware implements a simple middleware pattern for http handlers,
// along with implementations for some common middlewares.
package middleware
import (
"fmt"
"net/http"
"runtime/debug"
"time"
"golang.org/x/exp/slog"
)
// A Middleware is a func that wraps an http.Handler.
type Middleware func(http.Handler) http.Handler
// Chain creates a new Middleware that applies a sequence of Middlewares, so
// that they execute in the given order when handling an http request.
//
// In other words, Chain(m1, m2)(handler) = m1(m2(handler))
//
// A similar pattern is used in e.g. github.com/justinas/alice:
// https://github.com/justinas/alice/blob/ce87934/chain.go#L45
func Chain(middlewares ...Middleware) Middleware {
return func(h http.Handler) http.Handler {
for i := range middlewares {
h = middlewares[len(middlewares)-1-i](h)
}
return h
}
}
// Log is a middleware that logs request start, end, duration, and status.
func Log(logger *slog.Logger) Middleware {
return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
ctx := r.Context()
l := logger.With(
slog.String("method", r.Method),
slog.String("uri", r.RequestURI),
// TODO(hyangah): set trace context from X-Cloud-Trace-Context
)
l.InfoContext(ctx, "request start")
w2 := &statusRecorder{w, 200}
h.ServeHTTP(w2, r)
level := slog.LevelInfo
msg := "request end"
switch w2.status / 100 {
case 5:
level = slog.LevelError // 5XX error
msg = "request error"
case 4:
level = slog.LevelWarn // 4XX error
msg = "request rejected"
}
l.Log(ctx, level, msg,
slog.Int("status", w2.status),
slog.Duration("duration", time.Since(start)),
)
})
}
}
// Recover is a middleware that recovers from panics in the delegate
// handler and prints a stack trace.
func Recover() Middleware {
return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
if err := recover(); err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
slog.Error(r.RequestURI, fmt.Errorf(`panic("%s")`, err))
fmt.Println(string(debug.Stack()))
}
}()
h.ServeHTTP(w, r)
})
}
}
// RequestSize limits the size of incoming request bodies.
func RequestSize(n int64) Middleware {
return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, n)
h.ServeHTTP(w, r)
})
}
}
type statusRecorder struct {
http.ResponseWriter
status int
}
func (rec *statusRecorder) WriteHeader(code int) {
rec.status = code
rec.ResponseWriter.WriteHeader(code)
}
// Timeout returns a new Middleware that times out each request after the given
// duration.
func Timeout(d time.Duration) Middleware {
return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.TimeoutHandler(h, d, "request timed out").ServeHTTP(w, r)
})
}
}