[dev.cc] cmd/asm: check for overflow on multiply and left shift

The internal size of integers is not part of the definition of the assembler,
so if bits roll out the top it's a portability problem at best.

If you need to use shift to create a mask, use & to restrict the bit count
before shifting. That will make it portable, too.

Change-Id: I24f9a4d2152c3f9f253e22ff75270fe50c18612b
Reviewed-on: https://go-review.googlesource.com/3451
Reviewed-by: Russ Cox <rsc@golang.org>
diff --git a/src/cmd/asm/internal/asm/overflow.go b/src/cmd/asm/internal/asm/overflow.go
index 9e03e7a..a429201 100644
--- a/src/cmd/asm/internal/asm/overflow.go
+++ b/src/cmd/asm/internal/asm/overflow.go
@@ -33,6 +33,23 @@
 			}
 		}
 	}
+	overflow := func(a, b int) bool {
+		for ; b > 0; b-- {
+			a <<= 1
+			if a >= 256 {
+				return true
+			}
+		}
+		return false
+	}
+	for a := 0; a <= 255; a++ {
+		for b := 0; b <= 255; b++ {
+			ovfl := overflow(a, b)
+			if shiftOverflows(uint8(a), uint8(b)) != ovfl {
+				fmt.Printf("%d<<%d fails\n", a, b)
+			}
+		}
+	}
 */
 
 func addOverflows(a, b uint64) bool {
@@ -51,6 +68,11 @@
 	return c/b != a
 }
 
+func shiftOverflows(a, b uint64) bool {
+	c := a << b
+	return c>>b != a
+}
+
 /*
 For the record, signed overflow:
 
@@ -91,4 +113,17 @@
 	c := a * b
 	return c/b != a
 }
+
+func signedShiftOverflows(a, b int64) bool {
+	// Avoid right shift of a negative number.
+	if a >= 0 {
+		c := a << b
+		return c>>b != a
+	}
+	// Otherwise it's negative, so we complement, which
+	// puts zeros at the top.
+	a = ^a
+	c := a << b
+	return c>>b != a
+}
 */
diff --git a/src/cmd/asm/internal/asm/parse.go b/src/cmd/asm/internal/asm/parse.go
index 18ec932..be616ec 100644
--- a/src/cmd/asm/internal/asm/parse.go
+++ b/src/cmd/asm/internal/asm/parse.go
@@ -408,7 +408,11 @@
 		switch p.peek() {
 		case '*':
 			p.next()
-			value *= p.factor() // OVERFLOW?
+			x := p.factor()
+			if mulOverflows(value, x) {
+				p.errorf("%d * %d overflows", value, x)
+			}
+			value *= x
 		case '/':
 			p.next()
 			value /= p.factor()
@@ -418,10 +422,13 @@
 		case lex.LSH:
 			p.next()
 			shift := p.factor()
-			if shift < 0 {
+			if int64(shift) < 0 {
 				p.errorf("negative left shift %d", shift)
 			}
-			value <<= uint(shift) // OVERFLOW?
+			if shiftOverflows(value, shift) {
+				p.errorf("%d << %d overflows", value, shift)
+			}
+			return value << shift
 		case lex.RSH:
 			p.next()
 			shift := p.term()