root.go

  1package cmd
  2
  3import (
  4	"bytes"
  5	"context"
  6	"errors"
  7	"fmt"
  8	"io"
  9	"log/slog"
 10	"os"
 11	"path/filepath"
 12	"strconv"
 13	"strings"
 14	"time"
 15
 16	tea "charm.land/bubbletea/v2"
 17	"charm.land/lipgloss/v2"
 18	"github.com/charmbracelet/colorprofile"
 19	"github.com/charmbracelet/crush/internal/app"
 20	"github.com/charmbracelet/crush/internal/config"
 21	"github.com/charmbracelet/crush/internal/db"
 22	"github.com/charmbracelet/crush/internal/event"
 23	"github.com/charmbracelet/crush/internal/stringext"
 24	termutil "github.com/charmbracelet/crush/internal/term"
 25	"github.com/charmbracelet/crush/internal/tui"
 26	tuiutil "github.com/charmbracelet/crush/internal/tui/util"
 27	"github.com/charmbracelet/crush/internal/update"
 28	"github.com/charmbracelet/crush/internal/version"
 29	"github.com/charmbracelet/fang"
 30	uv "github.com/charmbracelet/ultraviolet"
 31	"github.com/charmbracelet/x/ansi"
 32	"github.com/charmbracelet/x/exp/charmtone"
 33	"github.com/charmbracelet/x/term"
 34	"github.com/spf13/cobra"
 35)
 36
 37func init() {
 38	rootCmd.PersistentFlags().StringP("cwd", "c", "", "Current working directory")
 39	rootCmd.PersistentFlags().StringP("data-dir", "D", "", "Custom crush data directory")
 40	rootCmd.PersistentFlags().BoolP("debug", "d", false, "Debug")
 41	rootCmd.Flags().BoolP("help", "h", false, "Help")
 42	rootCmd.Flags().BoolP("yolo", "y", false, "Automatically accept all permissions (dangerous mode)")
 43
 44	rootCmd.AddCommand(
 45		runCmd,
 46		dirsCmd,
 47		updateCmd,
 48		updateProvidersCmd,
 49		logsCmd,
 50		schemaCmd,
 51		loginCmd,
 52	)
 53}
 54
 55var rootCmd = &cobra.Command{
 56	Use:   "crush",
 57	Short: "Terminal-based AI assistant for software development",
 58	Long: `Crush is a powerful terminal-based AI assistant that helps with software development tasks.
 59It provides an interactive chat interface with AI capabilities, code analysis, and LSP integration
 60to assist developers in writing, debugging, and understanding code directly from the terminal.`,
 61	Example: `
 62# Run in interactive mode
 63crush
 64
 65# Run with debug logging
 66crush -d
 67
 68# Run with debug logging in a specific directory
 69crush -d -c /path/to/project
 70
 71# Run with custom data directory
 72crush -D /path/to/custom/.crush
 73
 74# Print version
 75crush -v
 76
 77# Run a single non-interactive prompt
 78crush run "Explain the use of context in Go"
 79
 80# Run in dangerous mode (auto-accept all permissions)
 81crush -y
 82  `,
 83	RunE: func(cmd *cobra.Command, args []string) error {
 84		app, err := setupAppWithProgressBar(cmd)
 85		if err != nil {
 86			return err
 87		}
 88		defer app.Shutdown()
 89
 90		event.AppInitialized()
 91
 92		// Set up the TUI.
 93		var env uv.Environ = os.Environ()
 94		ui := tui.New(app)
 95		ui.QueryVersion = shouldQueryTerminalVersion(env)
 96
 97		program := tea.NewProgram(
 98			ui,
 99			tea.WithEnvironment(env),
100			tea.WithContext(cmd.Context()),
101			tea.WithFilter(tui.MouseEventFilter)) // Filter mouse events based on focus state
102		go app.Subscribe(program)
103
104		// Start async update check unless disabled.
105		go checkForUpdateAsync(cmd.Context(), program)
106
107		if _, err := program.Run(); err != nil {
108			event.Error(err)
109			slog.Error("TUI run error", "error", err)
110			return errors.New("Crush crashed. If metrics are enabled, we were notified about it. If you'd like to report it, please copy the stacktrace above and open an issue at https://github.com/charmbracelet/crush/issues/new?template=bug.yml") //nolint:staticcheck
111		}
112		return nil
113	},
114	PostRun: func(cmd *cobra.Command, args []string) {
115		event.AppExited()
116	},
117}
118
119var heartbit = lipgloss.NewStyle().Foreground(charmtone.Dolly).SetString(`
120    ▄▄▄▄▄▄▄▄    ▄▄▄▄▄▄▄▄
121  ███████████  ███████████
122████████████████████████████
123████████████████████████████
124██████████▀██████▀██████████
125██████████ ██████ ██████████
126▀▀██████▄████▄▄████▄██████▀▀
127  ████████████████████████
128    ████████████████████
129       ▀▀██████████▀▀
130           ▀▀▀▀▀▀
131`)
132
133// copied from cobra:
134const defaultVersionTemplate = `{{with .DisplayName}}{{printf "%s " .}}{{end}}{{printf "version %s" .Version}}
135`
136
137func Execute() {
138	// NOTE: very hacky: we create a colorprofile writer with STDOUT, then make
139	// it forward to a bytes.Buffer, write the colored heartbit to it, and then
140	// finally prepend it in the version template.
141	// Unfortunately cobra doesn't give us a way to set a function to handle
142	// printing the version, and PreRunE runs after the version is already
143	// handled, so that doesn't work either.
144	// This is the only way I could find that works relatively well.
145	versionTemplate := defaultVersionTemplate
146	if term.IsTerminal(os.Stdout.Fd()) {
147		var b bytes.Buffer
148		w := colorprofile.NewWriter(os.Stdout, os.Environ())
149		w.Forward = &b
150		_, _ = w.WriteString(heartbit.String())
151		versionTemplate = b.String() + "\n" + defaultVersionTemplate
152	}
153
154	// Check if version flag is present and add update notification if available.
155	if hasVersionFlag() {
156		if updateMsg := checkForUpdateSync(); updateMsg != "" {
157			versionTemplate += updateMsg
158		}
159	}
160
161	rootCmd.SetVersionTemplate(versionTemplate)
162
163	if err := fang.Execute(
164		context.Background(),
165		rootCmd,
166		fang.WithVersion(version.Version),
167		fang.WithNotifySignal(os.Interrupt),
168	); err != nil {
169		os.Exit(1)
170	}
171}
172
173// hasVersionFlag checks if the version flag is present in os.Args.
174func hasVersionFlag() bool {
175	for _, arg := range os.Args {
176		if arg == "-v" || arg == "--version" {
177			return true
178		}
179	}
180	return false
181}
182
183// isAutoUpdateDisabled checks if update checks are disabled via env var.
184// Config is not loaded at this point (called before Execute), so only env var is checked.
185func isAutoUpdateDisabled() bool {
186	if str, ok := os.LookupEnv("CRUSH_DISABLE_AUTO_UPDATE"); ok {
187		v, _ := strconv.ParseBool(str)
188		return v
189	}
190	return false
191}
192
193// checkForUpdateSync performs a synchronous update check with a short timeout.
194// Returns a formatted update message if an update is available, empty string otherwise.
195func checkForUpdateSync() string {
196	if isAutoUpdateDisabled() {
197		return ""
198	}
199
200	ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
201	defer cancel()
202
203	info, err := update.Check(ctx, version.Version, update.Default)
204	if err != nil || !info.Available() {
205		return ""
206	}
207
208	if info.IsDevelopment() {
209		return fmt.Sprintf("\nThis is a development version of Crush. The latest stable release is v%s.\nRun 'crush update' to learn more.\n", info.Latest)
210	}
211
212	return fmt.Sprintf("\nUpdate available: v%s → v%s\nRun 'crush update apply' to install.\n", info.Current, info.Latest)
213}
214
215// checkForUpdateAsync checks for updates in the background and applies them if possible.
216func checkForUpdateAsync(ctx context.Context, program *tea.Program) {
217	// Check config (if loaded) or env var.
218	if isAutoUpdateDisabled() {
219		return
220	}
221	if cfg := config.Get(); cfg != nil && cfg.Options.DisableAutoUpdate {
222		return
223	}
224
225	checkCtx, cancel := context.WithTimeout(ctx, 2*time.Minute)
226	defer cancel()
227
228	info, err := update.Check(checkCtx, version.Version, update.Default)
229	if err != nil || !info.Available() || info.IsDevelopment() {
230		return
231	}
232
233	// Check install method.
234	method := update.DetectInstallMethod()
235	if !method.CanSelfUpdate() {
236		// Package manager install - show instructions.
237		program.Send(tuiutil.InfoMsg{
238			Type: tuiutil.InfoTypeUpdate,
239			Msg:  fmt.Sprintf("Update available: v%s → v%s. Run: %s", info.Current, info.Latest, method.UpdateInstructions()),
240			TTL:  30 * time.Second,
241		})
242		return
243	}
244
245	// Attempt self-update.
246	asset, err := update.FindAsset(info.Release.Assets)
247	if err != nil {
248		program.Send(tuiutil.InfoMsg{
249			Type: tuiutil.InfoTypeWarn,
250			Msg:  "Update available but failed to find asset. Run 'crush update' for details.",
251			TTL:  15 * time.Second,
252		})
253		return
254	}
255
256	binaryPath, err := update.Download(checkCtx, asset, info.Release)
257	if err != nil {
258		program.Send(tuiutil.InfoMsg{
259			Type: tuiutil.InfoTypeWarn,
260			Msg:  "Update download failed. Run 'crush update' for details.",
261			TTL:  15 * time.Second,
262		})
263		return
264	}
265	defer os.Remove(binaryPath)
266
267	if err := update.Apply(binaryPath); err != nil {
268		program.Send(tuiutil.InfoMsg{
269			Type: tuiutil.InfoTypeWarn,
270			Msg:  "Update failed to install. Run 'crush update' for details.",
271			TTL:  15 * time.Second,
272		})
273		return
274	}
275
276	// Success!
277	program.Send(tuiutil.InfoMsg{
278		Type: tuiutil.InfoTypeUpdate,
279		Msg:  fmt.Sprintf("Updated to v%s! Restart Crush to use the new version.", info.Latest),
280		TTL:  30 * time.Second,
281	})
282}
283
284func setupAppWithProgressBar(cmd *cobra.Command) (*app.App, error) {
285	if termutil.SupportsProgressBar() {
286		_, _ = fmt.Fprintf(os.Stderr, ansi.SetIndeterminateProgressBar)
287		defer func() { _, _ = fmt.Fprintf(os.Stderr, ansi.ResetProgressBar) }()
288	}
289
290	return setupApp(cmd)
291}
292
293// setupApp handles the common setup logic for both interactive and non-interactive modes.
294// It returns the app instance, config, cleanup function, and any error.
295func setupApp(cmd *cobra.Command) (*app.App, error) {
296	debug, _ := cmd.Flags().GetBool("debug")
297	yolo, _ := cmd.Flags().GetBool("yolo")
298	dataDir, _ := cmd.Flags().GetString("data-dir")
299	ctx := cmd.Context()
300
301	cwd, err := ResolveCwd(cmd)
302	if err != nil {
303		return nil, err
304	}
305
306	cfg, err := config.Init(cwd, dataDir, debug)
307	if err != nil {
308		return nil, err
309	}
310
311	if cfg.Permissions == nil {
312		cfg.Permissions = &config.Permissions{}
313	}
314	cfg.Permissions.SkipRequests = yolo
315
316	if err := createDotCrushDir(cfg.Options.DataDirectory); err != nil {
317		return nil, err
318	}
319
320	// Connect to DB; this will also run migrations.
321	conn, err := db.Connect(ctx, cfg.Options.DataDirectory)
322	if err != nil {
323		return nil, err
324	}
325
326	appInstance, err := app.New(ctx, conn, cfg)
327	if err != nil {
328		slog.Error("Failed to create app instance", "error", err)
329		return nil, err
330	}
331
332	if shouldEnableMetrics() {
333		event.Init()
334	}
335
336	return appInstance, nil
337}
338
339func shouldEnableMetrics() bool {
340	if v, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_METRICS")); v {
341		return false
342	}
343	if v, _ := strconv.ParseBool(os.Getenv("DO_NOT_TRACK")); v {
344		return false
345	}
346	if config.Get().Options.DisableMetrics {
347		return false
348	}
349	return true
350}
351
352func MaybePrependStdin(prompt string) (string, error) {
353	if term.IsTerminal(os.Stdin.Fd()) {
354		return prompt, nil
355	}
356	fi, err := os.Stdin.Stat()
357	if err != nil {
358		return prompt, err
359	}
360	if fi.Mode()&os.ModeNamedPipe == 0 {
361		return prompt, nil
362	}
363	bts, err := io.ReadAll(os.Stdin)
364	if err != nil {
365		return prompt, err
366	}
367	return string(bts) + "\n\n" + prompt, nil
368}
369
370func ResolveCwd(cmd *cobra.Command) (string, error) {
371	cwd, _ := cmd.Flags().GetString("cwd")
372	if cwd != "" {
373		err := os.Chdir(cwd)
374		if err != nil {
375			return "", fmt.Errorf("failed to change directory: %v", err)
376		}
377		return cwd, nil
378	}
379	cwd, err := os.Getwd()
380	if err != nil {
381		return "", fmt.Errorf("failed to get current working directory: %v", err)
382	}
383	return cwd, nil
384}
385
386func createDotCrushDir(dir string) error {
387	if err := os.MkdirAll(dir, 0o700); err != nil {
388		return fmt.Errorf("failed to create data directory: %q %w", dir, err)
389	}
390
391	gitIgnorePath := filepath.Join(dir, ".gitignore")
392	if _, err := os.Stat(gitIgnorePath); os.IsNotExist(err) {
393		if err := os.WriteFile(gitIgnorePath, []byte("*\n"), 0o644); err != nil {
394			return fmt.Errorf("failed to create .gitignore file: %q %w", gitIgnorePath, err)
395		}
396	}
397
398	return nil
399}
400
401func shouldQueryTerminalVersion(env uv.Environ) bool {
402	termType := env.Getenv("TERM")
403	termProg, okTermProg := env.LookupEnv("TERM_PROGRAM")
404	_, okSSHTTY := env.LookupEnv("SSH_TTY")
405	return (!okTermProg && !okSSHTTY) ||
406		(!strings.Contains(termProg, "Apple") && !okSSHTTY) ||
407		// Terminals that do support XTVERSION.
408		stringext.ContainsAny(termType, "alacritty", "ghostty", "kitty", "rio", "wezterm")
409}