directive.go

  1package codegen
  2
  3import (
  4	"fmt"
  5	"strconv"
  6	"strings"
  7
  8	"github.com/99designs/gqlgen/codegen/templates"
  9	"github.com/pkg/errors"
 10	"github.com/vektah/gqlparser/ast"
 11)
 12
 13type Directive struct {
 14	Name    string
 15	Args    []*FieldArgument
 16	Builtin bool
 17}
 18
 19func (b *builder) buildDirectives() (map[string]*Directive, error) {
 20	directives := make(map[string]*Directive, len(b.Schema.Directives))
 21
 22	for name, dir := range b.Schema.Directives {
 23		if _, ok := directives[name]; ok {
 24			return nil, errors.Errorf("directive with name %s already exists", name)
 25		}
 26
 27		var builtin bool
 28		if name == "skip" || name == "include" || name == "deprecated" {
 29			builtin = true
 30		}
 31
 32		var args []*FieldArgument
 33		for _, arg := range dir.Arguments {
 34			tr, err := b.Binder.TypeReference(arg.Type, nil)
 35			if err != nil {
 36				return nil, err
 37			}
 38
 39			newArg := &FieldArgument{
 40				ArgumentDefinition: arg,
 41				TypeReference:      tr,
 42				VarName:            templates.ToGoPrivate(arg.Name),
 43			}
 44
 45			if arg.DefaultValue != nil {
 46				var err error
 47				newArg.Default, err = arg.DefaultValue.Value(nil)
 48				if err != nil {
 49					return nil, errors.Errorf("default value for directive argument %s(%s) is not valid: %s", dir.Name, arg.Name, err.Error())
 50				}
 51			}
 52			args = append(args, newArg)
 53		}
 54
 55		directives[name] = &Directive{
 56			Name:    name,
 57			Args:    args,
 58			Builtin: builtin,
 59		}
 60	}
 61
 62	return directives, nil
 63}
 64
 65func (b *builder) getDirectives(list ast.DirectiveList) ([]*Directive, error) {
 66	dirs := make([]*Directive, len(list))
 67	for i, d := range list {
 68		argValues := make(map[string]interface{}, len(d.Arguments))
 69		for _, da := range d.Arguments {
 70			val, err := da.Value.Value(nil)
 71			if err != nil {
 72				return nil, err
 73			}
 74			argValues[da.Name] = val
 75		}
 76		def, ok := b.Directives[d.Name]
 77		if !ok {
 78			return nil, fmt.Errorf("directive %s not found", d.Name)
 79		}
 80
 81		var args []*FieldArgument
 82		for _, a := range def.Args {
 83			value := a.Default
 84			if argValue, ok := argValues[a.Name]; ok {
 85				value = argValue
 86			}
 87			args = append(args, &FieldArgument{
 88				ArgumentDefinition: a.ArgumentDefinition,
 89				Value:              value,
 90				VarName:            a.VarName,
 91				TypeReference:      a.TypeReference,
 92			})
 93		}
 94		dirs[i] = &Directive{
 95			Name: d.Name,
 96			Args: args,
 97		}
 98
 99	}
100
101	return dirs, nil
102}
103
104func (d *Directive) ArgsFunc() string {
105	if len(d.Args) == 0 {
106		return ""
107	}
108
109	return "dir_" + d.Name + "_args"
110}
111
112func (d *Directive) CallArgs() string {
113	args := []string{"ctx", "obj", "n"}
114
115	for _, arg := range d.Args {
116		args = append(args, "args["+strconv.Quote(arg.Name)+"].("+templates.CurrentImports.LookupType(arg.TypeReference.GO)+")")
117	}
118
119	return strings.Join(args, ", ")
120}
121
122func (d *Directive) ResolveArgs(obj string, next string) string {
123	args := []string{"ctx", obj, next}
124
125	for _, arg := range d.Args {
126		dArg := "&" + arg.VarName
127		if !arg.TypeReference.IsPtr() {
128			if arg.Value != nil {
129				dArg = templates.Dump(arg.Value)
130			} else {
131				dArg = templates.Dump(arg.Default)
132			}
133		} else if arg.Value == nil && arg.Default == nil {
134			dArg = "nil"
135		}
136
137		args = append(args, dArg)
138	}
139
140	return strings.Join(args, ", ")
141}
142
143func (d *Directive) Declaration() string {
144	res := ucFirst(d.Name) + " func(ctx context.Context, obj interface{}, next graphql.Resolver"
145
146	for _, arg := range d.Args {
147		res += fmt.Sprintf(", %s %s", arg.Name, templates.CurrentImports.LookupType(arg.TypeReference.GO))
148	}
149
150	res += ") (res interface{}, err error)"
151	return res
152}