blob: fbcd46aa29e2851dcc4bf101bf16f50c13511e1c [file] [log] [blame]
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"fmt"
"go/ast"
"go/parser"
"go/token"
"os"
"reflect"
"strings"
"unicode"
"utf8"
)
func initRewrite() {
if *rewriteRule == "" {
return
}
f := strings.Split(*rewriteRule, "->", -1)
if len(f) != 2 {
fmt.Fprintf(os.Stderr, "rewrite rule must be of the form 'pattern -> replacement'\n")
os.Exit(2)
}
pattern := parseExpr(f[0], "pattern")
replace := parseExpr(f[1], "replacement")
rewrite = func(p *ast.File) *ast.File { return rewriteFile(pattern, replace, p) }
}
// parseExpr parses s as an expression.
// It might make sense to expand this to allow statement patterns,
// but there are problems with preserving formatting and also
// with what a wildcard for a statement looks like.
func parseExpr(s string, what string) ast.Expr {
x, err := parser.ParseExpr(fset, "input", s)
if err != nil {
fmt.Fprintf(os.Stderr, "parsing %s %s: %s\n", what, s, err)
os.Exit(2)
}
return x
}
// rewriteFile applies the rewrite rule 'pattern -> replace' to an entire file.
func rewriteFile(pattern, replace ast.Expr, p *ast.File) *ast.File {
m := make(map[string]reflect.Value)
pat := reflect.NewValue(pattern)
repl := reflect.NewValue(replace)
var f func(val reflect.Value) reflect.Value // f is recursive
f = func(val reflect.Value) reflect.Value {
for k := range m {
m[k] = nil, false
}
val = apply(f, val)
if match(m, pat, val) {
val = subst(m, repl, reflect.NewValue(val.Interface().(ast.Node).Pos()))
}
return val
}
return apply(f, reflect.NewValue(p)).Interface().(*ast.File)
}
// setValue is a wrapper for x.SetValue(y); it protects
// the caller from panics if x cannot be changed to y.
func setValue(x, y reflect.Value) {
defer func() {
if x := recover(); x != nil {
if s, ok := x.(string); ok && strings.HasPrefix(s, "type mismatch") {
// x cannot be set to y - ignore this rewrite
return
}
panic(x)
}
}()
x.SetValue(y)
}
// apply replaces each AST field x in val with f(x), returning val.
// To avoid extra conversions, f operates on the reflect.Value form.
func apply(f func(reflect.Value) reflect.Value, val reflect.Value) reflect.Value {
if val == nil {
return nil
}
switch v := reflect.Indirect(val).(type) {
case *reflect.SliceValue:
for i := 0; i < v.Len(); i++ {
e := v.Elem(i)
setValue(e, f(e))
}
case *reflect.StructValue:
for i := 0; i < v.NumField(); i++ {
e := v.Field(i)
setValue(e, f(e))
}
case *reflect.InterfaceValue:
e := v.Elem()
setValue(v, f(e))
}
return val
}
var positionType = reflect.Typeof(token.NoPos)
var identType = reflect.Typeof((*ast.Ident)(nil))
func isWildcard(s string) bool {
rune, size := utf8.DecodeRuneInString(s)
return size == len(s) && unicode.IsLower(rune)
}
// match returns true if pattern matches val,
// recording wildcard submatches in m.
// If m == nil, match checks whether pattern == val.
func match(m map[string]reflect.Value, pattern, val reflect.Value) bool {
// Wildcard matches any expression. If it appears multiple
// times in the pattern, it must match the same expression
// each time.
if m != nil && pattern != nil && pattern.Type() == identType {
name := pattern.Interface().(*ast.Ident).Name
if isWildcard(name) && val != nil {
// wildcards only match expressions
if _, ok := val.Interface().(ast.Expr); ok {
if old, ok := m[name]; ok {
return match(nil, old, val)
}
m[name] = val
return true
}
}
}
// Otherwise, pattern and val must match recursively.
if pattern == nil || val == nil {
return pattern == nil && val == nil
}
if pattern.Type() != val.Type() {
return false
}
// Special cases.
switch pattern.Type() {
case positionType:
// token positions don't need to match
return true
case identType:
// For identifiers, only the names need to match
// (and none of the other *ast.Object information).
// This is a common case, handle it all here instead
// of recursing down any further via reflection.
p := pattern.Interface().(*ast.Ident)
v := val.Interface().(*ast.Ident)
return p == nil && v == nil || p != nil && v != nil && p.Name == v.Name
}
p := reflect.Indirect(pattern)
v := reflect.Indirect(val)
if p == nil || v == nil {
return p == nil && v == nil
}
switch p := p.(type) {
case *reflect.SliceValue:
v := v.(*reflect.SliceValue)
if p.Len() != v.Len() {
return false
}
for i := 0; i < p.Len(); i++ {
if !match(m, p.Elem(i), v.Elem(i)) {
return false
}
}
return true
case *reflect.StructValue:
v := v.(*reflect.StructValue)
if p.NumField() != v.NumField() {
return false
}
for i := 0; i < p.NumField(); i++ {
if !match(m, p.Field(i), v.Field(i)) {
return false
}
}
return true
case *reflect.InterfaceValue:
v := v.(*reflect.InterfaceValue)
return match(m, p.Elem(), v.Elem())
}
// Handle token integers, etc.
return p.Interface() == v.Interface()
}
// subst returns a copy of pattern with values from m substituted in place
// of wildcards and pos used as the position of tokens from the pattern.
// if m == nil, subst returns a copy of pattern and doesn't change the line
// number information.
func subst(m map[string]reflect.Value, pattern reflect.Value, pos reflect.Value) reflect.Value {
if pattern == nil {
return nil
}
// Wildcard gets replaced with map value.
if m != nil && pattern.Type() == identType {
name := pattern.Interface().(*ast.Ident).Name
if isWildcard(name) {
if old, ok := m[name]; ok {
return subst(nil, old, nil)
}
}
}
if pos != nil && pattern.Type() == positionType {
// use new position only if old position was valid in the first place
if old := pattern.Interface().(token.Pos); !old.IsValid() {
return pattern
}
return pos
}
// Otherwise copy.
switch p := pattern.(type) {
case *reflect.SliceValue:
v := reflect.MakeSlice(p.Type().(*reflect.SliceType), p.Len(), p.Len())
for i := 0; i < p.Len(); i++ {
v.Elem(i).SetValue(subst(m, p.Elem(i), pos))
}
return v
case *reflect.StructValue:
v := reflect.MakeZero(p.Type()).(*reflect.StructValue)
for i := 0; i < p.NumField(); i++ {
v.Field(i).SetValue(subst(m, p.Field(i), pos))
}
return v
case *reflect.PtrValue:
v := reflect.MakeZero(p.Type()).(*reflect.PtrValue)
v.PointTo(subst(m, p.Elem(), pos))
return v
case *reflect.InterfaceValue:
v := reflect.MakeZero(p.Type()).(*reflect.InterfaceValue)
v.SetValue(subst(m, p.Elem(), pos))
return v
}
return pattern
}