graphql.go

  1package handler
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"net/http"
  8	"strings"
  9
 10	"github.com/gorilla/websocket"
 11	"github.com/vektah/gqlgen/graphql"
 12	"github.com/vektah/gqlgen/neelance/errors"
 13	"github.com/vektah/gqlgen/neelance/query"
 14	"github.com/vektah/gqlgen/neelance/validation"
 15)
 16
 17type params struct {
 18	Query         string                 `json:"query"`
 19	OperationName string                 `json:"operationName"`
 20	Variables     map[string]interface{} `json:"variables"`
 21}
 22
 23type Config struct {
 24	upgrader       websocket.Upgrader
 25	recover        graphql.RecoverFunc
 26	errorPresenter graphql.ErrorPresenterFunc
 27	resolverHook   graphql.ResolverMiddleware
 28	requestHook    graphql.RequestMiddleware
 29}
 30
 31func (c *Config) newRequestContext(doc *query.Document, query string, variables map[string]interface{}) *graphql.RequestContext {
 32	reqCtx := graphql.NewRequestContext(doc, query, variables)
 33	if hook := c.recover; hook != nil {
 34		reqCtx.Recover = hook
 35	}
 36
 37	if hook := c.errorPresenter; hook != nil {
 38		reqCtx.ErrorPresenter = hook
 39	}
 40
 41	if hook := c.resolverHook; hook != nil {
 42		reqCtx.ResolverMiddleware = hook
 43	}
 44
 45	if hook := c.requestHook; hook != nil {
 46		reqCtx.RequestMiddleware = hook
 47	}
 48
 49	return reqCtx
 50}
 51
 52type Option func(cfg *Config)
 53
 54func WebsocketUpgrader(upgrader websocket.Upgrader) Option {
 55	return func(cfg *Config) {
 56		cfg.upgrader = upgrader
 57	}
 58}
 59
 60func RecoverFunc(recover graphql.RecoverFunc) Option {
 61	return func(cfg *Config) {
 62		cfg.recover = recover
 63	}
 64}
 65
 66// ErrorPresenter transforms errors found while resolving into errors that will be returned to the user. It provides
 67// a good place to add any extra fields, like error.type, that might be desired by your frontend. Check the default
 68// implementation in graphql.DefaultErrorPresenter for an example.
 69func ErrorPresenter(f graphql.ErrorPresenterFunc) Option {
 70	return func(cfg *Config) {
 71		cfg.errorPresenter = f
 72	}
 73}
 74
 75// ResolverMiddleware allows you to define a function that will be called around every resolver,
 76// useful for tracing and logging.
 77// It will only be called for user defined resolvers, any direct binding to models is assumed
 78// to cost nothing.
 79func ResolverMiddleware(middleware graphql.ResolverMiddleware) Option {
 80	return func(cfg *Config) {
 81		if cfg.resolverHook == nil {
 82			cfg.resolverHook = middleware
 83			return
 84		}
 85
 86		lastResolve := cfg.resolverHook
 87		cfg.resolverHook = func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) {
 88			return lastResolve(ctx, func(ctx context.Context) (res interface{}, err error) {
 89				return middleware(ctx, next)
 90			})
 91		}
 92	}
 93}
 94
 95// RequestMiddleware allows you to define a function that will be called around the root request,
 96// after the query has been parsed. This is useful for logging and tracing
 97func RequestMiddleware(middleware graphql.RequestMiddleware) Option {
 98	return func(cfg *Config) {
 99		if cfg.requestHook == nil {
100			cfg.requestHook = middleware
101			return
102		}
103
104		lastResolve := cfg.requestHook
105		cfg.requestHook = func(ctx context.Context, next func(ctx context.Context) []byte) []byte {
106			return lastResolve(ctx, func(ctx context.Context) []byte {
107				return middleware(ctx, next)
108			})
109		}
110	}
111}
112
113func GraphQL(exec graphql.ExecutableSchema, options ...Option) http.HandlerFunc {
114	cfg := Config{
115		upgrader: websocket.Upgrader{
116			ReadBufferSize:  1024,
117			WriteBufferSize: 1024,
118		},
119	}
120
121	for _, option := range options {
122		option(&cfg)
123	}
124
125	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
126		if r.Method == http.MethodOptions {
127			w.Header().Set("Allow", "OPTIONS, GET, POST")
128			w.WriteHeader(http.StatusOK)
129			return
130		}
131
132		if strings.Contains(r.Header.Get("Upgrade"), "websocket") {
133			connectWs(exec, w, r, &cfg)
134			return
135		}
136
137		var reqParams params
138		switch r.Method {
139		case http.MethodGet:
140			reqParams.Query = r.URL.Query().Get("query")
141			reqParams.OperationName = r.URL.Query().Get("operationName")
142
143			if variables := r.URL.Query().Get("variables"); variables != "" {
144				if err := json.Unmarshal([]byte(variables), &reqParams.Variables); err != nil {
145					sendErrorf(w, http.StatusBadRequest, "variables could not be decoded")
146					return
147				}
148			}
149		case http.MethodPost:
150			if err := json.NewDecoder(r.Body).Decode(&reqParams); err != nil {
151				sendErrorf(w, http.StatusBadRequest, "json body could not be decoded: "+err.Error())
152				return
153			}
154		default:
155			w.WriteHeader(http.StatusMethodNotAllowed)
156			return
157		}
158		w.Header().Set("Content-Type", "application/json")
159
160		doc, qErr := query.Parse(reqParams.Query)
161		if qErr != nil {
162			sendError(w, http.StatusUnprocessableEntity, qErr)
163			return
164		}
165
166		errs := validation.Validate(exec.Schema(), doc)
167		if len(errs) != 0 {
168			sendError(w, http.StatusUnprocessableEntity, errs...)
169			return
170		}
171
172		op, err := doc.GetOperation(reqParams.OperationName)
173		if err != nil {
174			sendErrorf(w, http.StatusUnprocessableEntity, err.Error())
175			return
176		}
177
178		reqCtx := cfg.newRequestContext(doc, reqParams.Query, reqParams.Variables)
179		ctx := graphql.WithRequestContext(r.Context(), reqCtx)
180
181		defer func() {
182			if err := recover(); err != nil {
183				userErr := reqCtx.Recover(ctx, err)
184				sendErrorf(w, http.StatusUnprocessableEntity, userErr.Error())
185			}
186		}()
187
188		switch op.Type {
189		case query.Query:
190			b, err := json.Marshal(exec.Query(ctx, op))
191			if err != nil {
192				panic(err)
193			}
194			w.Write(b)
195		case query.Mutation:
196			b, err := json.Marshal(exec.Mutation(ctx, op))
197			if err != nil {
198				panic(err)
199			}
200			w.Write(b)
201		default:
202			sendErrorf(w, http.StatusBadRequest, "unsupported operation type")
203		}
204	})
205}
206
207func sendError(w http.ResponseWriter, code int, errors ...*errors.QueryError) {
208	w.WriteHeader(code)
209	var errs []*graphql.Error
210	for _, err := range errors {
211		var locations []graphql.ErrorLocation
212		for _, l := range err.Locations {
213			fmt.Println(graphql.ErrorLocation(l))
214			locations = append(locations, graphql.ErrorLocation{
215				Line:   l.Line,
216				Column: l.Column,
217			})
218		}
219
220		errs = append(errs, &graphql.Error{
221			Message:   err.Message,
222			Path:      err.Path,
223			Locations: locations,
224		})
225	}
226	b, err := json.Marshal(&graphql.Response{Errors: errs})
227	if err != nil {
228		panic(err)
229	}
230	w.Write(b)
231}
232
233func sendErrorf(w http.ResponseWriter, code int, format string, args ...interface{}) {
234	sendError(w, code, &errors.QueryError{Message: fmt.Sprintf(format, args...)})
235}