go/types: check for duplicate values in expression switches

Fixes #11578.

Change-Id: I29a542be247127f470ba6c39aac0d0f6a18de553
Reviewed-on: https://go-review.googlesource.com/13285
Reviewed-by: Alan Donovan <adonovan@google.com>
diff --git a/src/go/types/stmt.go b/src/go/types/stmt.go
index 0ab2492..062b767 100644
--- a/src/go/types/stmt.go
+++ b/src/go/types/stmt.go
@@ -155,20 +155,84 @@
 	check.errorf(x.pos(), "%s %s %s", keyword, msg, &x)
 }
 
-func (check *Checker) caseValues(x *operand, values []ast.Expr) {
-	// No duplicate checking for now. See issue 4524.
+// goVal returns the Go value for val, or nil.
+func goVal(val constant.Value) interface{} {
+	// val should exist, but be conservative and check
+	if val == nil {
+		return nil
+	}
+	// Match implementation restriction of other compilers.
+	// gc only checks duplicates for integer, floating-point
+	// and string values, so only create Go values for these
+	// types.
+	switch val.Kind() {
+	case constant.Int:
+		if x, ok := constant.Int64Val(val); ok {
+			return x
+		}
+		if x, ok := constant.Uint64Val(val); ok {
+			return x
+		}
+	case constant.Float:
+		if x, ok := constant.Float64Val(val); ok {
+			return x
+		}
+	case constant.String:
+		return constant.StringVal(val)
+	}
+	return nil
+}
+
+// A valueMap maps a case value (of a basic Go type) to a list of positions
+// where the same case value appeared, together with the corresponding case
+// types.
+// Since two case values may have the same "underlying" value but different
+// types we need to also check the value's types (e.g., byte(1) vs myByte(1))
+// when the switch expression is of interface type.
+type (
+	valueMap  map[interface{}][]valueType // underlying Go value -> valueType
+	valueType struct {
+		pos token.Pos
+		typ Type
+	}
+)
+
+func (check *Checker) caseValues(x *operand, values []ast.Expr, seen valueMap) {
+L:
 	for _, e := range values {
 		var v operand
 		check.expr(&v, e)
 		if x.mode == invalid || v.mode == invalid {
-			continue
+			continue L
 		}
 		check.convertUntyped(&v, x.typ)
 		if v.mode == invalid {
-			continue
+			continue L
 		}
 		// Order matters: By comparing v against x, error positions are at the case values.
-		check.comparison(&v, x, token.EQL)
+		res := v // keep original v unchanged
+		check.comparison(&res, x, token.EQL)
+		if res.mode == invalid {
+			continue L
+		}
+		if v.mode != constant_ {
+			continue L // we're done
+		}
+		// look for duplicate values
+		if val := goVal(v.val); val != nil {
+			if list := seen[val]; list != nil {
+				// look for duplicate types for a given value
+				// (quadratic algorithm, but these lists tend to be very short)
+				for _, vt := range list {
+					if Identical(v.typ, vt.typ) {
+						check.errorf(v.pos(), "duplicate case %s in expression switch", &v)
+						check.error(vt.pos, "\tprevious case") // secondary error, \t indented
+						continue L
+					}
+				}
+			}
+			seen[val] = append(seen[val], valueType{v.pos(), v.typ})
+		}
 	}
 }
 
@@ -177,15 +241,19 @@
 	for _, e := range types {
 		T = check.typOrNil(e)
 		if T == Typ[Invalid] {
-			continue
+			continue L
 		}
-		// complain about duplicate types
-		// TODO(gri) use a type hash to avoid quadratic algorithm
+		// look for duplicate types
+		// (quadratic algorithm, but type switches tend to be reasonably small)
 		for t, pos := range seen {
 			if T == nil && t == nil || T != nil && t != nil && Identical(T, t) {
 				// talk about "case" rather than "type" because of nil case
-				check.error(e.Pos(), "duplicate case in type switch")
-				check.errorf(pos, "\tprevious case %s", T) // secondary error, \t indented
+				Ts := "nil"
+				if T != nil {
+					Ts = T.String()
+				}
+				check.errorf(e.Pos(), "duplicate case %s in type switch", Ts)
+				check.error(pos, "\tprevious case") // secondary error, \t indented
 				continue L
 			}
 		}
@@ -408,13 +476,14 @@
 
 		check.multipleDefaults(s.Body.List)
 
+		seen := make(valueMap) // map of seen case values to positions and types
 		for i, c := range s.Body.List {
 			clause, _ := c.(*ast.CaseClause)
 			if clause == nil {
 				check.invalidAST(c.Pos(), "incorrect expression switch case")
 				continue
 			}
-			check.caseValues(&x, clause.List)
+			check.caseValues(&x, clause.List, seen)
 			check.openScope(clause, "case")
 			inner := inner
 			if i+1 < len(s.Body.List) {
diff --git a/src/go/types/testdata/stmt0.src b/src/go/types/testdata/stmt0.src
index 7e28c23..e946066 100644
--- a/src/go/types/testdata/stmt0.src
+++ b/src/go/types/testdata/stmt0.src
@@ -438,14 +438,85 @@
 
 	switch x {
 	case 1:
-	case 1 /* DISABLED "duplicate case" */ :
+	case 1 /* ERROR "duplicate case" */ :
+	case ( /* ERROR "duplicate case" */ 1):
 	case 2, 3, 4:
-	case 1 /* DISABLED "duplicate case" */ :
+	case 5, 1 /* ERROR "duplicate case" */ :
 	}
 
 	switch uint64(x) {
-	case 1 /* DISABLED duplicate case */ <<64-1:
-	case 1 /* DISABLED duplicate case */ <<64-1:
+	case 1<<64 - 1:
+	case 1 /* ERROR duplicate case */ <<64 - 1:
+	case 2, 3, 4:
+	case 5, 1 /* ERROR duplicate case */ <<64 - 1:
+	}
+
+	var y32 float32
+	switch y32 {
+	case 1.1:
+	case 11/10: // integer division!
+	case 11. /* ERROR duplicate case */ /10:
+	case 2, 3.0, 4.1:
+	case 5.2, 1.10 /* ERROR duplicate case */ :
+	}
+
+	var y64 float64
+	switch y64 {
+	case 1.1:
+	case 11/10: // integer division!
+	case 11. /* ERROR duplicate case */ /10:
+	case 2, 3.0, 4.1:
+	case 5.2, 1.10 /* ERROR duplicate case */ :
+	}
+
+	var s string
+	switch s {
+	case "foo":
+	case "foo" /* ERROR duplicate case */ :
+	case "f" /* ERROR duplicate case */ + "oo":
+	case "abc", "def", "ghi":
+	case "jkl", "foo" /* ERROR duplicate case */ :
+	}
+
+	type T int
+	type F float64
+	type S string
+	type B bool
+	var i interface{}
+	switch i {
+	case nil:
+	case nil: // no duplicate detection
+	case (*int)(nil):
+	case (*int)(nil): // do duplicate detection
+	case 1:
+	case byte(1):
+	case int /* ERROR duplicate case */ (1):
+	case T(1):
+	case 1.0:
+	case F(1.0):
+	case F /* ERROR duplicate case */ (1.0):
+	case "hello":
+	case S("hello"):
+	case S /* ERROR duplicate case */ ("hello"):
+	case 1==1, B(false):
+	case false, B(2==2):
+	}
+
+	// switch on array
+	var a [3]int
+	switch a {
+	case [3]int{1, 2, 3}:
+	case [3]int{1, 2, 3}: // no duplicate detection
+	case [ /* ERROR "mismatched types */ 4]int{4, 5, 6}:
+	}
+
+	// switch on channel
+	var c1, c2 chan int
+	switch c1 {
+	case nil:
+	case c1:
+	case c2:
+	case c1, c2: // no duplicate detection
 	}
 }