models.go

  1package modelgen
  2
  3import (
  4	"fmt"
  5	"go/types"
  6	"sort"
  7
  8	"github.com/99designs/gqlgen/codegen/config"
  9	"github.com/99designs/gqlgen/codegen/templates"
 10	"github.com/99designs/gqlgen/internal/code"
 11	"github.com/99designs/gqlgen/plugin"
 12	"github.com/vektah/gqlparser/ast"
 13)
 14
 15type ModelBuild struct {
 16	PackageName string
 17	Interfaces  []*Interface
 18	Models      []*Object
 19	Enums       []*Enum
 20	Scalars     []string
 21}
 22
 23type Interface struct {
 24	Description string
 25	Name        string
 26}
 27
 28type Object struct {
 29	Description string
 30	Name        string
 31	Fields      []*Field
 32	Implements  []string
 33}
 34
 35type Field struct {
 36	Description string
 37	Name        string
 38	Type        types.Type
 39	Tag         string
 40}
 41
 42type Enum struct {
 43	Description string
 44	Name        string
 45	Values      []*EnumValue
 46}
 47
 48type EnumValue struct {
 49	Description string
 50	Name        string
 51}
 52
 53func New() plugin.Plugin {
 54	return &Plugin{}
 55}
 56
 57type Plugin struct{}
 58
 59var _ plugin.ConfigMutator = &Plugin{}
 60
 61func (m *Plugin) Name() string {
 62	return "modelgen"
 63}
 64
 65func (m *Plugin) MutateConfig(cfg *config.Config) error {
 66	if err := cfg.Check(); err != nil {
 67		return err
 68	}
 69
 70	schema, _, err := cfg.LoadSchema()
 71	if err != nil {
 72		return err
 73	}
 74
 75	err = cfg.Autobind(schema)
 76	if err != nil {
 77		return err
 78	}
 79
 80	cfg.InjectBuiltins(schema)
 81
 82	binder, err := cfg.NewBinder(schema)
 83	if err != nil {
 84		return err
 85	}
 86
 87	b := &ModelBuild{
 88		PackageName: cfg.Model.Package,
 89	}
 90
 91	for _, schemaType := range schema.Types {
 92		if cfg.Models.UserDefined(schemaType.Name) {
 93			continue
 94		}
 95
 96		switch schemaType.Kind {
 97		case ast.Interface, ast.Union:
 98			it := &Interface{
 99				Description: schemaType.Description,
100				Name:        schemaType.Name,
101			}
102
103			b.Interfaces = append(b.Interfaces, it)
104		case ast.Object, ast.InputObject:
105			if schemaType == schema.Query || schemaType == schema.Mutation || schemaType == schema.Subscription {
106				continue
107			}
108			it := &Object{
109				Description: schemaType.Description,
110				Name:        schemaType.Name,
111			}
112
113			for _, implementor := range schema.GetImplements(schemaType) {
114				it.Implements = append(it.Implements, implementor.Name)
115			}
116
117			for _, field := range schemaType.Fields {
118				var typ types.Type
119				fieldDef := schema.Types[field.Type.Name()]
120
121				if cfg.Models.UserDefined(field.Type.Name()) {
122					pkg, typeName := code.PkgAndType(cfg.Models[field.Type.Name()].Model[0])
123					typ, err = binder.FindType(pkg, typeName)
124					if err != nil {
125						return err
126					}
127				} else {
128					switch fieldDef.Kind {
129					case ast.Scalar:
130						// no user defined model, referencing a default scalar
131						typ = types.NewNamed(
132							types.NewTypeName(0, cfg.Model.Pkg(), "string", nil),
133							nil,
134							nil,
135						)
136
137					case ast.Interface, ast.Union:
138						// no user defined model, referencing a generated interface type
139						typ = types.NewNamed(
140							types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil),
141							types.NewInterfaceType([]*types.Func{}, []types.Type{}),
142							nil,
143						)
144
145					case ast.Enum:
146						// no user defined model, must reference a generated enum
147						typ = types.NewNamed(
148							types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil),
149							nil,
150							nil,
151						)
152
153					case ast.Object, ast.InputObject:
154						// no user defined model, must reference a generated struct
155						typ = types.NewNamed(
156							types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil),
157							types.NewStruct(nil, nil),
158							nil,
159						)
160
161					default:
162						panic(fmt.Errorf("unknown ast type %s", fieldDef.Kind))
163					}
164				}
165
166				name := field.Name
167				if nameOveride := cfg.Models[schemaType.Name].Fields[field.Name].FieldName; nameOveride != "" {
168					name = nameOveride
169				}
170
171				typ = binder.CopyModifiersFromAst(field.Type, typ)
172
173				if isStruct(typ) && (fieldDef.Kind == ast.Object || fieldDef.Kind == ast.InputObject) {
174					typ = types.NewPointer(typ)
175				}
176
177				it.Fields = append(it.Fields, &Field{
178					Name:        name,
179					Type:        typ,
180					Description: field.Description,
181					Tag:         `json:"` + field.Name + `"`,
182				})
183			}
184
185			b.Models = append(b.Models, it)
186		case ast.Enum:
187			it := &Enum{
188				Name:        schemaType.Name,
189				Description: schemaType.Description,
190			}
191
192			for _, v := range schemaType.EnumValues {
193				it.Values = append(it.Values, &EnumValue{
194					Name:        v.Name,
195					Description: v.Description,
196				})
197			}
198
199			b.Enums = append(b.Enums, it)
200		case ast.Scalar:
201			b.Scalars = append(b.Scalars, schemaType.Name)
202		}
203	}
204
205	sort.Slice(b.Enums, func(i, j int) bool { return b.Enums[i].Name < b.Enums[j].Name })
206	sort.Slice(b.Models, func(i, j int) bool { return b.Models[i].Name < b.Models[j].Name })
207	sort.Slice(b.Interfaces, func(i, j int) bool { return b.Interfaces[i].Name < b.Interfaces[j].Name })
208
209	for _, it := range b.Enums {
210		cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name))
211	}
212	for _, it := range b.Models {
213		cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name))
214	}
215	for _, it := range b.Interfaces {
216		cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name))
217	}
218	for _, it := range b.Scalars {
219		cfg.Models.Add(it, "github.com/99designs/gqlgen/graphql.String")
220	}
221
222	if len(b.Models) == 0 && len(b.Enums) == 0 {
223		return nil
224	}
225
226	return templates.Render(templates.Options{
227		PackageName:     cfg.Model.Package,
228		Filename:        cfg.Model.Filename,
229		Data:            b,
230		GeneratedHeader: true,
231	})
232}
233
234func isStruct(t types.Type) bool {
235	_, is := t.Underlying().(*types.Struct)
236	return is
237}