Detailed changes
@@ -79,6 +79,7 @@ require (
cloud.google.com/go/compute/metadata v0.9.0 // indirect
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.17.0 // indirect
github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0 // indirect
+ github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/andybalholm/cascadia v1.3.3 // indirect
github.com/aws/aws-sdk-go-v2 v1.41.2 // indirect
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5 // indirect
@@ -34,6 +34,8 @@ github.com/JohannesKaufmann/html-to-markdown v1.6.0 h1:04VXMiE50YYfCfLboJCLcgqF5
github.com/JohannesKaufmann/html-to-markdown v1.6.0/go.mod h1:NUI78lGg/a7vpEJTz/0uOcYMaibytE4BUOQS8k78yPQ=
github.com/MakeNowJust/heredoc v1.0.0 h1:cXCdzVdstXyiTqTvfqk9SDHpKNjxuom+DOlyEeQ4pzQ=
github.com/MakeNowJust/heredoc v1.0.0/go.mod h1:mG5amYoWBHf8vpLOuehzbGGw0EHxpZZ6lCpQ4fNJ8LE=
+github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
+github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/PuerkitoBio/goquery v1.9.2/go.mod h1:GHPCaP0ODyyxqcNoFGYlAprUFH81NuRPd0GX3Zu2Mvk=
github.com/PuerkitoBio/goquery v1.11.0 h1:jZ7pwMQXIITcUXNH83LLk+txlaEy6NVOfTuP43xxfqw=
github.com/PuerkitoBio/goquery v1.11.0/go.mod h1:wQHgxUOU3JGuj3oD/QFfxUdlzW6xPHfqyHre6VMY4DQ=
@@ -143,6 +143,11 @@ func (app *App) Config() *config.Config {
return app.config
}
+// Events returns the events channel for the application.
+func (app *App) Events() <-chan tea.Msg {
+ return app.events
+}
+
// RunNonInteractive runs the application in non-interactive mode with the
// given prompt, printing to stdout.
func (app *App) RunNonInteractive(ctx context.Context, output io.Writer, prompt, largeModel, smallModel string, hideSpinner bool) error {
@@ -0,0 +1,192 @@
+package client
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net"
+ "net/http"
+ "net/url"
+ stdpath "path"
+ "path/filepath"
+ "time"
+
+ "github.com/charmbracelet/crush/internal/config"
+ "github.com/charmbracelet/crush/internal/proto"
+ "github.com/charmbracelet/crush/internal/server"
+)
+
+// DummyHost is used to satisfy the http.Client's requirement for a URL.
+const DummyHost = "api.crush.localhost"
+
+// Client represents an RPC client connected to a Crush server.
+type Client struct {
+ h *http.Client
+ path string
+ network string
+ addr string
+}
+
+// DefaultClient creates a new [Client] connected to the default server address.
+func DefaultClient(path string) (*Client, error) {
+ host, err := server.ParseHostURL(server.DefaultHost())
+ if err != nil {
+ return nil, err
+ }
+ return NewClient(path, host.Scheme, host.Host)
+}
+
+// NewClient creates a new [Client] connected to the server at the given
+// network and address.
+func NewClient(path, network, address string) (*Client, error) {
+ c := new(Client)
+ c.path = filepath.Clean(path)
+ c.network = network
+ c.addr = address
+ p := &http.Protocols{}
+ p.SetHTTP1(true)
+ p.SetUnencryptedHTTP2(true)
+ tr := http.DefaultTransport.(*http.Transport).Clone()
+ tr.Protocols = p
+ tr.DialContext = c.dialer
+ if c.network == "npipe" || c.network == "unix" {
+ tr.DisableCompression = true
+ }
+ c.h = &http.Client{
+ Transport: tr,
+ Timeout: 0,
+ }
+ return c, nil
+}
+
+// Path returns the client's workspace filesystem path.
+func (c *Client) Path() string {
+ return c.path
+}
+
+// GetGlobalConfig retrieves the server's configuration.
+func (c *Client) GetGlobalConfig(ctx context.Context) (*config.Config, error) {
+ var cfg config.Config
+ rsp, err := c.get(ctx, "/config", nil, nil)
+ if err != nil {
+ return nil, err
+ }
+ defer rsp.Body.Close()
+ if err := json.NewDecoder(rsp.Body).Decode(&cfg); err != nil {
+ return nil, err
+ }
+ return &cfg, nil
+}
+
+// Health checks the server's health status.
+func (c *Client) Health(ctx context.Context) error {
+ rsp, err := c.get(ctx, "/health", nil, nil)
+ if err != nil {
+ return err
+ }
+ defer rsp.Body.Close()
+ if rsp.StatusCode != http.StatusOK {
+ return fmt.Errorf("server health check failed: %s", rsp.Status)
+ }
+ return nil
+}
+
+// VersionInfo retrieves the server's version information.
+func (c *Client) VersionInfo(ctx context.Context) (*proto.VersionInfo, error) {
+ var vi proto.VersionInfo
+ rsp, err := c.get(ctx, "version", nil, nil)
+ if err != nil {
+ return nil, err
+ }
+ defer rsp.Body.Close()
+ if err := json.NewDecoder(rsp.Body).Decode(&vi); err != nil {
+ return nil, err
+ }
+ return &vi, nil
+}
+
+// ShutdownServer sends a shutdown request to the server.
+func (c *Client) ShutdownServer(ctx context.Context) error {
+ rsp, err := c.post(ctx, "/control", nil, jsonBody(proto.ServerControl{
+ Command: "shutdown",
+ }), nil)
+ if err != nil {
+ return err
+ }
+ defer rsp.Body.Close()
+ if rsp.StatusCode != http.StatusOK {
+ return fmt.Errorf("server shutdown failed: %s", rsp.Status)
+ }
+ return nil
+}
+
+func (c *Client) dialer(ctx context.Context, network, address string) (net.Conn, error) {
+ d := net.Dialer{
+ Timeout: 30 * time.Second,
+ KeepAlive: 30 * time.Second,
+ }
+ switch c.network {
+ case "npipe":
+ ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
+ defer cancel()
+ return dialPipeContext(ctx, c.addr)
+ case "unix":
+ return d.DialContext(ctx, "unix", c.addr)
+ default:
+ return d.DialContext(ctx, network, address)
+ }
+}
+
+func (c *Client) get(ctx context.Context, path string, query url.Values, headers http.Header) (*http.Response, error) {
+ return c.sendReq(ctx, http.MethodGet, path, query, nil, headers)
+}
+
+func (c *Client) post(ctx context.Context, path string, query url.Values, body io.Reader, headers http.Header) (*http.Response, error) {
+ return c.sendReq(ctx, http.MethodPost, path, query, body, headers)
+}
+
+func (c *Client) delete(ctx context.Context, path string, query url.Values, headers http.Header) (*http.Response, error) {
+ return c.sendReq(ctx, http.MethodDelete, path, query, nil, headers)
+}
+
+func (c *Client) sendReq(ctx context.Context, method, path string, query url.Values, body io.Reader, headers http.Header) (*http.Response, error) {
+ url := (&url.URL{
+ Path: stdpath.Join("/v1", path),
+ RawQuery: query.Encode(),
+ }).String()
+ req, err := c.buildReq(ctx, method, url, body, headers)
+ if err != nil {
+ return nil, err
+ }
+
+ rsp, err := c.h.Do(req)
+ if err != nil {
+ return nil, err
+ }
+
+ return rsp, nil
+}
+
+func (c *Client) buildReq(ctx context.Context, method, url string, body io.Reader, headers http.Header) (*http.Request, error) {
+ r, err := http.NewRequestWithContext(ctx, method, url, body)
+ if err != nil {
+ return nil, err
+ }
+
+ for k, v := range headers {
+ r.Header[http.CanonicalHeaderKey(k)] = v
+ }
+
+ r.URL.Scheme = "http"
+ r.URL.Host = c.addr
+ if c.network == "npipe" || c.network == "unix" {
+ r.Host = DummyHost
+ }
+
+ if body != nil && r.Header.Get("Content-Type") == "" {
+ r.Header.Set("Content-Type", "text/plain")
+ }
+
+ return r, nil
+}
@@ -0,0 +1,14 @@
+//go:build !windows
+// +build !windows
+
+package client
+
+import (
+ "context"
+ "net"
+ "syscall"
+)
+
+func dialPipeContext(context.Context, string) (net.Conn, error) {
+ return nil, syscall.EAFNOSUPPORT
+}
@@ -0,0 +1,15 @@
+//go:build windows
+// +build windows
+
+package client
+
+import (
+ "context"
+ "net"
+
+ "github.com/Microsoft/go-winio"
+)
+
+func dialPipeContext(ctx context.Context, address string) (net.Conn, error) {
+ return winio.DialPipeContext(ctx, address)
+}
@@ -0,0 +1,490 @@
+package client
+
+import (
+ "bufio"
+ "bytes"
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "log/slog"
+ "net/http"
+ "time"
+
+ "github.com/charmbracelet/crush/internal/app"
+ "github.com/charmbracelet/crush/internal/config"
+ "github.com/charmbracelet/crush/internal/history"
+ "github.com/charmbracelet/crush/internal/message"
+ "github.com/charmbracelet/crush/internal/proto"
+ "github.com/charmbracelet/crush/internal/pubsub"
+ "github.com/charmbracelet/crush/internal/session"
+ "github.com/charmbracelet/x/powernap/pkg/lsp/protocol"
+)
+
+// CreateWorkspace creates a new workspace on the server.
+func (c *Client) CreateWorkspace(ctx context.Context, ws proto.Workspace) (*proto.Workspace, error) {
+ rsp, err := c.post(ctx, "/workspaces", nil, jsonBody(ws), http.Header{"Content-Type": []string{"application/json"}})
+ if err != nil {
+ return nil, fmt.Errorf("failed to create workspace: %w", err)
+ }
+ defer rsp.Body.Close()
+ if rsp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("failed to create workspace: status code %d", rsp.StatusCode)
+ }
+ var created proto.Workspace
+ if err := json.NewDecoder(rsp.Body).Decode(&created); err != nil {
+ return nil, fmt.Errorf("failed to decode workspace: %w", err)
+ }
+ return &created, nil
+}
+
+// DeleteWorkspace deletes a workspace on the server.
+func (c *Client) DeleteWorkspace(ctx context.Context, id string) error {
+ rsp, err := c.delete(ctx, fmt.Sprintf("/workspaces/%s", id), nil, nil)
+ if err != nil {
+ return fmt.Errorf("failed to delete workspace: %w", err)
+ }
+ defer rsp.Body.Close()
+ if rsp.StatusCode != http.StatusOK {
+ return fmt.Errorf("failed to delete workspace: status code %d", rsp.StatusCode)
+ }
+ return nil
+}
+
+// SubscribeEvents subscribes to server-sent events for a workspace.
+func (c *Client) SubscribeEvents(ctx context.Context, id string) (<-chan any, error) {
+ events := make(chan any, 100)
+ rsp, err := c.get(ctx, fmt.Sprintf("/workspaces/%s/events", id), nil, http.Header{
+ "Accept": []string{"text/event-stream"},
+ "Cache-Control": []string{"no-cache"},
+ "Connection": []string{"keep-alive"},
+ })
+ if err != nil {
+ return nil, fmt.Errorf("failed to subscribe to events: %w", err)
+ }
+
+ if rsp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("failed to subscribe to events: status code %d", rsp.StatusCode)
+ }
+
+ go func() {
+ defer rsp.Body.Close()
+
+ scr := bufio.NewReader(rsp.Body)
+ for {
+ line, err := scr.ReadBytes('\n')
+ if errors.Is(err, io.EOF) {
+ break
+ }
+ if err != nil {
+ slog.Error("Reading from events stream", "error", err)
+ time.Sleep(time.Second * 2)
+ continue
+ }
+ line = bytes.TrimSpace(line)
+ if len(line) == 0 {
+ continue
+ }
+
+ data, ok := bytes.CutPrefix(line, []byte("data:"))
+ if !ok {
+ slog.Warn("Invalid event format", "line", string(line))
+ continue
+ }
+
+ data = bytes.TrimSpace(data)
+
+ var event pubsub.Event[any]
+ if err := json.Unmarshal(data, &event); err != nil {
+ slog.Error("Unmarshaling event", "error", err)
+ continue
+ }
+
+ type alias pubsub.Event[any]
+ aux := &struct {
+ Payload json.RawMessage `json:"payload"`
+ *alias
+ }{
+ alias: (*alias)(&event),
+ }
+
+ if err := json.Unmarshal(data, &aux); err != nil {
+ slog.Error("Unmarshaling event payload", "error", err)
+ continue
+ }
+
+ var p pubsub.Payload
+ if err := json.Unmarshal(aux.Payload, &p); err != nil {
+ slog.Error("Unmarshaling event payload", "error", err)
+ continue
+ }
+
+ switch p.Type {
+ case pubsub.PayloadTypeLSPEvent:
+ var e pubsub.Event[proto.LSPEvent]
+ _ = json.Unmarshal(data, &e)
+ sendEvent(ctx, events, e)
+ case pubsub.PayloadTypeMCPEvent:
+ var e pubsub.Event[proto.MCPEvent]
+ _ = json.Unmarshal(data, &e)
+ sendEvent(ctx, events, e)
+ case pubsub.PayloadTypePermissionRequest:
+ var e pubsub.Event[proto.PermissionRequest]
+ _ = json.Unmarshal(data, &e)
+ sendEvent(ctx, events, e)
+ case pubsub.PayloadTypePermissionNotification:
+ var e pubsub.Event[proto.PermissionNotification]
+ _ = json.Unmarshal(data, &e)
+ sendEvent(ctx, events, e)
+ case pubsub.PayloadTypeMessage:
+ var e pubsub.Event[proto.Message]
+ _ = json.Unmarshal(data, &e)
+ sendEvent(ctx, events, e)
+ case pubsub.PayloadTypeSession:
+ var e pubsub.Event[proto.Session]
+ _ = json.Unmarshal(data, &e)
+ sendEvent(ctx, events, e)
+ case pubsub.PayloadTypeFile:
+ var e pubsub.Event[proto.File]
+ _ = json.Unmarshal(data, &e)
+ sendEvent(ctx, events, e)
+ case pubsub.PayloadTypeAgentEvent:
+ var e pubsub.Event[proto.AgentEvent]
+ _ = json.Unmarshal(data, &e)
+ sendEvent(ctx, events, e)
+ default:
+ slog.Warn("Unknown event type", "type", p.Type)
+ continue
+ }
+ }
+ }()
+
+ return events, nil
+}
+
+func sendEvent(ctx context.Context, evc chan any, ev any) {
+ slog.Info("Event received", "event", fmt.Sprintf("%T %+v", ev, ev))
+ select {
+ case evc <- ev:
+ case <-ctx.Done():
+ close(evc)
+ return
+ }
+}
+
+// GetLSPDiagnostics retrieves LSP diagnostics for a specific LSP client.
+func (c *Client) GetLSPDiagnostics(ctx context.Context, id string, lspName string) (map[protocol.DocumentURI][]protocol.Diagnostic, error) {
+ rsp, err := c.get(ctx, fmt.Sprintf("/workspaces/%s/lsps/%s/diagnostics", id, lspName), nil, nil)
+ if err != nil {
+ return nil, fmt.Errorf("failed to get LSP diagnostics: %w", err)
+ }
+ defer rsp.Body.Close()
+ if rsp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("failed to get LSP diagnostics: status code %d", rsp.StatusCode)
+ }
+ var diagnostics map[protocol.DocumentURI][]protocol.Diagnostic
+ if err := json.NewDecoder(rsp.Body).Decode(&diagnostics); err != nil {
+ return nil, fmt.Errorf("failed to decode LSP diagnostics: %w", err)
+ }
+ return diagnostics, nil
+}
+
+// GetLSPs retrieves the LSP client states for a workspace.
+func (c *Client) GetLSPs(ctx context.Context, id string) (map[string]app.LSPClientInfo, error) {
+ rsp, err := c.get(ctx, fmt.Sprintf("/workspaces/%s/lsps", id), nil, nil)
+ if err != nil {
+ return nil, fmt.Errorf("failed to get LSPs: %w", err)
+ }
+ defer rsp.Body.Close()
+ if rsp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("failed to get LSPs: status code %d", rsp.StatusCode)
+ }
+ var lsps map[string]app.LSPClientInfo
+ if err := json.NewDecoder(rsp.Body).Decode(&lsps); err != nil {
+ return nil, fmt.Errorf("failed to decode LSPs: %w", err)
+ }
+ return lsps, nil
+}
+
+// GetAgentSessionQueuedPrompts retrieves the number of queued prompts for a
+// session.
+func (c *Client) GetAgentSessionQueuedPrompts(ctx context.Context, id string, sessionID string) (int, error) {
+ rsp, err := c.get(ctx, fmt.Sprintf("/workspaces/%s/agent/sessions/%s/prompts/queued", id, sessionID), nil, nil)
+ if err != nil {
+ return 0, fmt.Errorf("failed to get session agent queued prompts: %w", err)
+ }
+ defer rsp.Body.Close()
+ if rsp.StatusCode != http.StatusOK {
+ return 0, fmt.Errorf("failed to get session agent queued prompts: status code %d", rsp.StatusCode)
+ }
+ var count int
+ if err := json.NewDecoder(rsp.Body).Decode(&count); err != nil {
+ return 0, fmt.Errorf("failed to decode session agent queued prompts: %w", err)
+ }
+ return count, nil
+}
+
+// ClearAgentSessionQueuedPrompts clears the queued prompts for a session.
+func (c *Client) ClearAgentSessionQueuedPrompts(ctx context.Context, id string, sessionID string) error {
+ rsp, err := c.post(ctx, fmt.Sprintf("/workspaces/%s/agent/sessions/%s/prompts/clear", id, sessionID), nil, nil, nil)
+ if err != nil {
+ return fmt.Errorf("failed to clear session agent queued prompts: %w", err)
+ }
+ defer rsp.Body.Close()
+ if rsp.StatusCode != http.StatusOK {
+ return fmt.Errorf("failed to clear session agent queued prompts: status code %d", rsp.StatusCode)
+ }
+ return nil
+}
+
+// GetAgentInfo retrieves the agent status for a workspace.
+func (c *Client) GetAgentInfo(ctx context.Context, id string) (*proto.AgentInfo, error) {
+ rsp, err := c.get(ctx, fmt.Sprintf("/workspaces/%s/agent", id), nil, nil)
+ if err != nil {
+ return nil, fmt.Errorf("failed to get agent status: %w", err)
+ }
+ defer rsp.Body.Close()
+ if rsp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("failed to get agent status: status code %d", rsp.StatusCode)
+ }
+ var info proto.AgentInfo
+ if err := json.NewDecoder(rsp.Body).Decode(&info); err != nil {
+ return nil, fmt.Errorf("failed to decode agent status: %w", err)
+ }
+ return &info, nil
+}
+
+// UpdateAgent triggers an agent model update on the server.
+func (c *Client) UpdateAgent(ctx context.Context, id string) error {
+ rsp, err := c.post(ctx, fmt.Sprintf("/workspaces/%s/agent/update", id), nil, nil, nil)
+ if err != nil {
+ return fmt.Errorf("failed to update agent: %w", err)
+ }
+ defer rsp.Body.Close()
+ if rsp.StatusCode != http.StatusOK {
+ return fmt.Errorf("failed to update agent: status code %d", rsp.StatusCode)
+ }
+ return nil
+}
+
+// SendMessage sends a message to the agent for a workspace.
+func (c *Client) SendMessage(ctx context.Context, id string, sessionID, prompt string, attachments ...message.Attachment) error {
+ protoAttachments := make([]proto.Attachment, len(attachments))
+ for i, a := range attachments {
+ protoAttachments[i] = proto.Attachment{
+ FilePath: a.FilePath,
+ FileName: a.FileName,
+ MimeType: a.MimeType,
+ Content: a.Content,
+ }
+ }
+ rsp, err := c.post(ctx, fmt.Sprintf("/workspaces/%s/agent", id), nil, jsonBody(proto.AgentMessage{
+ SessionID: sessionID,
+ Prompt: prompt,
+ Attachments: protoAttachments,
+ }), http.Header{"Content-Type": []string{"application/json"}})
+ if err != nil {
+ return fmt.Errorf("failed to send message to agent: %w", err)
+ }
+ defer rsp.Body.Close()
+ if rsp.StatusCode != http.StatusOK {
+ return fmt.Errorf("failed to send message to agent: status code %d", rsp.StatusCode)
+ }
+ return nil
+}
+
+// GetAgentSessionInfo retrieves the agent session info for a workspace.
+func (c *Client) GetAgentSessionInfo(ctx context.Context, id string, sessionID string) (*proto.AgentSession, error) {
+ rsp, err := c.get(ctx, fmt.Sprintf("/workspaces/%s/agent/sessions/%s", id, sessionID), nil, nil)
+ if err != nil {
+ return nil, fmt.Errorf("failed to get session agent info: %w", err)
+ }
+ defer rsp.Body.Close()
+ if rsp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("failed to get session agent info: status code %d", rsp.StatusCode)
+ }
+ var info proto.AgentSession
+ if err := json.NewDecoder(rsp.Body).Decode(&info); err != nil {
+ return nil, fmt.Errorf("failed to decode session agent info: %w", err)
+ }
+ return &info, nil
+}
+
+// AgentSummarizeSession requests a session summarization.
+func (c *Client) AgentSummarizeSession(ctx context.Context, id string, sessionID string) error {
+ rsp, err := c.post(ctx, fmt.Sprintf("/workspaces/%s/agent/sessions/%s/summarize", id, sessionID), nil, nil, nil)
+ if err != nil {
+ return fmt.Errorf("failed to summarize session: %w", err)
+ }
+ defer rsp.Body.Close()
+ if rsp.StatusCode != http.StatusOK {
+ return fmt.Errorf("failed to summarize session: status code %d", rsp.StatusCode)
+ }
+ return nil
+}
+
+// InitiateAgentProcessing triggers agent initialization on the server.
+func (c *Client) InitiateAgentProcessing(ctx context.Context, id string) error {
+ rsp, err := c.post(ctx, fmt.Sprintf("/workspaces/%s/agent/init", id), nil, nil, nil)
+ if err != nil {
+ return fmt.Errorf("failed to initiate session agent processing: %w", err)
+ }
+ defer rsp.Body.Close()
+ if rsp.StatusCode != http.StatusOK {
+ return fmt.Errorf("failed to initiate session agent processing: status code %d", rsp.StatusCode)
+ }
+ return nil
+}
+
+// ListMessages retrieves all messages for a session.
+func (c *Client) ListMessages(ctx context.Context, id string, sessionID string) ([]message.Message, error) {
+ rsp, err := c.get(ctx, fmt.Sprintf("/workspaces/%s/sessions/%s/messages", id, sessionID), nil, nil)
+ if err != nil {
+ return nil, fmt.Errorf("failed to get messages: %w", err)
+ }
+ defer rsp.Body.Close()
+ if rsp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("failed to get messages: status code %d", rsp.StatusCode)
+ }
+ var messages []message.Message
+ if err := json.NewDecoder(rsp.Body).Decode(&messages); err != nil {
+ return nil, fmt.Errorf("failed to decode messages: %w", err)
+ }
+ return messages, nil
+}
+
+// GetSession retrieves a specific session.
+func (c *Client) GetSession(ctx context.Context, id string, sessionID string) (*session.Session, error) {
+ rsp, err := c.get(ctx, fmt.Sprintf("/workspaces/%s/sessions/%s", id, sessionID), nil, nil)
+ if err != nil {
+ return nil, fmt.Errorf("failed to get session: %w", err)
+ }
+ defer rsp.Body.Close()
+ if rsp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("failed to get session: status code %d", rsp.StatusCode)
+ }
+ var sess session.Session
+ if err := json.NewDecoder(rsp.Body).Decode(&sess); err != nil {
+ return nil, fmt.Errorf("failed to decode session: %w", err)
+ }
+ return &sess, nil
+}
+
+// ListSessionHistoryFiles retrieves history files for a session.
+func (c *Client) ListSessionHistoryFiles(ctx context.Context, id string, sessionID string) ([]history.File, error) {
+ rsp, err := c.get(ctx, fmt.Sprintf("/workspaces/%s/sessions/%s/history", id, sessionID), nil, nil)
+ if err != nil {
+ return nil, fmt.Errorf("failed to get session history files: %w", err)
+ }
+ defer rsp.Body.Close()
+ if rsp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("failed to get session history files: status code %d", rsp.StatusCode)
+ }
+ var files []history.File
+ if err := json.NewDecoder(rsp.Body).Decode(&files); err != nil {
+ return nil, fmt.Errorf("failed to decode session history files: %w", err)
+ }
+ return files, nil
+}
+
+// CreateSession creates a new session in a workspace.
+func (c *Client) CreateSession(ctx context.Context, id string, title string) (*session.Session, error) {
+ rsp, err := c.post(ctx, fmt.Sprintf("/workspaces/%s/sessions", id), nil, jsonBody(session.Session{Title: title}), http.Header{"Content-Type": []string{"application/json"}})
+ if err != nil {
+ return nil, fmt.Errorf("failed to create session: %w", err)
+ }
+ defer rsp.Body.Close()
+ if rsp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("failed to create session: status code %d", rsp.StatusCode)
+ }
+ var sess session.Session
+ if err := json.NewDecoder(rsp.Body).Decode(&sess); err != nil {
+ return nil, fmt.Errorf("failed to decode session: %w", err)
+ }
+ return &sess, nil
+}
+
+// ListSessions lists all sessions in a workspace.
+func (c *Client) ListSessions(ctx context.Context, id string) ([]session.Session, error) {
+ rsp, err := c.get(ctx, fmt.Sprintf("/workspaces/%s/sessions", id), nil, nil)
+ if err != nil {
+ return nil, fmt.Errorf("failed to get sessions: %w", err)
+ }
+ defer rsp.Body.Close()
+ if rsp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("failed to get sessions: status code %d", rsp.StatusCode)
+ }
+ var sessions []session.Session
+ if err := json.NewDecoder(rsp.Body).Decode(&sessions); err != nil {
+ return nil, fmt.Errorf("failed to decode sessions: %w", err)
+ }
+ return sessions, nil
+}
+
+// GrantPermission grants a permission on a workspace.
+func (c *Client) GrantPermission(ctx context.Context, id string, req proto.PermissionGrant) error {
+ rsp, err := c.post(ctx, fmt.Sprintf("/workspaces/%s/permissions/grant", id), nil, jsonBody(req), http.Header{"Content-Type": []string{"application/json"}})
+ if err != nil {
+ return fmt.Errorf("failed to grant permission: %w", err)
+ }
+ defer rsp.Body.Close()
+ if rsp.StatusCode != http.StatusOK {
+ return fmt.Errorf("failed to grant permission: status code %d", rsp.StatusCode)
+ }
+ return nil
+}
+
+// SetPermissionsSkipRequests sets the skip-requests flag for a workspace.
+func (c *Client) SetPermissionsSkipRequests(ctx context.Context, id string, skip bool) error {
+ rsp, err := c.post(ctx, fmt.Sprintf("/workspaces/%s/permissions/skip", id), nil, jsonBody(proto.PermissionSkipRequest{Skip: skip}), http.Header{"Content-Type": []string{"application/json"}})
+ if err != nil {
+ return fmt.Errorf("failed to set permissions skip requests: %w", err)
+ }
+ defer rsp.Body.Close()
+ if rsp.StatusCode != http.StatusOK {
+ return fmt.Errorf("failed to set permissions skip requests: status code %d", rsp.StatusCode)
+ }
+ return nil
+}
+
+// GetPermissionsSkipRequests retrieves the skip-requests flag for a workspace.
+func (c *Client) GetPermissionsSkipRequests(ctx context.Context, id string) (bool, error) {
+ rsp, err := c.get(ctx, fmt.Sprintf("/workspaces/%s/permissions/skip", id), nil, nil)
+ if err != nil {
+ return false, fmt.Errorf("failed to get permissions skip requests: %w", err)
+ }
+ defer rsp.Body.Close()
+ if rsp.StatusCode != http.StatusOK {
+ return false, fmt.Errorf("failed to get permissions skip requests: status code %d", rsp.StatusCode)
+ }
+ var skip proto.PermissionSkipRequest
+ if err := json.NewDecoder(rsp.Body).Decode(&skip); err != nil {
+ return false, fmt.Errorf("failed to decode permissions skip requests: %w", err)
+ }
+ return skip.Skip, nil
+}
+
+// GetConfig retrieves the workspace-specific configuration.
+func (c *Client) GetConfig(ctx context.Context, id string) (*config.Config, error) {
+ rsp, err := c.get(ctx, fmt.Sprintf("/workspaces/%s/config", id), nil, nil)
+ if err != nil {
+ return nil, fmt.Errorf("failed to get config: %w", err)
+ }
+ defer rsp.Body.Close()
+ if rsp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("failed to get config: status code %d", rsp.StatusCode)
+ }
+ var cfg config.Config
+ if err := json.NewDecoder(rsp.Body).Decode(&cfg); err != nil {
+ return nil, fmt.Errorf("failed to decode config: %w", err)
+ }
+ return &cfg, nil
+}
+
+func jsonBody(v any) *bytes.Buffer {
+ b := new(bytes.Buffer)
+ m, _ := json.Marshal(v)
+ b.Write(m)
+ return b
+}
@@ -6,20 +6,29 @@ import (
"errors"
"fmt"
"io"
+ "io/fs"
"log/slog"
+ "net/url"
"os"
+ "os/exec"
"path/filepath"
+ "regexp"
"strconv"
"strings"
+ "time"
tea "charm.land/bubbletea/v2"
"charm.land/lipgloss/v2"
"github.com/charmbracelet/colorprofile"
"github.com/charmbracelet/crush/internal/app"
+ "github.com/charmbracelet/crush/internal/client"
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/db"
"github.com/charmbracelet/crush/internal/event"
+ "github.com/charmbracelet/crush/internal/log"
"github.com/charmbracelet/crush/internal/projects"
+ "github.com/charmbracelet/crush/internal/proto"
+ "github.com/charmbracelet/crush/internal/server"
"github.com/charmbracelet/crush/internal/ui/common"
ui "github.com/charmbracelet/crush/internal/ui/model"
"github.com/charmbracelet/crush/internal/version"
@@ -31,6 +40,8 @@ import (
"github.com/spf13/cobra"
)
+var clientHost string
+
func init() {
rootCmd.PersistentFlags().StringP("cwd", "c", "", "Current working directory")
rootCmd.PersistentFlags().StringP("data-dir", "D", "", "Custom crush data directory")
@@ -38,6 +49,8 @@ func init() {
rootCmd.Flags().BoolP("help", "h", false, "Help")
rootCmd.Flags().BoolP("yolo", "y", false, "Automatically accept all permissions (dangerous mode)")
+ rootCmd.Flags().StringVarP(&clientHost, "host", "H", server.DefaultHost(), "Connect to a specific crush server host (for advanced users)")
+
rootCmd.AddCommand(
runCmd,
dirsCmd,
@@ -228,6 +241,167 @@ func setupApp(cmd *cobra.Command) (*app.App, error) {
return appInstance, nil
}
+// setupClientApp sets up a client-based workspace via the server. It
+// auto-starts a detached server process if the socket does not exist.
+func setupClientApp(cmd *cobra.Command, hostURL *url.URL) (*client.Client, *proto.Workspace, error) {
+ debug, _ := cmd.Flags().GetBool("debug")
+ yolo, _ := cmd.Flags().GetBool("yolo")
+ dataDir, _ := cmd.Flags().GetString("data-dir")
+ ctx := cmd.Context()
+
+ cwd, err := ResolveCwd(cmd)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ c, err := client.NewClient(cwd, hostURL.Scheme, hostURL.Host)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ ws, err := c.CreateWorkspace(ctx, proto.Workspace{
+ Path: cwd,
+ DataDir: dataDir,
+ Debug: debug,
+ YOLO: yolo,
+ Env: os.Environ(),
+ })
+ if err != nil {
+ return nil, nil, fmt.Errorf("failed to create workspace: %v", err)
+ }
+
+ cfg, err := c.GetGlobalConfig(cmd.Context())
+ if err != nil {
+ return nil, nil, fmt.Errorf("failed to get global config: %v", err)
+ }
+
+ if shouldEnableMetrics(cfg) {
+ event.Init()
+ }
+
+ return c, ws, nil
+}
+
+// ensureServer auto-starts a detached server if the socket file does not
+// exist.
+func ensureServer(cmd *cobra.Command, hostURL *url.URL) error {
+ switch hostURL.Scheme {
+ case "unix", "npipe":
+ _, err := os.Stat(hostURL.Host)
+ if err != nil && errors.Is(err, fs.ErrNotExist) {
+ if err := startDetachedServer(cmd); err != nil {
+ return err
+ }
+ }
+
+ for range 10 {
+ _, err = os.Stat(hostURL.Host)
+ if err == nil {
+ break
+ }
+ select {
+ case <-cmd.Context().Done():
+ return cmd.Context().Err()
+ case <-time.After(100 * time.Millisecond):
+ }
+ }
+ if err != nil {
+ return fmt.Errorf("failed to initialize crush server: %v", err)
+ }
+ default:
+ // TCP: assume server is already running.
+ }
+ return nil
+}
+
+// waitForHealth polls the server's health endpoint until it responds.
+func waitForHealth(ctx context.Context, c *client.Client) error {
+ var err error
+ for range 10 {
+ err = c.Health(ctx)
+ if err == nil {
+ return nil
+ }
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ case <-time.After(100 * time.Millisecond):
+ }
+ }
+ return fmt.Errorf("failed to connect to crush server: %v", err)
+}
+
+// streamEvents forwards SSE events from the client to the TUI program.
+func streamEvents(ctx context.Context, evc <-chan any, p *tea.Program) {
+ defer log.RecoverPanic("app.Subscribe", func() {
+ slog.Info("TUI subscription panic: attempting graceful shutdown")
+ p.Quit()
+ })
+
+ for {
+ select {
+ case <-ctx.Done():
+ slog.Debug("TUI message handler shutting down")
+ return
+ case ev, ok := <-evc:
+ if !ok {
+ slog.Debug("TUI message channel closed")
+ return
+ }
+ p.Send(ev)
+ }
+ }
+}
+
+var safeNameRegexp = regexp.MustCompile(`[^a-zA-Z0-9._-]`)
+
+func startDetachedServer(cmd *cobra.Command) error {
+ exe, err := os.Executable()
+ if err != nil {
+ return fmt.Errorf("failed to get executable path: %v", err)
+ }
+
+ safeClientHost := safeNameRegexp.ReplaceAllString(clientHost, "_")
+ chDir := filepath.Join(config.GlobalCacheDir(), "server-"+safeClientHost)
+ if err := os.MkdirAll(chDir, 0o700); err != nil {
+ return fmt.Errorf("failed to create server working directory: %v", err)
+ }
+
+ cmdArgs := []string{"server"}
+ if clientHost != server.DefaultHost() {
+ cmdArgs = append(cmdArgs, "--host", clientHost)
+ }
+
+ c := exec.CommandContext(cmd.Context(), exe, cmdArgs...)
+ stdoutPath := filepath.Join(chDir, "stdout.log")
+ stderrPath := filepath.Join(chDir, "stderr.log")
+ detachProcess(c)
+
+ stdout, err := os.Create(stdoutPath)
+ if err != nil {
+ return fmt.Errorf("failed to create stdout log file: %v", err)
+ }
+ defer stdout.Close()
+ c.Stdout = stdout
+
+ stderr, err := os.Create(stderrPath)
+ if err != nil {
+ return fmt.Errorf("failed to create stderr log file: %v", err)
+ }
+ defer stderr.Close()
+ c.Stderr = stderr
+
+ if err := c.Start(); err != nil {
+ return fmt.Errorf("failed to start crush server: %v", err)
+ }
+
+ if err := c.Process.Release(); err != nil {
+ return fmt.Errorf("failed to detach crush server process: %v", err)
+ }
+
+ return nil
+}
+
func shouldEnableMetrics(cfg *config.Config) bool {
if v, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_METRICS")); v {
return false
@@ -0,0 +1,16 @@
+//go:build !windows
+// +build !windows
+
+package cmd
+
+import (
+ "os/exec"
+ "syscall"
+)
+
+func detachProcess(c *exec.Cmd) {
+ if c.SysProcAttr == nil {
+ c.SysProcAttr = &syscall.SysProcAttr{}
+ }
+ c.SysProcAttr.Setsid = true
+}
@@ -0,0 +1,18 @@
+//go:build windows
+// +build windows
+
+package cmd
+
+import (
+ "os/exec"
+ "syscall"
+
+ "golang.org/x/sys/windows"
+)
+
+func detachProcess(c *exec.Cmd) {
+ if c.SysProcAttr == nil {
+ c.SysProcAttr = &syscall.SysProcAttr{}
+ }
+ c.SysProcAttr.CreationFlags = syscall.CREATE_NEW_PROCESS_GROUP | windows.DETACHED_PROCESS
+}
@@ -0,0 +1,98 @@
+package cmd
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "log/slog"
+ "os"
+ "os/signal"
+ "time"
+
+ "github.com/charmbracelet/crush/internal/config"
+ "github.com/charmbracelet/crush/internal/server"
+ "github.com/spf13/cobra"
+)
+
+var serverHost string
+
+func init() {
+ serverCmd.Flags().StringVarP(&serverHost, "host", "H", server.DefaultHost(), "Server host (TCP or Unix socket)")
+ rootCmd.AddCommand(serverCmd)
+}
+
+var serverCmd = &cobra.Command{
+ Use: "server",
+ Short: "Start the Crush server",
+ RunE: func(cmd *cobra.Command, _ []string) error {
+ dataDir, err := cmd.Flags().GetString("data-dir")
+ if err != nil {
+ return fmt.Errorf("failed to get data directory: %v", err)
+ }
+ debug, err := cmd.Flags().GetBool("debug")
+ if err != nil {
+ return fmt.Errorf("failed to get debug flag: %v", err)
+ }
+
+ cfg, err := config.Load("", dataDir, debug)
+ if err != nil {
+ return fmt.Errorf("failed to load configuration: %v", err)
+ }
+
+ handler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
+ Level: slog.LevelInfo,
+ })
+ if debug {
+ handler = slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
+ Level: slog.LevelDebug,
+ })
+ }
+ slog.SetDefault(slog.New(handler))
+
+ hostURL, err := server.ParseHostURL(serverHost)
+ if err != nil {
+ return fmt.Errorf("invalid server host: %v", err)
+ }
+
+ srv := server.NewServer(cfg, hostURL.Scheme, hostURL.Host)
+ srv.SetLogger(slog.Default())
+ slog.Info("Starting Crush server...", "addr", serverHost)
+
+ errch := make(chan error, 1)
+ sigch := make(chan os.Signal, 1)
+ sigs := []os.Signal{os.Interrupt}
+ sigs = append(sigs, addSignals(sigs)...)
+ signal.Notify(sigch, sigs...)
+
+ go func() {
+ errch <- srv.ListenAndServe()
+ }()
+
+ select {
+ case <-sigch:
+ slog.Info("Received interrupt signal...")
+ case err = <-errch:
+ if err != nil && !errors.Is(err, server.ErrServerClosed) {
+ _ = srv.Close()
+ slog.Error("Server error", "error", err)
+ return fmt.Errorf("server error: %v", err)
+ }
+ }
+
+ if errors.Is(err, server.ErrServerClosed) {
+ return nil
+ }
+
+ ctx, cancel := context.WithTimeout(cmd.Context(), 5*time.Second)
+ defer cancel()
+
+ slog.Info("Shutting down...")
+
+ if err := srv.Shutdown(ctx); err != nil {
+ slog.Error("Failed to shutdown server", "error", err)
+ return fmt.Errorf("failed to shutdown server: %v", err)
+ }
+
+ return nil
+ },
+}
@@ -0,0 +1,13 @@
+//go:build !windows
+// +build !windows
+
+package cmd
+
+import (
+ "os"
+ "syscall"
+)
+
+func addSignals(sigs []os.Signal) []os.Signal {
+ return append(sigs, syscall.SIGTERM)
+}
@@ -0,0 +1,10 @@
+//go:build windows
+// +build windows
+
+package cmd
+
+import "os"
+
+func addSignals(sigs []os.Signal) []os.Signal {
+ return sigs
+}
@@ -721,6 +721,25 @@ func GlobalConfig() string {
return filepath.Join(home.Dir(), ".config", appName, fmt.Sprintf("%s.json", appName))
}
+// GlobalCacheDir returns the path to the global cache directory for the
+// application.
+func GlobalCacheDir() string {
+ if crushCache := os.Getenv("CRUSH_CACHE_DIR"); crushCache != "" {
+ return crushCache
+ }
+ if xdgCacheHome := os.Getenv("XDG_CACHE_HOME"); xdgCacheHome != "" {
+ return filepath.Join(xdgCacheHome, appName)
+ }
+ if runtime.GOOS == "windows" {
+ localAppData := cmp.Or(
+ os.Getenv("LOCALAPPDATA"),
+ filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local"),
+ )
+ return filepath.Join(localAppData, appName, "cache")
+ }
+ return filepath.Join(home.Dir(), ".cache", appName)
+}
+
// GlobalConfigData returns the path to the main data directory for the application.
// this config is used when the app overrides configurations instead of updating the global config.
func GlobalConfigData() string {
@@ -0,0 +1,74 @@
+package proto
+
+import (
+ "encoding/json"
+ "errors"
+)
+
+// AgentEventType represents the type of agent event.
+type AgentEventType string
+
+const (
+ AgentEventTypeError AgentEventType = "error"
+ AgentEventTypeResponse AgentEventType = "response"
+ AgentEventTypeSummarize AgentEventType = "summarize"
+)
+
+// MarshalText implements the [encoding.TextMarshaler] interface.
+func (t AgentEventType) MarshalText() ([]byte, error) {
+ return []byte(t), nil
+}
+
+// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
+func (t *AgentEventType) UnmarshalText(text []byte) error {
+ *t = AgentEventType(text)
+ return nil
+}
+
+// AgentEvent represents an event emitted by the agent.
+type AgentEvent struct {
+ Type AgentEventType `json:"type"`
+ Message Message `json:"message"`
+ Error error `json:"error,omitempty"`
+
+ // When summarizing.
+ SessionID string `json:"session_id,omitempty"`
+ Progress string `json:"progress,omitempty"`
+ Done bool `json:"done,omitempty"`
+}
+
+// MarshalJSON implements the [json.Marshaler] interface.
+func (e AgentEvent) MarshalJSON() ([]byte, error) {
+ type Alias AgentEvent
+ return json.Marshal(&struct {
+ Error string `json:"error,omitempty"`
+ Alias
+ }{
+ Error: func() string {
+ if e.Error != nil {
+ return e.Error.Error()
+ }
+ return ""
+ }(),
+ Alias: (Alias)(e),
+ })
+}
+
+// UnmarshalJSON implements the [json.Unmarshaler] interface.
+func (e *AgentEvent) UnmarshalJSON(data []byte) error {
+ type Alias AgentEvent
+ aux := &struct {
+ Error string `json:"error,omitempty"`
+ Alias
+ }{
+ Alias: (Alias)(*e),
+ }
+ if err := json.Unmarshal(data, &aux); err != nil {
+ return err
+ }
+ *e = AgentEvent(aux.Alias)
+ if aux.Error != "" {
+ e.Error = errors.New(aux.Error)
+ }
+ return nil
+}
@@ -0,0 +1,12 @@
+package proto
+
+// File represents a file tracked in session history.
+type File struct {
+ ID string `json:"id"`
+ SessionID string `json:"session_id"`
+ Path string `json:"path"`
+ Content string `json:"content"`
+ Version int64 `json:"version"`
+ CreatedAt int64 `json:"created_at"`
+ UpdatedAt int64 `json:"updated_at"`
+}
@@ -0,0 +1,78 @@
+package proto
+
+import "fmt"
+
+// MCPState represents the current state of an MCP client.
+type MCPState int
+
+const (
+ MCPStateDisabled MCPState = iota
+ MCPStateStarting
+ MCPStateConnected
+ MCPStateError
+)
+
+// MarshalText implements the [encoding.TextMarshaler] interface.
+func (s MCPState) MarshalText() ([]byte, error) {
+ return []byte(s.String()), nil
+}
+
+// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
+func (s *MCPState) UnmarshalText(data []byte) error {
+ switch string(data) {
+ case "disabled":
+ *s = MCPStateDisabled
+ case "starting":
+ *s = MCPStateStarting
+ case "connected":
+ *s = MCPStateConnected
+ case "error":
+ *s = MCPStateError
+ default:
+ return fmt.Errorf("unknown mcp state: %s", data)
+ }
+ return nil
+}
+
+// String returns the string representation of the MCPState.
+func (s MCPState) String() string {
+ switch s {
+ case MCPStateDisabled:
+ return "disabled"
+ case MCPStateStarting:
+ return "starting"
+ case MCPStateConnected:
+ return "connected"
+ case MCPStateError:
+ return "error"
+ default:
+ return "unknown"
+ }
+}
+
+// MCPEventType represents the type of MCP event.
+type MCPEventType string
+
+const (
+ MCPEventStateChanged MCPEventType = "state_changed"
+)
+
+// MarshalText implements the [encoding.TextMarshaler] interface.
+func (t MCPEventType) MarshalText() ([]byte, error) {
+ return []byte(t), nil
+}
+
+// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
+func (t *MCPEventType) UnmarshalText(data []byte) error {
+ *t = MCPEventType(data)
+ return nil
+}
+
+// MCPEvent represents an event in the MCP system.
+type MCPEvent struct {
+ Type MCPEventType `json:"type"`
+ Name string `json:"name"`
+ State MCPState `json:"state"`
+ Error error `json:"error,omitempty"`
+ ToolCount int `json:"tool_count,omitempty"`
+}
@@ -0,0 +1,653 @@
+package proto
+
+import (
+ "encoding/base64"
+ "encoding/json"
+ "fmt"
+ "slices"
+ "time"
+
+ "charm.land/catwalk/pkg/catwalk"
+)
+
+// CreateMessageParams represents parameters for creating a message.
+type CreateMessageParams struct {
+ Role MessageRole `json:"role"`
+ Parts []ContentPart `json:"parts"`
+ Model string `json:"model"`
+ Provider string `json:"provider,omitempty"`
+}
+
+// Message represents a message in the proto layer.
+type Message struct {
+ ID string `json:"id"`
+ Role MessageRole `json:"role"`
+ SessionID string `json:"session_id"`
+ Parts []ContentPart `json:"parts"`
+ Model string `json:"model"`
+ Provider string `json:"provider"`
+ CreatedAt int64 `json:"created_at"`
+ UpdatedAt int64 `json:"updated_at"`
+}
+
+// MessageRole represents the role of a message sender.
+type MessageRole string
+
+const (
+ Assistant MessageRole = "assistant"
+ User MessageRole = "user"
+ System MessageRole = "system"
+ Tool MessageRole = "tool"
+)
+
+// MarshalText implements the [encoding.TextMarshaler] interface.
+func (r MessageRole) MarshalText() ([]byte, error) {
+ return []byte(r), nil
+}
+
+// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
+func (r *MessageRole) UnmarshalText(data []byte) error {
+ *r = MessageRole(data)
+ return nil
+}
+
+// FinishReason represents why a message generation finished.
+type FinishReason string
+
+const (
+ FinishReasonEndTurn FinishReason = "end_turn"
+ FinishReasonMaxTokens FinishReason = "max_tokens"
+ FinishReasonToolUse FinishReason = "tool_use"
+ FinishReasonCanceled FinishReason = "canceled"
+ FinishReasonError FinishReason = "error"
+ FinishReasonPermissionDenied FinishReason = "permission_denied"
+ FinishReasonUnknown FinishReason = "unknown"
+)
+
+// MarshalText implements the [encoding.TextMarshaler] interface.
+func (fr FinishReason) MarshalText() ([]byte, error) {
+ return []byte(fr), nil
+}
+
+// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
+func (fr *FinishReason) UnmarshalText(data []byte) error {
+ *fr = FinishReason(data)
+ return nil
+}
+
+// ContentPart is a part of a message's content.
+type ContentPart interface {
+ isPart()
+}
+
+// ReasoningContent represents the reasoning/thinking part of a message.
+type ReasoningContent struct {
+ Thinking string `json:"thinking"`
+ Signature string `json:"signature"`
+ StartedAt int64 `json:"started_at,omitempty"`
+ FinishedAt int64 `json:"finished_at,omitempty"`
+}
+
+// String returns the thinking content as a string.
+func (tc ReasoningContent) String() string {
+ return tc.Thinking
+}
+
+func (ReasoningContent) isPart() {}
+
+// TextContent represents a text part of a message.
+type TextContent struct {
+ Text string `json:"text"`
+}
+
+// String returns the text content as a string.
+func (tc TextContent) String() string {
+ return tc.Text
+}
+
+func (TextContent) isPart() {}
+
+// ImageURLContent represents an image URL part of a message.
+type ImageURLContent struct {
+ URL string `json:"url"`
+ Detail string `json:"detail,omitempty"`
+}
+
+// String returns the image URL as a string.
+func (iuc ImageURLContent) String() string {
+ return iuc.URL
+}
+
+func (ImageURLContent) isPart() {}
+
+// BinaryContent represents binary data in a message.
+type BinaryContent struct {
+ Path string
+ MIMEType string
+ Data []byte
+}
+
+// String returns a base64-encoded string of the binary data.
+func (bc BinaryContent) String(p catwalk.InferenceProvider) string {
+ base64Encoded := base64.StdEncoding.EncodeToString(bc.Data)
+ if p == catwalk.InferenceProviderOpenAI {
+ return "data:" + bc.MIMEType + ";base64," + base64Encoded
+ }
+ return base64Encoded
+}
+
+func (BinaryContent) isPart() {}
+
+// ToolCall represents a tool call in a message.
+type ToolCall struct {
+ ID string `json:"id"`
+ Name string `json:"name"`
+ Input string `json:"input"`
+ Type string `json:"type,omitempty"`
+ Finished bool `json:"finished,omitempty"`
+}
+
+func (ToolCall) isPart() {}
+
+// ToolResult represents the result of a tool call.
+type ToolResult struct {
+ ToolCallID string `json:"tool_call_id"`
+ Name string `json:"name"`
+ Content string `json:"content"`
+ Metadata string `json:"metadata"`
+ IsError bool `json:"is_error"`
+}
+
+func (ToolResult) isPart() {}
+
+// Finish represents the end of a message generation.
+type Finish struct {
+ Reason FinishReason `json:"reason"`
+ Time int64 `json:"time"`
+ Message string `json:"message,omitempty"`
+ Details string `json:"details,omitempty"`
+}
+
+func (Finish) isPart() {}
+
+// MarshalJSON implements the [json.Marshaler] interface.
+func (m Message) MarshalJSON() ([]byte, error) {
+ parts, err := MarshalParts(m.Parts)
+ if err != nil {
+ return nil, err
+ }
+
+ type Alias Message
+ return json.Marshal(&struct {
+ Parts json.RawMessage `json:"parts"`
+ *Alias
+ }{
+ Parts: json.RawMessage(parts),
+ Alias: (*Alias)(&m),
+ })
+}
+
+// UnmarshalJSON implements the [json.Unmarshaler] interface.
+func (m *Message) UnmarshalJSON(data []byte) error {
+ type Alias Message
+ aux := &struct {
+ Parts json.RawMessage `json:"parts"`
+ *Alias
+ }{
+ Alias: (*Alias)(m),
+ }
+
+ if err := json.Unmarshal(data, &aux); err != nil {
+ return err
+ }
+
+ parts, err := UnmarshalParts([]byte(aux.Parts))
+ if err != nil {
+ return err
+ }
+
+ m.Parts = parts
+ return nil
+}
+
+// Content returns the first text content part.
+func (m *Message) Content() TextContent {
+ for _, part := range m.Parts {
+ if c, ok := part.(TextContent); ok {
+ return c
+ }
+ }
+ return TextContent{}
+}
+
+// ReasoningContent returns the first reasoning content part.
+func (m *Message) ReasoningContent() ReasoningContent {
+ for _, part := range m.Parts {
+ if c, ok := part.(ReasoningContent); ok {
+ return c
+ }
+ }
+ return ReasoningContent{}
+}
+
+// ImageURLContent returns all image URL content parts.
+func (m *Message) ImageURLContent() []ImageURLContent {
+ imageURLContents := make([]ImageURLContent, 0)
+ for _, part := range m.Parts {
+ if c, ok := part.(ImageURLContent); ok {
+ imageURLContents = append(imageURLContents, c)
+ }
+ }
+ return imageURLContents
+}
+
+// BinaryContent returns all binary content parts.
+func (m *Message) BinaryContent() []BinaryContent {
+ binaryContents := make([]BinaryContent, 0)
+ for _, part := range m.Parts {
+ if c, ok := part.(BinaryContent); ok {
+ binaryContents = append(binaryContents, c)
+ }
+ }
+ return binaryContents
+}
+
+// ToolCalls returns all tool call parts.
+func (m *Message) ToolCalls() []ToolCall {
+ toolCalls := make([]ToolCall, 0)
+ for _, part := range m.Parts {
+ if c, ok := part.(ToolCall); ok {
+ toolCalls = append(toolCalls, c)
+ }
+ }
+ return toolCalls
+}
+
+// ToolResults returns all tool result parts.
+func (m *Message) ToolResults() []ToolResult {
+ toolResults := make([]ToolResult, 0)
+ for _, part := range m.Parts {
+ if c, ok := part.(ToolResult); ok {
+ toolResults = append(toolResults, c)
+ }
+ }
+ return toolResults
+}
+
+// IsFinished returns true if the message has a finish part.
+func (m *Message) IsFinished() bool {
+ for _, part := range m.Parts {
+ if _, ok := part.(Finish); ok {
+ return true
+ }
+ }
+ return false
+}
+
+// FinishPart returns the finish part if present.
+func (m *Message) FinishPart() *Finish {
+ for _, part := range m.Parts {
+ if c, ok := part.(Finish); ok {
+ return &c
+ }
+ }
+ return nil
+}
+
+// FinishReason returns the finish reason if present.
+func (m *Message) FinishReason() FinishReason {
+ for _, part := range m.Parts {
+ if c, ok := part.(Finish); ok {
+ return c.Reason
+ }
+ }
+ return ""
+}
+
+// IsThinking returns true if the message is currently in a thinking state.
+func (m *Message) IsThinking() bool {
+ return m.ReasoningContent().Thinking != "" && m.Content().Text == "" && !m.IsFinished()
+}
+
+// AppendContent appends text to the text content part.
+func (m *Message) AppendContent(delta string) {
+ found := false
+ for i, part := range m.Parts {
+ if c, ok := part.(TextContent); ok {
+ m.Parts[i] = TextContent{Text: c.Text + delta}
+ found = true
+ }
+ }
+ if !found {
+ m.Parts = append(m.Parts, TextContent{Text: delta})
+ }
+}
+
+// AppendReasoningContent appends text to the reasoning content part.
+func (m *Message) AppendReasoningContent(delta string) {
+ found := false
+ for i, part := range m.Parts {
+ if c, ok := part.(ReasoningContent); ok {
+ m.Parts[i] = ReasoningContent{
+ Thinking: c.Thinking + delta,
+ Signature: c.Signature,
+ StartedAt: c.StartedAt,
+ FinishedAt: c.FinishedAt,
+ }
+ found = true
+ }
+ }
+ if !found {
+ m.Parts = append(m.Parts, ReasoningContent{
+ Thinking: delta,
+ StartedAt: time.Now().Unix(),
+ })
+ }
+}
+
+// AppendReasoningSignature appends a signature to the reasoning content part.
+func (m *Message) AppendReasoningSignature(signature string) {
+ for i, part := range m.Parts {
+ if c, ok := part.(ReasoningContent); ok {
+ m.Parts[i] = ReasoningContent{
+ Thinking: c.Thinking,
+ Signature: c.Signature + signature,
+ StartedAt: c.StartedAt,
+ FinishedAt: c.FinishedAt,
+ }
+ return
+ }
+ }
+ m.Parts = append(m.Parts, ReasoningContent{Signature: signature})
+}
+
+// FinishThinking marks the reasoning content as finished.
+func (m *Message) FinishThinking() {
+ for i, part := range m.Parts {
+ if c, ok := part.(ReasoningContent); ok {
+ if c.FinishedAt == 0 {
+ m.Parts[i] = ReasoningContent{
+ Thinking: c.Thinking,
+ Signature: c.Signature,
+ StartedAt: c.StartedAt,
+ FinishedAt: time.Now().Unix(),
+ }
+ }
+ return
+ }
+ }
+}
+
+// ThinkingDuration returns the duration of the thinking phase.
+func (m *Message) ThinkingDuration() time.Duration {
+ reasoning := m.ReasoningContent()
+ if reasoning.StartedAt == 0 {
+ return 0
+ }
+
+ endTime := reasoning.FinishedAt
+ if endTime == 0 {
+ endTime = time.Now().Unix()
+ }
+
+ return time.Duration(endTime-reasoning.StartedAt) * time.Second
+}
+
+// FinishToolCall marks a tool call as finished.
+func (m *Message) FinishToolCall(toolCallID string) {
+ for i, part := range m.Parts {
+ if c, ok := part.(ToolCall); ok {
+ if c.ID == toolCallID {
+ m.Parts[i] = ToolCall{
+ ID: c.ID,
+ Name: c.Name,
+ Input: c.Input,
+ Type: c.Type,
+ Finished: true,
+ }
+ return
+ }
+ }
+ }
+}
+
+// AppendToolCallInput appends input to a tool call.
+func (m *Message) AppendToolCallInput(toolCallID string, inputDelta string) {
+ for i, part := range m.Parts {
+ if c, ok := part.(ToolCall); ok {
+ if c.ID == toolCallID {
+ m.Parts[i] = ToolCall{
+ ID: c.ID,
+ Name: c.Name,
+ Input: c.Input + inputDelta,
+ Type: c.Type,
+ Finished: c.Finished,
+ }
+ return
+ }
+ }
+ }
+}
+
+// AddToolCall adds or updates a tool call.
+func (m *Message) AddToolCall(tc ToolCall) {
+ for i, part := range m.Parts {
+ if c, ok := part.(ToolCall); ok {
+ if c.ID == tc.ID {
+ m.Parts[i] = tc
+ return
+ }
+ }
+ }
+ m.Parts = append(m.Parts, tc)
+}
+
+// SetToolCalls replaces all tool call parts.
+func (m *Message) SetToolCalls(tc []ToolCall) {
+ parts := make([]ContentPart, 0)
+ for _, part := range m.Parts {
+ if _, ok := part.(ToolCall); ok {
+ continue
+ }
+ parts = append(parts, part)
+ }
+ m.Parts = parts
+ for _, toolCall := range tc {
+ m.Parts = append(m.Parts, toolCall)
+ }
+}
+
+// AddToolResult adds a tool result.
+func (m *Message) AddToolResult(tr ToolResult) {
+ m.Parts = append(m.Parts, tr)
+}
+
+// SetToolResults adds multiple tool results.
+func (m *Message) SetToolResults(tr []ToolResult) {
+ for _, toolResult := range tr {
+ m.Parts = append(m.Parts, toolResult)
+ }
+}
+
+// AddFinish adds a finish part to the message.
+func (m *Message) AddFinish(reason FinishReason, message, details string) {
+ for i, part := range m.Parts {
+ if _, ok := part.(Finish); ok {
+ m.Parts = slices.Delete(m.Parts, i, i+1)
+ break
+ }
+ }
+ m.Parts = append(m.Parts, Finish{Reason: reason, Time: time.Now().Unix(), Message: message, Details: details})
+}
+
+// AddImageURL adds an image URL part to the message.
+func (m *Message) AddImageURL(url, detail string) {
+ m.Parts = append(m.Parts, ImageURLContent{URL: url, Detail: detail})
+}
+
+// AddBinary adds a binary content part to the message.
+func (m *Message) AddBinary(mimeType string, data []byte) {
+ m.Parts = append(m.Parts, BinaryContent{MIMEType: mimeType, Data: data})
+}
+
+type partType string
+
+const (
+ reasoningType partType = "reasoning"
+ textType partType = "text"
+ imageURLType partType = "image_url"
+ binaryType partType = "binary"
+ toolCallType partType = "tool_call"
+ toolResultType partType = "tool_result"
+ finishType partType = "finish"
+)
+
+type partWrapper struct {
+ Type partType `json:"type"`
+ Data ContentPart `json:"data"`
+}
+
+// MarshalParts marshals content parts to JSON.
+func MarshalParts(parts []ContentPart) ([]byte, error) {
+ wrappedParts := make([]partWrapper, len(parts))
+
+ for i, part := range parts {
+ var typ partType
+
+ switch part.(type) {
+ case ReasoningContent:
+ typ = reasoningType
+ case TextContent:
+ typ = textType
+ case ImageURLContent:
+ typ = imageURLType
+ case BinaryContent:
+ typ = binaryType
+ case ToolCall:
+ typ = toolCallType
+ case ToolResult:
+ typ = toolResultType
+ case Finish:
+ typ = finishType
+ default:
+ return nil, fmt.Errorf("unknown part type: %T", part)
+ }
+
+ wrappedParts[i] = partWrapper{
+ Type: typ,
+ Data: part,
+ }
+ }
+ return json.Marshal(wrappedParts)
+}
+
+// UnmarshalParts unmarshals content parts from JSON.
+func UnmarshalParts(data []byte) ([]ContentPart, error) {
+ temp := []json.RawMessage{}
+
+ if err := json.Unmarshal(data, &temp); err != nil {
+ return nil, err
+ }
+
+ parts := make([]ContentPart, 0)
+
+ for _, rawPart := range temp {
+ var wrapper struct {
+ Type partType `json:"type"`
+ Data json.RawMessage `json:"data"`
+ }
+
+ if err := json.Unmarshal(rawPart, &wrapper); err != nil {
+ return nil, err
+ }
+
+ switch wrapper.Type {
+ case reasoningType:
+ part := ReasoningContent{}
+ if err := json.Unmarshal(wrapper.Data, &part); err != nil {
+ return nil, err
+ }
+ parts = append(parts, part)
+ case textType:
+ part := TextContent{}
+ if err := json.Unmarshal(wrapper.Data, &part); err != nil {
+ return nil, err
+ }
+ parts = append(parts, part)
+ case imageURLType:
+ part := ImageURLContent{}
+ if err := json.Unmarshal(wrapper.Data, &part); err != nil {
+ return nil, err
+ }
+ parts = append(parts, part)
+ case binaryType:
+ part := BinaryContent{}
+ if err := json.Unmarshal(wrapper.Data, &part); err != nil {
+ return nil, err
+ }
+ parts = append(parts, part)
+ case toolCallType:
+ part := ToolCall{}
+ if err := json.Unmarshal(wrapper.Data, &part); err != nil {
+ return nil, err
+ }
+ parts = append(parts, part)
+ case toolResultType:
+ part := ToolResult{}
+ if err := json.Unmarshal(wrapper.Data, &part); err != nil {
+ return nil, err
+ }
+ parts = append(parts, part)
+ case finishType:
+ part := Finish{}
+ if err := json.Unmarshal(wrapper.Data, &part); err != nil {
+ return nil, err
+ }
+ parts = append(parts, part)
+ default:
+ return nil, fmt.Errorf("unknown part type: %s", wrapper.Type)
+ }
+ }
+
+ return parts, nil
+}
+
+// Attachment represents a file attachment.
+type Attachment struct {
+ FilePath string `json:"file_path"`
+ FileName string `json:"file_name"`
+ MimeType string `json:"mime_type"`
+ Content []byte `json:"content"`
+}
+
+// MarshalJSON implements the [json.Marshaler] interface.
+func (a Attachment) MarshalJSON() ([]byte, error) {
+ type Alias Attachment
+ return json.Marshal(&struct {
+ Content string `json:"content"`
+ *Alias
+ }{
+ Content: base64.StdEncoding.EncodeToString(a.Content),
+ Alias: (*Alias)(&a),
+ })
+}
+
+// UnmarshalJSON implements the [json.Unmarshaler] interface.
+func (a *Attachment) UnmarshalJSON(data []byte) error {
+ type Alias Attachment
+ aux := &struct {
+ Content string `json:"content"`
+ *Alias
+ }{
+ Alias: (*Alias)(a),
+ }
+ if err := json.Unmarshal(data, &aux); err != nil {
+ return err
+ }
+ content, err := base64.StdEncoding.DecodeString(aux.Content)
+ if err != nil {
+ return err
+ }
+ a.Content = content
+ return nil
+}
@@ -0,0 +1,141 @@
+package proto
+
+import (
+ "encoding/json"
+)
+
+// CreatePermissionRequest represents a request to create a permission.
+type CreatePermissionRequest struct {
+ SessionID string `json:"session_id"`
+ ToolCallID string `json:"tool_call_id"`
+ ToolName string `json:"tool_name"`
+ Description string `json:"description"`
+ Action string `json:"action"`
+ Params any `json:"params"`
+ Path string `json:"path"`
+}
+
+// PermissionNotification represents a notification about a permission change.
+type PermissionNotification struct {
+ ToolCallID string `json:"tool_call_id"`
+ Granted bool `json:"granted"`
+ Denied bool `json:"denied"`
+}
+
+// PermissionRequest represents a pending permission request.
+type PermissionRequest struct {
+ ID string `json:"id"`
+ SessionID string `json:"session_id"`
+ ToolCallID string `json:"tool_call_id"`
+ ToolName string `json:"tool_name"`
+ Description string `json:"description"`
+ Action string `json:"action"`
+ Params any `json:"params"`
+ Path string `json:"path"`
+}
+
+// UnmarshalJSON implements the json.Unmarshaler interface. This is needed
+// because the Params field is of type any, so we need to unmarshal it into
+// its appropriate type based on the [PermissionRequest.ToolName].
+func (p *PermissionRequest) UnmarshalJSON(data []byte) error {
+ type Alias PermissionRequest
+ aux := &struct {
+ Params json.RawMessage `json:"params"`
+ *Alias
+ }{
+ Alias: (*Alias)(p),
+ }
+ if err := json.Unmarshal(data, &aux); err != nil {
+ return err
+ }
+
+ params, err := unmarshalToolParams(p.ToolName, aux.Params)
+ if err != nil {
+ return err
+ }
+ p.Params = params
+ return nil
+}
+
+// UnmarshalJSON implements the json.Unmarshaler interface. This is needed
+// because the Params field is of type any, so we need to unmarshal it into
+// its appropriate type based on the [CreatePermissionRequest.ToolName].
+func (p *CreatePermissionRequest) UnmarshalJSON(data []byte) error {
+ type Alias CreatePermissionRequest
+ aux := &struct {
+ Params json.RawMessage `json:"params"`
+ *Alias
+ }{
+ Alias: (*Alias)(p),
+ }
+ if err := json.Unmarshal(data, &aux); err != nil {
+ return err
+ }
+
+ params, err := unmarshalToolParams(p.ToolName, aux.Params)
+ if err != nil {
+ return err
+ }
+ p.Params = params
+ return nil
+}
+
+func unmarshalToolParams(toolName string, raw json.RawMessage) (any, error) {
+ switch toolName {
+ case BashToolName:
+ var params BashPermissionsParams
+ if err := json.Unmarshal(raw, ¶ms); err != nil {
+ return nil, err
+ }
+ return params, nil
+ case DownloadToolName:
+ var params DownloadPermissionsParams
+ if err := json.Unmarshal(raw, ¶ms); err != nil {
+ return nil, err
+ }
+ return params, nil
+ case EditToolName:
+ var params EditPermissionsParams
+ if err := json.Unmarshal(raw, ¶ms); err != nil {
+ return nil, err
+ }
+ return params, nil
+ case WriteToolName:
+ var params WritePermissionsParams
+ if err := json.Unmarshal(raw, ¶ms); err != nil {
+ return nil, err
+ }
+ return params, nil
+ case MultiEditToolName:
+ var params MultiEditPermissionsParams
+ if err := json.Unmarshal(raw, ¶ms); err != nil {
+ return nil, err
+ }
+ return params, nil
+ case FetchToolName:
+ var params FetchPermissionsParams
+ if err := json.Unmarshal(raw, ¶ms); err != nil {
+ return nil, err
+ }
+ return params, nil
+ case ViewToolName:
+ var params ViewPermissionsParams
+ if err := json.Unmarshal(raw, ¶ms); err != nil {
+ return nil, err
+ }
+ return params, nil
+ case LSToolName:
+ var params LSPermissionsParams
+ if err := json.Unmarshal(raw, ¶ms); err != nil {
+ return nil, err
+ }
+ return params, nil
+ default:
+ // For unknown tools, keep the raw JSON as-is.
+ var generic map[string]any
+ if err := json.Unmarshal(raw, &generic); err != nil {
+ return nil, err
+ }
+ return generic, nil
+ }
+}
@@ -0,0 +1,123 @@
+package proto
+
+import (
+ "time"
+
+ "charm.land/catwalk/pkg/catwalk"
+ "github.com/charmbracelet/crush/internal/config"
+ "github.com/charmbracelet/crush/internal/lsp"
+)
+
+// Workspace represents a running app.App workspace with its associated
+// resources and state.
+type Workspace struct {
+ ID string `json:"id"`
+ Path string `json:"path"`
+ YOLO bool `json:"yolo,omitempty"`
+ Debug bool `json:"debug,omitempty"`
+ DataDir string `json:"data_dir,omitempty"`
+ Config *config.Config `json:"config,omitempty"`
+ Env []string `json:"env,omitempty"`
+}
+
+// Error represents an error response.
+type Error struct {
+ Message string `json:"message"`
+}
+
+// AgentInfo represents information about the agent.
+type AgentInfo struct {
+ IsBusy bool `json:"is_busy"`
+ Model catwalk.Model `json:"model"`
+}
+
+// IsZero checks if the AgentInfo is zero-valued.
+func (a AgentInfo) IsZero() bool {
+ return !a.IsBusy && a.Model.ID == ""
+}
+
+// AgentMessage represents a message sent to the agent.
+type AgentMessage struct {
+ SessionID string `json:"session_id"`
+ Prompt string `json:"prompt"`
+ Attachments []Attachment `json:"attachments,omitempty"`
+}
+
+// AgentSession represents a session with its busy status.
+type AgentSession struct {
+ Session
+ IsBusy bool `json:"is_busy"`
+}
+
+// IsZero checks if the AgentSession is zero-valued.
+func (a AgentSession) IsZero() bool {
+ return a == AgentSession{}
+}
+
+// PermissionAction represents an action taken on a permission request.
+type PermissionAction string
+
+const (
+ PermissionAllow PermissionAction = "allow"
+ PermissionAllowForSession PermissionAction = "allow_session"
+ PermissionDeny PermissionAction = "deny"
+)
+
+// MarshalText implements the [encoding.TextMarshaler] interface.
+func (p PermissionAction) MarshalText() ([]byte, error) {
+ return []byte(p), nil
+}
+
+// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
+func (p *PermissionAction) UnmarshalText(text []byte) error {
+ *p = PermissionAction(text)
+ return nil
+}
+
+// PermissionGrant represents a permission grant request.
+type PermissionGrant struct {
+ Permission PermissionRequest `json:"permission"`
+ Action PermissionAction `json:"action"`
+}
+
+// PermissionSkipRequest represents a request to skip permission prompts.
+type PermissionSkipRequest struct {
+ Skip bool `json:"skip"`
+}
+
+// LSPEventType represents the type of LSP event.
+type LSPEventType string
+
+const (
+ LSPEventStateChanged LSPEventType = "state_changed"
+ LSPEventDiagnosticsChanged LSPEventType = "diagnostics_changed"
+)
+
+// MarshalText implements the [encoding.TextMarshaler] interface.
+func (e LSPEventType) MarshalText() ([]byte, error) {
+ return []byte(e), nil
+}
+
+// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
+func (e *LSPEventType) UnmarshalText(data []byte) error {
+ *e = LSPEventType(data)
+ return nil
+}
+
+// LSPEvent represents an event in the LSP system.
+type LSPEvent struct {
+ Type LSPEventType `json:"type"`
+ Name string `json:"name"`
+ State lsp.ServerState `json:"state"`
+ Error error `json:"error,omitempty"`
+ DiagnosticCount int `json:"diagnostic_count,omitempty"`
+}
+
+// LSPClientInfo holds information about an LSP client's state.
+type LSPClientInfo struct {
+ Name string `json:"name"`
+ State lsp.ServerState `json:"state"`
+ Error error `json:"error,omitempty"`
+ DiagnosticCount int `json:"diagnostic_count,omitempty"`
+ ConnectedAt time.Time `json:"connected_at"`
+}
@@ -0,0 +1,6 @@
+package proto
+
+// ServerControl represents a server control request.
+type ServerControl struct {
+ Command string `json:"command"`
+}
@@ -0,0 +1,15 @@
+package proto
+
+// Session represents a session in the proto layer.
+type Session struct {
+ ID string `json:"id"`
+ ParentSessionID string `json:"parent_session_id"`
+ Title string `json:"title"`
+ MessageCount int64 `json:"message_count"`
+ PromptTokens int64 `json:"prompt_tokens"`
+ CompletionTokens int64 `json:"completion_tokens"`
+ SummaryMessageID string `json:"summary_message_id"`
+ Cost float64 `json:"cost"`
+ CreatedAt int64 `json:"created_at"`
+ UpdatedAt int64 `json:"updated_at"`
+}
@@ -0,0 +1,250 @@
+package proto
+
+// ToolResponseType represents the type of tool response.
+type ToolResponseType string
+
+const (
+ ToolResponseTypeText ToolResponseType = "text"
+ ToolResponseTypeImage ToolResponseType = "image"
+)
+
+// ToolResponse represents a response from a tool.
+type ToolResponse struct {
+ Type ToolResponseType `json:"type"`
+ Content string `json:"content"`
+ Metadata string `json:"metadata,omitempty"`
+ IsError bool `json:"is_error"`
+}
+
+const BashToolName = "bash"
+
+// BashParams represents the parameters for the bash tool.
+type BashParams struct {
+ Command string `json:"command"`
+ Timeout int `json:"timeout"`
+}
+
+// BashPermissionsParams represents the permission parameters for the bash tool.
+type BashPermissionsParams struct {
+ Command string `json:"command"`
+ Timeout int `json:"timeout"`
+}
+
+// BashResponseMetadata represents the metadata for a bash tool response.
+type BashResponseMetadata struct {
+ StartTime int64 `json:"start_time"`
+ EndTime int64 `json:"end_time"`
+ Output string `json:"output"`
+ WorkingDirectory string `json:"working_directory"`
+}
+
+// DiagnosticsParams represents the parameters for the diagnostics tool.
+type DiagnosticsParams struct {
+ FilePath string `json:"file_path"`
+}
+
+const DownloadToolName = "download"
+
+// DownloadParams represents the parameters for the download tool.
+type DownloadParams struct {
+ URL string `json:"url"`
+ FilePath string `json:"file_path"`
+ Timeout int `json:"timeout,omitempty"`
+}
+
+// DownloadPermissionsParams represents the permission parameters for the download tool.
+type DownloadPermissionsParams struct {
+ URL string `json:"url"`
+ FilePath string `json:"file_path"`
+ Timeout int `json:"timeout,omitempty"`
+}
+
+const EditToolName = "edit"
+
+// EditParams represents the parameters for the edit tool.
+type EditParams struct {
+ FilePath string `json:"file_path"`
+ OldString string `json:"old_string"`
+ NewString string `json:"new_string"`
+ ReplaceAll bool `json:"replace_all,omitempty"`
+}
+
+// EditPermissionsParams represents the permission parameters for the edit tool.
+type EditPermissionsParams struct {
+ FilePath string `json:"file_path"`
+ OldContent string `json:"old_content,omitempty"`
+ NewContent string `json:"new_content,omitempty"`
+}
+
+// EditResponseMetadata represents the metadata for an edit tool response.
+type EditResponseMetadata struct {
+ Additions int `json:"additions"`
+ Removals int `json:"removals"`
+ OldContent string `json:"old_content,omitempty"`
+ NewContent string `json:"new_content,omitempty"`
+}
+
+const FetchToolName = "fetch"
+
+// FetchParams represents the parameters for the fetch tool.
+type FetchParams struct {
+ URL string `json:"url"`
+ Format string `json:"format"`
+ Timeout int `json:"timeout,omitempty"`
+}
+
+// FetchPermissionsParams represents the permission parameters for the fetch tool.
+type FetchPermissionsParams struct {
+ URL string `json:"url"`
+ Format string `json:"format"`
+ Timeout int `json:"timeout,omitempty"`
+}
+
+const GlobToolName = "glob"
+
+// GlobParams represents the parameters for the glob tool.
+type GlobParams struct {
+ Pattern string `json:"pattern"`
+ Path string `json:"path"`
+}
+
+// GlobResponseMetadata represents the metadata for a glob tool response.
+type GlobResponseMetadata struct {
+ NumberOfFiles int `json:"number_of_files"`
+ Truncated bool `json:"truncated"`
+}
+
+const GrepToolName = "grep"
+
+// GrepParams represents the parameters for the grep tool.
+type GrepParams struct {
+ Pattern string `json:"pattern"`
+ Path string `json:"path"`
+ Include string `json:"include"`
+ LiteralText bool `json:"literal_text"`
+}
+
+// GrepResponseMetadata represents the metadata for a grep tool response.
+type GrepResponseMetadata struct {
+ NumberOfMatches int `json:"number_of_matches"`
+ Truncated bool `json:"truncated"`
+}
+
+const LSToolName = "ls"
+
+// LSParams represents the parameters for the ls tool.
+type LSParams struct {
+ Path string `json:"path"`
+ Ignore []string `json:"ignore"`
+}
+
+// LSPermissionsParams represents the permission parameters for the ls tool.
+type LSPermissionsParams struct {
+ Path string `json:"path"`
+ Ignore []string `json:"ignore"`
+}
+
+// TreeNode represents a node in a directory tree.
+type TreeNode struct {
+ Name string `json:"name"`
+ Path string `json:"path"`
+ Type string `json:"type"`
+ Children []*TreeNode `json:"children,omitempty"`
+}
+
+// LSResponseMetadata represents the metadata for an ls tool response.
+type LSResponseMetadata struct {
+ NumberOfFiles int `json:"number_of_files"`
+ Truncated bool `json:"truncated"`
+}
+
+const MultiEditToolName = "multiedit"
+
+// MultiEditOperation represents a single edit operation in a multi-edit.
+type MultiEditOperation struct {
+ OldString string `json:"old_string"`
+ NewString string `json:"new_string"`
+ ReplaceAll bool `json:"replace_all,omitempty"`
+}
+
+// MultiEditParams represents the parameters for the multi-edit tool.
+type MultiEditParams struct {
+ FilePath string `json:"file_path"`
+ Edits []MultiEditOperation `json:"edits"`
+}
+
+// MultiEditPermissionsParams represents the permission parameters for the multi-edit tool.
+type MultiEditPermissionsParams struct {
+ FilePath string `json:"file_path"`
+ OldContent string `json:"old_content,omitempty"`
+ NewContent string `json:"new_content,omitempty"`
+}
+
+// MultiEditResponseMetadata represents the metadata for a multi-edit tool response.
+type MultiEditResponseMetadata struct {
+ Additions int `json:"additions"`
+ Removals int `json:"removals"`
+ OldContent string `json:"old_content,omitempty"`
+ NewContent string `json:"new_content,omitempty"`
+ EditsApplied int `json:"edits_applied"`
+}
+
+const SourcegraphToolName = "sourcegraph"
+
+// SourcegraphParams represents the parameters for the sourcegraph tool.
+type SourcegraphParams struct {
+ Query string `json:"query"`
+ Count int `json:"count,omitempty"`
+ ContextWindow int `json:"context_window,omitempty"`
+ Timeout int `json:"timeout,omitempty"`
+}
+
+// SourcegraphResponseMetadata represents the metadata for a sourcegraph tool response.
+type SourcegraphResponseMetadata struct {
+ NumberOfMatches int `json:"number_of_matches"`
+ Truncated bool `json:"truncated"`
+}
+
+const ViewToolName = "view"
+
+// ViewParams represents the parameters for the view tool.
+type ViewParams struct {
+ FilePath string `json:"file_path"`
+ Offset int `json:"offset"`
+ Limit int `json:"limit"`
+}
+
+// ViewPermissionsParams represents the permission parameters for the view tool.
+type ViewPermissionsParams struct {
+ FilePath string `json:"file_path"`
+ Offset int `json:"offset"`
+ Limit int `json:"limit"`
+}
+
+// ViewResponseMetadata represents the metadata for a view tool response.
+type ViewResponseMetadata struct {
+ FilePath string `json:"file_path"`
+ Content string `json:"content"`
+}
+
+const WriteToolName = "write"
+
+// WriteParams represents the parameters for the write tool.
+type WriteParams struct {
+ FilePath string `json:"file_path"`
+ Content string `json:"content"`
+}
+
+// WritePermissionsParams represents the permission parameters for the write tool.
+type WritePermissionsParams struct {
+ FilePath string `json:"file_path"`
+ OldContent string `json:"old_content,omitempty"`
+ NewContent string `json:"new_content,omitempty"`
+}
+
+// WriteResponseMetadata represents the metadata for a write tool response.
+type WriteResponseMetadata struct {
+ Diff string `json:"diff"`
+ Additions int `json:"additions"`
+ Removals int `json:"removals"`
+}
@@ -0,0 +1,9 @@
+package proto
+
+// VersionInfo represents version information about the server.
+type VersionInfo struct {
+ Version string `json:"version"`
+ Commit string `json:"commit"`
+ GoVersion string `json:"go_version"`
+ Platform string `json:"platform"`
+}
@@ -1,6 +1,9 @@
package pubsub
-import "context"
+import (
+ "context"
+ "encoding/json"
+)
const (
CreatedEvent EventType = "created"
@@ -8,20 +11,43 @@ const (
DeletedEvent EventType = "deleted"
)
+// PayloadType identifies the type of event payload for discriminated
+// deserialization over JSON.
+type PayloadType = string
+
+const (
+ PayloadTypeLSPEvent PayloadType = "lsp_event"
+ PayloadTypeMCPEvent PayloadType = "mcp_event"
+ PayloadTypePermissionRequest PayloadType = "permission_request"
+ PayloadTypePermissionNotification PayloadType = "permission_notification"
+ PayloadTypeMessage PayloadType = "message"
+ PayloadTypeSession PayloadType = "session"
+ PayloadTypeFile PayloadType = "file"
+ PayloadTypeAgentEvent PayloadType = "agent_event"
+)
+
+// Payload wraps a discriminated JSON payload with a type tag.
+type Payload struct {
+ Type PayloadType `json:"type"`
+ Payload json.RawMessage `json:"payload"`
+}
+
+// Subscriber can subscribe to events of type T.
type Subscriber[T any] interface {
Subscribe(context.Context) <-chan Event[T]
}
type (
- // EventType identifies the type of event
+ // EventType identifies the type of event.
EventType string
- // Event represents an event in the lifecycle of a resource
+ // Event represents an event in the lifecycle of a resource.
Event[T any] struct {
- Type EventType
- Payload T
+ Type EventType `json:"type"`
+ Payload T `json:"payload"`
}
+ // Publisher can publish events of type T.
Publisher[T any] interface {
Publish(EventType, T)
}
@@ -0,0 +1,51 @@
+package server
+
+import (
+ "log/slog"
+ "net/http"
+ "time"
+)
+
+func (s *Server) loggingHandler(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if s.logger == nil {
+ next.ServeHTTP(w, r)
+ return
+ }
+
+ start := time.Now()
+ lrw := &loggingResponseWriter{ResponseWriter: w, statusCode: http.StatusOK}
+ s.logger.Debug("HTTP request",
+ slog.String("method", r.Method),
+ slog.String("path", r.URL.Path),
+ slog.String("remote_addr", r.RemoteAddr),
+ slog.String("user_agent", r.UserAgent()),
+ )
+
+ next.ServeHTTP(lrw, r)
+ duration := time.Since(start)
+
+ s.logger.Debug("HTTP response",
+ slog.String("method", r.Method),
+ slog.String("path", r.URL.Path),
+ slog.Int("status", lrw.statusCode),
+ slog.Duration("duration", duration),
+ slog.String("remote_addr", r.RemoteAddr),
+ slog.String("user_agent", r.UserAgent()),
+ )
+ })
+}
+
+type loggingResponseWriter struct {
+ http.ResponseWriter
+ statusCode int
+}
+
+func (lrw *loggingResponseWriter) WriteHeader(code int) {
+ lrw.statusCode = code
+ lrw.ResponseWriter.WriteHeader(code)
+}
+
+func (lrw *loggingResponseWriter) Unwrap() http.ResponseWriter {
+ return lrw.ResponseWriter
+}
@@ -0,0 +1,10 @@
+//go:build !windows
+// +build !windows
+
+package server
+
+import "net"
+
+func listen(network, address string) (net.Listener, error) {
+ return net.Listen(network, address)
+}
@@ -0,0 +1,24 @@
+//go:build windows
+// +build windows
+
+package server
+
+import (
+ "net"
+
+ "github.com/Microsoft/go-winio"
+)
+
+func listen(network, address string) (net.Listener, error) {
+ switch network {
+ case "npipe":
+ cfg := &winio.PipeConfig{
+ MessageMode: true,
+ InputBufferSize: 65536,
+ OutputBufferSize: 65536,
+ }
+ return winio.ListenPipe(address, cfg)
+ default:
+ return net.Listen(network, address)
+ }
+}
@@ -0,0 +1,656 @@
+package server
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "log/slog"
+ "net/http"
+ "os"
+ "path/filepath"
+ "runtime"
+
+ "github.com/charmbracelet/crush/internal/app"
+ "github.com/charmbracelet/crush/internal/config"
+ "github.com/charmbracelet/crush/internal/db"
+ "github.com/charmbracelet/crush/internal/permission"
+ "github.com/charmbracelet/crush/internal/proto"
+ "github.com/charmbracelet/crush/internal/session"
+ "github.com/charmbracelet/crush/internal/version"
+ "github.com/google/uuid"
+)
+
+type controllerV1 struct {
+ *Server
+}
+
+func (c *controllerV1) handleGetHealth(w http.ResponseWriter, _ *http.Request) {
+ w.WriteHeader(http.StatusOK)
+}
+
+func (c *controllerV1) handleGetVersion(w http.ResponseWriter, _ *http.Request) {
+ jsonEncode(w, proto.VersionInfo{
+ Version: version.Version,
+ Commit: version.Commit,
+ GoVersion: runtime.Version(),
+ Platform: fmt.Sprintf("%s/%s", runtime.GOOS, runtime.GOARCH),
+ })
+}
+
+func (c *controllerV1) handlePostControl(w http.ResponseWriter, r *http.Request) {
+ var req proto.ServerControl
+ if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
+ c.logError(r, "Failed to decode request", "error", err)
+ jsonError(w, http.StatusBadRequest, "failed to decode request")
+ return
+ }
+
+ switch req.Command {
+ case "shutdown":
+ go func() {
+ slog.Info("Shutting down server...")
+ if err := c.Shutdown(context.Background()); err != nil {
+ slog.Error("Failed to shutdown server", "error", err)
+ }
+ }()
+ default:
+ c.logError(r, "Unknown command", "command", req.Command)
+ jsonError(w, http.StatusBadRequest, "unknown command")
+ return
+ }
+}
+
+func (c *controllerV1) handleGetConfig(w http.ResponseWriter, _ *http.Request) {
+ jsonEncode(w, c.cfg)
+}
+
+func (c *controllerV1) handleGetWorkspaces(w http.ResponseWriter, _ *http.Request) {
+ workspaces := []proto.Workspace{}
+ for _, ws := range c.workspaces.Seq2() {
+ workspaces = append(workspaces, proto.Workspace{
+ ID: ws.id,
+ Path: ws.path,
+ YOLO: ws.cfg.Permissions != nil && ws.cfg.Permissions.SkipRequests,
+ DataDir: ws.cfg.Options.DataDirectory,
+ Debug: ws.cfg.Options.Debug,
+ Config: ws.cfg,
+ })
+ }
+ jsonEncode(w, workspaces)
+}
+
+func (c *controllerV1) handleGetWorkspaceLSPDiagnostics(w http.ResponseWriter, r *http.Request) {
+ id := r.PathValue("id")
+ ws, ok := c.workspaces.Get(id)
+ if !ok {
+ c.logError(r, "Workspace not found", "id", id)
+ jsonError(w, http.StatusNotFound, "workspace not found")
+ return
+ }
+
+ lspName := r.PathValue("lsp")
+ var found bool
+ for name, client := range ws.LSPManager.Clients().Seq2() {
+ if name == lspName {
+ diagnostics := client.GetDiagnostics()
+ jsonEncode(w, diagnostics)
+ found = true
+ break
+ }
+ }
+
+ if !found {
+ c.logError(r, "LSP client not found", "id", id, "lsp", lspName)
+ jsonError(w, http.StatusNotFound, "LSP client not found")
+ }
+}
+
+func (c *controllerV1) handleGetWorkspaceLSPs(w http.ResponseWriter, r *http.Request) {
+ id := r.PathValue("id")
+ _, ok := c.workspaces.Get(id)
+ if !ok {
+ c.logError(r, "Workspace not found", "id", id)
+ jsonError(w, http.StatusNotFound, "workspace not found")
+ return
+ }
+
+ lspClients := app.GetLSPStates()
+ jsonEncode(w, lspClients)
+}
+
+func (c *controllerV1) handleGetWorkspaceAgentSessionPromptQueued(w http.ResponseWriter, r *http.Request) {
+ id := r.PathValue("id")
+ ws, ok := c.workspaces.Get(id)
+ if !ok {
+ c.logError(r, "Workspace not found", "id", id)
+ jsonError(w, http.StatusNotFound, "workspace not found")
+ return
+ }
+
+ sid := r.PathValue("sid")
+ queued := ws.App.AgentCoordinator.QueuedPrompts(sid)
+ jsonEncode(w, queued)
+}
+
+func (c *controllerV1) handlePostWorkspaceAgentSessionPromptClear(w http.ResponseWriter, r *http.Request) {
+ id := r.PathValue("id")
+ ws, ok := c.workspaces.Get(id)
+ if !ok {
+ c.logError(r, "Workspace not found", "id", id)
+ jsonError(w, http.StatusNotFound, "workspace not found")
+ return
+ }
+
+ sid := r.PathValue("sid")
+ ws.App.AgentCoordinator.ClearQueue(sid)
+ w.WriteHeader(http.StatusOK)
+}
+
+func (c *controllerV1) handleGetWorkspaceAgentSessionSummarize(w http.ResponseWriter, r *http.Request) {
+ id := r.PathValue("id")
+ ws, ok := c.workspaces.Get(id)
+ if !ok {
+ c.logError(r, "Workspace not found", "id", id)
+ jsonError(w, http.StatusNotFound, "workspace not found")
+ return
+ }
+
+ sid := r.PathValue("sid")
+ if err := ws.App.AgentCoordinator.Summarize(r.Context(), sid); err != nil {
+ c.logError(r, "Failed to summarize session", "error", err, "id", id, "sid", sid)
+ jsonError(w, http.StatusInternalServerError, "failed to summarize session")
+ return
+ }
+}
+
+func (c *controllerV1) handlePostWorkspaceAgentSessionCancel(w http.ResponseWriter, r *http.Request) {
+ id := r.PathValue("id")
+ ws, ok := c.workspaces.Get(id)
+ if !ok {
+ c.logError(r, "Workspace not found", "id", id)
+ jsonError(w, http.StatusNotFound, "workspace not found")
+ return
+ }
+
+ sid := r.PathValue("sid")
+ if ws.App.AgentCoordinator != nil {
+ ws.App.AgentCoordinator.Cancel(sid)
+ }
+ w.WriteHeader(http.StatusOK)
+}
+
+func (c *controllerV1) handleGetWorkspaceAgentSession(w http.ResponseWriter, r *http.Request) {
+ id := r.PathValue("id")
+ ws, ok := c.workspaces.Get(id)
+ if !ok {
+ c.logError(r, "Workspace not found", "id", id)
+ jsonError(w, http.StatusNotFound, "workspace not found")
+ return
+ }
+
+ sid := r.PathValue("sid")
+ se, err := ws.App.Sessions.Get(r.Context(), sid)
+ if err != nil {
+ c.logError(r, "Failed to get session", "error", err, "id", id, "sid", sid)
+ jsonError(w, http.StatusInternalServerError, "failed to get session")
+ return
+ }
+
+ var isSessionBusy bool
+ if ws.App.AgentCoordinator != nil {
+ isSessionBusy = ws.App.AgentCoordinator.IsSessionBusy(sid)
+ }
+
+ jsonEncode(w, proto.AgentSession{
+ Session: proto.Session{
+ ID: se.ID,
+ Title: se.Title,
+ },
+ IsBusy: isSessionBusy,
+ })
+}
+
+func (c *controllerV1) handlePostWorkspaceAgent(w http.ResponseWriter, r *http.Request) {
+ id := r.PathValue("id")
+ ws, ok := c.workspaces.Get(id)
+ if !ok {
+ c.logError(r, "Workspace not found", "id", id)
+ jsonError(w, http.StatusNotFound, "workspace not found")
+ return
+ }
+
+ w.Header().Set("Accept", "application/json")
+
+ var msg proto.AgentMessage
+ if err := json.NewDecoder(r.Body).Decode(&msg); err != nil {
+ c.logError(r, "Failed to decode request", "error", err)
+ jsonError(w, http.StatusBadRequest, "failed to decode request")
+ return
+ }
+
+ if ws.App.AgentCoordinator == nil {
+ c.logError(r, "Agent coordinator not initialized", "id", id)
+ jsonError(w, http.StatusBadRequest, "agent coordinator not initialized")
+ return
+ }
+
+ if _, err := ws.App.AgentCoordinator.Run(c.ctx, msg.SessionID, msg.Prompt); err != nil {
+ c.logError(r, "Failed to enqueue message", "error", err, "id", id, "sid", msg.SessionID)
+ jsonError(w, http.StatusInternalServerError, "failed to enqueue message")
+ return
+ }
+}
+
+func (c *controllerV1) handleGetWorkspaceAgent(w http.ResponseWriter, r *http.Request) {
+ id := r.PathValue("id")
+ ws, ok := c.workspaces.Get(id)
+ if !ok {
+ c.logError(r, "Workspace not found", "id", id)
+ jsonError(w, http.StatusNotFound, "workspace not found")
+ return
+ }
+
+ var agentInfo proto.AgentInfo
+ if ws.App.AgentCoordinator != nil {
+ m := ws.App.AgentCoordinator.Model()
+ agentInfo = proto.AgentInfo{
+ Model: m.CatwalkCfg,
+ IsBusy: ws.App.AgentCoordinator.IsBusy(),
+ }
+ }
+ jsonEncode(w, agentInfo)
+}
+
+func (c *controllerV1) handlePostWorkspaceAgentUpdate(w http.ResponseWriter, r *http.Request) {
+ id := r.PathValue("id")
+ ws, ok := c.workspaces.Get(id)
+ if !ok {
+ c.logError(r, "Workspace not found", "id", id)
+ jsonError(w, http.StatusNotFound, "workspace not found")
+ return
+ }
+
+ if err := ws.App.UpdateAgentModel(r.Context()); err != nil {
+ c.logError(r, "Failed to update agent model", "error", err)
+ jsonError(w, http.StatusInternalServerError, "failed to update agent model")
+ return
+ }
+}
+
+func (c *controllerV1) handlePostWorkspaceAgentInit(w http.ResponseWriter, r *http.Request) {
+ id := r.PathValue("id")
+ ws, ok := c.workspaces.Get(id)
+ if !ok {
+ c.logError(r, "Workspace not found", "id", id)
+ jsonError(w, http.StatusNotFound, "workspace not found")
+ return
+ }
+
+ if err := ws.App.InitCoderAgent(r.Context()); err != nil {
+ c.logError(r, "Failed to initialize coder agent", "error", err)
+ jsonError(w, http.StatusInternalServerError, "failed to initialize coder agent")
+ return
+ }
+}
+
+func (c *controllerV1) handleGetWorkspaceSessionHistory(w http.ResponseWriter, r *http.Request) {
+ id := r.PathValue("id")
+ ws, ok := c.workspaces.Get(id)
+ if !ok {
+ c.logError(r, "Workspace not found", "id", id)
+ jsonError(w, http.StatusNotFound, "workspace not found")
+ return
+ }
+
+ sid := r.PathValue("sid")
+ historyItems, err := ws.App.History.ListBySession(r.Context(), sid)
+ if err != nil {
+ c.logError(r, "Failed to list history", "error", err, "id", id, "sid", sid)
+ jsonError(w, http.StatusInternalServerError, "failed to list history")
+ return
+ }
+
+ jsonEncode(w, historyItems)
+}
+
+func (c *controllerV1) handleGetWorkspaceSessionMessages(w http.ResponseWriter, r *http.Request) {
+ id := r.PathValue("id")
+ ws, ok := c.workspaces.Get(id)
+ if !ok {
+ c.logError(r, "Workspace not found", "id", id)
+ jsonError(w, http.StatusNotFound, "workspace not found")
+ return
+ }
+
+ sid := r.PathValue("sid")
+ messages, err := ws.App.Messages.List(r.Context(), sid)
+ if err != nil {
+ c.logError(r, "Failed to list messages", "error", err, "id", id, "sid", sid)
+ jsonError(w, http.StatusInternalServerError, "failed to list messages")
+ return
+ }
+
+ jsonEncode(w, messages)
+}
+
+func (c *controllerV1) handleGetWorkspaceSession(w http.ResponseWriter, r *http.Request) {
+ id := r.PathValue("id")
+ ws, ok := c.workspaces.Get(id)
+ if !ok {
+ c.logError(r, "Workspace not found", "id", id)
+ jsonError(w, http.StatusNotFound, "workspace not found")
+ return
+ }
+
+ sid := r.PathValue("sid")
+ sess, err := ws.App.Sessions.Get(r.Context(), sid)
+ if err != nil {
+ c.logError(r, "Failed to get session", "error", err, "id", id, "sid", sid)
+ jsonError(w, http.StatusInternalServerError, "failed to get session")
+ return
+ }
+
+ jsonEncode(w, sess)
+}
+
+func (c *controllerV1) handlePostWorkspaceSessions(w http.ResponseWriter, r *http.Request) {
+ id := r.PathValue("id")
+ ws, ok := c.workspaces.Get(id)
+ if !ok {
+ c.logError(r, "Workspace not found", "id", id)
+ jsonError(w, http.StatusNotFound, "workspace not found")
+ return
+ }
+
+ var args session.Session
+ if err := json.NewDecoder(r.Body).Decode(&args); err != nil {
+ c.logError(r, "Failed to decode request", "error", err)
+ jsonError(w, http.StatusBadRequest, "failed to decode request")
+ return
+ }
+
+ sess, err := ws.App.Sessions.Create(r.Context(), args.Title)
+ if err != nil {
+ c.logError(r, "Failed to create session", "error", err, "id", id)
+ jsonError(w, http.StatusInternalServerError, "failed to create session")
+ return
+ }
+
+ jsonEncode(w, sess)
+}
+
+func (c *controllerV1) handleGetWorkspaceSessions(w http.ResponseWriter, r *http.Request) {
+ id := r.PathValue("id")
+ ws, ok := c.workspaces.Get(id)
+ if !ok {
+ c.logError(r, "Workspace not found", "id", id)
+ jsonError(w, http.StatusNotFound, "workspace not found")
+ return
+ }
+
+ sessions, err := ws.App.Sessions.List(r.Context())
+ if err != nil {
+ c.logError(r, "Failed to list sessions", "error", err)
+ jsonError(w, http.StatusInternalServerError, "failed to list sessions")
+ return
+ }
+
+ jsonEncode(w, sessions)
+}
+
+func (c *controllerV1) handlePostWorkspacePermissionsGrant(w http.ResponseWriter, r *http.Request) {
+ id := r.PathValue("id")
+ ws, ok := c.workspaces.Get(id)
+ if !ok {
+ c.logError(r, "Workspace not found", "id", id)
+ jsonError(w, http.StatusNotFound, "workspace not found")
+ return
+ }
+
+ var req proto.PermissionGrant
+ if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
+ c.logError(r, "Failed to decode request", "error", err)
+ jsonError(w, http.StatusBadRequest, "failed to decode request")
+ return
+ }
+
+ perm := permission.PermissionRequest{
+ ID: req.Permission.ID,
+ SessionID: req.Permission.SessionID,
+ ToolCallID: req.Permission.ToolCallID,
+ ToolName: req.Permission.ToolName,
+ Description: req.Permission.Description,
+ Action: req.Permission.Action,
+ Params: req.Permission.Params,
+ Path: req.Permission.Path,
+ }
+
+ switch req.Action {
+ case proto.PermissionAllow:
+ ws.App.Permissions.Grant(perm)
+ case proto.PermissionAllowForSession:
+ ws.App.Permissions.GrantPersistent(perm)
+ case proto.PermissionDeny:
+ ws.App.Permissions.Deny(perm)
+ default:
+ c.logError(r, "Invalid permission action", "action", req.Action)
+ jsonError(w, http.StatusBadRequest, "invalid permission action")
+ return
+ }
+}
+
+func (c *controllerV1) handlePostWorkspacePermissionsSkip(w http.ResponseWriter, r *http.Request) {
+ id := r.PathValue("id")
+ ws, ok := c.workspaces.Get(id)
+ if !ok {
+ c.logError(r, "Workspace not found", "id", id)
+ jsonError(w, http.StatusNotFound, "workspace not found")
+ return
+ }
+
+ var req proto.PermissionSkipRequest
+ if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
+ c.logError(r, "Failed to decode request", "error", err)
+ jsonError(w, http.StatusBadRequest, "failed to decode request")
+ return
+ }
+
+ ws.App.Permissions.SetSkipRequests(req.Skip)
+}
+
+func (c *controllerV1) handleGetWorkspacePermissionsSkip(w http.ResponseWriter, r *http.Request) {
+ id := r.PathValue("id")
+ ws, ok := c.workspaces.Get(id)
+ if !ok {
+ c.logError(r, "Workspace not found", "id", id)
+ jsonError(w, http.StatusNotFound, "workspace not found")
+ return
+ }
+
+ skip := ws.App.Permissions.SkipRequests()
+ jsonEncode(w, proto.PermissionSkipRequest{Skip: skip})
+}
+
+func (c *controllerV1) handleGetWorkspaceProviders(w http.ResponseWriter, r *http.Request) {
+ id := r.PathValue("id")
+ ws, ok := c.workspaces.Get(id)
+ if !ok {
+ c.logError(r, "Workspace not found", "id", id)
+ jsonError(w, http.StatusNotFound, "workspace not found")
+ return
+ }
+
+ providers, _ := config.Providers(ws.cfg)
+ jsonEncode(w, providers)
+}
+
+func (c *controllerV1) handleGetWorkspaceEvents(w http.ResponseWriter, r *http.Request) {
+ flusher := http.NewResponseController(w)
+ id := r.PathValue("id")
+ ws, ok := c.workspaces.Get(id)
+ if !ok {
+ c.logError(r, "Workspace not found", "id", id)
+ jsonError(w, http.StatusNotFound, "workspace not found")
+ return
+ }
+
+ w.Header().Set("Content-Type", "text/event-stream")
+ w.Header().Set("Cache-Control", "no-cache")
+ w.Header().Set("Connection", "keep-alive")
+
+ events := ws.App.Events()
+
+ for {
+ select {
+ case <-r.Context().Done():
+ c.logDebug(r, "Stopping event stream")
+ return
+ case ev, ok := <-events:
+ if !ok {
+ return
+ }
+ c.logDebug(r, "Sending event", "event", fmt.Sprintf("%T %+v", ev, ev))
+ data, err := json.Marshal(ev)
+ if err != nil {
+ c.logError(r, "Failed to marshal event", "error", err)
+ continue
+ }
+
+ fmt.Fprintf(w, "data: %s\n\n", data)
+ flusher.Flush()
+ }
+ }
+}
+
+func (c *controllerV1) handleGetWorkspaceConfig(w http.ResponseWriter, r *http.Request) {
+ id := r.PathValue("id")
+ ws, ok := c.workspaces.Get(id)
+ if !ok {
+ c.logError(r, "Workspace not found", "id", id)
+ jsonError(w, http.StatusNotFound, "workspace not found")
+ return
+ }
+
+ jsonEncode(w, ws.cfg)
+}
+
+func (c *controllerV1) handleDeleteWorkspaces(w http.ResponseWriter, r *http.Request) {
+ id := r.PathValue("id")
+ ws, ok := c.workspaces.Get(id)
+ if ok {
+ ws.App.Shutdown()
+ }
+ c.workspaces.Del(id)
+}
+
+func (c *controllerV1) handleGetWorkspace(w http.ResponseWriter, r *http.Request) {
+ id := r.PathValue("id")
+ ws, ok := c.workspaces.Get(id)
+ if !ok {
+ c.logError(r, "Workspace not found", "id", id)
+ jsonError(w, http.StatusNotFound, "workspace not found")
+ return
+ }
+
+ jsonEncode(w, proto.Workspace{
+ ID: ws.id,
+ Path: ws.path,
+ YOLO: ws.cfg.Permissions != nil && ws.cfg.Permissions.SkipRequests,
+ DataDir: ws.cfg.Options.DataDirectory,
+ Debug: ws.cfg.Options.Debug,
+ Config: ws.cfg,
+ })
+}
+
+func (c *controllerV1) handlePostWorkspaces(w http.ResponseWriter, r *http.Request) {
+ var args proto.Workspace
+ if err := json.NewDecoder(r.Body).Decode(&args); err != nil {
+ c.logError(r, "Failed to decode request", "error", err)
+ jsonError(w, http.StatusBadRequest, "failed to decode request")
+ return
+ }
+
+ if args.Path == "" {
+ c.logError(r, "Path is required")
+ jsonError(w, http.StatusBadRequest, "path is required")
+ return
+ }
+
+ id := uuid.New().String()
+ cfg, err := config.Init(args.Path, args.DataDir, args.Debug)
+ if err != nil {
+ c.logError(r, "Failed to initialize config", "error", err)
+ jsonError(w, http.StatusBadRequest, fmt.Sprintf("failed to initialize config: %v", err))
+ return
+ }
+
+ if cfg.Permissions == nil {
+ cfg.Permissions = &config.Permissions{}
+ }
+ cfg.Permissions.SkipRequests = args.YOLO
+
+ if err := createDotCrushDir(cfg.Options.DataDirectory); err != nil {
+ c.logError(r, "Failed to create data directory", "error", err)
+ jsonError(w, http.StatusInternalServerError, "failed to create data directory")
+ return
+ }
+
+ conn, err := db.Connect(c.ctx, cfg.Options.DataDirectory)
+ if err != nil {
+ c.logError(r, "Failed to connect to database", "error", err)
+ jsonError(w, http.StatusInternalServerError, "failed to connect to database")
+ return
+ }
+
+ appWorkspace, err := app.New(c.ctx, conn, cfg)
+ if err != nil {
+ slog.Error("Failed to create app workspace", "error", err)
+ jsonError(w, http.StatusInternalServerError, "failed to create app workspace")
+ return
+ }
+
+ ws := &Workspace{
+ App: appWorkspace,
+ id: id,
+ path: args.Path,
+ cfg: cfg,
+ env: args.Env,
+ }
+
+ c.workspaces.Set(id, ws)
+ jsonEncode(w, proto.Workspace{
+ ID: id,
+ Path: args.Path,
+ DataDir: cfg.Options.DataDirectory,
+ Debug: cfg.Options.Debug,
+ YOLO: cfg.Permissions.SkipRequests,
+ Config: cfg,
+ Env: args.Env,
+ })
+}
+
+func createDotCrushDir(dir string) error {
+ if err := os.MkdirAll(dir, 0o700); err != nil {
+ return fmt.Errorf("failed to create data directory: %q %w", dir, err)
+ }
+
+ gitIgnorePath := filepath.Join(dir, ".gitignore")
+ if _, err := os.Stat(gitIgnorePath); os.IsNotExist(err) {
+ if err := os.WriteFile(gitIgnorePath, []byte("*\n"), 0o644); err != nil {
+ return fmt.Errorf("failed to create .gitignore file: %q %w", gitIgnorePath, err)
+ }
+ }
+
+ return nil
+}
+
+func jsonEncode(w http.ResponseWriter, v any) {
+ w.Header().Set("Content-Type", "application/json")
+ _ = json.NewEncoder(w).Encode(v)
+}
+
+func jsonError(w http.ResponseWriter, status int, message string) {
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(status)
+ _ = json.NewEncoder(w).Encode(proto.Error{Message: message})
+}
@@ -0,0 +1,208 @@
+package server
+
+import (
+ "context"
+ "fmt"
+ "log/slog"
+ "net"
+ "net/http"
+ "net/url"
+ "os/user"
+ "runtime"
+ "strings"
+
+ "github.com/charmbracelet/crush/internal/app"
+ "github.com/charmbracelet/crush/internal/config"
+ "github.com/charmbracelet/crush/internal/csync"
+)
+
+// ErrServerClosed is returned when the server is closed.
+var ErrServerClosed = http.ErrServerClosed
+
+// Workspace represents a running [app.App] workspace with its associated
+// resources and state.
+type Workspace struct {
+ *app.App
+ ln net.Listener
+ cfg *config.Config
+ id string
+ path string
+ env []string
+}
+
+// ParseHostURL parses a host URL into a [url.URL].
+func ParseHostURL(host string) (*url.URL, error) {
+ proto, addr, ok := strings.Cut(host, "://")
+ if !ok {
+ return nil, fmt.Errorf("invalid host format: %s", host)
+ }
+
+ var basePath string
+ if proto == "tcp" {
+ parsed, err := url.Parse("tcp://" + addr)
+ if err != nil {
+ return nil, fmt.Errorf("invalid tcp address: %v", err)
+ }
+ addr = parsed.Host
+ basePath = parsed.Path
+ }
+ return &url.URL{
+ Scheme: proto,
+ Host: addr,
+ Path: basePath,
+ }, nil
+}
+
+// DefaultHost returns the default server host.
+func DefaultHost() string {
+ sock := "crush.sock"
+ usr, err := user.Current()
+ if err == nil && usr.Uid != "" {
+ sock = fmt.Sprintf("crush-%s.sock", usr.Uid)
+ }
+ if runtime.GOOS == "windows" {
+ return fmt.Sprintf("npipe:////./pipe/%s", sock)
+ }
+ return fmt.Sprintf("unix:///tmp/%s", sock)
+}
+
+// Server represents a Crush server bound to a specific address.
+type Server struct {
+ // Addr can be a TCP address, a Unix socket path, or a Windows named pipe.
+ Addr string
+ network string
+
+ h *http.Server
+ ln net.Listener
+ ctx context.Context
+
+ // workspaces is a map of running applications managed by the server.
+ workspaces *csync.Map[string, *Workspace]
+ cfg *config.Config
+ logger *slog.Logger
+}
+
+// SetLogger sets the logger for the server.
+func (s *Server) SetLogger(logger *slog.Logger) {
+ s.logger = logger
+}
+
+// DefaultServer returns a new [Server] with the default address.
+func DefaultServer(cfg *config.Config) *Server {
+ hostURL, err := ParseHostURL(DefaultHost())
+ if err != nil {
+ panic("invalid default host")
+ }
+ return NewServer(cfg, hostURL.Scheme, hostURL.Host)
+}
+
+// NewServer creates a new [Server] with the given network and address.
+func NewServer(cfg *config.Config, network, address string) *Server {
+ s := new(Server)
+ s.Addr = address
+ s.network = network
+ s.cfg = cfg
+ s.workspaces = csync.NewMap[string, *Workspace]()
+ s.ctx = context.Background()
+
+ var p http.Protocols
+ p.SetHTTP1(true)
+ p.SetUnencryptedHTTP2(true)
+ c := &controllerV1{Server: s}
+ mux := http.NewServeMux()
+ mux.HandleFunc("GET /v1/health", c.handleGetHealth)
+ mux.HandleFunc("GET /v1/version", c.handleGetVersion)
+ mux.HandleFunc("GET /v1/config", c.handleGetConfig)
+ mux.HandleFunc("POST /v1/control", c.handlePostControl)
+ mux.HandleFunc("GET /v1/workspaces", c.handleGetWorkspaces)
+ mux.HandleFunc("POST /v1/workspaces", c.handlePostWorkspaces)
+ mux.HandleFunc("DELETE /v1/workspaces/{id}", c.handleDeleteWorkspaces)
+ mux.HandleFunc("GET /v1/workspaces/{id}", c.handleGetWorkspace)
+ mux.HandleFunc("GET /v1/workspaces/{id}/config", c.handleGetWorkspaceConfig)
+ mux.HandleFunc("GET /v1/workspaces/{id}/events", c.handleGetWorkspaceEvents)
+ mux.HandleFunc("GET /v1/workspaces/{id}/providers", c.handleGetWorkspaceProviders)
+ mux.HandleFunc("GET /v1/workspaces/{id}/sessions", c.handleGetWorkspaceSessions)
+ mux.HandleFunc("POST /v1/workspaces/{id}/sessions", c.handlePostWorkspaceSessions)
+ mux.HandleFunc("GET /v1/workspaces/{id}/sessions/{sid}", c.handleGetWorkspaceSession)
+ mux.HandleFunc("GET /v1/workspaces/{id}/sessions/{sid}/history", c.handleGetWorkspaceSessionHistory)
+ mux.HandleFunc("GET /v1/workspaces/{id}/sessions/{sid}/messages", c.handleGetWorkspaceSessionMessages)
+ mux.HandleFunc("GET /v1/workspaces/{id}/lsps", c.handleGetWorkspaceLSPs)
+ mux.HandleFunc("GET /v1/workspaces/{id}/lsps/{lsp}/diagnostics", c.handleGetWorkspaceLSPDiagnostics)
+ mux.HandleFunc("GET /v1/workspaces/{id}/permissions/skip", c.handleGetWorkspacePermissionsSkip)
+ mux.HandleFunc("POST /v1/workspaces/{id}/permissions/skip", c.handlePostWorkspacePermissionsSkip)
+ mux.HandleFunc("POST /v1/workspaces/{id}/permissions/grant", c.handlePostWorkspacePermissionsGrant)
+ mux.HandleFunc("GET /v1/workspaces/{id}/agent", c.handleGetWorkspaceAgent)
+ mux.HandleFunc("POST /v1/workspaces/{id}/agent", c.handlePostWorkspaceAgent)
+ mux.HandleFunc("POST /v1/workspaces/{id}/agent/init", c.handlePostWorkspaceAgentInit)
+ mux.HandleFunc("POST /v1/workspaces/{id}/agent/update", c.handlePostWorkspaceAgentUpdate)
+ mux.HandleFunc("GET /v1/workspaces/{id}/agent/sessions/{sid}", c.handleGetWorkspaceAgentSession)
+ mux.HandleFunc("POST /v1/workspaces/{id}/agent/sessions/{sid}/cancel", c.handlePostWorkspaceAgentSessionCancel)
+ mux.HandleFunc("GET /v1/workspaces/{id}/agent/sessions/{sid}/prompts/queued", c.handleGetWorkspaceAgentSessionPromptQueued)
+ mux.HandleFunc("POST /v1/workspaces/{id}/agent/sessions/{sid}/prompts/clear", c.handlePostWorkspaceAgentSessionPromptClear)
+ mux.HandleFunc("POST /v1/workspaces/{id}/agent/sessions/{sid}/summarize", c.handleGetWorkspaceAgentSessionSummarize)
+ s.h = &http.Server{
+ Protocols: &p,
+ Handler: s.loggingHandler(mux),
+ }
+ if network == "tcp" {
+ s.h.Addr = address
+ }
+ return s
+}
+
+// Serve accepts incoming connections on the listener.
+func (s *Server) Serve(ln net.Listener) error {
+ return s.h.Serve(ln)
+}
+
+// ListenAndServe starts the server and begins accepting connections.
+func (s *Server) ListenAndServe() error {
+ if s.ln != nil {
+ return fmt.Errorf("server already started")
+ }
+ ln, err := listen(s.network, s.Addr)
+ if err != nil {
+ return fmt.Errorf("failed to listen on %s: %w", s.Addr, err)
+ }
+ return s.Serve(ln)
+}
+
+func (s *Server) closeListener() {
+ if s.ln != nil {
+ s.ln.Close()
+ s.ln = nil
+ }
+}
+
+// Close force closes all listeners and connections.
+func (s *Server) Close() error {
+ defer func() { s.closeListener() }()
+ return s.h.Close()
+}
+
+// Shutdown gracefully shuts down the server without interrupting active
+// connections.
+func (s *Server) Shutdown(ctx context.Context) error {
+ defer func() { s.closeListener() }()
+ return s.h.Shutdown(ctx)
+}
+
+func (s *Server) logDebug(r *http.Request, msg string, args ...any) {
+ if s.logger != nil {
+ s.logger.With(
+ slog.String("method", r.Method),
+ slog.String("url", r.URL.String()),
+ slog.String("remote_addr", r.RemoteAddr),
+ ).Debug(msg, args...)
+ }
+}
+
+func (s *Server) logError(r *http.Request, msg string, args ...any) {
+ if s.logger != nil {
+ s.logger.With(
+ slog.String("method", r.Method),
+ slog.String("url", r.URL.String()),
+ slog.String("remote_addr", r.RemoteAddr),
+ ).Error(msg, args...)
+ }
+}
@@ -57,7 +57,10 @@ func (m *UI) initializeProject() tea.Cmd {
initialize := func() tea.Msg {
initPrompt, err := agent.InitializePrompt(*cfg)
if err != nil {
- return util.InfoMsg{Type: util.InfoTypeError, Msg: err.Error()}
+ return util.InfoMsg{
+ Type: util.InfoTypeError,
+ Msg: fmt.Sprintf("Failed to initialize project: %v", err),
+ }
}
return sendMessageMsg{Content: initPrompt}
}
@@ -740,6 +740,9 @@ func (m *UI) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
cmds = append(cmds, cmd)
}
case util.InfoMsg:
+ if msg.Type == util.InfoTypeError {
+ slog.Error("Error reported", "error", msg.Msg)
+ }
m.status.SetInfoMsg(msg)
ttl := msg.TTL
if ttl <= 0 {
@@ -2753,7 +2756,7 @@ func (m *UI) sendMessage(content string, attachments ...message.Attachment) tea.
}
return util.InfoMsg{
Type: util.InfoTypeError,
- Msg: err.Error(),
+ Msg: fmt.Sprintf("Failed to run agent: %v", err),
}
}
return nil
@@ -4,7 +4,6 @@ package util
import (
"context"
"errors"
- "log/slog"
"os/exec"
"time"
@@ -23,7 +22,6 @@ func CmdHandler(msg tea.Msg) tea.Cmd {
}
func ReportError(err error) tea.Cmd {
- slog.Error("Error reported", "error", err)
return CmdHandler(NewErrorMsg(err))
}
@@ -2,9 +2,12 @@ package version
import "runtime/debug"
-// Build-time parameters set via -ldflags
+// Build-time parameters set via -ldflags.
-var Version = "devel"
+var (
+ Version = "devel"
+ Commit = "unknown"
+)
// A user may install crush using `go install github.com/charmbracelet/crush@latest`.
// without -ldflags, in which case the version above is unset. As a workaround