1package cmd
2
3import (
4 "context"
5 "fmt"
6 "os"
7 "sync"
8 "time"
9
10 tea "github.com/charmbracelet/bubbletea/v2"
11 "github.com/charmbracelet/crush/internal/app"
12 "github.com/charmbracelet/crush/internal/config"
13 "github.com/charmbracelet/crush/internal/db"
14 "github.com/charmbracelet/crush/internal/format"
15 "github.com/charmbracelet/crush/internal/llm/agent"
16 "github.com/charmbracelet/crush/internal/logging"
17 "github.com/charmbracelet/crush/internal/pubsub"
18 "github.com/charmbracelet/crush/internal/tui"
19 "github.com/charmbracelet/crush/internal/version"
20 "github.com/charmbracelet/fang"
21 "github.com/spf13/cobra"
22)
23
24var rootCmd = &cobra.Command{
25 Use: "crush",
26 Short: "Terminal-based AI assistant for software development",
27 Long: `Crush is a powerful terminal-based AI assistant that helps with software development tasks.
28It provides an interactive chat interface with AI capabilities, code analysis, and LSP integration
29to assist developers in writing, debugging, and understanding code directly from the terminal.`,
30 Example: `
31 # Run in interactive mode
32 crush
33
34 # Run with debug logging
35 crush -d
36
37 # Run with debug logging in a specific directory
38 crush -d -c /path/to/project
39
40 # Print version
41 crush -v
42
43 # Run a single non-interactive prompt
44 crush -p "Explain the use of context in Go"
45
46 # Run a single non-interactive prompt with JSON output format
47 crush -p "Explain the use of context in Go" -f json
48 `,
49 RunE: func(cmd *cobra.Command, args []string) error {
50 // Load the config
51 debug, _ := cmd.Flags().GetBool("debug")
52 cwd, _ := cmd.Flags().GetString("cwd")
53 prompt, _ := cmd.Flags().GetString("prompt")
54 outputFormat, _ := cmd.Flags().GetString("output-format")
55 quiet, _ := cmd.Flags().GetBool("quiet")
56
57 // Validate format option
58 if !format.IsValid(outputFormat) {
59 return fmt.Errorf("invalid format option: %s\n%s", outputFormat, format.GetHelpText())
60 }
61
62 if cwd != "" {
63 err := os.Chdir(cwd)
64 if err != nil {
65 return fmt.Errorf("failed to change directory: %v", err)
66 }
67 }
68 if cwd == "" {
69 c, err := os.Getwd()
70 if err != nil {
71 return fmt.Errorf("failed to get current working directory: %v", err)
72 }
73 cwd = c
74 }
75
76 _, err := config.Init(cwd, debug)
77 if err != nil {
78 return err
79 }
80
81 // Connect DB, this will also run migrations
82 conn, err := db.Connect()
83 if err != nil {
84 return err
85 }
86
87 // Create main context for the application
88 ctx, cancel := context.WithCancel(context.Background())
89 defer cancel()
90
91 app, err := app.New(ctx, conn)
92 if err != nil {
93 logging.Error("Failed to create app: %v", err)
94 return err
95 }
96 // Defer shutdown here so it runs for both interactive and non-interactive modes
97 defer app.Shutdown()
98
99 // Initialize MCP tools early for both modes
100 initMCPTools(ctx, app)
101
102 // Non-interactive mode
103 if prompt != "" {
104 // Run non-interactive flow using the App method
105 return app.RunNonInteractive(ctx, prompt, outputFormat, quiet)
106 }
107
108 // Set up the TUI
109 program := tea.NewProgram(
110 tui.New(app),
111 tea.WithAltScreen(),
112 tea.WithKeyReleases(),
113 tea.WithUniformKeyLayout(),
114 )
115
116 // Setup the subscriptions, this will send services events to the TUI
117 ch, cancelSubs := setupSubscriptions(app, ctx)
118
119 // Create a context for the TUI message handler
120 tuiCtx, tuiCancel := context.WithCancel(ctx)
121 var tuiWg sync.WaitGroup
122 tuiWg.Add(1)
123
124 // Set up message handling for the TUI
125 go func() {
126 defer tuiWg.Done()
127 defer logging.RecoverPanic("TUI-message-handler", func() {
128 attemptTUIRecovery(program)
129 })
130
131 for {
132 select {
133 case <-tuiCtx.Done():
134 logging.Info("TUI message handler shutting down")
135 return
136 case msg, ok := <-ch:
137 if !ok {
138 logging.Info("TUI message channel closed")
139 return
140 }
141 program.Send(msg)
142 }
143 }
144 }()
145
146 // Cleanup function for when the program exits
147 cleanup := func() {
148 // Shutdown the app
149 app.Shutdown()
150
151 // Cancel subscriptions first
152 cancelSubs()
153
154 // Then cancel TUI message handler
155 tuiCancel()
156
157 // Wait for TUI message handler to finish
158 tuiWg.Wait()
159
160 logging.Info("All goroutines cleaned up")
161 }
162
163 // Run the TUI
164 result, err := program.Run()
165 cleanup()
166
167 if err != nil {
168 logging.Error("TUI error: %v", err)
169 return fmt.Errorf("TUI error: %v", err)
170 }
171
172 logging.Info("TUI exited with result: %v", result)
173 return nil
174 },
175}
176
177// attemptTUIRecovery tries to recover the TUI after a panic
178func attemptTUIRecovery(program *tea.Program) {
179 logging.Info("Attempting to recover TUI after panic")
180
181 // We could try to restart the TUI or gracefully exit
182 // For now, we'll just quit the program to avoid further issues
183 program.Quit()
184}
185
186func initMCPTools(ctx context.Context, app *app.App) {
187 go func() {
188 defer logging.RecoverPanic("MCP-goroutine", nil)
189
190 // Create a context with timeout for the initial MCP tools fetch
191 ctxWithTimeout, cancel := context.WithTimeout(ctx, 30*time.Second)
192 defer cancel()
193
194 // Set this up once with proper error handling
195 agent.GetMcpTools(ctxWithTimeout, app.Permissions)
196 logging.Info("MCP message handling goroutine exiting")
197 }()
198}
199
200func setupSubscriber[T any](
201 ctx context.Context,
202 wg *sync.WaitGroup,
203 name string,
204 subscriber func(context.Context) <-chan pubsub.Event[T],
205 outputCh chan<- tea.Msg,
206) {
207 wg.Add(1)
208 go func() {
209 defer wg.Done()
210 defer logging.RecoverPanic(fmt.Sprintf("subscription-%s", name), nil)
211
212 subCh := subscriber(ctx)
213
214 for {
215 select {
216 case event, ok := <-subCh:
217 if !ok {
218 logging.Info("subscription channel closed", "name", name)
219 return
220 }
221
222 var msg tea.Msg = event
223
224 select {
225 case outputCh <- msg:
226 case <-time.After(2 * time.Second):
227 logging.Warn("message dropped due to slow consumer", "name", name)
228 case <-ctx.Done():
229 logging.Info("subscription cancelled", "name", name)
230 return
231 }
232 case <-ctx.Done():
233 logging.Info("subscription cancelled", "name", name)
234 return
235 }
236 }
237 }()
238}
239
240func setupSubscriptions(app *app.App, parentCtx context.Context) (chan tea.Msg, func()) {
241 ch := make(chan tea.Msg, 100)
242
243 wg := sync.WaitGroup{}
244 ctx, cancel := context.WithCancel(parentCtx) // Inherit from parent context
245
246 setupSubscriber(ctx, &wg, "logging", logging.Subscribe, ch)
247 setupSubscriber(ctx, &wg, "sessions", app.Sessions.Subscribe, ch)
248 setupSubscriber(ctx, &wg, "messages", app.Messages.Subscribe, ch)
249 setupSubscriber(ctx, &wg, "permissions", app.Permissions.Subscribe, ch)
250 setupSubscriber(ctx, &wg, "coderAgent", app.CoderAgent.Subscribe, ch)
251 setupSubscriber(ctx, &wg, "history", app.History.Subscribe, ch)
252
253 cleanupFunc := func() {
254 logging.Info("Cancelling all subscriptions")
255 cancel() // Signal all goroutines to stop
256
257 waitCh := make(chan struct{})
258 go func() {
259 defer logging.RecoverPanic("subscription-cleanup", nil)
260 wg.Wait()
261 close(waitCh)
262 }()
263
264 select {
265 case <-waitCh:
266 logging.Info("All subscription goroutines completed successfully")
267 close(ch) // Only close after all writers are confirmed done
268 case <-time.After(5 * time.Second):
269 logging.Warn("Timed out waiting for some subscription goroutines to complete")
270 close(ch)
271 }
272 }
273 return ch, cleanupFunc
274}
275
276func Execute() {
277 if err := fang.Execute(
278 context.Background(),
279 rootCmd,
280 fang.WithVersion(version.Version),
281 ); err != nil {
282 os.Exit(1)
283 }
284}
285
286func init() {
287 rootCmd.Flags().BoolP("help", "h", false, "Help")
288 rootCmd.Flags().BoolP("debug", "d", false, "Debug")
289 rootCmd.Flags().StringP("cwd", "c", "", "Current working directory")
290 rootCmd.Flags().StringP("prompt", "p", "", "Prompt to run in non-interactive mode")
291
292 // Add format flag with validation logic
293 rootCmd.Flags().StringP("output-format", "f", format.Text.String(),
294 "Output format for non-interactive mode (text, json)")
295
296 // Add quiet flag to hide spinner in non-interactive mode
297 rootCmd.Flags().BoolP("quiet", "q", false, "Hide spinner in non-interactive mode")
298
299 // Register custom validation for the format flag
300 rootCmd.RegisterFlagCompletionFunc("output-format", func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
301 return format.SupportedFormats, cobra.ShellCompDirectiveNoFileComp
302 })
303}