go/analysis/passes/shift: update for generics

Warn about shifts that exceed the width of a type in the type parameter
type set.

For golang/go#48704

Change-Id: If16650b0c7214415f650884014eda9690d8da11a
Reviewed-on: https://go-review.googlesource.com/c/tools/+/357758
Trust: Robert Findley <rfindley@google.com>
Run-TryBot: Robert Findley <rfindley@google.com>
gopls-CI: kokoro <noreply+kokoro@google.com>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Tim King <taking@google.com>
diff --git a/go/analysis/passes/shift/shift.go b/go/analysis/passes/shift/shift.go
index 1f3df07..640de28 100644
--- a/go/analysis/passes/shift/shift.go
+++ b/go/analysis/passes/shift/shift.go
@@ -14,11 +14,13 @@
 	"go/ast"
 	"go/constant"
 	"go/token"
+	"math"
 
 	"golang.org/x/tools/go/analysis"
 	"golang.org/x/tools/go/analysis/passes/inspect"
 	"golang.org/x/tools/go/analysis/passes/internal/analysisutil"
 	"golang.org/x/tools/go/ast/inspector"
+	"golang.org/x/tools/internal/typeparams"
 )
 
 const Doc = "check for shifts that equal or exceed the width of the integer"
@@ -93,9 +95,27 @@
 	if t == nil {
 		return
 	}
-	size := 8 * pass.TypesSizes.Sizeof(t)
-	if amt >= size {
+	terms, err := typeparams.StructuralTerms(t)
+	if err != nil {
+		return // invalid type
+	}
+	sizes := make(map[int64]struct{})
+	for _, term := range terms {
+		size := 8 * pass.TypesSizes.Sizeof(term.Type())
+		sizes[size] = struct{}{}
+	}
+	minSize := int64(math.MaxInt64)
+	for size := range sizes {
+		if size < minSize {
+			minSize = size
+		}
+	}
+	if amt >= minSize {
 		ident := analysisutil.Format(pass.Fset, x)
-		pass.ReportRangef(node, "%s (%d bits) too small for shift of %d", ident, size, amt)
+		qualifier := ""
+		if len(sizes) > 1 {
+			qualifier = "may be "
+		}
+		pass.ReportRangef(node, "%s (%s%d bits) too small for shift of %d", ident, qualifier, minSize, amt)
 	}
 }
diff --git a/go/analysis/passes/shift/shift_test.go b/go/analysis/passes/shift/shift_test.go
index 8b41b60..e60943e 100644
--- a/go/analysis/passes/shift/shift_test.go
+++ b/go/analysis/passes/shift/shift_test.go
@@ -9,9 +9,14 @@
 
 	"golang.org/x/tools/go/analysis/analysistest"
 	"golang.org/x/tools/go/analysis/passes/shift"
+	"golang.org/x/tools/internal/typeparams"
 )
 
 func Test(t *testing.T) {
 	testdata := analysistest.TestData()
-	analysistest.Run(t, testdata, shift.Analyzer, "a")
+	pkgs := []string{"a"}
+	if typeparams.Enabled {
+		pkgs = append(pkgs, "typeparams")
+	}
+	analysistest.Run(t, testdata, shift.Analyzer, pkgs...)
 }
diff --git a/go/analysis/passes/shift/testdata/src/typeparams/typeparams.go b/go/analysis/passes/shift/testdata/src/typeparams/typeparams.go
new file mode 100644
index 0000000..a76df88
--- /dev/null
+++ b/go/analysis/passes/shift/testdata/src/typeparams/typeparams.go
@@ -0,0 +1,32 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package typeparams
+
+import "unsafe"
+
+func GenericShiftTest[DifferentSize ~int8|int16|int64, SameSize int8|byte]() {
+	var d DifferentSize
+	_ = d << 7
+	_ = d << 8        // want "d .may be 8 bits. too small for shift of 8"
+	_ = d << 15       // want "d .may be 8 bits. too small for shift of 15"
+	_ = (d + 1) << 8  // want ".d . 1. .may be 8 bits. too small for shift of 8"
+	_ = (d + 1) << 16 // want ".d . 1. .may be 8 bits. too small for shift of 16"
+	_ = d << (7 + 1)  // want "d .may be 8 bits. too small for shift of 8"
+	_ = d >> 8        // want "d .may be 8 bits. too small for shift of 8"
+	d <<= 8           // want "d .may be 8 bits. too small for shift of 8"
+	d >>= 8           // want "d .may be 8 bits. too small for shift of 8"
+
+	// go/types does not compute constant sizes for type parameters, so we do not
+	// report a diagnostic here.
+	_ = d << (8 * DifferentSize(unsafe.Sizeof(d)))
+
+	var s SameSize
+	_ = s << 7
+	_ = s << 8        // want "s .8 bits. too small for shift of 8"
+	_ = s << (7 + 1)  // want "s .8 bits. too small for shift of 8"
+	_ = s >> 8        // want "s .8 bits. too small for shift of 8"
+	s <<= 8           // want "s .8 bits. too small for shift of 8"
+	s >>= 8           // want "s .8 bits. too small for shift of 8"
+}