prune.go

  1// Wrapper around x/tools/imports that only removes imports, never adds new ones.
  2
  3package imports
  4
  5import (
  6	"bytes"
  7	"go/ast"
  8	"go/build"
  9	"go/parser"
 10	"go/printer"
 11	"go/token"
 12	"path/filepath"
 13	"strings"
 14
 15	"golang.org/x/tools/imports"
 16
 17	"golang.org/x/tools/go/ast/astutil"
 18)
 19
 20type visitFn func(node ast.Node)
 21
 22func (fn visitFn) Visit(node ast.Node) ast.Visitor {
 23	fn(node)
 24	return fn
 25}
 26
 27// Prune removes any unused imports
 28func Prune(filename string, src []byte) ([]byte, error) {
 29	fset := token.NewFileSet()
 30
 31	file, err := parser.ParseFile(fset, filename, src, parser.ParseComments|parser.AllErrors)
 32	if err != nil {
 33		return nil, err
 34	}
 35
 36	unused, err := getUnusedImports(file, filename)
 37	if err != nil {
 38		return nil, err
 39	}
 40	for ipath, name := range unused {
 41		astutil.DeleteNamedImport(fset, file, name, ipath)
 42	}
 43	printConfig := &printer.Config{Mode: printer.TabIndent, Tabwidth: 8}
 44
 45	var buf bytes.Buffer
 46	if err := printConfig.Fprint(&buf, fset, file); err != nil {
 47		return nil, err
 48	}
 49
 50	return imports.Process(filename, buf.Bytes(), &imports.Options{FormatOnly: true, Comments: true, TabIndent: true, TabWidth: 8})
 51}
 52
 53func getUnusedImports(file ast.Node, filename string) (map[string]string, error) {
 54	imported := map[string]*ast.ImportSpec{}
 55	used := map[string]bool{}
 56
 57	abs, err := filepath.Abs(filename)
 58	if err != nil {
 59		return nil, err
 60	}
 61	srcDir := filepath.Dir(abs)
 62
 63	ast.Walk(visitFn(func(node ast.Node) {
 64		if node == nil {
 65			return
 66		}
 67		switch v := node.(type) {
 68		case *ast.ImportSpec:
 69			if v.Name != nil {
 70				imported[v.Name.Name] = v
 71				break
 72			}
 73			ipath := strings.Trim(v.Path.Value, `"`)
 74			if ipath == "C" {
 75				break
 76			}
 77
 78			local := importPathToName(ipath, srcDir)
 79
 80			imported[local] = v
 81		case *ast.SelectorExpr:
 82			xident, ok := v.X.(*ast.Ident)
 83			if !ok {
 84				break
 85			}
 86			if xident.Obj != nil {
 87				// if the parser can resolve it, it's not a package ref
 88				break
 89			}
 90			used[xident.Name] = true
 91		}
 92	}), file)
 93
 94	for pkg := range used {
 95		delete(imported, pkg)
 96	}
 97
 98	unusedImport := map[string]string{}
 99	for pkg, is := range imported {
100		if !used[pkg] && pkg != "_" && pkg != "." {
101			name := ""
102			if is.Name != nil {
103				name = is.Name.Name
104			}
105			unusedImport[strings.Trim(is.Path.Value, `"`)] = name
106		}
107	}
108
109	return unusedImport, nil
110}
111
112func importPathToName(importPath, srcDir string) (packageName string) {
113	pkg, err := build.Default.Import(importPath, srcDir, 0)
114	if err != nil {
115		return ""
116	}
117
118	return pkg.Name
119}