root.go

  1package cmd
  2
  3import (
  4	"context"
  5	"fmt"
  6	"os"
  7	"sync"
  8	"time"
  9
 10	tea "github.com/charmbracelet/bubbletea/v2"
 11	"github.com/charmbracelet/crush/internal/app"
 12	"github.com/charmbracelet/crush/internal/config"
 13	"github.com/charmbracelet/crush/internal/db"
 14	"github.com/charmbracelet/crush/internal/format"
 15	"github.com/charmbracelet/crush/internal/llm/agent"
 16	"github.com/charmbracelet/crush/internal/logging"
 17	"github.com/charmbracelet/crush/internal/pubsub"
 18	"github.com/charmbracelet/crush/internal/tui"
 19	"github.com/charmbracelet/crush/internal/version"
 20	"github.com/charmbracelet/fang"
 21	"github.com/spf13/cobra"
 22)
 23
 24var rootCmd = &cobra.Command{
 25	Use:   "crush",
 26	Short: "Terminal-based AI assistant for software development",
 27	Long: `Crush is a powerful terminal-based AI assistant that helps with software development tasks.
 28It provides an interactive chat interface with AI capabilities, code analysis, and LSP integration
 29to assist developers in writing, debugging, and understanding code directly from the terminal.`,
 30	Example: `
 31  # Run in interactive mode
 32  crush
 33
 34  # Run with debug logging
 35  crush -d
 36
 37  # Run with debug logging in a specific directory
 38  crush -d -c /path/to/project
 39
 40  # Print version
 41  crush -v
 42
 43  # Run a single non-interactive prompt
 44  crush -p "Explain the use of context in Go"
 45
 46  # Run a single non-interactive prompt with JSON output format
 47  crush -p "Explain the use of context in Go" -f json
 48  `,
 49	RunE: func(cmd *cobra.Command, args []string) error {
 50		// Load the config
 51		debug, _ := cmd.Flags().GetBool("debug")
 52		cwd, _ := cmd.Flags().GetString("cwd")
 53		prompt, _ := cmd.Flags().GetString("prompt")
 54		outputFormat, _ := cmd.Flags().GetString("output-format")
 55		quiet, _ := cmd.Flags().GetBool("quiet")
 56
 57		// Validate format option
 58		if !format.IsValid(outputFormat) {
 59			return fmt.Errorf("invalid format option: %s\n%s", outputFormat, format.GetHelpText())
 60		}
 61
 62		if cwd != "" {
 63			err := os.Chdir(cwd)
 64			if err != nil {
 65				return fmt.Errorf("failed to change directory: %v", err)
 66			}
 67		}
 68		if cwd == "" {
 69			c, err := os.Getwd()
 70			if err != nil {
 71				return fmt.Errorf("failed to get current working directory: %v", err)
 72			}
 73			cwd = c
 74		}
 75
 76		_, err := config.Init(cwd, debug)
 77		if err != nil {
 78			return err
 79		}
 80
 81		// Connect DB, this will also run migrations
 82		conn, err := db.Connect()
 83		if err != nil {
 84			return err
 85		}
 86
 87		// Create main context for the application
 88		ctx, cancel := context.WithCancel(context.Background())
 89		defer cancel()
 90
 91		app, err := app.New(ctx, conn)
 92		if err != nil {
 93			logging.Error("Failed to create app: %v", err)
 94			return err
 95		}
 96		// Defer shutdown here so it runs for both interactive and non-interactive modes
 97		defer app.Shutdown()
 98
 99		// Initialize MCP tools early for both modes
100		initMCPTools(ctx, app)
101
102		// Non-interactive mode
103		if prompt != "" {
104			// Run non-interactive flow using the App method
105			return app.RunNonInteractive(ctx, prompt, outputFormat, quiet)
106		}
107
108		// Set up the TUI
109		program := tea.NewProgram(
110			tui.New(app),
111			tea.WithAltScreen(),
112			tea.WithKeyReleases(),
113			tea.WithUniformKeyLayout(),
114		)
115
116		// Setup the subscriptions, this will send services events to the TUI
117		ch, cancelSubs := setupSubscriptions(app, ctx)
118
119		// Create a context for the TUI message handler
120		tuiCtx, tuiCancel := context.WithCancel(ctx)
121		var tuiWg sync.WaitGroup
122		tuiWg.Add(1)
123
124		// Set up message handling for the TUI
125		go func() {
126			defer tuiWg.Done()
127			defer logging.RecoverPanic("TUI-message-handler", func() {
128				attemptTUIRecovery(program)
129			})
130
131			for {
132				select {
133				case <-tuiCtx.Done():
134					logging.Info("TUI message handler shutting down")
135					return
136				case msg, ok := <-ch:
137					if !ok {
138						logging.Info("TUI message channel closed")
139						return
140					}
141					program.Send(msg)
142				}
143			}
144		}()
145
146		// Cleanup function for when the program exits
147		cleanup := func() {
148			// Shutdown the app
149			app.Shutdown()
150
151			// Cancel subscriptions first
152			cancelSubs()
153
154			// Then cancel TUI message handler
155			tuiCancel()
156
157			// Wait for TUI message handler to finish
158			tuiWg.Wait()
159
160			logging.Info("All goroutines cleaned up")
161		}
162
163		// Run the TUI
164		result, err := program.Run()
165		cleanup()
166
167		if err != nil {
168			logging.Error("TUI error: %v", err)
169			return fmt.Errorf("TUI error: %v", err)
170		}
171
172		logging.Info("TUI exited with result: %v", result)
173		return nil
174	},
175}
176
177// attemptTUIRecovery tries to recover the TUI after a panic
178func attemptTUIRecovery(program *tea.Program) {
179	logging.Info("Attempting to recover TUI after panic")
180
181	// We could try to restart the TUI or gracefully exit
182	// For now, we'll just quit the program to avoid further issues
183	program.Quit()
184}
185
186func initMCPTools(ctx context.Context, app *app.App) {
187	go func() {
188		defer logging.RecoverPanic("MCP-goroutine", nil)
189
190		// Create a context with timeout for the initial MCP tools fetch
191		ctxWithTimeout, cancel := context.WithTimeout(ctx, 30*time.Second)
192		defer cancel()
193
194		// Set this up once with proper error handling
195		agent.GetMcpTools(ctxWithTimeout, app.Permissions)
196		logging.Info("MCP message handling goroutine exiting")
197	}()
198}
199
200func setupSubscriber[T any](
201	ctx context.Context,
202	wg *sync.WaitGroup,
203	name string,
204	subscriber func(context.Context) <-chan pubsub.Event[T],
205	outputCh chan<- tea.Msg,
206) {
207	wg.Add(1)
208	go func() {
209		defer wg.Done()
210		defer logging.RecoverPanic(fmt.Sprintf("subscription-%s", name), nil)
211
212		subCh := subscriber(ctx)
213
214		for {
215			select {
216			case event, ok := <-subCh:
217				if !ok {
218					logging.Info("subscription channel closed", "name", name)
219					return
220				}
221
222				var msg tea.Msg = event
223
224				select {
225				case outputCh <- msg:
226				case <-time.After(2 * time.Second):
227					logging.Warn("message dropped due to slow consumer", "name", name)
228				case <-ctx.Done():
229					logging.Info("subscription cancelled", "name", name)
230					return
231				}
232			case <-ctx.Done():
233				logging.Info("subscription cancelled", "name", name)
234				return
235			}
236		}
237	}()
238}
239
240func setupSubscriptions(app *app.App, parentCtx context.Context) (chan tea.Msg, func()) {
241	ch := make(chan tea.Msg, 100)
242
243	wg := sync.WaitGroup{}
244	ctx, cancel := context.WithCancel(parentCtx) // Inherit from parent context
245
246	setupSubscriber(ctx, &wg, "logging", logging.Subscribe, ch)
247	setupSubscriber(ctx, &wg, "sessions", app.Sessions.Subscribe, ch)
248	setupSubscriber(ctx, &wg, "messages", app.Messages.Subscribe, ch)
249	setupSubscriber(ctx, &wg, "permissions", app.Permissions.Subscribe, ch)
250	setupSubscriber(ctx, &wg, "coderAgent", app.CoderAgent.Subscribe, ch)
251	setupSubscriber(ctx, &wg, "history", app.History.Subscribe, ch)
252
253	cleanupFunc := func() {
254		logging.Info("Cancelling all subscriptions")
255		cancel() // Signal all goroutines to stop
256
257		waitCh := make(chan struct{})
258		go func() {
259			defer logging.RecoverPanic("subscription-cleanup", nil)
260			wg.Wait()
261			close(waitCh)
262		}()
263
264		select {
265		case <-waitCh:
266			logging.Info("All subscription goroutines completed successfully")
267			close(ch) // Only close after all writers are confirmed done
268		case <-time.After(5 * time.Second):
269			logging.Warn("Timed out waiting for some subscription goroutines to complete")
270			close(ch)
271		}
272	}
273	return ch, cleanupFunc
274}
275
276func Execute() {
277	if err := fang.Execute(
278		context.Background(),
279		rootCmd,
280		fang.WithVersion(version.Version),
281	); err != nil {
282		os.Exit(1)
283	}
284}
285
286func init() {
287	rootCmd.Flags().BoolP("help", "h", false, "Help")
288	rootCmd.Flags().BoolP("debug", "d", false, "Debug")
289	rootCmd.Flags().StringP("cwd", "c", "", "Current working directory")
290	rootCmd.Flags().StringP("prompt", "p", "", "Prompt to run in non-interactive mode")
291
292	// Add format flag with validation logic
293	rootCmd.Flags().StringP("output-format", "f", format.Text.String(),
294		"Output format for non-interactive mode (text, json)")
295
296	// Add quiet flag to hide spinner in non-interactive mode
297	rootCmd.Flags().BoolP("quiet", "q", false, "Hide spinner in non-interactive mode")
298
299	// Register custom validation for the format flag
300	rootCmd.RegisterFlagCompletionFunc("output-format", func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
301		return format.SupportedFormats, cobra.ShellCompDirectiveNoFileComp
302	})
303}