1package cmd
2
3import (
4 "context"
5 "fmt"
6 "io"
7 "os"
8 "sync"
9 "time"
10
11 tea "github.com/charmbracelet/bubbletea/v2"
12 "github.com/charmbracelet/crush/internal/app"
13 "github.com/charmbracelet/crush/internal/config"
14 "github.com/charmbracelet/crush/internal/db"
15 "github.com/charmbracelet/crush/internal/format"
16 "github.com/charmbracelet/crush/internal/llm/agent"
17 "github.com/charmbracelet/crush/internal/logging"
18 "github.com/charmbracelet/crush/internal/pubsub"
19 "github.com/charmbracelet/crush/internal/tui"
20 "github.com/charmbracelet/crush/internal/version"
21 "github.com/charmbracelet/fang"
22 "github.com/charmbracelet/x/term"
23 "github.com/spf13/cobra"
24)
25
26var rootCmd = &cobra.Command{
27 Use: "crush",
28 Short: "Terminal-based AI assistant for software development",
29 Long: `Crush is a powerful terminal-based AI assistant that helps with software development tasks.
30It provides an interactive chat interface with AI capabilities, code analysis, and LSP integration
31to assist developers in writing, debugging, and understanding code directly from the terminal.`,
32 Example: `
33 # Run in interactive mode
34 crush
35
36 # Run with debug logging
37 crush -d
38
39 # Run with debug logging in a specific directory
40 crush -d -c /path/to/project
41
42 # Print version
43 crush -v
44
45 # Run a single non-interactive prompt
46 crush -p "Explain the use of context in Go"
47
48 # Run a single non-interactive prompt with JSON output format
49 crush -p "Explain the use of context in Go" -f json
50 `,
51 RunE: func(cmd *cobra.Command, args []string) error {
52 // Load the config
53 debug, _ := cmd.Flags().GetBool("debug")
54 cwd, _ := cmd.Flags().GetString("cwd")
55 prompt, _ := cmd.Flags().GetString("prompt")
56 outputFormat, _ := cmd.Flags().GetString("output-format")
57 quiet, _ := cmd.Flags().GetBool("quiet")
58
59 // Validate format option
60 if !format.IsValid(outputFormat) {
61 return fmt.Errorf("invalid format option: %s\n%s", outputFormat, format.GetHelpText())
62 }
63
64 if cwd != "" {
65 err := os.Chdir(cwd)
66 if err != nil {
67 return fmt.Errorf("failed to change directory: %v", err)
68 }
69 }
70 if cwd == "" {
71 c, err := os.Getwd()
72 if err != nil {
73 return fmt.Errorf("failed to get current working directory: %v", err)
74 }
75 cwd = c
76 }
77
78 _, err := config.Init(cwd, debug)
79 if err != nil {
80 return err
81 }
82
83 // Create main context for the application
84 ctx, cancel := context.WithCancel(context.Background())
85 defer cancel()
86
87 // Connect DB, this will also run migrations
88 conn, err := db.Connect(ctx)
89 if err != nil {
90 return err
91 }
92
93 app, err := app.New(ctx, conn)
94 if err != nil {
95 logging.Error("Failed to create app: %v", err)
96 return err
97 }
98 // Defer shutdown here so it runs for both interactive and non-interactive modes
99 defer app.Shutdown()
100
101 // Initialize MCP tools early for both modes
102 initMCPTools(ctx, app)
103
104 prompt, err = maybePrependStdin(prompt)
105 if err != nil {
106 logging.Error("Failed to read stdin: %v", err)
107 return err
108 }
109
110 // Non-interactive mode
111 if prompt != "" {
112 // Run non-interactive flow using the App method
113 return app.RunNonInteractive(ctx, prompt, outputFormat, quiet)
114 }
115
116 // Set up the TUI
117 program := tea.NewProgram(
118 tui.New(app),
119 tea.WithAltScreen(),
120 tea.WithKeyReleases(),
121 tea.WithUniformKeyLayout(),
122 )
123
124 // Setup the subscriptions, this will send services events to the TUI
125 ch, cancelSubs := setupSubscriptions(app, ctx)
126
127 // Create a context for the TUI message handler
128 tuiCtx, tuiCancel := context.WithCancel(ctx)
129 var tuiWg sync.WaitGroup
130 tuiWg.Add(1)
131
132 // Set up message handling for the TUI
133 go func() {
134 defer tuiWg.Done()
135 defer logging.RecoverPanic("TUI-message-handler", func() {
136 attemptTUIRecovery(program)
137 })
138
139 for {
140 select {
141 case <-tuiCtx.Done():
142 logging.Info("TUI message handler shutting down")
143 return
144 case msg, ok := <-ch:
145 if !ok {
146 logging.Info("TUI message channel closed")
147 return
148 }
149 program.Send(msg)
150 }
151 }
152 }()
153
154 // Cleanup function for when the program exits
155 cleanup := func() {
156 // Shutdown the app
157 app.Shutdown()
158
159 // Cancel subscriptions first
160 cancelSubs()
161
162 // Then cancel TUI message handler
163 tuiCancel()
164
165 // Wait for TUI message handler to finish
166 tuiWg.Wait()
167
168 logging.Info("All goroutines cleaned up")
169 }
170
171 // Run the TUI
172 result, err := program.Run()
173 cleanup()
174
175 if err != nil {
176 logging.Error("TUI error: %v", err)
177 return fmt.Errorf("TUI error: %v", err)
178 }
179
180 logging.Info("TUI exited with result: %v", result)
181 return nil
182 },
183}
184
185// attemptTUIRecovery tries to recover the TUI after a panic
186func attemptTUIRecovery(program *tea.Program) {
187 logging.Info("Attempting to recover TUI after panic")
188
189 // We could try to restart the TUI or gracefully exit
190 // For now, we'll just quit the program to avoid further issues
191 program.Quit()
192}
193
194func initMCPTools(ctx context.Context, app *app.App) {
195 go func() {
196 defer logging.RecoverPanic("MCP-goroutine", nil)
197
198 // Create a context with timeout for the initial MCP tools fetch
199 ctxWithTimeout, cancel := context.WithTimeout(ctx, 30*time.Second)
200 defer cancel()
201
202 // Set this up once with proper error handling
203 agent.GetMcpTools(ctxWithTimeout, app.Permissions)
204 logging.Info("MCP message handling goroutine exiting")
205 }()
206}
207
208func setupSubscriber[T any](
209 ctx context.Context,
210 wg *sync.WaitGroup,
211 name string,
212 subscriber func(context.Context) <-chan pubsub.Event[T],
213 outputCh chan<- tea.Msg,
214) {
215 wg.Add(1)
216 go func() {
217 defer wg.Done()
218 defer logging.RecoverPanic(fmt.Sprintf("subscription-%s", name), nil)
219
220 subCh := subscriber(ctx)
221
222 for {
223 select {
224 case event, ok := <-subCh:
225 if !ok {
226 logging.Info("subscription channel closed", "name", name)
227 return
228 }
229
230 var msg tea.Msg = event
231
232 select {
233 case outputCh <- msg:
234 case <-time.After(2 * time.Second):
235 logging.Warn("message dropped due to slow consumer", "name", name)
236 case <-ctx.Done():
237 logging.Info("subscription cancelled", "name", name)
238 return
239 }
240 case <-ctx.Done():
241 logging.Info("subscription cancelled", "name", name)
242 return
243 }
244 }
245 }()
246}
247
248func setupSubscriptions(app *app.App, parentCtx context.Context) (chan tea.Msg, func()) {
249 ch := make(chan tea.Msg, 100)
250
251 wg := sync.WaitGroup{}
252 ctx, cancel := context.WithCancel(parentCtx) // Inherit from parent context
253
254 setupSubscriber(ctx, &wg, "logging", logging.Subscribe, ch)
255 setupSubscriber(ctx, &wg, "sessions", app.Sessions.Subscribe, ch)
256 setupSubscriber(ctx, &wg, "messages", app.Messages.Subscribe, ch)
257 setupSubscriber(ctx, &wg, "permissions", app.Permissions.Subscribe, ch)
258 setupSubscriber(ctx, &wg, "coderAgent", app.CoderAgent.Subscribe, ch)
259 setupSubscriber(ctx, &wg, "history", app.History.Subscribe, ch)
260
261 cleanupFunc := func() {
262 logging.Info("Cancelling all subscriptions")
263 cancel() // Signal all goroutines to stop
264
265 waitCh := make(chan struct{})
266 go func() {
267 defer logging.RecoverPanic("subscription-cleanup", nil)
268 wg.Wait()
269 close(waitCh)
270 }()
271
272 select {
273 case <-waitCh:
274 logging.Info("All subscription goroutines completed successfully")
275 close(ch) // Only close after all writers are confirmed done
276 case <-time.After(5 * time.Second):
277 logging.Warn("Timed out waiting for some subscription goroutines to complete")
278 close(ch)
279 }
280 }
281 return ch, cleanupFunc
282}
283
284func Execute() {
285 if err := fang.Execute(
286 context.Background(),
287 rootCmd,
288 fang.WithVersion(version.Version),
289 ); err != nil {
290 os.Exit(1)
291 }
292}
293
294func init() {
295 rootCmd.Flags().BoolP("help", "h", false, "Help")
296 rootCmd.Flags().BoolP("debug", "d", false, "Debug")
297 rootCmd.Flags().StringP("cwd", "c", "", "Current working directory")
298 rootCmd.Flags().StringP("prompt", "p", "", "Prompt to run in non-interactive mode")
299
300 // Add format flag with validation logic
301 rootCmd.Flags().StringP("output-format", "f", format.Text.String(),
302 "Output format for non-interactive mode (text, json)")
303
304 // Add quiet flag to hide spinner in non-interactive mode
305 rootCmd.Flags().BoolP("quiet", "q", false, "Hide spinner in non-interactive mode")
306
307 // Register custom validation for the format flag
308 rootCmd.RegisterFlagCompletionFunc("output-format", func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
309 return format.SupportedFormats, cobra.ShellCompDirectiveNoFileComp
310 })
311}
312
313func maybePrependStdin(prompt string) (string, error) {
314 if term.IsTerminal(os.Stdin.Fd()) {
315 return prompt, nil
316 }
317 fi, err := os.Stdin.Stat()
318 if err != nil {
319 return prompt, err
320 }
321 if fi.Mode()&os.ModeNamedPipe == 0 {
322 return prompt, nil
323 }
324 bts, err := io.ReadAll(os.Stdin)
325 if err != nil {
326 return prompt, err
327 }
328 return string(bts) + "\n\n" + prompt, nil
329}