Detailed changes
@@ -10,10 +10,12 @@ import (
"charm.land/lipgloss/v2"
"github.com/atotto/clipboard"
hyperp "github.com/charmbracelet/crush/internal/agent/hyper"
+ "github.com/charmbracelet/crush/internal/client"
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/oauth"
"github.com/charmbracelet/crush/internal/oauth/copilot"
"github.com/charmbracelet/crush/internal/oauth/hyper"
+ "github.com/charmbracelet/x/ansi"
"github.com/pkg/browser"
"github.com/spf13/cobra"
)
@@ -40,11 +42,17 @@ crush login copilot
},
Args: cobra.MaximumNArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
- app, err := setupAppWithProgressBar(cmd)
+ c, ws, cleanup, err := connectToServer(cmd)
if err != nil {
return err
}
- defer app.Shutdown()
+ defer cleanup()
+
+ progressEnabled := ws.Config.Options.Progress == nil || *ws.Config.Options.Progress
+ if progressEnabled && supportsProgressBar() {
+ _, _ = fmt.Fprintf(os.Stderr, ansi.SetIndeterminateProgressBar)
+ defer func() { _, _ = fmt.Fprintf(os.Stderr, ansi.ResetProgressBar) }()
+ }
provider := "hyper"
if len(args) > 0 {
@@ -52,16 +60,16 @@ crush login copilot
}
switch provider {
case "hyper":
- return loginHyper(app.Store())
+ return loginHyper(c, ws.ID)
case "copilot", "github", "github-copilot":
- return loginCopilot(app.Store())
+ return loginCopilot(cmd.Context(), c, ws.ID)
default:
return fmt.Errorf("unknown platform: %s", args[0])
}
},
}
-func loginHyper(cfg *config.ConfigStore) error {
+func loginHyper(c *client.Client, wsID string) error {
if !hyperp.Enabled() {
return fmt.Errorf("hyper not enabled")
}
@@ -112,8 +120,8 @@ func loginHyper(cfg *config.ConfigStore) error {
}
if err := cmp.Or(
- cfg.SetConfigField(config.ScopeGlobal, "providers.hyper.api_key", token.AccessToken),
- cfg.SetConfigField(config.ScopeGlobal, "providers.hyper.oauth", token),
+ c.SetConfigField(ctx, wsID, config.ScopeGlobal, "providers.hyper.api_key", token.AccessToken),
+ c.SetConfigField(ctx, wsID, config.ScopeGlobal, "providers.hyper.oauth", token),
); err != nil {
return err
}
@@ -123,12 +131,15 @@ func loginHyper(cfg *config.ConfigStore) error {
return nil
}
-func loginCopilot(cfg *config.ConfigStore) error {
- ctx := getLoginContext()
+func loginCopilot(ctx context.Context, c *client.Client, wsID string) error {
+ loginCtx := getLoginContext()
- if cfg.HasConfigField(config.ScopeGlobal, "providers.copilot.oauth") {
- fmt.Println("You are already logged in to GitHub Copilot.")
- return nil
+ cfg, err := c.GetConfig(ctx, wsID)
+ if err == nil && cfg != nil {
+ if pc, ok := cfg.Providers.Get("copilot"); ok && pc.OAuthToken != nil {
+ fmt.Println("You are already logged in to GitHub Copilot.")
+ return nil
+ }
}
diskToken, hasDiskToken := copilot.RefreshTokenFromDisk()
@@ -138,14 +149,14 @@ func loginCopilot(cfg *config.ConfigStore) error {
case hasDiskToken:
fmt.Println("Found existing GitHub Copilot token on disk. Using it to authenticate...")
- t, err := copilot.RefreshToken(ctx, diskToken)
+ t, err := copilot.RefreshToken(loginCtx, diskToken)
if err != nil {
return fmt.Errorf("unable to refresh token from disk: %w", err)
}
token = t
default:
fmt.Println("Requesting device code from GitHub...")
- dc, err := copilot.RequestDeviceCode(ctx)
+ dc, err := copilot.RequestDeviceCode(loginCtx)
if err != nil {
return err
}
@@ -159,7 +170,7 @@ func loginCopilot(cfg *config.ConfigStore) error {
fmt.Println()
fmt.Println("Waiting for authorization...")
- t, err := copilot.PollForToken(ctx, dc)
+ t, err := copilot.PollForToken(loginCtx, dc)
if err == copilot.ErrNotAvailable {
fmt.Println()
fmt.Println("GitHub Copilot is unavailable for this account. To signup, go to the following page:")
@@ -177,8 +188,8 @@ func loginCopilot(cfg *config.ConfigStore) error {
}
if err := cmp.Or(
- cfg.SetConfigField(config.ScopeGlobal, "providers.copilot.api_key", token.AccessToken),
- cfg.SetConfigField(config.ScopeGlobal, "providers.copilot.oauth", token),
+ c.SetConfigField(loginCtx, wsID, config.ScopeGlobal, "providers.copilot.api_key", token.AccessToken),
+ c.SetConfigField(loginCtx, wsID, config.ScopeGlobal, "providers.copilot.oauth", token),
); err != nil {
return err
}
@@ -20,12 +20,9 @@ import (
tea "charm.land/bubbletea/v2"
"charm.land/lipgloss/v2"
"github.com/charmbracelet/colorprofile"
- "github.com/charmbracelet/crush/internal/app"
"github.com/charmbracelet/crush/internal/client"
"github.com/charmbracelet/crush/internal/config"
- "github.com/charmbracelet/crush/internal/db"
"github.com/charmbracelet/crush/internal/event"
- "github.com/charmbracelet/crush/internal/projects"
"github.com/charmbracelet/crush/internal/proto"
"github.com/charmbracelet/crush/internal/server"
"github.com/charmbracelet/crush/internal/ui/common"
@@ -34,7 +31,6 @@ import (
"github.com/charmbracelet/crush/internal/workspace"
"github.com/charmbracelet/fang"
uv "github.com/charmbracelet/ultraviolet"
- "github.com/charmbracelet/x/ansi"
"github.com/charmbracelet/x/exp/charmtone"
"github.com/charmbracelet/x/term"
"github.com/spf13/cobra"
@@ -46,10 +42,9 @@ func init() {
rootCmd.PersistentFlags().StringP("cwd", "c", "", "Current working directory")
rootCmd.PersistentFlags().StringP("data-dir", "D", "", "Custom crush data directory")
rootCmd.PersistentFlags().BoolP("debug", "d", false, "Debug")
+ rootCmd.PersistentFlags().BoolP("yolo", "y", false, "Automatically accept all permissions (dangerous mode)")
+ rootCmd.PersistentFlags().StringVarP(&clientHost, "host", "H", server.DefaultHost(), "Connect to a specific crush server host (for advanced users)")
rootCmd.Flags().BoolP("help", "h", false, "Help")
- rootCmd.Flags().BoolP("yolo", "y", false, "Automatically accept all permissions (dangerous mode)")
-
- rootCmd.Flags().StringVarP(&clientHost, "host", "H", server.DefaultHost(), "Connect to a specific crush server host (for advanced users)")
rootCmd.AddCommand(
runCmd,
@@ -87,20 +82,11 @@ crush --yolo
crush --data-dir /path/to/custom/.crush
`,
RunE: func(cmd *cobra.Command, args []string) error {
- hostURL, err := server.ParseHostURL(clientHost)
- if err != nil {
- return fmt.Errorf("invalid host URL: %v", err)
- }
-
- if err := ensureServer(cmd, hostURL); err != nil {
- return err
- }
-
- c, ws, err := setupClientApp(cmd, hostURL)
+ c, ws, cleanup, err := connectToServer(cmd)
if err != nil {
return err
}
- defer func() { _ = c.DeleteWorkspace(context.Background(), ws.ID) }()
+ defer cleanup()
event.AppInitialized()
@@ -188,25 +174,18 @@ func supportsProgressBar() bool {
return isWindowsTerminal || strings.Contains(strings.ToLower(termProg), "ghostty")
}
-func setupAppWithProgressBar(cmd *cobra.Command) (*app.App, error) {
- app, err := setupApp(cmd)
+// connectToServer ensures the server is running, creates a client and
+// workspace, and returns a cleanup function that deletes the workspace.
+func connectToServer(cmd *cobra.Command) (*client.Client, *proto.Workspace, func(), error) {
+ hostURL, err := server.ParseHostURL(clientHost)
if err != nil {
- return nil, err
+ return nil, nil, nil, fmt.Errorf("invalid host URL: %v", err)
}
- // Check if progress bar is enabled in config (defaults to true if nil)
- progressEnabled := app.Config().Options.Progress == nil || *app.Config().Options.Progress
- if progressEnabled && supportsProgressBar() {
- _, _ = fmt.Fprintf(os.Stderr, ansi.SetIndeterminateProgressBar)
- defer func() { _, _ = fmt.Fprintf(os.Stderr, ansi.ResetProgressBar) }()
+ if err := ensureServer(cmd, hostURL); err != nil {
+ return nil, nil, nil, err
}
- return app, nil
-}
-
-// setupApp handles the common setup logic for both interactive and non-interactive modes.
-// It returns the app instance, config, cleanup function, and any error.
-func setupApp(cmd *cobra.Command) (*app.App, error) {
debug, _ := cmd.Flags().GetBool("debug")
yolo, _ := cmd.Flags().GetBool("yolo")
dataDir, _ := cmd.Flags().GetString("data-dir")
@@ -214,95 +193,40 @@ func setupApp(cmd *cobra.Command) (*app.App, error) {
cwd, err := ResolveCwd(cmd)
if err != nil {
- return nil, err
- }
-
- store, err := config.Init(cwd, dataDir, debug)
- if err != nil {
- return nil, err
- }
-
- store.Overrides().SkipPermissionRequests = yolo
- cfg := store.Config()
-
- if err := createDotCrushDir(cfg.Options.DataDirectory); err != nil {
- return nil, err
- }
-
- // Register this project in the centralized projects list.
- if err := projects.Register(cwd, cfg.Options.DataDirectory); err != nil {
- slog.Warn("Failed to register project", "error", err)
- // Non-fatal: continue even if registration fails
- }
-
- // Connect to DB; this will also run migrations.
- conn, err := db.Connect(ctx, cfg.Options.DataDirectory)
- if err != nil {
- return nil, err
- }
-
- appInstance, err := app.New(ctx, conn, store)
- if err != nil {
- slog.Error("Failed to create app instance", "error", err)
- return nil, err
- }
-
- if shouldEnableMetrics(cfg) {
- event.Init()
- }
-
- return appInstance, nil
-}
-
-// setupClientApp sets up a client-based workspace via the server. It
-// auto-starts a detached server process if the socket does not exist.
-func setupClientApp(cmd *cobra.Command, hostURL *url.URL) (*client.Client, *proto.Workspace, error) {
- debug, _ := cmd.Flags().GetBool("debug")
- yolo, _ := cmd.Flags().GetBool("yolo")
- dataDir, _ := cmd.Flags().GetString("data-dir")
- ctx := cmd.Context()
-
- cwd, err := ResolveCwd(cmd)
- if err != nil {
- return nil, nil, err
+ return nil, nil, nil, err
}
c, err := client.NewClient(cwd, hostURL.Scheme, hostURL.Host)
if err != nil {
- return nil, nil, err
+ return nil, nil, nil, err
}
- ws, err := c.CreateWorkspace(ctx, proto.Workspace{
+ wsReq := proto.Workspace{
Path: cwd,
DataDir: dataDir,
Debug: debug,
YOLO: yolo,
Version: version.Version,
Env: os.Environ(),
- })
+ }
+
+ ws, err := c.CreateWorkspace(ctx, wsReq)
if err != nil {
// The server socket may exist before the HTTP handler is ready.
// Retry a few times with a short backoff.
for range 5 {
select {
case <-ctx.Done():
- return nil, nil, ctx.Err()
+ return nil, nil, nil, ctx.Err()
case <-time.After(200 * time.Millisecond):
}
- ws, err = c.CreateWorkspace(ctx, proto.Workspace{
- Path: cwd,
- DataDir: dataDir,
- Debug: debug,
- YOLO: yolo,
- Version: version.Version,
- Env: os.Environ(),
- })
+ ws, err = c.CreateWorkspace(ctx, wsReq)
if err == nil {
break
}
}
if err != nil {
- return nil, nil, fmt.Errorf("failed to create workspace: %v", err)
+ return nil, nil, nil, fmt.Errorf("failed to create workspace: %v", err)
}
}
@@ -310,7 +234,8 @@ func setupClientApp(cmd *cobra.Command, hostURL *url.URL) (*client.Client, *prot
event.Init()
}
- return c, ws, nil
+ cleanup := func() { _ = c.DeleteWorkspace(context.Background(), ws.ID) }
+ return c, ws, cleanup, nil
}
// ensureServer auto-starts a detached server if the socket file does not
@@ -488,18 +413,3 @@ func ResolveCwd(cmd *cobra.Command) (string, error) {
}
return cwd, nil
}
-
-func createDotCrushDir(dir string) error {
- if err := os.MkdirAll(dir, 0o700); err != nil {
- return fmt.Errorf("failed to create data directory: %q %w", dir, err)
- }
-
- gitIgnorePath := filepath.Join(dir, ".gitignore")
- if _, err := os.Stat(gitIgnorePath); os.IsNotExist(err) {
- if err := os.WriteFile(gitIgnorePath, []byte("*\n"), 0o644); err != nil {
- return fmt.Errorf("failed to create .gitignore file: %q %w", gitIgnorePath, err)
- }
- }
-
- return nil
-}
@@ -7,9 +7,21 @@ import (
"os"
"os/signal"
"strings"
+ "time"
+ "charm.land/lipgloss/v2"
"charm.land/log/v2"
+ "github.com/charmbracelet/crush/internal/client"
+ "github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/event"
+ "github.com/charmbracelet/crush/internal/format"
+ "github.com/charmbracelet/crush/internal/proto"
+ "github.com/charmbracelet/crush/internal/pubsub"
+ "github.com/charmbracelet/crush/internal/ui/anim"
+ "github.com/charmbracelet/crush/internal/ui/styles"
+ "github.com/charmbracelet/x/ansi"
+ "github.com/charmbracelet/x/exp/charmtone"
+ "github.com/charmbracelet/x/term"
"github.com/spf13/cobra"
)
@@ -47,13 +59,13 @@ crush run --verbose "Generate a README for this project"
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, os.Kill)
defer cancel()
- app, err := setupApp(cmd)
+ c, ws, cleanup, err := connectToServer(cmd)
if err != nil {
return err
}
- defer app.Shutdown()
+ defer cleanup()
- if !app.Config().IsConfigured() {
+ if !ws.Config.IsConfigured() {
return fmt.Errorf("no providers configured - please run 'crush' to set up a provider interactively")
}
@@ -76,7 +88,7 @@ crush run --verbose "Generate a README for this project"
event.SetNonInteractive(true)
event.AppInitialized()
- return app.RunNonInteractive(ctx, os.Stdout, prompt, largeModel, smallModel, quiet || verbose)
+ return runNonInteractive(ctx, c, ws, prompt, largeModel, smallModel, quiet || verbose)
},
}
@@ -86,3 +98,312 @@ func init() {
runCmd.Flags().StringP("model", "m", "", "Model to use. Accepts 'model' or 'provider/model' to disambiguate models with the same name across providers")
runCmd.Flags().String("small-model", "", "Small model to use. If not provided, uses the default small model for the provider")
}
+
+// runNonInteractive executes the agent via the server and streams output
+// to stdout.
+func runNonInteractive(
+ ctx context.Context,
+ c *client.Client,
+ ws *proto.Workspace,
+ prompt, largeModel, smallModel string,
+ hideSpinner bool,
+) error {
+ slog.Info("Running in non-interactive mode")
+
+ ctx, cancel := context.WithCancel(ctx)
+ defer cancel()
+
+ if largeModel != "" || smallModel != "" {
+ if err := overrideModels(ctx, c, ws, largeModel, smallModel); err != nil {
+ return fmt.Errorf("failed to override models: %w", err)
+ }
+ }
+
+ var (
+ spinner *format.Spinner
+ stdoutTTY bool
+ stderrTTY bool
+ stdinTTY bool
+ progress bool
+ )
+
+ stdoutTTY = term.IsTerminal(os.Stdout.Fd())
+ stderrTTY = term.IsTerminal(os.Stderr.Fd())
+ stdinTTY = term.IsTerminal(os.Stdin.Fd())
+ progress = ws.Config.Options.Progress == nil || *ws.Config.Options.Progress
+
+ if !hideSpinner && stderrTTY {
+ t := styles.DefaultStyles()
+
+ hasDarkBG := true
+ if stdinTTY && stdoutTTY {
+ hasDarkBG = lipgloss.HasDarkBackground(os.Stdin, os.Stdout)
+ }
+ defaultFG := lipgloss.LightDark(hasDarkBG)(charmtone.Pepper, t.FgBase)
+
+ spinner = format.NewSpinner(ctx, cancel, anim.Settings{
+ Size: 10,
+ Label: "Generating",
+ LabelColor: defaultFG,
+ GradColorA: t.Primary,
+ GradColorB: t.Secondary,
+ CycleColors: true,
+ })
+ spinner.Start()
+ }
+
+ stopSpinner := func() {
+ if !hideSpinner && spinner != nil {
+ spinner.Stop()
+ spinner = nil
+ }
+ }
+
+ // Wait for the agent to become ready (MCP init, etc).
+ if err := waitForAgent(ctx, c, ws.ID); err != nil {
+ stopSpinner()
+ return fmt.Errorf("agent not ready: %w", err)
+ }
+
+ // Force-update agent models so MCP tools are loaded.
+ if err := c.UpdateAgent(ctx, ws.ID); err != nil {
+ slog.Warn("Failed to update agent", "error", err)
+ }
+
+ defer stopSpinner()
+
+ sess, err := c.CreateSession(ctx, ws.ID, "non-interactive")
+ if err != nil {
+ return fmt.Errorf("failed to create session: %w", err)
+ }
+ slog.Info("Created session for non-interactive run", "session_id", sess.ID)
+
+ events, err := c.SubscribeEvents(ctx, ws.ID)
+ if err != nil {
+ return fmt.Errorf("failed to subscribe to events: %w", err)
+ }
+
+ if err := c.SendMessage(ctx, ws.ID, sess.ID, prompt); err != nil {
+ return fmt.Errorf("failed to send message: %w", err)
+ }
+
+ messageReadBytes := make(map[string]int)
+ var printed bool
+
+ defer func() {
+ if progress && stderrTTY {
+ _, _ = fmt.Fprintf(os.Stderr, ansi.ResetProgressBar)
+ }
+ _, _ = fmt.Fprintln(os.Stdout)
+ }()
+
+ for {
+ if progress && stderrTTY {
+ _, _ = fmt.Fprintf(os.Stderr, ansi.SetIndeterminateProgressBar)
+ }
+
+ select {
+ case ev, ok := <-events:
+ if !ok {
+ stopSpinner()
+ return nil
+ }
+
+ switch e := ev.(type) {
+ case pubsub.Event[proto.Message]:
+ msg := e.Payload
+ if msg.SessionID != sess.ID || msg.Role != proto.Assistant || len(msg.Parts) == 0 {
+ continue
+ }
+ stopSpinner()
+
+ content := msg.Content().String()
+ readBytes := messageReadBytes[msg.ID]
+
+ if len(content) < readBytes {
+ slog.Error("Non-interactive: message content shorter than read bytes",
+ "message_length", len(content), "read_bytes", readBytes)
+ return fmt.Errorf("message content is shorter than read bytes: %d < %d", len(content), readBytes)
+ }
+
+ part := content[readBytes:]
+ if readBytes == 0 {
+ part = strings.TrimLeft(part, " \t")
+ }
+ if printed || strings.TrimSpace(part) != "" {
+ printed = true
+ fmt.Fprint(os.Stdout, part)
+ }
+ messageReadBytes[msg.ID] = len(content)
+
+ if msg.IsFinished() {
+ return nil
+ }
+
+ case pubsub.Event[proto.AgentEvent]:
+ if e.Payload.Error != nil {
+ stopSpinner()
+ return fmt.Errorf("agent error: %w", e.Payload.Error)
+ }
+ }
+
+ case <-ctx.Done():
+ stopSpinner()
+ return ctx.Err()
+ }
+ }
+}
+
+// waitForAgent polls GetAgentInfo until the agent is ready, with a
+// timeout.
+func waitForAgent(ctx context.Context, c *client.Client, wsID string) error {
+ timeout := time.After(30 * time.Second)
+ for {
+ info, err := c.GetAgentInfo(ctx, wsID)
+ if err == nil && info.IsReady {
+ return nil
+ }
+ select {
+ case <-timeout:
+ if err != nil {
+ return fmt.Errorf("timeout waiting for agent: %w", err)
+ }
+ return fmt.Errorf("timeout waiting for agent readiness")
+ case <-ctx.Done():
+ return ctx.Err()
+ case <-time.After(200 * time.Millisecond):
+ }
+ }
+}
+
+// overrideModels resolves model strings and updates the workspace
+// configuration via the server.
+func overrideModels(
+ ctx context.Context,
+ c *client.Client,
+ ws *proto.Workspace,
+ largeModel, smallModel string,
+) error {
+ cfg, err := c.GetConfig(ctx, ws.ID)
+ if err != nil {
+ return fmt.Errorf("failed to get config: %w", err)
+ }
+
+ providers := cfg.Providers.Copy()
+
+ largeMatches, smallMatches := findModelMatches(providers, largeModel, smallModel)
+
+ var largeProviderID string
+
+ if largeModel != "" {
+ found, err := validateModelMatches(largeMatches, largeModel, "large")
+ if err != nil {
+ return err
+ }
+ largeProviderID = found.provider
+ slog.Info("Overriding large model", "provider", found.provider, "model", found.modelID)
+ if err := c.UpdatePreferredModel(ctx, ws.ID, config.ScopeWorkspace, config.SelectedModelTypeLarge, config.SelectedModel{
+ Provider: found.provider,
+ Model: found.modelID,
+ }); err != nil {
+ return fmt.Errorf("failed to set large model: %w", err)
+ }
+ }
+
+ switch {
+ case smallModel != "":
+ found, err := validateModelMatches(smallMatches, smallModel, "small")
+ if err != nil {
+ return err
+ }
+ slog.Info("Overriding small model", "provider", found.provider, "model", found.modelID)
+ if err := c.UpdatePreferredModel(ctx, ws.ID, config.ScopeWorkspace, config.SelectedModelTypeSmall, config.SelectedModel{
+ Provider: found.provider,
+ Model: found.modelID,
+ }); err != nil {
+ return fmt.Errorf("failed to set small model: %w", err)
+ }
+
+ case largeModel != "":
+ sm, err := c.GetDefaultSmallModel(ctx, ws.ID, largeProviderID)
+ if err != nil {
+ slog.Warn("Failed to get default small model", "error", err)
+ } else if sm != nil {
+ if err := c.UpdatePreferredModel(ctx, ws.ID, config.ScopeWorkspace, config.SelectedModelTypeSmall, *sm); err != nil {
+ return fmt.Errorf("failed to set small model: %w", err)
+ }
+ }
+ }
+
+ return c.UpdateAgent(ctx, ws.ID)
+}
+
+type modelMatch struct {
+ provider string
+ modelID string
+}
+
+// findModelMatches searches providers for matching large/small model
+// strings.
+func findModelMatches(providers map[string]config.ProviderConfig, largeModel, smallModel string) ([]modelMatch, []modelMatch) {
+ largeFilter, largeID := parseModelString(largeModel)
+ smallFilter, smallID := parseModelString(smallModel)
+
+ var largeMatches, smallMatches []modelMatch
+ for name, provider := range providers {
+ if provider.Disable {
+ continue
+ }
+ for _, m := range provider.Models {
+ if matchesModel(largeID, largeFilter, m.ID, name) {
+ largeMatches = append(largeMatches, modelMatch{provider: name, modelID: m.ID})
+ }
+ if matchesModel(smallID, smallFilter, m.ID, name) {
+ smallMatches = append(smallMatches, modelMatch{provider: name, modelID: m.ID})
+ }
+ }
+ }
+ return largeMatches, smallMatches
+}
+
+// parseModelString splits "provider/model" into (provider, model) or
+// ("", model).
+func parseModelString(s string) (string, string) {
+ if s == "" {
+ return "", ""
+ }
+ if idx := strings.Index(s, "/"); idx >= 0 {
+ return s[:idx], s[idx+1:]
+ }
+ return "", s
+}
+
+// matchesModel returns true if the model ID matches the filter
+// criteria.
+func matchesModel(wantID, wantProvider, modelID, providerName string) bool {
+ if wantID == "" {
+ return false
+ }
+ if wantProvider != "" && wantProvider != providerName {
+ return false
+ }
+ return strings.EqualFold(modelID, wantID)
+}
+
+// validateModelMatches ensures exactly one match exists.
+func validateModelMatches(matches []modelMatch, modelID, label string) (modelMatch, error) {
+ switch {
+ case len(matches) == 0:
+ return modelMatch{}, fmt.Errorf("%s model %q not found", label, modelID)
+ case len(matches) > 1:
+ names := make([]string, len(matches))
+ for i, m := range matches {
+ names[i] = m.provider
+ }
+ return modelMatch{}, fmt.Errorf(
+ "%s model: model %q found in multiple providers: %s. Please specify provider using 'provider/model' format",
+ label, modelID, strings.Join(names, ", "),
+ )
+ }
+ return matches[0], nil
+}
@@ -1,370 +0,0 @@
-package workspace
-
-import (
- "context"
- "log/slog"
- "time"
-
- tea "charm.land/bubbletea/v2"
- "github.com/charmbracelet/crush/internal/agent"
- mcptools "github.com/charmbracelet/crush/internal/agent/tools/mcp"
- "github.com/charmbracelet/crush/internal/app"
- "github.com/charmbracelet/crush/internal/commands"
- "github.com/charmbracelet/crush/internal/config"
- "github.com/charmbracelet/crush/internal/history"
- "github.com/charmbracelet/crush/internal/log"
- "github.com/charmbracelet/crush/internal/lsp"
- "github.com/charmbracelet/crush/internal/message"
- "github.com/charmbracelet/crush/internal/oauth"
- "github.com/charmbracelet/crush/internal/permission"
- "github.com/charmbracelet/crush/internal/pubsub"
- "github.com/charmbracelet/crush/internal/session"
-)
-
-// AppWorkspace wraps an in-process app.App to satisfy the Workspace
-// interface. This is the default mode when no server is involved.
-type AppWorkspace struct {
- app *app.App
-}
-
-// NewAppWorkspace creates a Workspace backed by a local app.App.
-func NewAppWorkspace(a *app.App) *AppWorkspace {
- return &AppWorkspace{app: a}
-}
-
-// App returns the underlying app.App for callers that still need
-// direct access during the migration period.
-func (w *AppWorkspace) App() *app.App {
- return w.app
-}
-
-// -- Sessions --
-
-func (w *AppWorkspace) CreateSession(ctx context.Context, title string) (session.Session, error) {
- return w.app.Sessions.Create(ctx, title)
-}
-
-func (w *AppWorkspace) GetSession(ctx context.Context, sessionID string) (session.Session, error) {
- return w.app.Sessions.Get(ctx, sessionID)
-}
-
-func (w *AppWorkspace) ListSessions(ctx context.Context) ([]session.Session, error) {
- return w.app.Sessions.List(ctx)
-}
-
-func (w *AppWorkspace) SaveSession(ctx context.Context, sess session.Session) (session.Session, error) {
- return w.app.Sessions.Save(ctx, sess)
-}
-
-func (w *AppWorkspace) DeleteSession(ctx context.Context, sessionID string) error {
- return w.app.Sessions.Delete(ctx, sessionID)
-}
-
-func (w *AppWorkspace) CreateAgentToolSessionID(messageID, toolCallID string) string {
- return w.app.Sessions.CreateAgentToolSessionID(messageID, toolCallID)
-}
-
-func (w *AppWorkspace) ParseAgentToolSessionID(sessionID string) (string, string, bool) {
- return w.app.Sessions.ParseAgentToolSessionID(sessionID)
-}
-
-// -- Messages --
-
-func (w *AppWorkspace) ListMessages(ctx context.Context, sessionID string) ([]message.Message, error) {
- return w.app.Messages.List(ctx, sessionID)
-}
-
-func (w *AppWorkspace) ListUserMessages(ctx context.Context, sessionID string) ([]message.Message, error) {
- return w.app.Messages.ListUserMessages(ctx, sessionID)
-}
-
-func (w *AppWorkspace) ListAllUserMessages(ctx context.Context) ([]message.Message, error) {
- return w.app.Messages.ListAllUserMessages(ctx)
-}
-
-// -- Agent --
-
-func (w *AppWorkspace) AgentRun(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) error {
- if w.app.AgentCoordinator == nil {
- return nil
- }
- _, err := w.app.AgentCoordinator.Run(ctx, sessionID, prompt, attachments...)
- return err
-}
-
-func (w *AppWorkspace) AgentCancel(sessionID string) {
- if w.app.AgentCoordinator != nil {
- w.app.AgentCoordinator.Cancel(sessionID)
- }
-}
-
-func (w *AppWorkspace) AgentIsBusy() bool {
- if w.app.AgentCoordinator == nil {
- return false
- }
- return w.app.AgentCoordinator.IsBusy()
-}
-
-func (w *AppWorkspace) AgentIsSessionBusy(sessionID string) bool {
- if w.app.AgentCoordinator == nil {
- return false
- }
- return w.app.AgentCoordinator.IsSessionBusy(sessionID)
-}
-
-func (w *AppWorkspace) AgentModel() AgentModel {
- if w.app.AgentCoordinator == nil {
- return AgentModel{}
- }
- m := w.app.AgentCoordinator.Model()
- return AgentModel{
- CatwalkCfg: m.CatwalkCfg,
- ModelCfg: m.ModelCfg,
- }
-}
-
-func (w *AppWorkspace) AgentIsReady() bool {
- return w.app.AgentCoordinator != nil
-}
-
-func (w *AppWorkspace) AgentQueuedPrompts(sessionID string) int {
- if w.app.AgentCoordinator == nil {
- return 0
- }
- return w.app.AgentCoordinator.QueuedPrompts(sessionID)
-}
-
-func (w *AppWorkspace) AgentQueuedPromptsList(sessionID string) []string {
- if w.app.AgentCoordinator == nil {
- return nil
- }
- return w.app.AgentCoordinator.QueuedPromptsList(sessionID)
-}
-
-func (w *AppWorkspace) AgentClearQueue(sessionID string) {
- if w.app.AgentCoordinator != nil {
- w.app.AgentCoordinator.ClearQueue(sessionID)
- }
-}
-
-func (w *AppWorkspace) AgentSummarize(ctx context.Context, sessionID string) error {
- if w.app.AgentCoordinator == nil {
- return nil
- }
- return w.app.AgentCoordinator.Summarize(ctx, sessionID)
-}
-
-func (w *AppWorkspace) UpdateAgentModel(ctx context.Context) error {
- return w.app.UpdateAgentModel(ctx)
-}
-
-func (w *AppWorkspace) InitCoderAgent(ctx context.Context) error {
- return w.app.InitCoderAgent(ctx)
-}
-
-func (w *AppWorkspace) GetDefaultSmallModel(providerID string) config.SelectedModel {
- return w.app.GetDefaultSmallModel(providerID)
-}
-
-// -- Permissions --
-
-func (w *AppWorkspace) PermissionGrant(perm permission.PermissionRequest) {
- w.app.Permissions.Grant(perm)
-}
-
-func (w *AppWorkspace) PermissionGrantPersistent(perm permission.PermissionRequest) {
- w.app.Permissions.GrantPersistent(perm)
-}
-
-func (w *AppWorkspace) PermissionDeny(perm permission.PermissionRequest) {
- w.app.Permissions.Deny(perm)
-}
-
-func (w *AppWorkspace) PermissionSkipRequests() bool {
- return w.app.Permissions.SkipRequests()
-}
-
-func (w *AppWorkspace) PermissionSetSkipRequests(skip bool) {
- w.app.Permissions.SetSkipRequests(skip)
-}
-
-// -- FileTracker --
-
-func (w *AppWorkspace) FileTrackerRecordRead(ctx context.Context, sessionID, path string) {
- w.app.FileTracker.RecordRead(ctx, sessionID, path)
-}
-
-func (w *AppWorkspace) FileTrackerLastReadTime(ctx context.Context, sessionID, path string) time.Time {
- return w.app.FileTracker.LastReadTime(ctx, sessionID, path)
-}
-
-func (w *AppWorkspace) FileTrackerListReadFiles(ctx context.Context, sessionID string) ([]string, error) {
- return w.app.FileTracker.ListReadFiles(ctx, sessionID)
-}
-
-// -- History --
-
-func (w *AppWorkspace) ListSessionHistory(ctx context.Context, sessionID string) ([]history.File, error) {
- return w.app.History.ListBySession(ctx, sessionID)
-}
-
-// -- LSP --
-
-func (w *AppWorkspace) LSPStart(ctx context.Context, path string) {
- w.app.LSPManager.Start(ctx, path)
-}
-
-func (w *AppWorkspace) LSPStopAll(ctx context.Context) {
- w.app.LSPManager.StopAll(ctx)
-}
-
-func (w *AppWorkspace) LSPGetStates() map[string]LSPClientInfo {
- states := app.GetLSPStates()
- result := make(map[string]LSPClientInfo, len(states))
- for k, v := range states {
- result[k] = LSPClientInfo{
- Name: v.Name,
- State: v.State,
- Error: v.Error,
- DiagnosticCount: v.DiagnosticCount,
- ConnectedAt: v.ConnectedAt,
- }
- }
- return result
-}
-
-func (w *AppWorkspace) LSPGetClient(name string) (*lsp.Client, bool) {
- info, ok := app.GetLSPState(name)
- if !ok {
- return nil, false
- }
- return info.Client, true
-}
-
-// -- Config (read-only) --
-
-func (w *AppWorkspace) Config() *config.Config {
- return w.app.Config()
-}
-
-func (w *AppWorkspace) WorkingDir() string {
- return w.app.Store().WorkingDir()
-}
-
-func (w *AppWorkspace) Resolver() config.VariableResolver {
- return w.app.Store().Resolver()
-}
-
-// -- Config mutations --
-
-func (w *AppWorkspace) UpdatePreferredModel(scope config.Scope, modelType config.SelectedModelType, model config.SelectedModel) error {
- return w.app.Store().UpdatePreferredModel(scope, modelType, model)
-}
-
-func (w *AppWorkspace) SetCompactMode(scope config.Scope, enabled bool) error {
- return w.app.Store().SetCompactMode(scope, enabled)
-}
-
-func (w *AppWorkspace) SetProviderAPIKey(scope config.Scope, providerID string, apiKey any) error {
- return w.app.Store().SetProviderAPIKey(scope, providerID, apiKey)
-}
-
-func (w *AppWorkspace) SetConfigField(scope config.Scope, key string, value any) error {
- return w.app.Store().SetConfigField(scope, key, value)
-}
-
-func (w *AppWorkspace) RemoveConfigField(scope config.Scope, key string) error {
- return w.app.Store().RemoveConfigField(scope, key)
-}
-
-func (w *AppWorkspace) ImportCopilot() (*oauth.Token, bool) {
- return w.app.Store().ImportCopilot()
-}
-
-func (w *AppWorkspace) RefreshOAuthToken(ctx context.Context, scope config.Scope, providerID string) error {
- return w.app.Store().RefreshOAuthToken(ctx, scope, providerID)
-}
-
-// -- Project lifecycle --
-
-func (w *AppWorkspace) ProjectNeedsInitialization() (bool, error) {
- return config.ProjectNeedsInitialization(w.app.Store())
-}
-
-func (w *AppWorkspace) MarkProjectInitialized() error {
- return config.MarkProjectInitialized(w.app.Store())
-}
-
-func (w *AppWorkspace) InitializePrompt() (string, error) {
- return agent.InitializePrompt(w.app.Store())
-}
-
-// -- MCP operations --
-
-func (w *AppWorkspace) MCPGetStates() map[string]mcptools.ClientInfo {
- return mcptools.GetStates()
-}
-
-func (w *AppWorkspace) MCPRefreshPrompts(ctx context.Context, name string) {
- mcptools.RefreshPrompts(ctx, name)
-}
-
-func (w *AppWorkspace) MCPRefreshResources(ctx context.Context, name string) {
- mcptools.RefreshResources(ctx, name)
-}
-
-func (w *AppWorkspace) RefreshMCPTools(ctx context.Context, name string) {
- mcptools.RefreshTools(ctx, w.app.Store(), name)
-}
-
-func (w *AppWorkspace) ReadMCPResource(ctx context.Context, name, uri string) ([]MCPResourceContents, error) {
- contents, err := mcptools.ReadResource(ctx, w.app.Store(), name, uri)
- if err != nil {
- return nil, err
- }
- result := make([]MCPResourceContents, len(contents))
- for i, c := range contents {
- result[i] = MCPResourceContents{
- URI: c.URI,
- MIMEType: c.MIMEType,
- Text: c.Text,
- Blob: c.Blob,
- }
- }
- return result, nil
-}
-
-func (w *AppWorkspace) GetMCPPrompt(clientID, promptID string, args map[string]string) (string, error) {
- return commands.GetMCPPrompt(w.app.Store(), clientID, promptID, args)
-}
-
-// -- Lifecycle --
-
-func (w *AppWorkspace) Subscribe(program *tea.Program) {
- defer log.RecoverPanic("AppWorkspace.Subscribe", func() {
- slog.Info("TUI subscription panic: attempting graceful shutdown")
- program.Quit()
- })
-
- for msg := range w.app.Events() {
- switch ev := msg.(type) {
- case pubsub.Event[app.LSPEvent]:
- program.Send(pubsub.Event[LSPEvent]{
- Type: ev.Type,
- Payload: LSPEvent{
- Type: LSPEventType(ev.Payload.Type),
- Name: ev.Payload.Name,
- State: ev.Payload.State,
- Error: ev.Payload.Error,
- DiagnosticCount: ev.Payload.DiagnosticCount,
- },
- })
- default:
- program.Send(msg)
- }
- }
-}
-
-func (w *AppWorkspace) Shutdown() {
- w.app.Shutdown()
-}