| // Copyright 2025 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package unify |
| |
| import ( |
| "fmt" |
| "iter" |
| "maps" |
| "slices" |
| ) |
| |
| type Closure struct { |
| val *Value |
| env envSet |
| } |
| |
| func NewSum(vs ...*Value) Closure { |
| id := &ident{name: "sum"} |
| return Closure{NewValue(Var{id}), topEnv.bind(id, vs...)} |
| } |
| |
| // IsBottom returns whether c consists of no values. |
| func (c Closure) IsBottom() bool { |
| return c.val.Domain == nil |
| } |
| |
| // Summands returns the top-level Values of c. This assumes the top-level of c |
| // was constructed as a sum, and is mostly useful for debugging. |
| func (c Closure) Summands() iter.Seq[*Value] { |
| return func(yield func(*Value) bool) { |
| var rec func(v *Value, env envSet) bool |
| rec = func(v *Value, env envSet) bool { |
| switch d := v.Domain.(type) { |
| case Var: |
| parts := env.partitionBy(d.id) |
| for _, part := range parts { |
| // It may be a sum of sums. Walk into this value. |
| if !rec(part.value, part.env) { |
| return false |
| } |
| } |
| return true |
| default: |
| return yield(v) |
| } |
| } |
| rec(c.val, c.env) |
| } |
| } |
| |
| // All enumerates all possible concrete values of c by substituting variables |
| // from the environment. |
| // |
| // E.g., enumerating this Value |
| // |
| // a: !sum [1, 2] |
| // b: !sum [3, 4] |
| // |
| // results in |
| // |
| // - {a: 1, b: 3} |
| // - {a: 1, b: 4} |
| // - {a: 2, b: 3} |
| // - {a: 2, b: 4} |
| func (c Closure) All() iter.Seq[*Value] { |
| // In order to enumerate all concrete values under all possible variable |
| // bindings, we use a "non-deterministic continuation passing style" to |
| // implement this. We use CPS to traverse the Value tree, threading the |
| // (possibly narrowing) environment through that CPS following an Euler |
| // tour. Where the environment permits multiple choices, we invoke the same |
| // continuation for each choice. Similar to a yield function, the |
| // continuation can return false to stop the non-deterministic walk. |
| return func(yield func(*Value) bool) { |
| c.val.all1(c.env, func(v *Value, e envSet) bool { |
| return yield(v) |
| }) |
| } |
| } |
| |
| func (v *Value) all1(e envSet, cont func(*Value, envSet) bool) bool { |
| switch d := v.Domain.(type) { |
| default: |
| panic(fmt.Sprintf("unknown domain type %T", d)) |
| |
| case nil: |
| return true |
| |
| case Top, String: |
| return cont(v, e) |
| |
| case Def: |
| fields := d.keys() |
| // We can reuse this parts slice because we're doing a DFS through the |
| // state space. (Otherwise, we'd have to do some messy threading of an |
| // immutable slice-like value through allElt.) |
| parts := make(map[string]*Value, len(fields)) |
| |
| // TODO: If there are no Vars or Sums under this Def, then nothing can |
| // change the Value or env, so we could just cont(v, e). |
| var allElt func(elt int, e envSet) bool |
| allElt = func(elt int, e envSet) bool { |
| if elt == len(fields) { |
| // Build a new Def from the concrete parts. Clone parts because |
| // we may reuse it on other non-deterministic branches. |
| nVal := newValueFrom(Def{maps.Clone(parts)}, v) |
| return cont(nVal, e) |
| } |
| |
| return d.fields[fields[elt]].all1(e, func(v *Value, e envSet) bool { |
| parts[fields[elt]] = v |
| return allElt(elt+1, e) |
| }) |
| } |
| return allElt(0, e) |
| |
| case Tuple: |
| // Essentially the same as Def. |
| if d.repeat != nil { |
| // There's nothing we can do with this. |
| return cont(v, e) |
| } |
| parts := make([]*Value, len(d.vs)) |
| var allElt func(elt int, e envSet) bool |
| allElt = func(elt int, e envSet) bool { |
| if elt == len(d.vs) { |
| // Build a new tuple from the concrete parts. Clone parts because |
| // we may reuse it on other non-deterministic branches. |
| nVal := newValueFrom(Tuple{vs: slices.Clone(parts)}, v) |
| return cont(nVal, e) |
| } |
| |
| return d.vs[elt].all1(e, func(v *Value, e envSet) bool { |
| parts[elt] = v |
| return allElt(elt+1, e) |
| }) |
| } |
| return allElt(0, e) |
| |
| case Var: |
| // Go each way this variable can be bound. |
| for _, ePart := range e.partitionBy(d.id) { |
| // d.id is no longer bound in this environment partition. We'll may |
| // need it later in the Euler tour, so bind it back to this single |
| // value. |
| env := ePart.env.bind(d.id, ePart.value) |
| if !ePart.value.all1(env, cont) { |
| return false |
| } |
| } |
| return true |
| } |
| } |