// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

// Package go2go rewrites polymorphic code into non-polymorphic code.
package go2go

import (
	"bytes"
	"fmt"
	"go/ast"
	"go/parser"
	"go/token"
	"go/types"
	"io"
	"os"
	"path/filepath"
	"sort"
	"strings"
)

// parseTypeParams tells go/parser to parse type parameters. Must be kept in
// sync with go/parser/interface.go.
const parseTypeParams parser.Mode = 1 << 30

// rewritePrefix is what we put at the start of each newly generated .go file.
const rewritePrefix = "// Code generated by go2go; DO NOT EDIT.\n\n"

// Rewrite rewrites the contents of a single directory.
// It looks for all files with the extension .go2, and parses
// them as a single package. It writes out a .go file with any
// polymorphic code rewritten into normal code.
func Rewrite(importer *Importer, dir string) error {
	_, err := rewriteToPkgs(importer, "", dir)
	return err
}

// rewriteToPkgs rewrites the contents of a single directory,
// and returns the types.Packages that it computes.
func rewriteToPkgs(importer *Importer, importPath, dir string) ([]*types.Package, error) {
	go2files, gofiles, err := go2Files(dir)
	if err != nil {
		return nil, err
	}

	if err := checkAndRemoveGofiles(dir, gofiles); err != nil {
		return nil, err
	}

	return rewriteFilesInPath(importer, importPath, dir, go2files)
}

// namedAST holds a file name and the AST parsed from that file.
type namedAST struct {
	name string
	ast  *ast.File
}

// RewriteFiles rewrites a set of .go2 files in dir.
func RewriteFiles(importer *Importer, dir string, go2files []string) ([]*types.Package, error) {
	return rewriteFilesInPath(importer, "", dir, go2files)
}

// rewriteFilesInPath rewrites a set of .go2 files in dir for importPath.
func rewriteFilesInPath(importer *Importer, importPath, dir string, go2files []string) ([]*types.Package, error) {
	fset := token.NewFileSet()
	pkgs, err := parseFiles(importer, dir, go2files, fset)
	if err != nil {
		return nil, err
	}

	var rpkgs []*types.Package
	var tpkgs [][]namedAST
	for _, pkg := range pkgs {
		pkgfiles := make([]namedAST, 0, len(pkg.Files))
		for n, f := range pkg.Files {
			pkgfiles = append(pkgfiles, namedAST{n, f})
		}
		sort.Slice(pkgfiles, func(i, j int) bool {
			return pkgfiles[i].name < pkgfiles[j].name
		})

		asts := make([]*ast.File, 0, len(pkgfiles))
		for _, a := range pkgfiles {
			asts = append(asts, a.ast)
		}

		var merr multiErr
		conf := types.Config{
			Importer: importer,
			Error:    merr.add,
		}
		path := importPath
		if path == "" {
			path = pkg.Name
		}
		tpkg, err := conf.Check(path, fset, asts, importer.info)
		if err != nil {
			return nil, fmt.Errorf("type checking failed for %s\n%v", pkg.Name, merr)
		}

		importer.record(pkg.Name, pkgfiles, importPath, tpkg, asts)

		rpkgs = append(rpkgs, tpkg)
		tpkgs = append(tpkgs, pkgfiles)
	}

	for i, tpkg := range tpkgs {
		addImportable := 0
		for j, pkgfile := range tpkg {
			if !strings.HasSuffix(pkgfile.name, "_test.go2") {
				addImportable = j
				break
			}
		}

		for j, pkgfile := range tpkg {
			if err := rewriteFile(dir, fset, importer, importPath, rpkgs[i], pkgfile.name, pkgfile.ast, j == addImportable); err != nil {
				return nil, err
			}
		}
	}

	return rpkgs, nil
}

// RewriteBuffer rewrites the contents of a single file, in a buffer.
// It returns a modified buffer. The filename parameter is only used
// for error messages.
func RewriteBuffer(importer *Importer, filename string, file []byte) ([]byte, error) {
	fset := token.NewFileSet()
	pf, err := parser.ParseFile(fset, filename, file, parseTypeParams)
	if err != nil {
		return nil, err
	}
	var merr multiErr
	conf := types.Config{
		Importer: importer,
		Error:    merr.add,
	}
	tpkg, err := conf.Check(pf.Name.Name, fset, []*ast.File{pf}, importer.info)
	if err != nil {
		return nil, fmt.Errorf("type checking failed for %s\n%v", pf.Name.Name, merr)
	}
	importer.addIDs(pf)
	if err := rewriteAST(fset, importer, "", tpkg, pf, true); err != nil {
		return nil, err
	}
	var buf bytes.Buffer
	fmt.Fprintln(&buf, rewritePrefix)
	if err := config.Fprint(&buf, fset, pf); err != nil {
		return nil, err
	}
	return buf.Bytes(), nil
}

// go2Files returns the list of files in dir with a .go2 extension
// and a list of files with a .go extension.
// This returns an error if it finds any .go files that do not start
// with rewritePrefix.
func go2Files(dir string) (go2files []string, gofiles []string, err error) {
	f, err := os.Open(dir)
	if err != nil {
		return nil, nil, err
	}
	defer f.Close()

	files, err := f.Readdirnames(0)
	if err != nil {
		return nil, nil, fmt.Errorf("reading directory %s: %w", dir, err)

	}

	go2files = make([]string, 0, len(files))
	gofiles = make([]string, 0, len(files))
	for _, f := range files {
		switch filepath.Ext(f) {
		case ".go2":
			go2files = append(go2files, f)
		case ".go":
			gofiles = append(gofiles, f)
		}
	}

	return go2files, gofiles, nil
}

// checkAndRemoveGofiles looks through all the .go files.
// Any .go file that starts with rewritePrefix is removed.
// Any other .go file is reported as an error.
// This is intended to make it harder for go2go to break a
// traditional Go package.
func checkAndRemoveGofiles(dir string, gofiles []string) error {
	for _, f := range gofiles {
		if err := checkGoFile(dir, f); err != nil {
			return err
		}
		if err := os.Remove(filepath.Join(dir, f)); err != nil {
			return err
		}
	}
	return nil
}

// checkGofile reports an error if the file does not start with rewritePrefix.
func checkGoFile(dir, f string) error {
	o, err := os.Open(filepath.Join(dir, f))
	if err != nil {
		return err
	}
	defer o.Close()
	var buf [100]byte
	n, err := o.Read(buf[:])
	if n > 0 && !strings.HasPrefix(string(buf[:n]), rewritePrefix) {
		return fmt.Errorf("Go file %s was not created by go2go", f)
	}
	if err != nil && err != io.EOF {
		return err
	}
	return nil
}

// parseFiles parses a list of .go2 files.
func parseFiles(importer *Importer, dir string, go2files []string, fset *token.FileSet) ([]*ast.Package, error) {
	pkgs := make(map[string]*ast.Package)
	for _, go2f := range go2files {
		mode := parseTypeParams

		filename := filepath.Join(dir, go2f)
		pf, err := parser.ParseFile(fset, filename, nil, mode)
		if err != nil {
			return nil, err
		}

		name := pf.Name.Name
		pkg, ok := pkgs[name]
		if !ok {
			pkg = &ast.Package{
				Name:  name,
				Files: make(map[string]*ast.File),
			}
			pkgs[name] = pkg
		}
		pkg.Files[filename] = pf
	}

	rpkgs := make([]*ast.Package, 0, len(pkgs))
	for _, pkg := range pkgs {
		rpkgs = append(rpkgs, pkg)
	}
	sort.Slice(rpkgs, func(i, j int) bool {
		return rpkgs[i].Name < rpkgs[j].Name
	})

	return rpkgs, nil
}

// multiErr is an error value that accumulates type checking errors.
type multiErr []error

// The add methods adds another error to a multiErr.
func (m *multiErr) add(err error) {
	*m = append(*m, err)
}

// The Error method returns the accumulated errors.
func (m multiErr) Error() string {
	if len(m) == 0 {
		return "internal error: empty multiErr"
	}
	var sb strings.Builder
	for _, e := range m {
		fmt.Fprintln(&sb, e)
	}
	return sb.String()
}
