diff --git a/internal/client/client.go b/internal/client/client.go index 4d7706b267d877cf1c577b3aac45729a91ae7ada..d4577eb836ae4991fb47a9ef37566c71c0d9ecd4 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -4,8 +4,11 @@ import ( "context" "encoding/json" "fmt" + "io" "net" "net/http" + "net/url" + stdpath "path" "path/filepath" "time" @@ -19,11 +22,11 @@ const DummyHost = "api.crush.localhost" // Client represents an RPC client connected to a Crush server. type Client struct { - h *http.Client - id string - path string - proto string - addr string + h *http.Client + id string + path string + network string + addr string } // DefaultClient creates a new [Client] connected to the default server address. @@ -40,7 +43,7 @@ func DefaultClient(path string) (*Client, error) { func NewClient(path, network, address string) (*Client, error) { c := new(Client) c.path = filepath.Clean(path) - c.proto = network + c.network = network c.addr = address p := &http.Protocols{} p.SetHTTP1(true) @@ -48,7 +51,7 @@ func NewClient(path, network, address string) (*Client, error) { tr := http.DefaultTransport.(*http.Transport).Clone() tr.Protocols = p tr.DialContext = c.dialer - if c.proto == "npipe" || c.proto == "unix" { + if c.network == "npipe" || c.network == "unix" { // We don't need compression for local connections. tr.DisableCompression = true } @@ -75,9 +78,9 @@ func (c *Client) Path() string { } // GetGlobalConfig retrieves the server's configuration. -func (c *Client) GetGlobalConfig() (*config.Config, error) { +func (c *Client) GetGlobalConfig(ctx context.Context) (*config.Config, error) { var cfg config.Config - rsp, err := c.h.Get("http://localhost/v1/config") + rsp, err := c.get(ctx, "/config", nil, nil) if err != nil { return nil, err } @@ -89,8 +92,8 @@ func (c *Client) GetGlobalConfig() (*config.Config, error) { } // Health checks the server's health status. -func (c *Client) Health() error { - rsp, err := c.h.Get("http://localhost/v1/health") +func (c *Client) Health(ctx context.Context) error { + rsp, err := c.get(ctx, "/health", nil, nil) if err != nil { return err } @@ -102,9 +105,9 @@ func (c *Client) Health() error { } // VersionInfo retrieves the server's version information. -func (c *Client) VersionInfo() (*proto.VersionInfo, error) { +func (c *Client) VersionInfo(ctx context.Context) (*proto.VersionInfo, error) { var vi proto.VersionInfo - rsp, err := c.h.Get("http://localhost/v1/version") + rsp, err := c.get(ctx, "version", nil, nil) if err != nil { return nil, err } @@ -116,14 +119,10 @@ func (c *Client) VersionInfo() (*proto.VersionInfo, error) { } // ShutdownServer sends a shutdown request to the server. -func (c *Client) ShutdownServer() error { - req, err := http.NewRequest("POST", "http://localhost/v1/control", jsonBody(proto.ServerControl{ +func (c *Client) ShutdownServer(ctx context.Context) error { + rsp, err := c.post(ctx, "/control", nil, jsonBody(proto.ServerControl{ Command: "shutdown", - })) - if err != nil { - return err - } - rsp, err := c.h.Do(req) + }), nil) if err != nil { return err } @@ -139,7 +138,10 @@ func (c *Client) dialer(ctx context.Context, network, address string) (net.Conn, Timeout: 30 * time.Second, KeepAlive: 30 * time.Second, } - switch c.proto { + // It's important to use the client's addr for npipe/unix and not the + // address param because the address param is always "localhost:port" for + // HTTP clients and npipe/unix don't have a concept of ports. + switch c.network { case "npipe": ctx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() @@ -150,3 +152,71 @@ func (c *Client) dialer(ctx context.Context, network, address string) (net.Conn, return d.DialContext(ctx, network, address) } } + +func (c *Client) get(ctx context.Context, path string, query url.Values, headers http.Header) (*http.Response, error) { + return c.sendReq(ctx, http.MethodGet, path, query, nil, headers) +} + +func (c *Client) post(ctx context.Context, path string, query url.Values, body io.Reader, headers http.Header) (*http.Response, error) { + return c.sendReq(ctx, http.MethodPost, path, query, body, headers) +} + +func (c *Client) put(ctx context.Context, path string, query url.Values, body io.Reader, headers http.Header) (*http.Response, error) { + return c.sendReq(ctx, http.MethodPut, path, query, body, headers) +} + +func (c *Client) delete(ctx context.Context, path string, query url.Values, headers http.Header) (*http.Response, error) { + return c.sendReq(ctx, http.MethodDelete, path, query, nil, headers) +} + +func (c *Client) sendReq(ctx context.Context, method, path string, query url.Values, body io.Reader, headers http.Header) (*http.Response, error) { + url := (&url.URL{ + Path: stdpath.Join("/v1", path), // Right now, we only have v1 + RawQuery: query.Encode(), + }).String() + req, err := c.buildReq(ctx, method, url, body, headers) + if err != nil { + return nil, err + } + + rsp, err := c.doReq(req) + if err != nil { + return nil, err + } + + // TODO: check server errors in the response body? + + return rsp, nil +} + +func (c *Client) doReq(req *http.Request) (*http.Response, error) { + rsp, err := c.h.Do(req) + if err != nil { + return nil, err + } + return rsp, nil +} + +func (c *Client) buildReq(ctx context.Context, method, url string, body io.Reader, headers http.Header) (*http.Request, error) { + r, err := http.NewRequestWithContext(ctx, method, url, body) + if err != nil { + return nil, err + } + + for k, v := range headers { + r.Header[http.CanonicalHeaderKey(k)] = v + } + + r.URL.Scheme = "http" // This is always http because we don't use TLS + r.URL.Host = c.addr + if c.network == "npipe" || c.network == "unix" { + // We use a dummy host for non-tcp connections. + r.Host = DummyHost + } + + if body != nil && r.Header.Get("Content-Type") == "" { + r.Header.Set("Content-Type", "text/plain") + } + + return r, nil +} diff --git a/internal/client/proto.go b/internal/client/proto.go index 09ad5d379dbdef985a31bbb0926fca9c81624887..c048ecbfbf1e3575e1cdeb71dc952a7425d95582 100644 --- a/internal/client/proto.go +++ b/internal/client/proto.go @@ -24,29 +24,22 @@ import ( func (c *Client) SubscribeEvents(ctx context.Context) (<-chan any, error) { events := make(chan any, 100) - r, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("http://localhost/v1/instances/%s/events", c.id), nil) + rsp, err := c.get(ctx, fmt.Sprintf("/instances/%s/events", c.id), nil, http.Header{ + "Accept": []string{"text/event-stream"}, + "Cache-Control": []string{"no-cache"}, + "Connection": []string{"keep-alive"}, + }) if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) + return nil, fmt.Errorf("failed to subscribe to events: %w", err) } - r.Header.Set("Accept", "text/event-stream") - r.Header.Set("Cache-Control", "no-cache") - r.Header.Set("Connection", "keep-alive") + if rsp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("failed to subscribe to events: status code %d", rsp.StatusCode) + } go func() { - rsp, err := c.h.Do(r) - if err != nil { - slog.Error("subscribing to events", "error", err) - return - } - defer rsp.Body.Close() - if rsp.StatusCode != http.StatusOK { - slog.Error("subscribing to events", "status_code", rsp.StatusCode) - return - } - scr := bufio.NewReader(rsp.Body) for { line, err := scr.ReadBytes('\n') @@ -151,11 +144,7 @@ func sendEvent(ctx context.Context, evc chan any, ev any) { } func (c *Client) GetLSPDiagnostics(ctx context.Context, lsp string) (map[protocol.DocumentURI][]protocol.Diagnostic, error) { - r, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("http://localhost/v1/instances/%s/lsps/%s/diagnostics", c.id, lsp), nil) - if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) - } - rsp, err := c.h.Do(r) + rsp, err := c.get(ctx, fmt.Sprintf("/instances/%s/lsps/%s/diagnostics", c.id, lsp), nil, nil) if err != nil { return nil, fmt.Errorf("failed to get LSP diagnostics: %w", err) } @@ -171,11 +160,7 @@ func (c *Client) GetLSPDiagnostics(ctx context.Context, lsp string) (map[protoco } func (c *Client) GetLSPs(ctx context.Context) (map[string]app.LSPClientInfo, error) { - r, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("http://localhost/v1/instances/%s/lsps", c.id), nil) - if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) - } - rsp, err := c.h.Do(r) + rsp, err := c.get(ctx, fmt.Sprintf("/instances/%s/lsps", c.id), nil, nil) if err != nil { return nil, fmt.Errorf("failed to get LSPs: %w", err) } @@ -191,11 +176,7 @@ func (c *Client) GetLSPs(ctx context.Context) (map[string]app.LSPClientInfo, err } func (c *Client) GetAgentSessionQueuedPrompts(ctx context.Context, sessionID string) (int, error) { - r, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("http://localhost/v1/instances/%s/agent/sessions/%s/prompts/queued", c.id, sessionID), nil) - if err != nil { - return 0, fmt.Errorf("failed to create request: %w", err) - } - rsp, err := c.h.Do(r) + rsp, err := c.get(ctx, fmt.Sprintf("/instances/%s/agent/sessions/%s/prompts/queued", c.id, sessionID), nil, nil) if err != nil { return 0, fmt.Errorf("failed to get session agent queued prompts: %w", err) } @@ -211,11 +192,7 @@ func (c *Client) GetAgentSessionQueuedPrompts(ctx context.Context, sessionID str } func (c *Client) ClearAgentSessionQueuedPrompts(ctx context.Context, sessionID string) error { - r, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("http://localhost/v1/instances/%s/agent/sessions/%s/prompts/clear", c.id, sessionID), nil) - if err != nil { - return fmt.Errorf("failed to create request: %w", err) - } - rsp, err := c.h.Do(r) + rsp, err := c.post(ctx, fmt.Sprintf("/instances/%s/agent/sessions/%s/prompts/clear", c.id, sessionID), nil, nil, nil) if err != nil { return fmt.Errorf("failed to clear session agent queued prompts: %w", err) } @@ -227,11 +204,7 @@ func (c *Client) ClearAgentSessionQueuedPrompts(ctx context.Context, sessionID s } func (c *Client) GetAgentInfo(ctx context.Context) (*proto.AgentInfo, error) { - r, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("http://localhost/v1/instances/%s/agent", c.id), nil) - if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) - } - rsp, err := c.h.Do(r) + rsp, err := c.get(ctx, fmt.Sprintf("/instances/%s/agent", c.id), nil, nil) if err != nil { return nil, fmt.Errorf("failed to get agent status: %w", err) } @@ -247,11 +220,7 @@ func (c *Client) GetAgentInfo(ctx context.Context) (*proto.AgentInfo, error) { } func (c *Client) UpdateAgent(ctx context.Context) error { - r, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("http://localhost/v1/instances/%s/agent/update", c.id), nil) - if err != nil { - return fmt.Errorf("failed to create request: %w", err) - } - rsp, err := c.h.Do(r) + rsp, err := c.post(ctx, fmt.Sprintf("/instances/%s/agent/update", c.id), nil, nil, nil) if err != nil { return fmt.Errorf("failed to update agent: %w", err) } @@ -263,15 +232,11 @@ func (c *Client) UpdateAgent(ctx context.Context) error { } func (c *Client) SendMessage(ctx context.Context, sessionID, message string, attchments ...message.Attachment) error { - r, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("http://localhost/v1/instances/%s/agent", c.id), jsonBody(proto.AgentMessage{ + rsp, err := c.post(ctx, fmt.Sprintf("/instances/%s/agent", c.id), nil, jsonBody(proto.AgentMessage{ SessionID: sessionID, Prompt: message, Attachments: attchments, - })) - if err != nil { - return fmt.Errorf("failed to create request: %w", err) - } - rsp, err := c.h.Do(r) + }), http.Header{"Content-Type": []string{"application/json"}}) if err != nil { return fmt.Errorf("failed to send message to agent: %w", err) } @@ -283,11 +248,7 @@ func (c *Client) SendMessage(ctx context.Context, sessionID, message string, att } func (c *Client) GetAgentSessionInfo(ctx context.Context, sessionID string) (*proto.AgentSession, error) { - r, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("http://localhost/v1/instances/%s/agent/sessions/%s", c.id, sessionID), nil) - if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) - } - rsp, err := c.h.Do(r) + rsp, err := c.get(ctx, fmt.Sprintf("/instances/%s/agent/sessions/%s", c.id, sessionID), nil, nil) if err != nil { return nil, fmt.Errorf("failed to get session agent info: %w", err) } @@ -303,11 +264,7 @@ func (c *Client) GetAgentSessionInfo(ctx context.Context, sessionID string) (*pr } func (c *Client) AgentSummarizeSession(ctx context.Context, sessionID string) error { - r, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("http://localhost/v1/instances/%s/agent/sessions/%s/summarize", c.id, sessionID), nil) - if err != nil { - return fmt.Errorf("failed to create request: %w", err) - } - rsp, err := c.h.Do(r) + rsp, err := c.post(ctx, fmt.Sprintf("/instances/%s/agent/sessions/%s/summarize", c.id, sessionID), nil, nil, nil) if err != nil { return fmt.Errorf("failed to summarize session: %w", err) } @@ -319,11 +276,7 @@ func (c *Client) AgentSummarizeSession(ctx context.Context, sessionID string) er } func (c *Client) ListMessages(ctx context.Context, sessionID string) ([]message.Message, error) { - r, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("http://localhost/v1/instances/%s/sessions/%s/messages", c.id, sessionID), nil) - if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) - } - rsp, err := c.h.Do(r) + rsp, err := c.get(ctx, fmt.Sprintf("/instances/%s/sessions/%s/messages", c.id, sessionID), nil, nil) if err != nil { return nil, fmt.Errorf("failed to get messages: %w", err) } @@ -339,11 +292,7 @@ func (c *Client) ListMessages(ctx context.Context, sessionID string) ([]message. } func (c *Client) GetSession(ctx context.Context, sessionID string) (*session.Session, error) { - r, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("http://localhost/v1/instances/%s/sessions/%s", c.id, sessionID), nil) - if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) - } - rsp, err := c.h.Do(r) + rsp, err := c.get(ctx, fmt.Sprintf("/instances/%s/sessions/%s", c.id, sessionID), nil, nil) if err != nil { return nil, fmt.Errorf("failed to get session: %w", err) } @@ -359,11 +308,7 @@ func (c *Client) GetSession(ctx context.Context, sessionID string) (*session.Ses } func (c *Client) InitiateAgentProcessing(ctx context.Context) error { - r, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("http://localhost/v1/instances/%s/agent/init", c.id), nil) - if err != nil { - return fmt.Errorf("failed to create request: %w", err) - } - rsp, err := c.h.Do(r) + rsp, err := c.post(ctx, fmt.Sprintf("/instances/%s/agent/init", c.id), nil, nil, nil) if err != nil { return fmt.Errorf("failed to initiate session agent processing: %w", err) } @@ -375,11 +320,7 @@ func (c *Client) InitiateAgentProcessing(ctx context.Context) error { } func (c *Client) ListSessionHistoryFiles(ctx context.Context, sessionID string) ([]history.File, error) { - r, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("http://localhost/v1/instances/%s/sessions/%s/history", c.id, sessionID), nil) - if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) - } - rsp, err := c.h.Do(r) + rsp, err := c.get(ctx, fmt.Sprintf("/instances/%s/sessions/%s/history", c.id, sessionID), nil, nil) if err != nil { return nil, fmt.Errorf("failed to get session history files: %w", err) } @@ -395,12 +336,7 @@ func (c *Client) ListSessionHistoryFiles(ctx context.Context, sessionID string) } func (c *Client) CreateSession(ctx context.Context, title string) (*session.Session, error) { - r, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("http://localhost/v1/instances/%s/sessions", c.id), jsonBody(session.Session{Title: title})) - if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) - } - r.Header.Set("Content-Type", "application/json") - rsp, err := c.h.Do(r) + rsp, err := c.post(ctx, fmt.Sprintf("/instances/%s/sessions", c.id), nil, jsonBody(session.Session{Title: title}), http.Header{"Content-Type": []string{"application/json"}}) if err != nil { return nil, fmt.Errorf("failed to create session: %w", err) } @@ -416,11 +352,7 @@ func (c *Client) CreateSession(ctx context.Context, title string) (*session.Sess } func (c *Client) ListSessions(ctx context.Context) ([]session.Session, error) { - r, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("http://localhost/v1/instances/%s/sessions", c.id), nil) - if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) - } - rsp, err := c.h.Do(r) + rsp, err := c.get(ctx, fmt.Sprintf("/instances/%s/sessions", c.id), nil, nil) if err != nil { return nil, fmt.Errorf("failed to get sessions: %w", err) } @@ -436,12 +368,7 @@ func (c *Client) ListSessions(ctx context.Context) ([]session.Session, error) { } func (c *Client) GrantPermission(ctx context.Context, req proto.PermissionGrant) error { - r, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("http://localhost/v1/instances/%s/permissions/grant", c.id), jsonBody(req)) - if err != nil { - return fmt.Errorf("failed to create request: %w", err) - } - r.Header.Set("Content-Type", "application/json") - rsp, err := c.h.Do(r) + rsp, err := c.post(ctx, fmt.Sprintf("/instances/%s/permissions/grant", c.id), nil, jsonBody(req), http.Header{"Content-Type": []string{"application/json"}}) if err != nil { return fmt.Errorf("failed to grant permission: %w", err) } @@ -453,12 +380,7 @@ func (c *Client) GrantPermission(ctx context.Context, req proto.PermissionGrant) } func (c *Client) SetPermissionsSkipRequests(ctx context.Context, skip bool) error { - r, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("http://localhost/v1/instances/%s/permissions/skip", c.id), jsonBody(proto.PermissionSkipRequest{Skip: skip})) - if err != nil { - return fmt.Errorf("failed to create request: %w", err) - } - r.Header.Set("Content-Type", "application/json") - rsp, err := c.h.Do(r) + rsp, err := c.post(ctx, fmt.Sprintf("/instances/%s/permissions/skip", c.id), nil, jsonBody(proto.PermissionSkipRequest{Skip: skip}), http.Header{"Content-Type": []string{"application/json"}}) if err != nil { return fmt.Errorf("failed to set permissions skip requests: %w", err) } @@ -470,11 +392,7 @@ func (c *Client) SetPermissionsSkipRequests(ctx context.Context, skip bool) erro } func (c *Client) GetPermissionsSkipRequests(ctx context.Context) (bool, error) { - r, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("http://localhost/v1/instances/%s/permissions/skip", c.id), nil) - if err != nil { - return false, fmt.Errorf("failed to create request: %w", err) - } - rsp, err := c.h.Do(r) + rsp, err := c.get(ctx, fmt.Sprintf("/instances/%s/permissions/skip", c.id), nil, nil) if err != nil { return false, fmt.Errorf("failed to get permissions skip requests: %w", err) } @@ -490,11 +408,7 @@ func (c *Client) GetPermissionsSkipRequests(ctx context.Context) (bool, error) { } func (c *Client) GetConfig(ctx context.Context) (*config.Config, error) { - r, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("http://localhost/v1/instances/%s/config", c.id), nil) - if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) - } - rsp, err := c.h.Do(r) + rsp, err := c.get(ctx, fmt.Sprintf("/instances/%s/config", c.id), nil, nil) if err != nil { return nil, fmt.Errorf("failed to get config: %w", err) } @@ -510,13 +424,7 @@ func (c *Client) GetConfig(ctx context.Context) (*config.Config, error) { } func (c *Client) CreateInstance(ctx context.Context, ins proto.Instance) (*proto.Instance, error) { - r, err := http.NewRequestWithContext(ctx, "POST", "http://localhost/v1/instances", jsonBody(ins)) - if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) - } - - r.Header.Set("Content-Type", "application/json") - rsp, err := c.h.Do(r) + rsp, err := c.post(ctx, "instances", nil, jsonBody(ins), http.Header{"Content-Type": []string{"application/json"}}) if err != nil { return nil, fmt.Errorf("failed to create instance: %w", err) } @@ -532,11 +440,7 @@ func (c *Client) CreateInstance(ctx context.Context, ins proto.Instance) (*proto } func (c *Client) DeleteInstance(ctx context.Context, id string) error { - r, err := http.NewRequestWithContext(ctx, "DELETE", fmt.Sprintf("http://localhost/v1/instances/%s", id), nil) - if err != nil { - return fmt.Errorf("failed to create request: %w", err) - } - rsp, err := c.h.Do(r) + rsp, err := c.delete(ctx, fmt.Sprintf("/instances/%s", id), nil, nil) if err != nil { return fmt.Errorf("failed to delete instance: %w", err) } diff --git a/internal/cmd/root.go b/internal/cmd/root.go index 59e63632384150ca08bfeb6c73978ce9c3a4e20b..b1c31a116712bbe8afe8a660948fcacf5f1693ca 100644 --- a/internal/cmd/root.go +++ b/internal/cmd/root.go @@ -114,7 +114,7 @@ crush -y } for range 10 { - err = c.Health() + err = c.Health(cmd.Context()) if err == nil { break } @@ -227,7 +227,7 @@ func setupApp(cmd *cobra.Command, hostURL *url.URL) (*client.Client, error) { c.SetID(ins.ID) - cfg, err := c.GetGlobalConfig() + cfg, err := c.GetGlobalConfig(cmd.Context()) if err != nil { return nil, fmt.Errorf("failed to get global config: %v", err) }