gofmt: add -r flag to rewrite source code according to pattern
a little slow, but usable (speed unchanged when not using -r)
tweak go/printer to handle nodes without line numbers
more gracefully in a couple cases.
R=gri
https://golang.org/cl/156103
diff --git a/src/cmd/gofmt/Makefile b/src/cmd/gofmt/Makefile
index a93b8c3..dbc134f 100644
--- a/src/cmd/gofmt/Makefile
+++ b/src/cmd/gofmt/Makefile
@@ -7,6 +7,7 @@
TARG=gofmt
GOFILES=\
gofmt.go\
+ rewrite.go\
include $(GOROOT)/src/Make.cmd
diff --git a/src/cmd/gofmt/gofmt.go b/src/cmd/gofmt/gofmt.go
index bec4c88..d7c96dc 100644
--- a/src/cmd/gofmt/gofmt.go
+++ b/src/cmd/gofmt/gofmt.go
@@ -8,6 +8,7 @@
"bytes";
"flag";
"fmt";
+ "go/ast";
"go/parser";
"go/printer";
"go/scanner";
@@ -20,8 +21,9 @@
var (
// main operation modes
- list = flag.Bool("l", false, "list files whose formatting differs from gofmt's");
- write = flag.Bool("w", false, "write result to (source) file instead of stdout");
+ list = flag.Bool("l", false, "list files whose formatting differs from gofmt's");
+ write = flag.Bool("w", false, "write result to (source) file instead of stdout");
+ rewriteRule = flag.String("r", "", "rewrite rule (e.g., 'α[β:len(α)] -> α[β:]')");
// debugging support
comments = flag.Bool("comments", true, "print comments");
@@ -34,6 +36,8 @@
var exitCode = 0
+var rewrite func(*ast.File) *ast.File
+
func report(err os.Error) {
scanner.PrintError(os.Stderr, err);
@@ -86,6 +90,10 @@
return err
}
+ if rewrite != nil {
+ file = rewrite(file)
+ }
+
var res bytes.Buffer;
_, err = (&printer.Config{printerMode(), *tabwidth, nil}).Fprint(&res, file);
if err != nil {
@@ -154,6 +162,8 @@
os.Exit(2);
}
+ initRewrite();
+
if flag.NArg() == 0 {
if err := processFile("/dev/stdin"); err != nil {
report(err)
diff --git a/src/cmd/gofmt/rewrite.go b/src/cmd/gofmt/rewrite.go
new file mode 100644
index 0000000..9399bcd
--- /dev/null
+++ b/src/cmd/gofmt/rewrite.go
@@ -0,0 +1,226 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+ "fmt";
+ "go/ast";
+ "go/parser";
+ "go/token";
+ "os";
+ "reflect";
+ "strings";
+ "unicode";
+ "utf8";
+)
+
+
+func initRewrite() {
+ if *rewriteRule == "" {
+ return
+ }
+ f := strings.Split(*rewriteRule, "->", 0);
+ if len(f) != 2 {
+ fmt.Fprintf(os.Stderr, "rewrite rule must be of the form 'pattern -> replacement'\n");
+ os.Exit(2);
+ }
+ pattern := parseExpr(f[0], "pattern");
+ replace := parseExpr(f[1], "replacement");
+ rewrite = func(p *ast.File) *ast.File { return rewriteFile(pattern, replace, p) };
+}
+
+
+// parseExpr parses s as an expression.
+// It might make sense to expand this to allow statement patterns,
+// but there are problems with preserving formatting and also
+// with what a wildcard for a statement looks like.
+func parseExpr(s string, what string) ast.Expr {
+ stmts, err := parser.ParseStmtList("input", s);
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "parsing %s %s: %s\n", what, s, err);
+ os.Exit(2);
+ }
+ if len(stmts) != 1 {
+ fmt.Fprintf(os.Stderr, "%s must be single expression\n", what);
+ os.Exit(2);
+ }
+ x, ok := stmts[0].(*ast.ExprStmt);
+ if !ok {
+ fmt.Fprintf(os.Stderr, "%s must be single expression\n", what);
+ os.Exit(2);
+ }
+ return x.X;
+}
+
+
+// rewriteFile applys the rewrite rule pattern -> replace to an entire file.
+func rewriteFile(pattern, replace ast.Expr, p *ast.File) *ast.File {
+ m := make(map[string]reflect.Value);
+ pat := reflect.NewValue(pattern);
+ repl := reflect.NewValue(replace);
+ var f func(val reflect.Value) reflect.Value; // f is recursive
+ f = func(val reflect.Value) reflect.Value {
+ for k := range m {
+ m[k] = nil, false
+ }
+ if match(m, pat, val) {
+ return subst(m, repl)
+ }
+ return apply(f, val);
+ };
+ return apply(f, reflect.NewValue(p)).Interface().(*ast.File);
+}
+
+
+var positionType = reflect.Typeof(token.Position{})
+var zeroPosition = reflect.NewValue(token.Position{})
+var identType = reflect.Typeof((*ast.Ident)(nil))
+
+
+func isWildcard(s string) bool {
+ rune, _ := utf8.DecodeRuneInString(s);
+ return unicode.Is(unicode.Greek, rune) && unicode.IsLower(rune);
+}
+
+
+// apply replaces each AST field x in val with f(x), returning val.
+// To avoid extra conversions, f operates on the reflect.Value form.
+func apply(f func(reflect.Value) reflect.Value, val reflect.Value) reflect.Value {
+ if val == nil {
+ return nil
+ }
+ switch v := reflect.Indirect(val).(type) {
+ case *reflect.SliceValue:
+ for i := 0; i < v.Len(); i++ {
+ e := v.Elem(i);
+ e.SetValue(f(e));
+ }
+ case *reflect.StructValue:
+ for i := 0; i < v.NumField(); i++ {
+ e := v.Field(i);
+ e.SetValue(f(e));
+ }
+ case *reflect.InterfaceValue:
+ e := v.Elem();
+ v.SetValue(f(e));
+ }
+ return val;
+}
+
+
+// match returns true if pattern matches val,
+// recording wildcard submatches in m.
+// If m == nil, match checks whether pattern == val.
+func match(m map[string]reflect.Value, pattern, val reflect.Value) bool {
+ // Wildcard matches any expression. If it appears multiple
+ // times in the pattern, it must match the same expression
+ // each time.
+ if m != nil && pattern.Type() == identType {
+ name := pattern.Interface().(*ast.Ident).Value;
+ if isWildcard(name) {
+ if old, ok := m[name]; ok {
+ return match(nil, old, val)
+ }
+ m[name] = val;
+ return true;
+ }
+ }
+
+ // Otherwise, the expressions must match recursively.
+ if pattern == nil || val == nil {
+ return pattern == nil && val == nil
+ }
+ if pattern.Type() != val.Type() {
+ return false
+ }
+
+ // Token positions need not match.
+ if pattern.Type() == positionType {
+ return true
+ }
+
+ p := reflect.Indirect(pattern);
+ v := reflect.Indirect(val);
+
+ switch p := p.(type) {
+ case *reflect.SliceValue:
+ v := v.(*reflect.SliceValue);
+ for i := 0; i < p.Len(); i++ {
+ if !match(m, p.Elem(i), v.Elem(i)) {
+ return false
+ }
+ }
+ return true;
+
+ case *reflect.StructValue:
+ v := v.(*reflect.StructValue);
+ for i := 0; i < p.NumField(); i++ {
+ if !match(m, p.Field(i), v.Field(i)) {
+ return false
+ }
+ }
+ return true;
+
+ case *reflect.InterfaceValue:
+ v := v.(*reflect.InterfaceValue);
+ return match(m, p.Elem(), v.Elem());
+ }
+
+ // Handle token integers, etc.
+ return p.Interface() == v.Interface();
+}
+
+
+// subst returns a copy of pattern with values from m substituted in place of wildcards.
+// if m == nil, subst returns a copy of pattern.
+// Either way, the returned value has no valid line number information.
+func subst(m map[string]reflect.Value, pattern reflect.Value) reflect.Value {
+ if pattern == nil {
+ return nil
+ }
+
+ // Wildcard gets replaced with map value.
+ if m != nil && pattern.Type() == identType {
+ name := pattern.Interface().(*ast.Ident).Value;
+ if isWildcard(name) {
+ if old, ok := m[name]; ok {
+ return subst(nil, old)
+ }
+ }
+ }
+
+ if pattern.Type() == positionType {
+ return zeroPosition
+ }
+
+ // Otherwise copy.
+ switch p := pattern.(type) {
+ case *reflect.SliceValue:
+ v := reflect.MakeSlice(p.Type().(*reflect.SliceType), p.Len(), p.Len());
+ for i := 0; i < p.Len(); i++ {
+ v.Elem(i).SetValue(subst(m, p.Elem(i)))
+ }
+ return v;
+
+ case *reflect.StructValue:
+ v := reflect.MakeZero(p.Type()).(*reflect.StructValue);
+ for i := 0; i < p.NumField(); i++ {
+ v.Field(i).SetValue(subst(m, p.Field(i)))
+ }
+ return v;
+
+ case *reflect.PtrValue:
+ v := reflect.MakeZero(p.Type()).(*reflect.PtrValue);
+ v.PointTo(subst(m, p.Elem()));
+ return v;
+
+ case *reflect.InterfaceValue:
+ v := reflect.MakeZero(p.Type()).(*reflect.InterfaceValue);
+ v.SetValue(subst(m, p.Elem()));
+ return v;
+ }
+
+ return pattern;
+}
diff --git a/src/pkg/go/printer/nodes.go b/src/pkg/go/printer/nodes.go
index 6304830..1c74603 100644
--- a/src/pkg/go/printer/nodes.go
+++ b/src/pkg/go/printer/nodes.go
@@ -52,6 +52,17 @@
case n > max:
n = max
}
+
+ // TODO(gri): try to avoid direct manipulation of p.pos
+ // demo of why this is necessary: run gofmt -r 'i < i -> i < j' x.go on this x.go:
+ // package main
+ // func main() {
+ // i < i;
+ // j < 10;
+ // }
+ //
+ p.pos.Line += n;
+
if n > 0 {
p.print(ws);
if newSection {
@@ -199,7 +210,7 @@
if mode&commaSep != 0 {
p.print(token.COMMA)
}
- if prev < line {
+ if prev < line && prev > 0 && line > 0 {
if p.linebreak(line, 1, 2, ws, true) {
ws = ignore;
*multiLine = true;
@@ -564,8 +575,7 @@
xline := p.pos.Line; // before the operator (it may be on the next line!)
yline := x.Y.Pos().Line;
p.print(x.OpPos, x.Op);
- if xline != yline {
- //println(x.OpPos.String());
+ if xline != yline && xline > 0 && yline > 0 {
// at least one line break, but respect an extra empty line
// in the source
if p.linebreak(yline, 1, 2, ws, true) {