root.go

  1package cmd
  2
  3import (
  4	"bytes"
  5	"context"
  6	"errors"
  7	"fmt"
  8	"io"
  9	"log/slog"
 10	"os"
 11	"path/filepath"
 12	"strconv"
 13	"strings"
 14
 15	tea "charm.land/bubbletea/v2"
 16	"charm.land/lipgloss/v2"
 17	"github.com/charmbracelet/colorprofile"
 18	"github.com/charmbracelet/crush/internal/app"
 19	"github.com/charmbracelet/crush/internal/config"
 20	"github.com/charmbracelet/crush/internal/db"
 21	"github.com/charmbracelet/crush/internal/event"
 22	"github.com/charmbracelet/crush/internal/stringext"
 23	termutil "github.com/charmbracelet/crush/internal/term"
 24	"github.com/charmbracelet/crush/internal/tui"
 25	"github.com/charmbracelet/crush/internal/version"
 26	"github.com/charmbracelet/fang"
 27	uv "github.com/charmbracelet/ultraviolet"
 28	"github.com/charmbracelet/x/ansi"
 29	"github.com/charmbracelet/x/exp/charmtone"
 30	"github.com/charmbracelet/x/term"
 31	"github.com/spf13/cobra"
 32)
 33
 34func init() {
 35	rootCmd.PersistentFlags().StringP("cwd", "c", "", "Current working directory")
 36	rootCmd.PersistentFlags().StringP("data-dir", "D", "", "Custom crush data directory")
 37	rootCmd.PersistentFlags().BoolP("debug", "d", false, "Debug")
 38	rootCmd.Flags().BoolP("help", "h", false, "Help")
 39	rootCmd.Flags().BoolP("yolo", "y", false, "Automatically accept all permissions (dangerous mode)")
 40
 41	rootCmd.AddCommand(
 42		runCmd,
 43		dirsCmd,
 44		updateProvidersCmd,
 45		logsCmd,
 46		schemaCmd,
 47		loginCmd,
 48	)
 49}
 50
 51var rootCmd = &cobra.Command{
 52	Use:   "crush",
 53	Short: "Terminal-based AI assistant for software development",
 54	Long: `Crush is a powerful terminal-based AI assistant that helps with software development tasks.
 55It provides an interactive chat interface with AI capabilities, code analysis, and LSP integration
 56to assist developers in writing, debugging, and understanding code directly from the terminal.`,
 57	Example: `
 58# Run in interactive mode
 59crush
 60
 61# Run with debug logging
 62crush -d
 63
 64# Run with debug logging in a specific directory
 65crush -d -c /path/to/project
 66
 67# Run with custom data directory
 68crush -D /path/to/custom/.crush
 69
 70# Print version
 71crush -v
 72
 73# Run a single non-interactive prompt
 74crush run "Explain the use of context in Go"
 75
 76# Run in dangerous mode (auto-accept all permissions)
 77crush -y
 78  `,
 79	RunE: func(cmd *cobra.Command, args []string) error {
 80		app, err := setupAppWithProgressBar(cmd)
 81		if err != nil {
 82			return err
 83		}
 84		defer app.Shutdown()
 85
 86		event.AppInitialized()
 87
 88		// Set up the TUI.
 89		var env uv.Environ = os.Environ()
 90		ui := tui.New(app)
 91		ui.QueryVersion = shouldQueryTerminalVersion(env)
 92
 93		program := tea.NewProgram(
 94			ui,
 95			tea.WithEnvironment(env),
 96			tea.WithContext(cmd.Context()),
 97			tea.WithFilter(tui.MouseEventFilter)) // Filter mouse events based on focus state
 98		go app.Subscribe(program)
 99
100		if _, err := program.Run(); err != nil {
101			event.Error(err)
102			slog.Error("TUI run error", "error", err)
103			return errors.New("Crush crashed. If metrics are enabled, we were notified about it. If you'd like to report it, please copy the stacktrace above and open an issue at https://github.com/charmbracelet/crush/issues/new?template=bug.yml") //nolint:staticcheck
104		}
105		return nil
106	},
107	PostRun: func(cmd *cobra.Command, args []string) {
108		event.AppExited()
109	},
110}
111
112var heartbit = lipgloss.NewStyle().Foreground(charmtone.Dolly).SetString(`
113    ▄▄▄▄▄▄▄▄    ▄▄▄▄▄▄▄▄
114  ███████████  ███████████
115████████████████████████████
116████████████████████████████
117██████████▀██████▀██████████
118██████████ ██████ ██████████
119▀▀██████▄████▄▄████▄██████▀▀
120  ████████████████████████
121    ████████████████████
122       ▀▀██████████▀▀
123           ▀▀▀▀▀▀
124`)
125
126// copied from cobra:
127const defaultVersionTemplate = `{{with .DisplayName}}{{printf "%s " .}}{{end}}{{printf "version %s" .Version}}
128`
129
130func Execute() {
131	// NOTE: very hacky: we create a colorprofile writer with STDOUT, then make
132	// it forward to a bytes.Buffer, write the colored heartbit to it, and then
133	// finally prepend it in the version template.
134	// Unfortunately cobra doesn't give us a way to set a function to handle
135	// printing the version, and PreRunE runs after the version is already
136	// handled, so that doesn't work either.
137	// This is the only way I could find that works relatively well.
138	if term.IsTerminal(os.Stdout.Fd()) {
139		var b bytes.Buffer
140		w := colorprofile.NewWriter(os.Stdout, os.Environ())
141		w.Forward = &b
142		_, _ = w.WriteString(heartbit.String())
143		rootCmd.SetVersionTemplate(b.String() + "\n" + defaultVersionTemplate)
144	}
145	if err := fang.Execute(
146		context.Background(),
147		rootCmd,
148		fang.WithVersion(version.Version),
149		fang.WithNotifySignal(os.Interrupt),
150	); err != nil {
151		os.Exit(1)
152	}
153}
154
155func setupAppWithProgressBar(cmd *cobra.Command) (*app.App, error) {
156	if termutil.SupportsProgressBar() {
157		_, _ = fmt.Fprintf(os.Stderr, ansi.SetIndeterminateProgressBar)
158		defer func() { _, _ = fmt.Fprintf(os.Stderr, ansi.ResetProgressBar) }()
159	}
160
161	return setupApp(cmd)
162}
163
164// setupApp handles the common setup logic for both interactive and non-interactive modes.
165// It returns the app instance, config, cleanup function, and any error.
166func setupApp(cmd *cobra.Command) (*app.App, error) {
167	debug, _ := cmd.Flags().GetBool("debug")
168	yolo, _ := cmd.Flags().GetBool("yolo")
169	dataDir, _ := cmd.Flags().GetString("data-dir")
170	ctx := cmd.Context()
171
172	cwd, err := ResolveCwd(cmd)
173	if err != nil {
174		return nil, err
175	}
176
177	cfg, err := config.Init(cwd, dataDir, debug)
178	if err != nil {
179		return nil, err
180	}
181
182	if cfg.Permissions == nil {
183		cfg.Permissions = &config.Permissions{}
184	}
185	cfg.Permissions.SkipRequests = yolo
186
187	if err := createDotCrushDir(cfg.Options.DataDirectory); err != nil {
188		return nil, err
189	}
190
191	// Connect to DB; this will also run migrations.
192	conn, err := db.Connect(ctx, cfg.Options.DataDirectory)
193	if err != nil {
194		return nil, err
195	}
196
197	appInstance, err := app.New(ctx, conn, cfg)
198	if err != nil {
199		slog.Error("Failed to create app instance", "error", err)
200		return nil, err
201	}
202
203	if shouldEnableMetrics() {
204		event.Init()
205	}
206
207	return appInstance, nil
208}
209
210func shouldEnableMetrics() bool {
211	if v, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_METRICS")); v {
212		return false
213	}
214	if v, _ := strconv.ParseBool(os.Getenv("DO_NOT_TRACK")); v {
215		return false
216	}
217	if config.Get().Options.DisableMetrics {
218		return false
219	}
220	return true
221}
222
223func MaybePrependStdin(prompt string) (string, error) {
224	if term.IsTerminal(os.Stdin.Fd()) {
225		return prompt, nil
226	}
227	fi, err := os.Stdin.Stat()
228	if err != nil {
229		return prompt, err
230	}
231	if fi.Mode()&os.ModeNamedPipe == 0 {
232		return prompt, nil
233	}
234	bts, err := io.ReadAll(os.Stdin)
235	if err != nil {
236		return prompt, err
237	}
238	return string(bts) + "\n\n" + prompt, nil
239}
240
241func ResolveCwd(cmd *cobra.Command) (string, error) {
242	cwd, _ := cmd.Flags().GetString("cwd")
243	if cwd != "" {
244		err := os.Chdir(cwd)
245		if err != nil {
246			return "", fmt.Errorf("failed to change directory: %v", err)
247		}
248		return cwd, nil
249	}
250	cwd, err := os.Getwd()
251	if err != nil {
252		return "", fmt.Errorf("failed to get current working directory: %v", err)
253	}
254	return cwd, nil
255}
256
257func createDotCrushDir(dir string) error {
258	if err := os.MkdirAll(dir, 0o700); err != nil {
259		return fmt.Errorf("failed to create data directory: %q %w", dir, err)
260	}
261
262	gitIgnorePath := filepath.Join(dir, ".gitignore")
263	if _, err := os.Stat(gitIgnorePath); os.IsNotExist(err) {
264		if err := os.WriteFile(gitIgnorePath, []byte("*\n"), 0o644); err != nil {
265			return fmt.Errorf("failed to create .gitignore file: %q %w", gitIgnorePath, err)
266		}
267	}
268
269	return nil
270}
271
272func shouldQueryTerminalVersion(env uv.Environ) bool {
273	termType := env.Getenv("TERM")
274	termProg, okTermProg := env.LookupEnv("TERM_PROGRAM")
275	_, okSSHTTY := env.LookupEnv("SSH_TTY")
276	return (!okTermProg && !okSSHTTY) ||
277		(!strings.Contains(termProg, "Apple") && !okSSHTTY) ||
278		// Terminals that do support XTVERSION.
279		stringext.ContainsAny(termType, "alacritty", "ghostty", "kitty", "rio", "wezterm")
280}