1package handler
2
3import (
4 "bytes"
5 "context"
6 "encoding/json"
7 "fmt"
8 "log"
9 "net/http"
10 "sync"
11
12 "github.com/99designs/gqlgen/graphql"
13 "github.com/gorilla/websocket"
14 "github.com/vektah/gqlparser"
15 "github.com/vektah/gqlparser/ast"
16 "github.com/vektah/gqlparser/gqlerror"
17 "github.com/vektah/gqlparser/validator"
18)
19
20const (
21 connectionInitMsg = "connection_init" // Client -> Server
22 connectionTerminateMsg = "connection_terminate" // Client -> Server
23 startMsg = "start" // Client -> Server
24 stopMsg = "stop" // Client -> Server
25 connectionAckMsg = "connection_ack" // Server -> Client
26 connectionErrorMsg = "connection_error" // Server -> Client
27 dataMsg = "data" // Server -> Client
28 errorMsg = "error" // Server -> Client
29 completeMsg = "complete" // Server -> Client
30 //connectionKeepAliveMsg = "ka" // Server -> Client TODO: keepalives
31)
32
33type operationMessage struct {
34 Payload json.RawMessage `json:"payload,omitempty"`
35 ID string `json:"id,omitempty"`
36 Type string `json:"type"`
37}
38
39type wsConnection struct {
40 ctx context.Context
41 conn *websocket.Conn
42 exec graphql.ExecutableSchema
43 active map[string]context.CancelFunc
44 mu sync.Mutex
45 cfg *Config
46}
47
48func connectWs(exec graphql.ExecutableSchema, w http.ResponseWriter, r *http.Request, cfg *Config) {
49 ws, err := cfg.upgrader.Upgrade(w, r, http.Header{
50 "Sec-Websocket-Protocol": []string{"graphql-ws"},
51 })
52 if err != nil {
53 log.Printf("unable to upgrade %T to websocket %s: ", w, err.Error())
54 sendErrorf(w, http.StatusBadRequest, "unable to upgrade")
55 return
56 }
57
58 conn := wsConnection{
59 active: map[string]context.CancelFunc{},
60 exec: exec,
61 conn: ws,
62 ctx: r.Context(),
63 cfg: cfg,
64 }
65
66 if !conn.init() {
67 return
68 }
69
70 conn.run()
71}
72
73func (c *wsConnection) init() bool {
74 message := c.readOp()
75 if message == nil {
76 c.close(websocket.CloseProtocolError, "decoding error")
77 return false
78 }
79
80 switch message.Type {
81 case connectionInitMsg:
82 c.write(&operationMessage{Type: connectionAckMsg})
83 case connectionTerminateMsg:
84 c.close(websocket.CloseNormalClosure, "terminated")
85 return false
86 default:
87 c.sendConnectionError("unexpected message %s", message.Type)
88 c.close(websocket.CloseProtocolError, "unexpected message")
89 return false
90 }
91
92 return true
93}
94
95func (c *wsConnection) write(msg *operationMessage) {
96 c.mu.Lock()
97 c.conn.WriteJSON(msg)
98 c.mu.Unlock()
99}
100
101func (c *wsConnection) run() {
102 for {
103 message := c.readOp()
104 if message == nil {
105 return
106 }
107
108 switch message.Type {
109 case startMsg:
110 if !c.subscribe(message) {
111 return
112 }
113 case stopMsg:
114 c.mu.Lock()
115 closer := c.active[message.ID]
116 c.mu.Unlock()
117 if closer == nil {
118 c.sendError(message.ID, gqlerror.Errorf("%s is not running, cannot stop", message.ID))
119 continue
120 }
121
122 closer()
123 case connectionTerminateMsg:
124 c.close(websocket.CloseNormalClosure, "terminated")
125 return
126 default:
127 c.sendConnectionError("unexpected message %s", message.Type)
128 c.close(websocket.CloseProtocolError, "unexpected message")
129 return
130 }
131 }
132}
133
134func (c *wsConnection) subscribe(message *operationMessage) bool {
135 var reqParams params
136 if err := jsonDecode(bytes.NewReader(message.Payload), &reqParams); err != nil {
137 c.sendConnectionError("invalid json")
138 return false
139 }
140
141 doc, qErr := gqlparser.LoadQuery(c.exec.Schema(), reqParams.Query)
142 if qErr != nil {
143 c.sendError(message.ID, qErr...)
144 return true
145 }
146
147 op := doc.Operations.ForName(reqParams.OperationName)
148 if op == nil {
149 c.sendError(message.ID, gqlerror.Errorf("operation %s not found", reqParams.OperationName))
150 return true
151 }
152
153 vars, err := validator.VariableValues(c.exec.Schema(), op, reqParams.Variables)
154 if err != nil {
155 c.sendError(message.ID, err)
156 return true
157 }
158 reqCtx := c.cfg.newRequestContext(doc, reqParams.Query, vars)
159 ctx := graphql.WithRequestContext(c.ctx, reqCtx)
160
161 if op.Operation != ast.Subscription {
162 var result *graphql.Response
163 if op.Operation == ast.Query {
164 result = c.exec.Query(ctx, op)
165 } else {
166 result = c.exec.Mutation(ctx, op)
167 }
168
169 c.sendData(message.ID, result)
170 c.write(&operationMessage{ID: message.ID, Type: completeMsg})
171 return true
172 }
173
174 ctx, cancel := context.WithCancel(ctx)
175 c.mu.Lock()
176 c.active[message.ID] = cancel
177 c.mu.Unlock()
178 go func() {
179 defer func() {
180 if r := recover(); r != nil {
181 userErr := reqCtx.Recover(ctx, r)
182 c.sendError(message.ID, &gqlerror.Error{Message: userErr.Error()})
183 }
184 }()
185 next := c.exec.Subscription(ctx, op)
186 for result := next(); result != nil; result = next() {
187 c.sendData(message.ID, result)
188 }
189
190 c.write(&operationMessage{ID: message.ID, Type: completeMsg})
191
192 c.mu.Lock()
193 delete(c.active, message.ID)
194 c.mu.Unlock()
195 cancel()
196 }()
197
198 return true
199}
200
201func (c *wsConnection) sendData(id string, response *graphql.Response) {
202 b, err := json.Marshal(response)
203 if err != nil {
204 c.sendError(id, gqlerror.Errorf("unable to encode json response: %s", err.Error()))
205 return
206 }
207
208 c.write(&operationMessage{Type: dataMsg, ID: id, Payload: b})
209}
210
211func (c *wsConnection) sendError(id string, errors ...*gqlerror.Error) {
212 var errs []error
213 for _, err := range errors {
214 errs = append(errs, err)
215 }
216 b, err := json.Marshal(errs)
217 if err != nil {
218 panic(err)
219 }
220 c.write(&operationMessage{Type: errorMsg, ID: id, Payload: b})
221}
222
223func (c *wsConnection) sendConnectionError(format string, args ...interface{}) {
224 b, err := json.Marshal(&gqlerror.Error{Message: fmt.Sprintf(format, args...)})
225 if err != nil {
226 panic(err)
227 }
228
229 c.write(&operationMessage{Type: connectionErrorMsg, Payload: b})
230}
231
232func (c *wsConnection) readOp() *operationMessage {
233 _, r, err := c.conn.NextReader()
234 if err != nil {
235 c.sendConnectionError("invalid json")
236 return nil
237 }
238 message := operationMessage{}
239 if err := jsonDecode(r, &message); err != nil {
240 c.sendConnectionError("invalid json")
241 return nil
242 }
243
244 return &message
245}
246
247func (c *wsConnection) close(closeCode int, message string) {
248 c.mu.Lock()
249 _ = c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(closeCode, message))
250 c.mu.Unlock()
251 _ = c.conn.Close()
252}