go/analysis/passes/printf: improve support for %w

Report use of %w with non-error arguments.

Report multiple %w in a format.

Report use of %w with non-Errorf functions.

Fixes golang/go#32070

Change-Id: I65d8fcc235ae2f3717582d00352356eeb0eaf73c
Reviewed-on: https://go-review.googlesource.com/c/tools/+/177601
Run-TryBot: Ian Lance Taylor <iant@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Jonathan Amsterdam <jba@google.com>
diff --git a/go/analysis/passes/printf/printf.go b/go/analysis/passes/printf/printf.go
index f59e95d..52de8b0 100644
--- a/go/analysis/passes/printf/printf.go
+++ b/go/analysis/passes/printf/printf.go
@@ -67,15 +67,20 @@
 `
 
 // isWrapper is a fact indicating that a function is a print or printf wrapper.
-type isWrapper struct{ Printf bool }
+type isWrapper struct{ Kind funcKind }
 
 func (f *isWrapper) AFact() {}
 
 func (f *isWrapper) String() string {
-	if f.Printf {
+	switch f.Kind {
+	case kindPrintf:
 		return "printfWrapper"
-	} else {
+	case kindPrint:
 		return "printWrapper"
+	case kindErrorf:
+		return "errorfWrapper"
+	default:
+		return "unknownWrapper"
 	}
 }
 
@@ -223,16 +228,20 @@
 	return ok && info.ObjectOf(id) == param
 }
 
+type funcKind int
+
 const (
-	kindPrintf = 1
-	kindPrint  = 2
+	kindUnknown funcKind = iota
+	kindPrintf           = iota
+	kindPrint
+	kindErrorf
 )
 
 // checkPrintfFwd checks that a printf-forwarding wrapper is forwarding correctly.
 // It diagnoses writing fmt.Printf(format, args) instead of fmt.Printf(format, args...).
-func checkPrintfFwd(pass *analysis.Pass, w *printfWrapper, call *ast.CallExpr, kind int) {
+func checkPrintfFwd(pass *analysis.Pass, w *printfWrapper, call *ast.CallExpr, kind funcKind) {
 	matched := kind == kindPrint ||
-		kind == kindPrintf && len(call.Args) >= 2 && match(pass.TypesInfo, call.Args[len(call.Args)-2], w.format)
+		kind != kindUnknown && len(call.Args) >= 2 && match(pass.TypesInfo, call.Args[len(call.Args)-2], w.format)
 	if !matched {
 		return
 	}
@@ -262,7 +271,7 @@
 	fn := w.obj
 	var fact isWrapper
 	if !pass.ImportObjectFact(fn, &fact) {
-		fact.Printf = kind == kindPrintf
+		fact.Kind = kind
 		pass.ExportObjectFact(fn, &fact)
 		for _, caller := range w.callers {
 			checkPrintfFwd(pass, caller.w, caller.call, kind)
@@ -414,42 +423,42 @@
 		call := n.(*ast.CallExpr)
 		fn, kind := printfNameAndKind(pass, call)
 		switch kind {
-		case kindPrintf:
-			checkPrintf(pass, call, fn)
+		case kindPrintf, kindErrorf:
+			checkPrintf(pass, kind, call, fn)
 		case kindPrint:
 			checkPrint(pass, call, fn)
 		}
 	})
 }
 
-func printfNameAndKind(pass *analysis.Pass, call *ast.CallExpr) (fn *types.Func, kind int) {
+func printfNameAndKind(pass *analysis.Pass, call *ast.CallExpr) (fn *types.Func, kind funcKind) {
 	fn, _ = typeutil.Callee(pass.TypesInfo, call).(*types.Func)
 	if fn == nil {
 		return nil, 0
 	}
 
-	var fact isWrapper
-	if pass.ImportObjectFact(fn, &fact) {
-		if fact.Printf {
-			return fn, kindPrintf
-		} else {
-			return fn, kindPrint
-		}
-	}
-
 	_, ok := isPrint[fn.FullName()]
 	if !ok {
 		// Next look up just "printf", for use with -printf.funcs.
 		_, ok = isPrint[strings.ToLower(fn.Name())]
 	}
 	if ok {
-		if strings.HasSuffix(fn.Name(), "f") {
+		if fn.Name() == "Errorf" {
+			kind = kindErrorf
+		} else if strings.HasSuffix(fn.Name(), "f") {
 			kind = kindPrintf
 		} else {
 			kind = kindPrint
 		}
+		return fn, kind
 	}
-	return fn, kind
+
+	var fact isWrapper
+	if pass.ImportObjectFact(fn, &fact) {
+		return fn, fact.Kind
+	}
+
+	return fn, kindUnknown
 }
 
 // isFormatter reports whether t satisfies fmt.Formatter.
@@ -491,7 +500,7 @@
 }
 
 // checkPrintf checks a call to a formatted print routine such as Printf.
-func checkPrintf(pass *analysis.Pass, call *ast.CallExpr, fn *types.Func) {
+func checkPrintf(pass *analysis.Pass, kind funcKind, call *ast.CallExpr, fn *types.Func) {
 	format, idx := formatString(pass, call)
 	if idx < 0 {
 		if false {
@@ -511,6 +520,7 @@
 	argNum := firstArg
 	maxArgNum := firstArg
 	anyIndex := false
+	anyW := false
 	for i, w := 0, 0; i < len(format); i += w {
 		w = 1
 		if format[i] != '%' {
@@ -527,6 +537,17 @@
 		if state.hasIndex {
 			anyIndex = true
 		}
+		if state.verb == 'w' {
+			if kind != kindErrorf {
+				pass.Reportf(call.Pos(), "%s call has error-wrapping directive %%w", state.name)
+				return
+			}
+			if anyW {
+				pass.Reportf(call.Pos(), "%s call has more than one error-wrapping directive %%w", state.name)
+				return
+			}
+			anyW = true
+		}
 		if len(state.argNums) > 0 {
 			// Continue with the next sequential argument.
 			argNum = state.argNums[len(state.argNums)-1] + 1
@@ -697,6 +718,7 @@
 	argFloat
 	argComplex
 	argPointer
+	argError
 	anyType printfArgType = ^0
 )
 
@@ -739,7 +761,7 @@
 	{'T', "-", anyType},
 	{'U', "-#", argRune | argInt},
 	{'v', allFlags, anyType},
-	{'w', noFlag, anyType},
+	{'w', allFlags, argError},
 	{'x', sharpNumFlag, argRune | argInt | argString | argPointer},
 	{'X', sharpNumFlag, argRune | argInt | argString | argPointer},
 }
diff --git a/go/analysis/passes/printf/testdata/src/a/a.go b/go/analysis/passes/printf/testdata/src/a/a.go
index b783f10..6995176 100644
--- a/go/analysis/passes/printf/testdata/src/a/a.go
+++ b/go/analysis/passes/printf/testdata/src/a/a.go
@@ -97,7 +97,6 @@
 	fmt.Printf("%T", notstringerv)
 	fmt.Printf("%q", stringerarrayv)
 	fmt.Printf("%v", stringerarrayv)
-	fmt.Printf("%w", err)
 	fmt.Printf("%s", stringerarrayv)
 	fmt.Printf("%v", notstringerarrayv)
 	fmt.Printf("%T", notstringerarrayv)
@@ -323,6 +322,16 @@
 
 	// Issue 26486
 	dbg("", 1) // no error "call has arguments but no formatting directive"
+
+	// %w
+	_ = fmt.Errorf("%w", err)
+	_ = fmt.Errorf("%#w", err)
+	_ = fmt.Errorf("%[2]w %[1]s", "x", err)
+	_ = fmt.Errorf("%[2]w %[1]s", e, "x") // want `Errorf format %\[2\]w has arg "x" of wrong type string`
+	_ = fmt.Errorf("%w", "x")             // want `Errorf format %w has arg "x" of wrong type string`
+	_ = fmt.Errorf("%w %w", err, err)     // want `Errorf call has more than one error-wrapping directive %w`
+	fmt.Printf("%w", err)                 // want `Printf call has error-wrapping directive %w`
+	Errorf(0, "%w", err)
 }
 
 func someString() string { return "X" }
@@ -367,13 +376,13 @@
 
 // Errorf is used by the test for a case in which the first parameter
 // is not a format string.
-func Errorf(i int, format string, args ...interface{}) { // want Errorf:"printfWrapper"
+func Errorf(i int, format string, args ...interface{}) { // want Errorf:"errorfWrapper"
 	_ = fmt.Errorf(format, args...)
 }
 
 // errorf is used by the test for a case in which the function accepts multiple
 // string parameters before variadic arguments
-func errorf(level, format string, args ...interface{}) { // want errorf:"printfWrapper"
+func errorf(level, format string, args ...interface{}) { // want errorf:"errorfWrapper"
 	_ = fmt.Errorf(format, args...)
 }
 
diff --git a/go/analysis/passes/printf/types.go b/go/analysis/passes/printf/types.go
index 12286fd..5000d9a 100644
--- a/go/analysis/passes/printf/types.go
+++ b/go/analysis/passes/printf/types.go
@@ -37,6 +37,12 @@
 			return true // probably a type check problem
 		}
 	}
+
+	// %w accepts only errors.
+	if t == argError {
+		return types.ConvertibleTo(typ, errorType)
+	}
+
 	// If the type implements fmt.Formatter, we have nothing to check.
 	if isFormatter(typ) {
 		return true