context.go

  1package graphql
  2
  3import (
  4	"context"
  5	"fmt"
  6	"sync"
  7
  8	"github.com/vektah/gqlparser/ast"
  9	"github.com/vektah/gqlparser/gqlerror"
 10)
 11
 12type Resolver func(ctx context.Context) (res interface{}, err error)
 13type FieldMiddleware func(ctx context.Context, next Resolver) (res interface{}, err error)
 14type RequestMiddleware func(ctx context.Context, next func(ctx context.Context) []byte) []byte
 15
 16type RequestContext struct {
 17	RawQuery  string
 18	Variables map[string]interface{}
 19	Doc       *ast.QueryDocument
 20
 21	ComplexityLimit      int
 22	OperationComplexity  int
 23	DisableIntrospection bool
 24
 25	// ErrorPresenter will be used to generate the error
 26	// message from errors given to Error().
 27	ErrorPresenter      ErrorPresenterFunc
 28	Recover             RecoverFunc
 29	ResolverMiddleware  FieldMiddleware
 30	DirectiveMiddleware FieldMiddleware
 31	RequestMiddleware   RequestMiddleware
 32	Tracer              Tracer
 33
 34	errorsMu     sync.Mutex
 35	Errors       gqlerror.List
 36	extensionsMu sync.Mutex
 37	Extensions   map[string]interface{}
 38}
 39
 40func DefaultResolverMiddleware(ctx context.Context, next Resolver) (res interface{}, err error) {
 41	return next(ctx)
 42}
 43
 44func DefaultDirectiveMiddleware(ctx context.Context, next Resolver) (res interface{}, err error) {
 45	return next(ctx)
 46}
 47
 48func DefaultRequestMiddleware(ctx context.Context, next func(ctx context.Context) []byte) []byte {
 49	return next(ctx)
 50}
 51
 52func NewRequestContext(doc *ast.QueryDocument, query string, variables map[string]interface{}) *RequestContext {
 53	return &RequestContext{
 54		Doc:                 doc,
 55		RawQuery:            query,
 56		Variables:           variables,
 57		ResolverMiddleware:  DefaultResolverMiddleware,
 58		DirectiveMiddleware: DefaultDirectiveMiddleware,
 59		RequestMiddleware:   DefaultRequestMiddleware,
 60		Recover:             DefaultRecover,
 61		ErrorPresenter:      DefaultErrorPresenter,
 62		Tracer:              &NopTracer{},
 63	}
 64}
 65
 66type key string
 67
 68const (
 69	request  key = "request_context"
 70	resolver key = "resolver_context"
 71)
 72
 73func GetRequestContext(ctx context.Context) *RequestContext {
 74	val := ctx.Value(request)
 75	if val == nil {
 76		return nil
 77	}
 78
 79	return val.(*RequestContext)
 80}
 81
 82func WithRequestContext(ctx context.Context, rc *RequestContext) context.Context {
 83	return context.WithValue(ctx, request, rc)
 84}
 85
 86type ResolverContext struct {
 87	Parent *ResolverContext
 88	// The name of the type this field belongs to
 89	Object string
 90	// These are the args after processing, they can be mutated in middleware to change what the resolver will get.
 91	Args map[string]interface{}
 92	// The raw field
 93	Field CollectedField
 94	// The index of array in path.
 95	Index *int
 96	// The result object of resolver
 97	Result interface{}
 98}
 99
100func (r *ResolverContext) Path() []interface{} {
101	var path []interface{}
102	for it := r; it != nil; it = it.Parent {
103		if it.Index != nil {
104			path = append(path, *it.Index)
105		} else if it.Field.Field != nil {
106			path = append(path, it.Field.Alias)
107		}
108	}
109
110	// because we are walking up the chain, all the elements are backwards, do an inplace flip.
111	for i := len(path)/2 - 1; i >= 0; i-- {
112		opp := len(path) - 1 - i
113		path[i], path[opp] = path[opp], path[i]
114	}
115
116	return path
117}
118
119func GetResolverContext(ctx context.Context) *ResolverContext {
120	val, _ := ctx.Value(resolver).(*ResolverContext)
121	return val
122}
123
124func WithResolverContext(ctx context.Context, rc *ResolverContext) context.Context {
125	rc.Parent = GetResolverContext(ctx)
126	return context.WithValue(ctx, resolver, rc)
127}
128
129// This is just a convenient wrapper method for CollectFields
130func CollectFieldsCtx(ctx context.Context, satisfies []string) []CollectedField {
131	resctx := GetResolverContext(ctx)
132	return CollectFields(ctx, resctx.Field.Selections, satisfies)
133}
134
135// Errorf sends an error string to the client, passing it through the formatter.
136func (c *RequestContext) Errorf(ctx context.Context, format string, args ...interface{}) {
137	c.errorsMu.Lock()
138	defer c.errorsMu.Unlock()
139
140	c.Errors = append(c.Errors, c.ErrorPresenter(ctx, fmt.Errorf(format, args...)))
141}
142
143// Error sends an error to the client, passing it through the formatter.
144func (c *RequestContext) Error(ctx context.Context, err error) {
145	c.errorsMu.Lock()
146	defer c.errorsMu.Unlock()
147
148	c.Errors = append(c.Errors, c.ErrorPresenter(ctx, err))
149}
150
151// HasError returns true if the current field has already errored
152func (c *RequestContext) HasError(rctx *ResolverContext) bool {
153	c.errorsMu.Lock()
154	defer c.errorsMu.Unlock()
155	path := rctx.Path()
156
157	for _, err := range c.Errors {
158		if equalPath(err.Path, path) {
159			return true
160		}
161	}
162	return false
163}
164
165// GetErrors returns a list of errors that occurred in the current field
166func (c *RequestContext) GetErrors(rctx *ResolverContext) gqlerror.List {
167	c.errorsMu.Lock()
168	defer c.errorsMu.Unlock()
169	path := rctx.Path()
170
171	var errs gqlerror.List
172	for _, err := range c.Errors {
173		if equalPath(err.Path, path) {
174			errs = append(errs, err)
175		}
176	}
177	return errs
178}
179
180func equalPath(a []interface{}, b []interface{}) bool {
181	if len(a) != len(b) {
182		return false
183	}
184
185	for i := 0; i < len(a); i++ {
186		if a[i] != b[i] {
187			return false
188		}
189	}
190
191	return true
192}
193
194// AddError is a convenience method for adding an error to the current response
195func AddError(ctx context.Context, err error) {
196	GetRequestContext(ctx).Error(ctx, err)
197}
198
199// AddErrorf is a convenience method for adding an error to the current response
200func AddErrorf(ctx context.Context, format string, args ...interface{}) {
201	GetRequestContext(ctx).Errorf(ctx, format, args...)
202}
203
204// RegisterExtension registers an extension, returns error if extension has already been registered
205func (c *RequestContext) RegisterExtension(key string, value interface{}) error {
206	c.extensionsMu.Lock()
207	defer c.extensionsMu.Unlock()
208
209	if c.Extensions == nil {
210		c.Extensions = make(map[string]interface{})
211	}
212
213	if _, ok := c.Extensions[key]; ok {
214		return fmt.Errorf("extension already registered for key %s", key)
215	}
216
217	c.Extensions[key] = value
218	return nil
219}