1package config
2
3import (
4 "fmt"
5 "go/types"
6 "io/ioutil"
7 "os"
8 "path/filepath"
9 "sort"
10 "strings"
11
12 "github.com/99designs/gqlgen/internal/code"
13 "github.com/pkg/errors"
14 "github.com/vektah/gqlparser"
15 "github.com/vektah/gqlparser/ast"
16 yaml "gopkg.in/yaml.v2"
17)
18
19type Config struct {
20 SchemaFilename StringList `yaml:"schema,omitempty"`
21 Exec PackageConfig `yaml:"exec"`
22 Model PackageConfig `yaml:"model"`
23 Resolver PackageConfig `yaml:"resolver,omitempty"`
24 Models TypeMap `yaml:"models,omitempty"`
25 StructTag string `yaml:"struct_tag,omitempty"`
26}
27
28var cfgFilenames = []string{".gqlgen.yml", "gqlgen.yml", "gqlgen.yaml"}
29
30// DefaultConfig creates a copy of the default config
31func DefaultConfig() *Config {
32 return &Config{
33 SchemaFilename: StringList{"schema.graphql"},
34 Model: PackageConfig{Filename: "models_gen.go"},
35 Exec: PackageConfig{Filename: "generated.go"},
36 }
37}
38
39// LoadConfigFromDefaultLocations looks for a config file in the current directory, and all parent directories
40// walking up the tree. The closest config file will be returned.
41func LoadConfigFromDefaultLocations() (*Config, error) {
42 cfgFile, err := findCfg()
43 if err != nil {
44 return nil, err
45 }
46
47 err = os.Chdir(filepath.Dir(cfgFile))
48 if err != nil {
49 return nil, errors.Wrap(err, "unable to enter config dir")
50 }
51 return LoadConfig(cfgFile)
52}
53
54// LoadConfig reads the gqlgen.yml config file
55func LoadConfig(filename string) (*Config, error) {
56 config := DefaultConfig()
57
58 b, err := ioutil.ReadFile(filename)
59 if err != nil {
60 return nil, errors.Wrap(err, "unable to read config")
61 }
62
63 if err := yaml.UnmarshalStrict(b, config); err != nil {
64 return nil, errors.Wrap(err, "unable to parse config")
65 }
66
67 preGlobbing := config.SchemaFilename
68 config.SchemaFilename = StringList{}
69 for _, f := range preGlobbing {
70 matches, err := filepath.Glob(f)
71 if err != nil {
72 return nil, errors.Wrapf(err, "failed to glob schema filename %s", f)
73 }
74
75 for _, m := range matches {
76 if config.SchemaFilename.Has(m) {
77 continue
78 }
79 config.SchemaFilename = append(config.SchemaFilename, m)
80 }
81 }
82
83 return config, nil
84}
85
86type PackageConfig struct {
87 Filename string `yaml:"filename,omitempty"`
88 Package string `yaml:"package,omitempty"`
89 Type string `yaml:"type,omitempty"`
90}
91
92type TypeMapEntry struct {
93 Model StringList `yaml:"model"`
94 Fields map[string]TypeMapField `yaml:"fields,omitempty"`
95}
96
97type TypeMapField struct {
98 Resolver bool `yaml:"resolver"`
99 FieldName string `yaml:"fieldName"`
100}
101
102type StringList []string
103
104func (a *StringList) UnmarshalYAML(unmarshal func(interface{}) error) error {
105 var single string
106 err := unmarshal(&single)
107 if err == nil {
108 *a = []string{single}
109 return nil
110 }
111
112 var multi []string
113 err = unmarshal(&multi)
114 if err != nil {
115 return err
116 }
117
118 *a = multi
119 return nil
120}
121
122func (a StringList) Has(file string) bool {
123 for _, existing := range a {
124 if existing == file {
125 return true
126 }
127 }
128 return false
129}
130
131func (c *PackageConfig) normalize() error {
132 if c.Filename == "" {
133 return errors.New("Filename is required")
134 }
135 c.Filename = abs(c.Filename)
136 // If Package is not set, first attempt to load the package at the output dir. If that fails
137 // fallback to just the base dir name of the output filename.
138 if c.Package == "" {
139 c.Package = code.NameForPackage(c.ImportPath())
140 }
141
142 return nil
143}
144
145func (c *PackageConfig) ImportPath() string {
146 return code.ImportPathForDir(c.Dir())
147}
148
149func (c *PackageConfig) Dir() string {
150 return filepath.Dir(c.Filename)
151}
152
153func (c *PackageConfig) Check() error {
154 if strings.ContainsAny(c.Package, "./\\") {
155 return fmt.Errorf("package should be the output package name only, do not include the output filename")
156 }
157 if c.Filename != "" && !strings.HasSuffix(c.Filename, ".go") {
158 return fmt.Errorf("filename should be path to a go source file")
159 }
160
161 return c.normalize()
162}
163
164func (c *PackageConfig) Pkg() *types.Package {
165 return types.NewPackage(c.ImportPath(), c.Dir())
166}
167
168func (c *PackageConfig) IsDefined() bool {
169 return c.Filename != ""
170}
171
172func (c *Config) Check() error {
173 if err := c.Models.Check(); err != nil {
174 return errors.Wrap(err, "config.models")
175 }
176 if err := c.Exec.Check(); err != nil {
177 return errors.Wrap(err, "config.exec")
178 }
179 if err := c.Model.Check(); err != nil {
180 return errors.Wrap(err, "config.model")
181 }
182 if c.Resolver.IsDefined() {
183 if err := c.Resolver.Check(); err != nil {
184 return errors.Wrap(err, "config.resolver")
185 }
186 }
187
188 // check packages names against conflict, if present in the same dir
189 // and check filenames for uniqueness
190 packageConfigList := []PackageConfig{
191 c.Model,
192 c.Exec,
193 c.Resolver,
194 }
195 filesMap := make(map[string]bool)
196 pkgConfigsByDir := make(map[string]PackageConfig)
197 for _, current := range packageConfigList {
198 _, fileFound := filesMap[current.Filename]
199 if fileFound {
200 return fmt.Errorf("filename %s defined more than once", current.Filename)
201 }
202 filesMap[current.Filename] = true
203 previous, inSameDir := pkgConfigsByDir[current.Dir()]
204 if inSameDir && current.Package != previous.Package {
205 return fmt.Errorf("filenames %s and %s are in the same directory but have different package definitions", stripPath(current.Filename), stripPath(previous.Filename))
206 }
207 pkgConfigsByDir[current.Dir()] = current
208 }
209
210 return c.normalize()
211}
212
213func stripPath(path string) string {
214 return filepath.Base(path)
215}
216
217type TypeMap map[string]TypeMapEntry
218
219func (tm TypeMap) Exists(typeName string) bool {
220 _, ok := tm[typeName]
221 return ok
222}
223
224func (tm TypeMap) UserDefined(typeName string) bool {
225 m, ok := tm[typeName]
226 return ok && len(m.Model) > 0
227}
228
229func (tm TypeMap) Check() error {
230 for typeName, entry := range tm {
231 for _, model := range entry.Model {
232 if strings.LastIndex(model, ".") < strings.LastIndex(model, "/") {
233 return fmt.Errorf("model %s: invalid type specifier \"%s\" - you need to specify a struct to map to", typeName, entry.Model)
234 }
235 }
236 }
237 return nil
238}
239
240func (tm TypeMap) ReferencedPackages() []string {
241 var pkgs []string
242
243 for _, typ := range tm {
244 for _, model := range typ.Model {
245 if model == "map[string]interface{}" || model == "interface{}" {
246 continue
247 }
248 pkg, _ := code.PkgAndType(model)
249 if pkg == "" || inStrSlice(pkgs, pkg) {
250 continue
251 }
252 pkgs = append(pkgs, code.QualifyPackagePath(pkg))
253 }
254 }
255
256 sort.Slice(pkgs, func(i, j int) bool {
257 return pkgs[i] > pkgs[j]
258 })
259 return pkgs
260}
261
262func (tm TypeMap) Add(Name string, goType string) {
263 modelCfg := tm[Name]
264 modelCfg.Model = append(modelCfg.Model, goType)
265 tm[Name] = modelCfg
266}
267
268func inStrSlice(haystack []string, needle string) bool {
269 for _, v := range haystack {
270 if needle == v {
271 return true
272 }
273 }
274
275 return false
276}
277
278// findCfg searches for the config file in this directory and all parents up the tree
279// looking for the closest match
280func findCfg() (string, error) {
281 dir, err := os.Getwd()
282 if err != nil {
283 return "", errors.Wrap(err, "unable to get working dir to findCfg")
284 }
285
286 cfg := findCfgInDir(dir)
287
288 for cfg == "" && dir != filepath.Dir(dir) {
289 dir = filepath.Dir(dir)
290 cfg = findCfgInDir(dir)
291 }
292
293 if cfg == "" {
294 return "", os.ErrNotExist
295 }
296
297 return cfg, nil
298}
299
300func findCfgInDir(dir string) string {
301 for _, cfgName := range cfgFilenames {
302 path := filepath.Join(dir, cfgName)
303 if _, err := os.Stat(path); err == nil {
304 return path
305 }
306 }
307 return ""
308}
309
310func (c *Config) normalize() error {
311 if err := c.Model.normalize(); err != nil {
312 return errors.Wrap(err, "model")
313 }
314
315 if err := c.Exec.normalize(); err != nil {
316 return errors.Wrap(err, "exec")
317 }
318
319 if c.Resolver.IsDefined() {
320 if err := c.Resolver.normalize(); err != nil {
321 return errors.Wrap(err, "resolver")
322 }
323 }
324
325 if c.Models == nil {
326 c.Models = TypeMap{}
327 }
328
329 return nil
330}
331
332func (c *Config) InjectBuiltins(s *ast.Schema) {
333 builtins := TypeMap{
334 "__Directive": {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.Directive"}},
335 "__DirectiveLocation": {Model: StringList{"github.com/99designs/gqlgen/graphql.String"}},
336 "__Type": {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.Type"}},
337 "__TypeKind": {Model: StringList{"github.com/99designs/gqlgen/graphql.String"}},
338 "__Field": {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.Field"}},
339 "__EnumValue": {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.EnumValue"}},
340 "__InputValue": {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.InputValue"}},
341 "__Schema": {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.Schema"}},
342 "Float": {Model: StringList{"github.com/99designs/gqlgen/graphql.Float"}},
343 "String": {Model: StringList{"github.com/99designs/gqlgen/graphql.String"}},
344 "Boolean": {Model: StringList{"github.com/99designs/gqlgen/graphql.Boolean"}},
345 "Int": {Model: StringList{
346 "github.com/99designs/gqlgen/graphql.Int",
347 "github.com/99designs/gqlgen/graphql.Int32",
348 "github.com/99designs/gqlgen/graphql.Int64",
349 }},
350 "ID": {
351 Model: StringList{
352 "github.com/99designs/gqlgen/graphql.ID",
353 "github.com/99designs/gqlgen/graphql.IntID",
354 },
355 },
356 }
357
358 for typeName, entry := range builtins {
359 if !c.Models.Exists(typeName) {
360 c.Models[typeName] = entry
361 }
362 }
363
364 // These are additional types that are injected if defined in the schema as scalars.
365 extraBuiltins := TypeMap{
366 "Time": {Model: StringList{"github.com/99designs/gqlgen/graphql.Time"}},
367 "Map": {Model: StringList{"github.com/99designs/gqlgen/graphql.Map"}},
368 }
369
370 for typeName, entry := range extraBuiltins {
371 if t, ok := s.Types[typeName]; !c.Models.Exists(typeName) && ok && t.Kind == ast.Scalar {
372 c.Models[typeName] = entry
373 }
374 }
375}
376
377func (c *Config) LoadSchema() (*ast.Schema, map[string]string, error) {
378 schemaStrings := map[string]string{}
379
380 var sources []*ast.Source
381
382 for _, filename := range c.SchemaFilename {
383 filename = filepath.ToSlash(filename)
384 var err error
385 var schemaRaw []byte
386 schemaRaw, err = ioutil.ReadFile(filename)
387 if err != nil {
388 fmt.Fprintln(os.Stderr, "unable to open schema: "+err.Error())
389 os.Exit(1)
390 }
391 schemaStrings[filename] = string(schemaRaw)
392 sources = append(sources, &ast.Source{Name: filename, Input: schemaStrings[filename]})
393 }
394
395 schema, err := gqlparser.LoadSchema(sources...)
396 if err != nil {
397 return nil, nil, err
398 }
399 return schema, schemaStrings, nil
400}
401
402func abs(path string) string {
403 absPath, err := filepath.Abs(path)
404 if err != nil {
405 panic(err)
406 }
407 return filepath.ToSlash(absPath)
408}