config.go

  1package config
  2
  3import (
  4	"fmt"
  5	"go/types"
  6	"io/ioutil"
  7	"os"
  8	"path/filepath"
  9	"sort"
 10	"strings"
 11
 12	"github.com/99designs/gqlgen/internal/code"
 13	"github.com/pkg/errors"
 14	"github.com/vektah/gqlparser"
 15	"github.com/vektah/gqlparser/ast"
 16	yaml "gopkg.in/yaml.v2"
 17)
 18
 19type Config struct {
 20	SchemaFilename StringList    `yaml:"schema,omitempty"`
 21	Exec           PackageConfig `yaml:"exec"`
 22	Model          PackageConfig `yaml:"model"`
 23	Resolver       PackageConfig `yaml:"resolver,omitempty"`
 24	Models         TypeMap       `yaml:"models,omitempty"`
 25	StructTag      string        `yaml:"struct_tag,omitempty"`
 26}
 27
 28var cfgFilenames = []string{".gqlgen.yml", "gqlgen.yml", "gqlgen.yaml"}
 29
 30// DefaultConfig creates a copy of the default config
 31func DefaultConfig() *Config {
 32	return &Config{
 33		SchemaFilename: StringList{"schema.graphql"},
 34		Model:          PackageConfig{Filename: "models_gen.go"},
 35		Exec:           PackageConfig{Filename: "generated.go"},
 36	}
 37}
 38
 39// LoadConfigFromDefaultLocations looks for a config file in the current directory, and all parent directories
 40// walking up the tree. The closest config file will be returned.
 41func LoadConfigFromDefaultLocations() (*Config, error) {
 42	cfgFile, err := findCfg()
 43	if err != nil {
 44		return nil, err
 45	}
 46
 47	err = os.Chdir(filepath.Dir(cfgFile))
 48	if err != nil {
 49		return nil, errors.Wrap(err, "unable to enter config dir")
 50	}
 51	return LoadConfig(cfgFile)
 52}
 53
 54// LoadConfig reads the gqlgen.yml config file
 55func LoadConfig(filename string) (*Config, error) {
 56	config := DefaultConfig()
 57
 58	b, err := ioutil.ReadFile(filename)
 59	if err != nil {
 60		return nil, errors.Wrap(err, "unable to read config")
 61	}
 62
 63	if err := yaml.UnmarshalStrict(b, config); err != nil {
 64		return nil, errors.Wrap(err, "unable to parse config")
 65	}
 66
 67	preGlobbing := config.SchemaFilename
 68	config.SchemaFilename = StringList{}
 69	for _, f := range preGlobbing {
 70		matches, err := filepath.Glob(f)
 71		if err != nil {
 72			return nil, errors.Wrapf(err, "failed to glob schema filename %s", f)
 73		}
 74
 75		for _, m := range matches {
 76			if config.SchemaFilename.Has(m) {
 77				continue
 78			}
 79			config.SchemaFilename = append(config.SchemaFilename, m)
 80		}
 81	}
 82
 83	return config, nil
 84}
 85
 86type PackageConfig struct {
 87	Filename string `yaml:"filename,omitempty"`
 88	Package  string `yaml:"package,omitempty"`
 89	Type     string `yaml:"type,omitempty"`
 90}
 91
 92type TypeMapEntry struct {
 93	Model  StringList              `yaml:"model"`
 94	Fields map[string]TypeMapField `yaml:"fields,omitempty"`
 95}
 96
 97type TypeMapField struct {
 98	Resolver  bool   `yaml:"resolver"`
 99	FieldName string `yaml:"fieldName"`
100}
101
102type StringList []string
103
104func (a *StringList) UnmarshalYAML(unmarshal func(interface{}) error) error {
105	var single string
106	err := unmarshal(&single)
107	if err == nil {
108		*a = []string{single}
109		return nil
110	}
111
112	var multi []string
113	err = unmarshal(&multi)
114	if err != nil {
115		return err
116	}
117
118	*a = multi
119	return nil
120}
121
122func (a StringList) Has(file string) bool {
123	for _, existing := range a {
124		if existing == file {
125			return true
126		}
127	}
128	return false
129}
130
131func (c *PackageConfig) normalize() error {
132	if c.Filename == "" {
133		return errors.New("Filename is required")
134	}
135	c.Filename = abs(c.Filename)
136	// If Package is not set, first attempt to load the package at the output dir. If that fails
137	// fallback to just the base dir name of the output filename.
138	if c.Package == "" {
139		c.Package = code.NameForPackage(c.ImportPath())
140	}
141
142	return nil
143}
144
145func (c *PackageConfig) ImportPath() string {
146	return code.ImportPathForDir(c.Dir())
147}
148
149func (c *PackageConfig) Dir() string {
150	return filepath.Dir(c.Filename)
151}
152
153func (c *PackageConfig) Check() error {
154	if strings.ContainsAny(c.Package, "./\\") {
155		return fmt.Errorf("package should be the output package name only, do not include the output filename")
156	}
157	if c.Filename != "" && !strings.HasSuffix(c.Filename, ".go") {
158		return fmt.Errorf("filename should be path to a go source file")
159	}
160
161	return c.normalize()
162}
163
164func (c *PackageConfig) Pkg() *types.Package {
165	return types.NewPackage(c.ImportPath(), c.Dir())
166}
167
168func (c *PackageConfig) IsDefined() bool {
169	return c.Filename != ""
170}
171
172func (c *Config) Check() error {
173	if err := c.Models.Check(); err != nil {
174		return errors.Wrap(err, "config.models")
175	}
176	if err := c.Exec.Check(); err != nil {
177		return errors.Wrap(err, "config.exec")
178	}
179	if err := c.Model.Check(); err != nil {
180		return errors.Wrap(err, "config.model")
181	}
182	if c.Resolver.IsDefined() {
183		if err := c.Resolver.Check(); err != nil {
184			return errors.Wrap(err, "config.resolver")
185		}
186	}
187
188	// check packages names against conflict, if present in the same dir
189	// and check filenames for uniqueness
190	packageConfigList := []PackageConfig{
191		c.Model,
192		c.Exec,
193		c.Resolver,
194	}
195	filesMap := make(map[string]bool)
196	pkgConfigsByDir := make(map[string]PackageConfig)
197	for _, current := range packageConfigList {
198		_, fileFound := filesMap[current.Filename]
199		if fileFound {
200			return fmt.Errorf("filename %s defined more than once", current.Filename)
201		}
202		filesMap[current.Filename] = true
203		previous, inSameDir := pkgConfigsByDir[current.Dir()]
204		if inSameDir && current.Package != previous.Package {
205			return fmt.Errorf("filenames %s and %s are in the same directory but have different package definitions", stripPath(current.Filename), stripPath(previous.Filename))
206		}
207		pkgConfigsByDir[current.Dir()] = current
208	}
209
210	return c.normalize()
211}
212
213func stripPath(path string) string {
214	return filepath.Base(path)
215}
216
217type TypeMap map[string]TypeMapEntry
218
219func (tm TypeMap) Exists(typeName string) bool {
220	_, ok := tm[typeName]
221	return ok
222}
223
224func (tm TypeMap) UserDefined(typeName string) bool {
225	m, ok := tm[typeName]
226	return ok && len(m.Model) > 0
227}
228
229func (tm TypeMap) Check() error {
230	for typeName, entry := range tm {
231		for _, model := range entry.Model {
232			if strings.LastIndex(model, ".") < strings.LastIndex(model, "/") {
233				return fmt.Errorf("model %s: invalid type specifier \"%s\" - you need to specify a struct to map to", typeName, entry.Model)
234			}
235		}
236	}
237	return nil
238}
239
240func (tm TypeMap) ReferencedPackages() []string {
241	var pkgs []string
242
243	for _, typ := range tm {
244		for _, model := range typ.Model {
245			if model == "map[string]interface{}" || model == "interface{}" {
246				continue
247			}
248			pkg, _ := code.PkgAndType(model)
249			if pkg == "" || inStrSlice(pkgs, pkg) {
250				continue
251			}
252			pkgs = append(pkgs, code.QualifyPackagePath(pkg))
253		}
254	}
255
256	sort.Slice(pkgs, func(i, j int) bool {
257		return pkgs[i] > pkgs[j]
258	})
259	return pkgs
260}
261
262func (tm TypeMap) Add(Name string, goType string) {
263	modelCfg := tm[Name]
264	modelCfg.Model = append(modelCfg.Model, goType)
265	tm[Name] = modelCfg
266}
267
268func inStrSlice(haystack []string, needle string) bool {
269	for _, v := range haystack {
270		if needle == v {
271			return true
272		}
273	}
274
275	return false
276}
277
278// findCfg searches for the config file in this directory and all parents up the tree
279// looking for the closest match
280func findCfg() (string, error) {
281	dir, err := os.Getwd()
282	if err != nil {
283		return "", errors.Wrap(err, "unable to get working dir to findCfg")
284	}
285
286	cfg := findCfgInDir(dir)
287
288	for cfg == "" && dir != filepath.Dir(dir) {
289		dir = filepath.Dir(dir)
290		cfg = findCfgInDir(dir)
291	}
292
293	if cfg == "" {
294		return "", os.ErrNotExist
295	}
296
297	return cfg, nil
298}
299
300func findCfgInDir(dir string) string {
301	for _, cfgName := range cfgFilenames {
302		path := filepath.Join(dir, cfgName)
303		if _, err := os.Stat(path); err == nil {
304			return path
305		}
306	}
307	return ""
308}
309
310func (c *Config) normalize() error {
311	if err := c.Model.normalize(); err != nil {
312		return errors.Wrap(err, "model")
313	}
314
315	if err := c.Exec.normalize(); err != nil {
316		return errors.Wrap(err, "exec")
317	}
318
319	if c.Resolver.IsDefined() {
320		if err := c.Resolver.normalize(); err != nil {
321			return errors.Wrap(err, "resolver")
322		}
323	}
324
325	if c.Models == nil {
326		c.Models = TypeMap{}
327	}
328
329	return nil
330}
331
332func (c *Config) InjectBuiltins(s *ast.Schema) {
333	builtins := TypeMap{
334		"__Directive":         {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.Directive"}},
335		"__DirectiveLocation": {Model: StringList{"github.com/99designs/gqlgen/graphql.String"}},
336		"__Type":              {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.Type"}},
337		"__TypeKind":          {Model: StringList{"github.com/99designs/gqlgen/graphql.String"}},
338		"__Field":             {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.Field"}},
339		"__EnumValue":         {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.EnumValue"}},
340		"__InputValue":        {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.InputValue"}},
341		"__Schema":            {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.Schema"}},
342		"Float":               {Model: StringList{"github.com/99designs/gqlgen/graphql.Float"}},
343		"String":              {Model: StringList{"github.com/99designs/gqlgen/graphql.String"}},
344		"Boolean":             {Model: StringList{"github.com/99designs/gqlgen/graphql.Boolean"}},
345		"Int": {Model: StringList{
346			"github.com/99designs/gqlgen/graphql.Int",
347			"github.com/99designs/gqlgen/graphql.Int32",
348			"github.com/99designs/gqlgen/graphql.Int64",
349		}},
350		"ID": {
351			Model: StringList{
352				"github.com/99designs/gqlgen/graphql.ID",
353				"github.com/99designs/gqlgen/graphql.IntID",
354			},
355		},
356	}
357
358	for typeName, entry := range builtins {
359		if !c.Models.Exists(typeName) {
360			c.Models[typeName] = entry
361		}
362	}
363
364	// These are additional types that are injected if defined in the schema as scalars.
365	extraBuiltins := TypeMap{
366		"Time": {Model: StringList{"github.com/99designs/gqlgen/graphql.Time"}},
367		"Map":  {Model: StringList{"github.com/99designs/gqlgen/graphql.Map"}},
368	}
369
370	for typeName, entry := range extraBuiltins {
371		if t, ok := s.Types[typeName]; !c.Models.Exists(typeName) && ok && t.Kind == ast.Scalar {
372			c.Models[typeName] = entry
373		}
374	}
375}
376
377func (c *Config) LoadSchema() (*ast.Schema, map[string]string, error) {
378	schemaStrings := map[string]string{}
379
380	var sources []*ast.Source
381
382	for _, filename := range c.SchemaFilename {
383		filename = filepath.ToSlash(filename)
384		var err error
385		var schemaRaw []byte
386		schemaRaw, err = ioutil.ReadFile(filename)
387		if err != nil {
388			fmt.Fprintln(os.Stderr, "unable to open schema: "+err.Error())
389			os.Exit(1)
390		}
391		schemaStrings[filename] = string(schemaRaw)
392		sources = append(sources, &ast.Source{Name: filename, Input: schemaStrings[filename]})
393	}
394
395	schema, err := gqlparser.LoadSchema(sources...)
396	if err != nil {
397		return nil, nil, err
398	}
399	return schema, schemaStrings, nil
400}
401
402func abs(path string) string {
403	absPath, err := filepath.Abs(path)
404	if err != nil {
405		panic(err)
406	}
407	return filepath.ToSlash(absPath)
408}