root.go

  1package cmd
  2
  3import (
  4	"context"
  5	"fmt"
  6	"os"
  7	"sync"
  8	"time"
  9
 10	tea "github.com/charmbracelet/bubbletea/v2"
 11	"github.com/charmbracelet/crush/internal/app"
 12	"github.com/charmbracelet/crush/internal/config"
 13	"github.com/charmbracelet/crush/internal/db"
 14	"github.com/charmbracelet/crush/internal/format"
 15	"github.com/charmbracelet/crush/internal/llm/agent"
 16	"github.com/charmbracelet/crush/internal/logging"
 17	"github.com/charmbracelet/crush/internal/pubsub"
 18	"github.com/charmbracelet/crush/internal/tui"
 19	"github.com/charmbracelet/crush/internal/version"
 20	"github.com/charmbracelet/fang"
 21	"github.com/spf13/cobra"
 22)
 23
 24var rootCmd = &cobra.Command{
 25	Use:   "crush",
 26	Short: "Terminal-based AI assistant for software development",
 27	Long: `Crush is a powerful terminal-based AI assistant that helps with software development tasks.
 28It provides an interactive chat interface with AI capabilities, code analysis, and LSP integration
 29to assist developers in writing, debugging, and understanding code directly from the terminal.`,
 30	Example: `
 31  # Run in interactive mode
 32  crush
 33
 34  # Run with debug logging
 35  crush -d
 36
 37  # Run with debug logging in a specific directory
 38  crush -d -c /path/to/project
 39
 40  # Print version
 41  crush -v
 42
 43  # Run a single non-interactive prompt
 44  crush -p "Explain the use of context in Go"
 45
 46  # Run a single non-interactive prompt with JSON output format
 47  crush -p "Explain the use of context in Go" -f json
 48  `,
 49	RunE: func(cmd *cobra.Command, args []string) error {
 50		// Load the config
 51		debug, _ := cmd.Flags().GetBool("debug")
 52		cwd, _ := cmd.Flags().GetString("cwd")
 53		prompt, _ := cmd.Flags().GetString("prompt")
 54		outputFormat, _ := cmd.Flags().GetString("output-format")
 55		quiet, _ := cmd.Flags().GetBool("quiet")
 56
 57		// Validate format option
 58		if !format.IsValid(outputFormat) {
 59			return fmt.Errorf("invalid format option: %s\n%s", outputFormat, format.GetHelpText())
 60		}
 61
 62		if cwd != "" {
 63			err := os.Chdir(cwd)
 64			if err != nil {
 65				return fmt.Errorf("failed to change directory: %v", err)
 66			}
 67		}
 68		if cwd == "" {
 69			c, err := os.Getwd()
 70			if err != nil {
 71				return fmt.Errorf("failed to get current working directory: %v", err)
 72			}
 73			cwd = c
 74		}
 75		_, err := config.Load(cwd, debug)
 76		if err != nil {
 77			return err
 78		}
 79
 80		// Create main context for the application
 81		ctx, cancel := context.WithCancel(context.Background())
 82		defer cancel()
 83
 84		// Connect DB, this will also run migrations
 85		conn, err := db.Connect(ctx)
 86		if err != nil {
 87			return err
 88		}
 89
 90		app, err := app.New(ctx, conn)
 91		if err != nil {
 92			logging.Error("Failed to create app: %v", err)
 93			return err
 94		}
 95		// Defer shutdown here so it runs for both interactive and non-interactive modes
 96		defer app.Shutdown()
 97
 98		// Initialize MCP tools early for both modes
 99		initMCPTools(ctx, app)
100
101		// Non-interactive mode
102		if prompt != "" {
103			// Run non-interactive flow using the App method
104			return app.RunNonInteractive(ctx, prompt, outputFormat, quiet)
105		}
106
107		// Set up the TUI
108		program := tea.NewProgram(
109			tui.New(app),
110			tea.WithAltScreen(),
111			tea.WithKeyReleases(),
112			tea.WithUniformKeyLayout(),
113		)
114
115		// Setup the subscriptions, this will send services events to the TUI
116		ch, cancelSubs := setupSubscriptions(app, ctx)
117
118		// Create a context for the TUI message handler
119		tuiCtx, tuiCancel := context.WithCancel(ctx)
120		var tuiWg sync.WaitGroup
121		tuiWg.Add(1)
122
123		// Set up message handling for the TUI
124		go func() {
125			defer tuiWg.Done()
126			defer logging.RecoverPanic("TUI-message-handler", func() {
127				attemptTUIRecovery(program)
128			})
129
130			for {
131				select {
132				case <-tuiCtx.Done():
133					logging.Info("TUI message handler shutting down")
134					return
135				case msg, ok := <-ch:
136					if !ok {
137						logging.Info("TUI message channel closed")
138						return
139					}
140					program.Send(msg)
141				}
142			}
143		}()
144
145		// Cleanup function for when the program exits
146		cleanup := func() {
147			// Shutdown the app
148			app.Shutdown()
149
150			// Cancel subscriptions first
151			cancelSubs()
152
153			// Then cancel TUI message handler
154			tuiCancel()
155
156			// Wait for TUI message handler to finish
157			tuiWg.Wait()
158
159			logging.Info("All goroutines cleaned up")
160		}
161
162		// Run the TUI
163		result, err := program.Run()
164		cleanup()
165
166		if err != nil {
167			logging.Error("TUI error: %v", err)
168			return fmt.Errorf("TUI error: %v", err)
169		}
170
171		logging.Info("TUI exited with result: %v", result)
172		return nil
173	},
174}
175
176// attemptTUIRecovery tries to recover the TUI after a panic
177func attemptTUIRecovery(program *tea.Program) {
178	logging.Info("Attempting to recover TUI after panic")
179
180	// We could try to restart the TUI or gracefully exit
181	// For now, we'll just quit the program to avoid further issues
182	program.Quit()
183}
184
185func initMCPTools(ctx context.Context, app *app.App) {
186	go func() {
187		defer logging.RecoverPanic("MCP-goroutine", nil)
188
189		// Create a context with timeout for the initial MCP tools fetch
190		ctxWithTimeout, cancel := context.WithTimeout(ctx, 30*time.Second)
191		defer cancel()
192
193		// Set this up once with proper error handling
194		agent.GetMcpTools(ctxWithTimeout, app.Permissions)
195		logging.Info("MCP message handling goroutine exiting")
196	}()
197}
198
199func setupSubscriber[T any](
200	ctx context.Context,
201	wg *sync.WaitGroup,
202	name string,
203	subscriber func(context.Context) <-chan pubsub.Event[T],
204	outputCh chan<- tea.Msg,
205) {
206	wg.Add(1)
207	go func() {
208		defer wg.Done()
209		defer logging.RecoverPanic(fmt.Sprintf("subscription-%s", name), nil)
210
211		subCh := subscriber(ctx)
212
213		for {
214			select {
215			case event, ok := <-subCh:
216				if !ok {
217					logging.Info("subscription channel closed", "name", name)
218					return
219				}
220
221				var msg tea.Msg = event
222
223				select {
224				case outputCh <- msg:
225				case <-time.After(2 * time.Second):
226					logging.Warn("message dropped due to slow consumer", "name", name)
227				case <-ctx.Done():
228					logging.Info("subscription cancelled", "name", name)
229					return
230				}
231			case <-ctx.Done():
232				logging.Info("subscription cancelled", "name", name)
233				return
234			}
235		}
236	}()
237}
238
239func setupSubscriptions(app *app.App, parentCtx context.Context) (chan tea.Msg, func()) {
240	ch := make(chan tea.Msg, 100)
241
242	wg := sync.WaitGroup{}
243	ctx, cancel := context.WithCancel(parentCtx) // Inherit from parent context
244
245	setupSubscriber(ctx, &wg, "logging", logging.Subscribe, ch)
246	setupSubscriber(ctx, &wg, "sessions", app.Sessions.Subscribe, ch)
247	setupSubscriber(ctx, &wg, "messages", app.Messages.Subscribe, ch)
248	setupSubscriber(ctx, &wg, "permissions", app.Permissions.Subscribe, ch)
249	setupSubscriber(ctx, &wg, "coderAgent", app.CoderAgent.Subscribe, ch)
250	setupSubscriber(ctx, &wg, "history", app.History.Subscribe, ch)
251
252	cleanupFunc := func() {
253		logging.Info("Cancelling all subscriptions")
254		cancel() // Signal all goroutines to stop
255
256		waitCh := make(chan struct{})
257		go func() {
258			defer logging.RecoverPanic("subscription-cleanup", nil)
259			wg.Wait()
260			close(waitCh)
261		}()
262
263		select {
264		case <-waitCh:
265			logging.Info("All subscription goroutines completed successfully")
266			close(ch) // Only close after all writers are confirmed done
267		case <-time.After(5 * time.Second):
268			logging.Warn("Timed out waiting for some subscription goroutines to complete")
269			close(ch)
270		}
271	}
272	return ch, cleanupFunc
273}
274
275func Execute() {
276	if err := fang.Execute(
277		context.Background(),
278		rootCmd,
279		fang.WithVersion(version.Version),
280	); err != nil {
281		os.Exit(1)
282	}
283}
284
285func init() {
286	rootCmd.Flags().BoolP("help", "h", false, "Help")
287	rootCmd.Flags().BoolP("debug", "d", false, "Debug")
288	rootCmd.Flags().StringP("cwd", "c", "", "Current working directory")
289	rootCmd.Flags().StringP("prompt", "p", "", "Prompt to run in non-interactive mode")
290
291	// Add format flag with validation logic
292	rootCmd.Flags().StringP("output-format", "f", format.Text.String(),
293		"Output format for non-interactive mode (text, json)")
294
295	// Add quiet flag to hide spinner in non-interactive mode
296	rootCmd.Flags().BoolP("quiet", "q", false, "Hide spinner in non-interactive mode")
297
298	// Register custom validation for the format flag
299	rootCmd.RegisterFlagCompletionFunc("output-format", func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
300		return format.SupportedFormats, cobra.ShellCompDirectiveNoFileComp
301	})
302}