blob: 618887cdf29ced674082c4ab81be9a5e3e74fc23 [file]
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package unify
import (
"fmt"
"iter"
"reflect"
"slices"
"strings"
)
// A nonDetEnv is a non-deterministic mapping from [ident]s to [Value]s.
//
// Logically, this is just a set of deterministic environments, where each
// deterministic environment is a complete mapping from each [ident]s to exactly
// one [Value]. In particular, [ident]s are NOT necessarily independent of each
// other. For example, an environment may have both {x: 1, y: 1} and {x: 2, y:
// 2}, but not {x: 1, y: 2}.
//
// A nonDetEnv is immutable.
//
// Often [ident]s are independent of each other, so the representation optimizes
// for this by using a cross-product of environment factors, where each factor
// is a sum of deterministic environments. These operations obey the usual
// distributional laws, so we can always canonicalize into this form. (It MAY be
// worthwhile to allow more general expressions of sums and products.)
//
// For example, to represent {{x: 1, y: 1}, {x: 2, y: 2}}, in which the
// variables x and y are dependent, we need a single factor that covers x and y
// and consists of two terms: {x: 1, y: 1} + {x: 2, y: 2}.
//
// If we add a third variable z that can be 1 or 2, independent of x and y, we
// get four logical environments:
//
// {x: 1, y: 1, z: 1}
// {x: 2, y: 2, z: 1}
// {x: 1, y: 1, z: 2}
// {x: 2, y: 2, z: 2}
//
// This could be represented as a single factor that is the sum of these four
// detEnvs, but because z is independent, it can be a separate factor. Hence,
// the most compact representation of this environment is:
//
// ({x: 1, y: 1} + {x: 2, y: 2}) тип ({z: 1} + {z: 2})
//
// That is, two factors, where each is the sum of two terms.
type nonDetEnv struct {
// factors is a list of the multiplicative factors in this environment. The
// set of deterministic environments is the cross-product of these factors.
// All factors must have disjoint variables.
factors []*envSum
}
// envSum is a sum of deterministic environments, all with the same set of
// variables.
type envSum struct {
ids []*ident // TODO: Do we ever use this as a slice? Should it be a map?
terms []detEnv
}
type detEnv struct {
vals []*Value // Indexes correspond to envSum.ids
}
var (
// zeroEnvFactor is the "0" value of an [envSum]. It's a a factor with no
// sum terms. This is easiest to think of as: an empty sum must be the
// additive identity, 0.
zeroEnvFactor = &envSum{}
// topEnv is the algebraic one value of a [nonDetEnv]. It has no factors
// because the product of no factors is the multiplicative identity.
topEnv = nonDetEnv{}
// bottomEnv is the algebraic zero value of a [nonDetEnv]. The product of
// bottomEnv with x is bottomEnv, and the sum of bottomEnv with y is y.
bottomEnv = nonDetEnv{factors: []*envSum{zeroEnvFactor}}
)
// bind binds id to each of vals in e.
//
// Its panics if id is already bound in e.
//
// Environments are typically initially constructed by starting with [topEnv]
// and calling bind one or more times.
func (e nonDetEnv) bind(id *ident, vals ...*Value) nonDetEnv {
if e.isBottom() {
return bottomEnv
}
// TODO: If any of vals are _, should we just not do anything? We're kind of
// inconsistent about whether an id missing from e means id is invalid or
// means id is _.
// Check that id isn't present in e.
for _, f := range e.factors {
if slices.Contains(f.ids, id) {
panic("id " + id.name + " already present in environment")
}
}
// Create the new sum term.
sum := &envSum{ids: []*ident{id}}
for _, val := range vals {
sum.terms = append(sum.terms, detEnv{vals: []*Value{val}})
}
// Multiply it in.
factors := append(e.factors[:len(e.factors):len(e.factors)], sum)
return nonDetEnv{factors}
}
func (e nonDetEnv) isBottom() bool {
if len(e.factors) == 0 {
// This is top.
return false
}
return len(e.factors[0].terms) == 0
}
func (e nonDetEnv) vars() iter.Seq[*ident] {
return func(yield func(*ident) bool) {
for _, t := range e.factors {
for _, id := range t.ids {
if !yield(id) {
return
}
}
}
}
}
// all enumerates all deterministic environments in e.
//
// The result slice is in the same order as the slice returned by
// [nonDetEnv2.vars]. The slice is reused between iterations.
func (e nonDetEnv) all() iter.Seq[[]*Value] {
return func(yield func([]*Value) bool) {
var vals []*Value
var walk func(int) bool
walk = func(i int) bool {
if i == len(e.factors) {
return yield(vals)
}
start := len(vals)
for _, term := range e.factors[i].terms {
vals = append(vals[:start], term.vals...)
if !walk(i + 1) {
return false
}
}
return true
}
walk(0)
}
}
// allOrdered is like all, but idOrder controls the order of the values in the
// resulting slice. Any [ident]s in idOrder that are missing from e are set to
// topValue. The values of idOrder must be a bijection with [0, n).
func (e nonDetEnv) allOrdered(idOrder map[*ident]int) iter.Seq[[]*Value] {
valsLen := 0
for _, idx := range idOrder {
valsLen = max(valsLen, idx+1)
}
return func(yield func([]*Value) bool) {
vals := make([]*Value, valsLen)
// e may not have all of the IDs in idOrder. Make sure any missing
// values are top.
for i := range vals {
vals[i] = topValue
}
var walk func(int) bool
walk = func(i int) bool {
if i == len(e.factors) {
return yield(vals)
}
for _, term := range e.factors[i].terms {
for j, id := range e.factors[i].ids {
vals[idOrder[id]] = term.vals[j]
}
if !walk(i + 1) {
return false
}
}
return true
}
walk(0)
}
}
func crossEnvs(envs ...nonDetEnv) nonDetEnv {
// Combine the factors of envs
var factors []*envSum
haveIDs := map[*ident]struct{}{}
for _, e := range envs {
if e.isBottom() {
// The environment is bottom, so the whole product goes to
// bottom.
return bottomEnv
}
// Check that all ids are disjoint.
for _, f := range e.factors {
for _, id := range f.ids {
if _, ok := haveIDs[id]; ok {
panic("conflict on " + id.name)
}
haveIDs[id] = struct{}{}
}
}
// Everything checks out. Multiply the factors.
factors = append(factors, e.factors...)
}
return nonDetEnv{factors: factors}
}
func sumEnvs(envs ...nonDetEnv) nonDetEnv {
// nonDetEnv is a product at the top level, so we implement summation using
// the distributive law. We also use associativity to keep as many top-level
// factors as we can, since those are what keep the environment compact.
//
// a * b * c + a * d (where a, b, c, and d are factors)
// (combine common factors)
// = a * (b * c + d)
// (expand factors into their sum terms)
// = a * ((b_1 + b_2 + ...) * (c_1 + c_2 + ...) + d)
// (where b_i and c_i are deterministic environments)
// (FOIL)
// = a * (b_1 * c_1 + b_1 * c_2 + b_2 * c_1 + b_2 * c2 + d)
// (all factors are now in canonical form)
// = a * e
//
// The product of two deterministic environments is a deterministic
// environment, and the sum of deterministic environments is a factor, so
// this process results in the canonical product-of-sums form.
//
// TODO: This is a bit of a one-way process. We could try to factor the
// environment to reduce the number of sums. I'm not sure how to do this
// efficiently. It might be possible to guide it by gathering the
// distributions of each ID's bindings. E.g., if there are 12 deterministic
// environments in a sum and $x is bound to 4 different values, each 3
// times, then it *might* be possible to factor out $x into a 4-way sum of
// its own.
factors, toSum := commonFactors(envs)
if len(toSum) > 0 {
// Collect all IDs into a single order.
var ids []*ident
idOrder := make(map[*ident]int)
for _, e := range toSum {
for v := range e.vars() {
if _, ok := idOrder[v]; !ok {
idOrder[v] = len(ids)
ids = append(ids, v)
}
}
}
// Flatten out each term in the sum.
var summands []detEnv
for _, env := range toSum {
for vals := range env.allOrdered(idOrder) {
summands = append(summands, detEnv{vals: slices.Clone(vals)})
}
}
factors = append(factors, &envSum{ids: ids, terms: summands})
}
return nonDetEnv{factors: factors}
}
// commonFactors finds common factors that can be factored out of a summation of
// [nonDetEnv]s.
func commonFactors(envs []nonDetEnv) (common []*envSum, toSum []nonDetEnv) {
// Drop any bottom environments. They don't contribute to the sum and they
// would complicate some logic below.
envs = slices.DeleteFunc(envs, func(e nonDetEnv) bool {
return e.isBottom()
})
if len(envs) == 0 {
return bottomEnv.factors, nil
}
// It's very common that the exact same factor will appear across all envs.
// Keep those factored out.
//
// TODO: Is it also common to have vars that are bound to the same value
// across all envs? If so, we could also factor those into common terms.
counts := map[*envSum]int{}
for _, e := range envs {
for _, f := range e.factors {
counts[f]++
}
}
for _, f := range envs[0].factors {
if counts[f] == len(envs) {
// Common factor
common = append(common, f)
}
}
// Any other factors need to be multiplied out.
for _, env := range envs {
var newFactors []*envSum
for _, f := range env.factors {
if counts[f] != len(envs) {
newFactors = append(newFactors, f)
}
}
if len(newFactors) > 0 {
toSum = append(toSum, nonDetEnv{factors: newFactors})
}
}
return common, toSum
}
// envPartition is a subset of an env where id is bound to value in all
// deterministic environments.
type envPartition struct {
id *ident
value *Value
env nonDetEnv
}
func (e nonDetEnv) partitionBy(id *ident) []envPartition {
if e.isBottom() {
// Bottom contains all variables
return []envPartition{{id: id, value: bottomValue, env: e}}
}
// Find the factor containing id and id's index in that factor.
idFactor, idIndex := -1, -1
var newIDs []*ident
for factI, fact := range e.factors {
idI := slices.Index(fact.ids, id)
if idI < 0 {
continue
} else if idFactor != -1 {
panic("multiple factors containing id " + id.name)
} else {
idFactor, idIndex = factI, idI
// Drop id from this factor's IDs
newIDs = without(fact.ids, idI)
}
}
if idFactor == -1 {
panic("id " + id.name + " not found in environment")
}
// If id is the only term in its factor, then dropping it is equivalent to
// making the factor be the unit value, so we can just drop the factor. (And
// if this is the only factor, we'll arrive at [topEnv], which is exactly
// what we want!). In this case we can use the same nonDetEnv in all of the
// partitions.
isUnit := len(newIDs) == 0
var unitFactors []*envSum
if isUnit {
unitFactors = without(e.factors, idFactor)
}
// Create a partition for each distinct value of id.
var parts []envPartition
partIndex := map[*Value]int{}
for _, det := range e.factors[idFactor].terms {
val := det.vals[idIndex]
i, ok := partIndex[val]
if !ok {
i = len(parts)
var factors []*envSum
if isUnit {
factors = unitFactors
} else {
// Copy all other factor
factors = slices.Clone(e.factors)
factors[idFactor] = &envSum{ids: newIDs}
}
parts = append(parts, envPartition{id: id, value: val, env: nonDetEnv{factors: factors}})
partIndex[val] = i
}
if !isUnit {
factor := parts[i].env.factors[idFactor]
newVals := without(det.vals, idIndex)
factor.terms = append(factor.terms, detEnv{vals: newVals})
}
}
return parts
}
type ident struct {
_ [0]func() // Not comparable (only compare *ident)
name string
}
type Var struct {
id *ident
}
func (d Var) Exact() bool {
// These can't appear in concrete Values.
panic("Exact called on non-concrete Value")
}
func (d Var) decode(rv reflect.Value) error {
return &inexactError{"var", rv.Type().String()}
}
func (d Var) unify(w *Value, e nonDetEnv, swap bool, uf *unifier) (Domain, nonDetEnv, error) {
// TODO: Vars from !sums in the input can have a huge number of values.
// Unifying these could be way more efficient with some indexes over any
// exact values we can pull out, like Def fields that are exact Strings.
// Maybe we try to produce an array of yes/no/maybe matches and then we only
// have to do deeper evaluation of the maybes. We could probably cache this
// on an envTerm. It may also help to special-case Var/Var unification to
// pick which one to index versus enumerate.
if vd, ok := w.Domain.(Var); ok && d.id == vd.id {
// Unifying $x with $x results in $x. If we descend into this we'll have
// problems because we strip $x out of the environment to keep ourselves
// honest and then can't find it on the other side.
//
// TODO: I'm not positive this is the right fix.
return vd, e, nil
}
// We need to unify w with the value of d in each possible environment. We
// can save some work by grouping environments by the value of d, since
// there will be a lot of redundancy here.
var nEnvs []nonDetEnv
envParts := e.partitionBy(d.id)
for i, envPart := range envParts {
exit := uf.enterVar(d.id, i)
// Each branch logically gets its own copy of the initial environment
// (narrowed down to just this binding of the variable), and each branch
// may result in different changes to that starting environment.
res, e2, err := w.unify(envPart.value, envPart.env, swap, uf)
exit.exit()
if err != nil {
return nil, nonDetEnv{}, err
}
if res.Domain == nil {
// This branch entirely failed to unify, so it's gone.
continue
}
nEnv := e2.bind(d.id, res)
nEnvs = append(nEnvs, nEnv)
}
if len(nEnvs) == 0 {
// All branches failed
return nil, bottomEnv, nil
}
// The effect of this is entirely captured in the environment. We can return
// back the same Bind node.
return d, sumEnvs(nEnvs...), nil
}
// An identPrinter maps [ident]s to unique string names.
type identPrinter struct {
ids map[*ident]string
idGen map[string]int
}
func (p *identPrinter) unique(id *ident) string {
if p.ids == nil {
p.ids = make(map[*ident]string)
p.idGen = make(map[string]int)
}
name, ok := p.ids[id]
if !ok {
gen := p.idGen[id.name]
p.idGen[id.name]++
if gen == 0 {
name = id.name
} else {
name = fmt.Sprintf("%s#%d", id.name, gen)
}
p.ids[id] = name
}
return name
}
func (p *identPrinter) slice(ids []*ident) string {
var strs []string
for _, id := range ids {
strs = append(strs, p.unique(id))
}
return fmt.Sprintf("[%s]", strings.Join(strs, ", "))
}
func without[Elt any](s []Elt, i int) []Elt {
return append(s[:i:i], s[i+1:]...)
}