1package cmd
2
3import (
4 "context"
5 "fmt"
6 "os"
7 "os/signal"
8 "sync"
9 "syscall"
10 "time"
11
12 tea "github.com/charmbracelet/bubbletea/v2"
13 "github.com/charmbracelet/crush/internal/app"
14 "github.com/charmbracelet/crush/internal/config"
15 "github.com/charmbracelet/crush/internal/db"
16 "github.com/charmbracelet/crush/internal/format"
17 "github.com/charmbracelet/crush/internal/llm/agent"
18 "github.com/charmbracelet/crush/internal/logging"
19 "github.com/charmbracelet/crush/internal/pubsub"
20 "github.com/charmbracelet/crush/internal/tui"
21 "github.com/charmbracelet/crush/internal/version"
22 "github.com/charmbracelet/fang"
23 "github.com/spf13/cobra"
24)
25
26var rootCmd = &cobra.Command{
27 Use: "crush",
28 Short: "Terminal-based AI assistant for software development",
29 Long: `Crush is a powerful terminal-based AI assistant that helps with software development tasks.
30It provides an interactive chat interface with AI capabilities, code analysis, and LSP integration
31to assist developers in writing, debugging, and understanding code directly from the terminal.`,
32 Example: `
33 # Run in interactive mode
34 crush
35
36 # Run with debug logging
37 crush -d
38
39 # Run with debug logging in a specific directory
40 crush -d -c /path/to/project
41
42 # Print version
43 crush -v
44
45 # Run a single non-interactive prompt
46 crush -p "Explain the use of context in Go"
47
48 # Run a single non-interactive prompt with JSON output format
49 crush -p "Explain the use of context in Go" -f json
50 `,
51 RunE: func(cmd *cobra.Command, args []string) error {
52 // Load the config
53 debug, _ := cmd.Flags().GetBool("debug")
54 cwd, _ := cmd.Flags().GetString("cwd")
55 prompt, _ := cmd.Flags().GetString("prompt")
56 outputFormat, _ := cmd.Flags().GetString("output-format")
57 quiet, _ := cmd.Flags().GetBool("quiet")
58
59 // Validate format option
60 if !format.IsValid(outputFormat) {
61 return fmt.Errorf("invalid format option: %s\n%s", outputFormat, format.GetHelpText())
62 }
63
64 if cwd != "" {
65 err := os.Chdir(cwd)
66 if err != nil {
67 return fmt.Errorf("failed to change directory: %v", err)
68 }
69 }
70 if cwd == "" {
71 c, err := os.Getwd()
72 if err != nil {
73 return fmt.Errorf("failed to get current working directory: %v", err)
74 }
75 cwd = c
76 }
77 _, err := config.Load(cwd, debug)
78 if err != nil {
79 return err
80 }
81
82 // Connect DB, this will also run migrations
83 conn, err := db.Connect()
84 if err != nil {
85 return err
86 }
87
88 // Create main context for the application with signal handling
89 ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
90 defer cancel()
91
92 app, err := app.New(ctx, conn)
93 if err != nil {
94 logging.Error("Failed to create app: %v", err)
95 return err
96 }
97 // Defer shutdown here so it runs for both interactive and non-interactive modes
98 defer app.Shutdown()
99
100 // Initialize MCP tools early for both modes
101 initMCPTools(ctx, app)
102
103 // Non-interactive mode
104 if prompt != "" {
105 // Run non-interactive flow using the App method
106 return app.RunNonInteractive(ctx, prompt, outputFormat, quiet)
107 }
108
109 // Set up the TUI
110 program := tea.NewProgram(
111 tui.New(app),
112 tea.WithAltScreen(),
113 tea.WithKeyReleases(),
114 tea.WithUniformKeyLayout(),
115 )
116
117 // Setup the subscriptions, this will send services events to the TUI
118 ch, cancelSubs := setupSubscriptions(app, ctx)
119
120 // Create a context for the TUI message handler
121 tuiCtx, tuiCancel := context.WithCancel(ctx)
122 var tuiWg sync.WaitGroup
123 tuiWg.Add(1)
124
125 // Set up message handling for the TUI
126 go func() {
127 defer tuiWg.Done()
128 defer logging.RecoverPanic("TUI-message-handler", func() {
129 attemptTUIRecovery(program)
130 })
131
132 for {
133 select {
134 case <-tuiCtx.Done():
135 logging.Info("TUI message handler shutting down")
136 return
137 case msg, ok := <-ch:
138 if !ok {
139 logging.Info("TUI message channel closed")
140 return
141 }
142 program.Send(msg)
143 }
144 }
145 }()
146
147 // Cleanup function for when the program exits
148 cleanup := func() {
149 // Shutdown the app
150 app.Shutdown()
151
152 // Cancel subscriptions first
153 cancelSubs()
154
155 // Then cancel TUI message handler
156 tuiCancel()
157
158 // Wait for TUI message handler to finish
159 tuiWg.Wait()
160
161 logging.Info("All goroutines cleaned up")
162 }
163
164 // Run the TUI
165 result, err := program.Run()
166 cleanup()
167
168 if err != nil {
169 logging.Error("TUI error: %v", err)
170 return fmt.Errorf("TUI error: %v", err)
171 }
172
173 logging.Info("TUI exited with result: %v", result)
174 return nil
175 },
176}
177
178// attemptTUIRecovery tries to recover the TUI after a panic
179func attemptTUIRecovery(program *tea.Program) {
180 logging.Info("Attempting to recover TUI after panic")
181
182 // We could try to restart the TUI or gracefully exit
183 // For now, we'll just quit the program to avoid further issues
184 program.Quit()
185}
186
187func initMCPTools(ctx context.Context, app *app.App) {
188 go func() {
189 defer logging.RecoverPanic("MCP-goroutine", nil)
190
191 // Create a context with timeout for the initial MCP tools fetch
192 ctxWithTimeout, cancel := context.WithTimeout(ctx, 30*time.Second)
193 defer cancel()
194
195 // Set this up once with proper error handling
196 agent.GetMcpTools(ctxWithTimeout, app.Permissions)
197 logging.Info("MCP message handling goroutine exiting")
198 }()
199}
200
201func setupSubscriber[T any](
202 ctx context.Context,
203 wg *sync.WaitGroup,
204 name string,
205 subscriber func(context.Context) <-chan pubsub.Event[T],
206 outputCh chan<- tea.Msg,
207) {
208 wg.Add(1)
209 go func() {
210 defer wg.Done()
211 defer logging.RecoverPanic(fmt.Sprintf("subscription-%s", name), nil)
212
213 subCh := subscriber(ctx)
214
215 for {
216 select {
217 case event, ok := <-subCh:
218 if !ok {
219 logging.Info("subscription channel closed", "name", name)
220 return
221 }
222
223 var msg tea.Msg = event
224
225 select {
226 case outputCh <- msg:
227 case <-time.After(2 * time.Second):
228 logging.Warn("message dropped due to slow consumer", "name", name)
229 case <-ctx.Done():
230 logging.Info("subscription cancelled", "name", name)
231 return
232 }
233 case <-ctx.Done():
234 logging.Info("subscription cancelled", "name", name)
235 return
236 }
237 }
238 }()
239}
240
241func setupSubscriptions(app *app.App, parentCtx context.Context) (chan tea.Msg, func()) {
242 ch := make(chan tea.Msg, 100)
243
244 wg := sync.WaitGroup{}
245 ctx, cancel := context.WithCancel(parentCtx) // Inherit from parent context
246
247 setupSubscriber(ctx, &wg, "logging", logging.Subscribe, ch)
248 setupSubscriber(ctx, &wg, "sessions", app.Sessions.Subscribe, ch)
249 setupSubscriber(ctx, &wg, "messages", app.Messages.Subscribe, ch)
250 setupSubscriber(ctx, &wg, "permissions", app.Permissions.Subscribe, ch)
251 setupSubscriber(ctx, &wg, "coderAgent", app.CoderAgent.Subscribe, ch)
252 setupSubscriber(ctx, &wg, "history", app.History.Subscribe, ch)
253
254 cleanupFunc := func() {
255 logging.Info("Cancelling all subscriptions")
256 cancel() // Signal all goroutines to stop
257
258 waitCh := make(chan struct{})
259 go func() {
260 defer logging.RecoverPanic("subscription-cleanup", nil)
261 wg.Wait()
262 close(waitCh)
263 }()
264
265 select {
266 case <-waitCh:
267 logging.Info("All subscription goroutines completed successfully")
268 close(ch) // Only close after all writers are confirmed done
269 case <-time.After(5 * time.Second):
270 logging.Warn("Timed out waiting for some subscription goroutines to complete")
271 close(ch)
272 }
273 }
274 return ch, cleanupFunc
275}
276
277func Execute() {
278 if err := fang.Execute(
279 context.Background(),
280 rootCmd,
281 fang.WithVersion(version.Version),
282 ); err != nil {
283 os.Exit(1)
284 }
285}
286
287func init() {
288 rootCmd.Flags().BoolP("help", "h", false, "Help")
289 rootCmd.Flags().BoolP("debug", "d", false, "Debug")
290 rootCmd.Flags().StringP("cwd", "c", "", "Current working directory")
291 rootCmd.Flags().StringP("prompt", "p", "", "Prompt to run in non-interactive mode")
292
293 // Add format flag with validation logic
294 rootCmd.Flags().StringP("output-format", "f", format.Text.String(),
295 "Output format for non-interactive mode (text, json)")
296
297 // Add quiet flag to hide spinner in non-interactive mode
298 rootCmd.Flags().BoolP("quiet", "q", false, "Hide spinner in non-interactive mode")
299
300 // Register custom validation for the format flag
301 rootCmd.RegisterFlagCompletionFunc("output-format", func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
302 return format.SupportedFormats, cobra.ShellCompDirectiveNoFileComp
303 })
304}