graphql.go

  1package handler
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"io"
  8	"net/http"
  9	"strings"
 10
 11	"github.com/99designs/gqlgen/complexity"
 12	"github.com/99designs/gqlgen/graphql"
 13	"github.com/gorilla/websocket"
 14	"github.com/hashicorp/golang-lru"
 15	"github.com/vektah/gqlparser"
 16	"github.com/vektah/gqlparser/ast"
 17	"github.com/vektah/gqlparser/gqlerror"
 18	"github.com/vektah/gqlparser/validator"
 19)
 20
 21type params struct {
 22	Query         string                 `json:"query"`
 23	OperationName string                 `json:"operationName"`
 24	Variables     map[string]interface{} `json:"variables"`
 25}
 26
 27type Config struct {
 28	cacheSize       int
 29	upgrader        websocket.Upgrader
 30	recover         graphql.RecoverFunc
 31	errorPresenter  graphql.ErrorPresenterFunc
 32	resolverHook    graphql.FieldMiddleware
 33	requestHook     graphql.RequestMiddleware
 34	complexityLimit int
 35}
 36
 37func (c *Config) newRequestContext(doc *ast.QueryDocument, query string, variables map[string]interface{}) *graphql.RequestContext {
 38	reqCtx := graphql.NewRequestContext(doc, query, variables)
 39	if hook := c.recover; hook != nil {
 40		reqCtx.Recover = hook
 41	}
 42
 43	if hook := c.errorPresenter; hook != nil {
 44		reqCtx.ErrorPresenter = hook
 45	}
 46
 47	if hook := c.resolverHook; hook != nil {
 48		reqCtx.ResolverMiddleware = hook
 49	}
 50
 51	if hook := c.requestHook; hook != nil {
 52		reqCtx.RequestMiddleware = hook
 53	}
 54
 55	return reqCtx
 56}
 57
 58type Option func(cfg *Config)
 59
 60func WebsocketUpgrader(upgrader websocket.Upgrader) Option {
 61	return func(cfg *Config) {
 62		cfg.upgrader = upgrader
 63	}
 64}
 65
 66func RecoverFunc(recover graphql.RecoverFunc) Option {
 67	return func(cfg *Config) {
 68		cfg.recover = recover
 69	}
 70}
 71
 72// ErrorPresenter transforms errors found while resolving into errors that will be returned to the user. It provides
 73// a good place to add any extra fields, like error.type, that might be desired by your frontend. Check the default
 74// implementation in graphql.DefaultErrorPresenter for an example.
 75func ErrorPresenter(f graphql.ErrorPresenterFunc) Option {
 76	return func(cfg *Config) {
 77		cfg.errorPresenter = f
 78	}
 79}
 80
 81// ComplexityLimit sets a maximum query complexity that is allowed to be executed.
 82// If a query is submitted that exceeds the limit, a 422 status code will be returned.
 83func ComplexityLimit(limit int) Option {
 84	return func(cfg *Config) {
 85		cfg.complexityLimit = limit
 86	}
 87}
 88
 89// ResolverMiddleware allows you to define a function that will be called around every resolver,
 90// useful for tracing and logging.
 91func ResolverMiddleware(middleware graphql.FieldMiddleware) Option {
 92	return func(cfg *Config) {
 93		if cfg.resolverHook == nil {
 94			cfg.resolverHook = middleware
 95			return
 96		}
 97
 98		lastResolve := cfg.resolverHook
 99		cfg.resolverHook = func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) {
100			return lastResolve(ctx, func(ctx context.Context) (res interface{}, err error) {
101				return middleware(ctx, next)
102			})
103		}
104	}
105}
106
107// RequestMiddleware allows you to define a function that will be called around the root request,
108// after the query has been parsed. This is useful for logging and tracing
109func RequestMiddleware(middleware graphql.RequestMiddleware) Option {
110	return func(cfg *Config) {
111		if cfg.requestHook == nil {
112			cfg.requestHook = middleware
113			return
114		}
115
116		lastResolve := cfg.requestHook
117		cfg.requestHook = func(ctx context.Context, next func(ctx context.Context) []byte) []byte {
118			return lastResolve(ctx, func(ctx context.Context) []byte {
119				return middleware(ctx, next)
120			})
121		}
122	}
123}
124
125// CacheSize sets the maximum size of the query cache.
126// If size is less than or equal to 0, the cache is disabled.
127func CacheSize(size int) Option {
128	return func(cfg *Config) {
129		cfg.cacheSize = size
130	}
131}
132
133const DefaultCacheSize = 1000
134
135func GraphQL(exec graphql.ExecutableSchema, options ...Option) http.HandlerFunc {
136	cfg := Config{
137		cacheSize: DefaultCacheSize,
138		upgrader: websocket.Upgrader{
139			ReadBufferSize:  1024,
140			WriteBufferSize: 1024,
141		},
142	}
143
144	for _, option := range options {
145		option(&cfg)
146	}
147
148	var cache *lru.Cache
149	if cfg.cacheSize > 0 {
150		var err error
151		cache, err = lru.New(DefaultCacheSize)
152		if err != nil {
153			// An error is only returned for non-positive cache size
154			// and we already checked for that.
155			panic("unexpected error creating cache: " + err.Error())
156		}
157	}
158
159	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
160		if r.Method == http.MethodOptions {
161			w.Header().Set("Allow", "OPTIONS, GET, POST")
162			w.WriteHeader(http.StatusOK)
163			return
164		}
165
166		if strings.Contains(r.Header.Get("Upgrade"), "websocket") {
167			connectWs(exec, w, r, &cfg)
168			return
169		}
170
171		var reqParams params
172		switch r.Method {
173		case http.MethodGet:
174			reqParams.Query = r.URL.Query().Get("query")
175			reqParams.OperationName = r.URL.Query().Get("operationName")
176
177			if variables := r.URL.Query().Get("variables"); variables != "" {
178				if err := jsonDecode(strings.NewReader(variables), &reqParams.Variables); err != nil {
179					sendErrorf(w, http.StatusBadRequest, "variables could not be decoded")
180					return
181				}
182			}
183		case http.MethodPost:
184			if err := jsonDecode(r.Body, &reqParams); err != nil {
185				sendErrorf(w, http.StatusBadRequest, "json body could not be decoded: "+err.Error())
186				return
187			}
188		default:
189			w.WriteHeader(http.StatusMethodNotAllowed)
190			return
191		}
192		w.Header().Set("Content-Type", "application/json")
193
194		var doc *ast.QueryDocument
195		if cache != nil {
196			val, ok := cache.Get(reqParams.Query)
197			if ok {
198				doc = val.(*ast.QueryDocument)
199			}
200		}
201		if doc == nil {
202			var qErr gqlerror.List
203			doc, qErr = gqlparser.LoadQuery(exec.Schema(), reqParams.Query)
204			if len(qErr) > 0 {
205				sendError(w, http.StatusUnprocessableEntity, qErr...)
206				return
207			}
208			if cache != nil {
209				cache.Add(reqParams.Query, doc)
210			}
211		}
212
213		op := doc.Operations.ForName(reqParams.OperationName)
214		if op == nil {
215			sendErrorf(w, http.StatusUnprocessableEntity, "operation %s not found", reqParams.OperationName)
216			return
217		}
218
219		if op.Operation != ast.Query && r.Method == http.MethodGet {
220			sendErrorf(w, http.StatusUnprocessableEntity, "GET requests only allow query operations")
221			return
222		}
223
224		vars, err := validator.VariableValues(exec.Schema(), op, reqParams.Variables)
225		if err != nil {
226			sendError(w, http.StatusUnprocessableEntity, err)
227			return
228		}
229		reqCtx := cfg.newRequestContext(doc, reqParams.Query, vars)
230		ctx := graphql.WithRequestContext(r.Context(), reqCtx)
231
232		defer func() {
233			if err := recover(); err != nil {
234				userErr := reqCtx.Recover(ctx, err)
235				sendErrorf(w, http.StatusUnprocessableEntity, userErr.Error())
236			}
237		}()
238
239		if cfg.complexityLimit > 0 {
240			queryComplexity := complexity.Calculate(exec, op, vars)
241			if queryComplexity > cfg.complexityLimit {
242				sendErrorf(w, http.StatusUnprocessableEntity, "query has complexity %d, which exceeds the limit of %d", queryComplexity, cfg.complexityLimit)
243				return
244			}
245		}
246
247		switch op.Operation {
248		case ast.Query:
249			b, err := json.Marshal(exec.Query(ctx, op))
250			if err != nil {
251				panic(err)
252			}
253			w.Write(b)
254		case ast.Mutation:
255			b, err := json.Marshal(exec.Mutation(ctx, op))
256			if err != nil {
257				panic(err)
258			}
259			w.Write(b)
260		default:
261			sendErrorf(w, http.StatusBadRequest, "unsupported operation type")
262		}
263	})
264}
265
266func jsonDecode(r io.Reader, val interface{}) error {
267	dec := json.NewDecoder(r)
268	dec.UseNumber()
269	return dec.Decode(val)
270}
271
272func sendError(w http.ResponseWriter, code int, errors ...*gqlerror.Error) {
273	w.WriteHeader(code)
274	b, err := json.Marshal(&graphql.Response{Errors: errors})
275	if err != nil {
276		panic(err)
277	}
278	w.Write(b)
279}
280
281func sendErrorf(w http.ResponseWriter, code int, format string, args ...interface{}) {
282	sendError(w, code, &gqlerror.Error{Message: fmt.Sprintf(format, args...)})
283}