root.go

  1package cmd
  2
  3import (
  4	"context"
  5	"fmt"
  6	"os"
  7	"sync"
  8	"time"
  9
 10	tea "github.com/charmbracelet/bubbletea/v2"
 11	"github.com/charmbracelet/crush/internal/app"
 12	"github.com/charmbracelet/crush/internal/config"
 13	"github.com/charmbracelet/crush/internal/db"
 14	"github.com/charmbracelet/crush/internal/format"
 15	"github.com/charmbracelet/crush/internal/llm/agent"
 16	"github.com/charmbracelet/crush/internal/logging"
 17	"github.com/charmbracelet/crush/internal/pubsub"
 18	"github.com/charmbracelet/crush/internal/tui"
 19	"github.com/charmbracelet/crush/internal/version"
 20	"github.com/spf13/cobra"
 21)
 22
 23var rootCmd = &cobra.Command{
 24	Use:   "crush",
 25	Short: "Terminal-based AI assistant for software development",
 26	Long: `Crush is a powerful terminal-based AI assistant that helps with software development tasks.
 27It provides an interactive chat interface with AI capabilities, code analysis, and LSP integration
 28to assist developers in writing, debugging, and understanding code directly from the terminal.`,
 29	Example: `
 30  # Run in interactive mode
 31  crush
 32
 33  # Run with debug logging
 34  crush -d
 35
 36  # Run with debug logging in a specific directory
 37  crush -d -c /path/to/project
 38
 39  # Print version
 40  crush -v
 41
 42  # Run a single non-interactive prompt
 43  crush -p "Explain the use of context in Go"
 44
 45  # Run a single non-interactive prompt with JSON output format
 46  crush -p "Explain the use of context in Go" -f json
 47  `,
 48	RunE: func(cmd *cobra.Command, args []string) error {
 49		// If the help flag is set, show the help message
 50		if cmd.Flag("help").Changed {
 51			cmd.Help()
 52			return nil
 53		}
 54		if cmd.Flag("version").Changed {
 55			fmt.Println(version.Version)
 56			return nil
 57		}
 58
 59		// Load the config
 60		debug, _ := cmd.Flags().GetBool("debug")
 61		cwd, _ := cmd.Flags().GetString("cwd")
 62		prompt, _ := cmd.Flags().GetString("prompt")
 63		outputFormat, _ := cmd.Flags().GetString("output-format")
 64		quiet, _ := cmd.Flags().GetBool("quiet")
 65
 66		// Validate format option
 67		if !format.IsValid(outputFormat) {
 68			return fmt.Errorf("invalid format option: %s\n%s", outputFormat, format.GetHelpText())
 69		}
 70
 71		if cwd != "" {
 72			err := os.Chdir(cwd)
 73			if err != nil {
 74				return fmt.Errorf("failed to change directory: %v", err)
 75			}
 76		}
 77		if cwd == "" {
 78			c, err := os.Getwd()
 79			if err != nil {
 80				return fmt.Errorf("failed to get current working directory: %v", err)
 81			}
 82			cwd = c
 83		}
 84		_, err := config.Load(cwd, debug)
 85		if err != nil {
 86			return err
 87		}
 88
 89		// Connect DB, this will also run migrations
 90		conn, err := db.Connect()
 91		if err != nil {
 92			return err
 93		}
 94
 95		// Create main context for the application
 96		ctx, cancel := context.WithCancel(context.Background())
 97		defer cancel()
 98
 99		app, err := app.New(ctx, conn)
100		if err != nil {
101			logging.Error("Failed to create app: %v", err)
102			return err
103		}
104		// Defer shutdown here so it runs for both interactive and non-interactive modes
105		defer app.Shutdown()
106
107		// Initialize MCP tools early for both modes
108		initMCPTools(ctx, app)
109
110		// Non-interactive mode
111		if prompt != "" {
112			// Run non-interactive flow using the App method
113			return app.RunNonInteractive(ctx, prompt, outputFormat, quiet)
114		}
115
116		// Set up the TUI
117		program := tea.NewProgram(
118			tui.New(app),
119			tea.WithAltScreen(),
120		)
121
122		// Setup the subscriptions, this will send services events to the TUI
123		ch, cancelSubs := setupSubscriptions(app, ctx)
124
125		// Create a context for the TUI message handler
126		tuiCtx, tuiCancel := context.WithCancel(ctx)
127		var tuiWg sync.WaitGroup
128		tuiWg.Add(1)
129
130		// Set up message handling for the TUI
131		go func() {
132			defer tuiWg.Done()
133			defer logging.RecoverPanic("TUI-message-handler", func() {
134				attemptTUIRecovery(program)
135			})
136
137			for {
138				select {
139				case <-tuiCtx.Done():
140					logging.Info("TUI message handler shutting down")
141					return
142				case msg, ok := <-ch:
143					if !ok {
144						logging.Info("TUI message channel closed")
145						return
146					}
147					program.Send(msg)
148				}
149			}
150		}()
151
152		// Cleanup function for when the program exits
153		cleanup := func() {
154			// Shutdown the app
155			app.Shutdown()
156
157			// Cancel subscriptions first
158			cancelSubs()
159
160			// Then cancel TUI message handler
161			tuiCancel()
162
163			// Wait for TUI message handler to finish
164			tuiWg.Wait()
165
166			logging.Info("All goroutines cleaned up")
167		}
168
169		// Run the TUI
170		result, err := program.Run()
171		cleanup()
172
173		if err != nil {
174			logging.Error("TUI error: %v", err)
175			return fmt.Errorf("TUI error: %v", err)
176		}
177
178		logging.Info("TUI exited with result: %v", result)
179		return nil
180	},
181}
182
183// attemptTUIRecovery tries to recover the TUI after a panic
184func attemptTUIRecovery(program *tea.Program) {
185	logging.Info("Attempting to recover TUI after panic")
186
187	// We could try to restart the TUI or gracefully exit
188	// For now, we'll just quit the program to avoid further issues
189	program.Quit()
190}
191
192func initMCPTools(ctx context.Context, app *app.App) {
193	go func() {
194		defer logging.RecoverPanic("MCP-goroutine", nil)
195
196		// Create a context with timeout for the initial MCP tools fetch
197		ctxWithTimeout, cancel := context.WithTimeout(ctx, 30*time.Second)
198		defer cancel()
199
200		// Set this up once with proper error handling
201		agent.GetMcpTools(ctxWithTimeout, app.Permissions)
202		logging.Info("MCP message handling goroutine exiting")
203	}()
204}
205
206func setupSubscriber[T any](
207	ctx context.Context,
208	wg *sync.WaitGroup,
209	name string,
210	subscriber func(context.Context) <-chan pubsub.Event[T],
211	outputCh chan<- tea.Msg,
212) {
213	wg.Add(1)
214	go func() {
215		defer wg.Done()
216		defer logging.RecoverPanic(fmt.Sprintf("subscription-%s", name), nil)
217
218		subCh := subscriber(ctx)
219
220		for {
221			select {
222			case event, ok := <-subCh:
223				if !ok {
224					logging.Info("subscription channel closed", "name", name)
225					return
226				}
227
228				var msg tea.Msg = event
229
230				select {
231				case outputCh <- msg:
232				case <-time.After(2 * time.Second):
233					logging.Warn("message dropped due to slow consumer", "name", name)
234				case <-ctx.Done():
235					logging.Info("subscription cancelled", "name", name)
236					return
237				}
238			case <-ctx.Done():
239				logging.Info("subscription cancelled", "name", name)
240				return
241			}
242		}
243	}()
244}
245
246func setupSubscriptions(app *app.App, parentCtx context.Context) (chan tea.Msg, func()) {
247	ch := make(chan tea.Msg, 100)
248
249	wg := sync.WaitGroup{}
250	ctx, cancel := context.WithCancel(parentCtx) // Inherit from parent context
251
252	setupSubscriber(ctx, &wg, "logging", logging.Subscribe, ch)
253	setupSubscriber(ctx, &wg, "sessions", app.Sessions.Subscribe, ch)
254	setupSubscriber(ctx, &wg, "messages", app.Messages.Subscribe, ch)
255	setupSubscriber(ctx, &wg, "permissions", app.Permissions.Subscribe, ch)
256	setupSubscriber(ctx, &wg, "coderAgent", app.CoderAgent.Subscribe, ch)
257
258	cleanupFunc := func() {
259		logging.Info("Cancelling all subscriptions")
260		cancel() // Signal all goroutines to stop
261
262		waitCh := make(chan struct{})
263		go func() {
264			defer logging.RecoverPanic("subscription-cleanup", nil)
265			wg.Wait()
266			close(waitCh)
267		}()
268
269		select {
270		case <-waitCh:
271			logging.Info("All subscription goroutines completed successfully")
272			close(ch) // Only close after all writers are confirmed done
273		case <-time.After(5 * time.Second):
274			logging.Warn("Timed out waiting for some subscription goroutines to complete")
275			close(ch)
276		}
277	}
278	return ch, cleanupFunc
279}
280
281func Execute() {
282	err := rootCmd.Execute()
283	if err != nil {
284		os.Exit(1)
285	}
286}
287
288func init() {
289	rootCmd.Flags().BoolP("help", "h", false, "Help")
290	rootCmd.Flags().BoolP("version", "v", false, "Version")
291	rootCmd.Flags().BoolP("debug", "d", false, "Debug")
292	rootCmd.Flags().StringP("cwd", "c", "", "Current working directory")
293	rootCmd.Flags().StringP("prompt", "p", "", "Prompt to run in non-interactive mode")
294
295	// Add format flag with validation logic
296	rootCmd.Flags().StringP("output-format", "f", format.Text.String(),
297		"Output format for non-interactive mode (text, json)")
298
299	// Add quiet flag to hide spinner in non-interactive mode
300	rootCmd.Flags().BoolP("quiet", "q", false, "Hide spinner in non-interactive mode")
301
302	// Register custom validation for the format flag
303	rootCmd.RegisterFlagCompletionFunc("output-format", func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
304		return format.SupportedFormats, cobra.ShellCompDirectiveNoFileComp
305	})
306}