1package handler
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "log"
8 "net/http"
9 "sync"
10
11 "github.com/gorilla/websocket"
12 "github.com/vektah/gqlgen/graphql"
13 "github.com/vektah/gqlgen/neelance/errors"
14 "github.com/vektah/gqlgen/neelance/query"
15 "github.com/vektah/gqlgen/neelance/validation"
16)
17
18const (
19 connectionInitMsg = "connection_init" // Client -> Server
20 connectionTerminateMsg = "connection_terminate" // Client -> Server
21 startMsg = "start" // Client -> Server
22 stopMsg = "stop" // Client -> Server
23 connectionAckMsg = "connection_ack" // Server -> Client
24 connectionErrorMsg = "connection_error" // Server -> Client
25 dataMsg = "data" // Server -> Client
26 errorMsg = "error" // Server -> Client
27 completeMsg = "complete" // Server -> Client
28 //connectionKeepAliveMsg = "ka" // Server -> Client TODO: keepalives
29)
30
31type operationMessage struct {
32 Payload json.RawMessage `json:"payload,omitempty"`
33 ID string `json:"id,omitempty"`
34 Type string `json:"type"`
35}
36
37type wsConnection struct {
38 ctx context.Context
39 conn *websocket.Conn
40 exec graphql.ExecutableSchema
41 active map[string]context.CancelFunc
42 mu sync.Mutex
43 cfg *Config
44}
45
46func connectWs(exec graphql.ExecutableSchema, w http.ResponseWriter, r *http.Request, cfg *Config) {
47 ws, err := cfg.upgrader.Upgrade(w, r, http.Header{
48 "Sec-Websocket-Protocol": []string{"graphql-ws"},
49 })
50 if err != nil {
51 log.Printf("unable to upgrade %T to websocket %s: ", w, err.Error())
52 sendErrorf(w, http.StatusBadRequest, "unable to upgrade")
53 return
54 }
55
56 conn := wsConnection{
57 active: map[string]context.CancelFunc{},
58 exec: exec,
59 conn: ws,
60 ctx: r.Context(),
61 cfg: cfg,
62 }
63
64 if !conn.init() {
65 return
66 }
67
68 conn.run()
69}
70
71func (c *wsConnection) init() bool {
72 message := c.readOp()
73 if message == nil {
74 c.close(websocket.CloseProtocolError, "decoding error")
75 return false
76 }
77
78 switch message.Type {
79 case connectionInitMsg:
80 c.write(&operationMessage{Type: connectionAckMsg})
81 case connectionTerminateMsg:
82 c.close(websocket.CloseNormalClosure, "terminated")
83 return false
84 default:
85 c.sendConnectionError("unexpected message %s", message.Type)
86 c.close(websocket.CloseProtocolError, "unexpected message")
87 return false
88 }
89
90 return true
91}
92
93func (c *wsConnection) write(msg *operationMessage) {
94 c.mu.Lock()
95 c.conn.WriteJSON(msg)
96 c.mu.Unlock()
97}
98
99func (c *wsConnection) run() {
100 for {
101 message := c.readOp()
102 if message == nil {
103 return
104 }
105
106 switch message.Type {
107 case startMsg:
108 if !c.subscribe(message) {
109 return
110 }
111 case stopMsg:
112 c.mu.Lock()
113 closer := c.active[message.ID]
114 c.mu.Unlock()
115 if closer == nil {
116 c.sendError(message.ID, errors.Errorf("%s is not running, cannot stop", message.ID))
117 continue
118 }
119
120 closer()
121 case connectionTerminateMsg:
122 c.close(websocket.CloseNormalClosure, "terminated")
123 return
124 default:
125 c.sendConnectionError("unexpected message %s", message.Type)
126 c.close(websocket.CloseProtocolError, "unexpected message")
127 return
128 }
129 }
130}
131
132func (c *wsConnection) subscribe(message *operationMessage) bool {
133 var reqParams params
134 if err := json.Unmarshal(message.Payload, &reqParams); err != nil {
135 c.sendConnectionError("invalid json")
136 return false
137 }
138
139 doc, qErr := query.Parse(reqParams.Query)
140 if qErr != nil {
141 c.sendError(message.ID, qErr)
142 return true
143 }
144
145 errs := validation.Validate(c.exec.Schema(), doc)
146 if len(errs) != 0 {
147 c.sendError(message.ID, errs...)
148 return true
149 }
150
151 op, err := doc.GetOperation(reqParams.OperationName)
152 if err != nil {
153 c.sendError(message.ID, errors.Errorf("%s", err.Error()))
154 return true
155 }
156
157 reqCtx := c.cfg.newRequestContext(doc, reqParams.Query, reqParams.Variables)
158 ctx := graphql.WithRequestContext(c.ctx, reqCtx)
159
160 if op.Type != query.Subscription {
161 var result *graphql.Response
162 if op.Type == query.Query {
163 result = c.exec.Query(ctx, op)
164 } else {
165 result = c.exec.Mutation(ctx, op)
166 }
167
168 c.sendData(message.ID, result)
169 c.write(&operationMessage{ID: message.ID, Type: completeMsg})
170 return true
171 }
172
173 ctx, cancel := context.WithCancel(ctx)
174 c.mu.Lock()
175 c.active[message.ID] = cancel
176 c.mu.Unlock()
177 go func() {
178 defer func() {
179 if r := recover(); r != nil {
180 userErr := reqCtx.Recover(ctx, r)
181 c.sendError(message.ID, &errors.QueryError{Message: userErr.Error()})
182 }
183 }()
184 next := c.exec.Subscription(ctx, op)
185 for result := next(); result != nil; result = next() {
186 c.sendData(message.ID, result)
187 }
188
189 c.write(&operationMessage{ID: message.ID, Type: completeMsg})
190
191 c.mu.Lock()
192 delete(c.active, message.ID)
193 c.mu.Unlock()
194 cancel()
195 }()
196
197 return true
198}
199
200func (c *wsConnection) sendData(id string, response *graphql.Response) {
201 b, err := json.Marshal(response)
202 if err != nil {
203 c.sendError(id, errors.Errorf("unable to encode json response: %s", err.Error()))
204 return
205 }
206
207 c.write(&operationMessage{Type: dataMsg, ID: id, Payload: b})
208}
209
210func (c *wsConnection) sendError(id string, errors ...*errors.QueryError) {
211 var errs []error
212 for _, err := range errors {
213 errs = append(errs, err)
214 }
215 b, err := json.Marshal(errs)
216 if err != nil {
217 panic(err)
218 }
219 c.write(&operationMessage{Type: errorMsg, ID: id, Payload: b})
220}
221
222func (c *wsConnection) sendConnectionError(format string, args ...interface{}) {
223 b, err := json.Marshal(&graphql.Error{Message: fmt.Sprintf(format, args...)})
224 if err != nil {
225 panic(err)
226 }
227
228 c.write(&operationMessage{Type: connectionErrorMsg, Payload: b})
229}
230
231func (c *wsConnection) readOp() *operationMessage {
232 message := operationMessage{}
233 if err := c.conn.ReadJSON(&message); err != nil {
234 c.sendConnectionError("invalid json")
235 return nil
236 }
237 return &message
238}
239
240func (c *wsConnection) close(closeCode int, message string) {
241 c.mu.Lock()
242 _ = c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(closeCode, message))
243 c.mu.Unlock()
244 _ = c.conn.Close()
245}