root.go

  1package cmd
  2
  3import (
  4	"context"
  5	"fmt"
  6	"io"
  7	"os"
  8	"sync"
  9	"time"
 10
 11	tea "github.com/charmbracelet/bubbletea/v2"
 12	"github.com/charmbracelet/crush/internal/app"
 13	"github.com/charmbracelet/crush/internal/config"
 14	"github.com/charmbracelet/crush/internal/db"
 15	"github.com/charmbracelet/crush/internal/format"
 16	"github.com/charmbracelet/crush/internal/llm/agent"
 17	"github.com/charmbracelet/crush/internal/logging"
 18	"github.com/charmbracelet/crush/internal/pubsub"
 19	"github.com/charmbracelet/crush/internal/tui"
 20	"github.com/charmbracelet/crush/internal/version"
 21	"github.com/charmbracelet/fang"
 22	"github.com/charmbracelet/x/term"
 23	"github.com/spf13/cobra"
 24)
 25
 26var rootCmd = &cobra.Command{
 27	Use:   "crush",
 28	Short: "Terminal-based AI assistant for software development",
 29	Long: `Crush is a powerful terminal-based AI assistant that helps with software development tasks.
 30It provides an interactive chat interface with AI capabilities, code analysis, and LSP integration
 31to assist developers in writing, debugging, and understanding code directly from the terminal.`,
 32	Example: `
 33  # Run in interactive mode
 34  crush
 35
 36  # Run with debug logging
 37  crush -d
 38
 39  # Run with debug logging in a specific directory
 40  crush -d -c /path/to/project
 41
 42  # Print version
 43  crush -v
 44
 45  # Run a single non-interactive prompt
 46  crush -p "Explain the use of context in Go"
 47
 48  # Run a single non-interactive prompt with JSON output format
 49  crush -p "Explain the use of context in Go" -f json
 50  `,
 51	RunE: func(cmd *cobra.Command, args []string) error {
 52		// Load the config
 53		debug, _ := cmd.Flags().GetBool("debug")
 54		cwd, _ := cmd.Flags().GetString("cwd")
 55		prompt, _ := cmd.Flags().GetString("prompt")
 56		outputFormat, _ := cmd.Flags().GetString("output-format")
 57		quiet, _ := cmd.Flags().GetBool("quiet")
 58
 59		// Validate format option
 60		if !format.IsValid(outputFormat) {
 61			return fmt.Errorf("invalid format option: %s\n%s", outputFormat, format.GetHelpText())
 62		}
 63
 64		if cwd != "" {
 65			err := os.Chdir(cwd)
 66			if err != nil {
 67				return fmt.Errorf("failed to change directory: %v", err)
 68			}
 69		}
 70		if cwd == "" {
 71			c, err := os.Getwd()
 72			if err != nil {
 73				return fmt.Errorf("failed to get current working directory: %v", err)
 74			}
 75			cwd = c
 76		}
 77
 78		_, err := config.Init(cwd, debug)
 79		if err != nil {
 80			return err
 81		}
 82
 83		// Create main context for the application
 84		ctx, cancel := context.WithCancel(context.Background())
 85		defer cancel()
 86
 87		// Connect DB, this will also run migrations
 88		conn, err := db.Connect(ctx)
 89		if err != nil {
 90			return err
 91		}
 92
 93		app, err := app.New(ctx, conn)
 94		if err != nil {
 95			logging.Error("Failed to create app: %v", err)
 96			return err
 97		}
 98		// Defer shutdown here so it runs for both interactive and non-interactive modes
 99		defer app.Shutdown()
100
101		// Initialize MCP tools early for both modes
102		initMCPTools(ctx, app)
103
104		prompt, err = maybePrependStdin(prompt)
105		if err != nil {
106			logging.Error("Failed to read stdin: %v", err)
107			return err
108		}
109
110		// Non-interactive mode
111		if prompt != "" {
112			// Run non-interactive flow using the App method
113			return app.RunNonInteractive(ctx, prompt, outputFormat, quiet)
114		}
115
116		// Set up the TUI
117		program := tea.NewProgram(
118			tui.New(app),
119			tea.WithAltScreen(),
120			tea.WithKeyReleases(),
121			tea.WithUniformKeyLayout(),
122		)
123
124		// Setup the subscriptions, this will send services events to the TUI
125		ch, cancelSubs := setupSubscriptions(app, ctx)
126
127		// Create a context for the TUI message handler
128		tuiCtx, tuiCancel := context.WithCancel(ctx)
129		var tuiWg sync.WaitGroup
130		tuiWg.Add(1)
131
132		// Set up message handling for the TUI
133		go func() {
134			defer tuiWg.Done()
135			defer logging.RecoverPanic("TUI-message-handler", func() {
136				attemptTUIRecovery(program)
137			})
138
139			for {
140				select {
141				case <-tuiCtx.Done():
142					logging.Info("TUI message handler shutting down")
143					return
144				case msg, ok := <-ch:
145					if !ok {
146						logging.Info("TUI message channel closed")
147						return
148					}
149					program.Send(msg)
150				}
151			}
152		}()
153
154		// Cleanup function for when the program exits
155		cleanup := func() {
156			// Shutdown the app
157			app.Shutdown()
158
159			// Cancel subscriptions first
160			cancelSubs()
161
162			// Then cancel TUI message handler
163			tuiCancel()
164
165			// Wait for TUI message handler to finish
166			tuiWg.Wait()
167
168			logging.Info("All goroutines cleaned up")
169		}
170
171		// Run the TUI
172		result, err := program.Run()
173		cleanup()
174
175		if err != nil {
176			logging.Error("TUI error: %v", err)
177			return fmt.Errorf("TUI error: %v", err)
178		}
179
180		logging.Info("TUI exited with result: %v", result)
181		return nil
182	},
183}
184
185// attemptTUIRecovery tries to recover the TUI after a panic
186func attemptTUIRecovery(program *tea.Program) {
187	logging.Info("Attempting to recover TUI after panic")
188
189	// We could try to restart the TUI or gracefully exit
190	// For now, we'll just quit the program to avoid further issues
191	program.Quit()
192}
193
194func initMCPTools(ctx context.Context, app *app.App) {
195	go func() {
196		defer logging.RecoverPanic("MCP-goroutine", nil)
197
198		// Create a context with timeout for the initial MCP tools fetch
199		ctxWithTimeout, cancel := context.WithTimeout(ctx, 30*time.Second)
200		defer cancel()
201
202		// Set this up once with proper error handling
203		agent.GetMcpTools(ctxWithTimeout, app.Permissions)
204		logging.Info("MCP message handling goroutine exiting")
205	}()
206}
207
208func setupSubscriber[T any](
209	ctx context.Context,
210	wg *sync.WaitGroup,
211	name string,
212	subscriber func(context.Context) <-chan pubsub.Event[T],
213	outputCh chan<- tea.Msg,
214) {
215	wg.Add(1)
216	go func() {
217		defer wg.Done()
218		defer logging.RecoverPanic(fmt.Sprintf("subscription-%s", name), nil)
219
220		subCh := subscriber(ctx)
221
222		for {
223			select {
224			case event, ok := <-subCh:
225				if !ok {
226					logging.Info("subscription channel closed", "name", name)
227					return
228				}
229
230				var msg tea.Msg = event
231
232				select {
233				case outputCh <- msg:
234				case <-time.After(2 * time.Second):
235					logging.Warn("message dropped due to slow consumer", "name", name)
236				case <-ctx.Done():
237					logging.Info("subscription cancelled", "name", name)
238					return
239				}
240			case <-ctx.Done():
241				logging.Info("subscription cancelled", "name", name)
242				return
243			}
244		}
245	}()
246}
247
248func setupSubscriptions(app *app.App, parentCtx context.Context) (chan tea.Msg, func()) {
249	ch := make(chan tea.Msg, 100)
250
251	wg := sync.WaitGroup{}
252	ctx, cancel := context.WithCancel(parentCtx) // Inherit from parent context
253
254	setupSubscriber(ctx, &wg, "logging", logging.Subscribe, ch)
255	setupSubscriber(ctx, &wg, "sessions", app.Sessions.Subscribe, ch)
256	setupSubscriber(ctx, &wg, "messages", app.Messages.Subscribe, ch)
257	setupSubscriber(ctx, &wg, "permissions", app.Permissions.Subscribe, ch)
258	setupSubscriber(ctx, &wg, "coderAgent", app.CoderAgent.Subscribe, ch)
259	setupSubscriber(ctx, &wg, "history", app.History.Subscribe, ch)
260
261	cleanupFunc := func() {
262		logging.Info("Cancelling all subscriptions")
263		cancel() // Signal all goroutines to stop
264
265		waitCh := make(chan struct{})
266		go func() {
267			defer logging.RecoverPanic("subscription-cleanup", nil)
268			wg.Wait()
269			close(waitCh)
270		}()
271
272		select {
273		case <-waitCh:
274			logging.Info("All subscription goroutines completed successfully")
275			close(ch) // Only close after all writers are confirmed done
276		case <-time.After(5 * time.Second):
277			logging.Warn("Timed out waiting for some subscription goroutines to complete")
278			close(ch)
279		}
280	}
281	return ch, cleanupFunc
282}
283
284func Execute() {
285	if err := fang.Execute(
286		context.Background(),
287		rootCmd,
288		fang.WithVersion(version.Version),
289	); err != nil {
290		os.Exit(1)
291	}
292}
293
294func init() {
295	rootCmd.Flags().BoolP("help", "h", false, "Help")
296	rootCmd.Flags().BoolP("debug", "d", false, "Debug")
297	rootCmd.Flags().StringP("cwd", "c", "", "Current working directory")
298	rootCmd.Flags().StringP("prompt", "p", "", "Prompt to run in non-interactive mode")
299
300	// Add format flag with validation logic
301	rootCmd.Flags().StringP("output-format", "f", format.Text.String(),
302		"Output format for non-interactive mode (text, json)")
303
304	// Add quiet flag to hide spinner in non-interactive mode
305	rootCmd.Flags().BoolP("quiet", "q", false, "Hide spinner in non-interactive mode")
306
307	// Register custom validation for the format flag
308	rootCmd.RegisterFlagCompletionFunc("output-format", func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
309		return format.SupportedFormats, cobra.ShellCompDirectiveNoFileComp
310	})
311}
312
313func maybePrependStdin(prompt string) (string, error) {
314	if term.IsTerminal(os.Stdin.Fd()) {
315		return prompt, nil
316	}
317	fi, err := os.Stdin.Stat()
318	if err != nil {
319		return prompt, err
320	}
321	if fi.Mode()&os.ModeNamedPipe == 0 {
322		return prompt, nil
323	}
324	bts, err := io.ReadAll(os.Stdin)
325	if err != nil {
326		return prompt, err
327	}
328	return string(bts) + "\n\n" + prompt, nil
329}