diff --git a/go.mod b/go.mod index a8467a055e4d78bc244061459e499d0473426afe..6075460c2ac79460cf586ab6752d15a715445812 100644 --- a/go.mod +++ b/go.mod @@ -79,6 +79,7 @@ require ( cloud.google.com/go/compute/metadata v0.9.0 // indirect github.com/Azure/azure-sdk-for-go/sdk/azcore v1.17.0 // indirect github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0 // indirect + github.com/Microsoft/go-winio v0.6.2 // indirect github.com/andybalholm/cascadia v1.3.3 // indirect github.com/aws/aws-sdk-go-v2 v1.41.2 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5 // indirect diff --git a/go.sum b/go.sum index 8f27c3547b963695ad31ac7e03dbeebb0e9e612a..171527b9b755f5ffdd3592d43c4ebfbe17a1c82f 100644 --- a/go.sum +++ b/go.sum @@ -34,6 +34,8 @@ github.com/JohannesKaufmann/html-to-markdown v1.6.0 h1:04VXMiE50YYfCfLboJCLcgqF5 github.com/JohannesKaufmann/html-to-markdown v1.6.0/go.mod h1:NUI78lGg/a7vpEJTz/0uOcYMaibytE4BUOQS8k78yPQ= github.com/MakeNowJust/heredoc v1.0.0 h1:cXCdzVdstXyiTqTvfqk9SDHpKNjxuom+DOlyEeQ4pzQ= github.com/MakeNowJust/heredoc v1.0.0/go.mod h1:mG5amYoWBHf8vpLOuehzbGGw0EHxpZZ6lCpQ4fNJ8LE= +github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= +github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/PuerkitoBio/goquery v1.9.2/go.mod h1:GHPCaP0ODyyxqcNoFGYlAprUFH81NuRPd0GX3Zu2Mvk= github.com/PuerkitoBio/goquery v1.11.0 h1:jZ7pwMQXIITcUXNH83LLk+txlaEy6NVOfTuP43xxfqw= github.com/PuerkitoBio/goquery v1.11.0/go.mod h1:wQHgxUOU3JGuj3oD/QFfxUdlzW6xPHfqyHre6VMY4DQ= diff --git a/internal/app/app.go b/internal/app/app.go index 4f353f1bf2037593976f84b19508e52b1019a028..a81cfeb9d88357ec80566513d00e7f3080f0ecd5 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -143,6 +143,11 @@ func (app *App) Config() *config.Config { return app.config } +// Events returns the events channel for the application. +func (app *App) Events() <-chan tea.Msg { + return app.events +} + // RunNonInteractive runs the application in non-interactive mode with the // given prompt, printing to stdout. func (app *App) RunNonInteractive(ctx context.Context, output io.Writer, prompt, largeModel, smallModel string, hideSpinner bool) error { diff --git a/internal/client/client.go b/internal/client/client.go new file mode 100644 index 0000000000000000000000000000000000000000..fcc38914a1dba0908a8f47d3be709b555c948171 --- /dev/null +++ b/internal/client/client.go @@ -0,0 +1,192 @@ +package client + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "net/url" + stdpath "path" + "path/filepath" + "time" + + "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/proto" + "github.com/charmbracelet/crush/internal/server" +) + +// DummyHost is used to satisfy the http.Client's requirement for a URL. +const DummyHost = "api.crush.localhost" + +// Client represents an RPC client connected to a Crush server. +type Client struct { + h *http.Client + path string + network string + addr string +} + +// DefaultClient creates a new [Client] connected to the default server address. +func DefaultClient(path string) (*Client, error) { + host, err := server.ParseHostURL(server.DefaultHost()) + if err != nil { + return nil, err + } + return NewClient(path, host.Scheme, host.Host) +} + +// NewClient creates a new [Client] connected to the server at the given +// network and address. +func NewClient(path, network, address string) (*Client, error) { + c := new(Client) + c.path = filepath.Clean(path) + c.network = network + c.addr = address + p := &http.Protocols{} + p.SetHTTP1(true) + p.SetUnencryptedHTTP2(true) + tr := http.DefaultTransport.(*http.Transport).Clone() + tr.Protocols = p + tr.DialContext = c.dialer + if c.network == "npipe" || c.network == "unix" { + tr.DisableCompression = true + } + c.h = &http.Client{ + Transport: tr, + Timeout: 0, + } + return c, nil +} + +// Path returns the client's workspace filesystem path. +func (c *Client) Path() string { + return c.path +} + +// GetGlobalConfig retrieves the server's configuration. +func (c *Client) GetGlobalConfig(ctx context.Context) (*config.Config, error) { + var cfg config.Config + rsp, err := c.get(ctx, "/config", nil, nil) + if err != nil { + return nil, err + } + defer rsp.Body.Close() + if err := json.NewDecoder(rsp.Body).Decode(&cfg); err != nil { + return nil, err + } + return &cfg, nil +} + +// Health checks the server's health status. +func (c *Client) Health(ctx context.Context) error { + rsp, err := c.get(ctx, "/health", nil, nil) + if err != nil { + return err + } + defer rsp.Body.Close() + if rsp.StatusCode != http.StatusOK { + return fmt.Errorf("server health check failed: %s", rsp.Status) + } + return nil +} + +// VersionInfo retrieves the server's version information. +func (c *Client) VersionInfo(ctx context.Context) (*proto.VersionInfo, error) { + var vi proto.VersionInfo + rsp, err := c.get(ctx, "version", nil, nil) + if err != nil { + return nil, err + } + defer rsp.Body.Close() + if err := json.NewDecoder(rsp.Body).Decode(&vi); err != nil { + return nil, err + } + return &vi, nil +} + +// ShutdownServer sends a shutdown request to the server. +func (c *Client) ShutdownServer(ctx context.Context) error { + rsp, err := c.post(ctx, "/control", nil, jsonBody(proto.ServerControl{ + Command: "shutdown", + }), nil) + if err != nil { + return err + } + defer rsp.Body.Close() + if rsp.StatusCode != http.StatusOK { + return fmt.Errorf("server shutdown failed: %s", rsp.Status) + } + return nil +} + +func (c *Client) dialer(ctx context.Context, network, address string) (net.Conn, error) { + d := net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + } + switch c.network { + case "npipe": + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + return dialPipeContext(ctx, c.addr) + case "unix": + return d.DialContext(ctx, "unix", c.addr) + default: + return d.DialContext(ctx, network, address) + } +} + +func (c *Client) get(ctx context.Context, path string, query url.Values, headers http.Header) (*http.Response, error) { + return c.sendReq(ctx, http.MethodGet, path, query, nil, headers) +} + +func (c *Client) post(ctx context.Context, path string, query url.Values, body io.Reader, headers http.Header) (*http.Response, error) { + return c.sendReq(ctx, http.MethodPost, path, query, body, headers) +} + +func (c *Client) delete(ctx context.Context, path string, query url.Values, headers http.Header) (*http.Response, error) { + return c.sendReq(ctx, http.MethodDelete, path, query, nil, headers) +} + +func (c *Client) sendReq(ctx context.Context, method, path string, query url.Values, body io.Reader, headers http.Header) (*http.Response, error) { + url := (&url.URL{ + Path: stdpath.Join("/v1", path), + RawQuery: query.Encode(), + }).String() + req, err := c.buildReq(ctx, method, url, body, headers) + if err != nil { + return nil, err + } + + rsp, err := c.h.Do(req) + if err != nil { + return nil, err + } + + return rsp, nil +} + +func (c *Client) buildReq(ctx context.Context, method, url string, body io.Reader, headers http.Header) (*http.Request, error) { + r, err := http.NewRequestWithContext(ctx, method, url, body) + if err != nil { + return nil, err + } + + for k, v := range headers { + r.Header[http.CanonicalHeaderKey(k)] = v + } + + r.URL.Scheme = "http" + r.URL.Host = c.addr + if c.network == "npipe" || c.network == "unix" { + r.Host = DummyHost + } + + if body != nil && r.Header.Get("Content-Type") == "" { + r.Header.Set("Content-Type", "text/plain") + } + + return r, nil +} diff --git a/internal/client/dial_other.go b/internal/client/dial_other.go new file mode 100644 index 0000000000000000000000000000000000000000..f2ba8569ba3326f2df82dc34bbf842eac30918d9 --- /dev/null +++ b/internal/client/dial_other.go @@ -0,0 +1,14 @@ +//go:build !windows +// +build !windows + +package client + +import ( + "context" + "net" + "syscall" +) + +func dialPipeContext(context.Context, string) (net.Conn, error) { + return nil, syscall.EAFNOSUPPORT +} diff --git a/internal/client/dial_windows.go b/internal/client/dial_windows.go new file mode 100644 index 0000000000000000000000000000000000000000..750ce98b152c7583ce8e7889ef35c12098f6da8f --- /dev/null +++ b/internal/client/dial_windows.go @@ -0,0 +1,15 @@ +//go:build windows +// +build windows + +package client + +import ( + "context" + "net" + + "github.com/Microsoft/go-winio" +) + +func dialPipeContext(ctx context.Context, address string) (net.Conn, error) { + return winio.DialPipeContext(ctx, address) +} diff --git a/internal/client/proto.go b/internal/client/proto.go new file mode 100644 index 0000000000000000000000000000000000000000..0705f4ee3db77bd8500ebbb29b607f494298d959 --- /dev/null +++ b/internal/client/proto.go @@ -0,0 +1,490 @@ +package client + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "log/slog" + "net/http" + "time" + + "github.com/charmbracelet/crush/internal/app" + "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/history" + "github.com/charmbracelet/crush/internal/message" + "github.com/charmbracelet/crush/internal/proto" + "github.com/charmbracelet/crush/internal/pubsub" + "github.com/charmbracelet/crush/internal/session" + "github.com/charmbracelet/x/powernap/pkg/lsp/protocol" +) + +// CreateWorkspace creates a new workspace on the server. +func (c *Client) CreateWorkspace(ctx context.Context, ws proto.Workspace) (*proto.Workspace, error) { + rsp, err := c.post(ctx, "/workspaces", nil, jsonBody(ws), http.Header{"Content-Type": []string{"application/json"}}) + if err != nil { + return nil, fmt.Errorf("failed to create workspace: %w", err) + } + defer rsp.Body.Close() + if rsp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("failed to create workspace: status code %d", rsp.StatusCode) + } + var created proto.Workspace + if err := json.NewDecoder(rsp.Body).Decode(&created); err != nil { + return nil, fmt.Errorf("failed to decode workspace: %w", err) + } + return &created, nil +} + +// DeleteWorkspace deletes a workspace on the server. +func (c *Client) DeleteWorkspace(ctx context.Context, id string) error { + rsp, err := c.delete(ctx, fmt.Sprintf("/workspaces/%s", id), nil, nil) + if err != nil { + return fmt.Errorf("failed to delete workspace: %w", err) + } + defer rsp.Body.Close() + if rsp.StatusCode != http.StatusOK { + return fmt.Errorf("failed to delete workspace: status code %d", rsp.StatusCode) + } + return nil +} + +// SubscribeEvents subscribes to server-sent events for a workspace. +func (c *Client) SubscribeEvents(ctx context.Context, id string) (<-chan any, error) { + events := make(chan any, 100) + rsp, err := c.get(ctx, fmt.Sprintf("/workspaces/%s/events", id), nil, http.Header{ + "Accept": []string{"text/event-stream"}, + "Cache-Control": []string{"no-cache"}, + "Connection": []string{"keep-alive"}, + }) + if err != nil { + return nil, fmt.Errorf("failed to subscribe to events: %w", err) + } + + if rsp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("failed to subscribe to events: status code %d", rsp.StatusCode) + } + + go func() { + defer rsp.Body.Close() + + scr := bufio.NewReader(rsp.Body) + for { + line, err := scr.ReadBytes('\n') + if errors.Is(err, io.EOF) { + break + } + if err != nil { + slog.Error("Reading from events stream", "error", err) + time.Sleep(time.Second * 2) + continue + } + line = bytes.TrimSpace(line) + if len(line) == 0 { + continue + } + + data, ok := bytes.CutPrefix(line, []byte("data:")) + if !ok { + slog.Warn("Invalid event format", "line", string(line)) + continue + } + + data = bytes.TrimSpace(data) + + var event pubsub.Event[any] + if err := json.Unmarshal(data, &event); err != nil { + slog.Error("Unmarshaling event", "error", err) + continue + } + + type alias pubsub.Event[any] + aux := &struct { + Payload json.RawMessage `json:"payload"` + *alias + }{ + alias: (*alias)(&event), + } + + if err := json.Unmarshal(data, &aux); err != nil { + slog.Error("Unmarshaling event payload", "error", err) + continue + } + + var p pubsub.Payload + if err := json.Unmarshal(aux.Payload, &p); err != nil { + slog.Error("Unmarshaling event payload", "error", err) + continue + } + + switch p.Type { + case pubsub.PayloadTypeLSPEvent: + var e pubsub.Event[proto.LSPEvent] + _ = json.Unmarshal(data, &e) + sendEvent(ctx, events, e) + case pubsub.PayloadTypeMCPEvent: + var e pubsub.Event[proto.MCPEvent] + _ = json.Unmarshal(data, &e) + sendEvent(ctx, events, e) + case pubsub.PayloadTypePermissionRequest: + var e pubsub.Event[proto.PermissionRequest] + _ = json.Unmarshal(data, &e) + sendEvent(ctx, events, e) + case pubsub.PayloadTypePermissionNotification: + var e pubsub.Event[proto.PermissionNotification] + _ = json.Unmarshal(data, &e) + sendEvent(ctx, events, e) + case pubsub.PayloadTypeMessage: + var e pubsub.Event[proto.Message] + _ = json.Unmarshal(data, &e) + sendEvent(ctx, events, e) + case pubsub.PayloadTypeSession: + var e pubsub.Event[proto.Session] + _ = json.Unmarshal(data, &e) + sendEvent(ctx, events, e) + case pubsub.PayloadTypeFile: + var e pubsub.Event[proto.File] + _ = json.Unmarshal(data, &e) + sendEvent(ctx, events, e) + case pubsub.PayloadTypeAgentEvent: + var e pubsub.Event[proto.AgentEvent] + _ = json.Unmarshal(data, &e) + sendEvent(ctx, events, e) + default: + slog.Warn("Unknown event type", "type", p.Type) + continue + } + } + }() + + return events, nil +} + +func sendEvent(ctx context.Context, evc chan any, ev any) { + slog.Info("Event received", "event", fmt.Sprintf("%T %+v", ev, ev)) + select { + case evc <- ev: + case <-ctx.Done(): + close(evc) + return + } +} + +// GetLSPDiagnostics retrieves LSP diagnostics for a specific LSP client. +func (c *Client) GetLSPDiagnostics(ctx context.Context, id string, lspName string) (map[protocol.DocumentURI][]protocol.Diagnostic, error) { + rsp, err := c.get(ctx, fmt.Sprintf("/workspaces/%s/lsps/%s/diagnostics", id, lspName), nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to get LSP diagnostics: %w", err) + } + defer rsp.Body.Close() + if rsp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("failed to get LSP diagnostics: status code %d", rsp.StatusCode) + } + var diagnostics map[protocol.DocumentURI][]protocol.Diagnostic + if err := json.NewDecoder(rsp.Body).Decode(&diagnostics); err != nil { + return nil, fmt.Errorf("failed to decode LSP diagnostics: %w", err) + } + return diagnostics, nil +} + +// GetLSPs retrieves the LSP client states for a workspace. +func (c *Client) GetLSPs(ctx context.Context, id string) (map[string]app.LSPClientInfo, error) { + rsp, err := c.get(ctx, fmt.Sprintf("/workspaces/%s/lsps", id), nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to get LSPs: %w", err) + } + defer rsp.Body.Close() + if rsp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("failed to get LSPs: status code %d", rsp.StatusCode) + } + var lsps map[string]app.LSPClientInfo + if err := json.NewDecoder(rsp.Body).Decode(&lsps); err != nil { + return nil, fmt.Errorf("failed to decode LSPs: %w", err) + } + return lsps, nil +} + +// GetAgentSessionQueuedPrompts retrieves the number of queued prompts for a +// session. +func (c *Client) GetAgentSessionQueuedPrompts(ctx context.Context, id string, sessionID string) (int, error) { + rsp, err := c.get(ctx, fmt.Sprintf("/workspaces/%s/agent/sessions/%s/prompts/queued", id, sessionID), nil, nil) + if err != nil { + return 0, fmt.Errorf("failed to get session agent queued prompts: %w", err) + } + defer rsp.Body.Close() + if rsp.StatusCode != http.StatusOK { + return 0, fmt.Errorf("failed to get session agent queued prompts: status code %d", rsp.StatusCode) + } + var count int + if err := json.NewDecoder(rsp.Body).Decode(&count); err != nil { + return 0, fmt.Errorf("failed to decode session agent queued prompts: %w", err) + } + return count, nil +} + +// ClearAgentSessionQueuedPrompts clears the queued prompts for a session. +func (c *Client) ClearAgentSessionQueuedPrompts(ctx context.Context, id string, sessionID string) error { + rsp, err := c.post(ctx, fmt.Sprintf("/workspaces/%s/agent/sessions/%s/prompts/clear", id, sessionID), nil, nil, nil) + if err != nil { + return fmt.Errorf("failed to clear session agent queued prompts: %w", err) + } + defer rsp.Body.Close() + if rsp.StatusCode != http.StatusOK { + return fmt.Errorf("failed to clear session agent queued prompts: status code %d", rsp.StatusCode) + } + return nil +} + +// GetAgentInfo retrieves the agent status for a workspace. +func (c *Client) GetAgentInfo(ctx context.Context, id string) (*proto.AgentInfo, error) { + rsp, err := c.get(ctx, fmt.Sprintf("/workspaces/%s/agent", id), nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to get agent status: %w", err) + } + defer rsp.Body.Close() + if rsp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("failed to get agent status: status code %d", rsp.StatusCode) + } + var info proto.AgentInfo + if err := json.NewDecoder(rsp.Body).Decode(&info); err != nil { + return nil, fmt.Errorf("failed to decode agent status: %w", err) + } + return &info, nil +} + +// UpdateAgent triggers an agent model update on the server. +func (c *Client) UpdateAgent(ctx context.Context, id string) error { + rsp, err := c.post(ctx, fmt.Sprintf("/workspaces/%s/agent/update", id), nil, nil, nil) + if err != nil { + return fmt.Errorf("failed to update agent: %w", err) + } + defer rsp.Body.Close() + if rsp.StatusCode != http.StatusOK { + return fmt.Errorf("failed to update agent: status code %d", rsp.StatusCode) + } + return nil +} + +// SendMessage sends a message to the agent for a workspace. +func (c *Client) SendMessage(ctx context.Context, id string, sessionID, prompt string, attachments ...message.Attachment) error { + protoAttachments := make([]proto.Attachment, len(attachments)) + for i, a := range attachments { + protoAttachments[i] = proto.Attachment{ + FilePath: a.FilePath, + FileName: a.FileName, + MimeType: a.MimeType, + Content: a.Content, + } + } + rsp, err := c.post(ctx, fmt.Sprintf("/workspaces/%s/agent", id), nil, jsonBody(proto.AgentMessage{ + SessionID: sessionID, + Prompt: prompt, + Attachments: protoAttachments, + }), http.Header{"Content-Type": []string{"application/json"}}) + if err != nil { + return fmt.Errorf("failed to send message to agent: %w", err) + } + defer rsp.Body.Close() + if rsp.StatusCode != http.StatusOK { + return fmt.Errorf("failed to send message to agent: status code %d", rsp.StatusCode) + } + return nil +} + +// GetAgentSessionInfo retrieves the agent session info for a workspace. +func (c *Client) GetAgentSessionInfo(ctx context.Context, id string, sessionID string) (*proto.AgentSession, error) { + rsp, err := c.get(ctx, fmt.Sprintf("/workspaces/%s/agent/sessions/%s", id, sessionID), nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to get session agent info: %w", err) + } + defer rsp.Body.Close() + if rsp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("failed to get session agent info: status code %d", rsp.StatusCode) + } + var info proto.AgentSession + if err := json.NewDecoder(rsp.Body).Decode(&info); err != nil { + return nil, fmt.Errorf("failed to decode session agent info: %w", err) + } + return &info, nil +} + +// AgentSummarizeSession requests a session summarization. +func (c *Client) AgentSummarizeSession(ctx context.Context, id string, sessionID string) error { + rsp, err := c.post(ctx, fmt.Sprintf("/workspaces/%s/agent/sessions/%s/summarize", id, sessionID), nil, nil, nil) + if err != nil { + return fmt.Errorf("failed to summarize session: %w", err) + } + defer rsp.Body.Close() + if rsp.StatusCode != http.StatusOK { + return fmt.Errorf("failed to summarize session: status code %d", rsp.StatusCode) + } + return nil +} + +// InitiateAgentProcessing triggers agent initialization on the server. +func (c *Client) InitiateAgentProcessing(ctx context.Context, id string) error { + rsp, err := c.post(ctx, fmt.Sprintf("/workspaces/%s/agent/init", id), nil, nil, nil) + if err != nil { + return fmt.Errorf("failed to initiate session agent processing: %w", err) + } + defer rsp.Body.Close() + if rsp.StatusCode != http.StatusOK { + return fmt.Errorf("failed to initiate session agent processing: status code %d", rsp.StatusCode) + } + return nil +} + +// ListMessages retrieves all messages for a session. +func (c *Client) ListMessages(ctx context.Context, id string, sessionID string) ([]message.Message, error) { + rsp, err := c.get(ctx, fmt.Sprintf("/workspaces/%s/sessions/%s/messages", id, sessionID), nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to get messages: %w", err) + } + defer rsp.Body.Close() + if rsp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("failed to get messages: status code %d", rsp.StatusCode) + } + var messages []message.Message + if err := json.NewDecoder(rsp.Body).Decode(&messages); err != nil { + return nil, fmt.Errorf("failed to decode messages: %w", err) + } + return messages, nil +} + +// GetSession retrieves a specific session. +func (c *Client) GetSession(ctx context.Context, id string, sessionID string) (*session.Session, error) { + rsp, err := c.get(ctx, fmt.Sprintf("/workspaces/%s/sessions/%s", id, sessionID), nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to get session: %w", err) + } + defer rsp.Body.Close() + if rsp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("failed to get session: status code %d", rsp.StatusCode) + } + var sess session.Session + if err := json.NewDecoder(rsp.Body).Decode(&sess); err != nil { + return nil, fmt.Errorf("failed to decode session: %w", err) + } + return &sess, nil +} + +// ListSessionHistoryFiles retrieves history files for a session. +func (c *Client) ListSessionHistoryFiles(ctx context.Context, id string, sessionID string) ([]history.File, error) { + rsp, err := c.get(ctx, fmt.Sprintf("/workspaces/%s/sessions/%s/history", id, sessionID), nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to get session history files: %w", err) + } + defer rsp.Body.Close() + if rsp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("failed to get session history files: status code %d", rsp.StatusCode) + } + var files []history.File + if err := json.NewDecoder(rsp.Body).Decode(&files); err != nil { + return nil, fmt.Errorf("failed to decode session history files: %w", err) + } + return files, nil +} + +// CreateSession creates a new session in a workspace. +func (c *Client) CreateSession(ctx context.Context, id string, title string) (*session.Session, error) { + rsp, err := c.post(ctx, fmt.Sprintf("/workspaces/%s/sessions", id), nil, jsonBody(session.Session{Title: title}), http.Header{"Content-Type": []string{"application/json"}}) + if err != nil { + return nil, fmt.Errorf("failed to create session: %w", err) + } + defer rsp.Body.Close() + if rsp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("failed to create session: status code %d", rsp.StatusCode) + } + var sess session.Session + if err := json.NewDecoder(rsp.Body).Decode(&sess); err != nil { + return nil, fmt.Errorf("failed to decode session: %w", err) + } + return &sess, nil +} + +// ListSessions lists all sessions in a workspace. +func (c *Client) ListSessions(ctx context.Context, id string) ([]session.Session, error) { + rsp, err := c.get(ctx, fmt.Sprintf("/workspaces/%s/sessions", id), nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to get sessions: %w", err) + } + defer rsp.Body.Close() + if rsp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("failed to get sessions: status code %d", rsp.StatusCode) + } + var sessions []session.Session + if err := json.NewDecoder(rsp.Body).Decode(&sessions); err != nil { + return nil, fmt.Errorf("failed to decode sessions: %w", err) + } + return sessions, nil +} + +// GrantPermission grants a permission on a workspace. +func (c *Client) GrantPermission(ctx context.Context, id string, req proto.PermissionGrant) error { + rsp, err := c.post(ctx, fmt.Sprintf("/workspaces/%s/permissions/grant", id), nil, jsonBody(req), http.Header{"Content-Type": []string{"application/json"}}) + if err != nil { + return fmt.Errorf("failed to grant permission: %w", err) + } + defer rsp.Body.Close() + if rsp.StatusCode != http.StatusOK { + return fmt.Errorf("failed to grant permission: status code %d", rsp.StatusCode) + } + return nil +} + +// SetPermissionsSkipRequests sets the skip-requests flag for a workspace. +func (c *Client) SetPermissionsSkipRequests(ctx context.Context, id string, skip bool) error { + rsp, err := c.post(ctx, fmt.Sprintf("/workspaces/%s/permissions/skip", id), nil, jsonBody(proto.PermissionSkipRequest{Skip: skip}), http.Header{"Content-Type": []string{"application/json"}}) + if err != nil { + return fmt.Errorf("failed to set permissions skip requests: %w", err) + } + defer rsp.Body.Close() + if rsp.StatusCode != http.StatusOK { + return fmt.Errorf("failed to set permissions skip requests: status code %d", rsp.StatusCode) + } + return nil +} + +// GetPermissionsSkipRequests retrieves the skip-requests flag for a workspace. +func (c *Client) GetPermissionsSkipRequests(ctx context.Context, id string) (bool, error) { + rsp, err := c.get(ctx, fmt.Sprintf("/workspaces/%s/permissions/skip", id), nil, nil) + if err != nil { + return false, fmt.Errorf("failed to get permissions skip requests: %w", err) + } + defer rsp.Body.Close() + if rsp.StatusCode != http.StatusOK { + return false, fmt.Errorf("failed to get permissions skip requests: status code %d", rsp.StatusCode) + } + var skip proto.PermissionSkipRequest + if err := json.NewDecoder(rsp.Body).Decode(&skip); err != nil { + return false, fmt.Errorf("failed to decode permissions skip requests: %w", err) + } + return skip.Skip, nil +} + +// GetConfig retrieves the workspace-specific configuration. +func (c *Client) GetConfig(ctx context.Context, id string) (*config.Config, error) { + rsp, err := c.get(ctx, fmt.Sprintf("/workspaces/%s/config", id), nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to get config: %w", err) + } + defer rsp.Body.Close() + if rsp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("failed to get config: status code %d", rsp.StatusCode) + } + var cfg config.Config + if err := json.NewDecoder(rsp.Body).Decode(&cfg); err != nil { + return nil, fmt.Errorf("failed to decode config: %w", err) + } + return &cfg, nil +} + +func jsonBody(v any) *bytes.Buffer { + b := new(bytes.Buffer) + m, _ := json.Marshal(v) + b.Write(m) + return b +} diff --git a/internal/cmd/root.go b/internal/cmd/root.go index 52ffda3fb09a0e6fdfb88084b80f7bdd261fb3c2..11e2ccc377f6ecc5b8dfca39f1b84609528c92a0 100644 --- a/internal/cmd/root.go +++ b/internal/cmd/root.go @@ -6,20 +6,29 @@ import ( "errors" "fmt" "io" + "io/fs" "log/slog" + "net/url" "os" + "os/exec" "path/filepath" + "regexp" "strconv" "strings" + "time" tea "charm.land/bubbletea/v2" "charm.land/lipgloss/v2" "github.com/charmbracelet/colorprofile" "github.com/charmbracelet/crush/internal/app" + "github.com/charmbracelet/crush/internal/client" "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/db" "github.com/charmbracelet/crush/internal/event" + "github.com/charmbracelet/crush/internal/log" "github.com/charmbracelet/crush/internal/projects" + "github.com/charmbracelet/crush/internal/proto" + "github.com/charmbracelet/crush/internal/server" "github.com/charmbracelet/crush/internal/ui/common" ui "github.com/charmbracelet/crush/internal/ui/model" "github.com/charmbracelet/crush/internal/version" @@ -31,6 +40,8 @@ import ( "github.com/spf13/cobra" ) +var clientHost string + func init() { rootCmd.PersistentFlags().StringP("cwd", "c", "", "Current working directory") rootCmd.PersistentFlags().StringP("data-dir", "D", "", "Custom crush data directory") @@ -38,6 +49,8 @@ func init() { rootCmd.Flags().BoolP("help", "h", false, "Help") rootCmd.Flags().BoolP("yolo", "y", false, "Automatically accept all permissions (dangerous mode)") + rootCmd.Flags().StringVarP(&clientHost, "host", "H", server.DefaultHost(), "Connect to a specific crush server host (for advanced users)") + rootCmd.AddCommand( runCmd, dirsCmd, @@ -228,6 +241,167 @@ func setupApp(cmd *cobra.Command) (*app.App, error) { return appInstance, nil } +// setupClientApp sets up a client-based workspace via the server. It +// auto-starts a detached server process if the socket does not exist. +func setupClientApp(cmd *cobra.Command, hostURL *url.URL) (*client.Client, *proto.Workspace, error) { + debug, _ := cmd.Flags().GetBool("debug") + yolo, _ := cmd.Flags().GetBool("yolo") + dataDir, _ := cmd.Flags().GetString("data-dir") + ctx := cmd.Context() + + cwd, err := ResolveCwd(cmd) + if err != nil { + return nil, nil, err + } + + c, err := client.NewClient(cwd, hostURL.Scheme, hostURL.Host) + if err != nil { + return nil, nil, err + } + + ws, err := c.CreateWorkspace(ctx, proto.Workspace{ + Path: cwd, + DataDir: dataDir, + Debug: debug, + YOLO: yolo, + Env: os.Environ(), + }) + if err != nil { + return nil, nil, fmt.Errorf("failed to create workspace: %v", err) + } + + cfg, err := c.GetGlobalConfig(cmd.Context()) + if err != nil { + return nil, nil, fmt.Errorf("failed to get global config: %v", err) + } + + if shouldEnableMetrics(cfg) { + event.Init() + } + + return c, ws, nil +} + +// ensureServer auto-starts a detached server if the socket file does not +// exist. +func ensureServer(cmd *cobra.Command, hostURL *url.URL) error { + switch hostURL.Scheme { + case "unix", "npipe": + _, err := os.Stat(hostURL.Host) + if err != nil && errors.Is(err, fs.ErrNotExist) { + if err := startDetachedServer(cmd); err != nil { + return err + } + } + + for range 10 { + _, err = os.Stat(hostURL.Host) + if err == nil { + break + } + select { + case <-cmd.Context().Done(): + return cmd.Context().Err() + case <-time.After(100 * time.Millisecond): + } + } + if err != nil { + return fmt.Errorf("failed to initialize crush server: %v", err) + } + default: + // TCP: assume server is already running. + } + return nil +} + +// waitForHealth polls the server's health endpoint until it responds. +func waitForHealth(ctx context.Context, c *client.Client) error { + var err error + for range 10 { + err = c.Health(ctx) + if err == nil { + return nil + } + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(100 * time.Millisecond): + } + } + return fmt.Errorf("failed to connect to crush server: %v", err) +} + +// streamEvents forwards SSE events from the client to the TUI program. +func streamEvents(ctx context.Context, evc <-chan any, p *tea.Program) { + defer log.RecoverPanic("app.Subscribe", func() { + slog.Info("TUI subscription panic: attempting graceful shutdown") + p.Quit() + }) + + for { + select { + case <-ctx.Done(): + slog.Debug("TUI message handler shutting down") + return + case ev, ok := <-evc: + if !ok { + slog.Debug("TUI message channel closed") + return + } + p.Send(ev) + } + } +} + +var safeNameRegexp = regexp.MustCompile(`[^a-zA-Z0-9._-]`) + +func startDetachedServer(cmd *cobra.Command) error { + exe, err := os.Executable() + if err != nil { + return fmt.Errorf("failed to get executable path: %v", err) + } + + safeClientHost := safeNameRegexp.ReplaceAllString(clientHost, "_") + chDir := filepath.Join(config.GlobalCacheDir(), "server-"+safeClientHost) + if err := os.MkdirAll(chDir, 0o700); err != nil { + return fmt.Errorf("failed to create server working directory: %v", err) + } + + cmdArgs := []string{"server"} + if clientHost != server.DefaultHost() { + cmdArgs = append(cmdArgs, "--host", clientHost) + } + + c := exec.CommandContext(cmd.Context(), exe, cmdArgs...) + stdoutPath := filepath.Join(chDir, "stdout.log") + stderrPath := filepath.Join(chDir, "stderr.log") + detachProcess(c) + + stdout, err := os.Create(stdoutPath) + if err != nil { + return fmt.Errorf("failed to create stdout log file: %v", err) + } + defer stdout.Close() + c.Stdout = stdout + + stderr, err := os.Create(stderrPath) + if err != nil { + return fmt.Errorf("failed to create stderr log file: %v", err) + } + defer stderr.Close() + c.Stderr = stderr + + if err := c.Start(); err != nil { + return fmt.Errorf("failed to start crush server: %v", err) + } + + if err := c.Process.Release(); err != nil { + return fmt.Errorf("failed to detach crush server process: %v", err) + } + + return nil +} + func shouldEnableMetrics(cfg *config.Config) bool { if v, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_METRICS")); v { return false diff --git a/internal/cmd/root_other.go b/internal/cmd/root_other.go new file mode 100644 index 0000000000000000000000000000000000000000..6d178a07a6e55c85c7fdd4d6a4d98d923aad5a71 --- /dev/null +++ b/internal/cmd/root_other.go @@ -0,0 +1,16 @@ +//go:build !windows +// +build !windows + +package cmd + +import ( + "os/exec" + "syscall" +) + +func detachProcess(c *exec.Cmd) { + if c.SysProcAttr == nil { + c.SysProcAttr = &syscall.SysProcAttr{} + } + c.SysProcAttr.Setsid = true +} diff --git a/internal/cmd/root_windows.go b/internal/cmd/root_windows.go new file mode 100644 index 0000000000000000000000000000000000000000..134b6de258cd798ccfc2dbb7803099f11e31b052 --- /dev/null +++ b/internal/cmd/root_windows.go @@ -0,0 +1,18 @@ +//go:build windows +// +build windows + +package cmd + +import ( + "os/exec" + "syscall" + + "golang.org/x/sys/windows" +) + +func detachProcess(c *exec.Cmd) { + if c.SysProcAttr == nil { + c.SysProcAttr = &syscall.SysProcAttr{} + } + c.SysProcAttr.CreationFlags = syscall.CREATE_NEW_PROCESS_GROUP | windows.DETACHED_PROCESS +} diff --git a/internal/cmd/server.go b/internal/cmd/server.go new file mode 100644 index 0000000000000000000000000000000000000000..ce20d20b282f5ee364b886699c32d07594ca03fd --- /dev/null +++ b/internal/cmd/server.go @@ -0,0 +1,98 @@ +package cmd + +import ( + "context" + "errors" + "fmt" + "log/slog" + "os" + "os/signal" + "time" + + "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/server" + "github.com/spf13/cobra" +) + +var serverHost string + +func init() { + serverCmd.Flags().StringVarP(&serverHost, "host", "H", server.DefaultHost(), "Server host (TCP or Unix socket)") + rootCmd.AddCommand(serverCmd) +} + +var serverCmd = &cobra.Command{ + Use: "server", + Short: "Start the Crush server", + RunE: func(cmd *cobra.Command, _ []string) error { + dataDir, err := cmd.Flags().GetString("data-dir") + if err != nil { + return fmt.Errorf("failed to get data directory: %v", err) + } + debug, err := cmd.Flags().GetBool("debug") + if err != nil { + return fmt.Errorf("failed to get debug flag: %v", err) + } + + cfg, err := config.Load("", dataDir, debug) + if err != nil { + return fmt.Errorf("failed to load configuration: %v", err) + } + + handler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ + Level: slog.LevelInfo, + }) + if debug { + handler = slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ + Level: slog.LevelDebug, + }) + } + slog.SetDefault(slog.New(handler)) + + hostURL, err := server.ParseHostURL(serverHost) + if err != nil { + return fmt.Errorf("invalid server host: %v", err) + } + + srv := server.NewServer(cfg, hostURL.Scheme, hostURL.Host) + srv.SetLogger(slog.Default()) + slog.Info("Starting Crush server...", "addr", serverHost) + + errch := make(chan error, 1) + sigch := make(chan os.Signal, 1) + sigs := []os.Signal{os.Interrupt} + sigs = append(sigs, addSignals(sigs)...) + signal.Notify(sigch, sigs...) + + go func() { + errch <- srv.ListenAndServe() + }() + + select { + case <-sigch: + slog.Info("Received interrupt signal...") + case err = <-errch: + if err != nil && !errors.Is(err, server.ErrServerClosed) { + _ = srv.Close() + slog.Error("Server error", "error", err) + return fmt.Errorf("server error: %v", err) + } + } + + if errors.Is(err, server.ErrServerClosed) { + return nil + } + + ctx, cancel := context.WithTimeout(cmd.Context(), 5*time.Second) + defer cancel() + + slog.Info("Shutting down...") + + if err := srv.Shutdown(ctx); err != nil { + slog.Error("Failed to shutdown server", "error", err) + return fmt.Errorf("failed to shutdown server: %v", err) + } + + return nil + }, +} diff --git a/internal/cmd/server_other.go b/internal/cmd/server_other.go new file mode 100644 index 0000000000000000000000000000000000000000..58b05629bf5b85a579ded6379dec53f555fb68e7 --- /dev/null +++ b/internal/cmd/server_other.go @@ -0,0 +1,13 @@ +//go:build !windows +// +build !windows + +package cmd + +import ( + "os" + "syscall" +) + +func addSignals(sigs []os.Signal) []os.Signal { + return append(sigs, syscall.SIGTERM) +} diff --git a/internal/cmd/server_windows.go b/internal/cmd/server_windows.go new file mode 100644 index 0000000000000000000000000000000000000000..eff60b8b635128e08a6e27f4273520607f05a4c3 --- /dev/null +++ b/internal/cmd/server_windows.go @@ -0,0 +1,10 @@ +//go:build windows +// +build windows + +package cmd + +import "os" + +func addSignals(sigs []os.Signal) []os.Signal { + return sigs +} diff --git a/internal/config/load.go b/internal/config/load.go index 3fba44aa9142c52b8966b1dbe994cef0ae654c48..7bbc6b983439df387cfbc3430debf6b3a8f39bf5 100644 --- a/internal/config/load.go +++ b/internal/config/load.go @@ -721,6 +721,25 @@ func GlobalConfig() string { return filepath.Join(home.Dir(), ".config", appName, fmt.Sprintf("%s.json", appName)) } +// GlobalCacheDir returns the path to the global cache directory for the +// application. +func GlobalCacheDir() string { + if crushCache := os.Getenv("CRUSH_CACHE_DIR"); crushCache != "" { + return crushCache + } + if xdgCacheHome := os.Getenv("XDG_CACHE_HOME"); xdgCacheHome != "" { + return filepath.Join(xdgCacheHome, appName) + } + if runtime.GOOS == "windows" { + localAppData := cmp.Or( + os.Getenv("LOCALAPPDATA"), + filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local"), + ) + return filepath.Join(localAppData, appName, "cache") + } + return filepath.Join(home.Dir(), ".cache", appName) +} + // GlobalConfigData returns the path to the main data directory for the application. // this config is used when the app overrides configurations instead of updating the global config. func GlobalConfigData() string { diff --git a/internal/proto/agent.go b/internal/proto/agent.go new file mode 100644 index 0000000000000000000000000000000000000000..1163b1d8bac629546c8ef6632b0fed6a780c09e5 --- /dev/null +++ b/internal/proto/agent.go @@ -0,0 +1,74 @@ +package proto + +import ( + "encoding/json" + "errors" +) + +// AgentEventType represents the type of agent event. +type AgentEventType string + +const ( + AgentEventTypeError AgentEventType = "error" + AgentEventTypeResponse AgentEventType = "response" + AgentEventTypeSummarize AgentEventType = "summarize" +) + +// MarshalText implements the [encoding.TextMarshaler] interface. +func (t AgentEventType) MarshalText() ([]byte, error) { + return []byte(t), nil +} + +// UnmarshalText implements the [encoding.TextUnmarshaler] interface. +func (t *AgentEventType) UnmarshalText(text []byte) error { + *t = AgentEventType(text) + return nil +} + +// AgentEvent represents an event emitted by the agent. +type AgentEvent struct { + Type AgentEventType `json:"type"` + Message Message `json:"message"` + Error error `json:"error,omitempty"` + + // When summarizing. + SessionID string `json:"session_id,omitempty"` + Progress string `json:"progress,omitempty"` + Done bool `json:"done,omitempty"` +} + +// MarshalJSON implements the [json.Marshaler] interface. +func (e AgentEvent) MarshalJSON() ([]byte, error) { + type Alias AgentEvent + return json.Marshal(&struct { + Error string `json:"error,omitempty"` + Alias + }{ + Error: func() string { + if e.Error != nil { + return e.Error.Error() + } + return "" + }(), + Alias: (Alias)(e), + }) +} + +// UnmarshalJSON implements the [json.Unmarshaler] interface. +func (e *AgentEvent) UnmarshalJSON(data []byte) error { + type Alias AgentEvent + aux := &struct { + Error string `json:"error,omitempty"` + Alias + }{ + Alias: (Alias)(*e), + } + if err := json.Unmarshal(data, &aux); err != nil { + return err + } + *e = AgentEvent(aux.Alias) + if aux.Error != "" { + e.Error = errors.New(aux.Error) + } + return nil +} diff --git a/internal/proto/history.go b/internal/proto/history.go new file mode 100644 index 0000000000000000000000000000000000000000..caf60a7127c817d7d893b8418307226509d2e10a --- /dev/null +++ b/internal/proto/history.go @@ -0,0 +1,12 @@ +package proto + +// File represents a file tracked in session history. +type File struct { + ID string `json:"id"` + SessionID string `json:"session_id"` + Path string `json:"path"` + Content string `json:"content"` + Version int64 `json:"version"` + CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` +} diff --git a/internal/proto/mcp.go b/internal/proto/mcp.go new file mode 100644 index 0000000000000000000000000000000000000000..e04f9ed8467890bc34859cd54272204ad65a9156 --- /dev/null +++ b/internal/proto/mcp.go @@ -0,0 +1,78 @@ +package proto + +import "fmt" + +// MCPState represents the current state of an MCP client. +type MCPState int + +const ( + MCPStateDisabled MCPState = iota + MCPStateStarting + MCPStateConnected + MCPStateError +) + +// MarshalText implements the [encoding.TextMarshaler] interface. +func (s MCPState) MarshalText() ([]byte, error) { + return []byte(s.String()), nil +} + +// UnmarshalText implements the [encoding.TextUnmarshaler] interface. +func (s *MCPState) UnmarshalText(data []byte) error { + switch string(data) { + case "disabled": + *s = MCPStateDisabled + case "starting": + *s = MCPStateStarting + case "connected": + *s = MCPStateConnected + case "error": + *s = MCPStateError + default: + return fmt.Errorf("unknown mcp state: %s", data) + } + return nil +} + +// String returns the string representation of the MCPState. +func (s MCPState) String() string { + switch s { + case MCPStateDisabled: + return "disabled" + case MCPStateStarting: + return "starting" + case MCPStateConnected: + return "connected" + case MCPStateError: + return "error" + default: + return "unknown" + } +} + +// MCPEventType represents the type of MCP event. +type MCPEventType string + +const ( + MCPEventStateChanged MCPEventType = "state_changed" +) + +// MarshalText implements the [encoding.TextMarshaler] interface. +func (t MCPEventType) MarshalText() ([]byte, error) { + return []byte(t), nil +} + +// UnmarshalText implements the [encoding.TextUnmarshaler] interface. +func (t *MCPEventType) UnmarshalText(data []byte) error { + *t = MCPEventType(data) + return nil +} + +// MCPEvent represents an event in the MCP system. +type MCPEvent struct { + Type MCPEventType `json:"type"` + Name string `json:"name"` + State MCPState `json:"state"` + Error error `json:"error,omitempty"` + ToolCount int `json:"tool_count,omitempty"` +} diff --git a/internal/proto/message.go b/internal/proto/message.go new file mode 100644 index 0000000000000000000000000000000000000000..f24cf80584a802cad34a91cf64895faf81e3a32d --- /dev/null +++ b/internal/proto/message.go @@ -0,0 +1,653 @@ +package proto + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "slices" + "time" + + "charm.land/catwalk/pkg/catwalk" +) + +// CreateMessageParams represents parameters for creating a message. +type CreateMessageParams struct { + Role MessageRole `json:"role"` + Parts []ContentPart `json:"parts"` + Model string `json:"model"` + Provider string `json:"provider,omitempty"` +} + +// Message represents a message in the proto layer. +type Message struct { + ID string `json:"id"` + Role MessageRole `json:"role"` + SessionID string `json:"session_id"` + Parts []ContentPart `json:"parts"` + Model string `json:"model"` + Provider string `json:"provider"` + CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` +} + +// MessageRole represents the role of a message sender. +type MessageRole string + +const ( + Assistant MessageRole = "assistant" + User MessageRole = "user" + System MessageRole = "system" + Tool MessageRole = "tool" +) + +// MarshalText implements the [encoding.TextMarshaler] interface. +func (r MessageRole) MarshalText() ([]byte, error) { + return []byte(r), nil +} + +// UnmarshalText implements the [encoding.TextUnmarshaler] interface. +func (r *MessageRole) UnmarshalText(data []byte) error { + *r = MessageRole(data) + return nil +} + +// FinishReason represents why a message generation finished. +type FinishReason string + +const ( + FinishReasonEndTurn FinishReason = "end_turn" + FinishReasonMaxTokens FinishReason = "max_tokens" + FinishReasonToolUse FinishReason = "tool_use" + FinishReasonCanceled FinishReason = "canceled" + FinishReasonError FinishReason = "error" + FinishReasonPermissionDenied FinishReason = "permission_denied" + FinishReasonUnknown FinishReason = "unknown" +) + +// MarshalText implements the [encoding.TextMarshaler] interface. +func (fr FinishReason) MarshalText() ([]byte, error) { + return []byte(fr), nil +} + +// UnmarshalText implements the [encoding.TextUnmarshaler] interface. +func (fr *FinishReason) UnmarshalText(data []byte) error { + *fr = FinishReason(data) + return nil +} + +// ContentPart is a part of a message's content. +type ContentPart interface { + isPart() +} + +// ReasoningContent represents the reasoning/thinking part of a message. +type ReasoningContent struct { + Thinking string `json:"thinking"` + Signature string `json:"signature"` + StartedAt int64 `json:"started_at,omitempty"` + FinishedAt int64 `json:"finished_at,omitempty"` +} + +// String returns the thinking content as a string. +func (tc ReasoningContent) String() string { + return tc.Thinking +} + +func (ReasoningContent) isPart() {} + +// TextContent represents a text part of a message. +type TextContent struct { + Text string `json:"text"` +} + +// String returns the text content as a string. +func (tc TextContent) String() string { + return tc.Text +} + +func (TextContent) isPart() {} + +// ImageURLContent represents an image URL part of a message. +type ImageURLContent struct { + URL string `json:"url"` + Detail string `json:"detail,omitempty"` +} + +// String returns the image URL as a string. +func (iuc ImageURLContent) String() string { + return iuc.URL +} + +func (ImageURLContent) isPart() {} + +// BinaryContent represents binary data in a message. +type BinaryContent struct { + Path string + MIMEType string + Data []byte +} + +// String returns a base64-encoded string of the binary data. +func (bc BinaryContent) String(p catwalk.InferenceProvider) string { + base64Encoded := base64.StdEncoding.EncodeToString(bc.Data) + if p == catwalk.InferenceProviderOpenAI { + return "data:" + bc.MIMEType + ";base64," + base64Encoded + } + return base64Encoded +} + +func (BinaryContent) isPart() {} + +// ToolCall represents a tool call in a message. +type ToolCall struct { + ID string `json:"id"` + Name string `json:"name"` + Input string `json:"input"` + Type string `json:"type,omitempty"` + Finished bool `json:"finished,omitempty"` +} + +func (ToolCall) isPart() {} + +// ToolResult represents the result of a tool call. +type ToolResult struct { + ToolCallID string `json:"tool_call_id"` + Name string `json:"name"` + Content string `json:"content"` + Metadata string `json:"metadata"` + IsError bool `json:"is_error"` +} + +func (ToolResult) isPart() {} + +// Finish represents the end of a message generation. +type Finish struct { + Reason FinishReason `json:"reason"` + Time int64 `json:"time"` + Message string `json:"message,omitempty"` + Details string `json:"details,omitempty"` +} + +func (Finish) isPart() {} + +// MarshalJSON implements the [json.Marshaler] interface. +func (m Message) MarshalJSON() ([]byte, error) { + parts, err := MarshalParts(m.Parts) + if err != nil { + return nil, err + } + + type Alias Message + return json.Marshal(&struct { + Parts json.RawMessage `json:"parts"` + *Alias + }{ + Parts: json.RawMessage(parts), + Alias: (*Alias)(&m), + }) +} + +// UnmarshalJSON implements the [json.Unmarshaler] interface. +func (m *Message) UnmarshalJSON(data []byte) error { + type Alias Message + aux := &struct { + Parts json.RawMessage `json:"parts"` + *Alias + }{ + Alias: (*Alias)(m), + } + + if err := json.Unmarshal(data, &aux); err != nil { + return err + } + + parts, err := UnmarshalParts([]byte(aux.Parts)) + if err != nil { + return err + } + + m.Parts = parts + return nil +} + +// Content returns the first text content part. +func (m *Message) Content() TextContent { + for _, part := range m.Parts { + if c, ok := part.(TextContent); ok { + return c + } + } + return TextContent{} +} + +// ReasoningContent returns the first reasoning content part. +func (m *Message) ReasoningContent() ReasoningContent { + for _, part := range m.Parts { + if c, ok := part.(ReasoningContent); ok { + return c + } + } + return ReasoningContent{} +} + +// ImageURLContent returns all image URL content parts. +func (m *Message) ImageURLContent() []ImageURLContent { + imageURLContents := make([]ImageURLContent, 0) + for _, part := range m.Parts { + if c, ok := part.(ImageURLContent); ok { + imageURLContents = append(imageURLContents, c) + } + } + return imageURLContents +} + +// BinaryContent returns all binary content parts. +func (m *Message) BinaryContent() []BinaryContent { + binaryContents := make([]BinaryContent, 0) + for _, part := range m.Parts { + if c, ok := part.(BinaryContent); ok { + binaryContents = append(binaryContents, c) + } + } + return binaryContents +} + +// ToolCalls returns all tool call parts. +func (m *Message) ToolCalls() []ToolCall { + toolCalls := make([]ToolCall, 0) + for _, part := range m.Parts { + if c, ok := part.(ToolCall); ok { + toolCalls = append(toolCalls, c) + } + } + return toolCalls +} + +// ToolResults returns all tool result parts. +func (m *Message) ToolResults() []ToolResult { + toolResults := make([]ToolResult, 0) + for _, part := range m.Parts { + if c, ok := part.(ToolResult); ok { + toolResults = append(toolResults, c) + } + } + return toolResults +} + +// IsFinished returns true if the message has a finish part. +func (m *Message) IsFinished() bool { + for _, part := range m.Parts { + if _, ok := part.(Finish); ok { + return true + } + } + return false +} + +// FinishPart returns the finish part if present. +func (m *Message) FinishPart() *Finish { + for _, part := range m.Parts { + if c, ok := part.(Finish); ok { + return &c + } + } + return nil +} + +// FinishReason returns the finish reason if present. +func (m *Message) FinishReason() FinishReason { + for _, part := range m.Parts { + if c, ok := part.(Finish); ok { + return c.Reason + } + } + return "" +} + +// IsThinking returns true if the message is currently in a thinking state. +func (m *Message) IsThinking() bool { + return m.ReasoningContent().Thinking != "" && m.Content().Text == "" && !m.IsFinished() +} + +// AppendContent appends text to the text content part. +func (m *Message) AppendContent(delta string) { + found := false + for i, part := range m.Parts { + if c, ok := part.(TextContent); ok { + m.Parts[i] = TextContent{Text: c.Text + delta} + found = true + } + } + if !found { + m.Parts = append(m.Parts, TextContent{Text: delta}) + } +} + +// AppendReasoningContent appends text to the reasoning content part. +func (m *Message) AppendReasoningContent(delta string) { + found := false + for i, part := range m.Parts { + if c, ok := part.(ReasoningContent); ok { + m.Parts[i] = ReasoningContent{ + Thinking: c.Thinking + delta, + Signature: c.Signature, + StartedAt: c.StartedAt, + FinishedAt: c.FinishedAt, + } + found = true + } + } + if !found { + m.Parts = append(m.Parts, ReasoningContent{ + Thinking: delta, + StartedAt: time.Now().Unix(), + }) + } +} + +// AppendReasoningSignature appends a signature to the reasoning content part. +func (m *Message) AppendReasoningSignature(signature string) { + for i, part := range m.Parts { + if c, ok := part.(ReasoningContent); ok { + m.Parts[i] = ReasoningContent{ + Thinking: c.Thinking, + Signature: c.Signature + signature, + StartedAt: c.StartedAt, + FinishedAt: c.FinishedAt, + } + return + } + } + m.Parts = append(m.Parts, ReasoningContent{Signature: signature}) +} + +// FinishThinking marks the reasoning content as finished. +func (m *Message) FinishThinking() { + for i, part := range m.Parts { + if c, ok := part.(ReasoningContent); ok { + if c.FinishedAt == 0 { + m.Parts[i] = ReasoningContent{ + Thinking: c.Thinking, + Signature: c.Signature, + StartedAt: c.StartedAt, + FinishedAt: time.Now().Unix(), + } + } + return + } + } +} + +// ThinkingDuration returns the duration of the thinking phase. +func (m *Message) ThinkingDuration() time.Duration { + reasoning := m.ReasoningContent() + if reasoning.StartedAt == 0 { + return 0 + } + + endTime := reasoning.FinishedAt + if endTime == 0 { + endTime = time.Now().Unix() + } + + return time.Duration(endTime-reasoning.StartedAt) * time.Second +} + +// FinishToolCall marks a tool call as finished. +func (m *Message) FinishToolCall(toolCallID string) { + for i, part := range m.Parts { + if c, ok := part.(ToolCall); ok { + if c.ID == toolCallID { + m.Parts[i] = ToolCall{ + ID: c.ID, + Name: c.Name, + Input: c.Input, + Type: c.Type, + Finished: true, + } + return + } + } + } +} + +// AppendToolCallInput appends input to a tool call. +func (m *Message) AppendToolCallInput(toolCallID string, inputDelta string) { + for i, part := range m.Parts { + if c, ok := part.(ToolCall); ok { + if c.ID == toolCallID { + m.Parts[i] = ToolCall{ + ID: c.ID, + Name: c.Name, + Input: c.Input + inputDelta, + Type: c.Type, + Finished: c.Finished, + } + return + } + } + } +} + +// AddToolCall adds or updates a tool call. +func (m *Message) AddToolCall(tc ToolCall) { + for i, part := range m.Parts { + if c, ok := part.(ToolCall); ok { + if c.ID == tc.ID { + m.Parts[i] = tc + return + } + } + } + m.Parts = append(m.Parts, tc) +} + +// SetToolCalls replaces all tool call parts. +func (m *Message) SetToolCalls(tc []ToolCall) { + parts := make([]ContentPart, 0) + for _, part := range m.Parts { + if _, ok := part.(ToolCall); ok { + continue + } + parts = append(parts, part) + } + m.Parts = parts + for _, toolCall := range tc { + m.Parts = append(m.Parts, toolCall) + } +} + +// AddToolResult adds a tool result. +func (m *Message) AddToolResult(tr ToolResult) { + m.Parts = append(m.Parts, tr) +} + +// SetToolResults adds multiple tool results. +func (m *Message) SetToolResults(tr []ToolResult) { + for _, toolResult := range tr { + m.Parts = append(m.Parts, toolResult) + } +} + +// AddFinish adds a finish part to the message. +func (m *Message) AddFinish(reason FinishReason, message, details string) { + for i, part := range m.Parts { + if _, ok := part.(Finish); ok { + m.Parts = slices.Delete(m.Parts, i, i+1) + break + } + } + m.Parts = append(m.Parts, Finish{Reason: reason, Time: time.Now().Unix(), Message: message, Details: details}) +} + +// AddImageURL adds an image URL part to the message. +func (m *Message) AddImageURL(url, detail string) { + m.Parts = append(m.Parts, ImageURLContent{URL: url, Detail: detail}) +} + +// AddBinary adds a binary content part to the message. +func (m *Message) AddBinary(mimeType string, data []byte) { + m.Parts = append(m.Parts, BinaryContent{MIMEType: mimeType, Data: data}) +} + +type partType string + +const ( + reasoningType partType = "reasoning" + textType partType = "text" + imageURLType partType = "image_url" + binaryType partType = "binary" + toolCallType partType = "tool_call" + toolResultType partType = "tool_result" + finishType partType = "finish" +) + +type partWrapper struct { + Type partType `json:"type"` + Data ContentPart `json:"data"` +} + +// MarshalParts marshals content parts to JSON. +func MarshalParts(parts []ContentPart) ([]byte, error) { + wrappedParts := make([]partWrapper, len(parts)) + + for i, part := range parts { + var typ partType + + switch part.(type) { + case ReasoningContent: + typ = reasoningType + case TextContent: + typ = textType + case ImageURLContent: + typ = imageURLType + case BinaryContent: + typ = binaryType + case ToolCall: + typ = toolCallType + case ToolResult: + typ = toolResultType + case Finish: + typ = finishType + default: + return nil, fmt.Errorf("unknown part type: %T", part) + } + + wrappedParts[i] = partWrapper{ + Type: typ, + Data: part, + } + } + return json.Marshal(wrappedParts) +} + +// UnmarshalParts unmarshals content parts from JSON. +func UnmarshalParts(data []byte) ([]ContentPart, error) { + temp := []json.RawMessage{} + + if err := json.Unmarshal(data, &temp); err != nil { + return nil, err + } + + parts := make([]ContentPart, 0) + + for _, rawPart := range temp { + var wrapper struct { + Type partType `json:"type"` + Data json.RawMessage `json:"data"` + } + + if err := json.Unmarshal(rawPart, &wrapper); err != nil { + return nil, err + } + + switch wrapper.Type { + case reasoningType: + part := ReasoningContent{} + if err := json.Unmarshal(wrapper.Data, &part); err != nil { + return nil, err + } + parts = append(parts, part) + case textType: + part := TextContent{} + if err := json.Unmarshal(wrapper.Data, &part); err != nil { + return nil, err + } + parts = append(parts, part) + case imageURLType: + part := ImageURLContent{} + if err := json.Unmarshal(wrapper.Data, &part); err != nil { + return nil, err + } + parts = append(parts, part) + case binaryType: + part := BinaryContent{} + if err := json.Unmarshal(wrapper.Data, &part); err != nil { + return nil, err + } + parts = append(parts, part) + case toolCallType: + part := ToolCall{} + if err := json.Unmarshal(wrapper.Data, &part); err != nil { + return nil, err + } + parts = append(parts, part) + case toolResultType: + part := ToolResult{} + if err := json.Unmarshal(wrapper.Data, &part); err != nil { + return nil, err + } + parts = append(parts, part) + case finishType: + part := Finish{} + if err := json.Unmarshal(wrapper.Data, &part); err != nil { + return nil, err + } + parts = append(parts, part) + default: + return nil, fmt.Errorf("unknown part type: %s", wrapper.Type) + } + } + + return parts, nil +} + +// Attachment represents a file attachment. +type Attachment struct { + FilePath string `json:"file_path"` + FileName string `json:"file_name"` + MimeType string `json:"mime_type"` + Content []byte `json:"content"` +} + +// MarshalJSON implements the [json.Marshaler] interface. +func (a Attachment) MarshalJSON() ([]byte, error) { + type Alias Attachment + return json.Marshal(&struct { + Content string `json:"content"` + *Alias + }{ + Content: base64.StdEncoding.EncodeToString(a.Content), + Alias: (*Alias)(&a), + }) +} + +// UnmarshalJSON implements the [json.Unmarshaler] interface. +func (a *Attachment) UnmarshalJSON(data []byte) error { + type Alias Attachment + aux := &struct { + Content string `json:"content"` + *Alias + }{ + Alias: (*Alias)(a), + } + if err := json.Unmarshal(data, &aux); err != nil { + return err + } + content, err := base64.StdEncoding.DecodeString(aux.Content) + if err != nil { + return err + } + a.Content = content + return nil +} diff --git a/internal/proto/permission.go b/internal/proto/permission.go new file mode 100644 index 0000000000000000000000000000000000000000..5834de628e41a290d0bc391fbe3ead2505eb742a --- /dev/null +++ b/internal/proto/permission.go @@ -0,0 +1,141 @@ +package proto + +import ( + "encoding/json" +) + +// CreatePermissionRequest represents a request to create a permission. +type CreatePermissionRequest struct { + SessionID string `json:"session_id"` + ToolCallID string `json:"tool_call_id"` + ToolName string `json:"tool_name"` + Description string `json:"description"` + Action string `json:"action"` + Params any `json:"params"` + Path string `json:"path"` +} + +// PermissionNotification represents a notification about a permission change. +type PermissionNotification struct { + ToolCallID string `json:"tool_call_id"` + Granted bool `json:"granted"` + Denied bool `json:"denied"` +} + +// PermissionRequest represents a pending permission request. +type PermissionRequest struct { + ID string `json:"id"` + SessionID string `json:"session_id"` + ToolCallID string `json:"tool_call_id"` + ToolName string `json:"tool_name"` + Description string `json:"description"` + Action string `json:"action"` + Params any `json:"params"` + Path string `json:"path"` +} + +// UnmarshalJSON implements the json.Unmarshaler interface. This is needed +// because the Params field is of type any, so we need to unmarshal it into +// its appropriate type based on the [PermissionRequest.ToolName]. +func (p *PermissionRequest) UnmarshalJSON(data []byte) error { + type Alias PermissionRequest + aux := &struct { + Params json.RawMessage `json:"params"` + *Alias + }{ + Alias: (*Alias)(p), + } + if err := json.Unmarshal(data, &aux); err != nil { + return err + } + + params, err := unmarshalToolParams(p.ToolName, aux.Params) + if err != nil { + return err + } + p.Params = params + return nil +} + +// UnmarshalJSON implements the json.Unmarshaler interface. This is needed +// because the Params field is of type any, so we need to unmarshal it into +// its appropriate type based on the [CreatePermissionRequest.ToolName]. +func (p *CreatePermissionRequest) UnmarshalJSON(data []byte) error { + type Alias CreatePermissionRequest + aux := &struct { + Params json.RawMessage `json:"params"` + *Alias + }{ + Alias: (*Alias)(p), + } + if err := json.Unmarshal(data, &aux); err != nil { + return err + } + + params, err := unmarshalToolParams(p.ToolName, aux.Params) + if err != nil { + return err + } + p.Params = params + return nil +} + +func unmarshalToolParams(toolName string, raw json.RawMessage) (any, error) { + switch toolName { + case BashToolName: + var params BashPermissionsParams + if err := json.Unmarshal(raw, ¶ms); err != nil { + return nil, err + } + return params, nil + case DownloadToolName: + var params DownloadPermissionsParams + if err := json.Unmarshal(raw, ¶ms); err != nil { + return nil, err + } + return params, nil + case EditToolName: + var params EditPermissionsParams + if err := json.Unmarshal(raw, ¶ms); err != nil { + return nil, err + } + return params, nil + case WriteToolName: + var params WritePermissionsParams + if err := json.Unmarshal(raw, ¶ms); err != nil { + return nil, err + } + return params, nil + case MultiEditToolName: + var params MultiEditPermissionsParams + if err := json.Unmarshal(raw, ¶ms); err != nil { + return nil, err + } + return params, nil + case FetchToolName: + var params FetchPermissionsParams + if err := json.Unmarshal(raw, ¶ms); err != nil { + return nil, err + } + return params, nil + case ViewToolName: + var params ViewPermissionsParams + if err := json.Unmarshal(raw, ¶ms); err != nil { + return nil, err + } + return params, nil + case LSToolName: + var params LSPermissionsParams + if err := json.Unmarshal(raw, ¶ms); err != nil { + return nil, err + } + return params, nil + default: + // For unknown tools, keep the raw JSON as-is. + var generic map[string]any + if err := json.Unmarshal(raw, &generic); err != nil { + return nil, err + } + return generic, nil + } +} diff --git a/internal/proto/proto.go b/internal/proto/proto.go new file mode 100644 index 0000000000000000000000000000000000000000..d7477b580a8a0027ac6af1874eec2a117f587901 --- /dev/null +++ b/internal/proto/proto.go @@ -0,0 +1,123 @@ +package proto + +import ( + "time" + + "charm.land/catwalk/pkg/catwalk" + "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/lsp" +) + +// Workspace represents a running app.App workspace with its associated +// resources and state. +type Workspace struct { + ID string `json:"id"` + Path string `json:"path"` + YOLO bool `json:"yolo,omitempty"` + Debug bool `json:"debug,omitempty"` + DataDir string `json:"data_dir,omitempty"` + Config *config.Config `json:"config,omitempty"` + Env []string `json:"env,omitempty"` +} + +// Error represents an error response. +type Error struct { + Message string `json:"message"` +} + +// AgentInfo represents information about the agent. +type AgentInfo struct { + IsBusy bool `json:"is_busy"` + Model catwalk.Model `json:"model"` +} + +// IsZero checks if the AgentInfo is zero-valued. +func (a AgentInfo) IsZero() bool { + return !a.IsBusy && a.Model.ID == "" +} + +// AgentMessage represents a message sent to the agent. +type AgentMessage struct { + SessionID string `json:"session_id"` + Prompt string `json:"prompt"` + Attachments []Attachment `json:"attachments,omitempty"` +} + +// AgentSession represents a session with its busy status. +type AgentSession struct { + Session + IsBusy bool `json:"is_busy"` +} + +// IsZero checks if the AgentSession is zero-valued. +func (a AgentSession) IsZero() bool { + return a == AgentSession{} +} + +// PermissionAction represents an action taken on a permission request. +type PermissionAction string + +const ( + PermissionAllow PermissionAction = "allow" + PermissionAllowForSession PermissionAction = "allow_session" + PermissionDeny PermissionAction = "deny" +) + +// MarshalText implements the [encoding.TextMarshaler] interface. +func (p PermissionAction) MarshalText() ([]byte, error) { + return []byte(p), nil +} + +// UnmarshalText implements the [encoding.TextUnmarshaler] interface. +func (p *PermissionAction) UnmarshalText(text []byte) error { + *p = PermissionAction(text) + return nil +} + +// PermissionGrant represents a permission grant request. +type PermissionGrant struct { + Permission PermissionRequest `json:"permission"` + Action PermissionAction `json:"action"` +} + +// PermissionSkipRequest represents a request to skip permission prompts. +type PermissionSkipRequest struct { + Skip bool `json:"skip"` +} + +// LSPEventType represents the type of LSP event. +type LSPEventType string + +const ( + LSPEventStateChanged LSPEventType = "state_changed" + LSPEventDiagnosticsChanged LSPEventType = "diagnostics_changed" +) + +// MarshalText implements the [encoding.TextMarshaler] interface. +func (e LSPEventType) MarshalText() ([]byte, error) { + return []byte(e), nil +} + +// UnmarshalText implements the [encoding.TextUnmarshaler] interface. +func (e *LSPEventType) UnmarshalText(data []byte) error { + *e = LSPEventType(data) + return nil +} + +// LSPEvent represents an event in the LSP system. +type LSPEvent struct { + Type LSPEventType `json:"type"` + Name string `json:"name"` + State lsp.ServerState `json:"state"` + Error error `json:"error,omitempty"` + DiagnosticCount int `json:"diagnostic_count,omitempty"` +} + +// LSPClientInfo holds information about an LSP client's state. +type LSPClientInfo struct { + Name string `json:"name"` + State lsp.ServerState `json:"state"` + Error error `json:"error,omitempty"` + DiagnosticCount int `json:"diagnostic_count,omitempty"` + ConnectedAt time.Time `json:"connected_at"` +} diff --git a/internal/proto/server.go b/internal/proto/server.go new file mode 100644 index 0000000000000000000000000000000000000000..612772381a58aad04a5a3d1bc1216f6dd8882769 --- /dev/null +++ b/internal/proto/server.go @@ -0,0 +1,6 @@ +package proto + +// ServerControl represents a server control request. +type ServerControl struct { + Command string `json:"command"` +} diff --git a/internal/proto/session.go b/internal/proto/session.go new file mode 100644 index 0000000000000000000000000000000000000000..846ac592017e6ce447c6c6a94535d9317adad7d8 --- /dev/null +++ b/internal/proto/session.go @@ -0,0 +1,15 @@ +package proto + +// Session represents a session in the proto layer. +type Session struct { + ID string `json:"id"` + ParentSessionID string `json:"parent_session_id"` + Title string `json:"title"` + MessageCount int64 `json:"message_count"` + PromptTokens int64 `json:"prompt_tokens"` + CompletionTokens int64 `json:"completion_tokens"` + SummaryMessageID string `json:"summary_message_id"` + Cost float64 `json:"cost"` + CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` +} diff --git a/internal/proto/tools.go b/internal/proto/tools.go new file mode 100644 index 0000000000000000000000000000000000000000..09774ac0a22b672ff7df81d968db21ef35517c02 --- /dev/null +++ b/internal/proto/tools.go @@ -0,0 +1,250 @@ +package proto + +// ToolResponseType represents the type of tool response. +type ToolResponseType string + +const ( + ToolResponseTypeText ToolResponseType = "text" + ToolResponseTypeImage ToolResponseType = "image" +) + +// ToolResponse represents a response from a tool. +type ToolResponse struct { + Type ToolResponseType `json:"type"` + Content string `json:"content"` + Metadata string `json:"metadata,omitempty"` + IsError bool `json:"is_error"` +} + +const BashToolName = "bash" + +// BashParams represents the parameters for the bash tool. +type BashParams struct { + Command string `json:"command"` + Timeout int `json:"timeout"` +} + +// BashPermissionsParams represents the permission parameters for the bash tool. +type BashPermissionsParams struct { + Command string `json:"command"` + Timeout int `json:"timeout"` +} + +// BashResponseMetadata represents the metadata for a bash tool response. +type BashResponseMetadata struct { + StartTime int64 `json:"start_time"` + EndTime int64 `json:"end_time"` + Output string `json:"output"` + WorkingDirectory string `json:"working_directory"` +} + +// DiagnosticsParams represents the parameters for the diagnostics tool. +type DiagnosticsParams struct { + FilePath string `json:"file_path"` +} + +const DownloadToolName = "download" + +// DownloadParams represents the parameters for the download tool. +type DownloadParams struct { + URL string `json:"url"` + FilePath string `json:"file_path"` + Timeout int `json:"timeout,omitempty"` +} + +// DownloadPermissionsParams represents the permission parameters for the download tool. +type DownloadPermissionsParams struct { + URL string `json:"url"` + FilePath string `json:"file_path"` + Timeout int `json:"timeout,omitempty"` +} + +const EditToolName = "edit" + +// EditParams represents the parameters for the edit tool. +type EditParams struct { + FilePath string `json:"file_path"` + OldString string `json:"old_string"` + NewString string `json:"new_string"` + ReplaceAll bool `json:"replace_all,omitempty"` +} + +// EditPermissionsParams represents the permission parameters for the edit tool. +type EditPermissionsParams struct { + FilePath string `json:"file_path"` + OldContent string `json:"old_content,omitempty"` + NewContent string `json:"new_content,omitempty"` +} + +// EditResponseMetadata represents the metadata for an edit tool response. +type EditResponseMetadata struct { + Additions int `json:"additions"` + Removals int `json:"removals"` + OldContent string `json:"old_content,omitempty"` + NewContent string `json:"new_content,omitempty"` +} + +const FetchToolName = "fetch" + +// FetchParams represents the parameters for the fetch tool. +type FetchParams struct { + URL string `json:"url"` + Format string `json:"format"` + Timeout int `json:"timeout,omitempty"` +} + +// FetchPermissionsParams represents the permission parameters for the fetch tool. +type FetchPermissionsParams struct { + URL string `json:"url"` + Format string `json:"format"` + Timeout int `json:"timeout,omitempty"` +} + +const GlobToolName = "glob" + +// GlobParams represents the parameters for the glob tool. +type GlobParams struct { + Pattern string `json:"pattern"` + Path string `json:"path"` +} + +// GlobResponseMetadata represents the metadata for a glob tool response. +type GlobResponseMetadata struct { + NumberOfFiles int `json:"number_of_files"` + Truncated bool `json:"truncated"` +} + +const GrepToolName = "grep" + +// GrepParams represents the parameters for the grep tool. +type GrepParams struct { + Pattern string `json:"pattern"` + Path string `json:"path"` + Include string `json:"include"` + LiteralText bool `json:"literal_text"` +} + +// GrepResponseMetadata represents the metadata for a grep tool response. +type GrepResponseMetadata struct { + NumberOfMatches int `json:"number_of_matches"` + Truncated bool `json:"truncated"` +} + +const LSToolName = "ls" + +// LSParams represents the parameters for the ls tool. +type LSParams struct { + Path string `json:"path"` + Ignore []string `json:"ignore"` +} + +// LSPermissionsParams represents the permission parameters for the ls tool. +type LSPermissionsParams struct { + Path string `json:"path"` + Ignore []string `json:"ignore"` +} + +// TreeNode represents a node in a directory tree. +type TreeNode struct { + Name string `json:"name"` + Path string `json:"path"` + Type string `json:"type"` + Children []*TreeNode `json:"children,omitempty"` +} + +// LSResponseMetadata represents the metadata for an ls tool response. +type LSResponseMetadata struct { + NumberOfFiles int `json:"number_of_files"` + Truncated bool `json:"truncated"` +} + +const MultiEditToolName = "multiedit" + +// MultiEditOperation represents a single edit operation in a multi-edit. +type MultiEditOperation struct { + OldString string `json:"old_string"` + NewString string `json:"new_string"` + ReplaceAll bool `json:"replace_all,omitempty"` +} + +// MultiEditParams represents the parameters for the multi-edit tool. +type MultiEditParams struct { + FilePath string `json:"file_path"` + Edits []MultiEditOperation `json:"edits"` +} + +// MultiEditPermissionsParams represents the permission parameters for the multi-edit tool. +type MultiEditPermissionsParams struct { + FilePath string `json:"file_path"` + OldContent string `json:"old_content,omitempty"` + NewContent string `json:"new_content,omitempty"` +} + +// MultiEditResponseMetadata represents the metadata for a multi-edit tool response. +type MultiEditResponseMetadata struct { + Additions int `json:"additions"` + Removals int `json:"removals"` + OldContent string `json:"old_content,omitempty"` + NewContent string `json:"new_content,omitempty"` + EditsApplied int `json:"edits_applied"` +} + +const SourcegraphToolName = "sourcegraph" + +// SourcegraphParams represents the parameters for the sourcegraph tool. +type SourcegraphParams struct { + Query string `json:"query"` + Count int `json:"count,omitempty"` + ContextWindow int `json:"context_window,omitempty"` + Timeout int `json:"timeout,omitempty"` +} + +// SourcegraphResponseMetadata represents the metadata for a sourcegraph tool response. +type SourcegraphResponseMetadata struct { + NumberOfMatches int `json:"number_of_matches"` + Truncated bool `json:"truncated"` +} + +const ViewToolName = "view" + +// ViewParams represents the parameters for the view tool. +type ViewParams struct { + FilePath string `json:"file_path"` + Offset int `json:"offset"` + Limit int `json:"limit"` +} + +// ViewPermissionsParams represents the permission parameters for the view tool. +type ViewPermissionsParams struct { + FilePath string `json:"file_path"` + Offset int `json:"offset"` + Limit int `json:"limit"` +} + +// ViewResponseMetadata represents the metadata for a view tool response. +type ViewResponseMetadata struct { + FilePath string `json:"file_path"` + Content string `json:"content"` +} + +const WriteToolName = "write" + +// WriteParams represents the parameters for the write tool. +type WriteParams struct { + FilePath string `json:"file_path"` + Content string `json:"content"` +} + +// WritePermissionsParams represents the permission parameters for the write tool. +type WritePermissionsParams struct { + FilePath string `json:"file_path"` + OldContent string `json:"old_content,omitempty"` + NewContent string `json:"new_content,omitempty"` +} + +// WriteResponseMetadata represents the metadata for a write tool response. +type WriteResponseMetadata struct { + Diff string `json:"diff"` + Additions int `json:"additions"` + Removals int `json:"removals"` +} diff --git a/internal/proto/version.go b/internal/proto/version.go new file mode 100644 index 0000000000000000000000000000000000000000..b728a8b966068a7810f86aae74cfcc6f57e03d39 --- /dev/null +++ b/internal/proto/version.go @@ -0,0 +1,9 @@ +package proto + +// VersionInfo represents version information about the server. +type VersionInfo struct { + Version string `json:"version"` + Commit string `json:"commit"` + GoVersion string `json:"go_version"` + Platform string `json:"platform"` +} diff --git a/internal/pubsub/events.go b/internal/pubsub/events.go index 827158d52fd671aeda828c0383fce98850e27fc7..44963e3cfbdefc2ddc4657c293615df5329d885d 100644 --- a/internal/pubsub/events.go +++ b/internal/pubsub/events.go @@ -1,6 +1,9 @@ package pubsub -import "context" +import ( + "context" + "encoding/json" +) const ( CreatedEvent EventType = "created" @@ -8,20 +11,43 @@ const ( DeletedEvent EventType = "deleted" ) +// PayloadType identifies the type of event payload for discriminated +// deserialization over JSON. +type PayloadType = string + +const ( + PayloadTypeLSPEvent PayloadType = "lsp_event" + PayloadTypeMCPEvent PayloadType = "mcp_event" + PayloadTypePermissionRequest PayloadType = "permission_request" + PayloadTypePermissionNotification PayloadType = "permission_notification" + PayloadTypeMessage PayloadType = "message" + PayloadTypeSession PayloadType = "session" + PayloadTypeFile PayloadType = "file" + PayloadTypeAgentEvent PayloadType = "agent_event" +) + +// Payload wraps a discriminated JSON payload with a type tag. +type Payload struct { + Type PayloadType `json:"type"` + Payload json.RawMessage `json:"payload"` +} + +// Subscriber can subscribe to events of type T. type Subscriber[T any] interface { Subscribe(context.Context) <-chan Event[T] } type ( - // EventType identifies the type of event + // EventType identifies the type of event. EventType string - // Event represents an event in the lifecycle of a resource + // Event represents an event in the lifecycle of a resource. Event[T any] struct { - Type EventType - Payload T + Type EventType `json:"type"` + Payload T `json:"payload"` } + // Publisher can publish events of type T. Publisher[T any] interface { Publish(EventType, T) } diff --git a/internal/server/logging.go b/internal/server/logging.go new file mode 100644 index 0000000000000000000000000000000000000000..736e3d57cfb6697a07cc61a03c4157a42140df54 --- /dev/null +++ b/internal/server/logging.go @@ -0,0 +1,51 @@ +package server + +import ( + "log/slog" + "net/http" + "time" +) + +func (s *Server) loggingHandler(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if s.logger == nil { + next.ServeHTTP(w, r) + return + } + + start := time.Now() + lrw := &loggingResponseWriter{ResponseWriter: w, statusCode: http.StatusOK} + s.logger.Debug("HTTP request", + slog.String("method", r.Method), + slog.String("path", r.URL.Path), + slog.String("remote_addr", r.RemoteAddr), + slog.String("user_agent", r.UserAgent()), + ) + + next.ServeHTTP(lrw, r) + duration := time.Since(start) + + s.logger.Debug("HTTP response", + slog.String("method", r.Method), + slog.String("path", r.URL.Path), + slog.Int("status", lrw.statusCode), + slog.Duration("duration", duration), + slog.String("remote_addr", r.RemoteAddr), + slog.String("user_agent", r.UserAgent()), + ) + }) +} + +type loggingResponseWriter struct { + http.ResponseWriter + statusCode int +} + +func (lrw *loggingResponseWriter) WriteHeader(code int) { + lrw.statusCode = code + lrw.ResponseWriter.WriteHeader(code) +} + +func (lrw *loggingResponseWriter) Unwrap() http.ResponseWriter { + return lrw.ResponseWriter +} diff --git a/internal/server/net_other.go b/internal/server/net_other.go new file mode 100644 index 0000000000000000000000000000000000000000..b1fba90cf306b45c5764eb3702d2da642122ca69 --- /dev/null +++ b/internal/server/net_other.go @@ -0,0 +1,10 @@ +//go:build !windows +// +build !windows + +package server + +import "net" + +func listen(network, address string) (net.Listener, error) { + return net.Listen(network, address) +} diff --git a/internal/server/net_windows.go b/internal/server/net_windows.go new file mode 100644 index 0000000000000000000000000000000000000000..fc1ed8e2c298b740b3611b42a7485cb880ff76ca --- /dev/null +++ b/internal/server/net_windows.go @@ -0,0 +1,24 @@ +//go:build windows +// +build windows + +package server + +import ( + "net" + + "github.com/Microsoft/go-winio" +) + +func listen(network, address string) (net.Listener, error) { + switch network { + case "npipe": + cfg := &winio.PipeConfig{ + MessageMode: true, + InputBufferSize: 65536, + OutputBufferSize: 65536, + } + return winio.ListenPipe(address, cfg) + default: + return net.Listen(network, address) + } +} diff --git a/internal/server/proto.go b/internal/server/proto.go new file mode 100644 index 0000000000000000000000000000000000000000..588a485e362e07e67d4785c31101590c75320382 --- /dev/null +++ b/internal/server/proto.go @@ -0,0 +1,656 @@ +package server + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "net/http" + "os" + "path/filepath" + "runtime" + + "github.com/charmbracelet/crush/internal/app" + "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/db" + "github.com/charmbracelet/crush/internal/permission" + "github.com/charmbracelet/crush/internal/proto" + "github.com/charmbracelet/crush/internal/session" + "github.com/charmbracelet/crush/internal/version" + "github.com/google/uuid" +) + +type controllerV1 struct { + *Server +} + +func (c *controllerV1) handleGetHealth(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) +} + +func (c *controllerV1) handleGetVersion(w http.ResponseWriter, _ *http.Request) { + jsonEncode(w, proto.VersionInfo{ + Version: version.Version, + Commit: version.Commit, + GoVersion: runtime.Version(), + Platform: fmt.Sprintf("%s/%s", runtime.GOOS, runtime.GOARCH), + }) +} + +func (c *controllerV1) handlePostControl(w http.ResponseWriter, r *http.Request) { + var req proto.ServerControl + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + c.logError(r, "Failed to decode request", "error", err) + jsonError(w, http.StatusBadRequest, "failed to decode request") + return + } + + switch req.Command { + case "shutdown": + go func() { + slog.Info("Shutting down server...") + if err := c.Shutdown(context.Background()); err != nil { + slog.Error("Failed to shutdown server", "error", err) + } + }() + default: + c.logError(r, "Unknown command", "command", req.Command) + jsonError(w, http.StatusBadRequest, "unknown command") + return + } +} + +func (c *controllerV1) handleGetConfig(w http.ResponseWriter, _ *http.Request) { + jsonEncode(w, c.cfg) +} + +func (c *controllerV1) handleGetWorkspaces(w http.ResponseWriter, _ *http.Request) { + workspaces := []proto.Workspace{} + for _, ws := range c.workspaces.Seq2() { + workspaces = append(workspaces, proto.Workspace{ + ID: ws.id, + Path: ws.path, + YOLO: ws.cfg.Permissions != nil && ws.cfg.Permissions.SkipRequests, + DataDir: ws.cfg.Options.DataDirectory, + Debug: ws.cfg.Options.Debug, + Config: ws.cfg, + }) + } + jsonEncode(w, workspaces) +} + +func (c *controllerV1) handleGetWorkspaceLSPDiagnostics(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + ws, ok := c.workspaces.Get(id) + if !ok { + c.logError(r, "Workspace not found", "id", id) + jsonError(w, http.StatusNotFound, "workspace not found") + return + } + + lspName := r.PathValue("lsp") + var found bool + for name, client := range ws.LSPManager.Clients().Seq2() { + if name == lspName { + diagnostics := client.GetDiagnostics() + jsonEncode(w, diagnostics) + found = true + break + } + } + + if !found { + c.logError(r, "LSP client not found", "id", id, "lsp", lspName) + jsonError(w, http.StatusNotFound, "LSP client not found") + } +} + +func (c *controllerV1) handleGetWorkspaceLSPs(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + _, ok := c.workspaces.Get(id) + if !ok { + c.logError(r, "Workspace not found", "id", id) + jsonError(w, http.StatusNotFound, "workspace not found") + return + } + + lspClients := app.GetLSPStates() + jsonEncode(w, lspClients) +} + +func (c *controllerV1) handleGetWorkspaceAgentSessionPromptQueued(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + ws, ok := c.workspaces.Get(id) + if !ok { + c.logError(r, "Workspace not found", "id", id) + jsonError(w, http.StatusNotFound, "workspace not found") + return + } + + sid := r.PathValue("sid") + queued := ws.App.AgentCoordinator.QueuedPrompts(sid) + jsonEncode(w, queued) +} + +func (c *controllerV1) handlePostWorkspaceAgentSessionPromptClear(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + ws, ok := c.workspaces.Get(id) + if !ok { + c.logError(r, "Workspace not found", "id", id) + jsonError(w, http.StatusNotFound, "workspace not found") + return + } + + sid := r.PathValue("sid") + ws.App.AgentCoordinator.ClearQueue(sid) + w.WriteHeader(http.StatusOK) +} + +func (c *controllerV1) handleGetWorkspaceAgentSessionSummarize(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + ws, ok := c.workspaces.Get(id) + if !ok { + c.logError(r, "Workspace not found", "id", id) + jsonError(w, http.StatusNotFound, "workspace not found") + return + } + + sid := r.PathValue("sid") + if err := ws.App.AgentCoordinator.Summarize(r.Context(), sid); err != nil { + c.logError(r, "Failed to summarize session", "error", err, "id", id, "sid", sid) + jsonError(w, http.StatusInternalServerError, "failed to summarize session") + return + } +} + +func (c *controllerV1) handlePostWorkspaceAgentSessionCancel(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + ws, ok := c.workspaces.Get(id) + if !ok { + c.logError(r, "Workspace not found", "id", id) + jsonError(w, http.StatusNotFound, "workspace not found") + return + } + + sid := r.PathValue("sid") + if ws.App.AgentCoordinator != nil { + ws.App.AgentCoordinator.Cancel(sid) + } + w.WriteHeader(http.StatusOK) +} + +func (c *controllerV1) handleGetWorkspaceAgentSession(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + ws, ok := c.workspaces.Get(id) + if !ok { + c.logError(r, "Workspace not found", "id", id) + jsonError(w, http.StatusNotFound, "workspace not found") + return + } + + sid := r.PathValue("sid") + se, err := ws.App.Sessions.Get(r.Context(), sid) + if err != nil { + c.logError(r, "Failed to get session", "error", err, "id", id, "sid", sid) + jsonError(w, http.StatusInternalServerError, "failed to get session") + return + } + + var isSessionBusy bool + if ws.App.AgentCoordinator != nil { + isSessionBusy = ws.App.AgentCoordinator.IsSessionBusy(sid) + } + + jsonEncode(w, proto.AgentSession{ + Session: proto.Session{ + ID: se.ID, + Title: se.Title, + }, + IsBusy: isSessionBusy, + }) +} + +func (c *controllerV1) handlePostWorkspaceAgent(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + ws, ok := c.workspaces.Get(id) + if !ok { + c.logError(r, "Workspace not found", "id", id) + jsonError(w, http.StatusNotFound, "workspace not found") + return + } + + w.Header().Set("Accept", "application/json") + + var msg proto.AgentMessage + if err := json.NewDecoder(r.Body).Decode(&msg); err != nil { + c.logError(r, "Failed to decode request", "error", err) + jsonError(w, http.StatusBadRequest, "failed to decode request") + return + } + + if ws.App.AgentCoordinator == nil { + c.logError(r, "Agent coordinator not initialized", "id", id) + jsonError(w, http.StatusBadRequest, "agent coordinator not initialized") + return + } + + if _, err := ws.App.AgentCoordinator.Run(c.ctx, msg.SessionID, msg.Prompt); err != nil { + c.logError(r, "Failed to enqueue message", "error", err, "id", id, "sid", msg.SessionID) + jsonError(w, http.StatusInternalServerError, "failed to enqueue message") + return + } +} + +func (c *controllerV1) handleGetWorkspaceAgent(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + ws, ok := c.workspaces.Get(id) + if !ok { + c.logError(r, "Workspace not found", "id", id) + jsonError(w, http.StatusNotFound, "workspace not found") + return + } + + var agentInfo proto.AgentInfo + if ws.App.AgentCoordinator != nil { + m := ws.App.AgentCoordinator.Model() + agentInfo = proto.AgentInfo{ + Model: m.CatwalkCfg, + IsBusy: ws.App.AgentCoordinator.IsBusy(), + } + } + jsonEncode(w, agentInfo) +} + +func (c *controllerV1) handlePostWorkspaceAgentUpdate(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + ws, ok := c.workspaces.Get(id) + if !ok { + c.logError(r, "Workspace not found", "id", id) + jsonError(w, http.StatusNotFound, "workspace not found") + return + } + + if err := ws.App.UpdateAgentModel(r.Context()); err != nil { + c.logError(r, "Failed to update agent model", "error", err) + jsonError(w, http.StatusInternalServerError, "failed to update agent model") + return + } +} + +func (c *controllerV1) handlePostWorkspaceAgentInit(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + ws, ok := c.workspaces.Get(id) + if !ok { + c.logError(r, "Workspace not found", "id", id) + jsonError(w, http.StatusNotFound, "workspace not found") + return + } + + if err := ws.App.InitCoderAgent(r.Context()); err != nil { + c.logError(r, "Failed to initialize coder agent", "error", err) + jsonError(w, http.StatusInternalServerError, "failed to initialize coder agent") + return + } +} + +func (c *controllerV1) handleGetWorkspaceSessionHistory(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + ws, ok := c.workspaces.Get(id) + if !ok { + c.logError(r, "Workspace not found", "id", id) + jsonError(w, http.StatusNotFound, "workspace not found") + return + } + + sid := r.PathValue("sid") + historyItems, err := ws.App.History.ListBySession(r.Context(), sid) + if err != nil { + c.logError(r, "Failed to list history", "error", err, "id", id, "sid", sid) + jsonError(w, http.StatusInternalServerError, "failed to list history") + return + } + + jsonEncode(w, historyItems) +} + +func (c *controllerV1) handleGetWorkspaceSessionMessages(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + ws, ok := c.workspaces.Get(id) + if !ok { + c.logError(r, "Workspace not found", "id", id) + jsonError(w, http.StatusNotFound, "workspace not found") + return + } + + sid := r.PathValue("sid") + messages, err := ws.App.Messages.List(r.Context(), sid) + if err != nil { + c.logError(r, "Failed to list messages", "error", err, "id", id, "sid", sid) + jsonError(w, http.StatusInternalServerError, "failed to list messages") + return + } + + jsonEncode(w, messages) +} + +func (c *controllerV1) handleGetWorkspaceSession(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + ws, ok := c.workspaces.Get(id) + if !ok { + c.logError(r, "Workspace not found", "id", id) + jsonError(w, http.StatusNotFound, "workspace not found") + return + } + + sid := r.PathValue("sid") + sess, err := ws.App.Sessions.Get(r.Context(), sid) + if err != nil { + c.logError(r, "Failed to get session", "error", err, "id", id, "sid", sid) + jsonError(w, http.StatusInternalServerError, "failed to get session") + return + } + + jsonEncode(w, sess) +} + +func (c *controllerV1) handlePostWorkspaceSessions(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + ws, ok := c.workspaces.Get(id) + if !ok { + c.logError(r, "Workspace not found", "id", id) + jsonError(w, http.StatusNotFound, "workspace not found") + return + } + + var args session.Session + if err := json.NewDecoder(r.Body).Decode(&args); err != nil { + c.logError(r, "Failed to decode request", "error", err) + jsonError(w, http.StatusBadRequest, "failed to decode request") + return + } + + sess, err := ws.App.Sessions.Create(r.Context(), args.Title) + if err != nil { + c.logError(r, "Failed to create session", "error", err, "id", id) + jsonError(w, http.StatusInternalServerError, "failed to create session") + return + } + + jsonEncode(w, sess) +} + +func (c *controllerV1) handleGetWorkspaceSessions(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + ws, ok := c.workspaces.Get(id) + if !ok { + c.logError(r, "Workspace not found", "id", id) + jsonError(w, http.StatusNotFound, "workspace not found") + return + } + + sessions, err := ws.App.Sessions.List(r.Context()) + if err != nil { + c.logError(r, "Failed to list sessions", "error", err) + jsonError(w, http.StatusInternalServerError, "failed to list sessions") + return + } + + jsonEncode(w, sessions) +} + +func (c *controllerV1) handlePostWorkspacePermissionsGrant(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + ws, ok := c.workspaces.Get(id) + if !ok { + c.logError(r, "Workspace not found", "id", id) + jsonError(w, http.StatusNotFound, "workspace not found") + return + } + + var req proto.PermissionGrant + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + c.logError(r, "Failed to decode request", "error", err) + jsonError(w, http.StatusBadRequest, "failed to decode request") + return + } + + perm := permission.PermissionRequest{ + ID: req.Permission.ID, + SessionID: req.Permission.SessionID, + ToolCallID: req.Permission.ToolCallID, + ToolName: req.Permission.ToolName, + Description: req.Permission.Description, + Action: req.Permission.Action, + Params: req.Permission.Params, + Path: req.Permission.Path, + } + + switch req.Action { + case proto.PermissionAllow: + ws.App.Permissions.Grant(perm) + case proto.PermissionAllowForSession: + ws.App.Permissions.GrantPersistent(perm) + case proto.PermissionDeny: + ws.App.Permissions.Deny(perm) + default: + c.logError(r, "Invalid permission action", "action", req.Action) + jsonError(w, http.StatusBadRequest, "invalid permission action") + return + } +} + +func (c *controllerV1) handlePostWorkspacePermissionsSkip(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + ws, ok := c.workspaces.Get(id) + if !ok { + c.logError(r, "Workspace not found", "id", id) + jsonError(w, http.StatusNotFound, "workspace not found") + return + } + + var req proto.PermissionSkipRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + c.logError(r, "Failed to decode request", "error", err) + jsonError(w, http.StatusBadRequest, "failed to decode request") + return + } + + ws.App.Permissions.SetSkipRequests(req.Skip) +} + +func (c *controllerV1) handleGetWorkspacePermissionsSkip(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + ws, ok := c.workspaces.Get(id) + if !ok { + c.logError(r, "Workspace not found", "id", id) + jsonError(w, http.StatusNotFound, "workspace not found") + return + } + + skip := ws.App.Permissions.SkipRequests() + jsonEncode(w, proto.PermissionSkipRequest{Skip: skip}) +} + +func (c *controllerV1) handleGetWorkspaceProviders(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + ws, ok := c.workspaces.Get(id) + if !ok { + c.logError(r, "Workspace not found", "id", id) + jsonError(w, http.StatusNotFound, "workspace not found") + return + } + + providers, _ := config.Providers(ws.cfg) + jsonEncode(w, providers) +} + +func (c *controllerV1) handleGetWorkspaceEvents(w http.ResponseWriter, r *http.Request) { + flusher := http.NewResponseController(w) + id := r.PathValue("id") + ws, ok := c.workspaces.Get(id) + if !ok { + c.logError(r, "Workspace not found", "id", id) + jsonError(w, http.StatusNotFound, "workspace not found") + return + } + + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + + events := ws.App.Events() + + for { + select { + case <-r.Context().Done(): + c.logDebug(r, "Stopping event stream") + return + case ev, ok := <-events: + if !ok { + return + } + c.logDebug(r, "Sending event", "event", fmt.Sprintf("%T %+v", ev, ev)) + data, err := json.Marshal(ev) + if err != nil { + c.logError(r, "Failed to marshal event", "error", err) + continue + } + + fmt.Fprintf(w, "data: %s\n\n", data) + flusher.Flush() + } + } +} + +func (c *controllerV1) handleGetWorkspaceConfig(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + ws, ok := c.workspaces.Get(id) + if !ok { + c.logError(r, "Workspace not found", "id", id) + jsonError(w, http.StatusNotFound, "workspace not found") + return + } + + jsonEncode(w, ws.cfg) +} + +func (c *controllerV1) handleDeleteWorkspaces(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + ws, ok := c.workspaces.Get(id) + if ok { + ws.App.Shutdown() + } + c.workspaces.Del(id) +} + +func (c *controllerV1) handleGetWorkspace(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + ws, ok := c.workspaces.Get(id) + if !ok { + c.logError(r, "Workspace not found", "id", id) + jsonError(w, http.StatusNotFound, "workspace not found") + return + } + + jsonEncode(w, proto.Workspace{ + ID: ws.id, + Path: ws.path, + YOLO: ws.cfg.Permissions != nil && ws.cfg.Permissions.SkipRequests, + DataDir: ws.cfg.Options.DataDirectory, + Debug: ws.cfg.Options.Debug, + Config: ws.cfg, + }) +} + +func (c *controllerV1) handlePostWorkspaces(w http.ResponseWriter, r *http.Request) { + var args proto.Workspace + if err := json.NewDecoder(r.Body).Decode(&args); err != nil { + c.logError(r, "Failed to decode request", "error", err) + jsonError(w, http.StatusBadRequest, "failed to decode request") + return + } + + if args.Path == "" { + c.logError(r, "Path is required") + jsonError(w, http.StatusBadRequest, "path is required") + return + } + + id := uuid.New().String() + cfg, err := config.Init(args.Path, args.DataDir, args.Debug) + if err != nil { + c.logError(r, "Failed to initialize config", "error", err) + jsonError(w, http.StatusBadRequest, fmt.Sprintf("failed to initialize config: %v", err)) + return + } + + if cfg.Permissions == nil { + cfg.Permissions = &config.Permissions{} + } + cfg.Permissions.SkipRequests = args.YOLO + + if err := createDotCrushDir(cfg.Options.DataDirectory); err != nil { + c.logError(r, "Failed to create data directory", "error", err) + jsonError(w, http.StatusInternalServerError, "failed to create data directory") + return + } + + conn, err := db.Connect(c.ctx, cfg.Options.DataDirectory) + if err != nil { + c.logError(r, "Failed to connect to database", "error", err) + jsonError(w, http.StatusInternalServerError, "failed to connect to database") + return + } + + appWorkspace, err := app.New(c.ctx, conn, cfg) + if err != nil { + slog.Error("Failed to create app workspace", "error", err) + jsonError(w, http.StatusInternalServerError, "failed to create app workspace") + return + } + + ws := &Workspace{ + App: appWorkspace, + id: id, + path: args.Path, + cfg: cfg, + env: args.Env, + } + + c.workspaces.Set(id, ws) + jsonEncode(w, proto.Workspace{ + ID: id, + Path: args.Path, + DataDir: cfg.Options.DataDirectory, + Debug: cfg.Options.Debug, + YOLO: cfg.Permissions.SkipRequests, + Config: cfg, + Env: args.Env, + }) +} + +func createDotCrushDir(dir string) error { + if err := os.MkdirAll(dir, 0o700); err != nil { + return fmt.Errorf("failed to create data directory: %q %w", dir, err) + } + + gitIgnorePath := filepath.Join(dir, ".gitignore") + if _, err := os.Stat(gitIgnorePath); os.IsNotExist(err) { + if err := os.WriteFile(gitIgnorePath, []byte("*\n"), 0o644); err != nil { + return fmt.Errorf("failed to create .gitignore file: %q %w", gitIgnorePath, err) + } + } + + return nil +} + +func jsonEncode(w http.ResponseWriter, v any) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(v) +} + +func jsonError(w http.ResponseWriter, status int, message string) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + _ = json.NewEncoder(w).Encode(proto.Error{Message: message}) +} diff --git a/internal/server/server.go b/internal/server/server.go new file mode 100644 index 0000000000000000000000000000000000000000..38fa9109a5257b00cdfd1d7ab2b9fc7c7ad9fc6f --- /dev/null +++ b/internal/server/server.go @@ -0,0 +1,208 @@ +package server + +import ( + "context" + "fmt" + "log/slog" + "net" + "net/http" + "net/url" + "os/user" + "runtime" + "strings" + + "github.com/charmbracelet/crush/internal/app" + "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/csync" +) + +// ErrServerClosed is returned when the server is closed. +var ErrServerClosed = http.ErrServerClosed + +// Workspace represents a running [app.App] workspace with its associated +// resources and state. +type Workspace struct { + *app.App + ln net.Listener + cfg *config.Config + id string + path string + env []string +} + +// ParseHostURL parses a host URL into a [url.URL]. +func ParseHostURL(host string) (*url.URL, error) { + proto, addr, ok := strings.Cut(host, "://") + if !ok { + return nil, fmt.Errorf("invalid host format: %s", host) + } + + var basePath string + if proto == "tcp" { + parsed, err := url.Parse("tcp://" + addr) + if err != nil { + return nil, fmt.Errorf("invalid tcp address: %v", err) + } + addr = parsed.Host + basePath = parsed.Path + } + return &url.URL{ + Scheme: proto, + Host: addr, + Path: basePath, + }, nil +} + +// DefaultHost returns the default server host. +func DefaultHost() string { + sock := "crush.sock" + usr, err := user.Current() + if err == nil && usr.Uid != "" { + sock = fmt.Sprintf("crush-%s.sock", usr.Uid) + } + if runtime.GOOS == "windows" { + return fmt.Sprintf("npipe:////./pipe/%s", sock) + } + return fmt.Sprintf("unix:///tmp/%s", sock) +} + +// Server represents a Crush server bound to a specific address. +type Server struct { + // Addr can be a TCP address, a Unix socket path, or a Windows named pipe. + Addr string + network string + + h *http.Server + ln net.Listener + ctx context.Context + + // workspaces is a map of running applications managed by the server. + workspaces *csync.Map[string, *Workspace] + cfg *config.Config + logger *slog.Logger +} + +// SetLogger sets the logger for the server. +func (s *Server) SetLogger(logger *slog.Logger) { + s.logger = logger +} + +// DefaultServer returns a new [Server] with the default address. +func DefaultServer(cfg *config.Config) *Server { + hostURL, err := ParseHostURL(DefaultHost()) + if err != nil { + panic("invalid default host") + } + return NewServer(cfg, hostURL.Scheme, hostURL.Host) +} + +// NewServer creates a new [Server] with the given network and address. +func NewServer(cfg *config.Config, network, address string) *Server { + s := new(Server) + s.Addr = address + s.network = network + s.cfg = cfg + s.workspaces = csync.NewMap[string, *Workspace]() + s.ctx = context.Background() + + var p http.Protocols + p.SetHTTP1(true) + p.SetUnencryptedHTTP2(true) + c := &controllerV1{Server: s} + mux := http.NewServeMux() + mux.HandleFunc("GET /v1/health", c.handleGetHealth) + mux.HandleFunc("GET /v1/version", c.handleGetVersion) + mux.HandleFunc("GET /v1/config", c.handleGetConfig) + mux.HandleFunc("POST /v1/control", c.handlePostControl) + mux.HandleFunc("GET /v1/workspaces", c.handleGetWorkspaces) + mux.HandleFunc("POST /v1/workspaces", c.handlePostWorkspaces) + mux.HandleFunc("DELETE /v1/workspaces/{id}", c.handleDeleteWorkspaces) + mux.HandleFunc("GET /v1/workspaces/{id}", c.handleGetWorkspace) + mux.HandleFunc("GET /v1/workspaces/{id}/config", c.handleGetWorkspaceConfig) + mux.HandleFunc("GET /v1/workspaces/{id}/events", c.handleGetWorkspaceEvents) + mux.HandleFunc("GET /v1/workspaces/{id}/providers", c.handleGetWorkspaceProviders) + mux.HandleFunc("GET /v1/workspaces/{id}/sessions", c.handleGetWorkspaceSessions) + mux.HandleFunc("POST /v1/workspaces/{id}/sessions", c.handlePostWorkspaceSessions) + mux.HandleFunc("GET /v1/workspaces/{id}/sessions/{sid}", c.handleGetWorkspaceSession) + mux.HandleFunc("GET /v1/workspaces/{id}/sessions/{sid}/history", c.handleGetWorkspaceSessionHistory) + mux.HandleFunc("GET /v1/workspaces/{id}/sessions/{sid}/messages", c.handleGetWorkspaceSessionMessages) + mux.HandleFunc("GET /v1/workspaces/{id}/lsps", c.handleGetWorkspaceLSPs) + mux.HandleFunc("GET /v1/workspaces/{id}/lsps/{lsp}/diagnostics", c.handleGetWorkspaceLSPDiagnostics) + mux.HandleFunc("GET /v1/workspaces/{id}/permissions/skip", c.handleGetWorkspacePermissionsSkip) + mux.HandleFunc("POST /v1/workspaces/{id}/permissions/skip", c.handlePostWorkspacePermissionsSkip) + mux.HandleFunc("POST /v1/workspaces/{id}/permissions/grant", c.handlePostWorkspacePermissionsGrant) + mux.HandleFunc("GET /v1/workspaces/{id}/agent", c.handleGetWorkspaceAgent) + mux.HandleFunc("POST /v1/workspaces/{id}/agent", c.handlePostWorkspaceAgent) + mux.HandleFunc("POST /v1/workspaces/{id}/agent/init", c.handlePostWorkspaceAgentInit) + mux.HandleFunc("POST /v1/workspaces/{id}/agent/update", c.handlePostWorkspaceAgentUpdate) + mux.HandleFunc("GET /v1/workspaces/{id}/agent/sessions/{sid}", c.handleGetWorkspaceAgentSession) + mux.HandleFunc("POST /v1/workspaces/{id}/agent/sessions/{sid}/cancel", c.handlePostWorkspaceAgentSessionCancel) + mux.HandleFunc("GET /v1/workspaces/{id}/agent/sessions/{sid}/prompts/queued", c.handleGetWorkspaceAgentSessionPromptQueued) + mux.HandleFunc("POST /v1/workspaces/{id}/agent/sessions/{sid}/prompts/clear", c.handlePostWorkspaceAgentSessionPromptClear) + mux.HandleFunc("POST /v1/workspaces/{id}/agent/sessions/{sid}/summarize", c.handleGetWorkspaceAgentSessionSummarize) + s.h = &http.Server{ + Protocols: &p, + Handler: s.loggingHandler(mux), + } + if network == "tcp" { + s.h.Addr = address + } + return s +} + +// Serve accepts incoming connections on the listener. +func (s *Server) Serve(ln net.Listener) error { + return s.h.Serve(ln) +} + +// ListenAndServe starts the server and begins accepting connections. +func (s *Server) ListenAndServe() error { + if s.ln != nil { + return fmt.Errorf("server already started") + } + ln, err := listen(s.network, s.Addr) + if err != nil { + return fmt.Errorf("failed to listen on %s: %w", s.Addr, err) + } + return s.Serve(ln) +} + +func (s *Server) closeListener() { + if s.ln != nil { + s.ln.Close() + s.ln = nil + } +} + +// Close force closes all listeners and connections. +func (s *Server) Close() error { + defer func() { s.closeListener() }() + return s.h.Close() +} + +// Shutdown gracefully shuts down the server without interrupting active +// connections. +func (s *Server) Shutdown(ctx context.Context) error { + defer func() { s.closeListener() }() + return s.h.Shutdown(ctx) +} + +func (s *Server) logDebug(r *http.Request, msg string, args ...any) { + if s.logger != nil { + s.logger.With( + slog.String("method", r.Method), + slog.String("url", r.URL.String()), + slog.String("remote_addr", r.RemoteAddr), + ).Debug(msg, args...) + } +} + +func (s *Server) logError(r *http.Request, msg string, args ...any) { + if s.logger != nil { + s.logger.With( + slog.String("method", r.Method), + slog.String("url", r.URL.String()), + slog.String("remote_addr", r.RemoteAddr), + ).Error(msg, args...) + } +} diff --git a/internal/ui/model/onboarding.go b/internal/ui/model/onboarding.go index 075067d75333fc539152f0041b4e5a3c2eed1c5e..deff34ced3481b0e4af65893bfebe0e66db0135e 100644 --- a/internal/ui/model/onboarding.go +++ b/internal/ui/model/onboarding.go @@ -57,7 +57,10 @@ func (m *UI) initializeProject() tea.Cmd { initialize := func() tea.Msg { initPrompt, err := agent.InitializePrompt(*cfg) if err != nil { - return util.InfoMsg{Type: util.InfoTypeError, Msg: err.Error()} + return util.InfoMsg{ + Type: util.InfoTypeError, + Msg: fmt.Sprintf("Failed to initialize project: %v", err), + } } return sendMessageMsg{Content: initPrompt} } diff --git a/internal/ui/model/ui.go b/internal/ui/model/ui.go index 3e840faeffe1523eeb0346c07baa4f751733651d..da8663cff5d20c43da09ac70587457c5fda8fe5b 100644 --- a/internal/ui/model/ui.go +++ b/internal/ui/model/ui.go @@ -740,6 +740,9 @@ func (m *UI) Update(msg tea.Msg) (tea.Model, tea.Cmd) { cmds = append(cmds, cmd) } case util.InfoMsg: + if msg.Type == util.InfoTypeError { + slog.Error("Error reported", "error", msg.Msg) + } m.status.SetInfoMsg(msg) ttl := msg.TTL if ttl <= 0 { @@ -2753,7 +2756,7 @@ func (m *UI) sendMessage(content string, attachments ...message.Attachment) tea. } return util.InfoMsg{ Type: util.InfoTypeError, - Msg: err.Error(), + Msg: fmt.Sprintf("Failed to run agent: %v", err), } } return nil diff --git a/internal/ui/util/util.go b/internal/ui/util/util.go index 7a53df7d1e4e676b3b142de9ec74deff614c8af2..b8cd107753009c9853709f61aeb4f99b19b71d14 100644 --- a/internal/ui/util/util.go +++ b/internal/ui/util/util.go @@ -4,7 +4,6 @@ package util import ( "context" "errors" - "log/slog" "os/exec" "time" @@ -23,7 +22,6 @@ func CmdHandler(msg tea.Msg) tea.Cmd { } func ReportError(err error) tea.Cmd { - slog.Error("Error reported", "error", err) return CmdHandler(NewErrorMsg(err)) } diff --git a/internal/version/version.go b/internal/version/version.go index 6faef3251ca071a0a210ac1bc2327ca848a73ad0..3eb4f74139a752c1567986a8b9344913d55f08b1 100644 --- a/internal/version/version.go +++ b/internal/version/version.go @@ -2,9 +2,12 @@ package version import "runtime/debug" -// Build-time parameters set via -ldflags +// Build-time parameters set via -ldflags. -var Version = "devel" +var ( + Version = "devel" + Commit = "unknown" +) // A user may install crush using `go install github.com/charmbracelet/crush@latest`. // without -ldflags, in which case the version above is unset. As a workaround