templates.go

  1package templates
  2
  3import (
  4	"bytes"
  5	"fmt"
  6	"go/types"
  7	"io/ioutil"
  8	"os"
  9	"path/filepath"
 10	"reflect"
 11	"runtime"
 12	"sort"
 13	"strconv"
 14	"strings"
 15	"text/template"
 16	"unicode"
 17
 18	"github.com/99designs/gqlgen/internal/imports"
 19	"github.com/pkg/errors"
 20)
 21
 22// CurrentImports keeps track of all the import declarations that are needed during the execution of a plugin.
 23// this is done with a global because subtemplates currently get called in functions. Lets aim to remove this eventually.
 24var CurrentImports *Imports
 25
 26// Options specify various parameters to rendering a template.
 27type Options struct {
 28	// PackageName is a helper that specifies the package header declaration.
 29	// In other words, when you write the template you don't need to specify `package X`
 30	// at the top of the file. By providing PackageName in the Options, the Render
 31	// function will do that for you.
 32	PackageName string
 33	// Template is a string of the entire template that
 34	// will be parsed and rendered. If it's empty,
 35	// the plugin processor will look for .gotpl files
 36	// in the same directory of where you wrote the plugin.
 37	Template string
 38	// Filename is the name of the file that will be
 39	// written to the system disk once the template is rendered.
 40	Filename        string
 41	RegionTags      bool
 42	GeneratedHeader bool
 43	// Data will be passed to the template execution.
 44	Data  interface{}
 45	Funcs template.FuncMap
 46}
 47
 48// Render renders a gql plugin template from the given Options. Render is an
 49// abstraction of the text/template package that makes it easier to write gqlgen
 50// plugins. If Options.Template is empty, the Render function will look for `.gotpl`
 51// files inside the directory where you wrote the plugin.
 52func Render(cfg Options) error {
 53	if CurrentImports != nil {
 54		panic(fmt.Errorf("recursive or concurrent call to RenderToFile detected"))
 55	}
 56	CurrentImports = &Imports{destDir: filepath.Dir(cfg.Filename)}
 57
 58	// load path relative to calling source file
 59	_, callerFile, _, _ := runtime.Caller(1)
 60	rootDir := filepath.Dir(callerFile)
 61
 62	funcs := Funcs()
 63	for n, f := range cfg.Funcs {
 64		funcs[n] = f
 65	}
 66	t := template.New("").Funcs(funcs)
 67
 68	var roots []string
 69	if cfg.Template != "" {
 70		var err error
 71		t, err = t.New("template.gotpl").Parse(cfg.Template)
 72		if err != nil {
 73			return errors.Wrap(err, "error with provided template")
 74		}
 75		roots = append(roots, "template.gotpl")
 76	} else {
 77		// load all the templates in the directory
 78		err := filepath.Walk(rootDir, func(path string, info os.FileInfo, err error) error {
 79			if err != nil {
 80				return err
 81			}
 82			name := filepath.ToSlash(strings.TrimPrefix(path, rootDir+string(os.PathSeparator)))
 83			if !strings.HasSuffix(info.Name(), ".gotpl") {
 84				return nil
 85			}
 86			b, err := ioutil.ReadFile(path)
 87			if err != nil {
 88				return err
 89			}
 90
 91			t, err = t.New(name).Parse(string(b))
 92			if err != nil {
 93				return errors.Wrap(err, cfg.Filename)
 94			}
 95
 96			roots = append(roots, name)
 97
 98			return nil
 99		})
100		if err != nil {
101			return errors.Wrap(err, "locating templates")
102		}
103	}
104
105	// then execute all the important looking ones in order, adding them to the same file
106	sort.Slice(roots, func(i, j int) bool {
107		// important files go first
108		if strings.HasSuffix(roots[i], "!.gotpl") {
109			return true
110		}
111		if strings.HasSuffix(roots[j], "!.gotpl") {
112			return false
113		}
114		return roots[i] < roots[j]
115	})
116	var buf bytes.Buffer
117	for _, root := range roots {
118		if cfg.RegionTags {
119			buf.WriteString("\n// region    " + center(70, "*", " "+root+" ") + "\n")
120		}
121		err := t.Lookup(root).Execute(&buf, cfg.Data)
122		if err != nil {
123			return errors.Wrap(err, root)
124		}
125		if cfg.RegionTags {
126			buf.WriteString("\n// endregion " + center(70, "*", " "+root+" ") + "\n")
127		}
128	}
129
130	var result bytes.Buffer
131	if cfg.GeneratedHeader {
132		result.WriteString("// Code generated by github.com/99designs/gqlgen, DO NOT EDIT.\n\n")
133	}
134	result.WriteString("package ")
135	result.WriteString(cfg.PackageName)
136	result.WriteString("\n\n")
137	result.WriteString("import (\n")
138	result.WriteString(CurrentImports.String())
139	result.WriteString(")\n")
140	_, err := buf.WriteTo(&result)
141	if err != nil {
142		return err
143	}
144	CurrentImports = nil
145
146	return write(cfg.Filename, result.Bytes())
147}
148
149func center(width int, pad string, s string) string {
150	if len(s)+2 > width {
151		return s
152	}
153	lpad := (width - len(s)) / 2
154	rpad := width - (lpad + len(s))
155	return strings.Repeat(pad, lpad) + s + strings.Repeat(pad, rpad)
156}
157
158func Funcs() template.FuncMap {
159	return template.FuncMap{
160		"ucFirst":       ucFirst,
161		"lcFirst":       lcFirst,
162		"quote":         strconv.Quote,
163		"rawQuote":      rawQuote,
164		"dump":          Dump,
165		"ref":           ref,
166		"ts":            TypeIdentifier,
167		"call":          Call,
168		"prefixLines":   prefixLines,
169		"notNil":        notNil,
170		"reserveImport": CurrentImports.Reserve,
171		"lookupImport":  CurrentImports.Lookup,
172		"go":            ToGo,
173		"goPrivate":     ToGoPrivate,
174		"add": func(a, b int) int {
175			return a + b
176		},
177		"render": func(filename string, tpldata interface{}) (*bytes.Buffer, error) {
178			return render(resolveName(filename, 0), tpldata)
179		},
180	}
181}
182
183func ucFirst(s string) string {
184	if s == "" {
185		return ""
186	}
187	r := []rune(s)
188	r[0] = unicode.ToUpper(r[0])
189	return string(r)
190}
191
192func lcFirst(s string) string {
193	if s == "" {
194		return ""
195	}
196
197	r := []rune(s)
198	r[0] = unicode.ToLower(r[0])
199	return string(r)
200}
201
202func isDelimiter(c rune) bool {
203	return c == '-' || c == '_' || unicode.IsSpace(c)
204}
205
206func ref(p types.Type) string {
207	return CurrentImports.LookupType(p)
208}
209
210var pkgReplacer = strings.NewReplacer(
211	"/", "ᚋ",
212	".", "ᚗ",
213	"-", "ᚑ",
214)
215
216func TypeIdentifier(t types.Type) string {
217	res := ""
218	for {
219		switch it := t.(type) {
220		case *types.Pointer:
221			t.Underlying()
222			res += "ᚖ"
223			t = it.Elem()
224		case *types.Slice:
225			res += "ᚕ"
226			t = it.Elem()
227		case *types.Named:
228			res += pkgReplacer.Replace(it.Obj().Pkg().Path())
229			res += "ᚐ"
230			res += it.Obj().Name()
231			return res
232		case *types.Basic:
233			res += it.Name()
234			return res
235		case *types.Map:
236			res += "map"
237			return res
238		case *types.Interface:
239			res += "interface"
240			return res
241		default:
242			panic(fmt.Errorf("unexpected type %T", it))
243		}
244	}
245}
246
247func Call(p *types.Func) string {
248	pkg := CurrentImports.Lookup(p.Pkg().Path())
249
250	if pkg != "" {
251		pkg += "."
252	}
253
254	if p.Type() != nil {
255		// make sure the returned type is listed in our imports.
256		ref(p.Type().(*types.Signature).Results().At(0).Type())
257	}
258
259	return pkg + p.Name()
260}
261
262func ToGo(name string) string {
263	runes := make([]rune, 0, len(name))
264
265	wordWalker(name, func(info *wordInfo) {
266		word := info.Word
267		if info.MatchCommonInitial {
268			word = strings.ToUpper(word)
269		} else if !info.HasCommonInitial {
270			if strings.ToUpper(word) == word || strings.ToLower(word) == word {
271				// FOO or foo → Foo
272				// FOo → FOo
273				word = ucFirst(strings.ToLower(word))
274			}
275		}
276		runes = append(runes, []rune(word)...)
277	})
278
279	return string(runes)
280}
281
282func ToGoPrivate(name string) string {
283	runes := make([]rune, 0, len(name))
284
285	first := true
286	wordWalker(name, func(info *wordInfo) {
287		word := info.Word
288		if first {
289			if strings.ToUpper(word) == word || strings.ToLower(word) == word {
290				// ID → id, CAMEL → camel
291				word = strings.ToLower(info.Word)
292			} else {
293				// ITicket → iTicket
294				word = lcFirst(info.Word)
295			}
296			first = false
297		} else if info.MatchCommonInitial {
298			word = strings.ToUpper(word)
299		} else if !info.HasCommonInitial {
300			word = ucFirst(strings.ToLower(word))
301		}
302		runes = append(runes, []rune(word)...)
303	})
304
305	return sanitizeKeywords(string(runes))
306}
307
308type wordInfo struct {
309	Word               string
310	MatchCommonInitial bool
311	HasCommonInitial   bool
312}
313
314// This function is based on the following code.
315// https://github.com/golang/lint/blob/06c8688daad7faa9da5a0c2f163a3d14aac986ca/lint.go#L679
316func wordWalker(str string, f func(*wordInfo)) {
317	runes := []rune(str)
318	w, i := 0, 0 // index of start of word, scan
319	hasCommonInitial := false
320	for i+1 <= len(runes) {
321		eow := false // whether we hit the end of a word
322		if i+1 == len(runes) {
323			eow = true
324		} else if isDelimiter(runes[i+1]) {
325			// underscore; shift the remainder forward over any run of underscores
326			eow = true
327			n := 1
328			for i+n+1 < len(runes) && isDelimiter(runes[i+n+1]) {
329				n++
330			}
331
332			// Leave at most one underscore if the underscore is between two digits
333			if i+n+1 < len(runes) && unicode.IsDigit(runes[i]) && unicode.IsDigit(runes[i+n+1]) {
334				n--
335			}
336
337			copy(runes[i+1:], runes[i+n+1:])
338			runes = runes[:len(runes)-n]
339		} else if unicode.IsLower(runes[i]) && !unicode.IsLower(runes[i+1]) {
340			// lower->non-lower
341			eow = true
342		}
343		i++
344
345		// [w,i) is a word.
346		word := string(runes[w:i])
347		if !eow && commonInitialisms[word] && !unicode.IsLower(runes[i]) {
348			// through
349			// split IDFoo → ID, Foo
350			// but URLs → URLs
351		} else if !eow {
352			if commonInitialisms[word] {
353				hasCommonInitial = true
354			}
355			continue
356		}
357
358		matchCommonInitial := false
359		if commonInitialisms[strings.ToUpper(word)] {
360			hasCommonInitial = true
361			matchCommonInitial = true
362		}
363
364		f(&wordInfo{
365			Word:               word,
366			MatchCommonInitial: matchCommonInitial,
367			HasCommonInitial:   hasCommonInitial,
368		})
369		hasCommonInitial = false
370		w = i
371	}
372}
373
374var keywords = []string{
375	"break",
376	"default",
377	"func",
378	"interface",
379	"select",
380	"case",
381	"defer",
382	"go",
383	"map",
384	"struct",
385	"chan",
386	"else",
387	"goto",
388	"package",
389	"switch",
390	"const",
391	"fallthrough",
392	"if",
393	"range",
394	"type",
395	"continue",
396	"for",
397	"import",
398	"return",
399	"var",
400	"_",
401}
402
403// sanitizeKeywords prevents collisions with go keywords for arguments to resolver functions
404func sanitizeKeywords(name string) string {
405	for _, k := range keywords {
406		if name == k {
407			return name + "Arg"
408		}
409	}
410	return name
411}
412
413// commonInitialisms is a set of common initialisms.
414// Only add entries that are highly unlikely to be non-initialisms.
415// For instance, "ID" is fine (Freudian code is rare), but "AND" is not.
416var commonInitialisms = map[string]bool{
417	"ACL":   true,
418	"API":   true,
419	"ASCII": true,
420	"CPU":   true,
421	"CSS":   true,
422	"DNS":   true,
423	"EOF":   true,
424	"GUID":  true,
425	"HTML":  true,
426	"HTTP":  true,
427	"HTTPS": true,
428	"ID":    true,
429	"IP":    true,
430	"JSON":  true,
431	"LHS":   true,
432	"QPS":   true,
433	"RAM":   true,
434	"RHS":   true,
435	"RPC":   true,
436	"SLA":   true,
437	"SMTP":  true,
438	"SQL":   true,
439	"SSH":   true,
440	"TCP":   true,
441	"TLS":   true,
442	"TTL":   true,
443	"UDP":   true,
444	"UI":    true,
445	"UID":   true,
446	"UUID":  true,
447	"URI":   true,
448	"URL":   true,
449	"UTF8":  true,
450	"VM":    true,
451	"XML":   true,
452	"XMPP":  true,
453	"XSRF":  true,
454	"XSS":   true,
455}
456
457func rawQuote(s string) string {
458	return "`" + strings.Replace(s, "`", "`+\"`\"+`", -1) + "`"
459}
460
461func notNil(field string, data interface{}) bool {
462	v := reflect.ValueOf(data)
463
464	if v.Kind() == reflect.Ptr {
465		v = v.Elem()
466	}
467	if v.Kind() != reflect.Struct {
468		return false
469	}
470	val := v.FieldByName(field)
471
472	return val.IsValid() && !val.IsNil()
473}
474
475func Dump(val interface{}) string {
476	switch val := val.(type) {
477	case int:
478		return strconv.Itoa(val)
479	case int64:
480		return fmt.Sprintf("%d", val)
481	case float64:
482		return fmt.Sprintf("%f", val)
483	case string:
484		return strconv.Quote(val)
485	case bool:
486		return strconv.FormatBool(val)
487	case nil:
488		return "nil"
489	case []interface{}:
490		var parts []string
491		for _, part := range val {
492			parts = append(parts, Dump(part))
493		}
494		return "[]interface{}{" + strings.Join(parts, ",") + "}"
495	case map[string]interface{}:
496		buf := bytes.Buffer{}
497		buf.WriteString("map[string]interface{}{")
498		var keys []string
499		for key := range val {
500			keys = append(keys, key)
501		}
502		sort.Strings(keys)
503
504		for _, key := range keys {
505			data := val[key]
506
507			buf.WriteString(strconv.Quote(key))
508			buf.WriteString(":")
509			buf.WriteString(Dump(data))
510			buf.WriteString(",")
511		}
512		buf.WriteString("}")
513		return buf.String()
514	default:
515		panic(fmt.Errorf("unsupported type %T", val))
516	}
517}
518
519func prefixLines(prefix, s string) string {
520	return prefix + strings.Replace(s, "\n", "\n"+prefix, -1)
521}
522
523func resolveName(name string, skip int) string {
524	if name[0] == '.' {
525		// load path relative to calling source file
526		_, callerFile, _, _ := runtime.Caller(skip + 1)
527		return filepath.Join(filepath.Dir(callerFile), name[1:])
528	}
529
530	// load path relative to this directory
531	_, callerFile, _, _ := runtime.Caller(0)
532	return filepath.Join(filepath.Dir(callerFile), name)
533}
534
535func render(filename string, tpldata interface{}) (*bytes.Buffer, error) {
536	t := template.New("").Funcs(Funcs())
537
538	b, err := ioutil.ReadFile(filename)
539	if err != nil {
540		return nil, err
541	}
542
543	t, err = t.New(filepath.Base(filename)).Parse(string(b))
544	if err != nil {
545		panic(err)
546	}
547
548	buf := &bytes.Buffer{}
549	return buf, t.Execute(buf, tpldata)
550}
551
552func write(filename string, b []byte) error {
553	err := os.MkdirAll(filepath.Dir(filename), 0755)
554	if err != nil {
555		return errors.Wrap(err, "failed to create directory")
556	}
557
558	formatted, err := imports.Prune(filename, b)
559	if err != nil {
560		fmt.Fprintf(os.Stderr, "gofmt failed on %s: %s\n", filepath.Base(filename), err.Error())
561		formatted = b
562	}
563
564	err = ioutil.WriteFile(filename, formatted, 0644)
565	if err != nil {
566		return errors.Wrapf(err, "failed to write %s", filename)
567	}
568
569	return nil
570}