1package handler
2
3import (
4 "bytes"
5 "context"
6 "encoding/json"
7 "fmt"
8 "log"
9 "net/http"
10 "sync"
11 "time"
12
13 "github.com/99designs/gqlgen/graphql"
14 "github.com/gorilla/websocket"
15 "github.com/hashicorp/golang-lru"
16 "github.com/vektah/gqlparser"
17 "github.com/vektah/gqlparser/ast"
18 "github.com/vektah/gqlparser/gqlerror"
19 "github.com/vektah/gqlparser/validator"
20)
21
22const (
23 connectionInitMsg = "connection_init" // Client -> Server
24 connectionTerminateMsg = "connection_terminate" // Client -> Server
25 startMsg = "start" // Client -> Server
26 stopMsg = "stop" // Client -> Server
27 connectionAckMsg = "connection_ack" // Server -> Client
28 connectionErrorMsg = "connection_error" // Server -> Client
29 dataMsg = "data" // Server -> Client
30 errorMsg = "error" // Server -> Client
31 completeMsg = "complete" // Server -> Client
32 connectionKeepAliveMsg = "ka" // Server -> Client
33)
34
35type operationMessage struct {
36 Payload json.RawMessage `json:"payload,omitempty"`
37 ID string `json:"id,omitempty"`
38 Type string `json:"type"`
39}
40
41type wsConnection struct {
42 ctx context.Context
43 conn *websocket.Conn
44 exec graphql.ExecutableSchema
45 active map[string]context.CancelFunc
46 mu sync.Mutex
47 cfg *Config
48 cache *lru.Cache
49 keepAliveTicker *time.Ticker
50
51 initPayload InitPayload
52}
53
54func connectWs(exec graphql.ExecutableSchema, w http.ResponseWriter, r *http.Request, cfg *Config, cache *lru.Cache) {
55 ws, err := cfg.upgrader.Upgrade(w, r, http.Header{
56 "Sec-Websocket-Protocol": []string{"graphql-ws"},
57 })
58 if err != nil {
59 log.Printf("unable to upgrade %T to websocket %s: ", w, err.Error())
60 sendErrorf(w, http.StatusBadRequest, "unable to upgrade")
61 return
62 }
63
64 conn := wsConnection{
65 active: map[string]context.CancelFunc{},
66 exec: exec,
67 conn: ws,
68 ctx: r.Context(),
69 cfg: cfg,
70 cache: cache,
71 }
72
73 if !conn.init() {
74 return
75 }
76
77 conn.run()
78}
79
80func (c *wsConnection) init() bool {
81 message := c.readOp()
82 if message == nil {
83 c.close(websocket.CloseProtocolError, "decoding error")
84 return false
85 }
86
87 switch message.Type {
88 case connectionInitMsg:
89 if len(message.Payload) > 0 {
90 c.initPayload = make(InitPayload)
91 err := json.Unmarshal(message.Payload, &c.initPayload)
92 if err != nil {
93 return false
94 }
95 }
96
97 c.write(&operationMessage{Type: connectionAckMsg})
98 case connectionTerminateMsg:
99 c.close(websocket.CloseNormalClosure, "terminated")
100 return false
101 default:
102 c.sendConnectionError("unexpected message %s", message.Type)
103 c.close(websocket.CloseProtocolError, "unexpected message")
104 return false
105 }
106
107 return true
108}
109
110func (c *wsConnection) write(msg *operationMessage) {
111 c.mu.Lock()
112 c.conn.WriteJSON(msg)
113 c.mu.Unlock()
114}
115
116func (c *wsConnection) run() {
117 // We create a cancellation that will shutdown the keep-alive when we leave
118 // this function.
119 ctx, cancel := context.WithCancel(c.ctx)
120 defer cancel()
121
122 // Create a timer that will fire every interval to keep the connection alive.
123 if c.cfg.connectionKeepAlivePingInterval != 0 {
124 c.mu.Lock()
125 c.keepAliveTicker = time.NewTicker(c.cfg.connectionKeepAlivePingInterval)
126 c.mu.Unlock()
127
128 go c.keepAlive(ctx)
129 }
130
131 for {
132 message := c.readOp()
133 if message == nil {
134 return
135 }
136
137 switch message.Type {
138 case startMsg:
139 if !c.subscribe(message) {
140 return
141 }
142 case stopMsg:
143 c.mu.Lock()
144 closer := c.active[message.ID]
145 c.mu.Unlock()
146 if closer == nil {
147 c.sendError(message.ID, gqlerror.Errorf("%s is not running, cannot stop", message.ID))
148 continue
149 }
150
151 closer()
152 case connectionTerminateMsg:
153 c.close(websocket.CloseNormalClosure, "terminated")
154 return
155 default:
156 c.sendConnectionError("unexpected message %s", message.Type)
157 c.close(websocket.CloseProtocolError, "unexpected message")
158 return
159 }
160 }
161}
162
163func (c *wsConnection) keepAlive(ctx context.Context) {
164 for {
165 select {
166 case <-ctx.Done():
167 c.keepAliveTicker.Stop()
168 return
169 case <-c.keepAliveTicker.C:
170 c.write(&operationMessage{Type: connectionKeepAliveMsg})
171 }
172 }
173}
174
175func (c *wsConnection) subscribe(message *operationMessage) bool {
176 var reqParams params
177 if err := jsonDecode(bytes.NewReader(message.Payload), &reqParams); err != nil {
178 c.sendConnectionError("invalid json")
179 return false
180 }
181
182 var (
183 doc *ast.QueryDocument
184 cacheHit bool
185 )
186 if c.cache != nil {
187 val, ok := c.cache.Get(reqParams.Query)
188 if ok {
189 doc = val.(*ast.QueryDocument)
190 cacheHit = true
191 }
192 }
193 if !cacheHit {
194 var qErr gqlerror.List
195 doc, qErr = gqlparser.LoadQuery(c.exec.Schema(), reqParams.Query)
196 if qErr != nil {
197 c.sendError(message.ID, qErr...)
198 return true
199 }
200 if c.cache != nil {
201 c.cache.Add(reqParams.Query, doc)
202 }
203 }
204
205 op := doc.Operations.ForName(reqParams.OperationName)
206 if op == nil {
207 c.sendError(message.ID, gqlerror.Errorf("operation %s not found", reqParams.OperationName))
208 return true
209 }
210
211 vars, err := validator.VariableValues(c.exec.Schema(), op, reqParams.Variables)
212 if err != nil {
213 c.sendError(message.ID, err)
214 return true
215 }
216 reqCtx := c.cfg.newRequestContext(c.exec, doc, op, reqParams.Query, vars)
217 ctx := graphql.WithRequestContext(c.ctx, reqCtx)
218
219 if c.initPayload != nil {
220 ctx = withInitPayload(ctx, c.initPayload)
221 }
222
223 if op.Operation != ast.Subscription {
224 var result *graphql.Response
225 if op.Operation == ast.Query {
226 result = c.exec.Query(ctx, op)
227 } else {
228 result = c.exec.Mutation(ctx, op)
229 }
230
231 c.sendData(message.ID, result)
232 c.write(&operationMessage{ID: message.ID, Type: completeMsg})
233 return true
234 }
235
236 ctx, cancel := context.WithCancel(ctx)
237 c.mu.Lock()
238 c.active[message.ID] = cancel
239 c.mu.Unlock()
240 go func() {
241 defer func() {
242 if r := recover(); r != nil {
243 userErr := reqCtx.Recover(ctx, r)
244 c.sendError(message.ID, &gqlerror.Error{Message: userErr.Error()})
245 }
246 }()
247 next := c.exec.Subscription(ctx, op)
248 for result := next(); result != nil; result = next() {
249 c.sendData(message.ID, result)
250 }
251
252 c.write(&operationMessage{ID: message.ID, Type: completeMsg})
253
254 c.mu.Lock()
255 delete(c.active, message.ID)
256 c.mu.Unlock()
257 cancel()
258 }()
259
260 return true
261}
262
263func (c *wsConnection) sendData(id string, response *graphql.Response) {
264 b, err := json.Marshal(response)
265 if err != nil {
266 c.sendError(id, gqlerror.Errorf("unable to encode json response: %s", err.Error()))
267 return
268 }
269
270 c.write(&operationMessage{Type: dataMsg, ID: id, Payload: b})
271}
272
273func (c *wsConnection) sendError(id string, errors ...*gqlerror.Error) {
274 var errs []error
275 for _, err := range errors {
276 errs = append(errs, err)
277 }
278 b, err := json.Marshal(errs)
279 if err != nil {
280 panic(err)
281 }
282 c.write(&operationMessage{Type: errorMsg, ID: id, Payload: b})
283}
284
285func (c *wsConnection) sendConnectionError(format string, args ...interface{}) {
286 b, err := json.Marshal(&gqlerror.Error{Message: fmt.Sprintf(format, args...)})
287 if err != nil {
288 panic(err)
289 }
290
291 c.write(&operationMessage{Type: connectionErrorMsg, Payload: b})
292}
293
294func (c *wsConnection) readOp() *operationMessage {
295 _, r, err := c.conn.NextReader()
296 if err != nil {
297 c.sendConnectionError("invalid json")
298 return nil
299 }
300 message := operationMessage{}
301 if err := jsonDecode(r, &message); err != nil {
302 c.sendConnectionError("invalid json")
303 return nil
304 }
305
306 return &message
307}
308
309func (c *wsConnection) close(closeCode int, message string) {
310 c.mu.Lock()
311 _ = c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(closeCode, message))
312 c.mu.Unlock()
313 _ = c.conn.Close()
314}