graphql.go

  1package handler
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"io"
  8	"net/http"
  9	"strings"
 10	"time"
 11
 12	"github.com/99designs/gqlgen/complexity"
 13	"github.com/99designs/gqlgen/graphql"
 14	"github.com/gorilla/websocket"
 15	"github.com/hashicorp/golang-lru"
 16	"github.com/vektah/gqlparser/ast"
 17	"github.com/vektah/gqlparser/gqlerror"
 18	"github.com/vektah/gqlparser/parser"
 19	"github.com/vektah/gqlparser/validator"
 20)
 21
 22type params struct {
 23	Query         string                 `json:"query"`
 24	OperationName string                 `json:"operationName"`
 25	Variables     map[string]interface{} `json:"variables"`
 26}
 27
 28type Config struct {
 29	cacheSize                       int
 30	upgrader                        websocket.Upgrader
 31	recover                         graphql.RecoverFunc
 32	errorPresenter                  graphql.ErrorPresenterFunc
 33	resolverHook                    graphql.FieldMiddleware
 34	requestHook                     graphql.RequestMiddleware
 35	tracer                          graphql.Tracer
 36	complexityLimit                 int
 37	complexityLimitFunc             graphql.ComplexityLimitFunc
 38	disableIntrospection            bool
 39	connectionKeepAlivePingInterval time.Duration
 40}
 41
 42func (c *Config) newRequestContext(es graphql.ExecutableSchema, doc *ast.QueryDocument, op *ast.OperationDefinition, query string, variables map[string]interface{}) *graphql.RequestContext {
 43	reqCtx := graphql.NewRequestContext(doc, query, variables)
 44	reqCtx.DisableIntrospection = c.disableIntrospection
 45
 46	if hook := c.recover; hook != nil {
 47		reqCtx.Recover = hook
 48	}
 49
 50	if hook := c.errorPresenter; hook != nil {
 51		reqCtx.ErrorPresenter = hook
 52	}
 53
 54	if hook := c.resolverHook; hook != nil {
 55		reqCtx.ResolverMiddleware = hook
 56	}
 57
 58	if hook := c.requestHook; hook != nil {
 59		reqCtx.RequestMiddleware = hook
 60	}
 61
 62	if hook := c.tracer; hook != nil {
 63		reqCtx.Tracer = hook
 64	}
 65
 66	if c.complexityLimit > 0 || c.complexityLimitFunc != nil {
 67		reqCtx.ComplexityLimit = c.complexityLimit
 68		operationComplexity := complexity.Calculate(es, op, variables)
 69		reqCtx.OperationComplexity = operationComplexity
 70	}
 71
 72	return reqCtx
 73}
 74
 75type Option func(cfg *Config)
 76
 77func WebsocketUpgrader(upgrader websocket.Upgrader) Option {
 78	return func(cfg *Config) {
 79		cfg.upgrader = upgrader
 80	}
 81}
 82
 83func RecoverFunc(recover graphql.RecoverFunc) Option {
 84	return func(cfg *Config) {
 85		cfg.recover = recover
 86	}
 87}
 88
 89// ErrorPresenter transforms errors found while resolving into errors that will be returned to the user. It provides
 90// a good place to add any extra fields, like error.type, that might be desired by your frontend. Check the default
 91// implementation in graphql.DefaultErrorPresenter for an example.
 92func ErrorPresenter(f graphql.ErrorPresenterFunc) Option {
 93	return func(cfg *Config) {
 94		cfg.errorPresenter = f
 95	}
 96}
 97
 98// IntrospectionEnabled = false will forbid clients from calling introspection endpoints. Can be useful in prod when you dont
 99// want clients introspecting the full schema.
100func IntrospectionEnabled(enabled bool) Option {
101	return func(cfg *Config) {
102		cfg.disableIntrospection = !enabled
103	}
104}
105
106// ComplexityLimit sets a maximum query complexity that is allowed to be executed.
107// If a query is submitted that exceeds the limit, a 422 status code will be returned.
108func ComplexityLimit(limit int) Option {
109	return func(cfg *Config) {
110		cfg.complexityLimit = limit
111	}
112}
113
114// ComplexityLimitFunc allows you to define a function to dynamically set the maximum query complexity that is allowed
115// to be executed.
116// If a query is submitted that exceeds the limit, a 422 status code will be returned.
117func ComplexityLimitFunc(complexityLimitFunc graphql.ComplexityLimitFunc) Option {
118	return func(cfg *Config) {
119		cfg.complexityLimitFunc = complexityLimitFunc
120	}
121}
122
123// ResolverMiddleware allows you to define a function that will be called around every resolver,
124// useful for logging.
125func ResolverMiddleware(middleware graphql.FieldMiddleware) Option {
126	return func(cfg *Config) {
127		if cfg.resolverHook == nil {
128			cfg.resolverHook = middleware
129			return
130		}
131
132		lastResolve := cfg.resolverHook
133		cfg.resolverHook = func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) {
134			return lastResolve(ctx, func(ctx context.Context) (res interface{}, err error) {
135				return middleware(ctx, next)
136			})
137		}
138	}
139}
140
141// RequestMiddleware allows you to define a function that will be called around the root request,
142// after the query has been parsed. This is useful for logging
143func RequestMiddleware(middleware graphql.RequestMiddleware) Option {
144	return func(cfg *Config) {
145		if cfg.requestHook == nil {
146			cfg.requestHook = middleware
147			return
148		}
149
150		lastResolve := cfg.requestHook
151		cfg.requestHook = func(ctx context.Context, next func(ctx context.Context) []byte) []byte {
152			return lastResolve(ctx, func(ctx context.Context) []byte {
153				return middleware(ctx, next)
154			})
155		}
156	}
157}
158
159// Tracer allows you to add a request/resolver tracer that will be called around the root request,
160// calling resolver. This is useful for tracing
161func Tracer(tracer graphql.Tracer) Option {
162	return func(cfg *Config) {
163		if cfg.tracer == nil {
164			cfg.tracer = tracer
165
166		} else {
167			lastResolve := cfg.tracer
168			cfg.tracer = &tracerWrapper{
169				tracer1: lastResolve,
170				tracer2: tracer,
171			}
172		}
173
174		opt := RequestMiddleware(func(ctx context.Context, next func(ctx context.Context) []byte) []byte {
175			ctx = tracer.StartOperationExecution(ctx)
176			resp := next(ctx)
177			tracer.EndOperationExecution(ctx)
178
179			return resp
180		})
181		opt(cfg)
182	}
183}
184
185type tracerWrapper struct {
186	tracer1 graphql.Tracer
187	tracer2 graphql.Tracer
188}
189
190func (tw *tracerWrapper) StartOperationParsing(ctx context.Context) context.Context {
191	ctx = tw.tracer1.StartOperationParsing(ctx)
192	ctx = tw.tracer2.StartOperationParsing(ctx)
193	return ctx
194}
195
196func (tw *tracerWrapper) EndOperationParsing(ctx context.Context) {
197	tw.tracer2.EndOperationParsing(ctx)
198	tw.tracer1.EndOperationParsing(ctx)
199}
200
201func (tw *tracerWrapper) StartOperationValidation(ctx context.Context) context.Context {
202	ctx = tw.tracer1.StartOperationValidation(ctx)
203	ctx = tw.tracer2.StartOperationValidation(ctx)
204	return ctx
205}
206
207func (tw *tracerWrapper) EndOperationValidation(ctx context.Context) {
208	tw.tracer2.EndOperationValidation(ctx)
209	tw.tracer1.EndOperationValidation(ctx)
210}
211
212func (tw *tracerWrapper) StartOperationExecution(ctx context.Context) context.Context {
213	ctx = tw.tracer1.StartOperationExecution(ctx)
214	ctx = tw.tracer2.StartOperationExecution(ctx)
215	return ctx
216}
217
218func (tw *tracerWrapper) StartFieldExecution(ctx context.Context, field graphql.CollectedField) context.Context {
219	ctx = tw.tracer1.StartFieldExecution(ctx, field)
220	ctx = tw.tracer2.StartFieldExecution(ctx, field)
221	return ctx
222}
223
224func (tw *tracerWrapper) StartFieldResolverExecution(ctx context.Context, rc *graphql.ResolverContext) context.Context {
225	ctx = tw.tracer1.StartFieldResolverExecution(ctx, rc)
226	ctx = tw.tracer2.StartFieldResolverExecution(ctx, rc)
227	return ctx
228}
229
230func (tw *tracerWrapper) StartFieldChildExecution(ctx context.Context) context.Context {
231	ctx = tw.tracer1.StartFieldChildExecution(ctx)
232	ctx = tw.tracer2.StartFieldChildExecution(ctx)
233	return ctx
234}
235
236func (tw *tracerWrapper) EndFieldExecution(ctx context.Context) {
237	tw.tracer2.EndFieldExecution(ctx)
238	tw.tracer1.EndFieldExecution(ctx)
239}
240
241func (tw *tracerWrapper) EndOperationExecution(ctx context.Context) {
242	tw.tracer2.EndOperationExecution(ctx)
243	tw.tracer1.EndOperationExecution(ctx)
244}
245
246// CacheSize sets the maximum size of the query cache.
247// If size is less than or equal to 0, the cache is disabled.
248func CacheSize(size int) Option {
249	return func(cfg *Config) {
250		cfg.cacheSize = size
251	}
252}
253
254// WebsocketKeepAliveDuration allows you to reconfigure the keepalive behavior.
255// By default, keepalive is enabled with a DefaultConnectionKeepAlivePingInterval
256// duration. Set handler.connectionKeepAlivePingInterval = 0 to disable keepalive
257// altogether.
258func WebsocketKeepAliveDuration(duration time.Duration) Option {
259	return func(cfg *Config) {
260		cfg.connectionKeepAlivePingInterval = duration
261	}
262}
263
264const DefaultCacheSize = 1000
265const DefaultConnectionKeepAlivePingInterval = 25 * time.Second
266
267func GraphQL(exec graphql.ExecutableSchema, options ...Option) http.HandlerFunc {
268	cfg := &Config{
269		cacheSize:                       DefaultCacheSize,
270		connectionKeepAlivePingInterval: DefaultConnectionKeepAlivePingInterval,
271		upgrader: websocket.Upgrader{
272			ReadBufferSize:  1024,
273			WriteBufferSize: 1024,
274		},
275	}
276
277	for _, option := range options {
278		option(cfg)
279	}
280
281	var cache *lru.Cache
282	if cfg.cacheSize > 0 {
283		var err error
284		cache, err = lru.New(cfg.cacheSize)
285		if err != nil {
286			// An error is only returned for non-positive cache size
287			// and we already checked for that.
288			panic("unexpected error creating cache: " + err.Error())
289		}
290	}
291	if cfg.tracer == nil {
292		cfg.tracer = &graphql.NopTracer{}
293	}
294
295	handler := &graphqlHandler{
296		cfg:   cfg,
297		cache: cache,
298		exec:  exec,
299	}
300
301	return handler.ServeHTTP
302}
303
304var _ http.Handler = (*graphqlHandler)(nil)
305
306type graphqlHandler struct {
307	cfg   *Config
308	cache *lru.Cache
309	exec  graphql.ExecutableSchema
310}
311
312func (gh *graphqlHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
313	if r.Method == http.MethodOptions {
314		w.Header().Set("Allow", "OPTIONS, GET, POST")
315		w.WriteHeader(http.StatusOK)
316		return
317	}
318
319	if strings.Contains(r.Header.Get("Upgrade"), "websocket") {
320		connectWs(gh.exec, w, r, gh.cfg, gh.cache)
321		return
322	}
323
324	w.Header().Set("Content-Type", "application/json")
325	var reqParams params
326	switch r.Method {
327	case http.MethodGet:
328		reqParams.Query = r.URL.Query().Get("query")
329		reqParams.OperationName = r.URL.Query().Get("operationName")
330
331		if variables := r.URL.Query().Get("variables"); variables != "" {
332			if err := jsonDecode(strings.NewReader(variables), &reqParams.Variables); err != nil {
333				sendErrorf(w, http.StatusBadRequest, "variables could not be decoded")
334				return
335			}
336		}
337	case http.MethodPost:
338		if err := jsonDecode(r.Body, &reqParams); err != nil {
339			sendErrorf(w, http.StatusBadRequest, "json body could not be decoded: "+err.Error())
340			return
341		}
342	default:
343		w.WriteHeader(http.StatusMethodNotAllowed)
344		return
345	}
346
347	ctx := r.Context()
348
349	var doc *ast.QueryDocument
350	var cacheHit bool
351	if gh.cache != nil {
352		val, ok := gh.cache.Get(reqParams.Query)
353		if ok {
354			doc = val.(*ast.QueryDocument)
355			cacheHit = true
356		}
357	}
358
359	ctx, doc, gqlErr := gh.parseOperation(ctx, &parseOperationArgs{
360		Query:     reqParams.Query,
361		CachedDoc: doc,
362	})
363	if gqlErr != nil {
364		sendError(w, http.StatusUnprocessableEntity, gqlErr)
365		return
366	}
367
368	ctx, op, vars, listErr := gh.validateOperation(ctx, &validateOperationArgs{
369		Doc:           doc,
370		OperationName: reqParams.OperationName,
371		CacheHit:      cacheHit,
372		R:             r,
373		Variables:     reqParams.Variables,
374	})
375	if len(listErr) != 0 {
376		sendError(w, http.StatusUnprocessableEntity, listErr...)
377		return
378	}
379
380	if gh.cache != nil && !cacheHit {
381		gh.cache.Add(reqParams.Query, doc)
382	}
383
384	reqCtx := gh.cfg.newRequestContext(gh.exec, doc, op, reqParams.Query, vars)
385	ctx = graphql.WithRequestContext(ctx, reqCtx)
386
387	defer func() {
388		if err := recover(); err != nil {
389			userErr := reqCtx.Recover(ctx, err)
390			sendErrorf(w, http.StatusUnprocessableEntity, userErr.Error())
391		}
392	}()
393
394	if gh.cfg.complexityLimitFunc != nil {
395		reqCtx.ComplexityLimit = gh.cfg.complexityLimitFunc(ctx)
396	}
397
398	if reqCtx.ComplexityLimit > 0 && reqCtx.OperationComplexity > reqCtx.ComplexityLimit {
399		sendErrorf(w, http.StatusUnprocessableEntity, "operation has complexity %d, which exceeds the limit of %d", reqCtx.OperationComplexity, reqCtx.ComplexityLimit)
400		return
401	}
402
403	switch op.Operation {
404	case ast.Query:
405		b, err := json.Marshal(gh.exec.Query(ctx, op))
406		if err != nil {
407			panic(err)
408		}
409		w.Write(b)
410	case ast.Mutation:
411		b, err := json.Marshal(gh.exec.Mutation(ctx, op))
412		if err != nil {
413			panic(err)
414		}
415		w.Write(b)
416	default:
417		sendErrorf(w, http.StatusBadRequest, "unsupported operation type")
418	}
419}
420
421type parseOperationArgs struct {
422	Query     string
423	CachedDoc *ast.QueryDocument
424}
425
426func (gh *graphqlHandler) parseOperation(ctx context.Context, args *parseOperationArgs) (context.Context, *ast.QueryDocument, *gqlerror.Error) {
427	ctx = gh.cfg.tracer.StartOperationParsing(ctx)
428	defer func() { gh.cfg.tracer.EndOperationParsing(ctx) }()
429
430	if args.CachedDoc != nil {
431		return ctx, args.CachedDoc, nil
432	}
433
434	doc, gqlErr := parser.ParseQuery(&ast.Source{Input: args.Query})
435	if gqlErr != nil {
436		return ctx, nil, gqlErr
437	}
438
439	return ctx, doc, nil
440}
441
442type validateOperationArgs struct {
443	Doc           *ast.QueryDocument
444	OperationName string
445	CacheHit      bool
446	R             *http.Request
447	Variables     map[string]interface{}
448}
449
450func (gh *graphqlHandler) validateOperation(ctx context.Context, args *validateOperationArgs) (context.Context, *ast.OperationDefinition, map[string]interface{}, gqlerror.List) {
451	ctx = gh.cfg.tracer.StartOperationValidation(ctx)
452	defer func() { gh.cfg.tracer.EndOperationValidation(ctx) }()
453
454	if !args.CacheHit {
455		listErr := validator.Validate(gh.exec.Schema(), args.Doc)
456		if len(listErr) != 0 {
457			return ctx, nil, nil, listErr
458		}
459	}
460
461	op := args.Doc.Operations.ForName(args.OperationName)
462	if op == nil {
463		return ctx, nil, nil, gqlerror.List{gqlerror.Errorf("operation %s not found", args.OperationName)}
464	}
465
466	if op.Operation != ast.Query && args.R.Method == http.MethodGet {
467		return ctx, nil, nil, gqlerror.List{gqlerror.Errorf("GET requests only allow query operations")}
468	}
469
470	vars, err := validator.VariableValues(gh.exec.Schema(), op, args.Variables)
471	if err != nil {
472		return ctx, nil, nil, gqlerror.List{err}
473	}
474
475	return ctx, op, vars, nil
476}
477
478func jsonDecode(r io.Reader, val interface{}) error {
479	dec := json.NewDecoder(r)
480	dec.UseNumber()
481	return dec.Decode(val)
482}
483
484func sendError(w http.ResponseWriter, code int, errors ...*gqlerror.Error) {
485	w.WriteHeader(code)
486	b, err := json.Marshal(&graphql.Response{Errors: errors})
487	if err != nil {
488		panic(err)
489	}
490	w.Write(b)
491}
492
493func sendErrorf(w http.ResponseWriter, code int, format string, args ...interface{}) {
494	sendError(w, code, &gqlerror.Error{Message: fmt.Sprintf(format, args...)})
495}