1package client
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "io"
8 "net"
9 "net/http"
10 "net/url"
11 stdpath "path"
12 "path/filepath"
13 "time"
14
15 "github.com/charmbracelet/crush/internal/config"
16 "github.com/charmbracelet/crush/internal/proto"
17 "github.com/charmbracelet/crush/internal/server"
18 "github.com/google/uuid"
19)
20
21// DummyHost is used to satisfy the http.Client's requirement for a URL.
22const DummyHost = "api.crush.localhost"
23
24// Client represents an RPC client connected to a Crush server.
25type Client struct {
26 h *http.Client
27 path string
28 network string
29 addr string
30 clientID string
31}
32
33// DefaultClient creates a new [Client] connected to the default server address.
34func DefaultClient(path string) (*Client, error) {
35 host, err := server.ParseHostURL(server.DefaultHost())
36 if err != nil {
37 return nil, err
38 }
39 return NewClient(path, host.Scheme, host.Host)
40}
41
42// NewClient creates a new [Client] connected to the server at the given
43// network and address.
44func NewClient(path, network, address string) (*Client, error) {
45 c := new(Client)
46 c.path = filepath.Clean(path)
47 c.network = network
48 c.addr = address
49 c.clientID = uuid.New().String()
50 p := &http.Protocols{}
51 p.SetHTTP1(true)
52 p.SetUnencryptedHTTP2(true)
53 tr := http.DefaultTransport.(*http.Transport).Clone()
54 tr.Protocols = p
55 tr.DialContext = c.dialer
56 if c.network == "npipe" || c.network == "unix" {
57 tr.DisableCompression = true
58 }
59 c.h = &http.Client{
60 Transport: tr,
61 Timeout: 0,
62 }
63 return c, nil
64}
65
66// Path returns the client's workspace filesystem path.
67func (c *Client) Path() string {
68 return c.path
69}
70
71// ClientID returns the per-process client ID minted in [NewClient].
72// The server uses it as a presence/coordination handle.
73func (c *Client) ClientID() string {
74 return c.clientID
75}
76
77// GetGlobalConfig retrieves the server's configuration.
78func (c *Client) GetGlobalConfig(ctx context.Context) (*config.Config, error) {
79 var cfg config.Config
80 rsp, err := c.get(ctx, "/config", nil, nil)
81 if err != nil {
82 return nil, err
83 }
84 defer rsp.Body.Close()
85 if err := json.NewDecoder(rsp.Body).Decode(&cfg); err != nil {
86 return nil, err
87 }
88 return &cfg, nil
89}
90
91// Health checks the server's health status.
92func (c *Client) Health(ctx context.Context) error {
93 rsp, err := c.get(ctx, "/health", nil, nil)
94 if err != nil {
95 return err
96 }
97 defer rsp.Body.Close()
98 if rsp.StatusCode != http.StatusOK {
99 return fmt.Errorf("server health check failed: %s", rsp.Status)
100 }
101 return nil
102}
103
104// VersionInfo retrieves the server's version information.
105func (c *Client) VersionInfo(ctx context.Context) (*proto.VersionInfo, error) {
106 var vi proto.VersionInfo
107 rsp, err := c.get(ctx, "version", nil, nil)
108 if err != nil {
109 return nil, err
110 }
111 defer rsp.Body.Close()
112 if err := json.NewDecoder(rsp.Body).Decode(&vi); err != nil {
113 return nil, err
114 }
115 return &vi, nil
116}
117
118// ShutdownServer sends a shutdown request to the server.
119func (c *Client) ShutdownServer(ctx context.Context) error {
120 rsp, err := c.post(ctx, "/control", nil, jsonBody(proto.ServerControl{
121 Command: "shutdown",
122 }), nil)
123 if err != nil {
124 return err
125 }
126 defer rsp.Body.Close()
127 if rsp.StatusCode != http.StatusOK {
128 return fmt.Errorf("server shutdown failed: %s", rsp.Status)
129 }
130 return nil
131}
132
133// Dial opens a connection to the server using the same scheme-aware
134// logic the client uses for its HTTP transport. Exposed so callers can
135// reuse the dialer when they need to construct sibling HTTP transports
136// (e.g. a readiness probe in the CLI).
137func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, error) {
138 return c.dialer(ctx, network, address)
139}
140
141func (c *Client) dialer(ctx context.Context, network, address string) (net.Conn, error) {
142 d := net.Dialer{
143 Timeout: 30 * time.Second,
144 KeepAlive: 30 * time.Second,
145 }
146 switch c.network {
147 case "npipe":
148 ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
149 defer cancel()
150 return dialPipeContext(ctx, c.addr)
151 case "unix":
152 return d.DialContext(ctx, "unix", c.addr)
153 default:
154 return d.DialContext(ctx, network, address)
155 }
156}
157
158func (c *Client) get(ctx context.Context, path string, query url.Values, headers http.Header) (*http.Response, error) {
159 return c.sendReq(ctx, http.MethodGet, path, query, nil, headers)
160}
161
162func (c *Client) post(ctx context.Context, path string, query url.Values, body io.Reader, headers http.Header) (*http.Response, error) {
163 return c.sendReq(ctx, http.MethodPost, path, query, body, headers)
164}
165
166func (c *Client) delete(ctx context.Context, path string, query url.Values, headers http.Header) (*http.Response, error) {
167 return c.sendReq(ctx, http.MethodDelete, path, query, nil, headers)
168}
169
170func (c *Client) put(ctx context.Context, path string, query url.Values, body io.Reader, headers http.Header) (*http.Response, error) {
171 return c.sendReq(ctx, http.MethodPut, path, query, body, headers)
172}
173
174func (c *Client) sendReq(ctx context.Context, method, path string, query url.Values, body io.Reader, headers http.Header) (*http.Response, error) {
175 url := (&url.URL{
176 Path: stdpath.Join("/v1", path),
177 RawQuery: query.Encode(),
178 }).String()
179 req, err := c.buildReq(ctx, method, url, body, headers)
180 if err != nil {
181 return nil, err
182 }
183
184 rsp, err := c.h.Do(req)
185 if err != nil {
186 return nil, err
187 }
188
189 return rsp, nil
190}
191
192func (c *Client) buildReq(ctx context.Context, method, url string, body io.Reader, headers http.Header) (*http.Request, error) {
193 r, err := http.NewRequestWithContext(ctx, method, url, body)
194 if err != nil {
195 return nil, err
196 }
197
198 for k, v := range headers {
199 r.Header[http.CanonicalHeaderKey(k)] = v
200 }
201
202 r.URL.Scheme = "http"
203 r.URL.Host = c.addr
204 if c.network == "npipe" || c.network == "unix" {
205 r.Host = DummyHost
206 }
207
208 if body != nil && r.Header.Get("Content-Type") == "" {
209 r.Header.Set("Content-Type", "text/plain")
210 }
211
212 return r, nil
213}