| // Copyright 2025 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package analysisinternal |
| |
| import ( |
| "go/ast" |
| "go/parser" |
| "go/token" |
| "testing" |
| |
| "golang.org/x/tools/go/ast/inspector" |
| "golang.org/x/tools/internal/astutil/cursor" |
| ) |
| |
| func TestCanImport(t *testing.T) { |
| for _, tt := range []struct { |
| from string |
| to string |
| want bool |
| }{ |
| {"fmt", "internal", true}, |
| {"fmt", "internal/foo", true}, |
| {"a.com/b", "internal", false}, |
| {"a.com/b", "xinternal", true}, |
| {"a.com/b", "internal/foo", false}, |
| {"a.com/b", "xinternal/foo", true}, |
| {"a.com/b", "a.com/internal", true}, |
| {"a.com/b", "a.com/b/internal", true}, |
| {"a.com/b", "a.com/b/internal/foo", true}, |
| {"a.com/b", "a.com/c/internal", false}, |
| {"a.com/b", "a.com/c/xinternal", true}, |
| {"a.com/b", "a.com/c/internal/foo", false}, |
| {"a.com/b", "a.com/c/xinternal/foo", true}, |
| } { |
| got := CanImport(tt.from, tt.to) |
| if got != tt.want { |
| t.Errorf("CanImport(%q, %q) = %v, want %v", tt.from, tt.to, got, tt.want) |
| } |
| } |
| } |
| |
| func TestDeleteStmt(t *testing.T) { |
| type testCase struct { |
| in string |
| which int // count of ast.Stmt in ast.Inspect traversal to remove |
| want string |
| name string // should contain exactly one of [block,switch,case,comm,for,type] |
| } |
| tests := []testCase{ |
| { // do nothing when asked to remove a function body |
| in: "package p; func f() { }", |
| which: 0, |
| want: "package p; func f() { }", |
| name: "block0", |
| }, |
| { |
| in: "package p; func f() { abcd()}", |
| which: 1, |
| want: "package p; func f() { }", |
| name: "block1", |
| }, |
| { |
| in: "package p; func f() { a() }", |
| which: 1, |
| want: "package p; func f() { }", |
| name: "block2", |
| }, |
| { |
| in: "package p; func f() { a();}", |
| which: 1, |
| want: "package p; func f() { ;}", |
| name: "block3", |
| }, |
| { |
| in: "package p; func f() {\n a() \n\n}", |
| which: 1, |
| want: "package p; func f() {\n\n}", |
| name: "block4", |
| }, |
| { |
| in: "package p; func f() { a()// comment\n}", |
| which: 1, |
| want: "package p; func f() { // comment\n}", |
| name: "block5", |
| }, |
| { |
| in: "package p; func f() { /*c*/a() \n}", |
| which: 1, |
| want: "package p; func f() { /*c*/ \n}", |
| name: "block6", |
| }, |
| { |
| in: "package p; func f() { a();b();}", |
| which: 2, |
| want: "package p; func f() { a();;}", |
| name: "block7", |
| }, |
| { |
| in: "package p; func f() {\n\ta()\n\tb()\n}", |
| which: 2, |
| want: "package p; func f() {\n\ta()\n}", |
| name: "block8", |
| }, |
| { |
| in: "package p; func f() {\n\ta()\n\tb()\n\tc()\n}", |
| which: 2, |
| want: "package p; func f() {\n\ta()\n\tc()\n}", |
| name: "block9", |
| }, |
| { |
| in: "package p\nfunc f() {a()+b()}", |
| which: 1, |
| want: "package p\nfunc f() {}", |
| name: "block10", |
| }, |
| { |
| in: "package p\nfunc f() {(a()+b())}", |
| which: 1, |
| want: "package p\nfunc f() {}", |
| name: "block11", |
| }, |
| { |
| in: "package p; func f() { switch a(); b() {}}", |
| which: 2, // 0 is the func body, 1 is the switch statement |
| want: "package p; func f() { switch ; b() {}}", |
| name: "switch0", |
| }, |
| { |
| in: "package p; func f() { switch /*c*/a(); {}}", |
| which: 2, // 0 is the func body, 1 is the switch statement |
| want: "package p; func f() { switch /*c*/; {}}", |
| name: "switch1", |
| }, |
| { |
| in: "package p; func f() { switch a()/*c*/; {}}", |
| which: 2, // 0 is the func body, 1 is the switch statement |
| want: "package p; func f() { switch /*c*/; {}}", |
| name: "switch2", |
| }, |
| { |
| in: "package p; func f() { select {default: a()}}", |
| which: 4, // 0 is the func body, 1 is the select statement, 2 is its body, 3 is the comm clause |
| want: "package p; func f() { select {default: }}", |
| name: "comm0", |
| }, |
| { |
| in: "package p; func f(x chan any) { select {case x <- a: a(x)}}", |
| which: 5, // 0 is the func body, 1 is the select statement, 2 is its body, 3 is the comm clause |
| want: "package p; func f(x chan any) { select {case x <- a: }}", |
| name: "comm1", |
| }, |
| { |
| in: "package p; func f(x chan any) { select {case x <- a: a(x)}}", |
| which: 4, // 0 is the func body, 1 is the select statement, 2 is its body, 3 is the comm clause |
| want: "package p; func f(x chan any) { select {case x <- a: a(x)}}", |
| name: "comm2", |
| }, |
| { |
| in: "package p; func f() { switch {default: a()}}", |
| which: 4, // 0 is the func body, 1 is the select statement, 2 is its body |
| want: "package p; func f() { switch {default: }}", |
| name: "case0", |
| }, |
| { |
| in: "package p; func f() { switch {case 3: a()}}", |
| which: 4, // 0 is the func body, 1 is the select statement, 2 is its body |
| want: "package p; func f() { switch {case 3: }}", |
| name: "case1", |
| }, |
| { |
| in: "package p; func f() {for a();;b() {}}", |
| which: 2, |
| want: "package p; func f() {for ;;b() {}}", |
| name: "for0", |
| }, |
| { |
| in: "package p; func f() {for a();c();b() {}}", |
| which: 3, |
| want: "package p; func f() {for a();c(); {}}", |
| name: "for1", |
| }, |
| { |
| in: "package p; func f() {for\na();c()\nb() {}}", |
| which: 2, |
| want: "package p; func f() {for\n;c()\nb() {}}", |
| name: "for2", |
| }, |
| { |
| in: "package p; func f() {for a();\nc();b() {}}", |
| which: 3, |
| want: "package p; func f() {for a();\nc(); {}}", |
| name: "for3", |
| }, |
| { |
| in: "package p; func f() {switch a();b().(type){}}", |
| which: 2, |
| want: "package p; func f() {switch ;b().(type){}}", |
| name: "type0", |
| }, |
| { |
| in: "package p; func f() {switch a();b().(type){}}", |
| which: 3, |
| want: "package p; func f() {switch a();b().(type){}}", |
| name: "type1", |
| }, |
| } |
| for _, tt := range tests { |
| t.Run(tt.name, func(t *testing.T) { |
| fset := token.NewFileSet() |
| f, err := parser.ParseFile(fset, tt.name, tt.in, parser.ParseComments) |
| if err != nil { |
| t.Fatalf("%s: %v", tt.name, err) |
| } |
| insp := inspector.New([]*ast.File{f}) |
| root := cursor.Root(insp) |
| var stmt cursor.Cursor |
| cnt := 0 |
| for cn := range root.Preorder() { // Preorder(ast.Stmt(nil)) doesn't work |
| if _, ok := cn.Node().(ast.Stmt); !ok { |
| continue |
| } |
| if cnt == tt.which { |
| stmt = cn |
| break |
| } |
| cnt++ |
| } |
| if cnt != tt.which { |
| t.Fatalf("test %s does not contain desired statement %d", tt.name, tt.which) |
| } |
| edits := DeleteStmt(fset, f, stmt.Node().(ast.Stmt), nil) |
| if tt.want == tt.in { |
| if len(edits) != 0 { |
| t.Fatalf("%s: got %d edits, expected 0", tt.name, len(edits)) |
| } |
| return |
| } |
| if len(edits) != 1 { |
| t.Fatalf("%s: got %d edits, expected 1", tt.name, len(edits)) |
| } |
| tokFile := fset.File(f.Pos()) |
| |
| left := tokFile.Offset(edits[0].Pos) |
| right := tokFile.Offset(edits[0].End) |
| |
| got := tt.in[:left] + tt.in[right:] |
| if got != tt.want { |
| t.Errorf("%s: got\n%q, want\n%q", tt.name, got, tt.want) |
| } |
| }) |
| |
| } |
| } |