1package cmd
2
3import (
4 "context"
5 "fmt"
6 "io"
7 "log/slog"
8 "os"
9 "sync"
10 "time"
11
12 tea "github.com/charmbracelet/bubbletea/v2"
13 "github.com/charmbracelet/crush/internal/app"
14 "github.com/charmbracelet/crush/internal/config"
15 "github.com/charmbracelet/crush/internal/db"
16 "github.com/charmbracelet/crush/internal/format"
17 "github.com/charmbracelet/crush/internal/llm/agent"
18 "github.com/charmbracelet/crush/internal/log"
19 "github.com/charmbracelet/crush/internal/pubsub"
20 "github.com/charmbracelet/crush/internal/tui"
21 "github.com/charmbracelet/crush/internal/version"
22 "github.com/charmbracelet/fang"
23 "github.com/charmbracelet/x/term"
24 "github.com/spf13/cobra"
25)
26
27var rootCmd = &cobra.Command{
28 Use: "crush",
29 Short: "Terminal-based AI assistant for software development",
30 Long: `Crush is a powerful terminal-based AI assistant that helps with software development tasks.
31It provides an interactive chat interface with AI capabilities, code analysis, and LSP integration
32to assist developers in writing, debugging, and understanding code directly from the terminal.`,
33 Example: `
34 # Run in interactive mode
35 crush
36
37 # Run with debug logging
38 crush -d
39
40 # Run with debug slog.in a specific directory
41 crush -d -c /path/to/project
42
43 # Print version
44 crush -v
45
46 # Run a single non-interactive prompt
47 crush -p "Explain the use of context in Go"
48
49 # Run a single non-interactive prompt with JSON output format
50 crush -p "Explain the use of context in Go" -f json
51 `,
52 RunE: func(cmd *cobra.Command, args []string) error {
53 // Load the config
54 debug, _ := cmd.Flags().GetBool("debug")
55 cwd, _ := cmd.Flags().GetString("cwd")
56 prompt, _ := cmd.Flags().GetString("prompt")
57 outputFormat, _ := cmd.Flags().GetString("output-format")
58 quiet, _ := cmd.Flags().GetBool("quiet")
59
60 // Validate format option
61 if !format.IsValid(outputFormat) {
62 return fmt.Errorf("invalid format option: %s\n%s", outputFormat, format.GetHelpText())
63 }
64
65 if cwd != "" {
66 err := os.Chdir(cwd)
67 if err != nil {
68 return fmt.Errorf("failed to change directory: %v", err)
69 }
70 }
71 if cwd == "" {
72 c, err := os.Getwd()
73 if err != nil {
74 return fmt.Errorf("failed to get current working directory: %v", err)
75 }
76 cwd = c
77 }
78
79 _, err := config.Init(cwd, debug)
80 if err != nil {
81 return err
82 }
83
84 // Create main context for the application
85 ctx, cancel := context.WithCancel(context.Background())
86 defer cancel()
87
88 // Connect DB, this will also run migrations
89 conn, err := db.Connect(ctx)
90 if err != nil {
91 return err
92 }
93
94 app, err := app.New(ctx, conn)
95 if err != nil {
96 slog.Error("Failed to create app: %v", err)
97 return err
98 }
99 // Defer shutdown here so it runs for both interactive and non-interactive modes
100 defer app.Shutdown()
101
102 // Initialize MCP tools early for both modes
103 initMCPTools(ctx, app)
104
105 prompt, err = maybePrependStdin(prompt)
106 if err != nil {
107 slog.Error("Failed to read stdin: %v", err)
108 return err
109 }
110
111 // Non-interactive mode
112 if prompt != "" {
113 // Run non-interactive flow using the App method
114 return app.RunNonInteractive(ctx, prompt, outputFormat, quiet)
115 }
116
117 // Set up the TUI
118 program := tea.NewProgram(
119 tui.New(app),
120 tea.WithAltScreen(),
121 tea.WithKeyReleases(),
122 tea.WithUniformKeyLayout(),
123 )
124
125 // Setup the subscriptions, this will send services events to the TUI
126 ch, cancelSubs := setupSubscriptions(app, ctx)
127
128 // Create a context for the TUI message handler
129 tuiCtx, tuiCancel := context.WithCancel(ctx)
130 var tuiWg sync.WaitGroup
131 tuiWg.Add(1)
132
133 // Set up message handling for the TUI
134 go func() {
135 defer tuiWg.Done()
136 defer log.RecoverPanic("TUI-message-handler", func() {
137 attemptTUIRecovery(program)
138 })
139
140 for {
141 select {
142 case <-tuiCtx.Done():
143 slog.Info("TUI message handler shutting down")
144 return
145 case msg, ok := <-ch:
146 if !ok {
147 slog.Info("TUI message channel closed")
148 return
149 }
150 program.Send(msg)
151 }
152 }
153 }()
154
155 // Cleanup function for when the program exits
156 cleanup := func() {
157 // Shutdown the app
158 app.Shutdown()
159
160 // Cancel subscriptions first
161 cancelSubs()
162
163 // Then cancel TUI message handler
164 tuiCancel()
165
166 // Wait for TUI message handler to finish
167 tuiWg.Wait()
168
169 slog.Info("All goroutines cleaned up")
170 }
171
172 // Run the TUI
173 result, err := program.Run()
174 cleanup()
175
176 if err != nil {
177 slog.Error("TUI error: %v", err)
178 return fmt.Errorf("TUI error: %v", err)
179 }
180
181 slog.Info("TUI exited with result: %v", result)
182 return nil
183 },
184}
185
186// attemptTUIRecovery tries to recover the TUI after a panic
187func attemptTUIRecovery(program *tea.Program) {
188 slog.Info("Attempting to recover TUI after panic")
189
190 // We could try to restart the TUI or gracefully exit
191 // For now, we'll just quit the program to avoid further issues
192 program.Quit()
193}
194
195func initMCPTools(ctx context.Context, app *app.App) {
196 go func() {
197 defer log.RecoverPanic("MCP-goroutine", nil)
198
199 // Create a context with timeout for the initial MCP tools fetch
200 ctxWithTimeout, cancel := context.WithTimeout(ctx, 30*time.Second)
201 defer cancel()
202
203 // Set this up once with proper error handling
204 agent.GetMcpTools(ctxWithTimeout, app.Permissions)
205 slog.Info("MCP message handling goroutine exiting")
206 }()
207}
208
209func setupSubscriber[T any](
210 ctx context.Context,
211 wg *sync.WaitGroup,
212 name string,
213 subscriber func(context.Context) <-chan pubsub.Event[T],
214 outputCh chan<- tea.Msg,
215) {
216 wg.Add(1)
217 go func() {
218 defer wg.Done()
219 defer log.RecoverPanic(fmt.Sprintf("subscription-%s", name), nil)
220
221 subCh := subscriber(ctx)
222
223 for {
224 select {
225 case event, ok := <-subCh:
226 if !ok {
227 slog.Info("subscription channel closed", "name", name)
228 return
229 }
230
231 var msg tea.Msg = event
232
233 select {
234 case outputCh <- msg:
235 case <-time.After(2 * time.Second):
236 slog.Warn("message dropped due to slow consumer", "name", name)
237 case <-ctx.Done():
238 slog.Info("subscription cancelled", "name", name)
239 return
240 }
241 case <-ctx.Done():
242 slog.Info("subscription cancelled", "name", name)
243 return
244 }
245 }
246 }()
247}
248
249func setupSubscriptions(app *app.App, parentCtx context.Context) (chan tea.Msg, func()) {
250 ch := make(chan tea.Msg, 100)
251
252 wg := sync.WaitGroup{}
253 ctx, cancel := context.WithCancel(parentCtx) // Inherit from parent context
254
255 setupSubscriber(ctx, &wg, "sessions", app.Sessions.Subscribe, ch)
256 setupSubscriber(ctx, &wg, "messages", app.Messages.Subscribe, ch)
257 setupSubscriber(ctx, &wg, "permissions", app.Permissions.Subscribe, ch)
258 setupSubscriber(ctx, &wg, "coderAgent", app.CoderAgent.Subscribe, ch)
259 setupSubscriber(ctx, &wg, "history", app.History.Subscribe, ch)
260
261 cleanupFunc := func() {
262 slog.Info("Cancelling all subscriptions")
263 cancel() // Signal all goroutines to stop
264
265 waitCh := make(chan struct{})
266 go func() {
267 defer log.RecoverPanic("subscription-cleanup", nil)
268 wg.Wait()
269 close(waitCh)
270 }()
271
272 select {
273 case <-waitCh:
274 slog.Info("All subscription goroutines completed successfully")
275 close(ch) // Only close after all writers are confirmed done
276 case <-time.After(5 * time.Second):
277 slog.Warn("Timed out waiting for some subscription goroutines to complete")
278 close(ch)
279 }
280 }
281 return ch, cleanupFunc
282}
283
284func Execute() {
285 if err := fang.Execute(
286 context.Background(),
287 rootCmd,
288 fang.WithVersion(version.Version),
289 ); err != nil {
290 os.Exit(1)
291 }
292}
293
294func init() {
295 rootCmd.PersistentFlags().StringP("cwd", "c", "", "Current working directory")
296
297 rootCmd.Flags().BoolP("help", "h", false, "Help")
298 rootCmd.Flags().BoolP("debug", "d", false, "Debug")
299 rootCmd.Flags().StringP("prompt", "p", "", "Prompt to run in non-interactive mode")
300
301 // Add format flag with validation logic
302 rootCmd.Flags().StringP("output-format", "f", format.Text.String(),
303 "Output format for non-interactive mode (text, json)")
304
305 // Add quiet flag to hide spinner in non-interactive mode
306 rootCmd.Flags().BoolP("quiet", "q", false, "Hide spinner in non-interactive mode")
307
308 // Register custom validation for the format flag
309 rootCmd.RegisterFlagCompletionFunc("output-format", func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
310 return format.SupportedFormats, cobra.ShellCompDirectiveNoFileComp
311 })
312}
313
314func maybePrependStdin(prompt string) (string, error) {
315 if term.IsTerminal(os.Stdin.Fd()) {
316 return prompt, nil
317 }
318 fi, err := os.Stdin.Stat()
319 if err != nil {
320 return prompt, err
321 }
322 if fi.Mode()&os.ModeNamedPipe == 0 {
323 return prompt, nil
324 }
325 bts, err := io.ReadAll(os.Stdin)
326 if err != nil {
327 return prompt, err
328 }
329 return string(bts) + "\n\n" + prompt, nil
330}