templates.go

  1package templates
  2
  3import (
  4	"bytes"
  5	"fmt"
  6	"go/types"
  7	"io/ioutil"
  8	"os"
  9	"path/filepath"
 10	"reflect"
 11	"runtime"
 12	"sort"
 13	"strconv"
 14	"strings"
 15	"text/template"
 16	"unicode"
 17
 18	"github.com/99designs/gqlgen/internal/imports"
 19	"github.com/pkg/errors"
 20)
 21
 22// CurrentImports keeps track of all the import declarations that are needed during the execution of a plugin.
 23// this is done with a global because subtemplates currently get called in functions. Lets aim to remove this eventually.
 24var CurrentImports *Imports
 25
 26// Options specify various parameters to rendering a template.
 27type Options struct {
 28	// PackageName is a helper that specifies the package header declaration.
 29	// In other words, when you write the template you don't need to specify `package X`
 30	// at the top of the file. By providing PackageName in the Options, the Render
 31	// function will do that for you.
 32	PackageName string
 33	// Template is a string of the entire template that
 34	// will be parsed and rendered. If it's empty,
 35	// the plugin processor will look for .gotpl files
 36	// in the same directory of where you wrote the plugin.
 37	Template string
 38	// Filename is the name of the file that will be
 39	// written to the system disk once the template is rendered.
 40	Filename        string
 41	RegionTags      bool
 42	GeneratedHeader bool
 43	// Data will be passed to the template execution.
 44	Data  interface{}
 45	Funcs template.FuncMap
 46}
 47
 48// Render renders a gql plugin template from the given Options. Render is an
 49// abstraction of the text/template package that makes it easier to write gqlgen
 50// plugins. If Options.Template is empty, the Render function will look for `.gotpl`
 51// files inside the directory where you wrote the plugin.
 52func Render(cfg Options) error {
 53	if CurrentImports != nil {
 54		panic(fmt.Errorf("recursive or concurrent call to RenderToFile detected"))
 55	}
 56	CurrentImports = &Imports{destDir: filepath.Dir(cfg.Filename)}
 57
 58	// load path relative to calling source file
 59	_, callerFile, _, _ := runtime.Caller(1)
 60	rootDir := filepath.Dir(callerFile)
 61
 62	funcs := Funcs()
 63	for n, f := range cfg.Funcs {
 64		funcs[n] = f
 65	}
 66	t := template.New("").Funcs(funcs)
 67
 68	var roots []string
 69	if cfg.Template != "" {
 70		var err error
 71		t, err = t.New("template.gotpl").Parse(cfg.Template)
 72		if err != nil {
 73			return errors.Wrap(err, "error with provided template")
 74		}
 75		roots = append(roots, "template.gotpl")
 76	} else {
 77		// load all the templates in the directory
 78		err := filepath.Walk(rootDir, func(path string, info os.FileInfo, err error) error {
 79			if err != nil {
 80				return err
 81			}
 82			name := filepath.ToSlash(strings.TrimPrefix(path, rootDir+string(os.PathSeparator)))
 83			if !strings.HasSuffix(info.Name(), ".gotpl") {
 84				return nil
 85			}
 86			b, err := ioutil.ReadFile(path)
 87			if err != nil {
 88				return err
 89			}
 90
 91			t, err = t.New(name).Parse(string(b))
 92			if err != nil {
 93				return errors.Wrap(err, cfg.Filename)
 94			}
 95
 96			roots = append(roots, name)
 97
 98			return nil
 99		})
100		if err != nil {
101			return errors.Wrap(err, "locating templates")
102		}
103	}
104
105	// then execute all the important looking ones in order, adding them to the same file
106	sort.Slice(roots, func(i, j int) bool {
107		// important files go first
108		if strings.HasSuffix(roots[i], "!.gotpl") {
109			return true
110		}
111		if strings.HasSuffix(roots[j], "!.gotpl") {
112			return false
113		}
114		return roots[i] < roots[j]
115	})
116	var buf bytes.Buffer
117	for _, root := range roots {
118		if cfg.RegionTags {
119			buf.WriteString("\n// region    " + center(70, "*", " "+root+" ") + "\n")
120		}
121		err := t.Lookup(root).Execute(&buf, cfg.Data)
122		if err != nil {
123			return errors.Wrap(err, root)
124		}
125		if cfg.RegionTags {
126			buf.WriteString("\n// endregion " + center(70, "*", " "+root+" ") + "\n")
127		}
128	}
129
130	var result bytes.Buffer
131	if cfg.GeneratedHeader {
132		result.WriteString("// Code generated by github.com/99designs/gqlgen, DO NOT EDIT.\n\n")
133	}
134	result.WriteString("package ")
135	result.WriteString(cfg.PackageName)
136	result.WriteString("\n\n")
137	result.WriteString("import (\n")
138	result.WriteString(CurrentImports.String())
139	result.WriteString(")\n")
140	_, err := buf.WriteTo(&result)
141	if err != nil {
142		return err
143	}
144	CurrentImports = nil
145
146	return write(cfg.Filename, result.Bytes())
147}
148
149func center(width int, pad string, s string) string {
150	if len(s)+2 > width {
151		return s
152	}
153	lpad := (width - len(s)) / 2
154	rpad := width - (lpad + len(s))
155	return strings.Repeat(pad, lpad) + s + strings.Repeat(pad, rpad)
156}
157
158func Funcs() template.FuncMap {
159	return template.FuncMap{
160		"ucFirst":       ucFirst,
161		"lcFirst":       lcFirst,
162		"quote":         strconv.Quote,
163		"rawQuote":      rawQuote,
164		"dump":          Dump,
165		"ref":           ref,
166		"ts":            TypeIdentifier,
167		"call":          Call,
168		"prefixLines":   prefixLines,
169		"notNil":        notNil,
170		"reserveImport": CurrentImports.Reserve,
171		"lookupImport":  CurrentImports.Lookup,
172		"go":            ToGo,
173		"goPrivate":     ToGoPrivate,
174		"add": func(a, b int) int {
175			return a + b
176		},
177		"render": func(filename string, tpldata interface{}) (*bytes.Buffer, error) {
178			return render(resolveName(filename, 0), tpldata)
179		},
180	}
181}
182
183func ucFirst(s string) string {
184	if s == "" {
185		return ""
186	}
187	r := []rune(s)
188	r[0] = unicode.ToUpper(r[0])
189	return string(r)
190}
191
192func lcFirst(s string) string {
193	if s == "" {
194		return ""
195	}
196
197	r := []rune(s)
198	r[0] = unicode.ToLower(r[0])
199	return string(r)
200}
201
202func isDelimiter(c rune) bool {
203	return c == '-' || c == '_' || unicode.IsSpace(c)
204}
205
206func ref(p types.Type) string {
207	return CurrentImports.LookupType(p)
208}
209
210var pkgReplacer = strings.NewReplacer(
211	"/", "ᚋ",
212	".", "ᚗ",
213	"-", "ᚑ",
214)
215
216func TypeIdentifier(t types.Type) string {
217	res := ""
218	for {
219		switch it := t.(type) {
220		case *types.Pointer:
221			t.Underlying()
222			res += "ᚖ"
223			t = it.Elem()
224		case *types.Slice:
225			res += "ᚕ"
226			t = it.Elem()
227		case *types.Named:
228			res += pkgReplacer.Replace(it.Obj().Pkg().Path())
229			res += "ᚐ"
230			res += it.Obj().Name()
231			return res
232		case *types.Basic:
233			res += it.Name()
234			return res
235		case *types.Map:
236			res += "map"
237			return res
238		case *types.Interface:
239			res += "interface"
240			return res
241		default:
242			panic(fmt.Errorf("unexpected type %T", it))
243		}
244	}
245}
246
247func Call(p *types.Func) string {
248	pkg := CurrentImports.Lookup(p.Pkg().Path())
249
250	if pkg != "" {
251		pkg += "."
252	}
253
254	if p.Type() != nil {
255		// make sure the returned type is listed in our imports.
256		ref(p.Type().(*types.Signature).Results().At(0).Type())
257	}
258
259	return pkg + p.Name()
260}
261
262func ToGo(name string) string {
263	runes := make([]rune, 0, len(name))
264
265	wordWalker(name, func(info *wordInfo) {
266		word := info.Word
267		if info.MatchCommonInitial {
268			word = strings.ToUpper(word)
269		} else if !info.HasCommonInitial {
270			if strings.ToUpper(word) == word || strings.ToLower(word) == word {
271				// FOO or foo → Foo
272				// FOo → FOo
273				word = ucFirst(strings.ToLower(word))
274			}
275		}
276		runes = append(runes, []rune(word)...)
277	})
278
279	return string(runes)
280}
281
282func ToGoPrivate(name string) string {
283	runes := make([]rune, 0, len(name))
284
285	first := true
286	wordWalker(name, func(info *wordInfo) {
287		word := info.Word
288		switch {
289		case first:
290			if strings.ToUpper(word) == word || strings.ToLower(word) == word {
291				// ID → id, CAMEL → camel
292				word = strings.ToLower(info.Word)
293			} else {
294				// ITicket → iTicket
295				word = lcFirst(info.Word)
296			}
297			first = false
298		case info.MatchCommonInitial:
299			word = strings.ToUpper(word)
300		case !info.HasCommonInitial:
301			word = ucFirst(strings.ToLower(word))
302		}
303		runes = append(runes, []rune(word)...)
304	})
305
306	return sanitizeKeywords(string(runes))
307}
308
309type wordInfo struct {
310	Word               string
311	MatchCommonInitial bool
312	HasCommonInitial   bool
313}
314
315// This function is based on the following code.
316// https://github.com/golang/lint/blob/06c8688daad7faa9da5a0c2f163a3d14aac986ca/lint.go#L679
317func wordWalker(str string, f func(*wordInfo)) {
318	runes := []rune(str)
319	w, i := 0, 0 // index of start of word, scan
320	hasCommonInitial := false
321	for i+1 <= len(runes) {
322		eow := false // whether we hit the end of a word
323		switch {
324		case i+1 == len(runes):
325			eow = true
326		case isDelimiter(runes[i+1]):
327			// underscore; shift the remainder forward over any run of underscores
328			eow = true
329			n := 1
330			for i+n+1 < len(runes) && isDelimiter(runes[i+n+1]) {
331				n++
332			}
333
334			// Leave at most one underscore if the underscore is between two digits
335			if i+n+1 < len(runes) && unicode.IsDigit(runes[i]) && unicode.IsDigit(runes[i+n+1]) {
336				n--
337			}
338
339			copy(runes[i+1:], runes[i+n+1:])
340			runes = runes[:len(runes)-n]
341		case unicode.IsLower(runes[i]) && !unicode.IsLower(runes[i+1]):
342			// lower->non-lower
343			eow = true
344		}
345		i++
346
347		// [w,i) is a word.
348		word := string(runes[w:i])
349		if !eow && commonInitialisms[word] && !unicode.IsLower(runes[i]) {
350			// through
351			// split IDFoo → ID, Foo
352			// but URLs → URLs
353		} else if !eow {
354			if commonInitialisms[word] {
355				hasCommonInitial = true
356			}
357			continue
358		}
359
360		matchCommonInitial := false
361		if commonInitialisms[strings.ToUpper(word)] {
362			hasCommonInitial = true
363			matchCommonInitial = true
364		}
365
366		f(&wordInfo{
367			Word:               word,
368			MatchCommonInitial: matchCommonInitial,
369			HasCommonInitial:   hasCommonInitial,
370		})
371		hasCommonInitial = false
372		w = i
373	}
374}
375
376var keywords = []string{
377	"break",
378	"default",
379	"func",
380	"interface",
381	"select",
382	"case",
383	"defer",
384	"go",
385	"map",
386	"struct",
387	"chan",
388	"else",
389	"goto",
390	"package",
391	"switch",
392	"const",
393	"fallthrough",
394	"if",
395	"range",
396	"type",
397	"continue",
398	"for",
399	"import",
400	"return",
401	"var",
402	"_",
403}
404
405// sanitizeKeywords prevents collisions with go keywords for arguments to resolver functions
406func sanitizeKeywords(name string) string {
407	for _, k := range keywords {
408		if name == k {
409			return name + "Arg"
410		}
411	}
412	return name
413}
414
415// commonInitialisms is a set of common initialisms.
416// Only add entries that are highly unlikely to be non-initialisms.
417// For instance, "ID" is fine (Freudian code is rare), but "AND" is not.
418var commonInitialisms = map[string]bool{
419	"ACL":   true,
420	"API":   true,
421	"ASCII": true,
422	"CPU":   true,
423	"CSS":   true,
424	"DNS":   true,
425	"EOF":   true,
426	"GUID":  true,
427	"HTML":  true,
428	"HTTP":  true,
429	"HTTPS": true,
430	"ID":    true,
431	"IP":    true,
432	"JSON":  true,
433	"LHS":   true,
434	"QPS":   true,
435	"RAM":   true,
436	"RHS":   true,
437	"RPC":   true,
438	"SLA":   true,
439	"SMTP":  true,
440	"SQL":   true,
441	"SSH":   true,
442	"TCP":   true,
443	"TLS":   true,
444	"TTL":   true,
445	"UDP":   true,
446	"UI":    true,
447	"UID":   true,
448	"UUID":  true,
449	"URI":   true,
450	"URL":   true,
451	"UTF8":  true,
452	"VM":    true,
453	"XML":   true,
454	"XMPP":  true,
455	"XSRF":  true,
456	"XSS":   true,
457}
458
459func rawQuote(s string) string {
460	return "`" + strings.Replace(s, "`", "`+\"`\"+`", -1) + "`"
461}
462
463func notNil(field string, data interface{}) bool {
464	v := reflect.ValueOf(data)
465
466	if v.Kind() == reflect.Ptr {
467		v = v.Elem()
468	}
469	if v.Kind() != reflect.Struct {
470		return false
471	}
472	val := v.FieldByName(field)
473
474	return val.IsValid() && !val.IsNil()
475}
476
477func Dump(val interface{}) string {
478	switch val := val.(type) {
479	case int:
480		return strconv.Itoa(val)
481	case int64:
482		return fmt.Sprintf("%d", val)
483	case float64:
484		return fmt.Sprintf("%f", val)
485	case string:
486		return strconv.Quote(val)
487	case bool:
488		return strconv.FormatBool(val)
489	case nil:
490		return "nil"
491	case []interface{}:
492		var parts []string
493		for _, part := range val {
494			parts = append(parts, Dump(part))
495		}
496		return "[]interface{}{" + strings.Join(parts, ",") + "}"
497	case map[string]interface{}:
498		buf := bytes.Buffer{}
499		buf.WriteString("map[string]interface{}{")
500		var keys []string
501		for key := range val {
502			keys = append(keys, key)
503		}
504		sort.Strings(keys)
505
506		for _, key := range keys {
507			data := val[key]
508
509			buf.WriteString(strconv.Quote(key))
510			buf.WriteString(":")
511			buf.WriteString(Dump(data))
512			buf.WriteString(",")
513		}
514		buf.WriteString("}")
515		return buf.String()
516	default:
517		panic(fmt.Errorf("unsupported type %T", val))
518	}
519}
520
521func prefixLines(prefix, s string) string {
522	return prefix + strings.Replace(s, "\n", "\n"+prefix, -1)
523}
524
525func resolveName(name string, skip int) string {
526	if name[0] == '.' {
527		// load path relative to calling source file
528		_, callerFile, _, _ := runtime.Caller(skip + 1)
529		return filepath.Join(filepath.Dir(callerFile), name[1:])
530	}
531
532	// load path relative to this directory
533	_, callerFile, _, _ := runtime.Caller(0)
534	return filepath.Join(filepath.Dir(callerFile), name)
535}
536
537func render(filename string, tpldata interface{}) (*bytes.Buffer, error) {
538	t := template.New("").Funcs(Funcs())
539
540	b, err := ioutil.ReadFile(filename)
541	if err != nil {
542		return nil, err
543	}
544
545	t, err = t.New(filepath.Base(filename)).Parse(string(b))
546	if err != nil {
547		panic(err)
548	}
549
550	buf := &bytes.Buffer{}
551	return buf, t.Execute(buf, tpldata)
552}
553
554func write(filename string, b []byte) error {
555	err := os.MkdirAll(filepath.Dir(filename), 0755)
556	if err != nil {
557		return errors.Wrap(err, "failed to create directory")
558	}
559
560	formatted, err := imports.Prune(filename, b)
561	if err != nil {
562		fmt.Fprintf(os.Stderr, "gofmt failed on %s: %s\n", filepath.Base(filename), err.Error())
563		formatted = b
564	}
565
566	err = ioutil.WriteFile(filename, formatted, 0644)
567	if err != nil {
568		return errors.Wrapf(err, "failed to write %s", filename)
569	}
570
571	return nil
572}