1// Wrapper around x/tools/imports that only removes imports, never adds new ones.
2
3package imports
4
5import (
6 "bytes"
7 "go/ast"
8 "go/parser"
9 "go/printer"
10 "go/token"
11 "strings"
12
13 "github.com/99designs/gqlgen/internal/code"
14
15 "golang.org/x/tools/go/ast/astutil"
16 "golang.org/x/tools/imports"
17)
18
19type visitFn func(node ast.Node)
20
21func (fn visitFn) Visit(node ast.Node) ast.Visitor {
22 fn(node)
23 return fn
24}
25
26// Prune removes any unused imports
27func Prune(filename string, src []byte) ([]byte, error) {
28 fset := token.NewFileSet()
29
30 file, err := parser.ParseFile(fset, filename, src, parser.ParseComments|parser.AllErrors)
31 if err != nil {
32 return nil, err
33 }
34
35 unused, err := getUnusedImports(file, filename)
36 if err != nil {
37 return nil, err
38 }
39 for ipath, name := range unused {
40 astutil.DeleteNamedImport(fset, file, name, ipath)
41 }
42 printConfig := &printer.Config{Mode: printer.TabIndent, Tabwidth: 8}
43
44 var buf bytes.Buffer
45 if err := printConfig.Fprint(&buf, fset, file); err != nil {
46 return nil, err
47 }
48
49 return imports.Process(filename, buf.Bytes(), &imports.Options{FormatOnly: true, Comments: true, TabIndent: true, TabWidth: 8})
50}
51
52func getUnusedImports(file ast.Node, filename string) (map[string]string, error) {
53 imported := map[string]*ast.ImportSpec{}
54 used := map[string]bool{}
55
56 ast.Walk(visitFn(func(node ast.Node) {
57 if node == nil {
58 return
59 }
60 switch v := node.(type) {
61 case *ast.ImportSpec:
62 if v.Name != nil {
63 imported[v.Name.Name] = v
64 break
65 }
66 ipath := strings.Trim(v.Path.Value, `"`)
67 if ipath == "C" {
68 break
69 }
70
71 local := code.NameForPackage(ipath)
72
73 imported[local] = v
74 case *ast.SelectorExpr:
75 xident, ok := v.X.(*ast.Ident)
76 if !ok {
77 break
78 }
79 if xident.Obj != nil {
80 // if the parser can resolve it, it's not a package ref
81 break
82 }
83 used[xident.Name] = true
84 }
85 }), file)
86
87 for pkg := range used {
88 delete(imported, pkg)
89 }
90
91 unusedImport := map[string]string{}
92 for pkg, is := range imported {
93 if !used[pkg] && pkg != "_" && pkg != "." {
94 name := ""
95 if is.Name != nil {
96 name = is.Name.Name
97 }
98 unusedImport[strings.Trim(is.Path.Value, `"`)] = name
99 }
100 }
101
102 return unusedImport, nil
103}