| // Copyright 2019 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package cache |
| |
| import ( |
| "bytes" |
| "context" |
| "fmt" |
| "go/ast" |
| "go/parser" |
| "go/scanner" |
| "go/token" |
| "go/types" |
| "reflect" |
| "strconv" |
| "strings" |
| |
| "golang.org/x/tools/internal/event" |
| "golang.org/x/tools/internal/lsp/debug/tag" |
| "golang.org/x/tools/internal/lsp/diff" |
| "golang.org/x/tools/internal/lsp/diff/myers" |
| "golang.org/x/tools/internal/lsp/protocol" |
| "golang.org/x/tools/internal/lsp/source" |
| "golang.org/x/tools/internal/memoize" |
| "golang.org/x/tools/internal/span" |
| errors "golang.org/x/xerrors" |
| ) |
| |
| // parseKey uniquely identifies a parsed Go file. |
| type parseKey struct { |
| file source.FileIdentity |
| mode source.ParseMode |
| } |
| |
| type parseGoHandle struct { |
| handle *memoize.Handle |
| file source.FileHandle |
| mode source.ParseMode |
| } |
| |
| type parseGoData struct { |
| parsed *source.ParsedGoFile |
| |
| // If true, we adjusted the AST to make it type check better, and |
| // it may not match the source code. |
| fixed bool |
| err error // any other errors |
| } |
| |
| func (s *snapshot) parseGoHandle(ctx context.Context, fh source.FileHandle, mode source.ParseMode) *parseGoHandle { |
| key := parseKey{ |
| file: fh.FileIdentity(), |
| mode: mode, |
| } |
| if pgh := s.getGoFile(key); pgh != nil { |
| return pgh |
| } |
| parseHandle := s.generation.Bind(key, func(ctx context.Context, arg memoize.Arg) interface{} { |
| snapshot := arg.(*snapshot) |
| return parseGo(ctx, snapshot.FileSet(), fh, mode) |
| }, nil) |
| |
| pgh := &parseGoHandle{ |
| handle: parseHandle, |
| file: fh, |
| mode: mode, |
| } |
| return s.addGoFile(key, pgh) |
| } |
| |
| func (pgh *parseGoHandle) String() string { |
| return pgh.File().URI().Filename() |
| } |
| |
| func (pgh *parseGoHandle) File() source.FileHandle { |
| return pgh.file |
| } |
| |
| func (pgh *parseGoHandle) Mode() source.ParseMode { |
| return pgh.mode |
| } |
| |
| func (s *snapshot) ParseGo(ctx context.Context, fh source.FileHandle, mode source.ParseMode) (*source.ParsedGoFile, error) { |
| pgh := s.parseGoHandle(ctx, fh, mode) |
| pgf, _, err := s.parseGo(ctx, pgh) |
| return pgf, err |
| } |
| |
| func (s *snapshot) parseGo(ctx context.Context, pgh *parseGoHandle) (*source.ParsedGoFile, bool, error) { |
| if pgh.mode == source.ParseExported { |
| panic("only type checking should use Exported") |
| } |
| d, err := pgh.handle.Get(ctx, s.generation, s) |
| if err != nil { |
| return nil, false, err |
| } |
| data := d.(*parseGoData) |
| return data.parsed, data.fixed, data.err |
| } |
| |
| type astCacheKey struct { |
| pkg packageHandleKey |
| uri span.URI |
| } |
| |
| func (s *snapshot) astCacheData(ctx context.Context, spkg source.Package, pos token.Pos) (*astCacheData, error) { |
| pkg := spkg.(*pkg) |
| pkgHandle := s.getPackage(pkg.m.ID, pkg.mode) |
| if pkgHandle == nil { |
| return nil, fmt.Errorf("could not reconstruct package handle for %v", pkg.m.ID) |
| } |
| tok := s.FileSet().File(pos) |
| if tok == nil { |
| return nil, fmt.Errorf("no file for pos %v", pos) |
| } |
| pgf, err := pkg.File(span.URIFromPath(tok.Name())) |
| if err != nil { |
| return nil, err |
| } |
| astHandle := s.generation.Bind(astCacheKey{pkgHandle.key, pgf.URI}, func(ctx context.Context, arg memoize.Arg) interface{} { |
| return buildASTCache(pgf) |
| }, nil) |
| |
| d, err := astHandle.Get(ctx, s.generation, s) |
| if err != nil { |
| return nil, err |
| } |
| data := d.(*astCacheData) |
| if data.err != nil { |
| return nil, data.err |
| } |
| return data, nil |
| } |
| |
| func (s *snapshot) PosToDecl(ctx context.Context, spkg source.Package, pos token.Pos) (ast.Decl, error) { |
| data, err := s.astCacheData(ctx, spkg, pos) |
| if err != nil { |
| return nil, err |
| } |
| return data.posToDecl[pos], nil |
| } |
| |
| func (s *snapshot) PosToField(ctx context.Context, spkg source.Package, pos token.Pos) (*ast.Field, error) { |
| data, err := s.astCacheData(ctx, spkg, pos) |
| if err != nil { |
| return nil, err |
| } |
| return data.posToField[pos], nil |
| } |
| |
| type astCacheData struct { |
| err error |
| |
| posToDecl map[token.Pos]ast.Decl |
| posToField map[token.Pos]*ast.Field |
| } |
| |
| // buildASTCache builds caches to aid in quickly going from the typed |
| // world to the syntactic world. |
| func buildASTCache(pgf *source.ParsedGoFile) *astCacheData { |
| var ( |
| // path contains all ancestors, including n. |
| path []ast.Node |
| // decls contains all ancestors that are decls. |
| decls []ast.Decl |
| ) |
| |
| data := &astCacheData{ |
| posToDecl: make(map[token.Pos]ast.Decl), |
| posToField: make(map[token.Pos]*ast.Field), |
| } |
| |
| ast.Inspect(pgf.File, func(n ast.Node) bool { |
| if n == nil { |
| lastP := path[len(path)-1] |
| path = path[:len(path)-1] |
| if len(decls) > 0 && decls[len(decls)-1] == lastP { |
| decls = decls[:len(decls)-1] |
| } |
| return false |
| } |
| |
| path = append(path, n) |
| |
| switch n := n.(type) { |
| case *ast.Field: |
| addField := func(f ast.Node) { |
| if f.Pos().IsValid() { |
| data.posToField[f.Pos()] = n |
| if len(decls) > 0 { |
| data.posToDecl[f.Pos()] = decls[len(decls)-1] |
| } |
| } |
| } |
| |
| // Add mapping for *ast.Field itself. This handles embedded |
| // fields which have no associated *ast.Ident name. |
| addField(n) |
| |
| // Add mapping for each field name since you can have |
| // multiple names for the same type expression. |
| for _, name := range n.Names { |
| addField(name) |
| } |
| |
| // Also map "X" in "...X" to the containing *ast.Field. This |
| // makes it easy to format variadic signature params |
| // properly. |
| if elips, ok := n.Type.(*ast.Ellipsis); ok && elips.Elt != nil { |
| addField(elips.Elt) |
| } |
| case *ast.FuncDecl: |
| decls = append(decls, n) |
| |
| if n.Name != nil && n.Name.Pos().IsValid() { |
| data.posToDecl[n.Name.Pos()] = n |
| } |
| case *ast.GenDecl: |
| decls = append(decls, n) |
| |
| for _, spec := range n.Specs { |
| switch spec := spec.(type) { |
| case *ast.TypeSpec: |
| if spec.Name != nil && spec.Name.Pos().IsValid() { |
| data.posToDecl[spec.Name.Pos()] = n |
| } |
| case *ast.ValueSpec: |
| for _, id := range spec.Names { |
| if id != nil && id.Pos().IsValid() { |
| data.posToDecl[id.Pos()] = n |
| } |
| } |
| } |
| } |
| } |
| |
| return true |
| }) |
| |
| return data |
| } |
| |
| func parseGo(ctx context.Context, fset *token.FileSet, fh source.FileHandle, mode source.ParseMode) *parseGoData { |
| ctx, done := event.Start(ctx, "cache.parseGo", tag.File.Of(fh.URI().Filename())) |
| defer done() |
| |
| if fh.Kind() != source.Go { |
| return &parseGoData{err: errors.Errorf("cannot parse non-Go file %s", fh.URI())} |
| } |
| src, err := fh.Read() |
| if err != nil { |
| return &parseGoData{err: err} |
| } |
| |
| parserMode := parser.AllErrors | parser.ParseComments |
| if mode == source.ParseHeader { |
| parserMode = parser.ImportsOnly | parser.ParseComments |
| } |
| |
| file, err := parser.ParseFile(fset, fh.URI().Filename(), src, parserMode) |
| var parseErr scanner.ErrorList |
| if err != nil { |
| // We passed a byte slice, so the only possible error is a parse error. |
| parseErr = err.(scanner.ErrorList) |
| } |
| |
| tok := fset.File(file.Pos()) |
| if tok == nil { |
| // file.Pos is the location of the package declaration. If there was |
| // none, we can't find the token.File that ParseFile created, and we |
| // have no choice but to recreate it. |
| tok = fset.AddFile(fh.URI().Filename(), -1, len(src)) |
| tok.SetLinesForContent(src) |
| } |
| |
| fixed := false |
| // If there were parse errors, attempt to fix them up. |
| if parseErr != nil { |
| // Fix any badly parsed parts of the AST. |
| fixed = fixAST(ctx, file, tok, src) |
| |
| for i := 0; i < 10; i++ { |
| // Fix certain syntax errors that render the file unparseable. |
| newSrc := fixSrc(file, tok, src) |
| if newSrc == nil { |
| break |
| } |
| |
| // If we thought there was something to fix 10 times in a row, |
| // it is likely we got stuck in a loop somehow. Log out a diff |
| // of the last changes we made to aid in debugging. |
| if i == 9 { |
| edits, err := myers.ComputeEdits(fh.URI(), string(src), string(newSrc)) |
| if err != nil { |
| event.Error(ctx, "error generating fixSrc diff", err, tag.File.Of(tok.Name())) |
| } else { |
| unified := diff.ToUnified("before", "after", string(src), edits) |
| event.Log(ctx, fmt.Sprintf("fixSrc loop - last diff:\n%v", unified), tag.File.Of(tok.Name())) |
| } |
| } |
| |
| newFile, _ := parser.ParseFile(fset, fh.URI().Filename(), newSrc, parserMode) |
| if newFile != nil { |
| // Maintain the original parseError so we don't try formatting the doctored file. |
| file = newFile |
| src = newSrc |
| tok = fset.File(file.Pos()) |
| |
| fixed = fixAST(ctx, file, tok, src) |
| } |
| } |
| } |
| |
| return &parseGoData{ |
| parsed: &source.ParsedGoFile{ |
| URI: fh.URI(), |
| Mode: mode, |
| Src: src, |
| File: file, |
| Tok: tok, |
| Mapper: &protocol.ColumnMapper{ |
| URI: fh.URI(), |
| Converter: span.NewTokenConverter(fset, tok), |
| Content: src, |
| }, |
| ParseErr: parseErr, |
| }, |
| fixed: fixed, |
| } |
| } |
| |
| // An unexportedFilter removes as much unexported AST from a set of Files as possible. |
| type unexportedFilter struct { |
| uses map[string]bool |
| } |
| |
| // Filter records uses of unexported identifiers and filters out all other |
| // unexported declarations. |
| func (f *unexportedFilter) Filter(files []*ast.File) { |
| // Iterate to fixed point -- unexported types can include other unexported types. |
| oldLen := len(f.uses) |
| for { |
| for _, file := range files { |
| f.recordUses(file) |
| } |
| if len(f.uses) == oldLen { |
| break |
| } |
| oldLen = len(f.uses) |
| } |
| |
| for _, file := range files { |
| var newDecls []ast.Decl |
| for _, decl := range file.Decls { |
| if f.filterDecl(decl) { |
| newDecls = append(newDecls, decl) |
| } |
| } |
| file.Decls = newDecls |
| file.Scope = nil |
| file.Unresolved = nil |
| file.Comments = nil |
| trimAST(file) |
| } |
| } |
| |
| func (f *unexportedFilter) keep(ident *ast.Ident) bool { |
| return ast.IsExported(ident.Name) || f.uses[ident.Name] |
| } |
| |
| func (f *unexportedFilter) filterDecl(decl ast.Decl) bool { |
| switch decl := decl.(type) { |
| case *ast.FuncDecl: |
| if ident := recvIdent(decl); ident != nil && !f.keep(ident) { |
| return false |
| } |
| return f.keep(decl.Name) |
| case *ast.GenDecl: |
| if decl.Tok == token.CONST { |
| // Constants can involve iota, and iota is hard to deal with. |
| return true |
| } |
| var newSpecs []ast.Spec |
| for _, spec := range decl.Specs { |
| if f.filterSpec(spec) { |
| newSpecs = append(newSpecs, spec) |
| } |
| } |
| decl.Specs = newSpecs |
| return len(newSpecs) != 0 |
| case *ast.BadDecl: |
| return false |
| } |
| panic(fmt.Sprintf("unknown ast.Decl %T", decl)) |
| } |
| |
| func (f *unexportedFilter) filterSpec(spec ast.Spec) bool { |
| switch spec := spec.(type) { |
| case *ast.ImportSpec: |
| return true |
| case *ast.ValueSpec: |
| var newNames []*ast.Ident |
| for _, name := range spec.Names { |
| if f.keep(name) { |
| newNames = append(newNames, name) |
| } |
| } |
| spec.Names = newNames |
| return len(spec.Names) != 0 |
| case *ast.TypeSpec: |
| if !f.keep(spec.Name) { |
| return false |
| } |
| switch typ := spec.Type.(type) { |
| case *ast.StructType: |
| f.filterFieldList(typ.Fields) |
| case *ast.InterfaceType: |
| f.filterFieldList(typ.Methods) |
| } |
| return true |
| } |
| panic(fmt.Sprintf("unknown ast.Spec %T", spec)) |
| } |
| |
| func (f *unexportedFilter) filterFieldList(fields *ast.FieldList) { |
| var newFields []*ast.Field |
| for _, field := range fields.List { |
| if len(field.Names) == 0 { |
| // Keep embedded fields: they can export methods and fields. |
| newFields = append(newFields, field) |
| } |
| for _, name := range field.Names { |
| if f.keep(name) { |
| newFields = append(newFields, field) |
| break |
| } |
| } |
| } |
| fields.List = newFields |
| } |
| |
| func (f *unexportedFilter) recordUses(file *ast.File) { |
| for _, decl := range file.Decls { |
| switch decl := decl.(type) { |
| case *ast.FuncDecl: |
| // Ignore methods on dropped types. |
| if ident := recvIdent(decl); ident != nil && !f.keep(ident) { |
| break |
| } |
| // Ignore functions with dropped names. |
| if !f.keep(decl.Name) { |
| break |
| } |
| f.recordFuncType(decl.Type) |
| case *ast.GenDecl: |
| for _, spec := range decl.Specs { |
| switch spec := spec.(type) { |
| case *ast.ValueSpec: |
| for i, name := range spec.Names { |
| // Don't mess with constants -- iota is hard. |
| if f.keep(name) || decl.Tok == token.CONST { |
| f.recordIdents(spec.Type) |
| if len(spec.Values) > i { |
| f.recordIdents(spec.Values[i]) |
| } |
| } |
| } |
| case *ast.TypeSpec: |
| switch typ := spec.Type.(type) { |
| case *ast.StructType: |
| f.recordFieldUses(false, typ.Fields) |
| case *ast.InterfaceType: |
| f.recordFieldUses(false, typ.Methods) |
| } |
| } |
| } |
| } |
| } |
| } |
| |
| // recvIdent returns the identifier of a method receiver, e.g. *int. |
| func recvIdent(decl *ast.FuncDecl) *ast.Ident { |
| if decl.Recv == nil || len(decl.Recv.List) == 0 { |
| return nil |
| } |
| x := decl.Recv.List[0].Type |
| if star, ok := x.(*ast.StarExpr); ok { |
| x = star.X |
| } |
| if ident, ok := x.(*ast.Ident); ok { |
| return ident |
| } |
| return nil |
| } |
| |
| // recordIdents records unexported identifiers in an Expr in uses. |
| // These may be types, e.g. in map[key]value, function names, e.g. in foo(), |
| // or simple variable references. References that will be discarded, such |
| // as those in function literal bodies, are ignored. |
| func (f *unexportedFilter) recordIdents(x ast.Expr) { |
| ast.Inspect(x, func(n ast.Node) bool { |
| if n == nil { |
| return false |
| } |
| if complit, ok := n.(*ast.CompositeLit); ok { |
| // We clear out composite literal contents; just record their type. |
| f.recordIdents(complit.Type) |
| return false |
| } |
| if flit, ok := n.(*ast.FuncLit); ok { |
| f.recordFuncType(flit.Type) |
| return false |
| } |
| if ident, ok := n.(*ast.Ident); ok && !ast.IsExported(ident.Name) { |
| f.uses[ident.Name] = true |
| } |
| return true |
| }) |
| } |
| |
| // recordFuncType records the types mentioned by a function type. |
| func (f *unexportedFilter) recordFuncType(x *ast.FuncType) { |
| f.recordFieldUses(true, x.Params) |
| f.recordFieldUses(true, x.Results) |
| } |
| |
| // recordFieldUses records unexported identifiers used in fields, which may be |
| // struct members, interface members, or function parameter/results. |
| func (f *unexportedFilter) recordFieldUses(isParams bool, fields *ast.FieldList) { |
| if fields == nil { |
| return |
| } |
| for _, field := range fields.List { |
| if isParams { |
| // Parameter types of retained functions need to be retained. |
| f.recordIdents(field.Type) |
| continue |
| } |
| if ft, ok := field.Type.(*ast.FuncType); ok { |
| // Function declarations in interfaces need all their types retained. |
| f.recordFuncType(ft) |
| continue |
| } |
| if len(field.Names) == 0 { |
| // Embedded fields might contribute exported names. |
| f.recordIdents(field.Type) |
| } |
| for _, name := range field.Names { |
| // We only need normal fields if they're exported. |
| if ast.IsExported(name.Name) { |
| f.recordIdents(field.Type) |
| break |
| } |
| } |
| } |
| } |
| |
| // ProcessErrors records additional uses from errors, returning the new uses |
| // and any unexpected errors. |
| func (f *unexportedFilter) ProcessErrors(errors []types.Error) (map[string]bool, []types.Error) { |
| var unexpected []types.Error |
| missing := map[string]bool{} |
| for _, err := range errors { |
| if strings.Contains(err.Msg, "missing return") { |
| continue |
| } |
| const undeclared = "undeclared name: " |
| if strings.HasPrefix(err.Msg, undeclared) { |
| missing[strings.TrimPrefix(err.Msg, undeclared)] = true |
| f.uses[strings.TrimPrefix(err.Msg, undeclared)] = true |
| continue |
| } |
| unexpected = append(unexpected, err) |
| } |
| return missing, unexpected |
| } |
| |
| // trimAST clears any part of the AST not relevant to type checking |
| // expressions at pos. |
| func trimAST(file *ast.File) { |
| ast.Inspect(file, func(n ast.Node) bool { |
| if n == nil { |
| return false |
| } |
| switch n := n.(type) { |
| case *ast.FuncDecl: |
| n.Body = nil |
| case *ast.BlockStmt: |
| n.List = nil |
| case *ast.CaseClause: |
| n.Body = nil |
| case *ast.CommClause: |
| n.Body = nil |
| case *ast.CompositeLit: |
| // types.Info.Types for long slice/array literals are particularly |
| // expensive. Try to clear them out. |
| at, ok := n.Type.(*ast.ArrayType) |
| if !ok { |
| // Composite literal. No harm removing all its fields. |
| n.Elts = nil |
| break |
| } |
| // Removing the elements from an ellipsis array changes its type. |
| // Try to set the length explicitly so we can continue. |
| if _, ok := at.Len.(*ast.Ellipsis); ok { |
| length, ok := arrayLength(n) |
| if !ok { |
| break |
| } |
| at.Len = &ast.BasicLit{ |
| Kind: token.INT, |
| Value: fmt.Sprint(length), |
| ValuePos: at.Len.Pos(), |
| } |
| } |
| n.Elts = nil |
| } |
| return true |
| }) |
| } |
| |
| // arrayLength returns the length of some simple forms of ellipsis array literal. |
| // Notably, it handles the tables in golang.org/x/text. |
| func arrayLength(array *ast.CompositeLit) (int, bool) { |
| litVal := func(expr ast.Expr) (int, bool) { |
| lit, ok := expr.(*ast.BasicLit) |
| if !ok { |
| return 0, false |
| } |
| val, err := strconv.ParseInt(lit.Value, 10, 64) |
| if err != nil { |
| return 0, false |
| } |
| return int(val), true |
| } |
| largestKey := -1 |
| for _, elt := range array.Elts { |
| kve, ok := elt.(*ast.KeyValueExpr) |
| if !ok { |
| continue |
| } |
| switch key := kve.Key.(type) { |
| case *ast.BasicLit: |
| if val, ok := litVal(key); ok && largestKey < val { |
| largestKey = val |
| } |
| case *ast.BinaryExpr: |
| // golang.org/x/text uses subtraction (and only subtraction) in its indices. |
| if key.Op != token.SUB { |
| break |
| } |
| x, ok := litVal(key.X) |
| if !ok { |
| break |
| } |
| y, ok := litVal(key.Y) |
| if !ok { |
| break |
| } |
| if val := x - y; largestKey < val { |
| largestKey = val |
| } |
| } |
| } |
| if largestKey != -1 { |
| return largestKey + 1, true |
| } |
| return len(array.Elts), true |
| } |
| |
| // fixAST inspects the AST and potentially modifies any *ast.BadStmts so that it can be |
| // type-checked more effectively. |
| func fixAST(ctx context.Context, n ast.Node, tok *token.File, src []byte) (fixed bool) { |
| var err error |
| walkASTWithParent(n, func(n, parent ast.Node) bool { |
| switch n := n.(type) { |
| case *ast.BadStmt: |
| if fixed = fixDeferOrGoStmt(n, parent, tok, src); fixed { |
| // Recursively fix in our fixed node. |
| _ = fixAST(ctx, parent, tok, src) |
| } else { |
| err = errors.Errorf("unable to parse defer or go from *ast.BadStmt: %v", err) |
| } |
| return false |
| case *ast.BadExpr: |
| if fixed = fixArrayType(n, parent, tok, src); fixed { |
| // Recursively fix in our fixed node. |
| _ = fixAST(ctx, parent, tok, src) |
| return false |
| } |
| |
| // Fix cases where parser interprets if/for/switch "init" |
| // statement as "cond" expression, e.g.: |
| // |
| // // "i := foo" is init statement, not condition. |
| // for i := foo |
| // |
| fixInitStmt(n, parent, tok, src) |
| |
| return false |
| case *ast.SelectorExpr: |
| // Fix cases where a keyword prefix results in a phantom "_" selector, e.g.: |
| // |
| // foo.var<> // want to complete to "foo.variance" |
| // |
| fixPhantomSelector(n, tok, src) |
| return true |
| |
| case *ast.BlockStmt: |
| switch parent.(type) { |
| case *ast.SwitchStmt, *ast.TypeSwitchStmt, *ast.SelectStmt: |
| // Adjust closing curly brace of empty switch/select |
| // statements so we can complete inside them. |
| fixEmptySwitch(n, tok, src) |
| } |
| |
| return true |
| default: |
| return true |
| } |
| }) |
| return fixed |
| } |
| |
| // walkASTWithParent walks the AST rooted at n. The semantics are |
| // similar to ast.Inspect except it does not call f(nil). |
| func walkASTWithParent(n ast.Node, f func(n ast.Node, parent ast.Node) bool) { |
| var ancestors []ast.Node |
| ast.Inspect(n, func(n ast.Node) (recurse bool) { |
| defer func() { |
| if recurse { |
| ancestors = append(ancestors, n) |
| } |
| }() |
| |
| if n == nil { |
| ancestors = ancestors[:len(ancestors)-1] |
| return false |
| } |
| |
| var parent ast.Node |
| if len(ancestors) > 0 { |
| parent = ancestors[len(ancestors)-1] |
| } |
| |
| return f(n, parent) |
| }) |
| } |
| |
| // fixSrc attempts to modify the file's source code to fix certain |
| // syntax errors that leave the rest of the file unparsed. |
| func fixSrc(f *ast.File, tok *token.File, src []byte) (newSrc []byte) { |
| walkASTWithParent(f, func(n, parent ast.Node) bool { |
| if newSrc != nil { |
| return false |
| } |
| |
| switch n := n.(type) { |
| case *ast.BlockStmt: |
| newSrc = fixMissingCurlies(f, n, parent, tok, src) |
| case *ast.SelectorExpr: |
| newSrc = fixDanglingSelector(n, tok, src) |
| } |
| |
| return newSrc == nil |
| }) |
| |
| return newSrc |
| } |
| |
| // fixMissingCurlies adds in curly braces for block statements that |
| // are missing curly braces. For example: |
| // |
| // if foo |
| // |
| // becomes |
| // |
| // if foo {} |
| func fixMissingCurlies(f *ast.File, b *ast.BlockStmt, parent ast.Node, tok *token.File, src []byte) []byte { |
| // If the "{" is already in the source code, there isn't anything to |
| // fix since we aren't missing curlies. |
| if b.Lbrace.IsValid() { |
| braceOffset := tok.Offset(b.Lbrace) |
| if braceOffset < len(src) && src[braceOffset] == '{' { |
| return nil |
| } |
| } |
| |
| parentLine := tok.Line(parent.Pos()) |
| |
| if parentLine >= tok.LineCount() { |
| // If we are the last line in the file, no need to fix anything. |
| return nil |
| } |
| |
| // Insert curlies at the end of parent's starting line. The parent |
| // is the statement that contains the block, e.g. *ast.IfStmt. The |
| // block's Pos()/End() can't be relied upon because they are based |
| // on the (missing) curly braces. We assume the statement is a |
| // single line for now and try sticking the curly braces at the end. |
| insertPos := tok.LineStart(parentLine+1) - 1 |
| |
| // Scootch position backwards until it's not in a comment. For example: |
| // |
| // if foo<> // some amazing comment | |
| // someOtherCode() |
| // |
| // insertPos will be located at "|", so we back it out of the comment. |
| didSomething := true |
| for didSomething { |
| didSomething = false |
| for _, c := range f.Comments { |
| if c.Pos() < insertPos && insertPos <= c.End() { |
| insertPos = c.Pos() |
| didSomething = true |
| } |
| } |
| } |
| |
| // Bail out if line doesn't end in an ident or ".". This is to avoid |
| // cases like below where we end up making things worse by adding |
| // curlies: |
| // |
| // if foo && |
| // bar<> |
| switch precedingToken(insertPos, tok, src) { |
| case token.IDENT, token.PERIOD: |
| // ok |
| default: |
| return nil |
| } |
| |
| var buf bytes.Buffer |
| buf.Grow(len(src) + 3) |
| buf.Write(src[:tok.Offset(insertPos)]) |
| |
| // Detect if we need to insert a semicolon to fix "for" loop situations like: |
| // |
| // for i := foo(); foo<> |
| // |
| // Just adding curlies is not sufficient to make things parse well. |
| if fs, ok := parent.(*ast.ForStmt); ok { |
| if _, ok := fs.Cond.(*ast.BadExpr); !ok { |
| if xs, ok := fs.Post.(*ast.ExprStmt); ok { |
| if _, ok := xs.X.(*ast.BadExpr); ok { |
| buf.WriteByte(';') |
| } |
| } |
| } |
| } |
| |
| // Insert "{}" at insertPos. |
| buf.WriteByte('{') |
| buf.WriteByte('}') |
| buf.Write(src[tok.Offset(insertPos):]) |
| return buf.Bytes() |
| } |
| |
| // fixEmptySwitch moves empty switch/select statements' closing curly |
| // brace down one line. This allows us to properly detect incomplete |
| // "case" and "default" keywords as inside the switch statement. For |
| // example: |
| // |
| // switch { |
| // def<> |
| // } |
| // |
| // gets parsed like: |
| // |
| // switch { |
| // } |
| // |
| // Later we manually pull out the "def" token, but we need to detect |
| // that our "<>" position is inside the switch block. To do that we |
| // move the curly brace so it looks like: |
| // |
| // switch { |
| // |
| // } |
| // |
| func fixEmptySwitch(body *ast.BlockStmt, tok *token.File, src []byte) { |
| // We only care about empty switch statements. |
| if len(body.List) > 0 || !body.Rbrace.IsValid() { |
| return |
| } |
| |
| // If the right brace is actually in the source code at the |
| // specified position, don't mess with it. |
| braceOffset := tok.Offset(body.Rbrace) |
| if braceOffset < len(src) && src[braceOffset] == '}' { |
| return |
| } |
| |
| braceLine := tok.Line(body.Rbrace) |
| if braceLine >= tok.LineCount() { |
| // If we are the last line in the file, no need to fix anything. |
| return |
| } |
| |
| // Move the right brace down one line. |
| body.Rbrace = tok.LineStart(braceLine + 1) |
| } |
| |
| // fixDanglingSelector inserts real "_" selector expressions in place |
| // of phantom "_" selectors. For example: |
| // |
| // func _() { |
| // x.<> |
| // } |
| // var x struct { i int } |
| // |
| // To fix completion at "<>", we insert a real "_" after the "." so the |
| // following declaration of "x" can be parsed and type checked |
| // normally. |
| func fixDanglingSelector(s *ast.SelectorExpr, tok *token.File, src []byte) []byte { |
| if !isPhantomUnderscore(s.Sel, tok, src) { |
| return nil |
| } |
| |
| if !s.X.End().IsValid() { |
| return nil |
| } |
| |
| // Insert directly after the selector's ".". |
| insertOffset := tok.Offset(s.X.End()) + 1 |
| if src[insertOffset-1] != '.' { |
| return nil |
| } |
| |
| var buf bytes.Buffer |
| buf.Grow(len(src) + 1) |
| buf.Write(src[:insertOffset]) |
| buf.WriteByte('_') |
| buf.Write(src[insertOffset:]) |
| return buf.Bytes() |
| } |
| |
| // fixPhantomSelector tries to fix selector expressions with phantom |
| // "_" selectors. In particular, we check if the selector is a |
| // keyword, and if so we swap in an *ast.Ident with the keyword text. For example: |
| // |
| // foo.var |
| // |
| // yields a "_" selector instead of "var" since "var" is a keyword. |
| func fixPhantomSelector(sel *ast.SelectorExpr, tok *token.File, src []byte) { |
| if !isPhantomUnderscore(sel.Sel, tok, src) { |
| return |
| } |
| |
| // Only consider selectors directly abutting the selector ".". This |
| // avoids false positives in cases like: |
| // |
| // foo. // don't think "var" is our selector |
| // var bar = 123 |
| // |
| if sel.Sel.Pos() != sel.X.End()+1 { |
| return |
| } |
| |
| maybeKeyword := readKeyword(sel.Sel.Pos(), tok, src) |
| if maybeKeyword == "" { |
| return |
| } |
| |
| replaceNode(sel, sel.Sel, &ast.Ident{ |
| Name: maybeKeyword, |
| NamePos: sel.Sel.Pos(), |
| }) |
| } |
| |
| // isPhantomUnderscore reports whether the given ident is a phantom |
| // underscore. The parser sometimes inserts phantom underscores when |
| // it encounters otherwise unparseable situations. |
| func isPhantomUnderscore(id *ast.Ident, tok *token.File, src []byte) bool { |
| if id == nil || id.Name != "_" { |
| return false |
| } |
| |
| // Phantom underscore means the underscore is not actually in the |
| // program text. |
| offset := tok.Offset(id.Pos()) |
| return len(src) <= offset || src[offset] != '_' |
| } |
| |
| // fixInitStmt fixes cases where the parser misinterprets an |
| // if/for/switch "init" statement as the "cond" conditional. In cases |
| // like "if i := 0" the user hasn't typed the semicolon yet so the |
| // parser is looking for the conditional expression. However, "i := 0" |
| // are not valid expressions, so we get a BadExpr. |
| func fixInitStmt(bad *ast.BadExpr, parent ast.Node, tok *token.File, src []byte) { |
| if !bad.Pos().IsValid() || !bad.End().IsValid() { |
| return |
| } |
| |
| // Try to extract a statement from the BadExpr. |
| stmtBytes := src[tok.Offset(bad.Pos()) : tok.Offset(bad.End()-1)+1] |
| stmt, err := parseStmt(bad.Pos(), stmtBytes) |
| if err != nil { |
| return |
| } |
| |
| // If the parent statement doesn't already have an "init" statement, |
| // move the extracted statement into the "init" field and insert a |
| // dummy expression into the required "cond" field. |
| switch p := parent.(type) { |
| case *ast.IfStmt: |
| if p.Init != nil { |
| return |
| } |
| p.Init = stmt |
| p.Cond = &ast.Ident{ |
| Name: "_", |
| NamePos: stmt.End(), |
| } |
| case *ast.ForStmt: |
| if p.Init != nil { |
| return |
| } |
| p.Init = stmt |
| p.Cond = &ast.Ident{ |
| Name: "_", |
| NamePos: stmt.End(), |
| } |
| case *ast.SwitchStmt: |
| if p.Init != nil { |
| return |
| } |
| p.Init = stmt |
| p.Tag = nil |
| } |
| } |
| |
| // readKeyword reads the keyword starting at pos, if any. |
| func readKeyword(pos token.Pos, tok *token.File, src []byte) string { |
| var kwBytes []byte |
| for i := tok.Offset(pos); i < len(src); i++ { |
| // Use a simplified identifier check since keywords are always lowercase ASCII. |
| if src[i] < 'a' || src[i] > 'z' { |
| break |
| } |
| kwBytes = append(kwBytes, src[i]) |
| |
| // Stop search at arbitrarily chosen too-long-for-a-keyword length. |
| if len(kwBytes) > 15 { |
| return "" |
| } |
| } |
| |
| if kw := string(kwBytes); token.Lookup(kw).IsKeyword() { |
| return kw |
| } |
| |
| return "" |
| } |
| |
| // fixArrayType tries to parse an *ast.BadExpr into an *ast.ArrayType. |
| // go/parser often turns lone array types like "[]int" into BadExprs |
| // if it isn't expecting a type. |
| func fixArrayType(bad *ast.BadExpr, parent ast.Node, tok *token.File, src []byte) bool { |
| // Our expected input is a bad expression that looks like "[]someExpr". |
| |
| from := bad.Pos() |
| to := bad.End() |
| |
| if !from.IsValid() || !to.IsValid() { |
| return false |
| } |
| |
| exprBytes := make([]byte, 0, int(to-from)+3) |
| // Avoid doing tok.Offset(to) since that panics if badExpr ends at EOF. |
| // It also panics if the position is not in the range of the file, and |
| // badExprs may not necessarily have good positions, so check first. |
| if !source.InRange(tok, from) { |
| return false |
| } |
| if !source.InRange(tok, to-1) { |
| return false |
| } |
| fromOffset := tok.Offset(from) |
| toOffset := tok.Offset(to-1) + 1 |
| exprBytes = append(exprBytes, src[fromOffset:toOffset]...) |
| exprBytes = bytes.TrimSpace(exprBytes) |
| |
| // If our expression ends in "]" (e.g. "[]"), add a phantom selector |
| // so we can complete directly after the "[]". |
| if len(exprBytes) > 0 && exprBytes[len(exprBytes)-1] == ']' { |
| exprBytes = append(exprBytes, '_') |
| } |
| |
| // Add "{}" to turn our ArrayType into a CompositeLit. This is to |
| // handle the case of "[...]int" where we must make it a composite |
| // literal to be parseable. |
| exprBytes = append(exprBytes, '{', '}') |
| |
| expr, err := parseExpr(from, exprBytes) |
| if err != nil { |
| return false |
| } |
| |
| cl, _ := expr.(*ast.CompositeLit) |
| if cl == nil { |
| return false |
| } |
| |
| at, _ := cl.Type.(*ast.ArrayType) |
| if at == nil { |
| return false |
| } |
| |
| return replaceNode(parent, bad, at) |
| } |
| |
| // precedingToken scans src to find the token preceding pos. |
| func precedingToken(pos token.Pos, tok *token.File, src []byte) token.Token { |
| s := &scanner.Scanner{} |
| s.Init(tok, src, nil, 0) |
| |
| var lastTok token.Token |
| for { |
| p, t, _ := s.Scan() |
| if t == token.EOF || p >= pos { |
| break |
| } |
| |
| lastTok = t |
| } |
| return lastTok |
| } |
| |
| // fixDeferOrGoStmt tries to parse an *ast.BadStmt into a defer or a go statement. |
| // |
| // go/parser packages a statement of the form "defer x." as an *ast.BadStmt because |
| // it does not include a call expression. This means that go/types skips type-checking |
| // this statement entirely, and we can't use the type information when completing. |
| // Here, we try to generate a fake *ast.DeferStmt or *ast.GoStmt to put into the AST, |
| // instead of the *ast.BadStmt. |
| func fixDeferOrGoStmt(bad *ast.BadStmt, parent ast.Node, tok *token.File, src []byte) bool { |
| // Check if we have a bad statement containing either a "go" or "defer". |
| s := &scanner.Scanner{} |
| s.Init(tok, src, nil, 0) |
| |
| var ( |
| pos token.Pos |
| tkn token.Token |
| ) |
| for { |
| if tkn == token.EOF { |
| return false |
| } |
| if pos >= bad.From { |
| break |
| } |
| pos, tkn, _ = s.Scan() |
| } |
| |
| var stmt ast.Stmt |
| switch tkn { |
| case token.DEFER: |
| stmt = &ast.DeferStmt{ |
| Defer: pos, |
| } |
| case token.GO: |
| stmt = &ast.GoStmt{ |
| Go: pos, |
| } |
| default: |
| return false |
| } |
| |
| var ( |
| from, to, last token.Pos |
| lastToken token.Token |
| braceDepth int |
| phantomSelectors []token.Pos |
| ) |
| FindTo: |
| for { |
| to, tkn, _ = s.Scan() |
| |
| if from == token.NoPos { |
| from = to |
| } |
| |
| switch tkn { |
| case token.EOF: |
| break FindTo |
| case token.SEMICOLON: |
| // If we aren't in nested braces, end of statement means |
| // end of expression. |
| if braceDepth == 0 { |
| break FindTo |
| } |
| case token.LBRACE: |
| braceDepth++ |
| } |
| |
| // This handles the common dangling selector case. For example in |
| // |
| // defer fmt. |
| // y := 1 |
| // |
| // we notice the dangling period and end our expression. |
| // |
| // If the previous token was a "." and we are looking at a "}", |
| // the period is likely a dangling selector and needs a phantom |
| // "_". Likewise if the current token is on a different line than |
| // the period, the period is likely a dangling selector. |
| if lastToken == token.PERIOD && (tkn == token.RBRACE || tok.Line(to) > tok.Line(last)) { |
| // Insert phantom "_" selector after the dangling ".". |
| phantomSelectors = append(phantomSelectors, last+1) |
| // If we aren't in a block then end the expression after the ".". |
| if braceDepth == 0 { |
| to = last + 1 |
| break |
| } |
| } |
| |
| lastToken = tkn |
| last = to |
| |
| switch tkn { |
| case token.RBRACE: |
| braceDepth-- |
| if braceDepth <= 0 { |
| if braceDepth == 0 { |
| // +1 to include the "}" itself. |
| to += 1 |
| } |
| break FindTo |
| } |
| } |
| } |
| |
| if !from.IsValid() || tok.Offset(from) >= len(src) { |
| return false |
| } |
| |
| if !to.IsValid() || tok.Offset(to) >= len(src) { |
| return false |
| } |
| |
| // Insert any phantom selectors needed to prevent dangling "." from messing |
| // up the AST. |
| exprBytes := make([]byte, 0, int(to-from)+len(phantomSelectors)) |
| for i, b := range src[tok.Offset(from):tok.Offset(to)] { |
| if len(phantomSelectors) > 0 && from+token.Pos(i) == phantomSelectors[0] { |
| exprBytes = append(exprBytes, '_') |
| phantomSelectors = phantomSelectors[1:] |
| } |
| exprBytes = append(exprBytes, b) |
| } |
| |
| if len(phantomSelectors) > 0 { |
| exprBytes = append(exprBytes, '_') |
| } |
| |
| expr, err := parseExpr(from, exprBytes) |
| if err != nil { |
| return false |
| } |
| |
| // Package the expression into a fake *ast.CallExpr and re-insert |
| // into the function. |
| call := &ast.CallExpr{ |
| Fun: expr, |
| Lparen: to, |
| Rparen: to, |
| } |
| |
| switch stmt := stmt.(type) { |
| case *ast.DeferStmt: |
| stmt.Call = call |
| case *ast.GoStmt: |
| stmt.Call = call |
| } |
| |
| return replaceNode(parent, bad, stmt) |
| } |
| |
| // parseStmt parses the statement in src and updates its position to |
| // start at pos. |
| func parseStmt(pos token.Pos, src []byte) (ast.Stmt, error) { |
| // Wrap our expression to make it a valid Go file we can pass to ParseFile. |
| fileSrc := bytes.Join([][]byte{ |
| []byte("package fake;func _(){"), |
| src, |
| []byte("}"), |
| }, nil) |
| |
| // Use ParseFile instead of ParseExpr because ParseFile has |
| // best-effort behavior, whereas ParseExpr fails hard on any error. |
| fakeFile, err := parser.ParseFile(token.NewFileSet(), "", fileSrc, 0) |
| if fakeFile == nil { |
| return nil, errors.Errorf("error reading fake file source: %v", err) |
| } |
| |
| // Extract our expression node from inside the fake file. |
| if len(fakeFile.Decls) == 0 { |
| return nil, errors.Errorf("error parsing fake file: %v", err) |
| } |
| |
| fakeDecl, _ := fakeFile.Decls[0].(*ast.FuncDecl) |
| if fakeDecl == nil || len(fakeDecl.Body.List) == 0 { |
| return nil, errors.Errorf("no statement in %s: %v", src, err) |
| } |
| |
| stmt := fakeDecl.Body.List[0] |
| |
| // parser.ParseFile returns undefined positions. |
| // Adjust them for the current file. |
| offsetPositions(stmt, pos-1-(stmt.Pos()-1)) |
| |
| return stmt, nil |
| } |
| |
| // parseExpr parses the expression in src and updates its position to |
| // start at pos. |
| func parseExpr(pos token.Pos, src []byte) (ast.Expr, error) { |
| stmt, err := parseStmt(pos, src) |
| if err != nil { |
| return nil, err |
| } |
| |
| exprStmt, ok := stmt.(*ast.ExprStmt) |
| if !ok { |
| return nil, errors.Errorf("no expr in %s: %v", src, err) |
| } |
| |
| return exprStmt.X, nil |
| } |
| |
| var tokenPosType = reflect.TypeOf(token.NoPos) |
| |
| // offsetPositions applies an offset to the positions in an ast.Node. |
| func offsetPositions(n ast.Node, offset token.Pos) { |
| ast.Inspect(n, func(n ast.Node) bool { |
| if n == nil { |
| return false |
| } |
| |
| v := reflect.ValueOf(n).Elem() |
| |
| switch v.Kind() { |
| case reflect.Struct: |
| for i := 0; i < v.NumField(); i++ { |
| f := v.Field(i) |
| if f.Type() != tokenPosType { |
| continue |
| } |
| |
| if !f.CanSet() { |
| continue |
| } |
| |
| f.SetInt(f.Int() + int64(offset)) |
| } |
| } |
| |
| return true |
| }) |
| } |
| |
| // replaceNode updates parent's child oldChild to be newChild. It |
| // returns whether it replaced successfully. |
| func replaceNode(parent, oldChild, newChild ast.Node) bool { |
| if parent == nil || oldChild == nil || newChild == nil { |
| return false |
| } |
| |
| parentVal := reflect.ValueOf(parent).Elem() |
| if parentVal.Kind() != reflect.Struct { |
| return false |
| } |
| |
| newChildVal := reflect.ValueOf(newChild) |
| |
| tryReplace := func(v reflect.Value) bool { |
| if !v.CanSet() || !v.CanInterface() { |
| return false |
| } |
| |
| // If the existing value is oldChild, we found our child. Make |
| // sure our newChild is assignable and then make the swap. |
| if v.Interface() == oldChild && newChildVal.Type().AssignableTo(v.Type()) { |
| v.Set(newChildVal) |
| return true |
| } |
| |
| return false |
| } |
| |
| // Loop over parent's struct fields. |
| for i := 0; i < parentVal.NumField(); i++ { |
| f := parentVal.Field(i) |
| |
| switch f.Kind() { |
| // Check interface and pointer fields. |
| case reflect.Interface, reflect.Ptr: |
| if tryReplace(f) { |
| return true |
| } |
| |
| // Search through any slice fields. |
| case reflect.Slice: |
| for i := 0; i < f.Len(); i++ { |
| if tryReplace(f.Index(i)) { |
| return true |
| } |
| } |
| } |
| } |
| |
| return false |
| } |