directive.go

  1package codegen
  2
  3import (
  4	"fmt"
  5	"strconv"
  6	"strings"
  7
  8	"github.com/99designs/gqlgen/codegen/templates"
  9	"github.com/pkg/errors"
 10	"github.com/vektah/gqlparser/ast"
 11)
 12
 13type DirectiveList map[string]*Directive
 14
 15//LocationDirectives filter directives by location
 16func (dl DirectiveList) LocationDirectives(location string) DirectiveList {
 17	return locationDirectives(dl, ast.DirectiveLocation(location))
 18}
 19
 20type Directive struct {
 21	*ast.DirectiveDefinition
 22	Name    string
 23	Args    []*FieldArgument
 24	Builtin bool
 25}
 26
 27//IsLocation check location directive
 28func (d *Directive) IsLocation(location ...ast.DirectiveLocation) bool {
 29	for _, l := range d.Locations {
 30		for _, a := range location {
 31			if l == a {
 32				return true
 33			}
 34		}
 35	}
 36
 37	return false
 38}
 39
 40func locationDirectives(directives DirectiveList, location ...ast.DirectiveLocation) map[string]*Directive {
 41	mDirectives := make(map[string]*Directive)
 42	for name, d := range directives {
 43		if d.IsLocation(location...) {
 44			mDirectives[name] = d
 45		}
 46	}
 47	return mDirectives
 48}
 49
 50func (b *builder) buildDirectives() (map[string]*Directive, error) {
 51	directives := make(map[string]*Directive, len(b.Schema.Directives))
 52
 53	for name, dir := range b.Schema.Directives {
 54		if _, ok := directives[name]; ok {
 55			return nil, errors.Errorf("directive with name %s already exists", name)
 56		}
 57
 58		var args []*FieldArgument
 59		for _, arg := range dir.Arguments {
 60			tr, err := b.Binder.TypeReference(arg.Type, nil)
 61			if err != nil {
 62				return nil, err
 63			}
 64
 65			newArg := &FieldArgument{
 66				ArgumentDefinition: arg,
 67				TypeReference:      tr,
 68				VarName:            templates.ToGoPrivate(arg.Name),
 69			}
 70
 71			if arg.DefaultValue != nil {
 72				var err error
 73				newArg.Default, err = arg.DefaultValue.Value(nil)
 74				if err != nil {
 75					return nil, errors.Errorf("default value for directive argument %s(%s) is not valid: %s", dir.Name, arg.Name, err.Error())
 76				}
 77			}
 78			args = append(args, newArg)
 79		}
 80
 81		directives[name] = &Directive{
 82			DirectiveDefinition: dir,
 83			Name:                name,
 84			Args:                args,
 85			Builtin:             b.Config.Directives[name].SkipRuntime,
 86		}
 87	}
 88
 89	return directives, nil
 90}
 91
 92func (b *builder) getDirectives(list ast.DirectiveList) ([]*Directive, error) {
 93	dirs := make([]*Directive, len(list))
 94	for i, d := range list {
 95		argValues := make(map[string]interface{}, len(d.Arguments))
 96		for _, da := range d.Arguments {
 97			val, err := da.Value.Value(nil)
 98			if err != nil {
 99				return nil, err
100			}
101			argValues[da.Name] = val
102		}
103		def, ok := b.Directives[d.Name]
104		if !ok {
105			return nil, fmt.Errorf("directive %s not found", d.Name)
106		}
107
108		var args []*FieldArgument
109		for _, a := range def.Args {
110			value := a.Default
111			if argValue, ok := argValues[a.Name]; ok {
112				value = argValue
113			}
114			args = append(args, &FieldArgument{
115				ArgumentDefinition: a.ArgumentDefinition,
116				Value:              value,
117				VarName:            a.VarName,
118				TypeReference:      a.TypeReference,
119			})
120		}
121		dirs[i] = &Directive{
122			Name:                d.Name,
123			Args:                args,
124			DirectiveDefinition: list[i].Definition,
125			Builtin:             b.Config.Directives[d.Name].SkipRuntime,
126		}
127
128	}
129
130	return dirs, nil
131}
132
133func (d *Directive) ArgsFunc() string {
134	if len(d.Args) == 0 {
135		return ""
136	}
137
138	return "dir_" + d.Name + "_args"
139}
140
141func (d *Directive) CallArgs() string {
142	args := []string{"ctx", "obj", "n"}
143
144	for _, arg := range d.Args {
145		args = append(args, "args["+strconv.Quote(arg.Name)+"].("+templates.CurrentImports.LookupType(arg.TypeReference.GO)+")")
146	}
147
148	return strings.Join(args, ", ")
149}
150
151func (d *Directive) ResolveArgs(obj string, next int) string {
152	args := []string{"ctx", obj, fmt.Sprintf("directive%d", next)}
153
154	for _, arg := range d.Args {
155		dArg := arg.VarName
156		if arg.Value == nil && arg.Default == nil {
157			dArg = "nil"
158		}
159
160		args = append(args, dArg)
161	}
162
163	return strings.Join(args, ", ")
164}
165
166func (d *Directive) Declaration() string {
167	res := ucFirst(d.Name) + " func(ctx context.Context, obj interface{}, next graphql.Resolver"
168
169	for _, arg := range d.Args {
170		res += fmt.Sprintf(", %s %s", arg.Name, templates.CurrentImports.LookupType(arg.TypeReference.GO))
171	}
172
173	res += ") (res interface{}, err error)"
174	return res
175}