blob: 1e390487149f5ff7f9805cc2f2fbef4ae1f37f5a [file] [log] [blame]
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package middleware
import (
"context"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
)
type contextKey int
const key = contextKey(1)
func TestChain(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
v := r.Context().Value(key).(int)
fmt.Fprintf(w, "%d", v)
})
add := Middleware(func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
v, _ := r.Context().Value(key).(int)
ctx := context.WithValue(r.Context(), key, v+2)
h.ServeHTTP(w, r.WithContext(ctx))
})
})
multiply := Middleware(func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
v, _ := r.Context().Value(key).(int)
ctx := context.WithValue(r.Context(), key, v*2)
h.ServeHTTP(w, r.WithContext(ctx))
})
})
ts := httptest.NewServer(Chain(add, multiply)(handler))
defer ts.Close()
resp, err := ts.Client().Get(ts.URL)
if err != nil {
t.Fatalf("GET got error %v, want nil", err)
}
body, err := ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
t.Fatalf("ioutil.ReadAll(resp.Body): %v", err)
}
// Test that both middleware executed, in the correct order.
if got, want := string(body), "4"; got != want {
t.Errorf("GET returned body %q, want %q", got, want)
}
}