1package handler
2
3import (
4 "bytes"
5 "context"
6 "encoding/json"
7 "fmt"
8 "log"
9 "net/http"
10 "sync"
11 "time"
12
13 "github.com/99designs/gqlgen/graphql"
14 "github.com/gorilla/websocket"
15 lru "github.com/hashicorp/golang-lru"
16 "github.com/vektah/gqlparser"
17 "github.com/vektah/gqlparser/ast"
18 "github.com/vektah/gqlparser/gqlerror"
19 "github.com/vektah/gqlparser/validator"
20)
21
22const (
23 connectionInitMsg = "connection_init" // Client -> Server
24 connectionTerminateMsg = "connection_terminate" // Client -> Server
25 startMsg = "start" // Client -> Server
26 stopMsg = "stop" // Client -> Server
27 connectionAckMsg = "connection_ack" // Server -> Client
28 connectionErrorMsg = "connection_error" // Server -> Client
29 dataMsg = "data" // Server -> Client
30 errorMsg = "error" // Server -> Client
31 completeMsg = "complete" // Server -> Client
32 connectionKeepAliveMsg = "ka" // Server -> Client
33)
34
35type operationMessage struct {
36 Payload json.RawMessage `json:"payload,omitempty"`
37 ID string `json:"id,omitempty"`
38 Type string `json:"type"`
39}
40
41type wsConnection struct {
42 ctx context.Context
43 conn *websocket.Conn
44 exec graphql.ExecutableSchema
45 active map[string]context.CancelFunc
46 mu sync.Mutex
47 cfg *Config
48 cache *lru.Cache
49 keepAliveTicker *time.Ticker
50
51 initPayload InitPayload
52}
53
54func connectWs(exec graphql.ExecutableSchema, w http.ResponseWriter, r *http.Request, cfg *Config, cache *lru.Cache) {
55 ws, err := cfg.upgrader.Upgrade(w, r, http.Header{
56 "Sec-Websocket-Protocol": []string{"graphql-ws"},
57 })
58 if err != nil {
59 log.Printf("unable to upgrade %T to websocket %s: ", w, err.Error())
60 sendErrorf(w, http.StatusBadRequest, "unable to upgrade")
61 return
62 }
63
64 conn := wsConnection{
65 active: map[string]context.CancelFunc{},
66 exec: exec,
67 conn: ws,
68 ctx: r.Context(),
69 cfg: cfg,
70 cache: cache,
71 }
72
73 if !conn.init() {
74 return
75 }
76
77 conn.run()
78}
79
80func (c *wsConnection) init() bool {
81 message := c.readOp()
82 if message == nil {
83 c.close(websocket.CloseProtocolError, "decoding error")
84 return false
85 }
86
87 switch message.Type {
88 case connectionInitMsg:
89 if len(message.Payload) > 0 {
90 c.initPayload = make(InitPayload)
91 err := json.Unmarshal(message.Payload, &c.initPayload)
92 if err != nil {
93 return false
94 }
95 }
96
97 if c.cfg.websocketInitFunc != nil {
98 if err := c.cfg.websocketInitFunc(c.ctx, c.initPayload); err != nil {
99 c.sendConnectionError(err.Error())
100 c.close(websocket.CloseNormalClosure, "terminated")
101 return false
102 }
103 }
104
105 c.write(&operationMessage{Type: connectionAckMsg})
106 case connectionTerminateMsg:
107 c.close(websocket.CloseNormalClosure, "terminated")
108 return false
109 default:
110 c.sendConnectionError("unexpected message %s", message.Type)
111 c.close(websocket.CloseProtocolError, "unexpected message")
112 return false
113 }
114
115 return true
116}
117
118func (c *wsConnection) write(msg *operationMessage) {
119 c.mu.Lock()
120 c.conn.WriteJSON(msg)
121 c.mu.Unlock()
122}
123
124func (c *wsConnection) run() {
125 // We create a cancellation that will shutdown the keep-alive when we leave
126 // this function.
127 ctx, cancel := context.WithCancel(c.ctx)
128 defer cancel()
129
130 // Create a timer that will fire every interval to keep the connection alive.
131 if c.cfg.connectionKeepAlivePingInterval != 0 {
132 c.mu.Lock()
133 c.keepAliveTicker = time.NewTicker(c.cfg.connectionKeepAlivePingInterval)
134 c.mu.Unlock()
135
136 go c.keepAlive(ctx)
137 }
138
139 for {
140 message := c.readOp()
141 if message == nil {
142 return
143 }
144
145 switch message.Type {
146 case startMsg:
147 if !c.subscribe(message) {
148 return
149 }
150 case stopMsg:
151 c.mu.Lock()
152 closer := c.active[message.ID]
153 c.mu.Unlock()
154 if closer == nil {
155 c.sendError(message.ID, gqlerror.Errorf("%s is not running, cannot stop", message.ID))
156 continue
157 }
158
159 closer()
160 case connectionTerminateMsg:
161 c.close(websocket.CloseNormalClosure, "terminated")
162 return
163 default:
164 c.sendConnectionError("unexpected message %s", message.Type)
165 c.close(websocket.CloseProtocolError, "unexpected message")
166 return
167 }
168 }
169}
170
171func (c *wsConnection) keepAlive(ctx context.Context) {
172 for {
173 select {
174 case <-ctx.Done():
175 c.keepAliveTicker.Stop()
176 return
177 case <-c.keepAliveTicker.C:
178 c.write(&operationMessage{Type: connectionKeepAliveMsg})
179 }
180 }
181}
182
183func (c *wsConnection) subscribe(message *operationMessage) bool {
184 var reqParams params
185 if err := jsonDecode(bytes.NewReader(message.Payload), &reqParams); err != nil {
186 c.sendConnectionError("invalid json")
187 return false
188 }
189
190 var (
191 doc *ast.QueryDocument
192 cacheHit bool
193 )
194 if c.cache != nil {
195 val, ok := c.cache.Get(reqParams.Query)
196 if ok {
197 doc = val.(*ast.QueryDocument)
198 cacheHit = true
199 }
200 }
201 if !cacheHit {
202 var qErr gqlerror.List
203 doc, qErr = gqlparser.LoadQuery(c.exec.Schema(), reqParams.Query)
204 if qErr != nil {
205 c.sendError(message.ID, qErr...)
206 return true
207 }
208 if c.cache != nil {
209 c.cache.Add(reqParams.Query, doc)
210 }
211 }
212
213 op := doc.Operations.ForName(reqParams.OperationName)
214 if op == nil {
215 c.sendError(message.ID, gqlerror.Errorf("operation %s not found", reqParams.OperationName))
216 return true
217 }
218
219 vars, err := validator.VariableValues(c.exec.Schema(), op, reqParams.Variables)
220 if err != nil {
221 c.sendError(message.ID, err)
222 return true
223 }
224 reqCtx := c.cfg.newRequestContext(c.exec, doc, op, reqParams.Query, vars)
225 ctx := graphql.WithRequestContext(c.ctx, reqCtx)
226
227 if c.initPayload != nil {
228 ctx = withInitPayload(ctx, c.initPayload)
229 }
230
231 if op.Operation != ast.Subscription {
232 var result *graphql.Response
233 if op.Operation == ast.Query {
234 result = c.exec.Query(ctx, op)
235 } else {
236 result = c.exec.Mutation(ctx, op)
237 }
238
239 c.sendData(message.ID, result)
240 c.write(&operationMessage{ID: message.ID, Type: completeMsg})
241 return true
242 }
243
244 ctx, cancel := context.WithCancel(ctx)
245 c.mu.Lock()
246 c.active[message.ID] = cancel
247 c.mu.Unlock()
248 go func() {
249 defer func() {
250 if r := recover(); r != nil {
251 userErr := reqCtx.Recover(ctx, r)
252 c.sendError(message.ID, &gqlerror.Error{Message: userErr.Error()})
253 }
254 }()
255 next := c.exec.Subscription(ctx, op)
256 for result := next(); result != nil; result = next() {
257 c.sendData(message.ID, result)
258 }
259
260 c.write(&operationMessage{ID: message.ID, Type: completeMsg})
261
262 c.mu.Lock()
263 delete(c.active, message.ID)
264 c.mu.Unlock()
265 cancel()
266 }()
267
268 return true
269}
270
271func (c *wsConnection) sendData(id string, response *graphql.Response) {
272 b, err := json.Marshal(response)
273 if err != nil {
274 c.sendError(id, gqlerror.Errorf("unable to encode json response: %s", err.Error()))
275 return
276 }
277
278 c.write(&operationMessage{Type: dataMsg, ID: id, Payload: b})
279}
280
281func (c *wsConnection) sendError(id string, errors ...*gqlerror.Error) {
282 var errs []error
283 for _, err := range errors {
284 errs = append(errs, err)
285 }
286 b, err := json.Marshal(errs)
287 if err != nil {
288 panic(err)
289 }
290 c.write(&operationMessage{Type: errorMsg, ID: id, Payload: b})
291}
292
293func (c *wsConnection) sendConnectionError(format string, args ...interface{}) {
294 b, err := json.Marshal(&gqlerror.Error{Message: fmt.Sprintf(format, args...)})
295 if err != nil {
296 panic(err)
297 }
298
299 c.write(&operationMessage{Type: connectionErrorMsg, Payload: b})
300}
301
302func (c *wsConnection) readOp() *operationMessage {
303 _, r, err := c.conn.NextReader()
304 if err != nil {
305 c.sendConnectionError("invalid json")
306 return nil
307 }
308 message := operationMessage{}
309 if err := jsonDecode(r, &message); err != nil {
310 c.sendConnectionError("invalid json")
311 return nil
312 }
313
314 return &message
315}
316
317func (c *wsConnection) close(closeCode int, message string) {
318 c.mu.Lock()
319 _ = c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(closeCode, message))
320 c.mu.Unlock()
321 _ = c.conn.Close()
322}