websocket.go

  1package handler
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"log"
  8	"net/http"
  9	"sync"
 10
 11	"github.com/gorilla/websocket"
 12	"github.com/vektah/gqlgen/graphql"
 13	"github.com/vektah/gqlgen/neelance/errors"
 14	"github.com/vektah/gqlgen/neelance/query"
 15	"github.com/vektah/gqlgen/neelance/validation"
 16)
 17
 18const (
 19	connectionInitMsg      = "connection_init"      // Client -> Server
 20	connectionTerminateMsg = "connection_terminate" // Client -> Server
 21	startMsg               = "start"                // Client -> Server
 22	stopMsg                = "stop"                 // Client -> Server
 23	connectionAckMsg       = "connection_ack"       // Server -> Client
 24	connectionErrorMsg     = "connection_error"     // Server -> Client
 25	dataMsg                = "data"                 // Server -> Client
 26	errorMsg               = "error"                // Server -> Client
 27	completeMsg            = "complete"             // Server -> Client
 28	//connectionKeepAliveMsg = "ka"                 // Server -> Client  TODO: keepalives
 29)
 30
 31type operationMessage struct {
 32	Payload json.RawMessage `json:"payload,omitempty"`
 33	ID      string          `json:"id,omitempty"`
 34	Type    string          `json:"type"`
 35}
 36
 37type wsConnection struct {
 38	ctx    context.Context
 39	conn   *websocket.Conn
 40	exec   graphql.ExecutableSchema
 41	active map[string]context.CancelFunc
 42	mu     sync.Mutex
 43	cfg    *Config
 44}
 45
 46func connectWs(exec graphql.ExecutableSchema, w http.ResponseWriter, r *http.Request, cfg *Config) {
 47	ws, err := cfg.upgrader.Upgrade(w, r, http.Header{
 48		"Sec-Websocket-Protocol": []string{"graphql-ws"},
 49	})
 50	if err != nil {
 51		log.Printf("unable to upgrade %T to websocket %s: ", w, err.Error())
 52		sendErrorf(w, http.StatusBadRequest, "unable to upgrade")
 53		return
 54	}
 55
 56	conn := wsConnection{
 57		active: map[string]context.CancelFunc{},
 58		exec:   exec,
 59		conn:   ws,
 60		ctx:    r.Context(),
 61		cfg:    cfg,
 62	}
 63
 64	if !conn.init() {
 65		return
 66	}
 67
 68	conn.run()
 69}
 70
 71func (c *wsConnection) init() bool {
 72	message := c.readOp()
 73	if message == nil {
 74		c.close(websocket.CloseProtocolError, "decoding error")
 75		return false
 76	}
 77
 78	switch message.Type {
 79	case connectionInitMsg:
 80		c.write(&operationMessage{Type: connectionAckMsg})
 81	case connectionTerminateMsg:
 82		c.close(websocket.CloseNormalClosure, "terminated")
 83		return false
 84	default:
 85		c.sendConnectionError("unexpected message %s", message.Type)
 86		c.close(websocket.CloseProtocolError, "unexpected message")
 87		return false
 88	}
 89
 90	return true
 91}
 92
 93func (c *wsConnection) write(msg *operationMessage) {
 94	c.mu.Lock()
 95	c.conn.WriteJSON(msg)
 96	c.mu.Unlock()
 97}
 98
 99func (c *wsConnection) run() {
100	for {
101		message := c.readOp()
102		if message == nil {
103			return
104		}
105
106		switch message.Type {
107		case startMsg:
108			if !c.subscribe(message) {
109				return
110			}
111		case stopMsg:
112			c.mu.Lock()
113			closer := c.active[message.ID]
114			c.mu.Unlock()
115			if closer == nil {
116				c.sendError(message.ID, errors.Errorf("%s is not running, cannot stop", message.ID))
117				continue
118			}
119
120			closer()
121		case connectionTerminateMsg:
122			c.close(websocket.CloseNormalClosure, "terminated")
123			return
124		default:
125			c.sendConnectionError("unexpected message %s", message.Type)
126			c.close(websocket.CloseProtocolError, "unexpected message")
127			return
128		}
129	}
130}
131
132func (c *wsConnection) subscribe(message *operationMessage) bool {
133	var reqParams params
134	if err := json.Unmarshal(message.Payload, &reqParams); err != nil {
135		c.sendConnectionError("invalid json")
136		return false
137	}
138
139	doc, qErr := query.Parse(reqParams.Query)
140	if qErr != nil {
141		c.sendError(message.ID, qErr)
142		return true
143	}
144
145	errs := validation.Validate(c.exec.Schema(), doc)
146	if len(errs) != 0 {
147		c.sendError(message.ID, errs...)
148		return true
149	}
150
151	op, err := doc.GetOperation(reqParams.OperationName)
152	if err != nil {
153		c.sendError(message.ID, errors.Errorf("%s", err.Error()))
154		return true
155	}
156
157	reqCtx := c.cfg.newRequestContext(doc, reqParams.Query, reqParams.Variables)
158	ctx := graphql.WithRequestContext(c.ctx, reqCtx)
159
160	if op.Type != query.Subscription {
161		var result *graphql.Response
162		if op.Type == query.Query {
163			result = c.exec.Query(ctx, op)
164		} else {
165			result = c.exec.Mutation(ctx, op)
166		}
167
168		c.sendData(message.ID, result)
169		c.write(&operationMessage{ID: message.ID, Type: completeMsg})
170		return true
171	}
172
173	ctx, cancel := context.WithCancel(ctx)
174	c.mu.Lock()
175	c.active[message.ID] = cancel
176	c.mu.Unlock()
177	go func() {
178		defer func() {
179			if r := recover(); r != nil {
180				userErr := reqCtx.Recover(ctx, r)
181				c.sendError(message.ID, &errors.QueryError{Message: userErr.Error()})
182			}
183		}()
184		next := c.exec.Subscription(ctx, op)
185		for result := next(); result != nil; result = next() {
186			c.sendData(message.ID, result)
187		}
188
189		c.write(&operationMessage{ID: message.ID, Type: completeMsg})
190
191		c.mu.Lock()
192		delete(c.active, message.ID)
193		c.mu.Unlock()
194		cancel()
195	}()
196
197	return true
198}
199
200func (c *wsConnection) sendData(id string, response *graphql.Response) {
201	b, err := json.Marshal(response)
202	if err != nil {
203		c.sendError(id, errors.Errorf("unable to encode json response: %s", err.Error()))
204		return
205	}
206
207	c.write(&operationMessage{Type: dataMsg, ID: id, Payload: b})
208}
209
210func (c *wsConnection) sendError(id string, errors ...*errors.QueryError) {
211	var errs []error
212	for _, err := range errors {
213		errs = append(errs, err)
214	}
215	b, err := json.Marshal(errs)
216	if err != nil {
217		panic(err)
218	}
219	c.write(&operationMessage{Type: errorMsg, ID: id, Payload: b})
220}
221
222func (c *wsConnection) sendConnectionError(format string, args ...interface{}) {
223	b, err := json.Marshal(&graphql.Error{Message: fmt.Sprintf(format, args...)})
224	if err != nil {
225		panic(err)
226	}
227
228	c.write(&operationMessage{Type: connectionErrorMsg, Payload: b})
229}
230
231func (c *wsConnection) readOp() *operationMessage {
232	message := operationMessage{}
233	if err := c.conn.ReadJSON(&message); err != nil {
234		c.sendConnectionError("invalid json")
235		return nil
236	}
237	return &message
238}
239
240func (c *wsConnection) close(closeCode int, message string) {
241	c.mu.Lock()
242	_ = c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(closeCode, message))
243	c.mu.Unlock()
244	_ = c.conn.Close()
245}