1package templates
2
3import (
4 "bytes"
5 "fmt"
6 "go/types"
7 "io/ioutil"
8 "os"
9 "path/filepath"
10 "reflect"
11 "runtime"
12 "sort"
13 "strconv"
14 "strings"
15 "text/template"
16 "unicode"
17
18 "github.com/99designs/gqlgen/internal/imports"
19 "github.com/pkg/errors"
20)
21
22// this is done with a global because subtemplates currently get called in functions. Lets aim to remove this eventually.
23var CurrentImports *Imports
24
25type Options struct {
26 PackageName string
27 Filename string
28 RegionTags bool
29 GeneratedHeader bool
30 Data interface{}
31 Funcs template.FuncMap
32}
33
34func Render(cfg Options) error {
35 if CurrentImports != nil {
36 panic(fmt.Errorf("recursive or concurrent call to RenderToFile detected"))
37 }
38 CurrentImports = &Imports{destDir: filepath.Dir(cfg.Filename)}
39
40 // load path relative to calling source file
41 _, callerFile, _, _ := runtime.Caller(1)
42 rootDir := filepath.Dir(callerFile)
43
44 funcs := Funcs()
45 for n, f := range cfg.Funcs {
46 funcs[n] = f
47 }
48 t := template.New("").Funcs(funcs)
49
50 var roots []string
51 // load all the templates in the directory
52 err := filepath.Walk(rootDir, func(path string, info os.FileInfo, err error) error {
53 if err != nil {
54 return err
55 }
56 name := filepath.ToSlash(strings.TrimPrefix(path, rootDir+string(os.PathSeparator)))
57 if !strings.HasSuffix(info.Name(), ".gotpl") {
58 return nil
59 }
60 b, err := ioutil.ReadFile(path)
61 if err != nil {
62 return err
63 }
64
65 t, err = t.New(name).Parse(string(b))
66 if err != nil {
67 return errors.Wrap(err, cfg.Filename)
68 }
69
70 roots = append(roots, name)
71
72 return nil
73 })
74 if err != nil {
75 return errors.Wrap(err, "locating templates")
76 }
77
78 // then execute all the important looking ones in order, adding them to the same file
79 sort.Slice(roots, func(i, j int) bool {
80 // important files go first
81 if strings.HasSuffix(roots[i], "!.gotpl") {
82 return true
83 }
84 if strings.HasSuffix(roots[j], "!.gotpl") {
85 return false
86 }
87 return roots[i] < roots[j]
88 })
89 var buf bytes.Buffer
90 for _, root := range roots {
91 if cfg.RegionTags {
92 buf.WriteString("\n// region " + center(70, "*", " "+root+" ") + "\n")
93 }
94 err = t.Lookup(root).Execute(&buf, cfg.Data)
95 if err != nil {
96 return errors.Wrap(err, root)
97 }
98 if cfg.RegionTags {
99 buf.WriteString("\n// endregion " + center(70, "*", " "+root+" ") + "\n")
100 }
101 }
102
103 var result bytes.Buffer
104 if cfg.GeneratedHeader {
105 result.WriteString("// Code generated by github.com/99designs/gqlgen, DO NOT EDIT.\n\n")
106 }
107 result.WriteString("package ")
108 result.WriteString(cfg.PackageName)
109 result.WriteString("\n\n")
110 result.WriteString("import (\n")
111 result.WriteString(CurrentImports.String())
112 result.WriteString(")\n")
113 _, err = buf.WriteTo(&result)
114 if err != nil {
115 return err
116 }
117 CurrentImports = nil
118
119 return write(cfg.Filename, result.Bytes())
120}
121
122func center(width int, pad string, s string) string {
123 if len(s)+2 > width {
124 return s
125 }
126 lpad := (width - len(s)) / 2
127 rpad := width - (lpad + len(s))
128 return strings.Repeat(pad, lpad) + s + strings.Repeat(pad, rpad)
129}
130
131func Funcs() template.FuncMap {
132 return template.FuncMap{
133 "ucFirst": ucFirst,
134 "lcFirst": lcFirst,
135 "quote": strconv.Quote,
136 "rawQuote": rawQuote,
137 "dump": Dump,
138 "ref": ref,
139 "ts": TypeIdentifier,
140 "call": Call,
141 "prefixLines": prefixLines,
142 "notNil": notNil,
143 "reserveImport": CurrentImports.Reserve,
144 "lookupImport": CurrentImports.Lookup,
145 "go": ToGo,
146 "goPrivate": ToGoPrivate,
147 "add": func(a, b int) int {
148 return a + b
149 },
150 "render": func(filename string, tpldata interface{}) (*bytes.Buffer, error) {
151 return render(resolveName(filename, 0), tpldata)
152 },
153 }
154}
155
156func ucFirst(s string) string {
157 if s == "" {
158 return ""
159 }
160 r := []rune(s)
161 r[0] = unicode.ToUpper(r[0])
162 return string(r)
163}
164
165func lcFirst(s string) string {
166 if s == "" {
167 return ""
168 }
169
170 r := []rune(s)
171 r[0] = unicode.ToLower(r[0])
172 return string(r)
173}
174
175func isDelimiter(c rune) bool {
176 return c == '-' || c == '_' || unicode.IsSpace(c)
177}
178
179func ref(p types.Type) string {
180 return CurrentImports.LookupType(p)
181}
182
183var pkgReplacer = strings.NewReplacer(
184 "/", "ᚋ",
185 ".", "ᚗ",
186 "-", "ᚑ",
187)
188
189func TypeIdentifier(t types.Type) string {
190 res := ""
191 for {
192 switch it := t.(type) {
193 case *types.Pointer:
194 t.Underlying()
195 res += "ᚖ"
196 t = it.Elem()
197 case *types.Slice:
198 res += "ᚕ"
199 t = it.Elem()
200 case *types.Named:
201 res += pkgReplacer.Replace(it.Obj().Pkg().Path())
202 res += "ᚐ"
203 res += it.Obj().Name()
204 return res
205 case *types.Basic:
206 res += it.Name()
207 return res
208 case *types.Map:
209 res += "map"
210 return res
211 case *types.Interface:
212 res += "interface"
213 return res
214 default:
215 panic(fmt.Errorf("unexpected type %T", it))
216 }
217 }
218}
219
220func Call(p *types.Func) string {
221 pkg := CurrentImports.Lookup(p.Pkg().Path())
222
223 if pkg != "" {
224 pkg += "."
225 }
226
227 if p.Type() != nil {
228 // make sure the returned type is listed in our imports.
229 ref(p.Type().(*types.Signature).Results().At(0).Type())
230 }
231
232 return pkg + p.Name()
233}
234
235func ToGo(name string) string {
236 runes := make([]rune, 0, len(name))
237
238 wordWalker(name, func(info *wordInfo) {
239 word := info.Word
240 if info.MatchCommonInitial {
241 word = strings.ToUpper(word)
242 } else if !info.HasCommonInitial {
243 if strings.ToUpper(word) == word || strings.ToLower(word) == word {
244 // FOO or foo → Foo
245 // FOo → FOo
246 word = ucFirst(strings.ToLower(word))
247 }
248 }
249 runes = append(runes, []rune(word)...)
250 })
251
252 return string(runes)
253}
254
255func ToGoPrivate(name string) string {
256 runes := make([]rune, 0, len(name))
257
258 first := true
259 wordWalker(name, func(info *wordInfo) {
260 word := info.Word
261 if first {
262 if strings.ToUpper(word) == word || strings.ToLower(word) == word {
263 // ID → id, CAMEL → camel
264 word = strings.ToLower(info.Word)
265 } else {
266 // ITicket → iTicket
267 word = lcFirst(info.Word)
268 }
269 first = false
270 } else if info.MatchCommonInitial {
271 word = strings.ToUpper(word)
272 } else if !info.HasCommonInitial {
273 word = ucFirst(strings.ToLower(word))
274 }
275 runes = append(runes, []rune(word)...)
276 })
277
278 return sanitizeKeywords(string(runes))
279}
280
281type wordInfo struct {
282 Word string
283 MatchCommonInitial bool
284 HasCommonInitial bool
285}
286
287// This function is based on the following code.
288// https://github.com/golang/lint/blob/06c8688daad7faa9da5a0c2f163a3d14aac986ca/lint.go#L679
289func wordWalker(str string, f func(*wordInfo)) {
290 runes := []rune(str)
291 w, i := 0, 0 // index of start of word, scan
292 hasCommonInitial := false
293 for i+1 <= len(runes) {
294 eow := false // whether we hit the end of a word
295 if i+1 == len(runes) {
296 eow = true
297 } else if isDelimiter(runes[i+1]) {
298 // underscore; shift the remainder forward over any run of underscores
299 eow = true
300 n := 1
301 for i+n+1 < len(runes) && isDelimiter(runes[i+n+1]) {
302 n++
303 }
304
305 // Leave at most one underscore if the underscore is between two digits
306 if i+n+1 < len(runes) && unicode.IsDigit(runes[i]) && unicode.IsDigit(runes[i+n+1]) {
307 n--
308 }
309
310 copy(runes[i+1:], runes[i+n+1:])
311 runes = runes[:len(runes)-n]
312 } else if unicode.IsLower(runes[i]) && !unicode.IsLower(runes[i+1]) {
313 // lower->non-lower
314 eow = true
315 }
316 i++
317
318 // [w,i) is a word.
319 word := string(runes[w:i])
320 if !eow && commonInitialisms[word] && !unicode.IsLower(runes[i]) {
321 // through
322 // split IDFoo → ID, Foo
323 // but URLs → URLs
324 } else if !eow {
325 if commonInitialisms[word] {
326 hasCommonInitial = true
327 }
328 continue
329 }
330
331 matchCommonInitial := false
332 if commonInitialisms[strings.ToUpper(word)] {
333 hasCommonInitial = true
334 matchCommonInitial = true
335 }
336
337 f(&wordInfo{
338 Word: word,
339 MatchCommonInitial: matchCommonInitial,
340 HasCommonInitial: hasCommonInitial,
341 })
342 hasCommonInitial = false
343 w = i
344 }
345}
346
347var keywords = []string{
348 "break",
349 "default",
350 "func",
351 "interface",
352 "select",
353 "case",
354 "defer",
355 "go",
356 "map",
357 "struct",
358 "chan",
359 "else",
360 "goto",
361 "package",
362 "switch",
363 "const",
364 "fallthrough",
365 "if",
366 "range",
367 "type",
368 "continue",
369 "for",
370 "import",
371 "return",
372 "var",
373 "_",
374}
375
376// sanitizeKeywords prevents collisions with go keywords for arguments to resolver functions
377func sanitizeKeywords(name string) string {
378 for _, k := range keywords {
379 if name == k {
380 return name + "Arg"
381 }
382 }
383 return name
384}
385
386// commonInitialisms is a set of common initialisms.
387// Only add entries that are highly unlikely to be non-initialisms.
388// For instance, "ID" is fine (Freudian code is rare), but "AND" is not.
389var commonInitialisms = map[string]bool{
390 "ACL": true,
391 "API": true,
392 "ASCII": true,
393 "CPU": true,
394 "CSS": true,
395 "DNS": true,
396 "EOF": true,
397 "GUID": true,
398 "HTML": true,
399 "HTTP": true,
400 "HTTPS": true,
401 "ID": true,
402 "IP": true,
403 "JSON": true,
404 "LHS": true,
405 "QPS": true,
406 "RAM": true,
407 "RHS": true,
408 "RPC": true,
409 "SLA": true,
410 "SMTP": true,
411 "SQL": true,
412 "SSH": true,
413 "TCP": true,
414 "TLS": true,
415 "TTL": true,
416 "UDP": true,
417 "UI": true,
418 "UID": true,
419 "UUID": true,
420 "URI": true,
421 "URL": true,
422 "UTF8": true,
423 "VM": true,
424 "XML": true,
425 "XMPP": true,
426 "XSRF": true,
427 "XSS": true,
428}
429
430func rawQuote(s string) string {
431 return "`" + strings.Replace(s, "`", "`+\"`\"+`", -1) + "`"
432}
433
434func notNil(field string, data interface{}) bool {
435 v := reflect.ValueOf(data)
436
437 if v.Kind() == reflect.Ptr {
438 v = v.Elem()
439 }
440 if v.Kind() != reflect.Struct {
441 return false
442 }
443 val := v.FieldByName(field)
444
445 return val.IsValid() && !val.IsNil()
446}
447
448func Dump(val interface{}) string {
449 switch val := val.(type) {
450 case int:
451 return strconv.Itoa(val)
452 case int64:
453 return fmt.Sprintf("%d", val)
454 case float64:
455 return fmt.Sprintf("%f", val)
456 case string:
457 return strconv.Quote(val)
458 case bool:
459 return strconv.FormatBool(val)
460 case nil:
461 return "nil"
462 case []interface{}:
463 var parts []string
464 for _, part := range val {
465 parts = append(parts, Dump(part))
466 }
467 return "[]interface{}{" + strings.Join(parts, ",") + "}"
468 case map[string]interface{}:
469 buf := bytes.Buffer{}
470 buf.WriteString("map[string]interface{}{")
471 var keys []string
472 for key := range val {
473 keys = append(keys, key)
474 }
475 sort.Strings(keys)
476
477 for _, key := range keys {
478 data := val[key]
479
480 buf.WriteString(strconv.Quote(key))
481 buf.WriteString(":")
482 buf.WriteString(Dump(data))
483 buf.WriteString(",")
484 }
485 buf.WriteString("}")
486 return buf.String()
487 default:
488 panic(fmt.Errorf("unsupported type %T", val))
489 }
490}
491
492func prefixLines(prefix, s string) string {
493 return prefix + strings.Replace(s, "\n", "\n"+prefix, -1)
494}
495
496func resolveName(name string, skip int) string {
497 if name[0] == '.' {
498 // load path relative to calling source file
499 _, callerFile, _, _ := runtime.Caller(skip + 1)
500 return filepath.Join(filepath.Dir(callerFile), name[1:])
501 }
502
503 // load path relative to this directory
504 _, callerFile, _, _ := runtime.Caller(0)
505 return filepath.Join(filepath.Dir(callerFile), name)
506}
507
508func render(filename string, tpldata interface{}) (*bytes.Buffer, error) {
509 t := template.New("").Funcs(Funcs())
510
511 b, err := ioutil.ReadFile(filename)
512 if err != nil {
513 return nil, err
514 }
515
516 t, err = t.New(filepath.Base(filename)).Parse(string(b))
517 if err != nil {
518 panic(err)
519 }
520
521 buf := &bytes.Buffer{}
522 return buf, t.Execute(buf, tpldata)
523}
524
525func write(filename string, b []byte) error {
526 err := os.MkdirAll(filepath.Dir(filename), 0755)
527 if err != nil {
528 return errors.Wrap(err, "failed to create directory")
529 }
530
531 formatted, err := imports.Prune(filename, b)
532 if err != nil {
533 fmt.Fprintf(os.Stderr, "gofmt failed on %s: %s\n", filepath.Base(filename), err.Error())
534 formatted = b
535 }
536
537 err = ioutil.WriteFile(filename, formatted, 0644)
538 if err != nil {
539 return errors.Wrapf(err, "failed to write %s", filename)
540 }
541
542 return nil
543}