client.go

  1package client
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"net"
  8	"net/http"
  9	"path/filepath"
 10	"time"
 11
 12	"github.com/charmbracelet/crush/internal/config"
 13	"github.com/charmbracelet/crush/internal/proto"
 14	"github.com/charmbracelet/crush/internal/server"
 15)
 16
 17// DummyHost is used to satisfy the http.Client's requirement for a URL.
 18const DummyHost = "api.crush.localhost"
 19
 20// Client represents an RPC client connected to a Crush server.
 21type Client struct {
 22	h     *http.Client
 23	id    string
 24	path  string
 25	proto string
 26	addr  string
 27}
 28
 29// DefaultClient creates a new [Client] connected to the default server address.
 30func DefaultClient(path string) (*Client, error) {
 31	host, err := server.ParseHostURL(server.DefaultHost())
 32	if err != nil {
 33		return nil, err
 34	}
 35	return NewClient(path, host.Scheme, host.Host)
 36}
 37
 38// NewClient creates a new [Client] connected to the server at the given
 39// network and address.
 40func NewClient(path, network, address string) (*Client, error) {
 41	c := new(Client)
 42	c.path = filepath.Clean(path)
 43	c.proto = network
 44	c.addr = address
 45	p := &http.Protocols{}
 46	p.SetHTTP1(true)
 47	p.SetUnencryptedHTTP2(true)
 48	tr := http.DefaultTransport.(*http.Transport).Clone()
 49	tr.Protocols = p
 50	tr.DialContext = c.dialer
 51	if c.proto == "npipe" || c.proto == "unix" {
 52		// We don't need compression for local connections.
 53		tr.DisableCompression = true
 54	}
 55	c.h = &http.Client{
 56		Transport: tr,
 57		Timeout:   0, // we need this to be 0 for long-lived connections and SSE streams
 58	}
 59	return c, nil
 60}
 61
 62// ID returns the client's instance unique identifier.
 63func (c *Client) ID() string {
 64	return c.id
 65}
 66
 67// SetID sets the client's instance unique identifier.
 68func (c *Client) SetID(id string) {
 69	c.id = id
 70}
 71
 72// Path returns the client's instance filesystem path.
 73func (c *Client) Path() string {
 74	return c.path
 75}
 76
 77// GetGlobalConfig retrieves the server's configuration.
 78func (c *Client) GetGlobalConfig() (*config.Config, error) {
 79	var cfg config.Config
 80	rsp, err := c.h.Get("http://localhost/v1/config")
 81	if err != nil {
 82		return nil, err
 83	}
 84	defer rsp.Body.Close()
 85	if err := json.NewDecoder(rsp.Body).Decode(&cfg); err != nil {
 86		return nil, err
 87	}
 88	return &cfg, nil
 89}
 90
 91// Health checks the server's health status.
 92func (c *Client) Health() error {
 93	rsp, err := c.h.Get("http://localhost/v1/health")
 94	if err != nil {
 95		return err
 96	}
 97	defer rsp.Body.Close()
 98	if rsp.StatusCode != http.StatusOK {
 99		return fmt.Errorf("server health check failed: %s", rsp.Status)
100	}
101	return nil
102}
103
104// VersionInfo retrieves the server's version information.
105func (c *Client) VersionInfo() (*proto.VersionInfo, error) {
106	var vi proto.VersionInfo
107	rsp, err := c.h.Get("http://localhost/v1/version")
108	if err != nil {
109		return nil, err
110	}
111	defer rsp.Body.Close()
112	if err := json.NewDecoder(rsp.Body).Decode(&vi); err != nil {
113		return nil, err
114	}
115	return &vi, nil
116}
117
118// ShutdownServer sends a shutdown request to the server.
119func (c *Client) ShutdownServer() error {
120	req, err := http.NewRequest("POST", "http://localhost/v1/control", jsonBody(proto.ServerControl{
121		Command: "shutdown",
122	}))
123	if err != nil {
124		return err
125	}
126	rsp, err := c.h.Do(req)
127	if err != nil {
128		return err
129	}
130	defer rsp.Body.Close()
131	if rsp.StatusCode != http.StatusOK {
132		return fmt.Errorf("server shutdown failed: %s", rsp.Status)
133	}
134	return nil
135}
136
137func (c *Client) dialer(ctx context.Context, network, address string) (net.Conn, error) {
138	d := net.Dialer{
139		Timeout:   30 * time.Second,
140		KeepAlive: 30 * time.Second,
141	}
142	switch c.proto {
143	case "npipe":
144		ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
145		defer cancel()
146		return dialPipeContext(ctx, c.addr)
147	case "unix":
148		return d.DialContext(ctx, "unix", c.addr)
149	default:
150		return d.DialContext(ctx, network, address)
151	}
152}