config.go

  1package codegen
  2
  3import (
  4	"fmt"
  5	"go/build"
  6	"io/ioutil"
  7	"os"
  8	"path/filepath"
  9	"sort"
 10	"strings"
 11
 12	"github.com/99designs/gqlgen/internal/gopath"
 13	"github.com/pkg/errors"
 14	"github.com/vektah/gqlparser/ast"
 15	"gopkg.in/yaml.v2"
 16)
 17
 18var cfgFilenames = []string{".gqlgen.yml", "gqlgen.yml", "gqlgen.yaml"}
 19
 20// DefaultConfig creates a copy of the default config
 21func DefaultConfig() *Config {
 22	return &Config{
 23		SchemaFilename: SchemaFilenames{"schema.graphql"},
 24		SchemaStr:      map[string]string{},
 25		Model:          PackageConfig{Filename: "models_gen.go"},
 26		Exec:           PackageConfig{Filename: "generated.go"},
 27	}
 28}
 29
 30// LoadConfigFromDefaultLocations looks for a config file in the current directory, and all parent directories
 31// walking up the tree. The closest config file will be returned.
 32func LoadConfigFromDefaultLocations() (*Config, error) {
 33	cfgFile, err := findCfg()
 34	if err != nil {
 35		return nil, err
 36	}
 37
 38	err = os.Chdir(filepath.Dir(cfgFile))
 39	if err != nil {
 40		return nil, errors.Wrap(err, "unable to enter config dir")
 41	}
 42	return LoadConfig(cfgFile)
 43}
 44
 45// LoadConfig reads the gqlgen.yml config file
 46func LoadConfig(filename string) (*Config, error) {
 47	config := DefaultConfig()
 48
 49	b, err := ioutil.ReadFile(filename)
 50	if err != nil {
 51		return nil, errors.Wrap(err, "unable to read config")
 52	}
 53
 54	if err := yaml.UnmarshalStrict(b, config); err != nil {
 55		return nil, errors.Wrap(err, "unable to parse config")
 56	}
 57
 58	preGlobbing := config.SchemaFilename
 59	config.SchemaFilename = SchemaFilenames{}
 60	for _, f := range preGlobbing {
 61		matches, err := filepath.Glob(f)
 62		if err != nil {
 63			return nil, errors.Wrapf(err, "failed to glob schema filename %s", f)
 64		}
 65
 66		for _, m := range matches {
 67			if config.SchemaFilename.Has(m) {
 68				continue
 69			}
 70			config.SchemaFilename = append(config.SchemaFilename, m)
 71		}
 72	}
 73
 74	config.FilePath = filename
 75	config.SchemaStr = map[string]string{}
 76
 77	return config, nil
 78}
 79
 80type Config struct {
 81	SchemaFilename SchemaFilenames   `yaml:"schema,omitempty"`
 82	SchemaStr      map[string]string `yaml:"-"`
 83	Exec           PackageConfig     `yaml:"exec"`
 84	Model          PackageConfig     `yaml:"model"`
 85	Resolver       PackageConfig     `yaml:"resolver,omitempty"`
 86	Models         TypeMap           `yaml:"models,omitempty"`
 87	StructTag      string            `yaml:"struct_tag,omitempty"`
 88
 89	FilePath string `yaml:"-"`
 90
 91	schema *ast.Schema `yaml:"-"`
 92}
 93
 94type PackageConfig struct {
 95	Filename string `yaml:"filename,omitempty"`
 96	Package  string `yaml:"package,omitempty"`
 97	Type     string `yaml:"type,omitempty"`
 98}
 99
100type TypeMapEntry struct {
101	Model  string                  `yaml:"model"`
102	Fields map[string]TypeMapField `yaml:"fields,omitempty"`
103}
104
105type TypeMapField struct {
106	Resolver  bool   `yaml:"resolver"`
107	FieldName string `yaml:"fieldName"`
108}
109
110type SchemaFilenames []string
111
112func (a *SchemaFilenames) UnmarshalYAML(unmarshal func(interface{}) error) error {
113	var single string
114	err := unmarshal(&single)
115	if err == nil {
116		*a = []string{single}
117		return nil
118	}
119
120	var multi []string
121	err = unmarshal(&multi)
122	if err != nil {
123		return err
124	}
125
126	*a = multi
127	return nil
128}
129
130func (a SchemaFilenames) Has(file string) bool {
131	for _, existing := range a {
132		if existing == file {
133			return true
134		}
135	}
136	return false
137}
138
139func (c *PackageConfig) normalize() error {
140	if c.Filename == "" {
141		return errors.New("Filename is required")
142	}
143	c.Filename = abs(c.Filename)
144	// If Package is not set, first attempt to load the package at the output dir. If that fails
145	// fallback to just the base dir name of the output filename.
146	if c.Package == "" {
147		cwd, _ := os.Getwd()
148		pkg, _ := build.Default.Import(c.ImportPath(), cwd, 0)
149		if pkg.Name != "" {
150			c.Package = pkg.Name
151		} else {
152			c.Package = filepath.Base(c.Dir())
153		}
154	}
155	c.Package = sanitizePackageName(c.Package)
156	return nil
157}
158
159func (c *PackageConfig) ImportPath() string {
160	return gopath.MustDir2Import(c.Dir())
161}
162
163func (c *PackageConfig) Dir() string {
164	return filepath.Dir(c.Filename)
165}
166
167func (c *PackageConfig) Check() error {
168	if strings.ContainsAny(c.Package, "./\\") {
169		return fmt.Errorf("package should be the output package name only, do not include the output filename")
170	}
171	if c.Filename != "" && !strings.HasSuffix(c.Filename, ".go") {
172		return fmt.Errorf("filename should be path to a go source file")
173	}
174	return nil
175}
176
177func (c *PackageConfig) IsDefined() bool {
178	return c.Filename != ""
179}
180
181func (cfg *Config) Check() error {
182	if err := cfg.Models.Check(); err != nil {
183		return errors.Wrap(err, "config.models")
184	}
185	if err := cfg.Exec.Check(); err != nil {
186		return errors.Wrap(err, "config.exec")
187	}
188	if err := cfg.Model.Check(); err != nil {
189		return errors.Wrap(err, "config.model")
190	}
191	if err := cfg.Resolver.Check(); err != nil {
192		return errors.Wrap(err, "config.resolver")
193	}
194	return nil
195}
196
197type TypeMap map[string]TypeMapEntry
198
199func (tm TypeMap) Exists(typeName string) bool {
200	_, ok := tm[typeName]
201	return ok
202}
203
204func (tm TypeMap) Check() error {
205	for typeName, entry := range tm {
206		if strings.LastIndex(entry.Model, ".") < strings.LastIndex(entry.Model, "/") {
207			return fmt.Errorf("model %s: invalid type specifier \"%s\" - you need to specify a struct to map to", typeName, entry.Model)
208		}
209	}
210	return nil
211}
212
213func (tm TypeMap) referencedPackages() []string {
214	var pkgs []string
215
216	for _, typ := range tm {
217		if typ.Model == "map[string]interface{}" {
218			continue
219		}
220		pkg, _ := pkgAndType(typ.Model)
221		if pkg == "" || inStrSlice(pkgs, pkg) {
222			continue
223		}
224		pkgs = append(pkgs, pkg)
225	}
226
227	sort.Slice(pkgs, func(i, j int) bool {
228		return pkgs[i] > pkgs[j]
229	})
230	return pkgs
231}
232
233func inStrSlice(haystack []string, needle string) bool {
234	for _, v := range haystack {
235		if needle == v {
236			return true
237		}
238	}
239
240	return false
241}
242
243// findCfg searches for the config file in this directory and all parents up the tree
244// looking for the closest match
245func findCfg() (string, error) {
246	dir, err := os.Getwd()
247	if err != nil {
248		return "", errors.Wrap(err, "unable to get working dir to findCfg")
249	}
250
251	cfg := findCfgInDir(dir)
252
253	for cfg == "" && dir != filepath.Dir(dir) {
254		dir = filepath.Dir(dir)
255		cfg = findCfgInDir(dir)
256	}
257
258	if cfg == "" {
259		return "", os.ErrNotExist
260	}
261
262	return cfg, nil
263}
264
265func findCfgInDir(dir string) string {
266	for _, cfgName := range cfgFilenames {
267		path := filepath.Join(dir, cfgName)
268		if _, err := os.Stat(path); err == nil {
269			return path
270		}
271	}
272	return ""
273}