1package cmd
2
3import (
4 "context"
5 "fmt"
6 "io"
7 "log/slog"
8 "os"
9 "sync"
10 "time"
11
12 tea "github.com/charmbracelet/bubbletea/v2"
13 "github.com/charmbracelet/crush/internal/app"
14 "github.com/charmbracelet/crush/internal/config"
15 "github.com/charmbracelet/crush/internal/db"
16 "github.com/charmbracelet/crush/internal/format"
17 "github.com/charmbracelet/crush/internal/llm/agent"
18 "github.com/charmbracelet/crush/internal/log"
19 "github.com/charmbracelet/crush/internal/pubsub"
20 "github.com/charmbracelet/crush/internal/tui"
21 "github.com/charmbracelet/crush/internal/version"
22 "github.com/charmbracelet/fang"
23 "github.com/charmbracelet/x/term"
24 "github.com/spf13/cobra"
25)
26
27var rootCmd = &cobra.Command{
28 Use: "crush",
29 Short: "Terminal-based AI assistant for software development",
30 Long: `Crush is a powerful terminal-based AI assistant that helps with software development tasks.
31It provides an interactive chat interface with AI capabilities, code analysis, and LSP integration
32to assist developers in writing, debugging, and understanding code directly from the terminal.`,
33 Example: `
34 # Run in interactive mode
35 crush
36
37 # Run with debug logging
38 crush -d
39
40 # Run with debug slog.in a specific directory
41 crush -d -c /path/to/project
42
43 # Print version
44 crush -v
45
46 # Run a single non-interactive prompt
47 crush -p "Explain the use of context in Go"
48
49 # Run a single non-interactive prompt with JSON output format
50 crush -p "Explain the use of context in Go" -f json
51 `,
52 RunE: func(cmd *cobra.Command, args []string) error {
53 // Load the config
54 debug, _ := cmd.Flags().GetBool("debug")
55 cwd, _ := cmd.Flags().GetString("cwd")
56 prompt, _ := cmd.Flags().GetString("prompt")
57 outputFormat, _ := cmd.Flags().GetString("output-format")
58 quiet, _ := cmd.Flags().GetBool("quiet")
59
60 // Validate format option
61 if !format.IsValid(outputFormat) {
62 return fmt.Errorf("invalid format option: %s\n%s", outputFormat, format.GetHelpText())
63 }
64
65 if cwd != "" {
66 err := os.Chdir(cwd)
67 if err != nil {
68 return fmt.Errorf("failed to change directory: %v", err)
69 }
70 }
71 if cwd == "" {
72 c, err := os.Getwd()
73 if err != nil {
74 return fmt.Errorf("failed to get current working directory: %v", err)
75 }
76 cwd = c
77 }
78
79 cfg, err := config.Init(cwd, debug)
80 if err != nil {
81 return err
82 }
83
84 // Create main context for the application
85 ctx, cancel := context.WithCancel(context.Background())
86 defer cancel()
87
88 // Connect DB, this will also run migrations
89 conn, err := db.Connect(ctx, cfg.Options.DataDirectory)
90 if err != nil {
91 return err
92 }
93
94 app, err := app.New(ctx, conn, cfg)
95 if err != nil {
96 slog.Error(fmt.Sprintf("Failed to create app instance: %v", err))
97 return err
98 }
99 // Defer shutdown here so it runs for both interactive and non-interactive modes
100 defer app.Shutdown()
101
102 // Initialize MCP tools early for both modes
103 initMCPTools(ctx, app, cfg)
104
105 prompt, err = maybePrependStdin(prompt)
106 if err != nil {
107 slog.Error(fmt.Sprintf("Failed to read from stdin: %v", err))
108 return err
109 }
110
111 // Non-interactive mode
112 if prompt != "" {
113 // Run non-interactive flow using the App method
114 return app.RunNonInteractive(ctx, prompt, outputFormat, quiet)
115 }
116
117 // Set up the TUI
118 program := tea.NewProgram(
119 tui.New(app),
120 tea.WithAltScreen(),
121 tea.WithKeyReleases(),
122 tea.WithUniformKeyLayout(),
123 tea.WithMouseCellMotion(), // Use cell motion instead of all motion to reduce event flooding
124 tea.WithFilter(tui.MouseEventFilter), // Filter mouse events based on focus state
125 )
126
127 // Setup the subscriptions, this will send services events to the TUI
128 ch, cancelSubs := setupSubscriptions(app, ctx)
129
130 // Create a context for the TUI message handler
131 tuiCtx, tuiCancel := context.WithCancel(ctx)
132 var tuiWg sync.WaitGroup
133 tuiWg.Add(1)
134
135 // Set up message handling for the TUI
136 go func() {
137 defer tuiWg.Done()
138 defer log.RecoverPanic("TUI-message-handler", func() {
139 attemptTUIRecovery(program)
140 })
141
142 for {
143 select {
144 case <-tuiCtx.Done():
145 slog.Info("TUI message handler shutting down")
146 return
147 case msg, ok := <-ch:
148 if !ok {
149 slog.Info("TUI message channel closed")
150 return
151 }
152 program.Send(msg)
153 }
154 }
155 }()
156
157 // Cleanup function for when the program exits
158 cleanup := func() {
159 // Shutdown the app
160 app.Shutdown()
161
162 // Cancel subscriptions first
163 cancelSubs()
164
165 // Then cancel TUI message handler
166 tuiCancel()
167
168 // Wait for TUI message handler to finish
169 tuiWg.Wait()
170
171 slog.Info("All goroutines cleaned up")
172 }
173
174 // Run the TUI
175 result, err := program.Run()
176 cleanup()
177
178 if err != nil {
179 slog.Error(fmt.Sprintf("TUI run error: %v", err))
180 return fmt.Errorf("TUI error: %v", err)
181 }
182
183 slog.Info(fmt.Sprintf("TUI exited with result: %v", result))
184 return nil
185 },
186}
187
188// attemptTUIRecovery tries to recover the TUI after a panic
189func attemptTUIRecovery(program *tea.Program) {
190 slog.Info("Attempting to recover TUI after panic")
191
192 // We could try to restart the TUI or gracefully exit
193 // For now, we'll just quit the program to avoid further issues
194 program.Quit()
195}
196
197func initMCPTools(ctx context.Context, app *app.App, cfg *config.Config) {
198 go func() {
199 defer log.RecoverPanic("MCP-goroutine", nil)
200
201 // Create a context with timeout for the initial MCP tools fetch
202 ctxWithTimeout, cancel := context.WithTimeout(ctx, 30*time.Second)
203 defer cancel()
204
205 // Set this up once with proper error handling
206 agent.GetMcpTools(ctxWithTimeout, app.Permissions, cfg)
207 slog.Info("MCP message handling goroutine exiting")
208 }()
209}
210
211func setupSubscriber[T any](
212 ctx context.Context,
213 wg *sync.WaitGroup,
214 name string,
215 subscriber func(context.Context) <-chan pubsub.Event[T],
216 outputCh chan<- tea.Msg,
217) {
218 wg.Add(1)
219 go func() {
220 defer wg.Done()
221 defer log.RecoverPanic(fmt.Sprintf("subscription-%s", name), nil)
222
223 subCh := subscriber(ctx)
224
225 for {
226 select {
227 case event, ok := <-subCh:
228 if !ok {
229 slog.Info("subscription channel closed", "name", name)
230 return
231 }
232
233 var msg tea.Msg = event
234
235 select {
236 case outputCh <- msg:
237 case <-time.After(2 * time.Second):
238 slog.Warn("message dropped due to slow consumer", "name", name)
239 case <-ctx.Done():
240 slog.Info("subscription canceled", "name", name)
241 return
242 }
243 case <-ctx.Done():
244 slog.Info("subscription canceled", "name", name)
245 return
246 }
247 }
248 }()
249}
250
251func setupSubscriptions(app *app.App, parentCtx context.Context) (chan tea.Msg, func()) {
252 ch := make(chan tea.Msg, 100)
253
254 wg := sync.WaitGroup{}
255 ctx, cancel := context.WithCancel(parentCtx) // Inherit from parent context
256
257 setupSubscriber(ctx, &wg, "sessions", app.Sessions.Subscribe, ch)
258 setupSubscriber(ctx, &wg, "messages", app.Messages.Subscribe, ch)
259 setupSubscriber(ctx, &wg, "permissions", app.Permissions.Subscribe, ch)
260 setupSubscriber(ctx, &wg, "coderAgent", app.CoderAgent.Subscribe, ch)
261 setupSubscriber(ctx, &wg, "history", app.History.Subscribe, ch)
262
263 cleanupFunc := func() {
264 slog.Info("Cancelling all subscriptions")
265 cancel() // Signal all goroutines to stop
266
267 waitCh := make(chan struct{})
268 go func() {
269 defer log.RecoverPanic("subscription-cleanup", nil)
270 wg.Wait()
271 close(waitCh)
272 }()
273
274 select {
275 case <-waitCh:
276 slog.Info("All subscription goroutines completed successfully")
277 close(ch) // Only close after all writers are confirmed done
278 case <-time.After(5 * time.Second):
279 slog.Warn("Timed out waiting for some subscription goroutines to complete")
280 close(ch)
281 }
282 }
283 return ch, cleanupFunc
284}
285
286func Execute() {
287 if err := fang.Execute(
288 context.Background(),
289 rootCmd,
290 fang.WithVersion(version.Version),
291 ); err != nil {
292 os.Exit(1)
293 }
294}
295
296func init() {
297 rootCmd.PersistentFlags().StringP("cwd", "c", "", "Current working directory")
298
299 rootCmd.Flags().BoolP("help", "h", false, "Help")
300 rootCmd.Flags().BoolP("debug", "d", false, "Debug")
301 rootCmd.Flags().StringP("prompt", "p", "", "Prompt to run in non-interactive mode")
302
303 // Add format flag with validation logic
304 rootCmd.Flags().StringP("output-format", "f", format.Text.String(),
305 "Output format for non-interactive mode (text, json)")
306
307 // Add quiet flag to hide spinner in non-interactive mode
308 rootCmd.Flags().BoolP("quiet", "q", false, "Hide spinner in non-interactive mode")
309
310 // Register custom validation for the format flag
311 rootCmd.RegisterFlagCompletionFunc("output-format", func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
312 return format.SupportedFormats, cobra.ShellCompDirectiveNoFileComp
313 })
314}
315
316func maybePrependStdin(prompt string) (string, error) {
317 if term.IsTerminal(os.Stdin.Fd()) {
318 return prompt, nil
319 }
320 fi, err := os.Stdin.Stat()
321 if err != nil {
322 return prompt, err
323 }
324 if fi.Mode()&os.ModeNamedPipe == 0 {
325 return prompt, nil
326 }
327 bts, err := io.ReadAll(os.Stdin)
328 if err != nil {
329 return prompt, err
330 }
331 return string(bts) + "\n\n" + prompt, nil
332}