websocket.go

  1package handler
  2
  3import (
  4	"bytes"
  5	"context"
  6	"encoding/json"
  7	"fmt"
  8	"log"
  9	"net/http"
 10	"sync"
 11	"time"
 12
 13	"github.com/99designs/gqlgen/graphql"
 14	"github.com/gorilla/websocket"
 15	lru "github.com/hashicorp/golang-lru"
 16	"github.com/vektah/gqlparser"
 17	"github.com/vektah/gqlparser/ast"
 18	"github.com/vektah/gqlparser/gqlerror"
 19	"github.com/vektah/gqlparser/validator"
 20)
 21
 22const (
 23	connectionInitMsg      = "connection_init"      // Client -> Server
 24	connectionTerminateMsg = "connection_terminate" // Client -> Server
 25	startMsg               = "start"                // Client -> Server
 26	stopMsg                = "stop"                 // Client -> Server
 27	connectionAckMsg       = "connection_ack"       // Server -> Client
 28	connectionErrorMsg     = "connection_error"     // Server -> Client
 29	dataMsg                = "data"                 // Server -> Client
 30	errorMsg               = "error"                // Server -> Client
 31	completeMsg            = "complete"             // Server -> Client
 32	connectionKeepAliveMsg = "ka"                   // Server -> Client
 33)
 34
 35type operationMessage struct {
 36	Payload json.RawMessage `json:"payload,omitempty"`
 37	ID      string          `json:"id,omitempty"`
 38	Type    string          `json:"type"`
 39}
 40
 41type wsConnection struct {
 42	ctx             context.Context
 43	conn            *websocket.Conn
 44	exec            graphql.ExecutableSchema
 45	active          map[string]context.CancelFunc
 46	mu              sync.Mutex
 47	cfg             *Config
 48	cache           *lru.Cache
 49	keepAliveTicker *time.Ticker
 50
 51	initPayload InitPayload
 52}
 53
 54func connectWs(exec graphql.ExecutableSchema, w http.ResponseWriter, r *http.Request, cfg *Config, cache *lru.Cache) {
 55	ws, err := cfg.upgrader.Upgrade(w, r, http.Header{
 56		"Sec-Websocket-Protocol": []string{"graphql-ws"},
 57	})
 58	if err != nil {
 59		log.Printf("unable to upgrade %T to websocket %s: ", w, err.Error())
 60		sendErrorf(w, http.StatusBadRequest, "unable to upgrade")
 61		return
 62	}
 63
 64	conn := wsConnection{
 65		active: map[string]context.CancelFunc{},
 66		exec:   exec,
 67		conn:   ws,
 68		ctx:    r.Context(),
 69		cfg:    cfg,
 70		cache:  cache,
 71	}
 72
 73	if !conn.init() {
 74		return
 75	}
 76
 77	conn.run()
 78}
 79
 80func (c *wsConnection) init() bool {
 81	message := c.readOp()
 82	if message == nil {
 83		c.close(websocket.CloseProtocolError, "decoding error")
 84		return false
 85	}
 86
 87	switch message.Type {
 88	case connectionInitMsg:
 89		if len(message.Payload) > 0 {
 90			c.initPayload = make(InitPayload)
 91			err := json.Unmarshal(message.Payload, &c.initPayload)
 92			if err != nil {
 93				return false
 94			}
 95		}
 96
 97		if c.cfg.websocketInitFunc != nil {
 98			if err := c.cfg.websocketInitFunc(c.ctx, c.initPayload); err != nil {
 99				c.sendConnectionError(err.Error())
100				c.close(websocket.CloseNormalClosure, "terminated")
101				return false
102			}
103		}
104
105		c.write(&operationMessage{Type: connectionAckMsg})
106		c.write(&operationMessage{Type: connectionKeepAliveMsg})
107	case connectionTerminateMsg:
108		c.close(websocket.CloseNormalClosure, "terminated")
109		return false
110	default:
111		c.sendConnectionError("unexpected message %s", message.Type)
112		c.close(websocket.CloseProtocolError, "unexpected message")
113		return false
114	}
115
116	return true
117}
118
119func (c *wsConnection) write(msg *operationMessage) {
120	c.mu.Lock()
121	c.conn.WriteJSON(msg)
122	c.mu.Unlock()
123}
124
125func (c *wsConnection) run() {
126	// We create a cancellation that will shutdown the keep-alive when we leave
127	// this function.
128	ctx, cancel := context.WithCancel(c.ctx)
129	defer cancel()
130
131	// Create a timer that will fire every interval to keep the connection alive.
132	if c.cfg.connectionKeepAlivePingInterval != 0 {
133		c.mu.Lock()
134		c.keepAliveTicker = time.NewTicker(c.cfg.connectionKeepAlivePingInterval)
135		c.mu.Unlock()
136
137		go c.keepAlive(ctx)
138	}
139
140	for {
141		message := c.readOp()
142		if message == nil {
143			return
144		}
145
146		switch message.Type {
147		case startMsg:
148			if !c.subscribe(message) {
149				return
150			}
151		case stopMsg:
152			c.mu.Lock()
153			closer := c.active[message.ID]
154			c.mu.Unlock()
155			if closer == nil {
156				c.sendError(message.ID, gqlerror.Errorf("%s is not running, cannot stop", message.ID))
157				continue
158			}
159
160			closer()
161		case connectionTerminateMsg:
162			c.close(websocket.CloseNormalClosure, "terminated")
163			return
164		default:
165			c.sendConnectionError("unexpected message %s", message.Type)
166			c.close(websocket.CloseProtocolError, "unexpected message")
167			return
168		}
169	}
170}
171
172func (c *wsConnection) keepAlive(ctx context.Context) {
173	for {
174		select {
175		case <-ctx.Done():
176			c.keepAliveTicker.Stop()
177			return
178		case <-c.keepAliveTicker.C:
179			c.write(&operationMessage{Type: connectionKeepAliveMsg})
180		}
181	}
182}
183
184func (c *wsConnection) subscribe(message *operationMessage) bool {
185	var reqParams params
186	if err := jsonDecode(bytes.NewReader(message.Payload), &reqParams); err != nil {
187		c.sendConnectionError("invalid json")
188		return false
189	}
190
191	var (
192		doc      *ast.QueryDocument
193		cacheHit bool
194	)
195	if c.cache != nil {
196		val, ok := c.cache.Get(reqParams.Query)
197		if ok {
198			doc = val.(*ast.QueryDocument)
199			cacheHit = true
200		}
201	}
202	if !cacheHit {
203		var qErr gqlerror.List
204		doc, qErr = gqlparser.LoadQuery(c.exec.Schema(), reqParams.Query)
205		if qErr != nil {
206			c.sendError(message.ID, qErr...)
207			return true
208		}
209		if c.cache != nil {
210			c.cache.Add(reqParams.Query, doc)
211		}
212	}
213
214	op := doc.Operations.ForName(reqParams.OperationName)
215	if op == nil {
216		c.sendError(message.ID, gqlerror.Errorf("operation %s not found", reqParams.OperationName))
217		return true
218	}
219
220	vars, err := validator.VariableValues(c.exec.Schema(), op, reqParams.Variables)
221	if err != nil {
222		c.sendError(message.ID, err)
223		return true
224	}
225	reqCtx := c.cfg.newRequestContext(c.exec, doc, op, reqParams.Query, vars)
226	ctx := graphql.WithRequestContext(c.ctx, reqCtx)
227
228	if c.initPayload != nil {
229		ctx = withInitPayload(ctx, c.initPayload)
230	}
231
232	if op.Operation != ast.Subscription {
233		var result *graphql.Response
234		if op.Operation == ast.Query {
235			result = c.exec.Query(ctx, op)
236		} else {
237			result = c.exec.Mutation(ctx, op)
238		}
239
240		c.sendData(message.ID, result)
241		c.write(&operationMessage{ID: message.ID, Type: completeMsg})
242		return true
243	}
244
245	ctx, cancel := context.WithCancel(ctx)
246	c.mu.Lock()
247	c.active[message.ID] = cancel
248	c.mu.Unlock()
249	go func() {
250		defer func() {
251			if r := recover(); r != nil {
252				userErr := reqCtx.Recover(ctx, r)
253				c.sendError(message.ID, &gqlerror.Error{Message: userErr.Error()})
254			}
255		}()
256		next := c.exec.Subscription(ctx, op)
257		for result := next(); result != nil; result = next() {
258			c.sendData(message.ID, result)
259		}
260
261		c.write(&operationMessage{ID: message.ID, Type: completeMsg})
262
263		c.mu.Lock()
264		delete(c.active, message.ID)
265		c.mu.Unlock()
266		cancel()
267	}()
268
269	return true
270}
271
272func (c *wsConnection) sendData(id string, response *graphql.Response) {
273	b, err := json.Marshal(response)
274	if err != nil {
275		c.sendError(id, gqlerror.Errorf("unable to encode json response: %s", err.Error()))
276		return
277	}
278
279	c.write(&operationMessage{Type: dataMsg, ID: id, Payload: b})
280}
281
282func (c *wsConnection) sendError(id string, errors ...*gqlerror.Error) {
283	errs := make([]error, len(errors))
284	for i, err := range errors {
285		errs[i] = err
286	}
287	b, err := json.Marshal(errs)
288	if err != nil {
289		panic(err)
290	}
291	c.write(&operationMessage{Type: errorMsg, ID: id, Payload: b})
292}
293
294func (c *wsConnection) sendConnectionError(format string, args ...interface{}) {
295	b, err := json.Marshal(&gqlerror.Error{Message: fmt.Sprintf(format, args...)})
296	if err != nil {
297		panic(err)
298	}
299
300	c.write(&operationMessage{Type: connectionErrorMsg, Payload: b})
301}
302
303func (c *wsConnection) readOp() *operationMessage {
304	_, r, err := c.conn.NextReader()
305	if err != nil {
306		c.sendConnectionError("invalid json")
307		return nil
308	}
309	message := operationMessage{}
310	if err := jsonDecode(r, &message); err != nil {
311		c.sendConnectionError("invalid json")
312		return nil
313	}
314
315	return &message
316}
317
318func (c *wsConnection) close(closeCode int, message string) {
319	c.mu.Lock()
320	_ = c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(closeCode, message))
321	c.mu.Unlock()
322	_ = c.conn.Close()
323}