websocket.go

  1package handler
  2
  3import (
  4	"bytes"
  5	"context"
  6	"encoding/json"
  7	"fmt"
  8	"log"
  9	"net/http"
 10	"sync"
 11	"time"
 12
 13	"github.com/99designs/gqlgen/graphql"
 14	"github.com/gorilla/websocket"
 15	"github.com/vektah/gqlparser"
 16	"github.com/vektah/gqlparser/ast"
 17	"github.com/vektah/gqlparser/gqlerror"
 18	"github.com/vektah/gqlparser/validator"
 19)
 20
 21const (
 22	connectionInitMsg      = "connection_init"      // Client -> Server
 23	connectionTerminateMsg = "connection_terminate" // Client -> Server
 24	startMsg               = "start"                // Client -> Server
 25	stopMsg                = "stop"                 // Client -> Server
 26	connectionAckMsg       = "connection_ack"       // Server -> Client
 27	connectionErrorMsg     = "connection_error"     // Server -> Client
 28	dataMsg                = "data"                 // Server -> Client
 29	errorMsg               = "error"                // Server -> Client
 30	completeMsg            = "complete"             // Server -> Client
 31	connectionKeepAliveMsg = "ka"                   // Server -> Client
 32)
 33
 34type operationMessage struct {
 35	Payload json.RawMessage `json:"payload,omitempty"`
 36	ID      string          `json:"id,omitempty"`
 37	Type    string          `json:"type"`
 38}
 39
 40type wsConnection struct {
 41	ctx             context.Context
 42	conn            *websocket.Conn
 43	exec            graphql.ExecutableSchema
 44	active          map[string]context.CancelFunc
 45	mu              sync.Mutex
 46	cfg             *Config
 47	keepAliveTicker *time.Ticker
 48
 49	initPayload InitPayload
 50}
 51
 52func connectWs(exec graphql.ExecutableSchema, w http.ResponseWriter, r *http.Request, cfg *Config) {
 53	ws, err := cfg.upgrader.Upgrade(w, r, http.Header{
 54		"Sec-Websocket-Protocol": []string{"graphql-ws"},
 55	})
 56	if err != nil {
 57		log.Printf("unable to upgrade %T to websocket %s: ", w, err.Error())
 58		sendErrorf(w, http.StatusBadRequest, "unable to upgrade")
 59		return
 60	}
 61
 62	conn := wsConnection{
 63		active: map[string]context.CancelFunc{},
 64		exec:   exec,
 65		conn:   ws,
 66		ctx:    r.Context(),
 67		cfg:    cfg,
 68	}
 69
 70	if !conn.init() {
 71		return
 72	}
 73
 74	conn.run()
 75}
 76
 77func (c *wsConnection) init() bool {
 78	message := c.readOp()
 79	if message == nil {
 80		c.close(websocket.CloseProtocolError, "decoding error")
 81		return false
 82	}
 83
 84	switch message.Type {
 85	case connectionInitMsg:
 86		if len(message.Payload) > 0 {
 87			c.initPayload = make(InitPayload)
 88			err := json.Unmarshal(message.Payload, &c.initPayload)
 89			if err != nil {
 90				return false
 91			}
 92		}
 93
 94		c.write(&operationMessage{Type: connectionAckMsg})
 95	case connectionTerminateMsg:
 96		c.close(websocket.CloseNormalClosure, "terminated")
 97		return false
 98	default:
 99		c.sendConnectionError("unexpected message %s", message.Type)
100		c.close(websocket.CloseProtocolError, "unexpected message")
101		return false
102	}
103
104	return true
105}
106
107func (c *wsConnection) write(msg *operationMessage) {
108	c.mu.Lock()
109	c.conn.WriteJSON(msg)
110	c.mu.Unlock()
111}
112
113func (c *wsConnection) run() {
114	// We create a cancellation that will shutdown the keep-alive when we leave
115	// this function.
116	ctx, cancel := context.WithCancel(c.ctx)
117	defer cancel()
118
119	// Create a timer that will fire every interval to keep the connection alive.
120	if c.cfg.connectionKeepAlivePingInterval != 0 {
121		c.mu.Lock()
122		c.keepAliveTicker = time.NewTicker(c.cfg.connectionKeepAlivePingInterval)
123		c.mu.Unlock()
124
125		go c.keepAlive(ctx)
126	}
127
128	for {
129		message := c.readOp()
130		if message == nil {
131			return
132		}
133
134		switch message.Type {
135		case startMsg:
136			if !c.subscribe(message) {
137				return
138			}
139		case stopMsg:
140			c.mu.Lock()
141			closer := c.active[message.ID]
142			c.mu.Unlock()
143			if closer == nil {
144				c.sendError(message.ID, gqlerror.Errorf("%s is not running, cannot stop", message.ID))
145				continue
146			}
147
148			closer()
149		case connectionTerminateMsg:
150			c.close(websocket.CloseNormalClosure, "terminated")
151			return
152		default:
153			c.sendConnectionError("unexpected message %s", message.Type)
154			c.close(websocket.CloseProtocolError, "unexpected message")
155			return
156		}
157	}
158}
159
160func (c *wsConnection) keepAlive(ctx context.Context) {
161	for {
162		select {
163		case <-ctx.Done():
164			c.keepAliveTicker.Stop()
165			return
166		case <-c.keepAliveTicker.C:
167			c.write(&operationMessage{Type: connectionKeepAliveMsg})
168		}
169	}
170}
171
172func (c *wsConnection) subscribe(message *operationMessage) bool {
173	var reqParams params
174	if err := jsonDecode(bytes.NewReader(message.Payload), &reqParams); err != nil {
175		c.sendConnectionError("invalid json")
176		return false
177	}
178
179	doc, qErr := gqlparser.LoadQuery(c.exec.Schema(), reqParams.Query)
180	if qErr != nil {
181		c.sendError(message.ID, qErr...)
182		return true
183	}
184
185	op := doc.Operations.ForName(reqParams.OperationName)
186	if op == nil {
187		c.sendError(message.ID, gqlerror.Errorf("operation %s not found", reqParams.OperationName))
188		return true
189	}
190
191	vars, err := validator.VariableValues(c.exec.Schema(), op, reqParams.Variables)
192	if err != nil {
193		c.sendError(message.ID, err)
194		return true
195	}
196	reqCtx := c.cfg.newRequestContext(c.exec, doc, op, reqParams.Query, vars)
197	ctx := graphql.WithRequestContext(c.ctx, reqCtx)
198
199	if c.initPayload != nil {
200		ctx = withInitPayload(ctx, c.initPayload)
201	}
202
203	if op.Operation != ast.Subscription {
204		var result *graphql.Response
205		if op.Operation == ast.Query {
206			result = c.exec.Query(ctx, op)
207		} else {
208			result = c.exec.Mutation(ctx, op)
209		}
210
211		c.sendData(message.ID, result)
212		c.write(&operationMessage{ID: message.ID, Type: completeMsg})
213		return true
214	}
215
216	ctx, cancel := context.WithCancel(ctx)
217	c.mu.Lock()
218	c.active[message.ID] = cancel
219	c.mu.Unlock()
220	go func() {
221		defer func() {
222			if r := recover(); r != nil {
223				userErr := reqCtx.Recover(ctx, r)
224				c.sendError(message.ID, &gqlerror.Error{Message: userErr.Error()})
225			}
226		}()
227		next := c.exec.Subscription(ctx, op)
228		for result := next(); result != nil; result = next() {
229			c.sendData(message.ID, result)
230		}
231
232		c.write(&operationMessage{ID: message.ID, Type: completeMsg})
233
234		c.mu.Lock()
235		delete(c.active, message.ID)
236		c.mu.Unlock()
237		cancel()
238	}()
239
240	return true
241}
242
243func (c *wsConnection) sendData(id string, response *graphql.Response) {
244	b, err := json.Marshal(response)
245	if err != nil {
246		c.sendError(id, gqlerror.Errorf("unable to encode json response: %s", err.Error()))
247		return
248	}
249
250	c.write(&operationMessage{Type: dataMsg, ID: id, Payload: b})
251}
252
253func (c *wsConnection) sendError(id string, errors ...*gqlerror.Error) {
254	var errs []error
255	for _, err := range errors {
256		errs = append(errs, err)
257	}
258	b, err := json.Marshal(errs)
259	if err != nil {
260		panic(err)
261	}
262	c.write(&operationMessage{Type: errorMsg, ID: id, Payload: b})
263}
264
265func (c *wsConnection) sendConnectionError(format string, args ...interface{}) {
266	b, err := json.Marshal(&gqlerror.Error{Message: fmt.Sprintf(format, args...)})
267	if err != nil {
268		panic(err)
269	}
270
271	c.write(&operationMessage{Type: connectionErrorMsg, Payload: b})
272}
273
274func (c *wsConnection) readOp() *operationMessage {
275	_, r, err := c.conn.NextReader()
276	if err != nil {
277		c.sendConnectionError("invalid json")
278		return nil
279	}
280	message := operationMessage{}
281	if err := jsonDecode(r, &message); err != nil {
282		c.sendConnectionError("invalid json")
283		return nil
284	}
285
286	return &message
287}
288
289func (c *wsConnection) close(closeCode int, message string) {
290	c.mu.Lock()
291	_ = c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(closeCode, message))
292	c.mu.Unlock()
293	_ = c.conn.Close()
294}