prune.go

  1// Wrapper around x/tools/imports that only removes imports, never adds new ones.
  2
  3package imports
  4
  5import (
  6	"bytes"
  7	"go/ast"
  8	"go/parser"
  9	"go/printer"
 10	"go/token"
 11	"strings"
 12
 13	"github.com/99designs/gqlgen/internal/code"
 14
 15	"golang.org/x/tools/go/ast/astutil"
 16	"golang.org/x/tools/imports"
 17)
 18
 19type visitFn func(node ast.Node)
 20
 21func (fn visitFn) Visit(node ast.Node) ast.Visitor {
 22	fn(node)
 23	return fn
 24}
 25
 26// Prune removes any unused imports
 27func Prune(filename string, src []byte) ([]byte, error) {
 28	fset := token.NewFileSet()
 29
 30	file, err := parser.ParseFile(fset, filename, src, parser.ParseComments|parser.AllErrors)
 31	if err != nil {
 32		return nil, err
 33	}
 34
 35	unused := getUnusedImports(file)
 36	for ipath, name := range unused {
 37		astutil.DeleteNamedImport(fset, file, name, ipath)
 38	}
 39	printConfig := &printer.Config{Mode: printer.TabIndent, Tabwidth: 8}
 40
 41	var buf bytes.Buffer
 42	if err := printConfig.Fprint(&buf, fset, file); err != nil {
 43		return nil, err
 44	}
 45
 46	return imports.Process(filename, buf.Bytes(), &imports.Options{FormatOnly: true, Comments: true, TabIndent: true, TabWidth: 8})
 47}
 48
 49func getUnusedImports(file ast.Node) map[string]string {
 50	imported := map[string]*ast.ImportSpec{}
 51	used := map[string]bool{}
 52
 53	ast.Walk(visitFn(func(node ast.Node) {
 54		if node == nil {
 55			return
 56		}
 57		switch v := node.(type) {
 58		case *ast.ImportSpec:
 59			if v.Name != nil {
 60				imported[v.Name.Name] = v
 61				break
 62			}
 63			ipath := strings.Trim(v.Path.Value, `"`)
 64			if ipath == "C" {
 65				break
 66			}
 67
 68			local := code.NameForPackage(ipath)
 69
 70			imported[local] = v
 71		case *ast.SelectorExpr:
 72			xident, ok := v.X.(*ast.Ident)
 73			if !ok {
 74				break
 75			}
 76			if xident.Obj != nil {
 77				// if the parser can resolve it, it's not a package ref
 78				break
 79			}
 80			used[xident.Name] = true
 81		}
 82	}), file)
 83
 84	for pkg := range used {
 85		delete(imported, pkg)
 86	}
 87
 88	unusedImport := map[string]string{}
 89	for pkg, is := range imported {
 90		if !used[pkg] && pkg != "_" && pkg != "." {
 91			name := ""
 92			if is.Name != nil {
 93				name = is.Name.Name
 94			}
 95			unusedImport[strings.Trim(is.Path.Value, `"`)] = name
 96		}
 97	}
 98
 99	return unusedImport
100}