| // Copyright 2011 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package main |
| |
| import ( |
| "fmt" |
| "go/ast" |
| "go/parser" |
| "go/token" |
| exec "internal/execabs" |
| "os" |
| "path/filepath" |
| "reflect" |
| "runtime" |
| "strings" |
| ) |
| |
| // Partial type checker. |
| // |
| // The fact that it is partial is very important: the input is |
| // an AST and a description of some type information to |
| // assume about one or more packages, but not all the |
| // packages that the program imports. The checker is |
| // expected to do as much as it can with what it has been |
| // given. There is not enough information supplied to do |
| // a full type check, but the type checker is expected to |
| // apply information that can be derived from variable |
| // declarations, function and method returns, and type switches |
| // as far as it can, so that the caller can still tell the types |
| // of expression relevant to a particular fix. |
| // |
| // TODO(rsc,gri): Replace with go/typechecker. |
| // Doing that could be an interesting test case for go/typechecker: |
| // the constraints about working with partial information will |
| // likely exercise it in interesting ways. The ideal interface would |
| // be to pass typecheck a map from importpath to package API text |
| // (Go source code), but for now we use data structures (TypeConfig, Type). |
| // |
| // The strings mostly use gofmt form. |
| // |
| // A Field or FieldList has as its type a comma-separated list |
| // of the types of the fields. For example, the field list |
| // x, y, z int |
| // has type "int, int, int". |
| |
| // The prefix "type " is the type of a type. |
| // For example, given |
| // var x int |
| // type T int |
| // x's type is "int" but T's type is "type int". |
| // mkType inserts the "type " prefix. |
| // getType removes it. |
| // isType tests for it. |
| |
| func mkType(t string) string { |
| return "type " + t |
| } |
| |
| func getType(t string) string { |
| if !isType(t) { |
| return "" |
| } |
| return t[len("type "):] |
| } |
| |
| func isType(t string) bool { |
| return strings.HasPrefix(t, "type ") |
| } |
| |
| // TypeConfig describes the universe of relevant types. |
| // For ease of creation, the types are all referred to by string |
| // name (e.g., "reflect.Value"). TypeByName is the only place |
| // where the strings are resolved. |
| |
| type TypeConfig struct { |
| Type map[string]*Type |
| Var map[string]string |
| Func map[string]string |
| |
| // External maps from a name to its type. |
| // It provides additional typings not present in the Go source itself. |
| // For now, the only additional typings are those generated by cgo. |
| External map[string]string |
| } |
| |
| // typeof returns the type of the given name, which may be of |
| // the form "x" or "p.X". |
| func (cfg *TypeConfig) typeof(name string) string { |
| if cfg.Var != nil { |
| if t := cfg.Var[name]; t != "" { |
| return t |
| } |
| } |
| if cfg.Func != nil { |
| if t := cfg.Func[name]; t != "" { |
| return "func()" + t |
| } |
| } |
| return "" |
| } |
| |
| // Type describes the Fields and Methods of a type. |
| // If the field or method cannot be found there, it is next |
| // looked for in the Embed list. |
| type Type struct { |
| Field map[string]string // map field name to type |
| Method map[string]string // map method name to comma-separated return types (should start with "func ") |
| Embed []string // list of types this type embeds (for extra methods) |
| Def string // definition of named type |
| } |
| |
| // dot returns the type of "typ.name", making its decision |
| // using the type information in cfg. |
| func (typ *Type) dot(cfg *TypeConfig, name string) string { |
| if typ.Field != nil { |
| if t := typ.Field[name]; t != "" { |
| return t |
| } |
| } |
| if typ.Method != nil { |
| if t := typ.Method[name]; t != "" { |
| return t |
| } |
| } |
| |
| for _, e := range typ.Embed { |
| etyp := cfg.Type[e] |
| if etyp != nil { |
| if t := etyp.dot(cfg, name); t != "" { |
| return t |
| } |
| } |
| } |
| |
| return "" |
| } |
| |
| // typecheck type checks the AST f assuming the information in cfg. |
| // It returns two maps with type information: |
| // typeof maps AST nodes to type information in gofmt string form. |
| // assign maps type strings to lists of expressions that were assigned |
| // to values of another type that were assigned to that type. |
| func typecheck(cfg *TypeConfig, f *ast.File) (typeof map[any]string, assign map[string][]any) { |
| typeof = make(map[any]string) |
| assign = make(map[string][]any) |
| cfg1 := &TypeConfig{} |
| *cfg1 = *cfg // make copy so we can add locally |
| copied := false |
| |
| // If we import "C", add types of cgo objects. |
| cfg.External = map[string]string{} |
| cfg1.External = cfg.External |
| if imports(f, "C") { |
| // Run cgo on gofmtFile(f) |
| // Parse, extract decls from _cgo_gotypes.go |
| // Map _Ctype_* types to C.* types. |
| err := func() error { |
| txt, err := gofmtFile(f) |
| if err != nil { |
| return err |
| } |
| dir, err := os.MkdirTemp(os.TempDir(), "fix_cgo_typecheck") |
| if err != nil { |
| return err |
| } |
| defer os.RemoveAll(dir) |
| err = os.WriteFile(filepath.Join(dir, "in.go"), txt, 0600) |
| if err != nil { |
| return err |
| } |
| cmd := exec.Command(filepath.Join(runtime.GOROOT(), "bin", "go"), "tool", "cgo", "-objdir", dir, "-srcdir", dir, "in.go") |
| err = cmd.Run() |
| if err != nil { |
| return err |
| } |
| out, err := os.ReadFile(filepath.Join(dir, "_cgo_gotypes.go")) |
| if err != nil { |
| return err |
| } |
| cgo, err := parser.ParseFile(token.NewFileSet(), "cgo.go", out, 0) |
| if err != nil { |
| return err |
| } |
| for _, decl := range cgo.Decls { |
| fn, ok := decl.(*ast.FuncDecl) |
| if !ok { |
| continue |
| } |
| if strings.HasPrefix(fn.Name.Name, "_Cfunc_") { |
| var params, results []string |
| for _, p := range fn.Type.Params.List { |
| t := gofmt(p.Type) |
| t = strings.ReplaceAll(t, "_Ctype_", "C.") |
| params = append(params, t) |
| } |
| for _, r := range fn.Type.Results.List { |
| t := gofmt(r.Type) |
| t = strings.ReplaceAll(t, "_Ctype_", "C.") |
| results = append(results, t) |
| } |
| cfg.External["C."+fn.Name.Name[7:]] = joinFunc(params, results) |
| } |
| } |
| return nil |
| }() |
| if err != nil { |
| fmt.Fprintf(os.Stderr, "go fix: warning: no cgo types: %s\n", err) |
| } |
| } |
| |
| // gather function declarations |
| for _, decl := range f.Decls { |
| fn, ok := decl.(*ast.FuncDecl) |
| if !ok { |
| continue |
| } |
| typecheck1(cfg, fn.Type, typeof, assign) |
| t := typeof[fn.Type] |
| if fn.Recv != nil { |
| // The receiver must be a type. |
| rcvr := typeof[fn.Recv] |
| if !isType(rcvr) { |
| if len(fn.Recv.List) != 1 { |
| continue |
| } |
| rcvr = mkType(gofmt(fn.Recv.List[0].Type)) |
| typeof[fn.Recv.List[0].Type] = rcvr |
| } |
| rcvr = getType(rcvr) |
| if rcvr != "" && rcvr[0] == '*' { |
| rcvr = rcvr[1:] |
| } |
| typeof[rcvr+"."+fn.Name.Name] = t |
| } else { |
| if isType(t) { |
| t = getType(t) |
| } else { |
| t = gofmt(fn.Type) |
| } |
| typeof[fn.Name] = t |
| |
| // Record typeof[fn.Name.Obj] for future references to fn.Name. |
| typeof[fn.Name.Obj] = t |
| } |
| } |
| |
| // gather struct declarations |
| for _, decl := range f.Decls { |
| d, ok := decl.(*ast.GenDecl) |
| if ok { |
| for _, s := range d.Specs { |
| switch s := s.(type) { |
| case *ast.TypeSpec: |
| if cfg1.Type[s.Name.Name] != nil { |
| break |
| } |
| if !copied { |
| copied = true |
| // Copy map lazily: it's time. |
| cfg1.Type = make(map[string]*Type) |
| for k, v := range cfg.Type { |
| cfg1.Type[k] = v |
| } |
| } |
| t := &Type{Field: map[string]string{}} |
| cfg1.Type[s.Name.Name] = t |
| switch st := s.Type.(type) { |
| case *ast.StructType: |
| for _, f := range st.Fields.List { |
| for _, n := range f.Names { |
| t.Field[n.Name] = gofmt(f.Type) |
| } |
| } |
| case *ast.ArrayType, *ast.StarExpr, *ast.MapType: |
| t.Def = gofmt(st) |
| } |
| } |
| } |
| } |
| } |
| |
| typecheck1(cfg1, f, typeof, assign) |
| return typeof, assign |
| } |
| |
| func makeExprList(a []*ast.Ident) []ast.Expr { |
| var b []ast.Expr |
| for _, x := range a { |
| b = append(b, x) |
| } |
| return b |
| } |
| |
| // Typecheck1 is the recursive form of typecheck. |
| // It is like typecheck but adds to the information in typeof |
| // instead of allocating a new map. |
| func typecheck1(cfg *TypeConfig, f any, typeof map[any]string, assign map[string][]any) { |
| // set sets the type of n to typ. |
| // If isDecl is true, n is being declared. |
| set := func(n ast.Expr, typ string, isDecl bool) { |
| if typeof[n] != "" || typ == "" { |
| if typeof[n] != typ { |
| assign[typ] = append(assign[typ], n) |
| } |
| return |
| } |
| typeof[n] = typ |
| |
| // If we obtained typ from the declaration of x |
| // propagate the type to all the uses. |
| // The !isDecl case is a cheat here, but it makes |
| // up in some cases for not paying attention to |
| // struct fields. The real type checker will be |
| // more accurate so we won't need the cheat. |
| if id, ok := n.(*ast.Ident); ok && id.Obj != nil && (isDecl || typeof[id.Obj] == "") { |
| typeof[id.Obj] = typ |
| } |
| } |
| |
| // Type-check an assignment lhs = rhs. |
| // If isDecl is true, this is := so we can update |
| // the types of the objects that lhs refers to. |
| typecheckAssign := func(lhs, rhs []ast.Expr, isDecl bool) { |
| if len(lhs) > 1 && len(rhs) == 1 { |
| if _, ok := rhs[0].(*ast.CallExpr); ok { |
| t := split(typeof[rhs[0]]) |
| // Lists should have same length but may not; pair what can be paired. |
| for i := 0; i < len(lhs) && i < len(t); i++ { |
| set(lhs[i], t[i], isDecl) |
| } |
| return |
| } |
| } |
| if len(lhs) == 1 && len(rhs) == 2 { |
| // x = y, ok |
| rhs = rhs[:1] |
| } else if len(lhs) == 2 && len(rhs) == 1 { |
| // x, ok = y |
| lhs = lhs[:1] |
| } |
| |
| // Match as much as we can. |
| for i := 0; i < len(lhs) && i < len(rhs); i++ { |
| x, y := lhs[i], rhs[i] |
| if typeof[y] != "" { |
| set(x, typeof[y], isDecl) |
| } else { |
| set(y, typeof[x], false) |
| } |
| } |
| } |
| |
| expand := func(s string) string { |
| typ := cfg.Type[s] |
| if typ != nil && typ.Def != "" { |
| return typ.Def |
| } |
| return s |
| } |
| |
| // The main type check is a recursive algorithm implemented |
| // by walkBeforeAfter(n, before, after). |
| // Most of it is bottom-up, but in a few places we need |
| // to know the type of the function we are checking. |
| // The before function records that information on |
| // the curfn stack. |
| var curfn []*ast.FuncType |
| |
| before := func(n any) { |
| // push function type on stack |
| switch n := n.(type) { |
| case *ast.FuncDecl: |
| curfn = append(curfn, n.Type) |
| case *ast.FuncLit: |
| curfn = append(curfn, n.Type) |
| } |
| } |
| |
| // After is the real type checker. |
| after := func(n any) { |
| if n == nil { |
| return |
| } |
| if false && reflect.TypeOf(n).Kind() == reflect.Pointer { // debugging trace |
| defer func() { |
| if t := typeof[n]; t != "" { |
| pos := fset.Position(n.(ast.Node).Pos()) |
| fmt.Fprintf(os.Stderr, "%s: typeof[%s] = %s\n", pos, gofmt(n), t) |
| } |
| }() |
| } |
| |
| switch n := n.(type) { |
| case *ast.FuncDecl, *ast.FuncLit: |
| // pop function type off stack |
| curfn = curfn[:len(curfn)-1] |
| |
| case *ast.FuncType: |
| typeof[n] = mkType(joinFunc(split(typeof[n.Params]), split(typeof[n.Results]))) |
| |
| case *ast.FieldList: |
| // Field list is concatenation of sub-lists. |
| t := "" |
| for _, field := range n.List { |
| if t != "" { |
| t += ", " |
| } |
| t += typeof[field] |
| } |
| typeof[n] = t |
| |
| case *ast.Field: |
| // Field is one instance of the type per name. |
| all := "" |
| t := typeof[n.Type] |
| if !isType(t) { |
| // Create a type, because it is typically *T or *p.T |
| // and we might care about that type. |
| t = mkType(gofmt(n.Type)) |
| typeof[n.Type] = t |
| } |
| t = getType(t) |
| if len(n.Names) == 0 { |
| all = t |
| } else { |
| for _, id := range n.Names { |
| if all != "" { |
| all += ", " |
| } |
| all += t |
| typeof[id.Obj] = t |
| typeof[id] = t |
| } |
| } |
| typeof[n] = all |
| |
| case *ast.ValueSpec: |
| // var declaration. Use type if present. |
| if n.Type != nil { |
| t := typeof[n.Type] |
| if !isType(t) { |
| t = mkType(gofmt(n.Type)) |
| typeof[n.Type] = t |
| } |
| t = getType(t) |
| for _, id := range n.Names { |
| set(id, t, true) |
| } |
| } |
| // Now treat same as assignment. |
| typecheckAssign(makeExprList(n.Names), n.Values, true) |
| |
| case *ast.AssignStmt: |
| typecheckAssign(n.Lhs, n.Rhs, n.Tok == token.DEFINE) |
| |
| case *ast.Ident: |
| // Identifier can take its type from underlying object. |
| if t := typeof[n.Obj]; t != "" { |
| typeof[n] = t |
| } |
| |
| case *ast.SelectorExpr: |
| // Field or method. |
| name := n.Sel.Name |
| if t := typeof[n.X]; t != "" { |
| t = strings.TrimPrefix(t, "*") // implicit * |
| if typ := cfg.Type[t]; typ != nil { |
| if t := typ.dot(cfg, name); t != "" { |
| typeof[n] = t |
| return |
| } |
| } |
| tt := typeof[t+"."+name] |
| if isType(tt) { |
| typeof[n] = getType(tt) |
| return |
| } |
| } |
| // Package selector. |
| if x, ok := n.X.(*ast.Ident); ok && x.Obj == nil { |
| str := x.Name + "." + name |
| if cfg.Type[str] != nil { |
| typeof[n] = mkType(str) |
| return |
| } |
| if t := cfg.typeof(x.Name + "." + name); t != "" { |
| typeof[n] = t |
| return |
| } |
| } |
| |
| case *ast.CallExpr: |
| // make(T) has type T. |
| if isTopName(n.Fun, "make") && len(n.Args) >= 1 { |
| typeof[n] = gofmt(n.Args[0]) |
| return |
| } |
| // new(T) has type *T |
| if isTopName(n.Fun, "new") && len(n.Args) == 1 { |
| typeof[n] = "*" + gofmt(n.Args[0]) |
| return |
| } |
| // Otherwise, use type of function to determine arguments. |
| t := typeof[n.Fun] |
| if t == "" { |
| t = cfg.External[gofmt(n.Fun)] |
| } |
| in, out := splitFunc(t) |
| if in == nil && out == nil { |
| return |
| } |
| typeof[n] = join(out) |
| for i, arg := range n.Args { |
| if i >= len(in) { |
| break |
| } |
| if typeof[arg] == "" { |
| typeof[arg] = in[i] |
| } |
| } |
| |
| case *ast.TypeAssertExpr: |
| // x.(type) has type of x. |
| if n.Type == nil { |
| typeof[n] = typeof[n.X] |
| return |
| } |
| // x.(T) has type T. |
| if t := typeof[n.Type]; isType(t) { |
| typeof[n] = getType(t) |
| } else { |
| typeof[n] = gofmt(n.Type) |
| } |
| |
| case *ast.SliceExpr: |
| // x[i:j] has type of x. |
| typeof[n] = typeof[n.X] |
| |
| case *ast.IndexExpr: |
| // x[i] has key type of x's type. |
| t := expand(typeof[n.X]) |
| if strings.HasPrefix(t, "[") || strings.HasPrefix(t, "map[") { |
| // Lazy: assume there are no nested [] in the array |
| // length or map key type. |
| if _, elem, ok := strings.Cut(t, "]"); ok { |
| typeof[n] = elem |
| } |
| } |
| |
| case *ast.StarExpr: |
| // *x for x of type *T has type T when x is an expr. |
| // We don't use the result when *x is a type, but |
| // compute it anyway. |
| t := expand(typeof[n.X]) |
| if isType(t) { |
| typeof[n] = "type *" + getType(t) |
| } else if strings.HasPrefix(t, "*") { |
| typeof[n] = t[len("*"):] |
| } |
| |
| case *ast.UnaryExpr: |
| // &x for x of type T has type *T. |
| t := typeof[n.X] |
| if t != "" && n.Op == token.AND { |
| typeof[n] = "*" + t |
| } |
| |
| case *ast.CompositeLit: |
| // T{...} has type T. |
| typeof[n] = gofmt(n.Type) |
| |
| // Propagate types down to values used in the composite literal. |
| t := expand(typeof[n]) |
| if strings.HasPrefix(t, "[") { // array or slice |
| // Lazy: assume there are no nested [] in the array length. |
| if _, et, ok := strings.Cut(t, "]"); ok { |
| for _, e := range n.Elts { |
| if kv, ok := e.(*ast.KeyValueExpr); ok { |
| e = kv.Value |
| } |
| if typeof[e] == "" { |
| typeof[e] = et |
| } |
| } |
| } |
| } |
| if strings.HasPrefix(t, "map[") { // map |
| // Lazy: assume there are no nested [] in the map key type. |
| if kt, vt, ok := strings.Cut(t[len("map["):], "]"); ok { |
| for _, e := range n.Elts { |
| if kv, ok := e.(*ast.KeyValueExpr); ok { |
| if typeof[kv.Key] == "" { |
| typeof[kv.Key] = kt |
| } |
| if typeof[kv.Value] == "" { |
| typeof[kv.Value] = vt |
| } |
| } |
| } |
| } |
| } |
| if typ := cfg.Type[t]; typ != nil && len(typ.Field) > 0 { // struct |
| for _, e := range n.Elts { |
| if kv, ok := e.(*ast.KeyValueExpr); ok { |
| if ft := typ.Field[fmt.Sprintf("%s", kv.Key)]; ft != "" { |
| if typeof[kv.Value] == "" { |
| typeof[kv.Value] = ft |
| } |
| } |
| } |
| } |
| } |
| |
| case *ast.ParenExpr: |
| // (x) has type of x. |
| typeof[n] = typeof[n.X] |
| |
| case *ast.RangeStmt: |
| t := expand(typeof[n.X]) |
| if t == "" { |
| return |
| } |
| var key, value string |
| if t == "string" { |
| key, value = "int", "rune" |
| } else if strings.HasPrefix(t, "[") { |
| key = "int" |
| _, value, _ = strings.Cut(t, "]") |
| } else if strings.HasPrefix(t, "map[") { |
| if k, v, ok := strings.Cut(t[len("map["):], "]"); ok { |
| key, value = k, v |
| } |
| } |
| changed := false |
| if n.Key != nil && key != "" { |
| changed = true |
| set(n.Key, key, n.Tok == token.DEFINE) |
| } |
| if n.Value != nil && value != "" { |
| changed = true |
| set(n.Value, value, n.Tok == token.DEFINE) |
| } |
| // Ugly failure of vision: already type-checked body. |
| // Do it again now that we have that type info. |
| if changed { |
| typecheck1(cfg, n.Body, typeof, assign) |
| } |
| |
| case *ast.TypeSwitchStmt: |
| // Type of variable changes for each case in type switch, |
| // but go/parser generates just one variable. |
| // Repeat type check for each case with more precise |
| // type information. |
| as, ok := n.Assign.(*ast.AssignStmt) |
| if !ok { |
| return |
| } |
| varx, ok := as.Lhs[0].(*ast.Ident) |
| if !ok { |
| return |
| } |
| t := typeof[varx] |
| for _, cas := range n.Body.List { |
| cas := cas.(*ast.CaseClause) |
| if len(cas.List) == 1 { |
| // Variable has specific type only when there is |
| // exactly one type in the case list. |
| if tt := typeof[cas.List[0]]; isType(tt) { |
| tt = getType(tt) |
| typeof[varx] = tt |
| typeof[varx.Obj] = tt |
| typecheck1(cfg, cas.Body, typeof, assign) |
| } |
| } |
| } |
| // Restore t. |
| typeof[varx] = t |
| typeof[varx.Obj] = t |
| |
| case *ast.ReturnStmt: |
| if len(curfn) == 0 { |
| // Probably can't happen. |
| return |
| } |
| f := curfn[len(curfn)-1] |
| res := n.Results |
| if f.Results != nil { |
| t := split(typeof[f.Results]) |
| for i := 0; i < len(res) && i < len(t); i++ { |
| set(res[i], t[i], false) |
| } |
| } |
| |
| case *ast.BinaryExpr: |
| // Propagate types across binary ops that require two args of the same type. |
| switch n.Op { |
| case token.EQL, token.NEQ: // TODO: more cases. This is enough for the cftype fix. |
| if typeof[n.X] != "" && typeof[n.Y] == "" { |
| typeof[n.Y] = typeof[n.X] |
| } |
| if typeof[n.X] == "" && typeof[n.Y] != "" { |
| typeof[n.X] = typeof[n.Y] |
| } |
| } |
| } |
| } |
| walkBeforeAfter(f, before, after) |
| } |
| |
| // Convert between function type strings and lists of types. |
| // Using strings makes this a little harder, but it makes |
| // a lot of the rest of the code easier. This will all go away |
| // when we can use go/typechecker directly. |
| |
| // splitFunc splits "func(x,y,z) (a,b,c)" into ["x", "y", "z"] and ["a", "b", "c"]. |
| func splitFunc(s string) (in, out []string) { |
| if !strings.HasPrefix(s, "func(") { |
| return nil, nil |
| } |
| |
| i := len("func(") // index of beginning of 'in' arguments |
| nparen := 0 |
| for j := i; j < len(s); j++ { |
| switch s[j] { |
| case '(': |
| nparen++ |
| case ')': |
| nparen-- |
| if nparen < 0 { |
| // found end of parameter list |
| out := strings.TrimSpace(s[j+1:]) |
| if len(out) >= 2 && out[0] == '(' && out[len(out)-1] == ')' { |
| out = out[1 : len(out)-1] |
| } |
| return split(s[i:j]), split(out) |
| } |
| } |
| } |
| return nil, nil |
| } |
| |
| // joinFunc is the inverse of splitFunc. |
| func joinFunc(in, out []string) string { |
| outs := "" |
| if len(out) == 1 { |
| outs = " " + out[0] |
| } else if len(out) > 1 { |
| outs = " (" + join(out) + ")" |
| } |
| return "func(" + join(in) + ")" + outs |
| } |
| |
| // split splits "int, float" into ["int", "float"] and splits "" into []. |
| func split(s string) []string { |
| out := []string{} |
| i := 0 // current type being scanned is s[i:j]. |
| nparen := 0 |
| for j := 0; j < len(s); j++ { |
| switch s[j] { |
| case ' ': |
| if i == j { |
| i++ |
| } |
| case '(': |
| nparen++ |
| case ')': |
| nparen-- |
| if nparen < 0 { |
| // probably can't happen |
| return nil |
| } |
| case ',': |
| if nparen == 0 { |
| if i < j { |
| out = append(out, s[i:j]) |
| } |
| i = j + 1 |
| } |
| } |
| } |
| if nparen != 0 { |
| // probably can't happen |
| return nil |
| } |
| if i < len(s) { |
| out = append(out, s[i:]) |
| } |
| return out |
| } |
| |
| // join is the inverse of split. |
| func join(x []string) string { |
| return strings.Join(x, ", ") |
| } |