1package config
2
3import (
4 "fmt"
5 "go/types"
6 "io/ioutil"
7 "os"
8 "path/filepath"
9 "regexp"
10 "sort"
11 "strings"
12
13 "golang.org/x/tools/go/packages"
14
15 "github.com/99designs/gqlgen/internal/code"
16 "github.com/pkg/errors"
17 "github.com/vektah/gqlparser"
18 "github.com/vektah/gqlparser/ast"
19 "gopkg.in/yaml.v2"
20)
21
22type Config struct {
23 SchemaFilename StringList `yaml:"schema,omitempty"`
24 Exec PackageConfig `yaml:"exec"`
25 Model PackageConfig `yaml:"model"`
26 Resolver PackageConfig `yaml:"resolver,omitempty"`
27 AutoBind []string `yaml:"autobind"`
28 Models TypeMap `yaml:"models,omitempty"`
29 StructTag string `yaml:"struct_tag,omitempty"`
30 Directives map[string]DirectiveConfig `yaml:"directives,omitempty"`
31}
32
33var cfgFilenames = []string{".gqlgen.yml", "gqlgen.yml", "gqlgen.yaml"}
34
35// DefaultConfig creates a copy of the default config
36func DefaultConfig() *Config {
37 return &Config{
38 SchemaFilename: StringList{"schema.graphql"},
39 Model: PackageConfig{Filename: "models_gen.go"},
40 Exec: PackageConfig{Filename: "generated.go"},
41 Directives: map[string]DirectiveConfig{},
42 }
43}
44
45// LoadConfigFromDefaultLocations looks for a config file in the current directory, and all parent directories
46// walking up the tree. The closest config file will be returned.
47func LoadConfigFromDefaultLocations() (*Config, error) {
48 cfgFile, err := findCfg()
49 if err != nil {
50 return nil, err
51 }
52
53 err = os.Chdir(filepath.Dir(cfgFile))
54 if err != nil {
55 return nil, errors.Wrap(err, "unable to enter config dir")
56 }
57 return LoadConfig(cfgFile)
58}
59
60var path2regex = strings.NewReplacer(
61 `.`, `\.`,
62 `*`, `.+`,
63 `\`, `[\\/]`,
64 `/`, `[\\/]`,
65)
66
67// LoadConfig reads the gqlgen.yml config file
68func LoadConfig(filename string) (*Config, error) {
69 config := DefaultConfig()
70
71 b, err := ioutil.ReadFile(filename)
72 if err != nil {
73 return nil, errors.Wrap(err, "unable to read config")
74 }
75
76 if err := yaml.UnmarshalStrict(b, config); err != nil {
77 return nil, errors.Wrap(err, "unable to parse config")
78 }
79
80 defaultDirectives := map[string]DirectiveConfig{
81 "skip": {SkipRuntime: true},
82 "include": {SkipRuntime: true},
83 "deprecated": {SkipRuntime: true},
84 }
85
86 for key, value := range defaultDirectives {
87 if _, defined := config.Directives[key]; !defined {
88 config.Directives[key] = value
89 }
90 }
91
92 preGlobbing := config.SchemaFilename
93 config.SchemaFilename = StringList{}
94 for _, f := range preGlobbing {
95 var matches []string
96
97 // for ** we want to override default globbing patterns and walk all
98 // subdirectories to match schema files.
99 if strings.Contains(f, "**") {
100 pathParts := strings.SplitN(f, "**", 2)
101 rest := strings.TrimPrefix(strings.TrimPrefix(pathParts[1], `\`), `/`)
102 // turn the rest of the glob into a regex, anchored only at the end because ** allows
103 // for any number of dirs in between and walk will let us match against the full path name
104 globRe := regexp.MustCompile(path2regex.Replace(rest) + `$`)
105
106 if err := filepath.Walk(pathParts[0], func(path string, info os.FileInfo, err error) error {
107 if err != nil {
108 return err
109 }
110
111 if globRe.MatchString(strings.TrimPrefix(path, pathParts[0])) {
112 matches = append(matches, path)
113 }
114
115 return nil
116 }); err != nil {
117 return nil, errors.Wrapf(err, "failed to walk schema at root %s", pathParts[0])
118 }
119 } else {
120 matches, err = filepath.Glob(f)
121 if err != nil {
122 return nil, errors.Wrapf(err, "failed to glob schema filename %s", f)
123 }
124 }
125
126 for _, m := range matches {
127 if config.SchemaFilename.Has(m) {
128 continue
129 }
130 config.SchemaFilename = append(config.SchemaFilename, m)
131 }
132 }
133
134 return config, nil
135}
136
137type PackageConfig struct {
138 Filename string `yaml:"filename,omitempty"`
139 Package string `yaml:"package,omitempty"`
140 Type string `yaml:"type,omitempty"`
141}
142
143type TypeMapEntry struct {
144 Model StringList `yaml:"model"`
145 Fields map[string]TypeMapField `yaml:"fields,omitempty"`
146}
147
148type TypeMapField struct {
149 Resolver bool `yaml:"resolver"`
150 FieldName string `yaml:"fieldName"`
151}
152
153type StringList []string
154
155func (a *StringList) UnmarshalYAML(unmarshal func(interface{}) error) error {
156 var single string
157 err := unmarshal(&single)
158 if err == nil {
159 *a = []string{single}
160 return nil
161 }
162
163 var multi []string
164 err = unmarshal(&multi)
165 if err != nil {
166 return err
167 }
168
169 *a = multi
170 return nil
171}
172
173func (a StringList) Has(file string) bool {
174 for _, existing := range a {
175 if existing == file {
176 return true
177 }
178 }
179 return false
180}
181
182func (c *PackageConfig) normalize() error {
183 if c.Filename == "" {
184 return errors.New("Filename is required")
185 }
186 c.Filename = abs(c.Filename)
187 // If Package is not set, first attempt to load the package at the output dir. If that fails
188 // fallback to just the base dir name of the output filename.
189 if c.Package == "" {
190 c.Package = code.NameForDir(c.Dir())
191 }
192
193 return nil
194}
195
196func (c *PackageConfig) ImportPath() string {
197 return code.ImportPathForDir(c.Dir())
198}
199
200func (c *PackageConfig) Dir() string {
201 return filepath.Dir(c.Filename)
202}
203
204func (c *PackageConfig) Check() error {
205 if strings.ContainsAny(c.Package, "./\\") {
206 return fmt.Errorf("package should be the output package name only, do not include the output filename")
207 }
208 if c.Filename != "" && !strings.HasSuffix(c.Filename, ".go") {
209 return fmt.Errorf("filename should be path to a go source file")
210 }
211
212 return c.normalize()
213}
214
215func (c *PackageConfig) Pkg() *types.Package {
216 return types.NewPackage(c.ImportPath(), c.Dir())
217}
218
219func (c *PackageConfig) IsDefined() bool {
220 return c.Filename != ""
221}
222
223func (c *Config) Check() error {
224 if err := c.Models.Check(); err != nil {
225 return errors.Wrap(err, "config.models")
226 }
227 if err := c.Exec.Check(); err != nil {
228 return errors.Wrap(err, "config.exec")
229 }
230 if err := c.Model.Check(); err != nil {
231 return errors.Wrap(err, "config.model")
232 }
233 if c.Resolver.IsDefined() {
234 if err := c.Resolver.Check(); err != nil {
235 return errors.Wrap(err, "config.resolver")
236 }
237 }
238
239 // check packages names against conflict, if present in the same dir
240 // and check filenames for uniqueness
241 packageConfigList := []PackageConfig{
242 c.Model,
243 c.Exec,
244 c.Resolver,
245 }
246 filesMap := make(map[string]bool)
247 pkgConfigsByDir := make(map[string]PackageConfig)
248 for _, current := range packageConfigList {
249 _, fileFound := filesMap[current.Filename]
250 if fileFound {
251 return fmt.Errorf("filename %s defined more than once", current.Filename)
252 }
253 filesMap[current.Filename] = true
254 previous, inSameDir := pkgConfigsByDir[current.Dir()]
255 if inSameDir && current.Package != previous.Package {
256 return fmt.Errorf("filenames %s and %s are in the same directory but have different package definitions", stripPath(current.Filename), stripPath(previous.Filename))
257 }
258 pkgConfigsByDir[current.Dir()] = current
259 }
260
261 return c.normalize()
262}
263
264func stripPath(path string) string {
265 return filepath.Base(path)
266}
267
268type TypeMap map[string]TypeMapEntry
269
270func (tm TypeMap) Exists(typeName string) bool {
271 _, ok := tm[typeName]
272 return ok
273}
274
275func (tm TypeMap) UserDefined(typeName string) bool {
276 m, ok := tm[typeName]
277 return ok && len(m.Model) > 0
278}
279
280func (tm TypeMap) Check() error {
281 for typeName, entry := range tm {
282 for _, model := range entry.Model {
283 if strings.LastIndex(model, ".") < strings.LastIndex(model, "/") {
284 return fmt.Errorf("model %s: invalid type specifier \"%s\" - you need to specify a struct to map to", typeName, entry.Model)
285 }
286 }
287 }
288 return nil
289}
290
291func (tm TypeMap) ReferencedPackages() []string {
292 var pkgs []string
293
294 for _, typ := range tm {
295 for _, model := range typ.Model {
296 if model == "map[string]interface{}" || model == "interface{}" {
297 continue
298 }
299 pkg, _ := code.PkgAndType(model)
300 if pkg == "" || inStrSlice(pkgs, pkg) {
301 continue
302 }
303 pkgs = append(pkgs, code.QualifyPackagePath(pkg))
304 }
305 }
306
307 sort.Slice(pkgs, func(i, j int) bool {
308 return pkgs[i] > pkgs[j]
309 })
310 return pkgs
311}
312
313func (tm TypeMap) Add(name string, goType string) {
314 modelCfg := tm[name]
315 modelCfg.Model = append(modelCfg.Model, goType)
316 tm[name] = modelCfg
317}
318
319type DirectiveConfig struct {
320 SkipRuntime bool `yaml:"skip_runtime"`
321}
322
323func inStrSlice(haystack []string, needle string) bool {
324 for _, v := range haystack {
325 if needle == v {
326 return true
327 }
328 }
329
330 return false
331}
332
333// findCfg searches for the config file in this directory and all parents up the tree
334// looking for the closest match
335func findCfg() (string, error) {
336 dir, err := os.Getwd()
337 if err != nil {
338 return "", errors.Wrap(err, "unable to get working dir to findCfg")
339 }
340
341 cfg := findCfgInDir(dir)
342
343 for cfg == "" && dir != filepath.Dir(dir) {
344 dir = filepath.Dir(dir)
345 cfg = findCfgInDir(dir)
346 }
347
348 if cfg == "" {
349 return "", os.ErrNotExist
350 }
351
352 return cfg, nil
353}
354
355func findCfgInDir(dir string) string {
356 for _, cfgName := range cfgFilenames {
357 path := filepath.Join(dir, cfgName)
358 if _, err := os.Stat(path); err == nil {
359 return path
360 }
361 }
362 return ""
363}
364
365func (c *Config) normalize() error {
366 if err := c.Model.normalize(); err != nil {
367 return errors.Wrap(err, "model")
368 }
369
370 if err := c.Exec.normalize(); err != nil {
371 return errors.Wrap(err, "exec")
372 }
373
374 if c.Resolver.IsDefined() {
375 if err := c.Resolver.normalize(); err != nil {
376 return errors.Wrap(err, "resolver")
377 }
378 }
379
380 if c.Models == nil {
381 c.Models = TypeMap{}
382 }
383
384 return nil
385}
386
387func (c *Config) Autobind(s *ast.Schema) error {
388 if len(c.AutoBind) == 0 {
389 return nil
390 }
391 ps, err := packages.Load(&packages.Config{Mode: packages.LoadTypes}, c.AutoBind...)
392 if err != nil {
393 return err
394 }
395
396 for _, t := range s.Types {
397 if c.Models.UserDefined(t.Name) {
398 continue
399 }
400
401 for _, p := range ps {
402 if t := p.Types.Scope().Lookup(t.Name); t != nil {
403 c.Models.Add(t.Name(), t.Pkg().Path()+"."+t.Name())
404 break
405 }
406 }
407 }
408
409 return nil
410}
411
412func (c *Config) InjectBuiltins(s *ast.Schema) {
413 builtins := TypeMap{
414 "__Directive": {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.Directive"}},
415 "__DirectiveLocation": {Model: StringList{"github.com/99designs/gqlgen/graphql.String"}},
416 "__Type": {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.Type"}},
417 "__TypeKind": {Model: StringList{"github.com/99designs/gqlgen/graphql.String"}},
418 "__Field": {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.Field"}},
419 "__EnumValue": {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.EnumValue"}},
420 "__InputValue": {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.InputValue"}},
421 "__Schema": {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.Schema"}},
422 "Float": {Model: StringList{"github.com/99designs/gqlgen/graphql.Float"}},
423 "String": {Model: StringList{"github.com/99designs/gqlgen/graphql.String"}},
424 "Boolean": {Model: StringList{"github.com/99designs/gqlgen/graphql.Boolean"}},
425 "Int": {Model: StringList{
426 "github.com/99designs/gqlgen/graphql.Int",
427 "github.com/99designs/gqlgen/graphql.Int32",
428 "github.com/99designs/gqlgen/graphql.Int64",
429 }},
430 "ID": {
431 Model: StringList{
432 "github.com/99designs/gqlgen/graphql.ID",
433 "github.com/99designs/gqlgen/graphql.IntID",
434 },
435 },
436 }
437
438 for typeName, entry := range builtins {
439 if !c.Models.Exists(typeName) {
440 c.Models[typeName] = entry
441 }
442 }
443
444 // These are additional types that are injected if defined in the schema as scalars.
445 extraBuiltins := TypeMap{
446 "Time": {Model: StringList{"github.com/99designs/gqlgen/graphql.Time"}},
447 "Map": {Model: StringList{"github.com/99designs/gqlgen/graphql.Map"}},
448 "Upload": {Model: StringList{"github.com/99designs/gqlgen/graphql.Upload"}},
449 "Any": {Model: StringList{"github.com/99designs/gqlgen/graphql.Any"}},
450 }
451
452 for typeName, entry := range extraBuiltins {
453 if t, ok := s.Types[typeName]; !c.Models.Exists(typeName) && ok && t.Kind == ast.Scalar {
454 c.Models[typeName] = entry
455 }
456 }
457}
458
459func (c *Config) LoadSchema() (*ast.Schema, map[string]string, error) {
460 schemaStrings := map[string]string{}
461
462 sources := make([]*ast.Source, len(c.SchemaFilename))
463
464 for i, filename := range c.SchemaFilename {
465 filename = filepath.ToSlash(filename)
466 var err error
467 var schemaRaw []byte
468 schemaRaw, err = ioutil.ReadFile(filename)
469 if err != nil {
470 fmt.Fprintln(os.Stderr, "unable to open schema: "+err.Error())
471 os.Exit(1)
472 }
473 schemaStrings[filename] = string(schemaRaw)
474 sources[i] = &ast.Source{Name: filename, Input: schemaStrings[filename]}
475 }
476
477 schema, err := gqlparser.LoadSchema(sources...)
478 if err != nil {
479 return nil, nil, err
480 }
481 return schema, schemaStrings, nil
482}
483
484func abs(path string) string {
485 absPath, err := filepath.Abs(path)
486 if err != nil {
487 panic(err)
488 }
489 return filepath.ToSlash(absPath)
490}