1package handler
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "io"
8 "net/http"
9 "strings"
10
11 "github.com/99designs/gqlgen/complexity"
12 "github.com/99designs/gqlgen/graphql"
13 "github.com/gorilla/websocket"
14 "github.com/hashicorp/golang-lru"
15 "github.com/vektah/gqlparser"
16 "github.com/vektah/gqlparser/ast"
17 "github.com/vektah/gqlparser/gqlerror"
18 "github.com/vektah/gqlparser/validator"
19)
20
21type params struct {
22 Query string `json:"query"`
23 OperationName string `json:"operationName"`
24 Variables map[string]interface{} `json:"variables"`
25}
26
27type Config struct {
28 cacheSize int
29 upgrader websocket.Upgrader
30 recover graphql.RecoverFunc
31 errorPresenter graphql.ErrorPresenterFunc
32 resolverHook graphql.FieldMiddleware
33 requestHook graphql.RequestMiddleware
34 complexityLimit int
35}
36
37func (c *Config) newRequestContext(doc *ast.QueryDocument, query string, variables map[string]interface{}) *graphql.RequestContext {
38 reqCtx := graphql.NewRequestContext(doc, query, variables)
39 if hook := c.recover; hook != nil {
40 reqCtx.Recover = hook
41 }
42
43 if hook := c.errorPresenter; hook != nil {
44 reqCtx.ErrorPresenter = hook
45 }
46
47 if hook := c.resolverHook; hook != nil {
48 reqCtx.ResolverMiddleware = hook
49 }
50
51 if hook := c.requestHook; hook != nil {
52 reqCtx.RequestMiddleware = hook
53 }
54
55 return reqCtx
56}
57
58type Option func(cfg *Config)
59
60func WebsocketUpgrader(upgrader websocket.Upgrader) Option {
61 return func(cfg *Config) {
62 cfg.upgrader = upgrader
63 }
64}
65
66func RecoverFunc(recover graphql.RecoverFunc) Option {
67 return func(cfg *Config) {
68 cfg.recover = recover
69 }
70}
71
72// ErrorPresenter transforms errors found while resolving into errors that will be returned to the user. It provides
73// a good place to add any extra fields, like error.type, that might be desired by your frontend. Check the default
74// implementation in graphql.DefaultErrorPresenter for an example.
75func ErrorPresenter(f graphql.ErrorPresenterFunc) Option {
76 return func(cfg *Config) {
77 cfg.errorPresenter = f
78 }
79}
80
81// ComplexityLimit sets a maximum query complexity that is allowed to be executed.
82// If a query is submitted that exceeds the limit, a 422 status code will be returned.
83func ComplexityLimit(limit int) Option {
84 return func(cfg *Config) {
85 cfg.complexityLimit = limit
86 }
87}
88
89// ResolverMiddleware allows you to define a function that will be called around every resolver,
90// useful for tracing and logging.
91func ResolverMiddleware(middleware graphql.FieldMiddleware) Option {
92 return func(cfg *Config) {
93 if cfg.resolverHook == nil {
94 cfg.resolverHook = middleware
95 return
96 }
97
98 lastResolve := cfg.resolverHook
99 cfg.resolverHook = func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) {
100 return lastResolve(ctx, func(ctx context.Context) (res interface{}, err error) {
101 return middleware(ctx, next)
102 })
103 }
104 }
105}
106
107// RequestMiddleware allows you to define a function that will be called around the root request,
108// after the query has been parsed. This is useful for logging and tracing
109func RequestMiddleware(middleware graphql.RequestMiddleware) Option {
110 return func(cfg *Config) {
111 if cfg.requestHook == nil {
112 cfg.requestHook = middleware
113 return
114 }
115
116 lastResolve := cfg.requestHook
117 cfg.requestHook = func(ctx context.Context, next func(ctx context.Context) []byte) []byte {
118 return lastResolve(ctx, func(ctx context.Context) []byte {
119 return middleware(ctx, next)
120 })
121 }
122 }
123}
124
125// CacheSize sets the maximum size of the query cache.
126// If size is less than or equal to 0, the cache is disabled.
127func CacheSize(size int) Option {
128 return func(cfg *Config) {
129 cfg.cacheSize = size
130 }
131}
132
133const DefaultCacheSize = 1000
134
135func GraphQL(exec graphql.ExecutableSchema, options ...Option) http.HandlerFunc {
136 cfg := Config{
137 cacheSize: DefaultCacheSize,
138 upgrader: websocket.Upgrader{
139 ReadBufferSize: 1024,
140 WriteBufferSize: 1024,
141 },
142 }
143
144 for _, option := range options {
145 option(&cfg)
146 }
147
148 var cache *lru.Cache
149 if cfg.cacheSize > 0 {
150 var err error
151 cache, err = lru.New(DefaultCacheSize)
152 if err != nil {
153 // An error is only returned for non-positive cache size
154 // and we already checked for that.
155 panic("unexpected error creating cache: " + err.Error())
156 }
157 }
158
159 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
160 if r.Method == http.MethodOptions {
161 w.Header().Set("Allow", "OPTIONS, GET, POST")
162 w.WriteHeader(http.StatusOK)
163 return
164 }
165
166 if strings.Contains(r.Header.Get("Upgrade"), "websocket") {
167 connectWs(exec, w, r, &cfg)
168 return
169 }
170
171 var reqParams params
172 switch r.Method {
173 case http.MethodGet:
174 reqParams.Query = r.URL.Query().Get("query")
175 reqParams.OperationName = r.URL.Query().Get("operationName")
176
177 if variables := r.URL.Query().Get("variables"); variables != "" {
178 if err := jsonDecode(strings.NewReader(variables), &reqParams.Variables); err != nil {
179 sendErrorf(w, http.StatusBadRequest, "variables could not be decoded")
180 return
181 }
182 }
183 case http.MethodPost:
184 if err := jsonDecode(r.Body, &reqParams); err != nil {
185 sendErrorf(w, http.StatusBadRequest, "json body could not be decoded: "+err.Error())
186 return
187 }
188 default:
189 w.WriteHeader(http.StatusMethodNotAllowed)
190 return
191 }
192 w.Header().Set("Content-Type", "application/json")
193
194 var doc *ast.QueryDocument
195 if cache != nil {
196 val, ok := cache.Get(reqParams.Query)
197 if ok {
198 doc = val.(*ast.QueryDocument)
199 }
200 }
201 if doc == nil {
202 var qErr gqlerror.List
203 doc, qErr = gqlparser.LoadQuery(exec.Schema(), reqParams.Query)
204 if len(qErr) > 0 {
205 sendError(w, http.StatusUnprocessableEntity, qErr...)
206 return
207 }
208 if cache != nil {
209 cache.Add(reqParams.Query, doc)
210 }
211 }
212
213 op := doc.Operations.ForName(reqParams.OperationName)
214 if op == nil {
215 sendErrorf(w, http.StatusUnprocessableEntity, "operation %s not found", reqParams.OperationName)
216 return
217 }
218
219 if op.Operation != ast.Query && r.Method == http.MethodGet {
220 sendErrorf(w, http.StatusUnprocessableEntity, "GET requests only allow query operations")
221 return
222 }
223
224 vars, err := validator.VariableValues(exec.Schema(), op, reqParams.Variables)
225 if err != nil {
226 sendError(w, http.StatusUnprocessableEntity, err)
227 return
228 }
229 reqCtx := cfg.newRequestContext(doc, reqParams.Query, vars)
230 ctx := graphql.WithRequestContext(r.Context(), reqCtx)
231
232 defer func() {
233 if err := recover(); err != nil {
234 userErr := reqCtx.Recover(ctx, err)
235 sendErrorf(w, http.StatusUnprocessableEntity, userErr.Error())
236 }
237 }()
238
239 if cfg.complexityLimit > 0 {
240 queryComplexity := complexity.Calculate(exec, op, vars)
241 if queryComplexity > cfg.complexityLimit {
242 sendErrorf(w, http.StatusUnprocessableEntity, "query has complexity %d, which exceeds the limit of %d", queryComplexity, cfg.complexityLimit)
243 return
244 }
245 }
246
247 switch op.Operation {
248 case ast.Query:
249 b, err := json.Marshal(exec.Query(ctx, op))
250 if err != nil {
251 panic(err)
252 }
253 w.Write(b)
254 case ast.Mutation:
255 b, err := json.Marshal(exec.Mutation(ctx, op))
256 if err != nil {
257 panic(err)
258 }
259 w.Write(b)
260 default:
261 sendErrorf(w, http.StatusBadRequest, "unsupported operation type")
262 }
263 })
264}
265
266func jsonDecode(r io.Reader, val interface{}) error {
267 dec := json.NewDecoder(r)
268 dec.UseNumber()
269 return dec.Decode(val)
270}
271
272func sendError(w http.ResponseWriter, code int, errors ...*gqlerror.Error) {
273 w.WriteHeader(code)
274 b, err := json.Marshal(&graphql.Response{Errors: errors})
275 if err != nil {
276 panic(err)
277 }
278 w.Write(b)
279}
280
281func sendErrorf(w http.ResponseWriter, code int, format string, args ...interface{}) {
282 sendError(w, code, &gqlerror.Error{Message: fmt.Sprintf(format, args...)})
283}