1package modelgen
2
3import (
4 "fmt"
5 "go/types"
6 "sort"
7
8 "github.com/99designs/gqlgen/codegen/config"
9 "github.com/99designs/gqlgen/codegen/templates"
10 "github.com/99designs/gqlgen/internal/code"
11 "github.com/99designs/gqlgen/plugin"
12 "github.com/vektah/gqlparser/ast"
13)
14
15type ModelBuild struct {
16 PackageName string
17 Interfaces []*Interface
18 Models []*Object
19 Enums []*Enum
20 Scalars []string
21}
22
23type Interface struct {
24 Description string
25 Name string
26}
27
28type Object struct {
29 Description string
30 Name string
31 Fields []*Field
32 Implements []string
33}
34
35type Field struct {
36 Description string
37 Name string
38 Type types.Type
39 Tag string
40}
41
42type Enum struct {
43 Description string
44 Name string
45 Values []*EnumValue
46}
47
48type EnumValue struct {
49 Description string
50 Name string
51}
52
53func New() plugin.Plugin {
54 return &Plugin{}
55}
56
57type Plugin struct{}
58
59var _ plugin.ConfigMutator = &Plugin{}
60
61func (m *Plugin) Name() string {
62 return "modelgen"
63}
64
65func (m *Plugin) MutateConfig(cfg *config.Config) error {
66 if err := cfg.Check(); err != nil {
67 return err
68 }
69
70 schema, _, err := cfg.LoadSchema()
71 if err != nil {
72 return err
73 }
74
75 err = cfg.Autobind(schema)
76 if err != nil {
77 return err
78 }
79
80 cfg.InjectBuiltins(schema)
81
82 binder, err := cfg.NewBinder(schema)
83 if err != nil {
84 return err
85 }
86
87 b := &ModelBuild{
88 PackageName: cfg.Model.Package,
89 }
90
91 for _, schemaType := range schema.Types {
92 if cfg.Models.UserDefined(schemaType.Name) {
93 continue
94 }
95
96 switch schemaType.Kind {
97 case ast.Interface, ast.Union:
98 it := &Interface{
99 Description: schemaType.Description,
100 Name: schemaType.Name,
101 }
102
103 b.Interfaces = append(b.Interfaces, it)
104 case ast.Object, ast.InputObject:
105 if schemaType == schema.Query || schemaType == schema.Mutation || schemaType == schema.Subscription {
106 continue
107 }
108 it := &Object{
109 Description: schemaType.Description,
110 Name: schemaType.Name,
111 }
112
113 for _, implementor := range schema.GetImplements(schemaType) {
114 it.Implements = append(it.Implements, implementor.Name)
115 }
116
117 for _, field := range schemaType.Fields {
118 var typ types.Type
119 fieldDef := schema.Types[field.Type.Name()]
120
121 if cfg.Models.UserDefined(field.Type.Name()) {
122 pkg, typeName := code.PkgAndType(cfg.Models[field.Type.Name()].Model[0])
123 typ, err = binder.FindType(pkg, typeName)
124 if err != nil {
125 return err
126 }
127 } else {
128 switch fieldDef.Kind {
129 case ast.Scalar:
130 // no user defined model, referencing a default scalar
131 typ = types.NewNamed(
132 types.NewTypeName(0, cfg.Model.Pkg(), "string", nil),
133 nil,
134 nil,
135 )
136
137 case ast.Interface, ast.Union:
138 // no user defined model, referencing a generated interface type
139 typ = types.NewNamed(
140 types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil),
141 types.NewInterfaceType([]*types.Func{}, []types.Type{}),
142 nil,
143 )
144
145 case ast.Enum:
146 // no user defined model, must reference a generated enum
147 typ = types.NewNamed(
148 types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil),
149 nil,
150 nil,
151 )
152
153 case ast.Object, ast.InputObject:
154 // no user defined model, must reference a generated struct
155 typ = types.NewNamed(
156 types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil),
157 types.NewStruct(nil, nil),
158 nil,
159 )
160
161 default:
162 panic(fmt.Errorf("unknown ast type %s", fieldDef.Kind))
163 }
164 }
165
166 name := field.Name
167 if nameOveride := cfg.Models[schemaType.Name].Fields[field.Name].FieldName; nameOveride != "" {
168 name = nameOveride
169 }
170
171 typ = binder.CopyModifiersFromAst(field.Type, typ)
172
173 if isStruct(typ) && (fieldDef.Kind == ast.Object || fieldDef.Kind == ast.InputObject) {
174 typ = types.NewPointer(typ)
175 }
176
177 it.Fields = append(it.Fields, &Field{
178 Name: name,
179 Type: typ,
180 Description: field.Description,
181 Tag: `json:"` + field.Name + `"`,
182 })
183 }
184
185 b.Models = append(b.Models, it)
186 case ast.Enum:
187 it := &Enum{
188 Name: schemaType.Name,
189 Description: schemaType.Description,
190 }
191
192 for _, v := range schemaType.EnumValues {
193 it.Values = append(it.Values, &EnumValue{
194 Name: v.Name,
195 Description: v.Description,
196 })
197 }
198
199 b.Enums = append(b.Enums, it)
200 case ast.Scalar:
201 b.Scalars = append(b.Scalars, schemaType.Name)
202 }
203 }
204
205 sort.Slice(b.Enums, func(i, j int) bool { return b.Enums[i].Name < b.Enums[j].Name })
206 sort.Slice(b.Models, func(i, j int) bool { return b.Models[i].Name < b.Models[j].Name })
207 sort.Slice(b.Interfaces, func(i, j int) bool { return b.Interfaces[i].Name < b.Interfaces[j].Name })
208
209 for _, it := range b.Enums {
210 cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name))
211 }
212 for _, it := range b.Models {
213 cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name))
214 }
215 for _, it := range b.Interfaces {
216 cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name))
217 }
218 for _, it := range b.Scalars {
219 cfg.Models.Add(it, "github.com/99designs/gqlgen/graphql.String")
220 }
221
222 if len(b.Models) == 0 && len(b.Enums) == 0 {
223 return nil
224 }
225
226 return templates.Render(templates.Options{
227 PackageName: cfg.Model.Package,
228 Filename: cfg.Model.Filename,
229 Data: b,
230 GeneratedHeader: true,
231 })
232}
233
234func isStruct(t types.Type) bool {
235 _, is := t.Underlying().(*types.Struct)
236 return is
237}