root.go

  1package cmd
  2
  3import (
  4	"bytes"
  5	"context"
  6	"errors"
  7	"fmt"
  8	"io"
  9	"log/slog"
 10	"os"
 11	"path/filepath"
 12	"strconv"
 13	"strings"
 14
 15	tea "charm.land/bubbletea/v2"
 16	"charm.land/lipgloss/v2"
 17	"github.com/charmbracelet/colorprofile"
 18	"github.com/charmbracelet/crush/internal/app"
 19	"github.com/charmbracelet/crush/internal/config"
 20	"github.com/charmbracelet/crush/internal/db"
 21	"github.com/charmbracelet/crush/internal/event"
 22	"github.com/charmbracelet/crush/internal/projects"
 23	"github.com/charmbracelet/crush/internal/stringext"
 24	"github.com/charmbracelet/crush/internal/tui"
 25	"github.com/charmbracelet/crush/internal/version"
 26	"github.com/charmbracelet/fang"
 27	uv "github.com/charmbracelet/ultraviolet"
 28	"github.com/charmbracelet/x/ansi"
 29	"github.com/charmbracelet/x/exp/charmtone"
 30	"github.com/charmbracelet/x/term"
 31	"github.com/spf13/cobra"
 32)
 33
 34func init() {
 35	rootCmd.PersistentFlags().StringP("cwd", "c", "", "Current working directory")
 36	rootCmd.PersistentFlags().StringP("data-dir", "D", "", "Custom crush data directory")
 37	rootCmd.PersistentFlags().BoolP("debug", "d", false, "Debug")
 38	rootCmd.Flags().BoolP("help", "h", false, "Help")
 39	rootCmd.Flags().BoolP("yolo", "y", false, "Automatically accept all permissions (dangerous mode)")
 40
 41	rootCmd.AddCommand(
 42		runCmd,
 43		dirsCmd,
 44		projectsCmd,
 45		updateProvidersCmd,
 46		logsCmd,
 47		schemaCmd,
 48		loginCmd,
 49	)
 50}
 51
 52var rootCmd = &cobra.Command{
 53	Use:   "crush",
 54	Short: "Terminal-based AI assistant for software development",
 55	Long: `Crush is a powerful terminal-based AI assistant that helps with software development tasks.
 56It provides an interactive chat interface with AI capabilities, code analysis, and LSP integration
 57to assist developers in writing, debugging, and understanding code directly from the terminal.`,
 58	Example: `
 59# Run in interactive mode
 60crush
 61
 62# Run with debug logging
 63crush -d
 64
 65# Run with debug logging in a specific directory
 66crush -d -c /path/to/project
 67
 68# Run with custom data directory
 69crush -D /path/to/custom/.crush
 70
 71# Print version
 72crush -v
 73
 74# Run a single non-interactive prompt
 75crush run "Explain the use of context in Go"
 76
 77# Run in dangerous mode (auto-accept all permissions)
 78crush -y
 79  `,
 80	RunE: func(cmd *cobra.Command, args []string) error {
 81		app, err := setupAppWithProgressBar(cmd)
 82		if err != nil {
 83			return err
 84		}
 85		defer app.Shutdown()
 86
 87		event.AppInitialized()
 88
 89		// Set up the TUI.
 90		var env uv.Environ = os.Environ()
 91		ui := tui.New(app)
 92		ui.QueryVersion = shouldQueryTerminalVersion(env)
 93
 94		program := tea.NewProgram(
 95			ui,
 96			tea.WithEnvironment(env),
 97			tea.WithContext(cmd.Context()),
 98			tea.WithFilter(tui.MouseEventFilter)) // Filter mouse events based on focus state
 99		go app.Subscribe(program)
100
101		if _, err := program.Run(); err != nil {
102			event.Error(err)
103			slog.Error("TUI run error", "error", err)
104			return errors.New("Crush crashed. If metrics are enabled, we were notified about it. If you'd like to report it, please copy the stacktrace above and open an issue at https://github.com/charmbracelet/crush/issues/new?template=bug.yml") //nolint:staticcheck
105		}
106		return nil
107	},
108	PostRun: func(cmd *cobra.Command, args []string) {
109		event.AppExited()
110	},
111}
112
113var heartbit = lipgloss.NewStyle().Foreground(charmtone.Dolly).SetString(`
114    ▄▄▄▄▄▄▄▄    ▄▄▄▄▄▄▄▄
115  ███████████  ███████████
116████████████████████████████
117████████████████████████████
118██████████▀██████▀██████████
119██████████ ██████ ██████████
120▀▀██████▄████▄▄████▄██████▀▀
121  ████████████████████████
122    ████████████████████
123       ▀▀██████████▀▀
124           ▀▀▀▀▀▀
125`)
126
127// copied from cobra:
128const defaultVersionTemplate = `{{with .DisplayName}}{{printf "%s " .}}{{end}}{{printf "version %s" .Version}}
129`
130
131func Execute() {
132	// NOTE: very hacky: we create a colorprofile writer with STDOUT, then make
133	// it forward to a bytes.Buffer, write the colored heartbit to it, and then
134	// finally prepend it in the version template.
135	// Unfortunately cobra doesn't give us a way to set a function to handle
136	// printing the version, and PreRunE runs after the version is already
137	// handled, so that doesn't work either.
138	// This is the only way I could find that works relatively well.
139	if term.IsTerminal(os.Stdout.Fd()) {
140		var b bytes.Buffer
141		w := colorprofile.NewWriter(os.Stdout, os.Environ())
142		w.Forward = &b
143		_, _ = w.WriteString(heartbit.String())
144		rootCmd.SetVersionTemplate(b.String() + "\n" + defaultVersionTemplate)
145	}
146	if err := fang.Execute(
147		context.Background(),
148		rootCmd,
149		fang.WithVersion(version.Version),
150		fang.WithNotifySignal(os.Interrupt),
151	); err != nil {
152		os.Exit(1)
153	}
154}
155
156// supportsProgressBar tries to determine whether the current terminal supports
157// progress bars by looking into environment variables.
158func supportsProgressBar() bool {
159	if !term.IsTerminal(os.Stderr.Fd()) {
160		return false
161	}
162	termProg := os.Getenv("TERM_PROGRAM")
163	_, isWindowsTerminal := os.LookupEnv("WT_SESSION")
164
165	return isWindowsTerminal || strings.Contains(strings.ToLower(termProg), "ghostty")
166}
167
168func setupAppWithProgressBar(cmd *cobra.Command) (*app.App, error) {
169	if supportsProgressBar() {
170		_, _ = fmt.Fprintf(os.Stderr, ansi.SetIndeterminateProgressBar)
171		defer func() { _, _ = fmt.Fprintf(os.Stderr, ansi.ResetProgressBar) }()
172	}
173
174	return setupApp(cmd)
175}
176
177// setupApp handles the common setup logic for both interactive and non-interactive modes.
178// It returns the app instance, config, cleanup function, and any error.
179func setupApp(cmd *cobra.Command) (*app.App, error) {
180	debug, _ := cmd.Flags().GetBool("debug")
181	yolo, _ := cmd.Flags().GetBool("yolo")
182	dataDir, _ := cmd.Flags().GetString("data-dir")
183	ctx := cmd.Context()
184
185	cwd, err := ResolveCwd(cmd)
186	if err != nil {
187		return nil, err
188	}
189
190	cfg, err := config.Init(cwd, dataDir, debug)
191	if err != nil {
192		return nil, err
193	}
194
195	if cfg.Permissions == nil {
196		cfg.Permissions = &config.Permissions{}
197	}
198	cfg.Permissions.SkipRequests = yolo
199
200	if err := createDotCrushDir(cfg.Options.DataDirectory); err != nil {
201		return nil, err
202	}
203
204	// Register this project in the centralized projects list.
205	if err := projects.Register(cwd, cfg.Options.DataDirectory); err != nil {
206		slog.Warn("Failed to register project", "error", err)
207		// Non-fatal: continue even if registration fails
208	}
209
210	// Connect to DB; this will also run migrations.
211	conn, err := db.Connect(ctx, cfg.Options.DataDirectory)
212	if err != nil {
213		return nil, err
214	}
215
216	appInstance, err := app.New(ctx, conn, cfg)
217	if err != nil {
218		slog.Error("Failed to create app instance", "error", err)
219		return nil, err
220	}
221
222	if shouldEnableMetrics() {
223		event.Init()
224	}
225
226	return appInstance, nil
227}
228
229func shouldEnableMetrics() bool {
230	if v, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_METRICS")); v {
231		return false
232	}
233	if v, _ := strconv.ParseBool(os.Getenv("DO_NOT_TRACK")); v {
234		return false
235	}
236	if config.Get().Options.DisableMetrics {
237		return false
238	}
239	return true
240}
241
242func MaybePrependStdin(prompt string) (string, error) {
243	if term.IsTerminal(os.Stdin.Fd()) {
244		return prompt, nil
245	}
246	fi, err := os.Stdin.Stat()
247	if err != nil {
248		return prompt, err
249	}
250	// Check if stdin is a named pipe ( | ) or regular file ( < ).
251	if fi.Mode()&os.ModeNamedPipe == 0 && !fi.Mode().IsRegular() {
252		return prompt, nil
253	}
254	bts, err := io.ReadAll(os.Stdin)
255	if err != nil {
256		return prompt, err
257	}
258	return string(bts) + "\n\n" + prompt, nil
259}
260
261func ResolveCwd(cmd *cobra.Command) (string, error) {
262	cwd, _ := cmd.Flags().GetString("cwd")
263	if cwd != "" {
264		err := os.Chdir(cwd)
265		if err != nil {
266			return "", fmt.Errorf("failed to change directory: %v", err)
267		}
268		return cwd, nil
269	}
270	cwd, err := os.Getwd()
271	if err != nil {
272		return "", fmt.Errorf("failed to get current working directory: %v", err)
273	}
274	return cwd, nil
275}
276
277func createDotCrushDir(dir string) error {
278	if err := os.MkdirAll(dir, 0o700); err != nil {
279		return fmt.Errorf("failed to create data directory: %q %w", dir, err)
280	}
281
282	gitIgnorePath := filepath.Join(dir, ".gitignore")
283	if _, err := os.Stat(gitIgnorePath); os.IsNotExist(err) {
284		if err := os.WriteFile(gitIgnorePath, []byte("*\n"), 0o644); err != nil {
285			return fmt.Errorf("failed to create .gitignore file: %q %w", gitIgnorePath, err)
286		}
287	}
288
289	return nil
290}
291
292func shouldQueryTerminalVersion(env uv.Environ) bool {
293	termType := env.Getenv("TERM")
294	termProg, okTermProg := env.LookupEnv("TERM_PROGRAM")
295	_, okSSHTTY := env.LookupEnv("SSH_TTY")
296	return (!okTermProg && !okSSHTTY) ||
297		(!strings.Contains(termProg, "Apple") && !okSSHTTY) ||
298		// Terminals that do support XTVERSION.
299		stringext.ContainsAny(termType, "alacritty", "ghostty", "kitty", "rio", "wezterm")
300}