1package graphql
2
3import (
4 "context"
5 "fmt"
6 "sync"
7
8 "github.com/vektah/gqlparser/ast"
9 "github.com/vektah/gqlparser/gqlerror"
10)
11
12type Resolver func(ctx context.Context) (res interface{}, err error)
13type FieldMiddleware func(ctx context.Context, next Resolver) (res interface{}, err error)
14type RequestMiddleware func(ctx context.Context, next func(ctx context.Context) []byte) []byte
15
16type RequestContext struct {
17 RawQuery string
18 Variables map[string]interface{}
19 Doc *ast.QueryDocument
20
21 ComplexityLimit int
22 OperationComplexity int
23 DisableIntrospection bool
24
25 // ErrorPresenter will be used to generate the error
26 // message from errors given to Error().
27 ErrorPresenter ErrorPresenterFunc
28 Recover RecoverFunc
29 ResolverMiddleware FieldMiddleware
30 DirectiveMiddleware FieldMiddleware
31 RequestMiddleware RequestMiddleware
32 Tracer Tracer
33
34 errorsMu sync.Mutex
35 Errors gqlerror.List
36 extensionsMu sync.Mutex
37 Extensions map[string]interface{}
38}
39
40func DefaultResolverMiddleware(ctx context.Context, next Resolver) (res interface{}, err error) {
41 return next(ctx)
42}
43
44func DefaultDirectiveMiddleware(ctx context.Context, next Resolver) (res interface{}, err error) {
45 return next(ctx)
46}
47
48func DefaultRequestMiddleware(ctx context.Context, next func(ctx context.Context) []byte) []byte {
49 return next(ctx)
50}
51
52func NewRequestContext(doc *ast.QueryDocument, query string, variables map[string]interface{}) *RequestContext {
53 return &RequestContext{
54 Doc: doc,
55 RawQuery: query,
56 Variables: variables,
57 ResolverMiddleware: DefaultResolverMiddleware,
58 DirectiveMiddleware: DefaultDirectiveMiddleware,
59 RequestMiddleware: DefaultRequestMiddleware,
60 Recover: DefaultRecover,
61 ErrorPresenter: DefaultErrorPresenter,
62 Tracer: &NopTracer{},
63 }
64}
65
66type key string
67
68const (
69 request key = "request_context"
70 resolver key = "resolver_context"
71)
72
73func GetRequestContext(ctx context.Context) *RequestContext {
74 val := ctx.Value(request)
75 if val == nil {
76 return nil
77 }
78
79 return val.(*RequestContext)
80}
81
82func WithRequestContext(ctx context.Context, rc *RequestContext) context.Context {
83 return context.WithValue(ctx, request, rc)
84}
85
86type ResolverContext struct {
87 Parent *ResolverContext
88 // The name of the type this field belongs to
89 Object string
90 // These are the args after processing, they can be mutated in middleware to change what the resolver will get.
91 Args map[string]interface{}
92 // The raw field
93 Field CollectedField
94 // The index of array in path.
95 Index *int
96 // The result object of resolver
97 Result interface{}
98}
99
100func (r *ResolverContext) Path() []interface{} {
101 var path []interface{}
102 for it := r; it != nil; it = it.Parent {
103 if it.Index != nil {
104 path = append(path, *it.Index)
105 } else if it.Field.Field != nil {
106 path = append(path, it.Field.Alias)
107 }
108 }
109
110 // because we are walking up the chain, all the elements are backwards, do an inplace flip.
111 for i := len(path)/2 - 1; i >= 0; i-- {
112 opp := len(path) - 1 - i
113 path[i], path[opp] = path[opp], path[i]
114 }
115
116 return path
117}
118
119func GetResolverContext(ctx context.Context) *ResolverContext {
120 val, _ := ctx.Value(resolver).(*ResolverContext)
121 return val
122}
123
124func WithResolverContext(ctx context.Context, rc *ResolverContext) context.Context {
125 rc.Parent = GetResolverContext(ctx)
126 return context.WithValue(ctx, resolver, rc)
127}
128
129// This is just a convenient wrapper method for CollectFields
130func CollectFieldsCtx(ctx context.Context, satisfies []string) []CollectedField {
131 resctx := GetResolverContext(ctx)
132 return CollectFields(ctx, resctx.Field.Selections, satisfies)
133}
134
135// Errorf sends an error string to the client, passing it through the formatter.
136func (c *RequestContext) Errorf(ctx context.Context, format string, args ...interface{}) {
137 c.errorsMu.Lock()
138 defer c.errorsMu.Unlock()
139
140 c.Errors = append(c.Errors, c.ErrorPresenter(ctx, fmt.Errorf(format, args...)))
141}
142
143// Error sends an error to the client, passing it through the formatter.
144func (c *RequestContext) Error(ctx context.Context, err error) {
145 c.errorsMu.Lock()
146 defer c.errorsMu.Unlock()
147
148 c.Errors = append(c.Errors, c.ErrorPresenter(ctx, err))
149}
150
151// HasError returns true if the current field has already errored
152func (c *RequestContext) HasError(rctx *ResolverContext) bool {
153 c.errorsMu.Lock()
154 defer c.errorsMu.Unlock()
155 path := rctx.Path()
156
157 for _, err := range c.Errors {
158 if equalPath(err.Path, path) {
159 return true
160 }
161 }
162 return false
163}
164
165// GetErrors returns a list of errors that occurred in the current field
166func (c *RequestContext) GetErrors(rctx *ResolverContext) gqlerror.List {
167 c.errorsMu.Lock()
168 defer c.errorsMu.Unlock()
169 path := rctx.Path()
170
171 var errs gqlerror.List
172 for _, err := range c.Errors {
173 if equalPath(err.Path, path) {
174 errs = append(errs, err)
175 }
176 }
177 return errs
178}
179
180func equalPath(a []interface{}, b []interface{}) bool {
181 if len(a) != len(b) {
182 return false
183 }
184
185 for i := 0; i < len(a); i++ {
186 if a[i] != b[i] {
187 return false
188 }
189 }
190
191 return true
192}
193
194// AddError is a convenience method for adding an error to the current response
195func AddError(ctx context.Context, err error) {
196 GetRequestContext(ctx).Error(ctx, err)
197}
198
199// AddErrorf is a convenience method for adding an error to the current response
200func AddErrorf(ctx context.Context, format string, args ...interface{}) {
201 GetRequestContext(ctx).Errorf(ctx, format, args...)
202}
203
204// RegisterExtension registers an extension, returns error if extension has already been registered
205func (c *RequestContext) RegisterExtension(key string, value interface{}) error {
206 c.extensionsMu.Lock()
207 defer c.extensionsMu.Unlock()
208
209 if c.Extensions == nil {
210 c.Extensions = make(map[string]interface{})
211 }
212
213 if _, ok := c.Extensions[key]; ok {
214 return fmt.Errorf("extension already registered for key %s", key)
215 }
216
217 c.Extensions[key] = value
218 return nil
219}