context.go

  1package graphql
  2
  3import (
  4	"context"
  5	"fmt"
  6	"sync"
  7
  8	"github.com/vektah/gqlparser/ast"
  9	"github.com/vektah/gqlparser/gqlerror"
 10)
 11
 12type Resolver func(ctx context.Context) (res interface{}, err error)
 13type FieldMiddleware func(ctx context.Context, next Resolver) (res interface{}, err error)
 14type RequestMiddleware func(ctx context.Context, next func(ctx context.Context) []byte) []byte
 15
 16type RequestContext struct {
 17	RawQuery  string
 18	Variables map[string]interface{}
 19	Doc       *ast.QueryDocument
 20	// ErrorPresenter will be used to generate the error
 21	// message from errors given to Error().
 22	ErrorPresenter      ErrorPresenterFunc
 23	Recover             RecoverFunc
 24	ResolverMiddleware  FieldMiddleware
 25	DirectiveMiddleware FieldMiddleware
 26	RequestMiddleware   RequestMiddleware
 27
 28	errorsMu sync.Mutex
 29	Errors   gqlerror.List
 30}
 31
 32func DefaultResolverMiddleware(ctx context.Context, next Resolver) (res interface{}, err error) {
 33	return next(ctx)
 34}
 35
 36func DefaultDirectiveMiddleware(ctx context.Context, next Resolver) (res interface{}, err error) {
 37	return next(ctx)
 38}
 39
 40func DefaultRequestMiddleware(ctx context.Context, next func(ctx context.Context) []byte) []byte {
 41	return next(ctx)
 42}
 43
 44func NewRequestContext(doc *ast.QueryDocument, query string, variables map[string]interface{}) *RequestContext {
 45	return &RequestContext{
 46		Doc:                 doc,
 47		RawQuery:            query,
 48		Variables:           variables,
 49		ResolverMiddleware:  DefaultResolverMiddleware,
 50		DirectiveMiddleware: DefaultDirectiveMiddleware,
 51		RequestMiddleware:   DefaultRequestMiddleware,
 52		Recover:             DefaultRecover,
 53		ErrorPresenter:      DefaultErrorPresenter,
 54	}
 55}
 56
 57type key string
 58
 59const (
 60	request  key = "request_context"
 61	resolver key = "resolver_context"
 62)
 63
 64func GetRequestContext(ctx context.Context) *RequestContext {
 65	val := ctx.Value(request)
 66	if val == nil {
 67		return nil
 68	}
 69
 70	return val.(*RequestContext)
 71}
 72
 73func WithRequestContext(ctx context.Context, rc *RequestContext) context.Context {
 74	return context.WithValue(ctx, request, rc)
 75}
 76
 77type ResolverContext struct {
 78	Parent *ResolverContext
 79	// The name of the type this field belongs to
 80	Object string
 81	// These are the args after processing, they can be mutated in middleware to change what the resolver will get.
 82	Args map[string]interface{}
 83	// The raw field
 84	Field CollectedField
 85	// The index of array in path.
 86	Index *int
 87	// The result object of resolver
 88	Result interface{}
 89}
 90
 91func (r *ResolverContext) Path() []interface{} {
 92	var path []interface{}
 93	for it := r; it != nil; it = it.Parent {
 94		if it.Index != nil {
 95			path = append(path, *it.Index)
 96		} else if it.Field.Field != nil {
 97			path = append(path, it.Field.Alias)
 98		}
 99	}
100
101	// because we are walking up the chain, all the elements are backwards, do an inplace flip.
102	for i := len(path)/2 - 1; i >= 0; i-- {
103		opp := len(path) - 1 - i
104		path[i], path[opp] = path[opp], path[i]
105	}
106
107	return path
108}
109
110func GetResolverContext(ctx context.Context) *ResolverContext {
111	val, _ := ctx.Value(resolver).(*ResolverContext)
112	return val
113}
114
115func WithResolverContext(ctx context.Context, rc *ResolverContext) context.Context {
116	rc.Parent = GetResolverContext(ctx)
117	return context.WithValue(ctx, resolver, rc)
118}
119
120// This is just a convenient wrapper method for CollectFields
121func CollectFieldsCtx(ctx context.Context, satisfies []string) []CollectedField {
122	resctx := GetResolverContext(ctx)
123	return CollectFields(ctx, resctx.Field.Selections, satisfies)
124}
125
126// Errorf sends an error string to the client, passing it through the formatter.
127func (c *RequestContext) Errorf(ctx context.Context, format string, args ...interface{}) {
128	c.errorsMu.Lock()
129	defer c.errorsMu.Unlock()
130
131	c.Errors = append(c.Errors, c.ErrorPresenter(ctx, fmt.Errorf(format, args...)))
132}
133
134// Error sends an error to the client, passing it through the formatter.
135func (c *RequestContext) Error(ctx context.Context, err error) {
136	c.errorsMu.Lock()
137	defer c.errorsMu.Unlock()
138
139	c.Errors = append(c.Errors, c.ErrorPresenter(ctx, err))
140}
141
142// HasError returns true if the current field has already errored
143func (c *RequestContext) HasError(rctx *ResolverContext) bool {
144	c.errorsMu.Lock()
145	defer c.errorsMu.Unlock()
146	path := rctx.Path()
147
148	for _, err := range c.Errors {
149		if equalPath(err.Path, path) {
150			return true
151		}
152	}
153	return false
154}
155
156func equalPath(a []interface{}, b []interface{}) bool {
157	if len(a) != len(b) {
158		return false
159	}
160
161	for i := 0; i < len(a); i++ {
162		if a[i] != b[i] {
163			return false
164		}
165	}
166
167	return true
168}
169
170// AddError is a convenience method for adding an error to the current response
171func AddError(ctx context.Context, err error) {
172	GetRequestContext(ctx).Error(ctx, err)
173}
174
175// AddErrorf is a convenience method for adding an error to the current response
176func AddErrorf(ctx context.Context, format string, args ...interface{}) {
177	GetRequestContext(ctx).Errorf(ctx, format, args...)
178}