1package modelgen
2
3import (
4 "fmt"
5 "go/types"
6 "sort"
7
8 "github.com/99designs/gqlgen/codegen/config"
9 "github.com/99designs/gqlgen/codegen/templates"
10 "github.com/99designs/gqlgen/internal/code"
11 "github.com/99designs/gqlgen/plugin"
12 "github.com/vektah/gqlparser/ast"
13)
14
15type ModelBuild struct {
16 PackageName string
17 Interfaces []*Interface
18 Models []*Object
19 Enums []*Enum
20 Scalars []string
21}
22
23type Interface struct {
24 Description string
25 Name string
26}
27
28type Object struct {
29 Description string
30 Name string
31 Fields []*Field
32 Implements []string
33}
34
35type Field struct {
36 Description string
37 Name string
38 Type types.Type
39 Tag string
40}
41
42type Enum struct {
43 Description string
44 Name string
45 Values []*EnumValue
46}
47
48type EnumValue struct {
49 Description string
50 Name string
51}
52
53func New() plugin.Plugin {
54 return &Plugin{}
55}
56
57type Plugin struct{}
58
59var _ plugin.ConfigMutator = &Plugin{}
60
61func (m *Plugin) Name() string {
62 return "modelgen"
63}
64
65func (m *Plugin) MutateConfig(cfg *config.Config) error {
66 if err := cfg.Check(); err != nil {
67 return err
68 }
69
70 schema, _, err := cfg.LoadSchema()
71 if err != nil {
72 return err
73 }
74
75 cfg.InjectBuiltins(schema)
76
77 binder, err := cfg.NewBinder(schema)
78 if err != nil {
79 return err
80 }
81
82 b := &ModelBuild{
83 PackageName: cfg.Model.Package,
84 }
85
86 for _, schemaType := range schema.Types {
87 if cfg.Models.UserDefined(schemaType.Name) {
88 continue
89 }
90
91 switch schemaType.Kind {
92 case ast.Interface, ast.Union:
93 it := &Interface{
94 Description: schemaType.Description,
95 Name: schemaType.Name,
96 }
97
98 b.Interfaces = append(b.Interfaces, it)
99 case ast.Object, ast.InputObject:
100 if schemaType == schema.Query || schemaType == schema.Mutation || schemaType == schema.Subscription {
101 continue
102 }
103 it := &Object{
104 Description: schemaType.Description,
105 Name: schemaType.Name,
106 }
107
108 for _, implementor := range schema.GetImplements(schemaType) {
109 it.Implements = append(it.Implements, implementor.Name)
110 }
111
112 for _, field := range schemaType.Fields {
113 var typ types.Type
114 fieldDef := schema.Types[field.Type.Name()]
115
116 if cfg.Models.UserDefined(field.Type.Name()) {
117 pkg, typeName := code.PkgAndType(cfg.Models[field.Type.Name()].Model[0])
118 typ, err = binder.FindType(pkg, typeName)
119 if err != nil {
120 return err
121 }
122 } else {
123 switch fieldDef.Kind {
124 case ast.Scalar:
125 // no user defined model, referencing a default scalar
126 typ = types.NewNamed(
127 types.NewTypeName(0, cfg.Model.Pkg(), "string", nil),
128 nil,
129 nil,
130 )
131
132 case ast.Interface, ast.Union:
133 // no user defined model, referencing a generated interface type
134 typ = types.NewNamed(
135 types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil),
136 types.NewInterfaceType([]*types.Func{}, []types.Type{}),
137 nil,
138 )
139
140 case ast.Enum:
141 // no user defined model, must reference a generated enum
142 typ = types.NewNamed(
143 types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil),
144 nil,
145 nil,
146 )
147
148 case ast.Object, ast.InputObject:
149 // no user defined model, must reference a generated struct
150 typ = types.NewNamed(
151 types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil),
152 types.NewStruct(nil, nil),
153 nil,
154 )
155
156 default:
157 panic(fmt.Errorf("unknown ast type %s", fieldDef.Kind))
158 }
159 }
160
161 name := field.Name
162 if nameOveride := cfg.Models[schemaType.Name].Fields[field.Name].FieldName; nameOveride != "" {
163 name = nameOveride
164 }
165
166 typ = binder.CopyModifiersFromAst(field.Type, typ)
167
168 if isStruct(typ) && (fieldDef.Kind == ast.Object || fieldDef.Kind == ast.InputObject) {
169 typ = types.NewPointer(typ)
170 }
171
172 it.Fields = append(it.Fields, &Field{
173 Name: name,
174 Type: typ,
175 Description: field.Description,
176 Tag: `json:"` + field.Name + `"`,
177 })
178 }
179
180 b.Models = append(b.Models, it)
181 case ast.Enum:
182 it := &Enum{
183 Name: schemaType.Name,
184 Description: schemaType.Description,
185 }
186
187 for _, v := range schemaType.EnumValues {
188 it.Values = append(it.Values, &EnumValue{
189 Name: v.Name,
190 Description: v.Description,
191 })
192 }
193
194 b.Enums = append(b.Enums, it)
195 case ast.Scalar:
196 b.Scalars = append(b.Scalars, schemaType.Name)
197 }
198 }
199
200 sort.Slice(b.Enums, func(i, j int) bool { return b.Enums[i].Name < b.Enums[j].Name })
201 sort.Slice(b.Models, func(i, j int) bool { return b.Models[i].Name < b.Models[j].Name })
202 sort.Slice(b.Interfaces, func(i, j int) bool { return b.Interfaces[i].Name < b.Interfaces[j].Name })
203
204 for _, it := range b.Enums {
205 cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name))
206 }
207 for _, it := range b.Models {
208 cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name))
209 }
210 for _, it := range b.Interfaces {
211 cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name))
212 }
213 for _, it := range b.Scalars {
214 cfg.Models.Add(it, "github.com/99designs/gqlgen/graphql.String")
215 }
216
217 if len(b.Models) == 0 && len(b.Enums) == 0 {
218 return nil
219 }
220
221 return templates.Render(templates.Options{
222 PackageName: cfg.Model.Package,
223 Filename: cfg.Model.Filename,
224 Data: b,
225 GeneratedHeader: true,
226 })
227}
228
229func isStruct(t types.Type) bool {
230 _, is := t.Underlying().(*types.Struct)
231 return is
232}