1package modelgen
2
3import (
4 "go/types"
5 "sort"
6
7 "github.com/99designs/gqlgen/codegen/config"
8 "github.com/99designs/gqlgen/codegen/templates"
9 "github.com/99designs/gqlgen/internal/code"
10 "github.com/99designs/gqlgen/plugin"
11 "github.com/vektah/gqlparser/ast"
12)
13
14type ModelBuild struct {
15 PackageName string
16 Interfaces []*Interface
17 Models []*Object
18 Enums []*Enum
19 Scalars []string
20}
21
22type Interface struct {
23 Description string
24 Name string
25}
26
27type Object struct {
28 Description string
29 Name string
30 Fields []*Field
31 Implements []string
32}
33
34type Field struct {
35 Description string
36 Name string
37 Type types.Type
38 Tag string
39}
40
41type Enum struct {
42 Description string
43 Name string
44 Values []*EnumValue
45}
46
47type EnumValue struct {
48 Description string
49 Name string
50}
51
52func New() plugin.Plugin {
53 return &Plugin{}
54}
55
56type Plugin struct{}
57
58var _ plugin.ConfigMutator = &Plugin{}
59
60func (m *Plugin) Name() string {
61 return "modelgen"
62}
63
64func (m *Plugin) MutateConfig(cfg *config.Config) error {
65 if err := cfg.Check(); err != nil {
66 return err
67 }
68
69 schema, _, err := cfg.LoadSchema()
70 if err != nil {
71 return err
72 }
73
74 cfg.InjectBuiltins(schema)
75
76 binder, err := cfg.NewBinder(schema)
77 if err != nil {
78 return err
79 }
80
81 b := &ModelBuild{
82 PackageName: cfg.Model.Package,
83 }
84
85 for _, schemaType := range schema.Types {
86 if cfg.Models.UserDefined(schemaType.Name) {
87 continue
88 }
89
90 switch schemaType.Kind {
91 case ast.Interface, ast.Union:
92 it := &Interface{
93 Description: schemaType.Description,
94 Name: schemaType.Name,
95 }
96
97 b.Interfaces = append(b.Interfaces, it)
98 case ast.Object, ast.InputObject:
99 if schemaType == schema.Query || schemaType == schema.Mutation || schemaType == schema.Subscription {
100 continue
101 }
102 it := &Object{
103 Description: schemaType.Description,
104 Name: schemaType.Name,
105 }
106
107 for _, implementor := range schema.GetImplements(schemaType) {
108 it.Implements = append(it.Implements, implementor.Name)
109 }
110
111 for _, field := range schemaType.Fields {
112 var typ types.Type
113
114 if cfg.Models.UserDefined(field.Type.Name()) {
115 pkg, typeName := code.PkgAndType(cfg.Models[field.Type.Name()].Model[0])
116 typ, err = binder.FindType(pkg, typeName)
117 if err != nil {
118 return err
119 }
120 } else {
121 fieldDef := schema.Types[field.Type.Name()]
122 switch fieldDef.Kind {
123 case ast.Scalar:
124 // no user defined model, referencing a default scalar
125 typ = types.NewNamed(
126 types.NewTypeName(0, cfg.Model.Pkg(), "string", nil),
127 nil,
128 nil,
129 )
130 case ast.Interface, ast.Union:
131 // no user defined model, referencing a generated interface type
132 typ = types.NewNamed(
133 types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil),
134 types.NewInterfaceType([]*types.Func{}, []types.Type{}),
135 nil,
136 )
137 default:
138 // no user defined model, must reference another generated model
139 typ = types.NewNamed(
140 types.NewTypeName(0, cfg.Model.Pkg(), templates.ToGo(field.Type.Name()), nil),
141 nil,
142 nil,
143 )
144 }
145 }
146
147 name := field.Name
148 if nameOveride := cfg.Models[schemaType.Name].Fields[field.Name].FieldName; nameOveride != "" {
149 name = nameOveride
150 }
151
152 it.Fields = append(it.Fields, &Field{
153 Name: name,
154 Type: binder.CopyModifiersFromAst(field.Type, typ),
155 Description: field.Description,
156 Tag: `json:"` + field.Name + `"`,
157 })
158 }
159
160 b.Models = append(b.Models, it)
161 case ast.Enum:
162 it := &Enum{
163 Name: schemaType.Name,
164 Description: schemaType.Description,
165 }
166
167 for _, v := range schemaType.EnumValues {
168 it.Values = append(it.Values, &EnumValue{
169 Name: v.Name,
170 Description: v.Description,
171 })
172 }
173
174 b.Enums = append(b.Enums, it)
175 case ast.Scalar:
176 b.Scalars = append(b.Scalars, schemaType.Name)
177 }
178 }
179
180 sort.Slice(b.Enums, func(i, j int) bool { return b.Enums[i].Name < b.Enums[j].Name })
181 sort.Slice(b.Models, func(i, j int) bool { return b.Models[i].Name < b.Models[j].Name })
182 sort.Slice(b.Interfaces, func(i, j int) bool { return b.Interfaces[i].Name < b.Interfaces[j].Name })
183
184 for _, it := range b.Enums {
185 cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name))
186 }
187 for _, it := range b.Models {
188 cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name))
189 }
190 for _, it := range b.Interfaces {
191 cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name))
192 }
193 for _, it := range b.Scalars {
194 cfg.Models.Add(it, "github.com/99designs/gqlgen/graphql.String")
195 }
196
197 if len(b.Models) == 0 && len(b.Enums) == 0 {
198 return nil
199 }
200
201 return templates.Render(templates.Options{
202 PackageName: cfg.Model.Package,
203 Filename: cfg.Model.Filename,
204 Data: b,
205 GeneratedHeader: true,
206 })
207}