websocket.go

  1package handler
  2
  3import (
  4	"bytes"
  5	"context"
  6	"encoding/json"
  7	"fmt"
  8	"log"
  9	"net/http"
 10	"sync"
 11
 12	"github.com/99designs/gqlgen/graphql"
 13	"github.com/gorilla/websocket"
 14	"github.com/vektah/gqlparser"
 15	"github.com/vektah/gqlparser/ast"
 16	"github.com/vektah/gqlparser/gqlerror"
 17	"github.com/vektah/gqlparser/validator"
 18)
 19
 20const (
 21	connectionInitMsg      = "connection_init"      // Client -> Server
 22	connectionTerminateMsg = "connection_terminate" // Client -> Server
 23	startMsg               = "start"                // Client -> Server
 24	stopMsg                = "stop"                 // Client -> Server
 25	connectionAckMsg       = "connection_ack"       // Server -> Client
 26	connectionErrorMsg     = "connection_error"     // Server -> Client
 27	dataMsg                = "data"                 // Server -> Client
 28	errorMsg               = "error"                // Server -> Client
 29	completeMsg            = "complete"             // Server -> Client
 30	//connectionKeepAliveMsg = "ka"                 // Server -> Client  TODO: keepalives
 31)
 32
 33type operationMessage struct {
 34	Payload json.RawMessage `json:"payload,omitempty"`
 35	ID      string          `json:"id,omitempty"`
 36	Type    string          `json:"type"`
 37}
 38
 39type wsConnection struct {
 40	ctx    context.Context
 41	conn   *websocket.Conn
 42	exec   graphql.ExecutableSchema
 43	active map[string]context.CancelFunc
 44	mu     sync.Mutex
 45	cfg    *Config
 46
 47	initPayload InitPayload
 48}
 49
 50func connectWs(exec graphql.ExecutableSchema, w http.ResponseWriter, r *http.Request, cfg *Config) {
 51	ws, err := cfg.upgrader.Upgrade(w, r, http.Header{
 52		"Sec-Websocket-Protocol": []string{"graphql-ws"},
 53	})
 54	if err != nil {
 55		log.Printf("unable to upgrade %T to websocket %s: ", w, err.Error())
 56		sendErrorf(w, http.StatusBadRequest, "unable to upgrade")
 57		return
 58	}
 59
 60	conn := wsConnection{
 61		active: map[string]context.CancelFunc{},
 62		exec:   exec,
 63		conn:   ws,
 64		ctx:    r.Context(),
 65		cfg:    cfg,
 66	}
 67
 68	if !conn.init() {
 69		return
 70	}
 71
 72	conn.run()
 73}
 74
 75func (c *wsConnection) init() bool {
 76	message := c.readOp()
 77	if message == nil {
 78		c.close(websocket.CloseProtocolError, "decoding error")
 79		return false
 80	}
 81
 82	switch message.Type {
 83	case connectionInitMsg:
 84		if len(message.Payload) > 0 {
 85			c.initPayload = make(InitPayload)
 86			err := json.Unmarshal(message.Payload, &c.initPayload)
 87			if err != nil {
 88				return false
 89			}
 90		}
 91
 92		c.write(&operationMessage{Type: connectionAckMsg})
 93	case connectionTerminateMsg:
 94		c.close(websocket.CloseNormalClosure, "terminated")
 95		return false
 96	default:
 97		c.sendConnectionError("unexpected message %s", message.Type)
 98		c.close(websocket.CloseProtocolError, "unexpected message")
 99		return false
100	}
101
102	return true
103}
104
105func (c *wsConnection) write(msg *operationMessage) {
106	c.mu.Lock()
107	c.conn.WriteJSON(msg)
108	c.mu.Unlock()
109}
110
111func (c *wsConnection) run() {
112	for {
113		message := c.readOp()
114		if message == nil {
115			return
116		}
117
118		switch message.Type {
119		case startMsg:
120			if !c.subscribe(message) {
121				return
122			}
123		case stopMsg:
124			c.mu.Lock()
125			closer := c.active[message.ID]
126			c.mu.Unlock()
127			if closer == nil {
128				c.sendError(message.ID, gqlerror.Errorf("%s is not running, cannot stop", message.ID))
129				continue
130			}
131
132			closer()
133		case connectionTerminateMsg:
134			c.close(websocket.CloseNormalClosure, "terminated")
135			return
136		default:
137			c.sendConnectionError("unexpected message %s", message.Type)
138			c.close(websocket.CloseProtocolError, "unexpected message")
139			return
140		}
141	}
142}
143
144func (c *wsConnection) subscribe(message *operationMessage) bool {
145	var reqParams params
146	if err := jsonDecode(bytes.NewReader(message.Payload), &reqParams); err != nil {
147		c.sendConnectionError("invalid json")
148		return false
149	}
150
151	doc, qErr := gqlparser.LoadQuery(c.exec.Schema(), reqParams.Query)
152	if qErr != nil {
153		c.sendError(message.ID, qErr...)
154		return true
155	}
156
157	op := doc.Operations.ForName(reqParams.OperationName)
158	if op == nil {
159		c.sendError(message.ID, gqlerror.Errorf("operation %s not found", reqParams.OperationName))
160		return true
161	}
162
163	vars, err := validator.VariableValues(c.exec.Schema(), op, reqParams.Variables)
164	if err != nil {
165		c.sendError(message.ID, err)
166		return true
167	}
168	reqCtx := c.cfg.newRequestContext(c.exec, doc, op, reqParams.Query, vars)
169	ctx := graphql.WithRequestContext(c.ctx, reqCtx)
170
171	if c.initPayload != nil {
172		ctx = withInitPayload(ctx, c.initPayload)
173	}
174
175	if op.Operation != ast.Subscription {
176		var result *graphql.Response
177		if op.Operation == ast.Query {
178			result = c.exec.Query(ctx, op)
179		} else {
180			result = c.exec.Mutation(ctx, op)
181		}
182
183		c.sendData(message.ID, result)
184		c.write(&operationMessage{ID: message.ID, Type: completeMsg})
185		return true
186	}
187
188	ctx, cancel := context.WithCancel(ctx)
189	c.mu.Lock()
190	c.active[message.ID] = cancel
191	c.mu.Unlock()
192	go func() {
193		defer func() {
194			if r := recover(); r != nil {
195				userErr := reqCtx.Recover(ctx, r)
196				c.sendError(message.ID, &gqlerror.Error{Message: userErr.Error()})
197			}
198		}()
199		next := c.exec.Subscription(ctx, op)
200		for result := next(); result != nil; result = next() {
201			c.sendData(message.ID, result)
202		}
203
204		c.write(&operationMessage{ID: message.ID, Type: completeMsg})
205
206		c.mu.Lock()
207		delete(c.active, message.ID)
208		c.mu.Unlock()
209		cancel()
210	}()
211
212	return true
213}
214
215func (c *wsConnection) sendData(id string, response *graphql.Response) {
216	b, err := json.Marshal(response)
217	if err != nil {
218		c.sendError(id, gqlerror.Errorf("unable to encode json response: %s", err.Error()))
219		return
220	}
221
222	c.write(&operationMessage{Type: dataMsg, ID: id, Payload: b})
223}
224
225func (c *wsConnection) sendError(id string, errors ...*gqlerror.Error) {
226	var errs []error
227	for _, err := range errors {
228		errs = append(errs, err)
229	}
230	b, err := json.Marshal(errs)
231	if err != nil {
232		panic(err)
233	}
234	c.write(&operationMessage{Type: errorMsg, ID: id, Payload: b})
235}
236
237func (c *wsConnection) sendConnectionError(format string, args ...interface{}) {
238	b, err := json.Marshal(&gqlerror.Error{Message: fmt.Sprintf(format, args...)})
239	if err != nil {
240		panic(err)
241	}
242
243	c.write(&operationMessage{Type: connectionErrorMsg, Payload: b})
244}
245
246func (c *wsConnection) readOp() *operationMessage {
247	_, r, err := c.conn.NextReader()
248	if err != nil {
249		c.sendConnectionError("invalid json")
250		return nil
251	}
252	message := operationMessage{}
253	if err := jsonDecode(r, &message); err != nil {
254		c.sendConnectionError("invalid json")
255		return nil
256	}
257
258	return &message
259}
260
261func (c *wsConnection) close(closeCode int, message string) {
262	c.mu.Lock()
263	_ = c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(closeCode, message))
264	c.mu.Unlock()
265	_ = c.conn.Close()
266}