1package codegen
2
3import (
4 "fmt"
5 "go/build"
6 "io/ioutil"
7 "os"
8 "path/filepath"
9 "sort"
10 "strings"
11
12 "github.com/99designs/gqlgen/internal/gopath"
13 "github.com/pkg/errors"
14 "github.com/vektah/gqlparser/ast"
15 "gopkg.in/yaml.v2"
16)
17
18var cfgFilenames = []string{".gqlgen.yml", "gqlgen.yml", "gqlgen.yaml"}
19
20// DefaultConfig creates a copy of the default config
21func DefaultConfig() *Config {
22 return &Config{
23 SchemaFilename: SchemaFilenames{"schema.graphql"},
24 SchemaStr: map[string]string{},
25 Model: PackageConfig{Filename: "models_gen.go"},
26 Exec: PackageConfig{Filename: "generated.go"},
27 }
28}
29
30// LoadConfigFromDefaultLocations looks for a config file in the current directory, and all parent directories
31// walking up the tree. The closest config file will be returned.
32func LoadConfigFromDefaultLocations() (*Config, error) {
33 cfgFile, err := findCfg()
34 if err != nil {
35 return nil, err
36 }
37
38 err = os.Chdir(filepath.Dir(cfgFile))
39 if err != nil {
40 return nil, errors.Wrap(err, "unable to enter config dir")
41 }
42 return LoadConfig(cfgFile)
43}
44
45// LoadConfig reads the gqlgen.yml config file
46func LoadConfig(filename string) (*Config, error) {
47 config := DefaultConfig()
48
49 b, err := ioutil.ReadFile(filename)
50 if err != nil {
51 return nil, errors.Wrap(err, "unable to read config")
52 }
53
54 if err := yaml.UnmarshalStrict(b, config); err != nil {
55 return nil, errors.Wrap(err, "unable to parse config")
56 }
57
58 preGlobbing := config.SchemaFilename
59 config.SchemaFilename = SchemaFilenames{}
60 for _, f := range preGlobbing {
61 matches, err := filepath.Glob(f)
62 if err != nil {
63 return nil, errors.Wrapf(err, "failed to glob schema filename %s", f)
64 }
65
66 for _, m := range matches {
67 if config.SchemaFilename.Has(m) {
68 continue
69 }
70 config.SchemaFilename = append(config.SchemaFilename, m)
71 }
72 }
73
74 config.FilePath = filename
75 config.SchemaStr = map[string]string{}
76
77 return config, nil
78}
79
80type Config struct {
81 SchemaFilename SchemaFilenames `yaml:"schema,omitempty"`
82 SchemaStr map[string]string `yaml:"-"`
83 Exec PackageConfig `yaml:"exec"`
84 Model PackageConfig `yaml:"model"`
85 Resolver PackageConfig `yaml:"resolver,omitempty"`
86 Models TypeMap `yaml:"models,omitempty"`
87 StructTag string `yaml:"struct_tag,omitempty"`
88
89 FilePath string `yaml:"-"`
90
91 schema *ast.Schema `yaml:"-"`
92}
93
94type PackageConfig struct {
95 Filename string `yaml:"filename,omitempty"`
96 Package string `yaml:"package,omitempty"`
97 Type string `yaml:"type,omitempty"`
98}
99
100type TypeMapEntry struct {
101 Model string `yaml:"model"`
102 Fields map[string]TypeMapField `yaml:"fields,omitempty"`
103}
104
105type TypeMapField struct {
106 Resolver bool `yaml:"resolver"`
107 FieldName string `yaml:"fieldName"`
108}
109
110type SchemaFilenames []string
111
112func (a *SchemaFilenames) UnmarshalYAML(unmarshal func(interface{}) error) error {
113 var single string
114 err := unmarshal(&single)
115 if err == nil {
116 *a = []string{single}
117 return nil
118 }
119
120 var multi []string
121 err = unmarshal(&multi)
122 if err != nil {
123 return err
124 }
125
126 *a = multi
127 return nil
128}
129
130func (a SchemaFilenames) Has(file string) bool {
131 for _, existing := range a {
132 if existing == file {
133 return true
134 }
135 }
136 return false
137}
138
139func (c *PackageConfig) normalize() error {
140 if c.Filename == "" {
141 return errors.New("Filename is required")
142 }
143 c.Filename = abs(c.Filename)
144 // If Package is not set, first attempt to load the package at the output dir. If that fails
145 // fallback to just the base dir name of the output filename.
146 if c.Package == "" {
147 cwd, _ := os.Getwd()
148 pkg, _ := build.Default.Import(c.ImportPath(), cwd, 0)
149 if pkg.Name != "" {
150 c.Package = pkg.Name
151 } else {
152 c.Package = filepath.Base(c.Dir())
153 }
154 }
155 c.Package = sanitizePackageName(c.Package)
156 return nil
157}
158
159func (c *PackageConfig) ImportPath() string {
160 return gopath.MustDir2Import(c.Dir())
161}
162
163func (c *PackageConfig) Dir() string {
164 return filepath.Dir(c.Filename)
165}
166
167func (c *PackageConfig) Check() error {
168 if strings.ContainsAny(c.Package, "./\\") {
169 return fmt.Errorf("package should be the output package name only, do not include the output filename")
170 }
171 if c.Filename != "" && !strings.HasSuffix(c.Filename, ".go") {
172 return fmt.Errorf("filename should be path to a go source file")
173 }
174 return nil
175}
176
177func (c *PackageConfig) IsDefined() bool {
178 return c.Filename != ""
179}
180
181func (cfg *Config) Check() error {
182 if err := cfg.Models.Check(); err != nil {
183 return errors.Wrap(err, "config.models")
184 }
185 if err := cfg.Exec.Check(); err != nil {
186 return errors.Wrap(err, "config.exec")
187 }
188 if err := cfg.Model.Check(); err != nil {
189 return errors.Wrap(err, "config.model")
190 }
191 if err := cfg.Resolver.Check(); err != nil {
192 return errors.Wrap(err, "config.resolver")
193 }
194 return nil
195}
196
197type TypeMap map[string]TypeMapEntry
198
199func (tm TypeMap) Exists(typeName string) bool {
200 _, ok := tm[typeName]
201 return ok
202}
203
204func (tm TypeMap) Check() error {
205 for typeName, entry := range tm {
206 if strings.LastIndex(entry.Model, ".") < strings.LastIndex(entry.Model, "/") {
207 return fmt.Errorf("model %s: invalid type specifier \"%s\" - you need to specify a struct to map to", typeName, entry.Model)
208 }
209 }
210 return nil
211}
212
213func (tm TypeMap) referencedPackages() []string {
214 var pkgs []string
215
216 for _, typ := range tm {
217 if typ.Model == "map[string]interface{}" {
218 continue
219 }
220 pkg, _ := pkgAndType(typ.Model)
221 if pkg == "" || inStrSlice(pkgs, pkg) {
222 continue
223 }
224 pkgs = append(pkgs, pkg)
225 }
226
227 sort.Slice(pkgs, func(i, j int) bool {
228 return pkgs[i] > pkgs[j]
229 })
230 return pkgs
231}
232
233func inStrSlice(haystack []string, needle string) bool {
234 for _, v := range haystack {
235 if needle == v {
236 return true
237 }
238 }
239
240 return false
241}
242
243// findCfg searches for the config file in this directory and all parents up the tree
244// looking for the closest match
245func findCfg() (string, error) {
246 dir, err := os.Getwd()
247 if err != nil {
248 return "", errors.Wrap(err, "unable to get working dir to findCfg")
249 }
250
251 cfg := findCfgInDir(dir)
252
253 for cfg == "" && dir != filepath.Dir(dir) {
254 dir = filepath.Dir(dir)
255 cfg = findCfgInDir(dir)
256 }
257
258 if cfg == "" {
259 return "", os.ErrNotExist
260 }
261
262 return cfg, nil
263}
264
265func findCfgInDir(dir string) string {
266 for _, cfgName := range cfgFilenames {
267 path := filepath.Join(dir, cfgName)
268 if _, err := os.Stat(path); err == nil {
269 return path
270 }
271 }
272 return ""
273}