config.go

  1package config
  2
  3import (
  4	"fmt"
  5	"go/types"
  6	"io/ioutil"
  7	"os"
  8	"path/filepath"
  9	"regexp"
 10	"sort"
 11	"strings"
 12
 13	"golang.org/x/tools/go/packages"
 14
 15	"github.com/99designs/gqlgen/internal/code"
 16	"github.com/pkg/errors"
 17	"github.com/vektah/gqlparser"
 18	"github.com/vektah/gqlparser/ast"
 19	"gopkg.in/yaml.v2"
 20)
 21
 22type Config struct {
 23	SchemaFilename StringList                 `yaml:"schema,omitempty"`
 24	Exec           PackageConfig              `yaml:"exec"`
 25	Model          PackageConfig              `yaml:"model"`
 26	Resolver       PackageConfig              `yaml:"resolver,omitempty"`
 27	AutoBind       []string                   `yaml:"autobind"`
 28	Models         TypeMap                    `yaml:"models,omitempty"`
 29	StructTag      string                     `yaml:"struct_tag,omitempty"`
 30	Directives     map[string]DirectiveConfig `yaml:"directives,omitempty"`
 31}
 32
 33var cfgFilenames = []string{".gqlgen.yml", "gqlgen.yml", "gqlgen.yaml"}
 34
 35// DefaultConfig creates a copy of the default config
 36func DefaultConfig() *Config {
 37	return &Config{
 38		SchemaFilename: StringList{"schema.graphql"},
 39		Model:          PackageConfig{Filename: "models_gen.go"},
 40		Exec:           PackageConfig{Filename: "generated.go"},
 41		Directives:     map[string]DirectiveConfig{},
 42	}
 43}
 44
 45// LoadConfigFromDefaultLocations looks for a config file in the current directory, and all parent directories
 46// walking up the tree. The closest config file will be returned.
 47func LoadConfigFromDefaultLocations() (*Config, error) {
 48	cfgFile, err := findCfg()
 49	if err != nil {
 50		return nil, err
 51	}
 52
 53	err = os.Chdir(filepath.Dir(cfgFile))
 54	if err != nil {
 55		return nil, errors.Wrap(err, "unable to enter config dir")
 56	}
 57	return LoadConfig(cfgFile)
 58}
 59
 60var path2regex = strings.NewReplacer(
 61	`.`, `\.`,
 62	`*`, `.+`,
 63	`\`, `[\\/]`,
 64	`/`, `[\\/]`,
 65)
 66
 67// LoadConfig reads the gqlgen.yml config file
 68func LoadConfig(filename string) (*Config, error) {
 69	config := DefaultConfig()
 70
 71	b, err := ioutil.ReadFile(filename)
 72	if err != nil {
 73		return nil, errors.Wrap(err, "unable to read config")
 74	}
 75
 76	if err := yaml.UnmarshalStrict(b, config); err != nil {
 77		return nil, errors.Wrap(err, "unable to parse config")
 78	}
 79
 80	defaultDirectives := map[string]DirectiveConfig{
 81		"skip":       {SkipRuntime: true},
 82		"include":    {SkipRuntime: true},
 83		"deprecated": {SkipRuntime: true},
 84	}
 85
 86	for key, value := range defaultDirectives {
 87		if _, defined := config.Directives[key]; !defined {
 88			config.Directives[key] = value
 89		}
 90	}
 91
 92	preGlobbing := config.SchemaFilename
 93	config.SchemaFilename = StringList{}
 94	for _, f := range preGlobbing {
 95		var matches []string
 96
 97		// for ** we want to override default globbing patterns and walk all
 98		// subdirectories to match schema files.
 99		if strings.Contains(f, "**") {
100			pathParts := strings.SplitN(f, "**", 2)
101			rest := strings.TrimPrefix(strings.TrimPrefix(pathParts[1], `\`), `/`)
102			// turn the rest of the glob into a regex, anchored only at the end because ** allows
103			// for any number of dirs in between and walk will let us match against the full path name
104			globRe := regexp.MustCompile(path2regex.Replace(rest) + `$`)
105
106			if err := filepath.Walk(pathParts[0], func(path string, info os.FileInfo, err error) error {
107				if err != nil {
108					return err
109				}
110
111				if globRe.MatchString(strings.TrimPrefix(path, pathParts[0])) {
112					matches = append(matches, path)
113				}
114
115				return nil
116			}); err != nil {
117				return nil, errors.Wrapf(err, "failed to walk schema at root %s", pathParts[0])
118			}
119		} else {
120			matches, err = filepath.Glob(f)
121			if err != nil {
122				return nil, errors.Wrapf(err, "failed to glob schema filename %s", f)
123			}
124		}
125
126		for _, m := range matches {
127			if config.SchemaFilename.Has(m) {
128				continue
129			}
130			config.SchemaFilename = append(config.SchemaFilename, m)
131		}
132	}
133
134	return config, nil
135}
136
137type PackageConfig struct {
138	Filename string `yaml:"filename,omitempty"`
139	Package  string `yaml:"package,omitempty"`
140	Type     string `yaml:"type,omitempty"`
141}
142
143type TypeMapEntry struct {
144	Model  StringList              `yaml:"model"`
145	Fields map[string]TypeMapField `yaml:"fields,omitempty"`
146}
147
148type TypeMapField struct {
149	Resolver  bool   `yaml:"resolver"`
150	FieldName string `yaml:"fieldName"`
151}
152
153type StringList []string
154
155func (a *StringList) UnmarshalYAML(unmarshal func(interface{}) error) error {
156	var single string
157	err := unmarshal(&single)
158	if err == nil {
159		*a = []string{single}
160		return nil
161	}
162
163	var multi []string
164	err = unmarshal(&multi)
165	if err != nil {
166		return err
167	}
168
169	*a = multi
170	return nil
171}
172
173func (a StringList) Has(file string) bool {
174	for _, existing := range a {
175		if existing == file {
176			return true
177		}
178	}
179	return false
180}
181
182func (c *PackageConfig) normalize() error {
183	if c.Filename == "" {
184		return errors.New("Filename is required")
185	}
186	c.Filename = abs(c.Filename)
187	// If Package is not set, first attempt to load the package at the output dir. If that fails
188	// fallback to just the base dir name of the output filename.
189	if c.Package == "" {
190		c.Package = code.NameForDir(c.Dir())
191	}
192
193	return nil
194}
195
196func (c *PackageConfig) ImportPath() string {
197	return code.ImportPathForDir(c.Dir())
198}
199
200func (c *PackageConfig) Dir() string {
201	return filepath.Dir(c.Filename)
202}
203
204func (c *PackageConfig) Check() error {
205	if strings.ContainsAny(c.Package, "./\\") {
206		return fmt.Errorf("package should be the output package name only, do not include the output filename")
207	}
208	if c.Filename != "" && !strings.HasSuffix(c.Filename, ".go") {
209		return fmt.Errorf("filename should be path to a go source file")
210	}
211
212	return c.normalize()
213}
214
215func (c *PackageConfig) Pkg() *types.Package {
216	return types.NewPackage(c.ImportPath(), c.Dir())
217}
218
219func (c *PackageConfig) IsDefined() bool {
220	return c.Filename != ""
221}
222
223func (c *Config) Check() error {
224	if err := c.Models.Check(); err != nil {
225		return errors.Wrap(err, "config.models")
226	}
227	if err := c.Exec.Check(); err != nil {
228		return errors.Wrap(err, "config.exec")
229	}
230	if err := c.Model.Check(); err != nil {
231		return errors.Wrap(err, "config.model")
232	}
233	if c.Resolver.IsDefined() {
234		if err := c.Resolver.Check(); err != nil {
235			return errors.Wrap(err, "config.resolver")
236		}
237	}
238
239	// check packages names against conflict, if present in the same dir
240	// and check filenames for uniqueness
241	packageConfigList := []PackageConfig{
242		c.Model,
243		c.Exec,
244		c.Resolver,
245	}
246	filesMap := make(map[string]bool)
247	pkgConfigsByDir := make(map[string]PackageConfig)
248	for _, current := range packageConfigList {
249		_, fileFound := filesMap[current.Filename]
250		if fileFound {
251			return fmt.Errorf("filename %s defined more than once", current.Filename)
252		}
253		filesMap[current.Filename] = true
254		previous, inSameDir := pkgConfigsByDir[current.Dir()]
255		if inSameDir && current.Package != previous.Package {
256			return fmt.Errorf("filenames %s and %s are in the same directory but have different package definitions", stripPath(current.Filename), stripPath(previous.Filename))
257		}
258		pkgConfigsByDir[current.Dir()] = current
259	}
260
261	return c.normalize()
262}
263
264func stripPath(path string) string {
265	return filepath.Base(path)
266}
267
268type TypeMap map[string]TypeMapEntry
269
270func (tm TypeMap) Exists(typeName string) bool {
271	_, ok := tm[typeName]
272	return ok
273}
274
275func (tm TypeMap) UserDefined(typeName string) bool {
276	m, ok := tm[typeName]
277	return ok && len(m.Model) > 0
278}
279
280func (tm TypeMap) Check() error {
281	for typeName, entry := range tm {
282		for _, model := range entry.Model {
283			if strings.LastIndex(model, ".") < strings.LastIndex(model, "/") {
284				return fmt.Errorf("model %s: invalid type specifier \"%s\" - you need to specify a struct to map to", typeName, entry.Model)
285			}
286		}
287	}
288	return nil
289}
290
291func (tm TypeMap) ReferencedPackages() []string {
292	var pkgs []string
293
294	for _, typ := range tm {
295		for _, model := range typ.Model {
296			if model == "map[string]interface{}" || model == "interface{}" {
297				continue
298			}
299			pkg, _ := code.PkgAndType(model)
300			if pkg == "" || inStrSlice(pkgs, pkg) {
301				continue
302			}
303			pkgs = append(pkgs, code.QualifyPackagePath(pkg))
304		}
305	}
306
307	sort.Slice(pkgs, func(i, j int) bool {
308		return pkgs[i] > pkgs[j]
309	})
310	return pkgs
311}
312
313func (tm TypeMap) Add(name string, goType string) {
314	modelCfg := tm[name]
315	modelCfg.Model = append(modelCfg.Model, goType)
316	tm[name] = modelCfg
317}
318
319type DirectiveConfig struct {
320	SkipRuntime bool `yaml:"skip_runtime"`
321}
322
323func inStrSlice(haystack []string, needle string) bool {
324	for _, v := range haystack {
325		if needle == v {
326			return true
327		}
328	}
329
330	return false
331}
332
333// findCfg searches for the config file in this directory and all parents up the tree
334// looking for the closest match
335func findCfg() (string, error) {
336	dir, err := os.Getwd()
337	if err != nil {
338		return "", errors.Wrap(err, "unable to get working dir to findCfg")
339	}
340
341	cfg := findCfgInDir(dir)
342
343	for cfg == "" && dir != filepath.Dir(dir) {
344		dir = filepath.Dir(dir)
345		cfg = findCfgInDir(dir)
346	}
347
348	if cfg == "" {
349		return "", os.ErrNotExist
350	}
351
352	return cfg, nil
353}
354
355func findCfgInDir(dir string) string {
356	for _, cfgName := range cfgFilenames {
357		path := filepath.Join(dir, cfgName)
358		if _, err := os.Stat(path); err == nil {
359			return path
360		}
361	}
362	return ""
363}
364
365func (c *Config) normalize() error {
366	if err := c.Model.normalize(); err != nil {
367		return errors.Wrap(err, "model")
368	}
369
370	if err := c.Exec.normalize(); err != nil {
371		return errors.Wrap(err, "exec")
372	}
373
374	if c.Resolver.IsDefined() {
375		if err := c.Resolver.normalize(); err != nil {
376			return errors.Wrap(err, "resolver")
377		}
378	}
379
380	if c.Models == nil {
381		c.Models = TypeMap{}
382	}
383
384	return nil
385}
386
387func (c *Config) Autobind(s *ast.Schema) error {
388	if len(c.AutoBind) == 0 {
389		return nil
390	}
391	ps, err := packages.Load(&packages.Config{Mode: packages.LoadTypes}, c.AutoBind...)
392	if err != nil {
393		return err
394	}
395
396	for _, t := range s.Types {
397		if c.Models.UserDefined(t.Name) {
398			continue
399		}
400
401		for _, p := range ps {
402			if t := p.Types.Scope().Lookup(t.Name); t != nil {
403				c.Models.Add(t.Name(), t.Pkg().Path()+"."+t.Name())
404				break
405			}
406		}
407	}
408
409	return nil
410}
411
412func (c *Config) InjectBuiltins(s *ast.Schema) {
413	builtins := TypeMap{
414		"__Directive":         {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.Directive"}},
415		"__DirectiveLocation": {Model: StringList{"github.com/99designs/gqlgen/graphql.String"}},
416		"__Type":              {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.Type"}},
417		"__TypeKind":          {Model: StringList{"github.com/99designs/gqlgen/graphql.String"}},
418		"__Field":             {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.Field"}},
419		"__EnumValue":         {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.EnumValue"}},
420		"__InputValue":        {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.InputValue"}},
421		"__Schema":            {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.Schema"}},
422		"Float":               {Model: StringList{"github.com/99designs/gqlgen/graphql.Float"}},
423		"String":              {Model: StringList{"github.com/99designs/gqlgen/graphql.String"}},
424		"Boolean":             {Model: StringList{"github.com/99designs/gqlgen/graphql.Boolean"}},
425		"Int": {Model: StringList{
426			"github.com/99designs/gqlgen/graphql.Int",
427			"github.com/99designs/gqlgen/graphql.Int32",
428			"github.com/99designs/gqlgen/graphql.Int64",
429		}},
430		"ID": {
431			Model: StringList{
432				"github.com/99designs/gqlgen/graphql.ID",
433				"github.com/99designs/gqlgen/graphql.IntID",
434			},
435		},
436	}
437
438	for typeName, entry := range builtins {
439		if !c.Models.Exists(typeName) {
440			c.Models[typeName] = entry
441		}
442	}
443
444	// These are additional types that are injected if defined in the schema as scalars.
445	extraBuiltins := TypeMap{
446		"Time":   {Model: StringList{"github.com/99designs/gqlgen/graphql.Time"}},
447		"Map":    {Model: StringList{"github.com/99designs/gqlgen/graphql.Map"}},
448		"Upload": {Model: StringList{"github.com/99designs/gqlgen/graphql.Upload"}},
449		"Any":    {Model: StringList{"github.com/99designs/gqlgen/graphql.Any"}},
450	}
451
452	for typeName, entry := range extraBuiltins {
453		if t, ok := s.Types[typeName]; !c.Models.Exists(typeName) && ok && t.Kind == ast.Scalar {
454			c.Models[typeName] = entry
455		}
456	}
457}
458
459func (c *Config) LoadSchema() (*ast.Schema, map[string]string, error) {
460	schemaStrings := map[string]string{}
461
462	sources := make([]*ast.Source, len(c.SchemaFilename))
463
464	for i, filename := range c.SchemaFilename {
465		filename = filepath.ToSlash(filename)
466		var err error
467		var schemaRaw []byte
468		schemaRaw, err = ioutil.ReadFile(filename)
469		if err != nil {
470			fmt.Fprintln(os.Stderr, "unable to open schema: "+err.Error())
471			os.Exit(1)
472		}
473		schemaStrings[filename] = string(schemaRaw)
474		sources[i] = &ast.Source{Name: filename, Input: schemaStrings[filename]}
475	}
476
477	schema, err := gqlparser.LoadSchema(sources...)
478	if err != nil {
479		return nil, nil, err
480	}
481	return schema, schemaStrings, nil
482}
483
484func abs(path string) string {
485	absPath, err := filepath.Abs(path)
486	if err != nil {
487		panic(err)
488	}
489	return filepath.ToSlash(absPath)
490}