1package graphql
2
3import (
4 "context"
5 "fmt"
6 "sync"
7
8 "github.com/vektah/gqlparser/ast"
9 "github.com/vektah/gqlparser/gqlerror"
10)
11
12type Resolver func(ctx context.Context) (res interface{}, err error)
13type FieldMiddleware func(ctx context.Context, next Resolver) (res interface{}, err error)
14type RequestMiddleware func(ctx context.Context, next func(ctx context.Context) []byte) []byte
15type ComplexityLimitFunc func(ctx context.Context) int
16
17type RequestContext struct {
18 RawQuery string
19 Variables map[string]interface{}
20 Doc *ast.QueryDocument
21
22 ComplexityLimit int
23 OperationComplexity int
24 DisableIntrospection bool
25
26 // ErrorPresenter will be used to generate the error
27 // message from errors given to Error().
28 ErrorPresenter ErrorPresenterFunc
29 Recover RecoverFunc
30 ResolverMiddleware FieldMiddleware
31 DirectiveMiddleware FieldMiddleware
32 RequestMiddleware RequestMiddleware
33 Tracer Tracer
34
35 errorsMu sync.Mutex
36 Errors gqlerror.List
37 extensionsMu sync.Mutex
38 Extensions map[string]interface{}
39}
40
41func DefaultResolverMiddleware(ctx context.Context, next Resolver) (res interface{}, err error) {
42 return next(ctx)
43}
44
45func DefaultDirectiveMiddleware(ctx context.Context, next Resolver) (res interface{}, err error) {
46 return next(ctx)
47}
48
49func DefaultRequestMiddleware(ctx context.Context, next func(ctx context.Context) []byte) []byte {
50 return next(ctx)
51}
52
53func NewRequestContext(doc *ast.QueryDocument, query string, variables map[string]interface{}) *RequestContext {
54 return &RequestContext{
55 Doc: doc,
56 RawQuery: query,
57 Variables: variables,
58 ResolverMiddleware: DefaultResolverMiddleware,
59 DirectiveMiddleware: DefaultDirectiveMiddleware,
60 RequestMiddleware: DefaultRequestMiddleware,
61 Recover: DefaultRecover,
62 ErrorPresenter: DefaultErrorPresenter,
63 Tracer: &NopTracer{},
64 }
65}
66
67type key string
68
69const (
70 request key = "request_context"
71 resolver key = "resolver_context"
72)
73
74func GetRequestContext(ctx context.Context) *RequestContext {
75 if val, ok := ctx.Value(request).(*RequestContext); ok {
76 return val
77 }
78 return nil
79}
80
81func WithRequestContext(ctx context.Context, rc *RequestContext) context.Context {
82 return context.WithValue(ctx, request, rc)
83}
84
85type ResolverContext struct {
86 Parent *ResolverContext
87 // The name of the type this field belongs to
88 Object string
89 // These are the args after processing, they can be mutated in middleware to change what the resolver will get.
90 Args map[string]interface{}
91 // The raw field
92 Field CollectedField
93 // The index of array in path.
94 Index *int
95 // The result object of resolver
96 Result interface{}
97 // IsMethod indicates if the resolver is a method
98 IsMethod bool
99}
100
101func (r *ResolverContext) Path() []interface{} {
102 var path []interface{}
103 for it := r; it != nil; it = it.Parent {
104 if it.Index != nil {
105 path = append(path, *it.Index)
106 } else if it.Field.Field != nil {
107 path = append(path, it.Field.Alias)
108 }
109 }
110
111 // because we are walking up the chain, all the elements are backwards, do an inplace flip.
112 for i := len(path)/2 - 1; i >= 0; i-- {
113 opp := len(path) - 1 - i
114 path[i], path[opp] = path[opp], path[i]
115 }
116
117 return path
118}
119
120func GetResolverContext(ctx context.Context) *ResolverContext {
121 if val, ok := ctx.Value(resolver).(*ResolverContext); ok {
122 return val
123 }
124 return nil
125}
126
127func WithResolverContext(ctx context.Context, rc *ResolverContext) context.Context {
128 rc.Parent = GetResolverContext(ctx)
129 return context.WithValue(ctx, resolver, rc)
130}
131
132// This is just a convenient wrapper method for CollectFields
133func CollectFieldsCtx(ctx context.Context, satisfies []string) []CollectedField {
134 resctx := GetResolverContext(ctx)
135 return CollectFields(ctx, resctx.Field.Selections, satisfies)
136}
137
138// CollectAllFields returns a slice of all GraphQL field names that were selected for the current resolver context.
139// The slice will contain the unique set of all field names requested regardless of fragment type conditions.
140func CollectAllFields(ctx context.Context) []string {
141 resctx := GetResolverContext(ctx)
142 collected := CollectFields(ctx, resctx.Field.Selections, nil)
143 uniq := make([]string, 0, len(collected))
144Next:
145 for _, f := range collected {
146 for _, name := range uniq {
147 if name == f.Name {
148 continue Next
149 }
150 }
151 uniq = append(uniq, f.Name)
152 }
153 return uniq
154}
155
156// Errorf sends an error string to the client, passing it through the formatter.
157func (c *RequestContext) Errorf(ctx context.Context, format string, args ...interface{}) {
158 c.errorsMu.Lock()
159 defer c.errorsMu.Unlock()
160
161 c.Errors = append(c.Errors, c.ErrorPresenter(ctx, fmt.Errorf(format, args...)))
162}
163
164// Error sends an error to the client, passing it through the formatter.
165func (c *RequestContext) Error(ctx context.Context, err error) {
166 c.errorsMu.Lock()
167 defer c.errorsMu.Unlock()
168
169 c.Errors = append(c.Errors, c.ErrorPresenter(ctx, err))
170}
171
172// HasError returns true if the current field has already errored
173func (c *RequestContext) HasError(rctx *ResolverContext) bool {
174 c.errorsMu.Lock()
175 defer c.errorsMu.Unlock()
176 path := rctx.Path()
177
178 for _, err := range c.Errors {
179 if equalPath(err.Path, path) {
180 return true
181 }
182 }
183 return false
184}
185
186// GetErrors returns a list of errors that occurred in the current field
187func (c *RequestContext) GetErrors(rctx *ResolverContext) gqlerror.List {
188 c.errorsMu.Lock()
189 defer c.errorsMu.Unlock()
190 path := rctx.Path()
191
192 var errs gqlerror.List
193 for _, err := range c.Errors {
194 if equalPath(err.Path, path) {
195 errs = append(errs, err)
196 }
197 }
198 return errs
199}
200
201func equalPath(a []interface{}, b []interface{}) bool {
202 if len(a) != len(b) {
203 return false
204 }
205
206 for i := 0; i < len(a); i++ {
207 if a[i] != b[i] {
208 return false
209 }
210 }
211
212 return true
213}
214
215// AddError is a convenience method for adding an error to the current response
216func AddError(ctx context.Context, err error) {
217 GetRequestContext(ctx).Error(ctx, err)
218}
219
220// AddErrorf is a convenience method for adding an error to the current response
221func AddErrorf(ctx context.Context, format string, args ...interface{}) {
222 GetRequestContext(ctx).Errorf(ctx, format, args...)
223}
224
225// RegisterExtension registers an extension, returns error if extension has already been registered
226func (c *RequestContext) RegisterExtension(key string, value interface{}) error {
227 c.extensionsMu.Lock()
228 defer c.extensionsMu.Unlock()
229
230 if c.Extensions == nil {
231 c.Extensions = make(map[string]interface{})
232 }
233
234 if _, ok := c.Extensions[key]; ok {
235 return fmt.Errorf("extension already registered for key %s", key)
236 }
237
238 c.Extensions[key] = value
239 return nil
240}
241
242// ChainFieldMiddleware add chain by FieldMiddleware
243func ChainFieldMiddleware(handleFunc ...FieldMiddleware) FieldMiddleware {
244 n := len(handleFunc)
245
246 if n > 1 {
247 lastI := n - 1
248 return func(ctx context.Context, next Resolver) (interface{}, error) {
249 var (
250 chainHandler Resolver
251 curI int
252 )
253 chainHandler = func(currentCtx context.Context) (interface{}, error) {
254 if curI == lastI {
255 return next(currentCtx)
256 }
257 curI++
258 res, err := handleFunc[curI](currentCtx, chainHandler)
259 curI--
260 return res, err
261
262 }
263 return handleFunc[0](ctx, chainHandler)
264 }
265 }
266
267 if n == 1 {
268 return handleFunc[0]
269 }
270
271 return func(ctx context.Context, next Resolver) (interface{}, error) {
272 return next(ctx)
273 }
274}