From e44211db2573aeda9100334cfc76908f918ad0bc Mon Sep 17 00:00:00 2001 From: Ayman Bagabas Date: Fri, 13 Mar 2026 12:42:32 +0300 Subject: [PATCH] refactor: migrate run and login commands to use client API instead of direct app access --- internal/cmd/login.go | 45 ++-- internal/cmd/root.go | 134 ++-------- internal/cmd/run.go | 329 ++++++++++++++++++++++++- internal/workspace/app_workspace.go | 370 ---------------------------- 4 files changed, 375 insertions(+), 503 deletions(-) delete mode 100644 internal/workspace/app_workspace.go diff --git a/internal/cmd/login.go b/internal/cmd/login.go index c9acb12df19875f48b242bee96e377bf5548aacb..9ce9a3e28deb168f7a78b38417e8d93d02ae69ce 100644 --- a/internal/cmd/login.go +++ b/internal/cmd/login.go @@ -10,10 +10,12 @@ import ( "charm.land/lipgloss/v2" "github.com/atotto/clipboard" hyperp "github.com/charmbracelet/crush/internal/agent/hyper" + "github.com/charmbracelet/crush/internal/client" "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/oauth" "github.com/charmbracelet/crush/internal/oauth/copilot" "github.com/charmbracelet/crush/internal/oauth/hyper" + "github.com/charmbracelet/x/ansi" "github.com/pkg/browser" "github.com/spf13/cobra" ) @@ -40,11 +42,17 @@ crush login copilot }, Args: cobra.MaximumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - app, err := setupAppWithProgressBar(cmd) + c, ws, cleanup, err := connectToServer(cmd) if err != nil { return err } - defer app.Shutdown() + defer cleanup() + + progressEnabled := ws.Config.Options.Progress == nil || *ws.Config.Options.Progress + if progressEnabled && supportsProgressBar() { + _, _ = fmt.Fprintf(os.Stderr, ansi.SetIndeterminateProgressBar) + defer func() { _, _ = fmt.Fprintf(os.Stderr, ansi.ResetProgressBar) }() + } provider := "hyper" if len(args) > 0 { @@ -52,16 +60,16 @@ crush login copilot } switch provider { case "hyper": - return loginHyper(app.Store()) + return loginHyper(c, ws.ID) case "copilot", "github", "github-copilot": - return loginCopilot(app.Store()) + return loginCopilot(cmd.Context(), c, ws.ID) default: return fmt.Errorf("unknown platform: %s", args[0]) } }, } -func loginHyper(cfg *config.ConfigStore) error { +func loginHyper(c *client.Client, wsID string) error { if !hyperp.Enabled() { return fmt.Errorf("hyper not enabled") } @@ -112,8 +120,8 @@ func loginHyper(cfg *config.ConfigStore) error { } if err := cmp.Or( - cfg.SetConfigField(config.ScopeGlobal, "providers.hyper.api_key", token.AccessToken), - cfg.SetConfigField(config.ScopeGlobal, "providers.hyper.oauth", token), + c.SetConfigField(ctx, wsID, config.ScopeGlobal, "providers.hyper.api_key", token.AccessToken), + c.SetConfigField(ctx, wsID, config.ScopeGlobal, "providers.hyper.oauth", token), ); err != nil { return err } @@ -123,12 +131,15 @@ func loginHyper(cfg *config.ConfigStore) error { return nil } -func loginCopilot(cfg *config.ConfigStore) error { - ctx := getLoginContext() +func loginCopilot(ctx context.Context, c *client.Client, wsID string) error { + loginCtx := getLoginContext() - if cfg.HasConfigField(config.ScopeGlobal, "providers.copilot.oauth") { - fmt.Println("You are already logged in to GitHub Copilot.") - return nil + cfg, err := c.GetConfig(ctx, wsID) + if err == nil && cfg != nil { + if pc, ok := cfg.Providers.Get("copilot"); ok && pc.OAuthToken != nil { + fmt.Println("You are already logged in to GitHub Copilot.") + return nil + } } diskToken, hasDiskToken := copilot.RefreshTokenFromDisk() @@ -138,14 +149,14 @@ func loginCopilot(cfg *config.ConfigStore) error { case hasDiskToken: fmt.Println("Found existing GitHub Copilot token on disk. Using it to authenticate...") - t, err := copilot.RefreshToken(ctx, diskToken) + t, err := copilot.RefreshToken(loginCtx, diskToken) if err != nil { return fmt.Errorf("unable to refresh token from disk: %w", err) } token = t default: fmt.Println("Requesting device code from GitHub...") - dc, err := copilot.RequestDeviceCode(ctx) + dc, err := copilot.RequestDeviceCode(loginCtx) if err != nil { return err } @@ -159,7 +170,7 @@ func loginCopilot(cfg *config.ConfigStore) error { fmt.Println() fmt.Println("Waiting for authorization...") - t, err := copilot.PollForToken(ctx, dc) + t, err := copilot.PollForToken(loginCtx, dc) if err == copilot.ErrNotAvailable { fmt.Println() fmt.Println("GitHub Copilot is unavailable for this account. To signup, go to the following page:") @@ -177,8 +188,8 @@ func loginCopilot(cfg *config.ConfigStore) error { } if err := cmp.Or( - cfg.SetConfigField(config.ScopeGlobal, "providers.copilot.api_key", token.AccessToken), - cfg.SetConfigField(config.ScopeGlobal, "providers.copilot.oauth", token), + c.SetConfigField(loginCtx, wsID, config.ScopeGlobal, "providers.copilot.api_key", token.AccessToken), + c.SetConfigField(loginCtx, wsID, config.ScopeGlobal, "providers.copilot.oauth", token), ); err != nil { return err } diff --git a/internal/cmd/root.go b/internal/cmd/root.go index ad9face77bd1f30853c4896ff63bf374438a0009..4f511af9490fa9af27a2f71f55f93698bb5f06d0 100644 --- a/internal/cmd/root.go +++ b/internal/cmd/root.go @@ -20,12 +20,9 @@ import ( tea "charm.land/bubbletea/v2" "charm.land/lipgloss/v2" "github.com/charmbracelet/colorprofile" - "github.com/charmbracelet/crush/internal/app" "github.com/charmbracelet/crush/internal/client" "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/db" "github.com/charmbracelet/crush/internal/event" - "github.com/charmbracelet/crush/internal/projects" "github.com/charmbracelet/crush/internal/proto" "github.com/charmbracelet/crush/internal/server" "github.com/charmbracelet/crush/internal/ui/common" @@ -34,7 +31,6 @@ import ( "github.com/charmbracelet/crush/internal/workspace" "github.com/charmbracelet/fang" uv "github.com/charmbracelet/ultraviolet" - "github.com/charmbracelet/x/ansi" "github.com/charmbracelet/x/exp/charmtone" "github.com/charmbracelet/x/term" "github.com/spf13/cobra" @@ -46,10 +42,9 @@ func init() { rootCmd.PersistentFlags().StringP("cwd", "c", "", "Current working directory") rootCmd.PersistentFlags().StringP("data-dir", "D", "", "Custom crush data directory") rootCmd.PersistentFlags().BoolP("debug", "d", false, "Debug") + rootCmd.PersistentFlags().BoolP("yolo", "y", false, "Automatically accept all permissions (dangerous mode)") + rootCmd.PersistentFlags().StringVarP(&clientHost, "host", "H", server.DefaultHost(), "Connect to a specific crush server host (for advanced users)") rootCmd.Flags().BoolP("help", "h", false, "Help") - rootCmd.Flags().BoolP("yolo", "y", false, "Automatically accept all permissions (dangerous mode)") - - rootCmd.Flags().StringVarP(&clientHost, "host", "H", server.DefaultHost(), "Connect to a specific crush server host (for advanced users)") rootCmd.AddCommand( runCmd, @@ -87,20 +82,11 @@ crush --yolo crush --data-dir /path/to/custom/.crush `, RunE: func(cmd *cobra.Command, args []string) error { - hostURL, err := server.ParseHostURL(clientHost) - if err != nil { - return fmt.Errorf("invalid host URL: %v", err) - } - - if err := ensureServer(cmd, hostURL); err != nil { - return err - } - - c, ws, err := setupClientApp(cmd, hostURL) + c, ws, cleanup, err := connectToServer(cmd) if err != nil { return err } - defer func() { _ = c.DeleteWorkspace(context.Background(), ws.ID) }() + defer cleanup() event.AppInitialized() @@ -188,25 +174,18 @@ func supportsProgressBar() bool { return isWindowsTerminal || strings.Contains(strings.ToLower(termProg), "ghostty") } -func setupAppWithProgressBar(cmd *cobra.Command) (*app.App, error) { - app, err := setupApp(cmd) +// connectToServer ensures the server is running, creates a client and +// workspace, and returns a cleanup function that deletes the workspace. +func connectToServer(cmd *cobra.Command) (*client.Client, *proto.Workspace, func(), error) { + hostURL, err := server.ParseHostURL(clientHost) if err != nil { - return nil, err + return nil, nil, nil, fmt.Errorf("invalid host URL: %v", err) } - // Check if progress bar is enabled in config (defaults to true if nil) - progressEnabled := app.Config().Options.Progress == nil || *app.Config().Options.Progress - if progressEnabled && supportsProgressBar() { - _, _ = fmt.Fprintf(os.Stderr, ansi.SetIndeterminateProgressBar) - defer func() { _, _ = fmt.Fprintf(os.Stderr, ansi.ResetProgressBar) }() + if err := ensureServer(cmd, hostURL); err != nil { + return nil, nil, nil, err } - return app, nil -} - -// setupApp handles the common setup logic for both interactive and non-interactive modes. -// It returns the app instance, config, cleanup function, and any error. -func setupApp(cmd *cobra.Command) (*app.App, error) { debug, _ := cmd.Flags().GetBool("debug") yolo, _ := cmd.Flags().GetBool("yolo") dataDir, _ := cmd.Flags().GetString("data-dir") @@ -214,95 +193,40 @@ func setupApp(cmd *cobra.Command) (*app.App, error) { cwd, err := ResolveCwd(cmd) if err != nil { - return nil, err - } - - store, err := config.Init(cwd, dataDir, debug) - if err != nil { - return nil, err - } - - store.Overrides().SkipPermissionRequests = yolo - cfg := store.Config() - - if err := createDotCrushDir(cfg.Options.DataDirectory); err != nil { - return nil, err - } - - // Register this project in the centralized projects list. - if err := projects.Register(cwd, cfg.Options.DataDirectory); err != nil { - slog.Warn("Failed to register project", "error", err) - // Non-fatal: continue even if registration fails - } - - // Connect to DB; this will also run migrations. - conn, err := db.Connect(ctx, cfg.Options.DataDirectory) - if err != nil { - return nil, err - } - - appInstance, err := app.New(ctx, conn, store) - if err != nil { - slog.Error("Failed to create app instance", "error", err) - return nil, err - } - - if shouldEnableMetrics(cfg) { - event.Init() - } - - return appInstance, nil -} - -// setupClientApp sets up a client-based workspace via the server. It -// auto-starts a detached server process if the socket does not exist. -func setupClientApp(cmd *cobra.Command, hostURL *url.URL) (*client.Client, *proto.Workspace, error) { - debug, _ := cmd.Flags().GetBool("debug") - yolo, _ := cmd.Flags().GetBool("yolo") - dataDir, _ := cmd.Flags().GetString("data-dir") - ctx := cmd.Context() - - cwd, err := ResolveCwd(cmd) - if err != nil { - return nil, nil, err + return nil, nil, nil, err } c, err := client.NewClient(cwd, hostURL.Scheme, hostURL.Host) if err != nil { - return nil, nil, err + return nil, nil, nil, err } - ws, err := c.CreateWorkspace(ctx, proto.Workspace{ + wsReq := proto.Workspace{ Path: cwd, DataDir: dataDir, Debug: debug, YOLO: yolo, Version: version.Version, Env: os.Environ(), - }) + } + + ws, err := c.CreateWorkspace(ctx, wsReq) if err != nil { // The server socket may exist before the HTTP handler is ready. // Retry a few times with a short backoff. for range 5 { select { case <-ctx.Done(): - return nil, nil, ctx.Err() + return nil, nil, nil, ctx.Err() case <-time.After(200 * time.Millisecond): } - ws, err = c.CreateWorkspace(ctx, proto.Workspace{ - Path: cwd, - DataDir: dataDir, - Debug: debug, - YOLO: yolo, - Version: version.Version, - Env: os.Environ(), - }) + ws, err = c.CreateWorkspace(ctx, wsReq) if err == nil { break } } if err != nil { - return nil, nil, fmt.Errorf("failed to create workspace: %v", err) + return nil, nil, nil, fmt.Errorf("failed to create workspace: %v", err) } } @@ -310,7 +234,8 @@ func setupClientApp(cmd *cobra.Command, hostURL *url.URL) (*client.Client, *prot event.Init() } - return c, ws, nil + cleanup := func() { _ = c.DeleteWorkspace(context.Background(), ws.ID) } + return c, ws, cleanup, nil } // ensureServer auto-starts a detached server if the socket file does not @@ -488,18 +413,3 @@ func ResolveCwd(cmd *cobra.Command) (string, error) { } return cwd, nil } - -func createDotCrushDir(dir string) error { - if err := os.MkdirAll(dir, 0o700); err != nil { - return fmt.Errorf("failed to create data directory: %q %w", dir, err) - } - - gitIgnorePath := filepath.Join(dir, ".gitignore") - if _, err := os.Stat(gitIgnorePath); os.IsNotExist(err) { - if err := os.WriteFile(gitIgnorePath, []byte("*\n"), 0o644); err != nil { - return fmt.Errorf("failed to create .gitignore file: %q %w", gitIgnorePath, err) - } - } - - return nil -} diff --git a/internal/cmd/run.go b/internal/cmd/run.go index 1119a83f993d08b3a2206104912ec993ec37a9e3..50264515dc7e5fd247f374e3673b306b5be196d0 100644 --- a/internal/cmd/run.go +++ b/internal/cmd/run.go @@ -7,9 +7,21 @@ import ( "os" "os/signal" "strings" + "time" + "charm.land/lipgloss/v2" "charm.land/log/v2" + "github.com/charmbracelet/crush/internal/client" + "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/event" + "github.com/charmbracelet/crush/internal/format" + "github.com/charmbracelet/crush/internal/proto" + "github.com/charmbracelet/crush/internal/pubsub" + "github.com/charmbracelet/crush/internal/ui/anim" + "github.com/charmbracelet/crush/internal/ui/styles" + "github.com/charmbracelet/x/ansi" + "github.com/charmbracelet/x/exp/charmtone" + "github.com/charmbracelet/x/term" "github.com/spf13/cobra" ) @@ -47,13 +59,13 @@ crush run --verbose "Generate a README for this project" ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, os.Kill) defer cancel() - app, err := setupApp(cmd) + c, ws, cleanup, err := connectToServer(cmd) if err != nil { return err } - defer app.Shutdown() + defer cleanup() - if !app.Config().IsConfigured() { + if !ws.Config.IsConfigured() { return fmt.Errorf("no providers configured - please run 'crush' to set up a provider interactively") } @@ -76,7 +88,7 @@ crush run --verbose "Generate a README for this project" event.SetNonInteractive(true) event.AppInitialized() - return app.RunNonInteractive(ctx, os.Stdout, prompt, largeModel, smallModel, quiet || verbose) + return runNonInteractive(ctx, c, ws, prompt, largeModel, smallModel, quiet || verbose) }, } @@ -86,3 +98,312 @@ func init() { runCmd.Flags().StringP("model", "m", "", "Model to use. Accepts 'model' or 'provider/model' to disambiguate models with the same name across providers") runCmd.Flags().String("small-model", "", "Small model to use. If not provided, uses the default small model for the provider") } + +// runNonInteractive executes the agent via the server and streams output +// to stdout. +func runNonInteractive( + ctx context.Context, + c *client.Client, + ws *proto.Workspace, + prompt, largeModel, smallModel string, + hideSpinner bool, +) error { + slog.Info("Running in non-interactive mode") + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + if largeModel != "" || smallModel != "" { + if err := overrideModels(ctx, c, ws, largeModel, smallModel); err != nil { + return fmt.Errorf("failed to override models: %w", err) + } + } + + var ( + spinner *format.Spinner + stdoutTTY bool + stderrTTY bool + stdinTTY bool + progress bool + ) + + stdoutTTY = term.IsTerminal(os.Stdout.Fd()) + stderrTTY = term.IsTerminal(os.Stderr.Fd()) + stdinTTY = term.IsTerminal(os.Stdin.Fd()) + progress = ws.Config.Options.Progress == nil || *ws.Config.Options.Progress + + if !hideSpinner && stderrTTY { + t := styles.DefaultStyles() + + hasDarkBG := true + if stdinTTY && stdoutTTY { + hasDarkBG = lipgloss.HasDarkBackground(os.Stdin, os.Stdout) + } + defaultFG := lipgloss.LightDark(hasDarkBG)(charmtone.Pepper, t.FgBase) + + spinner = format.NewSpinner(ctx, cancel, anim.Settings{ + Size: 10, + Label: "Generating", + LabelColor: defaultFG, + GradColorA: t.Primary, + GradColorB: t.Secondary, + CycleColors: true, + }) + spinner.Start() + } + + stopSpinner := func() { + if !hideSpinner && spinner != nil { + spinner.Stop() + spinner = nil + } + } + + // Wait for the agent to become ready (MCP init, etc). + if err := waitForAgent(ctx, c, ws.ID); err != nil { + stopSpinner() + return fmt.Errorf("agent not ready: %w", err) + } + + // Force-update agent models so MCP tools are loaded. + if err := c.UpdateAgent(ctx, ws.ID); err != nil { + slog.Warn("Failed to update agent", "error", err) + } + + defer stopSpinner() + + sess, err := c.CreateSession(ctx, ws.ID, "non-interactive") + if err != nil { + return fmt.Errorf("failed to create session: %w", err) + } + slog.Info("Created session for non-interactive run", "session_id", sess.ID) + + events, err := c.SubscribeEvents(ctx, ws.ID) + if err != nil { + return fmt.Errorf("failed to subscribe to events: %w", err) + } + + if err := c.SendMessage(ctx, ws.ID, sess.ID, prompt); err != nil { + return fmt.Errorf("failed to send message: %w", err) + } + + messageReadBytes := make(map[string]int) + var printed bool + + defer func() { + if progress && stderrTTY { + _, _ = fmt.Fprintf(os.Stderr, ansi.ResetProgressBar) + } + _, _ = fmt.Fprintln(os.Stdout) + }() + + for { + if progress && stderrTTY { + _, _ = fmt.Fprintf(os.Stderr, ansi.SetIndeterminateProgressBar) + } + + select { + case ev, ok := <-events: + if !ok { + stopSpinner() + return nil + } + + switch e := ev.(type) { + case pubsub.Event[proto.Message]: + msg := e.Payload + if msg.SessionID != sess.ID || msg.Role != proto.Assistant || len(msg.Parts) == 0 { + continue + } + stopSpinner() + + content := msg.Content().String() + readBytes := messageReadBytes[msg.ID] + + if len(content) < readBytes { + slog.Error("Non-interactive: message content shorter than read bytes", + "message_length", len(content), "read_bytes", readBytes) + return fmt.Errorf("message content is shorter than read bytes: %d < %d", len(content), readBytes) + } + + part := content[readBytes:] + if readBytes == 0 { + part = strings.TrimLeft(part, " \t") + } + if printed || strings.TrimSpace(part) != "" { + printed = true + fmt.Fprint(os.Stdout, part) + } + messageReadBytes[msg.ID] = len(content) + + if msg.IsFinished() { + return nil + } + + case pubsub.Event[proto.AgentEvent]: + if e.Payload.Error != nil { + stopSpinner() + return fmt.Errorf("agent error: %w", e.Payload.Error) + } + } + + case <-ctx.Done(): + stopSpinner() + return ctx.Err() + } + } +} + +// waitForAgent polls GetAgentInfo until the agent is ready, with a +// timeout. +func waitForAgent(ctx context.Context, c *client.Client, wsID string) error { + timeout := time.After(30 * time.Second) + for { + info, err := c.GetAgentInfo(ctx, wsID) + if err == nil && info.IsReady { + return nil + } + select { + case <-timeout: + if err != nil { + return fmt.Errorf("timeout waiting for agent: %w", err) + } + return fmt.Errorf("timeout waiting for agent readiness") + case <-ctx.Done(): + return ctx.Err() + case <-time.After(200 * time.Millisecond): + } + } +} + +// overrideModels resolves model strings and updates the workspace +// configuration via the server. +func overrideModels( + ctx context.Context, + c *client.Client, + ws *proto.Workspace, + largeModel, smallModel string, +) error { + cfg, err := c.GetConfig(ctx, ws.ID) + if err != nil { + return fmt.Errorf("failed to get config: %w", err) + } + + providers := cfg.Providers.Copy() + + largeMatches, smallMatches := findModelMatches(providers, largeModel, smallModel) + + var largeProviderID string + + if largeModel != "" { + found, err := validateModelMatches(largeMatches, largeModel, "large") + if err != nil { + return err + } + largeProviderID = found.provider + slog.Info("Overriding large model", "provider", found.provider, "model", found.modelID) + if err := c.UpdatePreferredModel(ctx, ws.ID, config.ScopeWorkspace, config.SelectedModelTypeLarge, config.SelectedModel{ + Provider: found.provider, + Model: found.modelID, + }); err != nil { + return fmt.Errorf("failed to set large model: %w", err) + } + } + + switch { + case smallModel != "": + found, err := validateModelMatches(smallMatches, smallModel, "small") + if err != nil { + return err + } + slog.Info("Overriding small model", "provider", found.provider, "model", found.modelID) + if err := c.UpdatePreferredModel(ctx, ws.ID, config.ScopeWorkspace, config.SelectedModelTypeSmall, config.SelectedModel{ + Provider: found.provider, + Model: found.modelID, + }); err != nil { + return fmt.Errorf("failed to set small model: %w", err) + } + + case largeModel != "": + sm, err := c.GetDefaultSmallModel(ctx, ws.ID, largeProviderID) + if err != nil { + slog.Warn("Failed to get default small model", "error", err) + } else if sm != nil { + if err := c.UpdatePreferredModel(ctx, ws.ID, config.ScopeWorkspace, config.SelectedModelTypeSmall, *sm); err != nil { + return fmt.Errorf("failed to set small model: %w", err) + } + } + } + + return c.UpdateAgent(ctx, ws.ID) +} + +type modelMatch struct { + provider string + modelID string +} + +// findModelMatches searches providers for matching large/small model +// strings. +func findModelMatches(providers map[string]config.ProviderConfig, largeModel, smallModel string) ([]modelMatch, []modelMatch) { + largeFilter, largeID := parseModelString(largeModel) + smallFilter, smallID := parseModelString(smallModel) + + var largeMatches, smallMatches []modelMatch + for name, provider := range providers { + if provider.Disable { + continue + } + for _, m := range provider.Models { + if matchesModel(largeID, largeFilter, m.ID, name) { + largeMatches = append(largeMatches, modelMatch{provider: name, modelID: m.ID}) + } + if matchesModel(smallID, smallFilter, m.ID, name) { + smallMatches = append(smallMatches, modelMatch{provider: name, modelID: m.ID}) + } + } + } + return largeMatches, smallMatches +} + +// parseModelString splits "provider/model" into (provider, model) or +// ("", model). +func parseModelString(s string) (string, string) { + if s == "" { + return "", "" + } + if idx := strings.Index(s, "/"); idx >= 0 { + return s[:idx], s[idx+1:] + } + return "", s +} + +// matchesModel returns true if the model ID matches the filter +// criteria. +func matchesModel(wantID, wantProvider, modelID, providerName string) bool { + if wantID == "" { + return false + } + if wantProvider != "" && wantProvider != providerName { + return false + } + return strings.EqualFold(modelID, wantID) +} + +// validateModelMatches ensures exactly one match exists. +func validateModelMatches(matches []modelMatch, modelID, label string) (modelMatch, error) { + switch { + case len(matches) == 0: + return modelMatch{}, fmt.Errorf("%s model %q not found", label, modelID) + case len(matches) > 1: + names := make([]string, len(matches)) + for i, m := range matches { + names[i] = m.provider + } + return modelMatch{}, fmt.Errorf( + "%s model: model %q found in multiple providers: %s. Please specify provider using 'provider/model' format", + label, modelID, strings.Join(names, ", "), + ) + } + return matches[0], nil +} diff --git a/internal/workspace/app_workspace.go b/internal/workspace/app_workspace.go deleted file mode 100644 index e78afdd0792df9a5c6673bfad1ec4465a28711e1..0000000000000000000000000000000000000000 --- a/internal/workspace/app_workspace.go +++ /dev/null @@ -1,370 +0,0 @@ -package workspace - -import ( - "context" - "log/slog" - "time" - - tea "charm.land/bubbletea/v2" - "github.com/charmbracelet/crush/internal/agent" - mcptools "github.com/charmbracelet/crush/internal/agent/tools/mcp" - "github.com/charmbracelet/crush/internal/app" - "github.com/charmbracelet/crush/internal/commands" - "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/history" - "github.com/charmbracelet/crush/internal/log" - "github.com/charmbracelet/crush/internal/lsp" - "github.com/charmbracelet/crush/internal/message" - "github.com/charmbracelet/crush/internal/oauth" - "github.com/charmbracelet/crush/internal/permission" - "github.com/charmbracelet/crush/internal/pubsub" - "github.com/charmbracelet/crush/internal/session" -) - -// AppWorkspace wraps an in-process app.App to satisfy the Workspace -// interface. This is the default mode when no server is involved. -type AppWorkspace struct { - app *app.App -} - -// NewAppWorkspace creates a Workspace backed by a local app.App. -func NewAppWorkspace(a *app.App) *AppWorkspace { - return &AppWorkspace{app: a} -} - -// App returns the underlying app.App for callers that still need -// direct access during the migration period. -func (w *AppWorkspace) App() *app.App { - return w.app -} - -// -- Sessions -- - -func (w *AppWorkspace) CreateSession(ctx context.Context, title string) (session.Session, error) { - return w.app.Sessions.Create(ctx, title) -} - -func (w *AppWorkspace) GetSession(ctx context.Context, sessionID string) (session.Session, error) { - return w.app.Sessions.Get(ctx, sessionID) -} - -func (w *AppWorkspace) ListSessions(ctx context.Context) ([]session.Session, error) { - return w.app.Sessions.List(ctx) -} - -func (w *AppWorkspace) SaveSession(ctx context.Context, sess session.Session) (session.Session, error) { - return w.app.Sessions.Save(ctx, sess) -} - -func (w *AppWorkspace) DeleteSession(ctx context.Context, sessionID string) error { - return w.app.Sessions.Delete(ctx, sessionID) -} - -func (w *AppWorkspace) CreateAgentToolSessionID(messageID, toolCallID string) string { - return w.app.Sessions.CreateAgentToolSessionID(messageID, toolCallID) -} - -func (w *AppWorkspace) ParseAgentToolSessionID(sessionID string) (string, string, bool) { - return w.app.Sessions.ParseAgentToolSessionID(sessionID) -} - -// -- Messages -- - -func (w *AppWorkspace) ListMessages(ctx context.Context, sessionID string) ([]message.Message, error) { - return w.app.Messages.List(ctx, sessionID) -} - -func (w *AppWorkspace) ListUserMessages(ctx context.Context, sessionID string) ([]message.Message, error) { - return w.app.Messages.ListUserMessages(ctx, sessionID) -} - -func (w *AppWorkspace) ListAllUserMessages(ctx context.Context) ([]message.Message, error) { - return w.app.Messages.ListAllUserMessages(ctx) -} - -// -- Agent -- - -func (w *AppWorkspace) AgentRun(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) error { - if w.app.AgentCoordinator == nil { - return nil - } - _, err := w.app.AgentCoordinator.Run(ctx, sessionID, prompt, attachments...) - return err -} - -func (w *AppWorkspace) AgentCancel(sessionID string) { - if w.app.AgentCoordinator != nil { - w.app.AgentCoordinator.Cancel(sessionID) - } -} - -func (w *AppWorkspace) AgentIsBusy() bool { - if w.app.AgentCoordinator == nil { - return false - } - return w.app.AgentCoordinator.IsBusy() -} - -func (w *AppWorkspace) AgentIsSessionBusy(sessionID string) bool { - if w.app.AgentCoordinator == nil { - return false - } - return w.app.AgentCoordinator.IsSessionBusy(sessionID) -} - -func (w *AppWorkspace) AgentModel() AgentModel { - if w.app.AgentCoordinator == nil { - return AgentModel{} - } - m := w.app.AgentCoordinator.Model() - return AgentModel{ - CatwalkCfg: m.CatwalkCfg, - ModelCfg: m.ModelCfg, - } -} - -func (w *AppWorkspace) AgentIsReady() bool { - return w.app.AgentCoordinator != nil -} - -func (w *AppWorkspace) AgentQueuedPrompts(sessionID string) int { - if w.app.AgentCoordinator == nil { - return 0 - } - return w.app.AgentCoordinator.QueuedPrompts(sessionID) -} - -func (w *AppWorkspace) AgentQueuedPromptsList(sessionID string) []string { - if w.app.AgentCoordinator == nil { - return nil - } - return w.app.AgentCoordinator.QueuedPromptsList(sessionID) -} - -func (w *AppWorkspace) AgentClearQueue(sessionID string) { - if w.app.AgentCoordinator != nil { - w.app.AgentCoordinator.ClearQueue(sessionID) - } -} - -func (w *AppWorkspace) AgentSummarize(ctx context.Context, sessionID string) error { - if w.app.AgentCoordinator == nil { - return nil - } - return w.app.AgentCoordinator.Summarize(ctx, sessionID) -} - -func (w *AppWorkspace) UpdateAgentModel(ctx context.Context) error { - return w.app.UpdateAgentModel(ctx) -} - -func (w *AppWorkspace) InitCoderAgent(ctx context.Context) error { - return w.app.InitCoderAgent(ctx) -} - -func (w *AppWorkspace) GetDefaultSmallModel(providerID string) config.SelectedModel { - return w.app.GetDefaultSmallModel(providerID) -} - -// -- Permissions -- - -func (w *AppWorkspace) PermissionGrant(perm permission.PermissionRequest) { - w.app.Permissions.Grant(perm) -} - -func (w *AppWorkspace) PermissionGrantPersistent(perm permission.PermissionRequest) { - w.app.Permissions.GrantPersistent(perm) -} - -func (w *AppWorkspace) PermissionDeny(perm permission.PermissionRequest) { - w.app.Permissions.Deny(perm) -} - -func (w *AppWorkspace) PermissionSkipRequests() bool { - return w.app.Permissions.SkipRequests() -} - -func (w *AppWorkspace) PermissionSetSkipRequests(skip bool) { - w.app.Permissions.SetSkipRequests(skip) -} - -// -- FileTracker -- - -func (w *AppWorkspace) FileTrackerRecordRead(ctx context.Context, sessionID, path string) { - w.app.FileTracker.RecordRead(ctx, sessionID, path) -} - -func (w *AppWorkspace) FileTrackerLastReadTime(ctx context.Context, sessionID, path string) time.Time { - return w.app.FileTracker.LastReadTime(ctx, sessionID, path) -} - -func (w *AppWorkspace) FileTrackerListReadFiles(ctx context.Context, sessionID string) ([]string, error) { - return w.app.FileTracker.ListReadFiles(ctx, sessionID) -} - -// -- History -- - -func (w *AppWorkspace) ListSessionHistory(ctx context.Context, sessionID string) ([]history.File, error) { - return w.app.History.ListBySession(ctx, sessionID) -} - -// -- LSP -- - -func (w *AppWorkspace) LSPStart(ctx context.Context, path string) { - w.app.LSPManager.Start(ctx, path) -} - -func (w *AppWorkspace) LSPStopAll(ctx context.Context) { - w.app.LSPManager.StopAll(ctx) -} - -func (w *AppWorkspace) LSPGetStates() map[string]LSPClientInfo { - states := app.GetLSPStates() - result := make(map[string]LSPClientInfo, len(states)) - for k, v := range states { - result[k] = LSPClientInfo{ - Name: v.Name, - State: v.State, - Error: v.Error, - DiagnosticCount: v.DiagnosticCount, - ConnectedAt: v.ConnectedAt, - } - } - return result -} - -func (w *AppWorkspace) LSPGetClient(name string) (*lsp.Client, bool) { - info, ok := app.GetLSPState(name) - if !ok { - return nil, false - } - return info.Client, true -} - -// -- Config (read-only) -- - -func (w *AppWorkspace) Config() *config.Config { - return w.app.Config() -} - -func (w *AppWorkspace) WorkingDir() string { - return w.app.Store().WorkingDir() -} - -func (w *AppWorkspace) Resolver() config.VariableResolver { - return w.app.Store().Resolver() -} - -// -- Config mutations -- - -func (w *AppWorkspace) UpdatePreferredModel(scope config.Scope, modelType config.SelectedModelType, model config.SelectedModel) error { - return w.app.Store().UpdatePreferredModel(scope, modelType, model) -} - -func (w *AppWorkspace) SetCompactMode(scope config.Scope, enabled bool) error { - return w.app.Store().SetCompactMode(scope, enabled) -} - -func (w *AppWorkspace) SetProviderAPIKey(scope config.Scope, providerID string, apiKey any) error { - return w.app.Store().SetProviderAPIKey(scope, providerID, apiKey) -} - -func (w *AppWorkspace) SetConfigField(scope config.Scope, key string, value any) error { - return w.app.Store().SetConfigField(scope, key, value) -} - -func (w *AppWorkspace) RemoveConfigField(scope config.Scope, key string) error { - return w.app.Store().RemoveConfigField(scope, key) -} - -func (w *AppWorkspace) ImportCopilot() (*oauth.Token, bool) { - return w.app.Store().ImportCopilot() -} - -func (w *AppWorkspace) RefreshOAuthToken(ctx context.Context, scope config.Scope, providerID string) error { - return w.app.Store().RefreshOAuthToken(ctx, scope, providerID) -} - -// -- Project lifecycle -- - -func (w *AppWorkspace) ProjectNeedsInitialization() (bool, error) { - return config.ProjectNeedsInitialization(w.app.Store()) -} - -func (w *AppWorkspace) MarkProjectInitialized() error { - return config.MarkProjectInitialized(w.app.Store()) -} - -func (w *AppWorkspace) InitializePrompt() (string, error) { - return agent.InitializePrompt(w.app.Store()) -} - -// -- MCP operations -- - -func (w *AppWorkspace) MCPGetStates() map[string]mcptools.ClientInfo { - return mcptools.GetStates() -} - -func (w *AppWorkspace) MCPRefreshPrompts(ctx context.Context, name string) { - mcptools.RefreshPrompts(ctx, name) -} - -func (w *AppWorkspace) MCPRefreshResources(ctx context.Context, name string) { - mcptools.RefreshResources(ctx, name) -} - -func (w *AppWorkspace) RefreshMCPTools(ctx context.Context, name string) { - mcptools.RefreshTools(ctx, w.app.Store(), name) -} - -func (w *AppWorkspace) ReadMCPResource(ctx context.Context, name, uri string) ([]MCPResourceContents, error) { - contents, err := mcptools.ReadResource(ctx, w.app.Store(), name, uri) - if err != nil { - return nil, err - } - result := make([]MCPResourceContents, len(contents)) - for i, c := range contents { - result[i] = MCPResourceContents{ - URI: c.URI, - MIMEType: c.MIMEType, - Text: c.Text, - Blob: c.Blob, - } - } - return result, nil -} - -func (w *AppWorkspace) GetMCPPrompt(clientID, promptID string, args map[string]string) (string, error) { - return commands.GetMCPPrompt(w.app.Store(), clientID, promptID, args) -} - -// -- Lifecycle -- - -func (w *AppWorkspace) Subscribe(program *tea.Program) { - defer log.RecoverPanic("AppWorkspace.Subscribe", func() { - slog.Info("TUI subscription panic: attempting graceful shutdown") - program.Quit() - }) - - for msg := range w.app.Events() { - switch ev := msg.(type) { - case pubsub.Event[app.LSPEvent]: - program.Send(pubsub.Event[LSPEvent]{ - Type: ev.Type, - Payload: LSPEvent{ - Type: LSPEventType(ev.Payload.Type), - Name: ev.Payload.Name, - State: ev.Payload.State, - Error: ev.Payload.Error, - DiagnosticCount: ev.Payload.DiagnosticCount, - }, - }) - default: - program.Send(msg) - } - } -} - -func (w *AppWorkspace) Shutdown() { - w.app.Shutdown() -}