Disallow basic types as keys in context.WithValue().

Signed-off-by: Joe Tsai <joetsai@google.com>
diff --git a/lint.go b/lint.go
index 14cc626..6768516 100644
--- a/lint.go
+++ b/lint.go
@@ -187,6 +187,7 @@
 	f.lintErrorReturn()
 	f.lintUnexportedReturn()
 	f.lintTimeNames()
+	f.lintContextKeyTypes()
 }
 
 type link string
@@ -1430,6 +1431,46 @@
 	})
 }
 
+// lintContextKeyTypes checks for call expressions to context.WithValue with
+// basic types used for the key argument.
+// See: https://golang.org/issue/17293
+func (f *file) lintContextKeyTypes() {
+	f.walk(func(node ast.Node) bool {
+		switch node := node.(type) {
+		case *ast.CallExpr:
+			f.checkContextKeyType(node)
+		}
+
+		return true
+	})
+}
+
+// checkContextKeyType reports an error if the call expression calls
+// context.WithValue with a key argument of basic type.
+func (f *file) checkContextKeyType(x *ast.CallExpr) {
+	sel, ok := x.Fun.(*ast.SelectorExpr)
+	if !ok {
+		return
+	}
+	pkg, ok := sel.X.(*ast.Ident)
+	if !ok || pkg.Name != "context" {
+		return
+	}
+	if sel.Sel.Name != "WithValue" {
+		return
+	}
+
+	// key is second argument to context.WithValue
+	if len(x.Args) != 3 {
+		return
+	}
+	key := f.pkg.typesInfo.Types[x.Args[1]]
+
+	if _, ok := key.Type.(*types.Basic); ok {
+		f.errorf(x, 1.0, category("context"), fmt.Sprintf("should not use basic type %s as key in context.WithValue", key.Type))
+	}
+}
+
 // receiverType returns the named type of the method receiver, sans "*",
 // or "invalid-type" if fn.Recv is ill formed.
 func receiverType(fn *ast.FuncDecl) string {
diff --git a/testdata/contextkeytypes.go b/testdata/contextkeytypes.go
new file mode 100644
index 0000000..6433e62
--- /dev/null
+++ b/testdata/contextkeytypes.go
@@ -0,0 +1,37 @@
+// Package contextkeytypes verifies that correct types are used as keys in
+// calls to context.WithValue.
+package contextkeytypes
+
+import (
+	"context"
+	"fmt"
+)
+
+type ctxKey struct{}
+
+func contextKeyTypeTests() {
+	fmt.Println()                               // not in package context
+	context.TODO()                              // wrong function
+	c := context.Background()                   // wrong function
+	context.WithValue(c, "foo", "bar")          // MATCH /should not use basic type( untyped|)? string as key in context.WithValue/
+	context.WithValue(c, true, "bar")           // MATCH /should not use basic type( untyped|)? bool as key in context.WithValue/
+	context.WithValue(c, 1, "bar")              // MATCH /should not use basic type( untyped|)? int as key in context.WithValue/
+	context.WithValue(c, int8(1), "bar")        // MATCH /should not use basic type int8 as key in context.WithValue/
+	context.WithValue(c, int16(1), "bar")       // MATCH /should not use basic type int16 as key in context.WithValue/
+	context.WithValue(c, int32(1), "bar")       // MATCH /should not use basic type int32 as key in context.WithValue/
+	context.WithValue(c, rune(1), "bar")        // MATCH /should not use basic type rune as key in context.WithValue/
+	context.WithValue(c, int64(1), "bar")       // MATCH /should not use basic type int64 as key in context.WithValue/
+	context.WithValue(c, uint(1), "bar")        // MATCH /should not use basic type uint as key in context.WithValue/
+	context.WithValue(c, uint8(1), "bar")       // MATCH /should not use basic type uint8 as key in context.WithValue/
+	context.WithValue(c, byte(1), "bar")        // MATCH /should not use basic type byte as key in context.WithValue/
+	context.WithValue(c, uint16(1), "bar")      // MATCH /should not use basic type uint16 as key in context.WithValue/
+	context.WithValue(c, uint32(1), "bar")      // MATCH /should not use basic type uint32 as key in context.WithValue/
+	context.WithValue(c, uint64(1), "bar")      // MATCH /should not use basic type uint64 as key in context.WithValue/
+	context.WithValue(c, uintptr(1), "bar")     // MATCH /should not use basic type uintptr as key in context.WithValue/
+	context.WithValue(c, float32(1.0), "bar")   // MATCH /should not use basic type float32 as key in context.WithValue/
+	context.WithValue(c, float64(1.0), "bar")   // MATCH /should not use basic type float64 as key in context.WithValue/
+	context.WithValue(c, complex64(1i), "bar")  // MATCH /should not use basic type complex64 as key in context.WithValue/
+	context.WithValue(c, complex128(1i), "bar") // MATCH /should not use basic type complex128 as key in context.WithValue/
+	context.WithValue(c, ctxKey{}, "bar")       // ok
+	context.WithValue(c, &ctxKey{}, "bar")      // ok
+}