root.go

  1package cmd
  2
  3import (
  4	"context"
  5	"fmt"
  6	"io"
  7	"log/slog"
  8	"os"
  9	"sync"
 10	"time"
 11
 12	tea "github.com/charmbracelet/bubbletea/v2"
 13	"github.com/charmbracelet/crush/internal/app"
 14	"github.com/charmbracelet/crush/internal/config"
 15	"github.com/charmbracelet/crush/internal/db"
 16	"github.com/charmbracelet/crush/internal/format"
 17	"github.com/charmbracelet/crush/internal/llm/agent"
 18	"github.com/charmbracelet/crush/internal/log"
 19	"github.com/charmbracelet/crush/internal/pubsub"
 20	"github.com/charmbracelet/crush/internal/tui"
 21	"github.com/charmbracelet/crush/internal/version"
 22	"github.com/charmbracelet/fang"
 23	"github.com/charmbracelet/x/term"
 24	"github.com/spf13/cobra"
 25)
 26
 27var rootCmd = &cobra.Command{
 28	Use:   "crush",
 29	Short: "Terminal-based AI assistant for software development",
 30	Long: `Crush is a powerful terminal-based AI assistant that helps with software development tasks.
 31It provides an interactive chat interface with AI capabilities, code analysis, and LSP integration
 32to assist developers in writing, debugging, and understanding code directly from the terminal.`,
 33	Example: `
 34  # Run in interactive mode
 35  crush
 36
 37  # Run with debug logging
 38  crush -d
 39
 40  # Run with debug slog.in a specific directory
 41  crush -d -c /path/to/project
 42
 43  # Print version
 44  crush -v
 45
 46  # Run a single non-interactive prompt
 47  crush -p "Explain the use of context in Go"
 48
 49  # Run a single non-interactive prompt with JSON output format
 50  crush -p "Explain the use of context in Go" -f json
 51  `,
 52	RunE: func(cmd *cobra.Command, args []string) error {
 53		// Load the config
 54		debug, _ := cmd.Flags().GetBool("debug")
 55		cwd, _ := cmd.Flags().GetString("cwd")
 56		prompt, _ := cmd.Flags().GetString("prompt")
 57		outputFormat, _ := cmd.Flags().GetString("output-format")
 58		quiet, _ := cmd.Flags().GetBool("quiet")
 59
 60		// Validate format option
 61		if !format.IsValid(outputFormat) {
 62			return fmt.Errorf("invalid format option: %s\n%s", outputFormat, format.GetHelpText())
 63		}
 64
 65		if cwd != "" {
 66			err := os.Chdir(cwd)
 67			if err != nil {
 68				return fmt.Errorf("failed to change directory: %v", err)
 69			}
 70		}
 71		if cwd == "" {
 72			c, err := os.Getwd()
 73			if err != nil {
 74				return fmt.Errorf("failed to get current working directory: %v", err)
 75			}
 76			cwd = c
 77		}
 78
 79		cfg, err := config.Init(cwd, debug)
 80		if err != nil {
 81			return err
 82		}
 83
 84		// Create main context for the application
 85		ctx, cancel := context.WithCancel(context.Background())
 86		defer cancel()
 87
 88		// Connect DB, this will also run migrations
 89		conn, err := db.Connect(ctx, cfg.Options.DataDirectory)
 90		if err != nil {
 91			return err
 92		}
 93
 94		app, err := app.New(ctx, conn, cfg)
 95		if err != nil {
 96			slog.Error(fmt.Sprintf("Failed to create app instance: %v", err))
 97			return err
 98		}
 99		// Defer shutdown here so it runs for both interactive and non-interactive modes
100		defer app.Shutdown()
101
102		// Initialize MCP tools early for both modes
103		initMCPTools(ctx, app, cfg)
104
105		prompt, err = maybePrependStdin(prompt)
106		if err != nil {
107			slog.Error(fmt.Sprintf("Failed to read from stdin: %v", err))
108			return err
109		}
110
111		// Non-interactive mode
112		if prompt != "" {
113			// Run non-interactive flow using the App method
114			return app.RunNonInteractive(ctx, prompt, outputFormat, quiet)
115		}
116
117		// Set up the TUI
118		program := tea.NewProgram(
119			tui.New(app),
120			tea.WithAltScreen(),
121			tea.WithKeyReleases(),
122			tea.WithUniformKeyLayout(),
123		)
124
125		// Setup the subscriptions, this will send services events to the TUI
126		ch, cancelSubs := setupSubscriptions(app, ctx)
127
128		// Create a context for the TUI message handler
129		tuiCtx, tuiCancel := context.WithCancel(ctx)
130		var tuiWg sync.WaitGroup
131		tuiWg.Add(1)
132
133		// Set up message handling for the TUI
134		go func() {
135			defer tuiWg.Done()
136			defer log.RecoverPanic("TUI-message-handler", func() {
137				attemptTUIRecovery(program)
138			})
139
140			for {
141				select {
142				case <-tuiCtx.Done():
143					slog.Info("TUI message handler shutting down")
144					return
145				case msg, ok := <-ch:
146					if !ok {
147						slog.Info("TUI message channel closed")
148						return
149					}
150					program.Send(msg)
151				}
152			}
153		}()
154
155		// Cleanup function for when the program exits
156		cleanup := func() {
157			// Shutdown the app
158			app.Shutdown()
159
160			// Cancel subscriptions first
161			cancelSubs()
162
163			// Then cancel TUI message handler
164			tuiCancel()
165
166			// Wait for TUI message handler to finish
167			tuiWg.Wait()
168
169			slog.Info("All goroutines cleaned up")
170		}
171
172		// Run the TUI
173		result, err := program.Run()
174		cleanup()
175
176		if err != nil {
177			slog.Error(fmt.Sprintf("TUI run error: %v", err))
178			return fmt.Errorf("TUI error: %v", err)
179		}
180
181		slog.Info(fmt.Sprintf("TUI exited with result: %v", result))
182		return nil
183	},
184}
185
186// attemptTUIRecovery tries to recover the TUI after a panic
187func attemptTUIRecovery(program *tea.Program) {
188	slog.Info("Attempting to recover TUI after panic")
189
190	// We could try to restart the TUI or gracefully exit
191	// For now, we'll just quit the program to avoid further issues
192	program.Quit()
193}
194
195func initMCPTools(ctx context.Context, app *app.App, cfg *config.Config) {
196	go func() {
197		defer log.RecoverPanic("MCP-goroutine", nil)
198
199		// Create a context with timeout for the initial MCP tools fetch
200		ctxWithTimeout, cancel := context.WithTimeout(ctx, 30*time.Second)
201		defer cancel()
202
203		// Set this up once with proper error handling
204		agent.GetMcpTools(ctxWithTimeout, app.Permissions, cfg)
205		slog.Info("MCP message handling goroutine exiting")
206	}()
207}
208
209func setupSubscriber[T any](
210	ctx context.Context,
211	wg *sync.WaitGroup,
212	name string,
213	subscriber func(context.Context) <-chan pubsub.Event[T],
214	outputCh chan<- tea.Msg,
215) {
216	wg.Add(1)
217	go func() {
218		defer wg.Done()
219		defer log.RecoverPanic(fmt.Sprintf("subscription-%s", name), nil)
220
221		subCh := subscriber(ctx)
222
223		for {
224			select {
225			case event, ok := <-subCh:
226				if !ok {
227					slog.Info("subscription channel closed", "name", name)
228					return
229				}
230
231				var msg tea.Msg = event
232
233				select {
234				case outputCh <- msg:
235				case <-time.After(2 * time.Second):
236					slog.Warn("message dropped due to slow consumer", "name", name)
237				case <-ctx.Done():
238					slog.Info("subscription cancelled", "name", name)
239					return
240				}
241			case <-ctx.Done():
242				slog.Info("subscription cancelled", "name", name)
243				return
244			}
245		}
246	}()
247}
248
249func setupSubscriptions(app *app.App, parentCtx context.Context) (chan tea.Msg, func()) {
250	ch := make(chan tea.Msg, 100)
251
252	wg := sync.WaitGroup{}
253	ctx, cancel := context.WithCancel(parentCtx) // Inherit from parent context
254
255	setupSubscriber(ctx, &wg, "sessions", app.Sessions.Subscribe, ch)
256	setupSubscriber(ctx, &wg, "messages", app.Messages.Subscribe, ch)
257	setupSubscriber(ctx, &wg, "permissions", app.Permissions.Subscribe, ch)
258	setupSubscriber(ctx, &wg, "coderAgent", app.CoderAgent.Subscribe, ch)
259	setupSubscriber(ctx, &wg, "history", app.History.Subscribe, ch)
260
261	cleanupFunc := func() {
262		slog.Info("Cancelling all subscriptions")
263		cancel() // Signal all goroutines to stop
264
265		waitCh := make(chan struct{})
266		go func() {
267			defer log.RecoverPanic("subscription-cleanup", nil)
268			wg.Wait()
269			close(waitCh)
270		}()
271
272		select {
273		case <-waitCh:
274			slog.Info("All subscription goroutines completed successfully")
275			close(ch) // Only close after all writers are confirmed done
276		case <-time.After(5 * time.Second):
277			slog.Warn("Timed out waiting for some subscription goroutines to complete")
278			close(ch)
279		}
280	}
281	return ch, cleanupFunc
282}
283
284func Execute() {
285	if err := fang.Execute(
286		context.Background(),
287		rootCmd,
288		fang.WithVersion(version.Version),
289	); err != nil {
290		os.Exit(1)
291	}
292}
293
294func init() {
295	rootCmd.PersistentFlags().StringP("cwd", "c", "", "Current working directory")
296
297	rootCmd.Flags().BoolP("help", "h", false, "Help")
298	rootCmd.Flags().BoolP("debug", "d", false, "Debug")
299	rootCmd.Flags().StringP("prompt", "p", "", "Prompt to run in non-interactive mode")
300
301	// Add format flag with validation logic
302	rootCmd.Flags().StringP("output-format", "f", format.Text.String(),
303		"Output format for non-interactive mode (text, json)")
304
305	// Add quiet flag to hide spinner in non-interactive mode
306	rootCmd.Flags().BoolP("quiet", "q", false, "Hide spinner in non-interactive mode")
307
308	// Register custom validation for the format flag
309	rootCmd.RegisterFlagCompletionFunc("output-format", func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
310		return format.SupportedFormats, cobra.ShellCompDirectiveNoFileComp
311	})
312}
313
314func maybePrependStdin(prompt string) (string, error) {
315	if term.IsTerminal(os.Stdin.Fd()) {
316		return prompt, nil
317	}
318	fi, err := os.Stdin.Stat()
319	if err != nil {
320		return prompt, err
321	}
322	if fi.Mode()&os.ModeNamedPipe == 0 {
323		return prompt, nil
324	}
325	bts, err := io.ReadAll(os.Stdin)
326	if err != nil {
327		return prompt, err
328	}
329	return string(bts) + "\n\n" + prompt, nil
330}