@@ -4,8 +4,11 @@ import (
"context"
"encoding/json"
"fmt"
+ "io"
"net"
"net/http"
+ "net/url"
+ stdpath "path"
"path/filepath"
"time"
@@ -19,11 +22,11 @@ const DummyHost = "api.crush.localhost"
// Client represents an RPC client connected to a Crush server.
type Client struct {
- h *http.Client
- id string
- path string
- proto string
- addr string
+ h *http.Client
+ id string
+ path string
+ network string
+ addr string
}
// DefaultClient creates a new [Client] connected to the default server address.
@@ -40,7 +43,7 @@ func DefaultClient(path string) (*Client, error) {
func NewClient(path, network, address string) (*Client, error) {
c := new(Client)
c.path = filepath.Clean(path)
- c.proto = network
+ c.network = network
c.addr = address
p := &http.Protocols{}
p.SetHTTP1(true)
@@ -48,7 +51,7 @@ func NewClient(path, network, address string) (*Client, error) {
tr := http.DefaultTransport.(*http.Transport).Clone()
tr.Protocols = p
tr.DialContext = c.dialer
- if c.proto == "npipe" || c.proto == "unix" {
+ if c.network == "npipe" || c.network == "unix" {
// We don't need compression for local connections.
tr.DisableCompression = true
}
@@ -75,9 +78,9 @@ func (c *Client) Path() string {
}
// GetGlobalConfig retrieves the server's configuration.
-func (c *Client) GetGlobalConfig() (*config.Config, error) {
+func (c *Client) GetGlobalConfig(ctx context.Context) (*config.Config, error) {
var cfg config.Config
- rsp, err := c.h.Get("http://localhost/v1/config")
+ rsp, err := c.get(ctx, "/config", nil, nil)
if err != nil {
return nil, err
}
@@ -89,8 +92,8 @@ func (c *Client) GetGlobalConfig() (*config.Config, error) {
}
// Health checks the server's health status.
-func (c *Client) Health() error {
- rsp, err := c.h.Get("http://localhost/v1/health")
+func (c *Client) Health(ctx context.Context) error {
+ rsp, err := c.get(ctx, "/health", nil, nil)
if err != nil {
return err
}
@@ -102,9 +105,9 @@ func (c *Client) Health() error {
}
// VersionInfo retrieves the server's version information.
-func (c *Client) VersionInfo() (*proto.VersionInfo, error) {
+func (c *Client) VersionInfo(ctx context.Context) (*proto.VersionInfo, error) {
var vi proto.VersionInfo
- rsp, err := c.h.Get("http://localhost/v1/version")
+ rsp, err := c.get(ctx, "version", nil, nil)
if err != nil {
return nil, err
}
@@ -116,14 +119,10 @@ func (c *Client) VersionInfo() (*proto.VersionInfo, error) {
}
// ShutdownServer sends a shutdown request to the server.
-func (c *Client) ShutdownServer() error {
- req, err := http.NewRequest("POST", "http://localhost/v1/control", jsonBody(proto.ServerControl{
+func (c *Client) ShutdownServer(ctx context.Context) error {
+ rsp, err := c.post(ctx, "/control", nil, jsonBody(proto.ServerControl{
Command: "shutdown",
- }))
- if err != nil {
- return err
- }
- rsp, err := c.h.Do(req)
+ }), nil)
if err != nil {
return err
}
@@ -139,7 +138,10 @@ func (c *Client) dialer(ctx context.Context, network, address string) (net.Conn,
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}
- switch c.proto {
+ // It's important to use the client's addr for npipe/unix and not the
+ // address param because the address param is always "localhost:port" for
+ // HTTP clients and npipe/unix don't have a concept of ports.
+ switch c.network {
case "npipe":
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
@@ -150,3 +152,71 @@ func (c *Client) dialer(ctx context.Context, network, address string) (net.Conn,
return d.DialContext(ctx, network, address)
}
}
+
+func (c *Client) get(ctx context.Context, path string, query url.Values, headers http.Header) (*http.Response, error) {
+ return c.sendReq(ctx, http.MethodGet, path, query, nil, headers)
+}
+
+func (c *Client) post(ctx context.Context, path string, query url.Values, body io.Reader, headers http.Header) (*http.Response, error) {
+ return c.sendReq(ctx, http.MethodPost, path, query, body, headers)
+}
+
+func (c *Client) put(ctx context.Context, path string, query url.Values, body io.Reader, headers http.Header) (*http.Response, error) {
+ return c.sendReq(ctx, http.MethodPut, path, query, body, headers)
+}
+
+func (c *Client) delete(ctx context.Context, path string, query url.Values, headers http.Header) (*http.Response, error) {
+ return c.sendReq(ctx, http.MethodDelete, path, query, nil, headers)
+}
+
+func (c *Client) sendReq(ctx context.Context, method, path string, query url.Values, body io.Reader, headers http.Header) (*http.Response, error) {
+ url := (&url.URL{
+ Path: stdpath.Join("/v1", path), // Right now, we only have v1
+ RawQuery: query.Encode(),
+ }).String()
+ req, err := c.buildReq(ctx, method, url, body, headers)
+ if err != nil {
+ return nil, err
+ }
+
+ rsp, err := c.doReq(req)
+ if err != nil {
+ return nil, err
+ }
+
+ // TODO: check server errors in the response body?
+
+ return rsp, nil
+}
+
+func (c *Client) doReq(req *http.Request) (*http.Response, error) {
+ rsp, err := c.h.Do(req)
+ if err != nil {
+ return nil, err
+ }
+ return rsp, nil
+}
+
+func (c *Client) buildReq(ctx context.Context, method, url string, body io.Reader, headers http.Header) (*http.Request, error) {
+ r, err := http.NewRequestWithContext(ctx, method, url, body)
+ if err != nil {
+ return nil, err
+ }
+
+ for k, v := range headers {
+ r.Header[http.CanonicalHeaderKey(k)] = v
+ }
+
+ r.URL.Scheme = "http" // This is always http because we don't use TLS
+ r.URL.Host = c.addr
+ if c.network == "npipe" || c.network == "unix" {
+ // We use a dummy host for non-tcp connections.
+ r.Host = DummyHost
+ }
+
+ if body != nil && r.Header.Get("Content-Type") == "" {
+ r.Header.Set("Content-Type", "text/plain")
+ }
+
+ return r, nil
+}
@@ -24,29 +24,22 @@ import (
func (c *Client) SubscribeEvents(ctx context.Context) (<-chan any, error) {
events := make(chan any, 100)
- r, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("http://localhost/v1/instances/%s/events", c.id), nil)
+ rsp, err := c.get(ctx, fmt.Sprintf("/instances/%s/events", c.id), nil, http.Header{
+ "Accept": []string{"text/event-stream"},
+ "Cache-Control": []string{"no-cache"},
+ "Connection": []string{"keep-alive"},
+ })
if err != nil {
- return nil, fmt.Errorf("failed to create request: %w", err)
+ return nil, fmt.Errorf("failed to subscribe to events: %w", err)
}
- r.Header.Set("Accept", "text/event-stream")
- r.Header.Set("Cache-Control", "no-cache")
- r.Header.Set("Connection", "keep-alive")
+ if rsp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("failed to subscribe to events: status code %d", rsp.StatusCode)
+ }
go func() {
- rsp, err := c.h.Do(r)
- if err != nil {
- slog.Error("subscribing to events", "error", err)
- return
- }
-
defer rsp.Body.Close()
- if rsp.StatusCode != http.StatusOK {
- slog.Error("subscribing to events", "status_code", rsp.StatusCode)
- return
- }
-
scr := bufio.NewReader(rsp.Body)
for {
line, err := scr.ReadBytes('\n')
@@ -151,11 +144,7 @@ func sendEvent(ctx context.Context, evc chan any, ev any) {
}
func (c *Client) GetLSPDiagnostics(ctx context.Context, lsp string) (map[protocol.DocumentURI][]protocol.Diagnostic, error) {
- r, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("http://localhost/v1/instances/%s/lsps/%s/diagnostics", c.id, lsp), nil)
- if err != nil {
- return nil, fmt.Errorf("failed to create request: %w", err)
- }
- rsp, err := c.h.Do(r)
+ rsp, err := c.get(ctx, fmt.Sprintf("/instances/%s/lsps/%s/diagnostics", c.id, lsp), nil, nil)
if err != nil {
return nil, fmt.Errorf("failed to get LSP diagnostics: %w", err)
}
@@ -171,11 +160,7 @@ func (c *Client) GetLSPDiagnostics(ctx context.Context, lsp string) (map[protoco
}
func (c *Client) GetLSPs(ctx context.Context) (map[string]app.LSPClientInfo, error) {
- r, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("http://localhost/v1/instances/%s/lsps", c.id), nil)
- if err != nil {
- return nil, fmt.Errorf("failed to create request: %w", err)
- }
- rsp, err := c.h.Do(r)
+ rsp, err := c.get(ctx, fmt.Sprintf("/instances/%s/lsps", c.id), nil, nil)
if err != nil {
return nil, fmt.Errorf("failed to get LSPs: %w", err)
}
@@ -191,11 +176,7 @@ func (c *Client) GetLSPs(ctx context.Context) (map[string]app.LSPClientInfo, err
}
func (c *Client) GetAgentSessionQueuedPrompts(ctx context.Context, sessionID string) (int, error) {
- r, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("http://localhost/v1/instances/%s/agent/sessions/%s/prompts/queued", c.id, sessionID), nil)
- if err != nil {
- return 0, fmt.Errorf("failed to create request: %w", err)
- }
- rsp, err := c.h.Do(r)
+ rsp, err := c.get(ctx, fmt.Sprintf("/instances/%s/agent/sessions/%s/prompts/queued", c.id, sessionID), nil, nil)
if err != nil {
return 0, fmt.Errorf("failed to get session agent queued prompts: %w", err)
}
@@ -211,11 +192,7 @@ func (c *Client) GetAgentSessionQueuedPrompts(ctx context.Context, sessionID str
}
func (c *Client) ClearAgentSessionQueuedPrompts(ctx context.Context, sessionID string) error {
- r, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("http://localhost/v1/instances/%s/agent/sessions/%s/prompts/clear", c.id, sessionID), nil)
- if err != nil {
- return fmt.Errorf("failed to create request: %w", err)
- }
- rsp, err := c.h.Do(r)
+ rsp, err := c.post(ctx, fmt.Sprintf("/instances/%s/agent/sessions/%s/prompts/clear", c.id, sessionID), nil, nil, nil)
if err != nil {
return fmt.Errorf("failed to clear session agent queued prompts: %w", err)
}
@@ -227,11 +204,7 @@ func (c *Client) ClearAgentSessionQueuedPrompts(ctx context.Context, sessionID s
}
func (c *Client) GetAgentInfo(ctx context.Context) (*proto.AgentInfo, error) {
- r, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("http://localhost/v1/instances/%s/agent", c.id), nil)
- if err != nil {
- return nil, fmt.Errorf("failed to create request: %w", err)
- }
- rsp, err := c.h.Do(r)
+ rsp, err := c.get(ctx, fmt.Sprintf("/instances/%s/agent", c.id), nil, nil)
if err != nil {
return nil, fmt.Errorf("failed to get agent status: %w", err)
}
@@ -247,11 +220,7 @@ func (c *Client) GetAgentInfo(ctx context.Context) (*proto.AgentInfo, error) {
}
func (c *Client) UpdateAgent(ctx context.Context) error {
- r, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("http://localhost/v1/instances/%s/agent/update", c.id), nil)
- if err != nil {
- return fmt.Errorf("failed to create request: %w", err)
- }
- rsp, err := c.h.Do(r)
+ rsp, err := c.post(ctx, fmt.Sprintf("/instances/%s/agent/update", c.id), nil, nil, nil)
if err != nil {
return fmt.Errorf("failed to update agent: %w", err)
}
@@ -263,15 +232,11 @@ func (c *Client) UpdateAgent(ctx context.Context) error {
}
func (c *Client) SendMessage(ctx context.Context, sessionID, message string, attchments ...message.Attachment) error {
- r, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("http://localhost/v1/instances/%s/agent", c.id), jsonBody(proto.AgentMessage{
+ rsp, err := c.post(ctx, fmt.Sprintf("/instances/%s/agent", c.id), nil, jsonBody(proto.AgentMessage{
SessionID: sessionID,
Prompt: message,
Attachments: attchments,
- }))
- if err != nil {
- return fmt.Errorf("failed to create request: %w", err)
- }
- rsp, err := c.h.Do(r)
+ }), http.Header{"Content-Type": []string{"application/json"}})
if err != nil {
return fmt.Errorf("failed to send message to agent: %w", err)
}
@@ -283,11 +248,7 @@ func (c *Client) SendMessage(ctx context.Context, sessionID, message string, att
}
func (c *Client) GetAgentSessionInfo(ctx context.Context, sessionID string) (*proto.AgentSession, error) {
- r, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("http://localhost/v1/instances/%s/agent/sessions/%s", c.id, sessionID), nil)
- if err != nil {
- return nil, fmt.Errorf("failed to create request: %w", err)
- }
- rsp, err := c.h.Do(r)
+ rsp, err := c.get(ctx, fmt.Sprintf("/instances/%s/agent/sessions/%s", c.id, sessionID), nil, nil)
if err != nil {
return nil, fmt.Errorf("failed to get session agent info: %w", err)
}
@@ -303,11 +264,7 @@ func (c *Client) GetAgentSessionInfo(ctx context.Context, sessionID string) (*pr
}
func (c *Client) AgentSummarizeSession(ctx context.Context, sessionID string) error {
- r, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("http://localhost/v1/instances/%s/agent/sessions/%s/summarize", c.id, sessionID), nil)
- if err != nil {
- return fmt.Errorf("failed to create request: %w", err)
- }
- rsp, err := c.h.Do(r)
+ rsp, err := c.post(ctx, fmt.Sprintf("/instances/%s/agent/sessions/%s/summarize", c.id, sessionID), nil, nil, nil)
if err != nil {
return fmt.Errorf("failed to summarize session: %w", err)
}
@@ -319,11 +276,7 @@ func (c *Client) AgentSummarizeSession(ctx context.Context, sessionID string) er
}
func (c *Client) ListMessages(ctx context.Context, sessionID string) ([]message.Message, error) {
- r, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("http://localhost/v1/instances/%s/sessions/%s/messages", c.id, sessionID), nil)
- if err != nil {
- return nil, fmt.Errorf("failed to create request: %w", err)
- }
- rsp, err := c.h.Do(r)
+ rsp, err := c.get(ctx, fmt.Sprintf("/instances/%s/sessions/%s/messages", c.id, sessionID), nil, nil)
if err != nil {
return nil, fmt.Errorf("failed to get messages: %w", err)
}
@@ -339,11 +292,7 @@ func (c *Client) ListMessages(ctx context.Context, sessionID string) ([]message.
}
func (c *Client) GetSession(ctx context.Context, sessionID string) (*session.Session, error) {
- r, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("http://localhost/v1/instances/%s/sessions/%s", c.id, sessionID), nil)
- if err != nil {
- return nil, fmt.Errorf("failed to create request: %w", err)
- }
- rsp, err := c.h.Do(r)
+ rsp, err := c.get(ctx, fmt.Sprintf("/instances/%s/sessions/%s", c.id, sessionID), nil, nil)
if err != nil {
return nil, fmt.Errorf("failed to get session: %w", err)
}
@@ -359,11 +308,7 @@ func (c *Client) GetSession(ctx context.Context, sessionID string) (*session.Ses
}
func (c *Client) InitiateAgentProcessing(ctx context.Context) error {
- r, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("http://localhost/v1/instances/%s/agent/init", c.id), nil)
- if err != nil {
- return fmt.Errorf("failed to create request: %w", err)
- }
- rsp, err := c.h.Do(r)
+ rsp, err := c.post(ctx, fmt.Sprintf("/instances/%s/agent/init", c.id), nil, nil, nil)
if err != nil {
return fmt.Errorf("failed to initiate session agent processing: %w", err)
}
@@ -375,11 +320,7 @@ func (c *Client) InitiateAgentProcessing(ctx context.Context) error {
}
func (c *Client) ListSessionHistoryFiles(ctx context.Context, sessionID string) ([]history.File, error) {
- r, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("http://localhost/v1/instances/%s/sessions/%s/history", c.id, sessionID), nil)
- if err != nil {
- return nil, fmt.Errorf("failed to create request: %w", err)
- }
- rsp, err := c.h.Do(r)
+ rsp, err := c.get(ctx, fmt.Sprintf("/instances/%s/sessions/%s/history", c.id, sessionID), nil, nil)
if err != nil {
return nil, fmt.Errorf("failed to get session history files: %w", err)
}
@@ -395,12 +336,7 @@ func (c *Client) ListSessionHistoryFiles(ctx context.Context, sessionID string)
}
func (c *Client) CreateSession(ctx context.Context, title string) (*session.Session, error) {
- r, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("http://localhost/v1/instances/%s/sessions", c.id), jsonBody(session.Session{Title: title}))
- if err != nil {
- return nil, fmt.Errorf("failed to create request: %w", err)
- }
- r.Header.Set("Content-Type", "application/json")
- rsp, err := c.h.Do(r)
+ rsp, err := c.post(ctx, fmt.Sprintf("/instances/%s/sessions", c.id), nil, jsonBody(session.Session{Title: title}), http.Header{"Content-Type": []string{"application/json"}})
if err != nil {
return nil, fmt.Errorf("failed to create session: %w", err)
}
@@ -416,11 +352,7 @@ func (c *Client) CreateSession(ctx context.Context, title string) (*session.Sess
}
func (c *Client) ListSessions(ctx context.Context) ([]session.Session, error) {
- r, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("http://localhost/v1/instances/%s/sessions", c.id), nil)
- if err != nil {
- return nil, fmt.Errorf("failed to create request: %w", err)
- }
- rsp, err := c.h.Do(r)
+ rsp, err := c.get(ctx, fmt.Sprintf("/instances/%s/sessions", c.id), nil, nil)
if err != nil {
return nil, fmt.Errorf("failed to get sessions: %w", err)
}
@@ -436,12 +368,7 @@ func (c *Client) ListSessions(ctx context.Context) ([]session.Session, error) {
}
func (c *Client) GrantPermission(ctx context.Context, req proto.PermissionGrant) error {
- r, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("http://localhost/v1/instances/%s/permissions/grant", c.id), jsonBody(req))
- if err != nil {
- return fmt.Errorf("failed to create request: %w", err)
- }
- r.Header.Set("Content-Type", "application/json")
- rsp, err := c.h.Do(r)
+ rsp, err := c.post(ctx, fmt.Sprintf("/instances/%s/permissions/grant", c.id), nil, jsonBody(req), http.Header{"Content-Type": []string{"application/json"}})
if err != nil {
return fmt.Errorf("failed to grant permission: %w", err)
}
@@ -453,12 +380,7 @@ func (c *Client) GrantPermission(ctx context.Context, req proto.PermissionGrant)
}
func (c *Client) SetPermissionsSkipRequests(ctx context.Context, skip bool) error {
- r, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("http://localhost/v1/instances/%s/permissions/skip", c.id), jsonBody(proto.PermissionSkipRequest{Skip: skip}))
- if err != nil {
- return fmt.Errorf("failed to create request: %w", err)
- }
- r.Header.Set("Content-Type", "application/json")
- rsp, err := c.h.Do(r)
+ rsp, err := c.post(ctx, fmt.Sprintf("/instances/%s/permissions/skip", c.id), nil, jsonBody(proto.PermissionSkipRequest{Skip: skip}), http.Header{"Content-Type": []string{"application/json"}})
if err != nil {
return fmt.Errorf("failed to set permissions skip requests: %w", err)
}
@@ -470,11 +392,7 @@ func (c *Client) SetPermissionsSkipRequests(ctx context.Context, skip bool) erro
}
func (c *Client) GetPermissionsSkipRequests(ctx context.Context) (bool, error) {
- r, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("http://localhost/v1/instances/%s/permissions/skip", c.id), nil)
- if err != nil {
- return false, fmt.Errorf("failed to create request: %w", err)
- }
- rsp, err := c.h.Do(r)
+ rsp, err := c.get(ctx, fmt.Sprintf("/instances/%s/permissions/skip", c.id), nil, nil)
if err != nil {
return false, fmt.Errorf("failed to get permissions skip requests: %w", err)
}
@@ -490,11 +408,7 @@ func (c *Client) GetPermissionsSkipRequests(ctx context.Context) (bool, error) {
}
func (c *Client) GetConfig(ctx context.Context) (*config.Config, error) {
- r, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("http://localhost/v1/instances/%s/config", c.id), nil)
- if err != nil {
- return nil, fmt.Errorf("failed to create request: %w", err)
- }
- rsp, err := c.h.Do(r)
+ rsp, err := c.get(ctx, fmt.Sprintf("/instances/%s/config", c.id), nil, nil)
if err != nil {
return nil, fmt.Errorf("failed to get config: %w", err)
}
@@ -510,13 +424,7 @@ func (c *Client) GetConfig(ctx context.Context) (*config.Config, error) {
}
func (c *Client) CreateInstance(ctx context.Context, ins proto.Instance) (*proto.Instance, error) {
- r, err := http.NewRequestWithContext(ctx, "POST", "http://localhost/v1/instances", jsonBody(ins))
- if err != nil {
- return nil, fmt.Errorf("failed to create request: %w", err)
- }
-
- r.Header.Set("Content-Type", "application/json")
- rsp, err := c.h.Do(r)
+ rsp, err := c.post(ctx, "instances", nil, jsonBody(ins), http.Header{"Content-Type": []string{"application/json"}})
if err != nil {
return nil, fmt.Errorf("failed to create instance: %w", err)
}
@@ -532,11 +440,7 @@ func (c *Client) CreateInstance(ctx context.Context, ins proto.Instance) (*proto
}
func (c *Client) DeleteInstance(ctx context.Context, id string) error {
- r, err := http.NewRequestWithContext(ctx, "DELETE", fmt.Sprintf("http://localhost/v1/instances/%s", id), nil)
- if err != nil {
- return fmt.Errorf("failed to create request: %w", err)
- }
- rsp, err := c.h.Do(r)
+ rsp, err := c.delete(ctx, fmt.Sprintf("/instances/%s", id), nil, nil)
if err != nil {
return fmt.Errorf("failed to delete instance: %w", err)
}