object_build.go

  1package codegen
  2
  3import (
  4	"log"
  5	"sort"
  6
  7	"github.com/pkg/errors"
  8	"github.com/vektah/gqlparser/ast"
  9	"golang.org/x/tools/go/loader"
 10)
 11
 12func (cfg *Config) buildObjects(types NamedTypes, prog *loader.Program) (Objects, error) {
 13	var objects Objects
 14
 15	for _, typ := range cfg.schema.Types {
 16		if typ.Kind != ast.Object {
 17			continue
 18		}
 19
 20		obj, err := cfg.buildObject(types, typ)
 21		if err != nil {
 22			return nil, err
 23		}
 24
 25		def, err := findGoType(prog, obj.Package, obj.GoType)
 26		if err != nil {
 27			return nil, err
 28		}
 29		if def != nil {
 30			for _, bindErr := range bindObject(def.Type(), obj, cfg.StructTag) {
 31				log.Println(bindErr.Error())
 32				log.Println("  Adding resolver method")
 33			}
 34		}
 35
 36		objects = append(objects, obj)
 37	}
 38
 39	sort.Slice(objects, func(i, j int) bool {
 40		return objects[i].GQLType < objects[j].GQLType
 41	})
 42
 43	return objects, nil
 44}
 45
 46var keywords = []string{
 47	"break",
 48	"default",
 49	"func",
 50	"interface",
 51	"select",
 52	"case",
 53	"defer",
 54	"go",
 55	"map",
 56	"struct",
 57	"chan",
 58	"else",
 59	"goto",
 60	"package",
 61	"switch",
 62	"const",
 63	"fallthrough",
 64	"if",
 65	"range",
 66	"type",
 67	"continue",
 68	"for",
 69	"import",
 70	"return",
 71	"var",
 72}
 73
 74// sanitizeArgName prevents collisions with go keywords for arguments to resolver functions
 75func sanitizeArgName(name string) string {
 76	for _, k := range keywords {
 77		if name == k {
 78			return name + "Arg"
 79		}
 80	}
 81	return name
 82}
 83
 84func (cfg *Config) buildObject(types NamedTypes, typ *ast.Definition) (*Object, error) {
 85	obj := &Object{NamedType: types[typ.Name]}
 86	typeEntry, entryExists := cfg.Models[typ.Name]
 87
 88	obj.ResolverInterface = &Ref{GoType: obj.GQLType + "Resolver"}
 89
 90	if typ == cfg.schema.Query {
 91		obj.Root = true
 92	}
 93
 94	if typ == cfg.schema.Mutation {
 95		obj.Root = true
 96		obj.DisableConcurrency = true
 97	}
 98
 99	if typ == cfg.schema.Subscription {
100		obj.Root = true
101		obj.Stream = true
102	}
103
104	obj.Satisfies = append(obj.Satisfies, typ.Interfaces...)
105
106	for _, intf := range cfg.schema.GetImplements(typ) {
107		obj.Implements = append(obj.Implements, types[intf.Name])
108	}
109
110	for _, field := range typ.Fields {
111		if typ == cfg.schema.Query && field.Name == "__type" {
112			obj.Fields = append(obj.Fields, Field{
113				Type:           &Type{types["__Schema"], []string{modPtr}, ast.NamedType("__Schema", nil), nil},
114				GQLName:        "__schema",
115				GoFieldType:    GoFieldMethod,
116				GoReceiverName: "ec",
117				GoFieldName:    "introspectSchema",
118				Object:         obj,
119				Description:    field.Description,
120			})
121			continue
122		}
123		if typ == cfg.schema.Query && field.Name == "__schema" {
124			obj.Fields = append(obj.Fields, Field{
125				Type:           &Type{types["__Type"], []string{modPtr}, ast.NamedType("__Schema", nil), nil},
126				GQLName:        "__type",
127				GoFieldType:    GoFieldMethod,
128				GoReceiverName: "ec",
129				GoFieldName:    "introspectType",
130				Args: []FieldArgument{
131					{GQLName: "name", Type: &Type{types["String"], []string{}, ast.NamedType("String", nil), nil}, Object: &Object{}},
132				},
133				Object: obj,
134			})
135			continue
136		}
137
138		var forceResolver bool
139		var goName string
140		if entryExists {
141			if typeField, ok := typeEntry.Fields[field.Name]; ok {
142				goName = typeField.FieldName
143				forceResolver = typeField.Resolver
144			}
145		}
146
147		var args []FieldArgument
148		for _, arg := range field.Arguments {
149			newArg := FieldArgument{
150				GQLName:   arg.Name,
151				Type:      types.getType(arg.Type),
152				Object:    obj,
153				GoVarName: sanitizeArgName(arg.Name),
154			}
155
156			if !newArg.Type.IsInput && !newArg.Type.IsScalar {
157				return nil, errors.Errorf("%s cannot be used as argument of %s.%s. only input and scalar types are allowed", arg.Type, obj.GQLType, field.Name)
158			}
159
160			if arg.DefaultValue != nil {
161				var err error
162				newArg.Default, err = arg.DefaultValue.Value(nil)
163				if err != nil {
164					return nil, errors.Errorf("default value for %s.%s is not valid: %s", typ.Name, field.Name, err.Error())
165				}
166			}
167			args = append(args, newArg)
168		}
169
170		obj.Fields = append(obj.Fields, Field{
171			GQLName:       field.Name,
172			Type:          types.getType(field.Type),
173			Args:          args,
174			Object:        obj,
175			GoFieldName:   goName,
176			ForceResolver: forceResolver,
177		})
178	}
179
180	return obj, nil
181}