1package client
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"io"
  8	"net"
  9	"net/http"
 10	"net/url"
 11	stdpath "path"
 12	"path/filepath"
 13	"time"
 14
 15	"github.com/charmbracelet/crush/internal/config"
 16	"github.com/charmbracelet/crush/internal/proto"
 17	"github.com/charmbracelet/crush/internal/server"
 18)
 19
 20// DummyHost is used to satisfy the http.Client's requirement for a URL.
 21const DummyHost = "api.crush.localhost"
 22
 23// Client represents an RPC client connected to a Crush server.
 24type Client struct {
 25	h       *http.Client
 26	path    string
 27	network string
 28	addr    string
 29}
 30
 31// DefaultClient creates a new [Client] connected to the default server address.
 32func DefaultClient(path string) (*Client, error) {
 33	host, err := server.ParseHostURL(server.DefaultHost())
 34	if err != nil {
 35		return nil, err
 36	}
 37	return NewClient(path, host.Scheme, host.Host)
 38}
 39
 40// NewClient creates a new [Client] connected to the server at the given
 41// network and address.
 42func NewClient(path, network, address string) (*Client, error) {
 43	c := new(Client)
 44	c.path = filepath.Clean(path)
 45	c.network = network
 46	c.addr = address
 47	p := &http.Protocols{}
 48	p.SetHTTP1(true)
 49	p.SetUnencryptedHTTP2(true)
 50	tr := http.DefaultTransport.(*http.Transport).Clone()
 51	tr.Protocols = p
 52	tr.DialContext = c.dialer
 53	if c.network == "npipe" || c.network == "unix" {
 54		// We don't need compression for local connections.
 55		tr.DisableCompression = true
 56	}
 57	c.h = &http.Client{
 58		Transport: tr,
 59		Timeout:   0, // we need this to be 0 for long-lived connections and SSE streams
 60	}
 61	return c, nil
 62}
 63
 64// Path returns the client's instance filesystem path.
 65func (c *Client) Path() string {
 66	return c.path
 67}
 68
 69// GetGlobalConfig retrieves the server's configuration.
 70func (c *Client) GetGlobalConfig(ctx context.Context) (*config.Config, error) {
 71	var cfg config.Config
 72	rsp, err := c.get(ctx, "/config", nil, nil)
 73	if err != nil {
 74		return nil, err
 75	}
 76	defer rsp.Body.Close()
 77	if err := json.NewDecoder(rsp.Body).Decode(&cfg); err != nil {
 78		return nil, err
 79	}
 80	return &cfg, nil
 81}
 82
 83// Health checks the server's health status.
 84func (c *Client) Health(ctx context.Context) error {
 85	rsp, err := c.get(ctx, "/health", nil, nil)
 86	if err != nil {
 87		return err
 88	}
 89	defer rsp.Body.Close()
 90	if rsp.StatusCode != http.StatusOK {
 91		return fmt.Errorf("server health check failed: %s", rsp.Status)
 92	}
 93	return nil
 94}
 95
 96// VersionInfo retrieves the server's version information.
 97func (c *Client) VersionInfo(ctx context.Context) (*proto.VersionInfo, error) {
 98	var vi proto.VersionInfo
 99	rsp, err := c.get(ctx, "version", nil, nil)
100	if err != nil {
101		return nil, err
102	}
103	defer rsp.Body.Close()
104	if err := json.NewDecoder(rsp.Body).Decode(&vi); err != nil {
105		return nil, err
106	}
107	return &vi, nil
108}
109
110// ShutdownServer sends a shutdown request to the server.
111func (c *Client) ShutdownServer(ctx context.Context) error {
112	rsp, err := c.post(ctx, "/control", nil, jsonBody(proto.ServerControl{
113		Command: "shutdown",
114	}), nil)
115	if err != nil {
116		return err
117	}
118	defer rsp.Body.Close()
119	if rsp.StatusCode != http.StatusOK {
120		return fmt.Errorf("server shutdown failed: %s", rsp.Status)
121	}
122	return nil
123}
124
125func (c *Client) dialer(ctx context.Context, network, address string) (net.Conn, error) {
126	d := net.Dialer{
127		Timeout:   30 * time.Second,
128		KeepAlive: 30 * time.Second,
129	}
130	// It's important to use the client's addr for npipe/unix and not the
131	// address param because the address param is always "localhost:port" for
132	// HTTP clients and npipe/unix don't have a concept of ports.
133	switch c.network {
134	case "npipe":
135		ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
136		defer cancel()
137		return dialPipeContext(ctx, c.addr)
138	case "unix":
139		return d.DialContext(ctx, "unix", c.addr)
140	default:
141		return d.DialContext(ctx, network, address)
142	}
143}
144
145func (c *Client) get(ctx context.Context, path string, query url.Values, headers http.Header) (*http.Response, error) {
146	return c.sendReq(ctx, http.MethodGet, path, query, nil, headers)
147}
148
149func (c *Client) post(ctx context.Context, path string, query url.Values, body io.Reader, headers http.Header) (*http.Response, error) {
150	return c.sendReq(ctx, http.MethodPost, path, query, body, headers)
151}
152
153func (c *Client) put(ctx context.Context, path string, query url.Values, body io.Reader, headers http.Header) (*http.Response, error) {
154	return c.sendReq(ctx, http.MethodPut, path, query, body, headers)
155}
156
157func (c *Client) delete(ctx context.Context, path string, query url.Values, headers http.Header) (*http.Response, error) {
158	return c.sendReq(ctx, http.MethodDelete, path, query, nil, headers)
159}
160
161func (c *Client) sendReq(ctx context.Context, method, path string, query url.Values, body io.Reader, headers http.Header) (*http.Response, error) {
162	url := (&url.URL{
163		Path:     stdpath.Join("/v1", path), // Right now, we only have v1
164		RawQuery: query.Encode(),
165	}).String()
166	req, err := c.buildReq(ctx, method, url, body, headers)
167	if err != nil {
168		return nil, err
169	}
170
171	rsp, err := c.doReq(req)
172	if err != nil {
173		return nil, err
174	}
175
176	// TODO: check server errors in the response body?
177
178	return rsp, nil
179}
180
181func (c *Client) doReq(req *http.Request) (*http.Response, error) {
182	rsp, err := c.h.Do(req)
183	if err != nil {
184		return nil, err
185	}
186	return rsp, nil
187}
188
189func (c *Client) buildReq(ctx context.Context, method, url string, body io.Reader, headers http.Header) (*http.Request, error) {
190	r, err := http.NewRequestWithContext(ctx, method, url, body)
191	if err != nil {
192		return nil, err
193	}
194
195	for k, v := range headers {
196		r.Header[http.CanonicalHeaderKey(k)] = v
197	}
198
199	r.URL.Scheme = "http" // This is always http because we don't use TLS
200	r.URL.Host = c.addr
201	if c.network == "npipe" || c.network == "unix" {
202		// We use a dummy host for non-tcp connections.
203		r.Host = DummyHost
204	}
205
206	if body != nil && r.Header.Get("Content-Type") == "" {
207		r.Header.Set("Content-Type", "text/plain")
208	}
209
210	return r, nil
211}