1package handler
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "io"
9 "io/ioutil"
10 "mime"
11 "net/http"
12 "os"
13 "strconv"
14 "strings"
15 "time"
16
17 "github.com/99designs/gqlgen/complexity"
18 "github.com/99designs/gqlgen/graphql"
19 "github.com/gorilla/websocket"
20 lru "github.com/hashicorp/golang-lru"
21 "github.com/vektah/gqlparser/ast"
22 "github.com/vektah/gqlparser/gqlerror"
23 "github.com/vektah/gqlparser/parser"
24 "github.com/vektah/gqlparser/validator"
25)
26
27type params struct {
28 Query string `json:"query"`
29 OperationName string `json:"operationName"`
30 Variables map[string]interface{} `json:"variables"`
31}
32
33type Config struct {
34 cacheSize int
35 upgrader websocket.Upgrader
36 recover graphql.RecoverFunc
37 errorPresenter graphql.ErrorPresenterFunc
38 resolverHook graphql.FieldMiddleware
39 requestHook graphql.RequestMiddleware
40 tracer graphql.Tracer
41 complexityLimit int
42 complexityLimitFunc graphql.ComplexityLimitFunc
43 disableIntrospection bool
44 connectionKeepAlivePingInterval time.Duration
45 uploadMaxMemory int64
46 uploadMaxSize int64
47}
48
49func (c *Config) newRequestContext(es graphql.ExecutableSchema, doc *ast.QueryDocument, op *ast.OperationDefinition, query string, variables map[string]interface{}) *graphql.RequestContext {
50 reqCtx := graphql.NewRequestContext(doc, query, variables)
51 reqCtx.DisableIntrospection = c.disableIntrospection
52
53 if hook := c.recover; hook != nil {
54 reqCtx.Recover = hook
55 }
56
57 if hook := c.errorPresenter; hook != nil {
58 reqCtx.ErrorPresenter = hook
59 }
60
61 if hook := c.resolverHook; hook != nil {
62 reqCtx.ResolverMiddleware = hook
63 }
64
65 if hook := c.requestHook; hook != nil {
66 reqCtx.RequestMiddleware = hook
67 }
68
69 if hook := c.tracer; hook != nil {
70 reqCtx.Tracer = hook
71 }
72
73 if c.complexityLimit > 0 || c.complexityLimitFunc != nil {
74 reqCtx.ComplexityLimit = c.complexityLimit
75 operationComplexity := complexity.Calculate(es, op, variables)
76 reqCtx.OperationComplexity = operationComplexity
77 }
78
79 return reqCtx
80}
81
82type Option func(cfg *Config)
83
84func WebsocketUpgrader(upgrader websocket.Upgrader) Option {
85 return func(cfg *Config) {
86 cfg.upgrader = upgrader
87 }
88}
89
90func RecoverFunc(recover graphql.RecoverFunc) Option {
91 return func(cfg *Config) {
92 cfg.recover = recover
93 }
94}
95
96// ErrorPresenter transforms errors found while resolving into errors that will be returned to the user. It provides
97// a good place to add any extra fields, like error.type, that might be desired by your frontend. Check the default
98// implementation in graphql.DefaultErrorPresenter for an example.
99func ErrorPresenter(f graphql.ErrorPresenterFunc) Option {
100 return func(cfg *Config) {
101 cfg.errorPresenter = f
102 }
103}
104
105// IntrospectionEnabled = false will forbid clients from calling introspection endpoints. Can be useful in prod when you dont
106// want clients introspecting the full schema.
107func IntrospectionEnabled(enabled bool) Option {
108 return func(cfg *Config) {
109 cfg.disableIntrospection = !enabled
110 }
111}
112
113// ComplexityLimit sets a maximum query complexity that is allowed to be executed.
114// If a query is submitted that exceeds the limit, a 422 status code will be returned.
115func ComplexityLimit(limit int) Option {
116 return func(cfg *Config) {
117 cfg.complexityLimit = limit
118 }
119}
120
121// ComplexityLimitFunc allows you to define a function to dynamically set the maximum query complexity that is allowed
122// to be executed.
123// If a query is submitted that exceeds the limit, a 422 status code will be returned.
124func ComplexityLimitFunc(complexityLimitFunc graphql.ComplexityLimitFunc) Option {
125 return func(cfg *Config) {
126 cfg.complexityLimitFunc = complexityLimitFunc
127 }
128}
129
130// ResolverMiddleware allows you to define a function that will be called around every resolver,
131// useful for logging.
132func ResolverMiddleware(middleware graphql.FieldMiddleware) Option {
133 return func(cfg *Config) {
134 if cfg.resolverHook == nil {
135 cfg.resolverHook = middleware
136 return
137 }
138
139 lastResolve := cfg.resolverHook
140 cfg.resolverHook = func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) {
141 return lastResolve(ctx, func(ctx context.Context) (res interface{}, err error) {
142 return middleware(ctx, next)
143 })
144 }
145 }
146}
147
148// RequestMiddleware allows you to define a function that will be called around the root request,
149// after the query has been parsed. This is useful for logging
150func RequestMiddleware(middleware graphql.RequestMiddleware) Option {
151 return func(cfg *Config) {
152 if cfg.requestHook == nil {
153 cfg.requestHook = middleware
154 return
155 }
156
157 lastResolve := cfg.requestHook
158 cfg.requestHook = func(ctx context.Context, next func(ctx context.Context) []byte) []byte {
159 return lastResolve(ctx, func(ctx context.Context) []byte {
160 return middleware(ctx, next)
161 })
162 }
163 }
164}
165
166// Tracer allows you to add a request/resolver tracer that will be called around the root request,
167// calling resolver. This is useful for tracing
168func Tracer(tracer graphql.Tracer) Option {
169 return func(cfg *Config) {
170 if cfg.tracer == nil {
171 cfg.tracer = tracer
172
173 } else {
174 lastResolve := cfg.tracer
175 cfg.tracer = &tracerWrapper{
176 tracer1: lastResolve,
177 tracer2: tracer,
178 }
179 }
180
181 opt := RequestMiddleware(func(ctx context.Context, next func(ctx context.Context) []byte) []byte {
182 ctx = tracer.StartOperationExecution(ctx)
183 resp := next(ctx)
184 tracer.EndOperationExecution(ctx)
185
186 return resp
187 })
188 opt(cfg)
189 }
190}
191
192type tracerWrapper struct {
193 tracer1 graphql.Tracer
194 tracer2 graphql.Tracer
195}
196
197func (tw *tracerWrapper) StartOperationParsing(ctx context.Context) context.Context {
198 ctx = tw.tracer1.StartOperationParsing(ctx)
199 ctx = tw.tracer2.StartOperationParsing(ctx)
200 return ctx
201}
202
203func (tw *tracerWrapper) EndOperationParsing(ctx context.Context) {
204 tw.tracer2.EndOperationParsing(ctx)
205 tw.tracer1.EndOperationParsing(ctx)
206}
207
208func (tw *tracerWrapper) StartOperationValidation(ctx context.Context) context.Context {
209 ctx = tw.tracer1.StartOperationValidation(ctx)
210 ctx = tw.tracer2.StartOperationValidation(ctx)
211 return ctx
212}
213
214func (tw *tracerWrapper) EndOperationValidation(ctx context.Context) {
215 tw.tracer2.EndOperationValidation(ctx)
216 tw.tracer1.EndOperationValidation(ctx)
217}
218
219func (tw *tracerWrapper) StartOperationExecution(ctx context.Context) context.Context {
220 ctx = tw.tracer1.StartOperationExecution(ctx)
221 ctx = tw.tracer2.StartOperationExecution(ctx)
222 return ctx
223}
224
225func (tw *tracerWrapper) StartFieldExecution(ctx context.Context, field graphql.CollectedField) context.Context {
226 ctx = tw.tracer1.StartFieldExecution(ctx, field)
227 ctx = tw.tracer2.StartFieldExecution(ctx, field)
228 return ctx
229}
230
231func (tw *tracerWrapper) StartFieldResolverExecution(ctx context.Context, rc *graphql.ResolverContext) context.Context {
232 ctx = tw.tracer1.StartFieldResolverExecution(ctx, rc)
233 ctx = tw.tracer2.StartFieldResolverExecution(ctx, rc)
234 return ctx
235}
236
237func (tw *tracerWrapper) StartFieldChildExecution(ctx context.Context) context.Context {
238 ctx = tw.tracer1.StartFieldChildExecution(ctx)
239 ctx = tw.tracer2.StartFieldChildExecution(ctx)
240 return ctx
241}
242
243func (tw *tracerWrapper) EndFieldExecution(ctx context.Context) {
244 tw.tracer2.EndFieldExecution(ctx)
245 tw.tracer1.EndFieldExecution(ctx)
246}
247
248func (tw *tracerWrapper) EndOperationExecution(ctx context.Context) {
249 tw.tracer2.EndOperationExecution(ctx)
250 tw.tracer1.EndOperationExecution(ctx)
251}
252
253// CacheSize sets the maximum size of the query cache.
254// If size is less than or equal to 0, the cache is disabled.
255func CacheSize(size int) Option {
256 return func(cfg *Config) {
257 cfg.cacheSize = size
258 }
259}
260
261// UploadMaxSize sets the maximum number of bytes used to parse a request body
262// as multipart/form-data.
263func UploadMaxSize(size int64) Option {
264 return func(cfg *Config) {
265 cfg.uploadMaxSize = size
266 }
267}
268
269// UploadMaxMemory sets the maximum number of bytes used to parse a request body
270// as multipart/form-data in memory, with the remainder stored on disk in
271// temporary files.
272func UploadMaxMemory(size int64) Option {
273 return func(cfg *Config) {
274 cfg.uploadMaxMemory = size
275 }
276}
277
278// WebsocketKeepAliveDuration allows you to reconfigure the keepalive behavior.
279// By default, keepalive is enabled with a DefaultConnectionKeepAlivePingInterval
280// duration. Set handler.connectionKeepAlivePingInterval = 0 to disable keepalive
281// altogether.
282func WebsocketKeepAliveDuration(duration time.Duration) Option {
283 return func(cfg *Config) {
284 cfg.connectionKeepAlivePingInterval = duration
285 }
286}
287
288const DefaultCacheSize = 1000
289const DefaultConnectionKeepAlivePingInterval = 25 * time.Second
290
291// DefaultUploadMaxMemory is the maximum number of bytes used to parse a request body
292// as multipart/form-data in memory, with the remainder stored on disk in
293// temporary files.
294const DefaultUploadMaxMemory = 32 << 20
295
296// DefaultUploadMaxSize is maximum number of bytes used to parse a request body
297// as multipart/form-data.
298const DefaultUploadMaxSize = 32 << 20
299
300func GraphQL(exec graphql.ExecutableSchema, options ...Option) http.HandlerFunc {
301 cfg := &Config{
302 cacheSize: DefaultCacheSize,
303 uploadMaxMemory: DefaultUploadMaxMemory,
304 uploadMaxSize: DefaultUploadMaxSize,
305 connectionKeepAlivePingInterval: DefaultConnectionKeepAlivePingInterval,
306 upgrader: websocket.Upgrader{
307 ReadBufferSize: 1024,
308 WriteBufferSize: 1024,
309 },
310 }
311
312 for _, option := range options {
313 option(cfg)
314 }
315
316 var cache *lru.Cache
317 if cfg.cacheSize > 0 {
318 var err error
319 cache, err = lru.New(cfg.cacheSize)
320 if err != nil {
321 // An error is only returned for non-positive cache size
322 // and we already checked for that.
323 panic("unexpected error creating cache: " + err.Error())
324 }
325 }
326 if cfg.tracer == nil {
327 cfg.tracer = &graphql.NopTracer{}
328 }
329
330 handler := &graphqlHandler{
331 cfg: cfg,
332 cache: cache,
333 exec: exec,
334 }
335
336 return handler.ServeHTTP
337}
338
339var _ http.Handler = (*graphqlHandler)(nil)
340
341type graphqlHandler struct {
342 cfg *Config
343 cache *lru.Cache
344 exec graphql.ExecutableSchema
345}
346
347func (gh *graphqlHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
348 if r.Method == http.MethodOptions {
349 w.Header().Set("Allow", "OPTIONS, GET, POST")
350 w.WriteHeader(http.StatusOK)
351 return
352 }
353
354 if strings.Contains(r.Header.Get("Upgrade"), "websocket") {
355 connectWs(gh.exec, w, r, gh.cfg, gh.cache)
356 return
357 }
358
359 w.Header().Set("Content-Type", "application/json")
360 var reqParams params
361 switch r.Method {
362 case http.MethodGet:
363 reqParams.Query = r.URL.Query().Get("query")
364 reqParams.OperationName = r.URL.Query().Get("operationName")
365
366 if variables := r.URL.Query().Get("variables"); variables != "" {
367 if err := jsonDecode(strings.NewReader(variables), &reqParams.Variables); err != nil {
368 sendErrorf(w, http.StatusBadRequest, "variables could not be decoded")
369 return
370 }
371 }
372 case http.MethodPost:
373 mediaType, _, err := mime.ParseMediaType(r.Header.Get("Content-Type"))
374 if err != nil {
375 sendErrorf(w, http.StatusBadRequest, "error parsing request Content-Type")
376 return
377 }
378
379 switch mediaType {
380 case "application/json":
381 if err := jsonDecode(r.Body, &reqParams); err != nil {
382 sendErrorf(w, http.StatusBadRequest, "json body could not be decoded: "+err.Error())
383 return
384 }
385
386 case "multipart/form-data":
387 var closers []io.Closer
388 var tmpFiles []string
389 defer func() {
390 for i := len(closers) - 1; 0 <= i; i-- {
391 _ = closers[i].Close()
392 }
393 for _, tmpFile := range tmpFiles {
394 _ = os.Remove(tmpFile)
395 }
396 }()
397 if err := processMultipart(w, r, &reqParams, &closers, &tmpFiles, gh.cfg.uploadMaxSize, gh.cfg.uploadMaxMemory); err != nil {
398 sendErrorf(w, http.StatusBadRequest, "multipart body could not be decoded: "+err.Error())
399 return
400 }
401 default:
402 sendErrorf(w, http.StatusBadRequest, "unsupported Content-Type: "+mediaType)
403 return
404 }
405 default:
406 w.WriteHeader(http.StatusMethodNotAllowed)
407 return
408 }
409
410 ctx := r.Context()
411
412 var doc *ast.QueryDocument
413 var cacheHit bool
414 if gh.cache != nil {
415 val, ok := gh.cache.Get(reqParams.Query)
416 if ok {
417 doc = val.(*ast.QueryDocument)
418 cacheHit = true
419 }
420 }
421
422 ctx, doc, gqlErr := gh.parseOperation(ctx, &parseOperationArgs{
423 Query: reqParams.Query,
424 CachedDoc: doc,
425 })
426 if gqlErr != nil {
427 sendError(w, http.StatusUnprocessableEntity, gqlErr)
428 return
429 }
430
431 ctx, op, vars, listErr := gh.validateOperation(ctx, &validateOperationArgs{
432 Doc: doc,
433 OperationName: reqParams.OperationName,
434 CacheHit: cacheHit,
435 R: r,
436 Variables: reqParams.Variables,
437 })
438 if len(listErr) != 0 {
439 sendError(w, http.StatusUnprocessableEntity, listErr...)
440 return
441 }
442
443 if gh.cache != nil && !cacheHit {
444 gh.cache.Add(reqParams.Query, doc)
445 }
446
447 reqCtx := gh.cfg.newRequestContext(gh.exec, doc, op, reqParams.Query, vars)
448 ctx = graphql.WithRequestContext(ctx, reqCtx)
449
450 defer func() {
451 if err := recover(); err != nil {
452 userErr := reqCtx.Recover(ctx, err)
453 sendErrorf(w, http.StatusUnprocessableEntity, userErr.Error())
454 }
455 }()
456
457 if gh.cfg.complexityLimitFunc != nil {
458 reqCtx.ComplexityLimit = gh.cfg.complexityLimitFunc(ctx)
459 }
460
461 if reqCtx.ComplexityLimit > 0 && reqCtx.OperationComplexity > reqCtx.ComplexityLimit {
462 sendErrorf(w, http.StatusUnprocessableEntity, "operation has complexity %d, which exceeds the limit of %d", reqCtx.OperationComplexity, reqCtx.ComplexityLimit)
463 return
464 }
465
466 switch op.Operation {
467 case ast.Query:
468 b, err := json.Marshal(gh.exec.Query(ctx, op))
469 if err != nil {
470 panic(err)
471 }
472 w.Write(b)
473 case ast.Mutation:
474 b, err := json.Marshal(gh.exec.Mutation(ctx, op))
475 if err != nil {
476 panic(err)
477 }
478 w.Write(b)
479 default:
480 sendErrorf(w, http.StatusBadRequest, "unsupported operation type")
481 }
482}
483
484type parseOperationArgs struct {
485 Query string
486 CachedDoc *ast.QueryDocument
487}
488
489func (gh *graphqlHandler) parseOperation(ctx context.Context, args *parseOperationArgs) (context.Context, *ast.QueryDocument, *gqlerror.Error) {
490 ctx = gh.cfg.tracer.StartOperationParsing(ctx)
491 defer func() { gh.cfg.tracer.EndOperationParsing(ctx) }()
492
493 if args.CachedDoc != nil {
494 return ctx, args.CachedDoc, nil
495 }
496
497 doc, gqlErr := parser.ParseQuery(&ast.Source{Input: args.Query})
498 if gqlErr != nil {
499 return ctx, nil, gqlErr
500 }
501
502 return ctx, doc, nil
503}
504
505type validateOperationArgs struct {
506 Doc *ast.QueryDocument
507 OperationName string
508 CacheHit bool
509 R *http.Request
510 Variables map[string]interface{}
511}
512
513func (gh *graphqlHandler) validateOperation(ctx context.Context, args *validateOperationArgs) (context.Context, *ast.OperationDefinition, map[string]interface{}, gqlerror.List) {
514 ctx = gh.cfg.tracer.StartOperationValidation(ctx)
515 defer func() { gh.cfg.tracer.EndOperationValidation(ctx) }()
516
517 if !args.CacheHit {
518 listErr := validator.Validate(gh.exec.Schema(), args.Doc)
519 if len(listErr) != 0 {
520 return ctx, nil, nil, listErr
521 }
522 }
523
524 op := args.Doc.Operations.ForName(args.OperationName)
525 if op == nil {
526 return ctx, nil, nil, gqlerror.List{gqlerror.Errorf("operation %s not found", args.OperationName)}
527 }
528
529 if op.Operation != ast.Query && args.R.Method == http.MethodGet {
530 return ctx, nil, nil, gqlerror.List{gqlerror.Errorf("GET requests only allow query operations")}
531 }
532
533 vars, err := validator.VariableValues(gh.exec.Schema(), op, args.Variables)
534 if err != nil {
535 return ctx, nil, nil, gqlerror.List{err}
536 }
537
538 return ctx, op, vars, nil
539}
540
541func jsonDecode(r io.Reader, val interface{}) error {
542 dec := json.NewDecoder(r)
543 dec.UseNumber()
544 return dec.Decode(val)
545}
546
547func sendError(w http.ResponseWriter, code int, errors ...*gqlerror.Error) {
548 w.WriteHeader(code)
549 b, err := json.Marshal(&graphql.Response{Errors: errors})
550 if err != nil {
551 panic(err)
552 }
553 w.Write(b)
554}
555
556func sendErrorf(w http.ResponseWriter, code int, format string, args ...interface{}) {
557 sendError(w, code, &gqlerror.Error{Message: fmt.Sprintf(format, args...)})
558}
559
560type bytesReader struct {
561 s *[]byte
562 i int64 // current reading index
563 prevRune int // index of previous rune; or < 0
564}
565
566func (r *bytesReader) Read(b []byte) (n int, err error) {
567 if r.s == nil {
568 return 0, errors.New("byte slice pointer is nil")
569 }
570 if r.i >= int64(len(*r.s)) {
571 return 0, io.EOF
572 }
573 r.prevRune = -1
574 n = copy(b, (*r.s)[r.i:])
575 r.i += int64(n)
576 return
577}
578
579func processMultipart(w http.ResponseWriter, r *http.Request, request *params, closers *[]io.Closer, tmpFiles *[]string, uploadMaxSize, uploadMaxMemory int64) error {
580 var err error
581 if r.ContentLength > uploadMaxSize {
582 return errors.New("failed to parse multipart form, request body too large")
583 }
584 r.Body = http.MaxBytesReader(w, r.Body, uploadMaxSize)
585 if err = r.ParseMultipartForm(uploadMaxMemory); err != nil {
586 if strings.Contains(err.Error(), "request body too large") {
587 return errors.New("failed to parse multipart form, request body too large")
588 }
589 return errors.New("failed to parse multipart form")
590 }
591 *closers = append(*closers, r.Body)
592
593 if err = jsonDecode(strings.NewReader(r.Form.Get("operations")), &request); err != nil {
594 return errors.New("operations form field could not be decoded")
595 }
596
597 var uploadsMap = map[string][]string{}
598 if err = json.Unmarshal([]byte(r.Form.Get("map")), &uploadsMap); err != nil {
599 return errors.New("map form field could not be decoded")
600 }
601
602 var upload graphql.Upload
603 for key, paths := range uploadsMap {
604 if len(paths) == 0 {
605 return fmt.Errorf("invalid empty operations paths list for key %s", key)
606 }
607 file, header, err := r.FormFile(key)
608 if err != nil {
609 return fmt.Errorf("failed to get key %s from form", key)
610 }
611 *closers = append(*closers, file)
612
613 if len(paths) == 1 {
614 upload = graphql.Upload{
615 File: file,
616 Size: header.Size,
617 Filename: header.Filename,
618 }
619 err = addUploadToOperations(request, upload, key, paths[0])
620 if err != nil {
621 return err
622 }
623 } else {
624 if r.ContentLength < uploadMaxMemory {
625 fileBytes, err := ioutil.ReadAll(file)
626 if err != nil {
627 return fmt.Errorf("failed to read file for key %s", key)
628 }
629 for _, path := range paths {
630 upload = graphql.Upload{
631 File: &bytesReader{s: &fileBytes, i: 0, prevRune: -1},
632 Size: header.Size,
633 Filename: header.Filename,
634 }
635 err = addUploadToOperations(request, upload, key, path)
636 if err != nil {
637 return err
638 }
639 }
640 } else {
641 tmpFile, err := ioutil.TempFile(os.TempDir(), "gqlgen-")
642 if err != nil {
643 return fmt.Errorf("failed to create temp file for key %s", key)
644 }
645 tmpName := tmpFile.Name()
646 *tmpFiles = append(*tmpFiles, tmpName)
647 _, err = io.Copy(tmpFile, file)
648 if err != nil {
649 if err := tmpFile.Close(); err != nil {
650 return fmt.Errorf("failed to copy to temp file and close temp file for key %s", key)
651 }
652 return fmt.Errorf("failed to copy to temp file for key %s", key)
653 }
654 if err := tmpFile.Close(); err != nil {
655 return fmt.Errorf("failed to close temp file for key %s", key)
656 }
657 for _, path := range paths {
658 pathTmpFile, err := os.Open(tmpName)
659 if err != nil {
660 return fmt.Errorf("failed to open temp file for key %s", key)
661 }
662 *closers = append(*closers, pathTmpFile)
663 upload = graphql.Upload{
664 File: pathTmpFile,
665 Size: header.Size,
666 Filename: header.Filename,
667 }
668 err = addUploadToOperations(request, upload, key, path)
669 if err != nil {
670 return err
671 }
672 }
673 }
674 }
675 }
676 return nil
677}
678
679func addUploadToOperations(request *params, upload graphql.Upload, key, path string) error {
680 if !strings.HasPrefix(path, "variables.") {
681 return fmt.Errorf("invalid operations paths for key %s", key)
682 }
683
684 var ptr interface{} = request.Variables
685 parts := strings.Split(path, ".")
686
687 // skip the first part (variables) because we started there
688 for i, p := range parts[1:] {
689 last := i == len(parts)-2
690 if ptr == nil {
691 return fmt.Errorf("path is missing \"variables.\" prefix, key: %s, path: %s", key, path)
692 }
693 if index, parseNbrErr := strconv.Atoi(p); parseNbrErr == nil {
694 if last {
695 ptr.([]interface{})[index] = upload
696 } else {
697 ptr = ptr.([]interface{})[index]
698 }
699 } else {
700 if last {
701 ptr.(map[string]interface{})[p] = upload
702 } else {
703 ptr = ptr.(map[string]interface{})[p]
704 }
705 }
706 }
707
708 return nil
709}