diff --git a/internal/cmd/root.go b/internal/cmd/root.go index 0c300e6b1eeb5f7297d627cf43f4b36a29771375..26684aae09a475a50eceeb0daafccacfb45584df 100644 --- a/internal/cmd/root.go +++ b/internal/cmd/root.go @@ -21,17 +21,22 @@ import ( fang "charm.land/fang/v2" "charm.land/lipgloss/v2" "github.com/charmbracelet/colorprofile" + "github.com/charmbracelet/crush/internal/app" "github.com/charmbracelet/crush/internal/client" "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/db" "github.com/charmbracelet/crush/internal/event" crushlog "github.com/charmbracelet/crush/internal/log" + "github.com/charmbracelet/crush/internal/projects" "github.com/charmbracelet/crush/internal/proto" "github.com/charmbracelet/crush/internal/server" + "github.com/charmbracelet/crush/internal/session" "github.com/charmbracelet/crush/internal/ui/common" ui "github.com/charmbracelet/crush/internal/ui/model" "github.com/charmbracelet/crush/internal/version" "github.com/charmbracelet/crush/internal/workspace" uv "github.com/charmbracelet/ultraviolet" + "github.com/charmbracelet/x/ansi" "github.com/charmbracelet/x/exp/charmtone" "github.com/charmbracelet/x/term" "github.com/spf13/cobra" @@ -96,15 +101,14 @@ crush --continue sessionID, _ := cmd.Flags().GetString("session") continueLast, _ := cmd.Flags().GetBool("continue") - c, ws, cleanup, err := connectToServer(cmd) + ws, cleanup, err := setupWorkspaceWithProgressBar(cmd) if err != nil { return err } defer cleanup() - // Resolve session ID if provided. if sessionID != "" { - sess, err := resolveSessionByID(cmd.Context(), c, ws.ID, sessionID) + sess, err := resolveWorkspaceSessionID(cmd.Context(), ws, sessionID) if err != nil { return err } @@ -113,15 +117,7 @@ crush --continue event.AppInitialized() - clientWs := workspace.NewClientWorkspace(c, *ws) - - if ws.Config.IsConfigured() { - if err := clientWs.InitCoderAgent(cmd.Context()); err != nil { - slog.Error("Failed to initialize coder agent", "error", err) - } - } - - com := common.DefaultCommon(clientWs) + com := common.DefaultCommon(ws) model := ui.New(com, sessionID, continueLast) var env uv.Environ = os.Environ() @@ -131,7 +127,7 @@ crush --continue tea.WithContext(cmd.Context()), tea.WithFilter(ui.MouseEventFilter), ) - go clientWs.Subscribe(program) + go ws.Subscribe(program) if _, err := program.Run(); err != nil { event.Error(err) @@ -197,6 +193,120 @@ func supportsProgressBar() bool { return isWindowsTerminal || strings.Contains(strings.ToLower(termProg), "ghostty") } +// useClientServer returns true when the client/server architecture is +// enabled via the CRUSH_CLIENT_SERVER environment variable. +func useClientServer() bool { + v, _ := strconv.ParseBool(os.Getenv("CRUSH_CLIENT_SERVER")) + return v +} + +// setupWorkspaceWithProgressBar wraps setupWorkspace with an optional +// terminal progress bar shown during initialization. +func setupWorkspaceWithProgressBar(cmd *cobra.Command) (workspace.Workspace, func(), error) { + showProgress := supportsProgressBar() + if showProgress { + _, _ = fmt.Fprintf(os.Stderr, ansi.SetIndeterminateProgressBar) + } + + ws, cleanup, err := setupWorkspace(cmd) + + if showProgress { + _, _ = fmt.Fprintf(os.Stderr, ansi.ResetProgressBar) + } + + return ws, cleanup, err +} + +// setupWorkspace returns a Workspace and cleanup function. When +// CRUSH_CLIENT_SERVER=1, it connects to a server process and returns a +// ClientWorkspace. Otherwise it creates an in-process app.App and +// returns an AppWorkspace. +func setupWorkspace(cmd *cobra.Command) (workspace.Workspace, func(), error) { + if useClientServer() { + return setupClientServerWorkspace(cmd) + } + return setupLocalWorkspace(cmd) +} + +// setupLocalWorkspace creates an in-process app.App and wraps it in an +// AppWorkspace. +func setupLocalWorkspace(cmd *cobra.Command) (workspace.Workspace, func(), error) { + debug, _ := cmd.Flags().GetBool("debug") + yolo, _ := cmd.Flags().GetBool("yolo") + dataDir, _ := cmd.Flags().GetString("data-dir") + ctx := cmd.Context() + + cwd, err := ResolveCwd(cmd) + if err != nil { + return nil, nil, err + } + + store, err := config.Init(cwd, dataDir, debug) + if err != nil { + return nil, nil, err + } + + cfg := store.Config() + store.Overrides().SkipPermissionRequests = yolo + + if err := os.MkdirAll(cfg.Options.DataDirectory, 0o700); err != nil { + return nil, nil, fmt.Errorf("failed to create data directory: %q %w", cfg.Options.DataDirectory, err) + } + + gitIgnorePath := filepath.Join(cfg.Options.DataDirectory, ".gitignore") + if _, err := os.Stat(gitIgnorePath); os.IsNotExist(err) { + if err := os.WriteFile(gitIgnorePath, []byte("*\n"), 0o644); err != nil { + return nil, nil, fmt.Errorf("failed to create .gitignore file: %q %w", gitIgnorePath, err) + } + } + + if err := projects.Register(cwd, cfg.Options.DataDirectory); err != nil { + slog.Warn("Failed to register project", "error", err) + } + + conn, err := db.Connect(ctx, cfg.Options.DataDirectory) + if err != nil { + return nil, nil, err + } + + logFile := filepath.Join(cfg.Options.DataDirectory, "logs", "crush.log") + crushlog.Setup(logFile, debug) + + appInstance, err := app.New(ctx, conn, store) + if err != nil { + _ = conn.Close() + slog.Error("Failed to create app instance", "error", err) + return nil, nil, err + } + + if shouldEnableMetrics(cfg) { + event.Init() + } + + ws := workspace.NewAppWorkspace(appInstance, store) + cleanup := func() { appInstance.Shutdown() } + return ws, cleanup, nil +} + +// setupClientServerWorkspace connects to a server process and wraps the +// result in a ClientWorkspace. +func setupClientServerWorkspace(cmd *cobra.Command) (workspace.Workspace, func(), error) { + c, protoWs, cleanupServer, err := connectToServer(cmd) + if err != nil { + return nil, nil, err + } + + clientWs := workspace.NewClientWorkspace(c, *protoWs) + + if protoWs.Config.IsConfigured() { + if err := clientWs.InitCoderAgent(cmd.Context()); err != nil { + slog.Error("Failed to initialize coder agent", "error", err) + } + } + + return clientWs, cleanupServer, nil +} + // connectToServer ensures the server is running, creates a client and // workspace, and returns a cleanup function that deletes the workspace. func connectToServer(cmd *cobra.Command) (*client.Client, *proto.Workspace, func(), error) { @@ -426,6 +536,38 @@ func MaybePrependStdin(prompt string) (string, error) { return string(bts) + "\n\n" + prompt, nil } +// resolveWorkspaceSessionID resolves a session ID that may be a full +// UUID, full hash, or hash prefix. Works against the Workspace +// interface so both local and client/server paths get hash prefix +// support. +func resolveWorkspaceSessionID(ctx context.Context, ws workspace.Workspace, id string) (session.Session, error) { + if sess, err := ws.GetSession(ctx, id); err == nil { + return sess, nil + } + + sessions, err := ws.ListSessions(ctx) + if err != nil { + return session.Session{}, err + } + + var matches []session.Session + for _, s := range sessions { + hash := session.HashID(s.ID) + if hash == id || strings.HasPrefix(hash, id) { + matches = append(matches, s) + } + } + + switch len(matches) { + case 0: + return session.Session{}, fmt.Errorf("session not found: %s", id) + case 1: + return matches[0], nil + default: + return session.Session{}, fmt.Errorf("session ID %q is ambiguous (%d matches)", id, len(matches)) + } +} + func ResolveCwd(cmd *cobra.Command) (string, error) { cwd, _ := cmd.Flags().GetString("cwd") if cwd != "" { diff --git a/internal/cmd/run.go b/internal/cmd/run.go index fbe28f27de42f27c993aecfd2b0340df553dbcf2..054965f5b72441c0fd56e7876963292e813bb135 100644 --- a/internal/cmd/run.go +++ b/internal/cmd/run.go @@ -20,6 +20,7 @@ import ( "github.com/charmbracelet/crush/internal/session" "github.com/charmbracelet/crush/internal/ui/anim" "github.com/charmbracelet/crush/internal/ui/styles" + "github.com/charmbracelet/crush/internal/workspace" "github.com/charmbracelet/x/ansi" "github.com/charmbracelet/x/exp/charmtone" "github.com/charmbracelet/x/term" @@ -72,31 +73,9 @@ crush run --continue "Follow up on your last response" ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, os.Kill) defer cancel() - c, ws, cleanup, err := connectToServer(cmd) - if err != nil { - return err - } - defer cleanup() - - if sessionID != "" { - sess, err := resolveSessionByID(ctx, c, ws.ID, sessionID) - if err != nil { - return err - } - sessionID = sess.ID - } - - if !ws.Config.IsConfigured() { - return fmt.Errorf("no providers configured - please run 'crush' to set up a provider interactively") - } - - if verbose { - slog.SetDefault(slog.New(log.New(os.Stderr))) - } - prompt := strings.Join(args, " ") - prompt, err = MaybePrependStdin(prompt) + prompt, err := MaybePrependStdin(prompt) if err != nil { slog.Error("Failed to read from stdin", "error", err) return err @@ -116,7 +95,48 @@ crush run --continue "Follow up on your last response" event.SetContinueLastSession(true) } - return runNonInteractive(ctx, c, ws, prompt, largeModel, smallModel, quiet || verbose, sessionID, useLast) + if useClientServer() { + c, ws, cleanup, err := connectToServer(cmd) + if err != nil { + return err + } + defer cleanup() + + if sessionID != "" { + sess, err := resolveSessionByID(ctx, c, ws.ID, sessionID) + if err != nil { + return err + } + sessionID = sess.ID + } + + if !ws.Config.IsConfigured() { + return fmt.Errorf("no providers configured - please run 'crush' to set up a provider interactively") + } + + if verbose { + slog.SetDefault(slog.New(log.New(os.Stderr))) + } + + return runNonInteractive(ctx, c, ws, prompt, largeModel, smallModel, quiet || verbose, sessionID, useLast) + } + + ws, cleanup, err := setupLocalWorkspace(cmd) + if err != nil { + return err + } + defer cleanup() + + if !ws.Config().IsConfigured() { + return fmt.Errorf("no providers configured - please run 'crush' to set up a provider interactively") + } + + if verbose { + slog.SetDefault(slog.New(log.New(os.Stderr))) + } + + appWs := ws.(*workspace.AppWorkspace) + return appWs.App().RunNonInteractive(ctx, os.Stdout, prompt, largeModel, smallModel, quiet || verbose, sessionID, useLast) }, } diff --git a/internal/workspace/app_workspace.go b/internal/workspace/app_workspace.go new file mode 100644 index 0000000000000000000000000000000000000000..57b1228e7eacb28a16141283ee2703a33511bd18 --- /dev/null +++ b/internal/workspace/app_workspace.go @@ -0,0 +1,389 @@ +package workspace + +import ( + "context" + "errors" + "fmt" + "time" + + tea "charm.land/bubbletea/v2" + "github.com/charmbracelet/crush/internal/agent" + mcptools "github.com/charmbracelet/crush/internal/agent/tools/mcp" + "github.com/charmbracelet/crush/internal/app" + "github.com/charmbracelet/crush/internal/commands" + "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/history" + "github.com/charmbracelet/crush/internal/lsp" + "github.com/charmbracelet/crush/internal/message" + "github.com/charmbracelet/crush/internal/oauth" + "github.com/charmbracelet/crush/internal/permission" + "github.com/charmbracelet/crush/internal/session" +) + +// AppWorkspace implements the Workspace interface by delegating +// directly to an in-process [app.App] instance. This is the default +// mode when the client/server architecture is not enabled. +type AppWorkspace struct { + app *app.App + store *config.ConfigStore +} + +// NewAppWorkspace creates a new AppWorkspace wrapping the given app +// and config store. +func NewAppWorkspace(a *app.App, store *config.ConfigStore) *AppWorkspace { + return &AppWorkspace{ + app: a, + store: store, + } +} + +// -- Sessions -- + +func (w *AppWorkspace) CreateSession(ctx context.Context, title string) (session.Session, error) { + return w.app.Sessions.Create(ctx, title) +} + +func (w *AppWorkspace) GetSession(ctx context.Context, sessionID string) (session.Session, error) { + return w.app.Sessions.Get(ctx, sessionID) +} + +func (w *AppWorkspace) ListSessions(ctx context.Context) ([]session.Session, error) { + return w.app.Sessions.List(ctx) +} + +func (w *AppWorkspace) SaveSession(ctx context.Context, sess session.Session) (session.Session, error) { + return w.app.Sessions.Save(ctx, sess) +} + +func (w *AppWorkspace) DeleteSession(ctx context.Context, sessionID string) error { + return w.app.Sessions.Delete(ctx, sessionID) +} + +func (w *AppWorkspace) CreateAgentToolSessionID(messageID, toolCallID string) string { + return w.app.Sessions.CreateAgentToolSessionID(messageID, toolCallID) +} + +func (w *AppWorkspace) ParseAgentToolSessionID(sessionID string) (string, string, bool) { + return w.app.Sessions.ParseAgentToolSessionID(sessionID) +} + +// -- Messages -- + +func (w *AppWorkspace) ListMessages(ctx context.Context, sessionID string) ([]message.Message, error) { + return w.app.Messages.List(ctx, sessionID) +} + +func (w *AppWorkspace) ListUserMessages(ctx context.Context, sessionID string) ([]message.Message, error) { + return w.app.Messages.ListUserMessages(ctx, sessionID) +} + +func (w *AppWorkspace) ListAllUserMessages(ctx context.Context) ([]message.Message, error) { + return w.app.Messages.ListAllUserMessages(ctx) +} + +// -- Agent -- + +func (w *AppWorkspace) AgentRun(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) error { + if w.app.AgentCoordinator == nil { + return errors.New("agent coordinator not initialized") + } + _, err := w.app.AgentCoordinator.Run(ctx, sessionID, prompt, attachments...) + return err +} + +func (w *AppWorkspace) AgentCancel(sessionID string) { + if w.app.AgentCoordinator != nil { + w.app.AgentCoordinator.Cancel(sessionID) + } +} + +func (w *AppWorkspace) AgentIsBusy() bool { + if w.app.AgentCoordinator == nil { + return false + } + return w.app.AgentCoordinator.IsBusy() +} + +func (w *AppWorkspace) AgentIsSessionBusy(sessionID string) bool { + if w.app.AgentCoordinator == nil { + return false + } + return w.app.AgentCoordinator.IsSessionBusy(sessionID) +} + +func (w *AppWorkspace) AgentModel() AgentModel { + if w.app.AgentCoordinator == nil { + return AgentModel{} + } + m := w.app.AgentCoordinator.Model() + return AgentModel{ + CatwalkCfg: m.CatwalkCfg, + ModelCfg: m.ModelCfg, + } +} + +func (w *AppWorkspace) AgentIsReady() bool { + return w.app.AgentCoordinator != nil +} + +func (w *AppWorkspace) AgentQueuedPrompts(sessionID string) int { + if w.app.AgentCoordinator == nil { + return 0 + } + return w.app.AgentCoordinator.QueuedPrompts(sessionID) +} + +func (w *AppWorkspace) AgentQueuedPromptsList(sessionID string) []string { + if w.app.AgentCoordinator == nil { + return nil + } + return w.app.AgentCoordinator.QueuedPromptsList(sessionID) +} + +func (w *AppWorkspace) AgentClearQueue(sessionID string) { + if w.app.AgentCoordinator != nil { + w.app.AgentCoordinator.ClearQueue(sessionID) + } +} + +func (w *AppWorkspace) AgentSummarize(ctx context.Context, sessionID string) error { + if w.app.AgentCoordinator == nil { + return errors.New("agent coordinator not initialized") + } + return w.app.AgentCoordinator.Summarize(ctx, sessionID) +} + +func (w *AppWorkspace) UpdateAgentModel(ctx context.Context) error { + return w.app.UpdateAgentModel(ctx) +} + +func (w *AppWorkspace) InitCoderAgent(ctx context.Context) error { + return w.app.InitCoderAgent(ctx) +} + +func (w *AppWorkspace) GetDefaultSmallModel(providerID string) config.SelectedModel { + return w.app.GetDefaultSmallModel(providerID) +} + +// -- Permissions -- + +func (w *AppWorkspace) PermissionGrant(perm permission.PermissionRequest) { + w.app.Permissions.Grant(perm) +} + +func (w *AppWorkspace) PermissionGrantPersistent(perm permission.PermissionRequest) { + w.app.Permissions.GrantPersistent(perm) +} + +func (w *AppWorkspace) PermissionDeny(perm permission.PermissionRequest) { + w.app.Permissions.Deny(perm) +} + +func (w *AppWorkspace) PermissionSkipRequests() bool { + return w.app.Permissions.SkipRequests() +} + +func (w *AppWorkspace) PermissionSetSkipRequests(skip bool) { + w.app.Permissions.SetSkipRequests(skip) +} + +// -- FileTracker -- + +func (w *AppWorkspace) FileTrackerRecordRead(ctx context.Context, sessionID, path string) { + w.app.FileTracker.RecordRead(ctx, sessionID, path) +} + +func (w *AppWorkspace) FileTrackerLastReadTime(ctx context.Context, sessionID, path string) time.Time { + return w.app.FileTracker.LastReadTime(ctx, sessionID, path) +} + +func (w *AppWorkspace) FileTrackerListReadFiles(ctx context.Context, sessionID string) ([]string, error) { + return w.app.FileTracker.ListReadFiles(ctx, sessionID) +} + +// -- History -- + +func (w *AppWorkspace) ListSessionHistory(ctx context.Context, sessionID string) ([]history.File, error) { + return w.app.History.ListBySession(ctx, sessionID) +} + +// -- LSP -- + +func (w *AppWorkspace) LSPStart(ctx context.Context, path string) { + w.app.LSPManager.Start(ctx, path) +} + +func (w *AppWorkspace) LSPStopAll(ctx context.Context) { + w.app.LSPManager.StopAll(ctx) +} + +func (w *AppWorkspace) LSPGetStates() map[string]LSPClientInfo { + states := app.GetLSPStates() + result := make(map[string]LSPClientInfo, len(states)) + for k, v := range states { + result[k] = LSPClientInfo{ + Name: v.Name, + State: v.State, + Error: v.Error, + DiagnosticCount: v.DiagnosticCount, + ConnectedAt: v.ConnectedAt, + } + } + return result +} + +func (w *AppWorkspace) LSPGetDiagnosticCounts(name string) lsp.DiagnosticCounts { + state, ok := app.GetLSPState(name) + if !ok || state.Client == nil { + return lsp.DiagnosticCounts{} + } + return state.Client.GetDiagnosticCounts() +} + +// -- Config (read-only) -- + +func (w *AppWorkspace) Config() *config.Config { + return w.store.Config() +} + +func (w *AppWorkspace) WorkingDir() string { + return w.store.WorkingDir() +} + +func (w *AppWorkspace) Resolver() config.VariableResolver { + return w.store.Resolver() +} + +// -- Config mutations -- + +func (w *AppWorkspace) UpdatePreferredModel(scope config.Scope, modelType config.SelectedModelType, model config.SelectedModel) error { + return w.store.UpdatePreferredModel(scope, modelType, model) +} + +func (w *AppWorkspace) SetCompactMode(scope config.Scope, enabled bool) error { + return w.store.SetCompactMode(scope, enabled) +} + +func (w *AppWorkspace) SetProviderAPIKey(scope config.Scope, providerID string, apiKey any) error { + return w.store.SetProviderAPIKey(scope, providerID, apiKey) +} + +func (w *AppWorkspace) SetConfigField(scope config.Scope, key string, value any) error { + return w.store.SetConfigField(scope, key, value) +} + +func (w *AppWorkspace) RemoveConfigField(scope config.Scope, key string) error { + return w.store.RemoveConfigField(scope, key) +} + +func (w *AppWorkspace) ImportCopilot() (*oauth.Token, bool) { + return w.store.ImportCopilot() +} + +func (w *AppWorkspace) RefreshOAuthToken(ctx context.Context, scope config.Scope, providerID string) error { + return w.store.RefreshOAuthToken(ctx, scope, providerID) +} + +// -- Project lifecycle -- + +func (w *AppWorkspace) ProjectNeedsInitialization() (bool, error) { + return config.ProjectNeedsInitialization(w.store) +} + +func (w *AppWorkspace) MarkProjectInitialized() error { + return config.MarkProjectInitialized(w.store) +} + +func (w *AppWorkspace) InitializePrompt() (string, error) { + return agent.InitializePrompt(w.store) +} + +// -- MCP operations -- + +func (w *AppWorkspace) MCPGetStates() map[string]mcptools.ClientInfo { + return mcptools.GetStates() +} + +func (w *AppWorkspace) MCPRefreshPrompts(ctx context.Context, name string) { + mcptools.RefreshPrompts(ctx, name) +} + +func (w *AppWorkspace) MCPRefreshResources(ctx context.Context, name string) { + mcptools.RefreshResources(ctx, name) +} + +func (w *AppWorkspace) RefreshMCPTools(ctx context.Context, name string) { + mcptools.RefreshTools(ctx, w.store, name) +} + +func (w *AppWorkspace) ReadMCPResource(ctx context.Context, name, uri string) ([]MCPResourceContents, error) { + contents, err := mcptools.ReadResource(ctx, w.store, name, uri) + if err != nil { + return nil, err + } + result := make([]MCPResourceContents, len(contents)) + for i, c := range contents { + result[i] = MCPResourceContents{ + URI: c.URI, + MIMEType: c.MIMEType, + Text: c.Text, + Blob: c.Blob, + } + } + return result, nil +} + +func (w *AppWorkspace) GetMCPPrompt(clientID, promptID string, args map[string]string) (string, error) { + return commands.GetMCPPrompt(w.store, clientID, promptID, args) +} + +func (w *AppWorkspace) EnableDockerMCP(ctx context.Context) error { + mcpConfig, err := w.store.PrepareDockerMCPConfig() + if err != nil { + return err + } + + if err := mcptools.InitializeSingle(ctx, config.DockerMCPName, w.store); err != nil { + disableErr := mcptools.DisableSingle(w.store, config.DockerMCPName) + delete(w.store.Config().MCP, config.DockerMCPName) + return fmt.Errorf("failed to start docker MCP: %w", errors.Join(err, disableErr)) + } + + if err := w.store.PersistDockerMCPConfig(mcpConfig); err != nil { + disableErr := mcptools.DisableSingle(w.store, config.DockerMCPName) + delete(w.store.Config().MCP, config.DockerMCPName) + return fmt.Errorf("docker MCP started but failed to persist configuration: %w", errors.Join(err, disableErr)) + } + + return nil +} + +func (w *AppWorkspace) DisableDockerMCP() error { + if err := mcptools.DisableSingle(w.store, config.DockerMCPName); err != nil { + return fmt.Errorf("failed to disable docker MCP: %w", err) + } + return w.store.DisableDockerMCP() +} + +// -- Lifecycle -- + +func (w *AppWorkspace) Subscribe(program *tea.Program) { + w.app.Subscribe(program) +} + +func (w *AppWorkspace) Shutdown() { + w.app.Shutdown() +} + +// App returns the underlying app.App instance. +func (w *AppWorkspace) App() *app.App { + return w.app +} + +// Store returns the underlying config store. +func (w *AppWorkspace) Store() *config.ConfigStore { + return w.store +} + +// Compile-time check that AppWorkspace implements Workspace. +var _ Workspace = (*AppWorkspace)(nil)