import_build.go

  1package codegen
  2
  3import (
  4	"fmt"
  5	"go/build"
  6	"sort"
  7	"strconv"
  8
  9	// Import and ignore the ambient imports listed below so dependency managers
 10	// don't prune unused code for us. Both lists should be kept in sync.
 11	_ "github.com/99designs/gqlgen/graphql"
 12	_ "github.com/99designs/gqlgen/graphql/introspection"
 13	"github.com/99designs/gqlgen/internal/gopath"
 14	_ "github.com/vektah/gqlparser"
 15	_ "github.com/vektah/gqlparser/ast"
 16)
 17
 18// These imports are referenced by the generated code, and are assumed to have the
 19// default alias. So lets make sure they get added first, and any later collisions get
 20// renamed.
 21var ambientImports = []string{
 22	"context",
 23	"fmt",
 24	"io",
 25	"strconv",
 26	"time",
 27	"sync",
 28	"errors",
 29
 30	"github.com/vektah/gqlparser",
 31	"github.com/vektah/gqlparser/ast",
 32	"github.com/99designs/gqlgen/graphql",
 33	"github.com/99designs/gqlgen/graphql/introspection",
 34}
 35
 36func buildImports(types NamedTypes, destDir string) *Imports {
 37	imports := Imports{
 38		destDir: destDir,
 39	}
 40
 41	for _, ambient := range ambientImports {
 42		imports.add(ambient)
 43	}
 44
 45	// Imports from top level user types
 46	for _, t := range types {
 47		t.Import = imports.add(t.Package)
 48	}
 49
 50	return &imports
 51}
 52
 53func (s *Imports) add(path string) *Import {
 54	if path == "" {
 55		return nil
 56	}
 57
 58	// if we are referencing our own package we dont need an import
 59	if gopath.MustDir2Import(s.destDir) == path {
 60		return nil
 61	}
 62
 63	if existing := s.findByPath(path); existing != nil {
 64		return existing
 65	}
 66
 67	pkg, err := build.Default.Import(path, s.destDir, 0)
 68	if err != nil {
 69		panic(err)
 70	}
 71
 72	imp := &Import{
 73		Name: pkg.Name,
 74		Path: path,
 75	}
 76	s.imports = append(s.imports, imp)
 77
 78	return imp
 79}
 80
 81func (s Imports) finalize() []*Import {
 82	// ensure stable ordering by sorting
 83	sort.Slice(s.imports, func(i, j int) bool {
 84		return s.imports[i].Path > s.imports[j].Path
 85	})
 86
 87	for _, imp := range s.imports {
 88		alias := imp.Name
 89
 90		i := 1
 91		for s.findByAlias(alias) != nil {
 92			alias = imp.Name + strconv.Itoa(i)
 93			i++
 94			if i > 10 {
 95				panic(fmt.Errorf("too many collisions, last attempt was %s", alias))
 96			}
 97		}
 98		imp.alias = alias
 99	}
100
101	return s.imports
102}
103
104func (s Imports) findByPath(importPath string) *Import {
105	for _, imp := range s.imports {
106		if imp.Path == importPath {
107			return imp
108		}
109	}
110	return nil
111}
112
113func (s Imports) findByAlias(alias string) *Import {
114	for _, imp := range s.imports {
115		if imp.alias == alias {
116			return imp
117		}
118	}
119	return nil
120}