| // Copyright 2020 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package progress |
| |
| import ( |
| "context" |
| "fmt" |
| "sync" |
| "testing" |
| |
| "golang.org/x/tools/gopls/internal/protocol" |
| ) |
| |
| type fakeClient struct { |
| protocol.Client |
| |
| token protocol.ProgressToken |
| |
| mu sync.Mutex |
| created, begun, reported, messages, ended int |
| } |
| |
| func (c *fakeClient) checkToken(token protocol.ProgressToken) { |
| if token == nil { |
| panic("nil token in progress message") |
| } |
| if c.token != nil && c.token != token { |
| panic(fmt.Errorf("invalid token in progress message: got %v, want %v", token, c.token)) |
| } |
| } |
| |
| func (c *fakeClient) WorkDoneProgressCreate(ctx context.Context, params *protocol.WorkDoneProgressCreateParams) error { |
| c.mu.Lock() |
| defer c.mu.Unlock() |
| c.checkToken(params.Token) |
| c.created++ |
| return nil |
| } |
| |
| func (c *fakeClient) Progress(ctx context.Context, params *protocol.ProgressParams) error { |
| c.mu.Lock() |
| defer c.mu.Unlock() |
| c.checkToken(params.Token) |
| switch params.Value.(type) { |
| case *protocol.WorkDoneProgressBegin: |
| c.begun++ |
| case *protocol.WorkDoneProgressReport: |
| c.reported++ |
| case *protocol.WorkDoneProgressEnd: |
| c.ended++ |
| default: |
| panic(fmt.Errorf("unknown progress value %T", params.Value)) |
| } |
| return nil |
| } |
| |
| func (c *fakeClient) ShowMessage(context.Context, *protocol.ShowMessageParams) error { |
| c.mu.Lock() |
| defer c.mu.Unlock() |
| c.messages++ |
| return nil |
| } |
| |
| func setup() (context.Context, *Tracker, *fakeClient) { |
| c := &fakeClient{} |
| tracker := NewTracker(c) |
| tracker.SetSupportsWorkDoneProgress(true) |
| return context.Background(), tracker, c |
| } |
| |
| func TestProgressTracker_Reporting(t *testing.T) { |
| for _, test := range []struct { |
| name string |
| supported bool |
| token protocol.ProgressToken |
| wantReported, wantCreated, wantBegun, wantEnded int |
| wantMessages int |
| }{ |
| { |
| name: "unsupported", |
| wantMessages: 2, |
| }, |
| { |
| name: "random token", |
| supported: true, |
| wantCreated: 1, |
| wantBegun: 1, |
| wantReported: 1, |
| wantEnded: 1, |
| }, |
| { |
| name: "string token", |
| supported: true, |
| token: "token", |
| wantBegun: 1, |
| wantReported: 1, |
| wantEnded: 1, |
| }, |
| { |
| name: "numeric token", |
| supported: true, |
| token: 1, |
| wantReported: 1, |
| wantBegun: 1, |
| wantEnded: 1, |
| }, |
| } { |
| test := test |
| t.Run(test.name, func(t *testing.T) { |
| ctx, tracker, client := setup() |
| ctx, cancel := context.WithCancel(ctx) |
| defer cancel() |
| tracker.supportsWorkDoneProgress = test.supported |
| work := tracker.Start(ctx, "work", "message", test.token, nil) |
| client.mu.Lock() |
| gotCreated, gotBegun := client.created, client.begun |
| client.mu.Unlock() |
| if gotCreated != test.wantCreated { |
| t.Errorf("got %d created tokens, want %d", gotCreated, test.wantCreated) |
| } |
| if gotBegun != test.wantBegun { |
| t.Errorf("got %d work begun, want %d", gotBegun, test.wantBegun) |
| } |
| // Ignore errors: this is just testing the reporting behavior. |
| work.Report(ctx, "report", 50) |
| client.mu.Lock() |
| gotReported := client.reported |
| client.mu.Unlock() |
| if gotReported != test.wantReported { |
| t.Errorf("got %d progress reports, want %d", gotReported, test.wantCreated) |
| } |
| work.End(ctx, "done") |
| client.mu.Lock() |
| gotEnded, gotMessages := client.ended, client.messages |
| client.mu.Unlock() |
| if gotEnded != test.wantEnded { |
| t.Errorf("got %d ended reports, want %d", gotEnded, test.wantEnded) |
| } |
| if gotMessages != test.wantMessages { |
| t.Errorf("got %d messages, want %d", gotMessages, test.wantMessages) |
| } |
| }) |
| } |
| } |
| |
| func TestProgressTracker_Cancellation(t *testing.T) { |
| for _, token := range []protocol.ProgressToken{nil, 1, "a"} { |
| ctx, tracker, _ := setup() |
| var canceled bool |
| cancel := func() { canceled = true } |
| work := tracker.Start(ctx, "work", "message", token, cancel) |
| if err := tracker.Cancel(work.Token()); err != nil { |
| t.Fatal(err) |
| } |
| if !canceled { |
| t.Errorf("tracker.cancel(...): cancel not called") |
| } |
| } |
| } |