root.go

  1package cmd
  2
  3import (
  4	"bytes"
  5	"context"
  6	"errors"
  7	"fmt"
  8	"io"
  9	"io/fs"
 10	"log/slog"
 11	"net/url"
 12	"os"
 13	"os/exec"
 14	"path/filepath"
 15	"regexp"
 16	"strconv"
 17	"strings"
 18	"time"
 19
 20	tea "charm.land/bubbletea/v2"
 21	fang "charm.land/fang/v2"
 22	"charm.land/lipgloss/v2"
 23	"github.com/charmbracelet/colorprofile"
 24	"github.com/charmbracelet/crush/internal/app"
 25	"github.com/charmbracelet/crush/internal/client"
 26	"github.com/charmbracelet/crush/internal/config"
 27	"github.com/charmbracelet/crush/internal/db"
 28	"github.com/charmbracelet/crush/internal/event"
 29	crushlog "github.com/charmbracelet/crush/internal/log"
 30	"github.com/charmbracelet/crush/internal/projects"
 31	"github.com/charmbracelet/crush/internal/proto"
 32	"github.com/charmbracelet/crush/internal/server"
 33	"github.com/charmbracelet/crush/internal/session"
 34	"github.com/charmbracelet/crush/internal/ui/common"
 35	ui "github.com/charmbracelet/crush/internal/ui/model"
 36	"github.com/charmbracelet/crush/internal/version"
 37	"github.com/charmbracelet/crush/internal/workspace"
 38	uv "github.com/charmbracelet/ultraviolet"
 39	"github.com/charmbracelet/x/ansi"
 40	"github.com/charmbracelet/x/exp/charmtone"
 41	"github.com/charmbracelet/x/term"
 42	"github.com/spf13/cobra"
 43)
 44
 45var clientHost string
 46
 47func init() {
 48	rootCmd.PersistentFlags().StringP("cwd", "c", "", "Current working directory")
 49	rootCmd.PersistentFlags().StringP("data-dir", "D", "", "Custom crush data directory")
 50	rootCmd.PersistentFlags().BoolP("debug", "d", false, "Debug")
 51	rootCmd.PersistentFlags().StringVarP(&clientHost, "host", "H", server.DefaultHost(), "Connect to a specific crush server host (for advanced users)")
 52	rootCmd.Flags().BoolP("help", "h", false, "Help")
 53	rootCmd.Flags().BoolP("yolo", "y", false, "Automatically accept all permissions (dangerous mode)")
 54	rootCmd.Flags().StringP("session", "s", "", "Continue a previous session by ID")
 55	rootCmd.Flags().BoolP("continue", "C", false, "Continue the most recent session")
 56	rootCmd.MarkFlagsMutuallyExclusive("session", "continue")
 57
 58	rootCmd.AddCommand(
 59		runCmd,
 60		dirsCmd,
 61		projectsCmd,
 62		updateProvidersCmd,
 63		logsCmd,
 64		schemaCmd,
 65		loginCmd,
 66		statsCmd,
 67		sessionCmd,
 68	)
 69}
 70
 71var rootCmd = &cobra.Command{
 72	Use:   "crush",
 73	Short: "A terminal-first AI assistant for software development",
 74	Long:  "A glamorous, terminal-first AI assistant for software development and adjacent tasks",
 75	Example: `
 76# Run in interactive mode
 77crush
 78
 79# Run non-interactively
 80crush run "Guess my 5 favorite PokΓ©mon"
 81
 82# Run a non-interactively with pipes and redirection
 83cat README.md | crush run "make this more glamorous" > GLAMOROUS_README.md
 84
 85# Run with debug logging in a specific directory
 86crush --debug --cwd /path/to/project
 87
 88# Run in yolo mode (auto-accept all permissions; use with care)
 89crush --yolo
 90
 91# Run with custom data directory
 92crush --data-dir /path/to/custom/.crush
 93
 94# Continue a previous session
 95crush --session {session-id}
 96
 97# Continue the most recent session
 98crush --continue
 99  `,
100	RunE: func(cmd *cobra.Command, args []string) error {
101		sessionID, _ := cmd.Flags().GetString("session")
102		continueLast, _ := cmd.Flags().GetBool("continue")
103
104		ws, cleanup, err := setupWorkspaceWithProgressBar(cmd)
105		if err != nil {
106			return err
107		}
108		defer cleanup()
109
110		if sessionID != "" {
111			sess, err := resolveWorkspaceSessionID(cmd.Context(), ws, sessionID)
112			if err != nil {
113				return err
114			}
115			sessionID = sess.ID
116		}
117
118		event.AppInitialized()
119
120		com := common.DefaultCommon(ws)
121		model := ui.New(com, sessionID, continueLast)
122
123		var env uv.Environ = os.Environ()
124		program := tea.NewProgram(
125			model,
126			tea.WithEnvironment(env),
127			tea.WithContext(cmd.Context()),
128			tea.WithFilter(ui.MouseEventFilter),
129		)
130		go ws.Subscribe(program)
131
132		if _, err := program.Run(); err != nil {
133			event.Error(err)
134			slog.Error("TUI run error", "error", err)
135			return errors.New("Crush crashed. If metrics are enabled, we were notified about it. If you'd like to report it, please copy the stacktrace above and open an issue at https://github.com/charmbracelet/crush/issues/new?template=bug.yml") //nolint:staticcheck
136		}
137		return nil
138	},
139}
140
141var heartbit = lipgloss.NewStyle().Foreground(charmtone.Dolly).SetString(`
142    β–„β–„β–„β–„β–„β–„β–„β–„    β–„β–„β–„β–„β–„β–„β–„β–„
143  β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ  β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ
144β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ
145β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ
146β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–€β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–€β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ
147β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ
148β–€β–€β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–„β–ˆβ–ˆβ–ˆβ–ˆβ–„β–„β–ˆβ–ˆβ–ˆβ–ˆβ–„β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–€β–€
149  β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ
150    β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ
151       β–€β–€β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–€β–€
152           β–€β–€β–€β–€β–€β–€
153`)
154
155// copied from cobra:
156const defaultVersionTemplate = `{{with .DisplayName}}{{printf "%s " .}}{{end}}{{printf "version %s" .Version}}
157`
158
159func Execute() {
160	// NOTE: very hacky: we create a colorprofile writer with STDOUT, then make
161	// it forward to a bytes.Buffer, write the colored heartbit to it, and then
162	// finally prepend it in the version template.
163	// Unfortunately cobra doesn't give us a way to set a function to handle
164	// printing the version, and PreRunE runs after the version is already
165	// handled, so that doesn't work either.
166	// This is the only way I could find that works relatively well.
167	if term.IsTerminal(os.Stdout.Fd()) {
168		var b bytes.Buffer
169		w := colorprofile.NewWriter(os.Stdout, os.Environ())
170		w.Forward = &b
171		_, _ = w.WriteString(heartbit.String())
172		rootCmd.SetVersionTemplate(b.String() + "\n" + defaultVersionTemplate)
173	}
174	if err := fang.Execute(
175		context.Background(),
176		rootCmd,
177		fang.WithVersion(version.Version),
178		fang.WithNotifySignal(os.Interrupt),
179	); err != nil {
180		os.Exit(1)
181	}
182}
183
184// supportsProgressBar tries to determine whether the current terminal supports
185// progress bars by looking into environment variables.
186func supportsProgressBar() bool {
187	if !term.IsTerminal(os.Stderr.Fd()) {
188		return false
189	}
190	termProg := os.Getenv("TERM_PROGRAM")
191	_, isWindowsTerminal := os.LookupEnv("WT_SESSION")
192
193	return isWindowsTerminal || strings.Contains(strings.ToLower(termProg), "ghostty")
194}
195
196// useClientServer returns true when the client/server architecture is
197// enabled via the CRUSH_CLIENT_SERVER environment variable.
198func useClientServer() bool {
199	v, _ := strconv.ParseBool(os.Getenv("CRUSH_CLIENT_SERVER"))
200	return v
201}
202
203// setupWorkspaceWithProgressBar wraps setupWorkspace with an optional
204// terminal progress bar shown during initialization.
205func setupWorkspaceWithProgressBar(cmd *cobra.Command) (workspace.Workspace, func(), error) {
206	showProgress := supportsProgressBar()
207	if showProgress {
208		_, _ = fmt.Fprintf(os.Stderr, ansi.SetIndeterminateProgressBar)
209	}
210
211	ws, cleanup, err := setupWorkspace(cmd)
212
213	if showProgress {
214		_, _ = fmt.Fprintf(os.Stderr, ansi.ResetProgressBar)
215	}
216
217	return ws, cleanup, err
218}
219
220// setupWorkspace returns a Workspace and cleanup function. When
221// CRUSH_CLIENT_SERVER=1, it connects to a server process and returns a
222// ClientWorkspace. Otherwise it creates an in-process app.App and
223// returns an AppWorkspace.
224func setupWorkspace(cmd *cobra.Command) (workspace.Workspace, func(), error) {
225	if useClientServer() {
226		return setupClientServerWorkspace(cmd)
227	}
228	return setupLocalWorkspace(cmd)
229}
230
231// setupLocalWorkspace creates an in-process app.App and wraps it in an
232// AppWorkspace.
233func setupLocalWorkspace(cmd *cobra.Command) (workspace.Workspace, func(), error) {
234	debug, _ := cmd.Flags().GetBool("debug")
235	yolo, _ := cmd.Flags().GetBool("yolo")
236	dataDir, _ := cmd.Flags().GetString("data-dir")
237	ctx := cmd.Context()
238
239	cwd, err := ResolveCwd(cmd)
240	if err != nil {
241		return nil, nil, err
242	}
243
244	store, err := config.Init(cwd, dataDir, debug)
245	if err != nil {
246		return nil, nil, err
247	}
248
249	cfg := store.Config()
250	store.Overrides().SkipPermissionRequests = yolo
251
252	if err := os.MkdirAll(cfg.Options.DataDirectory, 0o700); err != nil {
253		return nil, nil, fmt.Errorf("failed to create data directory: %q %w", cfg.Options.DataDirectory, err)
254	}
255
256	gitIgnorePath := filepath.Join(cfg.Options.DataDirectory, ".gitignore")
257	if _, err := os.Stat(gitIgnorePath); os.IsNotExist(err) {
258		if err := os.WriteFile(gitIgnorePath, []byte("*\n"), 0o644); err != nil {
259			return nil, nil, fmt.Errorf("failed to create .gitignore file: %q %w", gitIgnorePath, err)
260		}
261	}
262
263	if err := projects.Register(cwd, cfg.Options.DataDirectory); err != nil {
264		slog.Warn("Failed to register project", "error", err)
265	}
266
267	conn, err := db.Connect(ctx, cfg.Options.DataDirectory)
268	if err != nil {
269		return nil, nil, err
270	}
271
272	logFile := filepath.Join(cfg.Options.DataDirectory, "logs", "crush.log")
273	crushlog.Setup(logFile, debug)
274
275	appInstance, err := app.New(ctx, conn, store)
276	if err != nil {
277		_ = conn.Close()
278		slog.Error("Failed to create app instance", "error", err)
279		return nil, nil, err
280	}
281
282	if shouldEnableMetrics(cfg) {
283		event.Init()
284	}
285
286	ws := workspace.NewAppWorkspace(appInstance, store)
287	cleanup := func() { appInstance.Shutdown() }
288	return ws, cleanup, nil
289}
290
291// setupClientServerWorkspace connects to a server process and wraps the
292// result in a ClientWorkspace.
293func setupClientServerWorkspace(cmd *cobra.Command) (workspace.Workspace, func(), error) {
294	c, protoWs, cleanupServer, err := connectToServer(cmd)
295	if err != nil {
296		return nil, nil, err
297	}
298
299	clientWs := workspace.NewClientWorkspace(c, *protoWs)
300
301	if protoWs.Config.IsConfigured() {
302		if err := clientWs.InitCoderAgent(cmd.Context()); err != nil {
303			slog.Error("Failed to initialize coder agent", "error", err)
304		}
305	}
306
307	return clientWs, cleanupServer, nil
308}
309
310// connectToServer ensures the server is running, creates a client and
311// workspace, and returns a cleanup function that deletes the workspace.
312func connectToServer(cmd *cobra.Command) (*client.Client, *proto.Workspace, func(), error) {
313	hostURL, err := server.ParseHostURL(clientHost)
314	if err != nil {
315		return nil, nil, nil, fmt.Errorf("invalid host URL: %v", err)
316	}
317
318	if err := ensureServer(cmd, hostURL); err != nil {
319		return nil, nil, nil, err
320	}
321
322	debug, _ := cmd.Flags().GetBool("debug")
323	yolo, _ := cmd.Flags().GetBool("yolo")
324	dataDir, _ := cmd.Flags().GetString("data-dir")
325	ctx := cmd.Context()
326
327	cwd, err := ResolveCwd(cmd)
328	if err != nil {
329		return nil, nil, nil, err
330	}
331
332	c, err := client.NewClient(cwd, hostURL.Scheme, hostURL.Host)
333	if err != nil {
334		return nil, nil, nil, err
335	}
336
337	wsReq := proto.Workspace{
338		Path:    cwd,
339		DataDir: dataDir,
340		Debug:   debug,
341		YOLO:    yolo,
342		Version: version.Version,
343		Env:     os.Environ(),
344	}
345
346	ws, err := c.CreateWorkspace(ctx, wsReq)
347	if err != nil {
348		// The server socket may exist before the HTTP handler is ready.
349		// Retry a few times with a short backoff.
350		for range 5 {
351			select {
352			case <-ctx.Done():
353				return nil, nil, nil, ctx.Err()
354			case <-time.After(200 * time.Millisecond):
355			}
356			ws, err = c.CreateWorkspace(ctx, wsReq)
357			if err == nil {
358				break
359			}
360		}
361		if err != nil {
362			return nil, nil, nil, fmt.Errorf("failed to create workspace: %v", err)
363		}
364	}
365
366	if shouldEnableMetrics(ws.Config) {
367		event.Init()
368	}
369
370	if ws.Config != nil {
371		logFile := filepath.Join(ws.Config.Options.DataDirectory, "logs", "crush.log")
372		crushlog.Setup(logFile, debug)
373	}
374
375	cleanup := func() { _ = c.DeleteWorkspace(context.Background(), ws.ID) }
376	return c, ws, cleanup, nil
377}
378
379// ensureServer auto-starts a detached server if the socket file does not
380// exist. When the socket exists, it verifies that the running server
381// version matches the client; on mismatch it shuts down the old server
382// and starts a fresh one.
383func ensureServer(cmd *cobra.Command, hostURL *url.URL) error {
384	switch hostURL.Scheme {
385	case "unix", "npipe":
386		needsStart := false
387		if _, err := os.Stat(hostURL.Host); err != nil && errors.Is(err, fs.ErrNotExist) {
388			needsStart = true
389		} else if err == nil {
390			if err := restartIfStale(cmd, hostURL); err != nil {
391				slog.Warn("Failed to check server version, restarting", "error", err)
392				needsStart = true
393			}
394		}
395
396		if needsStart {
397			if err := startDetachedServer(cmd); err != nil {
398				return err
399			}
400		}
401
402		var err error
403		for range 10 {
404			_, err = os.Stat(hostURL.Host)
405			if err == nil {
406				break
407			}
408			select {
409			case <-cmd.Context().Done():
410				return cmd.Context().Err()
411			case <-time.After(100 * time.Millisecond):
412			}
413		}
414		if err != nil {
415			return fmt.Errorf("failed to initialize crush server: %v", err)
416		}
417	}
418
419	return nil
420}
421
422// restartIfStale checks whether the running server matches the current
423// client version. When they differ, it sends a shutdown command and
424// removes the stale socket so the caller can start a fresh server.
425func restartIfStale(cmd *cobra.Command, hostURL *url.URL) error {
426	c, err := client.NewClient("", hostURL.Scheme, hostURL.Host)
427	if err != nil {
428		return err
429	}
430	vi, err := c.VersionInfo(cmd.Context())
431	if err != nil {
432		return err
433	}
434	if vi.Version == version.Version {
435		return nil
436	}
437	slog.Info("Server version mismatch, restarting",
438		"server", vi.Version,
439		"client", version.Version,
440	)
441	_ = c.ShutdownServer(cmd.Context())
442	// Give the old process a moment to release the socket.
443	for range 20 {
444		if _, err := os.Stat(hostURL.Host); errors.Is(err, fs.ErrNotExist) {
445			break
446		}
447		select {
448		case <-cmd.Context().Done():
449			return cmd.Context().Err()
450		case <-time.After(100 * time.Millisecond):
451		}
452	}
453	// Force-remove if the socket is still lingering.
454	_ = os.Remove(hostURL.Host)
455	return nil
456}
457
458var safeNameRegexp = regexp.MustCompile(`[^a-zA-Z0-9._-]`)
459
460func startDetachedServer(cmd *cobra.Command) error {
461	exe, err := os.Executable()
462	if err != nil {
463		return fmt.Errorf("failed to get executable path: %v", err)
464	}
465
466	safeClientHost := safeNameRegexp.ReplaceAllString(clientHost, "_")
467	chDir := filepath.Join(config.GlobalCacheDir(), "server-"+safeClientHost)
468	if err := os.MkdirAll(chDir, 0o700); err != nil {
469		return fmt.Errorf("failed to create server working directory: %v", err)
470	}
471
472	cmdArgs := []string{"server"}
473	if clientHost != server.DefaultHost() {
474		cmdArgs = append(cmdArgs, "--host", clientHost)
475	}
476
477	c := exec.CommandContext(cmd.Context(), exe, cmdArgs...)
478	stdoutPath := filepath.Join(chDir, "stdout.log")
479	stderrPath := filepath.Join(chDir, "stderr.log")
480	detachProcess(c)
481
482	stdout, err := os.Create(stdoutPath)
483	if err != nil {
484		return fmt.Errorf("failed to create stdout log file: %v", err)
485	}
486	defer stdout.Close()
487	c.Stdout = stdout
488
489	stderr, err := os.Create(stderrPath)
490	if err != nil {
491		return fmt.Errorf("failed to create stderr log file: %v", err)
492	}
493	defer stderr.Close()
494	c.Stderr = stderr
495
496	if err := c.Start(); err != nil {
497		return fmt.Errorf("failed to start crush server: %v", err)
498	}
499
500	if err := c.Process.Release(); err != nil {
501		return fmt.Errorf("failed to detach crush server process: %v", err)
502	}
503
504	return nil
505}
506
507func shouldEnableMetrics(cfg *config.Config) bool {
508	if v, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_METRICS")); v {
509		return false
510	}
511	if v, _ := strconv.ParseBool(os.Getenv("DO_NOT_TRACK")); v {
512		return false
513	}
514	if cfg.Options.DisableMetrics {
515		return false
516	}
517	return true
518}
519
520func MaybePrependStdin(prompt string) (string, error) {
521	if term.IsTerminal(os.Stdin.Fd()) {
522		return prompt, nil
523	}
524	fi, err := os.Stdin.Stat()
525	if err != nil {
526		return prompt, err
527	}
528	// Check if stdin is a named pipe ( | ) or regular file ( < ).
529	if fi.Mode()&os.ModeNamedPipe == 0 && !fi.Mode().IsRegular() {
530		return prompt, nil
531	}
532	bts, err := io.ReadAll(os.Stdin)
533	if err != nil {
534		return prompt, err
535	}
536	return string(bts) + "\n\n" + prompt, nil
537}
538
539// resolveWorkspaceSessionID resolves a session ID that may be a full
540// UUID, full hash, or hash prefix. Works against the Workspace
541// interface so both local and client/server paths get hash prefix
542// support.
543func resolveWorkspaceSessionID(ctx context.Context, ws workspace.Workspace, id string) (session.Session, error) {
544	if sess, err := ws.GetSession(ctx, id); err == nil {
545		return sess, nil
546	}
547
548	sessions, err := ws.ListSessions(ctx)
549	if err != nil {
550		return session.Session{}, err
551	}
552
553	var matches []session.Session
554	for _, s := range sessions {
555		hash := session.HashID(s.ID)
556		if hash == id || strings.HasPrefix(hash, id) {
557			matches = append(matches, s)
558		}
559	}
560
561	switch len(matches) {
562	case 0:
563		return session.Session{}, fmt.Errorf("session not found: %s", id)
564	case 1:
565		return matches[0], nil
566	default:
567		return session.Session{}, fmt.Errorf("session ID %q is ambiguous (%d matches)", id, len(matches))
568	}
569}
570
571func ResolveCwd(cmd *cobra.Command) (string, error) {
572	cwd, _ := cmd.Flags().GetString("cwd")
573	if cwd != "" {
574		err := os.Chdir(cwd)
575		if err != nil {
576			return "", fmt.Errorf("failed to change directory: %v", err)
577		}
578		return cwd, nil
579	}
580	cwd, err := os.Getwd()
581	if err != nil {
582		return "", fmt.Errorf("failed to get current working directory: %v", err)
583	}
584	return cwd, nil
585}