blob: 7d8e978bb0cd167a40cf02efd67d0c25cf67a083 [file] [log] [blame]
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package go2go rewrites polymorphic code into non-polymorphic code.
package go2go
import (
"bytes"
"fmt"
"go/ast"
"go/parser"
"go/token"
"go/types"
"io"
"os"
"path/filepath"
"sort"
"strings"
)
// parseTypeParams tells go/parser to parse type parameters. Must be kept in
// sync with go/parser/interface.go.
const parseTypeParams parser.Mode = 1 << 30
// rewritePrefix is what we put at the start of each newly generated .go file.
const rewritePrefix = "// Code generated by go2go; DO NOT EDIT.\n\n"
// Rewrite rewrites the contents of a single directory.
// It looks for all files with the extension .go2, and parses
// them as a single package. It writes out a .go file with any
// polymorphic code rewritten into normal code.
func Rewrite(importer *Importer, dir string) error {
_, err := rewriteToPkgs(importer, "", dir)
return err
}
// rewriteToPkgs rewrites the contents of a single directory,
// and returns the types.Packages that it computes.
func rewriteToPkgs(importer *Importer, importPath, dir string) ([]*types.Package, error) {
go2files, gofiles, err := go2Files(dir)
if err != nil {
return nil, err
}
if err := checkAndRemoveGofiles(dir, gofiles); err != nil {
return nil, err
}
return rewriteFilesInPath(importer, importPath, dir, go2files)
}
// namedAST holds a file name and the AST parsed from that file.
type namedAST struct {
name string
ast *ast.File
}
// RewriteFiles rewrites a set of .go2 files in dir.
func RewriteFiles(importer *Importer, dir string, go2files []string) ([]*types.Package, error) {
return rewriteFilesInPath(importer, "", dir, go2files)
}
// rewriteFilesInPath rewrites a set of .go2 files in dir for importPath.
func rewriteFilesInPath(importer *Importer, importPath, dir string, go2files []string) ([]*types.Package, error) {
fset := token.NewFileSet()
pkgs, err := parseFiles(importer, dir, go2files, fset)
if err != nil {
return nil, err
}
var rpkgs []*types.Package
var tpkgs [][]namedAST
for _, pkg := range pkgs {
pkgfiles := make([]namedAST, 0, len(pkg.Files))
for n, f := range pkg.Files {
pkgfiles = append(pkgfiles, namedAST{n, f})
}
sort.Slice(pkgfiles, func(i, j int) bool {
return pkgfiles[i].name < pkgfiles[j].name
})
asts := make([]*ast.File, 0, len(pkgfiles))
for _, a := range pkgfiles {
asts = append(asts, a.ast)
}
var merr multiErr
conf := types.Config{
Importer: importer,
Error: merr.add,
}
path := importPath
if path == "" {
path = pkg.Name
}
tpkg, err := conf.Check(path, fset, asts, importer.info)
if err != nil {
return nil, fmt.Errorf("type checking failed for %s\n%v", pkg.Name, merr)
}
importer.record(pkg.Name, pkgfiles, importPath, tpkg, asts)
rpkgs = append(rpkgs, tpkg)
tpkgs = append(tpkgs, pkgfiles)
}
for i, tpkg := range tpkgs {
addImportable := 0
for j, pkgfile := range tpkg {
if !strings.HasSuffix(pkgfile.name, "_test.go2") {
addImportable = j
break
}
}
for j, pkgfile := range tpkg {
if err := rewriteFile(dir, fset, importer, importPath, rpkgs[i], pkgfile.name, pkgfile.ast, j == addImportable); err != nil {
return nil, err
}
}
}
return rpkgs, nil
}
// RewriteBuffer rewrites the contents of a single file, in a buffer.
// It returns a modified buffer. The filename parameter is only used
// for error messages.
func RewriteBuffer(importer *Importer, filename string, file []byte) ([]byte, error) {
fset := token.NewFileSet()
pf, err := parser.ParseFile(fset, filename, file, parseTypeParams)
if err != nil {
return nil, err
}
var merr multiErr
conf := types.Config{
Importer: importer,
Error: merr.add,
}
tpkg, err := conf.Check(pf.Name.Name, fset, []*ast.File{pf}, importer.info)
if err != nil {
return nil, fmt.Errorf("type checking failed for %s\n%v", pf.Name.Name, merr)
}
importer.addIDs(pf)
if err := rewriteAST(fset, importer, "", tpkg, pf, true); err != nil {
return nil, err
}
var buf bytes.Buffer
fmt.Fprintln(&buf, rewritePrefix)
if err := config.Fprint(&buf, fset, pf); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
// go2Files returns the list of files in dir with a .go2 extension
// and a list of files with a .go extension.
// This returns an error if it finds any .go files that do not start
// with rewritePrefix.
func go2Files(dir string) (go2files []string, gofiles []string, err error) {
f, err := os.Open(dir)
if err != nil {
return nil, nil, err
}
defer f.Close()
files, err := f.Readdirnames(0)
if err != nil {
return nil, nil, fmt.Errorf("reading directory %s: %w", dir, err)
}
go2files = make([]string, 0, len(files))
gofiles = make([]string, 0, len(files))
for _, f := range files {
switch filepath.Ext(f) {
case ".go2":
go2files = append(go2files, f)
case ".go":
gofiles = append(gofiles, f)
}
}
return go2files, gofiles, nil
}
// checkAndRemoveGofiles looks through all the .go files.
// Any .go file that starts with rewritePrefix is removed.
// Any other .go file is reported as an error.
// This is intended to make it harder for go2go to break a
// traditional Go package.
func checkAndRemoveGofiles(dir string, gofiles []string) error {
for _, f := range gofiles {
if err := checkGoFile(dir, f); err != nil {
return err
}
if err := os.Remove(filepath.Join(dir, f)); err != nil {
return err
}
}
return nil
}
// checkGofile reports an error if the file does not start with rewritePrefix.
func checkGoFile(dir, f string) error {
o, err := os.Open(filepath.Join(dir, f))
if err != nil {
return err
}
defer o.Close()
var buf [100]byte
n, err := o.Read(buf[:])
if n > 0 && !strings.HasPrefix(string(buf[:n]), rewritePrefix) {
return fmt.Errorf("Go file %s was not created by go2go", f)
}
if err != nil && err != io.EOF {
return err
}
return nil
}
// parseFiles parses a list of .go2 files.
func parseFiles(importer *Importer, dir string, go2files []string, fset *token.FileSet) ([]*ast.Package, error) {
pkgs := make(map[string]*ast.Package)
for _, go2f := range go2files {
mode := parseTypeParams
filename := filepath.Join(dir, go2f)
pf, err := parser.ParseFile(fset, filename, nil, mode)
if err != nil {
return nil, err
}
name := pf.Name.Name
pkg, ok := pkgs[name]
if !ok {
pkg = &ast.Package{
Name: name,
Files: make(map[string]*ast.File),
}
pkgs[name] = pkg
}
pkg.Files[filename] = pf
}
rpkgs := make([]*ast.Package, 0, len(pkgs))
for _, pkg := range pkgs {
rpkgs = append(rpkgs, pkg)
}
sort.Slice(rpkgs, func(i, j int) bool {
return rpkgs[i].Name < rpkgs[j].Name
})
return rpkgs, nil
}
// multiErr is an error value that accumulates type checking errors.
type multiErr []error
// The add methods adds another error to a multiErr.
func (m *multiErr) add(err error) {
*m = append(*m, err)
}
// The Error method returns the accumulated errors.
func (m multiErr) Error() string {
if len(m) == 0 {
return "internal error: empty multiErr"
}
var sb strings.Builder
for _, e := range m {
fmt.Fprintln(&sb, e)
}
return sb.String()
}