// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

// The mkmerge command parses generated source files and merges common
// consts, funcs, and types into a common source file, per GOOS.
//
// Usage:
//
//	$ mkmerge -out MERGED FILE [FILE ...]
//
// Example:
//
//	# Remove all common consts, funcs, and types from zerrors_linux_*.go
//	# and write the common code into zerrors_linux.go
//	$ mkmerge -out zerrors_linux.go zerrors_linux_*.go
//
// mkmerge performs the merge in the following steps:
//  1. Construct the set of common code that is identical in all
//     architecture-specific files.
//  2. Write this common code to the merged file.
//  3. Remove the common code from all architecture-specific files.
package main

import (
	"bufio"
	"bytes"
	"flag"
	"fmt"
	"go/ast"
	"go/format"
	"go/parser"
	"go/token"
	"io"
	"log"
	"os"
	"path"
	"path/filepath"
	"regexp"
	"strconv"
	"strings"
)

const validGOOS = "aix|darwin|dragonfly|freebsd|linux|netbsd|openbsd|solaris"

// getValidGOOS returns GOOS, true if filename ends with a valid "_GOOS.go"
func getValidGOOS(filename string) (string, bool) {
	matches := regexp.MustCompile(`_(` + validGOOS + `)\.go$`).FindStringSubmatch(filename)
	if len(matches) != 2 {
		return "", false
	}
	return matches[1], true
}

// codeElem represents an ast.Decl in a comparable way.
type codeElem struct {
	tok token.Token // e.g. token.CONST, token.TYPE, or token.FUNC
	src string      // the declaration formatted as source code
}

// newCodeElem returns a codeElem based on tok and node, or an error is returned.
func newCodeElem(tok token.Token, node ast.Node) (codeElem, error) {
	var b strings.Builder
	err := format.Node(&b, token.NewFileSet(), node)
	if err != nil {
		return codeElem{}, err
	}
	return codeElem{tok, b.String()}, nil
}

// codeSet is a set of codeElems
type codeSet struct {
	set map[codeElem]bool // true for all codeElems in the set
}

// newCodeSet returns a new codeSet
func newCodeSet() *codeSet { return &codeSet{make(map[codeElem]bool)} }

// add adds elem to c
func (c *codeSet) add(elem codeElem) { c.set[elem] = true }

// has returns true if elem is in c
func (c *codeSet) has(elem codeElem) bool { return c.set[elem] }

// isEmpty returns true if the set is empty
func (c *codeSet) isEmpty() bool { return len(c.set) == 0 }

// intersection returns a new set which is the intersection of c and a
func (c *codeSet) intersection(a *codeSet) *codeSet {
	res := newCodeSet()

	for elem := range c.set {
		if a.has(elem) {
			res.add(elem)
		}
	}
	return res
}

// keepCommon is a filterFn for filtering the merged file with common declarations.
func (c *codeSet) keepCommon(elem codeElem) bool {
	switch elem.tok {
	case token.VAR:
		// Remove all vars from the merged file
		return false
	case token.CONST, token.TYPE, token.FUNC, token.COMMENT:
		// Remove arch-specific consts, types, functions, and file-level comments from the merged file
		return c.has(elem)
	case token.IMPORT:
		// Keep imports, they are handled by filterImports
		return true
	}

	log.Fatalf("keepCommon: invalid elem %v", elem)
	return true
}

// keepArchSpecific is a filterFn for filtering the GOARC-specific files.
func (c *codeSet) keepArchSpecific(elem codeElem) bool {
	switch elem.tok {
	case token.CONST, token.TYPE, token.FUNC:
		// Remove common consts, types, or functions from the arch-specific file
		return !c.has(elem)
	}
	return true
}

// srcFile represents a source file
type srcFile struct {
	name string
	src  []byte
}

// filterFn is a helper for filter
type filterFn func(codeElem) bool

// filter parses and filters Go source code from src, removing top
// level declarations using keep as predicate.
// For src parameter, please see docs for parser.ParseFile.
func filter(src interface{}, keep filterFn) ([]byte, error) {
	// Parse the src into an ast
	fset := token.NewFileSet()
	f, err := parser.ParseFile(fset, "", src, parser.ParseComments)
	if err != nil {
		return nil, err
	}
	cmap := ast.NewCommentMap(fset, f, f.Comments)

	// Group const/type specs on adjacent lines
	var groups specGroups = make(map[string]int)
	var groupID int

	decls := f.Decls
	f.Decls = f.Decls[:0]
	for _, decl := range decls {
		switch decl := decl.(type) {
		case *ast.GenDecl:
			// Filter imports, consts, types, vars
			specs := decl.Specs
			decl.Specs = decl.Specs[:0]
			for i, spec := range specs {
				elem, err := newCodeElem(decl.Tok, spec)
				if err != nil {
					return nil, err
				}

				// Create new group if there are empty lines between this and the previous spec
				if i > 0 && fset.Position(specs[i-1].End()).Line < fset.Position(spec.Pos()).Line-1 {
					groupID++
				}

				// Check if we should keep this spec
				if keep(elem) {
					decl.Specs = append(decl.Specs, spec)
					groups.add(elem.src, groupID)
				}
			}
			// Check if we should keep this decl
			if len(decl.Specs) > 0 {
				f.Decls = append(f.Decls, decl)
			}
		case *ast.FuncDecl:
			// Filter funcs
			elem, err := newCodeElem(token.FUNC, decl)
			if err != nil {
				return nil, err
			}
			if keep(elem) {
				f.Decls = append(f.Decls, decl)
			}
		}
	}

	// Filter file level comments
	if cmap[f] != nil {
		commentGroups := cmap[f]
		cmap[f] = cmap[f][:0]
		for _, cGrp := range commentGroups {
			if keep(codeElem{token.COMMENT, cGrp.Text()}) {
				cmap[f] = append(cmap[f], cGrp)
			}
		}
	}
	f.Comments = cmap.Filter(f).Comments()

	// Generate code for the filtered ast
	var buf bytes.Buffer
	if err = format.Node(&buf, fset, f); err != nil {
		return nil, err
	}

	groupedSrc, err := groups.filterEmptyLines(&buf)
	if err != nil {
		return nil, err
	}

	return filterImports(groupedSrc)
}

// getCommonSet returns the set of consts, types, and funcs that are present in every file.
func getCommonSet(files []srcFile) (*codeSet, error) {
	if len(files) == 0 {
		return nil, fmt.Errorf("no files provided")
	}
	// Use the first architecture file as the baseline
	baseSet, err := getCodeSet(files[0].src)
	if err != nil {
		return nil, err
	}

	// Compare baseline set with other architecture files: discard any element,
	// that doesn't exist in other architecture files.
	for _, f := range files[1:] {
		set, err := getCodeSet(f.src)
		if err != nil {
			return nil, err
		}

		baseSet = baseSet.intersection(set)
	}
	return baseSet, nil
}

// getCodeSet returns the set of all top-level consts, types, and funcs from src.
// src must be string, []byte, or io.Reader (see go/parser.ParseFile docs)
func getCodeSet(src interface{}) (*codeSet, error) {
	set := newCodeSet()

	fset := token.NewFileSet()
	f, err := parser.ParseFile(fset, "", src, parser.ParseComments)
	if err != nil {
		return nil, err
	}

	for _, decl := range f.Decls {
		switch decl := decl.(type) {
		case *ast.GenDecl:
			// Add const, and type declarations
			if !(decl.Tok == token.CONST || decl.Tok == token.TYPE) {
				break
			}

			for _, spec := range decl.Specs {
				elem, err := newCodeElem(decl.Tok, spec)
				if err != nil {
					return nil, err
				}

				set.add(elem)
			}
		case *ast.FuncDecl:
			// Add func declarations
			elem, err := newCodeElem(token.FUNC, decl)
			if err != nil {
				return nil, err
			}

			set.add(elem)
		}
	}

	// Add file level comments
	cmap := ast.NewCommentMap(fset, f, f.Comments)
	for _, cGrp := range cmap[f] {
		text := cGrp.Text()
		if text == "" && len(cGrp.List) == 1 && strings.HasPrefix(cGrp.List[0].Text, "//go:build ") {
			// ast.CommentGroup.Text doesn't include comment directives like "//go:build"
			// in the text. So if a comment group has empty text and a single //go:build
			// constraint line, make a custom codeElem. This is enough for mkmerge needs.
			set.add(codeElem{token.COMMENT, cGrp.List[0].Text[len("//"):] + "\n"})
			continue
		}
		set.add(codeElem{token.COMMENT, text})
	}

	return set, nil
}

// importName returns the identifier (PackageName) for an imported package
func importName(iSpec *ast.ImportSpec) (string, error) {
	if iSpec.Name == nil {
		name, err := strconv.Unquote(iSpec.Path.Value)
		if err != nil {
			return "", err
		}
		return path.Base(name), nil
	}
	return iSpec.Name.Name, nil
}

// specGroups tracks grouped const/type specs with a map of line: groupID pairs
type specGroups map[string]int

// add spec source to group
func (s specGroups) add(src string, groupID int) error {
	srcBytes, err := format.Source(bytes.TrimSpace([]byte(src)))
	if err != nil {
		return err
	}
	s[string(srcBytes)] = groupID
	return nil
}

// filterEmptyLines removes empty lines within groups of const/type specs.
// Returns the filtered source.
func (s specGroups) filterEmptyLines(src io.Reader) ([]byte, error) {
	scanner := bufio.NewScanner(src)
	var out bytes.Buffer

	var emptyLines bytes.Buffer
	prevGroupID := -1 // Initialize to invalid group
	for scanner.Scan() {
		line := bytes.TrimSpace(scanner.Bytes())

		if len(line) == 0 {
			fmt.Fprintf(&emptyLines, "%s\n", scanner.Bytes())
			continue
		}

		// Discard emptyLines if previous non-empty line belonged to the same
		// group as this line
		if src, err := format.Source(line); err == nil {
			groupID, ok := s[string(src)]
			if ok && groupID == prevGroupID {
				emptyLines.Reset()
			}
			prevGroupID = groupID
		}

		emptyLines.WriteTo(&out)
		fmt.Fprintf(&out, "%s\n", scanner.Bytes())
	}
	if err := scanner.Err(); err != nil {
		return nil, err
	}
	return out.Bytes(), nil
}

// filterImports removes unused imports from fileSrc, and returns a formatted src.
func filterImports(fileSrc []byte) ([]byte, error) {
	fset := token.NewFileSet()
	file, err := parser.ParseFile(fset, "", fileSrc, parser.ParseComments)
	if err != nil {
		return nil, err
	}
	cmap := ast.NewCommentMap(fset, file, file.Comments)

	// create set of references to imported identifiers
	keepImport := make(map[string]bool)
	for _, u := range file.Unresolved {
		keepImport[u.Name] = true
	}

	// filter import declarations
	decls := file.Decls
	file.Decls = file.Decls[:0]
	for _, decl := range decls {
		importDecl, ok := decl.(*ast.GenDecl)

		// Keep non-import declarations
		if !ok || importDecl.Tok != token.IMPORT {
			file.Decls = append(file.Decls, decl)
			continue
		}

		// Filter the import specs
		specs := importDecl.Specs
		importDecl.Specs = importDecl.Specs[:0]
		for _, spec := range specs {
			iSpec := spec.(*ast.ImportSpec)
			name, err := importName(iSpec)
			if err != nil {
				return nil, err
			}

			if keepImport[name] {
				importDecl.Specs = append(importDecl.Specs, iSpec)
			}
		}
		if len(importDecl.Specs) > 0 {
			file.Decls = append(file.Decls, importDecl)
		}
	}

	// filter file.Imports
	imports := file.Imports
	file.Imports = file.Imports[:0]
	for _, spec := range imports {
		name, err := importName(spec)
		if err != nil {
			return nil, err
		}

		if keepImport[name] {
			file.Imports = append(file.Imports, spec)
		}
	}
	file.Comments = cmap.Filter(file).Comments()

	var buf bytes.Buffer
	err = format.Node(&buf, fset, file)
	if err != nil {
		return nil, err
	}

	return buf.Bytes(), nil
}

// merge extracts duplicate code from archFiles and merges it to mergeFile.
// 1. Construct commonSet: the set of code that is idential in all archFiles.
// 2. Write the code in commonSet to mergedFile.
// 3. Remove the commonSet code from all archFiles.
func merge(mergedFile string, archFiles ...string) error {
	// extract and validate the GOOS part of the merged filename
	goos, ok := getValidGOOS(mergedFile)
	if !ok {
		return fmt.Errorf("invalid GOOS in merged file name %s", mergedFile)
	}

	// Read architecture files
	var inSrc []srcFile
	for _, file := range archFiles {
		src, err := os.ReadFile(file)
		if err != nil {
			return fmt.Errorf("cannot read archfile %s: %w", file, err)
		}

		inSrc = append(inSrc, srcFile{file, src})
	}

	// 1. Construct the set of top-level declarations common for all files
	commonSet, err := getCommonSet(inSrc)
	if err != nil {
		return err
	}
	if commonSet.isEmpty() {
		// No common code => do not modify any files
		return nil
	}

	// 2. Write the merged file
	mergedSrc, err := filter(inSrc[0].src, commonSet.keepCommon)
	if err != nil {
		return err
	}

	f, err := os.Create(mergedFile)
	if err != nil {
		return err
	}

	buf := bufio.NewWriter(f)
	fmt.Fprintln(buf, "// Code generated by mkmerge; DO NOT EDIT.")
	fmt.Fprintln(buf)
	fmt.Fprintf(buf, "//go:build %s\n", goos)
	fmt.Fprintln(buf)
	buf.Write(mergedSrc)

	err = buf.Flush()
	if err != nil {
		return err
	}
	err = f.Close()
	if err != nil {
		return err
	}

	// 3. Remove duplicate declarations from the architecture files
	for _, inFile := range inSrc {
		src, err := filter(inFile.src, commonSet.keepArchSpecific)
		if err != nil {
			return err
		}
		err = os.WriteFile(inFile.name, src, 0644)
		if err != nil {
			return err
		}
	}
	return nil
}

func main() {
	var mergedFile string
	flag.StringVar(&mergedFile, "out", "", "Write merged code to `FILE`")
	flag.Parse()

	// Expand wildcards
	var filenames []string
	for _, arg := range flag.Args() {
		matches, err := filepath.Glob(arg)
		if err != nil {
			fmt.Fprintf(os.Stderr, "Invalid command line argument %q: %v\n", arg, err)
			os.Exit(1)
		}
		filenames = append(filenames, matches...)
	}

	if len(filenames) < 2 {
		// No need to merge
		return
	}

	err := merge(mergedFile, filenames...)
	if err != nil {
		fmt.Fprintf(os.Stderr, "Merge failed with error: %v\n", err)
		os.Exit(1)
	}
}
