websocket.go

  1package handler
  2
  3import (
  4	"bytes"
  5	"context"
  6	"encoding/json"
  7	"fmt"
  8	"log"
  9	"net/http"
 10	"sync"
 11	"time"
 12
 13	"github.com/99designs/gqlgen/graphql"
 14	"github.com/gorilla/websocket"
 15	"github.com/hashicorp/golang-lru"
 16	"github.com/vektah/gqlparser"
 17	"github.com/vektah/gqlparser/ast"
 18	"github.com/vektah/gqlparser/gqlerror"
 19	"github.com/vektah/gqlparser/validator"
 20)
 21
 22const (
 23	connectionInitMsg      = "connection_init"      // Client -> Server
 24	connectionTerminateMsg = "connection_terminate" // Client -> Server
 25	startMsg               = "start"                // Client -> Server
 26	stopMsg                = "stop"                 // Client -> Server
 27	connectionAckMsg       = "connection_ack"       // Server -> Client
 28	connectionErrorMsg     = "connection_error"     // Server -> Client
 29	dataMsg                = "data"                 // Server -> Client
 30	errorMsg               = "error"                // Server -> Client
 31	completeMsg            = "complete"             // Server -> Client
 32	connectionKeepAliveMsg = "ka"                   // Server -> Client
 33)
 34
 35type operationMessage struct {
 36	Payload json.RawMessage `json:"payload,omitempty"`
 37	ID      string          `json:"id,omitempty"`
 38	Type    string          `json:"type"`
 39}
 40
 41type wsConnection struct {
 42	ctx             context.Context
 43	conn            *websocket.Conn
 44	exec            graphql.ExecutableSchema
 45	active          map[string]context.CancelFunc
 46	mu              sync.Mutex
 47	cfg             *Config
 48	cache           *lru.Cache
 49	keepAliveTicker *time.Ticker
 50
 51	initPayload InitPayload
 52}
 53
 54func connectWs(exec graphql.ExecutableSchema, w http.ResponseWriter, r *http.Request, cfg *Config, cache *lru.Cache) {
 55	ws, err := cfg.upgrader.Upgrade(w, r, http.Header{
 56		"Sec-Websocket-Protocol": []string{"graphql-ws"},
 57	})
 58	if err != nil {
 59		log.Printf("unable to upgrade %T to websocket %s: ", w, err.Error())
 60		sendErrorf(w, http.StatusBadRequest, "unable to upgrade")
 61		return
 62	}
 63
 64	conn := wsConnection{
 65		active: map[string]context.CancelFunc{},
 66		exec:   exec,
 67		conn:   ws,
 68		ctx:    r.Context(),
 69		cfg:    cfg,
 70		cache:  cache,
 71	}
 72
 73	if !conn.init() {
 74		return
 75	}
 76
 77	conn.run()
 78}
 79
 80func (c *wsConnection) init() bool {
 81	message := c.readOp()
 82	if message == nil {
 83		c.close(websocket.CloseProtocolError, "decoding error")
 84		return false
 85	}
 86
 87	switch message.Type {
 88	case connectionInitMsg:
 89		if len(message.Payload) > 0 {
 90			c.initPayload = make(InitPayload)
 91			err := json.Unmarshal(message.Payload, &c.initPayload)
 92			if err != nil {
 93				return false
 94			}
 95		}
 96
 97		c.write(&operationMessage{Type: connectionAckMsg})
 98	case connectionTerminateMsg:
 99		c.close(websocket.CloseNormalClosure, "terminated")
100		return false
101	default:
102		c.sendConnectionError("unexpected message %s", message.Type)
103		c.close(websocket.CloseProtocolError, "unexpected message")
104		return false
105	}
106
107	return true
108}
109
110func (c *wsConnection) write(msg *operationMessage) {
111	c.mu.Lock()
112	c.conn.WriteJSON(msg)
113	c.mu.Unlock()
114}
115
116func (c *wsConnection) run() {
117	// We create a cancellation that will shutdown the keep-alive when we leave
118	// this function.
119	ctx, cancel := context.WithCancel(c.ctx)
120	defer cancel()
121
122	// Create a timer that will fire every interval to keep the connection alive.
123	if c.cfg.connectionKeepAlivePingInterval != 0 {
124		c.mu.Lock()
125		c.keepAliveTicker = time.NewTicker(c.cfg.connectionKeepAlivePingInterval)
126		c.mu.Unlock()
127
128		go c.keepAlive(ctx)
129	}
130
131	for {
132		message := c.readOp()
133		if message == nil {
134			return
135		}
136
137		switch message.Type {
138		case startMsg:
139			if !c.subscribe(message) {
140				return
141			}
142		case stopMsg:
143			c.mu.Lock()
144			closer := c.active[message.ID]
145			c.mu.Unlock()
146			if closer == nil {
147				c.sendError(message.ID, gqlerror.Errorf("%s is not running, cannot stop", message.ID))
148				continue
149			}
150
151			closer()
152		case connectionTerminateMsg:
153			c.close(websocket.CloseNormalClosure, "terminated")
154			return
155		default:
156			c.sendConnectionError("unexpected message %s", message.Type)
157			c.close(websocket.CloseProtocolError, "unexpected message")
158			return
159		}
160	}
161}
162
163func (c *wsConnection) keepAlive(ctx context.Context) {
164	for {
165		select {
166		case <-ctx.Done():
167			c.keepAliveTicker.Stop()
168			return
169		case <-c.keepAliveTicker.C:
170			c.write(&operationMessage{Type: connectionKeepAliveMsg})
171		}
172	}
173}
174
175func (c *wsConnection) subscribe(message *operationMessage) bool {
176	var reqParams params
177	if err := jsonDecode(bytes.NewReader(message.Payload), &reqParams); err != nil {
178		c.sendConnectionError("invalid json")
179		return false
180	}
181
182	var (
183		doc      *ast.QueryDocument
184		cacheHit bool
185	)
186	if c.cache != nil {
187		val, ok := c.cache.Get(reqParams.Query)
188		if ok {
189			doc = val.(*ast.QueryDocument)
190			cacheHit = true
191		}
192	}
193	if !cacheHit {
194		var qErr gqlerror.List
195		doc, qErr = gqlparser.LoadQuery(c.exec.Schema(), reqParams.Query)
196		if qErr != nil {
197			c.sendError(message.ID, qErr...)
198			return true
199		}
200		if c.cache != nil {
201			c.cache.Add(reqParams.Query, doc)
202		}
203	}
204
205	op := doc.Operations.ForName(reqParams.OperationName)
206	if op == nil {
207		c.sendError(message.ID, gqlerror.Errorf("operation %s not found", reqParams.OperationName))
208		return true
209	}
210
211	vars, err := validator.VariableValues(c.exec.Schema(), op, reqParams.Variables)
212	if err != nil {
213		c.sendError(message.ID, err)
214		return true
215	}
216	reqCtx := c.cfg.newRequestContext(c.exec, doc, op, reqParams.Query, vars)
217	ctx := graphql.WithRequestContext(c.ctx, reqCtx)
218
219	if c.initPayload != nil {
220		ctx = withInitPayload(ctx, c.initPayload)
221	}
222
223	if op.Operation != ast.Subscription {
224		var result *graphql.Response
225		if op.Operation == ast.Query {
226			result = c.exec.Query(ctx, op)
227		} else {
228			result = c.exec.Mutation(ctx, op)
229		}
230
231		c.sendData(message.ID, result)
232		c.write(&operationMessage{ID: message.ID, Type: completeMsg})
233		return true
234	}
235
236	ctx, cancel := context.WithCancel(ctx)
237	c.mu.Lock()
238	c.active[message.ID] = cancel
239	c.mu.Unlock()
240	go func() {
241		defer func() {
242			if r := recover(); r != nil {
243				userErr := reqCtx.Recover(ctx, r)
244				c.sendError(message.ID, &gqlerror.Error{Message: userErr.Error()})
245			}
246		}()
247		next := c.exec.Subscription(ctx, op)
248		for result := next(); result != nil; result = next() {
249			c.sendData(message.ID, result)
250		}
251
252		c.write(&operationMessage{ID: message.ID, Type: completeMsg})
253
254		c.mu.Lock()
255		delete(c.active, message.ID)
256		c.mu.Unlock()
257		cancel()
258	}()
259
260	return true
261}
262
263func (c *wsConnection) sendData(id string, response *graphql.Response) {
264	b, err := json.Marshal(response)
265	if err != nil {
266		c.sendError(id, gqlerror.Errorf("unable to encode json response: %s", err.Error()))
267		return
268	}
269
270	c.write(&operationMessage{Type: dataMsg, ID: id, Payload: b})
271}
272
273func (c *wsConnection) sendError(id string, errors ...*gqlerror.Error) {
274	var errs []error
275	for _, err := range errors {
276		errs = append(errs, err)
277	}
278	b, err := json.Marshal(errs)
279	if err != nil {
280		panic(err)
281	}
282	c.write(&operationMessage{Type: errorMsg, ID: id, Payload: b})
283}
284
285func (c *wsConnection) sendConnectionError(format string, args ...interface{}) {
286	b, err := json.Marshal(&gqlerror.Error{Message: fmt.Sprintf(format, args...)})
287	if err != nil {
288		panic(err)
289	}
290
291	c.write(&operationMessage{Type: connectionErrorMsg, Payload: b})
292}
293
294func (c *wsConnection) readOp() *operationMessage {
295	_, r, err := c.conn.NextReader()
296	if err != nil {
297		c.sendConnectionError("invalid json")
298		return nil
299	}
300	message := operationMessage{}
301	if err := jsonDecode(r, &message); err != nil {
302		c.sendConnectionError("invalid json")
303		return nil
304	}
305
306	return &message
307}
308
309func (c *wsConnection) close(closeCode int, message string) {
310	c.mu.Lock()
311	_ = c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(closeCode, message))
312	c.mu.Unlock()
313	_ = c.conn.Close()
314}