1package cmd
2
3import (
4 "context"
5 "fmt"
6 "os"
7 "sync"
8 "time"
9
10 tea "github.com/charmbracelet/bubbletea/v2"
11 "github.com/charmbracelet/crush/internal/app"
12 "github.com/charmbracelet/crush/internal/config"
13 "github.com/charmbracelet/crush/internal/db"
14 "github.com/charmbracelet/crush/internal/format"
15 "github.com/charmbracelet/crush/internal/llm/agent"
16 "github.com/charmbracelet/crush/internal/logging"
17 "github.com/charmbracelet/crush/internal/pubsub"
18 "github.com/charmbracelet/crush/internal/tui"
19 "github.com/charmbracelet/crush/internal/version"
20 "github.com/charmbracelet/fang"
21 "github.com/spf13/cobra"
22)
23
24var rootCmd = &cobra.Command{
25 Use: "crush",
26 Short: "Terminal-based AI assistant for software development",
27 Long: `Crush is a powerful terminal-based AI assistant that helps with software development tasks.
28It provides an interactive chat interface with AI capabilities, code analysis, and LSP integration
29to assist developers in writing, debugging, and understanding code directly from the terminal.`,
30 Example: `
31 # Run in interactive mode
32 crush
33
34 # Run with debug logging
35 crush -d
36
37 # Run with debug logging in a specific directory
38 crush -d -c /path/to/project
39
40 # Print version
41 crush -v
42
43 # Run a single non-interactive prompt
44 crush -p "Explain the use of context in Go"
45
46 # Run a single non-interactive prompt with JSON output format
47 crush -p "Explain the use of context in Go" -f json
48 `,
49 RunE: func(cmd *cobra.Command, args []string) error {
50 // Load the config
51 debug, _ := cmd.Flags().GetBool("debug")
52 cwd, _ := cmd.Flags().GetString("cwd")
53 prompt, _ := cmd.Flags().GetString("prompt")
54 outputFormat, _ := cmd.Flags().GetString("output-format")
55 quiet, _ := cmd.Flags().GetBool("quiet")
56
57 // Validate format option
58 if !format.IsValid(outputFormat) {
59 return fmt.Errorf("invalid format option: %s\n%s", outputFormat, format.GetHelpText())
60 }
61
62 if cwd != "" {
63 err := os.Chdir(cwd)
64 if err != nil {
65 return fmt.Errorf("failed to change directory: %v", err)
66 }
67 }
68 if cwd == "" {
69 c, err := os.Getwd()
70 if err != nil {
71 return fmt.Errorf("failed to get current working directory: %v", err)
72 }
73 cwd = c
74 }
75 _, err := config.Load(cwd, debug)
76 if err != nil {
77 return err
78 }
79
80 // Create main context for the application
81 ctx, cancel := context.WithCancel(context.Background())
82 defer cancel()
83
84 // Connect DB, this will also run migrations
85 conn, err := db.Connect(ctx)
86 if err != nil {
87 return err
88 }
89
90 app, err := app.New(ctx, conn)
91 if err != nil {
92 logging.Error("Failed to create app: %v", err)
93 return err
94 }
95 // Defer shutdown here so it runs for both interactive and non-interactive modes
96 defer app.Shutdown()
97
98 // Initialize MCP tools early for both modes
99 initMCPTools(ctx, app)
100
101 // Non-interactive mode
102 if prompt != "" {
103 // Run non-interactive flow using the App method
104 return app.RunNonInteractive(ctx, prompt, outputFormat, quiet)
105 }
106
107 // Set up the TUI
108 program := tea.NewProgram(
109 tui.New(app),
110 tea.WithAltScreen(),
111 tea.WithKeyReleases(),
112 tea.WithUniformKeyLayout(),
113 )
114
115 // Setup the subscriptions, this will send services events to the TUI
116 ch, cancelSubs := setupSubscriptions(app, ctx)
117
118 // Create a context for the TUI message handler
119 tuiCtx, tuiCancel := context.WithCancel(ctx)
120 var tuiWg sync.WaitGroup
121 tuiWg.Add(1)
122
123 // Set up message handling for the TUI
124 go func() {
125 defer tuiWg.Done()
126 defer logging.RecoverPanic("TUI-message-handler", func() {
127 attemptTUIRecovery(program)
128 })
129
130 for {
131 select {
132 case <-tuiCtx.Done():
133 logging.Info("TUI message handler shutting down")
134 return
135 case msg, ok := <-ch:
136 if !ok {
137 logging.Info("TUI message channel closed")
138 return
139 }
140 program.Send(msg)
141 }
142 }
143 }()
144
145 // Cleanup function for when the program exits
146 cleanup := func() {
147 // Shutdown the app
148 app.Shutdown()
149
150 // Cancel subscriptions first
151 cancelSubs()
152
153 // Then cancel TUI message handler
154 tuiCancel()
155
156 // Wait for TUI message handler to finish
157 tuiWg.Wait()
158
159 logging.Info("All goroutines cleaned up")
160 }
161
162 // Run the TUI
163 result, err := program.Run()
164 cleanup()
165
166 if err != nil {
167 logging.Error("TUI error: %v", err)
168 return fmt.Errorf("TUI error: %v", err)
169 }
170
171 logging.Info("TUI exited with result: %v", result)
172 return nil
173 },
174}
175
176// attemptTUIRecovery tries to recover the TUI after a panic
177func attemptTUIRecovery(program *tea.Program) {
178 logging.Info("Attempting to recover TUI after panic")
179
180 // We could try to restart the TUI or gracefully exit
181 // For now, we'll just quit the program to avoid further issues
182 program.Quit()
183}
184
185func initMCPTools(ctx context.Context, app *app.App) {
186 go func() {
187 defer logging.RecoverPanic("MCP-goroutine", nil)
188
189 // Create a context with timeout for the initial MCP tools fetch
190 ctxWithTimeout, cancel := context.WithTimeout(ctx, 30*time.Second)
191 defer cancel()
192
193 // Set this up once with proper error handling
194 agent.GetMcpTools(ctxWithTimeout, app.Permissions)
195 logging.Info("MCP message handling goroutine exiting")
196 }()
197}
198
199func setupSubscriber[T any](
200 ctx context.Context,
201 wg *sync.WaitGroup,
202 name string,
203 subscriber func(context.Context) <-chan pubsub.Event[T],
204 outputCh chan<- tea.Msg,
205) {
206 wg.Add(1)
207 go func() {
208 defer wg.Done()
209 defer logging.RecoverPanic(fmt.Sprintf("subscription-%s", name), nil)
210
211 subCh := subscriber(ctx)
212
213 for {
214 select {
215 case event, ok := <-subCh:
216 if !ok {
217 logging.Info("subscription channel closed", "name", name)
218 return
219 }
220
221 var msg tea.Msg = event
222
223 select {
224 case outputCh <- msg:
225 case <-time.After(2 * time.Second):
226 logging.Warn("message dropped due to slow consumer", "name", name)
227 case <-ctx.Done():
228 logging.Info("subscription cancelled", "name", name)
229 return
230 }
231 case <-ctx.Done():
232 logging.Info("subscription cancelled", "name", name)
233 return
234 }
235 }
236 }()
237}
238
239func setupSubscriptions(app *app.App, parentCtx context.Context) (chan tea.Msg, func()) {
240 ch := make(chan tea.Msg, 100)
241
242 wg := sync.WaitGroup{}
243 ctx, cancel := context.WithCancel(parentCtx) // Inherit from parent context
244
245 setupSubscriber(ctx, &wg, "logging", logging.Subscribe, ch)
246 setupSubscriber(ctx, &wg, "sessions", app.Sessions.Subscribe, ch)
247 setupSubscriber(ctx, &wg, "messages", app.Messages.Subscribe, ch)
248 setupSubscriber(ctx, &wg, "permissions", app.Permissions.Subscribe, ch)
249 setupSubscriber(ctx, &wg, "coderAgent", app.CoderAgent.Subscribe, ch)
250 setupSubscriber(ctx, &wg, "history", app.History.Subscribe, ch)
251
252 cleanupFunc := func() {
253 logging.Info("Cancelling all subscriptions")
254 cancel() // Signal all goroutines to stop
255
256 waitCh := make(chan struct{})
257 go func() {
258 defer logging.RecoverPanic("subscription-cleanup", nil)
259 wg.Wait()
260 close(waitCh)
261 }()
262
263 select {
264 case <-waitCh:
265 logging.Info("All subscription goroutines completed successfully")
266 close(ch) // Only close after all writers are confirmed done
267 case <-time.After(5 * time.Second):
268 logging.Warn("Timed out waiting for some subscription goroutines to complete")
269 close(ch)
270 }
271 }
272 return ch, cleanupFunc
273}
274
275func Execute() {
276 if err := fang.Execute(
277 context.Background(),
278 rootCmd,
279 fang.WithVersion(version.Version),
280 ); err != nil {
281 os.Exit(1)
282 }
283}
284
285func init() {
286 rootCmd.Flags().BoolP("help", "h", false, "Help")
287 rootCmd.Flags().BoolP("debug", "d", false, "Debug")
288 rootCmd.Flags().StringP("cwd", "c", "", "Current working directory")
289 rootCmd.Flags().StringP("prompt", "p", "", "Prompt to run in non-interactive mode")
290
291 // Add format flag with validation logic
292 rootCmd.Flags().StringP("output-format", "f", format.Text.String(),
293 "Output format for non-interactive mode (text, json)")
294
295 // Add quiet flag to hide spinner in non-interactive mode
296 rootCmd.Flags().BoolP("quiet", "q", false, "Hide spinner in non-interactive mode")
297
298 // Register custom validation for the format flag
299 rootCmd.RegisterFlagCompletionFunc("output-format", func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
300 return format.SupportedFormats, cobra.ShellCompDirectiveNoFileComp
301 })
302}