1package config
2
3import (
4 "fmt"
5 "go/types"
6 "io/ioutil"
7 "os"
8 "path/filepath"
9 "regexp"
10 "sort"
11 "strings"
12
13 "golang.org/x/tools/go/packages"
14
15 "github.com/99designs/gqlgen/internal/code"
16 "github.com/pkg/errors"
17 "github.com/vektah/gqlparser"
18 "github.com/vektah/gqlparser/ast"
19 yaml "gopkg.in/yaml.v2"
20)
21
22type Config struct {
23 SchemaFilename StringList `yaml:"schema,omitempty"`
24 Exec PackageConfig `yaml:"exec"`
25 Model PackageConfig `yaml:"model"`
26 Resolver PackageConfig `yaml:"resolver,omitempty"`
27 AutoBind []string `yaml:"autobind"`
28 Models TypeMap `yaml:"models,omitempty"`
29 StructTag string `yaml:"struct_tag,omitempty"`
30 Directives map[string]DirectiveConfig `yaml:"directives,omitempty"`
31}
32
33var cfgFilenames = []string{".gqlgen.yml", "gqlgen.yml", "gqlgen.yaml"}
34
35// DefaultConfig creates a copy of the default config
36func DefaultConfig() *Config {
37 return &Config{
38 SchemaFilename: StringList{"schema.graphql"},
39 Model: PackageConfig{Filename: "models_gen.go"},
40 Exec: PackageConfig{Filename: "generated.go"},
41 Directives: map[string]DirectiveConfig{
42 "skip": {
43 SkipRuntime: true,
44 },
45 "include": {
46 SkipRuntime: true,
47 },
48 "deprecated": {
49 SkipRuntime: true,
50 },
51 },
52 }
53}
54
55// LoadConfigFromDefaultLocations looks for a config file in the current directory, and all parent directories
56// walking up the tree. The closest config file will be returned.
57func LoadConfigFromDefaultLocations() (*Config, error) {
58 cfgFile, err := findCfg()
59 if err != nil {
60 return nil, err
61 }
62
63 err = os.Chdir(filepath.Dir(cfgFile))
64 if err != nil {
65 return nil, errors.Wrap(err, "unable to enter config dir")
66 }
67 return LoadConfig(cfgFile)
68}
69
70var path2regex = strings.NewReplacer(
71 `.`, `\.`,
72 `*`, `.+`,
73 `\`, `[\\/]`,
74 `/`, `[\\/]`,
75)
76
77// LoadConfig reads the gqlgen.yml config file
78func LoadConfig(filename string) (*Config, error) {
79 config := DefaultConfig()
80
81 b, err := ioutil.ReadFile(filename)
82 if err != nil {
83 return nil, errors.Wrap(err, "unable to read config")
84 }
85
86 if err := yaml.UnmarshalStrict(b, config); err != nil {
87 return nil, errors.Wrap(err, "unable to parse config")
88 }
89
90 preGlobbing := config.SchemaFilename
91 config.SchemaFilename = StringList{}
92 for _, f := range preGlobbing {
93 var matches []string
94
95 // for ** we want to override default globbing patterns and walk all
96 // subdirectories to match schema files.
97 if strings.Contains(f, "**") {
98 pathParts := strings.SplitN(f, "**", 2)
99 rest := strings.TrimPrefix(strings.TrimPrefix(pathParts[1], `\`), `/`)
100 // turn the rest of the glob into a regex, anchored only at the end because ** allows
101 // for any number of dirs in between and walk will let us match against the full path name
102 globRe := regexp.MustCompile(path2regex.Replace(rest) + `$`)
103
104 if err := filepath.Walk(pathParts[0], func(path string, info os.FileInfo, err error) error {
105 if err != nil {
106 return err
107 }
108
109 if globRe.MatchString(strings.TrimPrefix(path, pathParts[0])) {
110 matches = append(matches, path)
111 }
112
113 return nil
114 }); err != nil {
115 return nil, errors.Wrapf(err, "failed to walk schema at root %s", pathParts[0])
116 }
117 } else {
118 matches, err = filepath.Glob(f)
119 if err != nil {
120 return nil, errors.Wrapf(err, "failed to glob schema filename %s", f)
121 }
122 }
123
124 for _, m := range matches {
125 if config.SchemaFilename.Has(m) {
126 continue
127 }
128 config.SchemaFilename = append(config.SchemaFilename, m)
129 }
130 }
131
132 return config, nil
133}
134
135type PackageConfig struct {
136 Filename string `yaml:"filename,omitempty"`
137 Package string `yaml:"package,omitempty"`
138 Type string `yaml:"type,omitempty"`
139}
140
141type TypeMapEntry struct {
142 Model StringList `yaml:"model"`
143 Fields map[string]TypeMapField `yaml:"fields,omitempty"`
144}
145
146type TypeMapField struct {
147 Resolver bool `yaml:"resolver"`
148 FieldName string `yaml:"fieldName"`
149}
150
151type StringList []string
152
153func (a *StringList) UnmarshalYAML(unmarshal func(interface{}) error) error {
154 var single string
155 err := unmarshal(&single)
156 if err == nil {
157 *a = []string{single}
158 return nil
159 }
160
161 var multi []string
162 err = unmarshal(&multi)
163 if err != nil {
164 return err
165 }
166
167 *a = multi
168 return nil
169}
170
171func (a StringList) Has(file string) bool {
172 for _, existing := range a {
173 if existing == file {
174 return true
175 }
176 }
177 return false
178}
179
180func (c *PackageConfig) normalize() error {
181 if c.Filename == "" {
182 return errors.New("Filename is required")
183 }
184 c.Filename = abs(c.Filename)
185 // If Package is not set, first attempt to load the package at the output dir. If that fails
186 // fallback to just the base dir name of the output filename.
187 if c.Package == "" {
188 c.Package = code.NameForDir(c.Dir())
189 }
190
191 return nil
192}
193
194func (c *PackageConfig) ImportPath() string {
195 return code.ImportPathForDir(c.Dir())
196}
197
198func (c *PackageConfig) Dir() string {
199 return filepath.Dir(c.Filename)
200}
201
202func (c *PackageConfig) Check() error {
203 if strings.ContainsAny(c.Package, "./\\") {
204 return fmt.Errorf("package should be the output package name only, do not include the output filename")
205 }
206 if c.Filename != "" && !strings.HasSuffix(c.Filename, ".go") {
207 return fmt.Errorf("filename should be path to a go source file")
208 }
209
210 return c.normalize()
211}
212
213func (c *PackageConfig) Pkg() *types.Package {
214 return types.NewPackage(c.ImportPath(), c.Dir())
215}
216
217func (c *PackageConfig) IsDefined() bool {
218 return c.Filename != ""
219}
220
221func (c *Config) Check() error {
222 if err := c.Models.Check(); err != nil {
223 return errors.Wrap(err, "config.models")
224 }
225 if err := c.Exec.Check(); err != nil {
226 return errors.Wrap(err, "config.exec")
227 }
228 if err := c.Model.Check(); err != nil {
229 return errors.Wrap(err, "config.model")
230 }
231 if c.Resolver.IsDefined() {
232 if err := c.Resolver.Check(); err != nil {
233 return errors.Wrap(err, "config.resolver")
234 }
235 }
236
237 // check packages names against conflict, if present in the same dir
238 // and check filenames for uniqueness
239 packageConfigList := []PackageConfig{
240 c.Model,
241 c.Exec,
242 c.Resolver,
243 }
244 filesMap := make(map[string]bool)
245 pkgConfigsByDir := make(map[string]PackageConfig)
246 for _, current := range packageConfigList {
247 _, fileFound := filesMap[current.Filename]
248 if fileFound {
249 return fmt.Errorf("filename %s defined more than once", current.Filename)
250 }
251 filesMap[current.Filename] = true
252 previous, inSameDir := pkgConfigsByDir[current.Dir()]
253 if inSameDir && current.Package != previous.Package {
254 return fmt.Errorf("filenames %s and %s are in the same directory but have different package definitions", stripPath(current.Filename), stripPath(previous.Filename))
255 }
256 pkgConfigsByDir[current.Dir()] = current
257 }
258
259 return c.normalize()
260}
261
262func stripPath(path string) string {
263 return filepath.Base(path)
264}
265
266type TypeMap map[string]TypeMapEntry
267
268func (tm TypeMap) Exists(typeName string) bool {
269 _, ok := tm[typeName]
270 return ok
271}
272
273func (tm TypeMap) UserDefined(typeName string) bool {
274 m, ok := tm[typeName]
275 return ok && len(m.Model) > 0
276}
277
278func (tm TypeMap) Check() error {
279 for typeName, entry := range tm {
280 for _, model := range entry.Model {
281 if strings.LastIndex(model, ".") < strings.LastIndex(model, "/") {
282 return fmt.Errorf("model %s: invalid type specifier \"%s\" - you need to specify a struct to map to", typeName, entry.Model)
283 }
284 }
285 }
286 return nil
287}
288
289func (tm TypeMap) ReferencedPackages() []string {
290 var pkgs []string
291
292 for _, typ := range tm {
293 for _, model := range typ.Model {
294 if model == "map[string]interface{}" || model == "interface{}" {
295 continue
296 }
297 pkg, _ := code.PkgAndType(model)
298 if pkg == "" || inStrSlice(pkgs, pkg) {
299 continue
300 }
301 pkgs = append(pkgs, code.QualifyPackagePath(pkg))
302 }
303 }
304
305 sort.Slice(pkgs, func(i, j int) bool {
306 return pkgs[i] > pkgs[j]
307 })
308 return pkgs
309}
310
311func (tm TypeMap) Add(Name string, goType string) {
312 modelCfg := tm[Name]
313 modelCfg.Model = append(modelCfg.Model, goType)
314 tm[Name] = modelCfg
315}
316
317type DirectiveConfig struct {
318 SkipRuntime bool `yaml:"skip_runtime"`
319}
320
321func inStrSlice(haystack []string, needle string) bool {
322 for _, v := range haystack {
323 if needle == v {
324 return true
325 }
326 }
327
328 return false
329}
330
331// findCfg searches for the config file in this directory and all parents up the tree
332// looking for the closest match
333func findCfg() (string, error) {
334 dir, err := os.Getwd()
335 if err != nil {
336 return "", errors.Wrap(err, "unable to get working dir to findCfg")
337 }
338
339 cfg := findCfgInDir(dir)
340
341 for cfg == "" && dir != filepath.Dir(dir) {
342 dir = filepath.Dir(dir)
343 cfg = findCfgInDir(dir)
344 }
345
346 if cfg == "" {
347 return "", os.ErrNotExist
348 }
349
350 return cfg, nil
351}
352
353func findCfgInDir(dir string) string {
354 for _, cfgName := range cfgFilenames {
355 path := filepath.Join(dir, cfgName)
356 if _, err := os.Stat(path); err == nil {
357 return path
358 }
359 }
360 return ""
361}
362
363func (c *Config) normalize() error {
364 if err := c.Model.normalize(); err != nil {
365 return errors.Wrap(err, "model")
366 }
367
368 if err := c.Exec.normalize(); err != nil {
369 return errors.Wrap(err, "exec")
370 }
371
372 if c.Resolver.IsDefined() {
373 if err := c.Resolver.normalize(); err != nil {
374 return errors.Wrap(err, "resolver")
375 }
376 }
377
378 if c.Models == nil {
379 c.Models = TypeMap{}
380 }
381
382 return nil
383}
384
385func (c *Config) Autobind(s *ast.Schema) error {
386 if len(c.AutoBind) == 0 {
387 return nil
388 }
389 ps, err := packages.Load(&packages.Config{Mode: packages.LoadTypes}, c.AutoBind...)
390 if err != nil {
391 return err
392 }
393
394 for _, t := range s.Types {
395 if c.Models.UserDefined(t.Name) {
396 continue
397 }
398
399 for _, p := range ps {
400 if t := p.Types.Scope().Lookup(t.Name); t != nil {
401 c.Models.Add(t.Name(), t.Pkg().Path()+"."+t.Name())
402 break
403 }
404 }
405 }
406
407 return nil
408}
409
410func (c *Config) InjectBuiltins(s *ast.Schema) {
411 builtins := TypeMap{
412 "__Directive": {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.Directive"}},
413 "__DirectiveLocation": {Model: StringList{"github.com/99designs/gqlgen/graphql.String"}},
414 "__Type": {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.Type"}},
415 "__TypeKind": {Model: StringList{"github.com/99designs/gqlgen/graphql.String"}},
416 "__Field": {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.Field"}},
417 "__EnumValue": {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.EnumValue"}},
418 "__InputValue": {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.InputValue"}},
419 "__Schema": {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.Schema"}},
420 "Float": {Model: StringList{"github.com/99designs/gqlgen/graphql.Float"}},
421 "String": {Model: StringList{"github.com/99designs/gqlgen/graphql.String"}},
422 "Boolean": {Model: StringList{"github.com/99designs/gqlgen/graphql.Boolean"}},
423 "Int": {Model: StringList{
424 "github.com/99designs/gqlgen/graphql.Int",
425 "github.com/99designs/gqlgen/graphql.Int32",
426 "github.com/99designs/gqlgen/graphql.Int64",
427 }},
428 "ID": {
429 Model: StringList{
430 "github.com/99designs/gqlgen/graphql.ID",
431 "github.com/99designs/gqlgen/graphql.IntID",
432 },
433 },
434 }
435
436 for typeName, entry := range builtins {
437 if !c.Models.Exists(typeName) {
438 c.Models[typeName] = entry
439 }
440 }
441
442 // These are additional types that are injected if defined in the schema as scalars.
443 extraBuiltins := TypeMap{
444 "Time": {Model: StringList{"github.com/99designs/gqlgen/graphql.Time"}},
445 "Map": {Model: StringList{"github.com/99designs/gqlgen/graphql.Map"}},
446 "Upload": {Model: StringList{"github.com/99designs/gqlgen/graphql.Upload"}},
447 "Any": {Model: StringList{"github.com/99designs/gqlgen/graphql.Any"}},
448 }
449
450 for typeName, entry := range extraBuiltins {
451 if t, ok := s.Types[typeName]; !c.Models.Exists(typeName) && ok && t.Kind == ast.Scalar {
452 c.Models[typeName] = entry
453 }
454 }
455}
456
457func (c *Config) LoadSchema() (*ast.Schema, map[string]string, error) {
458 schemaStrings := map[string]string{}
459
460 var sources []*ast.Source
461
462 for _, filename := range c.SchemaFilename {
463 filename = filepath.ToSlash(filename)
464 var err error
465 var schemaRaw []byte
466 schemaRaw, err = ioutil.ReadFile(filename)
467 if err != nil {
468 fmt.Fprintln(os.Stderr, "unable to open schema: "+err.Error())
469 os.Exit(1)
470 }
471 schemaStrings[filename] = string(schemaRaw)
472 sources = append(sources, &ast.Source{Name: filename, Input: schemaStrings[filename]})
473 }
474
475 schema, err := gqlparser.LoadSchema(sources...)
476 if err != nil {
477 return nil, nil, err
478 }
479 return schema, schemaStrings, nil
480}
481
482func abs(path string) string {
483 absPath, err := filepath.Abs(path)
484 if err != nil {
485 panic(err)
486 }
487 return filepath.ToSlash(absPath)
488}