cmd/compile: implement shifts by signed amounts
Allow shifts by signed amounts. Panic if the shift amount is negative.
TODO: We end up doing two compares per shift, see Ian's comment
https://github.com/golang/go/issues/19113#issuecomment-443241799 that
we could do it with a single comparison in the normal case.
The prove pass mostly handles this code well. For instance, it removes the
<0 check for cases like this:
if s >= 0 { _ = x << s }
_ = x << len(a)
This case isn't handled well yet:
_ = x << (y & 0xf)
I'll do followon CLs for unhandled cases as needed.
Update #19113
R=go1.13
Change-Id: I839a5933d94b54ab04deb9dd5149f32c51c90fa1
Reviewed-on: https://go-review.googlesource.com/c/158719
Run-TryBot: Keith Randall <khr@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Josh Bleecher Snyder <josharian@gmail.com>
diff --git a/src/cmd/compile/internal/gc/builtin.go b/src/cmd/compile/internal/gc/builtin.go
index 04f4cbf..f32fcd6 100644
--- a/src/cmd/compile/internal/gc/builtin.go
+++ b/src/cmd/compile/internal/gc/builtin.go
@@ -13,6 +13,7 @@
{"panicindex", funcTag, 5},
{"panicslice", funcTag, 5},
{"panicdivide", funcTag, 5},
+ {"panicshift", funcTag, 5},
{"panicmakeslicelen", funcTag, 5},
{"throwinit", funcTag, 5},
{"panicwrap", funcTag, 5},
diff --git a/src/cmd/compile/internal/gc/builtin/runtime.go b/src/cmd/compile/internal/gc/builtin/runtime.go
index fc879ba..210881a 100644
--- a/src/cmd/compile/internal/gc/builtin/runtime.go
+++ b/src/cmd/compile/internal/gc/builtin/runtime.go
@@ -18,6 +18,7 @@
func panicindex()
func panicslice()
func panicdivide()
+func panicshift()
func panicmakeslicelen()
func throwinit()
func panicwrap()
diff --git a/src/cmd/compile/internal/gc/go.go b/src/cmd/compile/internal/gc/go.go
index 376637b..2213d8d 100644
--- a/src/cmd/compile/internal/gc/go.go
+++ b/src/cmd/compile/internal/gc/go.go
@@ -296,6 +296,7 @@
msanwrite,
newproc,
panicdivide,
+ panicshift,
panicdottypeE,
panicdottypeI,
panicindex,
diff --git a/src/cmd/compile/internal/gc/ssa.go b/src/cmd/compile/internal/gc/ssa.go
index e201376..6ddc9fb 100644
--- a/src/cmd/compile/internal/gc/ssa.go
+++ b/src/cmd/compile/internal/gc/ssa.go
@@ -84,6 +84,7 @@
panicnildottype = sysfunc("panicnildottype")
panicoverflow = sysfunc("panicoverflow")
panicslice = sysfunc("panicslice")
+ panicshift = sysfunc("panicshift")
raceread = sysfunc("raceread")
racereadrange = sysfunc("racereadrange")
racewrite = sysfunc("racewrite")
@@ -2128,7 +2129,13 @@
case OLSH, ORSH:
a := s.expr(n.Left)
b := s.expr(n.Right)
- return s.newValue2(s.ssaShiftOp(n.Op, n.Type, n.Right.Type), a.Type, a, b)
+ bt := b.Type
+ if bt.IsSigned() {
+ cmp := s.newValue2(s.ssaOp(OGE, bt), types.Types[TBOOL], b, s.zeroVal(bt))
+ s.check(cmp, panicshift)
+ bt = bt.ToUnsigned()
+ }
+ return s.newValue2(s.ssaShiftOp(n.Op, n.Type, bt), a.Type, a, b)
case OANDAND, OOROR:
// To implement OANDAND (and OOROR), we introduce a
// new temporary variable to hold the result. The
diff --git a/src/cmd/compile/internal/gc/typecheck.go b/src/cmd/compile/internal/gc/typecheck.go
index 4fc1c5c..63e0d78 100644
--- a/src/cmd/compile/internal/gc/typecheck.go
+++ b/src/cmd/compile/internal/gc/typecheck.go
@@ -660,8 +660,8 @@
r = defaultlit(r, types.Types[TUINT])
n.Right = r
t := r.Type
- if !t.IsInteger() || t.IsSigned() {
- yyerror("invalid operation: %v (shift count type %v, must be unsigned integer)", n, r.Type)
+ if !t.IsInteger() {
+ yyerror("invalid operation: %v (shift count type %v, must be integer)", n, r.Type)
n.Type = nil
return n
}
diff --git a/src/cmd/compile/internal/ssa/rewrite.go b/src/cmd/compile/internal/ssa/rewrite.go
index a154249..6edb593 100644
--- a/src/cmd/compile/internal/ssa/rewrite.go
+++ b/src/cmd/compile/internal/ssa/rewrite.go
@@ -1115,7 +1115,8 @@
case OpStaticCall:
switch v.Aux.(fmt.Stringer).String() {
case "runtime.racefuncenter", "runtime.racefuncexit", "runtime.panicindex",
- "runtime.panicslice", "runtime.panicdivide", "runtime.panicwrap":
+ "runtime.panicslice", "runtime.panicdivide", "runtime.panicwrap",
+ "runtime.panicshift":
// Check for racefuncenter will encounter racefuncexit and vice versa.
// Allow calls to panic*
default:
diff --git a/src/cmd/internal/obj/x86/obj6.go b/src/cmd/internal/obj/x86/obj6.go
index babfd38..a6931e8 100644
--- a/src/cmd/internal/obj/x86/obj6.go
+++ b/src/cmd/internal/obj/x86/obj6.go
@@ -968,7 +968,7 @@
return false
}
switch s.Name {
- case "runtime.panicindex", "runtime.panicslice", "runtime.panicdivide", "runtime.panicwrap":
+ case "runtime.panicindex", "runtime.panicslice", "runtime.panicdivide", "runtime.panicwrap", "runtime.panicshift":
return true
}
return false
diff --git a/src/runtime/panic.go b/src/runtime/panic.go
index bb83be4..59916dd 100644
--- a/src/runtime/panic.go
+++ b/src/runtime/panic.go
@@ -23,16 +23,16 @@
var indexError = error(errorString("index out of range"))
-// The panicindex, panicslice, and panicdivide functions are called by
+// The panic{index,slice,divide,shift} functions are called by
// code generated by the compiler for out of bounds index expressions,
-// out of bounds slice expressions, and division by zero. The
-// panicdivide (again), panicoverflow, panicfloat, and panicmem
+// out of bounds slice expressions, division by zero, and shift by negative.
+// The panicdivide (again), panicoverflow, panicfloat, and panicmem
// functions are called by the signal handler when a signal occurs
// indicating the respective problem.
//
-// Since panicindex and panicslice are never called directly, and
+// Since panic{index,slice,shift} are never called directly, and
// since the runtime package should never have an out of bounds slice
-// or array reference, if we see those functions called from the
+// or array reference or negative shift, if we see those functions called from the
// runtime package we turn the panic into a throw. That will dump the
// entire runtime stack for easier debugging.
@@ -68,6 +68,16 @@
panic(overflowError)
}
+var shiftError = error(errorString("negative shift amount"))
+
+func panicshift() {
+ if hasPrefix(funcname(findfunc(getcallerpc())), "runtime.") {
+ throw(string(shiftError.(errorString)))
+ }
+ panicCheckMalloc(shiftError)
+ panic(shiftError)
+}
+
var floatError = error(errorString("floating point error"))
func panicfloat() {
diff --git a/test/fixedbugs/bug073.go b/test/fixedbugs/bug073.go
index 49b47ae..f3605b3 100644
--- a/test/fixedbugs/bug073.go
+++ b/test/fixedbugs/bug073.go
@@ -1,4 +1,4 @@
-// errorcheck
+// compile
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
@@ -7,8 +7,8 @@
package main
func main() {
- var s int = 0;
- var x int = 0;
- x = x << s; // ERROR "illegal|inval|shift"
- x = x >> s; // ERROR "illegal|inval|shift"
+ var s int = 0
+ var x int = 0
+ x = x << s // as of 1.13, these are ok
+ x = x >> s // as of 1.13, these are ok
}
diff --git a/test/fixedbugs/issue19113.go b/test/fixedbugs/issue19113.go
new file mode 100644
index 0000000..5e01dde
--- /dev/null
+++ b/test/fixedbugs/issue19113.go
@@ -0,0 +1,108 @@
+// run
+
+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import "reflect"
+
+var tests = []interface{}{
+ func(x int, s int) int {
+ return x << s
+ },
+ func(x int, s int64) int {
+ return x << s
+ },
+ func(x int, s int32) int {
+ return x << s
+ },
+ func(x int, s int16) int {
+ return x << s
+ },
+ func(x int, s int8) int {
+ return x << s
+ },
+ func(x int, s int) int {
+ return x >> s
+ },
+ func(x int, s int64) int {
+ return x >> s
+ },
+ func(x int, s int32) int {
+ return x >> s
+ },
+ func(x int, s int16) int {
+ return x >> s
+ },
+ func(x int, s int8) int {
+ return x >> s
+ },
+ func(x uint, s int) uint {
+ return x << s
+ },
+ func(x uint, s int64) uint {
+ return x << s
+ },
+ func(x uint, s int32) uint {
+ return x << s
+ },
+ func(x uint, s int16) uint {
+ return x << s
+ },
+ func(x uint, s int8) uint {
+ return x << s
+ },
+ func(x uint, s int) uint {
+ return x >> s
+ },
+ func(x uint, s int64) uint {
+ return x >> s
+ },
+ func(x uint, s int32) uint {
+ return x >> s
+ },
+ func(x uint, s int16) uint {
+ return x >> s
+ },
+ func(x uint, s int8) uint {
+ return x >> s
+ },
+}
+
+func main() {
+ for _, t := range tests {
+ runTest(reflect.ValueOf(t))
+ }
+}
+
+func runTest(f reflect.Value) {
+ xt := f.Type().In(0)
+ st := f.Type().In(1)
+
+ for _, x := range []int{1, 0, -1} {
+ for _, s := range []int{-99, -64, -63, -32, -31, -16, -15, -8, -7, -1, 0, 1, 7, 8, 15, 16, 31, 32, 63, 64, 99} {
+ args := []reflect.Value{
+ reflect.ValueOf(x).Convert(xt),
+ reflect.ValueOf(s).Convert(st),
+ }
+ if s < 0 {
+ shouldPanic(func() {
+ f.Call(args)
+ })
+ } else {
+ f.Call(args) // should not panic
+ }
+ }
+ }
+}
+
+func shouldPanic(f func()) {
+ defer func() {
+ if recover() == nil {
+ panic("did not panic")
+ }
+ }()
+ f()
+}