context.go

  1package graphql
  2
  3import (
  4	"context"
  5	"fmt"
  6	"sync"
  7
  8	"github.com/vektah/gqlparser/ast"
  9	"github.com/vektah/gqlparser/gqlerror"
 10)
 11
 12type Resolver func(ctx context.Context) (res interface{}, err error)
 13type FieldMiddleware func(ctx context.Context, next Resolver) (res interface{}, err error)
 14type RequestMiddleware func(ctx context.Context, next func(ctx context.Context) []byte) []byte
 15type ComplexityLimitFunc func(ctx context.Context) int
 16
 17type RequestContext struct {
 18	RawQuery  string
 19	Variables map[string]interface{}
 20	Doc       *ast.QueryDocument
 21
 22	ComplexityLimit      int
 23	OperationComplexity  int
 24	DisableIntrospection bool
 25
 26	// ErrorPresenter will be used to generate the error
 27	// message from errors given to Error().
 28	ErrorPresenter      ErrorPresenterFunc
 29	Recover             RecoverFunc
 30	ResolverMiddleware  FieldMiddleware
 31	DirectiveMiddleware FieldMiddleware
 32	RequestMiddleware   RequestMiddleware
 33	Tracer              Tracer
 34
 35	errorsMu     sync.Mutex
 36	Errors       gqlerror.List
 37	extensionsMu sync.Mutex
 38	Extensions   map[string]interface{}
 39}
 40
 41func DefaultResolverMiddleware(ctx context.Context, next Resolver) (res interface{}, err error) {
 42	return next(ctx)
 43}
 44
 45func DefaultDirectiveMiddleware(ctx context.Context, next Resolver) (res interface{}, err error) {
 46	return next(ctx)
 47}
 48
 49func DefaultRequestMiddleware(ctx context.Context, next func(ctx context.Context) []byte) []byte {
 50	return next(ctx)
 51}
 52
 53func NewRequestContext(doc *ast.QueryDocument, query string, variables map[string]interface{}) *RequestContext {
 54	return &RequestContext{
 55		Doc:                 doc,
 56		RawQuery:            query,
 57		Variables:           variables,
 58		ResolverMiddleware:  DefaultResolverMiddleware,
 59		DirectiveMiddleware: DefaultDirectiveMiddleware,
 60		RequestMiddleware:   DefaultRequestMiddleware,
 61		Recover:             DefaultRecover,
 62		ErrorPresenter:      DefaultErrorPresenter,
 63		Tracer:              &NopTracer{},
 64	}
 65}
 66
 67type key string
 68
 69const (
 70	request  key = "request_context"
 71	resolver key = "resolver_context"
 72)
 73
 74func GetRequestContext(ctx context.Context) *RequestContext {
 75	if val, ok := ctx.Value(request).(*RequestContext); ok {
 76		return val
 77	}
 78	return nil
 79}
 80
 81func WithRequestContext(ctx context.Context, rc *RequestContext) context.Context {
 82	return context.WithValue(ctx, request, rc)
 83}
 84
 85type ResolverContext struct {
 86	Parent *ResolverContext
 87	// The name of the type this field belongs to
 88	Object string
 89	// These are the args after processing, they can be mutated in middleware to change what the resolver will get.
 90	Args map[string]interface{}
 91	// The raw field
 92	Field CollectedField
 93	// The index of array in path.
 94	Index *int
 95	// The result object of resolver
 96	Result interface{}
 97	// IsMethod indicates if the resolver is a method
 98	IsMethod bool
 99}
100
101func (r *ResolverContext) Path() []interface{} {
102	var path []interface{}
103	for it := r; it != nil; it = it.Parent {
104		if it.Index != nil {
105			path = append(path, *it.Index)
106		} else if it.Field.Field != nil {
107			path = append(path, it.Field.Alias)
108		}
109	}
110
111	// because we are walking up the chain, all the elements are backwards, do an inplace flip.
112	for i := len(path)/2 - 1; i >= 0; i-- {
113		opp := len(path) - 1 - i
114		path[i], path[opp] = path[opp], path[i]
115	}
116
117	return path
118}
119
120func GetResolverContext(ctx context.Context) *ResolverContext {
121	if val, ok := ctx.Value(resolver).(*ResolverContext); ok {
122		return val
123	}
124	return nil
125}
126
127func WithResolverContext(ctx context.Context, rc *ResolverContext) context.Context {
128	rc.Parent = GetResolverContext(ctx)
129	return context.WithValue(ctx, resolver, rc)
130}
131
132// This is just a convenient wrapper method for CollectFields
133func CollectFieldsCtx(ctx context.Context, satisfies []string) []CollectedField {
134	resctx := GetResolverContext(ctx)
135	return CollectFields(ctx, resctx.Field.Selections, satisfies)
136}
137
138// CollectAllFields returns a slice of all GraphQL field names that were selected for the current resolver context.
139// The slice will contain the unique set of all field names requested regardless of fragment type conditions.
140func CollectAllFields(ctx context.Context) []string {
141	resctx := GetResolverContext(ctx)
142	collected := CollectFields(ctx, resctx.Field.Selections, nil)
143	uniq := make([]string, 0, len(collected))
144Next:
145	for _, f := range collected {
146		for _, name := range uniq {
147			if name == f.Name {
148				continue Next
149			}
150		}
151		uniq = append(uniq, f.Name)
152	}
153	return uniq
154}
155
156// Errorf sends an error string to the client, passing it through the formatter.
157func (c *RequestContext) Errorf(ctx context.Context, format string, args ...interface{}) {
158	c.errorsMu.Lock()
159	defer c.errorsMu.Unlock()
160
161	c.Errors = append(c.Errors, c.ErrorPresenter(ctx, fmt.Errorf(format, args...)))
162}
163
164// Error sends an error to the client, passing it through the formatter.
165func (c *RequestContext) Error(ctx context.Context, err error) {
166	c.errorsMu.Lock()
167	defer c.errorsMu.Unlock()
168
169	c.Errors = append(c.Errors, c.ErrorPresenter(ctx, err))
170}
171
172// HasError returns true if the current field has already errored
173func (c *RequestContext) HasError(rctx *ResolverContext) bool {
174	c.errorsMu.Lock()
175	defer c.errorsMu.Unlock()
176	path := rctx.Path()
177
178	for _, err := range c.Errors {
179		if equalPath(err.Path, path) {
180			return true
181		}
182	}
183	return false
184}
185
186// GetErrors returns a list of errors that occurred in the current field
187func (c *RequestContext) GetErrors(rctx *ResolverContext) gqlerror.List {
188	c.errorsMu.Lock()
189	defer c.errorsMu.Unlock()
190	path := rctx.Path()
191
192	var errs gqlerror.List
193	for _, err := range c.Errors {
194		if equalPath(err.Path, path) {
195			errs = append(errs, err)
196		}
197	}
198	return errs
199}
200
201func equalPath(a []interface{}, b []interface{}) bool {
202	if len(a) != len(b) {
203		return false
204	}
205
206	for i := 0; i < len(a); i++ {
207		if a[i] != b[i] {
208			return false
209		}
210	}
211
212	return true
213}
214
215// AddError is a convenience method for adding an error to the current response
216func AddError(ctx context.Context, err error) {
217	GetRequestContext(ctx).Error(ctx, err)
218}
219
220// AddErrorf is a convenience method for adding an error to the current response
221func AddErrorf(ctx context.Context, format string, args ...interface{}) {
222	GetRequestContext(ctx).Errorf(ctx, format, args...)
223}
224
225// RegisterExtension registers an extension, returns error if extension has already been registered
226func (c *RequestContext) RegisterExtension(key string, value interface{}) error {
227	c.extensionsMu.Lock()
228	defer c.extensionsMu.Unlock()
229
230	if c.Extensions == nil {
231		c.Extensions = make(map[string]interface{})
232	}
233
234	if _, ok := c.Extensions[key]; ok {
235		return fmt.Errorf("extension already registered for key %s", key)
236	}
237
238	c.Extensions[key] = value
239	return nil
240}
241
242// ChainFieldMiddleware add chain by FieldMiddleware
243func ChainFieldMiddleware(handleFunc ...FieldMiddleware) FieldMiddleware {
244	n := len(handleFunc)
245
246	if n > 1 {
247		lastI := n - 1
248		return func(ctx context.Context, next Resolver) (interface{}, error) {
249			var (
250				chainHandler Resolver
251				curI         int
252			)
253			chainHandler = func(currentCtx context.Context) (interface{}, error) {
254				if curI == lastI {
255					return next(currentCtx)
256				}
257				curI++
258				res, err := handleFunc[curI](currentCtx, chainHandler)
259				curI--
260				return res, err
261
262			}
263			return handleFunc[0](ctx, chainHandler)
264		}
265	}
266
267	if n == 1 {
268		return handleFunc[0]
269	}
270
271	return func(ctx context.Context, next Resolver) (interface{}, error) {
272		return next(ctx)
273	}
274}