| // Copyright 2016 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| // This program generates a test to verify that the standard arithmetic |
| // operators properly handle constant folding. The test file should be |
| // generated with a known working version of go. |
| // launch with `go run constFoldGen.go` a file called constFold_test.go |
| // will be written into the grandparent directory containing the tests. |
| |
| package main |
| |
| import ( |
| "bytes" |
| "fmt" |
| "go/format" |
| "io/ioutil" |
| "log" |
| ) |
| |
| type op struct { |
| name, symbol string |
| } |
| type szD struct { |
| name string |
| sn string |
| u []uint64 |
| i []int64 |
| } |
| |
| var szs []szD = []szD{ |
| szD{name: "uint64", sn: "64", u: []uint64{0, 1, 4294967296, 0xffffFFFFffffFFFF}}, |
| szD{name: "int64", sn: "64", i: []int64{-0x8000000000000000, -0x7FFFFFFFFFFFFFFF, |
| -4294967296, -1, 0, 1, 4294967296, 0x7FFFFFFFFFFFFFFE, 0x7FFFFFFFFFFFFFFF}}, |
| |
| szD{name: "uint32", sn: "32", u: []uint64{0, 1, 4294967295}}, |
| szD{name: "int32", sn: "32", i: []int64{-0x80000000, -0x7FFFFFFF, -1, 0, |
| 1, 0x7FFFFFFF}}, |
| |
| szD{name: "uint16", sn: "16", u: []uint64{0, 1, 65535}}, |
| szD{name: "int16", sn: "16", i: []int64{-32768, -32767, -1, 0, 1, 32766, 32767}}, |
| |
| szD{name: "uint8", sn: "8", u: []uint64{0, 1, 255}}, |
| szD{name: "int8", sn: "8", i: []int64{-128, -127, -1, 0, 1, 126, 127}}, |
| } |
| |
| var ops = []op{ |
| op{"add", "+"}, op{"sub", "-"}, op{"div", "/"}, op{"mul", "*"}, |
| op{"lsh", "<<"}, op{"rsh", ">>"}, op{"mod", "%"}, |
| } |
| |
| // compute the result of i op j, cast as type t. |
| func ansU(i, j uint64, t, op string) string { |
| var ans uint64 |
| switch op { |
| case "+": |
| ans = i + j |
| case "-": |
| ans = i - j |
| case "*": |
| ans = i * j |
| case "/": |
| if j != 0 { |
| ans = i / j |
| } |
| case "%": |
| if j != 0 { |
| ans = i % j |
| } |
| case "<<": |
| ans = i << j |
| case ">>": |
| ans = i >> j |
| } |
| switch t { |
| case "uint32": |
| ans = uint64(uint32(ans)) |
| case "uint16": |
| ans = uint64(uint16(ans)) |
| case "uint8": |
| ans = uint64(uint8(ans)) |
| } |
| return fmt.Sprintf("%d", ans) |
| } |
| |
| // compute the result of i op j, cast as type t. |
| func ansS(i, j int64, t, op string) string { |
| var ans int64 |
| switch op { |
| case "+": |
| ans = i + j |
| case "-": |
| ans = i - j |
| case "*": |
| ans = i * j |
| case "/": |
| if j != 0 { |
| ans = i / j |
| } |
| case "%": |
| if j != 0 { |
| ans = i % j |
| } |
| case "<<": |
| ans = i << uint64(j) |
| case ">>": |
| ans = i >> uint64(j) |
| } |
| switch t { |
| case "int32": |
| ans = int64(int32(ans)) |
| case "int16": |
| ans = int64(int16(ans)) |
| case "int8": |
| ans = int64(int8(ans)) |
| } |
| return fmt.Sprintf("%d", ans) |
| } |
| |
| func main() { |
| w := new(bytes.Buffer) |
| fmt.Fprintf(w, "// run\n") |
| fmt.Fprintf(w, "// Code generated by gen/constFoldGen.go. DO NOT EDIT.\n\n") |
| fmt.Fprintf(w, "package gc\n") |
| fmt.Fprintf(w, "import \"testing\"\n") |
| |
| for _, s := range szs { |
| for _, o := range ops { |
| if o.symbol == "<<" || o.symbol == ">>" { |
| // shifts handled separately below, as they can have |
| // different types on the LHS and RHS. |
| continue |
| } |
| fmt.Fprintf(w, "func TestConstFold%s%s(t *testing.T) {\n", s.name, o.name) |
| fmt.Fprintf(w, "\tvar x, y, r %s\n", s.name) |
| // unsigned test cases |
| for _, c := range s.u { |
| fmt.Fprintf(w, "\tx = %d\n", c) |
| for _, d := range s.u { |
| if d == 0 && (o.symbol == "/" || o.symbol == "%") { |
| continue |
| } |
| fmt.Fprintf(w, "\ty = %d\n", d) |
| fmt.Fprintf(w, "\tr = x %s y\n", o.symbol) |
| want := ansU(c, d, s.name, o.symbol) |
| fmt.Fprintf(w, "\tif r != %s {\n", want) |
| fmt.Fprintf(w, "\t\tt.Errorf(\"%d %%s %d = %%d, want %s\", %q, r)\n", c, d, want, o.symbol) |
| fmt.Fprintf(w, "\t}\n") |
| } |
| } |
| // signed test cases |
| for _, c := range s.i { |
| fmt.Fprintf(w, "\tx = %d\n", c) |
| for _, d := range s.i { |
| if d == 0 && (o.symbol == "/" || o.symbol == "%") { |
| continue |
| } |
| fmt.Fprintf(w, "\ty = %d\n", d) |
| fmt.Fprintf(w, "\tr = x %s y\n", o.symbol) |
| want := ansS(c, d, s.name, o.symbol) |
| fmt.Fprintf(w, "\tif r != %s {\n", want) |
| fmt.Fprintf(w, "\t\tt.Errorf(\"%d %%s %d = %%d, want %s\", %q, r)\n", c, d, want, o.symbol) |
| fmt.Fprintf(w, "\t}\n") |
| } |
| } |
| fmt.Fprintf(w, "}\n") |
| } |
| } |
| |
| // Special signed/unsigned cases for shifts |
| for _, ls := range szs { |
| for _, rs := range szs { |
| if rs.name[0] != 'u' { |
| continue |
| } |
| for _, o := range ops { |
| if o.symbol != "<<" && o.symbol != ">>" { |
| continue |
| } |
| fmt.Fprintf(w, "func TestConstFold%s%s%s(t *testing.T) {\n", ls.name, rs.name, o.name) |
| fmt.Fprintf(w, "\tvar x, r %s\n", ls.name) |
| fmt.Fprintf(w, "\tvar y %s\n", rs.name) |
| // unsigned LHS |
| for _, c := range ls.u { |
| fmt.Fprintf(w, "\tx = %d\n", c) |
| for _, d := range rs.u { |
| fmt.Fprintf(w, "\ty = %d\n", d) |
| fmt.Fprintf(w, "\tr = x %s y\n", o.symbol) |
| want := ansU(c, d, ls.name, o.symbol) |
| fmt.Fprintf(w, "\tif r != %s {\n", want) |
| fmt.Fprintf(w, "\t\tt.Errorf(\"%d %%s %d = %%d, want %s\", %q, r)\n", c, d, want, o.symbol) |
| fmt.Fprintf(w, "\t}\n") |
| } |
| } |
| // signed LHS |
| for _, c := range ls.i { |
| fmt.Fprintf(w, "\tx = %d\n", c) |
| for _, d := range rs.u { |
| fmt.Fprintf(w, "\ty = %d\n", d) |
| fmt.Fprintf(w, "\tr = x %s y\n", o.symbol) |
| want := ansS(c, int64(d), ls.name, o.symbol) |
| fmt.Fprintf(w, "\tif r != %s {\n", want) |
| fmt.Fprintf(w, "\t\tt.Errorf(\"%d %%s %d = %%d, want %s\", %q, r)\n", c, d, want, o.symbol) |
| fmt.Fprintf(w, "\t}\n") |
| } |
| } |
| fmt.Fprintf(w, "}\n") |
| } |
| } |
| } |
| |
| // Constant folding for comparisons |
| for _, s := range szs { |
| fmt.Fprintf(w, "func TestConstFoldCompare%s(t *testing.T) {\n", s.name) |
| for _, x := range s.i { |
| for _, y := range s.i { |
| fmt.Fprintf(w, "\t{\n") |
| fmt.Fprintf(w, "\t\tvar x %s = %d\n", s.name, x) |
| fmt.Fprintf(w, "\t\tvar y %s = %d\n", s.name, y) |
| if x == y { |
| fmt.Fprintf(w, "\t\tif !(x == y) { t.Errorf(\"!(%%d == %%d)\", x, y) }\n") |
| } else { |
| fmt.Fprintf(w, "\t\tif x == y { t.Errorf(\"%%d == %%d\", x, y) }\n") |
| } |
| if x != y { |
| fmt.Fprintf(w, "\t\tif !(x != y) { t.Errorf(\"!(%%d != %%d)\", x, y) }\n") |
| } else { |
| fmt.Fprintf(w, "\t\tif x != y { t.Errorf(\"%%d != %%d\", x, y) }\n") |
| } |
| if x < y { |
| fmt.Fprintf(w, "\t\tif !(x < y) { t.Errorf(\"!(%%d < %%d)\", x, y) }\n") |
| } else { |
| fmt.Fprintf(w, "\t\tif x < y { t.Errorf(\"%%d < %%d\", x, y) }\n") |
| } |
| if x > y { |
| fmt.Fprintf(w, "\t\tif !(x > y) { t.Errorf(\"!(%%d > %%d)\", x, y) }\n") |
| } else { |
| fmt.Fprintf(w, "\t\tif x > y { t.Errorf(\"%%d > %%d\", x, y) }\n") |
| } |
| if x <= y { |
| fmt.Fprintf(w, "\t\tif !(x <= y) { t.Errorf(\"!(%%d <= %%d)\", x, y) }\n") |
| } else { |
| fmt.Fprintf(w, "\t\tif x <= y { t.Errorf(\"%%d <= %%d\", x, y) }\n") |
| } |
| if x >= y { |
| fmt.Fprintf(w, "\t\tif !(x >= y) { t.Errorf(\"!(%%d >= %%d)\", x, y) }\n") |
| } else { |
| fmt.Fprintf(w, "\t\tif x >= y { t.Errorf(\"%%d >= %%d\", x, y) }\n") |
| } |
| fmt.Fprintf(w, "\t}\n") |
| } |
| } |
| for _, x := range s.u { |
| for _, y := range s.u { |
| fmt.Fprintf(w, "\t{\n") |
| fmt.Fprintf(w, "\t\tvar x %s = %d\n", s.name, x) |
| fmt.Fprintf(w, "\t\tvar y %s = %d\n", s.name, y) |
| if x == y { |
| fmt.Fprintf(w, "\t\tif !(x == y) { t.Errorf(\"!(%%d == %%d)\", x, y) }\n") |
| } else { |
| fmt.Fprintf(w, "\t\tif x == y { t.Errorf(\"%%d == %%d\", x, y) }\n") |
| } |
| if x != y { |
| fmt.Fprintf(w, "\t\tif !(x != y) { t.Errorf(\"!(%%d != %%d)\", x, y) }\n") |
| } else { |
| fmt.Fprintf(w, "\t\tif x != y { t.Errorf(\"%%d != %%d\", x, y) }\n") |
| } |
| if x < y { |
| fmt.Fprintf(w, "\t\tif !(x < y) { t.Errorf(\"!(%%d < %%d)\", x, y) }\n") |
| } else { |
| fmt.Fprintf(w, "\t\tif x < y { t.Errorf(\"%%d < %%d\", x, y) }\n") |
| } |
| if x > y { |
| fmt.Fprintf(w, "\t\tif !(x > y) { t.Errorf(\"!(%%d > %%d)\", x, y) }\n") |
| } else { |
| fmt.Fprintf(w, "\t\tif x > y { t.Errorf(\"%%d > %%d\", x, y) }\n") |
| } |
| if x <= y { |
| fmt.Fprintf(w, "\t\tif !(x <= y) { t.Errorf(\"!(%%d <= %%d)\", x, y) }\n") |
| } else { |
| fmt.Fprintf(w, "\t\tif x <= y { t.Errorf(\"%%d <= %%d\", x, y) }\n") |
| } |
| if x >= y { |
| fmt.Fprintf(w, "\t\tif !(x >= y) { t.Errorf(\"!(%%d >= %%d)\", x, y) }\n") |
| } else { |
| fmt.Fprintf(w, "\t\tif x >= y { t.Errorf(\"%%d >= %%d\", x, y) }\n") |
| } |
| fmt.Fprintf(w, "\t}\n") |
| } |
| } |
| fmt.Fprintf(w, "}\n") |
| } |
| |
| // gofmt result |
| b := w.Bytes() |
| src, err := format.Source(b) |
| if err != nil { |
| fmt.Printf("%s\n", b) |
| panic(err) |
| } |
| |
| // write to file |
| err = ioutil.WriteFile("../../constFold_test.go", src, 0666) |
| if err != nil { |
| log.Fatalf("can't write output: %v\n", err) |
| } |
| } |