root.go

  1package cmd
  2
  3import (
  4	"bytes"
  5	"context"
  6	"errors"
  7	"fmt"
  8	"io"
  9	"log/slog"
 10	"os"
 11	"path/filepath"
 12	"strconv"
 13	"strings"
 14
 15	tea "charm.land/bubbletea/v2"
 16	"charm.land/lipgloss/v2"
 17	"github.com/charmbracelet/colorprofile"
 18	"github.com/charmbracelet/crush/internal/app"
 19	"github.com/charmbracelet/crush/internal/config"
 20	"github.com/charmbracelet/crush/internal/db"
 21	"github.com/charmbracelet/crush/internal/event"
 22	"github.com/charmbracelet/crush/internal/projects"
 23	"github.com/charmbracelet/crush/internal/tui"
 24	"github.com/charmbracelet/crush/internal/ui/common"
 25	ui "github.com/charmbracelet/crush/internal/ui/model"
 26	"github.com/charmbracelet/crush/internal/version"
 27	"github.com/charmbracelet/fang"
 28	uv "github.com/charmbracelet/ultraviolet"
 29	"github.com/charmbracelet/x/ansi"
 30	"github.com/charmbracelet/x/exp/charmtone"
 31	xstrings "github.com/charmbracelet/x/exp/strings"
 32	"github.com/charmbracelet/x/term"
 33	"github.com/spf13/cobra"
 34)
 35
 36// kittyTerminals defines terminals supporting querying capabilities.
 37var kittyTerminals = []string{"alacritty", "ghostty", "kitty", "rio", "wezterm"}
 38
 39func init() {
 40	rootCmd.PersistentFlags().StringP("cwd", "c", "", "Current working directory")
 41	rootCmd.PersistentFlags().StringP("data-dir", "D", "", "Custom crush data directory")
 42	rootCmd.PersistentFlags().BoolP("debug", "d", false, "Debug")
 43	rootCmd.Flags().BoolP("help", "h", false, "Help")
 44	rootCmd.Flags().BoolP("yolo", "y", false, "Automatically accept all permissions (dangerous mode)")
 45
 46	rootCmd.AddCommand(
 47		runCmd,
 48		dirsCmd,
 49		projectsCmd,
 50		updateProvidersCmd,
 51		logsCmd,
 52		schemaCmd,
 53		loginCmd,
 54		statsCmd,
 55	)
 56}
 57
 58var rootCmd = &cobra.Command{
 59	Use:   "crush",
 60	Short: "An AI assistant for software development",
 61	Long:  "An AI assistant for software development and similar tasks with direct access to the terminal",
 62	Example: `
 63# Run in interactive mode
 64crush
 65
 66# Run with debug logging
 67crush -d
 68
 69# Run with debug logging in a specific directory
 70crush -d -c /path/to/project
 71
 72# Run with custom data directory
 73crush -D /path/to/custom/.crush
 74
 75# Print version
 76crush -v
 77
 78# Run a single non-interactive prompt
 79crush run "Explain the use of context in Go"
 80
 81# Run in dangerous mode (auto-accept all permissions)
 82crush -y
 83  `,
 84	RunE: func(cmd *cobra.Command, args []string) error {
 85		app, err := setupAppWithProgressBar(cmd)
 86		if err != nil {
 87			return err
 88		}
 89		defer app.Shutdown()
 90
 91		event.AppInitialized()
 92
 93		// Set up the TUI.
 94		var env uv.Environ = os.Environ()
 95
 96		newUI := true
 97		if v, err := strconv.ParseBool(env.Getenv("CRUSH_NEW_UI")); err == nil {
 98			newUI = v
 99		}
100
101		var model tea.Model
102		if newUI {
103			slog.Info("New UI in control!")
104			com := common.DefaultCommon(app)
105			ui := ui.New(com)
106			model = ui
107		} else {
108			ui := tui.New(app)
109			ui.QueryVersion = shouldQueryCapabilities(env)
110			model = ui
111		}
112		program := tea.NewProgram(
113			model,
114			tea.WithEnvironment(env),
115			tea.WithContext(cmd.Context()),
116			tea.WithFilter(tui.MouseEventFilter)) // Filter mouse events based on focus state
117		go app.Subscribe(program)
118
119		if _, err := program.Run(); err != nil {
120			event.Error(err)
121			slog.Error("TUI run error", "error", err)
122			return errors.New("Crush crashed. If metrics are enabled, we were notified about it. If you'd like to report it, please copy the stacktrace above and open an issue at https://github.com/charmbracelet/crush/issues/new?template=bug.yml") //nolint:staticcheck
123		}
124		return nil
125	},
126	PostRun: func(cmd *cobra.Command, args []string) {
127		event.AppExited()
128	},
129}
130
131var heartbit = lipgloss.NewStyle().Foreground(charmtone.Dolly).SetString(`
132    ▄▄▄▄▄▄▄▄    ▄▄▄▄▄▄▄▄
133  ███████████  ███████████
134████████████████████████████
135████████████████████████████
136██████████▀██████▀██████████
137██████████ ██████ ██████████
138▀▀██████▄████▄▄████▄██████▀▀
139  ████████████████████████
140    ████████████████████
141       ▀▀██████████▀▀
142           ▀▀▀▀▀▀
143`)
144
145// copied from cobra:
146const defaultVersionTemplate = `{{with .DisplayName}}{{printf "%s " .}}{{end}}{{printf "version %s" .Version}}
147`
148
149func Execute() {
150	// NOTE: very hacky: we create a colorprofile writer with STDOUT, then make
151	// it forward to a bytes.Buffer, write the colored heartbit to it, and then
152	// finally prepend it in the version template.
153	// Unfortunately cobra doesn't give us a way to set a function to handle
154	// printing the version, and PreRunE runs after the version is already
155	// handled, so that doesn't work either.
156	// This is the only way I could find that works relatively well.
157	if term.IsTerminal(os.Stdout.Fd()) {
158		var b bytes.Buffer
159		w := colorprofile.NewWriter(os.Stdout, os.Environ())
160		w.Forward = &b
161		_, _ = w.WriteString(heartbit.String())
162		rootCmd.SetVersionTemplate(b.String() + "\n" + defaultVersionTemplate)
163	}
164	if err := fang.Execute(
165		context.Background(),
166		rootCmd,
167		fang.WithVersion(version.Version),
168		fang.WithNotifySignal(os.Interrupt),
169	); err != nil {
170		os.Exit(1)
171	}
172}
173
174// supportsProgressBar tries to determine whether the current terminal supports
175// progress bars by looking into environment variables.
176func supportsProgressBar() bool {
177	if !term.IsTerminal(os.Stderr.Fd()) {
178		return false
179	}
180	termProg := os.Getenv("TERM_PROGRAM")
181	_, isWindowsTerminal := os.LookupEnv("WT_SESSION")
182
183	return isWindowsTerminal || strings.Contains(strings.ToLower(termProg), "ghostty")
184}
185
186func setupAppWithProgressBar(cmd *cobra.Command) (*app.App, error) {
187	app, err := setupApp(cmd)
188	if err != nil {
189		return nil, err
190	}
191
192	// Check if progress bar is enabled in config (defaults to true if nil)
193	progressEnabled := app.Config().Options.Progress == nil || *app.Config().Options.Progress
194	if progressEnabled && supportsProgressBar() {
195		_, _ = fmt.Fprintf(os.Stderr, ansi.SetIndeterminateProgressBar)
196		defer func() { _, _ = fmt.Fprintf(os.Stderr, ansi.ResetProgressBar) }()
197	}
198
199	return app, nil
200}
201
202// setupApp handles the common setup logic for both interactive and non-interactive modes.
203// It returns the app instance, config, cleanup function, and any error.
204func setupApp(cmd *cobra.Command) (*app.App, error) {
205	debug, _ := cmd.Flags().GetBool("debug")
206	yolo, _ := cmd.Flags().GetBool("yolo")
207	dataDir, _ := cmd.Flags().GetString("data-dir")
208	ctx := cmd.Context()
209
210	cwd, err := ResolveCwd(cmd)
211	if err != nil {
212		return nil, err
213	}
214
215	cfg, err := config.Init(cwd, dataDir, debug)
216	if err != nil {
217		return nil, err
218	}
219
220	if cfg.Permissions == nil {
221		cfg.Permissions = &config.Permissions{}
222	}
223	cfg.Permissions.SkipRequests = yolo
224
225	if err := createDotCrushDir(cfg.Options.DataDirectory); err != nil {
226		return nil, err
227	}
228
229	// Register this project in the centralized projects list.
230	if err := projects.Register(cwd, cfg.Options.DataDirectory); err != nil {
231		slog.Warn("Failed to register project", "error", err)
232		// Non-fatal: continue even if registration fails
233	}
234
235	// Connect to DB; this will also run migrations.
236	conn, err := db.Connect(ctx, cfg.Options.DataDirectory)
237	if err != nil {
238		return nil, err
239	}
240
241	appInstance, err := app.New(ctx, conn, cfg)
242	if err != nil {
243		slog.Error("Failed to create app instance", "error", err)
244		return nil, err
245	}
246
247	if shouldEnableMetrics() {
248		event.Init()
249	}
250
251	return appInstance, nil
252}
253
254func shouldEnableMetrics() bool {
255	if v, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_METRICS")); v {
256		return false
257	}
258	if v, _ := strconv.ParseBool(os.Getenv("DO_NOT_TRACK")); v {
259		return false
260	}
261	if config.Get().Options.DisableMetrics {
262		return false
263	}
264	return true
265}
266
267func MaybePrependStdin(prompt string) (string, error) {
268	if term.IsTerminal(os.Stdin.Fd()) {
269		return prompt, nil
270	}
271	fi, err := os.Stdin.Stat()
272	if err != nil {
273		return prompt, err
274	}
275	// Check if stdin is a named pipe ( | ) or regular file ( < ).
276	if fi.Mode()&os.ModeNamedPipe == 0 && !fi.Mode().IsRegular() {
277		return prompt, nil
278	}
279	bts, err := io.ReadAll(os.Stdin)
280	if err != nil {
281		return prompt, err
282	}
283	return string(bts) + "\n\n" + prompt, nil
284}
285
286func ResolveCwd(cmd *cobra.Command) (string, error) {
287	cwd, _ := cmd.Flags().GetString("cwd")
288	if cwd != "" {
289		err := os.Chdir(cwd)
290		if err != nil {
291			return "", fmt.Errorf("failed to change directory: %v", err)
292		}
293		return cwd, nil
294	}
295	cwd, err := os.Getwd()
296	if err != nil {
297		return "", fmt.Errorf("failed to get current working directory: %v", err)
298	}
299	return cwd, nil
300}
301
302func createDotCrushDir(dir string) error {
303	if err := os.MkdirAll(dir, 0o700); err != nil {
304		return fmt.Errorf("failed to create data directory: %q %w", dir, err)
305	}
306
307	gitIgnorePath := filepath.Join(dir, ".gitignore")
308	if _, err := os.Stat(gitIgnorePath); os.IsNotExist(err) {
309		if err := os.WriteFile(gitIgnorePath, []byte("*\n"), 0o644); err != nil {
310			return fmt.Errorf("failed to create .gitignore file: %q %w", gitIgnorePath, err)
311		}
312	}
313
314	return nil
315}
316
317// TODO: Remove me after dropping the old TUI.
318func shouldQueryCapabilities(env uv.Environ) bool {
319	const osVendorTypeApple = "Apple"
320	termType := env.Getenv("TERM")
321	termProg, okTermProg := env.LookupEnv("TERM_PROGRAM")
322	_, okSSHTTY := env.LookupEnv("SSH_TTY")
323	if okTermProg && strings.Contains(termProg, osVendorTypeApple) {
324		return false
325	}
326	return (!okTermProg && !okSSHTTY) ||
327		(!strings.Contains(termProg, osVendorTypeApple) && !okSSHTTY) ||
328		// Terminals that do support XTVERSION.
329		xstrings.ContainsAnyOf(termType, kittyTerminals...)
330}