1package templates
2
3import (
4 "bytes"
5 "fmt"
6 "go/types"
7 "io/ioutil"
8 "os"
9 "path/filepath"
10 "reflect"
11 "runtime"
12 "sort"
13 "strconv"
14 "strings"
15 "text/template"
16 "unicode"
17
18 "github.com/99designs/gqlgen/internal/imports"
19 "github.com/pkg/errors"
20)
21
22// CurrentImports keeps track of all the import declarations that are needed during the execution of a plugin.
23// this is done with a global because subtemplates currently get called in functions. Lets aim to remove this eventually.
24var CurrentImports *Imports
25
26// Options specify various parameters to rendering a template.
27type Options struct {
28 // PackageName is a helper that specifies the package header declaration.
29 // In other words, when you write the template you don't need to specify `package X`
30 // at the top of the file. By providing PackageName in the Options, the Render
31 // function will do that for you.
32 PackageName string
33 // Template is a string of the entire template that
34 // will be parsed and rendered. If it's empty,
35 // the plugin processor will look for .gotpl files
36 // in the same directory of where you wrote the plugin.
37 Template string
38 // Filename is the name of the file that will be
39 // written to the system disk once the template is rendered.
40 Filename string
41 RegionTags bool
42 GeneratedHeader bool
43 // Data will be passed to the template execution.
44 Data interface{}
45 Funcs template.FuncMap
46}
47
48// Render renders a gql plugin template from the given Options. Render is an
49// abstraction of the text/template package that makes it easier to write gqlgen
50// plugins. If Options.Template is empty, the Render function will look for `.gotpl`
51// files inside the directory where you wrote the plugin.
52func Render(cfg Options) error {
53 if CurrentImports != nil {
54 panic(fmt.Errorf("recursive or concurrent call to RenderToFile detected"))
55 }
56 CurrentImports = &Imports{destDir: filepath.Dir(cfg.Filename)}
57
58 // load path relative to calling source file
59 _, callerFile, _, _ := runtime.Caller(1)
60 rootDir := filepath.Dir(callerFile)
61
62 funcs := Funcs()
63 for n, f := range cfg.Funcs {
64 funcs[n] = f
65 }
66 t := template.New("").Funcs(funcs)
67
68 var roots []string
69 if cfg.Template != "" {
70 var err error
71 t, err = t.New("template.gotpl").Parse(cfg.Template)
72 if err != nil {
73 return errors.Wrap(err, "error with provided template")
74 }
75 roots = append(roots, "template.gotpl")
76 } else {
77 // load all the templates in the directory
78 err := filepath.Walk(rootDir, func(path string, info os.FileInfo, err error) error {
79 if err != nil {
80 return err
81 }
82 name := filepath.ToSlash(strings.TrimPrefix(path, rootDir+string(os.PathSeparator)))
83 if !strings.HasSuffix(info.Name(), ".gotpl") {
84 return nil
85 }
86 b, err := ioutil.ReadFile(path)
87 if err != nil {
88 return err
89 }
90
91 t, err = t.New(name).Parse(string(b))
92 if err != nil {
93 return errors.Wrap(err, cfg.Filename)
94 }
95
96 roots = append(roots, name)
97
98 return nil
99 })
100 if err != nil {
101 return errors.Wrap(err, "locating templates")
102 }
103 }
104
105 // then execute all the important looking ones in order, adding them to the same file
106 sort.Slice(roots, func(i, j int) bool {
107 // important files go first
108 if strings.HasSuffix(roots[i], "!.gotpl") {
109 return true
110 }
111 if strings.HasSuffix(roots[j], "!.gotpl") {
112 return false
113 }
114 return roots[i] < roots[j]
115 })
116 var buf bytes.Buffer
117 for _, root := range roots {
118 if cfg.RegionTags {
119 buf.WriteString("\n// region " + center(70, "*", " "+root+" ") + "\n")
120 }
121 err := t.Lookup(root).Execute(&buf, cfg.Data)
122 if err != nil {
123 return errors.Wrap(err, root)
124 }
125 if cfg.RegionTags {
126 buf.WriteString("\n// endregion " + center(70, "*", " "+root+" ") + "\n")
127 }
128 }
129
130 var result bytes.Buffer
131 if cfg.GeneratedHeader {
132 result.WriteString("// Code generated by github.com/99designs/gqlgen, DO NOT EDIT.\n\n")
133 }
134 result.WriteString("package ")
135 result.WriteString(cfg.PackageName)
136 result.WriteString("\n\n")
137 result.WriteString("import (\n")
138 result.WriteString(CurrentImports.String())
139 result.WriteString(")\n")
140 _, err := buf.WriteTo(&result)
141 if err != nil {
142 return err
143 }
144 CurrentImports = nil
145
146 return write(cfg.Filename, result.Bytes())
147}
148
149func center(width int, pad string, s string) string {
150 if len(s)+2 > width {
151 return s
152 }
153 lpad := (width - len(s)) / 2
154 rpad := width - (lpad + len(s))
155 return strings.Repeat(pad, lpad) + s + strings.Repeat(pad, rpad)
156}
157
158func Funcs() template.FuncMap {
159 return template.FuncMap{
160 "ucFirst": ucFirst,
161 "lcFirst": lcFirst,
162 "quote": strconv.Quote,
163 "rawQuote": rawQuote,
164 "dump": Dump,
165 "ref": ref,
166 "ts": TypeIdentifier,
167 "call": Call,
168 "prefixLines": prefixLines,
169 "notNil": notNil,
170 "reserveImport": CurrentImports.Reserve,
171 "lookupImport": CurrentImports.Lookup,
172 "go": ToGo,
173 "goPrivate": ToGoPrivate,
174 "add": func(a, b int) int {
175 return a + b
176 },
177 "render": func(filename string, tpldata interface{}) (*bytes.Buffer, error) {
178 return render(resolveName(filename, 0), tpldata)
179 },
180 }
181}
182
183func ucFirst(s string) string {
184 if s == "" {
185 return ""
186 }
187 r := []rune(s)
188 r[0] = unicode.ToUpper(r[0])
189 return string(r)
190}
191
192func lcFirst(s string) string {
193 if s == "" {
194 return ""
195 }
196
197 r := []rune(s)
198 r[0] = unicode.ToLower(r[0])
199 return string(r)
200}
201
202func isDelimiter(c rune) bool {
203 return c == '-' || c == '_' || unicode.IsSpace(c)
204}
205
206func ref(p types.Type) string {
207 return CurrentImports.LookupType(p)
208}
209
210var pkgReplacer = strings.NewReplacer(
211 "/", "ᚋ",
212 ".", "ᚗ",
213 "-", "ᚑ",
214)
215
216func TypeIdentifier(t types.Type) string {
217 res := ""
218 for {
219 switch it := t.(type) {
220 case *types.Pointer:
221 t.Underlying()
222 res += "ᚖ"
223 t = it.Elem()
224 case *types.Slice:
225 res += "ᚕ"
226 t = it.Elem()
227 case *types.Named:
228 res += pkgReplacer.Replace(it.Obj().Pkg().Path())
229 res += "ᚐ"
230 res += it.Obj().Name()
231 return res
232 case *types.Basic:
233 res += it.Name()
234 return res
235 case *types.Map:
236 res += "map"
237 return res
238 case *types.Interface:
239 res += "interface"
240 return res
241 default:
242 panic(fmt.Errorf("unexpected type %T", it))
243 }
244 }
245}
246
247func Call(p *types.Func) string {
248 pkg := CurrentImports.Lookup(p.Pkg().Path())
249
250 if pkg != "" {
251 pkg += "."
252 }
253
254 if p.Type() != nil {
255 // make sure the returned type is listed in our imports.
256 ref(p.Type().(*types.Signature).Results().At(0).Type())
257 }
258
259 return pkg + p.Name()
260}
261
262func ToGo(name string) string {
263 runes := make([]rune, 0, len(name))
264
265 wordWalker(name, func(info *wordInfo) {
266 word := info.Word
267 if info.MatchCommonInitial {
268 word = strings.ToUpper(word)
269 } else if !info.HasCommonInitial {
270 if strings.ToUpper(word) == word || strings.ToLower(word) == word {
271 // FOO or foo → Foo
272 // FOo → FOo
273 word = ucFirst(strings.ToLower(word))
274 }
275 }
276 runes = append(runes, []rune(word)...)
277 })
278
279 return string(runes)
280}
281
282func ToGoPrivate(name string) string {
283 runes := make([]rune, 0, len(name))
284
285 first := true
286 wordWalker(name, func(info *wordInfo) {
287 word := info.Word
288 if first {
289 if strings.ToUpper(word) == word || strings.ToLower(word) == word {
290 // ID → id, CAMEL → camel
291 word = strings.ToLower(info.Word)
292 } else {
293 // ITicket → iTicket
294 word = lcFirst(info.Word)
295 }
296 first = false
297 } else if info.MatchCommonInitial {
298 word = strings.ToUpper(word)
299 } else if !info.HasCommonInitial {
300 word = ucFirst(strings.ToLower(word))
301 }
302 runes = append(runes, []rune(word)...)
303 })
304
305 return sanitizeKeywords(string(runes))
306}
307
308type wordInfo struct {
309 Word string
310 MatchCommonInitial bool
311 HasCommonInitial bool
312}
313
314// This function is based on the following code.
315// https://github.com/golang/lint/blob/06c8688daad7faa9da5a0c2f163a3d14aac986ca/lint.go#L679
316func wordWalker(str string, f func(*wordInfo)) {
317 runes := []rune(str)
318 w, i := 0, 0 // index of start of word, scan
319 hasCommonInitial := false
320 for i+1 <= len(runes) {
321 eow := false // whether we hit the end of a word
322 if i+1 == len(runes) {
323 eow = true
324 } else if isDelimiter(runes[i+1]) {
325 // underscore; shift the remainder forward over any run of underscores
326 eow = true
327 n := 1
328 for i+n+1 < len(runes) && isDelimiter(runes[i+n+1]) {
329 n++
330 }
331
332 // Leave at most one underscore if the underscore is between two digits
333 if i+n+1 < len(runes) && unicode.IsDigit(runes[i]) && unicode.IsDigit(runes[i+n+1]) {
334 n--
335 }
336
337 copy(runes[i+1:], runes[i+n+1:])
338 runes = runes[:len(runes)-n]
339 } else if unicode.IsLower(runes[i]) && !unicode.IsLower(runes[i+1]) {
340 // lower->non-lower
341 eow = true
342 }
343 i++
344
345 // [w,i) is a word.
346 word := string(runes[w:i])
347 if !eow && commonInitialisms[word] && !unicode.IsLower(runes[i]) {
348 // through
349 // split IDFoo → ID, Foo
350 // but URLs → URLs
351 } else if !eow {
352 if commonInitialisms[word] {
353 hasCommonInitial = true
354 }
355 continue
356 }
357
358 matchCommonInitial := false
359 if commonInitialisms[strings.ToUpper(word)] {
360 hasCommonInitial = true
361 matchCommonInitial = true
362 }
363
364 f(&wordInfo{
365 Word: word,
366 MatchCommonInitial: matchCommonInitial,
367 HasCommonInitial: hasCommonInitial,
368 })
369 hasCommonInitial = false
370 w = i
371 }
372}
373
374var keywords = []string{
375 "break",
376 "default",
377 "func",
378 "interface",
379 "select",
380 "case",
381 "defer",
382 "go",
383 "map",
384 "struct",
385 "chan",
386 "else",
387 "goto",
388 "package",
389 "switch",
390 "const",
391 "fallthrough",
392 "if",
393 "range",
394 "type",
395 "continue",
396 "for",
397 "import",
398 "return",
399 "var",
400 "_",
401}
402
403// sanitizeKeywords prevents collisions with go keywords for arguments to resolver functions
404func sanitizeKeywords(name string) string {
405 for _, k := range keywords {
406 if name == k {
407 return name + "Arg"
408 }
409 }
410 return name
411}
412
413// commonInitialisms is a set of common initialisms.
414// Only add entries that are highly unlikely to be non-initialisms.
415// For instance, "ID" is fine (Freudian code is rare), but "AND" is not.
416var commonInitialisms = map[string]bool{
417 "ACL": true,
418 "API": true,
419 "ASCII": true,
420 "CPU": true,
421 "CSS": true,
422 "DNS": true,
423 "EOF": true,
424 "GUID": true,
425 "HTML": true,
426 "HTTP": true,
427 "HTTPS": true,
428 "ID": true,
429 "IP": true,
430 "JSON": true,
431 "LHS": true,
432 "QPS": true,
433 "RAM": true,
434 "RHS": true,
435 "RPC": true,
436 "SLA": true,
437 "SMTP": true,
438 "SQL": true,
439 "SSH": true,
440 "TCP": true,
441 "TLS": true,
442 "TTL": true,
443 "UDP": true,
444 "UI": true,
445 "UID": true,
446 "UUID": true,
447 "URI": true,
448 "URL": true,
449 "UTF8": true,
450 "VM": true,
451 "XML": true,
452 "XMPP": true,
453 "XSRF": true,
454 "XSS": true,
455}
456
457func rawQuote(s string) string {
458 return "`" + strings.Replace(s, "`", "`+\"`\"+`", -1) + "`"
459}
460
461func notNil(field string, data interface{}) bool {
462 v := reflect.ValueOf(data)
463
464 if v.Kind() == reflect.Ptr {
465 v = v.Elem()
466 }
467 if v.Kind() != reflect.Struct {
468 return false
469 }
470 val := v.FieldByName(field)
471
472 return val.IsValid() && !val.IsNil()
473}
474
475func Dump(val interface{}) string {
476 switch val := val.(type) {
477 case int:
478 return strconv.Itoa(val)
479 case int64:
480 return fmt.Sprintf("%d", val)
481 case float64:
482 return fmt.Sprintf("%f", val)
483 case string:
484 return strconv.Quote(val)
485 case bool:
486 return strconv.FormatBool(val)
487 case nil:
488 return "nil"
489 case []interface{}:
490 var parts []string
491 for _, part := range val {
492 parts = append(parts, Dump(part))
493 }
494 return "[]interface{}{" + strings.Join(parts, ",") + "}"
495 case map[string]interface{}:
496 buf := bytes.Buffer{}
497 buf.WriteString("map[string]interface{}{")
498 var keys []string
499 for key := range val {
500 keys = append(keys, key)
501 }
502 sort.Strings(keys)
503
504 for _, key := range keys {
505 data := val[key]
506
507 buf.WriteString(strconv.Quote(key))
508 buf.WriteString(":")
509 buf.WriteString(Dump(data))
510 buf.WriteString(",")
511 }
512 buf.WriteString("}")
513 return buf.String()
514 default:
515 panic(fmt.Errorf("unsupported type %T", val))
516 }
517}
518
519func prefixLines(prefix, s string) string {
520 return prefix + strings.Replace(s, "\n", "\n"+prefix, -1)
521}
522
523func resolveName(name string, skip int) string {
524 if name[0] == '.' {
525 // load path relative to calling source file
526 _, callerFile, _, _ := runtime.Caller(skip + 1)
527 return filepath.Join(filepath.Dir(callerFile), name[1:])
528 }
529
530 // load path relative to this directory
531 _, callerFile, _, _ := runtime.Caller(0)
532 return filepath.Join(filepath.Dir(callerFile), name)
533}
534
535func render(filename string, tpldata interface{}) (*bytes.Buffer, error) {
536 t := template.New("").Funcs(Funcs())
537
538 b, err := ioutil.ReadFile(filename)
539 if err != nil {
540 return nil, err
541 }
542
543 t, err = t.New(filepath.Base(filename)).Parse(string(b))
544 if err != nil {
545 panic(err)
546 }
547
548 buf := &bytes.Buffer{}
549 return buf, t.Execute(buf, tpldata)
550}
551
552func write(filename string, b []byte) error {
553 err := os.MkdirAll(filepath.Dir(filename), 0755)
554 if err != nil {
555 return errors.Wrap(err, "failed to create directory")
556 }
557
558 formatted, err := imports.Prune(filename, b)
559 if err != nil {
560 fmt.Fprintf(os.Stderr, "gofmt failed on %s: %s\n", filepath.Base(filename), err.Error())
561 formatted = b
562 }
563
564 err = ioutil.WriteFile(filename, formatted, 0644)
565 if err != nil {
566 return errors.Wrapf(err, "failed to write %s", filename)
567 }
568
569 return nil
570}