package slogtest

import (
	"context"
	"errors"
	"fmt"
	"reflect"
	"runtime"
	"time"

	"golang.org/x/exp/slog"
)

type testCase struct {
	// If non-empty, explanation explains the violated constraint.
	explanation string
	// f executes a single log event using its argument logger.
	// So that mkdescs.sh can generate the right description,
	// the body of f must appear on a single line whose first
	// non-whitespace characters are "l.".
	f func(*slog.Logger)
	// If mod is not nil, it is called to modify the Record
	// generated by the Logger before it is passed to the Handler.
	mod func(*slog.Record)
	// checks is a list of checks to run on the result.
	checks []check
}

// TestHandler tests a [slog.Handler].
// If TestHandler finds any misbehaviors, it returns an error for each,
// combined into a single error with errors.Join.
//
// TestHandler installs the given Handler in a [slog.Logger] and
// makes several calls to the Logger's output methods.
//
// The results function is invoked after all such calls.
// It should return a slice of map[string]any, one for each call to a Logger output method.
// The keys and values of the map should correspond to the keys and values of the Handler's
// output. Each group in the output should be represented as its own nested map[string]any.
// The standard keys slog.TimeKey, slog.LevelKey and slog.MessageKey should be used.
//
// If the Handler outputs JSON, then calling [encoding/json.Unmarshal] with a `map[string]any`
// will create the right data structure.
//
// If a Handler intentionally drops an attribute that is checked by a test,
// then the results function should check for its absence and add it to the map it returns.
func TestHandler(h slog.Handler, results func() []map[string]any) error {
	cases := []testCase{
		{
			explanation: withSource("this test expects slog.TimeKey, slog.LevelKey and slog.MessageKey"),
			f: func(l *slog.Logger) {
				l.Info("message")
			},
			checks: []check{
				hasKey(slog.TimeKey),
				hasKey(slog.LevelKey),
				hasAttr(slog.MessageKey, "message"),
			},
		},
		{
			explanation: withSource("a Handler should output attributes passed to the logging function"),
			f: func(l *slog.Logger) {
				l.Info("message", "k", "v")
			},
			checks: []check{
				hasAttr("k", "v"),
			},
		},
		{
			explanation: withSource("a Handler should ignore an empty Attr"),
			f: func(l *slog.Logger) {
				l.Info("msg", "a", "b", "", nil, "c", "d")
			},
			checks: []check{
				hasAttr("a", "b"),
				missingKey(""),
				hasAttr("c", "d"),
			},
		},
		{
			explanation: withSource("a Handler should ignore a zero Record.Time"),
			f: func(l *slog.Logger) {
				l.Info("msg", "k", "v")
			},
			mod: func(r *slog.Record) { r.Time = time.Time{} },
			checks: []check{
				missingKey(slog.TimeKey),
			},
		},
		{
			explanation: withSource("a Handler should include the attributes from the WithAttrs method"),
			f: func(l *slog.Logger) {
				l.With("a", "b").Info("msg", "k", "v")
			},
			checks: []check{
				hasAttr("a", "b"),
				hasAttr("k", "v"),
			},
		},
		{
			explanation: withSource("a Handler should handle Group attributes"),
			f: func(l *slog.Logger) {
				l.Info("msg", "a", "b", slog.Group("G", slog.String("c", "d")), "e", "f")
			},
			checks: []check{
				hasAttr("a", "b"),
				inGroup("G", hasAttr("c", "d")),
				hasAttr("e", "f"),
			},
		},
		{
			explanation: withSource("a Handler should ignore an empty group"),
			f: func(l *slog.Logger) {
				l.Info("msg", "a", "b", slog.Group("G"), "e", "f")
			},
			checks: []check{
				hasAttr("a", "b"),
				missingKey("G"),
				hasAttr("e", "f"),
			},
		},
		{
			explanation: withSource("a Handler should inline the Attrs of a group with an empty key"),
			f: func(l *slog.Logger) {
				l.Info("msg", "a", "b", slog.Group("", slog.String("c", "d")), "e", "f")

			},
			checks: []check{
				hasAttr("a", "b"),
				hasAttr("c", "d"),
				hasAttr("e", "f"),
			},
		},
		{
			explanation: withSource("a Handler should handle the WithGroup method"),
			f: func(l *slog.Logger) {
				l.WithGroup("G").Info("msg", "a", "b")
			},
			checks: []check{
				hasKey(slog.TimeKey),
				hasKey(slog.LevelKey),
				hasAttr(slog.MessageKey, "msg"),
				missingKey("a"),
				inGroup("G", hasAttr("a", "b")),
			},
		},
		{
			explanation: withSource("a Handler should handle multiple WithGroup and WithAttr calls"),
			f: func(l *slog.Logger) {
				l.With("a", "b").WithGroup("G").With("c", "d").WithGroup("H").Info("msg", "e", "f")
			},
			checks: []check{
				hasKey(slog.TimeKey),
				hasKey(slog.LevelKey),
				hasAttr(slog.MessageKey, "msg"),
				hasAttr("a", "b"),
				inGroup("G", hasAttr("c", "d")),
				inGroup("G", inGroup("H", hasAttr("e", "f"))),
			},
		},
	}

	// Run the handler on the test cases.
	for _, c := range cases {
		ht := h
		if c.mod != nil {
			ht = &wrapper{h, c.mod}
		}
		l := slog.New(ht)
		c.f(l)
	}

	// Collect and check the results.
	var errs []error
	res := results()
	if g, w := len(res), len(cases); g != w {
		return fmt.Errorf("got %d results, want %d", g, w)
	}
	for i, got := range results() {
		c := cases[i]
		for _, check := range c.checks {
			if p := check(got); p != "" {
				errs = append(errs, fmt.Errorf("%s: %s", p, c.explanation))
			}
		}
	}
	return errors.Join(errs...)
}

type check func(map[string]any) string

func hasKey(key string) check {
	return func(m map[string]any) string {
		if _, ok := m[key]; !ok {
			return fmt.Sprintf("missing key %q", key)
		}
		return ""
	}
}

func missingKey(key string) check {
	return func(m map[string]any) string {
		if _, ok := m[key]; ok {
			return fmt.Sprintf("unexpected key %q", key)
		}
		return ""
	}
}

func hasAttr(key string, wantVal any) check {
	return func(m map[string]any) string {
		if s := hasKey(key)(m); s != "" {
			return s
		}
		gotVal := m[key]
		if !reflect.DeepEqual(gotVal, wantVal) {
			return fmt.Sprintf("%q: got %#v, want %#v", key, gotVal, wantVal)
		}
		return ""
	}
}

func inGroup(name string, c check) check {
	return func(m map[string]any) string {
		v, ok := m[name]
		if !ok {
			return fmt.Sprintf("missing group %q", name)
		}
		g, ok := v.(map[string]any)
		if !ok {
			return fmt.Sprintf("value for group %q is not map[string]any", name)
		}
		return c(g)
	}
}

type wrapper struct {
	slog.Handler
	mod func(*slog.Record)
}

func (h *wrapper) Handle(ctx context.Context, r slog.Record) error {
	h.mod(&r)
	return h.Handler.Handle(ctx, r)
}

func withSource(s string) string {
	_, file, line, ok := runtime.Caller(1)
	if !ok {
		panic("runtime.Caller failed")
	}
	return fmt.Sprintf("%s (%s:%d)", s, file, line)
}
