root.go

  1package cmd
  2
  3import (
  4	"context"
  5	"fmt"
  6	"os"
  7	"os/signal"
  8	"sync"
  9	"syscall"
 10	"time"
 11
 12	tea "github.com/charmbracelet/bubbletea/v2"
 13	"github.com/charmbracelet/crush/internal/app"
 14	"github.com/charmbracelet/crush/internal/config"
 15	"github.com/charmbracelet/crush/internal/db"
 16	"github.com/charmbracelet/crush/internal/format"
 17	"github.com/charmbracelet/crush/internal/llm/agent"
 18	"github.com/charmbracelet/crush/internal/logging"
 19	"github.com/charmbracelet/crush/internal/pubsub"
 20	"github.com/charmbracelet/crush/internal/tui"
 21	"github.com/charmbracelet/crush/internal/version"
 22	"github.com/charmbracelet/fang"
 23	"github.com/spf13/cobra"
 24)
 25
 26var rootCmd = &cobra.Command{
 27	Use:   "crush",
 28	Short: "Terminal-based AI assistant for software development",
 29	Long: `Crush is a powerful terminal-based AI assistant that helps with software development tasks.
 30It provides an interactive chat interface with AI capabilities, code analysis, and LSP integration
 31to assist developers in writing, debugging, and understanding code directly from the terminal.`,
 32	Example: `
 33  # Run in interactive mode
 34  crush
 35
 36  # Run with debug logging
 37  crush -d
 38
 39  # Run with debug logging in a specific directory
 40  crush -d -c /path/to/project
 41
 42  # Print version
 43  crush -v
 44
 45  # Run a single non-interactive prompt
 46  crush -p "Explain the use of context in Go"
 47
 48  # Run a single non-interactive prompt with JSON output format
 49  crush -p "Explain the use of context in Go" -f json
 50  `,
 51	RunE: func(cmd *cobra.Command, args []string) error {
 52		// Load the config
 53		debug, _ := cmd.Flags().GetBool("debug")
 54		cwd, _ := cmd.Flags().GetString("cwd")
 55		prompt, _ := cmd.Flags().GetString("prompt")
 56		outputFormat, _ := cmd.Flags().GetString("output-format")
 57		quiet, _ := cmd.Flags().GetBool("quiet")
 58
 59		// Validate format option
 60		if !format.IsValid(outputFormat) {
 61			return fmt.Errorf("invalid format option: %s\n%s", outputFormat, format.GetHelpText())
 62		}
 63
 64		if cwd != "" {
 65			err := os.Chdir(cwd)
 66			if err != nil {
 67				return fmt.Errorf("failed to change directory: %v", err)
 68			}
 69		}
 70		if cwd == "" {
 71			c, err := os.Getwd()
 72			if err != nil {
 73				return fmt.Errorf("failed to get current working directory: %v", err)
 74			}
 75			cwd = c
 76		}
 77		_, err := config.Load(cwd, debug)
 78		if err != nil {
 79			return err
 80		}
 81
 82		// Connect DB, this will also run migrations
 83		conn, err := db.Connect()
 84		if err != nil {
 85			return err
 86		}
 87
 88		// Create main context for the application with signal handling
 89		ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
 90		defer cancel()
 91
 92		app, err := app.New(ctx, conn)
 93		if err != nil {
 94			logging.Error("Failed to create app: %v", err)
 95			return err
 96		}
 97		// Defer shutdown here so it runs for both interactive and non-interactive modes
 98		defer app.Shutdown()
 99
100		// Initialize MCP tools early for both modes
101		initMCPTools(ctx, app)
102
103		// Non-interactive mode
104		if prompt != "" {
105			// Run non-interactive flow using the App method
106			return app.RunNonInteractive(ctx, prompt, outputFormat, quiet)
107		}
108
109		// Set up the TUI
110		program := tea.NewProgram(
111			tui.New(app),
112			tea.WithAltScreen(),
113			tea.WithKeyReleases(),
114			tea.WithUniformKeyLayout(),
115		)
116
117		// Setup the subscriptions, this will send services events to the TUI
118		ch, cancelSubs := setupSubscriptions(app, ctx)
119
120		// Create a context for the TUI message handler
121		tuiCtx, tuiCancel := context.WithCancel(ctx)
122		var tuiWg sync.WaitGroup
123		tuiWg.Add(1)
124
125		// Set up message handling for the TUI
126		go func() {
127			defer tuiWg.Done()
128			defer logging.RecoverPanic("TUI-message-handler", func() {
129				attemptTUIRecovery(program)
130			})
131
132			for {
133				select {
134				case <-tuiCtx.Done():
135					logging.Info("TUI message handler shutting down")
136					return
137				case msg, ok := <-ch:
138					if !ok {
139						logging.Info("TUI message channel closed")
140						return
141					}
142					program.Send(msg)
143				}
144			}
145		}()
146
147		// Cleanup function for when the program exits
148		cleanup := func() {
149			// Shutdown the app
150			app.Shutdown()
151
152			// Cancel subscriptions first
153			cancelSubs()
154
155			// Then cancel TUI message handler
156			tuiCancel()
157
158			// Wait for TUI message handler to finish
159			tuiWg.Wait()
160
161			logging.Info("All goroutines cleaned up")
162		}
163
164		// Run the TUI
165		result, err := program.Run()
166		cleanup()
167
168		if err != nil {
169			logging.Error("TUI error: %v", err)
170			return fmt.Errorf("TUI error: %v", err)
171		}
172
173		logging.Info("TUI exited with result: %v", result)
174		return nil
175	},
176}
177
178// attemptTUIRecovery tries to recover the TUI after a panic
179func attemptTUIRecovery(program *tea.Program) {
180	logging.Info("Attempting to recover TUI after panic")
181
182	// We could try to restart the TUI or gracefully exit
183	// For now, we'll just quit the program to avoid further issues
184	program.Quit()
185}
186
187func initMCPTools(ctx context.Context, app *app.App) {
188	go func() {
189		defer logging.RecoverPanic("MCP-goroutine", nil)
190
191		// Create a context with timeout for the initial MCP tools fetch
192		ctxWithTimeout, cancel := context.WithTimeout(ctx, 30*time.Second)
193		defer cancel()
194
195		// Set this up once with proper error handling
196		agent.GetMcpTools(ctxWithTimeout, app.Permissions)
197		logging.Info("MCP message handling goroutine exiting")
198	}()
199}
200
201func setupSubscriber[T any](
202	ctx context.Context,
203	wg *sync.WaitGroup,
204	name string,
205	subscriber func(context.Context) <-chan pubsub.Event[T],
206	outputCh chan<- tea.Msg,
207) {
208	wg.Add(1)
209	go func() {
210		defer wg.Done()
211		defer logging.RecoverPanic(fmt.Sprintf("subscription-%s", name), nil)
212
213		subCh := subscriber(ctx)
214
215		for {
216			select {
217			case event, ok := <-subCh:
218				if !ok {
219					logging.Info("subscription channel closed", "name", name)
220					return
221				}
222
223				var msg tea.Msg = event
224
225				select {
226				case outputCh <- msg:
227				case <-time.After(2 * time.Second):
228					logging.Warn("message dropped due to slow consumer", "name", name)
229				case <-ctx.Done():
230					logging.Info("subscription cancelled", "name", name)
231					return
232				}
233			case <-ctx.Done():
234				logging.Info("subscription cancelled", "name", name)
235				return
236			}
237		}
238	}()
239}
240
241func setupSubscriptions(app *app.App, parentCtx context.Context) (chan tea.Msg, func()) {
242	ch := make(chan tea.Msg, 100)
243
244	wg := sync.WaitGroup{}
245	ctx, cancel := context.WithCancel(parentCtx) // Inherit from parent context
246
247	setupSubscriber(ctx, &wg, "logging", logging.Subscribe, ch)
248	setupSubscriber(ctx, &wg, "sessions", app.Sessions.Subscribe, ch)
249	setupSubscriber(ctx, &wg, "messages", app.Messages.Subscribe, ch)
250	setupSubscriber(ctx, &wg, "permissions", app.Permissions.Subscribe, ch)
251	setupSubscriber(ctx, &wg, "coderAgent", app.CoderAgent.Subscribe, ch)
252	setupSubscriber(ctx, &wg, "history", app.History.Subscribe, ch)
253
254	cleanupFunc := func() {
255		logging.Info("Cancelling all subscriptions")
256		cancel() // Signal all goroutines to stop
257
258		waitCh := make(chan struct{})
259		go func() {
260			defer logging.RecoverPanic("subscription-cleanup", nil)
261			wg.Wait()
262			close(waitCh)
263		}()
264
265		select {
266		case <-waitCh:
267			logging.Info("All subscription goroutines completed successfully")
268			close(ch) // Only close after all writers are confirmed done
269		case <-time.After(5 * time.Second):
270			logging.Warn("Timed out waiting for some subscription goroutines to complete")
271			close(ch)
272		}
273	}
274	return ch, cleanupFunc
275}
276
277func Execute() {
278	if err := fang.Execute(
279		context.Background(),
280		rootCmd,
281		fang.WithVersion(version.Version),
282	); err != nil {
283		os.Exit(1)
284	}
285}
286
287func init() {
288	rootCmd.Flags().BoolP("help", "h", false, "Help")
289	rootCmd.Flags().BoolP("debug", "d", false, "Debug")
290	rootCmd.Flags().StringP("cwd", "c", "", "Current working directory")
291	rootCmd.Flags().StringP("prompt", "p", "", "Prompt to run in non-interactive mode")
292
293	// Add format flag with validation logic
294	rootCmd.Flags().StringP("output-format", "f", format.Text.String(),
295		"Output format for non-interactive mode (text, json)")
296
297	// Add quiet flag to hide spinner in non-interactive mode
298	rootCmd.Flags().BoolP("quiet", "q", false, "Hide spinner in non-interactive mode")
299
300	// Register custom validation for the format flag
301	rootCmd.RegisterFlagCompletionFunc("output-format", func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
302		return format.SupportedFormats, cobra.ShellCompDirectiveNoFileComp
303	})
304}