graphql.go

  1package handler
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"io"
  8	"net/http"
  9	"strings"
 10	"time"
 11
 12	"github.com/99designs/gqlgen/complexity"
 13	"github.com/99designs/gqlgen/graphql"
 14	"github.com/gorilla/websocket"
 15	"github.com/hashicorp/golang-lru"
 16	"github.com/vektah/gqlparser/ast"
 17	"github.com/vektah/gqlparser/gqlerror"
 18	"github.com/vektah/gqlparser/parser"
 19	"github.com/vektah/gqlparser/validator"
 20)
 21
 22type params struct {
 23	Query         string                 `json:"query"`
 24	OperationName string                 `json:"operationName"`
 25	Variables     map[string]interface{} `json:"variables"`
 26}
 27
 28type Config struct {
 29	cacheSize                       int
 30	upgrader                        websocket.Upgrader
 31	recover                         graphql.RecoverFunc
 32	errorPresenter                  graphql.ErrorPresenterFunc
 33	resolverHook                    graphql.FieldMiddleware
 34	requestHook                     graphql.RequestMiddleware
 35	tracer                          graphql.Tracer
 36	complexityLimit                 int
 37	disableIntrospection            bool
 38	connectionKeepAlivePingInterval time.Duration
 39}
 40
 41func (c *Config) newRequestContext(es graphql.ExecutableSchema, doc *ast.QueryDocument, op *ast.OperationDefinition, query string, variables map[string]interface{}) *graphql.RequestContext {
 42	reqCtx := graphql.NewRequestContext(doc, query, variables)
 43	reqCtx.DisableIntrospection = c.disableIntrospection
 44
 45	if hook := c.recover; hook != nil {
 46		reqCtx.Recover = hook
 47	}
 48
 49	if hook := c.errorPresenter; hook != nil {
 50		reqCtx.ErrorPresenter = hook
 51	}
 52
 53	if hook := c.resolverHook; hook != nil {
 54		reqCtx.ResolverMiddleware = hook
 55	}
 56
 57	if hook := c.requestHook; hook != nil {
 58		reqCtx.RequestMiddleware = hook
 59	}
 60
 61	if hook := c.tracer; hook != nil {
 62		reqCtx.Tracer = hook
 63	} else {
 64		reqCtx.Tracer = &graphql.NopTracer{}
 65	}
 66
 67	if c.complexityLimit > 0 {
 68		reqCtx.ComplexityLimit = c.complexityLimit
 69		operationComplexity := complexity.Calculate(es, op, variables)
 70		reqCtx.OperationComplexity = operationComplexity
 71	}
 72
 73	return reqCtx
 74}
 75
 76type Option func(cfg *Config)
 77
 78func WebsocketUpgrader(upgrader websocket.Upgrader) Option {
 79	return func(cfg *Config) {
 80		cfg.upgrader = upgrader
 81	}
 82}
 83
 84func RecoverFunc(recover graphql.RecoverFunc) Option {
 85	return func(cfg *Config) {
 86		cfg.recover = recover
 87	}
 88}
 89
 90// ErrorPresenter transforms errors found while resolving into errors that will be returned to the user. It provides
 91// a good place to add any extra fields, like error.type, that might be desired by your frontend. Check the default
 92// implementation in graphql.DefaultErrorPresenter for an example.
 93func ErrorPresenter(f graphql.ErrorPresenterFunc) Option {
 94	return func(cfg *Config) {
 95		cfg.errorPresenter = f
 96	}
 97}
 98
 99// IntrospectionEnabled = false will forbid clients from calling introspection endpoints. Can be useful in prod when you dont
100// want clients introspecting the full schema.
101func IntrospectionEnabled(enabled bool) Option {
102	return func(cfg *Config) {
103		cfg.disableIntrospection = !enabled
104	}
105}
106
107// ComplexityLimit sets a maximum query complexity that is allowed to be executed.
108// If a query is submitted that exceeds the limit, a 422 status code will be returned.
109func ComplexityLimit(limit int) Option {
110	return func(cfg *Config) {
111		cfg.complexityLimit = limit
112	}
113}
114
115// ResolverMiddleware allows you to define a function that will be called around every resolver,
116// useful for logging.
117func ResolverMiddleware(middleware graphql.FieldMiddleware) Option {
118	return func(cfg *Config) {
119		if cfg.resolverHook == nil {
120			cfg.resolverHook = middleware
121			return
122		}
123
124		lastResolve := cfg.resolverHook
125		cfg.resolverHook = func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) {
126			return lastResolve(ctx, func(ctx context.Context) (res interface{}, err error) {
127				return middleware(ctx, next)
128			})
129		}
130	}
131}
132
133// RequestMiddleware allows you to define a function that will be called around the root request,
134// after the query has been parsed. This is useful for logging
135func RequestMiddleware(middleware graphql.RequestMiddleware) Option {
136	return func(cfg *Config) {
137		if cfg.requestHook == nil {
138			cfg.requestHook = middleware
139			return
140		}
141
142		lastResolve := cfg.requestHook
143		cfg.requestHook = func(ctx context.Context, next func(ctx context.Context) []byte) []byte {
144			return lastResolve(ctx, func(ctx context.Context) []byte {
145				return middleware(ctx, next)
146			})
147		}
148	}
149}
150
151// Tracer allows you to add a request/resolver tracer that will be called around the root request,
152// calling resolver. This is useful for tracing
153func Tracer(tracer graphql.Tracer) Option {
154	return func(cfg *Config) {
155		if cfg.tracer == nil {
156			cfg.tracer = tracer
157
158		} else {
159			lastResolve := cfg.tracer
160			cfg.tracer = &tracerWrapper{
161				tracer1: lastResolve,
162				tracer2: tracer,
163			}
164		}
165
166		opt := RequestMiddleware(func(ctx context.Context, next func(ctx context.Context) []byte) []byte {
167			ctx = tracer.StartOperationExecution(ctx)
168			resp := next(ctx)
169			tracer.EndOperationExecution(ctx)
170
171			return resp
172		})
173		opt(cfg)
174	}
175}
176
177type tracerWrapper struct {
178	tracer1 graphql.Tracer
179	tracer2 graphql.Tracer
180}
181
182func (tw *tracerWrapper) StartOperationParsing(ctx context.Context) context.Context {
183	ctx = tw.tracer1.StartOperationParsing(ctx)
184	ctx = tw.tracer2.StartOperationParsing(ctx)
185	return ctx
186}
187
188func (tw *tracerWrapper) EndOperationParsing(ctx context.Context) {
189	tw.tracer2.EndOperationParsing(ctx)
190	tw.tracer1.EndOperationParsing(ctx)
191}
192
193func (tw *tracerWrapper) StartOperationValidation(ctx context.Context) context.Context {
194	ctx = tw.tracer1.StartOperationValidation(ctx)
195	ctx = tw.tracer2.StartOperationValidation(ctx)
196	return ctx
197}
198
199func (tw *tracerWrapper) EndOperationValidation(ctx context.Context) {
200	tw.tracer2.EndOperationValidation(ctx)
201	tw.tracer1.EndOperationValidation(ctx)
202}
203
204func (tw *tracerWrapper) StartOperationExecution(ctx context.Context) context.Context {
205	ctx = tw.tracer1.StartOperationExecution(ctx)
206	ctx = tw.tracer2.StartOperationExecution(ctx)
207	return ctx
208}
209
210func (tw *tracerWrapper) StartFieldExecution(ctx context.Context, field graphql.CollectedField) context.Context {
211	ctx = tw.tracer1.StartFieldExecution(ctx, field)
212	ctx = tw.tracer2.StartFieldExecution(ctx, field)
213	return ctx
214}
215
216func (tw *tracerWrapper) StartFieldResolverExecution(ctx context.Context, rc *graphql.ResolverContext) context.Context {
217	ctx = tw.tracer1.StartFieldResolverExecution(ctx, rc)
218	ctx = tw.tracer2.StartFieldResolverExecution(ctx, rc)
219	return ctx
220}
221
222func (tw *tracerWrapper) StartFieldChildExecution(ctx context.Context) context.Context {
223	ctx = tw.tracer1.StartFieldChildExecution(ctx)
224	ctx = tw.tracer2.StartFieldChildExecution(ctx)
225	return ctx
226}
227
228func (tw *tracerWrapper) EndFieldExecution(ctx context.Context) {
229	tw.tracer2.EndFieldExecution(ctx)
230	tw.tracer1.EndFieldExecution(ctx)
231}
232
233func (tw *tracerWrapper) EndOperationExecution(ctx context.Context) {
234	tw.tracer2.EndOperationExecution(ctx)
235	tw.tracer1.EndOperationExecution(ctx)
236}
237
238// CacheSize sets the maximum size of the query cache.
239// If size is less than or equal to 0, the cache is disabled.
240func CacheSize(size int) Option {
241	return func(cfg *Config) {
242		cfg.cacheSize = size
243	}
244}
245
246const DefaultCacheSize = 1000
247
248// WebsocketKeepAliveDuration allows you to reconfigure the keepAlive behavior.
249// By default, keep-alive is disabled.
250func WebsocketKeepAliveDuration(duration time.Duration) Option {
251	return func(cfg *Config) {
252		cfg.connectionKeepAlivePingInterval = duration
253	}
254}
255
256func GraphQL(exec graphql.ExecutableSchema, options ...Option) http.HandlerFunc {
257	cfg := &Config{
258		cacheSize: DefaultCacheSize,
259		upgrader: websocket.Upgrader{
260			ReadBufferSize:  1024,
261			WriteBufferSize: 1024,
262		},
263	}
264
265	for _, option := range options {
266		option(cfg)
267	}
268
269	var cache *lru.Cache
270	if cfg.cacheSize > 0 {
271		var err error
272		cache, err = lru.New(DefaultCacheSize)
273		if err != nil {
274			// An error is only returned for non-positive cache size
275			// and we already checked for that.
276			panic("unexpected error creating cache: " + err.Error())
277		}
278	}
279	if cfg.tracer == nil {
280		cfg.tracer = &graphql.NopTracer{}
281	}
282
283	handler := &graphqlHandler{
284		cfg:   cfg,
285		cache: cache,
286		exec:  exec,
287	}
288
289	return handler.ServeHTTP
290}
291
292var _ http.Handler = (*graphqlHandler)(nil)
293
294type graphqlHandler struct {
295	cfg   *Config
296	cache *lru.Cache
297	exec  graphql.ExecutableSchema
298}
299
300func (gh *graphqlHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
301	if r.Method == http.MethodOptions {
302		w.Header().Set("Allow", "OPTIONS, GET, POST")
303		w.WriteHeader(http.StatusOK)
304		return
305	}
306
307	if strings.Contains(r.Header.Get("Upgrade"), "websocket") {
308		connectWs(gh.exec, w, r, gh.cfg)
309		return
310	}
311
312	var reqParams params
313	switch r.Method {
314	case http.MethodGet:
315		reqParams.Query = r.URL.Query().Get("query")
316		reqParams.OperationName = r.URL.Query().Get("operationName")
317
318		if variables := r.URL.Query().Get("variables"); variables != "" {
319			if err := jsonDecode(strings.NewReader(variables), &reqParams.Variables); err != nil {
320				sendErrorf(w, http.StatusBadRequest, "variables could not be decoded")
321				return
322			}
323		}
324	case http.MethodPost:
325		if err := jsonDecode(r.Body, &reqParams); err != nil {
326			sendErrorf(w, http.StatusBadRequest, "json body could not be decoded: "+err.Error())
327			return
328		}
329	default:
330		w.WriteHeader(http.StatusMethodNotAllowed)
331		return
332	}
333	w.Header().Set("Content-Type", "application/json")
334
335	ctx := r.Context()
336
337	var doc *ast.QueryDocument
338	var cacheHit bool
339	if gh.cache != nil {
340		val, ok := gh.cache.Get(reqParams.Query)
341		if ok {
342			doc = val.(*ast.QueryDocument)
343			cacheHit = true
344		}
345	}
346
347	ctx, doc, gqlErr := gh.parseOperation(ctx, &parseOperationArgs{
348		Query:     reqParams.Query,
349		CachedDoc: doc,
350	})
351	if gqlErr != nil {
352		sendError(w, http.StatusUnprocessableEntity, gqlErr)
353		return
354	}
355
356	ctx, op, vars, listErr := gh.validateOperation(ctx, &validateOperationArgs{
357		Doc:           doc,
358		OperationName: reqParams.OperationName,
359		CacheHit:      cacheHit,
360		R:             r,
361		Variables:     reqParams.Variables,
362	})
363	if len(listErr) != 0 {
364		sendError(w, http.StatusUnprocessableEntity, listErr...)
365		return
366	}
367
368	if gh.cache != nil && !cacheHit {
369		gh.cache.Add(reqParams.Query, doc)
370	}
371
372	reqCtx := gh.cfg.newRequestContext(gh.exec, doc, op, reqParams.Query, vars)
373	ctx = graphql.WithRequestContext(ctx, reqCtx)
374
375	defer func() {
376		if err := recover(); err != nil {
377			userErr := reqCtx.Recover(ctx, err)
378			sendErrorf(w, http.StatusUnprocessableEntity, userErr.Error())
379		}
380	}()
381
382	if reqCtx.ComplexityLimit > 0 && reqCtx.OperationComplexity > reqCtx.ComplexityLimit {
383		sendErrorf(w, http.StatusUnprocessableEntity, "operation has complexity %d, which exceeds the limit of %d", reqCtx.OperationComplexity, reqCtx.ComplexityLimit)
384		return
385	}
386
387	switch op.Operation {
388	case ast.Query:
389		b, err := json.Marshal(gh.exec.Query(ctx, op))
390		if err != nil {
391			panic(err)
392		}
393		w.Write(b)
394	case ast.Mutation:
395		b, err := json.Marshal(gh.exec.Mutation(ctx, op))
396		if err != nil {
397			panic(err)
398		}
399		w.Write(b)
400	default:
401		sendErrorf(w, http.StatusBadRequest, "unsupported operation type")
402	}
403}
404
405type parseOperationArgs struct {
406	Query     string
407	CachedDoc *ast.QueryDocument
408}
409
410func (gh *graphqlHandler) parseOperation(ctx context.Context, args *parseOperationArgs) (context.Context, *ast.QueryDocument, *gqlerror.Error) {
411	ctx = gh.cfg.tracer.StartOperationParsing(ctx)
412	defer func() { gh.cfg.tracer.EndOperationParsing(ctx) }()
413
414	if args.CachedDoc != nil {
415		return ctx, args.CachedDoc, nil
416	}
417
418	doc, gqlErr := parser.ParseQuery(&ast.Source{Input: args.Query})
419	if gqlErr != nil {
420		return ctx, nil, gqlErr
421	}
422
423	return ctx, doc, nil
424}
425
426type validateOperationArgs struct {
427	Doc           *ast.QueryDocument
428	OperationName string
429	CacheHit      bool
430	R             *http.Request
431	Variables     map[string]interface{}
432}
433
434func (gh *graphqlHandler) validateOperation(ctx context.Context, args *validateOperationArgs) (context.Context, *ast.OperationDefinition, map[string]interface{}, gqlerror.List) {
435	ctx = gh.cfg.tracer.StartOperationValidation(ctx)
436	defer func() { gh.cfg.tracer.EndOperationValidation(ctx) }()
437
438	if !args.CacheHit {
439		listErr := validator.Validate(gh.exec.Schema(), args.Doc)
440		if len(listErr) != 0 {
441			return ctx, nil, nil, listErr
442		}
443	}
444
445	op := args.Doc.Operations.ForName(args.OperationName)
446	if op == nil {
447		return ctx, nil, nil, gqlerror.List{gqlerror.Errorf("operation %s not found", args.OperationName)}
448	}
449
450	if op.Operation != ast.Query && args.R.Method == http.MethodGet {
451		return ctx, nil, nil, gqlerror.List{gqlerror.Errorf("GET requests only allow query operations")}
452	}
453
454	vars, err := validator.VariableValues(gh.exec.Schema(), op, args.Variables)
455	if err != nil {
456		return ctx, nil, nil, gqlerror.List{err}
457	}
458
459	return ctx, op, vars, nil
460}
461
462func jsonDecode(r io.Reader, val interface{}) error {
463	dec := json.NewDecoder(r)
464	dec.UseNumber()
465	return dec.Decode(val)
466}
467
468func sendError(w http.ResponseWriter, code int, errors ...*gqlerror.Error) {
469	w.WriteHeader(code)
470	b, err := json.Marshal(&graphql.Response{Errors: errors})
471	if err != nil {
472		panic(err)
473	}
474	w.Write(b)
475}
476
477func sendErrorf(w http.ResponseWriter, code int, format string, args ...interface{}) {
478	sendError(w, code, &gqlerror.Error{Message: fmt.Sprintf(format, args...)})
479}