templates.go

  1package templates
  2
  3import (
  4	"bytes"
  5	"fmt"
  6	"go/types"
  7	"io/ioutil"
  8	"os"
  9	"path/filepath"
 10	"reflect"
 11	"runtime"
 12	"sort"
 13	"strconv"
 14	"strings"
 15	"text/template"
 16	"unicode"
 17
 18	"github.com/99designs/gqlgen/internal/imports"
 19	"github.com/pkg/errors"
 20)
 21
 22// this is done with a global because subtemplates currently get called in functions. Lets aim to remove this eventually.
 23var CurrentImports *Imports
 24
 25type Options struct {
 26	PackageName     string
 27	Filename        string
 28	RegionTags      bool
 29	GeneratedHeader bool
 30	Data            interface{}
 31	Funcs           template.FuncMap
 32}
 33
 34func Render(cfg Options) error {
 35	if CurrentImports != nil {
 36		panic(fmt.Errorf("recursive or concurrent call to RenderToFile detected"))
 37	}
 38	CurrentImports = &Imports{destDir: filepath.Dir(cfg.Filename)}
 39
 40	// load path relative to calling source file
 41	_, callerFile, _, _ := runtime.Caller(1)
 42	rootDir := filepath.Dir(callerFile)
 43
 44	funcs := Funcs()
 45	for n, f := range cfg.Funcs {
 46		funcs[n] = f
 47	}
 48	t := template.New("").Funcs(funcs)
 49
 50	var roots []string
 51	// load all the templates in the directory
 52	err := filepath.Walk(rootDir, func(path string, info os.FileInfo, err error) error {
 53		if err != nil {
 54			return err
 55		}
 56		name := filepath.ToSlash(strings.TrimPrefix(path, rootDir+string(os.PathSeparator)))
 57		if !strings.HasSuffix(info.Name(), ".gotpl") {
 58			return nil
 59		}
 60		b, err := ioutil.ReadFile(path)
 61		if err != nil {
 62			return err
 63		}
 64
 65		t, err = t.New(name).Parse(string(b))
 66		if err != nil {
 67			return errors.Wrap(err, cfg.Filename)
 68		}
 69
 70		roots = append(roots, name)
 71
 72		return nil
 73	})
 74	if err != nil {
 75		return errors.Wrap(err, "locating templates")
 76	}
 77
 78	// then execute all the important looking ones in order, adding them to the same file
 79	sort.Slice(roots, func(i, j int) bool {
 80		// important files go first
 81		if strings.HasSuffix(roots[i], "!.gotpl") {
 82			return true
 83		}
 84		if strings.HasSuffix(roots[j], "!.gotpl") {
 85			return false
 86		}
 87		return roots[i] < roots[j]
 88	})
 89	var buf bytes.Buffer
 90	for _, root := range roots {
 91		if cfg.RegionTags {
 92			buf.WriteString("\n// region    " + center(70, "*", " "+root+" ") + "\n")
 93		}
 94		err = t.Lookup(root).Execute(&buf, cfg.Data)
 95		if err != nil {
 96			return errors.Wrap(err, root)
 97		}
 98		if cfg.RegionTags {
 99			buf.WriteString("\n// endregion " + center(70, "*", " "+root+" ") + "\n")
100		}
101	}
102
103	var result bytes.Buffer
104	if cfg.GeneratedHeader {
105		result.WriteString("// Code generated by github.com/99designs/gqlgen, DO NOT EDIT.\n\n")
106	}
107	result.WriteString("package ")
108	result.WriteString(cfg.PackageName)
109	result.WriteString("\n\n")
110	result.WriteString("import (\n")
111	result.WriteString(CurrentImports.String())
112	result.WriteString(")\n")
113	_, err = buf.WriteTo(&result)
114	if err != nil {
115		return err
116	}
117	CurrentImports = nil
118
119	return write(cfg.Filename, result.Bytes())
120}
121
122func center(width int, pad string, s string) string {
123	if len(s)+2 > width {
124		return s
125	}
126	lpad := (width - len(s)) / 2
127	rpad := width - (lpad + len(s))
128	return strings.Repeat(pad, lpad) + s + strings.Repeat(pad, rpad)
129}
130
131func Funcs() template.FuncMap {
132	return template.FuncMap{
133		"ucFirst":       ucFirst,
134		"lcFirst":       lcFirst,
135		"quote":         strconv.Quote,
136		"rawQuote":      rawQuote,
137		"dump":          Dump,
138		"ref":           ref,
139		"ts":            TypeIdentifier,
140		"call":          Call,
141		"prefixLines":   prefixLines,
142		"notNil":        notNil,
143		"reserveImport": CurrentImports.Reserve,
144		"lookupImport":  CurrentImports.Lookup,
145		"go":            ToGo,
146		"goPrivate":     ToGoPrivate,
147		"add": func(a, b int) int {
148			return a + b
149		},
150		"render": func(filename string, tpldata interface{}) (*bytes.Buffer, error) {
151			return render(resolveName(filename, 0), tpldata)
152		},
153	}
154}
155
156func ucFirst(s string) string {
157	if s == "" {
158		return ""
159	}
160	r := []rune(s)
161	r[0] = unicode.ToUpper(r[0])
162	return string(r)
163}
164
165func lcFirst(s string) string {
166	if s == "" {
167		return ""
168	}
169
170	r := []rune(s)
171	r[0] = unicode.ToLower(r[0])
172	return string(r)
173}
174
175func isDelimiter(c rune) bool {
176	return c == '-' || c == '_' || unicode.IsSpace(c)
177}
178
179func ref(p types.Type) string {
180	return CurrentImports.LookupType(p)
181}
182
183var pkgReplacer = strings.NewReplacer(
184	"/", "ᚋ",
185	".", "ᚗ",
186	"-", "ᚑ",
187)
188
189func TypeIdentifier(t types.Type) string {
190	res := ""
191	for {
192		switch it := t.(type) {
193		case *types.Pointer:
194			t.Underlying()
195			res += "ᚖ"
196			t = it.Elem()
197		case *types.Slice:
198			res += "ᚕ"
199			t = it.Elem()
200		case *types.Named:
201			res += pkgReplacer.Replace(it.Obj().Pkg().Path())
202			res += "ᚐ"
203			res += it.Obj().Name()
204			return res
205		case *types.Basic:
206			res += it.Name()
207			return res
208		case *types.Map:
209			res += "map"
210			return res
211		case *types.Interface:
212			res += "interface"
213			return res
214		default:
215			panic(fmt.Errorf("unexpected type %T", it))
216		}
217	}
218}
219
220func Call(p *types.Func) string {
221	pkg := CurrentImports.Lookup(p.Pkg().Path())
222
223	if pkg != "" {
224		pkg += "."
225	}
226
227	if p.Type() != nil {
228		// make sure the returned type is listed in our imports.
229		ref(p.Type().(*types.Signature).Results().At(0).Type())
230	}
231
232	return pkg + p.Name()
233}
234
235func ToGo(name string) string {
236	runes := make([]rune, 0, len(name))
237
238	wordWalker(name, func(info *wordInfo) {
239		word := info.Word
240		if info.MatchCommonInitial {
241			word = strings.ToUpper(word)
242		} else if !info.HasCommonInitial {
243			if strings.ToUpper(word) == word || strings.ToLower(word) == word {
244				// FOO or foo → Foo
245				// FOo → FOo
246				word = ucFirst(strings.ToLower(word))
247			}
248		}
249		runes = append(runes, []rune(word)...)
250	})
251
252	return string(runes)
253}
254
255func ToGoPrivate(name string) string {
256	runes := make([]rune, 0, len(name))
257
258	first := true
259	wordWalker(name, func(info *wordInfo) {
260		word := info.Word
261		if first {
262			if strings.ToUpper(word) == word || strings.ToLower(word) == word {
263				// ID → id, CAMEL → camel
264				word = strings.ToLower(info.Word)
265			} else {
266				// ITicket → iTicket
267				word = lcFirst(info.Word)
268			}
269			first = false
270		} else if info.MatchCommonInitial {
271			word = strings.ToUpper(word)
272		} else if !info.HasCommonInitial {
273			word = ucFirst(strings.ToLower(word))
274		}
275		runes = append(runes, []rune(word)...)
276	})
277
278	return sanitizeKeywords(string(runes))
279}
280
281type wordInfo struct {
282	Word               string
283	MatchCommonInitial bool
284	HasCommonInitial   bool
285}
286
287// This function is based on the following code.
288// https://github.com/golang/lint/blob/06c8688daad7faa9da5a0c2f163a3d14aac986ca/lint.go#L679
289func wordWalker(str string, f func(*wordInfo)) {
290	runes := []rune(str)
291	w, i := 0, 0 // index of start of word, scan
292	hasCommonInitial := false
293	for i+1 <= len(runes) {
294		eow := false // whether we hit the end of a word
295		if i+1 == len(runes) {
296			eow = true
297		} else if isDelimiter(runes[i+1]) {
298			// underscore; shift the remainder forward over any run of underscores
299			eow = true
300			n := 1
301			for i+n+1 < len(runes) && isDelimiter(runes[i+n+1]) {
302				n++
303			}
304
305			// Leave at most one underscore if the underscore is between two digits
306			if i+n+1 < len(runes) && unicode.IsDigit(runes[i]) && unicode.IsDigit(runes[i+n+1]) {
307				n--
308			}
309
310			copy(runes[i+1:], runes[i+n+1:])
311			runes = runes[:len(runes)-n]
312		} else if unicode.IsLower(runes[i]) && !unicode.IsLower(runes[i+1]) {
313			// lower->non-lower
314			eow = true
315		}
316		i++
317
318		// [w,i) is a word.
319		word := string(runes[w:i])
320		if !eow && commonInitialisms[word] && !unicode.IsLower(runes[i]) {
321			// through
322			// split IDFoo → ID, Foo
323			// but URLs → URLs
324		} else if !eow {
325			if commonInitialisms[word] {
326				hasCommonInitial = true
327			}
328			continue
329		}
330
331		matchCommonInitial := false
332		if commonInitialisms[strings.ToUpper(word)] {
333			hasCommonInitial = true
334			matchCommonInitial = true
335		}
336
337		f(&wordInfo{
338			Word:               word,
339			MatchCommonInitial: matchCommonInitial,
340			HasCommonInitial:   hasCommonInitial,
341		})
342		hasCommonInitial = false
343		w = i
344	}
345}
346
347var keywords = []string{
348	"break",
349	"default",
350	"func",
351	"interface",
352	"select",
353	"case",
354	"defer",
355	"go",
356	"map",
357	"struct",
358	"chan",
359	"else",
360	"goto",
361	"package",
362	"switch",
363	"const",
364	"fallthrough",
365	"if",
366	"range",
367	"type",
368	"continue",
369	"for",
370	"import",
371	"return",
372	"var",
373	"_",
374}
375
376// sanitizeKeywords prevents collisions with go keywords for arguments to resolver functions
377func sanitizeKeywords(name string) string {
378	for _, k := range keywords {
379		if name == k {
380			return name + "Arg"
381		}
382	}
383	return name
384}
385
386// commonInitialisms is a set of common initialisms.
387// Only add entries that are highly unlikely to be non-initialisms.
388// For instance, "ID" is fine (Freudian code is rare), but "AND" is not.
389var commonInitialisms = map[string]bool{
390	"ACL":   true,
391	"API":   true,
392	"ASCII": true,
393	"CPU":   true,
394	"CSS":   true,
395	"DNS":   true,
396	"EOF":   true,
397	"GUID":  true,
398	"HTML":  true,
399	"HTTP":  true,
400	"HTTPS": true,
401	"ID":    true,
402	"IP":    true,
403	"JSON":  true,
404	"LHS":   true,
405	"QPS":   true,
406	"RAM":   true,
407	"RHS":   true,
408	"RPC":   true,
409	"SLA":   true,
410	"SMTP":  true,
411	"SQL":   true,
412	"SSH":   true,
413	"TCP":   true,
414	"TLS":   true,
415	"TTL":   true,
416	"UDP":   true,
417	"UI":    true,
418	"UID":   true,
419	"UUID":  true,
420	"URI":   true,
421	"URL":   true,
422	"UTF8":  true,
423	"VM":    true,
424	"XML":   true,
425	"XMPP":  true,
426	"XSRF":  true,
427	"XSS":   true,
428}
429
430func rawQuote(s string) string {
431	return "`" + strings.Replace(s, "`", "`+\"`\"+`", -1) + "`"
432}
433
434func notNil(field string, data interface{}) bool {
435	v := reflect.ValueOf(data)
436
437	if v.Kind() == reflect.Ptr {
438		v = v.Elem()
439	}
440	if v.Kind() != reflect.Struct {
441		return false
442	}
443	val := v.FieldByName(field)
444
445	return val.IsValid() && !val.IsNil()
446}
447
448func Dump(val interface{}) string {
449	switch val := val.(type) {
450	case int:
451		return strconv.Itoa(val)
452	case int64:
453		return fmt.Sprintf("%d", val)
454	case float64:
455		return fmt.Sprintf("%f", val)
456	case string:
457		return strconv.Quote(val)
458	case bool:
459		return strconv.FormatBool(val)
460	case nil:
461		return "nil"
462	case []interface{}:
463		var parts []string
464		for _, part := range val {
465			parts = append(parts, Dump(part))
466		}
467		return "[]interface{}{" + strings.Join(parts, ",") + "}"
468	case map[string]interface{}:
469		buf := bytes.Buffer{}
470		buf.WriteString("map[string]interface{}{")
471		var keys []string
472		for key := range val {
473			keys = append(keys, key)
474		}
475		sort.Strings(keys)
476
477		for _, key := range keys {
478			data := val[key]
479
480			buf.WriteString(strconv.Quote(key))
481			buf.WriteString(":")
482			buf.WriteString(Dump(data))
483			buf.WriteString(",")
484		}
485		buf.WriteString("}")
486		return buf.String()
487	default:
488		panic(fmt.Errorf("unsupported type %T", val))
489	}
490}
491
492func prefixLines(prefix, s string) string {
493	return prefix + strings.Replace(s, "\n", "\n"+prefix, -1)
494}
495
496func resolveName(name string, skip int) string {
497	if name[0] == '.' {
498		// load path relative to calling source file
499		_, callerFile, _, _ := runtime.Caller(skip + 1)
500		return filepath.Join(filepath.Dir(callerFile), name[1:])
501	}
502
503	// load path relative to this directory
504	_, callerFile, _, _ := runtime.Caller(0)
505	return filepath.Join(filepath.Dir(callerFile), name)
506}
507
508func render(filename string, tpldata interface{}) (*bytes.Buffer, error) {
509	t := template.New("").Funcs(Funcs())
510
511	b, err := ioutil.ReadFile(filename)
512	if err != nil {
513		return nil, err
514	}
515
516	t, err = t.New(filepath.Base(filename)).Parse(string(b))
517	if err != nil {
518		panic(err)
519	}
520
521	buf := &bytes.Buffer{}
522	return buf, t.Execute(buf, tpldata)
523}
524
525func write(filename string, b []byte) error {
526	err := os.MkdirAll(filepath.Dir(filename), 0755)
527	if err != nil {
528		return errors.Wrap(err, "failed to create directory")
529	}
530
531	formatted, err := imports.Prune(filename, b)
532	if err != nil {
533		fmt.Fprintf(os.Stderr, "gofmt failed on %s: %s\n", filepath.Base(filename), err.Error())
534		formatted = b
535	}
536
537	err = ioutil.WriteFile(filename, formatted, 0644)
538	if err != nil {
539		return errors.Wrapf(err, "failed to write %s", filename)
540	}
541
542	return nil
543}