1package handler
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "io"
8 "net/http"
9 "strings"
10
11 "github.com/99designs/gqlgen/complexity"
12 "github.com/99designs/gqlgen/graphql"
13 "github.com/gorilla/websocket"
14 "github.com/hashicorp/golang-lru"
15 "github.com/vektah/gqlparser/ast"
16 "github.com/vektah/gqlparser/gqlerror"
17 "github.com/vektah/gqlparser/parser"
18 "github.com/vektah/gqlparser/validator"
19)
20
21type params struct {
22 Query string `json:"query"`
23 OperationName string `json:"operationName"`
24 Variables map[string]interface{} `json:"variables"`
25}
26
27type Config struct {
28 cacheSize int
29 upgrader websocket.Upgrader
30 recover graphql.RecoverFunc
31 errorPresenter graphql.ErrorPresenterFunc
32 resolverHook graphql.FieldMiddleware
33 requestHook graphql.RequestMiddleware
34 tracer graphql.Tracer
35 complexityLimit int
36 disableIntrospection bool
37}
38
39func (c *Config) newRequestContext(es graphql.ExecutableSchema, doc *ast.QueryDocument, op *ast.OperationDefinition, query string, variables map[string]interface{}) *graphql.RequestContext {
40 reqCtx := graphql.NewRequestContext(doc, query, variables)
41 reqCtx.DisableIntrospection = c.disableIntrospection
42
43 if hook := c.recover; hook != nil {
44 reqCtx.Recover = hook
45 }
46
47 if hook := c.errorPresenter; hook != nil {
48 reqCtx.ErrorPresenter = hook
49 }
50
51 if hook := c.resolverHook; hook != nil {
52 reqCtx.ResolverMiddleware = hook
53 }
54
55 if hook := c.requestHook; hook != nil {
56 reqCtx.RequestMiddleware = hook
57 }
58
59 if hook := c.tracer; hook != nil {
60 reqCtx.Tracer = hook
61 } else {
62 reqCtx.Tracer = &graphql.NopTracer{}
63 }
64
65 if c.complexityLimit > 0 {
66 reqCtx.ComplexityLimit = c.complexityLimit
67 operationComplexity := complexity.Calculate(es, op, variables)
68 reqCtx.OperationComplexity = operationComplexity
69 }
70
71 return reqCtx
72}
73
74type Option func(cfg *Config)
75
76func WebsocketUpgrader(upgrader websocket.Upgrader) Option {
77 return func(cfg *Config) {
78 cfg.upgrader = upgrader
79 }
80}
81
82func RecoverFunc(recover graphql.RecoverFunc) Option {
83 return func(cfg *Config) {
84 cfg.recover = recover
85 }
86}
87
88// ErrorPresenter transforms errors found while resolving into errors that will be returned to the user. It provides
89// a good place to add any extra fields, like error.type, that might be desired by your frontend. Check the default
90// implementation in graphql.DefaultErrorPresenter for an example.
91func ErrorPresenter(f graphql.ErrorPresenterFunc) Option {
92 return func(cfg *Config) {
93 cfg.errorPresenter = f
94 }
95}
96
97// IntrospectionEnabled = false will forbid clients from calling introspection endpoints. Can be useful in prod when you dont
98// want clients introspecting the full schema.
99func IntrospectionEnabled(enabled bool) Option {
100 return func(cfg *Config) {
101 cfg.disableIntrospection = !enabled
102 }
103}
104
105// ComplexityLimit sets a maximum query complexity that is allowed to be executed.
106// If a query is submitted that exceeds the limit, a 422 status code will be returned.
107func ComplexityLimit(limit int) Option {
108 return func(cfg *Config) {
109 cfg.complexityLimit = limit
110 }
111}
112
113// ResolverMiddleware allows you to define a function that will be called around every resolver,
114// useful for logging.
115func ResolverMiddleware(middleware graphql.FieldMiddleware) Option {
116 return func(cfg *Config) {
117 if cfg.resolverHook == nil {
118 cfg.resolverHook = middleware
119 return
120 }
121
122 lastResolve := cfg.resolverHook
123 cfg.resolverHook = func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) {
124 return lastResolve(ctx, func(ctx context.Context) (res interface{}, err error) {
125 return middleware(ctx, next)
126 })
127 }
128 }
129}
130
131// RequestMiddleware allows you to define a function that will be called around the root request,
132// after the query has been parsed. This is useful for logging
133func RequestMiddleware(middleware graphql.RequestMiddleware) Option {
134 return func(cfg *Config) {
135 if cfg.requestHook == nil {
136 cfg.requestHook = middleware
137 return
138 }
139
140 lastResolve := cfg.requestHook
141 cfg.requestHook = func(ctx context.Context, next func(ctx context.Context) []byte) []byte {
142 return lastResolve(ctx, func(ctx context.Context) []byte {
143 return middleware(ctx, next)
144 })
145 }
146 }
147}
148
149// Tracer allows you to add a request/resolver tracer that will be called around the root request,
150// calling resolver. This is useful for tracing
151func Tracer(tracer graphql.Tracer) Option {
152 return func(cfg *Config) {
153 if cfg.tracer == nil {
154 cfg.tracer = tracer
155
156 } else {
157 lastResolve := cfg.tracer
158 cfg.tracer = &tracerWrapper{
159 tracer1: lastResolve,
160 tracer2: tracer,
161 }
162 }
163
164 opt := RequestMiddleware(func(ctx context.Context, next func(ctx context.Context) []byte) []byte {
165 ctx = tracer.StartOperationExecution(ctx)
166 resp := next(ctx)
167 tracer.EndOperationExecution(ctx)
168
169 return resp
170 })
171 opt(cfg)
172 }
173}
174
175type tracerWrapper struct {
176 tracer1 graphql.Tracer
177 tracer2 graphql.Tracer
178}
179
180func (tw *tracerWrapper) StartOperationParsing(ctx context.Context) context.Context {
181 ctx = tw.tracer1.StartOperationParsing(ctx)
182 ctx = tw.tracer2.StartOperationParsing(ctx)
183 return ctx
184}
185
186func (tw *tracerWrapper) EndOperationParsing(ctx context.Context) {
187 tw.tracer2.EndOperationParsing(ctx)
188 tw.tracer1.EndOperationParsing(ctx)
189}
190
191func (tw *tracerWrapper) StartOperationValidation(ctx context.Context) context.Context {
192 ctx = tw.tracer1.StartOperationValidation(ctx)
193 ctx = tw.tracer2.StartOperationValidation(ctx)
194 return ctx
195}
196
197func (tw *tracerWrapper) EndOperationValidation(ctx context.Context) {
198 tw.tracer2.EndOperationValidation(ctx)
199 tw.tracer1.EndOperationValidation(ctx)
200}
201
202func (tw *tracerWrapper) StartOperationExecution(ctx context.Context) context.Context {
203 ctx = tw.tracer1.StartOperationExecution(ctx)
204 ctx = tw.tracer2.StartOperationExecution(ctx)
205 return ctx
206}
207
208func (tw *tracerWrapper) StartFieldExecution(ctx context.Context, field graphql.CollectedField) context.Context {
209 ctx = tw.tracer1.StartFieldExecution(ctx, field)
210 ctx = tw.tracer2.StartFieldExecution(ctx, field)
211 return ctx
212}
213
214func (tw *tracerWrapper) StartFieldResolverExecution(ctx context.Context, rc *graphql.ResolverContext) context.Context {
215 ctx = tw.tracer1.StartFieldResolverExecution(ctx, rc)
216 ctx = tw.tracer2.StartFieldResolverExecution(ctx, rc)
217 return ctx
218}
219
220func (tw *tracerWrapper) StartFieldChildExecution(ctx context.Context) context.Context {
221 ctx = tw.tracer1.StartFieldChildExecution(ctx)
222 ctx = tw.tracer2.StartFieldChildExecution(ctx)
223 return ctx
224}
225
226func (tw *tracerWrapper) EndFieldExecution(ctx context.Context) {
227 tw.tracer2.EndFieldExecution(ctx)
228 tw.tracer1.EndFieldExecution(ctx)
229}
230
231func (tw *tracerWrapper) EndOperationExecution(ctx context.Context) {
232 tw.tracer2.EndOperationExecution(ctx)
233 tw.tracer1.EndOperationExecution(ctx)
234}
235
236// CacheSize sets the maximum size of the query cache.
237// If size is less than or equal to 0, the cache is disabled.
238func CacheSize(size int) Option {
239 return func(cfg *Config) {
240 cfg.cacheSize = size
241 }
242}
243
244const DefaultCacheSize = 1000
245
246func GraphQL(exec graphql.ExecutableSchema, options ...Option) http.HandlerFunc {
247 cfg := &Config{
248 cacheSize: DefaultCacheSize,
249 upgrader: websocket.Upgrader{
250 ReadBufferSize: 1024,
251 WriteBufferSize: 1024,
252 },
253 }
254
255 for _, option := range options {
256 option(cfg)
257 }
258
259 var cache *lru.Cache
260 if cfg.cacheSize > 0 {
261 var err error
262 cache, err = lru.New(DefaultCacheSize)
263 if err != nil {
264 // An error is only returned for non-positive cache size
265 // and we already checked for that.
266 panic("unexpected error creating cache: " + err.Error())
267 }
268 }
269 if cfg.tracer == nil {
270 cfg.tracer = &graphql.NopTracer{}
271 }
272
273 handler := &graphqlHandler{
274 cfg: cfg,
275 cache: cache,
276 exec: exec,
277 }
278
279 return handler.ServeHTTP
280}
281
282var _ http.Handler = (*graphqlHandler)(nil)
283
284type graphqlHandler struct {
285 cfg *Config
286 cache *lru.Cache
287 exec graphql.ExecutableSchema
288}
289
290func (gh *graphqlHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
291 if r.Method == http.MethodOptions {
292 w.Header().Set("Allow", "OPTIONS, GET, POST")
293 w.WriteHeader(http.StatusOK)
294 return
295 }
296
297 if strings.Contains(r.Header.Get("Upgrade"), "websocket") {
298 connectWs(gh.exec, w, r, gh.cfg)
299 return
300 }
301
302 var reqParams params
303 switch r.Method {
304 case http.MethodGet:
305 reqParams.Query = r.URL.Query().Get("query")
306 reqParams.OperationName = r.URL.Query().Get("operationName")
307
308 if variables := r.URL.Query().Get("variables"); variables != "" {
309 if err := jsonDecode(strings.NewReader(variables), &reqParams.Variables); err != nil {
310 sendErrorf(w, http.StatusBadRequest, "variables could not be decoded")
311 return
312 }
313 }
314 case http.MethodPost:
315 if err := jsonDecode(r.Body, &reqParams); err != nil {
316 sendErrorf(w, http.StatusBadRequest, "json body could not be decoded: "+err.Error())
317 return
318 }
319 default:
320 w.WriteHeader(http.StatusMethodNotAllowed)
321 return
322 }
323 w.Header().Set("Content-Type", "application/json")
324
325 ctx := r.Context()
326
327 var doc *ast.QueryDocument
328 var cacheHit bool
329 if gh.cache != nil {
330 val, ok := gh.cache.Get(reqParams.Query)
331 if ok {
332 doc = val.(*ast.QueryDocument)
333 cacheHit = true
334 }
335 }
336
337 ctx, doc, gqlErr := gh.parseOperation(ctx, &parseOperationArgs{
338 Query: reqParams.Query,
339 CachedDoc: doc,
340 })
341 if gqlErr != nil {
342 sendError(w, http.StatusUnprocessableEntity, gqlErr)
343 return
344 }
345
346 ctx, op, vars, listErr := gh.validateOperation(ctx, &validateOperationArgs{
347 Doc: doc,
348 OperationName: reqParams.OperationName,
349 CacheHit: cacheHit,
350 R: r,
351 Variables: reqParams.Variables,
352 })
353 if len(listErr) != 0 {
354 sendError(w, http.StatusUnprocessableEntity, listErr...)
355 return
356 }
357
358 if gh.cache != nil && !cacheHit {
359 gh.cache.Add(reqParams.Query, doc)
360 }
361
362 reqCtx := gh.cfg.newRequestContext(gh.exec, doc, op, reqParams.Query, vars)
363 ctx = graphql.WithRequestContext(ctx, reqCtx)
364
365 defer func() {
366 if err := recover(); err != nil {
367 userErr := reqCtx.Recover(ctx, err)
368 sendErrorf(w, http.StatusUnprocessableEntity, userErr.Error())
369 }
370 }()
371
372 if reqCtx.ComplexityLimit > 0 && reqCtx.OperationComplexity > reqCtx.ComplexityLimit {
373 sendErrorf(w, http.StatusUnprocessableEntity, "operation has complexity %d, which exceeds the limit of %d", reqCtx.OperationComplexity, reqCtx.ComplexityLimit)
374 return
375 }
376
377 switch op.Operation {
378 case ast.Query:
379 b, err := json.Marshal(gh.exec.Query(ctx, op))
380 if err != nil {
381 panic(err)
382 }
383 w.Write(b)
384 case ast.Mutation:
385 b, err := json.Marshal(gh.exec.Mutation(ctx, op))
386 if err != nil {
387 panic(err)
388 }
389 w.Write(b)
390 default:
391 sendErrorf(w, http.StatusBadRequest, "unsupported operation type")
392 }
393}
394
395type parseOperationArgs struct {
396 Query string
397 CachedDoc *ast.QueryDocument
398}
399
400func (gh *graphqlHandler) parseOperation(ctx context.Context, args *parseOperationArgs) (context.Context, *ast.QueryDocument, *gqlerror.Error) {
401 ctx = gh.cfg.tracer.StartOperationParsing(ctx)
402 defer func() { gh.cfg.tracer.EndOperationParsing(ctx) }()
403
404 if args.CachedDoc != nil {
405 return ctx, args.CachedDoc, nil
406 }
407
408 doc, gqlErr := parser.ParseQuery(&ast.Source{Input: args.Query})
409 if gqlErr != nil {
410 return ctx, nil, gqlErr
411 }
412
413 return ctx, doc, nil
414}
415
416type validateOperationArgs struct {
417 Doc *ast.QueryDocument
418 OperationName string
419 CacheHit bool
420 R *http.Request
421 Variables map[string]interface{}
422}
423
424func (gh *graphqlHandler) validateOperation(ctx context.Context, args *validateOperationArgs) (context.Context, *ast.OperationDefinition, map[string]interface{}, gqlerror.List) {
425 ctx = gh.cfg.tracer.StartOperationValidation(ctx)
426 defer func() { gh.cfg.tracer.EndOperationValidation(ctx) }()
427
428 if !args.CacheHit {
429 listErr := validator.Validate(gh.exec.Schema(), args.Doc)
430 if len(listErr) != 0 {
431 return ctx, nil, nil, listErr
432 }
433 }
434
435 op := args.Doc.Operations.ForName(args.OperationName)
436 if op == nil {
437 return ctx, nil, nil, gqlerror.List{gqlerror.Errorf("operation %s not found", args.OperationName)}
438 }
439
440 if op.Operation != ast.Query && args.R.Method == http.MethodGet {
441 return ctx, nil, nil, gqlerror.List{gqlerror.Errorf("GET requests only allow query operations")}
442 }
443
444 vars, err := validator.VariableValues(gh.exec.Schema(), op, args.Variables)
445 if err != nil {
446 return ctx, nil, nil, gqlerror.List{err}
447 }
448
449 return ctx, op, vars, nil
450}
451
452func jsonDecode(r io.Reader, val interface{}) error {
453 dec := json.NewDecoder(r)
454 dec.UseNumber()
455 return dec.Decode(val)
456}
457
458func sendError(w http.ResponseWriter, code int, errors ...*gqlerror.Error) {
459 w.WriteHeader(code)
460 b, err := json.Marshal(&graphql.Response{Errors: errors})
461 if err != nil {
462 panic(err)
463 }
464 w.Write(b)
465}
466
467func sendErrorf(w http.ResponseWriter, code int, format string, args ...interface{}) {
468 sendError(w, code, &gqlerror.Error{Message: fmt.Sprintf(format, args...)})
469}