blob: de0f79a73fd00db1791ad441d5699a2c285e3a5e [file] [log] [blame]
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mcp
import (
"context"
"fmt"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"golang.org/x/tools/internal/mcp/jsonschema"
)
type Item struct {
Name string
Value string
}
type ListTestParams struct {
Cursor string
}
func (p *ListTestParams) cursorPtr() *string {
return &p.Cursor
}
type ListTestResult struct {
Items []*Item
NextCursor string
}
func (r *ListTestResult) nextCursorPtr() *string {
return &r.NextCursor
}
var allItems = []*Item{
{"alpha", "val-A"},
{"bravo", "val-B"},
{"charlie", "val-C"},
{"delta", "val-D"},
{"echo", "val-E"},
{"foxtrot", "val-F"},
{"golf", "val-G"},
{"hotel", "val-H"},
{"india", "val-I"},
{"juliet", "val-J"},
{"kilo", "val-K"},
}
// generatePaginatedResults is a helper to create a sequence of mock responses for pagination.
// It simulates a server returning items in pages based on a given page size.
func generatePaginatedResults(all []*Item, pageSize int) []*ListTestResult {
if len(all) == 0 {
return []*ListTestResult{{Items: []*Item{}, NextCursor: ""}}
}
if pageSize <= 0 {
panic("pageSize must be greater than 0")
}
numPages := (len(all) + pageSize - 1) / pageSize // Ceiling division
var results []*ListTestResult
for i := range numPages {
startIndex := i * pageSize
endIndex := min(startIndex+pageSize, len(all)) // Use min to prevent out of bounds
nextCursor := ""
if endIndex < len(all) { // If there are more items after this page
nextCursor = fmt.Sprintf("cursor_%d", endIndex)
}
results = append(results, &ListTestResult{Items: all[startIndex:endIndex], NextCursor: nextCursor})
}
return results
}
func TestClientPaginateBasic(t *testing.T) {
ctx := context.Background()
testCases := []struct {
name string
results []*ListTestResult
mockError error
initialParams *ListTestParams
expected []*Item
expectError bool
}{
{
name: "SinglePageAllItems",
results: generatePaginatedResults(allItems, len(allItems)),
expected: allItems,
},
{
name: "MultiplePages",
results: generatePaginatedResults(allItems, 3),
expected: allItems,
},
{
name: "EmptyResults",
results: generatePaginatedResults([]*Item{}, 10),
expected: nil,
},
{
name: "ListFuncReturnsErrorImmediately",
results: []*ListTestResult{{}},
mockError: fmt.Errorf("API error on first call"),
expected: nil,
expectError: true,
},
{
name: "InitialCursorProvided",
initialParams: &ListTestParams{Cursor: "cursor_2"},
results: generatePaginatedResults(allItems[2:], 3),
expected: allItems[2:],
},
{
name: "CursorBeyondAllItems",
initialParams: &ListTestParams{Cursor: "cursor_999"},
results: []*ListTestResult{{Items: []*Item{}, NextCursor: ""}},
expected: nil,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
listFunc := func(ctx context.Context, params *ListTestParams) (*ListTestResult, error) {
if len(tc.results) == 0 {
t.Fatalf("listFunc called but no more results defined for test case %q", tc.name)
}
res := tc.results[0]
tc.results = tc.results[1:]
var err error
if tc.mockError != nil {
err = tc.mockError
}
return res, err
}
params := tc.initialParams
if tc.initialParams == nil {
params = &ListTestParams{}
}
var gotItems []*Item
var iterationErr error
seq := paginate(ctx, params, listFunc, func(r *ListTestResult) []*Item { return r.Items })
for item, err := range seq {
if err != nil {
iterationErr = err
break
}
gotItems = append(gotItems, item)
}
if tc.expectError {
if iterationErr == nil {
t.Errorf("paginate() expected an error during iteration, but got none")
}
} else {
if iterationErr != nil {
t.Errorf("paginate() got: %v, want: nil", iterationErr)
}
}
if diff := cmp.Diff(tc.expected, gotItems, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" {
t.Fatalf("paginate() mismatch (-want +got):\n%s", diff)
}
})
}
}
func TestClientPaginateVariousPageSizes(t *testing.T) {
ctx := context.Background()
for i := 1; i < len(allItems)+1; i++ {
testname := fmt.Sprintf("PageSize=%d", i)
t.Run(testname, func(t *testing.T) {
results := generatePaginatedResults(allItems, i)
listFunc := func(ctx context.Context, params *ListTestParams) (*ListTestResult, error) {
res := results[0]
results = results[1:]
return res, nil
}
var gotItems []*Item
seq := paginate(ctx, &ListTestParams{}, listFunc, func(r *ListTestResult) []*Item { return r.Items })
for item, err := range seq {
if err != nil {
t.Fatalf("paginate() unexpected error during iteration: %v", err)
}
gotItems = append(gotItems, item)
}
if diff := cmp.Diff(allItems, gotItems, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" {
t.Fatalf("paginate() mismatch (-want +got):\n%s", diff)
}
})
}
}