root.go

  1package cmd
  2
  3import (
  4	"bytes"
  5	"context"
  6	"errors"
  7	"fmt"
  8	"io"
  9	"io/fs"
 10	"log/slog"
 11	"net/url"
 12	"os"
 13	"os/exec"
 14	"path/filepath"
 15	"regexp"
 16	"strconv"
 17	"strings"
 18	"time"
 19
 20	tea "charm.land/bubbletea/v2"
 21	fang "charm.land/fang/v2"
 22	"charm.land/lipgloss/v2"
 23	"github.com/charmbracelet/colorprofile"
 24	"github.com/charmbracelet/crush/internal/client"
 25	"github.com/charmbracelet/crush/internal/config"
 26	"github.com/charmbracelet/crush/internal/event"
 27	crushlog "github.com/charmbracelet/crush/internal/log"
 28	"github.com/charmbracelet/crush/internal/proto"
 29	"github.com/charmbracelet/crush/internal/server"
 30	"github.com/charmbracelet/crush/internal/ui/common"
 31	ui "github.com/charmbracelet/crush/internal/ui/model"
 32	"github.com/charmbracelet/crush/internal/version"
 33	"github.com/charmbracelet/crush/internal/workspace"
 34	uv "github.com/charmbracelet/ultraviolet"
 35	"github.com/charmbracelet/x/exp/charmtone"
 36	"github.com/charmbracelet/x/term"
 37	"github.com/spf13/cobra"
 38)
 39
 40var clientHost string
 41
 42func init() {
 43	rootCmd.PersistentFlags().StringP("cwd", "c", "", "Current working directory")
 44	rootCmd.PersistentFlags().StringP("data-dir", "D", "", "Custom crush data directory")
 45	rootCmd.PersistentFlags().BoolP("debug", "d", false, "Debug")
 46	rootCmd.PersistentFlags().BoolP("yolo", "y", false, "Automatically accept all permissions (dangerous mode)")
 47	rootCmd.PersistentFlags().StringVarP(&clientHost, "host", "H", server.DefaultHost(), "Connect to a specific crush server host (for advanced users)")
 48	rootCmd.Flags().BoolP("help", "h", false, "Help")
 49
 50	rootCmd.AddCommand(
 51		runCmd,
 52		dirsCmd,
 53		projectsCmd,
 54		updateProvidersCmd,
 55		logsCmd,
 56		schemaCmd,
 57		loginCmd,
 58		statsCmd,
 59		sessionCmd,
 60	)
 61}
 62
 63var rootCmd = &cobra.Command{
 64	Use:   "crush",
 65	Short: "A terminal-first AI assistant for software development",
 66	Long:  "A glamorous, terminal-first AI assistant for software development and adjacent tasks",
 67	Example: `
 68# Run in interactive mode
 69crush
 70
 71# Run non-interactively
 72crush run "Guess my 5 favorite PokΓ©mon"
 73
 74# Run a non-interactively with pipes and redirection
 75cat README.md | crush run "make this more glamorous" > GLAMOROUS_README.md
 76
 77# Run with debug logging in a specific directory
 78crush --debug --cwd /path/to/project
 79
 80# Run in yolo mode (auto-accept all permissions; use with care)
 81crush --yolo
 82
 83# Run with custom data directory
 84crush --data-dir /path/to/custom/.crush
 85  `,
 86	RunE: func(cmd *cobra.Command, args []string) error {
 87		c, ws, cleanup, err := connectToServer(cmd)
 88		if err != nil {
 89			return err
 90		}
 91		defer cleanup()
 92
 93		event.AppInitialized()
 94
 95		clientWs := workspace.NewClientWorkspace(c, *ws)
 96
 97		if ws.Config.IsConfigured() {
 98			if err := clientWs.InitCoderAgent(cmd.Context()); err != nil {
 99				slog.Error("Failed to initialize coder agent", "error", err)
100			}
101		}
102
103		com := common.DefaultCommon(clientWs)
104		model := ui.New(com)
105
106		var env uv.Environ = os.Environ()
107		program := tea.NewProgram(
108			model,
109			tea.WithEnvironment(env),
110			tea.WithContext(cmd.Context()),
111			tea.WithFilter(ui.MouseEventFilter),
112		)
113		go clientWs.Subscribe(program)
114
115		if _, err := program.Run(); err != nil {
116			event.Error(err)
117			slog.Error("TUI run error", "error", err)
118			return errors.New("Crush crashed. If metrics are enabled, we were notified about it. If you'd like to report it, please copy the stacktrace above and open an issue at https://github.com/charmbracelet/crush/issues/new?template=bug.yml") //nolint:staticcheck
119		}
120		return nil
121	},
122}
123
124var heartbit = lipgloss.NewStyle().Foreground(charmtone.Dolly).SetString(`
125    β–„β–„β–„β–„β–„β–„β–„β–„    β–„β–„β–„β–„β–„β–„β–„β–„
126  β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ  β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ
127β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ
128β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ
129β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–€β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–€β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ
130β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ
131β–€β–€β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–„β–ˆβ–ˆβ–ˆβ–ˆβ–„β–„β–ˆβ–ˆβ–ˆβ–ˆβ–„β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–€β–€
132  β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ
133    β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ
134       β–€β–€β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–€β–€
135           β–€β–€β–€β–€β–€β–€
136`)
137
138// copied from cobra:
139const defaultVersionTemplate = `{{with .DisplayName}}{{printf "%s " .}}{{end}}{{printf "version %s" .Version}}
140`
141
142func Execute() {
143	// NOTE: very hacky: we create a colorprofile writer with STDOUT, then make
144	// it forward to a bytes.Buffer, write the colored heartbit to it, and then
145	// finally prepend it in the version template.
146	// Unfortunately cobra doesn't give us a way to set a function to handle
147	// printing the version, and PreRunE runs after the version is already
148	// handled, so that doesn't work either.
149	// This is the only way I could find that works relatively well.
150	if term.IsTerminal(os.Stdout.Fd()) {
151		var b bytes.Buffer
152		w := colorprofile.NewWriter(os.Stdout, os.Environ())
153		w.Forward = &b
154		_, _ = w.WriteString(heartbit.String())
155		rootCmd.SetVersionTemplate(b.String() + "\n" + defaultVersionTemplate)
156	}
157	if err := fang.Execute(
158		context.Background(),
159		rootCmd,
160		fang.WithVersion(version.Version),
161		fang.WithNotifySignal(os.Interrupt),
162	); err != nil {
163		os.Exit(1)
164	}
165}
166
167// supportsProgressBar tries to determine whether the current terminal supports
168// progress bars by looking into environment variables.
169func supportsProgressBar() bool {
170	if !term.IsTerminal(os.Stderr.Fd()) {
171		return false
172	}
173	termProg := os.Getenv("TERM_PROGRAM")
174	_, isWindowsTerminal := os.LookupEnv("WT_SESSION")
175
176	return isWindowsTerminal || strings.Contains(strings.ToLower(termProg), "ghostty")
177}
178
179// connectToServer ensures the server is running, creates a client and
180// workspace, and returns a cleanup function that deletes the workspace.
181func connectToServer(cmd *cobra.Command) (*client.Client, *proto.Workspace, func(), error) {
182	hostURL, err := server.ParseHostURL(clientHost)
183	if err != nil {
184		return nil, nil, nil, fmt.Errorf("invalid host URL: %v", err)
185	}
186
187	if err := ensureServer(cmd, hostURL); err != nil {
188		return nil, nil, nil, err
189	}
190
191	debug, _ := cmd.Flags().GetBool("debug")
192	yolo, _ := cmd.Flags().GetBool("yolo")
193	dataDir, _ := cmd.Flags().GetString("data-dir")
194	ctx := cmd.Context()
195
196	cwd, err := ResolveCwd(cmd)
197	if err != nil {
198		return nil, nil, nil, err
199	}
200
201	c, err := client.NewClient(cwd, hostURL.Scheme, hostURL.Host)
202	if err != nil {
203		return nil, nil, nil, err
204	}
205
206	wsReq := proto.Workspace{
207		Path:    cwd,
208		DataDir: dataDir,
209		Debug:   debug,
210		YOLO:    yolo,
211		Version: version.Version,
212		Env:     os.Environ(),
213	}
214
215	ws, err := c.CreateWorkspace(ctx, wsReq)
216	if err != nil {
217		// The server socket may exist before the HTTP handler is ready.
218		// Retry a few times with a short backoff.
219		for range 5 {
220			select {
221			case <-ctx.Done():
222				return nil, nil, nil, ctx.Err()
223			case <-time.After(200 * time.Millisecond):
224			}
225			ws, err = c.CreateWorkspace(ctx, wsReq)
226			if err == nil {
227				break
228			}
229		}
230		if err != nil {
231			return nil, nil, nil, fmt.Errorf("failed to create workspace: %v", err)
232		}
233	}
234
235	if shouldEnableMetrics(ws.Config) {
236		event.Init()
237	}
238
239	if ws.Config != nil {
240		logFile := filepath.Join(ws.Config.Options.DataDirectory, "logs", "crush.log")
241		crushlog.Setup(logFile, debug)
242	}
243
244	cleanup := func() { _ = c.DeleteWorkspace(context.Background(), ws.ID) }
245	return c, ws, cleanup, nil
246}
247
248// ensureServer auto-starts a detached server if the socket file does not
249// exist. When the socket exists, it verifies that the running server
250// version matches the client; on mismatch it shuts down the old server
251// and starts a fresh one.
252func ensureServer(cmd *cobra.Command, hostURL *url.URL) error {
253	switch hostURL.Scheme {
254	case "unix", "npipe":
255		needsStart := false
256		if _, err := os.Stat(hostURL.Host); err != nil && errors.Is(err, fs.ErrNotExist) {
257			needsStart = true
258		} else if err == nil {
259			if err := restartIfStale(cmd, hostURL); err != nil {
260				slog.Warn("Failed to check server version, restarting", "error", err)
261				needsStart = true
262			}
263		}
264
265		if needsStart {
266			if err := startDetachedServer(cmd); err != nil {
267				return err
268			}
269		}
270
271		var err error
272		for range 10 {
273			_, err = os.Stat(hostURL.Host)
274			if err == nil {
275				break
276			}
277			select {
278			case <-cmd.Context().Done():
279				return cmd.Context().Err()
280			case <-time.After(100 * time.Millisecond):
281			}
282		}
283		if err != nil {
284			return fmt.Errorf("failed to initialize crush server: %v", err)
285		}
286	}
287
288	return nil
289}
290
291// restartIfStale checks whether the running server matches the current
292// client version. When they differ, it sends a shutdown command and
293// removes the stale socket so the caller can start a fresh server.
294func restartIfStale(cmd *cobra.Command, hostURL *url.URL) error {
295	c, err := client.NewClient("", hostURL.Scheme, hostURL.Host)
296	if err != nil {
297		return err
298	}
299	vi, err := c.VersionInfo(cmd.Context())
300	if err != nil {
301		return err
302	}
303	if vi.Version == version.Version {
304		return nil
305	}
306	slog.Info("Server version mismatch, restarting",
307		"server", vi.Version,
308		"client", version.Version,
309	)
310	_ = c.ShutdownServer(cmd.Context())
311	// Give the old process a moment to release the socket.
312	for range 20 {
313		if _, err := os.Stat(hostURL.Host); errors.Is(err, fs.ErrNotExist) {
314			break
315		}
316		select {
317		case <-cmd.Context().Done():
318			return cmd.Context().Err()
319		case <-time.After(100 * time.Millisecond):
320		}
321	}
322	// Force-remove if the socket is still lingering.
323	_ = os.Remove(hostURL.Host)
324	return nil
325}
326
327var safeNameRegexp = regexp.MustCompile(`[^a-zA-Z0-9._-]`)
328
329func startDetachedServer(cmd *cobra.Command) error {
330	exe, err := os.Executable()
331	if err != nil {
332		return fmt.Errorf("failed to get executable path: %v", err)
333	}
334
335	safeClientHost := safeNameRegexp.ReplaceAllString(clientHost, "_")
336	chDir := filepath.Join(config.GlobalCacheDir(), "server-"+safeClientHost)
337	if err := os.MkdirAll(chDir, 0o700); err != nil {
338		return fmt.Errorf("failed to create server working directory: %v", err)
339	}
340
341	cmdArgs := []string{"server"}
342	if clientHost != server.DefaultHost() {
343		cmdArgs = append(cmdArgs, "--host", clientHost)
344	}
345
346	c := exec.CommandContext(cmd.Context(), exe, cmdArgs...)
347	stdoutPath := filepath.Join(chDir, "stdout.log")
348	stderrPath := filepath.Join(chDir, "stderr.log")
349	detachProcess(c)
350
351	stdout, err := os.Create(stdoutPath)
352	if err != nil {
353		return fmt.Errorf("failed to create stdout log file: %v", err)
354	}
355	defer stdout.Close()
356	c.Stdout = stdout
357
358	stderr, err := os.Create(stderrPath)
359	if err != nil {
360		return fmt.Errorf("failed to create stderr log file: %v", err)
361	}
362	defer stderr.Close()
363	c.Stderr = stderr
364
365	if err := c.Start(); err != nil {
366		return fmt.Errorf("failed to start crush server: %v", err)
367	}
368
369	if err := c.Process.Release(); err != nil {
370		return fmt.Errorf("failed to detach crush server process: %v", err)
371	}
372
373	return nil
374}
375
376func shouldEnableMetrics(cfg *config.Config) bool {
377	if v, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_METRICS")); v {
378		return false
379	}
380	if v, _ := strconv.ParseBool(os.Getenv("DO_NOT_TRACK")); v {
381		return false
382	}
383	if cfg.Options.DisableMetrics {
384		return false
385	}
386	return true
387}
388
389func MaybePrependStdin(prompt string) (string, error) {
390	if term.IsTerminal(os.Stdin.Fd()) {
391		return prompt, nil
392	}
393	fi, err := os.Stdin.Stat()
394	if err != nil {
395		return prompt, err
396	}
397	// Check if stdin is a named pipe ( | ) or regular file ( < ).
398	if fi.Mode()&os.ModeNamedPipe == 0 && !fi.Mode().IsRegular() {
399		return prompt, nil
400	}
401	bts, err := io.ReadAll(os.Stdin)
402	if err != nil {
403		return prompt, err
404	}
405	return string(bts) + "\n\n" + prompt, nil
406}
407
408func ResolveCwd(cmd *cobra.Command) (string, error) {
409	cwd, _ := cmd.Flags().GetString("cwd")
410	if cwd != "" {
411		err := os.Chdir(cwd)
412		if err != nil {
413			return "", fmt.Errorf("failed to change directory: %v", err)
414		}
415		return cwd, nil
416	}
417	cwd, err := os.Getwd()
418	if err != nil {
419		return "", fmt.Errorf("failed to get current working directory: %v", err)
420	}
421	return cwd, nil
422}