1package handler
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "io"
8 "net/http"
9 "strings"
10 "time"
11
12 "github.com/99designs/gqlgen/complexity"
13 "github.com/99designs/gqlgen/graphql"
14 "github.com/gorilla/websocket"
15 "github.com/hashicorp/golang-lru"
16 "github.com/vektah/gqlparser/ast"
17 "github.com/vektah/gqlparser/gqlerror"
18 "github.com/vektah/gqlparser/parser"
19 "github.com/vektah/gqlparser/validator"
20)
21
22type params struct {
23 Query string `json:"query"`
24 OperationName string `json:"operationName"`
25 Variables map[string]interface{} `json:"variables"`
26}
27
28type Config struct {
29 cacheSize int
30 upgrader websocket.Upgrader
31 recover graphql.RecoverFunc
32 errorPresenter graphql.ErrorPresenterFunc
33 resolverHook graphql.FieldMiddleware
34 requestHook graphql.RequestMiddleware
35 tracer graphql.Tracer
36 complexityLimit int
37 disableIntrospection bool
38 connectionKeepAlivePingInterval time.Duration
39}
40
41func (c *Config) newRequestContext(es graphql.ExecutableSchema, doc *ast.QueryDocument, op *ast.OperationDefinition, query string, variables map[string]interface{}) *graphql.RequestContext {
42 reqCtx := graphql.NewRequestContext(doc, query, variables)
43 reqCtx.DisableIntrospection = c.disableIntrospection
44
45 if hook := c.recover; hook != nil {
46 reqCtx.Recover = hook
47 }
48
49 if hook := c.errorPresenter; hook != nil {
50 reqCtx.ErrorPresenter = hook
51 }
52
53 if hook := c.resolverHook; hook != nil {
54 reqCtx.ResolverMiddleware = hook
55 }
56
57 if hook := c.requestHook; hook != nil {
58 reqCtx.RequestMiddleware = hook
59 }
60
61 if hook := c.tracer; hook != nil {
62 reqCtx.Tracer = hook
63 } else {
64 reqCtx.Tracer = &graphql.NopTracer{}
65 }
66
67 if c.complexityLimit > 0 {
68 reqCtx.ComplexityLimit = c.complexityLimit
69 operationComplexity := complexity.Calculate(es, op, variables)
70 reqCtx.OperationComplexity = operationComplexity
71 }
72
73 return reqCtx
74}
75
76type Option func(cfg *Config)
77
78func WebsocketUpgrader(upgrader websocket.Upgrader) Option {
79 return func(cfg *Config) {
80 cfg.upgrader = upgrader
81 }
82}
83
84func RecoverFunc(recover graphql.RecoverFunc) Option {
85 return func(cfg *Config) {
86 cfg.recover = recover
87 }
88}
89
90// ErrorPresenter transforms errors found while resolving into errors that will be returned to the user. It provides
91// a good place to add any extra fields, like error.type, that might be desired by your frontend. Check the default
92// implementation in graphql.DefaultErrorPresenter for an example.
93func ErrorPresenter(f graphql.ErrorPresenterFunc) Option {
94 return func(cfg *Config) {
95 cfg.errorPresenter = f
96 }
97}
98
99// IntrospectionEnabled = false will forbid clients from calling introspection endpoints. Can be useful in prod when you dont
100// want clients introspecting the full schema.
101func IntrospectionEnabled(enabled bool) Option {
102 return func(cfg *Config) {
103 cfg.disableIntrospection = !enabled
104 }
105}
106
107// ComplexityLimit sets a maximum query complexity that is allowed to be executed.
108// If a query is submitted that exceeds the limit, a 422 status code will be returned.
109func ComplexityLimit(limit int) Option {
110 return func(cfg *Config) {
111 cfg.complexityLimit = limit
112 }
113}
114
115// ResolverMiddleware allows you to define a function that will be called around every resolver,
116// useful for logging.
117func ResolverMiddleware(middleware graphql.FieldMiddleware) Option {
118 return func(cfg *Config) {
119 if cfg.resolverHook == nil {
120 cfg.resolverHook = middleware
121 return
122 }
123
124 lastResolve := cfg.resolverHook
125 cfg.resolverHook = func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) {
126 return lastResolve(ctx, func(ctx context.Context) (res interface{}, err error) {
127 return middleware(ctx, next)
128 })
129 }
130 }
131}
132
133// RequestMiddleware allows you to define a function that will be called around the root request,
134// after the query has been parsed. This is useful for logging
135func RequestMiddleware(middleware graphql.RequestMiddleware) Option {
136 return func(cfg *Config) {
137 if cfg.requestHook == nil {
138 cfg.requestHook = middleware
139 return
140 }
141
142 lastResolve := cfg.requestHook
143 cfg.requestHook = func(ctx context.Context, next func(ctx context.Context) []byte) []byte {
144 return lastResolve(ctx, func(ctx context.Context) []byte {
145 return middleware(ctx, next)
146 })
147 }
148 }
149}
150
151// Tracer allows you to add a request/resolver tracer that will be called around the root request,
152// calling resolver. This is useful for tracing
153func Tracer(tracer graphql.Tracer) Option {
154 return func(cfg *Config) {
155 if cfg.tracer == nil {
156 cfg.tracer = tracer
157
158 } else {
159 lastResolve := cfg.tracer
160 cfg.tracer = &tracerWrapper{
161 tracer1: lastResolve,
162 tracer2: tracer,
163 }
164 }
165
166 opt := RequestMiddleware(func(ctx context.Context, next func(ctx context.Context) []byte) []byte {
167 ctx = tracer.StartOperationExecution(ctx)
168 resp := next(ctx)
169 tracer.EndOperationExecution(ctx)
170
171 return resp
172 })
173 opt(cfg)
174 }
175}
176
177type tracerWrapper struct {
178 tracer1 graphql.Tracer
179 tracer2 graphql.Tracer
180}
181
182func (tw *tracerWrapper) StartOperationParsing(ctx context.Context) context.Context {
183 ctx = tw.tracer1.StartOperationParsing(ctx)
184 ctx = tw.tracer2.StartOperationParsing(ctx)
185 return ctx
186}
187
188func (tw *tracerWrapper) EndOperationParsing(ctx context.Context) {
189 tw.tracer2.EndOperationParsing(ctx)
190 tw.tracer1.EndOperationParsing(ctx)
191}
192
193func (tw *tracerWrapper) StartOperationValidation(ctx context.Context) context.Context {
194 ctx = tw.tracer1.StartOperationValidation(ctx)
195 ctx = tw.tracer2.StartOperationValidation(ctx)
196 return ctx
197}
198
199func (tw *tracerWrapper) EndOperationValidation(ctx context.Context) {
200 tw.tracer2.EndOperationValidation(ctx)
201 tw.tracer1.EndOperationValidation(ctx)
202}
203
204func (tw *tracerWrapper) StartOperationExecution(ctx context.Context) context.Context {
205 ctx = tw.tracer1.StartOperationExecution(ctx)
206 ctx = tw.tracer2.StartOperationExecution(ctx)
207 return ctx
208}
209
210func (tw *tracerWrapper) StartFieldExecution(ctx context.Context, field graphql.CollectedField) context.Context {
211 ctx = tw.tracer1.StartFieldExecution(ctx, field)
212 ctx = tw.tracer2.StartFieldExecution(ctx, field)
213 return ctx
214}
215
216func (tw *tracerWrapper) StartFieldResolverExecution(ctx context.Context, rc *graphql.ResolverContext) context.Context {
217 ctx = tw.tracer1.StartFieldResolverExecution(ctx, rc)
218 ctx = tw.tracer2.StartFieldResolverExecution(ctx, rc)
219 return ctx
220}
221
222func (tw *tracerWrapper) StartFieldChildExecution(ctx context.Context) context.Context {
223 ctx = tw.tracer1.StartFieldChildExecution(ctx)
224 ctx = tw.tracer2.StartFieldChildExecution(ctx)
225 return ctx
226}
227
228func (tw *tracerWrapper) EndFieldExecution(ctx context.Context) {
229 tw.tracer2.EndFieldExecution(ctx)
230 tw.tracer1.EndFieldExecution(ctx)
231}
232
233func (tw *tracerWrapper) EndOperationExecution(ctx context.Context) {
234 tw.tracer2.EndOperationExecution(ctx)
235 tw.tracer1.EndOperationExecution(ctx)
236}
237
238// CacheSize sets the maximum size of the query cache.
239// If size is less than or equal to 0, the cache is disabled.
240func CacheSize(size int) Option {
241 return func(cfg *Config) {
242 cfg.cacheSize = size
243 }
244}
245
246const DefaultCacheSize = 1000
247
248// WebsocketKeepAliveDuration allows you to reconfigure the keepAlive behavior.
249// By default, keep-alive is disabled.
250func WebsocketKeepAliveDuration(duration time.Duration) Option {
251 return func(cfg *Config) {
252 cfg.connectionKeepAlivePingInterval = duration
253 }
254}
255
256func GraphQL(exec graphql.ExecutableSchema, options ...Option) http.HandlerFunc {
257 cfg := &Config{
258 cacheSize: DefaultCacheSize,
259 upgrader: websocket.Upgrader{
260 ReadBufferSize: 1024,
261 WriteBufferSize: 1024,
262 },
263 }
264
265 for _, option := range options {
266 option(cfg)
267 }
268
269 var cache *lru.Cache
270 if cfg.cacheSize > 0 {
271 var err error
272 cache, err = lru.New(DefaultCacheSize)
273 if err != nil {
274 // An error is only returned for non-positive cache size
275 // and we already checked for that.
276 panic("unexpected error creating cache: " + err.Error())
277 }
278 }
279 if cfg.tracer == nil {
280 cfg.tracer = &graphql.NopTracer{}
281 }
282
283 handler := &graphqlHandler{
284 cfg: cfg,
285 cache: cache,
286 exec: exec,
287 }
288
289 return handler.ServeHTTP
290}
291
292var _ http.Handler = (*graphqlHandler)(nil)
293
294type graphqlHandler struct {
295 cfg *Config
296 cache *lru.Cache
297 exec graphql.ExecutableSchema
298}
299
300func (gh *graphqlHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
301 if r.Method == http.MethodOptions {
302 w.Header().Set("Allow", "OPTIONS, GET, POST")
303 w.WriteHeader(http.StatusOK)
304 return
305 }
306
307 if strings.Contains(r.Header.Get("Upgrade"), "websocket") {
308 connectWs(gh.exec, w, r, gh.cfg)
309 return
310 }
311
312 var reqParams params
313 switch r.Method {
314 case http.MethodGet:
315 reqParams.Query = r.URL.Query().Get("query")
316 reqParams.OperationName = r.URL.Query().Get("operationName")
317
318 if variables := r.URL.Query().Get("variables"); variables != "" {
319 if err := jsonDecode(strings.NewReader(variables), &reqParams.Variables); err != nil {
320 sendErrorf(w, http.StatusBadRequest, "variables could not be decoded")
321 return
322 }
323 }
324 case http.MethodPost:
325 if err := jsonDecode(r.Body, &reqParams); err != nil {
326 sendErrorf(w, http.StatusBadRequest, "json body could not be decoded: "+err.Error())
327 return
328 }
329 default:
330 w.WriteHeader(http.StatusMethodNotAllowed)
331 return
332 }
333 w.Header().Set("Content-Type", "application/json")
334
335 ctx := r.Context()
336
337 var doc *ast.QueryDocument
338 var cacheHit bool
339 if gh.cache != nil {
340 val, ok := gh.cache.Get(reqParams.Query)
341 if ok {
342 doc = val.(*ast.QueryDocument)
343 cacheHit = true
344 }
345 }
346
347 ctx, doc, gqlErr := gh.parseOperation(ctx, &parseOperationArgs{
348 Query: reqParams.Query,
349 CachedDoc: doc,
350 })
351 if gqlErr != nil {
352 sendError(w, http.StatusUnprocessableEntity, gqlErr)
353 return
354 }
355
356 ctx, op, vars, listErr := gh.validateOperation(ctx, &validateOperationArgs{
357 Doc: doc,
358 OperationName: reqParams.OperationName,
359 CacheHit: cacheHit,
360 R: r,
361 Variables: reqParams.Variables,
362 })
363 if len(listErr) != 0 {
364 sendError(w, http.StatusUnprocessableEntity, listErr...)
365 return
366 }
367
368 if gh.cache != nil && !cacheHit {
369 gh.cache.Add(reqParams.Query, doc)
370 }
371
372 reqCtx := gh.cfg.newRequestContext(gh.exec, doc, op, reqParams.Query, vars)
373 ctx = graphql.WithRequestContext(ctx, reqCtx)
374
375 defer func() {
376 if err := recover(); err != nil {
377 userErr := reqCtx.Recover(ctx, err)
378 sendErrorf(w, http.StatusUnprocessableEntity, userErr.Error())
379 }
380 }()
381
382 if reqCtx.ComplexityLimit > 0 && reqCtx.OperationComplexity > reqCtx.ComplexityLimit {
383 sendErrorf(w, http.StatusUnprocessableEntity, "operation has complexity %d, which exceeds the limit of %d", reqCtx.OperationComplexity, reqCtx.ComplexityLimit)
384 return
385 }
386
387 switch op.Operation {
388 case ast.Query:
389 b, err := json.Marshal(gh.exec.Query(ctx, op))
390 if err != nil {
391 panic(err)
392 }
393 w.Write(b)
394 case ast.Mutation:
395 b, err := json.Marshal(gh.exec.Mutation(ctx, op))
396 if err != nil {
397 panic(err)
398 }
399 w.Write(b)
400 default:
401 sendErrorf(w, http.StatusBadRequest, "unsupported operation type")
402 }
403}
404
405type parseOperationArgs struct {
406 Query string
407 CachedDoc *ast.QueryDocument
408}
409
410func (gh *graphqlHandler) parseOperation(ctx context.Context, args *parseOperationArgs) (context.Context, *ast.QueryDocument, *gqlerror.Error) {
411 ctx = gh.cfg.tracer.StartOperationParsing(ctx)
412 defer func() { gh.cfg.tracer.EndOperationParsing(ctx) }()
413
414 if args.CachedDoc != nil {
415 return ctx, args.CachedDoc, nil
416 }
417
418 doc, gqlErr := parser.ParseQuery(&ast.Source{Input: args.Query})
419 if gqlErr != nil {
420 return ctx, nil, gqlErr
421 }
422
423 return ctx, doc, nil
424}
425
426type validateOperationArgs struct {
427 Doc *ast.QueryDocument
428 OperationName string
429 CacheHit bool
430 R *http.Request
431 Variables map[string]interface{}
432}
433
434func (gh *graphqlHandler) validateOperation(ctx context.Context, args *validateOperationArgs) (context.Context, *ast.OperationDefinition, map[string]interface{}, gqlerror.List) {
435 ctx = gh.cfg.tracer.StartOperationValidation(ctx)
436 defer func() { gh.cfg.tracer.EndOperationValidation(ctx) }()
437
438 if !args.CacheHit {
439 listErr := validator.Validate(gh.exec.Schema(), args.Doc)
440 if len(listErr) != 0 {
441 return ctx, nil, nil, listErr
442 }
443 }
444
445 op := args.Doc.Operations.ForName(args.OperationName)
446 if op == nil {
447 return ctx, nil, nil, gqlerror.List{gqlerror.Errorf("operation %s not found", args.OperationName)}
448 }
449
450 if op.Operation != ast.Query && args.R.Method == http.MethodGet {
451 return ctx, nil, nil, gqlerror.List{gqlerror.Errorf("GET requests only allow query operations")}
452 }
453
454 vars, err := validator.VariableValues(gh.exec.Schema(), op, args.Variables)
455 if err != nil {
456 return ctx, nil, nil, gqlerror.List{err}
457 }
458
459 return ctx, op, vars, nil
460}
461
462func jsonDecode(r io.Reader, val interface{}) error {
463 dec := json.NewDecoder(r)
464 dec.UseNumber()
465 return dec.Decode(val)
466}
467
468func sendError(w http.ResponseWriter, code int, errors ...*gqlerror.Error) {
469 w.WriteHeader(code)
470 b, err := json.Marshal(&graphql.Response{Errors: errors})
471 if err != nil {
472 panic(err)
473 }
474 w.Write(b)
475}
476
477func sendErrorf(w http.ResponseWriter, code int, format string, args ...interface{}) {
478 sendError(w, code, &gqlerror.Error{Message: fmt.Sprintf(format, args...)})
479}