root.go

  1package cmd
  2
  3import (
  4	"bytes"
  5	"context"
  6	"errors"
  7	"fmt"
  8	"io"
  9	"log/slog"
 10	"os"
 11	"path/filepath"
 12	"strconv"
 13	"strings"
 14
 15	tea "charm.land/bubbletea/v2"
 16	"charm.land/lipgloss/v2"
 17	"github.com/charmbracelet/colorprofile"
 18	"github.com/charmbracelet/crush/internal/app"
 19	"github.com/charmbracelet/crush/internal/config"
 20	"github.com/charmbracelet/crush/internal/db"
 21	"github.com/charmbracelet/crush/internal/event"
 22	"github.com/charmbracelet/crush/internal/projects"
 23	"github.com/charmbracelet/crush/internal/stringext"
 24	"github.com/charmbracelet/crush/internal/tui"
 25	"github.com/charmbracelet/crush/internal/ui/common"
 26	ui "github.com/charmbracelet/crush/internal/ui/model"
 27	"github.com/charmbracelet/crush/internal/version"
 28	"github.com/charmbracelet/fang"
 29	uv "github.com/charmbracelet/ultraviolet"
 30	"github.com/charmbracelet/x/ansi"
 31	"github.com/charmbracelet/x/exp/charmtone"
 32	"github.com/charmbracelet/x/term"
 33	"github.com/spf13/cobra"
 34)
 35
 36func init() {
 37	rootCmd.PersistentFlags().StringP("cwd", "c", "", "Current working directory")
 38	rootCmd.PersistentFlags().StringP("data-dir", "D", "", "Custom crush data directory")
 39	rootCmd.PersistentFlags().BoolP("debug", "d", false, "Debug")
 40	rootCmd.Flags().BoolP("help", "h", false, "Help")
 41	rootCmd.Flags().BoolP("yolo", "y", false, "Automatically accept all permissions (dangerous mode)")
 42
 43	rootCmd.AddCommand(
 44		runCmd,
 45		dirsCmd,
 46		projectsCmd,
 47		updateProvidersCmd,
 48		logsCmd,
 49		schemaCmd,
 50		loginCmd,
 51	)
 52}
 53
 54var rootCmd = &cobra.Command{
 55	Use:   "crush",
 56	Short: "An AI assistant for software development",
 57	Long:  "An AI assistant for software development and similar tasks with direct access to the terminal",
 58	Example: `
 59# Run in interactive mode
 60crush
 61
 62# Run with debug logging
 63crush -d
 64
 65# Run with debug logging in a specific directory
 66crush -d -c /path/to/project
 67
 68# Run with custom data directory
 69crush -D /path/to/custom/.crush
 70
 71# Print version
 72crush -v
 73
 74# Run a single non-interactive prompt
 75crush run "Explain the use of context in Go"
 76
 77# Run in dangerous mode (auto-accept all permissions)
 78crush -y
 79  `,
 80	RunE: func(cmd *cobra.Command, args []string) error {
 81		app, err := setupAppWithProgressBar(cmd)
 82		if err != nil {
 83			return err
 84		}
 85		defer app.Shutdown()
 86
 87		event.AppInitialized()
 88
 89		// Set up the TUI.
 90		var env uv.Environ = os.Environ()
 91
 92		var model tea.Model
 93		if v, _ := strconv.ParseBool(env.Getenv("CRUSH_NEW_UI")); v {
 94			slog.Info("New UI in control!")
 95			com := common.DefaultCommon(app)
 96			ui := ui.New(com)
 97			ui.QueryVersion = shouldQueryTerminalVersion(env)
 98			model = ui
 99		} else {
100			ui := tui.New(app)
101			ui.QueryVersion = shouldQueryTerminalVersion(env)
102			model = ui
103		}
104		program := tea.NewProgram(
105			model,
106			tea.WithEnvironment(env),
107			tea.WithContext(cmd.Context()),
108			tea.WithFilter(tui.MouseEventFilter)) // Filter mouse events based on focus state
109		go app.Subscribe(program)
110
111		if _, err := program.Run(); err != nil {
112			event.Error(err)
113			slog.Error("TUI run error", "error", err)
114			return errors.New("Crush crashed. If metrics are enabled, we were notified about it. If you'd like to report it, please copy the stacktrace above and open an issue at https://github.com/charmbracelet/crush/issues/new?template=bug.yml") //nolint:staticcheck
115		}
116		return nil
117	},
118	PostRun: func(cmd *cobra.Command, args []string) {
119		event.AppExited()
120	},
121}
122
123var heartbit = lipgloss.NewStyle().Foreground(charmtone.Dolly).SetString(`
124    ▄▄▄▄▄▄▄▄    ▄▄▄▄▄▄▄▄
125  ███████████  ███████████
126████████████████████████████
127████████████████████████████
128██████████▀██████▀██████████
129██████████ ██████ ██████████
130▀▀██████▄████▄▄████▄██████▀▀
131  ████████████████████████
132    ████████████████████
133       ▀▀██████████▀▀
134           ▀▀▀▀▀▀
135`)
136
137// copied from cobra:
138const defaultVersionTemplate = `{{with .DisplayName}}{{printf "%s " .}}{{end}}{{printf "version %s" .Version}}
139`
140
141func Execute() {
142	// NOTE: very hacky: we create a colorprofile writer with STDOUT, then make
143	// it forward to a bytes.Buffer, write the colored heartbit to it, and then
144	// finally prepend it in the version template.
145	// Unfortunately cobra doesn't give us a way to set a function to handle
146	// printing the version, and PreRunE runs after the version is already
147	// handled, so that doesn't work either.
148	// This is the only way I could find that works relatively well.
149	if term.IsTerminal(os.Stdout.Fd()) {
150		var b bytes.Buffer
151		w := colorprofile.NewWriter(os.Stdout, os.Environ())
152		w.Forward = &b
153		_, _ = w.WriteString(heartbit.String())
154		rootCmd.SetVersionTemplate(b.String() + "\n" + defaultVersionTemplate)
155	}
156	if err := fang.Execute(
157		context.Background(),
158		rootCmd,
159		fang.WithVersion(version.Version),
160		fang.WithNotifySignal(os.Interrupt),
161	); err != nil {
162		os.Exit(1)
163	}
164}
165
166// supportsProgressBar tries to determine whether the current terminal supports
167// progress bars by looking into environment variables.
168func supportsProgressBar() bool {
169	if !term.IsTerminal(os.Stderr.Fd()) {
170		return false
171	}
172	termProg := os.Getenv("TERM_PROGRAM")
173	_, isWindowsTerminal := os.LookupEnv("WT_SESSION")
174
175	return isWindowsTerminal || strings.Contains(strings.ToLower(termProg), "ghostty")
176}
177
178func setupAppWithProgressBar(cmd *cobra.Command) (*app.App, error) {
179	if supportsProgressBar() {
180		_, _ = fmt.Fprintf(os.Stderr, ansi.SetIndeterminateProgressBar)
181		defer func() { _, _ = fmt.Fprintf(os.Stderr, ansi.ResetProgressBar) }()
182	}
183
184	return setupApp(cmd)
185}
186
187// setupApp handles the common setup logic for both interactive and non-interactive modes.
188// It returns the app instance, config, cleanup function, and any error.
189func setupApp(cmd *cobra.Command) (*app.App, error) {
190	debug, _ := cmd.Flags().GetBool("debug")
191	yolo, _ := cmd.Flags().GetBool("yolo")
192	dataDir, _ := cmd.Flags().GetString("data-dir")
193	ctx := cmd.Context()
194
195	cwd, err := ResolveCwd(cmd)
196	if err != nil {
197		return nil, err
198	}
199
200	cfg, err := config.Init(cwd, dataDir, debug)
201	if err != nil {
202		return nil, err
203	}
204
205	if cfg.Permissions == nil {
206		cfg.Permissions = &config.Permissions{}
207	}
208	cfg.Permissions.SkipRequests = yolo
209
210	if err := createDotCrushDir(cfg.Options.DataDirectory); err != nil {
211		return nil, err
212	}
213
214	// Register this project in the centralized projects list.
215	if err := projects.Register(cwd, cfg.Options.DataDirectory); err != nil {
216		slog.Warn("Failed to register project", "error", err)
217		// Non-fatal: continue even if registration fails
218	}
219
220	// Connect to DB; this will also run migrations.
221	conn, err := db.Connect(ctx, cfg.Options.DataDirectory)
222	if err != nil {
223		return nil, err
224	}
225
226	appInstance, err := app.New(ctx, conn, cfg)
227	if err != nil {
228		slog.Error("Failed to create app instance", "error", err)
229		return nil, err
230	}
231
232	if shouldEnableMetrics() {
233		event.Init()
234	}
235
236	return appInstance, nil
237}
238
239func shouldEnableMetrics() bool {
240	if v, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_METRICS")); v {
241		return false
242	}
243	if v, _ := strconv.ParseBool(os.Getenv("DO_NOT_TRACK")); v {
244		return false
245	}
246	if config.Get().Options.DisableMetrics {
247		return false
248	}
249	return true
250}
251
252func MaybePrependStdin(prompt string) (string, error) {
253	if term.IsTerminal(os.Stdin.Fd()) {
254		return prompt, nil
255	}
256	fi, err := os.Stdin.Stat()
257	if err != nil {
258		return prompt, err
259	}
260	// Check if stdin is a named pipe ( | ) or regular file ( < ).
261	if fi.Mode()&os.ModeNamedPipe == 0 && !fi.Mode().IsRegular() {
262		return prompt, nil
263	}
264	bts, err := io.ReadAll(os.Stdin)
265	if err != nil {
266		return prompt, err
267	}
268	return string(bts) + "\n\n" + prompt, nil
269}
270
271func ResolveCwd(cmd *cobra.Command) (string, error) {
272	cwd, _ := cmd.Flags().GetString("cwd")
273	if cwd != "" {
274		err := os.Chdir(cwd)
275		if err != nil {
276			return "", fmt.Errorf("failed to change directory: %v", err)
277		}
278		return cwd, nil
279	}
280	cwd, err := os.Getwd()
281	if err != nil {
282		return "", fmt.Errorf("failed to get current working directory: %v", err)
283	}
284	return cwd, nil
285}
286
287func createDotCrushDir(dir string) error {
288	if err := os.MkdirAll(dir, 0o700); err != nil {
289		return fmt.Errorf("failed to create data directory: %q %w", dir, err)
290	}
291
292	gitIgnorePath := filepath.Join(dir, ".gitignore")
293	if _, err := os.Stat(gitIgnorePath); os.IsNotExist(err) {
294		if err := os.WriteFile(gitIgnorePath, []byte("*\n"), 0o644); err != nil {
295			return fmt.Errorf("failed to create .gitignore file: %q %w", gitIgnorePath, err)
296		}
297	}
298
299	return nil
300}
301
302func shouldQueryTerminalVersion(env uv.Environ) bool {
303	termType := env.Getenv("TERM")
304	termProg, okTermProg := env.LookupEnv("TERM_PROGRAM")
305	_, okSSHTTY := env.LookupEnv("SSH_TTY")
306	return (!okTermProg && !okSSHTTY) ||
307		(!strings.Contains(termProg, "Apple") && !okSSHTTY) ||
308		// Terminals that do support XTVERSION.
309		stringext.ContainsAny(termType, "alacritty", "ghostty", "kitty", "rio", "wezterm")
310}