websocket.go

  1package client
  2
  3import (
  4	"encoding/json"
  5	"fmt"
  6	"strings"
  7
  8	"github.com/gorilla/websocket"
  9	"github.com/vektah/gqlparser/gqlerror"
 10)
 11
 12const (
 13	connectionInitMsg = "connection_init" // Client -> Server
 14	startMsg          = "start"           // Client -> Server
 15	connectionAckMsg  = "connection_ack"  // Server -> Client
 16	dataMsg           = "data"            // Server -> Client
 17	errorMsg          = "error"           // Server -> Client
 18)
 19
 20type operationMessage struct {
 21	Payload json.RawMessage `json:"payload,omitempty"`
 22	ID      string          `json:"id,omitempty"`
 23	Type    string          `json:"type"`
 24}
 25
 26type Subscription struct {
 27	Close func() error
 28	Next  func(response interface{}) error
 29}
 30
 31func errorSubscription(err error) *Subscription {
 32	return &Subscription{
 33		Close: func() error { return nil },
 34		Next: func(response interface{}) error {
 35			return err
 36		},
 37	}
 38}
 39
 40func (p *Client) Websocket(query string, options ...Option) *Subscription {
 41	r := p.mkRequest(query, options...)
 42	requestBody, err := json.Marshal(r)
 43	if err != nil {
 44		return errorSubscription(fmt.Errorf("encode: %s", err.Error()))
 45	}
 46
 47	url := strings.Replace(p.url, "http://", "ws://", -1)
 48	url = strings.Replace(url, "https://", "wss://", -1)
 49
 50	c, _, err := websocket.DefaultDialer.Dial(url, nil)
 51	if err != nil {
 52		return errorSubscription(fmt.Errorf("dial: %s", err.Error()))
 53	}
 54
 55	if err = c.WriteJSON(operationMessage{Type: connectionInitMsg}); err != nil {
 56		return errorSubscription(fmt.Errorf("init: %s", err.Error()))
 57	}
 58
 59	var ack operationMessage
 60	if err = c.ReadJSON(&ack); err != nil {
 61		return errorSubscription(fmt.Errorf("ack: %s", err.Error()))
 62	}
 63	if ack.Type != connectionAckMsg {
 64		return errorSubscription(fmt.Errorf("expected ack message, got %#v", ack))
 65	}
 66
 67	if err = c.WriteJSON(operationMessage{Type: startMsg, ID: "1", Payload: requestBody}); err != nil {
 68		return errorSubscription(fmt.Errorf("start: %s", err.Error()))
 69	}
 70
 71	return &Subscription{
 72		Close: c.Close,
 73		Next: func(response interface{}) error {
 74			var op operationMessage
 75			c.ReadJSON(&op)
 76			if op.Type != dataMsg {
 77				if op.Type == errorMsg {
 78					return fmt.Errorf(string(op.Payload))
 79				} else {
 80					return fmt.Errorf("expected data message, got %#v", op)
 81				}
 82			}
 83
 84			respDataRaw := map[string]interface{}{}
 85			err = json.Unmarshal(op.Payload, &respDataRaw)
 86			if err != nil {
 87				return fmt.Errorf("decode: %s", err.Error())
 88			}
 89
 90			if respDataRaw["errors"] != nil {
 91				var errs []*gqlerror.Error
 92				if err = unpack(respDataRaw["errors"], &errs); err != nil {
 93					return err
 94				}
 95				if len(errs) > 0 {
 96					return fmt.Errorf("errors: %s", errs)
 97				}
 98			}
 99
100			return unpack(respDataRaw["data"], response)
101		},
102	}
103}