Disallow basic types as keys in context.WithValue().
Signed-off-by: Joe Tsai <joetsai@google.com>
diff --git a/lint.go b/lint.go
index 14cc626..6768516 100644
--- a/lint.go
+++ b/lint.go
@@ -187,6 +187,7 @@
f.lintErrorReturn()
f.lintUnexportedReturn()
f.lintTimeNames()
+ f.lintContextKeyTypes()
}
type link string
@@ -1430,6 +1431,46 @@
})
}
+// lintContextKeyTypes checks for call expressions to context.WithValue with
+// basic types used for the key argument.
+// See: https://golang.org/issue/17293
+func (f *file) lintContextKeyTypes() {
+ f.walk(func(node ast.Node) bool {
+ switch node := node.(type) {
+ case *ast.CallExpr:
+ f.checkContextKeyType(node)
+ }
+
+ return true
+ })
+}
+
+// checkContextKeyType reports an error if the call expression calls
+// context.WithValue with a key argument of basic type.
+func (f *file) checkContextKeyType(x *ast.CallExpr) {
+ sel, ok := x.Fun.(*ast.SelectorExpr)
+ if !ok {
+ return
+ }
+ pkg, ok := sel.X.(*ast.Ident)
+ if !ok || pkg.Name != "context" {
+ return
+ }
+ if sel.Sel.Name != "WithValue" {
+ return
+ }
+
+ // key is second argument to context.WithValue
+ if len(x.Args) != 3 {
+ return
+ }
+ key := f.pkg.typesInfo.Types[x.Args[1]]
+
+ if _, ok := key.Type.(*types.Basic); ok {
+ f.errorf(x, 1.0, category("context"), fmt.Sprintf("should not use basic type %s as key in context.WithValue", key.Type))
+ }
+}
+
// receiverType returns the named type of the method receiver, sans "*",
// or "invalid-type" if fn.Recv is ill formed.
func receiverType(fn *ast.FuncDecl) string {
diff --git a/testdata/contextkeytypes.go b/testdata/contextkeytypes.go
new file mode 100644
index 0000000..6433e62
--- /dev/null
+++ b/testdata/contextkeytypes.go
@@ -0,0 +1,37 @@
+// Package contextkeytypes verifies that correct types are used as keys in
+// calls to context.WithValue.
+package contextkeytypes
+
+import (
+ "context"
+ "fmt"
+)
+
+type ctxKey struct{}
+
+func contextKeyTypeTests() {
+ fmt.Println() // not in package context
+ context.TODO() // wrong function
+ c := context.Background() // wrong function
+ context.WithValue(c, "foo", "bar") // MATCH /should not use basic type( untyped|)? string as key in context.WithValue/
+ context.WithValue(c, true, "bar") // MATCH /should not use basic type( untyped|)? bool as key in context.WithValue/
+ context.WithValue(c, 1, "bar") // MATCH /should not use basic type( untyped|)? int as key in context.WithValue/
+ context.WithValue(c, int8(1), "bar") // MATCH /should not use basic type int8 as key in context.WithValue/
+ context.WithValue(c, int16(1), "bar") // MATCH /should not use basic type int16 as key in context.WithValue/
+ context.WithValue(c, int32(1), "bar") // MATCH /should not use basic type int32 as key in context.WithValue/
+ context.WithValue(c, rune(1), "bar") // MATCH /should not use basic type rune as key in context.WithValue/
+ context.WithValue(c, int64(1), "bar") // MATCH /should not use basic type int64 as key in context.WithValue/
+ context.WithValue(c, uint(1), "bar") // MATCH /should not use basic type uint as key in context.WithValue/
+ context.WithValue(c, uint8(1), "bar") // MATCH /should not use basic type uint8 as key in context.WithValue/
+ context.WithValue(c, byte(1), "bar") // MATCH /should not use basic type byte as key in context.WithValue/
+ context.WithValue(c, uint16(1), "bar") // MATCH /should not use basic type uint16 as key in context.WithValue/
+ context.WithValue(c, uint32(1), "bar") // MATCH /should not use basic type uint32 as key in context.WithValue/
+ context.WithValue(c, uint64(1), "bar") // MATCH /should not use basic type uint64 as key in context.WithValue/
+ context.WithValue(c, uintptr(1), "bar") // MATCH /should not use basic type uintptr as key in context.WithValue/
+ context.WithValue(c, float32(1.0), "bar") // MATCH /should not use basic type float32 as key in context.WithValue/
+ context.WithValue(c, float64(1.0), "bar") // MATCH /should not use basic type float64 as key in context.WithValue/
+ context.WithValue(c, complex64(1i), "bar") // MATCH /should not use basic type complex64 as key in context.WithValue/
+ context.WithValue(c, complex128(1i), "bar") // MATCH /should not use basic type complex128 as key in context.WithValue/
+ context.WithValue(c, ctxKey{}, "bar") // ok
+ context.WithValue(c, &ctxKey{}, "bar") // ok
+}