|  | // Copyright 2025 The Go Authors. All rights reserved. | 
|  | // Use of this source code is governed by a BSD-style | 
|  | // license that can be found in the LICENSE file. | 
|  |  | 
|  | package mcp | 
|  |  | 
|  | import ( | 
|  | "bytes" | 
|  | "context" | 
|  | "errors" | 
|  | "fmt" | 
|  | "io" | 
|  | "log/slog" | 
|  | "net/url" | 
|  | "path/filepath" | 
|  | "runtime" | 
|  | "slices" | 
|  | "strings" | 
|  | "sync" | 
|  | "testing" | 
|  | "time" | 
|  |  | 
|  | "github.com/google/go-cmp/cmp" | 
|  | "github.com/google/go-cmp/cmp/cmpopts" | 
|  | jsonrpc2 "golang.org/x/tools/internal/jsonrpc2_v2" | 
|  | "golang.org/x/tools/internal/mcp/jsonschema" | 
|  | ) | 
|  |  | 
|  | type hiParams struct { | 
|  | Name string | 
|  | } | 
|  |  | 
|  | func sayHi(ctx context.Context, ss *ServerSession, params *CallToolParamsFor[hiParams]) (*CallToolResultFor[any], error) { | 
|  | if err := ss.Ping(ctx, nil); err != nil { | 
|  | return nil, fmt.Errorf("ping failed: %v", err) | 
|  | } | 
|  | return &CallToolResultFor[any]{Content: []*Content{NewTextContent("hi " + params.Arguments.Name)}}, nil | 
|  | } | 
|  |  | 
|  | func TestEndToEnd(t *testing.T) { | 
|  | ctx := context.Background() | 
|  | var ct, st Transport = NewInMemoryTransports() | 
|  |  | 
|  | // Channels to check if notification callbacks happened. | 
|  | notificationChans := map[string]chan int{} | 
|  | for _, name := range []string{"initialized", "roots", "tools", "prompts", "resources", "progress_server", "progress_client"} { | 
|  | notificationChans[name] = make(chan int, 1) | 
|  | } | 
|  | waitForNotification := func(t *testing.T, name string) { | 
|  | t.Helper() | 
|  | select { | 
|  | case <-notificationChans[name]: | 
|  | case <-time.After(time.Second): | 
|  | t.Fatalf("%s handler never called", name) | 
|  | } | 
|  | } | 
|  |  | 
|  | sopts := &ServerOptions{ | 
|  | InitializedHandler:      func(context.Context, *ServerSession, *InitializedParams) { notificationChans["initialized"] <- 0 }, | 
|  | RootsListChangedHandler: func(context.Context, *ServerSession, *RootsListChangedParams) { notificationChans["roots"] <- 0 }, | 
|  | ProgressNotificationHandler: func(context.Context, *ServerSession, *ProgressNotificationParams) { | 
|  | notificationChans["progress_server"] <- 0 | 
|  | }, | 
|  | } | 
|  | s := NewServer("testServer", "v1.0.0", sopts) | 
|  | add(tools, s.AddTools, "greet", "fail") | 
|  | add(prompts, s.AddPrompts, "code_review", "fail") | 
|  | add(resources, s.AddResources, "info.txt", "fail.txt") | 
|  |  | 
|  | // Connect the server. | 
|  | ss, err := s.Connect(ctx, st) | 
|  | if err != nil { | 
|  | t.Fatal(err) | 
|  | } | 
|  | if got := slices.Collect(s.Sessions()); len(got) != 1 { | 
|  | t.Errorf("after connection, Clients() has length %d, want 1", len(got)) | 
|  | } | 
|  |  | 
|  | // Wait for the server to exit after the client closes its connection. | 
|  | var clientWG sync.WaitGroup | 
|  | clientWG.Add(1) | 
|  | go func() { | 
|  | if err := ss.Wait(); err != nil { | 
|  | t.Errorf("server failed: %v", err) | 
|  | } | 
|  | clientWG.Done() | 
|  | }() | 
|  |  | 
|  | loggingMessages := make(chan *LoggingMessageParams, 100) // big enough for all logging | 
|  | opts := &ClientOptions{ | 
|  | CreateMessageHandler: func(context.Context, *ClientSession, *CreateMessageParams) (*CreateMessageResult, error) { | 
|  | return &CreateMessageResult{Model: "aModel"}, nil | 
|  | }, | 
|  | ToolListChangedHandler:     func(context.Context, *ClientSession, *ToolListChangedParams) { notificationChans["tools"] <- 0 }, | 
|  | PromptListChangedHandler:   func(context.Context, *ClientSession, *PromptListChangedParams) { notificationChans["prompts"] <- 0 }, | 
|  | ResourceListChangedHandler: func(context.Context, *ClientSession, *ResourceListChangedParams) { notificationChans["resources"] <- 0 }, | 
|  | LoggingMessageHandler: func(_ context.Context, _ *ClientSession, lm *LoggingMessageParams) { | 
|  | loggingMessages <- lm | 
|  | }, | 
|  | ProgressNotificationHandler: func(context.Context, *ClientSession, *ProgressNotificationParams) { | 
|  | notificationChans["progress_client"] <- 0 | 
|  | }, | 
|  | } | 
|  | c := NewClient("testClient", "v1.0.0", opts) | 
|  | rootAbs, err := filepath.Abs(filepath.FromSlash("testdata/files")) | 
|  | if err != nil { | 
|  | t.Fatal(err) | 
|  | } | 
|  | c.AddRoots(&Root{URI: "file://" + rootAbs}) | 
|  |  | 
|  | // Connect the client. | 
|  | cs, err := c.Connect(ctx, ct) | 
|  | if err != nil { | 
|  | t.Fatal(err) | 
|  | } | 
|  |  | 
|  | waitForNotification(t, "initialized") | 
|  | if err := cs.Ping(ctx, nil); err != nil { | 
|  | t.Fatalf("ping failed: %v", err) | 
|  | } | 
|  | t.Run("prompts", func(t *testing.T) { | 
|  | res, err := cs.ListPrompts(ctx, nil) | 
|  | if err != nil { | 
|  | t.Fatalf("prompts/list failed: %v", err) | 
|  | } | 
|  | wantPrompts := []*Prompt{ | 
|  | { | 
|  | Name:        "code_review", | 
|  | Description: "do a code review", | 
|  | Arguments:   []*PromptArgument{{Name: "Code", Required: true}}, | 
|  | }, | 
|  | {Name: "fail"}, | 
|  | } | 
|  | if diff := cmp.Diff(wantPrompts, res.Prompts); diff != "" { | 
|  | t.Fatalf("prompts/list mismatch (-want +got):\n%s", diff) | 
|  | } | 
|  |  | 
|  | gotReview, err := cs.GetPrompt(ctx, &GetPromptParams{Name: "code_review", Arguments: map[string]string{"Code": "1+1"}}) | 
|  | if err != nil { | 
|  | t.Fatal(err) | 
|  | } | 
|  | wantReview := &GetPromptResult{ | 
|  | Description: "Code review prompt", | 
|  | Messages: []*PromptMessage{{ | 
|  | Content: NewTextContent("Please review the following code: 1+1"), | 
|  | Role:    "user", | 
|  | }}, | 
|  | } | 
|  | if diff := cmp.Diff(wantReview, gotReview); diff != "" { | 
|  | t.Errorf("prompts/get 'code_review' mismatch (-want +got):\n%s", diff) | 
|  | } | 
|  |  | 
|  | if _, err := cs.GetPrompt(ctx, &GetPromptParams{Name: "fail"}); err == nil || !strings.Contains(err.Error(), errTestFailure.Error()) { | 
|  | t.Errorf("fail returned unexpected error: got %v, want containing %v", err, errTestFailure) | 
|  | } | 
|  |  | 
|  | s.AddPrompts(&ServerPrompt{Prompt: &Prompt{Name: "T"}}) | 
|  | waitForNotification(t, "prompts") | 
|  | s.RemovePrompts("T") | 
|  | waitForNotification(t, "prompts") | 
|  | }) | 
|  |  | 
|  | t.Run("tools", func(t *testing.T) { | 
|  | res, err := cs.ListTools(ctx, nil) | 
|  | if err != nil { | 
|  | t.Errorf("tools/list failed: %v", err) | 
|  | } | 
|  | wantTools := []*Tool{ | 
|  | { | 
|  | Name:        "fail", | 
|  | InputSchema: nil, | 
|  | }, | 
|  | { | 
|  | Name:        "greet", | 
|  | Description: "say hi", | 
|  | InputSchema: &jsonschema.Schema{ | 
|  | Type:     "object", | 
|  | Required: []string{"Name"}, | 
|  | Properties: map[string]*jsonschema.Schema{ | 
|  | "Name": {Type: "string"}, | 
|  | }, | 
|  | AdditionalProperties: falseSchema(), | 
|  | }, | 
|  | }, | 
|  | } | 
|  | if diff := cmp.Diff(wantTools, res.Tools, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" { | 
|  | t.Fatalf("tools/list mismatch (-want +got):\n%s", diff) | 
|  | } | 
|  |  | 
|  | gotHi, err := cs.CallTool(ctx, &CallToolParams{ | 
|  | Name:      "greet", | 
|  | Arguments: map[string]any{"name": "user"}, | 
|  | }) | 
|  | if err != nil { | 
|  | t.Fatal(err) | 
|  | } | 
|  | wantHi := &CallToolResult{ | 
|  | Content: []*Content{{Type: "text", Text: "hi user"}}, | 
|  | } | 
|  | if diff := cmp.Diff(wantHi, gotHi); diff != "" { | 
|  | t.Errorf("tools/call 'greet' mismatch (-want +got):\n%s", diff) | 
|  | } | 
|  |  | 
|  | gotFail, err := cs.CallTool(ctx, &CallToolParams{ | 
|  | Name:      "fail", | 
|  | Arguments: map[string]any{}, | 
|  | }) | 
|  | // Counter-intuitively, when a tool fails, we don't expect an RPC error for | 
|  | // call tool: instead, the failure is embedded in the result. | 
|  | if err != nil { | 
|  | t.Fatal(err) | 
|  | } | 
|  | wantFail := &CallToolResult{ | 
|  | IsError: true, | 
|  | Content: []*Content{{Type: "text", Text: errTestFailure.Error()}}, | 
|  | } | 
|  | if diff := cmp.Diff(wantFail, gotFail); diff != "" { | 
|  | t.Errorf("tools/call 'fail' mismatch (-want +got):\n%s", diff) | 
|  | } | 
|  |  | 
|  | s.AddTools(&ServerTool{Tool: &Tool{Name: "T"}, Handler: nopHandler}) | 
|  | waitForNotification(t, "tools") | 
|  | s.RemoveTools("T") | 
|  | waitForNotification(t, "tools") | 
|  | }) | 
|  |  | 
|  | t.Run("resources", func(t *testing.T) { | 
|  | if runtime.GOOS == "windows" { | 
|  | t.Skip("TODO: fix for Windows") | 
|  | } | 
|  | wantResources := []*Resource{resource2, resource1} | 
|  | lrres, err := cs.ListResources(ctx, nil) | 
|  | if err != nil { | 
|  | t.Fatal(err) | 
|  | } | 
|  | if diff := cmp.Diff(wantResources, lrres.Resources); diff != "" { | 
|  | t.Errorf("resources/list mismatch (-want, +got):\n%s", diff) | 
|  | } | 
|  |  | 
|  | template := &ResourceTemplate{ | 
|  | Name:        "rt", | 
|  | MIMEType:    "text/template", | 
|  | URITemplate: "file:///{+filename}", // the '+' means that filename can contain '/' | 
|  | } | 
|  | st := &ServerResourceTemplate{ResourceTemplate: template, Handler: readHandler} | 
|  | s.AddResourceTemplates(st) | 
|  | tres, err := cs.ListResourceTemplates(ctx, nil) | 
|  | if err != nil { | 
|  | t.Fatal(err) | 
|  | } | 
|  | if diff := cmp.Diff([]*ResourceTemplate{template}, tres.ResourceTemplates); diff != "" { | 
|  | t.Errorf("resources/list mismatch (-want, +got):\n%s", diff) | 
|  | } | 
|  |  | 
|  | for _, tt := range []struct { | 
|  | uri      string | 
|  | mimeType string // "": not found; "text/plain": resource; "text/template": template | 
|  | fail     bool   // non-nil error returned | 
|  | }{ | 
|  | {"file:///info.txt", "text/plain", false}, | 
|  | {"file:///fail.txt", "", false}, | 
|  | {"file:///template.txt", "text/template", false}, | 
|  | {"file:///../private.txt", "", true}, // not found: escaping disallowed | 
|  | } { | 
|  | rres, err := cs.ReadResource(ctx, &ReadResourceParams{URI: tt.uri}) | 
|  | if err != nil { | 
|  | if code := errorCode(err); code == CodeResourceNotFound { | 
|  | if tt.mimeType != "" { | 
|  | t.Errorf("%s: not found but expected it to be", tt.uri) | 
|  | } | 
|  | } else if !tt.fail { | 
|  | t.Errorf("%s: unexpected error %v", tt.uri, err) | 
|  | } | 
|  | } else { | 
|  | if tt.fail { | 
|  | t.Errorf("%s: unexpected success", tt.uri) | 
|  | } else if g, w := len(rres.Contents), 1; g != w { | 
|  | t.Errorf("got %d contents, wanted %d", g, w) | 
|  | } else { | 
|  | c := rres.Contents[0] | 
|  | if got := c.URI; got != tt.uri { | 
|  | t.Errorf("got uri %q, want %q", got, tt.uri) | 
|  | } | 
|  | if got := c.MIMEType; got != tt.mimeType { | 
|  | t.Errorf("%s: got MIME type %q, want %q", tt.uri, got, tt.mimeType) | 
|  | } | 
|  | } | 
|  | } | 
|  | } | 
|  |  | 
|  | s.AddResources(&ServerResource{Resource: &Resource{URI: "http://U"}}) | 
|  | waitForNotification(t, "resources") | 
|  | s.RemoveResources("http://U") | 
|  | waitForNotification(t, "resources") | 
|  | }) | 
|  | t.Run("roots", func(t *testing.T) { | 
|  | rootRes, err := ss.ListRoots(ctx, &ListRootsParams{}) | 
|  | if err != nil { | 
|  | t.Fatal(err) | 
|  | } | 
|  | gotRoots := rootRes.Roots | 
|  | wantRoots := slices.Collect(c.roots.all()) | 
|  | if diff := cmp.Diff(wantRoots, gotRoots); diff != "" { | 
|  | t.Errorf("roots/list mismatch (-want +got):\n%s", diff) | 
|  | } | 
|  |  | 
|  | c.AddRoots(&Root{URI: "U"}) | 
|  | waitForNotification(t, "roots") | 
|  | c.RemoveRoots("U") | 
|  | waitForNotification(t, "roots") | 
|  | }) | 
|  | t.Run("sampling", func(t *testing.T) { | 
|  | // TODO: test that a client that doesn't have the handler returns CodeUnsupportedMethod. | 
|  | res, err := ss.CreateMessage(ctx, &CreateMessageParams{}) | 
|  | if err != nil { | 
|  | t.Fatal(err) | 
|  | } | 
|  | if g, w := res.Model, "aModel"; g != w { | 
|  | t.Errorf("got %q, want %q", g, w) | 
|  | } | 
|  | }) | 
|  | t.Run("logging", func(t *testing.T) { | 
|  | want := []*LoggingMessageParams{ | 
|  | { | 
|  | Logger: "test", | 
|  | Level:  "warning", | 
|  | Data: map[string]any{ | 
|  | "msg":     "first", | 
|  | "name":    "Pat", | 
|  | "logtest": true, | 
|  | }, | 
|  | }, | 
|  | { | 
|  | Logger: "test", | 
|  | Level:  "alert", | 
|  | Data: map[string]any{ | 
|  | "msg":     "second", | 
|  | "count":   2.0, | 
|  | "logtest": true, | 
|  | }, | 
|  | }, | 
|  | } | 
|  |  | 
|  | check := func(t *testing.T) { | 
|  | t.Helper() | 
|  | var got []*LoggingMessageParams | 
|  | // Read messages from this test until we've seen all we expect. | 
|  | for len(got) < len(want) { | 
|  | select { | 
|  | case p := <-loggingMessages: | 
|  | // Ignore logging from other tests. | 
|  | if m, ok := p.Data.(map[string]any); ok && m["logtest"] != nil { | 
|  | delete(m, "time") // remove time because it changes | 
|  | got = append(got, p) | 
|  | } | 
|  | case <-time.After(time.Second): | 
|  | t.Fatal("timed out waiting for log messages") | 
|  | } | 
|  | } | 
|  | if diff := cmp.Diff(want, got); diff != "" { | 
|  | t.Errorf("mismatch (-want, +got):\n%s", diff) | 
|  | } | 
|  | } | 
|  |  | 
|  | t.Run("direct", func(t *testing.T) { // Use the LoggingMessage method directly. | 
|  |  | 
|  | mustLog := func(level LoggingLevel, data any) { | 
|  | t.Helper() | 
|  | if err := ss.LoggingMessage(ctx, &LoggingMessageParams{ | 
|  | Logger: "test", | 
|  | Level:  level, | 
|  | Data:   data, | 
|  | }); err != nil { | 
|  | t.Fatal(err) | 
|  | } | 
|  | } | 
|  |  | 
|  | // Nothing should be logged until the client sets a level. | 
|  | mustLog("info", "before") | 
|  | if err := cs.SetLevel(ctx, &SetLevelParams{Level: "warning"}); err != nil { | 
|  | t.Fatal(err) | 
|  | } | 
|  | mustLog("warning", want[0].Data) | 
|  | mustLog("debug", "nope")    // below the level | 
|  | mustLog("info", "negative") // below the level | 
|  | mustLog("alert", want[1].Data) | 
|  | check(t) | 
|  | }) | 
|  |  | 
|  | t.Run("handler", func(t *testing.T) { // Use the slog handler. | 
|  | // We can't check the "before SetLevel" behavior because it's already been set. | 
|  | // Not a big deal: that check is in LoggingMessage anyway. | 
|  | logger := slog.New(NewLoggingHandler(ss, &LoggingHandlerOptions{LoggerName: "test"})) | 
|  | logger.Warn("first", "name", "Pat", "logtest", true) | 
|  | logger.Debug("nope")    // below the level | 
|  | logger.Info("negative") // below the level | 
|  | logger.Log(ctx, LevelAlert, "second", "count", 2, "logtest", true) | 
|  | check(t) | 
|  | }) | 
|  | }) | 
|  | t.Run("progress", func(t *testing.T) { | 
|  | ss.NotifyProgress(ctx, &ProgressNotificationParams{ | 
|  | ProgressToken: "token-xyz", | 
|  | Message:       "progress update", | 
|  | Progress:      0.5, | 
|  | Total:         2, | 
|  | }) | 
|  | waitForNotification(t, "progress_client") | 
|  | cs.NotifyProgress(ctx, &ProgressNotificationParams{ | 
|  | ProgressToken: "token-abc", | 
|  | Message:       "progress update", | 
|  | Progress:      1, | 
|  | Total:         2, | 
|  | }) | 
|  | waitForNotification(t, "progress_server") | 
|  | }) | 
|  |  | 
|  | // Disconnect. | 
|  | cs.Close() | 
|  | clientWG.Wait() | 
|  |  | 
|  | // After disconnecting, neither client nor server should have any | 
|  | // connections. | 
|  | for range s.Sessions() { | 
|  | t.Errorf("unexpected client after disconnection") | 
|  | } | 
|  | } | 
|  |  | 
|  | // Registry of values to be referenced in tests. | 
|  | var ( | 
|  | errTestFailure = errors.New("mcp failure") | 
|  |  | 
|  | tools = map[string]*ServerTool{ | 
|  | "greet": NewServerTool("greet", "say hi", sayHi), | 
|  | "fail": { | 
|  | Tool: &Tool{Name: "fail"}, | 
|  | Handler: func(context.Context, *ServerSession, *CallToolParamsFor[map[string]any]) (*CallToolResult, error) { | 
|  | return nil, errTestFailure | 
|  | }, | 
|  | }, | 
|  | } | 
|  |  | 
|  | prompts = map[string]*ServerPrompt{ | 
|  | "code_review": { | 
|  | Prompt: &Prompt{ | 
|  | Name:        "code_review", | 
|  | Description: "do a code review", | 
|  | Arguments:   []*PromptArgument{{Name: "Code", Required: true}}, | 
|  | }, | 
|  | Handler: func(_ context.Context, _ *ServerSession, params *GetPromptParams) (*GetPromptResult, error) { | 
|  | return &GetPromptResult{ | 
|  | Description: "Code review prompt", | 
|  | Messages: []*PromptMessage{ | 
|  | {Role: "user", Content: NewTextContent("Please review the following code: " + params.Arguments["Code"])}, | 
|  | }, | 
|  | }, nil | 
|  | }, | 
|  | }, | 
|  | "fail": { | 
|  | Prompt: &Prompt{Name: "fail"}, | 
|  | Handler: func(_ context.Context, _ *ServerSession, _ *GetPromptParams) (*GetPromptResult, error) { | 
|  | return nil, errTestFailure | 
|  | }, | 
|  | }, | 
|  | } | 
|  |  | 
|  | resource1 = &Resource{ | 
|  | Name:     "public", | 
|  | MIMEType: "text/plain", | 
|  | URI:      "file:///info.txt", | 
|  | } | 
|  | resource2 = &Resource{ | 
|  | Name:     "public", // names are not unique IDs | 
|  | MIMEType: "text/plain", | 
|  | URI:      "file:///fail.txt", | 
|  | } | 
|  | resource3 = &Resource{ | 
|  | Name:     "info", | 
|  | MIMEType: "text/plain", | 
|  | URI:      "embedded:info", | 
|  | } | 
|  | readHandler = fileResourceHandler("testdata/files") | 
|  | resources   = map[string]*ServerResource{ | 
|  | "info.txt": {resource1, readHandler}, | 
|  | "fail.txt": {resource2, readHandler}, | 
|  | "info":     {resource3, handleEmbeddedResource}, | 
|  | } | 
|  | ) | 
|  |  | 
|  | var embeddedResources = map[string]string{ | 
|  | "info": "This is the MCP test server.", | 
|  | } | 
|  |  | 
|  | func handleEmbeddedResource(_ context.Context, _ *ServerSession, params *ReadResourceParams) (*ReadResourceResult, error) { | 
|  | u, err := url.Parse(params.URI) | 
|  | if err != nil { | 
|  | return nil, err | 
|  | } | 
|  | if u.Scheme != "embedded" { | 
|  | return nil, fmt.Errorf("wrong scheme: %q", u.Scheme) | 
|  | } | 
|  | key := u.Opaque | 
|  | text, ok := embeddedResources[key] | 
|  | if !ok { | 
|  | return nil, fmt.Errorf("no embedded resource named %q", key) | 
|  | } | 
|  | return &ReadResourceResult{ | 
|  | Contents: []*ResourceContents{NewTextResourceContents(params.URI, "text/plain", text)}, | 
|  | }, nil | 
|  | } | 
|  |  | 
|  | // Add calls the given function to add the named features. | 
|  | func add[T any](m map[string]T, add func(...T), names ...string) { | 
|  | for _, name := range names { | 
|  | feat, ok := m[name] | 
|  | if !ok { | 
|  | panic("missing feature " + name) | 
|  | } | 
|  | add(feat) | 
|  | } | 
|  | } | 
|  |  | 
|  | // errorCode returns the code associated with err. | 
|  | // If err is nil, it returns 0. | 
|  | // If there is no code, it returns -1. | 
|  | func errorCode(err error) int64 { | 
|  | if err == nil { | 
|  | return 0 | 
|  | } | 
|  | var werr *jsonrpc2.WireError | 
|  | if errors.As(err, &werr) { | 
|  | return werr.Code | 
|  | } | 
|  | return -1 | 
|  | } | 
|  |  | 
|  | // basicConnection returns a new basic client-server connection configured with | 
|  | // the provided tools. | 
|  | // | 
|  | // The caller should cancel either the client connection or server connection | 
|  | // when the connections are no longer needed. | 
|  | func basicConnection(t *testing.T, tools ...*ServerTool) (*ServerSession, *ClientSession) { | 
|  | t.Helper() | 
|  |  | 
|  | ctx := context.Background() | 
|  | ct, st := NewInMemoryTransports() | 
|  |  | 
|  | s := NewServer("testServer", "v1.0.0", nil) | 
|  |  | 
|  | // The 'greet' tool says hi. | 
|  | s.AddTools(tools...) | 
|  | ss, err := s.Connect(ctx, st) | 
|  | if err != nil { | 
|  | t.Fatal(err) | 
|  | } | 
|  |  | 
|  | c := NewClient("testClient", "v1.0.0", nil) | 
|  | cs, err := c.Connect(ctx, ct) | 
|  | if err != nil { | 
|  | t.Fatal(err) | 
|  | } | 
|  | return ss, cs | 
|  | } | 
|  |  | 
|  | func TestServerClosing(t *testing.T) { | 
|  | cc, cs := basicConnection(t, NewServerTool("greet", "say hi", sayHi)) | 
|  | defer cs.Close() | 
|  |  | 
|  | ctx := context.Background() | 
|  | var wg sync.WaitGroup | 
|  | wg.Add(1) | 
|  | go func() { | 
|  | if err := cs.Wait(); err != nil { | 
|  | t.Errorf("server connection failed: %v", err) | 
|  | } | 
|  | wg.Done() | 
|  | }() | 
|  | if _, err := cs.CallTool(ctx, &CallToolParams{ | 
|  | Name:      "greet", | 
|  | Arguments: map[string]any{"name": "user"}, | 
|  | }); err != nil { | 
|  | t.Fatalf("after connecting: %v", err) | 
|  | } | 
|  | cc.Close() | 
|  | wg.Wait() | 
|  | if _, err := cs.CallTool(ctx, &CallToolParams{ | 
|  | Name:      "greet", | 
|  | Arguments: map[string]any{"name": "user"}, | 
|  | }); !errors.Is(err, ErrConnectionClosed) { | 
|  | t.Errorf("after disconnection, got error %v, want EOF", err) | 
|  | } | 
|  | } | 
|  |  | 
|  | func TestBatching(t *testing.T) { | 
|  | ctx := context.Background() | 
|  | ct, st := NewInMemoryTransports() | 
|  |  | 
|  | s := NewServer("testServer", "v1.0.0", nil) | 
|  | _, err := s.Connect(ctx, st) | 
|  | if err != nil { | 
|  | t.Fatal(err) | 
|  | } | 
|  |  | 
|  | c := NewClient("testClient", "v1.0.0", nil) | 
|  | // TODO: this test is broken, because increasing the batch size here causes | 
|  | // 'initialize' to block. Therefore, we can only test with a size of 1. | 
|  | // Since batching is being removed, we can probably just delete this. | 
|  | const batchSize = 1 | 
|  | cs, err := c.Connect(ctx, ct) | 
|  | if err != nil { | 
|  | t.Fatal(err) | 
|  | } | 
|  | defer cs.Close() | 
|  |  | 
|  | errs := make(chan error, batchSize) | 
|  | for i := range batchSize { | 
|  | go func() { | 
|  | _, err := cs.ListTools(ctx, nil) | 
|  | errs <- err | 
|  | }() | 
|  | time.Sleep(2 * time.Millisecond) | 
|  | if i < batchSize-1 { | 
|  | select { | 
|  | case <-errs: | 
|  | t.Errorf("ListTools: unexpected result for incomplete batch: %v", err) | 
|  | default: | 
|  | } | 
|  | } | 
|  | } | 
|  | } | 
|  |  | 
|  | func TestCancellation(t *testing.T) { | 
|  | var ( | 
|  | start     = make(chan struct{}) | 
|  | cancelled = make(chan struct{}, 1) // don't block the request | 
|  | ) | 
|  |  | 
|  | slowRequest := func(ctx context.Context, cc *ServerSession, params *CallToolParamsFor[map[string]any]) (*CallToolResult, error) { | 
|  | start <- struct{}{} | 
|  | select { | 
|  | case <-ctx.Done(): | 
|  | cancelled <- struct{}{} | 
|  | case <-time.After(5 * time.Second): | 
|  | return nil, nil | 
|  | } | 
|  | return nil, nil | 
|  | } | 
|  | st := &ServerTool{ | 
|  | Tool:    &Tool{Name: "slow"}, | 
|  | Handler: slowRequest, | 
|  | } | 
|  | _, cs := basicConnection(t, st) | 
|  | defer cs.Close() | 
|  |  | 
|  | ctx, cancel := context.WithCancel(context.Background()) | 
|  | go cs.CallTool(ctx, &CallToolParams{Name: "slow"}) | 
|  | <-start | 
|  | cancel() | 
|  | select { | 
|  | case <-cancelled: | 
|  | case <-time.After(5 * time.Second): | 
|  | t.Fatal("timeout waiting for cancellation") | 
|  | } | 
|  | } | 
|  |  | 
|  | func TestMiddleware(t *testing.T) { | 
|  | ctx := context.Background() | 
|  | ct, st := NewInMemoryTransports() | 
|  |  | 
|  | s := NewServer("testServer", "v1.0.0", nil) | 
|  | ss, err := s.Connect(ctx, st) | 
|  | if err != nil { | 
|  | t.Fatal(err) | 
|  | } | 
|  | // Wait for the server to exit after the client closes its connection. | 
|  | var clientWG sync.WaitGroup | 
|  | clientWG.Add(1) | 
|  | go func() { | 
|  | if err := ss.Wait(); err != nil { | 
|  | t.Errorf("server failed: %v", err) | 
|  | } | 
|  | clientWG.Done() | 
|  | }() | 
|  |  | 
|  | var sbuf, cbuf bytes.Buffer | 
|  | sbuf.WriteByte('\n') | 
|  | cbuf.WriteByte('\n') | 
|  |  | 
|  | // "1" is the outer middleware layer, called first; then "2" is called, and finally | 
|  | // the default dispatcher. | 
|  | s.AddSendingMiddleware(traceCalls[*ServerSession](&sbuf, "S1"), traceCalls[*ServerSession](&sbuf, "S2")) | 
|  | s.AddReceivingMiddleware(traceCalls[*ServerSession](&sbuf, "R1"), traceCalls[*ServerSession](&sbuf, "R2")) | 
|  |  | 
|  | c := NewClient("testClient", "v1.0.0", nil) | 
|  | c.AddSendingMiddleware(traceCalls[*ClientSession](&cbuf, "S1"), traceCalls[*ClientSession](&cbuf, "S2")) | 
|  | c.AddReceivingMiddleware(traceCalls[*ClientSession](&cbuf, "R1"), traceCalls[*ClientSession](&cbuf, "R2")) | 
|  |  | 
|  | cs, err := c.Connect(ctx, ct) | 
|  | if err != nil { | 
|  | t.Fatal(err) | 
|  | } | 
|  | if _, err := cs.ListTools(ctx, nil); err != nil { | 
|  | t.Fatal(err) | 
|  | } | 
|  | if _, err := ss.ListRoots(ctx, nil); err != nil { | 
|  | t.Fatal(err) | 
|  | } | 
|  |  | 
|  | wantServer := ` | 
|  | R1 >initialize | 
|  | R2 >initialize | 
|  | R2 <initialize | 
|  | R1 <initialize | 
|  | R1 >notifications/initialized | 
|  | R2 >notifications/initialized | 
|  | R2 <notifications/initialized | 
|  | R1 <notifications/initialized | 
|  | R1 >tools/list | 
|  | R2 >tools/list | 
|  | R2 <tools/list | 
|  | R1 <tools/list | 
|  | S1 >roots/list | 
|  | S2 >roots/list | 
|  | S2 <roots/list | 
|  | S1 <roots/list | 
|  | ` | 
|  | if diff := cmp.Diff(wantServer, sbuf.String()); diff != "" { | 
|  | t.Errorf("server mismatch (-want, +got):\n%s", diff) | 
|  | } | 
|  |  | 
|  | wantClient := ` | 
|  | S1 >initialize | 
|  | S2 >initialize | 
|  | S2 <initialize | 
|  | S1 <initialize | 
|  | S1 >notifications/initialized | 
|  | S2 >notifications/initialized | 
|  | S2 <notifications/initialized | 
|  | S1 <notifications/initialized | 
|  | S1 >tools/list | 
|  | S2 >tools/list | 
|  | S2 <tools/list | 
|  | S1 <tools/list | 
|  | R1 >roots/list | 
|  | R2 >roots/list | 
|  | R2 <roots/list | 
|  | R1 <roots/list | 
|  | ` | 
|  | if diff := cmp.Diff(wantClient, cbuf.String()); diff != "" { | 
|  | t.Errorf("client mismatch (-want, +got):\n%s", diff) | 
|  | } | 
|  | } | 
|  |  | 
|  | type safeBuffer struct { | 
|  | mu  sync.Mutex | 
|  | buf bytes.Buffer | 
|  | } | 
|  |  | 
|  | func (b *safeBuffer) Write(data []byte) (int, error) { | 
|  | b.mu.Lock() | 
|  | defer b.mu.Unlock() | 
|  | return b.buf.Write(data) | 
|  | } | 
|  |  | 
|  | func (b *safeBuffer) Bytes() []byte { | 
|  | b.mu.Lock() | 
|  | defer b.mu.Unlock() | 
|  | return b.buf.Bytes() | 
|  | } | 
|  |  | 
|  | func TestNoJSONNull(t *testing.T) { | 
|  | ctx := context.Background() | 
|  | var ct, st Transport = NewInMemoryTransports() | 
|  |  | 
|  | // Collect logs, to sanity check that we don't write JSON null anywhere. | 
|  | var logbuf safeBuffer | 
|  | ct = NewLoggingTransport(ct, &logbuf) | 
|  |  | 
|  | s := NewServer("testServer", "v1.0.0", nil) | 
|  | ss, err := s.Connect(ctx, st) | 
|  | if err != nil { | 
|  | t.Fatal(err) | 
|  | } | 
|  |  | 
|  | c := NewClient("testClient", "v1.0.0", nil) | 
|  | cs, err := c.Connect(ctx, ct) | 
|  | if err != nil { | 
|  | t.Fatal(err) | 
|  | } | 
|  | if _, err := cs.ListTools(ctx, nil); err != nil { | 
|  | t.Fatal(err) | 
|  | } | 
|  | if _, err := cs.ListPrompts(ctx, nil); err != nil { | 
|  | t.Fatal(err) | 
|  | } | 
|  | if _, err := cs.ListResources(ctx, nil); err != nil { | 
|  | t.Fatal(err) | 
|  | } | 
|  | if _, err := cs.ListResourceTemplates(ctx, nil); err != nil { | 
|  | t.Fatal(err) | 
|  | } | 
|  | if _, err := ss.ListRoots(ctx, nil); err != nil { | 
|  | t.Fatal(err) | 
|  | } | 
|  |  | 
|  | cs.Close() | 
|  | ss.Wait() | 
|  |  | 
|  | logs := logbuf.Bytes() | 
|  | if i := bytes.Index(logs, []byte("null")); i >= 0 { | 
|  | start := max(i-20, 0) | 
|  | end := min(i+20, len(logs)) | 
|  | t.Errorf("conformance violation: MCP logs contain JSON null: %s", "..."+string(logs[start:end])+"...") | 
|  | } | 
|  | } | 
|  |  | 
|  | // traceCalls creates a middleware function that prints the method before and after each call | 
|  | // with the given prefix. | 
|  | func traceCalls[S Session](w io.Writer, prefix string) Middleware[S] { | 
|  | return func(h MethodHandler[S]) MethodHandler[S] { | 
|  | return func(ctx context.Context, sess S, method string, params Params) (Result, error) { | 
|  | fmt.Fprintf(w, "%s >%s\n", prefix, method) | 
|  | defer fmt.Fprintf(w, "%s <%s\n", prefix, method) | 
|  | return h(ctx, sess, method, params) | 
|  | } | 
|  | } | 
|  | } | 
|  |  | 
|  | // A function, because schemas must form a tree (they have hidden state). | 
|  | func falseSchema() *jsonschema.Schema { return &jsonschema.Schema{Not: &jsonschema.Schema{}} } | 
|  |  | 
|  | func nopHandler(context.Context, *ServerSession, *CallToolParamsFor[map[string]any]) (*CallToolResult, error) { | 
|  | return nil, nil | 
|  | } |