client.go

  1package client
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"io"
  8	"net"
  9	"net/http"
 10	"net/url"
 11	stdpath "path"
 12	"path/filepath"
 13	"time"
 14
 15	"github.com/charmbracelet/crush/internal/config"
 16	"github.com/charmbracelet/crush/internal/proto"
 17	"github.com/charmbracelet/crush/internal/server"
 18)
 19
 20// DummyHost is used to satisfy the http.Client's requirement for a URL.
 21const DummyHost = "api.crush.localhost"
 22
 23// Client represents an RPC client connected to a Crush server.
 24type Client struct {
 25	h       *http.Client
 26	id      string
 27	path    string
 28	network string
 29	addr    string
 30}
 31
 32// DefaultClient creates a new [Client] connected to the default server address.
 33func DefaultClient(path string) (*Client, error) {
 34	host, err := server.ParseHostURL(server.DefaultHost())
 35	if err != nil {
 36		return nil, err
 37	}
 38	return NewClient(path, host.Scheme, host.Host)
 39}
 40
 41// NewClient creates a new [Client] connected to the server at the given
 42// network and address.
 43func NewClient(path, network, address string) (*Client, error) {
 44	c := new(Client)
 45	c.path = filepath.Clean(path)
 46	c.network = network
 47	c.addr = address
 48	p := &http.Protocols{}
 49	p.SetHTTP1(true)
 50	p.SetUnencryptedHTTP2(true)
 51	tr := http.DefaultTransport.(*http.Transport).Clone()
 52	tr.Protocols = p
 53	tr.DialContext = c.dialer
 54	if c.network == "npipe" || c.network == "unix" {
 55		// We don't need compression for local connections.
 56		tr.DisableCompression = true
 57	}
 58	c.h = &http.Client{
 59		Transport: tr,
 60		Timeout:   0, // we need this to be 0 for long-lived connections and SSE streams
 61	}
 62	return c, nil
 63}
 64
 65// ID returns the client's instance unique identifier.
 66func (c *Client) ID() string {
 67	return c.id
 68}
 69
 70// SetID sets the client's instance unique identifier.
 71func (c *Client) SetID(id string) {
 72	c.id = id
 73}
 74
 75// Path returns the client's instance filesystem path.
 76func (c *Client) Path() string {
 77	return c.path
 78}
 79
 80// GetGlobalConfig retrieves the server's configuration.
 81func (c *Client) GetGlobalConfig(ctx context.Context) (*config.Config, error) {
 82	var cfg config.Config
 83	rsp, err := c.get(ctx, "/config", nil, nil)
 84	if err != nil {
 85		return nil, err
 86	}
 87	defer rsp.Body.Close()
 88	if err := json.NewDecoder(rsp.Body).Decode(&cfg); err != nil {
 89		return nil, err
 90	}
 91	return &cfg, nil
 92}
 93
 94// Health checks the server's health status.
 95func (c *Client) Health(ctx context.Context) error {
 96	rsp, err := c.get(ctx, "/health", nil, nil)
 97	if err != nil {
 98		return err
 99	}
100	defer rsp.Body.Close()
101	if rsp.StatusCode != http.StatusOK {
102		return fmt.Errorf("server health check failed: %s", rsp.Status)
103	}
104	return nil
105}
106
107// VersionInfo retrieves the server's version information.
108func (c *Client) VersionInfo(ctx context.Context) (*proto.VersionInfo, error) {
109	var vi proto.VersionInfo
110	rsp, err := c.get(ctx, "version", nil, nil)
111	if err != nil {
112		return nil, err
113	}
114	defer rsp.Body.Close()
115	if err := json.NewDecoder(rsp.Body).Decode(&vi); err != nil {
116		return nil, err
117	}
118	return &vi, nil
119}
120
121// ShutdownServer sends a shutdown request to the server.
122func (c *Client) ShutdownServer(ctx context.Context) error {
123	rsp, err := c.post(ctx, "/control", nil, jsonBody(proto.ServerControl{
124		Command: "shutdown",
125	}), nil)
126	if err != nil {
127		return err
128	}
129	defer rsp.Body.Close()
130	if rsp.StatusCode != http.StatusOK {
131		return fmt.Errorf("server shutdown failed: %s", rsp.Status)
132	}
133	return nil
134}
135
136func (c *Client) dialer(ctx context.Context, network, address string) (net.Conn, error) {
137	d := net.Dialer{
138		Timeout:   30 * time.Second,
139		KeepAlive: 30 * time.Second,
140	}
141	// It's important to use the client's addr for npipe/unix and not the
142	// address param because the address param is always "localhost:port" for
143	// HTTP clients and npipe/unix don't have a concept of ports.
144	switch c.network {
145	case "npipe":
146		ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
147		defer cancel()
148		return dialPipeContext(ctx, c.addr)
149	case "unix":
150		return d.DialContext(ctx, "unix", c.addr)
151	default:
152		return d.DialContext(ctx, network, address)
153	}
154}
155
156func (c *Client) get(ctx context.Context, path string, query url.Values, headers http.Header) (*http.Response, error) {
157	return c.sendReq(ctx, http.MethodGet, path, query, nil, headers)
158}
159
160func (c *Client) post(ctx context.Context, path string, query url.Values, body io.Reader, headers http.Header) (*http.Response, error) {
161	return c.sendReq(ctx, http.MethodPost, path, query, body, headers)
162}
163
164func (c *Client) put(ctx context.Context, path string, query url.Values, body io.Reader, headers http.Header) (*http.Response, error) {
165	return c.sendReq(ctx, http.MethodPut, path, query, body, headers)
166}
167
168func (c *Client) delete(ctx context.Context, path string, query url.Values, headers http.Header) (*http.Response, error) {
169	return c.sendReq(ctx, http.MethodDelete, path, query, nil, headers)
170}
171
172func (c *Client) sendReq(ctx context.Context, method, path string, query url.Values, body io.Reader, headers http.Header) (*http.Response, error) {
173	url := (&url.URL{
174		Path:     stdpath.Join("/v1", path), // Right now, we only have v1
175		RawQuery: query.Encode(),
176	}).String()
177	req, err := c.buildReq(ctx, method, url, body, headers)
178	if err != nil {
179		return nil, err
180	}
181
182	rsp, err := c.doReq(req)
183	if err != nil {
184		return nil, err
185	}
186
187	// TODO: check server errors in the response body?
188
189	return rsp, nil
190}
191
192func (c *Client) doReq(req *http.Request) (*http.Response, error) {
193	rsp, err := c.h.Do(req)
194	if err != nil {
195		return nil, err
196	}
197	return rsp, nil
198}
199
200func (c *Client) buildReq(ctx context.Context, method, url string, body io.Reader, headers http.Header) (*http.Request, error) {
201	r, err := http.NewRequestWithContext(ctx, method, url, body)
202	if err != nil {
203		return nil, err
204	}
205
206	for k, v := range headers {
207		r.Header[http.CanonicalHeaderKey(k)] = v
208	}
209
210	r.URL.Scheme = "http" // This is always http because we don't use TLS
211	r.URL.Host = c.addr
212	if c.network == "npipe" || c.network == "unix" {
213		// We use a dummy host for non-tcp connections.
214		r.Host = DummyHost
215	}
216
217	if body != nil && r.Header.Get("Content-Type") == "" {
218		r.Header.Set("Content-Type", "text/plain")
219	}
220
221	return r, nil
222}