1package codegen
2
3import (
4 "fmt"
5 "go/build"
6 "sort"
7 "strconv"
8
9 // Import and ignore the ambient imports listed below so dependency managers
10 // don't prune unused code for us. Both lists should be kept in sync.
11 _ "github.com/99designs/gqlgen/graphql"
12 _ "github.com/99designs/gqlgen/graphql/introspection"
13 "github.com/99designs/gqlgen/internal/gopath"
14 _ "github.com/vektah/gqlparser"
15 _ "github.com/vektah/gqlparser/ast"
16)
17
18// These imports are referenced by the generated code, and are assumed to have the
19// default alias. So lets make sure they get added first, and any later collisions get
20// renamed.
21var ambientImports = []string{
22 "context",
23 "fmt",
24 "io",
25 "strconv",
26 "time",
27 "sync",
28 "errors",
29
30 "github.com/vektah/gqlparser",
31 "github.com/vektah/gqlparser/ast",
32 "github.com/99designs/gqlgen/graphql",
33 "github.com/99designs/gqlgen/graphql/introspection",
34}
35
36func buildImports(types NamedTypes, destDir string) *Imports {
37 imports := Imports{
38 destDir: destDir,
39 }
40
41 for _, ambient := range ambientImports {
42 imports.add(ambient)
43 }
44
45 // Imports from top level user types
46 for _, t := range types {
47 t.Import = imports.add(t.Package)
48 }
49
50 return &imports
51}
52
53func (s *Imports) add(path string) *Import {
54 if path == "" {
55 return nil
56 }
57
58 // if we are referencing our own package we dont need an import
59 if gopath.MustDir2Import(s.destDir) == path {
60 return nil
61 }
62
63 if existing := s.findByPath(path); existing != nil {
64 return existing
65 }
66
67 pkg, err := build.Default.Import(path, s.destDir, 0)
68 if err != nil {
69 panic(err)
70 }
71
72 imp := &Import{
73 Name: pkg.Name,
74 Path: path,
75 }
76 s.imports = append(s.imports, imp)
77
78 return imp
79}
80
81func (s Imports) finalize() []*Import {
82 // ensure stable ordering by sorting
83 sort.Slice(s.imports, func(i, j int) bool {
84 return s.imports[i].Path > s.imports[j].Path
85 })
86
87 for _, imp := range s.imports {
88 alias := imp.Name
89
90 i := 1
91 for s.findByAlias(alias) != nil {
92 alias = imp.Name + strconv.Itoa(i)
93 i++
94 if i > 10 {
95 panic(fmt.Errorf("too many collisions, last attempt was %s", alias))
96 }
97 }
98 imp.alias = alias
99 }
100
101 return s.imports
102}
103
104func (s Imports) findByPath(importPath string) *Import {
105 for _, imp := range s.imports {
106 if imp.Path == importPath {
107 return imp
108 }
109 }
110 return nil
111}
112
113func (s Imports) findByAlias(alias string) *Import {
114 for _, imp := range s.imports {
115 if imp.alias == alias {
116 return imp
117 }
118 }
119 return nil
120}