| // Copyright 2025 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package mcp |
| |
| import ( |
| "bytes" |
| "context" |
| "errors" |
| "fmt" |
| "io" |
| "log/slog" |
| "net/url" |
| "path/filepath" |
| "runtime" |
| "slices" |
| "strings" |
| "sync" |
| "testing" |
| "time" |
| |
| "github.com/google/go-cmp/cmp" |
| "github.com/google/go-cmp/cmp/cmpopts" |
| jsonrpc2 "golang.org/x/tools/internal/jsonrpc2_v2" |
| "golang.org/x/tools/internal/mcp/jsonschema" |
| ) |
| |
| type hiParams struct { |
| Name string |
| } |
| |
| func sayHi(ctx context.Context, ss *ServerSession, params *CallToolParamsFor[hiParams]) (*CallToolResultFor[any], error) { |
| if err := ss.Ping(ctx, nil); err != nil { |
| return nil, fmt.Errorf("ping failed: %v", err) |
| } |
| return &CallToolResultFor[any]{Content: []*Content{NewTextContent("hi " + params.Arguments.Name)}}, nil |
| } |
| |
| func TestEndToEnd(t *testing.T) { |
| ctx := context.Background() |
| var ct, st Transport = NewInMemoryTransports() |
| |
| // Channels to check if notification callbacks happened. |
| notificationChans := map[string]chan int{} |
| for _, name := range []string{"initialized", "roots", "tools", "prompts", "resources", "progress_server", "progress_client"} { |
| notificationChans[name] = make(chan int, 1) |
| } |
| waitForNotification := func(t *testing.T, name string) { |
| t.Helper() |
| select { |
| case <-notificationChans[name]: |
| case <-time.After(time.Second): |
| t.Fatalf("%s handler never called", name) |
| } |
| } |
| |
| sopts := &ServerOptions{ |
| InitializedHandler: func(context.Context, *ServerSession, *InitializedParams) { notificationChans["initialized"] <- 0 }, |
| RootsListChangedHandler: func(context.Context, *ServerSession, *RootsListChangedParams) { notificationChans["roots"] <- 0 }, |
| ProgressNotificationHandler: func(context.Context, *ServerSession, *ProgressNotificationParams) { |
| notificationChans["progress_server"] <- 0 |
| }, |
| } |
| s := NewServer("testServer", "v1.0.0", sopts) |
| add(tools, s.AddTools, "greet", "fail") |
| add(prompts, s.AddPrompts, "code_review", "fail") |
| add(resources, s.AddResources, "info.txt", "fail.txt") |
| |
| // Connect the server. |
| ss, err := s.Connect(ctx, st) |
| if err != nil { |
| t.Fatal(err) |
| } |
| if got := slices.Collect(s.Sessions()); len(got) != 1 { |
| t.Errorf("after connection, Clients() has length %d, want 1", len(got)) |
| } |
| |
| // Wait for the server to exit after the client closes its connection. |
| var clientWG sync.WaitGroup |
| clientWG.Add(1) |
| go func() { |
| if err := ss.Wait(); err != nil { |
| t.Errorf("server failed: %v", err) |
| } |
| clientWG.Done() |
| }() |
| |
| loggingMessages := make(chan *LoggingMessageParams, 100) // big enough for all logging |
| opts := &ClientOptions{ |
| CreateMessageHandler: func(context.Context, *ClientSession, *CreateMessageParams) (*CreateMessageResult, error) { |
| return &CreateMessageResult{Model: "aModel"}, nil |
| }, |
| ToolListChangedHandler: func(context.Context, *ClientSession, *ToolListChangedParams) { notificationChans["tools"] <- 0 }, |
| PromptListChangedHandler: func(context.Context, *ClientSession, *PromptListChangedParams) { notificationChans["prompts"] <- 0 }, |
| ResourceListChangedHandler: func(context.Context, *ClientSession, *ResourceListChangedParams) { notificationChans["resources"] <- 0 }, |
| LoggingMessageHandler: func(_ context.Context, _ *ClientSession, lm *LoggingMessageParams) { |
| loggingMessages <- lm |
| }, |
| ProgressNotificationHandler: func(context.Context, *ClientSession, *ProgressNotificationParams) { |
| notificationChans["progress_client"] <- 0 |
| }, |
| } |
| c := NewClient("testClient", "v1.0.0", opts) |
| rootAbs, err := filepath.Abs(filepath.FromSlash("testdata/files")) |
| if err != nil { |
| t.Fatal(err) |
| } |
| c.AddRoots(&Root{URI: "file://" + rootAbs}) |
| |
| // Connect the client. |
| cs, err := c.Connect(ctx, ct) |
| if err != nil { |
| t.Fatal(err) |
| } |
| |
| waitForNotification(t, "initialized") |
| if err := cs.Ping(ctx, nil); err != nil { |
| t.Fatalf("ping failed: %v", err) |
| } |
| t.Run("prompts", func(t *testing.T) { |
| res, err := cs.ListPrompts(ctx, nil) |
| if err != nil { |
| t.Fatalf("prompts/list failed: %v", err) |
| } |
| wantPrompts := []*Prompt{ |
| { |
| Name: "code_review", |
| Description: "do a code review", |
| Arguments: []*PromptArgument{{Name: "Code", Required: true}}, |
| }, |
| {Name: "fail"}, |
| } |
| if diff := cmp.Diff(wantPrompts, res.Prompts); diff != "" { |
| t.Fatalf("prompts/list mismatch (-want +got):\n%s", diff) |
| } |
| |
| gotReview, err := cs.GetPrompt(ctx, &GetPromptParams{Name: "code_review", Arguments: map[string]string{"Code": "1+1"}}) |
| if err != nil { |
| t.Fatal(err) |
| } |
| wantReview := &GetPromptResult{ |
| Description: "Code review prompt", |
| Messages: []*PromptMessage{{ |
| Content: NewTextContent("Please review the following code: 1+1"), |
| Role: "user", |
| }}, |
| } |
| if diff := cmp.Diff(wantReview, gotReview); diff != "" { |
| t.Errorf("prompts/get 'code_review' mismatch (-want +got):\n%s", diff) |
| } |
| |
| if _, err := cs.GetPrompt(ctx, &GetPromptParams{Name: "fail"}); err == nil || !strings.Contains(err.Error(), errTestFailure.Error()) { |
| t.Errorf("fail returned unexpected error: got %v, want containing %v", err, errTestFailure) |
| } |
| |
| s.AddPrompts(&ServerPrompt{Prompt: &Prompt{Name: "T"}}) |
| waitForNotification(t, "prompts") |
| s.RemovePrompts("T") |
| waitForNotification(t, "prompts") |
| }) |
| |
| t.Run("tools", func(t *testing.T) { |
| res, err := cs.ListTools(ctx, nil) |
| if err != nil { |
| t.Errorf("tools/list failed: %v", err) |
| } |
| wantTools := []*Tool{ |
| { |
| Name: "fail", |
| InputSchema: nil, |
| }, |
| { |
| Name: "greet", |
| Description: "say hi", |
| InputSchema: &jsonschema.Schema{ |
| Type: "object", |
| Required: []string{"Name"}, |
| Properties: map[string]*jsonschema.Schema{ |
| "Name": {Type: "string"}, |
| }, |
| AdditionalProperties: falseSchema(), |
| }, |
| }, |
| } |
| if diff := cmp.Diff(wantTools, res.Tools, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" { |
| t.Fatalf("tools/list mismatch (-want +got):\n%s", diff) |
| } |
| |
| gotHi, err := cs.CallTool(ctx, &CallToolParams{ |
| Name: "greet", |
| Arguments: map[string]any{"name": "user"}, |
| }) |
| if err != nil { |
| t.Fatal(err) |
| } |
| wantHi := &CallToolResult{ |
| Content: []*Content{{Type: "text", Text: "hi user"}}, |
| } |
| if diff := cmp.Diff(wantHi, gotHi); diff != "" { |
| t.Errorf("tools/call 'greet' mismatch (-want +got):\n%s", diff) |
| } |
| |
| gotFail, err := cs.CallTool(ctx, &CallToolParams{ |
| Name: "fail", |
| Arguments: map[string]any{}, |
| }) |
| // Counter-intuitively, when a tool fails, we don't expect an RPC error for |
| // call tool: instead, the failure is embedded in the result. |
| if err != nil { |
| t.Fatal(err) |
| } |
| wantFail := &CallToolResult{ |
| IsError: true, |
| Content: []*Content{{Type: "text", Text: errTestFailure.Error()}}, |
| } |
| if diff := cmp.Diff(wantFail, gotFail); diff != "" { |
| t.Errorf("tools/call 'fail' mismatch (-want +got):\n%s", diff) |
| } |
| |
| s.AddTools(&ServerTool{Tool: &Tool{Name: "T"}, Handler: nopHandler}) |
| waitForNotification(t, "tools") |
| s.RemoveTools("T") |
| waitForNotification(t, "tools") |
| }) |
| |
| t.Run("resources", func(t *testing.T) { |
| if runtime.GOOS == "windows" { |
| t.Skip("TODO: fix for Windows") |
| } |
| wantResources := []*Resource{resource2, resource1} |
| lrres, err := cs.ListResources(ctx, nil) |
| if err != nil { |
| t.Fatal(err) |
| } |
| if diff := cmp.Diff(wantResources, lrres.Resources); diff != "" { |
| t.Errorf("resources/list mismatch (-want, +got):\n%s", diff) |
| } |
| |
| template := &ResourceTemplate{ |
| Name: "rt", |
| MIMEType: "text/template", |
| URITemplate: "file:///{+filename}", // the '+' means that filename can contain '/' |
| } |
| st := &ServerResourceTemplate{ResourceTemplate: template, Handler: readHandler} |
| s.AddResourceTemplates(st) |
| tres, err := cs.ListResourceTemplates(ctx, nil) |
| if err != nil { |
| t.Fatal(err) |
| } |
| if diff := cmp.Diff([]*ResourceTemplate{template}, tres.ResourceTemplates); diff != "" { |
| t.Errorf("resources/list mismatch (-want, +got):\n%s", diff) |
| } |
| |
| for _, tt := range []struct { |
| uri string |
| mimeType string // "": not found; "text/plain": resource; "text/template": template |
| fail bool // non-nil error returned |
| }{ |
| {"file:///info.txt", "text/plain", false}, |
| {"file:///fail.txt", "", false}, |
| {"file:///template.txt", "text/template", false}, |
| {"file:///../private.txt", "", true}, // not found: escaping disallowed |
| } { |
| rres, err := cs.ReadResource(ctx, &ReadResourceParams{URI: tt.uri}) |
| if err != nil { |
| if code := errorCode(err); code == CodeResourceNotFound { |
| if tt.mimeType != "" { |
| t.Errorf("%s: not found but expected it to be", tt.uri) |
| } |
| } else if !tt.fail { |
| t.Errorf("%s: unexpected error %v", tt.uri, err) |
| } |
| } else { |
| if tt.fail { |
| t.Errorf("%s: unexpected success", tt.uri) |
| } else if g, w := len(rres.Contents), 1; g != w { |
| t.Errorf("got %d contents, wanted %d", g, w) |
| } else { |
| c := rres.Contents[0] |
| if got := c.URI; got != tt.uri { |
| t.Errorf("got uri %q, want %q", got, tt.uri) |
| } |
| if got := c.MIMEType; got != tt.mimeType { |
| t.Errorf("%s: got MIME type %q, want %q", tt.uri, got, tt.mimeType) |
| } |
| } |
| } |
| } |
| |
| s.AddResources(&ServerResource{Resource: &Resource{URI: "http://U"}}) |
| waitForNotification(t, "resources") |
| s.RemoveResources("http://U") |
| waitForNotification(t, "resources") |
| }) |
| t.Run("roots", func(t *testing.T) { |
| rootRes, err := ss.ListRoots(ctx, &ListRootsParams{}) |
| if err != nil { |
| t.Fatal(err) |
| } |
| gotRoots := rootRes.Roots |
| wantRoots := slices.Collect(c.roots.all()) |
| if diff := cmp.Diff(wantRoots, gotRoots); diff != "" { |
| t.Errorf("roots/list mismatch (-want +got):\n%s", diff) |
| } |
| |
| c.AddRoots(&Root{URI: "U"}) |
| waitForNotification(t, "roots") |
| c.RemoveRoots("U") |
| waitForNotification(t, "roots") |
| }) |
| t.Run("sampling", func(t *testing.T) { |
| // TODO: test that a client that doesn't have the handler returns CodeUnsupportedMethod. |
| res, err := ss.CreateMessage(ctx, &CreateMessageParams{}) |
| if err != nil { |
| t.Fatal(err) |
| } |
| if g, w := res.Model, "aModel"; g != w { |
| t.Errorf("got %q, want %q", g, w) |
| } |
| }) |
| t.Run("logging", func(t *testing.T) { |
| want := []*LoggingMessageParams{ |
| { |
| Logger: "test", |
| Level: "warning", |
| Data: map[string]any{ |
| "msg": "first", |
| "name": "Pat", |
| "logtest": true, |
| }, |
| }, |
| { |
| Logger: "test", |
| Level: "alert", |
| Data: map[string]any{ |
| "msg": "second", |
| "count": 2.0, |
| "logtest": true, |
| }, |
| }, |
| } |
| |
| check := func(t *testing.T) { |
| t.Helper() |
| var got []*LoggingMessageParams |
| // Read messages from this test until we've seen all we expect. |
| for len(got) < len(want) { |
| select { |
| case p := <-loggingMessages: |
| // Ignore logging from other tests. |
| if m, ok := p.Data.(map[string]any); ok && m["logtest"] != nil { |
| delete(m, "time") // remove time because it changes |
| got = append(got, p) |
| } |
| case <-time.After(time.Second): |
| t.Fatal("timed out waiting for log messages") |
| } |
| } |
| if diff := cmp.Diff(want, got); diff != "" { |
| t.Errorf("mismatch (-want, +got):\n%s", diff) |
| } |
| } |
| |
| t.Run("direct", func(t *testing.T) { // Use the LoggingMessage method directly. |
| |
| mustLog := func(level LoggingLevel, data any) { |
| t.Helper() |
| if err := ss.LoggingMessage(ctx, &LoggingMessageParams{ |
| Logger: "test", |
| Level: level, |
| Data: data, |
| }); err != nil { |
| t.Fatal(err) |
| } |
| } |
| |
| // Nothing should be logged until the client sets a level. |
| mustLog("info", "before") |
| if err := cs.SetLevel(ctx, &SetLevelParams{Level: "warning"}); err != nil { |
| t.Fatal(err) |
| } |
| mustLog("warning", want[0].Data) |
| mustLog("debug", "nope") // below the level |
| mustLog("info", "negative") // below the level |
| mustLog("alert", want[1].Data) |
| check(t) |
| }) |
| |
| t.Run("handler", func(t *testing.T) { // Use the slog handler. |
| // We can't check the "before SetLevel" behavior because it's already been set. |
| // Not a big deal: that check is in LoggingMessage anyway. |
| logger := slog.New(NewLoggingHandler(ss, &LoggingHandlerOptions{LoggerName: "test"})) |
| logger.Warn("first", "name", "Pat", "logtest", true) |
| logger.Debug("nope") // below the level |
| logger.Info("negative") // below the level |
| logger.Log(ctx, LevelAlert, "second", "count", 2, "logtest", true) |
| check(t) |
| }) |
| }) |
| t.Run("progress", func(t *testing.T) { |
| ss.NotifyProgress(ctx, &ProgressNotificationParams{ |
| ProgressToken: "token-xyz", |
| Message: "progress update", |
| Progress: 0.5, |
| Total: 2, |
| }) |
| waitForNotification(t, "progress_client") |
| cs.NotifyProgress(ctx, &ProgressNotificationParams{ |
| ProgressToken: "token-abc", |
| Message: "progress update", |
| Progress: 1, |
| Total: 2, |
| }) |
| waitForNotification(t, "progress_server") |
| }) |
| |
| // Disconnect. |
| cs.Close() |
| clientWG.Wait() |
| |
| // After disconnecting, neither client nor server should have any |
| // connections. |
| for range s.Sessions() { |
| t.Errorf("unexpected client after disconnection") |
| } |
| } |
| |
| // Registry of values to be referenced in tests. |
| var ( |
| errTestFailure = errors.New("mcp failure") |
| |
| tools = map[string]*ServerTool{ |
| "greet": NewServerTool("greet", "say hi", sayHi), |
| "fail": { |
| Tool: &Tool{Name: "fail"}, |
| Handler: func(context.Context, *ServerSession, *CallToolParamsFor[map[string]any]) (*CallToolResult, error) { |
| return nil, errTestFailure |
| }, |
| }, |
| } |
| |
| prompts = map[string]*ServerPrompt{ |
| "code_review": { |
| Prompt: &Prompt{ |
| Name: "code_review", |
| Description: "do a code review", |
| Arguments: []*PromptArgument{{Name: "Code", Required: true}}, |
| }, |
| Handler: func(_ context.Context, _ *ServerSession, params *GetPromptParams) (*GetPromptResult, error) { |
| return &GetPromptResult{ |
| Description: "Code review prompt", |
| Messages: []*PromptMessage{ |
| {Role: "user", Content: NewTextContent("Please review the following code: " + params.Arguments["Code"])}, |
| }, |
| }, nil |
| }, |
| }, |
| "fail": { |
| Prompt: &Prompt{Name: "fail"}, |
| Handler: func(_ context.Context, _ *ServerSession, _ *GetPromptParams) (*GetPromptResult, error) { |
| return nil, errTestFailure |
| }, |
| }, |
| } |
| |
| resource1 = &Resource{ |
| Name: "public", |
| MIMEType: "text/plain", |
| URI: "file:///info.txt", |
| } |
| resource2 = &Resource{ |
| Name: "public", // names are not unique IDs |
| MIMEType: "text/plain", |
| URI: "file:///fail.txt", |
| } |
| resource3 = &Resource{ |
| Name: "info", |
| MIMEType: "text/plain", |
| URI: "embedded:info", |
| } |
| readHandler = fileResourceHandler("testdata/files") |
| resources = map[string]*ServerResource{ |
| "info.txt": {resource1, readHandler}, |
| "fail.txt": {resource2, readHandler}, |
| "info": {resource3, handleEmbeddedResource}, |
| } |
| ) |
| |
| var embeddedResources = map[string]string{ |
| "info": "This is the MCP test server.", |
| } |
| |
| func handleEmbeddedResource(_ context.Context, _ *ServerSession, params *ReadResourceParams) (*ReadResourceResult, error) { |
| u, err := url.Parse(params.URI) |
| if err != nil { |
| return nil, err |
| } |
| if u.Scheme != "embedded" { |
| return nil, fmt.Errorf("wrong scheme: %q", u.Scheme) |
| } |
| key := u.Opaque |
| text, ok := embeddedResources[key] |
| if !ok { |
| return nil, fmt.Errorf("no embedded resource named %q", key) |
| } |
| return &ReadResourceResult{ |
| Contents: []*ResourceContents{NewTextResourceContents(params.URI, "text/plain", text)}, |
| }, nil |
| } |
| |
| // Add calls the given function to add the named features. |
| func add[T any](m map[string]T, add func(...T), names ...string) { |
| for _, name := range names { |
| feat, ok := m[name] |
| if !ok { |
| panic("missing feature " + name) |
| } |
| add(feat) |
| } |
| } |
| |
| // errorCode returns the code associated with err. |
| // If err is nil, it returns 0. |
| // If there is no code, it returns -1. |
| func errorCode(err error) int64 { |
| if err == nil { |
| return 0 |
| } |
| var werr *jsonrpc2.WireError |
| if errors.As(err, &werr) { |
| return werr.Code |
| } |
| return -1 |
| } |
| |
| // basicConnection returns a new basic client-server connection configured with |
| // the provided tools. |
| // |
| // The caller should cancel either the client connection or server connection |
| // when the connections are no longer needed. |
| func basicConnection(t *testing.T, tools ...*ServerTool) (*ServerSession, *ClientSession) { |
| t.Helper() |
| |
| ctx := context.Background() |
| ct, st := NewInMemoryTransports() |
| |
| s := NewServer("testServer", "v1.0.0", nil) |
| |
| // The 'greet' tool says hi. |
| s.AddTools(tools...) |
| ss, err := s.Connect(ctx, st) |
| if err != nil { |
| t.Fatal(err) |
| } |
| |
| c := NewClient("testClient", "v1.0.0", nil) |
| cs, err := c.Connect(ctx, ct) |
| if err != nil { |
| t.Fatal(err) |
| } |
| return ss, cs |
| } |
| |
| func TestServerClosing(t *testing.T) { |
| cc, cs := basicConnection(t, NewServerTool("greet", "say hi", sayHi)) |
| defer cs.Close() |
| |
| ctx := context.Background() |
| var wg sync.WaitGroup |
| wg.Add(1) |
| go func() { |
| if err := cs.Wait(); err != nil { |
| t.Errorf("server connection failed: %v", err) |
| } |
| wg.Done() |
| }() |
| if _, err := cs.CallTool(ctx, &CallToolParams{ |
| Name: "greet", |
| Arguments: map[string]any{"name": "user"}, |
| }); err != nil { |
| t.Fatalf("after connecting: %v", err) |
| } |
| cc.Close() |
| wg.Wait() |
| if _, err := cs.CallTool(ctx, &CallToolParams{ |
| Name: "greet", |
| Arguments: map[string]any{"name": "user"}, |
| }); !errors.Is(err, ErrConnectionClosed) { |
| t.Errorf("after disconnection, got error %v, want EOF", err) |
| } |
| } |
| |
| func TestBatching(t *testing.T) { |
| ctx := context.Background() |
| ct, st := NewInMemoryTransports() |
| |
| s := NewServer("testServer", "v1.0.0", nil) |
| _, err := s.Connect(ctx, st) |
| if err != nil { |
| t.Fatal(err) |
| } |
| |
| c := NewClient("testClient", "v1.0.0", nil) |
| // TODO: this test is broken, because increasing the batch size here causes |
| // 'initialize' to block. Therefore, we can only test with a size of 1. |
| // Since batching is being removed, we can probably just delete this. |
| const batchSize = 1 |
| cs, err := c.Connect(ctx, ct) |
| if err != nil { |
| t.Fatal(err) |
| } |
| defer cs.Close() |
| |
| errs := make(chan error, batchSize) |
| for i := range batchSize { |
| go func() { |
| _, err := cs.ListTools(ctx, nil) |
| errs <- err |
| }() |
| time.Sleep(2 * time.Millisecond) |
| if i < batchSize-1 { |
| select { |
| case <-errs: |
| t.Errorf("ListTools: unexpected result for incomplete batch: %v", err) |
| default: |
| } |
| } |
| } |
| } |
| |
| func TestCancellation(t *testing.T) { |
| var ( |
| start = make(chan struct{}) |
| cancelled = make(chan struct{}, 1) // don't block the request |
| ) |
| |
| slowRequest := func(ctx context.Context, cc *ServerSession, params *CallToolParamsFor[map[string]any]) (*CallToolResult, error) { |
| start <- struct{}{} |
| select { |
| case <-ctx.Done(): |
| cancelled <- struct{}{} |
| case <-time.After(5 * time.Second): |
| return nil, nil |
| } |
| return nil, nil |
| } |
| st := &ServerTool{ |
| Tool: &Tool{Name: "slow"}, |
| Handler: slowRequest, |
| } |
| _, cs := basicConnection(t, st) |
| defer cs.Close() |
| |
| ctx, cancel := context.WithCancel(context.Background()) |
| go cs.CallTool(ctx, &CallToolParams{Name: "slow"}) |
| <-start |
| cancel() |
| select { |
| case <-cancelled: |
| case <-time.After(5 * time.Second): |
| t.Fatal("timeout waiting for cancellation") |
| } |
| } |
| |
| func TestMiddleware(t *testing.T) { |
| ctx := context.Background() |
| ct, st := NewInMemoryTransports() |
| |
| s := NewServer("testServer", "v1.0.0", nil) |
| ss, err := s.Connect(ctx, st) |
| if err != nil { |
| t.Fatal(err) |
| } |
| // Wait for the server to exit after the client closes its connection. |
| var clientWG sync.WaitGroup |
| clientWG.Add(1) |
| go func() { |
| if err := ss.Wait(); err != nil { |
| t.Errorf("server failed: %v", err) |
| } |
| clientWG.Done() |
| }() |
| |
| var sbuf, cbuf bytes.Buffer |
| sbuf.WriteByte('\n') |
| cbuf.WriteByte('\n') |
| |
| // "1" is the outer middleware layer, called first; then "2" is called, and finally |
| // the default dispatcher. |
| s.AddSendingMiddleware(traceCalls[*ServerSession](&sbuf, "S1"), traceCalls[*ServerSession](&sbuf, "S2")) |
| s.AddReceivingMiddleware(traceCalls[*ServerSession](&sbuf, "R1"), traceCalls[*ServerSession](&sbuf, "R2")) |
| |
| c := NewClient("testClient", "v1.0.0", nil) |
| c.AddSendingMiddleware(traceCalls[*ClientSession](&cbuf, "S1"), traceCalls[*ClientSession](&cbuf, "S2")) |
| c.AddReceivingMiddleware(traceCalls[*ClientSession](&cbuf, "R1"), traceCalls[*ClientSession](&cbuf, "R2")) |
| |
| cs, err := c.Connect(ctx, ct) |
| if err != nil { |
| t.Fatal(err) |
| } |
| if _, err := cs.ListTools(ctx, nil); err != nil { |
| t.Fatal(err) |
| } |
| if _, err := ss.ListRoots(ctx, nil); err != nil { |
| t.Fatal(err) |
| } |
| |
| wantServer := ` |
| R1 >initialize |
| R2 >initialize |
| R2 <initialize |
| R1 <initialize |
| R1 >notifications/initialized |
| R2 >notifications/initialized |
| R2 <notifications/initialized |
| R1 <notifications/initialized |
| R1 >tools/list |
| R2 >tools/list |
| R2 <tools/list |
| R1 <tools/list |
| S1 >roots/list |
| S2 >roots/list |
| S2 <roots/list |
| S1 <roots/list |
| ` |
| if diff := cmp.Diff(wantServer, sbuf.String()); diff != "" { |
| t.Errorf("server mismatch (-want, +got):\n%s", diff) |
| } |
| |
| wantClient := ` |
| S1 >initialize |
| S2 >initialize |
| S2 <initialize |
| S1 <initialize |
| S1 >notifications/initialized |
| S2 >notifications/initialized |
| S2 <notifications/initialized |
| S1 <notifications/initialized |
| S1 >tools/list |
| S2 >tools/list |
| S2 <tools/list |
| S1 <tools/list |
| R1 >roots/list |
| R2 >roots/list |
| R2 <roots/list |
| R1 <roots/list |
| ` |
| if diff := cmp.Diff(wantClient, cbuf.String()); diff != "" { |
| t.Errorf("client mismatch (-want, +got):\n%s", diff) |
| } |
| } |
| |
| type safeBuffer struct { |
| mu sync.Mutex |
| buf bytes.Buffer |
| } |
| |
| func (b *safeBuffer) Write(data []byte) (int, error) { |
| b.mu.Lock() |
| defer b.mu.Unlock() |
| return b.buf.Write(data) |
| } |
| |
| func (b *safeBuffer) Bytes() []byte { |
| b.mu.Lock() |
| defer b.mu.Unlock() |
| return b.buf.Bytes() |
| } |
| |
| func TestNoJSONNull(t *testing.T) { |
| ctx := context.Background() |
| var ct, st Transport = NewInMemoryTransports() |
| |
| // Collect logs, to sanity check that we don't write JSON null anywhere. |
| var logbuf safeBuffer |
| ct = NewLoggingTransport(ct, &logbuf) |
| |
| s := NewServer("testServer", "v1.0.0", nil) |
| ss, err := s.Connect(ctx, st) |
| if err != nil { |
| t.Fatal(err) |
| } |
| |
| c := NewClient("testClient", "v1.0.0", nil) |
| cs, err := c.Connect(ctx, ct) |
| if err != nil { |
| t.Fatal(err) |
| } |
| if _, err := cs.ListTools(ctx, nil); err != nil { |
| t.Fatal(err) |
| } |
| if _, err := cs.ListPrompts(ctx, nil); err != nil { |
| t.Fatal(err) |
| } |
| if _, err := cs.ListResources(ctx, nil); err != nil { |
| t.Fatal(err) |
| } |
| if _, err := cs.ListResourceTemplates(ctx, nil); err != nil { |
| t.Fatal(err) |
| } |
| if _, err := ss.ListRoots(ctx, nil); err != nil { |
| t.Fatal(err) |
| } |
| |
| cs.Close() |
| ss.Wait() |
| |
| logs := logbuf.Bytes() |
| if i := bytes.Index(logs, []byte("null")); i >= 0 { |
| start := max(i-20, 0) |
| end := min(i+20, len(logs)) |
| t.Errorf("conformance violation: MCP logs contain JSON null: %s", "..."+string(logs[start:end])+"...") |
| } |
| } |
| |
| // traceCalls creates a middleware function that prints the method before and after each call |
| // with the given prefix. |
| func traceCalls[S Session](w io.Writer, prefix string) Middleware[S] { |
| return func(h MethodHandler[S]) MethodHandler[S] { |
| return func(ctx context.Context, sess S, method string, params Params) (Result, error) { |
| fmt.Fprintf(w, "%s >%s\n", prefix, method) |
| defer fmt.Fprintf(w, "%s <%s\n", prefix, method) |
| return h(ctx, sess, method, params) |
| } |
| } |
| } |
| |
| // A function, because schemas must form a tree (they have hidden state). |
| func falseSchema() *jsonschema.Schema { return &jsonschema.Schema{Not: &jsonschema.Schema{}} } |
| |
| func nopHandler(context.Context, *ServerSession, *CallToolParamsFor[map[string]any]) (*CallToolResult, error) { |
| return nil, nil |
| } |