models.go

  1package modelgen
  2
  3import (
  4	"go/types"
  5	"sort"
  6
  7	"github.com/99designs/gqlgen/codegen/config"
  8	"github.com/99designs/gqlgen/codegen/templates"
  9	"github.com/99designs/gqlgen/internal/code"
 10	"github.com/99designs/gqlgen/plugin"
 11	"github.com/vektah/gqlparser/ast"
 12)
 13
 14type ModelBuild struct {
 15	PackageName string
 16	Interfaces  []*Interface
 17	Models      []*Object
 18	Enums       []*Enum
 19	Scalars     []string
 20}
 21
 22type Interface struct {
 23	Description string
 24	Name        string
 25}
 26
 27type Object struct {
 28	Description string
 29	Name        string
 30	Fields      []*Field
 31	Implements  []string
 32}
 33
 34type Field struct {
 35	Description string
 36	Name        string
 37	Type        types.Type
 38	Tag         string
 39}
 40
 41type Enum struct {
 42	Description string
 43	Name        string
 44	Values      []*EnumValue
 45}
 46
 47type EnumValue struct {
 48	Description string
 49	Name        string
 50}
 51
 52func New() plugin.Plugin {
 53	return &Plugin{}
 54}
 55
 56type Plugin struct{}
 57
 58var _ plugin.ConfigMutator = &Plugin{}
 59
 60func (m *Plugin) Name() string {
 61	return "modelgen"
 62}
 63
 64func (m *Plugin) MutateConfig(cfg *config.Config) error {
 65	if err := cfg.Check(); err != nil {
 66		return err
 67	}
 68
 69	schema, _, err := cfg.LoadSchema()
 70	if err != nil {
 71		return err
 72	}
 73
 74	cfg.InjectBuiltins(schema)
 75
 76	binder, err := cfg.NewBinder(schema)
 77	if err != nil {
 78		return err
 79	}
 80
 81	b := &ModelBuild{
 82		PackageName: cfg.Model.Package,
 83	}
 84
 85	for _, schemaType := range schema.Types {
 86		if cfg.Models.UserDefined(schemaType.Name) {
 87			continue
 88		}
 89
 90		switch schemaType.Kind {
 91		case ast.Interface, ast.Union:
 92			it := &Interface{
 93				Description: schemaType.Description,
 94				Name:        schemaType.Name,
 95			}
 96
 97			b.Interfaces = append(b.Interfaces, it)
 98		case ast.Object, ast.InputObject:
 99			if schemaType == schema.Query || schemaType == schema.Mutation || schemaType == schema.Subscription {
100				continue
101			}
102			it := &Object{
103				Description: schemaType.Description,
104				Name:        schemaType.Name,
105			}
106
107			for _, implementor := range schema.GetImplements(schemaType) {
108				it.Implements = append(it.Implements, implementor.Name)
109			}
110
111			for _, field := range schemaType.Fields {
112				var typ types.Type
113
114				if cfg.Models.UserDefined(field.Type.Name()) {
115					pkg, typeName := code.PkgAndType(cfg.Models[field.Type.Name()].Model[0])
116					typ, err = binder.FindType(pkg, typeName)
117					if err != nil {
118						return err
119					}
120				} else {
121					fieldDef := schema.Types[field.Type.Name()]
122					switch fieldDef.Kind {
123					case ast.Scalar:
124						// no user defined model, referencing a default scalar
125						typ = types.NewNamed(
126							types.NewTypeName(0, cfg.Model.Pkg(), "string", nil),
127							nil,
128							nil,
129						)
130					case ast.Interface, ast.Union:
131						// no user defined model, referencing a generated interface type
132						typ = types.NewNamed(
133							types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil),
134							types.NewInterfaceType([]*types.Func{}, []types.Type{}),
135							nil,
136						)
137					default:
138						// no user defined model, must reference another generated model
139						typ = types.NewNamed(
140							types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil),
141							nil,
142							nil,
143						)
144					}
145				}
146
147				name := field.Name
148				if nameOveride := cfg.Models[schemaType.Name].Fields[field.Name].FieldName; nameOveride != "" {
149					name = nameOveride
150				}
151
152				it.Fields = append(it.Fields, &Field{
153					Name:        name,
154					Type:        binder.CopyModifiersFromAst(field.Type, typ),
155					Description: field.Description,
156					Tag:         `json:"` + field.Name + `"`,
157				})
158			}
159
160			b.Models = append(b.Models, it)
161		case ast.Enum:
162			it := &Enum{
163				Name:        schemaType.Name,
164				Description: schemaType.Description,
165			}
166
167			for _, v := range schemaType.EnumValues {
168				it.Values = append(it.Values, &EnumValue{
169					Name:        v.Name,
170					Description: v.Description,
171				})
172			}
173
174			b.Enums = append(b.Enums, it)
175		case ast.Scalar:
176			b.Scalars = append(b.Scalars, schemaType.Name)
177		}
178	}
179
180	sort.Slice(b.Enums, func(i, j int) bool { return b.Enums[i].Name < b.Enums[j].Name })
181	sort.Slice(b.Models, func(i, j int) bool { return b.Models[i].Name < b.Models[j].Name })
182	sort.Slice(b.Interfaces, func(i, j int) bool { return b.Interfaces[i].Name < b.Interfaces[j].Name })
183
184	for _, it := range b.Enums {
185		cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name))
186	}
187	for _, it := range b.Models {
188		cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name))
189	}
190	for _, it := range b.Interfaces {
191		cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name))
192	}
193	for _, it := range b.Scalars {
194		cfg.Models.Add(it, "github.com/99designs/gqlgen/graphql.String")
195	}
196
197	if len(b.Models) == 0 && len(b.Enums) == 0 {
198		return nil
199	}
200
201	return templates.Render(templates.Options{
202		PackageName:     cfg.Model.Package,
203		Filename:        cfg.Model.Filename,
204		Data:            b,
205		GeneratedHeader: true,
206	})
207}