1// Wrapper around x/tools/imports that only removes imports, never adds new ones.
2
3package imports
4
5import (
6 "bytes"
7 "go/ast"
8 "go/build"
9 "go/parser"
10 "go/printer"
11 "go/token"
12 "path/filepath"
13 "strings"
14
15 "golang.org/x/tools/imports"
16
17 "golang.org/x/tools/go/ast/astutil"
18)
19
20type visitFn func(node ast.Node)
21
22func (fn visitFn) Visit(node ast.Node) ast.Visitor {
23 fn(node)
24 return fn
25}
26
27// Prune removes any unused imports
28func Prune(filename string, src []byte) ([]byte, error) {
29 fset := token.NewFileSet()
30
31 file, err := parser.ParseFile(fset, filename, src, parser.ParseComments|parser.AllErrors)
32 if err != nil {
33 return nil, err
34 }
35
36 unused, err := getUnusedImports(file, filename)
37 if err != nil {
38 return nil, err
39 }
40 for ipath, name := range unused {
41 astutil.DeleteNamedImport(fset, file, name, ipath)
42 }
43 printConfig := &printer.Config{Mode: printer.TabIndent, Tabwidth: 8}
44
45 var buf bytes.Buffer
46 if err := printConfig.Fprint(&buf, fset, file); err != nil {
47 return nil, err
48 }
49
50 return imports.Process(filename, buf.Bytes(), &imports.Options{FormatOnly: true, Comments: true, TabIndent: true, TabWidth: 8})
51}
52
53func getUnusedImports(file ast.Node, filename string) (map[string]string, error) {
54 imported := map[string]*ast.ImportSpec{}
55 used := map[string]bool{}
56
57 abs, err := filepath.Abs(filename)
58 if err != nil {
59 return nil, err
60 }
61 srcDir := filepath.Dir(abs)
62
63 ast.Walk(visitFn(func(node ast.Node) {
64 if node == nil {
65 return
66 }
67 switch v := node.(type) {
68 case *ast.ImportSpec:
69 if v.Name != nil {
70 imported[v.Name.Name] = v
71 break
72 }
73 ipath := strings.Trim(v.Path.Value, `"`)
74 if ipath == "C" {
75 break
76 }
77
78 local := importPathToName(ipath, srcDir)
79
80 imported[local] = v
81 case *ast.SelectorExpr:
82 xident, ok := v.X.(*ast.Ident)
83 if !ok {
84 break
85 }
86 if xident.Obj != nil {
87 // if the parser can resolve it, it's not a package ref
88 break
89 }
90 used[xident.Name] = true
91 }
92 }), file)
93
94 for pkg := range used {
95 delete(imported, pkg)
96 }
97
98 unusedImport := map[string]string{}
99 for pkg, is := range imported {
100 if !used[pkg] && pkg != "_" && pkg != "." {
101 name := ""
102 if is.Name != nil {
103 name = is.Name.Name
104 }
105 unusedImport[strings.Trim(is.Path.Value, `"`)] = name
106 }
107 }
108
109 return unusedImport, nil
110}
111
112func importPathToName(importPath, srcDir string) (packageName string) {
113 pkg, err := build.Default.Import(importPath, srcDir, 0)
114 if err != nil {
115 return ""
116 }
117
118 return pkg.Name
119}