| // Copyright 2020 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| // Package go2go rewrites polymorphic code into non-polymorphic code. |
| package go2go |
| |
| import ( |
| "bytes" |
| "fmt" |
| "go/ast" |
| "go/parser" |
| "go/token" |
| "go/types" |
| "io" |
| "os" |
| "path/filepath" |
| "sort" |
| "strings" |
| ) |
| |
| // parseTypeParams tells go/parser to parse type parameters. Must be kept in |
| // sync with go/parser/interface.go. |
| const parseTypeParams parser.Mode = 1 << 30 |
| |
| // rewritePrefix is what we put at the start of each newly generated .go file. |
| const rewritePrefix = "// Code generated by go2go; DO NOT EDIT.\n\n" |
| |
| // Rewrite rewrites the contents of a single directory. |
| // It looks for all files with the extension .go2, and parses |
| // them as a single package. It writes out a .go file with any |
| // polymorphic code rewritten into normal code. |
| func Rewrite(importer *Importer, dir string) error { |
| _, err := rewriteToPkgs(importer, "", dir) |
| return err |
| } |
| |
| // rewriteToPkgs rewrites the contents of a single directory, |
| // and returns the types.Packages that it computes. |
| func rewriteToPkgs(importer *Importer, importPath, dir string) ([]*types.Package, error) { |
| go2files, gofiles, err := go2Files(dir) |
| if err != nil { |
| return nil, err |
| } |
| |
| if err := checkAndRemoveGofiles(dir, gofiles); err != nil { |
| return nil, err |
| } |
| |
| return rewriteFilesInPath(importer, importPath, dir, go2files) |
| } |
| |
| // namedAST holds a file name and the AST parsed from that file. |
| type namedAST struct { |
| name string |
| ast *ast.File |
| } |
| |
| // RewriteFiles rewrites a set of .go2 files in dir. |
| func RewriteFiles(importer *Importer, dir string, go2files []string) ([]*types.Package, error) { |
| return rewriteFilesInPath(importer, "", dir, go2files) |
| } |
| |
| // rewriteFilesInPath rewrites a set of .go2 files in dir for importPath. |
| func rewriteFilesInPath(importer *Importer, importPath, dir string, go2files []string) ([]*types.Package, error) { |
| fset := token.NewFileSet() |
| pkgs, err := parseFiles(importer, dir, go2files, fset) |
| if err != nil { |
| return nil, err |
| } |
| |
| var rpkgs []*types.Package |
| var tpkgs [][]namedAST |
| for _, pkg := range pkgs { |
| pkgfiles := make([]namedAST, 0, len(pkg.Files)) |
| for n, f := range pkg.Files { |
| pkgfiles = append(pkgfiles, namedAST{n, f}) |
| } |
| sort.Slice(pkgfiles, func(i, j int) bool { |
| return pkgfiles[i].name < pkgfiles[j].name |
| }) |
| |
| asts := make([]*ast.File, 0, len(pkgfiles)) |
| for _, a := range pkgfiles { |
| asts = append(asts, a.ast) |
| } |
| |
| var merr multiErr |
| conf := types.Config{ |
| Importer: importer, |
| Error: merr.add, |
| } |
| path := importPath |
| if path == "" { |
| path = pkg.Name |
| } |
| tpkg, err := conf.Check(path, fset, asts, importer.info) |
| if err != nil { |
| return nil, fmt.Errorf("type checking failed for %s\n%v", pkg.Name, merr) |
| } |
| |
| importer.record(pkg.Name, pkgfiles, importPath, tpkg, asts) |
| |
| rpkgs = append(rpkgs, tpkg) |
| tpkgs = append(tpkgs, pkgfiles) |
| } |
| |
| for i, tpkg := range tpkgs { |
| addImportable := 0 |
| for j, pkgfile := range tpkg { |
| if !strings.HasSuffix(pkgfile.name, "_test.go2") { |
| addImportable = j |
| break |
| } |
| } |
| |
| for j, pkgfile := range tpkg { |
| if err := rewriteFile(dir, fset, importer, importPath, rpkgs[i], pkgfile.name, pkgfile.ast, j == addImportable); err != nil { |
| return nil, err |
| } |
| } |
| } |
| |
| return rpkgs, nil |
| } |
| |
| // RewriteBuffer rewrites the contents of a single file, in a buffer. |
| // It returns a modified buffer. The filename parameter is only used |
| // for error messages. |
| func RewriteBuffer(importer *Importer, filename string, file []byte) ([]byte, error) { |
| fset := token.NewFileSet() |
| pf, err := parser.ParseFile(fset, filename, file, parseTypeParams) |
| if err != nil { |
| return nil, err |
| } |
| var merr multiErr |
| conf := types.Config{ |
| Importer: importer, |
| Error: merr.add, |
| } |
| tpkg, err := conf.Check(pf.Name.Name, fset, []*ast.File{pf}, importer.info) |
| if err != nil { |
| return nil, fmt.Errorf("type checking failed for %s\n%v", pf.Name.Name, merr) |
| } |
| importer.addIDs(pf) |
| if err := rewriteAST(fset, importer, "", tpkg, pf, true); err != nil { |
| return nil, err |
| } |
| var buf bytes.Buffer |
| fmt.Fprintln(&buf, rewritePrefix) |
| if err := config.Fprint(&buf, fset, pf); err != nil { |
| return nil, err |
| } |
| return buf.Bytes(), nil |
| } |
| |
| // go2Files returns the list of files in dir with a .go2 extension |
| // and a list of files with a .go extension. |
| // This returns an error if it finds any .go files that do not start |
| // with rewritePrefix. |
| func go2Files(dir string) (go2files []string, gofiles []string, err error) { |
| f, err := os.Open(dir) |
| if err != nil { |
| return nil, nil, err |
| } |
| defer f.Close() |
| |
| files, err := f.Readdirnames(0) |
| if err != nil { |
| return nil, nil, fmt.Errorf("reading directory %s: %w", dir, err) |
| |
| } |
| |
| go2files = make([]string, 0, len(files)) |
| gofiles = make([]string, 0, len(files)) |
| for _, f := range files { |
| switch filepath.Ext(f) { |
| case ".go2": |
| go2files = append(go2files, f) |
| case ".go": |
| gofiles = append(gofiles, f) |
| } |
| } |
| |
| return go2files, gofiles, nil |
| } |
| |
| // checkAndRemoveGofiles looks through all the .go files. |
| // Any .go file that starts with rewritePrefix is removed. |
| // Any other .go file is reported as an error. |
| // This is intended to make it harder for go2go to break a |
| // traditional Go package. |
| func checkAndRemoveGofiles(dir string, gofiles []string) error { |
| for _, f := range gofiles { |
| if err := checkGoFile(dir, f); err != nil { |
| return err |
| } |
| if err := os.Remove(filepath.Join(dir, f)); err != nil { |
| return err |
| } |
| } |
| return nil |
| } |
| |
| // checkGofile reports an error if the file does not start with rewritePrefix. |
| func checkGoFile(dir, f string) error { |
| o, err := os.Open(filepath.Join(dir, f)) |
| if err != nil { |
| return err |
| } |
| defer o.Close() |
| var buf [100]byte |
| n, err := o.Read(buf[:]) |
| if n > 0 && !strings.HasPrefix(string(buf[:n]), rewritePrefix) { |
| return fmt.Errorf("Go file %s was not created by go2go", f) |
| } |
| if err != nil && err != io.EOF { |
| return err |
| } |
| return nil |
| } |
| |
| // parseFiles parses a list of .go2 files. |
| func parseFiles(importer *Importer, dir string, go2files []string, fset *token.FileSet) ([]*ast.Package, error) { |
| pkgs := make(map[string]*ast.Package) |
| for _, go2f := range go2files { |
| mode := parseTypeParams |
| |
| filename := filepath.Join(dir, go2f) |
| pf, err := parser.ParseFile(fset, filename, nil, mode) |
| if err != nil { |
| return nil, err |
| } |
| |
| name := pf.Name.Name |
| pkg, ok := pkgs[name] |
| if !ok { |
| pkg = &ast.Package{ |
| Name: name, |
| Files: make(map[string]*ast.File), |
| } |
| pkgs[name] = pkg |
| } |
| pkg.Files[filename] = pf |
| } |
| |
| rpkgs := make([]*ast.Package, 0, len(pkgs)) |
| for _, pkg := range pkgs { |
| rpkgs = append(rpkgs, pkg) |
| } |
| sort.Slice(rpkgs, func(i, j int) bool { |
| return rpkgs[i].Name < rpkgs[j].Name |
| }) |
| |
| return rpkgs, nil |
| } |
| |
| // multiErr is an error value that accumulates type checking errors. |
| type multiErr []error |
| |
| // The add methods adds another error to a multiErr. |
| func (m *multiErr) add(err error) { |
| *m = append(*m, err) |
| } |
| |
| // The Error method returns the accumulated errors. |
| func (m multiErr) Error() string { |
| if len(m) == 0 { |
| return "internal error: empty multiErr" |
| } |
| var sb strings.Builder |
| for _, e := range m { |
| fmt.Fprintln(&sb, e) |
| } |
| return sb.String() |
| } |