1package handler
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "io"
8 "net/http"
9 "strings"
10 "time"
11
12 "github.com/99designs/gqlgen/complexity"
13 "github.com/99designs/gqlgen/graphql"
14 "github.com/gorilla/websocket"
15 "github.com/hashicorp/golang-lru"
16 "github.com/vektah/gqlparser/ast"
17 "github.com/vektah/gqlparser/gqlerror"
18 "github.com/vektah/gqlparser/parser"
19 "github.com/vektah/gqlparser/validator"
20)
21
22type params struct {
23 Query string `json:"query"`
24 OperationName string `json:"operationName"`
25 Variables map[string]interface{} `json:"variables"`
26}
27
28type Config struct {
29 cacheSize int
30 upgrader websocket.Upgrader
31 recover graphql.RecoverFunc
32 errorPresenter graphql.ErrorPresenterFunc
33 resolverHook graphql.FieldMiddleware
34 requestHook graphql.RequestMiddleware
35 tracer graphql.Tracer
36 complexityLimit int
37 complexityLimitFunc graphql.ComplexityLimitFunc
38 disableIntrospection bool
39 connectionKeepAlivePingInterval time.Duration
40}
41
42func (c *Config) newRequestContext(es graphql.ExecutableSchema, doc *ast.QueryDocument, op *ast.OperationDefinition, query string, variables map[string]interface{}) *graphql.RequestContext {
43 reqCtx := graphql.NewRequestContext(doc, query, variables)
44 reqCtx.DisableIntrospection = c.disableIntrospection
45
46 if hook := c.recover; hook != nil {
47 reqCtx.Recover = hook
48 }
49
50 if hook := c.errorPresenter; hook != nil {
51 reqCtx.ErrorPresenter = hook
52 }
53
54 if hook := c.resolverHook; hook != nil {
55 reqCtx.ResolverMiddleware = hook
56 }
57
58 if hook := c.requestHook; hook != nil {
59 reqCtx.RequestMiddleware = hook
60 }
61
62 if hook := c.tracer; hook != nil {
63 reqCtx.Tracer = hook
64 }
65
66 if c.complexityLimit > 0 || c.complexityLimitFunc != nil {
67 reqCtx.ComplexityLimit = c.complexityLimit
68 operationComplexity := complexity.Calculate(es, op, variables)
69 reqCtx.OperationComplexity = operationComplexity
70 }
71
72 return reqCtx
73}
74
75type Option func(cfg *Config)
76
77func WebsocketUpgrader(upgrader websocket.Upgrader) Option {
78 return func(cfg *Config) {
79 cfg.upgrader = upgrader
80 }
81}
82
83func RecoverFunc(recover graphql.RecoverFunc) Option {
84 return func(cfg *Config) {
85 cfg.recover = recover
86 }
87}
88
89// ErrorPresenter transforms errors found while resolving into errors that will be returned to the user. It provides
90// a good place to add any extra fields, like error.type, that might be desired by your frontend. Check the default
91// implementation in graphql.DefaultErrorPresenter for an example.
92func ErrorPresenter(f graphql.ErrorPresenterFunc) Option {
93 return func(cfg *Config) {
94 cfg.errorPresenter = f
95 }
96}
97
98// IntrospectionEnabled = false will forbid clients from calling introspection endpoints. Can be useful in prod when you dont
99// want clients introspecting the full schema.
100func IntrospectionEnabled(enabled bool) Option {
101 return func(cfg *Config) {
102 cfg.disableIntrospection = !enabled
103 }
104}
105
106// ComplexityLimit sets a maximum query complexity that is allowed to be executed.
107// If a query is submitted that exceeds the limit, a 422 status code will be returned.
108func ComplexityLimit(limit int) Option {
109 return func(cfg *Config) {
110 cfg.complexityLimit = limit
111 }
112}
113
114// ComplexityLimitFunc allows you to define a function to dynamically set the maximum query complexity that is allowed
115// to be executed.
116// If a query is submitted that exceeds the limit, a 422 status code will be returned.
117func ComplexityLimitFunc(complexityLimitFunc graphql.ComplexityLimitFunc) Option {
118 return func(cfg *Config) {
119 cfg.complexityLimitFunc = complexityLimitFunc
120 }
121}
122
123// ResolverMiddleware allows you to define a function that will be called around every resolver,
124// useful for logging.
125func ResolverMiddleware(middleware graphql.FieldMiddleware) Option {
126 return func(cfg *Config) {
127 if cfg.resolverHook == nil {
128 cfg.resolverHook = middleware
129 return
130 }
131
132 lastResolve := cfg.resolverHook
133 cfg.resolverHook = func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) {
134 return lastResolve(ctx, func(ctx context.Context) (res interface{}, err error) {
135 return middleware(ctx, next)
136 })
137 }
138 }
139}
140
141// RequestMiddleware allows you to define a function that will be called around the root request,
142// after the query has been parsed. This is useful for logging
143func RequestMiddleware(middleware graphql.RequestMiddleware) Option {
144 return func(cfg *Config) {
145 if cfg.requestHook == nil {
146 cfg.requestHook = middleware
147 return
148 }
149
150 lastResolve := cfg.requestHook
151 cfg.requestHook = func(ctx context.Context, next func(ctx context.Context) []byte) []byte {
152 return lastResolve(ctx, func(ctx context.Context) []byte {
153 return middleware(ctx, next)
154 })
155 }
156 }
157}
158
159// Tracer allows you to add a request/resolver tracer that will be called around the root request,
160// calling resolver. This is useful for tracing
161func Tracer(tracer graphql.Tracer) Option {
162 return func(cfg *Config) {
163 if cfg.tracer == nil {
164 cfg.tracer = tracer
165
166 } else {
167 lastResolve := cfg.tracer
168 cfg.tracer = &tracerWrapper{
169 tracer1: lastResolve,
170 tracer2: tracer,
171 }
172 }
173
174 opt := RequestMiddleware(func(ctx context.Context, next func(ctx context.Context) []byte) []byte {
175 ctx = tracer.StartOperationExecution(ctx)
176 resp := next(ctx)
177 tracer.EndOperationExecution(ctx)
178
179 return resp
180 })
181 opt(cfg)
182 }
183}
184
185type tracerWrapper struct {
186 tracer1 graphql.Tracer
187 tracer2 graphql.Tracer
188}
189
190func (tw *tracerWrapper) StartOperationParsing(ctx context.Context) context.Context {
191 ctx = tw.tracer1.StartOperationParsing(ctx)
192 ctx = tw.tracer2.StartOperationParsing(ctx)
193 return ctx
194}
195
196func (tw *tracerWrapper) EndOperationParsing(ctx context.Context) {
197 tw.tracer2.EndOperationParsing(ctx)
198 tw.tracer1.EndOperationParsing(ctx)
199}
200
201func (tw *tracerWrapper) StartOperationValidation(ctx context.Context) context.Context {
202 ctx = tw.tracer1.StartOperationValidation(ctx)
203 ctx = tw.tracer2.StartOperationValidation(ctx)
204 return ctx
205}
206
207func (tw *tracerWrapper) EndOperationValidation(ctx context.Context) {
208 tw.tracer2.EndOperationValidation(ctx)
209 tw.tracer1.EndOperationValidation(ctx)
210}
211
212func (tw *tracerWrapper) StartOperationExecution(ctx context.Context) context.Context {
213 ctx = tw.tracer1.StartOperationExecution(ctx)
214 ctx = tw.tracer2.StartOperationExecution(ctx)
215 return ctx
216}
217
218func (tw *tracerWrapper) StartFieldExecution(ctx context.Context, field graphql.CollectedField) context.Context {
219 ctx = tw.tracer1.StartFieldExecution(ctx, field)
220 ctx = tw.tracer2.StartFieldExecution(ctx, field)
221 return ctx
222}
223
224func (tw *tracerWrapper) StartFieldResolverExecution(ctx context.Context, rc *graphql.ResolverContext) context.Context {
225 ctx = tw.tracer1.StartFieldResolverExecution(ctx, rc)
226 ctx = tw.tracer2.StartFieldResolverExecution(ctx, rc)
227 return ctx
228}
229
230func (tw *tracerWrapper) StartFieldChildExecution(ctx context.Context) context.Context {
231 ctx = tw.tracer1.StartFieldChildExecution(ctx)
232 ctx = tw.tracer2.StartFieldChildExecution(ctx)
233 return ctx
234}
235
236func (tw *tracerWrapper) EndFieldExecution(ctx context.Context) {
237 tw.tracer2.EndFieldExecution(ctx)
238 tw.tracer1.EndFieldExecution(ctx)
239}
240
241func (tw *tracerWrapper) EndOperationExecution(ctx context.Context) {
242 tw.tracer2.EndOperationExecution(ctx)
243 tw.tracer1.EndOperationExecution(ctx)
244}
245
246// CacheSize sets the maximum size of the query cache.
247// If size is less than or equal to 0, the cache is disabled.
248func CacheSize(size int) Option {
249 return func(cfg *Config) {
250 cfg.cacheSize = size
251 }
252}
253
254// WebsocketKeepAliveDuration allows you to reconfigure the keepalive behavior.
255// By default, keepalive is enabled with a DefaultConnectionKeepAlivePingInterval
256// duration. Set handler.connectionKeepAlivePingInterval = 0 to disable keepalive
257// altogether.
258func WebsocketKeepAliveDuration(duration time.Duration) Option {
259 return func(cfg *Config) {
260 cfg.connectionKeepAlivePingInterval = duration
261 }
262}
263
264const DefaultCacheSize = 1000
265const DefaultConnectionKeepAlivePingInterval = 25 * time.Second
266
267func GraphQL(exec graphql.ExecutableSchema, options ...Option) http.HandlerFunc {
268 cfg := &Config{
269 cacheSize: DefaultCacheSize,
270 connectionKeepAlivePingInterval: DefaultConnectionKeepAlivePingInterval,
271 upgrader: websocket.Upgrader{
272 ReadBufferSize: 1024,
273 WriteBufferSize: 1024,
274 },
275 }
276
277 for _, option := range options {
278 option(cfg)
279 }
280
281 var cache *lru.Cache
282 if cfg.cacheSize > 0 {
283 var err error
284 cache, err = lru.New(cfg.cacheSize)
285 if err != nil {
286 // An error is only returned for non-positive cache size
287 // and we already checked for that.
288 panic("unexpected error creating cache: " + err.Error())
289 }
290 }
291 if cfg.tracer == nil {
292 cfg.tracer = &graphql.NopTracer{}
293 }
294
295 handler := &graphqlHandler{
296 cfg: cfg,
297 cache: cache,
298 exec: exec,
299 }
300
301 return handler.ServeHTTP
302}
303
304var _ http.Handler = (*graphqlHandler)(nil)
305
306type graphqlHandler struct {
307 cfg *Config
308 cache *lru.Cache
309 exec graphql.ExecutableSchema
310}
311
312func (gh *graphqlHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
313 if r.Method == http.MethodOptions {
314 w.Header().Set("Allow", "OPTIONS, GET, POST")
315 w.WriteHeader(http.StatusOK)
316 return
317 }
318
319 if strings.Contains(r.Header.Get("Upgrade"), "websocket") {
320 connectWs(gh.exec, w, r, gh.cfg, gh.cache)
321 return
322 }
323
324 w.Header().Set("Content-Type", "application/json")
325 var reqParams params
326 switch r.Method {
327 case http.MethodGet:
328 reqParams.Query = r.URL.Query().Get("query")
329 reqParams.OperationName = r.URL.Query().Get("operationName")
330
331 if variables := r.URL.Query().Get("variables"); variables != "" {
332 if err := jsonDecode(strings.NewReader(variables), &reqParams.Variables); err != nil {
333 sendErrorf(w, http.StatusBadRequest, "variables could not be decoded")
334 return
335 }
336 }
337 case http.MethodPost:
338 if err := jsonDecode(r.Body, &reqParams); err != nil {
339 sendErrorf(w, http.StatusBadRequest, "json body could not be decoded: "+err.Error())
340 return
341 }
342 default:
343 w.WriteHeader(http.StatusMethodNotAllowed)
344 return
345 }
346
347 ctx := r.Context()
348
349 var doc *ast.QueryDocument
350 var cacheHit bool
351 if gh.cache != nil {
352 val, ok := gh.cache.Get(reqParams.Query)
353 if ok {
354 doc = val.(*ast.QueryDocument)
355 cacheHit = true
356 }
357 }
358
359 ctx, doc, gqlErr := gh.parseOperation(ctx, &parseOperationArgs{
360 Query: reqParams.Query,
361 CachedDoc: doc,
362 })
363 if gqlErr != nil {
364 sendError(w, http.StatusUnprocessableEntity, gqlErr)
365 return
366 }
367
368 ctx, op, vars, listErr := gh.validateOperation(ctx, &validateOperationArgs{
369 Doc: doc,
370 OperationName: reqParams.OperationName,
371 CacheHit: cacheHit,
372 R: r,
373 Variables: reqParams.Variables,
374 })
375 if len(listErr) != 0 {
376 sendError(w, http.StatusUnprocessableEntity, listErr...)
377 return
378 }
379
380 if gh.cache != nil && !cacheHit {
381 gh.cache.Add(reqParams.Query, doc)
382 }
383
384 reqCtx := gh.cfg.newRequestContext(gh.exec, doc, op, reqParams.Query, vars)
385 ctx = graphql.WithRequestContext(ctx, reqCtx)
386
387 defer func() {
388 if err := recover(); err != nil {
389 userErr := reqCtx.Recover(ctx, err)
390 sendErrorf(w, http.StatusUnprocessableEntity, userErr.Error())
391 }
392 }()
393
394 if gh.cfg.complexityLimitFunc != nil {
395 reqCtx.ComplexityLimit = gh.cfg.complexityLimitFunc(ctx)
396 }
397
398 if reqCtx.ComplexityLimit > 0 && reqCtx.OperationComplexity > reqCtx.ComplexityLimit {
399 sendErrorf(w, http.StatusUnprocessableEntity, "operation has complexity %d, which exceeds the limit of %d", reqCtx.OperationComplexity, reqCtx.ComplexityLimit)
400 return
401 }
402
403 switch op.Operation {
404 case ast.Query:
405 b, err := json.Marshal(gh.exec.Query(ctx, op))
406 if err != nil {
407 panic(err)
408 }
409 w.Write(b)
410 case ast.Mutation:
411 b, err := json.Marshal(gh.exec.Mutation(ctx, op))
412 if err != nil {
413 panic(err)
414 }
415 w.Write(b)
416 default:
417 sendErrorf(w, http.StatusBadRequest, "unsupported operation type")
418 }
419}
420
421type parseOperationArgs struct {
422 Query string
423 CachedDoc *ast.QueryDocument
424}
425
426func (gh *graphqlHandler) parseOperation(ctx context.Context, args *parseOperationArgs) (context.Context, *ast.QueryDocument, *gqlerror.Error) {
427 ctx = gh.cfg.tracer.StartOperationParsing(ctx)
428 defer func() { gh.cfg.tracer.EndOperationParsing(ctx) }()
429
430 if args.CachedDoc != nil {
431 return ctx, args.CachedDoc, nil
432 }
433
434 doc, gqlErr := parser.ParseQuery(&ast.Source{Input: args.Query})
435 if gqlErr != nil {
436 return ctx, nil, gqlErr
437 }
438
439 return ctx, doc, nil
440}
441
442type validateOperationArgs struct {
443 Doc *ast.QueryDocument
444 OperationName string
445 CacheHit bool
446 R *http.Request
447 Variables map[string]interface{}
448}
449
450func (gh *graphqlHandler) validateOperation(ctx context.Context, args *validateOperationArgs) (context.Context, *ast.OperationDefinition, map[string]interface{}, gqlerror.List) {
451 ctx = gh.cfg.tracer.StartOperationValidation(ctx)
452 defer func() { gh.cfg.tracer.EndOperationValidation(ctx) }()
453
454 if !args.CacheHit {
455 listErr := validator.Validate(gh.exec.Schema(), args.Doc)
456 if len(listErr) != 0 {
457 return ctx, nil, nil, listErr
458 }
459 }
460
461 op := args.Doc.Operations.ForName(args.OperationName)
462 if op == nil {
463 return ctx, nil, nil, gqlerror.List{gqlerror.Errorf("operation %s not found", args.OperationName)}
464 }
465
466 if op.Operation != ast.Query && args.R.Method == http.MethodGet {
467 return ctx, nil, nil, gqlerror.List{gqlerror.Errorf("GET requests only allow query operations")}
468 }
469
470 vars, err := validator.VariableValues(gh.exec.Schema(), op, args.Variables)
471 if err != nil {
472 return ctx, nil, nil, gqlerror.List{err}
473 }
474
475 return ctx, op, vars, nil
476}
477
478func jsonDecode(r io.Reader, val interface{}) error {
479 dec := json.NewDecoder(r)
480 dec.UseNumber()
481 return dec.Decode(val)
482}
483
484func sendError(w http.ResponseWriter, code int, errors ...*gqlerror.Error) {
485 w.WriteHeader(code)
486 b, err := json.Marshal(&graphql.Response{Errors: errors})
487 if err != nil {
488 panic(err)
489 }
490 w.Write(b)
491}
492
493func sendErrorf(w http.ResponseWriter, code int, format string, args ...interface{}) {
494 sendError(w, code, &gqlerror.Error{Message: fmt.Sprintf(format, args...)})
495}