app.go

  1package app
  2
  3import (
  4	"context"
  5	"database/sql"
  6	"errors"
  7	"fmt"
  8	"log/slog"
  9	"maps"
 10	"os/exec"
 11	"strings"
 12	"sync"
 13	"time"
 14
 15	tea "github.com/charmbracelet/bubbletea/v2"
 16	"github.com/charmbracelet/crush/internal/config"
 17	"github.com/charmbracelet/crush/internal/csync"
 18	"github.com/charmbracelet/crush/internal/db"
 19	"github.com/charmbracelet/crush/internal/format"
 20	"github.com/charmbracelet/crush/internal/history"
 21	"github.com/charmbracelet/crush/internal/llm/agent"
 22	"github.com/charmbracelet/crush/internal/log"
 23	"github.com/charmbracelet/crush/internal/pubsub"
 24
 25	"github.com/charmbracelet/crush/internal/lsp"
 26	"github.com/charmbracelet/crush/internal/lsp/watcher"
 27	"github.com/charmbracelet/crush/internal/message"
 28	"github.com/charmbracelet/crush/internal/permission"
 29	"github.com/charmbracelet/crush/internal/session"
 30)
 31
 32type App struct {
 33	Sessions    session.Service
 34	Messages    message.Service
 35	History     history.Service
 36	Permissions permission.Service
 37
 38	CoderAgent agent.Service
 39
 40	LSPClients map[string]*lsp.Client
 41
 42	clientsMutex sync.RWMutex
 43
 44	watcherCancelFuncs *csync.Slice[context.CancelFunc]
 45	lspWatcherWG       sync.WaitGroup
 46
 47	config *config.Config
 48
 49	serviceEventsWG *sync.WaitGroup
 50	eventsCtx       context.Context
 51	events          chan tea.Msg
 52	tuiWG           *sync.WaitGroup
 53
 54	// global context and cleanup functions
 55	globalCtx    context.Context
 56	cleanupFuncs []func() error
 57}
 58
 59// isGitRepo checks if the current directory is a git repository
 60func isGitRepo() bool {
 61	bts, err := exec.CommandContext(
 62		context.Background(),
 63		"git", "rev-parse",
 64		"--is-inside-work-tree",
 65	).CombinedOutput()
 66	return err == nil && strings.TrimSpace(string(bts)) == "true"
 67}
 68
 69// New initializes a new applcation instance.
 70func New(ctx context.Context, conn *sql.DB, cfg *config.Config) (*App, error) {
 71	q := db.New(conn)
 72	sessions := session.NewService(q)
 73	messages := message.NewService(q)
 74	files := history.NewService(q, conn)
 75	skipPermissionsRequests := cfg.Permissions != nil && cfg.Permissions.SkipRequests
 76	allowedTools := []string{}
 77	if cfg.Permissions != nil && cfg.Permissions.AllowedTools != nil {
 78		allowedTools = cfg.Permissions.AllowedTools
 79	}
 80
 81	app := &App{
 82		Sessions:    sessions,
 83		Messages:    messages,
 84		History:     files,
 85		Permissions: permission.NewPermissionService(cfg.WorkingDir(), skipPermissionsRequests, allowedTools),
 86		LSPClients:  make(map[string]*lsp.Client),
 87
 88		globalCtx: ctx,
 89
 90		config: cfg,
 91
 92		watcherCancelFuncs: csync.NewSlice[context.CancelFunc](),
 93
 94		events:          make(chan tea.Msg, 100),
 95		serviceEventsWG: &sync.WaitGroup{},
 96		tuiWG:           &sync.WaitGroup{},
 97	}
 98
 99	app.setupEvents()
100
101	// Start the global watcher only if this is a git repository
102	if isGitRepo() {
103		if err := watcher.Start(); err != nil {
104			return nil, fmt.Errorf("app: %w", err)
105		}
106	} else {
107		slog.Warn("Not starting global watcher: not a git repository")
108	}
109
110	// Initialize LSP clients in the background.
111	app.initLSPClients(ctx)
112
113	// cleanup database upon app shutdown
114	app.cleanupFuncs = append(app.cleanupFuncs, conn.Close)
115
116	// TODO: remove the concept of agent config, most likely.
117	if cfg.IsConfigured() {
118		if err := app.InitCoderAgent(); err != nil {
119			return nil, fmt.Errorf("failed to initialize coder agent: %w", err)
120		}
121	} else {
122		slog.Warn("No agent configuration found")
123	}
124	return app, nil
125}
126
127// Config returns the application configuration.
128func (app *App) Config() *config.Config {
129	return app.config
130}
131
132// RunNonInteractive handles the execution flow when a prompt is provided via
133// CLI flag.
134func (app *App) RunNonInteractive(ctx context.Context, prompt string, quiet bool) error {
135	slog.Info("Running in non-interactive mode")
136
137	ctx, cancel := context.WithCancel(ctx)
138	defer cancel()
139
140	// Start spinner if not in quiet mode.
141	var spinner *format.Spinner
142	if !quiet {
143		spinner = format.NewSpinner(ctx, cancel, "Generating")
144		spinner.Start()
145	}
146
147	// Helper function to stop spinner once.
148	stopSpinner := func() {
149		if !quiet && spinner != nil {
150			spinner.Stop()
151			spinner = nil
152		}
153	}
154	defer stopSpinner()
155
156	const maxPromptLengthForTitle = 100
157	titlePrefix := "Non-interactive: "
158	var titleSuffix string
159
160	if len(prompt) > maxPromptLengthForTitle {
161		titleSuffix = prompt[:maxPromptLengthForTitle] + "..."
162	} else {
163		titleSuffix = prompt
164	}
165	title := titlePrefix + titleSuffix
166
167	sess, err := app.Sessions.Create(ctx, title)
168	if err != nil {
169		return fmt.Errorf("failed to create session for non-interactive mode: %w", err)
170	}
171	slog.Info("Created session for non-interactive run", "session_id", sess.ID)
172
173	// Automatically approve all permission requests for this non-interactive session
174	app.Permissions.AutoApproveSession(sess.ID)
175
176	done, err := app.CoderAgent.Run(ctx, sess.ID, prompt)
177	if err != nil {
178		return fmt.Errorf("failed to start agent processing stream: %w", err)
179	}
180
181	messageEvents := app.Messages.Subscribe(ctx)
182	messageReadBytes := make(map[string]int)
183
184	for {
185		select {
186		case result := <-done:
187			stopSpinner()
188
189			if result.Error != nil {
190				if errors.Is(result.Error, context.Canceled) || errors.Is(result.Error, agent.ErrRequestCancelled) {
191					slog.Info("Non-interactive: agent processing cancelled", "session_id", sess.ID)
192					return nil
193				}
194				return fmt.Errorf("agent processing failed: %w", result.Error)
195			}
196
197			msgContent := result.Message.Content().String()
198			readBts := messageReadBytes[result.Message.ID]
199
200			if len(msgContent) < readBts {
201				slog.Error("Non-interactive: message content is shorter than read bytes", "message_length", len(msgContent), "read_bytes", readBts)
202				return fmt.Errorf("message content is shorter than read bytes: %d < %d", len(msgContent), readBts)
203			}
204			fmt.Println(msgContent[readBts:])
205			messageReadBytes[result.Message.ID] = len(msgContent)
206
207			slog.Info("Non-interactive: run completed", "session_id", sess.ID)
208			return nil
209
210		case event := <-messageEvents:
211			msg := event.Payload
212			if msg.SessionID == sess.ID && msg.Role == message.Assistant && len(msg.Parts) > 0 {
213				stopSpinner()
214
215				content := msg.Content().String()
216				readBytes := messageReadBytes[msg.ID]
217
218				if len(content) < readBytes {
219					slog.Error("Non-interactive: message content is shorter than read bytes", "message_length", len(content), "read_bytes", readBytes)
220					return fmt.Errorf("message content is shorter than read bytes: %d < %d", len(content), readBytes)
221				}
222
223				part := content[readBytes:]
224				fmt.Print(part)
225				messageReadBytes[msg.ID] = len(content)
226			}
227
228		case <-ctx.Done():
229			stopSpinner()
230			return ctx.Err()
231		}
232	}
233}
234
235func (app *App) UpdateAgentModel() error {
236	return app.CoderAgent.UpdateModel()
237}
238
239func (app *App) setupEvents() {
240	ctx, cancel := context.WithCancel(app.globalCtx)
241	app.eventsCtx = ctx
242	setupSubscriber(ctx, app.serviceEventsWG, "sessions", app.Sessions.Subscribe, app.events)
243	setupSubscriber(ctx, app.serviceEventsWG, "messages", app.Messages.Subscribe, app.events)
244	setupSubscriber(ctx, app.serviceEventsWG, "permissions", app.Permissions.Subscribe, app.events)
245	setupSubscriber(ctx, app.serviceEventsWG, "permissions-notifications", app.Permissions.SubscribeNotifications, app.events)
246	setupSubscriber(ctx, app.serviceEventsWG, "history", app.History.Subscribe, app.events)
247	setupSubscriber(ctx, app.serviceEventsWG, "mcp", agent.SubscribeMCPEvents, app.events)
248	setupSubscriber(ctx, app.serviceEventsWG, "lsp", SubscribeLSPEvents, app.events)
249	cleanupFunc := func() error {
250		cancel()
251		app.serviceEventsWG.Wait()
252		return nil
253	}
254	app.cleanupFuncs = append(app.cleanupFuncs, cleanupFunc)
255}
256
257func setupSubscriber[T any](
258	ctx context.Context,
259	wg *sync.WaitGroup,
260	name string,
261	subscriber func(context.Context) <-chan pubsub.Event[T],
262	outputCh chan<- tea.Msg,
263) {
264	wg.Go(func() {
265		subCh := subscriber(ctx)
266		for {
267			select {
268			case event, ok := <-subCh:
269				if !ok {
270					slog.Debug("subscription channel closed", "name", name)
271					return
272				}
273				var msg tea.Msg = event
274				select {
275				case outputCh <- msg:
276				case <-time.After(2 * time.Second):
277					slog.Warn("message dropped due to slow consumer", "name", name)
278				case <-ctx.Done():
279					slog.Debug("subscription cancelled", "name", name)
280					return
281				}
282			case <-ctx.Done():
283				slog.Debug("subscription cancelled", "name", name)
284				return
285			}
286		}
287	})
288}
289
290func (app *App) InitCoderAgent() error {
291	coderAgentCfg := app.config.Agents["coder"]
292	if coderAgentCfg.ID == "" {
293		return fmt.Errorf("coder agent configuration is missing")
294	}
295	var err error
296	app.CoderAgent, err = agent.NewAgent(
297		app.globalCtx,
298		coderAgentCfg,
299		app.Permissions,
300		app.Sessions,
301		app.Messages,
302		app.History,
303		app.LSPClients,
304	)
305	if err != nil {
306		slog.Error("Failed to create coder agent", "err", err)
307		return err
308	}
309
310	// Add MCP client cleanup to shutdown process
311	app.cleanupFuncs = append(app.cleanupFuncs, agent.CloseMCPClients)
312
313	setupSubscriber(app.eventsCtx, app.serviceEventsWG, "coderAgent", app.CoderAgent.Subscribe, app.events)
314	return nil
315}
316
317// Subscribe sends events to the TUI as tea.Msgs.
318func (app *App) Subscribe(program *tea.Program) {
319	defer log.RecoverPanic("app.Subscribe", func() {
320		slog.Info("TUI subscription panic: attempting graceful shutdown")
321		program.Quit()
322	})
323
324	app.tuiWG.Add(1)
325	tuiCtx, tuiCancel := context.WithCancel(app.globalCtx)
326	app.cleanupFuncs = append(app.cleanupFuncs, func() error {
327		slog.Debug("Cancelling TUI message handler")
328		tuiCancel()
329		app.tuiWG.Wait()
330		return nil
331	})
332	defer app.tuiWG.Done()
333
334	for {
335		select {
336		case <-tuiCtx.Done():
337			slog.Debug("TUI message handler shutting down")
338			return
339		case msg, ok := <-app.events:
340			if !ok {
341				slog.Debug("TUI message channel closed")
342				return
343			}
344			program.Send(msg)
345		}
346	}
347}
348
349// Shutdown performs a graceful shutdown of the application.
350func (app *App) Shutdown() {
351	if app.CoderAgent != nil {
352		app.CoderAgent.CancelAll()
353	}
354
355	for cancel := range app.watcherCancelFuncs.Seq() {
356		cancel()
357	}
358
359	// Wait for all LSP watchers to finish.
360	app.lspWatcherWG.Wait()
361
362	// Get all LSP clients.
363	app.clientsMutex.RLock()
364	clients := make(map[string]*lsp.Client, len(app.LSPClients))
365	maps.Copy(clients, app.LSPClients)
366	app.clientsMutex.RUnlock()
367
368	// Shutdown all LSP clients.
369	for name, client := range clients {
370		shutdownCtx, cancel := context.WithTimeout(app.globalCtx, 5*time.Second)
371		if err := client.Close(shutdownCtx); err != nil {
372			slog.Error("Failed to shutdown LSP client", "name", name, "error", err)
373		}
374		cancel()
375	}
376
377	// Shutdown the global watcher
378	watcher.Shutdown()
379
380	// Call call cleanup functions.
381	for _, cleanup := range app.cleanupFuncs {
382		if cleanup != nil {
383			if err := cleanup(); err != nil {
384				slog.Error("Failed to cleanup app properly on shutdown", "error", err)
385			}
386		}
387	}
388}