package eg

// This file defines the AST rewriting pass.
// Most of it was plundered directly from
// $GOROOT/src/cmd/gofmt/rewrite.go (after convergent evolution).

import (
	"fmt"
	"go/ast"
	"go/token"
	"os"
	"reflect"
	"sort"
	"strconv"
	"strings"

	"golang.org/x/tools/astutil"
	"golang.org/x/tools/go/types"
)

// Transform applies the transformation to the specified parsed file,
// whose type information is supplied in info, and returns the number
// of replacements that were made.
//
// It mutates the AST in place (the identity of the root node is
// unchanged), and may add nodes for which no type information is
// available in info.
//
// Derived from rewriteFile in $GOROOT/src/cmd/gofmt/rewrite.go.
//
func (tr *Transformer) Transform(info *types.Info, pkg *types.Package, file *ast.File) int {
	if !tr.seenInfos[info] {
		tr.seenInfos[info] = true
		mergeTypeInfo(&tr.info.Info, info)
	}
	tr.currentPkg = pkg
	tr.nsubsts = 0

	if tr.verbose {
		fmt.Fprintf(os.Stderr, "before: %s\n", astString(tr.fset, tr.before))
		fmt.Fprintf(os.Stderr, "after: %s\n", astString(tr.fset, tr.after))
	}

	var f func(rv reflect.Value) reflect.Value
	f = func(rv reflect.Value) reflect.Value {
		// don't bother if val is invalid to start with
		if !rv.IsValid() {
			return reflect.Value{}
		}

		rv = apply(f, rv)

		e := rvToExpr(rv)
		if e != nil {
			savedEnv := tr.env
			tr.env = make(map[string]ast.Expr) // inefficient!  Use a slice of k/v pairs

			if tr.matchExpr(tr.before, e) {
				if tr.verbose {
					fmt.Fprintf(os.Stderr, "%s matches %s",
						astString(tr.fset, tr.before), astString(tr.fset, e))
					if len(tr.env) > 0 {
						fmt.Fprintf(os.Stderr, " with:")
						for name, ast := range tr.env {
							fmt.Fprintf(os.Stderr, " %s->%s",
								name, astString(tr.fset, ast))
						}
					}
					fmt.Fprintf(os.Stderr, "\n")
				}
				tr.nsubsts++

				// Clone the replacement tree, performing parameter substitution.
				// We update all positions to n.Pos() to aid comment placement.
				rv = tr.subst(tr.env, reflect.ValueOf(tr.after),
					reflect.ValueOf(e.Pos()))
			}
			tr.env = savedEnv
		}

		return rv
	}
	file2 := apply(f, reflect.ValueOf(file)).Interface().(*ast.File)

	// By construction, the root node is unchanged.
	if file != file2 {
		panic("BUG")
	}

	// Add any necessary imports.
	// TODO(adonovan): remove no-longer needed imports too.
	if tr.nsubsts > 0 {
		pkgs := make(map[string]*types.Package)
		for obj := range tr.importedObjs {
			pkgs[obj.Pkg().Path()] = obj.Pkg()
		}

		for _, imp := range file.Imports {
			path, _ := strconv.Unquote(imp.Path.Value)
			delete(pkgs, path)
		}
		delete(pkgs, pkg.Path()) // don't import self

		// NB: AddImport may completely replace the AST!
		// It thus renders info and tr.info no longer relevant to file.
		var paths []string
		for path := range pkgs {
			paths = append(paths, path)
		}
		sort.Strings(paths)
		for _, path := range paths {
			astutil.AddImport(tr.fset, file, path)
		}
	}

	tr.currentPkg = nil

	return tr.nsubsts
}

// setValue is a wrapper for x.SetValue(y); it protects
// the caller from panics if x cannot be changed to y.
func setValue(x, y reflect.Value) {
	// don't bother if y is invalid to start with
	if !y.IsValid() {
		return
	}
	defer func() {
		if x := recover(); x != nil {
			if s, ok := x.(string); ok &&
				(strings.Contains(s, "type mismatch") || strings.Contains(s, "not assignable")) {
				// x cannot be set to y - ignore this rewrite
				return
			}
			panic(x)
		}
	}()
	x.Set(y)
}

// Values/types for special cases.
var (
	objectPtrNil = reflect.ValueOf((*ast.Object)(nil))
	scopePtrNil  = reflect.ValueOf((*ast.Scope)(nil))

	identType        = reflect.TypeOf((*ast.Ident)(nil))
	selectorExprType = reflect.TypeOf((*ast.SelectorExpr)(nil))
	objectPtrType    = reflect.TypeOf((*ast.Object)(nil))
	positionType     = reflect.TypeOf(token.NoPos)
	callExprType     = reflect.TypeOf((*ast.CallExpr)(nil))
	scopePtrType     = reflect.TypeOf((*ast.Scope)(nil))
)

// apply replaces each AST field x in val with f(x), returning val.
// To avoid extra conversions, f operates on the reflect.Value form.
func apply(f func(reflect.Value) reflect.Value, val reflect.Value) reflect.Value {
	if !val.IsValid() {
		return reflect.Value{}
	}

	// *ast.Objects introduce cycles and are likely incorrect after
	// rewrite; don't follow them but replace with nil instead
	if val.Type() == objectPtrType {
		return objectPtrNil
	}

	// similarly for scopes: they are likely incorrect after a rewrite;
	// replace them with nil
	if val.Type() == scopePtrType {
		return scopePtrNil
	}

	switch v := reflect.Indirect(val); v.Kind() {
	case reflect.Slice:
		for i := 0; i < v.Len(); i++ {
			e := v.Index(i)
			setValue(e, f(e))
		}
	case reflect.Struct:
		for i := 0; i < v.NumField(); i++ {
			e := v.Field(i)
			setValue(e, f(e))
		}
	case reflect.Interface:
		e := v.Elem()
		setValue(v, f(e))
	}
	return val
}

// subst returns a copy of (replacement) pattern with values from env
// substituted in place of wildcards and pos used as the position of
// tokens from the pattern.  if env == nil, subst returns a copy of
// pattern and doesn't change the line number information.
func (tr *Transformer) subst(env map[string]ast.Expr, pattern, pos reflect.Value) reflect.Value {
	if !pattern.IsValid() {
		return reflect.Value{}
	}

	// *ast.Objects introduce cycles and are likely incorrect after
	// rewrite; don't follow them but replace with nil instead
	if pattern.Type() == objectPtrType {
		return objectPtrNil
	}

	// similarly for scopes: they are likely incorrect after a rewrite;
	// replace them with nil
	if pattern.Type() == scopePtrType {
		return scopePtrNil
	}

	// Wildcard gets replaced with map value.
	if env != nil && pattern.Type() == identType {
		id := pattern.Interface().(*ast.Ident)
		if old, ok := env[id.Name]; ok {
			return tr.subst(nil, reflect.ValueOf(old), reflect.Value{})
		}
	}

	// Emit qualified identifiers in the pattern by appropriate
	// (possibly qualified) identifier in the input.
	//
	// The template cannot contain dot imports, so all identifiers
	// for imported objects are explicitly qualified.
	//
	// We assume (unsoundly) that there are no dot or named
	// imports in the input code, nor are any imported package
	// names shadowed, so the usual normal qualified identifier
	// syntax may be used.
	// TODO(adonovan): fix: avoid this assumption.
	//
	// A refactoring may be applied to a package referenced by the
	// template.  Objects belonging to the current package are
	// denoted by unqualified identifiers.
	//
	if tr.importedObjs != nil && pattern.Type() == selectorExprType {
		obj := isRef(pattern.Interface().(*ast.SelectorExpr), &tr.info)
		if obj != nil {
			if sel, ok := tr.importedObjs[obj]; ok {
				var id ast.Expr
				if obj.Pkg() == tr.currentPkg {
					id = sel.Sel // unqualified
				} else {
					id = sel // pkg-qualified
				}

				// Return a clone of id.
				saved := tr.importedObjs
				tr.importedObjs = nil // break cycle
				r := tr.subst(nil, reflect.ValueOf(id), pos)
				tr.importedObjs = saved
				return r
			}
		}
	}

	if pos.IsValid() && pattern.Type() == positionType {
		// use new position only if old position was valid in the first place
		if old := pattern.Interface().(token.Pos); !old.IsValid() {
			return pattern
		}
		return pos
	}

	// Otherwise copy.
	switch p := pattern; p.Kind() {
	case reflect.Slice:
		v := reflect.MakeSlice(p.Type(), p.Len(), p.Len())
		for i := 0; i < p.Len(); i++ {
			v.Index(i).Set(tr.subst(env, p.Index(i), pos))
		}
		return v

	case reflect.Struct:
		v := reflect.New(p.Type()).Elem()
		for i := 0; i < p.NumField(); i++ {
			v.Field(i).Set(tr.subst(env, p.Field(i), pos))
		}
		return v

	case reflect.Ptr:
		v := reflect.New(p.Type()).Elem()
		if elem := p.Elem(); elem.IsValid() {
			v.Set(tr.subst(env, elem, pos).Addr())
		}

		// Duplicate type information for duplicated ast.Expr.
		// All ast.Node implementations are *structs,
		// so this case catches them all.
		if e := rvToExpr(v); e != nil {
			updateTypeInfo(&tr.info.Info, e, p.Interface().(ast.Expr))
		}
		return v

	case reflect.Interface:
		v := reflect.New(p.Type()).Elem()
		if elem := p.Elem(); elem.IsValid() {
			v.Set(tr.subst(env, elem, pos))
		}
		return v
	}

	return pattern
}

// -- utilities -------------------------------------------------------

func rvToExpr(rv reflect.Value) ast.Expr {
	if rv.CanInterface() {
		if e, ok := rv.Interface().(ast.Expr); ok {
			return e
		}
	}
	return nil
}

// updateTypeInfo duplicates type information for the existing AST old
// so that it also applies to duplicated AST new.
func updateTypeInfo(info *types.Info, new, old ast.Expr) {
	switch new := new.(type) {
	case *ast.Ident:
		orig := old.(*ast.Ident)
		if obj, ok := info.Defs[orig]; ok {
			info.Defs[new] = obj
		}
		if obj, ok := info.Uses[orig]; ok {
			info.Uses[new] = obj
		}

	case *ast.SelectorExpr:
		orig := old.(*ast.SelectorExpr)
		if sel, ok := info.Selections[orig]; ok {
			info.Selections[new] = sel
		}
	}

	if tv, ok := info.Types[old]; ok {
		info.Types[new] = tv
	}
}
