refactor: migrate run and login commands to use client API instead of direct app access

Ayman Bagabas created

Change summary

internal/cmd/login.go               |  45 ++-
internal/cmd/root.go                | 134 +---------
internal/cmd/run.go                 | 329 +++++++++++++++++++++++++++
internal/workspace/app_workspace.go | 370 -------------------------------
4 files changed, 375 insertions(+), 503 deletions(-)

Detailed changes

internal/cmd/login.go 🔗

@@ -10,10 +10,12 @@ import (
 	"charm.land/lipgloss/v2"
 	"github.com/atotto/clipboard"
 	hyperp "github.com/charmbracelet/crush/internal/agent/hyper"
+	"github.com/charmbracelet/crush/internal/client"
 	"github.com/charmbracelet/crush/internal/config"
 	"github.com/charmbracelet/crush/internal/oauth"
 	"github.com/charmbracelet/crush/internal/oauth/copilot"
 	"github.com/charmbracelet/crush/internal/oauth/hyper"
+	"github.com/charmbracelet/x/ansi"
 	"github.com/pkg/browser"
 	"github.com/spf13/cobra"
 )
@@ -40,11 +42,17 @@ crush login copilot
 	},
 	Args: cobra.MaximumNArgs(1),
 	RunE: func(cmd *cobra.Command, args []string) error {
-		app, err := setupAppWithProgressBar(cmd)
+		c, ws, cleanup, err := connectToServer(cmd)
 		if err != nil {
 			return err
 		}
-		defer app.Shutdown()
+		defer cleanup()
+
+		progressEnabled := ws.Config.Options.Progress == nil || *ws.Config.Options.Progress
+		if progressEnabled && supportsProgressBar() {
+			_, _ = fmt.Fprintf(os.Stderr, ansi.SetIndeterminateProgressBar)
+			defer func() { _, _ = fmt.Fprintf(os.Stderr, ansi.ResetProgressBar) }()
+		}
 
 		provider := "hyper"
 		if len(args) > 0 {
@@ -52,16 +60,16 @@ crush login copilot
 		}
 		switch provider {
 		case "hyper":
-			return loginHyper(app.Store())
+			return loginHyper(c, ws.ID)
 		case "copilot", "github", "github-copilot":
-			return loginCopilot(app.Store())
+			return loginCopilot(cmd.Context(), c, ws.ID)
 		default:
 			return fmt.Errorf("unknown platform: %s", args[0])
 		}
 	},
 }
 
-func loginHyper(cfg *config.ConfigStore) error {
+func loginHyper(c *client.Client, wsID string) error {
 	if !hyperp.Enabled() {
 		return fmt.Errorf("hyper not enabled")
 	}
@@ -112,8 +120,8 @@ func loginHyper(cfg *config.ConfigStore) error {
 	}
 
 	if err := cmp.Or(
-		cfg.SetConfigField(config.ScopeGlobal, "providers.hyper.api_key", token.AccessToken),
-		cfg.SetConfigField(config.ScopeGlobal, "providers.hyper.oauth", token),
+		c.SetConfigField(ctx, wsID, config.ScopeGlobal, "providers.hyper.api_key", token.AccessToken),
+		c.SetConfigField(ctx, wsID, config.ScopeGlobal, "providers.hyper.oauth", token),
 	); err != nil {
 		return err
 	}
@@ -123,12 +131,15 @@ func loginHyper(cfg *config.ConfigStore) error {
 	return nil
 }
 
-func loginCopilot(cfg *config.ConfigStore) error {
-	ctx := getLoginContext()
+func loginCopilot(ctx context.Context, c *client.Client, wsID string) error {
+	loginCtx := getLoginContext()
 
-	if cfg.HasConfigField(config.ScopeGlobal, "providers.copilot.oauth") {
-		fmt.Println("You are already logged in to GitHub Copilot.")
-		return nil
+	cfg, err := c.GetConfig(ctx, wsID)
+	if err == nil && cfg != nil {
+		if pc, ok := cfg.Providers.Get("copilot"); ok && pc.OAuthToken != nil {
+			fmt.Println("You are already logged in to GitHub Copilot.")
+			return nil
+		}
 	}
 
 	diskToken, hasDiskToken := copilot.RefreshTokenFromDisk()
@@ -138,14 +149,14 @@ func loginCopilot(cfg *config.ConfigStore) error {
 	case hasDiskToken:
 		fmt.Println("Found existing GitHub Copilot token on disk. Using it to authenticate...")
 
-		t, err := copilot.RefreshToken(ctx, diskToken)
+		t, err := copilot.RefreshToken(loginCtx, diskToken)
 		if err != nil {
 			return fmt.Errorf("unable to refresh token from disk: %w", err)
 		}
 		token = t
 	default:
 		fmt.Println("Requesting device code from GitHub...")
-		dc, err := copilot.RequestDeviceCode(ctx)
+		dc, err := copilot.RequestDeviceCode(loginCtx)
 		if err != nil {
 			return err
 		}
@@ -159,7 +170,7 @@ func loginCopilot(cfg *config.ConfigStore) error {
 		fmt.Println()
 		fmt.Println("Waiting for authorization...")
 
-		t, err := copilot.PollForToken(ctx, dc)
+		t, err := copilot.PollForToken(loginCtx, dc)
 		if err == copilot.ErrNotAvailable {
 			fmt.Println()
 			fmt.Println("GitHub Copilot is unavailable for this account. To signup, go to the following page:")
@@ -177,8 +188,8 @@ func loginCopilot(cfg *config.ConfigStore) error {
 	}
 
 	if err := cmp.Or(
-		cfg.SetConfigField(config.ScopeGlobal, "providers.copilot.api_key", token.AccessToken),
-		cfg.SetConfigField(config.ScopeGlobal, "providers.copilot.oauth", token),
+		c.SetConfigField(loginCtx, wsID, config.ScopeGlobal, "providers.copilot.api_key", token.AccessToken),
+		c.SetConfigField(loginCtx, wsID, config.ScopeGlobal, "providers.copilot.oauth", token),
 	); err != nil {
 		return err
 	}

internal/cmd/root.go 🔗

@@ -20,12 +20,9 @@ import (
 	tea "charm.land/bubbletea/v2"
 	"charm.land/lipgloss/v2"
 	"github.com/charmbracelet/colorprofile"
-	"github.com/charmbracelet/crush/internal/app"
 	"github.com/charmbracelet/crush/internal/client"
 	"github.com/charmbracelet/crush/internal/config"
-	"github.com/charmbracelet/crush/internal/db"
 	"github.com/charmbracelet/crush/internal/event"
-	"github.com/charmbracelet/crush/internal/projects"
 	"github.com/charmbracelet/crush/internal/proto"
 	"github.com/charmbracelet/crush/internal/server"
 	"github.com/charmbracelet/crush/internal/ui/common"
@@ -34,7 +31,6 @@ import (
 	"github.com/charmbracelet/crush/internal/workspace"
 	"github.com/charmbracelet/fang"
 	uv "github.com/charmbracelet/ultraviolet"
-	"github.com/charmbracelet/x/ansi"
 	"github.com/charmbracelet/x/exp/charmtone"
 	"github.com/charmbracelet/x/term"
 	"github.com/spf13/cobra"
@@ -46,10 +42,9 @@ func init() {
 	rootCmd.PersistentFlags().StringP("cwd", "c", "", "Current working directory")
 	rootCmd.PersistentFlags().StringP("data-dir", "D", "", "Custom crush data directory")
 	rootCmd.PersistentFlags().BoolP("debug", "d", false, "Debug")
+	rootCmd.PersistentFlags().BoolP("yolo", "y", false, "Automatically accept all permissions (dangerous mode)")
+	rootCmd.PersistentFlags().StringVarP(&clientHost, "host", "H", server.DefaultHost(), "Connect to a specific crush server host (for advanced users)")
 	rootCmd.Flags().BoolP("help", "h", false, "Help")
-	rootCmd.Flags().BoolP("yolo", "y", false, "Automatically accept all permissions (dangerous mode)")
-
-	rootCmd.Flags().StringVarP(&clientHost, "host", "H", server.DefaultHost(), "Connect to a specific crush server host (for advanced users)")
 
 	rootCmd.AddCommand(
 		runCmd,
@@ -87,20 +82,11 @@ crush --yolo
 crush --data-dir /path/to/custom/.crush
   `,
 	RunE: func(cmd *cobra.Command, args []string) error {
-		hostURL, err := server.ParseHostURL(clientHost)
-		if err != nil {
-			return fmt.Errorf("invalid host URL: %v", err)
-		}
-
-		if err := ensureServer(cmd, hostURL); err != nil {
-			return err
-		}
-
-		c, ws, err := setupClientApp(cmd, hostURL)
+		c, ws, cleanup, err := connectToServer(cmd)
 		if err != nil {
 			return err
 		}
-		defer func() { _ = c.DeleteWorkspace(context.Background(), ws.ID) }()
+		defer cleanup()
 
 		event.AppInitialized()
 
@@ -188,25 +174,18 @@ func supportsProgressBar() bool {
 	return isWindowsTerminal || strings.Contains(strings.ToLower(termProg), "ghostty")
 }
 
-func setupAppWithProgressBar(cmd *cobra.Command) (*app.App, error) {
-	app, err := setupApp(cmd)
+// connectToServer ensures the server is running, creates a client and
+// workspace, and returns a cleanup function that deletes the workspace.
+func connectToServer(cmd *cobra.Command) (*client.Client, *proto.Workspace, func(), error) {
+	hostURL, err := server.ParseHostURL(clientHost)
 	if err != nil {
-		return nil, err
+		return nil, nil, nil, fmt.Errorf("invalid host URL: %v", err)
 	}
 
-	// Check if progress bar is enabled in config (defaults to true if nil)
-	progressEnabled := app.Config().Options.Progress == nil || *app.Config().Options.Progress
-	if progressEnabled && supportsProgressBar() {
-		_, _ = fmt.Fprintf(os.Stderr, ansi.SetIndeterminateProgressBar)
-		defer func() { _, _ = fmt.Fprintf(os.Stderr, ansi.ResetProgressBar) }()
+	if err := ensureServer(cmd, hostURL); err != nil {
+		return nil, nil, nil, err
 	}
 
-	return app, nil
-}
-
-// setupApp handles the common setup logic for both interactive and non-interactive modes.
-// It returns the app instance, config, cleanup function, and any error.
-func setupApp(cmd *cobra.Command) (*app.App, error) {
 	debug, _ := cmd.Flags().GetBool("debug")
 	yolo, _ := cmd.Flags().GetBool("yolo")
 	dataDir, _ := cmd.Flags().GetString("data-dir")
@@ -214,95 +193,40 @@ func setupApp(cmd *cobra.Command) (*app.App, error) {
 
 	cwd, err := ResolveCwd(cmd)
 	if err != nil {
-		return nil, err
-	}
-
-	store, err := config.Init(cwd, dataDir, debug)
-	if err != nil {
-		return nil, err
-	}
-
-	store.Overrides().SkipPermissionRequests = yolo
-	cfg := store.Config()
-
-	if err := createDotCrushDir(cfg.Options.DataDirectory); err != nil {
-		return nil, err
-	}
-
-	// Register this project in the centralized projects list.
-	if err := projects.Register(cwd, cfg.Options.DataDirectory); err != nil {
-		slog.Warn("Failed to register project", "error", err)
-		// Non-fatal: continue even if registration fails
-	}
-
-	// Connect to DB; this will also run migrations.
-	conn, err := db.Connect(ctx, cfg.Options.DataDirectory)
-	if err != nil {
-		return nil, err
-	}
-
-	appInstance, err := app.New(ctx, conn, store)
-	if err != nil {
-		slog.Error("Failed to create app instance", "error", err)
-		return nil, err
-	}
-
-	if shouldEnableMetrics(cfg) {
-		event.Init()
-	}
-
-	return appInstance, nil
-}
-
-// setupClientApp sets up a client-based workspace via the server. It
-// auto-starts a detached server process if the socket does not exist.
-func setupClientApp(cmd *cobra.Command, hostURL *url.URL) (*client.Client, *proto.Workspace, error) {
-	debug, _ := cmd.Flags().GetBool("debug")
-	yolo, _ := cmd.Flags().GetBool("yolo")
-	dataDir, _ := cmd.Flags().GetString("data-dir")
-	ctx := cmd.Context()
-
-	cwd, err := ResolveCwd(cmd)
-	if err != nil {
-		return nil, nil, err
+		return nil, nil, nil, err
 	}
 
 	c, err := client.NewClient(cwd, hostURL.Scheme, hostURL.Host)
 	if err != nil {
-		return nil, nil, err
+		return nil, nil, nil, err
 	}
 
-	ws, err := c.CreateWorkspace(ctx, proto.Workspace{
+	wsReq := proto.Workspace{
 		Path:    cwd,
 		DataDir: dataDir,
 		Debug:   debug,
 		YOLO:    yolo,
 		Version: version.Version,
 		Env:     os.Environ(),
-	})
+	}
+
+	ws, err := c.CreateWorkspace(ctx, wsReq)
 	if err != nil {
 		// The server socket may exist before the HTTP handler is ready.
 		// Retry a few times with a short backoff.
 		for range 5 {
 			select {
 			case <-ctx.Done():
-				return nil, nil, ctx.Err()
+				return nil, nil, nil, ctx.Err()
 			case <-time.After(200 * time.Millisecond):
 			}
-			ws, err = c.CreateWorkspace(ctx, proto.Workspace{
-				Path:    cwd,
-				DataDir: dataDir,
-				Debug:   debug,
-				YOLO:    yolo,
-				Version: version.Version,
-				Env:     os.Environ(),
-			})
+			ws, err = c.CreateWorkspace(ctx, wsReq)
 			if err == nil {
 				break
 			}
 		}
 		if err != nil {
-			return nil, nil, fmt.Errorf("failed to create workspace: %v", err)
+			return nil, nil, nil, fmt.Errorf("failed to create workspace: %v", err)
 		}
 	}
 
@@ -310,7 +234,8 @@ func setupClientApp(cmd *cobra.Command, hostURL *url.URL) (*client.Client, *prot
 		event.Init()
 	}
 
-	return c, ws, nil
+	cleanup := func() { _ = c.DeleteWorkspace(context.Background(), ws.ID) }
+	return c, ws, cleanup, nil
 }
 
 // ensureServer auto-starts a detached server if the socket file does not
@@ -488,18 +413,3 @@ func ResolveCwd(cmd *cobra.Command) (string, error) {
 	}
 	return cwd, nil
 }
-
-func createDotCrushDir(dir string) error {
-	if err := os.MkdirAll(dir, 0o700); err != nil {
-		return fmt.Errorf("failed to create data directory: %q %w", dir, err)
-	}
-
-	gitIgnorePath := filepath.Join(dir, ".gitignore")
-	if _, err := os.Stat(gitIgnorePath); os.IsNotExist(err) {
-		if err := os.WriteFile(gitIgnorePath, []byte("*\n"), 0o644); err != nil {
-			return fmt.Errorf("failed to create .gitignore file: %q %w", gitIgnorePath, err)
-		}
-	}
-
-	return nil
-}

internal/cmd/run.go 🔗

@@ -7,9 +7,21 @@ import (
 	"os"
 	"os/signal"
 	"strings"
+	"time"
 
+	"charm.land/lipgloss/v2"
 	"charm.land/log/v2"
+	"github.com/charmbracelet/crush/internal/client"
+	"github.com/charmbracelet/crush/internal/config"
 	"github.com/charmbracelet/crush/internal/event"
+	"github.com/charmbracelet/crush/internal/format"
+	"github.com/charmbracelet/crush/internal/proto"
+	"github.com/charmbracelet/crush/internal/pubsub"
+	"github.com/charmbracelet/crush/internal/ui/anim"
+	"github.com/charmbracelet/crush/internal/ui/styles"
+	"github.com/charmbracelet/x/ansi"
+	"github.com/charmbracelet/x/exp/charmtone"
+	"github.com/charmbracelet/x/term"
 	"github.com/spf13/cobra"
 )
 
@@ -47,13 +59,13 @@ crush run --verbose "Generate a README for this project"
 		ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, os.Kill)
 		defer cancel()
 
-		app, err := setupApp(cmd)
+		c, ws, cleanup, err := connectToServer(cmd)
 		if err != nil {
 			return err
 		}
-		defer app.Shutdown()
+		defer cleanup()
 
-		if !app.Config().IsConfigured() {
+		if !ws.Config.IsConfigured() {
 			return fmt.Errorf("no providers configured - please run 'crush' to set up a provider interactively")
 		}
 
@@ -76,7 +88,7 @@ crush run --verbose "Generate a README for this project"
 		event.SetNonInteractive(true)
 		event.AppInitialized()
 
-		return app.RunNonInteractive(ctx, os.Stdout, prompt, largeModel, smallModel, quiet || verbose)
+		return runNonInteractive(ctx, c, ws, prompt, largeModel, smallModel, quiet || verbose)
 	},
 }
 
@@ -86,3 +98,312 @@ func init() {
 	runCmd.Flags().StringP("model", "m", "", "Model to use. Accepts 'model' or 'provider/model' to disambiguate models with the same name across providers")
 	runCmd.Flags().String("small-model", "", "Small model to use. If not provided, uses the default small model for the provider")
 }
+
+// runNonInteractive executes the agent via the server and streams output
+// to stdout.
+func runNonInteractive(
+	ctx context.Context,
+	c *client.Client,
+	ws *proto.Workspace,
+	prompt, largeModel, smallModel string,
+	hideSpinner bool,
+) error {
+	slog.Info("Running in non-interactive mode")
+
+	ctx, cancel := context.WithCancel(ctx)
+	defer cancel()
+
+	if largeModel != "" || smallModel != "" {
+		if err := overrideModels(ctx, c, ws, largeModel, smallModel); err != nil {
+			return fmt.Errorf("failed to override models: %w", err)
+		}
+	}
+
+	var (
+		spinner   *format.Spinner
+		stdoutTTY bool
+		stderrTTY bool
+		stdinTTY  bool
+		progress  bool
+	)
+
+	stdoutTTY = term.IsTerminal(os.Stdout.Fd())
+	stderrTTY = term.IsTerminal(os.Stderr.Fd())
+	stdinTTY = term.IsTerminal(os.Stdin.Fd())
+	progress = ws.Config.Options.Progress == nil || *ws.Config.Options.Progress
+
+	if !hideSpinner && stderrTTY {
+		t := styles.DefaultStyles()
+
+		hasDarkBG := true
+		if stdinTTY && stdoutTTY {
+			hasDarkBG = lipgloss.HasDarkBackground(os.Stdin, os.Stdout)
+		}
+		defaultFG := lipgloss.LightDark(hasDarkBG)(charmtone.Pepper, t.FgBase)
+
+		spinner = format.NewSpinner(ctx, cancel, anim.Settings{
+			Size:        10,
+			Label:       "Generating",
+			LabelColor:  defaultFG,
+			GradColorA:  t.Primary,
+			GradColorB:  t.Secondary,
+			CycleColors: true,
+		})
+		spinner.Start()
+	}
+
+	stopSpinner := func() {
+		if !hideSpinner && spinner != nil {
+			spinner.Stop()
+			spinner = nil
+		}
+	}
+
+	// Wait for the agent to become ready (MCP init, etc).
+	if err := waitForAgent(ctx, c, ws.ID); err != nil {
+		stopSpinner()
+		return fmt.Errorf("agent not ready: %w", err)
+	}
+
+	// Force-update agent models so MCP tools are loaded.
+	if err := c.UpdateAgent(ctx, ws.ID); err != nil {
+		slog.Warn("Failed to update agent", "error", err)
+	}
+
+	defer stopSpinner()
+
+	sess, err := c.CreateSession(ctx, ws.ID, "non-interactive")
+	if err != nil {
+		return fmt.Errorf("failed to create session: %w", err)
+	}
+	slog.Info("Created session for non-interactive run", "session_id", sess.ID)
+
+	events, err := c.SubscribeEvents(ctx, ws.ID)
+	if err != nil {
+		return fmt.Errorf("failed to subscribe to events: %w", err)
+	}
+
+	if err := c.SendMessage(ctx, ws.ID, sess.ID, prompt); err != nil {
+		return fmt.Errorf("failed to send message: %w", err)
+	}
+
+	messageReadBytes := make(map[string]int)
+	var printed bool
+
+	defer func() {
+		if progress && stderrTTY {
+			_, _ = fmt.Fprintf(os.Stderr, ansi.ResetProgressBar)
+		}
+		_, _ = fmt.Fprintln(os.Stdout)
+	}()
+
+	for {
+		if progress && stderrTTY {
+			_, _ = fmt.Fprintf(os.Stderr, ansi.SetIndeterminateProgressBar)
+		}
+
+		select {
+		case ev, ok := <-events:
+			if !ok {
+				stopSpinner()
+				return nil
+			}
+
+			switch e := ev.(type) {
+			case pubsub.Event[proto.Message]:
+				msg := e.Payload
+				if msg.SessionID != sess.ID || msg.Role != proto.Assistant || len(msg.Parts) == 0 {
+					continue
+				}
+				stopSpinner()
+
+				content := msg.Content().String()
+				readBytes := messageReadBytes[msg.ID]
+
+				if len(content) < readBytes {
+					slog.Error("Non-interactive: message content shorter than read bytes",
+						"message_length", len(content), "read_bytes", readBytes)
+					return fmt.Errorf("message content is shorter than read bytes: %d < %d", len(content), readBytes)
+				}
+
+				part := content[readBytes:]
+				if readBytes == 0 {
+					part = strings.TrimLeft(part, " \t")
+				}
+				if printed || strings.TrimSpace(part) != "" {
+					printed = true
+					fmt.Fprint(os.Stdout, part)
+				}
+				messageReadBytes[msg.ID] = len(content)
+
+				if msg.IsFinished() {
+					return nil
+				}
+
+			case pubsub.Event[proto.AgentEvent]:
+				if e.Payload.Error != nil {
+					stopSpinner()
+					return fmt.Errorf("agent error: %w", e.Payload.Error)
+				}
+			}
+
+		case <-ctx.Done():
+			stopSpinner()
+			return ctx.Err()
+		}
+	}
+}
+
+// waitForAgent polls GetAgentInfo until the agent is ready, with a
+// timeout.
+func waitForAgent(ctx context.Context, c *client.Client, wsID string) error {
+	timeout := time.After(30 * time.Second)
+	for {
+		info, err := c.GetAgentInfo(ctx, wsID)
+		if err == nil && info.IsReady {
+			return nil
+		}
+		select {
+		case <-timeout:
+			if err != nil {
+				return fmt.Errorf("timeout waiting for agent: %w", err)
+			}
+			return fmt.Errorf("timeout waiting for agent readiness")
+		case <-ctx.Done():
+			return ctx.Err()
+		case <-time.After(200 * time.Millisecond):
+		}
+	}
+}
+
+// overrideModels resolves model strings and updates the workspace
+// configuration via the server.
+func overrideModels(
+	ctx context.Context,
+	c *client.Client,
+	ws *proto.Workspace,
+	largeModel, smallModel string,
+) error {
+	cfg, err := c.GetConfig(ctx, ws.ID)
+	if err != nil {
+		return fmt.Errorf("failed to get config: %w", err)
+	}
+
+	providers := cfg.Providers.Copy()
+
+	largeMatches, smallMatches := findModelMatches(providers, largeModel, smallModel)
+
+	var largeProviderID string
+
+	if largeModel != "" {
+		found, err := validateModelMatches(largeMatches, largeModel, "large")
+		if err != nil {
+			return err
+		}
+		largeProviderID = found.provider
+		slog.Info("Overriding large model", "provider", found.provider, "model", found.modelID)
+		if err := c.UpdatePreferredModel(ctx, ws.ID, config.ScopeWorkspace, config.SelectedModelTypeLarge, config.SelectedModel{
+			Provider: found.provider,
+			Model:    found.modelID,
+		}); err != nil {
+			return fmt.Errorf("failed to set large model: %w", err)
+		}
+	}
+
+	switch {
+	case smallModel != "":
+		found, err := validateModelMatches(smallMatches, smallModel, "small")
+		if err != nil {
+			return err
+		}
+		slog.Info("Overriding small model", "provider", found.provider, "model", found.modelID)
+		if err := c.UpdatePreferredModel(ctx, ws.ID, config.ScopeWorkspace, config.SelectedModelTypeSmall, config.SelectedModel{
+			Provider: found.provider,
+			Model:    found.modelID,
+		}); err != nil {
+			return fmt.Errorf("failed to set small model: %w", err)
+		}
+
+	case largeModel != "":
+		sm, err := c.GetDefaultSmallModel(ctx, ws.ID, largeProviderID)
+		if err != nil {
+			slog.Warn("Failed to get default small model", "error", err)
+		} else if sm != nil {
+			if err := c.UpdatePreferredModel(ctx, ws.ID, config.ScopeWorkspace, config.SelectedModelTypeSmall, *sm); err != nil {
+				return fmt.Errorf("failed to set small model: %w", err)
+			}
+		}
+	}
+
+	return c.UpdateAgent(ctx, ws.ID)
+}
+
+type modelMatch struct {
+	provider string
+	modelID  string
+}
+
+// findModelMatches searches providers for matching large/small model
+// strings.
+func findModelMatches(providers map[string]config.ProviderConfig, largeModel, smallModel string) ([]modelMatch, []modelMatch) {
+	largeFilter, largeID := parseModelString(largeModel)
+	smallFilter, smallID := parseModelString(smallModel)
+
+	var largeMatches, smallMatches []modelMatch
+	for name, provider := range providers {
+		if provider.Disable {
+			continue
+		}
+		for _, m := range provider.Models {
+			if matchesModel(largeID, largeFilter, m.ID, name) {
+				largeMatches = append(largeMatches, modelMatch{provider: name, modelID: m.ID})
+			}
+			if matchesModel(smallID, smallFilter, m.ID, name) {
+				smallMatches = append(smallMatches, modelMatch{provider: name, modelID: m.ID})
+			}
+		}
+	}
+	return largeMatches, smallMatches
+}
+
+// parseModelString splits "provider/model" into (provider, model) or
+// ("", model).
+func parseModelString(s string) (string, string) {
+	if s == "" {
+		return "", ""
+	}
+	if idx := strings.Index(s, "/"); idx >= 0 {
+		return s[:idx], s[idx+1:]
+	}
+	return "", s
+}
+
+// matchesModel returns true if the model ID matches the filter
+// criteria.
+func matchesModel(wantID, wantProvider, modelID, providerName string) bool {
+	if wantID == "" {
+		return false
+	}
+	if wantProvider != "" && wantProvider != providerName {
+		return false
+	}
+	return strings.EqualFold(modelID, wantID)
+}
+
+// validateModelMatches ensures exactly one match exists.
+func validateModelMatches(matches []modelMatch, modelID, label string) (modelMatch, error) {
+	switch {
+	case len(matches) == 0:
+		return modelMatch{}, fmt.Errorf("%s model %q not found", label, modelID)
+	case len(matches) > 1:
+		names := make([]string, len(matches))
+		for i, m := range matches {
+			names[i] = m.provider
+		}
+		return modelMatch{}, fmt.Errorf(
+			"%s model: model %q found in multiple providers: %s. Please specify provider using 'provider/model' format",
+			label, modelID, strings.Join(names, ", "),
+		)
+	}
+	return matches[0], nil
+}

internal/workspace/app_workspace.go 🔗

@@ -1,370 +0,0 @@
-package workspace
-
-import (
-	"context"
-	"log/slog"
-	"time"
-
-	tea "charm.land/bubbletea/v2"
-	"github.com/charmbracelet/crush/internal/agent"
-	mcptools "github.com/charmbracelet/crush/internal/agent/tools/mcp"
-	"github.com/charmbracelet/crush/internal/app"
-	"github.com/charmbracelet/crush/internal/commands"
-	"github.com/charmbracelet/crush/internal/config"
-	"github.com/charmbracelet/crush/internal/history"
-	"github.com/charmbracelet/crush/internal/log"
-	"github.com/charmbracelet/crush/internal/lsp"
-	"github.com/charmbracelet/crush/internal/message"
-	"github.com/charmbracelet/crush/internal/oauth"
-	"github.com/charmbracelet/crush/internal/permission"
-	"github.com/charmbracelet/crush/internal/pubsub"
-	"github.com/charmbracelet/crush/internal/session"
-)
-
-// AppWorkspace wraps an in-process app.App to satisfy the Workspace
-// interface. This is the default mode when no server is involved.
-type AppWorkspace struct {
-	app *app.App
-}
-
-// NewAppWorkspace creates a Workspace backed by a local app.App.
-func NewAppWorkspace(a *app.App) *AppWorkspace {
-	return &AppWorkspace{app: a}
-}
-
-// App returns the underlying app.App for callers that still need
-// direct access during the migration period.
-func (w *AppWorkspace) App() *app.App {
-	return w.app
-}
-
-// -- Sessions --
-
-func (w *AppWorkspace) CreateSession(ctx context.Context, title string) (session.Session, error) {
-	return w.app.Sessions.Create(ctx, title)
-}
-
-func (w *AppWorkspace) GetSession(ctx context.Context, sessionID string) (session.Session, error) {
-	return w.app.Sessions.Get(ctx, sessionID)
-}
-
-func (w *AppWorkspace) ListSessions(ctx context.Context) ([]session.Session, error) {
-	return w.app.Sessions.List(ctx)
-}
-
-func (w *AppWorkspace) SaveSession(ctx context.Context, sess session.Session) (session.Session, error) {
-	return w.app.Sessions.Save(ctx, sess)
-}
-
-func (w *AppWorkspace) DeleteSession(ctx context.Context, sessionID string) error {
-	return w.app.Sessions.Delete(ctx, sessionID)
-}
-
-func (w *AppWorkspace) CreateAgentToolSessionID(messageID, toolCallID string) string {
-	return w.app.Sessions.CreateAgentToolSessionID(messageID, toolCallID)
-}
-
-func (w *AppWorkspace) ParseAgentToolSessionID(sessionID string) (string, string, bool) {
-	return w.app.Sessions.ParseAgentToolSessionID(sessionID)
-}
-
-// -- Messages --
-
-func (w *AppWorkspace) ListMessages(ctx context.Context, sessionID string) ([]message.Message, error) {
-	return w.app.Messages.List(ctx, sessionID)
-}
-
-func (w *AppWorkspace) ListUserMessages(ctx context.Context, sessionID string) ([]message.Message, error) {
-	return w.app.Messages.ListUserMessages(ctx, sessionID)
-}
-
-func (w *AppWorkspace) ListAllUserMessages(ctx context.Context) ([]message.Message, error) {
-	return w.app.Messages.ListAllUserMessages(ctx)
-}
-
-// -- Agent --
-
-func (w *AppWorkspace) AgentRun(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) error {
-	if w.app.AgentCoordinator == nil {
-		return nil
-	}
-	_, err := w.app.AgentCoordinator.Run(ctx, sessionID, prompt, attachments...)
-	return err
-}
-
-func (w *AppWorkspace) AgentCancel(sessionID string) {
-	if w.app.AgentCoordinator != nil {
-		w.app.AgentCoordinator.Cancel(sessionID)
-	}
-}
-
-func (w *AppWorkspace) AgentIsBusy() bool {
-	if w.app.AgentCoordinator == nil {
-		return false
-	}
-	return w.app.AgentCoordinator.IsBusy()
-}
-
-func (w *AppWorkspace) AgentIsSessionBusy(sessionID string) bool {
-	if w.app.AgentCoordinator == nil {
-		return false
-	}
-	return w.app.AgentCoordinator.IsSessionBusy(sessionID)
-}
-
-func (w *AppWorkspace) AgentModel() AgentModel {
-	if w.app.AgentCoordinator == nil {
-		return AgentModel{}
-	}
-	m := w.app.AgentCoordinator.Model()
-	return AgentModel{
-		CatwalkCfg: m.CatwalkCfg,
-		ModelCfg:   m.ModelCfg,
-	}
-}
-
-func (w *AppWorkspace) AgentIsReady() bool {
-	return w.app.AgentCoordinator != nil
-}
-
-func (w *AppWorkspace) AgentQueuedPrompts(sessionID string) int {
-	if w.app.AgentCoordinator == nil {
-		return 0
-	}
-	return w.app.AgentCoordinator.QueuedPrompts(sessionID)
-}
-
-func (w *AppWorkspace) AgentQueuedPromptsList(sessionID string) []string {
-	if w.app.AgentCoordinator == nil {
-		return nil
-	}
-	return w.app.AgentCoordinator.QueuedPromptsList(sessionID)
-}
-
-func (w *AppWorkspace) AgentClearQueue(sessionID string) {
-	if w.app.AgentCoordinator != nil {
-		w.app.AgentCoordinator.ClearQueue(sessionID)
-	}
-}
-
-func (w *AppWorkspace) AgentSummarize(ctx context.Context, sessionID string) error {
-	if w.app.AgentCoordinator == nil {
-		return nil
-	}
-	return w.app.AgentCoordinator.Summarize(ctx, sessionID)
-}
-
-func (w *AppWorkspace) UpdateAgentModel(ctx context.Context) error {
-	return w.app.UpdateAgentModel(ctx)
-}
-
-func (w *AppWorkspace) InitCoderAgent(ctx context.Context) error {
-	return w.app.InitCoderAgent(ctx)
-}
-
-func (w *AppWorkspace) GetDefaultSmallModel(providerID string) config.SelectedModel {
-	return w.app.GetDefaultSmallModel(providerID)
-}
-
-// -- Permissions --
-
-func (w *AppWorkspace) PermissionGrant(perm permission.PermissionRequest) {
-	w.app.Permissions.Grant(perm)
-}
-
-func (w *AppWorkspace) PermissionGrantPersistent(perm permission.PermissionRequest) {
-	w.app.Permissions.GrantPersistent(perm)
-}
-
-func (w *AppWorkspace) PermissionDeny(perm permission.PermissionRequest) {
-	w.app.Permissions.Deny(perm)
-}
-
-func (w *AppWorkspace) PermissionSkipRequests() bool {
-	return w.app.Permissions.SkipRequests()
-}
-
-func (w *AppWorkspace) PermissionSetSkipRequests(skip bool) {
-	w.app.Permissions.SetSkipRequests(skip)
-}
-
-// -- FileTracker --
-
-func (w *AppWorkspace) FileTrackerRecordRead(ctx context.Context, sessionID, path string) {
-	w.app.FileTracker.RecordRead(ctx, sessionID, path)
-}
-
-func (w *AppWorkspace) FileTrackerLastReadTime(ctx context.Context, sessionID, path string) time.Time {
-	return w.app.FileTracker.LastReadTime(ctx, sessionID, path)
-}
-
-func (w *AppWorkspace) FileTrackerListReadFiles(ctx context.Context, sessionID string) ([]string, error) {
-	return w.app.FileTracker.ListReadFiles(ctx, sessionID)
-}
-
-// -- History --
-
-func (w *AppWorkspace) ListSessionHistory(ctx context.Context, sessionID string) ([]history.File, error) {
-	return w.app.History.ListBySession(ctx, sessionID)
-}
-
-// -- LSP --
-
-func (w *AppWorkspace) LSPStart(ctx context.Context, path string) {
-	w.app.LSPManager.Start(ctx, path)
-}
-
-func (w *AppWorkspace) LSPStopAll(ctx context.Context) {
-	w.app.LSPManager.StopAll(ctx)
-}
-
-func (w *AppWorkspace) LSPGetStates() map[string]LSPClientInfo {
-	states := app.GetLSPStates()
-	result := make(map[string]LSPClientInfo, len(states))
-	for k, v := range states {
-		result[k] = LSPClientInfo{
-			Name:            v.Name,
-			State:           v.State,
-			Error:           v.Error,
-			DiagnosticCount: v.DiagnosticCount,
-			ConnectedAt:     v.ConnectedAt,
-		}
-	}
-	return result
-}
-
-func (w *AppWorkspace) LSPGetClient(name string) (*lsp.Client, bool) {
-	info, ok := app.GetLSPState(name)
-	if !ok {
-		return nil, false
-	}
-	return info.Client, true
-}
-
-// -- Config (read-only) --
-
-func (w *AppWorkspace) Config() *config.Config {
-	return w.app.Config()
-}
-
-func (w *AppWorkspace) WorkingDir() string {
-	return w.app.Store().WorkingDir()
-}
-
-func (w *AppWorkspace) Resolver() config.VariableResolver {
-	return w.app.Store().Resolver()
-}
-
-// -- Config mutations --
-
-func (w *AppWorkspace) UpdatePreferredModel(scope config.Scope, modelType config.SelectedModelType, model config.SelectedModel) error {
-	return w.app.Store().UpdatePreferredModel(scope, modelType, model)
-}
-
-func (w *AppWorkspace) SetCompactMode(scope config.Scope, enabled bool) error {
-	return w.app.Store().SetCompactMode(scope, enabled)
-}
-
-func (w *AppWorkspace) SetProviderAPIKey(scope config.Scope, providerID string, apiKey any) error {
-	return w.app.Store().SetProviderAPIKey(scope, providerID, apiKey)
-}
-
-func (w *AppWorkspace) SetConfigField(scope config.Scope, key string, value any) error {
-	return w.app.Store().SetConfigField(scope, key, value)
-}
-
-func (w *AppWorkspace) RemoveConfigField(scope config.Scope, key string) error {
-	return w.app.Store().RemoveConfigField(scope, key)
-}
-
-func (w *AppWorkspace) ImportCopilot() (*oauth.Token, bool) {
-	return w.app.Store().ImportCopilot()
-}
-
-func (w *AppWorkspace) RefreshOAuthToken(ctx context.Context, scope config.Scope, providerID string) error {
-	return w.app.Store().RefreshOAuthToken(ctx, scope, providerID)
-}
-
-// -- Project lifecycle --
-
-func (w *AppWorkspace) ProjectNeedsInitialization() (bool, error) {
-	return config.ProjectNeedsInitialization(w.app.Store())
-}
-
-func (w *AppWorkspace) MarkProjectInitialized() error {
-	return config.MarkProjectInitialized(w.app.Store())
-}
-
-func (w *AppWorkspace) InitializePrompt() (string, error) {
-	return agent.InitializePrompt(w.app.Store())
-}
-
-// -- MCP operations --
-
-func (w *AppWorkspace) MCPGetStates() map[string]mcptools.ClientInfo {
-	return mcptools.GetStates()
-}
-
-func (w *AppWorkspace) MCPRefreshPrompts(ctx context.Context, name string) {
-	mcptools.RefreshPrompts(ctx, name)
-}
-
-func (w *AppWorkspace) MCPRefreshResources(ctx context.Context, name string) {
-	mcptools.RefreshResources(ctx, name)
-}
-
-func (w *AppWorkspace) RefreshMCPTools(ctx context.Context, name string) {
-	mcptools.RefreshTools(ctx, w.app.Store(), name)
-}
-
-func (w *AppWorkspace) ReadMCPResource(ctx context.Context, name, uri string) ([]MCPResourceContents, error) {
-	contents, err := mcptools.ReadResource(ctx, w.app.Store(), name, uri)
-	if err != nil {
-		return nil, err
-	}
-	result := make([]MCPResourceContents, len(contents))
-	for i, c := range contents {
-		result[i] = MCPResourceContents{
-			URI:      c.URI,
-			MIMEType: c.MIMEType,
-			Text:     c.Text,
-			Blob:     c.Blob,
-		}
-	}
-	return result, nil
-}
-
-func (w *AppWorkspace) GetMCPPrompt(clientID, promptID string, args map[string]string) (string, error) {
-	return commands.GetMCPPrompt(w.app.Store(), clientID, promptID, args)
-}
-
-// -- Lifecycle --
-
-func (w *AppWorkspace) Subscribe(program *tea.Program) {
-	defer log.RecoverPanic("AppWorkspace.Subscribe", func() {
-		slog.Info("TUI subscription panic: attempting graceful shutdown")
-		program.Quit()
-	})
-
-	for msg := range w.app.Events() {
-		switch ev := msg.(type) {
-		case pubsub.Event[app.LSPEvent]:
-			program.Send(pubsub.Event[LSPEvent]{
-				Type: ev.Type,
-				Payload: LSPEvent{
-					Type:            LSPEventType(ev.Payload.Type),
-					Name:            ev.Payload.Name,
-					State:           ev.Payload.State,
-					Error:           ev.Payload.Error,
-					DiagnosticCount: ev.Payload.DiagnosticCount,
-				},
-			})
-		default:
-			program.Send(msg)
-		}
-	}
-}
-
-func (w *AppWorkspace) Shutdown() {
-	w.app.Shutdown()
-}