config.go

  1package config
  2
  3import (
  4	"fmt"
  5	"go/types"
  6	"io/ioutil"
  7	"os"
  8	"path/filepath"
  9	"regexp"
 10	"sort"
 11	"strings"
 12
 13	"golang.org/x/tools/go/packages"
 14
 15	"github.com/99designs/gqlgen/internal/code"
 16	"github.com/pkg/errors"
 17	"github.com/vektah/gqlparser"
 18	"github.com/vektah/gqlparser/ast"
 19	yaml "gopkg.in/yaml.v2"
 20)
 21
 22type Config struct {
 23	SchemaFilename StringList                 `yaml:"schema,omitempty"`
 24	Exec           PackageConfig              `yaml:"exec"`
 25	Model          PackageConfig              `yaml:"model"`
 26	Resolver       PackageConfig              `yaml:"resolver,omitempty"`
 27	AutoBind       []string                   `yaml:"autobind"`
 28	Models         TypeMap                    `yaml:"models,omitempty"`
 29	StructTag      string                     `yaml:"struct_tag,omitempty"`
 30	Directives     map[string]DirectiveConfig `yaml:"directives,omitempty"`
 31}
 32
 33var cfgFilenames = []string{".gqlgen.yml", "gqlgen.yml", "gqlgen.yaml"}
 34
 35// DefaultConfig creates a copy of the default config
 36func DefaultConfig() *Config {
 37	return &Config{
 38		SchemaFilename: StringList{"schema.graphql"},
 39		Model:          PackageConfig{Filename: "models_gen.go"},
 40		Exec:           PackageConfig{Filename: "generated.go"},
 41		Directives: map[string]DirectiveConfig{
 42			"skip": {
 43				SkipRuntime: true,
 44			},
 45			"include": {
 46				SkipRuntime: true,
 47			},
 48			"deprecated": {
 49				SkipRuntime: true,
 50			},
 51		},
 52	}
 53}
 54
 55// LoadConfigFromDefaultLocations looks for a config file in the current directory, and all parent directories
 56// walking up the tree. The closest config file will be returned.
 57func LoadConfigFromDefaultLocations() (*Config, error) {
 58	cfgFile, err := findCfg()
 59	if err != nil {
 60		return nil, err
 61	}
 62
 63	err = os.Chdir(filepath.Dir(cfgFile))
 64	if err != nil {
 65		return nil, errors.Wrap(err, "unable to enter config dir")
 66	}
 67	return LoadConfig(cfgFile)
 68}
 69
 70var path2regex = strings.NewReplacer(
 71	`.`, `\.`,
 72	`*`, `.+`,
 73	`\`, `[\\/]`,
 74	`/`, `[\\/]`,
 75)
 76
 77// LoadConfig reads the gqlgen.yml config file
 78func LoadConfig(filename string) (*Config, error) {
 79	config := DefaultConfig()
 80
 81	b, err := ioutil.ReadFile(filename)
 82	if err != nil {
 83		return nil, errors.Wrap(err, "unable to read config")
 84	}
 85
 86	if err := yaml.UnmarshalStrict(b, config); err != nil {
 87		return nil, errors.Wrap(err, "unable to parse config")
 88	}
 89
 90	preGlobbing := config.SchemaFilename
 91	config.SchemaFilename = StringList{}
 92	for _, f := range preGlobbing {
 93		var matches []string
 94
 95		// for ** we want to override default globbing patterns and walk all
 96		// subdirectories to match schema files.
 97		if strings.Contains(f, "**") {
 98			pathParts := strings.SplitN(f, "**", 2)
 99			rest := strings.TrimPrefix(strings.TrimPrefix(pathParts[1], `\`), `/`)
100			// turn the rest of the glob into a regex, anchored only at the end because ** allows
101			// for any number of dirs in between and walk will let us match against the full path name
102			globRe := regexp.MustCompile(path2regex.Replace(rest) + `$`)
103
104			if err := filepath.Walk(pathParts[0], func(path string, info os.FileInfo, err error) error {
105				if err != nil {
106					return err
107				}
108
109				if globRe.MatchString(strings.TrimPrefix(path, pathParts[0])) {
110					matches = append(matches, path)
111				}
112
113				return nil
114			}); err != nil {
115				return nil, errors.Wrapf(err, "failed to walk schema at root %s", pathParts[0])
116			}
117		} else {
118			matches, err = filepath.Glob(f)
119			if err != nil {
120				return nil, errors.Wrapf(err, "failed to glob schema filename %s", f)
121			}
122		}
123
124		for _, m := range matches {
125			if config.SchemaFilename.Has(m) {
126				continue
127			}
128			config.SchemaFilename = append(config.SchemaFilename, m)
129		}
130	}
131
132	return config, nil
133}
134
135type PackageConfig struct {
136	Filename string `yaml:"filename,omitempty"`
137	Package  string `yaml:"package,omitempty"`
138	Type     string `yaml:"type,omitempty"`
139}
140
141type TypeMapEntry struct {
142	Model  StringList              `yaml:"model"`
143	Fields map[string]TypeMapField `yaml:"fields,omitempty"`
144}
145
146type TypeMapField struct {
147	Resolver  bool   `yaml:"resolver"`
148	FieldName string `yaml:"fieldName"`
149}
150
151type StringList []string
152
153func (a *StringList) UnmarshalYAML(unmarshal func(interface{}) error) error {
154	var single string
155	err := unmarshal(&single)
156	if err == nil {
157		*a = []string{single}
158		return nil
159	}
160
161	var multi []string
162	err = unmarshal(&multi)
163	if err != nil {
164		return err
165	}
166
167	*a = multi
168	return nil
169}
170
171func (a StringList) Has(file string) bool {
172	for _, existing := range a {
173		if existing == file {
174			return true
175		}
176	}
177	return false
178}
179
180func (c *PackageConfig) normalize() error {
181	if c.Filename == "" {
182		return errors.New("Filename is required")
183	}
184	c.Filename = abs(c.Filename)
185	// If Package is not set, first attempt to load the package at the output dir. If that fails
186	// fallback to just the base dir name of the output filename.
187	if c.Package == "" {
188		c.Package = code.NameForDir(c.Dir())
189	}
190
191	return nil
192}
193
194func (c *PackageConfig) ImportPath() string {
195	return code.ImportPathForDir(c.Dir())
196}
197
198func (c *PackageConfig) Dir() string {
199	return filepath.Dir(c.Filename)
200}
201
202func (c *PackageConfig) Check() error {
203	if strings.ContainsAny(c.Package, "./\\") {
204		return fmt.Errorf("package should be the output package name only, do not include the output filename")
205	}
206	if c.Filename != "" && !strings.HasSuffix(c.Filename, ".go") {
207		return fmt.Errorf("filename should be path to a go source file")
208	}
209
210	return c.normalize()
211}
212
213func (c *PackageConfig) Pkg() *types.Package {
214	return types.NewPackage(c.ImportPath(), c.Dir())
215}
216
217func (c *PackageConfig) IsDefined() bool {
218	return c.Filename != ""
219}
220
221func (c *Config) Check() error {
222	if err := c.Models.Check(); err != nil {
223		return errors.Wrap(err, "config.models")
224	}
225	if err := c.Exec.Check(); err != nil {
226		return errors.Wrap(err, "config.exec")
227	}
228	if err := c.Model.Check(); err != nil {
229		return errors.Wrap(err, "config.model")
230	}
231	if c.Resolver.IsDefined() {
232		if err := c.Resolver.Check(); err != nil {
233			return errors.Wrap(err, "config.resolver")
234		}
235	}
236
237	// check packages names against conflict, if present in the same dir
238	// and check filenames for uniqueness
239	packageConfigList := []PackageConfig{
240		c.Model,
241		c.Exec,
242		c.Resolver,
243	}
244	filesMap := make(map[string]bool)
245	pkgConfigsByDir := make(map[string]PackageConfig)
246	for _, current := range packageConfigList {
247		_, fileFound := filesMap[current.Filename]
248		if fileFound {
249			return fmt.Errorf("filename %s defined more than once", current.Filename)
250		}
251		filesMap[current.Filename] = true
252		previous, inSameDir := pkgConfigsByDir[current.Dir()]
253		if inSameDir && current.Package != previous.Package {
254			return fmt.Errorf("filenames %s and %s are in the same directory but have different package definitions", stripPath(current.Filename), stripPath(previous.Filename))
255		}
256		pkgConfigsByDir[current.Dir()] = current
257	}
258
259	return c.normalize()
260}
261
262func stripPath(path string) string {
263	return filepath.Base(path)
264}
265
266type TypeMap map[string]TypeMapEntry
267
268func (tm TypeMap) Exists(typeName string) bool {
269	_, ok := tm[typeName]
270	return ok
271}
272
273func (tm TypeMap) UserDefined(typeName string) bool {
274	m, ok := tm[typeName]
275	return ok && len(m.Model) > 0
276}
277
278func (tm TypeMap) Check() error {
279	for typeName, entry := range tm {
280		for _, model := range entry.Model {
281			if strings.LastIndex(model, ".") < strings.LastIndex(model, "/") {
282				return fmt.Errorf("model %s: invalid type specifier \"%s\" - you need to specify a struct to map to", typeName, entry.Model)
283			}
284		}
285	}
286	return nil
287}
288
289func (tm TypeMap) ReferencedPackages() []string {
290	var pkgs []string
291
292	for _, typ := range tm {
293		for _, model := range typ.Model {
294			if model == "map[string]interface{}" || model == "interface{}" {
295				continue
296			}
297			pkg, _ := code.PkgAndType(model)
298			if pkg == "" || inStrSlice(pkgs, pkg) {
299				continue
300			}
301			pkgs = append(pkgs, code.QualifyPackagePath(pkg))
302		}
303	}
304
305	sort.Slice(pkgs, func(i, j int) bool {
306		return pkgs[i] > pkgs[j]
307	})
308	return pkgs
309}
310
311func (tm TypeMap) Add(Name string, goType string) {
312	modelCfg := tm[Name]
313	modelCfg.Model = append(modelCfg.Model, goType)
314	tm[Name] = modelCfg
315}
316
317type DirectiveConfig struct {
318	SkipRuntime bool `yaml:"skip_runtime"`
319}
320
321func inStrSlice(haystack []string, needle string) bool {
322	for _, v := range haystack {
323		if needle == v {
324			return true
325		}
326	}
327
328	return false
329}
330
331// findCfg searches for the config file in this directory and all parents up the tree
332// looking for the closest match
333func findCfg() (string, error) {
334	dir, err := os.Getwd()
335	if err != nil {
336		return "", errors.Wrap(err, "unable to get working dir to findCfg")
337	}
338
339	cfg := findCfgInDir(dir)
340
341	for cfg == "" && dir != filepath.Dir(dir) {
342		dir = filepath.Dir(dir)
343		cfg = findCfgInDir(dir)
344	}
345
346	if cfg == "" {
347		return "", os.ErrNotExist
348	}
349
350	return cfg, nil
351}
352
353func findCfgInDir(dir string) string {
354	for _, cfgName := range cfgFilenames {
355		path := filepath.Join(dir, cfgName)
356		if _, err := os.Stat(path); err == nil {
357			return path
358		}
359	}
360	return ""
361}
362
363func (c *Config) normalize() error {
364	if err := c.Model.normalize(); err != nil {
365		return errors.Wrap(err, "model")
366	}
367
368	if err := c.Exec.normalize(); err != nil {
369		return errors.Wrap(err, "exec")
370	}
371
372	if c.Resolver.IsDefined() {
373		if err := c.Resolver.normalize(); err != nil {
374			return errors.Wrap(err, "resolver")
375		}
376	}
377
378	if c.Models == nil {
379		c.Models = TypeMap{}
380	}
381
382	return nil
383}
384
385func (c *Config) Autobind(s *ast.Schema) error {
386	if len(c.AutoBind) == 0 {
387		return nil
388	}
389	ps, err := packages.Load(&packages.Config{Mode: packages.LoadTypes}, c.AutoBind...)
390	if err != nil {
391		return err
392	}
393
394	for _, t := range s.Types {
395		if c.Models.UserDefined(t.Name) {
396			continue
397		}
398
399		for _, p := range ps {
400			if t := p.Types.Scope().Lookup(t.Name); t != nil {
401				c.Models.Add(t.Name(), t.Pkg().Path()+"."+t.Name())
402				break
403			}
404		}
405	}
406
407	return nil
408}
409
410func (c *Config) InjectBuiltins(s *ast.Schema) {
411	builtins := TypeMap{
412		"__Directive":         {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.Directive"}},
413		"__DirectiveLocation": {Model: StringList{"github.com/99designs/gqlgen/graphql.String"}},
414		"__Type":              {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.Type"}},
415		"__TypeKind":          {Model: StringList{"github.com/99designs/gqlgen/graphql.String"}},
416		"__Field":             {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.Field"}},
417		"__EnumValue":         {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.EnumValue"}},
418		"__InputValue":        {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.InputValue"}},
419		"__Schema":            {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.Schema"}},
420		"Float":               {Model: StringList{"github.com/99designs/gqlgen/graphql.Float"}},
421		"String":              {Model: StringList{"github.com/99designs/gqlgen/graphql.String"}},
422		"Boolean":             {Model: StringList{"github.com/99designs/gqlgen/graphql.Boolean"}},
423		"Int": {Model: StringList{
424			"github.com/99designs/gqlgen/graphql.Int",
425			"github.com/99designs/gqlgen/graphql.Int32",
426			"github.com/99designs/gqlgen/graphql.Int64",
427		}},
428		"ID": {
429			Model: StringList{
430				"github.com/99designs/gqlgen/graphql.ID",
431				"github.com/99designs/gqlgen/graphql.IntID",
432			},
433		},
434	}
435
436	for typeName, entry := range builtins {
437		if !c.Models.Exists(typeName) {
438			c.Models[typeName] = entry
439		}
440	}
441
442	// These are additional types that are injected if defined in the schema as scalars.
443	extraBuiltins := TypeMap{
444		"Time":   {Model: StringList{"github.com/99designs/gqlgen/graphql.Time"}},
445		"Map":    {Model: StringList{"github.com/99designs/gqlgen/graphql.Map"}},
446		"Upload": {Model: StringList{"github.com/99designs/gqlgen/graphql.Upload"}},
447		"Any":    {Model: StringList{"github.com/99designs/gqlgen/graphql.Any"}},
448	}
449
450	for typeName, entry := range extraBuiltins {
451		if t, ok := s.Types[typeName]; !c.Models.Exists(typeName) && ok && t.Kind == ast.Scalar {
452			c.Models[typeName] = entry
453		}
454	}
455}
456
457func (c *Config) LoadSchema() (*ast.Schema, map[string]string, error) {
458	schemaStrings := map[string]string{}
459
460	var sources []*ast.Source
461
462	for _, filename := range c.SchemaFilename {
463		filename = filepath.ToSlash(filename)
464		var err error
465		var schemaRaw []byte
466		schemaRaw, err = ioutil.ReadFile(filename)
467		if err != nil {
468			fmt.Fprintln(os.Stderr, "unable to open schema: "+err.Error())
469			os.Exit(1)
470		}
471		schemaStrings[filename] = string(schemaRaw)
472		sources = append(sources, &ast.Source{Name: filename, Input: schemaStrings[filename]})
473	}
474
475	schema, err := gqlparser.LoadSchema(sources...)
476	if err != nil {
477		return nil, nil, err
478	}
479	return schema, schemaStrings, nil
480}
481
482func abs(path string) string {
483	absPath, err := filepath.Abs(path)
484	if err != nil {
485		panic(err)
486	}
487	return filepath.ToSlash(absPath)
488}