1package handler
2
3import (
4 "bytes"
5 "context"
6 "encoding/json"
7 "fmt"
8 "log"
9 "net/http"
10 "sync"
11 "time"
12
13 "github.com/99designs/gqlgen/graphql"
14 "github.com/gorilla/websocket"
15 lru "github.com/hashicorp/golang-lru"
16 "github.com/vektah/gqlparser"
17 "github.com/vektah/gqlparser/ast"
18 "github.com/vektah/gqlparser/gqlerror"
19 "github.com/vektah/gqlparser/validator"
20)
21
22const (
23 connectionInitMsg = "connection_init" // Client -> Server
24 connectionTerminateMsg = "connection_terminate" // Client -> Server
25 startMsg = "start" // Client -> Server
26 stopMsg = "stop" // Client -> Server
27 connectionAckMsg = "connection_ack" // Server -> Client
28 connectionErrorMsg = "connection_error" // Server -> Client
29 dataMsg = "data" // Server -> Client
30 errorMsg = "error" // Server -> Client
31 completeMsg = "complete" // Server -> Client
32 connectionKeepAliveMsg = "ka" // Server -> Client
33)
34
35type operationMessage struct {
36 Payload json.RawMessage `json:"payload,omitempty"`
37 ID string `json:"id,omitempty"`
38 Type string `json:"type"`
39}
40
41type wsConnection struct {
42 ctx context.Context
43 conn *websocket.Conn
44 exec graphql.ExecutableSchema
45 active map[string]context.CancelFunc
46 mu sync.Mutex
47 cfg *Config
48 cache *lru.Cache
49 keepAliveTicker *time.Ticker
50
51 initPayload InitPayload
52}
53
54func connectWs(exec graphql.ExecutableSchema, w http.ResponseWriter, r *http.Request, cfg *Config, cache *lru.Cache) {
55 ws, err := cfg.upgrader.Upgrade(w, r, http.Header{
56 "Sec-Websocket-Protocol": []string{"graphql-ws"},
57 })
58 if err != nil {
59 log.Printf("unable to upgrade %T to websocket %s: ", w, err.Error())
60 sendErrorf(w, http.StatusBadRequest, "unable to upgrade")
61 return
62 }
63
64 conn := wsConnection{
65 active: map[string]context.CancelFunc{},
66 exec: exec,
67 conn: ws,
68 ctx: r.Context(),
69 cfg: cfg,
70 cache: cache,
71 }
72
73 if !conn.init() {
74 return
75 }
76
77 conn.run()
78}
79
80func (c *wsConnection) init() bool {
81 message := c.readOp()
82 if message == nil {
83 c.close(websocket.CloseProtocolError, "decoding error")
84 return false
85 }
86
87 switch message.Type {
88 case connectionInitMsg:
89 if len(message.Payload) > 0 {
90 c.initPayload = make(InitPayload)
91 err := json.Unmarshal(message.Payload, &c.initPayload)
92 if err != nil {
93 return false
94 }
95 }
96
97 if c.cfg.websocketInitFunc != nil {
98 if err := c.cfg.websocketInitFunc(c.ctx, c.initPayload); err != nil {
99 c.sendConnectionError(err.Error())
100 c.close(websocket.CloseNormalClosure, "terminated")
101 return false
102 }
103 }
104
105 c.write(&operationMessage{Type: connectionAckMsg})
106 c.write(&operationMessage{Type: connectionKeepAliveMsg})
107 case connectionTerminateMsg:
108 c.close(websocket.CloseNormalClosure, "terminated")
109 return false
110 default:
111 c.sendConnectionError("unexpected message %s", message.Type)
112 c.close(websocket.CloseProtocolError, "unexpected message")
113 return false
114 }
115
116 return true
117}
118
119func (c *wsConnection) write(msg *operationMessage) {
120 c.mu.Lock()
121 c.conn.WriteJSON(msg)
122 c.mu.Unlock()
123}
124
125func (c *wsConnection) run() {
126 // We create a cancellation that will shutdown the keep-alive when we leave
127 // this function.
128 ctx, cancel := context.WithCancel(c.ctx)
129 defer cancel()
130
131 // Create a timer that will fire every interval to keep the connection alive.
132 if c.cfg.connectionKeepAlivePingInterval != 0 {
133 c.mu.Lock()
134 c.keepAliveTicker = time.NewTicker(c.cfg.connectionKeepAlivePingInterval)
135 c.mu.Unlock()
136
137 go c.keepAlive(ctx)
138 }
139
140 for {
141 message := c.readOp()
142 if message == nil {
143 return
144 }
145
146 switch message.Type {
147 case startMsg:
148 if !c.subscribe(message) {
149 return
150 }
151 case stopMsg:
152 c.mu.Lock()
153 closer := c.active[message.ID]
154 c.mu.Unlock()
155 if closer == nil {
156 c.sendError(message.ID, gqlerror.Errorf("%s is not running, cannot stop", message.ID))
157 continue
158 }
159
160 closer()
161 case connectionTerminateMsg:
162 c.close(websocket.CloseNormalClosure, "terminated")
163 return
164 default:
165 c.sendConnectionError("unexpected message %s", message.Type)
166 c.close(websocket.CloseProtocolError, "unexpected message")
167 return
168 }
169 }
170}
171
172func (c *wsConnection) keepAlive(ctx context.Context) {
173 for {
174 select {
175 case <-ctx.Done():
176 c.keepAliveTicker.Stop()
177 return
178 case <-c.keepAliveTicker.C:
179 c.write(&operationMessage{Type: connectionKeepAliveMsg})
180 }
181 }
182}
183
184func (c *wsConnection) subscribe(message *operationMessage) bool {
185 var reqParams params
186 if err := jsonDecode(bytes.NewReader(message.Payload), &reqParams); err != nil {
187 c.sendConnectionError("invalid json")
188 return false
189 }
190
191 var (
192 doc *ast.QueryDocument
193 cacheHit bool
194 )
195 if c.cache != nil {
196 val, ok := c.cache.Get(reqParams.Query)
197 if ok {
198 doc = val.(*ast.QueryDocument)
199 cacheHit = true
200 }
201 }
202 if !cacheHit {
203 var qErr gqlerror.List
204 doc, qErr = gqlparser.LoadQuery(c.exec.Schema(), reqParams.Query)
205 if qErr != nil {
206 c.sendError(message.ID, qErr...)
207 return true
208 }
209 if c.cache != nil {
210 c.cache.Add(reqParams.Query, doc)
211 }
212 }
213
214 op := doc.Operations.ForName(reqParams.OperationName)
215 if op == nil {
216 c.sendError(message.ID, gqlerror.Errorf("operation %s not found", reqParams.OperationName))
217 return true
218 }
219
220 vars, err := validator.VariableValues(c.exec.Schema(), op, reqParams.Variables)
221 if err != nil {
222 c.sendError(message.ID, err)
223 return true
224 }
225 reqCtx := c.cfg.newRequestContext(c.exec, doc, op, reqParams.Query, vars)
226 ctx := graphql.WithRequestContext(c.ctx, reqCtx)
227
228 if c.initPayload != nil {
229 ctx = withInitPayload(ctx, c.initPayload)
230 }
231
232 if op.Operation != ast.Subscription {
233 var result *graphql.Response
234 if op.Operation == ast.Query {
235 result = c.exec.Query(ctx, op)
236 } else {
237 result = c.exec.Mutation(ctx, op)
238 }
239
240 c.sendData(message.ID, result)
241 c.write(&operationMessage{ID: message.ID, Type: completeMsg})
242 return true
243 }
244
245 ctx, cancel := context.WithCancel(ctx)
246 c.mu.Lock()
247 c.active[message.ID] = cancel
248 c.mu.Unlock()
249 go func() {
250 defer func() {
251 if r := recover(); r != nil {
252 userErr := reqCtx.Recover(ctx, r)
253 c.sendError(message.ID, &gqlerror.Error{Message: userErr.Error()})
254 }
255 }()
256 next := c.exec.Subscription(ctx, op)
257 for result := next(); result != nil; result = next() {
258 c.sendData(message.ID, result)
259 }
260
261 c.write(&operationMessage{ID: message.ID, Type: completeMsg})
262
263 c.mu.Lock()
264 delete(c.active, message.ID)
265 c.mu.Unlock()
266 cancel()
267 }()
268
269 return true
270}
271
272func (c *wsConnection) sendData(id string, response *graphql.Response) {
273 b, err := json.Marshal(response)
274 if err != nil {
275 c.sendError(id, gqlerror.Errorf("unable to encode json response: %s", err.Error()))
276 return
277 }
278
279 c.write(&operationMessage{Type: dataMsg, ID: id, Payload: b})
280}
281
282func (c *wsConnection) sendError(id string, errors ...*gqlerror.Error) {
283 errs := make([]error, len(errors))
284 for i, err := range errors {
285 errs[i] = err
286 }
287 b, err := json.Marshal(errs)
288 if err != nil {
289 panic(err)
290 }
291 c.write(&operationMessage{Type: errorMsg, ID: id, Payload: b})
292}
293
294func (c *wsConnection) sendConnectionError(format string, args ...interface{}) {
295 b, err := json.Marshal(&gqlerror.Error{Message: fmt.Sprintf(format, args...)})
296 if err != nil {
297 panic(err)
298 }
299
300 c.write(&operationMessage{Type: connectionErrorMsg, Payload: b})
301}
302
303func (c *wsConnection) readOp() *operationMessage {
304 _, r, err := c.conn.NextReader()
305 if err != nil {
306 c.sendConnectionError("invalid json")
307 return nil
308 }
309 message := operationMessage{}
310 if err := jsonDecode(r, &message); err != nil {
311 c.sendConnectionError("invalid json")
312 return nil
313 }
314
315 return &message
316}
317
318func (c *wsConnection) close(closeCode int, message string) {
319 c.mu.Lock()
320 _ = c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(closeCode, message))
321 c.mu.Unlock()
322 _ = c.conn.Close()
323}