client.go

  1package client
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"io"
  8	"net"
  9	"net/http"
 10	"net/url"
 11	stdpath "path"
 12	"path/filepath"
 13	"time"
 14
 15	"github.com/charmbracelet/crush/internal/config"
 16	"github.com/charmbracelet/crush/internal/proto"
 17	"github.com/charmbracelet/crush/internal/server"
 18	"github.com/google/uuid"
 19)
 20
 21// DummyHost is used to satisfy the http.Client's requirement for a URL.
 22const DummyHost = "api.crush.localhost"
 23
 24// Client represents an RPC client connected to a Crush server.
 25type Client struct {
 26	h        *http.Client
 27	path     string
 28	network  string
 29	addr     string
 30	clientID string
 31}
 32
 33// DefaultClient creates a new [Client] connected to the default server address.
 34func DefaultClient(path string) (*Client, error) {
 35	host, err := server.ParseHostURL(server.DefaultHost())
 36	if err != nil {
 37		return nil, err
 38	}
 39	return NewClient(path, host.Scheme, host.Host)
 40}
 41
 42// NewClient creates a new [Client] connected to the server at the given
 43// network and address.
 44func NewClient(path, network, address string) (*Client, error) {
 45	c := new(Client)
 46	c.path = filepath.Clean(path)
 47	c.network = network
 48	c.addr = address
 49	c.clientID = uuid.New().String()
 50	p := &http.Protocols{}
 51	p.SetHTTP1(true)
 52	p.SetUnencryptedHTTP2(true)
 53	tr := http.DefaultTransport.(*http.Transport).Clone()
 54	tr.Protocols = p
 55	tr.DialContext = c.dialer
 56	if c.network == "npipe" || c.network == "unix" {
 57		tr.DisableCompression = true
 58	}
 59	c.h = &http.Client{
 60		Transport: tr,
 61		Timeout:   0,
 62	}
 63	return c, nil
 64}
 65
 66// Path returns the client's workspace filesystem path.
 67func (c *Client) Path() string {
 68	return c.path
 69}
 70
 71// ClientID returns the per-process client ID minted in [NewClient].
 72// The server uses it as a presence/coordination handle.
 73func (c *Client) ClientID() string {
 74	return c.clientID
 75}
 76
 77// GetGlobalConfig retrieves the server's configuration.
 78func (c *Client) GetGlobalConfig(ctx context.Context) (*config.Config, error) {
 79	var cfg config.Config
 80	rsp, err := c.get(ctx, "/config", nil, nil)
 81	if err != nil {
 82		return nil, err
 83	}
 84	defer rsp.Body.Close()
 85	if err := json.NewDecoder(rsp.Body).Decode(&cfg); err != nil {
 86		return nil, err
 87	}
 88	return &cfg, nil
 89}
 90
 91// Health checks the server's health status.
 92func (c *Client) Health(ctx context.Context) error {
 93	rsp, err := c.get(ctx, "/health", nil, nil)
 94	if err != nil {
 95		return err
 96	}
 97	defer rsp.Body.Close()
 98	if rsp.StatusCode != http.StatusOK {
 99		return fmt.Errorf("server health check failed: %s", rsp.Status)
100	}
101	return nil
102}
103
104// VersionInfo retrieves the server's version information.
105func (c *Client) VersionInfo(ctx context.Context) (*proto.VersionInfo, error) {
106	var vi proto.VersionInfo
107	rsp, err := c.get(ctx, "version", nil, nil)
108	if err != nil {
109		return nil, err
110	}
111	defer rsp.Body.Close()
112	if err := json.NewDecoder(rsp.Body).Decode(&vi); err != nil {
113		return nil, err
114	}
115	return &vi, nil
116}
117
118// ShutdownServer sends a shutdown request to the server.
119func (c *Client) ShutdownServer(ctx context.Context) error {
120	rsp, err := c.post(ctx, "/control", nil, jsonBody(proto.ServerControl{
121		Command: "shutdown",
122	}), nil)
123	if err != nil {
124		return err
125	}
126	defer rsp.Body.Close()
127	if rsp.StatusCode != http.StatusOK {
128		return fmt.Errorf("server shutdown failed: %s", rsp.Status)
129	}
130	return nil
131}
132
133// Dial opens a connection to the server using the same scheme-aware
134// logic the client uses for its HTTP transport. Exposed so callers can
135// reuse the dialer when they need to construct sibling HTTP transports
136// (e.g. a readiness probe in the CLI).
137func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, error) {
138	return c.dialer(ctx, network, address)
139}
140
141func (c *Client) dialer(ctx context.Context, network, address string) (net.Conn, error) {
142	d := net.Dialer{
143		Timeout:   30 * time.Second,
144		KeepAlive: 30 * time.Second,
145	}
146	switch c.network {
147	case "npipe":
148		ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
149		defer cancel()
150		return dialPipeContext(ctx, c.addr)
151	case "unix":
152		return d.DialContext(ctx, "unix", c.addr)
153	default:
154		return d.DialContext(ctx, network, address)
155	}
156}
157
158func (c *Client) get(ctx context.Context, path string, query url.Values, headers http.Header) (*http.Response, error) {
159	return c.sendReq(ctx, http.MethodGet, path, query, nil, headers)
160}
161
162func (c *Client) post(ctx context.Context, path string, query url.Values, body io.Reader, headers http.Header) (*http.Response, error) {
163	return c.sendReq(ctx, http.MethodPost, path, query, body, headers)
164}
165
166func (c *Client) delete(ctx context.Context, path string, query url.Values, headers http.Header) (*http.Response, error) {
167	return c.sendReq(ctx, http.MethodDelete, path, query, nil, headers)
168}
169
170func (c *Client) put(ctx context.Context, path string, query url.Values, body io.Reader, headers http.Header) (*http.Response, error) {
171	return c.sendReq(ctx, http.MethodPut, path, query, body, headers)
172}
173
174func (c *Client) sendReq(ctx context.Context, method, path string, query url.Values, body io.Reader, headers http.Header) (*http.Response, error) {
175	url := (&url.URL{
176		Path:     stdpath.Join("/v1", path),
177		RawQuery: query.Encode(),
178	}).String()
179	req, err := c.buildReq(ctx, method, url, body, headers)
180	if err != nil {
181		return nil, err
182	}
183
184	rsp, err := c.h.Do(req)
185	if err != nil {
186		return nil, err
187	}
188
189	return rsp, nil
190}
191
192func (c *Client) buildReq(ctx context.Context, method, url string, body io.Reader, headers http.Header) (*http.Request, error) {
193	r, err := http.NewRequestWithContext(ctx, method, url, body)
194	if err != nil {
195		return nil, err
196	}
197
198	for k, v := range headers {
199		r.Header[http.CanonicalHeaderKey(k)] = v
200	}
201
202	r.URL.Scheme = "http"
203	r.URL.Host = c.addr
204	if c.network == "npipe" || c.network == "unix" {
205		r.Host = DummyHost
206	}
207
208	if body != nil && r.Header.Get("Content-Type") == "" {
209		r.Header.Set("Content-Type", "text/plain")
210	}
211
212	return r, nil
213}