1package templates
2
3import (
4 "bytes"
5 "fmt"
6 "go/types"
7 "io/ioutil"
8 "os"
9 "path/filepath"
10 "reflect"
11 "runtime"
12 "sort"
13 "strconv"
14 "strings"
15 "text/template"
16 "unicode"
17
18 "github.com/99designs/gqlgen/internal/imports"
19 "github.com/pkg/errors"
20)
21
22// CurrentImports keeps track of all the import declarations that are needed during the execution of a plugin.
23// this is done with a global because subtemplates currently get called in functions. Lets aim to remove this eventually.
24var CurrentImports *Imports
25
26// Options specify various parameters to rendering a template.
27type Options struct {
28 // PackageName is a helper that specifies the package header declaration.
29 // In other words, when you write the template you don't need to specify `package X`
30 // at the top of the file. By providing PackageName in the Options, the Render
31 // function will do that for you.
32 PackageName string
33 // Template is a string of the entire template that
34 // will be parsed and rendered. If it's empty,
35 // the plugin processor will look for .gotpl files
36 // in the same directory of where you wrote the plugin.
37 Template string
38 // Filename is the name of the file that will be
39 // written to the system disk once the template is rendered.
40 Filename string
41 RegionTags bool
42 GeneratedHeader bool
43 // Data will be passed to the template execution.
44 Data interface{}
45 Funcs template.FuncMap
46}
47
48// Render renders a gql plugin template from the given Options. Render is an
49// abstraction of the text/template package that makes it easier to write gqlgen
50// plugins. If Options.Template is empty, the Render function will look for `.gotpl`
51// files inside the directory where you wrote the plugin.
52func Render(cfg Options) error {
53 if CurrentImports != nil {
54 panic(fmt.Errorf("recursive or concurrent call to RenderToFile detected"))
55 }
56 CurrentImports = &Imports{destDir: filepath.Dir(cfg.Filename)}
57
58 // load path relative to calling source file
59 _, callerFile, _, _ := runtime.Caller(1)
60 rootDir := filepath.Dir(callerFile)
61
62 funcs := Funcs()
63 for n, f := range cfg.Funcs {
64 funcs[n] = f
65 }
66 t := template.New("").Funcs(funcs)
67
68 var roots []string
69 if cfg.Template != "" {
70 var err error
71 t, err = t.New("template.gotpl").Parse(cfg.Template)
72 if err != nil {
73 return errors.Wrap(err, "error with provided template")
74 }
75 roots = append(roots, "template.gotpl")
76 } else {
77 // load all the templates in the directory
78 err := filepath.Walk(rootDir, func(path string, info os.FileInfo, err error) error {
79 if err != nil {
80 return err
81 }
82 name := filepath.ToSlash(strings.TrimPrefix(path, rootDir+string(os.PathSeparator)))
83 if !strings.HasSuffix(info.Name(), ".gotpl") {
84 return nil
85 }
86 b, err := ioutil.ReadFile(path)
87 if err != nil {
88 return err
89 }
90
91 t, err = t.New(name).Parse(string(b))
92 if err != nil {
93 return errors.Wrap(err, cfg.Filename)
94 }
95
96 roots = append(roots, name)
97
98 return nil
99 })
100 if err != nil {
101 return errors.Wrap(err, "locating templates")
102 }
103 }
104
105 // then execute all the important looking ones in order, adding them to the same file
106 sort.Slice(roots, func(i, j int) bool {
107 // important files go first
108 if strings.HasSuffix(roots[i], "!.gotpl") {
109 return true
110 }
111 if strings.HasSuffix(roots[j], "!.gotpl") {
112 return false
113 }
114 return roots[i] < roots[j]
115 })
116 var buf bytes.Buffer
117 for _, root := range roots {
118 if cfg.RegionTags {
119 buf.WriteString("\n// region " + center(70, "*", " "+root+" ") + "\n")
120 }
121 err := t.Lookup(root).Execute(&buf, cfg.Data)
122 if err != nil {
123 return errors.Wrap(err, root)
124 }
125 if cfg.RegionTags {
126 buf.WriteString("\n// endregion " + center(70, "*", " "+root+" ") + "\n")
127 }
128 }
129
130 var result bytes.Buffer
131 if cfg.GeneratedHeader {
132 result.WriteString("// Code generated by github.com/99designs/gqlgen, DO NOT EDIT.\n\n")
133 }
134 result.WriteString("package ")
135 result.WriteString(cfg.PackageName)
136 result.WriteString("\n\n")
137 result.WriteString("import (\n")
138 result.WriteString(CurrentImports.String())
139 result.WriteString(")\n")
140 _, err := buf.WriteTo(&result)
141 if err != nil {
142 return err
143 }
144 CurrentImports = nil
145
146 return write(cfg.Filename, result.Bytes())
147}
148
149func center(width int, pad string, s string) string {
150 if len(s)+2 > width {
151 return s
152 }
153 lpad := (width - len(s)) / 2
154 rpad := width - (lpad + len(s))
155 return strings.Repeat(pad, lpad) + s + strings.Repeat(pad, rpad)
156}
157
158func Funcs() template.FuncMap {
159 return template.FuncMap{
160 "ucFirst": ucFirst,
161 "lcFirst": lcFirst,
162 "quote": strconv.Quote,
163 "rawQuote": rawQuote,
164 "dump": Dump,
165 "ref": ref,
166 "ts": TypeIdentifier,
167 "call": Call,
168 "prefixLines": prefixLines,
169 "notNil": notNil,
170 "reserveImport": CurrentImports.Reserve,
171 "lookupImport": CurrentImports.Lookup,
172 "go": ToGo,
173 "goPrivate": ToGoPrivate,
174 "add": func(a, b int) int {
175 return a + b
176 },
177 "render": func(filename string, tpldata interface{}) (*bytes.Buffer, error) {
178 return render(resolveName(filename, 0), tpldata)
179 },
180 }
181}
182
183func ucFirst(s string) string {
184 if s == "" {
185 return ""
186 }
187 r := []rune(s)
188 r[0] = unicode.ToUpper(r[0])
189 return string(r)
190}
191
192func lcFirst(s string) string {
193 if s == "" {
194 return ""
195 }
196
197 r := []rune(s)
198 r[0] = unicode.ToLower(r[0])
199 return string(r)
200}
201
202func isDelimiter(c rune) bool {
203 return c == '-' || c == '_' || unicode.IsSpace(c)
204}
205
206func ref(p types.Type) string {
207 return CurrentImports.LookupType(p)
208}
209
210var pkgReplacer = strings.NewReplacer(
211 "/", "ᚋ",
212 ".", "ᚗ",
213 "-", "ᚑ",
214)
215
216func TypeIdentifier(t types.Type) string {
217 res := ""
218 for {
219 switch it := t.(type) {
220 case *types.Pointer:
221 t.Underlying()
222 res += "ᚖ"
223 t = it.Elem()
224 case *types.Slice:
225 res += "ᚕ"
226 t = it.Elem()
227 case *types.Named:
228 res += pkgReplacer.Replace(it.Obj().Pkg().Path())
229 res += "ᚐ"
230 res += it.Obj().Name()
231 return res
232 case *types.Basic:
233 res += it.Name()
234 return res
235 case *types.Map:
236 res += "map"
237 return res
238 case *types.Interface:
239 res += "interface"
240 return res
241 default:
242 panic(fmt.Errorf("unexpected type %T", it))
243 }
244 }
245}
246
247func Call(p *types.Func) string {
248 pkg := CurrentImports.Lookup(p.Pkg().Path())
249
250 if pkg != "" {
251 pkg += "."
252 }
253
254 if p.Type() != nil {
255 // make sure the returned type is listed in our imports.
256 ref(p.Type().(*types.Signature).Results().At(0).Type())
257 }
258
259 return pkg + p.Name()
260}
261
262func ToGo(name string) string {
263 runes := make([]rune, 0, len(name))
264
265 wordWalker(name, func(info *wordInfo) {
266 word := info.Word
267 if info.MatchCommonInitial {
268 word = strings.ToUpper(word)
269 } else if !info.HasCommonInitial {
270 if strings.ToUpper(word) == word || strings.ToLower(word) == word {
271 // FOO or foo → Foo
272 // FOo → FOo
273 word = ucFirst(strings.ToLower(word))
274 }
275 }
276 runes = append(runes, []rune(word)...)
277 })
278
279 return string(runes)
280}
281
282func ToGoPrivate(name string) string {
283 runes := make([]rune, 0, len(name))
284
285 first := true
286 wordWalker(name, func(info *wordInfo) {
287 word := info.Word
288 switch {
289 case first:
290 if strings.ToUpper(word) == word || strings.ToLower(word) == word {
291 // ID → id, CAMEL → camel
292 word = strings.ToLower(info.Word)
293 } else {
294 // ITicket → iTicket
295 word = lcFirst(info.Word)
296 }
297 first = false
298 case info.MatchCommonInitial:
299 word = strings.ToUpper(word)
300 case !info.HasCommonInitial:
301 word = ucFirst(strings.ToLower(word))
302 }
303 runes = append(runes, []rune(word)...)
304 })
305
306 return sanitizeKeywords(string(runes))
307}
308
309type wordInfo struct {
310 Word string
311 MatchCommonInitial bool
312 HasCommonInitial bool
313}
314
315// This function is based on the following code.
316// https://github.com/golang/lint/blob/06c8688daad7faa9da5a0c2f163a3d14aac986ca/lint.go#L679
317func wordWalker(str string, f func(*wordInfo)) {
318 runes := []rune(str)
319 w, i := 0, 0 // index of start of word, scan
320 hasCommonInitial := false
321 for i+1 <= len(runes) {
322 eow := false // whether we hit the end of a word
323 switch {
324 case i+1 == len(runes):
325 eow = true
326 case isDelimiter(runes[i+1]):
327 // underscore; shift the remainder forward over any run of underscores
328 eow = true
329 n := 1
330 for i+n+1 < len(runes) && isDelimiter(runes[i+n+1]) {
331 n++
332 }
333
334 // Leave at most one underscore if the underscore is between two digits
335 if i+n+1 < len(runes) && unicode.IsDigit(runes[i]) && unicode.IsDigit(runes[i+n+1]) {
336 n--
337 }
338
339 copy(runes[i+1:], runes[i+n+1:])
340 runes = runes[:len(runes)-n]
341 case unicode.IsLower(runes[i]) && !unicode.IsLower(runes[i+1]):
342 // lower->non-lower
343 eow = true
344 }
345 i++
346
347 // [w,i) is a word.
348 word := string(runes[w:i])
349 if !eow && commonInitialisms[word] && !unicode.IsLower(runes[i]) {
350 // through
351 // split IDFoo → ID, Foo
352 // but URLs → URLs
353 } else if !eow {
354 if commonInitialisms[word] {
355 hasCommonInitial = true
356 }
357 continue
358 }
359
360 matchCommonInitial := false
361 if commonInitialisms[strings.ToUpper(word)] {
362 hasCommonInitial = true
363 matchCommonInitial = true
364 }
365
366 f(&wordInfo{
367 Word: word,
368 MatchCommonInitial: matchCommonInitial,
369 HasCommonInitial: hasCommonInitial,
370 })
371 hasCommonInitial = false
372 w = i
373 }
374}
375
376var keywords = []string{
377 "break",
378 "default",
379 "func",
380 "interface",
381 "select",
382 "case",
383 "defer",
384 "go",
385 "map",
386 "struct",
387 "chan",
388 "else",
389 "goto",
390 "package",
391 "switch",
392 "const",
393 "fallthrough",
394 "if",
395 "range",
396 "type",
397 "continue",
398 "for",
399 "import",
400 "return",
401 "var",
402 "_",
403}
404
405// sanitizeKeywords prevents collisions with go keywords for arguments to resolver functions
406func sanitizeKeywords(name string) string {
407 for _, k := range keywords {
408 if name == k {
409 return name + "Arg"
410 }
411 }
412 return name
413}
414
415// commonInitialisms is a set of common initialisms.
416// Only add entries that are highly unlikely to be non-initialisms.
417// For instance, "ID" is fine (Freudian code is rare), but "AND" is not.
418var commonInitialisms = map[string]bool{
419 "ACL": true,
420 "API": true,
421 "ASCII": true,
422 "CPU": true,
423 "CSS": true,
424 "DNS": true,
425 "EOF": true,
426 "GUID": true,
427 "HTML": true,
428 "HTTP": true,
429 "HTTPS": true,
430 "ID": true,
431 "IP": true,
432 "JSON": true,
433 "LHS": true,
434 "QPS": true,
435 "RAM": true,
436 "RHS": true,
437 "RPC": true,
438 "SLA": true,
439 "SMTP": true,
440 "SQL": true,
441 "SSH": true,
442 "TCP": true,
443 "TLS": true,
444 "TTL": true,
445 "UDP": true,
446 "UI": true,
447 "UID": true,
448 "UUID": true,
449 "URI": true,
450 "URL": true,
451 "UTF8": true,
452 "VM": true,
453 "XML": true,
454 "XMPP": true,
455 "XSRF": true,
456 "XSS": true,
457}
458
459func rawQuote(s string) string {
460 return "`" + strings.Replace(s, "`", "`+\"`\"+`", -1) + "`"
461}
462
463func notNil(field string, data interface{}) bool {
464 v := reflect.ValueOf(data)
465
466 if v.Kind() == reflect.Ptr {
467 v = v.Elem()
468 }
469 if v.Kind() != reflect.Struct {
470 return false
471 }
472 val := v.FieldByName(field)
473
474 return val.IsValid() && !val.IsNil()
475}
476
477func Dump(val interface{}) string {
478 switch val := val.(type) {
479 case int:
480 return strconv.Itoa(val)
481 case int64:
482 return fmt.Sprintf("%d", val)
483 case float64:
484 return fmt.Sprintf("%f", val)
485 case string:
486 return strconv.Quote(val)
487 case bool:
488 return strconv.FormatBool(val)
489 case nil:
490 return "nil"
491 case []interface{}:
492 var parts []string
493 for _, part := range val {
494 parts = append(parts, Dump(part))
495 }
496 return "[]interface{}{" + strings.Join(parts, ",") + "}"
497 case map[string]interface{}:
498 buf := bytes.Buffer{}
499 buf.WriteString("map[string]interface{}{")
500 var keys []string
501 for key := range val {
502 keys = append(keys, key)
503 }
504 sort.Strings(keys)
505
506 for _, key := range keys {
507 data := val[key]
508
509 buf.WriteString(strconv.Quote(key))
510 buf.WriteString(":")
511 buf.WriteString(Dump(data))
512 buf.WriteString(",")
513 }
514 buf.WriteString("}")
515 return buf.String()
516 default:
517 panic(fmt.Errorf("unsupported type %T", val))
518 }
519}
520
521func prefixLines(prefix, s string) string {
522 return prefix + strings.Replace(s, "\n", "\n"+prefix, -1)
523}
524
525func resolveName(name string, skip int) string {
526 if name[0] == '.' {
527 // load path relative to calling source file
528 _, callerFile, _, _ := runtime.Caller(skip + 1)
529 return filepath.Join(filepath.Dir(callerFile), name[1:])
530 }
531
532 // load path relative to this directory
533 _, callerFile, _, _ := runtime.Caller(0)
534 return filepath.Join(filepath.Dir(callerFile), name)
535}
536
537func render(filename string, tpldata interface{}) (*bytes.Buffer, error) {
538 t := template.New("").Funcs(Funcs())
539
540 b, err := ioutil.ReadFile(filename)
541 if err != nil {
542 return nil, err
543 }
544
545 t, err = t.New(filepath.Base(filename)).Parse(string(b))
546 if err != nil {
547 panic(err)
548 }
549
550 buf := &bytes.Buffer{}
551 return buf, t.Execute(buf, tpldata)
552}
553
554func write(filename string, b []byte) error {
555 err := os.MkdirAll(filepath.Dir(filename), 0755)
556 if err != nil {
557 return errors.Wrap(err, "failed to create directory")
558 }
559
560 formatted, err := imports.Prune(filename, b)
561 if err != nil {
562 fmt.Fprintf(os.Stderr, "gofmt failed on %s: %s\n", filepath.Base(filename), err.Error())
563 formatted = b
564 }
565
566 err = ioutil.WriteFile(filename, formatted, 0644)
567 if err != nil {
568 return errors.Wrapf(err, "failed to write %s", filename)
569 }
570
571 return nil
572}