root.go

  1package cmd
  2
  3import (
  4	"bytes"
  5	"context"
  6	"errors"
  7	"fmt"
  8	"io"
  9	"io/fs"
 10	"log/slog"
 11	"net/url"
 12	"os"
 13	"os/exec"
 14	"path/filepath"
 15	"regexp"
 16	"strconv"
 17	"strings"
 18	"time"
 19
 20	tea "charm.land/bubbletea/v2"
 21	"charm.land/lipgloss/v2"
 22	"github.com/charmbracelet/colorprofile"
 23	"github.com/charmbracelet/crush/internal/client"
 24	"github.com/charmbracelet/crush/internal/config"
 25	"github.com/charmbracelet/crush/internal/event"
 26	"github.com/charmbracelet/crush/internal/proto"
 27	"github.com/charmbracelet/crush/internal/server"
 28	"github.com/charmbracelet/crush/internal/ui/common"
 29	ui "github.com/charmbracelet/crush/internal/ui/model"
 30	"github.com/charmbracelet/crush/internal/version"
 31	"github.com/charmbracelet/crush/internal/workspace"
 32	"github.com/charmbracelet/fang"
 33	uv "github.com/charmbracelet/ultraviolet"
 34	"github.com/charmbracelet/x/exp/charmtone"
 35	"github.com/charmbracelet/x/term"
 36	"github.com/spf13/cobra"
 37)
 38
 39var clientHost string
 40
 41func init() {
 42	rootCmd.PersistentFlags().StringP("cwd", "c", "", "Current working directory")
 43	rootCmd.PersistentFlags().StringP("data-dir", "D", "", "Custom crush data directory")
 44	rootCmd.PersistentFlags().BoolP("debug", "d", false, "Debug")
 45	rootCmd.PersistentFlags().BoolP("yolo", "y", false, "Automatically accept all permissions (dangerous mode)")
 46	rootCmd.PersistentFlags().StringVarP(&clientHost, "host", "H", server.DefaultHost(), "Connect to a specific crush server host (for advanced users)")
 47	rootCmd.Flags().BoolP("help", "h", false, "Help")
 48
 49	rootCmd.AddCommand(
 50		runCmd,
 51		dirsCmd,
 52		projectsCmd,
 53		updateProvidersCmd,
 54		logsCmd,
 55		schemaCmd,
 56		loginCmd,
 57		statsCmd,
 58	)
 59}
 60
 61var rootCmd = &cobra.Command{
 62	Use:   "crush",
 63	Short: "A terminal-first AI assistant for software development",
 64	Long:  "A glamorous, terminal-first AI assistant for software development and adjacent tasks",
 65	Example: `
 66# Run in interactive mode
 67crush
 68
 69# Run non-interactively
 70crush run "Guess my 5 favorite PokΓ©mon"
 71
 72# Run a non-interactively with pipes and redirection
 73cat README.md | crush run "make this more glamorous" > GLAMOROUS_README.md
 74
 75# Run with debug logging in a specific directory
 76crush --debug --cwd /path/to/project
 77
 78# Run in yolo mode (auto-accept all permissions; use with care)
 79crush --yolo
 80
 81# Run with custom data directory
 82crush --data-dir /path/to/custom/.crush
 83  `,
 84	RunE: func(cmd *cobra.Command, args []string) error {
 85		c, ws, cleanup, err := connectToServer(cmd)
 86		if err != nil {
 87			return err
 88		}
 89		defer cleanup()
 90
 91		event.AppInitialized()
 92
 93		clientWs := workspace.NewClientWorkspace(c, *ws)
 94
 95		if ws.Config.IsConfigured() {
 96			if err := clientWs.InitCoderAgent(cmd.Context()); err != nil {
 97				slog.Error("Failed to initialize coder agent", "error", err)
 98			}
 99		}
100
101		com := common.DefaultCommon(clientWs)
102		model := ui.New(com)
103
104		var env uv.Environ = os.Environ()
105		program := tea.NewProgram(
106			model,
107			tea.WithEnvironment(env),
108			tea.WithContext(cmd.Context()),
109			tea.WithFilter(ui.MouseEventFilter),
110		)
111		go clientWs.Subscribe(program)
112
113		if _, err := program.Run(); err != nil {
114			event.Error(err)
115			slog.Error("TUI run error", "error", err)
116			return errors.New("Crush crashed. If metrics are enabled, we were notified about it. If you'd like to report it, please copy the stacktrace above and open an issue at https://github.com/charmbracelet/crush/issues/new?template=bug.yml") //nolint:staticcheck
117		}
118		return nil
119	},
120}
121
122var heartbit = lipgloss.NewStyle().Foreground(charmtone.Dolly).SetString(`
123    β–„β–„β–„β–„β–„β–„β–„β–„    β–„β–„β–„β–„β–„β–„β–„β–„
124  β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ  β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ
125β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ
126β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ
127β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–€β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–€β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ
128β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ
129β–€β–€β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–„β–ˆβ–ˆβ–ˆβ–ˆβ–„β–„β–ˆβ–ˆβ–ˆβ–ˆβ–„β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–€β–€
130  β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ
131    β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ
132       β–€β–€β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–€β–€
133           β–€β–€β–€β–€β–€β–€
134`)
135
136// copied from cobra:
137const defaultVersionTemplate = `{{with .DisplayName}}{{printf "%s " .}}{{end}}{{printf "version %s" .Version}}
138`
139
140func Execute() {
141	// NOTE: very hacky: we create a colorprofile writer with STDOUT, then make
142	// it forward to a bytes.Buffer, write the colored heartbit to it, and then
143	// finally prepend it in the version template.
144	// Unfortunately cobra doesn't give us a way to set a function to handle
145	// printing the version, and PreRunE runs after the version is already
146	// handled, so that doesn't work either.
147	// This is the only way I could find that works relatively well.
148	if term.IsTerminal(os.Stdout.Fd()) {
149		var b bytes.Buffer
150		w := colorprofile.NewWriter(os.Stdout, os.Environ())
151		w.Forward = &b
152		_, _ = w.WriteString(heartbit.String())
153		rootCmd.SetVersionTemplate(b.String() + "\n" + defaultVersionTemplate)
154	}
155	if err := fang.Execute(
156		context.Background(),
157		rootCmd,
158		fang.WithVersion(version.Version),
159		fang.WithNotifySignal(os.Interrupt),
160	); err != nil {
161		os.Exit(1)
162	}
163}
164
165// supportsProgressBar tries to determine whether the current terminal supports
166// progress bars by looking into environment variables.
167func supportsProgressBar() bool {
168	if !term.IsTerminal(os.Stderr.Fd()) {
169		return false
170	}
171	termProg := os.Getenv("TERM_PROGRAM")
172	_, isWindowsTerminal := os.LookupEnv("WT_SESSION")
173
174	return isWindowsTerminal || strings.Contains(strings.ToLower(termProg), "ghostty")
175}
176
177// connectToServer ensures the server is running, creates a client and
178// workspace, and returns a cleanup function that deletes the workspace.
179func connectToServer(cmd *cobra.Command) (*client.Client, *proto.Workspace, func(), error) {
180	hostURL, err := server.ParseHostURL(clientHost)
181	if err != nil {
182		return nil, nil, nil, fmt.Errorf("invalid host URL: %v", err)
183	}
184
185	if err := ensureServer(cmd, hostURL); err != nil {
186		return nil, nil, nil, err
187	}
188
189	debug, _ := cmd.Flags().GetBool("debug")
190	yolo, _ := cmd.Flags().GetBool("yolo")
191	dataDir, _ := cmd.Flags().GetString("data-dir")
192	ctx := cmd.Context()
193
194	cwd, err := ResolveCwd(cmd)
195	if err != nil {
196		return nil, nil, nil, err
197	}
198
199	c, err := client.NewClient(cwd, hostURL.Scheme, hostURL.Host)
200	if err != nil {
201		return nil, nil, nil, err
202	}
203
204	wsReq := proto.Workspace{
205		Path:    cwd,
206		DataDir: dataDir,
207		Debug:   debug,
208		YOLO:    yolo,
209		Version: version.Version,
210		Env:     os.Environ(),
211	}
212
213	ws, err := c.CreateWorkspace(ctx, wsReq)
214	if err != nil {
215		// The server socket may exist before the HTTP handler is ready.
216		// Retry a few times with a short backoff.
217		for range 5 {
218			select {
219			case <-ctx.Done():
220				return nil, nil, nil, ctx.Err()
221			case <-time.After(200 * time.Millisecond):
222			}
223			ws, err = c.CreateWorkspace(ctx, wsReq)
224			if err == nil {
225				break
226			}
227		}
228		if err != nil {
229			return nil, nil, nil, fmt.Errorf("failed to create workspace: %v", err)
230		}
231	}
232
233	if shouldEnableMetrics(ws.Config) {
234		event.Init()
235	}
236
237	cleanup := func() { _ = c.DeleteWorkspace(context.Background(), ws.ID) }
238	return c, ws, cleanup, nil
239}
240
241// ensureServer auto-starts a detached server if the socket file does not
242// exist. When the socket exists, it verifies that the running server
243// version matches the client; on mismatch it shuts down the old server
244// and starts a fresh one.
245func ensureServer(cmd *cobra.Command, hostURL *url.URL) error {
246	switch hostURL.Scheme {
247	case "unix", "npipe":
248		needsStart := false
249		if _, err := os.Stat(hostURL.Host); err != nil && errors.Is(err, fs.ErrNotExist) {
250			needsStart = true
251		} else if err == nil {
252			if err := restartIfStale(cmd, hostURL); err != nil {
253				slog.Warn("Failed to check server version, restarting", "error", err)
254				needsStart = true
255			}
256		}
257
258		if needsStart {
259			if err := startDetachedServer(cmd); err != nil {
260				return err
261			}
262		}
263
264		var err error
265		for range 10 {
266			_, err = os.Stat(hostURL.Host)
267			if err == nil {
268				break
269			}
270			select {
271			case <-cmd.Context().Done():
272				return cmd.Context().Err()
273			case <-time.After(100 * time.Millisecond):
274			}
275		}
276		if err != nil {
277			return fmt.Errorf("failed to initialize crush server: %v", err)
278		}
279	}
280
281	return nil
282}
283
284// restartIfStale checks whether the running server matches the current
285// client version. When they differ, it sends a shutdown command and
286// removes the stale socket so the caller can start a fresh server.
287func restartIfStale(cmd *cobra.Command, hostURL *url.URL) error {
288	c, err := client.NewClient("", hostURL.Scheme, hostURL.Host)
289	if err != nil {
290		return err
291	}
292	vi, err := c.VersionInfo(cmd.Context())
293	if err != nil {
294		return err
295	}
296	if vi.Version == version.Version {
297		return nil
298	}
299	slog.Info("Server version mismatch, restarting",
300		"server", vi.Version,
301		"client", version.Version,
302	)
303	_ = c.ShutdownServer(cmd.Context())
304	// Give the old process a moment to release the socket.
305	for range 20 {
306		if _, err := os.Stat(hostURL.Host); errors.Is(err, fs.ErrNotExist) {
307			break
308		}
309		select {
310		case <-cmd.Context().Done():
311			return cmd.Context().Err()
312		case <-time.After(100 * time.Millisecond):
313		}
314	}
315	// Force-remove if the socket is still lingering.
316	_ = os.Remove(hostURL.Host)
317	return nil
318}
319
320var safeNameRegexp = regexp.MustCompile(`[^a-zA-Z0-9._-]`)
321
322func startDetachedServer(cmd *cobra.Command) error {
323	exe, err := os.Executable()
324	if err != nil {
325		return fmt.Errorf("failed to get executable path: %v", err)
326	}
327
328	safeClientHost := safeNameRegexp.ReplaceAllString(clientHost, "_")
329	chDir := filepath.Join(config.GlobalCacheDir(), "server-"+safeClientHost)
330	if err := os.MkdirAll(chDir, 0o700); err != nil {
331		return fmt.Errorf("failed to create server working directory: %v", err)
332	}
333
334	cmdArgs := []string{"server"}
335	if clientHost != server.DefaultHost() {
336		cmdArgs = append(cmdArgs, "--host", clientHost)
337	}
338
339	c := exec.CommandContext(cmd.Context(), exe, cmdArgs...)
340	stdoutPath := filepath.Join(chDir, "stdout.log")
341	stderrPath := filepath.Join(chDir, "stderr.log")
342	detachProcess(c)
343
344	stdout, err := os.Create(stdoutPath)
345	if err != nil {
346		return fmt.Errorf("failed to create stdout log file: %v", err)
347	}
348	defer stdout.Close()
349	c.Stdout = stdout
350
351	stderr, err := os.Create(stderrPath)
352	if err != nil {
353		return fmt.Errorf("failed to create stderr log file: %v", err)
354	}
355	defer stderr.Close()
356	c.Stderr = stderr
357
358	if err := c.Start(); err != nil {
359		return fmt.Errorf("failed to start crush server: %v", err)
360	}
361
362	if err := c.Process.Release(); err != nil {
363		return fmt.Errorf("failed to detach crush server process: %v", err)
364	}
365
366	return nil
367}
368
369func shouldEnableMetrics(cfg *config.Config) bool {
370	if v, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_METRICS")); v {
371		return false
372	}
373	if v, _ := strconv.ParseBool(os.Getenv("DO_NOT_TRACK")); v {
374		return false
375	}
376	if cfg.Options.DisableMetrics {
377		return false
378	}
379	return true
380}
381
382func MaybePrependStdin(prompt string) (string, error) {
383	if term.IsTerminal(os.Stdin.Fd()) {
384		return prompt, nil
385	}
386	fi, err := os.Stdin.Stat()
387	if err != nil {
388		return prompt, err
389	}
390	// Check if stdin is a named pipe ( | ) or regular file ( < ).
391	if fi.Mode()&os.ModeNamedPipe == 0 && !fi.Mode().IsRegular() {
392		return prompt, nil
393	}
394	bts, err := io.ReadAll(os.Stdin)
395	if err != nil {
396		return prompt, err
397	}
398	return string(bts) + "\n\n" + prompt, nil
399}
400
401func ResolveCwd(cmd *cobra.Command) (string, error) {
402	cwd, _ := cmd.Flags().GetString("cwd")
403	if cwd != "" {
404		err := os.Chdir(cwd)
405		if err != nil {
406			return "", fmt.Errorf("failed to change directory: %v", err)
407		}
408		return cwd, nil
409	}
410	cwd, err := os.Getwd()
411	if err != nil {
412		return "", fmt.Errorf("failed to get current working directory: %v", err)
413	}
414	return cwd, nil
415}