1package graphql
2
3import (
4 "context"
5 "fmt"
6 "sync"
7
8 "github.com/vektah/gqlgen/neelance/query"
9)
10
11type Resolver func(ctx context.Context) (res interface{}, err error)
12type ResolverMiddleware func(ctx context.Context, next Resolver) (res interface{}, err error)
13type RequestMiddleware func(ctx context.Context, next func(ctx context.Context) []byte) []byte
14
15type RequestContext struct {
16 RawQuery string
17 Variables map[string]interface{}
18 Doc *query.Document
19 // ErrorPresenter will be used to generate the error
20 // message from errors given to Error().
21 ErrorPresenter ErrorPresenterFunc
22 Recover RecoverFunc
23 ResolverMiddleware ResolverMiddleware
24 RequestMiddleware RequestMiddleware
25
26 errorsMu sync.Mutex
27 Errors []*Error
28}
29
30func DefaultResolverMiddleware(ctx context.Context, next Resolver) (res interface{}, err error) {
31 return next(ctx)
32}
33
34func DefaultRequestMiddleware(ctx context.Context, next func(ctx context.Context) []byte) []byte {
35 return next(ctx)
36}
37
38func NewRequestContext(doc *query.Document, query string, variables map[string]interface{}) *RequestContext {
39 return &RequestContext{
40 Doc: doc,
41 RawQuery: query,
42 Variables: variables,
43 ResolverMiddleware: DefaultResolverMiddleware,
44 RequestMiddleware: DefaultRequestMiddleware,
45 Recover: DefaultRecover,
46 ErrorPresenter: DefaultErrorPresenter,
47 }
48}
49
50type key string
51
52const (
53 request key = "request_context"
54 resolver key = "resolver_context"
55)
56
57func GetRequestContext(ctx context.Context) *RequestContext {
58 val := ctx.Value(request)
59 if val == nil {
60 return nil
61 }
62
63 return val.(*RequestContext)
64}
65
66func WithRequestContext(ctx context.Context, rc *RequestContext) context.Context {
67 return context.WithValue(ctx, request, rc)
68}
69
70type ResolverContext struct {
71 // The name of the type this field belongs to
72 Object string
73 // These are the args after processing, they can be mutated in middleware to change what the resolver will get.
74 Args map[string]interface{}
75 // The raw field
76 Field CollectedField
77 // The path of fields to get to this resolver
78 Path []interface{}
79}
80
81func (r *ResolverContext) PushField(alias string) {
82 r.Path = append(r.Path, alias)
83}
84
85func (r *ResolverContext) PushIndex(index int) {
86 r.Path = append(r.Path, index)
87}
88
89func (r *ResolverContext) Pop() {
90 r.Path = r.Path[0 : len(r.Path)-1]
91}
92
93func GetResolverContext(ctx context.Context) *ResolverContext {
94 val := ctx.Value(resolver)
95 if val == nil {
96 return nil
97 }
98
99 return val.(*ResolverContext)
100}
101
102func WithResolverContext(ctx context.Context, rc *ResolverContext) context.Context {
103 parent := GetResolverContext(ctx)
104 rc.Path = nil
105 if parent != nil {
106 rc.Path = append(rc.Path, parent.Path...)
107 }
108 if rc.Field.Alias != "" {
109 rc.PushField(rc.Field.Alias)
110 }
111 return context.WithValue(ctx, resolver, rc)
112}
113
114// This is just a convenient wrapper method for CollectFields
115func CollectFieldsCtx(ctx context.Context, satisfies []string) []CollectedField {
116 reqctx := GetRequestContext(ctx)
117 resctx := GetResolverContext(ctx)
118 return CollectFields(reqctx.Doc, resctx.Field.Selections, satisfies, reqctx.Variables)
119}
120
121// Errorf sends an error string to the client, passing it through the formatter.
122func (c *RequestContext) Errorf(ctx context.Context, format string, args ...interface{}) {
123 c.errorsMu.Lock()
124 defer c.errorsMu.Unlock()
125
126 c.Errors = append(c.Errors, c.ErrorPresenter(ctx, fmt.Errorf(format, args...)))
127}
128
129// Error sends an error to the client, passing it through the formatter.
130func (c *RequestContext) Error(ctx context.Context, err error) {
131 c.errorsMu.Lock()
132 defer c.errorsMu.Unlock()
133
134 c.Errors = append(c.Errors, c.ErrorPresenter(ctx, err))
135}
136
137// AddError is a convenience method for adding an error to the current response
138func AddError(ctx context.Context, err error) {
139 GetRequestContext(ctx).Error(ctx, err)
140}
141
142// AddErrorf is a convenience method for adding an error to the current response
143func AddErrorf(ctx context.Context, format string, args ...interface{}) {
144 GetRequestContext(ctx).Errorf(ctx, format, args...)
145}