| // Copyright 2025 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package mcp |
| |
| import ( |
| "context" |
| "fmt" |
| "net/http" |
| "net/http/httptest" |
| "strings" |
| "testing" |
| |
| "github.com/google/go-cmp/cmp" |
| ) |
| |
| func TestSSEServer(t *testing.T) { |
| for _, closeServerFirst := range []bool{false, true} { |
| t.Run(fmt.Sprintf("closeServerFirst=%t", closeServerFirst), func(t *testing.T) { |
| ctx := context.Background() |
| server := NewServer("testServer", "v1.0.0", nil) |
| server.AddTools(NewServerTool("greet", "say hi", sayHi)) |
| |
| sseHandler := NewSSEHandler(func(*http.Request) *Server { return server }) |
| |
| conns := make(chan *ServerSession, 1) |
| sseHandler.onConnection = func(cc *ServerSession) { |
| select { |
| case conns <- cc: |
| default: |
| } |
| } |
| httpServer := httptest.NewServer(sseHandler) |
| defer httpServer.Close() |
| |
| clientTransport := NewSSEClientTransport(httpServer.URL, nil) |
| |
| c := NewClient("testClient", "v1.0.0", nil) |
| cs, err := c.Connect(ctx, clientTransport) |
| if err != nil { |
| t.Fatal(err) |
| } |
| if err := cs.Ping(ctx, nil); err != nil { |
| t.Fatal(err) |
| } |
| ss := <-conns |
| gotHi, err := cs.CallTool(ctx, &CallToolParams{ |
| Name: "greet", |
| Arguments: map[string]any{"Name": "user"}, |
| }) |
| if err != nil { |
| t.Fatal(err) |
| } |
| wantHi := &CallToolResult{ |
| Content: []*Content{{Type: "text", Text: "hi user"}}, |
| } |
| if diff := cmp.Diff(wantHi, gotHi); diff != "" { |
| t.Errorf("tools/call 'greet' mismatch (-want +got):\n%s", diff) |
| } |
| |
| // Test that closing either end of the connection terminates the other |
| // end. |
| if closeServerFirst { |
| cs.Close() |
| ss.Wait() |
| } else { |
| ss.Close() |
| cs.Wait() |
| } |
| }) |
| } |
| } |
| |
| func TestScanEvents(t *testing.T) { |
| tests := []struct { |
| name string |
| input string |
| want []event |
| wantErr string |
| }{ |
| { |
| name: "simple event", |
| input: "event: message\nid: 1\ndata: hello\n\n", |
| want: []event{ |
| {name: "message", id: "1", data: []byte("hello")}, |
| }, |
| }, |
| { |
| name: "multiple data lines", |
| input: "data: line 1\ndata: line 2\n\n", |
| want: []event{ |
| {data: []byte("line 1\nline 2")}, |
| }, |
| }, |
| { |
| name: "multiple events", |
| input: "data: first\n\nevent: second\ndata: second\n\n", |
| want: []event{ |
| {data: []byte("first")}, |
| {name: "second", data: []byte("second")}, |
| }, |
| }, |
| { |
| name: "no trailing newline", |
| input: "data: hello", |
| want: []event{ |
| {data: []byte("hello")}, |
| }, |
| }, |
| { |
| name: "malformed line", |
| input: "invalid line\n\n", |
| wantErr: "malformed line", |
| }, |
| } |
| |
| for _, tt := range tests { |
| t.Run(tt.name, func(t *testing.T) { |
| r := strings.NewReader(tt.input) |
| var got []event |
| var err error |
| for e, err2 := range scanEvents(r) { |
| if err2 != nil { |
| err = err2 |
| break |
| } |
| got = append(got, e) |
| } |
| |
| if tt.wantErr != "" { |
| if err == nil { |
| t.Fatalf("scanEvents() got nil error, want error containing %q", tt.wantErr) |
| } |
| if !strings.Contains(err.Error(), tt.wantErr) { |
| t.Fatalf("scanEvents() error = %q, want containing %q", err, tt.wantErr) |
| } |
| return |
| } |
| |
| if err != nil { |
| t.Fatalf("scanEvents() returned unexpected error: %v", err) |
| } |
| |
| if len(got) != len(tt.want) { |
| t.Fatalf("scanEvents() got %d events, want %d", len(got), len(tt.want)) |
| } |
| |
| for i := range got { |
| if g, w := got[i].name, tt.want[i].name; g != w { |
| t.Errorf("event %d: name = %q, want %q", i, g, w) |
| } |
| if g, w := got[i].id, tt.want[i].id; g != w { |
| t.Errorf("event %d: id = %q, want %q", i, g, w) |
| } |
| if g, w := string(got[i].data), string(tt.want[i].data); g != w { |
| t.Errorf("event %d: data = %q, want %q", i, g, w) |
| } |
| } |
| }) |
| } |
| } |