1package handler
2
3import (
4 "context"
5 "crypto/sha256"
6 "encoding/hex"
7 "encoding/json"
8 "errors"
9 "fmt"
10 "io"
11 "io/ioutil"
12 "mime"
13 "net/http"
14 "os"
15 "strconv"
16 "strings"
17 "time"
18
19 "github.com/99designs/gqlgen/complexity"
20 "github.com/99designs/gqlgen/graphql"
21 "github.com/gorilla/websocket"
22 lru "github.com/hashicorp/golang-lru"
23 "github.com/vektah/gqlparser/ast"
24 "github.com/vektah/gqlparser/gqlerror"
25 "github.com/vektah/gqlparser/parser"
26 "github.com/vektah/gqlparser/validator"
27)
28
29type params struct {
30 Query string `json:"query"`
31 OperationName string `json:"operationName"`
32 Variables map[string]interface{} `json:"variables"`
33 Extensions *extensions `json:"extensions"`
34}
35
36type extensions struct {
37 PersistedQuery *persistedQuery `json:"persistedQuery"`
38}
39
40type persistedQuery struct {
41 Sha256 string `json:"sha256Hash"`
42 Version int64 `json:"version"`
43}
44
45const (
46 errPersistedQueryNotSupported = "PersistedQueryNotSupported"
47 errPersistedQueryNotFound = "PersistedQueryNotFound"
48)
49
50type PersistedQueryCache interface {
51 Add(ctx context.Context, hash string, query string)
52 Get(ctx context.Context, hash string) (string, bool)
53}
54
55type websocketInitFunc func(ctx context.Context, initPayload InitPayload) error
56
57type Config struct {
58 cacheSize int
59 upgrader websocket.Upgrader
60 recover graphql.RecoverFunc
61 errorPresenter graphql.ErrorPresenterFunc
62 resolverHook graphql.FieldMiddleware
63 requestHook graphql.RequestMiddleware
64 tracer graphql.Tracer
65 complexityLimit int
66 complexityLimitFunc graphql.ComplexityLimitFunc
67 websocketInitFunc websocketInitFunc
68 disableIntrospection bool
69 connectionKeepAlivePingInterval time.Duration
70 uploadMaxMemory int64
71 uploadMaxSize int64
72 apqCache PersistedQueryCache
73}
74
75func (c *Config) newRequestContext(es graphql.ExecutableSchema, doc *ast.QueryDocument, op *ast.OperationDefinition, query string, variables map[string]interface{}) *graphql.RequestContext {
76 reqCtx := graphql.NewRequestContext(doc, query, variables)
77 reqCtx.DisableIntrospection = c.disableIntrospection
78
79 if hook := c.recover; hook != nil {
80 reqCtx.Recover = hook
81 }
82
83 if hook := c.errorPresenter; hook != nil {
84 reqCtx.ErrorPresenter = hook
85 }
86
87 if hook := c.resolverHook; hook != nil {
88 reqCtx.ResolverMiddleware = hook
89 }
90
91 if hook := c.requestHook; hook != nil {
92 reqCtx.RequestMiddleware = hook
93 }
94
95 if hook := c.tracer; hook != nil {
96 reqCtx.Tracer = hook
97 }
98
99 if c.complexityLimit > 0 || c.complexityLimitFunc != nil {
100 reqCtx.ComplexityLimit = c.complexityLimit
101 operationComplexity := complexity.Calculate(es, op, variables)
102 reqCtx.OperationComplexity = operationComplexity
103 }
104
105 return reqCtx
106}
107
108type Option func(cfg *Config)
109
110func WebsocketUpgrader(upgrader websocket.Upgrader) Option {
111 return func(cfg *Config) {
112 cfg.upgrader = upgrader
113 }
114}
115
116func RecoverFunc(recover graphql.RecoverFunc) Option {
117 return func(cfg *Config) {
118 cfg.recover = recover
119 }
120}
121
122// ErrorPresenter transforms errors found while resolving into errors that will be returned to the user. It provides
123// a good place to add any extra fields, like error.type, that might be desired by your frontend. Check the default
124// implementation in graphql.DefaultErrorPresenter for an example.
125func ErrorPresenter(f graphql.ErrorPresenterFunc) Option {
126 return func(cfg *Config) {
127 cfg.errorPresenter = f
128 }
129}
130
131// IntrospectionEnabled = false will forbid clients from calling introspection endpoints. Can be useful in prod when you dont
132// want clients introspecting the full schema.
133func IntrospectionEnabled(enabled bool) Option {
134 return func(cfg *Config) {
135 cfg.disableIntrospection = !enabled
136 }
137}
138
139// ComplexityLimit sets a maximum query complexity that is allowed to be executed.
140// If a query is submitted that exceeds the limit, a 422 status code will be returned.
141func ComplexityLimit(limit int) Option {
142 return func(cfg *Config) {
143 cfg.complexityLimit = limit
144 }
145}
146
147// ComplexityLimitFunc allows you to define a function to dynamically set the maximum query complexity that is allowed
148// to be executed.
149// If a query is submitted that exceeds the limit, a 422 status code will be returned.
150func ComplexityLimitFunc(complexityLimitFunc graphql.ComplexityLimitFunc) Option {
151 return func(cfg *Config) {
152 cfg.complexityLimitFunc = complexityLimitFunc
153 }
154}
155
156// ResolverMiddleware allows you to define a function that will be called around every resolver,
157// useful for logging.
158func ResolverMiddleware(middleware graphql.FieldMiddleware) Option {
159 return func(cfg *Config) {
160 if cfg.resolverHook == nil {
161 cfg.resolverHook = middleware
162 return
163 }
164
165 lastResolve := cfg.resolverHook
166 cfg.resolverHook = func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) {
167 return lastResolve(ctx, func(ctx context.Context) (res interface{}, err error) {
168 return middleware(ctx, next)
169 })
170 }
171 }
172}
173
174// RequestMiddleware allows you to define a function that will be called around the root request,
175// after the query has been parsed. This is useful for logging
176func RequestMiddleware(middleware graphql.RequestMiddleware) Option {
177 return func(cfg *Config) {
178 if cfg.requestHook == nil {
179 cfg.requestHook = middleware
180 return
181 }
182
183 lastResolve := cfg.requestHook
184 cfg.requestHook = func(ctx context.Context, next func(ctx context.Context) []byte) []byte {
185 return lastResolve(ctx, func(ctx context.Context) []byte {
186 return middleware(ctx, next)
187 })
188 }
189 }
190}
191
192// Tracer allows you to add a request/resolver tracer that will be called around the root request,
193// calling resolver. This is useful for tracing
194func Tracer(tracer graphql.Tracer) Option {
195 return func(cfg *Config) {
196 if cfg.tracer == nil {
197 cfg.tracer = tracer
198
199 } else {
200 lastResolve := cfg.tracer
201 cfg.tracer = &tracerWrapper{
202 tracer1: lastResolve,
203 tracer2: tracer,
204 }
205 }
206
207 opt := RequestMiddleware(func(ctx context.Context, next func(ctx context.Context) []byte) []byte {
208 ctx = tracer.StartOperationExecution(ctx)
209 resp := next(ctx)
210 tracer.EndOperationExecution(ctx)
211
212 return resp
213 })
214 opt(cfg)
215 }
216}
217
218type tracerWrapper struct {
219 tracer1 graphql.Tracer
220 tracer2 graphql.Tracer
221}
222
223func (tw *tracerWrapper) StartOperationParsing(ctx context.Context) context.Context {
224 ctx = tw.tracer1.StartOperationParsing(ctx)
225 ctx = tw.tracer2.StartOperationParsing(ctx)
226 return ctx
227}
228
229func (tw *tracerWrapper) EndOperationParsing(ctx context.Context) {
230 tw.tracer2.EndOperationParsing(ctx)
231 tw.tracer1.EndOperationParsing(ctx)
232}
233
234func (tw *tracerWrapper) StartOperationValidation(ctx context.Context) context.Context {
235 ctx = tw.tracer1.StartOperationValidation(ctx)
236 ctx = tw.tracer2.StartOperationValidation(ctx)
237 return ctx
238}
239
240func (tw *tracerWrapper) EndOperationValidation(ctx context.Context) {
241 tw.tracer2.EndOperationValidation(ctx)
242 tw.tracer1.EndOperationValidation(ctx)
243}
244
245func (tw *tracerWrapper) StartOperationExecution(ctx context.Context) context.Context {
246 ctx = tw.tracer1.StartOperationExecution(ctx)
247 ctx = tw.tracer2.StartOperationExecution(ctx)
248 return ctx
249}
250
251func (tw *tracerWrapper) StartFieldExecution(ctx context.Context, field graphql.CollectedField) context.Context {
252 ctx = tw.tracer1.StartFieldExecution(ctx, field)
253 ctx = tw.tracer2.StartFieldExecution(ctx, field)
254 return ctx
255}
256
257func (tw *tracerWrapper) StartFieldResolverExecution(ctx context.Context, rc *graphql.ResolverContext) context.Context {
258 ctx = tw.tracer1.StartFieldResolverExecution(ctx, rc)
259 ctx = tw.tracer2.StartFieldResolverExecution(ctx, rc)
260 return ctx
261}
262
263func (tw *tracerWrapper) StartFieldChildExecution(ctx context.Context) context.Context {
264 ctx = tw.tracer1.StartFieldChildExecution(ctx)
265 ctx = tw.tracer2.StartFieldChildExecution(ctx)
266 return ctx
267}
268
269func (tw *tracerWrapper) EndFieldExecution(ctx context.Context) {
270 tw.tracer2.EndFieldExecution(ctx)
271 tw.tracer1.EndFieldExecution(ctx)
272}
273
274func (tw *tracerWrapper) EndOperationExecution(ctx context.Context) {
275 tw.tracer2.EndOperationExecution(ctx)
276 tw.tracer1.EndOperationExecution(ctx)
277}
278
279// WebsocketInitFunc is called when the server receives connection init message from the client.
280// This can be used to check initial payload to see whether to accept the websocket connection.
281func WebsocketInitFunc(websocketInitFunc func(ctx context.Context, initPayload InitPayload) error) Option {
282 return func(cfg *Config) {
283 cfg.websocketInitFunc = websocketInitFunc
284 }
285}
286
287// CacheSize sets the maximum size of the query cache.
288// If size is less than or equal to 0, the cache is disabled.
289func CacheSize(size int) Option {
290 return func(cfg *Config) {
291 cfg.cacheSize = size
292 }
293}
294
295// UploadMaxSize sets the maximum number of bytes used to parse a request body
296// as multipart/form-data.
297func UploadMaxSize(size int64) Option {
298 return func(cfg *Config) {
299 cfg.uploadMaxSize = size
300 }
301}
302
303// UploadMaxMemory sets the maximum number of bytes used to parse a request body
304// as multipart/form-data in memory, with the remainder stored on disk in
305// temporary files.
306func UploadMaxMemory(size int64) Option {
307 return func(cfg *Config) {
308 cfg.uploadMaxMemory = size
309 }
310}
311
312// WebsocketKeepAliveDuration allows you to reconfigure the keepalive behavior.
313// By default, keepalive is enabled with a DefaultConnectionKeepAlivePingInterval
314// duration. Set handler.connectionKeepAlivePingInterval = 0 to disable keepalive
315// altogether.
316func WebsocketKeepAliveDuration(duration time.Duration) Option {
317 return func(cfg *Config) {
318 cfg.connectionKeepAlivePingInterval = duration
319 }
320}
321
322// Add cache that will hold queries for automatic persisted queries (APQ)
323func EnablePersistedQueryCache(cache PersistedQueryCache) Option {
324 return func(cfg *Config) {
325 cfg.apqCache = cache
326 }
327}
328
329const DefaultCacheSize = 1000
330const DefaultConnectionKeepAlivePingInterval = 25 * time.Second
331
332// DefaultUploadMaxMemory is the maximum number of bytes used to parse a request body
333// as multipart/form-data in memory, with the remainder stored on disk in
334// temporary files.
335const DefaultUploadMaxMemory = 32 << 20
336
337// DefaultUploadMaxSize is maximum number of bytes used to parse a request body
338// as multipart/form-data.
339const DefaultUploadMaxSize = 32 << 20
340
341func GraphQL(exec graphql.ExecutableSchema, options ...Option) http.HandlerFunc {
342 cfg := &Config{
343 cacheSize: DefaultCacheSize,
344 uploadMaxMemory: DefaultUploadMaxMemory,
345 uploadMaxSize: DefaultUploadMaxSize,
346 connectionKeepAlivePingInterval: DefaultConnectionKeepAlivePingInterval,
347 upgrader: websocket.Upgrader{
348 ReadBufferSize: 1024,
349 WriteBufferSize: 1024,
350 },
351 }
352
353 for _, option := range options {
354 option(cfg)
355 }
356
357 var cache *lru.Cache
358 if cfg.cacheSize > 0 {
359 var err error
360 cache, err = lru.New(cfg.cacheSize)
361 if err != nil {
362 // An error is only returned for non-positive cache size
363 // and we already checked for that.
364 panic("unexpected error creating cache: " + err.Error())
365 }
366 }
367 if cfg.tracer == nil {
368 cfg.tracer = &graphql.NopTracer{}
369 }
370
371 handler := &graphqlHandler{
372 cfg: cfg,
373 cache: cache,
374 exec: exec,
375 }
376
377 return handler.ServeHTTP
378}
379
380var _ http.Handler = (*graphqlHandler)(nil)
381
382type graphqlHandler struct {
383 cfg *Config
384 cache *lru.Cache
385 exec graphql.ExecutableSchema
386}
387
388func computeQueryHash(query string) string {
389 b := sha256.Sum256([]byte(query))
390 return hex.EncodeToString(b[:])
391}
392
393func (gh *graphqlHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
394 if r.Method == http.MethodOptions {
395 w.Header().Set("Allow", "OPTIONS, GET, POST")
396 w.WriteHeader(http.StatusOK)
397 return
398 }
399
400 if strings.Contains(r.Header.Get("Upgrade"), "websocket") {
401 connectWs(gh.exec, w, r, gh.cfg, gh.cache)
402 return
403 }
404
405 w.Header().Set("Content-Type", "application/json")
406 var reqParams params
407 switch r.Method {
408 case http.MethodGet:
409 reqParams.Query = r.URL.Query().Get("query")
410 reqParams.OperationName = r.URL.Query().Get("operationName")
411
412 if variables := r.URL.Query().Get("variables"); variables != "" {
413 if err := jsonDecode(strings.NewReader(variables), &reqParams.Variables); err != nil {
414 sendErrorf(w, http.StatusBadRequest, "variables could not be decoded")
415 return
416 }
417 }
418
419 if extensions := r.URL.Query().Get("extensions"); extensions != "" {
420 if err := jsonDecode(strings.NewReader(extensions), &reqParams.Extensions); err != nil {
421 sendErrorf(w, http.StatusBadRequest, "extensions could not be decoded")
422 return
423 }
424 }
425 case http.MethodPost:
426 mediaType, _, err := mime.ParseMediaType(r.Header.Get("Content-Type"))
427 if err != nil {
428 sendErrorf(w, http.StatusBadRequest, "error parsing request Content-Type")
429 return
430 }
431
432 switch mediaType {
433 case "application/json":
434 if err := jsonDecode(r.Body, &reqParams); err != nil {
435 sendErrorf(w, http.StatusBadRequest, "json body could not be decoded: "+err.Error())
436 return
437 }
438
439 case "multipart/form-data":
440 var closers []io.Closer
441 var tmpFiles []string
442 defer func() {
443 for i := len(closers) - 1; 0 <= i; i-- {
444 _ = closers[i].Close()
445 }
446 for _, tmpFile := range tmpFiles {
447 _ = os.Remove(tmpFile)
448 }
449 }()
450 if err := processMultipart(w, r, &reqParams, &closers, &tmpFiles, gh.cfg.uploadMaxSize, gh.cfg.uploadMaxMemory); err != nil {
451 sendErrorf(w, http.StatusBadRequest, "multipart body could not be decoded: "+err.Error())
452 return
453 }
454 default:
455 sendErrorf(w, http.StatusBadRequest, "unsupported Content-Type: "+mediaType)
456 return
457 }
458 default:
459 w.WriteHeader(http.StatusMethodNotAllowed)
460 return
461 }
462
463 ctx := r.Context()
464
465 var queryHash string
466 apqRegister := false
467 apq := reqParams.Extensions != nil && reqParams.Extensions.PersistedQuery != nil
468 if apq {
469 // client has enabled apq
470 queryHash = reqParams.Extensions.PersistedQuery.Sha256
471 if gh.cfg.apqCache == nil {
472 // server has disabled apq
473 sendErrorf(w, http.StatusOK, errPersistedQueryNotSupported)
474 return
475 }
476 if reqParams.Extensions.PersistedQuery.Version != 1 {
477 sendErrorf(w, http.StatusOK, "Unsupported persisted query version")
478 return
479 }
480 if reqParams.Query == "" {
481 // client sent optimistic query hash without query string
482 query, ok := gh.cfg.apqCache.Get(ctx, queryHash)
483 if !ok {
484 sendErrorf(w, http.StatusOK, errPersistedQueryNotFound)
485 return
486 }
487 reqParams.Query = query
488 } else {
489 if computeQueryHash(reqParams.Query) != queryHash {
490 sendErrorf(w, http.StatusOK, "provided sha does not match query")
491 return
492 }
493 apqRegister = true
494 }
495 } else if reqParams.Query == "" {
496 sendErrorf(w, http.StatusUnprocessableEntity, "Must provide query string")
497 return
498 }
499
500 var doc *ast.QueryDocument
501 var cacheHit bool
502 if gh.cache != nil {
503 val, ok := gh.cache.Get(reqParams.Query)
504 if ok {
505 doc = val.(*ast.QueryDocument)
506 cacheHit = true
507 }
508 }
509
510 ctx, doc, gqlErr := gh.parseOperation(ctx, &parseOperationArgs{
511 Query: reqParams.Query,
512 CachedDoc: doc,
513 })
514 if gqlErr != nil {
515 sendError(w, http.StatusUnprocessableEntity, gqlErr)
516 return
517 }
518
519 ctx, op, vars, listErr := gh.validateOperation(ctx, &validateOperationArgs{
520 Doc: doc,
521 OperationName: reqParams.OperationName,
522 CacheHit: cacheHit,
523 R: r,
524 Variables: reqParams.Variables,
525 })
526 if len(listErr) != 0 {
527 sendError(w, http.StatusUnprocessableEntity, listErr...)
528 return
529 }
530
531 if gh.cache != nil && !cacheHit {
532 gh.cache.Add(reqParams.Query, doc)
533 }
534
535 reqCtx := gh.cfg.newRequestContext(gh.exec, doc, op, reqParams.Query, vars)
536 ctx = graphql.WithRequestContext(ctx, reqCtx)
537
538 defer func() {
539 if err := recover(); err != nil {
540 userErr := reqCtx.Recover(ctx, err)
541 sendErrorf(w, http.StatusUnprocessableEntity, userErr.Error())
542 }
543 }()
544
545 if gh.cfg.complexityLimitFunc != nil {
546 reqCtx.ComplexityLimit = gh.cfg.complexityLimitFunc(ctx)
547 }
548
549 if reqCtx.ComplexityLimit > 0 && reqCtx.OperationComplexity > reqCtx.ComplexityLimit {
550 sendErrorf(w, http.StatusUnprocessableEntity, "operation has complexity %d, which exceeds the limit of %d", reqCtx.OperationComplexity, reqCtx.ComplexityLimit)
551 return
552 }
553
554 if apqRegister && gh.cfg.apqCache != nil {
555 // Add to persisted query cache
556 gh.cfg.apqCache.Add(ctx, queryHash, reqParams.Query)
557 }
558
559 switch op.Operation {
560 case ast.Query:
561 b, err := json.Marshal(gh.exec.Query(ctx, op))
562 if err != nil {
563 panic(err)
564 }
565 w.Write(b)
566 case ast.Mutation:
567 b, err := json.Marshal(gh.exec.Mutation(ctx, op))
568 if err != nil {
569 panic(err)
570 }
571 w.Write(b)
572 default:
573 sendErrorf(w, http.StatusBadRequest, "unsupported operation type")
574 }
575}
576
577type parseOperationArgs struct {
578 Query string
579 CachedDoc *ast.QueryDocument
580}
581
582func (gh *graphqlHandler) parseOperation(ctx context.Context, args *parseOperationArgs) (context.Context, *ast.QueryDocument, *gqlerror.Error) {
583 ctx = gh.cfg.tracer.StartOperationParsing(ctx)
584 defer func() { gh.cfg.tracer.EndOperationParsing(ctx) }()
585
586 if args.CachedDoc != nil {
587 return ctx, args.CachedDoc, nil
588 }
589
590 doc, gqlErr := parser.ParseQuery(&ast.Source{Input: args.Query})
591 if gqlErr != nil {
592 return ctx, nil, gqlErr
593 }
594
595 return ctx, doc, nil
596}
597
598type validateOperationArgs struct {
599 Doc *ast.QueryDocument
600 OperationName string
601 CacheHit bool
602 R *http.Request
603 Variables map[string]interface{}
604}
605
606func (gh *graphqlHandler) validateOperation(ctx context.Context, args *validateOperationArgs) (context.Context, *ast.OperationDefinition, map[string]interface{}, gqlerror.List) {
607 ctx = gh.cfg.tracer.StartOperationValidation(ctx)
608 defer func() { gh.cfg.tracer.EndOperationValidation(ctx) }()
609
610 if !args.CacheHit {
611 listErr := validator.Validate(gh.exec.Schema(), args.Doc)
612 if len(listErr) != 0 {
613 return ctx, nil, nil, listErr
614 }
615 }
616
617 op := args.Doc.Operations.ForName(args.OperationName)
618 if op == nil {
619 return ctx, nil, nil, gqlerror.List{gqlerror.Errorf("operation %s not found", args.OperationName)}
620 }
621
622 if op.Operation != ast.Query && args.R.Method == http.MethodGet {
623 return ctx, nil, nil, gqlerror.List{gqlerror.Errorf("GET requests only allow query operations")}
624 }
625
626 vars, err := validator.VariableValues(gh.exec.Schema(), op, args.Variables)
627 if err != nil {
628 return ctx, nil, nil, gqlerror.List{err}
629 }
630
631 return ctx, op, vars, nil
632}
633
634func jsonDecode(r io.Reader, val interface{}) error {
635 dec := json.NewDecoder(r)
636 dec.UseNumber()
637 return dec.Decode(val)
638}
639
640func sendError(w http.ResponseWriter, code int, errors ...*gqlerror.Error) {
641 w.WriteHeader(code)
642 b, err := json.Marshal(&graphql.Response{Errors: errors})
643 if err != nil {
644 panic(err)
645 }
646 w.Write(b)
647}
648
649func sendErrorf(w http.ResponseWriter, code int, format string, args ...interface{}) {
650 sendError(w, code, &gqlerror.Error{Message: fmt.Sprintf(format, args...)})
651}
652
653type bytesReader struct {
654 s *[]byte
655 i int64 // current reading index
656 prevRune int // index of previous rune; or < 0
657}
658
659func (r *bytesReader) Read(b []byte) (n int, err error) {
660 if r.s == nil {
661 return 0, errors.New("byte slice pointer is nil")
662 }
663 if r.i >= int64(len(*r.s)) {
664 return 0, io.EOF
665 }
666 r.prevRune = -1
667 n = copy(b, (*r.s)[r.i:])
668 r.i += int64(n)
669 return
670}
671
672func processMultipart(w http.ResponseWriter, r *http.Request, request *params, closers *[]io.Closer, tmpFiles *[]string, uploadMaxSize, uploadMaxMemory int64) error {
673 var err error
674 if r.ContentLength > uploadMaxSize {
675 return errors.New("failed to parse multipart form, request body too large")
676 }
677 r.Body = http.MaxBytesReader(w, r.Body, uploadMaxSize)
678 if err = r.ParseMultipartForm(uploadMaxMemory); err != nil {
679 if strings.Contains(err.Error(), "request body too large") {
680 return errors.New("failed to parse multipart form, request body too large")
681 }
682 return errors.New("failed to parse multipart form")
683 }
684 *closers = append(*closers, r.Body)
685
686 if err = jsonDecode(strings.NewReader(r.Form.Get("operations")), &request); err != nil {
687 return errors.New("operations form field could not be decoded")
688 }
689
690 var uploadsMap = map[string][]string{}
691 if err = json.Unmarshal([]byte(r.Form.Get("map")), &uploadsMap); err != nil {
692 return errors.New("map form field could not be decoded")
693 }
694
695 var upload graphql.Upload
696 for key, paths := range uploadsMap {
697 if len(paths) == 0 {
698 return fmt.Errorf("invalid empty operations paths list for key %s", key)
699 }
700 file, header, err := r.FormFile(key)
701 if err != nil {
702 return fmt.Errorf("failed to get key %s from form", key)
703 }
704 *closers = append(*closers, file)
705
706 if len(paths) == 1 {
707 upload = graphql.Upload{
708 File: file,
709 Size: header.Size,
710 Filename: header.Filename,
711 }
712 err = addUploadToOperations(request, upload, key, paths[0])
713 if err != nil {
714 return err
715 }
716 } else {
717 if r.ContentLength < uploadMaxMemory {
718 fileBytes, err := ioutil.ReadAll(file)
719 if err != nil {
720 return fmt.Errorf("failed to read file for key %s", key)
721 }
722 for _, path := range paths {
723 upload = graphql.Upload{
724 File: &bytesReader{s: &fileBytes, i: 0, prevRune: -1},
725 Size: header.Size,
726 Filename: header.Filename,
727 }
728 err = addUploadToOperations(request, upload, key, path)
729 if err != nil {
730 return err
731 }
732 }
733 } else {
734 tmpFile, err := ioutil.TempFile(os.TempDir(), "gqlgen-")
735 if err != nil {
736 return fmt.Errorf("failed to create temp file for key %s", key)
737 }
738 tmpName := tmpFile.Name()
739 *tmpFiles = append(*tmpFiles, tmpName)
740 _, err = io.Copy(tmpFile, file)
741 if err != nil {
742 if err := tmpFile.Close(); err != nil {
743 return fmt.Errorf("failed to copy to temp file and close temp file for key %s", key)
744 }
745 return fmt.Errorf("failed to copy to temp file for key %s", key)
746 }
747 if err := tmpFile.Close(); err != nil {
748 return fmt.Errorf("failed to close temp file for key %s", key)
749 }
750 for _, path := range paths {
751 pathTmpFile, err := os.Open(tmpName)
752 if err != nil {
753 return fmt.Errorf("failed to open temp file for key %s", key)
754 }
755 *closers = append(*closers, pathTmpFile)
756 upload = graphql.Upload{
757 File: pathTmpFile,
758 Size: header.Size,
759 Filename: header.Filename,
760 }
761 err = addUploadToOperations(request, upload, key, path)
762 if err != nil {
763 return err
764 }
765 }
766 }
767 }
768 }
769 return nil
770}
771
772func addUploadToOperations(request *params, upload graphql.Upload, key, path string) error {
773 if !strings.HasPrefix(path, "variables.") {
774 return fmt.Errorf("invalid operations paths for key %s", key)
775 }
776
777 var ptr interface{} = request.Variables
778 parts := strings.Split(path, ".")
779
780 // skip the first part (variables) because we started there
781 for i, p := range parts[1:] {
782 last := i == len(parts)-2
783 if ptr == nil {
784 return fmt.Errorf("path is missing \"variables.\" prefix, key: %s, path: %s", key, path)
785 }
786 if index, parseNbrErr := strconv.Atoi(p); parseNbrErr == nil {
787 if last {
788 ptr.([]interface{})[index] = upload
789 } else {
790 ptr = ptr.([]interface{})[index]
791 }
792 } else {
793 if last {
794 ptr.(map[string]interface{})[p] = upload
795 } else {
796 ptr = ptr.(map[string]interface{})[p]
797 }
798 }
799 }
800
801 return nil
802}