| // Copyright 2025 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package server |
| |
| import ( |
| "context" |
| "crypto/sha256" |
| "encoding/hex" |
| "os" |
| "path/filepath" |
| "sync" |
| "testing" |
| |
| "golang.org/x/mod/modfile" |
| "golang.org/x/tools/gopls/internal/filecache" |
| "golang.org/x/tools/gopls/internal/protocol" |
| "golang.org/x/tools/gopls/internal/settings" |
| ) |
| |
| func TestComputeGoModHash(t *testing.T) { |
| tests := []struct { |
| name string |
| content string |
| want string |
| wantErr bool |
| }{ |
| { |
| name: "empty file", |
| content: "module example.com", |
| want: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", // sha256 of empty string |
| }, |
| { |
| name: "with require", |
| content: ` |
| module example.com |
| require ( |
| golang.org/x/tools v0.1.0 |
| golang.org/x/vuln v0.2.0 |
| ) |
| `, |
| want: func() string { |
| h := sha256.New() |
| h.Write([]byte("golang.org/x/toolsv0.1.0")) |
| h.Write([]byte("golang.org/x/vulnv0.2.0")) |
| return hex.EncodeToString(h.Sum(nil)) |
| }(), |
| }, |
| { |
| name: "with exclude", |
| content: ` |
| module example.com |
| exclude ( |
| golang.org/x/tools v0.1.0 |
| ) |
| `, |
| want: func() string { |
| h := sha256.New() |
| h.Write([]byte("golang.org/x/toolsv0.1.0")) |
| return hex.EncodeToString(h.Sum(nil)) |
| }(), |
| }, |
| { |
| name: "with replace", |
| content: ` |
| module example.com |
| replace ( |
| golang.org/x/tools v0.1.0 => golang.org/x/tools v0.2.0 |
| ) |
| `, |
| want: func() string { |
| h := sha256.New() |
| h.Write([]byte("golang.org/x/toolsv0.1.0golang.org/x/toolsv0.2.0")) |
| return hex.EncodeToString(h.Sum(nil)) |
| }(), |
| }, |
| } |
| for _, tt := range tests { |
| t.Run(tt.name, func(t *testing.T) { |
| modFile, err := modfile.Parse("go.mod", []byte(tt.content), nil) |
| if err != nil { |
| t.Fatal(err) |
| } |
| got, err := computeGoModHash(modFile) |
| if (err != nil) != tt.wantErr { |
| t.Errorf("computeGoModHash() error = %v, wantErr %v", err, tt.wantErr) |
| return |
| } |
| if got != tt.want { |
| t.Errorf("computeGoModHash() = %v, want %v", got, tt.want) |
| } |
| }) |
| } |
| } |
| |
| type mockClient struct { |
| protocol.Client |
| showMessageRequest func(context.Context, *protocol.ShowMessageRequestParams) (*protocol.MessageActionItem, error) |
| } |
| |
| func (c *mockClient) ShowMessageRequest(ctx context.Context, params *protocol.ShowMessageRequestParams) (*protocol.MessageActionItem, error) { |
| if c.showMessageRequest != nil { |
| return c.showMessageRequest(ctx, params) |
| } |
| return nil, nil |
| } |
| |
| func (c *mockClient) Close() error { |
| return nil |
| } |
| |
| func TestCheckGoModDeps(t *testing.T) { |
| const ( |
| yes = "Yes" |
| no = "No" |
| always = "Always" |
| never = "Never" |
| ) |
| |
| tests := []struct { |
| name string |
| vulncheckMode settings.VulncheckMode |
| oldContent string |
| newContent string |
| userAction string |
| wantPrompt bool |
| wantHashUpdated bool |
| }{ |
| { |
| name: "vulncheck disabled", |
| vulncheckMode: settings.ModeVulncheckOff, |
| oldContent: "module example.com", |
| newContent: ` |
| module example.com |
| require golang.org/x/tools v0.1.0 |
| `, |
| wantPrompt: false, |
| }, |
| { |
| name: "no changes", |
| vulncheckMode: settings.ModeVulncheckPrompt, |
| oldContent: "module example.com", |
| newContent: "module example.com", |
| wantPrompt: false, |
| }, |
| { |
| name: "user says yes", |
| vulncheckMode: settings.ModeVulncheckPrompt, |
| oldContent: "module example.com", |
| newContent: ` |
| module example.com |
| require golang.org/x/tools v0.1.0 |
| `, |
| userAction: yes, |
| wantPrompt: true, |
| wantHashUpdated: true, |
| }, |
| { |
| name: "user says no", |
| vulncheckMode: settings.ModeVulncheckPrompt, |
| oldContent: "module example.com", |
| newContent: ` |
| module example.com |
| require golang.org/x/tools v0.1.0 |
| `, |
| userAction: no, |
| wantPrompt: true, |
| }, |
| { |
| name: "user says always", |
| vulncheckMode: settings.ModeVulncheckPrompt, |
| oldContent: "module example.com", |
| newContent: ` |
| module example.com |
| require golang.org/x/tools v0.1.0 |
| `, |
| userAction: always, |
| wantPrompt: true, |
| wantHashUpdated: true, |
| }, |
| { |
| name: "user says never", |
| vulncheckMode: settings.ModeVulncheckPrompt, |
| oldContent: "module example.com", |
| newContent: ` |
| module example.com |
| require golang.org/x/tools v0.1.0 |
| `, |
| userAction: never, |
| wantPrompt: true, |
| }, |
| { |
| name: "user dismisses prompt", |
| vulncheckMode: settings.ModeVulncheckPrompt, |
| oldContent: "module example.com", |
| newContent: ` |
| module example.com |
| require golang.org/x/tools v0.1.0 |
| `, |
| userAction: "", |
| wantPrompt: true, |
| }, |
| } |
| |
| for _, tt := range tests { |
| t.Run(tt.name, func(t *testing.T) { |
| ctx := context.Background() |
| var promptShown bool |
| var wg sync.WaitGroup |
| if tt.wantPrompt { |
| wg.Add(1) |
| } |
| client := &mockClient{ |
| showMessageRequest: func(ctx context.Context, params *protocol.ShowMessageRequestParams) (*protocol.MessageActionItem, error) { |
| promptShown = true |
| defer wg.Done() |
| if tt.userAction == "" { |
| return nil, nil |
| } |
| return &protocol.MessageActionItem{Title: tt.userAction}, nil |
| }, |
| } |
| s := &server{ |
| client: client, |
| options: &settings.Options{ |
| UserOptions: settings.UserOptions{ |
| UIOptions: settings.UIOptions{ |
| DiagnosticOptions: settings.DiagnosticOptions{ |
| Vulncheck: tt.vulncheckMode, |
| }, |
| }, |
| }, |
| }, |
| } |
| dir := t.TempDir() |
| goModPath := filepath.Join(dir, "go.mod") |
| if err := os.WriteFile(goModPath, []byte(tt.oldContent), 0644); err != nil { |
| t.Fatal(err) |
| } |
| uri := protocol.URIFromPath(goModPath) |
| |
| // Set the initial hash in the cache. |
| oldModFile, err := modfile.Parse("go.mod", []byte(tt.oldContent), nil) |
| if err != nil { |
| t.Fatal(err) |
| } |
| oldHash, err := computeGoModHash(oldModFile) |
| if err != nil { |
| t.Fatal(err) |
| } |
| pathHash := sha256.Sum256([]byte(uri.Path())) |
| if err := filecache.Set(goModHashKind, pathHash, []byte(oldHash)); err != nil { |
| t.Fatal(err) |
| } |
| |
| // Simulate the file change. |
| if err := os.WriteFile(goModPath, []byte(tt.newContent), 0644); err != nil { |
| t.Fatal(err) |
| } |
| |
| s.checkGoModDeps(ctx, uri) |
| |
| wg.Wait() |
| |
| if promptShown != tt.wantPrompt { |
| t.Errorf("promptShown = %v, want %v", promptShown, tt.wantPrompt) |
| } |
| |
| // Check if the hash was updated. |
| newModFile, err := modfile.Parse("go.mod", []byte(tt.newContent), nil) |
| if err != nil { |
| t.Fatal(err) |
| } |
| newHash, err := computeGoModHash(newModFile) |
| if err != nil { |
| t.Fatal(err) |
| } |
| |
| cachedHashBytes, err := filecache.Get(goModHashKind, pathHash) |
| if err != nil && err != filecache.ErrNotFound { |
| t.Fatal(err) |
| } |
| cachedHash := string(cachedHashBytes) |
| |
| if tt.wantHashUpdated { |
| if cachedHash != newHash { |
| t.Errorf("hash was not updated in cache") |
| } |
| } else { |
| if cachedHash == newHash && oldHash != newHash { |
| t.Errorf("hash was updated in cache, but should not have been") |
| } |
| } |
| }) |
| } |
| } |