websocket.go

  1package handler
  2
  3import (
  4	"bytes"
  5	"context"
  6	"encoding/json"
  7	"fmt"
  8	"log"
  9	"net/http"
 10	"sync"
 11	"time"
 12
 13	"github.com/99designs/gqlgen/graphql"
 14	"github.com/gorilla/websocket"
 15	lru "github.com/hashicorp/golang-lru"
 16	"github.com/vektah/gqlparser"
 17	"github.com/vektah/gqlparser/ast"
 18	"github.com/vektah/gqlparser/gqlerror"
 19	"github.com/vektah/gqlparser/validator"
 20)
 21
 22const (
 23	connectionInitMsg      = "connection_init"      // Client -> Server
 24	connectionTerminateMsg = "connection_terminate" // Client -> Server
 25	startMsg               = "start"                // Client -> Server
 26	stopMsg                = "stop"                 // Client -> Server
 27	connectionAckMsg       = "connection_ack"       // Server -> Client
 28	connectionErrorMsg     = "connection_error"     // Server -> Client
 29	dataMsg                = "data"                 // Server -> Client
 30	errorMsg               = "error"                // Server -> Client
 31	completeMsg            = "complete"             // Server -> Client
 32	connectionKeepAliveMsg = "ka"                   // Server -> Client
 33)
 34
 35type operationMessage struct {
 36	Payload json.RawMessage `json:"payload,omitempty"`
 37	ID      string          `json:"id,omitempty"`
 38	Type    string          `json:"type"`
 39}
 40
 41type wsConnection struct {
 42	ctx             context.Context
 43	conn            *websocket.Conn
 44	exec            graphql.ExecutableSchema
 45	active          map[string]context.CancelFunc
 46	mu              sync.Mutex
 47	cfg             *Config
 48	cache           *lru.Cache
 49	keepAliveTicker *time.Ticker
 50
 51	initPayload InitPayload
 52}
 53
 54func connectWs(exec graphql.ExecutableSchema, w http.ResponseWriter, r *http.Request, cfg *Config, cache *lru.Cache) {
 55	ws, err := cfg.upgrader.Upgrade(w, r, http.Header{
 56		"Sec-Websocket-Protocol": []string{"graphql-ws"},
 57	})
 58	if err != nil {
 59		log.Printf("unable to upgrade %T to websocket %s: ", w, err.Error())
 60		sendErrorf(w, http.StatusBadRequest, "unable to upgrade")
 61		return
 62	}
 63
 64	conn := wsConnection{
 65		active: map[string]context.CancelFunc{},
 66		exec:   exec,
 67		conn:   ws,
 68		ctx:    r.Context(),
 69		cfg:    cfg,
 70		cache:  cache,
 71	}
 72
 73	if !conn.init() {
 74		return
 75	}
 76
 77	conn.run()
 78}
 79
 80func (c *wsConnection) init() bool {
 81	message := c.readOp()
 82	if message == nil {
 83		c.close(websocket.CloseProtocolError, "decoding error")
 84		return false
 85	}
 86
 87	switch message.Type {
 88	case connectionInitMsg:
 89		if len(message.Payload) > 0 {
 90			c.initPayload = make(InitPayload)
 91			err := json.Unmarshal(message.Payload, &c.initPayload)
 92			if err != nil {
 93				return false
 94			}
 95		}
 96
 97		if c.cfg.websocketInitFunc != nil {
 98			if err := c.cfg.websocketInitFunc(c.ctx, c.initPayload); err != nil {
 99				c.sendConnectionError(err.Error())
100				c.close(websocket.CloseNormalClosure, "terminated")
101				return false
102			}
103		}
104
105		c.write(&operationMessage{Type: connectionAckMsg})
106	case connectionTerminateMsg:
107		c.close(websocket.CloseNormalClosure, "terminated")
108		return false
109	default:
110		c.sendConnectionError("unexpected message %s", message.Type)
111		c.close(websocket.CloseProtocolError, "unexpected message")
112		return false
113	}
114
115	return true
116}
117
118func (c *wsConnection) write(msg *operationMessage) {
119	c.mu.Lock()
120	c.conn.WriteJSON(msg)
121	c.mu.Unlock()
122}
123
124func (c *wsConnection) run() {
125	// We create a cancellation that will shutdown the keep-alive when we leave
126	// this function.
127	ctx, cancel := context.WithCancel(c.ctx)
128	defer cancel()
129
130	// Create a timer that will fire every interval to keep the connection alive.
131	if c.cfg.connectionKeepAlivePingInterval != 0 {
132		c.mu.Lock()
133		c.keepAliveTicker = time.NewTicker(c.cfg.connectionKeepAlivePingInterval)
134		c.mu.Unlock()
135
136		go c.keepAlive(ctx)
137	}
138
139	for {
140		message := c.readOp()
141		if message == nil {
142			return
143		}
144
145		switch message.Type {
146		case startMsg:
147			if !c.subscribe(message) {
148				return
149			}
150		case stopMsg:
151			c.mu.Lock()
152			closer := c.active[message.ID]
153			c.mu.Unlock()
154			if closer == nil {
155				c.sendError(message.ID, gqlerror.Errorf("%s is not running, cannot stop", message.ID))
156				continue
157			}
158
159			closer()
160		case connectionTerminateMsg:
161			c.close(websocket.CloseNormalClosure, "terminated")
162			return
163		default:
164			c.sendConnectionError("unexpected message %s", message.Type)
165			c.close(websocket.CloseProtocolError, "unexpected message")
166			return
167		}
168	}
169}
170
171func (c *wsConnection) keepAlive(ctx context.Context) {
172	for {
173		select {
174		case <-ctx.Done():
175			c.keepAliveTicker.Stop()
176			return
177		case <-c.keepAliveTicker.C:
178			c.write(&operationMessage{Type: connectionKeepAliveMsg})
179		}
180	}
181}
182
183func (c *wsConnection) subscribe(message *operationMessage) bool {
184	var reqParams params
185	if err := jsonDecode(bytes.NewReader(message.Payload), &reqParams); err != nil {
186		c.sendConnectionError("invalid json")
187		return false
188	}
189
190	var (
191		doc      *ast.QueryDocument
192		cacheHit bool
193	)
194	if c.cache != nil {
195		val, ok := c.cache.Get(reqParams.Query)
196		if ok {
197			doc = val.(*ast.QueryDocument)
198			cacheHit = true
199		}
200	}
201	if !cacheHit {
202		var qErr gqlerror.List
203		doc, qErr = gqlparser.LoadQuery(c.exec.Schema(), reqParams.Query)
204		if qErr != nil {
205			c.sendError(message.ID, qErr...)
206			return true
207		}
208		if c.cache != nil {
209			c.cache.Add(reqParams.Query, doc)
210		}
211	}
212
213	op := doc.Operations.ForName(reqParams.OperationName)
214	if op == nil {
215		c.sendError(message.ID, gqlerror.Errorf("operation %s not found", reqParams.OperationName))
216		return true
217	}
218
219	vars, err := validator.VariableValues(c.exec.Schema(), op, reqParams.Variables)
220	if err != nil {
221		c.sendError(message.ID, err)
222		return true
223	}
224	reqCtx := c.cfg.newRequestContext(c.exec, doc, op, reqParams.Query, vars)
225	ctx := graphql.WithRequestContext(c.ctx, reqCtx)
226
227	if c.initPayload != nil {
228		ctx = withInitPayload(ctx, c.initPayload)
229	}
230
231	if op.Operation != ast.Subscription {
232		var result *graphql.Response
233		if op.Operation == ast.Query {
234			result = c.exec.Query(ctx, op)
235		} else {
236			result = c.exec.Mutation(ctx, op)
237		}
238
239		c.sendData(message.ID, result)
240		c.write(&operationMessage{ID: message.ID, Type: completeMsg})
241		return true
242	}
243
244	ctx, cancel := context.WithCancel(ctx)
245	c.mu.Lock()
246	c.active[message.ID] = cancel
247	c.mu.Unlock()
248	go func() {
249		defer func() {
250			if r := recover(); r != nil {
251				userErr := reqCtx.Recover(ctx, r)
252				c.sendError(message.ID, &gqlerror.Error{Message: userErr.Error()})
253			}
254		}()
255		next := c.exec.Subscription(ctx, op)
256		for result := next(); result != nil; result = next() {
257			c.sendData(message.ID, result)
258		}
259
260		c.write(&operationMessage{ID: message.ID, Type: completeMsg})
261
262		c.mu.Lock()
263		delete(c.active, message.ID)
264		c.mu.Unlock()
265		cancel()
266	}()
267
268	return true
269}
270
271func (c *wsConnection) sendData(id string, response *graphql.Response) {
272	b, err := json.Marshal(response)
273	if err != nil {
274		c.sendError(id, gqlerror.Errorf("unable to encode json response: %s", err.Error()))
275		return
276	}
277
278	c.write(&operationMessage{Type: dataMsg, ID: id, Payload: b})
279}
280
281func (c *wsConnection) sendError(id string, errors ...*gqlerror.Error) {
282	var errs []error
283	for _, err := range errors {
284		errs = append(errs, err)
285	}
286	b, err := json.Marshal(errs)
287	if err != nil {
288		panic(err)
289	}
290	c.write(&operationMessage{Type: errorMsg, ID: id, Payload: b})
291}
292
293func (c *wsConnection) sendConnectionError(format string, args ...interface{}) {
294	b, err := json.Marshal(&gqlerror.Error{Message: fmt.Sprintf(format, args...)})
295	if err != nil {
296		panic(err)
297	}
298
299	c.write(&operationMessage{Type: connectionErrorMsg, Payload: b})
300}
301
302func (c *wsConnection) readOp() *operationMessage {
303	_, r, err := c.conn.NextReader()
304	if err != nil {
305		c.sendConnectionError("invalid json")
306		return nil
307	}
308	message := operationMessage{}
309	if err := jsonDecode(r, &message); err != nil {
310		c.sendConnectionError("invalid json")
311		return nil
312	}
313
314	return &message
315}
316
317func (c *wsConnection) close(closeCode int, message string) {
318	c.mu.Lock()
319	_ = c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(closeCode, message))
320	c.mu.Unlock()
321	_ = c.conn.Close()
322}