1package codegen
2
3import (
4 "fmt"
5 "go/build"
6 "go/types"
7 "os"
8
9 "github.com/pkg/errors"
10 "golang.org/x/tools/go/loader"
11)
12
13type Build struct {
14 PackageName string
15 Objects Objects
16 Inputs Objects
17 Interfaces []*Interface
18 QueryRoot *Object
19 MutationRoot *Object
20 SubscriptionRoot *Object
21 SchemaRaw map[string]string
22 SchemaFilename SchemaFilenames
23 Directives []*Directive
24}
25
26type ModelBuild struct {
27 PackageName string
28 Models []Model
29 Enums []Enum
30}
31
32type ResolverBuild struct {
33 PackageName string
34 ResolverType string
35 Objects Objects
36 ResolverFound bool
37}
38
39type ServerBuild struct {
40 PackageName string
41 ExecPackageName string
42 ResolverPackageName string
43}
44
45// Create a list of models that need to be generated
46func (cfg *Config) models() (*ModelBuild, error) {
47 namedTypes := cfg.buildNamedTypes()
48
49 progLoader := cfg.newLoaderWithoutErrors()
50
51 prog, err := progLoader.Load()
52 if err != nil {
53 return nil, errors.Wrap(err, "loading failed")
54 }
55
56 cfg.bindTypes(namedTypes, cfg.Model.Dir(), prog)
57
58 models, err := cfg.buildModels(namedTypes, prog)
59 if err != nil {
60 return nil, err
61 }
62 return &ModelBuild{
63 PackageName: cfg.Model.Package,
64 Models: models,
65 Enums: cfg.buildEnums(namedTypes),
66 }, nil
67}
68
69// bind a schema together with some code to generate a Build
70func (cfg *Config) resolver() (*ResolverBuild, error) {
71 progLoader := cfg.newLoaderWithoutErrors()
72 progLoader.Import(cfg.Resolver.ImportPath())
73
74 prog, err := progLoader.Load()
75 if err != nil {
76 return nil, err
77 }
78
79 destDir := cfg.Resolver.Dir()
80
81 namedTypes := cfg.buildNamedTypes()
82
83 cfg.bindTypes(namedTypes, destDir, prog)
84
85 objects, err := cfg.buildObjects(namedTypes, prog)
86 if err != nil {
87 return nil, err
88 }
89
90 def, _ := findGoType(prog, cfg.Resolver.ImportPath(), cfg.Resolver.Type)
91 resolverFound := def != nil
92
93 return &ResolverBuild{
94 PackageName: cfg.Resolver.Package,
95 Objects: objects,
96 ResolverType: cfg.Resolver.Type,
97 ResolverFound: resolverFound,
98 }, nil
99}
100
101func (cfg *Config) server(destDir string) *ServerBuild {
102 return &ServerBuild{
103 PackageName: cfg.Resolver.Package,
104 ExecPackageName: cfg.Exec.ImportPath(),
105 ResolverPackageName: cfg.Resolver.ImportPath(),
106 }
107}
108
109// bind a schema together with some code to generate a Build
110func (cfg *Config) bind() (*Build, error) {
111 namedTypes := cfg.buildNamedTypes()
112
113 progLoader := cfg.newLoaderWithoutErrors()
114 prog, err := progLoader.Load()
115 if err != nil {
116 return nil, errors.Wrap(err, "loading failed")
117 }
118
119 cfg.bindTypes(namedTypes, cfg.Exec.Dir(), prog)
120
121 objects, err := cfg.buildObjects(namedTypes, prog)
122 if err != nil {
123 return nil, err
124 }
125
126 inputs, err := cfg.buildInputs(namedTypes, prog)
127 if err != nil {
128 return nil, err
129 }
130 directives, err := cfg.buildDirectives(namedTypes)
131 if err != nil {
132 return nil, err
133 }
134
135 b := &Build{
136 PackageName: cfg.Exec.Package,
137 Objects: objects,
138 Interfaces: cfg.buildInterfaces(namedTypes, prog),
139 Inputs: inputs,
140 SchemaRaw: cfg.SchemaStr,
141 SchemaFilename: cfg.SchemaFilename,
142 Directives: directives,
143 }
144
145 if cfg.schema.Query != nil {
146 b.QueryRoot = b.Objects.ByName(cfg.schema.Query.Name)
147 } else {
148 return b, fmt.Errorf("query entry point missing")
149 }
150
151 if cfg.schema.Mutation != nil {
152 b.MutationRoot = b.Objects.ByName(cfg.schema.Mutation.Name)
153 }
154
155 if cfg.schema.Subscription != nil {
156 b.SubscriptionRoot = b.Objects.ByName(cfg.schema.Subscription.Name)
157 }
158 return b, nil
159}
160
161func (cfg *Config) validate() error {
162 progLoader := cfg.newLoaderWithErrors()
163 _, err := progLoader.Load()
164 return err
165}
166
167func (cfg *Config) newLoaderWithErrors() loader.Config {
168 conf := loader.Config{}
169
170 for _, pkg := range cfg.Models.referencedPackages() {
171 conf.Import(pkg)
172 }
173 return conf
174}
175
176func (cfg *Config) newLoaderWithoutErrors() loader.Config {
177 conf := cfg.newLoaderWithErrors()
178 conf.AllowErrors = true
179 conf.TypeChecker = types.Config{
180 Error: func(e error) {},
181 }
182 return conf
183}
184
185func resolvePkg(pkgName string) (string, error) {
186 cwd, _ := os.Getwd()
187
188 pkg, err := build.Default.Import(pkgName, cwd, build.FindOnly)
189 if err != nil {
190 return "", err
191 }
192
193 return pkg.ImportPath, nil
194}