1package cmd
2
3import (
4 "context"
5 "fmt"
6 "os"
7 "sync"
8 "time"
9
10 tea "github.com/charmbracelet/bubbletea/v2"
11 "github.com/charmbracelet/crush/internal/app"
12 "github.com/charmbracelet/crush/internal/config"
13 "github.com/charmbracelet/crush/internal/db"
14 "github.com/charmbracelet/crush/internal/format"
15 "github.com/charmbracelet/crush/internal/llm/agent"
16 "github.com/charmbracelet/crush/internal/logging"
17 "github.com/charmbracelet/crush/internal/pubsub"
18 "github.com/charmbracelet/crush/internal/tui"
19 "github.com/charmbracelet/crush/internal/version"
20 "github.com/spf13/cobra"
21)
22
23var rootCmd = &cobra.Command{
24 Use: "crush",
25 Short: "Terminal-based AI assistant for software development",
26 Long: `Crush is a powerful terminal-based AI assistant that helps with software development tasks.
27It provides an interactive chat interface with AI capabilities, code analysis, and LSP integration
28to assist developers in writing, debugging, and understanding code directly from the terminal.`,
29 Example: `
30 # Run in interactive mode
31 crush
32
33 # Run with debug logging
34 crush -d
35
36 # Run with debug logging in a specific directory
37 crush -d -c /path/to/project
38
39 # Print version
40 crush -v
41
42 # Run a single non-interactive prompt
43 crush -p "Explain the use of context in Go"
44
45 # Run a single non-interactive prompt with JSON output format
46 crush -p "Explain the use of context in Go" -f json
47 `,
48 RunE: func(cmd *cobra.Command, args []string) error {
49 // If the help flag is set, show the help message
50 if cmd.Flag("help").Changed {
51 cmd.Help()
52 return nil
53 }
54 if cmd.Flag("version").Changed {
55 fmt.Println(version.Version)
56 return nil
57 }
58
59 // Load the config
60 debug, _ := cmd.Flags().GetBool("debug")
61 cwd, _ := cmd.Flags().GetString("cwd")
62 prompt, _ := cmd.Flags().GetString("prompt")
63 outputFormat, _ := cmd.Flags().GetString("output-format")
64 quiet, _ := cmd.Flags().GetBool("quiet")
65
66 // Validate format option
67 if !format.IsValid(outputFormat) {
68 return fmt.Errorf("invalid format option: %s\n%s", outputFormat, format.GetHelpText())
69 }
70
71 if cwd != "" {
72 err := os.Chdir(cwd)
73 if err != nil {
74 return fmt.Errorf("failed to change directory: %v", err)
75 }
76 }
77 if cwd == "" {
78 c, err := os.Getwd()
79 if err != nil {
80 return fmt.Errorf("failed to get current working directory: %v", err)
81 }
82 cwd = c
83 }
84 _, err := config.Load(cwd, debug)
85 if err != nil {
86 return err
87 }
88
89 // Connect DB, this will also run migrations
90 conn, err := db.Connect()
91 if err != nil {
92 return err
93 }
94
95 // Create main context for the application
96 ctx, cancel := context.WithCancel(context.Background())
97 defer cancel()
98
99 app, err := app.New(ctx, conn)
100 if err != nil {
101 logging.Error("Failed to create app: %v", err)
102 return err
103 }
104 // Defer shutdown here so it runs for both interactive and non-interactive modes
105 defer app.Shutdown()
106
107 // Initialize MCP tools early for both modes
108 initMCPTools(ctx, app)
109
110 // Non-interactive mode
111 if prompt != "" {
112 // Run non-interactive flow using the App method
113 return app.RunNonInteractive(ctx, prompt, outputFormat, quiet)
114 }
115
116 // Set up the TUI
117 program := tea.NewProgram(
118 tui.New(app),
119 tea.WithAltScreen(),
120 tea.WithKeyReleases(),
121 tea.WithUniformKeyLayout(),
122 )
123
124 // Setup the subscriptions, this will send services events to the TUI
125 ch, cancelSubs := setupSubscriptions(app, ctx)
126
127 // Create a context for the TUI message handler
128 tuiCtx, tuiCancel := context.WithCancel(ctx)
129 var tuiWg sync.WaitGroup
130 tuiWg.Add(1)
131
132 // Set up message handling for the TUI
133 go func() {
134 defer tuiWg.Done()
135 defer logging.RecoverPanic("TUI-message-handler", func() {
136 attemptTUIRecovery(program)
137 })
138
139 for {
140 select {
141 case <-tuiCtx.Done():
142 logging.Info("TUI message handler shutting down")
143 return
144 case msg, ok := <-ch:
145 if !ok {
146 logging.Info("TUI message channel closed")
147 return
148 }
149 program.Send(msg)
150 }
151 }
152 }()
153
154 // Cleanup function for when the program exits
155 cleanup := func() {
156 // Shutdown the app
157 app.Shutdown()
158
159 // Cancel subscriptions first
160 cancelSubs()
161
162 // Then cancel TUI message handler
163 tuiCancel()
164
165 // Wait for TUI message handler to finish
166 tuiWg.Wait()
167
168 logging.Info("All goroutines cleaned up")
169 }
170
171 // Run the TUI
172 result, err := program.Run()
173 cleanup()
174
175 if err != nil {
176 logging.Error("TUI error: %v", err)
177 return fmt.Errorf("TUI error: %v", err)
178 }
179
180 logging.Info("TUI exited with result: %v", result)
181 return nil
182 },
183}
184
185// attemptTUIRecovery tries to recover the TUI after a panic
186func attemptTUIRecovery(program *tea.Program) {
187 logging.Info("Attempting to recover TUI after panic")
188
189 // We could try to restart the TUI or gracefully exit
190 // For now, we'll just quit the program to avoid further issues
191 program.Quit()
192}
193
194func initMCPTools(ctx context.Context, app *app.App) {
195 go func() {
196 defer logging.RecoverPanic("MCP-goroutine", nil)
197
198 // Create a context with timeout for the initial MCP tools fetch
199 ctxWithTimeout, cancel := context.WithTimeout(ctx, 30*time.Second)
200 defer cancel()
201
202 // Set this up once with proper error handling
203 agent.GetMcpTools(ctxWithTimeout, app.Permissions)
204 logging.Info("MCP message handling goroutine exiting")
205 }()
206}
207
208func setupSubscriber[T any](
209 ctx context.Context,
210 wg *sync.WaitGroup,
211 name string,
212 subscriber func(context.Context) <-chan pubsub.Event[T],
213 outputCh chan<- tea.Msg,
214) {
215 wg.Add(1)
216 go func() {
217 defer wg.Done()
218 defer logging.RecoverPanic(fmt.Sprintf("subscription-%s", name), nil)
219
220 subCh := subscriber(ctx)
221
222 for {
223 select {
224 case event, ok := <-subCh:
225 if !ok {
226 logging.Info("subscription channel closed", "name", name)
227 return
228 }
229
230 var msg tea.Msg = event
231
232 select {
233 case outputCh <- msg:
234 case <-time.After(2 * time.Second):
235 logging.Warn("message dropped due to slow consumer", "name", name)
236 case <-ctx.Done():
237 logging.Info("subscription cancelled", "name", name)
238 return
239 }
240 case <-ctx.Done():
241 logging.Info("subscription cancelled", "name", name)
242 return
243 }
244 }
245 }()
246}
247
248func setupSubscriptions(app *app.App, parentCtx context.Context) (chan tea.Msg, func()) {
249 ch := make(chan tea.Msg, 100)
250
251 wg := sync.WaitGroup{}
252 ctx, cancel := context.WithCancel(parentCtx) // Inherit from parent context
253
254 setupSubscriber(ctx, &wg, "logging", logging.Subscribe, ch)
255 setupSubscriber(ctx, &wg, "sessions", app.Sessions.Subscribe, ch)
256 setupSubscriber(ctx, &wg, "messages", app.Messages.Subscribe, ch)
257 setupSubscriber(ctx, &wg, "permissions", app.Permissions.Subscribe, ch)
258 setupSubscriber(ctx, &wg, "coderAgent", app.CoderAgent.Subscribe, ch)
259 setupSubscriber(ctx, &wg, "history", app.History.Subscribe, ch)
260
261 cleanupFunc := func() {
262 logging.Info("Cancelling all subscriptions")
263 cancel() // Signal all goroutines to stop
264
265 waitCh := make(chan struct{})
266 go func() {
267 defer logging.RecoverPanic("subscription-cleanup", nil)
268 wg.Wait()
269 close(waitCh)
270 }()
271
272 select {
273 case <-waitCh:
274 logging.Info("All subscription goroutines completed successfully")
275 close(ch) // Only close after all writers are confirmed done
276 case <-time.After(5 * time.Second):
277 logging.Warn("Timed out waiting for some subscription goroutines to complete")
278 close(ch)
279 }
280 }
281 return ch, cleanupFunc
282}
283
284func Execute() {
285 err := rootCmd.Execute()
286 if err != nil {
287 os.Exit(1)
288 }
289}
290
291func init() {
292 rootCmd.Flags().BoolP("help", "h", false, "Help")
293 rootCmd.Flags().BoolP("version", "v", false, "Version")
294 rootCmd.Flags().BoolP("debug", "d", false, "Debug")
295 rootCmd.Flags().StringP("cwd", "c", "", "Current working directory")
296 rootCmd.Flags().StringP("prompt", "p", "", "Prompt to run in non-interactive mode")
297
298 // Add format flag with validation logic
299 rootCmd.Flags().StringP("output-format", "f", format.Text.String(),
300 "Output format for non-interactive mode (text, json)")
301
302 // Add quiet flag to hide spinner in non-interactive mode
303 rootCmd.Flags().BoolP("quiet", "q", false, "Hide spinner in non-interactive mode")
304
305 // Register custom validation for the format flag
306 rootCmd.RegisterFlagCompletionFunc("output-format", func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
307 return format.SupportedFormats, cobra.ShellCompDirectiveNoFileComp
308 })
309}