websocket.go

  1package handler
  2
  3import (
  4	"bytes"
  5	"context"
  6	"encoding/json"
  7	"fmt"
  8	"log"
  9	"net/http"
 10	"sync"
 11
 12	"github.com/99designs/gqlgen/graphql"
 13	"github.com/gorilla/websocket"
 14	"github.com/vektah/gqlparser"
 15	"github.com/vektah/gqlparser/ast"
 16	"github.com/vektah/gqlparser/gqlerror"
 17	"github.com/vektah/gqlparser/validator"
 18)
 19
 20const (
 21	connectionInitMsg      = "connection_init"      // Client -> Server
 22	connectionTerminateMsg = "connection_terminate" // Client -> Server
 23	startMsg               = "start"                // Client -> Server
 24	stopMsg                = "stop"                 // Client -> Server
 25	connectionAckMsg       = "connection_ack"       // Server -> Client
 26	connectionErrorMsg     = "connection_error"     // Server -> Client
 27	dataMsg                = "data"                 // Server -> Client
 28	errorMsg               = "error"                // Server -> Client
 29	completeMsg            = "complete"             // Server -> Client
 30	//connectionKeepAliveMsg = "ka"                 // Server -> Client  TODO: keepalives
 31)
 32
 33type operationMessage struct {
 34	Payload json.RawMessage `json:"payload,omitempty"`
 35	ID      string          `json:"id,omitempty"`
 36	Type    string          `json:"type"`
 37}
 38
 39type wsConnection struct {
 40	ctx    context.Context
 41	conn   *websocket.Conn
 42	exec   graphql.ExecutableSchema
 43	active map[string]context.CancelFunc
 44	mu     sync.Mutex
 45	cfg    *Config
 46}
 47
 48func connectWs(exec graphql.ExecutableSchema, w http.ResponseWriter, r *http.Request, cfg *Config) {
 49	ws, err := cfg.upgrader.Upgrade(w, r, http.Header{
 50		"Sec-Websocket-Protocol": []string{"graphql-ws"},
 51	})
 52	if err != nil {
 53		log.Printf("unable to upgrade %T to websocket %s: ", w, err.Error())
 54		sendErrorf(w, http.StatusBadRequest, "unable to upgrade")
 55		return
 56	}
 57
 58	conn := wsConnection{
 59		active: map[string]context.CancelFunc{},
 60		exec:   exec,
 61		conn:   ws,
 62		ctx:    r.Context(),
 63		cfg:    cfg,
 64	}
 65
 66	if !conn.init() {
 67		return
 68	}
 69
 70	conn.run()
 71}
 72
 73func (c *wsConnection) init() bool {
 74	message := c.readOp()
 75	if message == nil {
 76		c.close(websocket.CloseProtocolError, "decoding error")
 77		return false
 78	}
 79
 80	switch message.Type {
 81	case connectionInitMsg:
 82		c.write(&operationMessage{Type: connectionAckMsg})
 83	case connectionTerminateMsg:
 84		c.close(websocket.CloseNormalClosure, "terminated")
 85		return false
 86	default:
 87		c.sendConnectionError("unexpected message %s", message.Type)
 88		c.close(websocket.CloseProtocolError, "unexpected message")
 89		return false
 90	}
 91
 92	return true
 93}
 94
 95func (c *wsConnection) write(msg *operationMessage) {
 96	c.mu.Lock()
 97	c.conn.WriteJSON(msg)
 98	c.mu.Unlock()
 99}
100
101func (c *wsConnection) run() {
102	for {
103		message := c.readOp()
104		if message == nil {
105			return
106		}
107
108		switch message.Type {
109		case startMsg:
110			if !c.subscribe(message) {
111				return
112			}
113		case stopMsg:
114			c.mu.Lock()
115			closer := c.active[message.ID]
116			c.mu.Unlock()
117			if closer == nil {
118				c.sendError(message.ID, gqlerror.Errorf("%s is not running, cannot stop", message.ID))
119				continue
120			}
121
122			closer()
123		case connectionTerminateMsg:
124			c.close(websocket.CloseNormalClosure, "terminated")
125			return
126		default:
127			c.sendConnectionError("unexpected message %s", message.Type)
128			c.close(websocket.CloseProtocolError, "unexpected message")
129			return
130		}
131	}
132}
133
134func (c *wsConnection) subscribe(message *operationMessage) bool {
135	var reqParams params
136	if err := jsonDecode(bytes.NewReader(message.Payload), &reqParams); err != nil {
137		c.sendConnectionError("invalid json")
138		return false
139	}
140
141	doc, qErr := gqlparser.LoadQuery(c.exec.Schema(), reqParams.Query)
142	if qErr != nil {
143		c.sendError(message.ID, qErr...)
144		return true
145	}
146
147	op := doc.Operations.ForName(reqParams.OperationName)
148	if op == nil {
149		c.sendError(message.ID, gqlerror.Errorf("operation %s not found", reqParams.OperationName))
150		return true
151	}
152
153	vars, err := validator.VariableValues(c.exec.Schema(), op, reqParams.Variables)
154	if err != nil {
155		c.sendError(message.ID, err)
156		return true
157	}
158	reqCtx := c.cfg.newRequestContext(doc, reqParams.Query, vars)
159	ctx := graphql.WithRequestContext(c.ctx, reqCtx)
160
161	if op.Operation != ast.Subscription {
162		var result *graphql.Response
163		if op.Operation == ast.Query {
164			result = c.exec.Query(ctx, op)
165		} else {
166			result = c.exec.Mutation(ctx, op)
167		}
168
169		c.sendData(message.ID, result)
170		c.write(&operationMessage{ID: message.ID, Type: completeMsg})
171		return true
172	}
173
174	ctx, cancel := context.WithCancel(ctx)
175	c.mu.Lock()
176	c.active[message.ID] = cancel
177	c.mu.Unlock()
178	go func() {
179		defer func() {
180			if r := recover(); r != nil {
181				userErr := reqCtx.Recover(ctx, r)
182				c.sendError(message.ID, &gqlerror.Error{Message: userErr.Error()})
183			}
184		}()
185		next := c.exec.Subscription(ctx, op)
186		for result := next(); result != nil; result = next() {
187			c.sendData(message.ID, result)
188		}
189
190		c.write(&operationMessage{ID: message.ID, Type: completeMsg})
191
192		c.mu.Lock()
193		delete(c.active, message.ID)
194		c.mu.Unlock()
195		cancel()
196	}()
197
198	return true
199}
200
201func (c *wsConnection) sendData(id string, response *graphql.Response) {
202	b, err := json.Marshal(response)
203	if err != nil {
204		c.sendError(id, gqlerror.Errorf("unable to encode json response: %s", err.Error()))
205		return
206	}
207
208	c.write(&operationMessage{Type: dataMsg, ID: id, Payload: b})
209}
210
211func (c *wsConnection) sendError(id string, errors ...*gqlerror.Error) {
212	var errs []error
213	for _, err := range errors {
214		errs = append(errs, err)
215	}
216	b, err := json.Marshal(errs)
217	if err != nil {
218		panic(err)
219	}
220	c.write(&operationMessage{Type: errorMsg, ID: id, Payload: b})
221}
222
223func (c *wsConnection) sendConnectionError(format string, args ...interface{}) {
224	b, err := json.Marshal(&gqlerror.Error{Message: fmt.Sprintf(format, args...)})
225	if err != nil {
226		panic(err)
227	}
228
229	c.write(&operationMessage{Type: connectionErrorMsg, Payload: b})
230}
231
232func (c *wsConnection) readOp() *operationMessage {
233	_, r, err := c.conn.NextReader()
234	if err != nil {
235		c.sendConnectionError("invalid json")
236		return nil
237	}
238	message := operationMessage{}
239	if err := jsonDecode(r, &message); err != nil {
240		c.sendConnectionError("invalid json")
241		return nil
242	}
243
244	return &message
245}
246
247func (c *wsConnection) close(closeCode int, message string) {
248	c.mu.Lock()
249	_ = c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(closeCode, message))
250	c.mu.Unlock()
251	_ = c.conn.Close()
252}