1package client
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "io"
8 "net"
9 "net/http"
10 "net/url"
11 stdpath "path"
12 "path/filepath"
13 "time"
14
15 "github.com/charmbracelet/crush/internal/config"
16 "github.com/charmbracelet/crush/internal/proto"
17 "github.com/charmbracelet/crush/internal/server"
18)
19
20// DummyHost is used to satisfy the http.Client's requirement for a URL.
21const DummyHost = "api.crush.localhost"
22
23// Client represents an RPC client connected to a Crush server.
24type Client struct {
25 h *http.Client
26 path string
27 network string
28 addr string
29}
30
31// DefaultClient creates a new [Client] connected to the default server address.
32func DefaultClient(path string) (*Client, error) {
33 host, err := server.ParseHostURL(server.DefaultHost())
34 if err != nil {
35 return nil, err
36 }
37 return NewClient(path, host.Scheme, host.Host)
38}
39
40// NewClient creates a new [Client] connected to the server at the given
41// network and address.
42func NewClient(path, network, address string) (*Client, error) {
43 c := new(Client)
44 c.path = filepath.Clean(path)
45 c.network = network
46 c.addr = address
47 p := &http.Protocols{}
48 p.SetHTTP1(true)
49 p.SetUnencryptedHTTP2(true)
50 tr := http.DefaultTransport.(*http.Transport).Clone()
51 tr.Protocols = p
52 tr.DialContext = c.dialer
53 if c.network == "npipe" || c.network == "unix" {
54 // We don't need compression for local connections.
55 tr.DisableCompression = true
56 }
57 c.h = &http.Client{
58 Transport: tr,
59 Timeout: 0, // we need this to be 0 for long-lived connections and SSE streams
60 }
61 return c, nil
62}
63
64// Path returns the client's instance filesystem path.
65func (c *Client) Path() string {
66 return c.path
67}
68
69// GetGlobalConfig retrieves the server's configuration.
70func (c *Client) GetGlobalConfig(ctx context.Context) (*config.Config, error) {
71 var cfg config.Config
72 rsp, err := c.get(ctx, "/config", nil, nil)
73 if err != nil {
74 return nil, err
75 }
76 defer rsp.Body.Close()
77 if err := json.NewDecoder(rsp.Body).Decode(&cfg); err != nil {
78 return nil, err
79 }
80 return &cfg, nil
81}
82
83// Health checks the server's health status.
84func (c *Client) Health(ctx context.Context) error {
85 rsp, err := c.get(ctx, "/health", nil, nil)
86 if err != nil {
87 return err
88 }
89 defer rsp.Body.Close()
90 if rsp.StatusCode != http.StatusOK {
91 return fmt.Errorf("server health check failed: %s", rsp.Status)
92 }
93 return nil
94}
95
96// VersionInfo retrieves the server's version information.
97func (c *Client) VersionInfo(ctx context.Context) (*proto.VersionInfo, error) {
98 var vi proto.VersionInfo
99 rsp, err := c.get(ctx, "version", nil, nil)
100 if err != nil {
101 return nil, err
102 }
103 defer rsp.Body.Close()
104 if err := json.NewDecoder(rsp.Body).Decode(&vi); err != nil {
105 return nil, err
106 }
107 return &vi, nil
108}
109
110// ShutdownServer sends a shutdown request to the server.
111func (c *Client) ShutdownServer(ctx context.Context) error {
112 rsp, err := c.post(ctx, "/control", nil, jsonBody(proto.ServerControl{
113 Command: "shutdown",
114 }), nil)
115 if err != nil {
116 return err
117 }
118 defer rsp.Body.Close()
119 if rsp.StatusCode != http.StatusOK {
120 return fmt.Errorf("server shutdown failed: %s", rsp.Status)
121 }
122 return nil
123}
124
125func (c *Client) dialer(ctx context.Context, network, address string) (net.Conn, error) {
126 d := net.Dialer{
127 Timeout: 30 * time.Second,
128 KeepAlive: 30 * time.Second,
129 }
130 // It's important to use the client's addr for npipe/unix and not the
131 // address param because the address param is always "localhost:port" for
132 // HTTP clients and npipe/unix don't have a concept of ports.
133 switch c.network {
134 case "npipe":
135 ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
136 defer cancel()
137 return dialPipeContext(ctx, c.addr)
138 case "unix":
139 return d.DialContext(ctx, "unix", c.addr)
140 default:
141 return d.DialContext(ctx, network, address)
142 }
143}
144
145func (c *Client) get(ctx context.Context, path string, query url.Values, headers http.Header) (*http.Response, error) {
146 return c.sendReq(ctx, http.MethodGet, path, query, nil, headers)
147}
148
149func (c *Client) post(ctx context.Context, path string, query url.Values, body io.Reader, headers http.Header) (*http.Response, error) {
150 return c.sendReq(ctx, http.MethodPost, path, query, body, headers)
151}
152
153func (c *Client) put(ctx context.Context, path string, query url.Values, body io.Reader, headers http.Header) (*http.Response, error) {
154 return c.sendReq(ctx, http.MethodPut, path, query, body, headers)
155}
156
157func (c *Client) delete(ctx context.Context, path string, query url.Values, headers http.Header) (*http.Response, error) {
158 return c.sendReq(ctx, http.MethodDelete, path, query, nil, headers)
159}
160
161func (c *Client) sendReq(ctx context.Context, method, path string, query url.Values, body io.Reader, headers http.Header) (*http.Response, error) {
162 url := (&url.URL{
163 Path: stdpath.Join("/v1", path), // Right now, we only have v1
164 RawQuery: query.Encode(),
165 }).String()
166 req, err := c.buildReq(ctx, method, url, body, headers)
167 if err != nil {
168 return nil, err
169 }
170
171 rsp, err := c.doReq(req)
172 if err != nil {
173 return nil, err
174 }
175
176 // TODO: check server errors in the response body?
177
178 return rsp, nil
179}
180
181func (c *Client) doReq(req *http.Request) (*http.Response, error) {
182 rsp, err := c.h.Do(req)
183 if err != nil {
184 return nil, err
185 }
186 return rsp, nil
187}
188
189func (c *Client) buildReq(ctx context.Context, method, url string, body io.Reader, headers http.Header) (*http.Request, error) {
190 r, err := http.NewRequestWithContext(ctx, method, url, body)
191 if err != nil {
192 return nil, err
193 }
194
195 for k, v := range headers {
196 r.Header[http.CanonicalHeaderKey(k)] = v
197 }
198
199 r.URL.Scheme = "http" // This is always http because we don't use TLS
200 r.URL.Host = c.addr
201 if c.network == "npipe" || c.network == "unix" {
202 // We use a dummy host for non-tcp connections.
203 r.Host = DummyHost
204 }
205
206 if body != nil && r.Header.Get("Content-Type") == "" {
207 r.Header.Set("Content-Type", "text/plain")
208 }
209
210 return r, nil
211}