// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package main

import (
	"fmt"
	"go/ast"
	"go/parser"
	"go/token"
	exec "internal/execabs"
	"os"
	"path/filepath"
	"reflect"
	"runtime"
	"strings"
)

// Partial type checker.
//
// The fact that it is partial is very important: the input is
// an AST and a description of some type information to
// assume about one or more packages, but not all the
// packages that the program imports. The checker is
// expected to do as much as it can with what it has been
// given. There is not enough information supplied to do
// a full type check, but the type checker is expected to
// apply information that can be derived from variable
// declarations, function and method returns, and type switches
// as far as it can, so that the caller can still tell the types
// of expression relevant to a particular fix.
//
// TODO(rsc,gri): Replace with go/typechecker.
// Doing that could be an interesting test case for go/typechecker:
// the constraints about working with partial information will
// likely exercise it in interesting ways. The ideal interface would
// be to pass typecheck a map from importpath to package API text
// (Go source code), but for now we use data structures (TypeConfig, Type).
//
// The strings mostly use gofmt form.
//
// A Field or FieldList has as its type a comma-separated list
// of the types of the fields. For example, the field list
//	x, y, z int
// has type "int, int, int".

// The prefix "type " is the type of a type.
// For example, given
//	var x int
//	type T int
// x's type is "int" but T's type is "type int".
// mkType inserts the "type " prefix.
// getType removes it.
// isType tests for it.

func mkType(t string) string {
	return "type " + t
}

func getType(t string) string {
	if !isType(t) {
		return ""
	}
	return t[len("type "):]
}

func isType(t string) bool {
	return strings.HasPrefix(t, "type ")
}

// TypeConfig describes the universe of relevant types.
// For ease of creation, the types are all referred to by string
// name (e.g., "reflect.Value").  TypeByName is the only place
// where the strings are resolved.

type TypeConfig struct {
	Type map[string]*Type
	Var  map[string]string
	Func map[string]string

	// External maps from a name to its type.
	// It provides additional typings not present in the Go source itself.
	// For now, the only additional typings are those generated by cgo.
	External map[string]string
}

// typeof returns the type of the given name, which may be of
// the form "x" or "p.X".
func (cfg *TypeConfig) typeof(name string) string {
	if cfg.Var != nil {
		if t := cfg.Var[name]; t != "" {
			return t
		}
	}
	if cfg.Func != nil {
		if t := cfg.Func[name]; t != "" {
			return "func()" + t
		}
	}
	return ""
}

// Type describes the Fields and Methods of a type.
// If the field or method cannot be found there, it is next
// looked for in the Embed list.
type Type struct {
	Field  map[string]string // map field name to type
	Method map[string]string // map method name to comma-separated return types (should start with "func ")
	Embed  []string          // list of types this type embeds (for extra methods)
	Def    string            // definition of named type
}

// dot returns the type of "typ.name", making its decision
// using the type information in cfg.
func (typ *Type) dot(cfg *TypeConfig, name string) string {
	if typ.Field != nil {
		if t := typ.Field[name]; t != "" {
			return t
		}
	}
	if typ.Method != nil {
		if t := typ.Method[name]; t != "" {
			return t
		}
	}

	for _, e := range typ.Embed {
		etyp := cfg.Type[e]
		if etyp != nil {
			if t := etyp.dot(cfg, name); t != "" {
				return t
			}
		}
	}

	return ""
}

// typecheck type checks the AST f assuming the information in cfg.
// It returns two maps with type information:
// typeof maps AST nodes to type information in gofmt string form.
// assign maps type strings to lists of expressions that were assigned
// to values of another type that were assigned to that type.
func typecheck(cfg *TypeConfig, f *ast.File) (typeof map[interface{}]string, assign map[string][]interface{}) {
	typeof = make(map[interface{}]string)
	assign = make(map[string][]interface{})
	cfg1 := &TypeConfig{}
	*cfg1 = *cfg // make copy so we can add locally
	copied := false

	// If we import "C", add types of cgo objects.
	cfg.External = map[string]string{}
	cfg1.External = cfg.External
	if imports(f, "C") {
		// Run cgo on gofmtFile(f)
		// Parse, extract decls from _cgo_gotypes.go
		// Map _Ctype_* types to C.* types.
		err := func() error {
			txt, err := gofmtFile(f)
			if err != nil {
				return err
			}
			dir, err := os.MkdirTemp(os.TempDir(), "fix_cgo_typecheck")
			if err != nil {
				return err
			}
			defer os.RemoveAll(dir)
			err = os.WriteFile(filepath.Join(dir, "in.go"), txt, 0600)
			if err != nil {
				return err
			}
			cmd := exec.Command(filepath.Join(runtime.GOROOT(), "bin", "go"), "tool", "cgo", "-objdir", dir, "-srcdir", dir, "in.go")
			err = cmd.Run()
			if err != nil {
				return err
			}
			out, err := os.ReadFile(filepath.Join(dir, "_cgo_gotypes.go"))
			if err != nil {
				return err
			}
			cgo, err := parser.ParseFile(token.NewFileSet(), "cgo.go", out, 0)
			if err != nil {
				return err
			}
			for _, decl := range cgo.Decls {
				fn, ok := decl.(*ast.FuncDecl)
				if !ok {
					continue
				}
				if strings.HasPrefix(fn.Name.Name, "_Cfunc_") {
					var params, results []string
					for _, p := range fn.Type.Params.List {
						t := gofmt(p.Type)
						t = strings.ReplaceAll(t, "_Ctype_", "C.")
						params = append(params, t)
					}
					for _, r := range fn.Type.Results.List {
						t := gofmt(r.Type)
						t = strings.ReplaceAll(t, "_Ctype_", "C.")
						results = append(results, t)
					}
					cfg.External["C."+fn.Name.Name[7:]] = joinFunc(params, results)
				}
			}
			return nil
		}()
		if err != nil {
			fmt.Fprintf(os.Stderr, "go fix: warning: no cgo types: %s\n", err)
		}
	}

	// gather function declarations
	for _, decl := range f.Decls {
		fn, ok := decl.(*ast.FuncDecl)
		if !ok {
			continue
		}
		typecheck1(cfg, fn.Type, typeof, assign)
		t := typeof[fn.Type]
		if fn.Recv != nil {
			// The receiver must be a type.
			rcvr := typeof[fn.Recv]
			if !isType(rcvr) {
				if len(fn.Recv.List) != 1 {
					continue
				}
				rcvr = mkType(gofmt(fn.Recv.List[0].Type))
				typeof[fn.Recv.List[0].Type] = rcvr
			}
			rcvr = getType(rcvr)
			if rcvr != "" && rcvr[0] == '*' {
				rcvr = rcvr[1:]
			}
			typeof[rcvr+"."+fn.Name.Name] = t
		} else {
			if isType(t) {
				t = getType(t)
			} else {
				t = gofmt(fn.Type)
			}
			typeof[fn.Name] = t

			// Record typeof[fn.Name.Obj] for future references to fn.Name.
			typeof[fn.Name.Obj] = t
		}
	}

	// gather struct declarations
	for _, decl := range f.Decls {
		d, ok := decl.(*ast.GenDecl)
		if ok {
			for _, s := range d.Specs {
				switch s := s.(type) {
				case *ast.TypeSpec:
					if cfg1.Type[s.Name.Name] != nil {
						break
					}
					if !copied {
						copied = true
						// Copy map lazily: it's time.
						cfg1.Type = make(map[string]*Type)
						for k, v := range cfg.Type {
							cfg1.Type[k] = v
						}
					}
					t := &Type{Field: map[string]string{}}
					cfg1.Type[s.Name.Name] = t
					switch st := s.Type.(type) {
					case *ast.StructType:
						for _, f := range st.Fields.List {
							for _, n := range f.Names {
								t.Field[n.Name] = gofmt(f.Type)
							}
						}
					case *ast.ArrayType, *ast.StarExpr, *ast.MapType:
						t.Def = gofmt(st)
					}
				}
			}
		}
	}

	typecheck1(cfg1, f, typeof, assign)
	return typeof, assign
}

func makeExprList(a []*ast.Ident) []ast.Expr {
	var b []ast.Expr
	for _, x := range a {
		b = append(b, x)
	}
	return b
}

// Typecheck1 is the recursive form of typecheck.
// It is like typecheck but adds to the information in typeof
// instead of allocating a new map.
func typecheck1(cfg *TypeConfig, f interface{}, typeof map[interface{}]string, assign map[string][]interface{}) {
	// set sets the type of n to typ.
	// If isDecl is true, n is being declared.
	set := func(n ast.Expr, typ string, isDecl bool) {
		if typeof[n] != "" || typ == "" {
			if typeof[n] != typ {
				assign[typ] = append(assign[typ], n)
			}
			return
		}
		typeof[n] = typ

		// If we obtained typ from the declaration of x
		// propagate the type to all the uses.
		// The !isDecl case is a cheat here, but it makes
		// up in some cases for not paying attention to
		// struct fields. The real type checker will be
		// more accurate so we won't need the cheat.
		if id, ok := n.(*ast.Ident); ok && id.Obj != nil && (isDecl || typeof[id.Obj] == "") {
			typeof[id.Obj] = typ
		}
	}

	// Type-check an assignment lhs = rhs.
	// If isDecl is true, this is := so we can update
	// the types of the objects that lhs refers to.
	typecheckAssign := func(lhs, rhs []ast.Expr, isDecl bool) {
		if len(lhs) > 1 && len(rhs) == 1 {
			if _, ok := rhs[0].(*ast.CallExpr); ok {
				t := split(typeof[rhs[0]])
				// Lists should have same length but may not; pair what can be paired.
				for i := 0; i < len(lhs) && i < len(t); i++ {
					set(lhs[i], t[i], isDecl)
				}
				return
			}
		}
		if len(lhs) == 1 && len(rhs) == 2 {
			// x = y, ok
			rhs = rhs[:1]
		} else if len(lhs) == 2 && len(rhs) == 1 {
			// x, ok = y
			lhs = lhs[:1]
		}

		// Match as much as we can.
		for i := 0; i < len(lhs) && i < len(rhs); i++ {
			x, y := lhs[i], rhs[i]
			if typeof[y] != "" {
				set(x, typeof[y], isDecl)
			} else {
				set(y, typeof[x], false)
			}
		}
	}

	expand := func(s string) string {
		typ := cfg.Type[s]
		if typ != nil && typ.Def != "" {
			return typ.Def
		}
		return s
	}

	// The main type check is a recursive algorithm implemented
	// by walkBeforeAfter(n, before, after).
	// Most of it is bottom-up, but in a few places we need
	// to know the type of the function we are checking.
	// The before function records that information on
	// the curfn stack.
	var curfn []*ast.FuncType

	before := func(n interface{}) {
		// push function type on stack
		switch n := n.(type) {
		case *ast.FuncDecl:
			curfn = append(curfn, n.Type)
		case *ast.FuncLit:
			curfn = append(curfn, n.Type)
		}
	}

	// After is the real type checker.
	after := func(n interface{}) {
		if n == nil {
			return
		}
		if false && reflect.TypeOf(n).Kind() == reflect.Pointer { // debugging trace
			defer func() {
				if t := typeof[n]; t != "" {
					pos := fset.Position(n.(ast.Node).Pos())
					fmt.Fprintf(os.Stderr, "%s: typeof[%s] = %s\n", pos, gofmt(n), t)
				}
			}()
		}

		switch n := n.(type) {
		case *ast.FuncDecl, *ast.FuncLit:
			// pop function type off stack
			curfn = curfn[:len(curfn)-1]

		case *ast.FuncType:
			typeof[n] = mkType(joinFunc(split(typeof[n.Params]), split(typeof[n.Results])))

		case *ast.FieldList:
			// Field list is concatenation of sub-lists.
			t := ""
			for _, field := range n.List {
				if t != "" {
					t += ", "
				}
				t += typeof[field]
			}
			typeof[n] = t

		case *ast.Field:
			// Field is one instance of the type per name.
			all := ""
			t := typeof[n.Type]
			if !isType(t) {
				// Create a type, because it is typically *T or *p.T
				// and we might care about that type.
				t = mkType(gofmt(n.Type))
				typeof[n.Type] = t
			}
			t = getType(t)
			if len(n.Names) == 0 {
				all = t
			} else {
				for _, id := range n.Names {
					if all != "" {
						all += ", "
					}
					all += t
					typeof[id.Obj] = t
					typeof[id] = t
				}
			}
			typeof[n] = all

		case *ast.ValueSpec:
			// var declaration. Use type if present.
			if n.Type != nil {
				t := typeof[n.Type]
				if !isType(t) {
					t = mkType(gofmt(n.Type))
					typeof[n.Type] = t
				}
				t = getType(t)
				for _, id := range n.Names {
					set(id, t, true)
				}
			}
			// Now treat same as assignment.
			typecheckAssign(makeExprList(n.Names), n.Values, true)

		case *ast.AssignStmt:
			typecheckAssign(n.Lhs, n.Rhs, n.Tok == token.DEFINE)

		case *ast.Ident:
			// Identifier can take its type from underlying object.
			if t := typeof[n.Obj]; t != "" {
				typeof[n] = t
			}

		case *ast.SelectorExpr:
			// Field or method.
			name := n.Sel.Name
			if t := typeof[n.X]; t != "" {
				t = strings.TrimPrefix(t, "*") // implicit *
				if typ := cfg.Type[t]; typ != nil {
					if t := typ.dot(cfg, name); t != "" {
						typeof[n] = t
						return
					}
				}
				tt := typeof[t+"."+name]
				if isType(tt) {
					typeof[n] = getType(tt)
					return
				}
			}
			// Package selector.
			if x, ok := n.X.(*ast.Ident); ok && x.Obj == nil {
				str := x.Name + "." + name
				if cfg.Type[str] != nil {
					typeof[n] = mkType(str)
					return
				}
				if t := cfg.typeof(x.Name + "." + name); t != "" {
					typeof[n] = t
					return
				}
			}

		case *ast.CallExpr:
			// make(T) has type T.
			if isTopName(n.Fun, "make") && len(n.Args) >= 1 {
				typeof[n] = gofmt(n.Args[0])
				return
			}
			// new(T) has type *T
			if isTopName(n.Fun, "new") && len(n.Args) == 1 {
				typeof[n] = "*" + gofmt(n.Args[0])
				return
			}
			// Otherwise, use type of function to determine arguments.
			t := typeof[n.Fun]
			if t == "" {
				t = cfg.External[gofmt(n.Fun)]
			}
			in, out := splitFunc(t)
			if in == nil && out == nil {
				return
			}
			typeof[n] = join(out)
			for i, arg := range n.Args {
				if i >= len(in) {
					break
				}
				if typeof[arg] == "" {
					typeof[arg] = in[i]
				}
			}

		case *ast.TypeAssertExpr:
			// x.(type) has type of x.
			if n.Type == nil {
				typeof[n] = typeof[n.X]
				return
			}
			// x.(T) has type T.
			if t := typeof[n.Type]; isType(t) {
				typeof[n] = getType(t)
			} else {
				typeof[n] = gofmt(n.Type)
			}

		case *ast.SliceExpr:
			// x[i:j] has type of x.
			typeof[n] = typeof[n.X]

		case *ast.IndexExpr:
			// x[i] has key type of x's type.
			t := expand(typeof[n.X])
			if strings.HasPrefix(t, "[") || strings.HasPrefix(t, "map[") {
				// Lazy: assume there are no nested [] in the array
				// length or map key type.
				if _, elem, ok := strings.Cut(t, "]"); ok {
					typeof[n] = elem
				}
			}

		case *ast.StarExpr:
			// *x for x of type *T has type T when x is an expr.
			// We don't use the result when *x is a type, but
			// compute it anyway.
			t := expand(typeof[n.X])
			if isType(t) {
				typeof[n] = "type *" + getType(t)
			} else if strings.HasPrefix(t, "*") {
				typeof[n] = t[len("*"):]
			}

		case *ast.UnaryExpr:
			// &x for x of type T has type *T.
			t := typeof[n.X]
			if t != "" && n.Op == token.AND {
				typeof[n] = "*" + t
			}

		case *ast.CompositeLit:
			// T{...} has type T.
			typeof[n] = gofmt(n.Type)

			// Propagate types down to values used in the composite literal.
			t := expand(typeof[n])
			if strings.HasPrefix(t, "[") { // array or slice
				// Lazy: assume there are no nested [] in the array length.
				if _, et, ok := strings.Cut(t, "]"); ok {
					for _, e := range n.Elts {
						if kv, ok := e.(*ast.KeyValueExpr); ok {
							e = kv.Value
						}
						if typeof[e] == "" {
							typeof[e] = et
						}
					}
				}
			}
			if strings.HasPrefix(t, "map[") { // map
				// Lazy: assume there are no nested [] in the map key type.
				if kt, vt, ok := strings.Cut(t[len("map["):], "]"); ok {
					for _, e := range n.Elts {
						if kv, ok := e.(*ast.KeyValueExpr); ok {
							if typeof[kv.Key] == "" {
								typeof[kv.Key] = kt
							}
							if typeof[kv.Value] == "" {
								typeof[kv.Value] = vt
							}
						}
					}
				}
			}
			if typ := cfg.Type[t]; typ != nil && len(typ.Field) > 0 { // struct
				for _, e := range n.Elts {
					if kv, ok := e.(*ast.KeyValueExpr); ok {
						if ft := typ.Field[fmt.Sprintf("%s", kv.Key)]; ft != "" {
							if typeof[kv.Value] == "" {
								typeof[kv.Value] = ft
							}
						}
					}
				}
			}

		case *ast.ParenExpr:
			// (x) has type of x.
			typeof[n] = typeof[n.X]

		case *ast.RangeStmt:
			t := expand(typeof[n.X])
			if t == "" {
				return
			}
			var key, value string
			if t == "string" {
				key, value = "int", "rune"
			} else if strings.HasPrefix(t, "[") {
				key = "int"
				_, value, _ = strings.Cut(t, "]")
			} else if strings.HasPrefix(t, "map[") {
				if k, v, ok := strings.Cut(t[len("map["):], "]"); ok {
					key, value = k, v
				}
			}
			changed := false
			if n.Key != nil && key != "" {
				changed = true
				set(n.Key, key, n.Tok == token.DEFINE)
			}
			if n.Value != nil && value != "" {
				changed = true
				set(n.Value, value, n.Tok == token.DEFINE)
			}
			// Ugly failure of vision: already type-checked body.
			// Do it again now that we have that type info.
			if changed {
				typecheck1(cfg, n.Body, typeof, assign)
			}

		case *ast.TypeSwitchStmt:
			// Type of variable changes for each case in type switch,
			// but go/parser generates just one variable.
			// Repeat type check for each case with more precise
			// type information.
			as, ok := n.Assign.(*ast.AssignStmt)
			if !ok {
				return
			}
			varx, ok := as.Lhs[0].(*ast.Ident)
			if !ok {
				return
			}
			t := typeof[varx]
			for _, cas := range n.Body.List {
				cas := cas.(*ast.CaseClause)
				if len(cas.List) == 1 {
					// Variable has specific type only when there is
					// exactly one type in the case list.
					if tt := typeof[cas.List[0]]; isType(tt) {
						tt = getType(tt)
						typeof[varx] = tt
						typeof[varx.Obj] = tt
						typecheck1(cfg, cas.Body, typeof, assign)
					}
				}
			}
			// Restore t.
			typeof[varx] = t
			typeof[varx.Obj] = t

		case *ast.ReturnStmt:
			if len(curfn) == 0 {
				// Probably can't happen.
				return
			}
			f := curfn[len(curfn)-1]
			res := n.Results
			if f.Results != nil {
				t := split(typeof[f.Results])
				for i := 0; i < len(res) && i < len(t); i++ {
					set(res[i], t[i], false)
				}
			}

		case *ast.BinaryExpr:
			// Propagate types across binary ops that require two args of the same type.
			switch n.Op {
			case token.EQL, token.NEQ: // TODO: more cases. This is enough for the cftype fix.
				if typeof[n.X] != "" && typeof[n.Y] == "" {
					typeof[n.Y] = typeof[n.X]
				}
				if typeof[n.X] == "" && typeof[n.Y] != "" {
					typeof[n.X] = typeof[n.Y]
				}
			}
		}
	}
	walkBeforeAfter(f, before, after)
}

// Convert between function type strings and lists of types.
// Using strings makes this a little harder, but it makes
// a lot of the rest of the code easier. This will all go away
// when we can use go/typechecker directly.

// splitFunc splits "func(x,y,z) (a,b,c)" into ["x", "y", "z"] and ["a", "b", "c"].
func splitFunc(s string) (in, out []string) {
	if !strings.HasPrefix(s, "func(") {
		return nil, nil
	}

	i := len("func(") // index of beginning of 'in' arguments
	nparen := 0
	for j := i; j < len(s); j++ {
		switch s[j] {
		case '(':
			nparen++
		case ')':
			nparen--
			if nparen < 0 {
				// found end of parameter list
				out := strings.TrimSpace(s[j+1:])
				if len(out) >= 2 && out[0] == '(' && out[len(out)-1] == ')' {
					out = out[1 : len(out)-1]
				}
				return split(s[i:j]), split(out)
			}
		}
	}
	return nil, nil
}

// joinFunc is the inverse of splitFunc.
func joinFunc(in, out []string) string {
	outs := ""
	if len(out) == 1 {
		outs = " " + out[0]
	} else if len(out) > 1 {
		outs = " (" + join(out) + ")"
	}
	return "func(" + join(in) + ")" + outs
}

// split splits "int, float" into ["int", "float"] and splits "" into [].
func split(s string) []string {
	out := []string{}
	i := 0 // current type being scanned is s[i:j].
	nparen := 0
	for j := 0; j < len(s); j++ {
		switch s[j] {
		case ' ':
			if i == j {
				i++
			}
		case '(':
			nparen++
		case ')':
			nparen--
			if nparen < 0 {
				// probably can't happen
				return nil
			}
		case ',':
			if nparen == 0 {
				if i < j {
					out = append(out, s[i:j])
				}
				i = j + 1
			}
		}
	}
	if nparen != 0 {
		// probably can't happen
		return nil
	}
	if i < len(s) {
		out = append(out, s[i:])
	}
	return out
}

// join is the inverse of split.
func join(x []string) string {
	return strings.Join(x, ", ")
}
