root.go

  1package cmd
  2
  3import (
  4	"context"
  5	"fmt"
  6	"os"
  7	"sync"
  8	"time"
  9
 10	tea "github.com/charmbracelet/bubbletea/v2"
 11	"github.com/charmbracelet/crush/internal/app"
 12	"github.com/charmbracelet/crush/internal/config"
 13	"github.com/charmbracelet/crush/internal/db"
 14	"github.com/charmbracelet/crush/internal/format"
 15	"github.com/charmbracelet/crush/internal/llm/agent"
 16	"github.com/charmbracelet/crush/internal/logging"
 17	"github.com/charmbracelet/crush/internal/pubsub"
 18	"github.com/charmbracelet/crush/internal/tui"
 19	"github.com/charmbracelet/crush/internal/version"
 20	"github.com/spf13/cobra"
 21)
 22
 23var rootCmd = &cobra.Command{
 24	Use:   "crush",
 25	Short: "Terminal-based AI assistant for software development",
 26	Long: `Crush is a powerful terminal-based AI assistant that helps with software development tasks.
 27It provides an interactive chat interface with AI capabilities, code analysis, and LSP integration
 28to assist developers in writing, debugging, and understanding code directly from the terminal.`,
 29	Example: `
 30  # Run in interactive mode
 31  crush
 32
 33  # Run with debug logging
 34  crush -d
 35
 36  # Run with debug logging in a specific directory
 37  crush -d -c /path/to/project
 38
 39  # Print version
 40  crush -v
 41
 42  # Run a single non-interactive prompt
 43  crush -p "Explain the use of context in Go"
 44
 45  # Run a single non-interactive prompt with JSON output format
 46  crush -p "Explain the use of context in Go" -f json
 47  `,
 48	RunE: func(cmd *cobra.Command, args []string) error {
 49		// If the help flag is set, show the help message
 50		if cmd.Flag("help").Changed {
 51			cmd.Help()
 52			return nil
 53		}
 54		if cmd.Flag("version").Changed {
 55			fmt.Println(version.Version)
 56			return nil
 57		}
 58
 59		// Load the config
 60		debug, _ := cmd.Flags().GetBool("debug")
 61		cwd, _ := cmd.Flags().GetString("cwd")
 62		prompt, _ := cmd.Flags().GetString("prompt")
 63		outputFormat, _ := cmd.Flags().GetString("output-format")
 64		quiet, _ := cmd.Flags().GetBool("quiet")
 65
 66		// Validate format option
 67		if !format.IsValid(outputFormat) {
 68			return fmt.Errorf("invalid format option: %s\n%s", outputFormat, format.GetHelpText())
 69		}
 70
 71		if cwd != "" {
 72			err := os.Chdir(cwd)
 73			if err != nil {
 74				return fmt.Errorf("failed to change directory: %v", err)
 75			}
 76		}
 77		if cwd == "" {
 78			c, err := os.Getwd()
 79			if err != nil {
 80				return fmt.Errorf("failed to get current working directory: %v", err)
 81			}
 82			cwd = c
 83		}
 84		_, err := config.Load(cwd, debug)
 85		if err != nil {
 86			return err
 87		}
 88
 89		// Connect DB, this will also run migrations
 90		conn, err := db.Connect()
 91		if err != nil {
 92			return err
 93		}
 94
 95		// Create main context for the application
 96		ctx, cancel := context.WithCancel(context.Background())
 97		defer cancel()
 98
 99		app, err := app.New(ctx, conn)
100		if err != nil {
101			logging.Error("Failed to create app: %v", err)
102			return err
103		}
104		// Defer shutdown here so it runs for both interactive and non-interactive modes
105		defer app.Shutdown()
106
107		// Initialize MCP tools early for both modes
108		initMCPTools(ctx, app)
109
110		// Non-interactive mode
111		if prompt != "" {
112			// Run non-interactive flow using the App method
113			return app.RunNonInteractive(ctx, prompt, outputFormat, quiet)
114		}
115
116		// Set up the TUI
117		program := tea.NewProgram(
118			tui.New(app),
119			tea.WithAltScreen(),
120			tea.WithKeyReleases(),
121			tea.WithUniformKeyLayout(),
122		)
123
124		// Setup the subscriptions, this will send services events to the TUI
125		ch, cancelSubs := setupSubscriptions(app, ctx)
126
127		// Create a context for the TUI message handler
128		tuiCtx, tuiCancel := context.WithCancel(ctx)
129		var tuiWg sync.WaitGroup
130		tuiWg.Add(1)
131
132		// Set up message handling for the TUI
133		go func() {
134			defer tuiWg.Done()
135			defer logging.RecoverPanic("TUI-message-handler", func() {
136				attemptTUIRecovery(program)
137			})
138
139			for {
140				select {
141				case <-tuiCtx.Done():
142					logging.Info("TUI message handler shutting down")
143					return
144				case msg, ok := <-ch:
145					if !ok {
146						logging.Info("TUI message channel closed")
147						return
148					}
149					program.Send(msg)
150				}
151			}
152		}()
153
154		// Cleanup function for when the program exits
155		cleanup := func() {
156			// Shutdown the app
157			app.Shutdown()
158
159			// Cancel subscriptions first
160			cancelSubs()
161
162			// Then cancel TUI message handler
163			tuiCancel()
164
165			// Wait for TUI message handler to finish
166			tuiWg.Wait()
167
168			logging.Info("All goroutines cleaned up")
169		}
170
171		// Run the TUI
172		result, err := program.Run()
173		cleanup()
174
175		if err != nil {
176			logging.Error("TUI error: %v", err)
177			return fmt.Errorf("TUI error: %v", err)
178		}
179
180		logging.Info("TUI exited with result: %v", result)
181		return nil
182	},
183}
184
185// attemptTUIRecovery tries to recover the TUI after a panic
186func attemptTUIRecovery(program *tea.Program) {
187	logging.Info("Attempting to recover TUI after panic")
188
189	// We could try to restart the TUI or gracefully exit
190	// For now, we'll just quit the program to avoid further issues
191	program.Quit()
192}
193
194func initMCPTools(ctx context.Context, app *app.App) {
195	go func() {
196		defer logging.RecoverPanic("MCP-goroutine", nil)
197
198		// Create a context with timeout for the initial MCP tools fetch
199		ctxWithTimeout, cancel := context.WithTimeout(ctx, 30*time.Second)
200		defer cancel()
201
202		// Set this up once with proper error handling
203		agent.GetMcpTools(ctxWithTimeout, app.Permissions)
204		logging.Info("MCP message handling goroutine exiting")
205	}()
206}
207
208func setupSubscriber[T any](
209	ctx context.Context,
210	wg *sync.WaitGroup,
211	name string,
212	subscriber func(context.Context) <-chan pubsub.Event[T],
213	outputCh chan<- tea.Msg,
214) {
215	wg.Add(1)
216	go func() {
217		defer wg.Done()
218		defer logging.RecoverPanic(fmt.Sprintf("subscription-%s", name), nil)
219
220		subCh := subscriber(ctx)
221
222		for {
223			select {
224			case event, ok := <-subCh:
225				if !ok {
226					logging.Info("subscription channel closed", "name", name)
227					return
228				}
229
230				var msg tea.Msg = event
231
232				select {
233				case outputCh <- msg:
234				case <-time.After(2 * time.Second):
235					logging.Warn("message dropped due to slow consumer", "name", name)
236				case <-ctx.Done():
237					logging.Info("subscription cancelled", "name", name)
238					return
239				}
240			case <-ctx.Done():
241				logging.Info("subscription cancelled", "name", name)
242				return
243			}
244		}
245	}()
246}
247
248func setupSubscriptions(app *app.App, parentCtx context.Context) (chan tea.Msg, func()) {
249	ch := make(chan tea.Msg, 100)
250
251	wg := sync.WaitGroup{}
252	ctx, cancel := context.WithCancel(parentCtx) // Inherit from parent context
253
254	setupSubscriber(ctx, &wg, "logging", logging.Subscribe, ch)
255	setupSubscriber(ctx, &wg, "sessions", app.Sessions.Subscribe, ch)
256	setupSubscriber(ctx, &wg, "messages", app.Messages.Subscribe, ch)
257	setupSubscriber(ctx, &wg, "permissions", app.Permissions.Subscribe, ch)
258	setupSubscriber(ctx, &wg, "coderAgent", app.CoderAgent.Subscribe, ch)
259	setupSubscriber(ctx, &wg, "history", app.History.Subscribe, ch)
260
261	cleanupFunc := func() {
262		logging.Info("Cancelling all subscriptions")
263		cancel() // Signal all goroutines to stop
264
265		waitCh := make(chan struct{})
266		go func() {
267			defer logging.RecoverPanic("subscription-cleanup", nil)
268			wg.Wait()
269			close(waitCh)
270		}()
271
272		select {
273		case <-waitCh:
274			logging.Info("All subscription goroutines completed successfully")
275			close(ch) // Only close after all writers are confirmed done
276		case <-time.After(5 * time.Second):
277			logging.Warn("Timed out waiting for some subscription goroutines to complete")
278			close(ch)
279		}
280	}
281	return ch, cleanupFunc
282}
283
284func Execute() {
285	err := rootCmd.Execute()
286	if err != nil {
287		os.Exit(1)
288	}
289}
290
291func init() {
292	rootCmd.Flags().BoolP("help", "h", false, "Help")
293	rootCmd.Flags().BoolP("version", "v", false, "Version")
294	rootCmd.Flags().BoolP("debug", "d", false, "Debug")
295	rootCmd.Flags().StringP("cwd", "c", "", "Current working directory")
296	rootCmd.Flags().StringP("prompt", "p", "", "Prompt to run in non-interactive mode")
297
298	// Add format flag with validation logic
299	rootCmd.Flags().StringP("output-format", "f", format.Text.String(),
300		"Output format for non-interactive mode (text, json)")
301
302	// Add quiet flag to hide spinner in non-interactive mode
303	rootCmd.Flags().BoolP("quiet", "q", false, "Hide spinner in non-interactive mode")
304
305	// Register custom validation for the format flag
306	rootCmd.RegisterFlagCompletionFunc("output-format", func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
307		return format.SupportedFormats, cobra.ShellCompDirectiveNoFileComp
308	})
309}