[release-branch.r58] gofix: fixes for os/signal changes

««« CL 4630056 / 8fe2bc5c3d53
gofix: fixes for os/signal changes

Fixes #1971.

R=adg, rsc
CC=golang-dev
https://golang.org/cl/4630056

»»»

R=golang-dev, r
CC=golang-dev
https://golang.org/cl/4645071
diff --git a/src/cmd/gofix/Makefile b/src/cmd/gofix/Makefile
index d19de5c..c4127d9 100644
--- a/src/cmd/gofix/Makefile
+++ b/src/cmd/gofix/Makefile
@@ -14,6 +14,7 @@
 	httpserver.go\
 	procattr.go\
 	reflect.go\
+	signal.go\
 	typecheck.go\
 
 include ../../Make.cmd
diff --git a/src/cmd/gofix/fix.go b/src/cmd/gofix/fix.go
index 0852ce2..c1c5a74 100644
--- a/src/cmd/gofix/fix.go
+++ b/src/cmd/gofix/fix.go
@@ -10,6 +10,7 @@
 	"go/token"
 	"os"
 	"strconv"
+	"strings"
 )
 
 type fix struct {
@@ -258,13 +259,28 @@
 
 // imports returns true if f imports path.
 func imports(f *ast.File, path string) bool {
+	return importSpec(f, path) != nil
+}
+
+// importSpec returns the import spec if f imports path,
+// or nil otherwise.
+func importSpec(f *ast.File, path string) *ast.ImportSpec {
 	for _, s := range f.Imports {
-		t, err := strconv.Unquote(s.Path.Value)
-		if err == nil && t == path {
-			return true
+		if importPath(s) == path {
+			return s
 		}
 	}
-	return false
+	return nil
+}
+
+// importPath returns the unquoted import path of s,
+// or "" if the path is not properly quoted.
+func importPath(s *ast.ImportSpec) string {
+	t, err := strconv.Unquote(s.Path.Value)
+	if err == nil {
+		return t
+	}
+	return ""
 }
 
 // isPkgDot returns true if t is the expression "pkg.name"
@@ -420,3 +436,138 @@
 		},
 	}
 }
+
+// addImport adds the import path to the file f, if absent.
+func addImport(f *ast.File, path string) {
+	if imports(f, path) {
+		return
+	}
+
+	newImport := &ast.ImportSpec{
+		Path: &ast.BasicLit{
+			Kind:  token.STRING,
+			Value: strconv.Quote(path),
+		},
+	}
+
+	var impdecl *ast.GenDecl
+
+	// Find an import decl to add to.
+	for _, decl := range f.Decls {
+		gen, ok := decl.(*ast.GenDecl)
+
+		if ok && gen.Tok == token.IMPORT {
+			impdecl = gen
+			break
+		}
+	}
+
+	// No import decl found.  Add one.
+	if impdecl == nil {
+		impdecl = &ast.GenDecl{
+			Tok: token.IMPORT,
+		}
+		f.Decls = append(f.Decls, nil)
+		copy(f.Decls[1:], f.Decls)
+		f.Decls[0] = impdecl
+	}
+
+	// Ensure the import decl has parentheses, if needed.
+	if len(impdecl.Specs) > 0 && !impdecl.Lparen.IsValid() {
+		impdecl.Lparen = impdecl.Pos()
+	}
+
+	// Assume the import paths are alphabetically ordered.
+	// If they are not, the result is ugly, but legal.
+	insertAt := len(impdecl.Specs) // default to end of specs
+	for i, spec := range impdecl.Specs {
+		impspec := spec.(*ast.ImportSpec)
+		if importPath(impspec) > path {
+			insertAt = i
+			break
+		}
+	}
+
+	impdecl.Specs = append(impdecl.Specs, nil)
+	copy(impdecl.Specs[insertAt+1:], impdecl.Specs[insertAt:])
+	impdecl.Specs[insertAt] = newImport
+
+	f.Imports = append(f.Imports, newImport)
+}
+
+// deleteImport deletes the import path from the file f, if present.
+func deleteImport(f *ast.File, path string) {
+	oldImport := importSpec(f, path)
+
+	// Find the import node that imports path, if any.
+	for i, decl := range f.Decls {
+		gen, ok := decl.(*ast.GenDecl)
+		if !ok || gen.Tok != token.IMPORT {
+			continue
+		}
+		for j, spec := range gen.Specs {
+			impspec := spec.(*ast.ImportSpec)
+
+			if oldImport != impspec {
+				continue
+			}
+
+			// We found an import spec that imports path.
+			// Delete it.
+			copy(gen.Specs[j:], gen.Specs[j+1:])
+			gen.Specs = gen.Specs[:len(gen.Specs)-1]
+
+			// If this was the last import spec in this decl,
+			// delete the decl, too.
+			if len(gen.Specs) == 0 {
+				copy(f.Decls[i:], f.Decls[i+1:])
+				f.Decls = f.Decls[:len(f.Decls)-1]
+			} else if len(gen.Specs) == 1 {
+				gen.Lparen = token.NoPos // drop parens
+			}
+
+			break
+		}
+	}
+
+	// Delete it from f.Imports.
+	for i, imp := range f.Imports {
+		if imp == oldImport {
+			copy(f.Imports[i:], f.Imports[i+1:])
+			f.Imports = f.Imports[:len(f.Imports)-1]
+			break
+		}
+	}
+}
+
+func usesImport(f *ast.File, path string) (used bool) {
+	spec := importSpec(f, path)
+	if spec == nil {
+		return
+	}
+
+	name := spec.Name.String()
+	switch name {
+	case "<nil>":
+		// If the package name is not explicitly specified,
+		// make an educated guess. This is not guaranteed to be correct.
+		lastSlash := strings.LastIndex(path, "/")
+		if lastSlash == -1 {
+			name = path
+		} else {
+			name = path[lastSlash+1:]
+		}
+	case "_", ".":
+		// Not sure if this import is used - err on the side of caution.
+		return true
+	}
+
+	walk(f, func(n interface{}) {
+		sel, ok := n.(*ast.SelectorExpr)
+		if ok && isTopName(sel.X, name) {
+			used = true
+		}
+	})
+
+	return
+}
diff --git a/src/cmd/gofix/signal.go b/src/cmd/gofix/signal.go
new file mode 100644
index 0000000..53c3388
--- /dev/null
+++ b/src/cmd/gofix/signal.go
@@ -0,0 +1,49 @@
+// Copyright 2011 The Go Authors.  All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+	"go/ast"
+	"strings"
+)
+
+func init() {
+	register(fix{
+		"signal",
+		signal,
+		`Adapt code to types moved from os/signal to signal.
+
+http://codereview.appspot.com/4437091
+`,
+	})
+}
+
+func signal(f *ast.File) (fixed bool) {
+	if !imports(f, "os/signal") {
+		return
+	}
+
+	walk(f, func(n interface{}) {
+		s, ok := n.(*ast.SelectorExpr)
+
+		if !ok || !isTopName(s.X, "signal") {
+			return
+		}
+
+		sel := s.Sel.String()
+		if sel == "Signal" || sel == "UnixSignal" || strings.HasPrefix(sel, "SIG") {
+			s.X = &ast.Ident{Name: "os"}
+			fixed = true
+		}
+	})
+
+	if fixed {
+		addImport(f, "os")
+		if !usesImport(f, "os/signal") {
+			deleteImport(f, "os/signal")
+		}
+	}
+	return
+}
diff --git a/src/cmd/gofix/signal_test.go b/src/cmd/gofix/signal_test.go
new file mode 100644
index 0000000..2500e9c
--- /dev/null
+++ b/src/cmd/gofix/signal_test.go
@@ -0,0 +1,96 @@
+// Copyright 2011 The Go Authors.  All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+func init() {
+	addTestCases(signalTests)
+}
+
+var signalTests = []testCase{
+	{
+		Name: "signal.0",
+		In: `package main
+
+import (
+	_ "a"
+	"os/signal"
+	_ "z"
+)
+
+type T1 signal.UnixSignal
+type T2 signal.Signal
+
+func f() {
+	_ = signal.SIGHUP
+	_ = signal.Incoming
+}
+`,
+		Out: `package main
+
+import (
+	_ "a"
+	"os"
+	"os/signal"
+	_ "z"
+)
+
+type T1 os.UnixSignal
+type T2 os.Signal
+
+func f() {
+	_ = os.SIGHUP
+	_ = signal.Incoming
+}
+`,
+	},
+	{
+		Name: "signal.1",
+		In: `package main
+
+import (
+	"os"
+	"os/signal"
+)
+
+func f() {
+	var _ os.Error
+	_ = signal.SIGHUP
+}
+`,
+		Out: `package main
+
+import "os"
+
+
+func f() {
+	var _ os.Error
+	_ = os.SIGHUP
+}
+`,
+	},
+	{
+		Name: "signal.2",
+		In: `package main
+
+import "os"
+import "os/signal"
+
+func f() {
+	var _ os.Error
+	_ = signal.SIGHUP
+}
+`,
+		Out: `package main
+
+import "os"
+
+
+func f() {
+	var _ os.Error
+	_ = os.SIGHUP
+}
+`,
+	},
+}