object_build.go

  1package codegen
  2
  3import (
  4	"log"
  5	"sort"
  6
  7	"github.com/pkg/errors"
  8	"github.com/vektah/gqlparser/ast"
  9	"golang.org/x/tools/go/loader"
 10)
 11
 12func (cfg *Config) buildObjects(types NamedTypes, prog *loader.Program, imports *Imports) (Objects, error) {
 13	var objects Objects
 14
 15	for _, typ := range cfg.schema.Types {
 16		if typ.Kind != ast.Object {
 17			continue
 18		}
 19
 20		obj, err := cfg.buildObject(types, typ, imports)
 21		if err != nil {
 22			return nil, err
 23		}
 24
 25		def, err := findGoType(prog, obj.Package, obj.GoType)
 26		if err != nil {
 27			return nil, err
 28		}
 29		if def != nil {
 30			for _, bindErr := range bindObject(def.Type(), obj, imports, cfg.StructTag) {
 31				log.Println(bindErr.Error())
 32				log.Println("  Adding resolver method")
 33			}
 34		}
 35
 36		objects = append(objects, obj)
 37	}
 38
 39	sort.Slice(objects, func(i, j int) bool {
 40		return objects[i].GQLType < objects[j].GQLType
 41	})
 42
 43	return objects, nil
 44}
 45
 46var keywords = []string{
 47	"break",
 48	"default",
 49	"func",
 50	"interface",
 51	"select",
 52	"case",
 53	"defer",
 54	"go",
 55	"map",
 56	"struct",
 57	"chan",
 58	"else",
 59	"goto",
 60	"package",
 61	"switch",
 62	"const",
 63	"fallthrough",
 64	"if",
 65	"range",
 66	"type",
 67	"continue",
 68	"for",
 69	"import",
 70	"return",
 71	"var",
 72}
 73
 74// sanitizeArgName prevents collisions with go keywords for arguments to resolver functions
 75func sanitizeArgName(name string) string {
 76	for _, k := range keywords {
 77		if name == k {
 78			return name + "Arg"
 79		}
 80	}
 81	return name
 82}
 83
 84func (cfg *Config) buildObject(types NamedTypes, typ *ast.Definition, imports *Imports) (*Object, error) {
 85	obj := &Object{NamedType: types[typ.Name]}
 86	typeEntry, entryExists := cfg.Models[typ.Name]
 87
 88	imp := imports.findByPath(cfg.Exec.ImportPath())
 89	obj.ResolverInterface = &Ref{GoType: obj.GQLType + "Resolver", Import: imp}
 90
 91	if typ == cfg.schema.Query {
 92		obj.Root = true
 93	}
 94
 95	if typ == cfg.schema.Mutation {
 96		obj.Root = true
 97		obj.DisableConcurrency = true
 98	}
 99
100	if typ == cfg.schema.Subscription {
101		obj.Root = true
102		obj.Stream = true
103	}
104
105	obj.Satisfies = append(obj.Satisfies, typ.Interfaces...)
106
107	for _, field := range typ.Fields {
108		if typ == cfg.schema.Query && field.Name == "__type" {
109			obj.Fields = append(obj.Fields, Field{
110				Type:           &Type{types["__Schema"], []string{modPtr}, ast.NamedType("__Schema", nil), nil},
111				GQLName:        "__schema",
112				NoErr:          true,
113				GoFieldType:    GoFieldMethod,
114				GoReceiverName: "ec",
115				GoFieldName:    "introspectSchema",
116				Object:         obj,
117				Description:    field.Description,
118			})
119			continue
120		}
121		if typ == cfg.schema.Query && field.Name == "__schema" {
122			obj.Fields = append(obj.Fields, Field{
123				Type:           &Type{types["__Type"], []string{modPtr}, ast.NamedType("__Schema", nil), nil},
124				GQLName:        "__type",
125				NoErr:          true,
126				GoFieldType:    GoFieldMethod,
127				GoReceiverName: "ec",
128				GoFieldName:    "introspectType",
129				Args: []FieldArgument{
130					{GQLName: "name", Type: &Type{types["String"], []string{}, ast.NamedType("String", nil), nil}, Object: &Object{}},
131				},
132				Object: obj,
133			})
134			continue
135		}
136
137		var forceResolver bool
138		var goName string
139		if entryExists {
140			if typeField, ok := typeEntry.Fields[field.Name]; ok {
141				goName = typeField.FieldName
142				forceResolver = typeField.Resolver
143			}
144		}
145
146		var args []FieldArgument
147		for _, arg := range field.Arguments {
148			newArg := FieldArgument{
149				GQLName:   arg.Name,
150				Type:      types.getType(arg.Type),
151				Object:    obj,
152				GoVarName: sanitizeArgName(arg.Name),
153			}
154
155			if !newArg.Type.IsInput && !newArg.Type.IsScalar {
156				return nil, errors.Errorf("%s cannot be used as argument of %s.%s. only input and scalar types are allowed", arg.Type, obj.GQLType, field.Name)
157			}
158
159			if arg.DefaultValue != nil {
160				var err error
161				newArg.Default, err = arg.DefaultValue.Value(nil)
162				if err != nil {
163					return nil, errors.Errorf("default value for %s.%s is not valid: %s", typ.Name, field.Name, err.Error())
164				}
165				newArg.StripPtr()
166			}
167			args = append(args, newArg)
168		}
169
170		obj.Fields = append(obj.Fields, Field{
171			GQLName:       field.Name,
172			Type:          types.getType(field.Type),
173			Args:          args,
174			Object:        obj,
175			GoFieldName:   goName,
176			ForceResolver: forceResolver,
177		})
178	}
179
180	return obj, nil
181}