From e70c8ba8f6dedd224ebaa712f374b9a54d10c669 Mon Sep 17 00:00:00 2001 From: Ayman Bagabas Date: Thu, 12 Mar 2026 18:24:29 +0300 Subject: [PATCH] refactor: use client and server for workspace access and management --- internal/backend/agent.go | 32 +- internal/backend/config.go | 165 ++++++ internal/backend/events.go | 40 ++ internal/backend/filetracker.go | 37 ++ internal/backend/session.go | 43 +- internal/client/client.go | 4 + internal/client/config.go | 252 +++++++++ internal/client/proto.go | 374 ++++++++++++-- internal/cmd/root.go | 132 +++-- internal/proto/agent.go | 7 +- internal/proto/mcp.go | 107 +++- internal/proto/proto.go | 82 ++- internal/server/config.go | 292 +++++++++++ internal/server/events.go | 214 ++++++++ internal/server/proto.go | 163 +++++- internal/server/server.go | 27 + internal/ui/common/common.go | 19 +- internal/ui/dialog/api_key_input.go | 6 +- internal/ui/dialog/filepicker.go | 2 +- internal/ui/dialog/models.go | 2 +- internal/ui/dialog/oauth.go | 4 +- internal/ui/dialog/sessions.go | 10 +- internal/ui/model/header.go | 31 +- internal/ui/model/history.go | 4 +- internal/ui/model/landing.go | 10 +- internal/ui/model/lsp.go | 8 +- internal/ui/model/onboarding.go | 10 +- internal/ui/model/pills.go | 4 +- internal/ui/model/session.go | 8 +- internal/ui/model/sidebar.go | 10 +- internal/ui/model/ui.go | 125 +++-- internal/workspace/app_workspace.go | 370 +++++++++++++ internal/workspace/client_workspace.go | 690 +++++++++++++++++++++++++ internal/workspace/workspace.go | 150 ++++++ 34 files changed, 3172 insertions(+), 262 deletions(-) create mode 100644 internal/backend/config.go create mode 100644 internal/backend/filetracker.go create mode 100644 internal/client/config.go create mode 100644 internal/server/config.go create mode 100644 internal/server/events.go create mode 100644 internal/workspace/app_workspace.go create mode 100644 internal/workspace/client_workspace.go create mode 100644 internal/workspace/workspace.go diff --git a/internal/backend/agent.go b/internal/backend/agent.go index 1859b3f67aa915c2115ea2f524e73a1e824e840a..8cc2ada26ca737bbe79fe989cc52e4ede9712dc5 100644 --- a/internal/backend/agent.go +++ b/internal/backend/agent.go @@ -3,6 +3,7 @@ package backend import ( "context" + "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/proto" ) @@ -33,8 +34,10 @@ func (b *Backend) GetAgentInfo(workspaceID string) (proto.AgentInfo, error) { if ws.AgentCoordinator != nil { m := ws.AgentCoordinator.Model() agentInfo = proto.AgentInfo{ - Model: m.CatwalkCfg, - IsBusy: ws.AgentCoordinator.IsBusy(), + Model: m.CatwalkCfg, + ModelCfg: m.ModelCfg, + IsBusy: ws.AgentCoordinator.IsBusy(), + IsReady: true, } } return agentInfo, nil @@ -114,3 +117,28 @@ func (b *Backend) ClearQueue(workspaceID, sessionID string) error { } return nil } + +// QueuedPromptsList returns the list of queued prompt strings for a +// session. +func (b *Backend) QueuedPromptsList(workspaceID, sessionID string) ([]string, error) { + ws, err := b.GetWorkspace(workspaceID) + if err != nil { + return nil, err + } + + if ws.AgentCoordinator == nil { + return nil, nil + } + + return ws.AgentCoordinator.QueuedPromptsList(sessionID), nil +} + +// GetDefaultSmallModel returns the default small model for a provider. +func (b *Backend) GetDefaultSmallModel(workspaceID, providerID string) (config.SelectedModel, error) { + ws, err := b.GetWorkspace(workspaceID) + if err != nil { + return config.SelectedModel{}, err + } + + return ws.GetDefaultSmallModel(providerID), nil +} diff --git a/internal/backend/config.go b/internal/backend/config.go new file mode 100644 index 0000000000000000000000000000000000000000..27fa174fefa51bcd46432b56c07327c8f2c25047 --- /dev/null +++ b/internal/backend/config.go @@ -0,0 +1,165 @@ +package backend + +import ( + "context" + + "github.com/charmbracelet/crush/internal/agent" + mcptools "github.com/charmbracelet/crush/internal/agent/tools/mcp" + "github.com/charmbracelet/crush/internal/commands" + "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/oauth" +) + +// MCPResourceContents holds the contents of an MCP resource returned +// by the backend. +type MCPResourceContents struct { + URI string `json:"uri"` + MIMEType string `json:"mime_type,omitempty"` + Text string `json:"text,omitempty"` + Blob []byte `json:"blob,omitempty"` +} + +// SetConfigField sets a key/value pair in the config file for the +// given scope. +func (b *Backend) SetConfigField(workspaceID string, scope config.Scope, key string, value any) error { + ws, err := b.GetWorkspace(workspaceID) + if err != nil { + return err + } + return ws.Cfg.SetConfigField(scope, key, value) +} + +// RemoveConfigField removes a key from the config file for the given +// scope. +func (b *Backend) RemoveConfigField(workspaceID string, scope config.Scope, key string) error { + ws, err := b.GetWorkspace(workspaceID) + if err != nil { + return err + } + return ws.Cfg.RemoveConfigField(scope, key) +} + +// UpdatePreferredModel updates the preferred model for the given type +// and persists it to the config file at the given scope. +func (b *Backend) UpdatePreferredModel(workspaceID string, scope config.Scope, modelType config.SelectedModelType, model config.SelectedModel) error { + ws, err := b.GetWorkspace(workspaceID) + if err != nil { + return err + } + return ws.Cfg.UpdatePreferredModel(scope, modelType, model) +} + +// SetCompactMode sets the compact mode setting and persists it. +func (b *Backend) SetCompactMode(workspaceID string, scope config.Scope, enabled bool) error { + ws, err := b.GetWorkspace(workspaceID) + if err != nil { + return err + } + return ws.Cfg.SetCompactMode(scope, enabled) +} + +// SetProviderAPIKey sets the API key for a provider and persists it. +func (b *Backend) SetProviderAPIKey(workspaceID string, scope config.Scope, providerID string, apiKey any) error { + ws, err := b.GetWorkspace(workspaceID) + if err != nil { + return err + } + return ws.Cfg.SetProviderAPIKey(scope, providerID, apiKey) +} + +// ImportCopilot attempts to import a GitHub Copilot token from disk. +func (b *Backend) ImportCopilot(workspaceID string) (*oauth.Token, bool, error) { + ws, err := b.GetWorkspace(workspaceID) + if err != nil { + return nil, false, err + } + token, ok := ws.Cfg.ImportCopilot() + return token, ok, nil +} + +// RefreshOAuthToken refreshes the OAuth token for a provider. +func (b *Backend) RefreshOAuthToken(ctx context.Context, workspaceID string, scope config.Scope, providerID string) error { + ws, err := b.GetWorkspace(workspaceID) + if err != nil { + return err + } + return ws.Cfg.RefreshOAuthToken(ctx, scope, providerID) +} + +// ProjectNeedsInitialization checks whether the project in this +// workspace needs initialization. +func (b *Backend) ProjectNeedsInitialization(workspaceID string) (bool, error) { + ws, err := b.GetWorkspace(workspaceID) + if err != nil { + return false, err + } + return config.ProjectNeedsInitialization(ws.Cfg) +} + +// MarkProjectInitialized marks the project as initialized. +func (b *Backend) MarkProjectInitialized(workspaceID string) error { + ws, err := b.GetWorkspace(workspaceID) + if err != nil { + return err + } + return config.MarkProjectInitialized(ws.Cfg) +} + +// InitializePrompt builds the initialization prompt for the workspace. +func (b *Backend) InitializePrompt(workspaceID string) (string, error) { + ws, err := b.GetWorkspace(workspaceID) + if err != nil { + return "", err + } + return agent.InitializePrompt(ws.Cfg) +} + +// RefreshMCPTools refreshes the tools for a named MCP server. +func (b *Backend) RefreshMCPTools(ctx context.Context, workspaceID, name string) error { + ws, err := b.GetWorkspace(workspaceID) + if err != nil { + return err + } + mcptools.RefreshTools(ctx, ws.Cfg, name) + return nil +} + +// ReadMCPResource reads a resource from a named MCP server. +func (b *Backend) ReadMCPResource(ctx context.Context, workspaceID, name, uri string) ([]MCPResourceContents, error) { + ws, err := b.GetWorkspace(workspaceID) + if err != nil { + return nil, err + } + contents, err := mcptools.ReadResource(ctx, ws.Cfg, name, uri) + if err != nil { + return nil, err + } + result := make([]MCPResourceContents, len(contents)) + for i, c := range contents { + result[i] = MCPResourceContents{ + URI: c.URI, + MIMEType: c.MIMEType, + Text: c.Text, + Blob: c.Blob, + } + } + return result, nil +} + +// GetMCPPrompt retrieves a prompt from a named MCP server. +func (b *Backend) GetMCPPrompt(workspaceID, clientID, promptID string, args map[string]string) (string, error) { + ws, err := b.GetWorkspace(workspaceID) + if err != nil { + return "", err + } + return commands.GetMCPPrompt(ws.Cfg, clientID, promptID, args) +} + +// GetWorkingDir returns the working directory for a workspace. +func (b *Backend) GetWorkingDir(workspaceID string) (string, error) { + ws, err := b.GetWorkspace(workspaceID) + if err != nil { + return "", err + } + return ws.Cfg.WorkingDir(), nil +} diff --git a/internal/backend/events.go b/internal/backend/events.go index 06a63d0c7bba1e411469349341ce6b85dc628ed2..a91bad1d5322d1c0ed909b3239e9e97c0eb0c366 100644 --- a/internal/backend/events.go +++ b/internal/backend/events.go @@ -1,8 +1,11 @@ package backend import ( + "context" + tea "charm.land/bubbletea/v2" + mcptools "github.com/charmbracelet/crush/internal/agent/tools/mcp" "github.com/charmbracelet/crush/internal/app" "github.com/charmbracelet/crush/internal/config" ) @@ -65,3 +68,40 @@ func (b *Backend) GetWorkspaceProviders(workspaceID string) (any, error) { providers, _ := config.Providers(ws.Cfg.Config()) return providers, nil } + +// LSPStart starts an LSP server for the given path. +func (b *Backend) LSPStart(ctx context.Context, workspaceID, path string) error { + ws, err := b.GetWorkspace(workspaceID) + if err != nil { + return err + } + + ws.LSPManager.Start(ctx, path) + return nil +} + +// LSPStopAll stops all LSP servers for a workspace. +func (b *Backend) LSPStopAll(ctx context.Context, workspaceID string) error { + ws, err := b.GetWorkspace(workspaceID) + if err != nil { + return err + } + + ws.LSPManager.StopAll(ctx) + return nil +} + +// MCPGetStates returns the current state of all MCP clients. +func (b *Backend) MCPGetStates(_ string) map[string]mcptools.ClientInfo { + return mcptools.GetStates() +} + +// MCPRefreshPrompts refreshes prompts for a named MCP client. +func (b *Backend) MCPRefreshPrompts(ctx context.Context, _ string, name string) { + mcptools.RefreshPrompts(ctx, name) +} + +// MCPRefreshResources refreshes resources for a named MCP client. +func (b *Backend) MCPRefreshResources(ctx context.Context, _ string, name string) { + mcptools.RefreshResources(ctx, name) +} diff --git a/internal/backend/filetracker.go b/internal/backend/filetracker.go new file mode 100644 index 0000000000000000000000000000000000000000..14ae99bc7fad5cac541e530cb94162d2218ccaac --- /dev/null +++ b/internal/backend/filetracker.go @@ -0,0 +1,37 @@ +package backend + +import ( + "context" + "time" +) + +// FileTrackerRecordRead records a file read for a session. +func (b *Backend) FileTrackerRecordRead(ctx context.Context, workspaceID, sessionID, path string) error { + ws, err := b.GetWorkspace(workspaceID) + if err != nil { + return err + } + + ws.FileTracker.RecordRead(ctx, sessionID, path) + return nil +} + +// FileTrackerLastReadTime returns the last read time for a file in a session. +func (b *Backend) FileTrackerLastReadTime(ctx context.Context, workspaceID, sessionID, path string) (time.Time, error) { + ws, err := b.GetWorkspace(workspaceID) + if err != nil { + return time.Time{}, err + } + + return ws.FileTracker.LastReadTime(ctx, sessionID, path), nil +} + +// FileTrackerListReadFiles returns the list of read files for a session. +func (b *Backend) FileTrackerListReadFiles(ctx context.Context, workspaceID, sessionID string) ([]string, error) { + ws, err := b.GetWorkspace(workspaceID) + if err != nil { + return nil, err + } + + return ws.FileTracker.ListReadFiles(ctx, sessionID) +} diff --git a/internal/backend/session.go b/internal/backend/session.go index 20592f6d9f4fdcd0afe95c54914403cb13d6277c..10e21ed8932ccbc990a525785166517cd231595c 100644 --- a/internal/backend/session.go +++ b/internal/backend/session.go @@ -3,6 +3,7 @@ package backend import ( "context" + "github.com/charmbracelet/crush/internal/message" "github.com/charmbracelet/crush/internal/proto" "github.com/charmbracelet/crush/internal/session" ) @@ -65,7 +66,7 @@ func (b *Backend) GetAgentSession(ctx context.Context, workspaceID, sessionID st } // ListSessionMessages returns all messages for a session. -func (b *Backend) ListSessionMessages(ctx context.Context, workspaceID, sessionID string) (any, error) { +func (b *Backend) ListSessionMessages(ctx context.Context, workspaceID, sessionID string) ([]message.Message, error) { ws, err := b.GetWorkspace(workspaceID) if err != nil { return nil, err @@ -83,3 +84,43 @@ func (b *Backend) ListSessionHistory(ctx context.Context, workspaceID, sessionID return ws.History.ListBySession(ctx, sessionID) } + +// SaveSession updates a session in the given workspace. +func (b *Backend) SaveSession(ctx context.Context, workspaceID string, sess session.Session) (session.Session, error) { + ws, err := b.GetWorkspace(workspaceID) + if err != nil { + return session.Session{}, err + } + + return ws.Sessions.Save(ctx, sess) +} + +// DeleteSession deletes a session from the given workspace. +func (b *Backend) DeleteSession(ctx context.Context, workspaceID, sessionID string) error { + ws, err := b.GetWorkspace(workspaceID) + if err != nil { + return err + } + + return ws.Sessions.Delete(ctx, sessionID) +} + +// ListUserMessages returns user-role messages for a session. +func (b *Backend) ListUserMessages(ctx context.Context, workspaceID, sessionID string) ([]message.Message, error) { + ws, err := b.GetWorkspace(workspaceID) + if err != nil { + return nil, err + } + + return ws.Messages.ListUserMessages(ctx, sessionID) +} + +// ListAllUserMessages returns all user-role messages across sessions. +func (b *Backend) ListAllUserMessages(ctx context.Context, workspaceID string) ([]message.Message, error) { + ws, err := b.GetWorkspace(workspaceID) + if err != nil { + return nil, err + } + + return ws.Messages.ListAllUserMessages(ctx) +} diff --git a/internal/client/client.go b/internal/client/client.go index fcc38914a1dba0908a8f47d3be709b555c948171..e97a0570e42e7176debf3e6ca4d91760483a197d 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -150,6 +150,10 @@ func (c *Client) delete(ctx context.Context, path string, query url.Values, head return c.sendReq(ctx, http.MethodDelete, path, query, nil, headers) } +func (c *Client) put(ctx context.Context, path string, query url.Values, body io.Reader, headers http.Header) (*http.Response, error) { + return c.sendReq(ctx, http.MethodPut, path, query, body, headers) +} + func (c *Client) sendReq(ctx context.Context, method, path string, query url.Values, body io.Reader, headers http.Header) (*http.Response, error) { url := (&url.URL{ Path: stdpath.Join("/v1", path), diff --git a/internal/client/config.go b/internal/client/config.go new file mode 100644 index 0000000000000000000000000000000000000000..7589c4c9684670d84f22ab737b4da5c3c9e8478d --- /dev/null +++ b/internal/client/config.go @@ -0,0 +1,252 @@ +package client + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + + "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/oauth" +) + +// SetConfigField sets a config key/value pair on the server. +func (c *Client) SetConfigField(ctx context.Context, id string, scope config.Scope, key string, value any) error { + rsp, err := c.post(ctx, fmt.Sprintf("/workspaces/%s/config/set", id), nil, jsonBody(struct { + Scope config.Scope `json:"scope"` + Key string `json:"key"` + Value any `json:"value"` + }{Scope: scope, Key: key, Value: value}), http.Header{"Content-Type": []string{"application/json"}}) + if err != nil { + return fmt.Errorf("failed to set config field: %w", err) + } + defer rsp.Body.Close() + if rsp.StatusCode != http.StatusOK { + return fmt.Errorf("failed to set config field: status code %d", rsp.StatusCode) + } + return nil +} + +// RemoveConfigField removes a config key on the server. +func (c *Client) RemoveConfigField(ctx context.Context, id string, scope config.Scope, key string) error { + rsp, err := c.post(ctx, fmt.Sprintf("/workspaces/%s/config/remove", id), nil, jsonBody(struct { + Scope config.Scope `json:"scope"` + Key string `json:"key"` + }{Scope: scope, Key: key}), http.Header{"Content-Type": []string{"application/json"}}) + if err != nil { + return fmt.Errorf("failed to remove config field: %w", err) + } + defer rsp.Body.Close() + if rsp.StatusCode != http.StatusOK { + return fmt.Errorf("failed to remove config field: status code %d", rsp.StatusCode) + } + return nil +} + +// UpdatePreferredModel updates the preferred model on the server. +func (c *Client) UpdatePreferredModel(ctx context.Context, id string, scope config.Scope, modelType config.SelectedModelType, model config.SelectedModel) error { + rsp, err := c.post(ctx, fmt.Sprintf("/workspaces/%s/config/model", id), nil, jsonBody(struct { + Scope config.Scope `json:"scope"` + ModelType config.SelectedModelType `json:"model_type"` + Model config.SelectedModel `json:"model"` + }{Scope: scope, ModelType: modelType, Model: model}), http.Header{"Content-Type": []string{"application/json"}}) + if err != nil { + return fmt.Errorf("failed to update preferred model: %w", err) + } + defer rsp.Body.Close() + if rsp.StatusCode != http.StatusOK { + return fmt.Errorf("failed to update preferred model: status code %d", rsp.StatusCode) + } + return nil +} + +// SetCompactMode sets compact mode on the server. +func (c *Client) SetCompactMode(ctx context.Context, id string, scope config.Scope, enabled bool) error { + rsp, err := c.post(ctx, fmt.Sprintf("/workspaces/%s/config/compact", id), nil, jsonBody(struct { + Scope config.Scope `json:"scope"` + Enabled bool `json:"enabled"` + }{Scope: scope, Enabled: enabled}), http.Header{"Content-Type": []string{"application/json"}}) + if err != nil { + return fmt.Errorf("failed to set compact mode: %w", err) + } + defer rsp.Body.Close() + if rsp.StatusCode != http.StatusOK { + return fmt.Errorf("failed to set compact mode: status code %d", rsp.StatusCode) + } + return nil +} + +// SetProviderAPIKey sets a provider API key on the server. +func (c *Client) SetProviderAPIKey(ctx context.Context, id string, scope config.Scope, providerID string, apiKey any) error { + rsp, err := c.post(ctx, fmt.Sprintf("/workspaces/%s/config/provider-key", id), nil, jsonBody(struct { + Scope config.Scope `json:"scope"` + ProviderID string `json:"provider_id"` + APIKey any `json:"api_key"` + }{Scope: scope, ProviderID: providerID, APIKey: apiKey}), http.Header{"Content-Type": []string{"application/json"}}) + if err != nil { + return fmt.Errorf("failed to set provider API key: %w", err) + } + defer rsp.Body.Close() + if rsp.StatusCode != http.StatusOK { + return fmt.Errorf("failed to set provider API key: status code %d", rsp.StatusCode) + } + return nil +} + +// ImportCopilot attempts to import a GitHub Copilot token on the +// server. +func (c *Client) ImportCopilot(ctx context.Context, id string) (*oauth.Token, bool, error) { + rsp, err := c.post(ctx, fmt.Sprintf("/workspaces/%s/config/import-copilot", id), nil, nil, nil) + if err != nil { + return nil, false, fmt.Errorf("failed to import copilot: %w", err) + } + defer rsp.Body.Close() + if rsp.StatusCode != http.StatusOK { + return nil, false, fmt.Errorf("failed to import copilot: status code %d", rsp.StatusCode) + } + var result struct { + Token *oauth.Token `json:"token"` + Success bool `json:"success"` + } + if err := json.NewDecoder(rsp.Body).Decode(&result); err != nil { + return nil, false, fmt.Errorf("failed to decode import copilot response: %w", err) + } + return result.Token, result.Success, nil +} + +// RefreshOAuthToken refreshes an OAuth token for a provider on the +// server. +func (c *Client) RefreshOAuthToken(ctx context.Context, id string, scope config.Scope, providerID string) error { + rsp, err := c.post(ctx, fmt.Sprintf("/workspaces/%s/config/refresh-oauth", id), nil, jsonBody(struct { + Scope config.Scope `json:"scope"` + ProviderID string `json:"provider_id"` + }{Scope: scope, ProviderID: providerID}), http.Header{"Content-Type": []string{"application/json"}}) + if err != nil { + return fmt.Errorf("failed to refresh OAuth token: %w", err) + } + defer rsp.Body.Close() + if rsp.StatusCode != http.StatusOK { + return fmt.Errorf("failed to refresh OAuth token: status code %d", rsp.StatusCode) + } + return nil +} + +// ProjectNeedsInitialization checks if the project needs +// initialization. +func (c *Client) ProjectNeedsInitialization(ctx context.Context, id string) (bool, error) { + rsp, err := c.get(ctx, fmt.Sprintf("/workspaces/%s/project/needs-init", id), nil, nil) + if err != nil { + return false, fmt.Errorf("failed to check project init: %w", err) + } + defer rsp.Body.Close() + if rsp.StatusCode != http.StatusOK { + return false, fmt.Errorf("failed to check project init: status code %d", rsp.StatusCode) + } + var result struct { + NeedsInit bool `json:"needs_init"` + } + if err := json.NewDecoder(rsp.Body).Decode(&result); err != nil { + return false, fmt.Errorf("failed to decode project init response: %w", err) + } + return result.NeedsInit, nil +} + +// MarkProjectInitialized marks the project as initialized on the +// server. +func (c *Client) MarkProjectInitialized(ctx context.Context, id string) error { + rsp, err := c.post(ctx, fmt.Sprintf("/workspaces/%s/project/init", id), nil, nil, nil) + if err != nil { + return fmt.Errorf("failed to mark project initialized: %w", err) + } + defer rsp.Body.Close() + if rsp.StatusCode != http.StatusOK { + return fmt.Errorf("failed to mark project initialized: status code %d", rsp.StatusCode) + } + return nil +} + +// GetInitializePrompt retrieves the initialization prompt from the +// server. +func (c *Client) GetInitializePrompt(ctx context.Context, id string) (string, error) { + rsp, err := c.get(ctx, fmt.Sprintf("/workspaces/%s/project/init-prompt", id), nil, nil) + if err != nil { + return "", fmt.Errorf("failed to get init prompt: %w", err) + } + defer rsp.Body.Close() + if rsp.StatusCode != http.StatusOK { + return "", fmt.Errorf("failed to get init prompt: status code %d", rsp.StatusCode) + } + var result struct { + Prompt string `json:"prompt"` + } + if err := json.NewDecoder(rsp.Body).Decode(&result); err != nil { + return "", fmt.Errorf("failed to decode init prompt response: %w", err) + } + return result.Prompt, nil +} + +// MCPResourceContents holds the contents of an MCP resource. +type MCPResourceContents struct { + URI string `json:"uri"` + MIMEType string `json:"mime_type,omitempty"` + Text string `json:"text,omitempty"` + Blob []byte `json:"blob,omitempty"` +} + +// RefreshMCPTools refreshes tools for a named MCP server. +func (c *Client) RefreshMCPTools(ctx context.Context, id, name string) error { + rsp, err := c.post(ctx, fmt.Sprintf("/workspaces/%s/mcp/refresh-tools", id), nil, jsonBody(struct { + Name string `json:"name"` + }{Name: name}), http.Header{"Content-Type": []string{"application/json"}}) + if err != nil { + return fmt.Errorf("failed to refresh MCP tools: %w", err) + } + defer rsp.Body.Close() + if rsp.StatusCode != http.StatusOK { + return fmt.Errorf("failed to refresh MCP tools: status code %d", rsp.StatusCode) + } + return nil +} + +// ReadMCPResource reads a resource from a named MCP server. +func (c *Client) ReadMCPResource(ctx context.Context, id, name, uri string) ([]MCPResourceContents, error) { + rsp, err := c.post(ctx, fmt.Sprintf("/workspaces/%s/mcp/read-resource", id), nil, jsonBody(struct { + Name string `json:"name"` + URI string `json:"uri"` + }{Name: name, URI: uri}), http.Header{"Content-Type": []string{"application/json"}}) + if err != nil { + return nil, fmt.Errorf("failed to read MCP resource: %w", err) + } + defer rsp.Body.Close() + if rsp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("failed to read MCP resource: status code %d", rsp.StatusCode) + } + var contents []MCPResourceContents + if err := json.NewDecoder(rsp.Body).Decode(&contents); err != nil { + return nil, fmt.Errorf("failed to decode MCP resource: %w", err) + } + return contents, nil +} + +// GetMCPPrompt retrieves a prompt from a named MCP server. +func (c *Client) GetMCPPrompt(ctx context.Context, id, clientID, promptID string, args map[string]string) (string, error) { + rsp, err := c.post(ctx, fmt.Sprintf("/workspaces/%s/mcp/get-prompt", id), nil, jsonBody(struct { + ClientID string `json:"client_id"` + PromptID string `json:"prompt_id"` + Args map[string]string `json:"args"` + }{ClientID: clientID, PromptID: promptID, Args: args}), http.Header{"Content-Type": []string{"application/json"}}) + if err != nil { + return "", fmt.Errorf("failed to get MCP prompt: %w", err) + } + defer rsp.Body.Close() + if rsp.StatusCode != http.StatusOK { + return "", fmt.Errorf("failed to get MCP prompt: status code %d", rsp.StatusCode) + } + var result struct { + Prompt string `json:"prompt"` + } + if err := json.NewDecoder(rsp.Body).Decode(&result); err != nil { + return "", fmt.Errorf("failed to decode MCP prompt response: %w", err) + } + return result.Prompt, nil +} diff --git a/internal/client/proto.go b/internal/client/proto.go index 0705f4ee3db77bd8500ebbb29b607f494298d959..0bbbb02b3a9f82bfd859a66a7c61f3f5a7c210e0 100644 --- a/internal/client/proto.go +++ b/internal/client/proto.go @@ -10,9 +10,9 @@ import ( "io" "log/slog" "net/http" + "net/url" "time" - "github.com/charmbracelet/crush/internal/app" "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/history" "github.com/charmbracelet/crush/internal/message" @@ -39,6 +39,23 @@ func (c *Client) CreateWorkspace(ctx context.Context, ws proto.Workspace) (*prot return &created, nil } +// GetWorkspace retrieves a workspace from the server. +func (c *Client) GetWorkspace(ctx context.Context, id string) (*proto.Workspace, error) { + rsp, err := c.get(ctx, fmt.Sprintf("/workspaces/%s", id), nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to get workspace: %w", err) + } + defer rsp.Body.Close() + if rsp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("failed to get workspace: status code %d", rsp.StatusCode) + } + var ws proto.Workspace + if err := json.NewDecoder(rsp.Body).Decode(&ws); err != nil { + return nil, fmt.Errorf("failed to decode workspace: %w", err) + } + return &ws, nil +} + // DeleteWorkspace deletes a workspace on the server. func (c *Client) DeleteWorkspace(ctx context.Context, id string) error { rsp, err := c.delete(ctx, fmt.Sprintf("/workspaces/%s", id), nil, nil) @@ -95,63 +112,44 @@ func (c *Client) SubscribeEvents(ctx context.Context, id string) (<-chan any, er data = bytes.TrimSpace(data) - var event pubsub.Event[any] - if err := json.Unmarshal(data, &event); err != nil { - slog.Error("Unmarshaling event", "error", err) - continue - } - - type alias pubsub.Event[any] - aux := &struct { - Payload json.RawMessage `json:"payload"` - *alias - }{ - alias: (*alias)(&event), - } - - if err := json.Unmarshal(data, &aux); err != nil { - slog.Error("Unmarshaling event payload", "error", err) - continue - } - var p pubsub.Payload - if err := json.Unmarshal(aux.Payload, &p); err != nil { - slog.Error("Unmarshaling event payload", "error", err) + if err := json.Unmarshal(data, &p); err != nil { + slog.Error("Unmarshaling event envelope", "error", err) continue } switch p.Type { case pubsub.PayloadTypeLSPEvent: var e pubsub.Event[proto.LSPEvent] - _ = json.Unmarshal(data, &e) + _ = json.Unmarshal(p.Payload, &e) sendEvent(ctx, events, e) case pubsub.PayloadTypeMCPEvent: var e pubsub.Event[proto.MCPEvent] - _ = json.Unmarshal(data, &e) + _ = json.Unmarshal(p.Payload, &e) sendEvent(ctx, events, e) case pubsub.PayloadTypePermissionRequest: var e pubsub.Event[proto.PermissionRequest] - _ = json.Unmarshal(data, &e) + _ = json.Unmarshal(p.Payload, &e) sendEvent(ctx, events, e) case pubsub.PayloadTypePermissionNotification: var e pubsub.Event[proto.PermissionNotification] - _ = json.Unmarshal(data, &e) + _ = json.Unmarshal(p.Payload, &e) sendEvent(ctx, events, e) case pubsub.PayloadTypeMessage: var e pubsub.Event[proto.Message] - _ = json.Unmarshal(data, &e) + _ = json.Unmarshal(p.Payload, &e) sendEvent(ctx, events, e) case pubsub.PayloadTypeSession: var e pubsub.Event[proto.Session] - _ = json.Unmarshal(data, &e) + _ = json.Unmarshal(p.Payload, &e) sendEvent(ctx, events, e) case pubsub.PayloadTypeFile: var e pubsub.Event[proto.File] - _ = json.Unmarshal(data, &e) + _ = json.Unmarshal(p.Payload, &e) sendEvent(ctx, events, e) case pubsub.PayloadTypeAgentEvent: var e pubsub.Event[proto.AgentEvent] - _ = json.Unmarshal(data, &e) + _ = json.Unmarshal(p.Payload, &e) sendEvent(ctx, events, e) default: slog.Warn("Unknown event type", "type", p.Type) @@ -191,7 +189,7 @@ func (c *Client) GetLSPDiagnostics(ctx context.Context, id string, lspName strin } // GetLSPs retrieves the LSP client states for a workspace. -func (c *Client) GetLSPs(ctx context.Context, id string) (map[string]app.LSPClientInfo, error) { +func (c *Client) GetLSPs(ctx context.Context, id string) (map[string]proto.LSPClientInfo, error) { rsp, err := c.get(ctx, fmt.Sprintf("/workspaces/%s/lsps", id), nil, nil) if err != nil { return nil, fmt.Errorf("failed to get LSPs: %w", err) @@ -200,13 +198,64 @@ func (c *Client) GetLSPs(ctx context.Context, id string) (map[string]app.LSPClie if rsp.StatusCode != http.StatusOK { return nil, fmt.Errorf("failed to get LSPs: status code %d", rsp.StatusCode) } - var lsps map[string]app.LSPClientInfo + var lsps map[string]proto.LSPClientInfo if err := json.NewDecoder(rsp.Body).Decode(&lsps); err != nil { return nil, fmt.Errorf("failed to decode LSPs: %w", err) } return lsps, nil } +// MCPGetStates retrieves the MCP client states for a workspace. +func (c *Client) MCPGetStates(ctx context.Context, id string) (map[string]proto.MCPClientInfo, error) { + rsp, err := c.get(ctx, fmt.Sprintf("/workspaces/%s/mcp/states", id), nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to get MCP states: %w", err) + } + defer rsp.Body.Close() + if rsp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("failed to get MCP states: status code %d", rsp.StatusCode) + } + var states map[string]proto.MCPClientInfo + if err := json.NewDecoder(rsp.Body).Decode(&states); err != nil { + return nil, fmt.Errorf("failed to decode MCP states: %w", err) + } + return states, nil +} + +// MCPRefreshPrompts refreshes prompts for a named MCP client. +func (c *Client) MCPRefreshPrompts(ctx context.Context, id, name string) error { + rsp, err := c.post(ctx, fmt.Sprintf("/workspaces/%s/mcp/refresh-prompts", id), nil, + jsonBody(struct { + Name string `json:"name"` + }{Name: name}), + http.Header{"Content-Type": []string{"application/json"}}) + if err != nil { + return fmt.Errorf("failed to refresh MCP prompts: %w", err) + } + defer rsp.Body.Close() + if rsp.StatusCode != http.StatusOK { + return fmt.Errorf("failed to refresh MCP prompts: status code %d", rsp.StatusCode) + } + return nil +} + +// MCPRefreshResources refreshes resources for a named MCP client. +func (c *Client) MCPRefreshResources(ctx context.Context, id, name string) error { + rsp, err := c.post(ctx, fmt.Sprintf("/workspaces/%s/mcp/refresh-resources", id), nil, + jsonBody(struct { + Name string `json:"name"` + }{Name: name}), + http.Header{"Content-Type": []string{"application/json"}}) + if err != nil { + return fmt.Errorf("failed to refresh MCP resources: %w", err) + } + defer rsp.Body.Close() + if rsp.StatusCode != http.StatusOK { + return fmt.Errorf("failed to refresh MCP resources: status code %d", rsp.StatusCode) + } + return nil +} + // GetAgentSessionQueuedPrompts retrieves the number of queued prompts for a // session. func (c *Client) GetAgentSessionQueuedPrompts(ctx context.Context, id string, sessionID string) (int, error) { @@ -347,11 +396,11 @@ func (c *Client) ListMessages(ctx context.Context, id string, sessionID string) if rsp.StatusCode != http.StatusOK { return nil, fmt.Errorf("failed to get messages: status code %d", rsp.StatusCode) } - var messages []message.Message - if err := json.NewDecoder(rsp.Body).Decode(&messages); err != nil { + var protoMsgs []proto.Message + if err := json.NewDecoder(rsp.Body).Decode(&protoMsgs); err != nil && !errors.Is(err, io.EOF) { return nil, fmt.Errorf("failed to decode messages: %w", err) } - return messages, nil + return protoToMessages(protoMsgs), nil } // GetSession retrieves a specific session. @@ -488,3 +537,258 @@ func jsonBody(v any) *bytes.Buffer { b.Write(m) return b } + +// SaveSession updates a session in a workspace. +func (c *Client) SaveSession(ctx context.Context, id string, sess session.Session) (*session.Session, error) { + rsp, err := c.put(ctx, fmt.Sprintf("/workspaces/%s/sessions/%s", id, sess.ID), nil, jsonBody(sess), http.Header{"Content-Type": []string{"application/json"}}) + if err != nil { + return nil, fmt.Errorf("failed to save session: %w", err) + } + defer rsp.Body.Close() + if rsp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("failed to save session: status code %d", rsp.StatusCode) + } + var saved session.Session + if err := json.NewDecoder(rsp.Body).Decode(&saved); err != nil { + return nil, fmt.Errorf("failed to decode session: %w", err) + } + return &saved, nil +} + +// DeleteSession deletes a session from a workspace. +func (c *Client) DeleteSession(ctx context.Context, id string, sessionID string) error { + rsp, err := c.delete(ctx, fmt.Sprintf("/workspaces/%s/sessions/%s", id, sessionID), nil, nil) + if err != nil { + return fmt.Errorf("failed to delete session: %w", err) + } + defer rsp.Body.Close() + if rsp.StatusCode != http.StatusOK { + return fmt.Errorf("failed to delete session: status code %d", rsp.StatusCode) + } + return nil +} + +// ListUserMessages retrieves user-role messages for a session. +func (c *Client) ListUserMessages(ctx context.Context, id string, sessionID string) ([]message.Message, error) { + rsp, err := c.get(ctx, fmt.Sprintf("/workspaces/%s/sessions/%s/messages/user", id, sessionID), nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to get user messages: %w", err) + } + defer rsp.Body.Close() + if rsp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("failed to get user messages: status code %d", rsp.StatusCode) + } + var protoMsgs []proto.Message + if err := json.NewDecoder(rsp.Body).Decode(&protoMsgs); err != nil && !errors.Is(err, io.EOF) { + return nil, fmt.Errorf("failed to decode user messages: %w", err) + } + return protoToMessages(protoMsgs), nil +} + +// ListAllUserMessages retrieves all user-role messages across sessions. +func (c *Client) ListAllUserMessages(ctx context.Context, id string) ([]message.Message, error) { + rsp, err := c.get(ctx, fmt.Sprintf("/workspaces/%s/messages/user", id), nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to get all user messages: %w", err) + } + defer rsp.Body.Close() + if rsp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("failed to get all user messages: status code %d", rsp.StatusCode) + } + var protoMsgs []proto.Message + if err := json.NewDecoder(rsp.Body).Decode(&protoMsgs); err != nil && !errors.Is(err, io.EOF) { + return nil, fmt.Errorf("failed to decode all user messages: %w", err) + } + return protoToMessages(protoMsgs), nil +} + +// CancelAgentSession cancels an ongoing agent operation for a session. +func (c *Client) CancelAgentSession(ctx context.Context, id string, sessionID string) error { + rsp, err := c.post(ctx, fmt.Sprintf("/workspaces/%s/agent/sessions/%s/cancel", id, sessionID), nil, nil, nil) + if err != nil { + return fmt.Errorf("failed to cancel agent session: %w", err) + } + defer rsp.Body.Close() + if rsp.StatusCode != http.StatusOK { + return fmt.Errorf("failed to cancel agent session: status code %d", rsp.StatusCode) + } + return nil +} + +// GetAgentSessionQueuedPromptsList retrieves the list of queued prompt +// strings for a session. +func (c *Client) GetAgentSessionQueuedPromptsList(ctx context.Context, id string, sessionID string) ([]string, error) { + rsp, err := c.get(ctx, fmt.Sprintf("/workspaces/%s/agent/sessions/%s/prompts/list", id, sessionID), nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to get queued prompts list: %w", err) + } + defer rsp.Body.Close() + if rsp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("failed to get queued prompts list: status code %d", rsp.StatusCode) + } + var prompts []string + if err := json.NewDecoder(rsp.Body).Decode(&prompts); err != nil { + return nil, fmt.Errorf("failed to decode queued prompts list: %w", err) + } + return prompts, nil +} + +// GetDefaultSmallModel retrieves the default small model for a provider. +func (c *Client) GetDefaultSmallModel(ctx context.Context, id string, providerID string) (*config.SelectedModel, error) { + rsp, err := c.get(ctx, fmt.Sprintf("/workspaces/%s/agent/default-small-model", id), url.Values{"provider_id": []string{providerID}}, nil) + if err != nil { + return nil, fmt.Errorf("failed to get default small model: %w", err) + } + defer rsp.Body.Close() + if rsp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("failed to get default small model: status code %d", rsp.StatusCode) + } + var model config.SelectedModel + if err := json.NewDecoder(rsp.Body).Decode(&model); err != nil { + return nil, fmt.Errorf("failed to decode default small model: %w", err) + } + return &model, nil +} + +// FileTrackerRecordRead records a file read for a session. +func (c *Client) FileTrackerRecordRead(ctx context.Context, id string, sessionID, path string) error { + rsp, err := c.post(ctx, fmt.Sprintf("/workspaces/%s/filetracker/read", id), nil, jsonBody(struct { + SessionID string `json:"session_id"` + Path string `json:"path"` + }{SessionID: sessionID, Path: path}), http.Header{"Content-Type": []string{"application/json"}}) + if err != nil { + return fmt.Errorf("failed to record file read: %w", err) + } + defer rsp.Body.Close() + if rsp.StatusCode != http.StatusOK { + return fmt.Errorf("failed to record file read: status code %d", rsp.StatusCode) + } + return nil +} + +// FileTrackerLastReadTime returns the last read time for a file in a +// session. +func (c *Client) FileTrackerLastReadTime(ctx context.Context, id string, sessionID, path string) (time.Time, error) { + rsp, err := c.get(ctx, fmt.Sprintf("/workspaces/%s/filetracker/lastread", id), url.Values{ + "session_id": []string{sessionID}, + "path": []string{path}, + }, nil) + if err != nil { + return time.Time{}, fmt.Errorf("failed to get last read time: %w", err) + } + defer rsp.Body.Close() + if rsp.StatusCode != http.StatusOK { + return time.Time{}, fmt.Errorf("failed to get last read time: status code %d", rsp.StatusCode) + } + var t time.Time + if err := json.NewDecoder(rsp.Body).Decode(&t); err != nil { + return time.Time{}, fmt.Errorf("failed to decode last read time: %w", err) + } + return t, nil +} + +// FileTrackerListReadFiles returns the list of read files for a session. +func (c *Client) FileTrackerListReadFiles(ctx context.Context, id string, sessionID string) ([]string, error) { + rsp, err := c.get(ctx, fmt.Sprintf("/workspaces/%s/sessions/%s/filetracker/files", id, sessionID), nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to get read files: %w", err) + } + defer rsp.Body.Close() + if rsp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("failed to get read files: status code %d", rsp.StatusCode) + } + var files []string + if err := json.NewDecoder(rsp.Body).Decode(&files); err != nil { + return nil, fmt.Errorf("failed to decode read files: %w", err) + } + return files, nil +} + +// LSPStart starts an LSP server for a path. +func (c *Client) LSPStart(ctx context.Context, id string, path string) error { + rsp, err := c.post(ctx, fmt.Sprintf("/workspaces/%s/lsps/start", id), nil, jsonBody(struct { + Path string `json:"path"` + }{Path: path}), http.Header{"Content-Type": []string{"application/json"}}) + if err != nil { + return fmt.Errorf("failed to start LSP: %w", err) + } + defer rsp.Body.Close() + if rsp.StatusCode != http.StatusOK { + return fmt.Errorf("failed to start LSP: status code %d", rsp.StatusCode) + } + return nil +} + +// LSPStopAll stops all LSP servers for a workspace. +func (c *Client) LSPStopAll(ctx context.Context, id string) error { + rsp, err := c.post(ctx, fmt.Sprintf("/workspaces/%s/lsps/stop", id), nil, nil, nil) + if err != nil { + return fmt.Errorf("failed to stop LSPs: %w", err) + } + defer rsp.Body.Close() + if rsp.StatusCode != http.StatusOK { + return fmt.Errorf("failed to stop LSPs: status code %d", rsp.StatusCode) + } + return nil +} + +func protoToMessages(msgs []proto.Message) []message.Message { + out := make([]message.Message, len(msgs)) + for i, m := range msgs { + out[i] = protoToMessage(m) + } + return out +} + +func protoToMessage(m proto.Message) message.Message { + msg := message.Message{ + ID: m.ID, + SessionID: m.SessionID, + Role: message.MessageRole(m.Role), + Model: m.Model, + Provider: m.Provider, + CreatedAt: m.CreatedAt, + UpdatedAt: m.UpdatedAt, + } + + for _, p := range m.Parts { + switch v := p.(type) { + case proto.TextContent: + msg.Parts = append(msg.Parts, message.TextContent{Text: v.Text}) + case proto.ReasoningContent: + msg.Parts = append(msg.Parts, message.ReasoningContent{ + Thinking: v.Thinking, + Signature: v.Signature, + StartedAt: v.StartedAt, + FinishedAt: v.FinishedAt, + }) + case proto.ToolCall: + msg.Parts = append(msg.Parts, message.ToolCall{ + ID: v.ID, + Name: v.Name, + Input: v.Input, + Finished: v.Finished, + }) + case proto.ToolResult: + msg.Parts = append(msg.Parts, message.ToolResult{ + ToolCallID: v.ToolCallID, + Name: v.Name, + Content: v.Content, + IsError: v.IsError, + }) + case proto.Finish: + msg.Parts = append(msg.Parts, message.Finish{ + Reason: message.FinishReason(v.Reason), + Time: v.Time, + Message: v.Message, + Details: v.Details, + }) + case proto.ImageURLContent: + msg.Parts = append(msg.Parts, message.ImageURLContent{URL: v.URL, Detail: v.Detail}) + case proto.BinaryContent: + msg.Parts = append(msg.Parts, message.BinaryContent{Path: v.Path, MIMEType: v.MIMEType, Data: v.Data}) + } + } + + return msg +} diff --git a/internal/cmd/root.go b/internal/cmd/root.go index b5eea1cceef1ca0c4e97187093671900cde2d8dc..b224277edbade7398b0f8a1449ed053c80ec8408 100644 --- a/internal/cmd/root.go +++ b/internal/cmd/root.go @@ -25,13 +25,13 @@ import ( "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/db" "github.com/charmbracelet/crush/internal/event" - "github.com/charmbracelet/crush/internal/log" "github.com/charmbracelet/crush/internal/projects" "github.com/charmbracelet/crush/internal/proto" "github.com/charmbracelet/crush/internal/server" "github.com/charmbracelet/crush/internal/ui/common" ui "github.com/charmbracelet/crush/internal/ui/model" "github.com/charmbracelet/crush/internal/version" + "github.com/charmbracelet/crush/internal/workspace" "github.com/charmbracelet/fang" uv "github.com/charmbracelet/ultraviolet" "github.com/charmbracelet/x/ansi" @@ -96,51 +96,33 @@ crush --data-dir /path/to/custom/.crush return err } - appInstance, err := setupAppWithProgressBar(cmd) + c, ws, err := setupClientApp(cmd, hostURL) if err != nil { return err } - defer appInstance.Shutdown() + defer func() { _ = c.DeleteWorkspace(context.Background(), ws.ID) }() - // Register the workspace with the server so it tracks active - // clients and auto-shuts down when the last one exits. - cwd, _ := ResolveCwd(cmd) - dataDir, _ := cmd.Flags().GetString("data-dir") - debug, _ := cmd.Flags().GetBool("debug") - yolo, _ := cmd.Flags().GetBool("yolo") + event.AppInitialized() - c, err := client.NewClient(cwd, hostURL.Scheme, hostURL.Host) - if err != nil { - return fmt.Errorf("failed to create client: %v", err) - } + clientWs := workspace.NewClientWorkspace(c, *ws) - ws, err := c.CreateWorkspace(cmd.Context(), proto.Workspace{ - Path: cwd, - DataDir: dataDir, - Debug: debug, - YOLO: yolo, - Version: version.Version, - }) - if err != nil { - return fmt.Errorf("failed to register workspace: %v", err) + if ws.Config.IsConfigured() { + if err := clientWs.InitCoderAgent(cmd.Context()); err != nil { + slog.Error("Failed to initialize coder agent", "error", err) + } } - defer func() { _ = c.DeleteWorkspace(cmd.Context(), ws.ID) }() - - event.AppInitialized() - - // Set up the TUI. - var env uv.Environ = os.Environ() - com := common.DefaultCommon(appInstance) + com := common.DefaultCommon(clientWs) model := ui.New(com) + var env uv.Environ = os.Environ() program := tea.NewProgram( model, tea.WithEnvironment(env), tea.WithContext(cmd.Context()), - tea.WithFilter(ui.MouseEventFilter), // Filter mouse events based on focus state + tea.WithFilter(ui.MouseEventFilter), ) - go appInstance.Subscribe(program) + go clientWs.Subscribe(program) if _, err := program.Run(); err != nil { event.Error(err) @@ -295,18 +277,14 @@ func setupClientApp(cmd *cobra.Command, hostURL *url.URL) (*client.Client, *prot DataDir: dataDir, Debug: debug, YOLO: yolo, + Version: version.Version, Env: os.Environ(), }) if err != nil { return nil, nil, fmt.Errorf("failed to create workspace: %v", err) } - cfg, err := c.GetGlobalConfig(cmd.Context()) - if err != nil { - return nil, nil, fmt.Errorf("failed to get global config: %v", err) - } - - if shouldEnableMetrics(cfg) { + if shouldEnableMetrics(ws.Config) { event.Init() } @@ -314,18 +292,29 @@ func setupClientApp(cmd *cobra.Command, hostURL *url.URL) (*client.Client, *prot } // ensureServer auto-starts a detached server if the socket file does not -// exist. When connecting to an existing server, it waits for the health -// endpoint to respond. +// exist. When the socket exists, it verifies that the running server +// version matches the client; on mismatch it shuts down the old server +// and starts a fresh one. func ensureServer(cmd *cobra.Command, hostURL *url.URL) error { switch hostURL.Scheme { case "unix", "npipe": - _, err := os.Stat(hostURL.Host) - if err != nil && errors.Is(err, fs.ErrNotExist) { + needsStart := false + if _, err := os.Stat(hostURL.Host); err != nil && errors.Is(err, fs.ErrNotExist) { + needsStart = true + } else if err == nil { + if err := restartIfStale(cmd, hostURL); err != nil { + slog.Warn("Failed to check server version, restarting", "error", err) + needsStart = true + } + } + + if needsStart { if err := startDetachedServer(cmd); err != nil { return err } } + var err error for range 10 { _, err = os.Stat(hostURL.Host) if err == nil { @@ -345,43 +334,40 @@ func ensureServer(cmd *cobra.Command, hostURL *url.URL) error { return nil } -// waitForHealth polls the server's health endpoint until it responds. -func waitForHealth(ctx context.Context, c *client.Client) error { - var err error - for range 10 { - err = c.Health(ctx) - if err == nil { - return nil +// restartIfStale checks whether the running server matches the current +// client version. When they differ, it sends a shutdown command and +// removes the stale socket so the caller can start a fresh server. +func restartIfStale(cmd *cobra.Command, hostURL *url.URL) error { + c, err := client.NewClient("", hostURL.Scheme, hostURL.Host) + if err != nil { + return err + } + vi, err := c.VersionInfo(cmd.Context()) + if err != nil { + return err + } + if vi.Version == version.Version { + return nil + } + slog.Info("Server version mismatch, restarting", + "server", vi.Version, + "client", version.Version, + ) + _ = c.ShutdownServer(cmd.Context()) + // Give the old process a moment to release the socket. + for range 20 { + if _, err := os.Stat(hostURL.Host); errors.Is(err, fs.ErrNotExist) { + break } select { - case <-ctx.Done(): - return ctx.Err() + case <-cmd.Context().Done(): + return cmd.Context().Err() case <-time.After(100 * time.Millisecond): } } - return fmt.Errorf("failed to connect to crush server: %v", err) -} - -// streamEvents forwards SSE events from the client to the TUI program. -func streamEvents(ctx context.Context, evc <-chan any, p *tea.Program) { - defer log.RecoverPanic("app.Subscribe", func() { - slog.Info("TUI subscription panic: attempting graceful shutdown") - p.Quit() - }) - - for { - select { - case <-ctx.Done(): - slog.Debug("TUI message handler shutting down") - return - case ev, ok := <-evc: - if !ok { - slog.Debug("TUI message channel closed") - return - } - p.Send(ev) - } - } + // Force-remove if the socket is still lingering. + _ = os.Remove(hostURL.Host) + return nil } var safeNameRegexp = regexp.MustCompile(`[^a-zA-Z0-9._-]`) diff --git a/internal/proto/agent.go b/internal/proto/agent.go index 1163b1d8bac629546c8ef6632b0fed6a780c09e5..2deb906afb24c8da0f774276229dc1bcc100a813 100644 --- a/internal/proto/agent.go +++ b/internal/proto/agent.go @@ -32,9 +32,10 @@ type AgentEvent struct { Error error `json:"error,omitempty"` // When summarizing. - SessionID string `json:"session_id,omitempty"` - Progress string `json:"progress,omitempty"` - Done bool `json:"done,omitempty"` + SessionID string `json:"session_id,omitempty"` + SessionTitle string `json:"session_title,omitempty"` + Progress string `json:"progress,omitempty"` + Done bool `json:"done,omitempty"` } // MarshalJSON implements the [json.Marshaler] interface. diff --git a/internal/proto/mcp.go b/internal/proto/mcp.go index e04f9ed8467890bc34859cd54272204ad65a9156..e7491e79b203a781dedbb1f0a9e5bbf2eda2766c 100644 --- a/internal/proto/mcp.go +++ b/internal/proto/mcp.go @@ -1,6 +1,10 @@ package proto -import "fmt" +import ( + "encoding/json" + "errors" + "fmt" +) // MCPState represents the current state of an MCP client. type MCPState int @@ -54,7 +58,10 @@ func (s MCPState) String() string { type MCPEventType string const ( - MCPEventStateChanged MCPEventType = "state_changed" + MCPEventStateChanged MCPEventType = "state_changed" + MCPEventToolsListChanged MCPEventType = "tools_list_changed" + MCPEventPromptsListChanged MCPEventType = "prompts_list_changed" + MCPEventResourcesListChanged MCPEventType = "resources_list_changed" ) // MarshalText implements the [encoding.TextMarshaler] interface. @@ -70,9 +77,95 @@ func (t *MCPEventType) UnmarshalText(data []byte) error { // MCPEvent represents an event in the MCP system. type MCPEvent struct { - Type MCPEventType `json:"type"` - Name string `json:"name"` - State MCPState `json:"state"` - Error error `json:"error,omitempty"` - ToolCount int `json:"tool_count,omitempty"` + Type MCPEventType `json:"type"` + Name string `json:"name"` + State MCPState `json:"state"` + Error error `json:"error,omitempty"` + ToolCount int `json:"tool_count,omitempty"` + PromptCount int `json:"prompt_count,omitempty"` + ResourceCount int `json:"resource_count,omitempty"` +} + +// MarshalJSON implements the [json.Marshaler] interface. +func (e MCPEvent) MarshalJSON() ([]byte, error) { + type Alias MCPEvent + return json.Marshal(&struct { + Error string `json:"error,omitempty"` + Alias + }{ + Error: func() string { + if e.Error != nil { + return e.Error.Error() + } + return "" + }(), + Alias: (Alias)(e), + }) +} + +// UnmarshalJSON implements the [json.Unmarshaler] interface. +func (e *MCPEvent) UnmarshalJSON(data []byte) error { + type Alias MCPEvent + aux := &struct { + Error string `json:"error,omitempty"` + Alias + }{ + Alias: (Alias)(*e), + } + if err := json.Unmarshal(data, &aux); err != nil { + return err + } + *e = MCPEvent(aux.Alias) + if aux.Error != "" { + e.Error = errors.New(aux.Error) + } + return nil +} + +// MCPClientInfo is the wire-format representation of an MCP client's +// state, suitable for JSON transport between server and client. +type MCPClientInfo struct { + Name string `json:"name"` + State MCPState `json:"state"` + Error error `json:"error,omitempty"` + ToolCount int `json:"tool_count,omitempty"` + PromptCount int `json:"prompt_count,omitempty"` + ResourceCount int `json:"resource_count,omitempty"` + ConnectedAt int64 `json:"connected_at,omitempty"` +} + +// MarshalJSON implements the [json.Marshaler] interface. +func (i MCPClientInfo) MarshalJSON() ([]byte, error) { + type Alias MCPClientInfo + return json.Marshal(&struct { + Error string `json:"error,omitempty"` + Alias + }{ + Error: func() string { + if i.Error != nil { + return i.Error.Error() + } + return "" + }(), + Alias: (Alias)(i), + }) +} + +// UnmarshalJSON implements the [json.Unmarshaler] interface. +func (i *MCPClientInfo) UnmarshalJSON(data []byte) error { + type Alias MCPClientInfo + aux := &struct { + Error string `json:"error,omitempty"` + Alias + }{ + Alias: (Alias)(*i), + } + if err := json.Unmarshal(data, &aux); err != nil { + return err + } + *i = MCPClientInfo(aux.Alias) + if aux.Error != "" { + i.Error = errors.New(aux.Error) + } + return nil } diff --git a/internal/proto/proto.go b/internal/proto/proto.go index 25a0e409b308f1fd979f5f6bb7bd00619fad0b1e..9c84c6c8bf0c2f14da75933f8eebd7a36ca534ba 100644 --- a/internal/proto/proto.go +++ b/internal/proto/proto.go @@ -1,6 +1,8 @@ package proto import ( + "encoding/json" + "errors" "time" "charm.land/catwalk/pkg/catwalk" @@ -28,13 +30,15 @@ type Error struct { // AgentInfo represents information about the agent. type AgentInfo struct { - IsBusy bool `json:"is_busy"` - Model catwalk.Model `json:"model"` + IsBusy bool `json:"is_busy"` + IsReady bool `json:"is_ready"` + Model catwalk.Model `json:"model"` + ModelCfg config.SelectedModel `json:"model_cfg"` } // IsZero checks if the AgentInfo is zero-valued. func (a AgentInfo) IsZero() bool { - return !a.IsBusy && a.Model.ID == "" + return !a.IsBusy && !a.IsReady && a.Model.ID == "" } // AgentMessage represents a message sent to the agent. @@ -114,6 +118,42 @@ type LSPEvent struct { DiagnosticCount int `json:"diagnostic_count,omitempty"` } +// MarshalJSON implements the [json.Marshaler] interface. +func (e LSPEvent) MarshalJSON() ([]byte, error) { + type Alias LSPEvent + return json.Marshal(&struct { + Error string `json:"error,omitempty"` + Alias + }{ + Error: func() string { + if e.Error != nil { + return e.Error.Error() + } + return "" + }(), + Alias: (Alias)(e), + }) +} + +// UnmarshalJSON implements the [json.Unmarshaler] interface. +func (e *LSPEvent) UnmarshalJSON(data []byte) error { + type Alias LSPEvent + aux := &struct { + Error string `json:"error,omitempty"` + Alias + }{ + Alias: (Alias)(*e), + } + if err := json.Unmarshal(data, &aux); err != nil { + return err + } + *e = LSPEvent(aux.Alias) + if aux.Error != "" { + e.Error = errors.New(aux.Error) + } + return nil +} + // LSPClientInfo holds information about an LSP client's state. type LSPClientInfo struct { Name string `json:"name"` @@ -122,3 +162,39 @@ type LSPClientInfo struct { DiagnosticCount int `json:"diagnostic_count,omitempty"` ConnectedAt time.Time `json:"connected_at"` } + +// MarshalJSON implements the [json.Marshaler] interface. +func (i LSPClientInfo) MarshalJSON() ([]byte, error) { + type Alias LSPClientInfo + return json.Marshal(&struct { + Error string `json:"error,omitempty"` + Alias + }{ + Error: func() string { + if i.Error != nil { + return i.Error.Error() + } + return "" + }(), + Alias: (Alias)(i), + }) +} + +// UnmarshalJSON implements the [json.Unmarshaler] interface. +func (i *LSPClientInfo) UnmarshalJSON(data []byte) error { + type Alias LSPClientInfo + aux := &struct { + Error string `json:"error,omitempty"` + Alias + }{ + Alias: (Alias)(*i), + } + if err := json.Unmarshal(data, &aux); err != nil { + return err + } + *i = LSPClientInfo(aux.Alias) + if aux.Error != "" { + i.Error = errors.New(aux.Error) + } + return nil +} diff --git a/internal/server/config.go b/internal/server/config.go new file mode 100644 index 0000000000000000000000000000000000000000..b449ac0260207223b3a82a405df16259859dd66f --- /dev/null +++ b/internal/server/config.go @@ -0,0 +1,292 @@ +package server + +import ( + "encoding/json" + "net/http" + + "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/proto" +) + +func (c *controllerV1) handlePostWorkspaceConfigSet(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + + var req struct { + Scope config.Scope `json:"scope"` + Key string `json:"key"` + Value any `json:"value"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + c.server.logError(r, "Failed to decode request", "error", err) + jsonError(w, http.StatusBadRequest, "failed to decode request") + return + } + + if err := c.backend.SetConfigField(id, req.Scope, req.Key, req.Value); err != nil { + c.handleError(w, r, err) + return + } + w.WriteHeader(http.StatusOK) +} + +func (c *controllerV1) handlePostWorkspaceConfigRemove(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + + var req struct { + Scope config.Scope `json:"scope"` + Key string `json:"key"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + c.server.logError(r, "Failed to decode request", "error", err) + jsonError(w, http.StatusBadRequest, "failed to decode request") + return + } + + if err := c.backend.RemoveConfigField(id, req.Scope, req.Key); err != nil { + c.handleError(w, r, err) + return + } + w.WriteHeader(http.StatusOK) +} + +func (c *controllerV1) handlePostWorkspaceConfigModel(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + + var req struct { + Scope config.Scope `json:"scope"` + ModelType config.SelectedModelType `json:"model_type"` + Model config.SelectedModel `json:"model"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + c.server.logError(r, "Failed to decode request", "error", err) + jsonError(w, http.StatusBadRequest, "failed to decode request") + return + } + + if err := c.backend.UpdatePreferredModel(id, req.Scope, req.ModelType, req.Model); err != nil { + c.handleError(w, r, err) + return + } + w.WriteHeader(http.StatusOK) +} + +func (c *controllerV1) handlePostWorkspaceConfigCompact(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + + var req struct { + Scope config.Scope `json:"scope"` + Enabled bool `json:"enabled"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + c.server.logError(r, "Failed to decode request", "error", err) + jsonError(w, http.StatusBadRequest, "failed to decode request") + return + } + + if err := c.backend.SetCompactMode(id, req.Scope, req.Enabled); err != nil { + c.handleError(w, r, err) + return + } + w.WriteHeader(http.StatusOK) +} + +func (c *controllerV1) handlePostWorkspaceConfigProviderKey(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + + var req struct { + Scope config.Scope `json:"scope"` + ProviderID string `json:"provider_id"` + APIKey any `json:"api_key"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + c.server.logError(r, "Failed to decode request", "error", err) + jsonError(w, http.StatusBadRequest, "failed to decode request") + return + } + + if err := c.backend.SetProviderAPIKey(id, req.Scope, req.ProviderID, req.APIKey); err != nil { + c.handleError(w, r, err) + return + } + w.WriteHeader(http.StatusOK) +} + +func (c *controllerV1) handlePostWorkspaceConfigImportCopilot(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + token, ok, err := c.backend.ImportCopilot(id) + if err != nil { + c.handleError(w, r, err) + return + } + jsonEncode(w, struct { + Token any `json:"token"` + Success bool `json:"success"` + }{Token: token, Success: ok}) +} + +func (c *controllerV1) handlePostWorkspaceConfigRefreshOAuth(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + + var req struct { + Scope config.Scope `json:"scope"` + ProviderID string `json:"provider_id"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + c.server.logError(r, "Failed to decode request", "error", err) + jsonError(w, http.StatusBadRequest, "failed to decode request") + return + } + + if err := c.backend.RefreshOAuthToken(r.Context(), id, req.Scope, req.ProviderID); err != nil { + c.handleError(w, r, err) + return + } + w.WriteHeader(http.StatusOK) +} + +func (c *controllerV1) handleGetWorkspaceProjectNeedsInit(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + needs, err := c.backend.ProjectNeedsInitialization(id) + if err != nil { + c.handleError(w, r, err) + return + } + jsonEncode(w, struct { + NeedsInit bool `json:"needs_init"` + }{NeedsInit: needs}) +} + +func (c *controllerV1) handlePostWorkspaceProjectInit(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + if err := c.backend.MarkProjectInitialized(id); err != nil { + c.handleError(w, r, err) + return + } + w.WriteHeader(http.StatusOK) +} + +func (c *controllerV1) handleGetWorkspaceProjectInitPrompt(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + prompt, err := c.backend.InitializePrompt(id) + if err != nil { + c.handleError(w, r, err) + return + } + jsonEncode(w, struct { + Prompt string `json:"prompt"` + }{Prompt: prompt}) +} + +func (c *controllerV1) handlePostWorkspaceMCPRefreshTools(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + + var req struct { + Name string `json:"name"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + c.server.logError(r, "Failed to decode request", "error", err) + jsonError(w, http.StatusBadRequest, "failed to decode request") + return + } + + if err := c.backend.RefreshMCPTools(r.Context(), id, req.Name); err != nil { + c.handleError(w, r, err) + return + } + w.WriteHeader(http.StatusOK) +} + +func (c *controllerV1) handlePostWorkspaceMCPReadResource(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + + var req struct { + Name string `json:"name"` + URI string `json:"uri"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + c.server.logError(r, "Failed to decode request", "error", err) + jsonError(w, http.StatusBadRequest, "failed to decode request") + return + } + + contents, err := c.backend.ReadMCPResource(r.Context(), id, req.Name, req.URI) + if err != nil { + c.handleError(w, r, err) + return + } + jsonEncode(w, contents) +} + +func (c *controllerV1) handlePostWorkspaceMCPGetPrompt(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + + var req struct { + ClientID string `json:"client_id"` + PromptID string `json:"prompt_id"` + Args map[string]string `json:"args"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + c.server.logError(r, "Failed to decode request", "error", err) + jsonError(w, http.StatusBadRequest, "failed to decode request") + return + } + + prompt, err := c.backend.GetMCPPrompt(id, req.ClientID, req.PromptID, req.Args) + if err != nil { + c.handleError(w, r, err) + return + } + jsonEncode(w, struct { + Prompt string `json:"prompt"` + }{Prompt: prompt}) +} + +func (c *controllerV1) handleGetWorkspaceMCPStates(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + states := c.backend.MCPGetStates(id) + result := make(map[string]proto.MCPClientInfo, len(states)) + for k, v := range states { + result[k] = proto.MCPClientInfo{ + Name: v.Name, + State: proto.MCPState(v.State), + Error: v.Error, + ToolCount: v.Counts.Tools, + PromptCount: v.Counts.Prompts, + ResourceCount: v.Counts.Resources, + ConnectedAt: v.ConnectedAt.Unix(), + } + } + jsonEncode(w, result) +} + +func (c *controllerV1) handlePostWorkspaceMCPRefreshPrompts(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + + var req struct { + Name string `json:"name"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + c.server.logError(r, "Failed to decode request", "error", err) + jsonError(w, http.StatusBadRequest, "failed to decode request") + return + } + + c.backend.MCPRefreshPrompts(r.Context(), id, req.Name) + w.WriteHeader(http.StatusOK) +} + +func (c *controllerV1) handlePostWorkspaceMCPRefreshResources(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + + var req struct { + Name string `json:"name"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + c.server.logError(r, "Failed to decode request", "error", err) + jsonError(w, http.StatusBadRequest, "failed to decode request") + return + } + + c.backend.MCPRefreshResources(r.Context(), id, req.Name) + w.WriteHeader(http.StatusOK) +} diff --git a/internal/server/events.go b/internal/server/events.go new file mode 100644 index 0000000000000000000000000000000000000000..752311666bb6fcc2b1efde4d037711eaafaa0162 --- /dev/null +++ b/internal/server/events.go @@ -0,0 +1,214 @@ +package server + +import ( + "encoding/json" + "fmt" + "log/slog" + + "github.com/charmbracelet/crush/internal/agent/notify" + "github.com/charmbracelet/crush/internal/agent/tools/mcp" + "github.com/charmbracelet/crush/internal/app" + "github.com/charmbracelet/crush/internal/history" + "github.com/charmbracelet/crush/internal/message" + "github.com/charmbracelet/crush/internal/permission" + "github.com/charmbracelet/crush/internal/proto" + "github.com/charmbracelet/crush/internal/pubsub" + "github.com/charmbracelet/crush/internal/session" +) + +// wrapEvent converts a raw tea.Msg (a pubsub.Event[T] from the app +// event fan-in) into a pubsub.Payload envelope with the correct +// PayloadType discriminator and a proto-typed inner payload that has +// proper JSON tags. Returns nil if the event type is unrecognized. +func wrapEvent(ev any) *pubsub.Payload { + switch e := ev.(type) { + case pubsub.Event[app.LSPEvent]: + return envelope(pubsub.PayloadTypeLSPEvent, pubsub.Event[proto.LSPEvent]{ + Type: e.Type, + Payload: proto.LSPEvent{ + Type: proto.LSPEventType(e.Payload.Type), + Name: e.Payload.Name, + State: e.Payload.State, + Error: e.Payload.Error, + DiagnosticCount: e.Payload.DiagnosticCount, + }, + }) + case pubsub.Event[mcp.Event]: + return envelope(pubsub.PayloadTypeMCPEvent, pubsub.Event[proto.MCPEvent]{ + Type: e.Type, + Payload: proto.MCPEvent{ + Type: mcpEventTypeToProto(e.Payload.Type), + Name: e.Payload.Name, + State: proto.MCPState(e.Payload.State), + Error: e.Payload.Error, + ToolCount: e.Payload.Counts.Tools, + }, + }) + case pubsub.Event[permission.PermissionRequest]: + return envelope(pubsub.PayloadTypePermissionRequest, pubsub.Event[proto.PermissionRequest]{ + Type: e.Type, + Payload: proto.PermissionRequest{ + ID: e.Payload.ID, + SessionID: e.Payload.SessionID, + ToolCallID: e.Payload.ToolCallID, + ToolName: e.Payload.ToolName, + Description: e.Payload.Description, + Action: e.Payload.Action, + Path: e.Payload.Path, + Params: e.Payload.Params, + }, + }) + case pubsub.Event[permission.PermissionNotification]: + return envelope(pubsub.PayloadTypePermissionNotification, pubsub.Event[proto.PermissionNotification]{ + Type: e.Type, + Payload: proto.PermissionNotification{ + ToolCallID: e.Payload.ToolCallID, + Granted: e.Payload.Granted, + Denied: e.Payload.Denied, + }, + }) + case pubsub.Event[message.Message]: + return envelope(pubsub.PayloadTypeMessage, pubsub.Event[proto.Message]{ + Type: e.Type, + Payload: messageToProto(e.Payload), + }) + case pubsub.Event[session.Session]: + return envelope(pubsub.PayloadTypeSession, pubsub.Event[proto.Session]{ + Type: e.Type, + Payload: sessionToProto(e.Payload), + }) + case pubsub.Event[history.File]: + return envelope(pubsub.PayloadTypeFile, pubsub.Event[proto.File]{ + Type: e.Type, + Payload: fileToProto(e.Payload), + }) + case pubsub.Event[notify.Notification]: + return envelope(pubsub.PayloadTypeAgentEvent, pubsub.Event[proto.AgentEvent]{ + Type: e.Type, + Payload: proto.AgentEvent{ + SessionID: e.Payload.SessionID, + SessionTitle: e.Payload.SessionTitle, + Type: proto.AgentEventType(e.Payload.Type), + }, + }) + default: + slog.Warn("Unrecognized event type for SSE wrapping", "type", fmt.Sprintf("%T", ev)) + return nil + } +} + +// envelope marshals the inner event and wraps it in a pubsub.Payload. +func envelope(payloadType pubsub.PayloadType, inner any) *pubsub.Payload { + raw, err := json.Marshal(inner) + if err != nil { + slog.Error("Failed to marshal event payload", "error", err) + return nil + } + return &pubsub.Payload{ + Type: payloadType, + Payload: raw, + } +} + +func mcpEventTypeToProto(t mcp.EventType) proto.MCPEventType { + switch t { + case mcp.EventStateChanged: + return proto.MCPEventStateChanged + case mcp.EventToolsListChanged: + return proto.MCPEventToolsListChanged + case mcp.EventPromptsListChanged: + return proto.MCPEventPromptsListChanged + case mcp.EventResourcesListChanged: + return proto.MCPEventResourcesListChanged + default: + return proto.MCPEventStateChanged + } +} + +func sessionToProto(s session.Session) proto.Session { + return proto.Session{ + ID: s.ID, + ParentSessionID: s.ParentSessionID, + Title: s.Title, + SummaryMessageID: s.SummaryMessageID, + MessageCount: s.MessageCount, + PromptTokens: s.PromptTokens, + CompletionTokens: s.CompletionTokens, + Cost: s.Cost, + CreatedAt: s.CreatedAt, + UpdatedAt: s.UpdatedAt, + } +} + +func fileToProto(f history.File) proto.File { + return proto.File{ + ID: f.ID, + SessionID: f.SessionID, + Path: f.Path, + Content: f.Content, + Version: f.Version, + CreatedAt: f.CreatedAt, + UpdatedAt: f.UpdatedAt, + } +} + +func messageToProto(m message.Message) proto.Message { + msg := proto.Message{ + ID: m.ID, + SessionID: m.SessionID, + Role: proto.MessageRole(m.Role), + Model: m.Model, + Provider: m.Provider, + CreatedAt: m.CreatedAt, + UpdatedAt: m.UpdatedAt, + } + + for _, p := range m.Parts { + switch v := p.(type) { + case message.TextContent: + msg.Parts = append(msg.Parts, proto.TextContent{Text: v.Text}) + case message.ReasoningContent: + msg.Parts = append(msg.Parts, proto.ReasoningContent{ + Thinking: v.Thinking, + Signature: v.Signature, + StartedAt: v.StartedAt, + FinishedAt: v.FinishedAt, + }) + case message.ToolCall: + msg.Parts = append(msg.Parts, proto.ToolCall{ + ID: v.ID, + Name: v.Name, + Input: v.Input, + Finished: v.Finished, + }) + case message.ToolResult: + msg.Parts = append(msg.Parts, proto.ToolResult{ + ToolCallID: v.ToolCallID, + Name: v.Name, + Content: v.Content, + IsError: v.IsError, + }) + case message.Finish: + msg.Parts = append(msg.Parts, proto.Finish{ + Reason: proto.FinishReason(v.Reason), + Time: v.Time, + Message: v.Message, + Details: v.Details, + }) + case message.ImageURLContent: + msg.Parts = append(msg.Parts, proto.ImageURLContent{URL: v.URL, Detail: v.Detail}) + case message.BinaryContent: + msg.Parts = append(msg.Parts, proto.BinaryContent{Path: v.Path, MIMEType: v.MIMEType, Data: v.Data}) + } + } + + return msg +} + +func messagesToProto(msgs []message.Message) []proto.Message { + out := make([]proto.Message, len(msgs)) + for i, m := range msgs { + out[i] = messageToProto(m) + } + return out +} diff --git a/internal/server/proto.go b/internal/server/proto.go index b2ba0c769e1796785ceeca41659316928bd05335..966a08f04e1b36834593067703e6fd1e862535c7 100644 --- a/internal/server/proto.go +++ b/internal/server/proto.go @@ -124,7 +124,11 @@ func (c *controllerV1) handleGetWorkspaceEvents(w http.ResponseWriter, r *http.R return } c.server.logDebug(r, "Sending event", "event", fmt.Sprintf("%T %+v", ev, ev)) - data, err := json.Marshal(ev) + wrapped := wrapEvent(ev) + if wrapped == nil { + continue + } + data, err := json.Marshal(wrapped) if err != nil { c.server.logError(r, "Failed to marshal event", "error", err) continue @@ -143,7 +147,17 @@ func (c *controllerV1) handleGetWorkspaceLSPs(w http.ResponseWriter, r *http.Req c.handleError(w, r, err) return } - jsonEncode(w, states) + result := make(map[string]proto.LSPClientInfo, len(states)) + for k, v := range states { + result[k] = proto.LSPClientInfo{ + Name: v.Name, + State: v.State, + Error: v.Error, + DiagnosticCount: v.DiagnosticCount, + ConnectedAt: v.ConnectedAt, + } + } + jsonEncode(w, result) } func (c *controllerV1) handleGetWorkspaceLSPDiagnostics(w http.ResponseWriter, r *http.Request) { @@ -215,7 +229,128 @@ func (c *controllerV1) handleGetWorkspaceSessionMessages(w http.ResponseWriter, c.handleError(w, r, err) return } - jsonEncode(w, messages) + jsonEncode(w, messagesToProto(messages)) +} + +func (c *controllerV1) handlePutWorkspaceSession(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + + var sess session.Session + if err := json.NewDecoder(r.Body).Decode(&sess); err != nil { + c.server.logError(r, "Failed to decode request", "error", err) + jsonError(w, http.StatusBadRequest, "failed to decode request") + return + } + + saved, err := c.backend.SaveSession(r.Context(), id, sess) + if err != nil { + c.handleError(w, r, err) + return + } + jsonEncode(w, saved) +} + +func (c *controllerV1) handleDeleteWorkspaceSession(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + sid := r.PathValue("sid") + if err := c.backend.DeleteSession(r.Context(), id, sid); err != nil { + c.handleError(w, r, err) + return + } + w.WriteHeader(http.StatusOK) +} + +func (c *controllerV1) handleGetWorkspaceSessionUserMessages(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + sid := r.PathValue("sid") + messages, err := c.backend.ListUserMessages(r.Context(), id, sid) + if err != nil { + c.handleError(w, r, err) + return + } + jsonEncode(w, messagesToProto(messages)) +} + +func (c *controllerV1) handleGetWorkspaceAllUserMessages(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + messages, err := c.backend.ListAllUserMessages(r.Context(), id) + if err != nil { + c.handleError(w, r, err) + return + } + jsonEncode(w, messagesToProto(messages)) +} + +func (c *controllerV1) handleGetWorkspaceSessionFileTrackerFiles(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + sid := r.PathValue("sid") + files, err := c.backend.FileTrackerListReadFiles(r.Context(), id, sid) + if err != nil { + c.handleError(w, r, err) + return + } + jsonEncode(w, files) +} + +func (c *controllerV1) handlePostWorkspaceFileTrackerRead(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + + var req struct { + SessionID string `json:"session_id"` + Path string `json:"path"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + c.server.logError(r, "Failed to decode request", "error", err) + jsonError(w, http.StatusBadRequest, "failed to decode request") + return + } + + if err := c.backend.FileTrackerRecordRead(r.Context(), id, req.SessionID, req.Path); err != nil { + c.handleError(w, r, err) + return + } + w.WriteHeader(http.StatusOK) +} + +func (c *controllerV1) handleGetWorkspaceFileTrackerLastRead(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + sid := r.URL.Query().Get("session_id") + path := r.URL.Query().Get("path") + + t, err := c.backend.FileTrackerLastReadTime(r.Context(), id, sid, path) + if err != nil { + c.handleError(w, r, err) + return + } + jsonEncode(w, t) +} + +func (c *controllerV1) handlePostWorkspaceLSPStart(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + + var req struct { + Path string `json:"path"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + c.server.logError(r, "Failed to decode request", "error", err) + jsonError(w, http.StatusBadRequest, "failed to decode request") + return + } + + if err := c.backend.LSPStart(r.Context(), id, req.Path); err != nil { + c.handleError(w, r, err) + return + } + w.WriteHeader(http.StatusOK) +} + +func (c *controllerV1) handlePostWorkspaceLSPStopAll(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + if err := c.backend.LSPStopAll(r.Context(), id); err != nil { + c.handleError(w, r, err) + return + } + w.WriteHeader(http.StatusOK) } func (c *controllerV1) handleGetWorkspaceAgent(w http.ResponseWriter, r *http.Request) { @@ -313,6 +448,28 @@ func (c *controllerV1) handleGetWorkspaceAgentSessionSummarize(w http.ResponseWr } } +func (c *controllerV1) handleGetWorkspaceAgentSessionPromptList(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + sid := r.PathValue("sid") + prompts, err := c.backend.QueuedPromptsList(id, sid) + if err != nil { + c.handleError(w, r, err) + return + } + jsonEncode(w, prompts) +} + +func (c *controllerV1) handleGetWorkspaceAgentDefaultSmallModel(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + providerID := r.URL.Query().Get("provider_id") + model, err := c.backend.GetDefaultSmallModel(id, providerID) + if err != nil { + c.handleError(w, r, err) + return + } + jsonEncode(w, model) +} + func (c *controllerV1) handlePostWorkspacePermissionsGrant(w http.ResponseWriter, r *http.Request) { id := r.PathValue("id") diff --git a/internal/server/server.go b/internal/server/server.go index 94211f46c397b40e77891a51be6b2a48f5502018..5fb05015e495e6a80e9ea4762903ba79d5639d61 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -118,10 +118,19 @@ func NewServer(cfg *config.ConfigStore, network, address string) *Server { mux.HandleFunc("GET /v1/workspaces/{id}/sessions", c.handleGetWorkspaceSessions) mux.HandleFunc("POST /v1/workspaces/{id}/sessions", c.handlePostWorkspaceSessions) mux.HandleFunc("GET /v1/workspaces/{id}/sessions/{sid}", c.handleGetWorkspaceSession) + mux.HandleFunc("PUT /v1/workspaces/{id}/sessions/{sid}", c.handlePutWorkspaceSession) + mux.HandleFunc("DELETE /v1/workspaces/{id}/sessions/{sid}", c.handleDeleteWorkspaceSession) mux.HandleFunc("GET /v1/workspaces/{id}/sessions/{sid}/history", c.handleGetWorkspaceSessionHistory) mux.HandleFunc("GET /v1/workspaces/{id}/sessions/{sid}/messages", c.handleGetWorkspaceSessionMessages) + mux.HandleFunc("GET /v1/workspaces/{id}/sessions/{sid}/messages/user", c.handleGetWorkspaceSessionUserMessages) + mux.HandleFunc("GET /v1/workspaces/{id}/messages/user", c.handleGetWorkspaceAllUserMessages) + mux.HandleFunc("GET /v1/workspaces/{id}/sessions/{sid}/filetracker/files", c.handleGetWorkspaceSessionFileTrackerFiles) + mux.HandleFunc("POST /v1/workspaces/{id}/filetracker/read", c.handlePostWorkspaceFileTrackerRead) + mux.HandleFunc("GET /v1/workspaces/{id}/filetracker/lastread", c.handleGetWorkspaceFileTrackerLastRead) mux.HandleFunc("GET /v1/workspaces/{id}/lsps", c.handleGetWorkspaceLSPs) mux.HandleFunc("GET /v1/workspaces/{id}/lsps/{lsp}/diagnostics", c.handleGetWorkspaceLSPDiagnostics) + mux.HandleFunc("POST /v1/workspaces/{id}/lsps/start", c.handlePostWorkspaceLSPStart) + mux.HandleFunc("POST /v1/workspaces/{id}/lsps/stop", c.handlePostWorkspaceLSPStopAll) mux.HandleFunc("GET /v1/workspaces/{id}/permissions/skip", c.handleGetWorkspacePermissionsSkip) mux.HandleFunc("POST /v1/workspaces/{id}/permissions/skip", c.handlePostWorkspacePermissionsSkip) mux.HandleFunc("POST /v1/workspaces/{id}/permissions/grant", c.handlePostWorkspacePermissionsGrant) @@ -132,8 +141,26 @@ func NewServer(cfg *config.ConfigStore, network, address string) *Server { mux.HandleFunc("GET /v1/workspaces/{id}/agent/sessions/{sid}", c.handleGetWorkspaceAgentSession) mux.HandleFunc("POST /v1/workspaces/{id}/agent/sessions/{sid}/cancel", c.handlePostWorkspaceAgentSessionCancel) mux.HandleFunc("GET /v1/workspaces/{id}/agent/sessions/{sid}/prompts/queued", c.handleGetWorkspaceAgentSessionPromptQueued) + mux.HandleFunc("GET /v1/workspaces/{id}/agent/sessions/{sid}/prompts/list", c.handleGetWorkspaceAgentSessionPromptList) mux.HandleFunc("POST /v1/workspaces/{id}/agent/sessions/{sid}/prompts/clear", c.handlePostWorkspaceAgentSessionPromptClear) mux.HandleFunc("POST /v1/workspaces/{id}/agent/sessions/{sid}/summarize", c.handleGetWorkspaceAgentSessionSummarize) + mux.HandleFunc("GET /v1/workspaces/{id}/agent/default-small-model", c.handleGetWorkspaceAgentDefaultSmallModel) + mux.HandleFunc("POST /v1/workspaces/{id}/config/set", c.handlePostWorkspaceConfigSet) + mux.HandleFunc("POST /v1/workspaces/{id}/config/remove", c.handlePostWorkspaceConfigRemove) + mux.HandleFunc("POST /v1/workspaces/{id}/config/model", c.handlePostWorkspaceConfigModel) + mux.HandleFunc("POST /v1/workspaces/{id}/config/compact", c.handlePostWorkspaceConfigCompact) + mux.HandleFunc("POST /v1/workspaces/{id}/config/provider-key", c.handlePostWorkspaceConfigProviderKey) + mux.HandleFunc("POST /v1/workspaces/{id}/config/import-copilot", c.handlePostWorkspaceConfigImportCopilot) + mux.HandleFunc("POST /v1/workspaces/{id}/config/refresh-oauth", c.handlePostWorkspaceConfigRefreshOAuth) + mux.HandleFunc("GET /v1/workspaces/{id}/project/needs-init", c.handleGetWorkspaceProjectNeedsInit) + mux.HandleFunc("POST /v1/workspaces/{id}/project/init", c.handlePostWorkspaceProjectInit) + mux.HandleFunc("GET /v1/workspaces/{id}/project/init-prompt", c.handleGetWorkspaceProjectInitPrompt) + mux.HandleFunc("POST /v1/workspaces/{id}/mcp/refresh-tools", c.handlePostWorkspaceMCPRefreshTools) + mux.HandleFunc("POST /v1/workspaces/{id}/mcp/read-resource", c.handlePostWorkspaceMCPReadResource) + mux.HandleFunc("POST /v1/workspaces/{id}/mcp/get-prompt", c.handlePostWorkspaceMCPGetPrompt) + mux.HandleFunc("GET /v1/workspaces/{id}/mcp/states", c.handleGetWorkspaceMCPStates) + mux.HandleFunc("POST /v1/workspaces/{id}/mcp/refresh-prompts", c.handlePostWorkspaceMCPRefreshPrompts) + mux.HandleFunc("POST /v1/workspaces/{id}/mcp/refresh-resources", c.handlePostWorkspaceMCPRefreshResources) s.h = &http.Server{ Protocols: &p, Handler: s.loggingHandler(mux), diff --git a/internal/ui/common/common.go b/internal/ui/common/common.go index 143b20305464da33d2f350a36176bab0e45b85aa..8e00f0d0d2a74396df36e4b4d97762a9087087be 100644 --- a/internal/ui/common/common.go +++ b/internal/ui/common/common.go @@ -7,10 +7,10 @@ import ( tea "charm.land/bubbletea/v2" "github.com/atotto/clipboard" - "github.com/charmbracelet/crush/internal/app" "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/ui/styles" "github.com/charmbracelet/crush/internal/ui/util" + "github.com/charmbracelet/crush/internal/workspace" uv "github.com/charmbracelet/ultraviolet" ) @@ -22,26 +22,21 @@ var AllowedImageTypes = []string{".jpg", ".jpeg", ".png"} // Common defines common UI options and configurations. type Common struct { - App *app.App - Styles *styles.Styles + Workspace workspace.Workspace + Styles *styles.Styles } // Config returns the pure-data configuration associated with this [Common] instance. func (c *Common) Config() *config.Config { - return c.App.Config() -} - -// Store returns the config store associated with this [Common] instance. -func (c *Common) Store() *config.ConfigStore { - return c.App.Store() + return c.Workspace.Config() } // DefaultCommon returns the default common UI configurations. -func DefaultCommon(app *app.App) *Common { +func DefaultCommon(ws workspace.Workspace) *Common { s := styles.DefaultStyles() return &Common{ - App: app, - Styles: &s, + Workspace: ws, + Styles: &s, } } diff --git a/internal/ui/dialog/api_key_input.go b/internal/ui/dialog/api_key_input.go index cc37d742903d5a80bbcffcf1ff24fb24596dfccd..bc69526caaf2be452e029a6d66a56d3d8c83326f 100644 --- a/internal/ui/dialog/api_key_input.go +++ b/internal/ui/dialog/api_key_input.go @@ -296,7 +296,7 @@ func (m *APIKeyInput) verifyAPIKey() tea.Msg { Type: m.provider.Type, BaseURL: m.provider.APIEndpoint, } - err := providerConfig.TestConnection(m.com.Store().Resolver()) + err := providerConfig.TestConnection(m.com.Workspace.Resolver()) // intentionally wait for at least 750ms to make sure the user sees the spinner elapsed := time.Since(start) @@ -312,9 +312,7 @@ func (m *APIKeyInput) verifyAPIKey() tea.Msg { } func (m *APIKeyInput) saveKeyAndContinue() Action { - store := m.com.Store() - - err := store.SetProviderAPIKey(config.ScopeGlobal, string(m.provider.ID), m.input.Value()) + err := m.com.Workspace.SetProviderAPIKey(config.ScopeGlobal, string(m.provider.ID), m.input.Value()) if err != nil { return ActionCmd{util.ReportError(fmt.Errorf("failed to save API key: %w", err))} } diff --git a/internal/ui/dialog/filepicker.go b/internal/ui/dialog/filepicker.go index 78f82a05f7e2e0db7a9bb561fb1b6248d8045513..82fca6f47ddc9338057f7cdf027cce03d1be65ef 100644 --- a/internal/ui/dialog/filepicker.go +++ b/internal/ui/dialog/filepicker.go @@ -123,7 +123,7 @@ func (f *FilePicker) SetImageCapabilities(caps *common.Capabilities) { // WorkingDir returns the current working directory of the [FilePicker]. func (f *FilePicker) WorkingDir() string { - wd := f.com.Store().WorkingDir() + wd := f.com.Workspace.WorkingDir() if len(wd) > 0 { return wd } diff --git a/internal/ui/dialog/models.go b/internal/ui/dialog/models.go index 434f699e91b4c227c4e54f6ff553affff76a1c43..0fdd710155dc11fe92139b97a8601f1ffe0e7d74 100644 --- a/internal/ui/dialog/models.go +++ b/internal/ui/dialog/models.go @@ -490,7 +490,7 @@ func (m *Models) setProviderItems() error { if len(validRecentItems) != len(recentItems) { // FIXME: Does this need to be here? Is it mutating the config during a read? - if err := m.com.Store().SetConfigField(config.ScopeGlobal, fmt.Sprintf("recent_models.%s", selectedType), validRecentItems); err != nil { + if err := m.com.Workspace.SetConfigField(config.ScopeGlobal, fmt.Sprintf("recent_models.%s", selectedType), validRecentItems); err != nil { return fmt.Errorf("failed to update recent models: %w", err) } } diff --git a/internal/ui/dialog/oauth.go b/internal/ui/dialog/oauth.go index 2803070381e65bd0380a8ddab5f256481c117c15..b8faa620807be4d2adb07f8e511cec6510050c09 100644 --- a/internal/ui/dialog/oauth.go +++ b/internal/ui/dialog/oauth.go @@ -373,9 +373,7 @@ func (d *OAuth) copyCodeAndOpenURL() tea.Cmd { } func (m *OAuth) saveKeyAndContinue() Action { - store := m.com.Store() - - err := store.SetProviderAPIKey(config.ScopeGlobal, string(m.provider.ID), m.token) + err := m.com.Workspace.SetProviderAPIKey(config.ScopeGlobal, string(m.provider.ID), m.token) if err != nil { return ActionCmd{util.ReportError(fmt.Errorf("failed to save API key: %w", err))} } diff --git a/internal/ui/dialog/sessions.go b/internal/ui/dialog/sessions.go index 6f9b7724a796818c789e19ba9455c23e7e51c9b4..8f3ce81960e5170c0059d44261c1e79ca2bbfea8 100644 --- a/internal/ui/dialog/sessions.go +++ b/internal/ui/dialog/sessions.go @@ -61,7 +61,7 @@ func NewSessions(com *common.Common, selectedSessionID string) (*Session, error) s := new(Session) s.sessionsMode = sessionsModeNormal s.com = com - sessions, err := com.App.Sessions.List(context.TODO()) + sessions, err := com.Workspace.ListSessions(context.TODO()) if err != nil { return nil, err } @@ -349,7 +349,7 @@ func (s *Session) removeSession(id string) { func (s *Session) deleteSessionCmd(id string) tea.Cmd { return func() tea.Msg { - err := s.com.App.Sessions.Delete(context.TODO(), id) + err := s.com.Workspace.DeleteSession(context.TODO(), id) if err != nil { return util.NewErrorMsg(err) } @@ -385,7 +385,7 @@ func (s *Session) updateSession(session session.Session) { func (s *Session) updateSessionCmd(session session.Session) tea.Cmd { return func() tea.Msg { - _, err := s.com.App.Sessions.Save(context.TODO(), session) + _, err := s.com.Workspace.SaveSession(context.TODO(), session) if err != nil { return util.NewErrorMsg(err) } @@ -399,11 +399,11 @@ func (s *Session) isCurrentSessionBusy() bool { return false } - if s.com.App.AgentCoordinator == nil { + if !s.com.Workspace.AgentIsReady() { return false } - return s.com.App.AgentCoordinator.IsSessionBusy(sessionItem.ID()) + return s.com.Workspace.AgentIsSessionBusy(sessionItem.ID()) } // ShortHelp implements [help.KeyMap]. diff --git a/internal/ui/model/header.go b/internal/ui/model/header.go index 06bb4ff92981b28625efb11683081e29fc55a21e..f4e4fd49a00ebb0280ec3583ce7fa3ea6513bd40 100644 --- a/internal/ui/model/header.go +++ b/internal/ui/model/header.go @@ -6,9 +6,7 @@ import ( "charm.land/lipgloss/v2" "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/fsext" - "github.com/charmbracelet/crush/internal/lsp" "github.com/charmbracelet/crush/internal/session" "github.com/charmbracelet/crush/internal/ui/common" "github.com/charmbracelet/crush/internal/ui/styles" @@ -62,7 +60,7 @@ func (h *header) drawHeader( h.width = width h.compact = compact - if !compact || session == nil || h.com.App == nil { + if !compact || session == nil { uv.NewStyledString(h.logo).Draw(scr, area) return } @@ -75,10 +73,14 @@ func (h *header) drawHeader( b.WriteString(h.compactLogo) availDetailWidth := width - leftPadding - rightPadding - lipgloss.Width(b.String()) - minHeaderDiags - diagToDetailsSpacing + lspErrorCount := 0 + for _, info := range h.com.Workspace.LSPGetStates() { + lspErrorCount += info.DiagnosticCount + } details := renderHeaderDetails( h.com, session, - h.com.App.LSPManager.Clients(), + lspErrorCount, detailsOpen, availDetailWidth, ) @@ -108,7 +110,7 @@ func (h *header) drawHeader( func renderHeaderDetails( com *common.Common, session *session.Session, - lspClients *csync.Map[string, *lsp.Client], + lspErrorCount int, detailsOpen bool, availWidth int, ) string { @@ -116,20 +118,17 @@ func renderHeaderDetails( var parts []string - errorCount := 0 - for l := range lspClients.Seq() { - errorCount += l.GetDiagnosticCounts().Error - } - - if errorCount > 0 { - parts = append(parts, t.LSP.ErrorDiagnostic.Render(fmt.Sprintf("%s%d", styles.LSPErrorIcon, errorCount))) + if lspErrorCount > 0 { + parts = append(parts, t.LSP.ErrorDiagnostic.Render(fmt.Sprintf("%s%d", styles.LSPErrorIcon, lspErrorCount))) } agentCfg := com.Config().Agents[config.AgentCoder] model := com.Config().GetModelByType(agentCfg.Model) - percentage := (float64(session.CompletionTokens+session.PromptTokens) / float64(model.ContextWindow)) * 100 - formattedPercentage := t.Header.Percentage.Render(fmt.Sprintf("%d%%", int(percentage))) - parts = append(parts, formattedPercentage) + if model != nil && model.ContextWindow > 0 { + percentage := (float64(session.CompletionTokens+session.PromptTokens) / float64(model.ContextWindow)) * 100 + formattedPercentage := t.Header.Percentage.Render(fmt.Sprintf("%d%%", int(percentage))) + parts = append(parts, formattedPercentage) + } const keystroke = "ctrl+d" if detailsOpen { @@ -143,7 +142,7 @@ func renderHeaderDetails( metadata = dot + metadata const dirTrimLimit = 4 - cwd := fsext.DirTrim(fsext.PrettyPath(com.Store().WorkingDir()), dirTrimLimit) + cwd := fsext.DirTrim(fsext.PrettyPath(com.Workspace.WorkingDir()), dirTrimLimit) cwd = t.Header.WorkingDir.Render(cwd) result := cwd + metadata diff --git a/internal/ui/model/history.go b/internal/ui/model/history.go index 5d2284ab1756257cc06b76de4621849f1e3071ba..5a9f2810806e8b1916335baa59484d114fcb5310 100644 --- a/internal/ui/model/history.go +++ b/internal/ui/model/history.go @@ -22,9 +22,9 @@ func (m *UI) loadPromptHistory() tea.Cmd { var err error if m.session != nil { - messages, err = m.com.App.Messages.ListUserMessages(ctx, m.session.ID) + messages, err = m.com.Workspace.ListUserMessages(ctx, m.session.ID) } else { - messages, err = m.com.App.Messages.ListAllUserMessages(ctx) + messages, err = m.com.Workspace.ListAllUserMessages(ctx) } if err != nil { slog.Error("Failed to load prompt history", "error", err) diff --git a/internal/ui/model/landing.go b/internal/ui/model/landing.go index 72c2671ccd297f4bade087f6b2cb960f6c6a92a9..e78d03e2afb5bd826e7d5ffd5c4d571575fcf949 100644 --- a/internal/ui/model/landing.go +++ b/internal/ui/model/landing.go @@ -2,16 +2,16 @@ package model import ( "charm.land/lipgloss/v2" - "github.com/charmbracelet/crush/internal/agent" "github.com/charmbracelet/crush/internal/ui/common" + "github.com/charmbracelet/crush/internal/workspace" "github.com/charmbracelet/ultraviolet/layout" ) // selectedLargeModel returns the currently selected large language model from // the agent coordinator, if one exists. -func (m *UI) selectedLargeModel() *agent.Model { - if m.com.App.AgentCoordinator != nil { - model := m.com.App.AgentCoordinator.Model() +func (m *UI) selectedLargeModel() *workspace.AgentModel { + if m.com.Workspace.AgentIsReady() { + model := m.com.Workspace.AgentModel() return &model } return nil @@ -22,7 +22,7 @@ func (m *UI) selectedLargeModel() *agent.Model { func (m *UI) landingView() string { t := m.com.Styles width := m.layout.main.Dx() - cwd := common.PrettyPath(t, m.com.Store().WorkingDir(), width) + cwd := common.PrettyPath(t, m.com.Workspace.WorkingDir(), width) parts := []string{ cwd, diff --git a/internal/ui/model/lsp.go b/internal/ui/model/lsp.go index 1458d3402cbfa1536e5bef31f7d72ac5d58dddfe..c82a4a77b80a5ee258e6d5195bffd1bd7deb76c0 100644 --- a/internal/ui/model/lsp.go +++ b/internal/ui/model/lsp.go @@ -7,16 +7,16 @@ import ( "strings" "charm.land/lipgloss/v2" - "github.com/charmbracelet/crush/internal/app" "github.com/charmbracelet/crush/internal/lsp" "github.com/charmbracelet/crush/internal/ui/common" "github.com/charmbracelet/crush/internal/ui/styles" + "github.com/charmbracelet/crush/internal/workspace" "github.com/charmbracelet/x/powernap/pkg/lsp/protocol" ) // LSPInfo wraps LSP client information with diagnostic counts by severity. type LSPInfo struct { - app.LSPClientInfo + workspace.LSPClientInfo Diagnostics map[protocol.DiagnosticSeverity]int } @@ -25,14 +25,14 @@ type LSPInfo struct { func (m *UI) lspInfo(width, maxItems int, isSection bool) string { t := m.com.Styles - states := slices.SortedFunc(maps.Values(m.lspStates), func(a, b app.LSPClientInfo) int { + states := slices.SortedFunc(maps.Values(m.lspStates), func(a, b workspace.LSPClientInfo) int { return strings.Compare(a.Name, b.Name) }) var lsps []LSPInfo for _, state := range states { lspErrs := map[protocol.DiagnosticSeverity]int{} - if client, ok := m.com.App.LSPManager.Clients().Get(state.Name); ok { + if client, ok := m.com.Workspace.LSPGetClient(state.Name); ok { counts := client.GetDiagnosticCounts() lspErrs[protocol.SeverityError] = counts.Error lspErrs[protocol.SeverityWarning] = counts.Warning diff --git a/internal/ui/model/onboarding.go b/internal/ui/model/onboarding.go index f094ae957113a2d7cf6cad92d76cca7df82e32e3..c905660a0570d1402d85be8aed4e805c11506510 100644 --- a/internal/ui/model/onboarding.go +++ b/internal/ui/model/onboarding.go @@ -9,8 +9,6 @@ import ( tea "charm.land/bubbletea/v2" "charm.land/lipgloss/v2" - "github.com/charmbracelet/crush/internal/agent" - "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/home" "github.com/charmbracelet/crush/internal/ui/common" "github.com/charmbracelet/crush/internal/ui/util" @@ -19,7 +17,7 @@ import ( // markProjectInitialized marks the current project as initialized in the config. func (m *UI) markProjectInitialized() tea.Msg { // TODO: handle error so we show it in the tui footer - err := config.MarkProjectInitialized(m.com.Store()) + err := m.com.Workspace.MarkProjectInitialized() if err != nil { slog.Error(err.Error()) } @@ -52,10 +50,8 @@ func (m *UI) initializeProject() tea.Cmd { if cmd := m.newSession(); cmd != nil { cmds = append(cmds, cmd) } - cfg := m.com.Store() - initialize := func() tea.Msg { - initPrompt, err := agent.InitializePrompt(cfg) + initPrompt, err := m.com.Workspace.InitializePrompt() if err != nil { return util.InfoMsg{ Type: util.InfoTypeError, @@ -81,7 +77,7 @@ func (m *UI) skipInitializeProject() tea.Cmd { // initializeView renders the project initialization prompt with Yes/No buttons. func (m *UI) initializeView() string { s := m.com.Styles.Initialize - cwd := home.Short(m.com.Store().WorkingDir()) + cwd := home.Short(m.com.Workspace.WorkingDir()) initFile := m.com.Config().Options.InitializeAs header := s.Header.Render("Would you like to initialize this project?") diff --git a/internal/ui/model/pills.go b/internal/ui/model/pills.go index 9b3307135aec89105adb895ef07bbe30484ec658..d2f843848956bdd5f4d05674da1d164caf1e04c6 100644 --- a/internal/ui/model/pills.go +++ b/internal/ui/model/pills.go @@ -249,8 +249,8 @@ func (m *UI) renderPills() { if todosFocused && hasIncomplete { expandedList = todoList(m.session.Todos, inProgressIcon, t, contentWidth) } else if queueFocused && hasQueue { - if m.com.App != nil && m.com.App.AgentCoordinator != nil { - queueItems := m.com.App.AgentCoordinator.QueuedPromptsList(m.session.ID) + if m.com.Workspace.AgentIsReady() { + queueItems := m.com.Workspace.AgentQueuedPromptsList(m.session.ID) expandedList = queueList(queueItems, t) } } diff --git a/internal/ui/model/session.go b/internal/ui/model/session.go index c043255c041c20523a2e14b85285bccc7ee7eeb1..62d6a050c38d06bdd16b595ffe16b436e7224c96 100644 --- a/internal/ui/model/session.go +++ b/internal/ui/model/session.go @@ -66,7 +66,7 @@ type SessionFile struct { // returns a sessionFilesLoadedMsg containing the processed session files. func (m *UI) loadSession(sessionID string) tea.Cmd { return func() tea.Msg { - session, err := m.com.App.Sessions.Get(context.Background(), sessionID) + session, err := m.com.Workspace.GetSession(context.Background(), sessionID) if err != nil { return util.ReportError(err) } @@ -76,7 +76,7 @@ func (m *UI) loadSession(sessionID string) tea.Cmd { return util.ReportError(err) } - readFiles, err := m.com.App.FileTracker.ListReadFiles(context.Background(), sessionID) + readFiles, err := m.com.Workspace.FileTrackerListReadFiles(context.Background(), sessionID) if err != nil { slog.Error("Failed to load read files for session", "error", err) } @@ -90,7 +90,7 @@ func (m *UI) loadSession(sessionID string) tea.Cmd { } func (m *UI) loadSessionFiles(sessionID string) ([]SessionFile, error) { - files, err := m.com.App.History.ListBySession(context.Background(), sessionID) + files, err := m.com.Workspace.ListSessionHistory(context.Background(), sessionID) if err != nil { return nil, err } @@ -241,7 +241,7 @@ func (m *UI) startLSPs(paths []string) tea.Cmd { return func() tea.Msg { ctx := context.Background() for _, path := range paths { - m.com.App.LSPManager.Start(ctx, path) + m.com.Workspace.LSPStart(ctx, path) } return nil } diff --git a/internal/ui/model/sidebar.go b/internal/ui/model/sidebar.go index 8849d86a8e1c8bda02092e3f165e85b8e32a8b1d..13e1797f3f2f7155a17d93bbe01ae7bb14e8246f 100644 --- a/internal/ui/model/sidebar.go +++ b/internal/ui/model/sidebar.go @@ -48,7 +48,11 @@ func (m *UI) modelInfo(width int) string { ModelContext: model.CatwalkCfg.ContextWindow, } } - return common.ModelInfo(m.com.Styles, model.CatwalkCfg.Name, providerName, reasoningInfo, modelContext, width) + var modelName string + if model != nil { + modelName = model.CatwalkCfg.Name + } + return common.ModelInfo(m.com.Styles, modelName, providerName, reasoningInfo, modelContext, width) } // getDynamicHeightLimits will give us the num of items to show in each section based on the hight @@ -112,7 +116,7 @@ func (m *UI) drawSidebar(scr uv.Screen, area uv.Rectangle) { height := area.Dy() title := t.Muted.Width(width).MaxHeight(2).Render(m.session.Title) - cwd := common.PrettyPath(t, m.com.Store().WorkingDir(), width) + cwd := common.PrettyPath(t, m.com.Workspace.WorkingDir(), width) sidebarLogo := m.sidebarLogo if height < logoHeightBreakpoint { sidebarLogo = logo.SmallRender(m.com.Styles, width) @@ -138,7 +142,7 @@ func (m *UI) drawSidebar(scr uv.Screen, area uv.Rectangle) { lspSection := m.lspInfo(width, maxLSPs, true) mcpSection := m.mcpInfo(width, maxMCPs, true) - filesSection := m.filesInfo(m.com.Store().WorkingDir(), width, maxFiles, true) + filesSection := m.filesInfo(m.com.Workspace.WorkingDir(), width, maxFiles, true) uv.NewStyledString( lipgloss.NewStyle(). diff --git a/internal/ui/model/ui.go b/internal/ui/model/ui.go index f5a822d6f16ef6c79fc2fbf003bf2a5d689f5643..d572f63f07aa1975312a35e5bc52bb2e782c08d4 100644 --- a/internal/ui/model/ui.go +++ b/internal/ui/model/ui.go @@ -28,7 +28,6 @@ import ( "github.com/charmbracelet/crush/internal/agent/notify" agenttools "github.com/charmbracelet/crush/internal/agent/tools" "github.com/charmbracelet/crush/internal/agent/tools/mcp" - "github.com/charmbracelet/crush/internal/app" "github.com/charmbracelet/crush/internal/commands" "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/fsext" @@ -50,6 +49,7 @@ import ( "github.com/charmbracelet/crush/internal/ui/styles" "github.com/charmbracelet/crush/internal/ui/util" "github.com/charmbracelet/crush/internal/version" + "github.com/charmbracelet/crush/internal/workspace" uv "github.com/charmbracelet/ultraviolet" "github.com/charmbracelet/ultraviolet/layout" "github.com/charmbracelet/ultraviolet/screen" @@ -195,7 +195,7 @@ type UI struct { } // lsp - lspStates map[string]app.LSPClientInfo + lspStates map[string]workspace.LSPClientInfo // mcp mcpStates map[string]mcp.ClientInfo @@ -294,7 +294,7 @@ func New(com *common.Common) *UI { completions: comp, attachments: attachments, todoSpinner: todoSpinner, - lspStates: make(map[string]app.LSPClientInfo), + lspStates: make(map[string]workspace.LSPClientInfo), mcpStates: make(map[string]mcp.ClientInfo), notifyBackend: notification.NoopBackend{}, notifyWindowFocused: true, @@ -317,7 +317,7 @@ func New(com *common.Common) *UI { desiredFocus := uiFocusEditor if !com.Config().IsConfigured() { desiredState = uiOnboarding - } else if n, _ := config.ProjectNeedsInitialization(com.Store()); n { + } else if n, _ := com.Workspace.ProjectNeedsInitialization(); n { desiredState = uiInitialize } @@ -415,7 +415,7 @@ func (m *UI) loadMCPrompts() tea.Msg { func (m *UI) Update(msg tea.Msg) (tea.Model, tea.Cmd) { var cmds []tea.Cmd if m.hasSession() && m.isAgentBusy() { - queueSize := m.com.App.AgentCoordinator.QueuedPrompts(m.session.ID) + queueSize := m.com.Workspace.AgentQueuedPrompts(m.session.ID) if queueSize != m.promptQueue { m.promptQueue = queueSize m.updateLayoutAndSize() @@ -450,7 +450,7 @@ func (m *UI) Update(msg tea.Msg) (tea.Model, tea.Cmd) { m.session = msg.session m.sessionFiles = msg.files cmds = append(cmds, m.startLSPs(msg.lspFilePaths())) - msgs, err := m.com.App.Messages.List(context.Background(), m.session.ID) + msgs, err := m.com.Workspace.ListMessages(context.Background(), m.session.ID) if err != nil { cmds = append(cmds, util.ReportError(err)) break @@ -567,8 +567,8 @@ func (m *UI) Update(msg tea.Msg) (tea.Model, tea.Cmd) { m.renderPills() case pubsub.Event[history.File]: cmds = append(cmds, m.handleFileEvent(msg.Payload)) - case pubsub.Event[app.LSPEvent]: - m.lspStates = app.GetLSPStates() + case pubsub.Event[workspace.LSPEvent]: + m.lspStates = m.com.Workspace.LSPGetStates() case pubsub.Event[mcp.Event]: switch msg.Payload.Type { case mcp.EventStateChanged: @@ -577,11 +577,11 @@ func (m *UI) Update(msg tea.Msg) (tea.Model, tea.Cmd) { m.loadMCPrompts, ) case mcp.EventPromptsListChanged: - return m, handleMCPPromptsEvent(msg.Payload.Name) + return m, handleMCPPromptsEvent(m.com.Workspace, msg.Payload.Name) case mcp.EventToolsListChanged: - return m, handleMCPToolsEvent(m.com.Store(), msg.Payload.Name) + return m, handleMCPToolsEvent(m.com.Workspace, msg.Payload.Name) case mcp.EventResourcesListChanged: - return m, handleMCPResourcesEvent(msg.Payload.Name) + return m, handleMCPResourcesEvent(m.com.Workspace, msg.Payload.Name) } case pubsub.Event[permission.PermissionRequest]: if cmd := m.openPermissionsDialog(msg.Payload); cmd != nil { @@ -830,7 +830,7 @@ func (m *UI) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } else { m.textarea.Placeholder = m.readyPlaceholder } - if m.com.App.Permissions.SkipRequests() { + if m.com.Workspace.PermissionSkipRequests() { m.textarea.Placeholder = "Yolo mode!" } } @@ -909,10 +909,10 @@ func (m *UI) loadNestedToolCalls(items []chat.MessageItem) { messageID := toolItem.MessageID() // Get the agent tool session ID. - agentSessionID := m.com.App.Sessions.CreateAgentToolSessionID(messageID, tc.ID) + agentSessionID := m.com.Workspace.CreateAgentToolSessionID(messageID, tc.ID) // Fetch nested messages. - nestedMsgs, err := m.com.App.Messages.List(context.Background(), agentSessionID) + nestedMsgs, err := m.com.Workspace.ListMessages(context.Background(), agentSessionID) if err != nil || len(nestedMsgs) == 0 { continue } @@ -1114,7 +1114,7 @@ func (m *UI) handleChildSessionMessage(event pubsub.Event[message.Message]) tea. // Check if this is an agent tool session and parse it. childSessionID := event.Payload.SessionID - _, toolCallID, ok := m.com.App.Sessions.ParseAgentToolSessionID(childSessionID) + _, toolCallID, ok := m.com.Workspace.ParseAgentToolSessionID(childSessionID) if !ok { return nil } @@ -1246,8 +1246,8 @@ func (m *UI) handleDialogMsg(msg tea.Msg) tea.Cmd { // Command dialog messages case dialog.ActionToggleYoloMode: - yolo := !m.com.App.Permissions.SkipRequests() - m.com.App.Permissions.SetSkipRequests(yolo) + yolo := !m.com.Workspace.PermissionSkipRequests() + m.com.Workspace.PermissionSetSkipRequests(yolo) m.setEditorPrompt(yolo) m.dialog.CloseDialog(dialog.CommandsID) case dialog.ActionNewSession: @@ -1265,7 +1265,7 @@ func (m *UI) handleDialogMsg(msg tea.Msg) tea.Cmd { break } cmds = append(cmds, func() tea.Msg { - err := m.com.App.AgentCoordinator.Summarize(context.Background(), msg.SessionID) + err := m.com.Workspace.AgentSummarize(context.Background(), msg.SessionID) if err != nil { return util.ReportError(err)() } @@ -1304,10 +1304,10 @@ func (m *UI) handleDialogMsg(msg tea.Msg) tea.Cmd { currentModel := cfg.Models[agentCfg.Model] currentModel.Think = !currentModel.Think - if err := m.com.Store().UpdatePreferredModel(config.ScopeGlobal, agentCfg.Model, currentModel); err != nil { + if err := m.com.Workspace.UpdatePreferredModel(config.ScopeGlobal, agentCfg.Model, currentModel); err != nil { return util.ReportError(err)() } - m.com.App.UpdateAgentModel(context.TODO()) + m.com.Workspace.UpdateAgentModel(context.TODO()) status := "disabled" if currentModel.Think { status = "enabled" @@ -1345,7 +1345,7 @@ func (m *UI) handleDialogMsg(msg tea.Msg) tea.Cmd { // Attempt to import GitHub Copilot tokens from VSCode if available. if isCopilot && !isConfigured() && !msg.ReAuthenticate { - m.com.Store().ImportCopilot() + m.com.Workspace.ImportCopilot() } if !isConfigured() || msg.ReAuthenticate { @@ -1356,18 +1356,18 @@ func (m *UI) handleDialogMsg(msg tea.Msg) tea.Cmd { break } - if err := m.com.Store().UpdatePreferredModel(config.ScopeGlobal, msg.ModelType, msg.Model); err != nil { + if err := m.com.Workspace.UpdatePreferredModel(config.ScopeGlobal, msg.ModelType, msg.Model); err != nil { cmds = append(cmds, util.ReportError(err)) } else if _, ok := cfg.Models[config.SelectedModelTypeSmall]; !ok { // Ensure small model is set is unset. - smallModel := m.com.App.GetDefaultSmallModel(providerID) - if err := m.com.Store().UpdatePreferredModel(config.ScopeGlobal, config.SelectedModelTypeSmall, smallModel); err != nil { + smallModel := m.com.Workspace.GetDefaultSmallModel(providerID) + if err := m.com.Workspace.UpdatePreferredModel(config.ScopeGlobal, config.SelectedModelTypeSmall, smallModel); err != nil { cmds = append(cmds, util.ReportError(err)) } } cmds = append(cmds, func() tea.Msg { - if err := m.com.App.UpdateAgentModel(context.TODO()); err != nil { + if err := m.com.Workspace.UpdateAgentModel(context.TODO()); err != nil { return util.ReportError(err) } @@ -1383,7 +1383,7 @@ func (m *UI) handleDialogMsg(msg tea.Msg) tea.Cmd { if isOnboarding { m.setState(uiLanding, uiFocusEditor) m.com.Config().SetupAgents() - if err := m.com.App.InitCoderAgent(context.TODO()); err != nil { + if err := m.com.Workspace.InitCoderAgent(context.TODO()); err != nil { cmds = append(cmds, util.ReportError(err)) } } @@ -1407,13 +1407,13 @@ func (m *UI) handleDialogMsg(msg tea.Msg) tea.Cmd { currentModel := cfg.Models[agentCfg.Model] currentModel.ReasoningEffort = msg.Effort - if err := m.com.Store().UpdatePreferredModel(config.ScopeGlobal, agentCfg.Model, currentModel); err != nil { + if err := m.com.Workspace.UpdatePreferredModel(config.ScopeGlobal, agentCfg.Model, currentModel); err != nil { cmds = append(cmds, util.ReportError(err)) break } cmds = append(cmds, func() tea.Msg { - m.com.App.UpdateAgentModel(context.TODO()) + m.com.Workspace.UpdateAgentModel(context.TODO()) return util.NewInfoMsg("Reasoning effort set to " + msg.Effort) }) m.dialog.CloseDialog(dialog.ReasoningID) @@ -1421,11 +1421,11 @@ func (m *UI) handleDialogMsg(msg tea.Msg) tea.Cmd { m.dialog.CloseDialog(dialog.PermissionsID) switch msg.Action { case dialog.PermissionAllow: - m.com.App.Permissions.Grant(msg.Permission) + m.com.Workspace.PermissionGrant(msg.Permission) case dialog.PermissionAllowForSession: - m.com.App.Permissions.GrantPersistent(msg.Permission) + m.com.Workspace.PermissionGrantPersistent(msg.Permission) case dialog.PermissionDeny: - m.com.App.Permissions.Deny(msg.Permission) + m.com.Workspace.PermissionDeny(msg.Permission) } case dialog.ActionFilePickerSelected: @@ -2019,7 +2019,7 @@ func (m *UI) View() tea.View { } v.MouseMode = tea.MouseModeCellMotion v.ReportFocus = m.caps.ReportFocusEvents - v.WindowTitle = "crush " + home.Short(m.com.Store().WorkingDir()) + v.WindowTitle = "crush " + home.Short(m.com.Workspace.WorkingDir()) canvas := uv.NewScreenBuffer(m.width, m.height) v.Cursor = m.Draw(canvas, canvas.Bounds()) @@ -2062,7 +2062,7 @@ func (m *UI) ShortHelp() []key.Binding { cancelBinding := k.Chat.Cancel if m.isCanceling { cancelBinding.SetHelp("esc", "press again to cancel") - } else if m.com.App.AgentCoordinator.QueuedPrompts(m.session.ID) > 0 { + } else if m.com.Workspace.AgentQueuedPrompts(m.session.ID) > 0 { cancelBinding.SetHelp("esc", "clear queue") } binds = append(binds, cancelBinding) @@ -2141,7 +2141,7 @@ func (m *UI) FullHelp() [][]key.Binding { cancelBinding := k.Chat.Cancel if m.isCanceling { cancelBinding.SetHelp("esc", "press again to cancel") - } else if m.com.App.AgentCoordinator.QueuedPrompts(m.session.ID) > 0 { + } else if m.com.Workspace.AgentQueuedPrompts(m.session.ID) > 0 { cancelBinding.SetHelp("esc", "clear queue") } binds = append(binds, []key.Binding{cancelBinding}) @@ -2258,7 +2258,7 @@ func (m *UI) FullHelp() [][]key.Binding { func (m *UI) toggleCompactMode() tea.Cmd { m.forceCompactMode = !m.forceCompactMode - err := m.com.Store().SetCompactMode(config.ScopeGlobal, m.forceCompactMode) + err := m.com.Workspace.SetCompactMode(config.ScopeGlobal, m.forceCompactMode) if err != nil { return util.ReportError(err) } @@ -2600,7 +2600,7 @@ func (m *UI) insertFileCompletion(path string) tea.Cmd { if m.hasSession() { // Skip attachment if file was already read and hasn't been modified. - lastRead := m.com.App.FileTracker.LastReadTime(context.Background(), m.session.ID, absPath) + lastRead := m.com.Workspace.FileTrackerLastReadTime(context.Background(), m.session.ID, absPath) if !lastRead.IsZero() { if info, err := os.Stat(path); err == nil && !info.ModTime().After(lastRead) { return nil @@ -2638,9 +2638,8 @@ func (m *UI) insertMCPResourceCompletion(item completions.ResourceCompletionValu } return func() tea.Msg { - contents, err := mcp.ReadResource( + contents, err := m.com.Workspace.ReadMCPResource( context.Background(), - m.com.Store(), item.MCPName, item.URI, ) @@ -2708,9 +2707,8 @@ func isWhitespace(b byte) bool { // isAgentBusy returns true if the agent coordinator exists and is currently // busy processing a request. func (m *UI) isAgentBusy() bool { - return m.com.App != nil && - m.com.App.AgentCoordinator != nil && - m.com.App.AgentCoordinator.IsBusy() + return m.com.Workspace.AgentIsReady() && + m.com.Workspace.AgentIsBusy() } // hasSession returns true if there is an active session with a valid ID. @@ -2767,13 +2765,13 @@ func (m *UI) cacheSidebarLogo(width int) { // sendMessage sends a message with the given content and attachments. func (m *UI) sendMessage(content string, attachments ...message.Attachment) tea.Cmd { - if m.com.App.AgentCoordinator == nil { + if !m.com.Workspace.AgentIsReady() { return util.ReportError(fmt.Errorf("coder agent is not initialized")) } var cmds []tea.Cmd if !m.hasSession() { - newSession, err := m.com.App.Sessions.Create(context.Background(), "New Session") + newSession, err := m.com.Workspace.CreateSession(context.Background(), "New Session") if err != nil { return util.ReportError(err) } @@ -2790,8 +2788,8 @@ func (m *UI) sendMessage(content string, attachments ...message.Attachment) tea. ctx := context.Background() cmds = append(cmds, func() tea.Msg { for _, path := range m.sessionFileReads { - m.com.App.FileTracker.RecordRead(ctx, m.session.ID, path) - m.com.App.LSPManager.Start(ctx, path) + m.com.Workspace.FileTrackerRecordRead(ctx, m.session.ID, path) + m.com.Workspace.LSPStart(ctx, path) } return nil }) @@ -2799,7 +2797,7 @@ func (m *UI) sendMessage(content string, attachments ...message.Attachment) tea. // Capture session ID to avoid race with main goroutine updating m.session. sessionID := m.session.ID cmds = append(cmds, func() tea.Msg { - _, err := m.com.App.AgentCoordinator.Run(context.Background(), sessionID, content, attachments...) + err := m.com.Workspace.AgentRun(context.Background(), sessionID, content, attachments...) if err != nil { isCancelErr := errors.Is(err, context.Canceled) isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied) @@ -2833,15 +2831,14 @@ func (m *UI) cancelAgent() tea.Cmd { return nil } - coordinator := m.com.App.AgentCoordinator - if coordinator == nil { + if !m.com.Workspace.AgentIsReady() { return nil } if m.isCanceling { // Second escape press - actually cancel the agent. m.isCanceling = false - coordinator.Cancel(m.session.ID) + m.com.Workspace.AgentCancel(m.session.ID) // Stop the spinning todo indicator. m.todoIsSpinning = false m.renderPills() @@ -2849,8 +2846,8 @@ func (m *UI) cancelAgent() tea.Cmd { } // Check if there are queued prompts - if so, clear the queue. - if coordinator.QueuedPrompts(m.session.ID) > 0 { - coordinator.ClearQueue(m.session.ID) + if m.com.Workspace.AgentQueuedPrompts(m.session.ID) > 0 { + m.com.Workspace.AgentClearQueue(m.session.ID) return nil } @@ -3071,7 +3068,7 @@ func (m *UI) newSession() tea.Cmd { agenttools.ResetCache() return tea.Batch( func() tea.Msg { - m.com.App.LSPManager.StopAll(context.Background()) + m.com.Workspace.LSPStopAll(context.Background()) return nil }, m.loadPromptHistory(), @@ -3302,7 +3299,7 @@ func (m *UI) drawSessionDetails(scr uv.Screen, area uv.Rectangle) { lspSection := m.lspInfo(sectionWidth, maxItemsPerSection, false) mcpSection := m.mcpInfo(sectionWidth, maxItemsPerSection, false) - filesSection := m.filesInfo(m.com.Store().WorkingDir(), sectionWidth, maxItemsPerSection, false) + filesSection := m.filesInfo(m.com.Workspace.WorkingDir(), sectionWidth, maxItemsPerSection, false) sections := lipgloss.JoinHorizontal(lipgloss.Top, filesSection, " ", lspSection, " ", mcpSection) uv.NewStyledString( s.CompactDetails.View. @@ -3320,7 +3317,7 @@ func (m *UI) drawSessionDetails(scr uv.Screen, area uv.Rectangle) { func (m *UI) runMCPPrompt(clientID, promptID string, arguments map[string]string) tea.Cmd { load := func() tea.Msg { - prompt, err := commands.GetMCPPrompt(m.com.Store(), clientID, promptID, arguments) + prompt, err := m.com.Workspace.GetMCPPrompt(clientID, promptID, arguments) if err != nil { // TODO: make this better return util.ReportError(err)() @@ -3347,34 +3344,30 @@ func (m *UI) runMCPPrompt(clientID, promptID string, arguments map[string]string func (m *UI) handleStateChanged() tea.Cmd { return func() tea.Msg { - m.com.App.UpdateAgentModel(context.Background()) + m.com.Workspace.UpdateAgentModel(context.Background()) return mcpStateChangedMsg{ - states: mcp.GetStates(), + states: m.com.Workspace.MCPGetStates(), } } } -func handleMCPPromptsEvent(name string) tea.Cmd { +func handleMCPPromptsEvent(ws workspace.Workspace, name string) tea.Cmd { return func() tea.Msg { - mcp.RefreshPrompts(context.Background(), name) + ws.MCPRefreshPrompts(context.Background(), name) return nil } } -func handleMCPToolsEvent(cfg *config.ConfigStore, name string) tea.Cmd { +func handleMCPToolsEvent(ws workspace.Workspace, name string) tea.Cmd { return func() tea.Msg { - mcp.RefreshTools( - context.Background(), - cfg, - name, - ) + ws.RefreshMCPTools(context.Background(), name) return nil } } -func handleMCPResourcesEvent(name string) tea.Cmd { +func handleMCPResourcesEvent(ws workspace.Workspace, name string) tea.Cmd { return func() tea.Msg { - mcp.RefreshResources(context.Background(), name) + ws.MCPRefreshResources(context.Background(), name) return nil } } diff --git a/internal/workspace/app_workspace.go b/internal/workspace/app_workspace.go new file mode 100644 index 0000000000000000000000000000000000000000..e78afdd0792df9a5c6673bfad1ec4465a28711e1 --- /dev/null +++ b/internal/workspace/app_workspace.go @@ -0,0 +1,370 @@ +package workspace + +import ( + "context" + "log/slog" + "time" + + tea "charm.land/bubbletea/v2" + "github.com/charmbracelet/crush/internal/agent" + mcptools "github.com/charmbracelet/crush/internal/agent/tools/mcp" + "github.com/charmbracelet/crush/internal/app" + "github.com/charmbracelet/crush/internal/commands" + "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/history" + "github.com/charmbracelet/crush/internal/log" + "github.com/charmbracelet/crush/internal/lsp" + "github.com/charmbracelet/crush/internal/message" + "github.com/charmbracelet/crush/internal/oauth" + "github.com/charmbracelet/crush/internal/permission" + "github.com/charmbracelet/crush/internal/pubsub" + "github.com/charmbracelet/crush/internal/session" +) + +// AppWorkspace wraps an in-process app.App to satisfy the Workspace +// interface. This is the default mode when no server is involved. +type AppWorkspace struct { + app *app.App +} + +// NewAppWorkspace creates a Workspace backed by a local app.App. +func NewAppWorkspace(a *app.App) *AppWorkspace { + return &AppWorkspace{app: a} +} + +// App returns the underlying app.App for callers that still need +// direct access during the migration period. +func (w *AppWorkspace) App() *app.App { + return w.app +} + +// -- Sessions -- + +func (w *AppWorkspace) CreateSession(ctx context.Context, title string) (session.Session, error) { + return w.app.Sessions.Create(ctx, title) +} + +func (w *AppWorkspace) GetSession(ctx context.Context, sessionID string) (session.Session, error) { + return w.app.Sessions.Get(ctx, sessionID) +} + +func (w *AppWorkspace) ListSessions(ctx context.Context) ([]session.Session, error) { + return w.app.Sessions.List(ctx) +} + +func (w *AppWorkspace) SaveSession(ctx context.Context, sess session.Session) (session.Session, error) { + return w.app.Sessions.Save(ctx, sess) +} + +func (w *AppWorkspace) DeleteSession(ctx context.Context, sessionID string) error { + return w.app.Sessions.Delete(ctx, sessionID) +} + +func (w *AppWorkspace) CreateAgentToolSessionID(messageID, toolCallID string) string { + return w.app.Sessions.CreateAgentToolSessionID(messageID, toolCallID) +} + +func (w *AppWorkspace) ParseAgentToolSessionID(sessionID string) (string, string, bool) { + return w.app.Sessions.ParseAgentToolSessionID(sessionID) +} + +// -- Messages -- + +func (w *AppWorkspace) ListMessages(ctx context.Context, sessionID string) ([]message.Message, error) { + return w.app.Messages.List(ctx, sessionID) +} + +func (w *AppWorkspace) ListUserMessages(ctx context.Context, sessionID string) ([]message.Message, error) { + return w.app.Messages.ListUserMessages(ctx, sessionID) +} + +func (w *AppWorkspace) ListAllUserMessages(ctx context.Context) ([]message.Message, error) { + return w.app.Messages.ListAllUserMessages(ctx) +} + +// -- Agent -- + +func (w *AppWorkspace) AgentRun(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) error { + if w.app.AgentCoordinator == nil { + return nil + } + _, err := w.app.AgentCoordinator.Run(ctx, sessionID, prompt, attachments...) + return err +} + +func (w *AppWorkspace) AgentCancel(sessionID string) { + if w.app.AgentCoordinator != nil { + w.app.AgentCoordinator.Cancel(sessionID) + } +} + +func (w *AppWorkspace) AgentIsBusy() bool { + if w.app.AgentCoordinator == nil { + return false + } + return w.app.AgentCoordinator.IsBusy() +} + +func (w *AppWorkspace) AgentIsSessionBusy(sessionID string) bool { + if w.app.AgentCoordinator == nil { + return false + } + return w.app.AgentCoordinator.IsSessionBusy(sessionID) +} + +func (w *AppWorkspace) AgentModel() AgentModel { + if w.app.AgentCoordinator == nil { + return AgentModel{} + } + m := w.app.AgentCoordinator.Model() + return AgentModel{ + CatwalkCfg: m.CatwalkCfg, + ModelCfg: m.ModelCfg, + } +} + +func (w *AppWorkspace) AgentIsReady() bool { + return w.app.AgentCoordinator != nil +} + +func (w *AppWorkspace) AgentQueuedPrompts(sessionID string) int { + if w.app.AgentCoordinator == nil { + return 0 + } + return w.app.AgentCoordinator.QueuedPrompts(sessionID) +} + +func (w *AppWorkspace) AgentQueuedPromptsList(sessionID string) []string { + if w.app.AgentCoordinator == nil { + return nil + } + return w.app.AgentCoordinator.QueuedPromptsList(sessionID) +} + +func (w *AppWorkspace) AgentClearQueue(sessionID string) { + if w.app.AgentCoordinator != nil { + w.app.AgentCoordinator.ClearQueue(sessionID) + } +} + +func (w *AppWorkspace) AgentSummarize(ctx context.Context, sessionID string) error { + if w.app.AgentCoordinator == nil { + return nil + } + return w.app.AgentCoordinator.Summarize(ctx, sessionID) +} + +func (w *AppWorkspace) UpdateAgentModel(ctx context.Context) error { + return w.app.UpdateAgentModel(ctx) +} + +func (w *AppWorkspace) InitCoderAgent(ctx context.Context) error { + return w.app.InitCoderAgent(ctx) +} + +func (w *AppWorkspace) GetDefaultSmallModel(providerID string) config.SelectedModel { + return w.app.GetDefaultSmallModel(providerID) +} + +// -- Permissions -- + +func (w *AppWorkspace) PermissionGrant(perm permission.PermissionRequest) { + w.app.Permissions.Grant(perm) +} + +func (w *AppWorkspace) PermissionGrantPersistent(perm permission.PermissionRequest) { + w.app.Permissions.GrantPersistent(perm) +} + +func (w *AppWorkspace) PermissionDeny(perm permission.PermissionRequest) { + w.app.Permissions.Deny(perm) +} + +func (w *AppWorkspace) PermissionSkipRequests() bool { + return w.app.Permissions.SkipRequests() +} + +func (w *AppWorkspace) PermissionSetSkipRequests(skip bool) { + w.app.Permissions.SetSkipRequests(skip) +} + +// -- FileTracker -- + +func (w *AppWorkspace) FileTrackerRecordRead(ctx context.Context, sessionID, path string) { + w.app.FileTracker.RecordRead(ctx, sessionID, path) +} + +func (w *AppWorkspace) FileTrackerLastReadTime(ctx context.Context, sessionID, path string) time.Time { + return w.app.FileTracker.LastReadTime(ctx, sessionID, path) +} + +func (w *AppWorkspace) FileTrackerListReadFiles(ctx context.Context, sessionID string) ([]string, error) { + return w.app.FileTracker.ListReadFiles(ctx, sessionID) +} + +// -- History -- + +func (w *AppWorkspace) ListSessionHistory(ctx context.Context, sessionID string) ([]history.File, error) { + return w.app.History.ListBySession(ctx, sessionID) +} + +// -- LSP -- + +func (w *AppWorkspace) LSPStart(ctx context.Context, path string) { + w.app.LSPManager.Start(ctx, path) +} + +func (w *AppWorkspace) LSPStopAll(ctx context.Context) { + w.app.LSPManager.StopAll(ctx) +} + +func (w *AppWorkspace) LSPGetStates() map[string]LSPClientInfo { + states := app.GetLSPStates() + result := make(map[string]LSPClientInfo, len(states)) + for k, v := range states { + result[k] = LSPClientInfo{ + Name: v.Name, + State: v.State, + Error: v.Error, + DiagnosticCount: v.DiagnosticCount, + ConnectedAt: v.ConnectedAt, + } + } + return result +} + +func (w *AppWorkspace) LSPGetClient(name string) (*lsp.Client, bool) { + info, ok := app.GetLSPState(name) + if !ok { + return nil, false + } + return info.Client, true +} + +// -- Config (read-only) -- + +func (w *AppWorkspace) Config() *config.Config { + return w.app.Config() +} + +func (w *AppWorkspace) WorkingDir() string { + return w.app.Store().WorkingDir() +} + +func (w *AppWorkspace) Resolver() config.VariableResolver { + return w.app.Store().Resolver() +} + +// -- Config mutations -- + +func (w *AppWorkspace) UpdatePreferredModel(scope config.Scope, modelType config.SelectedModelType, model config.SelectedModel) error { + return w.app.Store().UpdatePreferredModel(scope, modelType, model) +} + +func (w *AppWorkspace) SetCompactMode(scope config.Scope, enabled bool) error { + return w.app.Store().SetCompactMode(scope, enabled) +} + +func (w *AppWorkspace) SetProviderAPIKey(scope config.Scope, providerID string, apiKey any) error { + return w.app.Store().SetProviderAPIKey(scope, providerID, apiKey) +} + +func (w *AppWorkspace) SetConfigField(scope config.Scope, key string, value any) error { + return w.app.Store().SetConfigField(scope, key, value) +} + +func (w *AppWorkspace) RemoveConfigField(scope config.Scope, key string) error { + return w.app.Store().RemoveConfigField(scope, key) +} + +func (w *AppWorkspace) ImportCopilot() (*oauth.Token, bool) { + return w.app.Store().ImportCopilot() +} + +func (w *AppWorkspace) RefreshOAuthToken(ctx context.Context, scope config.Scope, providerID string) error { + return w.app.Store().RefreshOAuthToken(ctx, scope, providerID) +} + +// -- Project lifecycle -- + +func (w *AppWorkspace) ProjectNeedsInitialization() (bool, error) { + return config.ProjectNeedsInitialization(w.app.Store()) +} + +func (w *AppWorkspace) MarkProjectInitialized() error { + return config.MarkProjectInitialized(w.app.Store()) +} + +func (w *AppWorkspace) InitializePrompt() (string, error) { + return agent.InitializePrompt(w.app.Store()) +} + +// -- MCP operations -- + +func (w *AppWorkspace) MCPGetStates() map[string]mcptools.ClientInfo { + return mcptools.GetStates() +} + +func (w *AppWorkspace) MCPRefreshPrompts(ctx context.Context, name string) { + mcptools.RefreshPrompts(ctx, name) +} + +func (w *AppWorkspace) MCPRefreshResources(ctx context.Context, name string) { + mcptools.RefreshResources(ctx, name) +} + +func (w *AppWorkspace) RefreshMCPTools(ctx context.Context, name string) { + mcptools.RefreshTools(ctx, w.app.Store(), name) +} + +func (w *AppWorkspace) ReadMCPResource(ctx context.Context, name, uri string) ([]MCPResourceContents, error) { + contents, err := mcptools.ReadResource(ctx, w.app.Store(), name, uri) + if err != nil { + return nil, err + } + result := make([]MCPResourceContents, len(contents)) + for i, c := range contents { + result[i] = MCPResourceContents{ + URI: c.URI, + MIMEType: c.MIMEType, + Text: c.Text, + Blob: c.Blob, + } + } + return result, nil +} + +func (w *AppWorkspace) GetMCPPrompt(clientID, promptID string, args map[string]string) (string, error) { + return commands.GetMCPPrompt(w.app.Store(), clientID, promptID, args) +} + +// -- Lifecycle -- + +func (w *AppWorkspace) Subscribe(program *tea.Program) { + defer log.RecoverPanic("AppWorkspace.Subscribe", func() { + slog.Info("TUI subscription panic: attempting graceful shutdown") + program.Quit() + }) + + for msg := range w.app.Events() { + switch ev := msg.(type) { + case pubsub.Event[app.LSPEvent]: + program.Send(pubsub.Event[LSPEvent]{ + Type: ev.Type, + Payload: LSPEvent{ + Type: LSPEventType(ev.Payload.Type), + Name: ev.Payload.Name, + State: ev.Payload.State, + Error: ev.Payload.Error, + DiagnosticCount: ev.Payload.DiagnosticCount, + }, + }) + default: + program.Send(msg) + } + } +} + +func (w *AppWorkspace) Shutdown() { + w.app.Shutdown() +} diff --git a/internal/workspace/client_workspace.go b/internal/workspace/client_workspace.go new file mode 100644 index 0000000000000000000000000000000000000000..d61b1124d2df499aca2487640e292b7affb20ee3 --- /dev/null +++ b/internal/workspace/client_workspace.go @@ -0,0 +1,690 @@ +package workspace + +import ( + "context" + "fmt" + "log/slog" + "strings" + "sync" + "time" + + tea "charm.land/bubbletea/v2" + "github.com/charmbracelet/crush/internal/agent/notify" + "github.com/charmbracelet/crush/internal/agent/tools/mcp" + "github.com/charmbracelet/crush/internal/client" + "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/history" + "github.com/charmbracelet/crush/internal/log" + "github.com/charmbracelet/crush/internal/lsp" + "github.com/charmbracelet/crush/internal/message" + "github.com/charmbracelet/crush/internal/oauth" + "github.com/charmbracelet/crush/internal/permission" + "github.com/charmbracelet/crush/internal/proto" + "github.com/charmbracelet/crush/internal/pubsub" + "github.com/charmbracelet/crush/internal/session" +) + +// ClientWorkspace implements the Workspace interface by delegating all +// operations to a remote server via the client SDK. It caches the +// proto.Workspace returned at creation time and refreshes it after +// config-mutating operations. +type ClientWorkspace struct { + client *client.Client + + mu sync.RWMutex + ws proto.Workspace +} + +// NewClientWorkspace creates a new ClientWorkspace that proxies all +// operations through the given client SDK. The ws parameter is the +// proto.Workspace snapshot returned by the server at creation time. +func NewClientWorkspace(c *client.Client, ws proto.Workspace) *ClientWorkspace { + if ws.Config != nil { + ws.Config.SetupAgents() + } + return &ClientWorkspace{ + client: c, + ws: ws, + } +} + +// refreshWorkspace re-fetches the workspace from the server, updating +// the cached snapshot. Called after config-mutating operations. +func (w *ClientWorkspace) refreshWorkspace() { + updated, err := w.client.GetWorkspace(context.Background(), w.ws.ID) + if err != nil { + slog.Error("Failed to refresh workspace", "error", err) + return + } + if updated.Config != nil { + updated.Config.SetupAgents() + } + w.mu.Lock() + w.ws = *updated + w.mu.Unlock() +} + +// cached returns a snapshot of the cached workspace. +func (w *ClientWorkspace) cached() proto.Workspace { + w.mu.RLock() + defer w.mu.RUnlock() + return w.ws +} + +// workspaceID returns the cached workspace ID. +func (w *ClientWorkspace) workspaceID() string { + return w.cached().ID +} + +// -- Sessions -- + +func (w *ClientWorkspace) CreateSession(ctx context.Context, title string) (session.Session, error) { + sess, err := w.client.CreateSession(ctx, w.workspaceID(), title) + if err != nil { + return session.Session{}, err + } + return *sess, nil +} + +func (w *ClientWorkspace) GetSession(ctx context.Context, sessionID string) (session.Session, error) { + sess, err := w.client.GetSession(ctx, w.workspaceID(), sessionID) + if err != nil { + return session.Session{}, err + } + return *sess, nil +} + +func (w *ClientWorkspace) ListSessions(ctx context.Context) ([]session.Session, error) { + return w.client.ListSessions(ctx, w.workspaceID()) +} + +func (w *ClientWorkspace) SaveSession(ctx context.Context, sess session.Session) (session.Session, error) { + saved, err := w.client.SaveSession(ctx, w.workspaceID(), sess) + if err != nil { + return session.Session{}, err + } + return *saved, nil +} + +func (w *ClientWorkspace) DeleteSession(ctx context.Context, sessionID string) error { + return w.client.DeleteSession(ctx, w.workspaceID(), sessionID) +} + +func (w *ClientWorkspace) CreateAgentToolSessionID(messageID, toolCallID string) string { + return fmt.Sprintf("%s$$%s", messageID, toolCallID) +} + +func (w *ClientWorkspace) ParseAgentToolSessionID(sessionID string) (string, string, bool) { + parts := strings.Split(sessionID, "$$") + if len(parts) != 2 { + return "", "", false + } + return parts[0], parts[1], true +} + +// -- Messages -- + +func (w *ClientWorkspace) ListMessages(ctx context.Context, sessionID string) ([]message.Message, error) { + return w.client.ListMessages(ctx, w.workspaceID(), sessionID) +} + +func (w *ClientWorkspace) ListUserMessages(ctx context.Context, sessionID string) ([]message.Message, error) { + return w.client.ListUserMessages(ctx, w.workspaceID(), sessionID) +} + +func (w *ClientWorkspace) ListAllUserMessages(ctx context.Context) ([]message.Message, error) { + return w.client.ListAllUserMessages(ctx, w.workspaceID()) +} + +// -- Agent -- + +func (w *ClientWorkspace) AgentRun(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) error { + return w.client.SendMessage(ctx, w.workspaceID(), sessionID, prompt, attachments...) +} + +func (w *ClientWorkspace) AgentCancel(sessionID string) { + _ = w.client.CancelAgentSession(context.Background(), w.workspaceID(), sessionID) +} + +func (w *ClientWorkspace) AgentIsBusy() bool { + info, err := w.client.GetAgentInfo(context.Background(), w.workspaceID()) + if err != nil { + return false + } + return info.IsBusy +} + +func (w *ClientWorkspace) AgentIsSessionBusy(sessionID string) bool { + info, err := w.client.GetAgentSessionInfo(context.Background(), w.workspaceID(), sessionID) + if err != nil { + return false + } + return info.IsBusy +} + +func (w *ClientWorkspace) AgentModel() AgentModel { + info, err := w.client.GetAgentInfo(context.Background(), w.workspaceID()) + if err != nil { + return AgentModel{} + } + return AgentModel{ + CatwalkCfg: info.Model, + ModelCfg: info.ModelCfg, + } +} + +func (w *ClientWorkspace) AgentIsReady() bool { + info, err := w.client.GetAgentInfo(context.Background(), w.workspaceID()) + if err != nil { + return false + } + return info.IsReady +} + +func (w *ClientWorkspace) AgentQueuedPrompts(sessionID string) int { + count, err := w.client.GetAgentSessionQueuedPrompts(context.Background(), w.workspaceID(), sessionID) + if err != nil { + return 0 + } + return count +} + +func (w *ClientWorkspace) AgentQueuedPromptsList(sessionID string) []string { + prompts, err := w.client.GetAgentSessionQueuedPromptsList(context.Background(), w.workspaceID(), sessionID) + if err != nil { + return nil + } + return prompts +} + +func (w *ClientWorkspace) AgentClearQueue(sessionID string) { + _ = w.client.ClearAgentSessionQueuedPrompts(context.Background(), w.workspaceID(), sessionID) +} + +func (w *ClientWorkspace) AgentSummarize(ctx context.Context, sessionID string) error { + return w.client.AgentSummarizeSession(ctx, w.workspaceID(), sessionID) +} + +func (w *ClientWorkspace) UpdateAgentModel(ctx context.Context) error { + return w.client.UpdateAgent(ctx, w.workspaceID()) +} + +func (w *ClientWorkspace) InitCoderAgent(ctx context.Context) error { + return w.client.InitiateAgentProcessing(ctx, w.workspaceID()) +} + +func (w *ClientWorkspace) GetDefaultSmallModel(providerID string) config.SelectedModel { + model, err := w.client.GetDefaultSmallModel(context.Background(), w.workspaceID(), providerID) + if err != nil { + return config.SelectedModel{} + } + return *model +} + +// -- Permissions -- + +func (w *ClientWorkspace) PermissionGrant(perm permission.PermissionRequest) { + _ = w.client.GrantPermission(context.Background(), w.workspaceID(), proto.PermissionGrant{ + Permission: proto.PermissionRequest{ + ID: perm.ID, + SessionID: perm.SessionID, + ToolCallID: perm.ToolCallID, + ToolName: perm.ToolName, + Description: perm.Description, + Action: perm.Action, + Path: perm.Path, + Params: perm.Params, + }, + Action: proto.PermissionAllowForSession, + }) +} + +func (w *ClientWorkspace) PermissionGrantPersistent(perm permission.PermissionRequest) { + _ = w.client.GrantPermission(context.Background(), w.workspaceID(), proto.PermissionGrant{ + Permission: proto.PermissionRequest{ + ID: perm.ID, + SessionID: perm.SessionID, + ToolCallID: perm.ToolCallID, + ToolName: perm.ToolName, + Description: perm.Description, + Action: perm.Action, + Path: perm.Path, + Params: perm.Params, + }, + Action: proto.PermissionAllow, + }) +} + +func (w *ClientWorkspace) PermissionDeny(perm permission.PermissionRequest) { + _ = w.client.GrantPermission(context.Background(), w.workspaceID(), proto.PermissionGrant{ + Permission: proto.PermissionRequest{ + ID: perm.ID, + SessionID: perm.SessionID, + ToolCallID: perm.ToolCallID, + ToolName: perm.ToolName, + Description: perm.Description, + Action: perm.Action, + Path: perm.Path, + Params: perm.Params, + }, + Action: proto.PermissionDeny, + }) +} + +func (w *ClientWorkspace) PermissionSkipRequests() bool { + skip, err := w.client.GetPermissionsSkipRequests(context.Background(), w.workspaceID()) + if err != nil { + return false + } + return skip +} + +func (w *ClientWorkspace) PermissionSetSkipRequests(skip bool) { + _ = w.client.SetPermissionsSkipRequests(context.Background(), w.workspaceID(), skip) +} + +// -- FileTracker -- + +func (w *ClientWorkspace) FileTrackerRecordRead(ctx context.Context, sessionID, path string) { + _ = w.client.FileTrackerRecordRead(ctx, w.workspaceID(), sessionID, path) +} + +func (w *ClientWorkspace) FileTrackerLastReadTime(ctx context.Context, sessionID, path string) time.Time { + t, err := w.client.FileTrackerLastReadTime(ctx, w.workspaceID(), sessionID, path) + if err != nil { + return time.Time{} + } + return t +} + +func (w *ClientWorkspace) FileTrackerListReadFiles(ctx context.Context, sessionID string) ([]string, error) { + return w.client.FileTrackerListReadFiles(ctx, w.workspaceID(), sessionID) +} + +// -- History -- + +func (w *ClientWorkspace) ListSessionHistory(ctx context.Context, sessionID string) ([]history.File, error) { + return w.client.ListSessionHistoryFiles(ctx, w.workspaceID(), sessionID) +} + +// -- LSP -- + +func (w *ClientWorkspace) LSPStart(ctx context.Context, path string) { + _ = w.client.LSPStart(ctx, w.workspaceID(), path) +} + +func (w *ClientWorkspace) LSPStopAll(ctx context.Context) { + _ = w.client.LSPStopAll(ctx, w.workspaceID()) +} + +func (w *ClientWorkspace) LSPGetStates() map[string]LSPClientInfo { + states, err := w.client.GetLSPs(context.Background(), w.workspaceID()) + if err != nil { + return nil + } + result := make(map[string]LSPClientInfo, len(states)) + for k, v := range states { + result[k] = LSPClientInfo{ + Name: v.Name, + State: v.State, + Error: v.Error, + DiagnosticCount: v.DiagnosticCount, + ConnectedAt: v.ConnectedAt, + } + } + return result +} + +func (w *ClientWorkspace) LSPGetClient(_ string) (*lsp.Client, bool) { + return nil, false +} + +// -- Config (read-only) -- + +func (w *ClientWorkspace) Config() *config.Config { + return w.cached().Config +} + +func (w *ClientWorkspace) WorkingDir() string { + return w.cached().Path +} + +func (w *ClientWorkspace) Resolver() config.VariableResolver { + // In client mode, variable resolution is handled server-side. + return nil +} + +// -- Config mutations -- + +func (w *ClientWorkspace) UpdatePreferredModel(scope config.Scope, modelType config.SelectedModelType, model config.SelectedModel) error { + err := w.client.UpdatePreferredModel(context.Background(), w.workspaceID(), scope, modelType, model) + if err == nil { + w.refreshWorkspace() + } + return err +} + +func (w *ClientWorkspace) SetCompactMode(scope config.Scope, enabled bool) error { + err := w.client.SetCompactMode(context.Background(), w.workspaceID(), scope, enabled) + if err == nil { + w.refreshWorkspace() + } + return err +} + +func (w *ClientWorkspace) SetProviderAPIKey(scope config.Scope, providerID string, apiKey any) error { + err := w.client.SetProviderAPIKey(context.Background(), w.workspaceID(), scope, providerID, apiKey) + if err == nil { + w.refreshWorkspace() + } + return err +} + +func (w *ClientWorkspace) SetConfigField(scope config.Scope, key string, value any) error { + err := w.client.SetConfigField(context.Background(), w.workspaceID(), scope, key, value) + if err == nil { + w.refreshWorkspace() + } + return err +} + +func (w *ClientWorkspace) RemoveConfigField(scope config.Scope, key string) error { + err := w.client.RemoveConfigField(context.Background(), w.workspaceID(), scope, key) + if err == nil { + w.refreshWorkspace() + } + return err +} + +func (w *ClientWorkspace) ImportCopilot() (*oauth.Token, bool) { + token, ok, err := w.client.ImportCopilot(context.Background(), w.workspaceID()) + if err != nil { + return nil, false + } + if ok { + w.refreshWorkspace() + } + return token, ok +} + +func (w *ClientWorkspace) RefreshOAuthToken(ctx context.Context, scope config.Scope, providerID string) error { + err := w.client.RefreshOAuthToken(ctx, w.workspaceID(), scope, providerID) + if err == nil { + w.refreshWorkspace() + } + return err +} + +// -- Project lifecycle -- + +func (w *ClientWorkspace) ProjectNeedsInitialization() (bool, error) { + return w.client.ProjectNeedsInitialization(context.Background(), w.workspaceID()) +} + +func (w *ClientWorkspace) MarkProjectInitialized() error { + return w.client.MarkProjectInitialized(context.Background(), w.workspaceID()) +} + +func (w *ClientWorkspace) InitializePrompt() (string, error) { + return w.client.GetInitializePrompt(context.Background(), w.workspaceID()) +} + +// -- MCP operations -- + +func (w *ClientWorkspace) MCPGetStates() map[string]mcp.ClientInfo { + states, err := w.client.MCPGetStates(context.Background(), w.workspaceID()) + if err != nil { + return nil + } + result := make(map[string]mcp.ClientInfo, len(states)) + for k, v := range states { + result[k] = mcp.ClientInfo{ + Name: v.Name, + State: mcp.State(v.State), + Error: v.Error, + Counts: mcp.Counts{ + Tools: v.ToolCount, + Prompts: v.PromptCount, + Resources: v.ResourceCount, + }, + ConnectedAt: time.Unix(v.ConnectedAt, 0), + } + } + return result +} + +func (w *ClientWorkspace) MCPRefreshPrompts(ctx context.Context, name string) { + _ = w.client.MCPRefreshPrompts(ctx, w.workspaceID(), name) +} + +func (w *ClientWorkspace) MCPRefreshResources(ctx context.Context, name string) { + _ = w.client.MCPRefreshResources(ctx, w.workspaceID(), name) +} + +func (w *ClientWorkspace) RefreshMCPTools(ctx context.Context, name string) { + _ = w.client.RefreshMCPTools(ctx, w.workspaceID(), name) +} + +func (w *ClientWorkspace) ReadMCPResource(ctx context.Context, name, uri string) ([]MCPResourceContents, error) { + contents, err := w.client.ReadMCPResource(ctx, w.workspaceID(), name, uri) + if err != nil { + return nil, err + } + result := make([]MCPResourceContents, len(contents)) + for i, c := range contents { + result[i] = MCPResourceContents{ + URI: c.URI, + MIMEType: c.MIMEType, + Text: c.Text, + Blob: c.Blob, + } + } + return result, nil +} + +func (w *ClientWorkspace) GetMCPPrompt(clientID, promptID string, args map[string]string) (string, error) { + return w.client.GetMCPPrompt(context.Background(), w.workspaceID(), clientID, promptID, args) +} + +// -- Lifecycle -- + +func (w *ClientWorkspace) Subscribe(program *tea.Program) { + defer log.RecoverPanic("ClientWorkspace.Subscribe", func() { + slog.Info("TUI subscription panic: attempting graceful shutdown") + program.Quit() + }) + + evc, err := w.client.SubscribeEvents(context.Background(), w.workspaceID()) + if err != nil { + slog.Error("Failed to subscribe to events", "error", err) + return + } + + for ev := range evc { + translated := translateEvent(ev) + if translated != nil { + program.Send(translated) + } + } +} + +func (w *ClientWorkspace) Shutdown() { + _ = w.client.DeleteWorkspace(context.Background(), w.workspaceID()) +} + +// translateEvent converts proto-typed SSE events into the domain types +// that the TUI's Update() method expects. +func translateEvent(ev any) tea.Msg { + switch e := ev.(type) { + case pubsub.Event[proto.LSPEvent]: + return pubsub.Event[LSPEvent]{ + Type: e.Type, + Payload: LSPEvent{ + Type: LSPEventType(e.Payload.Type), + Name: e.Payload.Name, + State: e.Payload.State, + Error: e.Payload.Error, + DiagnosticCount: e.Payload.DiagnosticCount, + }, + } + case pubsub.Event[proto.MCPEvent]: + return pubsub.Event[mcp.Event]{ + Type: e.Type, + Payload: mcp.Event{ + Type: protoToMCPEventType(e.Payload.Type), + Name: e.Payload.Name, + State: mcp.State(e.Payload.State), + Error: e.Payload.Error, + Counts: mcp.Counts{ + Tools: e.Payload.ToolCount, + Prompts: e.Payload.PromptCount, + Resources: e.Payload.ResourceCount, + }, + }, + } + case pubsub.Event[proto.PermissionRequest]: + return pubsub.Event[permission.PermissionRequest]{ + Type: e.Type, + Payload: permission.PermissionRequest{ + ID: e.Payload.ID, + SessionID: e.Payload.SessionID, + ToolCallID: e.Payload.ToolCallID, + ToolName: e.Payload.ToolName, + Description: e.Payload.Description, + Action: e.Payload.Action, + Path: e.Payload.Path, + Params: e.Payload.Params, + }, + } + case pubsub.Event[proto.PermissionNotification]: + return pubsub.Event[permission.PermissionNotification]{ + Type: e.Type, + Payload: permission.PermissionNotification{ + ToolCallID: e.Payload.ToolCallID, + Granted: e.Payload.Granted, + Denied: e.Payload.Denied, + }, + } + case pubsub.Event[proto.Message]: + return pubsub.Event[message.Message]{ + Type: e.Type, + Payload: protoToMessage(e.Payload), + } + case pubsub.Event[proto.Session]: + return pubsub.Event[session.Session]{ + Type: e.Type, + Payload: protoToSession(e.Payload), + } + case pubsub.Event[proto.File]: + return pubsub.Event[history.File]{ + Type: e.Type, + Payload: protoToFile(e.Payload), + } + case pubsub.Event[proto.AgentEvent]: + return pubsub.Event[notify.Notification]{ + Type: e.Type, + Payload: notify.Notification{ + SessionID: e.Payload.SessionID, + SessionTitle: e.Payload.SessionTitle, + Type: notify.Type(e.Payload.Type), + }, + } + default: + return ev.(tea.Msg) + } +} + +func protoToMCPEventType(t proto.MCPEventType) mcp.EventType { + switch t { + case proto.MCPEventStateChanged: + return mcp.EventStateChanged + case proto.MCPEventToolsListChanged: + return mcp.EventToolsListChanged + case proto.MCPEventPromptsListChanged: + return mcp.EventPromptsListChanged + case proto.MCPEventResourcesListChanged: + return mcp.EventResourcesListChanged + default: + return mcp.EventStateChanged + } +} + +func protoToSession(s proto.Session) session.Session { + return session.Session{ + ID: s.ID, + ParentSessionID: s.ParentSessionID, + Title: s.Title, + SummaryMessageID: s.SummaryMessageID, + MessageCount: s.MessageCount, + PromptTokens: s.PromptTokens, + CompletionTokens: s.CompletionTokens, + Cost: s.Cost, + CreatedAt: s.CreatedAt, + UpdatedAt: s.UpdatedAt, + } +} + +func protoToFile(f proto.File) history.File { + return history.File{ + ID: f.ID, + SessionID: f.SessionID, + Path: f.Path, + Content: f.Content, + Version: f.Version, + CreatedAt: f.CreatedAt, + UpdatedAt: f.UpdatedAt, + } +} + +func protoToMessage(m proto.Message) message.Message { + msg := message.Message{ + ID: m.ID, + SessionID: m.SessionID, + Role: message.MessageRole(m.Role), + Model: m.Model, + Provider: m.Provider, + CreatedAt: m.CreatedAt, + UpdatedAt: m.UpdatedAt, + } + + for _, p := range m.Parts { + switch v := p.(type) { + case proto.TextContent: + msg.Parts = append(msg.Parts, message.TextContent{Text: v.Text}) + case proto.ReasoningContent: + msg.Parts = append(msg.Parts, message.ReasoningContent{ + Thinking: v.Thinking, + Signature: v.Signature, + StartedAt: v.StartedAt, + FinishedAt: v.FinishedAt, + }) + case proto.ToolCall: + msg.Parts = append(msg.Parts, message.ToolCall{ + ID: v.ID, + Name: v.Name, + Input: v.Input, + Finished: v.Finished, + }) + case proto.ToolResult: + msg.Parts = append(msg.Parts, message.ToolResult{ + ToolCallID: v.ToolCallID, + Name: v.Name, + Content: v.Content, + IsError: v.IsError, + }) + case proto.Finish: + msg.Parts = append(msg.Parts, message.Finish{ + Reason: message.FinishReason(v.Reason), + Time: v.Time, + Message: v.Message, + Details: v.Details, + }) + case proto.ImageURLContent: + msg.Parts = append(msg.Parts, message.ImageURLContent{URL: v.URL, Detail: v.Detail}) + case proto.BinaryContent: + msg.Parts = append(msg.Parts, message.BinaryContent{Path: v.Path, MIMEType: v.MIMEType, Data: v.Data}) + } + } + + return msg +} diff --git a/internal/workspace/workspace.go b/internal/workspace/workspace.go new file mode 100644 index 0000000000000000000000000000000000000000..eae106d0ff823446712316c1ba275af6ed67c6da --- /dev/null +++ b/internal/workspace/workspace.go @@ -0,0 +1,150 @@ +// Package workspace defines the Workspace interface used by all +// frontends (TUI, CLI) to interact with a running workspace. Two +// implementations exist: one wrapping a local app.App instance and one +// wrapping the HTTP client SDK. +package workspace + +import ( + "context" + "time" + + tea "charm.land/bubbletea/v2" + "charm.land/catwalk/pkg/catwalk" + mcptools "github.com/charmbracelet/crush/internal/agent/tools/mcp" + "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/history" + "github.com/charmbracelet/crush/internal/lsp" + "github.com/charmbracelet/crush/internal/message" + "github.com/charmbracelet/crush/internal/oauth" + "github.com/charmbracelet/crush/internal/permission" + "github.com/charmbracelet/crush/internal/session" +) + +// LSPClientInfo holds information about an LSP client's state. This is +// the frontend-facing type; implementations translate from the +// underlying app or proto representation. +type LSPClientInfo struct { + Name string + State lsp.ServerState + Error error + DiagnosticCount int + ConnectedAt time.Time +} + +// LSPEventType represents the type of LSP event. +type LSPEventType string + +const ( + LSPEventStateChanged LSPEventType = "state_changed" + LSPEventDiagnosticsChanged LSPEventType = "diagnostics_changed" +) + +// LSPEvent represents an LSP event forwarded to the TUI. +type LSPEvent struct { + Type LSPEventType + Name string + State lsp.ServerState + Error error + DiagnosticCount int +} + +// AgentModel holds the model information exposed to the UI. +type AgentModel struct { + CatwalkCfg catwalk.Model + ModelCfg config.SelectedModel +} + +// Workspace is the main abstraction consumed by the TUI and CLI. It +// groups every operation a frontend needs to perform against a running +// workspace, regardless of whether the workspace is in-process or +// remote. +type Workspace interface { + // Sessions + CreateSession(ctx context.Context, title string) (session.Session, error) + GetSession(ctx context.Context, sessionID string) (session.Session, error) + ListSessions(ctx context.Context) ([]session.Session, error) + SaveSession(ctx context.Context, sess session.Session) (session.Session, error) + DeleteSession(ctx context.Context, sessionID string) error + CreateAgentToolSessionID(messageID, toolCallID string) string + ParseAgentToolSessionID(sessionID string) (messageID string, toolCallID string, ok bool) + + // Messages + ListMessages(ctx context.Context, sessionID string) ([]message.Message, error) + ListUserMessages(ctx context.Context, sessionID string) ([]message.Message, error) + ListAllUserMessages(ctx context.Context) ([]message.Message, error) + + // Agent + AgentRun(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) error + AgentCancel(sessionID string) + AgentIsBusy() bool + AgentIsSessionBusy(sessionID string) bool + AgentModel() AgentModel + AgentIsReady() bool + AgentQueuedPrompts(sessionID string) int + AgentQueuedPromptsList(sessionID string) []string + AgentClearQueue(sessionID string) + AgentSummarize(ctx context.Context, sessionID string) error + UpdateAgentModel(ctx context.Context) error + InitCoderAgent(ctx context.Context) error + GetDefaultSmallModel(providerID string) config.SelectedModel + + // Permissions + PermissionGrant(perm permission.PermissionRequest) + PermissionGrantPersistent(perm permission.PermissionRequest) + PermissionDeny(perm permission.PermissionRequest) + PermissionSkipRequests() bool + PermissionSetSkipRequests(skip bool) + + // FileTracker + FileTrackerRecordRead(ctx context.Context, sessionID, path string) + FileTrackerLastReadTime(ctx context.Context, sessionID, path string) time.Time + FileTrackerListReadFiles(ctx context.Context, sessionID string) ([]string, error) + + // History + ListSessionHistory(ctx context.Context, sessionID string) ([]history.File, error) + + // LSP + LSPStart(ctx context.Context, path string) + LSPStopAll(ctx context.Context) + LSPGetStates() map[string]LSPClientInfo + LSPGetClient(name string) (*lsp.Client, bool) + + // Config (read-only data) + Config() *config.Config + WorkingDir() string + Resolver() config.VariableResolver + + // Config mutations (proxied to server in client mode) + UpdatePreferredModel(scope config.Scope, modelType config.SelectedModelType, model config.SelectedModel) error + SetCompactMode(scope config.Scope, enabled bool) error + SetProviderAPIKey(scope config.Scope, providerID string, apiKey any) error + SetConfigField(scope config.Scope, key string, value any) error + RemoveConfigField(scope config.Scope, key string) error + ImportCopilot() (*oauth.Token, bool) + RefreshOAuthToken(ctx context.Context, scope config.Scope, providerID string) error + + // Project lifecycle + ProjectNeedsInitialization() (bool, error) + MarkProjectInitialized() error + InitializePrompt() (string, error) + + // MCP operations (server-side in client mode) + MCPGetStates() map[string]mcptools.ClientInfo + MCPRefreshPrompts(ctx context.Context, name string) + MCPRefreshResources(ctx context.Context, name string) + RefreshMCPTools(ctx context.Context, name string) + ReadMCPResource(ctx context.Context, name, uri string) ([]MCPResourceContents, error) + GetMCPPrompt(clientID, promptID string, args map[string]string) (string, error) + + // Events + Subscribe(program *tea.Program) + Shutdown() +} + +// MCPResourceContents holds the contents of an MCP resource. +type MCPResourceContents struct { + URI string `json:"uri"` + MIMEType string `json:"mime_type,omitempty"` + Text string `json:"text,omitempty"` + Blob []byte `json:"blob,omitempty"` +}