1package handler
2
3import (
4 "bytes"
5 "context"
6 "encoding/json"
7 "fmt"
8 "log"
9 "net/http"
10 "sync"
11 "time"
12
13 "github.com/99designs/gqlgen/graphql"
14 "github.com/gorilla/websocket"
15 "github.com/vektah/gqlparser"
16 "github.com/vektah/gqlparser/ast"
17 "github.com/vektah/gqlparser/gqlerror"
18 "github.com/vektah/gqlparser/validator"
19)
20
21const (
22 connectionInitMsg = "connection_init" // Client -> Server
23 connectionTerminateMsg = "connection_terminate" // Client -> Server
24 startMsg = "start" // Client -> Server
25 stopMsg = "stop" // Client -> Server
26 connectionAckMsg = "connection_ack" // Server -> Client
27 connectionErrorMsg = "connection_error" // Server -> Client
28 dataMsg = "data" // Server -> Client
29 errorMsg = "error" // Server -> Client
30 completeMsg = "complete" // Server -> Client
31 connectionKeepAliveMsg = "ka" // Server -> Client
32)
33
34type operationMessage struct {
35 Payload json.RawMessage `json:"payload,omitempty"`
36 ID string `json:"id,omitempty"`
37 Type string `json:"type"`
38}
39
40type wsConnection struct {
41 ctx context.Context
42 conn *websocket.Conn
43 exec graphql.ExecutableSchema
44 active map[string]context.CancelFunc
45 mu sync.Mutex
46 cfg *Config
47 keepAliveTicker *time.Ticker
48
49 initPayload InitPayload
50}
51
52func connectWs(exec graphql.ExecutableSchema, w http.ResponseWriter, r *http.Request, cfg *Config) {
53 ws, err := cfg.upgrader.Upgrade(w, r, http.Header{
54 "Sec-Websocket-Protocol": []string{"graphql-ws"},
55 })
56 if err != nil {
57 log.Printf("unable to upgrade %T to websocket %s: ", w, err.Error())
58 sendErrorf(w, http.StatusBadRequest, "unable to upgrade")
59 return
60 }
61
62 conn := wsConnection{
63 active: map[string]context.CancelFunc{},
64 exec: exec,
65 conn: ws,
66 ctx: r.Context(),
67 cfg: cfg,
68 }
69
70 if !conn.init() {
71 return
72 }
73
74 conn.run()
75}
76
77func (c *wsConnection) init() bool {
78 message := c.readOp()
79 if message == nil {
80 c.close(websocket.CloseProtocolError, "decoding error")
81 return false
82 }
83
84 switch message.Type {
85 case connectionInitMsg:
86 if len(message.Payload) > 0 {
87 c.initPayload = make(InitPayload)
88 err := json.Unmarshal(message.Payload, &c.initPayload)
89 if err != nil {
90 return false
91 }
92 }
93
94 c.write(&operationMessage{Type: connectionAckMsg})
95 case connectionTerminateMsg:
96 c.close(websocket.CloseNormalClosure, "terminated")
97 return false
98 default:
99 c.sendConnectionError("unexpected message %s", message.Type)
100 c.close(websocket.CloseProtocolError, "unexpected message")
101 return false
102 }
103
104 return true
105}
106
107func (c *wsConnection) write(msg *operationMessage) {
108 c.mu.Lock()
109 c.conn.WriteJSON(msg)
110 c.mu.Unlock()
111}
112
113func (c *wsConnection) run() {
114 // We create a cancellation that will shutdown the keep-alive when we leave
115 // this function.
116 ctx, cancel := context.WithCancel(c.ctx)
117 defer cancel()
118
119 // Create a timer that will fire every interval to keep the connection alive.
120 if c.cfg.connectionKeepAlivePingInterval != 0 {
121 c.mu.Lock()
122 c.keepAliveTicker = time.NewTicker(c.cfg.connectionKeepAlivePingInterval)
123 c.mu.Unlock()
124
125 go c.keepAlive(ctx)
126 }
127
128 for {
129 message := c.readOp()
130 if message == nil {
131 return
132 }
133
134 switch message.Type {
135 case startMsg:
136 if !c.subscribe(message) {
137 return
138 }
139 case stopMsg:
140 c.mu.Lock()
141 closer := c.active[message.ID]
142 c.mu.Unlock()
143 if closer == nil {
144 c.sendError(message.ID, gqlerror.Errorf("%s is not running, cannot stop", message.ID))
145 continue
146 }
147
148 closer()
149 case connectionTerminateMsg:
150 c.close(websocket.CloseNormalClosure, "terminated")
151 return
152 default:
153 c.sendConnectionError("unexpected message %s", message.Type)
154 c.close(websocket.CloseProtocolError, "unexpected message")
155 return
156 }
157 }
158}
159
160func (c *wsConnection) keepAlive(ctx context.Context) {
161 for {
162 select {
163 case <-ctx.Done():
164 c.keepAliveTicker.Stop()
165 return
166 case <-c.keepAliveTicker.C:
167 c.write(&operationMessage{Type: connectionKeepAliveMsg})
168 }
169 }
170}
171
172func (c *wsConnection) subscribe(message *operationMessage) bool {
173 var reqParams params
174 if err := jsonDecode(bytes.NewReader(message.Payload), &reqParams); err != nil {
175 c.sendConnectionError("invalid json")
176 return false
177 }
178
179 doc, qErr := gqlparser.LoadQuery(c.exec.Schema(), reqParams.Query)
180 if qErr != nil {
181 c.sendError(message.ID, qErr...)
182 return true
183 }
184
185 op := doc.Operations.ForName(reqParams.OperationName)
186 if op == nil {
187 c.sendError(message.ID, gqlerror.Errorf("operation %s not found", reqParams.OperationName))
188 return true
189 }
190
191 vars, err := validator.VariableValues(c.exec.Schema(), op, reqParams.Variables)
192 if err != nil {
193 c.sendError(message.ID, err)
194 return true
195 }
196 reqCtx := c.cfg.newRequestContext(c.exec, doc, op, reqParams.Query, vars)
197 ctx := graphql.WithRequestContext(c.ctx, reqCtx)
198
199 if c.initPayload != nil {
200 ctx = withInitPayload(ctx, c.initPayload)
201 }
202
203 if op.Operation != ast.Subscription {
204 var result *graphql.Response
205 if op.Operation == ast.Query {
206 result = c.exec.Query(ctx, op)
207 } else {
208 result = c.exec.Mutation(ctx, op)
209 }
210
211 c.sendData(message.ID, result)
212 c.write(&operationMessage{ID: message.ID, Type: completeMsg})
213 return true
214 }
215
216 ctx, cancel := context.WithCancel(ctx)
217 c.mu.Lock()
218 c.active[message.ID] = cancel
219 c.mu.Unlock()
220 go func() {
221 defer func() {
222 if r := recover(); r != nil {
223 userErr := reqCtx.Recover(ctx, r)
224 c.sendError(message.ID, &gqlerror.Error{Message: userErr.Error()})
225 }
226 }()
227 next := c.exec.Subscription(ctx, op)
228 for result := next(); result != nil; result = next() {
229 c.sendData(message.ID, result)
230 }
231
232 c.write(&operationMessage{ID: message.ID, Type: completeMsg})
233
234 c.mu.Lock()
235 delete(c.active, message.ID)
236 c.mu.Unlock()
237 cancel()
238 }()
239
240 return true
241}
242
243func (c *wsConnection) sendData(id string, response *graphql.Response) {
244 b, err := json.Marshal(response)
245 if err != nil {
246 c.sendError(id, gqlerror.Errorf("unable to encode json response: %s", err.Error()))
247 return
248 }
249
250 c.write(&operationMessage{Type: dataMsg, ID: id, Payload: b})
251}
252
253func (c *wsConnection) sendError(id string, errors ...*gqlerror.Error) {
254 var errs []error
255 for _, err := range errors {
256 errs = append(errs, err)
257 }
258 b, err := json.Marshal(errs)
259 if err != nil {
260 panic(err)
261 }
262 c.write(&operationMessage{Type: errorMsg, ID: id, Payload: b})
263}
264
265func (c *wsConnection) sendConnectionError(format string, args ...interface{}) {
266 b, err := json.Marshal(&gqlerror.Error{Message: fmt.Sprintf(format, args...)})
267 if err != nil {
268 panic(err)
269 }
270
271 c.write(&operationMessage{Type: connectionErrorMsg, Payload: b})
272}
273
274func (c *wsConnection) readOp() *operationMessage {
275 _, r, err := c.conn.NextReader()
276 if err != nil {
277 c.sendConnectionError("invalid json")
278 return nil
279 }
280 message := operationMessage{}
281 if err := jsonDecode(r, &message); err != nil {
282 c.sendConnectionError("invalid json")
283 return nil
284 }
285
286 return &message
287}
288
289func (c *wsConnection) close(closeCode int, message string) {
290 c.mu.Lock()
291 _ = c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(closeCode, message))
292 c.mu.Unlock()
293 _ = c.conn.Close()
294}