root.go

  1package cmd
  2
  3import (
  4	"bytes"
  5	"context"
  6	"errors"
  7	"fmt"
  8	"io"
  9	"log/slog"
 10	"os"
 11	"path/filepath"
 12	"strconv"
 13	"strings"
 14
 15	tea "github.com/charmbracelet/bubbletea/v2"
 16	"github.com/charmbracelet/colorprofile"
 17	"github.com/charmbracelet/crush/internal/app"
 18	"github.com/charmbracelet/crush/internal/config"
 19	"github.com/charmbracelet/crush/internal/db"
 20	"github.com/charmbracelet/crush/internal/event"
 21	termutil "github.com/charmbracelet/crush/internal/term"
 22	"github.com/charmbracelet/crush/internal/tui"
 23	"github.com/charmbracelet/crush/internal/version"
 24	"github.com/charmbracelet/fang"
 25	"github.com/charmbracelet/lipgloss/v2"
 26	uv "github.com/charmbracelet/ultraviolet"
 27	"github.com/charmbracelet/x/ansi"
 28	"github.com/charmbracelet/x/exp/charmtone"
 29	"github.com/charmbracelet/x/term"
 30	"github.com/spf13/cobra"
 31)
 32
 33func init() {
 34	rootCmd.PersistentFlags().StringP("cwd", "c", "", "Current working directory")
 35	rootCmd.PersistentFlags().StringP("data-dir", "D", "", "Custom crush data directory")
 36	rootCmd.PersistentFlags().BoolP("debug", "d", false, "Debug")
 37	rootCmd.Flags().BoolP("help", "h", false, "Help")
 38	rootCmd.Flags().BoolP("yolo", "y", false, "Automatically accept all permissions (dangerous mode)")
 39
 40	rootCmd.AddCommand(
 41		runCmd,
 42		dirsCmd,
 43		updateProvidersCmd,
 44		logsCmd,
 45		schemaCmd,
 46	)
 47}
 48
 49var rootCmd = &cobra.Command{
 50	Use:   "crush",
 51	Short: "Terminal-based AI assistant for software development",
 52	Long: `Crush is a powerful terminal-based AI assistant that helps with software development tasks.
 53It provides an interactive chat interface with AI capabilities, code analysis, and LSP integration
 54to assist developers in writing, debugging, and understanding code directly from the terminal.`,
 55	Example: `
 56# Run in interactive mode
 57crush
 58
 59# Run with debug logging
 60crush -d
 61
 62# Run with debug logging in a specific directory
 63crush -d -c /path/to/project
 64
 65# Run with custom data directory
 66crush -D /path/to/custom/.crush
 67
 68# Print version
 69crush -v
 70
 71# Run a single non-interactive prompt
 72crush run "Explain the use of context in Go"
 73
 74# Run in dangerous mode (auto-accept all permissions)
 75crush -y
 76  `,
 77	RunE: func(cmd *cobra.Command, args []string) error {
 78		app, err := setupAppWithProgressBar(cmd)
 79		if err != nil {
 80			return err
 81		}
 82		defer app.Shutdown()
 83
 84		event.AppInitialized()
 85
 86		// Set up the TUI.
 87		var env uv.Environ = os.Environ()
 88		ui := tui.New(app)
 89		ui.QueryVersion = shouldQueryTerminalVersion(env)
 90
 91		program := tea.NewProgram(
 92			ui,
 93			tea.WithEnvironment(env),
 94			tea.WithContext(cmd.Context()),
 95			tea.WithFilter(tui.MouseEventFilter)) // Filter mouse events based on focus state
 96		go app.Subscribe(program)
 97
 98		if _, err := program.Run(); err != nil {
 99			event.Error(err)
100			slog.Error("TUI run error", "error", err)
101			return errors.New("Crush crashed. If metrics are enabled, we were notified about it. If you'd like to report it, please copy the stacktrace above and open an issue at https://github.com/charmbracelet/crush/issues/new?template=bug.yml") //nolint:staticcheck
102		}
103		return nil
104	},
105	PostRun: func(cmd *cobra.Command, args []string) {
106		event.AppExited()
107	},
108}
109
110var heartbit = lipgloss.NewStyle().Foreground(charmtone.Dolly).SetString(`
111    ▄▄▄▄▄▄▄▄    ▄▄▄▄▄▄▄▄
112  ███████████  ███████████
113████████████████████████████
114████████████████████████████
115██████████▀██████▀██████████
116██████████ ██████ ██████████
117▀▀██████▄████▄▄████▄██████▀▀
118  ████████████████████████
119    ████████████████████
120       ▀▀██████████▀▀
121           ▀▀▀▀▀▀
122`)
123
124// copied from cobra:
125const defaultVersionTemplate = `{{with .DisplayName}}{{printf "%s " .}}{{end}}{{printf "version %s" .Version}}
126`
127
128func Execute() {
129	// NOTE: very hacky: we create a colorprofile writer with STDOUT, then make
130	// it forward to a bytes.Buffer, write the colored heartbit to it, and then
131	// finally prepend it in the version template.
132	// Unfortunately cobra doesn't give us a way to set a function to handle
133	// printing the version, and PreRunE runs after the version is already
134	// handled, so that doesn't work either.
135	// This is the only way I could find that works relatively well.
136	if term.IsTerminal(os.Stdout.Fd()) {
137		var b bytes.Buffer
138		w := colorprofile.NewWriter(os.Stdout, os.Environ())
139		w.Forward = &b
140		_, _ = w.WriteString(heartbit.String())
141		rootCmd.SetVersionTemplate(b.String() + "\n" + defaultVersionTemplate)
142	}
143	if err := fang.Execute(
144		context.Background(),
145		rootCmd,
146		fang.WithVersion(version.Version),
147		fang.WithNotifySignal(os.Interrupt),
148	); err != nil {
149		os.Exit(1)
150	}
151}
152
153func setupAppWithProgressBar(cmd *cobra.Command) (*app.App, error) {
154	if termutil.SupportsProgressBar() {
155		_, _ = fmt.Fprintf(os.Stderr, ansi.SetIndeterminateProgressBar)
156		defer func() { _, _ = fmt.Fprintf(os.Stderr, ansi.ResetProgressBar) }()
157	}
158
159	return setupApp(cmd)
160}
161
162// setupApp handles the common setup logic for both interactive and non-interactive modes.
163// It returns the app instance, config, cleanup function, and any error.
164func setupApp(cmd *cobra.Command) (*app.App, error) {
165	debug, _ := cmd.Flags().GetBool("debug")
166	yolo, _ := cmd.Flags().GetBool("yolo")
167	dataDir, _ := cmd.Flags().GetString("data-dir")
168	ctx := cmd.Context()
169
170	cwd, err := ResolveCwd(cmd)
171	if err != nil {
172		return nil, err
173	}
174
175	cfg, err := config.Init(cwd, dataDir, debug)
176	if err != nil {
177		return nil, err
178	}
179
180	if cfg.Permissions == nil {
181		cfg.Permissions = &config.Permissions{}
182	}
183	cfg.Permissions.SkipRequests = yolo
184
185	if err := createDotCrushDir(cfg.Options.DataDirectory); err != nil {
186		return nil, err
187	}
188
189	// Connect to DB; this will also run migrations.
190	conn, err := db.Connect(ctx, cfg.Options.DataDirectory)
191	if err != nil {
192		return nil, err
193	}
194
195	appInstance, err := app.New(ctx, conn, cfg)
196	if err != nil {
197		slog.Error("Failed to create app instance", "error", err)
198		return nil, err
199	}
200
201	if shouldEnableMetrics() {
202		event.Init()
203	}
204
205	return appInstance, nil
206}
207
208func shouldEnableMetrics() bool {
209	if v, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_METRICS")); v {
210		return false
211	}
212	if v, _ := strconv.ParseBool(os.Getenv("DO_NOT_TRACK")); v {
213		return false
214	}
215	if config.Get().Options.DisableMetrics {
216		return false
217	}
218	return true
219}
220
221func MaybePrependStdin(prompt string) (string, error) {
222	if term.IsTerminal(os.Stdin.Fd()) {
223		return prompt, nil
224	}
225	fi, err := os.Stdin.Stat()
226	if err != nil {
227		return prompt, err
228	}
229	if fi.Mode()&os.ModeNamedPipe == 0 {
230		return prompt, nil
231	}
232	bts, err := io.ReadAll(os.Stdin)
233	if err != nil {
234		return prompt, err
235	}
236	return string(bts) + "\n\n" + prompt, nil
237}
238
239func ResolveCwd(cmd *cobra.Command) (string, error) {
240	cwd, _ := cmd.Flags().GetString("cwd")
241	if cwd != "" {
242		err := os.Chdir(cwd)
243		if err != nil {
244			return "", fmt.Errorf("failed to change directory: %v", err)
245		}
246		return cwd, nil
247	}
248	cwd, err := os.Getwd()
249	if err != nil {
250		return "", fmt.Errorf("failed to get current working directory: %v", err)
251	}
252	return cwd, nil
253}
254
255func createDotCrushDir(dir string) error {
256	if err := os.MkdirAll(dir, 0o700); err != nil {
257		return fmt.Errorf("failed to create data directory: %q %w", dir, err)
258	}
259
260	gitIgnorePath := filepath.Join(dir, ".gitignore")
261	if _, err := os.Stat(gitIgnorePath); os.IsNotExist(err) {
262		if err := os.WriteFile(gitIgnorePath, []byte("*\n"), 0o644); err != nil {
263			return fmt.Errorf("failed to create .gitignore file: %q %w", gitIgnorePath, err)
264		}
265	}
266
267	return nil
268}
269
270func shouldQueryTerminalVersion(env uv.Environ) bool {
271	termType := env.Getenv("TERM")
272	termProg, okTermProg := env.LookupEnv("TERM_PROGRAM")
273	_, okSSHTTY := env.LookupEnv("SSH_TTY")
274	return (!okTermProg && !okSSHTTY) ||
275		(!strings.Contains(termProg, "Apple") && !okSSHTTY) ||
276		// Terminals that do support XTVERSION.
277		strings.Contains(termType, "ghostty") ||
278		strings.Contains(termType, "wezterm") ||
279		strings.Contains(termType, "alacritty") ||
280		strings.Contains(termType, "kitty") ||
281		strings.Contains(termType, "rio")
282}