1package cmd
2
3import (
4 "context"
5 "fmt"
6 "os"
7 "sync"
8 "time"
9
10 tea "github.com/charmbracelet/bubbletea/v2"
11 "github.com/charmbracelet/crush/internal/app"
12 "github.com/charmbracelet/crush/internal/config"
13 "github.com/charmbracelet/crush/internal/db"
14 "github.com/charmbracelet/crush/internal/format"
15 "github.com/charmbracelet/crush/internal/llm/agent"
16 "github.com/charmbracelet/crush/internal/logging"
17 "github.com/charmbracelet/crush/internal/pubsub"
18 "github.com/charmbracelet/crush/internal/tui"
19 "github.com/charmbracelet/crush/internal/version"
20 "github.com/spf13/cobra"
21)
22
23var rootCmd = &cobra.Command{
24 Use: "crush",
25 Short: "Terminal-based AI assistant for software development",
26 Long: `Crush is a powerful terminal-based AI assistant that helps with software development tasks.
27It provides an interactive chat interface with AI capabilities, code analysis, and LSP integration
28to assist developers in writing, debugging, and understanding code directly from the terminal.`,
29 Example: `
30 # Run in interactive mode
31 crush
32
33 # Run with debug logging
34 crush -d
35
36 # Run with debug logging in a specific directory
37 crush -d -c /path/to/project
38
39 # Print version
40 crush -v
41
42 # Run a single non-interactive prompt
43 crush -p "Explain the use of context in Go"
44
45 # Run a single non-interactive prompt with JSON output format
46 crush -p "Explain the use of context in Go" -f json
47 `,
48 RunE: func(cmd *cobra.Command, args []string) error {
49 // If the help flag is set, show the help message
50 if cmd.Flag("help").Changed {
51 cmd.Help()
52 return nil
53 }
54 if cmd.Flag("version").Changed {
55 fmt.Println(version.Version)
56 return nil
57 }
58
59 // Load the config
60 debug, _ := cmd.Flags().GetBool("debug")
61 cwd, _ := cmd.Flags().GetString("cwd")
62 prompt, _ := cmd.Flags().GetString("prompt")
63 outputFormat, _ := cmd.Flags().GetString("output-format")
64 quiet, _ := cmd.Flags().GetBool("quiet")
65
66 // Validate format option
67 if !format.IsValid(outputFormat) {
68 return fmt.Errorf("invalid format option: %s\n%s", outputFormat, format.GetHelpText())
69 }
70
71 if cwd != "" {
72 err := os.Chdir(cwd)
73 if err != nil {
74 return fmt.Errorf("failed to change directory: %v", err)
75 }
76 }
77 if cwd == "" {
78 c, err := os.Getwd()
79 if err != nil {
80 return fmt.Errorf("failed to get current working directory: %v", err)
81 }
82 cwd = c
83 }
84 _, err := config.Load(cwd, debug)
85 if err != nil {
86 return err
87 }
88
89 // Connect DB, this will also run migrations
90 conn, err := db.Connect()
91 if err != nil {
92 return err
93 }
94
95 // Create main context for the application
96 ctx, cancel := context.WithCancel(context.Background())
97 defer cancel()
98
99 app, err := app.New(ctx, conn)
100 if err != nil {
101 logging.Error("Failed to create app: %v", err)
102 return err
103 }
104 // Defer shutdown here so it runs for both interactive and non-interactive modes
105 defer app.Shutdown()
106
107 // Initialize MCP tools early for both modes
108 initMCPTools(ctx, app)
109
110 // Non-interactive mode
111 if prompt != "" {
112 // Run non-interactive flow using the App method
113 return app.RunNonInteractive(ctx, prompt, outputFormat, quiet)
114 }
115
116 // Set up the TUI
117 program := tea.NewProgram(
118 tui.New(app),
119 tea.WithAltScreen(),
120 )
121
122 // Setup the subscriptions, this will send services events to the TUI
123 ch, cancelSubs := setupSubscriptions(app, ctx)
124
125 // Create a context for the TUI message handler
126 tuiCtx, tuiCancel := context.WithCancel(ctx)
127 var tuiWg sync.WaitGroup
128 tuiWg.Add(1)
129
130 // Set up message handling for the TUI
131 go func() {
132 defer tuiWg.Done()
133 defer logging.RecoverPanic("TUI-message-handler", func() {
134 attemptTUIRecovery(program)
135 })
136
137 for {
138 select {
139 case <-tuiCtx.Done():
140 logging.Info("TUI message handler shutting down")
141 return
142 case msg, ok := <-ch:
143 if !ok {
144 logging.Info("TUI message channel closed")
145 return
146 }
147 program.Send(msg)
148 }
149 }
150 }()
151
152 // Cleanup function for when the program exits
153 cleanup := func() {
154 // Shutdown the app
155 app.Shutdown()
156
157 // Cancel subscriptions first
158 cancelSubs()
159
160 // Then cancel TUI message handler
161 tuiCancel()
162
163 // Wait for TUI message handler to finish
164 tuiWg.Wait()
165
166 logging.Info("All goroutines cleaned up")
167 }
168
169 // Run the TUI
170 result, err := program.Run()
171 cleanup()
172
173 if err != nil {
174 logging.Error("TUI error: %v", err)
175 return fmt.Errorf("TUI error: %v", err)
176 }
177
178 logging.Info("TUI exited with result: %v", result)
179 return nil
180 },
181}
182
183// attemptTUIRecovery tries to recover the TUI after a panic
184func attemptTUIRecovery(program *tea.Program) {
185 logging.Info("Attempting to recover TUI after panic")
186
187 // We could try to restart the TUI or gracefully exit
188 // For now, we'll just quit the program to avoid further issues
189 program.Quit()
190}
191
192func initMCPTools(ctx context.Context, app *app.App) {
193 go func() {
194 defer logging.RecoverPanic("MCP-goroutine", nil)
195
196 // Create a context with timeout for the initial MCP tools fetch
197 ctxWithTimeout, cancel := context.WithTimeout(ctx, 30*time.Second)
198 defer cancel()
199
200 // Set this up once with proper error handling
201 agent.GetMcpTools(ctxWithTimeout, app.Permissions)
202 logging.Info("MCP message handling goroutine exiting")
203 }()
204}
205
206func setupSubscriber[T any](
207 ctx context.Context,
208 wg *sync.WaitGroup,
209 name string,
210 subscriber func(context.Context) <-chan pubsub.Event[T],
211 outputCh chan<- tea.Msg,
212) {
213 wg.Add(1)
214 go func() {
215 defer wg.Done()
216 defer logging.RecoverPanic(fmt.Sprintf("subscription-%s", name), nil)
217
218 subCh := subscriber(ctx)
219
220 for {
221 select {
222 case event, ok := <-subCh:
223 if !ok {
224 logging.Info("subscription channel closed", "name", name)
225 return
226 }
227
228 var msg tea.Msg = event
229
230 select {
231 case outputCh <- msg:
232 case <-time.After(2 * time.Second):
233 logging.Warn("message dropped due to slow consumer", "name", name)
234 case <-ctx.Done():
235 logging.Info("subscription cancelled", "name", name)
236 return
237 }
238 case <-ctx.Done():
239 logging.Info("subscription cancelled", "name", name)
240 return
241 }
242 }
243 }()
244}
245
246func setupSubscriptions(app *app.App, parentCtx context.Context) (chan tea.Msg, func()) {
247 ch := make(chan tea.Msg, 100)
248
249 wg := sync.WaitGroup{}
250 ctx, cancel := context.WithCancel(parentCtx) // Inherit from parent context
251
252 setupSubscriber(ctx, &wg, "logging", logging.Subscribe, ch)
253 setupSubscriber(ctx, &wg, "sessions", app.Sessions.Subscribe, ch)
254 setupSubscriber(ctx, &wg, "messages", app.Messages.Subscribe, ch)
255 setupSubscriber(ctx, &wg, "permissions", app.Permissions.Subscribe, ch)
256 setupSubscriber(ctx, &wg, "coderAgent", app.CoderAgent.Subscribe, ch)
257
258 cleanupFunc := func() {
259 logging.Info("Cancelling all subscriptions")
260 cancel() // Signal all goroutines to stop
261
262 waitCh := make(chan struct{})
263 go func() {
264 defer logging.RecoverPanic("subscription-cleanup", nil)
265 wg.Wait()
266 close(waitCh)
267 }()
268
269 select {
270 case <-waitCh:
271 logging.Info("All subscription goroutines completed successfully")
272 close(ch) // Only close after all writers are confirmed done
273 case <-time.After(5 * time.Second):
274 logging.Warn("Timed out waiting for some subscription goroutines to complete")
275 close(ch)
276 }
277 }
278 return ch, cleanupFunc
279}
280
281func Execute() {
282 err := rootCmd.Execute()
283 if err != nil {
284 os.Exit(1)
285 }
286}
287
288func init() {
289 rootCmd.Flags().BoolP("help", "h", false, "Help")
290 rootCmd.Flags().BoolP("version", "v", false, "Version")
291 rootCmd.Flags().BoolP("debug", "d", false, "Debug")
292 rootCmd.Flags().StringP("cwd", "c", "", "Current working directory")
293 rootCmd.Flags().StringP("prompt", "p", "", "Prompt to run in non-interactive mode")
294
295 // Add format flag with validation logic
296 rootCmd.Flags().StringP("output-format", "f", format.Text.String(),
297 "Output format for non-interactive mode (text, json)")
298
299 // Add quiet flag to hide spinner in non-interactive mode
300 rootCmd.Flags().BoolP("quiet", "q", false, "Hide spinner in non-interactive mode")
301
302 // Register custom validation for the format flag
303 rootCmd.RegisterFlagCompletionFunc("output-format", func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
304 return format.SupportedFormats, cobra.ShellCompDirectiveNoFileComp
305 })
306}