blob: 960f1151993d576eb45a4477091902a60dd7b24d [file] [log] [blame]
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package packagestest
import (
"fmt"
"go/token"
"reflect"
"regexp"
"strings"
"golang.org/x/tools/go/expect"
)
const (
markMethod = "mark"
eofIdentifier = "EOF"
)
// Expect invokes the supplied methods for all expectation notes found in
// the exported source files.
//
// All exported go source files are parsed to collect the expectation
// notes.
// See the documentation for expect.Parse for how the notes are collected
// and parsed.
//
// The methods are supplied as a map of name to function, and those functions
// will be matched against the expectations by name.
// Notes with no matching function will be skipped, and functions with no
// matching notes will not be invoked.
// If there are no registered markers yet, a special pass will be run first
// which adds any markers declared with @mark(Name, pattern) or @name. These
// call the Mark method to add the marker to the global set.
// You can register the "mark" method to override these in your own call to
// Expect. The bound Mark function is usable directly in your method map, so
// exported.Expect(map[string]interface{}{"mark": exported.Mark})
// replicates the built in behavior.
//
// Method invocation
//
// When invoking a method the expressions in the parameter list need to be
// converted to values to be passed to the method.
// There are a very limited set of types the arguments are allowed to be.
// expect.Comment : passed the Comment instance being evaluated.
// string : can be supplied either a string literal or an identifier.
// int : can only be supplied an integer literal.
// token.Pos : has a file position calculated as described below.
// token.Position : has a file position calculated as described below.
//
// Position calculation
//
// There is some extra handling when a parameter is being coerced into a
// token.Pos, token.Position or Range type argument.
//
// If the parameter is an identifier, it will be treated as the name of an
// marker to look up (as if markers were global variables).
//
// If it is a string or regular expression, then it will be passed to
// expect.MatchBefore to look up a match in the line at which it was declared.
//
// It is safe to call this repeatedly with different method sets, but it is
// not safe to call it concurrently.
func (e *Exported) Expect(methods map[string]interface{}) error {
if err := e.getNotes(); err != nil {
return err
}
if err := e.getMarkers(); err != nil {
return err
}
var err error
ms := make(map[string]method, len(methods))
for name, f := range methods {
mi := method{f: reflect.ValueOf(f)}
mi.converters = make([]converter, mi.f.Type().NumIn())
for i := 0; i < len(mi.converters); i++ {
mi.converters[i], err = e.buildConverter(mi.f.Type().In(i))
if err != nil {
return fmt.Errorf("invalid method %v: %v", name, err)
}
}
ms[name] = mi
}
for _, n := range e.notes {
if n.Args == nil {
// simple identifier form, convert to a call to mark
n = &expect.Note{
Pos: n.Pos,
Name: markMethod,
Args: []interface{}{n.Name, n.Name},
}
}
mi, ok := ms[n.Name]
if !ok {
continue
}
params := make([]reflect.Value, len(mi.converters))
args := n.Args
for i, convert := range mi.converters {
params[i], args, err = convert(n, args)
if err != nil {
return fmt.Errorf("%v: %v", e.fset.Position(n.Pos), err)
}
}
if len(args) > 0 {
return fmt.Errorf("%v: unwanted args got %+v extra", e.fset.Position(n.Pos), args)
}
//TODO: catch the error returned from the method
mi.f.Call(params)
}
return nil
}
type Range struct {
Start token.Pos
End token.Pos
}
// Mark adds a new marker to the known set.
func (e *Exported) Mark(name string, r Range) {
if e.markers == nil {
e.markers = make(map[string]Range)
}
e.markers[name] = r
}
func (e *Exported) getNotes() error {
if e.notes != nil {
return nil
}
notes := []*expect.Note{}
for _, module := range e.written {
for _, filename := range module {
if !strings.HasSuffix(filename, ".go") {
continue
}
l, err := expect.Parse(e.fset, filename, nil)
if err != nil {
return fmt.Errorf("Failed to extract expectations: %v", err)
}
notes = append(notes, l...)
}
}
e.notes = notes
return nil
}
func (e *Exported) getMarkers() error {
if e.markers != nil {
return nil
}
// set markers early so that we don't call getMarkers again from Expect
e.markers = make(map[string]Range)
return e.Expect(map[string]interface{}{
markMethod: e.Mark,
})
}
var (
noteType = reflect.TypeOf((*expect.Note)(nil))
identifierType = reflect.TypeOf(expect.Identifier(""))
posType = reflect.TypeOf(token.Pos(0))
positionType = reflect.TypeOf(token.Position{})
rangeType = reflect.TypeOf(Range{})
)
// converter converts from a marker's argument parsed from the comment to
// reflect values passed to the method during Invoke.
// It takes the args remaining, and returns the args it did not consume.
// This allows a converter to consume 0 args for well known types, or multiple
// args for compound types.
type converter func(*expect.Note, []interface{}) (reflect.Value, []interface{}, error)
// method is used to track information about Invoke methods that is expensive to
// calculate so that we can work it out once rather than per marker.
type method struct {
f reflect.Value // the reflect value of the passed in method
converters []converter // the parameter converters for the method
}
// buildConverter works out what function should be used to go from an ast expressions to a reflect
// value of the type expected by a method.
// It is called when only the target type is know, it returns converters that are flexible across
// all supported expression types for that target type.
func (e *Exported) buildConverter(pt reflect.Type) (converter, error) {
switch {
case pt == noteType:
return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
return reflect.ValueOf(n), args, nil
}, nil
case pt == posType:
return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
r, remains, err := e.rangeConverter(n, args)
if err != nil {
return reflect.Value{}, nil, err
}
return reflect.ValueOf(r.Start), remains, nil
}, nil
case pt == positionType:
return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
r, remains, err := e.rangeConverter(n, args)
if err != nil {
return reflect.Value{}, nil, err
}
return reflect.ValueOf(e.fset.Position(r.Start)), remains, nil
}, nil
case pt == rangeType:
return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
r, remains, err := e.rangeConverter(n, args)
if err != nil {
return reflect.Value{}, nil, err
}
return reflect.ValueOf(r), remains, nil
}, nil
case pt == identifierType:
return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
arg := args[0]
args = args[1:]
switch arg := arg.(type) {
case expect.Identifier:
return reflect.ValueOf(arg), args, nil
default:
return reflect.Value{}, nil, fmt.Errorf("cannot convert %v to string", arg)
}
}, nil
case pt.Kind() == reflect.String:
return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
arg := args[0]
args = args[1:]
switch arg := arg.(type) {
case expect.Identifier:
return reflect.ValueOf(string(arg)), args, nil
case string:
return reflect.ValueOf(arg), args, nil
default:
return reflect.Value{}, nil, fmt.Errorf("cannot convert %v to string", arg)
}
}, nil
case pt.Kind() == reflect.Int64:
return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
arg := args[0]
args = args[1:]
switch arg := arg.(type) {
case int64:
return reflect.ValueOf(arg), args, nil
default:
return reflect.Value{}, nil, fmt.Errorf("cannot convert %v to int", arg)
}
}, nil
case pt.Kind() == reflect.Bool:
return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
arg := args[0]
args = args[1:]
b, ok := arg.(bool)
if !ok {
return reflect.Value{}, nil, fmt.Errorf("cannot convert %v to bool", arg)
}
return reflect.ValueOf(b), args, nil
}, nil
case pt.Kind() == reflect.Slice:
return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
converter, err := e.buildConverter(pt.Elem())
if err != nil {
return reflect.Value{}, nil, err
}
result := reflect.MakeSlice(reflect.SliceOf(pt.Elem()), 0, len(args))
for range args {
value, remains, err := converter(n, args)
if err != nil {
return reflect.Value{}, nil, err
}
result = reflect.Append(result, value)
args = remains
}
return result, args, nil
}, nil
default:
return nil, fmt.Errorf("param has invalid type %v", pt)
}
}
func (e *Exported) rangeConverter(n *expect.Note, args []interface{}) (Range, []interface{}, error) {
if len(args) < 1 {
return Range{}, nil, fmt.Errorf("missing argument")
}
arg := args[0]
args = args[1:]
switch arg := arg.(type) {
case expect.Identifier:
// handle the special identifiers
switch arg {
case eofIdentifier:
// end of file identifier, look up the current file
f := e.fset.File(n.Pos)
eof := f.Pos(f.Size())
return Range{Start: eof, End: token.NoPos}, args, nil
default:
// look up an marker by name
mark, ok := e.markers[string(arg)]
if !ok {
return Range{}, nil, fmt.Errorf("cannot find marker %v", arg)
}
return mark, args, nil
}
case string:
start, end, err := expect.MatchBefore(e.fset, e.fileContents, n.Pos, arg)
if err != nil {
return Range{}, nil, err
}
if start == token.NoPos {
return Range{}, nil, fmt.Errorf("%v: pattern %s did not match", e.fset.Position(n.Pos), arg)
}
return Range{Start: start, End: end}, args, nil
case *regexp.Regexp:
start, end, err := expect.MatchBefore(e.fset, e.fileContents, n.Pos, arg)
if err != nil {
return Range{}, nil, err
}
if start == token.NoPos {
return Range{}, nil, fmt.Errorf("%v: pattern %s did not match", e.fset.Position(n.Pos), arg)
}
return Range{Start: start, End: end}, args, nil
default:
return Range{}, nil, fmt.Errorf("cannot convert %v to pos", arg)
}
}