blob: a7a1d986e4cb19f2a55e70e8ce02baa91c8c10b4 [file] [log] [blame] [edit]
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package unify
import (
"errors"
"fmt"
"io"
"io/fs"
"os"
"path/filepath"
"regexp"
"strings"
"gopkg.in/yaml.v3"
)
// ReadOpts provides options to [Read] and related functions. The zero value is
// the default options.
type ReadOpts struct {
// FS, if non-nil, is the file system from which to resolve !import file
// names.
FS fs.FS
}
// Read reads a [Closure] in YAML format from r, using path for error messages.
//
// It maps YAML nodes into terminal Values as follows:
//
// - "_" or !top _ is the top value ([Top]).
//
// - "_|_" or !bottom _ is the bottom value. This is an error during
// unmarshaling, but can appear in marshaled values.
//
// - "$<name>" or !var <name> is a variable ([Var]). Everywhere the same name
// appears within a single unmarshal operation, it is mapped to the same
// variable. Different unmarshal operations get different variables, even if
// they have the same string name.
//
// - !regex "x" is a regular expression ([String]), as is any string that
// doesn't match "_", "_|_", or "$...". Regular expressions are implicitly
// anchored at the beginning and end. If the string doesn't contain any
// meta-characters (that is, it's a "literal" regular expression), then it's
// treated as an exact string.
//
// - !string "x", or any int, float, bool, or binary value is an exact string
// ([String]).
//
// - !regex [x, y, ...] is an intersection of regular expressions ([String]).
//
// It maps YAML nodes into non-terminal Values as follows:
//
// - Sequence nodes like [x, y, z] are tuples ([Tuple]).
//
// - !repeat [x] is a repeated tuple ([Tuple]), which is 0 or more instances of
// x. There must be exactly one element in the list.
//
// - Mapping nodes like {a: x, b: y} are defs ([Def]). Any fields not listed are
// implicitly top.
//
// - !sum [x, y, z] is a sum of its children. This can be thought of as a union
// of the values x, y, and z, or as a non-deterministic choice between x, y, and
// z. If a variable appears both inside the sum and outside of it, only the
// non-deterministic choice view really works. The unifier does not directly
// implement sums; instead, this is decoded as a fresh variable that's
// simultaneously bound to x, y, and z.
//
// - !import glob is like a !sum, but its children are read from all files
// matching the given glob pattern, which is interpreted relative to the current
// file path. Each file gets its own variable scope.
func Read(r io.Reader, path string, opts ReadOpts) (Closure, error) {
dec := yamlDecoder{opts: opts, path: path, env: topEnv}
v, err := dec.read(r)
if err != nil {
return Closure{}, err
}
return dec.close(v), nil
}
// ReadFile reads a [Closure] in YAML format from a file.
//
// The file must consist of a single YAML document.
//
// If opts.FS is not set, this sets it to a FS rooted at path's directory.
//
// See [Read] for details.
func ReadFile(path string, opts ReadOpts) (Closure, error) {
f, err := os.Open(path)
if err != nil {
return Closure{}, err
}
defer f.Close()
if opts.FS == nil {
opts.FS = os.DirFS(filepath.Dir(path))
}
return Read(f, path, opts)
}
// UnmarshalYAML implements [yaml.Unmarshaler].
//
// Since there is no way to pass [ReadOpts] to this function, it assumes default
// options.
func (c *Closure) UnmarshalYAML(node *yaml.Node) error {
dec := yamlDecoder{path: "<yaml.Node>", env: topEnv}
v, err := dec.root(node)
if err != nil {
return err
}
*c = dec.close(v)
return nil
}
type yamlDecoder struct {
opts ReadOpts
path string
vars map[string]*ident
nSums int
env envSet
}
func (dec *yamlDecoder) read(r io.Reader) (*Value, error) {
n, err := readOneNode(r)
if err != nil {
return nil, fmt.Errorf("%s: %w", dec.path, err)
}
// Decode YAML node to a Value
v, err := dec.root(n)
if err != nil {
return nil, fmt.Errorf("%s: %w", dec.path, err)
}
return v, nil
}
// readOneNode reads a single YAML document from r and returns an error if there
// are more documents in r.
func readOneNode(r io.Reader) (*yaml.Node, error) {
yd := yaml.NewDecoder(r)
// Decode as a YAML node
var node yaml.Node
if err := yd.Decode(&node); err != nil {
return nil, err
}
np := &node
if np.Kind == yaml.DocumentNode {
np = node.Content[0]
}
// Ensure there are no more YAML docs in this file
if err := yd.Decode(nil); err == nil {
return nil, fmt.Errorf("must not contain multiple documents")
} else if err != io.EOF {
return nil, err
}
return np, nil
}
// root parses the root of a file.
func (dec *yamlDecoder) root(node *yaml.Node) (*Value, error) {
// Prepare for variable name resolution in this file. This may be a nested
// root, so restore the current values when we're done.
oldVars, oldNSums := dec.vars, dec.nSums
defer func() {
dec.vars, dec.nSums = oldVars, oldNSums
}()
dec.vars = make(map[string]*ident, 0)
dec.nSums = 0
return dec.value(node)
}
// close wraps a decoded [Value] into a [Closure].
func (dec *yamlDecoder) close(v *Value) Closure {
return Closure{v, dec.env}
}
func (dec *yamlDecoder) value(node *yaml.Node) (vOut *Value, errOut error) {
pos := &Pos{Path: dec.path, Line: node.Line}
// Resolve alias nodes.
if node.Kind == yaml.AliasNode {
node = node.Alias
}
mk := func(d Domain) (*Value, error) {
v := &Value{Domain: d, pos: pos}
return v, nil
}
mk2 := func(d Domain, err error) (*Value, error) {
if err != nil {
return nil, err
}
return mk(d)
}
// is tests the kind and long tag of node.
is := func(kind yaml.Kind, tag string) bool {
return node.Kind == kind && node.LongTag() == tag
}
isExact := func() bool {
if node.Kind != yaml.ScalarNode {
return false
}
// We treat any string-ish YAML node as a string.
switch node.LongTag() {
case "!string", "tag:yaml.org,2002:int", "tag:yaml.org,2002:float", "tag:yaml.org,2002:bool", "tag:yaml.org,2002:binary":
return true
}
return false
}
// !!str nodes provide a short-hand syntax for several leaf domains that are
// also available under explicit tags. To simplify checking below, we set
// strVal to non-"" only for !!str nodes.
strVal := ""
isStr := is(yaml.ScalarNode, "tag:yaml.org,2002:str")
if isStr {
strVal = node.Value
}
switch {
case is(yaml.ScalarNode, "!var"):
strVal = "$" + node.Value
fallthrough
case strings.HasPrefix(strVal, "$"):
id, ok := dec.vars[strVal]
if !ok {
// We encode different idents with the same string name by adding a
// #N suffix. Strip that off so it doesn't accumulate. This isn't
// meant to be used in user-written input, though nothing stops that.
name, _, _ := strings.Cut(strVal, "#")
id = &ident{name: name}
dec.vars[strVal] = id
dec.env = dec.env.bind(id, topValue)
}
return mk(Var{id: id})
case strVal == "_" || is(yaml.ScalarNode, "!top"):
return mk(Top{})
case strVal == "_|_" || is(yaml.ScalarNode, "!bottom"):
return nil, errors.New("found bottom")
case isExact():
val := node.Value
return mk(NewStringExact(val))
case isStr || is(yaml.ScalarNode, "!regex"):
// Any other string we treat as a regex. This will produce an exact
// string anyway if the regex is literal.
val := node.Value
return mk2(NewStringRegex(val))
case is(yaml.SequenceNode, "!regex"):
var vals []string
if err := node.Decode(&vals); err != nil {
return nil, err
}
return mk2(NewStringRegex(vals...))
case is(yaml.MappingNode, "tag:yaml.org,2002:map"):
var db DefBuilder
for i := 0; i < len(node.Content); i += 2 {
key := node.Content[i]
if key.Kind != yaml.ScalarNode {
return nil, fmt.Errorf("non-scalar key %q", key.Value)
}
val, err := dec.value(node.Content[i+1])
if err != nil {
return nil, err
}
db.Add(key.Value, val)
}
return mk(db.Build())
case is(yaml.SequenceNode, "tag:yaml.org,2002:seq"):
elts := node.Content
vs := make([]*Value, 0, len(elts))
for _, elt := range elts {
v, err := dec.value(elt)
if err != nil {
return nil, err
}
vs = append(vs, v)
}
return mk(NewTuple(vs...))
case is(yaml.SequenceNode, "!repeat") || is(yaml.SequenceNode, "!repeat-unify"):
// !repeat must have one child. !repeat-unify is used internally for
// delayed unification, and is the same, it's just allowed to have more
// than one child.
if node.LongTag() == "!repeat" && len(node.Content) != 1 {
return nil, fmt.Errorf("!repeat must have exactly one child")
}
// Decode the children to make sure they're well-formed, but otherwise
// discard that decoding and do it again every time we need a new
// element.
var gen []func(e envSet) (*Value, envSet)
origEnv := dec.env
elts := node.Content
for i, elt := range elts {
_, err := dec.value(elt)
if err != nil {
return nil, err
}
// Undo any effects on the environment. We *do* keep any named
// variables that were added to the vars map in case they were
// introduced within the element.
//
// TODO: If we change how we implement repeat nodes, we might be
// able to drop yamlEncoder.env and yamlDecoder.env.
dec.env = origEnv
// Add a generator function
gen = append(gen, func(e envSet) (*Value, envSet) {
dec.env = e
// TODO: If this is in a sum, this tends to generate a ton of
// fresh variables that are different on each branch of the
// parent sum. Does it make sense to hold on to the i'th value
// of the tuple after we've generated it?
v, err := dec.value(elts[i])
if err != nil {
// It worked the first time, so this really shouldn't hapen.
panic("decoding repeat element failed")
}
return v, dec.env
})
}
return mk(NewRepeat(gen...))
case is(yaml.SequenceNode, "!sum"):
vs := make([]*Value, 0, len(node.Content))
for _, elt := range node.Content {
v, err := dec.value(elt)
if err != nil {
return nil, err
}
vs = append(vs, v)
}
if len(vs) == 1 {
return vs[0], nil
}
// A sum is implemented as a fresh variable that's simultaneously bound
// to each of the descendants.
id := &ident{name: fmt.Sprintf("sum%d", dec.nSums)}
dec.nSums++
dec.env = dec.env.bind(id, vs...)
return mk(Var{id: id})
case is(yaml.ScalarNode, "!import"):
if dec.opts.FS == nil {
return nil, fmt.Errorf("!import not allowed (ReadOpts.FS not set)")
}
pat := node.Value
if !fs.ValidPath(pat) {
// This will result in Glob returning no results. Give a more useful
// error message for this case.
return nil, fmt.Errorf("!import path must not contain '.' or '..'")
}
ms, err := fs.Glob(dec.opts.FS, pat)
if err != nil {
return nil, fmt.Errorf("resolving !import: %w", err)
}
if len(ms) == 0 {
return nil, fmt.Errorf("!import did not match any files")
}
// Parse each file
vs := make([]*Value, 0, len(ms))
for _, m := range ms {
v, err := dec.import1(m)
if err != nil {
return nil, err
}
vs = append(vs, v)
}
// Create a sum.
if len(vs) == 1 {
return vs[0], nil
}
id := &ident{name: "import"}
dec.env = dec.env.bind(id, vs...)
return mk(Var{id: id})
}
return nil, fmt.Errorf("unknown node kind %d %v", node.Kind, node.Tag)
}
func (dec *yamlDecoder) import1(path string) (*Value, error) {
// Make sure we can open the path first.
f, err := dec.opts.FS.Open(path)
if err != nil {
return nil, fmt.Errorf("!import failed: %w", err)
}
defer f.Close()
// Prepare the enter path.
oldFS, oldPath := dec.opts.FS, dec.path
defer func() {
dec.opts.FS, dec.path = oldFS, oldPath
}()
// Enter path, which is relative to the current path's directory.
newPath := filepath.Join(filepath.Dir(dec.path), path)
subFS, err := fs.Sub(dec.opts.FS, filepath.Dir(path))
if err != nil {
return nil, err
}
dec.opts.FS, dec.path = subFS, newPath
// Parse the file.
return dec.read(f)
}
type yamlEncoder struct {
idp identPrinter
e envSet // We track the environment for !repeat nodes.
}
// TODO: Switch some Value marshaling to Closure?
func (c Closure) MarshalYAML() (any, error) {
// TODO: If the environment is trivial, just marshal the value.
enc := &yamlEncoder{}
return enc.closure(c), nil
}
func (c Closure) String() string {
b, err := yaml.Marshal(c)
if err != nil {
return fmt.Sprintf("marshal failed: %s", err)
}
return string(b)
}
func (v *Value) MarshalYAML() (any, error) {
enc := &yamlEncoder{e: topEnv}
return enc.value(v), nil
}
func (v *Value) String() string {
b, err := yaml.Marshal(v)
if err != nil {
return fmt.Sprintf("marshal failed: %s", err)
}
return string(b)
}
func (enc *yamlEncoder) closure(c Closure) *yaml.Node {
enc.e = c.env
var n yaml.Node
n.Kind = yaml.MappingNode
n.Tag = "!closure"
n.Content = make([]*yaml.Node, 4)
n.Content[0] = new(yaml.Node)
n.Content[0].SetString("env")
n.Content[2] = new(yaml.Node)
n.Content[2].SetString("in")
n.Content[3] = enc.value(c.val)
// Fill in the env after we've written the value in case value encoding
// affects the env.
n.Content[1] = enc.env(enc.e)
enc.e = envSet{} // Allow GC'ing the env
return &n
}
func (enc *yamlEncoder) env(e envSet) *yaml.Node {
var encode func(e *envExpr) *yaml.Node
encode = func(e *envExpr) *yaml.Node {
var n yaml.Node
switch e.kind {
default:
panic("bad kind")
case envZero:
n.SetString("0")
case envUnit:
n.SetString("1")
case envBinding:
var id yaml.Node
id.SetString(enc.idp.unique(e.id))
n.Kind = yaml.MappingNode
n.Content = []*yaml.Node{&id, enc.value(e.val)}
case envProduct, envSum:
n.Kind = yaml.SequenceNode
if e.kind == envProduct {
n.Tag = "!product"
} else {
n.Tag = "!sum"
}
for _, e2 := range e.operands {
n.Content = append(n.Content, encode(e2))
}
}
return &n
}
return encode(e.root)
}
var yamlIntRe = regexp.MustCompile(`^-?[0-9]+$`)
func (enc *yamlEncoder) value(v *Value) *yaml.Node {
var n yaml.Node
switch d := v.Domain.(type) {
case nil:
// Not allowed by unmarshaler, but useful for understanding when
// something goes horribly wrong.
//
// TODO: We might be able to track useful provenance for this, which
// would really help with debugging unexpected bottoms.
n.SetString("_|_")
return &n
case Top:
n.SetString("_")
return &n
case Def:
n.Kind = yaml.MappingNode
for k, elt := range d.All() {
var kn yaml.Node
kn.SetString(k)
n.Content = append(n.Content, &kn, enc.value(elt))
}
n.HeadComment = v.PosString()
return &n
case Tuple:
n.Kind = yaml.SequenceNode
if d.repeat == nil {
for _, elt := range d.vs {
n.Content = append(n.Content, enc.value(elt))
}
} else {
if len(d.repeat) == 1 {
n.Tag = "!repeat"
} else {
n.Tag = "!repeat-unify"
}
// TODO: I'm not positive this will round-trip everything correctly.
for _, gen := range d.repeat {
v, e := gen(enc.e)
enc.e = e
n.Content = append(n.Content, enc.value(v))
}
}
return &n
case String:
switch d.kind {
case stringExact:
n.SetString(d.exact)
switch {
// Make this into a "nice" !!int node if I can.
case yamlIntRe.MatchString(d.exact):
n.Tag = "tag:yaml.org,2002:int"
// Or a "nice" !!bool node.
case d.exact == "false" || d.exact == "true":
n.Tag = "tag:yaml.org,2002:bool"
// If this doesn't require escaping, leave it as a str node to avoid
// the annoying YAML tags. Otherwise, mark it as an exact string.
// Alternatively, we could always emit a str node with regexp
// quoting.
case d.exact != regexp.QuoteMeta(d.exact):
n.Tag = "!string"
}
return &n
case stringRegex:
o := make([]string, 0, 1)
for _, re := range d.re {
s := re.String()
s = strings.TrimSuffix(strings.TrimPrefix(s, `\A(?:`), `)\z`)
o = append(o, s)
}
if len(o) == 1 {
n.SetString(o[0])
return &n
}
n.Encode(o)
n.Tag = "!regex"
return &n
}
panic("bad String kind")
case Var:
// TODO: If Var only appears once in the whole Value and is independent
// in the environment (part of a term that is only over Var), then emit
// this as a !sum instead.
if false {
var vs []*Value // TODO: Get values of this var.
if len(vs) == 1 {
return enc.value(vs[0])
}
n.Kind = yaml.SequenceNode
n.Tag = "!sum"
for _, elt := range vs {
n.Content = append(n.Content, enc.value(elt))
}
return &n
}
n.SetString(enc.idp.unique(d.id))
if !strings.HasPrefix(d.id.name, "$") {
n.Tag = "!var"
}
return &n
}
panic(fmt.Sprintf("unknown domain type %T", v.Domain))
}