1package graphql
2
3import (
4 "context"
5 "fmt"
6 "sync"
7
8 "github.com/vektah/gqlparser/ast"
9 "github.com/vektah/gqlparser/gqlerror"
10)
11
12type Resolver func(ctx context.Context) (res interface{}, err error)
13type FieldMiddleware func(ctx context.Context, next Resolver) (res interface{}, err error)
14type RequestMiddleware func(ctx context.Context, next func(ctx context.Context) []byte) []byte
15
16type RequestContext struct {
17 RawQuery string
18 Variables map[string]interface{}
19 Doc *ast.QueryDocument
20 // ErrorPresenter will be used to generate the error
21 // message from errors given to Error().
22 ErrorPresenter ErrorPresenterFunc
23 Recover RecoverFunc
24 ResolverMiddleware FieldMiddleware
25 DirectiveMiddleware FieldMiddleware
26 RequestMiddleware RequestMiddleware
27
28 errorsMu sync.Mutex
29 Errors gqlerror.List
30}
31
32func DefaultResolverMiddleware(ctx context.Context, next Resolver) (res interface{}, err error) {
33 return next(ctx)
34}
35
36func DefaultDirectiveMiddleware(ctx context.Context, next Resolver) (res interface{}, err error) {
37 return next(ctx)
38}
39
40func DefaultRequestMiddleware(ctx context.Context, next func(ctx context.Context) []byte) []byte {
41 return next(ctx)
42}
43
44func NewRequestContext(doc *ast.QueryDocument, query string, variables map[string]interface{}) *RequestContext {
45 return &RequestContext{
46 Doc: doc,
47 RawQuery: query,
48 Variables: variables,
49 ResolverMiddleware: DefaultResolverMiddleware,
50 DirectiveMiddleware: DefaultDirectiveMiddleware,
51 RequestMiddleware: DefaultRequestMiddleware,
52 Recover: DefaultRecover,
53 ErrorPresenter: DefaultErrorPresenter,
54 }
55}
56
57type key string
58
59const (
60 request key = "request_context"
61 resolver key = "resolver_context"
62)
63
64func GetRequestContext(ctx context.Context) *RequestContext {
65 val := ctx.Value(request)
66 if val == nil {
67 return nil
68 }
69
70 return val.(*RequestContext)
71}
72
73func WithRequestContext(ctx context.Context, rc *RequestContext) context.Context {
74 return context.WithValue(ctx, request, rc)
75}
76
77type ResolverContext struct {
78 Parent *ResolverContext
79 // The name of the type this field belongs to
80 Object string
81 // These are the args after processing, they can be mutated in middleware to change what the resolver will get.
82 Args map[string]interface{}
83 // The raw field
84 Field CollectedField
85 // The index of array in path.
86 Index *int
87 // The result object of resolver
88 Result interface{}
89}
90
91func (r *ResolverContext) Path() []interface{} {
92 var path []interface{}
93 for it := r; it != nil; it = it.Parent {
94 if it.Index != nil {
95 path = append(path, *it.Index)
96 } else if it.Field.Field != nil {
97 path = append(path, it.Field.Alias)
98 }
99 }
100
101 // because we are walking up the chain, all the elements are backwards, do an inplace flip.
102 for i := len(path)/2 - 1; i >= 0; i-- {
103 opp := len(path) - 1 - i
104 path[i], path[opp] = path[opp], path[i]
105 }
106
107 return path
108}
109
110func GetResolverContext(ctx context.Context) *ResolverContext {
111 val, _ := ctx.Value(resolver).(*ResolverContext)
112 return val
113}
114
115func WithResolverContext(ctx context.Context, rc *ResolverContext) context.Context {
116 rc.Parent = GetResolverContext(ctx)
117 return context.WithValue(ctx, resolver, rc)
118}
119
120// This is just a convenient wrapper method for CollectFields
121func CollectFieldsCtx(ctx context.Context, satisfies []string) []CollectedField {
122 resctx := GetResolverContext(ctx)
123 return CollectFields(ctx, resctx.Field.Selections, satisfies)
124}
125
126// Errorf sends an error string to the client, passing it through the formatter.
127func (c *RequestContext) Errorf(ctx context.Context, format string, args ...interface{}) {
128 c.errorsMu.Lock()
129 defer c.errorsMu.Unlock()
130
131 c.Errors = append(c.Errors, c.ErrorPresenter(ctx, fmt.Errorf(format, args...)))
132}
133
134// Error sends an error to the client, passing it through the formatter.
135func (c *RequestContext) Error(ctx context.Context, err error) {
136 c.errorsMu.Lock()
137 defer c.errorsMu.Unlock()
138
139 c.Errors = append(c.Errors, c.ErrorPresenter(ctx, err))
140}
141
142// HasError returns true if the current field has already errored
143func (c *RequestContext) HasError(rctx *ResolverContext) bool {
144 c.errorsMu.Lock()
145 defer c.errorsMu.Unlock()
146 path := rctx.Path()
147
148 for _, err := range c.Errors {
149 if equalPath(err.Path, path) {
150 return true
151 }
152 }
153 return false
154}
155
156func equalPath(a []interface{}, b []interface{}) bool {
157 if len(a) != len(b) {
158 return false
159 }
160
161 for i := 0; i < len(a); i++ {
162 if a[i] != b[i] {
163 return false
164 }
165 }
166
167 return true
168}
169
170// AddError is a convenience method for adding an error to the current response
171func AddError(ctx context.Context, err error) {
172 GetRequestContext(ctx).Error(ctx, err)
173}
174
175// AddErrorf is a convenience method for adding an error to the current response
176func AddErrorf(ctx context.Context, format string, args ...interface{}) {
177 GetRequestContext(ctx).Errorf(ctx, format, args...)
178}