1package handler
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "net/http"
8 "strings"
9
10 "github.com/gorilla/websocket"
11 "github.com/vektah/gqlgen/graphql"
12 "github.com/vektah/gqlgen/neelance/errors"
13 "github.com/vektah/gqlgen/neelance/query"
14 "github.com/vektah/gqlgen/neelance/validation"
15)
16
17type params struct {
18 Query string `json:"query"`
19 OperationName string `json:"operationName"`
20 Variables map[string]interface{} `json:"variables"`
21}
22
23type Config struct {
24 upgrader websocket.Upgrader
25 recover graphql.RecoverFunc
26 errorPresenter graphql.ErrorPresenterFunc
27 resolverHook graphql.ResolverMiddleware
28 requestHook graphql.RequestMiddleware
29}
30
31func (c *Config) newRequestContext(doc *query.Document, query string, variables map[string]interface{}) *graphql.RequestContext {
32 reqCtx := graphql.NewRequestContext(doc, query, variables)
33 if hook := c.recover; hook != nil {
34 reqCtx.Recover = hook
35 }
36
37 if hook := c.errorPresenter; hook != nil {
38 reqCtx.ErrorPresenter = hook
39 }
40
41 if hook := c.resolverHook; hook != nil {
42 reqCtx.ResolverMiddleware = hook
43 }
44
45 if hook := c.requestHook; hook != nil {
46 reqCtx.RequestMiddleware = hook
47 }
48
49 return reqCtx
50}
51
52type Option func(cfg *Config)
53
54func WebsocketUpgrader(upgrader websocket.Upgrader) Option {
55 return func(cfg *Config) {
56 cfg.upgrader = upgrader
57 }
58}
59
60func RecoverFunc(recover graphql.RecoverFunc) Option {
61 return func(cfg *Config) {
62 cfg.recover = recover
63 }
64}
65
66// ErrorPresenter transforms errors found while resolving into errors that will be returned to the user. It provides
67// a good place to add any extra fields, like error.type, that might be desired by your frontend. Check the default
68// implementation in graphql.DefaultErrorPresenter for an example.
69func ErrorPresenter(f graphql.ErrorPresenterFunc) Option {
70 return func(cfg *Config) {
71 cfg.errorPresenter = f
72 }
73}
74
75// ResolverMiddleware allows you to define a function that will be called around every resolver,
76// useful for tracing and logging.
77// It will only be called for user defined resolvers, any direct binding to models is assumed
78// to cost nothing.
79func ResolverMiddleware(middleware graphql.ResolverMiddleware) Option {
80 return func(cfg *Config) {
81 if cfg.resolverHook == nil {
82 cfg.resolverHook = middleware
83 return
84 }
85
86 lastResolve := cfg.resolverHook
87 cfg.resolverHook = func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) {
88 return lastResolve(ctx, func(ctx context.Context) (res interface{}, err error) {
89 return middleware(ctx, next)
90 })
91 }
92 }
93}
94
95// RequestMiddleware allows you to define a function that will be called around the root request,
96// after the query has been parsed. This is useful for logging and tracing
97func RequestMiddleware(middleware graphql.RequestMiddleware) Option {
98 return func(cfg *Config) {
99 if cfg.requestHook == nil {
100 cfg.requestHook = middleware
101 return
102 }
103
104 lastResolve := cfg.requestHook
105 cfg.requestHook = func(ctx context.Context, next func(ctx context.Context) []byte) []byte {
106 return lastResolve(ctx, func(ctx context.Context) []byte {
107 return middleware(ctx, next)
108 })
109 }
110 }
111}
112
113func GraphQL(exec graphql.ExecutableSchema, options ...Option) http.HandlerFunc {
114 cfg := Config{
115 upgrader: websocket.Upgrader{
116 ReadBufferSize: 1024,
117 WriteBufferSize: 1024,
118 },
119 }
120
121 for _, option := range options {
122 option(&cfg)
123 }
124
125 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
126 if r.Method == http.MethodOptions {
127 w.Header().Set("Allow", "OPTIONS, GET, POST")
128 w.WriteHeader(http.StatusOK)
129 return
130 }
131
132 if strings.Contains(r.Header.Get("Upgrade"), "websocket") {
133 connectWs(exec, w, r, &cfg)
134 return
135 }
136
137 var reqParams params
138 switch r.Method {
139 case http.MethodGet:
140 reqParams.Query = r.URL.Query().Get("query")
141 reqParams.OperationName = r.URL.Query().Get("operationName")
142
143 if variables := r.URL.Query().Get("variables"); variables != "" {
144 if err := json.Unmarshal([]byte(variables), &reqParams.Variables); err != nil {
145 sendErrorf(w, http.StatusBadRequest, "variables could not be decoded")
146 return
147 }
148 }
149 case http.MethodPost:
150 if err := json.NewDecoder(r.Body).Decode(&reqParams); err != nil {
151 sendErrorf(w, http.StatusBadRequest, "json body could not be decoded: "+err.Error())
152 return
153 }
154 default:
155 w.WriteHeader(http.StatusMethodNotAllowed)
156 return
157 }
158 w.Header().Set("Content-Type", "application/json")
159
160 doc, qErr := query.Parse(reqParams.Query)
161 if qErr != nil {
162 sendError(w, http.StatusUnprocessableEntity, qErr)
163 return
164 }
165
166 errs := validation.Validate(exec.Schema(), doc)
167 if len(errs) != 0 {
168 sendError(w, http.StatusUnprocessableEntity, errs...)
169 return
170 }
171
172 op, err := doc.GetOperation(reqParams.OperationName)
173 if err != nil {
174 sendErrorf(w, http.StatusUnprocessableEntity, err.Error())
175 return
176 }
177
178 reqCtx := cfg.newRequestContext(doc, reqParams.Query, reqParams.Variables)
179 ctx := graphql.WithRequestContext(r.Context(), reqCtx)
180
181 defer func() {
182 if err := recover(); err != nil {
183 userErr := reqCtx.Recover(ctx, err)
184 sendErrorf(w, http.StatusUnprocessableEntity, userErr.Error())
185 }
186 }()
187
188 switch op.Type {
189 case query.Query:
190 b, err := json.Marshal(exec.Query(ctx, op))
191 if err != nil {
192 panic(err)
193 }
194 w.Write(b)
195 case query.Mutation:
196 b, err := json.Marshal(exec.Mutation(ctx, op))
197 if err != nil {
198 panic(err)
199 }
200 w.Write(b)
201 default:
202 sendErrorf(w, http.StatusBadRequest, "unsupported operation type")
203 }
204 })
205}
206
207func sendError(w http.ResponseWriter, code int, errors ...*errors.QueryError) {
208 w.WriteHeader(code)
209 var errs []*graphql.Error
210 for _, err := range errors {
211 var locations []graphql.ErrorLocation
212 for _, l := range err.Locations {
213 fmt.Println(graphql.ErrorLocation(l))
214 locations = append(locations, graphql.ErrorLocation{
215 Line: l.Line,
216 Column: l.Column,
217 })
218 }
219
220 errs = append(errs, &graphql.Error{
221 Message: err.Message,
222 Path: err.Path,
223 Locations: locations,
224 })
225 }
226 b, err := json.Marshal(&graphql.Response{Errors: errs})
227 if err != nil {
228 panic(err)
229 }
230 w.Write(b)
231}
232
233func sendErrorf(w http.ResponseWriter, code int, format string, args ...interface{}) {
234 sendError(w, code, &errors.QueryError{Message: fmt.Sprintf(format, args...)})
235}