1package handler
2
3import (
4 "bytes"
5 "context"
6 "encoding/json"
7 "fmt"
8 "log"
9 "net/http"
10 "sync"
11
12 "github.com/99designs/gqlgen/graphql"
13 "github.com/gorilla/websocket"
14 "github.com/vektah/gqlparser"
15 "github.com/vektah/gqlparser/ast"
16 "github.com/vektah/gqlparser/gqlerror"
17 "github.com/vektah/gqlparser/validator"
18)
19
20const (
21 connectionInitMsg = "connection_init" // Client -> Server
22 connectionTerminateMsg = "connection_terminate" // Client -> Server
23 startMsg = "start" // Client -> Server
24 stopMsg = "stop" // Client -> Server
25 connectionAckMsg = "connection_ack" // Server -> Client
26 connectionErrorMsg = "connection_error" // Server -> Client
27 dataMsg = "data" // Server -> Client
28 errorMsg = "error" // Server -> Client
29 completeMsg = "complete" // Server -> Client
30 //connectionKeepAliveMsg = "ka" // Server -> Client TODO: keepalives
31)
32
33type operationMessage struct {
34 Payload json.RawMessage `json:"payload,omitempty"`
35 ID string `json:"id,omitempty"`
36 Type string `json:"type"`
37}
38
39type wsConnection struct {
40 ctx context.Context
41 conn *websocket.Conn
42 exec graphql.ExecutableSchema
43 active map[string]context.CancelFunc
44 mu sync.Mutex
45 cfg *Config
46
47 initPayload InitPayload
48}
49
50func connectWs(exec graphql.ExecutableSchema, w http.ResponseWriter, r *http.Request, cfg *Config) {
51 ws, err := cfg.upgrader.Upgrade(w, r, http.Header{
52 "Sec-Websocket-Protocol": []string{"graphql-ws"},
53 })
54 if err != nil {
55 log.Printf("unable to upgrade %T to websocket %s: ", w, err.Error())
56 sendErrorf(w, http.StatusBadRequest, "unable to upgrade")
57 return
58 }
59
60 conn := wsConnection{
61 active: map[string]context.CancelFunc{},
62 exec: exec,
63 conn: ws,
64 ctx: r.Context(),
65 cfg: cfg,
66 }
67
68 if !conn.init() {
69 return
70 }
71
72 conn.run()
73}
74
75func (c *wsConnection) init() bool {
76 message := c.readOp()
77 if message == nil {
78 c.close(websocket.CloseProtocolError, "decoding error")
79 return false
80 }
81
82 switch message.Type {
83 case connectionInitMsg:
84 if len(message.Payload) > 0 {
85 c.initPayload = make(InitPayload)
86 err := json.Unmarshal(message.Payload, &c.initPayload)
87 if err != nil {
88 return false
89 }
90 }
91
92 c.write(&operationMessage{Type: connectionAckMsg})
93 case connectionTerminateMsg:
94 c.close(websocket.CloseNormalClosure, "terminated")
95 return false
96 default:
97 c.sendConnectionError("unexpected message %s", message.Type)
98 c.close(websocket.CloseProtocolError, "unexpected message")
99 return false
100 }
101
102 return true
103}
104
105func (c *wsConnection) write(msg *operationMessage) {
106 c.mu.Lock()
107 c.conn.WriteJSON(msg)
108 c.mu.Unlock()
109}
110
111func (c *wsConnection) run() {
112 for {
113 message := c.readOp()
114 if message == nil {
115 return
116 }
117
118 switch message.Type {
119 case startMsg:
120 if !c.subscribe(message) {
121 return
122 }
123 case stopMsg:
124 c.mu.Lock()
125 closer := c.active[message.ID]
126 c.mu.Unlock()
127 if closer == nil {
128 c.sendError(message.ID, gqlerror.Errorf("%s is not running, cannot stop", message.ID))
129 continue
130 }
131
132 closer()
133 case connectionTerminateMsg:
134 c.close(websocket.CloseNormalClosure, "terminated")
135 return
136 default:
137 c.sendConnectionError("unexpected message %s", message.Type)
138 c.close(websocket.CloseProtocolError, "unexpected message")
139 return
140 }
141 }
142}
143
144func (c *wsConnection) subscribe(message *operationMessage) bool {
145 var reqParams params
146 if err := jsonDecode(bytes.NewReader(message.Payload), &reqParams); err != nil {
147 c.sendConnectionError("invalid json")
148 return false
149 }
150
151 doc, qErr := gqlparser.LoadQuery(c.exec.Schema(), reqParams.Query)
152 if qErr != nil {
153 c.sendError(message.ID, qErr...)
154 return true
155 }
156
157 op := doc.Operations.ForName(reqParams.OperationName)
158 if op == nil {
159 c.sendError(message.ID, gqlerror.Errorf("operation %s not found", reqParams.OperationName))
160 return true
161 }
162
163 vars, err := validator.VariableValues(c.exec.Schema(), op, reqParams.Variables)
164 if err != nil {
165 c.sendError(message.ID, err)
166 return true
167 }
168 reqCtx := c.cfg.newRequestContext(c.exec, doc, op, reqParams.Query, vars)
169 ctx := graphql.WithRequestContext(c.ctx, reqCtx)
170
171 if c.initPayload != nil {
172 ctx = withInitPayload(ctx, c.initPayload)
173 }
174
175 if op.Operation != ast.Subscription {
176 var result *graphql.Response
177 if op.Operation == ast.Query {
178 result = c.exec.Query(ctx, op)
179 } else {
180 result = c.exec.Mutation(ctx, op)
181 }
182
183 c.sendData(message.ID, result)
184 c.write(&operationMessage{ID: message.ID, Type: completeMsg})
185 return true
186 }
187
188 ctx, cancel := context.WithCancel(ctx)
189 c.mu.Lock()
190 c.active[message.ID] = cancel
191 c.mu.Unlock()
192 go func() {
193 defer func() {
194 if r := recover(); r != nil {
195 userErr := reqCtx.Recover(ctx, r)
196 c.sendError(message.ID, &gqlerror.Error{Message: userErr.Error()})
197 }
198 }()
199 next := c.exec.Subscription(ctx, op)
200 for result := next(); result != nil; result = next() {
201 c.sendData(message.ID, result)
202 }
203
204 c.write(&operationMessage{ID: message.ID, Type: completeMsg})
205
206 c.mu.Lock()
207 delete(c.active, message.ID)
208 c.mu.Unlock()
209 cancel()
210 }()
211
212 return true
213}
214
215func (c *wsConnection) sendData(id string, response *graphql.Response) {
216 b, err := json.Marshal(response)
217 if err != nil {
218 c.sendError(id, gqlerror.Errorf("unable to encode json response: %s", err.Error()))
219 return
220 }
221
222 c.write(&operationMessage{Type: dataMsg, ID: id, Payload: b})
223}
224
225func (c *wsConnection) sendError(id string, errors ...*gqlerror.Error) {
226 var errs []error
227 for _, err := range errors {
228 errs = append(errs, err)
229 }
230 b, err := json.Marshal(errs)
231 if err != nil {
232 panic(err)
233 }
234 c.write(&operationMessage{Type: errorMsg, ID: id, Payload: b})
235}
236
237func (c *wsConnection) sendConnectionError(format string, args ...interface{}) {
238 b, err := json.Marshal(&gqlerror.Error{Message: fmt.Sprintf(format, args...)})
239 if err != nil {
240 panic(err)
241 }
242
243 c.write(&operationMessage{Type: connectionErrorMsg, Payload: b})
244}
245
246func (c *wsConnection) readOp() *operationMessage {
247 _, r, err := c.conn.NextReader()
248 if err != nil {
249 c.sendConnectionError("invalid json")
250 return nil
251 }
252 message := operationMessage{}
253 if err := jsonDecode(r, &message); err != nil {
254 c.sendConnectionError("invalid json")
255 return nil
256 }
257
258 return &message
259}
260
261func (c *wsConnection) close(closeCode int, message string) {
262 c.mu.Lock()
263 _ = c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(closeCode, message))
264 c.mu.Unlock()
265 _ = c.conn.Close()
266}