graphql.go

  1package handler
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"io"
  8	"net/http"
  9	"strings"
 10
 11	"github.com/99designs/gqlgen/complexity"
 12	"github.com/99designs/gqlgen/graphql"
 13	"github.com/gorilla/websocket"
 14	"github.com/hashicorp/golang-lru"
 15	"github.com/vektah/gqlparser/ast"
 16	"github.com/vektah/gqlparser/gqlerror"
 17	"github.com/vektah/gqlparser/parser"
 18	"github.com/vektah/gqlparser/validator"
 19)
 20
 21type params struct {
 22	Query         string                 `json:"query"`
 23	OperationName string                 `json:"operationName"`
 24	Variables     map[string]interface{} `json:"variables"`
 25}
 26
 27type Config struct {
 28	cacheSize            int
 29	upgrader             websocket.Upgrader
 30	recover              graphql.RecoverFunc
 31	errorPresenter       graphql.ErrorPresenterFunc
 32	resolverHook         graphql.FieldMiddleware
 33	requestHook          graphql.RequestMiddleware
 34	tracer               graphql.Tracer
 35	complexityLimit      int
 36	disableIntrospection bool
 37}
 38
 39func (c *Config) newRequestContext(es graphql.ExecutableSchema, doc *ast.QueryDocument, op *ast.OperationDefinition, query string, variables map[string]interface{}) *graphql.RequestContext {
 40	reqCtx := graphql.NewRequestContext(doc, query, variables)
 41	reqCtx.DisableIntrospection = c.disableIntrospection
 42
 43	if hook := c.recover; hook != nil {
 44		reqCtx.Recover = hook
 45	}
 46
 47	if hook := c.errorPresenter; hook != nil {
 48		reqCtx.ErrorPresenter = hook
 49	}
 50
 51	if hook := c.resolverHook; hook != nil {
 52		reqCtx.ResolverMiddleware = hook
 53	}
 54
 55	if hook := c.requestHook; hook != nil {
 56		reqCtx.RequestMiddleware = hook
 57	}
 58
 59	if hook := c.tracer; hook != nil {
 60		reqCtx.Tracer = hook
 61	} else {
 62		reqCtx.Tracer = &graphql.NopTracer{}
 63	}
 64
 65	if c.complexityLimit > 0 {
 66		reqCtx.ComplexityLimit = c.complexityLimit
 67		operationComplexity := complexity.Calculate(es, op, variables)
 68		reqCtx.OperationComplexity = operationComplexity
 69	}
 70
 71	return reqCtx
 72}
 73
 74type Option func(cfg *Config)
 75
 76func WebsocketUpgrader(upgrader websocket.Upgrader) Option {
 77	return func(cfg *Config) {
 78		cfg.upgrader = upgrader
 79	}
 80}
 81
 82func RecoverFunc(recover graphql.RecoverFunc) Option {
 83	return func(cfg *Config) {
 84		cfg.recover = recover
 85	}
 86}
 87
 88// ErrorPresenter transforms errors found while resolving into errors that will be returned to the user. It provides
 89// a good place to add any extra fields, like error.type, that might be desired by your frontend. Check the default
 90// implementation in graphql.DefaultErrorPresenter for an example.
 91func ErrorPresenter(f graphql.ErrorPresenterFunc) Option {
 92	return func(cfg *Config) {
 93		cfg.errorPresenter = f
 94	}
 95}
 96
 97// IntrospectionEnabled = false will forbid clients from calling introspection endpoints. Can be useful in prod when you dont
 98// want clients introspecting the full schema.
 99func IntrospectionEnabled(enabled bool) Option {
100	return func(cfg *Config) {
101		cfg.disableIntrospection = !enabled
102	}
103}
104
105// ComplexityLimit sets a maximum query complexity that is allowed to be executed.
106// If a query is submitted that exceeds the limit, a 422 status code will be returned.
107func ComplexityLimit(limit int) Option {
108	return func(cfg *Config) {
109		cfg.complexityLimit = limit
110	}
111}
112
113// ResolverMiddleware allows you to define a function that will be called around every resolver,
114// useful for logging.
115func ResolverMiddleware(middleware graphql.FieldMiddleware) Option {
116	return func(cfg *Config) {
117		if cfg.resolverHook == nil {
118			cfg.resolverHook = middleware
119			return
120		}
121
122		lastResolve := cfg.resolverHook
123		cfg.resolverHook = func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) {
124			return lastResolve(ctx, func(ctx context.Context) (res interface{}, err error) {
125				return middleware(ctx, next)
126			})
127		}
128	}
129}
130
131// RequestMiddleware allows you to define a function that will be called around the root request,
132// after the query has been parsed. This is useful for logging
133func RequestMiddleware(middleware graphql.RequestMiddleware) Option {
134	return func(cfg *Config) {
135		if cfg.requestHook == nil {
136			cfg.requestHook = middleware
137			return
138		}
139
140		lastResolve := cfg.requestHook
141		cfg.requestHook = func(ctx context.Context, next func(ctx context.Context) []byte) []byte {
142			return lastResolve(ctx, func(ctx context.Context) []byte {
143				return middleware(ctx, next)
144			})
145		}
146	}
147}
148
149// Tracer allows you to add a request/resolver tracer that will be called around the root request,
150// calling resolver. This is useful for tracing
151func Tracer(tracer graphql.Tracer) Option {
152	return func(cfg *Config) {
153		if cfg.tracer == nil {
154			cfg.tracer = tracer
155
156		} else {
157			lastResolve := cfg.tracer
158			cfg.tracer = &tracerWrapper{
159				tracer1: lastResolve,
160				tracer2: tracer,
161			}
162		}
163
164		opt := RequestMiddleware(func(ctx context.Context, next func(ctx context.Context) []byte) []byte {
165			ctx = tracer.StartOperationExecution(ctx)
166			resp := next(ctx)
167			tracer.EndOperationExecution(ctx)
168
169			return resp
170		})
171		opt(cfg)
172	}
173}
174
175type tracerWrapper struct {
176	tracer1 graphql.Tracer
177	tracer2 graphql.Tracer
178}
179
180func (tw *tracerWrapper) StartOperationParsing(ctx context.Context) context.Context {
181	ctx = tw.tracer1.StartOperationParsing(ctx)
182	ctx = tw.tracer2.StartOperationParsing(ctx)
183	return ctx
184}
185
186func (tw *tracerWrapper) EndOperationParsing(ctx context.Context) {
187	tw.tracer2.EndOperationParsing(ctx)
188	tw.tracer1.EndOperationParsing(ctx)
189}
190
191func (tw *tracerWrapper) StartOperationValidation(ctx context.Context) context.Context {
192	ctx = tw.tracer1.StartOperationValidation(ctx)
193	ctx = tw.tracer2.StartOperationValidation(ctx)
194	return ctx
195}
196
197func (tw *tracerWrapper) EndOperationValidation(ctx context.Context) {
198	tw.tracer2.EndOperationValidation(ctx)
199	tw.tracer1.EndOperationValidation(ctx)
200}
201
202func (tw *tracerWrapper) StartOperationExecution(ctx context.Context) context.Context {
203	ctx = tw.tracer1.StartOperationExecution(ctx)
204	ctx = tw.tracer2.StartOperationExecution(ctx)
205	return ctx
206}
207
208func (tw *tracerWrapper) StartFieldExecution(ctx context.Context, field graphql.CollectedField) context.Context {
209	ctx = tw.tracer1.StartFieldExecution(ctx, field)
210	ctx = tw.tracer2.StartFieldExecution(ctx, field)
211	return ctx
212}
213
214func (tw *tracerWrapper) StartFieldResolverExecution(ctx context.Context, rc *graphql.ResolverContext) context.Context {
215	ctx = tw.tracer1.StartFieldResolverExecution(ctx, rc)
216	ctx = tw.tracer2.StartFieldResolverExecution(ctx, rc)
217	return ctx
218}
219
220func (tw *tracerWrapper) StartFieldChildExecution(ctx context.Context) context.Context {
221	ctx = tw.tracer1.StartFieldChildExecution(ctx)
222	ctx = tw.tracer2.StartFieldChildExecution(ctx)
223	return ctx
224}
225
226func (tw *tracerWrapper) EndFieldExecution(ctx context.Context) {
227	tw.tracer2.EndFieldExecution(ctx)
228	tw.tracer1.EndFieldExecution(ctx)
229}
230
231func (tw *tracerWrapper) EndOperationExecution(ctx context.Context) {
232	tw.tracer2.EndOperationExecution(ctx)
233	tw.tracer1.EndOperationExecution(ctx)
234}
235
236// CacheSize sets the maximum size of the query cache.
237// If size is less than or equal to 0, the cache is disabled.
238func CacheSize(size int) Option {
239	return func(cfg *Config) {
240		cfg.cacheSize = size
241	}
242}
243
244const DefaultCacheSize = 1000
245
246func GraphQL(exec graphql.ExecutableSchema, options ...Option) http.HandlerFunc {
247	cfg := &Config{
248		cacheSize: DefaultCacheSize,
249		upgrader: websocket.Upgrader{
250			ReadBufferSize:  1024,
251			WriteBufferSize: 1024,
252		},
253	}
254
255	for _, option := range options {
256		option(cfg)
257	}
258
259	var cache *lru.Cache
260	if cfg.cacheSize > 0 {
261		var err error
262		cache, err = lru.New(DefaultCacheSize)
263		if err != nil {
264			// An error is only returned for non-positive cache size
265			// and we already checked for that.
266			panic("unexpected error creating cache: " + err.Error())
267		}
268	}
269	if cfg.tracer == nil {
270		cfg.tracer = &graphql.NopTracer{}
271	}
272
273	handler := &graphqlHandler{
274		cfg:   cfg,
275		cache: cache,
276		exec:  exec,
277	}
278
279	return handler.ServeHTTP
280}
281
282var _ http.Handler = (*graphqlHandler)(nil)
283
284type graphqlHandler struct {
285	cfg   *Config
286	cache *lru.Cache
287	exec  graphql.ExecutableSchema
288}
289
290func (gh *graphqlHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
291	if r.Method == http.MethodOptions {
292		w.Header().Set("Allow", "OPTIONS, GET, POST")
293		w.WriteHeader(http.StatusOK)
294		return
295	}
296
297	if strings.Contains(r.Header.Get("Upgrade"), "websocket") {
298		connectWs(gh.exec, w, r, gh.cfg)
299		return
300	}
301
302	var reqParams params
303	switch r.Method {
304	case http.MethodGet:
305		reqParams.Query = r.URL.Query().Get("query")
306		reqParams.OperationName = r.URL.Query().Get("operationName")
307
308		if variables := r.URL.Query().Get("variables"); variables != "" {
309			if err := jsonDecode(strings.NewReader(variables), &reqParams.Variables); err != nil {
310				sendErrorf(w, http.StatusBadRequest, "variables could not be decoded")
311				return
312			}
313		}
314	case http.MethodPost:
315		if err := jsonDecode(r.Body, &reqParams); err != nil {
316			sendErrorf(w, http.StatusBadRequest, "json body could not be decoded: "+err.Error())
317			return
318		}
319	default:
320		w.WriteHeader(http.StatusMethodNotAllowed)
321		return
322	}
323	w.Header().Set("Content-Type", "application/json")
324
325	ctx := r.Context()
326
327	var doc *ast.QueryDocument
328	var cacheHit bool
329	if gh.cache != nil {
330		val, ok := gh.cache.Get(reqParams.Query)
331		if ok {
332			doc = val.(*ast.QueryDocument)
333			cacheHit = true
334		}
335	}
336
337	ctx, doc, gqlErr := gh.parseOperation(ctx, &parseOperationArgs{
338		Query:     reqParams.Query,
339		CachedDoc: doc,
340	})
341	if gqlErr != nil {
342		sendError(w, http.StatusUnprocessableEntity, gqlErr)
343		return
344	}
345
346	ctx, op, vars, listErr := gh.validateOperation(ctx, &validateOperationArgs{
347		Doc:           doc,
348		OperationName: reqParams.OperationName,
349		CacheHit:      cacheHit,
350		R:             r,
351		Variables:     reqParams.Variables,
352	})
353	if len(listErr) != 0 {
354		sendError(w, http.StatusUnprocessableEntity, listErr...)
355		return
356	}
357
358	if gh.cache != nil && !cacheHit {
359		gh.cache.Add(reqParams.Query, doc)
360	}
361
362	reqCtx := gh.cfg.newRequestContext(gh.exec, doc, op, reqParams.Query, vars)
363	ctx = graphql.WithRequestContext(ctx, reqCtx)
364
365	defer func() {
366		if err := recover(); err != nil {
367			userErr := reqCtx.Recover(ctx, err)
368			sendErrorf(w, http.StatusUnprocessableEntity, userErr.Error())
369		}
370	}()
371
372	if reqCtx.ComplexityLimit > 0 && reqCtx.OperationComplexity > reqCtx.ComplexityLimit {
373		sendErrorf(w, http.StatusUnprocessableEntity, "operation has complexity %d, which exceeds the limit of %d", reqCtx.OperationComplexity, reqCtx.ComplexityLimit)
374		return
375	}
376
377	switch op.Operation {
378	case ast.Query:
379		b, err := json.Marshal(gh.exec.Query(ctx, op))
380		if err != nil {
381			panic(err)
382		}
383		w.Write(b)
384	case ast.Mutation:
385		b, err := json.Marshal(gh.exec.Mutation(ctx, op))
386		if err != nil {
387			panic(err)
388		}
389		w.Write(b)
390	default:
391		sendErrorf(w, http.StatusBadRequest, "unsupported operation type")
392	}
393}
394
395type parseOperationArgs struct {
396	Query     string
397	CachedDoc *ast.QueryDocument
398}
399
400func (gh *graphqlHandler) parseOperation(ctx context.Context, args *parseOperationArgs) (context.Context, *ast.QueryDocument, *gqlerror.Error) {
401	ctx = gh.cfg.tracer.StartOperationParsing(ctx)
402	defer func() { gh.cfg.tracer.EndOperationParsing(ctx) }()
403
404	if args.CachedDoc != nil {
405		return ctx, args.CachedDoc, nil
406	}
407
408	doc, gqlErr := parser.ParseQuery(&ast.Source{Input: args.Query})
409	if gqlErr != nil {
410		return ctx, nil, gqlErr
411	}
412
413	return ctx, doc, nil
414}
415
416type validateOperationArgs struct {
417	Doc           *ast.QueryDocument
418	OperationName string
419	CacheHit      bool
420	R             *http.Request
421	Variables     map[string]interface{}
422}
423
424func (gh *graphqlHandler) validateOperation(ctx context.Context, args *validateOperationArgs) (context.Context, *ast.OperationDefinition, map[string]interface{}, gqlerror.List) {
425	ctx = gh.cfg.tracer.StartOperationValidation(ctx)
426	defer func() { gh.cfg.tracer.EndOperationValidation(ctx) }()
427
428	if !args.CacheHit {
429		listErr := validator.Validate(gh.exec.Schema(), args.Doc)
430		if len(listErr) != 0 {
431			return ctx, nil, nil, listErr
432		}
433	}
434
435	op := args.Doc.Operations.ForName(args.OperationName)
436	if op == nil {
437		return ctx, nil, nil, gqlerror.List{gqlerror.Errorf("operation %s not found", args.OperationName)}
438	}
439
440	if op.Operation != ast.Query && args.R.Method == http.MethodGet {
441		return ctx, nil, nil, gqlerror.List{gqlerror.Errorf("GET requests only allow query operations")}
442	}
443
444	vars, err := validator.VariableValues(gh.exec.Schema(), op, args.Variables)
445	if err != nil {
446		return ctx, nil, nil, gqlerror.List{err}
447	}
448
449	return ctx, op, vars, nil
450}
451
452func jsonDecode(r io.Reader, val interface{}) error {
453	dec := json.NewDecoder(r)
454	dec.UseNumber()
455	return dec.Decode(val)
456}
457
458func sendError(w http.ResponseWriter, code int, errors ...*gqlerror.Error) {
459	w.WriteHeader(code)
460	b, err := json.Marshal(&graphql.Response{Errors: errors})
461	if err != nil {
462		panic(err)
463	}
464	w.Write(b)
465}
466
467func sendErrorf(w http.ResponseWriter, code int, format string, args ...interface{}) {
468	sendError(w, code, &gqlerror.Error{Message: fmt.Sprintf(format, args...)})
469}