1package client
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "io"
8 "net"
9 "net/http"
10 "net/url"
11 stdpath "path"
12 "path/filepath"
13 "time"
14
15 "github.com/charmbracelet/crush/internal/config"
16 "github.com/charmbracelet/crush/internal/proto"
17 "github.com/charmbracelet/crush/internal/server"
18)
19
20// DummyHost is used to satisfy the http.Client's requirement for a URL.
21const DummyHost = "api.crush.localhost"
22
23// Client represents an RPC client connected to a Crush server.
24type Client struct {
25 h *http.Client
26 id string
27 path string
28 network string
29 addr string
30}
31
32// DefaultClient creates a new [Client] connected to the default server address.
33func DefaultClient(path string) (*Client, error) {
34 host, err := server.ParseHostURL(server.DefaultHost())
35 if err != nil {
36 return nil, err
37 }
38 return NewClient(path, host.Scheme, host.Host)
39}
40
41// NewClient creates a new [Client] connected to the server at the given
42// network and address.
43func NewClient(path, network, address string) (*Client, error) {
44 c := new(Client)
45 c.path = filepath.Clean(path)
46 c.network = network
47 c.addr = address
48 p := &http.Protocols{}
49 p.SetHTTP1(true)
50 p.SetUnencryptedHTTP2(true)
51 tr := http.DefaultTransport.(*http.Transport).Clone()
52 tr.Protocols = p
53 tr.DialContext = c.dialer
54 if c.network == "npipe" || c.network == "unix" {
55 // We don't need compression for local connections.
56 tr.DisableCompression = true
57 }
58 c.h = &http.Client{
59 Transport: tr,
60 Timeout: 0, // we need this to be 0 for long-lived connections and SSE streams
61 }
62 return c, nil
63}
64
65// ID returns the client's instance unique identifier.
66func (c *Client) ID() string {
67 return c.id
68}
69
70// SetID sets the client's instance unique identifier.
71func (c *Client) SetID(id string) {
72 c.id = id
73}
74
75// Path returns the client's instance filesystem path.
76func (c *Client) Path() string {
77 return c.path
78}
79
80// GetGlobalConfig retrieves the server's configuration.
81func (c *Client) GetGlobalConfig(ctx context.Context) (*config.Config, error) {
82 var cfg config.Config
83 rsp, err := c.get(ctx, "/config", nil, nil)
84 if err != nil {
85 return nil, err
86 }
87 defer rsp.Body.Close()
88 if err := json.NewDecoder(rsp.Body).Decode(&cfg); err != nil {
89 return nil, err
90 }
91 return &cfg, nil
92}
93
94// Health checks the server's health status.
95func (c *Client) Health(ctx context.Context) error {
96 rsp, err := c.get(ctx, "/health", nil, nil)
97 if err != nil {
98 return err
99 }
100 defer rsp.Body.Close()
101 if rsp.StatusCode != http.StatusOK {
102 return fmt.Errorf("server health check failed: %s", rsp.Status)
103 }
104 return nil
105}
106
107// VersionInfo retrieves the server's version information.
108func (c *Client) VersionInfo(ctx context.Context) (*proto.VersionInfo, error) {
109 var vi proto.VersionInfo
110 rsp, err := c.get(ctx, "version", nil, nil)
111 if err != nil {
112 return nil, err
113 }
114 defer rsp.Body.Close()
115 if err := json.NewDecoder(rsp.Body).Decode(&vi); err != nil {
116 return nil, err
117 }
118 return &vi, nil
119}
120
121// ShutdownServer sends a shutdown request to the server.
122func (c *Client) ShutdownServer(ctx context.Context) error {
123 rsp, err := c.post(ctx, "/control", nil, jsonBody(proto.ServerControl{
124 Command: "shutdown",
125 }), nil)
126 if err != nil {
127 return err
128 }
129 defer rsp.Body.Close()
130 if rsp.StatusCode != http.StatusOK {
131 return fmt.Errorf("server shutdown failed: %s", rsp.Status)
132 }
133 return nil
134}
135
136func (c *Client) dialer(ctx context.Context, network, address string) (net.Conn, error) {
137 d := net.Dialer{
138 Timeout: 30 * time.Second,
139 KeepAlive: 30 * time.Second,
140 }
141 // It's important to use the client's addr for npipe/unix and not the
142 // address param because the address param is always "localhost:port" for
143 // HTTP clients and npipe/unix don't have a concept of ports.
144 switch c.network {
145 case "npipe":
146 ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
147 defer cancel()
148 return dialPipeContext(ctx, c.addr)
149 case "unix":
150 return d.DialContext(ctx, "unix", c.addr)
151 default:
152 return d.DialContext(ctx, network, address)
153 }
154}
155
156func (c *Client) get(ctx context.Context, path string, query url.Values, headers http.Header) (*http.Response, error) {
157 return c.sendReq(ctx, http.MethodGet, path, query, nil, headers)
158}
159
160func (c *Client) post(ctx context.Context, path string, query url.Values, body io.Reader, headers http.Header) (*http.Response, error) {
161 return c.sendReq(ctx, http.MethodPost, path, query, body, headers)
162}
163
164func (c *Client) put(ctx context.Context, path string, query url.Values, body io.Reader, headers http.Header) (*http.Response, error) {
165 return c.sendReq(ctx, http.MethodPut, path, query, body, headers)
166}
167
168func (c *Client) delete(ctx context.Context, path string, query url.Values, headers http.Header) (*http.Response, error) {
169 return c.sendReq(ctx, http.MethodDelete, path, query, nil, headers)
170}
171
172func (c *Client) sendReq(ctx context.Context, method, path string, query url.Values, body io.Reader, headers http.Header) (*http.Response, error) {
173 url := (&url.URL{
174 Path: stdpath.Join("/v1", path), // Right now, we only have v1
175 RawQuery: query.Encode(),
176 }).String()
177 req, err := c.buildReq(ctx, method, url, body, headers)
178 if err != nil {
179 return nil, err
180 }
181
182 rsp, err := c.doReq(req)
183 if err != nil {
184 return nil, err
185 }
186
187 // TODO: check server errors in the response body?
188
189 return rsp, nil
190}
191
192func (c *Client) doReq(req *http.Request) (*http.Response, error) {
193 rsp, err := c.h.Do(req)
194 if err != nil {
195 return nil, err
196 }
197 return rsp, nil
198}
199
200func (c *Client) buildReq(ctx context.Context, method, url string, body io.Reader, headers http.Header) (*http.Request, error) {
201 r, err := http.NewRequestWithContext(ctx, method, url, body)
202 if err != nil {
203 return nil, err
204 }
205
206 for k, v := range headers {
207 r.Header[http.CanonicalHeaderKey(k)] = v
208 }
209
210 r.URL.Scheme = "http" // This is always http because we don't use TLS
211 r.URL.Host = c.addr
212 if c.network == "npipe" || c.network == "unix" {
213 // We use a dummy host for non-tcp connections.
214 r.Host = DummyHost
215 }
216
217 if body != nil && r.Header.Get("Content-Type") == "" {
218 r.Header.Set("Content-Type", "text/plain")
219 }
220
221 return r, nil
222}