blob: 4153370f3281b88599066e91b42b82b09cc738fe [file] [log] [blame]
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build goexperiment.synctest || go1.25
package mcp
import (
"bytes"
"encoding/json"
"errors"
"flag"
"fmt"
"io"
"io/fs"
"os"
"path/filepath"
"strings"
"testing"
"testing/synctest"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
jsonrpc2 "golang.org/x/tools/internal/jsonrpc2_v2"
"golang.org/x/tools/txtar"
)
var update = flag.Bool("update", false, "if set, update conformance test data")
// A conformance test checks JSON-level conformance of a test server or client.
// This allows us to confirm that we can handle the input or output of other
// SDKs, even if they behave differently at the JSON level (for example, have
// different behavior with respect to optional fields).
//
// The client and server fields hold an encoded sequence of JSON-RPC messages.
//
// For server tests, the client messages are a sequence of messages to be sent
// from the (synthetic) client and the server messages are the expected
// messages to be received from the real server.
//
// For client tests, it's the other way around: server messages are synthetic,
// and client messages are expected from the real client.
//
// Conformance tests are loaded from txtar-encoded testdata files. Run the test
// with -update to have the test runner update the expected output, which may
// be client or server depending on the perspective of the test.
type conformanceTest struct {
name string // test name
path string // path to test file
archive *txtar.Archive // raw archive, for updating
tools, prompts, resources []string // named features to include
client []JSONRPCMessage // client messages
server []JSONRPCMessage // server messages
}
// TODO(rfindley): add client conformance tests.
func TestServerConformance(t *testing.T) {
var tests []*conformanceTest
dir := filepath.Join("testdata", "conformance", "server")
if err := filepath.WalkDir(dir, func(path string, _ fs.DirEntry, err error) error {
if err != nil {
return err
}
if strings.HasSuffix(path, ".txtar") {
test, err := loadConformanceTest(dir, path)
if err != nil {
return fmt.Errorf("%s: %v", path, err)
}
tests = append(tests, test)
}
return nil
}); err != nil {
t.Fatal(err)
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
// We use synctest here because in general, there is no way to know when the
// server is done processing any notifications. As long as our server doesn't
// do background work, synctest provides an easy way for us to detect when the
// server is done processing.
//
// By comparison, gopls has a complicated framework based on progress
// reporting and careful accounting to detect when all 'expected' work
// on the server is complete.
runSyncTest(t, func(t *testing.T) { runServerTest(t, test) })
// TODO: in 1.25, use the following instead:
// synctest.Test(t, func(t *testing.T) {
// runServerTest(t, test)
// })
})
}
}
// runServerTest runs the server conformance test.
// It must be executed in a synctest bubble.
func runServerTest(t *testing.T, test *conformanceTest) {
ctx := t.Context()
// Construct the server based on features listed in the test.
s := NewServer("testServer", "v1.0.0", nil)
add(tools, s.AddTools, test.tools...)
add(prompts, s.AddPrompts, test.prompts...)
add(resources, s.AddResources, test.resources...)
// Connect the server, and connect the client stream,
// but don't connect an actual client.
cTransport, sTransport := NewInMemoryTransports()
ss, err := s.Connect(ctx, sTransport)
if err != nil {
t.Fatal(err)
}
cStream, err := cTransport.Connect(ctx)
if err != nil {
t.Fatal(err)
}
writeMsg := func(msg JSONRPCMessage) {
if err := cStream.Write(ctx, msg); err != nil {
t.Fatalf("Write failed: %v", err)
}
}
var (
serverMessages []JSONRPCMessage
outRequests []*JSONRPCRequest
outResponses []*JSONRPCResponse
)
// Separate client requests and responses; we use them differently.
for _, msg := range test.client {
switch msg := msg.(type) {
case *JSONRPCRequest:
outRequests = append(outRequests, msg)
case *JSONRPCResponse:
outResponses = append(outResponses, msg)
default:
t.Fatalf("bad message type %T", msg)
}
}
// nextResponse handles incoming requests and notifications, and returns the
// next incoming response.
nextResponse := func() (*JSONRPCResponse, error, bool) {
for {
msg, err := cStream.Read(ctx)
if err != nil {
// TODO(rfindley): we don't document (or want to document) that the in
// memory transports use a net.Pipe. How can users detect this failure?
// Should we promote it to EOF?
if errors.Is(err, io.ErrClosedPipe) {
err = nil
}
return nil, err, false
}
serverMessages = append(serverMessages, msg)
if req, ok := msg.(*JSONRPCRequest); ok && req.ID.IsValid() {
// Pair up the next outgoing response with this request.
// We assume requests arrive in the same order every time.
if len(outResponses) == 0 {
t.Fatalf("no outgoing response for request %v", req)
}
outResponses[0].ID = req.ID
writeMsg(outResponses[0])
outResponses = outResponses[1:]
continue
}
return msg.(*JSONRPCResponse), nil, true
}
}
// Synthetic peer interacts with real peer.
for _, req := range outRequests {
writeMsg(req)
if req.ID.IsValid() {
// A request (as opposed to a notification). Wait for the response.
res, err, ok := nextResponse()
if err != nil {
t.Fatalf("reading server messages failed: %v", err)
}
if !ok {
t.Fatalf("missing response for request %v", req)
}
if res.ID != req.ID {
t.Fatalf("out-of-order response %v to request %v", req, res)
}
}
}
// There might be more notifications or requests, but there shouldn't be more
// responses.
// Run this in a goroutine so the current thread can wait for it.
var extra *JSONRPCResponse
go func() {
extra, err, _ = nextResponse()
}()
// Before closing the stream, wait for all messages to be processed.
synctest.Wait() // => stdversion "requires go1.25" false positive (#75367)
if err != nil {
t.Fatalf("reading server messages failedd: %v", err)
}
if extra != nil {
t.Fatalf("got extra response: %v", extra)
}
if err := cStream.Close(); err != nil {
t.Fatalf("Stream.Close failed: %v", err)
}
ss.Wait() // ignore error
// Handle server output. If -update is set, write the 'server' file.
// Otherwise, compare with expected.
if *update {
arch := &txtar.Archive{
Comment: test.archive.Comment,
}
var buf bytes.Buffer
for _, msg := range serverMessages {
data, err := jsonrpc2.EncodeIndent(msg, "", "\t")
if err != nil {
t.Fatalf("jsonrpc2.EncodeIndent failed: %v", err)
}
buf.Write(data)
buf.WriteByte('\n')
}
serverFile := txtar.File{Name: "server", Data: buf.Bytes()}
seenServer := false // replace or append the 'server' file
for _, f := range test.archive.Files {
if f.Name == "server" {
seenServer = true
arch.Files = append(arch.Files, serverFile)
} else {
arch.Files = append(arch.Files, f)
}
}
if !seenServer {
arch.Files = append(arch.Files, serverFile)
}
if err := os.WriteFile(test.path, txtar.Format(arch), 0o666); err != nil {
t.Fatalf("os.WriteFile(%q) failed: %v", test.path, err)
}
} else {
// JSONRPCMessages are not comparable, so we instead compare lines of JSON.
transform := cmpopts.AcyclicTransformer("toJSON", func(msg JSONRPCMessage) []string {
encoded, err := jsonrpc2.EncodeIndent(msg, "", "\t")
if err != nil {
t.Fatal(err)
}
return strings.Split(string(encoded), "\n")
})
if diff := cmp.Diff(test.server, serverMessages, transform); diff != "" {
t.Errorf("Mismatching server messages (-want +got):\n%s", diff)
}
}
}
// loadConformanceTest loads one conformance test from the given path contained
// in the root dir.
func loadConformanceTest(dir, path string) (*conformanceTest, error) {
content, err := os.ReadFile(path)
if err != nil {
return nil, err
}
test := &conformanceTest{
name: strings.TrimPrefix(path, dir+string(filepath.Separator)),
path: path,
archive: txtar.Parse(content),
}
if len(test.archive.Files) == 0 {
return nil, fmt.Errorf("txtar archive %q has no '-- filename --' sections", path)
}
// decodeMessages loads JSON-RPC messages from the archive file.
decodeMessages := func(data []byte) ([]JSONRPCMessage, error) {
dec := json.NewDecoder(bytes.NewReader(data))
var res []JSONRPCMessage
for dec.More() {
var raw json.RawMessage
if err := dec.Decode(&raw); err != nil {
return nil, err
}
m, err := jsonrpc2.DecodeMessage(raw)
if err != nil {
return nil, err
}
res = append(res, m)
}
return res, nil
}
// loadFeatures loads lists of named features from the archive file.
loadFeatures := func(data []byte) []string {
var feats []string
for line := range strings.Lines(string(data)) {
if f := strings.TrimSpace(line); f != "" {
feats = append(feats, f)
}
}
return feats
}
seen := make(map[string]bool) // catch accidentally duplicate files
for _, f := range test.archive.Files {
if seen[f.Name] {
return nil, fmt.Errorf("duplicate file name %q", f.Name)
}
seen[f.Name] = true
switch f.Name {
case "tools":
test.tools = loadFeatures(f.Data)
case "prompts":
test.prompts = loadFeatures(f.Data)
case "resources":
test.resources = loadFeatures(f.Data)
case "client":
test.client, err = decodeMessages(f.Data)
if err != nil {
return nil, fmt.Errorf("txtar archive %q contains bad -- client -- section: %v", path, err)
}
case "server":
test.server, err = decodeMessages(f.Data)
if err != nil {
return nil, fmt.Errorf("txtar archive %q contains bad -- server -- section: %v", path, err)
}
default:
return nil, fmt.Errorf("txtar archive %q contains unexpected file %q", path, f.Name)
}
}
return test, nil
}