root.go

  1package cmd
  2
  3import (
  4	"context"
  5	"fmt"
  6	"io"
  7	"log/slog"
  8	"os"
  9	"sync"
 10	"time"
 11
 12	tea "github.com/charmbracelet/bubbletea/v2"
 13	"github.com/charmbracelet/crush/internal/app"
 14	"github.com/charmbracelet/crush/internal/config"
 15	"github.com/charmbracelet/crush/internal/db"
 16	"github.com/charmbracelet/crush/internal/format"
 17	"github.com/charmbracelet/crush/internal/llm/agent"
 18	"github.com/charmbracelet/crush/internal/log"
 19	"github.com/charmbracelet/crush/internal/pubsub"
 20	"github.com/charmbracelet/crush/internal/tui"
 21	"github.com/charmbracelet/crush/internal/version"
 22	"github.com/charmbracelet/fang"
 23	"github.com/charmbracelet/x/term"
 24	"github.com/spf13/cobra"
 25)
 26
 27var rootCmd = &cobra.Command{
 28	Use:   "crush",
 29	Short: "Terminal-based AI assistant for software development",
 30	Long: `Crush is a powerful terminal-based AI assistant that helps with software development tasks.
 31It provides an interactive chat interface with AI capabilities, code analysis, and LSP integration
 32to assist developers in writing, debugging, and understanding code directly from the terminal.`,
 33	Example: `
 34  # Run in interactive mode
 35  crush
 36
 37  # Run with debug logging
 38  crush -d
 39
 40  # Run with debug slog.in a specific directory
 41  crush -d -c /path/to/project
 42
 43  # Print version
 44  crush -v
 45
 46  # Run a single non-interactive prompt
 47  crush -p "Explain the use of context in Go"
 48
 49  # Run a single non-interactive prompt with JSON output format
 50  crush -p "Explain the use of context in Go" -f json
 51  `,
 52	RunE: func(cmd *cobra.Command, args []string) error {
 53		// Load the config
 54		debug, _ := cmd.Flags().GetBool("debug")
 55		cwd, _ := cmd.Flags().GetString("cwd")
 56		prompt, _ := cmd.Flags().GetString("prompt")
 57		outputFormat, _ := cmd.Flags().GetString("output-format")
 58		quiet, _ := cmd.Flags().GetBool("quiet")
 59
 60		// Validate format option
 61		if !format.IsValid(outputFormat) {
 62			return fmt.Errorf("invalid format option: %s\n%s", outputFormat, format.GetHelpText())
 63		}
 64
 65		if cwd != "" {
 66			err := os.Chdir(cwd)
 67			if err != nil {
 68				return fmt.Errorf("failed to change directory: %v", err)
 69			}
 70		}
 71		if cwd == "" {
 72			c, err := os.Getwd()
 73			if err != nil {
 74				return fmt.Errorf("failed to get current working directory: %v", err)
 75			}
 76			cwd = c
 77		}
 78
 79		cfg, err := config.Init(cwd, debug)
 80		if err != nil {
 81			return err
 82		}
 83
 84		// Create main context for the application
 85		ctx, cancel := context.WithCancel(context.Background())
 86		defer cancel()
 87
 88		// Connect DB, this will also run migrations
 89		conn, err := db.Connect(ctx, cfg.Options.DataDirectory)
 90		if err != nil {
 91			return err
 92		}
 93
 94		app, err := app.New(ctx, conn, cfg)
 95		if err != nil {
 96			slog.Error(fmt.Sprintf("Failed to create app instance: %v", err))
 97			return err
 98		}
 99		// Defer shutdown here so it runs for both interactive and non-interactive modes
100		defer app.Shutdown()
101
102		// Initialize MCP tools early for both modes
103		initMCPTools(ctx, app, cfg)
104
105		prompt, err = maybePrependStdin(prompt)
106		if err != nil {
107			slog.Error(fmt.Sprintf("Failed to read from stdin: %v", err))
108			return err
109		}
110
111		// Non-interactive mode
112		if prompt != "" {
113			// Run non-interactive flow using the App method
114			return app.RunNonInteractive(ctx, prompt, outputFormat, quiet)
115		}
116
117		// Set up the TUI
118		program := tea.NewProgram(
119			tui.New(app),
120			tea.WithAltScreen(),
121			tea.WithKeyReleases(),
122			tea.WithUniformKeyLayout(),
123			tea.WithMouseCellMotion(),            // Use cell motion instead of all motion to reduce event flooding
124			tea.WithFilter(tui.MouseEventFilter), // Filter mouse events based on focus state
125		)
126
127		// Setup the subscriptions, this will send services events to the TUI
128		ch, cancelSubs := setupSubscriptions(app, ctx)
129
130		// Create a context for the TUI message handler
131		tuiCtx, tuiCancel := context.WithCancel(ctx)
132		var tuiWg sync.WaitGroup
133		tuiWg.Add(1)
134
135		// Set up message handling for the TUI
136		go func() {
137			defer tuiWg.Done()
138			defer log.RecoverPanic("TUI-message-handler", func() {
139				attemptTUIRecovery(program)
140			})
141
142			for {
143				select {
144				case <-tuiCtx.Done():
145					slog.Info("TUI message handler shutting down")
146					return
147				case msg, ok := <-ch:
148					if !ok {
149						slog.Info("TUI message channel closed")
150						return
151					}
152					program.Send(msg)
153				}
154			}
155		}()
156
157		// Cleanup function for when the program exits
158		cleanup := func() {
159			// Shutdown the app
160			app.Shutdown()
161
162			// Cancel subscriptions first
163			cancelSubs()
164
165			// Then cancel TUI message handler
166			tuiCancel()
167
168			// Wait for TUI message handler to finish
169			tuiWg.Wait()
170
171			slog.Info("All goroutines cleaned up")
172		}
173
174		// Run the TUI
175		result, err := program.Run()
176		cleanup()
177
178		if err != nil {
179			slog.Error(fmt.Sprintf("TUI run error: %v", err))
180			return fmt.Errorf("TUI error: %v", err)
181		}
182
183		slog.Info(fmt.Sprintf("TUI exited with result: %v", result))
184		return nil
185	},
186}
187
188// attemptTUIRecovery tries to recover the TUI after a panic
189func attemptTUIRecovery(program *tea.Program) {
190	slog.Info("Attempting to recover TUI after panic")
191
192	// We could try to restart the TUI or gracefully exit
193	// For now, we'll just quit the program to avoid further issues
194	program.Quit()
195}
196
197func initMCPTools(ctx context.Context, app *app.App, cfg *config.Config) {
198	go func() {
199		defer log.RecoverPanic("MCP-goroutine", nil)
200
201		// Create a context with timeout for the initial MCP tools fetch
202		ctxWithTimeout, cancel := context.WithTimeout(ctx, 30*time.Second)
203		defer cancel()
204
205		// Set this up once with proper error handling
206		agent.GetMcpTools(ctxWithTimeout, app.Permissions, cfg)
207		slog.Info("MCP message handling goroutine exiting")
208	}()
209}
210
211func setupSubscriber[T any](
212	ctx context.Context,
213	wg *sync.WaitGroup,
214	name string,
215	subscriber func(context.Context) <-chan pubsub.Event[T],
216	outputCh chan<- tea.Msg,
217) {
218	wg.Add(1)
219	go func() {
220		defer wg.Done()
221		defer log.RecoverPanic(fmt.Sprintf("subscription-%s", name), nil)
222
223		subCh := subscriber(ctx)
224
225		for {
226			select {
227			case event, ok := <-subCh:
228				if !ok {
229					slog.Info("subscription channel closed", "name", name)
230					return
231				}
232
233				var msg tea.Msg = event
234
235				select {
236				case outputCh <- msg:
237				case <-time.After(2 * time.Second):
238					slog.Warn("message dropped due to slow consumer", "name", name)
239				case <-ctx.Done():
240					slog.Info("subscription canceled", "name", name)
241					return
242				}
243			case <-ctx.Done():
244				slog.Info("subscription canceled", "name", name)
245				return
246			}
247		}
248	}()
249}
250
251func setupSubscriptions(app *app.App, parentCtx context.Context) (chan tea.Msg, func()) {
252	ch := make(chan tea.Msg, 100)
253
254	wg := sync.WaitGroup{}
255	ctx, cancel := context.WithCancel(parentCtx) // Inherit from parent context
256
257	setupSubscriber(ctx, &wg, "sessions", app.Sessions.Subscribe, ch)
258	setupSubscriber(ctx, &wg, "messages", app.Messages.Subscribe, ch)
259	setupSubscriber(ctx, &wg, "permissions", app.Permissions.Subscribe, ch)
260	setupSubscriber(ctx, &wg, "coderAgent", app.CoderAgent.Subscribe, ch)
261	setupSubscriber(ctx, &wg, "history", app.History.Subscribe, ch)
262
263	cleanupFunc := func() {
264		slog.Info("Cancelling all subscriptions")
265		cancel() // Signal all goroutines to stop
266
267		waitCh := make(chan struct{})
268		go func() {
269			defer log.RecoverPanic("subscription-cleanup", nil)
270			wg.Wait()
271			close(waitCh)
272		}()
273
274		select {
275		case <-waitCh:
276			slog.Info("All subscription goroutines completed successfully")
277			close(ch) // Only close after all writers are confirmed done
278		case <-time.After(5 * time.Second):
279			slog.Warn("Timed out waiting for some subscription goroutines to complete")
280			close(ch)
281		}
282	}
283	return ch, cleanupFunc
284}
285
286func Execute() {
287	if err := fang.Execute(
288		context.Background(),
289		rootCmd,
290		fang.WithVersion(version.Version),
291	); err != nil {
292		os.Exit(1)
293	}
294}
295
296func init() {
297	rootCmd.PersistentFlags().StringP("cwd", "c", "", "Current working directory")
298
299	rootCmd.Flags().BoolP("help", "h", false, "Help")
300	rootCmd.Flags().BoolP("debug", "d", false, "Debug")
301	rootCmd.Flags().StringP("prompt", "p", "", "Prompt to run in non-interactive mode")
302
303	// Add format flag with validation logic
304	rootCmd.Flags().StringP("output-format", "f", format.Text.String(),
305		"Output format for non-interactive mode (text, json)")
306
307	// Add quiet flag to hide spinner in non-interactive mode
308	rootCmd.Flags().BoolP("quiet", "q", false, "Hide spinner in non-interactive mode")
309
310	// Register custom validation for the format flag
311	rootCmd.RegisterFlagCompletionFunc("output-format", func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
312		return format.SupportedFormats, cobra.ShellCompDirectiveNoFileComp
313	})
314}
315
316func maybePrependStdin(prompt string) (string, error) {
317	if term.IsTerminal(os.Stdin.Fd()) {
318		return prompt, nil
319	}
320	fi, err := os.Stdin.Stat()
321	if err != nil {
322		return prompt, err
323	}
324	if fi.Mode()&os.ModeNamedPipe == 0 {
325		return prompt, nil
326	}
327	bts, err := io.ReadAll(os.Stdin)
328	if err != nil {
329		return prompt, err
330	}
331	return string(bts) + "\n\n" + prompt, nil
332}