codegen.go

  1package codegen
  2
  3import (
  4	"bytes"
  5	"fmt"
  6	"io/ioutil"
  7	"os"
  8	"path/filepath"
  9	"regexp"
 10	"syscall"
 11
 12	"github.com/pkg/errors"
 13	"github.com/vektah/gqlgen/codegen/templates"
 14	"github.com/vektah/gqlgen/neelance/schema"
 15	"golang.org/x/tools/imports"
 16)
 17
 18func Generate(cfg Config) error {
 19	if err := cfg.normalize(); err != nil {
 20		return err
 21	}
 22
 23	_ = syscall.Unlink(cfg.Exec.Filename)
 24	_ = syscall.Unlink(cfg.Model.Filename)
 25
 26	modelsBuild, err := cfg.models()
 27	if err != nil {
 28		return errors.Wrap(err, "model plan failed")
 29	}
 30	if len(modelsBuild.Models) > 0 || len(modelsBuild.Enums) > 0 {
 31		var buf *bytes.Buffer
 32		buf, err = templates.Run("models.gotpl", modelsBuild)
 33		if err != nil {
 34			return errors.Wrap(err, "model generation failed")
 35		}
 36
 37		if err = write(cfg.Model.Filename, buf.Bytes()); err != nil {
 38			return err
 39		}
 40		for _, model := range modelsBuild.Models {
 41			modelCfg := cfg.Models[model.GQLType]
 42			modelCfg.Model = cfg.Model.ImportPath() + "." + model.GoType
 43			cfg.Models[model.GQLType] = modelCfg
 44		}
 45
 46		for _, enum := range modelsBuild.Enums {
 47			modelCfg := cfg.Models[enum.GQLType]
 48			modelCfg.Model = cfg.Model.ImportPath() + "." + enum.GoType
 49			cfg.Models[enum.GQLType] = modelCfg
 50		}
 51	}
 52
 53	build, err := cfg.bind()
 54	if err != nil {
 55		return errors.Wrap(err, "exec plan failed")
 56	}
 57
 58	var buf *bytes.Buffer
 59	buf, err = templates.Run("generated.gotpl", build)
 60	if err != nil {
 61		return errors.Wrap(err, "exec codegen failed")
 62	}
 63
 64	if err = write(cfg.Exec.Filename, buf.Bytes()); err != nil {
 65		return err
 66	}
 67
 68	if err = cfg.validate(); err != nil {
 69		return errors.Wrap(err, "validation failed")
 70	}
 71
 72	return nil
 73}
 74
 75func (cfg *Config) normalize() error {
 76	if err := cfg.Model.normalize(); err != nil {
 77		return errors.Wrap(err, "model")
 78	}
 79
 80	if err := cfg.Exec.normalize(); err != nil {
 81		return errors.Wrap(err, "exec")
 82	}
 83
 84	builtins := TypeMap{
 85		"__Directive":  {Model: "github.com/vektah/gqlgen/neelance/introspection.Directive"},
 86		"__Type":       {Model: "github.com/vektah/gqlgen/neelance/introspection.Type"},
 87		"__Field":      {Model: "github.com/vektah/gqlgen/neelance/introspection.Field"},
 88		"__EnumValue":  {Model: "github.com/vektah/gqlgen/neelance/introspection.EnumValue"},
 89		"__InputValue": {Model: "github.com/vektah/gqlgen/neelance/introspection.InputValue"},
 90		"__Schema":     {Model: "github.com/vektah/gqlgen/neelance/introspection.Schema"},
 91		"Int":          {Model: "github.com/vektah/gqlgen/graphql.Int"},
 92		"Float":        {Model: "github.com/vektah/gqlgen/graphql.Float"},
 93		"String":       {Model: "github.com/vektah/gqlgen/graphql.String"},
 94		"Boolean":      {Model: "github.com/vektah/gqlgen/graphql.Boolean"},
 95		"ID":           {Model: "github.com/vektah/gqlgen/graphql.ID"},
 96		"Time":         {Model: "github.com/vektah/gqlgen/graphql.Time"},
 97		"Map":          {Model: "github.com/vektah/gqlgen/graphql.Map"},
 98	}
 99
100	if cfg.Models == nil {
101		cfg.Models = TypeMap{}
102	}
103	for typeName, entry := range builtins {
104		if !cfg.Models.Exists(typeName) {
105			cfg.Models[typeName] = entry
106		}
107	}
108
109	cfg.schema = schema.New()
110	return cfg.schema.Parse(cfg.SchemaStr)
111}
112
113var invalidPackageNameChar = regexp.MustCompile(`[^\w]`)
114
115func sanitizePackageName(pkg string) string {
116	return invalidPackageNameChar.ReplaceAllLiteralString(filepath.Base(pkg), "_")
117}
118
119func abs(path string) string {
120	absPath, err := filepath.Abs(path)
121	if err != nil {
122		panic(err)
123	}
124	return filepath.ToSlash(absPath)
125}
126
127func gofmt(filename string, b []byte) ([]byte, error) {
128	out, err := imports.Process(filename, b, nil)
129	if err != nil {
130		return b, errors.Wrap(err, "unable to gofmt")
131	}
132	return out, nil
133}
134
135func write(filename string, b []byte) error {
136	err := os.MkdirAll(filepath.Dir(filename), 0755)
137	if err != nil {
138		return errors.Wrap(err, "failed to create directory")
139	}
140
141	formatted, err := gofmt(filename, b)
142	if err != nil {
143		fmt.Fprintf(os.Stderr, "gofmt failed: %s\n", err.Error())
144		formatted = b
145	}
146
147	err = ioutil.WriteFile(filename, formatted, 0644)
148	if err != nil {
149		return errors.Wrapf(err, "failed to write %s", filename)
150	}
151
152	return nil
153}