1package client
2
3import (
4 "encoding/json"
5 "fmt"
6 "strings"
7
8 "github.com/gorilla/websocket"
9 "github.com/vektah/gqlparser/gqlerror"
10)
11
12const (
13 connectionInitMsg = "connection_init" // Client -> Server
14 startMsg = "start" // Client -> Server
15 connectionAckMsg = "connection_ack" // Server -> Client
16 dataMsg = "data" // Server -> Client
17 errorMsg = "error" // Server -> Client
18)
19
20type operationMessage struct {
21 Payload json.RawMessage `json:"payload,omitempty"`
22 ID string `json:"id,omitempty"`
23 Type string `json:"type"`
24}
25
26type Subscription struct {
27 Close func() error
28 Next func(response interface{}) error
29}
30
31func errorSubscription(err error) *Subscription {
32 return &Subscription{
33 Close: func() error { return nil },
34 Next: func(response interface{}) error {
35 return err
36 },
37 }
38}
39
40func (p *Client) Websocket(query string, options ...Option) *Subscription {
41 r := p.mkRequest(query, options...)
42 requestBody, err := json.Marshal(r)
43 if err != nil {
44 return errorSubscription(fmt.Errorf("encode: %s", err.Error()))
45 }
46
47 url := strings.Replace(p.url, "http://", "ws://", -1)
48 url = strings.Replace(url, "https://", "wss://", -1)
49
50 c, _, err := websocket.DefaultDialer.Dial(url, nil)
51 if err != nil {
52 return errorSubscription(fmt.Errorf("dial: %s", err.Error()))
53 }
54
55 if err = c.WriteJSON(operationMessage{Type: connectionInitMsg}); err != nil {
56 return errorSubscription(fmt.Errorf("init: %s", err.Error()))
57 }
58
59 var ack operationMessage
60 if err = c.ReadJSON(&ack); err != nil {
61 return errorSubscription(fmt.Errorf("ack: %s", err.Error()))
62 }
63 if ack.Type != connectionAckMsg {
64 return errorSubscription(fmt.Errorf("expected ack message, got %#v", ack))
65 }
66
67 if err = c.WriteJSON(operationMessage{Type: startMsg, ID: "1", Payload: requestBody}); err != nil {
68 return errorSubscription(fmt.Errorf("start: %s", err.Error()))
69 }
70
71 return &Subscription{
72 Close: c.Close,
73 Next: func(response interface{}) error {
74 var op operationMessage
75 c.ReadJSON(&op)
76 if op.Type != dataMsg {
77 if op.Type == errorMsg {
78 return fmt.Errorf(string(op.Payload))
79 } else {
80 return fmt.Errorf("expected data message, got %#v", op)
81 }
82 }
83
84 respDataRaw := map[string]interface{}{}
85 err = json.Unmarshal(op.Payload, &respDataRaw)
86 if err != nil {
87 return fmt.Errorf("decode: %s", err.Error())
88 }
89
90 if respDataRaw["errors"] != nil {
91 var errs []*gqlerror.Error
92 if err = unpack(respDataRaw["errors"], &errs); err != nil {
93 return err
94 }
95 if len(errs) > 0 {
96 return fmt.Errorf("errors: %s", errs)
97 }
98 }
99
100 return unpack(respDataRaw["data"], response)
101 },
102 }
103}